ezAhmed commited on
Commit
c2b4eaa
Β·
verified Β·
1 Parent(s): ad8b011

Upload 5 files

Browse files
Files changed (1) hide show
  1. app.py +235 -194
app.py CHANGED
@@ -190,60 +190,102 @@ def predict(image):
190
 
191
  def make_gradcam_heatmap(img_array, model, pred_index=None):
192
  """
193
- Generate GradCAM heatmap for the model prediction.
 
194
  """
195
  try:
196
- # Find the last convolutional layer
197
- last_conv_layer = None
198
  for layer in reversed(model.layers):
199
- if 'conv' in layer.name.lower() or isinstance(layer, tf.keras.layers.Conv2D):
200
- last_conv_layer = layer
 
 
 
 
 
201
  break
202
 
203
- if last_conv_layer is None:
204
- # If no conv layer, try to use patch encoder or dense layers
205
  for layer in reversed(model.layers):
206
- if 'dense' in layer.name.lower() or 'patch' in layer.name.lower():
207
- last_conv_layer = layer
208
  break
209
 
210
- if last_conv_layer is None:
211
- return None
 
212
 
213
- # Create a model that maps the input to the activations of the last conv layer and the output predictions
214
  grad_model = tf.keras.models.Model(
215
- [model.inputs],
216
- [last_conv_layer.output, model.output]
217
  )
218
 
219
- # Compute the gradient of the top predicted class with respect to the output feature map
220
  with tf.GradientTape() as tape:
221
- conv_outputs, predictions = grad_model(img_array)
222
  if pred_index is None:
223
  pred_index = tf.argmax(predictions[0])
224
  class_channel = predictions[:, pred_index]
225
 
226
- # Gradient of the output neuron with respect to the output feature map
227
- grads = tape.gradient(class_channel, conv_outputs)
228
-
229
- # Vector of mean intensity of the gradient over a specific feature map channel
230
- pooled_grads = tf.reduce_mean(grads, axis=(0, 1, 2))
231
-
232
- # Multiply each channel in the feature map array by importance
233
- conv_outputs = conv_outputs[0]
234
- heatmap = conv_outputs @ pooled_grads[..., tf.newaxis]
235
- heatmap = tf.squeeze(heatmap)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
236
 
237
- # Normalize the heatmap
238
- heatmap = tf.maximum(heatmap, 0) / tf.math.reduce_max(heatmap)
239
- return heatmap.numpy()
 
240
 
241
  except Exception as e:
242
  print(f"GradCAM error: {e}")
243
- return None
 
 
244
 
245
 
246
- def apply_gradcam(image, heatmap):
247
  """
248
  Apply GradCAM heatmap overlay on the original image.
249
  """
@@ -251,21 +293,30 @@ def apply_gradcam(image, heatmap):
251
  if heatmap is None:
252
  return image
253
 
254
- # Resize heatmap to match input image size
255
- heatmap = cv2.resize(heatmap, (image_size, image_size))
256
-
257
- # Convert heatmap to RGB
258
- heatmap = np.uint8(255 * heatmap)
259
- heatmap = cv2.applyColorMap(heatmap, cv2.COLORMAP_JET)
260
-
261
  # Convert image to numpy array
262
  if isinstance(image, Image.Image):
263
  img_array = np.array(image.resize((image_size, image_size)))
264
  else:
265
  img_array = image
266
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
267
  # Superimpose the heatmap on original image
268
- superimposed_img = cv2.addWeighted(img_array, 0.6, heatmap, 0.4, 0)
 
269
 
270
  return Image.fromarray(superimposed_img)
271
 
@@ -279,7 +330,7 @@ def generate_gradcam(image):
279
  Generate GradCAM visualization.
280
  """
281
  if model is None or image is None:
282
- return None, "Model not loaded or no image provided"
283
 
284
  try:
285
  # Preprocess image
@@ -288,23 +339,21 @@ def generate_gradcam(image):
288
  # Make prediction
289
  predictions = model.predict(processed_image, verbose=0)
290
  pred_class = np.argmax(predictions[0])
291
- confidence = predictions[0][pred_class]
292
 
293
  # Generate heatmap
294
- heatmap = make_gradcam_heatmap(processed_image, model, pred_class)
295
 
296
  if heatmap is None:
297
- return None, "GradCAM not available for this model architecture"
298
 
299
  # Apply heatmap
300
- gradcam_image = apply_gradcam(image, heatmap)
301
-
302
- info = f"**Predicted Class:** {class_names[pred_class]}\n**Confidence:** {confidence:.2%}"
303
 
304
- return gradcam_image, info
305
 
306
  except Exception as e:
307
- return None, f"Error generating GradCAM: {str(e)}"
 
308
 
309
 
310
  # -----------------------------
@@ -317,10 +366,10 @@ def generate_shap(image):
317
  Generate SHAP explanation visualization.
318
  """
319
  if not SHAP_AVAILABLE:
320
- return None, "SHAP library not available. Please install shap package."
321
 
322
  if model is None or image is None:
323
- return None, "Model not loaded or no image provided"
324
 
325
  try:
326
  # Preprocess image
@@ -362,18 +411,11 @@ def generate_shap(image):
362
  shap_image = Image.open(buf)
363
  plt.close()
364
 
365
- # Get prediction info
366
- prediction = model_predict(img_array[np.newaxis, ...])[0]
367
- pred_class = np.argmax(prediction)
368
- confidence = prediction[pred_class]
369
-
370
- info = f"**Predicted Class:** {class_names[pred_class]}\n**Confidence:** {confidence:.2%}"
371
-
372
- return shap_image, info
373
 
374
  except Exception as e:
375
  print(f"SHAP error: {e}")
376
- return None, f"Error generating SHAP: {str(e)}"
377
 
378
 
379
  # -----------------------------
@@ -386,10 +428,10 @@ def generate_lime(image):
386
  Generate LIME explanation visualization.
387
  """
388
  if not LIME_AVAILABLE or not SKIMAGE_AVAILABLE:
389
- return None, None, "LIME or scikit-image library not available. Please install lime and scikit-image packages."
390
 
391
  if model is None or image is None:
392
- return None, None, "Model not loaded or no image provided"
393
 
394
  try:
395
  # Preprocess image
@@ -414,12 +456,6 @@ def generate_lime(image):
414
  batch_size=32
415
  )
416
 
417
- # Get prediction
418
- prediction = model.predict(
419
- img_normalized.reshape(1, image_size, image_size, 3))
420
- pred_class = np.argmax(prediction)
421
- confidence = np.max(prediction)
422
-
423
  # Create visualizations
424
  # Positive features only
425
  temp_positive, mask_positive = explanation.get_image_and_mask(
@@ -444,13 +480,43 @@ def generate_lime(image):
444
  (lime_positive * 255).astype(np.uint8))
445
  lime_both_img = Image.fromarray((lime_both * 255).astype(np.uint8))
446
 
447
- info = f"**Predicted Class:** {class_names[pred_class]}\n**Confidence:** {confidence:.2%}"
448
-
449
- return lime_positive_img, lime_both_img, info
450
 
451
  except Exception as e:
452
  print(f"LIME error: {e}")
453
- return None, None, f"Error generating LIME: {str(e)}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
454
 
455
 
456
  # -----------------------------
@@ -461,8 +527,8 @@ description = """
461
  <div style="text-align: center; padding: 20px;">
462
  <h2 style="color: #2E86AB;">Advanced Medical Image Analysis with Explainable AI</h2>
463
  <p style="font-size: 16px; color: #555;">
464
- Upload an endoscopic image to classify using a <b>Lightweight Vision Transformer</b> model
465
- and explore <b>multiple explainability methods</b> to understand the model's decision-making process.
466
  </p>
467
  <div style="display: flex; justify-content: center; gap: 20px; margin-top: 15px;">
468
  <div style="background: linear-gradient(135deg, #667eea 0%, #764ba2 100%); padding: 15px; border-radius: 10px; color: white;">
@@ -484,20 +550,6 @@ custom_css = """
484
  .gradio-container {
485
  font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif !important;
486
  }
487
- .tab-nav button {
488
- font-size: 16px !important;
489
- font-weight: bold !important;
490
- padding: 12px 24px !important;
491
- border-radius: 10px 10px 0 0 !important;
492
- transition: all 0.3s ease !important;
493
- }
494
- .tab-nav button.selected {
495
- background: linear-gradient(135deg, #667eea 0%, #764ba2 100%) !important;
496
- color: white !important;
497
- }
498
- .tab-nav button:hover {
499
- transform: translateY(-2px) !important;
500
- }
501
  h1 {
502
  background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
503
  -webkit-background-clip: text;
@@ -505,12 +557,6 @@ h1 {
505
  font-size: 2.5em !important;
506
  text-align: center !important;
507
  }
508
- .output-class {
509
- font-size: 18px !important;
510
- padding: 10px !important;
511
- border-radius: 8px !important;
512
- background: linear-gradient(135deg, #e0f7fa 0%, #e1bee7 100%) !important;
513
- }
514
  .button-primary {
515
  background: linear-gradient(135deg, #667eea 0%, #764ba2 100%) !important;
516
  border: none !important;
@@ -518,6 +564,7 @@ h1 {
518
  font-weight: bold !important;
519
  padding: 12px 30px !important;
520
  border-radius: 25px !important;
 
521
  transition: all 0.3s ease !important;
522
  }
523
  .button-primary:hover {
@@ -526,103 +573,98 @@ h1 {
526
  }
527
  """
528
 
529
- examples = [] # Add example image paths if available
530
-
531
  # Create Gradio interface using Blocks with creative design
532
  with gr.Blocks(css=custom_css, theme=gr.themes.Soft()) as demo:
533
  gr.HTML(f"<h1>{title}</h1>")
534
  gr.HTML(description)
535
 
536
- with gr.Tabs():
537
- # Tab 1: Basic Prediction
538
- with gr.Tab("πŸ“Š Classification"):
539
- with gr.Row():
540
- with gr.Column(scale=1):
541
- input_image_basic = gr.Image(
542
- type="pil", label="Upload Endoscopic Image")
543
- classify_btn = gr.Button(
544
- "πŸ” Classify Image", variant="primary", elem_classes="button-primary")
545
- gr.Markdown("""
546
- <div style="background: #f0f4f8; padding: 15px; border-radius: 10px; margin-top: 10px;">
547
- <b>ℹ️ Instructions:</b>
548
- <ul>
549
- <li>Upload an endoscopic image (JPG, PNG)</li>
550
- <li>Click "Classify Image" to get predictions</li>
551
- <li>View confidence scores for each class</li>
552
- </ul>
553
- </div>
554
- """)
555
-
556
- with gr.Column(scale=1):
557
- output_label_basic = gr.Label(
558
- num_top_classes=4, label="πŸ“ˆ Prediction Results")
559
-
560
- # Tab 2: GradCAM Explanation
561
- with gr.Tab("πŸ”₯ GradCAM"):
562
- with gr.Row():
563
- with gr.Column(scale=1):
564
- input_image_gradcam = gr.Image(
565
- type="pil", label="Upload Endoscopic Image")
566
- gradcam_btn = gr.Button(
567
- "πŸ”₯ Generate GradCAM", variant="primary", elem_classes="button-primary")
568
- gr.Markdown("""
569
- <div style="background: #fff3e0; padding: 15px; border-radius: 10px; margin-top: 10px;">
570
- <b>πŸ”₯ GradCAM (Gradient-weighted Class Activation Mapping):</b>
571
- <p>Highlights the regions of the image that are most important for the model's prediction.
572
- Red areas indicate high importance, while blue areas indicate low importance.</p>
573
- </div>
574
- """)
575
-
576
- with gr.Column(scale=1):
577
- output_gradcam = gr.Image(label="πŸ—ΊοΈ GradCAM Heatmap")
578
- output_gradcam_info = gr.Markdown(label="Prediction Info")
579
-
580
- # Tab 3: SHAP Explanation
581
- with gr.Tab("🎯 SHAP"):
582
- with gr.Row():
583
- with gr.Column(scale=1):
584
- input_image_shap = gr.Image(
585
- type="pil", label="Upload Endoscopic Image")
586
- shap_btn = gr.Button(
587
- "🎯 Generate SHAP", variant="primary", elem_classes="button-primary")
588
- gr.Markdown("""
589
- <div style="background: #e8f5e9; padding: 15px; border-radius: 10px; margin-top: 10px;">
590
- <b>🎯 SHAP (SHapley Additive exPlanations):</b>
591
- <p>Shows how each pixel contributes to the prediction. Red pixels push the prediction
592
- towards the predicted class, while blue pixels push it away.</p>
593
- <p><i>⚠️ Note: SHAP generation may take 30-60 seconds.</i></p>
594
- </div>
595
- """)
596
-
597
- with gr.Column(scale=1):
598
- output_shap = gr.Image(label="πŸ“Š SHAP Explanation")
599
- output_shap_info = gr.Markdown(label="Prediction Info")
600
-
601
- # Tab 4: LIME Explanation
602
- with gr.Tab("πŸ‹ LIME"):
603
- with gr.Row():
604
- with gr.Column(scale=1):
605
- input_image_lime = gr.Image(
606
- type="pil", label="Upload Endoscopic Image")
607
- lime_btn = gr.Button(
608
- "πŸ‹ Generate LIME", variant="primary", elem_classes="button-primary")
609
- gr.Markdown("""
610
- <div style="background: #fce4ec; padding: 15px; border-radius: 10px; margin-top: 10px;">
611
- <b>πŸ‹ LIME (Local Interpretable Model-agnostic Explanations):</b>
612
- <p>Segments the image into superpixels and shows which segments are most important.
613
- Green boundaries indicate positive contributions to the prediction.</p>
614
- <p><i>⚠️ Note: LIME generation may take 30-60 seconds.</i></p>
615
- </div>
616
- """)
617
-
618
- with gr.Column(scale=1):
619
- gr.Markdown("### 🟒 Positive Features Only")
620
- output_lime_positive = gr.Image(
621
- label="Important Regions (Positive)")
622
- gr.Markdown("### πŸ”΄ Positive & Negative Features")
623
- output_lime_both = gr.Image(
624
- label="All Contributing Regions")
625
- output_lime_info = gr.Markdown(label="Prediction Info")
626
 
627
  # Footer
628
  gr.Markdown("""
@@ -636,15 +678,14 @@ with gr.Blocks(css=custom_css, theme=gr.themes.Soft()) as demo:
636
  </div>
637
  """)
638
 
639
- # Connect buttons to functions
640
- classify_btn.click(fn=predict, inputs=input_image_basic,
641
- outputs=output_label_basic, api_name="predict")
642
- gradcam_btn.click(fn=generate_gradcam, inputs=input_image_gradcam, outputs=[
643
- output_gradcam, output_gradcam_info], api_name="gradcam")
644
- shap_btn.click(fn=generate_shap, inputs=input_image_shap, outputs=[
645
- output_shap, output_shap_info], api_name="shap")
646
- lime_btn.click(fn=generate_lime, inputs=input_image_lime, outputs=[
647
- output_lime_positive, output_lime_both, output_lime_info], api_name="lime")
648
 
649
  # Launch with error reporting enabled
650
  if __name__ == "__main__":
 
190
 
191
  def make_gradcam_heatmap(img_array, model, pred_index=None):
192
  """
193
+ Generate Grad-CAM heatmap for lightweight ViT model
194
+ Uses the transformer output before global pooling
195
  """
196
  try:
197
+ # Find the layer before GlobalAveragePooling (typically the last Add or LayerNormalization)
198
+ target_layer = None
199
  for layer in reversed(model.layers):
200
+ # Look for the last Add layer (from transformer blocks)
201
+ if isinstance(layer, tf.keras.layers.Add):
202
+ target_layer = layer
203
+ break
204
+ # Or the LayerNormalization before classification head
205
+ if isinstance(layer, tf.keras.layers.LayerNormalization) and 'representation' not in layer.name:
206
+ target_layer = layer
207
  break
208
 
209
+ if target_layer is None:
210
+ # Fallback: find any layer with 3D output (batch, seq_len, features)
211
  for layer in reversed(model.layers):
212
+ if hasattr(layer, 'output_shape') and len(layer.output_shape) == 3:
213
+ target_layer = layer
214
  break
215
 
216
+ if target_layer is None:
217
+ print("Warning: No suitable layer found for Grad-CAM")
218
+ return None, pred_index
219
 
220
+ # Create a model that outputs both the target layer output and final predictions
221
  grad_model = tf.keras.models.Model(
222
+ inputs=model.inputs,
223
+ outputs=[model.get_layer(target_layer.name).output, model.output]
224
  )
225
 
226
+ # Compute gradients
227
  with tf.GradientTape() as tape:
228
+ layer_output, predictions = grad_model(img_array, training=False)
229
  if pred_index is None:
230
  pred_index = tf.argmax(predictions[0])
231
  class_channel = predictions[:, pred_index]
232
 
233
+ # Get gradients of the predicted class with respect to the layer output
234
+ grads = tape.gradient(class_channel, layer_output)
235
+
236
+ if grads is None:
237
+ print("Warning: Gradients are None. Using simple attention map.")
238
+ # Fallback: use attention weights
239
+ layer_output_np = layer_output[0].numpy()
240
+ heatmap = np.mean(np.abs(layer_output_np), axis=-1)
241
+ # Reshape to 2D grid
242
+ num_patches = heatmap.shape[0]
243
+ grid_size = int(np.sqrt(num_patches))
244
+ heatmap = heatmap[:grid_size *
245
+ grid_size].reshape(grid_size, grid_size)
246
+ heatmap = (heatmap - heatmap.min()) / \
247
+ (heatmap.max() - heatmap.min() + 1e-10)
248
+ return heatmap, int(pred_index.numpy())
249
+
250
+ # Global average pooling on gradients
251
+ if len(grads.shape) == 3: # (batch, seq_len, features)
252
+ pooled_grads = tf.reduce_mean(grads, axis=(0, 1))
253
+ layer_output = layer_output[0]
254
+
255
+ # Weight the sequence by the gradients
256
+ heatmap = layer_output @ pooled_grads[..., tf.newaxis]
257
+ heatmap = tf.squeeze(heatmap)
258
+
259
+ # Reshape to 2D grid
260
+ num_patches = heatmap.shape[0]
261
+ grid_size = int(np.sqrt(num_patches))
262
+ if grid_size * grid_size != num_patches:
263
+ # Handle case where sqrt is not exact
264
+ # Exclude class token if present
265
+ grid_size = int(np.sqrt(num_patches - 1))
266
+ heatmap = heatmap[1:grid_size*grid_size+1] # Skip class token
267
+ else:
268
+ heatmap = heatmap[:grid_size*grid_size]
269
+ heatmap = tf.reshape(heatmap, (grid_size, grid_size))
270
+ else:
271
+ pooled_grads = tf.reduce_mean(grads, axis=(0, 1, 2))
272
+ layer_output = layer_output[0]
273
+ heatmap = layer_output @ pooled_grads[..., tf.newaxis]
274
+ heatmap = tf.squeeze(heatmap)
275
 
276
+ # Normalize between 0 and 1
277
+ heatmap = tf.maximum(heatmap, 0) / \
278
+ (tf.math.reduce_max(heatmap) + 1e-10)
279
+ return heatmap.numpy(), int(pred_index.numpy())
280
 
281
  except Exception as e:
282
  print(f"GradCAM error: {e}")
283
+ import traceback
284
+ traceback.print_exc()
285
+ return None, pred_index
286
 
287
 
288
+ def apply_gradcam(image, heatmap, alpha=0.4):
289
  """
290
  Apply GradCAM heatmap overlay on the original image.
291
  """
 
293
  if heatmap is None:
294
  return image
295
 
 
 
 
 
 
 
 
296
  # Convert image to numpy array
297
  if isinstance(image, Image.Image):
298
  img_array = np.array(image.resize((image_size, image_size)))
299
  else:
300
  img_array = image
301
 
302
+ # Resize heatmap to match input image size
303
+ heatmap_resized = cv2.resize(
304
+ heatmap, (img_array.shape[1], img_array.shape[0]))
305
+
306
+ # Convert heatmap to RGB
307
+ heatmap_uint8 = np.uint8(255 * heatmap_resized)
308
+ heatmap_colored = cv2.applyColorMap(heatmap_uint8, cv2.COLORMAP_JET)
309
+ heatmap_colored = cv2.cvtColor(heatmap_colored, cv2.COLOR_BGR2RGB)
310
+
311
+ # Normalize image if needed
312
+ if img_array.max() <= 1.0:
313
+ img_uint8 = (img_array * 255).astype('uint8')
314
+ else:
315
+ img_uint8 = img_array.astype('uint8')
316
+
317
  # Superimpose the heatmap on original image
318
+ superimposed_img = heatmap_colored * alpha + img_uint8 * (1 - alpha)
319
+ superimposed_img = np.clip(superimposed_img, 0, 255).astype('uint8')
320
 
321
  return Image.fromarray(superimposed_img)
322
 
 
330
  Generate GradCAM visualization.
331
  """
332
  if model is None or image is None:
333
+ return None
334
 
335
  try:
336
  # Preprocess image
 
339
  # Make prediction
340
  predictions = model.predict(processed_image, verbose=0)
341
  pred_class = np.argmax(predictions[0])
 
342
 
343
  # Generate heatmap
344
+ heatmap, _ = make_gradcam_heatmap(processed_image, model, pred_class)
345
 
346
  if heatmap is None:
347
+ return None
348
 
349
  # Apply heatmap
350
+ gradcam_image = apply_gradcam(image, heatmap, alpha=0.4)
 
 
351
 
352
+ return gradcam_image
353
 
354
  except Exception as e:
355
+ print(f"Error generating GradCAM: {e}")
356
+ return None
357
 
358
 
359
  # -----------------------------
 
366
  Generate SHAP explanation visualization.
367
  """
368
  if not SHAP_AVAILABLE:
369
+ return None
370
 
371
  if model is None or image is None:
372
+ return None
373
 
374
  try:
375
  # Preprocess image
 
411
  shap_image = Image.open(buf)
412
  plt.close()
413
 
414
+ return shap_image
 
 
 
 
 
 
 
415
 
416
  except Exception as e:
417
  print(f"SHAP error: {e}")
418
+ return None
419
 
420
 
421
  # -----------------------------
 
428
  Generate LIME explanation visualization.
429
  """
430
  if not LIME_AVAILABLE or not SKIMAGE_AVAILABLE:
431
+ return None, None
432
 
433
  if model is None or image is None:
434
+ return None, None
435
 
436
  try:
437
  # Preprocess image
 
456
  batch_size=32
457
  )
458
 
 
 
 
 
 
 
459
  # Create visualizations
460
  # Positive features only
461
  temp_positive, mask_positive = explanation.get_image_and_mask(
 
480
  (lime_positive * 255).astype(np.uint8))
481
  lime_both_img = Image.fromarray((lime_both * 255).astype(np.uint8))
482
 
483
+ return lime_positive_img, lime_both_img
 
 
484
 
485
  except Exception as e:
486
  print(f"LIME error: {e}")
487
+ return None, None
488
+
489
+
490
+ # -----------------------------
491
+ # Unified Prediction with XAI
492
+ # -----------------------------
493
+
494
+
495
+ def predict_with_xai(image):
496
+ """
497
+ Make prediction and generate all XAI explanations at once.
498
+ """
499
+ if model is None or image is None:
500
+ return {class_name: 0.0 for class_name in class_names}, None, None, None, None
501
+
502
+ try:
503
+ # Make prediction
504
+ prediction_results = predict(image)
505
+
506
+ # Generate GradCAM
507
+ gradcam_img = generate_gradcam(image)
508
+
509
+ # Generate SHAP (can be slow)
510
+ shap_img = generate_shap(image)
511
+
512
+ # Generate LIME (can be slow)
513
+ lime_positive, lime_both = generate_lime(image)
514
+
515
+ return prediction_results, gradcam_img, shap_img, lime_positive, lime_both
516
+
517
+ except Exception as e:
518
+ print(f"Error in predict_with_xai: {e}")
519
+ return {class_name: 0.0 for class_name in class_names}, None, None, None, None
520
 
521
 
522
  # -----------------------------
 
527
  <div style="text-align: center; padding: 20px;">
528
  <h2 style="color: #2E86AB;">Advanced Medical Image Analysis with Explainable AI</h2>
529
  <p style="font-size: 16px; color: #555;">
530
+ Upload an endoscopic image to classify using a <b>Lightweight Vision Transformer</b> model.
531
+ Get predictions with <b>three explainability methods</b> to understand the AI's decision.
532
  </p>
533
  <div style="display: flex; justify-content: center; gap: 20px; margin-top: 15px;">
534
  <div style="background: linear-gradient(135deg, #667eea 0%, #764ba2 100%); padding: 15px; border-radius: 10px; color: white;">
 
550
  .gradio-container {
551
  font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif !important;
552
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
553
  h1 {
554
  background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
555
  -webkit-background-clip: text;
 
557
  font-size: 2.5em !important;
558
  text-align: center !important;
559
  }
 
 
 
 
 
 
560
  .button-primary {
561
  background: linear-gradient(135deg, #667eea 0%, #764ba2 100%) !important;
562
  border: none !important;
 
564
  font-weight: bold !important;
565
  padding: 12px 30px !important;
566
  border-radius: 25px !important;
567
+ font-size: 16px !important;
568
  transition: all 0.3s ease !important;
569
  }
570
  .button-primary:hover {
 
573
  }
574
  """
575
 
 
 
576
  # Create Gradio interface using Blocks with creative design
577
  with gr.Blocks(css=custom_css, theme=gr.themes.Soft()) as demo:
578
  gr.HTML(f"<h1>{title}</h1>")
579
  gr.HTML(description)
580
 
581
+ with gr.Row():
582
+ with gr.Column(scale=1):
583
+ input_image = gr.Image(
584
+ type="pil", label="πŸ“€ Upload Endoscopic Image")
585
+ predict_btn = gr.Button(
586
+ "πŸ” Classify & Explain", variant="primary", elem_classes="button-primary", size="lg")
587
+
588
+ gr.Markdown("""
589
+ <div style="background: #f0f4f8; padding: 15px; border-radius: 10px; margin-top: 10px;">
590
+ <b>ℹ️ Instructions:</b>
591
+ <ul>
592
+ <li>Upload an endoscopic image (JPG, PNG)</li>
593
+ <li>Click "Classify & Explain" to get results</li>
594
+ <li>View prediction + XAI explanations below</li>
595
+ <li><i>Note: SHAP and LIME may take 30-60 seconds</i></li>
596
+ </ul>
597
+ </div>
598
+ """)
599
+
600
+ with gr.Column(scale=1):
601
+ output_label = gr.Label(
602
+ num_top_classes=4, label="πŸ“Š Prediction Results", show_label=True)
603
+
604
+ # Explanations Section
605
+ gr.Markdown("""
606
+ <div style="text-align: center; margin-top: 30px; margin-bottom: 20px;">
607
+ <h2 style="color: #2E86AB;">🎯 Explainable AI Visualizations</h2>
608
+ <p style="color: #666;">Understanding how the model makes its predictions</p>
609
+ </div>
610
+ """)
611
+
612
+ with gr.Row():
613
+ # GradCAM
614
+ with gr.Column(scale=1):
615
+ gr.Markdown("""
616
+ <div style="background: linear-gradient(135deg, #fff3e0 0%, #ffe0b2 100%); padding: 15px; border-radius: 10px; margin-bottom: 10px;">
617
+ <h3 style="margin: 0; color: #e65100;">πŸ”₯ Grad-CAM</h3>
618
+ <p style="margin: 5px 0 0 0; font-size: 14px;">
619
+ <b>Gradient-weighted Class Activation Mapping</b><br>
620
+ Highlights regions most important for prediction. Red = high importance.
621
+ </p>
622
+ </div>
623
+ """)
624
+ output_gradcam = gr.Image(
625
+ label="Grad-CAM Heatmap", show_label=False)
626
+
627
+ with gr.Row():
628
+ # SHAP
629
+ with gr.Column(scale=1):
630
+ gr.Markdown("""
631
+ <div style="background: linear-gradient(135deg, #e8f5e9 0%, #c8e6c9 100%); padding: 15px; border-radius: 10px; margin-bottom: 10px;">
632
+ <h3 style="margin: 0; color: #2e7d32;">🎯 SHAP</h3>
633
+ <p style="margin: 5px 0 0 0; font-size: 14px;">
634
+ <b>SHapley Additive exPlanations</b><br>
635
+ Red pixels push toward predicted class, blue pixels push away.
636
+ </p>
637
+ </div>
638
+ """)
639
+ output_shap = gr.Image(label="SHAP Explanation", show_label=False)
640
+
641
+ with gr.Row():
642
+ # LIME
643
+ with gr.Column(scale=1):
644
+ gr.Markdown("""
645
+ <div style="background: linear-gradient(135deg, #fce4ec 0%, #f8bbd0 100%); padding: 15px; border-radius: 10px; margin-bottom: 10px;">
646
+ <h3 style="margin: 0; color: #c2185b;">πŸ‹ LIME - Positive Features</h3>
647
+ <p style="margin: 5px 0 0 0; font-size: 14px;">
648
+ <b>Local Interpretable Model-agnostic Explanations</b><br>
649
+ Green boundaries show regions supporting the prediction.
650
+ </p>
651
+ </div>
652
+ """)
653
+ output_lime_positive = gr.Image(
654
+ label="LIME Positive", show_label=False)
655
+
656
+ with gr.Column(scale=1):
657
+ gr.Markdown("""
658
+ <div style="background: linear-gradient(135deg, #e1f5fe 0%, #b3e5fc 100%); padding: 15px; border-radius: 10px; margin-bottom: 10px;">
659
+ <h3 style="margin: 0; color: #01579b;">πŸ‹ LIME - All Features</h3>
660
+ <p style="margin: 5px 0 0 0; font-size: 14px;">
661
+ <b>Positive & Negative Contributions</b><br>
662
+ Shows both supporting and opposing regions.
663
+ </p>
664
+ </div>
665
+ """)
666
+ output_lime_both = gr.Image(
667
+ label="LIME Positive & Negative", show_label=False)
 
 
 
668
 
669
  # Footer
670
  gr.Markdown("""
 
678
  </div>
679
  """)
680
 
681
+ # Connect button to unified function
682
+ predict_btn.click(
683
+ fn=predict_with_xai,
684
+ inputs=input_image,
685
+ outputs=[output_label, output_gradcam, output_shap,
686
+ output_lime_positive, output_lime_both],
687
+ api_name="predict"
688
+ )
 
689
 
690
  # Launch with error reporting enabled
691
  if __name__ == "__main__":