from copy import deepcopy import numpy as np import torch from torch import nn from torch import Tensor, FloatTensor, BoolTensor, LongTensor import torch.nn.functional as F from transformers.activations import ACT2FN from cobald_parser.bilinear_matrix_attention import BilinearMatrixAttention from cobald_parser.chu_liu_edmonds import decode_mst from cobald_parser.utils import pairwise_mask, replace_masked_values class DependencyHeadBase(nn.Module): """ Base class for scoring arcs and relations between tokens in a dependency tree/graph. """ def __init__(self, hidden_size: int, n_rels: int): super().__init__() self.arc_attention = BilinearMatrixAttention( hidden_size, hidden_size, use_input_biases=True, n_labels=1 ) self.rel_attention = BilinearMatrixAttention( hidden_size, hidden_size, use_input_biases=True, n_labels=n_rels ) def forward( self, h_arc_head: Tensor, # [batch_size, seq_len, hidden_size] h_arc_dep: Tensor, # ... h_rel_head: Tensor, # ... h_rel_dep: Tensor, # ... gold_arcs: LongTensor, # [batch_size, seq_len, seq_len] null_mask: BoolTensor, # [batch_size, seq_len] padding_mask: BoolTensor # [batch_size, seq_len] ) -> dict[str, Tensor]: # Score arcs. # s_arc[:, i, j] = score of edge i -> j. s_arc = self.arc_attention(h_arc_head, h_arc_dep) # Mask undesirable values (padding, nulls, etc.) with -inf. mask2d = pairwise_mask(null_mask & padding_mask) replace_masked_values(s_arc, mask2d, replace_with=-1e8) # Score arcs' relations. # [batch_size, seq_len, seq_len, num_labels] s_rel = self.rel_attention(h_rel_head, h_rel_dep).permute(0, 2, 3, 1) # Calculate loss. loss = 0.0 if gold_arcs is not None: loss += self.calc_arc_loss(s_arc, gold_arcs) loss += self.calc_rel_loss(s_rel, gold_arcs) # Predict arcs based on the scores. # [batch_size, seq_len, seq_len] pred_arcs_matrix = self.predict_arcs(s_arc, null_mask, padding_mask) # [batch_size, seq_len, seq_len] pred_rels_matrix = self.predict_rels(s_rel) # [n_pred_arcs, 4] preds_combined = self.combine_arcs_rels(pred_arcs_matrix, pred_rels_matrix) return { 'preds': preds_combined, 'loss': loss } @staticmethod def calc_arc_loss( s_arc: Tensor, # [batch_size, seq_len, seq_len] gold_arcs: LongTensor # [n_arcs, 4] ) -> Tensor: """Calculate arc loss.""" raise NotImplementedError @staticmethod def calc_rel_loss( s_rel: Tensor, # [batch_size, seq_len, seq_len, num_labels] gold_arcs: LongTensor # [n_arcs, 4] ) -> Tensor: batch_idxs, arcs_from, arcs_to, rels = gold_arcs.T return F.cross_entropy(s_rel[batch_idxs, arcs_from, arcs_to], rels) def predict_arcs( self, s_arc: Tensor, # [batch_size, seq_len, seq_len] null_mask: BoolTensor, # [batch_size, seq_len] padding_mask: BoolTensor # [batch_size, seq_len] ) -> LongTensor: """Predict arcs from scores.""" raise NotImplementedError def predict_rels( self, s_rel: FloatTensor ) -> LongTensor: return s_rel.argmax(dim=-1).long() @staticmethod def combine_arcs_rels( pred_arcs: LongTensor, pred_rels: LongTensor ) -> LongTensor: """Select relations towards predicted arcs.""" assert pred_arcs.shape == pred_rels.shape # Get indices where arcs exist indices = pred_arcs.nonzero(as_tuple=True) batch_idxs, from_idxs, to_idxs = indices # Get corresponding relation types rel_types = pred_rels[batch_idxs, from_idxs, to_idxs] # Stack as [batch_idx, from_idx, to_idx, rel_type] return torch.stack([batch_idxs, from_idxs, to_idxs, rel_types], dim=1) class DependencyHead(DependencyHeadBase): """ Basic UD syntax specialization that predicts single edge for each token. """ def predict_arcs( self, s_arc: Tensor, # [batch_size, seq_len, seq_len] null_mask: BoolTensor, # [batch_size, seq_len] padding_mask: BoolTensor # [batch_size, seq_len, seq_len] ) -> Tensor: if self.training: # During training, use fast greedy decoding. # - [batch_size, seq_len] pred_arcs_seq = s_arc.argmax(dim=1) else: # FIXME # During inference, decode Maximum Spanning Tree. # pred_arcs_seq = self._mst_decode(s_arc, padding_mask) pred_arcs_seq = s_arc.argmax(dim=1) # Upscale arcs sequence of shape [batch_size, seq_len] # to matrix of shape [batch_size, seq_len, seq_len]. pred_arcs = F.one_hot(pred_arcs_seq, num_classes=pred_arcs_seq.size(1)).long().transpose(1, 2) # Apply mask one more time (even though s_arc is already masked), # because argmax erases information about masked values. mask2d = pairwise_mask(null_mask & padding_mask) replace_masked_values(pred_arcs, mask2d, replace_with=0) return pred_arcs def _mst_decode( self, s_arc: Tensor, # [batch_size, seq_len, seq_len] padding_mask: Tensor ) -> tuple[Tensor, Tensor]: batch_size = s_arc.size(0) device = s_arc.device s_arc = s_arc.cpu() # Convert scores to probabilities, as `decode_mst` expects non-negative values. arc_probs = nn.functional.softmax(s_arc, dim=1) # `decode_mst` knows nothing about UD and ROOT, so we have to manually # zero probabilities of arcs leading to ROOT to make sure ROOT is a source node # of a graph. # Decode ROOT positions from diagonals. # shape: [batch_size] root_idxs = arc_probs.diagonal(dim1=1, dim2=2).argmax(dim=-1) # Zero out arcs leading to ROOTs. arc_probs[torch.arange(batch_size), :, root_idxs] = 0.0 pred_arcs = [] for sample_idx in range(batch_size): energy = arc_probs[sample_idx] length = padding_mask[sample_idx].sum() heads = decode_mst(energy, length) # Some nodes may be isolated. Pick heads greedily in this case. heads[heads <= 0] = s_arc[sample_idx].argmax(dim=1)[heads <= 0] pred_arcs.append(heads) # shape: [batch_size, seq_len] pred_arcs = torch.from_numpy(np.stack(pred_arcs)).long().to(device) return pred_arcs @staticmethod def calc_arc_loss( s_arc: Tensor, # [batch_size, seq_len, seq_len] gold_arcs: LongTensor # [n_arcs, 4] ) -> tuple[Tensor, Tensor]: batch_idxs, from_idxs, to_idxs, _ = gold_arcs.T return F.cross_entropy(s_arc[batch_idxs, :, to_idxs], from_idxs) class MultiDependencyHead(DependencyHeadBase): """ Enhanced UD syntax specialization that predicts multiple edges for each token. """ def predict_arcs( self, s_arc: Tensor, # [batch_size, seq_len, seq_len] null_mask: BoolTensor, # [batch_size, seq_len] padding_mask: BoolTensor # [batch_size, seq_len] ) -> Tensor: # Convert scores to probabilities. arc_probs = torch.sigmoid(s_arc) # Find confident arcs (with prob > 0.5). return arc_probs.round().long() @staticmethod def calc_arc_loss( s_arc: Tensor, # [batch_size, seq_len, seq_len] gold_arcs: LongTensor # [n_arcs, 4] ) -> Tensor: batch_idxs, from_idxs, to_idxs, _ = gold_arcs.T # Gold arcs but as a matrix, where matrix[i, arcs_from, arc_to] = 1.0 if arcs is present. gold_arcs_matrix = torch.zeros_like(s_arc) gold_arcs_matrix[batch_idxs, from_idxs, to_idxs] = 1.0 # Padded arcs's logits are huge negative values that doesn't contribute to the loss. return F.binary_cross_entropy_with_logits(s_arc, gold_arcs_matrix) class DependencyClassifier(nn.Module): """ Dozat and Manning's biaffine dependency classifier. """ def __init__( self, input_size: int, hidden_size: int, n_rels_ud: int, n_rels_eud: int, activation: str, dropout: float, ): super().__init__() self.arc_dep_mlp = nn.Sequential( nn.Dropout(dropout), nn.Linear(input_size, hidden_size), ACT2FN[activation], nn.Dropout(dropout) ) # All mlps are equal. self.arc_head_mlp = deepcopy(self.arc_dep_mlp) self.rel_dep_mlp = deepcopy(self.arc_dep_mlp) self.rel_head_mlp = deepcopy(self.arc_dep_mlp) self.dependency_head_ud = DependencyHead(hidden_size, n_rels_ud) self.dependency_head_eud = MultiDependencyHead(hidden_size, n_rels_eud) def forward( self, embeddings: Tensor, # [batch_size, seq_len, embedding_size] gold_ud: Tensor, # [n_ud_arcs, 4] gold_eud: Tensor, # [n_eud_arcs, 4] null_mask: Tensor, # [batch_size, seq_len] padding_mask: Tensor # [batch_size, seq_len] ) -> dict[str, Tensor]: # - [batch_size, seq_len, hidden_size] h_arc_head = self.arc_head_mlp(embeddings) h_arc_dep = self.arc_dep_mlp(embeddings) h_rel_head = self.rel_head_mlp(embeddings) h_rel_dep = self.rel_dep_mlp(embeddings) # Share the h vectors between dependency and multi-dependency heads. output_ud = self.dependency_head_ud( h_arc_head, h_arc_dep, h_rel_head, h_rel_dep, gold_arcs=gold_ud, null_mask=null_mask, padding_mask=padding_mask ) output_eud = self.dependency_head_eud( h_arc_head, h_arc_dep, h_rel_head, h_rel_dep, gold_arcs=gold_eud, # Ignore null mask in E-UD null_mask=torch.ones_like(padding_mask), padding_mask=padding_mask ) return { 'preds_ud': output_ud["preds"], 'preds_eud': output_eud["preds"], 'loss_ud': output_ud["loss"], 'loss_eud': output_eud["loss"] }