File size: 7,812 Bytes
48051fc f0253ab fea1f8d 48051fc fea1f8d f0253ab fea1f8d f0253ab 48051fc f0253ab 48051fc fea1f8d 48051fc f0253ab 48051fc f0253ab 48051fc f0253ab 48051fc fea1f8d 48051fc f0253ab fea1f8d f0253ab fea1f8d f0253ab 48051fc fea1f8d f0253ab fea1f8d f0253ab fea1f8d f0253ab fea1f8d 48051fc fea1f8d f0253ab fea1f8d f0253ab fea1f8d f0253ab fea1f8d f0253ab fea1f8d f0253ab fea1f8d f0253ab fea1f8d f0253ab fea1f8d f0253ab fea1f8d f0253ab fea1f8d f0253ab fea1f8d f0253ab 48051fc fea1f8d 48051fc fea1f8d 48051fc fea1f8d f0253ab fea1f8d f0253ab fea1f8d f0253ab fea1f8d f0253ab fea1f8d f0253ab 48051fc |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 |
import math
import gradio as gr
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
# ---- Model config ----
MODEL_NAME = "gpt2" # e.g., "distilgpt2", "gpt2-medium"
device = "cuda" if torch.cuda.is_available() else "cpu"
tok = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModelForCausalLM.from_pretrained(MODEL_NAME).to(device)
model.eval()
EPS = 1e-9 # tie tolerance
def safe_exp(x: float) -> str:
# Pretty string even for big magnitudes
try:
return f"{math.exp(x):.6e}"
except OverflowError:
return "∞ (overflow)"
except Exception:
return "—"
def is_finite(x: float) -> bool:
return x is not None and math.isfinite(x)
def seq_logprob(context: str, candidate: str, assume_leading_space: bool, show_topk: int):
"""
Compute log P(candidate | context) via chain rule over tokens.
Returns (total_logprob, detail_text, token_list, num_tokens).
"""
cand_text = (" " + candidate) if assume_leading_space else candidate
with torch.no_grad():
ctx_ids = tok.encode(context, return_tensors="pt").to(device)
cand_ids = tok.encode(cand_text, add_special_tokens=False)
if len(cand_ids) == 0:
return None, "Candidate tokenized to empty sequence (check spacing).", [], 0
total_logprob = 0.0
step_lines = []
input_ids = ctx_ids
token_texts = []
for i, t_id in enumerate(cand_ids):
outputs = model(input_ids=input_ids)
logits = outputs.logits[:, -1, :]
logprobs = torch.log_softmax(logits, dim=-1)
token_lp = logprobs[0, t_id].item()
total_logprob += token_lp
tok_str = tok.decode([t_id])
token_texts.append(tok_str)
if show_topk > 0:
topk_vals, topk_idx = torch.topk(logprobs, k=min(show_topk, logprobs.shape[-1]), dim=-1)
tops = ", ".join([f"{repr(tok.decode([int(idx)]))}:{val.item():.2f}"
for idx, val in zip(topk_idx[0], topk_vals[0])])
step_lines.append(
f"Step {i+1}: token={repr(tok_str)} logprob={token_lp:.6f} "
f"prob={math.exp(token_lp):.6e}\n top-{show_topk}: {tops}"
)
else:
step_lines.append(
f"Step {i+1}: token={repr(tok_str)} logprob={token_lp:.6f} "
f"prob={math.exp(token_lp):.6e}"
)
# teacher-forcing: append the true token to continue conditioning
input_ids = torch.cat([input_ids, torch.tensor([[t_id]], device=device)], dim=1)
return total_logprob, "\n".join(step_lines), token_texts, len(cand_ids)
def compare_candidates(context, candA, candB, assume_space, topk, use_len_norm):
# Basic checks
errors = []
if not context.strip():
errors.append("Please enter a context.")
if not candA.strip():
errors.append("Please enter Candidate A.")
if not candB.strip():
errors.append("Please enter Candidate B.")
if errors:
msg = " ".join(errors)
return (f"<div style='color:#b00020;font-weight:600'>{msg}</div>",
"", "", "", "", "")
# Compute
lpA, detA, toksA, nA = seq_logprob(context, candA, assume_space, topk)
lpB, detB, toksB, nB = seq_logprob(context, candB, assume_space, topk)
# Validate numbers
if not (is_finite(lpA) and is_finite(lpB)):
return ("<div style='color:#b00020;font-weight:600'>Numerical issue (NaN/Inf). "
"Try shorter context, smaller model (e.g., distilgpt2), or disable length normalization.</div>",
"", "", "", "", "")
# Optionally length-normalize (per-token average log-prob)
# Note: odds under length-normalization are "per-token odds", not whole-sequence odds.
if use_len_norm:
if nA == 0 or nB == 0:
return ("<div style='color:#b00020;font-weight:600'>Empty tokenization. "
"Check spacing or turn off 'assume leading space'.</div>",
"", "", "", "", "")
scoreA = lpA / nA
scoreB = lpB / nB
label_suffix = " (per-token)"
else:
scoreA = lpA
scoreB = lpB
label_suffix = ""
diff = scoreA - scoreB # log-odds if unnormalized; log per-token odds otherwise
# Winner logic with proper tie handling
if abs(diff) <= EPS:
winner = "Tie"
elif diff > 0:
winner = "Candidate A"
else:
winner = "Candidate B"
ratio_str = safe_exp(diff)
# Colors
if winner == "Candidate A":
win_color = "#166534" # green
elif winner == "Candidate B":
win_color = "#1d4ed8" # blue
else:
win_color = "#92400e" # amber
# Headline
headline = (
f"<div style='padding:14px;border-radius:12px;background:#f8fafc;"
f"border:1px solid #e2e8f0;margin-bottom:10px'>"
f"<div style='font-size:20px;font-weight:800;color:{win_color};'>Winner: {winner}{label_suffix}</div>"
f"<div style='margin-top:6px;font-size:16px;'>"
f"Odds A/B{label_suffix} = <b>{ratio_str}</b> | "
f"log-odds A−B{label_suffix} = <b>{diff:.6f}</b>"
f"</div>"
f"<div style='margin-top:6px;color:#475569'>"
f"(Odds > 1 ⇒ A more probable; < 1 ⇒ B more probable. "
f"{'Per-token uses average log-prob.' if use_len_norm else 'Whole-sequence comparison.'})"
f"</div></div>"
)
def summarize(label, cand, lp, toks, n):
return (
f"**{label}**: {cand}\n\n"
f"Tokenization: {toks}\n"
f"Total logprob: {lp:.6f}\n"
f"Sequence probability: {math.exp(lp):.6e}\n"
f"Tokens: {n}"
)
sumA = summarize("Candidate A", candA, lpA, toksA, nA)
sumB = summarize("Candidate B", candB, lpB, toksB, nB)
return headline, sumA, detA, sumB, detB, ""
def swap(a, b):
return b, a
with gr.Blocks(title="Two-Candidate Next-Token Comparator (Robust)") as demo:
gr.Markdown(
"# Two-Candidate Next-Word/Token Probability (Robust)\n"
"Compare **P(A|context)** vs **P(B|context)** from a pretrained causal LM (no fine-tuning).\n"
"- Proper tie handling and numerical guards.\n"
"- Optional **length normalization** (per-token).\n"
"- Use **Swap** to sanity-check symmetry."
)
with gr.Row():
context = gr.Textbox(label="Context (prompt)", lines=6, placeholder="Paste prior text here...")
with gr.Row():
candA = gr.Textbox(label="Candidate A (follow-up)")
candB = gr.Textbox(label="Candidate B (follow-up)")
with gr.Row():
assume_space = gr.Checkbox(value=True, label="Assume leading space before candidates (useful for GPT-2 tokenization)")
topk = gr.Slider(0, 20, value=5, step=1, label="Show top-k alternatives (per token step)")
use_len_norm = gr.Checkbox(value=False, label="Use length normalization (average log-prob per token)")
with gr.Row():
btn_compare = gr.Button("Compare", variant="primary")
btn_swap = gr.Button("Swap A ↔ B")
winner_html = gr.HTML()
summaryA = gr.Markdown()
detailsA = gr.Textbox(label="Candidate A — step-by-step", lines=10)
summaryB = gr.Markdown()
detailsB = gr.Textbox(label="Candidate B — step-by-step", lines=10)
_hidden = gr.Textbox(visible=False)
btn_compare.click(
fn=compare_candidates,
inputs=[context, candA, candB, assume_space, topk, use_len_norm],
outputs=[winner_html, summaryA, detailsA, summaryB, detailsB, _hidden]
)
btn_swap.click(
fn=swap, inputs=[candA, candB], outputs=[candA, candB]
)
demo.launch() |