"""List-vs-prose classifier (Python port of the shipped part of src/eval.js). The classifier reads the user's prompt and decides whether the answer is best rendered as a bulleted list or as narrative prose. It is itself an LLM call, grammar-constrained to exactly one of two literal completions: apply the chat template, append a partial assistant response (the `prefill`), constrain generation to one of `branches`, parse the result. Failure modes are model-specific, so the prompt is tuned per model. The default here is the MiniCPM5-1B winner (`minicpm_intent_write_sp`, 96% on the 100-prompt suite) found by re-running the sweep (eval_classifier.py / sweep_minicpm.py) on that model. The LFM2.5-350M winner (`r6_c1_v2_single_plural`, 97.5% dev / 85% val) is kept as an alternate — it is *prose-biased* on MiniCPM (~75%), so don't reuse it there. See CLASSIFIER_PROMPT_OPTIMIZATION.md for the original JS sweep. """ from __future__ import annotations from dataclasses import dataclass from typing import Callable, List import torch from transformers import LogitsProcessorList from grammar import compile_literal, union_grammars from logits import GrammarLogitsProcessor @dataclass class Variant: name: str system: str prefill: str branches: List[str] parse: Callable[[str], bool] # raw generated text -> True (list) / False (prose) # Shared trigger-rule strings. _INTENT_BASE = ( "Classify the user's intent. Use \"list\" when the answer is a set of " "separate items the user can scan. Use \"story\" when the answer flows as " "one narrative, single fact, or short paragraph." ) _WRITE_FORMS = ( " Whenever the user asks to \"write\" or \"compose\" a haiku, poem, letter, " "cover letter, email, joke, story, essay, or limerick, the answer is a story." ) _SINGLE_PLURAL = ( " \"What is X\" (a single fact) is a story; \"What are the/some Xs\" (plural " "enumeration) is a list; \"what are the steps/differences/causes/symptoms\" " "is a list." ) # --- The shipped MiniCPM5-1B winner ----------------------------------------- # On MiniCPM, every "Default to list" framing collapses to all-story (list 0/50) # and the LFM2 winner is prose-biased. A neutral *intent* framing nails list # recall; adding the write-forms rule (catches "write a haiku/email") and the # single-vs-plural rule (catches "what is X" single facts) fixes the residual # prose misses. 96% on the 100-prompt suite (list 49/50, prose 47/50). DEFAULT_VARIANT = Variant( name="minicpm_intent_write_sp", system=_INTENT_BASE + _WRITE_FORMS + _SINGLE_PLURAL, prefill="The intent is to get a ", branches=["list.", "story."], parse=lambda s: s.startswith("list"), ) # --- Reference alternates (other strong variants; useful when re-tuning) ----- ALTERNATES = [ # The LFM2.5-350M winner (97.5% dev / 85% val on LFM2; ~75% on MiniCPM). Variant( name="r6_c1_v2_single_plural", system=( "Classify the user's request. Use \"list\" when the user wants " "enumerated items. Use \"story\" for everything else. \"What is X\" " "(a single fact) is a story; \"What are the/some Xs\" (plural " "enumeration) is a list; \"what are the steps/differences/causes/" "symptoms\" is a list." ), prefill="The user is asking for a ", branches=["list.", "story."], parse=lambda s: s.startswith("list"), ), # Intent base + single-plural only (100% screen, 93% full on MiniCPM; # perfect list recall but misses some "write a X" prose prompts). Variant( name="minicpm_intent_sp", system=_INTENT_BASE + _SINGLE_PLURAL, prefill="The intent is to get a ", branches=["list.", "story."], parse=lambda s: s.startswith("list"), ), ] VARIANTS = [DEFAULT_VARIANT, *ALTERNATES] def classify(ctx, prompt, variant=DEFAULT_VARIANT): """Run one classifier call. ctx is a Context (see app.py): .model, .tokenizer, .token_text, .eos_token_ids. Returns (prediction, raw).""" tok = ctx.tokenizer messages = [ {"role": "system", "content": variant.system}, {"role": "user", "content": prompt}, ] templated = tok.apply_chat_template( messages, tokenize=False, add_generation_prompt=True ) full_text = templated + variant.prefill grammar = union_grammars([compile_literal(b) for b in variant.branches]) proc = GrammarLogitsProcessor(grammar, tok, ctx.token_text, ctx.eos_token_ids) enc = tok(full_text, return_tensors="pt", add_special_tokens=False).to(ctx.model.device) with torch.no_grad(): out = ctx.model.generate( **enc, max_new_tokens=16, do_sample=False, logits_processor=LogitsProcessorList([proc]), pad_token_id=ctx.pad_token_id, ) raw = tok.decode(out[0][enc["input_ids"].shape[1]:], skip_special_tokens=True) return variant.parse(raw), raw