#!/usr/bin/env python3 """Interactive chat and demo inference for Role SLM.""" import torch from tokenizers import Tokenizer from config import cfg from model import RoleSLM def load_model(checkpoint_name="best_model.pt"): device = torch.device(cfg.device) ckpt_path = cfg.checkpoint_dir / checkpoint_name if not ckpt_path.exists(): raise FileNotFoundError(f"Checkpoint not found: {ckpt_path}") ckpt = torch.load(ckpt_path, map_location=device, weights_only=False) for key, val in ckpt.get("config", {}).items(): if hasattr(cfg, key): setattr(cfg, key, val) model = RoleSLM() model.load_state_dict(ckpt["model_state_dict"], strict=False) model = model.to(device) model.eval() tok_path = cfg.tokenizer_dir / cfg.tokenizer_filename tokenizer = Tokenizer.from_file(str(tok_path)) print(f"Model loaded: {model.count_parameters()/1e6:.2f}M params") return model, tokenizer, device def generate_response(model, tokenizer, device, prompt, max_tokens=None, temperature=0.8, top_k=50, top_p=0.9): max_tokens = max_tokens or min(cfg.max_new_tokens, 512) encoded = tokenizer.encode(prompt) ids = encoded.ids if ids and ids[-1] == 3: ids = ids[:-1] input_ids = torch.tensor([ids], dtype=torch.long, device=device) input_len = input_ids.shape[1] with torch.no_grad(): output_ids = model.generate(input_ids, max_new_tokens=max_tokens, temperature=temperature, top_k=top_k, top_p=top_p) new_tokens = output_ids[0][input_len:].tolist() response = tokenizer.decode(new_tokens) response = response.replace("", "").replace("", "").replace("", "").strip() return response DEMO_PROMPTS = ['The OSI model layers define', 'OSPF routing protocol works by', 'Network segmentation improves security by', 'Software-defined networking enables', 'VPN tunneling protocols include'] def demo_generation(model, tokenizer, device): print(f"\n{'='*60}") print(f"Demo: {cfg.domain_name}-SLM Inference") print(f"{'='*60}\n") for i, prompt in enumerate(DEMO_PROMPTS, 1): print(f"[{i}] Prompt: {prompt}") response = generate_response(model, tokenizer, device, prompt, max_tokens=256) print(f" Response: {response}\n") def interactive_chat(): print("Loading model...") model, tokenizer, device = load_model() print(f"\n{'='*60}") print(f"{cfg.domain_name}-SLM Interactive Chat (type 'quit' to exit, 'demo' for demos)") print(f"{'='*60}\n") while True: try: user_input = input("You: ").strip() if not user_input: continue if user_input.lower() == "quit": print("Goodbye!") break if user_input.lower() == "demo": demo_generation(model, tokenizer, device) continue response = generate_response(model, tokenizer, device, user_input, max_tokens=512) print(f"SLM: {response}\n") except KeyboardInterrupt: print("\nGoodbye!") break if __name__ == "__main__": interactive_chat()