Spaces:
Sleeping
Sleeping
File size: 1,689 Bytes
1dde15e a9dcdc2 1dde15e | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 | # app/main.py
from fastapi import FastAPI, UploadFile, File
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse, FileResponse
from fastapi.staticfiles import StaticFiles
from PIL import Image
import io
import torch
import torchvision.transforms as transforms
from torchvision import models
app = FastAPI()
# Enable CORS
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
app.mount("/static", StaticFiles(directory="static"), name="static")
@app.get("/")
def read_root():
return FileResponse("static/index.html")
# ✅ Load model without downloading
model = models.resnet50()
model.load_state_dict(torch.load("resnet50_weights.pth", map_location="cpu"))
model.eval()
# Load labels
with open("imagenet_classes.txt") as f:
labels = [line.strip() for line in f]
transform = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]),
])
@app.post("/predict")
async def predict(file: UploadFile = File(...)):
try:
image_bytes = await file.read()
img = Image.open(io.BytesIO(image_bytes)).convert("RGB")
img_tensor = transform(img).unsqueeze(0)
with torch.no_grad():
outputs = model(img_tensor)
_, predicted = torch.max(outputs, 1)
label = labels[predicted.item()]
return JSONResponse(content={"prediction": label})
except Exception as e:
return JSONResponse(content={"error": str(e)}, status_code=500)
|