gpu_symbol / engine /core /yaml_config.py
himipo's picture
first
11aa70b
"""
Copied from RT-DETR (https://github.com/lyuwenyu/RT-DETR)
Copyright(c) 2023 lyuwenyu. All Rights Reserved.
"""
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
import re
import copy
from ._config import BaseConfig
from .workspace import create
from .yaml_utils import load_config, merge_config, merge_dict
class YAMLConfig(BaseConfig):
def __init__(self, cfg_path: str, **kwargs) -> None:
super().__init__()
cfg = load_config(cfg_path)
cfg = merge_dict(cfg, kwargs)
self.yaml_cfg = copy.deepcopy(cfg)
for k in super().__dict__:
if not k.startswith('_') and k in cfg:
self.__dict__[k] = cfg[k]
@property
def global_cfg(self, ):
return merge_config(self.yaml_cfg, inplace=False, overwrite=False)
@property
def model(self, ) -> torch.nn.Module:
if self._model is None and 'model' in self.yaml_cfg:
self._model = create(self.yaml_cfg['model'], self.global_cfg)
return super().model
@property
def postprocessor(self, ) -> torch.nn.Module:
if self._postprocessor is None and 'postprocessor' in self.yaml_cfg:
self._postprocessor = create(self.yaml_cfg['postprocessor'], self.global_cfg)
return super().postprocessor
@property
def criterion(self, ) -> torch.nn.Module:
if self._criterion is None and 'criterion' in self.yaml_cfg:
self._criterion = create(self.yaml_cfg['criterion'], self.global_cfg)
return super().criterion
@property
def optimizer(self, ) -> optim.Optimizer:
if self._optimizer is None and 'optimizer' in self.yaml_cfg:
params = self.get_optim_params(self.yaml_cfg['optimizer'], self.model)
self._optimizer = create('optimizer', self.global_cfg, params=params)
return super().optimizer
@property
def lr_scheduler(self, ) -> optim.lr_scheduler.LRScheduler:
if self._lr_scheduler is None and 'lr_scheduler' in self.yaml_cfg:
self._lr_scheduler = create('lr_scheduler', self.global_cfg, optimizer=self.optimizer)
print(f'Initial lr: {self._lr_scheduler.get_last_lr()}')
return super().lr_scheduler
@property
def lr_warmup_scheduler(self, ) -> optim.lr_scheduler.LRScheduler:
if self._lr_warmup_scheduler is None and 'lr_warmup_scheduler' in self.yaml_cfg :
self._lr_warmup_scheduler = create('lr_warmup_scheduler', self.global_cfg, lr_scheduler=self.lr_scheduler)
return super().lr_warmup_scheduler
@property
def train_dataloader(self, ) -> DataLoader:
if self._train_dataloader is None and 'train_dataloader' in self.yaml_cfg:
self._train_dataloader = self.build_dataloader('train_dataloader')
return super().train_dataloader
@property
def val_dataloader(self, ) -> DataLoader:
if self._val_dataloader is None and 'val_dataloader' in self.yaml_cfg:
self._val_dataloader = self.build_dataloader('val_dataloader')
return super().val_dataloader
@property
def ema(self, ) -> torch.nn.Module:
if self._ema is None and self.yaml_cfg.get('use_ema', False):
self._ema = create('ema', self.global_cfg, model=self.model)
return super().ema
@property
def scaler(self, ):
if self._scaler is None and self.yaml_cfg.get('use_amp', False):
self._scaler = create('scaler', self.global_cfg)
return super().scaler
@property
def evaluator(self, ):
if self._evaluator is None and 'evaluator' in self.yaml_cfg:
if self.yaml_cfg['evaluator']['type'] == 'CocoEvaluator':
from ..data import get_coco_api_from_dataset
base_ds = get_coco_api_from_dataset(self.val_dataloader.dataset)
self._evaluator = create('evaluator', self.global_cfg, coco_gt=base_ds)
else:
raise NotImplementedError(f"{self.yaml_cfg['evaluator']['type']}")
return super().evaluator
@staticmethod
def get_optim_params(cfg: dict, model: nn.Module):
"""
E.g.:
^(?=.*a)(?=.*b).*$ means including a and b
^(?=.*(?:a|b)).*$ means including a or b
^(?=.*a)(?!.*b).*$ means including a, but not b
"""
assert 'type' in cfg, ''
cfg = copy.deepcopy(cfg)
if 'params' not in cfg:
return model.parameters()
assert isinstance(cfg['params'], list), ''
param_groups = []
visited = []
for pg in cfg['params']:
pattern = pg['params']
params = {k: v for k, v in model.named_parameters() if v.requires_grad and len(re.findall(pattern, k)) > 0}
pg['params'] = params.values()
param_groups.append(pg)
visited.extend(list(params.keys()))
# print(params.keys())
names = [k for k, v in model.named_parameters() if v.requires_grad]
if len(visited) < len(names):
unseen = set(names) - set(visited)
params = {k: v for k, v in model.named_parameters() if v.requires_grad and k in unseen}
param_groups.append({'params': params.values()})
visited.extend(list(params.keys()))
# print(params.keys())
assert len(visited) == len(names), ''
return param_groups
@staticmethod
def get_rank_batch_size(cfg):
"""compute batch size for per rank if total_batch_size is provided.
"""
assert ('total_batch_size' in cfg or 'batch_size' in cfg) \
and not ('total_batch_size' in cfg and 'batch_size' in cfg), \
'`batch_size` or `total_batch_size` should be choosed one'
total_batch_size = cfg.get('total_batch_size', None)
if total_batch_size is None:
bs = cfg.get('batch_size')
else:
from ..misc import dist_utils
assert total_batch_size % dist_utils.get_world_size() == 0, \
'total_batch_size should be divisible by world size'
bs = total_batch_size // dist_utils.get_world_size()
return bs
def build_dataloader(self, name: str):
bs = self.get_rank_batch_size(self.yaml_cfg[name])
global_cfg = self.global_cfg
if 'total_batch_size' in global_cfg[name]:
# pop unexpected key for dataloader init
_ = global_cfg[name].pop('total_batch_size')
print(f'building {name} with batch_size={bs}...')
loader = create(name, global_cfg, batch_size=bs)
loader.shuffle = self.yaml_cfg[name].get('shuffle', False)
return loader