CheXNetDeep / app.py
thehammadishaq's picture
Create app.py
a97d9e5 verified
raw
history blame
7.35 kB
from fastapi import FastAPI, UploadFile, File, HTTPException, Request, Depends
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse, StreamingResponse
from fastapi.staticfiles import StaticFiles
from fastapi_limiter import FastAPILimiter
from fastapi_limiter.depends import RateLimiter
import tensorflow as tf
from tensorflow.keras.models import Model, load_model
from tensorflow.keras.preprocessing.image import img_to_array
from tensorflow.keras.applications.densenet import preprocess_input
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
import cv2
import io
import uuid
from typing import Dict
from datetime import datetime, timedelta
import os
# Configuration
MAX_FILE_SIZE = 10 * 1024 * 1024 # 10MB
HEATMAP_EXPIRY = 300 # 5 minutes in seconds
RATE_LIMIT = "5/minute" # 5 requests per minute
app = FastAPI(
title="ChexNet Medical Imaging API",
description="API for chest X-ray analysis with Grad-CAM visualization",
version="1.1.0"
)
# Mount static files
app.mount("/static", StaticFiles(directory="static"), name="static")
# CORS configuration
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Initialize rate limiter (in-memory)
@app.on_event("startup")
async def startup():
await FastAPILimiter.init()
# Session storage for heatmaps
heatmap_store: Dict[str, dict] = {}
# Load model
try:
model = load_model('Densenet.h5')
model.load_weights("pretrained_model.h5")
except Exception as e:
raise RuntimeError(f"Failed to load model: {str(e)}")
# Model configuration
layer_name = 'conv5_block16_concat'
class_names = [
'Cardiomegaly', 'Emphysema', 'Effusion', 'Hernia', 'Infiltration',
'Mass', 'Nodule', 'Atelectasis', 'Pneumothorax', 'Pleural_Thickening',
'Pneumonia', 'Fibrosis', 'Edema', 'Consolidation', 'No Finding'
]
def cleanup_expired_heatmaps():
"""Remove heatmaps older than HEATMAP_EXPIRY seconds"""
now = datetime.now()
expired = [
sid for sid, data in heatmap_store.items()
if (now - data['timestamp']).total_seconds() > HEATMAP_EXPIRY
]
for sid in expired:
del heatmap_store[sid]
def generate_gradcam(model, img, layer_name):
"""Generate Grad-CAM heatmap overlay"""
img_array = img_to_array(img)
img_array = np.expand_dims(img_array, axis=0)
img_array = preprocess_input(img_array)
grad_model = Model(
inputs=model.inputs,
outputs=[model.get_layer(layer_name).output, model.output]
)
with tf.GradientTape() as tape:
conv_outputs, predictions = grad_model(img_array)
class_idx = tf.argmax(predictions[0])
output = conv_outputs[0]
grads = tape.gradient(predictions, conv_outputs)[0]
guided_grads = tf.cast(output > 0, 'float32') * tf.cast(grads > 0, 'float32') * grads
weights = tf.reduce_mean(guided_grads, axis=(0, 1))
cam = tf.reduce_sum(tf.multiply(weights, output), axis=-1)
heatmap = np.maximum(cam, 0)
heatmap /= np.max(heatmap)
heatmap_img = plt.cm.jet(heatmap)[..., :3]
original_img = Image.fromarray(img)
heatmap_img = Image.fromarray((heatmap_img * 255).astype(np.uint8))
heatmap_img = heatmap_img.resize(original_img.size)
return Image.blend(original_img, heatmap_img, 0.5)
def process_predictions(predictions, class_labels):
"""Format predictions with top 4 classes"""
decoded = []
for pred in predictions:
top_indices = pred.argsort()[-4:][::-1] # Top 4 predictions
decoded.append([(class_labels[i], float(pred[i])) for i in top_indices])
return decoded
def preprocess_image(file_bytes):
"""Convert uploaded file to processed image array"""
img = cv2.imdecode(np.frombuffer(file_bytes, np.uint8), cv2.IMREAD_COLOR)
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
return cv2.resize(img, (540, 540), interpolation=cv2.INTER_AREA)
@app.get("/", include_in_schema=False)
async def root():
return {"message": "ChexNet API is operational", "docs": "/docs"}
@app.get("/health")
async def health_check():
return {
"status": "healthy",
"model_loaded": model is not None,
"timestamp": datetime.now().isoformat()
}
@app.get("/model/classes")
async def get_class_names():
return {"classes": class_names}
@app.post("/analyze",
dependencies=[Depends(RateLimiter(times=RATE_LIMIT))])
async def analyze_image(request: Request, file: UploadFile = File(...)):
"""
Analyze chest X-ray image and return predictions with Grad-CAM visualization
Parameters:
- file: Upload JPEG/PNG image (max 10MB)
Returns:
- predictions: Top 4 diagnoses with confidence scores
- heatmap_url: URL to retrieve Grad-CAM visualization
"""
# Validate input
if not file.content_type.startswith('image/'):
raise HTTPException(400, "Only image files are accepted")
if file.size > MAX_FILE_SIZE:
raise HTTPException(413, f"Maximum file size is {MAX_FILE_SIZE//(1024*1024)}MB")
try:
# Process image
img = preprocess_image(await file.read())
# Prepare input tensor
img_array = img_to_array(img)
img_array = np.expand_dims(img_array, axis=0)
img_array = preprocess_input(img_array)
# Get predictions
predictions = model.predict(img_array)
decoded = process_predictions(predictions, class_names)
# Generate Grad-CAM
heatmap = generate_gradcam(model, img, layer_name)
# Store heatmap with session ID
session_id = str(uuid.uuid4())
img_bytes = io.BytesIO()
heatmap.save(img_bytes, format='PNG')
heatmap_store[session_id] = {
'image': img_bytes.getvalue(),
'timestamp': datetime.now()
}
cleanup_expired_heatmaps()
return {
"session_id": session_id,
"predictions": decoded[0],
"heatmap_url": f"{request.base_url}static/heatmap/{session_id}"
}
except Exception as e:
raise HTTPException(500, f"Processing failed: {str(e)}")
@app.get("/static/heatmap/{session_id}")
async def get_heatmap(session_id: str):
"""Retrieve Grad-CAM visualization by session ID"""
if session_id not in heatmap_store:
raise HTTPException(404, "Session expired or invalid")
return StreamingResponse(
io.BytesIO(heatmap_store[session_id]['image']),
media_type="image/png",
headers={"Cache-Control": "max-age=300"}
)
@app.get("/model/info")
async def model_info():
"""Get model metadata"""
return {
"model_type": "DenseNet121",
"input_size": "540x540",
"classes": len(class_names),
"gradcam_layer": layer_name,
"rate_limit": RATE_LIMIT
}
# Error handlers
@app.exception_handler(HTTPException)
async def handle_http_exception(request, exc):
return JSONResponse(
status_code=exc.status_code,
content={"error": exc.detail}
)
@app.exception_handler(Exception)
async def handle_generic_exception(request, exc):
return JSONResponse(
status_code=500,
content={"error": "Internal server error"}
)