Spaces:
Running
Running
| """ | |
| HTML visualization generator for UncheatableEval. | |
| Generates interactive HTML visualizations comparing byte-level losses between two models. | |
| """ | |
| import json | |
| import math | |
| import re | |
| from typing import List, Tuple, Optional, Set | |
| import numpy as np | |
| from core.helpers import TokenizerBytesConverter | |
| # Compression rate conversion factor | |
| COMPRESSION_RATE_FACTOR = (1.0 / math.log(2.0)) * 0.125 * 100.0 | |
| # Global tokenizers (lazy loaded) | |
| _qwen_tokenizer = None | |
| _rwkv_tokenizer = None | |
| def get_qwen_tokenizer(): | |
| """Lazy load Qwen tokenizer.""" | |
| global _qwen_tokenizer | |
| if _qwen_tokenizer is None: | |
| _qwen_tokenizer = TokenizerBytesConverter("Qwen/Qwen3-0.6B-Base") | |
| return _qwen_tokenizer | |
| def get_rwkv_tokenizer(): | |
| """Lazy load RWKV tokenizer.""" | |
| global _rwkv_tokenizer | |
| if _rwkv_tokenizer is None: | |
| from rwkv.rwkv_tokenizer import TRIE_TOKENIZER | |
| import os | |
| script_dir = os.path.dirname(os.path.abspath(__file__)) | |
| vocab_path = os.path.join(os.path.dirname(script_dir), "support", "rwkv_vocab_v20230424.txt") | |
| _rwkv_tokenizer = TRIE_TOKENIZER(vocab_path) | |
| return _rwkv_tokenizer | |
| def get_tokenizer_boundaries(text: str, tokenizer, is_rwkv: bool = False) -> Set[int]: | |
| """Get token boundaries (byte positions) for a given text.""" | |
| boundaries = set() | |
| boundaries.add(0) | |
| if is_rwkv: | |
| tokenized = tokenizer.encode(text) | |
| if hasattr(tokenized, "ids"): | |
| token_ids = tokenized.ids | |
| else: | |
| token_ids = tokenized | |
| byte_pos = 0 | |
| for token_id in token_ids: | |
| token_bytes = tokenizer.decodeBytes([token_id]) | |
| byte_pos += len(token_bytes) | |
| boundaries.add(byte_pos) | |
| else: | |
| token_bytes_list = tokenizer.encode_to_bytes(text) | |
| byte_pos = 0 | |
| for token_bytes in token_bytes_list: | |
| byte_pos += len(token_bytes) | |
| boundaries.add(byte_pos) | |
| return boundaries | |
| def get_token_info_for_text(text: str) -> dict: | |
| """Get detailed token information for each byte position.""" | |
| qwen_tokenizer = get_qwen_tokenizer() | |
| rwkv_tokenizer = get_rwkv_tokenizer() | |
| # Get Qwen tokens with positions | |
| qwen_tokens = [] | |
| byte_to_qwen = {} | |
| # Keep both token id (vocab id) and decoded bytes so the tooltip can show true token ids. | |
| qwen_id_and_bytes = qwen_tokenizer.encode_to_ids_and_bytes(text) | |
| byte_pos = 0 | |
| for idx, (token_id, token_bytes) in enumerate(qwen_id_and_bytes): | |
| start = byte_pos | |
| end = byte_pos + len(token_bytes) | |
| try: | |
| token_str = bytes(token_bytes).decode("utf-8") | |
| except UnicodeDecodeError: | |
| token_str = repr(bytes(token_bytes)) | |
| qwen_tokens.append((start, end, token_id, token_str)) | |
| byte_to_qwen[start] = idx | |
| byte_pos = end | |
| # Get RWKV tokens with positions | |
| rwkv_tokens = [] | |
| byte_to_rwkv = {} | |
| tokenized = rwkv_tokenizer.encode(text) | |
| if hasattr(tokenized, "ids"): | |
| token_ids = tokenized.ids | |
| else: | |
| token_ids = tokenized | |
| byte_pos = 0 | |
| for idx, token_id in enumerate(token_ids): | |
| token_bytes = rwkv_tokenizer.decodeBytes([token_id]) | |
| start = byte_pos | |
| end = byte_pos + len(token_bytes) | |
| try: | |
| token_str = token_bytes.decode("utf-8") | |
| except UnicodeDecodeError: | |
| token_str = repr(token_bytes) | |
| rwkv_tokens.append((start, end, token_id, token_str)) | |
| byte_to_rwkv[start] = idx | |
| byte_pos = end | |
| # Get common boundaries | |
| qwen_boundaries = set([0] + [t[1] for t in qwen_tokens]) | |
| rwkv_boundaries = set([0] + [t[1] for t in rwkv_tokens]) | |
| common_boundaries = sorted(qwen_boundaries & rwkv_boundaries) | |
| return { | |
| "common_boundaries": common_boundaries, | |
| "qwen_tokens": qwen_tokens, | |
| "rwkv_tokens": rwkv_tokens, | |
| "byte_to_qwen": byte_to_qwen, | |
| "byte_to_rwkv": byte_to_rwkv, | |
| } | |
| def generate_comparison_html( | |
| text: str, | |
| byte_losses_a: List[float], | |
| byte_losses_b: List[float], | |
| model_a_name: str, | |
| model_b_name: str, | |
| topk_predictions_a: Optional[List] = None, | |
| topk_predictions_b: Optional[List] = None, | |
| tokenizer_a=None, | |
| tokenizer_b=None, | |
| model_type_a: str = "hf", | |
| model_type_b: str = "rwkv7", | |
| ) -> str: | |
| """ | |
| Generate an interactive HTML visualization comparing two models. | |
| Args: | |
| text: The input text that was evaluated | |
| byte_losses_a: Per-byte losses from model A | |
| byte_losses_b: Per-byte losses from model B | |
| model_a_name: Display name for model A | |
| model_b_name: Display name for model B | |
| topk_predictions_a: Top-k predictions from model A | |
| topk_predictions_b: Top-k predictions from model B | |
| tokenizer_a: Tokenizer for model A | |
| tokenizer_b: Tokenizer for model B | |
| model_type_a: Type of model A ("hf" or "rwkv7") | |
| model_type_b: Type of model B ("hf" or "rwkv7") | |
| Returns: | |
| HTML string with interactive visualization | |
| """ | |
| def decode_token(token_id: int, tokenizer, model_type: str) -> str: | |
| """Decode a single token ID to text using the appropriate tokenizer.""" | |
| if tokenizer is None: | |
| return f"[{token_id}]" | |
| try: | |
| if model_type in ["rwkv", "rwkv7"]: | |
| # RWKV tokenizer uses decode method | |
| decoded = tokenizer.decode([token_id]) | |
| return decoded if decoded else f"[{token_id}]" | |
| else: | |
| # HuggingFace tokenizer | |
| decoded = tokenizer.decode([token_id]) | |
| return decoded if decoded else f"[{token_id}]" | |
| except Exception as e: | |
| print(f"Warning: Failed to decode token {token_id} ({model_type}): {e}") | |
| return f"[{token_id}]" | |
| def build_byte_to_token_map(text: str, tokenizer, model_type: str): | |
| """Build mapping from byte position to token index using the correct tokenizer. | |
| Returns a list of (start, end, token_idx) tuples for range-based lookup.""" | |
| if tokenizer is None: | |
| return [] | |
| token_ranges = [] | |
| try: | |
| if model_type in ["rwkv", "rwkv7"]: | |
| # RWKV tokenizer | |
| tokenized = tokenizer.encode(text) | |
| if hasattr(tokenized, "ids"): | |
| token_ids = tokenized.ids | |
| else: | |
| token_ids = tokenized | |
| byte_pos = 0 | |
| for idx, token_id in enumerate(token_ids): | |
| try: | |
| token_bytes = tokenizer.decodeBytes([token_id]) | |
| token_ranges.append((byte_pos, byte_pos + len(token_bytes), idx)) | |
| byte_pos += len(token_bytes) | |
| except Exception as e: | |
| print(f"Warning: Failed to decode RWKV token {token_id}: {e}") | |
| pass | |
| else: | |
| # HuggingFace tokenizer - use TokenizerBytesConverter | |
| tokenizer_name = getattr(tokenizer, "name_or_path", None) | |
| if tokenizer_name: | |
| converter = TokenizerBytesConverter(tokenizer_name, trust_remote_code=True) | |
| token_bytes_list = converter.encode_to_bytes(text) | |
| byte_pos = 0 | |
| for idx, token_bytes in enumerate(token_bytes_list): | |
| token_ranges.append((byte_pos, byte_pos + len(token_bytes), idx)) | |
| byte_pos += len(token_bytes) | |
| else: | |
| print(f"Warning: Could not get tokenizer name for HF model") | |
| except Exception as e: | |
| print(f"Warning: Could not build byte-to-token map ({model_type}): {e}") | |
| return [] | |
| return token_ranges | |
| def find_token_for_byte(byte_pos: int, token_ranges): | |
| for start, end, idx in token_ranges: | |
| if start <= byte_pos < end: | |
| return idx | |
| return None | |
| # Calculate deltas | |
| deltas = [a - b for a, b in zip(byte_losses_a, byte_losses_b)] | |
| avg_delta = sum(deltas) / len(deltas) if deltas else 0 | |
| # Calculate average compression rates | |
| avg_compression_a = sum(byte_losses_a) / len(byte_losses_a) * COMPRESSION_RATE_FACTOR if byte_losses_a else 0 | |
| avg_compression_b = sum(byte_losses_b) / len(byte_losses_b) * COMPRESSION_RATE_FACTOR if byte_losses_b else 0 | |
| avg_delta_compression = avg_delta * COMPRESSION_RATE_FACTOR | |
| # Get token info | |
| text_bytes = text.encode("utf-8") | |
| token_info = get_token_info_for_text(text) | |
| common_boundaries = token_info["common_boundaries"] | |
| qwen_tokens = token_info["qwen_tokens"] | |
| rwkv_tokens = token_info["rwkv_tokens"] | |
| # Build byte position to token index mapping | |
| model_a_token_ranges = build_byte_to_token_map(text, tokenizer_a, model_type_a) | |
| model_b_token_ranges = build_byte_to_token_map(text, tokenizer_b, model_type_b) | |
| def get_tokens_for_range(byte_start, byte_end, token_list): | |
| result = [] | |
| for t_start, t_end, token_id, t_str in token_list: | |
| if t_start < byte_end and t_end > byte_start: | |
| result.append((token_id, t_str)) | |
| return result | |
| # Build tokens based on common boundaries | |
| tokens = [] | |
| token_count = 0 | |
| for i in range(len(common_boundaries) - 1): | |
| start_byte = common_boundaries[i] | |
| end_byte = common_boundaries[i + 1] | |
| token_bytes = text_bytes[start_byte:end_byte] | |
| try: | |
| token_text = token_bytes.decode("utf-8") | |
| except UnicodeDecodeError: | |
| continue | |
| qwen_toks = get_tokens_for_range(start_byte, end_byte, qwen_tokens) | |
| rwkv_toks = get_tokens_for_range(start_byte, end_byte, rwkv_tokens) | |
| if re.search(r"\w", token_text, re.UNICODE): | |
| tokens.append( | |
| { | |
| "type": "word", | |
| "text": token_text, | |
| "byte_start": start_byte, | |
| "byte_end": end_byte, | |
| "word_lower": token_text.lower(), | |
| "qwen_tokens": qwen_toks, | |
| "rwkv_tokens": rwkv_toks, | |
| } | |
| ) | |
| else: | |
| tokens.append( | |
| { | |
| "type": "non-word", | |
| "text": token_text, | |
| "byte_start": start_byte, | |
| "byte_end": end_byte, | |
| "qwen_tokens": qwen_toks, | |
| "rwkv_tokens": rwkv_toks, | |
| } | |
| ) | |
| # Track word occurrences | |
| word_occurrences = {} | |
| word_id_counter = 0 | |
| for i, token in enumerate(tokens): | |
| if token["type"] == "word": | |
| word_lower = token["word_lower"] | |
| if word_lower not in word_occurrences: | |
| word_occurrences[word_lower] = [] | |
| word_occurrences[word_lower].append(i) | |
| token["word_id"] = word_id_counter | |
| word_id_counter += 1 | |
| # Build HTML content | |
| html_content = [] | |
| def escape_for_attr(s): | |
| # Escape all characters that could break HTML attributes | |
| # Order matters: & must be escaped first | |
| return ( | |
| s.replace("&", "&") | |
| .replace('"', """) | |
| .replace("'", "'") | |
| .replace("<", "<") | |
| .replace(">", ">") | |
| .replace("\n", " ") | |
| .replace("\r", " ") | |
| .replace("\t", "	") | |
| ) | |
| for token in tokens: | |
| token_text = token["text"] | |
| byte_start = token["byte_start"] | |
| byte_end = token["byte_end"] | |
| # Get actual model token IDs for this byte range | |
| model_a_token_idx = find_token_for_byte(byte_start, model_a_token_ranges) | |
| model_b_token_idx = find_token_for_byte(byte_start, model_b_token_ranges) | |
| # Build token info strings showing all tokens in this byte range | |
| # Model A (RWKV7) - show all tokens that overlap with this byte range | |
| model_a_info = ", ".join([f"[{idx}] {repr(s)}" for idx, s in token["rwkv_tokens"]]) | |
| # Model B (Qwen3) - show all tokens that overlap with this byte range | |
| model_b_info = ", ".join([f"[{idx}] {repr(s)}" for idx, s in token["qwen_tokens"]]) | |
| raw_bytes = list(text_bytes[byte_start:byte_end]) | |
| losses_a = byte_losses_a[byte_start:byte_end] | |
| losses_b = byte_losses_b[byte_start:byte_end] | |
| bytes_str = " ".join([f"{b:02x}" for b in raw_bytes]) | |
| compression_a_str = " ".join([f"{l * COMPRESSION_RATE_FACTOR:.2f}%" for l in losses_a]) | |
| compression_b_str = " ".join([f"{l * COMPRESSION_RATE_FACTOR:.2f}%" for l in losses_b]) | |
| # Calculate average compression rate for this token | |
| avg_compression_a_token = sum(losses_a) / len(losses_a) * COMPRESSION_RATE_FACTOR if losses_a else 0 | |
| avg_compression_b_token = sum(losses_b) / len(losses_b) * COMPRESSION_RATE_FACTOR if losses_b else 0 | |
| topk_a_json = "" | |
| topk_b_json = "" | |
| if topk_predictions_a is not None and model_a_token_ranges: | |
| model_a_token_idx = find_token_for_byte(byte_start, model_a_token_ranges) | |
| if model_a_token_idx is not None and model_a_token_idx < len(topk_predictions_a): | |
| pred = topk_predictions_a[model_a_token_idx] | |
| try: | |
| decoded_pred = [ | |
| pred[0], | |
| pred[1], | |
| [[tid, prob, decode_token(tid, tokenizer_a, model_type_a)] for tid, prob in pred[2]], | |
| ] | |
| # Use base64 encoding to avoid escaping issues | |
| import base64 | |
| topk_a_json = base64.b64encode(json.dumps(decoded_pred, ensure_ascii=False).encode("utf-8")).decode("ascii") | |
| except Exception as e: | |
| pass | |
| if topk_predictions_b is not None and model_b_token_ranges: | |
| model_b_token_idx = find_token_for_byte(byte_start, model_b_token_ranges) | |
| if model_b_token_idx is not None and model_b_token_idx < len(topk_predictions_b): | |
| pred = topk_predictions_b[model_b_token_idx] | |
| try: | |
| decoded_pred = [pred[0], pred[1], [[tid, prob, decode_token(tid, tokenizer_b, model_type_b)] for tid, prob in pred[2]]] | |
| # Use base64 encoding to avoid escaping issues | |
| import base64 | |
| topk_b_json = base64.b64encode(json.dumps(decoded_pred, ensure_ascii=False).encode("utf-8")).decode("ascii") | |
| except Exception as e: | |
| pass | |
| token_count += 1 | |
| token_deltas = deltas[byte_start:byte_end] | |
| avg_token_delta = sum(token_deltas) / len(token_deltas) if token_deltas else 0 | |
| tuned_delta = avg_token_delta - avg_delta | |
| # Initial rendering uses white color, JavaScript will apply colors based on slider | |
| r, g, b = 255, 255, 255 | |
| token_html_parts = [] | |
| for char in token_text: | |
| if char == "<": | |
| escaped_char = "<" | |
| elif char == ">": | |
| escaped_char = ">" | |
| elif char == "&": | |
| escaped_char = "&" | |
| elif char == "\n": | |
| escaped_char = "<br>" | |
| elif char == " ": | |
| escaped_char = " " | |
| elif char == "\t": | |
| escaped_char = " " | |
| else: | |
| escaped_char = char | |
| token_html_parts.append(escaped_char) | |
| token_span_content = "".join(token_html_parts) | |
| data_attrs = ( | |
| f'data-model-a="{escape_for_attr(model_a_info)}" ' | |
| f'data-model-b="{escape_for_attr(model_b_info)}" ' | |
| f'data-bytes="{escape_for_attr(bytes_str)}" ' | |
| f'data-compression-a="{escape_for_attr(compression_a_str)}" ' | |
| f'data-compression-b="{escape_for_attr(compression_b_str)}" ' | |
| f'data-avg-compression-a="{avg_compression_a_token:.2f}" ' | |
| f'data-avg-compression-b="{avg_compression_b_token:.2f}" ' | |
| f'data-tuned-delta="{tuned_delta:.6f}" ' | |
| f'data-topk-a="{escape_for_attr(topk_a_json)}" ' | |
| f'data-topk-b="{escape_for_attr(topk_b_json)}"' | |
| ) | |
| style_attr = f'style="background-color: rgb({r},{g},{b})"' | |
| if token["type"] == "word": | |
| word_lower = token["word_lower"] | |
| occurrences = word_occurrences[word_lower] | |
| if len(occurrences) > 1: | |
| word_id = token["word_id"] | |
| html_content.append( | |
| f'<span class="token word" {data_attrs} {style_attr} data-word="{word_lower}" data-word-id="{word_id}">' | |
| + token_span_content | |
| + "</span>" | |
| ) | |
| else: | |
| html_content.append(f'<span class="token" {data_attrs} {style_attr}>{token_span_content}</span>') | |
| else: | |
| html_content.append(f'<span class="token" {data_attrs} {style_attr}>{token_span_content}</span>') | |
| delta_color = "#64ff64" if avg_delta < 0 else "#ff6464" | |
| html = f"""<!DOCTYPE html> | |
| <html> | |
| <head> | |
| <meta charset="UTF-8"> | |
| <title>Model Comparison</title> | |
| <style> | |
| body {{ | |
| font-family: Consolas, 'Courier New', monospace; | |
| margin: 0; | |
| padding: 0; | |
| background-color: #f5f5f5; | |
| }} | |
| .header {{ | |
| background-color: #333; | |
| color: white; | |
| padding: 20px; | |
| position: sticky; | |
| top: 0; | |
| z-index: 100; | |
| }} | |
| .header h1 {{ | |
| margin: 0 0 15px 0; | |
| font-size: 18px; | |
| }} | |
| .meta {{ | |
| display: flex; | |
| flex-wrap: wrap; | |
| gap: 20px; | |
| font-size: 12px; | |
| color: #c8c8c8; | |
| }} | |
| .legend {{ | |
| display: flex; | |
| gap: 15px; | |
| margin-top: 10px; | |
| }} | |
| .legend-item {{ | |
| display: flex; | |
| align-items: center; | |
| gap: 5px; | |
| }} | |
| .legend-box {{ | |
| width: 20px; | |
| height: 12px; | |
| border: 1px solid #666; | |
| }} | |
| .content {{ | |
| background-color: white; | |
| margin: 10px; | |
| padding: 15px; | |
| border: 1px solid #ccc; | |
| font-size: 14px; | |
| line-height: 1.8; | |
| word-wrap: break-word; | |
| position: relative; | |
| }} | |
| .content span {{ | |
| padding: 1px 0; | |
| }} | |
| .word {{ | |
| cursor: pointer; | |
| position: relative; | |
| }} | |
| .word:hover {{ | |
| outline: 2px solid #007bff; | |
| outline-offset: 1px; | |
| }} | |
| .word.highlighted {{ | |
| outline: 2px solid #ff6b6b; | |
| outline-offset: 1px; | |
| }} | |
| #svg-overlay {{ | |
| position: fixed; | |
| top: 0; | |
| left: 0; | |
| width: 100%; | |
| height: 100%; | |
| pointer-events: none; | |
| z-index: 1000; | |
| }} | |
| .link-line {{ | |
| stroke: #007bff; | |
| stroke-width: 2; | |
| fill: none; | |
| opacity: 0.7; | |
| }} | |
| .link-dot {{ | |
| fill: #007bff; | |
| opacity: 0.8; | |
| }} | |
| .token {{ | |
| position: relative; | |
| cursor: help; | |
| }} | |
| .token:hover {{ | |
| outline: 1px dashed #666; | |
| }} | |
| #tooltip {{ | |
| position: fixed; | |
| background-color: rgba(0, 0, 0, 0.9); | |
| color: white; | |
| padding: 10px 14px; | |
| border-radius: 6px; | |
| font-size: 12px; | |
| max-width: 500px; | |
| z-index: 2000; | |
| pointer-events: none; | |
| display: none; | |
| line-height: 1.6; | |
| box-shadow: 0 2px 10px rgba(0,0,0,0.3); | |
| }} | |
| #tooltip .label {{ | |
| color: #aaa; | |
| font-weight: bold; | |
| }} | |
| #tooltip .bytes {{ | |
| color: #a5f3fc; | |
| font-family: monospace; | |
| }} | |
| #tooltip .loss-a {{ | |
| color: #86efac; | |
| font-family: monospace; | |
| }} | |
| #tooltip .loss-b {{ | |
| color: #fca5a5; | |
| font-family: monospace; | |
| }} | |
| #tooltip .model-a {{ | |
| color: #fcd34d; | |
| }} | |
| #tooltip .model-b {{ | |
| color: #7dd3fc; | |
| }} | |
| #tooltip .topk-section {{ | |
| margin-top: 8px; | |
| padding-top: 8px; | |
| border-top: 1px solid #555; | |
| }} | |
| #tooltip .topk-container {{ | |
| display: flex; | |
| gap: 16px; | |
| }} | |
| #tooltip .topk-column {{ | |
| flex: 1; | |
| min-width: 180px; | |
| }} | |
| #tooltip .topk-title {{ | |
| color: #aaa; | |
| font-weight: bold; | |
| margin-bottom: 4px; | |
| font-size: 11px; | |
| }} | |
| #tooltip .topk-title.model-a {{ | |
| color: #86efac; | |
| }} | |
| #tooltip .topk-title.model-b {{ | |
| color: #fca5a5; | |
| }} | |
| #tooltip .topk-list {{ | |
| font-size: 11px; | |
| }} | |
| #tooltip .topk-item {{ | |
| display: flex; | |
| gap: 4px; | |
| padding: 1px 0; | |
| align-items: center; | |
| }} | |
| #tooltip .topk-rank {{ | |
| color: #888; | |
| min-width: 18px; | |
| }} | |
| #tooltip .topk-rank.hit {{ | |
| color: #ffd700; | |
| }} | |
| #tooltip .topk-token {{ | |
| color: #a5f3fc; | |
| max-width: 100px; | |
| overflow: hidden; | |
| text-overflow: ellipsis; | |
| white-space: nowrap; | |
| font-family: monospace; | |
| }} | |
| #tooltip .topk-prob {{ | |
| color: #86efac; | |
| min-width: 45px; | |
| text-align: right; | |
| }} | |
| #tooltip .topk-hit {{ | |
| color: #22c55e; | |
| }} | |
| #tooltip .topk-miss {{ | |
| color: #ef4444; | |
| font-style: italic; | |
| }} | |
| </style> | |
| </head> | |
| <body> | |
| <svg id="svg-overlay"></svg> | |
| <div id="tooltip"></div> | |
| <div class="header"> | |
| <div class="meta"> | |
| <div>Model A: {model_a_name}</div> | |
| <div>Model B: {model_b_name}</div> | |
| <div>RWKV Compression: {avg_compression_a:.2f}%</div> | |
| <div>Qwen Compression: {avg_compression_b:.2f}%</div> | |
| <div style="color: {delta_color}">Avg Delta: {avg_delta_compression:+.2f}%</div> | |
| </div> | |
| <div class="legend"> | |
| <div class="legend-item"> | |
| <div class="legend-box" style="background-color: rgb(77, 255, 77)"></div> | |
| <span>RWKV better than avg</span> | |
| </div> | |
| <div class="legend-item"> | |
| <div class="legend-box" style="background-color: rgb(255, 255, 255)"></div> | |
| <span>Equal to avg</span> | |
| </div> | |
| <div class="legend-item"> | |
| <div class="legend-box" style="background-color: rgb(255, 77, 77)"></div> | |
| <span>RWKV worse than avg</span> | |
| </div> | |
| <div class="legend-item" style="margin-left: 20px;"> | |
| <span style="color: #aaa;">Color Range:</span> | |
| <input type="range" id="color-range-slider" min="0" max="100" value="10" step="0.1" style="width: 200px; vertical-align: middle;"> | |
| <span id="color-range-value" style="color: #fff; min-width: 45px; display: inline-block;">10%</span> | |
| </div> | |
| </div> | |
| </div> | |
| <div class="content"> | |
| {''.join(html_content)} | |
| </div> | |
| <script> | |
| const svgOverlay = document.getElementById('svg-overlay'); | |
| const words = document.querySelectorAll('.word'); | |
| const wordGroups = {{}}; | |
| words.forEach(word => {{ | |
| const wordText = word.getAttribute('data-word'); | |
| if (!wordGroups[wordText]) {{ | |
| wordGroups[wordText] = []; | |
| }} | |
| wordGroups[wordText].push(word); | |
| }}); | |
| function clearLines() {{ | |
| svgOverlay.innerHTML = ''; | |
| words.forEach(w => w.classList.remove('highlighted')); | |
| }} | |
| function drawLines(hoveredWord) {{ | |
| clearLines(); | |
| const wordText = hoveredWord.getAttribute('data-word'); | |
| const wordId = parseInt(hoveredWord.getAttribute('data-word-id')); | |
| const sameWords = wordGroups[wordText] || []; | |
| const previousWords = sameWords.filter(w => {{ | |
| const id = parseInt(w.getAttribute('data-word-id')); | |
| return id < wordId; | |
| }}); | |
| if (previousWords.length === 0) return; | |
| sameWords.forEach(w => w.classList.add('highlighted')); | |
| const hoveredRect = hoveredWord.getBoundingClientRect(); | |
| const hoveredX = hoveredRect.left + hoveredRect.width / 2; | |
| const hoveredY = hoveredRect.top + hoveredRect.height / 2; | |
| previousWords.forEach(prevWord => {{ | |
| const prevRect = prevWord.getBoundingClientRect(); | |
| const prevX = prevRect.left + prevRect.width / 2; | |
| const prevY = prevRect.top + prevRect.height / 2; | |
| const midX = (hoveredX + prevX) / 2; | |
| const midY = Math.min(hoveredY, prevY) - 30; | |
| const path = document.createElementNS('http://www.w3.org/2000/svg', 'path'); | |
| path.setAttribute('class', 'link-line'); | |
| path.setAttribute('d', `M ${{prevX}} ${{prevY}} Q ${{midX}} ${{midY}} ${{hoveredX}} ${{hoveredY}}`); | |
| svgOverlay.appendChild(path); | |
| const dot1 = document.createElementNS('http://www.w3.org/2000/svg', 'circle'); | |
| dot1.setAttribute('class', 'link-dot'); | |
| dot1.setAttribute('cx', prevX); | |
| dot1.setAttribute('cy', prevY); | |
| dot1.setAttribute('r', 4); | |
| svgOverlay.appendChild(dot1); | |
| const dot2 = document.createElementNS('http://www.w3.org/2000/svg', 'circle'); | |
| dot2.setAttribute('class', 'link-dot'); | |
| dot2.setAttribute('cx', hoveredX); | |
| dot2.setAttribute('cy', hoveredY); | |
| dot2.setAttribute('r', 4); | |
| svgOverlay.appendChild(dot2); | |
| }}); | |
| }} | |
| words.forEach(word => {{ | |
| word.addEventListener('mouseenter', () => drawLines(word)); | |
| word.addEventListener('mouseleave', clearLines); | |
| }}); | |
| window.addEventListener('scroll', clearLines); | |
| const tooltip = document.getElementById('tooltip'); | |
| const tokenSpans = document.querySelectorAll('.token'); | |
| tokenSpans.forEach(token => {{ | |
| token.addEventListener('mouseenter', (e) => {{ | |
| const modelA = token.getAttribute('data-model-a') || 'N/A'; | |
| const modelB = token.getAttribute('data-model-b') || 'N/A'; | |
| const bytes = token.getAttribute('data-bytes') || ''; | |
| const compressionA = token.getAttribute('data-compression-a') || ''; | |
| const compressionB = token.getAttribute('data-compression-b') || ''; | |
| const avgCompressionA = token.getAttribute('data-avg-compression-a') || ''; | |
| const avgCompressionB = token.getAttribute('data-avg-compression-b') || ''; | |
| const top5A = token.getAttribute('data-topk-a') || ''; | |
| const top5B = token.getAttribute('data-topk-b') || ''; | |
| function formatTopkColumn(topkBase64, modelName, titleClass) {{ | |
| if (!topkBase64) return '<div class="topk-column"><div class="topk-title ' + titleClass + '">' + modelName + '</div><div class="topk-list">N/A</div></div>'; | |
| try {{ | |
| // Decode base64 to UTF-8 string (atob() doesn't support UTF-8, need proper decoding) | |
| const binaryString = atob(topkBase64); | |
| const bytes = new Uint8Array(binaryString.length); | |
| for (let i = 0; i < binaryString.length; i++) {{ | |
| bytes[i] = binaryString.charCodeAt(i); | |
| }} | |
| const topkJson = new TextDecoder('utf-8').decode(bytes); | |
| const data = JSON.parse(topkJson); | |
| const [actualId, rank, topkList] = data; | |
| let html = '<div class="topk-column">'; | |
| html += '<div class="topk-title ' + titleClass + '">' + modelName + '</div>'; | |
| html += '<div class="topk-list">'; | |
| topkList.forEach((item, idx) => {{ | |
| const [tokenId, prob, tokenText] = item; | |
| const isHit = tokenId === actualId; | |
| const rankClass = isHit ? 'topk-rank hit' : 'topk-rank'; | |
| const displayText = tokenText || '[' + tokenId + ']'; | |
| const escapedText = displayText.replace(/</g, '<').replace(/>/g, '>'); | |
| html += '<div class="topk-item">'; | |
| html += '<span class="' + rankClass + '">' + (idx + 1) + '.</span>'; | |
| html += '<span class="topk-token" title="ID: ' + tokenId + '">' + escapedText + '</span>'; | |
| html += '<span class="topk-prob">' + (prob * 100).toFixed(1) + '%</span>'; | |
| if (isHit) html += '<span class="topk-hit">✓</span>'; | |
| html += '</div>'; | |
| }}); | |
| if (rank > 10) {{ | |
| html += '<div class="topk-item topk-miss">Actual rank: ' + rank + '</div>'; | |
| }} | |
| html += '</div></div>'; | |
| return html; | |
| }} catch (e) {{ | |
| console.error('Error in formatTopkColumn for ' + modelName + ':', e); | |
| console.error('topkBase64:', topkBase64); | |
| return '<div class="topk-column"><div class="topk-title ' + titleClass + '">' + modelName + '</div><div class="topk-list">Error: ' + e.message + '</div></div>'; | |
| }} | |
| }} | |
| let tooltipHtml = ` | |
| <div><span class="label">Bytes:</span> <span class="bytes">${{bytes || '(empty)'}}</span></div> | |
| <div><span class="label">RWKV Compression Rate:</span> <span class="loss-a">${{compressionA || '(empty)'}}${{avgCompressionA ? ' (avg: ' + avgCompressionA + '%)' : ''}}</span></div> | |
| <div><span class="label">Qwen Compression Rate:</span> <span class="loss-b">${{compressionB || '(empty)'}}${{avgCompressionB ? ' (avg: ' + avgCompressionB + '%)' : ''}}</span></div> | |
| <hr style="border-color: #555; margin: 6px 0;"> | |
| <div><span class="label">RWKV:</span> <span class="model-a">${{modelA || '(empty)'}}</span></div> | |
| <div><span class="label">Qwen:</span> <span class="model-b">${{modelB || '(empty)'}}</span></div> | |
| `; | |
| if (top5A || top5B) {{ | |
| tooltipHtml += '<div class="topk-section"><div class="topk-container">'; | |
| tooltipHtml += formatTopkColumn(top5A, 'RWKV Top10', 'model-a'); | |
| tooltipHtml += formatTopkColumn(top5B, 'Qwen Top10', 'model-b'); | |
| tooltipHtml += '</div></div>'; | |
| }} | |
| tooltip.innerHTML = tooltipHtml; | |
| tooltip.style.display = 'block'; | |
| }}); | |
| token.addEventListener('mousemove', (e) => {{ | |
| const tooltipRect = tooltip.getBoundingClientRect(); | |
| const viewportWidth = window.innerWidth; | |
| const viewportHeight = window.innerHeight; | |
| let x = e.clientX + 15; | |
| let y = e.clientY + 15; | |
| if (x + tooltipRect.width > viewportWidth - 10) {{ | |
| x = e.clientX - tooltipRect.width - 15; | |
| }} | |
| if (y + tooltipRect.height > viewportHeight - 10) {{ | |
| y = e.clientY - tooltipRect.height - 15; | |
| }} | |
| if (x < 10) x = 10; | |
| if (y < 10) y = 10; | |
| tooltip.style.left = x + 'px'; | |
| tooltip.style.top = y + 'px'; | |
| }}); | |
| token.addEventListener('mouseleave', () => {{ | |
| tooltip.style.display = 'none'; | |
| }}); | |
| }}); | |
| const slider = document.getElementById('color-range-slider'); | |
| const rangeValue = document.getElementById('color-range-value'); | |
| // Collect all tuned_delta values | |
| const tokenData = []; | |
| tokenSpans.forEach((token, idx) => {{ | |
| const tunedDelta = parseFloat(token.getAttribute('data-tuned-delta')); | |
| if (!isNaN(tunedDelta)) {{ | |
| tokenData.push({{ token, tunedDelta, absDelta: Math.abs(tunedDelta) }}); | |
| }} | |
| }}); | |
| // Calculate max_abs_tuned_delta for normalization | |
| const maxAbsDelta = Math.max(...tokenData.map(d => d.absDelta), 1e-9); | |
| // Sort by |tuned_delta| to get rankings | |
| const sortedByAbs = [...tokenData].sort((a, b) => b.absDelta - a.absDelta); | |
| sortedByAbs.forEach((item, rank) => {{ | |
| item.rank = rank; // rank 0 = largest deviation | |
| }}); | |
| function tunedDeltaToColor(tunedDelta, maxAbsDelta, exponent) {{ | |
| // Normalize to [-1, 1] | |
| const normalized = Math.max(-1, Math.min(1, tunedDelta / maxAbsDelta)); | |
| let r, g, b; | |
| if (normalized < 0) {{ | |
| // Green (RWKV better) | |
| const intensity = Math.pow(-normalized, exponent); | |
| r = Math.round(255 * (1 - intensity * 0.85)); | |
| g = 255; | |
| b = Math.round(255 * (1 - intensity * 0.85)); | |
| }} else {{ | |
| // Red (RWKV worse) | |
| const intensity = Math.pow(normalized, exponent); | |
| r = 255; | |
| g = Math.round(255 * (1 - intensity * 0.85)); | |
| b = Math.round(255 * (1 - intensity * 0.85)); | |
| }} | |
| return `rgb(${{r}}, ${{g}}, ${{b}})`; | |
| }} | |
| function updateColors(colorRangePercent) {{ | |
| // colorRangePercent: 0-100, represents the proportion of tokens to color | |
| const colorCount = Math.round(tokenData.length * colorRangePercent / 100); | |
| // Calculate exponent: 100% -> 0.5, 0% -> 1.0 | |
| const exponent = 1 - (colorRangePercent / 100) * 0.5; | |
| // Calculate max deviation within the colored range | |
| let maxAbsDeltaInRange = 1e-9; | |
| tokenData.forEach(item => {{ | |
| if (item.rank < colorCount) {{ | |
| maxAbsDeltaInRange = Math.max(maxAbsDeltaInRange, item.absDelta); | |
| }} | |
| }}); | |
| tokenData.forEach(item => {{ | |
| if (item.rank < colorCount) {{ | |
| // Use dynamic normalization based on colored range | |
| item.token.style.backgroundColor = tunedDeltaToColor(item.tunedDelta, maxAbsDeltaInRange, exponent); | |
| }} else {{ | |
| // Outside color range, white | |
| item.token.style.backgroundColor = 'rgb(255, 255, 255)'; | |
| }} | |
| }}); | |
| }} | |
| slider.addEventListener('input', (e) => {{ | |
| const val = parseFloat(e.target.value); | |
| rangeValue.textContent = val.toFixed(1) + '%'; | |
| updateColors(val); | |
| }}); | |
| // Apply default color range on page load | |
| updateColors(10); | |
| </script> | |
| </body> | |
| </html> | |
| """ | |
| return html | |