thehammadishaq commited on
Commit
1c32437
·
verified ·
1 Parent(s): 08043e7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +67 -68
app.py CHANGED
@@ -1,4 +1,4 @@
1
- from fastapi import FastAPI, UploadFile, File, HTTPException, Request
2
  from fastapi.middleware.cors import CORSMiddleware
3
  from fastapi.responses import JSONResponse, StreamingResponse
4
  from fastapi.staticfiles import StaticFiles
@@ -6,7 +6,7 @@ from slowapi import Limiter
6
  from slowapi.util import get_remote_address
7
  import tensorflow as tf
8
  from tensorflow.keras.models import Model, load_model
9
- from tensorflow.keras.layers import Input, GlobalAveragePooling2D, Dense
10
  from tensorflow.keras.applications import DenseNet121
11
  from tensorflow.keras.preprocessing.image import img_to_array
12
  from tensorflow.keras.applications.densenet import preprocess_input
@@ -18,23 +18,29 @@ import io
18
  import uuid
19
  from typing import Dict
20
  from datetime import datetime, timedelta
 
21
  import os
22
 
23
  # Configuration
24
  MAX_FILE_SIZE = 10 * 1024 * 1024 # 10MB
25
  HEATMAP_EXPIRY = 300 # 5 minutes in seconds
 
26
 
27
  # Initialize FastAPI with rate limiting
28
  app = FastAPI(
29
  title="ChexNet Medical Imaging API",
30
  description="API for chest X-ray analysis with Grad-CAM visualization",
31
- version="3.0.0"
32
  )
33
 
34
  # Rate limiter setup
35
  limiter = Limiter(key_func=get_remote_address)
36
  app.state.limiter = limiter
37
 
 
 
 
 
38
  # Mount static files
39
  app.mount("/static", StaticFiles(directory="static"), name="static")
40
 
@@ -47,7 +53,7 @@ app.add_middleware(
47
  allow_headers=["*"],
48
  )
49
 
50
- # Session storage for heatmaps
51
  heatmap_store: Dict[str, dict] = {}
52
 
53
  # Model configuration
@@ -58,8 +64,8 @@ class_names = [
58
  'Pneumonia', 'Fibrosis', 'Edema', 'Consolidation', 'No Finding'
59
  ]
60
 
61
- def build_custom_model():
62
- """Build model with correct output shape matching your weights"""
63
  base_model = DenseNet121(
64
  weights=None,
65
  include_top=False,
@@ -67,54 +73,50 @@ def build_custom_model():
67
  )
68
  x = base_model.output
69
  x = GlobalAveragePooling2D()(x)
70
- # Match the output shape in your pretrained weights (14 classes)
71
- predictions = Dense(14, activation='sigmoid')(x)
72
  return Model(inputs=base_model.input, outputs=predictions)
73
 
74
- def load_model_with_retry():
75
- """Enhanced model loading with shape compatibility handling"""
76
  try:
77
- # First try loading with custom architecture
78
- model = build_custom_model()
79
  model.load_weights('pretrained_model.h5')
80
  return model
81
  except Exception as e:
82
- print(f"Loading with custom architecture failed: {e}")
83
  try:
84
- # Fallback to direct loading with compile=False
85
  model = load_model('Densenet.h5', compile=False)
86
- # Ensure output layer matches our class names
87
- if model.layers[-1].output_shape[-1] != len(class_names):
88
- print("Adjusting output layer to match class names")
89
- x = model.layers[-2].output
90
- predictions = Dense(len(class_names), activation='sigmoid')(x)
91
- model = Model(inputs=model.input, outputs=predictions)
92
  return model
93
  except Exception as e:
94
- print(f"All loading attempts failed: {e}")
95
- raise RuntimeError(f"Could not load model: {str(e)}")
96
 
97
  # Load model
98
  try:
99
- model = load_model_with_retry()
100
  print("✅ Model loaded successfully!")
 
101
  print(f"Model output shape: {model.output_shape}")
102
  except Exception as e:
103
  print(f"❌ Model loading failed: {e}")
104
  raise
105
 
106
- def cleanup_expired_heatmaps():
107
- """Remove heatmaps older than HEATMAP_EXPIRY seconds"""
108
  now = datetime.now()
109
- expired = [
110
- sid for sid, data in heatmap_store.items()
111
- if (now - data['timestamp']).total_seconds() > HEATMAP_EXPIRY
112
- ]
113
- for sid in expired:
114
- del heatmap_store[sid]
 
 
115
 
116
  def generate_gradcam(img):
117
- """Generate Grad-CAM heatmap overlay"""
118
  img_array = img_to_array(img)
119
  img_array = np.expand_dims(img_array, axis=0)
120
  img_array = preprocess_input(img_array)
@@ -144,24 +146,23 @@ def generate_gradcam(img):
144
  return Image.blend(original_img, heatmap_img, 0.5)
145
 
146
  def process_predictions(predictions):
147
- """Format predictions with top classes"""
148
  decoded = []
149
  for pred in predictions:
150
- # Get indices sorted by probability (descending)
151
  top_indices = np.argsort(pred)[::-1][:len(class_names)]
152
  decoded.append([(class_names[i], float(pred[i])) for i in top_indices])
153
  return decoded
154
 
155
  def preprocess_image(file_bytes):
156
- """Convert uploaded file to processed image array"""
157
  img = cv2.imdecode(np.frombuffer(file_bytes, np.uint8), cv2.IMREAD_COLOR)
158
  img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
159
  return cv2.resize(img, (540, 540), interpolation=cv2.INTER_AREA)
160
 
161
- @app.get("/", include_in_schema=False)
162
  async def root():
163
  return {
164
- "message": "ChexNet API is operational",
165
  "endpoints": {
166
  "docs": "/docs",
167
  "health": "/health",
@@ -172,10 +173,11 @@ async def root():
172
  @app.get("/health")
173
  async def health_check():
174
  return {
175
- "status": "healthy" if model else "unhealthy",
 
176
  "timestamp": datetime.now().isoformat(),
177
- "model_loaded": bool(model),
178
- "model_output_shape": str(model.output_shape) if model else "N/A"
179
  }
180
 
181
  @app.get("/model/classes")
@@ -184,13 +186,17 @@ async def get_class_names():
184
 
185
  @app.post("/analyze")
186
  @limiter.limit("5/minute")
187
- async def analyze_image(request: Request, file: UploadFile = File(...)):
 
 
 
 
188
  """Analyze chest X-ray image"""
189
  if not file.content_type.startswith('image/'):
190
- raise HTTPException(400, "Only image files are accepted")
191
 
192
  if file.size > MAX_FILE_SIZE:
193
- raise HTTPException(413, f"Max file size is {MAX_FILE_SIZE//(1024*1024)}MB")
194
 
195
  try:
196
  contents = await file.read()
@@ -210,56 +216,49 @@ async def analyze_image(request: Request, file: UploadFile = File(...)):
210
 
211
  # Store heatmap with session ID
212
  session_id = str(uuid.uuid4())
213
- img_bytes = io.BytesIO()
214
- heatmap.save(img_bytes, format='PNG')
215
 
216
- heatmap_store[session_id] = {
217
- 'image': img_bytes.getvalue(),
218
- 'timestamp': datetime.now()
219
- }
220
- cleanup_expired_heatmaps()
221
 
222
  return {
223
  "session_id": session_id,
224
  "predictions": decoded[0],
225
- "heatmap_url": f"{request.base_url}static/heatmap/{session_id}"
226
  }
227
  except Exception as e:
228
- raise HTTPException(500, f"Processing failed: {str(e)}")
229
-
230
- @app.get("/static/heatmap/{session_id}")
231
- async def get_heatmap(session_id: str):
232
- """Retrieve Grad-CAM visualization"""
233
- if session_id not in heatmap_store:
234
- raise HTTPException(404, "Session expired or invalid")
235
- return StreamingResponse(
236
- io.BytesIO(heatmap_store[session_id]['image']),
237
- media_type="image/png",
238
- headers={"Cache-Control": "max-age=300"}
239
- )
240
 
241
  @app.get("/model/info")
242
  async def model_info():
243
  """Get model metadata"""
244
  return {
245
  "model_type": "DenseNet121",
246
- "input_size": "540x540",
247
- "classes": len(class_names),
248
  "output_shape": str(model.output_shape),
 
249
  "gradcam_layer": layer_name,
250
  "rate_limit": "5 requests/minute"
251
  }
252
 
253
  @app.exception_handler(HTTPException)
254
- async def http_handler(request: Request, exc: HTTPException):
255
  return JSONResponse(
256
  status_code=exc.status_code,
257
  content={"error": exc.detail}
258
  )
259
 
260
  @app.exception_handler(Exception)
261
- async def generic_handler(request: Request, exc: Exception):
262
  return JSONResponse(
263
  status_code=500,
264
  content={"error": "Internal server error"}
265
- )
 
 
 
 
 
1
+ from fastapi import FastAPI, UploadFile, File, HTTPException, Request, BackgroundTasks
2
  from fastapi.middleware.cors import CORSMiddleware
3
  from fastapi.responses import JSONResponse, StreamingResponse
4
  from fastapi.staticfiles import StaticFiles
 
6
  from slowapi.util import get_remote_address
7
  import tensorflow as tf
8
  from tensorflow.keras.models import Model, load_model
9
+ from tensorflow.keras.layers import GlobalAveragePooling2D, Dense
10
  from tensorflow.keras.applications import DenseNet121
11
  from tensorflow.keras.preprocessing.image import img_to_array
12
  from tensorflow.keras.applications.densenet import preprocess_input
 
18
  import uuid
19
  from typing import Dict
20
  from datetime import datetime, timedelta
21
+ from pathlib import Path
22
  import os
23
 
24
  # Configuration
25
  MAX_FILE_SIZE = 10 * 1024 * 1024 # 10MB
26
  HEATMAP_EXPIRY = 300 # 5 minutes in seconds
27
+ PORT = 7860 # Hugging Face Spaces requires port 7860
28
 
29
  # Initialize FastAPI with rate limiting
30
  app = FastAPI(
31
  title="ChexNet Medical Imaging API",
32
  description="API for chest X-ray analysis with Grad-CAM visualization",
33
+ version="4.0.0"
34
  )
35
 
36
  # Rate limiter setup
37
  limiter = Limiter(key_func=get_remote_address)
38
  app.state.limiter = limiter
39
 
40
+ # Create static/heatmap directory
41
+ heatmap_dir = Path("static/heatmap")
42
+ heatmap_dir.mkdir(parents=True, exist_ok=True)
43
+
44
  # Mount static files
45
  app.mount("/static", StaticFiles(directory="static"), name="static")
46
 
 
53
  allow_headers=["*"],
54
  )
55
 
56
+ # Session storage for heatmaps (now using file system)
57
  heatmap_store: Dict[str, dict] = {}
58
 
59
  # Model configuration
 
64
  'Pneumonia', 'Fibrosis', 'Edema', 'Consolidation', 'No Finding'
65
  ]
66
 
67
+ def build_model():
68
+ """Build DenseNet121 model with correct output shape"""
69
  base_model = DenseNet121(
70
  weights=None,
71
  include_top=False,
 
73
  )
74
  x = base_model.output
75
  x = GlobalAveragePooling2D()(x)
76
+ predictions = Dense(14, activation='sigmoid')(x) # 14 classes in pretrained weights
 
77
  return Model(inputs=base_model.input, outputs=predictions)
78
 
79
+ def load_model_with_fallback():
80
+ """Robust model loading with multiple fallback strategies"""
81
  try:
82
+ # Strategy 1: Build model and load weights
83
+ model = build_model()
84
  model.load_weights('pretrained_model.h5')
85
  return model
86
  except Exception as e:
87
+ print(f"Primary loading failed: {e}")
88
  try:
89
+ # Strategy 2: Try direct loading
90
  model = load_model('Densenet.h5', compile=False)
 
 
 
 
 
 
91
  return model
92
  except Exception as e:
93
+ print(f"Fallback loading failed: {e}")
94
+ raise RuntimeError("All model loading strategies failed")
95
 
96
  # Load model
97
  try:
98
+ model = load_model_with_fallback()
99
  print("✅ Model loaded successfully!")
100
+ print(f"Model input shape: {model.input_shape}")
101
  print(f"Model output shape: {model.output_shape}")
102
  except Exception as e:
103
  print(f"❌ Model loading failed: {e}")
104
  raise
105
 
106
+ async def cleanup_old_heatmaps():
107
+ """Delete heatmap files older than HEATMAP_EXPIRY seconds"""
108
  now = datetime.now()
109
+ for file in heatmap_dir.glob("*.png"):
110
+ file_time = datetime.fromtimestamp(file.stat().st_mtime)
111
+ if (now - file_time) > timedelta(seconds=HEATMAP_EXPIRY):
112
+ try:
113
+ file.unlink()
114
+ print(f"Deleted expired heatmap: {file.name}")
115
+ except Exception as e:
116
+ print(f"Error deleting {file.name}: {e}")
117
 
118
  def generate_gradcam(img):
119
+ """Generate Grad-CAM heatmap visualization"""
120
  img_array = img_to_array(img)
121
  img_array = np.expand_dims(img_array, axis=0)
122
  img_array = preprocess_input(img_array)
 
146
  return Image.blend(original_img, heatmap_img, 0.5)
147
 
148
  def process_predictions(predictions):
149
+ """Format model predictions with confidence scores"""
150
  decoded = []
151
  for pred in predictions:
 
152
  top_indices = np.argsort(pred)[::-1][:len(class_names)]
153
  decoded.append([(class_names[i], float(pred[i])) for i in top_indices])
154
  return decoded
155
 
156
  def preprocess_image(file_bytes):
157
+ """Process uploaded image file"""
158
  img = cv2.imdecode(np.frombuffer(file_bytes, np.uint8), cv2.IMREAD_COLOR)
159
  img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
160
  return cv2.resize(img, (540, 540), interpolation=cv2.INTER_AREA)
161
 
162
+ @app.get("/")
163
  async def root():
164
  return {
165
+ "message": "ChexNet API is running",
166
  "endpoints": {
167
  "docs": "/docs",
168
  "health": "/health",
 
173
  @app.get("/health")
174
  async def health_check():
175
  return {
176
+ "status": "healthy",
177
+ "model_loaded": True,
178
  "timestamp": datetime.now().isoformat(),
179
+ "port": PORT,
180
+ "heatmap_files": len(list(heatmap_dir.glob("*.png")))
181
  }
182
 
183
  @app.get("/model/classes")
 
186
 
187
  @app.post("/analyze")
188
  @limiter.limit("5/minute")
189
+ async def analyze_image(
190
+ request: Request,
191
+ background_tasks: BackgroundTasks,
192
+ file: UploadFile = File(...)
193
+ ):
194
  """Analyze chest X-ray image"""
195
  if not file.content_type.startswith('image/'):
196
+ raise HTTPException(400, "Only image files accepted")
197
 
198
  if file.size > MAX_FILE_SIZE:
199
+ raise HTTPException(413, f"File too large (max {MAX_FILE_SIZE//1024//1024}MB)")
200
 
201
  try:
202
  contents = await file.read()
 
216
 
217
  # Store heatmap with session ID
218
  session_id = str(uuid.uuid4())
219
+ heatmap_path = heatmap_dir / f"{session_id}.png"
220
+ heatmap.save(heatmap_path)
221
 
222
+ # Add cleanup task
223
+ background_tasks.add_task(cleanup_old_heatmaps)
224
+
225
+ # Generate HTTPS URL
226
+ heatmap_url = f"https://{request.url.hostname}/static/heatmap/{session_id}.png"
227
 
228
  return {
229
  "session_id": session_id,
230
  "predictions": decoded[0],
231
+ "heatmap_url": heatmap_url
232
  }
233
  except Exception as e:
234
+ raise HTTPException(500, f"Analysis failed: {str(e)}")
 
 
 
 
 
 
 
 
 
 
 
235
 
236
  @app.get("/model/info")
237
  async def model_info():
238
  """Get model metadata"""
239
  return {
240
  "model_type": "DenseNet121",
241
+ "input_shape": str(model.input_shape),
 
242
  "output_shape": str(model.output_shape),
243
+ "classes": len(class_names),
244
  "gradcam_layer": layer_name,
245
  "rate_limit": "5 requests/minute"
246
  }
247
 
248
  @app.exception_handler(HTTPException)
249
+ async def http_exception_handler(request, exc):
250
  return JSONResponse(
251
  status_code=exc.status_code,
252
  content={"error": exc.detail}
253
  )
254
 
255
  @app.exception_handler(Exception)
256
+ async def generic_exception_handler(request, exc):
257
  return JSONResponse(
258
  status_code=500,
259
  content={"error": "Internal server error"}
260
+ )
261
+
262
+ if __name__ == "__main__":
263
+ import uvicorn
264
+ uvicorn.run(app, host="0.0.0.0", port=PORT)