Spaces:
Running
Running
Commit
·
fa470f3
1
Parent(s):
452ae9b
update
Browse files- visualization/html_generator.py +66 -98
visualization/html_generator.py
CHANGED
|
@@ -127,28 +127,6 @@ def get_token_info_for_text(text: str) -> dict:
|
|
| 127 |
}
|
| 128 |
|
| 129 |
|
| 130 |
-
def delta_to_color(delta: float, avg_delta: float, max_deviation: float) -> Tuple[int, int, int]:
|
| 131 |
-
"""Map a delta value to an RGB color based on deviation from average."""
|
| 132 |
-
if max_deviation == 0:
|
| 133 |
-
return (255, 255, 255)
|
| 134 |
-
|
| 135 |
-
deviation = delta - avg_delta
|
| 136 |
-
normalized = max(-1, min(1, deviation / max_deviation))
|
| 137 |
-
|
| 138 |
-
if normalized < 0:
|
| 139 |
-
intensity = -normalized
|
| 140 |
-
r = int(255 * (1 - intensity * 0.7))
|
| 141 |
-
g = 255
|
| 142 |
-
b = int(255 * (1 - intensity * 0.7))
|
| 143 |
-
else:
|
| 144 |
-
intensity = normalized
|
| 145 |
-
r = 255
|
| 146 |
-
g = int(255 * (1 - intensity * 0.7))
|
| 147 |
-
b = int(255 * (1 - intensity * 0.7))
|
| 148 |
-
|
| 149 |
-
return (r, g, b)
|
| 150 |
-
|
| 151 |
-
|
| 152 |
def generate_comparison_html(
|
| 153 |
text: str,
|
| 154 |
byte_losses_a: List[float],
|
|
@@ -253,12 +231,6 @@ def generate_comparison_html(
|
|
| 253 |
deltas = [a - b for a, b in zip(byte_losses_a, byte_losses_b)]
|
| 254 |
avg_delta = sum(deltas) / len(deltas) if deltas else 0
|
| 255 |
|
| 256 |
-
# Calculate max deviation
|
| 257 |
-
deviations = [d - avg_delta for d in deltas]
|
| 258 |
-
abs_deviations = [abs(dev) for dev in deviations]
|
| 259 |
-
max_deviation = float(np.percentile(abs_deviations, 100)) if abs_deviations else 0
|
| 260 |
-
max_deviation = max(max_deviation, 1e-6)
|
| 261 |
-
|
| 262 |
# Calculate average compression rates
|
| 263 |
avg_compression_a = sum(byte_losses_a) / len(byte_losses_a) * COMPRESSION_RATE_FACTOR if byte_losses_a else 0
|
| 264 |
avg_compression_b = sum(byte_losses_b) / len(byte_losses_b) * COMPRESSION_RATE_FACTOR if byte_losses_b else 0
|
|
@@ -375,6 +347,10 @@ def generate_comparison_html(
|
|
| 375 |
compression_a_str = " ".join([f"{l * COMPRESSION_RATE_FACTOR:.2f}%" for l in losses_a])
|
| 376 |
compression_b_str = " ".join([f"{l * COMPRESSION_RATE_FACTOR:.2f}%" for l in losses_b])
|
| 377 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 378 |
topk_a_json = ""
|
| 379 |
topk_b_json = ""
|
| 380 |
if topk_predictions_a is not None and model_a_token_ranges:
|
|
@@ -410,15 +386,10 @@ def generate_comparison_html(
|
|
| 410 |
|
| 411 |
token_deltas = deltas[byte_start:byte_end]
|
| 412 |
avg_token_delta = sum(token_deltas) / len(token_deltas) if token_deltas else 0
|
|
|
|
| 413 |
|
| 414 |
-
#
|
| 415 |
-
|
| 416 |
-
power_n = 3
|
| 417 |
-
sign = 1 if avg_token_delta >= 0 else -1
|
| 418 |
-
avg_token_delta_powered = sign * (abs(avg_token_delta) ** power_n)
|
| 419 |
-
|
| 420 |
-
color = delta_to_color(avg_token_delta_powered, avg_delta, max_deviation)
|
| 421 |
-
r, g, b = color
|
| 422 |
|
| 423 |
token_html_parts = []
|
| 424 |
for char in token_text:
|
|
@@ -445,7 +416,9 @@ def generate_comparison_html(
|
|
| 445 |
f'data-bytes="{escape_for_attr(bytes_str)}" '
|
| 446 |
f'data-compression-a="{escape_for_attr(compression_a_str)}" '
|
| 447 |
f'data-compression-b="{escape_for_attr(compression_b_str)}" '
|
| 448 |
-
f'data-
|
|
|
|
|
|
|
| 449 |
f'data-topk-a="{escape_for_attr(topk_a_json)}" '
|
| 450 |
f'data-topk-b="{escape_for_attr(topk_b_json)}"'
|
| 451 |
)
|
|
@@ -689,9 +662,9 @@ def generate_comparison_html(
|
|
| 689 |
<span>RWKV worse than avg</span>
|
| 690 |
</div>
|
| 691 |
<div class="legend-item" style="margin-left: 20px;">
|
| 692 |
-
<span style="color: #aaa;">
|
| 693 |
-
<input type="range" id="
|
| 694 |
-
<span id="
|
| 695 |
</div>
|
| 696 |
</div>
|
| 697 |
</div>
|
|
@@ -782,6 +755,8 @@ def generate_comparison_html(
|
|
| 782 |
const bytes = token.getAttribute('data-bytes') || '';
|
| 783 |
const compressionA = token.getAttribute('data-compression-a') || '';
|
| 784 |
const compressionB = token.getAttribute('data-compression-b') || '';
|
|
|
|
|
|
|
| 785 |
const top5A = token.getAttribute('data-topk-a') || '';
|
| 786 |
const top5B = token.getAttribute('data-topk-b') || '';
|
| 787 |
|
|
@@ -827,8 +802,8 @@ def generate_comparison_html(
|
|
| 827 |
|
| 828 |
let tooltipHtml = `
|
| 829 |
<div><span class="label">Bytes:</span> <span class="bytes">${{bytes || '(empty)'}}</span></div>
|
| 830 |
-
<div><span class="label">RWKV Compression Rate:</span> <span class="loss-a">${{compressionA || '(empty)'}}</span></div>
|
| 831 |
-
<div><span class="label">Qwen Compression Rate:</span> <span class="loss-b">${{compressionB || '(empty)'}}</span></div>
|
| 832 |
<hr style="border-color: #555; margin: 6px 0;">
|
| 833 |
<div><span class="label">RWKV:</span> <span class="model-a">${{modelA || '(empty)'}}</span></div>
|
| 834 |
<div><span class="label">Qwen:</span> <span class="model-b">${{modelB || '(empty)'}}</span></div>
|
|
@@ -869,85 +844,78 @@ def generate_comparison_html(
|
|
| 869 |
}});
|
| 870 |
}});
|
| 871 |
|
| 872 |
-
const
|
| 873 |
-
const
|
| 874 |
-
const saturationValue = document.getElementById('saturation-value');
|
| 875 |
-
const powerN = 3; // Must match Python's power_n
|
| 876 |
|
| 877 |
-
|
| 878 |
-
|
| 879 |
-
|
| 880 |
-
|
|
|
|
|
|
|
|
|
|
| 881 |
}});
|
| 882 |
|
| 883 |
-
//
|
| 884 |
-
|
| 885 |
-
const sign = delta >= 0 ? 1 : -1;
|
| 886 |
-
return sign * Math.pow(Math.abs(delta), powerN);
|
| 887 |
-
}}
|
| 888 |
|
| 889 |
-
|
| 890 |
-
|
| 891 |
-
|
| 892 |
-
|
| 893 |
-
|
| 894 |
-
if (lower === upper) return sorted[lower];
|
| 895 |
-
return sorted[lower] + (sorted[upper] - sorted[lower]) * (idx - lower);
|
| 896 |
-
}}
|
| 897 |
|
| 898 |
-
function
|
| 899 |
-
|
| 900 |
-
const
|
| 901 |
-
let normalized = Math.max(-1, Math.min(1, deviation / maxDeviation));
|
| 902 |
let r, g, b;
|
| 903 |
if (normalized < 0) {{
|
|
|
|
| 904 |
const intensity = -normalized;
|
| 905 |
-
r = Math.round(255 * (1 - intensity * 0.
|
| 906 |
g = 255;
|
| 907 |
-
b = Math.round(255 * (1 - intensity * 0.
|
| 908 |
}} else {{
|
|
|
|
| 909 |
const intensity = normalized;
|
| 910 |
r = 255;
|
| 911 |
-
g = Math.round(255 * (1 - intensity * 0.
|
| 912 |
-
b = Math.round(255 * (1 - intensity * 0.
|
| 913 |
}}
|
| 914 |
return `rgb(${{r}}, ${{g}}, ${{b}})`;
|
| 915 |
}}
|
| 916 |
|
| 917 |
-
|
| 918 |
-
|
| 919 |
-
|
| 920 |
-
|
| 921 |
-
|
| 922 |
-
|
| 923 |
-
|
| 924 |
-
|
| 925 |
-
|
| 926 |
-
|
| 927 |
-
|
| 928 |
-
|
| 929 |
-
|
| 930 |
-
|
| 931 |
-
|
| 932 |
-
|
| 933 |
-
|
| 934 |
-
|
| 935 |
-
|
| 936 |
-
|
| 937 |
-
tokenSpans.forEach((token, idx) => {{
|
| 938 |
-
const delta = parseFloat(token.getAttribute('data-delta'));
|
| 939 |
-
if (!isNaN(delta)) {{
|
| 940 |
-
const deltaPowered = applyPower(delta);
|
| 941 |
-
token.style.backgroundColor = deltaToColor(deltaPowered, avgDeltaPowered, maxDeviation);
|
| 942 |
}}
|
| 943 |
}});
|
| 944 |
}}
|
| 945 |
|
| 946 |
slider.addEventListener('input', (e) => {{
|
| 947 |
-
const val = parseInt(e.target.value)
|
| 948 |
-
|
| 949 |
updateColors(val);
|
| 950 |
}});
|
|
|
|
|
|
|
|
|
|
| 951 |
</script>
|
| 952 |
</body>
|
| 953 |
</html>
|
|
|
|
| 127 |
}
|
| 128 |
|
| 129 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 130 |
def generate_comparison_html(
|
| 131 |
text: str,
|
| 132 |
byte_losses_a: List[float],
|
|
|
|
| 231 |
deltas = [a - b for a, b in zip(byte_losses_a, byte_losses_b)]
|
| 232 |
avg_delta = sum(deltas) / len(deltas) if deltas else 0
|
| 233 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 234 |
# Calculate average compression rates
|
| 235 |
avg_compression_a = sum(byte_losses_a) / len(byte_losses_a) * COMPRESSION_RATE_FACTOR if byte_losses_a else 0
|
| 236 |
avg_compression_b = sum(byte_losses_b) / len(byte_losses_b) * COMPRESSION_RATE_FACTOR if byte_losses_b else 0
|
|
|
|
| 347 |
compression_a_str = " ".join([f"{l * COMPRESSION_RATE_FACTOR:.2f}%" for l in losses_a])
|
| 348 |
compression_b_str = " ".join([f"{l * COMPRESSION_RATE_FACTOR:.2f}%" for l in losses_b])
|
| 349 |
|
| 350 |
+
# Calculate average compression rate for this token
|
| 351 |
+
avg_compression_a_token = sum(losses_a) / len(losses_a) * COMPRESSION_RATE_FACTOR if losses_a else 0
|
| 352 |
+
avg_compression_b_token = sum(losses_b) / len(losses_b) * COMPRESSION_RATE_FACTOR if losses_b else 0
|
| 353 |
+
|
| 354 |
topk_a_json = ""
|
| 355 |
topk_b_json = ""
|
| 356 |
if topk_predictions_a is not None and model_a_token_ranges:
|
|
|
|
| 386 |
|
| 387 |
token_deltas = deltas[byte_start:byte_end]
|
| 388 |
avg_token_delta = sum(token_deltas) / len(token_deltas) if token_deltas else 0
|
| 389 |
+
tuned_delta = avg_token_delta - avg_delta
|
| 390 |
|
| 391 |
+
# Initial rendering uses white color, JavaScript will apply colors based on slider
|
| 392 |
+
r, g, b = 255, 255, 255
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 393 |
|
| 394 |
token_html_parts = []
|
| 395 |
for char in token_text:
|
|
|
|
| 416 |
f'data-bytes="{escape_for_attr(bytes_str)}" '
|
| 417 |
f'data-compression-a="{escape_for_attr(compression_a_str)}" '
|
| 418 |
f'data-compression-b="{escape_for_attr(compression_b_str)}" '
|
| 419 |
+
f'data-avg-compression-a="{avg_compression_a_token:.2f}" '
|
| 420 |
+
f'data-avg-compression-b="{avg_compression_b_token:.2f}" '
|
| 421 |
+
f'data-tuned-delta="{tuned_delta:.6f}" '
|
| 422 |
f'data-topk-a="{escape_for_attr(topk_a_json)}" '
|
| 423 |
f'data-topk-b="{escape_for_attr(topk_b_json)}"'
|
| 424 |
)
|
|
|
|
| 662 |
<span>RWKV worse than avg</span>
|
| 663 |
</div>
|
| 664 |
<div class="legend-item" style="margin-left: 20px;">
|
| 665 |
+
<span style="color: #aaa;">Color Range:</span>
|
| 666 |
+
<input type="range" id="color-range-slider" min="0" max="100" value="50" step="1" style="width: 200px; vertical-align: middle;">
|
| 667 |
+
<span id="color-range-value" style="color: #fff; min-width: 45px; display: inline-block;">50%</span>
|
| 668 |
</div>
|
| 669 |
</div>
|
| 670 |
</div>
|
|
|
|
| 755 |
const bytes = token.getAttribute('data-bytes') || '';
|
| 756 |
const compressionA = token.getAttribute('data-compression-a') || '';
|
| 757 |
const compressionB = token.getAttribute('data-compression-b') || '';
|
| 758 |
+
const avgCompressionA = token.getAttribute('data-avg-compression-a') || '';
|
| 759 |
+
const avgCompressionB = token.getAttribute('data-avg-compression-b') || '';
|
| 760 |
const top5A = token.getAttribute('data-topk-a') || '';
|
| 761 |
const top5B = token.getAttribute('data-topk-b') || '';
|
| 762 |
|
|
|
|
| 802 |
|
| 803 |
let tooltipHtml = `
|
| 804 |
<div><span class="label">Bytes:</span> <span class="bytes">${{bytes || '(empty)'}}</span></div>
|
| 805 |
+
<div><span class="label">RWKV Compression Rate:</span> <span class="loss-a">${{compressionA || '(empty)'}}${{avgCompressionA ? ' (avg: ' + avgCompressionA + '%)' : ''}}</span></div>
|
| 806 |
+
<div><span class="label">Qwen Compression Rate:</span> <span class="loss-b">${{compressionB || '(empty)'}}${{avgCompressionB ? ' (avg: ' + avgCompressionB + '%)' : ''}}</span></div>
|
| 807 |
<hr style="border-color: #555; margin: 6px 0;">
|
| 808 |
<div><span class="label">RWKV:</span> <span class="model-a">${{modelA || '(empty)'}}</span></div>
|
| 809 |
<div><span class="label">Qwen:</span> <span class="model-b">${{modelB || '(empty)'}}</span></div>
|
|
|
|
| 844 |
}});
|
| 845 |
}});
|
| 846 |
|
| 847 |
+
const slider = document.getElementById('color-range-slider');
|
| 848 |
+
const rangeValue = document.getElementById('color-range-value');
|
|
|
|
|
|
|
| 849 |
|
| 850 |
+
// Collect all tuned_delta values
|
| 851 |
+
const tokenData = [];
|
| 852 |
+
tokenSpans.forEach((token, idx) => {{
|
| 853 |
+
const tunedDelta = parseFloat(token.getAttribute('data-tuned-delta'));
|
| 854 |
+
if (!isNaN(tunedDelta)) {{
|
| 855 |
+
tokenData.push({{ token, tunedDelta, absDelta: Math.abs(tunedDelta) }});
|
| 856 |
+
}}
|
| 857 |
}});
|
| 858 |
|
| 859 |
+
// Calculate max_abs_tuned_delta for normalization
|
| 860 |
+
const maxAbsDelta = Math.max(...tokenData.map(d => d.absDelta), 1e-9);
|
|
|
|
|
|
|
|
|
|
| 861 |
|
| 862 |
+
// Sort by |tuned_delta| to get rankings
|
| 863 |
+
const sortedByAbs = [...tokenData].sort((a, b) => b.absDelta - a.absDelta);
|
| 864 |
+
sortedByAbs.forEach((item, rank) => {{
|
| 865 |
+
item.rank = rank; // rank 0 = largest deviation
|
| 866 |
+
}});
|
|
|
|
|
|
|
|
|
|
| 867 |
|
| 868 |
+
function tunedDeltaToColor(tunedDelta, maxAbsDelta) {{
|
| 869 |
+
// Normalize to [-1, 1]
|
| 870 |
+
const normalized = Math.max(-1, Math.min(1, tunedDelta / maxAbsDelta));
|
|
|
|
| 871 |
let r, g, b;
|
| 872 |
if (normalized < 0) {{
|
| 873 |
+
// Green (RWKV better)
|
| 874 |
const intensity = -normalized;
|
| 875 |
+
r = Math.round(255 * (1 - intensity * 0.85));
|
| 876 |
g = 255;
|
| 877 |
+
b = Math.round(255 * (1 - intensity * 0.85));
|
| 878 |
}} else {{
|
| 879 |
+
// Red (RWKV worse)
|
| 880 |
const intensity = normalized;
|
| 881 |
r = 255;
|
| 882 |
+
g = Math.round(255 * (1 - intensity * 0.85));
|
| 883 |
+
b = Math.round(255 * (1 - intensity * 0.85));
|
| 884 |
}}
|
| 885 |
return `rgb(${{r}}, ${{g}}, ${{b}})`;
|
| 886 |
}}
|
| 887 |
|
| 888 |
+
function updateColors(colorRangePercent) {{
|
| 889 |
+
// colorRangePercent: 0-100, represents the proportion of tokens to color
|
| 890 |
+
const colorCount = Math.round(tokenData.length * colorRangePercent / 100);
|
| 891 |
+
|
| 892 |
+
// Calculate max deviation within the colored range
|
| 893 |
+
let maxAbsDeltaInRange = 1e-9;
|
| 894 |
+
tokenData.forEach(item => {{
|
| 895 |
+
if (item.rank < colorCount) {{
|
| 896 |
+
maxAbsDeltaInRange = Math.max(maxAbsDeltaInRange, item.absDelta);
|
| 897 |
+
}}
|
| 898 |
+
}});
|
| 899 |
+
|
| 900 |
+
tokenData.forEach(item => {{
|
| 901 |
+
if (item.rank < colorCount) {{
|
| 902 |
+
// Use dynamic normalization based on colored range
|
| 903 |
+
item.token.style.backgroundColor = tunedDeltaToColor(item.tunedDelta, maxAbsDeltaInRange);
|
| 904 |
+
}} else {{
|
| 905 |
+
// Outside color range, white
|
| 906 |
+
item.token.style.backgroundColor = 'rgb(255, 255, 255)';
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 907 |
}}
|
| 908 |
}});
|
| 909 |
}}
|
| 910 |
|
| 911 |
slider.addEventListener('input', (e) => {{
|
| 912 |
+
const val = parseInt(e.target.value);
|
| 913 |
+
rangeValue.textContent = val + '%';
|
| 914 |
updateColors(val);
|
| 915 |
}});
|
| 916 |
+
|
| 917 |
+
// Apply default color range on page load
|
| 918 |
+
updateColors(50);
|
| 919 |
</script>
|
| 920 |
</body>
|
| 921 |
</html>
|