Spaces:
Running
Running
| """ | |
| EchoPrime FastAPI Backend | |
| ========================= | |
| Bridges the EchoPrime model to the React frontend. | |
| Model assets are downloaded automatically from HuggingFace on first run. | |
| Routes | |
| ------ | |
| POST /api/upload — Accept DICOM files, start async analysis job | |
| GET /api/status/{job_id} — Poll job status + progress | |
| GET /api/study/{study_id}— Fetch full study results (views + report + metrics) | |
| GET /api/view/{study_id}/{view_index}/frame — Serve labelled first-frame image | |
| DELETE /api/study/{study_id} — Delete a study from the job store | |
| Run | |
| --- | |
| uvicorn main:app --host 0.0.0.0 --port 8000 --reload | |
| """ | |
| import asyncio | |
| import base64 | |
| import io | |
| import json | |
| import math | |
| import os | |
| import shutil | |
| import sys | |
| import tempfile | |
| import time | |
| import uuid | |
| from pathlib import Path | |
| from typing import Any, Dict, List, Optional | |
| import cv2 | |
| import numpy as np | |
| import torch | |
| from dotenv import load_dotenv | |
| from huggingface_hub import snapshot_download | |
| # --------------------------------------------------------------------------- | |
| # Load .env (HF_TOKEN, ECHOPRIME_LANG, etc.) | |
| # --------------------------------------------------------------------------- | |
| load_dotenv() | |
| # --------------------------------------------------------------------------- | |
| # Download model repo from HuggingFace (cached after first run) | |
| # --------------------------------------------------------------------------- | |
| HF_REPO = os.getenv("HF_REPO", "amn23/echo-prime") | |
| HF_TOKEN = os.getenv("HF_TOKEN") | |
| print(f"[Setup] Downloading/verifying model from HuggingFace: {HF_REPO} ...") | |
| import time | |
| for attempt in range(5): | |
| try: | |
| ECHOPRIME_ROOT = Path( | |
| snapshot_download( | |
| repo_id=HF_REPO, | |
| repo_type="model", | |
| token=HF_TOKEN, | |
| ) | |
| ) | |
| break | |
| except Exception as e: | |
| if "429" in str(e) and attempt < 4: | |
| wait = 150 * (attempt + 1) | |
| print(f"[Setup] Rate limited, waiting {wait}s before retry {attempt+1}/5...") | |
| time.sleep(wait) | |
| else: | |
| raise | |
| print(f"[Setup] Model files at: {ECHOPRIME_ROOT}") | |
| # --------------------------------------------------------------------------- | |
| # Set working directory so utils.py can find assets/ via relative paths | |
| # --------------------------------------------------------------------------- | |
| os.environ["ECHOPRIME_ROOT_OVERRIDE"] = str(ECHOPRIME_ROOT) | |
| os.chdir(ECHOPRIME_ROOT) | |
| sys.path.insert(0, str(ECHOPRIME_ROOT)) | |
| sys.path.insert(0, "/app") | |
| from fastapi import BackgroundTasks, FastAPI, File, Form, HTTPException, UploadFile | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from fastapi.responses import JSONResponse, Response | |
| import utils | |
| from model import EchoPrime | |
| from gradcam import MViTGradCAM, apply_heatmap_to_video | |
| from attribution import compute_section_attributions | |
| # --------------------------------------------------------------------------- | |
| # App & CORS | |
| # --------------------------------------------------------------------------- | |
| app = FastAPI(title="EchoPrime API", version="1.0.0") | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], # tighten in production | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| # --------------------------------------------------------------------------- | |
| # Global model instance (loaded once at startup) | |
| # --------------------------------------------------------------------------- | |
| _model: Optional[EchoPrime] = None | |
| async def load_model() -> None: | |
| global _model | |
| lang = os.getenv("ECHOPRIME_LANG", "en") | |
| print(f"[Startup] Loading EchoPrime (lang={lang})…") | |
| _model = EchoPrime(lang=lang) | |
| print("[Startup] Model ready.") | |
| def get_model() -> EchoPrime: | |
| if _model is None: | |
| raise HTTPException(status_code=503, detail="Model not initialised yet.") | |
| return _model | |
| # --------------------------------------------------------------------------- | |
| # File-based job store (avoids multi-replica memory isolation on HF Spaces) | |
| # --------------------------------------------------------------------------- | |
| JOBS_DIR = Path("/tmp/echoprime_jobs") | |
| JOBS_DIR.mkdir(parents=True, exist_ok=True) | |
| def _save_job(job_id: str, job: Dict[str, Any]) -> None: | |
| (JOBS_DIR / f"{job_id}.json").write_text(json.dumps(job, default=str)) | |
| def _load_job(job_id: str) -> Optional[Dict[str, Any]]: | |
| p = JOBS_DIR / f"{job_id}.json" | |
| if not p.exists(): | |
| return None | |
| return json.loads(p.read_text()) | |
| def _delete_job(job_id: str) -> None: | |
| p = JOBS_DIR / f"{job_id}.json" | |
| if p.exists(): | |
| p.unlink() | |
| # --------------------------------------------------------------------------- | |
| # Stack cache — stores video tensors on disk for GradCAM requests | |
| # --------------------------------------------------------------------------- | |
| STACKS_DIR = Path("/tmp/echoprime_stacks") | |
| STACKS_DIR.mkdir(parents=True, exist_ok=True) | |
| def _save_stack(study_id: str, stack) -> None: | |
| torch.save(stack.cpu(), STACKS_DIR / f"{study_id}.pt") | |
| def _load_stack(study_id: str): | |
| p = STACKS_DIR / f"{study_id}.pt" | |
| if not p.exists(): | |
| return None | |
| return torch.load(p, map_location="cpu") | |
| def _delete_stack(study_id: str) -> None: | |
| p = STACKS_DIR / f"{study_id}.pt" | |
| if p.exists(): | |
| p.unlink() | |
| _gradcam_instance: Optional[MViTGradCAM] = None | |
| def _get_gradcam(ep: EchoPrime) -> MViTGradCAM: | |
| global _gradcam_instance | |
| if _gradcam_instance is None: | |
| import copy | |
| cpu_encoder = copy.deepcopy(ep.echo_encoder).cpu() | |
| cpu_encoder.eval() | |
| _gradcam_instance = MViTGradCAM(cpu_encoder) | |
| return _gradcam_instance | |
| # Mapping from view label → report section(s) it informs | |
| # Mirrors the logic already in study-results.tsx | |
| _VIEW_SECTION_MAP: Dict[str, List[str]] = { | |
| "PSAX Pap Musc": ["lv"], | |
| "PSAX Apex": ["lv"], | |
| "PLAX": ["lv", "valves"], | |
| "PLAX Zoomed Out": ["valves", "pericardium"], | |
| "PLAX AV/MV": ["valves"], | |
| "A4C": ["rh", "valves"], | |
| "A4C LV": ["lv", "la"], | |
| "A2C": ["lv", "la"], | |
| "A2C LV": ["lv"], | |
| "Subcostal": ["rh", "pericardium"], | |
| "Suprasternal": ["la"], | |
| } | |
| _COLOR_MAP: Dict[str, str] = { | |
| "lv": "cyan", | |
| "la": "blue", | |
| "rh": "green", | |
| "valves": "purple", | |
| "pericardium": "amber", | |
| } | |
| def _sections_for_view(label: str) -> List[str]: | |
| """Return linked report sections for a view label, defaulting to ['lv'].""" | |
| for key, sections in _VIEW_SECTION_MAP.items(): | |
| if key.lower() in label.lower() or label.lower() in key.lower(): | |
| return sections | |
| return ["lv"] | |
| def _color_for_sections(sections: List[str]) -> str: | |
| if sections: | |
| return _COLOR_MAP.get(sections[0], "cyan") | |
| return "cyan" | |
| # --------------------------------------------------------------------------- | |
| # Report → section mapping | |
| # --------------------------------------------------------------------------- | |
| # EchoPrime returns a flat string with sections separated by [SEP]. | |
| # The section names match COARSE_VIEWS / MIL section names. | |
| # We map them to the five UI buckets: lv, la, rh, valves, pericardium. | |
| _SECTION_BUCKET_MAP: Dict[str, str] = { | |
| "left ventricle": "lv", | |
| "resting segmental": "lv", | |
| "right ventricle": "rh", | |
| "left atrium": "la", | |
| "right atrium": "rh", | |
| "atrial septum": "rh", | |
| "mitral valve": "valves", | |
| "aortic valve": "valves", | |
| "tricuspid valve": "valves", | |
| "pulmonic valve": "valves", | |
| "pericardium": "pericardium", | |
| "aorta": "pericardium", | |
| "ivc": "rh", | |
| "pulmonary artery": "valves", | |
| "pulmonary veins": "la", | |
| "postoperative": "pericardium", | |
| # Russian translations | |
| "левый желудочек": "lv", | |
| "анализ сегментарной": "lv", | |
| "правый желудочек": "rh", | |
| "левое предсердие": "la", | |
| "правое предсердие": "rh", | |
| "межпредсердная перегородка":"rh", | |
| "митральный клапан": "valves", | |
| "аортальный клапан": "valves", | |
| "трёхстворчатый клапан": "valves", | |
| "клапан лёгочной": "valves", | |
| "перикард": "pericardium", | |
| "аорта": "pericardium", | |
| "нижняя полая вена": "rh", | |
| "лёгочная артерия": "valves", | |
| "лёгочные вены": "la", | |
| "послеоперационные": "pericardium", | |
| } | |
| _BUCKET_LABELS_EN: Dict[str, str] = { | |
| "lv": "Left Ventricle", | |
| "la": "Left Atrium", | |
| "rh": "Right Heart", | |
| "valves": "Valves", | |
| "pericardium": "Pericardium", | |
| } | |
| def _parse_report_to_sections(raw_report: str) -> Dict[str, str]: | |
| """ | |
| Split the flat EchoPrime report string into the five UI buckets. | |
| Each bucket accumulates text from the relevant clinical sections. | |
| """ | |
| buckets: Dict[str, List[str]] = {k: [] for k in _BUCKET_LABELS_EN} | |
| # EchoPrime joins sections with \n (after [SEP] replacement in app.py). | |
| # The section header is usually the first sentence of each paragraph. | |
| paragraphs = [p.strip() for p in raw_report.split("\n") if p.strip()] | |
| current_bucket = "lv" # default if we can't determine | |
| for para in paragraphs: | |
| lower = para.lower() | |
| matched = False | |
| for keyword, bucket in _SECTION_BUCKET_MAP.items(): | |
| if keyword in lower: | |
| current_bucket = bucket | |
| matched = True | |
| break | |
| buckets[current_bucket].append(para) | |
| # Collapse each bucket to a single string | |
| return {k: " ".join(v) if v else "" for k, v in buckets.items()} | |
| # --------------------------------------------------------------------------- | |
| # Metrics → key cardiac values for the summary cards | |
| # --------------------------------------------------------------------------- | |
| _METRIC_CARD_KEYS = [ | |
| ("ef", ["EF", "LVEF", "Ejection Fraction"], "%", "55-70%"), | |
| ("lvedv", ["LVEDV", "LV EDV"], "mL", "65-240 mL"), | |
| ("lvesv", ["LVESV", "LV ESV"], "mL", "16-143 mL"), | |
| ] | |
| def _extract_card_metrics(metrics: Dict[str, float]) -> List[Dict]: | |
| """Return up to three card-metric dicts for the UI summary strip.""" | |
| cards = [] | |
| for key, candidates, unit, normal in _METRIC_CARD_KEYS: | |
| for cand in candidates: | |
| # Case-insensitive key search | |
| found = next( | |
| (v for k, v in metrics.items() | |
| if k.lower().replace(" ", "").replace("_", "") == | |
| cand.lower().replace(" ", "").replace("_", "")), | |
| None, | |
| ) | |
| if found is not None and not (isinstance(found, float) and math.isnan(found)): | |
| cards.append({ | |
| "id": key, | |
| "label": cand, | |
| "value": round(float(found), 1), | |
| "unit": unit, | |
| "normal": normal, | |
| }) | |
| break | |
| return cards | |
| # --------------------------------------------------------------------------- | |
| # Frame image cache { (study_id, view_index): base64_png_str } | |
| # --------------------------------------------------------------------------- | |
| _frame_cache: Dict[str, str] = {} | |
| def _encode_frame_b64( | |
| img_tensor: torch.Tensor, | |
| label: str, | |
| ) -> str: | |
| """Convert a (C,H,W) tensor to a base64 PNG with a burnt-in label.""" | |
| img = img_tensor[0].cpu().numpy() | |
| img = img - img.min() | |
| mx = img.max() | |
| if mx > 0: | |
| img = img / mx | |
| img = (img * 255).astype(np.uint8) | |
| img = np.ascontiguousarray(img) | |
| display = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB) | |
| label_text = label.replace("_", " ") | |
| cv2.putText( | |
| display, | |
| label_text, | |
| (10, 25), | |
| cv2.FONT_HERSHEY_SIMPLEX, | |
| 0.7, | |
| (0, 220, 255), | |
| 2, | |
| ) | |
| _, buf = cv2.imencode(".png", cv2.cvtColor(display, cv2.COLOR_RGB2BGR)) | |
| return base64.b64encode(buf.tobytes()).decode() | |
| # --------------------------------------------------------------------------- | |
| # Background analysis task | |
| # --------------------------------------------------------------------------- | |
| def _run_analysis(job_id: str, tmp_dir: str) -> None: | |
| """ | |
| Runs in a thread pool via BackgroundTasks. | |
| Updates _jobs[job_id] with progress, step, and final result. | |
| """ | |
| job = _load_job(job_id) | |
| ep = get_model() | |
| def upd(progress: int, step: int, msg: str) -> None: | |
| job["progress"] = progress | |
| job["step"] = step | |
| job["stage_msg"] = msg | |
| _save_job(job_id, job) | |
| try: | |
| job["status"] = "processing" | |
| _save_job(job_id, job) | |
| # Step 1 — already done (upload) | |
| upd(20, 1, "Preprocessing DICOM files…") | |
| # Step 2 — DICOM → tensor | |
| stack = ep.process_dicoms(tmp_dir) | |
| if stack.numel() == 0: | |
| raise ValueError("No valid DICOM video clips found in upload.") | |
| upd(40, 2, "Detecting cardiac views…") | |
| # Step 3 — view classification (for gallery labels) | |
| first_frames = stack[:, :, 0, :, :].to(ep.device) | |
| with torch.no_grad(): | |
| logits = ep.view_classifier(first_frames) | |
| view_indices = torch.argmax(logits, dim=1) | |
| view_labels = [utils.COARSE_VIEWS[v] for v in view_indices] | |
| # Confidence = softmax probability of top class | |
| probs = torch.softmax(logits, dim=1) | |
| confidences = [float(probs[i, view_indices[i]]) for i in range(len(view_labels))] | |
| # Cache frames as base64 PNGs | |
| study_id = job_id | |
| for idx, (label, conf) in enumerate(zip(view_labels, confidences)): | |
| b64 = _encode_frame_b64(first_frames[idx].cpu(), label) | |
| _frame_cache[f"{study_id}:{idx}"] = b64 | |
| upd(55, 3, "Encoding study…") | |
| # Step 3b — encode study | |
| encoded = ep.encode_study(stack, visualize=False) | |
| # Cache stack for GradCAM requests | |
| _save_stack(job_id, stack) | |
| upd(70, 4, "Predicting cardiac metrics…") | |
| # Step 4 — metrics | |
| raw_metrics = ep.predict_metrics(encoded) | |
| # Convert nan to None for JSON serialisation | |
| metrics_json = { | |
| k: (None if (isinstance(v, float) and math.isnan(v)) else round(float(v), 3)) | |
| for k, v in raw_metrics.items() | |
| } | |
| upd(85, 5, "Generating clinical report…") | |
| # Step 5 — report | |
| raw_report = ep.generate_report(encoded) | |
| clean_report = raw_report.replace("[SEP]", "\n").replace(" .", ".") | |
| sections = _parse_report_to_sections(clean_report) | |
| # ── Per-study view-to-section attribution using real MIL weights ────── | |
| # | |
| # MIL_weights.csv structure (confirmed from actual file): | |
| # rows = EP sections (Left Ventricle, Right Ventricle, …) | |
| # columns = COARSE_VIEWS (A2C, A3C, A4C, …) | |
| # values = raw attention weight 0.0–1.0 (NOT normalised across views) | |
| # meaning: "how important is this view TYPE for this section" | |
| # | |
| # For each clip we know its view type from the one-hot encoding. | |
| # We look up the raw weight for every section directly from the CSV. | |
| # Threshold = 0.35 (moderate-to-high importance) based on the CSV values. | |
| # This gives: A4C → LV(0.65), RV(0.71), LA(1.0), RA(0.87), MV(0.86)… | |
| # SSN → Pericardium(1.0), Aorta(1.0), PA(1.0) | |
| # Subcostal → RV(1.0), RA(1.0), IVC(1.0), PA(0.78) | |
| ATTRIBUTION_THRESHOLD = 0.35 | |
| # Map EP section name → UI bucket (reuse _SECTION_BUCKET_MAP) | |
| def _ep_section_to_bucket(ep_sec: str) -> Optional[str]: | |
| low = ep_sec.lower() | |
| for kw, bucket in _SECTION_BUCKET_MAP.items(): | |
| if kw in low: | |
| return bucket | |
| return None | |
| n_clips = encoded.shape[0] | |
| n_sections = len(ep.non_empty_sections) | |
| view_enc = encoded[:, 512:].cpu() # (N, 11) | |
| clip_views = torch.argmax(view_enc, dim=1).numpy() # (N,) view index 0-10 | |
| # Build per-clip attribution from raw CSV weights | |
| clip_attribution: List[Dict[str, Any]] = [] | |
| for i in range(n_clips): | |
| v = int(clip_views[i]) # view type index for this clip | |
| bucket_weights: Dict[str, float] = {} | |
| for s_idx, ep_sec in enumerate(ep.non_empty_sections): | |
| bucket = _ep_section_to_bucket(str(ep_sec)) | |
| if bucket is None: | |
| continue | |
| # raw weight from CSV: how important is view v for section s_idx | |
| w = float(ep.section_weights[s_idx, v]) | |
| # keep max weight if multiple EP sections map to same UI bucket | |
| if bucket not in bucket_weights or w > bucket_weights[bucket]: | |
| bucket_weights[bucket] = w | |
| # linked = buckets where this view type has meaningful attention | |
| linked = [ | |
| b for b, w in sorted(bucket_weights.items(), key=lambda x: -x[1]) | |
| if w >= ATTRIBUTION_THRESHOLD | |
| ] | |
| # always keep at least the top bucket | |
| if not linked and bucket_weights: | |
| linked = [max(bucket_weights, key=bucket_weights.get)] | |
| elif not linked: | |
| linked = ["lv"] | |
| clip_attribution.append({ | |
| "linkedSections": linked, | |
| "sectionWeights": {b: round(bucket_weights[b], 3) for b in linked}, | |
| }) | |
| # Build views payload using real attribution | |
| views_payload = [] | |
| for idx, (label, conf) in enumerate(zip(view_labels, confidences)): | |
| attr = clip_attribution[idx] | |
| linked = attr["linkedSections"] | |
| views_payload.append({ | |
| "id": str(idx), | |
| "label": label.replace("_", " "), | |
| "confidence": round(conf, 4), | |
| "linkedSections": linked, | |
| "sectionWeights": attr["sectionWeights"], | |
| "colorCode": _color_for_sections(linked), | |
| "frameIndex": idx, | |
| }) | |
| upd(100, 5, "Analysis complete.") | |
| # Compute section attributions for GradCAM | |
| section_attrs = compute_section_attributions( | |
| study_embedding=encoded, | |
| candidate_embeddings=ep.candidate_embeddings, | |
| section_weights=ep.section_weights, | |
| non_empty_sections=ep.non_empty_sections, | |
| view_labels=view_labels, | |
| k=50, | |
| ) | |
| section_embeddings = { | |
| sec: attrs.section_embedding.tolist() | |
| for sec, attrs in section_attrs.items() | |
| } | |
| (STACKS_DIR / f"{study_id}_sections.json").write_text(json.dumps(section_embeddings)) | |
| job["result"] = { | |
| "studyId": study_id, | |
| "views": views_payload, | |
| "report": sections, | |
| "rawReport": clean_report, | |
| "metrics": metrics_json, | |
| "cardMetrics": _extract_card_metrics(raw_metrics), | |
| "createdAt": time.strftime("%Y-%m-%d %H:%M"), | |
| } | |
| _save_job(job_id, job) | |
| job["status"] = "done" | |
| _save_job(job_id, job) | |
| except Exception as exc: | |
| job["status"] = "error" | |
| job["error"] = str(exc) | |
| job["progress"] = 0 | |
| _save_job(job_id, job) | |
| raise | |
| finally: | |
| # Clean up uploaded files | |
| try: | |
| shutil.rmtree(tmp_dir, ignore_errors=True) | |
| except Exception: | |
| pass | |
| # --------------------------------------------------------------------------- | |
| # Routes | |
| # --------------------------------------------------------------------------- | |
| async def upload_study( | |
| background_tasks: BackgroundTasks, | |
| files: List[UploadFile] = File(...), | |
| external_id: Optional[str] = Form(None), | |
| ) -> JSONResponse: | |
| """ | |
| Accept DICOM files (or a ZIP), save to a temp folder, and kick off | |
| background analysis. Returns {jobId, studyId} immediately so the | |
| frontend can start polling /api/status/{jobId}. | |
| """ | |
| if not files: | |
| raise HTTPException(status_code=400, detail="No files provided.") | |
| job_id = external_id or str(uuid.uuid4()) | |
| tmp_dir = tempfile.mkdtemp(prefix=f"echoprime_{job_id}_") | |
| # Save all uploads | |
| for f in files: | |
| dest = Path(tmp_dir) / (f.filename or f.name or "upload.dcm") | |
| dest.parent.mkdir(parents=True, exist_ok=True) | |
| content = await f.read() | |
| dest.write_bytes(content) | |
| _save_job(job_id, { | |
| "status": "queued", | |
| "progress": 5, | |
| "step": 1, | |
| "stage_msg": "Files received, queuing analysis…", | |
| "study_id": job_id, | |
| "error": None, | |
| "result": None, | |
| "tmp_dir": tmp_dir, | |
| }) | |
| background_tasks.add_task(_run_analysis, job_id, tmp_dir) | |
| return JSONResponse({"jobId": job_id, "studyId": job_id}, status_code=202) | |
| async def get_status(job_id: str) -> JSONResponse: | |
| """ | |
| Poll analysis progress. | |
| Response shape: | |
| { | |
| "status": "queued" | "processing" | "done" | "error", | |
| "progress": 0-100, | |
| "step": 1-5, | |
| "stageMsg": str, | |
| "error": str | null | |
| } | |
| """ | |
| job = _load_job(job_id) | |
| if job is None: | |
| raise HTTPException(status_code=404, detail="Job not found.") | |
| return JSONResponse({ | |
| "status": job["status"], | |
| "progress": job["progress"], | |
| "step": job["step"], | |
| "stageMsg": job["stage_msg"], | |
| "studyId": job["study_id"], | |
| "error": job["error"], | |
| }) | |
| async def get_study(study_id: str) -> JSONResponse: | |
| """ | |
| Return full study results once analysis is complete. | |
| Response shape: | |
| { | |
| "studyId": str, | |
| "createdAt": str, | |
| "views": [ | |
| { | |
| "id": str, "label": str, "confidence": float, | |
| "linkedSections": [str], "colorCode": str, "frameIndex": int | |
| }, … | |
| ], | |
| "report": { | |
| "lv": str, "la": str, "rh": str, "valves": str, "pericardium": str | |
| }, | |
| "rawReport": str, | |
| "metrics": { phenotype: value|null, … }, | |
| "cardMetrics": [ | |
| { "id": str, "label": str, "value": float, "unit": str, "normal": str }, | |
| … | |
| ] | |
| } | |
| """ | |
| job = _load_job(study_id) | |
| if job is None: | |
| raise HTTPException(status_code=404, detail="Study not found.") | |
| if job["status"] == "error": | |
| raise HTTPException(status_code=500, detail=job["error"]) | |
| if job["status"] != "done": | |
| raise HTTPException(status_code=202, detail="Analysis still in progress.") | |
| return JSONResponse(job["result"]) | |
| async def get_view_frame(study_id: str, frame_index: int) -> Response: | |
| """ | |
| Return the labelled first-frame PNG for a specific view as a raw PNG. | |
| The frontend can use this as <img src="/api/view/{studyId}/{i}/frame" />. | |
| """ | |
| b64 = _frame_cache.get(f"{study_id}:{frame_index}") | |
| if b64 is None: | |
| raise HTTPException(status_code=404, detail="Frame not found.") | |
| img_bytes = base64.b64decode(b64) | |
| return Response(content=img_bytes, media_type="image/png") | |
| async def delete_study(study_id: str) -> JSONResponse: | |
| """Remove job + cached frames from memory.""" | |
| if _load_job(study_id) is None: | |
| raise HTTPException(status_code=404, detail="Study not found.") | |
| _delete_job(study_id) | |
| _delete_stack(study_id) | |
| (STACKS_DIR / f"{study_id}_sections.json").unlink(missing_ok=True) | |
| # Remove all cached frames for this study | |
| keys_to_delete = [k for k in _frame_cache if k.startswith(f"{study_id}:")] | |
| for k in keys_to_delete: | |
| _frame_cache.pop(k, None) | |
| return JSONResponse({"deleted": study_id}) | |
| async def compute_gradcam(study_id: str, view_idx: int, section: str) -> JSONResponse: | |
| ep = get_model() | |
| stack = _load_stack(study_id) | |
| if stack is None: | |
| raise HTTPException(status_code=404, detail="Study video data not found. Re-upload the study.") | |
| sections_path = STACKS_DIR / f"{study_id}_sections.json" | |
| if not sections_path.exists(): | |
| raise HTTPException(status_code=404, detail="Section embeddings not found.") | |
| section_embeddings = json.loads(sections_path.read_text()) | |
| if section not in section_embeddings: | |
| available = list(section_embeddings.keys()) | |
| raise HTTPException(status_code=400, detail=f"Section not found. Available: {available}") | |
| if view_idx < 0 or view_idx >= stack.shape[0]: | |
| raise HTTPException(status_code=400, detail=f"view_idx {view_idx} out of range") | |
| try: | |
| target = torch.tensor(section_embeddings[section], dtype=torch.float32) | |
| video = stack[view_idx:view_idx + 1].cpu() | |
| gradcam = _get_gradcam(ep) | |
| for p in gradcam.model.parameters(): | |
| p.requires_grad_(True) | |
| try: | |
| heatmap = gradcam.generate(video, target, output_size=(16, 224, 224)) | |
| finally: | |
| for p in gradcam.model.parameters(): | |
| p.requires_grad_(False) | |
| raw_video = stack[view_idx].cpu() | |
| T = raw_video.shape[1] | |
| frames = [] | |
| for t in range(min(T, 16)): | |
| frame = raw_video[:, t, :, :].permute(1, 2, 0).numpy() | |
| mean = np.array([29.110628, 28.076836, 29.096405]) | |
| std = np.array([47.989223, 46.456997, 47.20083]) | |
| frame = frame * std + mean | |
| frame = np.clip(frame, 0, 255).astype(np.uint8) | |
| frames.append(frame) | |
| frames_np = np.stack(frames) | |
| overlay = apply_heatmap_to_video(frames_np, heatmap[:len(frames)], alpha=0.4) | |
| encoded_frames = [] | |
| for t in range(overlay.shape[0]): | |
| _, buf = cv2.imencode(".jpg", cv2.cvtColor(overlay[t], cv2.COLOR_RGB2BGR), | |
| [cv2.IMWRITE_JPEG_QUALITY, 85]) | |
| encoded_frames.append(base64.b64encode(buf.tobytes()).decode()) | |
| return JSONResponse({ | |
| "studyId": study_id, | |
| "viewIdx": view_idx, | |
| "section": section, | |
| "numFrames": len(encoded_frames), | |
| "frames": encoded_frames, | |
| }) | |
| except Exception as exc: | |
| import traceback; traceback.print_exc() | |
| raise HTTPException(status_code=500, detail=f"GradCAM failed: {exc}") | |
| async def list_gradcam_sections(study_id: str) -> JSONResponse: | |
| sections_path = STACKS_DIR / f"{study_id}_sections.json" | |
| if not sections_path.exists(): | |
| raise HTTPException(status_code=404, detail="Study not found.") | |
| sections = list(json.loads(sections_path.read_text()).keys()) | |
| return JSONResponse({"sections": sections}) | |
| async def health() -> JSONResponse: | |
| return JSONResponse({"status": "ok", "modelReady": _model is not None}) |