| | import math |
| | import os |
| | import torch |
| | from src.modules.optimizers import * |
| | from src.modules.embeddings import * |
| | from src.modules.schedulers import * |
| | from src.modules.tokenizers import * |
| | from src.modules.metrics import * |
| | from src.modules.losses import * |
| | from src.utils.misc import * |
| | from src.utils.logger import Logger |
| | from src.utils.mapper import configmapper |
| | from src.utils.configuration import Config |
| |
|
| | from torch.utils.data import DataLoader |
| | from tqdm import tqdm |
| |
|
| |
|
| | @configmapper.map("trainers", "base") |
| | class BaseTrainer: |
| | def __init__(self, config): |
| | self._config = config |
| | self.metrics = { |
| | configmapper.get_object("metrics", metric["type"]): metric["params"] |
| | for metric in self._config.main_config.metrics |
| | } |
| | self.train_config = self._config.train |
| | self.val_config = self._config.val |
| | self.log_label = self.train_config.log.log_label |
| | if self.train_config.log_and_val_interval is not None: |
| | self.val_log_together = True |
| | print("Logging with label: ", self.log_label) |
| |
|
| | def train(self, model, train_dataset, val_dataset=None, logger=None): |
| | device = torch.device(self._config.main_config.device.name) |
| | model.to(device) |
| | optim_params = self.train_config.optimizer.params |
| | if optim_params: |
| | optimizer = configmapper.get_object( |
| | "optimizers", self.train_config.optimizer.type |
| | )(model.parameters(), **optim_params.as_dict()) |
| | else: |
| | optimizer = configmapper.get_object( |
| | "optimizers", self.train_config.optimizer.type |
| | )(model.parameters()) |
| |
|
| | if self.train_config.scheduler is not None: |
| | scheduler_params = self.train_config.scheduler.params |
| | if scheduler_params: |
| | scheduler = configmapper.get_object( |
| | "schedulers", self.train_config.scheduler.type |
| | )(optimizer, **scheduler_params.as_dict()) |
| | else: |
| | scheduler = configmapper.get_object( |
| | "schedulers", self.train_config.scheduler.type |
| | )(optimizer) |
| |
|
| | criterion_params = self.train_config.criterion.params |
| | if criterion_params: |
| | criterion = configmapper.get_object( |
| | "losses", self.train_config.criterion.type |
| | )(**criterion_params.as_dict()) |
| | else: |
| | criterion = configmapper.get_object( |
| | "losses", self.train_config.criterion.type |
| | )() |
| | if "custom_collate_fn" in dir(train_dataset): |
| | train_loader = DataLoader( |
| | dataset=train_dataset, |
| | collate_fn=train_dataset.custom_collate_fn, |
| | **self.train_config.loader_params.as_dict(), |
| | ) |
| | else: |
| | train_loader = DataLoader( |
| | dataset=train_dataset, **self.train_config.loader_params.as_dict() |
| | ) |
| | |
| |
|
| | max_epochs = self.train_config.max_epochs |
| | batch_size = self.train_config.loader_params.batch_size |
| |
|
| | if self.val_log_together: |
| | val_interval = self.train_config.log_and_val_interval |
| | log_interval = val_interval |
| | else: |
| | val_interval = self.train_config.val_interval |
| | log_interval = self.train_config.log.log_interval |
| |
|
| | if logger is None: |
| | train_logger = Logger(**self.train_config.log.logger_params.as_dict()) |
| | else: |
| | train_logger = logger |
| |
|
| | train_log_values = self.train_config.log.values.as_dict() |
| |
|
| | best_score = ( |
| | -math.inf if self.train_config.save_on.desired == "max" else math.inf |
| | ) |
| | save_on_score = self.train_config.save_on.score |
| | best_step = -1 |
| | best_model = None |
| |
|
| | best_hparam_list = None |
| | best_hparam_name_list = None |
| | best_metrics_list = None |
| | best_metrics_name_list = None |
| |
|
| | |
| | |
| |
|
| | global_step = 0 |
| | for epoch in range(1, max_epochs + 1): |
| | print( |
| | "Epoch: {}/{}, Global Step: {}".format(epoch, max_epochs, global_step) |
| | ) |
| | train_loss = 0 |
| | val_loss = 0 |
| |
|
| | if(self.train_config.label_type=='float'): |
| | all_labels = torch.FloatTensor().to(device) |
| | else: |
| | all_labels = torch.LongTensor().to(device) |
| |
|
| | all_outputs = torch.Tensor().to(device) |
| |
|
| | train_scores = None |
| | val_scores = None |
| |
|
| | pbar = tqdm(total=math.ceil(len(train_dataset) / batch_size)) |
| | pbar.set_description("Epoch " + str(epoch)) |
| |
|
| | val_counter = 0 |
| |
|
| | for step, batch in enumerate(train_loader): |
| | model.train() |
| | optimizer.zero_grad() |
| | inputs, labels = batch |
| |
|
| | if(self.train_config.label_type=='float'): |
| | labels = labels.float() |
| |
|
| | for key in inputs: |
| | inputs[key] = inputs[key].to(device) |
| | labels = labels.to(device) |
| | outputs = model(inputs) |
| | loss = criterion(torch.squeeze(outputs), labels) |
| | loss.backward() |
| |
|
| | all_labels = torch.cat((all_labels, labels), 0) |
| |
|
| | if (self.train_config.label_type=='float'): |
| | all_outputs = torch.cat((all_outputs, outputs), 0) |
| | else: |
| | all_outputs = torch.cat((all_outputs, torch.argmax(outputs, axis=1)), 0) |
| |
|
| |
|
| | train_loss += loss.item() |
| | optimizer.step() |
| |
|
| | if self.train_config.scheduler is not None: |
| | if isinstance(scheduler, ReduceLROnPlateau): |
| | scheduler.step(train_loss / (step + 1)) |
| | else: |
| | scheduler.step() |
| |
|
| | |
| | |
| |
|
| | pbar.set_postfix_str(f"Train Loss: {train_loss /(step+1)}") |
| | pbar.update(1) |
| |
|
| | global_step += 1 |
| |
|
| | |
| |
|
| | if val_dataset is not None and (global_step - 1) % val_interval == 0: |
| | |
| | val_scores = self.val( |
| | model, |
| | val_dataset, |
| | criterion, |
| | device, |
| | global_step, |
| | train_logger, |
| | train_log_values, |
| | ) |
| |
|
| | |
| | if self.train_config.save_on is not None: |
| |
|
| | |
| |
|
| | train_scores = self.get_scores( |
| | train_loss, |
| | global_step, |
| | self.train_config.criterion.type, |
| | all_outputs, |
| | all_labels, |
| | ) |
| |
|
| | best_score, best_step, save_flag = self.check_best( |
| | val_scores, save_on_score, best_score, global_step |
| | ) |
| |
|
| | store_dict = { |
| | "model_state_dict": model.state_dict(), |
| | "best_step": best_step, |
| | "best_score": best_score, |
| | "save_on_score": save_on_score, |
| | } |
| |
|
| | path = self.train_config.save_on.best_path.format( |
| | self.log_label |
| | ) |
| |
|
| | self.save(store_dict, path, save_flag) |
| |
|
| | if save_flag and train_log_values["hparams"] is not None: |
| | ( |
| | best_hparam_list, |
| | best_hparam_name_list, |
| | best_metrics_list, |
| | best_metrics_name_list, |
| | ) = self.update_hparams( |
| | train_scores, val_scores, desc="best_val" |
| | ) |
| | |
| | if (global_step - 1) % log_interval == 0: |
| | |
| | train_loss_name = self.train_config.criterion.type |
| | metric_list = [ |
| | metric(all_labels.cpu(), all_outputs.detach().cpu(), **self.metrics[metric]) |
| | for metric in self.metrics |
| | ] |
| | metric_name_list = [ |
| | metric['type'] for metric in self._config.main_config.metrics |
| | ] |
| |
|
| | train_scores = self.log( |
| | train_loss / (step + 1), |
| | train_loss_name, |
| | metric_list, |
| | metric_name_list, |
| | train_logger, |
| | train_log_values, |
| | global_step, |
| | append_text=self.train_config.append_text, |
| | ) |
| | pbar.close() |
| | if not os.path.exists(self.train_config.checkpoint.checkpoint_dir): |
| | os.makedirs(self.train_config.checkpoint.checkpoint_dir) |
| |
|
| | if self.train_config.save_after_epoch: |
| | store_dict = { |
| | "model_state_dict": model.state_dict(), |
| | } |
| |
|
| | path = f"{self.train_config.checkpoint.checkpoint_dir}_{str(self.train_config.log.log_label)}_{str(epoch)}.pth" |
| |
|
| | self.save(store_dict, path, save_flag=1) |
| |
|
| | if epoch == max_epochs: |
| | |
| | val_scores = self.val( |
| | model, |
| | val_dataset, |
| | criterion, |
| | device, |
| | global_step, |
| | train_logger, |
| | train_log_values, |
| | ) |
| |
|
| | |
| | train_loss_name = self.train_config.criterion.type |
| | metric_list = [ |
| | metric(all_labels.cpu(), all_outputs.detach().cpu(),**self.metrics[metric]) |
| | for metric in self.metrics |
| | ] |
| | metric_name_list = [metric['type'] for metric in self._config.main_config.metrics] |
| |
|
| | train_scores = self.log( |
| | train_loss / len(train_loader), |
| | train_loss_name, |
| | metric_list, |
| | metric_name_list, |
| | train_logger, |
| | train_log_values, |
| | global_step, |
| | append_text=self.train_config.append_text, |
| | ) |
| |
|
| | if self.train_config.save_on is not None: |
| |
|
| | |
| |
|
| | train_scores = self.get_scores( |
| | train_loss, |
| | len(train_loader), |
| | self.train_config.criterion.type, |
| | all_outputs, |
| | all_labels, |
| | ) |
| |
|
| | best_score, best_step, save_flag = self.check_best( |
| | val_scores, save_on_score, best_score, global_step |
| | ) |
| |
|
| | store_dict = { |
| | "model_state_dict": model.state_dict(), |
| | "best_step": best_step, |
| | "best_score": best_score, |
| | "save_on_score": save_on_score, |
| | } |
| |
|
| | path = self.train_config.save_on.best_path.format(self.log_label) |
| |
|
| | self.save(store_dict, path, save_flag) |
| |
|
| | if save_flag and train_log_values["hparams"] is not None: |
| | ( |
| | best_hparam_list, |
| | best_hparam_name_list, |
| | best_metrics_list, |
| | best_metrics_name_list, |
| | ) = self.update_hparams(train_scores, val_scores, desc="best_val") |
| |
|
| | |
| | train_scores = self.get_scores( |
| | train_loss, |
| | len(train_loader), |
| | self.train_config.criterion.type, |
| | all_outputs, |
| | all_labels, |
| | ) |
| |
|
| | store_dict = { |
| | "model_state_dict": model.state_dict(), |
| | "final_step": global_step, |
| | "final_score": train_scores[save_on_score], |
| | "save_on_score": save_on_score, |
| | } |
| |
|
| | path = self.train_config.save_on.final_path.format(self.log_label) |
| |
|
| | self.save(store_dict, path, save_flag=1) |
| | if train_log_values["hparams"] is not None: |
| | ( |
| | final_hparam_list, |
| | final_hparam_name_list, |
| | final_metrics_list, |
| | final_metrics_name_list, |
| | ) = self.update_hparams(train_scores, val_scores, desc="final") |
| | train_logger.save_hyperparams( |
| | best_hparam_list, |
| | best_hparam_name_list, |
| | [int(self.log_label),] + best_metrics_list + final_metrics_list, |
| | ["hparams/log_label",] |
| | + best_metrics_name_list |
| | + final_metrics_name_list, |
| | ) |
| | |
| |
|
| | |
| | |
| |
|
| | def get_scores(self, loss, divisor, loss_name, all_outputs, all_labels): |
| |
|
| | avg_loss = loss / divisor |
| |
|
| | metric_list = [ |
| | metric(all_labels.cpu(), all_outputs.detach().cpu(), **self.metrics[metric]) |
| | for metric in self.metrics |
| | ] |
| | metric_name_list = [metric['type'] for metric in self._config.main_config.metrics] |
| |
|
| | return dict(zip([loss_name,] + metric_name_list, [avg_loss,] + metric_list,)) |
| |
|
| | def check_best(self, val_scores, save_on_score, best_score, global_step): |
| | save_flag = 0 |
| | best_step = global_step |
| | if self.train_config.save_on.desired == "min": |
| | if val_scores[save_on_score] < best_score: |
| | save_flag = 1 |
| | best_score = val_scores[save_on_score] |
| | best_step = global_step |
| | else: |
| | if val_scores[save_on_score] > best_score: |
| | save_flag = 1 |
| | best_score = val_scores[save_on_score] |
| | best_step = global_step |
| | return best_score, best_step, save_flag |
| |
|
| | def update_hparams(self, train_scores, val_scores, desc): |
| | hparam_list = [] |
| | hparam_name_list = [] |
| | for hparam in self.train_config.log.values.hparams: |
| | hparam_list.append(get_item_in_config(self._config, hparam["path"])) |
| | if isinstance(hparam_list[-1], Config): |
| | hparam_list[-1] = hparam_list[-1].as_dict() |
| | hparam_name_list.append(hparam["name"]) |
| |
|
| | val_keys, val_values = zip(*val_scores.items()) |
| | train_keys, train_values = zip(*train_scores.items()) |
| | val_keys = list(val_keys) |
| | train_keys = list(train_keys) |
| | val_values = list(val_values) |
| | train_values = list(train_values) |
| | for i, key in enumerate(val_keys): |
| | val_keys[i] = f"hparams/{desc}_val_" + val_keys[i] |
| | for i, key in enumerate(train_keys): |
| | train_keys[i] = f"hparams/{desc}_train_" + train_keys[i] |
| | |
| | return ( |
| | hparam_list, |
| | hparam_name_list, |
| | train_values + val_values, |
| | train_keys + val_keys, |
| | ) |
| |
|
| | def save(self, store_dict, path, save_flag=0): |
| | if save_flag: |
| | dirs = "/".join(path.split("/")[:-1]) |
| | if not os.path.exists(dirs): |
| | os.makedirs(dirs) |
| | torch.save(store_dict, path) |
| |
|
| | def log( |
| | self, |
| | loss, |
| | loss_name, |
| | metric_list, |
| | metric_name_list, |
| | logger, |
| | log_values, |
| | global_step, |
| | append_text, |
| | ): |
| |
|
| | return_dic = dict(zip([loss_name,] + metric_name_list, [loss,] + metric_list,)) |
| |
|
| | loss_name = f"{append_text}_{self.log_label}_{loss_name}" |
| | if log_values["loss"]: |
| | logger.save_params( |
| | [loss], |
| | [loss_name], |
| | combine=True, |
| | combine_name="losses", |
| | global_step=global_step, |
| | ) |
| |
|
| | for i in range(len(metric_name_list)): |
| | metric_name_list[ |
| | i |
| | ] = f"{append_text}_{self.log_label}_{metric_name_list[i]}" |
| | if log_values["metrics"]: |
| | logger.save_params( |
| | metric_list, |
| | metric_name_list, |
| | combine=True, |
| | combine_name="metrics", |
| | global_step=global_step, |
| | ) |
| | |
| | |
| |
|
| | |
| | |
| | |
| | |
| | return return_dic |
| |
|
| | def val( |
| | self, |
| | model, |
| | dataset, |
| | criterion, |
| | device, |
| | global_step, |
| | train_logger=None, |
| | train_log_values=None, |
| | log=True, |
| | ): |
| | append_text = self.val_config.append_text |
| | if train_logger is not None: |
| | val_logger = train_logger |
| | else: |
| | val_logger = Logger(**self.val_config.log.logger_params.as_dict()) |
| |
|
| | if train_log_values is not None: |
| | val_log_values = train_log_values |
| | else: |
| | val_log_values = self.val_config.log.values.as_dict() |
| | if "custom_collate_fn" in dir(dataset): |
| | val_loader = DataLoader( |
| | dataset=dataset, |
| | collate_fn=dataset.custom_collate_fn, |
| | **self.val_config.loader_params.as_dict(), |
| | ) |
| | else: |
| | val_loader = DataLoader( |
| | dataset=dataset, **self.val_config.loader_params.as_dict() |
| | ) |
| |
|
| | all_outputs = torch.Tensor().to(device) |
| | if(self.train_config.label_type=='float'): |
| | all_labels = torch.FloatTensor().to(device) |
| | else: |
| | all_labels = torch.LongTensor().to(device) |
| |
|
| | batch_size = self.val_config.loader_params.batch_size |
| |
|
| | with torch.no_grad(): |
| | model.eval() |
| | val_loss = 0 |
| | for j, batch in enumerate(val_loader): |
| |
|
| | inputs, labels = batch |
| |
|
| | if(self.train_config.label_type=='float'): |
| | labels = labels.float() |
| |
|
| | for key in inputs: |
| | inputs[key] = inputs[key].to(device) |
| | labels = labels.to(device) |
| |
|
| | outputs = model(inputs) |
| | loss = criterion(torch.squeeze(outputs), labels) |
| | val_loss += loss.item() |
| |
|
| | all_labels = torch.cat((all_labels, labels), 0) |
| |
|
| | if (self.train_config.label_type=='float'): |
| | all_outputs = torch.cat((all_outputs, outputs), 0) |
| | else: |
| | all_outputs = torch.cat((all_outputs, torch.argmax(outputs, axis=1)), 0) |
| |
|
| | val_loss = val_loss / len(val_loader) |
| |
|
| | val_loss_name = self.train_config.criterion.type |
| |
|
| | |
| | metric_list = [ |
| | metric(all_labels.cpu(), all_outputs.detach().cpu(), **self.metrics[metric]) |
| | for metric in self.metrics |
| | ] |
| | metric_name_list = [metric['type'] for metric in self._config.main_config.metrics] |
| | return_dic = dict( |
| | zip([val_loss_name,] + metric_name_list, [val_loss,] + metric_list,) |
| | ) |
| | if log: |
| | val_scores = self.log( |
| | val_loss, |
| | val_loss_name, |
| | metric_list, |
| | metric_name_list, |
| | val_logger, |
| | val_log_values, |
| | global_step, |
| | append_text, |
| | ) |
| | return val_scores |
| | return return_dic |
| |
|