TreeLeek commited on
Commit
ac8476e
·
verified ·
1 Parent(s): 85da6d4

Upload chat_stage_b.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. chat_stage_b.py +148 -0
chat_stage_b.py ADDED
@@ -0,0 +1,148 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ chat_stage_b.py — Chat with Leek using the Stage B checkpoint.
4
+
5
+ She responds to instructions now, not just text completion.
6
+ Type your message, press Enter. Type 'quit' to exit.
7
+
8
+ Usage:
9
+ python3 chat_stage_b.py --block-size 512
10
+ python3 chat_stage_b.py --block-size 512 --temp 0.7
11
+ """
12
+
13
+ import argparse
14
+ import sys
15
+ from pathlib import Path
16
+
17
+ import mlx.core as mx
18
+ import mlx.utils as mlx_utils
19
+ import numpy as np
20
+ import sentencepiece as spm
21
+
22
+ ROOT = Path(__file__).parent
23
+ sys.path.insert(0, str(ROOT))
24
+
25
+ from leeknet_500m import LeekNet500M, TOKENIZER_MODEL, CKPT_DIR, BLOCK_SIZE
26
+
27
+
28
+ def load_best_checkpoint(model):
29
+ ckpts = sorted(CKPT_DIR.glob('stage_b_step*_best.npz'),
30
+ key=lambda p: int(p.stem.split('step')[1].split('_')[0]))
31
+ if not ckpts:
32
+ ckpts = sorted(CKPT_DIR.glob('stage_b_step*.npz'),
33
+ key=lambda p: int(p.stem.split('step')[1].split('_')[0]))
34
+ if not ckpts:
35
+ print('no Stage B checkpoint found')
36
+ sys.exit(1)
37
+ latest = ckpts[-1]
38
+ print(f'loading: {latest.name}')
39
+ w = np.load(latest)
40
+ model.load_weights([(k, mx.array(v)) for k, v in w.items()])
41
+
42
+
43
+ def generate(model, tok, prompt_ids, max_new_tokens, temperature, block_size):
44
+ ctx = mx.array([prompt_ids], dtype=mx.int32)
45
+ generated = []
46
+
47
+ for _ in range(max_new_tokens):
48
+ if ctx.shape[1] > block_size:
49
+ ctx = ctx[:, -block_size:]
50
+
51
+ logits = model(ctx)
52
+ next_logits = logits[0, -1]
53
+
54
+ if temperature <= 0.0:
55
+ next_id = int(mx.argmax(next_logits).item())
56
+ else:
57
+ next_logits = next_logits / temperature
58
+ probs = mx.softmax(next_logits)
59
+ mx.eval(probs)
60
+ p = np.array(probs.tolist())
61
+ p = p / p.sum()
62
+ next_id = int(np.random.choice(len(p), p=p))
63
+
64
+ if next_id == tok.eos_id():
65
+ break
66
+
67
+ generated.append(next_id)
68
+ ctx = mx.concatenate([ctx, mx.array([[next_id]])], axis=1)
69
+
70
+ full_text = tok.decode(prompt_ids + generated)
71
+ prev_text = tok.decode(prompt_ids + generated[:-1])
72
+ print(full_text[len(prev_text):], end='', flush=True)
73
+
74
+ print()
75
+ return generated
76
+
77
+
78
+ def main():
79
+ parser = argparse.ArgumentParser()
80
+ parser.add_argument('--block-size', type=int, default=512)
81
+ parser.add_argument('--temp', type=float, default=0.8)
82
+ parser.add_argument('--max-tokens', type=int, default=400)
83
+ parser.add_argument('--system', type=str, default=None,
84
+ help='system prompt prepended before conversation')
85
+ parser.add_argument('--no-system', action='store_true',
86
+ help='disable default system prompt')
87
+ args = parser.parse_args()
88
+
89
+ print('loading tokenizer...')
90
+ tok = spm.SentencePieceProcessor(model_file=str(TOKENIZER_MODEL))
91
+
92
+ print('building model...')
93
+ model = LeekNet500M(block_size=args.block_size)
94
+ load_best_checkpoint(model)
95
+
96
+ default_system = (
97
+ "You are a helpful, direct, and honest assistant. "
98
+ "Answer questions clearly and accurately. "
99
+ "Be concise. Do not ramble or use flowery language."
100
+ )
101
+
102
+ if args.no_system:
103
+ system = None
104
+ elif args.system:
105
+ system = args.system
106
+ else:
107
+ system = default_system
108
+
109
+ print(f'\nready. block_size={args.block_size} temp={args.temp}')
110
+ if system:
111
+ print(f'system: {system}')
112
+ print('type your message and press Enter. quit to exit.\n')
113
+
114
+ history = []
115
+ if system:
116
+ history.append(f'System: {system}')
117
+
118
+ while True:
119
+ try:
120
+ user_input = input('Human: ').strip()
121
+ except (EOFError, KeyboardInterrupt):
122
+ print()
123
+ break
124
+
125
+ if not user_input or user_input.lower() in ('quit', 'exit', 'q'):
126
+ break
127
+
128
+ history.append(f'Human: {user_input}')
129
+ prompt = '\n'.join(history) + '\nAssistant:'
130
+
131
+ prompt_ids = tok.encode(prompt)
132
+
133
+ print('Assistant: ', end='', flush=True)
134
+ generated = generate(model, tok, prompt_ids, args.max_tokens, args.temp, args.block_size)
135
+
136
+ response_text = tok.decode(generated).strip()
137
+ history.append(f'Assistant: {response_text}')
138
+
139
+ # keep history from growing past block_size
140
+ while len(tok.encode('\n'.join(history))) > args.block_size - 100:
141
+ if len(history) > 2:
142
+ history = history[2:]
143
+ else:
144
+ break
145
+
146
+
147
+ if __name__ == '__main__':
148
+ main()