AyuCS's picture
Update main.py
d9dcb69 verified
# ==============================================================================
# Phase 2: AI-Enabled Healthcare Diagnostic Tool - Backend API (Corrected)
# ==============================================================================
import torch
import torch.nn.functional as F
from fastapi import FastAPI, UploadFile, File, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from transformers import ViTForImageClassification, ViTImageProcessor
from PIL import Image
import io
import logging
import os
from datetime import datetime
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Set cache directory to a writable location
os.environ['TRANSFORMERS_CACHE'] = '/tmp/cache'
os.environ['HF_HOME'] = '/tmp/cache'
os.environ['HF_HUB_DISABLE_SYMLINKS'] = '1'
# --- 1. Application Setup ---
app = FastAPI(
title="Pneumonia Detection API",
description="An API to detect pneumonia from chest X-ray images using a Vision Transformer model.",
version="1.0.0"
)
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# --- 2. Configuration ---
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
MODEL_SAVE_PATH = 'pneumonia_detection_model.pth'
# Using a smaller base model to save storage
BASE_MODEL = "WinKawaks/vit-tiny-patch16-224" # Much smaller than the original
CLASS_NAMES = ['NORMAL', 'PNEUMONIA']
model = None
processor = None
# --- 3. Model Loading ---
@app.on_event("startup")
async def load_model():
global model, processor
try:
logger.info(f"===== Application Startup at {datetime.now().strftime('%Y-%m-%d %H:%M:%S')} =====")
logger.info(f"Device: {DEVICE}")
logger.info(f"Cache directory: /tmp/cache")
# Create cache directory if it doesn't exist
os.makedirs('/tmp/cache', exist_ok=True)
logger.info(f"Loading processor and model from Hugging Face Hub: {BASE_MODEL}")
# Load processor from Hugging Face
processor = ViTImageProcessor.from_pretrained(
BASE_MODEL,
cache_dir='/tmp/cache'
)
logger.info("Processor loaded successfully.")
# Load model architecture from Hugging Face
model = ViTForImageClassification.from_pretrained(
BASE_MODEL,
num_labels=len(CLASS_NAMES),
ignore_mismatched_sizes=True,
cache_dir='/tmp/cache'
)
logger.info("Base model loaded successfully.")
# Load your trained weights
logger.info(f"Loading trained weights from {MODEL_SAVE_PATH}...")
# Load state dict with weights_only=True for security
state_dict = torch.load(MODEL_SAVE_PATH, map_location=DEVICE, weights_only=True)
# Load weights with strict=False in case of size mismatches
missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False)
if missing_keys:
logger.warning(f"Missing keys: {missing_keys}")
if unexpected_keys:
logger.warning(f"Unexpected keys: {unexpected_keys}")
logger.info("Trained weights loaded successfully.")
model.to(DEVICE)
model.eval()
logger.info("✓ Model and processor loaded and ready!")
except Exception as e:
logger.error(f"Error loading model or processor: {e}")
import traceback
logger.error(traceback.format_exc())
model = None
processor = None
# --- 4. API Endpoints ---
@app.get("/")
def read_root():
return {
"message": "Welcome to the Pneumonia Detection API",
"status": "running",
"endpoints": {
"docs": "/docs",
"health": "/health",
"predict": "/predict/"
}
}
@app.get("/health")
def health_check():
if model is None or processor is None:
return {
"status": "unhealthy",
"reason": "Model not loaded",
"device": str(DEVICE)
}
return {
"status": "healthy",
"model_loaded": True,
"device": str(DEVICE),
"class_names": CLASS_NAMES
}
@app.post("/predict/")
async def predict(file: UploadFile = File(...)):
if not model or not processor:
raise HTTPException(
status_code=500,
detail="Model is not loaded. Check server logs or visit /health endpoint."
)
# Read file contents
contents = await file.read()
# Validate and load image
try:
image = Image.open(io.BytesIO(contents)).convert("RGB")
except Exception as e:
raise HTTPException(
status_code=400,
detail=f"Invalid image file: {str(e)}"
)
# Process image and make prediction
try:
inputs = processor(images=image, return_tensors="pt").to(DEVICE)
with torch.no_grad():
outputs = model(**inputs).logits
probabilities = F.softmax(outputs, dim=1)[0]
predicted_class_idx = torch.argmax(probabilities).item()
predicted_class = CLASS_NAMES[predicted_class_idx]
confidence = probabilities[predicted_class_idx].item()
return {
"filename": file.filename,
"prediction": predicted_class,
"confidence": f"{confidence:.4f}",
"probabilities": {
CLASS_NAMES[0]: f"{probabilities[0].item():.4f}",
CLASS_NAMES[1]: f"{probabilities[1].item():.4f}"
}
}
except Exception as e:
logger.error(f"Prediction error: {e}")
raise HTTPException(
status_code=500,
detail=f"Error during prediction: {str(e)}"
)
# --- 5. (Optional) For running directly ---
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8000)