import torch import torch.nn as nn import timm import json import io from fastapi import FastAPI, File, UploadFile from PIL import Image from torchvision import transforms # -------------------------------------------------- # Load model ONCE at startup # -------------------------------------------------- DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") WEIGHTS_PATH = "mobilenetv3_gender_weights.pth" METADATA_PATH = "metadata.json" # Load metadata with open(METADATA_PATH, "r") as f: metadata = json.load(f) # Build model model = timm.create_model( metadata["model_name"], pretrained=False, num_classes=metadata["num_classes"] ) # Rebuild classifier config = metadata["classifier_config"] model.classifier = nn.Sequential( nn.Linear(config["in_features"], config["hidden_dim"]), nn.ReLU(), nn.Dropout(config["dropout"]), nn.Linear(config["hidden_dim"], metadata["num_classes"]) ) # Load weights safely state_dict = torch.load(WEIGHTS_PATH, map_location=DEVICE, weights_only=True) model.load_state_dict(state_dict) model.to(DEVICE) model.eval() # Image preprocessing transform = transforms.Compose([ transforms.Resize((224, 224)), transforms.ToTensor(), transforms.Normalize( mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] ) ]) # -------------------------------------------------- # FastAPI app # -------------------------------------------------- app = FastAPI( title="Gender Classification API", description="MobileNetV3 Gender Prediction", version="1.0" ) @app.get("/") def root(): return { "message": "Gender Classification API is running 🚀", "model": metadata["model_name"], "classes": metadata["class_names"] } @app.post("/predict") async def predict(file: UploadFile = File(...)): image_bytes = await file.read() image = Image.open(io.BytesIO(image_bytes)).convert("RGB") image_tensor = transform(image).unsqueeze(0).to(DEVICE) with torch.no_grad(): outputs = model(image_tensor) probs = torch.softmax(outputs, dim=1) confidence, predicted = torch.max(probs, 1) return { "predicted_class": metadata["class_names"][predicted.item()], "confidence": round(confidence.item() * 100, 2), "probabilities": { metadata["class_names"][0]: round(probs[0][0].item() * 100, 2), metadata["class_names"][1]: round(probs[0][1].item() * 100, 2), } }