Spaces:
Sleeping
Sleeping
Port steganacrostics to a Gradio app; retarget to MiniCPM5-1B
#1
by lsb - opened
- .gitignore +16 -0
- README.md +50 -5
- app.py +334 -54
- classifier.py +124 -0
- crossing_search.py +292 -0
- eval_classifier.py +185 -0
- grammar.py +169 -0
- logits.py +84 -0
- masking.py +53 -0
- requirements.txt +5 -0
- sweep_minicpm.py +121 -0
- tokinfo.py +46 -0
.gitignore
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Python venv (huge: torch, transformers, …) — never commit
|
| 2 |
+
.venv/
|
| 3 |
+
venv/
|
| 4 |
+
|
| 5 |
+
# Byte-compiled / caches
|
| 6 |
+
__pycache__/
|
| 7 |
+
*.py[cod]
|
| 8 |
+
*.egg-info/
|
| 9 |
+
.ipynb_checkpoints/
|
| 10 |
+
|
| 11 |
+
# Local model / HF caches (models are downloaded at runtime)
|
| 12 |
+
.cache/
|
| 13 |
+
hf_cache/
|
| 14 |
+
|
| 15 |
+
# OS cruft
|
| 16 |
+
.DS_Store
|
README.md
CHANGED
|
@@ -4,14 +4,59 @@ emoji: 💬
|
|
| 4 |
colorFrom: yellow
|
| 5 |
colorTo: purple
|
| 6 |
sdk: gradio
|
| 7 |
-
sdk_version: 6.
|
| 8 |
app_file: app.py
|
| 9 |
pinned: false
|
| 10 |
-
hf_oauth: true
|
| 11 |
-
hf_oauth_scopes:
|
| 12 |
-
- inference-api
|
| 13 |
license: apache-2.0
|
| 14 |
short_description: Completely normal text assistant, with talking on the side
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 15 |
---
|
| 16 |
|
| 17 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 4 |
colorFrom: yellow
|
| 5 |
colorTo: purple
|
| 6 |
sdk: gradio
|
| 7 |
+
sdk_version: 6.18.0
|
| 8 |
app_file: app.py
|
| 9 |
pinned: false
|
|
|
|
|
|
|
|
|
|
| 10 |
license: apache-2.0
|
| 11 |
short_description: Completely normal text assistant, with talking on the side
|
| 12 |
+
tags:
|
| 13 |
+
- track:wood
|
| 14 |
+
- sponsor:openbmb
|
| 15 |
+
- achievement:offgrid
|
| 16 |
+
- achievement:sharing
|
| 17 |
+
- achievement:fieldnotes
|
| 18 |
---
|
| 19 |
|
| 20 |
+
# Side chat
|
| 21 |
+
|
| 22 |
+
A Gradio port of the browser steganacrostics app. A completely normal text
|
| 23 |
+
assistant — except every line of the answer secretly starts with the next
|
| 24 |
+
letter of a hidden **secret** word (an acrostic). It does this with
|
| 25 |
+
grammar-constrained decoding over a small local model (`openbmb/MiniCPM5-1B` by
|
| 26 |
+
default; set `SIDECHAT_MODEL=LiquidAI/LFM2.5-350M` for the smaller, faster
|
| 27 |
+
original), running on **CPU** via PyTorch `transformers`.
|
| 28 |
+
|
| 29 |
+
What's ported from the JavaScript original (`../../src/`):
|
| 30 |
+
|
| 31 |
+
- **Grammar engine** (`grammar.py`) — a tiny NFA that pins each line to its
|
| 32 |
+
forced first letter, with optional ` * ` bullets and a max line length.
|
| 33 |
+
- **Constrained generation** (`logits.py` + `masking.py`) — a `LogitsProcessor`
|
| 34 |
+
that masks every token that would break the acrostic; EOS only at an accept
|
| 35 |
+
state. A state-keyed cache makes the per-step vocab scan cheap.
|
| 36 |
+
- **List-vs-prose classifier** (`classifier.py`) — an optimized prompt,
|
| 37 |
+
grammar-constrained to `list.` / `story.`, that auto-picks the render mode.
|
| 38 |
+
The prompt is tuned per model: failure modes are model-specific, so
|
| 39 |
+
`eval_classifier.py` (50 list + 50 prose prompts) and `sweep_minicpm.py`
|
| 40 |
+
re-optimize it for whatever model is in use.
|
| 41 |
+
- **Local-crossing search** (`crossing_search.py`) — the "extra attention at
|
| 42 |
+
the constraint": generate each prose line greedily, then choose where to
|
| 43 |
+
break it so a short window straddling the crossing (last *k* tokens + forced
|
| 44 |
+
letter + next *j* tokens) reads best. Plus stealth lowercase casing and a
|
| 45 |
+
minimum line length.
|
| 46 |
+
|
| 47 |
+
Run locally:
|
| 48 |
+
|
| 49 |
+
```
|
| 50 |
+
pip install -r requirements.txt
|
| 51 |
+
python app.py
|
| 52 |
+
```
|
| 53 |
+
|
| 54 |
+
Then open the printed URL, type a prompt, set a secret in ⚙️ Settings, and click
|
| 55 |
+
Generate. The list-vs-prose classifier runs automatically on each Generate (turn
|
| 56 |
+
it off in ⚙️ Settings to set the render mode by hand, or use 🔎 Detect to preview
|
| 57 |
+
it). Because everything runs on CPU, generation takes seconds (more for the
|
| 58 |
+
larger model); the crossing search trades extra time for smoother prose.
|
| 59 |
+
|
| 60 |
+
The model is downloaded from the Hugging Face Hub on first run. Custom logits
|
| 61 |
+
processing requires the model to run in-process, so this app does not use the
|
| 62 |
+
remote Inference API.
|
app.py
CHANGED
|
@@ -1,69 +1,349 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import gradio as gr
|
| 2 |
-
from huggingface_hub import InferenceClient
|
| 3 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 4 |
|
| 5 |
-
|
| 6 |
-
|
| 7 |
-
|
| 8 |
-
|
| 9 |
-
max_tokens,
|
| 10 |
-
temperature,
|
| 11 |
-
top_p,
|
| 12 |
-
hf_token: gr.OAuthToken,
|
| 13 |
-
):
|
| 14 |
-
"""
|
| 15 |
-
For more information on `huggingface_hub` Inference API support, please check the docs: https://huggingface.co/docs/huggingface_hub/v0.22.2/en/guides/inference
|
| 16 |
-
"""
|
| 17 |
-
client = InferenceClient(token=hf_token.token, model="openai/gpt-oss-20b")
|
| 18 |
|
| 19 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 20 |
|
| 21 |
-
messages.extend(history)
|
| 22 |
|
| 23 |
-
|
|
|
|
| 24 |
|
| 25 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 26 |
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
stream=True,
|
| 31 |
-
temperature=temperature,
|
| 32 |
-
top_p=top_p,
|
| 33 |
-
):
|
| 34 |
-
choices = message.choices
|
| 35 |
-
token = ""
|
| 36 |
-
if len(choices) and choices[0].delta.content:
|
| 37 |
-
token = choices[0].delta.content
|
| 38 |
|
| 39 |
-
|
| 40 |
-
yield response
|
| 41 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 42 |
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
""
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 61 |
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 66 |
|
| 67 |
|
| 68 |
if __name__ == "__main__":
|
| 69 |
-
demo.launch()
|
|
|
|
| 1 |
+
"""Side chat — a Gradio port of the browser steganacrostics app.
|
| 2 |
+
|
| 3 |
+
Completely normal text assistant, with a secret talking on the side: every line
|
| 4 |
+
of the answer starts with successive letters of a hidden "secret" word (an
|
| 5 |
+
acrostic), produced by grammar-constrained decoding. A list-vs-prose classifier
|
| 6 |
+
auto-picks the render mode, and an optional local-crossing search spends extra
|
| 7 |
+
attention at each constraint cliff so the forced letters read as the natural
|
| 8 |
+
next word.
|
| 9 |
+
|
| 10 |
+
Runs the model locally on CPU with PyTorch transformers (the remote Inference
|
| 11 |
+
API can't do custom logits processing, which is the whole point here).
|
| 12 |
+
"""
|
| 13 |
+
|
| 14 |
+
from __future__ import annotations
|
| 15 |
+
|
| 16 |
+
import os
|
| 17 |
+
import re
|
| 18 |
+
import threading
|
| 19 |
+
import queue
|
| 20 |
+
import time
|
| 21 |
+
|
| 22 |
+
import torch
|
| 23 |
+
from transformers import (
|
| 24 |
+
AutoModelForCausalLM,
|
| 25 |
+
AutoTokenizer,
|
| 26 |
+
LogitsProcessorList,
|
| 27 |
+
TextIteratorStreamer,
|
| 28 |
+
)
|
| 29 |
+
|
| 30 |
import gradio as gr
|
|
|
|
| 31 |
|
| 32 |
+
from grammar import compile_acrostic, union_grammars
|
| 33 |
+
from logits import GrammarLogitsProcessor, build_token_text_table
|
| 34 |
+
from tokinfo import build_tok_info
|
| 35 |
+
from classifier import classify, DEFAULT_VARIANT
|
| 36 |
+
from crossing_search import generate_crossing_search
|
| 37 |
|
| 38 |
+
# Default to MiniCPM5-1B (OpenBMB); override with SIDECHAT_MODEL, e.g.
|
| 39 |
+
# SIDECHAT_MODEL=LiquidAI/LFM2.5-350M for the smaller, faster original.
|
| 40 |
+
MODEL_ID = os.environ.get("SIDECHAT_MODEL", "openbmb/MiniCPM5-1B")
|
| 41 |
+
DEVICE = "cpu" # pure CPU by request
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 42 |
|
| 43 |
+
LIST_SYSTEM = (
|
| 44 |
+
"You are a helpful assistant. Answer as a plain bulleted list — one short "
|
| 45 |
+
"item per line. Do not use markdown, bold text, headings, code, or numbered "
|
| 46 |
+
"lists."
|
| 47 |
+
)
|
| 48 |
+
PROSE_SYSTEM = (
|
| 49 |
+
"You are a helpful assistant. Answer in plain prose. Do not use markdown, "
|
| 50 |
+
"bold text, headings, code, or bulleted/numbered lists."
|
| 51 |
+
)
|
| 52 |
|
|
|
|
| 53 |
|
| 54 |
+
class Context:
|
| 55 |
+
"""Everything the generation + classifier code needs, built once at startup."""
|
| 56 |
|
| 57 |
+
def __init__(self):
|
| 58 |
+
print(f"loading {MODEL_ID} on {DEVICE}…", flush=True)
|
| 59 |
+
self.tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
|
| 60 |
+
self.model = (
|
| 61 |
+
AutoModelForCausalLM.from_pretrained(MODEL_ID, dtype=torch.float32)
|
| 62 |
+
.to(DEVICE)
|
| 63 |
+
.eval()
|
| 64 |
+
)
|
| 65 |
+
self.model.device # noqa: B018 (touch to confirm)
|
| 66 |
+
vocab = self.model.config.vocab_size
|
| 67 |
|
| 68 |
+
t0 = time.perf_counter()
|
| 69 |
+
self.token_text = build_token_text_table(self.tokenizer, vocab)
|
| 70 |
+
print(f"token table built in {time.perf_counter() - t0:.1f}s ({vocab} tokens)", flush=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 71 |
|
| 72 |
+
eos = set()
|
|
|
|
| 73 |
|
| 74 |
+
def add_eos(x):
|
| 75 |
+
if x is None:
|
| 76 |
+
return
|
| 77 |
+
if isinstance(x, (list, tuple)):
|
| 78 |
+
for y in x:
|
| 79 |
+
add_eos(y)
|
| 80 |
+
else:
|
| 81 |
+
eos.add(int(x))
|
| 82 |
|
| 83 |
+
add_eos(self.tokenizer.eos_token_id)
|
| 84 |
+
add_eos(getattr(self.model.config, "eos_token_id", None))
|
| 85 |
+
add_eos(getattr(self.model.generation_config, "eos_token_id", None))
|
| 86 |
+
self.eos_token_ids = sorted(eos)
|
| 87 |
+
|
| 88 |
+
pad = self.tokenizer.pad_token_id
|
| 89 |
+
if pad is None:
|
| 90 |
+
pad = getattr(self.model.generation_config, "pad_token_id", None)
|
| 91 |
+
if pad is None:
|
| 92 |
+
pad = self.eos_token_ids[0]
|
| 93 |
+
self.pad_token_id = int(pad)
|
| 94 |
+
|
| 95 |
+
self.tok_info = build_tok_info(self.token_text, self.eos_token_ids)
|
| 96 |
+
print("context ready.", flush=True)
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
CTX = Context()
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
# Case-insensitive acrostic; in list mode the very first ` * ` prefix is optional
|
| 103 |
+
# (some models start with a preamble-free letter, others don't).
|
| 104 |
+
def build_grammar(secret, list_mode, max_line):
|
| 105 |
+
if not list_mode:
|
| 106 |
+
return compile_acrostic(secret, list_prefix="", max_line=max_line, case_insensitive=True)
|
| 107 |
+
with_prefix = compile_acrostic(
|
| 108 |
+
secret, list_prefix=" * ", max_line=max_line, case_insensitive=True, first_line_prefix=True
|
| 109 |
+
)
|
| 110 |
+
without_prefix = compile_acrostic(
|
| 111 |
+
secret, list_prefix=" * ", max_line=max_line, case_insensitive=True, first_line_prefix=False
|
| 112 |
+
)
|
| 113 |
+
return union_grammars([with_prefix, without_prefix])
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
def check_acrostic(output, secret):
|
| 117 |
+
"""Find a window of line-initial letters matching the secret (case-insensitive,
|
| 118 |
+
bullets stripped). Returns (ok, firsts)."""
|
| 119 |
+
lines = [l.strip() for l in output.split("\n")]
|
| 120 |
+
lines = [l for l in lines if l]
|
| 121 |
+
def strip(l):
|
| 122 |
+
return re.sub(r"^\*?\s*", "", l)
|
| 123 |
+
firsts = [(strip(l)[:1] or "") for l in lines]
|
| 124 |
+
n = len(secret)
|
| 125 |
+
for i in range(0, len(firsts) - n + 1):
|
| 126 |
+
if all(firsts[i + j].lower() == secret[j].lower() for j in range(n)):
|
| 127 |
+
return True, "".join(firsts[i:i + n])
|
| 128 |
+
return False, "".join(firsts)
|
| 129 |
+
|
| 130 |
+
|
| 131 |
+
# --- Classifier: drive the list/prose checkbox ------------------------------
|
| 132 |
+
def classify_fn(prompt):
|
| 133 |
+
if not (prompt or "").strip():
|
| 134 |
+
return gr.update(), "enter a prompt to detect list vs. prose"
|
| 135 |
+
pred, raw = classify(CTX, prompt, DEFAULT_VARIANT)
|
| 136 |
+
label = "list" if pred else "prose"
|
| 137 |
+
return pred, f"detected **{label}** (classifier raw: {raw!r})"
|
| 138 |
+
|
| 139 |
+
|
| 140 |
+
def maybe_detect(prompt, list_mode, auto_detect):
|
| 141 |
+
"""Runs before Generate: when auto-detect is on, classify the prompt and set
|
| 142 |
+
the list/prose checkbox from it. Otherwise leave the manual choice alone."""
|
| 143 |
+
if auto_detect and (prompt or "").strip():
|
| 144 |
+
pred, raw = classify(CTX, prompt, DEFAULT_VARIANT)
|
| 145 |
+
return pred, f"detected **{'list' if pred else 'prose'}** (raw {raw!r}) — generating…"
|
| 146 |
+
return list_mode, gr.update()
|
| 147 |
+
|
| 148 |
+
|
| 149 |
+
# --- Generation -------------------------------------------------------------
|
| 150 |
+
def _run_in_thread(target):
|
| 151 |
+
"""Run target() in a daemon thread; return a queue it pushes to. target
|
| 152 |
+
receives the queue and must push a None sentinel when finished."""
|
| 153 |
+
q = queue.Queue()
|
| 154 |
+
threading.Thread(target=target, args=(q,), daemon=True).start()
|
| 155 |
+
return q
|
| 156 |
+
|
| 157 |
+
|
| 158 |
+
def generate_fn(prompt, secret, list_mode, max_line, crossing, k, j, R, min_line):
|
| 159 |
+
# Strip spaces: a multi-word secret spells its letters across lines; spaces
|
| 160 |
+
# would force odd punctuation-prefixed "word-break" lines. The field still
|
| 161 |
+
# shows the spaced version; the acrostic uses only the letters.
|
| 162 |
+
secret = re.sub(r"\s+", "", (secret or "").strip())
|
| 163 |
+
if not secret:
|
| 164 |
+
yield "(secret is empty — open ⚙️ Settings and set one)", "", ""
|
| 165 |
+
return
|
| 166 |
+
|
| 167 |
+
list_mode = bool(list_mode)
|
| 168 |
+
max_line = max(1, int(max_line or 80))
|
| 169 |
+
system_prompt = LIST_SYSTEM if list_mode else PROSE_SYSTEM
|
| 170 |
+
|
| 171 |
+
try:
|
| 172 |
+
grammar = build_grammar(secret, list_mode, max_line)
|
| 173 |
+
except Exception as e: # noqa: BLE001
|
| 174 |
+
yield f"grammar build error: {e}", "", ""
|
| 175 |
+
return
|
| 176 |
+
|
| 177 |
+
# --- Local-crossing search (prose only) ---------------------------------
|
| 178 |
+
if crossing and not list_mode:
|
| 179 |
+
k = max(0, int(k or 4))
|
| 180 |
+
j = max(0, int(j or 3))
|
| 181 |
+
R = max(0, int(R or 4))
|
| 182 |
+
min_line = min(max_line, max(0, int(min_line or 30)))
|
| 183 |
+
status = f"generating (local-crossing search · k={k}, j={j}, R={R}, minLine={min_line})…"
|
| 184 |
+
committed = [""]
|
| 185 |
+
t0 = time.perf_counter()
|
| 186 |
+
|
| 187 |
+
def worker(q):
|
| 188 |
+
def on_line(line_text, info):
|
| 189 |
+
committed[0] += line_text
|
| 190 |
+
q.put(committed[0])
|
| 191 |
+
try:
|
| 192 |
+
res = generate_crossing_search(
|
| 193 |
+
CTX, grammar, secret, max_line, prompt, system_prompt,
|
| 194 |
+
k=k, j=j, R=R, min_line=min_line, on_line=on_line,
|
| 195 |
+
)
|
| 196 |
+
q.put(("done", res))
|
| 197 |
+
except Exception as e: # noqa: BLE001
|
| 198 |
+
q.put(("error", str(e)))
|
| 199 |
+
q.put(None)
|
| 200 |
+
|
| 201 |
+
q = _run_in_thread(worker)
|
| 202 |
+
result = None
|
| 203 |
+
yield "", "", status
|
| 204 |
+
while True:
|
| 205 |
+
item = q.get()
|
| 206 |
+
if item is None:
|
| 207 |
+
break
|
| 208 |
+
if isinstance(item, tuple) and item[0] == "done":
|
| 209 |
+
result = item[1]
|
| 210 |
+
elif isinstance(item, tuple) and item[0] == "error":
|
| 211 |
+
yield committed[0], f"error: {item[1]}", "error"
|
| 212 |
+
else:
|
| 213 |
+
yield item, "", status
|
| 214 |
+
elapsed = time.perf_counter() - t0
|
| 215 |
+
text = result["text"] if result else committed[0]
|
| 216 |
+
per_line = result["per_line"] if result else []
|
| 217 |
+
n_moved = sum(1 for p in per_line if p.get("r", 0) > 0)
|
| 218 |
+
ok, firsts = check_acrostic(text, secret)
|
| 219 |
+
metrics = (
|
| 220 |
+
f"local-crossing · {elapsed:.2f}s · {len(per_line)} lines · "
|
| 221 |
+
f"{n_moved} breaks moved · acrostic {'OK' if ok else 'MISS'} ({firsts})"
|
| 222 |
+
)
|
| 223 |
+
yield text, metrics, "done (local-crossing search)."
|
| 224 |
+
return
|
| 225 |
+
|
| 226 |
+
# --- Plain grammar-constrained greedy (token-streamed) ------------------
|
| 227 |
+
proc = GrammarLogitsProcessor(grammar, CTX.tokenizer, CTX.token_text, CTX.eos_token_ids)
|
| 228 |
+
messages = [
|
| 229 |
+
{"role": "system", "content": system_prompt},
|
| 230 |
+
{"role": "user", "content": prompt},
|
| 231 |
+
]
|
| 232 |
+
enc = CTX.tokenizer.apply_chat_template(
|
| 233 |
+
messages, add_generation_prompt=True, return_tensors="pt", return_dict=True
|
| 234 |
+
).to(DEVICE)
|
| 235 |
+
streamer = TextIteratorStreamer(CTX.tokenizer, skip_prompt=True, skip_special_tokens=True)
|
| 236 |
+
|
| 237 |
+
gen_kwargs = dict(
|
| 238 |
+
**enc,
|
| 239 |
+
max_new_tokens=400,
|
| 240 |
+
do_sample=False,
|
| 241 |
+
logits_processor=LogitsProcessorList([proc]),
|
| 242 |
+
streamer=streamer,
|
| 243 |
+
pad_token_id=CTX.pad_token_id,
|
| 244 |
+
)
|
| 245 |
+
|
| 246 |
+
t0 = time.perf_counter()
|
| 247 |
+
thread = threading.Thread(target=CTX.model.generate, kwargs=gen_kwargs, daemon=True)
|
| 248 |
+
thread.start()
|
| 249 |
+
|
| 250 |
+
acc = ""
|
| 251 |
+
t_first = None
|
| 252 |
+
tokens = 0
|
| 253 |
+
chars = 0
|
| 254 |
+
yield "", "", "generating (grammar-constrained)…"
|
| 255 |
+
for chunk in streamer:
|
| 256 |
+
if not chunk:
|
| 257 |
+
continue
|
| 258 |
+
if t_first is None:
|
| 259 |
+
t_first = time.perf_counter()
|
| 260 |
+
tokens += 1
|
| 261 |
+
chars += len(chunk)
|
| 262 |
+
acc += chunk
|
| 263 |
+
gen_s = max(0.001, time.perf_counter() - t_first)
|
| 264 |
+
tps = tokens / gen_s
|
| 265 |
+
ttft = (t_first - t0)
|
| 266 |
+
yield acc, f"TTFT {ttft:.2f}s · ~{tps:.1f} tok/s · {tokens} tokens · {chars} chars", "generating…"
|
| 267 |
+
thread.join()
|
| 268 |
+
|
| 269 |
+
wall = time.perf_counter() - t0
|
| 270 |
+
s = proc.stats
|
| 271 |
+
proc_ms = s["total_ms"]
|
| 272 |
+
ttft = (t_first - t0) if t_first else 0.0
|
| 273 |
+
ok, firsts = check_acrostic(acc, secret)
|
| 274 |
+
metrics = (
|
| 275 |
+
f"TTFT {ttft:.2f}s · {tokens} tokens · {chars} chars · wall {wall:.2f}s · "
|
| 276 |
+
f"mask {proc_ms:.0f}ms ({(proc_ms/1000)/wall*100:.0f}%) · "
|
| 277 |
+
f"acrostic {'OK' if ok else 'MISS'} ({firsts})"
|
| 278 |
+
)
|
| 279 |
+
yield acc, metrics, "done. edit the secret and/or prompt and click Generate again."
|
| 280 |
+
|
| 281 |
+
|
| 282 |
+
# --- UI ---------------------------------------------------------------------
|
| 283 |
+
with gr.Blocks(title="Side chat") as demo:
|
| 284 |
+
gr.Markdown("# Side chat")
|
| 285 |
+
gr.Markdown(
|
| 286 |
+
"Completely normal text assistant, with talking on the side. Each line "
|
| 287 |
+
"of the answer secretly starts with the next letter of your **secret** "
|
| 288 |
+
f"word — grammar-constrained decoding on `{MODEL_ID}`, running locally on CPU."
|
| 289 |
+
)
|
| 290 |
+
|
| 291 |
+
prompt = gr.Textbox(label="Prompt", value="what are some easy-to-make home recipes?", lines=2)
|
| 292 |
+
gr.Examples(
|
| 293 |
+
examples=[
|
| 294 |
+
["what are some easy-to-make home recipes?"],
|
| 295 |
+
["please write a few sentences about regular expressions"],
|
| 296 |
+
],
|
| 297 |
+
inputs=prompt,
|
| 298 |
+
label="Demo prompts (one detects as a list, one as prose)",
|
| 299 |
+
)
|
| 300 |
+
run = gr.Button("Generate", variant="primary")
|
| 301 |
+
output = gr.Textbox(label="Output", lines=10, interactive=False)
|
| 302 |
+
metrics = gr.Markdown("")
|
| 303 |
+
|
| 304 |
+
with gr.Accordion("⚙️ Settings", open=False):
|
| 305 |
+
secret = gr.Textbox(
|
| 306 |
+
label="Secret (each line will start with these letters)", value="subtle"
|
| 307 |
+
)
|
| 308 |
+
auto_detect = gr.Checkbox(
|
| 309 |
+
label="auto-detect list vs. prose on Generate (LLM classifier)",
|
| 310 |
+
value=True,
|
| 311 |
+
)
|
| 312 |
+
list_mode = gr.Checkbox(
|
| 313 |
+
label="render as bulleted list (each line prefixed with ` * `) — "
|
| 314 |
+
"set by auto-detect; uncheck auto-detect to set it manually",
|
| 315 |
+
value=True,
|
| 316 |
+
)
|
| 317 |
+
# Manual preview: run the classifier without generating (debug aid).
|
| 318 |
+
detect = gr.Button("🔎 Detect list / prose (preview only)", size="sm")
|
| 319 |
+
max_line = gr.Number(label="Max chars per line (after the prefix + letter)", value=80, precision=0)
|
| 320 |
+
|
| 321 |
+
gr.Markdown("**Local-crossing search** (prose only) — extra attention at each constraint cliff")
|
| 322 |
+
crossing = gr.Checkbox(
|
| 323 |
+
label="enable local-crossing search (greedy line, then pick the break "
|
| 324 |
+
"that makes the crossing read best; list mode stays greedy)",
|
| 325 |
+
value=False,
|
| 326 |
+
)
|
| 327 |
+
win_k = gr.Number(label="↳ window before the break (k content tokens)", value=4, precision=0)
|
| 328 |
+
win_j = gr.Number(label="↳ window after the forced letter (j content tokens)", value=3, precision=0)
|
| 329 |
+
max_rewind = gr.Number(label="↳ max tokens to trim the break earlier (R; 0 = greedy)", value=4, precision=0)
|
| 330 |
+
min_line = gr.Number(label="↳ min chars per line (avoid stubby lines; 0 = off)", value=30, precision=0)
|
| 331 |
+
|
| 332 |
+
status = gr.Markdown("ready.")
|
| 333 |
+
|
| 334 |
+
# Manual preview: detect list vs. prose without generating.
|
| 335 |
+
detect.click(classify_fn, [prompt], [list_mode, status])
|
| 336 |
+
prompt.submit(classify_fn, [prompt], [list_mode, status])
|
| 337 |
|
| 338 |
+
# Generate: auto-detect first (updates the checkbox), then generate using it.
|
| 339 |
+
run.click(
|
| 340 |
+
maybe_detect, [prompt, list_mode, auto_detect], [list_mode, status]
|
| 341 |
+
).then(
|
| 342 |
+
generate_fn,
|
| 343 |
+
[prompt, secret, list_mode, max_line, crossing, win_k, win_j, max_rewind, min_line],
|
| 344 |
+
[output, metrics, status],
|
| 345 |
+
)
|
| 346 |
|
| 347 |
|
| 348 |
if __name__ == "__main__":
|
| 349 |
+
demo.queue().launch(theme=gr.themes.Soft())
|
classifier.py
ADDED
|
@@ -0,0 +1,124 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""List-vs-prose classifier (Python port of the shipped part of src/eval.js).
|
| 2 |
+
|
| 3 |
+
The classifier reads the user's prompt and decides whether the answer is best
|
| 4 |
+
rendered as a bulleted list or as narrative prose. It is itself an LLM call,
|
| 5 |
+
grammar-constrained to exactly one of two literal completions: apply the chat
|
| 6 |
+
template, append a partial assistant response (the `prefill`), constrain
|
| 7 |
+
generation to one of `branches`, parse the result.
|
| 8 |
+
|
| 9 |
+
Failure modes are model-specific, so the prompt is tuned per model. The default
|
| 10 |
+
here is the MiniCPM5-1B winner (`minicpm_intent_write_sp`, 96% on the 100-prompt
|
| 11 |
+
suite) found by re-running the sweep (eval_classifier.py / sweep_minicpm.py) on
|
| 12 |
+
that model. The LFM2.5-350M winner (`r6_c1_v2_single_plural`, 97.5% dev / 85%
|
| 13 |
+
val) is kept as an alternate — it is *prose-biased* on MiniCPM (~75%), so don't
|
| 14 |
+
reuse it there. See CLASSIFIER_PROMPT_OPTIMIZATION.md for the original JS sweep.
|
| 15 |
+
"""
|
| 16 |
+
|
| 17 |
+
from __future__ import annotations
|
| 18 |
+
|
| 19 |
+
from dataclasses import dataclass
|
| 20 |
+
from typing import Callable, List
|
| 21 |
+
|
| 22 |
+
import torch
|
| 23 |
+
from transformers import LogitsProcessorList
|
| 24 |
+
|
| 25 |
+
from grammar import compile_literal, union_grammars
|
| 26 |
+
from logits import GrammarLogitsProcessor
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
@dataclass
|
| 30 |
+
class Variant:
|
| 31 |
+
name: str
|
| 32 |
+
system: str
|
| 33 |
+
prefill: str
|
| 34 |
+
branches: List[str]
|
| 35 |
+
parse: Callable[[str], bool] # raw generated text -> True (list) / False (prose)
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
# Shared trigger-rule strings.
|
| 39 |
+
_INTENT_BASE = (
|
| 40 |
+
"Classify the user's intent. Use \"list\" when the answer is a set of "
|
| 41 |
+
"separate items the user can scan. Use \"story\" when the answer flows as "
|
| 42 |
+
"one narrative, single fact, or short paragraph."
|
| 43 |
+
)
|
| 44 |
+
_WRITE_FORMS = (
|
| 45 |
+
" Whenever the user asks to \"write\" or \"compose\" a haiku, poem, letter, "
|
| 46 |
+
"cover letter, email, joke, story, essay, or limerick, the answer is a story."
|
| 47 |
+
)
|
| 48 |
+
_SINGLE_PLURAL = (
|
| 49 |
+
" \"What is X\" (a single fact) is a story; \"What are the/some Xs\" (plural "
|
| 50 |
+
"enumeration) is a list; \"what are the steps/differences/causes/symptoms\" "
|
| 51 |
+
"is a list."
|
| 52 |
+
)
|
| 53 |
+
|
| 54 |
+
# --- The shipped MiniCPM5-1B winner -----------------------------------------
|
| 55 |
+
# On MiniCPM, every "Default to list" framing collapses to all-story (list 0/50)
|
| 56 |
+
# and the LFM2 winner is prose-biased. A neutral *intent* framing nails list
|
| 57 |
+
# recall; adding the write-forms rule (catches "write a haiku/email") and the
|
| 58 |
+
# single-vs-plural rule (catches "what is X" single facts) fixes the residual
|
| 59 |
+
# prose misses. 96% on the 100-prompt suite (list 49/50, prose 47/50).
|
| 60 |
+
DEFAULT_VARIANT = Variant(
|
| 61 |
+
name="minicpm_intent_write_sp",
|
| 62 |
+
system=_INTENT_BASE + _WRITE_FORMS + _SINGLE_PLURAL,
|
| 63 |
+
prefill="The intent is to get a ",
|
| 64 |
+
branches=["list.", "story."],
|
| 65 |
+
parse=lambda s: s.startswith("list"),
|
| 66 |
+
)
|
| 67 |
+
|
| 68 |
+
# --- Reference alternates (other strong variants; useful when re-tuning) -----
|
| 69 |
+
ALTERNATES = [
|
| 70 |
+
# The LFM2.5-350M winner (97.5% dev / 85% val on LFM2; ~75% on MiniCPM).
|
| 71 |
+
Variant(
|
| 72 |
+
name="r6_c1_v2_single_plural",
|
| 73 |
+
system=(
|
| 74 |
+
"Classify the user's request. Use \"list\" when the user wants "
|
| 75 |
+
"enumerated items. Use \"story\" for everything else. \"What is X\" "
|
| 76 |
+
"(a single fact) is a story; \"What are the/some Xs\" (plural "
|
| 77 |
+
"enumeration) is a list; \"what are the steps/differences/causes/"
|
| 78 |
+
"symptoms\" is a list."
|
| 79 |
+
),
|
| 80 |
+
prefill="The user is asking for a ",
|
| 81 |
+
branches=["list.", "story."],
|
| 82 |
+
parse=lambda s: s.startswith("list"),
|
| 83 |
+
),
|
| 84 |
+
# Intent base + single-plural only (100% screen, 93% full on MiniCPM;
|
| 85 |
+
# perfect list recall but misses some "write a X" prose prompts).
|
| 86 |
+
Variant(
|
| 87 |
+
name="minicpm_intent_sp",
|
| 88 |
+
system=_INTENT_BASE + _SINGLE_PLURAL,
|
| 89 |
+
prefill="The intent is to get a ",
|
| 90 |
+
branches=["list.", "story."],
|
| 91 |
+
parse=lambda s: s.startswith("list"),
|
| 92 |
+
),
|
| 93 |
+
]
|
| 94 |
+
|
| 95 |
+
VARIANTS = [DEFAULT_VARIANT, *ALTERNATES]
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
def classify(ctx, prompt, variant=DEFAULT_VARIANT):
|
| 99 |
+
"""Run one classifier call. ctx is a Context (see app.py): .model,
|
| 100 |
+
.tokenizer, .token_text, .eos_token_ids. Returns (prediction, raw)."""
|
| 101 |
+
tok = ctx.tokenizer
|
| 102 |
+
messages = [
|
| 103 |
+
{"role": "system", "content": variant.system},
|
| 104 |
+
{"role": "user", "content": prompt},
|
| 105 |
+
]
|
| 106 |
+
templated = tok.apply_chat_template(
|
| 107 |
+
messages, tokenize=False, add_generation_prompt=True
|
| 108 |
+
)
|
| 109 |
+
full_text = templated + variant.prefill
|
| 110 |
+
|
| 111 |
+
grammar = union_grammars([compile_literal(b) for b in variant.branches])
|
| 112 |
+
proc = GrammarLogitsProcessor(grammar, tok, ctx.token_text, ctx.eos_token_ids)
|
| 113 |
+
|
| 114 |
+
enc = tok(full_text, return_tensors="pt", add_special_tokens=False).to(ctx.model.device)
|
| 115 |
+
with torch.no_grad():
|
| 116 |
+
out = ctx.model.generate(
|
| 117 |
+
**enc,
|
| 118 |
+
max_new_tokens=16,
|
| 119 |
+
do_sample=False,
|
| 120 |
+
logits_processor=LogitsProcessorList([proc]),
|
| 121 |
+
pad_token_id=ctx.pad_token_id,
|
| 122 |
+
)
|
| 123 |
+
raw = tok.decode(out[0][enc["input_ids"].shape[1]:], skip_special_tokens=True)
|
| 124 |
+
return variant.parse(raw), raw
|
crossing_search.py
ADDED
|
@@ -0,0 +1,292 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Local-crossing-objective search for acrostics (Python port of
|
| 2 |
+
src/crossingSearch.js + the LineMaskScore / NewlineStop bits of
|
| 3 |
+
src/surprisalLookahead.js).
|
| 4 |
+
|
| 5 |
+
This is the "extra attention when we come to a constraint": the forced first
|
| 6 |
+
letter of each line is a cliff where the constraint fights what the model wants
|
| 7 |
+
to say. Greedy is a strong baseline; search only beats it when the *objective*
|
| 8 |
+
is right. So we change only the objective:
|
| 9 |
+
|
| 10 |
+
1. Score a SHORT fixed window straddling the crossing — the last `k` content
|
| 11 |
+
tokens before the break, plus the forced letter and the next `j` tokens.
|
| 12 |
+
Length-neutral; the structural newline is never scored, so there's no
|
| 13 |
+
run-to-the-wall bias.
|
| 14 |
+
2. Look `j` tokens PAST the forced letter (does the next line *continue*
|
| 15 |
+
well?), not just at it.
|
| 16 |
+
3. Make the break point a search variable, snapped to word boundaries:
|
| 17 |
+
generate the line greedily, then consider ending it 0..R tokens earlier.
|
| 18 |
+
r=0 (greedy) is always a candidate, so this can only match or beat greedy.
|
| 19 |
+
|
| 20 |
+
Plus two stealth touches carried in LineMaskScore: lowercase the forced letter
|
| 21 |
+
mid-sentence (so the acrostic hides), and a minimum line length (no stubby
|
| 22 |
+
lines). Public-API only: each line/rollout is a fresh model.generate()
|
| 23 |
+
continuation of the chat-templated prompt + committed text fed back as a string.
|
| 24 |
+
"""
|
| 25 |
+
|
| 26 |
+
from __future__ import annotations
|
| 27 |
+
|
| 28 |
+
import re
|
| 29 |
+
|
| 30 |
+
import numpy as np
|
| 31 |
+
import torch
|
| 32 |
+
from transformers import LogitsProcessor, LogitsProcessorList, StoppingCriteria, StoppingCriteriaList
|
| 33 |
+
|
| 34 |
+
from masking import LegalCache
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
# Mid-sentence iff the text so far doesn't end a sentence — then the next forced
|
| 38 |
+
# letter should be lowercase. Empty prefix (line 0) is a sentence start.
|
| 39 |
+
_SENTENCE_END = re.compile(r"[.!?][\"'”’)\]]?$")
|
| 40 |
+
_WORD_START = re.compile(r"^[\s.,;:!?)\]\"'’”]")
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
def mid_sentence(t):
|
| 44 |
+
s = re.sub(r"\s+$", "", t or "")
|
| 45 |
+
return len(s) > 0 and not _SENTENCE_END.search(s)
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
class LineMaskScore(LogitsProcessor):
|
| 49 |
+
"""Grammar-masking processor that also accumulates the greedy chosen-token
|
| 50 |
+
log-prob per step, applies stealth casing + minimum line length, and (when
|
| 51 |
+
capture_top_n > 0, at the first step) records the top-N legal openings and a
|
| 52 |
+
surprise signal."""
|
| 53 |
+
|
| 54 |
+
def __init__(self, grammar, start_state, tokenizer, token_text, tok_info, cache,
|
| 55 |
+
capture_top_n=0, force_lower_first=False, min_line=0):
|
| 56 |
+
super().__init__()
|
| 57 |
+
self.grammar = grammar
|
| 58 |
+
self.start_state = start_state
|
| 59 |
+
self.tokenizer = tokenizer
|
| 60 |
+
self.token_text = token_text
|
| 61 |
+
self.info = tok_info
|
| 62 |
+
self.cache = cache
|
| 63 |
+
self.capture_top_n = capture_top_n
|
| 64 |
+
self.force_lower_first = force_lower_first
|
| 65 |
+
self.min_line = min_line
|
| 66 |
+
self.prompt_length = None
|
| 67 |
+
self.step_logprobs = [] # chosen (argmax) log-prob, one per generated step
|
| 68 |
+
self.top_n = None
|
| 69 |
+
self.surprise = None
|
| 70 |
+
|
| 71 |
+
def __call__(self, input_ids, scores):
|
| 72 |
+
ids = input_ids[0]
|
| 73 |
+
if self.prompt_length is None:
|
| 74 |
+
self.prompt_length = ids.shape[0]
|
| 75 |
+
generated = ids[self.prompt_length:].tolist()
|
| 76 |
+
gen = (
|
| 77 |
+
self.tokenizer.decode(generated, skip_special_tokens=True)
|
| 78 |
+
if generated
|
| 79 |
+
else ""
|
| 80 |
+
)
|
| 81 |
+
|
| 82 |
+
state = self.grammar.advance(self.start_state, gen)
|
| 83 |
+
data = scores[0]
|
| 84 |
+
if state == -1:
|
| 85 |
+
self.step_logprobs.append(0.0)
|
| 86 |
+
return scores
|
| 87 |
+
|
| 88 |
+
first_step = self.top_n is None and self.surprise is None
|
| 89 |
+
want_signal = self.capture_top_n and first_step
|
| 90 |
+
|
| 91 |
+
lse_all = None
|
| 92 |
+
if want_signal:
|
| 93 |
+
lse_all = torch.logsumexp(data, dim=0).item()
|
| 94 |
+
|
| 95 |
+
# Grammar-legal set (shared, state-cached), then stealth + minLine on top.
|
| 96 |
+
legal = self.cache.legal_np(state).copy()
|
| 97 |
+
info = self.info
|
| 98 |
+
|
| 99 |
+
if self.force_lower_first and len(gen) == 0:
|
| 100 |
+
if np.any(legal & info.alpha_lower):
|
| 101 |
+
legal &= ~info.alpha_upper
|
| 102 |
+
|
| 103 |
+
if self.min_line and len(gen) < self.min_line:
|
| 104 |
+
body = legal & ~info.eos_mask & info.nonempty & ~info.has_newline
|
| 105 |
+
if np.any(body):
|
| 106 |
+
legal &= ~(info.eos_mask | info.has_newline)
|
| 107 |
+
|
| 108 |
+
illegal = torch.from_numpy(~legal).to(data.device)
|
| 109 |
+
data[illegal] = float("-inf")
|
| 110 |
+
|
| 111 |
+
max_legal = torch.max(data)
|
| 112 |
+
if max_legal.item() == float("-inf"):
|
| 113 |
+
self.step_logprobs.append(0.0)
|
| 114 |
+
return scores
|
| 115 |
+
lse_masked = torch.logsumexp(data, dim=0)
|
| 116 |
+
self.step_logprobs.append((max_legal - lse_masked).item())
|
| 117 |
+
|
| 118 |
+
if want_signal:
|
| 119 |
+
self.surprise = lse_all - max_legal.item()
|
| 120 |
+
legal_idx = np.nonzero(legal)[0]
|
| 121 |
+
logits = data[torch.from_numpy(legal_idx).to(data.device)]
|
| 122 |
+
order = torch.argsort(logits, descending=True)[: self.capture_top_n]
|
| 123 |
+
lse_m = lse_masked.item()
|
| 124 |
+
self.top_n = [
|
| 125 |
+
{"id": int(legal_idx[int(o)]), "logit": float(logits[int(o)]),
|
| 126 |
+
"logprob": float(logits[int(o)]) - lse_m}
|
| 127 |
+
for o in order
|
| 128 |
+
]
|
| 129 |
+
return scores
|
| 130 |
+
|
| 131 |
+
|
| 132 |
+
class NewlineStop(StoppingCriteria):
|
| 133 |
+
"""Stop a rollout as soon as the newly generated token contains a newline."""
|
| 134 |
+
|
| 135 |
+
def __init__(self, prompt_length, has_newline):
|
| 136 |
+
super().__init__()
|
| 137 |
+
self.prompt_length = prompt_length
|
| 138 |
+
self.has_newline = has_newline
|
| 139 |
+
|
| 140 |
+
def __call__(self, input_ids, scores, **kwargs):
|
| 141 |
+
out = []
|
| 142 |
+
for ids in input_ids:
|
| 143 |
+
if ids.shape[0] <= self.prompt_length:
|
| 144 |
+
out.append(False)
|
| 145 |
+
else:
|
| 146 |
+
out.append(bool(self.has_newline[int(ids[-1])]))
|
| 147 |
+
return torch.tensor(out, dtype=torch.bool, device=input_ids.device)
|
| 148 |
+
|
| 149 |
+
|
| 150 |
+
def _mean(a):
|
| 151 |
+
return sum(a) / len(a) if a else 0.0
|
| 152 |
+
|
| 153 |
+
|
| 154 |
+
def generate_crossing_search(ctx, grammar, secret, max_line, prompt, system_prompt,
|
| 155 |
+
k=4, j=3, R=4, min_line=30, on_line=None):
|
| 156 |
+
"""Generate acrostic text with the local-crossing search. Returns
|
| 157 |
+
{"text": str, "per_line": [...]}. on_line(line_text, info) is called as each
|
| 158 |
+
line is committed (for incremental display)."""
|
| 159 |
+
model = ctx.model
|
| 160 |
+
tok = ctx.tokenizer
|
| 161 |
+
token_text = ctx.token_text
|
| 162 |
+
info = ctx.tok_info
|
| 163 |
+
cache = LegalCache(grammar, token_text, ctx.eos_token_ids) # shared across rollouts
|
| 164 |
+
has_newline = info.has_newline
|
| 165 |
+
|
| 166 |
+
messages = [
|
| 167 |
+
{"role": "system", "content": system_prompt},
|
| 168 |
+
{"role": "user", "content": prompt},
|
| 169 |
+
]
|
| 170 |
+
prompt_string = tok.apply_chat_template(
|
| 171 |
+
messages, tokenize=False, add_generation_prompt=True
|
| 172 |
+
)
|
| 173 |
+
|
| 174 |
+
def enc_ids(text):
|
| 175 |
+
return tok(text, add_special_tokens=False).input_ids
|
| 176 |
+
|
| 177 |
+
def generate_from(text, max_new_tokens, proc, stop_newline):
|
| 178 |
+
enc = tok(text, return_tensors="pt", add_special_tokens=False).to(model.device)
|
| 179 |
+
stops = None
|
| 180 |
+
if stop_newline:
|
| 181 |
+
stops = StoppingCriteriaList(
|
| 182 |
+
[NewlineStop(enc["input_ids"].shape[1], has_newline)]
|
| 183 |
+
)
|
| 184 |
+
with torch.no_grad():
|
| 185 |
+
out = model.generate(
|
| 186 |
+
**enc,
|
| 187 |
+
max_new_tokens=max_new_tokens,
|
| 188 |
+
do_sample=False,
|
| 189 |
+
logits_processor=LogitsProcessorList([proc]),
|
| 190 |
+
stopping_criteria=stops,
|
| 191 |
+
pad_token_id=ctx.pad_token_id,
|
| 192 |
+
)
|
| 193 |
+
gen_ids = out[0][enc["input_ids"].shape[1]:]
|
| 194 |
+
return tok.decode(gen_ids, skip_special_tokens=True)
|
| 195 |
+
|
| 196 |
+
# Greedy line from prefix_text (acrostic text so far).
|
| 197 |
+
def gen_line(prefix_text, is_last):
|
| 198 |
+
start_state = grammar.advance(grammar.initial, prefix_text)
|
| 199 |
+
ctx_str = prompt_string + prefix_text
|
| 200 |
+
proc = LineMaskScore(
|
| 201 |
+
grammar, start_state, tok, token_text, info, cache,
|
| 202 |
+
force_lower_first=mid_sentence(prefix_text), min_line=min_line,
|
| 203 |
+
)
|
| 204 |
+
text = generate_from(ctx_str, max_line + 8, proc, stop_newline=not is_last)
|
| 205 |
+
if not is_last:
|
| 206 |
+
nl = text.find("\n")
|
| 207 |
+
if nl != -1:
|
| 208 |
+
text = text[: nl + 1]
|
| 209 |
+
base_n = len(enc_ids(ctx_str))
|
| 210 |
+
line_ids = enc_ids(ctx_str + text)[base_n:]
|
| 211 |
+
return text, line_ids, proc.step_logprobs
|
| 212 |
+
|
| 213 |
+
# Roll the NEXT line's opening: forced letter + up to n-1 content tokens.
|
| 214 |
+
def roll_open(prefix_text, n):
|
| 215 |
+
start_state = grammar.advance(grammar.initial, prefix_text)
|
| 216 |
+
if start_state == -1:
|
| 217 |
+
return []
|
| 218 |
+
proc = LineMaskScore(
|
| 219 |
+
grammar, start_state, tok, token_text, info, cache,
|
| 220 |
+
force_lower_first=mid_sentence(prefix_text),
|
| 221 |
+
)
|
| 222 |
+
generate_from(prompt_string + prefix_text, n, proc, stop_newline=True)
|
| 223 |
+
return proc.step_logprobs
|
| 224 |
+
|
| 225 |
+
n_lines = len(secret)
|
| 226 |
+
committed = ""
|
| 227 |
+
per_line = []
|
| 228 |
+
|
| 229 |
+
for i in range(n_lines):
|
| 230 |
+
is_last = i == n_lines - 1
|
| 231 |
+
text, line_ids, logps = gen_line(committed, is_last)
|
| 232 |
+
|
| 233 |
+
# Last line, or no break search: commit the greedy line as-is.
|
| 234 |
+
if is_last or R <= 0:
|
| 235 |
+
committed += text
|
| 236 |
+
per_line.append({"line": i, "chosen": text, "r": 0, "candidates": None})
|
| 237 |
+
if on_line:
|
| 238 |
+
on_line(text, {"line": i})
|
| 239 |
+
continue
|
| 240 |
+
|
| 241 |
+
m = min(len(line_ids), len(logps))
|
| 242 |
+
ids = line_ids[-m:] if m else []
|
| 243 |
+
lps = logps[-m:] if m else []
|
| 244 |
+
has_nl = text.endswith("\n")
|
| 245 |
+
line_start_state = grammar.advance(grammar.initial, committed)
|
| 246 |
+
|
| 247 |
+
candidates = []
|
| 248 |
+
for r in range(0, min(R, m - 1) + 1):
|
| 249 |
+
# r tokens trimmed -> break after (m-r) tokens. Require the first
|
| 250 |
+
# trimmed token to begin a new word/punctuation (clean boundary).
|
| 251 |
+
if r > 0:
|
| 252 |
+
first_trimmed = token_text[ids[m - r]]
|
| 253 |
+
if not first_trimmed or not _WORD_START.match(first_trimmed):
|
| 254 |
+
continue
|
| 255 |
+
kept_ids = ids[: m - r]
|
| 256 |
+
if not kept_ids:
|
| 257 |
+
continue
|
| 258 |
+
prefix_text = re.sub(r"\n+$", "", tok.decode(kept_ids, skip_special_tokens=True))
|
| 259 |
+
broke_line = prefix_text + "\n"
|
| 260 |
+
# The trimmed line must still be grammar-legal (keep the forced letter).
|
| 261 |
+
if grammar.advance(line_start_state, broke_line) == -1:
|
| 262 |
+
continue
|
| 263 |
+
if r > 0 and len(prefix_text) < min_line:
|
| 264 |
+
continue
|
| 265 |
+
|
| 266 |
+
before_lps = lps[: m - r]
|
| 267 |
+
if r == 0 and has_nl:
|
| 268 |
+
before_lps = before_lps[:-1]
|
| 269 |
+
before_lps = before_lps[-k:]
|
| 270 |
+
|
| 271 |
+
after_lps = roll_open(committed + broke_line, 1 + j)
|
| 272 |
+
|
| 273 |
+
candidates.append({
|
| 274 |
+
"r": r,
|
| 275 |
+
"broke_line": broke_line,
|
| 276 |
+
"score": _mean(before_lps + after_lps),
|
| 277 |
+
"n_before": len(before_lps),
|
| 278 |
+
"n_after": len(after_lps),
|
| 279 |
+
"preview": broke_line[-28:],
|
| 280 |
+
})
|
| 281 |
+
|
| 282 |
+
chosen, r = text, 0
|
| 283 |
+
if candidates:
|
| 284 |
+
candidates.sort(key=lambda c: c["score"], reverse=True)
|
| 285 |
+
chosen = candidates[0]["broke_line"]
|
| 286 |
+
r = candidates[0]["r"]
|
| 287 |
+
committed += chosen
|
| 288 |
+
per_line.append({"line": i, "chosen": chosen, "r": r, "candidates": candidates})
|
| 289 |
+
if on_line:
|
| 290 |
+
on_line(chosen, {"line": i})
|
| 291 |
+
|
| 292 |
+
return {"text": committed, "per_line": per_line}
|
eval_classifier.py
ADDED
|
@@ -0,0 +1,185 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""List-vs-prose classifier eval harness (Python port of the dataset + runner
|
| 2 |
+
from src/eval.js).
|
| 3 |
+
|
| 4 |
+
50 list-style + 50 prose-style hand-picked prompts, split 10+10 validation /
|
| 5 |
+
40+40 dev. Run this as a script to sweep candidate classifier variants on the
|
| 6 |
+
current model and pick the best one for it:
|
| 7 |
+
|
| 8 |
+
SIDECHAT_MODEL=openbmb/MiniCPM5-1B .venv/bin/python eval_classifier.py
|
| 9 |
+
|
| 10 |
+
It prints a ranking table (dev accuracy, list-recall, prose-recall) and then
|
| 11 |
+
validates the top variants on the held-out set. The winner becomes
|
| 12 |
+
classifier.DEFAULT_VARIANT.
|
| 13 |
+
"""
|
| 14 |
+
|
| 15 |
+
from __future__ import annotations
|
| 16 |
+
|
| 17 |
+
import time
|
| 18 |
+
|
| 19 |
+
from classifier import Variant, classify
|
| 20 |
+
|
| 21 |
+
# ---------------------------------------------------------------------------
|
| 22 |
+
# Datasets (ported verbatim from src/eval.js)
|
| 23 |
+
# ---------------------------------------------------------------------------
|
| 24 |
+
|
| 25 |
+
LIST_PROMPTS = [
|
| 26 |
+
# --- validation (first 10) ---
|
| 27 |
+
"list 10 ways to improve morale at work",
|
| 28 |
+
"give me five reasons to learn Rust",
|
| 29 |
+
"what are the main benefits of meditation?",
|
| 30 |
+
"suggest some names for my new puppy",
|
| 31 |
+
"name three famous jazz musicians",
|
| 32 |
+
"list the ingredients for guacamole",
|
| 33 |
+
"what are the steps to change a tire?",
|
| 34 |
+
"give me ideas for weekend activities with kids",
|
| 35 |
+
"tips for packing light when traveling",
|
| 36 |
+
"what are some common Italian desserts?",
|
| 37 |
+
# --- dev (next 40) ---
|
| 38 |
+
"list popular video game consoles from the 1990s",
|
| 39 |
+
"suggest questions to ask at a job interview",
|
| 40 |
+
"what are the symptoms of dehydration?",
|
| 41 |
+
"name ten countries in Africa",
|
| 42 |
+
"list some movies directed by Christopher Nolan",
|
| 43 |
+
"give me seven examples of onomatopoeia",
|
| 44 |
+
"what tools do I need to build a raised garden bed?",
|
| 45 |
+
"suggest some icebreaker activities for a team meeting",
|
| 46 |
+
"ways to reduce food waste at home",
|
| 47 |
+
"list the planets in order from the sun",
|
| 48 |
+
"what are the main differences between Python 2 and Python 3?",
|
| 49 |
+
"give me 5 good podcast recommendations about history",
|
| 50 |
+
"name three types of dance",
|
| 51 |
+
"top tourist attractions in Kyoto",
|
| 52 |
+
"list common symptoms of the flu",
|
| 53 |
+
"what are some healthy snack ideas for kids?",
|
| 54 |
+
"suggest some books similar to The Hobbit",
|
| 55 |
+
"name five spices commonly used in Indian cooking",
|
| 56 |
+
"list programming languages that compile to WebAssembly",
|
| 57 |
+
"give me a list of yoga poses for beginners",
|
| 58 |
+
"what are some good stretches before running?",
|
| 59 |
+
"name the colors of the rainbow",
|
| 60 |
+
"list the months of the year in French",
|
| 61 |
+
"what are common causes of burnout?",
|
| 62 |
+
"suggest some romantic date ideas in New York",
|
| 63 |
+
"give me a bullet list of home safety tips",
|
| 64 |
+
"list the bones in the human hand",
|
| 65 |
+
"ways to learn a new language quickly",
|
| 66 |
+
"name five mammals native to Australia",
|
| 67 |
+
"what are some highlights of the French Revolution?",
|
| 68 |
+
"list common pitfalls of distributed systems",
|
| 69 |
+
"top 10 songs from the 1980s",
|
| 70 |
+
"suggest some hobbies for introverts",
|
| 71 |
+
"name the original members of The Beatles",
|
| 72 |
+
"what are the primary colors?",
|
| 73 |
+
"list reasons to adopt a cat",
|
| 74 |
+
"give me 6 tips for better sleep hygiene",
|
| 75 |
+
"name the Great Lakes",
|
| 76 |
+
"list programming concepts every developer should know",
|
| 77 |
+
"suggest some vegan dinner recipes",
|
| 78 |
+
]
|
| 79 |
+
|
| 80 |
+
PROSE_PROMPTS = [
|
| 81 |
+
# --- validation (first 10) ---
|
| 82 |
+
"tell me a short story about a lighthouse keeper",
|
| 83 |
+
"write a haiku about autumn",
|
| 84 |
+
"explain how a solar panel works in a paragraph",
|
| 85 |
+
"summarize the plot of Pride and Prejudice",
|
| 86 |
+
'what does the word "quixotic" mean?',
|
| 87 |
+
'translate "good morning" to Japanese',
|
| 88 |
+
"write a professional email declining a meeting",
|
| 89 |
+
"describe the taste of a ripe mango",
|
| 90 |
+
"compose a poem about loneliness",
|
| 91 |
+
"what is the capital of Australia?",
|
| 92 |
+
# --- dev (next 40) ---
|
| 93 |
+
"tell me about the invention of the printing press",
|
| 94 |
+
"write a cover letter for a software engineering role",
|
| 95 |
+
"explain the theory of relativity to a 10-year-old",
|
| 96 |
+
"who was Marie Curie?",
|
| 97 |
+
"describe a sunset over the ocean",
|
| 98 |
+
"what is photosynthesis?",
|
| 99 |
+
"write a bedtime story for a 4-year-old",
|
| 100 |
+
"explain how blockchain works",
|
| 101 |
+
"tell me about the history of tea in China",
|
| 102 |
+
"describe the plot of Inception",
|
| 103 |
+
"write a haiku about the sea",
|
| 104 |
+
"what is the meaning of life according to Camus?",
|
| 105 |
+
"tell me a joke about programming",
|
| 106 |
+
"explain why the sky is blue",
|
| 107 |
+
"describe what it feels like to run a marathon",
|
| 108 |
+
"write a love letter in the style of Shakespeare",
|
| 109 |
+
"what year did the Berlin Wall fall?",
|
| 110 |
+
"tell me about the architecture of the Sagrada Familia",
|
| 111 |
+
"write a persuasive essay on renewable energy",
|
| 112 |
+
"describe the personality of a golden retriever",
|
| 113 |
+
"who was the first person on the moon?",
|
| 114 |
+
"tell me about quantum entanglement briefly",
|
| 115 |
+
"write a one-paragraph synopsis of The Great Gatsby",
|
| 116 |
+
'what is the etymology of the word "sandwich"?',
|
| 117 |
+
"explain why we dream",
|
| 118 |
+
"tell me a myth about the origin of fire",
|
| 119 |
+
"describe the feeling of nostalgia",
|
| 120 |
+
"write a toast for a wedding",
|
| 121 |
+
'what does "serendipity" mean?',
|
| 122 |
+
"tell me about your favorite season",
|
| 123 |
+
"explain the difference between empathy and sympathy",
|
| 124 |
+
"who wrote Hamlet?",
|
| 125 |
+
"write a limerick about cats",
|
| 126 |
+
"tell me a ghost story",
|
| 127 |
+
"describe Mount Fuji in winter",
|
| 128 |
+
"what happened in the Cuban Missile Crisis?",
|
| 129 |
+
"explain how a car engine works",
|
| 130 |
+
"tell me a folk tale from Ireland",
|
| 131 |
+
"write an essay on the importance of libraries",
|
| 132 |
+
"describe a perfect day",
|
| 133 |
+
]
|
| 134 |
+
|
| 135 |
+
VALIDATION_LIST = LIST_PROMPTS[:10]
|
| 136 |
+
VALIDATION_PROSE = PROSE_PROMPTS[:10]
|
| 137 |
+
DEV_LIST = LIST_PROMPTS[10:]
|
| 138 |
+
DEV_PROSE = PROSE_PROMPTS[10:]
|
| 139 |
+
|
| 140 |
+
|
| 141 |
+
def make_labelled(list_prompts, prose_prompts):
|
| 142 |
+
return [{"prompt": p, "expected": True} for p in list_prompts] + [
|
| 143 |
+
{"prompt": p, "expected": False} for p in prose_prompts
|
| 144 |
+
]
|
| 145 |
+
|
| 146 |
+
|
| 147 |
+
def run_variant_on(ctx, variant, labelled, on_progress=None):
|
| 148 |
+
results = []
|
| 149 |
+
for i, item in enumerate(labelled):
|
| 150 |
+
pred, raw = classify(ctx, item["prompt"], variant)
|
| 151 |
+
results.append({**item, "prediction": pred, "raw": raw, "correct": pred == item["expected"]})
|
| 152 |
+
if on_progress:
|
| 153 |
+
on_progress(i + 1, len(labelled))
|
| 154 |
+
correct = sum(1 for r in results if r["correct"])
|
| 155 |
+
list_total = sum(1 for r in results if r["expected"])
|
| 156 |
+
prose_total = len(results) - list_total
|
| 157 |
+
list_hit = sum(1 for r in results if r["expected"] and r["correct"])
|
| 158 |
+
prose_hit = sum(1 for r in results if not r["expected"] and r["correct"])
|
| 159 |
+
return {
|
| 160 |
+
"variant": variant.name,
|
| 161 |
+
"accuracy": correct / len(results),
|
| 162 |
+
"correct": correct,
|
| 163 |
+
"total": len(results),
|
| 164 |
+
"list_recall": (list_hit, list_total),
|
| 165 |
+
"prose_recall": (prose_hit, prose_total),
|
| 166 |
+
"results": results,
|
| 167 |
+
}
|
| 168 |
+
|
| 169 |
+
|
| 170 |
+
def sweep(ctx, variants, labelled, label=""):
|
| 171 |
+
summaries = []
|
| 172 |
+
for v in variants:
|
| 173 |
+
t0 = time.time()
|
| 174 |
+
res = run_variant_on(ctx, v, labelled)
|
| 175 |
+
res["wall_s"] = time.time() - t0
|
| 176 |
+
lh, lt = res["list_recall"]
|
| 177 |
+
ph, pt = res["prose_recall"]
|
| 178 |
+
print(
|
| 179 |
+
f" [{label}] {v.name:30} {res['correct']:>2}/{res['total']} "
|
| 180 |
+
f"= {res['accuracy']*100:5.1f}% list {lh}/{lt} prose {ph}/{pt} "
|
| 181 |
+
f"({res['wall_s']:.0f}s)",
|
| 182 |
+
flush=True,
|
| 183 |
+
)
|
| 184 |
+
summaries.append(res)
|
| 185 |
+
return summaries
|
grammar.py
ADDED
|
@@ -0,0 +1,169 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Tiny grammar engine for acrostic-style constraints.
|
| 2 |
+
|
| 3 |
+
A faithful Python port of src/grammar.js.
|
| 4 |
+
|
| 5 |
+
Primitives:
|
| 6 |
+
- Atoms: {"kind": "lit", "allowed": set[str]} (consumes exactly one char from
|
| 7 |
+
`allowed`) or {"kind": "body", "max": int} (consumes 0..max non-newline
|
| 8 |
+
chars).
|
| 9 |
+
- Atom sequences are concatenation-only; with the body/newline structure we
|
| 10 |
+
use, transitions are deterministic, so state packs into one int:
|
| 11 |
+
atom_idx * stride + count.
|
| 12 |
+
|
| 13 |
+
Builders:
|
| 14 |
+
- compile_acrostic(secret, ...) — list-mode or prose-mode acrostic.
|
| 15 |
+
- compile_literal(text) — exact-text matcher (used by the classifier).
|
| 16 |
+
- union_grammars([g1, g2, ...]) — accept if any branch is alive.
|
| 17 |
+
|
| 18 |
+
The dead-state sentinel is -1 everywhere, matching the JS original.
|
| 19 |
+
"""
|
| 20 |
+
|
| 21 |
+
from __future__ import annotations
|
| 22 |
+
|
| 23 |
+
# Spaces in the secret are treated as "word breaks" — they don't pin the line to
|
| 24 |
+
# any particular letter, but they still produce a line, and the line must start
|
| 25 |
+
# with a punctuation character so the acrostic reads naturally
|
| 26 |
+
# ("HI WORLD" -> H... / I... / <punct>... / W... / O... / R... / L... / D...).
|
| 27 |
+
PUNCT_FOR_SPACE = set(
|
| 28 |
+
list(".,;:!?-")
|
| 29 |
+
+ list("()[]{}")
|
| 30 |
+
+ list("~<>")
|
| 31 |
+
+ ['"', "'", "`"]
|
| 32 |
+
+ list("@#$%&+=/\\|_^")
|
| 33 |
+
)
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
class AtomGrammar:
|
| 37 |
+
"""A single concatenation-only atom-sequence grammar (an NFA packed into ints)."""
|
| 38 |
+
|
| 39 |
+
def __init__(self, atoms):
|
| 40 |
+
self.atoms = atoms
|
| 41 |
+
max_body_max = 0
|
| 42 |
+
for a in atoms:
|
| 43 |
+
if a["kind"] == "body" and a["max"] > max_body_max:
|
| 44 |
+
max_body_max = a["max"]
|
| 45 |
+
self.stride = max_body_max + 2
|
| 46 |
+
self.PAST_END = len(atoms) * self.stride
|
| 47 |
+
self.state_count = self.PAST_END + 1
|
| 48 |
+
|
| 49 |
+
# Precompute accepting states: (a, c) accepts iff atom `a` can be
|
| 50 |
+
# epsilon-skipped at count `c` AND (a+1, 0) is accepting.
|
| 51 |
+
accepting = bytearray(self.state_count)
|
| 52 |
+
accepting[self.PAST_END] = 1
|
| 53 |
+
next_accepting = True
|
| 54 |
+
for a in range(len(atoms) - 1, -1, -1):
|
| 55 |
+
atom = atoms[a]
|
| 56 |
+
if next_accepting:
|
| 57 |
+
mn = 1 if atom["kind"] == "lit" else 0
|
| 58 |
+
mx = 1 if atom["kind"] == "lit" else atom["max"]
|
| 59 |
+
for c in range(mn, mx + 1):
|
| 60 |
+
accepting[a * self.stride + c] = 1
|
| 61 |
+
next_accepting = accepting[a * self.stride + 0] == 1
|
| 62 |
+
self.accepting = accepting
|
| 63 |
+
|
| 64 |
+
self.initial = 0
|
| 65 |
+
|
| 66 |
+
def _consume_at(self, a, ch):
|
| 67 |
+
atoms = self.atoms
|
| 68 |
+
while a < len(atoms):
|
| 69 |
+
atom = atoms[a]
|
| 70 |
+
if atom["kind"] == "lit":
|
| 71 |
+
if ch in atom["allowed"]:
|
| 72 |
+
return self.PAST_END if a + 1 >= len(atoms) else (a + 1) * self.stride
|
| 73 |
+
return -1
|
| 74 |
+
if ch != "\n":
|
| 75 |
+
return a * self.stride + 1
|
| 76 |
+
a += 1
|
| 77 |
+
return -1
|
| 78 |
+
|
| 79 |
+
def advance(self, state, s):
|
| 80 |
+
stride = self.stride
|
| 81 |
+
atoms = self.atoms
|
| 82 |
+
cur = state
|
| 83 |
+
for ch in s:
|
| 84 |
+
if cur == self.PAST_END:
|
| 85 |
+
return -1
|
| 86 |
+
a = cur // stride
|
| 87 |
+
c = cur - a * stride
|
| 88 |
+
atom = atoms[a]
|
| 89 |
+
if atom["kind"] == "lit":
|
| 90 |
+
if c < 1 and ch in atom["allowed"]:
|
| 91 |
+
nxt = self.PAST_END if a + 1 >= len(atoms) else (a + 1) * stride
|
| 92 |
+
else:
|
| 93 |
+
return -1
|
| 94 |
+
else:
|
| 95 |
+
if c < atom["max"] and ch != "\n":
|
| 96 |
+
nxt = a * stride + (c + 1)
|
| 97 |
+
else:
|
| 98 |
+
nxt = self._consume_at(a + 1, ch)
|
| 99 |
+
if nxt == -1:
|
| 100 |
+
return -1
|
| 101 |
+
cur = nxt
|
| 102 |
+
return cur
|
| 103 |
+
|
| 104 |
+
def accepts(self, state):
|
| 105 |
+
return state is not None and 0 <= state < self.state_count and self.accepting[state] == 1
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
def compile_acrostic(secret, list_prefix=" * ", max_line=80, case_insensitive=False, first_line_prefix=True):
|
| 109 |
+
if not secret:
|
| 110 |
+
raise ValueError("secret must be non-empty")
|
| 111 |
+
atoms = []
|
| 112 |
+
for i, letter in enumerate(secret):
|
| 113 |
+
want_prefix = i > 0 or first_line_prefix
|
| 114 |
+
if want_prefix:
|
| 115 |
+
for c in list_prefix:
|
| 116 |
+
atoms.append({"kind": "lit", "allowed": {c}})
|
| 117 |
+
if letter == " ":
|
| 118 |
+
allowed = set(PUNCT_FOR_SPACE)
|
| 119 |
+
elif case_insensitive:
|
| 120 |
+
allowed = {letter.upper(), letter.lower()}
|
| 121 |
+
else:
|
| 122 |
+
allowed = {letter}
|
| 123 |
+
atoms.append({"kind": "lit", "allowed": allowed})
|
| 124 |
+
atoms.append({"kind": "body", "max": max_line})
|
| 125 |
+
if i < len(secret) - 1:
|
| 126 |
+
atoms.append({"kind": "lit", "allowed": {"\n"}})
|
| 127 |
+
return AtomGrammar(atoms)
|
| 128 |
+
|
| 129 |
+
|
| 130 |
+
def compile_literal(text):
|
| 131 |
+
if not text:
|
| 132 |
+
raise ValueError("literal must be non-empty")
|
| 133 |
+
atoms = [{"kind": "lit", "allowed": {c}} for c in text]
|
| 134 |
+
return AtomGrammar(atoms)
|
| 135 |
+
|
| 136 |
+
|
| 137 |
+
class UnionGrammar:
|
| 138 |
+
"""Run several grammars in parallel; a token is alive iff at least one branch
|
| 139 |
+
is alive. State is a list of per-branch ints (-1 = dead branch). When every
|
| 140 |
+
branch is dead, advance returns -1 (the single-grammar dead sentinel)."""
|
| 141 |
+
|
| 142 |
+
def __init__(self, grammars):
|
| 143 |
+
self.grammars = grammars
|
| 144 |
+
self.initial = [g.initial for g in grammars]
|
| 145 |
+
|
| 146 |
+
def advance(self, state, s):
|
| 147 |
+
nxt = [-1] * len(self.grammars)
|
| 148 |
+
any_live = False
|
| 149 |
+
for i, g in enumerate(self.grammars):
|
| 150 |
+
if state[i] == -1:
|
| 151 |
+
nxt[i] = -1
|
| 152 |
+
continue
|
| 153 |
+
r = g.advance(state[i], s)
|
| 154 |
+
nxt[i] = r
|
| 155 |
+
if r != -1:
|
| 156 |
+
any_live = True
|
| 157 |
+
return nxt if any_live else -1
|
| 158 |
+
|
| 159 |
+
def accepts(self, state):
|
| 160 |
+
if state == -1:
|
| 161 |
+
return False
|
| 162 |
+
for i, g in enumerate(self.grammars):
|
| 163 |
+
if state[i] != -1 and g.accepts(state[i]):
|
| 164 |
+
return True
|
| 165 |
+
return False
|
| 166 |
+
|
| 167 |
+
|
| 168 |
+
def union_grammars(grammars):
|
| 169 |
+
return UnionGrammar(grammars)
|
logits.py
ADDED
|
@@ -0,0 +1,84 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Grammar-constrained LogitsProcessor (Python port of src/logits.js).
|
| 2 |
+
|
| 3 |
+
At each generation step:
|
| 4 |
+
1. Decode the generated suffix back to text.
|
| 5 |
+
2. Advance the grammar NFA by that text.
|
| 6 |
+
3. For every candidate token id, check whether appending its decoded text
|
| 7 |
+
keeps the NFA alive; mask losers to -inf (via the shared LegalCache).
|
| 8 |
+
4. EOS is allowed only once the NFA has reached an accept state.
|
| 9 |
+
|
| 10 |
+
Per-token decoding can disagree with BPE sequence-decoding in edge cases
|
| 11 |
+
(merged punctuation, etc.); for the acrostic patterns we care about this
|
| 12 |
+
approximation is fine.
|
| 13 |
+
"""
|
| 14 |
+
|
| 15 |
+
from __future__ import annotations
|
| 16 |
+
|
| 17 |
+
import time
|
| 18 |
+
|
| 19 |
+
from transformers import LogitsProcessor
|
| 20 |
+
|
| 21 |
+
from masking import LegalCache
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def build_token_text_table(tokenizer, vocab_size):
|
| 25 |
+
"""One-shot build of tokenId -> text, using per-token decode. Special tokens
|
| 26 |
+
decode to '' under skip_special_tokens=True, which we treat as
|
| 27 |
+
"disallowed" (empty string)."""
|
| 28 |
+
texts = tokenizer.batch_decode(
|
| 29 |
+
[[i] for i in range(vocab_size)], skip_special_tokens=True
|
| 30 |
+
)
|
| 31 |
+
return [t if isinstance(t, str) else "" for t in texts]
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
class GrammarLogitsProcessor(LogitsProcessor):
|
| 35 |
+
def __init__(self, grammar, tokenizer, token_text, eos_token_ids=(), legal_cache=None):
|
| 36 |
+
super().__init__()
|
| 37 |
+
self.grammar = grammar
|
| 38 |
+
self.tokenizer = tokenizer
|
| 39 |
+
self.token_text = token_text
|
| 40 |
+
self.cache = legal_cache or LegalCache(grammar, token_text, eos_token_ids)
|
| 41 |
+
self.prompt_length = None
|
| 42 |
+
self.stats = _fresh_stats()
|
| 43 |
+
|
| 44 |
+
def reset(self):
|
| 45 |
+
self.prompt_length = None
|
| 46 |
+
self.stats = _fresh_stats()
|
| 47 |
+
|
| 48 |
+
def __call__(self, input_ids, scores):
|
| 49 |
+
t_entry = time.perf_counter()
|
| 50 |
+
ids = input_ids[0]
|
| 51 |
+
if self.prompt_length is None:
|
| 52 |
+
self.prompt_length = ids.shape[0]
|
| 53 |
+
|
| 54 |
+
generated = ids[self.prompt_length:].tolist()
|
| 55 |
+
text = (
|
| 56 |
+
self.tokenizer.decode(generated, skip_special_tokens=True)
|
| 57 |
+
if generated
|
| 58 |
+
else ""
|
| 59 |
+
)
|
| 60 |
+
|
| 61 |
+
state = self.grammar.advance(self.grammar.initial, text)
|
| 62 |
+
data = scores[0]
|
| 63 |
+
|
| 64 |
+
if state == -1:
|
| 65 |
+
# Already violated; nothing useful to do without rewinding. Let the
|
| 66 |
+
# original logits through so generation at least terminates.
|
| 67 |
+
self._record(time.perf_counter() - t_entry, -1)
|
| 68 |
+
return scores
|
| 69 |
+
|
| 70 |
+
illegal = self.cache.illegal_tensor(state)
|
| 71 |
+
data[illegal.to(data.device)] = float("-inf")
|
| 72 |
+
|
| 73 |
+
self._record(time.perf_counter() - t_entry, int((~illegal).sum().item()))
|
| 74 |
+
return scores
|
| 75 |
+
|
| 76 |
+
def _record(self, dt, survivors):
|
| 77 |
+
st = self.stats
|
| 78 |
+
st["calls"] += 1
|
| 79 |
+
st["total_ms"] += dt * 1000.0
|
| 80 |
+
st["per_step"].append({"ms": dt * 1000.0, "survivors": survivors})
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
def _fresh_stats():
|
| 84 |
+
return {"calls": 0, "total_ms": 0.0, "per_step": []}
|
masking.py
ADDED
|
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Shared grammar-legality computation + a per-state cache.
|
| 2 |
+
|
| 3 |
+
The set of grammar-legal next tokens is a pure function of the grammar state, so
|
| 4 |
+
we cache the boolean legal mask by state. This is what makes the crossing search
|
| 5 |
+
affordable: its many short rollouts all start from the same handful of
|
| 6 |
+
line-start states and reuse one (expensive) full-vocab scan.
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
from __future__ import annotations
|
| 10 |
+
|
| 11 |
+
import numpy as np
|
| 12 |
+
import torch
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class LegalCache:
|
| 16 |
+
def __init__(self, grammar, token_text, eos_token_ids=()):
|
| 17 |
+
self.grammar = grammar
|
| 18 |
+
self.token_text = token_text
|
| 19 |
+
self.eos_token_ids = [int(x) for x in eos_token_ids]
|
| 20 |
+
# Special tokens decode to '' and are always illegal — never probe them.
|
| 21 |
+
self._scan_ids = [i for i, t in enumerate(token_text) if t]
|
| 22 |
+
self._legal_cache = {} # state-key -> np.bool_ array
|
| 23 |
+
self._illegal_cache = {} # state-key -> torch.BoolTensor
|
| 24 |
+
|
| 25 |
+
@staticmethod
|
| 26 |
+
def _key(state):
|
| 27 |
+
return state if isinstance(state, int) else tuple(state)
|
| 28 |
+
|
| 29 |
+
def legal_np(self, state):
|
| 30 |
+
key = self._key(state)
|
| 31 |
+
cached = self._legal_cache.get(key)
|
| 32 |
+
if cached is not None:
|
| 33 |
+
return cached
|
| 34 |
+
advance = self.grammar.advance
|
| 35 |
+
token_text = self.token_text
|
| 36 |
+
at_accept = self.grammar.accepts(state)
|
| 37 |
+
legal = np.zeros(len(token_text), dtype=bool)
|
| 38 |
+
for i in self._scan_ids:
|
| 39 |
+
if advance(state, token_text[i]) != -1:
|
| 40 |
+
legal[i] = True
|
| 41 |
+
for eid in self.eos_token_ids:
|
| 42 |
+
legal[eid] = at_accept
|
| 43 |
+
self._legal_cache[key] = legal
|
| 44 |
+
return legal
|
| 45 |
+
|
| 46 |
+
def illegal_tensor(self, state):
|
| 47 |
+
key = self._key(state)
|
| 48 |
+
cached = self._illegal_cache.get(key)
|
| 49 |
+
if cached is not None:
|
| 50 |
+
return cached
|
| 51 |
+
t = torch.from_numpy(~self.legal_np(state))
|
| 52 |
+
self._illegal_cache[key] = t
|
| 53 |
+
return t
|
requirements.txt
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
torch
|
| 2 |
+
transformers>=5.12
|
| 3 |
+
accelerate
|
| 4 |
+
numpy
|
| 5 |
+
gradio>=6
|
sweep_minicpm.py
ADDED
|
@@ -0,0 +1,121 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Sweep candidate list/prose classifier variants on the current model, pick a
|
| 2 |
+
winner. CPU-conscious: screen all candidates on a 20-prompt subset, then run the
|
| 3 |
+
top few on the full 80-prompt dev set + 20-prompt validation set.
|
| 4 |
+
|
| 5 |
+
SIDECHAT_MODEL=openbmb/MiniCPM5-1B .venv/bin/python sweep_minicpm.py
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
from __future__ import annotations
|
| 9 |
+
|
| 10 |
+
import time
|
| 11 |
+
|
| 12 |
+
import app # loads the model + CTX
|
| 13 |
+
from classifier import Variant
|
| 14 |
+
from eval_classifier import (
|
| 15 |
+
DEV_LIST, DEV_PROSE, VALIDATION_LIST, VALIDATION_PROSE, make_labelled, sweep,
|
| 16 |
+
)
|
| 17 |
+
|
| 18 |
+
CTX = app.CTX
|
| 19 |
+
parse_list = lambda s: s.startswith("list")
|
| 20 |
+
parse_items = lambda s: s.startswith("items")
|
| 21 |
+
|
| 22 |
+
# Candidate variants spanning the axes that mattered in the JS sweep: intent vs.
|
| 23 |
+
# request framing, default-to-list vs. default-to-story polarity, list/story
|
| 24 |
+
# branch vocab, and the single-vs-plural rule. (true/false and CAPS branches are
|
| 25 |
+
# known-bad and omitted.)
|
| 26 |
+
C1_BASE = (
|
| 27 |
+
"Classify the user's request. Use \"list\" when the user wants enumerated "
|
| 28 |
+
"items. Use \"story\" for everything else."
|
| 29 |
+
)
|
| 30 |
+
SINGLE_PLURAL = (
|
| 31 |
+
" \"What is X\" (a single fact) is a story; \"What are the/some Xs\" (plural "
|
| 32 |
+
"enumeration) is a list; \"what are the steps/differences/causes/symptoms\" "
|
| 33 |
+
"is a list."
|
| 34 |
+
)
|
| 35 |
+
WRITE_FORMS = (
|
| 36 |
+
" Whenever the user asks to \"write\" or \"compose\" a haiku, poem, letter, "
|
| 37 |
+
"cover letter, email, joke, story, essay, or limerick, the answer is a story."
|
| 38 |
+
)
|
| 39 |
+
EXTENDED_TRIGGERS = (
|
| 40 |
+
"Classify the user's request. Default to \"list\". Use \"story\" only when the "
|
| 41 |
+
"user asks for narrative/prose: \"tell me a story\", \"write a poem/haiku/"
|
| 42 |
+
"limerick/email/essay/letter\", \"describe\", \"explain\", \"translate\", "
|
| 43 |
+
"\"summarize\", \"what does X mean\", \"who was/is\", \"what is X\", \"when "
|
| 44 |
+
"did\", \"why does\", \"how does (concept)\", \"compose\"."
|
| 45 |
+
)
|
| 46 |
+
STORY_DEFAULT = (
|
| 47 |
+
"Classify the user's request. Default to \"story\". Use \"list\" only when the "
|
| 48 |
+
"user clearly asks for multiple discrete items: \"list N\", \"name N\", "
|
| 49 |
+
"\"give N\", \"top N\", \"suggest some\", \"ways to\", \"tips\", \"steps\", "
|
| 50 |
+
"\"reasons\", \"examples of\"."
|
| 51 |
+
)
|
| 52 |
+
|
| 53 |
+
INTENT_BASE = (
|
| 54 |
+
"Classify the user's intent. Use \"list\" when the answer is a set of "
|
| 55 |
+
"separate items the user can scan. Use \"story\" when the answer flows as "
|
| 56 |
+
"one narrative, single fact, or short paragraph."
|
| 57 |
+
)
|
| 58 |
+
DEFAULT_LIST_BASE = (
|
| 59 |
+
"Classify the user's request. Default to \"list\". Use \"story\" only when "
|
| 60 |
+
"the user clearly asks for narrative: \"tell me a story\", \"write a "
|
| 61 |
+
"poem/haiku/email\", \"describe X\", \"explain X\", \"translate X\", "
|
| 62 |
+
"\"what does X mean\", \"who was/what is/when did\"."
|
| 63 |
+
)
|
| 64 |
+
|
| 65 |
+
# Baseline reveals MiniCPM's split failure: intent/list-default bases nail list
|
| 66 |
+
# recall but miss prose "write a haiku" (write-forms) and "capital of Australia"
|
| 67 |
+
# (single fact); the c1 base is the opposite. So pair list-favoring bases with
|
| 68 |
+
# the WRITE_FORMS + SINGLE_PLURAL rules that target exactly those prose misses.
|
| 69 |
+
CANDIDATES = [
|
| 70 |
+
# Baselines (carried from the LFM2 sweep).
|
| 71 |
+
Variant("c1_single_plural", C1_BASE + SINGLE_PLURAL, "The user is asking for a ", ["list.", "story."], parse_list),
|
| 72 |
+
Variant("intent_two_rules", INTENT_BASE, "The intent is to get a ", ["list.", "story."], parse_list),
|
| 73 |
+
# Intent base + targeted prose rules.
|
| 74 |
+
Variant("intent_write", INTENT_BASE + WRITE_FORMS, "The intent is to get a ", ["list.", "story."], parse_list),
|
| 75 |
+
Variant("intent_sp", INTENT_BASE + SINGLE_PLURAL, "The intent is to get a ", ["list.", "story."], parse_list),
|
| 76 |
+
Variant("intent_write_sp", INTENT_BASE + WRITE_FORMS + SINGLE_PLURAL, "The intent is to get a ", ["list.", "story."], parse_list),
|
| 77 |
+
# Default-to-list base + targeted prose rules.
|
| 78 |
+
Variant("default_list", DEFAULT_LIST_BASE, "The user wants the answer as a ", ["list.", "story."], parse_list),
|
| 79 |
+
Variant("default_list_write_sp", DEFAULT_LIST_BASE + WRITE_FORMS + SINGLE_PLURAL, "The user wants the answer as a ", ["list.", "story."], parse_list),
|
| 80 |
+
# c1 base + write-forms (complementary to single_plural).
|
| 81 |
+
Variant("c1_write_sp", C1_BASE + WRITE_FORMS + SINGLE_PLURAL, "The user is asking for a ", ["list.", "story."], parse_list),
|
| 82 |
+
# Long built-in trigger list (no separate rules).
|
| 83 |
+
Variant("extended_triggers", EXTENDED_TRIGGERS, "The user wants the answer as a ", ["list.", "story."], parse_list),
|
| 84 |
+
# Alternate branch vocab.
|
| 85 |
+
Variant(
|
| 86 |
+
"items_text",
|
| 87 |
+
"Classify the user's intent. Use \"items\" when the user wants enumerated "
|
| 88 |
+
"items. Use \"text\" for everything else (narrative, single answer, "
|
| 89 |
+
"explanation, translation, story, poem).",
|
| 90 |
+
"The intent is to get ", ["items.", "text."], parse_items,
|
| 91 |
+
),
|
| 92 |
+
]
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
def main():
|
| 96 |
+
print(f"model: {app.MODEL_ID} · {len(CTX.token_text)} tokens", flush=True)
|
| 97 |
+
# Fast screen on a 20-prompt subset (first 10 of each dev class).
|
| 98 |
+
screen = make_labelled(DEV_LIST[:10], DEV_PROSE[:10])
|
| 99 |
+
print(f"\n=== SCREEN ({len(screen)} prompts) ===", flush=True)
|
| 100 |
+
t0 = time.time()
|
| 101 |
+
screen_res = sweep(CTX, CANDIDATES, screen, label="screen")
|
| 102 |
+
screen_res.sort(key=lambda r: r["accuracy"], reverse=True)
|
| 103 |
+
print(f"screen done in {(time.time()-t0)/60:.1f} min", flush=True)
|
| 104 |
+
|
| 105 |
+
top = [next(c for c in CANDIDATES if c.name == r["variant"]) for r in screen_res[:3]]
|
| 106 |
+
print(f"\ntop 3 on screen: {[c.name for c in top]}", flush=True)
|
| 107 |
+
|
| 108 |
+
full = make_labelled(DEV_LIST, DEV_PROSE) + make_labelled(VALIDATION_LIST, VALIDATION_PROSE)
|
| 109 |
+
print(f"\n=== FULL ({len(full)} prompts: 50 list + 50 prose) ===", flush=True)
|
| 110 |
+
full_res = sweep(CTX, top, full, label="full")
|
| 111 |
+
full_res.sort(key=lambda r: r["accuracy"], reverse=True)
|
| 112 |
+
|
| 113 |
+
print("\n=== RANKING (full) ===", flush=True)
|
| 114 |
+
for r in full_res:
|
| 115 |
+
lh, lt = r["list_recall"]; ph, pt = r["prose_recall"]
|
| 116 |
+
print(f" {r['variant']:30} {r['accuracy']*100:5.1f}% list {lh}/{lt} prose {ph}/{pt}", flush=True)
|
| 117 |
+
print(f"\nWINNER: {full_res[0]['variant']} @ {full_res[0]['accuracy']*100:.1f}%", flush=True)
|
| 118 |
+
|
| 119 |
+
|
| 120 |
+
if __name__ == "__main__":
|
| 121 |
+
main()
|
tokinfo.py
ADDED
|
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Precomputed per-token boolean arrays used by the crossing search's stealth
|
| 2 |
+
casing and minimum-line-length masking. Built once at startup."""
|
| 3 |
+
|
| 4 |
+
from __future__ import annotations
|
| 5 |
+
|
| 6 |
+
from dataclasses import dataclass
|
| 7 |
+
|
| 8 |
+
import numpy as np
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def _is_ascii_alpha(ch):
|
| 12 |
+
return ("a" <= ch <= "z") or ("A" <= ch <= "Z")
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
@dataclass
|
| 16 |
+
class TokInfo:
|
| 17 |
+
has_newline: np.ndarray # token text contains '\n'
|
| 18 |
+
alpha_lower: np.ndarray # first ASCII letter is lowercase
|
| 19 |
+
alpha_upper: np.ndarray # first ASCII letter is uppercase
|
| 20 |
+
nonempty: np.ndarray # token decodes to a non-empty string
|
| 21 |
+
eos_mask: np.ndarray # token is an EOS id
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def build_tok_info(token_text, eos_token_ids):
|
| 25 |
+
n = len(token_text)
|
| 26 |
+
has_newline = np.zeros(n, dtype=bool)
|
| 27 |
+
alpha_lower = np.zeros(n, dtype=bool)
|
| 28 |
+
alpha_upper = np.zeros(n, dtype=bool)
|
| 29 |
+
nonempty = np.zeros(n, dtype=bool)
|
| 30 |
+
for i, t in enumerate(token_text):
|
| 31 |
+
if not t:
|
| 32 |
+
continue
|
| 33 |
+
nonempty[i] = True
|
| 34 |
+
if "\n" in t:
|
| 35 |
+
has_newline[i] = True
|
| 36 |
+
for ch in t:
|
| 37 |
+
if _is_ascii_alpha(ch):
|
| 38 |
+
if ch.islower():
|
| 39 |
+
alpha_lower[i] = True
|
| 40 |
+
else:
|
| 41 |
+
alpha_upper[i] = True
|
| 42 |
+
break
|
| 43 |
+
eos_mask = np.zeros(n, dtype=bool)
|
| 44 |
+
for e in eos_token_ids:
|
| 45 |
+
eos_mask[int(e)] = True
|
| 46 |
+
return TokInfo(has_newline, alpha_lower, alpha_upper, nonempty, eos_mask)
|