File size: 2,972 Bytes
9207b26
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
import io
import json

import torch
import torch.nn as nn
import timm

from fastapi import FastAPI, File, UploadFile
from PIL import Image
from torchvision import transforms


# --------------------------------------------------
# Load model ONCE at startup
# --------------------------------------------------

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

WEIGHTS_PATH = "mobilenetv3_large_100_deploy.pth"
METADATA_PATH = "metadata.json"

with open(METADATA_PATH, "r") as f:
    metadata = json.load(f)

MODEL_NAME = metadata["model_name"]
NUM_CLASSES = metadata["num_classes"]
CLASS_NAMES = metadata["class_names"]

mean = metadata.get("normalize", {}).get("mean", [0.485, 0.456, 0.406])
std = metadata.get("normalize", {}).get("std", [0.229, 0.224, 0.225])
input_size = metadata.get("input_size", [224, 224])

# Build model (same as training: timm + num_classes=4)
model = timm.create_model(
    MODEL_NAME,
    pretrained=False,
    num_classes=NUM_CLASSES
)

# Load checkpoint safely.
# torch.load weights_only=True restricts unpickling to safer objects [page:1].
ckpt = torch.load(WEIGHTS_PATH, map_location=DEVICE, weights_only=True)

# Your training saved: {"model_name":..., "num_classes":..., "mean":..., "std":..., "state_dict":...}
# But some people save only state_dict, so support both:
state_dict = ckpt["state_dict"] if isinstance(ckpt, dict) and "state_dict" in ckpt else ckpt

model.load_state_dict(state_dict)
model.to(DEVICE)
model.eval()

# Preprocessing (must match training normalization)
transform = transforms.Compose([
    transforms.Resize((input_size[0], input_size[1])),
    transforms.ToTensor(),
    transforms.Normalize(mean=mean, std=std),
])


# --------------------------------------------------
# FastAPI app
# --------------------------------------------------

app = FastAPI(
    title="Age Group Classification API",
    description="MobileNetV3 Age-Group Prediction (A/B/C/D)",
    version="1.0"
)

@app.get("/")
def root():
    return {
        "message": "Age Group Classification API is running",
        "model": MODEL_NAME,
        "classes": CLASS_NAMES,
        "input_size": input_size,
        "normalize": {"mean": mean, "std": std}
    }

@app.post("/predict")
async def predict(file: UploadFile = File(...)):
    image_bytes = await file.read()
    image = Image.open(io.BytesIO(image_bytes)).convert("RGB")

    image_tensor = transform(image).unsqueeze(0).to(DEVICE)

    with torch.no_grad():
        outputs = model(image_tensor)
        probs = torch.softmax(outputs, dim=1)

    conf, pred = torch.max(probs, dim=1)
    pred_idx = pred.item()

    # build probabilities dict for all classes (not hard-coded to 2)
    prob_dict = {
        CLASS_NAMES[i]: round(probs[0, i].item() * 100, 2)
        for i in range(NUM_CLASSES)
    }

    return {
        "predicted_class": CLASS_NAMES[pred_idx],
        "confidence": round(conf.item() * 100, 2),
        "probabilities": prob_dict
    }