PaMIC-Flow / app.py
fudan-renjun's picture
Upload app.py
10cd136 verified
from __future__ import annotations
import traceback
import tempfile
import html
from pathlib import Path
import gradio as gr
import pandas as pd
from pa_mic_predictor import PAMICPredictor
TITLE = "P. aeruginosa WGS-MIC Predictor"
DISCLAIMER = """
This Space is a research-use demonstrator for assembled *Pseudomonas aeruginosa*
FASTA files. It predicts log2 MIC-derived MIC values and S/NS categories for
eight antibiotics using the HMT-CORAL v6 model. Results are not intended for
clinical diagnosis or treatment decisions.
"""
predictor = PAMICPredictor()
def runtime_status() -> str:
missing = predictor.check_runtime()
if missing:
return (
"Runtime is not fully configured. Missing command-line tools: "
+ ", ".join(missing)
+ ". Build the Docker Space image with RGI, CARD, and Snippy support."
)
return "Runtime check passed: RGI and Snippy are available."
def predict(uploaded_file, progress=gr.Progress(track_tqdm=True)):
if uploaded_file is None:
return "Please upload one assembled FASTA file.", "", "", pd.DataFrame(), "", None
try:
file_path = Path(uploaded_file.name)
step_map = {
"Copying uploaded FASTA": 0.05,
"Running basic genome QC": 0.10,
"Running RGI/CARD annotation. This is usually a few minutes on CPU.": 0.25,
"Running Snippy against PAO1. This is usually the slowest step.": 0.60,
"Building the 442-feature model input": 0.85,
"Running HMT-CORAL v6 inference": 0.93,
"Formatting report": 0.98,
}
status_messages = []
def on_progress(message: str) -> None:
status_messages.append(message)
progress(step_map.get(message, 0.5), desc=message)
result = predictor.predict_fasta(file_path, progress_callback=on_progress)
progress(1.0, desc="Report ready")
qc = result["qc"]
qc_lines = [
"### Genome QC",
f"- Contigs: `{qc['contigs']}`",
f"- Total length: `{qc['total_bp']:,}` bp",
f"- GC content: `{qc['gc_percent']}%`",
f"- N content: `{qc['n_percent']}%`",
f"- N50: `{qc['n50']:,}` bp",
]
if qc.get("warning"):
qc_lines.append(f"- Warning: {qc['warning']}")
predictions = result["predictions"].copy()
predictions = predictions[
[
"antibiotic",
"predicted_MIC_mg_L",
"predicted_S_NS",
"confidence",
"reliability",
"model_breakpoint",
"local_validation_CA",
"local_validation_EA",
"interpretation",
]
]
mechanisms = result["mechanisms"].copy()
out_csv = Path(tempfile.gettempdir()) / f"pamic_flow_prediction_{file_path.stem}.csv"
predictions.to_csv(out_csv, index=False)
overview = render_prediction_overview(predictions)
status = "### Pipeline status\n\n" + "\n".join(f"- {m}" for m in status_messages + ["Report ready"])
mechanism_html = render_mechanisms(mechanisms)
return "\n".join(qc_lines), overview, status, predictions, mechanism_html, str(out_csv)
except Exception as exc:
message = "### Prediction failed\n\n"
message += f"`{type(exc).__name__}: {exc}`\n\n"
message += "<details><summary>Traceback</summary>\n\n"
message += "```text\n" + traceback.format_exc()[-4000:] + "\n```\n</details>"
return message, "", "", pd.DataFrame(), "", None
def render_prediction_overview(predictions: pd.DataFrame) -> str:
cards = []
order = ["AMK", "IPM", "TOB", "FEP", "MEM", "CAZ", "TZP", "LVX"]
df = predictions.set_index("antibiotic").loc[[d for d in order if d in predictions["antibiotic"].values]]
for drug, row in df.iterrows():
tier = str(row["reliability"])
tier_class = tier_to_class(tier)
sns = str(row["predicted_S_NS"])
sns_class = "sns-ns" if sns == "NS" else "sns-s"
mic = str(row["predicted_MIC_mg_L"])
mic_width = mic_to_width(mic)
cutoff = clean_cutoff(str(row["model_breakpoint"]))
interpretation = html.escape(str(row["interpretation"]))
cards.append(
f"""
<div class="drug-card {tier_class}">
<div class="drug-card-head">
<div class="drug-name">{html.escape(drug)}</div>
<div class="sns-badge {sns_class}">{html.escape(sns)}</div>
</div>
<div class="mic-line"><span class="mic-value">{html.escape(mic)}</span><span class="mic-unit">mg/L</span></div>
<div class="bar-track"><div class="bar-fill {sns_class}" style="width: {mic_width}%"></div></div>
<div class="cutoff">S/NS cutoff used: {html.escape(cutoff)}</div>
<div class="tier">{html.escape(tier)}</div>
<div class="metrics">Local CA {float(row['local_validation_CA']):.2f} · EA {float(row['local_validation_EA']):.2f}</div>
<details><summary>Interpretation</summary><p>{interpretation}</p></details>
</div>
"""
)
primary = []
for drug in ["AMK", "IPM"]:
if drug in predictions["antibiotic"].values:
row = predictions[predictions["antibiotic"] == drug].iloc[0]
primary.append(
f"<b>{drug}</b>: {html.escape(str(row['predicted_MIC_mg_L']))} mg/L, "
f"<span class='inline-sns'>{html.escape(str(row['predicted_S_NS']))}</span>"
)
return f"""
<section class="main-readout">
<h3>Main readout</h3>
<p>{' &nbsp; | &nbsp; '.join(primary)}</p>
<p class="note">The card color reflects endpoint reliability synthesized from internal validation, published public external validation, and local MIC-only validation.</p>
</section>
<section class="drug-grid">
{''.join(cards)}
</section>
"""
def render_mechanisms(mechanisms: pd.DataFrame) -> str:
rows = []
for _, row in mechanisms.iterrows():
feature = html.escape(str(row["feature"]))
value = str(row["value"])
if feature == "Top RGI hits":
hits = [h.strip() for h in value.split(",") if h.strip()]
badges = "".join(f"<span class='hit-badge'>{html.escape(h)}</span>" for h in hits[:40])
rows.append(f"<div class='mech-row wide'><b>{feature}</b><div class='hit-wrap'>{badges}</div></div>")
else:
active = value in {"1", "1.0", "True", "true"}
cls = "mech-on" if active else "mech-off"
label = "Detected" if active else "Not detected"
rows.append(f"<div class='mech-row'><span>{feature}</span><span class='{cls}'>{label}</span></div>")
return "<div class='mechanism-panel'>" + "".join(rows) + "</div>"
def tier_to_class(tier: str) -> str:
if tier == "Recommended endpoint":
return "tier-green"
if "IPM" in tier or "high-resistance" in tier:
return "tier-blue"
if "warning" in tier or "Supportive" in tier or "supportive" in tier:
return "tier-amber"
return "tier-red"
def mic_to_width(mic: str) -> int:
try:
val = float(mic)
except Exception:
return 20
if val <= 0:
return 10
# Map 0.25-64 mg/L to a readable log-scale width.
log_val = max(-2.0, min(6.0, __import__("math").log2(val)))
return int(12 + (log_val + 2.0) / 8.0 * 82)
def clean_cutoff(cutoff: str) -> str:
return cutoff.replace(" (model v6 threshold)", "")
CSS = """
.main-readout {
border: 1px solid #d8e2dc;
border-radius: 8px;
padding: 14px 16px;
background: #f7fbf9;
margin: 8px 0 14px 0;
}
.main-readout h3 { margin: 0 0 8px 0; }
.main-readout .note { color: #4b5563; font-size: 0.92rem; margin-bottom: 0; }
.drug-grid {
display: grid;
grid-template-columns: repeat(auto-fit, minmax(260px, 1fr));
gap: 12px;
}
.drug-card {
border: 1px solid #d6dbe1;
border-left-width: 5px;
border-radius: 8px;
padding: 12px;
background: white;
}
.drug-card-head { display: flex; justify-content: space-between; align-items: center; gap: 8px; }
.drug-name { font-weight: 700; font-size: 1.08rem; }
.sns-badge {
border-radius: 999px;
padding: 3px 9px;
font-weight: 700;
font-size: 0.86rem;
}
.sns-s { background: #e8f7ef; color: #137547; }
.sns-ns { background: #fde8e8; color: #b42318; }
.mic-line { margin-top: 10px; }
.mic-value { font-size: 1.75rem; font-weight: 760; }
.mic-unit { margin-left: 4px; color: #4b5563; }
.bar-track { height: 8px; border-radius: 999px; background: #eef2f7; overflow: hidden; margin: 8px 0; }
.bar-fill { height: 100%; border-radius: 999px; }
.bar-fill.sns-s { background: #37b26c; }
.bar-fill.sns-ns { background: #e45b5b; }
.cutoff, .metrics { color: #4b5563; font-size: 0.88rem; }
.tier { font-weight: 650; margin-top: 8px; }
.tier-green { border-left-color: #1f9d55; }
.tier-blue { border-left-color: #2f80ed; }
.tier-amber { border-left-color: #f59e0b; }
.tier-red { border-left-color: #ef4444; }
.drug-card details { margin-top: 8px; color: #374151; }
.drug-card summary { cursor: pointer; font-weight: 600; }
.mechanism-panel {
display: grid;
grid-template-columns: repeat(auto-fit, minmax(260px, 1fr));
gap: 8px;
}
.mech-row {
border: 1px solid #e5e7eb;
border-radius: 8px;
padding: 8px 10px;
display: flex;
justify-content: space-between;
gap: 10px;
background: white;
}
.mech-row.wide { grid-column: 1 / -1; display: block; }
.mech-on { color: #b42318; font-weight: 700; }
.mech-off { color: #6b7280; }
.hit-wrap { margin-top: 8px; display: flex; flex-wrap: wrap; gap: 6px; }
.hit-badge {
border: 1px solid #d6dbe1;
border-radius: 999px;
padding: 3px 8px;
background: #f9fafb;
font-size: 0.86rem;
}
"""
with gr.Blocks(title=TITLE) as demo:
gr.Markdown(f"# {TITLE}")
gr.Markdown(DISCLAIMER)
gr.Markdown(f"**Runtime status:** {runtime_status()}")
with gr.Row():
fasta_input = gr.File(
label="Upload one assembled FASTA file",
file_types=[".fasta", ".fa", ".fna"],
)
run_button = gr.Button("Predict", variant="primary")
qc_output = gr.Markdown(label="Genome QC")
overview_output = gr.HTML(label="Prediction overview")
status_output = gr.Markdown(label="Pipeline status")
prediction_output = gr.Dataframe(
label="Detailed prediction table",
wrap=True,
interactive=False,
)
mechanism_output = gr.HTML(
label="Detected resistance mechanisms",
)
csv_output = gr.File(label="Download prediction CSV")
run_button.click(
predict,
inputs=fasta_input,
outputs=[qc_output, overview_output, status_output, prediction_output, mechanism_output, csv_output],
)
gr.Markdown(
"""
### Endpoint reliability
The app reports endpoint-specific reliability flags derived from internal
validation, published public external validation, and local MIC-only validation.
In the current model version, AMK is the most consistent endpoint; IPM is
usable for MIC/high-resistance screening; TOB is public-supported but carries a
local-calibration warning; remaining endpoints are supportive or exploratory.
"""
)
if __name__ == "__main__":
demo.queue(default_concurrency_limit=1).launch(server_name="0.0.0.0", server_port=7860, css=CSS)