File size: 24,023 Bytes
16f5a6f
 
 
 
ad8b011
 
 
 
 
 
8f56225
 
 
5073ec7
ad8b011
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8f56225
 
 
 
8736948
 
 
1e86112
 
8736948
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1e86112
 
8736948
 
 
 
 
 
 
 
 
 
 
8f56225
 
 
 
 
 
 
 
 
8736948
8f56225
 
 
8736948
8f56225
 
 
8736948
 
8f56225
8736948
8f56225
 
8736948
 
 
 
 
 
 
8f56225
 
 
 
8736948
 
8f56225
 
 
 
d71a6b0
8f56225
d71a6b0
 
8f56225
d71a6b0
8f56225
 
 
 
 
 
 
 
 
8736948
8f56225
 
 
8736948
8f56225
8736948
 
8f56225
 
 
 
8736948
 
8f56225
 
 
8736948
3f7fc3e
 
 
 
 
 
8f56225
8736948
8f56225
8736948
8f56225
 
8736948
8f56225
 
8736948
 
3f7fc3e
 
 
 
 
 
 
 
 
8736948
8f56225
 
 
3f7fc3e
 
 
8f56225
 
16f5a6f
 
 
 
 
 
 
c2b4eaa
 
16f5a6f
 
c2b4eaa
 
16f5a6f
c2b4eaa
 
 
 
 
 
 
16f5a6f
 
c2b4eaa
 
16f5a6f
c2b4eaa
 
16f5a6f
 
c2b4eaa
 
 
16f5a6f
c2b4eaa
16f5a6f
c2b4eaa
 
16f5a6f
 
c2b4eaa
16f5a6f
c2b4eaa
16f5a6f
 
 
 
c2b4eaa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16f5a6f
c2b4eaa
 
 
 
16f5a6f
 
 
c2b4eaa
 
 
16f5a6f
 
c2b4eaa
16f5a6f
 
 
 
 
 
 
 
 
 
 
 
 
c2b4eaa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16f5a6f
c2b4eaa
 
16f5a6f
 
 
 
 
 
 
 
 
 
 
 
 
c2b4eaa
16f5a6f
 
 
 
 
 
 
 
 
 
c2b4eaa
16f5a6f
 
c2b4eaa
16f5a6f
 
c2b4eaa
16f5a6f
c2b4eaa
16f5a6f
 
c2b4eaa
 
16f5a6f
 
 
 
 
 
 
 
 
 
 
ad8b011
c2b4eaa
ad8b011
16f5a6f
c2b4eaa
16f5a6f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c2b4eaa
16f5a6f
 
 
c2b4eaa
16f5a6f
 
 
 
 
 
 
 
 
 
 
ad8b011
c2b4eaa
ad8b011
16f5a6f
c2b4eaa
16f5a6f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c2b4eaa
16f5a6f
 
 
c2b4eaa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16f5a6f
 
8f56225
 
 
16f5a6f
8f56225
16f5a6f
 
 
c2b4eaa
 
16f5a6f
 
 
 
 
 
 
 
 
 
 
 
 
 
8736948
 
16f5a6f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c2b4eaa
16f5a6f
 
 
 
 
 
 
d52543f
16f5a6f
 
 
 
 
c2b4eaa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16f5a6f
 
 
 
 
 
 
 
 
 
 
 
 
c2b4eaa
 
 
 
 
 
 
 
d52543f
8f56225
8736948
8f56225
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
import io
import cv2
import matplotlib.pyplot as plt
import matplotlib
import gradio as gr
import numpy as np
import tensorflow as tf
from tensorflow.keras import layers as L, models
from PIL import Image
import os
# Disable Gradio queue for direct REST API access - MUST be before gradio import
os.environ["GRADIO_QUEUE"] = "false"
os.environ["HF_HUB_DISABLE_GRADIO_QUEUE"] = "1"

matplotlib.use('Agg')  # Use non-interactive backend

# Import XAI libraries with error handling
try:
    import shap
    SHAP_AVAILABLE = True
except ImportError:
    SHAP_AVAILABLE = False
    print("Warning: SHAP not available")

try:
    from lime import lime_image
    LIME_AVAILABLE = True
except ImportError:
    LIME_AVAILABLE = False
    print("Warning: LIME not available")

try:
    from skimage.segmentation import mark_boundaries
    SKIMAGE_AVAILABLE = True
except ImportError:
    SKIMAGE_AVAILABLE = False
    print("Warning: scikit-image not available")


# -----------------------------
# Model Architecture Components
# -----------------------------


class Patches(L.Layer):
    def __init__(self, patch_size, **kwargs):
        super(Patches, self).__init__(**kwargs)
        self.patch_size = patch_size

    def call(self, images):
        batch_size = tf.shape(images)[0]
        patches = tf.image.extract_patches(
            images=images,
            sizes=[1, self.patch_size, self.patch_size, 1],
            strides=[1, self.patch_size, self.patch_size, 1],
            rates=[1, 1, 1, 1],
            padding="VALID"
        )
        patch_dims = patches.shape[-1]
        patches = tf.reshape(patches, [batch_size, -1, patch_dims])
        return patches


class PatchEncoder(L.Layer):
    def __init__(self, num_patches, projection_dim, **kwargs):
        super(PatchEncoder, self).__init__(**kwargs)
        self.num_patches = num_patches
        self.projection = L.Dense(units=projection_dim)
        self.position_embedding = L.Embedding(
            input_dim=num_patches, output_dim=projection_dim)

    def call(self, patch):
        positions = tf.range(start=0, limit=self.num_patches, delta=1)
        encoded = self.projection(patch) + self.position_embedding(positions)
        return encoded


# -----------------------------
# Model Configuration
# -----------------------------
image_size = 224
patch_size = 8
projection_dim = 64
transformer_layers = 4
num_heads = 4
mlp_head_units = [128, 64]

# Class names (update based on your dataset)
class_names = ['GERD', 'GERD NORMAL', 'POLYP',
               'POLYP_NORMAL']  # Update with actual class names

# -----------------------------
# Load Model
# -----------------------------
try:
    model = tf.keras.models.load_model(
        'best_fold_model.h5',
        custom_objects={
            'Patches': Patches,
            'PatchEncoder': PatchEncoder
        }
    )
    print("βœ“ Model loaded successfully")
except Exception as e:
    print(f"Error loading model: {e}")
    model = None

# -----------------------------
# Preprocessing Function
# -----------------------------


def preprocess_image(image):
    """

    Preprocess image for model prediction.

    """
    # Handle different input types
    if isinstance(image, str):
        # If it's a file path or URL, load it
        image = Image.open(image)
    elif not isinstance(image, Image.Image):
        # If it's a numpy array, convert to PIL
        image = Image.fromarray(image)

    # Convert to RGB if necessary
    if image.mode != 'RGB':
        image = image.convert('RGB')

    # Resize to model input size
    image = image.resize((image_size, image_size))

    # Convert to numpy array and normalize
    img_array = np.array(image, dtype=np.float32)
    img_array = img_array / 255.0  # Normalize to [0, 1]

    # Add batch dimension
    img_array = np.expand_dims(img_array, axis=0)

    return img_array

# -----------------------------
# Prediction Function
# -----------------------------


def predict(image):
    """

    Make prediction on input image.

    """
    if model is None:
        # Return zero confidence for all classes when model not loaded
        return {class_name: 0.0 for class_name in class_names}

    if image is None:
        # Return zero confidence for all classes when no image provided
        return {class_name: 0.0 for class_name in class_names}

    try:
        # Preprocess image
        processed_image = preprocess_image(image)

        # Make prediction
        predictions = model.predict(processed_image, verbose=0)

        # Get probabilities for each class
        probabilities = predictions[0]

        # Create result dictionary with validated float values
        results = {}
        for i in range(len(class_names)):
            prob = probabilities[i]
            # Ensure the probability is a valid number
            if prob is None or (isinstance(prob, float) and (np.isnan(prob) or np.isinf(prob))):
                results[class_names[i]] = 0.0
            else:
                results[class_names[i]] = float(prob)

        return results

    except Exception as e:
        print(f"Prediction error: {e}")
        # Return zero confidence for all classes on error
        return {class_name: 0.0 for class_name in class_names}


# -----------------------------
# GradCAM Implementation
# -----------------------------


def make_gradcam_heatmap(img_array, model, pred_index=None):
    """

    Generate Grad-CAM heatmap for lightweight ViT model

    Uses the transformer output before global pooling

    """
    try:
        # Find the layer before GlobalAveragePooling (typically the last Add or LayerNormalization)
        target_layer = None
        for layer in reversed(model.layers):
            # Look for the last Add layer (from transformer blocks)
            if isinstance(layer, tf.keras.layers.Add):
                target_layer = layer
                break
            # Or the LayerNormalization before classification head
            if isinstance(layer, tf.keras.layers.LayerNormalization) and 'representation' not in layer.name:
                target_layer = layer
                break

        if target_layer is None:
            # Fallback: find any layer with 3D output (batch, seq_len, features)
            for layer in reversed(model.layers):
                if hasattr(layer, 'output_shape') and len(layer.output_shape) == 3:
                    target_layer = layer
                    break

        if target_layer is None:
            print("Warning: No suitable layer found for Grad-CAM")
            return None, pred_index

        # Create a model that outputs both the target layer output and final predictions
        grad_model = tf.keras.models.Model(
            inputs=model.inputs,
            outputs=[model.get_layer(target_layer.name).output, model.output]
        )

        # Compute gradients
        with tf.GradientTape() as tape:
            layer_output, predictions = grad_model(img_array, training=False)
            if pred_index is None:
                pred_index = tf.argmax(predictions[0])
            class_channel = predictions[:, pred_index]

        # Get gradients of the predicted class with respect to the layer output
        grads = tape.gradient(class_channel, layer_output)

        if grads is None:
            print("Warning: Gradients are None. Using simple attention map.")
            # Fallback: use attention weights
            layer_output_np = layer_output[0].numpy()
            heatmap = np.mean(np.abs(layer_output_np), axis=-1)
            # Reshape to 2D grid
            num_patches = heatmap.shape[0]
            grid_size = int(np.sqrt(num_patches))
            heatmap = heatmap[:grid_size *
                              grid_size].reshape(grid_size, grid_size)
            heatmap = (heatmap - heatmap.min()) / \
                (heatmap.max() - heatmap.min() + 1e-10)
            return heatmap, int(pred_index.numpy())

        # Global average pooling on gradients
        if len(grads.shape) == 3:  # (batch, seq_len, features)
            pooled_grads = tf.reduce_mean(grads, axis=(0, 1))
            layer_output = layer_output[0]

            # Weight the sequence by the gradients
            heatmap = layer_output @ pooled_grads[..., tf.newaxis]
            heatmap = tf.squeeze(heatmap)

            # Reshape to 2D grid
            num_patches = heatmap.shape[0]
            grid_size = int(np.sqrt(num_patches))
            if grid_size * grid_size != num_patches:
                # Handle case where sqrt is not exact
                # Exclude class token if present
                grid_size = int(np.sqrt(num_patches - 1))
                heatmap = heatmap[1:grid_size*grid_size+1]  # Skip class token
            else:
                heatmap = heatmap[:grid_size*grid_size]
            heatmap = tf.reshape(heatmap, (grid_size, grid_size))
        else:
            pooled_grads = tf.reduce_mean(grads, axis=(0, 1, 2))
            layer_output = layer_output[0]
            heatmap = layer_output @ pooled_grads[..., tf.newaxis]
            heatmap = tf.squeeze(heatmap)

        # Normalize between 0 and 1
        heatmap = tf.maximum(heatmap, 0) / \
            (tf.math.reduce_max(heatmap) + 1e-10)
        return heatmap.numpy(), int(pred_index.numpy())

    except Exception as e:
        print(f"GradCAM error: {e}")
        import traceback
        traceback.print_exc()
        return None, pred_index


def apply_gradcam(image, heatmap, alpha=0.4):
    """

    Apply GradCAM heatmap overlay on the original image.

    """
    try:
        if heatmap is None:
            return image

        # Convert image to numpy array
        if isinstance(image, Image.Image):
            img_array = np.array(image.resize((image_size, image_size)))
        else:
            img_array = image

        # Resize heatmap to match input image size
        heatmap_resized = cv2.resize(
            heatmap, (img_array.shape[1], img_array.shape[0]))

        # Convert heatmap to RGB
        heatmap_uint8 = np.uint8(255 * heatmap_resized)
        heatmap_colored = cv2.applyColorMap(heatmap_uint8, cv2.COLORMAP_JET)
        heatmap_colored = cv2.cvtColor(heatmap_colored, cv2.COLOR_BGR2RGB)

        # Normalize image if needed
        if img_array.max() <= 1.0:
            img_uint8 = (img_array * 255).astype('uint8')
        else:
            img_uint8 = img_array.astype('uint8')

        # Superimpose the heatmap on original image
        superimposed_img = heatmap_colored * alpha + img_uint8 * (1 - alpha)
        superimposed_img = np.clip(superimposed_img, 0, 255).astype('uint8')

        return Image.fromarray(superimposed_img)

    except Exception as e:
        print(f"Apply GradCAM error: {e}")
        return image


def generate_gradcam(image):
    """

    Generate GradCAM visualization.

    """
    if model is None or image is None:
        return None

    try:
        # Preprocess image
        processed_image = preprocess_image(image)

        # Make prediction
        predictions = model.predict(processed_image, verbose=0)
        pred_class = np.argmax(predictions[0])

        # Generate heatmap
        heatmap, _ = make_gradcam_heatmap(processed_image, model, pred_class)

        if heatmap is None:
            return None

        # Apply heatmap
        gradcam_image = apply_gradcam(image, heatmap, alpha=0.4)

        return gradcam_image

    except Exception as e:
        print(f"Error generating GradCAM: {e}")
        return None


# -----------------------------
# SHAP Implementation
# -----------------------------


def generate_shap(image):
    """

    Generate SHAP explanation visualization.

    """
    if not SHAP_AVAILABLE:
        return None

    if model is None or image is None:
        return None

    try:
        # Preprocess image
        if isinstance(image, Image.Image):
            img_array = np.array(image.resize((image_size, image_size)))
        else:
            img_array = image

        # Ensure image is uint8
        if img_array.dtype != np.uint8:
            img_array = np.uint8(
                img_array * 255 if img_array.max() <= 1 else img_array)

        # Define model prediction function
        def model_predict(x):
            # Normalize to [0, 1] before prediction
            preds = model(tf.convert_to_tensor(x / 255.0))
            return preds.numpy()

        # Create masker
        masker = shap.maskers.Image("inpaint_telea", img_array.shape)

        # Create explainer
        explainer = shap.Explainer(
            model_predict, masker, output_names=class_names)

        # Get SHAP values for the top predicted class
        shap_values = explainer(
            img_array[np.newaxis, ...], outputs=shap.Explanation.argsort.flip[:1])

        # Create visualization
        plt.figure(figsize=(10, 8))
        shap.image_plot(shap_values, img_array[np.newaxis, ...], show=False)

        # Save to buffer
        buf = io.BytesIO()
        plt.savefig(buf, format='png', bbox_inches='tight', dpi=100)
        buf.seek(0)
        shap_image = Image.open(buf)
        plt.close()

        return shap_image

    except Exception as e:
        print(f"SHAP error: {e}")
        return None


# -----------------------------
# LIME Implementation
# -----------------------------


def generate_lime(image):
    """

    Generate LIME explanation visualization.

    """
    if not LIME_AVAILABLE or not SKIMAGE_AVAILABLE:
        return None, None

    if model is None or image is None:
        return None, None

    try:
        # Preprocess image
        if isinstance(image, Image.Image):
            img_array = np.array(image.resize((image_size, image_size)))
        else:
            img_array = image

        # Normalize
        img_normalized = img_array / 255.0 if img_array.max() > 1 else img_array

        # Create LIME explainer
        explainer = lime_image.LimeImageExplainer()

        # Generate explanation
        explanation = explainer.explain_instance(
            img_normalized.astype('float64'),
            model.predict,
            top_labels=3,
            hide_color=0,
            num_samples=1000,
            batch_size=32
        )

        # Create visualizations
        # Positive features only
        temp_positive, mask_positive = explanation.get_image_and_mask(
            explanation.top_labels[0],
            positive_only=True,
            num_features=10,
            hide_rest=False
        )
        lime_positive = mark_boundaries(temp_positive, mask_positive)

        # Positive and negative features
        temp_both, mask_both = explanation.get_image_and_mask(
            explanation.top_labels[0],
            positive_only=False,
            num_features=10,
            hide_rest=False
        )
        lime_both = mark_boundaries(temp_both, mask_both)

        # Convert to PIL Images
        lime_positive_img = Image.fromarray(
            (lime_positive * 255).astype(np.uint8))
        lime_both_img = Image.fromarray((lime_both * 255).astype(np.uint8))

        return lime_positive_img, lime_both_img

    except Exception as e:
        print(f"LIME error: {e}")
        return None, None


# -----------------------------
# Unified Prediction with XAI
# -----------------------------


def predict_with_xai(image):
    """

    Make prediction and generate all XAI explanations at once.

    """
    if model is None or image is None:
        return {class_name: 0.0 for class_name in class_names}, None, None, None, None

    try:
        # Make prediction
        prediction_results = predict(image)

        # Generate GradCAM
        gradcam_img = generate_gradcam(image)

        # Generate SHAP (can be slow)
        shap_img = generate_shap(image)

        # Generate LIME (can be slow)
        lime_positive, lime_both = generate_lime(image)

        return prediction_results, gradcam_img, shap_img, lime_positive, lime_both

    except Exception as e:
        print(f"Error in predict_with_xai: {e}")
        return {class_name: 0.0 for class_name in class_names}, None, None, None, None


# -----------------------------
# Gradio Interface
# -----------------------------
title = "πŸ”¬ GERD Lightweight Vision Transformer with XAI"
description = """

<div style="text-align: center; padding: 20px;">

    <h2 style="color: #2E86AB;">Advanced Medical Image Analysis with Explainable AI</h2>

    <p style="font-size: 16px; color: #555;">

        Upload an endoscopic image to classify using a <b>Lightweight Vision Transformer</b> model.

        Get predictions with <b>three explainability methods</b> to understand the AI's decision.

    </p>

    <div style="display: flex; justify-content: center; gap: 20px; margin-top: 15px;">

        <div style="background: linear-gradient(135deg, #667eea 0%, #764ba2 100%); padding: 15px; border-radius: 10px; color: white;">

            <b>πŸ“Š Model Architecture</b><br>

            Image: 224Γ—224 | Patches: 8Γ—8<br>

            Projection: 64 | Layers: 4 | Heads: 4

        </div>

        <div style="background: linear-gradient(135deg, #f093fb 0%, #f5576c 100%); padding: 15px; border-radius: 10px; color: white;">

            <b>🎯 XAI Methods</b><br>

            GradCAM | SHAP | LIME<br>

            Visual Explanations

        </div>

    </div>

</div>

"""

# Custom CSS for creative styling
custom_css = """

.gradio-container {

    font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif !important;

}

h1 {

    background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);

    -webkit-background-clip: text;

    -webkit-text-fill-color: transparent;

    font-size: 2.5em !important;

    text-align: center !important;

}

.button-primary {

    background: linear-gradient(135deg, #667eea 0%, #764ba2 100%) !important;

    border: none !important;

    color: white !important;

    font-weight: bold !important;

    padding: 12px 30px !important;

    border-radius: 25px !important;

    font-size: 16px !important;

    transition: all 0.3s ease !important;

}

.button-primary:hover {

    transform: scale(1.05) !important;

    box-shadow: 0 8px 15px rgba(102, 126, 234, 0.4) !important;

}

"""

# Create Gradio interface using Blocks with creative design
with gr.Blocks(css=custom_css, theme=gr.themes.Soft()) as demo:
    gr.HTML(f"<h1>{title}</h1>")
    gr.HTML(description)

    with gr.Row():
        with gr.Column(scale=1):
            input_image = gr.Image(
                type="pil", label="πŸ“€ Upload Endoscopic Image")
            predict_btn = gr.Button(
                "πŸ” Classify & Explain", variant="primary", elem_classes="button-primary", size="lg")

            gr.Markdown("""

            <div style="background: #f0f4f8; padding: 15px; border-radius: 10px; margin-top: 10px;">

                <b>ℹ️ Instructions:</b>

                <ul>

                    <li>Upload an endoscopic image (JPG, PNG)</li>

                    <li>Click "Classify & Explain" to get results</li>

                    <li>View prediction + XAI explanations below</li>

                    <li><i>Note: SHAP and LIME may take 30-60 seconds</i></li>

                </ul>

            </div>

            """)

        with gr.Column(scale=1):
            output_label = gr.Label(
                num_top_classes=4, label="πŸ“Š Prediction Results", show_label=True)

    # Explanations Section
    gr.Markdown("""

    <div style="text-align: center; margin-top: 30px; margin-bottom: 20px;">

        <h2 style="color: #2E86AB;">🎯 Explainable AI Visualizations</h2>

        <p style="color: #666;">Understanding how the model makes its predictions</p>

    </div>

    """)

    with gr.Row():
        # GradCAM
        with gr.Column(scale=1):
            gr.Markdown("""

            <div style="background: linear-gradient(135deg, #fff3e0 0%, #ffe0b2 100%); padding: 15px; border-radius: 10px; margin-bottom: 10px;">

                <h3 style="margin: 0; color: #e65100;">πŸ”₯ Grad-CAM</h3>

                <p style="margin: 5px 0 0 0; font-size: 14px;">

                    <b>Gradient-weighted Class Activation Mapping</b><br>

                    Highlights regions most important for prediction. Red = high importance.

                </p>

            </div>

            """)
            output_gradcam = gr.Image(
                label="Grad-CAM Heatmap", show_label=False)

    with gr.Row():
        # SHAP
        with gr.Column(scale=1):
            gr.Markdown("""

            <div style="background: linear-gradient(135deg, #e8f5e9 0%, #c8e6c9 100%); padding: 15px; border-radius: 10px; margin-bottom: 10px;">

                <h3 style="margin: 0; color: #2e7d32;">🎯 SHAP</h3>

                <p style="margin: 5px 0 0 0; font-size: 14px;">

                    <b>SHapley Additive exPlanations</b><br>

                    Red pixels push toward predicted class, blue pixels push away.

                </p>

            </div>

            """)
            output_shap = gr.Image(label="SHAP Explanation", show_label=False)

    with gr.Row():
        # LIME
        with gr.Column(scale=1):
            gr.Markdown("""

            <div style="background: linear-gradient(135deg, #fce4ec 0%, #f8bbd0 100%); padding: 15px; border-radius: 10px; margin-bottom: 10px;">

                <h3 style="margin: 0; color: #c2185b;">πŸ‹ LIME - Positive Features</h3>

                <p style="margin: 5px 0 0 0; font-size: 14px;">

                    <b>Local Interpretable Model-agnostic Explanations</b><br>

                    Green boundaries show regions supporting the prediction.

                </p>

            </div>

            """)
            output_lime_positive = gr.Image(
                label="LIME Positive", show_label=False)

        with gr.Column(scale=1):
            gr.Markdown("""

            <div style="background: linear-gradient(135deg, #e1f5fe 0%, #b3e5fc 100%); padding: 15px; border-radius: 10px; margin-bottom: 10px;">

                <h3 style="margin: 0; color: #01579b;">πŸ‹ LIME - All Features</h3>

                <p style="margin: 5px 0 0 0; font-size: 14px;">

                    <b>Positive & Negative Contributions</b><br>

                    Shows both supporting and opposing regions.

                </p>

            </div>

            """)
            output_lime_both = gr.Image(
                label="LIME Positive & Negative", show_label=False)

    # Footer
    gr.Markdown("""

    <div style="text-align: center; margin-top: 30px; padding: 20px; background: linear-gradient(135deg, #667eea 0%, #764ba2 100%); border-radius: 10px; color: white;">

        <h3>πŸ₯ Medical AI with Transparency</h3>

        <p>This tool combines state-of-the-art Vision Transformer technology with explainable AI methods 

        to provide transparent and interpretable medical image analysis.</p>

        <p style="font-size: 12px; margin-top: 10px;">

            <b>Classes:</b> GERD, GERD NORMAL, POLYP, POLYP NORMAL

        </p>

    </div>

    """)

    # Connect button to unified function
    predict_btn.click(
        fn=predict_with_xai,
        inputs=input_image,
        outputs=[output_label, output_gradcam, output_shap,
                 output_lime_positive, output_lime_both],
        api_name="predict"
    )

# Launch with error reporting enabled
if __name__ == "__main__":
    demo.launch(show_error=True)