|
|
from copy import deepcopy |
|
|
|
|
|
import numpy as np |
|
|
|
|
|
import torch |
|
|
from torch import nn |
|
|
from torch import Tensor, FloatTensor, BoolTensor, LongTensor |
|
|
import torch.nn.functional as F |
|
|
|
|
|
from transformers.activations import ACT2FN |
|
|
|
|
|
from cobald_parser.bilinear_matrix_attention import BilinearMatrixAttention |
|
|
from cobald_parser.chu_liu_edmonds import decode_mst |
|
|
from cobald_parser.utils import pairwise_mask, replace_masked_values |
|
|
|
|
|
|
|
|
class DependencyHeadBase(nn.Module): |
|
|
""" |
|
|
Base class for scoring arcs and relations between tokens in a dependency tree/graph. |
|
|
""" |
|
|
|
|
|
def __init__(self, hidden_size: int, n_rels: int): |
|
|
super().__init__() |
|
|
|
|
|
self.arc_attention = BilinearMatrixAttention( |
|
|
hidden_size, |
|
|
hidden_size, |
|
|
use_input_biases=True, |
|
|
n_labels=1 |
|
|
) |
|
|
self.rel_attention = BilinearMatrixAttention( |
|
|
hidden_size, |
|
|
hidden_size, |
|
|
use_input_biases=True, |
|
|
n_labels=n_rels |
|
|
) |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
h_arc_head: Tensor, |
|
|
h_arc_dep: Tensor, |
|
|
h_rel_head: Tensor, |
|
|
h_rel_dep: Tensor, |
|
|
gold_arcs: LongTensor, |
|
|
null_mask: BoolTensor, |
|
|
padding_mask: BoolTensor |
|
|
) -> dict[str, Tensor]: |
|
|
|
|
|
|
|
|
|
|
|
s_arc = self.arc_attention(h_arc_head, h_arc_dep) |
|
|
|
|
|
mask2d = pairwise_mask(null_mask & padding_mask) |
|
|
replace_masked_values(s_arc, mask2d, replace_with=-1e8) |
|
|
|
|
|
|
|
|
s_rel = self.rel_attention(h_rel_head, h_rel_dep).permute(0, 2, 3, 1) |
|
|
|
|
|
|
|
|
loss = 0.0 |
|
|
if gold_arcs is not None: |
|
|
loss += self.calc_arc_loss(s_arc, gold_arcs) |
|
|
loss += self.calc_rel_loss(s_rel, gold_arcs) |
|
|
|
|
|
|
|
|
|
|
|
pred_arcs_matrix = self.predict_arcs(s_arc, null_mask, padding_mask) |
|
|
|
|
|
pred_rels_matrix = self.predict_rels(s_rel) |
|
|
|
|
|
preds_combined = self.combine_arcs_rels(pred_arcs_matrix, pred_rels_matrix) |
|
|
return { |
|
|
'preds': preds_combined, |
|
|
'loss': loss |
|
|
} |
|
|
|
|
|
@staticmethod |
|
|
def calc_arc_loss( |
|
|
s_arc: Tensor, |
|
|
gold_arcs: LongTensor |
|
|
) -> Tensor: |
|
|
"""Calculate arc loss.""" |
|
|
raise NotImplementedError |
|
|
|
|
|
@staticmethod |
|
|
def calc_rel_loss( |
|
|
s_rel: Tensor, |
|
|
gold_arcs: LongTensor |
|
|
) -> Tensor: |
|
|
batch_idxs, arcs_from, arcs_to, rels = gold_arcs.T |
|
|
return F.cross_entropy(s_rel[batch_idxs, arcs_from, arcs_to], rels) |
|
|
|
|
|
def predict_arcs( |
|
|
self, |
|
|
s_arc: Tensor, |
|
|
null_mask: BoolTensor, |
|
|
padding_mask: BoolTensor |
|
|
) -> LongTensor: |
|
|
"""Predict arcs from scores.""" |
|
|
raise NotImplementedError |
|
|
|
|
|
def predict_rels( |
|
|
self, |
|
|
s_rel: FloatTensor |
|
|
) -> LongTensor: |
|
|
return s_rel.argmax(dim=-1).long() |
|
|
|
|
|
@staticmethod |
|
|
def combine_arcs_rels( |
|
|
pred_arcs: LongTensor, |
|
|
pred_rels: LongTensor |
|
|
) -> LongTensor: |
|
|
"""Select relations towards predicted arcs.""" |
|
|
assert pred_arcs.shape == pred_rels.shape |
|
|
|
|
|
indices = pred_arcs.nonzero(as_tuple=True) |
|
|
batch_idxs, from_idxs, to_idxs = indices |
|
|
|
|
|
rel_types = pred_rels[batch_idxs, from_idxs, to_idxs] |
|
|
|
|
|
return torch.stack([batch_idxs, from_idxs, to_idxs, rel_types], dim=1) |
|
|
|
|
|
|
|
|
class DependencyHead(DependencyHeadBase): |
|
|
""" |
|
|
Basic UD syntax specialization that predicts single edge for each token. |
|
|
""" |
|
|
|
|
|
|
|
|
def predict_arcs( |
|
|
self, |
|
|
s_arc: Tensor, |
|
|
null_mask: BoolTensor, |
|
|
padding_mask: BoolTensor |
|
|
) -> Tensor: |
|
|
|
|
|
if self.training: |
|
|
|
|
|
|
|
|
pred_arcs_seq = s_arc.argmax(dim=1) |
|
|
else: |
|
|
|
|
|
|
|
|
|
|
|
pred_arcs_seq = s_arc.argmax(dim=1) |
|
|
|
|
|
|
|
|
|
|
|
pred_arcs = F.one_hot(pred_arcs_seq, num_classes=pred_arcs_seq.size(1)).long().transpose(1, 2) |
|
|
|
|
|
|
|
|
mask2d = pairwise_mask(null_mask & padding_mask) |
|
|
replace_masked_values(pred_arcs, mask2d, replace_with=0) |
|
|
return pred_arcs |
|
|
|
|
|
def _mst_decode( |
|
|
self, |
|
|
s_arc: Tensor, |
|
|
padding_mask: Tensor |
|
|
) -> tuple[Tensor, Tensor]: |
|
|
|
|
|
batch_size = s_arc.size(0) |
|
|
device = s_arc.device |
|
|
s_arc = s_arc.cpu() |
|
|
|
|
|
|
|
|
arc_probs = nn.functional.softmax(s_arc, dim=1) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
root_idxs = arc_probs.diagonal(dim1=1, dim2=2).argmax(dim=-1) |
|
|
|
|
|
arc_probs[torch.arange(batch_size), :, root_idxs] = 0.0 |
|
|
|
|
|
pred_arcs = [] |
|
|
for sample_idx in range(batch_size): |
|
|
energy = arc_probs[sample_idx] |
|
|
length = padding_mask[sample_idx].sum() |
|
|
heads = decode_mst(energy, length) |
|
|
|
|
|
heads[heads <= 0] = s_arc[sample_idx].argmax(dim=1)[heads <= 0] |
|
|
pred_arcs.append(heads) |
|
|
|
|
|
|
|
|
pred_arcs = torch.from_numpy(np.stack(pred_arcs)).long().to(device) |
|
|
return pred_arcs |
|
|
|
|
|
@staticmethod |
|
|
def calc_arc_loss( |
|
|
s_arc: Tensor, |
|
|
gold_arcs: LongTensor |
|
|
) -> tuple[Tensor, Tensor]: |
|
|
batch_idxs, from_idxs, to_idxs, _ = gold_arcs.T |
|
|
return F.cross_entropy(s_arc[batch_idxs, :, to_idxs], from_idxs) |
|
|
|
|
|
|
|
|
class MultiDependencyHead(DependencyHeadBase): |
|
|
""" |
|
|
Enhanced UD syntax specialization that predicts multiple edges for each token. |
|
|
""" |
|
|
|
|
|
def predict_arcs( |
|
|
self, |
|
|
s_arc: Tensor, |
|
|
null_mask: BoolTensor, |
|
|
padding_mask: BoolTensor |
|
|
) -> Tensor: |
|
|
|
|
|
arc_probs = torch.sigmoid(s_arc) |
|
|
|
|
|
return arc_probs.round().long() |
|
|
|
|
|
@staticmethod |
|
|
def calc_arc_loss( |
|
|
s_arc: Tensor, |
|
|
gold_arcs: LongTensor |
|
|
) -> Tensor: |
|
|
batch_idxs, from_idxs, to_idxs, _ = gold_arcs.T |
|
|
|
|
|
gold_arcs_matrix = torch.zeros_like(s_arc) |
|
|
gold_arcs_matrix[batch_idxs, from_idxs, to_idxs] = 1.0 |
|
|
|
|
|
return F.binary_cross_entropy_with_logits(s_arc, gold_arcs_matrix) |
|
|
|
|
|
|
|
|
class DependencyClassifier(nn.Module): |
|
|
""" |
|
|
Dozat and Manning's biaffine dependency classifier. |
|
|
""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
input_size: int, |
|
|
hidden_size: int, |
|
|
n_rels_ud: int, |
|
|
n_rels_eud: int, |
|
|
activation: str, |
|
|
dropout: float, |
|
|
): |
|
|
super().__init__() |
|
|
|
|
|
self.arc_dep_mlp = nn.Sequential( |
|
|
nn.Dropout(dropout), |
|
|
nn.Linear(input_size, hidden_size), |
|
|
ACT2FN[activation], |
|
|
nn.Dropout(dropout) |
|
|
) |
|
|
|
|
|
self.arc_head_mlp = deepcopy(self.arc_dep_mlp) |
|
|
self.rel_dep_mlp = deepcopy(self.arc_dep_mlp) |
|
|
self.rel_head_mlp = deepcopy(self.arc_dep_mlp) |
|
|
|
|
|
|
|
|
self.arc_dep_extra = nn.Linear(hidden_size, hidden_size) |
|
|
self.arc_head_extra = nn.Linear(hidden_size, hidden_size) |
|
|
self.rel_dep_extra = nn.Linear(hidden_size, hidden_size) |
|
|
self.rel_head_extra = nn.Linear(hidden_size, hidden_size) |
|
|
|
|
|
for param in self.arc_dep_mlp.parameters(): |
|
|
param.requires_grad = False |
|
|
for param in self.arc_head_mlp.parameters(): |
|
|
param.requires_grad = False |
|
|
for param in self.rel_dep_mlp.parameters(): |
|
|
param.requires_grad = False |
|
|
for param in self.rel_head_mlp.parameters(): |
|
|
param.requires_grad = False |
|
|
for param in self.arc_dep_extra.parameters(): |
|
|
param.requires_grad = True |
|
|
for param in self.arc_head_extra.parameters(): |
|
|
param.requires_grad = True |
|
|
for param in self.rel_dep_extra.parameters(): |
|
|
param.requires_grad = True |
|
|
for param in self.rel_head_extra.parameters(): |
|
|
param.requires_grad = True |
|
|
|
|
|
self.dependency_head_ud = DependencyHead(hidden_size, n_rels_ud) |
|
|
self.dependency_head_eud = MultiDependencyHead(hidden_size, n_rels_eud) |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
embeddings: Tensor, |
|
|
gold_ud: Tensor, |
|
|
gold_eud: Tensor, |
|
|
null_mask: Tensor, |
|
|
padding_mask: Tensor |
|
|
) -> dict[str, Tensor]: |
|
|
|
|
|
h_arc_head = self.arc_head_extra(self.arc_head_mlp(embeddings)) |
|
|
h_arc_dep = self.arc_dep_extra(self.arc_dep_mlp(embeddings)) |
|
|
h_rel_head = self.rel_head_extra(self.rel_head_mlp(embeddings)) |
|
|
h_rel_dep = self.rel_dep_extra(self.rel_dep_mlp(embeddings)) |
|
|
|
|
|
output_ud = self.dependency_head_ud( |
|
|
h_arc_head, h_arc_dep, h_rel_head, h_rel_dep, |
|
|
gold_arcs=gold_ud, |
|
|
null_mask=null_mask, |
|
|
padding_mask=padding_mask |
|
|
) |
|
|
output_eud = self.dependency_head_eud( |
|
|
h_arc_head, h_arc_dep, h_rel_head, h_rel_dep, |
|
|
gold_arcs=gold_eud, |
|
|
null_mask=torch.ones_like(padding_mask), |
|
|
padding_mask=padding_mask |
|
|
) |
|
|
return { |
|
|
'preds_ud': output_ud["preds"], |
|
|
'preds_eud': output_eud["preds"], |
|
|
'loss_ud': output_ud["loss"], |
|
|
'loss_eud': output_eud["loss"] |
|
|
} |