Spaces:
Sleeping
Sleeping
| import os | |
| import keras | |
| from keras.applications import inception_v3 as inc_net | |
| from keras.preprocessing import image | |
| from skimage.segmentation import mark_boundaries | |
| import numpy as np | |
| import matplotlib.pyplot as plt | |
| import gradio as gr | |
| from lime import lime_image | |
| # Load the pre-trained InceptionV3 model | |
| inet_model = inc_net.InceptionV3() | |
| def transform_img_fn(img_path): | |
| """Preprocess image for InceptionV3""" | |
| img = image.load_img(img_path, target_size=(299, 299)) | |
| x = image.img_to_array(img) | |
| x = np.expand_dims(x, axis=0) | |
| return inc_net.preprocess_input(x) | |
| def explain_image(img_path): | |
| """Generate LIME explanation and visualization""" | |
| # Preprocess image | |
| processed_img = transform_img_fn(img_path) | |
| # Create LIME explainer | |
| explainer = lime_image.LimeImageExplainer() | |
| # Generate explanation | |
| explanation = explainer.explain_instance( | |
| processed_img[0].astype('double'), | |
| inet_model.predict, | |
| top_labels=5, | |
| hide_color=0, | |
| num_samples=1000 | |
| ) | |
| # Get image and mask | |
| temp, mask = explanation.get_image_and_mask( | |
| explanation.top_labels[0], | |
| positive_only=False, | |
| num_features=10, | |
| hide_rest=False | |
| ) | |
| # Get top 5 predictions | |
| predictions = inet_model.predict(processed_img) | |
| top_5_indices = np.argsort(predictions[0])[-5:][::-1] | |
| top_5_labels = [inc_net.decode_predictions(predictions, top=5)[0][i][1] for i in range(5)] | |
| top_5_probs = [inc_net.decode_predictions(predictions, top=5)[0][i][2] for i in range(5)] | |
| # Create visualization | |
| fig, ax = plt.subplots(figsize=(6, 6)) | |
| # Explanation visualization | |
| ax.imshow(mark_boundaries(temp / 2 + 0.5, mask)) | |
| ax.set_title('Pros (Green) vs Cons (Red)') | |
| ax.axis('off') | |
| plt.tight_layout() | |
| # Create a string for the top 5 predictions | |
| predictions_str = "Top 5 Predictions:\n" | |
| for i, (label, prob) in enumerate(zip(top_5_labels, top_5_probs)): | |
| predictions_str += f"{i+1}. {label}: {prob:.4f}\n" | |
| # Generate heatmap | |
| ind = explanation.top_labels[0] | |
| dict_heatmap = dict(explanation.local_exp[ind]) | |
| heatmap = np.vectorize(dict_heatmap.get)(explanation.segments) | |
| # Plot heatmap | |
| fig_heatmap, ax_heatmap = plt.subplots(figsize=(6, 6)) | |
| heatmap_plot = ax_heatmap.imshow(heatmap, cmap='RdBu', vmin=-heatmap.max(), vmax=heatmap.max()) | |
| plt.colorbar(heatmap_plot, ax=ax_heatmap) | |
| ax_heatmap.set_title('Heatmap Explanation') | |
| ax_heatmap.axis('off') | |
| plt.tight_layout() | |
| return fig, predictions_str, fig_heatmap | |
| # Create Gradio interface | |
| demo = gr.Interface( | |
| fn=explain_image, | |
| inputs=gr.Image(type="filepath", label="Input Image"), | |
| outputs=[ | |
| gr.Plot(label="Explanation"), | |
| gr.Textbox(label="Top 5 Predictions"), | |
| gr.Plot(label="Heatmap Explanation") | |
| ], | |
| title="LIME Image Classifier Explainer", | |
| description="Upload an image to see which areas positively (green) and negatively (red) influence the classification, the top 5 predictions, and a heatmap explanation." | |
| ) | |
| # Launch the app | |
| if __name__ == "__main__": | |
| demo.launch() | |