CheXNetDeep / app.py
thehammadishaq's picture
Update app.py
87cc4a2 verified
raw
history blame
8.04 kB
from fastapi import FastAPI, UploadFile, File, HTTPException, Request
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse, StreamingResponse
from fastapi.staticfiles import StaticFiles
from slowapi import Limiter
from slowapi.util import get_remote_address
import tensorflow as tf
from tensorflow.keras.models import Model, load_model
from tensorflow.keras.layers import Input
from tensorflow.keras.preprocessing.image import img_to_array
from tensorflow.keras.applications.densenet import DenseNet121, preprocess_input
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
import cv2
import io
import uuid
from typing import Dict
from datetime import datetime, timedelta
import os
# Configuration
MAX_FILE_SIZE = 10 * 1024 * 1024 # 10MB
HEATMAP_EXPIRY = 300 # 5 minutes in seconds
# Initialize FastAPI with rate limiting
app = FastAPI(
title="ChexNet Medical Imaging API",
description="API for chest X-ray analysis with Grad-CAM visualization",
version="2.3.0"
)
# Rate limiter setup
limiter = Limiter(key_func=get_remote_address)
app.state.limiter = limiter
# Mount static files
app.mount("/static", StaticFiles(directory="static"), name="static")
# CORS configuration
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Session storage for heatmaps
heatmap_store: Dict[str, dict] = {}
# Model configuration
layer_name = 'conv5_block16_concat'
class_names = [
'Cardiomegaly', 'Emphysema', 'Effusion', 'Hernia', 'Infiltration',
'Mass', 'Nodule', 'Atelectasis', 'Pneumothorax', 'Pleural_Thickening',
'Pneumonia', 'Fibrosis', 'Edema', 'Consolidation', 'No Finding'
]
def build_model():
"""Build DenseNet121 model architecture"""
base_model = DenseNet121(weights=None, include_top=False, input_shape=(None, None, 3))
x = base_model.output
x = tf.keras.layers.GlobalAveragePooling2D()(x)
predictions = tf.keras.layers.Dense(len(class_names), activation='sigmoid')(x)
return Model(inputs=base_model.input, outputs=predictions)
def load_model_with_fallback():
"""Attempt multiple strategies to load the model"""
try:
# Strategy 1: Try direct loading
model = load_model('Densenet.h5', compile=False)
model.load_weights('pretrained_model.h5')
return model
except Exception as e:
print(f"Direct loading failed: {e}")
try:
# Strategy 2: Build architecture and load weights
model = build_model()
model.load_weights('pretrained_model.h5')
return model
except Exception as e:
print(f"Architecture rebuild failed: {e}")
raise RuntimeError("All model loading strategies failed")
# Load model
try:
model = load_model_with_fallback()
print("✅ Model loaded successfully!")
except Exception as e:
print(f"❌ Model loading failed: {e}")
raise
def cleanup_expired_heatmaps():
"""Remove heatmaps older than HEATMAP_EXPIRY seconds"""
now = datetime.now()
expired = [
sid for sid, data in heatmap_store.items()
if (now - data['timestamp']).total_seconds() > HEATMAP_EXPIRY
]
for sid in expired:
del heatmap_store[sid]
def generate_gradcam(img):
"""Generate Grad-CAM heatmap overlay"""
img_array = img_to_array(img)
img_array = np.expand_dims(img_array, axis=0)
img_array = preprocess_input(img_array)
grad_model = Model(
inputs=model.inputs,
outputs=[model.get_layer(layer_name).output, model.output]
)
with tf.GradientTape() as tape:
conv_outputs, predictions = grad_model(img_array)
class_idx = tf.argmax(predictions[0])
output = conv_outputs[0]
grads = tape.gradient(predictions, conv_outputs)[0]
guided_grads = tf.cast(output > 0, 'float32') * tf.cast(grads > 0, 'float32') * grads
weights = tf.reduce_mean(guided_grads, axis=(0, 1))
cam = tf.reduce_sum(tf.multiply(weights, output), axis=-1)
heatmap = np.maximum(cam, 0)
heatmap /= np.max(heatmap)
heatmap_img = plt.cm.jet(heatmap)[..., :3]
original_img = Image.fromarray(img)
heatmap_img = Image.fromarray((heatmap_img * 255).astype(np.uint8))
heatmap_img = heatmap_img.resize(original_img.size)
return Image.blend(original_img, heatmap_img, 0.5)
def process_predictions(predictions):
"""Format predictions with top 4 classes"""
decoded = []
for pred in predictions:
top_indices = pred.argsort()[-4:][::-1]
decoded.append([(class_names[i], float(pred[i])) for i in top_indices])
return decoded
def preprocess_image(file_bytes):
"""Convert uploaded file to processed image array"""
img = cv2.imdecode(np.frombuffer(file_bytes, np.uint8), cv2.IMREAD_COLOR)
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
return cv2.resize(img, (540, 540), interpolation=cv2.INTER_AREA)
@app.get("/", include_in_schema=False)
async def root():
return {
"message": "ChexNet API is operational",
"endpoints": {
"docs": "/docs",
"health": "/health",
"analyze": "POST /analyze"
}
}
@app.get("/health")
async def health_check():
return {
"status": "healthy" if model else "unhealthy",
"timestamp": datetime.now().isoformat(),
"model_loaded": bool(model)
}
@app.get("/model/classes")
async def get_class_names():
return {"classes": class_names}
@app.post("/analyze")
@limiter.limit("5/minute")
async def analyze_image(request: Request, file: UploadFile = File(...)):
if not file.content_type.startswith('image/'):
raise HTTPException(400, "Only image files are accepted")
if file.size > MAX_FILE_SIZE:
raise HTTPException(413, f"Max file size is {MAX_FILE_SIZE//(1024*1024)}MB")
try:
contents = await file.read()
img = preprocess_image(contents)
# Prepare input tensor
img_array = img_to_array(img)
img_array = np.expand_dims(img_array, axis=0)
img_array = preprocess_input(img_array)
# Get predictions
predictions = model.predict(img_array)
decoded = process_predictions(predictions)
# Generate Grad-CAM
heatmap = generate_gradcam(img)
# Store heatmap with session ID
session_id = str(uuid.uuid4())
img_bytes = io.BytesIO()
heatmap.save(img_bytes, format='PNG')
heatmap_store[session_id] = {
'image': img_bytes.getvalue(),
'timestamp': datetime.now()
}
cleanup_expired_heatmaps()
return {
"session_id": session_id,
"predictions": decoded[0],
"heatmap_url": f"{request.base_url}static/heatmap/{session_id}"
}
except Exception as e:
raise HTTPException(500, f"Processing failed: {str(e)}")
@app.get("/static/heatmap/{session_id}")
async def get_heatmap(session_id: str):
if session_id not in heatmap_store:
raise HTTPException(404, "Session expired or invalid")
return StreamingResponse(
io.BytesIO(heatmap_store[session_id]['image']),
media_type="image/png",
headers={"Cache-Control": "max-age=300"}
)
@app.get("/model/info")
async def model_info():
return {
"model_type": "DenseNet121",
"input_size": "540x540",
"classes": len(class_names),
"gradcam_layer": layer_name,
"rate_limit": "5 requests/minute"
}
@app.exception_handler(HTTPException)
async def http_handler(request: Request, exc: HTTPException):
return JSONResponse(
status_code=exc.status_code,
content={"error": exc.detail}
)
@app.exception_handler(Exception)
async def generic_handler(request: Request, exc: Exception):
return JSONResponse(
status_code=500,
content={"error": "Internal server error"}
)