import os import keras from keras.applications import inception_v3 as inc_net from keras.preprocessing import image from skimage.segmentation import mark_boundaries import numpy as np import matplotlib.pyplot as plt import gradio as gr from lime import lime_image # Load the pre-trained InceptionV3 model inet_model = inc_net.InceptionV3() def transform_img_fn(img_path): """Preprocess image for InceptionV3""" img = image.load_img(img_path, target_size=(299, 299)) x = image.img_to_array(img) x = np.expand_dims(x, axis=0) return inc_net.preprocess_input(x) def explain_image(img_path): """Generate LIME explanation and visualization""" # Preprocess image processed_img = transform_img_fn(img_path) # Create LIME explainer explainer = lime_image.LimeImageExplainer() # Generate explanation explanation = explainer.explain_instance( processed_img[0].astype('double'), inet_model.predict, top_labels=5, hide_color=0, num_samples=1000 ) # Get image and mask temp, mask = explanation.get_image_and_mask( explanation.top_labels[0], positive_only=False, num_features=10, hide_rest=False ) # Get top 5 predictions predictions = inet_model.predict(processed_img) top_5_indices = np.argsort(predictions[0])[-5:][::-1] top_5_labels = [inc_net.decode_predictions(predictions, top=5)[0][i][1] for i in range(5)] top_5_probs = [inc_net.decode_predictions(predictions, top=5)[0][i][2] for i in range(5)] # Create visualization fig, ax = plt.subplots(figsize=(6, 6)) # Explanation visualization ax.imshow(mark_boundaries(temp / 2 + 0.5, mask)) ax.set_title('Pros (Green) vs Cons (Red)') ax.axis('off') plt.tight_layout() # Create a string for the top 5 predictions predictions_str = "Top 5 Predictions:\n" for i, (label, prob) in enumerate(zip(top_5_labels, top_5_probs)): predictions_str += f"{i+1}. {label}: {prob:.4f}\n" # Generate heatmap ind = explanation.top_labels[0] dict_heatmap = dict(explanation.local_exp[ind]) heatmap = np.vectorize(dict_heatmap.get)(explanation.segments) # Plot heatmap fig_heatmap, ax_heatmap = plt.subplots(figsize=(6, 6)) heatmap_plot = ax_heatmap.imshow(heatmap, cmap='RdBu', vmin=-heatmap.max(), vmax=heatmap.max()) plt.colorbar(heatmap_plot, ax=ax_heatmap) ax_heatmap.set_title('Heatmap Explanation') ax_heatmap.axis('off') plt.tight_layout() return fig, predictions_str, fig_heatmap # Create Gradio interface demo = gr.Interface( fn=explain_image, inputs=gr.Image(type="filepath", label="Input Image"), outputs=[ gr.Plot(label="Explanation"), gr.Textbox(label="Top 5 Predictions"), gr.Plot(label="Heatmap Explanation") ], title="LIME Image Classifier Explainer", description="Upload an image to see which areas positively (green) and negatively (red) influence the classification, the top 5 predictions, and a heatmap explanation." ) # Launch the app if __name__ == "__main__": demo.launch()