""" Интерактивный REPL для болтовни с обученной моделью. Запуск: python chat.py --out_dir=out-chat50m python chat.py --out_dir=out-chat50m --temperature=0.8 --top_k=50 python chat.py --out_dir=out-chat50m --system="Ты дружелюбный ассистент." Команды внутри REPL (в скобках -- однобуквенные алиасы): /help /h показать список команд /show /s показать текущие параметры сэмплинга /reset /r сбросить историю диалога /system <т> /sys <т> сменить system-промпт + reset /temp /t temperature (>0) /top_p /p nucleus sampling (0..1] /top_k /k top-k (0 = выкл) /rep /rp repetition_penalty (>=1.0) /max_tokens /mt лимит длины ответа /preset /ps creative | balanced | precise /quit /q выйти """ import os import sys import io import argparse import pickle import torch import sentencepiece as spm from model import GPTConfig, GPT if sys.platform == 'win32': sys.stdout = io.TextIOWrapper(sys.stdout.buffer, encoding='utf-8') sys.stdin = io.TextIOWrapper(sys.stdin.buffer, encoding='utf-8') SYS_TOK = '<|system|>' USR_TOK = '<|user|>' ASS_TOK = '<|assistant|>' EOT_TOK = '<|endoftext|>' def build_prompt(history, system): """history: list of (role, content). Возвращает строку, заканчивающуюся на <|assistant|>.""" parts = [] if system: parts.append(f'{SYS_TOK}{system}{EOT_TOK}') for role, content in history: tok = USR_TOK if role == 'user' else ASS_TOK parts.append(f'{tok}{content}{EOT_TOK}') parts.append(ASS_TOK) return ''.join(parts) @torch.no_grad() def generate_until_eot(model, idx, eot_id, max_new_tokens, temperature, top_k, top_p, repetition_penalty, repetition_window, device, on_token=None): """Сэмплинг до <|endoftext|> или max_new_tokens с repetition_penalty + top-k + top-p. on_token(new_id, all_new_ids) -- опц. колбэк после каждого нового токена (для streaming). """ new_ids = [] block_size = model.config.block_size prompt_len = idx.size(1) for _ in range(max_new_tokens): idx_cond = idx if idx.size(1) <= block_size else idx[:, -block_size:] logits, _ = model(idx_cond) logits = logits[:, -1, :].clone() # (1, V) # repetition penalty: штрафуем токены, появлявшиеся в последнем окне if repetition_penalty and repetition_penalty != 1.0: recent = idx[0, -repetition_window:].tolist() if recent: uniq = list(set(recent)) t = torch.tensor(uniq, device=logits.device, dtype=torch.long) cur = logits[0, t] # классический CTRL-style: положительные logits делим, отрицательные -- умножаем cur = torch.where(cur > 0, cur / repetition_penalty, cur * repetition_penalty) logits[0, t] = cur logits = logits / max(temperature, 1e-6) # top-k if top_k is not None and top_k > 0: v, _ = torch.topk(logits, min(top_k, logits.size(-1))) logits[logits < v[:, [-1]]] = -float('inf') # top-p (nucleus) if top_p is not None and 0.0 < top_p < 1.0: sorted_logits, sorted_idx = torch.sort(logits, descending=True, dim=-1) cum = torch.softmax(sorted_logits, dim=-1).cumsum(dim=-1) mask = cum > top_p mask[..., 1:] = mask[..., :-1].clone() mask[..., 0] = False sorted_logits = sorted_logits.masked_fill(mask, -float('inf')) logits = torch.full_like(logits, -float('inf')).scatter(-1, sorted_idx, sorted_logits) probs = torch.softmax(logits, dim=-1) next_id = torch.multinomial(probs, num_samples=1) nid = int(next_id.item()) if nid == eot_id: break new_ids.append(nid) idx = torch.cat([idx, next_id], dim=1) if on_token is not None: on_token(nid, new_ids) return new_ids def main(): ap = argparse.ArgumentParser() ap.add_argument('--out_dir', default='out-chat50m') ap.add_argument('--data_dir', default='data/chat_ru') ap.add_argument('--system', default='Ты вежливый и полезный ассистент. Отвечай по-русски.') ap.add_argument('--temperature', type=float, default=0.7) ap.add_argument('--top_k', type=int, default=40) ap.add_argument('--top_p', type=float, default=0.9) ap.add_argument('--repetition_penalty', type=float, default=1.15, help='1.0 = выкл; 1.1-1.3 типичные значения') ap.add_argument('--repetition_window', type=int, default=128, help='в каком окне последних токенов штрафовать повторы') ap.add_argument('--max_new_tokens', type=int, default=512) ap.add_argument('--device', default='cuda' if torch.cuda.is_available() else 'cpu') ap.add_argument('--dtype', default='bfloat16') args = ap.parse_args() # tokenizer + meta sp = spm.SentencePieceProcessor() sp.Load(os.path.join(args.data_dir, 'tokenizer.model')) with open(os.path.join(args.data_dir, 'meta.pkl'), 'rb') as f: meta = pickle.load(f) eot_id = meta['special_tokens']['endoftext'] print(f'tokenizer ok, vocab={sp.get_piece_size()}, eot_id={eot_id}') # model ckpt_path = os.path.join(args.out_dir, 'ckpt.pt') print(f'loading checkpoint: {ckpt_path}') ckpt = torch.load(ckpt_path, map_location=args.device, weights_only=False) gptconf = GPTConfig(**ckpt['model_args']) model = GPT(gptconf) sd = ckpt['model'] # снять префикс _orig_mod. если был torch.compile for k in list(sd.keys()): if k.startswith('_orig_mod.'): sd[k[len('_orig_mod.'):]] = sd.pop(k) model.load_state_dict(sd) model.eval() model.to(args.device) print(f'model: {model.get_num_params()/1e6:.1f}M params, block_size={model.config.block_size}') ptdtype = {'float32': torch.float32, 'bfloat16': torch.bfloat16, 'float16': torch.float16}[args.dtype] autocast = torch.amp.autocast(device_type=('cuda' if 'cuda' in args.device else 'cpu'), dtype=ptdtype) # Параметры сэмплинга, изменяемые на лету через /-команды. params = dict( temperature=args.temperature, top_k=args.top_k, top_p=args.top_p, repetition_penalty=args.repetition_penalty, repetition_window=args.repetition_window, max_new_tokens=args.max_new_tokens, ) PRESETS = { 'creative': dict(temperature=1.0, top_k=80, top_p=0.95, repetition_penalty=1.10), 'balanced': dict(temperature=0.7, top_k=40, top_p=0.90, repetition_penalty=1.15), 'precise': dict(temperature=0.35, top_k=20, top_p=0.85, repetition_penalty=1.25), } HELP = ( 'Команды (в скобках -- однобуквенные алиасы):\n' ' /help /h показать эту справку\n' ' /show /s показать текущие параметры\n' ' /reset /r сбросить историю диалога\n' ' /system <т> /sys <т> сменить system-промпт + reset\n' ' /temp /t temperature (>0)\n' ' /top_p /p nucleus sampling (0..1]\n' ' /top_k /k top-k (0 = выкл)\n' ' /rep /rp repetition_penalty (>=1.0)\n' ' /max_tokens /mt лимит длины ответа\n' ' /preset /ps ' + ' | '.join(PRESETS.keys()) + '\n' ' /quit /q выйти' ) # Алиасы: первое слово в команде раскрывается в каноническое. CANONICAL = { '/h': '/help', '/s': '/show', '/r': '/reset', '/q': '/quit', '/exit': '/quit', '/sys': '/system', '/t': '/temp', '/p': '/top_p', '/k': '/top_k', '/rp': '/rep', '/mt': '/max_tokens', '/ps': '/preset', } def show_params(): print(f' system: {system!r}') print(f' temperature={params["temperature"]}, top_k={params["top_k"]}, ' f'top_p={params["top_p"]}, repetition_penalty={params["repetition_penalty"]}, ' f'max_new_tokens={params["max_new_tokens"]}') def parse_set(line, prefix, kind, validate=None): """Распарсить '/cmd value' для одного параметра. Возвращает (ok, value_or_msg).""" s = line[len(prefix):].strip() if not s: return False, f'нужен аргумент: {prefix} ' try: v = kind(s) except ValueError: return False, f'не могу разобрать как {kind.__name__}: {s!r}' if validate is not None: err = validate(v) if err: return False, err return True, v history = [] # list[(role, content)] system = args.system print() print('=== chat REPL === /help для списка команд') show_params() print() while True: try: user = input('you> ').strip() except (EOFError, KeyboardInterrupt): print() break if not user: continue # Команды: первое слово раскрывается через CANONICAL if user.startswith('/'): head, _, rest = user.partition(' ') cmd = CANONICAL.get(head, head) rest = rest.strip() full = cmd if not rest else f'{cmd} {rest}' if cmd == '/quit': break elif cmd == '/help': print(HELP) elif cmd == '/show': show_params() elif cmd == '/reset': history = [] print('(история сброшена)') elif cmd == '/system': system = rest history = [] print(f'(новый system: {system!r}, история сброшена)') elif cmd == '/temp': ok, v = parse_set(full, '/temp', float, lambda x: None if x > 0 else 'temperature должен быть > 0') if ok: params['temperature'] = v; print(f'(temperature = {v})') else: print(f'! {v}') elif cmd == '/top_p': ok, v = parse_set(full, '/top_p', float, lambda x: None if 0 < x <= 1.0 else 'top_p должен быть в (0..1]') if ok: params['top_p'] = v; print(f'(top_p = {v})') else: print(f'! {v}') elif cmd == '/top_k': ok, v = parse_set(full, '/top_k', int, lambda x: None if x >= 0 else 'top_k должен быть >= 0') if ok: params['top_k'] = v; print(f'(top_k = {v})') else: print(f'! {v}') elif cmd == '/rep': ok, v = parse_set(full, '/rep', float, lambda x: None if x >= 1.0 else 'repetition_penalty должен быть >= 1.0') if ok: params['repetition_penalty'] = v; print(f'(repetition_penalty = {v})') else: print(f'! {v}') elif cmd == '/max_tokens': ok, v = parse_set(full, '/max_tokens', int, lambda x: None if 1 <= x <= 4096 else 'max_tokens в [1..4096]') if ok: params['max_new_tokens'] = v; print(f'(max_new_tokens = {v})') else: print(f'! {v}') elif cmd == '/preset': if rest not in PRESETS: print(f'! пресет {rest!r} не найден. доступны: {list(PRESETS.keys())}') else: params.update(PRESETS[rest]) print(f'(пресет {rest}: {PRESETS[rest]})') else: print(f'! неизвестная команда {head!r}. /help для списка.') continue history.append(('user', user)) prompt = build_prompt(history, system) ids = sp.encode(prompt, out_type=int) # обрезаем по block_size слева, оставляя минимум 64 для генерации max_ctx = model.config.block_size - 64 if len(ids) > max_ctx: ids = ids[-max_ctx:] idx = torch.tensor([ids], dtype=torch.long, device=args.device) # Streaming: после каждого нового токена декодируем весь префикс и печатаем # дельту -- так корректно склеиваются подслова BPE (без ▁-артефактов). printed = {'text': '', 'ids': []} def on_token(nid, all_ids): # храним актуальный список id чтобы при Ctrl+C сохранить partial-ответ printed['ids'] = list(all_ids) full = sp.decode(all_ids) delta = full[len(printed['text']):] if delta: print(delta, end='', flush=True) printed['text'] = full print('bot> ', end='', flush=True) try: with autocast: new_ids = generate_until_eot(model, idx, eot_id, params['max_new_tokens'], params['temperature'], params['top_k'], params['top_p'], params['repetition_penalty'], params['repetition_window'], args.device, on_token=on_token) except KeyboardInterrupt: new_ids = printed['ids'] print('\n(прервано Ctrl+C)') print() # перевод строки после финального токена reply = sp.decode(new_ids).strip() history.append(('assistant', reply)) print() if __name__ == '__main__': main()