Spaces:
Sleeping
Sleeping
| import io | |
| import json | |
| import torch | |
| import torch.nn as nn | |
| import timm | |
| from fastapi import FastAPI, File, UploadFile | |
| from PIL import Image | |
| from torchvision import transforms | |
| # -------------------------------------------------- | |
| # Load model ONCE at startup | |
| # -------------------------------------------------- | |
| DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| WEIGHTS_PATH = "mobilenetv3_large_100_deploy.pth" | |
| METADATA_PATH = "metadata.json" | |
| with open(METADATA_PATH, "r") as f: | |
| metadata = json.load(f) | |
| MODEL_NAME = metadata["model_name"] | |
| NUM_CLASSES = metadata["num_classes"] | |
| CLASS_NAMES = metadata["class_names"] | |
| mean = metadata.get("normalize", {}).get("mean", [0.485, 0.456, 0.406]) | |
| std = metadata.get("normalize", {}).get("std", [0.229, 0.224, 0.225]) | |
| input_size = metadata.get("input_size", [224, 224]) | |
| # Build model (same as training: timm + num_classes=4) | |
| model = timm.create_model( | |
| MODEL_NAME, | |
| pretrained=False, | |
| num_classes=NUM_CLASSES | |
| ) | |
| # Load checkpoint safely. | |
| # torch.load weights_only=True restricts unpickling to safer objects [page:1]. | |
| ckpt = torch.load(WEIGHTS_PATH, map_location=DEVICE, weights_only=True) | |
| # Your training saved: {"model_name":..., "num_classes":..., "mean":..., "std":..., "state_dict":...} | |
| # But some people save only state_dict, so support both: | |
| state_dict = ckpt["state_dict"] if isinstance(ckpt, dict) and "state_dict" in ckpt else ckpt | |
| model.load_state_dict(state_dict) | |
| model.to(DEVICE) | |
| model.eval() | |
| # Preprocessing (must match training normalization) | |
| transform = transforms.Compose([ | |
| transforms.Resize((input_size[0], input_size[1])), | |
| transforms.ToTensor(), | |
| transforms.Normalize(mean=mean, std=std), | |
| ]) | |
| # -------------------------------------------------- | |
| # FastAPI app | |
| # -------------------------------------------------- | |
| app = FastAPI( | |
| title="Age Group Classification API", | |
| description="MobileNetV3 Age-Group Prediction (A/B/C/D)", | |
| version="1.0" | |
| ) | |
| def root(): | |
| return { | |
| "message": "Age Group Classification API is running", | |
| "model": MODEL_NAME, | |
| "classes": CLASS_NAMES, | |
| "input_size": input_size, | |
| "normalize": {"mean": mean, "std": std} | |
| } | |
| async def predict(file: UploadFile = File(...)): | |
| image_bytes = await file.read() | |
| image = Image.open(io.BytesIO(image_bytes)).convert("RGB") | |
| image_tensor = transform(image).unsqueeze(0).to(DEVICE) | |
| with torch.no_grad(): | |
| outputs = model(image_tensor) | |
| probs = torch.softmax(outputs, dim=1) | |
| conf, pred = torch.max(probs, dim=1) | |
| pred_idx = pred.item() | |
| # build probabilities dict for all classes (not hard-coded to 2) | |
| prob_dict = { | |
| CLASS_NAMES[i]: round(probs[0, i].item() * 100, 2) | |
| for i in range(NUM_CLASSES) | |
| } | |
| return { | |
| "predicted_class": CLASS_NAMES[pred_idx], | |
| "confidence": round(conf.item() * 100, 2), | |
| "probabilities": prob_dict | |
| } | |