Omkar1872's picture
Update app.py
a813020 verified
from fastapi import FastAPI, File, UploadFile, Request
from fastapi.responses import HTMLResponse
from fastapi.templating import Jinja2Templates
from fastapi.staticfiles import StaticFiles
from fastapi.middleware.cors import CORSMiddleware
import torch
from torchvision import transforms, models
from PIL import Image
import io
app = FastAPI()
# CORS
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Mount static files and templates
app.mount("/static", StaticFiles(directory="static"), name="static")
templates = Jinja2Templates(directory="templates")
# Device
device = torch.device("cpu")
# Model (assumes ResNet18 was used)
model = models.resnet18(weights=None)
model.fc = torch.nn.Linear(model.fc.in_features, 2) # Assuming 2 classes: Normal & Pneumonia
model.load_state_dict(torch.load("pneumonia_weights.pth", map_location=device))
model.to(device)
model.eval()
# Transform
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
])
# Routes
@app.get("/", response_class=HTMLResponse)
async def home(request: Request):
return templates.TemplateResponse("index.html", {"request": request})
@app.post("/predict")
async def predict(file: UploadFile = File(...)):
try:
contents = await file.read()
image = Image.open(io.BytesIO(contents)).convert("RGB")
image = transform(image).unsqueeze(0).to(device)
with torch.no_grad():
output = model(image)
predicted = torch.argmax(output, dim=1).item()
result = "Pneumonia" if predicted == 1 else "Normal"
return {"result": result}
except Exception as e:
return {"result": f"Error during prediction: {str(e)}"}