AgePredict / app.py
n0v33n
Add initail commit
9207b26
import io
import json
import torch
import torch.nn as nn
import timm
from fastapi import FastAPI, File, UploadFile
from PIL import Image
from torchvision import transforms
# --------------------------------------------------
# Load model ONCE at startup
# --------------------------------------------------
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
WEIGHTS_PATH = "mobilenetv3_large_100_deploy.pth"
METADATA_PATH = "metadata.json"
with open(METADATA_PATH, "r") as f:
metadata = json.load(f)
MODEL_NAME = metadata["model_name"]
NUM_CLASSES = metadata["num_classes"]
CLASS_NAMES = metadata["class_names"]
mean = metadata.get("normalize", {}).get("mean", [0.485, 0.456, 0.406])
std = metadata.get("normalize", {}).get("std", [0.229, 0.224, 0.225])
input_size = metadata.get("input_size", [224, 224])
# Build model (same as training: timm + num_classes=4)
model = timm.create_model(
MODEL_NAME,
pretrained=False,
num_classes=NUM_CLASSES
)
# Load checkpoint safely.
# torch.load weights_only=True restricts unpickling to safer objects [page:1].
ckpt = torch.load(WEIGHTS_PATH, map_location=DEVICE, weights_only=True)
# Your training saved: {"model_name":..., "num_classes":..., "mean":..., "std":..., "state_dict":...}
# But some people save only state_dict, so support both:
state_dict = ckpt["state_dict"] if isinstance(ckpt, dict) and "state_dict" in ckpt else ckpt
model.load_state_dict(state_dict)
model.to(DEVICE)
model.eval()
# Preprocessing (must match training normalization)
transform = transforms.Compose([
transforms.Resize((input_size[0], input_size[1])),
transforms.ToTensor(),
transforms.Normalize(mean=mean, std=std),
])
# --------------------------------------------------
# FastAPI app
# --------------------------------------------------
app = FastAPI(
title="Age Group Classification API",
description="MobileNetV3 Age-Group Prediction (A/B/C/D)",
version="1.0"
)
@app.get("/")
def root():
return {
"message": "Age Group Classification API is running",
"model": MODEL_NAME,
"classes": CLASS_NAMES,
"input_size": input_size,
"normalize": {"mean": mean, "std": std}
}
@app.post("/predict")
async def predict(file: UploadFile = File(...)):
image_bytes = await file.read()
image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
image_tensor = transform(image).unsqueeze(0).to(DEVICE)
with torch.no_grad():
outputs = model(image_tensor)
probs = torch.softmax(outputs, dim=1)
conf, pred = torch.max(probs, dim=1)
pred_idx = pred.item()
# build probabilities dict for all classes (not hard-coded to 2)
prob_dict = {
CLASS_NAMES[i]: round(probs[0, i].item() * 100, 2)
for i in range(NUM_CLASSES)
}
return {
"predicted_class": CLASS_NAMES[pred_idx],
"confidence": round(conf.item() * 100, 2),
"probabilities": prob_dict
}