Spaces:
Sleeping
Sleeping
File size: 3,189 Bytes
0e00745 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 | import os
import keras
from keras.applications import inception_v3 as inc_net
from keras.preprocessing import image
from skimage.segmentation import mark_boundaries
import numpy as np
import matplotlib.pyplot as plt
import gradio as gr
from lime import lime_image
# Load the pre-trained InceptionV3 model
inet_model = inc_net.InceptionV3()
def transform_img_fn(img_path):
"""Preprocess image for InceptionV3"""
img = image.load_img(img_path, target_size=(299, 299))
x = image.img_to_array(img)
x = np.expand_dims(x, axis=0)
return inc_net.preprocess_input(x)
def explain_image(img_path):
"""Generate LIME explanation and visualization"""
# Preprocess image
processed_img = transform_img_fn(img_path)
# Create LIME explainer
explainer = lime_image.LimeImageExplainer()
# Generate explanation
explanation = explainer.explain_instance(
processed_img[0].astype('double'),
inet_model.predict,
top_labels=5,
hide_color=0,
num_samples=1000
)
# Get image and mask
temp, mask = explanation.get_image_and_mask(
explanation.top_labels[0],
positive_only=False,
num_features=10,
hide_rest=False
)
# Get top 5 predictions
predictions = inet_model.predict(processed_img)
top_5_indices = np.argsort(predictions[0])[-5:][::-1]
top_5_labels = [inc_net.decode_predictions(predictions, top=5)[0][i][1] for i in range(5)]
top_5_probs = [inc_net.decode_predictions(predictions, top=5)[0][i][2] for i in range(5)]
# Create visualization
fig, ax = plt.subplots(figsize=(6, 6))
# Explanation visualization
ax.imshow(mark_boundaries(temp / 2 + 0.5, mask))
ax.set_title('Pros (Green) vs Cons (Red)')
ax.axis('off')
plt.tight_layout()
# Create a string for the top 5 predictions
predictions_str = "Top 5 Predictions:\n"
for i, (label, prob) in enumerate(zip(top_5_labels, top_5_probs)):
predictions_str += f"{i+1}. {label}: {prob:.4f}\n"
# Generate heatmap
ind = explanation.top_labels[0]
dict_heatmap = dict(explanation.local_exp[ind])
heatmap = np.vectorize(dict_heatmap.get)(explanation.segments)
# Plot heatmap
fig_heatmap, ax_heatmap = plt.subplots(figsize=(6, 6))
heatmap_plot = ax_heatmap.imshow(heatmap, cmap='RdBu', vmin=-heatmap.max(), vmax=heatmap.max())
plt.colorbar(heatmap_plot, ax=ax_heatmap)
ax_heatmap.set_title('Heatmap Explanation')
ax_heatmap.axis('off')
plt.tight_layout()
return fig, predictions_str, fig_heatmap
# Create Gradio interface
demo = gr.Interface(
fn=explain_image,
inputs=gr.Image(type="filepath", label="Input Image"),
outputs=[
gr.Plot(label="Explanation"),
gr.Textbox(label="Top 5 Predictions"),
gr.Plot(label="Heatmap Explanation")
],
title="LIME Image Classifier Explainer",
description="Upload an image to see which areas positively (green) and negatively (red) influence the classification, the top 5 predictions, and a heatmap explanation."
)
# Launch the app
if __name__ == "__main__":
demo.launch()
|