Rashidbm
Initial deployment
6276d4c
"""
FastAPI backend for the Saudi Date Classifier.
Serves a custom static site (static/index.html) and a /api/predict endpoint
that runs the ResNet/EfficientNet/ViT ensemble on uploaded images and returns
predictions, confidence breakdown, heritage info, and a Grad-CAM overlay.
Run: python server.py
"""
import base64
from io import BytesIO
from pathlib import Path
import cv2
import numpy as np
import torch
import torch.nn.functional as F
from fastapi import FastAPI, File, Form, HTTPException, UploadFile
from fastapi.responses import FileResponse, JSONResponse
from fastapi.staticfiles import StaticFiles
from huggingface_hub import hf_hub_download
from PIL import Image
from pytorch_grad_cam import GradCAM
from pytorch_grad_cam.utils.image import show_cam_on_image
from src.dataset import get_val_transforms
from src.ensemble import (
CHECKPOINTS,
CLASS_NAMES,
HF_REPO_ID,
PretrainedViTClassifier,
build_efficientnet,
build_resnet50,
load_checkpoint,
)
from src.explainability import reshape_transform
from src.utils import HERITAGE_INFO, get_device, load_config
ROOT = Path(__file__).parent
STATIC_DIR = ROOT / "static"
RESULTS_DIR = ROOT / "results"
print("Loading models...")
config = load_config()
device = get_device()
transform = get_val_transforms(config)
paths = {
name: hf_hub_download(repo_id=HF_REPO_ID, filename=fname)
for name, fname in CHECKPOINTS.items()
}
models_dict = {
"resnet": load_checkpoint(build_resnet50(num_classes=9), paths["resnet"], device),
"efficientnet": load_checkpoint(build_efficientnet(num_classes=9), paths["efficientnet"], device),
"vit": load_checkpoint(PretrainedViTClassifier(num_classes=9), paths["vit"], device),
}
print(f"All models loaded on {device}")
vit_target_layer = models_dict["vit"].backbone.vit.encoder.layer[-1].layernorm_before
gradcam = GradCAM(
model=models_dict["vit"],
target_layers=[vit_target_layer],
reshape_transform=reshape_transform,
)
app = FastAPI(title="Saudi Date Classifier")
app.mount("/static", StaticFiles(directory=str(STATIC_DIR)), name="static")
@app.get("/")
def root():
return FileResponse(STATIC_DIR / "index.html")
@app.get("/tsne.png")
def tsne_image():
path = RESULTS_DIR / "tsne.png"
if not path.exists():
raise HTTPException(status_code=404, detail="t-SNE image not generated yet")
return FileResponse(path)
def _encode_png(rgb: np.ndarray) -> str:
pil = Image.fromarray(rgb.astype(np.uint8))
buf = BytesIO()
pil.save(buf, format="PNG")
return "data:image/png;base64," + base64.b64encode(buf.getvalue()).decode("ascii")
@app.post("/api/predict")
async def predict(file: UploadFile = File(...), model: str = Form("ensemble")):
data = await file.read()
try:
pil = Image.open(BytesIO(data)).convert("RGB")
except Exception as e:
raise HTTPException(status_code=400, detail=f"Invalid image: {e}")
image = np.array(pil)
transformed = transform(image=image)
input_tensor = transformed["image"].unsqueeze(0).to(device)
with torch.no_grad():
if model == "ensemble":
probs_sum = torch.zeros(9).to(device)
for m in models_dict.values():
probs_sum += F.softmax(m(input_tensor), dim=1)[0]
probs = probs_sum / len(models_dict)
elif model in models_dict:
probs = F.softmax(models_dict[model](input_tensor), dim=1)[0]
else:
raise HTTPException(status_code=400, detail=f"Unknown model: {model}")
confidences = {CLASS_NAMES[i]: float(probs[i].item()) for i in range(9)}
top_idx = int(probs.argmax().item())
top_variety = CLASS_NAMES[top_idx]
top_conf = float(probs[top_idx].item())
gradcam_b64 = None
try:
grayscale_cam = gradcam(input_tensor=input_tensor, targets=None)[0]
resized_rgb = cv2.resize(image, (224, 224)).astype(np.float32) / 255.0
cam_image = show_cam_on_image(resized_rgb, grayscale_cam, use_rgb=True)
gradcam_b64 = _encode_png(cam_image)
except Exception as e:
print(f"Grad-CAM failed: {e}")
return JSONResponse(
{
"variety": top_variety,
"confidence": top_conf,
"confidences": confidences,
"heritage": HERITAGE_INFO.get(top_variety, {}),
"gradcam": gradcam_b64,
"model": model,
}
)
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=7864)