Update dependency_classifier.py

#3
by E-katrin - opened
Files changed (1) hide show
  1. dependency_classifier.py +5 -5
dependency_classifier.py CHANGED
@@ -1,4 +1,4 @@
1
- from typing import override
2
  from copy import deepcopy
3
 
4
  import numpy as np
@@ -127,7 +127,7 @@ class DependencyHead(DependencyHeadBase):
127
  Basic UD syntax specialization that predicts single edge for each token.
128
  """
129
 
130
- @override
131
  def predict_arcs(
132
  self,
133
  s_arc: Tensor, # [batch_size, seq_len, seq_len]
@@ -191,7 +191,7 @@ class DependencyHead(DependencyHeadBase):
191
  return pred_arcs
192
 
193
  @staticmethod
194
- @override
195
  def calc_arc_loss(
196
  s_arc: Tensor, # [batch_size, seq_len, seq_len]
197
  gold_arcs: LongTensor # [n_arcs, 4]
@@ -205,7 +205,7 @@ class MultiDependencyHead(DependencyHeadBase):
205
  Enhanced UD syntax specialization that predicts multiple edges for each token.
206
  """
207
 
208
- @override
209
  def predict_arcs(
210
  self,
211
  s_arc: Tensor, # [batch_size, seq_len, seq_len]
@@ -218,7 +218,7 @@ class MultiDependencyHead(DependencyHeadBase):
218
  return arc_probs.round().long()
219
 
220
  @staticmethod
221
- @override
222
  def calc_arc_loss(
223
  s_arc: Tensor, # [batch_size, seq_len, seq_len]
224
  gold_arcs: LongTensor # [n_arcs, 4]
 
1
+ #from typing import override
2
  from copy import deepcopy
3
 
4
  import numpy as np
 
127
  Basic UD syntax specialization that predicts single edge for each token.
128
  """
129
 
130
+ #@override
131
  def predict_arcs(
132
  self,
133
  s_arc: Tensor, # [batch_size, seq_len, seq_len]
 
191
  return pred_arcs
192
 
193
  @staticmethod
194
+ #@override
195
  def calc_arc_loss(
196
  s_arc: Tensor, # [batch_size, seq_len, seq_len]
197
  gold_arcs: LongTensor # [n_arcs, 4]
 
205
  Enhanced UD syntax specialization that predicts multiple edges for each token.
206
  """
207
 
208
+ #@override
209
  def predict_arcs(
210
  self,
211
  s_arc: Tensor, # [batch_size, seq_len, seq_len]
 
218
  return arc_probs.round().long()
219
 
220
  @staticmethod
221
+ #@override
222
  def calc_arc_loss(
223
  s_arc: Tensor, # [batch_size, seq_len, seq_len]
224
  gold_arcs: LongTensor # [n_arcs, 4]