Spaces:
Sleeping
Sleeping
| """List-vs-prose classifier (Python port of the shipped part of src/eval.js). | |
| The classifier reads the user's prompt and decides whether the answer is best | |
| rendered as a bulleted list or as narrative prose. It is itself an LLM call, | |
| grammar-constrained to exactly one of two literal completions: apply the chat | |
| template, append a partial assistant response (the `prefill`), constrain | |
| generation to one of `branches`, parse the result. | |
| Failure modes are model-specific, so the prompt is tuned per model. The default | |
| here is the MiniCPM5-1B winner (`minicpm_intent_write_sp`, 96% on the 100-prompt | |
| suite) found by re-running the sweep (eval_classifier.py / sweep_minicpm.py) on | |
| that model. The LFM2.5-350M winner (`r6_c1_v2_single_plural`, 97.5% dev / 85% | |
| val) is kept as an alternate — it is *prose-biased* on MiniCPM (~75%), so don't | |
| reuse it there. See CLASSIFIER_PROMPT_OPTIMIZATION.md for the original JS sweep. | |
| """ | |
| from __future__ import annotations | |
| from dataclasses import dataclass | |
| from typing import Callable, List | |
| import torch | |
| from transformers import LogitsProcessorList | |
| from grammar import compile_literal, union_grammars | |
| from logits import GrammarLogitsProcessor | |
| class Variant: | |
| name: str | |
| system: str | |
| prefill: str | |
| branches: List[str] | |
| parse: Callable[[str], bool] # raw generated text -> True (list) / False (prose) | |
| # Shared trigger-rule strings. | |
| _INTENT_BASE = ( | |
| "Classify the user's intent. Use \"list\" when the answer is a set of " | |
| "separate items the user can scan. Use \"story\" when the answer flows as " | |
| "one narrative, single fact, or short paragraph." | |
| ) | |
| _WRITE_FORMS = ( | |
| " Whenever the user asks to \"write\" or \"compose\" a haiku, poem, letter, " | |
| "cover letter, email, joke, story, essay, or limerick, the answer is a story." | |
| ) | |
| _SINGLE_PLURAL = ( | |
| " \"What is X\" (a single fact) is a story; \"What are the/some Xs\" (plural " | |
| "enumeration) is a list; \"what are the steps/differences/causes/symptoms\" " | |
| "is a list." | |
| ) | |
| # --- The shipped MiniCPM5-1B winner ----------------------------------------- | |
| # On MiniCPM, every "Default to list" framing collapses to all-story (list 0/50) | |
| # and the LFM2 winner is prose-biased. A neutral *intent* framing nails list | |
| # recall; adding the write-forms rule (catches "write a haiku/email") and the | |
| # single-vs-plural rule (catches "what is X" single facts) fixes the residual | |
| # prose misses. 96% on the 100-prompt suite (list 49/50, prose 47/50). | |
| DEFAULT_VARIANT = Variant( | |
| name="minicpm_intent_write_sp", | |
| system=_INTENT_BASE + _WRITE_FORMS + _SINGLE_PLURAL, | |
| prefill="The intent is to get a ", | |
| branches=["list.", "story."], | |
| parse=lambda s: s.startswith("list"), | |
| ) | |
| # --- Reference alternates (other strong variants; useful when re-tuning) ----- | |
| ALTERNATES = [ | |
| # The LFM2.5-350M winner (97.5% dev / 85% val on LFM2; ~75% on MiniCPM). | |
| Variant( | |
| name="r6_c1_v2_single_plural", | |
| system=( | |
| "Classify the user's request. Use \"list\" when the user wants " | |
| "enumerated items. Use \"story\" for everything else. \"What is X\" " | |
| "(a single fact) is a story; \"What are the/some Xs\" (plural " | |
| "enumeration) is a list; \"what are the steps/differences/causes/" | |
| "symptoms\" is a list." | |
| ), | |
| prefill="The user is asking for a ", | |
| branches=["list.", "story."], | |
| parse=lambda s: s.startswith("list"), | |
| ), | |
| # Intent base + single-plural only (100% screen, 93% full on MiniCPM; | |
| # perfect list recall but misses some "write a X" prose prompts). | |
| Variant( | |
| name="minicpm_intent_sp", | |
| system=_INTENT_BASE + _SINGLE_PLURAL, | |
| prefill="The intent is to get a ", | |
| branches=["list.", "story."], | |
| parse=lambda s: s.startswith("list"), | |
| ), | |
| ] | |
| VARIANTS = [DEFAULT_VARIANT, *ALTERNATES] | |
| def classify(ctx, prompt, variant=DEFAULT_VARIANT): | |
| """Run one classifier call. ctx is a Context (see app.py): .model, | |
| .tokenizer, .token_text, .eos_token_ids. Returns (prediction, raw).""" | |
| tok = ctx.tokenizer | |
| messages = [ | |
| {"role": "system", "content": variant.system}, | |
| {"role": "user", "content": prompt}, | |
| ] | |
| templated = tok.apply_chat_template( | |
| messages, tokenize=False, add_generation_prompt=True | |
| ) | |
| full_text = templated + variant.prefill | |
| grammar = union_grammars([compile_literal(b) for b in variant.branches]) | |
| proc = GrammarLogitsProcessor(grammar, tok, ctx.token_text, ctx.eos_token_ids) | |
| enc = tok(full_text, return_tensors="pt", add_special_tokens=False).to(ctx.model.device) | |
| with torch.no_grad(): | |
| out = ctx.model.generate( | |
| **enc, | |
| max_new_tokens=16, | |
| do_sample=False, | |
| logits_processor=LogitsProcessorList([proc]), | |
| pad_token_id=ctx.pad_token_id, | |
| ) | |
| raw = tok.decode(out[0][enc["input_ids"].shape[1]:], skip_special_tokens=True) | |
| return variant.parse(raw), raw | |