| """HSSM v2 GPU Pretraining - Colab A6000 optimized""" |
| import argparse |
| import contextlib |
| import json |
| import os |
| import time |
| from dataclasses import asdict, dataclass |
| from pathlib import Path |
| from typing import Dict, Iterator, Optional |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from torch.utils.data import DataLoader, IterableDataset, get_worker_info |
| from transformers import AutoTokenizer, get_cosine_schedule_with_warmup, get_constant_schedule_with_warmup |
| from datasets import load_dataset |
|
|
| @dataclass |
| class HSSMV2Config: |
| vocab_size: int |
| d_model: int = 288 |
| n_layers: int = 10 |
| d_ff: int = 512 |
| state_rank: int = 128 |
| chunk_size: int = 8 |
| dropout: float = 0.0 |
| max_seq_len: int = 1024 |
| tie_embeddings: bool = True |
| num_experts: int = 64 |
| experts_per_token: int = 1 |
| expert_dim: int = 2048 |
| moe_every: int = 4 |
| aux_loss_coef: float = 1e-2 |
|
|
|
|
| class RMSNorm(nn.Module): |
| def __init__(self, dim: int, eps: float = 1e-6): |
| super().__init__() |
| self.weight = nn.Parameter(torch.ones(dim)) |
| self.eps = eps |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| norm = x.pow(2).mean(dim=-1, keepdim=True) |
| return x * torch.rsqrt(norm + self.eps) * self.weight |
|
|
|
|
| class HierarchicalStateMixer(nn.Module): |
| def __init__(self, config: HSSMV2Config): |
| super().__init__() |
| self.d_model = config.d_model |
| self.state_rank = config.state_rank |
| self.chunk_size = config.chunk_size |
| self.in_proj = nn.Linear(config.d_model, config.d_model * 3) |
| self.depthwise = nn.Conv1d( |
| config.d_model, config.d_model, |
| kernel_size=5, padding=2, groups=config.d_model |
| ) |
| self.chunk_proj = nn.Linear(config.d_model, config.d_model) |
| self.state_in = nn.Linear(config.d_model, config.state_rank) |
| self.state_out = nn.Linear(config.state_rank, config.d_model) |
| self.out_proj = nn.Linear(config.d_model, config.d_model) |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| gate, value, residual = self.in_proj(x).chunk(3, dim=-1) |
| local = self.depthwise(value.transpose(1, 2)).transpose(1, 2) |
| |
| batch, seq_len, dim = local.shape |
| pad_len = (self.chunk_size - seq_len % self.chunk_size) % self.chunk_size |
| if pad_len: |
| local_padded = F.pad(local, (0, 0, 0, pad_len)) |
| else: |
| local_padded = local |
| num_chunks = local_padded.size(1) // self.chunk_size |
| chunked = local_padded.view(batch, num_chunks, self.chunk_size, dim).mean(dim=2) |
| chunked = self.chunk_proj(chunked) |
| states = torch.tanh(self.state_in(chunked)) |
| states = self.state_out(states) |
| expanded = states.repeat_interleave(self.chunk_size, dim=1)[:, :seq_len, :] |
| |
| mixed = local + expanded + residual |
| return self.out_proj(torch.sigmoid(gate) * mixed) |
|
|
|
|
| class GatedMLP(nn.Module): |
| def __init__(self, config: HSSMV2Config): |
| super().__init__() |
| self.up_proj = nn.Linear(config.d_model, config.d_ff) |
| self.gate_proj = nn.Linear(config.d_model, config.d_ff) |
| self.down_proj = nn.Linear(config.d_ff, config.d_model) |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| return self.down_proj(F.silu(self.gate_proj(x)) * self.up_proj(x)) |
|
|
|
|
| class ExpertMLP(nn.Module): |
| def __init__(self, d_model: int, expert_dim: int): |
| super().__init__() |
| self.up_proj = nn.Linear(d_model, expert_dim) |
| self.gate_proj = nn.Linear(d_model, expert_dim) |
| self.down_proj = nn.Linear(expert_dim, d_model) |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| return self.down_proj(F.silu(self.gate_proj(x)) * self.up_proj(x)) |
|
|
|
|
| class SparseMoE(nn.Module): |
| def __init__(self, config: HSSMV2Config): |
| super().__init__() |
| self.num_experts = config.num_experts |
| self.experts_per_token = config.experts_per_token |
| self.router = nn.Linear(config.d_model, config.num_experts, bias=False) |
| self.experts = nn.ModuleList([ |
| ExpertMLP(config.d_model, config.expert_dim) for _ in range(config.num_experts) |
| ]) |
|
|
| def forward(self, x: torch.Tensor): |
| batch, seq_len, d_model = x.shape |
| x_flat = x.reshape(-1, d_model) |
| router_logits = self.router(x_flat) |
| router_probs = F.softmax(router_logits, dim=-1) |
| topk_weights, topk_indices = torch.topk(router_probs, k=self.experts_per_token, dim=-1) |
| if self.experts_per_token > 1: |
| topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True) |
|
|
| output = torch.zeros_like(x_flat) |
| expert_load = [] |
| for expert_id, expert in enumerate(self.experts): |
| token_mask = topk_indices == expert_id |
| expert_load.append(token_mask.any(dim=-1).float().mean()) |
| if not token_mask.any(): |
| continue |
| token_positions, slot_positions = torch.where(token_mask) |
| expert_input = x_flat.index_select(0, token_positions) |
| expert_output = expert(expert_input) |
| expert_weight = topk_weights[token_positions, slot_positions].unsqueeze(-1) |
| output.index_add_(0, token_positions, expert_output * expert_weight) |
|
|
| importance = router_probs.mean(dim=0) |
| load = torch.stack(expert_load) |
| aux_loss = self.num_experts * torch.sum(importance * load) |
| return output.view(batch, seq_len, d_model), aux_loss |
|
|
|
|
| class HSSMV2Block(nn.Module): |
| def __init__(self, config: HSSMV2Config, use_moe: bool = False): |
| super().__init__() |
| self.norm1 = RMSNorm(config.d_model) |
| self.mixer = HierarchicalStateMixer(config) |
| self.norm2 = RMSNorm(config.d_model) |
| self.use_moe = use_moe |
| self.ff = SparseMoE(config) if use_moe else GatedMLP(config) |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| x = x + self.mixer(self.norm1(x)) |
| if self.use_moe: |
| ff_out, aux_loss = self.ff(self.norm2(x)) |
| x = x + ff_out |
| return x, aux_loss |
| return x + self.ff(self.norm2(x)), x.new_zeros(()) |
|
|
|
|
| class HSSMV2LM(nn.Module): |
| def __init__(self, config: HSSMV2Config): |
| super().__init__() |
| self.config = config |
| self.embed = nn.Embedding(config.vocab_size, config.d_model) |
| self.blocks = nn.ModuleList([ |
| HSSMV2Block(config, use_moe=((layer_idx + 1) % config.moe_every == 0)) |
| for layer_idx in range(config.n_layers) |
| ]) |
| self.norm = RMSNorm(config.d_model) |
| self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False) |
| if config.tie_embeddings: |
| self.lm_head.weight = self.embed.weight |
|
|
| def forward(self, input_ids: torch.Tensor, labels: Optional[torch.Tensor] = None): |
| x = self.embed(input_ids) |
| aux_loss = x.new_zeros(()) |
| for block in self.blocks: |
| x, block_aux = block(x) |
| aux_loss = aux_loss + block_aux |
| x = self.norm(x) |
| logits = self.lm_head(x) |
| loss = None |
| if labels is not None: |
| ce_loss = F.cross_entropy( |
| logits[:, :-1, :].reshape(-1, logits.size(-1)), |
| labels[:, 1:].contiguous().reshape(-1), |
| ignore_index=-100 |
| ) |
| loss = ce_loss + (self.config.aux_loss_coef * aux_loss) |
| return {"loss": loss, "logits": logits, "aux_loss": aux_loss} |
|
|
| def num_parameters(self) -> int: |
| return sum(p.numel() for p in self.parameters()) |
|
|
|
|
| class FineWebDataset(IterableDataset): |
| """First N rows of FineWeb-Edu with packing.""" |
| def __init__( |
| self, |
| tokenizer, |
| max_seq_len: int, |
| max_rows: int = 5_000_000, |
| split: str = "train", |
| text_field: str = "text", |
| ): |
| super().__init__() |
| self.tokenizer = tokenizer |
| self.max_seq_len = max_seq_len |
| self.max_rows = max_rows |
| self.split = split |
| self.text_field = text_field |
|
|
| def _iter_texts(self): |
| ds = load_dataset( |
| "HuggingFaceFW/fineweb-edu", |
| name="sample-10BT", |
| split=self.split, |
| streaming=True |
| ) |
| for i, item in enumerate(ds): |
| if i >= self.max_rows: |
| break |
| text = str(item.get(self.text_field, "") or "").strip() |
| if text: |
| yield text |
|
|
| def __iter__(self) -> Iterator[Dict]: |
| buffer = [] |
| eos_id = self.tokenizer.eos_token_id or self.tokenizer.pad_token_id |
| for text in self._iter_texts(): |
| token_ids = self.tokenizer.encode(text, add_special_tokens=False) |
| if not token_ids: |
| continue |
| buffer.extend(token_ids + [eos_id]) |
| while len(buffer) >= self.max_seq_len + 1: |
| window = buffer[:self.max_seq_len + 1] |
| buffer = buffer[self.max_seq_len:] |
| sample = torch.tensor(window, dtype=torch.long) |
| yield {"input_ids": sample[:-1], "labels": sample[:-1].clone()} |
|
|
|
|
| def collate_batch(batch): |
| return { |
| "input_ids": torch.stack([b["input_ids"] for b in batch]), |
| "labels": torch.stack([b["labels"] for b in batch]), |
| } |
|
|
|
|
| def train(args): |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| print(f"Device: {device}") |
| if device.type == "cuda": |
| print(f"GPU: {torch.cuda.get_device_name(0)}") |
| print(f"VRAM: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB") |
| torch.backends.cuda.matmul.allow_tf32 = True |
| torch.backends.cudnn.allow_tf32 = True |
| torch.backends.cudnn.benchmark = True |
| use_bf16 = bool(getattr(args, "bf16", True)) and device.type == "cuda" |
| print(f"bf16: {use_bf16}") |
|
|
| tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_name, use_fast=True) |
| if tokenizer.pad_token is None: |
| tokenizer.pad_token = tokenizer.eos_token or tokenizer.unk_token |
| tokenizer.model_max_length = int(1e30) |
|
|
| config = HSSMV2Config( |
| vocab_size=tokenizer.vocab_size, |
| d_model=args.d_model, |
| n_layers=args.n_layers, |
| d_ff=args.d_ff, |
| state_rank=args.state_rank, |
| chunk_size=args.chunk_size, |
| max_seq_len=args.max_seq_len, |
| ) |
| |
| model = HSSMV2LM(config) |
| total_params = model.num_parameters() |
| print(f"Total params: {total_params:,} ({total_params/1e6:.2f}M)") |
| |
| |
| active_params = sum( |
| p.numel() for name, p in model.named_parameters() |
| if "experts" not in name or f".experts." in name |
| ) |
| |
| print(f"Active per forward: ~{active_params/1e6:.2f}M") |
| |
| model = model.to(device) |
| if device.type == "cuda" and torch.cuda.device_count() > 1: |
| print(f"Using {torch.cuda.device_count()} GPUs with DataParallel") |
| model = nn.DataParallel(model) |
|
|
| dataset = FineWebDataset( |
| tokenizer, args.max_seq_len, |
| max_rows=args.max_rows, |
| split=args.dataset_split |
| ) |
| dataloader_kwargs = { |
| "dataset": dataset, |
| "batch_size": args.batch_size, |
| "num_workers": args.num_workers, |
| "collate_fn": collate_batch, |
| "drop_last": True, |
| "pin_memory": device.type == "cuda", |
| } |
| if args.num_workers > 0: |
| dataloader_kwargs["persistent_workers"] = True |
| dataloader_kwargs["prefetch_factor"] = 4 |
| dataloader = DataLoader(**dataloader_kwargs) |
|
|
| optimizer = torch.optim.AdamW( |
| model.parameters(), lr=args.lr, |
| betas=(0.9, 0.95), weight_decay=args.weight_decay |
| ) |
| |
| if args.max_steps > 0: |
| scheduler = get_cosine_schedule_with_warmup( |
| optimizer, num_warmup_steps=args.warmup_steps, |
| num_training_steps=args.max_steps |
| ) |
| else: |
| scheduler = get_constant_schedule_with_warmup( |
| optimizer, num_warmup_steps=args.warmup_steps |
| ) |
|
|
| output_dir = Path(args.output_dir) |
| output_dir.mkdir(parents=True, exist_ok=True) |
| |
| model.train() |
| step = 0 |
| start_time = time.time() |
| grad_norm = 0.0 |
| last_aux_loss = 0.0 |
| optimizer.zero_grad(set_to_none=True) |
|
|
| for batch in dataloader: |
| input_ids = batch["input_ids"].to(device, non_blocking=True) |
| labels = batch["labels"].to(device, non_blocking=True) |
| labels = labels.masked_fill(labels == tokenizer.pad_token_id, -100) |
|
|
| autocast_ctx = torch.autocast(device_type="cuda", dtype=torch.bfloat16) if use_bf16 else contextlib.nullcontext() |
| with autocast_ctx: |
| outputs = model(input_ids=input_ids, labels=labels) |
| aux_loss_val = outputs.get("aux_loss") |
| if aux_loss_val is not None: |
| last_aux_loss = float(aux_loss_val.detach().item()) |
| |
| loss = outputs["loss"].float() / args.grad_accum_steps |
| loss.backward() |
|
|
| if (step + 1) % args.grad_accum_steps == 0: |
| grad_norm = torch.nn.utils.clip_grad_norm_( |
| model.parameters(), args.max_grad_norm |
| ) |
| optimizer.step() |
| scheduler.step() |
| optimizer.zero_grad(set_to_none=True) |
|
|
| step += 1 |
|
|
| if step % args.log_every == 0: |
| elapsed = time.time() - start_time |
| tokens = step * args.batch_size * args.max_seq_len |
| print(json.dumps({ |
| "step": step, |
| "loss": round(float(loss.item() * args.grad_accum_steps), 5), |
| "aux_loss": round(last_aux_loss, 5), |
| "lr": scheduler.get_last_lr()[0], |
| "tokens": tokens, |
| "tokens_per_sec": round(tokens / max(elapsed, 1e-6), 2), |
| "grad_norm": round(float(grad_norm), 4) if isinstance(grad_norm, torch.Tensor) else float(grad_norm), |
| "gpu_mem_gb": round(torch.cuda.memory_allocated() / 1e9, 2) if device.type == "cuda" else 0 |
| })) |
|
|
| if step % args.save_every == 0: |
| checkpoint = { |
| "step": step, |
| "model_state_dict": model.module.state_dict() if hasattr(model, "module") else model.state_dict(), |
| "optimizer_state_dict": optimizer.state_dict(), |
| "scheduler_state_dict": scheduler.state_dict(), |
| "config": asdict(config), |
| } |
| torch.save(checkpoint, output_dir / f"step_{step:07d}.pt") |
| torch.save(checkpoint, output_dir / "latest.pt") |
|
|
| if args.max_steps > 0 and step >= args.max_steps: |
| break |
|
|
| |
| final = { |
| "step": step, |
| "model_state_dict": model.module.state_dict() if hasattr(model, "module") else model.state_dict(), |
| "config": asdict(config), |
| "finished_at": time.time() |
| } |
| torch.save(final, output_dir / "final.pt") |
| print(f"Training complete. Final checkpoint: {output_dir / 'final.pt'}") |
|
|
|
|
| def parse_args(): |
| parser = argparse.ArgumentParser() |
| parser.add_argument("--dataset-split", default="train") |
| parser.add_argument("--text-field", default="text") |
| parser.add_argument("--max-rows", type=int, default=5_000_000) |
| parser.add_argument("--tokenizer-name", default="gpt2") |
| parser.add_argument("--output-dir", default="/content/hssm_v2_runs") |
| parser.add_argument("--max-seq-len", type=int, default=1024) |
| parser.add_argument("--batch-size", type=int, default=256) |
| parser.add_argument("--grad-accum-steps", type=int, default=1) |
| parser.add_argument("--max-steps", type=int, default=50_000) |
| parser.add_argument("--lr", type=float, default=3e-4) |
| parser.add_argument("--weight-decay", type=float, default=0.1) |
| parser.add_argument("--warmup-steps", type=int, default=1000) |
| parser.add_argument("--max-grad-norm", type=float, default=1.0) |
| parser.add_argument("--save-every", type=int, default=5000) |
| parser.add_argument("--log-every", type=int, default=10) |
| parser.add_argument("--num-workers", type=int, default=8) |
| parser.add_argument("--bf16", action="store_true") |
| parser.add_argument("--no-bf16", action="store_false", dest="bf16") |
| parser.set_defaults(bf16=True) |
| parser.add_argument("--d-model", type=int, default=288) |
| parser.add_argument("--n-layers", type=int, default=10) |
| parser.add_argument("--d-ff", type=int, default=512) |
| parser.add_argument("--state-rank", type=int, default=128) |
| parser.add_argument("--chunk-size", type=int, default=8) |
| parser.add_argument("--num-experts", type=int, default=64) |
| parser.add_argument("--experts-per-token", type=int, default=1) |
| parser.add_argument("--expert-dim", type=int, default=2048) |
| parser.add_argument("--moe-every", type=int, default=4) |
| parser.add_argument("--aux-loss-coef", type=float, default=1e-2) |
| return parser.parse_args() |
|
|
|
|
| if __name__ == "__main__": |
| train(parse_args()) |
|
|