| """ |
| Интерактивный REPL для болтовни с обученной моделью. |
| |
| Запуск: |
| python chat.py --out_dir=out-chat50m |
| python chat.py --out_dir=out-chat50m --temperature=0.8 --top_k=50 |
| python chat.py --out_dir=out-chat50m --system="Ты дружелюбный ассистент." |
| |
| Команды внутри REPL (в скобках -- однобуквенные алиасы): |
| /help /h показать список команд |
| /show /s показать текущие параметры сэмплинга |
| /reset /r сбросить историю диалога |
| /system <т> /sys <т> сменить system-промпт + reset |
| /temp <f> /t <f> temperature (>0) |
| /top_p <f> /p <f> nucleus sampling (0..1] |
| /top_k <i> /k <i> top-k (0 = выкл) |
| /rep <f> /rp <f> repetition_penalty (>=1.0) |
| /max_tokens<i> /mt <i> лимит длины ответа |
| /preset <n> /ps <n> creative | balanced | precise |
| /quit /q выйти |
| """ |
|
|
| import os |
| import sys |
| import io |
| import argparse |
| import pickle |
| import torch |
| import sentencepiece as spm |
|
|
| from model import GPTConfig, GPT |
|
|
| if sys.platform == 'win32': |
| sys.stdout = io.TextIOWrapper(sys.stdout.buffer, encoding='utf-8') |
| sys.stdin = io.TextIOWrapper(sys.stdin.buffer, encoding='utf-8') |
|
|
| SYS_TOK = '<|system|>' |
| USR_TOK = '<|user|>' |
| ASS_TOK = '<|assistant|>' |
| EOT_TOK = '<|endoftext|>' |
|
|
|
|
| def build_prompt(history, system): |
| """history: list of (role, content). Возвращает строку, заканчивающуюся на <|assistant|>.""" |
| parts = [] |
| if system: |
| parts.append(f'{SYS_TOK}{system}{EOT_TOK}') |
| for role, content in history: |
| tok = USR_TOK if role == 'user' else ASS_TOK |
| parts.append(f'{tok}{content}{EOT_TOK}') |
| parts.append(ASS_TOK) |
| return ''.join(parts) |
|
|
|
|
| @torch.no_grad() |
| def generate_until_eot(model, idx, eot_id, max_new_tokens, temperature, top_k, top_p, |
| repetition_penalty, repetition_window, device, on_token=None): |
| """Сэмплинг до <|endoftext|> или max_new_tokens с repetition_penalty + top-k + top-p. |
| |
| on_token(new_id, all_new_ids) -- опц. колбэк после каждого нового токена (для streaming). |
| """ |
| new_ids = [] |
| block_size = model.config.block_size |
| prompt_len = idx.size(1) |
| for _ in range(max_new_tokens): |
| idx_cond = idx if idx.size(1) <= block_size else idx[:, -block_size:] |
| logits, _ = model(idx_cond) |
| logits = logits[:, -1, :].clone() |
|
|
| |
| if repetition_penalty and repetition_penalty != 1.0: |
| recent = idx[0, -repetition_window:].tolist() |
| if recent: |
| uniq = list(set(recent)) |
| t = torch.tensor(uniq, device=logits.device, dtype=torch.long) |
| cur = logits[0, t] |
| |
| cur = torch.where(cur > 0, cur / repetition_penalty, cur * repetition_penalty) |
| logits[0, t] = cur |
|
|
| logits = logits / max(temperature, 1e-6) |
|
|
| |
| if top_k is not None and top_k > 0: |
| v, _ = torch.topk(logits, min(top_k, logits.size(-1))) |
| logits[logits < v[:, [-1]]] = -float('inf') |
|
|
| |
| if top_p is not None and 0.0 < top_p < 1.0: |
| sorted_logits, sorted_idx = torch.sort(logits, descending=True, dim=-1) |
| cum = torch.softmax(sorted_logits, dim=-1).cumsum(dim=-1) |
| mask = cum > top_p |
| mask[..., 1:] = mask[..., :-1].clone() |
| mask[..., 0] = False |
| sorted_logits = sorted_logits.masked_fill(mask, -float('inf')) |
| logits = torch.full_like(logits, -float('inf')).scatter(-1, sorted_idx, sorted_logits) |
|
|
| probs = torch.softmax(logits, dim=-1) |
| next_id = torch.multinomial(probs, num_samples=1) |
| nid = int(next_id.item()) |
| if nid == eot_id: |
| break |
| new_ids.append(nid) |
| idx = torch.cat([idx, next_id], dim=1) |
| if on_token is not None: |
| on_token(nid, new_ids) |
| return new_ids |
|
|
|
|
| def main(): |
| ap = argparse.ArgumentParser() |
| ap.add_argument('--out_dir', default='out-chat50m') |
| ap.add_argument('--data_dir', default='data/chat_ru') |
| ap.add_argument('--system', default='Ты вежливый и полезный ассистент. Отвечай по-русски.') |
| ap.add_argument('--temperature', type=float, default=0.7) |
| ap.add_argument('--top_k', type=int, default=40) |
| ap.add_argument('--top_p', type=float, default=0.9) |
| ap.add_argument('--repetition_penalty', type=float, default=1.15, |
| help='1.0 = выкл; 1.1-1.3 типичные значения') |
| ap.add_argument('--repetition_window', type=int, default=128, |
| help='в каком окне последних токенов штрафовать повторы') |
| ap.add_argument('--max_new_tokens', type=int, default=512) |
| ap.add_argument('--device', default='cuda' if torch.cuda.is_available() else 'cpu') |
| ap.add_argument('--dtype', default='bfloat16') |
| args = ap.parse_args() |
|
|
| |
| sp = spm.SentencePieceProcessor() |
| sp.Load(os.path.join(args.data_dir, 'tokenizer.model')) |
| with open(os.path.join(args.data_dir, 'meta.pkl'), 'rb') as f: |
| meta = pickle.load(f) |
| eot_id = meta['special_tokens']['endoftext'] |
| print(f'tokenizer ok, vocab={sp.get_piece_size()}, eot_id={eot_id}') |
|
|
| |
| ckpt_path = os.path.join(args.out_dir, 'ckpt.pt') |
| print(f'loading checkpoint: {ckpt_path}') |
| ckpt = torch.load(ckpt_path, map_location=args.device, weights_only=False) |
| gptconf = GPTConfig(**ckpt['model_args']) |
| model = GPT(gptconf) |
| sd = ckpt['model'] |
| |
| for k in list(sd.keys()): |
| if k.startswith('_orig_mod.'): |
| sd[k[len('_orig_mod.'):]] = sd.pop(k) |
| model.load_state_dict(sd) |
| model.eval() |
| model.to(args.device) |
| print(f'model: {model.get_num_params()/1e6:.1f}M params, block_size={model.config.block_size}') |
|
|
| ptdtype = {'float32': torch.float32, 'bfloat16': torch.bfloat16, 'float16': torch.float16}[args.dtype] |
| autocast = torch.amp.autocast(device_type=('cuda' if 'cuda' in args.device else 'cpu'), |
| dtype=ptdtype) |
|
|
| |
| params = dict( |
| temperature=args.temperature, |
| top_k=args.top_k, |
| top_p=args.top_p, |
| repetition_penalty=args.repetition_penalty, |
| repetition_window=args.repetition_window, |
| max_new_tokens=args.max_new_tokens, |
| ) |
|
|
| PRESETS = { |
| 'creative': dict(temperature=1.0, top_k=80, top_p=0.95, repetition_penalty=1.10), |
| 'balanced': dict(temperature=0.7, top_k=40, top_p=0.90, repetition_penalty=1.15), |
| 'precise': dict(temperature=0.35, top_k=20, top_p=0.85, repetition_penalty=1.25), |
| } |
|
|
| HELP = ( |
| 'Команды (в скобках -- однобуквенные алиасы):\n' |
| ' /help /h показать эту справку\n' |
| ' /show /s показать текущие параметры\n' |
| ' /reset /r сбросить историю диалога\n' |
| ' /system <т> /sys <т> сменить system-промпт + reset\n' |
| ' /temp <f> /t <f> temperature (>0)\n' |
| ' /top_p <f> /p <f> nucleus sampling (0..1]\n' |
| ' /top_k <i> /k <i> top-k (0 = выкл)\n' |
| ' /rep <f> /rp <f> repetition_penalty (>=1.0)\n' |
| ' /max_tokens<i> /mt <i> лимит длины ответа\n' |
| ' /preset <n> /ps <n> ' + ' | '.join(PRESETS.keys()) + '\n' |
| ' /quit /q выйти' |
| ) |
|
|
| |
| CANONICAL = { |
| '/h': '/help', '/s': '/show', '/r': '/reset', |
| '/q': '/quit', '/exit': '/quit', |
| '/sys': '/system', '/t': '/temp', '/p': '/top_p', '/k': '/top_k', |
| '/rp': '/rep', '/mt': '/max_tokens', '/ps': '/preset', |
| } |
|
|
| def show_params(): |
| print(f' system: {system!r}') |
| print(f' temperature={params["temperature"]}, top_k={params["top_k"]}, ' |
| f'top_p={params["top_p"]}, repetition_penalty={params["repetition_penalty"]}, ' |
| f'max_new_tokens={params["max_new_tokens"]}') |
|
|
| def parse_set(line, prefix, kind, validate=None): |
| """Распарсить '/cmd value' для одного параметра. Возвращает (ok, value_or_msg).""" |
| s = line[len(prefix):].strip() |
| if not s: |
| return False, f'нужен аргумент: {prefix} <value>' |
| try: |
| v = kind(s) |
| except ValueError: |
| return False, f'не могу разобрать как {kind.__name__}: {s!r}' |
| if validate is not None: |
| err = validate(v) |
| if err: |
| return False, err |
| return True, v |
|
|
| history = [] |
| system = args.system |
| print() |
| print('=== chat REPL === /help для списка команд') |
| show_params() |
| print() |
| while True: |
| try: |
| user = input('you> ').strip() |
| except (EOFError, KeyboardInterrupt): |
| print() |
| break |
| if not user: |
| continue |
|
|
| |
| if user.startswith('/'): |
| head, _, rest = user.partition(' ') |
| cmd = CANONICAL.get(head, head) |
| rest = rest.strip() |
| full = cmd if not rest else f'{cmd} {rest}' |
|
|
| if cmd == '/quit': |
| break |
| elif cmd == '/help': |
| print(HELP) |
| elif cmd == '/show': |
| show_params() |
| elif cmd == '/reset': |
| history = [] |
| print('(история сброшена)') |
| elif cmd == '/system': |
| system = rest |
| history = [] |
| print(f'(новый system: {system!r}, история сброшена)') |
| elif cmd == '/temp': |
| ok, v = parse_set(full, '/temp', float, |
| lambda x: None if x > 0 else 'temperature должен быть > 0') |
| if ok: params['temperature'] = v; print(f'(temperature = {v})') |
| else: print(f'! {v}') |
| elif cmd == '/top_p': |
| ok, v = parse_set(full, '/top_p', float, |
| lambda x: None if 0 < x <= 1.0 else 'top_p должен быть в (0..1]') |
| if ok: params['top_p'] = v; print(f'(top_p = {v})') |
| else: print(f'! {v}') |
| elif cmd == '/top_k': |
| ok, v = parse_set(full, '/top_k', int, |
| lambda x: None if x >= 0 else 'top_k должен быть >= 0') |
| if ok: params['top_k'] = v; print(f'(top_k = {v})') |
| else: print(f'! {v}') |
| elif cmd == '/rep': |
| ok, v = parse_set(full, '/rep', float, |
| lambda x: None if x >= 1.0 else 'repetition_penalty должен быть >= 1.0') |
| if ok: params['repetition_penalty'] = v; print(f'(repetition_penalty = {v})') |
| else: print(f'! {v}') |
| elif cmd == '/max_tokens': |
| ok, v = parse_set(full, '/max_tokens', int, |
| lambda x: None if 1 <= x <= 4096 else 'max_tokens в [1..4096]') |
| if ok: params['max_new_tokens'] = v; print(f'(max_new_tokens = {v})') |
| else: print(f'! {v}') |
| elif cmd == '/preset': |
| if rest not in PRESETS: |
| print(f'! пресет {rest!r} не найден. доступны: {list(PRESETS.keys())}') |
| else: |
| params.update(PRESETS[rest]) |
| print(f'(пресет {rest}: {PRESETS[rest]})') |
| else: |
| print(f'! неизвестная команда {head!r}. /help для списка.') |
| continue |
|
|
| history.append(('user', user)) |
| prompt = build_prompt(history, system) |
| ids = sp.encode(prompt, out_type=int) |
| |
| max_ctx = model.config.block_size - 64 |
| if len(ids) > max_ctx: |
| ids = ids[-max_ctx:] |
| idx = torch.tensor([ids], dtype=torch.long, device=args.device) |
|
|
| |
| |
| printed = {'text': '', 'ids': []} |
| def on_token(nid, all_ids): |
| |
| printed['ids'] = list(all_ids) |
| full = sp.decode(all_ids) |
| delta = full[len(printed['text']):] |
| if delta: |
| print(delta, end='', flush=True) |
| printed['text'] = full |
|
|
| print('bot> ', end='', flush=True) |
| try: |
| with autocast: |
| new_ids = generate_until_eot(model, idx, eot_id, params['max_new_tokens'], |
| params['temperature'], params['top_k'], |
| params['top_p'], params['repetition_penalty'], |
| params['repetition_window'], args.device, |
| on_token=on_token) |
| except KeyboardInterrupt: |
| new_ids = printed['ids'] |
| print('\n(прервано Ctrl+C)') |
| print() |
| reply = sp.decode(new_ids).strip() |
| history.append(('assistant', reply)) |
| print() |
|
|
|
|
| if __name__ == '__main__': |
| main() |
|
|