File size: 4,995 Bytes
08b2dd2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
"""List-vs-prose classifier (Python port of the shipped part of src/eval.js).

The classifier reads the user's prompt and decides whether the answer is best
rendered as a bulleted list or as narrative prose. It is itself an LLM call,
grammar-constrained to exactly one of two literal completions: apply the chat
template, append a partial assistant response (the `prefill`), constrain
generation to one of `branches`, parse the result.

Failure modes are model-specific, so the prompt is tuned per model. The default
here is the MiniCPM5-1B winner (`minicpm_intent_write_sp`, 96% on the 100-prompt
suite) found by re-running the sweep (eval_classifier.py / sweep_minicpm.py) on
that model. The LFM2.5-350M winner (`r6_c1_v2_single_plural`, 97.5% dev / 85%
val) is kept as an alternate — it is *prose-biased* on MiniCPM (~75%), so don't
reuse it there. See CLASSIFIER_PROMPT_OPTIMIZATION.md for the original JS sweep.
"""

from __future__ import annotations

from dataclasses import dataclass
from typing import Callable, List

import torch
from transformers import LogitsProcessorList

from grammar import compile_literal, union_grammars
from logits import GrammarLogitsProcessor


@dataclass
class Variant:
    name: str
    system: str
    prefill: str
    branches: List[str]
    parse: Callable[[str], bool]  # raw generated text -> True (list) / False (prose)


# Shared trigger-rule strings.
_INTENT_BASE = (
    "Classify the user's intent. Use \"list\" when the answer is a set of "
    "separate items the user can scan. Use \"story\" when the answer flows as "
    "one narrative, single fact, or short paragraph."
)
_WRITE_FORMS = (
    " Whenever the user asks to \"write\" or \"compose\" a haiku, poem, letter, "
    "cover letter, email, joke, story, essay, or limerick, the answer is a story."
)
_SINGLE_PLURAL = (
    " \"What is X\" (a single fact) is a story; \"What are the/some Xs\" (plural "
    "enumeration) is a list; \"what are the steps/differences/causes/symptoms\" "
    "is a list."
)

# --- The shipped MiniCPM5-1B winner -----------------------------------------
# On MiniCPM, every "Default to list" framing collapses to all-story (list 0/50)
# and the LFM2 winner is prose-biased. A neutral *intent* framing nails list
# recall; adding the write-forms rule (catches "write a haiku/email") and the
# single-vs-plural rule (catches "what is X" single facts) fixes the residual
# prose misses. 96% on the 100-prompt suite (list 49/50, prose 47/50).
DEFAULT_VARIANT = Variant(
    name="minicpm_intent_write_sp",
    system=_INTENT_BASE + _WRITE_FORMS + _SINGLE_PLURAL,
    prefill="The intent is to get a ",
    branches=["list.", "story."],
    parse=lambda s: s.startswith("list"),
)

# --- Reference alternates (other strong variants; useful when re-tuning) -----
ALTERNATES = [
    # The LFM2.5-350M winner (97.5% dev / 85% val on LFM2; ~75% on MiniCPM).
    Variant(
        name="r6_c1_v2_single_plural",
        system=(
            "Classify the user's request. Use \"list\" when the user wants "
            "enumerated items. Use \"story\" for everything else. \"What is X\" "
            "(a single fact) is a story; \"What are the/some Xs\" (plural "
            "enumeration) is a list; \"what are the steps/differences/causes/"
            "symptoms\" is a list."
        ),
        prefill="The user is asking for a ",
        branches=["list.", "story."],
        parse=lambda s: s.startswith("list"),
    ),
    # Intent base + single-plural only (100% screen, 93% full on MiniCPM;
    # perfect list recall but misses some "write a X" prose prompts).
    Variant(
        name="minicpm_intent_sp",
        system=_INTENT_BASE + _SINGLE_PLURAL,
        prefill="The intent is to get a ",
        branches=["list.", "story."],
        parse=lambda s: s.startswith("list"),
    ),
]

VARIANTS = [DEFAULT_VARIANT, *ALTERNATES]


def classify(ctx, prompt, variant=DEFAULT_VARIANT):
    """Run one classifier call. ctx is a Context (see app.py): .model,
    .tokenizer, .token_text, .eos_token_ids. Returns (prediction, raw)."""
    tok = ctx.tokenizer
    messages = [
        {"role": "system", "content": variant.system},
        {"role": "user", "content": prompt},
    ]
    templated = tok.apply_chat_template(
        messages, tokenize=False, add_generation_prompt=True
    )
    full_text = templated + variant.prefill

    grammar = union_grammars([compile_literal(b) for b in variant.branches])
    proc = GrammarLogitsProcessor(grammar, tok, ctx.token_text, ctx.eos_token_ids)

    enc = tok(full_text, return_tensors="pt", add_special_tokens=False).to(ctx.model.device)
    with torch.no_grad():
        out = ctx.model.generate(
            **enc,
            max_new_tokens=16,
            do_sample=False,
            logits_processor=LogitsProcessorList([proc]),
            pad_token_id=ctx.pad_token_id,
        )
    raw = tok.decode(out[0][enc["input_ids"].shape[1]:], skip_special_tokens=True)
    return variant.parse(raw), raw