Spaces:
Sleeping
Sleeping
File size: 2,972 Bytes
9207b26 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 |
import io
import json
import torch
import torch.nn as nn
import timm
from fastapi import FastAPI, File, UploadFile
from PIL import Image
from torchvision import transforms
# --------------------------------------------------
# Load model ONCE at startup
# --------------------------------------------------
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
WEIGHTS_PATH = "mobilenetv3_large_100_deploy.pth"
METADATA_PATH = "metadata.json"
with open(METADATA_PATH, "r") as f:
metadata = json.load(f)
MODEL_NAME = metadata["model_name"]
NUM_CLASSES = metadata["num_classes"]
CLASS_NAMES = metadata["class_names"]
mean = metadata.get("normalize", {}).get("mean", [0.485, 0.456, 0.406])
std = metadata.get("normalize", {}).get("std", [0.229, 0.224, 0.225])
input_size = metadata.get("input_size", [224, 224])
# Build model (same as training: timm + num_classes=4)
model = timm.create_model(
MODEL_NAME,
pretrained=False,
num_classes=NUM_CLASSES
)
# Load checkpoint safely.
# torch.load weights_only=True restricts unpickling to safer objects [page:1].
ckpt = torch.load(WEIGHTS_PATH, map_location=DEVICE, weights_only=True)
# Your training saved: {"model_name":..., "num_classes":..., "mean":..., "std":..., "state_dict":...}
# But some people save only state_dict, so support both:
state_dict = ckpt["state_dict"] if isinstance(ckpt, dict) and "state_dict" in ckpt else ckpt
model.load_state_dict(state_dict)
model.to(DEVICE)
model.eval()
# Preprocessing (must match training normalization)
transform = transforms.Compose([
transforms.Resize((input_size[0], input_size[1])),
transforms.ToTensor(),
transforms.Normalize(mean=mean, std=std),
])
# --------------------------------------------------
# FastAPI app
# --------------------------------------------------
app = FastAPI(
title="Age Group Classification API",
description="MobileNetV3 Age-Group Prediction (A/B/C/D)",
version="1.0"
)
@app.get("/")
def root():
return {
"message": "Age Group Classification API is running",
"model": MODEL_NAME,
"classes": CLASS_NAMES,
"input_size": input_size,
"normalize": {"mean": mean, "std": std}
}
@app.post("/predict")
async def predict(file: UploadFile = File(...)):
image_bytes = await file.read()
image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
image_tensor = transform(image).unsqueeze(0).to(DEVICE)
with torch.no_grad():
outputs = model(image_tensor)
probs = torch.softmax(outputs, dim=1)
conf, pred = torch.max(probs, dim=1)
pred_idx = pred.item()
# build probabilities dict for all classes (not hard-coded to 2)
prob_dict = {
CLASS_NAMES[i]: round(probs[0, i].item() * 100, 2)
for i in range(NUM_CLASSES)
}
return {
"predicted_class": CLASS_NAMES[pred_idx],
"confidence": round(conf.item() * 100, 2),
"probabilities": prob_dict
}
|