Model save
Browse files- README.md +3 -3
- dependency_classifier.py +46 -42
- model.safetensors +1 -1
- modeling_parser.py +25 -44
- training_args.bin +1 -1
- utils.py +1 -1
README.md
CHANGED
|
@@ -21,13 +21,13 @@ model-index:
|
|
| 21 |
split: validation
|
| 22 |
metrics:
|
| 23 |
- type: f1
|
| 24 |
-
value: 0.
|
| 25 |
name: Null F1
|
| 26 |
- type: accuracy
|
| 27 |
-
value: 0.
|
| 28 |
name: Ud Jaccard
|
| 29 |
- type: accuracy
|
| 30 |
-
value: 0.
|
| 31 |
name: Eud Jaccard
|
| 32 |
---
|
| 33 |
|
|
|
|
| 21 |
split: validation
|
| 22 |
metrics:
|
| 23 |
- type: f1
|
| 24 |
+
value: 0.2499754084433992
|
| 25 |
name: Null F1
|
| 26 |
- type: accuracy
|
| 27 |
+
value: 0.32062444472648816
|
| 28 |
name: Ud Jaccard
|
| 29 |
- type: accuracy
|
| 30 |
+
value: 0.7903051003317022
|
| 31 |
name: Eud Jaccard
|
| 32 |
---
|
| 33 |
|
dependency_classifier.py
CHANGED
|
@@ -38,19 +38,21 @@ class DependencyHeadBase(nn.Module):
|
|
| 38 |
|
| 39 |
def forward(
|
| 40 |
self,
|
| 41 |
-
h_arc_head: Tensor,
|
| 42 |
-
h_arc_dep: Tensor,
|
| 43 |
-
h_rel_head: Tensor,
|
| 44 |
-
h_rel_dep: Tensor,
|
| 45 |
-
gold_arcs: LongTensor,
|
| 46 |
-
|
|
|
|
| 47 |
) -> dict[str, Tensor]:
|
| 48 |
-
|
| 49 |
# Score arcs.
|
| 50 |
-
# s_arc[:, i, j] = score of edge
|
| 51 |
s_arc = self.arc_attention(h_arc_head, h_arc_dep)
|
| 52 |
# Mask undesirable values (padding, nulls, etc.) with -inf.
|
| 53 |
-
|
|
|
|
| 54 |
# Score arcs' relations.
|
| 55 |
# [batch_size, seq_len, seq_len, num_labels]
|
| 56 |
s_rel = self.rel_attention(h_rel_head, h_rel_dep).permute(0, 2, 3, 1)
|
|
@@ -63,11 +65,11 @@ class DependencyHeadBase(nn.Module):
|
|
| 63 |
|
| 64 |
# Predict arcs based on the scores.
|
| 65 |
# [batch_size, seq_len, seq_len]
|
| 66 |
-
|
| 67 |
# [batch_size, seq_len, seq_len]
|
| 68 |
-
|
| 69 |
# [n_pred_arcs, 4]
|
| 70 |
-
preds_combined = self.combine_arcs_rels(
|
| 71 |
return {
|
| 72 |
'preds': preds_combined,
|
| 73 |
'loss': loss
|
|
@@ -91,8 +93,9 @@ class DependencyHeadBase(nn.Module):
|
|
| 91 |
|
| 92 |
def predict_arcs(
|
| 93 |
self,
|
| 94 |
-
s_arc: Tensor,
|
| 95 |
-
|
|
|
|
| 96 |
) -> LongTensor:
|
| 97 |
"""Predict arcs from scores."""
|
| 98 |
raise NotImplementedError
|
|
@@ -127,42 +130,40 @@ class DependencyHead(DependencyHeadBase):
|
|
| 127 |
@override
|
| 128 |
def predict_arcs(
|
| 129 |
self,
|
| 130 |
-
s_arc: Tensor,
|
| 131 |
-
|
|
|
|
| 132 |
) -> Tensor:
|
| 133 |
|
| 134 |
if self.training:
|
| 135 |
# During training, use fast greedy decoding.
|
| 136 |
# - [batch_size, seq_len]
|
| 137 |
-
pred_arcs_seq = s_arc.argmax(dim=
|
| 138 |
else:
|
| 139 |
-
# During inference,
|
| 140 |
-
pred_arcs_seq = self._mst_decode(s_arc,
|
| 141 |
-
# FIXME
|
| 142 |
-
# pred_arcs_seq = s_arc.argmax(dim=-1)
|
| 143 |
|
| 144 |
# Upscale arcs sequence of shape [batch_size, seq_len]
|
| 145 |
# to matrix of shape [batch_size, seq_len, seq_len].
|
| 146 |
-
pred_arcs = F.one_hot(pred_arcs_seq, num_classes=pred_arcs_seq.size(1)).long()
|
|
|
|
|
|
|
|
|
|
|
|
|
| 147 |
return pred_arcs
|
| 148 |
|
| 149 |
def _mst_decode(
|
| 150 |
self,
|
| 151 |
-
s_arc: Tensor,
|
| 152 |
-
|
| 153 |
) -> tuple[Tensor, Tensor]:
|
| 154 |
-
|
| 155 |
batch_size = s_arc.size(0)
|
| 156 |
device = s_arc.device
|
| 157 |
s_arc = s_arc.cpu()
|
| 158 |
|
| 159 |
# Convert scores to probabilities, as `decode_mst` expects non-negative values.
|
| 160 |
-
arc_probs = nn.functional.softmax(s_arc, dim=
|
| 161 |
-
# Transpose arcs, because decode_mst defines 'energy' matrix as
|
| 162 |
-
# energy[i,j] = "Score that `i` is the head of `j`",
|
| 163 |
-
# whereas
|
| 164 |
-
# arc_probs[i,j] = "Probability that `j` is the head of `i`".
|
| 165 |
-
arc_probs = arc_probs.transpose(1, 2)
|
| 166 |
|
| 167 |
# `decode_mst` knows nothing about UD and ROOT, so we have to manually
|
| 168 |
# zero probabilities of arcs leading to ROOT to make sure ROOT is a source node
|
|
@@ -177,11 +178,10 @@ class DependencyHead(DependencyHeadBase):
|
|
| 177 |
pred_arcs = []
|
| 178 |
for sample_idx in range(batch_size):
|
| 179 |
energy = arc_probs[sample_idx]
|
| 180 |
-
|
| 181 |
-
|
| 182 |
-
heads, _ = decode_mst(energy, lengths, has_labels=False)
|
| 183 |
# Some nodes may be isolated. Pick heads greedily in this case.
|
| 184 |
-
heads[heads <= 0] = s_arc[sample_idx].argmax(dim=
|
| 185 |
pred_arcs.append(heads)
|
| 186 |
|
| 187 |
# shape: [batch_size, seq_len]
|
|
@@ -195,7 +195,7 @@ class DependencyHead(DependencyHeadBase):
|
|
| 195 |
gold_arcs: LongTensor # [n_arcs, 4]
|
| 196 |
) -> tuple[Tensor, Tensor]:
|
| 197 |
batch_idxs, from_idxs, to_idxs, _ = gold_arcs.T
|
| 198 |
-
return F.cross_entropy(s_arc[batch_idxs,
|
| 199 |
|
| 200 |
|
| 201 |
class MultiDependencyHead(DependencyHeadBase):
|
|
@@ -206,8 +206,9 @@ class MultiDependencyHead(DependencyHeadBase):
|
|
| 206 |
@override
|
| 207 |
def predict_arcs(
|
| 208 |
self,
|
| 209 |
-
s_arc: Tensor,
|
| 210 |
-
|
|
|
|
| 211 |
) -> Tensor:
|
| 212 |
# Convert scores to probabilities.
|
| 213 |
arc_probs = torch.sigmoid(s_arc)
|
|
@@ -263,8 +264,8 @@ class DependencyClassifier(nn.Module):
|
|
| 263 |
embeddings: Tensor, # [batch_size, seq_len, embedding_size]
|
| 264 |
gold_ud: Tensor, # [n_ud_arcs, 4]
|
| 265 |
gold_eud: Tensor, # [n_eud_arcs, 4]
|
| 266 |
-
|
| 267 |
-
|
| 268 |
) -> dict[str, Tensor]:
|
| 269 |
|
| 270 |
# - [batch_size, seq_len, hidden_size]
|
|
@@ -280,7 +281,8 @@ class DependencyClassifier(nn.Module):
|
|
| 280 |
h_rel_head,
|
| 281 |
h_rel_dep,
|
| 282 |
gold_arcs=gold_ud,
|
| 283 |
-
|
|
|
|
| 284 |
)
|
| 285 |
output_eud = self.dependency_head_eud(
|
| 286 |
h_arc_head,
|
|
@@ -288,7 +290,9 @@ class DependencyClassifier(nn.Module):
|
|
| 288 |
h_rel_head,
|
| 289 |
h_rel_dep,
|
| 290 |
gold_arcs=gold_eud,
|
| 291 |
-
mask
|
|
|
|
|
|
|
| 292 |
)
|
| 293 |
|
| 294 |
return {
|
|
|
|
| 38 |
|
| 39 |
def forward(
|
| 40 |
self,
|
| 41 |
+
h_arc_head: Tensor, # [batch_size, seq_len, hidden_size]
|
| 42 |
+
h_arc_dep: Tensor, # ...
|
| 43 |
+
h_rel_head: Tensor, # ...
|
| 44 |
+
h_rel_dep: Tensor, # ...
|
| 45 |
+
gold_arcs: LongTensor, # [batch_size, seq_len, seq_len]
|
| 46 |
+
null_mask: BoolTensor, # [batch_size, seq_len]
|
| 47 |
+
padding_mask: BoolTensor # [batch_size, seq_len]
|
| 48 |
) -> dict[str, Tensor]:
|
| 49 |
+
|
| 50 |
# Score arcs.
|
| 51 |
+
# s_arc[:, i, j] = score of edge i -> j.
|
| 52 |
s_arc = self.arc_attention(h_arc_head, h_arc_dep)
|
| 53 |
# Mask undesirable values (padding, nulls, etc.) with -inf.
|
| 54 |
+
mask2d = pairwise_mask(null_mask & padding_mask)
|
| 55 |
+
replace_masked_values(s_arc, mask2d, replace_with=-1e8)
|
| 56 |
# Score arcs' relations.
|
| 57 |
# [batch_size, seq_len, seq_len, num_labels]
|
| 58 |
s_rel = self.rel_attention(h_rel_head, h_rel_dep).permute(0, 2, 3, 1)
|
|
|
|
| 65 |
|
| 66 |
# Predict arcs based on the scores.
|
| 67 |
# [batch_size, seq_len, seq_len]
|
| 68 |
+
pred_arcs_matrix = self.predict_arcs(s_arc, null_mask, padding_mask)
|
| 69 |
# [batch_size, seq_len, seq_len]
|
| 70 |
+
pred_rels_matrix = self.predict_rels(s_rel)
|
| 71 |
# [n_pred_arcs, 4]
|
| 72 |
+
preds_combined = self.combine_arcs_rels(pred_arcs_matrix, pred_rels_matrix)
|
| 73 |
return {
|
| 74 |
'preds': preds_combined,
|
| 75 |
'loss': loss
|
|
|
|
| 93 |
|
| 94 |
def predict_arcs(
|
| 95 |
self,
|
| 96 |
+
s_arc: Tensor, # [batch_size, seq_len, seq_len]
|
| 97 |
+
null_mask: BoolTensor, # [batch_size, seq_len]
|
| 98 |
+
padding_mask: BoolTensor # [batch_size, seq_len]
|
| 99 |
) -> LongTensor:
|
| 100 |
"""Predict arcs from scores."""
|
| 101 |
raise NotImplementedError
|
|
|
|
| 130 |
@override
|
| 131 |
def predict_arcs(
|
| 132 |
self,
|
| 133 |
+
s_arc: Tensor, # [batch_size, seq_len, seq_len]
|
| 134 |
+
null_mask: BoolTensor, # [batch_size, seq_len]
|
| 135 |
+
padding_mask: BoolTensor # [batch_size, seq_len, seq_len]
|
| 136 |
) -> Tensor:
|
| 137 |
|
| 138 |
if self.training:
|
| 139 |
# During training, use fast greedy decoding.
|
| 140 |
# - [batch_size, seq_len]
|
| 141 |
+
pred_arcs_seq = s_arc.argmax(dim=1)
|
| 142 |
else:
|
| 143 |
+
# During inference, decode Maximum Spanning Tree.
|
| 144 |
+
pred_arcs_seq = self._mst_decode(s_arc, padding_mask)
|
|
|
|
|
|
|
| 145 |
|
| 146 |
# Upscale arcs sequence of shape [batch_size, seq_len]
|
| 147 |
# to matrix of shape [batch_size, seq_len, seq_len].
|
| 148 |
+
pred_arcs = F.one_hot(pred_arcs_seq, num_classes=pred_arcs_seq.size(1)).long().transpose(1, 2)
|
| 149 |
+
# Apply mask one more time (even though s_arc is already masked),
|
| 150 |
+
# because argmax erases information about masked values.
|
| 151 |
+
mask2d = pairwise_mask(null_mask & padding_mask)
|
| 152 |
+
replace_masked_values(pred_arcs, mask2d, replace_with=0)
|
| 153 |
return pred_arcs
|
| 154 |
|
| 155 |
def _mst_decode(
|
| 156 |
self,
|
| 157 |
+
s_arc: Tensor, # [batch_size, seq_len, seq_len]
|
| 158 |
+
padding_mask: Tensor
|
| 159 |
) -> tuple[Tensor, Tensor]:
|
| 160 |
+
|
| 161 |
batch_size = s_arc.size(0)
|
| 162 |
device = s_arc.device
|
| 163 |
s_arc = s_arc.cpu()
|
| 164 |
|
| 165 |
# Convert scores to probabilities, as `decode_mst` expects non-negative values.
|
| 166 |
+
arc_probs = nn.functional.softmax(s_arc, dim=1)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 167 |
|
| 168 |
# `decode_mst` knows nothing about UD and ROOT, so we have to manually
|
| 169 |
# zero probabilities of arcs leading to ROOT to make sure ROOT is a source node
|
|
|
|
| 178 |
pred_arcs = []
|
| 179 |
for sample_idx in range(batch_size):
|
| 180 |
energy = arc_probs[sample_idx]
|
| 181 |
+
length = padding_mask[sample_idx].sum()
|
| 182 |
+
heads = decode_mst(energy, length)
|
|
|
|
| 183 |
# Some nodes may be isolated. Pick heads greedily in this case.
|
| 184 |
+
heads[heads <= 0] = s_arc[sample_idx].argmax(dim=1)[heads <= 0]
|
| 185 |
pred_arcs.append(heads)
|
| 186 |
|
| 187 |
# shape: [batch_size, seq_len]
|
|
|
|
| 195 |
gold_arcs: LongTensor # [n_arcs, 4]
|
| 196 |
) -> tuple[Tensor, Tensor]:
|
| 197 |
batch_idxs, from_idxs, to_idxs, _ = gold_arcs.T
|
| 198 |
+
return F.cross_entropy(s_arc[batch_idxs, :, to_idxs], from_idxs)
|
| 199 |
|
| 200 |
|
| 201 |
class MultiDependencyHead(DependencyHeadBase):
|
|
|
|
| 206 |
@override
|
| 207 |
def predict_arcs(
|
| 208 |
self,
|
| 209 |
+
s_arc: Tensor, # [batch_size, seq_len, seq_len]
|
| 210 |
+
null_mask: BoolTensor, # [batch_size, seq_len]
|
| 211 |
+
padding_mask: BoolTensor # [batch_size, seq_len]
|
| 212 |
) -> Tensor:
|
| 213 |
# Convert scores to probabilities.
|
| 214 |
arc_probs = torch.sigmoid(s_arc)
|
|
|
|
| 264 |
embeddings: Tensor, # [batch_size, seq_len, embedding_size]
|
| 265 |
gold_ud: Tensor, # [n_ud_arcs, 4]
|
| 266 |
gold_eud: Tensor, # [n_eud_arcs, 4]
|
| 267 |
+
null_mask: Tensor, # [batch_size, seq_len]
|
| 268 |
+
padding_mask: Tensor # [batch_size, seq_len]
|
| 269 |
) -> dict[str, Tensor]:
|
| 270 |
|
| 271 |
# - [batch_size, seq_len, hidden_size]
|
|
|
|
| 281 |
h_rel_head,
|
| 282 |
h_rel_dep,
|
| 283 |
gold_arcs=gold_ud,
|
| 284 |
+
null_mask=null_mask,
|
| 285 |
+
padding_mask=padding_mask
|
| 286 |
)
|
| 287 |
output_eud = self.dependency_head_eud(
|
| 288 |
h_arc_head,
|
|
|
|
| 290 |
h_rel_head,
|
| 291 |
h_rel_dep,
|
| 292 |
gold_arcs=gold_eud,
|
| 293 |
+
# Ignore null mask in E-UD
|
| 294 |
+
null_mask=torch.ones_like(padding_mask),
|
| 295 |
+
padding_mask=padding_mask
|
| 296 |
)
|
| 297 |
|
| 298 |
return {
|
model.safetensors
CHANGED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
size 1147244460
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:c0c2327dbffac624222d10865069c3b63b26c65dd0c034a5d86210d080c8dc47
|
| 3 |
size 1147244460
|
modeling_parser.py
CHANGED
|
@@ -1,8 +1,6 @@
|
|
| 1 |
from torch import nn
|
| 2 |
from torch import LongTensor
|
| 3 |
from transformers import PreTrainedModel
|
| 4 |
-
from transformers.modeling_outputs import ModelOutput
|
| 5 |
-
from dataclasses import dataclass
|
| 6 |
|
| 7 |
from .configuration import CobaldParserConfig
|
| 8 |
from .encoder import WordTransformerEncoder
|
|
@@ -17,23 +15,6 @@ from .utils import (
|
|
| 17 |
)
|
| 18 |
|
| 19 |
|
| 20 |
-
@dataclass
|
| 21 |
-
class CobaldParserOutput(ModelOutput):
|
| 22 |
-
"""
|
| 23 |
-
Output type for CobaldParser.
|
| 24 |
-
"""
|
| 25 |
-
loss: float = None
|
| 26 |
-
words: list = None
|
| 27 |
-
counting_mask: LongTensor = None
|
| 28 |
-
lemma_rules: LongTensor = None
|
| 29 |
-
joint_feats: LongTensor = None
|
| 30 |
-
deps_ud: LongTensor = None
|
| 31 |
-
deps_eud: LongTensor = None
|
| 32 |
-
miscs: LongTensor = None
|
| 33 |
-
deepslots: LongTensor = None
|
| 34 |
-
semclasses: LongTensor = None
|
| 35 |
-
|
| 36 |
-
|
| 37 |
class CobaldParser(PreTrainedModel):
|
| 38 |
"""Morpho-Syntax-Semantic Parser."""
|
| 39 |
|
|
@@ -119,8 +100,8 @@ class CobaldParser(PreTrainedModel):
|
|
| 119 |
sent_ids: list[str] = None,
|
| 120 |
texts: list[str] = None,
|
| 121 |
inference_mode: bool = False
|
| 122 |
-
) ->
|
| 123 |
-
|
| 124 |
|
| 125 |
# Extra [CLS] token accounts for the case when #NULL is the first token in a sentence.
|
| 126 |
words_with_cls = prepend_cls(words)
|
|
@@ -129,62 +110,62 @@ class CobaldParser(PreTrainedModel):
|
|
| 129 |
embeddings_without_nulls = self.encoder(words_without_nulls)
|
| 130 |
# Predict nulls.
|
| 131 |
null_output = self.classifiers["null"](embeddings_without_nulls, counting_masks)
|
| 132 |
-
|
| 133 |
-
|
| 134 |
|
| 135 |
# "Teacher forcing": during training, pass the original words (with gold nulls)
|
| 136 |
# to the classification heads, so that they are trained upon correct sentences.
|
| 137 |
if inference_mode:
|
| 138 |
# Restore predicted nulls in the original sentences.
|
| 139 |
-
|
| 140 |
else:
|
| 141 |
-
|
| 142 |
|
| 143 |
# Encode words with nulls.
|
| 144 |
# [batch_size, seq_len, embedding_size]
|
| 145 |
-
embeddings = self.encoder(
|
| 146 |
|
| 147 |
# Predict lemmas and morphological features.
|
| 148 |
if "lemma_rule" in self.classifiers:
|
| 149 |
lemma_output = self.classifiers["lemma_rule"](embeddings, lemma_rules)
|
| 150 |
-
|
| 151 |
-
|
| 152 |
|
| 153 |
if "joint_feats" in self.classifiers:
|
| 154 |
joint_feats_output = self.classifiers["joint_feats"](embeddings, joint_feats)
|
| 155 |
-
|
| 156 |
-
|
| 157 |
|
| 158 |
# Predict syntax.
|
| 159 |
if "syntax" in self.classifiers:
|
| 160 |
-
padding_mask = build_padding_mask(
|
| 161 |
-
null_mask = build_null_mask(
|
| 162 |
deps_output = self.classifiers["syntax"](
|
| 163 |
embeddings,
|
| 164 |
deps_ud,
|
| 165 |
deps_eud,
|
| 166 |
-
|
| 167 |
-
|
| 168 |
)
|
| 169 |
-
|
| 170 |
-
|
| 171 |
-
|
| 172 |
|
| 173 |
# Predict miscellaneous features.
|
| 174 |
if "misc" in self.classifiers:
|
| 175 |
misc_output = self.classifiers["misc"](embeddings, miscs)
|
| 176 |
-
|
| 177 |
-
|
| 178 |
|
| 179 |
# Predict semantics.
|
| 180 |
if "deepslot" in self.classifiers:
|
| 181 |
deepslot_output = self.classifiers["deepslot"](embeddings, deepslots)
|
| 182 |
-
|
| 183 |
-
|
| 184 |
|
| 185 |
if "semclass" in self.classifiers:
|
| 186 |
semclass_output = self.classifiers["semclass"](embeddings, semclasses)
|
| 187 |
-
|
| 188 |
-
|
| 189 |
|
| 190 |
-
return
|
|
|
|
| 1 |
from torch import nn
|
| 2 |
from torch import LongTensor
|
| 3 |
from transformers import PreTrainedModel
|
|
|
|
|
|
|
| 4 |
|
| 5 |
from .configuration import CobaldParserConfig
|
| 6 |
from .encoder import WordTransformerEncoder
|
|
|
|
| 15 |
)
|
| 16 |
|
| 17 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 18 |
class CobaldParser(PreTrainedModel):
|
| 19 |
"""Morpho-Syntax-Semantic Parser."""
|
| 20 |
|
|
|
|
| 100 |
sent_ids: list[str] = None,
|
| 101 |
texts: list[str] = None,
|
| 102 |
inference_mode: bool = False
|
| 103 |
+
) -> dict:
|
| 104 |
+
output = {}
|
| 105 |
|
| 106 |
# Extra [CLS] token accounts for the case when #NULL is the first token in a sentence.
|
| 107 |
words_with_cls = prepend_cls(words)
|
|
|
|
| 110 |
embeddings_without_nulls = self.encoder(words_without_nulls)
|
| 111 |
# Predict nulls.
|
| 112 |
null_output = self.classifiers["null"](embeddings_without_nulls, counting_masks)
|
| 113 |
+
output["counting_mask"] = null_output['preds']
|
| 114 |
+
output["loss"] = null_output["loss"]
|
| 115 |
|
| 116 |
# "Teacher forcing": during training, pass the original words (with gold nulls)
|
| 117 |
# to the classification heads, so that they are trained upon correct sentences.
|
| 118 |
if inference_mode:
|
| 119 |
# Restore predicted nulls in the original sentences.
|
| 120 |
+
output["words"] = add_nulls(words, null_output["preds"])
|
| 121 |
else:
|
| 122 |
+
output["words"] = words
|
| 123 |
|
| 124 |
# Encode words with nulls.
|
| 125 |
# [batch_size, seq_len, embedding_size]
|
| 126 |
+
embeddings = self.encoder(output["words"])
|
| 127 |
|
| 128 |
# Predict lemmas and morphological features.
|
| 129 |
if "lemma_rule" in self.classifiers:
|
| 130 |
lemma_output = self.classifiers["lemma_rule"](embeddings, lemma_rules)
|
| 131 |
+
output["lemma_rules"] = lemma_output['preds']
|
| 132 |
+
output["loss"] += lemma_output['loss']
|
| 133 |
|
| 134 |
if "joint_feats" in self.classifiers:
|
| 135 |
joint_feats_output = self.classifiers["joint_feats"](embeddings, joint_feats)
|
| 136 |
+
output["joint_feats"] = joint_feats_output['preds']
|
| 137 |
+
output["loss"] += joint_feats_output['loss']
|
| 138 |
|
| 139 |
# Predict syntax.
|
| 140 |
if "syntax" in self.classifiers:
|
| 141 |
+
padding_mask = build_padding_mask(output["words"], self.device)
|
| 142 |
+
null_mask = build_null_mask(output["words"], self.device)
|
| 143 |
deps_output = self.classifiers["syntax"](
|
| 144 |
embeddings,
|
| 145 |
deps_ud,
|
| 146 |
deps_eud,
|
| 147 |
+
null_mask,
|
| 148 |
+
padding_mask
|
| 149 |
)
|
| 150 |
+
output["deps_ud"] = deps_output['preds_ud']
|
| 151 |
+
output["deps_eud"] = deps_output['preds_eud']
|
| 152 |
+
output["loss"] += deps_output['loss_ud'] + deps_output['loss_eud']
|
| 153 |
|
| 154 |
# Predict miscellaneous features.
|
| 155 |
if "misc" in self.classifiers:
|
| 156 |
misc_output = self.classifiers["misc"](embeddings, miscs)
|
| 157 |
+
output["miscs"] = misc_output['preds']
|
| 158 |
+
output["loss"] += misc_output['loss']
|
| 159 |
|
| 160 |
# Predict semantics.
|
| 161 |
if "deepslot" in self.classifiers:
|
| 162 |
deepslot_output = self.classifiers["deepslot"](embeddings, deepslots)
|
| 163 |
+
output["deepslots"] = deepslot_output['preds']
|
| 164 |
+
output["loss"] += deepslot_output['loss']
|
| 165 |
|
| 166 |
if "semclass" in self.classifiers:
|
| 167 |
semclass_output = self.classifiers["semclass"](embeddings, semclasses)
|
| 168 |
+
output["semclasses"] = semclass_output['preds']
|
| 169 |
+
output["loss"] += semclass_output['loss']
|
| 170 |
|
| 171 |
+
return output
|
training_args.bin
CHANGED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
size 5432
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:e122e18ca0d9c5f65733c55e15f58827e494045ce872cd55b7379b88c8e83ee6
|
| 3 |
size 5432
|
utils.py
CHANGED
|
@@ -21,7 +21,7 @@ def build_padding_mask(sentences: list[list[str]], device) -> Tensor:
|
|
| 21 |
return _build_condition_mask(sentences, condition_fn=lambda word: True, device=device)
|
| 22 |
|
| 23 |
def build_null_mask(sentences: list[list[str]], device) -> Tensor:
|
| 24 |
-
return _build_condition_mask(sentences, condition_fn=lambda word: word =
|
| 25 |
|
| 26 |
|
| 27 |
def pairwise_mask(masks1d: Tensor) -> Tensor:
|
|
|
|
| 21 |
return _build_condition_mask(sentences, condition_fn=lambda word: True, device=device)
|
| 22 |
|
| 23 |
def build_null_mask(sentences: list[list[str]], device) -> Tensor:
|
| 24 |
+
return _build_condition_mask(sentences, condition_fn=lambda word: word != "#NULL", device=device)
|
| 25 |
|
| 26 |
|
| 27 |
def pairwise_mask(masks1d: Tensor) -> Tensor:
|