Spaces:
Sleeping
Sleeping
| from fastapi import FastAPI, File, UploadFile, HTTPException | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from azure.storage.blob import BlobServiceClient | |
| import torch | |
| import torch.nn as nn | |
| from PIL import Image | |
| from torchvision import transforms | |
| import torchvision.models as models | |
| import io | |
| import os | |
| app = FastAPI() | |
| # Allow your React app to call this API | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| # Global variables | |
| model = None | |
| transform = None | |
| ASL_CLASSES = ['A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', | |
| 'K', 'L', 'M', 'N', 'O', 'P', 'Q', 'R', 'S', 'T', | |
| 'U', 'V', 'W', 'X', 'Y', 'Z', 'del', 'nothing', 'space'] | |
| class ASLEfficientNet(nn.Module): | |
| """EfficientNet-B3 - matches your uploaded model""" | |
| def __init__(self, num_classes=29): | |
| super(ASLEfficientNet, self).__init__() | |
| self.model = models.efficientnet_b3(weights=None) | |
| in_features = self.model.classifier[1].in_features | |
| self.model.classifier = nn.Sequential( | |
| nn.Dropout(0.3), | |
| nn.Linear(in_features, 512), | |
| nn.ReLU(), | |
| nn.Dropout(0.2), | |
| nn.Linear(512, num_classes) | |
| ) | |
| def forward(self, x): | |
| return self.model(x) | |
| async def load_model(): | |
| global model, transform | |
| print("Downloading model from Azure...") | |
| # Get connection string from environment variable | |
| connection_string = os.getenv("AZURE_STORAGE_CONNECTION_STRING") | |
| # Download model | |
| blob_service_client = BlobServiceClient.from_connection_string(connection_string) | |
| blob_client = blob_service_client.get_blob_client( | |
| container="models", | |
| blob="deep_model.pth" | |
| ) | |
| # Save to temp file | |
| with open("/tmp/model.pth", "wb") as f: | |
| download_stream = blob_client.download_blob() | |
| f.write(download_stream.readall()) | |
| print("Loading model...") | |
| # Load checkpoint | |
| checkpoint = torch.load("/tmp/model.pth", map_location="cpu") | |
| # Initialize model | |
| model = ASLEfficientNet(num_classes=29) | |
| model.load_state_dict(checkpoint['model_state_dict']) | |
| model.eval() | |
| # Set up preprocessing | |
| transform = transforms.Compose([ | |
| transforms.Resize((224, 224)), | |
| transforms.ToTensor(), | |
| transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) | |
| ]) | |
| print("Model loaded successfully!") | |
| def root(): | |
| return {"message": "ASL API is running"} | |
| async def predict(file: UploadFile = File(...)): | |
| if model is None: | |
| raise HTTPException(status_code=503, detail="Model not loaded") | |
| try: | |
| # Read image | |
| image_bytes = await file.read() | |
| image = Image.open(io.BytesIO(image_bytes)).convert('RGB') | |
| # Preprocess | |
| input_tensor = transform(image).unsqueeze(0) | |
| # Predict | |
| with torch.no_grad(): | |
| output = model(input_tensor) | |
| probabilities = torch.softmax(output, dim=1) | |
| confidence, predicted_idx = probabilities.max(1) | |
| # Top 5 | |
| top5_prob, top5_idx = probabilities.topk(5, dim=1) | |
| # Convert to letter | |
| predicted_class = predicted_idx.item() | |
| predicted_letter = ASL_CLASSES[predicted_class] | |
| return { | |
| "predicted_class": predicted_class, | |
| "predicted_letter": predicted_letter, | |
| "confidence": confidence.item(), | |
| "top5_predictions": [ | |
| { | |
| "class": int(top5_idx[0][i]), | |
| "letter": ASL_CLASSES[int(top5_idx[0][i])], | |
| "confidence": float(top5_prob[0][i]) | |
| } | |
| for i in range(5) | |
| ] | |
| } | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=str(e)) |