File size: 4,503 Bytes
6276d4c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
"""
FastAPI backend for the Saudi Date Classifier.

Serves a custom static site (static/index.html) and a /api/predict endpoint
that runs the ResNet/EfficientNet/ViT ensemble on uploaded images and returns
predictions, confidence breakdown, heritage info, and a Grad-CAM overlay.

Run: python server.py
"""

import base64
from io import BytesIO
from pathlib import Path

import cv2
import numpy as np
import torch
import torch.nn.functional as F
from fastapi import FastAPI, File, Form, HTTPException, UploadFile
from fastapi.responses import FileResponse, JSONResponse
from fastapi.staticfiles import StaticFiles
from huggingface_hub import hf_hub_download
from PIL import Image
from pytorch_grad_cam import GradCAM
from pytorch_grad_cam.utils.image import show_cam_on_image

from src.dataset import get_val_transforms
from src.ensemble import (
    CHECKPOINTS,
    CLASS_NAMES,
    HF_REPO_ID,
    PretrainedViTClassifier,
    build_efficientnet,
    build_resnet50,
    load_checkpoint,
)
from src.explainability import reshape_transform
from src.utils import HERITAGE_INFO, get_device, load_config

ROOT = Path(__file__).parent
STATIC_DIR = ROOT / "static"
RESULTS_DIR = ROOT / "results"

print("Loading models...")
config = load_config()
device = get_device()
transform = get_val_transforms(config)

paths = {
    name: hf_hub_download(repo_id=HF_REPO_ID, filename=fname)
    for name, fname in CHECKPOINTS.items()
}

models_dict = {
    "resnet": load_checkpoint(build_resnet50(num_classes=9), paths["resnet"], device),
    "efficientnet": load_checkpoint(build_efficientnet(num_classes=9), paths["efficientnet"], device),
    "vit": load_checkpoint(PretrainedViTClassifier(num_classes=9), paths["vit"], device),
}
print(f"All models loaded on {device}")

vit_target_layer = models_dict["vit"].backbone.vit.encoder.layer[-1].layernorm_before
gradcam = GradCAM(
    model=models_dict["vit"],
    target_layers=[vit_target_layer],
    reshape_transform=reshape_transform,
)

app = FastAPI(title="Saudi Date Classifier")
app.mount("/static", StaticFiles(directory=str(STATIC_DIR)), name="static")


@app.get("/")
def root():
    return FileResponse(STATIC_DIR / "index.html")


@app.get("/tsne.png")
def tsne_image():
    path = RESULTS_DIR / "tsne.png"
    if not path.exists():
        raise HTTPException(status_code=404, detail="t-SNE image not generated yet")
    return FileResponse(path)


def _encode_png(rgb: np.ndarray) -> str:
    pil = Image.fromarray(rgb.astype(np.uint8))
    buf = BytesIO()
    pil.save(buf, format="PNG")
    return "data:image/png;base64," + base64.b64encode(buf.getvalue()).decode("ascii")


@app.post("/api/predict")
async def predict(file: UploadFile = File(...), model: str = Form("ensemble")):
    data = await file.read()
    try:
        pil = Image.open(BytesIO(data)).convert("RGB")
    except Exception as e:
        raise HTTPException(status_code=400, detail=f"Invalid image: {e}")

    image = np.array(pil)
    transformed = transform(image=image)
    input_tensor = transformed["image"].unsqueeze(0).to(device)

    with torch.no_grad():
        if model == "ensemble":
            probs_sum = torch.zeros(9).to(device)
            for m in models_dict.values():
                probs_sum += F.softmax(m(input_tensor), dim=1)[0]
            probs = probs_sum / len(models_dict)
        elif model in models_dict:
            probs = F.softmax(models_dict[model](input_tensor), dim=1)[0]
        else:
            raise HTTPException(status_code=400, detail=f"Unknown model: {model}")

    confidences = {CLASS_NAMES[i]: float(probs[i].item()) for i in range(9)}
    top_idx = int(probs.argmax().item())
    top_variety = CLASS_NAMES[top_idx]
    top_conf = float(probs[top_idx].item())

    gradcam_b64 = None
    try:
        grayscale_cam = gradcam(input_tensor=input_tensor, targets=None)[0]
        resized_rgb = cv2.resize(image, (224, 224)).astype(np.float32) / 255.0
        cam_image = show_cam_on_image(resized_rgb, grayscale_cam, use_rgb=True)
        gradcam_b64 = _encode_png(cam_image)
    except Exception as e:
        print(f"Grad-CAM failed: {e}")

    return JSONResponse(
        {
            "variety": top_variety,
            "confidence": top_conf,
            "confidences": confidences,
            "heritage": HERITAGE_INFO.get(top_variety, {}),
            "gradcam": gradcam_b64,
            "model": model,
        }
    )


if __name__ == "__main__":
    import uvicorn

    uvicorn.run(app, host="0.0.0.0", port=7864)