Spaces:
Running
Running
| import torch | |
| from model.memory import BaseMemory | |
| from pytorch_utils.modules import MLP | |
| import torch.nn as nn | |
| from omegaconf import DictConfig | |
| from typing import Dict, Tuple, List | |
| from torch import Tensor | |
| from tqdm import tqdm | |
| import math | |
| class EntityMemory(BaseMemory): | |
| """Module for clustering proposed mention spans using Entity-Ranking paradigm.""" | |
| def __init__( | |
| self, config: DictConfig, span_emb_size: int, drop_module: nn.Module | |
| ) -> None: | |
| super(EntityMemory, self).__init__(config, span_emb_size, drop_module) | |
| self.mem_type: DictConfig = config.mem_type | |
| def forward_training( | |
| self, | |
| ment_boundaries: Tensor, | |
| mention_emb_list: List[Tensor], | |
| rep_emb_list: List[Tensor], | |
| gt_actions: List[Tuple[int, str]], | |
| metadata: Dict, | |
| ) -> List[Tensor]: | |
| """ | |
| Forward pass during coreference model training where we use teacher-forcing. | |
| Args: | |
| ment_boundaries: Mention boundaries of proposed mentions | |
| mention_emb_list: Embedding list of proposed mentions | |
| gt_actions: Ground truth clustering actions | |
| metadata: Metadata such as document genre | |
| Returns: | |
| coref_new_list: Logit scores for ground truth actions. | |
| """ | |
| assert ( | |
| len(rep_emb_list) != 0 | |
| ), "There are no entity representations, should not happen." | |
| # Initialize memory | |
| coref_new_list = [] | |
| mem_vectors, mem_vectors_init, ent_counter, last_mention_start = ( | |
| self.initialize_memory(rep=rep_emb_list) | |
| ) | |
| for ment_idx, (ment_emb, (gt_cell_idx, gt_action_str)) in enumerate( | |
| zip(mention_emb_list, gt_actions) | |
| ): | |
| ment_start, ment_end = ment_boundaries[ment_idx] | |
| if self.config.num_feats != 0: | |
| feature_embs = self.get_feature_embs( | |
| ment_start, last_mention_start, ent_counter, metadata | |
| ) | |
| else: | |
| feature_embs = torch.empty(mem_vectors.shape[0], 0, device=self.device) | |
| coref_new_scores = self.get_coref_new_scores( | |
| ment_emb, mem_vectors, mem_vectors_init, ent_counter, feature_embs | |
| ) | |
| coref_new_list.append(coref_new_scores) | |
| # Teacher forcing | |
| action_str, cell_idx = gt_action_str, gt_cell_idx | |
| num_ents: int = int(torch.sum((ent_counter > 0).long()).item()) | |
| cell_mask: Tensor = ( | |
| torch.arange(start=0, end=num_ents, device=self.device) | |
| == torch.tensor(cell_idx) | |
| ).float() | |
| mask = torch.unsqueeze(cell_mask, dim=1) | |
| mask = mask.repeat(1, self.mem_size) | |
| ## Update memory if action is cluster and memory is not static | |
| if action_str == "c" and self.config.type != "static": | |
| coref_vec = self.coref_update( | |
| ment_emb, mem_vectors, cell_idx, ent_counter | |
| ) | |
| mem_vectors = mem_vectors * (1 - mask) + mask * coref_vec | |
| ent_counter[cell_idx] = ent_counter[cell_idx] + 1 | |
| last_mention_start[cell_idx] = ment_start | |
| return coref_new_list | |
| def forward( | |
| self, | |
| ment_boundaries: Tensor, | |
| mention_emb_list: List[Tensor], | |
| rep_emb_list: List[Tensor], | |
| gt_actions: List[Tuple[int, str]], | |
| metadata: Dict, | |
| teacher_force: False, | |
| memory_init=None, | |
| ): | |
| """Forward pass for clustering entity mentions during inference/evaluation. | |
| Args: | |
| ment_boundaries: Start and end token indices for the proposed mentions. | |
| mention_emb_list: Embedding list of proposed mentions | |
| metadata: Metadata features such as document genre embedding | |
| memory_init: Initializer for memory. For streaming coreference, we can pass the previous | |
| memory state via this dictionary | |
| Returns: | |
| pred_actions: List of predicted clustering actions. | |
| mem_state: Current memory state. | |
| """ | |
| ## Check length of mention_emb_list == gt_action | |
| assert len(mention_emb_list) == len(gt_actions) | |
| # Initialize memory | |
| if memory_init is not None: | |
| mem_vectors, mem_vectors_init, ent_counter, last_mention_start = ( | |
| self.initialize_memory(**memory_init, rep=rep_emb_list) | |
| ) | |
| else: | |
| mem_vectors, mem_vectors_init, ent_counter, last_mention_start = ( | |
| self.initialize_memory(rep=rep_emb_list) | |
| ) | |
| pred_actions = [] # argmax actions | |
| coref_scores_list = [] | |
| ## Tensorized approach for static method | |
| if self.config.type == "static": | |
| batch_size = self.config.batch_size | |
| ### Mention Emb list gets batched in batch size | |
| num_batches = len(mention_emb_list) // batch_size + int( | |
| len(mention_emb_list) % batch_size != 0 | |
| ) | |
| for i in range(num_batches): | |
| print("Batch Number: ", i) | |
| start_idx = i * batch_size | |
| end_idx = min((i + 1) * batch_size, len(mention_emb_list)) | |
| num_elements = end_idx - start_idx | |
| if ent_counter.size() == 0: | |
| next_cell_idx, next_action_str = 0, "o" | |
| pred_actions.extend( | |
| [(next_cell_idx, next_action_str) * num_elements] | |
| ) | |
| continue | |
| ment_emb_tensor = torch.stack( | |
| mention_emb_list[start_idx:end_idx], dim=0 | |
| ) | |
| ment_start, ment_end = ( | |
| ment_boundaries[start_idx:end_idx, 0], | |
| ment_boundaries[start_idx:end_idx, 1], | |
| ) | |
| if self.config.num_feats != 0: | |
| feature_embs = self.get_feature_embs_tensorized( | |
| ment_start, last_mention_start, ent_counter, metadata | |
| ) ## [B,D,20] | |
| else: | |
| feature_embs = torch.empty( | |
| ment_start.shape[0], mem_vectors.shape[0], 0, device=self.device | |
| ) ## [B,D,20] | |
| coref_new_scores = self.get_coref_new_scores_tensorized( | |
| ment_emb_tensor, | |
| mem_vectors, | |
| mem_vectors_init, | |
| ent_counter, | |
| feature_embs, | |
| ) | |
| coref_copy = coref_new_scores.clone().detach().cpu() | |
| coref_scores_list.extend(coref_copy) | |
| assigned_cluster = self.assign_cluster_tensorized(coref_new_scores) | |
| gt_actions_batch = gt_actions[start_idx:end_idx] | |
| if teacher_force: | |
| pred_actions.extend(gt_actions_batch) | |
| else: | |
| pred_actions.extend(assigned_cluster) | |
| else: | |
| for ment_idx, ment_emb in enumerate(mention_emb_list): | |
| if ent_counter.size() == 0: | |
| next_cell_idx, next_action_str = 0, "o" | |
| pred_actions.append((next_cell_idx, next_action_str)) | |
| continue | |
| ment_start, ment_end = ment_boundaries[ment_idx] | |
| if self.config.num_feats != 0: | |
| feature_embs = self.get_feature_embs( | |
| ment_start, last_mention_start, ent_counter, metadata | |
| ) | |
| else: | |
| feature_embs = torch.empty( | |
| mem_vectors.shape[0], 0, device=self.device | |
| ) | |
| coref_new_scores = self.get_coref_new_scores( | |
| ment_emb, mem_vectors, mem_vectors_init, ent_counter, feature_embs | |
| ) | |
| coref_copy = coref_new_scores.clone().detach().cpu() | |
| coref_scores_list.append(coref_copy) | |
| pred_cell_idx, pred_action_str = self.assign_cluster(coref_new_scores) | |
| if teacher_force: | |
| next_cell_idx, next_action_str = gt_actions[ment_idx] | |
| pred_actions.append(gt_actions[ment_idx]) | |
| else: | |
| next_cell_idx, next_action_str = pred_cell_idx, pred_action_str | |
| pred_actions.append((pred_cell_idx, pred_action_str)) | |
| if next_action_str == "c": | |
| coref_vec = self.coref_update( | |
| ment_emb, mem_vectors, next_cell_idx, ent_counter | |
| ) | |
| mem_vectors[next_cell_idx] = coref_vec | |
| ent_counter[next_cell_idx] = ent_counter[next_cell_idx] + 1 | |
| last_mention_start[next_cell_idx] = ment_start | |
| mem_state = { | |
| "mem": mem_vectors, | |
| "mem_init": mem_vectors_init, | |
| "ent_counter": ent_counter, | |
| "last_mention_start": last_mention_start, | |
| } | |
| return pred_actions, mem_state, coref_scores_list | |