thehammadishaq commited on
Commit
a9ea101
·
verified ·
1 Parent(s): bd9b8a5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +15 -81
app.py CHANGED
@@ -16,7 +16,6 @@ import matplotlib.pyplot as plt
16
  import cv2
17
  import io
18
  import uuid
19
- from typing import Dict
20
  from datetime import datetime, timedelta
21
  from pathlib import Path
22
  import os
@@ -24,25 +23,26 @@ import os
24
  # Configuration
25
  MAX_FILE_SIZE = 10 * 1024 * 1024 # 10MB
26
  HEATMAP_EXPIRY = 300 # 5 minutes in seconds
27
- PORT = 7860 # Hugging Face Spaces requires port 7860
 
28
 
29
- # Initialize FastAPI with rate limiting
30
  app = FastAPI(
31
  title="ChexNet Medical Imaging API",
32
  description="API for chest X-ray analysis with Grad-CAM visualization",
33
- version="4.0.0"
34
  )
35
 
36
  # Rate limiter setup
37
  limiter = Limiter(key_func=get_remote_address)
38
  app.state.limiter = limiter
39
 
40
- # Create static/heatmap directory
41
- heatmap_dir = Path("static/heatmap")
42
  heatmap_dir.mkdir(parents=True, exist_ok=True)
43
 
44
- # Mount static files
45
- app.mount("/static", StaticFiles(directory="static"), name="static")
46
 
47
  # CORS configuration
48
  app.add_middleware(
@@ -53,9 +53,6 @@ app.add_middleware(
53
  allow_headers=["*"],
54
  )
55
 
56
- # Session storage for heatmaps (now using file system)
57
- heatmap_store: Dict[str, dict] = {}
58
-
59
  # Model configuration
60
  layer_name = 'conv5_block16_concat'
61
  class_names = [
@@ -65,7 +62,6 @@ class_names = [
65
  ]
66
 
67
  def build_model():
68
- """Build DenseNet121 model with correct output shape"""
69
  base_model = DenseNet121(
70
  weights=None,
71
  include_top=False,
@@ -73,20 +69,17 @@ def build_model():
73
  )
74
  x = base_model.output
75
  x = GlobalAveragePooling2D()(x)
76
- predictions = Dense(14, activation='sigmoid')(x) # 14 classes in pretrained weights
77
  return Model(inputs=base_model.input, outputs=predictions)
78
 
79
  def load_model_with_fallback():
80
- """Robust model loading with multiple fallback strategies"""
81
  try:
82
- # Strategy 1: Build model and load weights
83
  model = build_model()
84
  model.load_weights('pretrained_model.h5')
85
  return model
86
  except Exception as e:
87
  print(f"Primary loading failed: {e}")
88
  try:
89
- # Strategy 2: Try direct loading
90
  model = load_model('Densenet.h5', compile=False)
91
  return model
92
  except Exception as e:
@@ -97,26 +90,21 @@ def load_model_with_fallback():
97
  try:
98
  model = load_model_with_fallback()
99
  print("✅ Model loaded successfully!")
100
- print(f"Model input shape: {model.input_shape}")
101
- print(f"Model output shape: {model.output_shape}")
102
  except Exception as e:
103
  print(f"❌ Model loading failed: {e}")
104
  raise
105
 
106
  async def cleanup_old_heatmaps():
107
- """Delete heatmap files older than HEATMAP_EXPIRY seconds"""
108
  now = datetime.now()
109
  for file in heatmap_dir.glob("*.png"):
110
  file_time = datetime.fromtimestamp(file.stat().st_mtime)
111
  if (now - file_time) > timedelta(seconds=HEATMAP_EXPIRY):
112
  try:
113
  file.unlink()
114
- print(f"Deleted expired heatmap: {file.name}")
115
  except Exception as e:
116
  print(f"Error deleting {file.name}: {e}")
117
 
118
  def generate_gradcam(img):
119
- """Generate Grad-CAM heatmap visualization"""
120
  img_array = img_to_array(img)
121
  img_array = np.expand_dims(img_array, axis=0)
122
  img_array = preprocess_input(img_array)
@@ -146,7 +134,6 @@ def generate_gradcam(img):
146
  return Image.blend(original_img, heatmap_img, 0.5)
147
 
148
  def process_predictions(predictions):
149
- """Format model predictions with confidence scores"""
150
  decoded = []
151
  for pred in predictions:
152
  top_indices = np.argsort(pred)[::-1][:len(class_names)]
@@ -154,36 +141,10 @@ def process_predictions(predictions):
154
  return decoded
155
 
156
  def preprocess_image(file_bytes):
157
- """Process uploaded image file"""
158
  img = cv2.imdecode(np.frombuffer(file_bytes, np.uint8), cv2.IMREAD_COLOR)
159
  img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
160
  return cv2.resize(img, (540, 540), interpolation=cv2.INTER_AREA)
161
 
162
- @app.get("/")
163
- async def root():
164
- return {
165
- "message": "ChexNet API is running",
166
- "endpoints": {
167
- "docs": "/docs",
168
- "health": "/health",
169
- "analyze": "POST /analyze"
170
- }
171
- }
172
-
173
- @app.get("/health")
174
- async def health_check():
175
- return {
176
- "status": "healthy",
177
- "model_loaded": True,
178
- "timestamp": datetime.now().isoformat(),
179
- "port": PORT,
180
- "heatmap_files": len(list(heatmap_dir.glob("*.png")))
181
- }
182
-
183
- @app.get("/model/classes")
184
- async def get_class_names():
185
- return {"classes": class_names}
186
-
187
  @app.post("/analyze")
188
  @limiter.limit("5/minute")
189
  async def analyze_image(
@@ -191,7 +152,6 @@ async def analyze_image(
191
  background_tasks: BackgroundTasks,
192
  file: UploadFile = File(...)
193
  ):
194
- """Analyze chest X-ray image"""
195
  if not file.content_type.startswith('image/'):
196
  raise HTTPException(400, "Only image files accepted")
197
 
@@ -202,63 +162,37 @@ async def analyze_image(
202
  contents = await file.read()
203
  img = preprocess_image(contents)
204
 
205
- # Prepare input tensor
206
  img_array = img_to_array(img)
207
  img_array = np.expand_dims(img_array, axis=0)
208
  img_array = preprocess_input(img_array)
209
 
210
- # Get predictions
211
  predictions = model.predict(img_array)
212
  decoded = process_predictions(predictions)
213
 
214
- # Generate Grad-CAM
215
  heatmap = generate_gradcam(img)
216
 
217
- # Store heatmap with session ID
218
  session_id = str(uuid.uuid4())
219
  heatmap_path = heatmap_dir / f"{session_id}.png"
220
  heatmap.save(heatmap_path)
221
 
222
- # Add cleanup task
223
  background_tasks.add_task(cleanup_old_heatmaps)
224
 
225
- # Generate HTTPS URL
226
- heatmap_url = f"https://{request.url.hostname}/static/heatmap/{session_id}.png"
227
-
228
  return {
229
  "session_id": session_id,
230
  "predictions": decoded[0],
231
- "heatmap_url": heatmap_url
232
  }
233
  except Exception as e:
234
  raise HTTPException(500, f"Analysis failed: {str(e)}")
235
 
236
- @app.get("/model/info")
237
- async def model_info():
238
- """Get model metadata"""
239
  return {
240
- "model_type": "DenseNet121",
241
- "input_shape": str(model.input_shape),
242
- "output_shape": str(model.output_shape),
243
- "classes": len(class_names),
244
- "gradcam_layer": layer_name,
245
- "rate_limit": "5 requests/minute"
246
  }
247
 
248
- @app.exception_handler(HTTPException)
249
- async def http_exception_handler(request, exc):
250
- return JSONResponse(
251
- status_code=exc.status_code,
252
- content={"error": exc.detail}
253
- )
254
-
255
- @app.exception_handler(Exception)
256
- async def generic_exception_handler(request, exc):
257
- return JSONResponse(
258
- status_code=500,
259
- content={"error": "Internal server error"}
260
- )
261
-
262
  if __name__ == "__main__":
263
  import uvicorn
264
  uvicorn.run(app, host="0.0.0.0", port=PORT)
 
16
  import cv2
17
  import io
18
  import uuid
 
19
  from datetime import datetime, timedelta
20
  from pathlib import Path
21
  import os
 
23
  # Configuration
24
  MAX_FILE_SIZE = 10 * 1024 * 1024 # 10MB
25
  HEATMAP_EXPIRY = 300 # 5 minutes in seconds
26
+ PORT = 7860
27
+ HEATMAP_DIR = "/tmp/heatmaps" # Changed to writable /tmp directory
28
 
29
+ # Initialize FastAPI
30
  app = FastAPI(
31
  title="ChexNet Medical Imaging API",
32
  description="API for chest X-ray analysis with Grad-CAM visualization",
33
+ version="4.1.0"
34
  )
35
 
36
  # Rate limiter setup
37
  limiter = Limiter(key_func=get_remote_address)
38
  app.state.limiter = limiter
39
 
40
+ # Create heatmap directory (in /tmp which is writable)
41
+ heatmap_dir = Path(HEATMAP_DIR)
42
  heatmap_dir.mkdir(parents=True, exist_ok=True)
43
 
44
+ # Mount static files from /tmp
45
+ app.mount("/static/heatmap", StaticFiles(directory=HEATMAP_DIR), name="heatmaps")
46
 
47
  # CORS configuration
48
  app.add_middleware(
 
53
  allow_headers=["*"],
54
  )
55
 
 
 
 
56
  # Model configuration
57
  layer_name = 'conv5_block16_concat'
58
  class_names = [
 
62
  ]
63
 
64
  def build_model():
 
65
  base_model = DenseNet121(
66
  weights=None,
67
  include_top=False,
 
69
  )
70
  x = base_model.output
71
  x = GlobalAveragePooling2D()(x)
72
+ predictions = Dense(14, activation='sigmoid')(x)
73
  return Model(inputs=base_model.input, outputs=predictions)
74
 
75
  def load_model_with_fallback():
 
76
  try:
 
77
  model = build_model()
78
  model.load_weights('pretrained_model.h5')
79
  return model
80
  except Exception as e:
81
  print(f"Primary loading failed: {e}")
82
  try:
 
83
  model = load_model('Densenet.h5', compile=False)
84
  return model
85
  except Exception as e:
 
90
  try:
91
  model = load_model_with_fallback()
92
  print("✅ Model loaded successfully!")
 
 
93
  except Exception as e:
94
  print(f"❌ Model loading failed: {e}")
95
  raise
96
 
97
  async def cleanup_old_heatmaps():
 
98
  now = datetime.now()
99
  for file in heatmap_dir.glob("*.png"):
100
  file_time = datetime.fromtimestamp(file.stat().st_mtime)
101
  if (now - file_time) > timedelta(seconds=HEATMAP_EXPIRY):
102
  try:
103
  file.unlink()
 
104
  except Exception as e:
105
  print(f"Error deleting {file.name}: {e}")
106
 
107
  def generate_gradcam(img):
 
108
  img_array = img_to_array(img)
109
  img_array = np.expand_dims(img_array, axis=0)
110
  img_array = preprocess_input(img_array)
 
134
  return Image.blend(original_img, heatmap_img, 0.5)
135
 
136
  def process_predictions(predictions):
 
137
  decoded = []
138
  for pred in predictions:
139
  top_indices = np.argsort(pred)[::-1][:len(class_names)]
 
141
  return decoded
142
 
143
  def preprocess_image(file_bytes):
 
144
  img = cv2.imdecode(np.frombuffer(file_bytes, np.uint8), cv2.IMREAD_COLOR)
145
  img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
146
  return cv2.resize(img, (540, 540), interpolation=cv2.INTER_AREA)
147
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
148
  @app.post("/analyze")
149
  @limiter.limit("5/minute")
150
  async def analyze_image(
 
152
  background_tasks: BackgroundTasks,
153
  file: UploadFile = File(...)
154
  ):
 
155
  if not file.content_type.startswith('image/'):
156
  raise HTTPException(400, "Only image files accepted")
157
 
 
162
  contents = await file.read()
163
  img = preprocess_image(contents)
164
 
 
165
  img_array = img_to_array(img)
166
  img_array = np.expand_dims(img_array, axis=0)
167
  img_array = preprocess_input(img_array)
168
 
 
169
  predictions = model.predict(img_array)
170
  decoded = process_predictions(predictions)
171
 
 
172
  heatmap = generate_gradcam(img)
173
 
 
174
  session_id = str(uuid.uuid4())
175
  heatmap_path = heatmap_dir / f"{session_id}.png"
176
  heatmap.save(heatmap_path)
177
 
 
178
  background_tasks.add_task(cleanup_old_heatmaps)
179
 
 
 
 
180
  return {
181
  "session_id": session_id,
182
  "predictions": decoded[0],
183
+ "heatmap_url": f"https://{request.url.hostname}/static/heatmap/{session_id}.png"
184
  }
185
  except Exception as e:
186
  raise HTTPException(500, f"Analysis failed: {str(e)}")
187
 
188
+ @app.get("/health")
189
+ async def health_check():
 
190
  return {
191
+ "status": "healthy",
192
+ "timestamp": datetime.now().isoformat(),
193
+ "heatmap_files": len(list(heatmap_dir.glob("*.png")))
 
 
 
194
  }
195
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
196
  if __name__ == "__main__":
197
  import uvicorn
198
  uvicorn.run(app, host="0.0.0.0", port=PORT)