# app/main.py from fastapi import FastAPI, UploadFile, File from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import JSONResponse, FileResponse from fastapi.staticfiles import StaticFiles from PIL import Image import io import torch import torchvision.transforms as transforms from torchvision import models app = FastAPI() # Enable CORS app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) app.mount("/static", StaticFiles(directory="static"), name="static") @app.get("/") def read_root(): return FileResponse("static/index.html") # ✅ Load model without downloading model = models.resnet50() model.load_state_dict(torch.load("resnet50_weights.pth", map_location="cpu")) model.eval() # Load labels with open("imagenet_classes.txt") as f: labels = [line.strip() for line in f] transform = transforms.Compose([ transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(), transforms.Normalize( mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), ]) @app.post("/predict") async def predict(file: UploadFile = File(...)): try: image_bytes = await file.read() img = Image.open(io.BytesIO(image_bytes)).convert("RGB") img_tensor = transform(img).unsqueeze(0) with torch.no_grad(): outputs = model(img_tensor) _, predicted = torch.max(outputs, 1) label = labels[predicted.item()] return JSONResponse(content={"prediction": label}) except Exception as e: return JSONResponse(content={"error": str(e)}, status_code=500)