File size: 4,173 Bytes
a330cfa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
"""

Post-processing for Smartwatch LM chat replies (BPE gibberish removal).



Use after tokenizer.decode() — same logic is embedded in colab_all_in_one.py.

"""

from __future__ import annotations

import re
from dataclasses import dataclass

_BPE_SPACE = "\u0120"  # Ġ
_BPE_NEWLINE = "\u010a"  # Ċ

_MOJIBAKE_REPLACEMENTS: tuple[tuple[str, str], ...] = (
    ("âĢĶ", "—"),
    ("âĢĻ", "'"),
    ("âĢĺ", "'"),
    ("’", "'"),
    ("–", "—"),
)


def build_prompt(history: list[tuple[str, str]], user_message: str) -> str:
    lines: list[str] = []
    for user_text, bot_text in history:
        lines.append(f"user: {user_text}")
        lines.append(f"bot: {bot_text}")
    lines.append(f"user: {user_message}")
    lines.append("bot:")
    return "\n".join(lines)


def _compact_tag(match: re.Match[str]) -> str:
    inner = re.sub(r"\s+", "", match.group(1))
    return f"<{inner}>"


def clean_reply(text: str) -> str:
    """Remove ByteLevel BPE artifacts (Ġ Ċ) and fix broken punctuation."""
    out = text.replace(_BPE_SPACE, " ").replace(_BPE_NEWLINE, "\n")
    for bad, good in _MOJIBAKE_REPLACEMENTS:
        out = out.replace(bad, good)
    out = re.sub(r" +", " ", out)
    out = re.sub(r"<\s*([^>]+?)\s*>", _compact_tag, out)
    return out.replace(" '", "'").strip()


def _first_bot_line(text: str) -> str:
    """Keep only the first bot utterance; drop hallucinated user turns."""
    text = clean_reply(text.lstrip())
    if re.match(r"^\s*user\s*:", text, re.IGNORECASE):
        match = re.search(r"bot\s*:\s*(.+)", text, re.IGNORECASE | re.DOTALL)
        if match:
            text = match.group(1)
        else:
            return ""
    text = re.sub(r"^\s*bot\s*:\s*", "", text, count=1, flags=re.IGNORECASE)
    text = re.split(r"\n\s*user\s*:", text, maxsplit=1, flags=re.IGNORECASE)[0]
    if "\n\n" in text:
        text = text.split("\n\n", 1)[0]
    return clean_reply(text.split("\n", 1)[0].strip())


def extract_bot_reply(prompt: str, generated: str) -> str:
    """Strip prompt prefix and return one cleaned bot line."""
    marker = prompt.rstrip() + " "
    if generated.startswith(marker):
        reply = generated[len(marker) :]
    elif re.search(r"bot\s*:", generated, re.IGNORECASE):
        reply = re.split(r"bot\s*:", generated, maxsplit=0, flags=re.IGNORECASE)[-1]
    else:
        reply = generated
    return _first_bot_line(reply)


def extract_bot_reply_from_continuation(continuation: str) -> str:
    """Decode only new tokens, then extract the first bot line."""
    return _first_bot_line(continuation)


@dataclass
class ParsedReply:
    intent: str
    template: str


def extract_intent_reply(text: str) -> ParsedReply:
    cleaned = clean_reply(text)
    match = re.search(r"<\s*INTENT\s*:[^>]+>", cleaned, re.IGNORECASE)
    if not match:
        first = cleaned.split("\n", 1)[0].strip()
        return ParsedReply(intent="NONE", template=first or cleaned)

    rest = cleaned[match.start() :]
    rest = re.split(r"\nuser\s*:", rest, maxsplit=1, flags=re.IGNORECASE)[0]
    line = rest.split("\n", 1)[0].strip()

    intent_match = re.match(r"^<INTENT:([A-Z_]+)>\s*(.*)", line, re.IGNORECASE | re.DOTALL)
    if intent_match:
        return ParsedReply(intent=intent_match.group(1), template=intent_match.group(2).strip())
    return ParsedReply(intent="NONE", template=line)


def fill_slots(text: str, data: dict[str, str]) -> str:
    return re.sub(
        r"<([A-Z_]+)>",
        lambda m: data.get(m.group(1), m.group(0)),
        text,
    )


def process_model_output(

    prompt: str,

    generated: str,

    slot_data: dict[str, str] | None = None,

) -> tuple[str, ParsedReply, str]:
    """Raw continuation -> cleaned bot line -> intent parse -> slot-filled display."""
    raw = extract_bot_reply_from_continuation(generated)
    if not raw:
        raw = extract_bot_reply(prompt, generated)
    parsed = extract_intent_reply(raw)
    display = fill_slots(parsed.template, slot_data or {})
    return raw, parsed, display