smartwatch-lm-0.2 / reply_utils.py
prathamkode's picture
Upload folder using huggingface_hub
a330cfa verified
"""
Post-processing for Smartwatch LM chat replies (BPE gibberish removal).
Use after tokenizer.decode() — same logic is embedded in colab_all_in_one.py.
"""
from __future__ import annotations
import re
from dataclasses import dataclass
_BPE_SPACE = "\u0120" # Ġ
_BPE_NEWLINE = "\u010a" # Ċ
_MOJIBAKE_REPLACEMENTS: tuple[tuple[str, str], ...] = (
("âĢĶ", "—"),
("âĢĻ", "'"),
("âĢĺ", "'"),
("’", "'"),
("–", "—"),
)
def build_prompt(history: list[tuple[str, str]], user_message: str) -> str:
lines: list[str] = []
for user_text, bot_text in history:
lines.append(f"user: {user_text}")
lines.append(f"bot: {bot_text}")
lines.append(f"user: {user_message}")
lines.append("bot:")
return "\n".join(lines)
def _compact_tag(match: re.Match[str]) -> str:
inner = re.sub(r"\s+", "", match.group(1))
return f"<{inner}>"
def clean_reply(text: str) -> str:
"""Remove ByteLevel BPE artifacts (Ġ Ċ) and fix broken punctuation."""
out = text.replace(_BPE_SPACE, " ").replace(_BPE_NEWLINE, "\n")
for bad, good in _MOJIBAKE_REPLACEMENTS:
out = out.replace(bad, good)
out = re.sub(r" +", " ", out)
out = re.sub(r"<\s*([^>]+?)\s*>", _compact_tag, out)
return out.replace(" '", "'").strip()
def _first_bot_line(text: str) -> str:
"""Keep only the first bot utterance; drop hallucinated user turns."""
text = clean_reply(text.lstrip())
if re.match(r"^\s*user\s*:", text, re.IGNORECASE):
match = re.search(r"bot\s*:\s*(.+)", text, re.IGNORECASE | re.DOTALL)
if match:
text = match.group(1)
else:
return ""
text = re.sub(r"^\s*bot\s*:\s*", "", text, count=1, flags=re.IGNORECASE)
text = re.split(r"\n\s*user\s*:", text, maxsplit=1, flags=re.IGNORECASE)[0]
if "\n\n" in text:
text = text.split("\n\n", 1)[0]
return clean_reply(text.split("\n", 1)[0].strip())
def extract_bot_reply(prompt: str, generated: str) -> str:
"""Strip prompt prefix and return one cleaned bot line."""
marker = prompt.rstrip() + " "
if generated.startswith(marker):
reply = generated[len(marker) :]
elif re.search(r"bot\s*:", generated, re.IGNORECASE):
reply = re.split(r"bot\s*:", generated, maxsplit=0, flags=re.IGNORECASE)[-1]
else:
reply = generated
return _first_bot_line(reply)
def extract_bot_reply_from_continuation(continuation: str) -> str:
"""Decode only new tokens, then extract the first bot line."""
return _first_bot_line(continuation)
@dataclass
class ParsedReply:
intent: str
template: str
def extract_intent_reply(text: str) -> ParsedReply:
cleaned = clean_reply(text)
match = re.search(r"<\s*INTENT\s*:[^>]+>", cleaned, re.IGNORECASE)
if not match:
first = cleaned.split("\n", 1)[0].strip()
return ParsedReply(intent="NONE", template=first or cleaned)
rest = cleaned[match.start() :]
rest = re.split(r"\nuser\s*:", rest, maxsplit=1, flags=re.IGNORECASE)[0]
line = rest.split("\n", 1)[0].strip()
intent_match = re.match(r"^<INTENT:([A-Z_]+)>\s*(.*)", line, re.IGNORECASE | re.DOTALL)
if intent_match:
return ParsedReply(intent=intent_match.group(1), template=intent_match.group(2).strip())
return ParsedReply(intent="NONE", template=line)
def fill_slots(text: str, data: dict[str, str]) -> str:
return re.sub(
r"<([A-Z_]+)>",
lambda m: data.get(m.group(1), m.group(0)),
text,
)
def process_model_output(
prompt: str,
generated: str,
slot_data: dict[str, str] | None = None,
) -> tuple[str, ParsedReply, str]:
"""Raw continuation -> cleaned bot line -> intent parse -> slot-filled display."""
raw = extract_bot_reply_from_continuation(generated)
if not raw:
raw = extract_bot_reply(prompt, generated)
parsed = extract_intent_reply(raw)
display = fill_slots(parsed.template, slot_data or {})
return raw, parsed, display