stanza-digphil / stanza /models /coref /span_predictor.py
Albin Thörn Cleland
Clean initial commit with LFS
19b8775
""" Describes SpanPredictor which aims to predict spans by taking as input
head word and context embeddings.
"""
from typing import List, Optional, Tuple
from stanza.models.coref.const import Doc, Span
import torch
class SpanPredictor(torch.nn.Module):
def __init__(self, input_size: int, distance_emb_size: int):
super().__init__()
self.ffnn = torch.nn.Sequential(
torch.nn.Linear(input_size * 2 + 64, input_size),
torch.nn.ReLU(),
torch.nn.Dropout(0.3),
torch.nn.Linear(input_size, 256),
torch.nn.ReLU(),
torch.nn.Dropout(0.3),
torch.nn.Linear(256, 64),
)
self.conv = torch.nn.Sequential(
torch.nn.Conv1d(64, 4, 3, 1, 1),
torch.nn.Conv1d(4, 2, 3, 1, 1)
)
self.emb = torch.nn.Embedding(128, distance_emb_size) # [-63, 63] + too_far
@property
def device(self) -> torch.device:
""" A workaround to get current device (which is assumed to be the
device of the first parameter of one of the submodules) """
return next(self.ffnn.parameters()).device
def forward(self, # type: ignore # pylint: disable=arguments-differ #35566 in pytorch
doc: Doc,
words: torch.Tensor,
heads_ids: torch.Tensor) -> torch.Tensor:
"""
Calculates span start/end scores of words for each span head in
heads_ids
Args:
doc (Doc): the document data
words (torch.Tensor): contextual embeddings for each word in the
document, [n_words, emb_size]
heads_ids (torch.Tensor): word indices of span heads
Returns:
torch.Tensor: span start/end scores, [n_heads, n_words, 2]
"""
# Obtain distance embedding indices, [n_heads, n_words]
relative_positions = (heads_ids.unsqueeze(1) - torch.arange(words.shape[0], device=words.device).unsqueeze(0))
emb_ids = relative_positions + 63 # make all valid distances positive
emb_ids[(emb_ids < 0) + (emb_ids > 126)] = 127 # "too_far"
# Obtain "same sentence" boolean mask, [n_heads, n_words]
sent_id = torch.tensor(doc["sent_id"], device=words.device)
same_sent = (sent_id[heads_ids].unsqueeze(1) == sent_id.unsqueeze(0))
# To save memory, only pass candidates from one sentence for each head
# pair_matrix contains concatenated span_head_emb + candidate_emb + distance_emb
# for each candidate among the words in the same sentence as span_head
# [n_heads, input_size * 2 + distance_emb_size]
rows, cols = same_sent.nonzero(as_tuple=True)
pair_matrix = torch.cat((
words[heads_ids[rows]],
words[cols],
self.emb(emb_ids[rows, cols]),
), dim=1)
lengths = same_sent.sum(dim=1)
padding_mask = torch.arange(0, lengths.max(), device=words.device).unsqueeze(0)
padding_mask = (padding_mask < lengths.unsqueeze(1)) # [n_heads, max_sent_len]
# [n_heads, max_sent_len, input_size * 2 + distance_emb_size]
# This is necessary to allow the convolution layer to look at several
# word scores
padded_pairs = torch.zeros(*padding_mask.shape, pair_matrix.shape[-1], device=words.device)
padded_pairs[padding_mask] = pair_matrix
res = self.ffnn(padded_pairs) # [n_heads, n_candidates, last_layer_output]
res = self.conv(res.permute(0, 2, 1)).permute(0, 2, 1) # [n_heads, n_candidates, 2]
scores = torch.full((heads_ids.shape[0], words.shape[0], 2), float('-inf'), device=words.device)
scores[rows, cols] = res[padding_mask]
# Make sure that start <= head <= end during inference
if not self.training:
valid_starts = torch.log((relative_positions >= 0).to(torch.float))
valid_ends = torch.log((relative_positions <= 0).to(torch.float))
valid_positions = torch.stack((valid_starts, valid_ends), dim=2)
return scores + valid_positions
return scores
def get_training_data(self,
doc: Doc,
words: torch.Tensor
) -> Tuple[Optional[torch.Tensor],
Optional[Tuple[torch.Tensor, torch.Tensor]]]:
""" Returns span starts/ends for gold mentions in the document. """
head2span = sorted(doc["head2span"])
if not head2span:
return None, None
heads, starts, ends = zip(*head2span)
heads = torch.tensor(heads, device=self.device)
starts = torch.tensor(starts, device=self.device)
ends = torch.tensor(ends, device=self.device) - 1
return self(doc, words, heads), (starts, ends)
def predict(self,
doc: Doc,
words: torch.Tensor,
clusters: List[List[int]]) -> List[List[Span]]:
"""
Predicts span clusters based on the word clusters.
Args:
doc (Doc): the document data
words (torch.Tensor): [n_words, emb_size] matrix containing
embeddings for each of the words in the text
clusters (List[List[int]]): a list of clusters where each cluster
is a list of word indices
Returns:
List[List[Span]]: span clusters
"""
if not clusters:
return []
heads_ids = torch.tensor(
sorted(i for cluster in clusters for i in cluster),
device=self.device
)
scores = self(doc, words, heads_ids)
starts = scores[:, :, 0].argmax(dim=1).tolist()
ends = (scores[:, :, 1].argmax(dim=1) + 1).tolist()
head2span = {
head: (start, end)
for head, start, end in zip(heads_ids.tolist(), starts, ends)
}
return [[head2span[head] for head in cluster]
for cluster in clusters]