Port steganacrostics to a Gradio app; retarget to MiniCPM5-1B

#1
by lsb - opened
Files changed (12) hide show
  1. .gitignore +16 -0
  2. README.md +50 -5
  3. app.py +334 -54
  4. classifier.py +124 -0
  5. crossing_search.py +292 -0
  6. eval_classifier.py +185 -0
  7. grammar.py +169 -0
  8. logits.py +84 -0
  9. masking.py +53 -0
  10. requirements.txt +5 -0
  11. sweep_minicpm.py +121 -0
  12. tokinfo.py +46 -0
.gitignore ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Python venv (huge: torch, transformers, …) — never commit
2
+ .venv/
3
+ venv/
4
+
5
+ # Byte-compiled / caches
6
+ __pycache__/
7
+ *.py[cod]
8
+ *.egg-info/
9
+ .ipynb_checkpoints/
10
+
11
+ # Local model / HF caches (models are downloaded at runtime)
12
+ .cache/
13
+ hf_cache/
14
+
15
+ # OS cruft
16
+ .DS_Store
README.md CHANGED
@@ -4,14 +4,59 @@ emoji: 💬
4
  colorFrom: yellow
5
  colorTo: purple
6
  sdk: gradio
7
- sdk_version: 6.5.1
8
  app_file: app.py
9
  pinned: false
10
- hf_oauth: true
11
- hf_oauth_scopes:
12
- - inference-api
13
  license: apache-2.0
14
  short_description: Completely normal text assistant, with talking on the side
 
 
 
 
 
 
15
  ---
16
 
17
- An example chatbot using [Gradio](https://gradio.app), [`huggingface_hub`](https://huggingface.co/docs/huggingface_hub/v0.22.2/en/index), and the [Hugging Face Inference API](https://huggingface.co/docs/api-inference/index).
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
  colorFrom: yellow
5
  colorTo: purple
6
  sdk: gradio
7
+ sdk_version: 6.18.0
8
  app_file: app.py
9
  pinned: false
 
 
 
10
  license: apache-2.0
11
  short_description: Completely normal text assistant, with talking on the side
12
+ tags:
13
+ - track:wood
14
+ - sponsor:openbmb
15
+ - achievement:offgrid
16
+ - achievement:sharing
17
+ - achievement:fieldnotes
18
  ---
19
 
20
+ # Side chat
21
+
22
+ A Gradio port of the browser steganacrostics app. A completely normal text
23
+ assistant — except every line of the answer secretly starts with the next
24
+ letter of a hidden **secret** word (an acrostic). It does this with
25
+ grammar-constrained decoding over a small local model (`openbmb/MiniCPM5-1B` by
26
+ default; set `SIDECHAT_MODEL=LiquidAI/LFM2.5-350M` for the smaller, faster
27
+ original), running on **CPU** via PyTorch `transformers`.
28
+
29
+ What's ported from the JavaScript original (`../../src/`):
30
+
31
+ - **Grammar engine** (`grammar.py`) — a tiny NFA that pins each line to its
32
+ forced first letter, with optional ` * ` bullets and a max line length.
33
+ - **Constrained generation** (`logits.py` + `masking.py`) — a `LogitsProcessor`
34
+ that masks every token that would break the acrostic; EOS only at an accept
35
+ state. A state-keyed cache makes the per-step vocab scan cheap.
36
+ - **List-vs-prose classifier** (`classifier.py`) — an optimized prompt,
37
+ grammar-constrained to `list.` / `story.`, that auto-picks the render mode.
38
+ The prompt is tuned per model: failure modes are model-specific, so
39
+ `eval_classifier.py` (50 list + 50 prose prompts) and `sweep_minicpm.py`
40
+ re-optimize it for whatever model is in use.
41
+ - **Local-crossing search** (`crossing_search.py`) — the "extra attention at
42
+ the constraint": generate each prose line greedily, then choose where to
43
+ break it so a short window straddling the crossing (last *k* tokens + forced
44
+ letter + next *j* tokens) reads best. Plus stealth lowercase casing and a
45
+ minimum line length.
46
+
47
+ Run locally:
48
+
49
+ ```
50
+ pip install -r requirements.txt
51
+ python app.py
52
+ ```
53
+
54
+ Then open the printed URL, type a prompt, set a secret in ⚙️ Settings, and click
55
+ Generate. The list-vs-prose classifier runs automatically on each Generate (turn
56
+ it off in ⚙️ Settings to set the render mode by hand, or use 🔎 Detect to preview
57
+ it). Because everything runs on CPU, generation takes seconds (more for the
58
+ larger model); the crossing search trades extra time for smoother prose.
59
+
60
+ The model is downloaded from the Hugging Face Hub on first run. Custom logits
61
+ processing requires the model to run in-process, so this app does not use the
62
+ remote Inference API.
app.py CHANGED
@@ -1,69 +1,349 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
- from huggingface_hub import InferenceClient
3
 
 
 
 
 
 
4
 
5
- def respond(
6
- message,
7
- history: list[dict[str, str]],
8
- system_message,
9
- max_tokens,
10
- temperature,
11
- top_p,
12
- hf_token: gr.OAuthToken,
13
- ):
14
- """
15
- For more information on `huggingface_hub` Inference API support, please check the docs: https://huggingface.co/docs/huggingface_hub/v0.22.2/en/guides/inference
16
- """
17
- client = InferenceClient(token=hf_token.token, model="openai/gpt-oss-20b")
18
 
19
- messages = [{"role": "system", "content": system_message}]
 
 
 
 
 
 
 
 
20
 
21
- messages.extend(history)
22
 
23
- messages.append({"role": "user", "content": message})
 
24
 
25
- response = ""
 
 
 
 
 
 
 
 
 
26
 
27
- for message in client.chat_completion(
28
- messages,
29
- max_tokens=max_tokens,
30
- stream=True,
31
- temperature=temperature,
32
- top_p=top_p,
33
- ):
34
- choices = message.choices
35
- token = ""
36
- if len(choices) and choices[0].delta.content:
37
- token = choices[0].delta.content
38
 
39
- response += token
40
- yield response
41
 
 
 
 
 
 
 
 
 
42
 
43
- """
44
- For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface
45
- """
46
- chatbot = gr.ChatInterface(
47
- respond,
48
- additional_inputs=[
49
- gr.Textbox(value="You are a friendly Chatbot.", label="System message"),
50
- gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
51
- gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
52
- gr.Slider(
53
- minimum=0.1,
54
- maximum=1.0,
55
- value=0.95,
56
- step=0.05,
57
- label="Top-p (nucleus sampling)",
58
- ),
59
- ],
60
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
61
 
62
- with gr.Blocks() as demo:
63
- with gr.Sidebar():
64
- gr.LoginButton()
65
- chatbot.render()
 
 
 
 
66
 
67
 
68
  if __name__ == "__main__":
69
- demo.launch()
 
1
+ """Side chat — a Gradio port of the browser steganacrostics app.
2
+
3
+ Completely normal text assistant, with a secret talking on the side: every line
4
+ of the answer starts with successive letters of a hidden "secret" word (an
5
+ acrostic), produced by grammar-constrained decoding. A list-vs-prose classifier
6
+ auto-picks the render mode, and an optional local-crossing search spends extra
7
+ attention at each constraint cliff so the forced letters read as the natural
8
+ next word.
9
+
10
+ Runs the model locally on CPU with PyTorch transformers (the remote Inference
11
+ API can't do custom logits processing, which is the whole point here).
12
+ """
13
+
14
+ from __future__ import annotations
15
+
16
+ import os
17
+ import re
18
+ import threading
19
+ import queue
20
+ import time
21
+
22
+ import torch
23
+ from transformers import (
24
+ AutoModelForCausalLM,
25
+ AutoTokenizer,
26
+ LogitsProcessorList,
27
+ TextIteratorStreamer,
28
+ )
29
+
30
  import gradio as gr
 
31
 
32
+ from grammar import compile_acrostic, union_grammars
33
+ from logits import GrammarLogitsProcessor, build_token_text_table
34
+ from tokinfo import build_tok_info
35
+ from classifier import classify, DEFAULT_VARIANT
36
+ from crossing_search import generate_crossing_search
37
 
38
+ # Default to MiniCPM5-1B (OpenBMB); override with SIDECHAT_MODEL, e.g.
39
+ # SIDECHAT_MODEL=LiquidAI/LFM2.5-350M for the smaller, faster original.
40
+ MODEL_ID = os.environ.get("SIDECHAT_MODEL", "openbmb/MiniCPM5-1B")
41
+ DEVICE = "cpu" # pure CPU by request
 
 
 
 
 
 
 
 
 
42
 
43
+ LIST_SYSTEM = (
44
+ "You are a helpful assistant. Answer as a plain bulleted list — one short "
45
+ "item per line. Do not use markdown, bold text, headings, code, or numbered "
46
+ "lists."
47
+ )
48
+ PROSE_SYSTEM = (
49
+ "You are a helpful assistant. Answer in plain prose. Do not use markdown, "
50
+ "bold text, headings, code, or bulleted/numbered lists."
51
+ )
52
 
 
53
 
54
+ class Context:
55
+ """Everything the generation + classifier code needs, built once at startup."""
56
 
57
+ def __init__(self):
58
+ print(f"loading {MODEL_ID} on {DEVICE}…", flush=True)
59
+ self.tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
60
+ self.model = (
61
+ AutoModelForCausalLM.from_pretrained(MODEL_ID, dtype=torch.float32)
62
+ .to(DEVICE)
63
+ .eval()
64
+ )
65
+ self.model.device # noqa: B018 (touch to confirm)
66
+ vocab = self.model.config.vocab_size
67
 
68
+ t0 = time.perf_counter()
69
+ self.token_text = build_token_text_table(self.tokenizer, vocab)
70
+ print(f"token table built in {time.perf_counter() - t0:.1f}s ({vocab} tokens)", flush=True)
 
 
 
 
 
 
 
 
71
 
72
+ eos = set()
 
73
 
74
+ def add_eos(x):
75
+ if x is None:
76
+ return
77
+ if isinstance(x, (list, tuple)):
78
+ for y in x:
79
+ add_eos(y)
80
+ else:
81
+ eos.add(int(x))
82
 
83
+ add_eos(self.tokenizer.eos_token_id)
84
+ add_eos(getattr(self.model.config, "eos_token_id", None))
85
+ add_eos(getattr(self.model.generation_config, "eos_token_id", None))
86
+ self.eos_token_ids = sorted(eos)
87
+
88
+ pad = self.tokenizer.pad_token_id
89
+ if pad is None:
90
+ pad = getattr(self.model.generation_config, "pad_token_id", None)
91
+ if pad is None:
92
+ pad = self.eos_token_ids[0]
93
+ self.pad_token_id = int(pad)
94
+
95
+ self.tok_info = build_tok_info(self.token_text, self.eos_token_ids)
96
+ print("context ready.", flush=True)
97
+
98
+
99
+ CTX = Context()
100
+
101
+
102
+ # Case-insensitive acrostic; in list mode the very first ` * ` prefix is optional
103
+ # (some models start with a preamble-free letter, others don't).
104
+ def build_grammar(secret, list_mode, max_line):
105
+ if not list_mode:
106
+ return compile_acrostic(secret, list_prefix="", max_line=max_line, case_insensitive=True)
107
+ with_prefix = compile_acrostic(
108
+ secret, list_prefix=" * ", max_line=max_line, case_insensitive=True, first_line_prefix=True
109
+ )
110
+ without_prefix = compile_acrostic(
111
+ secret, list_prefix=" * ", max_line=max_line, case_insensitive=True, first_line_prefix=False
112
+ )
113
+ return union_grammars([with_prefix, without_prefix])
114
+
115
+
116
+ def check_acrostic(output, secret):
117
+ """Find a window of line-initial letters matching the secret (case-insensitive,
118
+ bullets stripped). Returns (ok, firsts)."""
119
+ lines = [l.strip() for l in output.split("\n")]
120
+ lines = [l for l in lines if l]
121
+ def strip(l):
122
+ return re.sub(r"^\*?\s*", "", l)
123
+ firsts = [(strip(l)[:1] or "") for l in lines]
124
+ n = len(secret)
125
+ for i in range(0, len(firsts) - n + 1):
126
+ if all(firsts[i + j].lower() == secret[j].lower() for j in range(n)):
127
+ return True, "".join(firsts[i:i + n])
128
+ return False, "".join(firsts)
129
+
130
+
131
+ # --- Classifier: drive the list/prose checkbox ------------------------------
132
+ def classify_fn(prompt):
133
+ if not (prompt or "").strip():
134
+ return gr.update(), "enter a prompt to detect list vs. prose"
135
+ pred, raw = classify(CTX, prompt, DEFAULT_VARIANT)
136
+ label = "list" if pred else "prose"
137
+ return pred, f"detected **{label}** (classifier raw: {raw!r})"
138
+
139
+
140
+ def maybe_detect(prompt, list_mode, auto_detect):
141
+ """Runs before Generate: when auto-detect is on, classify the prompt and set
142
+ the list/prose checkbox from it. Otherwise leave the manual choice alone."""
143
+ if auto_detect and (prompt or "").strip():
144
+ pred, raw = classify(CTX, prompt, DEFAULT_VARIANT)
145
+ return pred, f"detected **{'list' if pred else 'prose'}** (raw {raw!r}) — generating…"
146
+ return list_mode, gr.update()
147
+
148
+
149
+ # --- Generation -------------------------------------------------------------
150
+ def _run_in_thread(target):
151
+ """Run target() in a daemon thread; return a queue it pushes to. target
152
+ receives the queue and must push a None sentinel when finished."""
153
+ q = queue.Queue()
154
+ threading.Thread(target=target, args=(q,), daemon=True).start()
155
+ return q
156
+
157
+
158
+ def generate_fn(prompt, secret, list_mode, max_line, crossing, k, j, R, min_line):
159
+ # Strip spaces: a multi-word secret spells its letters across lines; spaces
160
+ # would force odd punctuation-prefixed "word-break" lines. The field still
161
+ # shows the spaced version; the acrostic uses only the letters.
162
+ secret = re.sub(r"\s+", "", (secret or "").strip())
163
+ if not secret:
164
+ yield "(secret is empty — open ⚙️ Settings and set one)", "", ""
165
+ return
166
+
167
+ list_mode = bool(list_mode)
168
+ max_line = max(1, int(max_line or 80))
169
+ system_prompt = LIST_SYSTEM if list_mode else PROSE_SYSTEM
170
+
171
+ try:
172
+ grammar = build_grammar(secret, list_mode, max_line)
173
+ except Exception as e: # noqa: BLE001
174
+ yield f"grammar build error: {e}", "", ""
175
+ return
176
+
177
+ # --- Local-crossing search (prose only) ---------------------------------
178
+ if crossing and not list_mode:
179
+ k = max(0, int(k or 4))
180
+ j = max(0, int(j or 3))
181
+ R = max(0, int(R or 4))
182
+ min_line = min(max_line, max(0, int(min_line or 30)))
183
+ status = f"generating (local-crossing search · k={k}, j={j}, R={R}, minLine={min_line})…"
184
+ committed = [""]
185
+ t0 = time.perf_counter()
186
+
187
+ def worker(q):
188
+ def on_line(line_text, info):
189
+ committed[0] += line_text
190
+ q.put(committed[0])
191
+ try:
192
+ res = generate_crossing_search(
193
+ CTX, grammar, secret, max_line, prompt, system_prompt,
194
+ k=k, j=j, R=R, min_line=min_line, on_line=on_line,
195
+ )
196
+ q.put(("done", res))
197
+ except Exception as e: # noqa: BLE001
198
+ q.put(("error", str(e)))
199
+ q.put(None)
200
+
201
+ q = _run_in_thread(worker)
202
+ result = None
203
+ yield "", "", status
204
+ while True:
205
+ item = q.get()
206
+ if item is None:
207
+ break
208
+ if isinstance(item, tuple) and item[0] == "done":
209
+ result = item[1]
210
+ elif isinstance(item, tuple) and item[0] == "error":
211
+ yield committed[0], f"error: {item[1]}", "error"
212
+ else:
213
+ yield item, "", status
214
+ elapsed = time.perf_counter() - t0
215
+ text = result["text"] if result else committed[0]
216
+ per_line = result["per_line"] if result else []
217
+ n_moved = sum(1 for p in per_line if p.get("r", 0) > 0)
218
+ ok, firsts = check_acrostic(text, secret)
219
+ metrics = (
220
+ f"local-crossing · {elapsed:.2f}s · {len(per_line)} lines · "
221
+ f"{n_moved} breaks moved · acrostic {'OK' if ok else 'MISS'} ({firsts})"
222
+ )
223
+ yield text, metrics, "done (local-crossing search)."
224
+ return
225
+
226
+ # --- Plain grammar-constrained greedy (token-streamed) ------------------
227
+ proc = GrammarLogitsProcessor(grammar, CTX.tokenizer, CTX.token_text, CTX.eos_token_ids)
228
+ messages = [
229
+ {"role": "system", "content": system_prompt},
230
+ {"role": "user", "content": prompt},
231
+ ]
232
+ enc = CTX.tokenizer.apply_chat_template(
233
+ messages, add_generation_prompt=True, return_tensors="pt", return_dict=True
234
+ ).to(DEVICE)
235
+ streamer = TextIteratorStreamer(CTX.tokenizer, skip_prompt=True, skip_special_tokens=True)
236
+
237
+ gen_kwargs = dict(
238
+ **enc,
239
+ max_new_tokens=400,
240
+ do_sample=False,
241
+ logits_processor=LogitsProcessorList([proc]),
242
+ streamer=streamer,
243
+ pad_token_id=CTX.pad_token_id,
244
+ )
245
+
246
+ t0 = time.perf_counter()
247
+ thread = threading.Thread(target=CTX.model.generate, kwargs=gen_kwargs, daemon=True)
248
+ thread.start()
249
+
250
+ acc = ""
251
+ t_first = None
252
+ tokens = 0
253
+ chars = 0
254
+ yield "", "", "generating (grammar-constrained)…"
255
+ for chunk in streamer:
256
+ if not chunk:
257
+ continue
258
+ if t_first is None:
259
+ t_first = time.perf_counter()
260
+ tokens += 1
261
+ chars += len(chunk)
262
+ acc += chunk
263
+ gen_s = max(0.001, time.perf_counter() - t_first)
264
+ tps = tokens / gen_s
265
+ ttft = (t_first - t0)
266
+ yield acc, f"TTFT {ttft:.2f}s · ~{tps:.1f} tok/s · {tokens} tokens · {chars} chars", "generating…"
267
+ thread.join()
268
+
269
+ wall = time.perf_counter() - t0
270
+ s = proc.stats
271
+ proc_ms = s["total_ms"]
272
+ ttft = (t_first - t0) if t_first else 0.0
273
+ ok, firsts = check_acrostic(acc, secret)
274
+ metrics = (
275
+ f"TTFT {ttft:.2f}s · {tokens} tokens · {chars} chars · wall {wall:.2f}s · "
276
+ f"mask {proc_ms:.0f}ms ({(proc_ms/1000)/wall*100:.0f}%) · "
277
+ f"acrostic {'OK' if ok else 'MISS'} ({firsts})"
278
+ )
279
+ yield acc, metrics, "done. edit the secret and/or prompt and click Generate again."
280
+
281
+
282
+ # --- UI ---------------------------------------------------------------------
283
+ with gr.Blocks(title="Side chat") as demo:
284
+ gr.Markdown("# Side chat")
285
+ gr.Markdown(
286
+ "Completely normal text assistant, with talking on the side. Each line "
287
+ "of the answer secretly starts with the next letter of your **secret** "
288
+ f"word — grammar-constrained decoding on `{MODEL_ID}`, running locally on CPU."
289
+ )
290
+
291
+ prompt = gr.Textbox(label="Prompt", value="what are some easy-to-make home recipes?", lines=2)
292
+ gr.Examples(
293
+ examples=[
294
+ ["what are some easy-to-make home recipes?"],
295
+ ["please write a few sentences about regular expressions"],
296
+ ],
297
+ inputs=prompt,
298
+ label="Demo prompts (one detects as a list, one as prose)",
299
+ )
300
+ run = gr.Button("Generate", variant="primary")
301
+ output = gr.Textbox(label="Output", lines=10, interactive=False)
302
+ metrics = gr.Markdown("")
303
+
304
+ with gr.Accordion("⚙️ Settings", open=False):
305
+ secret = gr.Textbox(
306
+ label="Secret (each line will start with these letters)", value="subtle"
307
+ )
308
+ auto_detect = gr.Checkbox(
309
+ label="auto-detect list vs. prose on Generate (LLM classifier)",
310
+ value=True,
311
+ )
312
+ list_mode = gr.Checkbox(
313
+ label="render as bulleted list (each line prefixed with ` * `) — "
314
+ "set by auto-detect; uncheck auto-detect to set it manually",
315
+ value=True,
316
+ )
317
+ # Manual preview: run the classifier without generating (debug aid).
318
+ detect = gr.Button("🔎 Detect list / prose (preview only)", size="sm")
319
+ max_line = gr.Number(label="Max chars per line (after the prefix + letter)", value=80, precision=0)
320
+
321
+ gr.Markdown("**Local-crossing search** (prose only) — extra attention at each constraint cliff")
322
+ crossing = gr.Checkbox(
323
+ label="enable local-crossing search (greedy line, then pick the break "
324
+ "that makes the crossing read best; list mode stays greedy)",
325
+ value=False,
326
+ )
327
+ win_k = gr.Number(label="↳ window before the break (k content tokens)", value=4, precision=0)
328
+ win_j = gr.Number(label="↳ window after the forced letter (j content tokens)", value=3, precision=0)
329
+ max_rewind = gr.Number(label="↳ max tokens to trim the break earlier (R; 0 = greedy)", value=4, precision=0)
330
+ min_line = gr.Number(label="↳ min chars per line (avoid stubby lines; 0 = off)", value=30, precision=0)
331
+
332
+ status = gr.Markdown("ready.")
333
+
334
+ # Manual preview: detect list vs. prose without generating.
335
+ detect.click(classify_fn, [prompt], [list_mode, status])
336
+ prompt.submit(classify_fn, [prompt], [list_mode, status])
337
 
338
+ # Generate: auto-detect first (updates the checkbox), then generate using it.
339
+ run.click(
340
+ maybe_detect, [prompt, list_mode, auto_detect], [list_mode, status]
341
+ ).then(
342
+ generate_fn,
343
+ [prompt, secret, list_mode, max_line, crossing, win_k, win_j, max_rewind, min_line],
344
+ [output, metrics, status],
345
+ )
346
 
347
 
348
  if __name__ == "__main__":
349
+ demo.queue().launch(theme=gr.themes.Soft())
classifier.py ADDED
@@ -0,0 +1,124 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """List-vs-prose classifier (Python port of the shipped part of src/eval.js).
2
+
3
+ The classifier reads the user's prompt and decides whether the answer is best
4
+ rendered as a bulleted list or as narrative prose. It is itself an LLM call,
5
+ grammar-constrained to exactly one of two literal completions: apply the chat
6
+ template, append a partial assistant response (the `prefill`), constrain
7
+ generation to one of `branches`, parse the result.
8
+
9
+ Failure modes are model-specific, so the prompt is tuned per model. The default
10
+ here is the MiniCPM5-1B winner (`minicpm_intent_write_sp`, 96% on the 100-prompt
11
+ suite) found by re-running the sweep (eval_classifier.py / sweep_minicpm.py) on
12
+ that model. The LFM2.5-350M winner (`r6_c1_v2_single_plural`, 97.5% dev / 85%
13
+ val) is kept as an alternate — it is *prose-biased* on MiniCPM (~75%), so don't
14
+ reuse it there. See CLASSIFIER_PROMPT_OPTIMIZATION.md for the original JS sweep.
15
+ """
16
+
17
+ from __future__ import annotations
18
+
19
+ from dataclasses import dataclass
20
+ from typing import Callable, List
21
+
22
+ import torch
23
+ from transformers import LogitsProcessorList
24
+
25
+ from grammar import compile_literal, union_grammars
26
+ from logits import GrammarLogitsProcessor
27
+
28
+
29
+ @dataclass
30
+ class Variant:
31
+ name: str
32
+ system: str
33
+ prefill: str
34
+ branches: List[str]
35
+ parse: Callable[[str], bool] # raw generated text -> True (list) / False (prose)
36
+
37
+
38
+ # Shared trigger-rule strings.
39
+ _INTENT_BASE = (
40
+ "Classify the user's intent. Use \"list\" when the answer is a set of "
41
+ "separate items the user can scan. Use \"story\" when the answer flows as "
42
+ "one narrative, single fact, or short paragraph."
43
+ )
44
+ _WRITE_FORMS = (
45
+ " Whenever the user asks to \"write\" or \"compose\" a haiku, poem, letter, "
46
+ "cover letter, email, joke, story, essay, or limerick, the answer is a story."
47
+ )
48
+ _SINGLE_PLURAL = (
49
+ " \"What is X\" (a single fact) is a story; \"What are the/some Xs\" (plural "
50
+ "enumeration) is a list; \"what are the steps/differences/causes/symptoms\" "
51
+ "is a list."
52
+ )
53
+
54
+ # --- The shipped MiniCPM5-1B winner -----------------------------------------
55
+ # On MiniCPM, every "Default to list" framing collapses to all-story (list 0/50)
56
+ # and the LFM2 winner is prose-biased. A neutral *intent* framing nails list
57
+ # recall; adding the write-forms rule (catches "write a haiku/email") and the
58
+ # single-vs-plural rule (catches "what is X" single facts) fixes the residual
59
+ # prose misses. 96% on the 100-prompt suite (list 49/50, prose 47/50).
60
+ DEFAULT_VARIANT = Variant(
61
+ name="minicpm_intent_write_sp",
62
+ system=_INTENT_BASE + _WRITE_FORMS + _SINGLE_PLURAL,
63
+ prefill="The intent is to get a ",
64
+ branches=["list.", "story."],
65
+ parse=lambda s: s.startswith("list"),
66
+ )
67
+
68
+ # --- Reference alternates (other strong variants; useful when re-tuning) -----
69
+ ALTERNATES = [
70
+ # The LFM2.5-350M winner (97.5% dev / 85% val on LFM2; ~75% on MiniCPM).
71
+ Variant(
72
+ name="r6_c1_v2_single_plural",
73
+ system=(
74
+ "Classify the user's request. Use \"list\" when the user wants "
75
+ "enumerated items. Use \"story\" for everything else. \"What is X\" "
76
+ "(a single fact) is a story; \"What are the/some Xs\" (plural "
77
+ "enumeration) is a list; \"what are the steps/differences/causes/"
78
+ "symptoms\" is a list."
79
+ ),
80
+ prefill="The user is asking for a ",
81
+ branches=["list.", "story."],
82
+ parse=lambda s: s.startswith("list"),
83
+ ),
84
+ # Intent base + single-plural only (100% screen, 93% full on MiniCPM;
85
+ # perfect list recall but misses some "write a X" prose prompts).
86
+ Variant(
87
+ name="minicpm_intent_sp",
88
+ system=_INTENT_BASE + _SINGLE_PLURAL,
89
+ prefill="The intent is to get a ",
90
+ branches=["list.", "story."],
91
+ parse=lambda s: s.startswith("list"),
92
+ ),
93
+ ]
94
+
95
+ VARIANTS = [DEFAULT_VARIANT, *ALTERNATES]
96
+
97
+
98
+ def classify(ctx, prompt, variant=DEFAULT_VARIANT):
99
+ """Run one classifier call. ctx is a Context (see app.py): .model,
100
+ .tokenizer, .token_text, .eos_token_ids. Returns (prediction, raw)."""
101
+ tok = ctx.tokenizer
102
+ messages = [
103
+ {"role": "system", "content": variant.system},
104
+ {"role": "user", "content": prompt},
105
+ ]
106
+ templated = tok.apply_chat_template(
107
+ messages, tokenize=False, add_generation_prompt=True
108
+ )
109
+ full_text = templated + variant.prefill
110
+
111
+ grammar = union_grammars([compile_literal(b) for b in variant.branches])
112
+ proc = GrammarLogitsProcessor(grammar, tok, ctx.token_text, ctx.eos_token_ids)
113
+
114
+ enc = tok(full_text, return_tensors="pt", add_special_tokens=False).to(ctx.model.device)
115
+ with torch.no_grad():
116
+ out = ctx.model.generate(
117
+ **enc,
118
+ max_new_tokens=16,
119
+ do_sample=False,
120
+ logits_processor=LogitsProcessorList([proc]),
121
+ pad_token_id=ctx.pad_token_id,
122
+ )
123
+ raw = tok.decode(out[0][enc["input_ids"].shape[1]:], skip_special_tokens=True)
124
+ return variant.parse(raw), raw
crossing_search.py ADDED
@@ -0,0 +1,292 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Local-crossing-objective search for acrostics (Python port of
2
+ src/crossingSearch.js + the LineMaskScore / NewlineStop bits of
3
+ src/surprisalLookahead.js).
4
+
5
+ This is the "extra attention when we come to a constraint": the forced first
6
+ letter of each line is a cliff where the constraint fights what the model wants
7
+ to say. Greedy is a strong baseline; search only beats it when the *objective*
8
+ is right. So we change only the objective:
9
+
10
+ 1. Score a SHORT fixed window straddling the crossing — the last `k` content
11
+ tokens before the break, plus the forced letter and the next `j` tokens.
12
+ Length-neutral; the structural newline is never scored, so there's no
13
+ run-to-the-wall bias.
14
+ 2. Look `j` tokens PAST the forced letter (does the next line *continue*
15
+ well?), not just at it.
16
+ 3. Make the break point a search variable, snapped to word boundaries:
17
+ generate the line greedily, then consider ending it 0..R tokens earlier.
18
+ r=0 (greedy) is always a candidate, so this can only match or beat greedy.
19
+
20
+ Plus two stealth touches carried in LineMaskScore: lowercase the forced letter
21
+ mid-sentence (so the acrostic hides), and a minimum line length (no stubby
22
+ lines). Public-API only: each line/rollout is a fresh model.generate()
23
+ continuation of the chat-templated prompt + committed text fed back as a string.
24
+ """
25
+
26
+ from __future__ import annotations
27
+
28
+ import re
29
+
30
+ import numpy as np
31
+ import torch
32
+ from transformers import LogitsProcessor, LogitsProcessorList, StoppingCriteria, StoppingCriteriaList
33
+
34
+ from masking import LegalCache
35
+
36
+
37
+ # Mid-sentence iff the text so far doesn't end a sentence — then the next forced
38
+ # letter should be lowercase. Empty prefix (line 0) is a sentence start.
39
+ _SENTENCE_END = re.compile(r"[.!?][\"'”’)\]]?$")
40
+ _WORD_START = re.compile(r"^[\s.,;:!?)\]\"'’”]")
41
+
42
+
43
+ def mid_sentence(t):
44
+ s = re.sub(r"\s+$", "", t or "")
45
+ return len(s) > 0 and not _SENTENCE_END.search(s)
46
+
47
+
48
+ class LineMaskScore(LogitsProcessor):
49
+ """Grammar-masking processor that also accumulates the greedy chosen-token
50
+ log-prob per step, applies stealth casing + minimum line length, and (when
51
+ capture_top_n > 0, at the first step) records the top-N legal openings and a
52
+ surprise signal."""
53
+
54
+ def __init__(self, grammar, start_state, tokenizer, token_text, tok_info, cache,
55
+ capture_top_n=0, force_lower_first=False, min_line=0):
56
+ super().__init__()
57
+ self.grammar = grammar
58
+ self.start_state = start_state
59
+ self.tokenizer = tokenizer
60
+ self.token_text = token_text
61
+ self.info = tok_info
62
+ self.cache = cache
63
+ self.capture_top_n = capture_top_n
64
+ self.force_lower_first = force_lower_first
65
+ self.min_line = min_line
66
+ self.prompt_length = None
67
+ self.step_logprobs = [] # chosen (argmax) log-prob, one per generated step
68
+ self.top_n = None
69
+ self.surprise = None
70
+
71
+ def __call__(self, input_ids, scores):
72
+ ids = input_ids[0]
73
+ if self.prompt_length is None:
74
+ self.prompt_length = ids.shape[0]
75
+ generated = ids[self.prompt_length:].tolist()
76
+ gen = (
77
+ self.tokenizer.decode(generated, skip_special_tokens=True)
78
+ if generated
79
+ else ""
80
+ )
81
+
82
+ state = self.grammar.advance(self.start_state, gen)
83
+ data = scores[0]
84
+ if state == -1:
85
+ self.step_logprobs.append(0.0)
86
+ return scores
87
+
88
+ first_step = self.top_n is None and self.surprise is None
89
+ want_signal = self.capture_top_n and first_step
90
+
91
+ lse_all = None
92
+ if want_signal:
93
+ lse_all = torch.logsumexp(data, dim=0).item()
94
+
95
+ # Grammar-legal set (shared, state-cached), then stealth + minLine on top.
96
+ legal = self.cache.legal_np(state).copy()
97
+ info = self.info
98
+
99
+ if self.force_lower_first and len(gen) == 0:
100
+ if np.any(legal & info.alpha_lower):
101
+ legal &= ~info.alpha_upper
102
+
103
+ if self.min_line and len(gen) < self.min_line:
104
+ body = legal & ~info.eos_mask & info.nonempty & ~info.has_newline
105
+ if np.any(body):
106
+ legal &= ~(info.eos_mask | info.has_newline)
107
+
108
+ illegal = torch.from_numpy(~legal).to(data.device)
109
+ data[illegal] = float("-inf")
110
+
111
+ max_legal = torch.max(data)
112
+ if max_legal.item() == float("-inf"):
113
+ self.step_logprobs.append(0.0)
114
+ return scores
115
+ lse_masked = torch.logsumexp(data, dim=0)
116
+ self.step_logprobs.append((max_legal - lse_masked).item())
117
+
118
+ if want_signal:
119
+ self.surprise = lse_all - max_legal.item()
120
+ legal_idx = np.nonzero(legal)[0]
121
+ logits = data[torch.from_numpy(legal_idx).to(data.device)]
122
+ order = torch.argsort(logits, descending=True)[: self.capture_top_n]
123
+ lse_m = lse_masked.item()
124
+ self.top_n = [
125
+ {"id": int(legal_idx[int(o)]), "logit": float(logits[int(o)]),
126
+ "logprob": float(logits[int(o)]) - lse_m}
127
+ for o in order
128
+ ]
129
+ return scores
130
+
131
+
132
+ class NewlineStop(StoppingCriteria):
133
+ """Stop a rollout as soon as the newly generated token contains a newline."""
134
+
135
+ def __init__(self, prompt_length, has_newline):
136
+ super().__init__()
137
+ self.prompt_length = prompt_length
138
+ self.has_newline = has_newline
139
+
140
+ def __call__(self, input_ids, scores, **kwargs):
141
+ out = []
142
+ for ids in input_ids:
143
+ if ids.shape[0] <= self.prompt_length:
144
+ out.append(False)
145
+ else:
146
+ out.append(bool(self.has_newline[int(ids[-1])]))
147
+ return torch.tensor(out, dtype=torch.bool, device=input_ids.device)
148
+
149
+
150
+ def _mean(a):
151
+ return sum(a) / len(a) if a else 0.0
152
+
153
+
154
+ def generate_crossing_search(ctx, grammar, secret, max_line, prompt, system_prompt,
155
+ k=4, j=3, R=4, min_line=30, on_line=None):
156
+ """Generate acrostic text with the local-crossing search. Returns
157
+ {"text": str, "per_line": [...]}. on_line(line_text, info) is called as each
158
+ line is committed (for incremental display)."""
159
+ model = ctx.model
160
+ tok = ctx.tokenizer
161
+ token_text = ctx.token_text
162
+ info = ctx.tok_info
163
+ cache = LegalCache(grammar, token_text, ctx.eos_token_ids) # shared across rollouts
164
+ has_newline = info.has_newline
165
+
166
+ messages = [
167
+ {"role": "system", "content": system_prompt},
168
+ {"role": "user", "content": prompt},
169
+ ]
170
+ prompt_string = tok.apply_chat_template(
171
+ messages, tokenize=False, add_generation_prompt=True
172
+ )
173
+
174
+ def enc_ids(text):
175
+ return tok(text, add_special_tokens=False).input_ids
176
+
177
+ def generate_from(text, max_new_tokens, proc, stop_newline):
178
+ enc = tok(text, return_tensors="pt", add_special_tokens=False).to(model.device)
179
+ stops = None
180
+ if stop_newline:
181
+ stops = StoppingCriteriaList(
182
+ [NewlineStop(enc["input_ids"].shape[1], has_newline)]
183
+ )
184
+ with torch.no_grad():
185
+ out = model.generate(
186
+ **enc,
187
+ max_new_tokens=max_new_tokens,
188
+ do_sample=False,
189
+ logits_processor=LogitsProcessorList([proc]),
190
+ stopping_criteria=stops,
191
+ pad_token_id=ctx.pad_token_id,
192
+ )
193
+ gen_ids = out[0][enc["input_ids"].shape[1]:]
194
+ return tok.decode(gen_ids, skip_special_tokens=True)
195
+
196
+ # Greedy line from prefix_text (acrostic text so far).
197
+ def gen_line(prefix_text, is_last):
198
+ start_state = grammar.advance(grammar.initial, prefix_text)
199
+ ctx_str = prompt_string + prefix_text
200
+ proc = LineMaskScore(
201
+ grammar, start_state, tok, token_text, info, cache,
202
+ force_lower_first=mid_sentence(prefix_text), min_line=min_line,
203
+ )
204
+ text = generate_from(ctx_str, max_line + 8, proc, stop_newline=not is_last)
205
+ if not is_last:
206
+ nl = text.find("\n")
207
+ if nl != -1:
208
+ text = text[: nl + 1]
209
+ base_n = len(enc_ids(ctx_str))
210
+ line_ids = enc_ids(ctx_str + text)[base_n:]
211
+ return text, line_ids, proc.step_logprobs
212
+
213
+ # Roll the NEXT line's opening: forced letter + up to n-1 content tokens.
214
+ def roll_open(prefix_text, n):
215
+ start_state = grammar.advance(grammar.initial, prefix_text)
216
+ if start_state == -1:
217
+ return []
218
+ proc = LineMaskScore(
219
+ grammar, start_state, tok, token_text, info, cache,
220
+ force_lower_first=mid_sentence(prefix_text),
221
+ )
222
+ generate_from(prompt_string + prefix_text, n, proc, stop_newline=True)
223
+ return proc.step_logprobs
224
+
225
+ n_lines = len(secret)
226
+ committed = ""
227
+ per_line = []
228
+
229
+ for i in range(n_lines):
230
+ is_last = i == n_lines - 1
231
+ text, line_ids, logps = gen_line(committed, is_last)
232
+
233
+ # Last line, or no break search: commit the greedy line as-is.
234
+ if is_last or R <= 0:
235
+ committed += text
236
+ per_line.append({"line": i, "chosen": text, "r": 0, "candidates": None})
237
+ if on_line:
238
+ on_line(text, {"line": i})
239
+ continue
240
+
241
+ m = min(len(line_ids), len(logps))
242
+ ids = line_ids[-m:] if m else []
243
+ lps = logps[-m:] if m else []
244
+ has_nl = text.endswith("\n")
245
+ line_start_state = grammar.advance(grammar.initial, committed)
246
+
247
+ candidates = []
248
+ for r in range(0, min(R, m - 1) + 1):
249
+ # r tokens trimmed -> break after (m-r) tokens. Require the first
250
+ # trimmed token to begin a new word/punctuation (clean boundary).
251
+ if r > 0:
252
+ first_trimmed = token_text[ids[m - r]]
253
+ if not first_trimmed or not _WORD_START.match(first_trimmed):
254
+ continue
255
+ kept_ids = ids[: m - r]
256
+ if not kept_ids:
257
+ continue
258
+ prefix_text = re.sub(r"\n+$", "", tok.decode(kept_ids, skip_special_tokens=True))
259
+ broke_line = prefix_text + "\n"
260
+ # The trimmed line must still be grammar-legal (keep the forced letter).
261
+ if grammar.advance(line_start_state, broke_line) == -1:
262
+ continue
263
+ if r > 0 and len(prefix_text) < min_line:
264
+ continue
265
+
266
+ before_lps = lps[: m - r]
267
+ if r == 0 and has_nl:
268
+ before_lps = before_lps[:-1]
269
+ before_lps = before_lps[-k:]
270
+
271
+ after_lps = roll_open(committed + broke_line, 1 + j)
272
+
273
+ candidates.append({
274
+ "r": r,
275
+ "broke_line": broke_line,
276
+ "score": _mean(before_lps + after_lps),
277
+ "n_before": len(before_lps),
278
+ "n_after": len(after_lps),
279
+ "preview": broke_line[-28:],
280
+ })
281
+
282
+ chosen, r = text, 0
283
+ if candidates:
284
+ candidates.sort(key=lambda c: c["score"], reverse=True)
285
+ chosen = candidates[0]["broke_line"]
286
+ r = candidates[0]["r"]
287
+ committed += chosen
288
+ per_line.append({"line": i, "chosen": chosen, "r": r, "candidates": candidates})
289
+ if on_line:
290
+ on_line(chosen, {"line": i})
291
+
292
+ return {"text": committed, "per_line": per_line}
eval_classifier.py ADDED
@@ -0,0 +1,185 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """List-vs-prose classifier eval harness (Python port of the dataset + runner
2
+ from src/eval.js).
3
+
4
+ 50 list-style + 50 prose-style hand-picked prompts, split 10+10 validation /
5
+ 40+40 dev. Run this as a script to sweep candidate classifier variants on the
6
+ current model and pick the best one for it:
7
+
8
+ SIDECHAT_MODEL=openbmb/MiniCPM5-1B .venv/bin/python eval_classifier.py
9
+
10
+ It prints a ranking table (dev accuracy, list-recall, prose-recall) and then
11
+ validates the top variants on the held-out set. The winner becomes
12
+ classifier.DEFAULT_VARIANT.
13
+ """
14
+
15
+ from __future__ import annotations
16
+
17
+ import time
18
+
19
+ from classifier import Variant, classify
20
+
21
+ # ---------------------------------------------------------------------------
22
+ # Datasets (ported verbatim from src/eval.js)
23
+ # ---------------------------------------------------------------------------
24
+
25
+ LIST_PROMPTS = [
26
+ # --- validation (first 10) ---
27
+ "list 10 ways to improve morale at work",
28
+ "give me five reasons to learn Rust",
29
+ "what are the main benefits of meditation?",
30
+ "suggest some names for my new puppy",
31
+ "name three famous jazz musicians",
32
+ "list the ingredients for guacamole",
33
+ "what are the steps to change a tire?",
34
+ "give me ideas for weekend activities with kids",
35
+ "tips for packing light when traveling",
36
+ "what are some common Italian desserts?",
37
+ # --- dev (next 40) ---
38
+ "list popular video game consoles from the 1990s",
39
+ "suggest questions to ask at a job interview",
40
+ "what are the symptoms of dehydration?",
41
+ "name ten countries in Africa",
42
+ "list some movies directed by Christopher Nolan",
43
+ "give me seven examples of onomatopoeia",
44
+ "what tools do I need to build a raised garden bed?",
45
+ "suggest some icebreaker activities for a team meeting",
46
+ "ways to reduce food waste at home",
47
+ "list the planets in order from the sun",
48
+ "what are the main differences between Python 2 and Python 3?",
49
+ "give me 5 good podcast recommendations about history",
50
+ "name three types of dance",
51
+ "top tourist attractions in Kyoto",
52
+ "list common symptoms of the flu",
53
+ "what are some healthy snack ideas for kids?",
54
+ "suggest some books similar to The Hobbit",
55
+ "name five spices commonly used in Indian cooking",
56
+ "list programming languages that compile to WebAssembly",
57
+ "give me a list of yoga poses for beginners",
58
+ "what are some good stretches before running?",
59
+ "name the colors of the rainbow",
60
+ "list the months of the year in French",
61
+ "what are common causes of burnout?",
62
+ "suggest some romantic date ideas in New York",
63
+ "give me a bullet list of home safety tips",
64
+ "list the bones in the human hand",
65
+ "ways to learn a new language quickly",
66
+ "name five mammals native to Australia",
67
+ "what are some highlights of the French Revolution?",
68
+ "list common pitfalls of distributed systems",
69
+ "top 10 songs from the 1980s",
70
+ "suggest some hobbies for introverts",
71
+ "name the original members of The Beatles",
72
+ "what are the primary colors?",
73
+ "list reasons to adopt a cat",
74
+ "give me 6 tips for better sleep hygiene",
75
+ "name the Great Lakes",
76
+ "list programming concepts every developer should know",
77
+ "suggest some vegan dinner recipes",
78
+ ]
79
+
80
+ PROSE_PROMPTS = [
81
+ # --- validation (first 10) ---
82
+ "tell me a short story about a lighthouse keeper",
83
+ "write a haiku about autumn",
84
+ "explain how a solar panel works in a paragraph",
85
+ "summarize the plot of Pride and Prejudice",
86
+ 'what does the word "quixotic" mean?',
87
+ 'translate "good morning" to Japanese',
88
+ "write a professional email declining a meeting",
89
+ "describe the taste of a ripe mango",
90
+ "compose a poem about loneliness",
91
+ "what is the capital of Australia?",
92
+ # --- dev (next 40) ---
93
+ "tell me about the invention of the printing press",
94
+ "write a cover letter for a software engineering role",
95
+ "explain the theory of relativity to a 10-year-old",
96
+ "who was Marie Curie?",
97
+ "describe a sunset over the ocean",
98
+ "what is photosynthesis?",
99
+ "write a bedtime story for a 4-year-old",
100
+ "explain how blockchain works",
101
+ "tell me about the history of tea in China",
102
+ "describe the plot of Inception",
103
+ "write a haiku about the sea",
104
+ "what is the meaning of life according to Camus?",
105
+ "tell me a joke about programming",
106
+ "explain why the sky is blue",
107
+ "describe what it feels like to run a marathon",
108
+ "write a love letter in the style of Shakespeare",
109
+ "what year did the Berlin Wall fall?",
110
+ "tell me about the architecture of the Sagrada Familia",
111
+ "write a persuasive essay on renewable energy",
112
+ "describe the personality of a golden retriever",
113
+ "who was the first person on the moon?",
114
+ "tell me about quantum entanglement briefly",
115
+ "write a one-paragraph synopsis of The Great Gatsby",
116
+ 'what is the etymology of the word "sandwich"?',
117
+ "explain why we dream",
118
+ "tell me a myth about the origin of fire",
119
+ "describe the feeling of nostalgia",
120
+ "write a toast for a wedding",
121
+ 'what does "serendipity" mean?',
122
+ "tell me about your favorite season",
123
+ "explain the difference between empathy and sympathy",
124
+ "who wrote Hamlet?",
125
+ "write a limerick about cats",
126
+ "tell me a ghost story",
127
+ "describe Mount Fuji in winter",
128
+ "what happened in the Cuban Missile Crisis?",
129
+ "explain how a car engine works",
130
+ "tell me a folk tale from Ireland",
131
+ "write an essay on the importance of libraries",
132
+ "describe a perfect day",
133
+ ]
134
+
135
+ VALIDATION_LIST = LIST_PROMPTS[:10]
136
+ VALIDATION_PROSE = PROSE_PROMPTS[:10]
137
+ DEV_LIST = LIST_PROMPTS[10:]
138
+ DEV_PROSE = PROSE_PROMPTS[10:]
139
+
140
+
141
+ def make_labelled(list_prompts, prose_prompts):
142
+ return [{"prompt": p, "expected": True} for p in list_prompts] + [
143
+ {"prompt": p, "expected": False} for p in prose_prompts
144
+ ]
145
+
146
+
147
+ def run_variant_on(ctx, variant, labelled, on_progress=None):
148
+ results = []
149
+ for i, item in enumerate(labelled):
150
+ pred, raw = classify(ctx, item["prompt"], variant)
151
+ results.append({**item, "prediction": pred, "raw": raw, "correct": pred == item["expected"]})
152
+ if on_progress:
153
+ on_progress(i + 1, len(labelled))
154
+ correct = sum(1 for r in results if r["correct"])
155
+ list_total = sum(1 for r in results if r["expected"])
156
+ prose_total = len(results) - list_total
157
+ list_hit = sum(1 for r in results if r["expected"] and r["correct"])
158
+ prose_hit = sum(1 for r in results if not r["expected"] and r["correct"])
159
+ return {
160
+ "variant": variant.name,
161
+ "accuracy": correct / len(results),
162
+ "correct": correct,
163
+ "total": len(results),
164
+ "list_recall": (list_hit, list_total),
165
+ "prose_recall": (prose_hit, prose_total),
166
+ "results": results,
167
+ }
168
+
169
+
170
+ def sweep(ctx, variants, labelled, label=""):
171
+ summaries = []
172
+ for v in variants:
173
+ t0 = time.time()
174
+ res = run_variant_on(ctx, v, labelled)
175
+ res["wall_s"] = time.time() - t0
176
+ lh, lt = res["list_recall"]
177
+ ph, pt = res["prose_recall"]
178
+ print(
179
+ f" [{label}] {v.name:30} {res['correct']:>2}/{res['total']} "
180
+ f"= {res['accuracy']*100:5.1f}% list {lh}/{lt} prose {ph}/{pt} "
181
+ f"({res['wall_s']:.0f}s)",
182
+ flush=True,
183
+ )
184
+ summaries.append(res)
185
+ return summaries
grammar.py ADDED
@@ -0,0 +1,169 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Tiny grammar engine for acrostic-style constraints.
2
+
3
+ A faithful Python port of src/grammar.js.
4
+
5
+ Primitives:
6
+ - Atoms: {"kind": "lit", "allowed": set[str]} (consumes exactly one char from
7
+ `allowed`) or {"kind": "body", "max": int} (consumes 0..max non-newline
8
+ chars).
9
+ - Atom sequences are concatenation-only; with the body/newline structure we
10
+ use, transitions are deterministic, so state packs into one int:
11
+ atom_idx * stride + count.
12
+
13
+ Builders:
14
+ - compile_acrostic(secret, ...) — list-mode or prose-mode acrostic.
15
+ - compile_literal(text) — exact-text matcher (used by the classifier).
16
+ - union_grammars([g1, g2, ...]) — accept if any branch is alive.
17
+
18
+ The dead-state sentinel is -1 everywhere, matching the JS original.
19
+ """
20
+
21
+ from __future__ import annotations
22
+
23
+ # Spaces in the secret are treated as "word breaks" — they don't pin the line to
24
+ # any particular letter, but they still produce a line, and the line must start
25
+ # with a punctuation character so the acrostic reads naturally
26
+ # ("HI WORLD" -> H... / I... / <punct>... / W... / O... / R... / L... / D...).
27
+ PUNCT_FOR_SPACE = set(
28
+ list(".,;:!?-")
29
+ + list("()[]{}")
30
+ + list("~<>")
31
+ + ['"', "'", "`"]
32
+ + list("@#$%&+=/\\|_^")
33
+ )
34
+
35
+
36
+ class AtomGrammar:
37
+ """A single concatenation-only atom-sequence grammar (an NFA packed into ints)."""
38
+
39
+ def __init__(self, atoms):
40
+ self.atoms = atoms
41
+ max_body_max = 0
42
+ for a in atoms:
43
+ if a["kind"] == "body" and a["max"] > max_body_max:
44
+ max_body_max = a["max"]
45
+ self.stride = max_body_max + 2
46
+ self.PAST_END = len(atoms) * self.stride
47
+ self.state_count = self.PAST_END + 1
48
+
49
+ # Precompute accepting states: (a, c) accepts iff atom `a` can be
50
+ # epsilon-skipped at count `c` AND (a+1, 0) is accepting.
51
+ accepting = bytearray(self.state_count)
52
+ accepting[self.PAST_END] = 1
53
+ next_accepting = True
54
+ for a in range(len(atoms) - 1, -1, -1):
55
+ atom = atoms[a]
56
+ if next_accepting:
57
+ mn = 1 if atom["kind"] == "lit" else 0
58
+ mx = 1 if atom["kind"] == "lit" else atom["max"]
59
+ for c in range(mn, mx + 1):
60
+ accepting[a * self.stride + c] = 1
61
+ next_accepting = accepting[a * self.stride + 0] == 1
62
+ self.accepting = accepting
63
+
64
+ self.initial = 0
65
+
66
+ def _consume_at(self, a, ch):
67
+ atoms = self.atoms
68
+ while a < len(atoms):
69
+ atom = atoms[a]
70
+ if atom["kind"] == "lit":
71
+ if ch in atom["allowed"]:
72
+ return self.PAST_END if a + 1 >= len(atoms) else (a + 1) * self.stride
73
+ return -1
74
+ if ch != "\n":
75
+ return a * self.stride + 1
76
+ a += 1
77
+ return -1
78
+
79
+ def advance(self, state, s):
80
+ stride = self.stride
81
+ atoms = self.atoms
82
+ cur = state
83
+ for ch in s:
84
+ if cur == self.PAST_END:
85
+ return -1
86
+ a = cur // stride
87
+ c = cur - a * stride
88
+ atom = atoms[a]
89
+ if atom["kind"] == "lit":
90
+ if c < 1 and ch in atom["allowed"]:
91
+ nxt = self.PAST_END if a + 1 >= len(atoms) else (a + 1) * stride
92
+ else:
93
+ return -1
94
+ else:
95
+ if c < atom["max"] and ch != "\n":
96
+ nxt = a * stride + (c + 1)
97
+ else:
98
+ nxt = self._consume_at(a + 1, ch)
99
+ if nxt == -1:
100
+ return -1
101
+ cur = nxt
102
+ return cur
103
+
104
+ def accepts(self, state):
105
+ return state is not None and 0 <= state < self.state_count and self.accepting[state] == 1
106
+
107
+
108
+ def compile_acrostic(secret, list_prefix=" * ", max_line=80, case_insensitive=False, first_line_prefix=True):
109
+ if not secret:
110
+ raise ValueError("secret must be non-empty")
111
+ atoms = []
112
+ for i, letter in enumerate(secret):
113
+ want_prefix = i > 0 or first_line_prefix
114
+ if want_prefix:
115
+ for c in list_prefix:
116
+ atoms.append({"kind": "lit", "allowed": {c}})
117
+ if letter == " ":
118
+ allowed = set(PUNCT_FOR_SPACE)
119
+ elif case_insensitive:
120
+ allowed = {letter.upper(), letter.lower()}
121
+ else:
122
+ allowed = {letter}
123
+ atoms.append({"kind": "lit", "allowed": allowed})
124
+ atoms.append({"kind": "body", "max": max_line})
125
+ if i < len(secret) - 1:
126
+ atoms.append({"kind": "lit", "allowed": {"\n"}})
127
+ return AtomGrammar(atoms)
128
+
129
+
130
+ def compile_literal(text):
131
+ if not text:
132
+ raise ValueError("literal must be non-empty")
133
+ atoms = [{"kind": "lit", "allowed": {c}} for c in text]
134
+ return AtomGrammar(atoms)
135
+
136
+
137
+ class UnionGrammar:
138
+ """Run several grammars in parallel; a token is alive iff at least one branch
139
+ is alive. State is a list of per-branch ints (-1 = dead branch). When every
140
+ branch is dead, advance returns -1 (the single-grammar dead sentinel)."""
141
+
142
+ def __init__(self, grammars):
143
+ self.grammars = grammars
144
+ self.initial = [g.initial for g in grammars]
145
+
146
+ def advance(self, state, s):
147
+ nxt = [-1] * len(self.grammars)
148
+ any_live = False
149
+ for i, g in enumerate(self.grammars):
150
+ if state[i] == -1:
151
+ nxt[i] = -1
152
+ continue
153
+ r = g.advance(state[i], s)
154
+ nxt[i] = r
155
+ if r != -1:
156
+ any_live = True
157
+ return nxt if any_live else -1
158
+
159
+ def accepts(self, state):
160
+ if state == -1:
161
+ return False
162
+ for i, g in enumerate(self.grammars):
163
+ if state[i] != -1 and g.accepts(state[i]):
164
+ return True
165
+ return False
166
+
167
+
168
+ def union_grammars(grammars):
169
+ return UnionGrammar(grammars)
logits.py ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Grammar-constrained LogitsProcessor (Python port of src/logits.js).
2
+
3
+ At each generation step:
4
+ 1. Decode the generated suffix back to text.
5
+ 2. Advance the grammar NFA by that text.
6
+ 3. For every candidate token id, check whether appending its decoded text
7
+ keeps the NFA alive; mask losers to -inf (via the shared LegalCache).
8
+ 4. EOS is allowed only once the NFA has reached an accept state.
9
+
10
+ Per-token decoding can disagree with BPE sequence-decoding in edge cases
11
+ (merged punctuation, etc.); for the acrostic patterns we care about this
12
+ approximation is fine.
13
+ """
14
+
15
+ from __future__ import annotations
16
+
17
+ import time
18
+
19
+ from transformers import LogitsProcessor
20
+
21
+ from masking import LegalCache
22
+
23
+
24
+ def build_token_text_table(tokenizer, vocab_size):
25
+ """One-shot build of tokenId -> text, using per-token decode. Special tokens
26
+ decode to '' under skip_special_tokens=True, which we treat as
27
+ "disallowed" (empty string)."""
28
+ texts = tokenizer.batch_decode(
29
+ [[i] for i in range(vocab_size)], skip_special_tokens=True
30
+ )
31
+ return [t if isinstance(t, str) else "" for t in texts]
32
+
33
+
34
+ class GrammarLogitsProcessor(LogitsProcessor):
35
+ def __init__(self, grammar, tokenizer, token_text, eos_token_ids=(), legal_cache=None):
36
+ super().__init__()
37
+ self.grammar = grammar
38
+ self.tokenizer = tokenizer
39
+ self.token_text = token_text
40
+ self.cache = legal_cache or LegalCache(grammar, token_text, eos_token_ids)
41
+ self.prompt_length = None
42
+ self.stats = _fresh_stats()
43
+
44
+ def reset(self):
45
+ self.prompt_length = None
46
+ self.stats = _fresh_stats()
47
+
48
+ def __call__(self, input_ids, scores):
49
+ t_entry = time.perf_counter()
50
+ ids = input_ids[0]
51
+ if self.prompt_length is None:
52
+ self.prompt_length = ids.shape[0]
53
+
54
+ generated = ids[self.prompt_length:].tolist()
55
+ text = (
56
+ self.tokenizer.decode(generated, skip_special_tokens=True)
57
+ if generated
58
+ else ""
59
+ )
60
+
61
+ state = self.grammar.advance(self.grammar.initial, text)
62
+ data = scores[0]
63
+
64
+ if state == -1:
65
+ # Already violated; nothing useful to do without rewinding. Let the
66
+ # original logits through so generation at least terminates.
67
+ self._record(time.perf_counter() - t_entry, -1)
68
+ return scores
69
+
70
+ illegal = self.cache.illegal_tensor(state)
71
+ data[illegal.to(data.device)] = float("-inf")
72
+
73
+ self._record(time.perf_counter() - t_entry, int((~illegal).sum().item()))
74
+ return scores
75
+
76
+ def _record(self, dt, survivors):
77
+ st = self.stats
78
+ st["calls"] += 1
79
+ st["total_ms"] += dt * 1000.0
80
+ st["per_step"].append({"ms": dt * 1000.0, "survivors": survivors})
81
+
82
+
83
+ def _fresh_stats():
84
+ return {"calls": 0, "total_ms": 0.0, "per_step": []}
masking.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Shared grammar-legality computation + a per-state cache.
2
+
3
+ The set of grammar-legal next tokens is a pure function of the grammar state, so
4
+ we cache the boolean legal mask by state. This is what makes the crossing search
5
+ affordable: its many short rollouts all start from the same handful of
6
+ line-start states and reuse one (expensive) full-vocab scan.
7
+ """
8
+
9
+ from __future__ import annotations
10
+
11
+ import numpy as np
12
+ import torch
13
+
14
+
15
+ class LegalCache:
16
+ def __init__(self, grammar, token_text, eos_token_ids=()):
17
+ self.grammar = grammar
18
+ self.token_text = token_text
19
+ self.eos_token_ids = [int(x) for x in eos_token_ids]
20
+ # Special tokens decode to '' and are always illegal — never probe them.
21
+ self._scan_ids = [i for i, t in enumerate(token_text) if t]
22
+ self._legal_cache = {} # state-key -> np.bool_ array
23
+ self._illegal_cache = {} # state-key -> torch.BoolTensor
24
+
25
+ @staticmethod
26
+ def _key(state):
27
+ return state if isinstance(state, int) else tuple(state)
28
+
29
+ def legal_np(self, state):
30
+ key = self._key(state)
31
+ cached = self._legal_cache.get(key)
32
+ if cached is not None:
33
+ return cached
34
+ advance = self.grammar.advance
35
+ token_text = self.token_text
36
+ at_accept = self.grammar.accepts(state)
37
+ legal = np.zeros(len(token_text), dtype=bool)
38
+ for i in self._scan_ids:
39
+ if advance(state, token_text[i]) != -1:
40
+ legal[i] = True
41
+ for eid in self.eos_token_ids:
42
+ legal[eid] = at_accept
43
+ self._legal_cache[key] = legal
44
+ return legal
45
+
46
+ def illegal_tensor(self, state):
47
+ key = self._key(state)
48
+ cached = self._illegal_cache.get(key)
49
+ if cached is not None:
50
+ return cached
51
+ t = torch.from_numpy(~self.legal_np(state))
52
+ self._illegal_cache[key] = t
53
+ return t
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ torch
2
+ transformers>=5.12
3
+ accelerate
4
+ numpy
5
+ gradio>=6
sweep_minicpm.py ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Sweep candidate list/prose classifier variants on the current model, pick a
2
+ winner. CPU-conscious: screen all candidates on a 20-prompt subset, then run the
3
+ top few on the full 80-prompt dev set + 20-prompt validation set.
4
+
5
+ SIDECHAT_MODEL=openbmb/MiniCPM5-1B .venv/bin/python sweep_minicpm.py
6
+ """
7
+
8
+ from __future__ import annotations
9
+
10
+ import time
11
+
12
+ import app # loads the model + CTX
13
+ from classifier import Variant
14
+ from eval_classifier import (
15
+ DEV_LIST, DEV_PROSE, VALIDATION_LIST, VALIDATION_PROSE, make_labelled, sweep,
16
+ )
17
+
18
+ CTX = app.CTX
19
+ parse_list = lambda s: s.startswith("list")
20
+ parse_items = lambda s: s.startswith("items")
21
+
22
+ # Candidate variants spanning the axes that mattered in the JS sweep: intent vs.
23
+ # request framing, default-to-list vs. default-to-story polarity, list/story
24
+ # branch vocab, and the single-vs-plural rule. (true/false and CAPS branches are
25
+ # known-bad and omitted.)
26
+ C1_BASE = (
27
+ "Classify the user's request. Use \"list\" when the user wants enumerated "
28
+ "items. Use \"story\" for everything else."
29
+ )
30
+ SINGLE_PLURAL = (
31
+ " \"What is X\" (a single fact) is a story; \"What are the/some Xs\" (plural "
32
+ "enumeration) is a list; \"what are the steps/differences/causes/symptoms\" "
33
+ "is a list."
34
+ )
35
+ WRITE_FORMS = (
36
+ " Whenever the user asks to \"write\" or \"compose\" a haiku, poem, letter, "
37
+ "cover letter, email, joke, story, essay, or limerick, the answer is a story."
38
+ )
39
+ EXTENDED_TRIGGERS = (
40
+ "Classify the user's request. Default to \"list\". Use \"story\" only when the "
41
+ "user asks for narrative/prose: \"tell me a story\", \"write a poem/haiku/"
42
+ "limerick/email/essay/letter\", \"describe\", \"explain\", \"translate\", "
43
+ "\"summarize\", \"what does X mean\", \"who was/is\", \"what is X\", \"when "
44
+ "did\", \"why does\", \"how does (concept)\", \"compose\"."
45
+ )
46
+ STORY_DEFAULT = (
47
+ "Classify the user's request. Default to \"story\". Use \"list\" only when the "
48
+ "user clearly asks for multiple discrete items: \"list N\", \"name N\", "
49
+ "\"give N\", \"top N\", \"suggest some\", \"ways to\", \"tips\", \"steps\", "
50
+ "\"reasons\", \"examples of\"."
51
+ )
52
+
53
+ INTENT_BASE = (
54
+ "Classify the user's intent. Use \"list\" when the answer is a set of "
55
+ "separate items the user can scan. Use \"story\" when the answer flows as "
56
+ "one narrative, single fact, or short paragraph."
57
+ )
58
+ DEFAULT_LIST_BASE = (
59
+ "Classify the user's request. Default to \"list\". Use \"story\" only when "
60
+ "the user clearly asks for narrative: \"tell me a story\", \"write a "
61
+ "poem/haiku/email\", \"describe X\", \"explain X\", \"translate X\", "
62
+ "\"what does X mean\", \"who was/what is/when did\"."
63
+ )
64
+
65
+ # Baseline reveals MiniCPM's split failure: intent/list-default bases nail list
66
+ # recall but miss prose "write a haiku" (write-forms) and "capital of Australia"
67
+ # (single fact); the c1 base is the opposite. So pair list-favoring bases with
68
+ # the WRITE_FORMS + SINGLE_PLURAL rules that target exactly those prose misses.
69
+ CANDIDATES = [
70
+ # Baselines (carried from the LFM2 sweep).
71
+ Variant("c1_single_plural", C1_BASE + SINGLE_PLURAL, "The user is asking for a ", ["list.", "story."], parse_list),
72
+ Variant("intent_two_rules", INTENT_BASE, "The intent is to get a ", ["list.", "story."], parse_list),
73
+ # Intent base + targeted prose rules.
74
+ Variant("intent_write", INTENT_BASE + WRITE_FORMS, "The intent is to get a ", ["list.", "story."], parse_list),
75
+ Variant("intent_sp", INTENT_BASE + SINGLE_PLURAL, "The intent is to get a ", ["list.", "story."], parse_list),
76
+ Variant("intent_write_sp", INTENT_BASE + WRITE_FORMS + SINGLE_PLURAL, "The intent is to get a ", ["list.", "story."], parse_list),
77
+ # Default-to-list base + targeted prose rules.
78
+ Variant("default_list", DEFAULT_LIST_BASE, "The user wants the answer as a ", ["list.", "story."], parse_list),
79
+ Variant("default_list_write_sp", DEFAULT_LIST_BASE + WRITE_FORMS + SINGLE_PLURAL, "The user wants the answer as a ", ["list.", "story."], parse_list),
80
+ # c1 base + write-forms (complementary to single_plural).
81
+ Variant("c1_write_sp", C1_BASE + WRITE_FORMS + SINGLE_PLURAL, "The user is asking for a ", ["list.", "story."], parse_list),
82
+ # Long built-in trigger list (no separate rules).
83
+ Variant("extended_triggers", EXTENDED_TRIGGERS, "The user wants the answer as a ", ["list.", "story."], parse_list),
84
+ # Alternate branch vocab.
85
+ Variant(
86
+ "items_text",
87
+ "Classify the user's intent. Use \"items\" when the user wants enumerated "
88
+ "items. Use \"text\" for everything else (narrative, single answer, "
89
+ "explanation, translation, story, poem).",
90
+ "The intent is to get ", ["items.", "text."], parse_items,
91
+ ),
92
+ ]
93
+
94
+
95
+ def main():
96
+ print(f"model: {app.MODEL_ID} · {len(CTX.token_text)} tokens", flush=True)
97
+ # Fast screen on a 20-prompt subset (first 10 of each dev class).
98
+ screen = make_labelled(DEV_LIST[:10], DEV_PROSE[:10])
99
+ print(f"\n=== SCREEN ({len(screen)} prompts) ===", flush=True)
100
+ t0 = time.time()
101
+ screen_res = sweep(CTX, CANDIDATES, screen, label="screen")
102
+ screen_res.sort(key=lambda r: r["accuracy"], reverse=True)
103
+ print(f"screen done in {(time.time()-t0)/60:.1f} min", flush=True)
104
+
105
+ top = [next(c for c in CANDIDATES if c.name == r["variant"]) for r in screen_res[:3]]
106
+ print(f"\ntop 3 on screen: {[c.name for c in top]}", flush=True)
107
+
108
+ full = make_labelled(DEV_LIST, DEV_PROSE) + make_labelled(VALIDATION_LIST, VALIDATION_PROSE)
109
+ print(f"\n=== FULL ({len(full)} prompts: 50 list + 50 prose) ===", flush=True)
110
+ full_res = sweep(CTX, top, full, label="full")
111
+ full_res.sort(key=lambda r: r["accuracy"], reverse=True)
112
+
113
+ print("\n=== RANKING (full) ===", flush=True)
114
+ for r in full_res:
115
+ lh, lt = r["list_recall"]; ph, pt = r["prose_recall"]
116
+ print(f" {r['variant']:30} {r['accuracy']*100:5.1f}% list {lh}/{lt} prose {ph}/{pt}", flush=True)
117
+ print(f"\nWINNER: {full_res[0]['variant']} @ {full_res[0]['accuracy']*100:.1f}%", flush=True)
118
+
119
+
120
+ if __name__ == "__main__":
121
+ main()
tokinfo.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Precomputed per-token boolean arrays used by the crossing search's stealth
2
+ casing and minimum-line-length masking. Built once at startup."""
3
+
4
+ from __future__ import annotations
5
+
6
+ from dataclasses import dataclass
7
+
8
+ import numpy as np
9
+
10
+
11
+ def _is_ascii_alpha(ch):
12
+ return ("a" <= ch <= "z") or ("A" <= ch <= "Z")
13
+
14
+
15
+ @dataclass
16
+ class TokInfo:
17
+ has_newline: np.ndarray # token text contains '\n'
18
+ alpha_lower: np.ndarray # first ASCII letter is lowercase
19
+ alpha_upper: np.ndarray # first ASCII letter is uppercase
20
+ nonempty: np.ndarray # token decodes to a non-empty string
21
+ eos_mask: np.ndarray # token is an EOS id
22
+
23
+
24
+ def build_tok_info(token_text, eos_token_ids):
25
+ n = len(token_text)
26
+ has_newline = np.zeros(n, dtype=bool)
27
+ alpha_lower = np.zeros(n, dtype=bool)
28
+ alpha_upper = np.zeros(n, dtype=bool)
29
+ nonempty = np.zeros(n, dtype=bool)
30
+ for i, t in enumerate(token_text):
31
+ if not t:
32
+ continue
33
+ nonempty[i] = True
34
+ if "\n" in t:
35
+ has_newline[i] = True
36
+ for ch in t:
37
+ if _is_ascii_alpha(ch):
38
+ if ch.islower():
39
+ alpha_lower[i] = True
40
+ else:
41
+ alpha_upper[i] = True
42
+ break
43
+ eos_mask = np.zeros(n, dtype=bool)
44
+ for e in eos_token_ids:
45
+ eos_mask[int(e)] = True
46
+ return TokInfo(has_newline, alpha_lower, alpha_upper, nonempty, eos_mask)