Update app.py
Browse files
app.py
CHANGED
|
@@ -149,7 +149,7 @@ def color_token(token, weight):
|
|
| 149 |
return f'<span style="background-color: rgba({color[0]}, {color[1]}, {color[2]}, 0.6); padding:2px; border-radius:3px;">{token}</span>'
|
| 150 |
|
| 151 |
def to_html(attention, clean_text_tokens):
|
| 152 |
-
ca =
|
| 153 |
weights = ca / ca.max()
|
| 154 |
html_text = " ".join(
|
| 155 |
color_token(tok, w) for tok, w in zip(clean_text_tokens, weights)
|
|
@@ -305,20 +305,24 @@ def predict(dataset, text, example_index, file, vision_encoder, text_encoder, ts
|
|
| 305 |
|
| 306 |
# Text Heatmap
|
| 307 |
text_tokens = res_json['text_tokens']
|
| 308 |
-
text_attentions = res_json['text_attentions']
|
| 309 |
-
|
| 310 |
if text_encoder == "Qwen":
|
| 311 |
clean_text_tokens = [t.replace("Ġ", "") for t in text_tokens]
|
|
|
|
|
|
|
|
|
|
| 312 |
elif text_encoder == "LLaMA":
|
| 313 |
clean_text_tokens = [t.replace("▁", "") for t in text_tokens]
|
|
|
|
| 314 |
else:
|
| 315 |
pass
|
|
|
|
| 316 |
|
| 317 |
vision_heatmap_html_text = '<div class="gallery-container"><div class="grid-wrap svelte-1atirkn"><div class="grid-container svelte-1atirkn pt-6">'
|
| 318 |
-
for i in range(
|
| 319 |
vision_heatmap_html_text += f'<button class="thumbnail-item thumbnail-lg svelte-1atirkn"><div class="svelte-1pijsyv">'
|
| 320 |
-
vision_heatmap_html_text += to_html(text_attentions[i:i+
|
| 321 |
-
vision_heatmap_html_text += f'</div><div class="caption-label svelte-1atirkn">Heatmap from Layer{i}:{i+
|
| 322 |
vision_heatmap_html_text += '</div></div></div>'
|
| 323 |
|
| 324 |
# Time Series Heatmap
|
|
|
|
| 149 |
return f'<span style="background-color: rgba({color[0]}, {color[1]}, {color[2]}, 0.6); padding:2px; border-radius:3px;">{token}</span>'
|
| 150 |
|
| 151 |
def to_html(attention, clean_text_tokens):
|
| 152 |
+
ca = attention.sum(axis=0)
|
| 153 |
weights = ca / ca.max()
|
| 154 |
html_text = " ".join(
|
| 155 |
color_token(tok, w) for tok, w in zip(clean_text_tokens, weights)
|
|
|
|
| 305 |
|
| 306 |
# Text Heatmap
|
| 307 |
text_tokens = res_json['text_tokens']
|
| 308 |
+
text_attentions = np.array(res_json['text_attentions'])
|
|
|
|
| 309 |
if text_encoder == "Qwen":
|
| 310 |
clean_text_tokens = [t.replace("Ġ", "") for t in text_tokens]
|
| 311 |
+
text_attentions = text_attentions.mean(axis=1) # 28 * seq * seq
|
| 312 |
+
text_attentions = text_attentions * np.arange(1, 1+text_attentions.shape[1]).reshape((text_attentions.shape[1], 1)) # 28 * seq * seq
|
| 313 |
+
text_attentions = text_attentions.mean(axis=1)[:, -len(text_tokens):]
|
| 314 |
elif text_encoder == "LLaMA":
|
| 315 |
clean_text_tokens = [t.replace("▁", "") for t in text_tokens]
|
| 316 |
+
text_attentions = text_attentions.mean(axis=(1,2))[:, -len(text_tokens):]
|
| 317 |
else:
|
| 318 |
pass
|
| 319 |
+
text_attentions_stride = text_attentions.shape[0]//4
|
| 320 |
|
| 321 |
vision_heatmap_html_text = '<div class="gallery-container"><div class="grid-wrap svelte-1atirkn"><div class="grid-container svelte-1atirkn pt-6">'
|
| 322 |
+
for i in range(4):
|
| 323 |
vision_heatmap_html_text += f'<button class="thumbnail-item thumbnail-lg svelte-1atirkn"><div class="svelte-1pijsyv">'
|
| 324 |
+
vision_heatmap_html_text += to_html(text_attentions[text_attentions_stride*i:text_attentions_stride*(i+1)], clean_text_tokens)
|
| 325 |
+
vision_heatmap_html_text += f'</div><div class="caption-label svelte-1atirkn">Heatmap from Layer{text_attentions_stride*i}:{text_attentions_stride*(i+1)}</div></button>'
|
| 326 |
vision_heatmap_html_text += '</div></div></div>'
|
| 327 |
|
| 328 |
# Time Series Heatmap
|