""" EchoPrime FastAPI Backend ========================= Bridges the EchoPrime model to the React frontend. Model assets are downloaded automatically from HuggingFace on first run. Routes ------ POST /api/upload — Accept DICOM files, start async analysis job GET /api/status/{job_id} — Poll job status + progress GET /api/study/{study_id}— Fetch full study results (views + report + metrics) GET /api/view/{study_id}/{view_index}/frame — Serve labelled first-frame image DELETE /api/study/{study_id} — Delete a study from the job store Run --- uvicorn main:app --host 0.0.0.0 --port 8000 --reload """ import asyncio import base64 import io import json import math import os import shutil import sys import tempfile import time import uuid from pathlib import Path from typing import Any, Dict, List, Optional import cv2 import numpy as np import torch from dotenv import load_dotenv from huggingface_hub import snapshot_download # --------------------------------------------------------------------------- # Load .env (HF_TOKEN, ECHOPRIME_LANG, etc.) # --------------------------------------------------------------------------- load_dotenv() # --------------------------------------------------------------------------- # Download model repo from HuggingFace (cached after first run) # --------------------------------------------------------------------------- HF_REPO = os.getenv("HF_REPO", "amn23/echo-prime") HF_TOKEN = os.getenv("HF_TOKEN") print(f"[Setup] Downloading/verifying model from HuggingFace: {HF_REPO} ...") import time for attempt in range(5): try: ECHOPRIME_ROOT = Path( snapshot_download( repo_id=HF_REPO, repo_type="model", token=HF_TOKEN, ) ) break except Exception as e: if "429" in str(e) and attempt < 4: wait = 150 * (attempt + 1) print(f"[Setup] Rate limited, waiting {wait}s before retry {attempt+1}/5...") time.sleep(wait) else: raise print(f"[Setup] Model files at: {ECHOPRIME_ROOT}") # --------------------------------------------------------------------------- # Set working directory so utils.py can find assets/ via relative paths # --------------------------------------------------------------------------- os.environ["ECHOPRIME_ROOT_OVERRIDE"] = str(ECHOPRIME_ROOT) os.chdir(ECHOPRIME_ROOT) sys.path.insert(0, str(ECHOPRIME_ROOT)) sys.path.insert(0, "/app") from fastapi import BackgroundTasks, FastAPI, File, Form, HTTPException, UploadFile from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import JSONResponse, Response import utils from model import EchoPrime from gradcam import MViTGradCAM, apply_heatmap_to_video from attribution import compute_section_attributions # --------------------------------------------------------------------------- # App & CORS # --------------------------------------------------------------------------- app = FastAPI(title="EchoPrime API", version="1.0.0") app.add_middleware( CORSMiddleware, allow_origins=["*"], # tighten in production allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) # --------------------------------------------------------------------------- # Global model instance (loaded once at startup) # --------------------------------------------------------------------------- _model: Optional[EchoPrime] = None @app.on_event("startup") async def load_model() -> None: global _model lang = os.getenv("ECHOPRIME_LANG", "en") print(f"[Startup] Loading EchoPrime (lang={lang})…") _model = EchoPrime(lang=lang) print("[Startup] Model ready.") def get_model() -> EchoPrime: if _model is None: raise HTTPException(status_code=503, detail="Model not initialised yet.") return _model # --------------------------------------------------------------------------- # File-based job store (avoids multi-replica memory isolation on HF Spaces) # --------------------------------------------------------------------------- JOBS_DIR = Path("/tmp/echoprime_jobs") JOBS_DIR.mkdir(parents=True, exist_ok=True) def _save_job(job_id: str, job: Dict[str, Any]) -> None: (JOBS_DIR / f"{job_id}.json").write_text(json.dumps(job, default=str)) def _load_job(job_id: str) -> Optional[Dict[str, Any]]: p = JOBS_DIR / f"{job_id}.json" if not p.exists(): return None return json.loads(p.read_text()) def _delete_job(job_id: str) -> None: p = JOBS_DIR / f"{job_id}.json" if p.exists(): p.unlink() # --------------------------------------------------------------------------- # Stack cache — stores video tensors on disk for GradCAM requests # --------------------------------------------------------------------------- STACKS_DIR = Path("/tmp/echoprime_stacks") STACKS_DIR.mkdir(parents=True, exist_ok=True) def _save_stack(study_id: str, stack) -> None: torch.save(stack.cpu(), STACKS_DIR / f"{study_id}.pt") def _load_stack(study_id: str): p = STACKS_DIR / f"{study_id}.pt" if not p.exists(): return None return torch.load(p, map_location="cpu") def _delete_stack(study_id: str) -> None: p = STACKS_DIR / f"{study_id}.pt" if p.exists(): p.unlink() _gradcam_instance: Optional[MViTGradCAM] = None def _get_gradcam(ep: EchoPrime) -> MViTGradCAM: global _gradcam_instance if _gradcam_instance is None: import copy cpu_encoder = copy.deepcopy(ep.echo_encoder).cpu() cpu_encoder.eval() _gradcam_instance = MViTGradCAM(cpu_encoder) return _gradcam_instance # Mapping from view label → report section(s) it informs # Mirrors the logic already in study-results.tsx _VIEW_SECTION_MAP: Dict[str, List[str]] = { "PSAX Pap Musc": ["lv"], "PSAX Apex": ["lv"], "PLAX": ["lv", "valves"], "PLAX Zoomed Out": ["valves", "pericardium"], "PLAX AV/MV": ["valves"], "A4C": ["rh", "valves"], "A4C LV": ["lv", "la"], "A2C": ["lv", "la"], "A2C LV": ["lv"], "Subcostal": ["rh", "pericardium"], "Suprasternal": ["la"], } _COLOR_MAP: Dict[str, str] = { "lv": "cyan", "la": "blue", "rh": "green", "valves": "purple", "pericardium": "amber", } def _sections_for_view(label: str) -> List[str]: """Return linked report sections for a view label, defaulting to ['lv'].""" for key, sections in _VIEW_SECTION_MAP.items(): if key.lower() in label.lower() or label.lower() in key.lower(): return sections return ["lv"] def _color_for_sections(sections: List[str]) -> str: if sections: return _COLOR_MAP.get(sections[0], "cyan") return "cyan" # --------------------------------------------------------------------------- # Report → section mapping # --------------------------------------------------------------------------- # EchoPrime returns a flat string with sections separated by [SEP]. # The section names match COARSE_VIEWS / MIL section names. # We map them to the five UI buckets: lv, la, rh, valves, pericardium. _SECTION_BUCKET_MAP: Dict[str, str] = { "left ventricle": "lv", "resting segmental": "lv", "right ventricle": "rh", "left atrium": "la", "right atrium": "rh", "atrial septum": "rh", "mitral valve": "valves", "aortic valve": "valves", "tricuspid valve": "valves", "pulmonic valve": "valves", "pericardium": "pericardium", "aorta": "pericardium", "ivc": "rh", "pulmonary artery": "valves", "pulmonary veins": "la", "postoperative": "pericardium", # Russian translations "левый желудочек": "lv", "анализ сегментарной": "lv", "правый желудочек": "rh", "левое предсердие": "la", "правое предсердие": "rh", "межпредсердная перегородка":"rh", "митральный клапан": "valves", "аортальный клапан": "valves", "трёхстворчатый клапан": "valves", "клапан лёгочной": "valves", "перикард": "pericardium", "аорта": "pericardium", "нижняя полая вена": "rh", "лёгочная артерия": "valves", "лёгочные вены": "la", "послеоперационные": "pericardium", } _BUCKET_LABELS_EN: Dict[str, str] = { "lv": "Left Ventricle", "la": "Left Atrium", "rh": "Right Heart", "valves": "Valves", "pericardium": "Pericardium", } def _parse_report_to_sections(raw_report: str) -> Dict[str, str]: """ Split the flat EchoPrime report string into the five UI buckets. Each bucket accumulates text from the relevant clinical sections. """ buckets: Dict[str, List[str]] = {k: [] for k in _BUCKET_LABELS_EN} # EchoPrime joins sections with \n (after [SEP] replacement in app.py). # The section header is usually the first sentence of each paragraph. paragraphs = [p.strip() for p in raw_report.split("\n") if p.strip()] current_bucket = "lv" # default if we can't determine for para in paragraphs: lower = para.lower() matched = False for keyword, bucket in _SECTION_BUCKET_MAP.items(): if keyword in lower: current_bucket = bucket matched = True break buckets[current_bucket].append(para) # Collapse each bucket to a single string return {k: " ".join(v) if v else "" for k, v in buckets.items()} # --------------------------------------------------------------------------- # Metrics → key cardiac values for the summary cards # --------------------------------------------------------------------------- _METRIC_CARD_KEYS = [ ("ef", ["EF", "LVEF", "Ejection Fraction"], "%", "55-70%"), ("lvedv", ["LVEDV", "LV EDV"], "mL", "65-240 mL"), ("lvesv", ["LVESV", "LV ESV"], "mL", "16-143 mL"), ] def _extract_card_metrics(metrics: Dict[str, float]) -> List[Dict]: """Return up to three card-metric dicts for the UI summary strip.""" cards = [] for key, candidates, unit, normal in _METRIC_CARD_KEYS: for cand in candidates: # Case-insensitive key search found = next( (v for k, v in metrics.items() if k.lower().replace(" ", "").replace("_", "") == cand.lower().replace(" ", "").replace("_", "")), None, ) if found is not None and not (isinstance(found, float) and math.isnan(found)): cards.append({ "id": key, "label": cand, "value": round(float(found), 1), "unit": unit, "normal": normal, }) break return cards # --------------------------------------------------------------------------- # Frame image cache { (study_id, view_index): base64_png_str } # --------------------------------------------------------------------------- _frame_cache: Dict[str, str] = {} def _encode_frame_b64( img_tensor: torch.Tensor, label: str, ) -> str: """Convert a (C,H,W) tensor to a base64 PNG with a burnt-in label.""" img = img_tensor[0].cpu().numpy() img = img - img.min() mx = img.max() if mx > 0: img = img / mx img = (img * 255).astype(np.uint8) img = np.ascontiguousarray(img) display = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB) label_text = label.replace("_", " ") cv2.putText( display, label_text, (10, 25), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 220, 255), 2, ) _, buf = cv2.imencode(".png", cv2.cvtColor(display, cv2.COLOR_RGB2BGR)) return base64.b64encode(buf.tobytes()).decode() # --------------------------------------------------------------------------- # Background analysis task # --------------------------------------------------------------------------- def _run_analysis(job_id: str, tmp_dir: str) -> None: """ Runs in a thread pool via BackgroundTasks. Updates _jobs[job_id] with progress, step, and final result. """ job = _load_job(job_id) ep = get_model() def upd(progress: int, step: int, msg: str) -> None: job["progress"] = progress job["step"] = step job["stage_msg"] = msg _save_job(job_id, job) try: job["status"] = "processing" _save_job(job_id, job) # Step 1 — already done (upload) upd(20, 1, "Preprocessing DICOM files…") # Step 2 — DICOM → tensor stack = ep.process_dicoms(tmp_dir) if stack.numel() == 0: raise ValueError("No valid DICOM video clips found in upload.") upd(40, 2, "Detecting cardiac views…") # Step 3 — view classification (for gallery labels) first_frames = stack[:, :, 0, :, :].to(ep.device) with torch.no_grad(): logits = ep.view_classifier(first_frames) view_indices = torch.argmax(logits, dim=1) view_labels = [utils.COARSE_VIEWS[v] for v in view_indices] # Confidence = softmax probability of top class probs = torch.softmax(logits, dim=1) confidences = [float(probs[i, view_indices[i]]) for i in range(len(view_labels))] # Cache frames as base64 PNGs study_id = job_id for idx, (label, conf) in enumerate(zip(view_labels, confidences)): b64 = _encode_frame_b64(first_frames[idx].cpu(), label) _frame_cache[f"{study_id}:{idx}"] = b64 upd(55, 3, "Encoding study…") # Step 3b — encode study encoded = ep.encode_study(stack, visualize=False) # Cache stack for GradCAM requests _save_stack(job_id, stack) upd(70, 4, "Predicting cardiac metrics…") # Step 4 — metrics raw_metrics = ep.predict_metrics(encoded) # Convert nan to None for JSON serialisation metrics_json = { k: (None if (isinstance(v, float) and math.isnan(v)) else round(float(v), 3)) for k, v in raw_metrics.items() } upd(85, 5, "Generating clinical report…") # Step 5 — report raw_report = ep.generate_report(encoded) clean_report = raw_report.replace("[SEP]", "\n").replace(" .", ".") sections = _parse_report_to_sections(clean_report) # ── Per-study view-to-section attribution using real MIL weights ────── # # MIL_weights.csv structure (confirmed from actual file): # rows = EP sections (Left Ventricle, Right Ventricle, …) # columns = COARSE_VIEWS (A2C, A3C, A4C, …) # values = raw attention weight 0.0–1.0 (NOT normalised across views) # meaning: "how important is this view TYPE for this section" # # For each clip we know its view type from the one-hot encoding. # We look up the raw weight for every section directly from the CSV. # Threshold = 0.35 (moderate-to-high importance) based on the CSV values. # This gives: A4C → LV(0.65), RV(0.71), LA(1.0), RA(0.87), MV(0.86)… # SSN → Pericardium(1.0), Aorta(1.0), PA(1.0) # Subcostal → RV(1.0), RA(1.0), IVC(1.0), PA(0.78) ATTRIBUTION_THRESHOLD = 0.35 # Map EP section name → UI bucket (reuse _SECTION_BUCKET_MAP) def _ep_section_to_bucket(ep_sec: str) -> Optional[str]: low = ep_sec.lower() for kw, bucket in _SECTION_BUCKET_MAP.items(): if kw in low: return bucket return None n_clips = encoded.shape[0] n_sections = len(ep.non_empty_sections) view_enc = encoded[:, 512:].cpu() # (N, 11) clip_views = torch.argmax(view_enc, dim=1).numpy() # (N,) view index 0-10 # Build per-clip attribution from raw CSV weights clip_attribution: List[Dict[str, Any]] = [] for i in range(n_clips): v = int(clip_views[i]) # view type index for this clip bucket_weights: Dict[str, float] = {} for s_idx, ep_sec in enumerate(ep.non_empty_sections): bucket = _ep_section_to_bucket(str(ep_sec)) if bucket is None: continue # raw weight from CSV: how important is view v for section s_idx w = float(ep.section_weights[s_idx, v]) # keep max weight if multiple EP sections map to same UI bucket if bucket not in bucket_weights or w > bucket_weights[bucket]: bucket_weights[bucket] = w # linked = buckets where this view type has meaningful attention linked = [ b for b, w in sorted(bucket_weights.items(), key=lambda x: -x[1]) if w >= ATTRIBUTION_THRESHOLD ] # always keep at least the top bucket if not linked and bucket_weights: linked = [max(bucket_weights, key=bucket_weights.get)] elif not linked: linked = ["lv"] clip_attribution.append({ "linkedSections": linked, "sectionWeights": {b: round(bucket_weights[b], 3) for b in linked}, }) # Build views payload using real attribution views_payload = [] for idx, (label, conf) in enumerate(zip(view_labels, confidences)): attr = clip_attribution[idx] linked = attr["linkedSections"] views_payload.append({ "id": str(idx), "label": label.replace("_", " "), "confidence": round(conf, 4), "linkedSections": linked, "sectionWeights": attr["sectionWeights"], "colorCode": _color_for_sections(linked), "frameIndex": idx, }) upd(100, 5, "Analysis complete.") # Compute section attributions for GradCAM section_attrs = compute_section_attributions( study_embedding=encoded, candidate_embeddings=ep.candidate_embeddings, section_weights=ep.section_weights, non_empty_sections=ep.non_empty_sections, view_labels=view_labels, k=50, ) section_embeddings = { sec: attrs.section_embedding.tolist() for sec, attrs in section_attrs.items() } (STACKS_DIR / f"{study_id}_sections.json").write_text(json.dumps(section_embeddings)) job["result"] = { "studyId": study_id, "views": views_payload, "report": sections, "rawReport": clean_report, "metrics": metrics_json, "cardMetrics": _extract_card_metrics(raw_metrics), "createdAt": time.strftime("%Y-%m-%d %H:%M"), } _save_job(job_id, job) job["status"] = "done" _save_job(job_id, job) except Exception as exc: job["status"] = "error" job["error"] = str(exc) job["progress"] = 0 _save_job(job_id, job) raise finally: # Clean up uploaded files try: shutil.rmtree(tmp_dir, ignore_errors=True) except Exception: pass # --------------------------------------------------------------------------- # Routes # --------------------------------------------------------------------------- @app.post("/api/upload") async def upload_study( background_tasks: BackgroundTasks, files: List[UploadFile] = File(...), external_id: Optional[str] = Form(None), ) -> JSONResponse: """ Accept DICOM files (or a ZIP), save to a temp folder, and kick off background analysis. Returns {jobId, studyId} immediately so the frontend can start polling /api/status/{jobId}. """ if not files: raise HTTPException(status_code=400, detail="No files provided.") job_id = external_id or str(uuid.uuid4()) tmp_dir = tempfile.mkdtemp(prefix=f"echoprime_{job_id}_") # Save all uploads for f in files: dest = Path(tmp_dir) / (f.filename or f.name or "upload.dcm") dest.parent.mkdir(parents=True, exist_ok=True) content = await f.read() dest.write_bytes(content) _save_job(job_id, { "status": "queued", "progress": 5, "step": 1, "stage_msg": "Files received, queuing analysis…", "study_id": job_id, "error": None, "result": None, "tmp_dir": tmp_dir, }) background_tasks.add_task(_run_analysis, job_id, tmp_dir) return JSONResponse({"jobId": job_id, "studyId": job_id}, status_code=202) @app.get("/api/status/{job_id}") async def get_status(job_id: str) -> JSONResponse: """ Poll analysis progress. Response shape: { "status": "queued" | "processing" | "done" | "error", "progress": 0-100, "step": 1-5, "stageMsg": str, "error": str | null } """ job = _load_job(job_id) if job is None: raise HTTPException(status_code=404, detail="Job not found.") return JSONResponse({ "status": job["status"], "progress": job["progress"], "step": job["step"], "stageMsg": job["stage_msg"], "studyId": job["study_id"], "error": job["error"], }) @app.get("/api/study/{study_id}") async def get_study(study_id: str) -> JSONResponse: """ Return full study results once analysis is complete. Response shape: { "studyId": str, "createdAt": str, "views": [ { "id": str, "label": str, "confidence": float, "linkedSections": [str], "colorCode": str, "frameIndex": int }, … ], "report": { "lv": str, "la": str, "rh": str, "valves": str, "pericardium": str }, "rawReport": str, "metrics": { phenotype: value|null, … }, "cardMetrics": [ { "id": str, "label": str, "value": float, "unit": str, "normal": str }, … ] } """ job = _load_job(study_id) if job is None: raise HTTPException(status_code=404, detail="Study not found.") if job["status"] == "error": raise HTTPException(status_code=500, detail=job["error"]) if job["status"] != "done": raise HTTPException(status_code=202, detail="Analysis still in progress.") return JSONResponse(job["result"]) @app.get("/api/view/{study_id}/{frame_index}/frame") async def get_view_frame(study_id: str, frame_index: int) -> Response: """ Return the labelled first-frame PNG for a specific view as a raw PNG. The frontend can use this as . """ b64 = _frame_cache.get(f"{study_id}:{frame_index}") if b64 is None: raise HTTPException(status_code=404, detail="Frame not found.") img_bytes = base64.b64decode(b64) return Response(content=img_bytes, media_type="image/png") @app.delete("/api/study/{study_id}") async def delete_study(study_id: str) -> JSONResponse: """Remove job + cached frames from memory.""" if _load_job(study_id) is None: raise HTTPException(status_code=404, detail="Study not found.") _delete_job(study_id) _delete_stack(study_id) (STACKS_DIR / f"{study_id}_sections.json").unlink(missing_ok=True) # Remove all cached frames for this study keys_to_delete = [k for k in _frame_cache if k.startswith(f"{study_id}:")] for k in keys_to_delete: _frame_cache.pop(k, None) return JSONResponse({"deleted": study_id}) @app.post("/api/gradcam/{study_id}/{view_idx}/{section}") async def compute_gradcam(study_id: str, view_idx: int, section: str) -> JSONResponse: ep = get_model() stack = _load_stack(study_id) if stack is None: raise HTTPException(status_code=404, detail="Study video data not found. Re-upload the study.") sections_path = STACKS_DIR / f"{study_id}_sections.json" if not sections_path.exists(): raise HTTPException(status_code=404, detail="Section embeddings not found.") section_embeddings = json.loads(sections_path.read_text()) if section not in section_embeddings: available = list(section_embeddings.keys()) raise HTTPException(status_code=400, detail=f"Section not found. Available: {available}") if view_idx < 0 or view_idx >= stack.shape[0]: raise HTTPException(status_code=400, detail=f"view_idx {view_idx} out of range") try: target = torch.tensor(section_embeddings[section], dtype=torch.float32) video = stack[view_idx:view_idx + 1].cpu() gradcam = _get_gradcam(ep) for p in gradcam.model.parameters(): p.requires_grad_(True) try: heatmap = gradcam.generate(video, target, output_size=(16, 224, 224)) finally: for p in gradcam.model.parameters(): p.requires_grad_(False) raw_video = stack[view_idx].cpu() T = raw_video.shape[1] frames = [] for t in range(min(T, 16)): frame = raw_video[:, t, :, :].permute(1, 2, 0).numpy() mean = np.array([29.110628, 28.076836, 29.096405]) std = np.array([47.989223, 46.456997, 47.20083]) frame = frame * std + mean frame = np.clip(frame, 0, 255).astype(np.uint8) frames.append(frame) frames_np = np.stack(frames) overlay = apply_heatmap_to_video(frames_np, heatmap[:len(frames)], alpha=0.4) encoded_frames = [] for t in range(overlay.shape[0]): _, buf = cv2.imencode(".jpg", cv2.cvtColor(overlay[t], cv2.COLOR_RGB2BGR), [cv2.IMWRITE_JPEG_QUALITY, 85]) encoded_frames.append(base64.b64encode(buf.tobytes()).decode()) return JSONResponse({ "studyId": study_id, "viewIdx": view_idx, "section": section, "numFrames": len(encoded_frames), "frames": encoded_frames, }) except Exception as exc: import traceback; traceback.print_exc() raise HTTPException(status_code=500, detail=f"GradCAM failed: {exc}") @app.get("/api/gradcam/sections/{study_id}") async def list_gradcam_sections(study_id: str) -> JSONResponse: sections_path = STACKS_DIR / f"{study_id}_sections.json" if not sections_path.exists(): raise HTTPException(status_code=404, detail="Study not found.") sections = list(json.loads(sections_path.read_text()).keys()) return JSONResponse({"sections": sections}) @app.get("/api/health") async def health() -> JSONResponse: return JSONResponse({"status": "ok", "modelReady": _model is not None})