Spaces:
Sleeping
Sleeping
| import torch | |
| import torch.nn as nn | |
| import timm | |
| import json | |
| import io | |
| from fastapi import FastAPI, File, UploadFile | |
| from PIL import Image | |
| from torchvision import transforms | |
| # -------------------------------------------------- | |
| # Load model ONCE at startup | |
| # -------------------------------------------------- | |
| DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| WEIGHTS_PATH = "mobilenetv3_gender_weights.pth" | |
| METADATA_PATH = "metadata.json" | |
| # Load metadata | |
| with open(METADATA_PATH, "r") as f: | |
| metadata = json.load(f) | |
| # Build model | |
| model = timm.create_model( | |
| metadata["model_name"], | |
| pretrained=False, | |
| num_classes=metadata["num_classes"] | |
| ) | |
| # Rebuild classifier | |
| config = metadata["classifier_config"] | |
| model.classifier = nn.Sequential( | |
| nn.Linear(config["in_features"], config["hidden_dim"]), | |
| nn.ReLU(), | |
| nn.Dropout(config["dropout"]), | |
| nn.Linear(config["hidden_dim"], metadata["num_classes"]) | |
| ) | |
| # Load weights safely | |
| state_dict = torch.load(WEIGHTS_PATH, map_location=DEVICE, weights_only=True) | |
| model.load_state_dict(state_dict) | |
| model.to(DEVICE) | |
| model.eval() | |
| # Image preprocessing | |
| transform = transforms.Compose([ | |
| transforms.Resize((224, 224)), | |
| transforms.ToTensor(), | |
| transforms.Normalize( | |
| mean=[0.485, 0.456, 0.406], | |
| std=[0.229, 0.224, 0.225] | |
| ) | |
| ]) | |
| # -------------------------------------------------- | |
| # FastAPI app | |
| # -------------------------------------------------- | |
| app = FastAPI( | |
| title="Gender Classification API", | |
| description="MobileNetV3 Gender Prediction", | |
| version="1.0" | |
| ) | |
| def root(): | |
| return { | |
| "message": "Gender Classification API is running ๐", | |
| "model": metadata["model_name"], | |
| "classes": metadata["class_names"] | |
| } | |
| async def predict(file: UploadFile = File(...)): | |
| image_bytes = await file.read() | |
| image = Image.open(io.BytesIO(image_bytes)).convert("RGB") | |
| image_tensor = transform(image).unsqueeze(0).to(DEVICE) | |
| with torch.no_grad(): | |
| outputs = model(image_tensor) | |
| probs = torch.softmax(outputs, dim=1) | |
| confidence, predicted = torch.max(probs, 1) | |
| return { | |
| "predicted_class": metadata["class_names"][predicted.item()], | |
| "confidence": round(confidence.item() * 100, 2), | |
| "probabilities": { | |
| metadata["class_names"][0]: round(probs[0][0].item() * 100, 2), | |
| metadata["class_names"][1]: round(probs[0][1].item() * 100, 2), | |
| } | |
| } |