BryanW's picture
Upload folder using huggingface_hub
3d1c0e1 verified
# Copyright (c) 2025 FoundationVision
# SPDX-License-Identifier: MIT
import math
from pprint import pformat
from typing import Tuple, List, Dict, Union
import torch.nn
import infinity.utils.dist as dist
def filter_params(model, ndim_dict, nowd_keys=(), lr_scale=0.0) -> Tuple[
List[str], List[torch.nn.Parameter], List[Dict[str, Union[torch.nn.Parameter, float]]]
]:
with_lr_scale = hasattr(model, 'get_layer_id_and_scale_exp') and 0 < lr_scale <= 1
print(f'[get_param_groups][lr decay] with_lr_scale={with_lr_scale}, lr_scale={lr_scale}')
para_groups, para_groups_dbg = {}, {}
names, paras = [], []
names_no_grad = []
count, numel = 0, 0
for name, para in model.named_parameters():
name = name.replace('_fsdp_wrapped_module.', '')
if not para.requires_grad:
names_no_grad.append(name)
continue # frozen weights
count += 1
numel += para.numel()
names.append(name)
paras.append(para)
if ndim_dict.get(name, 2) == 1 or name.endswith('bias') or any(k in name for k in nowd_keys):
cur_wd_sc, group_name = 0., 'ND'
# elif any(k in name for k in small_wd_keys):
# cur_wd_sc, group_name = small_wd, 'small_decay'
else:
cur_wd_sc, group_name = 1., 'D'
if with_lr_scale:
layer_id, scale_exp = model.get_layer_id_and_scale_exp(name)
group_name = f'layer{layer_id}_' + group_name
cur_lr_sc = lr_scale ** scale_exp
dbg = f'[layer {layer_id}][sc = {lr_scale} ** {scale_exp}]'
else:
cur_lr_sc = 1.
dbg = f'[no scale]'
if group_name not in para_groups:
para_groups[group_name] = {'params': [], 'wd_sc': cur_wd_sc, 'lr_sc': cur_lr_sc}
para_groups_dbg[group_name] = {'params': [], 'wd_sc': cur_wd_sc, 'lr_sc': dbg}
para_groups[group_name]['params'].append(para)
para_groups_dbg[group_name]['params'].append(name)
for g in para_groups_dbg.values():
g['params'] = pformat(', '.join(g['params']), width=200)
print(f'[get_param_groups] param_groups = \n{pformat(para_groups_dbg, indent=2, width=240)}\n')
for rk in range(dist.get_world_size()):
dist.barrier()
if dist.get_rank() == rk:
print(f'[get_param_groups][rank{dist.get_rank()}] {type(model).__name__=} {count=}, {numel=}', flush=True, force=True)
print('')
assert len(names_no_grad) == 0, f'[get_param_groups] names_no_grad = \n{pformat(names_no_grad, indent=2, width=240)}\n'
del ndim_dict
return names, paras, list(para_groups.values())