| | |
| | |
| | import os |
| | import sys |
| | import time |
| | import wandb |
| | import importlib |
| | import numpy as np |
| | import torch |
| | import torch.nn as nn |
| | from torch.cuda.amp import autocast, GradScaler |
| | from utils import * |
| |
|
| | class WrappedModel(nn.Module): |
| | |
| | def __init__(self, model): |
| | super(WrappedModel, self).__init__() |
| | self.module = model |
| |
|
| | def forward(self, data_a, data_v, labels=None): |
| | return self.module(data_a, data_v, labels) |
| |
|
| | class EquiAV_ft(nn.Module): |
| | def __init__(self, trainfunc_ft, model, **kwargs): |
| | super(EquiAV_ft, self).__init__(); |
| |
|
| | self.model = model |
| |
|
| | EquiAV_Model = importlib.import_module('models.'+model).__getattribute__('MainModel') |
| | self.__M__ = EquiAV_Model(**kwargs); |
| |
|
| | LossFunction = importlib.import_module('loss.'+trainfunc_ft).__getattribute__('LossFunction') |
| | self.__L__ = LossFunction(**kwargs) |
| |
|
| | def forward(self, data_a, data_v, labels=None): |
| | |
| | output_ft = self.__M__.forward(data_a, data_v) |
| | loss_ft = self.__L__.forward(output_ft, labels) |
| |
|
| | loss_dict = {'loss': loss_ft} |
| | outputs = {'output': output_ft} |
| |
|
| | return loss_dict, outputs |
| |
|
| |
|
| | class ModelTrainer(nn.Module): |
| | def __init__(self, equiAV, gpu, optimizer='adamw', mixedprec=True, freeze_base=False, no_wandb=False, **kwargs): |
| | super(ModelTrainer, self).__init__() |
| |
|
| | self.__model__ = equiAV |
| | self.gpu = gpu |
| | self.mixedprec = mixedprec |
| | self.freeze_base = freeze_base |
| |
|
| | self.max_epoch = kwargs['max_epoch'] |
| | |
| | self.lr = kwargs['lr'] |
| | self.head_lr = kwargs['head_lr'] |
| | self.scheduler = kwargs['scheduler'] |
| | self.warmup_epoch = kwargs['warmup_epoch'] |
| | self.ipe = kwargs['iteration_per_epoch'] |
| | self.ipe_scale = 1.0 |
| |
|
| | self.model_name = kwargs['model'] |
| | |
| | |
| | mlp_list = ['__M__.mlp_head.0.weight', '__M__.mlp_head.0.bias', '__M__.mlp_head.1.weight', '__M__.mlp_head.1.bias', |
| | '__M__.mlp_head2.0.weight', '__M__.mlp_head2.0.bias', '__M__.mlp_head2.1.weight', '__M__.mlp_head2.1.bias', |
| | '__M__.mlp_head_a.0.weight', '__M__.mlp_head_a.0.bias', '__M__.mlp_head_a.1.weight', '__M__.mlp_head_a.1.bias', |
| | '__M__.mlp_head_v.0.weight', '__M__.mlp_head_v.0.bias', '__M__.mlp_head_v.1.weight', '__M__.mlp_head_v.1.bias', |
| | '__M__.mlp_head_concat.0.weight', '__M__.mlp_head_concat.0.bias', '__M__.mlp_head_concat.1.weight', '__M__.mlp_head_concat.1.bias'] |
| |
|
| | mlp_params = list(filter(lambda kv: kv[0] in mlp_list, self.__model__.module.named_parameters())) |
| | base_params = list(filter(lambda kv: kv[0] not in mlp_list, self.__model__.module.named_parameters())) |
| | |
| | mlp_params = [i[1] for i in mlp_params] |
| | base_params = [i[1] for i in base_params] |
| |
|
| | if self.freeze_base == True: |
| | if self.gpu == 0: print('Pretrained backbone parameters are frozen.') |
| | for param in base_params: |
| | param.requires_grad = False |
| |
|
| | trainables = [p for p in self.__model__.parameters() if p.requires_grad] |
| |
|
| | Optimizer = importlib.import_module('optimizer.'+optimizer).__getattribute__('Optimizer') |
| | self.__optimizer__ = Optimizer([{'params': base_params, 'lr': self.lr}, {'params': mlp_params, 'lr': self.lr * self.head_lr}], gpu=gpu, **kwargs) |
| |
|
| | base_lr = self.__optimizer__.param_groups[0]['lr'] |
| | mlp_lr = self.__optimizer__.param_groups[1]['lr'] |
| |
|
| | if self.gpu == 0: |
| | print('Total pretrained backbone parameter number is : {:.3f} million\n'.format(sum(p.numel() for p in base_params) / 1e6)) |
| | print('Total parameter number is : {:.3f} million'.format(sum(p.numel() for p in self.__model__.parameters()) / 1e6)) |
| | print('Total trainable parameter number is : {:.3f} million'.format(sum(p.numel() for p in trainables) / 1e6)) |
| | print('Total newly initialized MLP parameter number is : {:.3f} million'.format(sum(p.numel() for p in mlp_params) / 1e6)) |
| | print('The newly initialized mlp layer uses {:.3f} x larger lr'.format(self.head_lr)) |
| | print(f'Start : base lr: {base_lr}, mlp lr: {mlp_lr}') |
| |
|
| | |
| | Scheduler = importlib.import_module('scheduler.'+self.scheduler).__getattribute__('Scheduler') |
| | self.__scheduler__ = Scheduler(self.__optimizer__, warmup_steps=int(self.warmup_epoch*self.ipe), ref_lr=self.lr, T_max=int(self.ipe_scale*self.max_epoch*self.ipe), **kwargs) |
| |
|
| | self.scaler = GradScaler(growth_factor=2, backoff_factor=0.5, growth_interval=2000) if self.mixedprec else None |
| |
|
| | |
| | self.no_wandb = no_wandb |
| | self.print_freq = kwargs['print_freq'] |
| | self.result_save_path = kwargs['result_save_path'] |
| |
|
| | def train_network(self, loader=None, evalmode=None, epoch=-1): |
| | |
| | batch_time = AverageMeter('Time', ':6.2f') |
| | data_time = AverageMeter('Data', ':6.2f') |
| | mem = AverageMeter('Mem (GB)', ':6.1f') |
| | metrics = AverageMeter('Train Loss', ':1.3e') if not evalmode else AverageMeter('Val Loss', ':1.3e') |
| |
|
| | progress = ProgressMeter( |
| | len(loader), |
| | [batch_time, data_time, mem, metrics], |
| | prefix="Epoch: [{}]".format(epoch)) |
| | |
| | |
| | param_num = 0 |
| | for p in self.__model__.parameters(): |
| | param_num += p.numel() |
| |
|
| | if evalmode: |
| | self.__model__.eval(); |
| | A_predictions, A_targets, A_loss, syn_A_loss = [], [], [], [] |
| | else: |
| | self.__model__.train(); |
| |
|
| | data_iter = 0 |
| | end = time.time() |
| | |
| | for data in loader: |
| |
|
| | |
| | data_time.update(time.time() - end) |
| |
|
| | data_a, data_v, _,_,_,_,labels = data |
| | |
| | data_a = data_a.cuda(self.gpu) |
| | data_v = data_v.cuda(self.gpu) |
| | labels = labels.cuda(self.gpu) |
| |
|
| | |
| | with autocast(enabled=self.mixedprec): |
| | loss_dict, outputs = self.__model__(data_a, data_v, labels=labels) |
| | |
| | if not evalmode: |
| | _new_lr = self.__scheduler__.step() |
| |
|
| | if self.mixedprec: |
| | |
| | self.scaler.scale(loss_dict['loss']).backward(); |
| | self.scaler.step(self.__optimizer__); |
| | self.scaler.update(); |
| | else: |
| | |
| | loss_dict['loss'].backward() |
| | self.__optimizer__.step(); |
| |
|
| | self.zero_grad(); |
| |
|
| | |
| | metrics.update(loss_dict['loss'], loader.batch_size) |
| | |
| | elif evalmode: |
| | predictions = outputs['output'].to('cpu').detach() |
| |
|
| | A_predictions.append(predictions) |
| | A_targets.append(labels) |
| | A_loss.append(loss_dict["loss"].to('cpu').detach()) |
| | syn_A_loss.append(loss_dict["loss"].reshape(1).to('cpu').detach()) |
| |
|
| | metrics.update(loss_dict['loss'], loader.batch_size) |
| | |
| | |
| | batch_time.update(time.time() - end) |
| | end = time.time() |
| | mem.update(torch.cuda.max_memory_allocated() // 1e9) |
| |
|
| | if data_iter % self.print_freq == 0: |
| | param_sum = 0 |
| | for p in self.__model__.parameters(): |
| | param_sum += torch.pow(p.detach(),2).sum() |
| | param_avg = torch.sqrt(param_sum) / param_num |
| |
|
| | if self.gpu == 0: |
| | if not self.no_wandb and not evalmode: |
| | wandb.log({"Train Loss": metrics.val, |
| | 'scaler': self.scaler.get_scale() if self.mixedprec else 0, |
| | 'base_lr': self.__optimizer__.param_groups[0]['lr'], |
| | 'mlp_lr': self.__optimizer__.param_groups[1]['lr'], |
| | 'param_avg': param_avg, |
| | }) |
| |
|
| | log_info = progress.display(data_iter) |
| | |
| | with open(os.path.join(self.result_save_path, 'log.txt'), 'a') as f: |
| | f.write('Eval: '+ '\t'.join(log_info) + '\n' if evalmode else 'Train: '+ '\t'.join(log_info) + '\n') |
| | data_iter += 1 |
| |
|
| | if evalmode: |
| | audio_output = torch.cat(A_predictions) |
| | target = torch.cat(A_targets) |
| | loss = np.mean(A_loss) |
| |
|
| | stats = calculate_stats(audio_output, target) |
| | |
| | return stats, loss |
| |
|
| | elif not evalmode: |
| |
|
| | sys.stdout.write("\n"); |
| | progress.synchronize() |
| | |
| | return metrics.avg |
| |
|
| | |
| | |
| | |
| |
|
| | def saveParameters(self, path): |
| | torch.save(self.__model__.module.state_dict(), path); |
| |
|
| | |
| | |
| | |
| |
|
| | def loadParameters(self, path): |
| | self_state = self.__model__.module.state_dict(); |
| | loaded_state = torch.load(path, map_location="cuda:%d"%self.gpu); |
| | for name, param in loaded_state.items(): |
| | origname = name; |
| | if name not in self_state: |
| | name = origname.replace('__M__.','') if 'cav_mae' not in self.model_name else origname.replace('module','__M__') |
| | if name not in self_state: |
| | if self.gpu == 0: print("{} is not in the model.".format(origname)); |
| | continue; |
| | else: |
| | if self.gpu == 0: print("{} is loaded in the model".format(name)); |
| | else: |
| | if self.gpu == 0: print("{} is loaded in the model".format(name)); |
| |
|
| | if self_state[name].size() != loaded_state[origname].size(): |
| | if self.gpu == 0: print("Wrong parameter length: {}, model: {}, loaded: {}".format(origname, self_state[name].size(), loaded_state[origname].size())); |
| | continue; |
| |
|
| | self_state[name].copy_(param); |