from fastapi import FastAPI, File, UploadFile, Request from fastapi.responses import HTMLResponse from fastapi.templating import Jinja2Templates from fastapi.staticfiles import StaticFiles from fastapi.middleware.cors import CORSMiddleware import torch from torchvision import transforms, models from PIL import Image import io app = FastAPI() # CORS app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) # Mount static files and templates app.mount("/static", StaticFiles(directory="static"), name="static") templates = Jinja2Templates(directory="templates") # Device device = torch.device("cpu") # Model (assumes ResNet18 was used) model = models.resnet18(weights=None) model.fc = torch.nn.Linear(model.fc.in_features, 2) # Assuming 2 classes: Normal & Pneumonia model.load_state_dict(torch.load("pneumonia_weights.pth", map_location=device)) model.to(device) model.eval() # Transform transform = transforms.Compose([ transforms.Resize((224, 224)), transforms.ToTensor(), ]) # Routes @app.get("/", response_class=HTMLResponse) async def home(request: Request): return templates.TemplateResponse("index.html", {"request": request}) @app.post("/predict") async def predict(file: UploadFile = File(...)): try: contents = await file.read() image = Image.open(io.BytesIO(contents)).convert("RGB") image = transform(image).unsqueeze(0).to(device) with torch.no_grad(): output = model(image) predicted = torch.argmax(output, dim=1).item() result = "Pneumonia" if predicted == 1 else "Normal" return {"result": result} except Exception as e: return {"result": f"Error during prediction: {str(e)}"}