from __future__ import annotations import traceback import tempfile import html from pathlib import Path import gradio as gr import pandas as pd from pa_mic_predictor import PAMICPredictor TITLE = "P. aeruginosa WGS-MIC Predictor" DISCLAIMER = """ This Space is a research-use demonstrator for assembled *Pseudomonas aeruginosa* FASTA files. It predicts log2 MIC-derived MIC values and S/NS categories for eight antibiotics using the HMT-CORAL v6 model. Results are not intended for clinical diagnosis or treatment decisions. """ predictor = PAMICPredictor() def runtime_status() -> str: missing = predictor.check_runtime() if missing: return ( "Runtime is not fully configured. Missing command-line tools: " + ", ".join(missing) + ". Build the Docker Space image with RGI, CARD, and Snippy support." ) return "Runtime check passed: RGI and Snippy are available." def predict(uploaded_file, progress=gr.Progress(track_tqdm=True)): if uploaded_file is None: return "Please upload one assembled FASTA file.", "", "", pd.DataFrame(), "", None try: file_path = Path(uploaded_file.name) step_map = { "Copying uploaded FASTA": 0.05, "Running basic genome QC": 0.10, "Running RGI/CARD annotation. This is usually a few minutes on CPU.": 0.25, "Running Snippy against PAO1. This is usually the slowest step.": 0.60, "Building the 442-feature model input": 0.85, "Running HMT-CORAL v6 inference": 0.93, "Formatting report": 0.98, } status_messages = [] def on_progress(message: str) -> None: status_messages.append(message) progress(step_map.get(message, 0.5), desc=message) result = predictor.predict_fasta(file_path, progress_callback=on_progress) progress(1.0, desc="Report ready") qc = result["qc"] qc_lines = [ "### Genome QC", f"- Contigs: `{qc['contigs']}`", f"- Total length: `{qc['total_bp']:,}` bp", f"- GC content: `{qc['gc_percent']}%`", f"- N content: `{qc['n_percent']}%`", f"- N50: `{qc['n50']:,}` bp", ] if qc.get("warning"): qc_lines.append(f"- Warning: {qc['warning']}") predictions = result["predictions"].copy() predictions = predictions[ [ "antibiotic", "predicted_MIC_mg_L", "predicted_S_NS", "confidence", "reliability", "model_breakpoint", "local_validation_CA", "local_validation_EA", "interpretation", ] ] mechanisms = result["mechanisms"].copy() out_csv = Path(tempfile.gettempdir()) / f"pamic_flow_prediction_{file_path.stem}.csv" predictions.to_csv(out_csv, index=False) overview = render_prediction_overview(predictions) status = "### Pipeline status\n\n" + "\n".join(f"- {m}" for m in status_messages + ["Report ready"]) mechanism_html = render_mechanisms(mechanisms) return "\n".join(qc_lines), overview, status, predictions, mechanism_html, str(out_csv) except Exception as exc: message = "### Prediction failed\n\n" message += f"`{type(exc).__name__}: {exc}`\n\n" message += "
Traceback\n\n" message += "```text\n" + traceback.format_exc()[-4000:] + "\n```\n
" return message, "", "", pd.DataFrame(), "", None def render_prediction_overview(predictions: pd.DataFrame) -> str: cards = [] order = ["AMK", "IPM", "TOB", "FEP", "MEM", "CAZ", "TZP", "LVX"] df = predictions.set_index("antibiotic").loc[[d for d in order if d in predictions["antibiotic"].values]] for drug, row in df.iterrows(): tier = str(row["reliability"]) tier_class = tier_to_class(tier) sns = str(row["predicted_S_NS"]) sns_class = "sns-ns" if sns == "NS" else "sns-s" mic = str(row["predicted_MIC_mg_L"]) mic_width = mic_to_width(mic) cutoff = clean_cutoff(str(row["model_breakpoint"])) interpretation = html.escape(str(row["interpretation"])) cards.append( f"""
{html.escape(drug)}
{html.escape(sns)}
{html.escape(mic)}mg/L
S/NS cutoff used: {html.escape(cutoff)}
{html.escape(tier)}
Local CA {float(row['local_validation_CA']):.2f} ยท EA {float(row['local_validation_EA']):.2f}
Interpretation

{interpretation}

""" ) primary = [] for drug in ["AMK", "IPM"]: if drug in predictions["antibiotic"].values: row = predictions[predictions["antibiotic"] == drug].iloc[0] primary.append( f"{drug}: {html.escape(str(row['predicted_MIC_mg_L']))} mg/L, " f"{html.escape(str(row['predicted_S_NS']))}" ) return f"""

Main readout

{'   |   '.join(primary)}

The card color reflects endpoint reliability synthesized from internal validation, published public external validation, and local MIC-only validation.

{''.join(cards)}
""" def render_mechanisms(mechanisms: pd.DataFrame) -> str: rows = [] for _, row in mechanisms.iterrows(): feature = html.escape(str(row["feature"])) value = str(row["value"]) if feature == "Top RGI hits": hits = [h.strip() for h in value.split(",") if h.strip()] badges = "".join(f"{html.escape(h)}" for h in hits[:40]) rows.append(f"
{feature}
{badges}
") else: active = value in {"1", "1.0", "True", "true"} cls = "mech-on" if active else "mech-off" label = "Detected" if active else "Not detected" rows.append(f"
{feature}{label}
") return "
" + "".join(rows) + "
" def tier_to_class(tier: str) -> str: if tier == "Recommended endpoint": return "tier-green" if "IPM" in tier or "high-resistance" in tier: return "tier-blue" if "warning" in tier or "Supportive" in tier or "supportive" in tier: return "tier-amber" return "tier-red" def mic_to_width(mic: str) -> int: try: val = float(mic) except Exception: return 20 if val <= 0: return 10 # Map 0.25-64 mg/L to a readable log-scale width. log_val = max(-2.0, min(6.0, __import__("math").log2(val))) return int(12 + (log_val + 2.0) / 8.0 * 82) def clean_cutoff(cutoff: str) -> str: return cutoff.replace(" (model v6 threshold)", "") CSS = """ .main-readout { border: 1px solid #d8e2dc; border-radius: 8px; padding: 14px 16px; background: #f7fbf9; margin: 8px 0 14px 0; } .main-readout h3 { margin: 0 0 8px 0; } .main-readout .note { color: #4b5563; font-size: 0.92rem; margin-bottom: 0; } .drug-grid { display: grid; grid-template-columns: repeat(auto-fit, minmax(260px, 1fr)); gap: 12px; } .drug-card { border: 1px solid #d6dbe1; border-left-width: 5px; border-radius: 8px; padding: 12px; background: white; } .drug-card-head { display: flex; justify-content: space-between; align-items: center; gap: 8px; } .drug-name { font-weight: 700; font-size: 1.08rem; } .sns-badge { border-radius: 999px; padding: 3px 9px; font-weight: 700; font-size: 0.86rem; } .sns-s { background: #e8f7ef; color: #137547; } .sns-ns { background: #fde8e8; color: #b42318; } .mic-line { margin-top: 10px; } .mic-value { font-size: 1.75rem; font-weight: 760; } .mic-unit { margin-left: 4px; color: #4b5563; } .bar-track { height: 8px; border-radius: 999px; background: #eef2f7; overflow: hidden; margin: 8px 0; } .bar-fill { height: 100%; border-radius: 999px; } .bar-fill.sns-s { background: #37b26c; } .bar-fill.sns-ns { background: #e45b5b; } .cutoff, .metrics { color: #4b5563; font-size: 0.88rem; } .tier { font-weight: 650; margin-top: 8px; } .tier-green { border-left-color: #1f9d55; } .tier-blue { border-left-color: #2f80ed; } .tier-amber { border-left-color: #f59e0b; } .tier-red { border-left-color: #ef4444; } .drug-card details { margin-top: 8px; color: #374151; } .drug-card summary { cursor: pointer; font-weight: 600; } .mechanism-panel { display: grid; grid-template-columns: repeat(auto-fit, minmax(260px, 1fr)); gap: 8px; } .mech-row { border: 1px solid #e5e7eb; border-radius: 8px; padding: 8px 10px; display: flex; justify-content: space-between; gap: 10px; background: white; } .mech-row.wide { grid-column: 1 / -1; display: block; } .mech-on { color: #b42318; font-weight: 700; } .mech-off { color: #6b7280; } .hit-wrap { margin-top: 8px; display: flex; flex-wrap: wrap; gap: 6px; } .hit-badge { border: 1px solid #d6dbe1; border-radius: 999px; padding: 3px 8px; background: #f9fafb; font-size: 0.86rem; } """ with gr.Blocks(title=TITLE) as demo: gr.Markdown(f"# {TITLE}") gr.Markdown(DISCLAIMER) gr.Markdown(f"**Runtime status:** {runtime_status()}") with gr.Row(): fasta_input = gr.File( label="Upload one assembled FASTA file", file_types=[".fasta", ".fa", ".fna"], ) run_button = gr.Button("Predict", variant="primary") qc_output = gr.Markdown(label="Genome QC") overview_output = gr.HTML(label="Prediction overview") status_output = gr.Markdown(label="Pipeline status") prediction_output = gr.Dataframe( label="Detailed prediction table", wrap=True, interactive=False, ) mechanism_output = gr.HTML( label="Detected resistance mechanisms", ) csv_output = gr.File(label="Download prediction CSV") run_button.click( predict, inputs=fasta_input, outputs=[qc_output, overview_output, status_output, prediction_output, mechanism_output, csv_output], ) gr.Markdown( """ ### Endpoint reliability The app reports endpoint-specific reliability flags derived from internal validation, published public external validation, and local MIC-only validation. In the current model version, AMK is the most consistent endpoint; IPM is usable for MIC/high-resistance screening; TOB is public-supported but carries a local-calibration warning; remaining endpoints are supportive or exploratory. """ ) if __name__ == "__main__": demo.queue(default_concurrency_limit=1).launch(server_name="0.0.0.0", server_port=7860, css=CSS)