image-classifier / main.py
Brightsun10's picture
Update main.py
a9dcdc2 verified
# app/main.py
from fastapi import FastAPI, UploadFile, File
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse, FileResponse
from fastapi.staticfiles import StaticFiles
from PIL import Image
import io
import torch
import torchvision.transforms as transforms
from torchvision import models
app = FastAPI()
# Enable CORS
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
app.mount("/static", StaticFiles(directory="static"), name="static")
@app.get("/")
def read_root():
return FileResponse("static/index.html")
# ✅ Load model without downloading
model = models.resnet50()
model.load_state_dict(torch.load("resnet50_weights.pth", map_location="cpu"))
model.eval()
# Load labels
with open("imagenet_classes.txt") as f:
labels = [line.strip() for line in f]
transform = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]),
])
@app.post("/predict")
async def predict(file: UploadFile = File(...)):
try:
image_bytes = await file.read()
img = Image.open(io.BytesIO(image_bytes)).convert("RGB")
img_tensor = transform(img).unsqueeze(0)
with torch.no_grad():
outputs = model(img_tensor)
_, predicted = torch.max(outputs, 1)
label = labels[predicted.item()]
return JSONResponse(content={"prediction": label})
except Exception as e:
return JSONResponse(content={"error": str(e)}, status_code=500)