Spaces:
Sleeping
Sleeping
| from __future__ import annotations | |
| import traceback | |
| import tempfile | |
| import html | |
| from pathlib import Path | |
| import gradio as gr | |
| import pandas as pd | |
| from pa_mic_predictor import PAMICPredictor | |
| TITLE = "P. aeruginosa WGS-MIC Predictor" | |
| DISCLAIMER = """ | |
| This Space is a research-use demonstrator for assembled *Pseudomonas aeruginosa* | |
| FASTA files. It predicts log2 MIC-derived MIC values and S/NS categories for | |
| eight antibiotics using the HMT-CORAL v6 model. Results are not intended for | |
| clinical diagnosis or treatment decisions. | |
| """ | |
| predictor = PAMICPredictor() | |
| def runtime_status() -> str: | |
| missing = predictor.check_runtime() | |
| if missing: | |
| return ( | |
| "Runtime is not fully configured. Missing command-line tools: " | |
| + ", ".join(missing) | |
| + ". Build the Docker Space image with RGI, CARD, and Snippy support." | |
| ) | |
| return "Runtime check passed: RGI and Snippy are available." | |
| def predict(uploaded_file, progress=gr.Progress(track_tqdm=True)): | |
| if uploaded_file is None: | |
| return "Please upload one assembled FASTA file.", "", "", pd.DataFrame(), "", None | |
| try: | |
| file_path = Path(uploaded_file.name) | |
| step_map = { | |
| "Copying uploaded FASTA": 0.05, | |
| "Running basic genome QC": 0.10, | |
| "Running RGI/CARD annotation. This is usually a few minutes on CPU.": 0.25, | |
| "Running Snippy against PAO1. This is usually the slowest step.": 0.60, | |
| "Building the 442-feature model input": 0.85, | |
| "Running HMT-CORAL v6 inference": 0.93, | |
| "Formatting report": 0.98, | |
| } | |
| status_messages = [] | |
| def on_progress(message: str) -> None: | |
| status_messages.append(message) | |
| progress(step_map.get(message, 0.5), desc=message) | |
| result = predictor.predict_fasta(file_path, progress_callback=on_progress) | |
| progress(1.0, desc="Report ready") | |
| qc = result["qc"] | |
| qc_lines = [ | |
| "### Genome QC", | |
| f"- Contigs: `{qc['contigs']}`", | |
| f"- Total length: `{qc['total_bp']:,}` bp", | |
| f"- GC content: `{qc['gc_percent']}%`", | |
| f"- N content: `{qc['n_percent']}%`", | |
| f"- N50: `{qc['n50']:,}` bp", | |
| ] | |
| if qc.get("warning"): | |
| qc_lines.append(f"- Warning: {qc['warning']}") | |
| predictions = result["predictions"].copy() | |
| predictions = predictions[ | |
| [ | |
| "antibiotic", | |
| "predicted_MIC_mg_L", | |
| "predicted_S_NS", | |
| "confidence", | |
| "reliability", | |
| "model_breakpoint", | |
| "local_validation_CA", | |
| "local_validation_EA", | |
| "interpretation", | |
| ] | |
| ] | |
| mechanisms = result["mechanisms"].copy() | |
| out_csv = Path(tempfile.gettempdir()) / f"pamic_flow_prediction_{file_path.stem}.csv" | |
| predictions.to_csv(out_csv, index=False) | |
| overview = render_prediction_overview(predictions) | |
| status = "### Pipeline status\n\n" + "\n".join(f"- {m}" for m in status_messages + ["Report ready"]) | |
| mechanism_html = render_mechanisms(mechanisms) | |
| return "\n".join(qc_lines), overview, status, predictions, mechanism_html, str(out_csv) | |
| except Exception as exc: | |
| message = "### Prediction failed\n\n" | |
| message += f"`{type(exc).__name__}: {exc}`\n\n" | |
| message += "<details><summary>Traceback</summary>\n\n" | |
| message += "```text\n" + traceback.format_exc()[-4000:] + "\n```\n</details>" | |
| return message, "", "", pd.DataFrame(), "", None | |
| def render_prediction_overview(predictions: pd.DataFrame) -> str: | |
| cards = [] | |
| order = ["AMK", "IPM", "TOB", "FEP", "MEM", "CAZ", "TZP", "LVX"] | |
| df = predictions.set_index("antibiotic").loc[[d for d in order if d in predictions["antibiotic"].values]] | |
| for drug, row in df.iterrows(): | |
| tier = str(row["reliability"]) | |
| tier_class = tier_to_class(tier) | |
| sns = str(row["predicted_S_NS"]) | |
| sns_class = "sns-ns" if sns == "NS" else "sns-s" | |
| mic = str(row["predicted_MIC_mg_L"]) | |
| mic_width = mic_to_width(mic) | |
| cutoff = clean_cutoff(str(row["model_breakpoint"])) | |
| interpretation = html.escape(str(row["interpretation"])) | |
| cards.append( | |
| f""" | |
| <div class="drug-card {tier_class}"> | |
| <div class="drug-card-head"> | |
| <div class="drug-name">{html.escape(drug)}</div> | |
| <div class="sns-badge {sns_class}">{html.escape(sns)}</div> | |
| </div> | |
| <div class="mic-line"><span class="mic-value">{html.escape(mic)}</span><span class="mic-unit">mg/L</span></div> | |
| <div class="bar-track"><div class="bar-fill {sns_class}" style="width: {mic_width}%"></div></div> | |
| <div class="cutoff">S/NS cutoff used: {html.escape(cutoff)}</div> | |
| <div class="tier">{html.escape(tier)}</div> | |
| <div class="metrics">Local CA {float(row['local_validation_CA']):.2f} · EA {float(row['local_validation_EA']):.2f}</div> | |
| <details><summary>Interpretation</summary><p>{interpretation}</p></details> | |
| </div> | |
| """ | |
| ) | |
| primary = [] | |
| for drug in ["AMK", "IPM"]: | |
| if drug in predictions["antibiotic"].values: | |
| row = predictions[predictions["antibiotic"] == drug].iloc[0] | |
| primary.append( | |
| f"<b>{drug}</b>: {html.escape(str(row['predicted_MIC_mg_L']))} mg/L, " | |
| f"<span class='inline-sns'>{html.escape(str(row['predicted_S_NS']))}</span>" | |
| ) | |
| return f""" | |
| <section class="main-readout"> | |
| <h3>Main readout</h3> | |
| <p>{' | '.join(primary)}</p> | |
| <p class="note">The card color reflects endpoint reliability synthesized from internal validation, published public external validation, and local MIC-only validation.</p> | |
| </section> | |
| <section class="drug-grid"> | |
| {''.join(cards)} | |
| </section> | |
| """ | |
| def render_mechanisms(mechanisms: pd.DataFrame) -> str: | |
| rows = [] | |
| for _, row in mechanisms.iterrows(): | |
| feature = html.escape(str(row["feature"])) | |
| value = str(row["value"]) | |
| if feature == "Top RGI hits": | |
| hits = [h.strip() for h in value.split(",") if h.strip()] | |
| badges = "".join(f"<span class='hit-badge'>{html.escape(h)}</span>" for h in hits[:40]) | |
| rows.append(f"<div class='mech-row wide'><b>{feature}</b><div class='hit-wrap'>{badges}</div></div>") | |
| else: | |
| active = value in {"1", "1.0", "True", "true"} | |
| cls = "mech-on" if active else "mech-off" | |
| label = "Detected" if active else "Not detected" | |
| rows.append(f"<div class='mech-row'><span>{feature}</span><span class='{cls}'>{label}</span></div>") | |
| return "<div class='mechanism-panel'>" + "".join(rows) + "</div>" | |
| def tier_to_class(tier: str) -> str: | |
| if tier == "Recommended endpoint": | |
| return "tier-green" | |
| if "IPM" in tier or "high-resistance" in tier: | |
| return "tier-blue" | |
| if "warning" in tier or "Supportive" in tier or "supportive" in tier: | |
| return "tier-amber" | |
| return "tier-red" | |
| def mic_to_width(mic: str) -> int: | |
| try: | |
| val = float(mic) | |
| except Exception: | |
| return 20 | |
| if val <= 0: | |
| return 10 | |
| # Map 0.25-64 mg/L to a readable log-scale width. | |
| log_val = max(-2.0, min(6.0, __import__("math").log2(val))) | |
| return int(12 + (log_val + 2.0) / 8.0 * 82) | |
| def clean_cutoff(cutoff: str) -> str: | |
| return cutoff.replace(" (model v6 threshold)", "") | |
| CSS = """ | |
| .main-readout { | |
| border: 1px solid #d8e2dc; | |
| border-radius: 8px; | |
| padding: 14px 16px; | |
| background: #f7fbf9; | |
| margin: 8px 0 14px 0; | |
| } | |
| .main-readout h3 { margin: 0 0 8px 0; } | |
| .main-readout .note { color: #4b5563; font-size: 0.92rem; margin-bottom: 0; } | |
| .drug-grid { | |
| display: grid; | |
| grid-template-columns: repeat(auto-fit, minmax(260px, 1fr)); | |
| gap: 12px; | |
| } | |
| .drug-card { | |
| border: 1px solid #d6dbe1; | |
| border-left-width: 5px; | |
| border-radius: 8px; | |
| padding: 12px; | |
| background: white; | |
| } | |
| .drug-card-head { display: flex; justify-content: space-between; align-items: center; gap: 8px; } | |
| .drug-name { font-weight: 700; font-size: 1.08rem; } | |
| .sns-badge { | |
| border-radius: 999px; | |
| padding: 3px 9px; | |
| font-weight: 700; | |
| font-size: 0.86rem; | |
| } | |
| .sns-s { background: #e8f7ef; color: #137547; } | |
| .sns-ns { background: #fde8e8; color: #b42318; } | |
| .mic-line { margin-top: 10px; } | |
| .mic-value { font-size: 1.75rem; font-weight: 760; } | |
| .mic-unit { margin-left: 4px; color: #4b5563; } | |
| .bar-track { height: 8px; border-radius: 999px; background: #eef2f7; overflow: hidden; margin: 8px 0; } | |
| .bar-fill { height: 100%; border-radius: 999px; } | |
| .bar-fill.sns-s { background: #37b26c; } | |
| .bar-fill.sns-ns { background: #e45b5b; } | |
| .cutoff, .metrics { color: #4b5563; font-size: 0.88rem; } | |
| .tier { font-weight: 650; margin-top: 8px; } | |
| .tier-green { border-left-color: #1f9d55; } | |
| .tier-blue { border-left-color: #2f80ed; } | |
| .tier-amber { border-left-color: #f59e0b; } | |
| .tier-red { border-left-color: #ef4444; } | |
| .drug-card details { margin-top: 8px; color: #374151; } | |
| .drug-card summary { cursor: pointer; font-weight: 600; } | |
| .mechanism-panel { | |
| display: grid; | |
| grid-template-columns: repeat(auto-fit, minmax(260px, 1fr)); | |
| gap: 8px; | |
| } | |
| .mech-row { | |
| border: 1px solid #e5e7eb; | |
| border-radius: 8px; | |
| padding: 8px 10px; | |
| display: flex; | |
| justify-content: space-between; | |
| gap: 10px; | |
| background: white; | |
| } | |
| .mech-row.wide { grid-column: 1 / -1; display: block; } | |
| .mech-on { color: #b42318; font-weight: 700; } | |
| .mech-off { color: #6b7280; } | |
| .hit-wrap { margin-top: 8px; display: flex; flex-wrap: wrap; gap: 6px; } | |
| .hit-badge { | |
| border: 1px solid #d6dbe1; | |
| border-radius: 999px; | |
| padding: 3px 8px; | |
| background: #f9fafb; | |
| font-size: 0.86rem; | |
| } | |
| """ | |
| with gr.Blocks(title=TITLE) as demo: | |
| gr.Markdown(f"# {TITLE}") | |
| gr.Markdown(DISCLAIMER) | |
| gr.Markdown(f"**Runtime status:** {runtime_status()}") | |
| with gr.Row(): | |
| fasta_input = gr.File( | |
| label="Upload one assembled FASTA file", | |
| file_types=[".fasta", ".fa", ".fna"], | |
| ) | |
| run_button = gr.Button("Predict", variant="primary") | |
| qc_output = gr.Markdown(label="Genome QC") | |
| overview_output = gr.HTML(label="Prediction overview") | |
| status_output = gr.Markdown(label="Pipeline status") | |
| prediction_output = gr.Dataframe( | |
| label="Detailed prediction table", | |
| wrap=True, | |
| interactive=False, | |
| ) | |
| mechanism_output = gr.HTML( | |
| label="Detected resistance mechanisms", | |
| ) | |
| csv_output = gr.File(label="Download prediction CSV") | |
| run_button.click( | |
| predict, | |
| inputs=fasta_input, | |
| outputs=[qc_output, overview_output, status_output, prediction_output, mechanism_output, csv_output], | |
| ) | |
| gr.Markdown( | |
| """ | |
| ### Endpoint reliability | |
| The app reports endpoint-specific reliability flags derived from internal | |
| validation, published public external validation, and local MIC-only validation. | |
| In the current model version, AMK is the most consistent endpoint; IPM is | |
| usable for MIC/high-resistance screening; TOB is public-supported but carries a | |
| local-calibration warning; remaining endpoints are supportive or exploratory. | |
| """ | |
| ) | |
| if __name__ == "__main__": | |
| demo.queue(default_concurrency_limit=1).launch(server_name="0.0.0.0", server_port=7860, css=CSS) | |