#!/usr/bin/env python3 # Retina/eye multi-task inference API (single-port, Torch-optional) import io import os import base64 import glob import hashlib import tempfile from pathlib import Path from dataclasses import dataclass from typing import Dict, List, Optional, Tuple, Any import requests from fastapi import FastAPI, UploadFile, File, HTTPException, Query, Form from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import HTMLResponse, JSONResponse from pydantic import BaseModel from PIL import Image # -------------------- Torch / Torchvision (optional) -------------------- TORCH_AVAILABLE = False _TV_WEIGHTS_ENUM = False try: import torch # type: ignore TORCH_AVAILABLE = True try: # import torchvision only if torch is OK from torchvision import transforms as T # type: ignore from torchvision.models import resnet50, mobilenet_v3_large # type: ignore try: from torchvision.models import ResNet50_Weights, MobileNet_V3_Large_Weights # type: ignore _TV_WEIGHTS_ENUM = True except Exception: ResNet50_Weights = None # type: ignore MobileNet_V3_Large_Weights = None # type: ignore _TV_WEIGHTS_ENUM = False except Exception: # torchvision هم در دسترس نبود T = None # type: ignore resnet50 = mobilenet_v3_large = None # type: ignore except Exception: torch = None # type: ignore T = None # type: ignore resnet50 = mobilenet_v3_large = None # type: ignore # -------------------- Defaults per task -------------------- DEFAULT_TASKS = ["dr"] TASK_DEFAULT_CLASSES_FA: Dict[str, List[str]] = { "dr": ["بدون DR", "خفیف", "متوسط", "شدید", "پرولیفراکتیو"], "oct_cme": ["بدون CME", "CME"], "oct_csr": ["بدون CSR", "CSR"], "oct_amd": ["بدون AMD", "خشک", "تر"], "glaucoma": ["نرمال", "گلوکوم"], "keratoconus": ["نرمال", "کراتوکونوس"], } TASK_DEFAULT_CLASSES_EN: Dict[str, List[str]] = { "dr": ["No DR", "Mild", "Moderate", "Severe", "Proliferative DR"], "oct_cme": ["No CME", "CME"], "oct_csr": ["No CSR", "CSR"], "oct_amd": ["No AMD", "Dry", "Wet"], "glaucoma": ["Normal", "Glaucoma"], "keratoconus": ["Normal", "Keratoconus"], } TASK_DEFAULT_IMG: Dict[str, int] = { "dr": 448, "oct_cme": 416, "oct_csr": 416, "oct_amd": 416, "glaucoma": 416, "keratoconus": 416, } TASK_DEFAULT_MODEL: Dict[str, str] = { "dr": "resnet50", "oct_cme": "resnet50", "oct_csr": "resnet50", "oct_amd": "resnet50", "glaucoma": "resnet50", "keratoconus": "resnet50", } # -------------------- Weights: autodiscovery / optional download -------------------- DEFAULT_WEIGHTS_DIR = os.getenv("RETINA_WEIGHTS_DIR", "/app/models") WEIGHT_PATTERNS = { "dr": ["runs_k80/phase2/best.pth", "dr/*.pth", "*.pth"], "oct_cme": ["oct_cme/best.pth", "oct_cme/*.pth", "*.pth"], "oct_csr": ["oct_csr/best.pth", "oct_csr/*.pth", "*.pth"], "oct_amd": ["oct_amd/best.pth", "oct_amd/*.pth", "*.pth"], "glaucoma": ["glaucoma/best.pth", "glaucoma/*.pth", "*.pth"], "keratoconus": ["keratoconus/best.pth", "keratoconus/*.pth", "*.pth"], } def _find_candidate_weights(task: str) -> List[str]: root = Path(DEFAULT_WEIGHTS_DIR) pats = WEIGHT_PATTERNS.get(task, ["*.pth"]) found: List[str] = [] for p in pats: found.extend(glob.glob(str(root / p))) uniq = sorted( set(found), key=lambda p: Path(p).stat().st_mtime if Path(p).exists() else 0, reverse=True, ) return [f for f in uniq if Path(f).is_file()] def _download(url: str, dest: Path, sha256: Optional[str] = None) -> Path: dest.parent.mkdir(parents=True, exist_ok=True) with requests.get(url, stream=True, timeout=60) as r: r.raise_for_status() h = hashlib.sha256() with tempfile.NamedTemporaryFile(delete=False, dir=str(dest.parent), suffix=".part") as tmp: for chunk in r.iter_content(chunk_size=1024*1024): if not chunk: continue tmp.write(chunk) h.update(chunk) tmp_path = Path(tmp.name) if sha256 and h.hexdigest().lower() != sha256.lower(): tmp_path.unlink(missing_ok=True) raise RuntimeError(f"SHA256 mismatch for {url}") tmp_path.replace(dest) return dest def _pick_weight(task: str) -> Tuple[Optional[str], List[str]]: env_path = os.getenv(f"RETINA_WEIGHTS_{task}") if env_path and Path(env_path).is_file(): return env_path, [env_path] cands = _find_candidate_weights(task) if cands: return cands[0], cands url = os.getenv(f"RETINA_WEIGHTS_URL_{task}") sha = os.getenv(f"RETINA_WEIGHTS_SHA256_{task}") if url: dest = Path(DEFAULT_WEIGHTS_DIR) / task / "best.pth" try: print(f"[weights] downloading {task} from {url} → {dest}") got = _download(url, dest, sha256=sha) return str(got), [str(got)] except Exception as e: print(f"[weights] download failed for {task}: {e}") return None, [] # -------------------- Utils (Torch-aware) -------------------- def device_setup() -> str: if TORCH_AVAILABLE and torch.cuda.is_available(): # type: ignore torch.backends.cudnn.enabled = False # type: ignore return "cuda" return "cpu" def build_model(name: str, num_classes: int): if not (TORCH_AVAILABLE and resnet50 and mobilenet_v3_large): raise RuntimeError("Torch/torchvision not available in this runtime.") name = name.lower() if name in ("resnet50", "resnet"): if _TV_WEIGHTS_ENUM: m = resnet50(weights=None) # type: ignore else: m = resnet50(pretrained=False) # type: ignore import torch.nn as nn # local import (only when torch exists) m.fc = nn.Linear(m.fc.in_features, num_classes) return m elif name in ("mobilenetv3", "mobilenet_v3", "mbv3"): if _TV_WEIGHTS_ENUM: m = mobilenet_v3_large(weights=None) # type: ignore else: m = mobilenet_v3_large(pretrained=False) # type: ignore import torch.nn as nn # local import m.classifier[3] = nn.Linear(m.classifier[3].in_features, num_classes) return m else: raise ValueError(f"Unknown model: {name}") def make_transform(img_size: int): if not (TORCH_AVAILABLE and T): # در حالت بدون Torch اصلاً این مسیر استفاده نمی‌شود def _noop(x): return x return _noop return T.Compose([ T.Resize(int(img_size * 1.15)), T.CenterCrop(img_size), T.ToTensor(), T.Normalize(mean=[0.485,0.456,0.406], std=[0.229,0.224,0.225]), ]) def load_state(model, weights_path: str): if not TORCH_AVAILABLE: raise RuntimeError("Torch not available for loading state.") ckpt = torch.load(weights_path, map_location='cpu') # type: ignore state = ckpt.get("model", ckpt) new_state = {} for k, v in state.items(): nk = k[7:] if k.startswith("module.") else k new_state[nk] = v missing, unexpected = model.load_state_dict(new_state, strict=False) return list(missing), list(unexpected) @dataclass class TaskModel: name: str model: Optional[Any] device: str img_size: int classes_fa: List[str] classes_en: List[str] weights_path: Optional[str] missing_keys: List[str] unexpected_keys: List[str] transform: Any def env_list(key: str, default: Optional[List[str]] = None) -> List[str]: raw = os.getenv(key) if not raw: return default or [] return [x.strip() for x in raw.split(",") if x.strip()] def parse_classes_env(task: str) -> Optional[List[str]]: key = f"RETINA_CLASSES_{task}" raw = os.getenv(key) if not raw: return None vals = [v.strip() for v in raw.split(",") if v.strip()] return vals or None def prepare_task(task: str, device: str) -> TaskModel: model_name = os.getenv(f"RETINA_MODEL_{task}", TASK_DEFAULT_MODEL.get(task, "resnet50")) img_size = int(os.getenv(f"RETINA_IMG_SIZE_{task}", str(TASK_DEFAULT_IMG.get(task, 416)))) classes_en = parse_classes_env(task) or TASK_DEFAULT_CLASSES_EN.get(task, ["Negative","Positive"]) classes_fa_default = TASK_DEFAULT_CLASSES_FA.get(task, ["منفی","مثبت"]) classes_fa = classes_fa_default if not parse_classes_env(task) else ( classes_fa_default if len(classes_fa_default)==len(classes_en) else classes_en ) weights, all_cands = _pick_weight(task) # اگر torch/torchvision نیست یا وزنی نداریم → مدل لوکال لود نشود if (not TORCH_AVAILABLE) or (not weights) or (not os.path.isfile(weights)): tm = TaskModel(task, None, device, img_size, classes_fa, classes_en, weights if weights else None, [], [], make_transform(img_size)) tm._all_weight_candidates = all_cands # type: ignore return tm m = build_model(model_name, num_classes=len(classes_en)) missing, unexpected = load_state(m, weights) m.eval().to(device) if device == 'cuda': m.to(memory_format=torch.channels_last) # type: ignore tm = TaskModel(task, m, device, img_size, classes_fa, classes_en, weights, missing, unexpected, make_transform(img_size)) tm._all_weight_candidates = all_cands # type: ignore return tm def predict_with_task(task_obj: TaskModel, pil_im: Image.Image) -> List[float]: if (not TORCH_AVAILABLE) or (task_obj.model is None): raise RuntimeError("Local model not available.") x = task_obj.transform(pil_im.convert("RGB")).unsqueeze(0) x = x.to(task_obj.device, non_blocking=True) with torch.no_grad(): # type: ignore logits = task_obj.model(x) probs = torch.softmax(logits, dim=1)[0].detach().cpu().numpy().tolist() # type: ignore return probs # -------------------- Remote proxy helpers -------------------- def _remote_base_for(task: str) -> Optional[str]: return os.getenv(f"RETINA_REMOTE_{task}") def _remote_auth_header_for(task: str) -> dict: token = os.getenv(f"RETINA_REMOTE_AUTH_{task}") or os.getenv("RETINA_REMOTE_AUTH") or "" return {"Authorization": token} if token.strip() else {} def _remote_verify_ssl() -> bool: v = (os.getenv("RETINA_REMOTE_VERIFY_SSL") or "true").strip().lower() return v not in ("0", "false", "no") def _remote_timeout() -> int: try: return int(os.getenv("RETINA_REMOTE_TIMEOUT", "90")) except Exception: return 90 def _remote_url(task: str, mode: str) -> Optional[str]: base = _remote_base_for(task) if not base: return None base = base.strip() if base.endswith("/predict_task") or base.endswith("/report_task"): return base return f"{base.rstrip('/')}/{ 'predict_task' if mode == 'predict' else 'report_task'}?task={task}" def _proxy_predict_task(task: str, file_bytes: bytes, filename: str = "image.jpg") -> JSONResponse: url = _remote_url(task, "predict") if not url: raise HTTPException(status_code=501, detail=f"Task '{task}' not loaded and no remote set (RETINA_REMOTE_{task}).") headers = _remote_auth_header_for(task) try: r = requests.post( url, files={"file": (filename, file_bytes, "image/jpeg")}, headers=headers, timeout=_remote_timeout(), verify=_remote_verify_ssl(), ) if not (200 <= r.status_code < 300): raise HTTPException(status_code=r.status_code, detail=f"Remote error: {r.text}") try: return JSONResponse(r.json()) except Exception: return JSONResponse({"remote_raw": r.text}) except requests.RequestException as e: raise HTTPException(status_code=502, detail=f"Remote proxy failed: {e}") def _proxy_report_task(task: str, file_bytes: bytes, form: dict, filename: str = "image.jpg") -> JSONResponse: url = _remote_url(task, "report") if not url: raise HTTPException(status_code=501, detail=f"Task '{task}' not loaded and no remote set (RETINA_REMOTE_{task}).") headers = _remote_auth_header_for(task) try: r = requests.post( url, files={"file": (filename, file_bytes, "image/jpeg")}, data=form, headers=headers, timeout=_remote_timeout(), verify=_remote_verify_ssl(), ) if not (200 <= r.status_code < 300): raise HTTPException(status_code=r.status_code, detail=f"Remote error: {r.text}") try: return JSONResponse(r.json()) except Exception: return JSONResponse({"remote_raw": r.text}) except requests.RequestException as e: raise HTTPException(status_code=502, detail=f"Remote proxy failed: {e}") # -------------------- App -------------------- app = FastAPI(title="Retina Multi-Task Inference API (Unified)", version="1.3.1") app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) _DEVICE = device_setup() _TASKS = env_list("RETINA_TASKS", DEFAULT_TASKS) _TASK_MODELS: Dict[str, TaskModel] = {t: prepare_task(t, _DEVICE) for t in _TASKS} DEFAULT_FALLBACK_TASK = os.getenv("RETINA_DEFAULT_TASK", "dr").strip().lower() # -------------------- Helpers for QC/format -------------------- def _simple_qc(im: Image.Image) -> dict: try: import numpy as np # lazy except Exception: w, h = im.size return {"width": w, "height": h, "mean_luma": None, "warnings": [], "ok": True} w, h = im.size mean_luma = float(np.array(im.convert("L")).mean()) warns: List[str] = [] if min(w, h) < 512: warns.append("low_resolution") if mean_luma < 25: warns.append("too_dark") if mean_luma > 230: warns.append("too_bright") return {"width": w, "height": h, "mean_luma": round(mean_luma,1), "warnings": warns, "ok": len(warns)==0} def _items_from_probs(task: str, probs: List[float]): tm = _TASK_MODELS[task] items = [{"index": i, "class_en": tm.classes_en[i], "class_fa": tm.classes_fa[i], "prob": float(p)} for i, p in enumerate(probs)] items_sorted = sorted(items, key=lambda d: d["prob"], reverse=True) top1 = items_sorted[0] return items_sorted, top1 def _format_report(task: str, probs: List[float], patient_name: str = "", exam_date: str = "", eye: str = "") -> str: tm = _TASK_MODELS[task] items, top = _items_from_probs(task, probs) title_map = { "dr": "گزارش رتینوپاتی دیابتی (DR)", "oct_cme": "گزارش OCT - CME", "oct_csr": "گزارش OCT - CSR", "oct_amd": "گزارش OCT - AMD", "glaucoma": "گزارش گلوکوم", "keratoconus": "گزارش کراتوکونوس", } title = title_map.get(task, f"گزارش {task}") lines: List[str] = [] lines.append(f"👁 {title} برای بیمار: {patient_name or '—'}") lines.append(f"📅 تاریخ معاینه: {exam_date or '—'}") if eye: lines.append(f"👓 چشم: {eye}") lines.append("________________________________________") lines.append("📌 نتیجه الگوریتم (Top-1):") lines.append(f"• {top['class_fa']} ({top['class_en']}) — احتمال {top['prob']:.3f}") lines.append("📊 توزیع احتمالات:") for it in items: lines.append(f"• {it['class_fa']} ({it['class_en']}) — {it['prob']:.4f}") if task == "dr": lines.append("🧠 یادداشت: نتیجه برای کمک به تصمیم‌گیری است؛ در موارد مثبت معاینه بالینی/تصویربرداری تکمیلی توصیه می‌شود.") elif task.startswith("oct_"): lines.append("🧠 یادداشت: تفسیر نهایی با همبستگی بالینی و تصاویر مکمل.") elif task in ("glaucoma", "keratoconus"): lines.append("🧠 یادداشت: جایگزین تشخیص پزشک نیست و باید با پاراکلینیک تلفیق شود.") return "\n".join(lines) # -------------------- Pages -------------------- @app.get("/", response_class=HTMLResponse) def root(): li = "".join([f"
  • {t} — loaded={_TASK_MODELS[t].model is not None} — img={_TASK_MODELS[t].img_size}
  • " for t in _TASKS]) return f""" Retina Unified API

    Retina Multi-Task Predictor (Single Port)

    Device: {_DEVICE} | Tasks: {", ".join(_TASKS)}

    Quick Forms

    Back-compat /predict (RETINA_DEFAULT_TASK = {DEFAULT_FALLBACK_TASK})

    OCT - CME
    """ # -------------------- Meta -------------------- @app.get("/tasks") def tasks(): out = {} for t, tm in _TASK_MODELS.items(): out[t] = { "loaded": tm.model is not None, "img_size": tm.img_size, "classes_en": tm.classes_en, "classes_fa": tm.classes_fa, "weights_used": tm.weights_path, "weights_candidates": getattr(tm, "_all_weight_candidates", []), "missing_keys": tm.missing_keys, "unexpected_keys": tm.unexpected_keys, "remote_url": _remote_url(t, "predict"), } return out @app.get("/health") def health(): return { "device": _DEVICE, "cuda": bool(TORCH_AVAILABLE and torch and torch.cuda.is_available()), # type: ignore "cudnn_enabled": bool(TORCH_AVAILABLE and torch and torch.backends.cudnn.enabled), # type: ignore "tasks": list(_TASK_MODELS.keys()), "loaded": {t: (_TASK_MODELS[t].model is not None) for t in _TASK_MODELS}, } # -------------------- API: multi-task -------------------- @app.post("/predict_task") def predict_task( task: str = Query(..., description="dr, oct_cme, oct_csr, oct_amd, glaucoma, keratoconus"), file: UploadFile = File(...) ): task = task.strip().lower() if task not in _TASK_MODELS: raise HTTPException(status_code=404, detail=f"Unknown task: {task}") tm = _TASK_MODELS[task] try: raw = file.file.read() except Exception: raise HTTPException(status_code=400, detail="Invalid file") if tm.model is None: return _proxy_predict_task(task, raw, filename=getattr(file, "filename", "image.jpg")) try: im = Image.open(io.BytesIO(raw)) except Exception: raise HTTPException(status_code=400, detail="Invalid image") qc = _simple_qc(im) probs = predict_with_task(tm, im) items_sorted, top1 = _items_from_probs(task, probs) return JSONResponse({ "task": task, "qc": qc, "top1": top1, "probs": items_sorted, "weights_used": tm.weights_path, "weights_candidates": getattr(tm, "_all_weight_candidates", []), }) @app.post("/report_task") def report_task( task: str = Query(...), file: UploadFile = File(...), patient_name: str = Form(""), exam_date: str = Form(""), eye: str = Form("") ): task = task.strip().lower() if task not in _TASK_MODELS: raise HTTPException(status_code=404, detail=f"Unknown task: {task}") tm = _TASK_MODELS[task] try: raw = file.file.read() except Exception: raise HTTPException(status_code=400, detail="Invalid file") if tm.model is None: form = {"patient_name": patient_name, "exam_date": exam_date, "eye": eye} return _proxy_report_task(task, raw, form, filename=getattr(file, "filename", "image.jpg")) try: im = Image.open(io.BytesIO(raw)) except Exception: raise HTTPException(status_code=400, detail="Invalid image") qc = _simple_qc(im) probs = predict_with_task(tm, im) items_sorted, top1 = _items_from_probs(task, probs) report_fa = _format_report(task, probs, patient_name=patient_name, exam_date=exam_date, eye=eye) return JSONResponse({ "task": task, "patient": {"name": patient_name, "exam_date": exam_date, "eye": eye}, "qc": qc, "top1": top1, "probs": items_sorted, "report": report_fa, "weights_used": tm.weights_path, "weights_candidates": getattr(tm, "_all_weight_candidates", []), }) # -------------------- Back-compat -------------------- class PredictJsonReq(BaseModel): image_b64: str def _get_fallback_task() -> str: t = os.getenv("RETINA_DEFAULT_TASK", "dr").strip().lower() if t not in _TASK_MODELS: raise HTTPException(status_code=404, detail=f"Unknown default task: {t}") return t @app.post("/predict") def predict(file: UploadFile = File(...)): task = _get_fallback_task() tm = _TASK_MODELS[task] try: raw = file.file.read() except Exception: raise HTTPException(status_code=400, detail="Invalid file") if tm.model is None: return _proxy_predict_task(task, raw, filename=getattr(file, "filename", "image.jpg")) try: im = Image.open(io.BytesIO(raw)) except Exception: raise HTTPException(status_code=400, detail="Invalid image") qc = _simple_qc(im) probs = predict_with_task(tm, im) items_sorted, top1 = _items_from_probs(task, probs) return {"task": task, "qc": qc, "top1": top1, "probs": items_sorted} @app.post("/predict_json") def predict_json(req: PredictJsonReq): task = _get_fallback_task() tm = _TASK_MODELS[task] try: data = base64.b64decode(req.image_b64) except Exception: raise HTTPException(status_code=400, detail="Invalid base64 image") if tm.model is None: return _proxy_predict_task(task, data, filename="image.jpg") try: im = Image.open(io.BytesIO(data)) except Exception: raise HTTPException(status_code=400, detail="Invalid image data") qc = _simple_qc(im) probs = predict_with_task(tm, im) items_sorted, top1 = _items_from_probs(task, probs) return {"task": task, "qc": qc, "top1": top1, "probs": items_sorted} @app.post("/report") def report( file: UploadFile = File(...), patient_name: str = Form(""), exam_date: str = Form(""), eye: str = Form("OD") ): task = _get_fallback_task() tm = _TASK_MODELS[task] try: raw = file.file.read() except Exception: raise HTTPException(status_code=400, detail="Invalid file") if tm.model is None: form = {"patient_name": patient_name, "exam_date": exam_date, "eye": eye} return _proxy_report_task(task, raw, form, filename=getattr(file, "filename", "image.jpg")) try: im = Image.open(io.BytesIO(raw)) except Exception: raise HTTPException(status_code=400, detail="Invalid image") qc = _simple_qc(im) probs = predict_with_task(tm, im) items_sorted, top1 = _items_from_probs(task, probs) rep = _format_report(task, probs, patient_name=patient_name, exam_date=exam_date, eye=eye) return {"task": task, "patient": {"name": patient_name, "exam_date": exam_date, "eye": eye}, "qc": qc, "top1": top1, "probs": items_sorted, "report": rep} @app.post("/predict_strict") def predict_strict(file: UploadFile = File(...), tta: int = 1): return predict(file)