HSSM-v2-250M / hssm_v2_gpu_pretrain.py
DevHunterAI's picture
Upload hssm_v2_gpu_pretrain.py with huggingface_hub
c610c1d verified
"""HSSM v2 GPU Pretraining - Colab A6000 optimized"""
import argparse
import contextlib
import json
import os
import time
from dataclasses import asdict, dataclass
from pathlib import Path
from typing import Dict, Iterator, Optional
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, IterableDataset, get_worker_info
from transformers import AutoTokenizer, get_cosine_schedule_with_warmup, get_constant_schedule_with_warmup
from datasets import load_dataset
@dataclass
class HSSMV2Config:
vocab_size: int
d_model: int = 288
n_layers: int = 10
d_ff: int = 512
state_rank: int = 128
chunk_size: int = 8
dropout: float = 0.0
max_seq_len: int = 1024
tie_embeddings: bool = True
num_experts: int = 64
experts_per_token: int = 1
expert_dim: int = 2048
moe_every: int = 4
aux_loss_coef: float = 1e-2
class RMSNorm(nn.Module):
def __init__(self, dim: int, eps: float = 1e-6):
super().__init__()
self.weight = nn.Parameter(torch.ones(dim))
self.eps = eps
def forward(self, x: torch.Tensor) -> torch.Tensor:
norm = x.pow(2).mean(dim=-1, keepdim=True)
return x * torch.rsqrt(norm + self.eps) * self.weight
class HierarchicalStateMixer(nn.Module):
def __init__(self, config: HSSMV2Config):
super().__init__()
self.d_model = config.d_model
self.state_rank = config.state_rank
self.chunk_size = config.chunk_size
self.in_proj = nn.Linear(config.d_model, config.d_model * 3)
self.depthwise = nn.Conv1d(
config.d_model, config.d_model,
kernel_size=5, padding=2, groups=config.d_model
)
self.chunk_proj = nn.Linear(config.d_model, config.d_model)
self.state_in = nn.Linear(config.d_model, config.state_rank)
self.state_out = nn.Linear(config.state_rank, config.d_model)
self.out_proj = nn.Linear(config.d_model, config.d_model)
def forward(self, x: torch.Tensor) -> torch.Tensor:
gate, value, residual = self.in_proj(x).chunk(3, dim=-1)
local = self.depthwise(value.transpose(1, 2)).transpose(1, 2)
batch, seq_len, dim = local.shape
pad_len = (self.chunk_size - seq_len % self.chunk_size) % self.chunk_size
if pad_len:
local_padded = F.pad(local, (0, 0, 0, pad_len))
else:
local_padded = local
num_chunks = local_padded.size(1) // self.chunk_size
chunked = local_padded.view(batch, num_chunks, self.chunk_size, dim).mean(dim=2)
chunked = self.chunk_proj(chunked)
states = torch.tanh(self.state_in(chunked))
states = self.state_out(states)
expanded = states.repeat_interleave(self.chunk_size, dim=1)[:, :seq_len, :]
mixed = local + expanded + residual
return self.out_proj(torch.sigmoid(gate) * mixed)
class GatedMLP(nn.Module):
def __init__(self, config: HSSMV2Config):
super().__init__()
self.up_proj = nn.Linear(config.d_model, config.d_ff)
self.gate_proj = nn.Linear(config.d_model, config.d_ff)
self.down_proj = nn.Linear(config.d_ff, config.d_model)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.down_proj(F.silu(self.gate_proj(x)) * self.up_proj(x))
class ExpertMLP(nn.Module):
def __init__(self, d_model: int, expert_dim: int):
super().__init__()
self.up_proj = nn.Linear(d_model, expert_dim)
self.gate_proj = nn.Linear(d_model, expert_dim)
self.down_proj = nn.Linear(expert_dim, d_model)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.down_proj(F.silu(self.gate_proj(x)) * self.up_proj(x))
class SparseMoE(nn.Module):
def __init__(self, config: HSSMV2Config):
super().__init__()
self.num_experts = config.num_experts
self.experts_per_token = config.experts_per_token
self.router = nn.Linear(config.d_model, config.num_experts, bias=False)
self.experts = nn.ModuleList([
ExpertMLP(config.d_model, config.expert_dim) for _ in range(config.num_experts)
])
def forward(self, x: torch.Tensor):
batch, seq_len, d_model = x.shape
x_flat = x.reshape(-1, d_model)
router_logits = self.router(x_flat)
router_probs = F.softmax(router_logits, dim=-1)
topk_weights, topk_indices = torch.topk(router_probs, k=self.experts_per_token, dim=-1)
if self.experts_per_token > 1:
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
output = torch.zeros_like(x_flat)
expert_load = []
for expert_id, expert in enumerate(self.experts):
token_mask = topk_indices == expert_id
expert_load.append(token_mask.any(dim=-1).float().mean())
if not token_mask.any():
continue
token_positions, slot_positions = torch.where(token_mask)
expert_input = x_flat.index_select(0, token_positions)
expert_output = expert(expert_input)
expert_weight = topk_weights[token_positions, slot_positions].unsqueeze(-1)
output.index_add_(0, token_positions, expert_output * expert_weight)
importance = router_probs.mean(dim=0)
load = torch.stack(expert_load)
aux_loss = self.num_experts * torch.sum(importance * load)
return output.view(batch, seq_len, d_model), aux_loss
class HSSMV2Block(nn.Module):
def __init__(self, config: HSSMV2Config, use_moe: bool = False):
super().__init__()
self.norm1 = RMSNorm(config.d_model)
self.mixer = HierarchicalStateMixer(config)
self.norm2 = RMSNorm(config.d_model)
self.use_moe = use_moe
self.ff = SparseMoE(config) if use_moe else GatedMLP(config)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = x + self.mixer(self.norm1(x))
if self.use_moe:
ff_out, aux_loss = self.ff(self.norm2(x))
x = x + ff_out
return x, aux_loss
return x + self.ff(self.norm2(x)), x.new_zeros(())
class HSSMV2LM(nn.Module):
def __init__(self, config: HSSMV2Config):
super().__init__()
self.config = config
self.embed = nn.Embedding(config.vocab_size, config.d_model)
self.blocks = nn.ModuleList([
HSSMV2Block(config, use_moe=((layer_idx + 1) % config.moe_every == 0))
for layer_idx in range(config.n_layers)
])
self.norm = RMSNorm(config.d_model)
self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False)
if config.tie_embeddings:
self.lm_head.weight = self.embed.weight
def forward(self, input_ids: torch.Tensor, labels: Optional[torch.Tensor] = None):
x = self.embed(input_ids)
aux_loss = x.new_zeros(())
for block in self.blocks:
x, block_aux = block(x)
aux_loss = aux_loss + block_aux
x = self.norm(x)
logits = self.lm_head(x)
loss = None
if labels is not None:
ce_loss = F.cross_entropy(
logits[:, :-1, :].reshape(-1, logits.size(-1)),
labels[:, 1:].contiguous().reshape(-1),
ignore_index=-100
)
loss = ce_loss + (self.config.aux_loss_coef * aux_loss)
return {"loss": loss, "logits": logits, "aux_loss": aux_loss}
def num_parameters(self) -> int:
return sum(p.numel() for p in self.parameters())
class FineWebDataset(IterableDataset):
"""First N rows of FineWeb-Edu with packing."""
def __init__(
self,
tokenizer,
max_seq_len: int,
max_rows: int = 5_000_000,
split: str = "train",
text_field: str = "text",
):
super().__init__()
self.tokenizer = tokenizer
self.max_seq_len = max_seq_len
self.max_rows = max_rows
self.split = split
self.text_field = text_field
def _iter_texts(self):
ds = load_dataset(
"HuggingFaceFW/fineweb-edu",
name="sample-10BT",
split=self.split,
streaming=True
)
for i, item in enumerate(ds):
if i >= self.max_rows:
break
text = str(item.get(self.text_field, "") or "").strip()
if text:
yield text
def __iter__(self) -> Iterator[Dict]:
buffer = []
eos_id = self.tokenizer.eos_token_id or self.tokenizer.pad_token_id
for text in self._iter_texts():
token_ids = self.tokenizer.encode(text, add_special_tokens=False)
if not token_ids:
continue
buffer.extend(token_ids + [eos_id])
while len(buffer) >= self.max_seq_len + 1:
window = buffer[:self.max_seq_len + 1]
buffer = buffer[self.max_seq_len:]
sample = torch.tensor(window, dtype=torch.long)
yield {"input_ids": sample[:-1], "labels": sample[:-1].clone()}
def collate_batch(batch):
return {
"input_ids": torch.stack([b["input_ids"] for b in batch]),
"labels": torch.stack([b["labels"] for b in batch]),
}
def train(args):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Device: {device}")
if device.type == "cuda":
print(f"GPU: {torch.cuda.get_device_name(0)}")
print(f"VRAM: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
torch.backends.cudnn.benchmark = True
use_bf16 = bool(getattr(args, "bf16", True)) and device.type == "cuda"
print(f"bf16: {use_bf16}")
tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_name, use_fast=True)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token or tokenizer.unk_token
tokenizer.model_max_length = int(1e30)
config = HSSMV2Config(
vocab_size=tokenizer.vocab_size,
d_model=args.d_model,
n_layers=args.n_layers,
d_ff=args.d_ff,
state_rank=args.state_rank,
chunk_size=args.chunk_size,
max_seq_len=args.max_seq_len,
)
model = HSSMV2LM(config)
total_params = model.num_parameters()
print(f"Total params: {total_params:,} ({total_params/1e6:.2f}M)")
# Calculate active params (non-MoE layers + 1 expert per MoE layer)
active_params = sum(
p.numel() for name, p in model.named_parameters()
if "experts" not in name or f".experts." in name
)
# Actually active is ~d_model paths
print(f"Active per forward: ~{active_params/1e6:.2f}M")
model = model.to(device)
if device.type == "cuda" and torch.cuda.device_count() > 1:
print(f"Using {torch.cuda.device_count()} GPUs with DataParallel")
model = nn.DataParallel(model)
dataset = FineWebDataset(
tokenizer, args.max_seq_len,
max_rows=args.max_rows,
split=args.dataset_split
)
dataloader_kwargs = {
"dataset": dataset,
"batch_size": args.batch_size,
"num_workers": args.num_workers,
"collate_fn": collate_batch,
"drop_last": True,
"pin_memory": device.type == "cuda",
}
if args.num_workers > 0:
dataloader_kwargs["persistent_workers"] = True
dataloader_kwargs["prefetch_factor"] = 4
dataloader = DataLoader(**dataloader_kwargs)
optimizer = torch.optim.AdamW(
model.parameters(), lr=args.lr,
betas=(0.9, 0.95), weight_decay=args.weight_decay
)
if args.max_steps > 0:
scheduler = get_cosine_schedule_with_warmup(
optimizer, num_warmup_steps=args.warmup_steps,
num_training_steps=args.max_steps
)
else:
scheduler = get_constant_schedule_with_warmup(
optimizer, num_warmup_steps=args.warmup_steps
)
output_dir = Path(args.output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
model.train()
step = 0
start_time = time.time()
grad_norm = 0.0
last_aux_loss = 0.0
optimizer.zero_grad(set_to_none=True)
for batch in dataloader:
input_ids = batch["input_ids"].to(device, non_blocking=True)
labels = batch["labels"].to(device, non_blocking=True)
labels = labels.masked_fill(labels == tokenizer.pad_token_id, -100)
autocast_ctx = torch.autocast(device_type="cuda", dtype=torch.bfloat16) if use_bf16 else contextlib.nullcontext()
with autocast_ctx:
outputs = model(input_ids=input_ids, labels=labels)
aux_loss_val = outputs.get("aux_loss")
if aux_loss_val is not None:
last_aux_loss = float(aux_loss_val.detach().item())
loss = outputs["loss"].float() / args.grad_accum_steps
loss.backward()
if (step + 1) % args.grad_accum_steps == 0:
grad_norm = torch.nn.utils.clip_grad_norm_(
model.parameters(), args.max_grad_norm
)
optimizer.step()
scheduler.step()
optimizer.zero_grad(set_to_none=True)
step += 1
if step % args.log_every == 0:
elapsed = time.time() - start_time
tokens = step * args.batch_size * args.max_seq_len
print(json.dumps({
"step": step,
"loss": round(float(loss.item() * args.grad_accum_steps), 5),
"aux_loss": round(last_aux_loss, 5),
"lr": scheduler.get_last_lr()[0],
"tokens": tokens,
"tokens_per_sec": round(tokens / max(elapsed, 1e-6), 2),
"grad_norm": round(float(grad_norm), 4) if isinstance(grad_norm, torch.Tensor) else float(grad_norm),
"gpu_mem_gb": round(torch.cuda.memory_allocated() / 1e9, 2) if device.type == "cuda" else 0
}))
if step % args.save_every == 0:
checkpoint = {
"step": step,
"model_state_dict": model.module.state_dict() if hasattr(model, "module") else model.state_dict(),
"optimizer_state_dict": optimizer.state_dict(),
"scheduler_state_dict": scheduler.state_dict(),
"config": asdict(config),
}
torch.save(checkpoint, output_dir / f"step_{step:07d}.pt")
torch.save(checkpoint, output_dir / "latest.pt")
if args.max_steps > 0 and step >= args.max_steps:
break
# Final save
final = {
"step": step,
"model_state_dict": model.module.state_dict() if hasattr(model, "module") else model.state_dict(),
"config": asdict(config),
"finished_at": time.time()
}
torch.save(final, output_dir / "final.pt")
print(f"Training complete. Final checkpoint: {output_dir / 'final.pt'}")
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument("--dataset-split", default="train")
parser.add_argument("--text-field", default="text")
parser.add_argument("--max-rows", type=int, default=5_000_000)
parser.add_argument("--tokenizer-name", default="gpt2")
parser.add_argument("--output-dir", default="/content/hssm_v2_runs")
parser.add_argument("--max-seq-len", type=int, default=1024)
parser.add_argument("--batch-size", type=int, default=256)
parser.add_argument("--grad-accum-steps", type=int, default=1)
parser.add_argument("--max-steps", type=int, default=50_000)
parser.add_argument("--lr", type=float, default=3e-4)
parser.add_argument("--weight-decay", type=float, default=0.1)
parser.add_argument("--warmup-steps", type=int, default=1000)
parser.add_argument("--max-grad-norm", type=float, default=1.0)
parser.add_argument("--save-every", type=int, default=5000)
parser.add_argument("--log-every", type=int, default=10)
parser.add_argument("--num-workers", type=int, default=8)
parser.add_argument("--bf16", action="store_true")
parser.add_argument("--no-bf16", action="store_false", dest="bf16")
parser.set_defaults(bf16=True)
parser.add_argument("--d-model", type=int, default=288)
parser.add_argument("--n-layers", type=int, default=10)
parser.add_argument("--d-ff", type=int, default=512)
parser.add_argument("--state-rank", type=int, default=128)
parser.add_argument("--chunk-size", type=int, default=8)
parser.add_argument("--num-experts", type=int, default=64)
parser.add_argument("--experts-per-token", type=int, default=1)
parser.add_argument("--expert-dim", type=int, default=2048)
parser.add_argument("--moe-every", type=int, default=4)
parser.add_argument("--aux-loss-coef", type=float, default=1e-2)
return parser.parse_args()
if __name__ == "__main__":
train(parse_args())