Spaces:
Sleeping
Sleeping
| from fastapi import FastAPI, UploadFile, File, HTTPException, Request, Depends | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from fastapi.responses import JSONResponse, StreamingResponse | |
| from fastapi.staticfiles import StaticFiles | |
| from fastapi_limiter import FastAPILimiter | |
| from fastapi_limiter.depends import RateLimiter | |
| import tensorflow as tf | |
| from tensorflow.keras.models import Model, load_model | |
| from tensorflow.keras.preprocessing.image import img_to_array | |
| from tensorflow.keras.applications.densenet import preprocess_input | |
| import numpy as np | |
| from PIL import Image | |
| import matplotlib.pyplot as plt | |
| import cv2 | |
| import io | |
| import uuid | |
| from typing import Dict | |
| from datetime import datetime, timedelta | |
| import os | |
| # Configuration | |
| MAX_FILE_SIZE = 10 * 1024 * 1024 # 10MB | |
| HEATMAP_EXPIRY = 300 # 5 minutes in seconds | |
| RATE_LIMIT = "5/minute" # 5 requests per minute | |
| app = FastAPI( | |
| title="ChexNet Medical Imaging API", | |
| description="API for chest X-ray analysis with Grad-CAM visualization", | |
| version="1.1.0" | |
| ) | |
| # Mount static files | |
| app.mount("/static", StaticFiles(directory="static"), name="static") | |
| # CORS configuration | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| # Initialize rate limiter (in-memory) | |
| async def startup(): | |
| await FastAPILimiter.init() | |
| # Session storage for heatmaps | |
| heatmap_store: Dict[str, dict] = {} | |
| # Load model | |
| try: | |
| model = load_model('Densenet.h5') | |
| model.load_weights("pretrained_model.h5") | |
| except Exception as e: | |
| raise RuntimeError(f"Failed to load model: {str(e)}") | |
| # Model configuration | |
| layer_name = 'conv5_block16_concat' | |
| class_names = [ | |
| 'Cardiomegaly', 'Emphysema', 'Effusion', 'Hernia', 'Infiltration', | |
| 'Mass', 'Nodule', 'Atelectasis', 'Pneumothorax', 'Pleural_Thickening', | |
| 'Pneumonia', 'Fibrosis', 'Edema', 'Consolidation', 'No Finding' | |
| ] | |
| def cleanup_expired_heatmaps(): | |
| """Remove heatmaps older than HEATMAP_EXPIRY seconds""" | |
| now = datetime.now() | |
| expired = [ | |
| sid for sid, data in heatmap_store.items() | |
| if (now - data['timestamp']).total_seconds() > HEATMAP_EXPIRY | |
| ] | |
| for sid in expired: | |
| del heatmap_store[sid] | |
| def generate_gradcam(model, img, layer_name): | |
| """Generate Grad-CAM heatmap overlay""" | |
| img_array = img_to_array(img) | |
| img_array = np.expand_dims(img_array, axis=0) | |
| img_array = preprocess_input(img_array) | |
| grad_model = Model( | |
| inputs=model.inputs, | |
| outputs=[model.get_layer(layer_name).output, model.output] | |
| ) | |
| with tf.GradientTape() as tape: | |
| conv_outputs, predictions = grad_model(img_array) | |
| class_idx = tf.argmax(predictions[0]) | |
| output = conv_outputs[0] | |
| grads = tape.gradient(predictions, conv_outputs)[0] | |
| guided_grads = tf.cast(output > 0, 'float32') * tf.cast(grads > 0, 'float32') * grads | |
| weights = tf.reduce_mean(guided_grads, axis=(0, 1)) | |
| cam = tf.reduce_sum(tf.multiply(weights, output), axis=-1) | |
| heatmap = np.maximum(cam, 0) | |
| heatmap /= np.max(heatmap) | |
| heatmap_img = plt.cm.jet(heatmap)[..., :3] | |
| original_img = Image.fromarray(img) | |
| heatmap_img = Image.fromarray((heatmap_img * 255).astype(np.uint8)) | |
| heatmap_img = heatmap_img.resize(original_img.size) | |
| return Image.blend(original_img, heatmap_img, 0.5) | |
| def process_predictions(predictions, class_labels): | |
| """Format predictions with top 4 classes""" | |
| decoded = [] | |
| for pred in predictions: | |
| top_indices = pred.argsort()[-4:][::-1] # Top 4 predictions | |
| decoded.append([(class_labels[i], float(pred[i])) for i in top_indices]) | |
| return decoded | |
| def preprocess_image(file_bytes): | |
| """Convert uploaded file to processed image array""" | |
| img = cv2.imdecode(np.frombuffer(file_bytes, np.uint8), cv2.IMREAD_COLOR) | |
| img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) | |
| return cv2.resize(img, (540, 540), interpolation=cv2.INTER_AREA) | |
| async def root(): | |
| return {"message": "ChexNet API is operational", "docs": "/docs"} | |
| async def health_check(): | |
| return { | |
| "status": "healthy", | |
| "model_loaded": model is not None, | |
| "timestamp": datetime.now().isoformat() | |
| } | |
| async def get_class_names(): | |
| return {"classes": class_names} | |
| async def analyze_image(request: Request, file: UploadFile = File(...)): | |
| """ | |
| Analyze chest X-ray image and return predictions with Grad-CAM visualization | |
| Parameters: | |
| - file: Upload JPEG/PNG image (max 10MB) | |
| Returns: | |
| - predictions: Top 4 diagnoses with confidence scores | |
| - heatmap_url: URL to retrieve Grad-CAM visualization | |
| """ | |
| # Validate input | |
| if not file.content_type.startswith('image/'): | |
| raise HTTPException(400, "Only image files are accepted") | |
| if file.size > MAX_FILE_SIZE: | |
| raise HTTPException(413, f"Maximum file size is {MAX_FILE_SIZE//(1024*1024)}MB") | |
| try: | |
| # Process image | |
| img = preprocess_image(await file.read()) | |
| # Prepare input tensor | |
| img_array = img_to_array(img) | |
| img_array = np.expand_dims(img_array, axis=0) | |
| img_array = preprocess_input(img_array) | |
| # Get predictions | |
| predictions = model.predict(img_array) | |
| decoded = process_predictions(predictions, class_names) | |
| # Generate Grad-CAM | |
| heatmap = generate_gradcam(model, img, layer_name) | |
| # Store heatmap with session ID | |
| session_id = str(uuid.uuid4()) | |
| img_bytes = io.BytesIO() | |
| heatmap.save(img_bytes, format='PNG') | |
| heatmap_store[session_id] = { | |
| 'image': img_bytes.getvalue(), | |
| 'timestamp': datetime.now() | |
| } | |
| cleanup_expired_heatmaps() | |
| return { | |
| "session_id": session_id, | |
| "predictions": decoded[0], | |
| "heatmap_url": f"{request.base_url}static/heatmap/{session_id}" | |
| } | |
| except Exception as e: | |
| raise HTTPException(500, f"Processing failed: {str(e)}") | |
| async def get_heatmap(session_id: str): | |
| """Retrieve Grad-CAM visualization by session ID""" | |
| if session_id not in heatmap_store: | |
| raise HTTPException(404, "Session expired or invalid") | |
| return StreamingResponse( | |
| io.BytesIO(heatmap_store[session_id]['image']), | |
| media_type="image/png", | |
| headers={"Cache-Control": "max-age=300"} | |
| ) | |
| async def model_info(): | |
| """Get model metadata""" | |
| return { | |
| "model_type": "DenseNet121", | |
| "input_size": "540x540", | |
| "classes": len(class_names), | |
| "gradcam_layer": layer_name, | |
| "rate_limit": RATE_LIMIT | |
| } | |
| # Error handlers | |
| async def handle_http_exception(request, exc): | |
| return JSONResponse( | |
| status_code=exc.status_code, | |
| content={"error": exc.detail} | |
| ) | |
| async def handle_generic_exception(request, exc): | |
| return JSONResponse( | |
| status_code=500, | |
| content={"error": "Internal server error"} | |
| ) |