| """
|
| Interactive chat with the exported Smartwatch LM.
|
|
|
| Usage:
|
| pip install torch tokenizers
|
| python chat.py
|
| """
|
|
|
| from __future__ import annotations
|
|
|
| import re
|
| import sys
|
| from pathlib import Path
|
|
|
| import torch
|
|
|
| import config as cfg
|
| from model import load_model
|
| from reply_utils import build_prompt, extract_bot_reply_from_continuation, extract_intent_reply
|
|
|
|
|
| class ChatSession:
|
| def __init__(
|
| self,
|
| checkpoint_path: Path | None = None,
|
| tokenizer_path: Path | None = None,
|
| device: str | None = None,
|
| max_new_tokens: int | None = None,
|
| temperature: float | None = None,
|
| top_k: int | None = None,
|
| ):
|
| self.model, self.tokenizer, self.device = load_model(
|
| checkpoint_path, tokenizer_path, device
|
| )
|
| self.max_new_tokens = max_new_tokens or cfg.SAMPLE_MAX_NEW_TOKENS
|
| self.temperature = temperature if temperature is not None else cfg.SAMPLE_TEMPERATURE
|
| self.top_k = top_k if top_k is not None else cfg.SAMPLE_TOP_K
|
| self.history: list[tuple[str, str]] = []
|
|
|
| def reset(self) -> None:
|
| self.history.clear()
|
|
|
| @torch.no_grad()
|
| def say(self, user_message: str) -> str:
|
| user_message = user_message.strip()
|
| if not user_message:
|
| return ""
|
|
|
| prompt = build_prompt(self.history, user_message)
|
| start_ids = self.tokenizer.encode(prompt).ids
|
| x = torch.tensor([start_ids], dtype=torch.long, device=self.device)
|
| y = self.model.generate(
|
| x,
|
| max_new_tokens=self.max_new_tokens,
|
| temperature=self.temperature,
|
| top_k=self.top_k,
|
| )
|
| new_ids = y[0, len(start_ids) :].tolist()
|
| continuation = self.tokenizer.decode(new_ids)
|
| reply = extract_bot_reply_from_continuation(continuation)
|
| self.history.append((user_message, reply))
|
| return reply
|
|
|
| def say_display(self, user_message: str) -> tuple[str, str, str]:
|
| """Return (raw_reply, intent, display_text)."""
|
| raw = self.say(user_message)
|
| parsed = extract_intent_reply(raw)
|
| return raw, parsed.intent, parsed.template
|
|
|
|
|
| def print_banner() -> None:
|
| print("Smartwatch LM chat — type a message and press Enter.")
|
| print("Commands: quit/exit | reset (clear history) | history")
|
| print("-" * 60)
|
|
|
|
|
| def run_repl() -> None:
|
| try:
|
| session = ChatSession()
|
| except FileNotFoundError as exc:
|
| print(exc, file=sys.stderr)
|
| sys.exit(1)
|
|
|
| val_loss = None
|
| ckpt_path = cfg.OUTPUT_DIR / "checkpoint.pt"
|
| if ckpt_path.is_file():
|
| checkpoint = torch.load(ckpt_path, map_location="cpu", weights_only=False)
|
| val_loss = checkpoint.get("best_val_loss")
|
|
|
| print_banner()
|
| print(f"device: {session.device}")
|
| if val_loss is not None:
|
| print(f"checkpoint val loss: {val_loss:.4f}")
|
| print()
|
|
|
| while True:
|
| try:
|
| user_input = input("you> ").strip()
|
| except (EOFError, KeyboardInterrupt):
|
| print("\nbye")
|
| break
|
|
|
| if not user_input:
|
| continue
|
| lowered = user_input.lower()
|
| if lowered in {"quit", "exit"}:
|
| print("bye")
|
| break
|
| if lowered == "reset":
|
| session.reset()
|
| print("(history cleared)")
|
| continue
|
| if lowered == "history":
|
| if not session.history:
|
| print("(empty)")
|
| for user_text, bot_text in session.history:
|
| print(f"user: {user_text}\nbot: {bot_text}\n")
|
| continue
|
|
|
| _, intent, display = session.say_display(user_input)
|
| print(f"bot> {display}")
|
| if intent and intent != "NONE":
|
| print(f" intent: {intent}")
|
|
|
|
|
| if __name__ == "__main__":
|
| run_repl()
|
|
|