import io import json import torch import torch.nn as nn import timm from fastapi import FastAPI, File, UploadFile from PIL import Image from torchvision import transforms # -------------------------------------------------- # Load model ONCE at startup # -------------------------------------------------- DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") WEIGHTS_PATH = "mobilenetv3_large_100_deploy.pth" METADATA_PATH = "metadata.json" with open(METADATA_PATH, "r") as f: metadata = json.load(f) MODEL_NAME = metadata["model_name"] NUM_CLASSES = metadata["num_classes"] CLASS_NAMES = metadata["class_names"] mean = metadata.get("normalize", {}).get("mean", [0.485, 0.456, 0.406]) std = metadata.get("normalize", {}).get("std", [0.229, 0.224, 0.225]) input_size = metadata.get("input_size", [224, 224]) # Build model (same as training: timm + num_classes=4) model = timm.create_model( MODEL_NAME, pretrained=False, num_classes=NUM_CLASSES ) # Load checkpoint safely. # torch.load weights_only=True restricts unpickling to safer objects [page:1]. ckpt = torch.load(WEIGHTS_PATH, map_location=DEVICE, weights_only=True) # Your training saved: {"model_name":..., "num_classes":..., "mean":..., "std":..., "state_dict":...} # But some people save only state_dict, so support both: state_dict = ckpt["state_dict"] if isinstance(ckpt, dict) and "state_dict" in ckpt else ckpt model.load_state_dict(state_dict) model.to(DEVICE) model.eval() # Preprocessing (must match training normalization) transform = transforms.Compose([ transforms.Resize((input_size[0], input_size[1])), transforms.ToTensor(), transforms.Normalize(mean=mean, std=std), ]) # -------------------------------------------------- # FastAPI app # -------------------------------------------------- app = FastAPI( title="Age Group Classification API", description="MobileNetV3 Age-Group Prediction (A/B/C/D)", version="1.0" ) @app.get("/") def root(): return { "message": "Age Group Classification API is running", "model": MODEL_NAME, "classes": CLASS_NAMES, "input_size": input_size, "normalize": {"mean": mean, "std": std} } @app.post("/predict") async def predict(file: UploadFile = File(...)): image_bytes = await file.read() image = Image.open(io.BytesIO(image_bytes)).convert("RGB") image_tensor = transform(image).unsqueeze(0).to(DEVICE) with torch.no_grad(): outputs = model(image_tensor) probs = torch.softmax(outputs, dim=1) conf, pred = torch.max(probs, dim=1) pred_idx = pred.item() # build probabilities dict for all classes (not hard-coded to 2) prob_dict = { CLASS_NAMES[i]: round(probs[0, i].item() * 100, 2) for i in range(NUM_CLASSES) } return { "predicted_class": CLASS_NAMES[pred_idx], "confidence": round(conf.item() * 100, 2), "probabilities": prob_dict }