Spaces:
Sleeping
Sleeping
| from fastapi import FastAPI, File, UploadFile, Request | |
| from fastapi.responses import HTMLResponse | |
| from fastapi.templating import Jinja2Templates | |
| from fastapi.staticfiles import StaticFiles | |
| from fastapi.middleware.cors import CORSMiddleware | |
| import torch | |
| from torchvision import transforms, models | |
| from PIL import Image | |
| import io | |
| app = FastAPI() | |
| # CORS | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| # Mount static files and templates | |
| app.mount("/static", StaticFiles(directory="static"), name="static") | |
| templates = Jinja2Templates(directory="templates") | |
| # Device | |
| device = torch.device("cpu") | |
| # Model (assumes ResNet18 was used) | |
| model = models.resnet18(weights=None) | |
| model.fc = torch.nn.Linear(model.fc.in_features, 2) # Assuming 2 classes: Normal & Pneumonia | |
| model.load_state_dict(torch.load("pneumonia_weights.pth", map_location=device)) | |
| model.to(device) | |
| model.eval() | |
| # Transform | |
| transform = transforms.Compose([ | |
| transforms.Resize((224, 224)), | |
| transforms.ToTensor(), | |
| ]) | |
| # Routes | |
| async def home(request: Request): | |
| return templates.TemplateResponse("index.html", {"request": request}) | |
| async def predict(file: UploadFile = File(...)): | |
| try: | |
| contents = await file.read() | |
| image = Image.open(io.BytesIO(contents)).convert("RGB") | |
| image = transform(image).unsqueeze(0).to(device) | |
| with torch.no_grad(): | |
| output = model(image) | |
| predicted = torch.argmax(output, dim=1).item() | |
| result = "Pneumonia" if predicted == 1 else "Normal" | |
| return {"result": result} | |
| except Exception as e: | |
| return {"result": f"Error during prediction: {str(e)}"} | |