Imperius commited on
Commit
e5855a0
·
verified ·
1 Parent(s): c2bd8f2

Upload folder using huggingface_hub

Browse files
.gitattributes CHANGED
@@ -1,35 +1,6 @@
1
- *.7z filter=lfs diff=lfs merge=lfs -text
2
- *.arrow filter=lfs diff=lfs merge=lfs -text
3
- *.bin filter=lfs diff=lfs merge=lfs -text
4
- *.bz2 filter=lfs diff=lfs merge=lfs -text
5
- *.ckpt filter=lfs diff=lfs merge=lfs -text
6
- *.ftz filter=lfs diff=lfs merge=lfs -text
7
- *.gz filter=lfs diff=lfs merge=lfs -text
8
- *.h5 filter=lfs diff=lfs merge=lfs -text
9
- *.joblib filter=lfs diff=lfs merge=lfs -text
10
- *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
- *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
- *.model filter=lfs diff=lfs merge=lfs -text
13
- *.msgpack filter=lfs diff=lfs merge=lfs -text
14
- *.npy filter=lfs diff=lfs merge=lfs -text
15
- *.npz filter=lfs diff=lfs merge=lfs -text
16
- *.onnx filter=lfs diff=lfs merge=lfs -text
17
- *.ot filter=lfs diff=lfs merge=lfs -text
18
- *.parquet filter=lfs diff=lfs merge=lfs -text
19
- *.pb filter=lfs diff=lfs merge=lfs -text
20
- *.pickle filter=lfs diff=lfs merge=lfs -text
21
- *.pkl filter=lfs diff=lfs merge=lfs -text
22
- *.pt filter=lfs diff=lfs merge=lfs -text
23
- *.pth filter=lfs diff=lfs merge=lfs -text
24
- *.rar filter=lfs diff=lfs merge=lfs -text
25
- *.safetensors filter=lfs diff=lfs merge=lfs -text
26
- saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
- *.tar.* filter=lfs diff=lfs merge=lfs -text
28
- *.tar filter=lfs diff=lfs merge=lfs -text
29
- *.tflite filter=lfs diff=lfs merge=lfs -text
30
- *.tgz filter=lfs diff=lfs merge=lfs -text
31
- *.wasm filter=lfs diff=lfs merge=lfs -text
32
- *.xz filter=lfs diff=lfs merge=lfs -text
33
- *.zip filter=lfs diff=lfs merge=lfs -text
34
- *.zst filter=lfs diff=lfs merge=lfs -text
35
- *tfevents* filter=lfs diff=lfs merge=lfs -text
 
1
+ *.bin filter=lfs diff=lfs merge=lfs -text
2
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
3
+ *.pt filter=lfs diff=lfs merge=lfs -text
4
+ *.model filter=lfs diff=lfs merge=lfs -text
5
+ *.exe filter=lfs diff=lfs merge=lfs -text
6
+ *.pkl filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
README.md CHANGED
@@ -1,3 +1,346 @@
1
- ---
2
- license: apache-2.0
3
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ language:
3
+ - ru
4
+ license: apache-2.0
5
+ library_name: pytorch
6
+ tags:
7
+ - text-generation
8
+ - russian
9
+ - chat
10
+ - chatbot
11
+ - nanogpt
12
+ - small-llm
13
+ - sft
14
+ - educational
15
+ - research
16
+ pipeline_tag: text-generation
17
+ inference: false
18
+ ---
19
+
20
+ # mini-tron-50 (SFT)
21
+
22
+ 50M-параметровая GPT-2-стиль языковая модель, обученная с нуля на русском
23
+ SFT-корпусе как чат-бот. Educational baseline, иллюстрирующий **что можно
24
+ выжать из модели такого размера** без претрейна на сыром тексте.
25
+
26
+ ## TL;DR
27
+
28
+ - **Architecture**: 10 layer × 8 head × 512 emb (GPT-2 style), 47.85M params
29
+ - **Trained from scratch**: SFT на ~1 ГБ chat-данных (1.7M диалогов), 1 эпоха
30
+ - **Hardware**: 1× RTX 3050 Laptop (4 ГБ VRAM), 13 часов
31
+ - **Tokenizer**: SentencePiece BPE, 32k vocab, custom-trained на корпусе
32
+ - **Format**: ChatML с спецтокенами `<|system|>`, `<|user|>`, `<|assistant|>`,
33
+ `<|endoftext|>`
34
+ - **Loss masking**: только на assistant-токенах (стандартный SFT-trick)
35
+ - **Status**: SFT-фаза успешна; KTO-фаза проведена и провалилась (см. ниже)
36
+
37
+ Модель **не достигла своего потолка** — это одна из её особенностей. Она
38
+ тренировалась ровно одну эпоху по корпусу (Chinchilla-оптимум для 50M), и
39
+ контурные кривые показывают что улучшения val_loss ещё не выходили на плато.
40
+ Дальнейшая тренировка или продолжение на новых данных вероятно даст ощутимое
41
+ улучшение.
42
+
43
+ ## Архитектура
44
+
45
+ ```text
46
+ GPTConfig(
47
+ n_layer = 10,
48
+ n_head = 8,
49
+ n_embd = 512,
50
+ block_size = 1024,
51
+ vocab_size = 32000,
52
+ bias = False,
53
+ dropout = 0.1,
54
+ )
55
+ ```
56
+
57
+ Стандартный GPT-2 без bias. Tied embeddings (`transformer.wte` shared с `lm_head`).
58
+ Attention использует `F.scaled_dot_product_attention` (flash-attn под капотом
59
+ на Ampere+).
60
+
61
+ ## Что модель умеет
62
+
63
+ | Способность | Качество |
64
+ | --- | --- |
65
+ | Грамматика русского | ✅ безупречно (падежи, согласования, синтаксис) |
66
+ | Chat-формат (отвечает в роли ассистента) | ✅ устойчиво |
67
+ | Markdown-структура (списки, **bold**, заголовки) | ✅ имитирует GPT-4-стиль |
68
+ | Самоидентификация ("я ИИ-ассистент") | ✅ говорит правильные слова |
69
+ | Завершение по EOS | ✅ обычно сама останавливается |
70
+ | Локальная связность 1-2 предложения | ⚠️ местами осмысленно |
71
+ | Ответ строго по теме промпта | ⚠️ слышит триггерные слова, не суть |
72
+ | Факты и точные знания | ❌ галлюцинации |
73
+ | Арифметика | ❌ имитирует подсчёт без него |
74
+ | Многошаговый reasoning | ❌ |
75
+ | Code (синтаксис + семантика) | ❌ форма правильная, не работает |
76
+
77
+ ## Известные failure modes
78
+
79
+ 1. **Canned-ответы** на простые вопросы:
80
+ ```
81
+ you> Любишь котов или собак?
82
+ bot> Привет! Я рад, что смог помочь вам сегодня. Если у вас есть вопросы,
83
+ не стесняйтесь обращаться к нам. Удачи!
84
+ ```
85
+ Заученный шаблонный хвост ChatGPT-style ответов.
86
+
87
+ 2. **Tutorial-простыни** на любой открытый промпт:
88
+ ```
89
+ you> Расскажи о себе
90
+ bot> ### Шаг 1: Определение задачи
91
+ ### Шаг 2: ...
92
+ ### Заключение
93
+ ```
94
+
95
+ 3. **Token-loops** на промптах вне распределения:
96
+ ```
97
+ you> Что больше: 17 или 71?
98
+ bot> Чтобы посчитать... + 112 - 112 + 112 + 112 - 112 + 112 ... [50+ повторов]
99
+ ```
100
+
101
+ 4. **Семантическая каша** в полностью грамотной обёртке:
102
+ ```
103
+ you> Сколько будет 7 умножить на 8?
104
+ bot> 5! = (5 × 8) / 8 = 120. Теперь разделим 120 на 8...
105
+ Итак, всего будет 140 способов выбрать 7 умножить на 8.
106
+ ```
107
+ Модель имитирует жанр школьной арифметики, не выполняя саму операцию.
108
+
109
+ Эти failure modes лечатся не SFT, а preference learning'ом + увеличением
110
+ размера модели.
111
+
112
+ ## Тренировочные данные
113
+
114
+ **Источник**: `big-russian-dataset` (HuggingFace) — русскоязычный SFT-корпус.
115
+
116
+ | Сплит | Диалогов | После фильтра |
117
+ | --- | --- | --- |
118
+ | train | 1.71M | 1,709,621 (99.9%) |
119
+ | val | 18.5k | 10,396 (56%) |
120
+
121
+ **Фильтр**: `overall_score ≥ 6 AND safety ≥ 8 AND pii_leak = 0`.
122
+
123
+ В train авторы датасета сами уже почистили мусор — там нет записей со
124
+ score < 6, поэтому фильтр пропускает почти всё. В val разброс score 1-10
125
+ оставлен специально для оценки на трудных примерах.
126
+
127
+ **Объём в токенах**: ~1.04 ГБ токенов в train.bin, из них ~603M токенов под
128
+ loss (assistant + EOT, 57.7%).
129
+
130
+ ## Тренировочные параметры
131
+
132
+ ```python
133
+ # AdamW
134
+ learning_rate = 3e-4 # cosine decay → min_lr=3e-5
135
+ weight_decay = 0.1
136
+ beta1, beta2 = 0.9, 0.95
137
+ grad_clip = 1.0
138
+
139
+ # Schedule
140
+ warmup_iters = 200
141
+ max_iters = 16000 # ~1 эпоха
142
+ lr_decay_iters = 16000
143
+
144
+ # Batch
145
+ batch_size = 2
146
+ gradient_accumulation_steps = 32 # effective batch = 64 sequences
147
+ block_size = 1024
148
+ # tokens per iter = 65,536
149
+
150
+ # System
151
+ dtype = 'bfloat16'
152
+ compile = False
153
+ ```
154
+
155
+ ## Кривая обучения
156
+
157
+ ```text
158
+ iter 0 loss 10.49 (≈ ln(32000), стартовая случайная инициализация)
159
+ iter 500 loss ~5 (warmup закончен, LR на peak)
160
+ iter 5500 loss ~2.4 (первый saved checkpoint)
161
+ iter 11500 loss ~1.7 (третий)
162
+ iter 14500 loss ~1.5 (best val_loss ~ 1.8)
163
+ iter 16000 loss ~1.45 (max_iters достигнут)
164
+ ```
165
+
166
+ train-val gap к концу ~1.7 nats — здоровое значение для SFT на small model.
167
+
168
+ ## Как использовать
169
+
170
+ Модель распространяется в формате nanoGPT (Karpathy). **Не совместима напрямую
171
+ с `transformers.AutoModelForCausalLM`** без конвертации (имена весов отличаются от
172
+ GPT-2-стандарта, в частности нет bias).
173
+
174
+ ### Вариант 1 — через nanoGPT-репозиторий (Python)
175
+
176
+ ```bash
177
+ git clone https://github.com/karpathy/nanoGPT.git
178
+ cd nanoGPT
179
+ # скопировать ckpt.pt, tokenizer.model, meta.pkl сюда
180
+ python chat.py --out_dir=<path-to-checkpoint-dir>
181
+ ```
182
+
183
+ `chat.py` из этого релиза предоставляет интерактивный REPL с командами для
184
+ изменения temperature, top_p, top_k, repetition_penalty на лету.
185
+
186
+ ### Вариант 2 — через готовый Go single-binary
187
+
188
+ В этом релизе есть pre-built `nanogpt-chat.exe` (~97 МБ) — single-binary с
189
+ встроенными весами, токенайзером, конфигом. Запускается без Python и без
190
+ GPU. Performance: ~36 ток/сек на 4-ядерном CPU.
191
+
192
+ ```bash
193
+ ./nanogpt-chat.exe
194
+ ```
195
+
196
+ Команды REPL:
197
+
198
+ ```text
199
+ /help справка
200
+ /show текущие параметры
201
+ /reset сбросить историю диалога
202
+ /system <текст> сменить system-prompt
203
+ /temp <float> temperature (>0)
204
+ /top_p <float> nucleus sampling
205
+ /top_k <int> top-k
206
+ /rep <float> repetition_penalty
207
+ /preset <name> creative | balanced | precise
208
+ /quit выход
209
+ ```
210
+
211
+ Все команды имеют однобуквенные алиасы (/h, /s, /r, /q, /t, /p, /k, ...).
212
+
213
+ ### Вариант 3 — конвертация в HuggingFace transformers
214
+
215
+ Не предусмотрена в релизе. Требует переименования весов под GPT-2 conventions,
216
+ добавления nullевых bias, и сборки `GPT2LMHeadModel` с custom tokenizer
217
+ (SentencePiece, не GPT-2 BPE). Если кто-то напишет конвертер — буду рад PR.
218
+
219
+ ## Примеры
220
+
221
+ ```text
222
+ you> Привет
223
+ bot> Привет! Я готов помочь вам в достижении ваших целей. Хорошего дня!
224
+
225
+ you> Объясни простыми словами, что такое черная дыра.
226
+ bot> Для начала, давайте разберемся с этим вопросом шаг за шагом.
227
+
228
+ ### Шаг 1: Изучение черной дыры
229
+ - Это означает, что черная дыра является частью черной дыры.
230
+ - Она может иметь разные размеры...
231
+ [грамотно, но содержательно бредово -- это потолок 50M на фактах]
232
+
233
+ you> Напиши сортировку пузырьком на Python.
234
+ bot> Вот пример сортировки пузырьком на Python:
235
+ ```python
236
+ def merge_sort(arr):
237
+ """Выполняет сортировку пузырьком в отсортированном порядке."""
238
+ sorted_array = {}
239
+ for i in range(len(arr)):
240
+ if arr[i] < sorted_array[i - 1]:
241
+ return merge_sort(arr[i], reverse=True)
242
+ return sorted_array
243
+ ```
244
+ [имя функции merge_sort вместо bubble_sort, словарь вместо двух циклов --
245
+ форма правильная, реализация бессмысленная]
246
+ ```
247
+
248
+ Хорошие промпты для этой модели — короткие, диалоговые, не требующие точных
249
+ фактов:
250
+
251
+ ```text
252
+ you> Привет, как дела?
253
+ you> Меня зовут Лёша. А тебя?
254
+ you> Назови три фрукта.
255
+ you> Расскажи короткую сказку.
256
+ ```
257
+
258
+ ## Ограничения
259
+
260
+ - **Знания**: модель **не** содержит достоверных фактов. Не стоит спрашивать о
261
+ датах, именах, числах, географии, биологии, медицине. Любой ответ —
262
+ имитация жанра справки, а не реальная информация.
263
+ - **Reasoning**: многошаговая логика недоступна. Арифметика — имитируется
264
+ без выполнения. Code — синтаксически правдоподобен, но не работает.
265
+ - **Длина**: модель тренировалась с `block_size=1024`. Длинные диалоги
266
+ (>800 токенов в истории) обрезаются с начала — модель «забывает» ранние
267
+ реплики.
268
+ - **Языки**: только русский. На английских промптах попытается отвечать,
269
+ но качество хуже.
270
+ - **Безопасность**: модель тренировалась только на отфильтрованной части
271
+ датасета (`safety ≥ 8`), но не имеет специального alignment — на
272
+ откровенно вредных промптах поведение не гарантировано.
273
+
274
+ ## Что не получилось
275
+
276
+ После SFT была попытка preference-learning'а через **KTO** для подавления
277
+ известных failure modes. Обе попытки (β=0.1 и β=0.03) дали полностью
278
+ разрушенную модель — связные ответы превратились в семантический мусор.
279
+ Подробный root-cause анализ — в `04_kto_attempts.md` сопровождающего отчёта.
280
+
281
+ Кратко: комбинация (a) бага в реализации loss (отсутствие `clamp(z_ref, 0)`)
282
+ и (b) asymmetric difficulty между chosen-данными и self-generated rejected.
283
+ После исправления бага деградация всё равно осталась, просто медленнее.
284
+
285
+ Любопытный side-effect: после KTO модель уходила не просто в шум, а в
286
+ «афористически-философский» регистр — узнаваемый стилистический хвост
287
+ распределения, который KTO не давила (см. секцию «Inverse mode collapse»
288
+ в отчёте).
289
+
290
+ В этом релизе публикуется **только SFT-чекпоинт**, KTO-веса не включены.
291
+
292
+ ## Возможности дообучения
293
+
294
+ Модель **не на потолке**. Несколько направлений для продолжения:
295
+
296
+ 1. **Continued SFT** на расширенном корпусе. Особенно — добавить корпус с
297
+ фактическими знаниями (например, выжимки из Википедии) и кодом. Каждые
298
+ ~30% новых данных стоит давать ~1-2 эпохи.
299
+ 2. **Pre-training на сыром тексте** (если хочется уйти ниже 50M-потолка
300
+ качества). 1-5 ГБ русского OSCAR/CulturaX перед SFT может дать
301
+ значительный буст.
302
+ 3. **Distillation от внешней большой модели**. Текущий датасет уже дистилл,
303
+ но генерация новых ответов от Claude / GPT-4o-mini / Yandex YandexGPT
304
+ на тех же промптах даст разнообразие стилей.
305
+ 4. **Preference learning** (DPO/KTO) с **внешними** rejected (не
306
+ self-generated). Например, low-score ответы из val того же датасета.
307
+ 5. **Scale up** до 100-200M params с теми же гиперпараметрами и тем же
308
+ корпусом. Сильно нелинейный бу��т качества.
309
+
310
+ ## Файлы релиза
311
+
312
+ | Файл | Размер | Описание |
313
+ | --- | --- | --- |
314
+ | `ckpt.pt` | 553 МБ | nanoGPT-checkpoint (модель + optimizer state + config) |
315
+ | `tokenizer.model` | 930 КБ | SentencePiece-токенайзер (BPE 32k) |
316
+ | `meta.pkl` | <1 КБ | спецтокены ID + vocab_size |
317
+ | `nanogpt-chat.exe` (опц.) | 97 МБ | Go single-binary с встроенной моделью |
318
+ | `model_card.md` | этот файл | |
319
+
320
+ Если хочется только inference — `tokenizer.model` + `ckpt.pt` достаточно.
321
+
322
+ ## Citation / благодарности
323
+
324
+ ```bibtex
325
+ @misc{mini-tron-50,
326
+ title = {mini-tron-50: 50M Russian chat model trained from scratch},
327
+ author = {Impi},
328
+ year = {2026},
329
+ note = {Educational baseline; nanoGPT architecture}
330
+ }
331
+ ```
332
+
333
+ Использованные ресурсы:
334
+
335
+ - [nanoGPT](https://github.com/karpathy/nanoGPT) by Andrej Karpathy — основа
336
+ архитектуры и тренировочного цикла
337
+ - `big-russian-dataset` — обучающий корпус (необходимо проверить
338
+ оригинальную лицензию датасета перед использованием derivatives для
339
+ коммерческих целей)
340
+
341
+ ## Лицензия
342
+
343
+ Apache 2.0 — на код и веса этой модели. **Внимание**: лицензия на исходный
344
+ датасет (`big-russian-dataset`) может налагать дополнительные ограничения
345
+ на использование. Для коммерческого применения проверь оригинальную
346
+ лицензию датасета.
chat.py ADDED
@@ -0,0 +1,324 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Интерактивный REPL для болтовни с обученной моделью.
3
+
4
+ Запуск:
5
+ python chat.py --out_dir=out-chat50m
6
+ python chat.py --out_dir=out-chat50m --temperature=0.8 --top_k=50
7
+ python chat.py --out_dir=out-chat50m --system="Ты дружелюбный ассистент."
8
+
9
+ Команды внутри REPL (в скобках -- однобуквенные алиасы):
10
+ /help /h показать список команд
11
+ /show /s показать текущие параметры сэмплинга
12
+ /reset /r сбросить историю диалога
13
+ /system <т> /sys <т> сменить system-промпт + reset
14
+ /temp <f> /t <f> temperature (>0)
15
+ /top_p <f> /p <f> nucleus sampling (0..1]
16
+ /top_k <i> /k <i> top-k (0 = выкл)
17
+ /rep <f> /rp <f> repetition_penalty (>=1.0)
18
+ /max_tokens<i> /mt <i> лимит длины ответа
19
+ /preset <n> /ps <n> creative | balanced | precise
20
+ /quit /q выйти
21
+ """
22
+
23
+ import os
24
+ import sys
25
+ import io
26
+ import argparse
27
+ import pickle
28
+ import torch
29
+ import sentencepiece as spm
30
+
31
+ from model import GPTConfig, GPT
32
+
33
+ if sys.platform == 'win32':
34
+ sys.stdout = io.TextIOWrapper(sys.stdout.buffer, encoding='utf-8')
35
+ sys.stdin = io.TextIOWrapper(sys.stdin.buffer, encoding='utf-8')
36
+
37
+ SYS_TOK = '<|system|>'
38
+ USR_TOK = '<|user|>'
39
+ ASS_TOK = '<|assistant|>'
40
+ EOT_TOK = '<|endoftext|>'
41
+
42
+
43
+ def build_prompt(history, system):
44
+ """history: list of (role, content). Возвращает строку, заканчивающуюся на <|assistant|>."""
45
+ parts = []
46
+ if system:
47
+ parts.append(f'{SYS_TOK}{system}{EOT_TOK}')
48
+ for role, content in history:
49
+ tok = USR_TOK if role == 'user' else ASS_TOK
50
+ parts.append(f'{tok}{content}{EOT_TOK}')
51
+ parts.append(ASS_TOK)
52
+ return ''.join(parts)
53
+
54
+
55
+ @torch.no_grad()
56
+ def generate_until_eot(model, idx, eot_id, max_new_tokens, temperature, top_k, top_p,
57
+ repetition_penalty, repetition_window, device, on_token=None):
58
+ """Сэмплинг до <|endoftext|> или max_new_tokens с repetition_penalty + top-k + top-p.
59
+
60
+ on_token(new_id, all_new_ids) -- опц. колбэк после каждого нового токена (для streaming).
61
+ """
62
+ new_ids = []
63
+ block_size = model.config.block_size
64
+ prompt_len = idx.size(1)
65
+ for _ in range(max_new_tokens):
66
+ idx_cond = idx if idx.size(1) <= block_size else idx[:, -block_size:]
67
+ logits, _ = model(idx_cond)
68
+ logits = logits[:, -1, :].clone() # (1, V)
69
+
70
+ # repetition penalty: штрафуем токены, появлявшиеся в последнем окне
71
+ if repetition_penalty and repetition_penalty != 1.0:
72
+ recent = idx[0, -repetition_window:].tolist()
73
+ if recent:
74
+ uniq = list(set(recent))
75
+ t = torch.tensor(uniq, device=logits.device, dtype=torch.long)
76
+ cur = logits[0, t]
77
+ # классический CTRL-style: положительные logits делим, отрицательные -- умножаем
78
+ cur = torch.where(cur > 0, cur / repetition_penalty, cur * repetition_penalty)
79
+ logits[0, t] = cur
80
+
81
+ logits = logits / max(temperature, 1e-6)
82
+
83
+ # top-k
84
+ if top_k is not None and top_k > 0:
85
+ v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
86
+ logits[logits < v[:, [-1]]] = -float('inf')
87
+
88
+ # top-p (nucleus)
89
+ if top_p is not None and 0.0 < top_p < 1.0:
90
+ sorted_logits, sorted_idx = torch.sort(logits, descending=True, dim=-1)
91
+ cum = torch.softmax(sorted_logits, dim=-1).cumsum(dim=-1)
92
+ mask = cum > top_p
93
+ mask[..., 1:] = mask[..., :-1].clone()
94
+ mask[..., 0] = False
95
+ sorted_logits = sorted_logits.masked_fill(mask, -float('inf'))
96
+ logits = torch.full_like(logits, -float('inf')).scatter(-1, sorted_idx, sorted_logits)
97
+
98
+ probs = torch.softmax(logits, dim=-1)
99
+ next_id = torch.multinomial(probs, num_samples=1)
100
+ nid = int(next_id.item())
101
+ if nid == eot_id:
102
+ break
103
+ new_ids.append(nid)
104
+ idx = torch.cat([idx, next_id], dim=1)
105
+ if on_token is not None:
106
+ on_token(nid, new_ids)
107
+ return new_ids
108
+
109
+
110
+ def main():
111
+ ap = argparse.ArgumentParser()
112
+ ap.add_argument('--out_dir', default='out-chat50m')
113
+ ap.add_argument('--data_dir', default='data/chat_ru')
114
+ ap.add_argument('--system', default='Ты вежливый и полезный ассистент. Отвечай по-русски.')
115
+ ap.add_argument('--temperature', type=float, default=0.7)
116
+ ap.add_argument('--top_k', type=int, default=40)
117
+ ap.add_argument('--top_p', type=float, default=0.9)
118
+ ap.add_argument('--repetition_penalty', type=float, default=1.15,
119
+ help='1.0 = выкл; 1.1-1.3 типичные значения')
120
+ ap.add_argument('--repetition_window', type=int, default=128,
121
+ help='в каком окне последних токенов штрафовать повторы')
122
+ ap.add_argument('--max_new_tokens', type=int, default=512)
123
+ ap.add_argument('--device', default='cuda' if torch.cuda.is_available() else 'cpu')
124
+ ap.add_argument('--dtype', default='bfloat16')
125
+ args = ap.parse_args()
126
+
127
+ # tokenizer + meta
128
+ sp = spm.SentencePieceProcessor()
129
+ sp.Load(os.path.join(args.data_dir, 'tokenizer.model'))
130
+ with open(os.path.join(args.data_dir, 'meta.pkl'), 'rb') as f:
131
+ meta = pickle.load(f)
132
+ eot_id = meta['special_tokens']['endoftext']
133
+ print(f'tokenizer ok, vocab={sp.get_piece_size()}, eot_id={eot_id}')
134
+
135
+ # model
136
+ ckpt_path = os.path.join(args.out_dir, 'ckpt.pt')
137
+ print(f'loading checkpoint: {ckpt_path}')
138
+ ckpt = torch.load(ckpt_path, map_location=args.device, weights_only=False)
139
+ gptconf = GPTConfig(**ckpt['model_args'])
140
+ model = GPT(gptconf)
141
+ sd = ckpt['model']
142
+ # снять префикс _orig_mod. если был torch.compile
143
+ for k in list(sd.keys()):
144
+ if k.startswith('_orig_mod.'):
145
+ sd[k[len('_orig_mod.'):]] = sd.pop(k)
146
+ model.load_state_dict(sd)
147
+ model.eval()
148
+ model.to(args.device)
149
+ print(f'model: {model.get_num_params()/1e6:.1f}M params, block_size={model.config.block_size}')
150
+
151
+ ptdtype = {'float32': torch.float32, 'bfloat16': torch.bfloat16, 'float16': torch.float16}[args.dtype]
152
+ autocast = torch.amp.autocast(device_type=('cuda' if 'cuda' in args.device else 'cpu'),
153
+ dtype=ptdtype)
154
+
155
+ # Параметры сэмплинга, изменяемые на лету через /-команды.
156
+ params = dict(
157
+ temperature=args.temperature,
158
+ top_k=args.top_k,
159
+ top_p=args.top_p,
160
+ repetition_penalty=args.repetition_penalty,
161
+ repetition_window=args.repetition_window,
162
+ max_new_tokens=args.max_new_tokens,
163
+ )
164
+
165
+ PRESETS = {
166
+ 'creative': dict(temperature=1.0, top_k=80, top_p=0.95, repetition_penalty=1.10),
167
+ 'balanced': dict(temperature=0.7, top_k=40, top_p=0.90, repetition_penalty=1.15),
168
+ 'precise': dict(temperature=0.35, top_k=20, top_p=0.85, repetition_penalty=1.25),
169
+ }
170
+
171
+ HELP = (
172
+ 'Команды (в скобках -- однобуквенные алиасы):\n'
173
+ ' /help /h показать эту справку\n'
174
+ ' /show /s показать текущие параметры\n'
175
+ ' /reset /r сбросить историю диалога\n'
176
+ ' /system <т> /sys <т> сменить system-промпт + reset\n'
177
+ ' /temp <f> /t <f> temperature (>0)\n'
178
+ ' /top_p <f> /p <f> nucleus sampling (0..1]\n'
179
+ ' /top_k <i> /k <i> top-k (0 = выкл)\n'
180
+ ' /rep <f> /rp <f> repetition_penalty (>=1.0)\n'
181
+ ' /max_tokens<i> /mt <i> лимит длины ответа\n'
182
+ ' /preset <n> /ps <n> ' + ' | '.join(PRESETS.keys()) + '\n'
183
+ ' /quit /q выйти'
184
+ )
185
+
186
+ # Алиасы: первое слово в команде раскрывается в каноническое.
187
+ CANONICAL = {
188
+ '/h': '/help', '/s': '/show', '/r': '/reset',
189
+ '/q': '/quit', '/exit': '/quit',
190
+ '/sys': '/system', '/t': '/temp', '/p': '/top_p', '/k': '/top_k',
191
+ '/rp': '/rep', '/mt': '/max_tokens', '/ps': '/preset',
192
+ }
193
+
194
+ def show_params():
195
+ print(f' system: {system!r}')
196
+ print(f' temperature={params["temperature"]}, top_k={params["top_k"]}, '
197
+ f'top_p={params["top_p"]}, repetition_penalty={params["repetition_penalty"]}, '
198
+ f'max_new_tokens={params["max_new_tokens"]}')
199
+
200
+ def parse_set(line, prefix, kind, validate=None):
201
+ """Распарсить '/cmd value' для одного параметра. Возвращает (ok, value_or_msg)."""
202
+ s = line[len(prefix):].strip()
203
+ if not s:
204
+ return False, f'нужен аргумент: {prefix} <value>'
205
+ try:
206
+ v = kind(s)
207
+ except ValueError:
208
+ return False, f'не могу разобрать как {kind.__name__}: {s!r}'
209
+ if validate is not None:
210
+ err = validate(v)
211
+ if err:
212
+ return False, err
213
+ return True, v
214
+
215
+ history = [] # list[(role, content)]
216
+ system = args.system
217
+ print()
218
+ print('=== chat REPL === /help для списка команд')
219
+ show_params()
220
+ print()
221
+ while True:
222
+ try:
223
+ user = input('you> ').strip()
224
+ except (EOFError, KeyboardInterrupt):
225
+ print()
226
+ break
227
+ if not user:
228
+ continue
229
+
230
+ # Команды: первое слово раскрывается через CANONICAL
231
+ if user.startswith('/'):
232
+ head, _, rest = user.partition(' ')
233
+ cmd = CANONICAL.get(head, head)
234
+ rest = rest.strip()
235
+ full = cmd if not rest else f'{cmd} {rest}'
236
+
237
+ if cmd == '/quit':
238
+ break
239
+ elif cmd == '/help':
240
+ print(HELP)
241
+ elif cmd == '/show':
242
+ show_params()
243
+ elif cmd == '/reset':
244
+ history = []
245
+ print('(история сброшена)')
246
+ elif cmd == '/system':
247
+ system = rest
248
+ history = []
249
+ print(f'(новый system: {system!r}, история сброшена)')
250
+ elif cmd == '/temp':
251
+ ok, v = parse_set(full, '/temp', float,
252
+ lambda x: None if x > 0 else 'temperature должен быть > 0')
253
+ if ok: params['temperature'] = v; print(f'(temperature = {v})')
254
+ else: print(f'! {v}')
255
+ elif cmd == '/top_p':
256
+ ok, v = parse_set(full, '/top_p', float,
257
+ lambda x: None if 0 < x <= 1.0 else 'top_p должен быть в (0..1]')
258
+ if ok: params['top_p'] = v; print(f'(top_p = {v})')
259
+ else: print(f'! {v}')
260
+ elif cmd == '/top_k':
261
+ ok, v = parse_set(full, '/top_k', int,
262
+ lambda x: None if x >= 0 else 'top_k должен быть >= 0')
263
+ if ok: params['top_k'] = v; print(f'(top_k = {v})')
264
+ else: print(f'! {v}')
265
+ elif cmd == '/rep':
266
+ ok, v = parse_set(full, '/rep', float,
267
+ lambda x: None if x >= 1.0 else 'repetition_penalty должен быть >= 1.0')
268
+ if ok: params['repetition_penalty'] = v; print(f'(repetition_penalty = {v})')
269
+ else: print(f'! {v}')
270
+ elif cmd == '/max_tokens':
271
+ ok, v = parse_set(full, '/max_tokens', int,
272
+ lambda x: None if 1 <= x <= 4096 else 'max_tokens в [1..4096]')
273
+ if ok: params['max_new_tokens'] = v; print(f'(max_new_tokens = {v})')
274
+ else: print(f'! {v}')
275
+ elif cmd == '/preset':
276
+ if rest not in PRESETS:
277
+ print(f'! пресет {rest!r} не найден. доступны: {list(PRESETS.keys())}')
278
+ else:
279
+ params.update(PRESETS[rest])
280
+ print(f'(пресет {rest}: {PRESETS[rest]})')
281
+ else:
282
+ print(f'! неизвестная команда {head!r}. /help для списка.')
283
+ continue
284
+
285
+ history.append(('user', user))
286
+ prompt = build_prompt(history, system)
287
+ ids = sp.encode(prompt, out_type=int)
288
+ # обрезаем по block_size слева, оставляя минимум 64 для генерации
289
+ max_ctx = model.config.block_size - 64
290
+ if len(ids) > max_ctx:
291
+ ids = ids[-max_ctx:]
292
+ idx = torch.tensor([ids], dtype=torch.long, device=args.device)
293
+
294
+ # Streaming: после каждого нового токена декодируем весь префикс и печатаем
295
+ # дельту -- так корректно склеиваются подслова BPE (без ▁-артефактов).
296
+ printed = {'text': '', 'ids': []}
297
+ def on_token(nid, all_ids):
298
+ # храним актуальный список id чтобы при Ctrl+C сохранить partial-ответ
299
+ printed['ids'] = list(all_ids)
300
+ full = sp.decode(all_ids)
301
+ delta = full[len(printed['text']):]
302
+ if delta:
303
+ print(delta, end='', flush=True)
304
+ printed['text'] = full
305
+
306
+ print('bot> ', end='', flush=True)
307
+ try:
308
+ with autocast:
309
+ new_ids = generate_until_eot(model, idx, eot_id, params['max_new_tokens'],
310
+ params['temperature'], params['top_k'],
311
+ params['top_p'], params['repetition_penalty'],
312
+ params['repetition_window'], args.device,
313
+ on_token=on_token)
314
+ except KeyboardInterrupt:
315
+ new_ids = printed['ids']
316
+ print('\n(прервано Ctrl+C)')
317
+ print() # перевод строки после финального токена
318
+ reply = sp.decode(new_ids).strip()
319
+ history.append(('assistant', reply))
320
+ print()
321
+
322
+
323
+ if __name__ == '__main__':
324
+ main()
config.json ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "model_type": "nanogpt",
3
+ "architectures": [
4
+ "GPT"
5
+ ],
6
+ "n_layer": 10,
7
+ "n_head": 8,
8
+ "n_embd": 512,
9
+ "block_size": 1024,
10
+ "vocab_size": 32000,
11
+ "bias": false,
12
+ "dropout": 0.1,
13
+ "tie_word_embeddings": true,
14
+ "torch_dtype": "float16",
15
+ "gpt2_equivalent": {
16
+ "n_positions": 1024,
17
+ "n_ctx": 1024,
18
+ "n_embd": 512,
19
+ "n_head": 8,
20
+ "n_layer": 10,
21
+ "vocab_size": 32000
22
+ }
23
+ }
generation_config.json ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "do_sample": true,
3
+ "temperature": 0.7,
4
+ "top_p": 0.9,
5
+ "top_k": 40,
6
+ "repetition_penalty": 1.15,
7
+ "max_new_tokens": 300,
8
+ "eos_token_id": 5,
9
+ "pad_token_id": 0,
10
+ "transformers_version": "4.x"
11
+ }
meta.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9015d57a544db8e4d826f6db89df6f0ddacb81e5de04a043f00b82b9af0f3150
3
+ size 176
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8c2e988e95f6cccb62a68cc8def8fc8dfb6e84571f46d38aa5c34e5774b2ad5e
3
+ size 129527208
nanogpt-chat.exe ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:70e61cb8ab61719183a5964ee062469d41ca5b50d2f36e6cbbaad10ddae2919c
3
+ size 100751360
pytorch_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:abb2b369aee6d6a9e2aef3bcb95591546106037796c29c7c77b3c1bb68966c67
3
+ size 129541319
special_tokens_map.json ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "eos_token": "<|endoftext|>",
3
+ "pad_token": "<pad>",
4
+ "unk_token": "<unk>",
5
+ "additional_special_tokens": [
6
+ "<|system|>",
7
+ "<|user|>",
8
+ "<|assistant|>"
9
+ ]
10
+ }
tokenizer.model ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a0bc6eb521395e82253f7cfe70df6314be2b1cf10e45756a74cb0d511cd66d17
3
+ size 952151
tokenizer_config.json ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "tokenizer_class": "PreTrainedTokenizerFast",
3
+ "model_max_length": 1024,
4
+ "bos_token": null,
5
+ "eos_token": "<|endoftext|>",
6
+ "pad_token": "<pad>",
7
+ "unk_token": "<unk>",
8
+ "additional_special_tokens": [
9
+ "<|system|>",
10
+ "<|user|>",
11
+ "<|assistant|>"
12
+ ],
13
+ "clean_up_tokenization_spaces": false,
14
+ "chat_template": "{% for message in messages %}{% if message['role'] == 'system' %}<|system|>{{ message['content'] }}<|endoftext|>{% elif message['role'] == 'user' %}<|user|>{{ message['content'] }}<|endoftext|>{% elif message['role'] == 'assistant' %}<|assistant|>{{ message['content'] }}<|endoftext|>{% endif %}{% endfor %}{% if add_generation_prompt %}<|assistant|>{% endif %}",
15
+ "sp_model_file": "tokenizer.model",
16
+ "special_token_ids": {
17
+ "pad": 0,
18
+ "unk": 1,
19
+ "system": 2,
20
+ "user": 3,
21
+ "assistant": 4,
22
+ "endoftext": 5
23
+ }
24
+ }