import math import torch import torch.nn as nn import torch.distributed as dist from muon import MuonWithAuxAdam from .config import AdamWConfig, OptimizerConfig def materialize_and_synchronize(model: nn.Module): """ Materializes 'meta' parameters and synchronizes all model parameters across ranks. This function finds any parameters initialized on the 'meta' device, creates real tensors for them on the target device, initializes them on rank 0, and broadcasts them to all other ranks. All other existing parameters are also synchronized. Args: model (nn.Module): The model to be synchronized. The model is modified in-place. """ rank = dist.get_rank() device = torch.device(f"cuda:{rank}") for name, param in list(model.named_parameters()): if param.device == torch.device("meta"): # Materialize, initialize on rank 0, and broadcast materialized_param = torch.empty_like(param, device=device) if rank == 0: nn.init.kaiming_uniform_(materialized_param, a=math.sqrt(5)) dist.broadcast(materialized_param, 0) # Replace the meta parameter with the real, synchronized one parent_module = model parts = name.split(".") for part in parts[:-1]: parent_module = getattr(parent_module, part) param_name = parts[-1] delattr(parent_module, param_name) setattr(parent_module, param_name, nn.Parameter(materialized_param)) else: # Synchronize parameters already on a real device dist.broadcast(param.detach(), 0) def setup_optimizer(model: nn.Module, cfg: OptimizerConfig) -> torch.optim.Optimizer: hidden_matrix_params = [ p for n, p in model.named_parameters() if p.ndim >= 2 and "embed" not in n and "head" not in n ] embed_params = [p for n, p in model.named_parameters() if "embed" in n] scalar_params = [p for p in model.parameters() if p.ndim < 2] head_params = [p for n, p in model.named_parameters() if "head" in n] adam_groups = [ dict(params=head_params, lr=cfg.head_lr), dict(params=embed_params, lr=cfg.embed_lr), dict(params=scalar_params, lr=cfg.scalar_lr), ] adam_groups = [ dict(**group, betas=(0.9, 0.999), eps=1e-8, use_muon=False) for group in adam_groups ] muon_group = dict( params=hidden_matrix_params, lr=cfg.muon_lr, momentum=0.95, use_muon=True ) param_groups = [*adam_groups, muon_group] optimizer = MuonWithAuxAdam(param_groups) for group in optimizer.param_groups: group["initial_lr"] = group["lr"] return optimizer def setup_optimizer_for_fine_tune( model: nn.Module, cfg: AdamWConfig ) -> torch.optim.Optimizer: decay_parameters = [] no_decay_parameters = [] no_decay_keywords = [ "bias", "norm", ] for name, param in model.named_parameters(): if not param.requires_grad: continue if any(nd in name for nd in no_decay_keywords): no_decay_parameters.append(param) else: decay_parameters.append(param) optimizer_grouped_parameters = [ { "params": decay_parameters, "weight_decay": cfg.weight_decay, }, { "params": no_decay_parameters, "weight_decay": 0.0, }, ] optimizer = torch.optim.AdamW( optimizer_grouped_parameters, lr=cfg.lr, eps=cfg.eps, betas=cfg.betas, ) for group in optimizer.param_groups: group["initial_lr"] = group["lr"] return optimizer def load_model_from_checkpoint(checkpoint: str): state_dict = torch.load(checkpoint) return state_dict def test_model(model, config, device: torch.device, next_token: bool, is_causal: bool): input_ids = torch.arange( start=0, end=config.context_length - 1, device=device ).unsqueeze(0) labels = ( torch.arange(start=1, end=config.context_length, device=device).unsqueeze(0) if next_token else torch.randint(0, config.num_labels, (1,), device=device) ) attention_mask = torch.ones_like(input_ids, device=device) output = model( input_ids=input_ids, labels=labels, attention_mask=attention_mask, is_causal=is_causal, ) print(f"Logits shape: {output.logits.shape}") print(f"Loss: {output.loss.item()}") peak_memory_allocated = torch.cuda.max_memory_allocated() // 1024 // 1024 reserved_memory = torch.cuda.max_memory_reserved() // 1024 // 1024 print(f"Peak memory allocated: {peak_memory_allocated} MB") print(f"Reserved memory: {reserved_memory} MB") def summary(model: nn.Module): trainable_parameters = 0 total_parameters = 0 for param in model.parameters(): size = param.numel() total_parameters += size if param.requires_grad: trainable_parameters += size print(model) print("# Trainable parameters:", trainable_parameters) print("# Total parameters:", total_parameters) def check_grad(model: nn.Module, is_causal: bool): """ Checks the gradient flow of the model to verify causality or masking. Args: model: The model to check (must support inputs_embeds argument in forward). is_causal: Whether the model should behave causally. """ device = next(model.parameters()).device config = model.config # Generate random embeddings with gradient tracking x = torch.randn( 1, config.context_length, config.embedding_dim, requires_grad=True, device=device, ) # Forward pass using inputs_embeds output = model(inputs_embeds=x, attention_mask=None, is_causal=is_causal) # Handle different output types if hasattr(output, "logits"): logits = output.logits elif isinstance(output, tuple): logits = output[0] else: logits = output # Calculate loss at the middle token t = config.context_length // 2 loss = logits[:, t, :].sum() # Clear previous gradients and compute new ones model.zero_grad() loss.backward() # Check past gradients (0 to t) grad_up_to_t = x.grad[:, : t + 1, :] has_grad_past = torch.all(grad_up_to_t != 0).item() # Check future gradients (t+1 to end) grad_after_t = x.grad[:, t + 1 :, :] has_grad_future = torch.any(grad_after_t != 0).item() print(f"{is_causal=} {has_grad_past=} {has_grad_future=}")