""" Describes SpanPredictor which aims to predict spans by taking as input head word and context embeddings. """ from typing import List, Optional, Tuple from stanza.models.coref.const import Doc, Span import torch class SpanPredictor(torch.nn.Module): def __init__(self, input_size: int, distance_emb_size: int): super().__init__() self.ffnn = torch.nn.Sequential( torch.nn.Linear(input_size * 2 + 64, input_size), torch.nn.ReLU(), torch.nn.Dropout(0.3), torch.nn.Linear(input_size, 256), torch.nn.ReLU(), torch.nn.Dropout(0.3), torch.nn.Linear(256, 64), ) self.conv = torch.nn.Sequential( torch.nn.Conv1d(64, 4, 3, 1, 1), torch.nn.Conv1d(4, 2, 3, 1, 1) ) self.emb = torch.nn.Embedding(128, distance_emb_size) # [-63, 63] + too_far @property def device(self) -> torch.device: """ A workaround to get current device (which is assumed to be the device of the first parameter of one of the submodules) """ return next(self.ffnn.parameters()).device def forward(self, # type: ignore # pylint: disable=arguments-differ #35566 in pytorch doc: Doc, words: torch.Tensor, heads_ids: torch.Tensor) -> torch.Tensor: """ Calculates span start/end scores of words for each span head in heads_ids Args: doc (Doc): the document data words (torch.Tensor): contextual embeddings for each word in the document, [n_words, emb_size] heads_ids (torch.Tensor): word indices of span heads Returns: torch.Tensor: span start/end scores, [n_heads, n_words, 2] """ # Obtain distance embedding indices, [n_heads, n_words] relative_positions = (heads_ids.unsqueeze(1) - torch.arange(words.shape[0], device=words.device).unsqueeze(0)) emb_ids = relative_positions + 63 # make all valid distances positive emb_ids[(emb_ids < 0) + (emb_ids > 126)] = 127 # "too_far" # Obtain "same sentence" boolean mask, [n_heads, n_words] sent_id = torch.tensor(doc["sent_id"], device=words.device) same_sent = (sent_id[heads_ids].unsqueeze(1) == sent_id.unsqueeze(0)) # To save memory, only pass candidates from one sentence for each head # pair_matrix contains concatenated span_head_emb + candidate_emb + distance_emb # for each candidate among the words in the same sentence as span_head # [n_heads, input_size * 2 + distance_emb_size] rows, cols = same_sent.nonzero(as_tuple=True) pair_matrix = torch.cat(( words[heads_ids[rows]], words[cols], self.emb(emb_ids[rows, cols]), ), dim=1) lengths = same_sent.sum(dim=1) padding_mask = torch.arange(0, lengths.max(), device=words.device).unsqueeze(0) padding_mask = (padding_mask < lengths.unsqueeze(1)) # [n_heads, max_sent_len] # [n_heads, max_sent_len, input_size * 2 + distance_emb_size] # This is necessary to allow the convolution layer to look at several # word scores padded_pairs = torch.zeros(*padding_mask.shape, pair_matrix.shape[-1], device=words.device) padded_pairs[padding_mask] = pair_matrix res = self.ffnn(padded_pairs) # [n_heads, n_candidates, last_layer_output] res = self.conv(res.permute(0, 2, 1)).permute(0, 2, 1) # [n_heads, n_candidates, 2] scores = torch.full((heads_ids.shape[0], words.shape[0], 2), float('-inf'), device=words.device) scores[rows, cols] = res[padding_mask] # Make sure that start <= head <= end during inference if not self.training: valid_starts = torch.log((relative_positions >= 0).to(torch.float)) valid_ends = torch.log((relative_positions <= 0).to(torch.float)) valid_positions = torch.stack((valid_starts, valid_ends), dim=2) return scores + valid_positions return scores def get_training_data(self, doc: Doc, words: torch.Tensor ) -> Tuple[Optional[torch.Tensor], Optional[Tuple[torch.Tensor, torch.Tensor]]]: """ Returns span starts/ends for gold mentions in the document. """ head2span = sorted(doc["head2span"]) if not head2span: return None, None heads, starts, ends = zip(*head2span) heads = torch.tensor(heads, device=self.device) starts = torch.tensor(starts, device=self.device) ends = torch.tensor(ends, device=self.device) - 1 return self(doc, words, heads), (starts, ends) def predict(self, doc: Doc, words: torch.Tensor, clusters: List[List[int]]) -> List[List[Span]]: """ Predicts span clusters based on the word clusters. Args: doc (Doc): the document data words (torch.Tensor): [n_words, emb_size] matrix containing embeddings for each of the words in the text clusters (List[List[int]]): a list of clusters where each cluster is a list of word indices Returns: List[List[Span]]: span clusters """ if not clusters: return [] heads_ids = torch.tensor( sorted(i for cluster in clusters for i in cluster), device=self.device ) scores = self(doc, words, heads_ids) starts = scores[:, :, 0].argmax(dim=1).tolist() ends = (scores[:, :, 1].argmax(dim=1) + 1).tolist() head2span = { head: (start, end) for head, start, end in zip(heads_ids.tolist(), starts, ends) } return [[head2span[head] for head in cluster] for cluster in clusters]