Spaces:
Running
Running
| import os | |
| import sys | |
| import shutil | |
| import uuid | |
| import zipfile | |
| import gradio as gr | |
| # Ensure repo root is importable on Spaces | |
| ROOT = os.path.dirname(__file__) | |
| if ROOT not in sys.path: | |
| sys.path.insert(0, ROOT) | |
| import kmer_predict # must be in repo root | |
| PERSIST_BASE = "/tmp/kmer_predict_runs" | |
| FASTA_EXTS = (".fa", ".fasta", ".fas", ".fna") | |
| def _zip_dir(folder: str, zip_path: str) -> None: | |
| """Zip the contents of folder into zip_path.""" | |
| with zipfile.ZipFile(zip_path, "w", zipfile.ZIP_DEFLATED) as z: | |
| for root, _, files in os.walk(folder): | |
| for fn in files: | |
| full = os.path.join(root, fn) | |
| rel = os.path.relpath(full, folder) | |
| z.write(full, rel) | |
| def _safe_extract_zip(zip_path: str, dst_dir: str) -> None: | |
| """Safely extract only FASTA files from ZIP (prevents zip-slip).""" | |
| with zipfile.ZipFile(zip_path, "r") as z: | |
| for member in z.infolist(): | |
| if member.is_dir(): | |
| continue | |
| # Zip-slip protection | |
| target = os.path.normpath(os.path.join(dst_dir, member.filename)) | |
| if not target.startswith(os.path.abspath(dst_dir) + os.sep): | |
| continue | |
| # Only FASTA-like files | |
| if not member.filename.lower().endswith(FASTA_EXTS): | |
| continue | |
| os.makedirs(os.path.dirname(target), exist_ok=True) | |
| with z.open(member) as src, open(target, "wb") as out: | |
| shutil.copyfileobj(src, out) | |
| def _ingest_unknown_uploads(unknown_uploads, unknown_dir: str) -> None: | |
| """ | |
| Accept unknown sequences as: | |
| - FASTA files, and/or | |
| - ZIP files containing FASTA files. | |
| Copies/extracts into unknown_dir. | |
| """ | |
| os.makedirs(unknown_dir, exist_ok=True) | |
| if not unknown_uploads: | |
| return | |
| for idx, f in enumerate(unknown_uploads, start=1): | |
| src = getattr(f, "path", None) or getattr(f, "name", None) or str(f) | |
| orig = ( | |
| getattr(f, "orig_name", None) | |
| or getattr(f, "filename", None) | |
| or os.path.basename(src) | |
| ) | |
| lower = str(orig).lower() | |
| # ZIP → extract FASTA files | |
| if lower.endswith(".zip") or str(src).lower().endswith(".zip"): | |
| _safe_extract_zip(src, unknown_dir) | |
| continue | |
| # FASTA → copy | |
| if lower.endswith(FASTA_EXTS): | |
| dst_name = os.path.basename(orig) | |
| else: | |
| # If Gradio provides a temp name without extension, keep it readable | |
| dst_name = f"unknown_{idx}.fasta" | |
| shutil.copy(src, os.path.join(unknown_dir, dst_name)) | |
| def run_prediction(unknown_uploads, kmer_zip, seqtype, mode, identity, coverage, fdr): | |
| if not unknown_uploads: | |
| raise gr.Error("Please upload unknown FASTA files or a ZIP containing FASTA files.") | |
| if not kmer_zip: | |
| raise gr.Error("Please upload the k-mer results ZIP from Space 1.") | |
| os.makedirs(PERSIST_BASE, exist_ok=True) | |
| run_id = uuid.uuid4().hex[:10] | |
| run_dir = os.path.join(PERSIST_BASE, f"run_{run_id}") | |
| os.makedirs(run_dir, exist_ok=True) | |
| unknown_dir = os.path.join(run_dir, "unknown") | |
| outdir = os.path.join(run_dir, "predictions") | |
| os.makedirs(unknown_dir, exist_ok=True) | |
| os.makedirs(outdir, exist_ok=True) | |
| # Ingest unknown uploads (FASTA and/or ZIP) | |
| _ingest_unknown_uploads(unknown_uploads, unknown_dir) | |
| # Ensure we actually got sequences | |
| # (Lightweight check: presence of at least one fasta-like file) | |
| found_any = any( | |
| fn.lower().endswith(FASTA_EXTS) | |
| for _, _, files in os.walk(unknown_dir) | |
| for fn in files | |
| ) | |
| if not found_any: | |
| raise gr.Error("No FASTA files were found after processing your uploads. Please check your ZIP contents.") | |
| # K-mer ZIP path (ZIP-only) | |
| kmer_zip_path = getattr(kmer_zip, "path", None) or getattr(kmer_zip, "name", None) or str(kmer_zip) | |
| if not str(kmer_zip_path).lower().endswith(".zip"): | |
| raise gr.Error("K-mer input must be a .zip file produced by Space 1.") | |
| # Run prediction | |
| kmer_predict.predict( | |
| unknown=unknown_dir, | |
| kmer_input=kmer_zip_path, | |
| output_dir=outdir, | |
| seqtype=seqtype, | |
| mode=mode, | |
| identity_threshold=float(identity), | |
| min_coverage=float(coverage), | |
| fdr_alpha=float(fdr), | |
| group_regex=kmer_predict.DEFAULT_GROUP_REGEX, | |
| ) | |
| plot_path = os.path.join(outdir, "predicted_results_summary.png") | |
| csv_path = os.path.join(outdir, "predictions_by_alignment.csv") | |
| zip_path = os.path.join(run_dir, "prediction_outputs.zip") | |
| _zip_dir(outdir, zip_path) | |
| return plot_path, csv_path, zip_path | |
| with gr.Blocks() as demo: | |
| gr.Markdown("# K-mer Sequence Predictor") | |
| gr.Markdown( | |
| "Upload **unknown sequences** (FASTA files or ZIP containing FASTA) and the **kmer_results.zip** from Space 1." | |
| ) | |
| unknown_uploads = gr.File( | |
| label="Unknown sequences (FASTA files or ZIP containing FASTA)", | |
| file_count="multiple", | |
| file_types=[".fa", ".fasta", ".fas", ".fna", ".zip"], | |
| ) | |
| kmer_zip = gr.File( | |
| label="kmer_results.zip (from Space 1)", | |
| file_count="single", | |
| file_types=[".zip"], | |
| ) | |
| with gr.Row(): | |
| seqtype = gr.Radio(["dna", "protein"], value="dna", label="Sequence type") | |
| mode = gr.Radio(["fast", "full"], value="fast", label="Mode") | |
| with gr.Row(): | |
| identity = gr.Number(value=0.90, precision=2, label="Identity (full mode)") | |
| coverage = gr.Number(value=0.80, precision=2, label="Coverage (full mode)") | |
| fdr = gr.Number(value=0.05, precision=3, label="FDR alpha (full mode)") | |
| run_btn = gr.Button("Run prediction") | |
| out_plot = gr.Image(label="Prediction summary plot") | |
| out_csv = gr.File(label="Predictions CSV") | |
| out_zip = gr.File(label="Download all outputs (ZIP)") | |
| run_btn.click( | |
| fn=run_prediction, | |
| inputs=[unknown_uploads, kmer_zip, seqtype, mode, identity, coverage, fdr], | |
| outputs=[out_plot, out_csv, out_zip], | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() | |