adnlp commited on
Commit
38518f4
·
verified ·
1 Parent(s): e9e368b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +10 -6
app.py CHANGED
@@ -149,7 +149,7 @@ def color_token(token, weight):
149
  return f'<span style="background-color: rgba({color[0]}, {color[1]}, {color[2]}, 0.6); padding:2px; border-radius:3px;">{token}</span>'
150
 
151
  def to_html(attention, clean_text_tokens):
152
- ca = np.array(attention).mean(axis=0)
153
  weights = ca / ca.max()
154
  html_text = " ".join(
155
  color_token(tok, w) for tok, w in zip(clean_text_tokens, weights)
@@ -305,20 +305,24 @@ def predict(dataset, text, example_index, file, vision_encoder, text_encoder, ts
305
 
306
  # Text Heatmap
307
  text_tokens = res_json['text_tokens']
308
- text_attentions = res_json['text_attentions']
309
-
310
  if text_encoder == "Qwen":
311
  clean_text_tokens = [t.replace("Ġ", "") for t in text_tokens]
 
 
 
312
  elif text_encoder == "LLaMA":
313
  clean_text_tokens = [t.replace("▁", "") for t in text_tokens]
 
314
  else:
315
  pass
 
316
 
317
  vision_heatmap_html_text = '<div class="gallery-container"><div class="grid-wrap svelte-1atirkn"><div class="grid-container svelte-1atirkn pt-6">'
318
- for i in range(0, 12, 3):
319
  vision_heatmap_html_text += f'<button class="thumbnail-item thumbnail-lg svelte-1atirkn"><div class="svelte-1pijsyv">'
320
- vision_heatmap_html_text += to_html(text_attentions[i:i+3], clean_text_tokens)
321
- vision_heatmap_html_text += f'</div><div class="caption-label svelte-1atirkn">Heatmap from Layer{i}:{i+3}</div></button>'
322
  vision_heatmap_html_text += '</div></div></div>'
323
 
324
  # Time Series Heatmap
 
149
  return f'<span style="background-color: rgba({color[0]}, {color[1]}, {color[2]}, 0.6); padding:2px; border-radius:3px;">{token}</span>'
150
 
151
  def to_html(attention, clean_text_tokens):
152
+ ca = attention.sum(axis=0)
153
  weights = ca / ca.max()
154
  html_text = " ".join(
155
  color_token(tok, w) for tok, w in zip(clean_text_tokens, weights)
 
305
 
306
  # Text Heatmap
307
  text_tokens = res_json['text_tokens']
308
+ text_attentions = np.array(res_json['text_attentions'])
 
309
  if text_encoder == "Qwen":
310
  clean_text_tokens = [t.replace("Ġ", "") for t in text_tokens]
311
+ text_attentions = text_attentions.mean(axis=1) # 28 * seq * seq
312
+ text_attentions = text_attentions * np.arange(1, 1+text_attentions.shape[1]).reshape((text_attentions.shape[1], 1)) # 28 * seq * seq
313
+ text_attentions = text_attentions.mean(axis=1)[:, -len(text_tokens):]
314
  elif text_encoder == "LLaMA":
315
  clean_text_tokens = [t.replace("▁", "") for t in text_tokens]
316
+ text_attentions = text_attentions.mean(axis=(1,2))[:, -len(text_tokens):]
317
  else:
318
  pass
319
+ text_attentions_stride = text_attentions.shape[0]//4
320
 
321
  vision_heatmap_html_text = '<div class="gallery-container"><div class="grid-wrap svelte-1atirkn"><div class="grid-container svelte-1atirkn pt-6">'
322
+ for i in range(4):
323
  vision_heatmap_html_text += f'<button class="thumbnail-item thumbnail-lg svelte-1atirkn"><div class="svelte-1pijsyv">'
324
+ vision_heatmap_html_text += to_html(text_attentions[text_attentions_stride*i:text_attentions_stride*(i+1)], clean_text_tokens)
325
+ vision_heatmap_html_text += f'</div><div class="caption-label svelte-1atirkn">Heatmap from Layer{text_attentions_stride*i}:{text_attentions_stride*(i+1)}</div></button>'
326
  vision_heatmap_html_text += '</div></div></div>'
327
 
328
  # Time Series Heatmap