| |
| """Interactive chat and demo inference for Role SLM.""" |
|
|
| import torch |
| from tokenizers import Tokenizer |
| from config import cfg |
| from model import RoleSLM |
|
|
|
|
| def load_model(checkpoint_name="best_model.pt"): |
| device = torch.device(cfg.device) |
| ckpt_path = cfg.checkpoint_dir / checkpoint_name |
| if not ckpt_path.exists(): |
| raise FileNotFoundError(f"Checkpoint not found: {ckpt_path}") |
|
|
| ckpt = torch.load(ckpt_path, map_location=device, weights_only=False) |
| for key, val in ckpt.get("config", {}).items(): |
| if hasattr(cfg, key): |
| setattr(cfg, key, val) |
|
|
| model = RoleSLM() |
| model.load_state_dict(ckpt["model_state_dict"], strict=False) |
| model = model.to(device) |
| model.eval() |
|
|
| tok_path = cfg.tokenizer_dir / cfg.tokenizer_filename |
| tokenizer = Tokenizer.from_file(str(tok_path)) |
| print(f"Model loaded: {model.count_parameters()/1e6:.2f}M params") |
| return model, tokenizer, device |
|
|
|
|
| def generate_response(model, tokenizer, device, prompt, max_tokens=None, |
| temperature=0.8, top_k=50, top_p=0.9): |
| max_tokens = max_tokens or min(cfg.max_new_tokens, 512) |
| encoded = tokenizer.encode(prompt) |
| ids = encoded.ids |
| if ids and ids[-1] == 3: |
| ids = ids[:-1] |
| input_ids = torch.tensor([ids], dtype=torch.long, device=device) |
| input_len = input_ids.shape[1] |
|
|
| with torch.no_grad(): |
| output_ids = model.generate(input_ids, max_new_tokens=max_tokens, |
| temperature=temperature, top_k=top_k, top_p=top_p) |
|
|
| new_tokens = output_ids[0][input_len:].tolist() |
| response = tokenizer.decode(new_tokens) |
| response = response.replace("<eos>", "").replace("<bos>", "").replace("<pad>", "").strip() |
| return response |
|
|
|
|
| DEMO_PROMPTS = ['Object-oriented design principles include', 'Microservices architecture benefits include', 'The SOLID principles in software engineering are', 'Database indexing improves query performance by', 'RESTful API design best practices include'] |
|
|
|
|
| def demo_generation(model, tokenizer, device): |
| print(f"\n{'='*60}") |
| print(f"Demo: {cfg.domain_name}-SLM Inference") |
| print(f"{'='*60}\n") |
| for i, prompt in enumerate(DEMO_PROMPTS, 1): |
| print(f"[{i}] Prompt: {prompt}") |
| response = generate_response(model, tokenizer, device, prompt, max_tokens=256) |
| print(f" Response: {response}\n") |
|
|
|
|
| def interactive_chat(): |
| print("Loading model...") |
| model, tokenizer, device = load_model() |
| print(f"\n{'='*60}") |
| print(f"{cfg.domain_name}-SLM Interactive Chat (type 'quit' to exit, 'demo' for demos)") |
| print(f"{'='*60}\n") |
| while True: |
| try: |
| user_input = input("You: ").strip() |
| if not user_input: |
| continue |
| if user_input.lower() == "quit": |
| print("Goodbye!") |
| break |
| if user_input.lower() == "demo": |
| demo_generation(model, tokenizer, device) |
| continue |
| response = generate_response(model, tokenizer, device, user_input, max_tokens=512) |
| print(f"SLM: {response}\n") |
| except KeyboardInterrupt: |
| print("\nGoodbye!") |
| break |
|
|
|
|
| if __name__ == "__main__": |
| interactive_chat() |
|
|