Jellyfish042 commited on
Commit
452ae9b
·
1 Parent(s): 04ab453
Files changed (1) hide show
  1. visualization/html_generator.py +75 -35
visualization/html_generator.py CHANGED
@@ -36,6 +36,7 @@ def get_rwkv_tokenizer():
36
  if _rwkv_tokenizer is None:
37
  from rwkv.rwkv_tokenizer import TRIE_TOKENIZER
38
  import os
 
39
  script_dir = os.path.dirname(os.path.abspath(__file__))
40
  vocab_path = os.path.join(os.path.dirname(script_dir), "support", "rwkv_vocab_v20230424.txt")
41
  _rwkv_tokenizer = TRIE_TOKENIZER(vocab_path)
@@ -297,24 +298,28 @@ def generate_comparison_html(
297
  rwkv_toks = get_tokens_for_range(start_byte, end_byte, rwkv_tokens)
298
 
299
  if re.search(r"\w", token_text, re.UNICODE):
300
- tokens.append({
301
- "type": "word",
302
- "text": token_text,
303
- "byte_start": start_byte,
304
- "byte_end": end_byte,
305
- "word_lower": token_text.lower(),
306
- "qwen_tokens": qwen_toks,
307
- "rwkv_tokens": rwkv_toks,
308
- })
 
 
309
  else:
310
- tokens.append({
311
- "type": "non-word",
312
- "text": token_text,
313
- "byte_start": start_byte,
314
- "byte_end": end_byte,
315
- "qwen_tokens": qwen_toks,
316
- "rwkv_tokens": rwkv_toks,
317
- })
 
 
318
 
319
  # Track word occurrences
320
  word_occurrences = {}
@@ -335,14 +340,16 @@ def generate_comparison_html(
335
  def escape_for_attr(s):
336
  # Escape all characters that could break HTML attributes
337
  # Order matters: & must be escaped first
338
- return (s.replace("&", "&")
339
- .replace('"', """)
340
- .replace("'", "'")
341
- .replace("<", "&lt;")
342
- .replace(">", "&gt;")
343
- .replace("\n", "&#10;")
344
- .replace("\r", "&#13;")
345
- .replace("\t", "&#9;"))
 
 
346
 
347
  for token in tokens:
348
  token_text = token["text"]
@@ -382,7 +389,8 @@ def generate_comparison_html(
382
  ]
383
  # Use base64 encoding to avoid escaping issues
384
  import base64
385
- topk_a_json = base64.b64encode(json.dumps(decoded_pred, ensure_ascii=False).encode('utf-8')).decode('ascii')
 
386
  except Exception as e:
387
  pass
388
  if topk_predictions_b is not None and model_b_token_ranges:
@@ -393,7 +401,8 @@ def generate_comparison_html(
393
  decoded_pred = [pred[0], pred[1], [[tid, prob, decode_token(tid, tokenizer_b, model_type_b)] for tid, prob in pred[2]]]
394
  # Use base64 encoding to avoid escaping issues
395
  import base64
396
- topk_b_json = base64.b64encode(json.dumps(decoded_pred, ensure_ascii=False).encode('utf-8')).decode('ascii')
 
397
  except Exception as e:
398
  pass
399
 
@@ -402,7 +411,13 @@ def generate_comparison_html(
402
  token_deltas = deltas[byte_start:byte_end]
403
  avg_token_delta = sum(token_deltas) / len(token_deltas) if token_deltas else 0
404
 
405
- color = delta_to_color(avg_token_delta, avg_delta, max_deviation)
 
 
 
 
 
 
406
  r, g, b = color
407
 
408
  token_html_parts = []
@@ -857,6 +872,7 @@ def generate_comparison_html(
857
  const avgDelta = {avg_delta_compression};
858
  const slider = document.getElementById('saturation-slider');
859
  const saturationValue = document.getElementById('saturation-value');
 
860
 
861
  const allDeltas = [];
862
  tokenSpans.forEach(token => {{
@@ -864,6 +880,12 @@ def generate_comparison_html(
864
  if (!isNaN(delta)) allDeltas.push(delta);
865
  }});
866
 
 
 
 
 
 
 
867
  function percentile(arr, p) {{
868
  const sorted = [...arr].sort((a, b) => a - b);
869
  const idx = (p / 100) * (sorted.length - 1);
@@ -873,9 +895,9 @@ def generate_comparison_html(
873
  return sorted[lower] + (sorted[upper] - sorted[lower]) * (idx - lower);
874
  }}
875
 
876
- function deltaToColor(delta, avgDelta, maxDeviation) {{
877
  if (maxDeviation === 0) return 'rgb(255, 255, 255)';
878
- const deviation = delta - avgDelta;
879
  let normalized = Math.max(-1, Math.min(1, deviation / maxDeviation));
880
  let r, g, b;
881
  if (normalized < 0) {{
@@ -892,13 +914,31 @@ def generate_comparison_html(
892
  return `rgb(${{r}}, ${{g}}, ${{b}})`;
893
  }}
894
 
895
- function updateColors(percentileValue) {{
896
- const deviations = allDeltas.map(d => Math.abs(d - avgDelta));
897
- const maxDeviation = Math.max(percentile(deviations, percentileValue), 1e-6);
898
- tokenSpans.forEach(token => {{
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
899
  const delta = parseFloat(token.getAttribute('data-delta'));
900
  if (!isNaN(delta)) {{
901
- token.style.backgroundColor = deltaToColor(delta, avgDelta, maxDeviation);
 
902
  }}
903
  }});
904
  }}
 
36
  if _rwkv_tokenizer is None:
37
  from rwkv.rwkv_tokenizer import TRIE_TOKENIZER
38
  import os
39
+
40
  script_dir = os.path.dirname(os.path.abspath(__file__))
41
  vocab_path = os.path.join(os.path.dirname(script_dir), "support", "rwkv_vocab_v20230424.txt")
42
  _rwkv_tokenizer = TRIE_TOKENIZER(vocab_path)
 
298
  rwkv_toks = get_tokens_for_range(start_byte, end_byte, rwkv_tokens)
299
 
300
  if re.search(r"\w", token_text, re.UNICODE):
301
+ tokens.append(
302
+ {
303
+ "type": "word",
304
+ "text": token_text,
305
+ "byte_start": start_byte,
306
+ "byte_end": end_byte,
307
+ "word_lower": token_text.lower(),
308
+ "qwen_tokens": qwen_toks,
309
+ "rwkv_tokens": rwkv_toks,
310
+ }
311
+ )
312
  else:
313
+ tokens.append(
314
+ {
315
+ "type": "non-word",
316
+ "text": token_text,
317
+ "byte_start": start_byte,
318
+ "byte_end": end_byte,
319
+ "qwen_tokens": qwen_toks,
320
+ "rwkv_tokens": rwkv_toks,
321
+ }
322
+ )
323
 
324
  # Track word occurrences
325
  word_occurrences = {}
 
340
  def escape_for_attr(s):
341
  # Escape all characters that could break HTML attributes
342
  # Order matters: & must be escaped first
343
+ return (
344
+ s.replace("&", "&amp;")
345
+ .replace('"', "&quot;")
346
+ .replace("'", "&#39;")
347
+ .replace("<", "&lt;")
348
+ .replace(">", "&gt;")
349
+ .replace("\n", "&#10;")
350
+ .replace("\r", "&#13;")
351
+ .replace("\t", "&#9;")
352
+ )
353
 
354
  for token in tokens:
355
  token_text = token["text"]
 
389
  ]
390
  # Use base64 encoding to avoid escaping issues
391
  import base64
392
+
393
+ topk_a_json = base64.b64encode(json.dumps(decoded_pred, ensure_ascii=False).encode("utf-8")).decode("ascii")
394
  except Exception as e:
395
  pass
396
  if topk_predictions_b is not None and model_b_token_ranges:
 
401
  decoded_pred = [pred[0], pred[1], [[tid, prob, decode_token(tid, tokenizer_b, model_type_b)] for tid, prob in pred[2]]]
402
  # Use base64 encoding to avoid escaping issues
403
  import base64
404
+
405
+ topk_b_json = base64.b64encode(json.dumps(decoded_pred, ensure_ascii=False).encode("utf-8")).decode("ascii")
406
  except Exception as e:
407
  pass
408
 
 
411
  token_deltas = deltas[byte_start:byte_end]
412
  avg_token_delta = sum(token_deltas) / len(token_deltas) if token_deltas else 0
413
 
414
+ # Apply power transformation to enhance color differentiation
415
+ # Preserve the sign and apply power to the absolute value
416
+ power_n = 3
417
+ sign = 1 if avg_token_delta >= 0 else -1
418
+ avg_token_delta_powered = sign * (abs(avg_token_delta) ** power_n)
419
+
420
+ color = delta_to_color(avg_token_delta_powered, avg_delta, max_deviation)
421
  r, g, b = color
422
 
423
  token_html_parts = []
 
872
  const avgDelta = {avg_delta_compression};
873
  const slider = document.getElementById('saturation-slider');
874
  const saturationValue = document.getElementById('saturation-value');
875
+ const powerN = 3; // Must match Python's power_n
876
 
877
  const allDeltas = [];
878
  tokenSpans.forEach(token => {{
 
880
  if (!isNaN(delta)) allDeltas.push(delta);
881
  }});
882
 
883
+ // Apply power transformation to delta values (matching Python's logic)
884
+ function applyPower(delta) {{
885
+ const sign = delta >= 0 ? 1 : -1;
886
+ return sign * Math.pow(Math.abs(delta), powerN);
887
+ }}
888
+
889
  function percentile(arr, p) {{
890
  const sorted = [...arr].sort((a, b) => a - b);
891
  const idx = (p / 100) * (sorted.length - 1);
 
895
  return sorted[lower] + (sorted[upper] - sorted[lower]) * (idx - lower);
896
  }}
897
 
898
+ function deltaToColor(deltaPowered, avgDeltaPowered, maxDeviation) {{
899
  if (maxDeviation === 0) return 'rgb(255, 255, 255)';
900
+ const deviation = deltaPowered - avgDeltaPowered;
901
  let normalized = Math.max(-1, Math.min(1, deviation / maxDeviation));
902
  let r, g, b;
903
  if (normalized < 0) {{
 
914
  return `rgb(${{r}}, ${{g}}, ${{b}})`;
915
  }}
916
 
917
+ // Pre-compute powered deltas and avgDeltaPowered
918
+ const allDeltasPowered = allDeltas.map(d => applyPower(d));
919
+ const avgDeltaPowered = applyPower(avgDelta);
920
+
921
+ // Pre-compute min and max deviations for logarithmic interpolation
922
+ const allDeviations = allDeltasPowered.map(d => Math.abs(d - avgDeltaPowered));
923
+ const minDeviation = Math.max(percentile(allDeviations, 1), 1e-9); // Use 1st percentile to avoid extreme outliers
924
+ const maxDeviationFull = Math.max(percentile(allDeviations, 100), 1e-6);
925
+
926
+ function updateColors(sliderValue) {{
927
+ // Use logarithmic interpolation for smoother perceptual control
928
+ // sliderValue: 50-100, maps to maxDeviation from minDeviation to maxDeviationFull
929
+ // Lower slider value = lower maxDeviation = more saturation (more colors hit the clamp)
930
+ // Higher slider value = higher maxDeviation = less saturation (fewer colors hit the clamp)
931
+ const t = (sliderValue - 50) / 50; // Normalize to 0-1
932
+ // Logarithmic interpolation: exp(lerp(log(min), log(max), t))
933
+ const logMin = Math.log(minDeviation);
934
+ const logMax = Math.log(maxDeviationFull);
935
+ const maxDeviation = Math.exp(logMin + t * (logMax - logMin));
936
+
937
+ tokenSpans.forEach((token, idx) => {{
938
  const delta = parseFloat(token.getAttribute('data-delta'));
939
  if (!isNaN(delta)) {{
940
+ const deltaPowered = applyPower(delta);
941
+ token.style.backgroundColor = deltaToColor(deltaPowered, avgDeltaPowered, maxDeviation);
942
  }}
943
  }});
944
  }}