| | |
| | """Interactive chat and demo inference for Role SLM.""" |
| |
|
| | import torch |
| | from tokenizers import Tokenizer |
| | from config import cfg |
| | from model import RoleSLM |
| |
|
| |
|
| | def load_model(checkpoint_name="best_model.pt"): |
| | device = torch.device(cfg.device) |
| | ckpt_path = cfg.checkpoint_dir / checkpoint_name |
| | if not ckpt_path.exists(): |
| | raise FileNotFoundError(f"Checkpoint not found: {ckpt_path}") |
| |
|
| | ckpt = torch.load(ckpt_path, map_location=device, weights_only=False) |
| | for key, val in ckpt.get("config", {}).items(): |
| | if hasattr(cfg, key): |
| | setattr(cfg, key, val) |
| |
|
| | model = RoleSLM() |
| | model.load_state_dict(ckpt["model_state_dict"], strict=False) |
| | model = model.to(device) |
| | model.eval() |
| |
|
| | tok_path = cfg.tokenizer_dir / cfg.tokenizer_filename |
| | tokenizer = Tokenizer.from_file(str(tok_path)) |
| | print(f"Model loaded: {model.count_parameters()/1e6:.2f}M params") |
| | return model, tokenizer, device |
| |
|
| |
|
| | def generate_response(model, tokenizer, device, prompt, max_tokens=None, |
| | temperature=0.8, top_k=50, top_p=0.9): |
| | max_tokens = max_tokens or min(cfg.max_new_tokens, 512) |
| | encoded = tokenizer.encode(prompt) |
| | ids = encoded.ids |
| | if ids and ids[-1] == 3: |
| | ids = ids[:-1] |
| | input_ids = torch.tensor([ids], dtype=torch.long, device=device) |
| | input_len = input_ids.shape[1] |
| |
|
| | with torch.no_grad(): |
| | output_ids = model.generate(input_ids, max_new_tokens=max_tokens, |
| | temperature=temperature, top_k=top_k, top_p=top_p) |
| |
|
| | new_tokens = output_ids[0][input_len:].tolist() |
| | response = tokenizer.decode(new_tokens) |
| | response = response.replace("<eos>", "").replace("<bos>", "").replace("<pad>", "").strip() |
| | return response |
| |
|
| |
|
| | DEMO_PROMPTS = ['Object-oriented design principles include', 'Microservices architecture benefits include', 'The SOLID principles in software engineering are', 'Database indexing improves query performance by', 'RESTful API design best practices include'] |
| |
|
| |
|
| | def demo_generation(model, tokenizer, device): |
| | print(f"\n{'='*60}") |
| | print(f"Demo: {cfg.domain_name}-SLM Inference") |
| | print(f"{'='*60}\n") |
| | for i, prompt in enumerate(DEMO_PROMPTS, 1): |
| | print(f"[{i}] Prompt: {prompt}") |
| | response = generate_response(model, tokenizer, device, prompt, max_tokens=256) |
| | print(f" Response: {response}\n") |
| |
|
| |
|
| | def interactive_chat(): |
| | print("Loading model...") |
| | model, tokenizer, device = load_model() |
| | print(f"\n{'='*60}") |
| | print(f"{cfg.domain_name}-SLM Interactive Chat (type 'quit' to exit, 'demo' for demos)") |
| | print(f"{'='*60}\n") |
| | while True: |
| | try: |
| | user_input = input("You: ").strip() |
| | if not user_input: |
| | continue |
| | if user_input.lower() == "quit": |
| | print("Goodbye!") |
| | break |
| | if user_input.lower() == "demo": |
| | demo_generation(model, tokenizer, device) |
| | continue |
| | response = generate_response(model, tokenizer, device, user_input, max_tokens=512) |
| | print(f"SLM: {response}\n") |
| | except KeyboardInterrupt: |
| | print("\nGoodbye!") |
| | break |
| |
|
| |
|
| | if __name__ == "__main__": |
| | interactive_chat() |
| |
|