Compression-Lens / visualization /html_generator.py
Jellyfish042's picture
update
257183f
"""
HTML visualization generator for UncheatableEval.
Generates interactive HTML visualizations comparing byte-level losses between two models.
"""
import json
import math
import re
from typing import List, Tuple, Optional, Set
import numpy as np
from core.helpers import TokenizerBytesConverter
# Compression rate conversion factor
COMPRESSION_RATE_FACTOR = (1.0 / math.log(2.0)) * 0.125 * 100.0
# Global tokenizers (lazy loaded)
_qwen_tokenizer = None
_rwkv_tokenizer = None
def get_qwen_tokenizer():
"""Lazy load Qwen tokenizer."""
global _qwen_tokenizer
if _qwen_tokenizer is None:
_qwen_tokenizer = TokenizerBytesConverter("Qwen/Qwen3-0.6B-Base")
return _qwen_tokenizer
def get_rwkv_tokenizer():
"""Lazy load RWKV tokenizer."""
global _rwkv_tokenizer
if _rwkv_tokenizer is None:
from rwkv.rwkv_tokenizer import TRIE_TOKENIZER
import os
script_dir = os.path.dirname(os.path.abspath(__file__))
vocab_path = os.path.join(os.path.dirname(script_dir), "support", "rwkv_vocab_v20230424.txt")
_rwkv_tokenizer = TRIE_TOKENIZER(vocab_path)
return _rwkv_tokenizer
def get_tokenizer_boundaries(text: str, tokenizer, is_rwkv: bool = False) -> Set[int]:
"""Get token boundaries (byte positions) for a given text."""
boundaries = set()
boundaries.add(0)
if is_rwkv:
tokenized = tokenizer.encode(text)
if hasattr(tokenized, "ids"):
token_ids = tokenized.ids
else:
token_ids = tokenized
byte_pos = 0
for token_id in token_ids:
token_bytes = tokenizer.decodeBytes([token_id])
byte_pos += len(token_bytes)
boundaries.add(byte_pos)
else:
token_bytes_list = tokenizer.encode_to_bytes(text)
byte_pos = 0
for token_bytes in token_bytes_list:
byte_pos += len(token_bytes)
boundaries.add(byte_pos)
return boundaries
def get_token_info_for_text(text: str) -> dict:
"""Get detailed token information for each byte position."""
qwen_tokenizer = get_qwen_tokenizer()
rwkv_tokenizer = get_rwkv_tokenizer()
# Get Qwen tokens with positions
qwen_tokens = []
byte_to_qwen = {}
# Keep both token id (vocab id) and decoded bytes so the tooltip can show true token ids.
qwen_id_and_bytes = qwen_tokenizer.encode_to_ids_and_bytes(text)
byte_pos = 0
for idx, (token_id, token_bytes) in enumerate(qwen_id_and_bytes):
start = byte_pos
end = byte_pos + len(token_bytes)
try:
token_str = bytes(token_bytes).decode("utf-8")
except UnicodeDecodeError:
token_str = repr(bytes(token_bytes))
qwen_tokens.append((start, end, token_id, token_str))
byte_to_qwen[start] = idx
byte_pos = end
# Get RWKV tokens with positions
rwkv_tokens = []
byte_to_rwkv = {}
tokenized = rwkv_tokenizer.encode(text)
if hasattr(tokenized, "ids"):
token_ids = tokenized.ids
else:
token_ids = tokenized
byte_pos = 0
for idx, token_id in enumerate(token_ids):
token_bytes = rwkv_tokenizer.decodeBytes([token_id])
start = byte_pos
end = byte_pos + len(token_bytes)
try:
token_str = token_bytes.decode("utf-8")
except UnicodeDecodeError:
token_str = repr(token_bytes)
rwkv_tokens.append((start, end, token_id, token_str))
byte_to_rwkv[start] = idx
byte_pos = end
# Get common boundaries
qwen_boundaries = set([0] + [t[1] for t in qwen_tokens])
rwkv_boundaries = set([0] + [t[1] for t in rwkv_tokens])
common_boundaries = sorted(qwen_boundaries & rwkv_boundaries)
return {
"common_boundaries": common_boundaries,
"qwen_tokens": qwen_tokens,
"rwkv_tokens": rwkv_tokens,
"byte_to_qwen": byte_to_qwen,
"byte_to_rwkv": byte_to_rwkv,
}
def generate_comparison_html(
text: str,
byte_losses_a: List[float],
byte_losses_b: List[float],
model_a_name: str,
model_b_name: str,
topk_predictions_a: Optional[List] = None,
topk_predictions_b: Optional[List] = None,
tokenizer_a=None,
tokenizer_b=None,
model_type_a: str = "hf",
model_type_b: str = "rwkv7",
) -> str:
"""
Generate an interactive HTML visualization comparing two models.
Args:
text: The input text that was evaluated
byte_losses_a: Per-byte losses from model A
byte_losses_b: Per-byte losses from model B
model_a_name: Display name for model A
model_b_name: Display name for model B
topk_predictions_a: Top-k predictions from model A
topk_predictions_b: Top-k predictions from model B
tokenizer_a: Tokenizer for model A
tokenizer_b: Tokenizer for model B
model_type_a: Type of model A ("hf" or "rwkv7")
model_type_b: Type of model B ("hf" or "rwkv7")
Returns:
HTML string with interactive visualization
"""
def decode_token(token_id: int, tokenizer, model_type: str) -> str:
"""Decode a single token ID to text using the appropriate tokenizer."""
if tokenizer is None:
return f"[{token_id}]"
try:
if model_type in ["rwkv", "rwkv7"]:
# RWKV tokenizer uses decode method
decoded = tokenizer.decode([token_id])
return decoded if decoded else f"[{token_id}]"
else:
# HuggingFace tokenizer
decoded = tokenizer.decode([token_id])
return decoded if decoded else f"[{token_id}]"
except Exception as e:
print(f"Warning: Failed to decode token {token_id} ({model_type}): {e}")
return f"[{token_id}]"
def build_byte_to_token_map(text: str, tokenizer, model_type: str):
"""Build mapping from byte position to token index using the correct tokenizer.
Returns a list of (start, end, token_idx) tuples for range-based lookup."""
if tokenizer is None:
return []
token_ranges = []
try:
if model_type in ["rwkv", "rwkv7"]:
# RWKV tokenizer
tokenized = tokenizer.encode(text)
if hasattr(tokenized, "ids"):
token_ids = tokenized.ids
else:
token_ids = tokenized
byte_pos = 0
for idx, token_id in enumerate(token_ids):
try:
token_bytes = tokenizer.decodeBytes([token_id])
token_ranges.append((byte_pos, byte_pos + len(token_bytes), idx))
byte_pos += len(token_bytes)
except Exception as e:
print(f"Warning: Failed to decode RWKV token {token_id}: {e}")
pass
else:
# HuggingFace tokenizer - use TokenizerBytesConverter
tokenizer_name = getattr(tokenizer, "name_or_path", None)
if tokenizer_name:
converter = TokenizerBytesConverter(tokenizer_name, trust_remote_code=True)
token_bytes_list = converter.encode_to_bytes(text)
byte_pos = 0
for idx, token_bytes in enumerate(token_bytes_list):
token_ranges.append((byte_pos, byte_pos + len(token_bytes), idx))
byte_pos += len(token_bytes)
else:
print(f"Warning: Could not get tokenizer name for HF model")
except Exception as e:
print(f"Warning: Could not build byte-to-token map ({model_type}): {e}")
return []
return token_ranges
def find_token_for_byte(byte_pos: int, token_ranges):
for start, end, idx in token_ranges:
if start <= byte_pos < end:
return idx
return None
# Calculate deltas
deltas = [a - b for a, b in zip(byte_losses_a, byte_losses_b)]
avg_delta = sum(deltas) / len(deltas) if deltas else 0
# Calculate average compression rates
avg_compression_a = sum(byte_losses_a) / len(byte_losses_a) * COMPRESSION_RATE_FACTOR if byte_losses_a else 0
avg_compression_b = sum(byte_losses_b) / len(byte_losses_b) * COMPRESSION_RATE_FACTOR if byte_losses_b else 0
avg_delta_compression = avg_delta * COMPRESSION_RATE_FACTOR
# Get token info
text_bytes = text.encode("utf-8")
token_info = get_token_info_for_text(text)
common_boundaries = token_info["common_boundaries"]
qwen_tokens = token_info["qwen_tokens"]
rwkv_tokens = token_info["rwkv_tokens"]
# Build byte position to token index mapping
model_a_token_ranges = build_byte_to_token_map(text, tokenizer_a, model_type_a)
model_b_token_ranges = build_byte_to_token_map(text, tokenizer_b, model_type_b)
def get_tokens_for_range(byte_start, byte_end, token_list):
result = []
for t_start, t_end, token_id, t_str in token_list:
if t_start < byte_end and t_end > byte_start:
result.append((token_id, t_str))
return result
# Build tokens based on common boundaries
tokens = []
token_count = 0
for i in range(len(common_boundaries) - 1):
start_byte = common_boundaries[i]
end_byte = common_boundaries[i + 1]
token_bytes = text_bytes[start_byte:end_byte]
try:
token_text = token_bytes.decode("utf-8")
except UnicodeDecodeError:
continue
qwen_toks = get_tokens_for_range(start_byte, end_byte, qwen_tokens)
rwkv_toks = get_tokens_for_range(start_byte, end_byte, rwkv_tokens)
if re.search(r"\w", token_text, re.UNICODE):
tokens.append(
{
"type": "word",
"text": token_text,
"byte_start": start_byte,
"byte_end": end_byte,
"word_lower": token_text.lower(),
"qwen_tokens": qwen_toks,
"rwkv_tokens": rwkv_toks,
}
)
else:
tokens.append(
{
"type": "non-word",
"text": token_text,
"byte_start": start_byte,
"byte_end": end_byte,
"qwen_tokens": qwen_toks,
"rwkv_tokens": rwkv_toks,
}
)
# Track word occurrences
word_occurrences = {}
word_id_counter = 0
for i, token in enumerate(tokens):
if token["type"] == "word":
word_lower = token["word_lower"]
if word_lower not in word_occurrences:
word_occurrences[word_lower] = []
word_occurrences[word_lower].append(i)
token["word_id"] = word_id_counter
word_id_counter += 1
# Build HTML content
html_content = []
def escape_for_attr(s):
# Escape all characters that could break HTML attributes
# Order matters: & must be escaped first
return (
s.replace("&", "&amp;")
.replace('"', "&quot;")
.replace("'", "&#39;")
.replace("<", "&lt;")
.replace(">", "&gt;")
.replace("\n", "&#10;")
.replace("\r", "&#13;")
.replace("\t", "&#9;")
)
for token in tokens:
token_text = token["text"]
byte_start = token["byte_start"]
byte_end = token["byte_end"]
# Get actual model token IDs for this byte range
model_a_token_idx = find_token_for_byte(byte_start, model_a_token_ranges)
model_b_token_idx = find_token_for_byte(byte_start, model_b_token_ranges)
# Build token info strings showing all tokens in this byte range
# Model A (RWKV7) - show all tokens that overlap with this byte range
model_a_info = ", ".join([f"[{idx}] {repr(s)}" for idx, s in token["rwkv_tokens"]])
# Model B (Qwen3) - show all tokens that overlap with this byte range
model_b_info = ", ".join([f"[{idx}] {repr(s)}" for idx, s in token["qwen_tokens"]])
raw_bytes = list(text_bytes[byte_start:byte_end])
losses_a = byte_losses_a[byte_start:byte_end]
losses_b = byte_losses_b[byte_start:byte_end]
bytes_str = " ".join([f"{b:02x}" for b in raw_bytes])
compression_a_str = " ".join([f"{l * COMPRESSION_RATE_FACTOR:.2f}%" for l in losses_a])
compression_b_str = " ".join([f"{l * COMPRESSION_RATE_FACTOR:.2f}%" for l in losses_b])
# Calculate average compression rate for this token
avg_compression_a_token = sum(losses_a) / len(losses_a) * COMPRESSION_RATE_FACTOR if losses_a else 0
avg_compression_b_token = sum(losses_b) / len(losses_b) * COMPRESSION_RATE_FACTOR if losses_b else 0
topk_a_json = ""
topk_b_json = ""
if topk_predictions_a is not None and model_a_token_ranges:
model_a_token_idx = find_token_for_byte(byte_start, model_a_token_ranges)
if model_a_token_idx is not None and model_a_token_idx < len(topk_predictions_a):
pred = topk_predictions_a[model_a_token_idx]
try:
decoded_pred = [
pred[0],
pred[1],
[[tid, prob, decode_token(tid, tokenizer_a, model_type_a)] for tid, prob in pred[2]],
]
# Use base64 encoding to avoid escaping issues
import base64
topk_a_json = base64.b64encode(json.dumps(decoded_pred, ensure_ascii=False).encode("utf-8")).decode("ascii")
except Exception as e:
pass
if topk_predictions_b is not None and model_b_token_ranges:
model_b_token_idx = find_token_for_byte(byte_start, model_b_token_ranges)
if model_b_token_idx is not None and model_b_token_idx < len(topk_predictions_b):
pred = topk_predictions_b[model_b_token_idx]
try:
decoded_pred = [pred[0], pred[1], [[tid, prob, decode_token(tid, tokenizer_b, model_type_b)] for tid, prob in pred[2]]]
# Use base64 encoding to avoid escaping issues
import base64
topk_b_json = base64.b64encode(json.dumps(decoded_pred, ensure_ascii=False).encode("utf-8")).decode("ascii")
except Exception as e:
pass
token_count += 1
token_deltas = deltas[byte_start:byte_end]
avg_token_delta = sum(token_deltas) / len(token_deltas) if token_deltas else 0
tuned_delta = avg_token_delta - avg_delta
# Initial rendering uses white color, JavaScript will apply colors based on slider
r, g, b = 255, 255, 255
token_html_parts = []
for char in token_text:
if char == "<":
escaped_char = "&lt;"
elif char == ">":
escaped_char = "&gt;"
elif char == "&":
escaped_char = "&amp;"
elif char == "\n":
escaped_char = "<br>"
elif char == " ":
escaped_char = "&nbsp;"
elif char == "\t":
escaped_char = "&nbsp;&nbsp;&nbsp;&nbsp;"
else:
escaped_char = char
token_html_parts.append(escaped_char)
token_span_content = "".join(token_html_parts)
data_attrs = (
f'data-model-a="{escape_for_attr(model_a_info)}" '
f'data-model-b="{escape_for_attr(model_b_info)}" '
f'data-bytes="{escape_for_attr(bytes_str)}" '
f'data-compression-a="{escape_for_attr(compression_a_str)}" '
f'data-compression-b="{escape_for_attr(compression_b_str)}" '
f'data-avg-compression-a="{avg_compression_a_token:.2f}" '
f'data-avg-compression-b="{avg_compression_b_token:.2f}" '
f'data-tuned-delta="{tuned_delta:.6f}" '
f'data-topk-a="{escape_for_attr(topk_a_json)}" '
f'data-topk-b="{escape_for_attr(topk_b_json)}"'
)
style_attr = f'style="background-color: rgb({r},{g},{b})"'
if token["type"] == "word":
word_lower = token["word_lower"]
occurrences = word_occurrences[word_lower]
if len(occurrences) > 1:
word_id = token["word_id"]
html_content.append(
f'<span class="token word" {data_attrs} {style_attr} data-word="{word_lower}" data-word-id="{word_id}">'
+ token_span_content
+ "</span>"
)
else:
html_content.append(f'<span class="token" {data_attrs} {style_attr}>{token_span_content}</span>')
else:
html_content.append(f'<span class="token" {data_attrs} {style_attr}>{token_span_content}</span>')
delta_color = "#64ff64" if avg_delta < 0 else "#ff6464"
html = f"""<!DOCTYPE html>
<html>
<head>
<meta charset="UTF-8">
<title>Model Comparison</title>
<style>
body {{
font-family: Consolas, 'Courier New', monospace;
margin: 0;
padding: 0;
background-color: #f5f5f5;
}}
.header {{
background-color: #333;
color: white;
padding: 20px;
position: sticky;
top: 0;
z-index: 100;
}}
.header h1 {{
margin: 0 0 15px 0;
font-size: 18px;
}}
.meta {{
display: flex;
flex-wrap: wrap;
gap: 20px;
font-size: 12px;
color: #c8c8c8;
}}
.legend {{
display: flex;
gap: 15px;
margin-top: 10px;
}}
.legend-item {{
display: flex;
align-items: center;
gap: 5px;
}}
.legend-box {{
width: 20px;
height: 12px;
border: 1px solid #666;
}}
.content {{
background-color: white;
margin: 10px;
padding: 15px;
border: 1px solid #ccc;
font-size: 14px;
line-height: 1.8;
word-wrap: break-word;
position: relative;
}}
.content span {{
padding: 1px 0;
}}
.word {{
cursor: pointer;
position: relative;
}}
.word:hover {{
outline: 2px solid #007bff;
outline-offset: 1px;
}}
.word.highlighted {{
outline: 2px solid #ff6b6b;
outline-offset: 1px;
}}
#svg-overlay {{
position: fixed;
top: 0;
left: 0;
width: 100%;
height: 100%;
pointer-events: none;
z-index: 1000;
}}
.link-line {{
stroke: #007bff;
stroke-width: 2;
fill: none;
opacity: 0.7;
}}
.link-dot {{
fill: #007bff;
opacity: 0.8;
}}
.token {{
position: relative;
cursor: help;
}}
.token:hover {{
outline: 1px dashed #666;
}}
#tooltip {{
position: fixed;
background-color: rgba(0, 0, 0, 0.9);
color: white;
padding: 10px 14px;
border-radius: 6px;
font-size: 12px;
max-width: 500px;
z-index: 2000;
pointer-events: none;
display: none;
line-height: 1.6;
box-shadow: 0 2px 10px rgba(0,0,0,0.3);
}}
#tooltip .label {{
color: #aaa;
font-weight: bold;
}}
#tooltip .bytes {{
color: #a5f3fc;
font-family: monospace;
}}
#tooltip .loss-a {{
color: #86efac;
font-family: monospace;
}}
#tooltip .loss-b {{
color: #fca5a5;
font-family: monospace;
}}
#tooltip .model-a {{
color: #fcd34d;
}}
#tooltip .model-b {{
color: #7dd3fc;
}}
#tooltip .topk-section {{
margin-top: 8px;
padding-top: 8px;
border-top: 1px solid #555;
}}
#tooltip .topk-container {{
display: flex;
gap: 16px;
}}
#tooltip .topk-column {{
flex: 1;
min-width: 180px;
}}
#tooltip .topk-title {{
color: #aaa;
font-weight: bold;
margin-bottom: 4px;
font-size: 11px;
}}
#tooltip .topk-title.model-a {{
color: #86efac;
}}
#tooltip .topk-title.model-b {{
color: #fca5a5;
}}
#tooltip .topk-list {{
font-size: 11px;
}}
#tooltip .topk-item {{
display: flex;
gap: 4px;
padding: 1px 0;
align-items: center;
}}
#tooltip .topk-rank {{
color: #888;
min-width: 18px;
}}
#tooltip .topk-rank.hit {{
color: #ffd700;
}}
#tooltip .topk-token {{
color: #a5f3fc;
max-width: 100px;
overflow: hidden;
text-overflow: ellipsis;
white-space: nowrap;
font-family: monospace;
}}
#tooltip .topk-prob {{
color: #86efac;
min-width: 45px;
text-align: right;
}}
#tooltip .topk-hit {{
color: #22c55e;
}}
#tooltip .topk-miss {{
color: #ef4444;
font-style: italic;
}}
</style>
</head>
<body>
<svg id="svg-overlay"></svg>
<div id="tooltip"></div>
<div class="header">
<div class="meta">
<div>Model A: {model_a_name}</div>
<div>Model B: {model_b_name}</div>
<div>RWKV Compression: {avg_compression_a:.2f}%</div>
<div>Qwen Compression: {avg_compression_b:.2f}%</div>
<div style="color: {delta_color}">Avg Delta: {avg_delta_compression:+.2f}%</div>
</div>
<div class="legend">
<div class="legend-item">
<div class="legend-box" style="background-color: rgb(77, 255, 77)"></div>
<span>RWKV better than avg</span>
</div>
<div class="legend-item">
<div class="legend-box" style="background-color: rgb(255, 255, 255)"></div>
<span>Equal to avg</span>
</div>
<div class="legend-item">
<div class="legend-box" style="background-color: rgb(255, 77, 77)"></div>
<span>RWKV worse than avg</span>
</div>
<div class="legend-item" style="margin-left: 20px;">
<span style="color: #aaa;">Color Range:</span>
<input type="range" id="color-range-slider" min="0" max="100" value="10" step="0.1" style="width: 200px; vertical-align: middle;">
<span id="color-range-value" style="color: #fff; min-width: 45px; display: inline-block;">10%</span>
</div>
</div>
</div>
<div class="content">
{''.join(html_content)}
</div>
<script>
const svgOverlay = document.getElementById('svg-overlay');
const words = document.querySelectorAll('.word');
const wordGroups = {{}};
words.forEach(word => {{
const wordText = word.getAttribute('data-word');
if (!wordGroups[wordText]) {{
wordGroups[wordText] = [];
}}
wordGroups[wordText].push(word);
}});
function clearLines() {{
svgOverlay.innerHTML = '';
words.forEach(w => w.classList.remove('highlighted'));
}}
function drawLines(hoveredWord) {{
clearLines();
const wordText = hoveredWord.getAttribute('data-word');
const wordId = parseInt(hoveredWord.getAttribute('data-word-id'));
const sameWords = wordGroups[wordText] || [];
const previousWords = sameWords.filter(w => {{
const id = parseInt(w.getAttribute('data-word-id'));
return id < wordId;
}});
if (previousWords.length === 0) return;
sameWords.forEach(w => w.classList.add('highlighted'));
const hoveredRect = hoveredWord.getBoundingClientRect();
const hoveredX = hoveredRect.left + hoveredRect.width / 2;
const hoveredY = hoveredRect.top + hoveredRect.height / 2;
previousWords.forEach(prevWord => {{
const prevRect = prevWord.getBoundingClientRect();
const prevX = prevRect.left + prevRect.width / 2;
const prevY = prevRect.top + prevRect.height / 2;
const midX = (hoveredX + prevX) / 2;
const midY = Math.min(hoveredY, prevY) - 30;
const path = document.createElementNS('http://www.w3.org/2000/svg', 'path');
path.setAttribute('class', 'link-line');
path.setAttribute('d', `M ${{prevX}} ${{prevY}} Q ${{midX}} ${{midY}} ${{hoveredX}} ${{hoveredY}}`);
svgOverlay.appendChild(path);
const dot1 = document.createElementNS('http://www.w3.org/2000/svg', 'circle');
dot1.setAttribute('class', 'link-dot');
dot1.setAttribute('cx', prevX);
dot1.setAttribute('cy', prevY);
dot1.setAttribute('r', 4);
svgOverlay.appendChild(dot1);
const dot2 = document.createElementNS('http://www.w3.org/2000/svg', 'circle');
dot2.setAttribute('class', 'link-dot');
dot2.setAttribute('cx', hoveredX);
dot2.setAttribute('cy', hoveredY);
dot2.setAttribute('r', 4);
svgOverlay.appendChild(dot2);
}});
}}
words.forEach(word => {{
word.addEventListener('mouseenter', () => drawLines(word));
word.addEventListener('mouseleave', clearLines);
}});
window.addEventListener('scroll', clearLines);
const tooltip = document.getElementById('tooltip');
const tokenSpans = document.querySelectorAll('.token');
tokenSpans.forEach(token => {{
token.addEventListener('mouseenter', (e) => {{
const modelA = token.getAttribute('data-model-a') || 'N/A';
const modelB = token.getAttribute('data-model-b') || 'N/A';
const bytes = token.getAttribute('data-bytes') || '';
const compressionA = token.getAttribute('data-compression-a') || '';
const compressionB = token.getAttribute('data-compression-b') || '';
const avgCompressionA = token.getAttribute('data-avg-compression-a') || '';
const avgCompressionB = token.getAttribute('data-avg-compression-b') || '';
const top5A = token.getAttribute('data-topk-a') || '';
const top5B = token.getAttribute('data-topk-b') || '';
function formatTopkColumn(topkBase64, modelName, titleClass) {{
if (!topkBase64) return '<div class="topk-column"><div class="topk-title ' + titleClass + '">' + modelName + '</div><div class="topk-list">N/A</div></div>';
try {{
// Decode base64 to UTF-8 string (atob() doesn't support UTF-8, need proper decoding)
const binaryString = atob(topkBase64);
const bytes = new Uint8Array(binaryString.length);
for (let i = 0; i < binaryString.length; i++) {{
bytes[i] = binaryString.charCodeAt(i);
}}
const topkJson = new TextDecoder('utf-8').decode(bytes);
const data = JSON.parse(topkJson);
const [actualId, rank, topkList] = data;
let html = '<div class="topk-column">';
html += '<div class="topk-title ' + titleClass + '">' + modelName + '</div>';
html += '<div class="topk-list">';
topkList.forEach((item, idx) => {{
const [tokenId, prob, tokenText] = item;
const isHit = tokenId === actualId;
const rankClass = isHit ? 'topk-rank hit' : 'topk-rank';
const displayText = tokenText || '[' + tokenId + ']';
const escapedText = displayText.replace(/</g, '&lt;').replace(/>/g, '&gt;');
html += '<div class="topk-item">';
html += '<span class="' + rankClass + '">' + (idx + 1) + '.</span>';
html += '<span class="topk-token" title="ID: ' + tokenId + '">' + escapedText + '</span>';
html += '<span class="topk-prob">' + (prob * 100).toFixed(1) + '%</span>';
if (isHit) html += '<span class="topk-hit">✓</span>';
html += '</div>';
}});
if (rank > 10) {{
html += '<div class="topk-item topk-miss">Actual rank: ' + rank + '</div>';
}}
html += '</div></div>';
return html;
}} catch (e) {{
console.error('Error in formatTopkColumn for ' + modelName + ':', e);
console.error('topkBase64:', topkBase64);
return '<div class="topk-column"><div class="topk-title ' + titleClass + '">' + modelName + '</div><div class="topk-list">Error: ' + e.message + '</div></div>';
}}
}}
let tooltipHtml = `
<div><span class="label">Bytes:</span> <span class="bytes">${{bytes || '(empty)'}}</span></div>
<div><span class="label">RWKV Compression Rate:</span> <span class="loss-a">${{compressionA || '(empty)'}}${{avgCompressionA ? ' (avg: ' + avgCompressionA + '%)' : ''}}</span></div>
<div><span class="label">Qwen Compression Rate:</span> <span class="loss-b">${{compressionB || '(empty)'}}${{avgCompressionB ? ' (avg: ' + avgCompressionB + '%)' : ''}}</span></div>
<hr style="border-color: #555; margin: 6px 0;">
<div><span class="label">RWKV:</span> <span class="model-a">${{modelA || '(empty)'}}</span></div>
<div><span class="label">Qwen:</span> <span class="model-b">${{modelB || '(empty)'}}</span></div>
`;
if (top5A || top5B) {{
tooltipHtml += '<div class="topk-section"><div class="topk-container">';
tooltipHtml += formatTopkColumn(top5A, 'RWKV Top10', 'model-a');
tooltipHtml += formatTopkColumn(top5B, 'Qwen Top10', 'model-b');
tooltipHtml += '</div></div>';
}}
tooltip.innerHTML = tooltipHtml;
tooltip.style.display = 'block';
}});
token.addEventListener('mousemove', (e) => {{
const tooltipRect = tooltip.getBoundingClientRect();
const viewportWidth = window.innerWidth;
const viewportHeight = window.innerHeight;
let x = e.clientX + 15;
let y = e.clientY + 15;
if (x + tooltipRect.width > viewportWidth - 10) {{
x = e.clientX - tooltipRect.width - 15;
}}
if (y + tooltipRect.height > viewportHeight - 10) {{
y = e.clientY - tooltipRect.height - 15;
}}
if (x < 10) x = 10;
if (y < 10) y = 10;
tooltip.style.left = x + 'px';
tooltip.style.top = y + 'px';
}});
token.addEventListener('mouseleave', () => {{
tooltip.style.display = 'none';
}});
}});
const slider = document.getElementById('color-range-slider');
const rangeValue = document.getElementById('color-range-value');
// Collect all tuned_delta values
const tokenData = [];
tokenSpans.forEach((token, idx) => {{
const tunedDelta = parseFloat(token.getAttribute('data-tuned-delta'));
if (!isNaN(tunedDelta)) {{
tokenData.push({{ token, tunedDelta, absDelta: Math.abs(tunedDelta) }});
}}
}});
// Calculate max_abs_tuned_delta for normalization
const maxAbsDelta = Math.max(...tokenData.map(d => d.absDelta), 1e-9);
// Sort by |tuned_delta| to get rankings
const sortedByAbs = [...tokenData].sort((a, b) => b.absDelta - a.absDelta);
sortedByAbs.forEach((item, rank) => {{
item.rank = rank; // rank 0 = largest deviation
}});
function tunedDeltaToColor(tunedDelta, maxAbsDelta, exponent) {{
// Normalize to [-1, 1]
const normalized = Math.max(-1, Math.min(1, tunedDelta / maxAbsDelta));
let r, g, b;
if (normalized < 0) {{
// Green (RWKV better)
const intensity = Math.pow(-normalized, exponent);
r = Math.round(255 * (1 - intensity * 0.85));
g = 255;
b = Math.round(255 * (1 - intensity * 0.85));
}} else {{
// Red (RWKV worse)
const intensity = Math.pow(normalized, exponent);
r = 255;
g = Math.round(255 * (1 - intensity * 0.85));
b = Math.round(255 * (1 - intensity * 0.85));
}}
return `rgb(${{r}}, ${{g}}, ${{b}})`;
}}
function updateColors(colorRangePercent) {{
// colorRangePercent: 0-100, represents the proportion of tokens to color
const colorCount = Math.round(tokenData.length * colorRangePercent / 100);
// Calculate exponent: 100% -> 0.5, 0% -> 1.0
const exponent = 1 - (colorRangePercent / 100) * 0.5;
// Calculate max deviation within the colored range
let maxAbsDeltaInRange = 1e-9;
tokenData.forEach(item => {{
if (item.rank < colorCount) {{
maxAbsDeltaInRange = Math.max(maxAbsDeltaInRange, item.absDelta);
}}
}});
tokenData.forEach(item => {{
if (item.rank < colorCount) {{
// Use dynamic normalization based on colored range
item.token.style.backgroundColor = tunedDeltaToColor(item.tunedDelta, maxAbsDeltaInRange, exponent);
}} else {{
// Outside color range, white
item.token.style.backgroundColor = 'rgb(255, 255, 255)';
}}
}});
}}
slider.addEventListener('input', (e) => {{
const val = parseFloat(e.target.value);
rangeValue.textContent = val.toFixed(1) + '%';
updateColors(val);
}});
// Apply default color range on page load
updateColors(10);
</script>
</body>
</html>
"""
return html