Update dependency_classifier.py
#3
by
E-katrin - opened
- dependency_classifier.py +5 -5
dependency_classifier.py
CHANGED
|
@@ -1,4 +1,4 @@
|
|
| 1 |
-
from typing import override
|
| 2 |
from copy import deepcopy
|
| 3 |
|
| 4 |
import numpy as np
|
|
@@ -127,7 +127,7 @@ class DependencyHead(DependencyHeadBase):
|
|
| 127 |
Basic UD syntax specialization that predicts single edge for each token.
|
| 128 |
"""
|
| 129 |
|
| 130 |
-
@override
|
| 131 |
def predict_arcs(
|
| 132 |
self,
|
| 133 |
s_arc: Tensor, # [batch_size, seq_len, seq_len]
|
|
@@ -191,7 +191,7 @@ class DependencyHead(DependencyHeadBase):
|
|
| 191 |
return pred_arcs
|
| 192 |
|
| 193 |
@staticmethod
|
| 194 |
-
@override
|
| 195 |
def calc_arc_loss(
|
| 196 |
s_arc: Tensor, # [batch_size, seq_len, seq_len]
|
| 197 |
gold_arcs: LongTensor # [n_arcs, 4]
|
|
@@ -205,7 +205,7 @@ class MultiDependencyHead(DependencyHeadBase):
|
|
| 205 |
Enhanced UD syntax specialization that predicts multiple edges for each token.
|
| 206 |
"""
|
| 207 |
|
| 208 |
-
@override
|
| 209 |
def predict_arcs(
|
| 210 |
self,
|
| 211 |
s_arc: Tensor, # [batch_size, seq_len, seq_len]
|
|
@@ -218,7 +218,7 @@ class MultiDependencyHead(DependencyHeadBase):
|
|
| 218 |
return arc_probs.round().long()
|
| 219 |
|
| 220 |
@staticmethod
|
| 221 |
-
@override
|
| 222 |
def calc_arc_loss(
|
| 223 |
s_arc: Tensor, # [batch_size, seq_len, seq_len]
|
| 224 |
gold_arcs: LongTensor # [n_arcs, 4]
|
|
|
|
| 1 |
+
#from typing import override
|
| 2 |
from copy import deepcopy
|
| 3 |
|
| 4 |
import numpy as np
|
|
|
|
| 127 |
Basic UD syntax specialization that predicts single edge for each token.
|
| 128 |
"""
|
| 129 |
|
| 130 |
+
#@override
|
| 131 |
def predict_arcs(
|
| 132 |
self,
|
| 133 |
s_arc: Tensor, # [batch_size, seq_len, seq_len]
|
|
|
|
| 191 |
return pred_arcs
|
| 192 |
|
| 193 |
@staticmethod
|
| 194 |
+
#@override
|
| 195 |
def calc_arc_loss(
|
| 196 |
s_arc: Tensor, # [batch_size, seq_len, seq_len]
|
| 197 |
gold_arcs: LongTensor # [n_arcs, 4]
|
|
|
|
| 205 |
Enhanced UD syntax specialization that predicts multiple edges for each token.
|
| 206 |
"""
|
| 207 |
|
| 208 |
+
#@override
|
| 209 |
def predict_arcs(
|
| 210 |
self,
|
| 211 |
s_arc: Tensor, # [batch_size, seq_len, seq_len]
|
|
|
|
| 218 |
return arc_probs.round().long()
|
| 219 |
|
| 220 |
@staticmethod
|
| 221 |
+
#@override
|
| 222 |
def calc_arc_loss(
|
| 223 |
s_arc: Tensor, # [batch_size, seq_len, seq_len]
|
| 224 |
gold_arcs: LongTensor # [n_arcs, 4]
|