"""
HTML visualization generator for UncheatableEval.
Generates interactive HTML visualizations comparing byte-level losses between two models.
"""
import json
import math
import re
from typing import List, Tuple, Optional, Set
import numpy as np
from core.helpers import TokenizerBytesConverter
# Compression rate conversion factor
COMPRESSION_RATE_FACTOR = (1.0 / math.log(2.0)) * 0.125 * 100.0
# Global tokenizers (lazy loaded)
_qwen_tokenizer = None
_rwkv_tokenizer = None
def get_qwen_tokenizer():
"""Lazy load Qwen tokenizer."""
global _qwen_tokenizer
if _qwen_tokenizer is None:
_qwen_tokenizer = TokenizerBytesConverter("Qwen/Qwen3-0.6B-Base")
return _qwen_tokenizer
def get_rwkv_tokenizer():
"""Lazy load RWKV tokenizer."""
global _rwkv_tokenizer
if _rwkv_tokenizer is None:
from rwkv.rwkv_tokenizer import TRIE_TOKENIZER
import os
script_dir = os.path.dirname(os.path.abspath(__file__))
vocab_path = os.path.join(os.path.dirname(script_dir), "support", "rwkv_vocab_v20230424.txt")
_rwkv_tokenizer = TRIE_TOKENIZER(vocab_path)
return _rwkv_tokenizer
def get_tokenizer_boundaries(text: str, tokenizer, is_rwkv: bool = False) -> Set[int]:
"""Get token boundaries (byte positions) for a given text."""
boundaries = set()
boundaries.add(0)
if is_rwkv:
tokenized = tokenizer.encode(text)
if hasattr(tokenized, "ids"):
token_ids = tokenized.ids
else:
token_ids = tokenized
byte_pos = 0
for token_id in token_ids:
token_bytes = tokenizer.decodeBytes([token_id])
byte_pos += len(token_bytes)
boundaries.add(byte_pos)
else:
token_bytes_list = tokenizer.encode_to_bytes(text)
byte_pos = 0
for token_bytes in token_bytes_list:
byte_pos += len(token_bytes)
boundaries.add(byte_pos)
return boundaries
def get_token_info_for_text(text: str) -> dict:
"""Get detailed token information for each byte position."""
qwen_tokenizer = get_qwen_tokenizer()
rwkv_tokenizer = get_rwkv_tokenizer()
# Get Qwen tokens with positions
qwen_tokens = []
byte_to_qwen = {}
# Keep both token id (vocab id) and decoded bytes so the tooltip can show true token ids.
qwen_id_and_bytes = qwen_tokenizer.encode_to_ids_and_bytes(text)
byte_pos = 0
for idx, (token_id, token_bytes) in enumerate(qwen_id_and_bytes):
start = byte_pos
end = byte_pos + len(token_bytes)
try:
token_str = bytes(token_bytes).decode("utf-8")
except UnicodeDecodeError:
token_str = repr(bytes(token_bytes))
qwen_tokens.append((start, end, token_id, token_str))
byte_to_qwen[start] = idx
byte_pos = end
# Get RWKV tokens with positions
rwkv_tokens = []
byte_to_rwkv = {}
tokenized = rwkv_tokenizer.encode(text)
if hasattr(tokenized, "ids"):
token_ids = tokenized.ids
else:
token_ids = tokenized
byte_pos = 0
for idx, token_id in enumerate(token_ids):
token_bytes = rwkv_tokenizer.decodeBytes([token_id])
start = byte_pos
end = byte_pos + len(token_bytes)
try:
token_str = token_bytes.decode("utf-8")
except UnicodeDecodeError:
token_str = repr(token_bytes)
rwkv_tokens.append((start, end, token_id, token_str))
byte_to_rwkv[start] = idx
byte_pos = end
# Get common boundaries
qwen_boundaries = set([0] + [t[1] for t in qwen_tokens])
rwkv_boundaries = set([0] + [t[1] for t in rwkv_tokens])
common_boundaries = sorted(qwen_boundaries & rwkv_boundaries)
return {
"common_boundaries": common_boundaries,
"qwen_tokens": qwen_tokens,
"rwkv_tokens": rwkv_tokens,
"byte_to_qwen": byte_to_qwen,
"byte_to_rwkv": byte_to_rwkv,
}
def generate_comparison_html(
text: str,
byte_losses_a: List[float],
byte_losses_b: List[float],
model_a_name: str,
model_b_name: str,
topk_predictions_a: Optional[List] = None,
topk_predictions_b: Optional[List] = None,
tokenizer_a=None,
tokenizer_b=None,
model_type_a: str = "hf",
model_type_b: str = "rwkv7",
) -> str:
"""
Generate an interactive HTML visualization comparing two models.
Args:
text: The input text that was evaluated
byte_losses_a: Per-byte losses from model A
byte_losses_b: Per-byte losses from model B
model_a_name: Display name for model A
model_b_name: Display name for model B
topk_predictions_a: Top-k predictions from model A
topk_predictions_b: Top-k predictions from model B
tokenizer_a: Tokenizer for model A
tokenizer_b: Tokenizer for model B
model_type_a: Type of model A ("hf" or "rwkv7")
model_type_b: Type of model B ("hf" or "rwkv7")
Returns:
HTML string with interactive visualization
"""
def decode_token(token_id: int, tokenizer, model_type: str) -> str:
"""Decode a single token ID to text using the appropriate tokenizer."""
if tokenizer is None:
return f"[{token_id}]"
try:
if model_type in ["rwkv", "rwkv7"]:
# RWKV tokenizer uses decode method
decoded = tokenizer.decode([token_id])
return decoded if decoded else f"[{token_id}]"
else:
# HuggingFace tokenizer
decoded = tokenizer.decode([token_id])
return decoded if decoded else f"[{token_id}]"
except Exception as e:
print(f"Warning: Failed to decode token {token_id} ({model_type}): {e}")
return f"[{token_id}]"
def build_byte_to_token_map(text: str, tokenizer, model_type: str):
"""Build mapping from byte position to token index using the correct tokenizer.
Returns a list of (start, end, token_idx) tuples for range-based lookup."""
if tokenizer is None:
return []
token_ranges = []
try:
if model_type in ["rwkv", "rwkv7"]:
# RWKV tokenizer
tokenized = tokenizer.encode(text)
if hasattr(tokenized, "ids"):
token_ids = tokenized.ids
else:
token_ids = tokenized
byte_pos = 0
for idx, token_id in enumerate(token_ids):
try:
token_bytes = tokenizer.decodeBytes([token_id])
token_ranges.append((byte_pos, byte_pos + len(token_bytes), idx))
byte_pos += len(token_bytes)
except Exception as e:
print(f"Warning: Failed to decode RWKV token {token_id}: {e}")
pass
else:
# HuggingFace tokenizer - use TokenizerBytesConverter
tokenizer_name = getattr(tokenizer, "name_or_path", None)
if tokenizer_name:
converter = TokenizerBytesConverter(tokenizer_name, trust_remote_code=True)
token_bytes_list = converter.encode_to_bytes(text)
byte_pos = 0
for idx, token_bytes in enumerate(token_bytes_list):
token_ranges.append((byte_pos, byte_pos + len(token_bytes), idx))
byte_pos += len(token_bytes)
else:
print(f"Warning: Could not get tokenizer name for HF model")
except Exception as e:
print(f"Warning: Could not build byte-to-token map ({model_type}): {e}")
return []
return token_ranges
def find_token_for_byte(byte_pos: int, token_ranges):
for start, end, idx in token_ranges:
if start <= byte_pos < end:
return idx
return None
# Calculate deltas
deltas = [a - b for a, b in zip(byte_losses_a, byte_losses_b)]
avg_delta = sum(deltas) / len(deltas) if deltas else 0
# Calculate average compression rates
avg_compression_a = sum(byte_losses_a) / len(byte_losses_a) * COMPRESSION_RATE_FACTOR if byte_losses_a else 0
avg_compression_b = sum(byte_losses_b) / len(byte_losses_b) * COMPRESSION_RATE_FACTOR if byte_losses_b else 0
avg_delta_compression = avg_delta * COMPRESSION_RATE_FACTOR
# Get token info
text_bytes = text.encode("utf-8")
token_info = get_token_info_for_text(text)
common_boundaries = token_info["common_boundaries"]
qwen_tokens = token_info["qwen_tokens"]
rwkv_tokens = token_info["rwkv_tokens"]
# Build byte position to token index mapping
model_a_token_ranges = build_byte_to_token_map(text, tokenizer_a, model_type_a)
model_b_token_ranges = build_byte_to_token_map(text, tokenizer_b, model_type_b)
def get_tokens_for_range(byte_start, byte_end, token_list):
result = []
for t_start, t_end, token_id, t_str in token_list:
if t_start < byte_end and t_end > byte_start:
result.append((token_id, t_str))
return result
# Build tokens based on common boundaries
tokens = []
token_count = 0
for i in range(len(common_boundaries) - 1):
start_byte = common_boundaries[i]
end_byte = common_boundaries[i + 1]
token_bytes = text_bytes[start_byte:end_byte]
try:
token_text = token_bytes.decode("utf-8")
except UnicodeDecodeError:
continue
qwen_toks = get_tokens_for_range(start_byte, end_byte, qwen_tokens)
rwkv_toks = get_tokens_for_range(start_byte, end_byte, rwkv_tokens)
if re.search(r"\w", token_text, re.UNICODE):
tokens.append(
{
"type": "word",
"text": token_text,
"byte_start": start_byte,
"byte_end": end_byte,
"word_lower": token_text.lower(),
"qwen_tokens": qwen_toks,
"rwkv_tokens": rwkv_toks,
}
)
else:
tokens.append(
{
"type": "non-word",
"text": token_text,
"byte_start": start_byte,
"byte_end": end_byte,
"qwen_tokens": qwen_toks,
"rwkv_tokens": rwkv_toks,
}
)
# Track word occurrences
word_occurrences = {}
word_id_counter = 0
for i, token in enumerate(tokens):
if token["type"] == "word":
word_lower = token["word_lower"]
if word_lower not in word_occurrences:
word_occurrences[word_lower] = []
word_occurrences[word_lower].append(i)
token["word_id"] = word_id_counter
word_id_counter += 1
# Build HTML content
html_content = []
def escape_for_attr(s):
# Escape all characters that could break HTML attributes
# Order matters: & must be escaped first
return (
s.replace("&", "&")
.replace('"', """)
.replace("'", "'")
.replace("<", "<")
.replace(">", ">")
.replace("\n", "
")
.replace("\r", "
")
.replace("\t", " ")
)
for token in tokens:
token_text = token["text"]
byte_start = token["byte_start"]
byte_end = token["byte_end"]
# Get actual model token IDs for this byte range
model_a_token_idx = find_token_for_byte(byte_start, model_a_token_ranges)
model_b_token_idx = find_token_for_byte(byte_start, model_b_token_ranges)
# Build token info strings showing all tokens in this byte range
# Model A (RWKV7) - show all tokens that overlap with this byte range
model_a_info = ", ".join([f"[{idx}] {repr(s)}" for idx, s in token["rwkv_tokens"]])
# Model B (Qwen3) - show all tokens that overlap with this byte range
model_b_info = ", ".join([f"[{idx}] {repr(s)}" for idx, s in token["qwen_tokens"]])
raw_bytes = list(text_bytes[byte_start:byte_end])
losses_a = byte_losses_a[byte_start:byte_end]
losses_b = byte_losses_b[byte_start:byte_end]
bytes_str = " ".join([f"{b:02x}" for b in raw_bytes])
compression_a_str = " ".join([f"{l * COMPRESSION_RATE_FACTOR:.2f}%" for l in losses_a])
compression_b_str = " ".join([f"{l * COMPRESSION_RATE_FACTOR:.2f}%" for l in losses_b])
# Calculate average compression rate for this token
avg_compression_a_token = sum(losses_a) / len(losses_a) * COMPRESSION_RATE_FACTOR if losses_a else 0
avg_compression_b_token = sum(losses_b) / len(losses_b) * COMPRESSION_RATE_FACTOR if losses_b else 0
topk_a_json = ""
topk_b_json = ""
if topk_predictions_a is not None and model_a_token_ranges:
model_a_token_idx = find_token_for_byte(byte_start, model_a_token_ranges)
if model_a_token_idx is not None and model_a_token_idx < len(topk_predictions_a):
pred = topk_predictions_a[model_a_token_idx]
try:
decoded_pred = [
pred[0],
pred[1],
[[tid, prob, decode_token(tid, tokenizer_a, model_type_a)] for tid, prob in pred[2]],
]
# Use base64 encoding to avoid escaping issues
import base64
topk_a_json = base64.b64encode(json.dumps(decoded_pred, ensure_ascii=False).encode("utf-8")).decode("ascii")
except Exception as e:
pass
if topk_predictions_b is not None and model_b_token_ranges:
model_b_token_idx = find_token_for_byte(byte_start, model_b_token_ranges)
if model_b_token_idx is not None and model_b_token_idx < len(topk_predictions_b):
pred = topk_predictions_b[model_b_token_idx]
try:
decoded_pred = [pred[0], pred[1], [[tid, prob, decode_token(tid, tokenizer_b, model_type_b)] for tid, prob in pred[2]]]
# Use base64 encoding to avoid escaping issues
import base64
topk_b_json = base64.b64encode(json.dumps(decoded_pred, ensure_ascii=False).encode("utf-8")).decode("ascii")
except Exception as e:
pass
token_count += 1
token_deltas = deltas[byte_start:byte_end]
avg_token_delta = sum(token_deltas) / len(token_deltas) if token_deltas else 0
tuned_delta = avg_token_delta - avg_delta
# Initial rendering uses white color, JavaScript will apply colors based on slider
r, g, b = 255, 255, 255
token_html_parts = []
for char in token_text:
if char == "<":
escaped_char = "<"
elif char == ">":
escaped_char = ">"
elif char == "&":
escaped_char = "&"
elif char == "\n":
escaped_char = "
"
elif char == " ":
escaped_char = " "
elif char == "\t":
escaped_char = " "
else:
escaped_char = char
token_html_parts.append(escaped_char)
token_span_content = "".join(token_html_parts)
data_attrs = (
f'data-model-a="{escape_for_attr(model_a_info)}" '
f'data-model-b="{escape_for_attr(model_b_info)}" '
f'data-bytes="{escape_for_attr(bytes_str)}" '
f'data-compression-a="{escape_for_attr(compression_a_str)}" '
f'data-compression-b="{escape_for_attr(compression_b_str)}" '
f'data-avg-compression-a="{avg_compression_a_token:.2f}" '
f'data-avg-compression-b="{avg_compression_b_token:.2f}" '
f'data-tuned-delta="{tuned_delta:.6f}" '
f'data-topk-a="{escape_for_attr(topk_a_json)}" '
f'data-topk-b="{escape_for_attr(topk_b_json)}"'
)
style_attr = f'style="background-color: rgb({r},{g},{b})"'
if token["type"] == "word":
word_lower = token["word_lower"]
occurrences = word_occurrences[word_lower]
if len(occurrences) > 1:
word_id = token["word_id"]
html_content.append(
f''
+ token_span_content
+ ""
)
else:
html_content.append(f'{token_span_content}')
else:
html_content.append(f'{token_span_content}')
delta_color = "#64ff64" if avg_delta < 0 else "#ff6464"
html = f"""