Spaces:
Running on CPU Upgrade
Running on CPU Upgrade
feat:workbench page
Browse files- app.py +268 -95
- closest_sample.py +56 -3
- explanations.py +47 -20
- inference_resnet.py +1 -1
app.py
CHANGED
|
@@ -21,7 +21,8 @@ from inference_resnet import get_triplet_model
|
|
| 21 |
from inference_beit import get_triplet_model_beit
|
| 22 |
import pathlib
|
| 23 |
import tensorflow as tf
|
| 24 |
-
from closest_sample import get_images
|
|
|
|
| 25 |
|
| 26 |
if not os.path.exists('images'):
|
| 27 |
REPO_ID='Serrelab/image_examples_gradio'
|
|
@@ -35,6 +36,57 @@ if not os.path.exists('dataset'):
|
|
| 35 |
print("warning! A read token in env variables is needed for authentication.")
|
| 36 |
snapshot_download(repo_id=REPO_ID, token=token,repo_type='dataset',local_dir='dataset')
|
| 37 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 38 |
def get_model(model_name):
|
| 39 |
|
| 40 |
|
|
@@ -61,6 +113,13 @@ def get_model(model_name):
|
|
| 61 |
embedding_depth = 2,
|
| 62 |
n_classes = n_classes)
|
| 63 |
model.load_weights('model_classification/fossil-142.h5')
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 64 |
else:
|
| 65 |
raise ValueError(f"Model name '{model_name}' is not recognized")
|
| 66 |
return model,n_classes
|
|
@@ -82,7 +141,12 @@ def classify_image(input_image, model_name):
|
|
| 82 |
model, n_classes= get_model(model_name)
|
| 83 |
result = inference_resnet_finer(input_image,model,size=600,n_classes=n_classes)
|
| 84 |
return result
|
| 85 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 86 |
from inference_beit import inference_resnet_finer_beit
|
| 87 |
model,n_classes = get_model(model_name)
|
| 88 |
result = inference_resnet_finer_beit(input_image,model,size=384,n_classes=n_classes)
|
|
@@ -100,7 +164,12 @@ def get_embeddings(input_image,model_name):
|
|
| 100 |
model, n_classes= get_model(model_name)
|
| 101 |
result = inference_resnet_embedding(input_image,model,size=600,n_classes=n_classes)
|
| 102 |
return result
|
| 103 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 104 |
from inference_beit import inference_resnet_embedding_beit
|
| 105 |
model,n_classes = get_model(model_name)
|
| 106 |
result = inference_resnet_embedding_beit(input_image,model,size=384,n_classes=n_classes)
|
|
@@ -114,30 +183,103 @@ def find_closest(input_image,model_name):
|
|
| 114 |
#outputs = classes+paths
|
| 115 |
return classes,paths
|
| 116 |
|
| 117 |
-
def
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 118 |
model,n_classes= get_model(model_name)
|
| 119 |
-
if model_name=='Fossils 142':
|
| 120 |
size = 384
|
| 121 |
else:
|
| 122 |
size = 600
|
| 123 |
#saliency, integrated, smoothgrad,
|
| 124 |
-
exp_list = explain(model,input_image,size = size, n_classes=n_classes)
|
| 125 |
#original = saliency + integrated + smoothgrad
|
| 126 |
print('done')
|
| 127 |
-
sobol1,sobol2,sobol3,sobol4,sobol5 = exp_list[0],exp_list[1],exp_list[2],exp_list[3],exp_list[4]
|
| 128 |
-
rise1,rise2,rise3,rise4,rise5 = exp_list[5],exp_list[6],exp_list[7],exp_list[8],exp_list[9]
|
| 129 |
-
hsic1,hsic2,hsic3,hsic4,hsic5 = exp_list[10],exp_list[11],exp_list[12],exp_list[13],exp_list[14]
|
| 130 |
-
saliency1,saliency2,saliency3,saliency4,saliency5 = exp_list[15],exp_list[16],exp_list[17],exp_list[18],exp_list[19]
|
| 131 |
-
return sobol1,sobol2,sobol3,sobol4,sobol5,rise1,rise2,rise3,rise4,rise5,hsic1,hsic2,hsic3,hsic4,hsic5,saliency1,saliency2,saliency3,saliency4,saliency5
|
| 132 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 133 |
#minimalist theme
|
| 134 |
with gr.Blocks(theme='sudeepshouche/minimalist') as demo:
|
| 135 |
|
| 136 |
with gr.Tab(" Florrissant Fossils"):
|
| 137 |
-
|
| 138 |
with gr.Row():
|
| 139 |
with gr.Column():
|
| 140 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 141 |
classify_image_button = gr.Button("Classify Image")
|
| 142 |
|
| 143 |
# with gr.Column():
|
|
@@ -148,21 +290,101 @@ with gr.Blocks(theme='sudeepshouche/minimalist') as demo:
|
|
| 148 |
|
| 149 |
with gr.Column():
|
| 150 |
model_name = gr.Dropdown(
|
| 151 |
-
["Mummified 170", "Rock 170","Fossils 142"],
|
| 152 |
multiselect=False,
|
| 153 |
-
value="Fossils
|
| 154 |
label="Model",
|
| 155 |
interactive=True,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 156 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 157 |
class_predicted = gr.Label(label='Class Predicted',num_top_classes=10)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 158 |
|
| 159 |
-
with gr.Row():
|
| 160 |
-
|
| 161 |
-
paths = sorted(pathlib.Path('images/').rglob('*.jpg'))
|
| 162 |
-
samples=[[path.as_posix()] for path in paths if 'fossils' in str(path) ][:19]
|
| 163 |
-
examples_fossils = gr.Examples(samples, inputs=input_image,examples_per_page=10,label='Fossils Examples from the dataset')
|
| 164 |
-
samples=[[path.as_posix()] for path in paths if 'leaves' in str(path) ][:19]
|
| 165 |
-
examples_leaves = gr.Examples(samples, inputs=input_image,examples_per_page=5,label='Leaves Examples from the dataset')
|
| 166 |
|
| 167 |
# with gr.Accordion("Using Diffuser"):
|
| 168 |
# with gr.Column():
|
|
@@ -173,80 +395,20 @@ with gr.Blocks(theme='sudeepshouche/minimalist') as demo:
|
|
| 173 |
# class_predicted2 = gr.Label(label='Class Predicted from diffuser')
|
| 174 |
# classify_button = gr.Button("Classify Image")
|
| 175 |
|
| 176 |
-
|
| 177 |
-
with gr.Accordion("Explanations "):
|
| 178 |
-
gr.Markdown("Computing Explanations from the model")
|
| 179 |
-
with gr.Column():
|
| 180 |
-
with gr.Row():
|
| 181 |
-
|
| 182 |
-
#original_input = gr.Image(label="Original Frame")
|
| 183 |
-
#saliency = gr.Image(label="saliency")
|
| 184 |
-
#gradcam = gr.Image(label='integraged gradients')
|
| 185 |
-
#guided_gradcam = gr.Image(label='gradcam')
|
| 186 |
-
#guided_backprop = gr.Image(label='guided backprop')
|
| 187 |
-
sobol1 = gr.Image(label = 'Sobol1')
|
| 188 |
-
sobol2= gr.Image(label = 'Sobol2')
|
| 189 |
-
sobol3= gr.Image(label = 'Sobol3')
|
| 190 |
-
sobol4= gr.Image(label = 'Sobol4')
|
| 191 |
-
sobol5= gr.Image(label = 'Sobol5')
|
| 192 |
-
|
| 193 |
-
with gr.Row():
|
| 194 |
-
rise1 = gr.Image(label = 'Rise1')
|
| 195 |
-
rise2 = gr.Image(label = 'Rise2')
|
| 196 |
-
rise3 = gr.Image(label = 'Rise3')
|
| 197 |
-
rise4 = gr.Image(label = 'Rise4')
|
| 198 |
-
rise5 = gr.Image(label = 'Rise5')
|
| 199 |
-
|
| 200 |
-
with gr.Row():
|
| 201 |
-
hsic1 = gr.Image(label = 'HSIC1')
|
| 202 |
-
hsic2 = gr.Image(label = 'HSIC2')
|
| 203 |
-
hsic3 = gr.Image(label = 'HSIC3')
|
| 204 |
-
hsic4 = gr.Image(label = 'HSIC4')
|
| 205 |
-
hsic5 = gr.Image(label = 'HSIC5')
|
| 206 |
-
|
| 207 |
-
with gr.Row():
|
| 208 |
-
saliency1 = gr.Image(label = 'Saliency1')
|
| 209 |
-
saliency2 = gr.Image(label = 'Saliency2')
|
| 210 |
-
saliency3 = gr.Image(label = 'Saliency3')
|
| 211 |
-
saliency4 = gr.Image(label = 'Saliency4')
|
| 212 |
-
saliency5 = gr.Image(label = 'Saliency5')
|
| 213 |
-
|
| 214 |
-
|
| 215 |
-
generate_explanations = gr.Button("Generate Explanations")
|
| 216 |
-
|
| 217 |
-
# with gr.Accordion('Closest Images'):
|
| 218 |
-
# gr.Markdown("Finding the closest images in the dataset")
|
| 219 |
-
# with gr.Row():
|
| 220 |
-
# with gr.Column():
|
| 221 |
-
# label_closest_image_0 = gr.Markdown('')
|
| 222 |
-
# closest_image_0 = gr.Image(label='Closest Image',image_mode='contain',width=200, height=200)
|
| 223 |
-
# with gr.Column():
|
| 224 |
-
# label_closest_image_1 = gr.Markdown('')
|
| 225 |
-
# closest_image_1 = gr.Image(label='Second Closest Image',image_mode='contain',width=200, height=200)
|
| 226 |
-
# with gr.Column():
|
| 227 |
-
# label_closest_image_2 = gr.Markdown('')
|
| 228 |
-
# closest_image_2 = gr.Image(label='Third Closest Image',image_mode='contain',width=200, height=200)
|
| 229 |
-
# with gr.Column():
|
| 230 |
-
# label_closest_image_3 = gr.Markdown('')
|
| 231 |
-
# closest_image_3 = gr.Image(label='Forth Closest Image',image_mode='contain', width=200, height=200)
|
| 232 |
-
# with gr.Column():
|
| 233 |
-
# label_closest_image_4 = gr.Markdown('')
|
| 234 |
-
# closest_image_4 = gr.Image(label='Fifth Closest Image',image_mode='contain',width=200, height=200)
|
| 235 |
-
# find_closest_btn = gr.Button("Find Closest Images")
|
| 236 |
-
with gr.Accordion('Closest Images'):
|
| 237 |
-
gr.Markdown("Finding the closest images in the dataset")
|
| 238 |
-
|
| 239 |
-
with gr.Row():
|
| 240 |
-
gallery = gr.Gallery(label="Closest Images", show_label=False,elem_id="gallery",columns=[5], rows=[1],height='auto', allow_preview=True, preview=None)
|
| 241 |
-
#.style(grid=[1, 5], height=200, width=200)
|
| 242 |
|
| 243 |
-
|
| 244 |
-
|
| 245 |
-
|
| 246 |
-
|
| 247 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 248 |
#find_closest_btn.click(find_closest, inputs=[input_image,model_name], outputs=[label_closest_image_0,label_closest_image_1,label_closest_image_2,label_closest_image_3,label_closest_image_4,closest_image_0,closest_image_1,closest_image_2,closest_image_3,closest_image_4])
|
| 249 |
-
def
|
| 250 |
labels, images = find_closest(input_image,model_name)
|
| 251 |
#labels_html = "".join([f'<div style="display: inline-block; text-align: center; width: 18%;">{label}</div>' for label in labels])
|
| 252 |
#labels_markdown = f"<div style='width: 100%; text-align: center;'>{labels_html}</div>"
|
|
@@ -255,8 +417,19 @@ with gr.Blocks(theme='sudeepshouche/minimalist') as demo:
|
|
| 255 |
image_caption.append((images[i],labels[i]))
|
| 256 |
return image_caption
|
| 257 |
|
| 258 |
-
find_closest_btn.click(fn=
|
| 259 |
#classify_segmented_button.click(classify_image, inputs=[segmented_image,model_name], outputs=class_predicted)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 260 |
|
| 261 |
demo.queue() # manage multiple incoming requests
|
| 262 |
|
|
|
|
| 21 |
from inference_beit import get_triplet_model_beit
|
| 22 |
import pathlib
|
| 23 |
import tensorflow as tf
|
| 24 |
+
from closest_sample import get_images,get_diagram
|
| 25 |
+
|
| 26 |
|
| 27 |
if not os.path.exists('images'):
|
| 28 |
REPO_ID='Serrelab/image_examples_gradio'
|
|
|
|
| 36 |
print("warning! A read token in env variables is needed for authentication.")
|
| 37 |
snapshot_download(repo_id=REPO_ID, token=token,repo_type='dataset',local_dir='dataset')
|
| 38 |
|
| 39 |
+
HEADER = '''
|
| 40 |
+
<h2><b>Official Gradio Demo</b></h2><h2><a href='https://huggingface.co/spaces/Serrelab/fossil_app' target='_blank'><b>Identifying Florissant Leaf Fossils to Family using Deep Neural Networks </b></a></h2>
|
| 41 |
+
Code: <a href='https://github.com/orgs/serre-lab/projects/2' target='_blank'>GitHub</a>. Paper: <a href='' target='_blank'>ArXiv</a>.
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
'''
|
| 45 |
+
|
| 46 |
+
"""
|
| 47 |
+
**Fossil** a brief intro to the project.
|
| 48 |
+
# ❗️❗️❗️**Important Notes:**
|
| 49 |
+
# - some notes to users some notes to users some notes to users some notes to users some notes to users some notes to users .
|
| 50 |
+
# - some notes to users some notes to users some notes to users some notes to users some notes to users some notes to users.
|
| 51 |
+
|
| 52 |
+
"""
|
| 53 |
+
|
| 54 |
+
USER_GUIDE = """
|
| 55 |
+
<div style='background-color: #f0f0f0; padding: 20px; border-radius: 10px;'>
|
| 56 |
+
<h2>❗️ User Guide</h2>
|
| 57 |
+
<p>Welcome to the interactive fossil exploration tool. Here's how to get started:</p>
|
| 58 |
+
<ul>
|
| 59 |
+
<li><strong>Upload an Image:</strong> Drag and drop or choose from given samples to upload images of fossils.</li>
|
| 60 |
+
<li><strong>Process Image:</strong> After uploading, click the 'Process Image' button to analyze the image.</li>
|
| 61 |
+
<li><strong>Explore Results:</strong> Switch to the 'Workbench' tab to check out detailed analysis and results.</li>
|
| 62 |
+
</ul>
|
| 63 |
+
<h3>Tips</h3>
|
| 64 |
+
<ul>
|
| 65 |
+
<li>Zoom into images on the workbench for finer details.</li>
|
| 66 |
+
<li>Use the examples below as references for what types of images to upload.</li>
|
| 67 |
+
</ul>
|
| 68 |
+
<p>Enjoy exploring! 🌟</p>
|
| 69 |
+
</div>
|
| 70 |
+
"""
|
| 71 |
+
|
| 72 |
+
TIPS = """
|
| 73 |
+
## Tips
|
| 74 |
+
- Zoom into images on the workbench for finer details.
|
| 75 |
+
- Use the examples below as references for what types of images to upload.
|
| 76 |
+
|
| 77 |
+
Enjoy exploring!
|
| 78 |
+
"""
|
| 79 |
+
CITATION = '''
|
| 80 |
+
📧 **Contact** <br>
|
| 81 |
+
If you have any questions, feel free to contact us at <b>ivan_felipe_rodriguez@brown.edu</b>.
|
| 82 |
+
'''
|
| 83 |
+
"""
|
| 84 |
+
📝 **Citation**
|
| 85 |
+
cite using this bibtex:...
|
| 86 |
+
```
|
| 87 |
+
```
|
| 88 |
+
📋 **License**
|
| 89 |
+
"""
|
| 90 |
def get_model(model_name):
|
| 91 |
|
| 92 |
|
|
|
|
| 113 |
embedding_depth = 2,
|
| 114 |
n_classes = n_classes)
|
| 115 |
model.load_weights('model_classification/fossil-142.h5')
|
| 116 |
+
elif model_name == 'Fossils new':
|
| 117 |
+
n_classes = 142
|
| 118 |
+
model = get_triplet_model_beit(input_shape = (384, 384, 3),
|
| 119 |
+
embedding_units = 256,
|
| 120 |
+
embedding_depth = 2,
|
| 121 |
+
n_classes = n_classes)
|
| 122 |
+
model.load_weights('model_classification/fossil-new.h5')
|
| 123 |
else:
|
| 124 |
raise ValueError(f"Model name '{model_name}' is not recognized")
|
| 125 |
return model,n_classes
|
|
|
|
| 141 |
model, n_classes= get_model(model_name)
|
| 142 |
result = inference_resnet_finer(input_image,model,size=600,n_classes=n_classes)
|
| 143 |
return result
|
| 144 |
+
elif 'Fossils 142' ==model_name:
|
| 145 |
+
from inference_beit import inference_resnet_finer_beit
|
| 146 |
+
model,n_classes = get_model(model_name)
|
| 147 |
+
result = inference_resnet_finer_beit(input_image,model,size=384,n_classes=n_classes)
|
| 148 |
+
return result
|
| 149 |
+
elif 'Fossils new' ==model_name:
|
| 150 |
from inference_beit import inference_resnet_finer_beit
|
| 151 |
model,n_classes = get_model(model_name)
|
| 152 |
result = inference_resnet_finer_beit(input_image,model,size=384,n_classes=n_classes)
|
|
|
|
| 164 |
model, n_classes= get_model(model_name)
|
| 165 |
result = inference_resnet_embedding(input_image,model,size=600,n_classes=n_classes)
|
| 166 |
return result
|
| 167 |
+
elif 'Fossils 142' ==model_name:
|
| 168 |
+
from inference_beit import inference_resnet_embedding_beit
|
| 169 |
+
model,n_classes = get_model(model_name)
|
| 170 |
+
result = inference_resnet_embedding_beit(input_image,model,size=384,n_classes=n_classes)
|
| 171 |
+
return result
|
| 172 |
+
elif 'Fossils new' ==model_name:
|
| 173 |
from inference_beit import inference_resnet_embedding_beit
|
| 174 |
model,n_classes = get_model(model_name)
|
| 175 |
result = inference_resnet_embedding_beit(input_image,model,size=384,n_classes=n_classes)
|
|
|
|
| 183 |
#outputs = classes+paths
|
| 184 |
return classes,paths
|
| 185 |
|
| 186 |
+
def generate_diagram_closest(input_image,model_name,top_k):
|
| 187 |
+
embedding = get_embeddings(input_image,model_name)
|
| 188 |
+
diagram_path = get_diagram(embedding,top_k)
|
| 189 |
+
return diagram_path
|
| 190 |
+
|
| 191 |
+
def explain_image(input_image,model_name,explain_method,nb_samples):
|
| 192 |
model,n_classes= get_model(model_name)
|
| 193 |
+
if model_name=='Fossils 142' or 'Fossils new':
|
| 194 |
size = 384
|
| 195 |
else:
|
| 196 |
size = 600
|
| 197 |
#saliency, integrated, smoothgrad,
|
| 198 |
+
classes,exp_list = explain(model,input_image,explain_method,nb_samples,size = size, n_classes=n_classes)
|
| 199 |
#original = saliency + integrated + smoothgrad
|
| 200 |
print('done')
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 201 |
|
| 202 |
+
return classes,exp_list
|
| 203 |
+
|
| 204 |
+
def setup_examples():
|
| 205 |
+
paths = sorted(pathlib.Path('images/').rglob('*.jpg'))
|
| 206 |
+
samples = [path.as_posix() for path in paths if 'fossils' in str(path)][:19]
|
| 207 |
+
examples_fossils = gr.Examples(samples, inputs=input_image,examples_per_page=5,label='Fossils Examples from the dataset')
|
| 208 |
+
samples=[[path.as_posix()] for path in paths if 'leaves' in str(path) ][:19]
|
| 209 |
+
examples_leaves = gr.Examples(samples, inputs=input_image,examples_per_page=5,label='Leaves Examples from the dataset')
|
| 210 |
+
return examples_fossils,examples_leaves
|
| 211 |
+
|
| 212 |
+
def preprocess_image(image, output_size=(300, 300)):
|
| 213 |
+
#shape (height, width, channels)
|
| 214 |
+
h, w = image.shape[:2]
|
| 215 |
+
|
| 216 |
+
#padding
|
| 217 |
+
if h > w:
|
| 218 |
+
padding = (h - w) // 2
|
| 219 |
+
image_padded = cv2.copyMakeBorder(image, 0, 0, padding, padding, cv2.BORDER_CONSTANT, value=[0, 0, 0])
|
| 220 |
+
else:
|
| 221 |
+
padding = (w - h) // 2
|
| 222 |
+
image_padded = cv2.copyMakeBorder(image, padding, padding, 0, 0, cv2.BORDER_CONSTANT, value=[0, 0, 0])
|
| 223 |
+
|
| 224 |
+
# resize
|
| 225 |
+
image_resized = cv2.resize(image_padded, output_size, interpolation=cv2.INTER_AREA)
|
| 226 |
+
|
| 227 |
+
return image_resized
|
| 228 |
+
|
| 229 |
+
def update_display(image):
|
| 230 |
+
processed_image = preprocess_image(image)
|
| 231 |
+
instruction = "Image ready. Please switch to the 'Specimen Workbench' tab to check out further analysis and outputs."
|
| 232 |
+
model_name = gr.Dropdown(
|
| 233 |
+
["Mummified 170", "Rock 170","Fossils 142","Fossils new"],
|
| 234 |
+
multiselect=False,
|
| 235 |
+
value="Fossils new", # default option
|
| 236 |
+
label="Model",
|
| 237 |
+
interactive=True,
|
| 238 |
+
info="Choose the model you'd like to use"
|
| 239 |
+
)
|
| 240 |
+
explain_method = gr.Dropdown(
|
| 241 |
+
["Sobol", "HSIC","Rise","Saliency"],
|
| 242 |
+
multiselect=False,
|
| 243 |
+
value="Rise", # default option
|
| 244 |
+
label="Explain method",
|
| 245 |
+
interactive=True,
|
| 246 |
+
info="Choose one method to explain the model"
|
| 247 |
+
)
|
| 248 |
+
sampling_size = gr.Slider(1, 5000, value=2000, label="Sampling Size in Rise",interactive=True,visible=True,
|
| 249 |
+
info="Choose between 1 and 5000")
|
| 250 |
+
|
| 251 |
+
top_k = gr.Slider(10,200,value=50,label="Number of Closest Samples for Distribution Chart",interactive=True,info="Choose between 10 and 200")
|
| 252 |
+
class_predicted = gr.Label(label='Class Predicted',num_top_classes=10)
|
| 253 |
+
exp_gallery = gr.Gallery(label="Explanation Heatmaps for top 5 predicted classes", show_label=False,elem_id="gallery",columns=[5], rows=[1],height='auto', allow_preview=True, preview=None)
|
| 254 |
+
closest_gallery = gr.Gallery(label="Closest Images", show_label=False,elem_id="gallery",columns=[5], rows=[1],height='auto', allow_preview=True, preview=None)
|
| 255 |
+
diagram= gr.Image(label = 'Bar Chart')
|
| 256 |
+
return processed_image,processed_image,instruction,model_name,explain_method,sampling_size,top_k,class_predicted,exp_gallery,closest_gallery,diagram
|
| 257 |
+
def update_slider_visibility(explain_method):
|
| 258 |
+
bool = explain_method=="Rise"
|
| 259 |
+
return {sampling_size: gr.Slider(1, 5000, value=2000, label="Sampling Size in Rise", visible=bool, interactive=True)}
|
| 260 |
+
|
| 261 |
#minimalist theme
|
| 262 |
with gr.Blocks(theme='sudeepshouche/minimalist') as demo:
|
| 263 |
|
| 264 |
with gr.Tab(" Florrissant Fossils"):
|
| 265 |
+
gr.Markdown(HEADER)
|
| 266 |
with gr.Row():
|
| 267 |
with gr.Column():
|
| 268 |
+
gr.Markdown(USER_GUIDE)
|
| 269 |
+
with gr.Column(scale=2):
|
| 270 |
+
with gr.Column(scale=2):
|
| 271 |
+
instruction_text = gr.Textbox(label="Instructions", value="Upload/Choose an image and click 'Process Image'.")
|
| 272 |
+
input_image = gr.Image(label="Input",width="100%",container=True)
|
| 273 |
+
process_button = gr.Button("Process Image")
|
| 274 |
+
with gr.Column(scale=1):
|
| 275 |
+
examples_fossils,examples_leaves = setup_examples()
|
| 276 |
+
|
| 277 |
+
gr.Markdown(CITATION)
|
| 278 |
+
|
| 279 |
+
with gr.Tab("Specimen Workbench"):
|
| 280 |
+
with gr.Row():
|
| 281 |
+
with gr.Column():
|
| 282 |
+
workbench_image = gr.Image(label="Workbench Image")
|
| 283 |
classify_image_button = gr.Button("Classify Image")
|
| 284 |
|
| 285 |
# with gr.Column():
|
|
|
|
| 290 |
|
| 291 |
with gr.Column():
|
| 292 |
model_name = gr.Dropdown(
|
| 293 |
+
["Mummified 170", "Rock 170","Fossils 142","Fossils new"],
|
| 294 |
multiselect=False,
|
| 295 |
+
value="Fossils new", # default option
|
| 296 |
label="Model",
|
| 297 |
interactive=True,
|
| 298 |
+
info="Choose the model you'd like to use"
|
| 299 |
+
)
|
| 300 |
+
explain_method = gr.Dropdown(
|
| 301 |
+
["Sobol", "HSIC","Rise","Saliency"],
|
| 302 |
+
multiselect=False,
|
| 303 |
+
value="Rise", # default option
|
| 304 |
+
label="Explain method",
|
| 305 |
+
interactive=True,
|
| 306 |
+
info="Choose one method to explain the model"
|
| 307 |
)
|
| 308 |
+
# explain_method = gr.CheckboxGroup(["Sobol", "HSIC","Rise","Saliency"],
|
| 309 |
+
# label="explain method",
|
| 310 |
+
# value="Rise",
|
| 311 |
+
# multiselect=False,
|
| 312 |
+
# interactive=True,)
|
| 313 |
+
sampling_size = gr.Slider(1, 5000, value=2000, label="Sampling Size in Rise",interactive=True,visible=True,
|
| 314 |
+
info="Choose between 1 and 5000")
|
| 315 |
+
|
| 316 |
+
top_k = gr.Slider(10,200,value=50,label="Number of Closest Samples for Distribution Chart",interactive=True,info="Choose between 10 and 200")
|
| 317 |
+
explain_method.change(
|
| 318 |
+
fn=update_slider_visibility,
|
| 319 |
+
inputs=explain_method,
|
| 320 |
+
outputs=sampling_size
|
| 321 |
+
)
|
| 322 |
+
with gr.Row():
|
| 323 |
+
with gr.Column(scale=1):
|
| 324 |
class_predicted = gr.Label(label='Class Predicted',num_top_classes=10)
|
| 325 |
+
with gr.Column(scale=4):
|
| 326 |
+
with gr.Accordion("Explanations "):
|
| 327 |
+
gr.Markdown("Computing Explanations from the model")
|
| 328 |
+
with gr.Column():
|
| 329 |
+
with gr.Row():
|
| 330 |
+
|
| 331 |
+
#original_input = gr.Image(label="Original Frame")
|
| 332 |
+
#saliency = gr.Image(label="saliency")
|
| 333 |
+
#gradcam = gr.Image(label='integraged gradients')
|
| 334 |
+
#guided_gradcam = gr.Image(label='gradcam')
|
| 335 |
+
#guided_backprop = gr.Image(label='guided backprop')
|
| 336 |
+
# exp1 = gr.Image(label = 'Class_name1')
|
| 337 |
+
# exp2= gr.Image(label = 'Class_name2')
|
| 338 |
+
# exp3= gr.Image(label = 'Class_name3')
|
| 339 |
+
# exp4= gr.Image(label = 'Class_name4')
|
| 340 |
+
# exp5= gr.Image(label = 'Class_name5')
|
| 341 |
+
|
| 342 |
+
exp_gallery = gr.Gallery(label="Explanation Heatmaps for top 5 predicted classes", show_label=False,elem_id="gallery",columns=[5], rows=[1],height='auto', allow_preview=True, preview=None)
|
| 343 |
+
|
| 344 |
+
generate_explanations = gr.Button("Generate Explanations")
|
| 345 |
+
|
| 346 |
+
# with gr.Accordion('Closest Images'):
|
| 347 |
+
# gr.Markdown("Finding the closest images in the dataset")
|
| 348 |
+
# with gr.Row():
|
| 349 |
+
# with gr.Column():
|
| 350 |
+
# label_closest_image_0 = gr.Markdown('')
|
| 351 |
+
# closest_image_0 = gr.Image(label='Closest Image',image_mode='contain',width=200, height=200)
|
| 352 |
+
# with gr.Column():
|
| 353 |
+
# label_closest_image_1 = gr.Markdown('')
|
| 354 |
+
# closest_image_1 = gr.Image(label='Second Closest Image',image_mode='contain',width=200, height=200)
|
| 355 |
+
# with gr.Column():
|
| 356 |
+
# label_closest_image_2 = gr.Markdown('')
|
| 357 |
+
# closest_image_2 = gr.Image(label='Third Closest Image',image_mode='contain',width=200, height=200)
|
| 358 |
+
# with gr.Column():
|
| 359 |
+
# label_closest_image_3 = gr.Markdown('')
|
| 360 |
+
# closest_image_3 = gr.Image(label='Forth Closest Image',image_mode='contain', width=200, height=200)
|
| 361 |
+
# with gr.Column():
|
| 362 |
+
# label_closest_image_4 = gr.Markdown('')
|
| 363 |
+
# closest_image_4 = gr.Image(label='Fifth Closest Image',image_mode='contain',width=200, height=200)
|
| 364 |
+
# find_closest_btn = gr.Button("Find Closest Images")
|
| 365 |
+
with gr.Accordion('Closest Fossil Images'):
|
| 366 |
+
gr.Markdown("Finding the closest images in the dataset")
|
| 367 |
+
|
| 368 |
+
with gr.Row():
|
| 369 |
+
closest_gallery = gr.Gallery(label="Closest Images", show_label=False,elem_id="gallery",columns=[5], rows=[1],height='auto', allow_preview=True, preview=None)
|
| 370 |
+
#.style(grid=[1, 5], height=200, width=200)
|
| 371 |
+
|
| 372 |
+
find_closest_btn = gr.Button("Find Closest Images")
|
| 373 |
+
|
| 374 |
+
#segment_button.click(segment_image, inputs=input_image, outputs=segmented_image)
|
| 375 |
+
classify_image_button.click(classify_image, inputs=[input_image,model_name], outputs=class_predicted)
|
| 376 |
+
# generate_exp.click(exp_image, inputs=[input_image,model_name,explain_method,sampling_size], outputs=[exp1,exp2,exp3,exp4,exp5]) #
|
| 377 |
+
with gr.Accordion('Closest Leaves Images'):
|
| 378 |
+
gr.Markdown("5 closest leaves")
|
| 379 |
+
with gr.Accordion("Class Distribution of Closest Samples "):
|
| 380 |
+
gr.Markdown("Visualize class distribution of top-k closest samples in our dataset")
|
| 381 |
+
with gr.Column():
|
| 382 |
+
with gr.Row():
|
| 383 |
+
diagram= gr.Image(label = 'Bar Chart')
|
| 384 |
+
|
| 385 |
+
generate_diagram = gr.Button("Generate Diagram")
|
| 386 |
+
|
| 387 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 388 |
|
| 389 |
# with gr.Accordion("Using Diffuser"):
|
| 390 |
# with gr.Column():
|
|
|
|
| 395 |
# class_predicted2 = gr.Label(label='Class Predicted from diffuser')
|
| 396 |
# classify_button = gr.Button("Classify Image")
|
| 397 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 398 |
|
| 399 |
+
def update_exp_outputs(input_image,model_name,explain_method,nb_samples):
|
| 400 |
+
labels, images = explain_image(input_image,model_name,explain_method,nb_samples)
|
| 401 |
+
#labels_html = "".join([f'<div style="display: inline-block; text-align: center; width: 18%;">{label}</div>' for label in labels])
|
| 402 |
+
#labels_markdown = f"<div style='width: 100%; text-align: center;'>{labels_html}</div>"
|
| 403 |
+
image_caption=[]
|
| 404 |
+
for i in range(5):
|
| 405 |
+
image_caption.append((images[i],"Predicted Class "+str(i)+": "+labels[i]))
|
| 406 |
+
return image_caption
|
| 407 |
+
|
| 408 |
+
generate_explanations.click(fn=update_exp_outputs, inputs=[input_image,model_name,explain_method,sampling_size], outputs=[exp_gallery])
|
| 409 |
+
|
| 410 |
#find_closest_btn.click(find_closest, inputs=[input_image,model_name], outputs=[label_closest_image_0,label_closest_image_1,label_closest_image_2,label_closest_image_3,label_closest_image_4,closest_image_0,closest_image_1,closest_image_2,closest_image_3,closest_image_4])
|
| 411 |
+
def update_closest_outputs(input_image,model_name):
|
| 412 |
labels, images = find_closest(input_image,model_name)
|
| 413 |
#labels_html = "".join([f'<div style="display: inline-block; text-align: center; width: 18%;">{label}</div>' for label in labels])
|
| 414 |
#labels_markdown = f"<div style='width: 100%; text-align: center;'>{labels_html}</div>"
|
|
|
|
| 417 |
image_caption.append((images[i],labels[i]))
|
| 418 |
return image_caption
|
| 419 |
|
| 420 |
+
find_closest_btn.click(fn=update_closest_outputs, inputs=[input_image,model_name], outputs=[closest_gallery])
|
| 421 |
#classify_segmented_button.click(classify_image, inputs=[segmented_image,model_name], outputs=class_predicted)
|
| 422 |
+
|
| 423 |
+
generate_diagram.click(generate_diagram_closest, inputs=[input_image,model_name,top_k], outputs=diagram)
|
| 424 |
+
|
| 425 |
+
process_button.click(
|
| 426 |
+
fn=update_display,
|
| 427 |
+
inputs=input_image,
|
| 428 |
+
outputs=[input_image,workbench_image,instruction_text,model_name,explain_method,sampling_size,top_k,class_predicted,exp_gallery,closest_gallery,diagram]
|
| 429 |
+
)
|
| 430 |
+
|
| 431 |
+
|
| 432 |
+
|
| 433 |
|
| 434 |
demo.queue() # manage multiple incoming requests
|
| 435 |
|
closest_sample.py
CHANGED
|
@@ -5,6 +5,8 @@ import pandas as pd
|
|
| 5 |
import os
|
| 6 |
from huggingface_hub import snapshot_download
|
| 7 |
import requests
|
|
|
|
|
|
|
| 8 |
|
| 9 |
|
| 10 |
pca_fossils = pk.load(open('pca_fossils_170_finer.pkl','rb'))
|
|
@@ -23,7 +25,7 @@ embedding_fossils = np.load('dataset/embedding_fossils_170_finer.npy')
|
|
| 23 |
|
| 24 |
fossils_pd= pd.read_csv('fossils_paths.csv')
|
| 25 |
|
| 26 |
-
def pca_distance(pca,sample,embedding):
|
| 27 |
"""
|
| 28 |
Args:
|
| 29 |
pca:fitted PCA model
|
|
@@ -35,7 +37,7 @@ def pca_distance(pca,sample,embedding):
|
|
| 35 |
s = pca.transform(sample.reshape(1,-1))
|
| 36 |
all = pca.transform(embedding[:,-1])
|
| 37 |
distances = np.linalg.norm(all - s, axis=1)
|
| 38 |
-
return np.argsort(distances)[:
|
| 39 |
|
| 40 |
def return_paths(argsorted,files):
|
| 41 |
paths= []
|
|
@@ -56,7 +58,7 @@ def get_images(embedding):
|
|
| 56 |
|
| 57 |
#pca_embedding_fossils = pca_fossils.transform(embedding_fossils[:,-1])
|
| 58 |
|
| 59 |
-
pca_d =pca_distance(pca_fossils,embedding,embedding_fossils)
|
| 60 |
|
| 61 |
fossils_paths = fossils_pd['file_name'].values
|
| 62 |
|
|
@@ -87,3 +89,54 @@ def get_images(embedding):
|
|
| 87 |
# '/media/data_cifs/projects/prj_fossils/data/processed_data/leavesdb-v1_1/images/Fossil/Florissant_Fossil/original/full/jpg/') for path in paths]
|
| 88 |
|
| 89 |
return classes, local_paths
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 5 |
import os
|
| 6 |
from huggingface_hub import snapshot_download
|
| 7 |
import requests
|
| 8 |
+
import matplotlib.pyplot as plt
|
| 9 |
+
from collections import Counter
|
| 10 |
|
| 11 |
|
| 12 |
pca_fossils = pk.load(open('pca_fossils_170_finer.pkl','rb'))
|
|
|
|
| 25 |
|
| 26 |
fossils_pd= pd.read_csv('fossils_paths.csv')
|
| 27 |
|
| 28 |
+
def pca_distance(pca,sample,embedding,top_k):
|
| 29 |
"""
|
| 30 |
Args:
|
| 31 |
pca:fitted PCA model
|
|
|
|
| 37 |
s = pca.transform(sample.reshape(1,-1))
|
| 38 |
all = pca.transform(embedding[:,-1])
|
| 39 |
distances = np.linalg.norm(all - s, axis=1)
|
| 40 |
+
return np.argsort(distances)[:top_k]
|
| 41 |
|
| 42 |
def return_paths(argsorted,files):
|
| 43 |
paths= []
|
|
|
|
| 58 |
|
| 59 |
#pca_embedding_fossils = pca_fossils.transform(embedding_fossils[:,-1])
|
| 60 |
|
| 61 |
+
pca_d =pca_distance(pca_fossils,embedding,embedding_fossils,top_k=5)
|
| 62 |
|
| 63 |
fossils_paths = fossils_pd['file_name'].values
|
| 64 |
|
|
|
|
| 89 |
# '/media/data_cifs/projects/prj_fossils/data/processed_data/leavesdb-v1_1/images/Fossil/Florissant_Fossil/original/full/jpg/') for path in paths]
|
| 90 |
|
| 91 |
return classes, local_paths
|
| 92 |
+
|
| 93 |
+
def get_diagram(embedding,top_k):
|
| 94 |
+
|
| 95 |
+
#pca_embedding_fossils = pca_fossils.transform(embedding_fossils[:,-1])
|
| 96 |
+
|
| 97 |
+
pca_d =pca_distance(pca_fossils,embedding,embedding_fossils,top_k=top_k)
|
| 98 |
+
|
| 99 |
+
fossils_paths = fossils_pd['file_name'].values
|
| 100 |
+
|
| 101 |
+
paths = return_paths(pca_d,fossils_paths)
|
| 102 |
+
#print(paths)
|
| 103 |
+
|
| 104 |
+
folder_florissant = 'https://storage.googleapis.com/serrelab/prj_fossils/2024/Florissant_Fossil_v2.0/'
|
| 105 |
+
folder_general = 'https://storage.googleapis.com/serrelab/prj_fossils/2024/General_Fossil_v2.0/'
|
| 106 |
+
|
| 107 |
+
classes = []
|
| 108 |
+
for i, path in enumerate(paths):
|
| 109 |
+
local_file_path = f'image_{i}.jpg'
|
| 110 |
+
if 'Florissant_Fossil/512/full/jpg/' in path:
|
| 111 |
+
public_path = path.replace('/gpfs/data/tserre/irodri15/Fossils/new_data/leavesdb-v1_1/images/Fossil/Florissant_Fossil/512/full/jpg/', folder_florissant)
|
| 112 |
+
elif 'General_Fossil/512/full/jpg/' in path:
|
| 113 |
+
public_path = path.replace('/gpfs/data/tserre/irodri15/Fossils/new_data/leavesdb-v1_1/images/Fossil/General_Fossil/512/full/jpg/', folder_general)
|
| 114 |
+
else:
|
| 115 |
+
print("no match found")
|
| 116 |
+
print(public_path)
|
| 117 |
+
#download_public_image(public_path, local_file_path)
|
| 118 |
+
parts = [part for part in public_path.split('/') if part]
|
| 119 |
+
part = parts[-2]
|
| 120 |
+
classes.append(part)
|
| 121 |
+
#local_paths.append(local_file_path)
|
| 122 |
+
#paths= [path.replace('/gpfs/data/tserre/irodri15/Fossils/new_data/leavesdb-v1_1/images/Fossil/Florissant_Fossil/512/full/jpg/',
|
| 123 |
+
# '/media/data_cifs/projects/prj_fossils/data/processed_data/leavesdb-v1_1/images/Fossil/Florissant_Fossil/original/full/jpg/') for path in paths]
|
| 124 |
+
class_counts = Counter(classes)
|
| 125 |
+
|
| 126 |
+
sorted_class_counts = sorted(class_counts.items(), key=lambda item: item[1], reverse=True)
|
| 127 |
+
sorted_classes, sorted_frequencies = zip(*sorted_class_counts)
|
| 128 |
+
colors = plt.cm.viridis(np.linspace(0, 1, len(sorted_classes)))
|
| 129 |
+
fig, ax = plt.subplots()
|
| 130 |
+
ax.bar(sorted_classes, sorted_frequencies,color=colors)
|
| 131 |
+
ax.set_xlabel('Class Label')
|
| 132 |
+
ax.set_ylabel('Frequency')
|
| 133 |
+
ax.set_title('Distribution of '+str(top_k) +' Closest Sample Classes')
|
| 134 |
+
ax.set_xticklabels(class_counts.keys(), rotation=45, ha='right')
|
| 135 |
+
|
| 136 |
+
# Save the diagram to a file
|
| 137 |
+
diagram_path = 'class_distribution_chart.png'
|
| 138 |
+
plt.tight_layout() # Adjust layout to make room for rotated x-axis labels
|
| 139 |
+
plt.savefig(diagram_path)
|
| 140 |
+
plt.close() # Close the figure to free up memory
|
| 141 |
+
|
| 142 |
+
return diagram_path
|
explanations.py
CHANGED
|
@@ -7,6 +7,7 @@ from xplique.attributions.global_sensitivity_analysis import LatinHypercube
|
|
| 7 |
import numpy as np
|
| 8 |
import matplotlib.pyplot as plt
|
| 9 |
from inference_resnet import inference_resnet_finer, preprocess, _clever_crop
|
|
|
|
| 10 |
BATCH_SIZE = 1
|
| 11 |
|
| 12 |
def show(img, p=False, **kwargs):
|
|
@@ -35,7 +36,7 @@ def show(img, p=False, **kwargs):
|
|
| 35 |
|
| 36 |
|
| 37 |
|
| 38 |
-
def explain(model, input_image,size=600, n_classes=171) :
|
| 39 |
"""
|
| 40 |
Generate explanations for a given model and dataset.
|
| 41 |
:param model: The model to explain.
|
|
@@ -45,31 +46,55 @@ def explain(model, input_image,size=600, n_classes=171) :
|
|
| 45 |
:param batch_size: The batch size to use.
|
| 46 |
:return: The explanations.
|
| 47 |
"""
|
| 48 |
-
|
| 49 |
# we only need the classification part of the model
|
| 50 |
class_model = tf.keras.Model(model.input, model.output[1])
|
| 51 |
|
| 52 |
-
explainers = [
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
SobolAttributionMethod(class_model, grid_size=8, nb_design=32),
|
| 58 |
-
Rise(class_model,nb_samples = 5000, batch_size = BATCH_SIZE,grid_size=15,
|
| 59 |
-
preservation_probability=0.5),
|
| 60 |
-
HsicAttributionMethod(class_model,
|
| 61 |
grid_size=7, nb_design=1500,
|
| 62 |
-
sampler = LatinHypercube(binary=True))
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 67 |
cropped,repetitions = _clever_crop(input_image,(size,size))
|
| 68 |
-
size_repetitions = int(size//(repetitions.numpy()+1))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 69 |
X = preprocess(cropped,size=size)
|
| 70 |
predictions = class_model.predict(np.array([X]))
|
| 71 |
#Y = np.argmax(predictions)
|
| 72 |
top_5_indices = np.argsort(predictions[0])[-5:][::-1]
|
|
|
|
|
|
|
|
|
|
| 73 |
#print(top_5_indices)
|
| 74 |
X = np.expand_dims(X, 0)
|
| 75 |
explanations = []
|
|
@@ -81,8 +106,10 @@ def explain(model, input_image,size=600, n_classes=171) :
|
|
| 81 |
phi = np.abs(explainer(X, Y))[0]
|
| 82 |
if len(phi.shape) == 3:
|
| 83 |
phi = np.mean(phi, -1)
|
| 84 |
-
show(X[0]
|
| 85 |
-
show(phi
|
|
|
|
|
|
|
| 86 |
plt.savefig(f'phi_{e}{i}.png')
|
| 87 |
explanations.append(f'phi_{e}{i}.png')
|
| 88 |
# avg=[]
|
|
@@ -101,4 +128,4 @@ def explain(model, input_image,size=600, n_classes=171) :
|
|
| 101 |
if len(explanations)==1:
|
| 102 |
explanations = explanations[0]
|
| 103 |
# return explanations,avg
|
| 104 |
-
return explanations
|
|
|
|
| 7 |
import numpy as np
|
| 8 |
import matplotlib.pyplot as plt
|
| 9 |
from inference_resnet import inference_resnet_finer, preprocess, _clever_crop
|
| 10 |
+
from labels import lookup_140
|
| 11 |
BATCH_SIZE = 1
|
| 12 |
|
| 13 |
def show(img, p=False, **kwargs):
|
|
|
|
| 36 |
|
| 37 |
|
| 38 |
|
| 39 |
+
def explain(model, input_image,explain_method,nb_samples,size=600, n_classes=171) :
|
| 40 |
"""
|
| 41 |
Generate explanations for a given model and dataset.
|
| 42 |
:param model: The model to explain.
|
|
|
|
| 46 |
:param batch_size: The batch size to use.
|
| 47 |
:return: The explanations.
|
| 48 |
"""
|
| 49 |
+
print('using explain_method:',explain_method)
|
| 50 |
# we only need the classification part of the model
|
| 51 |
class_model = tf.keras.Model(model.input, model.output[1])
|
| 52 |
|
| 53 |
+
explainers = []
|
| 54 |
+
if explain_method=="Sobol":
|
| 55 |
+
explainers.append(SobolAttributionMethod(class_model, grid_size=8, nb_design=32))
|
| 56 |
+
if explain_method=="HSIC":
|
| 57 |
+
explainers.append(HsicAttributionMethod(class_model,
|
|
|
|
|
|
|
|
|
|
|
|
|
| 58 |
grid_size=7, nb_design=1500,
|
| 59 |
+
sampler = LatinHypercube(binary=True)))
|
| 60 |
+
if explain_method=="Rise":
|
| 61 |
+
explainers.append(Rise(class_model,nb_samples = nb_samples, batch_size = BATCH_SIZE,grid_size=15,
|
| 62 |
+
preservation_probability=0.5))
|
| 63 |
+
if explain_method=="Saliency":
|
| 64 |
+
explainers.append(Saliency(class_model))
|
| 65 |
+
|
| 66 |
+
# explainers = [
|
| 67 |
+
# #Sobol, RISE, HSIC, Saliency
|
| 68 |
+
# #IntegratedGradients(class_model, steps=50, batch_size=BATCH_SIZE),
|
| 69 |
+
# #SmoothGrad(class_model, nb_samples=50, batch_size=BATCH_SIZE),
|
| 70 |
+
# #GradCAM(class_model),
|
| 71 |
+
# SobolAttributionMethod(class_model, grid_size=8, nb_design=32),
|
| 72 |
+
# HsicAttributionMethod(class_model,
|
| 73 |
+
# grid_size=7, nb_design=1500,
|
| 74 |
+
# sampler = LatinHypercube(binary=True)),
|
| 75 |
+
# Saliency(class_model),
|
| 76 |
+
# Rise(class_model,nb_samples = 5000, batch_size = BATCH_SIZE,grid_size=15,
|
| 77 |
+
# preservation_probability=0.5),
|
| 78 |
+
# #
|
| 79 |
+
# ]
|
| 80 |
+
|
| 81 |
cropped,repetitions = _clever_crop(input_image,(size,size))
|
| 82 |
+
# size_repetitions = int(size//(repetitions.numpy()+1))
|
| 83 |
+
# print(size)
|
| 84 |
+
# print(type(input_image))
|
| 85 |
+
# print(input_image.shape)
|
| 86 |
+
# size_repetitions = int(size//(repetitions+1))
|
| 87 |
+
# print(type(repetitions))
|
| 88 |
+
# print(repetitions)
|
| 89 |
+
# print(size_repetitions)
|
| 90 |
+
# print(type(size_repetitions))
|
| 91 |
X = preprocess(cropped,size=size)
|
| 92 |
predictions = class_model.predict(np.array([X]))
|
| 93 |
#Y = np.argmax(predictions)
|
| 94 |
top_5_indices = np.argsort(predictions[0])[-5:][::-1]
|
| 95 |
+
classes = []
|
| 96 |
+
for index in top_5_indices:
|
| 97 |
+
classes.append(lookup_140[index])
|
| 98 |
#print(top_5_indices)
|
| 99 |
X = np.expand_dims(X, 0)
|
| 100 |
explanations = []
|
|
|
|
| 106 |
phi = np.abs(explainer(X, Y))[0]
|
| 107 |
if len(phi.shape) == 3:
|
| 108 |
phi = np.mean(phi, -1)
|
| 109 |
+
show(X[0])
|
| 110 |
+
show(phi, p=1, alpha=0.4)
|
| 111 |
+
# show(X[0][:,size_repetitions:2*size_repetitions,:])
|
| 112 |
+
# show(phi[:,size_repetitions:2*size_repetitions], p=1, alpha=0.4)
|
| 113 |
plt.savefig(f'phi_{e}{i}.png')
|
| 114 |
explanations.append(f'phi_{e}{i}.png')
|
| 115 |
# avg=[]
|
|
|
|
| 128 |
if len(explanations)==1:
|
| 129 |
explanations = explanations[0]
|
| 130 |
# return explanations,avg
|
| 131 |
+
return classes,explanations
|
inference_resnet.py
CHANGED
|
@@ -7,7 +7,7 @@ else:
|
|
| 7 |
|
| 8 |
from keras.applications import resnet
|
| 9 |
import tensorflow.keras.layers as L
|
| 10 |
-
import os
|
| 11 |
|
| 12 |
from tensorflow.keras.layers import Dense, GlobalAveragePooling2D
|
| 13 |
import matplotlib.pyplot as plt
|
|
|
|
| 7 |
|
| 8 |
from keras.applications import resnet
|
| 9 |
import tensorflow.keras.layers as L
|
| 10 |
+
import os
|
| 11 |
|
| 12 |
from tensorflow.keras.layers import Dense, GlobalAveragePooling2D
|
| 13 |
import matplotlib.pyplot as plt
|