FtsI_Classifier / app.py
Muhamed-Kheir's picture
Update app.py
af876ca verified
import os
import sys
import shutil
import uuid
import zipfile
import gradio as gr
# Ensure repo root is importable on Spaces
ROOT = os.path.dirname(__file__)
if ROOT not in sys.path:
sys.path.insert(0, ROOT)
import kmer_predict # must be in repo root
PERSIST_BASE = "/tmp/kmer_predict_runs"
FASTA_EXTS = (".fa", ".fasta", ".fas", ".fna")
def _zip_dir(folder: str, zip_path: str) -> None:
"""Zip the contents of folder into zip_path."""
with zipfile.ZipFile(zip_path, "w", zipfile.ZIP_DEFLATED) as z:
for root, _, files in os.walk(folder):
for fn in files:
full = os.path.join(root, fn)
rel = os.path.relpath(full, folder)
z.write(full, rel)
def _safe_extract_zip(zip_path: str, dst_dir: str) -> None:
"""Safely extract only FASTA files from ZIP (prevents zip-slip)."""
with zipfile.ZipFile(zip_path, "r") as z:
for member in z.infolist():
if member.is_dir():
continue
# Zip-slip protection
target = os.path.normpath(os.path.join(dst_dir, member.filename))
if not target.startswith(os.path.abspath(dst_dir) + os.sep):
continue
# Only FASTA-like files
if not member.filename.lower().endswith(FASTA_EXTS):
continue
os.makedirs(os.path.dirname(target), exist_ok=True)
with z.open(member) as src, open(target, "wb") as out:
shutil.copyfileobj(src, out)
def _ingest_unknown_uploads(unknown_uploads, unknown_dir: str) -> None:
"""
Accept unknown sequences as:
- FASTA files, and/or
- ZIP files containing FASTA files.
Copies/extracts into unknown_dir.
"""
os.makedirs(unknown_dir, exist_ok=True)
if not unknown_uploads:
return
for idx, f in enumerate(unknown_uploads, start=1):
src = getattr(f, "path", None) or getattr(f, "name", None) or str(f)
orig = (
getattr(f, "orig_name", None)
or getattr(f, "filename", None)
or os.path.basename(src)
)
lower = str(orig).lower()
# ZIP → extract FASTA files
if lower.endswith(".zip") or str(src).lower().endswith(".zip"):
_safe_extract_zip(src, unknown_dir)
continue
# FASTA → copy
if lower.endswith(FASTA_EXTS):
dst_name = os.path.basename(orig)
else:
# If Gradio provides a temp name without extension, keep it readable
dst_name = f"unknown_{idx}.fasta"
shutil.copy(src, os.path.join(unknown_dir, dst_name))
def run_prediction(unknown_uploads, kmer_zip, seqtype, mode, identity, coverage, fdr):
if not unknown_uploads:
raise gr.Error("Please upload unknown FASTA files or a ZIP containing FASTA files.")
if not kmer_zip:
raise gr.Error("Please upload the k-mer results ZIP from Space 1.")
os.makedirs(PERSIST_BASE, exist_ok=True)
run_id = uuid.uuid4().hex[:10]
run_dir = os.path.join(PERSIST_BASE, f"run_{run_id}")
os.makedirs(run_dir, exist_ok=True)
unknown_dir = os.path.join(run_dir, "unknown")
outdir = os.path.join(run_dir, "predictions")
os.makedirs(unknown_dir, exist_ok=True)
os.makedirs(outdir, exist_ok=True)
# Ingest unknown uploads (FASTA and/or ZIP)
_ingest_unknown_uploads(unknown_uploads, unknown_dir)
# Ensure we actually got sequences
# (Lightweight check: presence of at least one fasta-like file)
found_any = any(
fn.lower().endswith(FASTA_EXTS)
for _, _, files in os.walk(unknown_dir)
for fn in files
)
if not found_any:
raise gr.Error("No FASTA files were found after processing your uploads. Please check your ZIP contents.")
# K-mer ZIP path (ZIP-only)
kmer_zip_path = getattr(kmer_zip, "path", None) or getattr(kmer_zip, "name", None) or str(kmer_zip)
if not str(kmer_zip_path).lower().endswith(".zip"):
raise gr.Error("K-mer input must be a .zip file produced by Space 1.")
# Run prediction
kmer_predict.predict(
unknown=unknown_dir,
kmer_input=kmer_zip_path,
output_dir=outdir,
seqtype=seqtype,
mode=mode,
identity_threshold=float(identity),
min_coverage=float(coverage),
fdr_alpha=float(fdr),
group_regex=kmer_predict.DEFAULT_GROUP_REGEX,
)
plot_path = os.path.join(outdir, "predicted_results_summary.png")
csv_path = os.path.join(outdir, "predictions_by_alignment.csv")
zip_path = os.path.join(run_dir, "prediction_outputs.zip")
_zip_dir(outdir, zip_path)
return plot_path, csv_path, zip_path
with gr.Blocks() as demo:
gr.Markdown("# K-mer Sequence Predictor")
gr.Markdown(
"Upload **unknown sequences** (FASTA files or ZIP containing FASTA) and the **kmer_results.zip** from Space 1."
)
unknown_uploads = gr.File(
label="Unknown sequences (FASTA files or ZIP containing FASTA)",
file_count="multiple",
file_types=[".fa", ".fasta", ".fas", ".fna", ".zip"],
)
kmer_zip = gr.File(
label="kmer_results.zip (from Space 1)",
file_count="single",
file_types=[".zip"],
)
with gr.Row():
seqtype = gr.Radio(["dna", "protein"], value="dna", label="Sequence type")
mode = gr.Radio(["fast", "full"], value="fast", label="Mode")
with gr.Row():
identity = gr.Number(value=0.90, precision=2, label="Identity (full mode)")
coverage = gr.Number(value=0.80, precision=2, label="Coverage (full mode)")
fdr = gr.Number(value=0.05, precision=3, label="FDR alpha (full mode)")
run_btn = gr.Button("Run prediction")
out_plot = gr.Image(label="Prediction summary plot")
out_csv = gr.File(label="Predictions CSV")
out_zip = gr.File(label="Download all outputs (ZIP)")
run_btn.click(
fn=run_prediction,
inputs=[unknown_uploads, kmer_zip, seqtype, mode, identity, coverage, fdr],
outputs=[out_plot, out_csv, out_zip],
)
if __name__ == "__main__":
demo.launch()