Spaces:
Runtime error
Runtime error
| import os | |
| import sys | |
| import abc | |
| import math | |
| import yaml | |
| import torch | |
| from torch.utils.tensorboard import SummaryWriter | |
| from .option import default_hparas | |
| from utils.util import human_format, Timer | |
| from utils.load_yaml import HpsYaml | |
| class BaseSolver(): | |
| ''' | |
| Prototype Solver for all kinds of tasks | |
| Arguments | |
| config - yaml-styled config | |
| paras - argparse outcome | |
| mode - "train"/"test" | |
| ''' | |
| def __init__(self, config, paras, mode="train"): | |
| # General Settings | |
| self.config = config # load from yaml file | |
| self.paras = paras # command line args | |
| self.mode = mode # 'train' or 'test' | |
| for k, v in default_hparas.items(): | |
| setattr(self, k, v) | |
| self.device = torch.device('cuda') if self.paras.gpu and torch.cuda.is_available() \ | |
| else torch.device('cpu') | |
| # Name experiment | |
| self.exp_name = paras.name | |
| if self.exp_name is None: | |
| if 'exp_name' in self.config: | |
| self.exp_name = self.config.exp_name | |
| else: | |
| # By default, exp is named after config file | |
| self.exp_name = paras.config.split('/')[-1].replace('.yaml', '') | |
| if mode == 'train': | |
| self.exp_name += '_seed{}'.format(paras.seed) | |
| if mode == 'train': | |
| # Filepath setup | |
| os.makedirs(paras.ckpdir, exist_ok=True) | |
| self.ckpdir = os.path.join(paras.ckpdir, self.exp_name) | |
| os.makedirs(self.ckpdir, exist_ok=True) | |
| # Logger settings | |
| self.logdir = os.path.join(paras.logdir, self.exp_name) | |
| self.log = SummaryWriter( | |
| self.logdir, flush_secs=self.TB_FLUSH_FREQ) | |
| self.timer = Timer() | |
| # Hyper-parameters | |
| self.step = 0 | |
| self.valid_step = config.hparas.valid_step | |
| self.max_step = config.hparas.max_step | |
| self.verbose('Exp. name : {}'.format(self.exp_name)) | |
| self.verbose('Loading data... large corpus may took a while.') | |
| # elif mode == 'test': | |
| # # Output path | |
| # os.makedirs(paras.outdir, exist_ok=True) | |
| # self.ckpdir = os.path.join(paras.outdir, self.exp_name) | |
| # Load training config to get acoustic feat and build model | |
| # self.src_config = HpsYaml(config.src.config) | |
| # self.paras.load = config.src.ckpt | |
| # self.verbose('Evaluating result of tr. config @ {}'.format( | |
| # config.src.config)) | |
| def backward(self, loss): | |
| ''' | |
| Standard backward step with self.timer and debugger | |
| Arguments | |
| loss - the loss to perform loss.backward() | |
| ''' | |
| self.timer.set() | |
| loss.backward() | |
| grad_norm = torch.nn.utils.clip_grad_norm_( | |
| self.model.parameters(), self.GRAD_CLIP) | |
| if math.isnan(grad_norm): | |
| self.verbose('Error : grad norm is NaN @ step '+str(self.step)) | |
| else: | |
| self.optimizer.step() | |
| self.timer.cnt('bw') | |
| return grad_norm | |
| def load_ckpt(self): | |
| ''' Load ckpt if --load option is specified ''' | |
| print(self.paras) | |
| if self.paras.load is not None: | |
| if self.paras.warm_start: | |
| self.verbose(f"Warm starting model from checkpoint {self.paras.load}.") | |
| ckpt = torch.load( | |
| self.paras.load, map_location=self.device if self.mode == 'train' | |
| else 'cpu') | |
| model_dict = ckpt['model'] | |
| if "ignore_layers" in self.config.model and len(self.config.model.ignore_layers) > 0: | |
| model_dict = {k:v for k, v in model_dict.items() | |
| if k not in self.config.model.ignore_layers} | |
| dummy_dict = self.model.state_dict() | |
| dummy_dict.update(model_dict) | |
| model_dict = dummy_dict | |
| self.model.load_state_dict(model_dict) | |
| else: | |
| # Load weights | |
| ckpt = torch.load( | |
| self.paras.load, map_location=self.device if self.mode == 'train' | |
| else 'cpu') | |
| self.model.load_state_dict(ckpt['model']) | |
| # Load task-dependent items | |
| if self.mode == 'train': | |
| self.step = ckpt['global_step'] | |
| self.optimizer.load_opt_state_dict(ckpt['optimizer']) | |
| self.verbose('Load ckpt from {}, restarting at step {}'.format( | |
| self.paras.load, self.step)) | |
| else: | |
| for k, v in ckpt.items(): | |
| if type(v) is float: | |
| metric, score = k, v | |
| self.model.eval() | |
| self.verbose('Evaluation target = {} (recorded {} = {:.2f} %)'.format( | |
| self.paras.load, metric, score)) | |
| def verbose(self, msg): | |
| ''' Verbose function for print information to stdout''' | |
| if self.paras.verbose: | |
| if type(msg) == list: | |
| for m in msg: | |
| print('[INFO]', m.ljust(100)) | |
| else: | |
| print('[INFO]', msg.ljust(100)) | |
| def progress(self, msg): | |
| ''' Verbose function for updating progress on stdout (do not include newline) ''' | |
| if self.paras.verbose: | |
| sys.stdout.write("\033[K") # Clear line | |
| print('[{}] {}'.format(human_format(self.step), msg), end='\r') | |
| def write_log(self, log_name, log_dict): | |
| ''' | |
| Write log to TensorBoard | |
| log_name - <str> Name of tensorboard variable | |
| log_value - <dict>/<array> Value of variable (e.g. dict of losses), passed if value = None | |
| ''' | |
| if type(log_dict) is dict: | |
| log_dict = {key: val for key, val in log_dict.items() if ( | |
| val is not None and not math.isnan(val))} | |
| if log_dict is None: | |
| pass | |
| elif len(log_dict) > 0: | |
| if 'align' in log_name or 'spec' in log_name: | |
| img, form = log_dict | |
| self.log.add_image( | |
| log_name, img, global_step=self.step, dataformats=form) | |
| elif 'text' in log_name or 'hyp' in log_name: | |
| self.log.add_text(log_name, log_dict, self.step) | |
| else: | |
| self.log.add_scalars(log_name, log_dict, self.step) | |
| def save_checkpoint(self, f_name, metric, score, show_msg=True): | |
| '''' | |
| Ckpt saver | |
| f_name - <str> the name of ckpt file (w/o prefix) to store, overwrite if existed | |
| score - <float> The value of metric used to evaluate model | |
| ''' | |
| ckpt_path = os.path.join(self.ckpdir, f_name) | |
| full_dict = { | |
| "model": self.model.state_dict(), | |
| "optimizer": self.optimizer.get_opt_state_dict(), | |
| "global_step": self.step, | |
| metric: score | |
| } | |
| torch.save(full_dict, ckpt_path) | |
| if show_msg: | |
| self.verbose("Saved checkpoint (step = {}, {} = {:.2f}) and status @ {}". | |
| format(human_format(self.step), metric, score, ckpt_path)) | |
| # ----------------------------------- Abtract Methods ------------------------------------------ # | |
| def load_data(self): | |
| ''' | |
| Called by main to load all data | |
| After this call, data related attributes should be setup (e.g. self.tr_set, self.dev_set) | |
| No return value | |
| ''' | |
| raise NotImplementedError | |
| def set_model(self): | |
| ''' | |
| Called by main to set models | |
| After this call, model related attributes should be setup (e.g. self.l2_loss) | |
| The followings MUST be setup | |
| - self.model (torch.nn.Module) | |
| - self.optimizer (src.Optimizer), | |
| init. w/ self.optimizer = src.Optimizer(self.model.parameters(),**self.config['hparas']) | |
| Loading pre-trained model should also be performed here | |
| No return value | |
| ''' | |
| raise NotImplementedError | |
| def exec(self): | |
| ''' | |
| Called by main to execute training/inference | |
| ''' | |
| raise NotImplementedError | |