| | import math |
| |
|
| | import torch |
| | import torch.nn as nn |
| | import torch.distributed as dist |
| |
|
| | from muon import MuonWithAuxAdam |
| |
|
| | from .config import AdamWConfig, OptimizerConfig |
| |
|
| |
|
| | def materialize_and_synchronize(model: nn.Module): |
| | """ |
| | Materializes 'meta' parameters and synchronizes all model parameters across ranks. |
| | |
| | This function finds any parameters initialized on the 'meta' device, |
| | creates real tensors for them on the target device, initializes them on rank 0, |
| | and broadcasts them to all other ranks. All other existing parameters are also |
| | synchronized. |
| | |
| | Args: |
| | model (nn.Module): The model to be synchronized. The model is modified in-place. |
| | """ |
| | rank = dist.get_rank() |
| | device = torch.device(f"cuda:{rank}") |
| |
|
| | for name, param in list(model.named_parameters()): |
| | if param.device == torch.device("meta"): |
| | |
| | materialized_param = torch.empty_like(param, device=device) |
| | if rank == 0: |
| | nn.init.kaiming_uniform_(materialized_param, a=math.sqrt(5)) |
| |
|
| | dist.broadcast(materialized_param, 0) |
| |
|
| | |
| | parent_module = model |
| | parts = name.split(".") |
| | for part in parts[:-1]: |
| | parent_module = getattr(parent_module, part) |
| |
|
| | param_name = parts[-1] |
| | delattr(parent_module, param_name) |
| | setattr(parent_module, param_name, nn.Parameter(materialized_param)) |
| | else: |
| | |
| | dist.broadcast(param.detach(), 0) |
| |
|
| |
|
| | def setup_optimizer(model: nn.Module, cfg: OptimizerConfig) -> torch.optim.Optimizer: |
| | hidden_matrix_params = [ |
| | p |
| | for n, p in model.named_parameters() |
| | if p.ndim >= 2 and "embed" not in n and "head" not in n |
| | ] |
| | embed_params = [p for n, p in model.named_parameters() if "embed" in n] |
| | scalar_params = [p for p in model.parameters() if p.ndim < 2] |
| | head_params = [p for n, p in model.named_parameters() if "head" in n] |
| |
|
| | adam_groups = [ |
| | dict(params=head_params, lr=cfg.head_lr), |
| | dict(params=embed_params, lr=cfg.embed_lr), |
| | dict(params=scalar_params, lr=cfg.scalar_lr), |
| | ] |
| | adam_groups = [ |
| | dict(**group, betas=(0.9, 0.999), eps=1e-8, use_muon=False) |
| | for group in adam_groups |
| | ] |
| | muon_group = dict( |
| | params=hidden_matrix_params, lr=cfg.muon_lr, momentum=0.95, use_muon=True |
| | ) |
| | param_groups = [*adam_groups, muon_group] |
| | optimizer = MuonWithAuxAdam(param_groups) |
| |
|
| | for group in optimizer.param_groups: |
| | group["initial_lr"] = group["lr"] |
| |
|
| | return optimizer |
| |
|
| |
|
| | def setup_optimizer_for_fine_tune( |
| | model: nn.Module, cfg: AdamWConfig |
| | ) -> torch.optim.Optimizer: |
| | decay_parameters = [] |
| | no_decay_parameters = [] |
| |
|
| | no_decay_keywords = [ |
| | "bias", |
| | "norm", |
| | ] |
| |
|
| | for name, param in model.named_parameters(): |
| | if not param.requires_grad: |
| | continue |
| |
|
| | if any(nd in name for nd in no_decay_keywords): |
| | no_decay_parameters.append(param) |
| | else: |
| | decay_parameters.append(param) |
| |
|
| | optimizer_grouped_parameters = [ |
| | { |
| | "params": decay_parameters, |
| | "weight_decay": cfg.weight_decay, |
| | }, |
| | { |
| | "params": no_decay_parameters, |
| | "weight_decay": 0.0, |
| | }, |
| | ] |
| |
|
| | optimizer = torch.optim.AdamW( |
| | optimizer_grouped_parameters, |
| | lr=cfg.lr, |
| | eps=cfg.eps, |
| | betas=cfg.betas, |
| | ) |
| |
|
| | for group in optimizer.param_groups: |
| | group["initial_lr"] = group["lr"] |
| |
|
| | return optimizer |
| |
|
| |
|
| | def load_model_from_checkpoint(checkpoint: str): |
| | state_dict = torch.load(checkpoint) |
| | return state_dict |
| |
|
| |
|
| | def test_model(model, config, device: torch.device, next_token: bool, is_causal: bool): |
| | input_ids = torch.arange( |
| | start=0, end=config.context_length - 1, device=device |
| | ).unsqueeze(0) |
| |
|
| | labels = ( |
| | torch.arange(start=1, end=config.context_length, device=device).unsqueeze(0) |
| | if next_token |
| | else torch.randint(0, config.num_labels, (1,), device=device) |
| | ) |
| |
|
| | attention_mask = torch.ones_like(input_ids, device=device) |
| | output = model( |
| | input_ids=input_ids, |
| | labels=labels, |
| | attention_mask=attention_mask, |
| | is_causal=is_causal, |
| | ) |
| |
|
| | print(f"Logits shape: {output.logits.shape}") |
| | print(f"Loss: {output.loss.item()}") |
| |
|
| | peak_memory_allocated = torch.cuda.max_memory_allocated() // 1024 // 1024 |
| | reserved_memory = torch.cuda.max_memory_reserved() // 1024 // 1024 |
| |
|
| | print(f"Peak memory allocated: {peak_memory_allocated} MB") |
| | print(f"Reserved memory: {reserved_memory} MB") |
| |
|
| |
|
| | def summary(model: nn.Module): |
| | trainable_parameters = 0 |
| | total_parameters = 0 |
| | for param in model.parameters(): |
| | size = param.numel() |
| | total_parameters += size |
| | if param.requires_grad: |
| | trainable_parameters += size |
| |
|
| | print(model) |
| | print("# Trainable parameters:", trainable_parameters) |
| | print("# Total parameters:", total_parameters) |
| |
|
| |
|
| | def check_grad(model: nn.Module, is_causal: bool): |
| | """ |
| | Checks the gradient flow of the model to verify causality or masking. |
| | |
| | Args: |
| | model: The model to check (must support inputs_embeds argument in forward). |
| | is_causal: Whether the model should behave causally. |
| | """ |
| | device = next(model.parameters()).device |
| | config = model.config |
| |
|
| | |
| | x = torch.randn( |
| | 1, |
| | config.context_length, |
| | config.embedding_dim, |
| | requires_grad=True, |
| | device=device, |
| | ) |
| |
|
| | |
| | output = model(inputs_embeds=x, attention_mask=None, is_causal=is_causal) |
| |
|
| | |
| | if hasattr(output, "logits"): |
| | logits = output.logits |
| | elif isinstance(output, tuple): |
| | logits = output[0] |
| | else: |
| | logits = output |
| |
|
| | |
| | t = config.context_length // 2 |
| | loss = logits[:, t, :].sum() |
| |
|
| | |
| | model.zero_grad() |
| | loss.backward() |
| |
|
| | |
| | grad_up_to_t = x.grad[:, : t + 1, :] |
| | has_grad_past = torch.all(grad_up_to_t != 0).item() |
| |
|
| | |
| | grad_after_t = x.grad[:, t + 1 :, :] |
| | has_grad_future = torch.any(grad_after_t != 0).item() |
| |
|
| | print(f"{is_causal=} {has_grad_past=} {has_grad_future=}") |
| |
|