|
|
|
|
|
|
|
|
import math |
|
|
from pprint import pformat |
|
|
from typing import Tuple, List, Dict, Union |
|
|
|
|
|
import torch.nn |
|
|
import infinity.utils.dist as dist |
|
|
|
|
|
def filter_params(model, ndim_dict, nowd_keys=(), lr_scale=0.0) -> Tuple[ |
|
|
List[str], List[torch.nn.Parameter], List[Dict[str, Union[torch.nn.Parameter, float]]] |
|
|
]: |
|
|
with_lr_scale = hasattr(model, 'get_layer_id_and_scale_exp') and 0 < lr_scale <= 1 |
|
|
print(f'[get_param_groups][lr decay] with_lr_scale={with_lr_scale}, lr_scale={lr_scale}') |
|
|
para_groups, para_groups_dbg = {}, {} |
|
|
names, paras = [], [] |
|
|
names_no_grad = [] |
|
|
count, numel = 0, 0 |
|
|
for name, para in model.named_parameters(): |
|
|
name = name.replace('_fsdp_wrapped_module.', '') |
|
|
if not para.requires_grad: |
|
|
names_no_grad.append(name) |
|
|
continue |
|
|
count += 1 |
|
|
numel += para.numel() |
|
|
names.append(name) |
|
|
paras.append(para) |
|
|
|
|
|
if ndim_dict.get(name, 2) == 1 or name.endswith('bias') or any(k in name for k in nowd_keys): |
|
|
cur_wd_sc, group_name = 0., 'ND' |
|
|
|
|
|
|
|
|
else: |
|
|
cur_wd_sc, group_name = 1., 'D' |
|
|
|
|
|
if with_lr_scale: |
|
|
layer_id, scale_exp = model.get_layer_id_and_scale_exp(name) |
|
|
group_name = f'layer{layer_id}_' + group_name |
|
|
cur_lr_sc = lr_scale ** scale_exp |
|
|
dbg = f'[layer {layer_id}][sc = {lr_scale} ** {scale_exp}]' |
|
|
else: |
|
|
cur_lr_sc = 1. |
|
|
dbg = f'[no scale]' |
|
|
|
|
|
if group_name not in para_groups: |
|
|
para_groups[group_name] = {'params': [], 'wd_sc': cur_wd_sc, 'lr_sc': cur_lr_sc} |
|
|
para_groups_dbg[group_name] = {'params': [], 'wd_sc': cur_wd_sc, 'lr_sc': dbg} |
|
|
para_groups[group_name]['params'].append(para) |
|
|
para_groups_dbg[group_name]['params'].append(name) |
|
|
|
|
|
for g in para_groups_dbg.values(): |
|
|
g['params'] = pformat(', '.join(g['params']), width=200) |
|
|
|
|
|
print(f'[get_param_groups] param_groups = \n{pformat(para_groups_dbg, indent=2, width=240)}\n') |
|
|
|
|
|
for rk in range(dist.get_world_size()): |
|
|
dist.barrier() |
|
|
if dist.get_rank() == rk: |
|
|
print(f'[get_param_groups][rank{dist.get_rank()}] {type(model).__name__=} {count=}, {numel=}', flush=True, force=True) |
|
|
print('') |
|
|
|
|
|
assert len(names_no_grad) == 0, f'[get_param_groups] names_no_grad = \n{pformat(names_no_grad, indent=2, width=240)}\n' |
|
|
del ndim_dict |
|
|
return names, paras, list(para_groups.values()) |