| | import logging |
| |
|
| | log = logging.getLogger() |
| |
|
| |
|
| | def get_parameter_groups(model, stage_cfg, print_log=False): |
| | """ |
| | Assign different weight decays and learning rates to different parameters. |
| | Returns a parameter group which can be passed to the optimizer. |
| | """ |
| | weight_decay = stage_cfg.weight_decay |
| | embed_weight_decay = stage_cfg.embed_weight_decay |
| | backbone_lr_ratio = stage_cfg.backbone_lr_ratio |
| | base_lr = stage_cfg.learning_rate |
| |
|
| | backbone_params = [] |
| | embed_params = [] |
| | other_params = [] |
| |
|
| | embedding_names = ['summary_pos', 'query_init', 'query_emb', 'obj_pe'] |
| | embedding_names = [e + '.weight' for e in embedding_names] |
| |
|
| | |
| | memo = set() |
| | for name, param in model.named_parameters(): |
| | if not param.requires_grad: |
| | continue |
| | |
| | if param in memo: |
| | continue |
| | memo.add(param) |
| |
|
| | if name.startswith('module'): |
| | name = name[7:] |
| |
|
| | inserted = False |
| | if name.startswith('pixel_encoder.'): |
| | backbone_params.append(param) |
| | inserted = True |
| | if print_log: |
| | log.info(f'{name} counted as a backbone parameter.') |
| | else: |
| | for e in embedding_names: |
| | if name.endswith(e): |
| | embed_params.append(param) |
| | inserted = True |
| | if print_log: |
| | log.info(f'{name} counted as an embedding parameter.') |
| | break |
| |
|
| | if not inserted: |
| | other_params.append(param) |
| |
|
| | parameter_groups = [ |
| | { |
| | 'params': backbone_params, |
| | 'lr': base_lr * backbone_lr_ratio, |
| | 'weight_decay': weight_decay |
| | }, |
| | { |
| | 'params': embed_params, |
| | 'lr': base_lr, |
| | 'weight_decay': embed_weight_decay |
| | }, |
| | { |
| | 'params': other_params, |
| | 'lr': base_lr, |
| | 'weight_decay': weight_decay |
| | }, |
| | ] |
| |
|
| | return parameter_groups |