|
|
import math |
|
|
|
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.distributed as dist |
|
|
|
|
|
from muon import MuonWithAuxAdam |
|
|
|
|
|
from .config import AdamWConfig, OptimizerConfig |
|
|
|
|
|
|
|
|
def materialize_and_synchronize(model: nn.Module): |
|
|
""" |
|
|
Materializes 'meta' parameters and synchronizes all model parameters across ranks. |
|
|
|
|
|
This function finds any parameters initialized on the 'meta' device, |
|
|
creates real tensors for them on the target device, initializes them on rank 0, |
|
|
and broadcasts them to all other ranks. All other existing parameters are also |
|
|
synchronized. |
|
|
|
|
|
Args: |
|
|
model (nn.Module): The model to be synchronized. The model is modified in-place. |
|
|
""" |
|
|
rank = dist.get_rank() |
|
|
device = torch.device(f"cuda:{rank}") |
|
|
|
|
|
for name, param in list(model.named_parameters()): |
|
|
if param.device == torch.device("meta"): |
|
|
|
|
|
materialized_param = torch.empty_like(param, device=device) |
|
|
if rank == 0: |
|
|
nn.init.kaiming_uniform_(materialized_param, a=math.sqrt(5)) |
|
|
|
|
|
dist.broadcast(materialized_param, 0) |
|
|
|
|
|
|
|
|
parent_module = model |
|
|
parts = name.split(".") |
|
|
for part in parts[:-1]: |
|
|
parent_module = getattr(parent_module, part) |
|
|
|
|
|
param_name = parts[-1] |
|
|
delattr(parent_module, param_name) |
|
|
setattr(parent_module, param_name, nn.Parameter(materialized_param)) |
|
|
else: |
|
|
|
|
|
dist.broadcast(param.detach(), 0) |
|
|
|
|
|
|
|
|
def setup_optimizer(model: nn.Module, cfg: OptimizerConfig) -> torch.optim.Optimizer: |
|
|
hidden_matrix_params = [ |
|
|
p |
|
|
for n, p in model.named_parameters() |
|
|
if p.ndim >= 2 and "embed" not in n and "head" not in n |
|
|
] |
|
|
embed_params = [p for n, p in model.named_parameters() if "embed" in n] |
|
|
scalar_params = [p for p in model.parameters() if p.ndim < 2] |
|
|
head_params = [p for n, p in model.named_parameters() if "head" in n] |
|
|
|
|
|
adam_groups = [ |
|
|
dict(params=head_params, lr=cfg.head_lr), |
|
|
dict(params=embed_params, lr=cfg.embed_lr), |
|
|
dict(params=scalar_params, lr=cfg.scalar_lr), |
|
|
] |
|
|
adam_groups = [ |
|
|
dict(**group, betas=(0.9, 0.999), eps=1e-8, use_muon=False) |
|
|
for group in adam_groups |
|
|
] |
|
|
muon_group = dict( |
|
|
params=hidden_matrix_params, lr=cfg.muon_lr, momentum=0.95, use_muon=True |
|
|
) |
|
|
param_groups = [*adam_groups, muon_group] |
|
|
optimizer = MuonWithAuxAdam(param_groups) |
|
|
|
|
|
for group in optimizer.param_groups: |
|
|
group["initial_lr"] = group["lr"] |
|
|
|
|
|
return optimizer |
|
|
|
|
|
|
|
|
def setup_optimizer_for_fine_tune( |
|
|
model: nn.Module, cfg: AdamWConfig |
|
|
) -> torch.optim.Optimizer: |
|
|
decay_parameters = [] |
|
|
no_decay_parameters = [] |
|
|
|
|
|
no_decay_keywords = [ |
|
|
"bias", |
|
|
"norm", |
|
|
] |
|
|
|
|
|
for name, param in model.named_parameters(): |
|
|
if not param.requires_grad: |
|
|
continue |
|
|
|
|
|
if any(nd in name for nd in no_decay_keywords): |
|
|
no_decay_parameters.append(param) |
|
|
else: |
|
|
decay_parameters.append(param) |
|
|
|
|
|
optimizer_grouped_parameters = [ |
|
|
{ |
|
|
"params": decay_parameters, |
|
|
"weight_decay": cfg.weight_decay, |
|
|
}, |
|
|
{ |
|
|
"params": no_decay_parameters, |
|
|
"weight_decay": 0.0, |
|
|
}, |
|
|
] |
|
|
|
|
|
optimizer = torch.optim.AdamW( |
|
|
optimizer_grouped_parameters, |
|
|
lr=cfg.lr, |
|
|
eps=cfg.eps, |
|
|
betas=cfg.betas, |
|
|
) |
|
|
|
|
|
for group in optimizer.param_groups: |
|
|
group["initial_lr"] = group["lr"] |
|
|
|
|
|
return optimizer |
|
|
|
|
|
|
|
|
def load_model_from_checkpoint(checkpoint: str): |
|
|
state_dict = torch.load(checkpoint) |
|
|
return state_dict |
|
|
|
|
|
|
|
|
def test_model(model, config, device: torch.device, next_token: bool, is_causal: bool): |
|
|
input_ids = torch.arange( |
|
|
start=0, end=config.context_length - 1, device=device |
|
|
).unsqueeze(0) |
|
|
|
|
|
labels = ( |
|
|
torch.arange(start=1, end=config.context_length, device=device).unsqueeze(0) |
|
|
if next_token |
|
|
else torch.randint(0, config.num_labels, (1,), device=device) |
|
|
) |
|
|
|
|
|
attention_mask = torch.ones_like(input_ids, device=device) |
|
|
output = model( |
|
|
input_ids=input_ids, |
|
|
labels=labels, |
|
|
attention_mask=attention_mask, |
|
|
is_causal=is_causal, |
|
|
) |
|
|
|
|
|
print(f"Logits shape: {output.logits.shape}") |
|
|
print(f"Loss: {output.loss.item()}") |
|
|
|
|
|
peak_memory_allocated = torch.cuda.max_memory_allocated() // 1024 // 1024 |
|
|
reserved_memory = torch.cuda.max_memory_reserved() // 1024 // 1024 |
|
|
|
|
|
print(f"Peak memory allocated: {peak_memory_allocated} MB") |
|
|
print(f"Reserved memory: {reserved_memory} MB") |
|
|
|
|
|
|
|
|
def summary(model: nn.Module): |
|
|
trainable_parameters = 0 |
|
|
total_parameters = 0 |
|
|
for param in model.parameters(): |
|
|
size = param.numel() |
|
|
total_parameters += size |
|
|
if param.requires_grad: |
|
|
trainable_parameters += size |
|
|
|
|
|
print(model) |
|
|
print("# Trainable parameters:", trainable_parameters) |
|
|
print("# Total parameters:", total_parameters) |
|
|
|
|
|
|
|
|
def check_grad(model: nn.Module, is_causal: bool): |
|
|
""" |
|
|
Checks the gradient flow of the model to verify causality or masking. |
|
|
|
|
|
Args: |
|
|
model: The model to check (must support inputs_embeds argument in forward). |
|
|
is_causal: Whether the model should behave causally. |
|
|
""" |
|
|
device = next(model.parameters()).device |
|
|
config = model.config |
|
|
|
|
|
|
|
|
x = torch.randn( |
|
|
1, |
|
|
config.context_length, |
|
|
config.embedding_dim, |
|
|
requires_grad=True, |
|
|
device=device, |
|
|
) |
|
|
|
|
|
|
|
|
output = model(inputs_embeds=x, attention_mask=None, is_causal=is_causal) |
|
|
|
|
|
|
|
|
if hasattr(output, "logits"): |
|
|
logits = output.logits |
|
|
elif isinstance(output, tuple): |
|
|
logits = output[0] |
|
|
else: |
|
|
logits = output |
|
|
|
|
|
|
|
|
t = config.context_length // 2 |
|
|
loss = logits[:, t, :].sum() |
|
|
|
|
|
|
|
|
model.zero_grad() |
|
|
loss.backward() |
|
|
|
|
|
|
|
|
grad_up_to_t = x.grad[:, : t + 1, :] |
|
|
has_grad_past = torch.all(grad_up_to_t != 0).item() |
|
|
|
|
|
|
|
|
grad_after_t = x.grad[:, t + 1 :, :] |
|
|
has_grad_future = torch.any(grad_after_t != 0).item() |
|
|
|
|
|
print(f"{is_causal=} {has_grad_past=} {has_grad_future=}") |
|
|
|