davidbeaver commited on
Commit
f0253ab
·
verified ·
1 Parent(s): 635598d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +136 -62
app.py CHANGED
@@ -3,98 +3,172 @@ import gradio as gr
3
  from transformers import AutoTokenizer, AutoModelForCausalLM
4
  import torch
5
 
6
- MODEL_NAME = "gpt2" # swap to "gpt2-medium" etc. if you like
 
7
  device = "cuda" if torch.cuda.is_available() else "cpu"
8
 
9
  tok = AutoTokenizer.from_pretrained(MODEL_NAME)
10
  model = AutoModelForCausalLM.from_pretrained(MODEL_NAME).to(device)
11
  model.eval()
12
 
13
- def next_seq_prob(context, candidate, assume_leading_space, show_topk):
 
 
 
 
14
  if not context.strip():
15
- return "Please enter context.", "", ""
16
- if not candidate.strip():
17
- return "Please enter a candidate next word/token.", "", ""
18
 
19
- # Optionally prepend a leading space (helps align with GPT-2 BPE “word starts)
20
  cand_text = (" " + candidate) if assume_leading_space else candidate
21
 
22
  with torch.no_grad():
23
- # Encode context
24
  ctx_ids = tok.encode(context, return_tensors="pt").to(device)
25
- # Tokenize candidate (no special tokens)
26
  cand_ids = tok.encode(cand_text, add_special_tokens=False)
 
27
  if len(cand_ids) == 0:
28
- return "Candidate tokenized to empty sequence (check spacing).", "", ""
29
 
30
  total_logprob = 0.0
31
- step_details = []
32
-
33
- # Start from context, then feed each candidate token step-by-step (teacher forcing)
34
  input_ids = ctx_ids
 
 
35
  for i, t_id in enumerate(cand_ids):
36
  outputs = model(input_ids=input_ids)
37
- logits = outputs.logits[:, -1, :] # distribution over next token
38
  logprobs = torch.log_softmax(logits, dim=-1)
39
  token_logprob = logprobs[0, t_id].item()
40
  total_logprob += token_logprob
41
 
42
- # top-k display
43
- topk_vals, topk_idx = torch.topk(logprobs, k=min(show_topk, logprobs.shape[-1]), dim=-1)
44
- topk_pairs = [
45
- (tok.decode(int(idx)), float(val))
46
- for idx, val in zip(topk_idx[0].tolist(), topk_vals[0].tolist())
47
- ]
48
-
49
- step_details.append({
50
- "step": i+1,
51
- "predicted_for": tok.decode([t_id]),
52
- "logprob": token_logprob,
53
- "prob": math.exp(token_logprob),
54
- "topk": topk_pairs
55
- })
56
-
57
- # append the true token to continue conditioning
 
 
58
  input_ids = torch.cat([input_ids, torch.tensor([[t_id]], device=device)], dim=1)
59
 
60
- seq_prob = math.exp(total_logprob)
61
- # Human-friendly note about words vs tokens
62
- tokenized_candidate = [tok.decode([i]) for i in cand_ids]
63
- summary = (
64
- f"Candidate tokenization: {tokenized_candidate}\n"
65
- f"Total logprob (chain rule): {total_logprob:.6f}\n"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
66
  f"Sequence probability: {seq_prob:.6e}"
67
  )
68
 
69
- # Pretty print step details
70
- lines = []
71
- for d in step_details:
72
- lines.append(
73
- f"Step {d['step']}: token={repr(d['predicted_for'])} "
74
- f"logprob={d['logprob']:.6f} prob={d['prob']:.6e}"
75
- )
76
- if show_topk > 0:
77
- tops = ", ".join([f"{repr(tok)}:{lp:.2f}" for tok, lp in d["topk"]])
78
- lines.append(f" top-{show_topk} logprobs: {tops}")
79
- detail_text = "\n".join(lines)
80
-
81
- return summary, detail_text, ""
82
-
83
- with gr.Blocks(title="Next-Token Probability (no fine-tuning)") as demo:
84
- gr.Markdown("# Next-Token Probability\n"
85
- "Compute the probability of a chosen next word/token sequence given a prior text segment.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
86
  with gr.Row():
87
- context = gr.Textbox(label="Context (prompt)", lines=6)
 
88
  with gr.Row():
89
- candidate = gr.Textbox(label="Candidate next word / token sequence")
 
 
90
  with gr.Row():
91
- assume_space = gr.Checkbox(value=True, label="Assume leading space before candidate (useful for word starts in GPT-2 tokenization)")
92
- topk = gr.Slider(0, 20, value=10, step=1, label="Show top-k alternatives (per step)")
93
- btn = gr.Button("Compute probability")
94
- summary = gr.Textbox(label="Summary", lines=4)
95
- details = gr.Textbox(label="Step-by-step (per token)", lines=12)
96
- _hidden = gr.Textbox(visible=False) # placeholder
97
-
98
- btn.click(fn=next_seq_prob, inputs=[context, candidate, assume_space, topk], outputs=[summary, details, _hidden])
 
 
 
 
 
 
 
 
 
 
 
99
 
100
  demo.launch()
 
3
  from transformers import AutoTokenizer, AutoModelForCausalLM
4
  import torch
5
 
6
+ # ---- Model config ----
7
+ MODEL_NAME = "gpt2" # e.g., "distilgpt2", "gpt2", "gpt2-medium"
8
  device = "cuda" if torch.cuda.is_available() else "cpu"
9
 
10
  tok = AutoTokenizer.from_pretrained(MODEL_NAME)
11
  model = AutoModelForCausalLM.from_pretrained(MODEL_NAME).to(device)
12
  model.eval()
13
 
14
+ def seq_logprob(context: str, candidate: str, assume_leading_space: bool, show_topk: int):
15
+ """
16
+ Return (total_logprob, step_detail_text, token_list)
17
+ Computes P(candidate | context) via chain rule over the candidate tokens.
18
+ """
19
  if not context.strip():
20
+ return None, "Please enter context.", []
 
 
21
 
22
+ # Helpful for GPT-2 BPE “word starts
23
  cand_text = (" " + candidate) if assume_leading_space else candidate
24
 
25
  with torch.no_grad():
 
26
  ctx_ids = tok.encode(context, return_tensors="pt").to(device)
 
27
  cand_ids = tok.encode(cand_text, add_special_tokens=False)
28
+
29
  if len(cand_ids) == 0:
30
+ return None, "Candidate tokenized to empty sequence (check spacing).", []
31
 
32
  total_logprob = 0.0
33
+ step_lines = []
 
 
34
  input_ids = ctx_ids
35
+ token_texts = []
36
+
37
  for i, t_id in enumerate(cand_ids):
38
  outputs = model(input_ids=input_ids)
39
+ logits = outputs.logits[:, -1, :]
40
  logprobs = torch.log_softmax(logits, dim=-1)
41
  token_logprob = logprobs[0, t_id].item()
42
  total_logprob += token_logprob
43
 
44
+ tok_str = tok.decode([t_id])
45
+ token_texts.append(tok_str)
46
+
47
+ if show_topk > 0:
48
+ topk_vals, topk_idx = torch.topk(logprobs, k=min(show_topk, logprobs.shape[-1]), dim=-1)
49
+ tops = ", ".join([f"{repr(tok.decode([int(idx)]))}:{val.item():.2f}"
50
+ for idx, val in zip(topk_idx[0], topk_vals[0])])
51
+ step_lines.append(
52
+ f"Step {i+1}: token={repr(tok_str)} logprob={token_logprob:.6f} "
53
+ f"prob={math.exp(token_logprob):.6e}\n top-{show_topk}: {tops}"
54
+ )
55
+ else:
56
+ step_lines.append(
57
+ f"Step {i+1}: token={repr(tok_str)} logprob={token_logprob:.6f} "
58
+ f"prob={math.exp(token_logprob):.6e}"
59
+ )
60
+
61
+ # teacher-forcing: append the true token to continue conditioning
62
  input_ids = torch.cat([input_ids, torch.tensor([[t_id]], device=device)], dim=1)
63
 
64
+ detail_text = "\n".join(step_lines)
65
+ return total_logprob, detail_text, token_texts
66
+
67
+ def compare_candidates(context, cand1, cand2, assume_space, topk):
68
+ # Basic input checks
69
+ errs = []
70
+ if not context.strip():
71
+ errs.append("Please enter a context.")
72
+ if not cand1.strip():
73
+ errs.append("Please enter Candidate A.")
74
+ if not cand2.strip():
75
+ errs.append("Please enter Candidate B.")
76
+ if errs:
77
+ return (
78
+ f"<div style='color:#b00020;font-weight:600'>{' '.join(errs)}</div>",
79
+ "", "", "", "", ""
80
+ )
81
+
82
+ # Compute log-probs
83
+ logp1, details1, toks1 = seq_logprob(context, cand1, assume_space, topk)
84
+ logp2, details2, toks2 = seq_logprob(context, cand2, assume_space, topk)
85
+
86
+ if logp1 is None or logp2 is None:
87
+ return (
88
+ "<div style='color:#b00020;font-weight:600'>Tokenization error. Check inputs.</div>",
89
+ details1, details2, "", "", ""
90
+ )
91
+
92
+ # Summaries for each candidate
93
+ def make_summary(label, cand, logp, toks):
94
+ seq_prob = math.exp(logp)
95
+ return (
96
+ f"**{label}**: {cand}\n\n"
97
+ f"Tokenization: {toks}\n"
98
+ f"Total logprob: {logp:.6f}\n"
99
  f"Sequence probability: {seq_prob:.6e}"
100
  )
101
 
102
+ summary1 = make_summary("Candidate A", cand1, logp1, toks1)
103
+ summary2 = make_summary("Candidate B", cand2, logp2, toks2)
104
+
105
+ # Ratio, odds, and winner
106
+ log_odds = logp1 - logp2 # log(P(A)/P(B))
107
+ # Cap extreme ratios for display; still show exact log-odds
108
+ try:
109
+ ratio = math.exp(log_odds)
110
+ ratio_str = f"{ratio:.6e}"
111
+ except OverflowError:
112
+ ratio_str = "(overflow)"
113
+ winner = "Candidate A" if logp1 > logp2 else ("Tie" if abs(log_odds) < 1e-12 else "Candidate B")
114
+
115
+ if winner == "Candidate A":
116
+ win_color = "#166534" # green
117
+ elif winner == "Candidate B":
118
+ win_color = "#1d4ed8" # blue
119
+ else:
120
+ win_color = "#92400e" # amber (tie)
121
+
122
+ headline = (
123
+ f"<div style='padding:14px;border-radius:12px;background:#f8fafc;"
124
+ f"border:1px solid #e2e8f0;margin-bottom:10px'>"
125
+ f"<div style='font-size:20px;font-weight:800;color:{win_color};'>Winner: {winner}</div>"
126
+ f"<div style='margin-top:6px;font-size:16px;'>"
127
+ f"Odds (A/B) = <b>{ratio_str}</b> &nbsp;|&nbsp; "
128
+ f"log-odds = <b>{log_odds:.6f}</b>"
129
+ f"</div>"
130
+ f"<div style='margin-top:6px;color:#475569'>"
131
+ f"(Odds &gt; 1 means A is more probable; &lt; 1 means B is more probable.)"
132
+ f"</div></div>"
133
+ )
134
+
135
+ return headline, summary1, details1, summary2, details2, ""
136
+
137
+ with gr.Blocks(title="Two-Candidate Next-Token Probability Comparator") as demo:
138
+ gr.Markdown(
139
+ "# Two-Candidate Next-Word/Token Probability\n"
140
+ "Given a **context**, compare the conditional probabilities of **two candidate continuations**.\n"
141
+ "- Uses a pretrained causal LM (default: GPT-2). No fine-tuning.\n"
142
+ "- Works at the **token** level; multi-token “words” are handled via the chain rule.\n"
143
+ "- The **Winner** is the higher-probability candidate; we also show the **odds ratio (A/B)** and log-odds."
144
+ )
145
+
146
  with gr.Row():
147
+ context = gr.Textbox(label="Context (prompt)", lines=6, placeholder="Paste your prior text here...")
148
+
149
  with gr.Row():
150
+ cand1 = gr.Textbox(label="Candidate A (follow-up)")
151
+ cand2 = gr.Textbox(label="Candidate B (follow-up)")
152
+
153
  with gr.Row():
154
+ assume_space = gr.Checkbox(
155
+ value=True,
156
+ label="Assume leading space before candidates (helps align with word starts in GPT-2 tokenization)"
157
+ )
158
+ topk = gr.Slider(0, 20, value=5, step=1, label="Show top-k alternatives (per token step)")
159
+
160
+ btn = gr.Button("Compare")
161
+ winner_html = gr.HTML()
162
+ summary1 = gr.Markdown()
163
+ details1 = gr.Textbox(label="Candidate A — step-by-step", lines=10)
164
+ summary2 = gr.Markdown()
165
+ details2 = gr.Textbox(label="Candidate B — step-by-step", lines=10)
166
+ _hidden = gr.Textbox(visible=False)
167
+
168
+ btn.click(
169
+ fn=compare_candidates,
170
+ inputs=[context, cand1, cand2, assume_space, topk],
171
+ outputs=[winner_html, summary1, details1, summary2, details2, _hidden]
172
+ )
173
 
174
  demo.launch()