""" Post-processing for Smartwatch LM chat replies (BPE gibberish removal). Use after tokenizer.decode() — same logic is embedded in colab_all_in_one.py. """ from __future__ import annotations import re from dataclasses import dataclass _BPE_SPACE = "\u0120" # Ġ _BPE_NEWLINE = "\u010a" # Ċ _MOJIBAKE_REPLACEMENTS: tuple[tuple[str, str], ...] = ( ("âĢĶ", "—"), ("âĢĻ", "'"), ("âĢĺ", "'"), ("’", "'"), ("–", "—"), ) def build_prompt(history: list[tuple[str, str]], user_message: str) -> str: lines: list[str] = [] for user_text, bot_text in history: lines.append(f"user: {user_text}") lines.append(f"bot: {bot_text}") lines.append(f"user: {user_message}") lines.append("bot:") return "\n".join(lines) def _compact_tag(match: re.Match[str]) -> str: inner = re.sub(r"\s+", "", match.group(1)) return f"<{inner}>" def clean_reply(text: str) -> str: """Remove ByteLevel BPE artifacts (Ġ Ċ) and fix broken punctuation.""" out = text.replace(_BPE_SPACE, " ").replace(_BPE_NEWLINE, "\n") for bad, good in _MOJIBAKE_REPLACEMENTS: out = out.replace(bad, good) out = re.sub(r" +", " ", out) out = re.sub(r"<\s*([^>]+?)\s*>", _compact_tag, out) return out.replace(" '", "'").strip() def _first_bot_line(text: str) -> str: """Keep only the first bot utterance; drop hallucinated user turns.""" text = clean_reply(text.lstrip()) if re.match(r"^\s*user\s*:", text, re.IGNORECASE): match = re.search(r"bot\s*:\s*(.+)", text, re.IGNORECASE | re.DOTALL) if match: text = match.group(1) else: return "" text = re.sub(r"^\s*bot\s*:\s*", "", text, count=1, flags=re.IGNORECASE) text = re.split(r"\n\s*user\s*:", text, maxsplit=1, flags=re.IGNORECASE)[0] if "\n\n" in text: text = text.split("\n\n", 1)[0] return clean_reply(text.split("\n", 1)[0].strip()) def extract_bot_reply(prompt: str, generated: str) -> str: """Strip prompt prefix and return one cleaned bot line.""" marker = prompt.rstrip() + " " if generated.startswith(marker): reply = generated[len(marker) :] elif re.search(r"bot\s*:", generated, re.IGNORECASE): reply = re.split(r"bot\s*:", generated, maxsplit=0, flags=re.IGNORECASE)[-1] else: reply = generated return _first_bot_line(reply) def extract_bot_reply_from_continuation(continuation: str) -> str: """Decode only new tokens, then extract the first bot line.""" return _first_bot_line(continuation) @dataclass class ParsedReply: intent: str template: str def extract_intent_reply(text: str) -> ParsedReply: cleaned = clean_reply(text) match = re.search(r"<\s*INTENT\s*:[^>]+>", cleaned, re.IGNORECASE) if not match: first = cleaned.split("\n", 1)[0].strip() return ParsedReply(intent="NONE", template=first or cleaned) rest = cleaned[match.start() :] rest = re.split(r"\nuser\s*:", rest, maxsplit=1, flags=re.IGNORECASE)[0] line = rest.split("\n", 1)[0].strip() intent_match = re.match(r"^\s*(.*)", line, re.IGNORECASE | re.DOTALL) if intent_match: return ParsedReply(intent=intent_match.group(1), template=intent_match.group(2).strip()) return ParsedReply(intent="NONE", template=line) def fill_slots(text: str, data: dict[str, str]) -> str: return re.sub( r"<([A-Z_]+)>", lambda m: data.get(m.group(1), m.group(0)), text, ) def process_model_output( prompt: str, generated: str, slot_data: dict[str, str] | None = None, ) -> tuple[str, ParsedReply, str]: """Raw continuation -> cleaned bot line -> intent parse -> slot-filled display.""" raw = extract_bot_reply_from_continuation(generated) if not raw: raw = extract_bot_reply(prompt, generated) parsed = extract_intent_reply(raw) display = fill_slots(parsed.template, slot_data or {}) return raw, parsed, display