sidechat / classifier.py
lsb's picture
Port steganacrostics to a Gradio app; retarget to MiniCPM5-1B
08b2dd2 verified
Raw
History Blame
5 kB
"""List-vs-prose classifier (Python port of the shipped part of src/eval.js).
The classifier reads the user's prompt and decides whether the answer is best
rendered as a bulleted list or as narrative prose. It is itself an LLM call,
grammar-constrained to exactly one of two literal completions: apply the chat
template, append a partial assistant response (the `prefill`), constrain
generation to one of `branches`, parse the result.
Failure modes are model-specific, so the prompt is tuned per model. The default
here is the MiniCPM5-1B winner (`minicpm_intent_write_sp`, 96% on the 100-prompt
suite) found by re-running the sweep (eval_classifier.py / sweep_minicpm.py) on
that model. The LFM2.5-350M winner (`r6_c1_v2_single_plural`, 97.5% dev / 85%
val) is kept as an alternate — it is *prose-biased* on MiniCPM (~75%), so don't
reuse it there. See CLASSIFIER_PROMPT_OPTIMIZATION.md for the original JS sweep.
"""
from __future__ import annotations
from dataclasses import dataclass
from typing import Callable, List
import torch
from transformers import LogitsProcessorList
from grammar import compile_literal, union_grammars
from logits import GrammarLogitsProcessor
@dataclass
class Variant:
name: str
system: str
prefill: str
branches: List[str]
parse: Callable[[str], bool] # raw generated text -> True (list) / False (prose)
# Shared trigger-rule strings.
_INTENT_BASE = (
"Classify the user's intent. Use \"list\" when the answer is a set of "
"separate items the user can scan. Use \"story\" when the answer flows as "
"one narrative, single fact, or short paragraph."
)
_WRITE_FORMS = (
" Whenever the user asks to \"write\" or \"compose\" a haiku, poem, letter, "
"cover letter, email, joke, story, essay, or limerick, the answer is a story."
)
_SINGLE_PLURAL = (
" \"What is X\" (a single fact) is a story; \"What are the/some Xs\" (plural "
"enumeration) is a list; \"what are the steps/differences/causes/symptoms\" "
"is a list."
)
# --- The shipped MiniCPM5-1B winner -----------------------------------------
# On MiniCPM, every "Default to list" framing collapses to all-story (list 0/50)
# and the LFM2 winner is prose-biased. A neutral *intent* framing nails list
# recall; adding the write-forms rule (catches "write a haiku/email") and the
# single-vs-plural rule (catches "what is X" single facts) fixes the residual
# prose misses. 96% on the 100-prompt suite (list 49/50, prose 47/50).
DEFAULT_VARIANT = Variant(
name="minicpm_intent_write_sp",
system=_INTENT_BASE + _WRITE_FORMS + _SINGLE_PLURAL,
prefill="The intent is to get a ",
branches=["list.", "story."],
parse=lambda s: s.startswith("list"),
)
# --- Reference alternates (other strong variants; useful when re-tuning) -----
ALTERNATES = [
# The LFM2.5-350M winner (97.5% dev / 85% val on LFM2; ~75% on MiniCPM).
Variant(
name="r6_c1_v2_single_plural",
system=(
"Classify the user's request. Use \"list\" when the user wants "
"enumerated items. Use \"story\" for everything else. \"What is X\" "
"(a single fact) is a story; \"What are the/some Xs\" (plural "
"enumeration) is a list; \"what are the steps/differences/causes/"
"symptoms\" is a list."
),
prefill="The user is asking for a ",
branches=["list.", "story."],
parse=lambda s: s.startswith("list"),
),
# Intent base + single-plural only (100% screen, 93% full on MiniCPM;
# perfect list recall but misses some "write a X" prose prompts).
Variant(
name="minicpm_intent_sp",
system=_INTENT_BASE + _SINGLE_PLURAL,
prefill="The intent is to get a ",
branches=["list.", "story."],
parse=lambda s: s.startswith("list"),
),
]
VARIANTS = [DEFAULT_VARIANT, *ALTERNATES]
def classify(ctx, prompt, variant=DEFAULT_VARIANT):
"""Run one classifier call. ctx is a Context (see app.py): .model,
.tokenizer, .token_text, .eos_token_ids. Returns (prediction, raw)."""
tok = ctx.tokenizer
messages = [
{"role": "system", "content": variant.system},
{"role": "user", "content": prompt},
]
templated = tok.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
full_text = templated + variant.prefill
grammar = union_grammars([compile_literal(b) for b in variant.branches])
proc = GrammarLogitsProcessor(grammar, tok, ctx.token_text, ctx.eos_token_ids)
enc = tok(full_text, return_tensors="pt", add_special_tokens=False).to(ctx.model.device)
with torch.no_grad():
out = ctx.model.generate(
**enc,
max_new_tokens=16,
do_sample=False,
logits_processor=LogitsProcessorList([proc]),
pad_token_id=ctx.pad_token_id,
)
raw = tok.decode(out[0][enc["input_ids"].shape[1]:], skip_special_tokens=True)
return variant.parse(raw), raw