hadi6681 commited on
Commit
16baeec
·
verified ·
1 Parent(s): a5c80aa

Update retina_api_multi.py

Browse files
Files changed (1) hide show
  1. retina_api_multi.py +620 -650
retina_api_multi.py CHANGED
@@ -1,650 +1,620 @@
1
- #!/usr/bin/env python3
2
- # Retina/eye multi-task inference API (Windows-friendly, single-port)
3
- #
4
- # ✅ یک پورت برای همه‌ی تحلیل‌ها
5
- # ✅ /predict و /report (سازگاری با کلاینت‌های قدیمی – پیش‌فرض task=dr مگر این‌که با ENV عوض کنید)
6
- # ✅ /predict_task و /report_task برای همه‌ی task ها
7
- # ✅ /tasks و /health برای دیباگ + فهرست همه‌ی وزن‌های کاندید
8
- # ✅ کشف خودکار وزن‌ها از پوشهٔ مدل‌ها؛ ENVها بالاترین اولویت را دارند؛ در صورت نیاز دانلود از URL
9
- # ✅ fallback به ریموت وقتی وزن لوکال نداریم
10
- #
11
- # اجرای نمونه (پورت 8000):
12
- # python -m uvicorn app.retina.retina_api_multi:app --host 0.0.0.0 --port 8000 --workers 1
13
- #
14
- # ENV نمونه:
15
- # RETINA_TASKS="dr,oct_cme,oct_csr,oct_amd,glaucoma,keratoconus"
16
- # RETINA_WEIGHTS_DIR="C:\\namavaran_server\\models"
17
- # RETINA_WEIGHTS_dr="C:\\namavaran_server\\models\\runs_k80\\phase2\\best.pth"
18
- # RETINA_WEIGHTS_URL_oct_cme="https://your.host/oct_cme_best.pth" # اختیاری (دانلود)
19
- # RETINA_WEIGHTS_SHA256_oct_cme="<sha256-hex>" # اختیاری
20
- # RETINA_DEFAULT_TASK="dr"
21
- # RETINA_REMOTE_oct_cme="http://85.208.254.231:8001" # fallback ریموت
22
- # RETINA_REMOTE_AUTH="Bearer <token>" # اختیاری
23
- # RETINA_REMOTE_VERIFY_SSL="false" # فقط تست
24
- # RETINA_REMOTE_TIMEOUT="120"
25
-
26
- import io
27
- import os
28
- import base64
29
- import glob
30
- import hashlib
31
- import tempfile
32
- from pathlib import Path
33
- from dataclasses import dataclass
34
- from typing import Dict, List, Optional, Tuple
35
-
36
- import requests
37
- from fastapi import FastAPI, UploadFile, File, HTTPException, Query, Form
38
- from fastapi.middleware.cors import CORSMiddleware
39
- from fastapi.responses import HTMLResponse, JSONResponse
40
- from pydantic import BaseModel
41
-
42
- import torch
43
- import torch.nn as nn
44
- import torchvision.transforms as T
45
- from PIL import Image
46
-
47
- # ---- torchvision compat (weights API) ----
48
- try:
49
- from torchvision.models import resnet50, ResNet50_Weights
50
- from torchvision.models import mobilenet_v3_large, MobileNet_V3_Large_Weights
51
- _TV_WEIGHTS_ENUM = True
52
- except Exception:
53
- from torchvision.models import resnet50, mobilenet_v3_large # type: ignore
54
- ResNet50_Weights = None # type: ignore
55
- MobileNet_V3_Large_Weights = None # type: ignore
56
- _TV_WEIGHTS_ENUM = False
57
-
58
-
59
- # ---------- defaults per task ----------
60
- DEFAULT_TASKS = ["dr"]
61
- TASK_DEFAULT_CLASSES_FA: Dict[str, List[str]] = {
62
- "dr": ["بدون DR", "خفیف", "متوسط", "شدید", "پرولیفراکتیو"],
63
- "oct_cme": ["بدون CME", "CME"],
64
- "oct_csr": ["بدون CSR", "CSR"],
65
- "oct_amd": ["بدون AMD", "خشک", "تر"],
66
- "glaucoma": ["نرمال", "گلوکوم"],
67
- "keratoconus": ["نرمال", "کراتوکونوس"],
68
- }
69
- TASK_DEFAULT_CLASSES_EN: Dict[str, List[str]] = {
70
- "dr": ["No DR", "Mild", "Moderate", "Severe", "Proliferative DR"],
71
- "oct_cme": ["No CME", "CME"],
72
- "oct_csr": ["No CSR", "CSR"],
73
- "oct_amd": ["No AMD", "Dry", "Wet"],
74
- "glaucoma": ["Normal", "Glaucoma"],
75
- "keratoconus": ["Normal", "Keratoconus"],
76
- }
77
- TASK_DEFAULT_IMG: Dict[str, int] = {
78
- "dr": 448,
79
- "oct_cme": 416,
80
- "oct_csr": 416,
81
- "oct_amd": 416,
82
- "glaucoma": 416,
83
- "keratoconus": 416,
84
- }
85
- TASK_DEFAULT_MODEL: Dict[str, str] = {
86
- "dr": "resnet50",
87
- "oct_cme": "resnet50",
88
- "oct_csr": "resnet50",
89
- "oct_amd": "resnet50",
90
- "glaucoma": "resnet50",
91
- "keratoconus": "resnet50",
92
- }
93
-
94
-
95
- # ---------- weights: autodiscovery / optional download ----------
96
- DEFAULT_WEIGHTS_DIR = os.getenv("RETINA_WEIGHTS_DIR", r"C:\namavaran_server\models")
97
- WEIGHT_PATTERNS = {
98
- "dr": ["runs_k80/phase2/best.pth", "dr/*.pth", "*.pth"],
99
- "oct_cme": ["oct_cme/best.pth", "oct_cme/*.pth", "*.pth"],
100
- "oct_csr": ["oct_csr/best.pth", "oct_csr/*.pth", "*.pth"],
101
- "oct_amd": ["oct_amd/best.pth", "oct_amd/*.pth", "*.pth"],
102
- "glaucoma": ["glaucoma/best.pth", "glaucoma/*.pth", "*.pth"],
103
- "keratoconus": ["keratoconus/best.pth", "keratoconus/*.pth", "*.pth"],
104
- }
105
-
106
- def _find_candidate_weights(task: str) -> List[str]:
107
- """Return candidate weight file paths (newest first)."""
108
- root = Path(DEFAULT_WEIGHTS_DIR)
109
- pats = WEIGHT_PATTERNS.get(task, ["*.pth"])
110
- found: List[str] = []
111
- for p in pats:
112
- found.extend(glob.glob(str(root / p)))
113
- uniq = sorted(
114
- set(found),
115
- key=lambda p: Path(p).stat().st_mtime if Path(p).exists() else 0,
116
- reverse=True,
117
- )
118
- return [f for f in uniq if Path(f).is_file()]
119
-
120
- def _download(url: str, dest: Path, sha256: Optional[str] = None) -> Path:
121
- dest.parent.mkdir(parents=True, exist_ok=True)
122
- with requests.get(url, stream=True, timeout=60) as r:
123
- r.raise_for_status()
124
- h = hashlib.sha256()
125
- with tempfile.NamedTemporaryFile(delete=False, dir=str(dest.parent), suffix=".part") as tmp:
126
- for chunk in r.iter_content(chunk_size=1024*1024):
127
- if not chunk:
128
- continue
129
- tmp.write(chunk)
130
- h.update(chunk)
131
- tmp_path = Path(tmp.name)
132
- if sha256 and h.hexdigest().lower() != sha256.lower():
133
- tmp_path.unlink(missing_ok=True)
134
- raise RuntimeError(f"SHA256 mismatch for {url}")
135
- tmp_path.replace(dest)
136
- return dest
137
-
138
- def _pick_weight(task: str) -> Tuple[Optional[str], List[str]]:
139
- """ENV has priority; otherwise use auto-discovery; otherwise optional download via URL env."""
140
- # 1) explicit local path
141
- env_path = os.getenv(f"RETINA_WEIGHTS_{task}")
142
- if env_path and Path(env_path).is_file():
143
- return env_path, [env_path]
144
- # 2) discover
145
- cands = _find_candidate_weights(task)
146
- if cands:
147
- return cands[0], cands
148
- # 3) optional download
149
- url = os.getenv(f"RETINA_WEIGHTS_URL_{task}")
150
- sha = os.getenv(f"RETINA_WEIGHTS_SHA256_{task}")
151
- if url:
152
- dest = Path(DEFAULT_WEIGHTS_DIR) / task / "best.pth"
153
- try:
154
- print(f"[weights] downloading {task} from {url} → {dest}")
155
- got = _download(url, dest, sha256=sha)
156
- return str(got), [str(got)]
157
- except Exception as e:
158
- print(f"[weights] download failed for {task}: {e}")
159
- return None, []
160
-
161
-
162
- # -------- utils --------
163
- def device_setup() -> str:
164
- dev = 'cuda' if torch.cuda.is_available() else 'cpu'
165
- if dev == 'cuda':
166
- # برای سازگاری GPUهای قدیمی (مثلاً K80)
167
- torch.backends.cudnn.enabled = False
168
- return dev
169
-
170
- def build_model(name: str, num_classes: int) -> nn.Module:
171
- name = name.lower()
172
- if name in ("resnet50", "resnet"):
173
- m = resnet50(weights=None) if _TV_WEIGHTS_ENUM else resnet50(pretrained=False)
174
- m.fc = nn.Linear(m.fc.in_features, num_classes)
175
- return m
176
- elif name in ("mobilenetv3", "mobilenet_v3", "mbv3"):
177
- m = mobilenet_v3_large(weights=None) if _TV_WEIGHTS_ENUM else mobilenet_v3_large(pretrained=False)
178
- m.classifier[3] = nn.Linear(m.classifier[3].in_features, num_classes)
179
- return m
180
- else:
181
- raise ValueError(f"Unknown model: {name}")
182
-
183
- def make_transform(img_size: int) -> T.Compose:
184
- return T.Compose([
185
- T.Resize(int(img_size * 1.15)),
186
- T.CenterCrop(img_size),
187
- T.ToTensor(),
188
- T.Normalize(mean=[0.485,0.456,0.406], std=[0.229,0.224,0.225]),
189
- ])
190
-
191
- def load_state(model: nn.Module, weights_path: str):
192
- ckpt = torch.load(weights_path, map_location='cpu')
193
- state = ckpt.get("model", ckpt) # state_dict or raw dict
194
- new_state = {}
195
- for k, v in state.items():
196
- nk = k[7:] if k.startswith("module.") else k
197
- new_state[nk] = v
198
- missing, unexpected = model.load_state_dict(new_state, strict=False)
199
- return list(missing), list(unexpected)
200
-
201
- @dataclass
202
- class TaskModel:
203
- name: str
204
- model: Optional[nn.Module]
205
- device: str
206
- img_size: int
207
- classes_fa: List[str]
208
- classes_en: List[str]
209
- weights_path: Optional[str]
210
- missing_keys: List[str]
211
- unexpected_keys: List[str]
212
- transform: T.Compose
213
-
214
-
215
- def env_list(key: str, default: Optional[List[str]]=None) -> List[str]:
216
- raw = os.getenv(key)
217
- if not raw:
218
- return default or []
219
- return [x.strip() for x in raw.split(",") if x.strip()]
220
-
221
- def parse_classes_env(task: str) -> Optional[List[str]]:
222
- key = f"RETINA_CLASSES_{task}"
223
- raw = os.getenv(key)
224
- if not raw:
225
- return None
226
- vals = [v.strip() for v in raw.split(",") if v.strip()]
227
- return vals or None
228
-
229
- def prepare_task(task: str, device: str) -> TaskModel:
230
- model_name = os.getenv(f"RETINA_MODEL_{task}", TASK_DEFAULT_MODEL.get(task, "resnet50"))
231
- img_size = int(os.getenv(f"RETINA_IMG_SIZE_{task}", str(TASK_DEFAULT_IMG.get(task, 416))))
232
- classes_en = parse_classes_env(task) or TASK_DEFAULT_CLASSES_EN.get(task, ["Negative","Positive"])
233
- classes_fa_default = TASK_DEFAULT_CLASSES_FA.get(task, ["منفی","مثبت"])
234
- classes_fa = classes_fa_default if not parse_classes_env(task) else (
235
- classes_fa_default if len(classes_fa_default)==len(classes_en) else classes_en
236
- )
237
-
238
- weights, all_cands = _pick_weight(task)
239
-
240
- if not weights:
241
- tm = TaskModel(task, None, device, img_size, classes_fa, classes_en, None, [], [], make_transform(img_size))
242
- tm._all_weight_candidates = all_cands # type: ignore
243
- return tm
244
-
245
- m = build_model(model_name, num_classes=len(classes_en))
246
- missing, unexpected = load_state(m, weights)
247
- m.eval().to(device)
248
- if device == 'cuda':
249
- m.to(memory_format=torch.channels_last)
250
-
251
- tm = TaskModel(task, m, device, img_size, classes_fa, classes_en, weights, missing, unexpected, make_transform(img_size))
252
- tm._all_weight_candidates = all_cands # type: ignore
253
- return tm
254
-
255
-
256
- def predict_with_task(task_obj: TaskModel, pil_im: Image.Image) -> List[float]:
257
- if task_obj.model is None:
258
- raise RuntimeError("Task is not loaded (no weights).")
259
- x = task_obj.transform(pil_im.convert("RGB")).unsqueeze(0)
260
- x = x.to(task_obj.device, non_blocking=True)
261
- with torch.no_grad():
262
- logits = task_obj.model(x)
263
- probs = torch.softmax(logits, dim=1)[0].detach().cpu().numpy().tolist()
264
- return probs
265
-
266
-
267
- # ---------- Remote proxy helpers (fallback when model is not loaded) ----------
268
- def _remote_base_for(task: str) -> Optional[str]:
269
- return os.getenv(f"RETINA_REMOTE_{task}")
270
-
271
- def _remote_auth_header_for(task: str) -> dict:
272
- token = os.getenv(f"RETINA_REMOTE_AUTH_{task}") or os.getenv("RETINA_REMOTE_AUTH") or ""
273
- return {"Authorization": token} if token.strip() else {}
274
-
275
- def _remote_verify_ssl() -> bool:
276
- v = (os.getenv("RETINA_REMOTE_VERIFY_SSL") or "true").strip().lower()
277
- return v not in ("0", "false", "no")
278
-
279
- def _remote_timeout() -> int:
280
- try:
281
- return int(os.getenv("RETINA_REMOTE_TIMEOUT", "90"))
282
- except Exception:
283
- return 90
284
-
285
- def _remote_url(task: str, mode: str) -> Optional[str]:
286
- """
287
- mode: 'predict' | 'report'
288
- اگر مقدار محیطی مستقیماً به endpoint اشاره کند همان استفاده می‌شود.
289
- اگر base باشد، endpoint استاندارد ساخته می‌شود.
290
- """
291
- base = _remote_base_for(task)
292
- if not base:
293
- return None
294
- base = base.strip()
295
- if base.endswith("/predict_task") or base.endswith("/report_task"):
296
- return base
297
- if mode == "predict":
298
- return f"{base.rstrip('/')}/predict_task?task={task}"
299
- else:
300
- return f"{base.rstrip('/')}/report_task?task={task}"
301
-
302
- def _proxy_predict_task(task: str, file_bytes: bytes, filename: str = "image.jpg") -> JSONResponse:
303
- url = _remote_url(task, "predict")
304
- if not url:
305
- raise HTTPException(status_code=501, detail=f"Task '{task}' not loaded and no remote set (RETINA_REMOTE_{task}).")
306
- headers = _remote_auth_header_for(task)
307
- try:
308
- r = requests.post(
309
- url,
310
- files={"file": (filename, file_bytes, "image/jpeg")},
311
- headers=headers,
312
- timeout=_remote_timeout(),
313
- verify=_remote_verify_ssl(),
314
- )
315
- if r.status_code < 200 or r.status_code >= 300:
316
- raise HTTPException(status_code=r.status_code, detail=f"Remote error: {r.text}")
317
- try:
318
- return JSONResponse(r.json())
319
- except Exception:
320
- return JSONResponse({"remote_raw": r.text})
321
- except requests.RequestException as e:
322
- raise HTTPException(status_code=502, detail=f"Remote proxy failed: {e}")
323
-
324
- def _proxy_report_task(task: str, file_bytes: bytes, form: dict, filename: str = "image.jpg") -> JSONResponse:
325
- url = _remote_url(task, "report")
326
- if not url:
327
- raise HTTPException(status_code=501, detail=f"Task '{task}' not loaded and no remote set (RETINA_REMOTE_{task}).")
328
- headers = _remote_auth_header_for(task)
329
- try:
330
- r = requests.post(
331
- url,
332
- files={"file": (filename, file_bytes, "image/jpeg")},
333
- data=form,
334
- headers=headers,
335
- timeout=_remote_timeout(),
336
- verify=_remote_verify_ssl(),
337
- )
338
- if r.status_code < 200 or r.status_code >= 300:
339
- raise HTTPException(status_code=r.status_code, detail=f"Remote error: {r.text}")
340
- try:
341
- return JSONResponse(r.json())
342
- except Exception:
343
- return JSONResponse({"remote_raw": r.text})
344
- except requests.RequestException as e:
345
- raise HTTPException(status_code=502, detail=f"Remote proxy failed: {e}")
346
-
347
-
348
- # ---------- App ----------
349
- app = FastAPI(title="Retina Multi-Task Inference API (Unified)", version="1.3.0")
350
- app.add_middleware(
351
- CORSMiddleware,
352
- allow_origins=["*"], allow_credentials=True,
353
- allow_methods=["*"], allow_headers=["*"],
354
- )
355
-
356
- _DEVICE = device_setup()
357
- _TASKS = env_list("RETINA_TASKS", DEFAULT_TASKS)
358
- _TASK_MODELS: Dict[str, TaskModel] = {t: prepare_task(t, _DEVICE) for t in _TASKS}
359
- DEFAULT_FALLBACK_TASK = os.getenv("RETINA_DEFAULT_TASK", "dr").strip().lower()
360
-
361
-
362
- # ---------- Helpers ----------
363
- def _simple_qc(im: Image.Image) -> dict:
364
- try:
365
- import numpy as np # lazy import
366
- except Exception:
367
- w, h = im.size
368
- return {"width": w, "height": h, "mean_luma": None, "warnings": [], "ok": True}
369
- w, h = im.size
370
- mean_luma = float(np.array(im.convert("L")).mean())
371
- warns: List[str] = []
372
- if min(w, h) < 512: warns.append("low_resolution")
373
- if mean_luma < 25: warns.append("too_dark")
374
- if mean_luma > 230: warns.append("too_bright")
375
- return {"width": w, "height": h, "mean_luma": round(mean_luma,1), "warnings": warns, "ok": len(warns)==0}
376
-
377
- def _items_from_probs(task: str, probs: List[float]):
378
- tm = _TASK_MODELS[task]
379
- items = [{"index": i,
380
- "class_en": tm.classes_en[i],
381
- "class_fa": tm.classes_fa[i],
382
- "prob": float(p)} for i, p in enumerate(probs)]
383
- items_sorted = sorted(items, key=lambda d: d["prob"], reverse=True)
384
- top1 = items_sorted[0]
385
- return items_sorted, top1
386
-
387
- def _format_report(task: str, probs: List[float], patient_name: str = "", exam_date: str = "", eye: str = "") -> str:
388
- tm = _TASK_MODELS[task]
389
- items, top = _items_from_probs(task, probs)
390
- title_map = {
391
- "dr": "گزارش رتینوپاتی دیابتی (DR)",
392
- "oct_cme": "گزارش OCT - CME",
393
- "oct_csr": "گزارش OCT - CSR",
394
- "oct_amd": "گزارش OCT - AMD",
395
- "glaucoma": "گزارش گلوکوم",
396
- "keratoconus": "گزارش کراتوکونوس",
397
- }
398
- title = title_map.get(task, f"گزارش {task}")
399
- lines: List[str] = []
400
- lines.append(f"👁 {title} برای بیمار: {patient_name or '—'}")
401
- lines.append(f"📅 تاریخ معاینه: {exam_date or '—'}")
402
- if eye: lines.append(f"👓 چشم: {eye}")
403
- lines.append("________________________________________")
404
- lines.append("📌 نتیجه الگوریتم (Top-1):")
405
- lines.append(f"{top['class_fa']} ({top['class_en']})احتمال {top['prob']:.3f}")
406
- lines.append("📊 توزیع احتمالات:")
407
- for it in items:
408
- lines.append(f"• {it['class_fa']} ({it['class_en']}) — {it['prob']:.4f}")
409
- if task == "dr":
410
- lines.append("🧠 یادداشت: نتیجه برای کمک به تصمیم‌گیری است؛ در موارد مثبت معاینه بالینی/تصویربرداری تکمیلی توصیه می‌شود.")
411
- elif task.startswith("oct_"):
412
- lines.append("🧠 یادداشت: تفسیر نهایی با همبستگی بالینی و تصاویر مکمل.")
413
- elif task in ("glaucoma", "keratoconus"):
414
- lines.append("🧠 یادداشت: جایگزین تشخیص پزشک نیست و باید با پاراکلینیک تلفیق شود.")
415
- return "\n".join(lines)
416
-
417
-
418
- # ---------- Pages ----------
419
- @app.get("/", response_class=HTMLResponse)
420
- def root():
421
- li = "".join([f"<li>{t} — loaded={_TASK_MODELS[t].model is not None} — img={_TASK_MODELS[t].img_size}</li>" for t in _TASKS])
422
- return f"""
423
- <html><head><meta charset="utf-8"><title>Retina Unified API</title></head>
424
- <body style="font-family:Tahoma,Arial,sans-serif">
425
- <h2>Retina Multi-Task Predictor (Single Port)</h2>
426
- <p>Device: <b>{_DEVICE}</b> | Tasks: {", ".join(_TASKS)}</p>
427
- <ul>{li}</ul>
428
- <h3>Quick Forms</h3>
429
- <form action="/predict" method="post" enctype="multipart/form-data">
430
- <div><b>Back-compat /predict (RETINA_DEFAULT_TASK = {DEFAULT_FALLBACK_TASK})</b></div>
431
- <input type="file" name="file" accept="image/*" required />
432
- <button type="submit">/predict</button>
433
- </form>
434
- <hr/>
435
- <form action="/predict_task?task=oct_cme" method="post" enctype="multipart/form-data">
436
- <div><b>OCT - CME</b></div>
437
- <input type="file" name="file" accept="image/*" required />
438
- <button type="submit">/predict_task?task=oct_cme</button>
439
- </form>
440
- </body></html>
441
- """
442
-
443
-
444
- # ---------- Meta ----------
445
- @app.get("/tasks")
446
- def tasks():
447
- out = {}
448
- for t, tm in _TASK_MODELS.items():
449
- out[t] = {
450
- "loaded": tm.model is not None,
451
- "img_size": tm.img_size,
452
- "classes_en": tm.classes_en,
453
- "classes_fa": tm.classes_fa,
454
- "weights_used": tm.weights_path,
455
- "weights_candidates": getattr(tm, "_all_weight_candidates", []),
456
- "missing_keys": tm.missing_keys,
457
- "unexpected_keys": tm.unexpected_keys,
458
- "remote_url": _remote_url(t, "predict"),
459
- }
460
- return out
461
-
462
- @app.get("/health")
463
- def health():
464
- return {
465
- "device": _DEVICE,
466
- "cuda": bool(torch.cuda.is_available()),
467
- "cudnn_enabled": bool(torch.backends.cudnn.enabled),
468
- "tasks": list(_TASK_MODELS.keys()),
469
- "loaded": {t: (_TASK_MODELS[t].model is not None) for t in _TASK_MODELS},
470
- }
471
-
472
-
473
- # ---------- API: multi-task ----------
474
- @app.post("/predict_task")
475
- def predict_task(
476
- task: str = Query(..., description="dr, oct_cme, oct_csr, oct_amd, glaucoma, keratoconus"),
477
- file: UploadFile = File(...)
478
- ):
479
- task = task.strip().lower()
480
- if task not in _TASK_MODELS:
481
- raise HTTPException(status_code=404, detail=f"Unknown task: {task}")
482
- tm = _TASK_MODELS[task]
483
-
484
- try:
485
- raw = file.file.read()
486
- except Exception:
487
- raise HTTPException(status_code=400, detail="Invalid file")
488
-
489
- if tm.model is None:
490
- # remote fallback
491
- return _proxy_predict_task(task, raw, filename=getattr(file, "filename", "image.jpg"))
492
-
493
- try:
494
- im = Image.open(io.BytesIO(raw))
495
- except Exception:
496
- raise HTTPException(status_code=400, detail="Invalid image")
497
-
498
- qc = _simple_qc(im)
499
- probs = predict_with_task(tm, im)
500
- items_sorted, top1 = _items_from_probs(task, probs)
501
- return JSONResponse({
502
- "task": task,
503
- "qc": qc,
504
- "top1": top1,
505
- "probs": items_sorted,
506
- "weights_used": tm.weights_path,
507
- "weights_candidates": getattr(tm, "_all_weight_candidates", []),
508
- })
509
-
510
- @app.post("/report_task")
511
- def report_task(
512
- task: str = Query(...),
513
- file: UploadFile = File(...),
514
- patient_name: str = Form(""),
515
- exam_date: str = Form(""),
516
- eye: str = Form("")
517
- ):
518
- task = task.strip().lower()
519
- if task not in _TASK_MODELS:
520
- raise HTTPException(status_code=404, detail=f"Unknown task: {task}")
521
- tm = _TASK_MODELS[task]
522
-
523
- try:
524
- raw = file.file.read()
525
- except Exception:
526
- raise HTTPException(status_code=400, detail="Invalid file")
527
-
528
- if tm.model is None:
529
- # remote fallback
530
- form = {"patient_name": patient_name, "exam_date": exam_date, "eye": eye}
531
- return _proxy_report_task(task, raw, form, filename=getattr(file, "filename", "image.jpg"))
532
-
533
- try:
534
- im = Image.open(io.BytesIO(raw))
535
- except Exception:
536
- raise HTTPException(status_code=400, detail="Invalid image")
537
-
538
- qc = _simple_qc(im)
539
- probs = predict_with_task(tm, im)
540
- items_sorted, top1 = _items_from_probs(task, probs)
541
- report_fa = _format_report(task, probs, patient_name=patient_name, exam_date=exam_date, eye=eye)
542
-
543
- return JSONResponse({
544
- "task": task,
545
- "patient": {"name": patient_name, "exam_date": exam_date, "eye": eye},
546
- "qc": qc,
547
- "top1": top1,
548
- "probs": items_sorted,
549
- "report": report_fa,
550
- "weights_used": tm.weights_path,
551
- "weights_candidates": getattr(tm, "_all_weight_candidates", []),
552
- })
553
-
554
-
555
- # ---------- Back-compat: /predict, /predict_json, /report, /predict_strict ----------
556
- class PredictJsonReq(BaseModel):
557
- image_b64: str
558
-
559
- def _get_fallback_task() -> str:
560
- t = DEFAULT_FALLBACK_TASK
561
- if t not in _TASK_MODELS:
562
- raise HTTPException(status_code=404, detail=f"Unknown default task: {t}")
563
- return t
564
-
565
- @app.post("/predict")
566
- def predict(file: UploadFile = File(...)):
567
- task = _get_fallback_task()
568
- tm = _TASK_MODELS[task]
569
-
570
- try:
571
- raw = file.file.read()
572
- except Exception:
573
- raise HTTPException(status_code=400, detail="Invalid file")
574
-
575
- if tm.model is None:
576
- # back-compat هم از ریموت استفاده می‌کند
577
- return _proxy_predict_task(task, raw, filename=getattr(file, "filename", "image.jpg"))
578
-
579
- try:
580
- im = Image.open(io.BytesIO(raw))
581
- except Exception:
582
- raise HTTPException(status_code=400, detail="Invalid image")
583
-
584
- qc = _simple_qc(im)
585
- probs = predict_with_task(tm, im)
586
- items_sorted, top1 = _items_from_probs(task, probs)
587
- return {"task": task, "qc": qc, "top1": top1, "probs": items_sorted}
588
-
589
- @app.post("/predict_json")
590
- def predict_json(req: PredictJsonReq):
591
- task = _get_fallback_task()
592
- tm = _TASK_MODELS[task]
593
- data = None
594
- try:
595
- data = base64.b64decode(req.image_b64)
596
- except Exception:
597
- raise HTTPException(status_code=400, detail="Invalid base64 image")
598
-
599
- if tm.model is None:
600
- return _proxy_predict_task(task, data, filename="image.jpg")
601
-
602
- try:
603
- im = Image.open(io.BytesIO(data))
604
- except Exception:
605
- raise HTTPException(status_code=400, detail="Invalid image data")
606
-
607
- qc = _simple_qc(im)
608
- probs = predict_with_task(tm, im)
609
- items_sorted, top1 = _items_from_probs(task, probs)
610
- return {"task": task, "qc": qc, "top1": top1, "probs": items_sorted}
611
-
612
- @app.post("/report")
613
- def report(
614
- file: UploadFile = File(...),
615
- patient_name: str = Form(""),
616
- exam_date: str = Form(""),
617
- eye: str = Form("OD")
618
- ):
619
- task = _get_fallback_task()
620
- tm = _TASK_MODELS[task]
621
-
622
- try:
623
- raw = file.file.read()
624
- except Exception:
625
- raise HTTPException(status_code=400, detail="Invalid file")
626
-
627
- if tm.model is None:
628
- form = {"patient_name": patient_name, "exam_date": exam_date, "eye": eye}
629
- return _proxy_report_task(task, raw, form, filename=getattr(file, "filename", "image.jpg"))
630
-
631
- try:
632
- im = Image.open(io.BytesIO(raw))
633
- except Exception:
634
- raise HTTPException(status_code=400, detail="Invalid image")
635
-
636
- qc = _simple_qc(im)
637
- probs = predict_with_task(tm, im)
638
- items_sorted, top1 = _items_from_probs(task, probs)
639
- rep = _format_report(task, probs, patient_name=patient_name, exam_date=exam_date, eye=eye)
640
- return {
641
- "task": task,
642
- "patient": {"name": patient_name, "exam_date": exam_date, "eye": eye},
643
- "qc": qc, "top1": top1, "probs": items_sorted,
644
- "report": rep
645
- }
646
-
647
- @app.post("/predict_strict")
648
- def predict_strict(file: UploadFile = File(...), tta: int = 1):
649
- """Alias برای سازگاری؛ مثل /predict عمل می‌کند (پارامتر tta نادیده گرفته می‌شود)."""
650
- return predict(file)
 
1
+ #!/usr/bin/env python3
2
+ # Retina/eye multi-task inference API (single-port, Torch-optional)
3
+
4
+ import io
5
+ import os
6
+ import base64
7
+ import glob
8
+ import hashlib
9
+ import tempfile
10
+ from pathlib import Path
11
+ from dataclasses import dataclass
12
+ from typing import Dict, List, Optional, Tuple, Any
13
+
14
+ import requests
15
+ from fastapi import FastAPI, UploadFile, File, HTTPException, Query, Form
16
+ from fastapi.middleware.cors import CORSMiddleware
17
+ from fastapi.responses import HTMLResponse, JSONResponse
18
+ from pydantic import BaseModel
19
+ from PIL import Image
20
+
21
+ # -------------------- Torch / Torchvision (optional) --------------------
22
+ TORCH_AVAILABLE = False
23
+ _TV_WEIGHTS_ENUM = False
24
+ try:
25
+ import torch # type: ignore
26
+ TORCH_AVAILABLE = True
27
+ try:
28
+ # import torchvision only if torch is OK
29
+ from torchvision import transforms as T # type: ignore
30
+ from torchvision.models import resnet50, mobilenet_v3_large # type: ignore
31
+ try:
32
+ from torchvision.models import ResNet50_Weights, MobileNet_V3_Large_Weights # type: ignore
33
+ _TV_WEIGHTS_ENUM = True
34
+ except Exception:
35
+ ResNet50_Weights = None # type: ignore
36
+ MobileNet_V3_Large_Weights = None # type: ignore
37
+ _TV_WEIGHTS_ENUM = False
38
+ except Exception:
39
+ # torchvision هم در دسترس نبود
40
+ T = None # type: ignore
41
+ resnet50 = mobilenet_v3_large = None # type: ignore
42
+ except Exception:
43
+ torch = None # type: ignore
44
+ T = None # type: ignore
45
+ resnet50 = mobilenet_v3_large = None # type: ignore
46
+
47
+ # -------------------- Defaults per task --------------------
48
+ DEFAULT_TASKS = ["dr"]
49
+ TASK_DEFAULT_CLASSES_FA: Dict[str, List[str]] = {
50
+ "dr": ["بدون DR", "خفیف", "متوسط", "شدید", "پرولیفراکتیو"],
51
+ "oct_cme": ["بدون CME", "CME"],
52
+ "oct_csr": ["بدون CSR", "CSR"],
53
+ "oct_amd": ["بدون AMD", "خشک", "تر"],
54
+ "glaucoma": ["نرمال", "گلوکوم"],
55
+ "keratoconus": ["نرمال", "کراتوکونوس"],
56
+ }
57
+ TASK_DEFAULT_CLASSES_EN: Dict[str, List[str]] = {
58
+ "dr": ["No DR", "Mild", "Moderate", "Severe", "Proliferative DR"],
59
+ "oct_cme": ["No CME", "CME"],
60
+ "oct_csr": ["No CSR", "CSR"],
61
+ "oct_amd": ["No AMD", "Dry", "Wet"],
62
+ "glaucoma": ["Normal", "Glaucoma"],
63
+ "keratoconus": ["Normal", "Keratoconus"],
64
+ }
65
+ TASK_DEFAULT_IMG: Dict[str, int] = {
66
+ "dr": 448,
67
+ "oct_cme": 416,
68
+ "oct_csr": 416,
69
+ "oct_amd": 416,
70
+ "glaucoma": 416,
71
+ "keratoconus": 416,
72
+ }
73
+ TASK_DEFAULT_MODEL: Dict[str, str] = {
74
+ "dr": "resnet50",
75
+ "oct_cme": "resnet50",
76
+ "oct_csr": "resnet50",
77
+ "oct_amd": "resnet50",
78
+ "glaucoma": "resnet50",
79
+ "keratoconus": "resnet50",
80
+ }
81
+
82
+ # -------------------- Weights: autodiscovery / optional download --------------------
83
+ DEFAULT_WEIGHTS_DIR = os.getenv("RETINA_WEIGHTS_DIR", "/app/models")
84
+ WEIGHT_PATTERNS = {
85
+ "dr": ["runs_k80/phase2/best.pth", "dr/*.pth", "*.pth"],
86
+ "oct_cme": ["oct_cme/best.pth", "oct_cme/*.pth", "*.pth"],
87
+ "oct_csr": ["oct_csr/best.pth", "oct_csr/*.pth", "*.pth"],
88
+ "oct_amd": ["oct_amd/best.pth", "oct_amd/*.pth", "*.pth"],
89
+ "glaucoma": ["glaucoma/best.pth", "glaucoma/*.pth", "*.pth"],
90
+ "keratoconus": ["keratoconus/best.pth", "keratoconus/*.pth", "*.pth"],
91
+ }
92
+
93
+ def _find_candidate_weights(task: str) -> List[str]:
94
+ root = Path(DEFAULT_WEIGHTS_DIR)
95
+ pats = WEIGHT_PATTERNS.get(task, ["*.pth"])
96
+ found: List[str] = []
97
+ for p in pats:
98
+ found.extend(glob.glob(str(root / p)))
99
+ uniq = sorted(
100
+ set(found),
101
+ key=lambda p: Path(p).stat().st_mtime if Path(p).exists() else 0,
102
+ reverse=True,
103
+ )
104
+ return [f for f in uniq if Path(f).is_file()]
105
+
106
+ def _download(url: str, dest: Path, sha256: Optional[str] = None) -> Path:
107
+ dest.parent.mkdir(parents=True, exist_ok=True)
108
+ with requests.get(url, stream=True, timeout=60) as r:
109
+ r.raise_for_status()
110
+ h = hashlib.sha256()
111
+ with tempfile.NamedTemporaryFile(delete=False, dir=str(dest.parent), suffix=".part") as tmp:
112
+ for chunk in r.iter_content(chunk_size=1024*1024):
113
+ if not chunk:
114
+ continue
115
+ tmp.write(chunk)
116
+ h.update(chunk)
117
+ tmp_path = Path(tmp.name)
118
+ if sha256 and h.hexdigest().lower() != sha256.lower():
119
+ tmp_path.unlink(missing_ok=True)
120
+ raise RuntimeError(f"SHA256 mismatch for {url}")
121
+ tmp_path.replace(dest)
122
+ return dest
123
+
124
+ def _pick_weight(task: str) -> Tuple[Optional[str], List[str]]:
125
+ env_path = os.getenv(f"RETINA_WEIGHTS_{task}")
126
+ if env_path and Path(env_path).is_file():
127
+ return env_path, [env_path]
128
+ cands = _find_candidate_weights(task)
129
+ if cands:
130
+ return cands[0], cands
131
+ url = os.getenv(f"RETINA_WEIGHTS_URL_{task}")
132
+ sha = os.getenv(f"RETINA_WEIGHTS_SHA256_{task}")
133
+ if url:
134
+ dest = Path(DEFAULT_WEIGHTS_DIR) / task / "best.pth"
135
+ try:
136
+ print(f"[weights] downloading {task} from {url} → {dest}")
137
+ got = _download(url, dest, sha256=sha)
138
+ return str(got), [str(got)]
139
+ except Exception as e:
140
+ print(f"[weights] download failed for {task}: {e}")
141
+ return None, []
142
+
143
+ # -------------------- Utils (Torch-aware) --------------------
144
+ def device_setup() -> str:
145
+ if TORCH_AVAILABLE and torch.cuda.is_available(): # type: ignore
146
+ torch.backends.cudnn.enabled = False # type: ignore
147
+ return "cuda"
148
+ return "cpu"
149
+
150
+ def build_model(name: str, num_classes: int):
151
+ if not (TORCH_AVAILABLE and resnet50 and mobilenet_v3_large):
152
+ raise RuntimeError("Torch/torchvision not available in this runtime.")
153
+ name = name.lower()
154
+ if name in ("resnet50", "resnet"):
155
+ if _TV_WEIGHTS_ENUM:
156
+ m = resnet50(weights=None) # type: ignore
157
+ else:
158
+ m = resnet50(pretrained=False) # type: ignore
159
+ import torch.nn as nn # local import (only when torch exists)
160
+ m.fc = nn.Linear(m.fc.in_features, num_classes)
161
+ return m
162
+ elif name in ("mobilenetv3", "mobilenet_v3", "mbv3"):
163
+ if _TV_WEIGHTS_ENUM:
164
+ m = mobilenet_v3_large(weights=None) # type: ignore
165
+ else:
166
+ m = mobilenet_v3_large(pretrained=False) # type: ignore
167
+ import torch.nn as nn # local import
168
+ m.classifier[3] = nn.Linear(m.classifier[3].in_features, num_classes)
169
+ return m
170
+ else:
171
+ raise ValueError(f"Unknown model: {name}")
172
+
173
+ def make_transform(img_size: int):
174
+ if not (TORCH_AVAILABLE and T):
175
+ # در حالت بدون Torch اصلاً این مسیر استفاده نمی‌شود
176
+ def _noop(x): return x
177
+ return _noop
178
+ return T.Compose([
179
+ T.Resize(int(img_size * 1.15)),
180
+ T.CenterCrop(img_size),
181
+ T.ToTensor(),
182
+ T.Normalize(mean=[0.485,0.456,0.406], std=[0.229,0.224,0.225]),
183
+ ])
184
+
185
+ def load_state(model, weights_path: str):
186
+ if not TORCH_AVAILABLE:
187
+ raise RuntimeError("Torch not available for loading state.")
188
+ ckpt = torch.load(weights_path, map_location='cpu') # type: ignore
189
+ state = ckpt.get("model", ckpt)
190
+ new_state = {}
191
+ for k, v in state.items():
192
+ nk = k[7:] if k.startswith("module.") else k
193
+ new_state[nk] = v
194
+ missing, unexpected = model.load_state_dict(new_state, strict=False)
195
+ return list(missing), list(unexpected)
196
+
197
+ @dataclass
198
+ class TaskModel:
199
+ name: str
200
+ model: Optional[Any]
201
+ device: str
202
+ img_size: int
203
+ classes_fa: List[str]
204
+ classes_en: List[str]
205
+ weights_path: Optional[str]
206
+ missing_keys: List[str]
207
+ unexpected_keys: List[str]
208
+ transform: Any
209
+
210
+ def env_list(key: str, default: Optional[List[str]] = None) -> List[str]:
211
+ raw = os.getenv(key)
212
+ if not raw:
213
+ return default or []
214
+ return [x.strip() for x in raw.split(",") if x.strip()]
215
+
216
+ def parse_classes_env(task: str) -> Optional[List[str]]:
217
+ key = f"RETINA_CLASSES_{task}"
218
+ raw = os.getenv(key)
219
+ if not raw:
220
+ return None
221
+ vals = [v.strip() for v in raw.split(",") if v.strip()]
222
+ return vals or None
223
+
224
+ def prepare_task(task: str, device: str) -> TaskModel:
225
+ model_name = os.getenv(f"RETINA_MODEL_{task}", TASK_DEFAULT_MODEL.get(task, "resnet50"))
226
+ img_size = int(os.getenv(f"RETINA_IMG_SIZE_{task}", str(TASK_DEFAULT_IMG.get(task, 416))))
227
+ classes_en = parse_classes_env(task) or TASK_DEFAULT_CLASSES_EN.get(task, ["Negative","Positive"])
228
+ classes_fa_default = TASK_DEFAULT_CLASSES_FA.get(task, ["منفی","مثبت"])
229
+ classes_fa = classes_fa_default if not parse_classes_env(task) else (
230
+ classes_fa_default if len(classes_fa_default)==len(classes_en) else classes_en
231
+ )
232
+
233
+ weights, all_cands = _pick_weight(task)
234
+
235
+ # اگر torch/torchvision نیست یا وزنی نداریم → مدل لوکال لود نشود
236
+ if (not TORCH_AVAILABLE) or (not weights) or (not os.path.isfile(weights)):
237
+ tm = TaskModel(task, None, device, img_size, classes_fa, classes_en, weights if weights else None,
238
+ [], [], make_transform(img_size))
239
+ tm._all_weight_candidates = all_cands # type: ignore
240
+ return tm
241
+
242
+ m = build_model(model_name, num_classes=len(classes_en))
243
+ missing, unexpected = load_state(m, weights)
244
+ m.eval().to(device)
245
+ if device == 'cuda':
246
+ m.to(memory_format=torch.channels_last) # type: ignore
247
+
248
+ tm = TaskModel(task, m, device, img_size, classes_fa, classes_en, weights, missing, unexpected, make_transform(img_size))
249
+ tm._all_weight_candidates = all_cands # type: ignore
250
+ return tm
251
+
252
+ def predict_with_task(task_obj: TaskModel, pil_im: Image.Image) -> List[float]:
253
+ if (not TORCH_AVAILABLE) or (task_obj.model is None):
254
+ raise RuntimeError("Local model not available.")
255
+ x = task_obj.transform(pil_im.convert("RGB")).unsqueeze(0)
256
+ x = x.to(task_obj.device, non_blocking=True)
257
+ with torch.no_grad(): # type: ignore
258
+ logits = task_obj.model(x)
259
+ probs = torch.softmax(logits, dim=1)[0].detach().cpu().numpy().tolist() # type: ignore
260
+ return probs
261
+
262
+ # -------------------- Remote proxy helpers --------------------
263
+ def _remote_base_for(task: str) -> Optional[str]:
264
+ return os.getenv(f"RETINA_REMOTE_{task}")
265
+
266
+ def _remote_auth_header_for(task: str) -> dict:
267
+ token = os.getenv(f"RETINA_REMOTE_AUTH_{task}") or os.getenv("RETINA_REMOTE_AUTH") or ""
268
+ return {"Authorization": token} if token.strip() else {}
269
+
270
+ def _remote_verify_ssl() -> bool:
271
+ v = (os.getenv("RETINA_REMOTE_VERIFY_SSL") or "true").strip().lower()
272
+ return v not in ("0", "false", "no")
273
+
274
+ def _remote_timeout() -> int:
275
+ try:
276
+ return int(os.getenv("RETINA_REMOTE_TIMEOUT", "90"))
277
+ except Exception:
278
+ return 90
279
+
280
+ def _remote_url(task: str, mode: str) -> Optional[str]:
281
+ base = _remote_base_for(task)
282
+ if not base:
283
+ return None
284
+ base = base.strip()
285
+ if base.endswith("/predict_task") or base.endswith("/report_task"):
286
+ return base
287
+ return f"{base.rstrip('/')}/{ 'predict_task' if mode == 'predict' else 'report_task'}?task={task}"
288
+
289
+ def _proxy_predict_task(task: str, file_bytes: bytes, filename: str = "image.jpg") -> JSONResponse:
290
+ url = _remote_url(task, "predict")
291
+ if not url:
292
+ raise HTTPException(status_code=501, detail=f"Task '{task}' not loaded and no remote set (RETINA_REMOTE_{task}).")
293
+ headers = _remote_auth_header_for(task)
294
+ try:
295
+ r = requests.post(
296
+ url,
297
+ files={"file": (filename, file_bytes, "image/jpeg")},
298
+ headers=headers,
299
+ timeout=_remote_timeout(),
300
+ verify=_remote_verify_ssl(),
301
+ )
302
+ if not (200 <= r.status_code < 300):
303
+ raise HTTPException(status_code=r.status_code, detail=f"Remote error: {r.text}")
304
+ try:
305
+ return JSONResponse(r.json())
306
+ except Exception:
307
+ return JSONResponse({"remote_raw": r.text})
308
+ except requests.RequestException as e:
309
+ raise HTTPException(status_code=502, detail=f"Remote proxy failed: {e}")
310
+
311
+ def _proxy_report_task(task: str, file_bytes: bytes, form: dict, filename: str = "image.jpg") -> JSONResponse:
312
+ url = _remote_url(task, "report")
313
+ if not url:
314
+ raise HTTPException(status_code=501, detail=f"Task '{task}' not loaded and no remote set (RETINA_REMOTE_{task}).")
315
+ headers = _remote_auth_header_for(task)
316
+ try:
317
+ r = requests.post(
318
+ url,
319
+ files={"file": (filename, file_bytes, "image/jpeg")},
320
+ data=form,
321
+ headers=headers,
322
+ timeout=_remote_timeout(),
323
+ verify=_remote_verify_ssl(),
324
+ )
325
+ if not (200 <= r.status_code < 300):
326
+ raise HTTPException(status_code=r.status_code, detail=f"Remote error: {r.text}")
327
+ try:
328
+ return JSONResponse(r.json())
329
+ except Exception:
330
+ return JSONResponse({"remote_raw": r.text})
331
+ except requests.RequestException as e:
332
+ raise HTTPException(status_code=502, detail=f"Remote proxy failed: {e}")
333
+
334
+ # -------------------- App --------------------
335
+ app = FastAPI(title="Retina Multi-Task Inference API (Unified)", version="1.3.1")
336
+ app.add_middleware(
337
+ CORSMiddleware,
338
+ allow_origins=["*"], allow_credentials=True,
339
+ allow_methods=["*"], allow_headers=["*"],
340
+ )
341
+
342
+ _DEVICE = device_setup()
343
+ _TASKS = env_list("RETINA_TASKS", DEFAULT_TASKS)
344
+ _TASK_MODELS: Dict[str, TaskModel] = {t: prepare_task(t, _DEVICE) for t in _TASKS}
345
+ DEFAULT_FALLBACK_TASK = os.getenv("RETINA_DEFAULT_TASK", "dr").strip().lower()
346
+
347
+ # -------------------- Helpers for QC/format --------------------
348
+ def _simple_qc(im: Image.Image) -> dict:
349
+ try:
350
+ import numpy as np # lazy
351
+ except Exception:
352
+ w, h = im.size
353
+ return {"width": w, "height": h, "mean_luma": None, "warnings": [], "ok": True}
354
+ w, h = im.size
355
+ mean_luma = float(np.array(im.convert("L")).mean())
356
+ warns: List[str] = []
357
+ if min(w, h) < 512: warns.append("low_resolution")
358
+ if mean_luma < 25: warns.append("too_dark")
359
+ if mean_luma > 230: warns.append("too_bright")
360
+ return {"width": w, "height": h, "mean_luma": round(mean_luma,1), "warnings": warns, "ok": len(warns)==0}
361
+
362
+ def _items_from_probs(task: str, probs: List[float]):
363
+ tm = _TASK_MODELS[task]
364
+ items = [{"index": i,
365
+ "class_en": tm.classes_en[i],
366
+ "class_fa": tm.classes_fa[i],
367
+ "prob": float(p)} for i, p in enumerate(probs)]
368
+ items_sorted = sorted(items, key=lambda d: d["prob"], reverse=True)
369
+ top1 = items_sorted[0]
370
+ return items_sorted, top1
371
+
372
+ def _format_report(task: str, probs: List[float], patient_name: str = "", exam_date: str = "", eye: str = "") -> str:
373
+ tm = _TASK_MODELS[task]
374
+ items, top = _items_from_probs(task, probs)
375
+ title_map = {
376
+ "dr": "گزارش رتینوپاتی دیابتی (DR)",
377
+ "oct_cme": "گزارش OCT - CME",
378
+ "oct_csr": "گزارش OCT - CSR",
379
+ "oct_amd": "گزارش OCT - AMD",
380
+ "glaucoma": "گزارش گلوکوم",
381
+ "keratoconus": "گزارش کراتوکونوس",
382
+ }
383
+ title = title_map.get(task, f"گزارش {task}")
384
+ lines: List[str] = []
385
+ lines.append(f"👁 {title} برای بیمار: {patient_name or '—'}")
386
+ lines.append(f"📅 تاریخ معاینه: {exam_date or '—'}")
387
+ if eye: lines.append(f"👓 چشم: {eye}")
388
+ lines.append("________________________________________")
389
+ lines.append("📌 نتیجه الگوریتم (Top-1):")
390
+ lines.append(f"• {top['class_fa']} ({top['class_en']}) — احتمال {top['prob']:.3f}")
391
+ lines.append("📊 توزیع احتمالات:")
392
+ for it in items:
393
+ lines.append(f" {it['class_fa']} ({it['class_en']}) {it['prob']:.4f}")
394
+ if task == "dr":
395
+ lines.append("🧠 یادداشت: نتیجه برای کمک به تصمیم‌گیری است؛ در موارد مثبت معاینه بالینی/تصویربرداری تکمیلی توصیه می‌شود.")
396
+ elif task.startswith("oct_"):
397
+ lines.append("🧠 یادداشت: تفسیر نهایی با همبستگی بالینی و تصاویر مکمل.")
398
+ elif task in ("glaucoma", "keratoconus"):
399
+ lines.append("🧠 یادداشت: جایگزین تشخیص پزشک نیست و باید با پاراکلینیک تلفیق شود.")
400
+ return "\n".join(lines)
401
+
402
+ # -------------------- Pages --------------------
403
+ @app.get("/", response_class=HTMLResponse)
404
+ def root():
405
+ li = "".join([f"<li>{t} — loaded={_TASK_MODELS[t].model is not None} — img={_TASK_MODELS[t].img_size}</li>" for t in _TASKS])
406
+ return f"""
407
+ <html><head><meta charset="utf-8"><title>Retina Unified API</title></head>
408
+ <body style="font-family:Tahoma,Arial,sans-serif">
409
+ <h2>Retina Multi-Task Predictor (Single Port)</h2>
410
+ <p>Device: <b>{_DEVICE}</b> | Tasks: {", ".join(_TASKS)}</p>
411
+ <ul>{li}</ul>
412
+ <h3>Quick Forms</h3>
413
+ <form action="/predict" method="post" enctype="multipart/form-data">
414
+ <div><b>Back-compat /predict (RETINA_DEFAULT_TASK = {DEFAULT_FALLBACK_TASK})</b></div>
415
+ <input type="file" name="file" accept="image/*" required />
416
+ <button type="submit">/predict</button>
417
+ </form>
418
+ <hr/>
419
+ <form action="/predict_task?task=oct_cme" method="post" enctype="multipart/form-data">
420
+ <div><b>OCT - CME</b></div>
421
+ <input type="file" name="file" accept="image/*" required />
422
+ <button type="submit">/predict_task?task=oct_cme</button>
423
+ </form>
424
+ </body></html>
425
+ """
426
+
427
+ # -------------------- Meta --------------------
428
+ @app.get("/tasks")
429
+ def tasks():
430
+ out = {}
431
+ for t, tm in _TASK_MODELS.items():
432
+ out[t] = {
433
+ "loaded": tm.model is not None,
434
+ "img_size": tm.img_size,
435
+ "classes_en": tm.classes_en,
436
+ "classes_fa": tm.classes_fa,
437
+ "weights_used": tm.weights_path,
438
+ "weights_candidates": getattr(tm, "_all_weight_candidates", []),
439
+ "missing_keys": tm.missing_keys,
440
+ "unexpected_keys": tm.unexpected_keys,
441
+ "remote_url": _remote_url(t, "predict"),
442
+ }
443
+ return out
444
+
445
+ @app.get("/health")
446
+ def health():
447
+ return {
448
+ "device": _DEVICE,
449
+ "cuda": bool(TORCH_AVAILABLE and torch and torch.cuda.is_available()), # type: ignore
450
+ "cudnn_enabled": bool(TORCH_AVAILABLE and torch and torch.backends.cudnn.enabled), # type: ignore
451
+ "tasks": list(_TASK_MODELS.keys()),
452
+ "loaded": {t: (_TASK_MODELS[t].model is not None) for t in _TASK_MODELS},
453
+ }
454
+
455
+ # -------------------- API: multi-task --------------------
456
+ @app.post("/predict_task")
457
+ def predict_task(
458
+ task: str = Query(..., description="dr, oct_cme, oct_csr, oct_amd, glaucoma, keratoconus"),
459
+ file: UploadFile = File(...)
460
+ ):
461
+ task = task.strip().lower()
462
+ if task not in _TASK_MODELS:
463
+ raise HTTPException(status_code=404, detail=f"Unknown task: {task}")
464
+ tm = _TASK_MODELS[task]
465
+
466
+ try:
467
+ raw = file.file.read()
468
+ except Exception:
469
+ raise HTTPException(status_code=400, detail="Invalid file")
470
+
471
+ if tm.model is None:
472
+ return _proxy_predict_task(task, raw, filename=getattr(file, "filename", "image.jpg"))
473
+
474
+ try:
475
+ im = Image.open(io.BytesIO(raw))
476
+ except Exception:
477
+ raise HTTPException(status_code=400, detail="Invalid image")
478
+
479
+ qc = _simple_qc(im)
480
+ probs = predict_with_task(tm, im)
481
+ items_sorted, top1 = _items_from_probs(task, probs)
482
+ return JSONResponse({
483
+ "task": task,
484
+ "qc": qc,
485
+ "top1": top1,
486
+ "probs": items_sorted,
487
+ "weights_used": tm.weights_path,
488
+ "weights_candidates": getattr(tm, "_all_weight_candidates", []),
489
+ })
490
+
491
+ @app.post("/report_task")
492
+ def report_task(
493
+ task: str = Query(...),
494
+ file: UploadFile = File(...),
495
+ patient_name: str = Form(""),
496
+ exam_date: str = Form(""),
497
+ eye: str = Form("")
498
+ ):
499
+ task = task.strip().lower()
500
+ if task not in _TASK_MODELS:
501
+ raise HTTPException(status_code=404, detail=f"Unknown task: {task}")
502
+ tm = _TASK_MODELS[task]
503
+
504
+ try:
505
+ raw = file.file.read()
506
+ except Exception:
507
+ raise HTTPException(status_code=400, detail="Invalid file")
508
+
509
+ if tm.model is None:
510
+ form = {"patient_name": patient_name, "exam_date": exam_date, "eye": eye}
511
+ return _proxy_report_task(task, raw, form, filename=getattr(file, "filename", "image.jpg"))
512
+
513
+ try:
514
+ im = Image.open(io.BytesIO(raw))
515
+ except Exception:
516
+ raise HTTPException(status_code=400, detail="Invalid image")
517
+
518
+ qc = _simple_qc(im)
519
+ probs = predict_with_task(tm, im)
520
+ items_sorted, top1 = _items_from_probs(task, probs)
521
+ report_fa = _format_report(task, probs, patient_name=patient_name, exam_date=exam_date, eye=eye)
522
+
523
+ return JSONResponse({
524
+ "task": task,
525
+ "patient": {"name": patient_name, "exam_date": exam_date, "eye": eye},
526
+ "qc": qc, "top1": top1, "probs": items_sorted,
527
+ "report": report_fa,
528
+ "weights_used": tm.weights_path,
529
+ "weights_candidates": getattr(tm, "_all_weight_candidates", []),
530
+ })
531
+
532
+ # -------------------- Back-compat --------------------
533
+ class PredictJsonReq(BaseModel):
534
+ image_b64: str
535
+
536
+ def _get_fallback_task() -> str:
537
+ t = os.getenv("RETINA_DEFAULT_TASK", "dr").strip().lower()
538
+ if t not in _TASK_MODELS:
539
+ raise HTTPException(status_code=404, detail=f"Unknown default task: {t}")
540
+ return t
541
+
542
+ @app.post("/predict")
543
+ def predict(file: UploadFile = File(...)):
544
+ task = _get_fallback_task()
545
+ tm = _TASK_MODELS[task]
546
+ try:
547
+ raw = file.file.read()
548
+ except Exception:
549
+ raise HTTPException(status_code=400, detail="Invalid file")
550
+
551
+ if tm.model is None:
552
+ return _proxy_predict_task(task, raw, filename=getattr(file, "filename", "image.jpg"))
553
+
554
+ try:
555
+ im = Image.open(io.BytesIO(raw))
556
+ except Exception:
557
+ raise HTTPException(status_code=400, detail="Invalid image")
558
+
559
+ qc = _simple_qc(im)
560
+ probs = predict_with_task(tm, im)
561
+ items_sorted, top1 = _items_from_probs(task, probs)
562
+ return {"task": task, "qc": qc, "top1": top1, "probs": items_sorted}
563
+
564
+ @app.post("/predict_json")
565
+ def predict_json(req: PredictJsonReq):
566
+ task = _get_fallback_task()
567
+ tm = _TASK_MODELS[task]
568
+ try:
569
+ data = base64.b64decode(req.image_b64)
570
+ except Exception:
571
+ raise HTTPException(status_code=400, detail="Invalid base64 image")
572
+
573
+ if tm.model is None:
574
+ return _proxy_predict_task(task, data, filename="image.jpg")
575
+
576
+ try:
577
+ im = Image.open(io.BytesIO(data))
578
+ except Exception:
579
+ raise HTTPException(status_code=400, detail="Invalid image data")
580
+
581
+ qc = _simple_qc(im)
582
+ probs = predict_with_task(tm, im)
583
+ items_sorted, top1 = _items_from_probs(task, probs)
584
+ return {"task": task, "qc": qc, "top1": top1, "probs": items_sorted}
585
+
586
+ @app.post("/report")
587
+ def report(
588
+ file: UploadFile = File(...),
589
+ patient_name: str = Form(""),
590
+ exam_date: str = Form(""),
591
+ eye: str = Form("OD")
592
+ ):
593
+ task = _get_fallback_task()
594
+ tm = _TASK_MODELS[task]
595
+ try:
596
+ raw = file.file.read()
597
+ except Exception:
598
+ raise HTTPException(status_code=400, detail="Invalid file")
599
+
600
+ if tm.model is None:
601
+ form = {"patient_name": patient_name, "exam_date": exam_date, "eye": eye}
602
+ return _proxy_report_task(task, raw, form, filename=getattr(file, "filename", "image.jpg"))
603
+
604
+ try:
605
+ im = Image.open(io.BytesIO(raw))
606
+ except Exception:
607
+ raise HTTPException(status_code=400, detail="Invalid image")
608
+
609
+ qc = _simple_qc(im)
610
+ probs = predict_with_task(tm, im)
611
+ items_sorted, top1 = _items_from_probs(task, probs)
612
+ rep = _format_report(task, probs, patient_name=patient_name, exam_date=exam_date, eye=eye)
613
+ return {"task": task,
614
+ "patient": {"name": patient_name, "exam_date": exam_date, "eye": eye},
615
+ "qc": qc, "top1": top1, "probs": items_sorted,
616
+ "report": rep}
617
+
618
+ @app.post("/predict_strict")
619
+ def predict_strict(file: UploadFile = File(...), tta: int = 1):
620
+ return predict(file)