thehammadishaq commited on
Commit
87cc4a2
·
verified ·
1 Parent(s): ca49012

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +60 -40
app.py CHANGED
@@ -6,8 +6,9 @@ from slowapi import Limiter
6
  from slowapi.util import get_remote_address
7
  import tensorflow as tf
8
  from tensorflow.keras.models import Model, load_model
 
9
  from tensorflow.keras.preprocessing.image import img_to_array
10
- from tensorflow.keras.applications.densenet import preprocess_input
11
  import numpy as np
12
  from PIL import Image
13
  import matplotlib.pyplot as plt
@@ -26,7 +27,7 @@ HEATMAP_EXPIRY = 300 # 5 minutes in seconds
26
  app = FastAPI(
27
  title="ChexNet Medical Imaging API",
28
  description="API for chest X-ray analysis with Grad-CAM visualization",
29
- version="2.1.0"
30
  )
31
 
32
  # Rate limiter setup
@@ -48,13 +49,6 @@ app.add_middleware(
48
  # Session storage for heatmaps
49
  heatmap_store: Dict[str, dict] = {}
50
 
51
- # Load model
52
- try:
53
- model = load_model('Densenet.h5')
54
- model.load_weights("pretrained_model.h5")
55
- except Exception as e:
56
- raise RuntimeError(f"Failed to load model: {str(e)}")
57
-
58
  # Model configuration
59
  layer_name = 'conv5_block16_concat'
60
  class_names = [
@@ -63,6 +57,41 @@ class_names = [
63
  'Pneumonia', 'Fibrosis', 'Edema', 'Consolidation', 'No Finding'
64
  ]
65
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
66
  def cleanup_expired_heatmaps():
67
  """Remove heatmaps older than HEATMAP_EXPIRY seconds"""
68
  now = datetime.now()
@@ -73,7 +102,7 @@ def cleanup_expired_heatmaps():
73
  for sid in expired:
74
  del heatmap_store[sid]
75
 
76
- def generate_gradcam(model, img, layer_name):
77
  """Generate Grad-CAM heatmap overlay"""
78
  img_array = img_to_array(img)
79
  img_array = np.expand_dims(img_array, axis=0)
@@ -103,12 +132,12 @@ def generate_gradcam(model, img, layer_name):
103
  heatmap_img = heatmap_img.resize(original_img.size)
104
  return Image.blend(original_img, heatmap_img, 0.5)
105
 
106
- def process_predictions(predictions, class_labels):
107
  """Format predictions with top 4 classes"""
108
  decoded = []
109
  for pred in predictions:
110
- top_indices = pred.argsort()[-4:][::-1] # Top 4 predictions
111
- decoded.append([(class_labels[i], float(pred[i])) for i in top_indices])
112
  return decoded
113
 
114
  def preprocess_image(file_bytes):
@@ -119,14 +148,21 @@ def preprocess_image(file_bytes):
119
 
120
  @app.get("/", include_in_schema=False)
121
  async def root():
122
- return {"message": "ChexNet API is operational", "docs": "/docs"}
 
 
 
 
 
 
 
123
 
124
  @app.get("/health")
125
  async def health_check():
126
  return {
127
- "status": "healthy",
128
- "model_loaded": model is not None,
129
- "timestamp": datetime.now().isoformat()
130
  }
131
 
132
  @app.get("/model/classes")
@@ -136,26 +172,15 @@ async def get_class_names():
136
  @app.post("/analyze")
137
  @limiter.limit("5/minute")
138
  async def analyze_image(request: Request, file: UploadFile = File(...)):
139
- """
140
- Analyze chest X-ray image and return predictions with Grad-CAM visualization
141
-
142
- Parameters:
143
- - file: Upload JPEG/PNG image (max 10MB)
144
-
145
- Returns:
146
- - predictions: Top 4 diagnoses with confidence scores
147
- - heatmap_url: URL to retrieve Grad-CAM visualization
148
- """
149
- # Validate input
150
  if not file.content_type.startswith('image/'):
151
  raise HTTPException(400, "Only image files are accepted")
152
 
153
  if file.size > MAX_FILE_SIZE:
154
- raise HTTPException(413, f"Maximum file size is {MAX_FILE_SIZE//(1024*1024)}MB")
155
 
156
  try:
157
- # Process image
158
- img = preprocess_image(await file.read())
159
 
160
  # Prepare input tensor
161
  img_array = img_to_array(img)
@@ -164,10 +189,10 @@ async def analyze_image(request: Request, file: UploadFile = File(...)):
164
 
165
  # Get predictions
166
  predictions = model.predict(img_array)
167
- decoded = process_predictions(predictions, class_names)
168
 
169
  # Generate Grad-CAM
170
- heatmap = generate_gradcam(model, img, layer_name)
171
 
172
  # Store heatmap with session ID
173
  session_id = str(uuid.uuid4())
@@ -185,16 +210,13 @@ async def analyze_image(request: Request, file: UploadFile = File(...)):
185
  "predictions": decoded[0],
186
  "heatmap_url": f"{request.base_url}static/heatmap/{session_id}"
187
  }
188
-
189
  except Exception as e:
190
  raise HTTPException(500, f"Processing failed: {str(e)}")
191
 
192
  @app.get("/static/heatmap/{session_id}")
193
  async def get_heatmap(session_id: str):
194
- """Retrieve Grad-CAM visualization by session ID"""
195
  if session_id not in heatmap_store:
196
  raise HTTPException(404, "Session expired or invalid")
197
-
198
  return StreamingResponse(
199
  io.BytesIO(heatmap_store[session_id]['image']),
200
  media_type="image/png",
@@ -203,7 +225,6 @@ async def get_heatmap(session_id: str):
203
 
204
  @app.get("/model/info")
205
  async def model_info():
206
- """Get model metadata"""
207
  return {
208
  "model_type": "DenseNet121",
209
  "input_size": "540x540",
@@ -212,16 +233,15 @@ async def model_info():
212
  "rate_limit": "5 requests/minute"
213
  }
214
 
215
- # Error handlers
216
  @app.exception_handler(HTTPException)
217
- async def handle_http_exception(request, exc):
218
  return JSONResponse(
219
  status_code=exc.status_code,
220
  content={"error": exc.detail}
221
  )
222
 
223
  @app.exception_handler(Exception)
224
- async def handle_generic_exception(request, exc):
225
  return JSONResponse(
226
  status_code=500,
227
  content={"error": "Internal server error"}
 
6
  from slowapi.util import get_remote_address
7
  import tensorflow as tf
8
  from tensorflow.keras.models import Model, load_model
9
+ from tensorflow.keras.layers import Input
10
  from tensorflow.keras.preprocessing.image import img_to_array
11
+ from tensorflow.keras.applications.densenet import DenseNet121, preprocess_input
12
  import numpy as np
13
  from PIL import Image
14
  import matplotlib.pyplot as plt
 
27
  app = FastAPI(
28
  title="ChexNet Medical Imaging API",
29
  description="API for chest X-ray analysis with Grad-CAM visualization",
30
+ version="2.3.0"
31
  )
32
 
33
  # Rate limiter setup
 
49
  # Session storage for heatmaps
50
  heatmap_store: Dict[str, dict] = {}
51
 
 
 
 
 
 
 
 
52
  # Model configuration
53
  layer_name = 'conv5_block16_concat'
54
  class_names = [
 
57
  'Pneumonia', 'Fibrosis', 'Edema', 'Consolidation', 'No Finding'
58
  ]
59
 
60
+ def build_model():
61
+ """Build DenseNet121 model architecture"""
62
+ base_model = DenseNet121(weights=None, include_top=False, input_shape=(None, None, 3))
63
+ x = base_model.output
64
+ x = tf.keras.layers.GlobalAveragePooling2D()(x)
65
+ predictions = tf.keras.layers.Dense(len(class_names), activation='sigmoid')(x)
66
+ return Model(inputs=base_model.input, outputs=predictions)
67
+
68
+ def load_model_with_fallback():
69
+ """Attempt multiple strategies to load the model"""
70
+ try:
71
+ # Strategy 1: Try direct loading
72
+ model = load_model('Densenet.h5', compile=False)
73
+ model.load_weights('pretrained_model.h5')
74
+ return model
75
+ except Exception as e:
76
+ print(f"Direct loading failed: {e}")
77
+
78
+ try:
79
+ # Strategy 2: Build architecture and load weights
80
+ model = build_model()
81
+ model.load_weights('pretrained_model.h5')
82
+ return model
83
+ except Exception as e:
84
+ print(f"Architecture rebuild failed: {e}")
85
+ raise RuntimeError("All model loading strategies failed")
86
+
87
+ # Load model
88
+ try:
89
+ model = load_model_with_fallback()
90
+ print("✅ Model loaded successfully!")
91
+ except Exception as e:
92
+ print(f"❌ Model loading failed: {e}")
93
+ raise
94
+
95
  def cleanup_expired_heatmaps():
96
  """Remove heatmaps older than HEATMAP_EXPIRY seconds"""
97
  now = datetime.now()
 
102
  for sid in expired:
103
  del heatmap_store[sid]
104
 
105
+ def generate_gradcam(img):
106
  """Generate Grad-CAM heatmap overlay"""
107
  img_array = img_to_array(img)
108
  img_array = np.expand_dims(img_array, axis=0)
 
132
  heatmap_img = heatmap_img.resize(original_img.size)
133
  return Image.blend(original_img, heatmap_img, 0.5)
134
 
135
+ def process_predictions(predictions):
136
  """Format predictions with top 4 classes"""
137
  decoded = []
138
  for pred in predictions:
139
+ top_indices = pred.argsort()[-4:][::-1]
140
+ decoded.append([(class_names[i], float(pred[i])) for i in top_indices])
141
  return decoded
142
 
143
  def preprocess_image(file_bytes):
 
148
 
149
  @app.get("/", include_in_schema=False)
150
  async def root():
151
+ return {
152
+ "message": "ChexNet API is operational",
153
+ "endpoints": {
154
+ "docs": "/docs",
155
+ "health": "/health",
156
+ "analyze": "POST /analyze"
157
+ }
158
+ }
159
 
160
  @app.get("/health")
161
  async def health_check():
162
  return {
163
+ "status": "healthy" if model else "unhealthy",
164
+ "timestamp": datetime.now().isoformat(),
165
+ "model_loaded": bool(model)
166
  }
167
 
168
  @app.get("/model/classes")
 
172
  @app.post("/analyze")
173
  @limiter.limit("5/minute")
174
  async def analyze_image(request: Request, file: UploadFile = File(...)):
 
 
 
 
 
 
 
 
 
 
 
175
  if not file.content_type.startswith('image/'):
176
  raise HTTPException(400, "Only image files are accepted")
177
 
178
  if file.size > MAX_FILE_SIZE:
179
+ raise HTTPException(413, f"Max file size is {MAX_FILE_SIZE//(1024*1024)}MB")
180
 
181
  try:
182
+ contents = await file.read()
183
+ img = preprocess_image(contents)
184
 
185
  # Prepare input tensor
186
  img_array = img_to_array(img)
 
189
 
190
  # Get predictions
191
  predictions = model.predict(img_array)
192
+ decoded = process_predictions(predictions)
193
 
194
  # Generate Grad-CAM
195
+ heatmap = generate_gradcam(img)
196
 
197
  # Store heatmap with session ID
198
  session_id = str(uuid.uuid4())
 
210
  "predictions": decoded[0],
211
  "heatmap_url": f"{request.base_url}static/heatmap/{session_id}"
212
  }
 
213
  except Exception as e:
214
  raise HTTPException(500, f"Processing failed: {str(e)}")
215
 
216
  @app.get("/static/heatmap/{session_id}")
217
  async def get_heatmap(session_id: str):
 
218
  if session_id not in heatmap_store:
219
  raise HTTPException(404, "Session expired or invalid")
 
220
  return StreamingResponse(
221
  io.BytesIO(heatmap_store[session_id]['image']),
222
  media_type="image/png",
 
225
 
226
  @app.get("/model/info")
227
  async def model_info():
 
228
  return {
229
  "model_type": "DenseNet121",
230
  "input_size": "540x540",
 
233
  "rate_limit": "5 requests/minute"
234
  }
235
 
 
236
  @app.exception_handler(HTTPException)
237
+ async def http_handler(request: Request, exc: HTTPException):
238
  return JSONResponse(
239
  status_code=exc.status_code,
240
  content={"error": exc.detail}
241
  )
242
 
243
  @app.exception_handler(Exception)
244
+ async def generic_handler(request: Request, exc: Exception):
245
  return JSONResponse(
246
  status_code=500,
247
  content={"error": "Internal server error"}