Spaces:
Sleeping
Sleeping
| import io | |
| import cv2 | |
| import matplotlib.pyplot as plt | |
| import matplotlib | |
| import gradio as gr | |
| import numpy as np | |
| import tensorflow as tf | |
| from tensorflow.keras import layers as L, models | |
| from PIL import Image | |
| import os | |
| # Disable Gradio queue for direct REST API access - MUST be before gradio import | |
| os.environ["GRADIO_QUEUE"] = "false" | |
| os.environ["HF_HUB_DISABLE_GRADIO_QUEUE"] = "1" | |
| matplotlib.use('Agg') # Use non-interactive backend | |
| # Import XAI libraries with error handling | |
| try: | |
| import shap | |
| SHAP_AVAILABLE = True | |
| except ImportError: | |
| SHAP_AVAILABLE = False | |
| print("Warning: SHAP not available") | |
| try: | |
| from lime import lime_image | |
| LIME_AVAILABLE = True | |
| except ImportError: | |
| LIME_AVAILABLE = False | |
| print("Warning: LIME not available") | |
| try: | |
| from skimage.segmentation import mark_boundaries | |
| SKIMAGE_AVAILABLE = True | |
| except ImportError: | |
| SKIMAGE_AVAILABLE = False | |
| print("Warning: scikit-image not available") | |
| # ----------------------------- | |
| # Model Architecture Components | |
| # ----------------------------- | |
| class Patches(L.Layer): | |
| def __init__(self, patch_size, **kwargs): | |
| super(Patches, self).__init__(**kwargs) | |
| self.patch_size = patch_size | |
| def call(self, images): | |
| batch_size = tf.shape(images)[0] | |
| patches = tf.image.extract_patches( | |
| images=images, | |
| sizes=[1, self.patch_size, self.patch_size, 1], | |
| strides=[1, self.patch_size, self.patch_size, 1], | |
| rates=[1, 1, 1, 1], | |
| padding="VALID" | |
| ) | |
| patch_dims = patches.shape[-1] | |
| patches = tf.reshape(patches, [batch_size, -1, patch_dims]) | |
| return patches | |
| class PatchEncoder(L.Layer): | |
| def __init__(self, num_patches, projection_dim, **kwargs): | |
| super(PatchEncoder, self).__init__(**kwargs) | |
| self.num_patches = num_patches | |
| self.projection = L.Dense(units=projection_dim) | |
| self.position_embedding = L.Embedding( | |
| input_dim=num_patches, output_dim=projection_dim) | |
| def call(self, patch): | |
| positions = tf.range(start=0, limit=self.num_patches, delta=1) | |
| encoded = self.projection(patch) + self.position_embedding(positions) | |
| return encoded | |
| # ----------------------------- | |
| # Model Configuration | |
| # ----------------------------- | |
| image_size = 224 | |
| patch_size = 8 | |
| projection_dim = 64 | |
| transformer_layers = 4 | |
| num_heads = 4 | |
| mlp_head_units = [128, 64] | |
| # Class names (update based on your dataset) | |
| class_names = ['GERD', 'GERD NORMAL', 'POLYP', | |
| 'POLYP_NORMAL'] # Update with actual class names | |
| # ----------------------------- | |
| # Load Model | |
| # ----------------------------- | |
| try: | |
| model = tf.keras.models.load_model( | |
| 'best_fold_model.h5', | |
| custom_objects={ | |
| 'Patches': Patches, | |
| 'PatchEncoder': PatchEncoder | |
| } | |
| ) | |
| print("β Model loaded successfully") | |
| except Exception as e: | |
| print(f"Error loading model: {e}") | |
| model = None | |
| # ----------------------------- | |
| # Preprocessing Function | |
| # ----------------------------- | |
| def preprocess_image(image): | |
| """ | |
| Preprocess image for model prediction. | |
| """ | |
| # Handle different input types | |
| if isinstance(image, str): | |
| # If it's a file path or URL, load it | |
| image = Image.open(image) | |
| elif not isinstance(image, Image.Image): | |
| # If it's a numpy array, convert to PIL | |
| image = Image.fromarray(image) | |
| # Convert to RGB if necessary | |
| if image.mode != 'RGB': | |
| image = image.convert('RGB') | |
| # Resize to model input size | |
| image = image.resize((image_size, image_size)) | |
| # Convert to numpy array and normalize | |
| img_array = np.array(image, dtype=np.float32) | |
| img_array = img_array / 255.0 # Normalize to [0, 1] | |
| # Add batch dimension | |
| img_array = np.expand_dims(img_array, axis=0) | |
| return img_array | |
| # ----------------------------- | |
| # Prediction Function | |
| # ----------------------------- | |
| def predict(image): | |
| """ | |
| Make prediction on input image. | |
| """ | |
| if model is None: | |
| # Return zero confidence for all classes when model not loaded | |
| return {class_name: 0.0 for class_name in class_names} | |
| if image is None: | |
| # Return zero confidence for all classes when no image provided | |
| return {class_name: 0.0 for class_name in class_names} | |
| try: | |
| # Preprocess image | |
| processed_image = preprocess_image(image) | |
| # Make prediction | |
| predictions = model.predict(processed_image, verbose=0) | |
| # Get probabilities for each class | |
| probabilities = predictions[0] | |
| # Create result dictionary with validated float values | |
| results = {} | |
| for i in range(len(class_names)): | |
| prob = probabilities[i] | |
| # Ensure the probability is a valid number | |
| if prob is None or (isinstance(prob, float) and (np.isnan(prob) or np.isinf(prob))): | |
| results[class_names[i]] = 0.0 | |
| else: | |
| results[class_names[i]] = float(prob) | |
| return results | |
| except Exception as e: | |
| print(f"Prediction error: {e}") | |
| # Return zero confidence for all classes on error | |
| return {class_name: 0.0 for class_name in class_names} | |
| # ----------------------------- | |
| # GradCAM Implementation | |
| # ----------------------------- | |
| def make_gradcam_heatmap(img_array, model, pred_index=None): | |
| """ | |
| Generate Grad-CAM heatmap for lightweight ViT model | |
| Uses the transformer output before global pooling | |
| """ | |
| try: | |
| # Find the layer before GlobalAveragePooling (typically the last Add or LayerNormalization) | |
| target_layer = None | |
| for layer in reversed(model.layers): | |
| # Look for the last Add layer (from transformer blocks) | |
| if isinstance(layer, tf.keras.layers.Add): | |
| target_layer = layer | |
| break | |
| # Or the LayerNormalization before classification head | |
| if isinstance(layer, tf.keras.layers.LayerNormalization) and 'representation' not in layer.name: | |
| target_layer = layer | |
| break | |
| if target_layer is None: | |
| # Fallback: find any layer with 3D output (batch, seq_len, features) | |
| for layer in reversed(model.layers): | |
| if hasattr(layer, 'output_shape') and len(layer.output_shape) == 3: | |
| target_layer = layer | |
| break | |
| if target_layer is None: | |
| print("Warning: No suitable layer found for Grad-CAM") | |
| return None, pred_index | |
| # Create a model that outputs both the target layer output and final predictions | |
| grad_model = tf.keras.models.Model( | |
| inputs=model.inputs, | |
| outputs=[model.get_layer(target_layer.name).output, model.output] | |
| ) | |
| # Compute gradients | |
| with tf.GradientTape() as tape: | |
| layer_output, predictions = grad_model(img_array, training=False) | |
| if pred_index is None: | |
| pred_index = tf.argmax(predictions[0]) | |
| class_channel = predictions[:, pred_index] | |
| # Get gradients of the predicted class with respect to the layer output | |
| grads = tape.gradient(class_channel, layer_output) | |
| if grads is None: | |
| print("Warning: Gradients are None. Using simple attention map.") | |
| # Fallback: use attention weights | |
| layer_output_np = layer_output[0].numpy() | |
| heatmap = np.mean(np.abs(layer_output_np), axis=-1) | |
| # Reshape to 2D grid | |
| num_patches = heatmap.shape[0] | |
| grid_size = int(np.sqrt(num_patches)) | |
| heatmap = heatmap[:grid_size * | |
| grid_size].reshape(grid_size, grid_size) | |
| heatmap = (heatmap - heatmap.min()) / \ | |
| (heatmap.max() - heatmap.min() + 1e-10) | |
| return heatmap, int(pred_index.numpy()) | |
| # Global average pooling on gradients | |
| if len(grads.shape) == 3: # (batch, seq_len, features) | |
| pooled_grads = tf.reduce_mean(grads, axis=(0, 1)) | |
| layer_output = layer_output[0] | |
| # Weight the sequence by the gradients | |
| heatmap = layer_output @ pooled_grads[..., tf.newaxis] | |
| heatmap = tf.squeeze(heatmap) | |
| # Reshape to 2D grid | |
| num_patches = heatmap.shape[0] | |
| grid_size = int(np.sqrt(num_patches)) | |
| if grid_size * grid_size != num_patches: | |
| # Handle case where sqrt is not exact | |
| # Exclude class token if present | |
| grid_size = int(np.sqrt(num_patches - 1)) | |
| heatmap = heatmap[1:grid_size*grid_size+1] # Skip class token | |
| else: | |
| heatmap = heatmap[:grid_size*grid_size] | |
| heatmap = tf.reshape(heatmap, (grid_size, grid_size)) | |
| else: | |
| pooled_grads = tf.reduce_mean(grads, axis=(0, 1, 2)) | |
| layer_output = layer_output[0] | |
| heatmap = layer_output @ pooled_grads[..., tf.newaxis] | |
| heatmap = tf.squeeze(heatmap) | |
| # Normalize between 0 and 1 | |
| heatmap = tf.maximum(heatmap, 0) / \ | |
| (tf.math.reduce_max(heatmap) + 1e-10) | |
| return heatmap.numpy(), int(pred_index.numpy()) | |
| except Exception as e: | |
| print(f"GradCAM error: {e}") | |
| import traceback | |
| traceback.print_exc() | |
| return None, pred_index | |
| def apply_gradcam(image, heatmap, alpha=0.4): | |
| """ | |
| Apply GradCAM heatmap overlay on the original image. | |
| """ | |
| try: | |
| if heatmap is None: | |
| return image | |
| # Convert image to numpy array | |
| if isinstance(image, Image.Image): | |
| img_array = np.array(image.resize((image_size, image_size))) | |
| else: | |
| img_array = image | |
| # Resize heatmap to match input image size | |
| heatmap_resized = cv2.resize( | |
| heatmap, (img_array.shape[1], img_array.shape[0])) | |
| # Convert heatmap to RGB | |
| heatmap_uint8 = np.uint8(255 * heatmap_resized) | |
| heatmap_colored = cv2.applyColorMap(heatmap_uint8, cv2.COLORMAP_JET) | |
| heatmap_colored = cv2.cvtColor(heatmap_colored, cv2.COLOR_BGR2RGB) | |
| # Normalize image if needed | |
| if img_array.max() <= 1.0: | |
| img_uint8 = (img_array * 255).astype('uint8') | |
| else: | |
| img_uint8 = img_array.astype('uint8') | |
| # Superimpose the heatmap on original image | |
| superimposed_img = heatmap_colored * alpha + img_uint8 * (1 - alpha) | |
| superimposed_img = np.clip(superimposed_img, 0, 255).astype('uint8') | |
| return Image.fromarray(superimposed_img) | |
| except Exception as e: | |
| print(f"Apply GradCAM error: {e}") | |
| return image | |
| def generate_gradcam(image): | |
| """ | |
| Generate GradCAM visualization. | |
| """ | |
| if model is None or image is None: | |
| return None | |
| try: | |
| # Preprocess image | |
| processed_image = preprocess_image(image) | |
| # Make prediction | |
| predictions = model.predict(processed_image, verbose=0) | |
| pred_class = np.argmax(predictions[0]) | |
| # Generate heatmap | |
| heatmap, _ = make_gradcam_heatmap(processed_image, model, pred_class) | |
| if heatmap is None: | |
| return None | |
| # Apply heatmap | |
| gradcam_image = apply_gradcam(image, heatmap, alpha=0.4) | |
| return gradcam_image | |
| except Exception as e: | |
| print(f"Error generating GradCAM: {e}") | |
| return None | |
| # ----------------------------- | |
| # SHAP Implementation | |
| # ----------------------------- | |
| def generate_shap(image): | |
| """ | |
| Generate SHAP explanation visualization. | |
| """ | |
| if not SHAP_AVAILABLE: | |
| return None | |
| if model is None or image is None: | |
| return None | |
| try: | |
| # Preprocess image | |
| if isinstance(image, Image.Image): | |
| img_array = np.array(image.resize((image_size, image_size))) | |
| else: | |
| img_array = image | |
| # Ensure image is uint8 | |
| if img_array.dtype != np.uint8: | |
| img_array = np.uint8( | |
| img_array * 255 if img_array.max() <= 1 else img_array) | |
| # Define model prediction function | |
| def model_predict(x): | |
| # Normalize to [0, 1] before prediction | |
| preds = model(tf.convert_to_tensor(x / 255.0)) | |
| return preds.numpy() | |
| # Create masker | |
| masker = shap.maskers.Image("inpaint_telea", img_array.shape) | |
| # Create explainer | |
| explainer = shap.Explainer( | |
| model_predict, masker, output_names=class_names) | |
| # Get SHAP values for the top predicted class | |
| shap_values = explainer( | |
| img_array[np.newaxis, ...], outputs=shap.Explanation.argsort.flip[:1]) | |
| # Create visualization | |
| plt.figure(figsize=(10, 8)) | |
| shap.image_plot(shap_values, img_array[np.newaxis, ...], show=False) | |
| # Save to buffer | |
| buf = io.BytesIO() | |
| plt.savefig(buf, format='png', bbox_inches='tight', dpi=100) | |
| buf.seek(0) | |
| shap_image = Image.open(buf) | |
| plt.close() | |
| return shap_image | |
| except Exception as e: | |
| print(f"SHAP error: {e}") | |
| return None | |
| # ----------------------------- | |
| # LIME Implementation | |
| # ----------------------------- | |
| def generate_lime(image): | |
| """ | |
| Generate LIME explanation visualization. | |
| """ | |
| if not LIME_AVAILABLE or not SKIMAGE_AVAILABLE: | |
| return None, None | |
| if model is None or image is None: | |
| return None, None | |
| try: | |
| # Preprocess image | |
| if isinstance(image, Image.Image): | |
| img_array = np.array(image.resize((image_size, image_size))) | |
| else: | |
| img_array = image | |
| # Normalize | |
| img_normalized = img_array / 255.0 if img_array.max() > 1 else img_array | |
| # Create LIME explainer | |
| explainer = lime_image.LimeImageExplainer() | |
| # Generate explanation | |
| explanation = explainer.explain_instance( | |
| img_normalized.astype('float64'), | |
| model.predict, | |
| top_labels=3, | |
| hide_color=0, | |
| num_samples=1000, | |
| batch_size=32 | |
| ) | |
| # Create visualizations | |
| # Positive features only | |
| temp_positive, mask_positive = explanation.get_image_and_mask( | |
| explanation.top_labels[0], | |
| positive_only=True, | |
| num_features=10, | |
| hide_rest=False | |
| ) | |
| lime_positive = mark_boundaries(temp_positive, mask_positive) | |
| # Positive and negative features | |
| temp_both, mask_both = explanation.get_image_and_mask( | |
| explanation.top_labels[0], | |
| positive_only=False, | |
| num_features=10, | |
| hide_rest=False | |
| ) | |
| lime_both = mark_boundaries(temp_both, mask_both) | |
| # Convert to PIL Images | |
| lime_positive_img = Image.fromarray( | |
| (lime_positive * 255).astype(np.uint8)) | |
| lime_both_img = Image.fromarray((lime_both * 255).astype(np.uint8)) | |
| return lime_positive_img, lime_both_img | |
| except Exception as e: | |
| print(f"LIME error: {e}") | |
| return None, None | |
| # ----------------------------- | |
| # Unified Prediction with XAI | |
| # ----------------------------- | |
| def predict_with_xai(image): | |
| """ | |
| Make prediction and generate all XAI explanations at once. | |
| """ | |
| if model is None or image is None: | |
| return {class_name: 0.0 for class_name in class_names}, None, None, None, None | |
| try: | |
| # Make prediction | |
| prediction_results = predict(image) | |
| # Generate GradCAM | |
| gradcam_img = generate_gradcam(image) | |
| # Generate SHAP (can be slow) | |
| shap_img = generate_shap(image) | |
| # Generate LIME (can be slow) | |
| lime_positive, lime_both = generate_lime(image) | |
| return prediction_results, gradcam_img, shap_img, lime_positive, lime_both | |
| except Exception as e: | |
| print(f"Error in predict_with_xai: {e}") | |
| return {class_name: 0.0 for class_name in class_names}, None, None, None, None | |
| # ----------------------------- | |
| # Gradio Interface | |
| # ----------------------------- | |
| title = "π¬ GERD Lightweight Vision Transformer with XAI" | |
| description = """ | |
| <div style="text-align: center; padding: 20px;"> | |
| <h2 style="color: #2E86AB;">Advanced Medical Image Analysis with Explainable AI</h2> | |
| <p style="font-size: 16px; color: #555;"> | |
| Upload an endoscopic image to classify using a <b>Lightweight Vision Transformer</b> model. | |
| Get predictions with <b>three explainability methods</b> to understand the AI's decision. | |
| </p> | |
| <div style="display: flex; justify-content: center; gap: 20px; margin-top: 15px;"> | |
| <div style="background: linear-gradient(135deg, #667eea 0%, #764ba2 100%); padding: 15px; border-radius: 10px; color: white;"> | |
| <b>π Model Architecture</b><br> | |
| Image: 224Γ224 | Patches: 8Γ8<br> | |
| Projection: 64 | Layers: 4 | Heads: 4 | |
| </div> | |
| <div style="background: linear-gradient(135deg, #f093fb 0%, #f5576c 100%); padding: 15px; border-radius: 10px; color: white;"> | |
| <b>π― XAI Methods</b><br> | |
| GradCAM | SHAP | LIME<br> | |
| Visual Explanations | |
| </div> | |
| </div> | |
| </div> | |
| """ | |
| # Custom CSS for creative styling | |
| custom_css = """ | |
| .gradio-container { | |
| font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif !important; | |
| } | |
| h1 { | |
| background: linear-gradient(135deg, #667eea 0%, #764ba2 100%); | |
| -webkit-background-clip: text; | |
| -webkit-text-fill-color: transparent; | |
| font-size: 2.5em !important; | |
| text-align: center !important; | |
| } | |
| .button-primary { | |
| background: linear-gradient(135deg, #667eea 0%, #764ba2 100%) !important; | |
| border: none !important; | |
| color: white !important; | |
| font-weight: bold !important; | |
| padding: 12px 30px !important; | |
| border-radius: 25px !important; | |
| font-size: 16px !important; | |
| transition: all 0.3s ease !important; | |
| } | |
| .button-primary:hover { | |
| transform: scale(1.05) !important; | |
| box-shadow: 0 8px 15px rgba(102, 126, 234, 0.4) !important; | |
| } | |
| """ | |
| # Create Gradio interface using Blocks with creative design | |
| with gr.Blocks(css=custom_css, theme=gr.themes.Soft()) as demo: | |
| gr.HTML(f"<h1>{title}</h1>") | |
| gr.HTML(description) | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| input_image = gr.Image( | |
| type="pil", label="π€ Upload Endoscopic Image") | |
| predict_btn = gr.Button( | |
| "π Classify & Explain", variant="primary", elem_classes="button-primary", size="lg") | |
| gr.Markdown(""" | |
| <div style="background: #f0f4f8; padding: 15px; border-radius: 10px; margin-top: 10px;"> | |
| <b>βΉοΈ Instructions:</b> | |
| <ul> | |
| <li>Upload an endoscopic image (JPG, PNG)</li> | |
| <li>Click "Classify & Explain" to get results</li> | |
| <li>View prediction + XAI explanations below</li> | |
| <li><i>Note: SHAP and LIME may take 30-60 seconds</i></li> | |
| </ul> | |
| </div> | |
| """) | |
| with gr.Column(scale=1): | |
| output_label = gr.Label( | |
| num_top_classes=4, label="π Prediction Results", show_label=True) | |
| # Explanations Section | |
| gr.Markdown(""" | |
| <div style="text-align: center; margin-top: 30px; margin-bottom: 20px;"> | |
| <h2 style="color: #2E86AB;">π― Explainable AI Visualizations</h2> | |
| <p style="color: #666;">Understanding how the model makes its predictions</p> | |
| </div> | |
| """) | |
| with gr.Row(): | |
| # GradCAM | |
| with gr.Column(scale=1): | |
| gr.Markdown(""" | |
| <div style="background: linear-gradient(135deg, #fff3e0 0%, #ffe0b2 100%); padding: 15px; border-radius: 10px; margin-bottom: 10px;"> | |
| <h3 style="margin: 0; color: #e65100;">π₯ Grad-CAM</h3> | |
| <p style="margin: 5px 0 0 0; font-size: 14px;"> | |
| <b>Gradient-weighted Class Activation Mapping</b><br> | |
| Highlights regions most important for prediction. Red = high importance. | |
| </p> | |
| </div> | |
| """) | |
| output_gradcam = gr.Image( | |
| label="Grad-CAM Heatmap", show_label=False) | |
| with gr.Row(): | |
| # SHAP | |
| with gr.Column(scale=1): | |
| gr.Markdown(""" | |
| <div style="background: linear-gradient(135deg, #e8f5e9 0%, #c8e6c9 100%); padding: 15px; border-radius: 10px; margin-bottom: 10px;"> | |
| <h3 style="margin: 0; color: #2e7d32;">π― SHAP</h3> | |
| <p style="margin: 5px 0 0 0; font-size: 14px;"> | |
| <b>SHapley Additive exPlanations</b><br> | |
| Red pixels push toward predicted class, blue pixels push away. | |
| </p> | |
| </div> | |
| """) | |
| output_shap = gr.Image(label="SHAP Explanation", show_label=False) | |
| with gr.Row(): | |
| # LIME | |
| with gr.Column(scale=1): | |
| gr.Markdown(""" | |
| <div style="background: linear-gradient(135deg, #fce4ec 0%, #f8bbd0 100%); padding: 15px; border-radius: 10px; margin-bottom: 10px;"> | |
| <h3 style="margin: 0; color: #c2185b;">π LIME - Positive Features</h3> | |
| <p style="margin: 5px 0 0 0; font-size: 14px;"> | |
| <b>Local Interpretable Model-agnostic Explanations</b><br> | |
| Green boundaries show regions supporting the prediction. | |
| </p> | |
| </div> | |
| """) | |
| output_lime_positive = gr.Image( | |
| label="LIME Positive", show_label=False) | |
| with gr.Column(scale=1): | |
| gr.Markdown(""" | |
| <div style="background: linear-gradient(135deg, #e1f5fe 0%, #b3e5fc 100%); padding: 15px; border-radius: 10px; margin-bottom: 10px;"> | |
| <h3 style="margin: 0; color: #01579b;">π LIME - All Features</h3> | |
| <p style="margin: 5px 0 0 0; font-size: 14px;"> | |
| <b>Positive & Negative Contributions</b><br> | |
| Shows both supporting and opposing regions. | |
| </p> | |
| </div> | |
| """) | |
| output_lime_both = gr.Image( | |
| label="LIME Positive & Negative", show_label=False) | |
| # Footer | |
| gr.Markdown(""" | |
| <div style="text-align: center; margin-top: 30px; padding: 20px; background: linear-gradient(135deg, #667eea 0%, #764ba2 100%); border-radius: 10px; color: white;"> | |
| <h3>π₯ Medical AI with Transparency</h3> | |
| <p>This tool combines state-of-the-art Vision Transformer technology with explainable AI methods | |
| to provide transparent and interpretable medical image analysis.</p> | |
| <p style="font-size: 12px; margin-top: 10px;"> | |
| <b>Classes:</b> GERD, GERD NORMAL, POLYP, POLYP NORMAL | |
| </p> | |
| </div> | |
| """) | |
| # Connect button to unified function | |
| predict_btn.click( | |
| fn=predict_with_xai, | |
| inputs=input_image, | |
| outputs=[output_label, output_gradcam, output_shap, | |
| output_lime_positive, output_lime_both], | |
| api_name="predict" | |
| ) | |
| # Launch with error reporting enabled | |
| if __name__ == "__main__": | |
| demo.launch(show_error=True) | |