equiav / ft_trainer.py
wendell0218's picture
Upload folder using huggingface_hub
c8ef6d5 verified
#!/usr/bin/python
#-*- coding: utf-8 -*-
import os
import sys
import time
import wandb
import importlib
import numpy as np
import torch
import torch.nn as nn
from torch.cuda.amp import autocast, GradScaler
from utils import *
class WrappedModel(nn.Module):
## The purpose of this wrapper is to make the model structure consistent between single and multi-GPU
def __init__(self, model):
super(WrappedModel, self).__init__()
self.module = model
def forward(self, data_a, data_v, labels=None):
return self.module(data_a, data_v, labels)
class EquiAV_ft(nn.Module):
def __init__(self, trainfunc_ft, model, **kwargs):
super(EquiAV_ft, self).__init__();
self.model = model
EquiAV_Model = importlib.import_module('models.'+model).__getattribute__('MainModel')
self.__M__ = EquiAV_Model(**kwargs);
LossFunction = importlib.import_module('loss.'+trainfunc_ft).__getattribute__('LossFunction')
self.__L__ = LossFunction(**kwargs)
def forward(self, data_a, data_v, labels=None):
output_ft = self.__M__.forward(data_a, data_v)
loss_ft = self.__L__.forward(output_ft, labels)
loss_dict = {'loss': loss_ft}
outputs = {'output': output_ft}
return loss_dict, outputs
class ModelTrainer(nn.Module):
def __init__(self, equiAV, gpu, optimizer='adamw', mixedprec=True, freeze_base=False, no_wandb=False, **kwargs):
super(ModelTrainer, self).__init__()
self.__model__ = equiAV
self.gpu = gpu
self.mixedprec = mixedprec
self.freeze_base = freeze_base
self.max_epoch = kwargs['max_epoch']
self.lr = kwargs['lr']
self.head_lr = kwargs['head_lr']
self.scheduler = kwargs['scheduler']
self.warmup_epoch = kwargs['warmup_epoch']
self.ipe = kwargs['iteration_per_epoch']
self.ipe_scale = 1.0
self.model_name = kwargs['model']
# possible mlp layer name list, mlp layers are newly initialized layers in the finetuning stage (i.e., not pretrained) and should use a larger lr during finetuning
mlp_list = ['__M__.mlp_head.0.weight', '__M__.mlp_head.0.bias', '__M__.mlp_head.1.weight', '__M__.mlp_head.1.bias',
'__M__.mlp_head2.0.weight', '__M__.mlp_head2.0.bias', '__M__.mlp_head2.1.weight', '__M__.mlp_head2.1.bias',
'__M__.mlp_head_a.0.weight', '__M__.mlp_head_a.0.bias', '__M__.mlp_head_a.1.weight', '__M__.mlp_head_a.1.bias',
'__M__.mlp_head_v.0.weight', '__M__.mlp_head_v.0.bias', '__M__.mlp_head_v.1.weight', '__M__.mlp_head_v.1.bias',
'__M__.mlp_head_concat.0.weight', '__M__.mlp_head_concat.0.bias', '__M__.mlp_head_concat.1.weight', '__M__.mlp_head_concat.1.bias']
mlp_params = list(filter(lambda kv: kv[0] in mlp_list, self.__model__.module.named_parameters()))
base_params = list(filter(lambda kv: kv[0] not in mlp_list, self.__model__.module.named_parameters()))
mlp_params = [i[1] for i in mlp_params]
base_params = [i[1] for i in base_params]
if self.freeze_base == True:
if self.gpu == 0: print('Pretrained backbone parameters are frozen.')
for param in base_params:
param.requires_grad = False
trainables = [p for p in self.__model__.parameters() if p.requires_grad]
Optimizer = importlib.import_module('optimizer.'+optimizer).__getattribute__('Optimizer')
self.__optimizer__ = Optimizer([{'params': base_params, 'lr': self.lr}, {'params': mlp_params, 'lr': self.lr * self.head_lr}], gpu=gpu, **kwargs)
base_lr = self.__optimizer__.param_groups[0]['lr']
mlp_lr = self.__optimizer__.param_groups[1]['lr']
if self.gpu == 0:
print('Total pretrained backbone parameter number is : {:.3f} million\n'.format(sum(p.numel() for p in base_params) / 1e6))
print('Total parameter number is : {:.3f} million'.format(sum(p.numel() for p in self.__model__.parameters()) / 1e6))
print('Total trainable parameter number is : {:.3f} million'.format(sum(p.numel() for p in trainables) / 1e6))
print('Total newly initialized MLP parameter number is : {:.3f} million'.format(sum(p.numel() for p in mlp_params) / 1e6))
print('The newly initialized mlp layer uses {:.3f} x larger lr'.format(self.head_lr))
print(f'Start : base lr: {base_lr}, mlp lr: {mlp_lr}')
# only for preliminary test, formal exps should use fixed learning rate scheduler
Scheduler = importlib.import_module('scheduler.'+self.scheduler).__getattribute__('Scheduler')
self.__scheduler__ = Scheduler(self.__optimizer__, warmup_steps=int(self.warmup_epoch*self.ipe), ref_lr=self.lr, T_max=int(self.ipe_scale*self.max_epoch*self.ipe), **kwargs)
self.scaler = GradScaler(growth_factor=2, backoff_factor=0.5, growth_interval=2000) if self.mixedprec else None
# logging
self.no_wandb = no_wandb
self.print_freq = kwargs['print_freq']
self.result_save_path = kwargs['result_save_path']
def train_network(self, loader=None, evalmode=None, epoch=-1):
# Setting for the logging
batch_time = AverageMeter('Time', ':6.2f')
data_time = AverageMeter('Data', ':6.2f')
mem = AverageMeter('Mem (GB)', ':6.1f')
metrics = AverageMeter('Train Loss', ':1.3e') if not evalmode else AverageMeter('Val Loss', ':1.3e')
progress = ProgressMeter(
len(loader),
[batch_time, data_time, mem, metrics],
prefix="Epoch: [{}]".format(epoch))
# number of model parameters
param_num = 0
for p in self.__model__.parameters():
param_num += p.numel()
if evalmode:
self.__model__.eval();
A_predictions, A_targets, A_loss, syn_A_loss = [], [], [], []
else:
self.__model__.train();
data_iter = 0
end = time.time()
for data in loader:
# measure data loading time
data_time.update(time.time() - end)
data_a, data_v, _,_,_,_,labels = data
# transform input to torch cuda tensor
data_a = data_a.cuda(self.gpu) # batch x target_length x num melbins
data_v = data_v.cuda(self.gpu) # batch x channel x width x height
labels = labels.cuda(self.gpu)
# ==================== FORWARD PASS ====================
with autocast(enabled=self.mixedprec):
loss_dict, outputs = self.__model__(data_a, data_v, labels=labels)
if not evalmode:
_new_lr = self.__scheduler__.step()
if self.mixedprec:
# mixed precision
self.scaler.scale(loss_dict['loss']).backward();
self.scaler.step(self.__optimizer__);
self.scaler.update();
else:
# single precision
loss_dict['loss'].backward()
self.__optimizer__.step();
self.zero_grad();
# logging
metrics.update(loss_dict['loss'], loader.batch_size)
elif evalmode:
predictions = outputs['output'].to('cpu').detach()
A_predictions.append(predictions)
A_targets.append(labels)
A_loss.append(loss_dict["loss"].to('cpu').detach())
syn_A_loss.append(loss_dict["loss"].reshape(1).to('cpu').detach())
metrics.update(loss_dict['loss'], loader.batch_size)
# measure elapsed time and memory
batch_time.update(time.time() - end)
end = time.time()
mem.update(torch.cuda.max_memory_allocated() // 1e9)
if data_iter % self.print_freq == 0:
param_sum = 0
for p in self.__model__.parameters():
param_sum += torch.pow(p.detach(),2).sum()
param_avg = torch.sqrt(param_sum) / param_num
if self.gpu == 0:
if not self.no_wandb and not evalmode:
wandb.log({"Train Loss": metrics.val,
'scaler': self.scaler.get_scale() if self.mixedprec else 0,
'base_lr': self.__optimizer__.param_groups[0]['lr'],
'mlp_lr': self.__optimizer__.param_groups[1]['lr'],
'param_avg': param_avg,
})
log_info = progress.display(data_iter)
with open(os.path.join(self.result_save_path, 'log.txt'), 'a') as f:
f.write('Eval: '+ '\t'.join(log_info) + '\n' if evalmode else 'Train: '+ '\t'.join(log_info) + '\n')
data_iter += 1
if evalmode:
audio_output = torch.cat(A_predictions)
target = torch.cat(A_targets)
loss = np.mean(A_loss)
stats = calculate_stats(audio_output, target)
return stats, loss
elif not evalmode:
sys.stdout.write("\n");
progress.synchronize()
return metrics.avg
## ===== ===== ===== ===== ===== ===== ===== =====
## Save parameters
## ===== ===== ===== ===== ===== ===== ===== =====
def saveParameters(self, path):
torch.save(self.__model__.module.state_dict(), path);
## ===== ===== ===== ===== ===== ===== ===== =====
## Load parameters
## ===== ===== ===== ===== ===== ===== ===== =====
def loadParameters(self, path):
self_state = self.__model__.module.state_dict();
loaded_state = torch.load(path, map_location="cuda:%d"%self.gpu);
for name, param in loaded_state.items():
origname = name;
if name not in self_state:
name = origname.replace('__M__.','') if 'cav_mae' not in self.model_name else origname.replace('module','__M__')
if name not in self_state:
if self.gpu == 0: print("{} is not in the model.".format(origname));
continue;
else:
if self.gpu == 0: print("{} is loaded in the model".format(name));
else:
if self.gpu == 0: print("{} is loaded in the model".format(name));
if self_state[name].size() != loaded_state[origname].size():
if self.gpu == 0: print("Wrong parameter length: {}, model: {}, loaded: {}".format(origname, self_state[name].size(), loaded_state[origname].size()));
continue;
self_state[name].copy_(param);