|
|
import math |
|
|
import gradio as gr |
|
|
from transformers import AutoTokenizer, AutoModelForCausalLM |
|
|
import torch |
|
|
|
|
|
|
|
|
MODEL_NAME = "gpt2" |
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
|
|
|
tok = AutoTokenizer.from_pretrained(MODEL_NAME) |
|
|
model = AutoModelForCausalLM.from_pretrained(MODEL_NAME).to(device) |
|
|
model.eval() |
|
|
|
|
|
EPS = 1e-9 |
|
|
|
|
|
def safe_exp(x: float) -> str: |
|
|
|
|
|
try: |
|
|
return f"{math.exp(x):.6e}" |
|
|
except OverflowError: |
|
|
return "∞ (overflow)" |
|
|
except Exception: |
|
|
return "—" |
|
|
|
|
|
def is_finite(x: float) -> bool: |
|
|
return x is not None and math.isfinite(x) |
|
|
|
|
|
def seq_logprob(context: str, candidate: str, assume_leading_space: bool, show_topk: int): |
|
|
""" |
|
|
Compute log P(candidate | context) via chain rule over tokens. |
|
|
Returns (total_logprob, detail_text, token_list, num_tokens). |
|
|
""" |
|
|
cand_text = (" " + candidate) if assume_leading_space else candidate |
|
|
with torch.no_grad(): |
|
|
ctx_ids = tok.encode(context, return_tensors="pt").to(device) |
|
|
cand_ids = tok.encode(cand_text, add_special_tokens=False) |
|
|
|
|
|
if len(cand_ids) == 0: |
|
|
return None, "Candidate tokenized to empty sequence (check spacing).", [], 0 |
|
|
|
|
|
total_logprob = 0.0 |
|
|
step_lines = [] |
|
|
input_ids = ctx_ids |
|
|
token_texts = [] |
|
|
|
|
|
for i, t_id in enumerate(cand_ids): |
|
|
outputs = model(input_ids=input_ids) |
|
|
logits = outputs.logits[:, -1, :] |
|
|
logprobs = torch.log_softmax(logits, dim=-1) |
|
|
token_lp = logprobs[0, t_id].item() |
|
|
total_logprob += token_lp |
|
|
|
|
|
tok_str = tok.decode([t_id]) |
|
|
token_texts.append(tok_str) |
|
|
|
|
|
if show_topk > 0: |
|
|
topk_vals, topk_idx = torch.topk(logprobs, k=min(show_topk, logprobs.shape[-1]), dim=-1) |
|
|
tops = ", ".join([f"{repr(tok.decode([int(idx)]))}:{val.item():.2f}" |
|
|
for idx, val in zip(topk_idx[0], topk_vals[0])]) |
|
|
step_lines.append( |
|
|
f"Step {i+1}: token={repr(tok_str)} logprob={token_lp:.6f} " |
|
|
f"prob={math.exp(token_lp):.6e}\n top-{show_topk}: {tops}" |
|
|
) |
|
|
else: |
|
|
step_lines.append( |
|
|
f"Step {i+1}: token={repr(tok_str)} logprob={token_lp:.6f} " |
|
|
f"prob={math.exp(token_lp):.6e}" |
|
|
) |
|
|
|
|
|
|
|
|
input_ids = torch.cat([input_ids, torch.tensor([[t_id]], device=device)], dim=1) |
|
|
|
|
|
return total_logprob, "\n".join(step_lines), token_texts, len(cand_ids) |
|
|
|
|
|
def compare_candidates(context, candA, candB, assume_space, topk, use_len_norm): |
|
|
|
|
|
errors = [] |
|
|
if not context.strip(): |
|
|
errors.append("Please enter a context.") |
|
|
if not candA.strip(): |
|
|
errors.append("Please enter Candidate A.") |
|
|
if not candB.strip(): |
|
|
errors.append("Please enter Candidate B.") |
|
|
if errors: |
|
|
msg = " ".join(errors) |
|
|
return (f"<div style='color:#b00020;font-weight:600'>{msg}</div>", |
|
|
"", "", "", "", "") |
|
|
|
|
|
|
|
|
lpA, detA, toksA, nA = seq_logprob(context, candA, assume_space, topk) |
|
|
lpB, detB, toksB, nB = seq_logprob(context, candB, assume_space, topk) |
|
|
|
|
|
|
|
|
if not (is_finite(lpA) and is_finite(lpB)): |
|
|
return ("<div style='color:#b00020;font-weight:600'>Numerical issue (NaN/Inf). " |
|
|
"Try shorter context, smaller model (e.g., distilgpt2), or disable length normalization.</div>", |
|
|
"", "", "", "", "") |
|
|
|
|
|
|
|
|
|
|
|
if use_len_norm: |
|
|
if nA == 0 or nB == 0: |
|
|
return ("<div style='color:#b00020;font-weight:600'>Empty tokenization. " |
|
|
"Check spacing or turn off 'assume leading space'.</div>", |
|
|
"", "", "", "", "") |
|
|
scoreA = lpA / nA |
|
|
scoreB = lpB / nB |
|
|
label_suffix = " (per-token)" |
|
|
else: |
|
|
scoreA = lpA |
|
|
scoreB = lpB |
|
|
label_suffix = "" |
|
|
|
|
|
diff = scoreA - scoreB |
|
|
|
|
|
|
|
|
if abs(diff) <= EPS: |
|
|
winner = "Tie" |
|
|
elif diff > 0: |
|
|
winner = "Candidate A" |
|
|
else: |
|
|
winner = "Candidate B" |
|
|
|
|
|
ratio_str = safe_exp(diff) |
|
|
|
|
|
|
|
|
if winner == "Candidate A": |
|
|
win_color = "#166534" |
|
|
elif winner == "Candidate B": |
|
|
win_color = "#1d4ed8" |
|
|
else: |
|
|
win_color = "#92400e" |
|
|
|
|
|
|
|
|
headline = ( |
|
|
f"<div style='padding:14px;border-radius:12px;background:#f8fafc;" |
|
|
f"border:1px solid #e2e8f0;margin-bottom:10px'>" |
|
|
f"<div style='font-size:20px;font-weight:800;color:{win_color};'>Winner: {winner}{label_suffix}</div>" |
|
|
f"<div style='margin-top:6px;font-size:16px;'>" |
|
|
f"Odds A/B{label_suffix} = <b>{ratio_str}</b> | " |
|
|
f"log-odds A−B{label_suffix} = <b>{diff:.6f}</b>" |
|
|
f"</div>" |
|
|
f"<div style='margin-top:6px;color:#475569'>" |
|
|
f"(Odds > 1 ⇒ A more probable; < 1 ⇒ B more probable. " |
|
|
f"{'Per-token uses average log-prob.' if use_len_norm else 'Whole-sequence comparison.'})" |
|
|
f"</div></div>" |
|
|
) |
|
|
|
|
|
def summarize(label, cand, lp, toks, n): |
|
|
return ( |
|
|
f"**{label}**: {cand}\n\n" |
|
|
f"Tokenization: {toks}\n" |
|
|
f"Total logprob: {lp:.6f}\n" |
|
|
f"Sequence probability: {math.exp(lp):.6e}\n" |
|
|
f"Tokens: {n}" |
|
|
) |
|
|
|
|
|
sumA = summarize("Candidate A", candA, lpA, toksA, nA) |
|
|
sumB = summarize("Candidate B", candB, lpB, toksB, nB) |
|
|
|
|
|
return headline, sumA, detA, sumB, detB, "" |
|
|
|
|
|
def swap(a, b): |
|
|
return b, a |
|
|
|
|
|
with gr.Blocks(title="Two-Candidate Next-Token Comparator (Robust)") as demo: |
|
|
gr.Markdown( |
|
|
"# Two-Candidate Next-Word/Token Probability (Robust)\n" |
|
|
"Compare **P(A|context)** vs **P(B|context)** from a pretrained causal LM (no fine-tuning).\n" |
|
|
"- Proper tie handling and numerical guards.\n" |
|
|
"- Optional **length normalization** (per-token).\n" |
|
|
"- Use **Swap** to sanity-check symmetry." |
|
|
) |
|
|
with gr.Row(): |
|
|
context = gr.Textbox(label="Context (prompt)", lines=6, placeholder="Paste prior text here...") |
|
|
with gr.Row(): |
|
|
candA = gr.Textbox(label="Candidate A (follow-up)") |
|
|
candB = gr.Textbox(label="Candidate B (follow-up)") |
|
|
with gr.Row(): |
|
|
assume_space = gr.Checkbox(value=True, label="Assume leading space before candidates (useful for GPT-2 tokenization)") |
|
|
topk = gr.Slider(0, 20, value=5, step=1, label="Show top-k alternatives (per token step)") |
|
|
use_len_norm = gr.Checkbox(value=False, label="Use length normalization (average log-prob per token)") |
|
|
with gr.Row(): |
|
|
btn_compare = gr.Button("Compare", variant="primary") |
|
|
btn_swap = gr.Button("Swap A ↔ B") |
|
|
|
|
|
winner_html = gr.HTML() |
|
|
summaryA = gr.Markdown() |
|
|
detailsA = gr.Textbox(label="Candidate A — step-by-step", lines=10) |
|
|
summaryB = gr.Markdown() |
|
|
detailsB = gr.Textbox(label="Candidate B — step-by-step", lines=10) |
|
|
_hidden = gr.Textbox(visible=False) |
|
|
|
|
|
btn_compare.click( |
|
|
fn=compare_candidates, |
|
|
inputs=[context, candA, candB, assume_space, topk, use_len_norm], |
|
|
outputs=[winner_html, summaryA, detailsA, summaryB, detailsB, _hidden] |
|
|
) |
|
|
|
|
|
btn_swap.click( |
|
|
fn=swap, inputs=[candA, candB], outputs=[candA, candB] |
|
|
) |
|
|
|
|
|
demo.launch() |