DevHunterAI commited on
Commit
c610c1d
·
verified ·
1 Parent(s): 8827556

Upload hssm_v2_gpu_pretrain.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. hssm_v2_gpu_pretrain.py +437 -0
hssm_v2_gpu_pretrain.py ADDED
@@ -0,0 +1,437 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """HSSM v2 GPU Pretraining - Colab A6000 optimized"""
2
+ import argparse
3
+ import contextlib
4
+ import json
5
+ import os
6
+ import time
7
+ from dataclasses import asdict, dataclass
8
+ from pathlib import Path
9
+ from typing import Dict, Iterator, Optional
10
+
11
+ import torch
12
+ import torch.nn as nn
13
+ import torch.nn.functional as F
14
+ from torch.utils.data import DataLoader, IterableDataset, get_worker_info
15
+ from transformers import AutoTokenizer, get_cosine_schedule_with_warmup, get_constant_schedule_with_warmup
16
+ from datasets import load_dataset
17
+
18
+ @dataclass
19
+ class HSSMV2Config:
20
+ vocab_size: int
21
+ d_model: int = 288
22
+ n_layers: int = 10
23
+ d_ff: int = 512
24
+ state_rank: int = 128
25
+ chunk_size: int = 8
26
+ dropout: float = 0.0
27
+ max_seq_len: int = 1024
28
+ tie_embeddings: bool = True
29
+ num_experts: int = 64
30
+ experts_per_token: int = 1
31
+ expert_dim: int = 2048
32
+ moe_every: int = 4
33
+ aux_loss_coef: float = 1e-2
34
+
35
+
36
+ class RMSNorm(nn.Module):
37
+ def __init__(self, dim: int, eps: float = 1e-6):
38
+ super().__init__()
39
+ self.weight = nn.Parameter(torch.ones(dim))
40
+ self.eps = eps
41
+
42
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
43
+ norm = x.pow(2).mean(dim=-1, keepdim=True)
44
+ return x * torch.rsqrt(norm + self.eps) * self.weight
45
+
46
+
47
+ class HierarchicalStateMixer(nn.Module):
48
+ def __init__(self, config: HSSMV2Config):
49
+ super().__init__()
50
+ self.d_model = config.d_model
51
+ self.state_rank = config.state_rank
52
+ self.chunk_size = config.chunk_size
53
+ self.in_proj = nn.Linear(config.d_model, config.d_model * 3)
54
+ self.depthwise = nn.Conv1d(
55
+ config.d_model, config.d_model,
56
+ kernel_size=5, padding=2, groups=config.d_model
57
+ )
58
+ self.chunk_proj = nn.Linear(config.d_model, config.d_model)
59
+ self.state_in = nn.Linear(config.d_model, config.state_rank)
60
+ self.state_out = nn.Linear(config.state_rank, config.d_model)
61
+ self.out_proj = nn.Linear(config.d_model, config.d_model)
62
+
63
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
64
+ gate, value, residual = self.in_proj(x).chunk(3, dim=-1)
65
+ local = self.depthwise(value.transpose(1, 2)).transpose(1, 2)
66
+
67
+ batch, seq_len, dim = local.shape
68
+ pad_len = (self.chunk_size - seq_len % self.chunk_size) % self.chunk_size
69
+ if pad_len:
70
+ local_padded = F.pad(local, (0, 0, 0, pad_len))
71
+ else:
72
+ local_padded = local
73
+ num_chunks = local_padded.size(1) // self.chunk_size
74
+ chunked = local_padded.view(batch, num_chunks, self.chunk_size, dim).mean(dim=2)
75
+ chunked = self.chunk_proj(chunked)
76
+ states = torch.tanh(self.state_in(chunked))
77
+ states = self.state_out(states)
78
+ expanded = states.repeat_interleave(self.chunk_size, dim=1)[:, :seq_len, :]
79
+
80
+ mixed = local + expanded + residual
81
+ return self.out_proj(torch.sigmoid(gate) * mixed)
82
+
83
+
84
+ class GatedMLP(nn.Module):
85
+ def __init__(self, config: HSSMV2Config):
86
+ super().__init__()
87
+ self.up_proj = nn.Linear(config.d_model, config.d_ff)
88
+ self.gate_proj = nn.Linear(config.d_model, config.d_ff)
89
+ self.down_proj = nn.Linear(config.d_ff, config.d_model)
90
+
91
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
92
+ return self.down_proj(F.silu(self.gate_proj(x)) * self.up_proj(x))
93
+
94
+
95
+ class ExpertMLP(nn.Module):
96
+ def __init__(self, d_model: int, expert_dim: int):
97
+ super().__init__()
98
+ self.up_proj = nn.Linear(d_model, expert_dim)
99
+ self.gate_proj = nn.Linear(d_model, expert_dim)
100
+ self.down_proj = nn.Linear(expert_dim, d_model)
101
+
102
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
103
+ return self.down_proj(F.silu(self.gate_proj(x)) * self.up_proj(x))
104
+
105
+
106
+ class SparseMoE(nn.Module):
107
+ def __init__(self, config: HSSMV2Config):
108
+ super().__init__()
109
+ self.num_experts = config.num_experts
110
+ self.experts_per_token = config.experts_per_token
111
+ self.router = nn.Linear(config.d_model, config.num_experts, bias=False)
112
+ self.experts = nn.ModuleList([
113
+ ExpertMLP(config.d_model, config.expert_dim) for _ in range(config.num_experts)
114
+ ])
115
+
116
+ def forward(self, x: torch.Tensor):
117
+ batch, seq_len, d_model = x.shape
118
+ x_flat = x.reshape(-1, d_model)
119
+ router_logits = self.router(x_flat)
120
+ router_probs = F.softmax(router_logits, dim=-1)
121
+ topk_weights, topk_indices = torch.topk(router_probs, k=self.experts_per_token, dim=-1)
122
+ if self.experts_per_token > 1:
123
+ topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
124
+
125
+ output = torch.zeros_like(x_flat)
126
+ expert_load = []
127
+ for expert_id, expert in enumerate(self.experts):
128
+ token_mask = topk_indices == expert_id
129
+ expert_load.append(token_mask.any(dim=-1).float().mean())
130
+ if not token_mask.any():
131
+ continue
132
+ token_positions, slot_positions = torch.where(token_mask)
133
+ expert_input = x_flat.index_select(0, token_positions)
134
+ expert_output = expert(expert_input)
135
+ expert_weight = topk_weights[token_positions, slot_positions].unsqueeze(-1)
136
+ output.index_add_(0, token_positions, expert_output * expert_weight)
137
+
138
+ importance = router_probs.mean(dim=0)
139
+ load = torch.stack(expert_load)
140
+ aux_loss = self.num_experts * torch.sum(importance * load)
141
+ return output.view(batch, seq_len, d_model), aux_loss
142
+
143
+
144
+ class HSSMV2Block(nn.Module):
145
+ def __init__(self, config: HSSMV2Config, use_moe: bool = False):
146
+ super().__init__()
147
+ self.norm1 = RMSNorm(config.d_model)
148
+ self.mixer = HierarchicalStateMixer(config)
149
+ self.norm2 = RMSNorm(config.d_model)
150
+ self.use_moe = use_moe
151
+ self.ff = SparseMoE(config) if use_moe else GatedMLP(config)
152
+
153
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
154
+ x = x + self.mixer(self.norm1(x))
155
+ if self.use_moe:
156
+ ff_out, aux_loss = self.ff(self.norm2(x))
157
+ x = x + ff_out
158
+ return x, aux_loss
159
+ return x + self.ff(self.norm2(x)), x.new_zeros(())
160
+
161
+
162
+ class HSSMV2LM(nn.Module):
163
+ def __init__(self, config: HSSMV2Config):
164
+ super().__init__()
165
+ self.config = config
166
+ self.embed = nn.Embedding(config.vocab_size, config.d_model)
167
+ self.blocks = nn.ModuleList([
168
+ HSSMV2Block(config, use_moe=((layer_idx + 1) % config.moe_every == 0))
169
+ for layer_idx in range(config.n_layers)
170
+ ])
171
+ self.norm = RMSNorm(config.d_model)
172
+ self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False)
173
+ if config.tie_embeddings:
174
+ self.lm_head.weight = self.embed.weight
175
+
176
+ def forward(self, input_ids: torch.Tensor, labels: Optional[torch.Tensor] = None):
177
+ x = self.embed(input_ids)
178
+ aux_loss = x.new_zeros(())
179
+ for block in self.blocks:
180
+ x, block_aux = block(x)
181
+ aux_loss = aux_loss + block_aux
182
+ x = self.norm(x)
183
+ logits = self.lm_head(x)
184
+ loss = None
185
+ if labels is not None:
186
+ ce_loss = F.cross_entropy(
187
+ logits[:, :-1, :].reshape(-1, logits.size(-1)),
188
+ labels[:, 1:].contiguous().reshape(-1),
189
+ ignore_index=-100
190
+ )
191
+ loss = ce_loss + (self.config.aux_loss_coef * aux_loss)
192
+ return {"loss": loss, "logits": logits, "aux_loss": aux_loss}
193
+
194
+ def num_parameters(self) -> int:
195
+ return sum(p.numel() for p in self.parameters())
196
+
197
+
198
+ class FineWebDataset(IterableDataset):
199
+ """First N rows of FineWeb-Edu with packing."""
200
+ def __init__(
201
+ self,
202
+ tokenizer,
203
+ max_seq_len: int,
204
+ max_rows: int = 5_000_000,
205
+ split: str = "train",
206
+ text_field: str = "text",
207
+ ):
208
+ super().__init__()
209
+ self.tokenizer = tokenizer
210
+ self.max_seq_len = max_seq_len
211
+ self.max_rows = max_rows
212
+ self.split = split
213
+ self.text_field = text_field
214
+
215
+ def _iter_texts(self):
216
+ ds = load_dataset(
217
+ "HuggingFaceFW/fineweb-edu",
218
+ name="sample-10BT",
219
+ split=self.split,
220
+ streaming=True
221
+ )
222
+ for i, item in enumerate(ds):
223
+ if i >= self.max_rows:
224
+ break
225
+ text = str(item.get(self.text_field, "") or "").strip()
226
+ if text:
227
+ yield text
228
+
229
+ def __iter__(self) -> Iterator[Dict]:
230
+ buffer = []
231
+ eos_id = self.tokenizer.eos_token_id or self.tokenizer.pad_token_id
232
+ for text in self._iter_texts():
233
+ token_ids = self.tokenizer.encode(text, add_special_tokens=False)
234
+ if not token_ids:
235
+ continue
236
+ buffer.extend(token_ids + [eos_id])
237
+ while len(buffer) >= self.max_seq_len + 1:
238
+ window = buffer[:self.max_seq_len + 1]
239
+ buffer = buffer[self.max_seq_len:]
240
+ sample = torch.tensor(window, dtype=torch.long)
241
+ yield {"input_ids": sample[:-1], "labels": sample[:-1].clone()}
242
+
243
+
244
+ def collate_batch(batch):
245
+ return {
246
+ "input_ids": torch.stack([b["input_ids"] for b in batch]),
247
+ "labels": torch.stack([b["labels"] for b in batch]),
248
+ }
249
+
250
+
251
+ def train(args):
252
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
253
+ print(f"Device: {device}")
254
+ if device.type == "cuda":
255
+ print(f"GPU: {torch.cuda.get_device_name(0)}")
256
+ print(f"VRAM: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")
257
+ torch.backends.cuda.matmul.allow_tf32 = True
258
+ torch.backends.cudnn.allow_tf32 = True
259
+ torch.backends.cudnn.benchmark = True
260
+ use_bf16 = bool(getattr(args, "bf16", True)) and device.type == "cuda"
261
+ print(f"bf16: {use_bf16}")
262
+
263
+ tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_name, use_fast=True)
264
+ if tokenizer.pad_token is None:
265
+ tokenizer.pad_token = tokenizer.eos_token or tokenizer.unk_token
266
+ tokenizer.model_max_length = int(1e30)
267
+
268
+ config = HSSMV2Config(
269
+ vocab_size=tokenizer.vocab_size,
270
+ d_model=args.d_model,
271
+ n_layers=args.n_layers,
272
+ d_ff=args.d_ff,
273
+ state_rank=args.state_rank,
274
+ chunk_size=args.chunk_size,
275
+ max_seq_len=args.max_seq_len,
276
+ )
277
+
278
+ model = HSSMV2LM(config)
279
+ total_params = model.num_parameters()
280
+ print(f"Total params: {total_params:,} ({total_params/1e6:.2f}M)")
281
+
282
+ # Calculate active params (non-MoE layers + 1 expert per MoE layer)
283
+ active_params = sum(
284
+ p.numel() for name, p in model.named_parameters()
285
+ if "experts" not in name or f".experts." in name
286
+ )
287
+ # Actually active is ~d_model paths
288
+ print(f"Active per forward: ~{active_params/1e6:.2f}M")
289
+
290
+ model = model.to(device)
291
+ if device.type == "cuda" and torch.cuda.device_count() > 1:
292
+ print(f"Using {torch.cuda.device_count()} GPUs with DataParallel")
293
+ model = nn.DataParallel(model)
294
+
295
+ dataset = FineWebDataset(
296
+ tokenizer, args.max_seq_len,
297
+ max_rows=args.max_rows,
298
+ split=args.dataset_split
299
+ )
300
+ dataloader_kwargs = {
301
+ "dataset": dataset,
302
+ "batch_size": args.batch_size,
303
+ "num_workers": args.num_workers,
304
+ "collate_fn": collate_batch,
305
+ "drop_last": True,
306
+ "pin_memory": device.type == "cuda",
307
+ }
308
+ if args.num_workers > 0:
309
+ dataloader_kwargs["persistent_workers"] = True
310
+ dataloader_kwargs["prefetch_factor"] = 4
311
+ dataloader = DataLoader(**dataloader_kwargs)
312
+
313
+ optimizer = torch.optim.AdamW(
314
+ model.parameters(), lr=args.lr,
315
+ betas=(0.9, 0.95), weight_decay=args.weight_decay
316
+ )
317
+
318
+ if args.max_steps > 0:
319
+ scheduler = get_cosine_schedule_with_warmup(
320
+ optimizer, num_warmup_steps=args.warmup_steps,
321
+ num_training_steps=args.max_steps
322
+ )
323
+ else:
324
+ scheduler = get_constant_schedule_with_warmup(
325
+ optimizer, num_warmup_steps=args.warmup_steps
326
+ )
327
+
328
+ output_dir = Path(args.output_dir)
329
+ output_dir.mkdir(parents=True, exist_ok=True)
330
+
331
+ model.train()
332
+ step = 0
333
+ start_time = time.time()
334
+ grad_norm = 0.0
335
+ last_aux_loss = 0.0
336
+ optimizer.zero_grad(set_to_none=True)
337
+
338
+ for batch in dataloader:
339
+ input_ids = batch["input_ids"].to(device, non_blocking=True)
340
+ labels = batch["labels"].to(device, non_blocking=True)
341
+ labels = labels.masked_fill(labels == tokenizer.pad_token_id, -100)
342
+
343
+ autocast_ctx = torch.autocast(device_type="cuda", dtype=torch.bfloat16) if use_bf16 else contextlib.nullcontext()
344
+ with autocast_ctx:
345
+ outputs = model(input_ids=input_ids, labels=labels)
346
+ aux_loss_val = outputs.get("aux_loss")
347
+ if aux_loss_val is not None:
348
+ last_aux_loss = float(aux_loss_val.detach().item())
349
+
350
+ loss = outputs["loss"].float() / args.grad_accum_steps
351
+ loss.backward()
352
+
353
+ if (step + 1) % args.grad_accum_steps == 0:
354
+ grad_norm = torch.nn.utils.clip_grad_norm_(
355
+ model.parameters(), args.max_grad_norm
356
+ )
357
+ optimizer.step()
358
+ scheduler.step()
359
+ optimizer.zero_grad(set_to_none=True)
360
+
361
+ step += 1
362
+
363
+ if step % args.log_every == 0:
364
+ elapsed = time.time() - start_time
365
+ tokens = step * args.batch_size * args.max_seq_len
366
+ print(json.dumps({
367
+ "step": step,
368
+ "loss": round(float(loss.item() * args.grad_accum_steps), 5),
369
+ "aux_loss": round(last_aux_loss, 5),
370
+ "lr": scheduler.get_last_lr()[0],
371
+ "tokens": tokens,
372
+ "tokens_per_sec": round(tokens / max(elapsed, 1e-6), 2),
373
+ "grad_norm": round(float(grad_norm), 4) if isinstance(grad_norm, torch.Tensor) else float(grad_norm),
374
+ "gpu_mem_gb": round(torch.cuda.memory_allocated() / 1e9, 2) if device.type == "cuda" else 0
375
+ }))
376
+
377
+ if step % args.save_every == 0:
378
+ checkpoint = {
379
+ "step": step,
380
+ "model_state_dict": model.module.state_dict() if hasattr(model, "module") else model.state_dict(),
381
+ "optimizer_state_dict": optimizer.state_dict(),
382
+ "scheduler_state_dict": scheduler.state_dict(),
383
+ "config": asdict(config),
384
+ }
385
+ torch.save(checkpoint, output_dir / f"step_{step:07d}.pt")
386
+ torch.save(checkpoint, output_dir / "latest.pt")
387
+
388
+ if args.max_steps > 0 and step >= args.max_steps:
389
+ break
390
+
391
+ # Final save
392
+ final = {
393
+ "step": step,
394
+ "model_state_dict": model.module.state_dict() if hasattr(model, "module") else model.state_dict(),
395
+ "config": asdict(config),
396
+ "finished_at": time.time()
397
+ }
398
+ torch.save(final, output_dir / "final.pt")
399
+ print(f"Training complete. Final checkpoint: {output_dir / 'final.pt'}")
400
+
401
+
402
+ def parse_args():
403
+ parser = argparse.ArgumentParser()
404
+ parser.add_argument("--dataset-split", default="train")
405
+ parser.add_argument("--text-field", default="text")
406
+ parser.add_argument("--max-rows", type=int, default=5_000_000)
407
+ parser.add_argument("--tokenizer-name", default="gpt2")
408
+ parser.add_argument("--output-dir", default="/content/hssm_v2_runs")
409
+ parser.add_argument("--max-seq-len", type=int, default=1024)
410
+ parser.add_argument("--batch-size", type=int, default=256)
411
+ parser.add_argument("--grad-accum-steps", type=int, default=1)
412
+ parser.add_argument("--max-steps", type=int, default=50_000)
413
+ parser.add_argument("--lr", type=float, default=3e-4)
414
+ parser.add_argument("--weight-decay", type=float, default=0.1)
415
+ parser.add_argument("--warmup-steps", type=int, default=1000)
416
+ parser.add_argument("--max-grad-norm", type=float, default=1.0)
417
+ parser.add_argument("--save-every", type=int, default=5000)
418
+ parser.add_argument("--log-every", type=int, default=10)
419
+ parser.add_argument("--num-workers", type=int, default=8)
420
+ parser.add_argument("--bf16", action="store_true")
421
+ parser.add_argument("--no-bf16", action="store_false", dest="bf16")
422
+ parser.set_defaults(bf16=True)
423
+ parser.add_argument("--d-model", type=int, default=288)
424
+ parser.add_argument("--n-layers", type=int, default=10)
425
+ parser.add_argument("--d-ff", type=int, default=512)
426
+ parser.add_argument("--state-rank", type=int, default=128)
427
+ parser.add_argument("--chunk-size", type=int, default=8)
428
+ parser.add_argument("--num-experts", type=int, default=64)
429
+ parser.add_argument("--experts-per-token", type=int, default=1)
430
+ parser.add_argument("--expert-dim", type=int, default=2048)
431
+ parser.add_argument("--moe-every", type=int, default=4)
432
+ parser.add_argument("--aux-loss-coef", type=float, default=1e-2)
433
+ return parser.parse_args()
434
+
435
+
436
+ if __name__ == "__main__":
437
+ train(parse_args())