sparsh007 commited on
Commit
0e00745
·
verified ·
1 Parent(s): de59526

Upload 2 files

Browse files
Files changed (2) hide show
  1. app-4.py +98 -0
  2. requirements-3.txt +7 -0
app-4.py ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import keras
3
+ from keras.applications import inception_v3 as inc_net
4
+ from keras.preprocessing import image
5
+ from skimage.segmentation import mark_boundaries
6
+ import numpy as np
7
+ import matplotlib.pyplot as plt
8
+ import gradio as gr
9
+ from lime import lime_image
10
+
11
+ # Load the pre-trained InceptionV3 model
12
+ inet_model = inc_net.InceptionV3()
13
+
14
+ def transform_img_fn(img_path):
15
+ """Preprocess image for InceptionV3"""
16
+ img = image.load_img(img_path, target_size=(299, 299))
17
+ x = image.img_to_array(img)
18
+ x = np.expand_dims(x, axis=0)
19
+ return inc_net.preprocess_input(x)
20
+
21
+ def explain_image(img_path):
22
+ """Generate LIME explanation and visualization"""
23
+ # Preprocess image
24
+ processed_img = transform_img_fn(img_path)
25
+
26
+ # Create LIME explainer
27
+ explainer = lime_image.LimeImageExplainer()
28
+
29
+ # Generate explanation
30
+ explanation = explainer.explain_instance(
31
+ processed_img[0].astype('double'),
32
+ inet_model.predict,
33
+ top_labels=5,
34
+ hide_color=0,
35
+ num_samples=1000
36
+ )
37
+
38
+ # Get image and mask
39
+ temp, mask = explanation.get_image_and_mask(
40
+ explanation.top_labels[0],
41
+ positive_only=False,
42
+ num_features=10,
43
+ hide_rest=False
44
+ )
45
+
46
+ # Get top 5 predictions
47
+ predictions = inet_model.predict(processed_img)
48
+ top_5_indices = np.argsort(predictions[0])[-5:][::-1]
49
+ top_5_labels = [inc_net.decode_predictions(predictions, top=5)[0][i][1] for i in range(5)]
50
+ top_5_probs = [inc_net.decode_predictions(predictions, top=5)[0][i][2] for i in range(5)]
51
+
52
+ # Create visualization
53
+ fig, ax = plt.subplots(figsize=(6, 6))
54
+
55
+ # Explanation visualization
56
+ ax.imshow(mark_boundaries(temp / 2 + 0.5, mask))
57
+ ax.set_title('Pros (Green) vs Cons (Red)')
58
+ ax.axis('off')
59
+
60
+ plt.tight_layout()
61
+
62
+ # Create a string for the top 5 predictions
63
+ predictions_str = "Top 5 Predictions:\n"
64
+ for i, (label, prob) in enumerate(zip(top_5_labels, top_5_probs)):
65
+ predictions_str += f"{i+1}. {label}: {prob:.4f}\n"
66
+
67
+ # Generate heatmap
68
+ ind = explanation.top_labels[0]
69
+ dict_heatmap = dict(explanation.local_exp[ind])
70
+ heatmap = np.vectorize(dict_heatmap.get)(explanation.segments)
71
+
72
+ # Plot heatmap
73
+ fig_heatmap, ax_heatmap = plt.subplots(figsize=(6, 6))
74
+ heatmap_plot = ax_heatmap.imshow(heatmap, cmap='RdBu', vmin=-heatmap.max(), vmax=heatmap.max())
75
+ plt.colorbar(heatmap_plot, ax=ax_heatmap)
76
+ ax_heatmap.set_title('Heatmap Explanation')
77
+ ax_heatmap.axis('off')
78
+
79
+ plt.tight_layout()
80
+
81
+ return fig, predictions_str, fig_heatmap
82
+
83
+ # Create Gradio interface
84
+ demo = gr.Interface(
85
+ fn=explain_image,
86
+ inputs=gr.Image(type="filepath", label="Input Image"),
87
+ outputs=[
88
+ gr.Plot(label="Explanation"),
89
+ gr.Textbox(label="Top 5 Predictions"),
90
+ gr.Plot(label="Heatmap Explanation")
91
+ ],
92
+ title="LIME Image Classifier Explainer",
93
+ description="Upload an image to see which areas positively (green) and negatively (red) influence the classification, the top 5 predictions, and a heatmap explanation."
94
+ )
95
+
96
+ # Launch the app
97
+ if __name__ == "__main__":
98
+ demo.launch()
requirements-3.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ gradio
2
+ keras
3
+ tensorflow
4
+ lime
5
+ numpy
6
+ matplotlib
7
+ scikit-image