Spaces:
Sleeping
Sleeping
| from fastapi import FastAPI, UploadFile, File, HTTPException, Request | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from fastapi.responses import JSONResponse, StreamingResponse | |
| from fastapi.staticfiles import StaticFiles | |
| from slowapi import Limiter | |
| from slowapi.util import get_remote_address | |
| import tensorflow as tf | |
| from tensorflow.keras.models import Model, load_model | |
| from tensorflow.keras.layers import Input | |
| from tensorflow.keras.preprocessing.image import img_to_array | |
| from tensorflow.keras.applications.densenet import DenseNet121, preprocess_input | |
| import numpy as np | |
| from PIL import Image | |
| import matplotlib.pyplot as plt | |
| import cv2 | |
| import io | |
| import uuid | |
| from typing import Dict | |
| from datetime import datetime, timedelta | |
| import os | |
| # Configuration | |
| MAX_FILE_SIZE = 10 * 1024 * 1024 # 10MB | |
| HEATMAP_EXPIRY = 300 # 5 minutes in seconds | |
| # Initialize FastAPI with rate limiting | |
| app = FastAPI( | |
| title="ChexNet Medical Imaging API", | |
| description="API for chest X-ray analysis with Grad-CAM visualization", | |
| version="2.3.0" | |
| ) | |
| # Rate limiter setup | |
| limiter = Limiter(key_func=get_remote_address) | |
| app.state.limiter = limiter | |
| # Mount static files | |
| app.mount("/static", StaticFiles(directory="static"), name="static") | |
| # CORS configuration | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| # Session storage for heatmaps | |
| heatmap_store: Dict[str, dict] = {} | |
| # Model configuration | |
| layer_name = 'conv5_block16_concat' | |
| class_names = [ | |
| 'Cardiomegaly', 'Emphysema', 'Effusion', 'Hernia', 'Infiltration', | |
| 'Mass', 'Nodule', 'Atelectasis', 'Pneumothorax', 'Pleural_Thickening', | |
| 'Pneumonia', 'Fibrosis', 'Edema', 'Consolidation', 'No Finding' | |
| ] | |
| def build_model(): | |
| """Build DenseNet121 model architecture""" | |
| base_model = DenseNet121(weights=None, include_top=False, input_shape=(None, None, 3)) | |
| x = base_model.output | |
| x = tf.keras.layers.GlobalAveragePooling2D()(x) | |
| predictions = tf.keras.layers.Dense(len(class_names), activation='sigmoid')(x) | |
| return Model(inputs=base_model.input, outputs=predictions) | |
| def load_model_with_fallback(): | |
| """Attempt multiple strategies to load the model""" | |
| try: | |
| # Strategy 1: Try direct loading | |
| model = load_model('Densenet.h5', compile=False) | |
| model.load_weights('pretrained_model.h5') | |
| return model | |
| except Exception as e: | |
| print(f"Direct loading failed: {e}") | |
| try: | |
| # Strategy 2: Build architecture and load weights | |
| model = build_model() | |
| model.load_weights('pretrained_model.h5') | |
| return model | |
| except Exception as e: | |
| print(f"Architecture rebuild failed: {e}") | |
| raise RuntimeError("All model loading strategies failed") | |
| # Load model | |
| try: | |
| model = load_model_with_fallback() | |
| print("✅ Model loaded successfully!") | |
| except Exception as e: | |
| print(f"❌ Model loading failed: {e}") | |
| raise | |
| def cleanup_expired_heatmaps(): | |
| """Remove heatmaps older than HEATMAP_EXPIRY seconds""" | |
| now = datetime.now() | |
| expired = [ | |
| sid for sid, data in heatmap_store.items() | |
| if (now - data['timestamp']).total_seconds() > HEATMAP_EXPIRY | |
| ] | |
| for sid in expired: | |
| del heatmap_store[sid] | |
| def generate_gradcam(img): | |
| """Generate Grad-CAM heatmap overlay""" | |
| img_array = img_to_array(img) | |
| img_array = np.expand_dims(img_array, axis=0) | |
| img_array = preprocess_input(img_array) | |
| grad_model = Model( | |
| inputs=model.inputs, | |
| outputs=[model.get_layer(layer_name).output, model.output] | |
| ) | |
| with tf.GradientTape() as tape: | |
| conv_outputs, predictions = grad_model(img_array) | |
| class_idx = tf.argmax(predictions[0]) | |
| output = conv_outputs[0] | |
| grads = tape.gradient(predictions, conv_outputs)[0] | |
| guided_grads = tf.cast(output > 0, 'float32') * tf.cast(grads > 0, 'float32') * grads | |
| weights = tf.reduce_mean(guided_grads, axis=(0, 1)) | |
| cam = tf.reduce_sum(tf.multiply(weights, output), axis=-1) | |
| heatmap = np.maximum(cam, 0) | |
| heatmap /= np.max(heatmap) | |
| heatmap_img = plt.cm.jet(heatmap)[..., :3] | |
| original_img = Image.fromarray(img) | |
| heatmap_img = Image.fromarray((heatmap_img * 255).astype(np.uint8)) | |
| heatmap_img = heatmap_img.resize(original_img.size) | |
| return Image.blend(original_img, heatmap_img, 0.5) | |
| def process_predictions(predictions): | |
| """Format predictions with top 4 classes""" | |
| decoded = [] | |
| for pred in predictions: | |
| top_indices = pred.argsort()[-4:][::-1] | |
| decoded.append([(class_names[i], float(pred[i])) for i in top_indices]) | |
| return decoded | |
| def preprocess_image(file_bytes): | |
| """Convert uploaded file to processed image array""" | |
| img = cv2.imdecode(np.frombuffer(file_bytes, np.uint8), cv2.IMREAD_COLOR) | |
| img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) | |
| return cv2.resize(img, (540, 540), interpolation=cv2.INTER_AREA) | |
| async def root(): | |
| return { | |
| "message": "ChexNet API is operational", | |
| "endpoints": { | |
| "docs": "/docs", | |
| "health": "/health", | |
| "analyze": "POST /analyze" | |
| } | |
| } | |
| async def health_check(): | |
| return { | |
| "status": "healthy" if model else "unhealthy", | |
| "timestamp": datetime.now().isoformat(), | |
| "model_loaded": bool(model) | |
| } | |
| async def get_class_names(): | |
| return {"classes": class_names} | |
| async def analyze_image(request: Request, file: UploadFile = File(...)): | |
| if not file.content_type.startswith('image/'): | |
| raise HTTPException(400, "Only image files are accepted") | |
| if file.size > MAX_FILE_SIZE: | |
| raise HTTPException(413, f"Max file size is {MAX_FILE_SIZE//(1024*1024)}MB") | |
| try: | |
| contents = await file.read() | |
| img = preprocess_image(contents) | |
| # Prepare input tensor | |
| img_array = img_to_array(img) | |
| img_array = np.expand_dims(img_array, axis=0) | |
| img_array = preprocess_input(img_array) | |
| # Get predictions | |
| predictions = model.predict(img_array) | |
| decoded = process_predictions(predictions) | |
| # Generate Grad-CAM | |
| heatmap = generate_gradcam(img) | |
| # Store heatmap with session ID | |
| session_id = str(uuid.uuid4()) | |
| img_bytes = io.BytesIO() | |
| heatmap.save(img_bytes, format='PNG') | |
| heatmap_store[session_id] = { | |
| 'image': img_bytes.getvalue(), | |
| 'timestamp': datetime.now() | |
| } | |
| cleanup_expired_heatmaps() | |
| return { | |
| "session_id": session_id, | |
| "predictions": decoded[0], | |
| "heatmap_url": f"{request.base_url}static/heatmap/{session_id}" | |
| } | |
| except Exception as e: | |
| raise HTTPException(500, f"Processing failed: {str(e)}") | |
| async def get_heatmap(session_id: str): | |
| if session_id not in heatmap_store: | |
| raise HTTPException(404, "Session expired or invalid") | |
| return StreamingResponse( | |
| io.BytesIO(heatmap_store[session_id]['image']), | |
| media_type="image/png", | |
| headers={"Cache-Control": "max-age=300"} | |
| ) | |
| async def model_info(): | |
| return { | |
| "model_type": "DenseNet121", | |
| "input_size": "540x540", | |
| "classes": len(class_names), | |
| "gradcam_layer": layer_name, | |
| "rate_limit": "5 requests/minute" | |
| } | |
| async def http_handler(request: Request, exc: HTTPException): | |
| return JSONResponse( | |
| status_code=exc.status_code, | |
| content={"error": exc.detail} | |
| ) | |
| async def generic_handler(request: Request, exc: Exception): | |
| return JSONResponse( | |
| status_code=500, | |
| content={"error": "Internal server error"} | |
| ) |