Spaces:
Sleeping
Sleeping
File size: 2,498 Bytes
c80fed4 bafe5eb c80fed4 e2d87ba |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 |
import torch
import torch.nn as nn
import timm
import json
import io
from fastapi import FastAPI, File, UploadFile
from PIL import Image
from torchvision import transforms
# --------------------------------------------------
# Load model ONCE at startup
# --------------------------------------------------
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
WEIGHTS_PATH = "mobilenetv3_gender_weights.pth"
METADATA_PATH = "metadata.json"
# Load metadata
with open(METADATA_PATH, "r") as f:
metadata = json.load(f)
# Build model
model = timm.create_model(
metadata["model_name"],
pretrained=False,
num_classes=metadata["num_classes"]
)
# Rebuild classifier
config = metadata["classifier_config"]
model.classifier = nn.Sequential(
nn.Linear(config["in_features"], config["hidden_dim"]),
nn.ReLU(),
nn.Dropout(config["dropout"]),
nn.Linear(config["hidden_dim"], metadata["num_classes"])
)
# Load weights safely
state_dict = torch.load(WEIGHTS_PATH, map_location=DEVICE, weights_only=True)
model.load_state_dict(state_dict)
model.to(DEVICE)
model.eval()
# Image preprocessing
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]
)
])
# --------------------------------------------------
# FastAPI app
# --------------------------------------------------
app = FastAPI(
title="Gender Classification API",
description="MobileNetV3 Gender Prediction",
version="1.0"
)
@app.get("/")
def root():
return {
"message": "Gender Classification API is running 🚀",
"model": metadata["model_name"],
"classes": metadata["class_names"]
}
@app.post("/predict")
async def predict(file: UploadFile = File(...)):
image_bytes = await file.read()
image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
image_tensor = transform(image).unsqueeze(0).to(DEVICE)
with torch.no_grad():
outputs = model(image_tensor)
probs = torch.softmax(outputs, dim=1)
confidence, predicted = torch.max(probs, 1)
return {
"predicted_class": metadata["class_names"][predicted.item()],
"confidence": round(confidence.item() * 100, 2),
"probabilities": {
metadata["class_names"][0]: round(probs[0][0].item() * 100, 2),
metadata["class_names"][1]: round(probs[0][1].item() * 100, 2),
}
} |