davidbeaver commited on
Commit
48051fc
·
verified ·
1 Parent(s): 161580b

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +100 -0
app.py ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import gradio as gr
3
+ from transformers import AutoTokenizer, AutoModelForCausalLM
4
+ import torch
5
+
6
+ MODEL_NAME = "gpt2" # swap to "gpt2-medium" etc. if you like
7
+ device = "cuda" if torch.cuda.is_available() else "cpu"
8
+
9
+ tok = AutoTokenizer.from_pretrained(MODEL_NAME)
10
+ model = AutoModelForCausalLM.from_pretrained(MODEL_NAME).to(device)
11
+ model.eval()
12
+
13
+ def next_seq_prob(context, candidate, assume_leading_space, show_topk):
14
+ if not context.strip():
15
+ return "Please enter context.", "", ""
16
+ if not candidate.strip():
17
+ return "Please enter a candidate next word/token.", "", ""
18
+
19
+ # Optionally prepend a leading space (helps align with GPT-2 BPE “word” starts)
20
+ cand_text = (" " + candidate) if assume_leading_space else candidate
21
+
22
+ with torch.no_grad():
23
+ # Encode context
24
+ ctx_ids = tok.encode(context, return_tensors="pt").to(device)
25
+ # Tokenize candidate (no special tokens)
26
+ cand_ids = tok.encode(cand_text, add_special_tokens=False)
27
+ if len(cand_ids) == 0:
28
+ return "Candidate tokenized to empty sequence (check spacing).", "", ""
29
+
30
+ total_logprob = 0.0
31
+ step_details = []
32
+
33
+ # Start from context, then feed each candidate token step-by-step (teacher forcing)
34
+ input_ids = ctx_ids
35
+ for i, t_id in enumerate(cand_ids):
36
+ outputs = model(input_ids=input_ids)
37
+ logits = outputs.logits[:, -1, :] # distribution over next token
38
+ logprobs = torch.log_softmax(logits, dim=-1)
39
+ token_logprob = logprobs[0, t_id].item()
40
+ total_logprob += token_logprob
41
+
42
+ # top-k display
43
+ topk_vals, topk_idx = torch.topk(logprobs, k=min(show_topk, logprobs.shape[-1]), dim=-1)
44
+ topk_pairs = [
45
+ (tok.decode(int(idx)), float(val))
46
+ for idx, val in zip(topk_idx[0].tolist(), topk_vals[0].tolist())
47
+ ]
48
+
49
+ step_details.append({
50
+ "step": i+1,
51
+ "predicted_for": tok.decode([t_id]),
52
+ "logprob": token_logprob,
53
+ "prob": math.exp(token_logprob),
54
+ "topk": topk_pairs
55
+ })
56
+
57
+ # append the true token to continue conditioning
58
+ input_ids = torch.cat([input_ids, torch.tensor([[t_id]], device=device)], dim=1)
59
+
60
+ seq_prob = math.exp(total_logprob)
61
+ # Human-friendly note about words vs tokens
62
+ tokenized_candidate = [tok.decode([i]) for i in cand_ids]
63
+ summary = (
64
+ f"Candidate tokenization: {tokenized_candidate}\n"
65
+ f"Total logprob (chain rule): {total_logprob:.6f}\n"
66
+ f"Sequence probability: {seq_prob:.6e}"
67
+ )
68
+
69
+ # Pretty print step details
70
+ lines = []
71
+ for d in step_details:
72
+ lines.append(
73
+ f"Step {d['step']}: token={repr(d['predicted_for'])} "
74
+ f"logprob={d['logprob']:.6f} prob={d['prob']:.6e}"
75
+ )
76
+ if show_topk > 0:
77
+ tops = ", ".join([f"{repr(tok)}:{lp:.2f}" for tok, lp in d["topk"]])
78
+ lines.append(f" top-{show_topk} logprobs: {tops}")
79
+ detail_text = "\n".join(lines)
80
+
81
+ return summary, detail_text, ""
82
+
83
+ with gr.Blocks(title="Next-Token Probability (no fine-tuning)") as demo:
84
+ gr.Markdown("# Next-Token Probability\n"
85
+ "Compute the probability of a chosen next word/token sequence given a prior text segment.")
86
+ with gr.Row():
87
+ context = gr.Textbox(label="Context (prompt)", lines=6)
88
+ with gr.Row():
89
+ candidate = gr.Textbox(label="Candidate next word / token sequence")
90
+ with gr.Row():
91
+ assume_space = gr.Checkbox(value=True, label="Assume leading space before candidate (useful for word starts in GPT-2 tokenization)")
92
+ topk = gr.Slider(0, 20, value=10, step=1, label="Show top-k alternatives (per step)")
93
+ btn = gr.Button("Compute probability")
94
+ summary = gr.Textbox(label="Summary", lines=4)
95
+ details = gr.Textbox(label="Step-by-step (per token)", lines=12)
96
+ _hidden = gr.Textbox(visible=False) # placeholder
97
+
98
+ btn.click(fn=next_seq_prob, inputs=[context, candidate, assume_space, topk], outputs=[summary, details, _hidden])
99
+
100
+ demo.launch()