prathamkode's picture
Upload folder using huggingface_hub
a330cfa verified
"""
Interactive chat with the exported Smartwatch LM.
Usage:
pip install torch tokenizers
python chat.py
"""
from __future__ import annotations
import re
import sys
from pathlib import Path
import torch
import config as cfg
from model import load_model
from reply_utils import build_prompt, extract_bot_reply_from_continuation, extract_intent_reply
class ChatSession:
def __init__(
self,
checkpoint_path: Path | None = None,
tokenizer_path: Path | None = None,
device: str | None = None,
max_new_tokens: int | None = None,
temperature: float | None = None,
top_k: int | None = None,
):
self.model, self.tokenizer, self.device = load_model(
checkpoint_path, tokenizer_path, device
)
self.max_new_tokens = max_new_tokens or cfg.SAMPLE_MAX_NEW_TOKENS
self.temperature = temperature if temperature is not None else cfg.SAMPLE_TEMPERATURE
self.top_k = top_k if top_k is not None else cfg.SAMPLE_TOP_K
self.history: list[tuple[str, str]] = []
def reset(self) -> None:
self.history.clear()
@torch.no_grad()
def say(self, user_message: str) -> str:
user_message = user_message.strip()
if not user_message:
return ""
prompt = build_prompt(self.history, user_message)
start_ids = self.tokenizer.encode(prompt).ids
x = torch.tensor([start_ids], dtype=torch.long, device=self.device)
y = self.model.generate(
x,
max_new_tokens=self.max_new_tokens,
temperature=self.temperature,
top_k=self.top_k,
)
new_ids = y[0, len(start_ids) :].tolist()
continuation = self.tokenizer.decode(new_ids)
reply = extract_bot_reply_from_continuation(continuation)
self.history.append((user_message, reply))
return reply
def say_display(self, user_message: str) -> tuple[str, str, str]:
"""Return (raw_reply, intent, display_text)."""
raw = self.say(user_message)
parsed = extract_intent_reply(raw)
return raw, parsed.intent, parsed.template
def print_banner() -> None:
print("Smartwatch LM chat — type a message and press Enter.")
print("Commands: quit/exit | reset (clear history) | history")
print("-" * 60)
def run_repl() -> None:
try:
session = ChatSession()
except FileNotFoundError as exc:
print(exc, file=sys.stderr)
sys.exit(1)
val_loss = None
ckpt_path = cfg.OUTPUT_DIR / "checkpoint.pt"
if ckpt_path.is_file():
checkpoint = torch.load(ckpt_path, map_location="cpu", weights_only=False)
val_loss = checkpoint.get("best_val_loss")
print_banner()
print(f"device: {session.device}")
if val_loss is not None:
print(f"checkpoint val loss: {val_loss:.4f}")
print()
while True:
try:
user_input = input("you> ").strip()
except (EOFError, KeyboardInterrupt):
print("\nbye")
break
if not user_input:
continue
lowered = user_input.lower()
if lowered in {"quit", "exit"}:
print("bye")
break
if lowered == "reset":
session.reset()
print("(history cleared)")
continue
if lowered == "history":
if not session.history:
print("(empty)")
for user_text, bot_text in session.history:
print(f"user: {user_text}\nbot: {bot_text}\n")
continue
_, intent, display = session.say_display(user_input)
print(f"bot> {display}")
if intent and intent != "NONE":
print(f" intent: {intent}")
if __name__ == "__main__":
run_repl()