itriedcoding commited on
Commit
1228d03
·
verified ·
1 Parent(s): 6384860

Upload modeling_sage_1b.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. modeling_sage_1b.py +340 -0
modeling_sage_1b.py ADDED
@@ -0,0 +1,340 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json, os, pickle, math, time, sys
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+ from torch.utils.data import Dataset, DataLoader
6
+ from tokenizers import Tokenizer, models, trainers, pre_tokenizers, decoders, processors
7
+ from datasets import load_dataset
8
+ import numpy as np
9
+ import warnings
10
+ warnings.filterwarnings("ignore")
11
+
12
+ os.environ["HF_HUB_DISABLE_SYMLINKS_WARNING"] = "1"
13
+
14
+ BASE_DIR = os.path.dirname(os.path.abspath(__file__))
15
+ DATA_DIR = os.path.join(BASE_DIR, "data")
16
+ TOKENIZER_DIR = os.path.join(BASE_DIR, "tokenizer")
17
+ MODEL_DIR = os.path.join(BASE_DIR, "model")
18
+ os.makedirs(DATA_DIR, exist_ok=True)
19
+ os.makedirs(TOKENIZER_DIR, exist_ok=True)
20
+ os.makedirs(MODEL_DIR, exist_ok=True)
21
+
22
+ torch.set_num_threads(4)
23
+
24
+ VOCAB_SIZE = 50000
25
+ HIDDEN_SIZE = 1536
26
+ NUM_LAYERS = 30
27
+ NUM_HEADS = 12
28
+ HEAD_DIM = HIDDEN_SIZE // NUM_HEADS
29
+ INTERMEDIATE_SIZE = 6144
30
+ MAX_SEQ_LEN = 128
31
+ NUM_SAMPLES = 10000
32
+ TRAIN_BATCH_SIZE = 2
33
+ GRAD_ACCUM_STEPS = 4
34
+ LEARNING_RATE = 4e-4
35
+ NUM_EPOCHS = 3
36
+ WARMUP_STEPS = 50
37
+
38
+ total_p = (VOCAB_SIZE * HIDDEN_SIZE +
39
+ NUM_LAYERS * (4 * HIDDEN_SIZE * HIDDEN_SIZE + 3 * HIDDEN_SIZE * INTERMEDIATE_SIZE + 2 * HIDDEN_SIZE) +
40
+ HIDDEN_SIZE * VOCAB_SIZE)
41
+ print(f"=== Sage 1B ({total_p/1e9:.3f}B params) ===")
42
+
43
+ # ====== STEP 1: Load English Dataset ======
44
+ print("\n--- Step 1: Loading English text dataset ---")
45
+ dataset = load_dataset("roneneldan/TinyStories", split="train", streaming=True)
46
+ samples = []
47
+ start = time.time()
48
+ for i, example in enumerate(dataset):
49
+ if i >= NUM_SAMPLES:
50
+ break
51
+ text = example.get("text", "").strip()
52
+ if len(text) >= 100:
53
+ samples.append(text)
54
+ if (i+1) % 10000 == 0:
55
+ print(f" {i+1}/{NUM_SAMPLES} scanned, {len(samples)} valid ({time.time()-start:.0f}s)")
56
+
57
+ # Supplement with more if needed
58
+ if len(samples) < 10000:
59
+ print(f" Only {len(samples)} valid samples. Trying additional sources...")
60
+ try:
61
+ ds2 = load_dataset("wikipedia", "20220301.en", split="train", streaming=True)
62
+ for i, ex in enumerate(ds2):
63
+ if len(samples) >= NUM_SAMPLES:
64
+ break
65
+ text = ex.get("text", "").strip()
66
+ if len(text) >= 200:
67
+ samples.append(text[:2000])
68
+ if (i+1) % 5000 == 0:
69
+ print(f" wiki: {i+1} scanned, {len(samples)} total")
70
+ except Exception as e:
71
+ print(f" Wikipedia supplement failed: {e}")
72
+
73
+ print(f"Collected {len(samples)} samples in {time.time()-start:.0f}s")
74
+ with open(os.path.join(DATA_DIR, "raw_texts.pkl"), "wb") as f:
75
+ pickle.dump(samples, f)
76
+
77
+ # ====== STEP 2: Train BPE Tokenizer ======
78
+ print("\n--- Step 2: Training BPE tokenizer ---")
79
+ tokenizer = Tokenizer(models.BPE())
80
+ tokenizer.pre_tokenizer = pre_tokenizers.ByteLevel(add_prefix_space=False)
81
+ tokenizer.decoder = decoders.ByteLevel()
82
+ tokenizer.post_processor = processors.ByteLevel(trim_offsets=True)
83
+ trainer = trainers.BpeTrainer(
84
+ vocab_size=VOCAB_SIZE,
85
+ special_tokens=["<PAD>", "<UNK>", "<BOS>", "<EOS>"],
86
+ min_frequency=2,
87
+ )
88
+ tokenizer.train_from_iterator(samples, trainer=trainer)
89
+ tokenizer.save(os.path.join(TOKENIZER_DIR, "tokenizer.json"))
90
+ print(f"Vocabulary size: {tokenizer.get_vocab_size()}")
91
+
92
+ # ====== STEP 3: Tokenize ======
93
+ print("\n--- Step 3: Tokenizing ---")
94
+ pad_id = tokenizer.token_to_id("<PAD>")
95
+ bos_id = tokenizer.token_to_id("<BOS>")
96
+ eos_id = tokenizer.token_to_id("<EOS>")
97
+ tokenized = []
98
+ for text in samples:
99
+ ids = tokenizer.encode(text).ids
100
+ if len(ids) > MAX_SEQ_LEN - 2:
101
+ ids = ids[:MAX_SEQ_LEN - 2]
102
+ ids = [bos_id] + ids + [eos_id]
103
+ if len(ids) < MAX_SEQ_LEN:
104
+ ids += [pad_id] * (MAX_SEQ_LEN - len(ids))
105
+ tokenized.append(ids)
106
+ tensor_data = torch.tensor(tokenized, dtype=torch.long)
107
+ torch.save(tensor_data, os.path.join(DATA_DIR, "tokenized.pt"))
108
+ print(f"Tokenized {len(tokenized)} sequences, shape: {tensor_data.shape}")
109
+
110
+ # ====== STEP 4: Build Model ======
111
+ print("\n--- Step 4: Building Sage 1B model ---")
112
+
113
+ class RotaryEmbedding(nn.Module):
114
+ def __init__(self, dim, max_seq_len=MAX_SEQ_LEN):
115
+ super().__init__()
116
+ inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim))
117
+ self.register_buffer("inv_freq", inv_freq)
118
+ self.max_seq_len = max_seq_len
119
+ self._cos = None
120
+ self._sin = None
121
+ def get_cos_sin(self, x, seq_len=None):
122
+ seq_len = seq_len or x.size(1)
123
+ if self._cos is None or self._cos.size(-2) < seq_len:
124
+ t = torch.arange(self.max_seq_len, device=x.device).type_as(self.inv_freq)
125
+ freqs = torch.einsum("i,j->ij", t, self.inv_freq)
126
+ emb = torch.cat((freqs, freqs), dim=-1)[None, None]
127
+ self._cos = emb.cos()
128
+ self._sin = emb.sin()
129
+ return self._cos[..., :seq_len, :], self._sin[..., :seq_len, :]
130
+
131
+ def rotate_half(x):
132
+ x1, x2 = x.chunk(2, dim=-1)
133
+ return torch.cat((-x2, x1), dim=-1)
134
+
135
+ def apply_rotary(x, cos, sin):
136
+ return (x * cos) + (rotate_half(x) * sin)
137
+
138
+ class Attention(nn.Module):
139
+ def __init__(self):
140
+ super().__init__()
141
+ self.q_proj = nn.Linear(HIDDEN_SIZE, HIDDEN_SIZE, bias=False)
142
+ self.k_proj = nn.Linear(HIDDEN_SIZE, HIDDEN_SIZE, bias=False)
143
+ self.v_proj = nn.Linear(HIDDEN_SIZE, HIDDEN_SIZE, bias=False)
144
+ self.o_proj = nn.Linear(HIDDEN_SIZE, HIDDEN_SIZE, bias=False)
145
+ def forward(self, x, cos, sin, mask):
146
+ B, T, _ = x.shape
147
+ q = self.q_proj(x).reshape(B, T, NUM_HEADS, HEAD_DIM).transpose(1, 2)
148
+ k = self.k_proj(x).reshape(B, T, NUM_HEADS, HEAD_DIM).transpose(1, 2)
149
+ v = self.v_proj(x).reshape(B, T, NUM_HEADS, HEAD_DIM).transpose(1, 2)
150
+ q, k = apply_rotary(q, cos, sin), apply_rotary(k, cos, sin)
151
+ attn = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(HEAD_DIM)
152
+ attn = attn + mask[:, :, :T, :T]
153
+ attn = F.softmax(attn, dim=-1)
154
+ return self.o_proj(attn.matmul(v).transpose(1, 2).reshape(B, T, HIDDEN_SIZE))
155
+
156
+ class FeedForward(nn.Module):
157
+ def __init__(self):
158
+ super().__init__()
159
+ self.gate = nn.Linear(HIDDEN_SIZE, INTERMEDIATE_SIZE, bias=False)
160
+ self.up = nn.Linear(HIDDEN_SIZE, INTERMEDIATE_SIZE, bias=False)
161
+ self.down = nn.Linear(INTERMEDIATE_SIZE, HIDDEN_SIZE, bias=False)
162
+ def forward(self, x):
163
+ return self.down(F.silu(self.gate(x)) * self.up(x))
164
+
165
+ class TransformerBlock(nn.Module):
166
+ def __init__(self):
167
+ super().__init__()
168
+ self.attn_norm = nn.RMSNorm(HIDDEN_SIZE, eps=1e-6)
169
+ self.ffn_norm = nn.RMSNorm(HIDDEN_SIZE, eps=1e-6)
170
+ self.attn = Attention()
171
+ self.ffn = FeedForward()
172
+ def forward(self, x, cos, sin, mask):
173
+ x = x + self.attn(self.attn_norm(x), cos, sin, mask)
174
+ x = x + self.ffn(self.ffn_norm(x))
175
+ return x
176
+
177
+ mask_cache = {}
178
+ def get_causal_mask(T, device):
179
+ if T not in mask_cache:
180
+ m = torch.triu(torch.full((T, T), float('-inf'), device=device), diagonal=1)
181
+ mask_cache[T] = m
182
+ return mask_cache[T][None, None]
183
+
184
+ class Sage1B(nn.Module):
185
+ def __init__(self):
186
+ super().__init__()
187
+ self.embed_tokens = nn.Embedding(VOCAB_SIZE, HIDDEN_SIZE)
188
+ self.layers = nn.ModuleList([TransformerBlock() for _ in range(NUM_LAYERS)])
189
+ self.norm = nn.RMSNorm(HIDDEN_SIZE, eps=1e-6)
190
+ self.lm_head = nn.Linear(HIDDEN_SIZE, VOCAB_SIZE, bias=False)
191
+ self.rotary = RotaryEmbedding(HEAD_DIM)
192
+ self.max_seq_len = MAX_SEQ_LEN
193
+ self.vocab_size = VOCAB_SIZE
194
+ self.hidden_size = HIDDEN_SIZE
195
+
196
+ def forward(self, input_ids, labels=None):
197
+ B, T = input_ids.shape
198
+ x = self.embed_tokens(input_ids) * math.sqrt(HIDDEN_SIZE)
199
+ cos, sin = self.rotary.get_cos_sin(x, T)
200
+ mask = get_causal_mask(T, x.device)
201
+ for layer in self.layers:
202
+ x = layer(x, cos, sin, mask)
203
+ x = self.norm(x)
204
+ logits = self.lm_head(x)
205
+ loss = None
206
+ if labels is not None:
207
+ loss = F.cross_entropy(logits.view(-1, VOCAB_SIZE), labels.view(-1), ignore_index=0)
208
+ return loss, logits
209
+
210
+ @torch.no_grad()
211
+ def generate(self, input_ids, max_new_tokens=50, temperature=0.8, top_k=40):
212
+ self.eval()
213
+ for _ in range(max_new_tokens):
214
+ if input_ids.size(1) > MAX_SEQ_LEN:
215
+ input_ids = input_ids[:, -MAX_SEQ_LEN:]
216
+ _, logits = self.forward(input_ids)
217
+ logits = logits[:, -1, :] / temperature
218
+ if top_k > 0:
219
+ vals = torch.topk(logits, top_k).values[:, -1:]
220
+ logits[logits < vals] = float('-inf')
221
+ probs = F.softmax(logits, dim=-1)
222
+ nxt = torch.multinomial(probs, num_samples=1)
223
+ input_ids = torch.cat([input_ids, nxt], dim=1)
224
+ if nxt.item() == 3:
225
+ break
226
+ return input_ids
227
+
228
+ model = Sage1B()
229
+ total_params = sum(p.numel() for p in model.parameters())
230
+ print(f"Parameters: {total_params:,} ({total_params/1e9:.3f}B)")
231
+
232
+ config = {
233
+ "vocab_size": VOCAB_SIZE, "hidden_size": HIDDEN_SIZE,
234
+ "num_hidden_layers": NUM_LAYERS, "num_attention_heads": NUM_HEADS,
235
+ "head_dim": HEAD_DIM, "intermediate_size": INTERMEDIATE_SIZE,
236
+ "max_position_embeddings": MAX_SEQ_LEN, "model_type": "sage_1b",
237
+ "total_params": total_params, "torch_dtype": "float32",
238
+ }
239
+ with open(os.path.join(MODEL_DIR, "config.json"), "w") as f:
240
+ json.dump(config, f, indent=2)
241
+
242
+ # Copy this file as modeling_sage_1b.py for HF distribution
243
+ with open(os.path.join(MODEL_DIR, "modeling_sage_1b.py"), "w") as f:
244
+ f.write(open(os.path.abspath(__file__)).read())
245
+
246
+ # ====== STEP 5: Train ======
247
+ print("\n--- Step 5: Training ---")
248
+ data = torch.load(os.path.join(DATA_DIR, "tokenized.pt"))
249
+ print(f"Training samples: {len(data)}")
250
+
251
+ class TextDataset(Dataset):
252
+ def __init__(self, data):
253
+ self.data = data
254
+ def __len__(self):
255
+ return len(self.data)
256
+ def __getitem__(self, idx):
257
+ t = self.data[idx]
258
+ return t[:-1], t[1:]
259
+
260
+ tds = TextDataset(data)
261
+ loader = DataLoader(tds, batch_size=TRAIN_BATCH_SIZE, shuffle=True, drop_last=True)
262
+
263
+ optimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE, betas=(0.9, 0.95), weight_decay=0.1)
264
+
265
+ def get_lr(step):
266
+ if step < WARMUP_STEPS:
267
+ return LEARNING_RATE * (step + 1) / WARMUP_STEPS
268
+ return LEARNING_RATE * (1 - min(step, 10000) / 10000 * 0.9)
269
+
270
+ best_loss = float('inf')
271
+ global_step = 0
272
+
273
+ for epoch in range(NUM_EPOCHS):
274
+ model.train()
275
+ total_loss = 0
276
+ n_batches = 0
277
+ optimizer.zero_grad()
278
+ epoch_start = time.time()
279
+
280
+ for bidx, (inp, tgt) in enumerate(loader):
281
+ loss, _ = model(inp, labels=tgt)
282
+ loss = loss / GRAD_ACCUM_STEPS
283
+ loss.backward()
284
+
285
+ if (bidx + 1) % GRAD_ACCUM_STEPS == 0:
286
+ torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
287
+ for pg in optimizer.param_groups:
288
+ pg['lr'] = get_lr(global_step)
289
+ optimizer.step()
290
+ optimizer.zero_grad()
291
+ global_step += 1
292
+
293
+ total_loss += loss.item() * GRAD_ACCUM_STEPS
294
+ n_batches += 1
295
+
296
+ if (bidx + 1) % 200 == 0:
297
+ elapsed = time.time() - epoch_start
298
+ avg = total_loss / max(n_batches, 1)
299
+ lr = optimizer.param_groups[0]['lr']
300
+ print(f" E{epoch+1} B{bidx+1}/{len(loader)} | Loss: {avg:.4f} | LR: {lr:.2e} | {elapsed:.0f}s")
301
+
302
+ avg = total_loss / max(n_batches, 1)
303
+ et = time.time() - epoch_start
304
+ print(f"Epoch {epoch+1} | Avg loss: {avg:.4f} | Time: {et:.0f}s | Steps: {global_step}")
305
+
306
+ if avg < best_loss:
307
+ best_loss = avg
308
+ sd = model.state_dict()
309
+ torch.save(sd, os.path.join(MODEL_DIR, "pytorch_model.bin"))
310
+ torch.save({k: v.half() if v.dtype == torch.float32 else v for k, v in sd.items()},
311
+ os.path.join(MODEL_DIR, "pytorch_model_state.bin"))
312
+ print(f" Best model saved (loss: {avg:.4f})")
313
+
314
+ # Final save
315
+ sd = model.state_dict()
316
+ torch.save(sd, os.path.join(MODEL_DIR, "pytorch_model.bin"))
317
+ torch.save({k: v.half() if v.dtype == torch.float32 else v for k, v in sd.items()},
318
+ os.path.join(MODEL_DIR, "pytorch_model_state.bin"))
319
+
320
+ # Save tokenizer pickle
321
+ with open(os.path.join(TOKENIZER_DIR, "tokenizer.pkl"), "wb") as f:
322
+ pickle.dump(tokenizer, f)
323
+
324
+ # Test generation
325
+ print("\n--- Test generation ---")
326
+ model.eval()
327
+ from tokenizers import Tokenizer as Tk
328
+ test_tokenizer = Tk.from_file(os.path.join(TOKENIZER_DIR, "tokenizer.json"))
329
+ prompt = "Once upon a time"
330
+ tokens = test_tokenizer.encode(prompt).ids
331
+ inp = torch.tensor([[1] + tokens[:20]], dtype=torch.long)
332
+ out = model.generate(inp, max_new_tokens=30, temperature=0.7)
333
+ gen_text = test_tokenizer.decode(out[0].tolist(), skip_special_tokens=True)
334
+ print(f"Prompt: {prompt}")
335
+ print(f"Generated: {gen_text}")
336
+
337
+ print(f"\n=== DONE ===")
338
+ print(f"Params: {total_params:,} ({total_params/1e9:.3f}B)")
339
+ print(f"Best loss: {best_loss:.4f}")
340
+ print(f"Model: {MODEL_DIR}")