daniilkolbasenko commited on
Commit
c0b8285
·
verified ·
1 Parent(s): cd40371

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +267 -0
app.py ADDED
@@ -0,0 +1,267 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import time
3
+ import torch
4
+ import torch.nn as nn
5
+ from torch.nn import functional as F
6
+ from torch.utils.data import DataLoader, Dataset
7
+ import tiktoken
8
+ import gradio as gr
9
+ from tqdm import tqdm
10
+ import numpy as np
11
+ from datasets import load_dataset
12
+
13
+ # ---------- 1. Жёсткие ограничения на ресурсы ----------
14
+ # Используем 12 ядер CPU и ~13 ГБ RAM
15
+ torch.set_num_threads(12)
16
+ torch.set_num_interop_threads(12)
17
+
18
+ # Ограничение памяти PyTorch (опционально, для безопасности)
19
+ # torch.cuda.empty_cache() – не нужно, так как CPU
20
+
21
+ # --- Гиперпараметры модели (подобраны под 13 ГБ RAM) ---
22
+ vocab_size = 50257
23
+ block_size = 256
24
+ n_embd = 384
25
+ n_head = 6
26
+ n_layer = 6
27
+ dropout = 0.1
28
+
29
+ # --- Гиперпараметры обучения (снижены для экономии памяти) ---
30
+ batch_size = 24 # было 32 -> снижаем
31
+ learning_rate = 5e-4
32
+ max_iters = 15000
33
+ eval_interval = 500
34
+ eval_iters = 100
35
+ warmup_iters = 500
36
+
37
+ # --- Параметры DataLoader (умеренные) ---
38
+ num_workers = 6 # было 8 -> снижаем
39
+ prefetch_factor = 4
40
+ pin_memory = True
41
+
42
+ device = 'cpu'
43
+ print(f"Устройство: {device}")
44
+ print(f"Используется CPU потоков: {torch.get_num_threads()}")
45
+
46
+ # ---------- 2. Датасет и токенизация ----------
47
+ print("\n[1/5] Загрузка и токенизация датасета...")
48
+ dataset = load_dataset("JoshKeesee/Alfred-Indigo", split="train")
49
+ dialogue_texts = []
50
+ for example in dataset:
51
+ dialogue = "\n".join([f"{msg['role']}: {msg['content']}" for msg in example['messages']])
52
+ dialogue_texts.append(dialogue)
53
+ all_text = "\n\n".join(dialogue_texts)
54
+ print(f"Загружено {len(dialogue_texts)} диалогов. Общий объём: {len(all_text)} символов.")
55
+
56
+ enc = tiktoken.get_encoding("gpt2")
57
+ data = torch.tensor(enc.encode_ordinary(all_text), dtype=torch.long)
58
+
59
+ n = int(0.9 * len(data))
60
+ train_data = data[:n]
61
+ val_data = data[n:]
62
+
63
+ class TextDataset(Dataset):
64
+ def __init__(self, data, block_size):
65
+ self.data = data
66
+ self.block_size = block_size
67
+ def __len__(self):
68
+ return len(self.data) - self.block_size
69
+ def __getitem__(self, idx):
70
+ x = self.data[idx:idx+self.block_size]
71
+ y = self.data[idx+1:idx+self.block_size+1]
72
+ return x, y
73
+
74
+ train_dataset = TextDataset(train_data, block_size)
75
+ val_dataset = TextDataset(val_data, block_size)
76
+
77
+ # DataLoader с умеренным числом воркеров
78
+ train_loader = DataLoader(
79
+ train_dataset,
80
+ batch_size=batch_size,
81
+ shuffle=True,
82
+ num_workers=num_workers,
83
+ pin_memory=pin_memory,
84
+ prefetch_factor=prefetch_factor
85
+ )
86
+
87
+ val_loader = DataLoader(
88
+ val_dataset,
89
+ batch_size=batch_size,
90
+ shuffle=False,
91
+ num_workers=num_workers,
92
+ pin_memory=pin_memory,
93
+ prefetch_factor=prefetch_factor
94
+ )
95
+
96
+ # ---------- 3. Архитектура модели (оптимизированная) ----------
97
+ class AttentionHead(nn.Module):
98
+ def __init__(self, head_size):
99
+ super().__init__()
100
+ self.key = nn.Linear(n_embd, head_size, bias=False)
101
+ self.query = nn.Linear(n_embd, head_size, bias=False)
102
+ self.value = nn.Linear(n_embd, head_size, bias=False)
103
+ self.dropout = nn.Dropout(dropout)
104
+ self.register_buffer('tril', torch.tril(torch.ones(block_size, block_size)))
105
+ def forward(self, x):
106
+ B, T, C = x.shape
107
+ k = self.key(x)
108
+ q = self.query(x)
109
+ wei = q @ k.transpose(-2, -1) * (C ** -0.5)
110
+ wei = wei.masked_fill(self.tril[:T, :T] == 0, float('-inf'))
111
+ wei = F.softmax(wei, dim=-1)
112
+ wei = self.dropout(wei)
113
+ v = self.value(x)
114
+ return wei @ v
115
+
116
+ class MultiHeadAttention(nn.Module):
117
+ def __init__(self):
118
+ super().__init__()
119
+ head_size = n_embd // n_head
120
+ self.heads = nn.ModuleList([AttentionHead(head_size) for _ in range(n_head)])
121
+ self.proj = nn.Linear(n_embd, n_embd)
122
+ self.dropout = nn.Dropout(dropout)
123
+ def forward(self, x):
124
+ out = torch.cat([h(x) for h in self.heads], dim=-1)
125
+ out = self.dropout(self.proj(out))
126
+ return out
127
+
128
+ class FeedForward(nn.Module):
129
+ def __init__(self):
130
+ super().__init__()
131
+ self.net = nn.Sequential(
132
+ nn.Linear(n_embd, 4 * n_embd),
133
+ nn.GELU(),
134
+ nn.Linear(4 * n_embd, n_embd),
135
+ nn.Dropout(dropout)
136
+ )
137
+ def forward(self, x):
138
+ return self.net(x)
139
+
140
+ class TransformerBlock(nn.Module):
141
+ def __init__(self):
142
+ super().__init__()
143
+ self.ln1 = nn.LayerNorm(n_embd)
144
+ self.attn = MultiHeadAttention()
145
+ self.ln2 = nn.LayerNorm(n_embd)
146
+ self.ffwd = FeedForward()
147
+ def forward(self, x):
148
+ x = x + self.attn(self.ln1(x))
149
+ x = x + self.ffwd(self.ln2(x))
150
+ return x
151
+
152
+ class GPTLanguageModel(nn.Module):
153
+ def __init__(self):
154
+ super().__init__()
155
+ self.token_embedding = nn.Embedding(vocab_size, n_embd)
156
+ self.position_embedding = nn.Embedding(block_size, n_embd)
157
+ self.blocks = nn.Sequential(*[TransformerBlock() for _ in range(n_layer)])
158
+ self.ln_f = nn.LayerNorm(n_embd)
159
+ self.lm_head = nn.Linear(n_embd, vocab_size)
160
+ def forward(self, idx, targets=None):
161
+ B, T = idx.shape
162
+ tok_emb = self.token_embedding(idx)
163
+ pos_emb = self.position_embedding(torch.arange(T, device=device))
164
+ x = tok_emb + pos_emb
165
+ x = self.blocks(x)
166
+ x = self.ln_f(x)
167
+ logits = self.lm_head(x)
168
+ loss = None
169
+ if targets is not None:
170
+ B, T, C = logits.shape
171
+ logits = logits.view(B*T, C)
172
+ targets = targets.view(B*T)
173
+ loss = F.cross_entropy(logits, targets)
174
+ return logits, loss
175
+ def generate(self, idx, max_new_tokens, temperature=0.8, top_k=40):
176
+ for _ in range(max_new_tokens):
177
+ idx_cond = idx[:, -block_size:]
178
+ logits, _ = self.forward(idx_cond)
179
+ logits = logits[:, -1, :] / temperature
180
+ if top_k is not None:
181
+ v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
182
+ logits[logits < v[:, [-1]]] = -float('Inf')
183
+ probs = F.softmax(logits, dim=-1)
184
+ idx_next = torch.multinomial(probs, num_samples=1)
185
+ idx = torch.cat((idx, idx_next), dim=1)
186
+ return idx
187
+
188
+ model = GPTLanguageModel()
189
+ # Компиляция (работает на PyTorch 2.x)
190
+ model = torch.compile(model)
191
+ print(f"Модель создана. Параметров: {sum(p.numel() for p in model.parameters())/1e6:.2f}M")
192
+
193
+ # ---------- 4. Обучение ----------
194
+ def get_batch_from_loader(loader):
195
+ for x, y in loader:
196
+ yield x, y
197
+
198
+ def estimate_loss():
199
+ out = {}
200
+ model.eval()
201
+ for split, loader in [('train', train_loader), ('val', val_loader)]:
202
+ losses = torch.zeros(eval_iters)
203
+ loader_iter = iter(loader)
204
+ for k in range(eval_iters):
205
+ try:
206
+ X, Y = next(loader_iter)
207
+ except StopIteration:
208
+ loader_iter = iter(loader)
209
+ X, Y = next(loader_iter)
210
+ logits, loss = model(X, Y)
211
+ losses[k] = loss.item()
212
+ out[split] = losses.mean()
213
+ model.train()
214
+ return out
215
+
216
+ optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=0.1)
217
+
218
+ def get_lr(it):
219
+ if it < warmup_iters:
220
+ return learning_rate * (it + 1) / warmup_iters
221
+ return learning_rate
222
+
223
+ print("\n[2/5] Старт обучения (ограничение 12 CPU / 13 ГБ RAM)...")
224
+ start_time = time.time()
225
+
226
+ for iter_num in tqdm(range(max_iters), desc="Обучение"):
227
+ lr = get_lr(iter_num)
228
+ for param_group in optimizer.param_groups:
229
+ param_group['lr'] = lr
230
+
231
+ if iter_num % eval_interval == 0 or iter_num == max_iters - 1:
232
+ losses = estimate_loss()
233
+ elapsed = time.time() - start_time
234
+ print(f"\nШаг {iter_num}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f} (время {elapsed:.2f} с)")
235
+
236
+ xb, yb = next(iter(train_loader))
237
+ logits, loss = model(xb, yb)
238
+ optimizer.zero_grad(set_to_none=True)
239
+ loss.backward()
240
+ torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
241
+ optimizer.step()
242
+
243
+ print(f"\nОбучение завершено! Время: {(time.time() - start_time)/60:.2f} мин")
244
+
245
+ # Сохранение
246
+ os.makedirs('checkpoints', exist_ok=True)
247
+ torch.save(model._orig_mod.state_dict(), 'checkpoints/model_final.pth')
248
+ print("Модель сохранена в 'checkpoints/model_final.pth'")
249
+
250
+ # ---------- 5. Интерфейс Gradio ----------
251
+ def generate_response(prompt, max_new_tokens=150, temperature=0.7, top_k=40):
252
+ context = torch.tensor(enc.encode_ordinary(prompt), dtype=torch.long, device=device).unsqueeze(0)
253
+ generated_ids = model.generate(context, max_new_tokens=max_new_tokens, temperature=temperature, top_k=top_k)[0].tolist()
254
+ return enc.decode(generated_ids)
255
+
256
+ def chat_function(message, history):
257
+ return generate_response(message)
258
+
259
+ demo = gr.ChatInterface(
260
+ fn=chat_function,
261
+ title="🤖 GPT обучена с нуля (12 CPU / 13 ГБ RAM)",
262
+ description="Модель обучена на Alfred-Indigo, 6 слоёв, 6 голов внимания, контекст 256 токенов. Ограничение ресурсов: 12 ядер CPU, ~13 ГБ RAM.",
263
+ theme="soft"
264
+ )
265
+
266
+ if __name__ == "__main__":
267
+ demo.launch()