Spaces:
Running
Running
| import sys | |
| import os | |
| import time | |
| import logging | |
| import torch | |
| ## Uncomment the following line to make the code deterministic and use CUBLAS_WORKSPACE_CONFIG=:4096:8 | |
| torch.use_deterministic_algorithms(True) | |
| import json | |
| import numpy as np | |
| import random | |
| import wandb | |
| from omegaconf import OmegaConf, open_dict | |
| from os import path | |
| from collections import OrderedDict, defaultdict | |
| from transformers import get_linear_schedule_with_warmup | |
| from transformers import AutoModel, AutoTokenizer | |
| from data_utils.utils import load_dataset, load_eval_dataset | |
| import pytorch_utils.utils as utils | |
| from torch.profiler import profile, record_function, ProfilerActivity | |
| from model.entity_ranking_model import EntityRankingModel | |
| from model.mention_proposal import MentionProposalModule | |
| from data_utils.tensorize_dataset import TensorizeDataset | |
| from pytorch_utils.optimization_utils import get_inverse_square_root_decay | |
| from utils_evaluate import coref_evaluation | |
| from typing import Dict, Union, List, Optional | |
| from omegaconf import DictConfig | |
| import copy | |
| logging.basicConfig(format="%(asctime)s - %(message)s", level=logging.INFO) | |
| logger = logging.getLogger() | |
| loss_acc_template_dict = { | |
| "total": 0.0, | |
| "ment_loss": 0.0, | |
| "coref": 0.0, | |
| "mention_count": 0.0, | |
| "processed_docs": 0.0, | |
| "ment_correct": 0.0, | |
| "ment_total": 0.0, | |
| "ment_tp": 0.0, | |
| "ment_pp": 0.0, | |
| "ment_ap": 0.0, | |
| } | |
| class Experiment: | |
| """Class for training and evaluating coreference models.""" | |
| def __init__(self, config: DictConfig): | |
| self.config = config | |
| print("Seeded: ", config.seed) | |
| print("Cuda Available: ", torch.cuda.is_available()) | |
| # Whether to train or not | |
| self.eval_model: bool = not self.config.train | |
| # Initialize dictionary to track key training variables | |
| self.train_info = { | |
| "val_perf": 0.0, | |
| "global_steps": 0, | |
| "num_stuck_evals": 0, | |
| "peak_memory": 0.0, | |
| } | |
| self.wandbdata = {} | |
| # Initialize model path attributes | |
| self.model_path = self.config.paths.model_path | |
| self.best_model_path = self.config.paths.best_model_path | |
| if not self.eval_model: | |
| # Step 1 - Initialize model | |
| self._build_model() | |
| # Step 2 - Load Data - Data processing choices such as tokenizer will depend on the model | |
| self._load_data() | |
| # Step 3 - Resume training | |
| self._setup_training() | |
| # Step 4 - Loading the checkpoint also restores the training metadata | |
| self._load_previous_checkpoint() | |
| # All set to resume training | |
| # But first check if training is remaining | |
| if self._is_training_remaining(): | |
| self.train() | |
| # Perform final evaluation | |
| if path.exists(self.best_model_path): | |
| # Step 1 - Initialize model | |
| self._initialize_best_model() | |
| # Step 2 - Load evaluation data | |
| self._load_data() | |
| # Step 3 - Perform evaluation | |
| self.perform_final_eval() | |
| else: | |
| logger.info("No model accessible!") | |
| sys.exit(1) | |
| def _build_model(self) -> None: | |
| """Constructs the model with given config.""" | |
| model_params: DictConfig = self.config.model | |
| train_config: DictConfig = self.config.trainer | |
| self.model = EntityRankingModel( | |
| model_config=model_params, train_config=train_config | |
| ) | |
| if torch.cuda.is_available(): | |
| self.model.cuda(device=self.config.device) | |
| # Print model | |
| utils.print_model_info(self.model) | |
| sys.stdout.flush() | |
| def _load_data(self): | |
| """Loads and processes the training and evaluation data. | |
| Loads the data concerning all the specified datasets for training and eval. | |
| The first part of this method loads all the data from the preprocessed jsonline files. | |
| In the second half, the loaded data is tensorized for consumption by the model. | |
| Apart from loading and processing the data, the method also populates important | |
| attributes such as: | |
| num_train_docs_map (dict): Dictionary to maintain the number of training | |
| docs per dataset which is useful for implementing sampling in joint training. | |
| num_training_steps (int): Number of total training steps. | |
| eval_per_k_steps (int): Number of gradient updates before each evaluation. | |
| """ | |
| self.data_iter_map, self.conll_data_dir, self.num_split_docs_map = ( | |
| {}, | |
| {}, | |
| {"train": {}, "dev": {}, "test": {}}, | |
| ) | |
| raw_data_map = {} | |
| max_segment_len: int = self.config.model.doc_encoder.transformer.max_segment_len | |
| model_name: str = self.config.model.doc_encoder.transformer.name | |
| add_speaker_tokens: bool = self.config.model.doc_encoder.add_speaker_tokens | |
| base_data_dir: str = path.abspath(self.config.paths.base_data_dir) | |
| # Load data | |
| for dataset_name, attributes in self.config.datasets.items(): | |
| num_train_docs: Optional[int] = attributes.get("num_train_docs", None) | |
| num_dev_docs: Optional[int] = attributes.get("num_dev_docs", None) | |
| num_test_docs: Optional[int] = attributes.get("num_test_docs", None) | |
| singleton_file: Optional[str] = attributes.get("singleton_file", None) | |
| external_md_file: Optional[str] = attributes.get("external_md_file", None) | |
| if singleton_file is not None: | |
| singleton_file = path.join(base_data_dir, singleton_file) | |
| if path.exists(singleton_file): | |
| logger.info(f"Singleton file found: {singleton_file}") | |
| if external_md_file is not None: | |
| external_md_file = path.join(base_data_dir, external_md_file) | |
| if path.exists(external_md_file): | |
| logger.info( | |
| f"External mention detector file found: {external_md_file}" | |
| ) | |
| # Data directory is a function of dataset name and tokenizer used | |
| data_dir = path.join(path.join(base_data_dir, dataset_name), model_name) | |
| # Check if speaker tokens are added | |
| if add_speaker_tokens: | |
| pot_data_dir = path.join( | |
| path.join(path.join(base_data_dir, dataset_name)), | |
| model_name + "_speaker", | |
| ) | |
| if path.exists(pot_data_dir): | |
| data_dir = pot_data_dir | |
| # Datasets such as litbank have cross validation splits | |
| if attributes.get("cross_val_split", None) is not None: | |
| data_dir = path.join(data_dir, str(attributes.get("cross_val_split"))) | |
| logger.info("Data directory: %s" % data_dir) | |
| # CoNLL data dir | |
| if attributes.get("has_conll", False): | |
| conll_dir = path.join( | |
| path.join(path.join(base_data_dir, dataset_name)), "conll" | |
| ) | |
| if attributes.get("cross_val_split", None) is not None: | |
| # LitBank like datasets have cross validation splits | |
| conll_dir = path.join( | |
| conll_dir, str(attributes.get("cross_val_split")) | |
| ) | |
| if path.exists(conll_dir): | |
| self.conll_data_dir[dataset_name] = conll_dir | |
| self.num_split_docs_map["train"][dataset_name] = num_train_docs | |
| self.num_split_docs_map["dev"][dataset_name] = num_dev_docs | |
| self.num_split_docs_map["test"][dataset_name] = num_test_docs | |
| if self.eval_model: | |
| print("In Eval Model DataLoader") | |
| raw_data_map[dataset_name] = load_eval_dataset( | |
| data_dir, | |
| external_md_file=external_md_file, | |
| max_segment_len=max_segment_len, | |
| dataset_name=dataset_name, | |
| ) | |
| else: | |
| raw_data_map[dataset_name] = load_dataset( | |
| data_dir, | |
| singleton_file=singleton_file, | |
| num_dev_docs=num_dev_docs, | |
| num_test_docs=num_test_docs, | |
| max_segment_len=max_segment_len, | |
| dataset_name=dataset_name, | |
| ) | |
| # Tensorize data | |
| data_processor = TensorizeDataset( | |
| self.model.get_tokenizer(), | |
| remove_singletons=(not self.config.keep_singletons), | |
| ) | |
| if self.eval_model: | |
| for split in ["dev", "test"]: | |
| self.data_iter_map[split] = {} | |
| for dataset in raw_data_map: | |
| for split in raw_data_map[dataset]: | |
| self.data_iter_map[split][dataset] = data_processor.tensorize_data( | |
| raw_data_map[dataset][split], training=False | |
| ) | |
| else: | |
| # Training | |
| for split in ["train", "dev", "test"]: | |
| self.data_iter_map[split] = {} | |
| training = split == "train" | |
| for dataset in raw_data_map: | |
| self.data_iter_map[split][dataset] = data_processor.tensorize_data( | |
| raw_data_map[dataset][split], training=training | |
| ) | |
| # Estimate number of training steps | |
| if self.config.trainer.eval_per_k_steps is None: | |
| # Eval steps is 1 epoch (with subsampling) of all the datasets used in joint training | |
| self.config.trainer.eval_per_k_steps = sum( | |
| self.num_split_docs_map["train"].values() | |
| ) | |
| self.config.trainer.num_training_steps = ( | |
| self.config.trainer.eval_per_k_steps * self.config.trainer.max_evals | |
| ) | |
| logger.info( | |
| f"Number of training steps: {self.config.trainer.num_training_steps}" | |
| ) | |
| logger.info(f"Eval per k steps: {self.config.trainer.eval_per_k_steps}") | |
| def _load_previous_checkpoint(self): | |
| """Loads the last checkpoint or best checkpoint.""" | |
| # Resume training | |
| print("Model Path: ", self.model_path) | |
| print("Model Initialised:", torch.cuda.memory_summary(self.config.device)) | |
| if path.exists(self.model_path): | |
| self.load_model(self.model_path, last_checkpoint=True) | |
| logger.info("Model loaded\n") | |
| print( | |
| "Loaded Model Returned:", torch.cuda.memory_summary(self.config.device) | |
| ) | |
| else: | |
| # Starting training | |
| logger.info("Model initialized\n") | |
| sys.stdout.flush() | |
| def _is_training_remaining(self): | |
| """Check if training is done or remaining. | |
| There are two cases where we don't resume training: | |
| (a) The dev performance has not improved for the allowed patience parameter number of evaluations. | |
| (b) Number of gradient updates is already >= Total training steps. | |
| Returns: | |
| bool: If true, we resume training. Otherwise do final evaluation. | |
| """ | |
| if self.train_info["num_stuck_evals"] >= self.config.trainer.patience: | |
| return False | |
| if self.train_info["global_steps"] >= self.config.trainer.num_training_steps: | |
| return False | |
| return True | |
| def _setup_training(self): | |
| """Initialize optimizer and bookkeeping variables for training.""" | |
| # Dictionary to track key training variables | |
| self.train_info = { | |
| "val_perf": 0.0, | |
| "global_steps": 0, | |
| "num_stuck_evals": 0, | |
| "peak_memory": 0.0, | |
| "max_mem": 0.0, | |
| } | |
| # Initialize optimizers | |
| self._initialize_optimizers() | |
| def _initialize_optimizers(self): | |
| """Initialize model + optimizer(s). Check if there's a checkpoint in which case we resume from there.""" | |
| optimizer_config: DictConfig = self.config.optimizer | |
| train_config: DictConfig = self.config.trainer | |
| self.optimizer, self.optim_scheduler = {}, {} | |
| if torch.cuda.is_available(): | |
| # Gradient scaler required for mixed precision training | |
| self.scaler = torch.GradScaler("cuda") | |
| else: | |
| self.scaler = None | |
| # Optimizer for clustering params | |
| self.optimizer["mem"] = torch.optim.Adam( | |
| self.model.get_params()[1], lr=optimizer_config.init_lr, eps=1e-6 | |
| ) | |
| if optimizer_config.lr_decay == "inv": | |
| self.optim_scheduler["mem"] = get_inverse_square_root_decay( | |
| self.optimizer["mem"], num_warmup_steps=0 | |
| ) | |
| else: | |
| # No warmup steps for model params | |
| self.optim_scheduler["mem"] = get_linear_schedule_with_warmup( | |
| self.optimizer["mem"], | |
| num_warmup_steps=0, | |
| num_training_steps=train_config.num_training_steps, | |
| ) | |
| if self.config.model.doc_encoder.finetune: | |
| # Optimizer for document encoder | |
| no_decay = [ | |
| "bias", | |
| "LayerNorm.weight", | |
| ] # No weight decay for bias and layernorm weights | |
| encoder_params = self.model.get_params(named=True)[0] | |
| grouped_param = [ | |
| { | |
| "params": [ | |
| p | |
| for n, p in encoder_params | |
| if not any(nd in n for nd in no_decay) | |
| ], | |
| "lr": optimizer_config.fine_tune_lr, | |
| "weight_decay": 1e-2, | |
| }, | |
| { | |
| "params": [ | |
| p for n, p in encoder_params if any(nd in n for nd in no_decay) | |
| ], | |
| "lr": optimizer_config.fine_tune_lr, | |
| "weight_decay": 0.0, | |
| }, | |
| ] | |
| self.optimizer["doc"] = torch.optim.AdamW( | |
| grouped_param, lr=optimizer_config.fine_tune_lr, eps=1e-6 | |
| ) | |
| # Scheduler for document encoder | |
| num_warmup_steps = int(0.1 * train_config.num_training_steps) | |
| if optimizer_config.lr_decay == "inv": | |
| self.optim_scheduler["doc"] = get_inverse_square_root_decay( | |
| self.optimizer["doc"], num_warmup_steps=num_warmup_steps | |
| ) | |
| else: | |
| self.optim_scheduler["doc"] = get_linear_schedule_with_warmup( | |
| self.optimizer["doc"], | |
| num_warmup_steps=num_warmup_steps, | |
| num_training_steps=train_config.num_training_steps, | |
| ) | |
| def agg(self, datadepdict): | |
| agg_dict = defaultdict(float) | |
| for dataset in datadepdict: | |
| for key in datadepdict[dataset]: | |
| agg_dict[key] += datadepdict[dataset][key] | |
| agg_dict["loss_norm"] = ( | |
| agg_dict["coref"] / agg_dict["mention_count"] | |
| + agg_dict["ment_loss"] / agg_dict["ment_total"] | |
| if agg_dict["mention_count"] > 0 | |
| else 0 | |
| ) | |
| agg_dict["ment_acc"] = agg_dict["ment_correct"] / agg_dict["ment_total"] | |
| agg_dict["ment_prec"] = ( | |
| agg_dict["ment_tp"] / agg_dict["ment_pp"] if agg_dict["ment_pp"] > 0 else 0 | |
| ) | |
| agg_dict["ment_rec"] = ( | |
| agg_dict["ment_tp"] / agg_dict["ment_ap"] if agg_dict["ment_ap"] > 0 else 0 | |
| ) | |
| agg_dict["ment_f1"] = ( | |
| 2 | |
| * (agg_dict["ment_prec"] * agg_dict["ment_rec"]) | |
| / (agg_dict["ment_prec"] + agg_dict["ment_rec"]) | |
| if (agg_dict["ment_prec"] + agg_dict["ment_rec"]) > 0 | |
| else 0 | |
| ) | |
| return agg_dict | |
| def train(self) -> None: | |
| """Method for training the model. | |
| This method implements the training loop. | |
| Within the training loop, the model is periodically evaluated on the dev set(s). | |
| """ | |
| model, optimizer, scheduler, scaler = ( | |
| self.model, | |
| self.optimizer, | |
| self.optim_scheduler, | |
| self.scaler, | |
| ) | |
| model.train() | |
| optimizer_config, train_config = self.config.optimizer, self.config.trainer | |
| start_time = time.time() | |
| eval_time = {"total_time": 0, "num_evals": 0} | |
| print("Started Training..") | |
| while True: | |
| logger.info("Steps done %d" % (self.train_info["global_steps"])) | |
| train_data = self.runtime_load_dataset("train") | |
| np.random.shuffle(train_data) | |
| logger.info("Per epoch training steps: %d" % len(train_data)) | |
| logger.info("Per epoch training steps: %d" % len(train_data)) | |
| encoder_params, task_params = model.get_params() | |
| stat_per_dataset = defaultdict( | |
| lambda: copy.deepcopy(loss_acc_template_dict) | |
| ) | |
| agg_stat = self.agg | |
| # Training "epoch" -> May not correspond to actual epoch | |
| for cur_document in train_data: | |
| def handle_example(document: Dict) -> Union[None, float]: | |
| self.train_info["global_steps"] += 1 | |
| for key in optimizer: | |
| optimizer[key].zero_grad() | |
| loss_dict: Dict = model.forward_training(document) | |
| total_loss = loss_dict["total"] | |
| if total_loss is None or torch.isnan(total_loss): | |
| print("Problem with Loss. Should not occur often") | |
| return None | |
| total_loss.backward() | |
| # Gradient clipping | |
| try: | |
| for name_ind, param_group in enumerate( | |
| [encoder_params, task_params] | |
| ): | |
| torch.nn.utils.clip_grad_norm_( | |
| param_group, | |
| optimizer_config.max_gradient_norm, | |
| error_if_nonfinite=True, | |
| ) | |
| except RuntimeError: | |
| print("Non Finite Gradient") | |
| return None | |
| for key in optimizer: | |
| self.wandbdata[key + "_lr"] = scheduler[key].get_last_lr()[0] | |
| for key in optimizer: | |
| optimizer[key].step() | |
| scheduler[key].step() | |
| loss_dict_items = {} | |
| for key in loss_dict: | |
| loss_dict_items[key] = loss_dict[key].item() | |
| dataset_name = document["dataset_name"] | |
| # print(f"Total loss {cur_document['doc_key']}: {total_loss.item()}") | |
| for key in loss_dict_items: | |
| stat_per_dataset[dataset_name][key] += loss_dict_items[key] | |
| stat_per_dataset[dataset_name]["processed_docs"] += 1 | |
| return total_loss.item() | |
| loss = handle_example(cur_document) | |
| if self.train_info["global_steps"] % train_config.log_frequency == 0: | |
| max_mem = ( | |
| ( | |
| torch.cuda.max_memory_allocated(self.config.device) | |
| / (1024**3) | |
| ) | |
| if torch.cuda.is_available() | |
| else 0.0 | |
| ) | |
| if self.train_info.get("max_mem", 0.0) < max_mem: | |
| self.train_info["max_mem"] = max_mem | |
| if loss is not None: | |
| logger.info( | |
| "{} {:.3f} Max mem {:.1f} GB".format( | |
| cur_document["doc_key"], | |
| loss, | |
| max_mem, | |
| ) | |
| ) | |
| sys.stdout.flush() | |
| if torch.cuda.is_available(): | |
| torch.cuda.reset_peak_memory_stats() | |
| if train_config.eval_per_k_steps and ( | |
| self.train_info["global_steps"] % train_config.eval_per_k_steps == 0 | |
| ): | |
| print("Eval needs to be done here") | |
| coref_dict = {} | |
| print(stat_per_dataset) | |
| if self.config.use_wandb: | |
| self._wandb_log( | |
| split="train", | |
| stat_per_dataset=stat_per_dataset, | |
| agg_stat=agg_stat, | |
| coref_dict=coref_dict, | |
| step=self.train_info["global_steps"] | |
| // train_config.eval_per_k_steps, | |
| ) | |
| stat_per_dataset = defaultdict( | |
| lambda: copy.deepcopy(loss_acc_template_dict) | |
| ) | |
| macro_fscore = self.periodic_model_eval() | |
| model.train() | |
| # Get elapsed time | |
| elapsed_time = time.time() - start_time | |
| start_time = time.time() | |
| logger.info( | |
| "Steps: %d, Micro F1: %.1f, Max Micro F1: %.1f, Time: %.2f" | |
| % ( | |
| self.train_info["global_steps"], | |
| macro_fscore, | |
| self.train_info["val_perf"], | |
| elapsed_time, | |
| ) | |
| ) | |
| # Check stopping criteria | |
| if not self._is_training_remaining(): | |
| break | |
| # Check stopping criteria | |
| if not self._is_training_remaining(): | |
| break | |
| logger.handlers[0].flush() | |
| def runtime_load_dataset(self, split): | |
| # Shuffle and load the training data | |
| data = [] | |
| for dataset, dataset_data in self.data_iter_map[split].items(): | |
| np.random.shuffle( | |
| dataset_data | |
| ) ### Commenting this so that we can have a deterministic training | |
| if self.num_split_docs_map[split].get(dataset, None) is not None: | |
| # Subsampling the data - This is useful in joint training | |
| logger.info( | |
| f"{dataset}: Subsampled {self.num_split_docs_map[split].get(dataset)}" | |
| ) | |
| random_indices = range(self.num_split_docs_map[split].get(dataset)) | |
| data += [dataset_data[idx] for idx in random_indices] | |
| else: | |
| data += dataset_data | |
| return data | |
| def _wandb_log(self, split, stat_per_dataset, agg_stat, coref_dict, step=None): | |
| for dataset_name in stat_per_dataset: | |
| for metric_vals in stat_per_dataset[dataset_name]: | |
| wandb.log( | |
| data={ | |
| f"{split}/{dataset_name}/{metric_vals}": stat_per_dataset[ | |
| dataset_name | |
| ][metric_vals] | |
| }, | |
| step=step, | |
| ) | |
| if stat_per_dataset[dataset_name]["mention_count"] > 0.0: | |
| ment_prec = ( | |
| stat_per_dataset[dataset_name]["ment_tp"] | |
| / stat_per_dataset[dataset_name]["ment_pp"] | |
| if stat_per_dataset[dataset_name]["ment_pp"] > 0 | |
| else 0 | |
| ) | |
| ment_rec = ( | |
| stat_per_dataset[dataset_name]["ment_tp"] | |
| / stat_per_dataset[dataset_name]["ment_ap"] | |
| if stat_per_dataset[dataset_name]["ment_ap"] > 0 | |
| else 0 | |
| ) | |
| ment_f1 = ( | |
| 2 * (ment_prec * ment_rec) / (ment_prec + ment_rec) | |
| if (ment_prec + ment_rec) > 0 | |
| else 0 | |
| ) | |
| wandb.log( | |
| data={ | |
| f"{split}/{dataset_name}/loss_norm": stat_per_dataset[ | |
| dataset_name | |
| ]["coref"] | |
| / stat_per_dataset[dataset_name]["mention_count"] | |
| + stat_per_dataset[dataset_name]["ment_loss"] | |
| / stat_per_dataset[dataset_name]["ment_total"], | |
| f"{split}/{dataset_name}/ment_acc": stat_per_dataset[ | |
| dataset_name | |
| ]["ment_correct"] | |
| / stat_per_dataset[dataset_name]["ment_total"], | |
| f"{split}/{dataset_name}/ment_prec": ment_prec, | |
| f"{split}/{dataset_name}/ment_rec": ment_rec, | |
| f"{split}/{dataset_name}/ment_f1": ment_f1, | |
| }, | |
| step=step, | |
| ) | |
| else: | |
| print("No mentions processed. Should not occur many times.") | |
| if agg_stat: | |
| for metric in agg_stat(stat_per_dataset): | |
| wandb.log( | |
| data={f"{split}/{metric}": agg_stat(stat_per_dataset)[metric]}, | |
| step=step, | |
| ) | |
| for dataset in coref_dict: | |
| for key in coref_dict[dataset]: | |
| # Log result for individual metrics | |
| if isinstance(coref_dict[dataset][key], dict): | |
| wandb.log( | |
| data={ | |
| f"{split}/{dataset}/{key}": coref_dict[dataset][key].get( | |
| "fscore", 0.0 | |
| ) | |
| }, | |
| step=step, | |
| ) | |
| # Log the overall F-score | |
| wandb.log( | |
| data={ | |
| f"{split}/{dataset}/CoNLL": coref_dict[dataset].get("fscore", 0.0) | |
| }, | |
| step=step, | |
| ) | |
| wandb.log( | |
| data={ | |
| f"{split}/{dataset}/Micro-F1": coref_dict[dataset].get( | |
| "f1_micro", 0.0 | |
| ) | |
| }, | |
| step=step, | |
| ) | |
| wandb.log( | |
| data={ | |
| f"{split}/{dataset}/Macro-F1": coref_dict[dataset].get( | |
| "f1_macro", 0.0 | |
| ) | |
| }, | |
| step=step, | |
| ) | |
| wandb.log(data=self.wandbdata, step=step) | |
| def periodic_model_eval(self) -> float: | |
| """Method for evaluating and saving the model during the training loop. | |
| Returns: | |
| float: Average CoNLL F-score over all the development sets of datasets. | |
| """ | |
| self.model.eval() | |
| ## Dev Loss Calculations: | |
| dev_data = self.runtime_load_dataset("dev") | |
| np.random.shuffle(dev_data) | |
| stat_per_dataset = defaultdict(lambda: copy.deepcopy(loss_acc_template_dict)) | |
| agg_stat = self.agg | |
| for cur_document in dev_data: | |
| def handle_example(document: Dict) -> Union[None, float]: | |
| loss_dict: Dict = self.model.forward_training(document) | |
| total_loss = loss_dict["total"] | |
| if total_loss is None or torch.isnan(total_loss): | |
| print("Problem with Loss. Should not occur many times") | |
| return None | |
| loss_dict_items = {} | |
| for key in loss_dict: | |
| loss_dict_items[key] = loss_dict[key].item() | |
| dataset_name = document["dataset_name"] | |
| for key in loss_dict_items: | |
| stat_per_dataset[dataset_name][key] += loss_dict_items[key] | |
| stat_per_dataset[dataset_name]["processed_docs"] += 1 | |
| return total_loss.item() | |
| loss = handle_example(cur_document) | |
| if loss is None: | |
| continue | |
| # Dev performance | |
| coref_dict = {} | |
| train_config = self.config.trainer | |
| for dataset in self.data_iter_map["dev"]: | |
| for go in [False]: | |
| for tf in [False]: | |
| result_dict = coref_evaluation( | |
| self.config, | |
| self.model, | |
| self.data_iter_map, | |
| dataset, | |
| teacher_force=tf, | |
| gold_mentions=go, | |
| _iter="_" | |
| + str( | |
| self.train_info["global_steps"] | |
| // train_config.eval_per_k_steps | |
| ), | |
| conll_data_dir=self.conll_data_dir, | |
| ) | |
| coref_dict[dataset] = result_dict | |
| if self.config.use_wandb: | |
| self._wandb_log( | |
| split="dev", | |
| stat_per_dataset=stat_per_dataset, | |
| agg_stat=agg_stat, | |
| coref_dict=coref_dict, | |
| step=self.train_info["global_steps"] // train_config.eval_per_k_steps, | |
| ) | |
| # Calculate Mean F-score | |
| fscore = sum([coref_dict[dataset]["fscore"] for dataset in coref_dict]) / len( | |
| coref_dict | |
| ) | |
| micro_fscore = sum( | |
| [coref_dict[dataset]["f1_micro"] for dataset in coref_dict] | |
| ) / len(coref_dict) | |
| macro_fscore = sum( | |
| [coref_dict[dataset]["f1_macro"] for dataset in coref_dict] | |
| ) / len(coref_dict) | |
| logger.info( | |
| "Avg Macro F1: %.1f, Max Micro F1: %.1f" | |
| % (macro_fscore, self.train_info["val_perf"]) | |
| ) | |
| logger.info("Avg Macro F1: %.1f" % (macro_fscore)) | |
| # Update model if dev performance improves | |
| if macro_fscore > self.train_info["val_perf"]: | |
| # Update training bookkeeping variables | |
| self.train_info["num_stuck_evals"] = 0 | |
| self.train_info["val_perf"] = macro_fscore | |
| # Save the best model | |
| logger.info("Saving best model") | |
| self.save_model(self.best_model_path, last_checkpoint=False) | |
| else: | |
| self.train_info["num_stuck_evals"] += 1 | |
| # Save model | |
| if self.config.trainer.to_save_model: | |
| self.save_model(self.model_path, last_checkpoint=True) | |
| # Go back to training mode | |
| self.model.train() | |
| return macro_fscore | |
| def perform_final_eval(self) -> None: | |
| """Method to evaluate the model after training has finished.""" | |
| self.model.eval() | |
| base_output_dict = OmegaConf.to_container(self.config) | |
| perf_summary = {"best_perf": self.train_info["val_perf"]} | |
| if self.config.paths.model_dir: | |
| perf_summary["model_dir"] = path.normpath(self.config.paths.model_dir) | |
| logger.info( | |
| "Max training memory: %.1f GB" % self.train_info.get("max_mem", 0.0) | |
| ) | |
| logger.info("Validation performance: %.1f" % self.train_info["val_perf"]) | |
| perf_file_dict = {} | |
| dataset_output_dict = {} | |
| for split in ["dev", "test"]: | |
| perf_summary[split] = {} | |
| logger.info("\n") | |
| logger.info("%s" % split.capitalize()) | |
| coref_dict = {} | |
| for dataset in self.data_iter_map.get(split, []): | |
| dataset_dir = path.join(self.config.paths.model_dir, dataset) | |
| if not path.exists(dataset_dir): | |
| os.makedirs(dataset_dir) | |
| if dataset not in dataset_output_dict: | |
| dataset_output_dict[dataset] = {} | |
| if dataset not in perf_file_dict: | |
| perf_file_dict[dataset] = path.join(dataset_dir, f"perf.json") | |
| print("Dataset Name:", self.config.datasets[dataset].name) | |
| logger.info("Dataset: %s\n" % self.config.datasets[dataset].name) | |
| for go in [False]: | |
| for tf in [False]: | |
| result_dict = coref_evaluation( | |
| self.config, | |
| self.model, | |
| self.data_iter_map, | |
| dataset=dataset, | |
| split=split, | |
| teacher_force=tf, | |
| gold_mentions=go, | |
| final_eval=True, | |
| conll_data_dir=self.conll_data_dir, | |
| ) | |
| coref_dict[dataset] = result_dict | |
| dataset_output_dict[dataset][split] = result_dict | |
| perf_summary[split][dataset] = result_dict["f1_micro"] | |
| if self.config.use_wandb: | |
| self._wandb_log( | |
| split=split, | |
| stat_per_dataset={}, | |
| agg_stat=None, | |
| coref_dict=coref_dict, | |
| step=None, | |
| ) | |
| sys.stdout.flush() | |
| for dataset, output_dict in dataset_output_dict.items(): | |
| perf_file = perf_file_dict[dataset] | |
| json.dump(output_dict, open(perf_file, "w"), indent=2) | |
| logger.info("Final performance summary at %s" % path.abspath(perf_file)) | |
| summary_file = path.join(self.config.paths.model_dir, "perf.json") | |
| json.dump(perf_summary, open(summary_file, "w"), indent=2) | |
| logger.info("Performance summary file: %s" % path.abspath(summary_file)) | |
| def _initialize_best_model(self): | |
| checkpoint = torch.load( | |
| self.best_model_path, | |
| map_location="cpu", | |
| ) | |
| config = checkpoint["config"] | |
| ## Due to version changes -- these changes are necessary | |
| # if | |
| if self.config.get("override_encoder", False): | |
| model_config = config.model | |
| print(type(self.config.model.doc_encoder.transformer)) | |
| print(self.config.model.doc_encoder.transformer) | |
| model_config.doc_encoder.transformer = ( | |
| self.config.model.doc_encoder.transformer | |
| ) | |
| # Override memory | |
| # For e.g., can test with a different bounded memory size | |
| if self.config.get("override_memory", False): | |
| model_config = config.model | |
| model_config.memory = self.config.model.memory | |
| with open_dict(config): | |
| print("Config change") | |
| config.model.mention_params.ext_ment = ( | |
| self.config.model.mention_params.ext_ment | |
| ) | |
| config = utils.fill_missing_configs(config, self.config) | |
| print("Type: ", config.model.memory.type) | |
| self.config.model = config.model | |
| self.train_info = checkpoint["train_info"] | |
| if self.config.model.doc_encoder.finetune: | |
| # Load the document encoder params if encoder is finetuned | |
| doc_encoder_dir = path.join( | |
| path.dirname(self.best_model_path), | |
| self.config.paths.doc_encoder_dirname, | |
| ) | |
| if path.exists(doc_encoder_dir): | |
| logger.info( | |
| "Loading document encoder from %s" % path.abspath(doc_encoder_dir) | |
| ) | |
| config.model.doc_encoder.transformer.model_str = doc_encoder_dir | |
| self.model = EntityRankingModel(config.model, config.trainer) | |
| # Document encoder parameters will be loaded via the huggingface initialization | |
| self.model.load_state_dict(checkpoint["model"], strict=False) | |
| if torch.cuda.is_available(): | |
| self.model.cuda(device=self.config.device) | |
| def load_model(self, location: str, last_checkpoint=True) -> None: | |
| """Load model from given location. | |
| Args: | |
| location: str | |
| Location of checkpoint | |
| last_checkpoint: bool | |
| Whether the checkpoint is the last one saved or not. | |
| If false, don't load optimizers, schedulers, and other training variables. | |
| """ | |
| checkpoint = torch.load(location, map_location="cpu") | |
| logger.info("Loading model from %s" % path.abspath(location)) | |
| # self.config = checkpoint["config"] ## Commented out so that it does not load the config of the trained model. Removed comment | |
| self.model.load_state_dict( | |
| checkpoint["model"], strict=False | |
| ) ## No encoder in this model so strict=False is compulsary. No other weight missing. Checked | |
| # self.train_info = checkpoint["train_info"] ## No train info transfer too. ## Transferring | |
| if self.config.model.doc_encoder.finetune: | |
| # Load the document encoder params if encoder is finetuned | |
| doc_encoder_dir = path.join( | |
| path.dirname(location), self.config.paths.doc_encoder_dirname | |
| ) | |
| logger.info( | |
| "Loading document encoder from %s" % path.abspath(doc_encoder_dir) | |
| ) | |
| # Load the encoder | |
| self.model.mention_proposer.doc_encoder.lm_encoder = ( | |
| AutoModel.from_pretrained(pretrained_model_name_or_path=doc_encoder_dir) | |
| ) | |
| self.model.mention_proposer.doc_encoder.tokenizer = ( | |
| AutoTokenizer.from_pretrained( | |
| pretrained_model_name_or_path=doc_encoder_dir, | |
| clean_up_tokenization_spaces=True, | |
| ) | |
| ) | |
| if self.model.mention_proposer.doc_encoder.config.finetune: | |
| self.model.mention_proposer.doc_encoder.lm_encoder.gradient_checkpointing_enable() | |
| if torch.cuda.is_available(): | |
| self.model.cuda(device=self.config.device) | |
| print("Loaded Model:", torch.cuda.memory_summary()) | |
| print( | |
| "Gradient checkpointing enabled? ", torch.autograd.grad_checkpoint_enabled() | |
| ) | |
| del checkpoint | |
| torch.cuda.empty_cache() | |
| def save_model(self, location: os.PathLike, last_checkpoint=True) -> None: | |
| """Save model. | |
| Args: | |
| location: Location of checkpoint | |
| last_checkpoint: | |
| Whether the checkpoint is the last one saved or not. | |
| If false, don't save optimizers and schedulers which take up a lot of space. | |
| """ | |
| model_state_dict = OrderedDict(self.model.state_dict()) | |
| doc_encoder_state_dict = {} | |
| # Separate the doc_encoder state dict | |
| # We will save the model in two parts: | |
| # (a) Doc encoder parameters - Useful for final upload to huggingface | |
| # (b) Rest of the model parameters, optimizers, schedulers, and other bookkeeping variables | |
| for key in self.model.state_dict(): | |
| if "lm_encoder." in key: | |
| doc_encoder_state_dict[key] = model_state_dict[key] | |
| del model_state_dict[key] | |
| # Save the document encoder params | |
| if self.config.model.doc_encoder.finetune: | |
| doc_encoder_dir = path.join( | |
| path.dirname(location), self.config.paths.doc_encoder_dirname | |
| ) | |
| if not path.exists(doc_encoder_dir): | |
| os.makedirs(doc_encoder_dir) | |
| logger.info(f"Encoder saved at {path.abspath(doc_encoder_dir)}") | |
| # Save the encoder | |
| self.model.mention_proposer.doc_encoder.lm_encoder.save_pretrained( | |
| save_directory=doc_encoder_dir, save_config=True | |
| ) | |
| # Save the tokenizer | |
| self.model.mention_proposer.doc_encoder.tokenizer.save_pretrained( | |
| doc_encoder_dir | |
| ) | |
| save_dict = { | |
| "train_info": self.train_info, | |
| "model": model_state_dict, | |
| "rng_state": torch.get_rng_state(), | |
| "np_rng_state": np.random.get_state(), | |
| "config": self.config, | |
| } | |
| if self.scaler is not None: | |
| save_dict["scaler"] = self.scaler.state_dict() | |
| if last_checkpoint: | |
| # For last checkpoint save the optimizer and scheduler states as well | |
| save_dict["optimizer"] = {} | |
| save_dict["scheduler"] = {} | |
| param_groups: List[str] = ( | |
| ["mem", "doc"] if self.config.model.doc_encoder.finetune else ["mem"] | |
| ) | |
| for param_group in param_groups: | |
| save_dict["optimizer"][param_group] = self.optimizer[ | |
| param_group | |
| ].state_dict() | |
| save_dict["scheduler"][param_group] = self.optim_scheduler[ | |
| param_group | |
| ].state_dict() | |
| torch.save(save_dict, location) | |
| logger.info(f"Model saved at: {path.abspath(location)}") | |