File size: 3,282 Bytes
7d0c7c5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
#!/usr/bin/env python3
"""Interactive chat and demo inference for Role SLM."""

import torch
from tokenizers import Tokenizer
from config import cfg
from model import RoleSLM


def load_model(checkpoint_name="best_model.pt"):
    device = torch.device(cfg.device)
    ckpt_path = cfg.checkpoint_dir / checkpoint_name
    if not ckpt_path.exists():
        raise FileNotFoundError(f"Checkpoint not found: {ckpt_path}")

    ckpt = torch.load(ckpt_path, map_location=device, weights_only=False)
    for key, val in ckpt.get("config", {}).items():
        if hasattr(cfg, key):
            setattr(cfg, key, val)

    model = RoleSLM()
    model.load_state_dict(ckpt["model_state_dict"], strict=False)
    model = model.to(device)
    model.eval()

    tok_path = cfg.tokenizer_dir / cfg.tokenizer_filename
    tokenizer = Tokenizer.from_file(str(tok_path))
    print(f"Model loaded: {model.count_parameters()/1e6:.2f}M params")
    return model, tokenizer, device


def generate_response(model, tokenizer, device, prompt, max_tokens=None,
                      temperature=0.8, top_k=50, top_p=0.9):
    max_tokens = max_tokens or min(cfg.max_new_tokens, 512)
    encoded = tokenizer.encode(prompt)
    ids = encoded.ids
    if ids and ids[-1] == 3:
        ids = ids[:-1]
    input_ids = torch.tensor([ids], dtype=torch.long, device=device)
    input_len = input_ids.shape[1]

    with torch.no_grad():
        output_ids = model.generate(input_ids, max_new_tokens=max_tokens,
                                    temperature=temperature, top_k=top_k, top_p=top_p)

    new_tokens = output_ids[0][input_len:].tolist()
    response = tokenizer.decode(new_tokens)
    response = response.replace("<eos>", "").replace("<bos>", "").replace("<pad>", "").strip()
    return response


DEMO_PROMPTS = ['Object-oriented design principles include', 'Microservices architecture benefits include', 'The SOLID principles in software engineering are', 'Database indexing improves query performance by', 'RESTful API design best practices include']


def demo_generation(model, tokenizer, device):
    print(f"\n{'='*60}")
    print(f"Demo: {cfg.domain_name}-SLM Inference")
    print(f"{'='*60}\n")
    for i, prompt in enumerate(DEMO_PROMPTS, 1):
        print(f"[{i}] Prompt: {prompt}")
        response = generate_response(model, tokenizer, device, prompt, max_tokens=256)
        print(f"    Response: {response}\n")


def interactive_chat():
    print("Loading model...")
    model, tokenizer, device = load_model()
    print(f"\n{'='*60}")
    print(f"{cfg.domain_name}-SLM Interactive Chat (type 'quit' to exit, 'demo' for demos)")
    print(f"{'='*60}\n")
    while True:
        try:
            user_input = input("You: ").strip()
            if not user_input:
                continue
            if user_input.lower() == "quit":
                print("Goodbye!")
                break
            if user_input.lower() == "demo":
                demo_generation(model, tokenizer, device)
                continue
            response = generate_response(model, tokenizer, device, user_input, max_tokens=512)
            print(f"SLM: {response}\n")
        except KeyboardInterrupt:
            print("\nGoodbye!")
            break


if __name__ == "__main__":
    interactive_chat()