import math
import gradio as gr
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
# ---- Model config ----
MODEL_NAME = "gpt2" # e.g., "distilgpt2", "gpt2-medium"
device = "cuda" if torch.cuda.is_available() else "cpu"
tok = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModelForCausalLM.from_pretrained(MODEL_NAME).to(device)
model.eval()
EPS = 1e-9 # tie tolerance
def safe_exp(x: float) -> str:
# Pretty string even for big magnitudes
try:
return f"{math.exp(x):.6e}"
except OverflowError:
return "∞ (overflow)"
except Exception:
return "—"
def is_finite(x: float) -> bool:
return x is not None and math.isfinite(x)
def seq_logprob(context: str, candidate: str, assume_leading_space: bool, show_topk: int):
"""
Compute log P(candidate | context) via chain rule over tokens.
Returns (total_logprob, detail_text, token_list, num_tokens).
"""
cand_text = (" " + candidate) if assume_leading_space else candidate
with torch.no_grad():
ctx_ids = tok.encode(context, return_tensors="pt").to(device)
cand_ids = tok.encode(cand_text, add_special_tokens=False)
if len(cand_ids) == 0:
return None, "Candidate tokenized to empty sequence (check spacing).", [], 0
total_logprob = 0.0
step_lines = []
input_ids = ctx_ids
token_texts = []
for i, t_id in enumerate(cand_ids):
outputs = model(input_ids=input_ids)
logits = outputs.logits[:, -1, :]
logprobs = torch.log_softmax(logits, dim=-1)
token_lp = logprobs[0, t_id].item()
total_logprob += token_lp
tok_str = tok.decode([t_id])
token_texts.append(tok_str)
if show_topk > 0:
topk_vals, topk_idx = torch.topk(logprobs, k=min(show_topk, logprobs.shape[-1]), dim=-1)
tops = ", ".join([f"{repr(tok.decode([int(idx)]))}:{val.item():.2f}"
for idx, val in zip(topk_idx[0], topk_vals[0])])
step_lines.append(
f"Step {i+1}: token={repr(tok_str)} logprob={token_lp:.6f} "
f"prob={math.exp(token_lp):.6e}\n top-{show_topk}: {tops}"
)
else:
step_lines.append(
f"Step {i+1}: token={repr(tok_str)} logprob={token_lp:.6f} "
f"prob={math.exp(token_lp):.6e}"
)
# teacher-forcing: append the true token to continue conditioning
input_ids = torch.cat([input_ids, torch.tensor([[t_id]], device=device)], dim=1)
return total_logprob, "\n".join(step_lines), token_texts, len(cand_ids)
def compare_candidates(context, candA, candB, assume_space, topk, use_len_norm):
# Basic checks
errors = []
if not context.strip():
errors.append("Please enter a context.")
if not candA.strip():
errors.append("Please enter Candidate A.")
if not candB.strip():
errors.append("Please enter Candidate B.")
if errors:
msg = " ".join(errors)
return (f"
{msg}
",
"", "", "", "", "")
# Compute
lpA, detA, toksA, nA = seq_logprob(context, candA, assume_space, topk)
lpB, detB, toksB, nB = seq_logprob(context, candB, assume_space, topk)
# Validate numbers
if not (is_finite(lpA) and is_finite(lpB)):
return ("Numerical issue (NaN/Inf). "
"Try shorter context, smaller model (e.g., distilgpt2), or disable length normalization.
",
"", "", "", "", "")
# Optionally length-normalize (per-token average log-prob)
# Note: odds under length-normalization are "per-token odds", not whole-sequence odds.
if use_len_norm:
if nA == 0 or nB == 0:
return ("Empty tokenization. "
"Check spacing or turn off 'assume leading space'.
",
"", "", "", "", "")
scoreA = lpA / nA
scoreB = lpB / nB
label_suffix = " (per-token)"
else:
scoreA = lpA
scoreB = lpB
label_suffix = ""
diff = scoreA - scoreB # log-odds if unnormalized; log per-token odds otherwise
# Winner logic with proper tie handling
if abs(diff) <= EPS:
winner = "Tie"
elif diff > 0:
winner = "Candidate A"
else:
winner = "Candidate B"
ratio_str = safe_exp(diff)
# Colors
if winner == "Candidate A":
win_color = "#166534" # green
elif winner == "Candidate B":
win_color = "#1d4ed8" # blue
else:
win_color = "#92400e" # amber
# Headline
headline = (
f""
f"
Winner: {winner}{label_suffix}
"
f"
"
f"Odds A/B{label_suffix} = {ratio_str} | "
f"log-odds A−B{label_suffix} = {diff:.6f}"
f"
"
f"
"
f"(Odds > 1 ⇒ A more probable; < 1 ⇒ B more probable. "
f"{'Per-token uses average log-prob.' if use_len_norm else 'Whole-sequence comparison.'})"
f"
"
)
def summarize(label, cand, lp, toks, n):
return (
f"**{label}**: {cand}\n\n"
f"Tokenization: {toks}\n"
f"Total logprob: {lp:.6f}\n"
f"Sequence probability: {math.exp(lp):.6e}\n"
f"Tokens: {n}"
)
sumA = summarize("Candidate A", candA, lpA, toksA, nA)
sumB = summarize("Candidate B", candB, lpB, toksB, nB)
return headline, sumA, detA, sumB, detB, ""
def swap(a, b):
return b, a
with gr.Blocks(title="Two-Candidate Next-Token Comparator (Robust)") as demo:
gr.Markdown(
"# Two-Candidate Next-Word/Token Probability (Robust)\n"
"Compare **P(A|context)** vs **P(B|context)** from a pretrained causal LM (no fine-tuning).\n"
"- Proper tie handling and numerical guards.\n"
"- Optional **length normalization** (per-token).\n"
"- Use **Swap** to sanity-check symmetry."
)
with gr.Row():
context = gr.Textbox(label="Context (prompt)", lines=6, placeholder="Paste prior text here...")
with gr.Row():
candA = gr.Textbox(label="Candidate A (follow-up)")
candB = gr.Textbox(label="Candidate B (follow-up)")
with gr.Row():
assume_space = gr.Checkbox(value=True, label="Assume leading space before candidates (useful for GPT-2 tokenization)")
topk = gr.Slider(0, 20, value=5, step=1, label="Show top-k alternatives (per token step)")
use_len_norm = gr.Checkbox(value=False, label="Use length normalization (average log-prob per token)")
with gr.Row():
btn_compare = gr.Button("Compare", variant="primary")
btn_swap = gr.Button("Swap A ↔ B")
winner_html = gr.HTML()
summaryA = gr.Markdown()
detailsA = gr.Textbox(label="Candidate A — step-by-step", lines=10)
summaryB = gr.Markdown()
detailsB = gr.Textbox(label="Candidate B — step-by-step", lines=10)
_hidden = gr.Textbox(visible=False)
btn_compare.click(
fn=compare_candidates,
inputs=[context, candA, candB, assume_space, topk, use_len_norm],
outputs=[winner_html, summaryA, detailsA, summaryB, detailsB, _hidden]
)
btn_swap.click(
fn=swap, inputs=[candA, candB], outputs=[candA, candB]
)
demo.launch()