| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| import inspect |
|
|
| import torch |
| import torch.distributed as dist |
|
|
|
|
| def get_adamw( |
| model: torch.nn.Module, |
| weight_decay: float, |
| learning_rate: float, |
| betas: tuple[float, float], |
| device_type: str, |
| ) -> torch.optim.AdamW: |
| """ |
| Create an AdamW optimizer for the given model with specified parameters. |
| |
| Args: |
| model (torch.nn.Module): The model for which the optimizer is created. |
| weight_decay (float): The weight decay (L2 penalty) for the optimizer. |
| learning_rate (float): The learning rate for the optimizer. |
| betas (tuple): Coefficients used for computing running averages of gradient and its square. |
| device_type (str): The device type ('cuda' or 'cpu') on which the optimizer will operate. |
| |
| Returns: |
| torch.optim.AdamW: The AdamW optimizer configured with the specified parameters. |
| """ |
| |
| param_dict = {pn: p for pn, p in model.named_parameters()} |
| |
| param_dict = {pn: p for pn, p in param_dict.items() if p.requires_grad} |
| |
| |
| decay_params = [p for n, p in param_dict.items() if p.dim() >= 2] |
| nodecay_params = [p for n, p in param_dict.items() if p.dim() < 2] |
| optim_groups = [ |
| {"params": decay_params, "weight_decay": weight_decay}, |
| {"params": nodecay_params, "weight_decay": 0.0}, |
| ] |
| num_decay_params = sum(p.numel() for p in decay_params) |
| num_nodecay_params = sum(p.numel() for p in nodecay_params) |
| print( |
| f"num decayed parameter tensors: {len(decay_params)}, with {num_decay_params:,} parameters" |
| ) |
| print( |
| f"num non-decayed parameter tensors: {len(nodecay_params)}, with {num_nodecay_params:,} parameters" |
| ) |
| |
| fused_available = "fused" in inspect.signature(torch.optim.AdamW).parameters |
| use_fused = fused_available and device_type == "cuda" |
| extra_args = dict(fused=True) if use_fused else dict() |
| optimizer = torch.optim.AdamW( |
| optim_groups, lr=learning_rate, betas=betas, **extra_args |
| ) |
| print(f"using fused AdamW: {use_fused}") |
|
|
| return optimizer |
|
|
|
|
| def get_optimizer(configs, model: torch.nn.Module) -> torch.optim.Optimizer: |
| if configs.adam.use_adamw: |
| optimizer = get_adamw( |
| model=model, |
| weight_decay=configs.adam.weight_decay, |
| learning_rate=configs.adam.lr, |
| betas=(configs.adam.beta1, configs.adam.beta2), |
| device_type="cuda" if torch.cuda.is_available() else "cpu", |
| ) |
| else: |
| optimizer = torch.optim.Adam( |
| model.parameters(), |
| lr=configs.adam.lr, |
| weight_decay=configs.adam.weight_decay, |
| betas=(configs.adam.beta1, configs.adam.beta2), |
| ) |
| return optimizer |
|
|
|
|
| def is_loss_nan_check(loss: torch.Tensor) -> bool: |
| """check the validness of the current loss |
| |
| Args: |
| loss: the loss from the model |
| |
| Returns: |
| bool: if True, loss is not nan or inf |
| """ |
|
|
| def is_nan(x): |
| return torch.isnan(x).any() or torch.isinf(x).any() |
|
|
| def all_reduce_tensor(tensor, op=dist.ReduceOp.SUM): |
| if dist.is_initialized(): |
| dist.all_reduce(tensor, op=op) |
| return tensor |
|
|
| nan_flag = torch.tensor( |
| 1.0 if is_nan(loss) else 0.0, |
| device=loss.device if torch.cuda.is_available() else None, |
| ) |
| |
| all_reduce_tensor(nan_flag) |
| if nan_flag.item() > 0.0: |
| return True |
| return False |
|
|