ASL_Backend / app.py
mg643's picture
initial commit
76fbd01 verified
from fastapi import FastAPI, File, UploadFile, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from azure.storage.blob import BlobServiceClient
import torch
import torch.nn as nn
from PIL import Image
from torchvision import transforms
import torchvision.models as models
import io
import os
app = FastAPI()
# Allow your React app to call this API
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Global variables
model = None
transform = None
ASL_CLASSES = ['A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J',
'K', 'L', 'M', 'N', 'O', 'P', 'Q', 'R', 'S', 'T',
'U', 'V', 'W', 'X', 'Y', 'Z', 'del', 'nothing', 'space']
class ASLEfficientNet(nn.Module):
"""EfficientNet-B3 - matches your uploaded model"""
def __init__(self, num_classes=29):
super(ASLEfficientNet, self).__init__()
self.model = models.efficientnet_b3(weights=None)
in_features = self.model.classifier[1].in_features
self.model.classifier = nn.Sequential(
nn.Dropout(0.3),
nn.Linear(in_features, 512),
nn.ReLU(),
nn.Dropout(0.2),
nn.Linear(512, num_classes)
)
def forward(self, x):
return self.model(x)
@app.on_event("startup")
async def load_model():
global model, transform
print("Downloading model from Azure...")
# Get connection string from environment variable
connection_string = os.getenv("AZURE_STORAGE_CONNECTION_STRING")
# Download model
blob_service_client = BlobServiceClient.from_connection_string(connection_string)
blob_client = blob_service_client.get_blob_client(
container="models",
blob="deep_model.pth"
)
# Save to temp file
with open("/tmp/model.pth", "wb") as f:
download_stream = blob_client.download_blob()
f.write(download_stream.readall())
print("Loading model...")
# Load checkpoint
checkpoint = torch.load("/tmp/model.pth", map_location="cpu")
# Initialize model
model = ASLEfficientNet(num_classes=29)
model.load_state_dict(checkpoint['model_state_dict'])
model.eval()
# Set up preprocessing
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
print("Model loaded successfully!")
@app.get("/")
def root():
return {"message": "ASL API is running"}
@app.post("/predict")
async def predict(file: UploadFile = File(...)):
if model is None:
raise HTTPException(status_code=503, detail="Model not loaded")
try:
# Read image
image_bytes = await file.read()
image = Image.open(io.BytesIO(image_bytes)).convert('RGB')
# Preprocess
input_tensor = transform(image).unsqueeze(0)
# Predict
with torch.no_grad():
output = model(input_tensor)
probabilities = torch.softmax(output, dim=1)
confidence, predicted_idx = probabilities.max(1)
# Top 5
top5_prob, top5_idx = probabilities.topk(5, dim=1)
# Convert to letter
predicted_class = predicted_idx.item()
predicted_letter = ASL_CLASSES[predicted_class]
return {
"predicted_class": predicted_class,
"predicted_letter": predicted_letter,
"confidence": confidence.item(),
"top5_predictions": [
{
"class": int(top5_idx[0][i]),
"letter": ASL_CLASSES[int(top5_idx[0][i])],
"confidence": float(top5_prob[0][i])
}
for i in range(5)
]
}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))