mini-tron-50 / chat.py
Imperius's picture
Upload folder using huggingface_hub
e5855a0 verified
"""
Интерактивный REPL для болтовни с обученной моделью.
Запуск:
python chat.py --out_dir=out-chat50m
python chat.py --out_dir=out-chat50m --temperature=0.8 --top_k=50
python chat.py --out_dir=out-chat50m --system="Ты дружелюбный ассистент."
Команды внутри REPL (в скобках -- однобуквенные алиасы):
/help /h показать список команд
/show /s показать текущие параметры сэмплинга
/reset /r сбросить историю диалога
/system <т> /sys <т> сменить system-промпт + reset
/temp <f> /t <f> temperature (>0)
/top_p <f> /p <f> nucleus sampling (0..1]
/top_k <i> /k <i> top-k (0 = выкл)
/rep <f> /rp <f> repetition_penalty (>=1.0)
/max_tokens<i> /mt <i> лимит длины ответа
/preset <n> /ps <n> creative | balanced | precise
/quit /q выйти
"""
import os
import sys
import io
import argparse
import pickle
import torch
import sentencepiece as spm
from model import GPTConfig, GPT
if sys.platform == 'win32':
sys.stdout = io.TextIOWrapper(sys.stdout.buffer, encoding='utf-8')
sys.stdin = io.TextIOWrapper(sys.stdin.buffer, encoding='utf-8')
SYS_TOK = '<|system|>'
USR_TOK = '<|user|>'
ASS_TOK = '<|assistant|>'
EOT_TOK = '<|endoftext|>'
def build_prompt(history, system):
"""history: list of (role, content). Возвращает строку, заканчивающуюся на <|assistant|>."""
parts = []
if system:
parts.append(f'{SYS_TOK}{system}{EOT_TOK}')
for role, content in history:
tok = USR_TOK if role == 'user' else ASS_TOK
parts.append(f'{tok}{content}{EOT_TOK}')
parts.append(ASS_TOK)
return ''.join(parts)
@torch.no_grad()
def generate_until_eot(model, idx, eot_id, max_new_tokens, temperature, top_k, top_p,
repetition_penalty, repetition_window, device, on_token=None):
"""Сэмплинг до <|endoftext|> или max_new_tokens с repetition_penalty + top-k + top-p.
on_token(new_id, all_new_ids) -- опц. колбэк после каждого нового токена (для streaming).
"""
new_ids = []
block_size = model.config.block_size
prompt_len = idx.size(1)
for _ in range(max_new_tokens):
idx_cond = idx if idx.size(1) <= block_size else idx[:, -block_size:]
logits, _ = model(idx_cond)
logits = logits[:, -1, :].clone() # (1, V)
# repetition penalty: штрафуем токены, появлявшиеся в последнем окне
if repetition_penalty and repetition_penalty != 1.0:
recent = idx[0, -repetition_window:].tolist()
if recent:
uniq = list(set(recent))
t = torch.tensor(uniq, device=logits.device, dtype=torch.long)
cur = logits[0, t]
# классический CTRL-style: положительные logits делим, отрицательные -- умножаем
cur = torch.where(cur > 0, cur / repetition_penalty, cur * repetition_penalty)
logits[0, t] = cur
logits = logits / max(temperature, 1e-6)
# top-k
if top_k is not None and top_k > 0:
v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
logits[logits < v[:, [-1]]] = -float('inf')
# top-p (nucleus)
if top_p is not None and 0.0 < top_p < 1.0:
sorted_logits, sorted_idx = torch.sort(logits, descending=True, dim=-1)
cum = torch.softmax(sorted_logits, dim=-1).cumsum(dim=-1)
mask = cum > top_p
mask[..., 1:] = mask[..., :-1].clone()
mask[..., 0] = False
sorted_logits = sorted_logits.masked_fill(mask, -float('inf'))
logits = torch.full_like(logits, -float('inf')).scatter(-1, sorted_idx, sorted_logits)
probs = torch.softmax(logits, dim=-1)
next_id = torch.multinomial(probs, num_samples=1)
nid = int(next_id.item())
if nid == eot_id:
break
new_ids.append(nid)
idx = torch.cat([idx, next_id], dim=1)
if on_token is not None:
on_token(nid, new_ids)
return new_ids
def main():
ap = argparse.ArgumentParser()
ap.add_argument('--out_dir', default='out-chat50m')
ap.add_argument('--data_dir', default='data/chat_ru')
ap.add_argument('--system', default='Ты вежливый и полезный ассистент. Отвечай по-русски.')
ap.add_argument('--temperature', type=float, default=0.7)
ap.add_argument('--top_k', type=int, default=40)
ap.add_argument('--top_p', type=float, default=0.9)
ap.add_argument('--repetition_penalty', type=float, default=1.15,
help='1.0 = выкл; 1.1-1.3 типичные значения')
ap.add_argument('--repetition_window', type=int, default=128,
help='в каком окне последних токенов штрафовать повторы')
ap.add_argument('--max_new_tokens', type=int, default=512)
ap.add_argument('--device', default='cuda' if torch.cuda.is_available() else 'cpu')
ap.add_argument('--dtype', default='bfloat16')
args = ap.parse_args()
# tokenizer + meta
sp = spm.SentencePieceProcessor()
sp.Load(os.path.join(args.data_dir, 'tokenizer.model'))
with open(os.path.join(args.data_dir, 'meta.pkl'), 'rb') as f:
meta = pickle.load(f)
eot_id = meta['special_tokens']['endoftext']
print(f'tokenizer ok, vocab={sp.get_piece_size()}, eot_id={eot_id}')
# model
ckpt_path = os.path.join(args.out_dir, 'ckpt.pt')
print(f'loading checkpoint: {ckpt_path}')
ckpt = torch.load(ckpt_path, map_location=args.device, weights_only=False)
gptconf = GPTConfig(**ckpt['model_args'])
model = GPT(gptconf)
sd = ckpt['model']
# снять префикс _orig_mod. если был torch.compile
for k in list(sd.keys()):
if k.startswith('_orig_mod.'):
sd[k[len('_orig_mod.'):]] = sd.pop(k)
model.load_state_dict(sd)
model.eval()
model.to(args.device)
print(f'model: {model.get_num_params()/1e6:.1f}M params, block_size={model.config.block_size}')
ptdtype = {'float32': torch.float32, 'bfloat16': torch.bfloat16, 'float16': torch.float16}[args.dtype]
autocast = torch.amp.autocast(device_type=('cuda' if 'cuda' in args.device else 'cpu'),
dtype=ptdtype)
# Параметры сэмплинга, изменяемые на лету через /-команды.
params = dict(
temperature=args.temperature,
top_k=args.top_k,
top_p=args.top_p,
repetition_penalty=args.repetition_penalty,
repetition_window=args.repetition_window,
max_new_tokens=args.max_new_tokens,
)
PRESETS = {
'creative': dict(temperature=1.0, top_k=80, top_p=0.95, repetition_penalty=1.10),
'balanced': dict(temperature=0.7, top_k=40, top_p=0.90, repetition_penalty=1.15),
'precise': dict(temperature=0.35, top_k=20, top_p=0.85, repetition_penalty=1.25),
}
HELP = (
'Команды (в скобках -- однобуквенные алиасы):\n'
' /help /h показать эту справку\n'
' /show /s показать текущие параметры\n'
' /reset /r сбросить историю диалога\n'
' /system <т> /sys <т> сменить system-промпт + reset\n'
' /temp <f> /t <f> temperature (>0)\n'
' /top_p <f> /p <f> nucleus sampling (0..1]\n'
' /top_k <i> /k <i> top-k (0 = выкл)\n'
' /rep <f> /rp <f> repetition_penalty (>=1.0)\n'
' /max_tokens<i> /mt <i> лимит длины ответа\n'
' /preset <n> /ps <n> ' + ' | '.join(PRESETS.keys()) + '\n'
' /quit /q выйти'
)
# Алиасы: первое слово в команде раскрывается в каноническое.
CANONICAL = {
'/h': '/help', '/s': '/show', '/r': '/reset',
'/q': '/quit', '/exit': '/quit',
'/sys': '/system', '/t': '/temp', '/p': '/top_p', '/k': '/top_k',
'/rp': '/rep', '/mt': '/max_tokens', '/ps': '/preset',
}
def show_params():
print(f' system: {system!r}')
print(f' temperature={params["temperature"]}, top_k={params["top_k"]}, '
f'top_p={params["top_p"]}, repetition_penalty={params["repetition_penalty"]}, '
f'max_new_tokens={params["max_new_tokens"]}')
def parse_set(line, prefix, kind, validate=None):
"""Распарсить '/cmd value' для одного параметра. Возвращает (ok, value_or_msg)."""
s = line[len(prefix):].strip()
if not s:
return False, f'нужен аргумент: {prefix} <value>'
try:
v = kind(s)
except ValueError:
return False, f'не могу разобрать как {kind.__name__}: {s!r}'
if validate is not None:
err = validate(v)
if err:
return False, err
return True, v
history = [] # list[(role, content)]
system = args.system
print()
print('=== chat REPL === /help для списка команд')
show_params()
print()
while True:
try:
user = input('you> ').strip()
except (EOFError, KeyboardInterrupt):
print()
break
if not user:
continue
# Команды: первое слово раскрывается через CANONICAL
if user.startswith('/'):
head, _, rest = user.partition(' ')
cmd = CANONICAL.get(head, head)
rest = rest.strip()
full = cmd if not rest else f'{cmd} {rest}'
if cmd == '/quit':
break
elif cmd == '/help':
print(HELP)
elif cmd == '/show':
show_params()
elif cmd == '/reset':
history = []
print('(история сброшена)')
elif cmd == '/system':
system = rest
history = []
print(f'(новый system: {system!r}, история сброшена)')
elif cmd == '/temp':
ok, v = parse_set(full, '/temp', float,
lambda x: None if x > 0 else 'temperature должен быть > 0')
if ok: params['temperature'] = v; print(f'(temperature = {v})')
else: print(f'! {v}')
elif cmd == '/top_p':
ok, v = parse_set(full, '/top_p', float,
lambda x: None if 0 < x <= 1.0 else 'top_p должен быть в (0..1]')
if ok: params['top_p'] = v; print(f'(top_p = {v})')
else: print(f'! {v}')
elif cmd == '/top_k':
ok, v = parse_set(full, '/top_k', int,
lambda x: None if x >= 0 else 'top_k должен быть >= 0')
if ok: params['top_k'] = v; print(f'(top_k = {v})')
else: print(f'! {v}')
elif cmd == '/rep':
ok, v = parse_set(full, '/rep', float,
lambda x: None if x >= 1.0 else 'repetition_penalty должен быть >= 1.0')
if ok: params['repetition_penalty'] = v; print(f'(repetition_penalty = {v})')
else: print(f'! {v}')
elif cmd == '/max_tokens':
ok, v = parse_set(full, '/max_tokens', int,
lambda x: None if 1 <= x <= 4096 else 'max_tokens в [1..4096]')
if ok: params['max_new_tokens'] = v; print(f'(max_new_tokens = {v})')
else: print(f'! {v}')
elif cmd == '/preset':
if rest not in PRESETS:
print(f'! пресет {rest!r} не найден. доступны: {list(PRESETS.keys())}')
else:
params.update(PRESETS[rest])
print(f'(пресет {rest}: {PRESETS[rest]})')
else:
print(f'! неизвестная команда {head!r}. /help для списка.')
continue
history.append(('user', user))
prompt = build_prompt(history, system)
ids = sp.encode(prompt, out_type=int)
# обрезаем по block_size слева, оставляя минимум 64 для генерации
max_ctx = model.config.block_size - 64
if len(ids) > max_ctx:
ids = ids[-max_ctx:]
idx = torch.tensor([ids], dtype=torch.long, device=args.device)
# Streaming: после каждого нового токена декодируем весь префикс и печатаем
# дельту -- так корректно склеиваются подслова BPE (без ▁-артефактов).
printed = {'text': '', 'ids': []}
def on_token(nid, all_ids):
# храним актуальный список id чтобы при Ctrl+C сохранить partial-ответ
printed['ids'] = list(all_ids)
full = sp.decode(all_ids)
delta = full[len(printed['text']):]
if delta:
print(delta, end='', flush=True)
printed['text'] = full
print('bot> ', end='', flush=True)
try:
with autocast:
new_ids = generate_until_eot(model, idx, eot_id, params['max_new_tokens'],
params['temperature'], params['top_k'],
params['top_p'], params['repetition_penalty'],
params['repetition_window'], args.device,
on_token=on_token)
except KeyboardInterrupt:
new_ids = printed['ids']
print('\n(прервано Ctrl+C)')
print() # перевод строки после финального токена
reply = sp.decode(new_ids).strip()
history.append(('assistant', reply))
print()
if __name__ == '__main__':
main()