| """
|
| Post-processing for Smartwatch LM chat replies (BPE gibberish removal).
|
|
|
| Use after tokenizer.decode() — same logic is embedded in colab_all_in_one.py.
|
| """
|
|
|
| from __future__ import annotations
|
|
|
| import re
|
| from dataclasses import dataclass
|
|
|
| _BPE_SPACE = "\u0120"
|
| _BPE_NEWLINE = "\u010a"
|
|
|
| _MOJIBAKE_REPLACEMENTS: tuple[tuple[str, str], ...] = (
|
| ("âĢĶ", "—"),
|
| ("âĢĻ", "'"),
|
| ("âĢĺ", "'"),
|
| ("’", "'"),
|
| ("–", "—"),
|
| )
|
|
|
|
|
| def build_prompt(history: list[tuple[str, str]], user_message: str) -> str:
|
| lines: list[str] = []
|
| for user_text, bot_text in history:
|
| lines.append(f"user: {user_text}")
|
| lines.append(f"bot: {bot_text}")
|
| lines.append(f"user: {user_message}")
|
| lines.append("bot:")
|
| return "\n".join(lines)
|
|
|
|
|
| def _compact_tag(match: re.Match[str]) -> str:
|
| inner = re.sub(r"\s+", "", match.group(1))
|
| return f"<{inner}>"
|
|
|
|
|
| def clean_reply(text: str) -> str:
|
| """Remove ByteLevel BPE artifacts (Ġ Ċ) and fix broken punctuation."""
|
| out = text.replace(_BPE_SPACE, " ").replace(_BPE_NEWLINE, "\n")
|
| for bad, good in _MOJIBAKE_REPLACEMENTS:
|
| out = out.replace(bad, good)
|
| out = re.sub(r" +", " ", out)
|
| out = re.sub(r"<\s*([^>]+?)\s*>", _compact_tag, out)
|
| return out.replace(" '", "'").strip()
|
|
|
|
|
| def _first_bot_line(text: str) -> str:
|
| """Keep only the first bot utterance; drop hallucinated user turns."""
|
| text = clean_reply(text.lstrip())
|
| if re.match(r"^\s*user\s*:", text, re.IGNORECASE):
|
| match = re.search(r"bot\s*:\s*(.+)", text, re.IGNORECASE | re.DOTALL)
|
| if match:
|
| text = match.group(1)
|
| else:
|
| return ""
|
| text = re.sub(r"^\s*bot\s*:\s*", "", text, count=1, flags=re.IGNORECASE)
|
| text = re.split(r"\n\s*user\s*:", text, maxsplit=1, flags=re.IGNORECASE)[0]
|
| if "\n\n" in text:
|
| text = text.split("\n\n", 1)[0]
|
| return clean_reply(text.split("\n", 1)[0].strip())
|
|
|
|
|
| def extract_bot_reply(prompt: str, generated: str) -> str:
|
| """Strip prompt prefix and return one cleaned bot line."""
|
| marker = prompt.rstrip() + " "
|
| if generated.startswith(marker):
|
| reply = generated[len(marker) :]
|
| elif re.search(r"bot\s*:", generated, re.IGNORECASE):
|
| reply = re.split(r"bot\s*:", generated, maxsplit=0, flags=re.IGNORECASE)[-1]
|
| else:
|
| reply = generated
|
| return _first_bot_line(reply)
|
|
|
|
|
| def extract_bot_reply_from_continuation(continuation: str) -> str:
|
| """Decode only new tokens, then extract the first bot line."""
|
| return _first_bot_line(continuation)
|
|
|
|
|
| @dataclass
|
| class ParsedReply:
|
| intent: str
|
| template: str
|
|
|
|
|
| def extract_intent_reply(text: str) -> ParsedReply:
|
| cleaned = clean_reply(text)
|
| match = re.search(r"<\s*INTENT\s*:[^>]+>", cleaned, re.IGNORECASE)
|
| if not match:
|
| first = cleaned.split("\n", 1)[0].strip()
|
| return ParsedReply(intent="NONE", template=first or cleaned)
|
|
|
| rest = cleaned[match.start() :]
|
| rest = re.split(r"\nuser\s*:", rest, maxsplit=1, flags=re.IGNORECASE)[0]
|
| line = rest.split("\n", 1)[0].strip()
|
|
|
| intent_match = re.match(r"^<INTENT:([A-Z_]+)>\s*(.*)", line, re.IGNORECASE | re.DOTALL)
|
| if intent_match:
|
| return ParsedReply(intent=intent_match.group(1), template=intent_match.group(2).strip())
|
| return ParsedReply(intent="NONE", template=line)
|
|
|
|
|
| def fill_slots(text: str, data: dict[str, str]) -> str:
|
| return re.sub(
|
| r"<([A-Z_]+)>",
|
| lambda m: data.get(m.group(1), m.group(0)),
|
| text,
|
| )
|
|
|
|
|
| def process_model_output(
|
| prompt: str,
|
| generated: str,
|
| slot_data: dict[str, str] | None = None,
|
| ) -> tuple[str, ParsedReply, str]:
|
| """Raw continuation -> cleaned bot line -> intent parse -> slot-filled display."""
|
| raw = extract_bot_reply_from_continuation(generated)
|
| if not raw:
|
| raw = extract_bot_reply(prompt, generated)
|
| parsed = extract_intent_reply(raw)
|
| display = fill_slots(parsed.template, slot_data or {})
|
| return raw, parsed, display
|
|
|