Spaces:
Runtime error
Runtime error
Martijn van Beers commited on
Commit ·
8f3d1af
1
Parent(s): cf1865f
Remove code for jupyter notebooks
Browse filesThere was some partially commented out code to create matplotlib
figures. Remove it altogether.
- CLIP_explainability/utils.py +7 -25
- app.py +2 -2
CLIP_explainability/utils.py
CHANGED
|
@@ -69,7 +69,7 @@ def interpret(image, texts, model, device):
|
|
| 69 |
return text_relevance, image_relevance
|
| 70 |
|
| 71 |
|
| 72 |
-
def show_image_relevance(image_relevance, image, orig_image, device
|
| 73 |
# create heatmap from mask on image
|
| 74 |
def show_cam_on_image(img, mask):
|
| 75 |
heatmap = cv2.applyColorMap(np.uint8(255 * mask), cv2.COLORMAP_JET)
|
|
@@ -78,15 +78,6 @@ def show_image_relevance(image_relevance, image, orig_image, device, show=True):
|
|
| 78 |
cam = cam / np.max(cam)
|
| 79 |
return cam
|
| 80 |
|
| 81 |
-
# plt.axis('off')
|
| 82 |
-
# f, axarr = plt.subplots(1,2)
|
| 83 |
-
# axarr[0].imshow(orig_image)
|
| 84 |
-
|
| 85 |
-
if show:
|
| 86 |
-
fig, axs = plt.subplots(1, 2)
|
| 87 |
-
axs[0].imshow(orig_image);
|
| 88 |
-
axs[0].axis('off');
|
| 89 |
-
|
| 90 |
image_relevance = image_relevance.reshape(1, 1, 7, 7)
|
| 91 |
image_relevance = torch.nn.functional.interpolate(image_relevance, size=224, mode='bilinear')
|
| 92 |
image_relevance = image_relevance.reshape(224, 224).to(device).data.cpu().numpy()
|
|
@@ -97,16 +88,10 @@ def show_image_relevance(image_relevance, image, orig_image, device, show=True):
|
|
| 97 |
vis = np.uint8(255 * vis)
|
| 98 |
vis = cv2.cvtColor(np.array(vis), cv2.COLOR_RGB2BGR)
|
| 99 |
|
| 100 |
-
if show:
|
| 101 |
-
# axar[1].imshow(vis)
|
| 102 |
-
axs[1].imshow(vis);
|
| 103 |
-
axs[1].axis('off');
|
| 104 |
-
# plt.imshow(vis)
|
| 105 |
-
|
| 106 |
return image_relevance
|
| 107 |
|
| 108 |
|
| 109 |
-
def show_heatmap_on_text(text, text_encoding, R_text
|
| 110 |
CLS_idx = text_encoding.argmax(dim=-1)
|
| 111 |
R_text = R_text[CLS_idx, 1:CLS_idx]
|
| 112 |
text_scores = R_text / R_text.sum()
|
|
@@ -115,19 +100,16 @@ def show_heatmap_on_text(text, text_encoding, R_text, show=True):
|
|
| 115 |
text_tokens=_tokenizer.encode(text)
|
| 116 |
text_tokens_decoded=[_tokenizer.decode([a]) for a in text_tokens]
|
| 117 |
vis_data_records = [visualization.VisualizationDataRecord(text_scores,0,0,0,0,0,text_tokens_decoded,1)]
|
| 118 |
-
|
| 119 |
-
if show:
|
| 120 |
-
visualization.visualize_text(vis_data_records)
|
| 121 |
|
| 122 |
return text_scores, text_tokens_decoded
|
| 123 |
|
| 124 |
|
| 125 |
-
def show_img_heatmap(image_relevance, image, orig_image, device
|
| 126 |
-
return show_image_relevance(image_relevance, image, orig_image, device
|
| 127 |
|
| 128 |
|
| 129 |
-
def show_txt_heatmap(text, text_encoding, R_text
|
| 130 |
-
return show_heatmap_on_text(text, text_encoding, R_text
|
| 131 |
|
| 132 |
|
| 133 |
def load_dataset():
|
|
@@ -149,4 +131,4 @@ class color:
|
|
| 149 |
RED = '\033[91m'
|
| 150 |
BOLD = '\033[1m'
|
| 151 |
UNDERLINE = '\033[4m'
|
| 152 |
-
END = '\033[0m'
|
|
|
|
| 69 |
return text_relevance, image_relevance
|
| 70 |
|
| 71 |
|
| 72 |
+
def show_image_relevance(image_relevance, image, orig_image, device):
|
| 73 |
# create heatmap from mask on image
|
| 74 |
def show_cam_on_image(img, mask):
|
| 75 |
heatmap = cv2.applyColorMap(np.uint8(255 * mask), cv2.COLORMAP_JET)
|
|
|
|
| 78 |
cam = cam / np.max(cam)
|
| 79 |
return cam
|
| 80 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 81 |
image_relevance = image_relevance.reshape(1, 1, 7, 7)
|
| 82 |
image_relevance = torch.nn.functional.interpolate(image_relevance, size=224, mode='bilinear')
|
| 83 |
image_relevance = image_relevance.reshape(224, 224).to(device).data.cpu().numpy()
|
|
|
|
| 88 |
vis = np.uint8(255 * vis)
|
| 89 |
vis = cv2.cvtColor(np.array(vis), cv2.COLOR_RGB2BGR)
|
| 90 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 91 |
return image_relevance
|
| 92 |
|
| 93 |
|
| 94 |
+
def show_heatmap_on_text(text, text_encoding, R_text):
|
| 95 |
CLS_idx = text_encoding.argmax(dim=-1)
|
| 96 |
R_text = R_text[CLS_idx, 1:CLS_idx]
|
| 97 |
text_scores = R_text / R_text.sum()
|
|
|
|
| 100 |
text_tokens=_tokenizer.encode(text)
|
| 101 |
text_tokens_decoded=[_tokenizer.decode([a]) for a in text_tokens]
|
| 102 |
vis_data_records = [visualization.VisualizationDataRecord(text_scores,0,0,0,0,0,text_tokens_decoded,1)]
|
|
|
|
|
|
|
|
|
|
| 103 |
|
| 104 |
return text_scores, text_tokens_decoded
|
| 105 |
|
| 106 |
|
| 107 |
+
def show_img_heatmap(image_relevance, image, orig_image, device):
|
| 108 |
+
return show_image_relevance(image_relevance, image, orig_image, device)
|
| 109 |
|
| 110 |
|
| 111 |
+
def show_txt_heatmap(text, text_encoding, R_text):
|
| 112 |
+
return show_heatmap_on_text(text, text_encoding, R_text)
|
| 113 |
|
| 114 |
|
| 115 |
def load_dataset():
|
|
|
|
| 131 |
RED = '\033[91m'
|
| 132 |
BOLD = '\033[1m'
|
| 133 |
UNDERLINE = '\033[4m'
|
| 134 |
+
END = '\033[0m'
|
app.py
CHANGED
|
@@ -59,10 +59,10 @@ def run_demo(image, text):
|
|
| 59 |
|
| 60 |
R_text, R_image = interpret(model=model, image=img, texts=text_input, device=device)
|
| 61 |
|
| 62 |
-
image_relevance = show_img_heatmap(R_image[0], img, orig_image=orig_image, device=device
|
| 63 |
overlapped = overlay_relevance_map_on_image(image, image_relevance)
|
| 64 |
|
| 65 |
-
text_scores, text_tokens_decoded = show_heatmap_on_text(text, text_input, R_text[0]
|
| 66 |
|
| 67 |
highlighted_text = []
|
| 68 |
for i, token in enumerate(text_tokens_decoded):
|
|
|
|
| 59 |
|
| 60 |
R_text, R_image = interpret(model=model, image=img, texts=text_input, device=device)
|
| 61 |
|
| 62 |
+
image_relevance = show_img_heatmap(R_image[0], img, orig_image=orig_image, device=device)
|
| 63 |
overlapped = overlay_relevance_map_on_image(image, image_relevance)
|
| 64 |
|
| 65 |
+
text_scores, text_tokens_decoded = show_heatmap_on_text(text, text_input, R_text[0])
|
| 66 |
|
| 67 |
highlighted_text = []
|
| 68 |
for i, token in enumerate(text_tokens_decoded):
|