File size: 4,530 Bytes
352de18
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
"""Build HTML for completion view with per-token probability styling."""

from __future__ import annotations

import base64
import html
import json
import math
from typing import Any


def _lerp_byte(low: int, high: int, weight: float) -> int:
    return int(round(low + (high - low) * weight))


# matplotlib / ColorBrewer sequential "Blues" (lighter range; keeps contrast with black text)
_SEQUENTIAL_BLUES_STOPS: list[tuple[int, int, int]] = [
    (247, 251, 255),  # #f7fbff
    (222, 235, 247),  # #deebf7
    (198, 219, 239),  # #c6dbef
    (158, 202, 225),  # #9ecae1
    (107, 174, 214),  # #6baed6
]


def _interpolate_sequential_stops(
    stops: list[tuple[int, int, int]],
    weight: float,
) -> tuple[int, int, int]:
    """Piecewise linear interpolation along `weight` in [0, 1]."""
    if len(stops) == 1:
        return stops[0]
    weight = max(0.0, min(1.0, weight))
    segment_count = len(stops) - 1
    scaled = weight * segment_count
    segment_index = int(math.floor(scaled))
    segment_index = min(segment_index, segment_count - 1)
    fraction = scaled - segment_index
    low = stops[segment_index]
    high = stops[segment_index + 1]
    return (
        _lerp_byte(low[0], high[0], fraction),
        _lerp_byte(low[1], high[1], fraction),
        _lerp_byte(low[2], high[2], fraction),
    )


def probability_to_css_background(probability: float) -> str:
    """
    Background tint linear in **probability** along a typical sequential Blues colormap.

    Uses the light band of ColorBrewer / matplotlib Blues so black text stays readable.
    """
    if math.isnan(probability):
        weight = 0.0
    else:
        weight = max(0.0, min(1.0, float(probability)))
    red, green, blue = _interpolate_sequential_stops(_SEQUENTIAL_BLUES_STOPS, weight)
    return f"rgb({red},{green},{blue})"


def _encode_tooltip_payload(
    alternatives: list[dict[str, Any]],
    sampled_token_text: str,
    sampled_probability: float,
    chosen_in_top5: bool,
) -> str:
    """Base64 JSON for safe use in a data attribute."""
    payload = json.dumps(
        {
            "alternatives": alternatives,
            "sampled_token": {
                "token_text": sampled_token_text,
                "probability": sampled_probability,
            },
            "chosen_in_top5": chosen_in_top5,
        },
        ensure_ascii=True,
    )
    return base64.b64encode(payload.encode("utf-8")).decode("ascii")


def build_completion_html(
    prompt_text: str,
    token_display_strings: list[str],
    chosen_probabilities: list[float],
    top5_alternatives: list[list[dict[str, Any]]],
    chosen_in_top5_flags: list[bool],
) -> str:
    """
    Build a single div with escaped prompt text and per-token spans for the completion.

    Each entry in top5_alternatives is up to five dicts with keys: token_text, probability.
    chosen_in_top5_flags indicates whether the sampled token appears in that top-5 list.
    """
    if len(token_display_strings) != len(chosen_probabilities):
        raise ValueError("token_display_strings and chosen_probabilities length mismatch")
    if len(token_display_strings) != len(top5_alternatives):
        raise ValueError("token_display_strings and top5_alternatives length mismatch")
    if len(token_display_strings) != len(chosen_in_top5_flags):
        raise ValueError("token_display_strings and chosen_in_top5_flags length mismatch")

    escaped_prompt = html.escape(prompt_text)
    parts: list[str] = [
        '<div class="completion-playground-root" style="white-space: pre-wrap; word-break: break-word;">',
        "<style>"
        ".completion-playground-root .completion-token{"
        "display:inline-block;vertical-align:baseline;"
        "}</style>",
        escaped_prompt,
    ]
    for display_text, probability, alternatives, chosen_in_top5 in zip(
        token_display_strings,
        chosen_probabilities,
        top5_alternatives,
        chosen_in_top5_flags,
        strict=True,
    ):
        background = probability_to_css_background(probability)
        payload = _encode_tooltip_payload(
            alternatives,
            display_text,
            probability,
            chosen_in_top5,
        )
        escaped_inner = html.escape(display_text)
        parts.append(
            f'<span class="completion-token" style="background-color:{background};cursor:pointer;" '
            f'data-top5="{html.escape(payload, quote=True)}">{escaped_inner}</span>'
        )
    parts.append("</div>")
    return "".join(parts)