File size: 7,812 Bytes
48051fc
 
 
 
 
f0253ab
fea1f8d
48051fc
 
 
 
 
 
fea1f8d
 
 
 
 
 
 
 
 
 
 
 
 
 
f0253ab
 
fea1f8d
 
f0253ab
48051fc
 
 
 
f0253ab
48051fc
fea1f8d
48051fc
 
f0253ab
48051fc
f0253ab
 
48051fc
 
f0253ab
48051fc
fea1f8d
 
48051fc
f0253ab
 
 
 
 
 
 
 
fea1f8d
 
f0253ab
 
 
fea1f8d
 
f0253ab
 
 
48051fc
 
fea1f8d
f0253ab
fea1f8d
 
 
f0253ab
fea1f8d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f0253ab
fea1f8d
48051fc
fea1f8d
 
 
 
 
 
 
f0253ab
fea1f8d
f0253ab
fea1f8d
f0253ab
 
 
 
 
fea1f8d
f0253ab
fea1f8d
f0253ab
 
 
fea1f8d
f0253ab
fea1f8d
 
f0253ab
 
fea1f8d
 
f0253ab
 
 
fea1f8d
 
 
 
 
 
 
 
 
 
 
f0253ab
fea1f8d
 
 
 
 
 
f0253ab
fea1f8d
 
 
 
 
f0253ab
48051fc
fea1f8d
48051fc
fea1f8d
 
48051fc
fea1f8d
f0253ab
fea1f8d
 
 
 
f0253ab
 
fea1f8d
 
 
 
f0253ab
 
fea1f8d
f0253ab
fea1f8d
 
 
 
 
 
f0253ab
48051fc
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
import math
import gradio as gr
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch

# ---- Model config ----
MODEL_NAME = "gpt2"  # e.g., "distilgpt2", "gpt2-medium"
device = "cuda" if torch.cuda.is_available() else "cpu"

tok = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModelForCausalLM.from_pretrained(MODEL_NAME).to(device)
model.eval()

EPS = 1e-9  # tie tolerance

def safe_exp(x: float) -> str:
    # Pretty string even for big magnitudes
    try:
        return f"{math.exp(x):.6e}"
    except OverflowError:
        return "∞ (overflow)"
    except Exception:
        return "—"

def is_finite(x: float) -> bool:
    return x is not None and math.isfinite(x)

def seq_logprob(context: str, candidate: str, assume_leading_space: bool, show_topk: int):
    """
    Compute log P(candidate | context) via chain rule over tokens.
    Returns (total_logprob, detail_text, token_list, num_tokens).
    """
    cand_text = (" " + candidate) if assume_leading_space else candidate
    with torch.no_grad():
        ctx_ids = tok.encode(context, return_tensors="pt").to(device)
        cand_ids = tok.encode(cand_text, add_special_tokens=False)

        if len(cand_ids) == 0:
            return None, "Candidate tokenized to empty sequence (check spacing).", [], 0

        total_logprob = 0.0
        step_lines = []
        input_ids = ctx_ids
        token_texts = []

        for i, t_id in enumerate(cand_ids):
            outputs = model(input_ids=input_ids)
            logits = outputs.logits[:, -1, :]
            logprobs = torch.log_softmax(logits, dim=-1)
            token_lp = logprobs[0, t_id].item()
            total_logprob += token_lp

            tok_str = tok.decode([t_id])
            token_texts.append(tok_str)

            if show_topk > 0:
                topk_vals, topk_idx = torch.topk(logprobs, k=min(show_topk, logprobs.shape[-1]), dim=-1)
                tops = ", ".join([f"{repr(tok.decode([int(idx)]))}:{val.item():.2f}"
                                  for idx, val in zip(topk_idx[0], topk_vals[0])])
                step_lines.append(
                    f"Step {i+1}: token={repr(tok_str)}  logprob={token_lp:.6f}  "
                    f"prob={math.exp(token_lp):.6e}\n  top-{show_topk}: {tops}"
                )
            else:
                step_lines.append(
                    f"Step {i+1}: token={repr(tok_str)}  logprob={token_lp:.6f}  "
                    f"prob={math.exp(token_lp):.6e}"
                )

            # teacher-forcing: append the true token to continue conditioning
            input_ids = torch.cat([input_ids, torch.tensor([[t_id]], device=device)], dim=1)

        return total_logprob, "\n".join(step_lines), token_texts, len(cand_ids)

def compare_candidates(context, candA, candB, assume_space, topk, use_len_norm):
    # Basic checks
    errors = []
    if not context.strip():
        errors.append("Please enter a context.")
    if not candA.strip():
        errors.append("Please enter Candidate A.")
    if not candB.strip():
        errors.append("Please enter Candidate B.")
    if errors:
        msg = " ".join(errors)
        return (f"<div style='color:#b00020;font-weight:600'>{msg}</div>",
                "", "", "", "", "")

    # Compute
    lpA, detA, toksA, nA = seq_logprob(context, candA, assume_space, topk)
    lpB, detB, toksB, nB = seq_logprob(context, candB, assume_space, topk)

    # Validate numbers
    if not (is_finite(lpA) and is_finite(lpB)):
        return ("<div style='color:#b00020;font-weight:600'>Numerical issue (NaN/Inf). "
                "Try shorter context, smaller model (e.g., distilgpt2), or disable length normalization.</div>",
                "", "", "", "", "")

    # Optionally length-normalize (per-token average log-prob)
    # Note: odds under length-normalization are "per-token odds", not whole-sequence odds.
    if use_len_norm:
        if nA == 0 or nB == 0:
            return ("<div style='color:#b00020;font-weight:600'>Empty tokenization. "
                    "Check spacing or turn off 'assume leading space'.</div>",
                    "", "", "", "", "")
        scoreA = lpA / nA
        scoreB = lpB / nB
        label_suffix = " (per-token)"
    else:
        scoreA = lpA
        scoreB = lpB
        label_suffix = ""

    diff = scoreA - scoreB  # log-odds if unnormalized; log per-token odds otherwise

    # Winner logic with proper tie handling
    if abs(diff) <= EPS:
        winner = "Tie"
    elif diff > 0:
        winner = "Candidate A"
    else:
        winner = "Candidate B"

    ratio_str = safe_exp(diff)

    # Colors
    if winner == "Candidate A":
        win_color = "#166534"  # green
    elif winner == "Candidate B":
        win_color = "#1d4ed8"  # blue
    else:
        win_color = "#92400e"  # amber

    # Headline
    headline = (
        f"<div style='padding:14px;border-radius:12px;background:#f8fafc;"
        f"border:1px solid #e2e8f0;margin-bottom:10px'>"
        f"<div style='font-size:20px;font-weight:800;color:{win_color};'>Winner: {winner}{label_suffix}</div>"
        f"<div style='margin-top:6px;font-size:16px;'>"
        f"Odds A/B{label_suffix} = <b>{ratio_str}</b> &nbsp;|&nbsp; "
        f"log-odds A−B{label_suffix} = <b>{diff:.6f}</b>"
        f"</div>"
        f"<div style='margin-top:6px;color:#475569'>"
        f"(Odds &gt; 1 ⇒ A more probable; &lt; 1 ⇒ B more probable. "
        f"{'Per-token uses average log-prob.' if use_len_norm else 'Whole-sequence comparison.'})"
        f"</div></div>"
    )

    def summarize(label, cand, lp, toks, n):
        return (
            f"**{label}**: {cand}\n\n"
            f"Tokenization: {toks}\n"
            f"Total logprob: {lp:.6f}\n"
            f"Sequence probability: {math.exp(lp):.6e}\n"
            f"Tokens: {n}"
        )

    sumA = summarize("Candidate A", candA, lpA, toksA, nA)
    sumB = summarize("Candidate B", candB, lpB, toksB, nB)

    return headline, sumA, detA, sumB, detB, ""

def swap(a, b):
    return b, a

with gr.Blocks(title="Two-Candidate Next-Token Comparator (Robust)") as demo:
    gr.Markdown(
        "# Two-Candidate Next-Word/Token Probability (Robust)\n"
        "Compare **P(A|context)** vs **P(B|context)** from a pretrained causal LM (no fine-tuning).\n"
        "- Proper tie handling and numerical guards.\n"
        "- Optional **length normalization** (per-token).\n"
        "- Use **Swap** to sanity-check symmetry."
    )
    with gr.Row():
        context = gr.Textbox(label="Context (prompt)", lines=6, placeholder="Paste prior text here...")
    with gr.Row():
        candA = gr.Textbox(label="Candidate A (follow-up)")
        candB = gr.Textbox(label="Candidate B (follow-up)")
    with gr.Row():
        assume_space = gr.Checkbox(value=True, label="Assume leading space before candidates (useful for GPT-2 tokenization)")
        topk = gr.Slider(0, 20, value=5, step=1, label="Show top-k alternatives (per token step)")
        use_len_norm = gr.Checkbox(value=False, label="Use length normalization (average log-prob per token)")
    with gr.Row():
        btn_compare = gr.Button("Compare", variant="primary")
        btn_swap = gr.Button("Swap A ↔ B")

    winner_html = gr.HTML()
    summaryA = gr.Markdown()
    detailsA = gr.Textbox(label="Candidate A — step-by-step", lines=10)
    summaryB = gr.Markdown()
    detailsB = gr.Textbox(label="Candidate B — step-by-step", lines=10)
    _hidden = gr.Textbox(visible=False)

    btn_compare.click(
        fn=compare_candidates,
        inputs=[context, candA, candB, assume_space, topk, use_len_norm],
        outputs=[winner_html, summaryA, detailsA, summaryB, detailsB, _hidden]
    )

    btn_swap.click(
        fn=swap, inputs=[candA, candB], outputs=[candA, candB]
    )

demo.launch()