|
|
""" Describes SpanPredictor which aims to predict spans by taking as input |
|
|
head word and context embeddings. |
|
|
""" |
|
|
|
|
|
from typing import List, Optional, Tuple |
|
|
|
|
|
from stanza.models.coref.const import Doc, Span |
|
|
import torch |
|
|
|
|
|
|
|
|
class SpanPredictor(torch.nn.Module): |
|
|
def __init__(self, input_size: int, distance_emb_size: int): |
|
|
super().__init__() |
|
|
self.ffnn = torch.nn.Sequential( |
|
|
torch.nn.Linear(input_size * 2 + 64, input_size), |
|
|
torch.nn.ReLU(), |
|
|
torch.nn.Dropout(0.3), |
|
|
torch.nn.Linear(input_size, 256), |
|
|
torch.nn.ReLU(), |
|
|
torch.nn.Dropout(0.3), |
|
|
torch.nn.Linear(256, 64), |
|
|
) |
|
|
self.conv = torch.nn.Sequential( |
|
|
torch.nn.Conv1d(64, 4, 3, 1, 1), |
|
|
torch.nn.Conv1d(4, 2, 3, 1, 1) |
|
|
) |
|
|
self.emb = torch.nn.Embedding(128, distance_emb_size) |
|
|
|
|
|
@property |
|
|
def device(self) -> torch.device: |
|
|
""" A workaround to get current device (which is assumed to be the |
|
|
device of the first parameter of one of the submodules) """ |
|
|
return next(self.ffnn.parameters()).device |
|
|
|
|
|
def forward(self, |
|
|
doc: Doc, |
|
|
words: torch.Tensor, |
|
|
heads_ids: torch.Tensor) -> torch.Tensor: |
|
|
""" |
|
|
Calculates span start/end scores of words for each span head in |
|
|
heads_ids |
|
|
|
|
|
Args: |
|
|
doc (Doc): the document data |
|
|
words (torch.Tensor): contextual embeddings for each word in the |
|
|
document, [n_words, emb_size] |
|
|
heads_ids (torch.Tensor): word indices of span heads |
|
|
|
|
|
Returns: |
|
|
torch.Tensor: span start/end scores, [n_heads, n_words, 2] |
|
|
""" |
|
|
|
|
|
relative_positions = (heads_ids.unsqueeze(1) - torch.arange(words.shape[0], device=words.device).unsqueeze(0)) |
|
|
emb_ids = relative_positions + 63 |
|
|
emb_ids[(emb_ids < 0) + (emb_ids > 126)] = 127 |
|
|
|
|
|
|
|
|
sent_id = torch.tensor(doc["sent_id"], device=words.device) |
|
|
same_sent = (sent_id[heads_ids].unsqueeze(1) == sent_id.unsqueeze(0)) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
rows, cols = same_sent.nonzero(as_tuple=True) |
|
|
pair_matrix = torch.cat(( |
|
|
words[heads_ids[rows]], |
|
|
words[cols], |
|
|
self.emb(emb_ids[rows, cols]), |
|
|
), dim=1) |
|
|
|
|
|
lengths = same_sent.sum(dim=1) |
|
|
padding_mask = torch.arange(0, lengths.max(), device=words.device).unsqueeze(0) |
|
|
padding_mask = (padding_mask < lengths.unsqueeze(1)) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
padded_pairs = torch.zeros(*padding_mask.shape, pair_matrix.shape[-1], device=words.device) |
|
|
padded_pairs[padding_mask] = pair_matrix |
|
|
|
|
|
res = self.ffnn(padded_pairs) |
|
|
res = self.conv(res.permute(0, 2, 1)).permute(0, 2, 1) |
|
|
|
|
|
scores = torch.full((heads_ids.shape[0], words.shape[0], 2), float('-inf'), device=words.device) |
|
|
scores[rows, cols] = res[padding_mask] |
|
|
|
|
|
|
|
|
if not self.training: |
|
|
valid_starts = torch.log((relative_positions >= 0).to(torch.float)) |
|
|
valid_ends = torch.log((relative_positions <= 0).to(torch.float)) |
|
|
valid_positions = torch.stack((valid_starts, valid_ends), dim=2) |
|
|
return scores + valid_positions |
|
|
return scores |
|
|
|
|
|
def get_training_data(self, |
|
|
doc: Doc, |
|
|
words: torch.Tensor |
|
|
) -> Tuple[Optional[torch.Tensor], |
|
|
Optional[Tuple[torch.Tensor, torch.Tensor]]]: |
|
|
""" Returns span starts/ends for gold mentions in the document. """ |
|
|
head2span = sorted(doc["head2span"]) |
|
|
if not head2span: |
|
|
return None, None |
|
|
heads, starts, ends = zip(*head2span) |
|
|
heads = torch.tensor(heads, device=self.device) |
|
|
starts = torch.tensor(starts, device=self.device) |
|
|
ends = torch.tensor(ends, device=self.device) - 1 |
|
|
return self(doc, words, heads), (starts, ends) |
|
|
|
|
|
def predict(self, |
|
|
doc: Doc, |
|
|
words: torch.Tensor, |
|
|
clusters: List[List[int]]) -> List[List[Span]]: |
|
|
""" |
|
|
Predicts span clusters based on the word clusters. |
|
|
|
|
|
Args: |
|
|
doc (Doc): the document data |
|
|
words (torch.Tensor): [n_words, emb_size] matrix containing |
|
|
embeddings for each of the words in the text |
|
|
clusters (List[List[int]]): a list of clusters where each cluster |
|
|
is a list of word indices |
|
|
|
|
|
Returns: |
|
|
List[List[Span]]: span clusters |
|
|
""" |
|
|
if not clusters: |
|
|
return [] |
|
|
|
|
|
heads_ids = torch.tensor( |
|
|
sorted(i for cluster in clusters for i in cluster), |
|
|
device=self.device |
|
|
) |
|
|
|
|
|
scores = self(doc, words, heads_ids) |
|
|
starts = scores[:, :, 0].argmax(dim=1).tolist() |
|
|
ends = (scores[:, :, 1].argmax(dim=1) + 1).tolist() |
|
|
|
|
|
head2span = { |
|
|
head: (start, end) |
|
|
for head, start, end in zip(heads_ids.tolist(), starts, ends) |
|
|
} |
|
|
|
|
|
return [[head2span[head] for head in cluster] |
|
|
for cluster in clusters] |
|
|
|