Spaces:
Sleeping
Sleeping
| # app/main.py | |
| from fastapi import FastAPI, UploadFile, File | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from fastapi.responses import JSONResponse, FileResponse | |
| from fastapi.staticfiles import StaticFiles | |
| from PIL import Image | |
| import io | |
| import torch | |
| import torchvision.transforms as transforms | |
| from torchvision import models | |
| app = FastAPI() | |
| # Enable CORS | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| app.mount("/static", StaticFiles(directory="static"), name="static") | |
| def read_root(): | |
| return FileResponse("static/index.html") | |
| # ✅ Load model without downloading | |
| model = models.resnet50() | |
| model.load_state_dict(torch.load("resnet50_weights.pth", map_location="cpu")) | |
| model.eval() | |
| # Load labels | |
| with open("imagenet_classes.txt") as f: | |
| labels = [line.strip() for line in f] | |
| transform = transforms.Compose([ | |
| transforms.Resize(256), | |
| transforms.CenterCrop(224), | |
| transforms.ToTensor(), | |
| transforms.Normalize( | |
| mean=[0.485, 0.456, 0.406], | |
| std=[0.229, 0.224, 0.225]), | |
| ]) | |
| async def predict(file: UploadFile = File(...)): | |
| try: | |
| image_bytes = await file.read() | |
| img = Image.open(io.BytesIO(image_bytes)).convert("RGB") | |
| img_tensor = transform(img).unsqueeze(0) | |
| with torch.no_grad(): | |
| outputs = model(img_tensor) | |
| _, predicted = torch.max(outputs, 1) | |
| label = labels[predicted.item()] | |
| return JSONResponse(content={"prediction": label}) | |
| except Exception as e: | |
| return JSONResponse(content={"error": str(e)}, status_code=500) | |