Spaces:
Running
Running
Commit
·
452ae9b
1
Parent(s):
04ab453
update
Browse files- visualization/html_generator.py +75 -35
visualization/html_generator.py
CHANGED
|
@@ -36,6 +36,7 @@ def get_rwkv_tokenizer():
|
|
| 36 |
if _rwkv_tokenizer is None:
|
| 37 |
from rwkv.rwkv_tokenizer import TRIE_TOKENIZER
|
| 38 |
import os
|
|
|
|
| 39 |
script_dir = os.path.dirname(os.path.abspath(__file__))
|
| 40 |
vocab_path = os.path.join(os.path.dirname(script_dir), "support", "rwkv_vocab_v20230424.txt")
|
| 41 |
_rwkv_tokenizer = TRIE_TOKENIZER(vocab_path)
|
|
@@ -297,24 +298,28 @@ def generate_comparison_html(
|
|
| 297 |
rwkv_toks = get_tokens_for_range(start_byte, end_byte, rwkv_tokens)
|
| 298 |
|
| 299 |
if re.search(r"\w", token_text, re.UNICODE):
|
| 300 |
-
tokens.append(
|
| 301 |
-
|
| 302 |
-
|
| 303 |
-
|
| 304 |
-
|
| 305 |
-
|
| 306 |
-
|
| 307 |
-
|
| 308 |
-
|
|
|
|
|
|
|
| 309 |
else:
|
| 310 |
-
tokens.append(
|
| 311 |
-
|
| 312 |
-
|
| 313 |
-
|
| 314 |
-
|
| 315 |
-
|
| 316 |
-
|
| 317 |
-
|
|
|
|
|
|
|
| 318 |
|
| 319 |
# Track word occurrences
|
| 320 |
word_occurrences = {}
|
|
@@ -335,14 +340,16 @@ def generate_comparison_html(
|
|
| 335 |
def escape_for_attr(s):
|
| 336 |
# Escape all characters that could break HTML attributes
|
| 337 |
# Order matters: & must be escaped first
|
| 338 |
-
return (
|
| 339 |
-
|
| 340 |
-
|
| 341 |
-
|
| 342 |
-
|
| 343 |
-
|
| 344 |
-
|
| 345 |
-
|
|
|
|
|
|
|
| 346 |
|
| 347 |
for token in tokens:
|
| 348 |
token_text = token["text"]
|
|
@@ -382,7 +389,8 @@ def generate_comparison_html(
|
|
| 382 |
]
|
| 383 |
# Use base64 encoding to avoid escaping issues
|
| 384 |
import base64
|
| 385 |
-
|
|
|
|
| 386 |
except Exception as e:
|
| 387 |
pass
|
| 388 |
if topk_predictions_b is not None and model_b_token_ranges:
|
|
@@ -393,7 +401,8 @@ def generate_comparison_html(
|
|
| 393 |
decoded_pred = [pred[0], pred[1], [[tid, prob, decode_token(tid, tokenizer_b, model_type_b)] for tid, prob in pred[2]]]
|
| 394 |
# Use base64 encoding to avoid escaping issues
|
| 395 |
import base64
|
| 396 |
-
|
|
|
|
| 397 |
except Exception as e:
|
| 398 |
pass
|
| 399 |
|
|
@@ -402,7 +411,13 @@ def generate_comparison_html(
|
|
| 402 |
token_deltas = deltas[byte_start:byte_end]
|
| 403 |
avg_token_delta = sum(token_deltas) / len(token_deltas) if token_deltas else 0
|
| 404 |
|
| 405 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 406 |
r, g, b = color
|
| 407 |
|
| 408 |
token_html_parts = []
|
|
@@ -857,6 +872,7 @@ def generate_comparison_html(
|
|
| 857 |
const avgDelta = {avg_delta_compression};
|
| 858 |
const slider = document.getElementById('saturation-slider');
|
| 859 |
const saturationValue = document.getElementById('saturation-value');
|
|
|
|
| 860 |
|
| 861 |
const allDeltas = [];
|
| 862 |
tokenSpans.forEach(token => {{
|
|
@@ -864,6 +880,12 @@ def generate_comparison_html(
|
|
| 864 |
if (!isNaN(delta)) allDeltas.push(delta);
|
| 865 |
}});
|
| 866 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 867 |
function percentile(arr, p) {{
|
| 868 |
const sorted = [...arr].sort((a, b) => a - b);
|
| 869 |
const idx = (p / 100) * (sorted.length - 1);
|
|
@@ -873,9 +895,9 @@ def generate_comparison_html(
|
|
| 873 |
return sorted[lower] + (sorted[upper] - sorted[lower]) * (idx - lower);
|
| 874 |
}}
|
| 875 |
|
| 876 |
-
function deltaToColor(
|
| 877 |
if (maxDeviation === 0) return 'rgb(255, 255, 255)';
|
| 878 |
-
const deviation =
|
| 879 |
let normalized = Math.max(-1, Math.min(1, deviation / maxDeviation));
|
| 880 |
let r, g, b;
|
| 881 |
if (normalized < 0) {{
|
|
@@ -892,13 +914,31 @@ def generate_comparison_html(
|
|
| 892 |
return `rgb(${{r}}, ${{g}}, ${{b}})`;
|
| 893 |
}}
|
| 894 |
|
| 895 |
-
|
| 896 |
-
|
| 897 |
-
|
| 898 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 899 |
const delta = parseFloat(token.getAttribute('data-delta'));
|
| 900 |
if (!isNaN(delta)) {{
|
| 901 |
-
|
|
|
|
| 902 |
}}
|
| 903 |
}});
|
| 904 |
}}
|
|
|
|
| 36 |
if _rwkv_tokenizer is None:
|
| 37 |
from rwkv.rwkv_tokenizer import TRIE_TOKENIZER
|
| 38 |
import os
|
| 39 |
+
|
| 40 |
script_dir = os.path.dirname(os.path.abspath(__file__))
|
| 41 |
vocab_path = os.path.join(os.path.dirname(script_dir), "support", "rwkv_vocab_v20230424.txt")
|
| 42 |
_rwkv_tokenizer = TRIE_TOKENIZER(vocab_path)
|
|
|
|
| 298 |
rwkv_toks = get_tokens_for_range(start_byte, end_byte, rwkv_tokens)
|
| 299 |
|
| 300 |
if re.search(r"\w", token_text, re.UNICODE):
|
| 301 |
+
tokens.append(
|
| 302 |
+
{
|
| 303 |
+
"type": "word",
|
| 304 |
+
"text": token_text,
|
| 305 |
+
"byte_start": start_byte,
|
| 306 |
+
"byte_end": end_byte,
|
| 307 |
+
"word_lower": token_text.lower(),
|
| 308 |
+
"qwen_tokens": qwen_toks,
|
| 309 |
+
"rwkv_tokens": rwkv_toks,
|
| 310 |
+
}
|
| 311 |
+
)
|
| 312 |
else:
|
| 313 |
+
tokens.append(
|
| 314 |
+
{
|
| 315 |
+
"type": "non-word",
|
| 316 |
+
"text": token_text,
|
| 317 |
+
"byte_start": start_byte,
|
| 318 |
+
"byte_end": end_byte,
|
| 319 |
+
"qwen_tokens": qwen_toks,
|
| 320 |
+
"rwkv_tokens": rwkv_toks,
|
| 321 |
+
}
|
| 322 |
+
)
|
| 323 |
|
| 324 |
# Track word occurrences
|
| 325 |
word_occurrences = {}
|
|
|
|
| 340 |
def escape_for_attr(s):
|
| 341 |
# Escape all characters that could break HTML attributes
|
| 342 |
# Order matters: & must be escaped first
|
| 343 |
+
return (
|
| 344 |
+
s.replace("&", "&")
|
| 345 |
+
.replace('"', """)
|
| 346 |
+
.replace("'", "'")
|
| 347 |
+
.replace("<", "<")
|
| 348 |
+
.replace(">", ">")
|
| 349 |
+
.replace("\n", " ")
|
| 350 |
+
.replace("\r", " ")
|
| 351 |
+
.replace("\t", "	")
|
| 352 |
+
)
|
| 353 |
|
| 354 |
for token in tokens:
|
| 355 |
token_text = token["text"]
|
|
|
|
| 389 |
]
|
| 390 |
# Use base64 encoding to avoid escaping issues
|
| 391 |
import base64
|
| 392 |
+
|
| 393 |
+
topk_a_json = base64.b64encode(json.dumps(decoded_pred, ensure_ascii=False).encode("utf-8")).decode("ascii")
|
| 394 |
except Exception as e:
|
| 395 |
pass
|
| 396 |
if topk_predictions_b is not None and model_b_token_ranges:
|
|
|
|
| 401 |
decoded_pred = [pred[0], pred[1], [[tid, prob, decode_token(tid, tokenizer_b, model_type_b)] for tid, prob in pred[2]]]
|
| 402 |
# Use base64 encoding to avoid escaping issues
|
| 403 |
import base64
|
| 404 |
+
|
| 405 |
+
topk_b_json = base64.b64encode(json.dumps(decoded_pred, ensure_ascii=False).encode("utf-8")).decode("ascii")
|
| 406 |
except Exception as e:
|
| 407 |
pass
|
| 408 |
|
|
|
|
| 411 |
token_deltas = deltas[byte_start:byte_end]
|
| 412 |
avg_token_delta = sum(token_deltas) / len(token_deltas) if token_deltas else 0
|
| 413 |
|
| 414 |
+
# Apply power transformation to enhance color differentiation
|
| 415 |
+
# Preserve the sign and apply power to the absolute value
|
| 416 |
+
power_n = 3
|
| 417 |
+
sign = 1 if avg_token_delta >= 0 else -1
|
| 418 |
+
avg_token_delta_powered = sign * (abs(avg_token_delta) ** power_n)
|
| 419 |
+
|
| 420 |
+
color = delta_to_color(avg_token_delta_powered, avg_delta, max_deviation)
|
| 421 |
r, g, b = color
|
| 422 |
|
| 423 |
token_html_parts = []
|
|
|
|
| 872 |
const avgDelta = {avg_delta_compression};
|
| 873 |
const slider = document.getElementById('saturation-slider');
|
| 874 |
const saturationValue = document.getElementById('saturation-value');
|
| 875 |
+
const powerN = 3; // Must match Python's power_n
|
| 876 |
|
| 877 |
const allDeltas = [];
|
| 878 |
tokenSpans.forEach(token => {{
|
|
|
|
| 880 |
if (!isNaN(delta)) allDeltas.push(delta);
|
| 881 |
}});
|
| 882 |
|
| 883 |
+
// Apply power transformation to delta values (matching Python's logic)
|
| 884 |
+
function applyPower(delta) {{
|
| 885 |
+
const sign = delta >= 0 ? 1 : -1;
|
| 886 |
+
return sign * Math.pow(Math.abs(delta), powerN);
|
| 887 |
+
}}
|
| 888 |
+
|
| 889 |
function percentile(arr, p) {{
|
| 890 |
const sorted = [...arr].sort((a, b) => a - b);
|
| 891 |
const idx = (p / 100) * (sorted.length - 1);
|
|
|
|
| 895 |
return sorted[lower] + (sorted[upper] - sorted[lower]) * (idx - lower);
|
| 896 |
}}
|
| 897 |
|
| 898 |
+
function deltaToColor(deltaPowered, avgDeltaPowered, maxDeviation) {{
|
| 899 |
if (maxDeviation === 0) return 'rgb(255, 255, 255)';
|
| 900 |
+
const deviation = deltaPowered - avgDeltaPowered;
|
| 901 |
let normalized = Math.max(-1, Math.min(1, deviation / maxDeviation));
|
| 902 |
let r, g, b;
|
| 903 |
if (normalized < 0) {{
|
|
|
|
| 914 |
return `rgb(${{r}}, ${{g}}, ${{b}})`;
|
| 915 |
}}
|
| 916 |
|
| 917 |
+
// Pre-compute powered deltas and avgDeltaPowered
|
| 918 |
+
const allDeltasPowered = allDeltas.map(d => applyPower(d));
|
| 919 |
+
const avgDeltaPowered = applyPower(avgDelta);
|
| 920 |
+
|
| 921 |
+
// Pre-compute min and max deviations for logarithmic interpolation
|
| 922 |
+
const allDeviations = allDeltasPowered.map(d => Math.abs(d - avgDeltaPowered));
|
| 923 |
+
const minDeviation = Math.max(percentile(allDeviations, 1), 1e-9); // Use 1st percentile to avoid extreme outliers
|
| 924 |
+
const maxDeviationFull = Math.max(percentile(allDeviations, 100), 1e-6);
|
| 925 |
+
|
| 926 |
+
function updateColors(sliderValue) {{
|
| 927 |
+
// Use logarithmic interpolation for smoother perceptual control
|
| 928 |
+
// sliderValue: 50-100, maps to maxDeviation from minDeviation to maxDeviationFull
|
| 929 |
+
// Lower slider value = lower maxDeviation = more saturation (more colors hit the clamp)
|
| 930 |
+
// Higher slider value = higher maxDeviation = less saturation (fewer colors hit the clamp)
|
| 931 |
+
const t = (sliderValue - 50) / 50; // Normalize to 0-1
|
| 932 |
+
// Logarithmic interpolation: exp(lerp(log(min), log(max), t))
|
| 933 |
+
const logMin = Math.log(minDeviation);
|
| 934 |
+
const logMax = Math.log(maxDeviationFull);
|
| 935 |
+
const maxDeviation = Math.exp(logMin + t * (logMax - logMin));
|
| 936 |
+
|
| 937 |
+
tokenSpans.forEach((token, idx) => {{
|
| 938 |
const delta = parseFloat(token.getAttribute('data-delta'));
|
| 939 |
if (!isNaN(delta)) {{
|
| 940 |
+
const deltaPowered = applyPower(delta);
|
| 941 |
+
token.style.backgroundColor = deltaToColor(deltaPowered, avgDeltaPowered, maxDeviation);
|
| 942 |
}}
|
| 943 |
}});
|
| 944 |
}}
|