"""
AMR-Guard — Gradio Interface (ZeroGPU compatible)
Infection Lifecycle Orchestrator · Multi-Agent Clinical Decision Support
"""
import json
import logging
import os
import subprocess
import sys
import traceback
from io import BytesIO
from pathlib import Path
PROJECT_ROOT = Path(__file__).parent
sys.path.insert(0, str(PROJECT_ROOT))
# Configure logging early so all module-level loggers emit to stdout.
# force=True reconfigures the root logger even if already set by an import.
logging.basicConfig(
level=logging.INFO,
stream=sys.stdout,
format="%(asctime)s %(levelname)s %(name)s: %(message)s",
force=True,
)
logger = logging.getLogger(__name__)
# ── huggingface_hub compatibility shim ───────────────────────────────────────
# Older gradio versions (pre-5.7) import HfFolder from huggingface_hub in
# oauth.py. HfFolder was removed in huggingface_hub >= 0.25. Patch it back
# in-memory before importing gradio so the old oauth.py can find it.
try:
from huggingface_hub import HfFolder as _check # noqa: F401
except ImportError:
import huggingface_hub as _hfh
class _HfFolder:
@staticmethod
def get_token():
return os.environ.get("HF_TOKEN") or _hfh.get_token()
@staticmethod
def save_token(token: str) -> None: # noqa: ARG004
pass
@staticmethod
def delete_token() -> None:
pass
_hfh.HfFolder = _HfFolder
# ── HuggingFace Spaces: auto-build knowledge base on first boot ───────────────
_DB_PATH = PROJECT_ROOT / os.getenv("MEDIC_DATA_DIR", "data") / "amr_guard.db"
if os.environ.get("SPACE_ID") and not _DB_PATH.exists():
subprocess.run([sys.executable, str(PROJECT_ROOT / "setup_demo.py")], check=False)
import gradio as gr
import pandas as pd
# ── Gradio boolean-schema safety patch ───────────────────────────────────────
# Gradio <5.7 walks JSON Schemas and does `if "const" in schema:` without
# guarding against boolean schemas (valid in JSON Schema spec but not a dict).
# sdk_version is now >=5.25.0 (bug fixed upstream) but keep this as a guard.
try:
import gradio.utils as _gr_utils
_orig_get_type = getattr(_gr_utils, "get_type", None)
if _orig_get_type:
def _safe_get_type(schema, *a, **kw):
if not isinstance(schema, dict):
return "other"
return _orig_get_type(schema, *a, **kw)
_gr_utils.get_type = _safe_get_type
except Exception:
pass
try:
import gradio.route_utils as _gr_ru
for _fn_name in ("get_type", "_json_schema_to_python_type", "json_schema_to_python_type"):
_fn = getattr(_gr_ru, _fn_name, None)
if _fn:
def _safe_fn(schema, *a, _f=_fn, **kw):
if not isinstance(schema, dict):
return "other"
return _f(schema, *a, **kw)
setattr(_gr_ru, _fn_name, _safe_fn)
except Exception:
pass
from src.config import get_settings
from src.form_config import CREATININE_PROMINENT_SITES, SITE_SPECIFIC_FIELDS, SUSPECTED_SOURCE_OPTIONS
from src.loader import run_inference # noqa: F401 – triggers spaces import / ZeroGPU registration at startup
# ── Single GPU session for the full multi-agent pipeline ──────────────────────
# Each run_inference call uses lru_cache'd model weights. ZeroGPU frees GPU
# memory when a @spaces.GPU function returns, so wrapping every individual
# inference call in its own GPU session would invalidate the cached model
# between agents, causing model.generate() to hang on freed CUDA memory.
# Wrapping the *entire* pipeline in one session keeps the CUDA context alive
# for all four agents, so the model is loaded once and stays valid throughout.
if os.environ.get("SPACE_ID"):
try:
import spaces as _spaces_ui
@_spaces_ui.GPU(duration=200)
def _run_pipeline_gpu(patient_data: dict, labs_raw_text):
from src.graph import run_pipeline
return run_pipeline(patient_data, labs_raw_text)
except ImportError:
def _run_pipeline_gpu(patient_data: dict, labs_raw_text):
from src.graph import run_pipeline
return run_pipeline(patient_data, labs_raw_text)
else:
def _run_pipeline_gpu(patient_data: dict, labs_raw_text):
from src.graph import run_pipeline
return run_pipeline(patient_data, labs_raw_text)
from src.tools import (
calculate_mic_trend,
get_empirical_therapy_guidance,
get_most_effective_antibiotics,
interpret_mic_value,
screen_antibiotic_safety,
search_clinical_guidelines,
)
# ── CSS ────────────────────────────────────────────────────────────────────────
CSS = """
@import url('https://fonts.googleapis.com/css2?family=Inter:wght@400;500;600;700&display=swap');
body, .gradio-container { font-family: 'Inter', sans-serif !important; }
/* ── Banner ── */
.med-banner {
background: linear-gradient(135deg, #020d1f 0%, #0b2545 55%, #102f62 100%);
padding: 24px 32px; border-radius: 14px; margin-bottom: 22px;
border: 1px solid #1e4a80;
box-shadow: 0 0 48px rgba(26,74,138,0.45), inset 0 1px 0 rgba(158,196,240,0.12);
}
.med-banner h1 {
color: #ffffff; font-size: 1.95rem; font-weight: 700; margin: 0;
text-shadow: 0 0 24px rgba(96,196,255,0.45);
}
.med-banner p { color: #7eb8e8; font-size: 0.95rem; margin: 5px 0 0; }
/* ── Section titles ── */
.section-title {
font-size: 0.8rem; font-weight: 700; color: #60b4ff;
border-bottom: 1px solid #1e3f72; padding-bottom: 6px; margin: 18px 0 13px;
text-transform: uppercase; letter-spacing: 0.1em;
}
/* ── Stat cards ── */
.stat-cards {
display: grid; grid-template-columns: repeat(4, 1fr); gap: 16px; margin-bottom: 22px;
}
.stat-card {
background: linear-gradient(160deg, #0b1e3d 0%, #0e2a56 100%);
border: 1px solid #1e4a80; border-top: 3px solid #3b82f6;
border-radius: 11px; padding: 18px 20px; text-align: center;
box-shadow: 0 4px 18px rgba(0,0,0,0.35);
}
.stat-card .label {
color: #7eaadb; font-size: 0.78rem; font-weight: 600;
text-transform: uppercase; letter-spacing: 0.05em;
}
.stat-card .value { color: #60c8ff; font-size: 1.65rem; font-weight: 700; margin-top: 5px; }
.stat-card .sub { color: #a8cce8; font-size: 0.75rem; margin-top: 3px; }
/* ── Agent steps ── */
.agent-step {
background: linear-gradient(135deg, #091a36 0%, #0d2450 100%);
border: 1px solid #1e4278; border-left: 4px solid #3b82f6;
border-radius: 8px; padding: 14px 16px; margin-bottom: 10px;
}
.agent-step .num { color: #60b4ff; font-weight: 700; font-size: 0.82rem; letter-spacing: 0.04em; }
.agent-step .name { color: #dceeff; font-weight: 600; }
.agent-step .desc { color: #8ab4d8; font-size: 0.85rem; margin-top: 4px; }
/* ── Status badges — dark backgrounds, high-contrast text ── */
.badge-high {
background: #1e0707; border-left: 4px solid #dc2626; color: #fca5a5;
padding: 10px 14px; border-radius: 7px; margin-bottom: 6px;
}
.badge-moderate {
background: #1c1200; border-left: 4px solid #d97706; color: #fcd34d;
padding: 10px 14px; border-radius: 7px; margin-bottom: 6px;
}
.badge-low {
background: #021a0e; border-left: 4px solid #16a34a; color: #86efac;
padding: 10px 14px; border-radius: 7px; margin-bottom: 6px;
}
.badge-info {
background: #071428; border-left: 4px solid #2563eb; color: #93c5fd;
padding: 10px 14px; border-radius: 7px; margin-bottom: 6px;
}
/* ── Prescription card ── */
.rx-card {
background: linear-gradient(145deg, #081730 0%, #0c2248 100%);
border: 1px solid #1e4a80; border-radius: 12px;
padding: 24px 26px; font-size: 0.9rem; line-height: 1.75; color: #cce3ff;
box-shadow: 0 6px 28px rgba(0,0,0,0.45), 0 0 0 1px rgba(59,130,246,0.18);
}
.rx-card .rx-symbol {
font-size: 2.2rem; color: #60c8ff; font-weight: 700;
text-shadow: 0 0 14px rgba(96,200,255,0.55);
}
.rx-card .rx-drug { font-size: 1.25rem; font-weight: 700; color: #ffffff; }
.rx-card strong { color: #a8d4ff; }
.rx-card ul { color: #cce3ff; }
/* ── Badge child elements inherit color ── */
.badge-high strong, .badge-high em, .badge-high span { color: inherit; }
.badge-moderate strong,.badge-moderate em,.badge-moderate span { color: inherit; }
.badge-low strong, .badge-low em, .badge-low span { color: inherit; }
.badge-info strong, .badge-info em, .badge-info span { color: inherit; }
/* ── Disclaimer ── */
.disclaimer {
background: #150e00; border: 1px solid #78450e; border-radius: 8px;
padding: 12px 16px; font-size: 0.78rem; color: #fbbf24; margin-top: 20px;
}
"""
BANNER_HTML = """
⚕ AMR-Guard
Infection Lifecycle Orchestrator · Multi-Agent Clinical Decision Support
"""
INFECTION_SITES = ["urinary", "respiratory", "bloodstream", "skin", "intra-abdominal", "CNS", "other"]
# ── HTML result builders ───────────────────────────────────────────────────────
def _parse_notes(raw):
if not raw or raw in ("No lab data provided", "No MIC data available for trend analysis", ""):
return None
if isinstance(raw, (dict, list)):
return raw
try:
return json.loads(raw)
except Exception:
return None
def _build_rec_html(result: dict) -> str:
rec = result.get("recommendation") or {}
if not rec:
return 'No recommendation generated.
'
primary = rec.get("primary_antibiotic", "—")
dose = rec.get("dose", "—")
route = rec.get("route", "—")
freq = rec.get("frequency", "—")
duration = rec.get("duration", "—")
alt = rec.get("backup_antibiotic", "")
rationale = rec.get("rationale", "")
refs = rec.get("references", [])
alt_html = f"
Alternative: {alt}" if alt else ""
rat_html = f"
Clinical rationale
{rationale}" if rationale else ""
ref_html = ""
if refs:
items = "".join(f"{r}" for r in refs)
ref_html = f"
References"
return f"""
℞
{primary}
Dose: {dose} ·
Route: {route} ·
Frequency: {freq} ·
Duration: {duration}
{alt_html}{rat_html}{ref_html}
"""
def _build_intake_html(result: dict) -> str:
intake = _parse_notes(result.get("intake_notes", ""))
crcl = result.get("creatinine_clearance_ml_min")
html = ""
if isinstance(intake, dict):
v = crcl or intake.get("creatinine_clearance_ml_min", 0)
sev = intake.get("infection_severity", "")
pathway = intake.get("recommended_stage", "")
cells = ""
if v:
cells += f"CrCl {float(v):.1f} mL/min | "
if sev:
cells += f"Severity {sev.capitalize()} | "
if pathway:
cells += f"Pathway {pathway.capitalize()} | "
if cells:
html += f""
if intake.get("patient_summary"):
html += f'{intake["patient_summary"]}
'
if intake.get("renal_dose_adjustment_needed"):
html += '⚠ Renal dose adjustment required
'
if intake.get("identified_risk_factors"):
items = "".join(f"{rf}" for rf in intake["identified_risk_factors"])
html += f"
Identified risk factors"
elif crcl:
html = f"CrCl: {float(crcl):.1f} mL/min"
else:
html = 'Intake summary not available.
'
return html
def _build_lab_html_and_df(result: dict) -> tuple[str, pd.DataFrame]:
vision = _parse_notes(result.get("vision_notes", ""))
trend = _parse_notes(result.get("trend_notes", ""))
html = ""
df = pd.DataFrame()
if vision is None:
html += 'No lab data processed. Provide lab results to activate the targeted pathway.
'
else:
v = vision if isinstance(vision, dict) else {}
if v.get("specimen_type"):
html += f"Specimen: {v['specimen_type'].capitalize()}
"
if v.get("extraction_confidence") is not None:
conf = float(v["extraction_confidence"])
color = "#86efac" if conf >= 0.85 else "#fcd34d" if conf >= 0.6 else "#fca5a5"
html += (f'Extraction confidence: '
f'{conf:.0%}
')
orgs = v.get("identified_organisms", [])
if orgs:
items = "".join(
f"{o.get('organism_name','?')}"
+ (f" — {o.get('significance','')}" if o.get("significance") else "")
+ ""
for o in orgs
)
html += f"
Identified organisms"
sus = v.get("susceptibility_results", [])
if sus:
rows = [
{
"Organism": e.get("organism", ""),
"Antibiotic": e.get("antibiotic", ""),
"MIC (mg/L)": str(e.get("mic_value", "")),
"Result": e.get("interpretation", ""),
}
for e in sus
]
df = pd.DataFrame(rows)
if trend:
html += "
MIC Trend Analysis
"
items = trend if isinstance(trend, list) else [trend]
for item in items:
if not isinstance(item, dict):
html += f"{item}
"
continue
risk = item.get("risk_level", "UNKNOWN").upper()
css = {"HIGH": "badge-high", "MODERATE": "badge-moderate"}.get(risk, "badge-low")
icon = {"HIGH": "🚨", "MODERATE": "⚠"}.get(risk, "✓")
org = item.get("organism", "")
ab = item.get("antibiotic", "")
label = f"{org} / {ab} — " if (org or ab) else ""
html += (f'{icon} {label}{risk}
'
f'{item.get("recommendation","")}
')
return html, df
def _build_safety_html(result: dict) -> str:
warnings = result.get("safety_warnings", [])
errors = result.get("errors", [])
html = "".join(f'⚠ {w}
' for w in warnings)
if not warnings:
html = '✓ No safety concerns identified.
'
html += "".join(f'Error: {e}
' for e in errors)
return html
def _demo_result(patient_data: dict, has_labs: bool) -> dict:
result = {
"stage": "targeted" if has_labs else "empirical",
"creatinine_clearance_ml_min": 58.3,
"intake_notes": json.dumps({
"patient_summary": (
f"{patient_data.get('age_years')}-year-old {patient_data.get('sex')} "
f"· {patient_data.get('suspected_source', 'infection')}"
),
"creatinine_clearance_ml_min": 58.3,
"renal_dose_adjustment_needed": True,
"identified_risk_factors": patient_data.get("comorbidities", []),
"infection_severity": "moderate",
"recommended_stage": "targeted" if has_labs else "empirical",
}),
"recommendation": {
"primary_antibiotic": "Ciprofloxacin",
"dose": "500 mg",
"route": "Oral",
"frequency": "Every 12 hours",
"duration": "7 days",
"backup_antibiotic": "Nitrofurantoin 100 mg MR BD × 5 days",
"rationale": (
"Community-acquired UTI with moderate renal impairment (CrCl 58 mL/min). "
"Ciprofloxacin provides broad Gram-negative coverage. "
"No dose adjustment required above CrCl 30 mL/min."
),
"references": ["IDSA UTI Guidelines 2024", "EUCAST Breakpoint Tables v16.0"],
},
"safety_warnings": [],
"errors": [],
}
if has_labs:
result["vision_notes"] = json.dumps({
"specimen_type": "urine",
"identified_organisms": [{"organism_name": "Escherichia coli", "significance": "pathogen"}],
"susceptibility_results": [
{"organism": "E. coli", "antibiotic": "Ciprofloxacin", "mic_value": 0.25, "interpretation": "S"},
{"organism": "E. coli", "antibiotic": "Nitrofurantoin", "mic_value": 16, "interpretation": "S"},
{"organism": "E. coli", "antibiotic": "Ampicillin", "mic_value": ">32", "interpretation": "R"},
],
"extraction_confidence": 0.95,
})
result["trend_notes"] = json.dumps([{
"organism": "E. coli", "antibiotic": "Ciprofloxacin",
"risk_level": "LOW", "recommendation": "No MIC creep detected.",
}])
return result
# ── Site change / lab method handlers ─────────────────────────────────────────
def update_site_ui(site):
grp_updates = [gr.update(visible=(s == site)) for s in INFECTION_SITES]
src_choices = SUSPECTED_SOURCE_OPTIONS.get(site, []) or ["Other"]
prominent = site in CREATININE_PROMINENT_SITES
return (
*grp_updates,
gr.update(choices=src_choices, value=src_choices[0]),
gr.update(visible=prominent), # creatinine_main
gr.update(visible=not prominent), # renal_flag
gr.update(visible=False), # creatinine_optional (reset hidden)
)
def toggle_optional_creatinine(flag):
return gr.update(visible=bool(flag))
def toggle_lab_inputs(method):
return (
gr.update(visible=(method == "Upload file (PDF / image)")),
gr.update(visible=(method == "Paste lab text")),
)
# ── Pipeline function ──────────────────────────────────────────────────────────
# Site-specific field order (matches component creation order in the Blocks):
# urinary : sf0 sf1 sf2 (3 fields)
# respiratory : sf3 sf4 sf5 sf6 (4 fields)
# bloodstream : sf7 sf8 sf9 sf10 sf11 sf12 sf13 (7 fields)
# skin : sf14 sf15 sf16 sf17 (4 fields)
# intra-abdom : sf18 sf19 sf20 sf21 (4 fields)
# CNS : sf22 sf23 sf24 sf25 (4 fields)
def run_pipeline_ui(
age, weight, height, sex,
creatinine_main, renal_flag, creatinine_optional,
infection_site, suspected_source,
# urinary
sf0, sf1, sf2,
# respiratory
sf3, sf4, sf5, sf6,
# bloodstream
sf7, sf8, sf9, sf10, sf11, sf12, sf13,
# skin
sf14, sf15, sf16, sf17,
# intra-abdominal
sf18, sf19, sf20, sf21,
# CNS
sf22, sf23, sf24, sf25,
# medical history
medications, allergies, comorbidities, risk_factors,
# lab
lab_method, lab_file, lab_paste,
progress=gr.Progress(),
):
# Creatinine
if infection_site in CREATININE_PROMINENT_SITES:
creatinine = creatinine_main
else:
creatinine = creatinine_optional if renal_flag else None
# Site-specific vitals
site_vitals: dict = {}
if infection_site == "urinary":
site_vitals = {
"catheter_status": str(sf0 or ""),
"urinary_symptoms": ", ".join(sf1) if sf1 else "",
"urine_appearance": str(sf2 or ""),
}
elif infection_site == "respiratory":
site_vitals = {
"o2_saturation": str(sf3 or ""),
"ventilation_status": str(sf4 or ""),
"cough_type": str(sf5 or ""),
"sputum_character": str(sf6 or ""),
}
elif infection_site == "bloodstream":
site_vitals = {
"central_line_present": "Yes" if sf7 else "No",
"temperature_c": str(sf8 or ""),
"heart_rate_bpm": str(sf9 or ""),
"respiratory_rate": str(sf10 or ""),
"wbc_count": str(sf11 or ""),
"lactate_mmol": str(sf12 or ""),
"shock_status": str(sf13 or ""),
}
elif infection_site == "skin":
site_vitals = {
"wound_type": str(sf14 or ""),
"cellulitis_extent": str(sf15 or ""),
"abscess_present": "Yes" if sf16 else "No",
"foreign_body": "Yes" if sf17 else "No",
}
elif infection_site == "intra-abdominal":
site_vitals = {
"abdominal_pain_location": str(sf18 or ""),
"peritonitis_signs": ", ".join(sf19) if sf19 else "",
"perforation_suspected": "Yes" if sf20 else "No",
"ascites": "Yes" if sf21 else "No",
}
elif infection_site == "CNS":
site_vitals = {
"csf_obtained": "Yes" if sf22 else "No",
"neuro_symptoms": ", ".join(sf23) if sf23 else "",
"recent_neurosurgery": "Yes" if sf24 else "No",
"gcs_score": str(sf25 or ""),
}
# Lab file handling
labs_raw_text = None
labs_image_bytes = None
if lab_method == "Upload file (PDF / image)" and lab_file is not None:
file_path = lab_file if isinstance(lab_file, str) else lab_file.name
ext = file_path.rsplit(".", 1)[-1].lower()
with open(file_path, "rb") as fh:
file_bytes = fh.read()
if ext == "pdf":
try:
import pypdf
reader = pypdf.PdfReader(BytesIO(file_bytes))
extracted = "\n".join(p.extract_text() or "" for p in reader.pages).strip()
if extracted:
labs_raw_text = extracted
else:
labs_image_bytes = file_bytes
except Exception:
labs_image_bytes = file_bytes
else:
labs_image_bytes = file_bytes
elif lab_method == "Paste lab text" and lab_paste:
labs_raw_text = lab_paste.strip() or None
patient_data = {
"age_years": float(age or 65),
"weight_kg": float(weight or 70),
"height_cm": float(height or 170),
"sex": sex or "male",
"serum_creatinine_mg_dl": float(creatinine) if creatinine else None,
"infection_site": infection_site,
"suspected_source": suspected_source or f"{infection_site} infection",
"medications": [m.strip() for m in (medications or "").split("\n") if m.strip()],
"allergies": [a.strip() for a in (allergies or "").split("\n") if a.strip()],
"comorbidities": list(comorbidities or []) + list(risk_factors or []),
"vitals": site_vitals,
"labs_image_bytes": labs_image_bytes,
}
has_labs = bool(labs_raw_text or labs_image_bytes)
stages = (
["Intake Historian", "Vision Specialist", "Trend Analyst", "Clinical Pharmacologist"]
if has_labs else ["Intake Historian", "Clinical Pharmacologist"]
)
for i, name in enumerate(stages):
progress((i + 0.5) / len(stages), desc=f"Running: {name}…")
try:
result = _run_pipeline_gpu(patient_data, labs_raw_text)
except Exception as e:
tb = traceback.format_exc()
logger.error("Pipeline failed — falling back to demo result.\n%s", tb)
result = _demo_result(patient_data, has_labs)
result["errors"].append(f"Pipeline error: {e}")
result["recommendation"] = {} # suppress the hardcoded drug from showing
progress(1.0, desc="Complete")
rec_html = _build_rec_html(result)
intake_html = _build_intake_html(result)
lab_html, lab_df = _build_lab_html_and_df(result)
safety_html = _build_safety_html(result)
return rec_html, intake_html, lab_html, lab_df, safety_html, gr.update(visible=True)
# ── Clinical Tools handlers ────────────────────────────────────────────────────
def switch_tool(tool):
tools = ["Empirical Advisor", "MIC Interpreter", "MIC Trend Analysis", "Drug Safety Check"]
return [gr.update(visible=(t == tool)) for t in tools]
def run_empirical(infection_type, pathogen, risk):
guidance = get_empirical_therapy_guidance(infection_type, list(risk or []))
html = ""
for i, rec in enumerate(guidance.get("recommendations", [])[:3], 1):
score = rec.get("relevance_score", 0)
content = rec.get("content", "")
source = rec.get("source", "IDSA Guidelines 2024")
html += (f'Excerpt {i}'
f' (relevance {score:.2f})
{content}
Source: {source}
')
if pathogen:
effective = get_most_effective_antibiotics(pathogen, min_susceptibility=70)
if effective:
items = "".join(
f"{ab.get('antibiotic')}"
f" — {ab.get('avg_susceptibility', 0):.1f}% susceptible"
for ab in effective[:6]
)
html += f"
Resistance data — {pathogen}"
else:
html += 'No resistance data available for this pathogen.
'
return html or 'No results found.
'
def run_mic_interpret(pathogen, antibiotic, mic):
if not pathogen or not antibiotic:
return 'Enter pathogen and antibiotic.
'
result = interpret_mic_value(pathogen, antibiotic, float(mic or 1.0))
interp = result.get("interpretation", "UNKNOWN")
msg = result.get("message", "")
if interp == "SUSCEPTIBLE":
return f'Susceptible (S) — {msg}
'
if interp == "RESISTANT":
return f'Resistant (R) — {msg}
'
return f'Intermediate (I) — {msg}
'
def update_mic_inputs(n):
return [gr.update(visible=(i < int(n))) for i in range(6)]
def run_mic_trend(n, m0, m1, m2, m3, m4, m5):
vals = [m0, m1, m2, m3, m4, m5][: int(n)]
mic_values = [{"date": f"T{i}", "mic_value": float(v or 1.0)} for i, v in enumerate(vals)]
result = calculate_mic_trend(mic_values)
risk = result.get("risk_level", "UNKNOWN")
alert = result.get("alert", "")
css = {"HIGH": "badge-high", "MODERATE": "badge-moderate"}.get(risk, "badge-low")
icon = {"HIGH": "🚨", "MODERATE": "⚠"}.get(risk, "✓")
base = result.get("baseline_mic", "—")
curr = result.get("current_mic", "—")
ratio = result.get("ratio", "—")
return f"""
{icon} {risk} RISK — {alert}
Baseline MIC {base} mg/L |
Current MIC {curr} mg/L |
Fold change {ratio}× |
"""
def run_drug_safety(ab, meds, allergies_txt):
if not ab:
return 'Enter an antibiotic to check.
'
med_list = [m.strip() for m in (meds or "").split("\n") if m.strip()]
allergy_list = [a.strip() for a in (allergies_txt or "").split("\n") if a.strip()]
result = screen_antibiotic_safety(ab, med_list, allergy_list)
if result.get("safe_to_use"):
html = '✓ No critical safety concerns identified.
'
else:
html = '⚠ Safety concerns identified — review required.
'
html += "".join(
f'⚠ {a.get("message","")}
'
for a in result.get("alerts", [])
)
return html
def run_guidelines_search(query, pathogen_filter):
if not query:
return 'Enter a search query.
'
filt = None if pathogen_filter == "All" else pathogen_filter
results = search_clinical_guidelines(query, pathogen_filter=filt, n_results=5)
if not results:
return ('No results found. Try broader search terms or '
'check that the knowledge base has been initialised.
')
html = ""
for i, r in enumerate(results, 1):
score = r.get("relevance_score", 0)
content = r.get("content", "")
source = r.get("source", "")
src_str = f"
Source: {source}" if source else ""
html += (f'Result {i}'
f' · relevance {score:.2f}
{content}{src_str}
')
return html
# ── Widget factory for site-specific fields ────────────────────────────────────
def _make_site_widget(field):
ftype = field["type"]
label = field["label"]
if ftype == "selectbox":
return gr.Dropdown(choices=field["options"], value=field["options"][0], label=label)
if ftype == "multiselect":
return gr.CheckboxGroup(choices=field["options"], label=label)
if ftype == "number_input":
return gr.Number(
value=field.get("default", 0), label=label,
minimum=field.get("min"), maximum=field.get("max"),
)
if ftype == "checkbox":
return gr.Checkbox(value=field.get("default", False), label=label)
return gr.Textbox(label=label)
# ── Models table (build-time) ─────────────────────────────────────────────────
_s = get_settings()
OVERVIEW_MODELS_MD = f"""
| Agent | Role | Model |
|---|---|---|
| 1, 2, 4 | Clinical reasoning | `{_s.medgemma_4b_model or "google/medgemma-4b-it"}` |
| 3 | Trend analysis | `{_s.medgemma_27b_model or "google/medgemma-27b-text-it"}` |
| 4 (safety) | Pharmacology check | `{_s.txgemma_9b_model or "google/txgemma-9b-predict"}` |
| — | Semantic retrieval | `{_s.embedding_model_name}` |
| — | Inference backend | HuggingFace Transformers · {_s.quantization} quant |
"""
# ── Gradio Blocks ─────────────────────────────────────────────────────────────
with gr.Blocks(theme=gr.themes.Soft(), css=CSS, title="AMR-Guard") as demo:
gr.HTML(BANNER_HTML)
with gr.Tabs():
# ── Tab 1: Overview ────────────────────────────────────────────────────
with gr.Tab("Overview"):
gr.HTML("""
System Overview
WHO AWaRe
264
antibiotics classified
EUCAST
v16.0
breakpoint tables
IDSA
2024
treatment guidelines
DDInter
191K+
drug interactions
Agent Pipeline
""")
with gr.Row():
with gr.Column():
gr.HTML("""
Stage 1 — Empirical (no lab results yet)
Agent 01
Intake Historian
Parses patient data, calculates CrCl, identifies MDR risk factors
Agent 04
Clinical Pharmacologist
Empirical antibiotic selection · WHO AWaRe · safety screening
""")
with gr.Column():
gr.HTML("""
Stage 2 — Targeted (culture / sensitivity available)
Agent 01
Intake Historian
Same as Stage 1
Agent 02
Vision Specialist
Extracts structured data from lab reports (any language / format)
Agent 03
Trend Analyst
Detects MIC creep · calculates resistance velocity
Agent 04
Clinical Pharmacologist
Targeted recommendation informed by susceptibility data
""")
gr.HTML('AI Models (Local)
')
gr.Markdown(OVERVIEW_MODELS_MD)
gr.HTML(
'⚠ Research demo only. '
"Not validated for clinical use. All recommendations must be reviewed "
"by a licensed clinician before any patient-care decision.
"
)
# ── Tab 2: Patient Analysis ────────────────────────────────────────────
with gr.Tab("Patient Analysis"):
gr.HTML('Patient Analysis Pipeline
')
# Demographics row
with gr.Row():
with gr.Column(scale=1):
age = gr.Number(value=65, label="Age (years)", minimum=0, maximum=120, precision=0)
weight = gr.Number(value=70.0, label="Weight (kg)", minimum=1.0, maximum=300.0)
height = gr.Number(value=170.0,label="Height (cm)", minimum=50.0,maximum=250.0)
with gr.Column(scale=1):
sex = gr.Dropdown(choices=["male", "female"], value="male", label="Biological sex")
creatinine_main = gr.Number(value=1.2, label="Serum Creatinine (mg/dL)",
minimum=0.1, maximum=20.0, visible=True)
renal_flag = gr.Checkbox(label="Known renal impairment / CKD?", visible=False)
creatinine_optional = gr.Number(value=1.2, label="Serum Creatinine (mg/dL)",
minimum=0.1, maximum=20.0, visible=False)
with gr.Column(scale=1):
infection_site = gr.Dropdown(choices=INFECTION_SITES, value="urinary",
label="Primary infection site")
_init_src = SUSPECTED_SOURCE_OPTIONS.get("urinary", [])
suspected_source = gr.Dropdown(choices=_init_src,
value=_init_src[0] if _init_src else None,
label="Suspected source")
# Site-specific field groups (pre-rendered, one per site)
site_groups: dict = {}
# Component lists per site (in field declaration order)
u_comps: list = [] # 3 components
r_comps: list = [] # 4 components
b_comps: list = [] # 7 components
sk_comps: list = [] # 4 components
ia_comps: list = [] # 4 components
cn_comps: list = [] # 4 components
for site in INFECTION_SITES:
fields = SITE_SPECIFIC_FIELDS.get(site, [])
with gr.Group(visible=(site == "urinary")) as grp:
if fields:
gr.HTML(f'{site.title()} — Assessment
')
with gr.Row():
for field in fields:
comp = _make_site_widget(field)
if site == "urinary":
u_comps.append(comp)
elif site == "respiratory":
r_comps.append(comp)
elif site == "bloodstream":
b_comps.append(comp)
elif site == "skin":
sk_comps.append(comp)
elif site == "intra-abdominal":
ia_comps.append(comp)
elif site == "CNS":
cn_comps.append(comp)
site_groups[site] = grp
# Flatten all site components in fixed order for fn inputs
all_site_inputs = u_comps + r_comps + b_comps + sk_comps + ia_comps + cn_comps
# Medical history
gr.HTML('Medical History
')
with gr.Row():
with gr.Column():
medications = gr.Textbox(
label="Current medications (one per line)",
placeholder="Metformin\nLisinopril", lines=4,
)
allergies = gr.Textbox(
label="Drug allergies (one per line)",
placeholder="Penicillin\nSulfa", lines=3,
)
with gr.Column():
comorbidities = gr.CheckboxGroup(
choices=["Diabetes", "CKD", "Heart Failure", "COPD",
"Immunocompromised", "Recent Surgery", "Malignancy", "Liver Disease"],
label="Comorbidities",
)
risk_factors = gr.CheckboxGroup(
choices=["Prior MRSA", "Recent antibiotics (<90 d)", "Healthcare-associated",
"Recent hospitalisation", "Nursing home", "Prior MDR infection"],
label="MDR risk factors",
)
# Lab input
gr.HTML('Lab / Culture Results '
'(optional — triggers targeted pathway)
')
lab_method = gr.Radio(
choices=["None — empirical pathway only", "Upload file (PDF / image)", "Paste lab text"],
value="None — empirical pathway only",
label="Input method",
)
lab_file = gr.File(
label="Lab report",
file_types=[".pdf", ".png", ".jpg", ".jpeg", ".tiff", ".bmp"],
visible=False,
)
lab_paste = gr.Textbox(
label="Lab report text",
placeholder=(
"Organism: Escherichia coli\n"
"Ciprofloxacin: S MIC 0.25\n"
"Nitrofurantoin: S MIC 16\n"
"Ampicillin: R MIC >32"
),
lines=5, visible=False,
)
run_btn = gr.Button("Run Agent Pipeline", variant="primary")
# Results (hidden until pipeline completes)
with gr.Group(visible=False) as results_group:
gr.HTML('Results
')
with gr.Tabs():
with gr.Tab("Recommendation"):
rec_out = gr.HTML()
with gr.Tab("Patient Summary"):
intake_out = gr.HTML()
with gr.Tab("Lab Analysis"):
lab_html_out = gr.HTML()
lab_df_out = gr.DataFrame(label="Susceptibility Table", wrap=True)
with gr.Tab("Safety"):
safety_out = gr.HTML()
# ── Wiring ──
infection_site.change(
fn=update_site_ui,
inputs=[infection_site],
outputs=[
*[site_groups[s] for s in INFECTION_SITES],
suspected_source,
creatinine_main,
renal_flag,
creatinine_optional,
],
)
renal_flag.change(
fn=toggle_optional_creatinine,
inputs=[renal_flag],
outputs=[creatinine_optional],
)
lab_method.change(
fn=toggle_lab_inputs,
inputs=[lab_method],
outputs=[lab_file, lab_paste],
)
_loading_html = '⏳ Pipeline running — please wait…
'
run_btn.click(
fn=lambda: (
_loading_html, _loading_html, _loading_html,
pd.DataFrame(), _loading_html,
gr.update(visible=True),
),
inputs=[],
outputs=[rec_out, intake_out, lab_html_out, lab_df_out, safety_out, results_group],
queue=False,
).then(
fn=run_pipeline_ui,
inputs=[
age, weight, height, sex,
creatinine_main, renal_flag, creatinine_optional,
infection_site, suspected_source,
*all_site_inputs,
medications, allergies, comorbidities, risk_factors,
lab_method, lab_file, lab_paste,
],
outputs=[rec_out, intake_out, lab_html_out, lab_df_out, safety_out, results_group],
)
# ── Tab 3: Clinical Tools ──────────────────────────────────────────────
with gr.Tab("Clinical Tools"):
gr.HTML('Clinical Tools
')
tool_sel = gr.Dropdown(
choices=["Empirical Advisor", "MIC Interpreter", "MIC Trend Analysis", "Drug Safety Check"],
value="Empirical Advisor",
label="Select tool",
)
# Empirical Advisor
with gr.Group(visible=True) as grp_ea:
with gr.Row():
with gr.Column(scale=3):
ea_infection = gr.Dropdown(
choices=["Urinary Tract Infection", "Pneumonia", "Sepsis",
"Skin / Soft Tissue", "Intra-abdominal", "Meningitis"],
value="Urinary Tract Infection", label="Infection type",
)
ea_pathogen = gr.Textbox(
label="Suspected pathogen (optional)",
placeholder="e.g., Klebsiella pneumoniae",
)
ea_risk = gr.CheckboxGroup(
choices=["Prior MRSA", "Recent antibiotics (<90 d)", "Healthcare-associated",
"Immunocompromised", "Renal impairment", "Prior MDR"],
label="Risk factors",
)
with gr.Column(scale=1):
gr.HTML("""
WHO AWaRe
● Access — first-line
● Watch — second-line
● Reserve — last resort
""")
ea_btn = gr.Button("Get recommendation", variant="primary")
ea_out = gr.HTML()
# MIC Interpreter
with gr.Group(visible=False) as grp_mi:
with gr.Row():
with gr.Column():
mi_pathogen = gr.Textbox(label="Pathogen", placeholder="e.g., Escherichia coli")
mi_antibiotic= gr.Textbox(label="Antibiotic", placeholder="e.g., Ciprofloxacin")
mi_mic = gr.Number(value=1.0, label="MIC value (mg/L)", minimum=0.001, maximum=1024.0)
with gr.Column():
gr.HTML("""
Interpretation guide
S Susceptible — antibiotic is effective
I Intermediate — effective at higher doses
R Resistant — do not use
""")
mi_btn = gr.Button("Interpret", variant="primary")
mi_out = gr.HTML()
# MIC Trend Analysis
with gr.Group(visible=False) as grp_mt:
mt_n = gr.Slider(minimum=2, maximum=6, value=3, step=1,
label="Number of historical readings")
with gr.Row():
mt_m = [
gr.Number(value=float(2 ** i), label=f"MIC {i+1} (mg/L)",
minimum=0.001, maximum=256.0, visible=(i < 3))
for i in range(6)
]
mt_btn = gr.Button("Analyse trend", variant="primary")
mt_out = gr.HTML()
mt_n.change(fn=update_mic_inputs, inputs=[mt_n], outputs=mt_m)
# Drug Safety Check
with gr.Group(visible=False) as grp_ds:
with gr.Row():
with gr.Column():
ds_ab = gr.Textbox(label="Antibiotic to check",
placeholder="e.g., Ciprofloxacin")
ds_meds = gr.Textbox(label="Concurrent medications",
placeholder="Warfarin\nMetformin\nAmlodipine", lines=4)
with gr.Column():
ds_allergies = gr.Textbox(label="Known allergies",
placeholder="Penicillin\nSulfa", lines=3)
ds_btn = gr.Button("Check safety", variant="primary")
ds_out = gr.HTML()
tool_sel.change(
fn=switch_tool, inputs=[tool_sel],
outputs=[grp_ea, grp_mi, grp_mt, grp_ds],
)
ea_btn.click(fn=run_empirical, inputs=[ea_infection, ea_pathogen, ea_risk], outputs=[ea_out])
mi_btn.click(fn=run_mic_interpret, inputs=[mi_pathogen, mi_antibiotic, mi_mic], outputs=[mi_out])
mt_btn.click(fn=run_mic_trend, inputs=[mt_n, *mt_m], outputs=[mt_out])
ds_btn.click(fn=run_drug_safety, inputs=[ds_ab, ds_meds, ds_allergies], outputs=[ds_out])
# ── Tab 4: Guidelines ──────────────────────────────────────────────────
with gr.Tab("Guidelines"):
gr.HTML('Clinical Guidelines Search
')
with gr.Row():
gl_query = gr.Textbox(
label="Search query",
placeholder="e.g., ESBL E. coli UTI treatment carbapenems",
scale=3,
)
gl_filter = gr.Dropdown(
choices=["All", "ESBL-E", "CRE", "CRAB", "DTR-PA"],
value="All", label="Filter by pathogen", scale=1,
)
gl_btn = gr.Button("Search", variant="primary")
gl_out = gr.HTML()
gr.HTML(
'Sources: IDSA Treatment Guidelines 2024 · '
"EUCAST Breakpoint Tables v16.0 · WHO EML · DDInter drug interaction database.
"
)
gl_btn.click(fn=run_guidelines_search, inputs=[gl_query, gl_filter], outputs=[gl_out])
if __name__ == "__main__":
demo.launch()