GenderPredict / app.py
n0v33n
update cahnges
bafe5eb
import torch
import torch.nn as nn
import timm
import json
import io
from fastapi import FastAPI, File, UploadFile
from PIL import Image
from torchvision import transforms
# --------------------------------------------------
# Load model ONCE at startup
# --------------------------------------------------
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
WEIGHTS_PATH = "mobilenetv3_gender_weights.pth"
METADATA_PATH = "metadata.json"
# Load metadata
with open(METADATA_PATH, "r") as f:
metadata = json.load(f)
# Build model
model = timm.create_model(
metadata["model_name"],
pretrained=False,
num_classes=metadata["num_classes"]
)
# Rebuild classifier
config = metadata["classifier_config"]
model.classifier = nn.Sequential(
nn.Linear(config["in_features"], config["hidden_dim"]),
nn.ReLU(),
nn.Dropout(config["dropout"]),
nn.Linear(config["hidden_dim"], metadata["num_classes"])
)
# Load weights safely
state_dict = torch.load(WEIGHTS_PATH, map_location=DEVICE, weights_only=True)
model.load_state_dict(state_dict)
model.to(DEVICE)
model.eval()
# Image preprocessing
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]
)
])
# --------------------------------------------------
# FastAPI app
# --------------------------------------------------
app = FastAPI(
title="Gender Classification API",
description="MobileNetV3 Gender Prediction",
version="1.0"
)
@app.get("/")
def root():
return {
"message": "Gender Classification API is running ๐Ÿš€",
"model": metadata["model_name"],
"classes": metadata["class_names"]
}
@app.post("/predict")
async def predict(file: UploadFile = File(...)):
image_bytes = await file.read()
image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
image_tensor = transform(image).unsqueeze(0).to(DEVICE)
with torch.no_grad():
outputs = model(image_tensor)
probs = torch.softmax(outputs, dim=1)
confidence, predicted = torch.max(probs, 1)
return {
"predicted_class": metadata["class_names"][predicted.item()],
"confidence": round(confidence.item() * 100, 2),
"probabilities": {
metadata["class_names"][0]: round(probs[0][0].item() * 100, 2),
metadata["class_names"][1]: round(probs[0][1].item() * 100, 2),
}
}