"""HSSM v2 GPU Pretraining - Colab A6000 optimized""" import argparse import contextlib import json import os import time from dataclasses import asdict, dataclass from pathlib import Path from typing import Dict, Iterator, Optional import torch import torch.nn as nn import torch.nn.functional as F from torch.utils.data import DataLoader, IterableDataset, get_worker_info from transformers import AutoTokenizer, get_cosine_schedule_with_warmup, get_constant_schedule_with_warmup from datasets import load_dataset @dataclass class HSSMV2Config: vocab_size: int d_model: int = 288 n_layers: int = 10 d_ff: int = 512 state_rank: int = 128 chunk_size: int = 8 dropout: float = 0.0 max_seq_len: int = 1024 tie_embeddings: bool = True num_experts: int = 64 experts_per_token: int = 1 expert_dim: int = 2048 moe_every: int = 4 aux_loss_coef: float = 1e-2 class RMSNorm(nn.Module): def __init__(self, dim: int, eps: float = 1e-6): super().__init__() self.weight = nn.Parameter(torch.ones(dim)) self.eps = eps def forward(self, x: torch.Tensor) -> torch.Tensor: norm = x.pow(2).mean(dim=-1, keepdim=True) return x * torch.rsqrt(norm + self.eps) * self.weight class HierarchicalStateMixer(nn.Module): def __init__(self, config: HSSMV2Config): super().__init__() self.d_model = config.d_model self.state_rank = config.state_rank self.chunk_size = config.chunk_size self.in_proj = nn.Linear(config.d_model, config.d_model * 3) self.depthwise = nn.Conv1d( config.d_model, config.d_model, kernel_size=5, padding=2, groups=config.d_model ) self.chunk_proj = nn.Linear(config.d_model, config.d_model) self.state_in = nn.Linear(config.d_model, config.state_rank) self.state_out = nn.Linear(config.state_rank, config.d_model) self.out_proj = nn.Linear(config.d_model, config.d_model) def forward(self, x: torch.Tensor) -> torch.Tensor: gate, value, residual = self.in_proj(x).chunk(3, dim=-1) local = self.depthwise(value.transpose(1, 2)).transpose(1, 2) batch, seq_len, dim = local.shape pad_len = (self.chunk_size - seq_len % self.chunk_size) % self.chunk_size if pad_len: local_padded = F.pad(local, (0, 0, 0, pad_len)) else: local_padded = local num_chunks = local_padded.size(1) // self.chunk_size chunked = local_padded.view(batch, num_chunks, self.chunk_size, dim).mean(dim=2) chunked = self.chunk_proj(chunked) states = torch.tanh(self.state_in(chunked)) states = self.state_out(states) expanded = states.repeat_interleave(self.chunk_size, dim=1)[:, :seq_len, :] mixed = local + expanded + residual return self.out_proj(torch.sigmoid(gate) * mixed) class GatedMLP(nn.Module): def __init__(self, config: HSSMV2Config): super().__init__() self.up_proj = nn.Linear(config.d_model, config.d_ff) self.gate_proj = nn.Linear(config.d_model, config.d_ff) self.down_proj = nn.Linear(config.d_ff, config.d_model) def forward(self, x: torch.Tensor) -> torch.Tensor: return self.down_proj(F.silu(self.gate_proj(x)) * self.up_proj(x)) class ExpertMLP(nn.Module): def __init__(self, d_model: int, expert_dim: int): super().__init__() self.up_proj = nn.Linear(d_model, expert_dim) self.gate_proj = nn.Linear(d_model, expert_dim) self.down_proj = nn.Linear(expert_dim, d_model) def forward(self, x: torch.Tensor) -> torch.Tensor: return self.down_proj(F.silu(self.gate_proj(x)) * self.up_proj(x)) class SparseMoE(nn.Module): def __init__(self, config: HSSMV2Config): super().__init__() self.num_experts = config.num_experts self.experts_per_token = config.experts_per_token self.router = nn.Linear(config.d_model, config.num_experts, bias=False) self.experts = nn.ModuleList([ ExpertMLP(config.d_model, config.expert_dim) for _ in range(config.num_experts) ]) def forward(self, x: torch.Tensor): batch, seq_len, d_model = x.shape x_flat = x.reshape(-1, d_model) router_logits = self.router(x_flat) router_probs = F.softmax(router_logits, dim=-1) topk_weights, topk_indices = torch.topk(router_probs, k=self.experts_per_token, dim=-1) if self.experts_per_token > 1: topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True) output = torch.zeros_like(x_flat) expert_load = [] for expert_id, expert in enumerate(self.experts): token_mask = topk_indices == expert_id expert_load.append(token_mask.any(dim=-1).float().mean()) if not token_mask.any(): continue token_positions, slot_positions = torch.where(token_mask) expert_input = x_flat.index_select(0, token_positions) expert_output = expert(expert_input) expert_weight = topk_weights[token_positions, slot_positions].unsqueeze(-1) output.index_add_(0, token_positions, expert_output * expert_weight) importance = router_probs.mean(dim=0) load = torch.stack(expert_load) aux_loss = self.num_experts * torch.sum(importance * load) return output.view(batch, seq_len, d_model), aux_loss class HSSMV2Block(nn.Module): def __init__(self, config: HSSMV2Config, use_moe: bool = False): super().__init__() self.norm1 = RMSNorm(config.d_model) self.mixer = HierarchicalStateMixer(config) self.norm2 = RMSNorm(config.d_model) self.use_moe = use_moe self.ff = SparseMoE(config) if use_moe else GatedMLP(config) def forward(self, x: torch.Tensor) -> torch.Tensor: x = x + self.mixer(self.norm1(x)) if self.use_moe: ff_out, aux_loss = self.ff(self.norm2(x)) x = x + ff_out return x, aux_loss return x + self.ff(self.norm2(x)), x.new_zeros(()) class HSSMV2LM(nn.Module): def __init__(self, config: HSSMV2Config): super().__init__() self.config = config self.embed = nn.Embedding(config.vocab_size, config.d_model) self.blocks = nn.ModuleList([ HSSMV2Block(config, use_moe=((layer_idx + 1) % config.moe_every == 0)) for layer_idx in range(config.n_layers) ]) self.norm = RMSNorm(config.d_model) self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False) if config.tie_embeddings: self.lm_head.weight = self.embed.weight def forward(self, input_ids: torch.Tensor, labels: Optional[torch.Tensor] = None): x = self.embed(input_ids) aux_loss = x.new_zeros(()) for block in self.blocks: x, block_aux = block(x) aux_loss = aux_loss + block_aux x = self.norm(x) logits = self.lm_head(x) loss = None if labels is not None: ce_loss = F.cross_entropy( logits[:, :-1, :].reshape(-1, logits.size(-1)), labels[:, 1:].contiguous().reshape(-1), ignore_index=-100 ) loss = ce_loss + (self.config.aux_loss_coef * aux_loss) return {"loss": loss, "logits": logits, "aux_loss": aux_loss} def num_parameters(self) -> int: return sum(p.numel() for p in self.parameters()) class FineWebDataset(IterableDataset): """First N rows of FineWeb-Edu with packing.""" def __init__( self, tokenizer, max_seq_len: int, max_rows: int = 5_000_000, split: str = "train", text_field: str = "text", ): super().__init__() self.tokenizer = tokenizer self.max_seq_len = max_seq_len self.max_rows = max_rows self.split = split self.text_field = text_field def _iter_texts(self): ds = load_dataset( "HuggingFaceFW/fineweb-edu", name="sample-10BT", split=self.split, streaming=True ) for i, item in enumerate(ds): if i >= self.max_rows: break text = str(item.get(self.text_field, "") or "").strip() if text: yield text def __iter__(self) -> Iterator[Dict]: buffer = [] eos_id = self.tokenizer.eos_token_id or self.tokenizer.pad_token_id for text in self._iter_texts(): token_ids = self.tokenizer.encode(text, add_special_tokens=False) if not token_ids: continue buffer.extend(token_ids + [eos_id]) while len(buffer) >= self.max_seq_len + 1: window = buffer[:self.max_seq_len + 1] buffer = buffer[self.max_seq_len:] sample = torch.tensor(window, dtype=torch.long) yield {"input_ids": sample[:-1], "labels": sample[:-1].clone()} def collate_batch(batch): return { "input_ids": torch.stack([b["input_ids"] for b in batch]), "labels": torch.stack([b["labels"] for b in batch]), } def train(args): device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print(f"Device: {device}") if device.type == "cuda": print(f"GPU: {torch.cuda.get_device_name(0)}") print(f"VRAM: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB") torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cudnn.allow_tf32 = True torch.backends.cudnn.benchmark = True use_bf16 = bool(getattr(args, "bf16", True)) and device.type == "cuda" print(f"bf16: {use_bf16}") tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_name, use_fast=True) if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token or tokenizer.unk_token tokenizer.model_max_length = int(1e30) config = HSSMV2Config( vocab_size=tokenizer.vocab_size, d_model=args.d_model, n_layers=args.n_layers, d_ff=args.d_ff, state_rank=args.state_rank, chunk_size=args.chunk_size, max_seq_len=args.max_seq_len, ) model = HSSMV2LM(config) total_params = model.num_parameters() print(f"Total params: {total_params:,} ({total_params/1e6:.2f}M)") # Calculate active params (non-MoE layers + 1 expert per MoE layer) active_params = sum( p.numel() for name, p in model.named_parameters() if "experts" not in name or f".experts." in name ) # Actually active is ~d_model paths print(f"Active per forward: ~{active_params/1e6:.2f}M") model = model.to(device) if device.type == "cuda" and torch.cuda.device_count() > 1: print(f"Using {torch.cuda.device_count()} GPUs with DataParallel") model = nn.DataParallel(model) dataset = FineWebDataset( tokenizer, args.max_seq_len, max_rows=args.max_rows, split=args.dataset_split ) dataloader_kwargs = { "dataset": dataset, "batch_size": args.batch_size, "num_workers": args.num_workers, "collate_fn": collate_batch, "drop_last": True, "pin_memory": device.type == "cuda", } if args.num_workers > 0: dataloader_kwargs["persistent_workers"] = True dataloader_kwargs["prefetch_factor"] = 4 dataloader = DataLoader(**dataloader_kwargs) optimizer = torch.optim.AdamW( model.parameters(), lr=args.lr, betas=(0.9, 0.95), weight_decay=args.weight_decay ) if args.max_steps > 0: scheduler = get_cosine_schedule_with_warmup( optimizer, num_warmup_steps=args.warmup_steps, num_training_steps=args.max_steps ) else: scheduler = get_constant_schedule_with_warmup( optimizer, num_warmup_steps=args.warmup_steps ) output_dir = Path(args.output_dir) output_dir.mkdir(parents=True, exist_ok=True) model.train() step = 0 start_time = time.time() grad_norm = 0.0 last_aux_loss = 0.0 optimizer.zero_grad(set_to_none=True) for batch in dataloader: input_ids = batch["input_ids"].to(device, non_blocking=True) labels = batch["labels"].to(device, non_blocking=True) labels = labels.masked_fill(labels == tokenizer.pad_token_id, -100) autocast_ctx = torch.autocast(device_type="cuda", dtype=torch.bfloat16) if use_bf16 else contextlib.nullcontext() with autocast_ctx: outputs = model(input_ids=input_ids, labels=labels) aux_loss_val = outputs.get("aux_loss") if aux_loss_val is not None: last_aux_loss = float(aux_loss_val.detach().item()) loss = outputs["loss"].float() / args.grad_accum_steps loss.backward() if (step + 1) % args.grad_accum_steps == 0: grad_norm = torch.nn.utils.clip_grad_norm_( model.parameters(), args.max_grad_norm ) optimizer.step() scheduler.step() optimizer.zero_grad(set_to_none=True) step += 1 if step % args.log_every == 0: elapsed = time.time() - start_time tokens = step * args.batch_size * args.max_seq_len print(json.dumps({ "step": step, "loss": round(float(loss.item() * args.grad_accum_steps), 5), "aux_loss": round(last_aux_loss, 5), "lr": scheduler.get_last_lr()[0], "tokens": tokens, "tokens_per_sec": round(tokens / max(elapsed, 1e-6), 2), "grad_norm": round(float(grad_norm), 4) if isinstance(grad_norm, torch.Tensor) else float(grad_norm), "gpu_mem_gb": round(torch.cuda.memory_allocated() / 1e9, 2) if device.type == "cuda" else 0 })) if step % args.save_every == 0: checkpoint = { "step": step, "model_state_dict": model.module.state_dict() if hasattr(model, "module") else model.state_dict(), "optimizer_state_dict": optimizer.state_dict(), "scheduler_state_dict": scheduler.state_dict(), "config": asdict(config), } torch.save(checkpoint, output_dir / f"step_{step:07d}.pt") torch.save(checkpoint, output_dir / "latest.pt") if args.max_steps > 0 and step >= args.max_steps: break # Final save final = { "step": step, "model_state_dict": model.module.state_dict() if hasattr(model, "module") else model.state_dict(), "config": asdict(config), "finished_at": time.time() } torch.save(final, output_dir / "final.pt") print(f"Training complete. Final checkpoint: {output_dir / 'final.pt'}") def parse_args(): parser = argparse.ArgumentParser() parser.add_argument("--dataset-split", default="train") parser.add_argument("--text-field", default="text") parser.add_argument("--max-rows", type=int, default=5_000_000) parser.add_argument("--tokenizer-name", default="gpt2") parser.add_argument("--output-dir", default="/content/hssm_v2_runs") parser.add_argument("--max-seq-len", type=int, default=1024) parser.add_argument("--batch-size", type=int, default=256) parser.add_argument("--grad-accum-steps", type=int, default=1) parser.add_argument("--max-steps", type=int, default=50_000) parser.add_argument("--lr", type=float, default=3e-4) parser.add_argument("--weight-decay", type=float, default=0.1) parser.add_argument("--warmup-steps", type=int, default=1000) parser.add_argument("--max-grad-norm", type=float, default=1.0) parser.add_argument("--save-every", type=int, default=5000) parser.add_argument("--log-every", type=int, default=10) parser.add_argument("--num-workers", type=int, default=8) parser.add_argument("--bf16", action="store_true") parser.add_argument("--no-bf16", action="store_false", dest="bf16") parser.set_defaults(bf16=True) parser.add_argument("--d-model", type=int, default=288) parser.add_argument("--n-layers", type=int, default=10) parser.add_argument("--d-ff", type=int, default=512) parser.add_argument("--state-rank", type=int, default=128) parser.add_argument("--chunk-size", type=int, default=8) parser.add_argument("--num-experts", type=int, default=64) parser.add_argument("--experts-per-token", type=int, default=1) parser.add_argument("--expert-dim", type=int, default=2048) parser.add_argument("--moe-every", type=int, default=4) parser.add_argument("--aux-loss-coef", type=float, default=1e-2) return parser.parse_args() if __name__ == "__main__": train(parse_args())