Jellyfish042 commited on
Commit
fa470f3
·
1 Parent(s): 452ae9b
Files changed (1) hide show
  1. visualization/html_generator.py +66 -98
visualization/html_generator.py CHANGED
@@ -127,28 +127,6 @@ def get_token_info_for_text(text: str) -> dict:
127
  }
128
 
129
 
130
- def delta_to_color(delta: float, avg_delta: float, max_deviation: float) -> Tuple[int, int, int]:
131
- """Map a delta value to an RGB color based on deviation from average."""
132
- if max_deviation == 0:
133
- return (255, 255, 255)
134
-
135
- deviation = delta - avg_delta
136
- normalized = max(-1, min(1, deviation / max_deviation))
137
-
138
- if normalized < 0:
139
- intensity = -normalized
140
- r = int(255 * (1 - intensity * 0.7))
141
- g = 255
142
- b = int(255 * (1 - intensity * 0.7))
143
- else:
144
- intensity = normalized
145
- r = 255
146
- g = int(255 * (1 - intensity * 0.7))
147
- b = int(255 * (1 - intensity * 0.7))
148
-
149
- return (r, g, b)
150
-
151
-
152
  def generate_comparison_html(
153
  text: str,
154
  byte_losses_a: List[float],
@@ -253,12 +231,6 @@ def generate_comparison_html(
253
  deltas = [a - b for a, b in zip(byte_losses_a, byte_losses_b)]
254
  avg_delta = sum(deltas) / len(deltas) if deltas else 0
255
 
256
- # Calculate max deviation
257
- deviations = [d - avg_delta for d in deltas]
258
- abs_deviations = [abs(dev) for dev in deviations]
259
- max_deviation = float(np.percentile(abs_deviations, 100)) if abs_deviations else 0
260
- max_deviation = max(max_deviation, 1e-6)
261
-
262
  # Calculate average compression rates
263
  avg_compression_a = sum(byte_losses_a) / len(byte_losses_a) * COMPRESSION_RATE_FACTOR if byte_losses_a else 0
264
  avg_compression_b = sum(byte_losses_b) / len(byte_losses_b) * COMPRESSION_RATE_FACTOR if byte_losses_b else 0
@@ -375,6 +347,10 @@ def generate_comparison_html(
375
  compression_a_str = " ".join([f"{l * COMPRESSION_RATE_FACTOR:.2f}%" for l in losses_a])
376
  compression_b_str = " ".join([f"{l * COMPRESSION_RATE_FACTOR:.2f}%" for l in losses_b])
377
 
 
 
 
 
378
  topk_a_json = ""
379
  topk_b_json = ""
380
  if topk_predictions_a is not None and model_a_token_ranges:
@@ -410,15 +386,10 @@ def generate_comparison_html(
410
 
411
  token_deltas = deltas[byte_start:byte_end]
412
  avg_token_delta = sum(token_deltas) / len(token_deltas) if token_deltas else 0
 
413
 
414
- # Apply power transformation to enhance color differentiation
415
- # Preserve the sign and apply power to the absolute value
416
- power_n = 3
417
- sign = 1 if avg_token_delta >= 0 else -1
418
- avg_token_delta_powered = sign * (abs(avg_token_delta) ** power_n)
419
-
420
- color = delta_to_color(avg_token_delta_powered, avg_delta, max_deviation)
421
- r, g, b = color
422
 
423
  token_html_parts = []
424
  for char in token_text:
@@ -445,7 +416,9 @@ def generate_comparison_html(
445
  f'data-bytes="{escape_for_attr(bytes_str)}" '
446
  f'data-compression-a="{escape_for_attr(compression_a_str)}" '
447
  f'data-compression-b="{escape_for_attr(compression_b_str)}" '
448
- f'data-delta="{avg_token_delta * COMPRESSION_RATE_FACTOR:.4f}" '
 
 
449
  f'data-topk-a="{escape_for_attr(topk_a_json)}" '
450
  f'data-topk-b="{escape_for_attr(topk_b_json)}"'
451
  )
@@ -689,9 +662,9 @@ def generate_comparison_html(
689
  <span>RWKV worse than avg</span>
690
  </div>
691
  <div class="legend-item" style="margin-left: 20px;">
692
- <span style="color: #aaa;">Saturation:</span>
693
- <input type="range" id="saturation-slider" min="500" max="1000" value="1000" step="1" style="width: 200px; vertical-align: middle;">
694
- <span id="saturation-value" style="color: #fff; min-width: 45px; display: inline-block;">100.0%</span>
695
  </div>
696
  </div>
697
  </div>
@@ -782,6 +755,8 @@ def generate_comparison_html(
782
  const bytes = token.getAttribute('data-bytes') || '';
783
  const compressionA = token.getAttribute('data-compression-a') || '';
784
  const compressionB = token.getAttribute('data-compression-b') || '';
 
 
785
  const top5A = token.getAttribute('data-topk-a') || '';
786
  const top5B = token.getAttribute('data-topk-b') || '';
787
 
@@ -827,8 +802,8 @@ def generate_comparison_html(
827
 
828
  let tooltipHtml = `
829
  <div><span class="label">Bytes:</span> <span class="bytes">${{bytes || '(empty)'}}</span></div>
830
- <div><span class="label">RWKV Compression Rate:</span> <span class="loss-a">${{compressionA || '(empty)'}}</span></div>
831
- <div><span class="label">Qwen Compression Rate:</span> <span class="loss-b">${{compressionB || '(empty)'}}</span></div>
832
  <hr style="border-color: #555; margin: 6px 0;">
833
  <div><span class="label">RWKV:</span> <span class="model-a">${{modelA || '(empty)'}}</span></div>
834
  <div><span class="label">Qwen:</span> <span class="model-b">${{modelB || '(empty)'}}</span></div>
@@ -869,85 +844,78 @@ def generate_comparison_html(
869
  }});
870
  }});
871
 
872
- const avgDelta = {avg_delta_compression};
873
- const slider = document.getElementById('saturation-slider');
874
- const saturationValue = document.getElementById('saturation-value');
875
- const powerN = 3; // Must match Python's power_n
876
 
877
- const allDeltas = [];
878
- tokenSpans.forEach(token => {{
879
- const delta = parseFloat(token.getAttribute('data-delta'));
880
- if (!isNaN(delta)) allDeltas.push(delta);
 
 
 
881
  }});
882
 
883
- // Apply power transformation to delta values (matching Python's logic)
884
- function applyPower(delta) {{
885
- const sign = delta >= 0 ? 1 : -1;
886
- return sign * Math.pow(Math.abs(delta), powerN);
887
- }}
888
 
889
- function percentile(arr, p) {{
890
- const sorted = [...arr].sort((a, b) => a - b);
891
- const idx = (p / 100) * (sorted.length - 1);
892
- const lower = Math.floor(idx);
893
- const upper = Math.ceil(idx);
894
- if (lower === upper) return sorted[lower];
895
- return sorted[lower] + (sorted[upper] - sorted[lower]) * (idx - lower);
896
- }}
897
 
898
- function deltaToColor(deltaPowered, avgDeltaPowered, maxDeviation) {{
899
- if (maxDeviation === 0) return 'rgb(255, 255, 255)';
900
- const deviation = deltaPowered - avgDeltaPowered;
901
- let normalized = Math.max(-1, Math.min(1, deviation / maxDeviation));
902
  let r, g, b;
903
  if (normalized < 0) {{
 
904
  const intensity = -normalized;
905
- r = Math.round(255 * (1 - intensity * 0.7));
906
  g = 255;
907
- b = Math.round(255 * (1 - intensity * 0.7));
908
  }} else {{
 
909
  const intensity = normalized;
910
  r = 255;
911
- g = Math.round(255 * (1 - intensity * 0.7));
912
- b = Math.round(255 * (1 - intensity * 0.7));
913
  }}
914
  return `rgb(${{r}}, ${{g}}, ${{b}})`;
915
  }}
916
 
917
- // Pre-compute powered deltas and avgDeltaPowered
918
- const allDeltasPowered = allDeltas.map(d => applyPower(d));
919
- const avgDeltaPowered = applyPower(avgDelta);
920
-
921
- // Pre-compute min and max deviations for logarithmic interpolation
922
- const allDeviations = allDeltasPowered.map(d => Math.abs(d - avgDeltaPowered));
923
- const minDeviation = Math.max(percentile(allDeviations, 1), 1e-9); // Use 1st percentile to avoid extreme outliers
924
- const maxDeviationFull = Math.max(percentile(allDeviations, 100), 1e-6);
925
-
926
- function updateColors(sliderValue) {{
927
- // Use logarithmic interpolation for smoother perceptual control
928
- // sliderValue: 50-100, maps to maxDeviation from minDeviation to maxDeviationFull
929
- // Lower slider value = lower maxDeviation = more saturation (more colors hit the clamp)
930
- // Higher slider value = higher maxDeviation = less saturation (fewer colors hit the clamp)
931
- const t = (sliderValue - 50) / 50; // Normalize to 0-1
932
- // Logarithmic interpolation: exp(lerp(log(min), log(max), t))
933
- const logMin = Math.log(minDeviation);
934
- const logMax = Math.log(maxDeviationFull);
935
- const maxDeviation = Math.exp(logMin + t * (logMax - logMin));
936
-
937
- tokenSpans.forEach((token, idx) => {{
938
- const delta = parseFloat(token.getAttribute('data-delta'));
939
- if (!isNaN(delta)) {{
940
- const deltaPowered = applyPower(delta);
941
- token.style.backgroundColor = deltaToColor(deltaPowered, avgDeltaPowered, maxDeviation);
942
  }}
943
  }});
944
  }}
945
 
946
  slider.addEventListener('input', (e) => {{
947
- const val = parseInt(e.target.value) / 10;
948
- saturationValue.textContent = val.toFixed(1) + '%';
949
  updateColors(val);
950
  }});
 
 
 
951
  </script>
952
  </body>
953
  </html>
 
127
  }
128
 
129
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
130
  def generate_comparison_html(
131
  text: str,
132
  byte_losses_a: List[float],
 
231
  deltas = [a - b for a, b in zip(byte_losses_a, byte_losses_b)]
232
  avg_delta = sum(deltas) / len(deltas) if deltas else 0
233
 
 
 
 
 
 
 
234
  # Calculate average compression rates
235
  avg_compression_a = sum(byte_losses_a) / len(byte_losses_a) * COMPRESSION_RATE_FACTOR if byte_losses_a else 0
236
  avg_compression_b = sum(byte_losses_b) / len(byte_losses_b) * COMPRESSION_RATE_FACTOR if byte_losses_b else 0
 
347
  compression_a_str = " ".join([f"{l * COMPRESSION_RATE_FACTOR:.2f}%" for l in losses_a])
348
  compression_b_str = " ".join([f"{l * COMPRESSION_RATE_FACTOR:.2f}%" for l in losses_b])
349
 
350
+ # Calculate average compression rate for this token
351
+ avg_compression_a_token = sum(losses_a) / len(losses_a) * COMPRESSION_RATE_FACTOR if losses_a else 0
352
+ avg_compression_b_token = sum(losses_b) / len(losses_b) * COMPRESSION_RATE_FACTOR if losses_b else 0
353
+
354
  topk_a_json = ""
355
  topk_b_json = ""
356
  if topk_predictions_a is not None and model_a_token_ranges:
 
386
 
387
  token_deltas = deltas[byte_start:byte_end]
388
  avg_token_delta = sum(token_deltas) / len(token_deltas) if token_deltas else 0
389
+ tuned_delta = avg_token_delta - avg_delta
390
 
391
+ # Initial rendering uses white color, JavaScript will apply colors based on slider
392
+ r, g, b = 255, 255, 255
 
 
 
 
 
 
393
 
394
  token_html_parts = []
395
  for char in token_text:
 
416
  f'data-bytes="{escape_for_attr(bytes_str)}" '
417
  f'data-compression-a="{escape_for_attr(compression_a_str)}" '
418
  f'data-compression-b="{escape_for_attr(compression_b_str)}" '
419
+ f'data-avg-compression-a="{avg_compression_a_token:.2f}" '
420
+ f'data-avg-compression-b="{avg_compression_b_token:.2f}" '
421
+ f'data-tuned-delta="{tuned_delta:.6f}" '
422
  f'data-topk-a="{escape_for_attr(topk_a_json)}" '
423
  f'data-topk-b="{escape_for_attr(topk_b_json)}"'
424
  )
 
662
  <span>RWKV worse than avg</span>
663
  </div>
664
  <div class="legend-item" style="margin-left: 20px;">
665
+ <span style="color: #aaa;">Color Range:</span>
666
+ <input type="range" id="color-range-slider" min="0" max="100" value="50" step="1" style="width: 200px; vertical-align: middle;">
667
+ <span id="color-range-value" style="color: #fff; min-width: 45px; display: inline-block;">50%</span>
668
  </div>
669
  </div>
670
  </div>
 
755
  const bytes = token.getAttribute('data-bytes') || '';
756
  const compressionA = token.getAttribute('data-compression-a') || '';
757
  const compressionB = token.getAttribute('data-compression-b') || '';
758
+ const avgCompressionA = token.getAttribute('data-avg-compression-a') || '';
759
+ const avgCompressionB = token.getAttribute('data-avg-compression-b') || '';
760
  const top5A = token.getAttribute('data-topk-a') || '';
761
  const top5B = token.getAttribute('data-topk-b') || '';
762
 
 
802
 
803
  let tooltipHtml = `
804
  <div><span class="label">Bytes:</span> <span class="bytes">${{bytes || '(empty)'}}</span></div>
805
+ <div><span class="label">RWKV Compression Rate:</span> <span class="loss-a">${{compressionA || '(empty)'}}${{avgCompressionA ? ' (avg: ' + avgCompressionA + '%)' : ''}}</span></div>
806
+ <div><span class="label">Qwen Compression Rate:</span> <span class="loss-b">${{compressionB || '(empty)'}}${{avgCompressionB ? ' (avg: ' + avgCompressionB + '%)' : ''}}</span></div>
807
  <hr style="border-color: #555; margin: 6px 0;">
808
  <div><span class="label">RWKV:</span> <span class="model-a">${{modelA || '(empty)'}}</span></div>
809
  <div><span class="label">Qwen:</span> <span class="model-b">${{modelB || '(empty)'}}</span></div>
 
844
  }});
845
  }});
846
 
847
+ const slider = document.getElementById('color-range-slider');
848
+ const rangeValue = document.getElementById('color-range-value');
 
 
849
 
850
+ // Collect all tuned_delta values
851
+ const tokenData = [];
852
+ tokenSpans.forEach((token, idx) => {{
853
+ const tunedDelta = parseFloat(token.getAttribute('data-tuned-delta'));
854
+ if (!isNaN(tunedDelta)) {{
855
+ tokenData.push({{ token, tunedDelta, absDelta: Math.abs(tunedDelta) }});
856
+ }}
857
  }});
858
 
859
+ // Calculate max_abs_tuned_delta for normalization
860
+ const maxAbsDelta = Math.max(...tokenData.map(d => d.absDelta), 1e-9);
 
 
 
861
 
862
+ // Sort by |tuned_delta| to get rankings
863
+ const sortedByAbs = [...tokenData].sort((a, b) => b.absDelta - a.absDelta);
864
+ sortedByAbs.forEach((item, rank) => {{
865
+ item.rank = rank; // rank 0 = largest deviation
866
+ }});
 
 
 
867
 
868
+ function tunedDeltaToColor(tunedDelta, maxAbsDelta) {{
869
+ // Normalize to [-1, 1]
870
+ const normalized = Math.max(-1, Math.min(1, tunedDelta / maxAbsDelta));
 
871
  let r, g, b;
872
  if (normalized < 0) {{
873
+ // Green (RWKV better)
874
  const intensity = -normalized;
875
+ r = Math.round(255 * (1 - intensity * 0.85));
876
  g = 255;
877
+ b = Math.round(255 * (1 - intensity * 0.85));
878
  }} else {{
879
+ // Red (RWKV worse)
880
  const intensity = normalized;
881
  r = 255;
882
+ g = Math.round(255 * (1 - intensity * 0.85));
883
+ b = Math.round(255 * (1 - intensity * 0.85));
884
  }}
885
  return `rgb(${{r}}, ${{g}}, ${{b}})`;
886
  }}
887
 
888
+ function updateColors(colorRangePercent) {{
889
+ // colorRangePercent: 0-100, represents the proportion of tokens to color
890
+ const colorCount = Math.round(tokenData.length * colorRangePercent / 100);
891
+
892
+ // Calculate max deviation within the colored range
893
+ let maxAbsDeltaInRange = 1e-9;
894
+ tokenData.forEach(item => {{
895
+ if (item.rank < colorCount) {{
896
+ maxAbsDeltaInRange = Math.max(maxAbsDeltaInRange, item.absDelta);
897
+ }}
898
+ }});
899
+
900
+ tokenData.forEach(item => {{
901
+ if (item.rank < colorCount) {{
902
+ // Use dynamic normalization based on colored range
903
+ item.token.style.backgroundColor = tunedDeltaToColor(item.tunedDelta, maxAbsDeltaInRange);
904
+ }} else {{
905
+ // Outside color range, white
906
+ item.token.style.backgroundColor = 'rgb(255, 255, 255)';
 
 
 
 
 
 
907
  }}
908
  }});
909
  }}
910
 
911
  slider.addEventListener('input', (e) => {{
912
+ const val = parseInt(e.target.value);
913
+ rangeValue.textContent = val + '%';
914
  updateColors(val);
915
  }});
916
+
917
+ // Apply default color range on page load
918
+ updateColors(50);
919
  </script>
920
  </body>
921
  </html>