Spaces:
Runtime error
Runtime error
| import torch | |
| import torch.nn as nn | |
| from lightning.pytorch.utilities import grad_norm | |
| from mmengine import OPTIM_WRAPPERS | |
| from mmengine.optim import build_optim_wrapper, _ParamScheduler | |
| import copy | |
| from torchmetrics import MetricCollection | |
| from mmpl.registry import MODELS, METRICS | |
| import lightning.pytorch as pl | |
| from mmengine.registry import OPTIMIZERS, PARAM_SCHEDULERS | |
| from mmengine.model import BaseModel | |
| class BasePLer(pl.LightningModule, BaseModel): | |
| def __init__( | |
| self, | |
| hyperparameters, | |
| data_preprocessor=None, | |
| train_cfg=None, | |
| test_cfg=None, | |
| *args, | |
| **kwargs | |
| ): | |
| super().__init__() | |
| self.hyperparameters = hyperparameters | |
| self.train_cfg = train_cfg | |
| self.test_cfg = test_cfg | |
| if data_preprocessor is not None: | |
| if isinstance(data_preprocessor, nn.Module): | |
| self.data_preprocessor = data_preprocessor | |
| elif isinstance(data_preprocessor, dict): | |
| self.data_preprocessor = MODELS.build(data_preprocessor) | |
| else: | |
| raise TypeError('data_preprocessor should be a `dict` or ' | |
| f'`nn.Module` instance, but got ' | |
| f'{type(data_preprocessor)}') | |
| evaluator_cfg = copy.deepcopy(self.hyperparameters.get('evaluator', None)) | |
| if evaluator_cfg is not None: | |
| for k, v in evaluator_cfg.items(): | |
| metrics = [] | |
| if isinstance(v, dict): | |
| v = [v] | |
| if isinstance(v, list): | |
| for metric_cfg in v: | |
| metric = METRICS.build(metric_cfg) | |
| metrics.append(metric) | |
| else: | |
| raise TypeError('evaluator should be a `dict` or ' | |
| f'`list` instance, but got ' | |
| f'{type(evaluator_cfg)}') | |
| setattr(self, k, MetricCollection(metrics, prefix=k.split('_')[0])) | |
| def _set_grad(self, need_train_names: list=[], noneed_train_names: list=[]): | |
| for name, param in self.named_parameters(): | |
| flag = False | |
| for need_train_name in need_train_names: | |
| if need_train_name in name: | |
| flag = True | |
| for noneed_train_name in noneed_train_names: | |
| if noneed_train_name in name: | |
| flag = False | |
| param.requires_grad_(flag) | |
| not_specific_names = [] | |
| for name, param in self.named_parameters(): | |
| flag_find = False | |
| for specific_name in need_train_names + noneed_train_names: | |
| if specific_name in name: | |
| flag_find = True | |
| if not flag_find: | |
| not_specific_names.append(name) | |
| if self.local_rank == 0: | |
| not_specific_names = [x.split('.')[0] for x in not_specific_names] | |
| not_specific_names = set(not_specific_names) | |
| print(f"Turning off gradients for names: {noneed_train_names}") | |
| print(f"Turning on gradients for names: {need_train_names}") | |
| print(f"Turning off gradients for not specific names: {not_specific_names}") | |
| def _set_train_module(self, mode=True, need_train_names: list=[]): | |
| self.training = mode | |
| for name, module in self.named_children(): | |
| flag = False | |
| for need_train_name in need_train_names: | |
| if need_train_name in name: | |
| flag = True | |
| if flag: | |
| module.train(mode) | |
| else: | |
| module.eval() | |
| return self | |
| def configure_optimizers(self): | |
| optimizer_cfg = copy.deepcopy(self.hyperparameters.get('optimizer')) | |
| base_lr = optimizer_cfg.get('lr') | |
| base_wd = optimizer_cfg.get('weight_decay', None) | |
| sub_models = optimizer_cfg.pop('sub_model', None) | |
| if sub_models is None: | |
| optimizer_cfg['params'] = filter(lambda p: p.requires_grad, self.parameters()) | |
| # optimizer_cfg['params'] = self.parameters() | |
| else: | |
| if isinstance(sub_models, str): | |
| sub_models = {sub_models: {}} | |
| if isinstance(sub_models, list): | |
| sub_models = {x: {} for x in sub_models} | |
| assert isinstance(sub_models, dict), f'sub_models should be a dict, but got {type(sub_models)}' | |
| # import ipdb; ipdb.set_trace() | |
| # set training parameters and lr | |
| for sub_model_name, value in sub_models.items(): | |
| sub_attrs = sub_model_name.split('.') | |
| sub_model_ = self | |
| # import ipdb; ipdb.set_trace() | |
| for sub_attr in sub_attrs: | |
| sub_model_ = getattr(sub_model_, sub_attr) | |
| # sub_model_ = self.trainer.strategy.model._forward_module.get_submodule(sub_model_name) | |
| if isinstance(sub_model_, torch.nn.Parameter): | |
| # filter(lambda p: p.requires_grad, model.parameters()) | |
| # sub_models[sub_model_name]['params'] = filter(lambda p: p.requires_grad, [sub_model_]) | |
| sub_models[sub_model_name]['params'] = filter(lambda p: p.requires_grad, [sub_model_]) | |
| else: | |
| # import ipdb;ipdb.set_trace() | |
| sub_models[sub_model_name]['params'] = filter(lambda p: p.requires_grad, sub_model_.parameters()) | |
| # sub_models[sub_model_name]['params'] = sub_model_.parameters() | |
| lr_mult = value.pop('lr_mult', 1.) | |
| sub_models[sub_model_name]['lr'] = base_lr * lr_mult | |
| if base_wd is not None: | |
| decay_mult = value.pop('decay_mult', 1.) | |
| sub_models[sub_model_name]['weight_decay'] = base_wd * decay_mult | |
| else: | |
| raise ModuleNotFoundError(f'{sub_model_name} not in model') | |
| if self.local_rank == 0: | |
| print('All sub models:') | |
| for name, module in self.named_children(): | |
| print(name, end=', ') | |
| print() | |
| print('Needed train models:') | |
| for name, value in sub_models.items(): | |
| print(f'{name}', end=', ') | |
| print() | |
| optimizer_cfg['params'] = [value for key, value in sub_models.items()] | |
| optimizer = OPTIMIZERS.build(optimizer_cfg) | |
| if self.local_rank == 0: | |
| print('查看优化器参数') | |
| for param_group in optimizer.param_groups: | |
| print([value.shape for value in param_group['params']], '学习率: ', param_group['lr']) | |
| schedulers = copy.deepcopy(self.hyperparameters.get('param_scheduler', None)) | |
| if schedulers is None: | |
| return [optimizer] | |
| param_schedulers = [] | |
| total_step = self.trainer.estimated_stepping_batches | |
| for scheduler in schedulers: | |
| if isinstance(scheduler, _ParamScheduler): | |
| param_schedulers.append(scheduler) | |
| elif isinstance(scheduler, dict): | |
| _scheduler = copy.deepcopy(scheduler) | |
| param_schedulers.append( | |
| PARAM_SCHEDULERS.build( | |
| _scheduler, | |
| default_args=dict( | |
| optimizer=optimizer, | |
| epoch_length=self.trainer.num_training_batches, | |
| ) | |
| ) | |
| ) | |
| else: | |
| raise TypeError( | |
| 'scheduler should be a _ParamScheduler object or dict, ' | |
| f'but got {scheduler}') | |
| return [optimizer], param_schedulers | |
| def lr_scheduler_step(self, scheduler, metric): | |
| pass | |
| def log_grad(self, module=None) -> None: | |
| # Compute the 2-norm for each layer | |
| # If using mixed precision, the gradients are already unscaled here | |
| if module is None: | |
| module = self | |
| norms = grad_norm(module, norm_type=2) | |
| max_grad = max(norms.values()) | |
| min_gead = min(norms.values()) | |
| self.log_dict( | |
| {'max_grad': max_grad, 'min_grad': min_gead}, | |
| prog_bar=True, | |
| logger=True | |
| ) | |
| def setup(self, stage: str) -> None: | |
| evaluators = ['train', 'val', 'test'] | |
| for evaluator in evaluators: | |
| if hasattr(self, f'{evaluator}_evaluator'): | |
| if hasattr(self.trainer.datamodule, f'{evaluator}_dataset'): | |
| dataset = getattr(self.trainer.datamodule, f'{evaluator}_dataset') | |
| if hasattr(dataset, 'metainfo'): | |
| evaluator_ = getattr(self, f'{evaluator}_evaluator') | |
| for v in evaluator_.values(): | |
| if hasattr(v, 'dataset_meta'): | |
| v.dataset_meta = dataset.metainfo | |
| def on_before_optimizer_step(self, optimizer) -> None: | |
| self.log_grad() | |
| def on_validation_epoch_end(self) -> None: | |
| self._log_eval_metrics('val') | |
| def on_test_epoch_end(self) -> None: | |
| self._log_eval_metrics('test') | |
| def on_train_epoch_end(self) -> None: | |
| self._log_eval_metrics('train') | |
| def _log_eval_metrics(self, stage): | |
| assert stage in ['train', 'val', 'test'] | |
| if hasattr(self, f'{stage}_evaluator'): | |
| evaluator = getattr(self, f'{stage}_evaluator') | |
| metrics = evaluator.compute() | |
| metrics = {k.lower(): v for k, v in metrics.items()} | |
| keys = [] | |
| for k, v in metrics.items(): | |
| v = v.view(-1) | |
| for i, data in enumerate(v): | |
| keys.append(f'{k}_{i}') | |
| self.log(f'{k.lower()}_{i}', data, on_step=False, on_epoch=True, prog_bar=True, logger=True, sync_dist=True) | |
| evaluator.reset() | |
| if hasattr(self.trainer, 'checkpoint_callback'): | |
| monitor = self.trainer.checkpoint_callback.monitor | |
| if (monitor is not None) and (monitor not in keys): | |
| data = torch.tensor(0., device=self.device) | |
| self.log(f'{monitor}', data, on_step=False, on_epoch=True, prog_bar=True, logger=True, sync_dist=True) |