mmcarpi's picture
Upload custom model with source code and tokenizer
7e9dc48 verified
import math
import torch
import torch.nn as nn
import torch.distributed as dist
from muon import MuonWithAuxAdam
from .config import AdamWConfig, OptimizerConfig
def materialize_and_synchronize(model: nn.Module):
"""
Materializes 'meta' parameters and synchronizes all model parameters across ranks.
This function finds any parameters initialized on the 'meta' device,
creates real tensors for them on the target device, initializes them on rank 0,
and broadcasts them to all other ranks. All other existing parameters are also
synchronized.
Args:
model (nn.Module): The model to be synchronized. The model is modified in-place.
"""
rank = dist.get_rank()
device = torch.device(f"cuda:{rank}")
for name, param in list(model.named_parameters()):
if param.device == torch.device("meta"):
# Materialize, initialize on rank 0, and broadcast
materialized_param = torch.empty_like(param, device=device)
if rank == 0:
nn.init.kaiming_uniform_(materialized_param, a=math.sqrt(5))
dist.broadcast(materialized_param, 0)
# Replace the meta parameter with the real, synchronized one
parent_module = model
parts = name.split(".")
for part in parts[:-1]:
parent_module = getattr(parent_module, part)
param_name = parts[-1]
delattr(parent_module, param_name)
setattr(parent_module, param_name, nn.Parameter(materialized_param))
else:
# Synchronize parameters already on a real device
dist.broadcast(param.detach(), 0)
def setup_optimizer(model: nn.Module, cfg: OptimizerConfig) -> torch.optim.Optimizer:
hidden_matrix_params = [
p
for n, p in model.named_parameters()
if p.ndim >= 2 and "embed" not in n and "head" not in n
]
embed_params = [p for n, p in model.named_parameters() if "embed" in n]
scalar_params = [p for p in model.parameters() if p.ndim < 2]
head_params = [p for n, p in model.named_parameters() if "head" in n]
adam_groups = [
dict(params=head_params, lr=cfg.head_lr),
dict(params=embed_params, lr=cfg.embed_lr),
dict(params=scalar_params, lr=cfg.scalar_lr),
]
adam_groups = [
dict(**group, betas=(0.9, 0.999), eps=1e-8, use_muon=False)
for group in adam_groups
]
muon_group = dict(
params=hidden_matrix_params, lr=cfg.muon_lr, momentum=0.95, use_muon=True
)
param_groups = [*adam_groups, muon_group]
optimizer = MuonWithAuxAdam(param_groups)
for group in optimizer.param_groups:
group["initial_lr"] = group["lr"]
return optimizer
def setup_optimizer_for_fine_tune(
model: nn.Module, cfg: AdamWConfig
) -> torch.optim.Optimizer:
decay_parameters = []
no_decay_parameters = []
no_decay_keywords = [
"bias",
"norm",
]
for name, param in model.named_parameters():
if not param.requires_grad:
continue
if any(nd in name for nd in no_decay_keywords):
no_decay_parameters.append(param)
else:
decay_parameters.append(param)
optimizer_grouped_parameters = [
{
"params": decay_parameters,
"weight_decay": cfg.weight_decay,
},
{
"params": no_decay_parameters,
"weight_decay": 0.0,
},
]
optimizer = torch.optim.AdamW(
optimizer_grouped_parameters,
lr=cfg.lr,
eps=cfg.eps,
betas=cfg.betas,
)
for group in optimizer.param_groups:
group["initial_lr"] = group["lr"]
return optimizer
def load_model_from_checkpoint(checkpoint: str):
state_dict = torch.load(checkpoint)
return state_dict
def test_model(model, config, device: torch.device, next_token: bool, is_causal: bool):
input_ids = torch.arange(
start=0, end=config.context_length - 1, device=device
).unsqueeze(0)
labels = (
torch.arange(start=1, end=config.context_length, device=device).unsqueeze(0)
if next_token
else torch.randint(0, config.num_labels, (1,), device=device)
)
attention_mask = torch.ones_like(input_ids, device=device)
output = model(
input_ids=input_ids,
labels=labels,
attention_mask=attention_mask,
is_causal=is_causal,
)
print(f"Logits shape: {output.logits.shape}")
print(f"Loss: {output.loss.item()}")
peak_memory_allocated = torch.cuda.max_memory_allocated() // 1024 // 1024
reserved_memory = torch.cuda.max_memory_reserved() // 1024 // 1024
print(f"Peak memory allocated: {peak_memory_allocated} MB")
print(f"Reserved memory: {reserved_memory} MB")
def summary(model: nn.Module):
trainable_parameters = 0
total_parameters = 0
for param in model.parameters():
size = param.numel()
total_parameters += size
if param.requires_grad:
trainable_parameters += size
print(model)
print("# Trainable parameters:", trainable_parameters)
print("# Total parameters:", total_parameters)
def check_grad(model: nn.Module, is_causal: bool):
"""
Checks the gradient flow of the model to verify causality or masking.
Args:
model: The model to check (must support inputs_embeds argument in forward).
is_causal: Whether the model should behave causally.
"""
device = next(model.parameters()).device
config = model.config
# Generate random embeddings with gradient tracking
x = torch.randn(
1,
config.context_length,
config.embedding_dim,
requires_grad=True,
device=device,
)
# Forward pass using inputs_embeds
output = model(inputs_embeds=x, attention_mask=None, is_causal=is_causal)
# Handle different output types
if hasattr(output, "logits"):
logits = output.logits
elif isinstance(output, tuple):
logits = output[0]
else:
logits = output
# Calculate loss at the middle token
t = config.context_length // 2
loss = logits[:, t, :].sum()
# Clear previous gradients and compute new ones
model.zero_grad()
loss.backward()
# Check past gradients (0 to t)
grad_up_to_t = x.grad[:, : t + 1, :]
has_grad_past = torch.all(grad_up_to_t != 0).item()
# Check future gradients (t+1 to end)
grad_after_t = x.grad[:, t + 1 :, :]
has_grad_future = torch.any(grad_after_t != 0).item()
print(f"{is_causal=} {has_grad_past=} {has_grad_future=}")