""" Interactive chat with the exported Smartwatch LM. Usage: pip install torch tokenizers python chat.py """ from __future__ import annotations import re import sys from pathlib import Path import torch import config as cfg from model import load_model from reply_utils import build_prompt, extract_bot_reply_from_continuation, extract_intent_reply class ChatSession: def __init__( self, checkpoint_path: Path | None = None, tokenizer_path: Path | None = None, device: str | None = None, max_new_tokens: int | None = None, temperature: float | None = None, top_k: int | None = None, ): self.model, self.tokenizer, self.device = load_model( checkpoint_path, tokenizer_path, device ) self.max_new_tokens = max_new_tokens or cfg.SAMPLE_MAX_NEW_TOKENS self.temperature = temperature if temperature is not None else cfg.SAMPLE_TEMPERATURE self.top_k = top_k if top_k is not None else cfg.SAMPLE_TOP_K self.history: list[tuple[str, str]] = [] def reset(self) -> None: self.history.clear() @torch.no_grad() def say(self, user_message: str) -> str: user_message = user_message.strip() if not user_message: return "" prompt = build_prompt(self.history, user_message) start_ids = self.tokenizer.encode(prompt).ids x = torch.tensor([start_ids], dtype=torch.long, device=self.device) y = self.model.generate( x, max_new_tokens=self.max_new_tokens, temperature=self.temperature, top_k=self.top_k, ) new_ids = y[0, len(start_ids) :].tolist() continuation = self.tokenizer.decode(new_ids) reply = extract_bot_reply_from_continuation(continuation) self.history.append((user_message, reply)) return reply def say_display(self, user_message: str) -> tuple[str, str, str]: """Return (raw_reply, intent, display_text).""" raw = self.say(user_message) parsed = extract_intent_reply(raw) return raw, parsed.intent, parsed.template def print_banner() -> None: print("Smartwatch LM chat — type a message and press Enter.") print("Commands: quit/exit | reset (clear history) | history") print("-" * 60) def run_repl() -> None: try: session = ChatSession() except FileNotFoundError as exc: print(exc, file=sys.stderr) sys.exit(1) val_loss = None ckpt_path = cfg.OUTPUT_DIR / "checkpoint.pt" if ckpt_path.is_file(): checkpoint = torch.load(ckpt_path, map_location="cpu", weights_only=False) val_loss = checkpoint.get("best_val_loss") print_banner() print(f"device: {session.device}") if val_loss is not None: print(f"checkpoint val loss: {val_loss:.4f}") print() while True: try: user_input = input("you> ").strip() except (EOFError, KeyboardInterrupt): print("\nbye") break if not user_input: continue lowered = user_input.lower() if lowered in {"quit", "exit"}: print("bye") break if lowered == "reset": session.reset() print("(history cleared)") continue if lowered == "history": if not session.history: print("(empty)") for user_text, bot_text in session.history: print(f"user: {user_text}\nbot: {bot_text}\n") continue _, intent, display = session.say_display(user_input) print(f"bot> {display}") if intent and intent != "NONE": print(f" intent: {intent}") if __name__ == "__main__": run_repl()