thehammadishaq commited on
Commit
a97d9e5
·
verified ·
1 Parent(s): 5b66acb

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +229 -0
app.py ADDED
@@ -0,0 +1,229 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, UploadFile, File, HTTPException, Request, Depends
2
+ from fastapi.middleware.cors import CORSMiddleware
3
+ from fastapi.responses import JSONResponse, StreamingResponse
4
+ from fastapi.staticfiles import StaticFiles
5
+ from fastapi_limiter import FastAPILimiter
6
+ from fastapi_limiter.depends import RateLimiter
7
+ import tensorflow as tf
8
+ from tensorflow.keras.models import Model, load_model
9
+ from tensorflow.keras.preprocessing.image import img_to_array
10
+ from tensorflow.keras.applications.densenet import preprocess_input
11
+ import numpy as np
12
+ from PIL import Image
13
+ import matplotlib.pyplot as plt
14
+ import cv2
15
+ import io
16
+ import uuid
17
+ from typing import Dict
18
+ from datetime import datetime, timedelta
19
+ import os
20
+
21
+ # Configuration
22
+ MAX_FILE_SIZE = 10 * 1024 * 1024 # 10MB
23
+ HEATMAP_EXPIRY = 300 # 5 minutes in seconds
24
+ RATE_LIMIT = "5/minute" # 5 requests per minute
25
+
26
+ app = FastAPI(
27
+ title="ChexNet Medical Imaging API",
28
+ description="API for chest X-ray analysis with Grad-CAM visualization",
29
+ version="1.1.0"
30
+ )
31
+
32
+ # Mount static files
33
+ app.mount("/static", StaticFiles(directory="static"), name="static")
34
+
35
+ # CORS configuration
36
+ app.add_middleware(
37
+ CORSMiddleware,
38
+ allow_origins=["*"],
39
+ allow_credentials=True,
40
+ allow_methods=["*"],
41
+ allow_headers=["*"],
42
+ )
43
+
44
+ # Initialize rate limiter (in-memory)
45
+ @app.on_event("startup")
46
+ async def startup():
47
+ await FastAPILimiter.init()
48
+
49
+ # Session storage for heatmaps
50
+ heatmap_store: Dict[str, dict] = {}
51
+
52
+ # Load model
53
+ try:
54
+ model = load_model('Densenet.h5')
55
+ model.load_weights("pretrained_model.h5")
56
+ except Exception as e:
57
+ raise RuntimeError(f"Failed to load model: {str(e)}")
58
+
59
+ # Model configuration
60
+ layer_name = 'conv5_block16_concat'
61
+ class_names = [
62
+ 'Cardiomegaly', 'Emphysema', 'Effusion', 'Hernia', 'Infiltration',
63
+ 'Mass', 'Nodule', 'Atelectasis', 'Pneumothorax', 'Pleural_Thickening',
64
+ 'Pneumonia', 'Fibrosis', 'Edema', 'Consolidation', 'No Finding'
65
+ ]
66
+
67
+ def cleanup_expired_heatmaps():
68
+ """Remove heatmaps older than HEATMAP_EXPIRY seconds"""
69
+ now = datetime.now()
70
+ expired = [
71
+ sid for sid, data in heatmap_store.items()
72
+ if (now - data['timestamp']).total_seconds() > HEATMAP_EXPIRY
73
+ ]
74
+ for sid in expired:
75
+ del heatmap_store[sid]
76
+
77
+ def generate_gradcam(model, img, layer_name):
78
+ """Generate Grad-CAM heatmap overlay"""
79
+ img_array = img_to_array(img)
80
+ img_array = np.expand_dims(img_array, axis=0)
81
+ img_array = preprocess_input(img_array)
82
+
83
+ grad_model = Model(
84
+ inputs=model.inputs,
85
+ outputs=[model.get_layer(layer_name).output, model.output]
86
+ )
87
+
88
+ with tf.GradientTape() as tape:
89
+ conv_outputs, predictions = grad_model(img_array)
90
+ class_idx = tf.argmax(predictions[0])
91
+
92
+ output = conv_outputs[0]
93
+ grads = tape.gradient(predictions, conv_outputs)[0]
94
+ guided_grads = tf.cast(output > 0, 'float32') * tf.cast(grads > 0, 'float32') * grads
95
+
96
+ weights = tf.reduce_mean(guided_grads, axis=(0, 1))
97
+ cam = tf.reduce_sum(tf.multiply(weights, output), axis=-1)
98
+ heatmap = np.maximum(cam, 0)
99
+ heatmap /= np.max(heatmap)
100
+ heatmap_img = plt.cm.jet(heatmap)[..., :3]
101
+
102
+ original_img = Image.fromarray(img)
103
+ heatmap_img = Image.fromarray((heatmap_img * 255).astype(np.uint8))
104
+ heatmap_img = heatmap_img.resize(original_img.size)
105
+ return Image.blend(original_img, heatmap_img, 0.5)
106
+
107
+ def process_predictions(predictions, class_labels):
108
+ """Format predictions with top 4 classes"""
109
+ decoded = []
110
+ for pred in predictions:
111
+ top_indices = pred.argsort()[-4:][::-1] # Top 4 predictions
112
+ decoded.append([(class_labels[i], float(pred[i])) for i in top_indices])
113
+ return decoded
114
+
115
+ def preprocess_image(file_bytes):
116
+ """Convert uploaded file to processed image array"""
117
+ img = cv2.imdecode(np.frombuffer(file_bytes, np.uint8), cv2.IMREAD_COLOR)
118
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
119
+ return cv2.resize(img, (540, 540), interpolation=cv2.INTER_AREA)
120
+
121
+ @app.get("/", include_in_schema=False)
122
+ async def root():
123
+ return {"message": "ChexNet API is operational", "docs": "/docs"}
124
+
125
+ @app.get("/health")
126
+ async def health_check():
127
+ return {
128
+ "status": "healthy",
129
+ "model_loaded": model is not None,
130
+ "timestamp": datetime.now().isoformat()
131
+ }
132
+
133
+ @app.get("/model/classes")
134
+ async def get_class_names():
135
+ return {"classes": class_names}
136
+
137
+ @app.post("/analyze",
138
+ dependencies=[Depends(RateLimiter(times=RATE_LIMIT))])
139
+ async def analyze_image(request: Request, file: UploadFile = File(...)):
140
+ """
141
+ Analyze chest X-ray image and return predictions with Grad-CAM visualization
142
+
143
+ Parameters:
144
+ - file: Upload JPEG/PNG image (max 10MB)
145
+
146
+ Returns:
147
+ - predictions: Top 4 diagnoses with confidence scores
148
+ - heatmap_url: URL to retrieve Grad-CAM visualization
149
+ """
150
+ # Validate input
151
+ if not file.content_type.startswith('image/'):
152
+ raise HTTPException(400, "Only image files are accepted")
153
+
154
+ if file.size > MAX_FILE_SIZE:
155
+ raise HTTPException(413, f"Maximum file size is {MAX_FILE_SIZE//(1024*1024)}MB")
156
+
157
+ try:
158
+ # Process image
159
+ img = preprocess_image(await file.read())
160
+
161
+ # Prepare input tensor
162
+ img_array = img_to_array(img)
163
+ img_array = np.expand_dims(img_array, axis=0)
164
+ img_array = preprocess_input(img_array)
165
+
166
+ # Get predictions
167
+ predictions = model.predict(img_array)
168
+ decoded = process_predictions(predictions, class_names)
169
+
170
+ # Generate Grad-CAM
171
+ heatmap = generate_gradcam(model, img, layer_name)
172
+
173
+ # Store heatmap with session ID
174
+ session_id = str(uuid.uuid4())
175
+ img_bytes = io.BytesIO()
176
+ heatmap.save(img_bytes, format='PNG')
177
+
178
+ heatmap_store[session_id] = {
179
+ 'image': img_bytes.getvalue(),
180
+ 'timestamp': datetime.now()
181
+ }
182
+ cleanup_expired_heatmaps()
183
+
184
+ return {
185
+ "session_id": session_id,
186
+ "predictions": decoded[0],
187
+ "heatmap_url": f"{request.base_url}static/heatmap/{session_id}"
188
+ }
189
+
190
+ except Exception as e:
191
+ raise HTTPException(500, f"Processing failed: {str(e)}")
192
+
193
+ @app.get("/static/heatmap/{session_id}")
194
+ async def get_heatmap(session_id: str):
195
+ """Retrieve Grad-CAM visualization by session ID"""
196
+ if session_id not in heatmap_store:
197
+ raise HTTPException(404, "Session expired or invalid")
198
+
199
+ return StreamingResponse(
200
+ io.BytesIO(heatmap_store[session_id]['image']),
201
+ media_type="image/png",
202
+ headers={"Cache-Control": "max-age=300"}
203
+ )
204
+
205
+ @app.get("/model/info")
206
+ async def model_info():
207
+ """Get model metadata"""
208
+ return {
209
+ "model_type": "DenseNet121",
210
+ "input_size": "540x540",
211
+ "classes": len(class_names),
212
+ "gradcam_layer": layer_name,
213
+ "rate_limit": RATE_LIMIT
214
+ }
215
+
216
+ # Error handlers
217
+ @app.exception_handler(HTTPException)
218
+ async def handle_http_exception(request, exc):
219
+ return JSONResponse(
220
+ status_code=exc.status_code,
221
+ content={"error": exc.detail}
222
+ )
223
+
224
+ @app.exception_handler(Exception)
225
+ async def handle_generic_exception(request, exc):
226
+ return JSONResponse(
227
+ status_code=500,
228
+ content={"error": "Internal server error"}
229
+ )