import os import sys import shutil import uuid import zipfile import gradio as gr # Ensure repo root is importable on Spaces ROOT = os.path.dirname(__file__) if ROOT not in sys.path: sys.path.insert(0, ROOT) import kmer_predict # must be in repo root PERSIST_BASE = "/tmp/kmer_predict_runs" FASTA_EXTS = (".fa", ".fasta", ".fas", ".fna") def _zip_dir(folder: str, zip_path: str) -> None: """Zip the contents of folder into zip_path.""" with zipfile.ZipFile(zip_path, "w", zipfile.ZIP_DEFLATED) as z: for root, _, files in os.walk(folder): for fn in files: full = os.path.join(root, fn) rel = os.path.relpath(full, folder) z.write(full, rel) def _safe_extract_zip(zip_path: str, dst_dir: str) -> None: """Safely extract only FASTA files from ZIP (prevents zip-slip).""" with zipfile.ZipFile(zip_path, "r") as z: for member in z.infolist(): if member.is_dir(): continue # Zip-slip protection target = os.path.normpath(os.path.join(dst_dir, member.filename)) if not target.startswith(os.path.abspath(dst_dir) + os.sep): continue # Only FASTA-like files if not member.filename.lower().endswith(FASTA_EXTS): continue os.makedirs(os.path.dirname(target), exist_ok=True) with z.open(member) as src, open(target, "wb") as out: shutil.copyfileobj(src, out) def _ingest_unknown_uploads(unknown_uploads, unknown_dir: str) -> None: """ Accept unknown sequences as: - FASTA files, and/or - ZIP files containing FASTA files. Copies/extracts into unknown_dir. """ os.makedirs(unknown_dir, exist_ok=True) if not unknown_uploads: return for idx, f in enumerate(unknown_uploads, start=1): src = getattr(f, "path", None) or getattr(f, "name", None) or str(f) orig = ( getattr(f, "orig_name", None) or getattr(f, "filename", None) or os.path.basename(src) ) lower = str(orig).lower() # ZIP → extract FASTA files if lower.endswith(".zip") or str(src).lower().endswith(".zip"): _safe_extract_zip(src, unknown_dir) continue # FASTA → copy if lower.endswith(FASTA_EXTS): dst_name = os.path.basename(orig) else: # If Gradio provides a temp name without extension, keep it readable dst_name = f"unknown_{idx}.fasta" shutil.copy(src, os.path.join(unknown_dir, dst_name)) def run_prediction(unknown_uploads, kmer_zip, seqtype, mode, identity, coverage, fdr): if not unknown_uploads: raise gr.Error("Please upload unknown FASTA files or a ZIP containing FASTA files.") if not kmer_zip: raise gr.Error("Please upload the k-mer results ZIP from Space 1.") os.makedirs(PERSIST_BASE, exist_ok=True) run_id = uuid.uuid4().hex[:10] run_dir = os.path.join(PERSIST_BASE, f"run_{run_id}") os.makedirs(run_dir, exist_ok=True) unknown_dir = os.path.join(run_dir, "unknown") outdir = os.path.join(run_dir, "predictions") os.makedirs(unknown_dir, exist_ok=True) os.makedirs(outdir, exist_ok=True) # Ingest unknown uploads (FASTA and/or ZIP) _ingest_unknown_uploads(unknown_uploads, unknown_dir) # Ensure we actually got sequences # (Lightweight check: presence of at least one fasta-like file) found_any = any( fn.lower().endswith(FASTA_EXTS) for _, _, files in os.walk(unknown_dir) for fn in files ) if not found_any: raise gr.Error("No FASTA files were found after processing your uploads. Please check your ZIP contents.") # K-mer ZIP path (ZIP-only) kmer_zip_path = getattr(kmer_zip, "path", None) or getattr(kmer_zip, "name", None) or str(kmer_zip) if not str(kmer_zip_path).lower().endswith(".zip"): raise gr.Error("K-mer input must be a .zip file produced by Space 1.") # Run prediction kmer_predict.predict( unknown=unknown_dir, kmer_input=kmer_zip_path, output_dir=outdir, seqtype=seqtype, mode=mode, identity_threshold=float(identity), min_coverage=float(coverage), fdr_alpha=float(fdr), group_regex=kmer_predict.DEFAULT_GROUP_REGEX, ) plot_path = os.path.join(outdir, "predicted_results_summary.png") csv_path = os.path.join(outdir, "predictions_by_alignment.csv") zip_path = os.path.join(run_dir, "prediction_outputs.zip") _zip_dir(outdir, zip_path) return plot_path, csv_path, zip_path with gr.Blocks() as demo: gr.Markdown("# K-mer Sequence Predictor") gr.Markdown( "Upload **unknown sequences** (FASTA files or ZIP containing FASTA) and the **kmer_results.zip** from Space 1." ) unknown_uploads = gr.File( label="Unknown sequences (FASTA files or ZIP containing FASTA)", file_count="multiple", file_types=[".fa", ".fasta", ".fas", ".fna", ".zip"], ) kmer_zip = gr.File( label="kmer_results.zip (from Space 1)", file_count="single", file_types=[".zip"], ) with gr.Row(): seqtype = gr.Radio(["dna", "protein"], value="dna", label="Sequence type") mode = gr.Radio(["fast", "full"], value="fast", label="Mode") with gr.Row(): identity = gr.Number(value=0.90, precision=2, label="Identity (full mode)") coverage = gr.Number(value=0.80, precision=2, label="Coverage (full mode)") fdr = gr.Number(value=0.05, precision=3, label="FDR alpha (full mode)") run_btn = gr.Button("Run prediction") out_plot = gr.Image(label="Prediction summary plot") out_csv = gr.File(label="Predictions CSV") out_zip = gr.File(label="Download all outputs (ZIP)") run_btn.click( fn=run_prediction, inputs=[unknown_uploads, kmer_zip, seqtype, mode, identity, coverage, fdr], outputs=[out_plot, out_csv, out_zip], ) if __name__ == "__main__": demo.launch()