File size: 2,498 Bytes
c80fed4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bafe5eb
c80fed4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e2d87ba
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
import torch
import torch.nn as nn
import timm
import json
import io

from fastapi import FastAPI, File, UploadFile
from PIL import Image
from torchvision import transforms

# --------------------------------------------------
# Load model ONCE at startup
# --------------------------------------------------

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

WEIGHTS_PATH = "mobilenetv3_gender_weights.pth"
METADATA_PATH = "metadata.json"

# Load metadata
with open(METADATA_PATH, "r") as f:
    metadata = json.load(f)

# Build model
model = timm.create_model(
    metadata["model_name"],
    pretrained=False,
    num_classes=metadata["num_classes"]
)

# Rebuild classifier
config = metadata["classifier_config"]
model.classifier = nn.Sequential(
    nn.Linear(config["in_features"], config["hidden_dim"]),
    nn.ReLU(),
    nn.Dropout(config["dropout"]),
    nn.Linear(config["hidden_dim"], metadata["num_classes"])
)

# Load weights safely
state_dict = torch.load(WEIGHTS_PATH, map_location=DEVICE, weights_only=True)
model.load_state_dict(state_dict)
model.to(DEVICE)
model.eval()

# Image preprocessing
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225]
    )
])

# --------------------------------------------------
# FastAPI app
# --------------------------------------------------

app = FastAPI(
    title="Gender Classification API",
    description="MobileNetV3 Gender Prediction",
    version="1.0"
)

@app.get("/")
def root():
    return {
        "message": "Gender Classification API is running 🚀",
        "model": metadata["model_name"],
        "classes": metadata["class_names"]
    }

@app.post("/predict")
async def predict(file: UploadFile = File(...)):
    image_bytes = await file.read()
    image = Image.open(io.BytesIO(image_bytes)).convert("RGB")

    image_tensor = transform(image).unsqueeze(0).to(DEVICE)

    with torch.no_grad():
        outputs = model(image_tensor)
        probs = torch.softmax(outputs, dim=1)

    confidence, predicted = torch.max(probs, 1)

    return {
        "predicted_class": metadata["class_names"][predicted.item()],
        "confidence": round(confidence.item() * 100, 2),
        "probabilities": {
            metadata["class_names"][0]: round(probs[0][0].item() * 100, 2),
            metadata["class_names"][1]: round(probs[0][1].item() * 100, 2),
        }
    }