eye-api / retina_api_multi.py
hadi6681's picture
Update retina_api_multi.py
16baeec verified
#!/usr/bin/env python3
# Retina/eye multi-task inference API (single-port, Torch-optional)
import io
import os
import base64
import glob
import hashlib
import tempfile
from pathlib import Path
from dataclasses import dataclass
from typing import Dict, List, Optional, Tuple, Any
import requests
from fastapi import FastAPI, UploadFile, File, HTTPException, Query, Form
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import HTMLResponse, JSONResponse
from pydantic import BaseModel
from PIL import Image
# -------------------- Torch / Torchvision (optional) --------------------
TORCH_AVAILABLE = False
_TV_WEIGHTS_ENUM = False
try:
import torch # type: ignore
TORCH_AVAILABLE = True
try:
# import torchvision only if torch is OK
from torchvision import transforms as T # type: ignore
from torchvision.models import resnet50, mobilenet_v3_large # type: ignore
try:
from torchvision.models import ResNet50_Weights, MobileNet_V3_Large_Weights # type: ignore
_TV_WEIGHTS_ENUM = True
except Exception:
ResNet50_Weights = None # type: ignore
MobileNet_V3_Large_Weights = None # type: ignore
_TV_WEIGHTS_ENUM = False
except Exception:
# torchvision هم در دسترس نبود
T = None # type: ignore
resnet50 = mobilenet_v3_large = None # type: ignore
except Exception:
torch = None # type: ignore
T = None # type: ignore
resnet50 = mobilenet_v3_large = None # type: ignore
# -------------------- Defaults per task --------------------
DEFAULT_TASKS = ["dr"]
TASK_DEFAULT_CLASSES_FA: Dict[str, List[str]] = {
"dr": ["بدون DR", "خفیف", "متوسط", "شدید", "پرولیفراکتیو"],
"oct_cme": ["بدون CME", "CME"],
"oct_csr": ["بدون CSR", "CSR"],
"oct_amd": ["بدون AMD", "خشک", "تر"],
"glaucoma": ["نرمال", "گلوکوم"],
"keratoconus": ["نرمال", "کراتوکونوس"],
}
TASK_DEFAULT_CLASSES_EN: Dict[str, List[str]] = {
"dr": ["No DR", "Mild", "Moderate", "Severe", "Proliferative DR"],
"oct_cme": ["No CME", "CME"],
"oct_csr": ["No CSR", "CSR"],
"oct_amd": ["No AMD", "Dry", "Wet"],
"glaucoma": ["Normal", "Glaucoma"],
"keratoconus": ["Normal", "Keratoconus"],
}
TASK_DEFAULT_IMG: Dict[str, int] = {
"dr": 448,
"oct_cme": 416,
"oct_csr": 416,
"oct_amd": 416,
"glaucoma": 416,
"keratoconus": 416,
}
TASK_DEFAULT_MODEL: Dict[str, str] = {
"dr": "resnet50",
"oct_cme": "resnet50",
"oct_csr": "resnet50",
"oct_amd": "resnet50",
"glaucoma": "resnet50",
"keratoconus": "resnet50",
}
# -------------------- Weights: autodiscovery / optional download --------------------
DEFAULT_WEIGHTS_DIR = os.getenv("RETINA_WEIGHTS_DIR", "/app/models")
WEIGHT_PATTERNS = {
"dr": ["runs_k80/phase2/best.pth", "dr/*.pth", "*.pth"],
"oct_cme": ["oct_cme/best.pth", "oct_cme/*.pth", "*.pth"],
"oct_csr": ["oct_csr/best.pth", "oct_csr/*.pth", "*.pth"],
"oct_amd": ["oct_amd/best.pth", "oct_amd/*.pth", "*.pth"],
"glaucoma": ["glaucoma/best.pth", "glaucoma/*.pth", "*.pth"],
"keratoconus": ["keratoconus/best.pth", "keratoconus/*.pth", "*.pth"],
}
def _find_candidate_weights(task: str) -> List[str]:
root = Path(DEFAULT_WEIGHTS_DIR)
pats = WEIGHT_PATTERNS.get(task, ["*.pth"])
found: List[str] = []
for p in pats:
found.extend(glob.glob(str(root / p)))
uniq = sorted(
set(found),
key=lambda p: Path(p).stat().st_mtime if Path(p).exists() else 0,
reverse=True,
)
return [f for f in uniq if Path(f).is_file()]
def _download(url: str, dest: Path, sha256: Optional[str] = None) -> Path:
dest.parent.mkdir(parents=True, exist_ok=True)
with requests.get(url, stream=True, timeout=60) as r:
r.raise_for_status()
h = hashlib.sha256()
with tempfile.NamedTemporaryFile(delete=False, dir=str(dest.parent), suffix=".part") as tmp:
for chunk in r.iter_content(chunk_size=1024*1024):
if not chunk:
continue
tmp.write(chunk)
h.update(chunk)
tmp_path = Path(tmp.name)
if sha256 and h.hexdigest().lower() != sha256.lower():
tmp_path.unlink(missing_ok=True)
raise RuntimeError(f"SHA256 mismatch for {url}")
tmp_path.replace(dest)
return dest
def _pick_weight(task: str) -> Tuple[Optional[str], List[str]]:
env_path = os.getenv(f"RETINA_WEIGHTS_{task}")
if env_path and Path(env_path).is_file():
return env_path, [env_path]
cands = _find_candidate_weights(task)
if cands:
return cands[0], cands
url = os.getenv(f"RETINA_WEIGHTS_URL_{task}")
sha = os.getenv(f"RETINA_WEIGHTS_SHA256_{task}")
if url:
dest = Path(DEFAULT_WEIGHTS_DIR) / task / "best.pth"
try:
print(f"[weights] downloading {task} from {url}{dest}")
got = _download(url, dest, sha256=sha)
return str(got), [str(got)]
except Exception as e:
print(f"[weights] download failed for {task}: {e}")
return None, []
# -------------------- Utils (Torch-aware) --------------------
def device_setup() -> str:
if TORCH_AVAILABLE and torch.cuda.is_available(): # type: ignore
torch.backends.cudnn.enabled = False # type: ignore
return "cuda"
return "cpu"
def build_model(name: str, num_classes: int):
if not (TORCH_AVAILABLE and resnet50 and mobilenet_v3_large):
raise RuntimeError("Torch/torchvision not available in this runtime.")
name = name.lower()
if name in ("resnet50", "resnet"):
if _TV_WEIGHTS_ENUM:
m = resnet50(weights=None) # type: ignore
else:
m = resnet50(pretrained=False) # type: ignore
import torch.nn as nn # local import (only when torch exists)
m.fc = nn.Linear(m.fc.in_features, num_classes)
return m
elif name in ("mobilenetv3", "mobilenet_v3", "mbv3"):
if _TV_WEIGHTS_ENUM:
m = mobilenet_v3_large(weights=None) # type: ignore
else:
m = mobilenet_v3_large(pretrained=False) # type: ignore
import torch.nn as nn # local import
m.classifier[3] = nn.Linear(m.classifier[3].in_features, num_classes)
return m
else:
raise ValueError(f"Unknown model: {name}")
def make_transform(img_size: int):
if not (TORCH_AVAILABLE and T):
# در حالت بدون Torch اصلاً این مسیر استفاده نمی‌شود
def _noop(x): return x
return _noop
return T.Compose([
T.Resize(int(img_size * 1.15)),
T.CenterCrop(img_size),
T.ToTensor(),
T.Normalize(mean=[0.485,0.456,0.406], std=[0.229,0.224,0.225]),
])
def load_state(model, weights_path: str):
if not TORCH_AVAILABLE:
raise RuntimeError("Torch not available for loading state.")
ckpt = torch.load(weights_path, map_location='cpu') # type: ignore
state = ckpt.get("model", ckpt)
new_state = {}
for k, v in state.items():
nk = k[7:] if k.startswith("module.") else k
new_state[nk] = v
missing, unexpected = model.load_state_dict(new_state, strict=False)
return list(missing), list(unexpected)
@dataclass
class TaskModel:
name: str
model: Optional[Any]
device: str
img_size: int
classes_fa: List[str]
classes_en: List[str]
weights_path: Optional[str]
missing_keys: List[str]
unexpected_keys: List[str]
transform: Any
def env_list(key: str, default: Optional[List[str]] = None) -> List[str]:
raw = os.getenv(key)
if not raw:
return default or []
return [x.strip() for x in raw.split(",") if x.strip()]
def parse_classes_env(task: str) -> Optional[List[str]]:
key = f"RETINA_CLASSES_{task}"
raw = os.getenv(key)
if not raw:
return None
vals = [v.strip() for v in raw.split(",") if v.strip()]
return vals or None
def prepare_task(task: str, device: str) -> TaskModel:
model_name = os.getenv(f"RETINA_MODEL_{task}", TASK_DEFAULT_MODEL.get(task, "resnet50"))
img_size = int(os.getenv(f"RETINA_IMG_SIZE_{task}", str(TASK_DEFAULT_IMG.get(task, 416))))
classes_en = parse_classes_env(task) or TASK_DEFAULT_CLASSES_EN.get(task, ["Negative","Positive"])
classes_fa_default = TASK_DEFAULT_CLASSES_FA.get(task, ["منفی","مثبت"])
classes_fa = classes_fa_default if not parse_classes_env(task) else (
classes_fa_default if len(classes_fa_default)==len(classes_en) else classes_en
)
weights, all_cands = _pick_weight(task)
# اگر torch/torchvision نیست یا وزنی نداریم → مدل لوکال لود نشود
if (not TORCH_AVAILABLE) or (not weights) or (not os.path.isfile(weights)):
tm = TaskModel(task, None, device, img_size, classes_fa, classes_en, weights if weights else None,
[], [], make_transform(img_size))
tm._all_weight_candidates = all_cands # type: ignore
return tm
m = build_model(model_name, num_classes=len(classes_en))
missing, unexpected = load_state(m, weights)
m.eval().to(device)
if device == 'cuda':
m.to(memory_format=torch.channels_last) # type: ignore
tm = TaskModel(task, m, device, img_size, classes_fa, classes_en, weights, missing, unexpected, make_transform(img_size))
tm._all_weight_candidates = all_cands # type: ignore
return tm
def predict_with_task(task_obj: TaskModel, pil_im: Image.Image) -> List[float]:
if (not TORCH_AVAILABLE) or (task_obj.model is None):
raise RuntimeError("Local model not available.")
x = task_obj.transform(pil_im.convert("RGB")).unsqueeze(0)
x = x.to(task_obj.device, non_blocking=True)
with torch.no_grad(): # type: ignore
logits = task_obj.model(x)
probs = torch.softmax(logits, dim=1)[0].detach().cpu().numpy().tolist() # type: ignore
return probs
# -------------------- Remote proxy helpers --------------------
def _remote_base_for(task: str) -> Optional[str]:
return os.getenv(f"RETINA_REMOTE_{task}")
def _remote_auth_header_for(task: str) -> dict:
token = os.getenv(f"RETINA_REMOTE_AUTH_{task}") or os.getenv("RETINA_REMOTE_AUTH") or ""
return {"Authorization": token} if token.strip() else {}
def _remote_verify_ssl() -> bool:
v = (os.getenv("RETINA_REMOTE_VERIFY_SSL") or "true").strip().lower()
return v not in ("0", "false", "no")
def _remote_timeout() -> int:
try:
return int(os.getenv("RETINA_REMOTE_TIMEOUT", "90"))
except Exception:
return 90
def _remote_url(task: str, mode: str) -> Optional[str]:
base = _remote_base_for(task)
if not base:
return None
base = base.strip()
if base.endswith("/predict_task") or base.endswith("/report_task"):
return base
return f"{base.rstrip('/')}/{ 'predict_task' if mode == 'predict' else 'report_task'}?task={task}"
def _proxy_predict_task(task: str, file_bytes: bytes, filename: str = "image.jpg") -> JSONResponse:
url = _remote_url(task, "predict")
if not url:
raise HTTPException(status_code=501, detail=f"Task '{task}' not loaded and no remote set (RETINA_REMOTE_{task}).")
headers = _remote_auth_header_for(task)
try:
r = requests.post(
url,
files={"file": (filename, file_bytes, "image/jpeg")},
headers=headers,
timeout=_remote_timeout(),
verify=_remote_verify_ssl(),
)
if not (200 <= r.status_code < 300):
raise HTTPException(status_code=r.status_code, detail=f"Remote error: {r.text}")
try:
return JSONResponse(r.json())
except Exception:
return JSONResponse({"remote_raw": r.text})
except requests.RequestException as e:
raise HTTPException(status_code=502, detail=f"Remote proxy failed: {e}")
def _proxy_report_task(task: str, file_bytes: bytes, form: dict, filename: str = "image.jpg") -> JSONResponse:
url = _remote_url(task, "report")
if not url:
raise HTTPException(status_code=501, detail=f"Task '{task}' not loaded and no remote set (RETINA_REMOTE_{task}).")
headers = _remote_auth_header_for(task)
try:
r = requests.post(
url,
files={"file": (filename, file_bytes, "image/jpeg")},
data=form,
headers=headers,
timeout=_remote_timeout(),
verify=_remote_verify_ssl(),
)
if not (200 <= r.status_code < 300):
raise HTTPException(status_code=r.status_code, detail=f"Remote error: {r.text}")
try:
return JSONResponse(r.json())
except Exception:
return JSONResponse({"remote_raw": r.text})
except requests.RequestException as e:
raise HTTPException(status_code=502, detail=f"Remote proxy failed: {e}")
# -------------------- App --------------------
app = FastAPI(title="Retina Multi-Task Inference API (Unified)", version="1.3.1")
app.add_middleware(
CORSMiddleware,
allow_origins=["*"], allow_credentials=True,
allow_methods=["*"], allow_headers=["*"],
)
_DEVICE = device_setup()
_TASKS = env_list("RETINA_TASKS", DEFAULT_TASKS)
_TASK_MODELS: Dict[str, TaskModel] = {t: prepare_task(t, _DEVICE) for t in _TASKS}
DEFAULT_FALLBACK_TASK = os.getenv("RETINA_DEFAULT_TASK", "dr").strip().lower()
# -------------------- Helpers for QC/format --------------------
def _simple_qc(im: Image.Image) -> dict:
try:
import numpy as np # lazy
except Exception:
w, h = im.size
return {"width": w, "height": h, "mean_luma": None, "warnings": [], "ok": True}
w, h = im.size
mean_luma = float(np.array(im.convert("L")).mean())
warns: List[str] = []
if min(w, h) < 512: warns.append("low_resolution")
if mean_luma < 25: warns.append("too_dark")
if mean_luma > 230: warns.append("too_bright")
return {"width": w, "height": h, "mean_luma": round(mean_luma,1), "warnings": warns, "ok": len(warns)==0}
def _items_from_probs(task: str, probs: List[float]):
tm = _TASK_MODELS[task]
items = [{"index": i,
"class_en": tm.classes_en[i],
"class_fa": tm.classes_fa[i],
"prob": float(p)} for i, p in enumerate(probs)]
items_sorted = sorted(items, key=lambda d: d["prob"], reverse=True)
top1 = items_sorted[0]
return items_sorted, top1
def _format_report(task: str, probs: List[float], patient_name: str = "", exam_date: str = "", eye: str = "") -> str:
tm = _TASK_MODELS[task]
items, top = _items_from_probs(task, probs)
title_map = {
"dr": "گزارش رتینوپاتی دیابتی (DR)",
"oct_cme": "گزارش OCT - CME",
"oct_csr": "گزارش OCT - CSR",
"oct_amd": "گزارش OCT - AMD",
"glaucoma": "گزارش گلوکوم",
"keratoconus": "گزارش کراتوکونوس",
}
title = title_map.get(task, f"گزارش {task}")
lines: List[str] = []
lines.append(f"👁 {title} برای بیمار: {patient_name or '—'}")
lines.append(f"📅 تاریخ معاینه: {exam_date or '—'}")
if eye: lines.append(f"👓 چشم: {eye}")
lines.append("________________________________________")
lines.append("📌 نتیجه الگوریتم (Top-1):")
lines.append(f"• {top['class_fa']} ({top['class_en']}) — احتمال {top['prob']:.3f}")
lines.append("📊 توزیع احتمالات:")
for it in items:
lines.append(f"• {it['class_fa']} ({it['class_en']}) — {it['prob']:.4f}")
if task == "dr":
lines.append("🧠 یادداشت: نتیجه برای کمک به تصمیم‌گیری است؛ در موارد مثبت معاینه بالینی/تصویربرداری تکمیلی توصیه می‌شود.")
elif task.startswith("oct_"):
lines.append("🧠 یادداشت: تفسیر نهایی با همبستگی بالینی و تصاویر مکمل.")
elif task in ("glaucoma", "keratoconus"):
lines.append("🧠 یادداشت: جایگزین تشخیص پزشک نیست و باید با پاراکلینیک تلفیق شود.")
return "\n".join(lines)
# -------------------- Pages --------------------
@app.get("/", response_class=HTMLResponse)
def root():
li = "".join([f"<li>{t} — loaded={_TASK_MODELS[t].model is not None} — img={_TASK_MODELS[t].img_size}</li>" for t in _TASKS])
return f"""
<html><head><meta charset="utf-8"><title>Retina Unified API</title></head>
<body style="font-family:Tahoma,Arial,sans-serif">
<h2>Retina Multi-Task Predictor (Single Port)</h2>
<p>Device: <b>{_DEVICE}</b> | Tasks: {", ".join(_TASKS)}</p>
<ul>{li}</ul>
<h3>Quick Forms</h3>
<form action="/predict" method="post" enctype="multipart/form-data">
<div><b>Back-compat /predict (RETINA_DEFAULT_TASK = {DEFAULT_FALLBACK_TASK})</b></div>
<input type="file" name="file" accept="image/*" required />
<button type="submit">/predict</button>
</form>
<hr/>
<form action="/predict_task?task=oct_cme" method="post" enctype="multipart/form-data">
<div><b>OCT - CME</b></div>
<input type="file" name="file" accept="image/*" required />
<button type="submit">/predict_task?task=oct_cme</button>
</form>
</body></html>
"""
# -------------------- Meta --------------------
@app.get("/tasks")
def tasks():
out = {}
for t, tm in _TASK_MODELS.items():
out[t] = {
"loaded": tm.model is not None,
"img_size": tm.img_size,
"classes_en": tm.classes_en,
"classes_fa": tm.classes_fa,
"weights_used": tm.weights_path,
"weights_candidates": getattr(tm, "_all_weight_candidates", []),
"missing_keys": tm.missing_keys,
"unexpected_keys": tm.unexpected_keys,
"remote_url": _remote_url(t, "predict"),
}
return out
@app.get("/health")
def health():
return {
"device": _DEVICE,
"cuda": bool(TORCH_AVAILABLE and torch and torch.cuda.is_available()), # type: ignore
"cudnn_enabled": bool(TORCH_AVAILABLE and torch and torch.backends.cudnn.enabled), # type: ignore
"tasks": list(_TASK_MODELS.keys()),
"loaded": {t: (_TASK_MODELS[t].model is not None) for t in _TASK_MODELS},
}
# -------------------- API: multi-task --------------------
@app.post("/predict_task")
def predict_task(
task: str = Query(..., description="dr, oct_cme, oct_csr, oct_amd, glaucoma, keratoconus"),
file: UploadFile = File(...)
):
task = task.strip().lower()
if task not in _TASK_MODELS:
raise HTTPException(status_code=404, detail=f"Unknown task: {task}")
tm = _TASK_MODELS[task]
try:
raw = file.file.read()
except Exception:
raise HTTPException(status_code=400, detail="Invalid file")
if tm.model is None:
return _proxy_predict_task(task, raw, filename=getattr(file, "filename", "image.jpg"))
try:
im = Image.open(io.BytesIO(raw))
except Exception:
raise HTTPException(status_code=400, detail="Invalid image")
qc = _simple_qc(im)
probs = predict_with_task(tm, im)
items_sorted, top1 = _items_from_probs(task, probs)
return JSONResponse({
"task": task,
"qc": qc,
"top1": top1,
"probs": items_sorted,
"weights_used": tm.weights_path,
"weights_candidates": getattr(tm, "_all_weight_candidates", []),
})
@app.post("/report_task")
def report_task(
task: str = Query(...),
file: UploadFile = File(...),
patient_name: str = Form(""),
exam_date: str = Form(""),
eye: str = Form("")
):
task = task.strip().lower()
if task not in _TASK_MODELS:
raise HTTPException(status_code=404, detail=f"Unknown task: {task}")
tm = _TASK_MODELS[task]
try:
raw = file.file.read()
except Exception:
raise HTTPException(status_code=400, detail="Invalid file")
if tm.model is None:
form = {"patient_name": patient_name, "exam_date": exam_date, "eye": eye}
return _proxy_report_task(task, raw, form, filename=getattr(file, "filename", "image.jpg"))
try:
im = Image.open(io.BytesIO(raw))
except Exception:
raise HTTPException(status_code=400, detail="Invalid image")
qc = _simple_qc(im)
probs = predict_with_task(tm, im)
items_sorted, top1 = _items_from_probs(task, probs)
report_fa = _format_report(task, probs, patient_name=patient_name, exam_date=exam_date, eye=eye)
return JSONResponse({
"task": task,
"patient": {"name": patient_name, "exam_date": exam_date, "eye": eye},
"qc": qc, "top1": top1, "probs": items_sorted,
"report": report_fa,
"weights_used": tm.weights_path,
"weights_candidates": getattr(tm, "_all_weight_candidates", []),
})
# -------------------- Back-compat --------------------
class PredictJsonReq(BaseModel):
image_b64: str
def _get_fallback_task() -> str:
t = os.getenv("RETINA_DEFAULT_TASK", "dr").strip().lower()
if t not in _TASK_MODELS:
raise HTTPException(status_code=404, detail=f"Unknown default task: {t}")
return t
@app.post("/predict")
def predict(file: UploadFile = File(...)):
task = _get_fallback_task()
tm = _TASK_MODELS[task]
try:
raw = file.file.read()
except Exception:
raise HTTPException(status_code=400, detail="Invalid file")
if tm.model is None:
return _proxy_predict_task(task, raw, filename=getattr(file, "filename", "image.jpg"))
try:
im = Image.open(io.BytesIO(raw))
except Exception:
raise HTTPException(status_code=400, detail="Invalid image")
qc = _simple_qc(im)
probs = predict_with_task(tm, im)
items_sorted, top1 = _items_from_probs(task, probs)
return {"task": task, "qc": qc, "top1": top1, "probs": items_sorted}
@app.post("/predict_json")
def predict_json(req: PredictJsonReq):
task = _get_fallback_task()
tm = _TASK_MODELS[task]
try:
data = base64.b64decode(req.image_b64)
except Exception:
raise HTTPException(status_code=400, detail="Invalid base64 image")
if tm.model is None:
return _proxy_predict_task(task, data, filename="image.jpg")
try:
im = Image.open(io.BytesIO(data))
except Exception:
raise HTTPException(status_code=400, detail="Invalid image data")
qc = _simple_qc(im)
probs = predict_with_task(tm, im)
items_sorted, top1 = _items_from_probs(task, probs)
return {"task": task, "qc": qc, "top1": top1, "probs": items_sorted}
@app.post("/report")
def report(
file: UploadFile = File(...),
patient_name: str = Form(""),
exam_date: str = Form(""),
eye: str = Form("OD")
):
task = _get_fallback_task()
tm = _TASK_MODELS[task]
try:
raw = file.file.read()
except Exception:
raise HTTPException(status_code=400, detail="Invalid file")
if tm.model is None:
form = {"patient_name": patient_name, "exam_date": exam_date, "eye": eye}
return _proxy_report_task(task, raw, form, filename=getattr(file, "filename", "image.jpg"))
try:
im = Image.open(io.BytesIO(raw))
except Exception:
raise HTTPException(status_code=400, detail="Invalid image")
qc = _simple_qc(im)
probs = predict_with_task(tm, im)
items_sorted, top1 = _items_from_probs(task, probs)
rep = _format_report(task, probs, patient_name=patient_name, exam_date=exam_date, eye=eye)
return {"task": task,
"patient": {"name": patient_name, "exam_date": exam_date, "eye": eye},
"qc": qc, "top1": top1, "probs": items_sorted,
"report": rep}
@app.post("/predict_strict")
def predict_strict(file: UploadFile = File(...), tta: int = 1):
return predict(file)