| import inspect |
| from torch.optim import AdamW |
|
|
|
|
| class CustomAdamW(AdamW): |
| def __init__(self, params, weight_decay, *args, **kwargs): |
| import pdb; pdb.set_trace() |
| if isinstance(params, dict): |
| params = [p for p in params.values() if p.requires_grad] |
| else: |
| params = [p for p in params if p.requires_grad] |
|
|
| |
| |
| decay_params = [p for p in params if p.dim() >= 2] |
| nodecay_params = [p for p in params if p.dim() < 2] |
| optim_groups = [ |
| {'params': decay_params, 'weight_decay': weight_decay}, |
| {'params': nodecay_params, 'weight_decay': 0.0} |
| ] |
| num_decay_params = sum(p.numel() for p in decay_params) |
| num_nodecay_params = sum(p.numel() for p in nodecay_params) |
| print(f"num decayed parameter tensors: {len(decay_params)}, with {num_decay_params:,} parameters") |
| print(f"num non-decayed parameter tensors: {len(nodecay_params)}, with {num_nodecay_params:,} parameters") |
| |
| |
| |
| |
|
|
| |
|
|
| super().__init__(params=optim_groups, *args, **kwargs) |
|
|