echo-prime-demo / main.py
amn23's picture
Update main.py
b8603ef verified
"""
EchoPrime FastAPI Backend
=========================
Bridges the EchoPrime model to the React frontend.
Model assets are downloaded automatically from HuggingFace on first run.
Routes
------
POST /api/upload — Accept DICOM files, start async analysis job
GET /api/status/{job_id} — Poll job status + progress
GET /api/study/{study_id}— Fetch full study results (views + report + metrics)
GET /api/view/{study_id}/{view_index}/frame — Serve labelled first-frame image
DELETE /api/study/{study_id} — Delete a study from the job store
Run
---
uvicorn main:app --host 0.0.0.0 --port 8000 --reload
"""
import asyncio
import base64
import io
import json
import math
import os
import shutil
import sys
import tempfile
import time
import uuid
from pathlib import Path
from typing import Any, Dict, List, Optional
import cv2
import numpy as np
import torch
from dotenv import load_dotenv
from huggingface_hub import snapshot_download
# ---------------------------------------------------------------------------
# Load .env (HF_TOKEN, ECHOPRIME_LANG, etc.)
# ---------------------------------------------------------------------------
load_dotenv()
# ---------------------------------------------------------------------------
# Download model repo from HuggingFace (cached after first run)
# ---------------------------------------------------------------------------
HF_REPO = os.getenv("HF_REPO", "amn23/echo-prime")
HF_TOKEN = os.getenv("HF_TOKEN")
print(f"[Setup] Downloading/verifying model from HuggingFace: {HF_REPO} ...")
import time
for attempt in range(5):
try:
ECHOPRIME_ROOT = Path(
snapshot_download(
repo_id=HF_REPO,
repo_type="model",
token=HF_TOKEN,
)
)
break
except Exception as e:
if "429" in str(e) and attempt < 4:
wait = 150 * (attempt + 1)
print(f"[Setup] Rate limited, waiting {wait}s before retry {attempt+1}/5...")
time.sleep(wait)
else:
raise
print(f"[Setup] Model files at: {ECHOPRIME_ROOT}")
# ---------------------------------------------------------------------------
# Set working directory so utils.py can find assets/ via relative paths
# ---------------------------------------------------------------------------
os.environ["ECHOPRIME_ROOT_OVERRIDE"] = str(ECHOPRIME_ROOT)
os.chdir(ECHOPRIME_ROOT)
sys.path.insert(0, str(ECHOPRIME_ROOT))
sys.path.insert(0, "/app")
from fastapi import BackgroundTasks, FastAPI, File, Form, HTTPException, UploadFile
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse, Response
import utils
from model import EchoPrime
from gradcam import MViTGradCAM, apply_heatmap_to_video
from attribution import compute_section_attributions
# ---------------------------------------------------------------------------
# App & CORS
# ---------------------------------------------------------------------------
app = FastAPI(title="EchoPrime API", version="1.0.0")
app.add_middleware(
CORSMiddleware,
allow_origins=["*"], # tighten in production
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# ---------------------------------------------------------------------------
# Global model instance (loaded once at startup)
# ---------------------------------------------------------------------------
_model: Optional[EchoPrime] = None
@app.on_event("startup")
async def load_model() -> None:
global _model
lang = os.getenv("ECHOPRIME_LANG", "en")
print(f"[Startup] Loading EchoPrime (lang={lang})…")
_model = EchoPrime(lang=lang)
print("[Startup] Model ready.")
def get_model() -> EchoPrime:
if _model is None:
raise HTTPException(status_code=503, detail="Model not initialised yet.")
return _model
# ---------------------------------------------------------------------------
# File-based job store (avoids multi-replica memory isolation on HF Spaces)
# ---------------------------------------------------------------------------
JOBS_DIR = Path("/tmp/echoprime_jobs")
JOBS_DIR.mkdir(parents=True, exist_ok=True)
def _save_job(job_id: str, job: Dict[str, Any]) -> None:
(JOBS_DIR / f"{job_id}.json").write_text(json.dumps(job, default=str))
def _load_job(job_id: str) -> Optional[Dict[str, Any]]:
p = JOBS_DIR / f"{job_id}.json"
if not p.exists():
return None
return json.loads(p.read_text())
def _delete_job(job_id: str) -> None:
p = JOBS_DIR / f"{job_id}.json"
if p.exists():
p.unlink()
# ---------------------------------------------------------------------------
# Stack cache — stores video tensors on disk for GradCAM requests
# ---------------------------------------------------------------------------
STACKS_DIR = Path("/tmp/echoprime_stacks")
STACKS_DIR.mkdir(parents=True, exist_ok=True)
def _save_stack(study_id: str, stack) -> None:
torch.save(stack.cpu(), STACKS_DIR / f"{study_id}.pt")
def _load_stack(study_id: str):
p = STACKS_DIR / f"{study_id}.pt"
if not p.exists():
return None
return torch.load(p, map_location="cpu")
def _delete_stack(study_id: str) -> None:
p = STACKS_DIR / f"{study_id}.pt"
if p.exists():
p.unlink()
_gradcam_instance: Optional[MViTGradCAM] = None
def _get_gradcam(ep: EchoPrime) -> MViTGradCAM:
global _gradcam_instance
if _gradcam_instance is None:
import copy
cpu_encoder = copy.deepcopy(ep.echo_encoder).cpu()
cpu_encoder.eval()
_gradcam_instance = MViTGradCAM(cpu_encoder)
return _gradcam_instance
# Mapping from view label → report section(s) it informs
# Mirrors the logic already in study-results.tsx
_VIEW_SECTION_MAP: Dict[str, List[str]] = {
"PSAX Pap Musc": ["lv"],
"PSAX Apex": ["lv"],
"PLAX": ["lv", "valves"],
"PLAX Zoomed Out": ["valves", "pericardium"],
"PLAX AV/MV": ["valves"],
"A4C": ["rh", "valves"],
"A4C LV": ["lv", "la"],
"A2C": ["lv", "la"],
"A2C LV": ["lv"],
"Subcostal": ["rh", "pericardium"],
"Suprasternal": ["la"],
}
_COLOR_MAP: Dict[str, str] = {
"lv": "cyan",
"la": "blue",
"rh": "green",
"valves": "purple",
"pericardium": "amber",
}
def _sections_for_view(label: str) -> List[str]:
"""Return linked report sections for a view label, defaulting to ['lv']."""
for key, sections in _VIEW_SECTION_MAP.items():
if key.lower() in label.lower() or label.lower() in key.lower():
return sections
return ["lv"]
def _color_for_sections(sections: List[str]) -> str:
if sections:
return _COLOR_MAP.get(sections[0], "cyan")
return "cyan"
# ---------------------------------------------------------------------------
# Report → section mapping
# ---------------------------------------------------------------------------
# EchoPrime returns a flat string with sections separated by [SEP].
# The section names match COARSE_VIEWS / MIL section names.
# We map them to the five UI buckets: lv, la, rh, valves, pericardium.
_SECTION_BUCKET_MAP: Dict[str, str] = {
"left ventricle": "lv",
"resting segmental": "lv",
"right ventricle": "rh",
"left atrium": "la",
"right atrium": "rh",
"atrial septum": "rh",
"mitral valve": "valves",
"aortic valve": "valves",
"tricuspid valve": "valves",
"pulmonic valve": "valves",
"pericardium": "pericardium",
"aorta": "pericardium",
"ivc": "rh",
"pulmonary artery": "valves",
"pulmonary veins": "la",
"postoperative": "pericardium",
# Russian translations
"левый желудочек": "lv",
"анализ сегментарной": "lv",
"правый желудочек": "rh",
"левое предсердие": "la",
"правое предсердие": "rh",
"межпредсердная перегородка":"rh",
"митральный клапан": "valves",
"аортальный клапан": "valves",
"трёхстворчатый клапан": "valves",
"клапан лёгочной": "valves",
"перикард": "pericardium",
"аорта": "pericardium",
"нижняя полая вена": "rh",
"лёгочная артерия": "valves",
"лёгочные вены": "la",
"послеоперационные": "pericardium",
}
_BUCKET_LABELS_EN: Dict[str, str] = {
"lv": "Left Ventricle",
"la": "Left Atrium",
"rh": "Right Heart",
"valves": "Valves",
"pericardium": "Pericardium",
}
def _parse_report_to_sections(raw_report: str) -> Dict[str, str]:
"""
Split the flat EchoPrime report string into the five UI buckets.
Each bucket accumulates text from the relevant clinical sections.
"""
buckets: Dict[str, List[str]] = {k: [] for k in _BUCKET_LABELS_EN}
# EchoPrime joins sections with \n (after [SEP] replacement in app.py).
# The section header is usually the first sentence of each paragraph.
paragraphs = [p.strip() for p in raw_report.split("\n") if p.strip()]
current_bucket = "lv" # default if we can't determine
for para in paragraphs:
lower = para.lower()
matched = False
for keyword, bucket in _SECTION_BUCKET_MAP.items():
if keyword in lower:
current_bucket = bucket
matched = True
break
buckets[current_bucket].append(para)
# Collapse each bucket to a single string
return {k: " ".join(v) if v else "" for k, v in buckets.items()}
# ---------------------------------------------------------------------------
# Metrics → key cardiac values for the summary cards
# ---------------------------------------------------------------------------
_METRIC_CARD_KEYS = [
("ef", ["EF", "LVEF", "Ejection Fraction"], "%", "55-70%"),
("lvedv", ["LVEDV", "LV EDV"], "mL", "65-240 mL"),
("lvesv", ["LVESV", "LV ESV"], "mL", "16-143 mL"),
]
def _extract_card_metrics(metrics: Dict[str, float]) -> List[Dict]:
"""Return up to three card-metric dicts for the UI summary strip."""
cards = []
for key, candidates, unit, normal in _METRIC_CARD_KEYS:
for cand in candidates:
# Case-insensitive key search
found = next(
(v for k, v in metrics.items()
if k.lower().replace(" ", "").replace("_", "") ==
cand.lower().replace(" ", "").replace("_", "")),
None,
)
if found is not None and not (isinstance(found, float) and math.isnan(found)):
cards.append({
"id": key,
"label": cand,
"value": round(float(found), 1),
"unit": unit,
"normal": normal,
})
break
return cards
# ---------------------------------------------------------------------------
# Frame image cache { (study_id, view_index): base64_png_str }
# ---------------------------------------------------------------------------
_frame_cache: Dict[str, str] = {}
def _encode_frame_b64(
img_tensor: torch.Tensor,
label: str,
) -> str:
"""Convert a (C,H,W) tensor to a base64 PNG with a burnt-in label."""
img = img_tensor[0].cpu().numpy()
img = img - img.min()
mx = img.max()
if mx > 0:
img = img / mx
img = (img * 255).astype(np.uint8)
img = np.ascontiguousarray(img)
display = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB)
label_text = label.replace("_", " ")
cv2.putText(
display,
label_text,
(10, 25),
cv2.FONT_HERSHEY_SIMPLEX,
0.7,
(0, 220, 255),
2,
)
_, buf = cv2.imencode(".png", cv2.cvtColor(display, cv2.COLOR_RGB2BGR))
return base64.b64encode(buf.tobytes()).decode()
# ---------------------------------------------------------------------------
# Background analysis task
# ---------------------------------------------------------------------------
def _run_analysis(job_id: str, tmp_dir: str) -> None:
"""
Runs in a thread pool via BackgroundTasks.
Updates _jobs[job_id] with progress, step, and final result.
"""
job = _load_job(job_id)
ep = get_model()
def upd(progress: int, step: int, msg: str) -> None:
job["progress"] = progress
job["step"] = step
job["stage_msg"] = msg
_save_job(job_id, job)
try:
job["status"] = "processing"
_save_job(job_id, job)
# Step 1 — already done (upload)
upd(20, 1, "Preprocessing DICOM files…")
# Step 2 — DICOM → tensor
stack = ep.process_dicoms(tmp_dir)
if stack.numel() == 0:
raise ValueError("No valid DICOM video clips found in upload.")
upd(40, 2, "Detecting cardiac views…")
# Step 3 — view classification (for gallery labels)
first_frames = stack[:, :, 0, :, :].to(ep.device)
with torch.no_grad():
logits = ep.view_classifier(first_frames)
view_indices = torch.argmax(logits, dim=1)
view_labels = [utils.COARSE_VIEWS[v] for v in view_indices]
# Confidence = softmax probability of top class
probs = torch.softmax(logits, dim=1)
confidences = [float(probs[i, view_indices[i]]) for i in range(len(view_labels))]
# Cache frames as base64 PNGs
study_id = job_id
for idx, (label, conf) in enumerate(zip(view_labels, confidences)):
b64 = _encode_frame_b64(first_frames[idx].cpu(), label)
_frame_cache[f"{study_id}:{idx}"] = b64
upd(55, 3, "Encoding study…")
# Step 3b — encode study
encoded = ep.encode_study(stack, visualize=False)
# Cache stack for GradCAM requests
_save_stack(job_id, stack)
upd(70, 4, "Predicting cardiac metrics…")
# Step 4 — metrics
raw_metrics = ep.predict_metrics(encoded)
# Convert nan to None for JSON serialisation
metrics_json = {
k: (None if (isinstance(v, float) and math.isnan(v)) else round(float(v), 3))
for k, v in raw_metrics.items()
}
upd(85, 5, "Generating clinical report…")
# Step 5 — report
raw_report = ep.generate_report(encoded)
clean_report = raw_report.replace("[SEP]", "\n").replace(" .", ".")
sections = _parse_report_to_sections(clean_report)
# ── Per-study view-to-section attribution using real MIL weights ──────
#
# MIL_weights.csv structure (confirmed from actual file):
# rows = EP sections (Left Ventricle, Right Ventricle, …)
# columns = COARSE_VIEWS (A2C, A3C, A4C, …)
# values = raw attention weight 0.0–1.0 (NOT normalised across views)
# meaning: "how important is this view TYPE for this section"
#
# For each clip we know its view type from the one-hot encoding.
# We look up the raw weight for every section directly from the CSV.
# Threshold = 0.35 (moderate-to-high importance) based on the CSV values.
# This gives: A4C → LV(0.65), RV(0.71), LA(1.0), RA(0.87), MV(0.86)…
# SSN → Pericardium(1.0), Aorta(1.0), PA(1.0)
# Subcostal → RV(1.0), RA(1.0), IVC(1.0), PA(0.78)
ATTRIBUTION_THRESHOLD = 0.35
# Map EP section name → UI bucket (reuse _SECTION_BUCKET_MAP)
def _ep_section_to_bucket(ep_sec: str) -> Optional[str]:
low = ep_sec.lower()
for kw, bucket in _SECTION_BUCKET_MAP.items():
if kw in low:
return bucket
return None
n_clips = encoded.shape[0]
n_sections = len(ep.non_empty_sections)
view_enc = encoded[:, 512:].cpu() # (N, 11)
clip_views = torch.argmax(view_enc, dim=1).numpy() # (N,) view index 0-10
# Build per-clip attribution from raw CSV weights
clip_attribution: List[Dict[str, Any]] = []
for i in range(n_clips):
v = int(clip_views[i]) # view type index for this clip
bucket_weights: Dict[str, float] = {}
for s_idx, ep_sec in enumerate(ep.non_empty_sections):
bucket = _ep_section_to_bucket(str(ep_sec))
if bucket is None:
continue
# raw weight from CSV: how important is view v for section s_idx
w = float(ep.section_weights[s_idx, v])
# keep max weight if multiple EP sections map to same UI bucket
if bucket not in bucket_weights or w > bucket_weights[bucket]:
bucket_weights[bucket] = w
# linked = buckets where this view type has meaningful attention
linked = [
b for b, w in sorted(bucket_weights.items(), key=lambda x: -x[1])
if w >= ATTRIBUTION_THRESHOLD
]
# always keep at least the top bucket
if not linked and bucket_weights:
linked = [max(bucket_weights, key=bucket_weights.get)]
elif not linked:
linked = ["lv"]
clip_attribution.append({
"linkedSections": linked,
"sectionWeights": {b: round(bucket_weights[b], 3) for b in linked},
})
# Build views payload using real attribution
views_payload = []
for idx, (label, conf) in enumerate(zip(view_labels, confidences)):
attr = clip_attribution[idx]
linked = attr["linkedSections"]
views_payload.append({
"id": str(idx),
"label": label.replace("_", " "),
"confidence": round(conf, 4),
"linkedSections": linked,
"sectionWeights": attr["sectionWeights"],
"colorCode": _color_for_sections(linked),
"frameIndex": idx,
})
upd(100, 5, "Analysis complete.")
# Compute section attributions for GradCAM
section_attrs = compute_section_attributions(
study_embedding=encoded,
candidate_embeddings=ep.candidate_embeddings,
section_weights=ep.section_weights,
non_empty_sections=ep.non_empty_sections,
view_labels=view_labels,
k=50,
)
section_embeddings = {
sec: attrs.section_embedding.tolist()
for sec, attrs in section_attrs.items()
}
(STACKS_DIR / f"{study_id}_sections.json").write_text(json.dumps(section_embeddings))
job["result"] = {
"studyId": study_id,
"views": views_payload,
"report": sections,
"rawReport": clean_report,
"metrics": metrics_json,
"cardMetrics": _extract_card_metrics(raw_metrics),
"createdAt": time.strftime("%Y-%m-%d %H:%M"),
}
_save_job(job_id, job)
job["status"] = "done"
_save_job(job_id, job)
except Exception as exc:
job["status"] = "error"
job["error"] = str(exc)
job["progress"] = 0
_save_job(job_id, job)
raise
finally:
# Clean up uploaded files
try:
shutil.rmtree(tmp_dir, ignore_errors=True)
except Exception:
pass
# ---------------------------------------------------------------------------
# Routes
# ---------------------------------------------------------------------------
@app.post("/api/upload")
async def upload_study(
background_tasks: BackgroundTasks,
files: List[UploadFile] = File(...),
external_id: Optional[str] = Form(None),
) -> JSONResponse:
"""
Accept DICOM files (or a ZIP), save to a temp folder, and kick off
background analysis. Returns {jobId, studyId} immediately so the
frontend can start polling /api/status/{jobId}.
"""
if not files:
raise HTTPException(status_code=400, detail="No files provided.")
job_id = external_id or str(uuid.uuid4())
tmp_dir = tempfile.mkdtemp(prefix=f"echoprime_{job_id}_")
# Save all uploads
for f in files:
dest = Path(tmp_dir) / (f.filename or f.name or "upload.dcm")
dest.parent.mkdir(parents=True, exist_ok=True)
content = await f.read()
dest.write_bytes(content)
_save_job(job_id, {
"status": "queued",
"progress": 5,
"step": 1,
"stage_msg": "Files received, queuing analysis…",
"study_id": job_id,
"error": None,
"result": None,
"tmp_dir": tmp_dir,
})
background_tasks.add_task(_run_analysis, job_id, tmp_dir)
return JSONResponse({"jobId": job_id, "studyId": job_id}, status_code=202)
@app.get("/api/status/{job_id}")
async def get_status(job_id: str) -> JSONResponse:
"""
Poll analysis progress.
Response shape:
{
"status": "queued" | "processing" | "done" | "error",
"progress": 0-100,
"step": 1-5,
"stageMsg": str,
"error": str | null
}
"""
job = _load_job(job_id)
if job is None:
raise HTTPException(status_code=404, detail="Job not found.")
return JSONResponse({
"status": job["status"],
"progress": job["progress"],
"step": job["step"],
"stageMsg": job["stage_msg"],
"studyId": job["study_id"],
"error": job["error"],
})
@app.get("/api/study/{study_id}")
async def get_study(study_id: str) -> JSONResponse:
"""
Return full study results once analysis is complete.
Response shape:
{
"studyId": str,
"createdAt": str,
"views": [
{
"id": str, "label": str, "confidence": float,
"linkedSections": [str], "colorCode": str, "frameIndex": int
}, …
],
"report": {
"lv": str, "la": str, "rh": str, "valves": str, "pericardium": str
},
"rawReport": str,
"metrics": { phenotype: value|null, … },
"cardMetrics": [
{ "id": str, "label": str, "value": float, "unit": str, "normal": str },
]
}
"""
job = _load_job(study_id)
if job is None:
raise HTTPException(status_code=404, detail="Study not found.")
if job["status"] == "error":
raise HTTPException(status_code=500, detail=job["error"])
if job["status"] != "done":
raise HTTPException(status_code=202, detail="Analysis still in progress.")
return JSONResponse(job["result"])
@app.get("/api/view/{study_id}/{frame_index}/frame")
async def get_view_frame(study_id: str, frame_index: int) -> Response:
"""
Return the labelled first-frame PNG for a specific view as a raw PNG.
The frontend can use this as <img src="/api/view/{studyId}/{i}/frame" />.
"""
b64 = _frame_cache.get(f"{study_id}:{frame_index}")
if b64 is None:
raise HTTPException(status_code=404, detail="Frame not found.")
img_bytes = base64.b64decode(b64)
return Response(content=img_bytes, media_type="image/png")
@app.delete("/api/study/{study_id}")
async def delete_study(study_id: str) -> JSONResponse:
"""Remove job + cached frames from memory."""
if _load_job(study_id) is None:
raise HTTPException(status_code=404, detail="Study not found.")
_delete_job(study_id)
_delete_stack(study_id)
(STACKS_DIR / f"{study_id}_sections.json").unlink(missing_ok=True)
# Remove all cached frames for this study
keys_to_delete = [k for k in _frame_cache if k.startswith(f"{study_id}:")]
for k in keys_to_delete:
_frame_cache.pop(k, None)
return JSONResponse({"deleted": study_id})
@app.post("/api/gradcam/{study_id}/{view_idx}/{section}")
async def compute_gradcam(study_id: str, view_idx: int, section: str) -> JSONResponse:
ep = get_model()
stack = _load_stack(study_id)
if stack is None:
raise HTTPException(status_code=404, detail="Study video data not found. Re-upload the study.")
sections_path = STACKS_DIR / f"{study_id}_sections.json"
if not sections_path.exists():
raise HTTPException(status_code=404, detail="Section embeddings not found.")
section_embeddings = json.loads(sections_path.read_text())
if section not in section_embeddings:
available = list(section_embeddings.keys())
raise HTTPException(status_code=400, detail=f"Section not found. Available: {available}")
if view_idx < 0 or view_idx >= stack.shape[0]:
raise HTTPException(status_code=400, detail=f"view_idx {view_idx} out of range")
try:
target = torch.tensor(section_embeddings[section], dtype=torch.float32)
video = stack[view_idx:view_idx + 1].cpu()
gradcam = _get_gradcam(ep)
for p in gradcam.model.parameters():
p.requires_grad_(True)
try:
heatmap = gradcam.generate(video, target, output_size=(16, 224, 224))
finally:
for p in gradcam.model.parameters():
p.requires_grad_(False)
raw_video = stack[view_idx].cpu()
T = raw_video.shape[1]
frames = []
for t in range(min(T, 16)):
frame = raw_video[:, t, :, :].permute(1, 2, 0).numpy()
mean = np.array([29.110628, 28.076836, 29.096405])
std = np.array([47.989223, 46.456997, 47.20083])
frame = frame * std + mean
frame = np.clip(frame, 0, 255).astype(np.uint8)
frames.append(frame)
frames_np = np.stack(frames)
overlay = apply_heatmap_to_video(frames_np, heatmap[:len(frames)], alpha=0.4)
encoded_frames = []
for t in range(overlay.shape[0]):
_, buf = cv2.imencode(".jpg", cv2.cvtColor(overlay[t], cv2.COLOR_RGB2BGR),
[cv2.IMWRITE_JPEG_QUALITY, 85])
encoded_frames.append(base64.b64encode(buf.tobytes()).decode())
return JSONResponse({
"studyId": study_id,
"viewIdx": view_idx,
"section": section,
"numFrames": len(encoded_frames),
"frames": encoded_frames,
})
except Exception as exc:
import traceback; traceback.print_exc()
raise HTTPException(status_code=500, detail=f"GradCAM failed: {exc}")
@app.get("/api/gradcam/sections/{study_id}")
async def list_gradcam_sections(study_id: str) -> JSONResponse:
sections_path = STACKS_DIR / f"{study_id}_sections.json"
if not sections_path.exists():
raise HTTPException(status_code=404, detail="Study not found.")
sections = list(json.loads(sections_path.read_text()).keys())
return JSONResponse({"sections": sections})
@app.get("/api/health")
async def health() -> JSONResponse:
return JSONResponse({"status": "ok", "modelReady": _model is not None})