File size: 2,500 Bytes
a65c9ed
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
import io
import numpy as np
import torch
from PIL import Image
from fastapi import APIRouter, Depends, File, UploadFile, Form, HTTPException
from fastapi.responses import JSONResponse
from typing import Optional

from ..state import app_state
from ..utils import normalize_age

from fastapi_limiter.depends import RateLimiter

from app.config import RATE_TIMES, RATE_SECONDS

router = APIRouter()

@router.get("/health")
def health():
    return {
        "status": "ok", 
        "device": str(app_state.device), 
        "classes": app_state.id2label, 
        "model_loaded": app_state.is_model_loaded()
    }


@router.post("/predict", dependencies=[Depends(RateLimiter(times=RATE_TIMES, seconds=RATE_SECONDS))],)
async def predict(
    file: UploadFile = File(..., description="RGB lesion image"),
    age: Optional[float] = Form(None),
    localization: Optional[str] = Form("unknown"),
    top_k: Optional[int] = Form(3),
):
    if not app_state.is_model_loaded():
        raise HTTPException(status_code=503, detail="Model not loaded yet")

    # Read image
    try:
        img_bytes = await file.read()
        img = Image.open(io.BytesIO(img_bytes)).convert("RGB")
    except Exception as e:
        raise HTTPException(status_code=400, detail=f"Invalid image: {e}")

    # Preprocess image
    px = app_state.image_processor(img, return_tensors="pt")["pixel_values"].to(app_state.device)

    # Tabular vector
    loc = (localization or "unknown").strip().lower()
    loc_oh = app_state.loc_encoder.transform(np.array([loc]).reshape(-1, 1))  # (1, L)
    norm_age = normalize_age(age, app_state.age_stats["age_min"], app_state.age_stats["age_max"], app_state.age_stats["age_mean"])
    tab = np.concatenate([loc_oh, np.array([[norm_age]])], axis=1).astype("float32")
    tab_t = torch.tensor(tab, dtype=torch.float32, device=app_state.device)

    # Forward
    with torch.no_grad():
        logits = app_state.model(pixel_values=px, tabular_features=tab_t)
        probs = torch.softmax(logits, dim=-1).cpu().numpy()[0]

    # Top-k
    k = max(1, min(int(top_k or 3), len(probs)))
    idxs = np.argsort(-probs)[:k]
    top = [{"label": app_state.id2label[int(i)], "probability": float(probs[i])} for i in idxs]
    dist = {app_state.id2label[int(i)]: float(p) for i, p in enumerate(probs)}

    payload = {
        "top": top
        # "distribution": dist,
        # "accepted_localizations_example": app_state.valid_localizations[:10]
    }

    return JSONResponse(content=payload)