| from basicsr.utils import get_root_logger |
| from basicsr.utils.registry import MODEL_REGISTRY |
| from .video_base_model import VideoBaseModel |
|
|
|
|
| @MODEL_REGISTRY.register() |
| class EDVRModel(VideoBaseModel): |
| """EDVR Model. |
| |
| Paper: EDVR: Video Restoration with Enhanced Deformable Convolutional Networks. # noqa: E501 |
| """ |
|
|
| def __init__(self, opt): |
| super(EDVRModel, self).__init__(opt) |
| if self.is_train: |
| self.train_tsa_iter = opt['train'].get('tsa_iter') |
|
|
| def setup_optimizers(self): |
| train_opt = self.opt['train'] |
| dcn_lr_mul = train_opt.get('dcn_lr_mul', 1) |
| logger = get_root_logger() |
| logger.info(f'Multiple the learning rate for dcn with {dcn_lr_mul}.') |
| if dcn_lr_mul == 1: |
| optim_params = self.net_g.parameters() |
| else: |
| normal_params = [] |
| dcn_params = [] |
| for name, param in self.net_g.named_parameters(): |
| if 'dcn' in name: |
| dcn_params.append(param) |
| else: |
| normal_params.append(param) |
| optim_params = [ |
| { |
| 'params': normal_params, |
| 'lr': train_opt['optim_g']['lr'] |
| }, |
| { |
| 'params': dcn_params, |
| 'lr': train_opt['optim_g']['lr'] * dcn_lr_mul |
| }, |
| ] |
|
|
| optim_type = train_opt['optim_g'].pop('type') |
| self.optimizer_g = self.get_optimizer(optim_type, optim_params, **train_opt['optim_g']) |
| self.optimizers.append(self.optimizer_g) |
|
|
| def optimize_parameters(self, current_iter): |
| if self.train_tsa_iter: |
| if current_iter == 1: |
| logger = get_root_logger() |
| logger.info(f'Only train TSA module for {self.train_tsa_iter} iters.') |
| for name, param in self.net_g.named_parameters(): |
| if 'fusion' not in name: |
| param.requires_grad = False |
| elif current_iter == self.train_tsa_iter: |
| logger = get_root_logger() |
| logger.warning('Train all the parameters.') |
| for param in self.net_g.parameters(): |
| param.requires_grad = True |
|
|
| super(EDVRModel, self).optimize_parameters(current_iter) |
|
|