File size: 3,282 Bytes
7d0c7c5 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 | #!/usr/bin/env python3
"""Interactive chat and demo inference for Role SLM."""
import torch
from tokenizers import Tokenizer
from config import cfg
from model import RoleSLM
def load_model(checkpoint_name="best_model.pt"):
device = torch.device(cfg.device)
ckpt_path = cfg.checkpoint_dir / checkpoint_name
if not ckpt_path.exists():
raise FileNotFoundError(f"Checkpoint not found: {ckpt_path}")
ckpt = torch.load(ckpt_path, map_location=device, weights_only=False)
for key, val in ckpt.get("config", {}).items():
if hasattr(cfg, key):
setattr(cfg, key, val)
model = RoleSLM()
model.load_state_dict(ckpt["model_state_dict"], strict=False)
model = model.to(device)
model.eval()
tok_path = cfg.tokenizer_dir / cfg.tokenizer_filename
tokenizer = Tokenizer.from_file(str(tok_path))
print(f"Model loaded: {model.count_parameters()/1e6:.2f}M params")
return model, tokenizer, device
def generate_response(model, tokenizer, device, prompt, max_tokens=None,
temperature=0.8, top_k=50, top_p=0.9):
max_tokens = max_tokens or min(cfg.max_new_tokens, 512)
encoded = tokenizer.encode(prompt)
ids = encoded.ids
if ids and ids[-1] == 3:
ids = ids[:-1]
input_ids = torch.tensor([ids], dtype=torch.long, device=device)
input_len = input_ids.shape[1]
with torch.no_grad():
output_ids = model.generate(input_ids, max_new_tokens=max_tokens,
temperature=temperature, top_k=top_k, top_p=top_p)
new_tokens = output_ids[0][input_len:].tolist()
response = tokenizer.decode(new_tokens)
response = response.replace("<eos>", "").replace("<bos>", "").replace("<pad>", "").strip()
return response
DEMO_PROMPTS = ['Object-oriented design principles include', 'Microservices architecture benefits include', 'The SOLID principles in software engineering are', 'Database indexing improves query performance by', 'RESTful API design best practices include']
def demo_generation(model, tokenizer, device):
print(f"\n{'='*60}")
print(f"Demo: {cfg.domain_name}-SLM Inference")
print(f"{'='*60}\n")
for i, prompt in enumerate(DEMO_PROMPTS, 1):
print(f"[{i}] Prompt: {prompt}")
response = generate_response(model, tokenizer, device, prompt, max_tokens=256)
print(f" Response: {response}\n")
def interactive_chat():
print("Loading model...")
model, tokenizer, device = load_model()
print(f"\n{'='*60}")
print(f"{cfg.domain_name}-SLM Interactive Chat (type 'quit' to exit, 'demo' for demos)")
print(f"{'='*60}\n")
while True:
try:
user_input = input("You: ").strip()
if not user_input:
continue
if user_input.lower() == "quit":
print("Goodbye!")
break
if user_input.lower() == "demo":
demo_generation(model, tokenizer, device)
continue
response = generate_response(model, tokenizer, device, user_input, max_tokens=512)
print(f"SLM: {response}\n")
except KeyboardInterrupt:
print("\nGoodbye!")
break
if __name__ == "__main__":
interactive_chat()
|