train100_encoder_freezed_5_10e-4 / dependency_classifier.py
E-katrin's picture
Upload ConlluTokenClassificationPipeline
15892c6 verified
from copy import deepcopy
import numpy as np
import torch
from torch import nn
from torch import Tensor, FloatTensor, BoolTensor, LongTensor
import torch.nn.functional as F
from transformers.activations import ACT2FN
from cobald_parser.bilinear_matrix_attention import BilinearMatrixAttention
from cobald_parser.chu_liu_edmonds import decode_mst
from cobald_parser.utils import pairwise_mask, replace_masked_values
class DependencyHeadBase(nn.Module):
"""
Base class for scoring arcs and relations between tokens in a dependency tree/graph.
"""
def __init__(self, hidden_size: int, n_rels: int):
super().__init__()
self.arc_attention = BilinearMatrixAttention(
hidden_size,
hidden_size,
use_input_biases=True,
n_labels=1
)
self.rel_attention = BilinearMatrixAttention(
hidden_size,
hidden_size,
use_input_biases=True,
n_labels=n_rels
)
def forward(
self,
h_arc_head: Tensor, # [batch_size, seq_len, hidden_size]
h_arc_dep: Tensor, # ...
h_rel_head: Tensor, # ...
h_rel_dep: Tensor, # ...
gold_arcs: LongTensor, # [batch_size, seq_len, seq_len]
null_mask: BoolTensor, # [batch_size, seq_len]
padding_mask: BoolTensor # [batch_size, seq_len]
) -> dict[str, Tensor]:
# Score arcs.
# s_arc[:, i, j] = score of edge i -> j.
s_arc = self.arc_attention(h_arc_head, h_arc_dep)
# Mask undesirable values (padding, nulls, etc.) with -inf.
mask2d = pairwise_mask(null_mask & padding_mask)
replace_masked_values(s_arc, mask2d, replace_with=-1e8)
# Score arcs' relations.
# [batch_size, seq_len, seq_len, num_labels]
s_rel = self.rel_attention(h_rel_head, h_rel_dep).permute(0, 2, 3, 1)
# Calculate loss.
loss = 0.0
if gold_arcs is not None:
loss += self.calc_arc_loss(s_arc, gold_arcs)
loss += self.calc_rel_loss(s_rel, gold_arcs)
# Predict arcs based on the scores.
# [batch_size, seq_len, seq_len]
pred_arcs_matrix = self.predict_arcs(s_arc, null_mask, padding_mask)
# [batch_size, seq_len, seq_len]
pred_rels_matrix = self.predict_rels(s_rel)
# [n_pred_arcs, 4]
preds_combined = self.combine_arcs_rels(pred_arcs_matrix, pred_rels_matrix)
return {
'preds': preds_combined,
'loss': loss
}
@staticmethod
def calc_arc_loss(
s_arc: Tensor, # [batch_size, seq_len, seq_len]
gold_arcs: LongTensor # [n_arcs, 4]
) -> Tensor:
"""Calculate arc loss."""
raise NotImplementedError
@staticmethod
def calc_rel_loss(
s_rel: Tensor, # [batch_size, seq_len, seq_len, num_labels]
gold_arcs: LongTensor # [n_arcs, 4]
) -> Tensor:
batch_idxs, arcs_from, arcs_to, rels = gold_arcs.T
return F.cross_entropy(s_rel[batch_idxs, arcs_from, arcs_to], rels)
def predict_arcs(
self,
s_arc: Tensor, # [batch_size, seq_len, seq_len]
null_mask: BoolTensor, # [batch_size, seq_len]
padding_mask: BoolTensor # [batch_size, seq_len]
) -> LongTensor:
"""Predict arcs from scores."""
raise NotImplementedError
def predict_rels(
self,
s_rel: FloatTensor
) -> LongTensor:
return s_rel.argmax(dim=-1).long()
@staticmethod
def combine_arcs_rels(
pred_arcs: LongTensor,
pred_rels: LongTensor
) -> LongTensor:
"""Select relations towards predicted arcs."""
assert pred_arcs.shape == pred_rels.shape
# Get indices where arcs exist
indices = pred_arcs.nonzero(as_tuple=True)
batch_idxs, from_idxs, to_idxs = indices
# Get corresponding relation types
rel_types = pred_rels[batch_idxs, from_idxs, to_idxs]
# Stack as [batch_idx, from_idx, to_idx, rel_type]
return torch.stack([batch_idxs, from_idxs, to_idxs, rel_types], dim=1)
class DependencyHead(DependencyHeadBase):
"""
Basic UD syntax specialization that predicts single edge for each token.
"""
def predict_arcs(
self,
s_arc: Tensor, # [batch_size, seq_len, seq_len]
null_mask: BoolTensor, # [batch_size, seq_len]
padding_mask: BoolTensor # [batch_size, seq_len, seq_len]
) -> Tensor:
if self.training:
# During training, use fast greedy decoding.
# - [batch_size, seq_len]
pred_arcs_seq = s_arc.argmax(dim=1)
else:
# FIXME
# During inference, decode Maximum Spanning Tree.
# pred_arcs_seq = self._mst_decode(s_arc, padding_mask)
pred_arcs_seq = s_arc.argmax(dim=1)
# Upscale arcs sequence of shape [batch_size, seq_len]
# to matrix of shape [batch_size, seq_len, seq_len].
pred_arcs = F.one_hot(pred_arcs_seq, num_classes=pred_arcs_seq.size(1)).long().transpose(1, 2)
# Apply mask one more time (even though s_arc is already masked),
# because argmax erases information about masked values.
mask2d = pairwise_mask(null_mask & padding_mask)
replace_masked_values(pred_arcs, mask2d, replace_with=0)
return pred_arcs
def _mst_decode(
self,
s_arc: Tensor, # [batch_size, seq_len, seq_len]
padding_mask: Tensor
) -> tuple[Tensor, Tensor]:
batch_size = s_arc.size(0)
device = s_arc.device
s_arc = s_arc.cpu()
# Convert scores to probabilities, as `decode_mst` expects non-negative values.
arc_probs = nn.functional.softmax(s_arc, dim=1)
# `decode_mst` knows nothing about UD and ROOT, so we have to manually
# zero probabilities of arcs leading to ROOT to make sure ROOT is a source node
# of a graph.
# Decode ROOT positions from diagonals.
# shape: [batch_size]
root_idxs = arc_probs.diagonal(dim1=1, dim2=2).argmax(dim=-1)
# Zero out arcs leading to ROOTs.
arc_probs[torch.arange(batch_size), :, root_idxs] = 0.0
pred_arcs = []
for sample_idx in range(batch_size):
energy = arc_probs[sample_idx]
length = padding_mask[sample_idx].sum()
heads = decode_mst(energy, length)
# Some nodes may be isolated. Pick heads greedily in this case.
heads[heads <= 0] = s_arc[sample_idx].argmax(dim=1)[heads <= 0]
pred_arcs.append(heads)
# shape: [batch_size, seq_len]
pred_arcs = torch.from_numpy(np.stack(pred_arcs)).long().to(device)
return pred_arcs
@staticmethod
def calc_arc_loss(
s_arc: Tensor, # [batch_size, seq_len, seq_len]
gold_arcs: LongTensor # [n_arcs, 4]
) -> tuple[Tensor, Tensor]:
batch_idxs, from_idxs, to_idxs, _ = gold_arcs.T
return F.cross_entropy(s_arc[batch_idxs, :, to_idxs], from_idxs)
class MultiDependencyHead(DependencyHeadBase):
"""
Enhanced UD syntax specialization that predicts multiple edges for each token.
"""
def predict_arcs(
self,
s_arc: Tensor, # [batch_size, seq_len, seq_len]
null_mask: BoolTensor, # [batch_size, seq_len]
padding_mask: BoolTensor # [batch_size, seq_len]
) -> Tensor:
# Convert scores to probabilities.
arc_probs = torch.sigmoid(s_arc)
# Find confident arcs (with prob > 0.5).
return arc_probs.round().long()
@staticmethod
def calc_arc_loss(
s_arc: Tensor, # [batch_size, seq_len, seq_len]
gold_arcs: LongTensor # [n_arcs, 4]
) -> Tensor:
batch_idxs, from_idxs, to_idxs, _ = gold_arcs.T
# Gold arcs but as a matrix, where matrix[i, arcs_from, arc_to] = 1.0 if arcs is present.
gold_arcs_matrix = torch.zeros_like(s_arc)
gold_arcs_matrix[batch_idxs, from_idxs, to_idxs] = 1.0
# Padded arcs's logits are huge negative values that doesn't contribute to the loss.
return F.binary_cross_entropy_with_logits(s_arc, gold_arcs_matrix)
class DependencyClassifier(nn.Module):
"""
Dozat and Manning's biaffine dependency classifier.
"""
def __init__(
self,
input_size: int,
hidden_size: int,
n_rels_ud: int,
n_rels_eud: int,
activation: str,
dropout: float,
):
super().__init__()
self.arc_dep_mlp = nn.Sequential(
nn.Dropout(dropout),
nn.Linear(input_size, hidden_size),
ACT2FN[activation],
nn.Dropout(dropout)
)
# All mlps are equal.
self.arc_head_mlp = deepcopy(self.arc_dep_mlp)
self.rel_dep_mlp = deepcopy(self.arc_dep_mlp)
self.rel_head_mlp = deepcopy(self.arc_dep_mlp)
self.dependency_head_ud = DependencyHead(hidden_size, n_rels_ud)
self.dependency_head_eud = MultiDependencyHead(hidden_size, n_rels_eud)
def forward(
self,
embeddings: Tensor, # [batch_size, seq_len, embedding_size]
gold_ud: Tensor, # [n_ud_arcs, 4]
gold_eud: Tensor, # [n_eud_arcs, 4]
null_mask: Tensor, # [batch_size, seq_len]
padding_mask: Tensor # [batch_size, seq_len]
) -> dict[str, Tensor]:
# - [batch_size, seq_len, hidden_size]
h_arc_head = self.arc_head_mlp(embeddings)
h_arc_dep = self.arc_dep_mlp(embeddings)
h_rel_head = self.rel_head_mlp(embeddings)
h_rel_dep = self.rel_dep_mlp(embeddings)
# Share the h vectors between dependency and multi-dependency heads.
output_ud = self.dependency_head_ud(
h_arc_head,
h_arc_dep,
h_rel_head,
h_rel_dep,
gold_arcs=gold_ud,
null_mask=null_mask,
padding_mask=padding_mask
)
output_eud = self.dependency_head_eud(
h_arc_head,
h_arc_dep,
h_rel_head,
h_rel_dep,
gold_arcs=gold_eud,
# Ignore null mask in E-UD
null_mask=torch.ones_like(padding_mask),
padding_mask=padding_mask
)
return {
'preds_ud': output_ud["preds"],
'preds_eud': output_eud["preds"],
'loss_ud': output_ud["loss"],
'loss_eud': output_eud["loss"]
}