Spaces:
Running
Running
| import torch | |
| import torch.nn as nn | |
| from model.mention_proposal import MentionProposalModule | |
| from model.utils import get_gt_actions | |
| from model.memory.entity_memory import EntityMemory | |
| from torch.profiler import profile, record_function, ProfilerActivity | |
| from typing import Dict, List, Tuple | |
| from omegaconf import DictConfig | |
| from torch import Tensor | |
| from transformers import PreTrainedTokenizerFast | |
| import logging | |
| import random | |
| from collections import defaultdict | |
| import copy | |
| import time | |
| logging.basicConfig(format="%(asctime)s - %(message)s", level=logging.INFO) | |
| logger = logging.getLogger() | |
| class EntityRankingModel(nn.Module): | |
| """ | |
| Coreference model based on Entity-Ranking paradigm. | |
| In the entity-ranking paradigm, given a new mention we rank the different | |
| entity clusters to determine the clustering updates. Entity-Ranking paradigm | |
| allows for a naturally scalable solution to coreference resolution. | |
| Reference: Rahman and Ng [https://arxiv.org/pdf/1405.5202.pdf] | |
| This particular implementation represents the entities/clusters via fixed-dimensional | |
| dense representations, typically a simple avereage of mention representations. | |
| Clustering is performed in an online, autoregressive manner where mentions are | |
| processed in a left-to-right manner. | |
| References: | |
| Toshniwal et al [https://arxiv.org/pdf/2010.02807.pdf] | |
| Toshniwal et al [https://arxiv.org/pdf/2109.09667.pdf] | |
| """ | |
| def __init__(self, model_config: DictConfig, train_config: DictConfig): | |
| super(EntityRankingModel, self).__init__() | |
| self.config = model_config | |
| self.train_config = train_config | |
| # Dropout module - Used during training | |
| self.drop_module = nn.Dropout(p=train_config.dropout_rate) | |
| self.loss_template_dict = { | |
| "total": torch.tensor(0.0, requires_grad=True), | |
| "ment_loss": torch.tensor(0.0), | |
| "coref": torch.tensor(0.0), | |
| "mention_count": torch.tensor(0.0), | |
| "ment_correct": torch.tensor(0.001), | |
| "ment_total": torch.tensor(0.001), | |
| "ment_tp": torch.tensor(0.001), | |
| "ment_pp": torch.tensor(0.001), | |
| "ment_ap": torch.tensor(0.001), | |
| } | |
| # Document encoder + Mention proposer | |
| self.mention_proposer = MentionProposalModule( | |
| self.config, train_config, drop_module=self.drop_module | |
| ) | |
| # Clustering module | |
| span_emb_size: int = self.mention_proposer.span_emb_size | |
| # Use of genre feature in clustering or not | |
| if self.config.metadata_params.use_genre_feature: | |
| self.config.memory.num_feats = 3 | |
| self.mem_type = self.config.memory.mem_type.name | |
| self.memory_net = EntityMemory( | |
| config=self.config.memory, | |
| span_emb_size=span_emb_size, | |
| drop_module=self.drop_module, | |
| ) | |
| self.loss_fn = nn.CrossEntropyLoss( | |
| label_smoothing=self.train_config.label_smoothing_wt | |
| ) | |
| if self.config.metadata_params.use_genre_feature: | |
| self.genre_embeddings = nn.Embedding( | |
| num_embeddings=len(self.config.metadata_params.genres), | |
| embedding_dim=self.config.mention_params.emb_size, | |
| ) | |
| def device(self) -> torch.device: | |
| return self.mention_proposer.device | |
| def get_params(self, named=False) -> Tuple[List, List]: | |
| """Returns a tuple of document encoder parameters and rest of the model params.""" | |
| encoder_params, mem_params = [], [] | |
| for name, param in self.named_parameters(): | |
| elem = (name, param) if named else param | |
| if "doc_encoder" in name: | |
| encoder_params.append(elem) | |
| else: | |
| mem_params.append(elem) | |
| return encoder_params, mem_params | |
| def get_tokenizer(self) -> PreTrainedTokenizerFast: | |
| """Returns tokenizer used by the document encoder.""" | |
| return self.mention_proposer.doc_encoder.get_tokenizer() | |
| def get_metadata(self, document: Dict) -> Dict: | |
| """Extract metadata such as document genre from document.""" | |
| meta_params = self.config.metadata_params | |
| if meta_params.use_genre_feature: | |
| doc_class = document["doc_key"][:2] | |
| if doc_class in meta_params.genres: | |
| doc_class_idx = meta_params.genres.index(doc_class) | |
| else: | |
| doc_class_idx = meta_params.genres.index( | |
| meta_params.default_genre | |
| ) # Default genre | |
| return { | |
| "genre": self.genre_embeddings( | |
| torch.tensor(doc_class_idx, device=self.device) | |
| ) | |
| } | |
| else: | |
| return {} | |
| def calculate_coref_loss( | |
| self, action_prob_list: List, action_tuple_list: List[Tuple[int, str]] | |
| ) -> Tensor: | |
| """Calculates the coreference loss for the autoregressive online clustering module. | |
| Args: | |
| action_prob_list (List): | |
| Probability of each clustering action i.e. mention is merged with existing clusters | |
| or a new cluster is created. | |
| action_tuple_list (List[Tuple[int, str]]): | |
| Ground truth actions represented as a tuple of cluster index and action string. | |
| 'c' represents that the mention is coreferent with existing clusters while | |
| 'o' represents that the mention represents a new cluster. | |
| Returns: | |
| coref_loss (torch.Tensor): | |
| The scalar tensor representing the coreference loss. | |
| """ | |
| counter = 0 | |
| correct = 0 | |
| coref_loss = torch.tensor(0.0, device=self.device) | |
| num_predictions_clusters = defaultdict(int) | |
| for idx, (cell_idx, action_str) in enumerate(action_tuple_list): | |
| if action_str == "c": | |
| ## Major Entity | |
| gt_idx = cell_idx | |
| elif action_str == "o": | |
| ## Other Entity | |
| gt_idx = action_prob_list[counter].shape[0] - 1 | |
| else: | |
| continue | |
| target = torch.tensor([gt_idx], device=self.device) | |
| if target[0] == torch.argmax( | |
| torch.unsqueeze(action_prob_list[counter], dim=0) | |
| ): | |
| correct += 1 | |
| num_predictions_clusters[ | |
| torch.argmax(torch.unsqueeze(action_prob_list[counter], dim=0)).item() | |
| ] += 1 | |
| coref_loss += self.loss_fn( | |
| torch.unsqueeze(action_prob_list[counter], dim=0), target | |
| ) | |
| counter += 1 | |
| return coref_loss | |
| def get_filtered_clusters( | |
| clusters, | |
| init_token_offset, | |
| final_token_offset, | |
| cluster_mask=None, | |
| with_offset=True, | |
| ): | |
| """Filter clusters from a document given the token offsets.""" | |
| """Note that len(cluster_mask) == len(clusters) assured in the previous function.""" | |
| filt_clusters = [] | |
| no_rep_cluster_mentions = ( | |
| [] | |
| ) ## Mentions that belonged to a major entity whose representative phrase is not part of the current mentions. | |
| for cluster_ind, orig_cluster in enumerate(clusters): | |
| cluster = [] | |
| for ment_start, ment_end in orig_cluster: | |
| if ment_start >= init_token_offset and ment_end < final_token_offset: | |
| if with_offset: | |
| cluster.append((ment_start, ment_end)) | |
| else: | |
| cluster.append( | |
| ( | |
| ment_start - init_token_offset, | |
| ment_end - init_token_offset, | |
| ) | |
| ) | |
| if len(cluster) != 0: | |
| if ( | |
| cluster_mask | |
| ): ## During this process if we missed any representative phrases, all clusters that have no representative phrase will be added to the last cluster. | |
| if ( | |
| cluster_mask[cluster_ind] == True | |
| ): ## If representative phrase is in the current segment then, there exists atleast one mention that belongs to the cluster. But anyways | |
| filt_clusters.append(cluster) | |
| else: | |
| no_rep_cluster_mentions.extend(cluster) | |
| else: | |
| filt_clusters.append(cluster) | |
| if cluster_mask: | |
| if len(filt_clusters) == 0: | |
| filt_clusters.append(no_rep_cluster_mentions) | |
| else: | |
| filt_clusters[-1].extend(no_rep_cluster_mentions) | |
| return filt_clusters | |
| def get_filtered_representatives( | |
| representatives, init_token_offset, final_token_offset, with_offset=True | |
| ): | |
| """Filter clusters from a document given the token offsets.""" | |
| filt_reps = [] | |
| indices = [] | |
| for rep_ind, (ment_start, ment_end) in enumerate(representatives): | |
| if ment_start >= init_token_offset and ment_end < final_token_offset: | |
| if with_offset: | |
| filt_reps.append((ment_start, ment_end)) | |
| else: | |
| filt_reps.append( | |
| ( | |
| ment_start - init_token_offset, | |
| ment_end - init_token_offset, | |
| ) | |
| ) | |
| indices.append(rep_ind) | |
| return filt_reps, indices | |
| def mask_representative_phrases(rep_emb_list): | |
| positive_inds = [] | |
| for rep_emb_ind, rep_emb in enumerate(rep_emb_list): | |
| if not isinstance(rep_emb, int): | |
| positive_inds.append(rep_emb_ind) | |
| if len(positive_inds) > 1: | |
| num_entitites_preserved = random.randint(1, len(positive_inds)) | |
| random.shuffle(positive_inds) | |
| for ind in positive_inds[num_entitites_preserved:]: | |
| rep_emb_list[ind] = -1 | |
| return rep_emb_list | |
| def forward_training(self, document: Dict) -> Dict: | |
| """Forward pass for training. | |
| Args: | |
| document: The tensorized document. | |
| Returns: | |
| loss_dict (Dict): Loss dictionary containing the losses of different stages of the model. | |
| """ | |
| # print(document["doc_key"]) | |
| assert ( | |
| len(document["clusters"]) == len(document["representatives"]) + 1 | |
| ), "Length of clusters not equal to length of representatives + 1." | |
| assert document["representatives"] == sorted( | |
| document["representatives"] | |
| ), "Representatives are not sorted." | |
| loss_dict = copy.deepcopy(self.loss_template_dict) | |
| max_training_segments = self.train_config.get("max_training_segments", None) | |
| num_segments = len(document["sentences"]) | |
| if max_training_segments is None: | |
| seg_range = [0, num_segments] | |
| else: | |
| if num_segments > max_training_segments: | |
| start_seg = random.randint(0, num_segments - max_training_segments) | |
| seg_range = [start_seg, start_seg + max_training_segments] | |
| else: | |
| seg_range = [0, num_segments] | |
| # Initialize lists to track all the mentions predicted across the chunks | |
| pred_mentions_list, mention_emb_list, rep_emb_list = ( | |
| [], | |
| [], | |
| [-1 for _ in range(len(document["representatives"]))], | |
| ) | |
| init_token_offset = sum( | |
| [len(document["sentences"][idx]) for idx in range(0, seg_range[0])] | |
| ) | |
| token_offset = init_token_offset | |
| # Metadata such as document genre can be used by model for clustering | |
| metadata = self.get_metadata(document) | |
| # Initialize the mention loss | |
| ment_loss = None | |
| # Step 1: Predict all the mentions | |
| for idx in range(seg_range[0], seg_range[1]): | |
| num_tokens = len(document["sentences"][idx]) | |
| representatives_entities, rep_filtered_inds = ( | |
| self.get_filtered_representatives( | |
| document["representatives"], | |
| token_offset, | |
| token_offset + num_tokens, | |
| with_offset=False, | |
| ) | |
| ) | |
| cur_doc_slice = { | |
| "tensorized_sent": document["tensorized_sent"][idx], | |
| "sentence_map": document["sentence_map"][ | |
| token_offset : token_offset + num_tokens | |
| ], | |
| "subtoken_map": document["subtoken_map"][ | |
| token_offset : token_offset + num_tokens | |
| ], | |
| "sent_len_list": [document["sent_len_list"][idx]], | |
| "clusters": self.get_filtered_clusters( | |
| document["clusters"], | |
| token_offset, | |
| token_offset + num_tokens, | |
| with_offset=False, | |
| ), | |
| "representatives": representatives_entities, | |
| "doc_key": document["doc_key"], | |
| } | |
| ## No golden mentions in the current segment and mode is golden so basically no job to do. | |
| if ( | |
| len(cur_doc_slice["clusters"]) == 0 | |
| and self.mention_proposer.config.mention_params.use_gold_ments | |
| ): | |
| token_offset += num_tokens | |
| continue | |
| proposer_output_dict = self.mention_proposer(cur_doc_slice, eval_loss=True) | |
| ### Shifted above because if the model predicts no mentions then earlier it had no mention loss. But now it has. | |
| if "ment_loss" in proposer_output_dict: | |
| if ment_loss is None: | |
| ment_loss = proposer_output_dict["ment_loss"] | |
| else: | |
| ment_loss += proposer_output_dict["ment_loss"] | |
| ## If no mentions are predicted, originally then no coref loss and surprisingly no mention loss as well :) | |
| if proposer_output_dict.get("ments", None) is None: | |
| token_offset += num_tokens | |
| continue | |
| ## Mention post-processing and collection happens here: Add the document offset to mentions predicted for the current chunk | |
| cur_pred_mentions = proposer_output_dict.get("ments") + token_offset | |
| pred_mentions_list.extend(cur_pred_mentions.tolist()) | |
| mention_emb_list.extend(proposer_output_dict["ment_emb_list"]) | |
| for key in ["ment_correct", "ment_total", "ment_tp", "ment_pp", "ment_ap"]: | |
| if key in proposer_output_dict: | |
| loss_dict[key] += proposer_output_dict[key] | |
| ## Collect representation embeddings: | |
| for ind, rep_ind in enumerate(rep_filtered_inds): | |
| rep_emb_list[rep_ind] = proposer_output_dict["rep_emb_list"][ind] | |
| # Update the document offset for next iteration | |
| token_offset += num_tokens | |
| ## Collect mention detection loss | |
| if ment_loss is not None: | |
| ## Tried training the model with only mention loss, but it did not work well. | |
| if self.train_config.ment_loss_incl: | |
| loss_dict["total"] = ment_loss | |
| loss_dict["ment_loss"] = ment_loss | |
| # Step 2: Perform clustering | |
| # Get clusters part of the truncated document | |
| ## select certain entities or representatives | |
| if self.train_config.get("generalise", False): | |
| rep_emb_list = self.mask_representative_phrases(rep_emb_list) | |
| rep_emb_list_filtered = [] | |
| entities_mask = [] | |
| for rep_emb in rep_emb_list: | |
| if not isinstance(rep_emb, int): | |
| rep_emb_list_filtered.append(rep_emb) | |
| entities_mask.append(True) | |
| else: | |
| entities_mask.append(False) | |
| ## For the other cluster that contains all the mentions that do not belong to any representative phrase. | |
| entities_mask.append(True) | |
| truncated_document_clusters = { | |
| "clusters": self.get_filtered_clusters( | |
| document["clusters"], | |
| init_token_offset, | |
| token_offset, | |
| cluster_mask=entities_mask, | |
| ) | |
| } | |
| assert ( | |
| len(document["clusters"]) == len(document["representatives"]) + 1 | |
| ), "Number of clusters and representatives after segmentation do not match." | |
| # Get ground truth clustering mentions | |
| gt_actions: List[Tuple[int, str]] = get_gt_actions( | |
| pred_mentions_list, truncated_document_clusters, self.config.memory.mem_type | |
| ) | |
| pred_mentions = torch.tensor(pred_mentions_list, device=self.device) | |
| if ( | |
| len(rep_emb_list_filtered) == 0 | |
| ): ## No representative phrases in the current segments, so no coref loss | |
| return loss_dict | |
| coref_new_list = self.memory_net.forward_training( | |
| pred_mentions, mention_emb_list, rep_emb_list_filtered, gt_actions, metadata | |
| ) | |
| if len(coref_new_list) > 0: | |
| coref_loss = self.calculate_coref_loss(coref_new_list, gt_actions) | |
| loss_dict["total"] = loss_dict["total"] + coref_loss | |
| loss_dict["coref"] = coref_loss | |
| loss_dict["mention_count"] += torch.tensor(len(coref_new_list)) | |
| return loss_dict | |
| def forward(self, document: Dict, teacher_force=False, gold_mentions=False): | |
| """Forward pass of the streaming coreference model. | |
| This method performs streaming coreference. The entity clusters from previous | |
| documents chunks are represented as vectors and passed along to the processing | |
| of subsequent chunks along with the metadata associated with these clusters. | |
| Args: | |
| document (Dict): Tensorized document | |
| Returns: | |
| pred_mentions_list (List): Mentions predicted by the mention proposal module | |
| mention_scores (List): Scores assigned by the mention proposal module for | |
| the predicted mentions | |
| gt_actions (List): Ground truth clustering actions; useful for calculating oracle performance | |
| action_list (List): Actions predicted by the clustering module for the predicted mentions | |
| '""" | |
| # Initialize lists to track all the actions taken, mentions predicted across the chunks | |
| assert document["representatives"] == sorted( | |
| document["representatives"] | |
| ), "Representatives are not sorted." | |
| print("Device: ", self.device) | |
| print("#" * 40) | |
| pred_mentions_list, pred_mention_emb_list, mention_scores, pred_actions = ( | |
| [], | |
| [], | |
| [], | |
| [], | |
| ) | |
| # Initialize entity clusters and current document token offset | |
| entity_cluster_states, token_offset = None, 0 | |
| metadata = self.get_metadata(document) | |
| coref_scores_doc = [] | |
| link_time = 0.0 | |
| for idx in range(0, len(document["sentences"])): | |
| num_tokens = len(document["sentences"][idx]) | |
| new_representatives_entities, rep_filtered_inds = ( | |
| self.get_filtered_representatives( | |
| document["representatives"], | |
| token_offset, | |
| token_offset + num_tokens, | |
| with_offset=False, | |
| ) | |
| ) | |
| ext_predicted_mentions_filt, _ = self.get_filtered_representatives( | |
| document.get("ext_predicted_mentions", []), | |
| token_offset, | |
| token_offset + num_tokens, | |
| with_offset=False, | |
| ) | |
| cur_example = { | |
| "tensorized_sent": document["tensorized_sent"][idx], | |
| "sentence_map": document["sentence_map"][ | |
| token_offset : token_offset + num_tokens | |
| ], | |
| "subtoken_map": document["subtoken_map"][ | |
| token_offset : token_offset + num_tokens | |
| ], | |
| "sent_len_list": [document["sent_len_list"][idx]], | |
| "clusters": self.get_filtered_clusters( | |
| document["clusters"], | |
| token_offset, | |
| token_offset + num_tokens, | |
| with_offset=False, | |
| ), | |
| "representatives": new_representatives_entities, | |
| "ext_predicted_mentions": ext_predicted_mentions_filt, | |
| } | |
| # Pass along other metadata | |
| for key in document: | |
| if key not in cur_example: | |
| cur_example[key] = document[key] | |
| if len(cur_example["clusters"]) == 0 and ( | |
| self.mention_proposer.config.mention_params.use_gold_ments | |
| or gold_mentions | |
| ): | |
| token_offset += num_tokens | |
| continue | |
| proposer_output_dict = self.mention_proposer( | |
| cur_example, gold_mentions=gold_mentions | |
| ) | |
| if proposer_output_dict.get("ments", None) is None: | |
| token_offset += num_tokens | |
| continue | |
| # Add the document offset to mentions predicted for the current chunk | |
| # It's important to add the offset before clustering because features like | |
| # number of tokens between the last mention of the cluster and the current mention | |
| # will be affected if the current token indices of the mention are not supplied. | |
| cur_pred_mentions = proposer_output_dict.get("ments") + token_offset | |
| # Update the document offset for next iteration | |
| token_offset += num_tokens | |
| # Get ground truth clustering mentions | |
| pred_mentions_list.extend(cur_pred_mentions.tolist()) | |
| gt_actions_full: List[Tuple[int, str]] = get_gt_actions( | |
| pred_mentions_list, document, self.config.memory.mem_type | |
| ) | |
| gt_actions = gt_actions_full[-len(cur_pred_mentions.tolist()) :] | |
| pred_mention_emb_list.extend( | |
| [emb.tolist() for emb in proposer_output_dict.get("ment_emb_list")] | |
| ) | |
| mention_scores.extend(proposer_output_dict["ment_scores"].tolist()) | |
| start_time = time.time() | |
| repr_candidates = list(proposer_output_dict["rep_emb_list"]) | |
| # Pass along entity clusters from previous chunks while processing next chunks | |
| cur_pred_actions, entity_cluster_states, coref_scores_list = ( | |
| self.memory_net( | |
| cur_pred_mentions, | |
| list(proposer_output_dict["ment_emb_list"]), | |
| repr_candidates, | |
| gt_actions, | |
| metadata, | |
| teacher_force=teacher_force, | |
| memory_init=entity_cluster_states, | |
| ) | |
| ) | |
| link_time += time.time() - start_time | |
| # print( | |
| # "Number of representatives available now: ", | |
| # entity_cluster_states["mem"].shape[0], | |
| # ) | |
| pred_actions.extend(cur_pred_actions) | |
| coref_scores_doc.extend(coref_scores_list) | |
| gt_actions = get_gt_actions( | |
| pred_mentions_list, document, self.config.memory.mem_type | |
| ) # Useful for oracle calcs | |
| for ind in range(len(coref_scores_doc)): | |
| coref_scores_doc[ind] = coref_scores_doc[ind].tolist() | |
| if entity_cluster_states is not None: | |
| for key in entity_cluster_states: | |
| entity_cluster_states[key] = entity_cluster_states[key].tolist() | |
| return ( | |
| pred_mentions_list, | |
| pred_mention_emb_list, | |
| mention_scores, | |
| gt_actions, | |
| pred_actions, | |
| coref_scores_doc, | |
| entity_cluster_states, | |
| link_time, | |
| ) | |