""" Describes PairwiseEncodes, that transforms pairwise features, such as distance between the mentions, same/different speaker into feature embeddings """ from typing import List import torch from stanza.models.coref.config import Config from stanza.models.coref.const import Doc class PairwiseEncoder(torch.nn.Module): """ A Pytorch module to obtain feature embeddings for pairwise features Usage: encoder = PairwiseEncoder(config) pairwise_features = encoder(pair_indices, doc) """ def __init__(self, config: Config): super().__init__() emb_size = config.embedding_size self.genre2int = {g: gi for gi, g in enumerate(["bc", "bn", "mz", "nw", "pt", "tc", "wb"])} self.genre_emb = torch.nn.Embedding(len(self.genre2int), emb_size) # each position corresponds to a bucket: # [(0, 2), (2, 3), (3, 4), (4, 5), (5, 8), # (8, 16), (16, 32), (32, 64), (64, float("inf"))] self.distance_emb = torch.nn.Embedding(9, emb_size) # two possibilities: same vs different speaker self.speaker_emb = torch.nn.Embedding(2, emb_size) self.dropout = torch.nn.Dropout(config.dropout_rate) self.__full_pw = config.full_pairwise if self.__full_pw: self.shape = emb_size * 3 # genre, distance, speaker\ else: self.shape = emb_size # distance only @property def device(self) -> torch.device: """ A workaround to get current device (which is assumed to be the device of the first parameter of one of the submodules) """ return next(self.genre_emb.parameters()).device def forward(self, # type: ignore # pylint: disable=arguments-differ #35566 in pytorch top_indices: torch.Tensor, doc: Doc) -> torch.Tensor: word_ids = torch.arange(0, len(doc["cased_words"]), device=self.device) # bucketing the distance (see __init__()) distance = (word_ids.unsqueeze(1) - word_ids[top_indices] ).clamp_min_(min=1) log_distance = distance.to(torch.float).log2().floor_() log_distance = log_distance.clamp_max_(max=6).to(torch.long) distance = torch.where(distance < 5, distance - 1, log_distance + 2) distance = self.distance_emb(distance) if not self.__full_pw: return self.dropout(distance) # calculate speaker embeddings speaker_map = torch.tensor(self._speaker_map(doc), device=self.device) same_speaker = (speaker_map[top_indices] == speaker_map.unsqueeze(1)) same_speaker = self.speaker_emb(same_speaker.to(torch.long)) # if there is no genre information, use "wb" as the genre (which is what the # Pipeline does genre = torch.tensor(self.genre2int.get(doc["document_id"][:2], self.genre2int["wb"]), device=self.device).expand_as(top_indices) genre = self.genre_emb(genre) return self.dropout(torch.cat((same_speaker, distance, genre), dim=2)) @staticmethod def _speaker_map(doc: Doc) -> List[int]: """ Returns a tensor where i-th element is the speaker id of i-th word. """ # if speaker is not found in the doc, simply return "speaker#1" for all the speakers # and embed them using the same ID # speaker string -> speaker id str2int = {s: i for i, s in enumerate(set(doc.get("speaker", ["speaker#1" for _ in range(len(doc["deprel"]))])))} # word id -> speaker id return [str2int[s] for s in doc.get("speaker", ["speaker#1" for _ in range(len(doc["deprel"]))])]