Spaces:
Sleeping
Sleeping
| import io | |
| import json | |
| import torch | |
| import torch.nn as nn | |
| from torchvision import models, transforms | |
| from PIL import Image | |
| from fastapi import FastAPI, File, UploadFile, HTTPException | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from fastapi.responses import JSONResponse | |
| # βββ App setup ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| app = FastAPI(title="ISL Recognition API", version="1.0.0") | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], # Lock this to your Vercel URL in production | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| # βββ Model loader ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def build_model(arch: str, num_classes: int) -> nn.Module: | |
| arch = arch.lower() | |
| if arch == "resnet18": | |
| model = models.resnet18(weights=None) | |
| model.fc = nn.Sequential( | |
| nn.Dropout(0.5), | |
| nn.Linear(model.fc.in_features, num_classes) | |
| ) | |
| elif arch == "mobilenet_v2": | |
| model = models.mobilenet_v2(weights=None) | |
| model.classifier[1] = nn.Linear(model.classifier[1].in_features, num_classes) | |
| elif arch == "efficientnet_b0": | |
| model = models.efficientnet_b0(weights=None) | |
| model.classifier[1] = nn.Linear(model.classifier[1].in_features, num_classes) | |
| elif arch == "vgg16": | |
| model = models.vgg16(weights=None) | |
| model.classifier[6] = nn.Linear(model.classifier[6].in_features, num_classes) | |
| elif arch in ("cnn", "cnn_dropout"): | |
| # Simple custom CNN | |
| class _CNN(nn.Module): | |
| def __init__(self, n, dropout=False): | |
| super().__init__() | |
| self.features = nn.Sequential( | |
| nn.Conv2d(3, 32, 3, padding=1), nn.BatchNorm2d(32), nn.ReLU(True), nn.MaxPool2d(2), | |
| nn.Conv2d(32, 64, 3, padding=1), nn.BatchNorm2d(64), nn.ReLU(True), nn.MaxPool2d(2), | |
| nn.Conv2d(64, 128, 3, padding=1), nn.BatchNorm2d(128), nn.ReLU(True), nn.MaxPool2d(2), | |
| nn.Conv2d(128, 256, 3, padding=1), nn.BatchNorm2d(256), nn.ReLU(True), nn.MaxPool2d(2), | |
| ) | |
| layers = [nn.AdaptiveAvgPool2d((1, 1)), nn.Flatten()] | |
| if dropout: | |
| layers.append(nn.Dropout(0.5)) | |
| layers.append(nn.Linear(256, n)) | |
| self.classifier = nn.Sequential(*layers) | |
| def forward(self, x): | |
| return self.classifier(self.features(x)) | |
| model = _CNN(num_classes, dropout=(arch == "cnn_dropout")) | |
| else: | |
| raise ValueError(f"Unknown architecture: {arch}") | |
| return model | |
| # βββ Load checkpoint on startup ββββββββββββββββββββββββββββββββββββββββββββββ | |
| MODEL_PATH = "isl_best_model.pth" | |
| device = torch.device("cpu") | |
| checkpoint = torch.load(MODEL_PATH, map_location=device) | |
| ARCH = checkpoint["arch"] | |
| NUM_CLASSES = checkpoint["num_classes"] | |
| CLASS_NAMES = checkpoint["class_names"] | |
| model = build_model(ARCH, NUM_CLASSES) | |
| model.load_state_dict(checkpoint["state_dict"]) | |
| model.eval() | |
| model.to(device) | |
| print(f"β Loaded model: {ARCH} | Classes: {NUM_CLASSES} | Val Acc: {checkpoint.get('val_acc', 'N/A')}") | |
| # βββ Inference transform (matches val_transform in notebook) βββββββββββββββββ | |
| transform = transforms.Compose([ | |
| transforms.Resize((224, 224)), | |
| transforms.ToTensor(), | |
| transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), | |
| ]) | |
| # βββ Routes ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def root(): | |
| return { | |
| "message": "ISL Recognition API is running π€", | |
| "model": ARCH, | |
| "num_classes": NUM_CLASSES, | |
| "val_acc": checkpoint.get("val_acc"), | |
| } | |
| def health(): | |
| return {"status": "ok"} | |
| def get_classes(): | |
| return {"classes": CLASS_NAMES} | |
| async def predict(file: UploadFile = File(...)): | |
| # Validate file type | |
| if file.content_type not in ("image/jpeg", "image/png", "image/jpg", "image/webp"): | |
| raise HTTPException(status_code=400, detail="Only JPEG/PNG images are accepted.") | |
| try: | |
| contents = await file.read() | |
| image = Image.open(io.BytesIO(contents)).convert("RGB") | |
| except Exception: | |
| raise HTTPException(status_code=400, detail="Could not read image file.") | |
| tensor = transform(image).unsqueeze(0).to(device) # [1, 3, 224, 224] | |
| with torch.no_grad(): | |
| logits = model(tensor) | |
| probs = torch.softmax(logits, dim=1)[0] | |
| top5_probs, top5_idx = torch.topk(probs, k=min(5, NUM_CLASSES)) | |
| return JSONResponse({ | |
| "prediction": CLASS_NAMES[top5_idx[0].item()], | |
| "confidence": round(top5_probs[0].item() * 100, 2), | |
| "top5": [ | |
| { | |
| "label": CLASS_NAMES[idx.item()], | |
| "confidence": round(prob.item() * 100, 2) | |
| } | |
| for prob, idx in zip(top5_probs, top5_idx) | |
| ], | |
| "model_used": ARCH, | |
| }) | |
| async def live_predict(file: UploadFile = File(...)): | |
| # Validate file type | |
| if file.content_type not in ("image/jpeg", "image/png", "image/jpg", "image/webp"): | |
| raise HTTPException(status_code=400, detail="Only JPEG/PNG images are accepted.") | |
| try: | |
| contents = await file.read() | |
| image = Image.open(io.BytesIO(contents)).convert("RGB") | |
| except Exception: | |
| raise HTTPException(status_code=400, detail="Could not read image file.") | |
| tensor = transform(image).unsqueeze(0).to(device) # [1, 3, 224, 224] | |
| with torch.no_grad(): | |
| logits = model(tensor) | |
| probs = torch.softmax(logits, dim=1)[0] | |
| top5_probs, top5_idx = torch.topk(probs, k=min(5, NUM_CLASSES)) | |
| return JSONResponse({ | |
| "prediction": CLASS_NAMES[top5_idx[0].item()], | |
| "confidence": round(top5_probs[0].item() * 100, 2), | |
| "top5": [ | |
| { | |
| "label": CLASS_NAMES[idx.item()], | |
| "confidence": round(prob.item() * 100, 2) | |
| } | |
| for prob, idx in zip(top5_probs, top5_idx) | |
| ], | |
| "model_used": ARCH, | |
| }) |