|
|
|
|
|
|
|
|
|
|
|
import io |
|
|
import os |
|
|
import base64 |
|
|
import glob |
|
|
import hashlib |
|
|
import tempfile |
|
|
from pathlib import Path |
|
|
from dataclasses import dataclass |
|
|
from typing import Dict, List, Optional, Tuple, Any |
|
|
|
|
|
import requests |
|
|
from fastapi import FastAPI, UploadFile, File, HTTPException, Query, Form |
|
|
from fastapi.middleware.cors import CORSMiddleware |
|
|
from fastapi.responses import HTMLResponse, JSONResponse |
|
|
from pydantic import BaseModel |
|
|
from PIL import Image |
|
|
|
|
|
|
|
|
TORCH_AVAILABLE = False |
|
|
_TV_WEIGHTS_ENUM = False |
|
|
try: |
|
|
import torch |
|
|
TORCH_AVAILABLE = True |
|
|
try: |
|
|
|
|
|
from torchvision import transforms as T |
|
|
from torchvision.models import resnet50, mobilenet_v3_large |
|
|
try: |
|
|
from torchvision.models import ResNet50_Weights, MobileNet_V3_Large_Weights |
|
|
_TV_WEIGHTS_ENUM = True |
|
|
except Exception: |
|
|
ResNet50_Weights = None |
|
|
MobileNet_V3_Large_Weights = None |
|
|
_TV_WEIGHTS_ENUM = False |
|
|
except Exception: |
|
|
|
|
|
T = None |
|
|
resnet50 = mobilenet_v3_large = None |
|
|
except Exception: |
|
|
torch = None |
|
|
T = None |
|
|
resnet50 = mobilenet_v3_large = None |
|
|
|
|
|
|
|
|
DEFAULT_TASKS = ["dr"] |
|
|
TASK_DEFAULT_CLASSES_FA: Dict[str, List[str]] = { |
|
|
"dr": ["بدون DR", "خفیف", "متوسط", "شدید", "پرولیفراکتیو"], |
|
|
"oct_cme": ["بدون CME", "CME"], |
|
|
"oct_csr": ["بدون CSR", "CSR"], |
|
|
"oct_amd": ["بدون AMD", "خشک", "تر"], |
|
|
"glaucoma": ["نرمال", "گلوکوم"], |
|
|
"keratoconus": ["نرمال", "کراتوکونوس"], |
|
|
} |
|
|
TASK_DEFAULT_CLASSES_EN: Dict[str, List[str]] = { |
|
|
"dr": ["No DR", "Mild", "Moderate", "Severe", "Proliferative DR"], |
|
|
"oct_cme": ["No CME", "CME"], |
|
|
"oct_csr": ["No CSR", "CSR"], |
|
|
"oct_amd": ["No AMD", "Dry", "Wet"], |
|
|
"glaucoma": ["Normal", "Glaucoma"], |
|
|
"keratoconus": ["Normal", "Keratoconus"], |
|
|
} |
|
|
TASK_DEFAULT_IMG: Dict[str, int] = { |
|
|
"dr": 448, |
|
|
"oct_cme": 416, |
|
|
"oct_csr": 416, |
|
|
"oct_amd": 416, |
|
|
"glaucoma": 416, |
|
|
"keratoconus": 416, |
|
|
} |
|
|
TASK_DEFAULT_MODEL: Dict[str, str] = { |
|
|
"dr": "resnet50", |
|
|
"oct_cme": "resnet50", |
|
|
"oct_csr": "resnet50", |
|
|
"oct_amd": "resnet50", |
|
|
"glaucoma": "resnet50", |
|
|
"keratoconus": "resnet50", |
|
|
} |
|
|
|
|
|
|
|
|
DEFAULT_WEIGHTS_DIR = os.getenv("RETINA_WEIGHTS_DIR", "/app/models") |
|
|
WEIGHT_PATTERNS = { |
|
|
"dr": ["runs_k80/phase2/best.pth", "dr/*.pth", "*.pth"], |
|
|
"oct_cme": ["oct_cme/best.pth", "oct_cme/*.pth", "*.pth"], |
|
|
"oct_csr": ["oct_csr/best.pth", "oct_csr/*.pth", "*.pth"], |
|
|
"oct_amd": ["oct_amd/best.pth", "oct_amd/*.pth", "*.pth"], |
|
|
"glaucoma": ["glaucoma/best.pth", "glaucoma/*.pth", "*.pth"], |
|
|
"keratoconus": ["keratoconus/best.pth", "keratoconus/*.pth", "*.pth"], |
|
|
} |
|
|
|
|
|
def _find_candidate_weights(task: str) -> List[str]: |
|
|
root = Path(DEFAULT_WEIGHTS_DIR) |
|
|
pats = WEIGHT_PATTERNS.get(task, ["*.pth"]) |
|
|
found: List[str] = [] |
|
|
for p in pats: |
|
|
found.extend(glob.glob(str(root / p))) |
|
|
uniq = sorted( |
|
|
set(found), |
|
|
key=lambda p: Path(p).stat().st_mtime if Path(p).exists() else 0, |
|
|
reverse=True, |
|
|
) |
|
|
return [f for f in uniq if Path(f).is_file()] |
|
|
|
|
|
def _download(url: str, dest: Path, sha256: Optional[str] = None) -> Path: |
|
|
dest.parent.mkdir(parents=True, exist_ok=True) |
|
|
with requests.get(url, stream=True, timeout=60) as r: |
|
|
r.raise_for_status() |
|
|
h = hashlib.sha256() |
|
|
with tempfile.NamedTemporaryFile(delete=False, dir=str(dest.parent), suffix=".part") as tmp: |
|
|
for chunk in r.iter_content(chunk_size=1024*1024): |
|
|
if not chunk: |
|
|
continue |
|
|
tmp.write(chunk) |
|
|
h.update(chunk) |
|
|
tmp_path = Path(tmp.name) |
|
|
if sha256 and h.hexdigest().lower() != sha256.lower(): |
|
|
tmp_path.unlink(missing_ok=True) |
|
|
raise RuntimeError(f"SHA256 mismatch for {url}") |
|
|
tmp_path.replace(dest) |
|
|
return dest |
|
|
|
|
|
def _pick_weight(task: str) -> Tuple[Optional[str], List[str]]: |
|
|
env_path = os.getenv(f"RETINA_WEIGHTS_{task}") |
|
|
if env_path and Path(env_path).is_file(): |
|
|
return env_path, [env_path] |
|
|
cands = _find_candidate_weights(task) |
|
|
if cands: |
|
|
return cands[0], cands |
|
|
url = os.getenv(f"RETINA_WEIGHTS_URL_{task}") |
|
|
sha = os.getenv(f"RETINA_WEIGHTS_SHA256_{task}") |
|
|
if url: |
|
|
dest = Path(DEFAULT_WEIGHTS_DIR) / task / "best.pth" |
|
|
try: |
|
|
print(f"[weights] downloading {task} from {url} → {dest}") |
|
|
got = _download(url, dest, sha256=sha) |
|
|
return str(got), [str(got)] |
|
|
except Exception as e: |
|
|
print(f"[weights] download failed for {task}: {e}") |
|
|
return None, [] |
|
|
|
|
|
|
|
|
def device_setup() -> str: |
|
|
if TORCH_AVAILABLE and torch.cuda.is_available(): |
|
|
torch.backends.cudnn.enabled = False |
|
|
return "cuda" |
|
|
return "cpu" |
|
|
|
|
|
def build_model(name: str, num_classes: int): |
|
|
if not (TORCH_AVAILABLE and resnet50 and mobilenet_v3_large): |
|
|
raise RuntimeError("Torch/torchvision not available in this runtime.") |
|
|
name = name.lower() |
|
|
if name in ("resnet50", "resnet"): |
|
|
if _TV_WEIGHTS_ENUM: |
|
|
m = resnet50(weights=None) |
|
|
else: |
|
|
m = resnet50(pretrained=False) |
|
|
import torch.nn as nn |
|
|
m.fc = nn.Linear(m.fc.in_features, num_classes) |
|
|
return m |
|
|
elif name in ("mobilenetv3", "mobilenet_v3", "mbv3"): |
|
|
if _TV_WEIGHTS_ENUM: |
|
|
m = mobilenet_v3_large(weights=None) |
|
|
else: |
|
|
m = mobilenet_v3_large(pretrained=False) |
|
|
import torch.nn as nn |
|
|
m.classifier[3] = nn.Linear(m.classifier[3].in_features, num_classes) |
|
|
return m |
|
|
else: |
|
|
raise ValueError(f"Unknown model: {name}") |
|
|
|
|
|
def make_transform(img_size: int): |
|
|
if not (TORCH_AVAILABLE and T): |
|
|
|
|
|
def _noop(x): return x |
|
|
return _noop |
|
|
return T.Compose([ |
|
|
T.Resize(int(img_size * 1.15)), |
|
|
T.CenterCrop(img_size), |
|
|
T.ToTensor(), |
|
|
T.Normalize(mean=[0.485,0.456,0.406], std=[0.229,0.224,0.225]), |
|
|
]) |
|
|
|
|
|
def load_state(model, weights_path: str): |
|
|
if not TORCH_AVAILABLE: |
|
|
raise RuntimeError("Torch not available for loading state.") |
|
|
ckpt = torch.load(weights_path, map_location='cpu') |
|
|
state = ckpt.get("model", ckpt) |
|
|
new_state = {} |
|
|
for k, v in state.items(): |
|
|
nk = k[7:] if k.startswith("module.") else k |
|
|
new_state[nk] = v |
|
|
missing, unexpected = model.load_state_dict(new_state, strict=False) |
|
|
return list(missing), list(unexpected) |
|
|
|
|
|
@dataclass |
|
|
class TaskModel: |
|
|
name: str |
|
|
model: Optional[Any] |
|
|
device: str |
|
|
img_size: int |
|
|
classes_fa: List[str] |
|
|
classes_en: List[str] |
|
|
weights_path: Optional[str] |
|
|
missing_keys: List[str] |
|
|
unexpected_keys: List[str] |
|
|
transform: Any |
|
|
|
|
|
def env_list(key: str, default: Optional[List[str]] = None) -> List[str]: |
|
|
raw = os.getenv(key) |
|
|
if not raw: |
|
|
return default or [] |
|
|
return [x.strip() for x in raw.split(",") if x.strip()] |
|
|
|
|
|
def parse_classes_env(task: str) -> Optional[List[str]]: |
|
|
key = f"RETINA_CLASSES_{task}" |
|
|
raw = os.getenv(key) |
|
|
if not raw: |
|
|
return None |
|
|
vals = [v.strip() for v in raw.split(",") if v.strip()] |
|
|
return vals or None |
|
|
|
|
|
def prepare_task(task: str, device: str) -> TaskModel: |
|
|
model_name = os.getenv(f"RETINA_MODEL_{task}", TASK_DEFAULT_MODEL.get(task, "resnet50")) |
|
|
img_size = int(os.getenv(f"RETINA_IMG_SIZE_{task}", str(TASK_DEFAULT_IMG.get(task, 416)))) |
|
|
classes_en = parse_classes_env(task) or TASK_DEFAULT_CLASSES_EN.get(task, ["Negative","Positive"]) |
|
|
classes_fa_default = TASK_DEFAULT_CLASSES_FA.get(task, ["منفی","مثبت"]) |
|
|
classes_fa = classes_fa_default if not parse_classes_env(task) else ( |
|
|
classes_fa_default if len(classes_fa_default)==len(classes_en) else classes_en |
|
|
) |
|
|
|
|
|
weights, all_cands = _pick_weight(task) |
|
|
|
|
|
|
|
|
if (not TORCH_AVAILABLE) or (not weights) or (not os.path.isfile(weights)): |
|
|
tm = TaskModel(task, None, device, img_size, classes_fa, classes_en, weights if weights else None, |
|
|
[], [], make_transform(img_size)) |
|
|
tm._all_weight_candidates = all_cands |
|
|
return tm |
|
|
|
|
|
m = build_model(model_name, num_classes=len(classes_en)) |
|
|
missing, unexpected = load_state(m, weights) |
|
|
m.eval().to(device) |
|
|
if device == 'cuda': |
|
|
m.to(memory_format=torch.channels_last) |
|
|
|
|
|
tm = TaskModel(task, m, device, img_size, classes_fa, classes_en, weights, missing, unexpected, make_transform(img_size)) |
|
|
tm._all_weight_candidates = all_cands |
|
|
return tm |
|
|
|
|
|
def predict_with_task(task_obj: TaskModel, pil_im: Image.Image) -> List[float]: |
|
|
if (not TORCH_AVAILABLE) or (task_obj.model is None): |
|
|
raise RuntimeError("Local model not available.") |
|
|
x = task_obj.transform(pil_im.convert("RGB")).unsqueeze(0) |
|
|
x = x.to(task_obj.device, non_blocking=True) |
|
|
with torch.no_grad(): |
|
|
logits = task_obj.model(x) |
|
|
probs = torch.softmax(logits, dim=1)[0].detach().cpu().numpy().tolist() |
|
|
return probs |
|
|
|
|
|
|
|
|
def _remote_base_for(task: str) -> Optional[str]: |
|
|
return os.getenv(f"RETINA_REMOTE_{task}") |
|
|
|
|
|
def _remote_auth_header_for(task: str) -> dict: |
|
|
token = os.getenv(f"RETINA_REMOTE_AUTH_{task}") or os.getenv("RETINA_REMOTE_AUTH") or "" |
|
|
return {"Authorization": token} if token.strip() else {} |
|
|
|
|
|
def _remote_verify_ssl() -> bool: |
|
|
v = (os.getenv("RETINA_REMOTE_VERIFY_SSL") or "true").strip().lower() |
|
|
return v not in ("0", "false", "no") |
|
|
|
|
|
def _remote_timeout() -> int: |
|
|
try: |
|
|
return int(os.getenv("RETINA_REMOTE_TIMEOUT", "90")) |
|
|
except Exception: |
|
|
return 90 |
|
|
|
|
|
def _remote_url(task: str, mode: str) -> Optional[str]: |
|
|
base = _remote_base_for(task) |
|
|
if not base: |
|
|
return None |
|
|
base = base.strip() |
|
|
if base.endswith("/predict_task") or base.endswith("/report_task"): |
|
|
return base |
|
|
return f"{base.rstrip('/')}/{ 'predict_task' if mode == 'predict' else 'report_task'}?task={task}" |
|
|
|
|
|
def _proxy_predict_task(task: str, file_bytes: bytes, filename: str = "image.jpg") -> JSONResponse: |
|
|
url = _remote_url(task, "predict") |
|
|
if not url: |
|
|
raise HTTPException(status_code=501, detail=f"Task '{task}' not loaded and no remote set (RETINA_REMOTE_{task}).") |
|
|
headers = _remote_auth_header_for(task) |
|
|
try: |
|
|
r = requests.post( |
|
|
url, |
|
|
files={"file": (filename, file_bytes, "image/jpeg")}, |
|
|
headers=headers, |
|
|
timeout=_remote_timeout(), |
|
|
verify=_remote_verify_ssl(), |
|
|
) |
|
|
if not (200 <= r.status_code < 300): |
|
|
raise HTTPException(status_code=r.status_code, detail=f"Remote error: {r.text}") |
|
|
try: |
|
|
return JSONResponse(r.json()) |
|
|
except Exception: |
|
|
return JSONResponse({"remote_raw": r.text}) |
|
|
except requests.RequestException as e: |
|
|
raise HTTPException(status_code=502, detail=f"Remote proxy failed: {e}") |
|
|
|
|
|
def _proxy_report_task(task: str, file_bytes: bytes, form: dict, filename: str = "image.jpg") -> JSONResponse: |
|
|
url = _remote_url(task, "report") |
|
|
if not url: |
|
|
raise HTTPException(status_code=501, detail=f"Task '{task}' not loaded and no remote set (RETINA_REMOTE_{task}).") |
|
|
headers = _remote_auth_header_for(task) |
|
|
try: |
|
|
r = requests.post( |
|
|
url, |
|
|
files={"file": (filename, file_bytes, "image/jpeg")}, |
|
|
data=form, |
|
|
headers=headers, |
|
|
timeout=_remote_timeout(), |
|
|
verify=_remote_verify_ssl(), |
|
|
) |
|
|
if not (200 <= r.status_code < 300): |
|
|
raise HTTPException(status_code=r.status_code, detail=f"Remote error: {r.text}") |
|
|
try: |
|
|
return JSONResponse(r.json()) |
|
|
except Exception: |
|
|
return JSONResponse({"remote_raw": r.text}) |
|
|
except requests.RequestException as e: |
|
|
raise HTTPException(status_code=502, detail=f"Remote proxy failed: {e}") |
|
|
|
|
|
|
|
|
app = FastAPI(title="Retina Multi-Task Inference API (Unified)", version="1.3.1") |
|
|
app.add_middleware( |
|
|
CORSMiddleware, |
|
|
allow_origins=["*"], allow_credentials=True, |
|
|
allow_methods=["*"], allow_headers=["*"], |
|
|
) |
|
|
|
|
|
_DEVICE = device_setup() |
|
|
_TASKS = env_list("RETINA_TASKS", DEFAULT_TASKS) |
|
|
_TASK_MODELS: Dict[str, TaskModel] = {t: prepare_task(t, _DEVICE) for t in _TASKS} |
|
|
DEFAULT_FALLBACK_TASK = os.getenv("RETINA_DEFAULT_TASK", "dr").strip().lower() |
|
|
|
|
|
|
|
|
def _simple_qc(im: Image.Image) -> dict: |
|
|
try: |
|
|
import numpy as np |
|
|
except Exception: |
|
|
w, h = im.size |
|
|
return {"width": w, "height": h, "mean_luma": None, "warnings": [], "ok": True} |
|
|
w, h = im.size |
|
|
mean_luma = float(np.array(im.convert("L")).mean()) |
|
|
warns: List[str] = [] |
|
|
if min(w, h) < 512: warns.append("low_resolution") |
|
|
if mean_luma < 25: warns.append("too_dark") |
|
|
if mean_luma > 230: warns.append("too_bright") |
|
|
return {"width": w, "height": h, "mean_luma": round(mean_luma,1), "warnings": warns, "ok": len(warns)==0} |
|
|
|
|
|
def _items_from_probs(task: str, probs: List[float]): |
|
|
tm = _TASK_MODELS[task] |
|
|
items = [{"index": i, |
|
|
"class_en": tm.classes_en[i], |
|
|
"class_fa": tm.classes_fa[i], |
|
|
"prob": float(p)} for i, p in enumerate(probs)] |
|
|
items_sorted = sorted(items, key=lambda d: d["prob"], reverse=True) |
|
|
top1 = items_sorted[0] |
|
|
return items_sorted, top1 |
|
|
|
|
|
def _format_report(task: str, probs: List[float], patient_name: str = "", exam_date: str = "", eye: str = "") -> str: |
|
|
tm = _TASK_MODELS[task] |
|
|
items, top = _items_from_probs(task, probs) |
|
|
title_map = { |
|
|
"dr": "گزارش رتینوپاتی دیابتی (DR)", |
|
|
"oct_cme": "گزارش OCT - CME", |
|
|
"oct_csr": "گزارش OCT - CSR", |
|
|
"oct_amd": "گزارش OCT - AMD", |
|
|
"glaucoma": "گزارش گلوکوم", |
|
|
"keratoconus": "گزارش کراتوکونوس", |
|
|
} |
|
|
title = title_map.get(task, f"گزارش {task}") |
|
|
lines: List[str] = [] |
|
|
lines.append(f"👁 {title} برای بیمار: {patient_name or '—'}") |
|
|
lines.append(f"📅 تاریخ معاینه: {exam_date or '—'}") |
|
|
if eye: lines.append(f"👓 چشم: {eye}") |
|
|
lines.append("________________________________________") |
|
|
lines.append("📌 نتیجه الگوریتم (Top-1):") |
|
|
lines.append(f"• {top['class_fa']} ({top['class_en']}) — احتمال {top['prob']:.3f}") |
|
|
lines.append("📊 توزیع احتمالات:") |
|
|
for it in items: |
|
|
lines.append(f"• {it['class_fa']} ({it['class_en']}) — {it['prob']:.4f}") |
|
|
if task == "dr": |
|
|
lines.append("🧠 یادداشت: نتیجه برای کمک به تصمیمگیری است؛ در موارد مثبت معاینه بالینی/تصویربرداری تکمیلی توصیه میشود.") |
|
|
elif task.startswith("oct_"): |
|
|
lines.append("🧠 یادداشت: تفسیر نهایی با همبستگی بالینی و تصاویر مکمل.") |
|
|
elif task in ("glaucoma", "keratoconus"): |
|
|
lines.append("🧠 یادداشت: جایگزین تشخیص پزشک نیست و باید با پاراکلینیک تلفیق شود.") |
|
|
return "\n".join(lines) |
|
|
|
|
|
|
|
|
@app.get("/", response_class=HTMLResponse) |
|
|
def root(): |
|
|
li = "".join([f"<li>{t} — loaded={_TASK_MODELS[t].model is not None} — img={_TASK_MODELS[t].img_size}</li>" for t in _TASKS]) |
|
|
return f""" |
|
|
<html><head><meta charset="utf-8"><title>Retina Unified API</title></head> |
|
|
<body style="font-family:Tahoma,Arial,sans-serif"> |
|
|
<h2>Retina Multi-Task Predictor (Single Port)</h2> |
|
|
<p>Device: <b>{_DEVICE}</b> | Tasks: {", ".join(_TASKS)}</p> |
|
|
<ul>{li}</ul> |
|
|
<h3>Quick Forms</h3> |
|
|
<form action="/predict" method="post" enctype="multipart/form-data"> |
|
|
<div><b>Back-compat /predict (RETINA_DEFAULT_TASK = {DEFAULT_FALLBACK_TASK})</b></div> |
|
|
<input type="file" name="file" accept="image/*" required /> |
|
|
<button type="submit">/predict</button> |
|
|
</form> |
|
|
<hr/> |
|
|
<form action="/predict_task?task=oct_cme" method="post" enctype="multipart/form-data"> |
|
|
<div><b>OCT - CME</b></div> |
|
|
<input type="file" name="file" accept="image/*" required /> |
|
|
<button type="submit">/predict_task?task=oct_cme</button> |
|
|
</form> |
|
|
</body></html> |
|
|
""" |
|
|
|
|
|
|
|
|
@app.get("/tasks") |
|
|
def tasks(): |
|
|
out = {} |
|
|
for t, tm in _TASK_MODELS.items(): |
|
|
out[t] = { |
|
|
"loaded": tm.model is not None, |
|
|
"img_size": tm.img_size, |
|
|
"classes_en": tm.classes_en, |
|
|
"classes_fa": tm.classes_fa, |
|
|
"weights_used": tm.weights_path, |
|
|
"weights_candidates": getattr(tm, "_all_weight_candidates", []), |
|
|
"missing_keys": tm.missing_keys, |
|
|
"unexpected_keys": tm.unexpected_keys, |
|
|
"remote_url": _remote_url(t, "predict"), |
|
|
} |
|
|
return out |
|
|
|
|
|
@app.get("/health") |
|
|
def health(): |
|
|
return { |
|
|
"device": _DEVICE, |
|
|
"cuda": bool(TORCH_AVAILABLE and torch and torch.cuda.is_available()), |
|
|
"cudnn_enabled": bool(TORCH_AVAILABLE and torch and torch.backends.cudnn.enabled), |
|
|
"tasks": list(_TASK_MODELS.keys()), |
|
|
"loaded": {t: (_TASK_MODELS[t].model is not None) for t in _TASK_MODELS}, |
|
|
} |
|
|
|
|
|
|
|
|
@app.post("/predict_task") |
|
|
def predict_task( |
|
|
task: str = Query(..., description="dr, oct_cme, oct_csr, oct_amd, glaucoma, keratoconus"), |
|
|
file: UploadFile = File(...) |
|
|
): |
|
|
task = task.strip().lower() |
|
|
if task not in _TASK_MODELS: |
|
|
raise HTTPException(status_code=404, detail=f"Unknown task: {task}") |
|
|
tm = _TASK_MODELS[task] |
|
|
|
|
|
try: |
|
|
raw = file.file.read() |
|
|
except Exception: |
|
|
raise HTTPException(status_code=400, detail="Invalid file") |
|
|
|
|
|
if tm.model is None: |
|
|
return _proxy_predict_task(task, raw, filename=getattr(file, "filename", "image.jpg")) |
|
|
|
|
|
try: |
|
|
im = Image.open(io.BytesIO(raw)) |
|
|
except Exception: |
|
|
raise HTTPException(status_code=400, detail="Invalid image") |
|
|
|
|
|
qc = _simple_qc(im) |
|
|
probs = predict_with_task(tm, im) |
|
|
items_sorted, top1 = _items_from_probs(task, probs) |
|
|
return JSONResponse({ |
|
|
"task": task, |
|
|
"qc": qc, |
|
|
"top1": top1, |
|
|
"probs": items_sorted, |
|
|
"weights_used": tm.weights_path, |
|
|
"weights_candidates": getattr(tm, "_all_weight_candidates", []), |
|
|
}) |
|
|
|
|
|
@app.post("/report_task") |
|
|
def report_task( |
|
|
task: str = Query(...), |
|
|
file: UploadFile = File(...), |
|
|
patient_name: str = Form(""), |
|
|
exam_date: str = Form(""), |
|
|
eye: str = Form("") |
|
|
): |
|
|
task = task.strip().lower() |
|
|
if task not in _TASK_MODELS: |
|
|
raise HTTPException(status_code=404, detail=f"Unknown task: {task}") |
|
|
tm = _TASK_MODELS[task] |
|
|
|
|
|
try: |
|
|
raw = file.file.read() |
|
|
except Exception: |
|
|
raise HTTPException(status_code=400, detail="Invalid file") |
|
|
|
|
|
if tm.model is None: |
|
|
form = {"patient_name": patient_name, "exam_date": exam_date, "eye": eye} |
|
|
return _proxy_report_task(task, raw, form, filename=getattr(file, "filename", "image.jpg")) |
|
|
|
|
|
try: |
|
|
im = Image.open(io.BytesIO(raw)) |
|
|
except Exception: |
|
|
raise HTTPException(status_code=400, detail="Invalid image") |
|
|
|
|
|
qc = _simple_qc(im) |
|
|
probs = predict_with_task(tm, im) |
|
|
items_sorted, top1 = _items_from_probs(task, probs) |
|
|
report_fa = _format_report(task, probs, patient_name=patient_name, exam_date=exam_date, eye=eye) |
|
|
|
|
|
return JSONResponse({ |
|
|
"task": task, |
|
|
"patient": {"name": patient_name, "exam_date": exam_date, "eye": eye}, |
|
|
"qc": qc, "top1": top1, "probs": items_sorted, |
|
|
"report": report_fa, |
|
|
"weights_used": tm.weights_path, |
|
|
"weights_candidates": getattr(tm, "_all_weight_candidates", []), |
|
|
}) |
|
|
|
|
|
|
|
|
class PredictJsonReq(BaseModel): |
|
|
image_b64: str |
|
|
|
|
|
def _get_fallback_task() -> str: |
|
|
t = os.getenv("RETINA_DEFAULT_TASK", "dr").strip().lower() |
|
|
if t not in _TASK_MODELS: |
|
|
raise HTTPException(status_code=404, detail=f"Unknown default task: {t}") |
|
|
return t |
|
|
|
|
|
@app.post("/predict") |
|
|
def predict(file: UploadFile = File(...)): |
|
|
task = _get_fallback_task() |
|
|
tm = _TASK_MODELS[task] |
|
|
try: |
|
|
raw = file.file.read() |
|
|
except Exception: |
|
|
raise HTTPException(status_code=400, detail="Invalid file") |
|
|
|
|
|
if tm.model is None: |
|
|
return _proxy_predict_task(task, raw, filename=getattr(file, "filename", "image.jpg")) |
|
|
|
|
|
try: |
|
|
im = Image.open(io.BytesIO(raw)) |
|
|
except Exception: |
|
|
raise HTTPException(status_code=400, detail="Invalid image") |
|
|
|
|
|
qc = _simple_qc(im) |
|
|
probs = predict_with_task(tm, im) |
|
|
items_sorted, top1 = _items_from_probs(task, probs) |
|
|
return {"task": task, "qc": qc, "top1": top1, "probs": items_sorted} |
|
|
|
|
|
@app.post("/predict_json") |
|
|
def predict_json(req: PredictJsonReq): |
|
|
task = _get_fallback_task() |
|
|
tm = _TASK_MODELS[task] |
|
|
try: |
|
|
data = base64.b64decode(req.image_b64) |
|
|
except Exception: |
|
|
raise HTTPException(status_code=400, detail="Invalid base64 image") |
|
|
|
|
|
if tm.model is None: |
|
|
return _proxy_predict_task(task, data, filename="image.jpg") |
|
|
|
|
|
try: |
|
|
im = Image.open(io.BytesIO(data)) |
|
|
except Exception: |
|
|
raise HTTPException(status_code=400, detail="Invalid image data") |
|
|
|
|
|
qc = _simple_qc(im) |
|
|
probs = predict_with_task(tm, im) |
|
|
items_sorted, top1 = _items_from_probs(task, probs) |
|
|
return {"task": task, "qc": qc, "top1": top1, "probs": items_sorted} |
|
|
|
|
|
@app.post("/report") |
|
|
def report( |
|
|
file: UploadFile = File(...), |
|
|
patient_name: str = Form(""), |
|
|
exam_date: str = Form(""), |
|
|
eye: str = Form("OD") |
|
|
): |
|
|
task = _get_fallback_task() |
|
|
tm = _TASK_MODELS[task] |
|
|
try: |
|
|
raw = file.file.read() |
|
|
except Exception: |
|
|
raise HTTPException(status_code=400, detail="Invalid file") |
|
|
|
|
|
if tm.model is None: |
|
|
form = {"patient_name": patient_name, "exam_date": exam_date, "eye": eye} |
|
|
return _proxy_report_task(task, raw, form, filename=getattr(file, "filename", "image.jpg")) |
|
|
|
|
|
try: |
|
|
im = Image.open(io.BytesIO(raw)) |
|
|
except Exception: |
|
|
raise HTTPException(status_code=400, detail="Invalid image") |
|
|
|
|
|
qc = _simple_qc(im) |
|
|
probs = predict_with_task(tm, im) |
|
|
items_sorted, top1 = _items_from_probs(task, probs) |
|
|
rep = _format_report(task, probs, patient_name=patient_name, exam_date=exam_date, eye=eye) |
|
|
return {"task": task, |
|
|
"patient": {"name": patient_name, "exam_date": exam_date, "eye": eye}, |
|
|
"qc": qc, "top1": top1, "probs": items_sorted, |
|
|
"report": rep} |
|
|
|
|
|
@app.post("/predict_strict") |
|
|
def predict_strict(file: UploadFile = File(...), tta: int = 1): |
|
|
return predict(file) |
|
|
|