thehammadishaq commited on
Commit
b9aef46
·
verified ·
1 Parent(s): 67849a5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +40 -23
app.py CHANGED
@@ -6,9 +6,10 @@ from slowapi import Limiter
6
  from slowapi.util import get_remote_address
7
  import tensorflow as tf
8
  from tensorflow.keras.models import Model, load_model
9
- from tensorflow.keras.layers import Input
 
10
  from tensorflow.keras.preprocessing.image import img_to_array
11
- from tensorflow.keras.applications.densenet import DenseNet121, preprocess_input
12
  import numpy as np
13
  from PIL import Image
14
  import matplotlib.pyplot as plt
@@ -27,7 +28,7 @@ HEATMAP_EXPIRY = 300 # 5 minutes in seconds
27
  app = FastAPI(
28
  title="ChexNet Medical Imaging API",
29
  description="API for chest X-ray analysis with Grad-CAM visualization",
30
- version="2.3.0"
31
  )
32
 
33
  # Rate limiter setup
@@ -57,37 +58,47 @@ class_names = [
57
  'Pneumonia', 'Fibrosis', 'Edema', 'Consolidation', 'No Finding'
58
  ]
59
 
60
- def build_model():
61
- """Build DenseNet121 model architecture"""
62
- base_model = DenseNet121(weights=None, include_top=False, input_shape=(None, None, 3))
 
 
 
 
63
  x = base_model.output
64
- x = tf.keras.layers.GlobalAveragePooling2D()(x)
65
- predictions = tf.keras.layers.Dense(len(class_names), activation='sigmoid')(x)
 
66
  return Model(inputs=base_model.input, outputs=predictions)
67
 
68
- def load_model_with_fallback():
69
- """Attempt multiple strategies to load the model"""
70
  try:
71
- # Strategy 1: Try direct loading
72
- model = load_model('Densenet.h5', compile=False)
73
  model.load_weights('pretrained_model.h5')
74
  return model
75
  except Exception as e:
76
- print(f"Direct loading failed: {e}")
77
-
78
  try:
79
- # Strategy 2: Build architecture and load weights
80
- model = build_model()
81
- model.load_weights('pretrained_model.h5')
 
 
 
 
 
82
  return model
83
  except Exception as e:
84
- print(f"Architecture rebuild failed: {e}")
85
- raise RuntimeError("All model loading strategies failed")
86
 
87
  # Load model
88
  try:
89
- model = load_model_with_fallback()
90
  print("✅ Model loaded successfully!")
 
91
  except Exception as e:
92
  print(f"❌ Model loading failed: {e}")
93
  raise
@@ -133,10 +144,11 @@ def generate_gradcam(img):
133
  return Image.blend(original_img, heatmap_img, 0.5)
134
 
135
  def process_predictions(predictions):
136
- """Format predictions with top 4 classes"""
137
  decoded = []
138
  for pred in predictions:
139
- top_indices = pred.argsort()[-4:][::-1]
 
140
  decoded.append([(class_names[i], float(pred[i])) for i in top_indices])
141
  return decoded
142
 
@@ -162,7 +174,8 @@ async def health_check():
162
  return {
163
  "status": "healthy" if model else "unhealthy",
164
  "timestamp": datetime.now().isoformat(),
165
- "model_loaded": bool(model)
 
166
  }
167
 
168
  @app.get("/model/classes")
@@ -172,6 +185,7 @@ async def get_class_names():
172
  @app.post("/analyze")
173
  @limiter.limit("5/minute")
174
  async def analyze_image(request: Request, file: UploadFile = File(...)):
 
175
  if not file.content_type.startswith('image/'):
176
  raise HTTPException(400, "Only image files are accepted")
177
 
@@ -215,6 +229,7 @@ async def analyze_image(request: Request, file: UploadFile = File(...)):
215
 
216
  @app.get("/static/heatmap/{session_id}")
217
  async def get_heatmap(session_id: str):
 
218
  if session_id not in heatmap_store:
219
  raise HTTPException(404, "Session expired or invalid")
220
  return StreamingResponse(
@@ -225,10 +240,12 @@ async def get_heatmap(session_id: str):
225
 
226
  @app.get("/model/info")
227
  async def model_info():
 
228
  return {
229
  "model_type": "DenseNet121",
230
  "input_size": "540x540",
231
  "classes": len(class_names),
 
232
  "gradcam_layer": layer_name,
233
  "rate_limit": "5 requests/minute"
234
  }
 
6
  from slowapi.util import get_remote_address
7
  import tensorflow as tf
8
  from tensorflow.keras.models import Model, load_model
9
+ from tensorflow.keras.layers import Input, GlobalAveragePooling2D, Dense
10
+ from tensorflow.keras.applications import DenseNet121
11
  from tensorflow.keras.preprocessing.image import img_to_array
12
+ from tensorflow.keras.applications.densenet import preprocess_input
13
  import numpy as np
14
  from PIL import Image
15
  import matplotlib.pyplot as plt
 
28
  app = FastAPI(
29
  title="ChexNet Medical Imaging API",
30
  description="API for chest X-ray analysis with Grad-CAM visualization",
31
+ version="3.0.0"
32
  )
33
 
34
  # Rate limiter setup
 
58
  'Pneumonia', 'Fibrosis', 'Edema', 'Consolidation', 'No Finding'
59
  ]
60
 
61
+ def build_custom_model():
62
+ """Build model with correct output shape matching your weights"""
63
+ base_model = DenseNet121(
64
+ weights=None,
65
+ include_top=False,
66
+ input_shape=(None, None, 3)
67
+ )
68
  x = base_model.output
69
+ x = GlobalAveragePooling2D()(x)
70
+ # Match the output shape in your pretrained weights (14 classes)
71
+ predictions = Dense(14, activation='sigmoid')(x)
72
  return Model(inputs=base_model.input, outputs=predictions)
73
 
74
+ def load_model_with_retry():
75
+ """Enhanced model loading with shape compatibility handling"""
76
  try:
77
+ # First try loading with custom architecture
78
+ model = build_custom_model()
79
  model.load_weights('pretrained_model.h5')
80
  return model
81
  except Exception as e:
82
+ print(f"Loading with custom architecture failed: {e}")
 
83
  try:
84
+ # Fallback to direct loading with compile=False
85
+ model = load_model('Densenet.h5', compile=False)
86
+ # Ensure output layer matches our class names
87
+ if model.layers[-1].output_shape[-1] != len(class_names):
88
+ print("Adjusting output layer to match class names")
89
+ x = model.layers[-2].output
90
+ predictions = Dense(len(class_names), activation='sigmoid')(x)
91
+ model = Model(inputs=model.input, outputs=predictions)
92
  return model
93
  except Exception as e:
94
+ print(f"All loading attempts failed: {e}")
95
+ raise RuntimeError(f"Could not load model: {str(e)}")
96
 
97
  # Load model
98
  try:
99
+ model = load_model_with_retry()
100
  print("✅ Model loaded successfully!")
101
+ print(f"Model output shape: {model.output_shape}")
102
  except Exception as e:
103
  print(f"❌ Model loading failed: {e}")
104
  raise
 
144
  return Image.blend(original_img, heatmap_img, 0.5)
145
 
146
  def process_predictions(predictions):
147
+ """Format predictions with top classes"""
148
  decoded = []
149
  for pred in predictions:
150
+ # Get indices sorted by probability (descending)
151
+ top_indices = np.argsort(pred)[::-1][:len(class_names)]
152
  decoded.append([(class_names[i], float(pred[i])) for i in top_indices])
153
  return decoded
154
 
 
174
  return {
175
  "status": "healthy" if model else "unhealthy",
176
  "timestamp": datetime.now().isoformat(),
177
+ "model_loaded": bool(model),
178
+ "model_output_shape": str(model.output_shape) if model else "N/A"
179
  }
180
 
181
  @app.get("/model/classes")
 
185
  @app.post("/analyze")
186
  @limiter.limit("5/minute")
187
  async def analyze_image(request: Request, file: UploadFile = File(...)):
188
+ """Analyze chest X-ray image"""
189
  if not file.content_type.startswith('image/'):
190
  raise HTTPException(400, "Only image files are accepted")
191
 
 
229
 
230
  @app.get("/static/heatmap/{session_id}")
231
  async def get_heatmap(session_id: str):
232
+ """Retrieve Grad-CAM visualization"""
233
  if session_id not in heatmap_store:
234
  raise HTTPException(404, "Session expired or invalid")
235
  return StreamingResponse(
 
240
 
241
  @app.get("/model/info")
242
  async def model_info():
243
+ """Get model metadata"""
244
  return {
245
  "model_type": "DenseNet121",
246
  "input_size": "540x540",
247
  "classes": len(class_names),
248
+ "output_shape": str(model.output_shape),
249
  "gradcam_layer": layer_name,
250
  "rate_limit": "5 requests/minute"
251
  }