PneumoniaAPI / app.py
GitHub Actions
Auto-deploy from GitHub: 65670d5535a66243456ef2e8a5ccc8eafed0b120
e642110
"""
FastAPI for Pneumonia Detection - Hugging Face Spaces Deployment
CI/CD enabled - auto-deploys from GitHub
"""
import io
import time
from pathlib import Path
import torch
import torch.nn as nn
from torchvision import models, transforms
from PIL import Image
from fastapi import FastAPI, UploadFile, File, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
# =============================================================================
# Configuration
# =============================================================================
IMAGE_SIZE = 224
IMAGENET_MEAN = [0.485, 0.456, 0.406]
IMAGENET_STD = [0.229, 0.224, 0.225]
CLASS_NAMES = ["NORMAL", "PNEUMONIA"]
MODEL_PATH = Path("models/best_model.pt")
# =============================================================================
# Model Definition
# =============================================================================
class PneumoniaClassifier(nn.Module):
def __init__(self):
super().__init__()
self.backbone = models.efficientnet_b0(weights=None)
in_features = self.backbone.classifier[1].in_features
self.backbone.classifier = nn.Sequential(
nn.Dropout(p=0.3, inplace=True),
nn.Linear(in_features, 1)
)
def forward(self, x):
return self.backbone(x)
# =============================================================================
# Response Models
# =============================================================================
class HealthResponse(BaseModel):
status: str
model_loaded: bool
class PredictionResponse(BaseModel):
prediction: str
confidence: float
probability: float
processing_time_ms: float
# =============================================================================
# App Setup
# =============================================================================
app = FastAPI(
title="Pneumonia Detection API",
description="Deep learning API for detecting pneumonia from chest X-rays",
version="1.0.0"
)
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# =============================================================================
# Model Loading
# =============================================================================
model = None
device = None
@app.on_event("startup")
async def load_model():
global model, device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
if not MODEL_PATH.exists():
print(f"Warning: Model not found at {MODEL_PATH}")
return
model = PneumoniaClassifier()
checkpoint = torch.load(MODEL_PATH, map_location=device)
model.load_state_dict(checkpoint['model_state_dict'])
model.to(device)
model.eval()
print("Model loaded successfully")
# =============================================================================
# Helper Functions
# =============================================================================
def get_transforms():
return transforms.Compose([
transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)),
transforms.ToTensor(),
transforms.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD)
])
async def read_image(file: UploadFile) -> Image.Image:
contents = await file.read()
return Image.open(io.BytesIO(contents)).convert("RGB")
def predict(image: Image.Image):
transform = get_transforms()
img_tensor = transform(image).unsqueeze(0).to(device)
with torch.no_grad():
output = model(img_tensor)
prob = torch.sigmoid(output).item()
pred_class = CLASS_NAMES[1] if prob > 0.5 else CLASS_NAMES[0]
confidence = prob if prob > 0.5 else 1 - prob
return pred_class, confidence, prob
# =============================================================================
# Endpoints
# =============================================================================
@app.get("/")
async def root():
return {"message": "Pneumonia Detection API", "docs": "/docs"}
@app.get("/health", response_model=HealthResponse)
async def health():
return HealthResponse(
status="healthy" if model else "model_not_loaded",
model_loaded=model is not None
)
@app.post("/predict", response_model=PredictionResponse)
async def predict_endpoint(file: UploadFile = File(...)):
if model is None:
raise HTTPException(status_code=503, detail="Model not loaded")
if not file.content_type.startswith("image/"):
raise HTTPException(status_code=400, detail="File must be an image")
image = await read_image(file)
start_time = time.time()
pred_class, confidence, prob = predict(image)
processing_time = (time.time() - start_time) * 1000
return PredictionResponse(
prediction=pred_class,
confidence=confidence,
probability=prob,
processing_time_ms=round(processing_time, 2)
)