stanza-digphil / stanza /models /coref /pairwise_encoder.py
Albin Thörn Cleland
Clean initial commit with LFS
19b8775
""" Describes PairwiseEncodes, that transforms pairwise features, such as
distance between the mentions, same/different speaker into feature embeddings
"""
from typing import List
import torch
from stanza.models.coref.config import Config
from stanza.models.coref.const import Doc
class PairwiseEncoder(torch.nn.Module):
""" A Pytorch module to obtain feature embeddings for pairwise features
Usage:
encoder = PairwiseEncoder(config)
pairwise_features = encoder(pair_indices, doc)
"""
def __init__(self, config: Config):
super().__init__()
emb_size = config.embedding_size
self.genre2int = {g: gi for gi, g in enumerate(["bc", "bn", "mz", "nw",
"pt", "tc", "wb"])}
self.genre_emb = torch.nn.Embedding(len(self.genre2int), emb_size)
# each position corresponds to a bucket:
# [(0, 2), (2, 3), (3, 4), (4, 5), (5, 8),
# (8, 16), (16, 32), (32, 64), (64, float("inf"))]
self.distance_emb = torch.nn.Embedding(9, emb_size)
# two possibilities: same vs different speaker
self.speaker_emb = torch.nn.Embedding(2, emb_size)
self.dropout = torch.nn.Dropout(config.dropout_rate)
self.__full_pw = config.full_pairwise
if self.__full_pw:
self.shape = emb_size * 3 # genre, distance, speaker\
else:
self.shape = emb_size # distance only
@property
def device(self) -> torch.device:
""" A workaround to get current device (which is assumed to be the
device of the first parameter of one of the submodules) """
return next(self.genre_emb.parameters()).device
def forward(self, # type: ignore # pylint: disable=arguments-differ #35566 in pytorch
top_indices: torch.Tensor,
doc: Doc) -> torch.Tensor:
word_ids = torch.arange(0, len(doc["cased_words"]), device=self.device)
# bucketing the distance (see __init__())
distance = (word_ids.unsqueeze(1) - word_ids[top_indices]
).clamp_min_(min=1)
log_distance = distance.to(torch.float).log2().floor_()
log_distance = log_distance.clamp_max_(max=6).to(torch.long)
distance = torch.where(distance < 5, distance - 1, log_distance + 2)
distance = self.distance_emb(distance)
if not self.__full_pw:
return self.dropout(distance)
# calculate speaker embeddings
speaker_map = torch.tensor(self._speaker_map(doc), device=self.device)
same_speaker = (speaker_map[top_indices] == speaker_map.unsqueeze(1))
same_speaker = self.speaker_emb(same_speaker.to(torch.long))
# if there is no genre information, use "wb" as the genre (which is what the
# Pipeline does
genre = torch.tensor(self.genre2int.get(doc["document_id"][:2], self.genre2int["wb"]),
device=self.device).expand_as(top_indices)
genre = self.genre_emb(genre)
return self.dropout(torch.cat((same_speaker, distance, genre), dim=2))
@staticmethod
def _speaker_map(doc: Doc) -> List[int]:
"""
Returns a tensor where i-th element is the speaker id of i-th word.
"""
# if speaker is not found in the doc, simply return "speaker#1" for all the speakers
# and embed them using the same ID
# speaker string -> speaker id
str2int = {s: i for i, s in enumerate(set(doc.get("speaker", ["speaker#1"
for _ in range(len(doc["deprel"]))])))}
# word id -> speaker id
return [str2int[s] for s in doc.get("speaker", ["speaker#1"
for _ in range(len(doc["deprel"]))])]