bowphs commited on
Commit
9634055
·
verified ·
1 Parent(s): 986094c

Add files using upload-large-folder tool

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. stanza/stanza/models/classifiers/__init__.py +0 -0
  2. stanza/stanza/models/classifiers/config.py +55 -0
  3. stanza/stanza/models/classifiers/utils.py +41 -0
  4. stanza/stanza/models/constituency/dynamic_oracle.py +135 -0
  5. stanza/stanza/models/constituency/parse_transitions.py +641 -0
  6. stanza/stanza/models/constituency/parser_training.py +771 -0
  7. stanza/stanza/models/constituency/partitioned_transformer.py +308 -0
  8. stanza/stanza/models/coref/__init__.py +0 -0
  9. stanza/stanza/models/coref/anaphoricity_scorer.py +122 -0
  10. stanza/stanza/models/coref/cluster_checker.py +230 -0
  11. stanza/stanza/models/coref/conll.py +105 -0
  12. stanza/stanza/models/coref/const.py +27 -0
  13. stanza/stanza/models/coref/coref_chain.py +39 -0
  14. stanza/stanza/models/coref/loss.py +37 -0
  15. stanza/stanza/models/coref/model.py +784 -0
  16. stanza/stanza/models/depparse/__init__.py +0 -0
  17. stanza/stanza/models/depparse/scorer.py +60 -0
  18. stanza/stanza/models/depparse/trainer.py +250 -0
  19. stanza/stanza/models/langid/trainer.py +51 -0
  20. stanza/stanza/models/lemma/__init__.py +0 -0
  21. stanza/stanza/models/lemma/data.py +212 -0
  22. stanza/stanza/models/lemma_classifier/base_model.py +134 -0
  23. stanza/stanza/models/lemma_classifier/baseline_model.py +54 -0
  24. stanza/stanza/models/lemma_classifier/lstm_model.py +219 -0
  25. stanza/stanza/models/mwt/data.py +182 -0
  26. stanza/stanza/models/mwt/scorer.py +12 -0
  27. stanza/stanza/models/mwt/utils.py +92 -0
  28. stanza/stanza/models/ner/__init__.py +0 -0
  29. stanza/stanza/models/ner/trainer.py +268 -0
  30. stanza/stanza/models/tokenization/vocab.py +35 -0
  31. stanza/stanza/pipeline/demo/README.md +23 -0
  32. stanza/stanza/utils/datasets/constituency/__init__.py +0 -0
  33. stanza/stanza/utils/datasets/constituency/common_trees.py +23 -0
  34. stanza/stanza/utils/datasets/constituency/convert_alt.py +100 -0
  35. stanza/stanza/utils/datasets/constituency/convert_arboretum.py +443 -0
  36. stanza/stanza/utils/datasets/constituency/convert_icepahc.py +83 -0
  37. stanza/stanza/utils/datasets/constituency/convert_it_turin.py +339 -0
  38. stanza/stanza/utils/datasets/constituency/convert_it_vit.py +700 -0
  39. stanza/stanza/utils/datasets/constituency/convert_spmrl.py +35 -0
  40. stanza/stanza/utils/datasets/constituency/convert_starlang.py +96 -0
  41. stanza/stanza/utils/datasets/constituency/extract_all_silver_dataset.py +46 -0
  42. stanza/stanza/utils/datasets/constituency/relabel_tags.py +48 -0
  43. stanza/stanza/utils/datasets/constituency/selftrain.py +268 -0
  44. stanza/stanza/utils/datasets/constituency/selftrain_it.py +120 -0
  45. stanza/stanza/utils/datasets/constituency/selftrain_single_file.py +88 -0
  46. stanza/stanza/utils/datasets/constituency/selftrain_vi_quad.py +98 -0
  47. stanza/stanza/utils/datasets/constituency/selftrain_wiki.py +140 -0
  48. stanza/stanza/utils/datasets/constituency/split_holdout.py +64 -0
  49. stanza/stanza/utils/datasets/constituency/split_weighted_ensemble.py +73 -0
  50. stanza/stanza/utils/datasets/constituency/tokenize_wiki.py +104 -0
stanza/stanza/models/classifiers/__init__.py ADDED
File without changes
stanza/stanza/models/classifiers/config.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ from typing import List, Union
3
+
4
+ # TODO: perhaps put the enums in this file
5
+ from stanza.models.classifiers.utils import WVType, ExtraVectors, ModelType
6
+
7
+ @dataclass
8
+ class CNNConfig: # pylint: disable=too-many-instance-attributes, too-few-public-methods
9
+ filter_channels: Union[int, tuple]
10
+ filter_sizes: tuple
11
+ fc_shapes: tuple
12
+ dropout: float
13
+ num_classes: int
14
+ wordvec_type: WVType
15
+ extra_wordvec_method: ExtraVectors
16
+ extra_wordvec_dim: int
17
+ extra_wordvec_max_norm: float
18
+ char_lowercase: bool
19
+ charlm_projection: int
20
+ has_charlm_forward: bool
21
+ has_charlm_backward: bool
22
+
23
+ use_elmo: bool
24
+ elmo_projection: int
25
+
26
+ bert_model: str
27
+ bert_finetune: bool
28
+ bert_hidden_layers: int
29
+ force_bert_saved: bool
30
+
31
+ use_peft: bool
32
+ lora_rank: int
33
+ lora_alpha: float
34
+ lora_dropout: float
35
+ lora_modules_to_save: List
36
+ lora_target_modules: List
37
+
38
+ bilstm: bool
39
+ bilstm_hidden_dim: int
40
+ maxpool_width: int
41
+ model_type: ModelType
42
+
43
+ @dataclass
44
+ class ConstituencyConfig: # pylint: disable=too-many-instance-attributes, too-few-public-methods
45
+ fc_shapes: tuple
46
+ dropout: float
47
+ num_classes: int
48
+
49
+ constituency_backprop: bool
50
+ constituency_batch_norm: bool
51
+ constituency_node_attn: bool
52
+ constituency_top_layer: bool
53
+ constituency_all_words: bool
54
+
55
+ model_type: ModelType
stanza/stanza/models/classifiers/utils.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from enum import Enum
2
+
3
+ from torch import nn
4
+
5
+ """
6
+ Defines some methods which may occur in multiple model types
7
+ """
8
+ # NLP machines:
9
+ # word2vec are in
10
+ # /u/nlp/data/stanfordnlp/model_production/stanfordnlp/extern_data/word2vec
11
+ # google vectors are in
12
+ # /scr/nlp/data/wordvectors/en/google/GoogleNews-vectors-negative300.txt
13
+
14
+ class WVType(Enum):
15
+ WORD2VEC = 1
16
+ GOOGLE = 2
17
+ FASTTEXT = 3
18
+ OTHER = 4
19
+
20
+ class ExtraVectors(Enum):
21
+ NONE = 1
22
+ CONCAT = 2
23
+ SUM = 3
24
+
25
+ class ModelType(Enum):
26
+ CNN = 1
27
+ CONSTITUENCY = 2
28
+
29
+ def build_output_layers(fc_input_size, fc_shapes, num_classes):
30
+ """
31
+ Build a sequence of fully connected layers to go from the final conv layer to num_classes
32
+
33
+ Returns an nn.ModuleList
34
+ """
35
+ fc_layers = []
36
+ previous_layer_size = fc_input_size
37
+ for shape in fc_shapes:
38
+ fc_layers.append(nn.Linear(previous_layer_size, shape))
39
+ previous_layer_size = shape
40
+ fc_layers.append(nn.Linear(previous_layer_size, num_classes))
41
+ return nn.ModuleList(fc_layers)
stanza/stanza/models/constituency/dynamic_oracle.py ADDED
@@ -0,0 +1,135 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections import namedtuple
2
+
3
+ import numpy as np
4
+
5
+ from stanza.models.constituency.parse_transitions import Shift, OpenConstituent, CloseConstituent
6
+
7
+ RepairEnum = namedtuple("RepairEnum", "name value is_correct")
8
+
9
+ def score_candidates(model, state, candidates, candidate_idx):
10
+ """
11
+ score candidate fixed sequences by summing up the transition scores of the most important block
12
+
13
+ the candidate with the best summed score is chosen, and the candidate sequence is reconstructed from the blocks
14
+ """
15
+ scores = []
16
+ # could bulkify this if we wanted
17
+ for candidate in candidates:
18
+ current_state = [state]
19
+ for block in candidate[1:candidate_idx]:
20
+ for transition in block:
21
+ current_state = model.bulk_apply(current_state, [transition])
22
+ score = 0.0
23
+ for transition in candidate[candidate_idx]:
24
+ predictions = model.forward(current_state)
25
+ t_idx = model.transition_map[transition]
26
+ score += predictions[0, t_idx].cpu().item()
27
+ current_state = model.bulk_apply(current_state, [transition])
28
+ scores.append(score)
29
+ best_idx = np.argmax(scores)
30
+ best_candidate = [x for block in candidates[best_idx] for x in block]
31
+ return scores, best_idx, best_candidate
32
+
33
+ def advance_past_constituents(gold_sequence, cur_index):
34
+ """
35
+ Advance cur_index through gold_sequence until we have seen 1 more Close than Open
36
+
37
+ The index returned is the index of the Close which occurred after all the stuff
38
+ """
39
+ count = 0
40
+ while cur_index < len(gold_sequence):
41
+ if isinstance(gold_sequence[cur_index], OpenConstituent):
42
+ count = count + 1
43
+ elif isinstance(gold_sequence[cur_index], CloseConstituent):
44
+ count = count - 1
45
+ if count == -1: return cur_index
46
+ cur_index = cur_index + 1
47
+ return None
48
+
49
+ def find_previous_open(gold_sequence, cur_index):
50
+ """
51
+ Go backwards from cur_index to find the open which opens the previous block of stuff.
52
+
53
+ Return None if it can't be found.
54
+ """
55
+ count = 0
56
+ cur_index = cur_index - 1
57
+ while cur_index >= 0:
58
+ if isinstance(gold_sequence[cur_index], OpenConstituent):
59
+ count = count + 1
60
+ if count > 0:
61
+ return cur_index
62
+ elif isinstance(gold_sequence[cur_index], CloseConstituent):
63
+ count = count - 1
64
+ cur_index = cur_index - 1
65
+ return None
66
+
67
+ def find_in_order_constituent_end(gold_sequence, cur_index):
68
+ """
69
+ Advance cur_index through gold_sequence until the next block has ended
70
+
71
+ This is different from advance_past_constituents in that it will
72
+ also return when there is a Shift when count == 0. That way, we
73
+ return the first block of things we know attach to the left
74
+ """
75
+ count = 0
76
+ saw_shift = False
77
+ while cur_index < len(gold_sequence):
78
+ if isinstance(gold_sequence[cur_index], OpenConstituent):
79
+ count = count + 1
80
+ elif isinstance(gold_sequence[cur_index], CloseConstituent):
81
+ count = count - 1
82
+ if count == -1: return cur_index
83
+ elif isinstance(gold_sequence[cur_index], Shift):
84
+ if saw_shift and count == 0:
85
+ return cur_index
86
+ else:
87
+ saw_shift = True
88
+ cur_index = cur_index + 1
89
+ return None
90
+
91
+ class DynamicOracle():
92
+ def __init__(self, root_labels, oracle_level, repair_types, additional_levels, deactivated_levels):
93
+ self.root_labels = root_labels
94
+ # default oracle_level will be the UNKNOWN repair type (which each oracle should have)
95
+ # transitions after that as experimental or ambiguous, not to be used by default
96
+ self.oracle_level = oracle_level if oracle_level is not None else repair_types.UNKNOWN.value
97
+ self.repair_types = repair_types
98
+ self.additional_levels = set()
99
+ if additional_levels:
100
+ self.additional_levels = set([repair_types[x.upper()] for x in additional_levels.split(",")])
101
+ self.deactivated_levels = set()
102
+ if deactivated_levels:
103
+ self.deactivated_levels = set([repair_types[x.upper()] for x in deactivated_levels.split(",")])
104
+
105
+ def fix_error(self, pred_transition, model, state):
106
+ """
107
+ Return which error has been made, if any, along with an updated transition list
108
+
109
+ We assume the transition sequence builds a correct tree, meaning
110
+ that there will always be a CloseConstituent sometime after an
111
+ OpenConstituent, for example
112
+ """
113
+ gold_transition = state.gold_sequence[state.num_transitions]
114
+ if gold_transition == pred_transition:
115
+ return self.repair_types.CORRECT, None
116
+
117
+ for repair_type in self.repair_types:
118
+ if repair_type.fn is None:
119
+ continue
120
+ if self.oracle_level is not None and repair_type.value > self.oracle_level and repair_type not in self.additional_levels and not repair_type.debug:
121
+ continue
122
+ if repair_type in self.deactivated_levels:
123
+ continue
124
+ repair = repair_type.fn(gold_transition, pred_transition, state.gold_sequence, state.num_transitions, self.root_labels, model, state)
125
+ if repair is None:
126
+ continue
127
+
128
+ if isinstance(repair, tuple) and len(repair) == 2:
129
+ return repair
130
+
131
+ # TODO: could update all of the returns to be tuples of length 2
132
+ if repair is not None:
133
+ return repair_type, repair
134
+
135
+ return self.repair_types.UNKNOWN, None
stanza/stanza/models/constituency/parse_transitions.py ADDED
@@ -0,0 +1,641 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Defines a series of transitions (open a constituent, close a constituent, etc
3
+
4
+ Also defines a State which holds the various data needed to build
5
+ a parse tree out of tagged words.
6
+ """
7
+
8
+ from abc import ABC, abstractmethod
9
+ import ast
10
+ from collections import defaultdict
11
+ from enum import Enum
12
+ import functools
13
+ import logging
14
+
15
+ from stanza.models.constituency.parse_tree import Tree
16
+
17
+ logger = logging.getLogger('stanza')
18
+
19
+ class TransitionScheme(Enum):
20
+ def __new__(cls, value, short_name):
21
+ obj = object.__new__(cls)
22
+ obj._value_ = value
23
+ obj.short_name = short_name
24
+ return obj
25
+
26
+
27
+ # top down, so the open transition comes before any constituents
28
+ # score on vi_vlsp22 with 5 different sizes of bert layers,
29
+ # bert tagger, no silver dataset:
30
+ # 0.8171
31
+ TOP_DOWN = 1, "top"
32
+ # unary transitions are modeled as one entire transition
33
+ # version that uses one transform per item,
34
+ # score on experiment described above:
35
+ # 0.8157
36
+ # score using one combination step for an entire transition:
37
+ # 0.8178
38
+ TOP_DOWN_COMPOUND = 2, "topc"
39
+ # unary is a separate transition. doesn't help
40
+ # score on experiment described above:
41
+ # 0.8128
42
+ TOP_DOWN_UNARY = 3, "topu"
43
+
44
+ # open transition comes after the first constituent it cares about
45
+ # score on experiment described above:
46
+ # 0.8205
47
+ # note that this is with an oracle, whereas IN_ORDER_COMPOUND does
48
+ # not have a dynamic oracle, so there may be room for improvement
49
+ IN_ORDER = 4, "in"
50
+
51
+ # in order, with unaries after preterminals represented as a single
52
+ # transition after the preterminal
53
+ # and unaries elsewhere tied to the rest of the constituent
54
+ # score: 0.8186
55
+ IN_ORDER_COMPOUND = 5, "inc"
56
+
57
+ # in order, with CompoundUnary on both preterminals and internal nodes
58
+ # score: 0.8166
59
+ IN_ORDER_UNARY = 6, "inu"
60
+
61
+ @functools.total_ordering
62
+ class Transition(ABC):
63
+ """
64
+ model is passed in as a dependency injection
65
+ for example, an LSTM model can update hidden & output vectors when transitioning
66
+ """
67
+ @abstractmethod
68
+ def update_state(self, state, model):
69
+ """
70
+ update the word queue position, possibly remove old pieces from the constituents state, and return the new constituent
71
+
72
+ the return value should be a tuple:
73
+ updated word_position
74
+ updated constituents
75
+ new constituent to put on the queue and None
76
+ - note that the constituent shouldn't be on the queue yet
77
+ that allows putting it on as a batch operation, which
78
+ saves a significant amount of time in an LSTM, for example
79
+ OR
80
+ data used to make a new constituent and the method used
81
+ - for example, CloseConstituent can return the children needed
82
+ and itself. this allows a batch operation to build
83
+ the constituent
84
+ """
85
+
86
+ def delta_opens(self):
87
+ return 0
88
+
89
+ def apply(self, state, model):
90
+ """
91
+ return a new State transformed via this transition
92
+
93
+ convenience method to call bulk_apply, which is significantly
94
+ faster than single operations for an NN based model
95
+ """
96
+ update = model.bulk_apply([state], [self])
97
+ return update[0]
98
+
99
+ @abstractmethod
100
+ def is_legal(self, state, model):
101
+ """
102
+ assess whether or not this transition is legal in this state
103
+
104
+ at parse time, the parser might choose a transition which cannot be made
105
+ """
106
+
107
+ def components(self):
108
+ """
109
+ Return a list of transitions which could theoretically make up this transition
110
+
111
+ For example, an Open transition with multiple labels would
112
+ return a list of Opens with those labels
113
+ """
114
+ return [self]
115
+
116
+ @abstractmethod
117
+ def short_name(self):
118
+ """
119
+ A short name to identify this transition
120
+ """
121
+
122
+ def short_label(self):
123
+ if not hasattr(self, "label"):
124
+ return self.short_name()
125
+
126
+ if isinstance(self.label, str):
127
+ label = self.label
128
+ elif len(self.label) == 1:
129
+ label = self.label[0]
130
+ else:
131
+ label = self.label
132
+ return "{}({})".format(self.short_name(), label)
133
+
134
+ def __lt__(self, other):
135
+ # put the Shift at the front of a list, and otherwise sort alphabetically
136
+ if self == other:
137
+ return False
138
+ if isinstance(self, Shift):
139
+ return True
140
+ if isinstance(other, Shift):
141
+ return False
142
+ return str(self) < str(other)
143
+
144
+
145
+ @staticmethod
146
+ def from_repr(desc):
147
+ """
148
+ This method is to avoid using eval() or otherwise trying to
149
+ deserialize strings in a possibly untrusted manner when
150
+ loading from a checkpoint
151
+ """
152
+ if desc == 'Shift':
153
+ return Shift()
154
+ if desc == 'CloseConstituent':
155
+ return CloseConstituent()
156
+ labels = desc.split("(", maxsplit=1)
157
+ if labels[0] not in ('CompoundUnary', 'OpenConstituent', 'Finalize'):
158
+ raise ValueError("Unknown Transition %s" % desc)
159
+ if len(labels) == 1:
160
+ raise ValueError("Unexpected Transition repr, %s needs labels" % labels[0])
161
+ if labels[1][-1] != ')':
162
+ raise ValueError("Expected Transition repr for %s: %s(labels)" % (labels[0], labels[0]))
163
+ trans_type = labels[0]
164
+ labels = labels[1][:-1]
165
+ labels = ast.literal_eval(labels)
166
+ if trans_type == 'CompoundUnary':
167
+ return CompoundUnary(*labels)
168
+ if trans_type == 'OpenConstituent':
169
+ return OpenConstituent(*labels)
170
+ if trans_type == 'Finalize':
171
+ return Finalize(*labels)
172
+ raise ValueError("Unexpected Transition %s" % desc)
173
+
174
+ class Shift(Transition):
175
+ def update_state(self, state, model):
176
+ """
177
+ This will handle all aspects of a shift transition
178
+
179
+ - push the top element of the word queue onto constituents
180
+ - pop the top element of the word queue
181
+ """
182
+ new_constituent = model.transform_word_to_constituent(state)
183
+ return state.word_position+1, state.constituents, new_constituent, None
184
+
185
+ def is_legal(self, state, model):
186
+ """
187
+ Disallow shifting when the word queue is empty or there are no opens to eventually eat this word
188
+ """
189
+ if state.empty_word_queue():
190
+ return False
191
+ if model.is_top_down:
192
+ # top down transition sequences cannot shift if there are currently no
193
+ # Open transitions on the stack. in such a case, the new constituent
194
+ # will never be reduced
195
+ if state.num_opens == 0:
196
+ return False
197
+ if state.num_opens == 1:
198
+ # there must be at least one transition, since there is an open
199
+ assert state.transitions.parent is not None
200
+ if state.transitions.parent.parent is None:
201
+ # only one transition
202
+ trans = model.get_top_transition(state.transitions)
203
+ # must be an Open, since there is one open and one transitions
204
+ # note that an S, FRAG, etc could happen if we're using unary
205
+ # and ROOT-S is possible in the case of compound Open
206
+ # in both cases, Shift is legal
207
+ # Note that the corresponding problem of shifting after the ROOT-S
208
+ # has been closed to just ROOT is handled in CloseConstituent
209
+ if len(trans.label) == 1 and trans.top_label in model.root_labels:
210
+ # don't shift a word at the very start of a parse
211
+ # we want there to be an extra layer below ROOT
212
+ return False
213
+ else:
214
+ # in-order k==1 (the only other option currently)
215
+ # can shift ONCE, but note that there is no way to consume
216
+ # two items in a row if there is no Open on the stack.
217
+ # As long as there is one or more open transitions,
218
+ # everything can be eaten
219
+ if state.num_opens == 0:
220
+ if not state.empty_constituents:
221
+ return False
222
+ return True
223
+
224
+ def short_name(self):
225
+ return "Shift"
226
+
227
+ def __repr__(self):
228
+ return "Shift"
229
+
230
+ def __eq__(self, other):
231
+ if self is other:
232
+ return True
233
+ if isinstance(other, Shift):
234
+ return True
235
+ return False
236
+
237
+ def __hash__(self):
238
+ return hash(37)
239
+
240
+ class CompoundUnary(Transition):
241
+ def __init__(self, *label):
242
+ # the FIRST label will be the top of the tree
243
+ # so CompoundUnary that results in root will have root as labels[0], for example
244
+ self.label = tuple(label)
245
+
246
+ def update_state(self, state, model):
247
+ """
248
+ Apply potentially multiple unary transitions to the same preterminal
249
+
250
+ It reuses the CloseConstituent machinery
251
+ """
252
+ # only the top constituent is meaningful here
253
+ constituents = state.constituents
254
+ children = [constituents.value]
255
+ constituents = constituents.pop()
256
+ # unlike with CloseConstituent, our label is not on the stack.
257
+ # it is just our label
258
+ # ... but we do reuse CloseConstituent's update mechanism
259
+ return state.word_position, constituents, (self.label, children), CloseConstituent
260
+
261
+ def is_legal(self, state, model):
262
+ """
263
+ Disallow consecutive CompoundUnary transitions, force final transition to go to ROOT
264
+ """
265
+ # can't unary transition nothing
266
+ tree = model.get_top_constituent(state.constituents)
267
+ if tree is None:
268
+ return False
269
+ # don't unary transition a dummy, dummy
270
+ # and don't stack CompoundUnary transitions
271
+ if isinstance(model.get_top_transition(state.transitions), (CompoundUnary, OpenConstituent)):
272
+ return False
273
+ # if we are doing IN_ORDER_COMPOUND, then we are only using these
274
+ # transitions to model changes from a tag node to a sequence of
275
+ # unary nodes. can only occur at preterminals
276
+ if model.transition_scheme() is TransitionScheme.IN_ORDER_COMPOUND:
277
+ return tree.is_preterminal()
278
+ if model.transition_scheme() is not TransitionScheme.TOP_DOWN_UNARY:
279
+ return True
280
+
281
+ is_root = self.label[0] in model.root_labels
282
+ if not state.empty_word_queue() or not state.has_one_constituent():
283
+ return not is_root
284
+ else:
285
+ return is_root
286
+
287
+ def components(self):
288
+ return [CompoundUnary(label) for label in self.label]
289
+
290
+ def short_name(self):
291
+ return "Unary"
292
+
293
+ def __repr__(self):
294
+ return "CompoundUnary(%s)" % ",".join(self.label)
295
+
296
+ def __eq__(self, other):
297
+ if self is other:
298
+ return True
299
+ if not isinstance(other, CompoundUnary):
300
+ return False
301
+ if self.label == other.label:
302
+ return True
303
+ return False
304
+
305
+ def __hash__(self):
306
+ return hash(self.label)
307
+
308
+ class Dummy():
309
+ """
310
+ Takes a space on the constituent stack to represent where an Open transition occurred
311
+ """
312
+ def __init__(self, label):
313
+ self.label = label
314
+
315
+ def is_preterminal(self):
316
+ return False
317
+
318
+ def __format__(self, spec):
319
+ if spec is None or spec == '' or spec == 'O':
320
+ return "(%s ...)" % self.label
321
+ if spec == 'T':
322
+ return "\Tree [.%s ? ]" % self.label
323
+ raise ValueError("Unhandled spec: %s" % spec)
324
+
325
+ def __str__(self):
326
+ return "Dummy({})".format(self.label)
327
+
328
+ def __eq__(self, other):
329
+ if self is other:
330
+ return True
331
+ if not isinstance(other, Dummy):
332
+ return False
333
+ if self.label == other.label:
334
+ return True
335
+ return False
336
+
337
+ def __hash__(self):
338
+ return hash(self.label)
339
+
340
+ def too_many_unary_nodes(tree, unary_limit):
341
+ """
342
+ Return True iff there are UNARY_LIMIT unary nodes in a tree in a row
343
+
344
+ helps prevent infinite open/close patterns
345
+ otherwise, the model can get stuck in essentially an infinite loop
346
+ """
347
+ if tree is None:
348
+ return False
349
+ for _ in range(unary_limit + 1):
350
+ if len(tree.children) != 1:
351
+ return False
352
+ tree = tree.children[0]
353
+ return True
354
+
355
+ class OpenConstituent(Transition):
356
+ def __init__(self, *label):
357
+ self.label = tuple(label)
358
+ self.top_label = self.label[0]
359
+
360
+ def delta_opens(self):
361
+ return 1
362
+
363
+ def update_state(self, state, model):
364
+ # open a new constituent which can later be closed
365
+ # puts a DUMMY constituent on the stack to mark where the constituents end
366
+ return state.word_position, state.constituents, model.dummy_constituent(Dummy(self.label)), None
367
+
368
+ def is_legal(self, state, model):
369
+ """
370
+ disallow based on the length of the sentence
371
+ """
372
+ if state.num_opens > state.sentence_length + 10:
373
+ # fudge a bit so we don't miss root nodes etc in very small trees
374
+ # also there's one really deep tree in CTB 9.0
375
+ return False
376
+ if model.is_top_down:
377
+ # If the model is top down, you can't Open if there are
378
+ # no words to eventually eat
379
+ if state.empty_word_queue():
380
+ return False
381
+ # Also, you can only Open a ROOT iff it is at the root position
382
+ # The assumption in the unary scheme is there will be no
383
+ # root open transitions
384
+ if not model.has_unary_transitions():
385
+ # TODO: maybe cache this value if this is an expensive operation
386
+ is_root = self.top_label in model.root_labels
387
+ if is_root:
388
+ return state.empty_transitions()
389
+ else:
390
+ return not state.empty_transitions()
391
+ else:
392
+ # in-order nodes can Open as long as there is at least one thing
393
+ # on the constituency stack
394
+ # since closing the in-order involves removing one more
395
+ # item before the open, and it can close at any time
396
+ # (a close immediately after the open represents a unary)
397
+ if state.empty_constituents:
398
+ return False
399
+ if isinstance(model.get_top_transition(state.transitions), OpenConstituent):
400
+ # consecutive Opens don't make sense in the context of in-order
401
+ return False
402
+ if not model.transition_scheme() is TransitionScheme.IN_ORDER:
403
+ # eg, IN_ORDER_UNARY or IN_ORDER_COMPOUND
404
+ # if compound unary opens are used
405
+ # or the unary transitions are via CompoundUnary
406
+ # can always open as long as the word queue isn't empty
407
+ # if the word queue is empty, only close is allowed
408
+ return not state.empty_word_queue()
409
+ # one other restriction - we assume all parse trees
410
+ # start with (ROOT (first_real_con ...))
411
+ # therefore ROOT can only occur via Open after everything
412
+ # else has been pushed and processed
413
+ # there are no further restrictions
414
+ is_root = self.top_label in model.root_labels
415
+ if is_root:
416
+ # can't make a root node if it will be in the middle of the parse
417
+ # can't make a root node if there's still words to eat
418
+ # note that the second assumption wouldn't work,
419
+ # except we are assuming there will never be multiple
420
+ # nodes under one root
421
+ return state.num_opens == 0 and state.empty_word_queue()
422
+ else:
423
+ if (state.num_opens > 0 or state.empty_word_queue()) and too_many_unary_nodes(model.get_top_constituent(state.constituents), model.unary_limit()):
424
+ # looks like we've been in a loop of lots of unary transitions
425
+ # note that we check `num_opens > 0` because otherwise we might wind up stuck
426
+ # in a state where the only legal transition is open, such as if the
427
+ # constituent stack is otherwise empty, but the open is illegal because
428
+ # it causes too many unaries
429
+ # in such a case we can forbid the corresponding close instead...
430
+ # if empty_word_queue, that means it is trying to make infinitiely many
431
+ # non-ROOT Open transitions instead of just transitioning ROOT
432
+ return False
433
+ return True
434
+ return True
435
+
436
+ def components(self):
437
+ return [OpenConstituent(label) for label in self.label]
438
+
439
+ def short_name(self):
440
+ return "Open"
441
+
442
+ def __repr__(self):
443
+ return "OpenConstituent({})".format(self.label)
444
+
445
+ def __eq__(self, other):
446
+ if self is other:
447
+ return True
448
+ if not isinstance(other, OpenConstituent):
449
+ return False
450
+ if self.label == other.label:
451
+ return True
452
+ return False
453
+
454
+ def __hash__(self):
455
+ return hash(self.label)
456
+
457
+ class Finalize(Transition):
458
+ """
459
+ Specifically applies at the end of a parse sequence to add a ROOT
460
+
461
+ Seemed like the simplest way to remove ROOT from the
462
+ in_order_compound transitions while still using the mechanism of
463
+ the transitions to build the parse tree
464
+ """
465
+ def __init__(self, *label):
466
+ self.label = tuple(label)
467
+
468
+ def update_state(self, state, model):
469
+ """
470
+ Apply potentially multiple unary transitions to the same preterminal
471
+
472
+ Only applies to preterminals
473
+ It reuses the CloseConstituent machinery
474
+ """
475
+ # only the top constituent is meaningful here
476
+ constituents = state.constituents
477
+ children = [constituents.value]
478
+ constituents = constituents.pop()
479
+ # unlike with CloseConstituent, our label is not on the stack.
480
+ # it is just our label
481
+ label = self.label
482
+
483
+ # ... but we do reuse CloseConstituent's update
484
+ return state.word_position, constituents, (label, children), CloseConstituent
485
+
486
+ def is_legal(self, state, model):
487
+ """
488
+ Legal if & only if there is one tree, no more words, and no ROOT yet
489
+ """
490
+ return state.empty_word_queue() and state.has_one_constituent() and not state.finished(model)
491
+
492
+ def short_name(self):
493
+ return "Finalize"
494
+
495
+ def __repr__(self):
496
+ return "Finalize(%s)" % ",".join(self.label)
497
+
498
+ def __eq__(self, other):
499
+ if self is other:
500
+ return True
501
+ if not isinstance(other, Finalize):
502
+ return False
503
+ return other.label == self.label
504
+
505
+ def __hash__(self):
506
+ return hash((53, self.label))
507
+
508
+ class CloseConstituent(Transition):
509
+ def delta_opens(self):
510
+ return -1
511
+
512
+ def update_state(self, state, model):
513
+ # pop constituents until we are done
514
+ children = []
515
+ constituents = state.constituents
516
+ while not isinstance(model.get_top_constituent(constituents), Dummy):
517
+ # keep the entire value from the stack - the model may need
518
+ # the whole thing to transform the children into a new node
519
+ children.append(constituents.value)
520
+ constituents = constituents.pop()
521
+ # the Dummy has the label on it
522
+ label = model.get_top_constituent(constituents).label
523
+ # pop past the Dummy as well
524
+ constituents = constituents.pop()
525
+ if not model.is_top_down:
526
+ # the alternative to TOP_DOWN_... is IN_ORDER
527
+ # in which case we want to pop one more constituent
528
+ children.append(constituents.value)
529
+ constituents = constituents.pop()
530
+ # the children are in the opposite order of what we expect
531
+ children.reverse()
532
+
533
+ return state.word_position, constituents, (label, children), CloseConstituent
534
+
535
+ @staticmethod
536
+ def build_constituents(model, data):
537
+ """
538
+ builds new constituents out of the incoming data
539
+
540
+ data is a list of tuples: (label, children)
541
+ the model will batch the build operation
542
+ again, the purpose of this batching is to do multiple deep learning operations at once
543
+ """
544
+ labels, children_lists = map(list, zip(*data))
545
+ new_constituents = model.build_constituents(labels, children_lists)
546
+ return new_constituents
547
+
548
+
549
+ def is_legal(self, state, model):
550
+ """
551
+ Disallow if there is no Open on the stack yet
552
+
553
+ in TOP_DOWN, if the previous transition was the Open (nothing built yet)
554
+ in IN_ORDER, previous transition does not matter, except for one small corner case
555
+ """
556
+ if state.num_opens <= 0:
557
+ return False
558
+ if model.is_top_down:
559
+ if isinstance(model.get_top_transition(state.transitions), OpenConstituent):
560
+ return False
561
+ if state.num_opens <= 1 and not state.empty_word_queue():
562
+ # don't close the last open until all words have been used
563
+ return False
564
+ if model.transition_scheme() == TransitionScheme.TOP_DOWN_COMPOUND:
565
+ # when doing TOP_DOWN_COMPOUND, we assume all transitions
566
+ # at the ROOT level have an S, SQ, FRAG, etc underneath
567
+ # this is checked when the model is first trained
568
+ if state.num_opens == 1 and not state.empty_word_queue():
569
+ return False
570
+ elif not model.has_unary_transitions():
571
+ # in fact, we have to leave the top level constituent
572
+ # under the ROOT open if unary transitions are not possible
573
+ if state.num_opens == 2 and not state.empty_word_queue():
574
+ return False
575
+ elif model.transition_scheme() is TransitionScheme.IN_ORDER:
576
+ if not isinstance(model.get_top_transition(state.transitions), OpenConstituent):
577
+ # we're not stuck in a loop of unaries
578
+ return True
579
+ if state.num_opens > 1 or state.empty_word_queue():
580
+ # in either of these cases, the corresponding Open should be eliminated
581
+ # if we're stuck in a loop of unaries
582
+ return True
583
+ node = model.get_top_constituent(state.constituents.pop())
584
+ if too_many_unary_nodes(node, model.unary_limit()):
585
+ # at this point, we are in a situation where
586
+ # - multiple unaries have happened in a row
587
+ # - there is stuff on the word_queue, so a ROOT open isn't legal
588
+ # - there's only one constituent on the stack, so the only legal
589
+ # option once there are no opens left will be an open
590
+ # this means we'll be stuck having to open again if we do close
591
+ # this node, so instead we make the Close illegal
592
+ return False
593
+ else:
594
+ # model.transition_scheme() == TransitionScheme.IN_ORDER_COMPOUND or
595
+ # model.transition_scheme() == TransitionScheme.IN_ORDER_UNARY:
596
+ # in both of these cases, we cannot do open/close
597
+ # IN_ORDER_COMPOUND will use compound opens and preterminal unaries
598
+ # IN_ORDER_UNARY will use compound unaries
599
+ # the only restriction here is that we can't close immediately after an open
600
+ # internal unaries are handled by the opens being compound
601
+ # preterminal unaries are handled with CompoundUnary
602
+ if isinstance(model.get_top_transition(state.transitions), OpenConstituent):
603
+ return False
604
+ return True
605
+
606
+ def short_name(self):
607
+ return "Close"
608
+
609
+ def __repr__(self):
610
+ return "CloseConstituent"
611
+
612
+ def __eq__(self, other):
613
+ if self is other:
614
+ return True
615
+ if isinstance(other, CloseConstituent):
616
+ return True
617
+ return False
618
+
619
+ def __hash__(self):
620
+ return hash(93)
621
+
622
+ def check_transitions(train_transitions, other_transitions, treebank_name):
623
+ """
624
+ Check that all the transitions in the other dataset are known in the train set
625
+
626
+ Weird nested unaries are warned rather than failed as long as the
627
+ components are all known
628
+
629
+ There is a tree in VLSP, for example, with three (!) nested NP nodes
630
+ If this is an unknown compound transition, we won't possibly get it
631
+ right when parsing, but at least we don't need to fail
632
+ """
633
+ unknown_transitions = set()
634
+ for trans in other_transitions:
635
+ if trans not in train_transitions:
636
+ for component in trans.components():
637
+ if component not in train_transitions:
638
+ raise RuntimeError("Found transition {} in the {} set which don't exist in the train set".format(trans, treebank_name))
639
+ unknown_transitions.add(trans)
640
+ if len(unknown_transitions) > 0:
641
+ logger.warning("Found transitions where the components are all valid transitions, but the complete transition is unknown: %s", sorted(unknown_transitions))
stanza/stanza/models/constituency/parser_training.py ADDED
@@ -0,0 +1,771 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections import Counter, namedtuple
2
+ import copy
3
+ import logging
4
+ import os
5
+ import random
6
+ import re
7
+
8
+ import torch
9
+ from torch import nn
10
+
11
+ #from stanza.models.common import pretrain
12
+
13
+ from stanza.models.common import utils
14
+ from stanza.models.common.foundation_cache import FoundationCache, NoTransformerFoundationCache
15
+ from stanza.models.common.large_margin_loss import LargeMarginInSoftmaxLoss
16
+ from stanza.models.common.utils import sort_with_indices, unsort
17
+ from stanza.models.constituency import parse_transitions
18
+ from stanza.models.constituency import transition_sequence
19
+ from stanza.models.constituency import tree_reader
20
+ from stanza.models.constituency.in_order_compound_oracle import InOrderCompoundOracle
21
+ from stanza.models.constituency.in_order_oracle import InOrderOracle
22
+ from stanza.models.constituency.lstm_model import LSTMModel
23
+ from stanza.models.constituency.parse_transitions import TransitionScheme
24
+ from stanza.models.constituency.parse_tree import Tree
25
+ from stanza.models.constituency.top_down_oracle import TopDownOracle
26
+ from stanza.models.constituency.trainer import Trainer
27
+ from stanza.models.constituency.utils import retag_trees, build_optimizer, build_scheduler, verify_transitions, get_open_nodes, check_constituents, check_root_labels, remove_duplicate_trees, remove_singleton_trees
28
+ from stanza.server.parser_eval import EvaluateParser, ParseResult
29
+ from stanza.utils.get_tqdm import get_tqdm
30
+
31
+ tqdm = get_tqdm()
32
+
33
+ tlogger = logging.getLogger('stanza.constituency.trainer')
34
+
35
+ TrainItem = namedtuple("TrainItem", ['tree', 'gold_sequence', 'preterminals'])
36
+
37
+ class EpochStats(namedtuple("EpochStats", ['epoch_loss', 'transitions_correct', 'transitions_incorrect', 'repairs_used', 'fake_transitions_used', 'nans'])):
38
+ def __add__(self, other):
39
+ transitions_correct = self.transitions_correct + other.transitions_correct
40
+ transitions_incorrect = self.transitions_incorrect + other.transitions_incorrect
41
+ repairs_used = self.repairs_used + other.repairs_used
42
+ fake_transitions_used = self.fake_transitions_used + other.fake_transitions_used
43
+ epoch_loss = self.epoch_loss + other.epoch_loss
44
+ nans = self.nans + other.nans
45
+ return EpochStats(epoch_loss, transitions_correct, transitions_incorrect, repairs_used, fake_transitions_used, nans)
46
+
47
+ def evaluate(args, model_file, retag_pipeline):
48
+ """
49
+ Loads the given model file and tests the eval_file treebank.
50
+
51
+ May retag the trees using retag_pipeline
52
+ Uses a subprocess to run the Java EvalB code
53
+ """
54
+ # we create the Evaluator here because otherwise the transformers
55
+ # library constantly complains about forking the process
56
+ # note that this won't help in the event of training multiple
57
+ # models in the same run, although since that would take hours
58
+ # or days, that's not a very common problem
59
+ if args['num_generate'] > 0:
60
+ kbest = args['num_generate'] + 1
61
+ else:
62
+ kbest = None
63
+
64
+ with EvaluateParser(kbest=kbest) as evaluator:
65
+ foundation_cache = retag_pipeline[0].foundation_cache if retag_pipeline else FoundationCache()
66
+ load_args = {
67
+ 'wordvec_pretrain_file': args['wordvec_pretrain_file'],
68
+ 'charlm_forward_file': args['charlm_forward_file'],
69
+ 'charlm_backward_file': args['charlm_backward_file'],
70
+ 'device': args['device'],
71
+ }
72
+ trainer = Trainer.load(model_file, args=load_args, foundation_cache=foundation_cache)
73
+
74
+ if args['log_shapes']:
75
+ trainer.log_shapes()
76
+
77
+ treebank = tree_reader.read_treebank(args['eval_file'])
78
+ tlogger.info("Read %d trees for evaluation", len(treebank))
79
+
80
+ retagged_treebank = treebank
81
+ if retag_pipeline is not None:
82
+ retag_method = trainer.model.retag_method
83
+ retag_xpos = retag_method == 'xpos'
84
+ tlogger.info("Retagging trees using the %s tags from the %s package...", retag_method, args['retag_package'])
85
+ retagged_treebank = retag_trees(treebank, retag_pipeline, retag_xpos)
86
+ tlogger.info("Retagging finished")
87
+
88
+ if args['log_norms']:
89
+ trainer.log_norms()
90
+ f1, kbestF1, _ = run_dev_set(trainer.model, retagged_treebank, treebank, args, evaluator)
91
+ tlogger.info("F1 score on %s: %f", args['eval_file'], f1)
92
+ if kbestF1 is not None:
93
+ tlogger.info("KBest F1 score on %s: %f", args['eval_file'], kbestF1)
94
+
95
+ def remove_optimizer(args, model_save_file, model_load_file):
96
+ """
97
+ A utility method to remove the optimizer from a save file
98
+
99
+ Will make the save file a lot smaller
100
+ """
101
+ # TODO: kind of overkill to load in the pretrain rather than
102
+ # change the load/save to work without it, but probably this
103
+ # functionality isn't used that often anyway
104
+ load_args = {
105
+ 'wordvec_pretrain_file': args['wordvec_pretrain_file'],
106
+ 'charlm_forward_file': args['charlm_forward_file'],
107
+ 'charlm_backward_file': args['charlm_backward_file'],
108
+ 'device': args['device'],
109
+ }
110
+ trainer = Trainer.load(model_load_file, args=load_args, load_optimizer=False)
111
+ trainer.save(model_save_file)
112
+
113
+ def add_grad_clipping(trainer, grad_clipping):
114
+ """
115
+ Adds a torch.clamp hook on each parameter if grad_clipping is not None
116
+ """
117
+ if grad_clipping is not None:
118
+ for p in trainer.model.parameters():
119
+ if p.requires_grad:
120
+ p.register_hook(lambda grad: torch.clamp(grad, -grad_clipping, grad_clipping))
121
+
122
+ def build_trainer(args, train_trees, dev_trees, silver_trees, foundation_cache, model_load_file):
123
+ """
124
+ Builds a Trainer (with model) and the train_sequences and transitions for the given trees.
125
+ """
126
+ train_constituents = Tree.get_unique_constituent_labels(train_trees)
127
+ tlogger.info("Unique constituents in training set: %s", train_constituents)
128
+ if args['check_valid_states']:
129
+ check_constituents(train_constituents, dev_trees, "dev", fail=args['strict_check_constituents'])
130
+ check_constituents(train_constituents, silver_trees, "silver", fail=args['strict_check_constituents'])
131
+ constituent_counts = Tree.get_constituent_counts(train_trees)
132
+ tlogger.info("Constituent node counts: %s", constituent_counts)
133
+
134
+ tags = Tree.get_unique_tags(train_trees)
135
+ if None in tags:
136
+ raise RuntimeError("Fatal problem: the tagger put None on some of the nodes!")
137
+ tlogger.info("Unique tags in training set: %s", tags)
138
+ # no need to fail for missing tags between train/dev set
139
+ # the model has an unknown tag embedding
140
+ for tag in Tree.get_unique_tags(dev_trees):
141
+ if tag not in tags:
142
+ tlogger.info("Found tag in dev set which does not exist in train set: %s Continuing...", tag)
143
+
144
+ unary_limit = max(max(t.count_unary_depth() for t in train_trees),
145
+ max(t.count_unary_depth() for t in dev_trees)) + 1
146
+ if silver_trees:
147
+ unary_limit = max(unary_limit, max(t.count_unary_depth() for t in silver_trees))
148
+ tlogger.info("Unary limit: %d", unary_limit)
149
+ train_sequences, train_transitions = transition_sequence.convert_trees_to_sequences(train_trees, "training", args['transition_scheme'], args['reversed'])
150
+ dev_sequences, dev_transitions = transition_sequence.convert_trees_to_sequences(dev_trees, "dev", args['transition_scheme'], args['reversed'])
151
+ silver_sequences, silver_transitions = transition_sequence.convert_trees_to_sequences(silver_trees, "silver", args['transition_scheme'], args['reversed'])
152
+
153
+ tlogger.info("Total unique transitions in train set: %d", len(train_transitions))
154
+ tlogger.info("Unique transitions in training set:\n %s", "\n ".join(map(str, train_transitions)))
155
+ expanded_train_transitions = set(train_transitions + [x for trans in train_transitions for x in trans.components()])
156
+ if args['check_valid_states']:
157
+ parse_transitions.check_transitions(expanded_train_transitions, dev_transitions, "dev")
158
+ # theoretically could just train based on the items in the silver dataset
159
+ parse_transitions.check_transitions(expanded_train_transitions, silver_transitions, "silver")
160
+
161
+ root_labels = Tree.get_root_labels(train_trees)
162
+ check_root_labels(root_labels, dev_trees, "dev")
163
+ check_root_labels(root_labels, silver_trees, "silver")
164
+ tlogger.info("Root labels in treebank: %s", root_labels)
165
+
166
+ verify_transitions(train_trees, train_sequences, args['transition_scheme'], unary_limit, args['reversed'], "train", root_labels)
167
+ verify_transitions(dev_trees, dev_sequences, args['transition_scheme'], unary_limit, args['reversed'], "dev", root_labels)
168
+
169
+ # we don't check against the words in the dev set as it is
170
+ # expected there will be some UNK words
171
+ words = Tree.get_unique_words(train_trees)
172
+ rare_words = Tree.get_rare_words(train_trees, args['rare_word_threshold'])
173
+ # rare/unknown silver words will just get UNK if they are not already known
174
+ if silver_trees and args['use_silver_words']:
175
+ tlogger.info("Getting silver words to add to the delta embedding")
176
+ silver_words = Tree.get_common_words(tqdm(silver_trees, postfix='Silver words'), len(words))
177
+ words = sorted(set(words + silver_words))
178
+
179
+ # also, it's not actually an error if there is a pattern of
180
+ # compound unary or compound open nodes which doesn't exist in the
181
+ # train set. it just means we probably won't ever get that right
182
+ open_nodes = get_open_nodes(train_trees, args['transition_scheme'])
183
+ tlogger.info("Using the following open nodes:\n %s", "\n ".join(map(str, open_nodes)))
184
+
185
+ # at this point we have:
186
+ # pretrain
187
+ # train_trees, dev_trees
188
+ # lists of transitions, internal nodes, and root states the parser needs to be aware of
189
+
190
+ trainer = Trainer.build_trainer(args, train_transitions, train_constituents, tags, words, rare_words, root_labels, open_nodes, unary_limit, foundation_cache, model_load_file)
191
+
192
+ trainer.log_num_words_known(words)
193
+ # grad clipping is not saved with the rest of the model,
194
+ # so even in the case of a model we saved,
195
+ # we now have to add the grad clipping
196
+ add_grad_clipping(trainer, args['grad_clipping'])
197
+
198
+ return trainer, train_sequences, silver_sequences, train_transitions
199
+
200
+ def train(args, model_load_file, retag_pipeline):
201
+ """
202
+ Build a model, train it using the requested train & dev files
203
+ """
204
+ utils.log_training_args(args, tlogger)
205
+
206
+ # we create the Evaluator here because otherwise the transformers
207
+ # library constantly complains about forking the process
208
+ # note that this won't help in the event of training multiple
209
+ # models in the same run, although since that would take hours
210
+ # or days, that's not a very common problem
211
+ if args['num_generate'] > 0:
212
+ kbest = args['num_generate'] + 1
213
+ else:
214
+ kbest = None
215
+
216
+ if args['wandb']:
217
+ global wandb
218
+ import wandb
219
+ wandb_name = args['wandb_name'] if args['wandb_name'] else "%s_constituency" % args['shorthand']
220
+ wandb.init(name=wandb_name, config=args)
221
+ wandb.run.define_metric('dev_score', summary='max')
222
+
223
+ with EvaluateParser(kbest=kbest) as evaluator:
224
+ utils.ensure_dir(args['save_dir'])
225
+
226
+ train_trees = tree_reader.read_treebank(args['train_file'])
227
+ tlogger.info("Read %d trees for the training set", len(train_trees))
228
+ if args['train_remove_duplicates']:
229
+ train_trees = remove_duplicate_trees(train_trees, "train")
230
+ train_trees = remove_singleton_trees(train_trees)
231
+
232
+ dev_trees = tree_reader.read_treebank(args['eval_file'])
233
+ tlogger.info("Read %d trees for the dev set", len(dev_trees))
234
+ dev_trees = remove_duplicate_trees(dev_trees, "dev")
235
+
236
+ silver_trees = []
237
+ if args['silver_file']:
238
+ silver_trees = tree_reader.read_treebank(args['silver_file'])
239
+ tlogger.info("Read %d trees for the silver training set", len(silver_trees))
240
+ if args['silver_remove_duplicates']:
241
+ silver_trees = remove_duplicate_trees(silver_trees, "silver")
242
+
243
+ if retag_pipeline is not None:
244
+ tlogger.info("Retagging trees using the %s tags from the %s package...", args['retag_method'], args['retag_package'])
245
+ train_trees = retag_trees(train_trees, retag_pipeline, args['retag_xpos'])
246
+ dev_trees = retag_trees(dev_trees, retag_pipeline, args['retag_xpos'])
247
+ silver_trees = retag_trees(silver_trees, retag_pipeline, args['retag_xpos'])
248
+ tlogger.info("Retagging finished")
249
+
250
+ foundation_cache = retag_pipeline[0].foundation_cache if retag_pipeline else FoundationCache()
251
+ trainer, train_sequences, silver_sequences, train_transitions = build_trainer(args, train_trees, dev_trees, silver_trees, foundation_cache, model_load_file)
252
+
253
+ if args['log_shapes']:
254
+ trainer.log_shapes()
255
+ trainer = iterate_training(args, trainer, train_trees, train_sequences, train_transitions, dev_trees, silver_trees, silver_sequences, foundation_cache, evaluator)
256
+
257
+ if args['wandb']:
258
+ wandb.finish()
259
+
260
+ return trainer
261
+
262
+ def compose_train_data(trees, sequences):
263
+ preterminal_lists = [[Tree(label=preterminal.label, children=Tree(label=preterminal.children[0].label))
264
+ for preterminal in tree.yield_preterminals()]
265
+ for tree in trees]
266
+ data = [TrainItem(*x) for x in zip(trees, sequences, preterminal_lists)]
267
+ return data
268
+
269
+ def next_epoch_data(leftover_training_data, train_data, epoch_size):
270
+ """
271
+ Return the next epoch_size trees from the training data, starting
272
+ with leftover data from the previous epoch if there is any
273
+
274
+ The training loop generally operates on a fixed number of trees,
275
+ rather than going through all the trees in the training set
276
+ exactly once, and keeping the leftover training data via this
277
+ function ensures that each tree in the training set is touched
278
+ once before beginning to iterate again.
279
+ """
280
+ if not train_data:
281
+ return [], []
282
+
283
+ epoch_data = leftover_training_data
284
+ while len(epoch_data) < epoch_size:
285
+ random.shuffle(train_data)
286
+ epoch_data.extend(train_data)
287
+ leftover_training_data = epoch_data[epoch_size:]
288
+ epoch_data = epoch_data[:epoch_size]
289
+
290
+ return leftover_training_data, epoch_data
291
+
292
+ def update_bert_learning_rate(args, optimizer, epochs_trained):
293
+ """
294
+ Update the learning rate for the bert finetuning, if applicable
295
+ """
296
+ # would be nice to have a parameter group specific scheduler
297
+ # however, there is an issue with the optimizer we had the most success with, madgrad
298
+ # when the learning rate is 0 for a group, it still learns by some
299
+ # small amount because of the eps parameter
300
+ # in fact, that is enough to make the learning for the bert in the
301
+ # second half broken
302
+ for base_param_group in optimizer.param_groups:
303
+ if base_param_group['param_group_name'] == 'base':
304
+ break
305
+ else:
306
+ raise AssertionError("There should always be a base parameter group")
307
+ for param_group in optimizer.param_groups:
308
+ if param_group['param_group_name'] == 'bert':
309
+ # Occasionally a model goes haywire and forgets how to use the transformer
310
+ # So far we have only seen this happen with Electra on the non-NML version of PTB
311
+ # We tried fixing that with an increasing transformer learning rate, but that
312
+ # didn't fully resolve the problem
313
+ # Switching to starting the finetuning after a few epochs seems to help a lot, though
314
+ old_lr = param_group['lr']
315
+ if args['bert_finetune_begin_epoch'] is not None and epochs_trained < args['bert_finetune_begin_epoch']:
316
+ param_group['lr'] = 0.0
317
+ elif args['bert_finetune_end_epoch'] is not None and epochs_trained >= args['bert_finetune_end_epoch']:
318
+ param_group['lr'] = 0.0
319
+ elif args['multistage'] and epochs_trained < args['epochs'] // 2:
320
+ param_group['lr'] = base_param_group['lr'] * args['stage1_bert_learning_rate']
321
+ else:
322
+ param_group['lr'] = base_param_group['lr'] * args['bert_learning_rate']
323
+ if param_group['lr'] != old_lr:
324
+ tlogger.info("Setting %s finetuning rate from %f to %f", param_group['param_group_name'], old_lr, param_group['lr'])
325
+
326
+ def iterate_training(args, trainer, train_trees, train_sequences, transitions, dev_trees, silver_trees, silver_sequences, foundation_cache, evaluator):
327
+ """
328
+ Given an initialized model, a processed dataset, and a secondary dev dataset, train the model
329
+
330
+ The training is iterated in the following loop:
331
+ extract a batch of trees of the same length from the training set
332
+ convert those trees into initial parsing states
333
+ repeat until trees are done:
334
+ batch predict the model's interpretation of the current states
335
+ add the errors to the list of things to backprop
336
+ advance the parsing state for each of the trees
337
+ """
338
+ # Somewhat unusual, but possibly related to the extreme variability in length of trees
339
+ # Various experiments generally show about 0.5 F1 loss on various
340
+ # datasets when using 'mean' instead of 'sum' for reduction
341
+ # (Remember to adjust the weight decay when rerunning that experiment)
342
+ if args['loss'] == 'cross':
343
+ tlogger.info("Building CrossEntropyLoss(sum)")
344
+ process_outputs = lambda x: x
345
+ model_loss_function = nn.CrossEntropyLoss(reduction='sum')
346
+ elif args['loss'] == 'focal':
347
+ try:
348
+ from focal_loss.focal_loss import FocalLoss
349
+ except ImportError:
350
+ raise ImportError("focal_loss not installed. Must `pip install focal_loss_torch` to use the --loss=focal feature")
351
+ tlogger.info("Building FocalLoss, gamma=%f", args['loss_focal_gamma'])
352
+ process_outputs = lambda x: torch.softmax(x, dim=1)
353
+ model_loss_function = FocalLoss(reduction='sum', gamma=args['loss_focal_gamma'])
354
+ elif args['loss'] == 'large_margin':
355
+ tlogger.info("Building LargeMarginInSoftmaxLoss(sum)")
356
+ process_outputs = lambda x: x
357
+ model_loss_function = LargeMarginInSoftmaxLoss(reduction='sum')
358
+ else:
359
+ raise ValueError("Unexpected loss term: %s" % args['loss'])
360
+
361
+ device = trainer.device
362
+ model_loss_function.to(device)
363
+ transition_tensors = {x: torch.tensor(y, requires_grad=False, device=device).unsqueeze(0)
364
+ for (y, x) in enumerate(trainer.transitions)}
365
+ trainer.train()
366
+
367
+ train_data = compose_train_data(train_trees, train_sequences)
368
+ silver_data = compose_train_data(silver_trees, silver_sequences)
369
+
370
+ if not args['epoch_size']:
371
+ args['epoch_size'] = len(train_data)
372
+ if silver_data and not args['silver_epoch_size']:
373
+ args['silver_epoch_size'] = args['epoch_size']
374
+
375
+ if args['multistage']:
376
+ multistage_splits = {}
377
+ # if we're halfway, only do pattn. save lattn for next time
378
+ multistage_splits[args['epochs'] // 2] = (args['pattn_num_layers'], False)
379
+ if LSTMModel.uses_lattn(args):
380
+ multistage_splits[args['epochs'] * 3 // 4] = (args['pattn_num_layers'], True)
381
+
382
+ oracle = None
383
+ if args['transition_scheme'] is TransitionScheme.IN_ORDER:
384
+ oracle = InOrderOracle(trainer.root_labels, args['oracle_level'], args['additional_oracle_levels'], args['deactivated_oracle_levels'])
385
+ elif args['transition_scheme'] is TransitionScheme.IN_ORDER_COMPOUND:
386
+ oracle = InOrderCompoundOracle(trainer.root_labels, args['oracle_level'], args['additional_oracle_levels'], args['deactivated_oracle_levels'])
387
+ elif args['transition_scheme'] is TransitionScheme.TOP_DOWN:
388
+ oracle = TopDownOracle(trainer.root_labels, args['oracle_level'], args['additional_oracle_levels'], args['deactivated_oracle_levels'])
389
+
390
+ leftover_training_data = []
391
+ leftover_silver_data = []
392
+ if trainer.best_epoch > 0:
393
+ tlogger.info("Restarting trainer with a model trained for %d epochs. Best epoch %d, f1 %f", trainer.epochs_trained, trainer.best_epoch, trainer.best_f1)
394
+
395
+ # if we're training a new model, save the initial state so it can be inspected
396
+ if args['save_each_start'] == 0 and trainer.epochs_trained == 0:
397
+ trainer.save(args['save_each_name'] % trainer.epochs_trained, save_optimizer=True)
398
+
399
+ # trainer.epochs_trained+1 so that if the trainer gets saved after 1 epoch, the epochs_trained is 1
400
+ for trainer.epochs_trained in range(trainer.epochs_trained+1, args['epochs']+1):
401
+ trainer.train()
402
+ tlogger.info("Starting epoch %d", trainer.epochs_trained)
403
+ update_bert_learning_rate(args, trainer.optimizer, trainer.epochs_trained)
404
+
405
+ if args['log_norms']:
406
+ trainer.log_norms()
407
+ leftover_training_data, epoch_data = next_epoch_data(leftover_training_data, train_data, args['epoch_size'])
408
+ leftover_silver_data, epoch_silver_data = next_epoch_data(leftover_silver_data, silver_data, args['silver_epoch_size'])
409
+ epoch_data = epoch_data + epoch_silver_data
410
+ epoch_data.sort(key=lambda x: len(x[1]))
411
+
412
+ epoch_stats = train_model_one_epoch(trainer.epochs_trained, trainer, transition_tensors, process_outputs, model_loss_function, epoch_data, oracle, args)
413
+
414
+ # print statistics
415
+ # by now we've forgotten about the original tags on the trees,
416
+ # but it doesn't matter for hill climbing
417
+ f1, _, _ = run_dev_set(trainer.model, dev_trees, dev_trees, args, evaluator)
418
+ if f1 > trainer.best_f1 or (trainer.best_epoch == 0 and trainer.best_f1 == 0.0):
419
+ # best_epoch == 0 to force a save of an initial model
420
+ # useful for tests which expect something, even when a
421
+ # very simple model didn't learn anything
422
+ tlogger.info("New best dev score: %.5f > %.5f", f1, trainer.best_f1)
423
+ trainer.best_f1 = f1
424
+ trainer.best_epoch = trainer.epochs_trained
425
+ trainer.save(args['save_name'], save_optimizer=False)
426
+ if epoch_stats.nans > 0:
427
+ tlogger.warning("Had to ignore %d batches with NaN", epoch_stats.nans)
428
+ stats_log_lines = [
429
+ "Epoch %d finished" % trainer.epochs_trained,
430
+ "Transitions correct: %s" % epoch_stats.transitions_correct,
431
+ "Transitions incorrect: %s" % epoch_stats.transitions_incorrect,
432
+ "Total loss for epoch: %.5f" % epoch_stats.epoch_loss,
433
+ "Dev score (%5d): %8f" % (trainer.epochs_trained, f1),
434
+ "Best dev score (%5d): %8f" % (trainer.best_epoch, trainer.best_f1)
435
+ ]
436
+ tlogger.info("\n ".join(stats_log_lines))
437
+
438
+ old_lr = trainer.optimizer.param_groups[0]['lr']
439
+ trainer.scheduler.step(f1)
440
+ new_lr = trainer.optimizer.param_groups[0]['lr']
441
+ if old_lr != new_lr:
442
+ tlogger.info("Updating learning rate from %f to %f", old_lr, new_lr)
443
+
444
+ if args['wandb']:
445
+ wandb.log({'epoch_loss': epoch_stats.epoch_loss, 'dev_score': f1}, step=trainer.epochs_trained)
446
+ if args['wandb_norm_regex']:
447
+ watch_regex = re.compile(args['wandb_norm_regex'])
448
+ for n, p in trainer.model.named_parameters():
449
+ if watch_regex.search(n):
450
+ wandb.log({n: torch.linalg.norm(p)})
451
+
452
+ if args['early_dropout'] > 0 and trainer.epochs_trained >= args['early_dropout']:
453
+ if any(x > 0.0 for x in (trainer.model.word_dropout.p, trainer.model.predict_dropout.p, trainer.model.lstm_input_dropout.p)):
454
+ tlogger.info("Setting dropout to 0.0 at epoch %d", trainer.epochs_trained)
455
+ trainer.model.word_dropout.p = 0
456
+ trainer.model.predict_dropout.p = 0
457
+ trainer.model.lstm_input_dropout.p = 0
458
+
459
+ # recreate the optimizer and alter the model as needed if we hit a new multistage split
460
+ if args['multistage'] and trainer.epochs_trained in multistage_splits:
461
+ # we may be loading a save model from an earlier epoch if the scores stopped increasing
462
+ epochs_trained = trainer.epochs_trained
463
+ batches_trained = trainer.batches_trained
464
+
465
+ stage_pattn_layers, stage_uses_lattn = multistage_splits[epochs_trained]
466
+
467
+ # when loading the model, let the saved model determine whether it has pattn or lattn
468
+ temp_args = copy.deepcopy(trainer.model.args)
469
+ temp_args.pop('pattn_num_layers', None)
470
+ temp_args.pop('lattn_d_proj', None)
471
+ # overwriting the old trainer & model will hopefully free memory
472
+ # load a new bert, even in PEFT mode, mostly so that the bert model
473
+ # doesn't collect a whole bunch of PEFTs
474
+ # for one thing, two PEFTs would mean 2x the optimizer parameters,
475
+ # messing up saving and loading the optimizer without jumping
476
+ # through more hoops
477
+ # loading the trainer w/o the foundation_cache should create
478
+ # the necessary bert_model and bert_tokenizer, and then we
479
+ # can reuse those values when building out new LSTMModel
480
+ trainer = Trainer.load(args['save_name'], temp_args, load_optimizer=False)
481
+ model = trainer.model
482
+ tlogger.info("Finished stage at epoch %d. Restarting optimizer", epochs_trained)
483
+ tlogger.info("Previous best model was at epoch %d", trainer.epochs_trained)
484
+
485
+ temp_args = dict(args)
486
+ tlogger.info("Switching to a model with %d pattn layers and %slattn", stage_pattn_layers, "" if stage_uses_lattn else "NO ")
487
+ temp_args['pattn_num_layers'] = stage_pattn_layers
488
+ if not stage_uses_lattn:
489
+ temp_args['lattn_d_proj'] = 0
490
+ pt = foundation_cache.load_pretrain(args['wordvec_pretrain_file'])
491
+ forward_charlm = foundation_cache.load_charlm(args['charlm_forward_file'])
492
+ backward_charlm = foundation_cache.load_charlm(args['charlm_backward_file'])
493
+ new_model = LSTMModel(pt,
494
+ forward_charlm,
495
+ backward_charlm,
496
+ model.bert_model,
497
+ model.bert_tokenizer,
498
+ model.force_bert_saved,
499
+ model.peft_name,
500
+ model.transitions,
501
+ model.constituents,
502
+ model.tags,
503
+ model.delta_words,
504
+ model.rare_words,
505
+ model.root_labels,
506
+ model.constituent_opens,
507
+ model.unary_limit(),
508
+ temp_args)
509
+ new_model.to(device)
510
+ new_model.copy_with_new_structure(model)
511
+
512
+ optimizer = build_optimizer(temp_args, new_model, False)
513
+ scheduler = build_scheduler(temp_args, optimizer)
514
+ trainer = Trainer(new_model, optimizer, scheduler, epochs_trained, batches_trained, trainer.best_f1, trainer.best_epoch)
515
+ add_grad_clipping(trainer, args['grad_clipping'])
516
+
517
+ # checkpoint needs to be saved AFTER rebuilding the optimizer
518
+ # so that assumptions about the optimizer in the checkpoint
519
+ # can be made based on the end of the epoch
520
+ if args['checkpoint'] and args['checkpoint_save_name']:
521
+ trainer.save(args['checkpoint_save_name'], save_optimizer=True)
522
+ # same with the "each filename", actually, in case those are
523
+ # brought back for more training or even just for testing
524
+ if args['save_each_start'] is not None and args['save_each_start'] <= trainer.epochs_trained and trainer.epochs_trained % args['save_each_frequency'] == 0:
525
+ trainer.save(args['save_each_name'] % trainer.epochs_trained, save_optimizer=args['save_each_optimizer'])
526
+
527
+ return trainer
528
+
529
+ def train_model_one_epoch(epoch, trainer, transition_tensors, process_outputs, model_loss_function, epoch_data, oracle, args):
530
+ interval_starts = list(range(0, len(epoch_data), args['train_batch_size']))
531
+ random.shuffle(interval_starts)
532
+
533
+ optimizer = trainer.optimizer
534
+
535
+ epoch_stats = EpochStats(0.0, Counter(), Counter(), Counter(), 0, 0)
536
+
537
+ for batch_idx, interval_start in enumerate(tqdm(interval_starts, postfix="Epoch %d" % epoch)):
538
+ batch = epoch_data[interval_start:interval_start+args['train_batch_size']]
539
+ batch_stats = train_model_one_batch(epoch, batch_idx, trainer.model, batch, transition_tensors, process_outputs, model_loss_function, oracle, args)
540
+ trainer.batches_trained += 1
541
+
542
+ # Early in the training, some trees will be degenerate in a
543
+ # way that results in layers going up the tree amplifying the
544
+ # weights until they overflow. Generally that problem
545
+ # resolves itself in a few iterations, so for now we just
546
+ # ignore those batches, but report how often it happens
547
+ if batch_stats.nans == 0:
548
+ optimizer.step()
549
+ optimizer.zero_grad()
550
+ epoch_stats = epoch_stats + batch_stats
551
+
552
+
553
+ # TODO: refactor the logging?
554
+ total_correct = sum(v for _, v in epoch_stats.transitions_correct.items())
555
+ total_incorrect = sum(v for _, v in epoch_stats.transitions_incorrect.items())
556
+ tlogger.info("Transitions correct: %d\n %s", total_correct, str(epoch_stats.transitions_correct))
557
+ tlogger.info("Transitions incorrect: %d\n %s", total_incorrect, str(epoch_stats.transitions_incorrect))
558
+ if len(epoch_stats.repairs_used) > 0:
559
+ tlogger.info("Oracle repairs:\n %s", "\n ".join("%s (%s): %d" % (x.name, x.value, y) for x, y in epoch_stats.repairs_used.most_common()))
560
+ if epoch_stats.fake_transitions_used > 0:
561
+ tlogger.info("Fake transitions used: %d", epoch_stats.fake_transitions_used)
562
+
563
+ return epoch_stats
564
+
565
+ def train_model_one_batch(epoch, batch_idx, model, training_batch, transition_tensors, process_outputs, model_loss_function, oracle, args):
566
+ """
567
+ Train the model for one batch
568
+
569
+ The model itself will be updated, and a bunch of stats are returned
570
+ It is unclear if this refactoring is useful in any way. Might not be
571
+
572
+ ... although the indentation does get pretty ridiculous if this is
573
+ merged into train_model_one_epoch and then iterate_training
574
+ """
575
+ # now we add the state to the trees in the batch
576
+ # the state is built as a bulk operation
577
+ current_batch = model.initial_state_from_preterminals([x.preterminals for x in training_batch],
578
+ [x.tree for x in training_batch],
579
+ [x.gold_sequence for x in training_batch])
580
+
581
+ transitions_correct = Counter()
582
+ transitions_incorrect = Counter()
583
+ repairs_used = Counter()
584
+ fake_transitions_used = 0
585
+
586
+ all_errors = []
587
+ all_answers = []
588
+
589
+ # we iterate through the batch in the following sequence:
590
+ # predict the logits and the applied transition for each tree in the batch
591
+ # collect errors
592
+ # - we always train to the desired one-hot vector
593
+ # this was a noticeable improvement over training just the
594
+ # incorrect transitions
595
+ # determine whether the training can continue using the "student" transition
596
+ # or if we need to use teacher forcing
597
+ # update all states using either the gold or predicted transition
598
+ # any trees which are now finished are removed from the training cycle
599
+ while len(current_batch) > 0:
600
+ outputs, pred_transitions, _ = model.predict(current_batch, is_legal=False)
601
+ gold_transitions = [x.gold_sequence[x.num_transitions] for x in current_batch]
602
+ trans_tensor = [transition_tensors[gold_transition] for gold_transition in gold_transitions]
603
+ all_errors.append(outputs)
604
+ all_answers.extend(trans_tensor)
605
+
606
+ new_batch = []
607
+ update_transitions = []
608
+ for pred_transition, gold_transition, state in zip(pred_transitions, gold_transitions, current_batch):
609
+ # forget teacher forcing vs scheduled sampling
610
+ # we're going with idiot forcing
611
+ if pred_transition == gold_transition:
612
+ transitions_correct[gold_transition.short_name()] += 1
613
+ if state.num_transitions + 1 < len(state.gold_sequence):
614
+ if oracle is not None and epoch >= args['oracle_initial_epoch'] and random.random() < args['oracle_forced_errors']:
615
+ # TODO: could randomly choose from the legal transitions
616
+ # perhaps the second best scored transition
617
+ fake_transition = random.choice(model.transitions)
618
+ if fake_transition.is_legal(state, model):
619
+ _, new_sequence = oracle.fix_error(fake_transition, model, state)
620
+ if new_sequence is not None:
621
+ new_batch.append(state._replace(gold_sequence=new_sequence))
622
+ update_transitions.append(fake_transition)
623
+ fake_transitions_used = fake_transitions_used + 1
624
+ continue
625
+ new_batch.append(state)
626
+ update_transitions.append(gold_transition)
627
+ continue
628
+
629
+ transitions_incorrect[gold_transition.short_name(), pred_transition.short_name()] += 1
630
+ # if we are on the final operation, there are two choices:
631
+ # - the parsing mode is IN_ORDER, and the final transition
632
+ # is the close to end the sequence, which has no alternatives
633
+ # - the parsing mode is something else, in which case
634
+ # we have no oracle anyway
635
+ if state.num_transitions + 1 >= len(state.gold_sequence):
636
+ continue
637
+
638
+ if oracle is None or epoch < args['oracle_initial_epoch'] or not pred_transition.is_legal(state, model):
639
+ new_batch.append(state)
640
+ update_transitions.append(gold_transition)
641
+ continue
642
+
643
+ repair_type, new_sequence = oracle.fix_error(pred_transition, model, state)
644
+ # we can only reach here on an error
645
+ assert not repair_type.is_correct
646
+ repairs_used[repair_type] += 1
647
+ if new_sequence is not None and random.random() < args['oracle_frequency']:
648
+ new_batch.append(state._replace(gold_sequence=new_sequence))
649
+ update_transitions.append(pred_transition)
650
+ else:
651
+ new_batch.append(state)
652
+ update_transitions.append(gold_transition)
653
+
654
+ if len(current_batch) > 0:
655
+ # bulk update states - significantly faster
656
+ current_batch = model.bulk_apply(new_batch, update_transitions, fail=True)
657
+
658
+ errors = torch.cat(all_errors)
659
+ answers = torch.cat(all_answers)
660
+
661
+ errors = process_outputs(errors)
662
+ tree_loss = model_loss_function(errors, answers)
663
+ tree_loss.backward()
664
+ if args['watch_regex']:
665
+ matched = False
666
+ tlogger.info("Watching %s ... epoch %d batch %d", args['watch_regex'], epoch, batch_idx)
667
+ watch_regex = re.compile(args['watch_regex'])
668
+ for n, p in trainer.model.named_parameters():
669
+ if watch_regex.search(n):
670
+ matched = True
671
+ if p.requires_grad and p.grad is not None:
672
+ tlogger.info(" %s norm: %f grad: %f", n, torch.linalg.norm(p), torch.linalg.norm(p.grad))
673
+ elif p.requires_grad:
674
+ tlogger.info(" %s norm: %f grad required, but is None!", n, torch.linalg.norm(p))
675
+ else:
676
+ tlogger.info(" %s norm: %f grad not required", n, torch.linalg.norm(p))
677
+ if not matched:
678
+ tlogger.info(" (none found!)")
679
+ if torch.any(torch.isnan(tree_loss)):
680
+ batch_loss = 0.0
681
+ nans = 1
682
+ else:
683
+ batch_loss = tree_loss.item()
684
+ nans = 0
685
+
686
+ return EpochStats(batch_loss, transitions_correct, transitions_incorrect, repairs_used, fake_transitions_used, nans)
687
+
688
+ def run_dev_set(model, retagged_trees, original_trees, args, evaluator=None):
689
+ """
690
+ This reparses a treebank and executes the CoreNLP Java EvalB code.
691
+
692
+ It only works if CoreNLP 4.3.0 or higher is in the classpath.
693
+ """
694
+ tlogger.info("Processing %d trees from %s", len(retagged_trees), args['eval_file'])
695
+ model.eval()
696
+
697
+ num_generate = args.get('num_generate', 0)
698
+ keep_scores = num_generate > 0
699
+
700
+ sorted_trees, original_indices = sort_with_indices(retagged_trees, key=len, reverse=True)
701
+ tree_iterator = iter(tqdm(sorted_trees))
702
+ treebank = model.parse_sentences_no_grad(tree_iterator, model.build_batch_from_trees, args['eval_batch_size'], model.predict, keep_scores=keep_scores)
703
+ treebank = unsort(treebank, original_indices)
704
+ full_results = treebank
705
+
706
+ if num_generate > 0:
707
+ tlogger.info("Generating %d random analyses", args['num_generate'])
708
+ generated_treebanks = [treebank]
709
+ for i in tqdm(range(num_generate)):
710
+ tree_iterator = iter(tqdm(retagged_trees, leave=False, postfix="tb%03d" % i))
711
+ generated_treebanks.append(model.parse_sentences_no_grad(tree_iterator, model.build_batch_from_trees, args['eval_batch_size'], model.weighted_choice, keep_scores=keep_scores))
712
+
713
+ #best_treebank = [ParseResult(parses[0].gold, [max([p.predictions[0] for p in parses], key=itemgetter(1))], None, None)
714
+ # for parses in zip(*generated_treebanks)]
715
+ #generated_treebanks = [best_treebank] + generated_treebanks
716
+
717
+ # TODO: if the model is dropping trees, this will not work
718
+ full_results = [ParseResult(parses[0].gold, [p.predictions[0] for p in parses], None, None)
719
+ for parses in zip(*generated_treebanks)]
720
+
721
+ if len(full_results) < len(retagged_trees):
722
+ tlogger.warning("Only evaluating %d trees instead of %d", len(full_results), len(retagged_trees))
723
+ else:
724
+ full_results = [x._replace(gold=gold) for x, gold in zip(full_results, original_trees)]
725
+
726
+ if args.get('mode', None) == 'predict' and args['predict_file']:
727
+ utils.ensure_dir(args['predict_dir'], verbose=False)
728
+ pred_file = os.path.join(args['predict_dir'], args['predict_file'] + ".pred.mrg")
729
+ orig_file = os.path.join(args['predict_dir'], args['predict_file'] + ".orig.mrg")
730
+ if os.path.exists(pred_file):
731
+ tlogger.warning("Cowardly refusing to overwrite {}".format(pred_file))
732
+ elif os.path.exists(orig_file):
733
+ tlogger.warning("Cowardly refusing to overwrite {}".format(orig_file))
734
+ else:
735
+ with open(pred_file, 'w') as fout:
736
+ for tree in full_results:
737
+ output_tree = tree.predictions[0].tree
738
+ if args['predict_output_gold_tags']:
739
+ output_tree = output_tree.replace_tags(tree.gold)
740
+ fout.write(args['predict_format'].format(output_tree))
741
+ fout.write("\n")
742
+
743
+ for i in range(num_generate):
744
+ pred_file = os.path.join(args['predict_dir'], args['predict_file'] + ".%03d.pred.mrg" % i)
745
+ with open(pred_file, 'w') as fout:
746
+ for tree in generated_treebanks[-(i+1)]:
747
+ output_tree = tree.predictions[0].tree
748
+ if args['predict_output_gold_tags']:
749
+ output_tree = output_tree.replace_tags(tree.gold)
750
+ fout.write(args['predict_format'].format(output_tree))
751
+ fout.write("\n")
752
+
753
+ with open(orig_file, 'w') as fout:
754
+ for tree in full_results:
755
+ fout.write(args['predict_format'].format(tree.gold))
756
+ fout.write("\n")
757
+
758
+ if len(full_results) == 0:
759
+ return 0.0, 0.0
760
+ if evaluator is None:
761
+ if num_generate > 0:
762
+ kbest = max(len(fr.predictions) for fr in full_results)
763
+ else:
764
+ kbest = None
765
+ with EvaluateParser(kbest=kbest) as evaluator:
766
+ response = evaluator.process(full_results)
767
+ else:
768
+ response = evaluator.process(full_results)
769
+
770
+ kbestF1 = response.kbestF1 if response.HasField("kbestF1") else None
771
+ return response.f1, kbestF1, response.treeF1
stanza/stanza/models/constituency/partitioned_transformer.py ADDED
@@ -0,0 +1,308 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Transformer with partitioned content and position features.
3
+
4
+ See section 3 of https://arxiv.org/pdf/1805.01052.pdf
5
+ """
6
+
7
+ import copy
8
+ import math
9
+
10
+ import torch
11
+ import torch.nn as nn
12
+ import torch.nn.functional as F
13
+
14
+ from stanza.models.constituency.positional_encoding import ConcatSinusoidalEncoding
15
+
16
+ class FeatureDropoutFunction(torch.autograd.function.InplaceFunction):
17
+ @staticmethod
18
+ def forward(ctx, input, p=0.5, train=False, inplace=False):
19
+ if p < 0 or p > 1:
20
+ raise ValueError(
21
+ "dropout probability has to be between 0 and 1, but got {}".format(p)
22
+ )
23
+
24
+ ctx.p = p
25
+ ctx.train = train
26
+ ctx.inplace = inplace
27
+
28
+ if ctx.inplace:
29
+ ctx.mark_dirty(input)
30
+ output = input
31
+ else:
32
+ output = input.clone()
33
+
34
+ if ctx.p > 0 and ctx.train:
35
+ ctx.noise = torch.empty(
36
+ (input.size(0), input.size(-1)),
37
+ dtype=input.dtype,
38
+ layout=input.layout,
39
+ device=input.device,
40
+ )
41
+ if ctx.p == 1:
42
+ ctx.noise.fill_(0)
43
+ else:
44
+ ctx.noise.bernoulli_(1 - ctx.p).div_(1 - ctx.p)
45
+ ctx.noise = ctx.noise[:, None, :]
46
+ output.mul_(ctx.noise)
47
+
48
+ return output
49
+
50
+ @staticmethod
51
+ def backward(ctx, grad_output):
52
+ if ctx.p > 0 and ctx.train:
53
+ return grad_output.mul(ctx.noise), None, None, None
54
+ else:
55
+ return grad_output, None, None, None
56
+
57
+
58
+ class FeatureDropout(nn.Dropout):
59
+ """
60
+ Feature-level dropout: takes an input of size len x num_features and drops
61
+ each feature with probabibility p. A feature is dropped across the full
62
+ portion of the input that corresponds to a single batch element.
63
+ """
64
+
65
+ def forward(self, x):
66
+ if isinstance(x, tuple):
67
+ x_c, x_p = x
68
+ x_c = FeatureDropoutFunction.apply(x_c, self.p, self.training, self.inplace)
69
+ x_p = FeatureDropoutFunction.apply(x_p, self.p, self.training, self.inplace)
70
+ return x_c, x_p
71
+ else:
72
+ return FeatureDropoutFunction.apply(x, self.p, self.training, self.inplace)
73
+
74
+
75
+ # TODO: this module apparently is not treated the same the built-in
76
+ # nonlinearity modules, as multiple uses of the same relu on different
77
+ # tensors winds up mixing the gradients See if there is a way to
78
+ # resolve that other than creating a new nonlinearity for each layer
79
+ class PartitionedReLU(nn.ReLU):
80
+ def forward(self, x):
81
+ if isinstance(x, tuple):
82
+ x_c, x_p = x
83
+ else:
84
+ x_c, x_p = torch.chunk(x, 2, dim=-1)
85
+ return super().forward(x_c), super().forward(x_p)
86
+
87
+
88
+ class PartitionedLinear(nn.Module):
89
+ def __init__(self, in_features, out_features, bias=True):
90
+ super().__init__()
91
+ self.linear_c = nn.Linear(in_features // 2, out_features // 2, bias)
92
+ self.linear_p = nn.Linear(in_features // 2, out_features // 2, bias)
93
+
94
+ def forward(self, x):
95
+ if isinstance(x, tuple):
96
+ x_c, x_p = x
97
+ else:
98
+ x_c, x_p = torch.chunk(x, 2, dim=-1)
99
+
100
+ out_c = self.linear_c(x_c)
101
+ out_p = self.linear_p(x_p)
102
+ return out_c, out_p
103
+
104
+
105
+ class PartitionedMultiHeadAttention(nn.Module):
106
+ def __init__(
107
+ self, d_model, n_head, d_qkv, attention_dropout=0.1, initializer_range=0.02
108
+ ):
109
+ super().__init__()
110
+
111
+ self.w_qkv_c = nn.Parameter(torch.Tensor(n_head, d_model // 2, 3, d_qkv // 2))
112
+ self.w_qkv_p = nn.Parameter(torch.Tensor(n_head, d_model // 2, 3, d_qkv // 2))
113
+ self.w_o_c = nn.Parameter(torch.Tensor(n_head, d_qkv // 2, d_model // 2))
114
+ self.w_o_p = nn.Parameter(torch.Tensor(n_head, d_qkv // 2, d_model // 2))
115
+
116
+ bound = math.sqrt(3.0) * initializer_range
117
+ for param in [self.w_qkv_c, self.w_qkv_p, self.w_o_c, self.w_o_p]:
118
+ nn.init.uniform_(param, -bound, bound)
119
+ self.scaling_factor = 1 / d_qkv ** 0.5
120
+
121
+ self.dropout = nn.Dropout(attention_dropout)
122
+
123
+ def forward(self, x, mask=None):
124
+ if isinstance(x, tuple):
125
+ x_c, x_p = x
126
+ else:
127
+ x_c, x_p = torch.chunk(x, 2, dim=-1)
128
+ qkv_c = torch.einsum("btf,hfca->bhtca", x_c, self.w_qkv_c)
129
+ qkv_p = torch.einsum("btf,hfca->bhtca", x_p, self.w_qkv_p)
130
+ q_c, k_c, v_c = [c.squeeze(dim=3) for c in torch.chunk(qkv_c, 3, dim=3)]
131
+ q_p, k_p, v_p = [c.squeeze(dim=3) for c in torch.chunk(qkv_p, 3, dim=3)]
132
+ q = torch.cat([q_c, q_p], dim=-1) * self.scaling_factor
133
+ k = torch.cat([k_c, k_p], dim=-1)
134
+ v = torch.cat([v_c, v_p], dim=-1)
135
+ dots = torch.einsum("bhqa,bhka->bhqk", q, k)
136
+ if mask is not None:
137
+ dots.data.masked_fill_(~mask[:, None, None, :], -float("inf"))
138
+ probs = F.softmax(dots, dim=-1)
139
+ probs = self.dropout(probs)
140
+ o = torch.einsum("bhqk,bhka->bhqa", probs, v)
141
+ o_c, o_p = torch.chunk(o, 2, dim=-1)
142
+ out_c = torch.einsum("bhta,haf->btf", o_c, self.w_o_c)
143
+ out_p = torch.einsum("bhta,haf->btf", o_p, self.w_o_p)
144
+ return out_c, out_p
145
+
146
+
147
+ class PartitionedTransformerEncoderLayer(nn.Module):
148
+ def __init__(self,
149
+ d_model,
150
+ n_head,
151
+ d_qkv,
152
+ d_ff,
153
+ ff_dropout,
154
+ residual_dropout,
155
+ attention_dropout,
156
+ activation=PartitionedReLU(),
157
+ ):
158
+ super().__init__()
159
+ self.self_attn = PartitionedMultiHeadAttention(
160
+ d_model, n_head, d_qkv, attention_dropout=attention_dropout
161
+ )
162
+ self.linear1 = PartitionedLinear(d_model, d_ff)
163
+ self.ff_dropout = FeatureDropout(ff_dropout)
164
+ self.linear2 = PartitionedLinear(d_ff, d_model)
165
+
166
+ self.norm_attn = nn.LayerNorm(d_model)
167
+ self.norm_ff = nn.LayerNorm(d_model)
168
+ self.residual_dropout_attn = FeatureDropout(residual_dropout)
169
+ self.residual_dropout_ff = FeatureDropout(residual_dropout)
170
+
171
+ self.activation = activation
172
+
173
+ def forward(self, x, mask=None):
174
+ residual = self.self_attn(x, mask=mask)
175
+ residual = torch.cat(residual, dim=-1)
176
+ residual = self.residual_dropout_attn(residual)
177
+ x = self.norm_attn(x + residual)
178
+ residual = self.linear2(self.ff_dropout(self.activation(self.linear1(x))))
179
+ residual = torch.cat(residual, dim=-1)
180
+ residual = self.residual_dropout_ff(residual)
181
+ x = self.norm_ff(x + residual)
182
+ return x
183
+
184
+
185
+ class PartitionedTransformerEncoder(nn.Module):
186
+ def __init__(self,
187
+ n_layers,
188
+ d_model,
189
+ n_head,
190
+ d_qkv,
191
+ d_ff,
192
+ ff_dropout,
193
+ residual_dropout,
194
+ attention_dropout,
195
+ activation=PartitionedReLU,
196
+ ):
197
+ super().__init__()
198
+ self.layers = nn.ModuleList([PartitionedTransformerEncoderLayer(d_model=d_model,
199
+ n_head=n_head,
200
+ d_qkv=d_qkv,
201
+ d_ff=d_ff,
202
+ ff_dropout=ff_dropout,
203
+ residual_dropout=residual_dropout,
204
+ attention_dropout=attention_dropout,
205
+ activation=activation())
206
+ for i in range(n_layers)])
207
+
208
+ def forward(self, x, mask=None):
209
+ for layer in self.layers:
210
+ x = layer(x, mask=mask)
211
+ return x
212
+
213
+
214
+ class ConcatPositionalEncoding(nn.Module):
215
+ """
216
+ Learns a position embedding
217
+ """
218
+ def __init__(self, d_model=256, max_len=512):
219
+ super().__init__()
220
+ self.timing_table = nn.Parameter(torch.FloatTensor(max_len, d_model))
221
+ nn.init.normal_(self.timing_table)
222
+
223
+ def forward(self, x):
224
+ timing = self.timing_table[:x.shape[1], :]
225
+ timing = timing.expand(x.shape[0], -1, -1)
226
+ out = torch.cat([x, timing], dim=-1)
227
+ return out
228
+
229
+ #
230
+ class PartitionedTransformerModule(nn.Module):
231
+ def __init__(self,
232
+ n_layers,
233
+ d_model,
234
+ n_head,
235
+ d_qkv,
236
+ d_ff,
237
+ ff_dropout,
238
+ residual_dropout,
239
+ attention_dropout,
240
+ word_input_size,
241
+ bias,
242
+ morpho_emb_dropout,
243
+ timing,
244
+ encoder_max_len,
245
+ activation=PartitionedReLU()
246
+ ):
247
+ super().__init__()
248
+ self.project_pretrained = nn.Linear(
249
+ word_input_size, d_model // 2, bias=bias
250
+ )
251
+
252
+ self.pattention_morpho_emb_dropout = FeatureDropout(morpho_emb_dropout)
253
+ if timing == 'sin':
254
+ self.add_timing = ConcatSinusoidalEncoding(d_model=d_model // 2, max_len=encoder_max_len)
255
+ elif timing == 'learned':
256
+ self.add_timing = ConcatPositionalEncoding(d_model=d_model // 2, max_len=encoder_max_len)
257
+ else:
258
+ raise ValueError("Unhandled timing type: %s" % timing)
259
+ self.transformer_input_norm = nn.LayerNorm(d_model)
260
+ self.pattn_encoder = PartitionedTransformerEncoder(
261
+ n_layers,
262
+ d_model=d_model,
263
+ n_head=n_head,
264
+ d_qkv=d_qkv,
265
+ d_ff=d_ff,
266
+ ff_dropout=ff_dropout,
267
+ residual_dropout=residual_dropout,
268
+ attention_dropout=attention_dropout,
269
+ )
270
+
271
+
272
+ #
273
+ def forward(self, attention_mask, bert_embeddings):
274
+ # Prepares attention mask for feeding into the self-attention
275
+ device = bert_embeddings[0].device
276
+ if attention_mask:
277
+ valid_token_mask = attention_mask
278
+ else:
279
+ valids = []
280
+ for sent in bert_embeddings:
281
+ valids.append(torch.ones(len(sent), device=device))
282
+
283
+ padded_data = torch.nn.utils.rnn.pad_sequence(
284
+ valids,
285
+ batch_first=True,
286
+ padding_value=-100
287
+ )
288
+
289
+ valid_token_mask = padded_data != -100
290
+
291
+ valid_token_mask = valid_token_mask.to(device=device)
292
+ padded_embeddings = torch.nn.utils.rnn.pad_sequence(
293
+ bert_embeddings,
294
+ batch_first=True,
295
+ padding_value=0
296
+ )
297
+
298
+ # Project the pretrained embedding onto the desired dimension
299
+ extra_content_annotations = self.project_pretrained(padded_embeddings)
300
+
301
+ # Add positional information through the table
302
+ encoder_in = self.add_timing(self.pattention_morpho_emb_dropout(extra_content_annotations))
303
+ encoder_in = self.transformer_input_norm(encoder_in)
304
+ # Put the partitioned input through the partitioned attention
305
+ annotations = self.pattn_encoder(encoder_in, valid_token_mask)
306
+
307
+ return annotations
308
+
stanza/stanza/models/coref/__init__.py ADDED
File without changes
stanza/stanza/models/coref/anaphoricity_scorer.py ADDED
@@ -0,0 +1,122 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ Describes AnaphicityScorer, a torch module that for a matrix of
2
+ mentions produces their anaphoricity scores.
3
+ """
4
+ import torch
5
+
6
+ from stanza.models.coref import utils
7
+ from stanza.models.coref.config import Config
8
+
9
+
10
+ class AnaphoricityScorer(torch.nn.Module):
11
+ """ Calculates anaphoricity scores by passing the inputs into a FFNN """
12
+ def __init__(self,
13
+ in_features: int,
14
+ config: Config):
15
+ super().__init__()
16
+ hidden_size = config.hidden_size
17
+ if not config.n_hidden_layers:
18
+ hidden_size = in_features
19
+ layers = []
20
+ for i in range(config.n_hidden_layers):
21
+ layers.extend([torch.nn.Linear(hidden_size if i else in_features,
22
+ hidden_size),
23
+ torch.nn.LeakyReLU(),
24
+ torch.nn.Dropout(config.dropout_rate)])
25
+ self.hidden = torch.nn.Sequential(*layers)
26
+ self.out = torch.nn.Linear(hidden_size, out_features=1)
27
+
28
+ # are we going to predict singletons
29
+ self.predict_singletons = config.singletons
30
+
31
+ if self.predict_singletons:
32
+ # map to whether or not this is a start of a coref given all the
33
+ # antecedents; not used when config.singletons = False because
34
+ # we only need to know this for predicting singletons
35
+ self.start_map = torch.nn.Linear(config.rough_k, out_features=1, bias=False)
36
+
37
+
38
+ def forward(self, *, # type: ignore # pylint: disable=arguments-differ #35566 in pytorch
39
+ top_mentions: torch.Tensor,
40
+ mentions_batch: torch.Tensor,
41
+ pw_batch: torch.Tensor,
42
+ top_rough_scores_batch: torch.Tensor,
43
+ ) -> torch.Tensor:
44
+ """ Builds a pairwise matrix, scores the pairs and returns the scores.
45
+
46
+ Args:
47
+ all_mentions (torch.Tensor): [n_mentions, mention_emb]
48
+ mentions_batch (torch.Tensor): [batch_size, mention_emb]
49
+ pw_batch (torch.Tensor): [batch_size, n_ants, pw_emb]
50
+ top_indices_batch (torch.Tensor): [batch_size, n_ants]
51
+ top_rough_scores_batch (torch.Tensor): [batch_size, n_ants]
52
+
53
+ Returns:
54
+ torch.Tensor [batch_size, n_ants + 1]
55
+ anaphoricity scores for the pairs + a dummy column
56
+ """
57
+ # [batch_size, n_ants, pair_emb]
58
+ pair_matrix = self._get_pair_matrix(mentions_batch, pw_batch, top_mentions)
59
+
60
+ # [batch_size, n_ants] vs [batch_size, 1]
61
+ # first is coref scores, the second is whether its the start of a coref
62
+ if self.predict_singletons:
63
+ scores, start = self._ffnn(pair_matrix)
64
+ scores = utils.add_dummy(scores+top_rough_scores_batch, eps=True)
65
+
66
+ return torch.cat([start, scores], dim=1)
67
+ else:
68
+ scores = self._ffnn(pair_matrix)
69
+ return utils.add_dummy(scores+top_rough_scores_batch, eps=True)
70
+
71
+ def _ffnn(self, x: torch.Tensor) -> torch.Tensor:
72
+ """
73
+ Calculates anaphoricity scores.
74
+
75
+ Args:
76
+ x: tensor of shape [batch_size, n_ants, n_features]
77
+
78
+ Returns:
79
+ tensor of shape [batch_size, n_ants]
80
+ """
81
+ x = self.out(self.hidden(x))
82
+ x = x.squeeze(2)
83
+
84
+ if not self.predict_singletons:
85
+ return x
86
+
87
+ # because sometimes we only have the first 49 anaphoricities
88
+ start = x @ self.start_map.weight[:,:x.shape[1]].T
89
+ return x, start
90
+
91
+ @staticmethod
92
+ def _get_pair_matrix(mentions_batch: torch.Tensor,
93
+ pw_batch: torch.Tensor,
94
+ top_mentions: torch.Tensor) -> torch.Tensor:
95
+ """
96
+ Builds the matrix used as input for AnaphoricityScorer.
97
+
98
+ Args:
99
+ all_mentions (torch.Tensor): [n_mentions, mention_emb],
100
+ all the valid mentions of the document,
101
+ can be on a different device
102
+ mentions_batch (torch.Tensor): [batch_size, mention_emb],
103
+ the mentions of the current batch,
104
+ is expected to be on the current device
105
+ pw_batch (torch.Tensor): [batch_size, n_ants, pw_emb],
106
+ pairwise features of the current batch,
107
+ is expected to be on the current device
108
+ top_indices_batch (torch.Tensor): [batch_size, n_ants],
109
+ indices of antecedents of each mention
110
+
111
+ Returns:
112
+ torch.Tensor: [batch_size, n_ants, pair_emb]
113
+ """
114
+ emb_size = mentions_batch.shape[1]
115
+ n_ants = pw_batch.shape[1]
116
+
117
+ a_mentions = mentions_batch.unsqueeze(1).expand(-1, n_ants, emb_size)
118
+ b_mentions = top_mentions
119
+ similarity = a_mentions * b_mentions
120
+
121
+ out = torch.cat((a_mentions, b_mentions, similarity, pw_batch), dim=2)
122
+ return out
stanza/stanza/models/coref/cluster_checker.py ADDED
@@ -0,0 +1,230 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ Describes ClusterChecker, a class used to retrieve LEA scores.
2
+ See aclweb.org/anthology/P16-1060.pdf. """
3
+
4
+ from typing import Hashable, List, Tuple
5
+
6
+ from stanza.models.coref.const import EPSILON
7
+ import numpy as np
8
+
9
+ import math
10
+ import logging
11
+
12
+ logger = logging.getLogger('stanza')
13
+
14
+
15
+ class ClusterChecker:
16
+ """ Collects information on gold and predicted clusters across documents.
17
+ Can be used to retrieve weighted LEA-score for them.
18
+ """
19
+ def __init__(self):
20
+ self._lea_precision = 0.0
21
+ self._lea_recall = 0.0
22
+ self._lea_precision_weighting = 0.0
23
+ self._lea_recall_weighting = 0.0
24
+ self._num_preds = 0.0
25
+
26
+ # muc
27
+ self._muc_precision = 0.0
28
+ self._muc_recall = 0.0
29
+
30
+ # b3
31
+ self._b3_precision = 0.0
32
+ self._b3_recall = 0.0
33
+
34
+ # ceafe
35
+ self._ceafe_precision = 0.0
36
+ self._ceafe_recall = 0.0
37
+
38
+ @staticmethod
39
+ def _f1(p,r):
40
+ return (p * r) / (p+r + EPSILON) * 2
41
+
42
+ def add_predictions(self,
43
+ gold_clusters: List[List[Hashable]],
44
+ pred_clusters: List[List[Hashable]]):
45
+ """
46
+ Calculates LEA for the document's clusters and stores them to later
47
+ output weighted LEA across documents.
48
+
49
+ Returns:
50
+ LEA score for the document as a tuple of (f1, precision, recall)
51
+ """
52
+
53
+ # if len(gold_clusters) == 0:
54
+ # breakpoint()
55
+
56
+ self._num_preds += 1
57
+
58
+ recall, r_weight = ClusterChecker._lea(gold_clusters, pred_clusters)
59
+ precision, p_weight = ClusterChecker._lea(pred_clusters, gold_clusters)
60
+
61
+ self._muc_recall += ClusterChecker._muc(gold_clusters, pred_clusters)
62
+ self._muc_precision += ClusterChecker._muc(pred_clusters, gold_clusters)
63
+
64
+ self._b3_recall += ClusterChecker._b3(gold_clusters, pred_clusters)
65
+ self._b3_precision += ClusterChecker._b3(pred_clusters, gold_clusters)
66
+
67
+ ceafe_precision, ceafe_recall = ClusterChecker._ceafe(pred_clusters, gold_clusters)
68
+ if math.isnan(ceafe_precision) and len(gold_clusters) > 0:
69
+ # because our model predicted no clusters
70
+ ceafe_precision = 0.0
71
+
72
+ self._ceafe_precision += ceafe_precision
73
+ self._ceafe_recall += ceafe_recall
74
+
75
+ self._lea_recall += recall
76
+ self._lea_recall_weighting += r_weight
77
+ self._lea_precision += precision
78
+ self._lea_precision_weighting += p_weight
79
+
80
+ doc_precision = precision / (p_weight + EPSILON)
81
+ doc_recall = recall / (r_weight + EPSILON)
82
+ doc_f1 = (doc_precision * doc_recall) \
83
+ / (doc_precision + doc_recall + EPSILON) * 2
84
+ return doc_f1, doc_precision, doc_recall
85
+
86
+ @property
87
+ def bakeoff(self):
88
+ """ Get the F1 macroaverage score used by the bakeoff """
89
+ return sum(self.mbc)/3
90
+
91
+ @property
92
+ def mbc(self):
93
+ """ Get the F1 average score of (muc, b3, ceafe) over docs """
94
+ avg_precisions = [self._muc_precision, self._b3_precision, self._ceafe_precision]
95
+ avg_precisions = [i/(self._num_preds + EPSILON) for i in avg_precisions]
96
+
97
+ avg_recalls = [self._muc_recall, self._b3_recall, self._ceafe_recall]
98
+ avg_recalls = [i/(self._num_preds + EPSILON) for i in avg_recalls]
99
+
100
+ avg_f1s = [self._f1(p,r) for p,r in zip(avg_precisions, avg_recalls)]
101
+
102
+ return avg_f1s
103
+
104
+ @property
105
+ def total_lea(self):
106
+ """ Returns weighted LEA for all the documents as
107
+ (f1, precision, recall) """
108
+ precision = self._lea_precision / (self._lea_precision_weighting + EPSILON)
109
+ recall = self._lea_recall / (self._lea_recall_weighting + EPSILON)
110
+ f1 = self._f1(precision, recall)
111
+ return f1, precision, recall
112
+
113
+ @staticmethod
114
+ def _lea(key: List[List[Hashable]],
115
+ response: List[List[Hashable]]) -> Tuple[float, float]:
116
+ """ See aclweb.org/anthology/P16-1060.pdf. """
117
+ response_clusters = [set(cluster) for cluster in response]
118
+ response_map = {mention: cluster
119
+ for cluster in response_clusters
120
+ for mention in cluster}
121
+ importances = []
122
+ resolutions = []
123
+ for entity in key:
124
+ size = len(entity)
125
+ if size == 1: # entities of size 1 are not annotated
126
+ continue
127
+ importances.append(size)
128
+ correct_links = 0
129
+ for i in range(size):
130
+ for j in range(i + 1, size):
131
+ correct_links += int(entity[i]
132
+ in response_map.get(entity[j], {}))
133
+ resolutions.append(correct_links / (size * (size - 1) / 2))
134
+ res = sum(imp * res for imp, res in zip(importances, resolutions))
135
+ weight = sum(importances)
136
+ return res, weight
137
+
138
+ @staticmethod
139
+ def _muc(key: List[List[Hashable]],
140
+ response: List[List[Hashable]]) -> float:
141
+ """ See aclweb.org/anthology/P16-1060.pdf. """
142
+
143
+ response_clusters = [set(cluster) for cluster in response]
144
+ response_map = {mention: cluster
145
+ for cluster in response_clusters
146
+ for mention in cluster}
147
+
148
+ top = 0 # sum over k of |k_i| - response_partitions(|k_i|)
149
+ bottom = 0 # sum over k of |k_i| - 1
150
+
151
+ for entity in key:
152
+ S = len(entity)
153
+ # we need to figure the number of DIFFERENT clusters
154
+ # the response assigns to members of the entity; ideally
155
+ # this number is 1 (i.e. they are all assigned the same
156
+ # coref).
157
+ response_clusters = [response_map.get(i, None) for i in entity]
158
+ # and dedplicate
159
+ deduped = []
160
+ for i in response_clusters:
161
+ if i == None:
162
+ deduped.append(i)
163
+ elif i not in deduped:
164
+ deduped.append(i)
165
+ # the "partitions" will then be size of the deduped list
166
+ p_k = len(deduped)
167
+ top += (S - p_k)
168
+ bottom += (S - 1)
169
+
170
+ try:
171
+ return top/bottom
172
+ except ZeroDivisionError:
173
+ logger.warning("muc got a zero division error because the model predicted no spans!")
174
+ return 0 # +inf technically
175
+
176
+ @staticmethod
177
+ def _b3(key: List[List[Hashable]],
178
+ response: List[List[Hashable]]) -> float:
179
+ """ See aclweb.org/anthology/P16-1060.pdf. """
180
+
181
+ response_clusters = [set(cluster) for cluster in response]
182
+
183
+ top = 0 # sum over key and response of (|k intersect response|^2/|k|)
184
+ bottom = 0 # sum over k of |k_i|
185
+
186
+ for entity in key:
187
+ bottom += len(entity)
188
+ entity = set(entity)
189
+
190
+ for res_entity in response_clusters:
191
+ top += (len(entity.intersection(res_entity))**2)/len(entity)
192
+
193
+ try:
194
+ return top/bottom
195
+ except ZeroDivisionError:
196
+ logger.warning("b3 got a zero division error because the model predicted no spans!")
197
+ return 0 # +inf technically
198
+
199
+
200
+
201
+ @staticmethod
202
+ def _phi4(c1, c2):
203
+ return 2 * len([m for m in c1 if m in c2]) / float(len(c1) + len(c2))
204
+
205
+ @staticmethod
206
+ def _ceafe(clusters: List[List[Hashable]], gold_clusters: List[List[Hashable]]):
207
+ """ see https://github.com/ufal/corefud-scorer/blob/main/coval/eval/evaluator.py """
208
+
209
+ try:
210
+ from scipy.optimize import linear_sum_assignment
211
+ except ImportError:
212
+ raise ImportError("To perform CEAF scoring, please install scipy via `pip install scipy` for the Kuhn-Munkres linear assignment scheme.")
213
+
214
+ clusters = [c for c in clusters]
215
+ scores = np.zeros((len(gold_clusters), len(clusters)))
216
+ for i in range(len(gold_clusters)):
217
+ for j in range(len(clusters)):
218
+ scores[i, j] = ClusterChecker._phi4(gold_clusters[i], clusters[j])
219
+ row_ind, col_ind = linear_sum_assignment(-scores)
220
+ similarity = scores[row_ind, col_ind].sum()
221
+
222
+ # precision, recall
223
+ try:
224
+ prec = similarity/len(clusters)
225
+ except ZeroDivisionError:
226
+ logger.warning("ceafe got a zero division error because the model predicted no spans!")
227
+ prec = 0
228
+ recc = similarity/len(gold_clusters)
229
+ return prec, recc
230
+
stanza/stanza/models/coref/conll.py ADDED
@@ -0,0 +1,105 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ Contains functions to produce conll-formatted output files with
2
+ predicted spans and their clustering """
3
+
4
+ from collections import defaultdict
5
+ from contextlib import contextmanager
6
+ import os
7
+ from typing import List, TextIO
8
+
9
+ from stanza.models.coref.config import Config
10
+ from stanza.models.coref.const import Doc, Span
11
+
12
+
13
+ # pylint: disable=too-many-locals
14
+ def write_conll(doc: Doc,
15
+ clusters: List[List[Span]],
16
+ heads: List[int],
17
+ f_obj: TextIO):
18
+ """ Writes span/cluster information to f_obj, which is assumed to be a file
19
+ object open for writing """
20
+ placeholder = list("\t_" * 7)
21
+ # the nth token needs to be a number
22
+ placeholder[9] = "0"
23
+ placeholder = "".join(placeholder)
24
+ doc_id = doc["document_id"].replace("-", "_").replace("/", "_").replace(".","_")
25
+ words = doc["cased_words"]
26
+ part_id = doc["part_id"]
27
+ sents = doc["sent_id"]
28
+
29
+ max_word_len = max(len(w) for w in words)
30
+
31
+ starts = defaultdict(lambda: [])
32
+ ends = defaultdict(lambda: [])
33
+ single_word = defaultdict(lambda: [])
34
+
35
+ for cluster_id, cluster in enumerate(clusters):
36
+ if len(heads[cluster_id]) != len(cluster):
37
+ # TODO debug this fact and why it occurs
38
+ # print(f"cluster {cluster_id} doesn't have the same number of elements for word and span levels, skipping...")
39
+ continue
40
+ for cluster_part, (start, end) in enumerate(cluster):
41
+ if end - start == 1:
42
+ single_word[start].append((cluster_part, cluster_id))
43
+ else:
44
+ starts[start].append((cluster_part, cluster_id))
45
+ ends[end - 1].append((cluster_part, cluster_id))
46
+
47
+ f_obj.write(f"# newdoc id = {doc_id}\n# global.Entity = eid-head\n")
48
+
49
+ word_number = 0
50
+ sent_id = 0
51
+ for word_id, word in enumerate(words):
52
+
53
+ cluster_info_lst = []
54
+ for part, cluster_marker in starts[word_id]:
55
+ start, end = clusters[cluster_marker][part]
56
+ cluster_info_lst.append(f"(e{cluster_marker}-{min(heads[cluster_marker][part], end-start)}")
57
+ for part, cluster_marker in single_word[word_id]:
58
+ start, end = clusters[cluster_marker][part]
59
+ cluster_info_lst.append(f"(e{cluster_marker}-{min(heads[cluster_marker][part], end-start)})")
60
+ for part, cluster_marker in ends[word_id]:
61
+ cluster_info_lst.append(f"e{cluster_marker})")
62
+
63
+
64
+ # we need our clusters to be ordered such that the one that is closest the first change
65
+ # is listed last in the chains
66
+ def compare_sort(x):
67
+ split = x.split("-")
68
+ if len(split) > 1:
69
+ return int(split[-1].replace(")", "").strip())
70
+ else:
71
+ # we want everything that's a closer to be first
72
+ return float("inf")
73
+
74
+ cluster_info_lst = sorted(cluster_info_lst, key=compare_sort, reverse=True)
75
+ cluster_info = "".join(cluster_info_lst) if cluster_info_lst else "_"
76
+
77
+ if word_id == 0 or sents[word_id] != sents[word_id - 1]:
78
+ f_obj.write(f"# sent_id = {doc_id}-{sent_id}\n")
79
+ word_number = 0
80
+ sent_id += 1
81
+
82
+ if cluster_info != "_":
83
+ cluster_info = f"Entity={cluster_info}"
84
+
85
+ f_obj.write(f"{word_id}\t{word}{placeholder}\t{cluster_info}\n")
86
+
87
+ word_number += 1
88
+
89
+ f_obj.write("\n")
90
+
91
+
92
+ @contextmanager
93
+ def open_(config: Config, epochs: int, data_split: str):
94
+ """ Opens conll log files for writing in a safe way. """
95
+ base_filename = f"{config.section}_{data_split}_e{epochs}"
96
+ conll_dir = config.conll_log_dir
97
+ kwargs = {"mode": "w", "encoding": "utf8"}
98
+
99
+ os.makedirs(conll_dir, exist_ok=True)
100
+
101
+ with open(os.path.join( # type: ignore
102
+ conll_dir, f"{base_filename}.gold.conll"), **kwargs) as gold_f:
103
+ with open(os.path.join( # type: ignore
104
+ conll_dir, f"{base_filename}.pred.conll"), **kwargs) as pred_f:
105
+ yield (gold_f, pred_f)
stanza/stanza/models/coref/const.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ Contains type aliases for coref module """
2
+
3
+ from dataclasses import dataclass
4
+ from typing import Any, Dict, List, Tuple
5
+
6
+ import torch
7
+
8
+
9
+ EPSILON = 1e-7
10
+ LARGE_VALUE = 1000 # used instead of inf due to bug #16762 in pytorch
11
+
12
+ Doc = Dict[str, Any]
13
+ Span = Tuple[int, int]
14
+
15
+
16
+ @dataclass
17
+ class CorefResult:
18
+ coref_scores: torch.Tensor = None # [n_words, k + 1]
19
+ coref_y: torch.Tensor = None # [n_words, k + 1]
20
+ rough_y: torch.Tensor = None # [n_words, n_words]
21
+
22
+ word_clusters: List[List[int]] = None
23
+ span_clusters: List[List[Span]] = None
24
+
25
+ rough_scores: torch.Tensor = None # [n_words, n_words]
26
+ span_scores: torch.Tensor = None # [n_heads, n_words, 2]
27
+ span_y: Tuple[torch.Tensor, torch.Tensor] = None # [n_heads] x2
stanza/stanza/models/coref/coref_chain.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Coref chain suitable for attaching to a Document after coref processing
3
+ """
4
+
5
+ # by not using namedtuple, we can use this object as output from the json module
6
+ # in the doc class as long as we wrap the encoder to print these out in dict() form
7
+ # CorefMention = namedtuple('CorefMention', ['sentence', 'start_word', 'end_word'])
8
+ class CorefMention:
9
+ def __init__(self, sentence, start_word, end_word):
10
+ self.sentence = sentence
11
+ self.start_word = start_word
12
+ self.end_word = end_word
13
+
14
+ class CorefChain:
15
+ def __init__(self, index, mentions, representative_text, representative_index):
16
+ self.index = index
17
+ self.mentions = mentions
18
+ self.representative_text = representative_text
19
+ self.representative_index = representative_index
20
+
21
+ class CorefAttachment:
22
+ def __init__(self, chain, is_start, is_end, is_representative):
23
+ self.chain = chain
24
+ self.is_start = is_start
25
+ self.is_end = is_end
26
+ self.is_representative = is_representative
27
+
28
+ def to_json(self):
29
+ j = {
30
+ "index": self.chain.index,
31
+ "representative_text": self.chain.representative_text
32
+ }
33
+ if self.is_start:
34
+ j['is_start'] = True
35
+ if self.is_end:
36
+ j['is_end'] = True
37
+ if self.is_representative:
38
+ j['is_representative'] = True
39
+ return j
stanza/stanza/models/coref/loss.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ Describes the loss function used to train the model, which is a weighted
2
+ sum of NLML and BCE losses. """
3
+
4
+ import torch
5
+
6
+
7
+ class CorefLoss(torch.nn.Module):
8
+ """ See the rationale for using NLML in Lee et al. 2017
9
+ https://www.aclweb.org/anthology/D17-1018/
10
+ The added weighted summand of BCE helps the model learn even after
11
+ converging on the NLML task. """
12
+
13
+ def __init__(self, bce_weight: float):
14
+ assert 0 <= bce_weight <= 1
15
+ super().__init__()
16
+ self._bce_module = torch.nn.BCEWithLogitsLoss()
17
+ self._bce_weight = bce_weight
18
+
19
+ def forward(self, # type: ignore # pylint: disable=arguments-differ #35566 in pytorch
20
+ input_: torch.Tensor,
21
+ target: torch.Tensor) -> torch.Tensor:
22
+ """ Returns a weighted sum of two losses as a torch.Tensor """
23
+ return (self._nlml(input_, target)
24
+ + self._bce(input_, target) * self._bce_weight)
25
+
26
+ def _bce(self,
27
+ input_: torch.Tensor,
28
+ target: torch.Tensor) -> torch.Tensor:
29
+ """ For numerical stability, clamps the input before passing it to BCE.
30
+ """
31
+ return self._bce_module(torch.clamp(input_, min=-50, max=50), target)
32
+
33
+ @staticmethod
34
+ def _nlml(input_: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
35
+ gold = torch.logsumexp(input_ + torch.log(target), dim=1)
36
+ input_ = torch.logsumexp(input_, dim=1)
37
+ return (input_ - gold).mean()
stanza/stanza/models/coref/model.py ADDED
@@ -0,0 +1,784 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ see __init__.py """
2
+
3
+ from datetime import datetime
4
+ import dataclasses
5
+ import json
6
+ import logging
7
+ import os
8
+ import random
9
+ import re
10
+ from typing import Any, Dict, List, Optional, Set, Tuple
11
+
12
+ import numpy as np # type: ignore
13
+ try:
14
+ import tomllib
15
+ except ImportError:
16
+ import tomli as tomllib
17
+ import torch
18
+ import transformers # type: ignore
19
+
20
+ from pickle import UnpicklingError
21
+ import warnings
22
+
23
+ from stanza.utils.get_tqdm import get_tqdm # type: ignore
24
+ tqdm = get_tqdm()
25
+
26
+ from stanza.models.coref import bert, conll, utils
27
+ from stanza.models.coref.anaphoricity_scorer import AnaphoricityScorer
28
+ from stanza.models.coref.cluster_checker import ClusterChecker
29
+ from stanza.models.coref.config import Config
30
+ from stanza.models.coref.const import CorefResult, Doc
31
+ from stanza.models.coref.loss import CorefLoss
32
+ from stanza.models.coref.pairwise_encoder import PairwiseEncoder
33
+ from stanza.models.coref.rough_scorer import RoughScorer
34
+ from stanza.models.coref.span_predictor import SpanPredictor
35
+ from stanza.models.coref.utils import GraphNode
36
+ from stanza.models.coref.word_encoder import WordEncoder
37
+ from stanza.models.coref.dataset import CorefDataset
38
+ from stanza.models.coref.tokenizer_customization import *
39
+
40
+ from stanza.models.common.bert_embedding import load_tokenizer
41
+ from stanza.models.common.foundation_cache import load_bert, load_bert_with_peft, NoTransformerFoundationCache
42
+ from stanza.models.common.peft_config import build_peft_wrapper, load_peft_wrapper
43
+
44
+ logger = logging.getLogger('stanza')
45
+
46
+ class CorefModel: # pylint: disable=too-many-instance-attributes
47
+ """Combines all coref modules together to find coreferent spans.
48
+
49
+ Attributes:
50
+ config (coref.config.Config): the model's configuration,
51
+ see config.toml for the details
52
+ epochs_trained (int): number of epochs the model has been trained for
53
+ trainable (Dict[str, torch.nn.Module]): trainable submodules with their
54
+ names used as keys
55
+ training (bool): used to toggle train/eval modes
56
+
57
+ Submodules (in the order of their usage in the pipeline):
58
+ tokenizer (transformers.AutoTokenizer)
59
+ bert (transformers.AutoModel)
60
+ we (WordEncoder)
61
+ rough_scorer (RoughScorer)
62
+ pw (PairwiseEncoder)
63
+ a_scorer (AnaphoricityScorer)
64
+ sp (SpanPredictor)
65
+ """
66
+ def __init__(self,
67
+ epochs_trained: int = 0,
68
+ build_optimizers: bool = True,
69
+ config: Optional[dict] = None,
70
+ foundation_cache=None):
71
+ """
72
+ A newly created model is set to evaluation mode.
73
+
74
+ Args:
75
+ config_path (str): the path to the toml file with the configuration
76
+ section (str): the selected section of the config file
77
+ epochs_trained (int): the number of epochs finished
78
+ (useful for warm start)
79
+ """
80
+ if config is None:
81
+ raise ValueError("Cannot create a model without a config")
82
+ self.config = config
83
+ self.epochs_trained = epochs_trained
84
+ self._docs: Dict[str, List[Doc]] = {}
85
+ self._build_model(foundation_cache)
86
+
87
+ self.optimizers = {}
88
+ self.schedulers = {}
89
+
90
+ if build_optimizers:
91
+ self._build_optimizers()
92
+ self._set_training(False)
93
+
94
+ # final coreference resolution score
95
+ self._coref_criterion = CorefLoss(self.config.bce_loss_weight)
96
+ # score simply for the top-k choices out of the rough scorer
97
+ self._rough_criterion = CorefLoss(0)
98
+ # exact span matches
99
+ self._span_criterion = torch.nn.CrossEntropyLoss(reduction="sum")
100
+
101
+ @property
102
+ def training(self) -> bool:
103
+ """ Represents whether the model is in the training mode """
104
+ return self._training
105
+
106
+ @training.setter
107
+ def training(self, new_value: bool):
108
+ if self._training is new_value:
109
+ return
110
+ self._set_training(new_value)
111
+
112
+ # ========================================================== Public methods
113
+
114
+ @torch.no_grad()
115
+ def evaluate(self,
116
+ data_split: str = "dev",
117
+ word_level_conll: bool = False,
118
+ eval_lang: Optional[str] = None
119
+ ) -> Tuple[float, Tuple[float, float, float]]:
120
+ """ Evaluates the modes on the data split provided.
121
+
122
+ Args:
123
+ data_split (str): one of 'dev'/'test'/'train'
124
+ word_level_conll (bool): if True, outputs conll files on word-level
125
+ eval_lang (str): which language to evaluate
126
+
127
+ Returns:
128
+ mean loss
129
+ span-level LEA: f1, precision, recal
130
+ """
131
+ self.training = False
132
+ w_checker = ClusterChecker()
133
+ s_checker = ClusterChecker()
134
+ try:
135
+ data_split_data = f"{data_split}_data"
136
+ data_path = self.config.__dict__[data_split_data]
137
+ docs = self._get_docs(data_path)
138
+ except FileNotFoundError as e:
139
+ raise FileNotFoundError("Unable to find data split %s at file %s" % (data_split_data, data_path)) from e
140
+ running_loss = 0.0
141
+ s_correct = 0
142
+ s_total = 0
143
+
144
+ with conll.open_(self.config, self.epochs_trained, data_split) \
145
+ as (gold_f, pred_f):
146
+ pbar = tqdm(docs, unit="docs", ncols=0)
147
+ for doc in pbar:
148
+ if eval_lang and doc.get("lang", "") != eval_lang:
149
+ # skip that document, only used for ablation where we only
150
+ # want to test evaluation on one language
151
+ continue
152
+
153
+ res = self.run(doc)
154
+
155
+ if (res.coref_y.argmax(dim=1) == 1).all():
156
+ logger.warning(f"EVAL: skipping document with no corefs...")
157
+ continue
158
+
159
+ running_loss += self._coref_criterion(res.coref_scores, res.coref_y).item()
160
+
161
+ if res.span_y:
162
+ pred_starts = res.span_scores[:, :, 0].argmax(dim=1)
163
+ pred_ends = res.span_scores[:, :, 1].argmax(dim=1)
164
+ s_correct += ((res.span_y[0] == pred_starts) * (res.span_y[1] == pred_ends)).sum().item()
165
+ s_total += len(pred_starts)
166
+
167
+
168
+ if word_level_conll:
169
+ raise NotImplementedError("We now write Conll-U conforming to UDCoref, which means that the span_clusters annotations will have headword info. word_level option is meaningless.")
170
+ else:
171
+ conll.write_conll(doc, doc["span_clusters"], doc["word_clusters"], gold_f)
172
+ conll.write_conll(doc, res.span_clusters, res.word_clusters, pred_f)
173
+
174
+ w_checker.add_predictions(doc["word_clusters"], res.word_clusters)
175
+ w_lea = w_checker.total_lea
176
+
177
+ s_checker.add_predictions(doc["span_clusters"], res.span_clusters)
178
+ s_lea = s_checker.total_lea
179
+
180
+ del res
181
+
182
+ pbar.set_description(
183
+ f"{data_split}:"
184
+ f" | WL: "
185
+ f" loss: {running_loss / (pbar.n + 1):<.5f},"
186
+ f" f1: {w_lea[0]:.5f},"
187
+ f" p: {w_lea[1]:.5f},"
188
+ f" r: {w_lea[2]:<.5f}"
189
+ f" | SL: "
190
+ f" sa: {s_correct / s_total:<.5f},"
191
+ f" f1: {s_lea[0]:.5f},"
192
+ f" p: {s_lea[1]:.5f},"
193
+ f" r: {s_lea[2]:<.5f}"
194
+ )
195
+ logger.info(f"CoNLL-2012 3-Score Average : {w_checker.bakeoff:.5f}")
196
+
197
+ return (running_loss / len(docs), *s_checker.total_lea, *w_checker.total_lea, *s_checker.mbc, *w_checker.mbc, w_checker.bakeoff, s_checker.bakeoff)
198
+
199
+ def load_weights(self,
200
+ path: Optional[str] = None,
201
+ ignore: Optional[Set[str]] = None,
202
+ map_location: Optional[str] = None,
203
+ noexception: bool = False) -> None:
204
+ """
205
+ Loads pretrained weights of modules saved in a file located at path.
206
+ If path is None, the last saved model with current configuration
207
+ in save_dir is loaded.
208
+ Assumes files are named like {configuration}_(e{epoch}_{time})*.pt.
209
+ """
210
+ if path is None:
211
+ # pattern = rf"{self.config.save_name}_\(e(\d+)_[^()]*\).*\.pt"
212
+ # tries to load the last checkpoint in the same dir
213
+ pattern = rf"{self.config.save_name}.*?\.checkpoint\.pt"
214
+ files = []
215
+ os.makedirs(self.config.save_dir, exist_ok=True)
216
+ for f in os.listdir(self.config.save_dir):
217
+ match_obj = re.match(pattern, f)
218
+ if match_obj:
219
+ files.append(f)
220
+ if not files:
221
+ if noexception:
222
+ logger.debug("No weights have been loaded", flush=True)
223
+ return
224
+ raise OSError(f"No weights found in {self.config.save_dir}!")
225
+ path = sorted(files)[-1]
226
+ path = os.path.join(self.config.save_dir, path)
227
+
228
+ if map_location is None:
229
+ map_location = self.config.device
230
+ logger.debug(f"Loading from {path}...")
231
+ try:
232
+ state_dicts = torch.load(path, map_location=map_location, weights_only=True)
233
+ except UnpicklingError:
234
+ state_dicts = torch.load(path, map_location=map_location, weights_only=False)
235
+ warnings.warn("The saved coref model has an old format using Config instead of the Config mapped to dict to store weights. This version of Stanza can support reading both the new and the old formats. Future versions will only allow loading with weights_only=True. Please resave the coref model using this version ASAP.")
236
+ self.epochs_trained = state_dicts.pop("epochs_trained", 0)
237
+ # just ignore a config in the model, since we should already have one
238
+ # TODO: some config elements may be fixed parameters of the model,
239
+ # such as the dimensions of the head,
240
+ # so we would want to use the ones from the config even if the
241
+ # user created a weird shaped model
242
+ config = state_dicts.pop("config", {})
243
+ self.load_state_dicts(state_dicts, ignore)
244
+
245
+ def load_state_dicts(self,
246
+ state_dicts: dict,
247
+ ignore: Optional[Set[str]] = None):
248
+ """
249
+ Process the dictionaries from the save file
250
+
251
+ Loads the weights into the tensors of this model
252
+ May also have optimizer and/or schedule state
253
+ """
254
+ for key, state_dict in state_dicts.items():
255
+ logger.debug("Loading state: %s", key)
256
+ if not ignore or key not in ignore:
257
+ if key.endswith("_optimizer"):
258
+ self.optimizers[key].load_state_dict(state_dict)
259
+ elif key.endswith("_scheduler"):
260
+ self.schedulers[key].load_state_dict(state_dict)
261
+ elif key == "bert_lora":
262
+ assert self.config.lora, "Unable to load state dict of LoRA model into model initialized without LoRA!"
263
+ self.bert = load_peft_wrapper(self.bert, state_dict, vars(self.config), logger, self.peft_name)
264
+ else:
265
+ self.trainable[key].load_state_dict(state_dict, strict=False)
266
+ logger.debug(f"Loaded {key}")
267
+ if self.config.log_norms:
268
+ self.log_norms()
269
+
270
+ def build_doc(self, doc: dict) -> dict:
271
+ filter_func = TOKENIZER_FILTERS.get(self.config.bert_model,
272
+ lambda _: True)
273
+ token_map = TOKENIZER_MAPS.get(self.config.bert_model, {})
274
+
275
+ word2subword = []
276
+ subwords = []
277
+ word_id = []
278
+ for i, word in enumerate(doc["cased_words"]):
279
+ tokenized_word = (token_map[word]
280
+ if word in token_map
281
+ else self.tokenizer.tokenize(word))
282
+ tokenized_word = list(filter(filter_func, tokenized_word))
283
+ word2subword.append((len(subwords), len(subwords) + len(tokenized_word)))
284
+ subwords.extend(tokenized_word)
285
+ word_id.extend([i] * len(tokenized_word))
286
+ doc["word2subword"] = word2subword
287
+ doc["subwords"] = subwords
288
+ doc["word_id"] = word_id
289
+
290
+ doc["head2span"] = []
291
+ if "speaker" not in doc:
292
+ doc["speaker"] = ["_" for _ in doc["cased_words"]]
293
+ doc["word_clusters"] = []
294
+ doc["span_clusters"] = []
295
+
296
+ return doc
297
+
298
+
299
+ @staticmethod
300
+ def load_model(path: str,
301
+ map_location: str = "cpu",
302
+ ignore: Optional[Set[str]] = None,
303
+ config_update: Optional[dict] = None,
304
+ foundation_cache = None):
305
+ if not path:
306
+ raise FileNotFoundError("coref model got an invalid path |%s|" % path)
307
+ if not os.path.exists(path):
308
+ raise FileNotFoundError("coref model file %s not found" % path)
309
+ try:
310
+ state_dicts = torch.load(path, map_location=map_location, weights_only=True)
311
+ except UnpicklingError:
312
+ state_dicts = torch.load(path, map_location=map_location, weights_only=False)
313
+ warnings.warn("The saved coref model has an old format using Config instead of the Config mapped to dict to store weights. This version of Stanza can support reading both the new and the old formats. Future versions will only allow loading with weights_only=True. Please resave the coref model using this version ASAP.")
314
+ epochs_trained = state_dicts.pop("epochs_trained", 0)
315
+ config = state_dicts.pop('config', None)
316
+ if config is None:
317
+ raise ValueError("Cannot load this format model without config in the dicts")
318
+ if isinstance(config, dict):
319
+ config = Config(**config)
320
+ if config_update:
321
+ for key, value in config_update.items():
322
+ setattr(config, key, value)
323
+ model = CorefModel(config=config, build_optimizers=False,
324
+ epochs_trained=epochs_trained, foundation_cache=foundation_cache)
325
+ model.load_state_dicts(state_dicts, ignore)
326
+ return model
327
+
328
+
329
+ def run(self, # pylint: disable=too-many-locals
330
+ doc: Doc,
331
+ ) -> CorefResult:
332
+ """
333
+ This is a massive method, but it made sense to me to not split it into
334
+ several ones to let one see the data flow.
335
+
336
+ Args:
337
+ doc (Doc): a dictionary with the document data.
338
+
339
+ Returns:
340
+ CorefResult (see const.py)
341
+ """
342
+ # Encode words with bert
343
+ # words [n_words, span_emb]
344
+ # cluster_ids [n_words]
345
+ words, cluster_ids = self.we(doc, self._bertify(doc))
346
+
347
+ # Obtain bilinear scores and leave only top-k antecedents for each word
348
+ # top_rough_scores [n_words, n_ants]
349
+ # top_indices [n_words, n_ants]
350
+ top_rough_scores, top_indices, rough_scores = self.rough_scorer(words)
351
+
352
+ # Get pairwise features [n_words, n_ants, n_pw_features]
353
+ pw = self.pw(top_indices, doc)
354
+
355
+ batch_size = self.config.a_scoring_batch_size
356
+ a_scores_lst: List[torch.Tensor] = []
357
+
358
+ for i in range(0, len(words), batch_size):
359
+ pw_batch = pw[i:i + batch_size]
360
+ words_batch = words[i:i + batch_size]
361
+ top_indices_batch = top_indices[i:i + batch_size]
362
+ top_rough_scores_batch = top_rough_scores[i:i + batch_size]
363
+
364
+ # a_scores_batch [batch_size, n_ants]
365
+ a_scores_batch = self.a_scorer(
366
+ top_mentions=words[top_indices_batch], mentions_batch=words_batch,
367
+ pw_batch=pw_batch, top_rough_scores_batch=top_rough_scores_batch
368
+ )
369
+ a_scores_lst.append(a_scores_batch)
370
+
371
+ res = CorefResult()
372
+
373
+ # coref_scores [n_spans, n_ants]
374
+ res.coref_scores = torch.cat(a_scores_lst, dim=0)
375
+
376
+ res.coref_y = self._get_ground_truth(
377
+ cluster_ids, top_indices, (top_rough_scores > float("-inf")),
378
+ self.config.clusters_starts_are_singletons,
379
+ self.config.singletons)
380
+
381
+ res.word_clusters = self._clusterize(doc, res.coref_scores, top_indices,
382
+ self.config.singletons)
383
+
384
+ res.span_scores, res.span_y = self.sp.get_training_data(doc, words)
385
+
386
+ if not self.training:
387
+ res.span_clusters = self.sp.predict(doc, words, res.word_clusters)
388
+
389
+ return res
390
+
391
+ def save_weights(self, save_path=None, save_optimizers=True):
392
+ """ Saves trainable models as state dicts. """
393
+ to_save: List[Tuple[str, Any]] = \
394
+ [(key, value) for key, value in self.trainable.items()
395
+ if (self.config.bert_finetune and not self.config.lora) or key != "bert"]
396
+ if save_optimizers:
397
+ to_save.extend(self.optimizers.items())
398
+ to_save.extend(self.schedulers.items())
399
+
400
+ time = datetime.strftime(datetime.now(), "%Y.%m.%d_%H.%M")
401
+ if save_path is None:
402
+ save_path = os.path.join(self.config.save_dir,
403
+ f"{self.config.save_name}"
404
+ f"_e{self.epochs_trained}_{time}.pt")
405
+ savedict = {name: module.state_dict() for name, module in to_save}
406
+ if self.config.lora:
407
+ # so that this dependency remains optional
408
+ from peft import get_peft_model_state_dict
409
+ savedict["bert_lora"] = get_peft_model_state_dict(self.bert, adapter_name="coref")
410
+ savedict["epochs_trained"] = self.epochs_trained # type: ignore
411
+ # save as a dictionary because the weights_only=True load option
412
+ # doesn't allow for arbitrary @dataclass configs
413
+ savedict["config"] = dataclasses.asdict(self.config)
414
+ save_dir = os.path.split(save_path)[0]
415
+ if save_dir:
416
+ os.makedirs(save_dir, exist_ok=True)
417
+ torch.save(savedict, save_path)
418
+
419
+ def log_norms(self):
420
+ lines = ["NORMS FOR MODEL PARAMTERS"]
421
+ for t_name, trainable in self.trainable.items():
422
+ for name, param in trainable.named_parameters():
423
+ if param.requires_grad:
424
+ lines.append(" %s: %s %.6g (%d)" % (t_name, name, torch.norm(param).item(), param.numel()))
425
+ logger.info("\n".join(lines))
426
+
427
+
428
+ def train(self, log=False):
429
+ """
430
+ Trains all the trainable blocks in the model using the config provided.
431
+
432
+ log: whether or not to log using wandb
433
+ skip_lang: str if we want to skip training this language (used for ablation)
434
+ """
435
+
436
+ if log:
437
+ import wandb
438
+ wandb.watch((self.bert, self.pw,
439
+ self.a_scorer, self.we,
440
+ self.rough_scorer, self.sp))
441
+
442
+ docs = self._get_docs(self.config.train_data)
443
+ docs_ids = list(range(len(docs)))
444
+ avg_spans = docs.avg_span
445
+
446
+ best_f1 = None
447
+ for epoch in range(self.epochs_trained, self.config.train_epochs):
448
+ self.training = True
449
+ if self.config.log_norms:
450
+ self.log_norms()
451
+ running_c_loss = 0.0
452
+ running_s_loss = 0.0
453
+ random.shuffle(docs_ids)
454
+ pbar = tqdm(docs_ids, unit="docs", ncols=0)
455
+ for doc_indx, doc_id in enumerate(pbar):
456
+ doc = docs[doc_id]
457
+
458
+ # skip very long documents during training time
459
+ if len(doc["subwords"]) > 5000:
460
+ continue
461
+
462
+ for optim in self.optimizers.values():
463
+ optim.zero_grad()
464
+
465
+ res = self.run(doc)
466
+
467
+ c_loss = self._coref_criterion(res.coref_scores, res.coref_y)
468
+
469
+ if res.span_y:
470
+ s_loss = (self._span_criterion(res.span_scores[:, :, 0], res.span_y[0])
471
+ + self._span_criterion(res.span_scores[:, :, 1], res.span_y[1])) / avg_spans / 2
472
+ else:
473
+ s_loss = torch.zeros_like(c_loss)
474
+
475
+ del res
476
+
477
+ (c_loss + s_loss).backward()
478
+
479
+ running_c_loss += c_loss.item()
480
+ running_s_loss += s_loss.item()
481
+
482
+ # log every 100 docs
483
+ if log and doc_indx % 100 == 0:
484
+ wandb.log({'train_c_loss': c_loss.item(),
485
+ 'train_s_loss': s_loss.item()})
486
+
487
+
488
+ del c_loss, s_loss
489
+
490
+ for optim in self.optimizers.values():
491
+ optim.step()
492
+ for scheduler in self.schedulers.values():
493
+ scheduler.step()
494
+
495
+ pbar.set_description(
496
+ f"Epoch {epoch + 1}:"
497
+ f" {doc['document_id']:26}"
498
+ f" c_loss: {running_c_loss / (pbar.n + 1):<.5f}"
499
+ f" s_loss: {running_s_loss / (pbar.n + 1):<.5f}"
500
+ )
501
+
502
+ self.epochs_trained += 1
503
+ scores = self.evaluate()
504
+ prev_best_f1 = best_f1
505
+ if log:
506
+ wandb.log({'dev_score': scores[1]})
507
+ wandb.log({'dev_bakeoff': scores[-1]})
508
+
509
+ if best_f1 is None or scores[1] > best_f1:
510
+
511
+ if best_f1 is None:
512
+ logger.info("Saving new best model: F1 %.4f", scores[1])
513
+ else:
514
+ logger.info("Saving new best model: F1 %.4f > %.4f", scores[1], best_f1)
515
+ best_f1 = scores[1]
516
+ if self.config.save_name.endswith(".pt"):
517
+ save_path = os.path.join(self.config.save_dir,
518
+ f"{self.config.save_name}")
519
+ else:
520
+ save_path = os.path.join(self.config.save_dir,
521
+ f"{self.config.save_name}.pt")
522
+ self.save_weights(save_path, save_optimizers=False)
523
+ if self.config.save_each_checkpoint:
524
+ self.save_weights()
525
+ else:
526
+ if self.config.save_name.endswith(".pt"):
527
+ checkpoint_path = os.path.join(self.config.save_dir,
528
+ f"{self.config.save_name[:-3]}.checkpoint.pt")
529
+ else:
530
+ checkpoint_path = os.path.join(self.config.save_dir,
531
+ f"{self.config.save_name}.checkpoint.pt")
532
+ self.save_weights(checkpoint_path)
533
+ if prev_best_f1 is not None and prev_best_f1 != best_f1:
534
+ logger.info("Epoch %d finished.\nSentence F1 %.5f p %.5f r %.5f\nBest F1 %.5f\nPrevious best F1 %.5f", self.epochs_trained, scores[1], scores[2], scores[3], best_f1, prev_best_f1)
535
+ else:
536
+ logger.info("Epoch %d finished.\nSentence F1 %.5f p %.5f r %.5f\nBest F1 %.5f", self.epochs_trained, scores[1], scores[2], scores[3], best_f1)
537
+
538
+ # ========================================================= Private methods
539
+
540
+ def _bertify(self, doc: Doc) -> torch.Tensor:
541
+ all_batches = bert.get_subwords_batches(doc, self.config, self.tokenizer)
542
+
543
+ # we index the batches n at a time to prevent oom
544
+ result = []
545
+ for i in range(0, all_batches.shape[0], 1024):
546
+ subwords_batches = all_batches[i:i+1024]
547
+
548
+ special_tokens = np.array([self.tokenizer.cls_token_id,
549
+ self.tokenizer.sep_token_id,
550
+ self.tokenizer.pad_token_id,
551
+ self.tokenizer.eos_token_id])
552
+ subword_mask = ~(np.isin(subwords_batches, special_tokens))
553
+
554
+ subwords_batches_tensor = torch.tensor(subwords_batches,
555
+ device=self.config.device,
556
+ dtype=torch.long)
557
+ subword_mask_tensor = torch.tensor(subword_mask,
558
+ device=self.config.device)
559
+
560
+ # Obtain bert output for selected batches only
561
+ attention_mask = (subwords_batches != self.tokenizer.pad_token_id)
562
+ if "t5" in self.config.bert_model:
563
+ out = self.bert.encoder(
564
+ input_ids=subwords_batches_tensor,
565
+ attention_mask=torch.tensor(
566
+ attention_mask, device=self.config.device))
567
+ else:
568
+ out = self.bert(
569
+ subwords_batches_tensor,
570
+ attention_mask=torch.tensor(
571
+ attention_mask, device=self.config.device))
572
+
573
+ out = out['last_hidden_state']
574
+ # [n_subwords, bert_emb]
575
+ result.append(out[subword_mask_tensor])
576
+
577
+ # stack returns and return
578
+ return torch.cat(result)
579
+
580
+ def _build_model(self, foundation_cache):
581
+ if hasattr(self.config, 'lora') and self.config.lora:
582
+ self.bert, self.tokenizer, peft_name = load_bert_with_peft(self.config.bert_model, "coref", foundation_cache)
583
+ # vars() converts a dataclass to a dict, used for being able to index things like args["lora_*"]
584
+ self.bert = build_peft_wrapper(self.bert, vars(self.config), logger, adapter_name=peft_name)
585
+ self.peft_name = peft_name
586
+ else:
587
+ if self.config.bert_finetune:
588
+ logger.debug("Coref model requested a finetuned transformer; we are not using the foundation model cache to prevent we accidentally leak the finetuning weights elsewhere.")
589
+ foundation_cache = NoTransformerFoundationCache(foundation_cache)
590
+ self.bert, self.tokenizer = load_bert(self.config.bert_model, foundation_cache)
591
+
592
+ base_bert_name = self.config.bert_model.split("/")[-1]
593
+ tokenizer_kwargs = self.config.tokenizer_kwargs.get(base_bert_name, {})
594
+ if tokenizer_kwargs:
595
+ logger.debug(f"Using tokenizer kwargs: {tokenizer_kwargs}")
596
+ # we just downloaded the tokenizer, so for simplicity, we don't make another request to HF
597
+ self.tokenizer = load_tokenizer(self.config.bert_model, tokenizer_kwargs, local_files_only=True)
598
+
599
+ if self.config.bert_finetune or (hasattr(self.config, 'lora') and self.config.lora):
600
+ self.bert = self.bert.train()
601
+
602
+ self.bert = self.bert.to(self.config.device)
603
+ self.pw = PairwiseEncoder(self.config).to(self.config.device)
604
+
605
+ bert_emb = self.bert.config.hidden_size
606
+ pair_emb = bert_emb * 3 + self.pw.shape
607
+
608
+ # pylint: disable=line-too-long
609
+ self.a_scorer = AnaphoricityScorer(pair_emb, self.config).to(self.config.device)
610
+ self.we = WordEncoder(bert_emb, self.config).to(self.config.device)
611
+ self.rough_scorer = RoughScorer(bert_emb, self.config).to(self.config.device)
612
+ self.sp = SpanPredictor(bert_emb, self.config.sp_embedding_size).to(self.config.device)
613
+
614
+ self.trainable: Dict[str, torch.nn.Module] = {
615
+ "bert": self.bert, "we": self.we,
616
+ "rough_scorer": self.rough_scorer,
617
+ "pw": self.pw, "a_scorer": self.a_scorer,
618
+ "sp": self.sp
619
+ }
620
+
621
+ def _build_optimizers(self):
622
+ n_docs = len(self._get_docs(self.config.train_data))
623
+ self.optimizers: Dict[str, torch.optim.Optimizer] = {}
624
+ self.schedulers: Dict[str, torch.optim.lr_scheduler.LRScheduler] = {}
625
+
626
+ if not getattr(self.config, 'lora', False):
627
+ for param in self.bert.parameters():
628
+ param.requires_grad = self.config.bert_finetune
629
+
630
+ if self.config.bert_finetune:
631
+ logger.debug("Making bert optimizer with LR of %f", self.config.bert_learning_rate)
632
+ self.optimizers["bert_optimizer"] = torch.optim.Adam(
633
+ self.bert.parameters(), lr=self.config.bert_learning_rate
634
+ )
635
+ start_finetuning = int(n_docs * self.config.bert_finetune_begin_epoch)
636
+ if start_finetuning > 0:
637
+ logger.info("Will begin finetuning transformer at iteration %d", start_finetuning)
638
+ zero_scheduler = torch.optim.lr_scheduler.ConstantLR(self.optimizers["bert_optimizer"], factor=0, total_iters=start_finetuning)
639
+ warmup_scheduler = transformers.get_linear_schedule_with_warmup(
640
+ self.optimizers["bert_optimizer"],
641
+ start_finetuning, n_docs * self.config.train_epochs - start_finetuning)
642
+ self.schedulers["bert_scheduler"] = torch.optim.lr_scheduler.SequentialLR(
643
+ self.optimizers["bert_optimizer"],
644
+ schedulers=[zero_scheduler, warmup_scheduler],
645
+ milestones=[start_finetuning])
646
+
647
+ # Must ensure the same ordering of parameters between launches
648
+ modules = sorted((key, value) for key, value in self.trainable.items()
649
+ if key != "bert")
650
+ params = []
651
+ for _, module in modules:
652
+ for param in module.parameters():
653
+ param.requires_grad = True
654
+ params.append(param)
655
+
656
+ self.optimizers["general_optimizer"] = torch.optim.Adam(
657
+ params, lr=self.config.learning_rate)
658
+ self.schedulers["general_scheduler"] = \
659
+ transformers.get_linear_schedule_with_warmup(
660
+ self.optimizers["general_optimizer"],
661
+ 0, n_docs * self.config.train_epochs
662
+ )
663
+
664
+ def _clusterize(self, doc: Doc, scores: torch.Tensor, top_indices: torch.Tensor,
665
+ singletons: bool = True):
666
+ if singletons:
667
+ antecedents = scores[:,1:].argmax(dim=1) - 1
668
+ # set the dummy values to -1, so that they are not coref to themselves
669
+ is_start = (scores[:, :2].argmax(dim=1) == 0)
670
+ else:
671
+ antecedents = scores.argmax(dim=1) - 1
672
+
673
+ not_dummy = antecedents >= 0
674
+ coref_span_heads = torch.arange(0, len(scores), device=not_dummy.device)[not_dummy]
675
+ antecedents = top_indices[coref_span_heads, antecedents[not_dummy]]
676
+
677
+ nodes = [GraphNode(i) for i in range(len(doc["cased_words"]))]
678
+ for i, j in zip(coref_span_heads.tolist(), antecedents.tolist()):
679
+ nodes[i].link(nodes[j])
680
+ assert nodes[i] is not nodes[j]
681
+
682
+ visited = {}
683
+
684
+ clusters = []
685
+ for node in nodes:
686
+ if len(node.links) > 0 and not node.visited:
687
+ cluster = []
688
+ stack = [node]
689
+ while stack:
690
+ current_node = stack.pop()
691
+ current_node.visited = True
692
+ cluster.append(current_node.id)
693
+ stack.extend(link for link in current_node.links if not link.visited)
694
+ assert len(cluster) > 1
695
+ for i in cluster:
696
+ visited[i] = True
697
+ clusters.append(sorted(cluster))
698
+
699
+ if singletons:
700
+ # go through the is_start nodes; if no clusters contain that node
701
+ # i.e. visited[i] == False, we add it as a singleton
702
+ for indx, i in enumerate(is_start):
703
+ if i and not visited.get(indx, False):
704
+ clusters.append([indx])
705
+
706
+ return sorted(clusters)
707
+
708
+ def _get_docs(self, path: str) -> List[Doc]:
709
+ if path not in self._docs:
710
+ self._docs[path] = CorefDataset(path, self.config, self.tokenizer)
711
+ return self._docs[path]
712
+
713
+ @staticmethod
714
+ def _get_ground_truth(cluster_ids: torch.Tensor,
715
+ top_indices: torch.Tensor,
716
+ valid_pair_map: torch.Tensor,
717
+ cluster_starts: bool,
718
+ singletons:bool = True) -> torch.Tensor:
719
+ """
720
+ Args:
721
+ cluster_ids: tensor of shape [n_words], containing cluster indices
722
+ for each word. Non-gold words have cluster id of zero.
723
+ top_indices: tensor of shape [n_words, n_ants],
724
+ indices of antecedents of each word
725
+ valid_pair_map: boolean tensor of shape [n_words, n_ants],
726
+ whether for pair at [i, j] (i-th word and j-th word)
727
+ j < i is True
728
+
729
+ Returns:
730
+ tensor of shape [n_words, n_ants + 1] (dummy added),
731
+ containing 1 at position [i, j] if i-th and j-th words corefer.
732
+ """
733
+ y = cluster_ids[top_indices] * valid_pair_map # [n_words, n_ants]
734
+ y[y == 0] = -1 # -1 for non-gold words
735
+ y = utils.add_dummy(y) # [n_words, n_cands + 1]
736
+
737
+ if singletons:
738
+ if not cluster_starts:
739
+ unique, counts = cluster_ids.unique(return_counts=True)
740
+ singleton_clusters = unique[(counts == 1) & (unique != 0)]
741
+ first_corefs = [(cluster_ids == i).nonzero().flatten()[0] for i in singleton_clusters]
742
+ if len(first_corefs) > 0:
743
+ first_coref = torch.stack(first_corefs)
744
+ else:
745
+ first_coref = torch.tensor([]).to(cluster_ids.device).long()
746
+ else:
747
+ # I apologize for this abuse of everything that's good about PyTorch.
748
+ # in essence, this line finds the INDEX of FIRST OCCURENCE of each NON-ZERO value
749
+ # from cluster_ids. We need this information because we use it to mark the
750
+ # special "is-start-of-ref" marker used to detect singletons.
751
+ first_coref = (cluster_ids ==
752
+ cluster_ids.unique().sort().values[1:].unsqueeze(1)
753
+ ).float().topk(k=1, dim=1).indices.squeeze()
754
+ y = (y == cluster_ids.unsqueeze(1)) # True if coreferent
755
+ # For all rows with no gold antecedents setting dummy to True
756
+ y[y.sum(dim=1) == 0, 0] = True
757
+
758
+ if singletons:
759
+ # add another dummy for first coref
760
+ y = utils.add_dummy(y) # [n_words, n_cands + 2]
761
+ # for all rows that's a first coref, setting its dummy to True and unset the
762
+ # non-coref dummy to false
763
+ y[first_coref, 0] = True
764
+ y[first_coref, 1] = False
765
+ return y.to(torch.float)
766
+
767
+ @staticmethod
768
+ def _load_config(config_path: str,
769
+ section: str) -> Config:
770
+ with open(config_path, "rb") as fin:
771
+ config = tomllib.load(fin)
772
+ default_section = config["DEFAULT"]
773
+ current_section = config[section]
774
+ unknown_keys = (set(current_section.keys())
775
+ - set(default_section.keys()))
776
+ if unknown_keys:
777
+ raise ValueError(f"Unexpected config keys: {unknown_keys}")
778
+ return Config(section, **{**default_section, **current_section})
779
+
780
+ def _set_training(self, value: bool):
781
+ self._training = value
782
+ for module in self.trainable.values():
783
+ module.train(self._training)
784
+
stanza/stanza/models/depparse/__init__.py ADDED
File without changes
stanza/stanza/models/depparse/scorer.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Utils and wrappers for scoring parsers.
3
+ """
4
+
5
+ from collections import Counter
6
+ import logging
7
+
8
+ from stanza.models.common.utils import ud_scores
9
+
10
+ logger = logging.getLogger('stanza')
11
+
12
+ def score_named_dependencies(pred_doc, gold_doc):
13
+ if len(pred_doc.sentences) != len(gold_doc.sentences):
14
+ logger.warning("Not evaluating individual dependency F1 on accound of document length mismatch")
15
+ return
16
+ for sent_idx, (x, y) in enumerate(zip(pred_doc.sentences, gold_doc.sentences)):
17
+ if len(x.words) != len(y.words):
18
+ logger.warning("Not evaluating individual dependency F1 on accound of sentence length mismatch")
19
+ return
20
+
21
+ tp = Counter()
22
+ fp = Counter()
23
+ fn = Counter()
24
+ for pred_sentence, gold_sentence in zip(pred_doc.sentences, gold_doc.sentences):
25
+ for pred_word, gold_word in zip(pred_sentence.words, gold_sentence.words):
26
+ if pred_word.head == gold_word.head and pred_word.deprel == gold_word.deprel:
27
+ tp[gold_word.deprel] = tp[gold_word.deprel] + 1
28
+ else:
29
+ fn[gold_word.deprel] = fn[gold_word.deprel] + 1
30
+ fp[pred_word.deprel] = fp[pred_word.deprel] + 1
31
+
32
+ labels = sorted(set(tp.keys()).union(fp.keys()).union(fn.keys()))
33
+ max_len = max(len(x) for x in labels)
34
+ log_lines = []
35
+ log_line_fmt = "%" + str(max_len) + "s: p %.4f r %.4f f1 %.4f (%d actual)"
36
+ for label in labels:
37
+ if tp[label] == 0:
38
+ precision = 0
39
+ recall = 0
40
+ f1 = 0
41
+ else:
42
+ precision = tp[label] / (tp[label] + fp[label])
43
+ recall = tp[label] / (tp[label] + fn[label])
44
+ f1 = 2 * (precision * recall) / (precision + recall)
45
+ log_lines.append(log_line_fmt % (label, precision, recall, f1, tp[label] + fn[label]))
46
+ logger.info("F1 scores for each dependency:\n Note that unlabeled attachment errors hurt the labeled attachment scores\n%s" % "\n".join(log_lines))
47
+
48
+ def score(system_conllu_file, gold_conllu_file, verbose=True):
49
+ """ Wrapper for UD parser scorer. """
50
+ evaluation = ud_scores(gold_conllu_file, system_conllu_file)
51
+ el = evaluation['LAS']
52
+ p = el.precision
53
+ r = el.recall
54
+ f = el.f1
55
+ if verbose:
56
+ scores = [evaluation[k].f1 * 100 for k in ['LAS', 'MLAS', 'BLEX']]
57
+ logger.info("LAS\tMLAS\tBLEX")
58
+ logger.info("{:.2f}\t{:.2f}\t{:.2f}".format(*scores))
59
+ return p, r, f
60
+
stanza/stanza/models/depparse/trainer.py ADDED
@@ -0,0 +1,250 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ A trainer class to handle training and testing of models.
3
+ """
4
+
5
+ import copy
6
+ import sys
7
+ import logging
8
+ import torch
9
+ from torch import nn
10
+
11
+ try:
12
+ import transformers
13
+ except ImportError:
14
+ pass
15
+
16
+ from stanza.models.common.trainer import Trainer as BaseTrainer
17
+ from stanza.models.common import utils, loss
18
+ from stanza.models.common.foundation_cache import load_bert, load_bert_with_peft, NoTransformerFoundationCache
19
+ from stanza.models.common.chuliu_edmonds import chuliu_edmonds_one_root
20
+ from stanza.models.common.peft_config import build_peft_wrapper, load_peft_wrapper
21
+ from stanza.models.depparse.model import Parser
22
+ from stanza.models.pos.vocab import MultiVocab
23
+
24
+ logger = logging.getLogger('stanza')
25
+
26
+ def unpack_batch(batch, device):
27
+ """ Unpack a batch from the data loader. """
28
+ inputs = [b.to(device) if b is not None else None for b in batch[:11]]
29
+ orig_idx = batch[11]
30
+ word_orig_idx = batch[12]
31
+ sentlens = batch[13]
32
+ wordlens = batch[14]
33
+ text = batch[15]
34
+ return inputs, orig_idx, word_orig_idx, sentlens, wordlens, text
35
+
36
+ class Trainer(BaseTrainer):
37
+ """ A trainer for training models. """
38
+ def __init__(self, args=None, vocab=None, pretrain=None, model_file=None,
39
+ device=None, foundation_cache=None, ignore_model_config=False, reset_history=False):
40
+ self.global_step = 0
41
+ self.last_best_step = 0
42
+ self.dev_score_history = []
43
+
44
+ orig_args = copy.deepcopy(args)
45
+ # whether the training is in primary or secondary stage
46
+ # during FT (loading weights), etc., the training is considered to be in "secondary stage"
47
+ # during this time, we (optionally) use a different set of optimizers than that during "primary stage".
48
+ #
49
+ # Regardless, we use TWO SETS of optimizers; once primary converges, we switch to secondary
50
+
51
+ if model_file is not None:
52
+ # load everything from file
53
+ self.load(model_file, pretrain, args, foundation_cache, device)
54
+
55
+ if reset_history:
56
+ self.global_step = 0
57
+ self.last_best_step = 0
58
+ self.dev_score_history = []
59
+ else:
60
+ # build model from scratch
61
+ self.args = args
62
+ self.vocab = vocab
63
+
64
+ bert_model, bert_tokenizer = load_bert(self.args['bert_model'])
65
+ peft_name = None
66
+ if self.args['use_peft']:
67
+ # fine tune the bert if we're using peft
68
+ self.args['bert_finetune'] = True
69
+ peft_name = "depparse"
70
+ bert_model = build_peft_wrapper(bert_model, self.args, logger, adapter_name=peft_name)
71
+
72
+ self.model = Parser(args, vocab, emb_matrix=pretrain.emb if pretrain is not None else None, foundation_cache=foundation_cache, bert_model=bert_model, bert_tokenizer=bert_tokenizer, force_bert_saved=self.args['bert_finetune'], peft_name=peft_name)
73
+ self.model = self.model.to(device)
74
+ self.__init_optim()
75
+
76
+ if ignore_model_config:
77
+ self.args = orig_args
78
+
79
+ if self.args.get('wandb'):
80
+ import wandb
81
+ # track gradients!
82
+ wandb.watch(self.model, log_freq=4, log="all", log_graph=True)
83
+
84
+ def __init_optim(self):
85
+ # TODO: can get rid of args.get when models are rebuilt
86
+ if (self.args.get("second_stage", False) and self.args.get('second_optim')):
87
+ self.optimizer = utils.get_split_optimizer(self.args['second_optim'], self.model,
88
+ self.args['second_lr'], betas=(0.9, self.args['beta2']), eps=1e-6,
89
+ bert_learning_rate=self.args.get('second_bert_learning_rate', 0.0),
90
+ is_peft=self.args.get('use_peft', False),
91
+ bert_finetune_layers=self.args.get('bert_finetune_layers', None))
92
+ else:
93
+ self.optimizer = utils.get_split_optimizer(self.args['optim'], self.model,
94
+ self.args['lr'], betas=(0.9, self.args['beta2']),
95
+ eps=1e-6, bert_learning_rate=self.args.get('bert_learning_rate', 0.0),
96
+ weight_decay=self.args.get('weight_decay', None),
97
+ bert_weight_decay=self.args.get('bert_weight_decay', 0.0),
98
+ is_peft=self.args.get('use_peft', False),
99
+ bert_finetune_layers=self.args.get('bert_finetune_layers', None))
100
+ self.scheduler = {}
101
+ if self.args.get("second_stage", False) and self.args.get('second_optim'):
102
+ if self.args.get('second_warmup_steps', None):
103
+ for name, optimizer in self.optimizer.items():
104
+ name = name + "_scheduler"
105
+ warmup_scheduler = transformers.get_constant_schedule_with_warmup(optimizer, self.args['second_warmup_steps'])
106
+ self.scheduler[name] = warmup_scheduler
107
+ else:
108
+ if "bert_optimizer" in self.optimizer:
109
+ zero_scheduler = torch.optim.lr_scheduler.ConstantLR(self.optimizer["bert_optimizer"], factor=0, total_iters=self.args['bert_start_finetuning'])
110
+ warmup_scheduler = transformers.get_constant_schedule_with_warmup(
111
+ self.optimizer["bert_optimizer"],
112
+ self.args['bert_warmup_steps'])
113
+ self.scheduler["bert_scheduler"] = torch.optim.lr_scheduler.SequentialLR(
114
+ self.optimizer["bert_optimizer"],
115
+ schedulers=[zero_scheduler, warmup_scheduler],
116
+ milestones=[self.args['bert_start_finetuning']])
117
+
118
+ def update(self, batch, eval=False):
119
+ device = next(self.model.parameters()).device
120
+ inputs, orig_idx, word_orig_idx, sentlens, wordlens, text = unpack_batch(batch, device)
121
+ word, word_mask, wordchars, wordchars_mask, upos, xpos, ufeats, pretrained, lemma, head, deprel = inputs
122
+
123
+ if eval:
124
+ self.model.eval()
125
+ else:
126
+ self.model.train()
127
+ for opt in self.optimizer.values():
128
+ opt.zero_grad()
129
+ loss, _ = self.model(word, word_mask, wordchars, wordchars_mask, upos, xpos, ufeats, pretrained, lemma, head, deprel, word_orig_idx, sentlens, wordlens, text)
130
+ loss_val = loss.data.item()
131
+ if eval:
132
+ return loss_val
133
+
134
+ loss.backward()
135
+ torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.args['max_grad_norm'])
136
+ for opt in self.optimizer.values():
137
+ opt.step()
138
+ for scheduler in self.scheduler.values():
139
+ scheduler.step()
140
+ return loss_val
141
+
142
+ def predict(self, batch, unsort=True):
143
+ device = next(self.model.parameters()).device
144
+ inputs, orig_idx, word_orig_idx, sentlens, wordlens, text = unpack_batch(batch, device)
145
+ word, word_mask, wordchars, wordchars_mask, upos, xpos, ufeats, pretrained, lemma, head, deprel = inputs
146
+
147
+ self.model.eval()
148
+ batch_size = word.size(0)
149
+ _, preds = self.model(word, word_mask, wordchars, wordchars_mask, upos, xpos, ufeats, pretrained, lemma, head, deprel, word_orig_idx, sentlens, wordlens, text)
150
+ head_seqs = [chuliu_edmonds_one_root(adj[:l, :l])[1:] for adj, l in zip(preds[0], sentlens)] # remove attachment for the root
151
+ deprel_seqs = [self.vocab['deprel'].unmap([preds[1][i][j+1][h] for j, h in enumerate(hs)]) for i, hs in enumerate(head_seqs)]
152
+
153
+ pred_tokens = [[[str(head_seqs[i][j]), deprel_seqs[i][j]] for j in range(sentlens[i]-1)] for i in range(batch_size)]
154
+ if unsort:
155
+ pred_tokens = utils.unsort(pred_tokens, orig_idx)
156
+ return pred_tokens
157
+
158
+ def save(self, filename, skip_modules=True, save_optimizer=False):
159
+ model_state = self.model.state_dict()
160
+ # skip saving modules like pretrained embeddings, because they are large and will be saved in a separate file
161
+ if skip_modules:
162
+ skipped = [k for k in model_state.keys() if k.split('.')[0] in self.model.unsaved_modules]
163
+ for k in skipped:
164
+ del model_state[k]
165
+ params = {
166
+ 'model': model_state,
167
+ 'vocab': self.vocab.state_dict(),
168
+ 'config': self.args,
169
+ 'global_step': self.global_step,
170
+ 'last_best_step': self.last_best_step,
171
+ 'dev_score_history': self.dev_score_history,
172
+ }
173
+ if self.args.get('use_peft', False):
174
+ # Hide import so that peft dependency is optional
175
+ from peft import get_peft_model_state_dict
176
+ params["bert_lora"] = get_peft_model_state_dict(self.model.bert_model, adapter_name=self.model.peft_name)
177
+
178
+ if save_optimizer and self.optimizer is not None:
179
+ params['optimizer_state_dict'] = {k: opt.state_dict() for k, opt in self.optimizer.items()}
180
+ params['scheduler_state_dict'] = {k: scheduler.state_dict() for k, scheduler in self.scheduler.items()}
181
+
182
+ try:
183
+ torch.save(params, filename, _use_new_zipfile_serialization=False)
184
+ logger.info("Model saved to {}".format(filename))
185
+ except BaseException:
186
+ logger.warning("Saving failed... continuing anyway.")
187
+
188
+ def load(self, filename, pretrain, args=None, foundation_cache=None, device=None):
189
+ """
190
+ Load a model from file, with preloaded pretrain embeddings. Here we allow the pretrain to be None or a dummy input,
191
+ and the actual use of pretrain embeddings will depend on the boolean config "pretrain" in the loaded args.
192
+ """
193
+ try:
194
+ checkpoint = torch.load(filename, lambda storage, loc: storage, weights_only=True)
195
+ except BaseException:
196
+ logger.error("Cannot load model from {}".format(filename))
197
+ raise
198
+ self.args = checkpoint['config']
199
+ if args is not None: self.args.update(args)
200
+
201
+ # preserve old models which were created before transformers were added
202
+ if 'bert_model' not in self.args:
203
+ self.args['bert_model'] = None
204
+
205
+ lora_weights = checkpoint.get('bert_lora')
206
+ if lora_weights:
207
+ logger.debug("Found peft weights for depparse; loading a peft adapter")
208
+ self.args["use_peft"] = True
209
+
210
+ self.vocab = MultiVocab.load_state_dict(checkpoint['vocab'])
211
+ # load model
212
+ emb_matrix = None
213
+ if self.args['pretrain'] and pretrain is not None: # we use pretrain only if args['pretrain'] == True and pretrain is not None
214
+ emb_matrix = pretrain.emb
215
+
216
+ # TODO: refactor this common block of code with NER
217
+ force_bert_saved = False
218
+ peft_name = None
219
+ if self.args.get('use_peft', False):
220
+ force_bert_saved = True
221
+ bert_model, bert_tokenizer, peft_name = load_bert_with_peft(self.args['bert_model'], "depparse", foundation_cache)
222
+ bert_model = load_peft_wrapper(bert_model, lora_weights, self.args, logger, peft_name)
223
+ logger.debug("Loaded peft with name %s", peft_name)
224
+ else:
225
+ if any(x.startswith("bert_model.") for x in checkpoint['model'].keys()):
226
+ logger.debug("Model %s has a finetuned transformer. Not using transformer cache to make sure the finetuned version of the transformer isn't accidentally used elsewhere", filename)
227
+ foundation_cache = NoTransformerFoundationCache(foundation_cache)
228
+ force_bert_saved = True
229
+ bert_model, bert_tokenizer = load_bert(self.args.get('bert_model'), foundation_cache)
230
+
231
+ self.model = Parser(self.args, self.vocab, emb_matrix=emb_matrix, foundation_cache=foundation_cache, bert_model=bert_model, bert_tokenizer=bert_tokenizer, force_bert_saved=force_bert_saved, peft_name=peft_name)
232
+ self.model.load_state_dict(checkpoint['model'], strict=False)
233
+
234
+ if device is not None:
235
+ self.model = self.model.to(device)
236
+
237
+ self.__init_optim()
238
+ optim_state_dict = checkpoint.get("optimizer_state_dict")
239
+ if optim_state_dict:
240
+ for k, state in optim_state_dict.items():
241
+ self.optimizer[k].load_state_dict(state)
242
+
243
+ scheduler_state_dict = checkpoint.get("scheduler_state_dict")
244
+ if scheduler_state_dict:
245
+ for k, state in scheduler_state_dict.items():
246
+ self.scheduler[k].load_state_dict(state)
247
+
248
+ self.global_step = checkpoint.get("global_step", 0)
249
+ self.last_best_step = checkpoint.get("last_best_step", 0)
250
+ self.dev_score_history = checkpoint.get("dev_score_history", list())
stanza/stanza/models/langid/trainer.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.optim as optim
3
+
4
+ from stanza.models.langid.model import LangIDBiLSTM
5
+
6
+
7
+ class Trainer:
8
+
9
+ DEFAULT_BATCH_SIZE = 64
10
+ DEFAULT_LAYERS = 2
11
+ DEFAULT_EMBEDDING_DIM = 150
12
+ DEFAULT_HIDDEN_DIM = 150
13
+
14
+ def __init__(self, config, load_model=False, device=None):
15
+ self.model_path = config["model_path"]
16
+ self.batch_size = config.get("batch_size", Trainer.DEFAULT_BATCH_SIZE)
17
+ if load_model:
18
+ self.load(config["load_name"], device)
19
+ else:
20
+ self.model = LangIDBiLSTM(config["char_to_idx"], config["tag_to_idx"], Trainer.DEFAULT_LAYERS,
21
+ Trainer.DEFAULT_EMBEDDING_DIM,
22
+ Trainer.DEFAULT_HIDDEN_DIM,
23
+ batch_size=self.batch_size,
24
+ weights=config["lang_weights"]).to(device)
25
+ self.optimizer = optim.AdamW(self.model.parameters())
26
+
27
+ def update(self, inputs):
28
+ self.model.train()
29
+ sentences, targets = inputs
30
+ self.optimizer.zero_grad()
31
+ y_hat = self.model.forward(sentences)
32
+ loss = self.model.loss(y_hat, targets)
33
+ loss.backward()
34
+ self.optimizer.step()
35
+
36
+ def predict(self, inputs):
37
+ self.model.eval()
38
+ sentences, targets = inputs
39
+ return torch.argmax(self.model(sentences), dim=1)
40
+
41
+ def save(self, label=None):
42
+ # save a copy of model with label
43
+ if label:
44
+ self.model.save(f"{self.model_path[:-3]}-{label}.pt")
45
+ self.model.save(self.model_path)
46
+
47
+ def load(self, model_path=None, device=None):
48
+ if not model_path:
49
+ model_path = self.model_path
50
+ self.model = LangIDBiLSTM.load(model_path, device, self.batch_size)
51
+
stanza/stanza/models/lemma/__init__.py ADDED
File without changes
stanza/stanza/models/lemma/data.py ADDED
@@ -0,0 +1,212 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+ import numpy as np
3
+ import os
4
+ from collections import Counter
5
+ import logging
6
+ import torch
7
+
8
+ import stanza.models.common.seq2seq_constant as constant
9
+ from stanza.models.common.data import map_to_ids, get_long_tensor, get_float_tensor, sort_all
10
+ from stanza.models.common.vocab import DeltaVocab
11
+ from stanza.models.lemma.vocab import Vocab, MultiVocab
12
+ from stanza.models.lemma import edit
13
+ from stanza.models.common.doc import *
14
+
15
+ logger = logging.getLogger('stanza')
16
+
17
+ class DataLoader:
18
+ def __init__(self, doc, batch_size, args, vocab=None, evaluation=False, conll_only=False, skip=None, expand_unk_vocab=False):
19
+ self.batch_size = batch_size
20
+ self.args = args
21
+ self.eval = evaluation
22
+ self.shuffled = not self.eval
23
+ self.doc = doc
24
+
25
+ data = self.raw_data()
26
+
27
+ if conll_only: # only load conll file
28
+ return
29
+
30
+ if skip is not None:
31
+ assert len(data) == len(skip)
32
+ data = [x for x, y in zip(data, skip) if not y]
33
+
34
+ # handle vocab
35
+ if vocab is not None:
36
+ if expand_unk_vocab:
37
+ pos_vocab = vocab['pos']
38
+ char_vocab = DeltaVocab(data, vocab['char'])
39
+ self.vocab = MultiVocab({'char': char_vocab, 'pos': pos_vocab})
40
+ else:
41
+ self.vocab = vocab
42
+ else:
43
+ self.vocab = dict()
44
+ char_vocab, pos_vocab = self.init_vocab(data)
45
+ self.vocab = MultiVocab({'char': char_vocab, 'pos': pos_vocab})
46
+
47
+ # filter and sample data
48
+ if args.get('sample_train', 1.0) < 1.0 and not self.eval:
49
+ keep = int(args['sample_train'] * len(data))
50
+ data = random.sample(data, keep)
51
+ logger.debug("Subsample training set with rate {:g}".format(args['sample_train']))
52
+
53
+ data = self.preprocess(data, self.vocab['char'], self.vocab['pos'], args)
54
+ # shuffle for training
55
+ if self.shuffled:
56
+ indices = list(range(len(data)))
57
+ random.shuffle(indices)
58
+ data = [data[i] for i in indices]
59
+ self.num_examples = len(data)
60
+
61
+ # chunk into batches
62
+ data = [data[i:i+batch_size] for i in range(0, len(data), batch_size)]
63
+ self.data = data
64
+ logger.debug("{} batches created.".format(len(data)))
65
+
66
+ def init_vocab(self, data):
67
+ assert self.eval is False, "Vocab file must exist for evaluation"
68
+ char_data = "".join(d[0] + d[2] for d in data)
69
+ char_vocab = Vocab(char_data, self.args['lang'])
70
+ pos_data = [d[1] for d in data]
71
+ pos_vocab = Vocab(pos_data, self.args['lang'])
72
+ return char_vocab, pos_vocab
73
+
74
+ def preprocess(self, data, char_vocab, pos_vocab, args):
75
+ processed = []
76
+ for d in data:
77
+ edit_type = edit.EDIT_TO_ID[edit.get_edit_type(d[0], d[2])]
78
+ src = list(d[0])
79
+ src = [constant.SOS] + src + [constant.EOS]
80
+ src = char_vocab.map(src)
81
+ pos = d[1]
82
+ pos = pos_vocab.unit2id(pos)
83
+ tgt = list(d[2])
84
+ tgt_in = char_vocab.map([constant.SOS] + tgt)
85
+ tgt_out = char_vocab.map(tgt + [constant.EOS])
86
+ processed += [[src, tgt_in, tgt_out, pos, edit_type, d[0]]]
87
+ return processed
88
+
89
+ def __len__(self):
90
+ return len(self.data)
91
+
92
+ def __getitem__(self, key):
93
+ """ Get a batch with index. """
94
+ if not isinstance(key, int):
95
+ raise TypeError
96
+ if key < 0 or key >= len(self.data):
97
+ raise IndexError
98
+ batch = self.data[key]
99
+ batch_size = len(batch)
100
+ batch = list(zip(*batch))
101
+ assert len(batch) == 6
102
+
103
+ # sort all fields by lens for easy RNN operations
104
+ lens = [len(x) for x in batch[0]]
105
+ batch, orig_idx = sort_all(batch, lens)
106
+
107
+ # convert to tensors
108
+ src = batch[0]
109
+ src = get_long_tensor(src, batch_size)
110
+ src_mask = torch.eq(src, constant.PAD_ID)
111
+ tgt_in = get_long_tensor(batch[1], batch_size)
112
+ tgt_out = get_long_tensor(batch[2], batch_size)
113
+ pos = torch.LongTensor(batch[3])
114
+ edits = torch.LongTensor(batch[4])
115
+ text = batch[5]
116
+ assert tgt_in.size(1) == tgt_out.size(1), "Target input and output sequence sizes do not match."
117
+ return src, src_mask, tgt_in, tgt_out, pos, edits, orig_idx, text
118
+
119
+ def __iter__(self):
120
+ for i in range(self.__len__()):
121
+ yield self.__getitem__(i)
122
+
123
+ def raw_data(self):
124
+ return self.load_doc(self.doc, self.args.get('caseless', False), self.eval)
125
+
126
+ @staticmethod
127
+ def load_doc(doc, caseless, evaluation):
128
+ if evaluation:
129
+ data = doc.get([TEXT, UPOS, LEMMA])
130
+ else:
131
+ data = doc.get([TEXT, UPOS, LEMMA, HEAD, DEPREL, MISC], as_sentences=True)
132
+ data = DataLoader.remove_goeswith(data)
133
+ data = DataLoader.extract_correct_forms(data)
134
+ data = DataLoader.resolve_none(data)
135
+ if caseless:
136
+ data = DataLoader.lowercase_data(data)
137
+ return data
138
+
139
+ @staticmethod
140
+ def extract_correct_forms(data):
141
+ """
142
+ Here we go through the raw data and use the CorrectForm of words tagged with CorrectForm
143
+
144
+ In addition, if the incorrect form of the word is not present in the training data,
145
+ we keep the incorrect form for the lemmatizer to learn from.
146
+ This way, it can occasionally get things right in misspelled input text.
147
+
148
+ We do check for and eliminate words where the incorrect form is already known as the
149
+ lemma for a different word. For example, in the English datasets, there is a "busy"
150
+ which was meant to be "buys", and we don't want the model to learn to lemmatize "busy" to "buy"
151
+ """
152
+ new_data = []
153
+ incorrect_forms = []
154
+ for word in data:
155
+ misc = word[-1]
156
+ if not misc:
157
+ new_data.append(word[:3])
158
+ continue
159
+ misc = misc.split("|")
160
+ for piece in misc:
161
+ if piece.startswith("CorrectForm="):
162
+ cf = piece.split("=", maxsplit=1)[1]
163
+ # treat the CorrectForm as the desired word
164
+ new_data.append((cf, word[1], word[2]))
165
+ # and save the broken one for later in case it wasn't used anywhere else
166
+ incorrect_forms.append((cf, word))
167
+ break
168
+ else:
169
+ # if no CorrectForm, just keep the word as normal
170
+ new_data.append(word[:3])
171
+ known_words = {x[0] for x in new_data}
172
+ for correct_form, word in incorrect_forms:
173
+ if word[0] not in known_words:
174
+ new_data.append(word[:3])
175
+ return new_data
176
+
177
+ @staticmethod
178
+ def remove_goeswith(data):
179
+ """
180
+ This method specifically removes words that goeswith something else, along with the something else
181
+
182
+ The purpose is to eliminate text such as
183
+
184
+ 1 Ken kenrice@enroncommunications X GW Typo=Yes 0 root 0:root _
185
+ 2 Rice@ENRON _ X GW _ 1 goeswith 1:goeswith _
186
+ 3 COMMUNICATIONS _ X ADD _ 1 goeswith 1:goeswith _
187
+ """
188
+ filtered_data = []
189
+ remove_indices = set()
190
+ for sentence in data:
191
+ remove_indices.clear()
192
+ for word_idx, word in enumerate(sentence):
193
+ if word[4] == 'goeswith':
194
+ remove_indices.add(word_idx)
195
+ remove_indices.add(word[3]-1)
196
+ filtered_data.extend([x for idx, x in enumerate(sentence) if idx not in remove_indices])
197
+ return filtered_data
198
+
199
+ @staticmethod
200
+ def lowercase_data(data):
201
+ for token in data:
202
+ token[0] = token[0].lower()
203
+ return data
204
+
205
+ @staticmethod
206
+ def resolve_none(data):
207
+ # replace None to '_'
208
+ for tok_idx in range(len(data)):
209
+ for feat_idx in range(len(data[tok_idx])):
210
+ if data[tok_idx][feat_idx] is None:
211
+ data[tok_idx][feat_idx] = '_'
212
+ return data
stanza/stanza/models/lemma_classifier/base_model.py ADDED
@@ -0,0 +1,134 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Base class for the LemmaClassifier types.
3
+
4
+ Versions include LSTM and Transformer varieties
5
+ """
6
+
7
+ import logging
8
+
9
+ from abc import ABC, abstractmethod
10
+
11
+ import os
12
+
13
+ import torch
14
+ import torch.nn as nn
15
+
16
+ from stanza.models.common.foundation_cache import load_pretrain
17
+ from stanza.models.lemma_classifier.constants import ModelType
18
+
19
+ from typing import List
20
+
21
+ logger = logging.getLogger('stanza.lemmaclassifier')
22
+
23
+ class LemmaClassifier(ABC, nn.Module):
24
+ def __init__(self, label_decoder, target_words, target_upos, *args, **kwargs):
25
+ super().__init__(*args, **kwargs)
26
+
27
+ self.label_decoder = label_decoder
28
+ self.label_encoder = {y: x for x, y in label_decoder.items()}
29
+ self.target_words = target_words
30
+ self.target_upos = target_upos
31
+ self.unsaved_modules = []
32
+
33
+ def add_unsaved_module(self, name, module):
34
+ self.unsaved_modules += [name]
35
+ setattr(self, name, module)
36
+
37
+ def is_unsaved_module(self, name):
38
+ return name.split('.')[0] in self.unsaved_modules
39
+
40
+ def save(self, save_name):
41
+ """
42
+ Save the model to the given path, possibly with some args
43
+ """
44
+ save_dir = os.path.split(save_name)[0]
45
+ if save_dir:
46
+ os.makedirs(save_dir, exist_ok=True)
47
+ save_dict = self.get_save_dict()
48
+ torch.save(save_dict, save_name)
49
+ return save_dict
50
+
51
+ @abstractmethod
52
+ def model_type(self):
53
+ """
54
+ return a ModelType
55
+ """
56
+
57
+ def target_indices(self, words, tags):
58
+ return [idx for idx, (word, tag) in enumerate(zip(words, tags)) if word.lower() in self.target_words and tag in self.target_upos]
59
+
60
+ def predict(self, position_indices: torch.Tensor, sentences: List[List[str]], upos_tags: List[List[str]]=[]) -> torch.Tensor:
61
+ upos_tags = self.convert_tags(upos_tags)
62
+ with torch.no_grad():
63
+ logits = self.forward(position_indices, sentences, upos_tags) # should be size (batch_size, output_size)
64
+ predicted_class = torch.argmax(logits, dim=1) # should be size (batch_size, 1)
65
+ predicted_class = [self.label_encoder[x.item()] for x in predicted_class]
66
+ return predicted_class
67
+
68
+ @staticmethod
69
+ def from_checkpoint(checkpoint, args=None):
70
+ model_type = ModelType[checkpoint['model_type']]
71
+ if model_type is ModelType.LSTM:
72
+ # TODO: if anyone can suggest a way to avoid this circular import
73
+ # (or better yet, avoid the load method knowing about subclasses)
74
+ # please do so
75
+ # maybe the subclassing is not necessary and we just put
76
+ # save & load in the trainer
77
+ from stanza.models.lemma_classifier.lstm_model import LemmaClassifierLSTM
78
+
79
+ saved_args = checkpoint['args']
80
+ # other model args are part of the model and cannot be changed for evaluation or pipeline
81
+ # the file paths might be relevant, though
82
+ keep_args = ['wordvec_pretrain_file', 'charlm_forward_file', 'charlm_backward_file']
83
+ for arg in keep_args:
84
+ if args is not None and args.get(arg, None) is not None:
85
+ saved_args[arg] = args[arg]
86
+
87
+ # TODO: refactor loading the pretrain (also done in the trainer)
88
+ pt = load_pretrain(saved_args['wordvec_pretrain_file'])
89
+
90
+ use_charlm = saved_args['use_charlm']
91
+ charlm_forward_file = saved_args.get('charlm_forward_file', None)
92
+ charlm_backward_file = saved_args.get('charlm_backward_file', None)
93
+
94
+ model = LemmaClassifierLSTM(model_args=saved_args,
95
+ output_dim=len(checkpoint['label_decoder']),
96
+ pt_embedding=pt,
97
+ label_decoder=checkpoint['label_decoder'],
98
+ upos_to_id=checkpoint['upos_to_id'],
99
+ known_words=checkpoint['known_words'],
100
+ target_words=set(checkpoint['target_words']),
101
+ target_upos=set(checkpoint['target_upos']),
102
+ use_charlm=use_charlm,
103
+ charlm_forward_file=charlm_forward_file,
104
+ charlm_backward_file=charlm_backward_file)
105
+ elif model_type is ModelType.TRANSFORMER:
106
+ from stanza.models.lemma_classifier.transformer_model import LemmaClassifierWithTransformer
107
+
108
+ output_dim = len(checkpoint['label_decoder'])
109
+ saved_args = checkpoint['args']
110
+ bert_model = saved_args['bert_model']
111
+ model = LemmaClassifierWithTransformer(model_args=saved_args,
112
+ output_dim=output_dim,
113
+ transformer_name=bert_model,
114
+ label_decoder=checkpoint['label_decoder'],
115
+ target_words=set(checkpoint['target_words']),
116
+ target_upos=set(checkpoint['target_upos']))
117
+ else:
118
+ raise ValueError("Unknown model type %s" % model_type)
119
+
120
+ # strict=False to accommodate missing parameters from the transformer or charlm
121
+ model.load_state_dict(checkpoint['params'], strict=False)
122
+ return model
123
+
124
+ @staticmethod
125
+ def load(filename, args=None):
126
+ try:
127
+ checkpoint = torch.load(filename, lambda storage, loc: storage)
128
+ except BaseException:
129
+ logger.exception("Cannot load model from %s", filename)
130
+ raise
131
+
132
+ logger.debug("Loading LemmaClassifier model from %s", filename)
133
+
134
+ return LemmaClassifier.from_checkpoint(checkpoint)
stanza/stanza/models/lemma_classifier/baseline_model.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Baseline model for the existing lemmatizer which always predicts "be" and never "have" on the "'s" token.
3
+
4
+ The BaselineModel class can be updated to any arbitrary token and predicton lemma, not just "be" on the "s" token.
5
+ """
6
+
7
+ import stanza
8
+ import os
9
+ from stanza.models.lemma_classifier.evaluate_models import evaluate_sequences
10
+ from stanza.models.lemma_classifier.prepare_dataset import load_doc_from_conll_file
11
+
12
+ class BaselineModel:
13
+
14
+ def __init__(self, token_to_lemmatize, prediction_lemma, prediction_upos):
15
+ self.token_to_lemmatize = token_to_lemmatize
16
+ self.prediction_lemma = prediction_lemma
17
+ self.prediction_upos = prediction_upos
18
+
19
+ def predict(self, token):
20
+ if token == self.token_to_lemmatize:
21
+ return self.prediction_lemma
22
+
23
+ def evaluate(self, conll_path):
24
+ """
25
+ Evaluates the baseline model against the test set defined in conll_path.
26
+
27
+ Returns a map where the keys are each class and the values are another map including the precision, recall and f1 scores
28
+ for that class.
29
+
30
+ Also returns confusion matrix. Keys are gold tags and inner keys are predicted tags
31
+ """
32
+ doc = load_doc_from_conll_file(conll_path)
33
+ gold_tag_sequences, pred_tag_sequences = [], []
34
+ for sentence in doc.sentences:
35
+ gold_tags, pred_tags = [], []
36
+ for word in sentence.words:
37
+ if word.upos in self.prediction_upos and word.text == self.token_to_lemmatize:
38
+ pred = self.prediction_lemma
39
+ gold = word.lemma
40
+ gold_tags.append(gold)
41
+ pred_tags.append(pred)
42
+ gold_tag_sequences.append(gold_tags)
43
+ pred_tag_sequences.append(pred_tags)
44
+
45
+ multiclass_result, confusion_mtx, weighted_f1 = evaluate_sequences(gold_tag_sequences, pred_tag_sequences)
46
+ return multiclass_result, confusion_mtx
47
+
48
+
49
+ if __name__ == "__main__":
50
+
51
+ bl_model = BaselineModel("'s", "be", ["AUX"])
52
+ coNLL_path = os.path.join(os.path.dirname(__file__), "en_gum-ud-train.conllu")
53
+ bl_model.evaluate(coNLL_path)
54
+
stanza/stanza/models/lemma_classifier/lstm_model.py ADDED
@@ -0,0 +1,219 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import os
4
+ import logging
5
+ import math
6
+ from torch.nn.utils.rnn import pad_sequence, pack_padded_sequence, pad_packed_sequence
7
+ from stanza.models.common.char_model import CharacterModel, CharacterLanguageModel
8
+ from typing import List, Tuple
9
+
10
+ from stanza.models.common.vocab import UNK_ID
11
+ from stanza.models.lemma_classifier import utils
12
+ from stanza.models.lemma_classifier.base_model import LemmaClassifier
13
+ from stanza.models.lemma_classifier.constants import ModelType
14
+
15
+ logger = logging.getLogger('stanza.lemmaclassifier')
16
+
17
+ class LemmaClassifierLSTM(LemmaClassifier):
18
+ """
19
+ Model architecture:
20
+ Extracts word embeddings over the sentence, passes embeddings into a bi-LSTM to get a sentence encoding.
21
+ From the LSTM output, we get the embedding of the specific token that we classify on. That embedding
22
+ is fed into an MLP for classification.
23
+ """
24
+ def __init__(self, model_args, output_dim, pt_embedding, label_decoder, upos_to_id, known_words, target_words, target_upos,
25
+ use_charlm=False, charlm_forward_file=None, charlm_backward_file=None):
26
+ """
27
+ Args:
28
+ vocab_size (int): Size of the vocab being used (if custom vocab)
29
+ output_dim (int): Size of output vector from MLP layer
30
+ upos_to_id (Mapping[str, int]): A dictionary mapping UPOS tag strings to their respective IDs
31
+ pt_embedding (Pretrain): pretrained embeddings
32
+ known_words (list(str)): Words which are in the training data
33
+ target_words (set(str)): a set of the words which might need lemmatization
34
+ use_charlm (bool): Whether or not to use the charlm embeddings
35
+ charlm_forward_file (str): The path to the forward pass model for the character language model
36
+ charlm_backward_file (str): The path to the forward pass model for the character language model.
37
+
38
+ Kwargs:
39
+ upos_emb_dim (int): The size of the UPOS tag embeddings
40
+ num_heads (int): The number of heads to use for attention. If there are more than 0 heads, attention will be used instead of the LSTM.
41
+
42
+ Raises:
43
+ FileNotFoundError: if the forward or backward charlm file cannot be found.
44
+ """
45
+ super(LemmaClassifierLSTM, self).__init__(label_decoder, target_words, target_upos)
46
+ self.model_args = model_args
47
+
48
+ self.hidden_dim = model_args['hidden_dim']
49
+ self.input_size = 0
50
+ self.num_heads = self.model_args['num_heads']
51
+
52
+ emb_matrix = pt_embedding.emb
53
+ self.add_unsaved_module("embeddings", nn.Embedding.from_pretrained(emb_matrix, freeze=True))
54
+ self.vocab_map = { word.replace('\xa0', ' '): i for i, word in enumerate(pt_embedding.vocab) }
55
+ self.vocab_size = emb_matrix.shape[0]
56
+ self.embedding_dim = emb_matrix.shape[1]
57
+
58
+ self.known_words = known_words
59
+ self.known_word_map = {word: idx for idx, word in enumerate(known_words)}
60
+ self.delta_embedding = nn.Embedding(num_embeddings=len(known_words)+1,
61
+ embedding_dim=self.embedding_dim,
62
+ padding_idx=0)
63
+ nn.init.normal_(self.delta_embedding.weight, std=0.01)
64
+
65
+ self.input_size += self.embedding_dim
66
+
67
+ # Optionally, include charlm embeddings
68
+ self.use_charlm = use_charlm
69
+
70
+ if self.use_charlm:
71
+ if charlm_forward_file is None or not os.path.exists(charlm_forward_file):
72
+ raise FileNotFoundError(f'Could not find forward character model: {charlm_forward_file}')
73
+ if charlm_backward_file is None or not os.path.exists(charlm_backward_file):
74
+ raise FileNotFoundError(f'Could not find backward character model: {charlm_backward_file}')
75
+ self.add_unsaved_module('charmodel_forward', CharacterLanguageModel.load(charlm_forward_file, finetune=False))
76
+ self.add_unsaved_module('charmodel_backward', CharacterLanguageModel.load(charlm_backward_file, finetune=False))
77
+
78
+ self.input_size += self.charmodel_forward.hidden_dim() + self.charmodel_backward.hidden_dim()
79
+
80
+ self.upos_emb_dim = self.model_args["upos_emb_dim"]
81
+ self.upos_to_id = upos_to_id
82
+ if self.upos_emb_dim > 0 and self.upos_to_id is not None:
83
+ # TODO: should leave space for unknown POS?
84
+ self.upos_emb = nn.Embedding(num_embeddings=len(self.upos_to_id),
85
+ embedding_dim=self.upos_emb_dim,
86
+ padding_idx=0)
87
+ self.input_size += self.upos_emb_dim
88
+
89
+ device = next(self.parameters()).device
90
+ # Determine if attn or LSTM should be used
91
+ if self.num_heads > 0:
92
+ self.input_size = utils.round_up_to_multiple(self.input_size, self.num_heads)
93
+ self.multihead_attn = nn.MultiheadAttention(embed_dim=self.input_size, num_heads=self.num_heads, batch_first=True).to(device)
94
+ logger.debug(f"Using attention mechanism with embed dim {self.input_size} and {self.num_heads} attention heads.")
95
+ else:
96
+ self.lstm = nn.LSTM(self.input_size,
97
+ self.hidden_dim,
98
+ batch_first=True,
99
+ bidirectional=True)
100
+ logger.debug(f"Using LSTM mechanism.")
101
+
102
+ mlp_input_size = self.hidden_dim * 2 if self.num_heads == 0 else self.input_size
103
+ self.mlp = nn.Sequential(
104
+ nn.Linear(mlp_input_size, 64),
105
+ nn.ReLU(),
106
+ nn.Linear(64, output_dim)
107
+ )
108
+
109
+ def get_save_dict(self):
110
+ save_dict = {
111
+ "params": self.state_dict(),
112
+ "label_decoder": self.label_decoder,
113
+ "model_type": self.model_type().name,
114
+ "args": self.model_args,
115
+ "upos_to_id": self.upos_to_id,
116
+ "known_words": self.known_words,
117
+ "target_words": list(self.target_words),
118
+ "target_upos": list(self.target_upos),
119
+ }
120
+ skipped = [k for k in save_dict["params"].keys() if self.is_unsaved_module(k)]
121
+ for k in skipped:
122
+ del save_dict["params"][k]
123
+ return save_dict
124
+
125
+ def convert_tags(self, upos_tags: List[List[str]]):
126
+ if self.upos_to_id is not None:
127
+ return [[self.upos_to_id[x] for x in sentence] for sentence in upos_tags]
128
+ return None
129
+
130
+ def forward(self, pos_indices: List[int], sentences: List[List[str]], upos_tags: List[List[int]]):
131
+ """
132
+ Computes the forward pass of the neural net
133
+
134
+ Args:
135
+ pos_indices (List[int]): A list of the position index of the target token for lemmatization classification in each sentence.
136
+ sentences (List[List[str]]): A list of the token-split sentences of the input data.
137
+ upos_tags (List[List[int]]): A list of the upos tags for each token in every sentence.
138
+
139
+ Returns:
140
+ torch.tensor: Output logits of the neural network, where the shape is (n, output_size) where n is the number of sentences.
141
+ """
142
+ device = next(self.parameters()).device
143
+ batch_size = len(sentences)
144
+ token_ids = []
145
+ delta_token_ids = []
146
+ for words in sentences:
147
+ sentence_token_ids = [self.vocab_map.get(word.lower(), UNK_ID) for word in words]
148
+ sentence_token_ids = torch.tensor(sentence_token_ids, device=device)
149
+ token_ids.append(sentence_token_ids)
150
+
151
+ sentence_delta_token_ids = [self.known_word_map.get(word.lower(), 0) for word in words]
152
+ sentence_delta_token_ids = torch.tensor(sentence_delta_token_ids, device=device)
153
+ delta_token_ids.append(sentence_delta_token_ids)
154
+
155
+ token_ids = pad_sequence(token_ids, batch_first=True)
156
+ delta_token_ids = pad_sequence(delta_token_ids, batch_first=True)
157
+ embedded = self.embeddings(token_ids) + self.delta_embedding(delta_token_ids)
158
+
159
+ if self.upos_emb_dim > 0:
160
+ upos_tags = [torch.tensor(sentence_tags) for sentence_tags in upos_tags] # convert internal lists to tensors
161
+ upos_tags = pad_sequence(upos_tags, batch_first=True, padding_value=0).to(device)
162
+ pos_emb = self.upos_emb(upos_tags)
163
+ embedded = torch.cat((embedded, pos_emb), 2).to(device)
164
+
165
+ if self.use_charlm:
166
+ char_reps_forward = self.charmodel_forward.build_char_representation(sentences) # takes [[str]]
167
+ char_reps_backward = self.charmodel_backward.build_char_representation(sentences)
168
+
169
+ char_reps_forward = pad_sequence(char_reps_forward, batch_first=True)
170
+ char_reps_backward = pad_sequence(char_reps_backward, batch_first=True)
171
+
172
+ embedded = torch.cat((embedded, char_reps_forward, char_reps_backward), 2)
173
+
174
+ if self.num_heads > 0:
175
+
176
+ def positional_encoding(seq_len, d_model, device):
177
+ encoding = torch.zeros(seq_len, d_model, device=device)
178
+ position = torch.arange(0, seq_len, dtype=torch.float, device=device).unsqueeze(1)
179
+ div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)).to(device)
180
+
181
+ encoding[:, 0::2] = torch.sin(position * div_term)
182
+ encoding[:, 1::2] = torch.cos(position * div_term)
183
+
184
+ # Add a new dimension to fit the batch size
185
+ encoding = encoding.unsqueeze(0)
186
+ return encoding
187
+
188
+ seq_len, d_model = embedded.shape[1], embedded.shape[2]
189
+ pos_enc = positional_encoding(seq_len, d_model, device=device)
190
+
191
+ embedded += pos_enc.expand_as(embedded)
192
+
193
+ padded_sequences = pad_sequence(embedded, batch_first=True)
194
+ lengths = torch.tensor([len(seq) for seq in embedded])
195
+
196
+ if self.num_heads > 0:
197
+ target_seq_length, src_seq_length = padded_sequences.size(1), padded_sequences.size(1)
198
+ attn_mask = torch.triu(torch.ones(batch_size * self.num_heads, target_seq_length, src_seq_length, dtype=torch.bool), diagonal=1)
199
+
200
+ attn_mask = attn_mask.view(batch_size, self.num_heads, target_seq_length, src_seq_length)
201
+ attn_mask = attn_mask.repeat(1, 1, 1, 1).view(batch_size * self.num_heads, target_seq_length, src_seq_length).to(device)
202
+
203
+ attn_output, attn_weights = self.multihead_attn(padded_sequences, padded_sequences, padded_sequences, attn_mask=attn_mask)
204
+ # Extract the hidden state at the index of the token to classify
205
+ token_reps = attn_output[torch.arange(attn_output.size(0)), pos_indices]
206
+
207
+ else:
208
+ packed_sequences = pack_padded_sequence(padded_sequences, lengths, batch_first=True)
209
+ lstm_out, (hidden, _) = self.lstm(packed_sequences)
210
+ # Extract the hidden state at the index of the token to classify
211
+ unpacked_lstm_outputs, _ = pad_packed_sequence(lstm_out, batch_first=True)
212
+ token_reps = unpacked_lstm_outputs[torch.arange(unpacked_lstm_outputs.size(0)), pos_indices]
213
+
214
+ # MLP forward pass
215
+ output = self.mlp(token_reps)
216
+ return output
217
+
218
+ def model_type(self):
219
+ return ModelType.LSTM
stanza/stanza/models/mwt/data.py ADDED
@@ -0,0 +1,182 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+ import numpy as np
3
+ import os
4
+ from collections import Counter, namedtuple
5
+ import logging
6
+
7
+ import torch
8
+ from torch.nn.utils.rnn import pad_sequence
9
+ from torch.utils.data import DataLoader as DL
10
+
11
+ import stanza.models.common.seq2seq_constant as constant
12
+ from stanza.models.common.data import map_to_ids, get_long_tensor, get_float_tensor, sort_all
13
+ from stanza.models.common.vocab import DeltaVocab
14
+ from stanza.models.mwt.vocab import Vocab
15
+ from stanza.models.common.doc import Document
16
+
17
+ logger = logging.getLogger('stanza')
18
+
19
+ DataSample = namedtuple("DataSample", "src tgt_in tgt_out orig_text")
20
+ DataBatch = namedtuple("DataBatch", "src src_mask tgt_in tgt_out orig_text orig_idx")
21
+
22
+ # enforce that the MWT splitter knows about a couple different alternate apostrophes
23
+ # including covering some potential " typos
24
+ # setting the augmentation to a very low value should be enough to teach it
25
+ # about the unknown characters without messing up the predictions for other text
26
+ #
27
+ # 0x22, 0x27, 0x02BC, 0x02CA, 0x055A, 0x07F4, 0x2019, 0xFF07
28
+ APOS = ('"', "'", 'ʼ', 'ˊ', '՚', 'ߴ', '’', ''')
29
+
30
+ class DataLoader:
31
+ def __init__(self, doc, batch_size, args, vocab=None, evaluation=False, expand_unk_vocab=False):
32
+ self.batch_size = batch_size
33
+ self.args = args
34
+ self.augment_apos = args.get('augment_apos', 0.0)
35
+ self.evaluation = evaluation
36
+ self.doc = doc
37
+
38
+ data = self.load_doc(self.doc, evaluation=self.evaluation)
39
+
40
+ # handle vocab
41
+ if vocab is None:
42
+ assert self.evaluation == False # for eval vocab must exist
43
+ self.vocab = self.init_vocab(data)
44
+ if self.augment_apos > 0 and any(x in self.vocab for x in APOS):
45
+ for apos in APOS:
46
+ self.vocab.add_unit(apos)
47
+ elif expand_unk_vocab:
48
+ self.vocab = DeltaVocab(data, vocab)
49
+ else:
50
+ self.vocab = vocab
51
+
52
+ # filter and sample data
53
+ if args.get('sample_train', 1.0) < 1.0 and not self.evaluation:
54
+ keep = int(args['sample_train'] * len(data))
55
+ data = random.sample(data, keep)
56
+ logger.debug("Subsample training set with rate {:g}".format(args['sample_train']))
57
+
58
+ # shuffle for training
59
+ if not self.evaluation:
60
+ indices = list(range(len(data)))
61
+ random.shuffle(indices)
62
+ data = [data[i] for i in indices]
63
+
64
+ self.data = data
65
+ self.num_examples = len(data)
66
+
67
+ def init_vocab(self, data):
68
+ assert self.evaluation == False # for eval vocab must exist
69
+ vocab = Vocab(data, self.args['shorthand'])
70
+ return vocab
71
+
72
+ def maybe_augment_apos(self, datum):
73
+ for original in APOS:
74
+ if original in datum[0]:
75
+ if random.uniform(0,1) < self.augment_apos:
76
+ replacement = random.choice(APOS)
77
+ datum = (datum[0].replace(original, replacement), datum[1].replace(original, replacement))
78
+ break
79
+ return datum
80
+
81
+ def process(self, sample):
82
+ if not self.evaluation and self.augment_apos > 0:
83
+ sample = self.maybe_augment_apos(sample)
84
+ src = list(sample[0])
85
+ src = [constant.SOS] + src + [constant.EOS]
86
+ tgt_in, tgt_out = self.prepare_target(self.vocab, sample)
87
+ src = self.vocab.map(src)
88
+ processed = [src, tgt_in, tgt_out, sample[0]]
89
+ return processed
90
+
91
+ def prepare_target(self, vocab, datum):
92
+ if self.evaluation:
93
+ tgt = list(datum[0]) # as a placeholder
94
+ else:
95
+ tgt = list(datum[1])
96
+ tgt_in = vocab.map([constant.SOS] + tgt)
97
+ tgt_out = vocab.map(tgt + [constant.EOS])
98
+ return tgt_in, tgt_out
99
+
100
+ def __len__(self):
101
+ return len(self.data)
102
+
103
+ def __getitem__(self, key):
104
+ """ Get a batch with index. """
105
+ if not isinstance(key, int):
106
+ raise TypeError
107
+ if key < 0 or key >= len(self.data):
108
+ raise IndexError
109
+ sample = self.data[key]
110
+ sample = self.process(sample)
111
+ assert len(sample) == 4
112
+
113
+ src = torch.tensor(sample[0])
114
+ tgt_in = torch.tensor(sample[1])
115
+ tgt_out = torch.tensor(sample[2])
116
+ orig_text = sample[3]
117
+ result = DataSample(src, tgt_in, tgt_out, orig_text), key
118
+ return result
119
+
120
+ @staticmethod
121
+ def __collate_fn(data):
122
+ (data, idx) = zip(*data)
123
+ (src, tgt_in, tgt_out, orig_text) = zip(*data)
124
+
125
+ # collate_fn is given a list of length batch size
126
+ batch_size = len(data)
127
+
128
+ # need to sort by length of src to properly handle
129
+ # the batching in the model itself
130
+ lens = [len(x) for x in src]
131
+ (src, tgt_in, tgt_out, orig_text), orig_idx = sort_all((src, tgt_in, tgt_out, orig_text), lens)
132
+ lens = [len(x) for x in src]
133
+
134
+ # convert to tensors
135
+ src = pad_sequence(src, True, constant.PAD_ID)
136
+ src_mask = torch.eq(src, constant.PAD_ID)
137
+ tgt_in = pad_sequence(tgt_in, True, constant.PAD_ID)
138
+ tgt_out = pad_sequence(tgt_out, True, constant.PAD_ID)
139
+ assert tgt_in.size(1) == tgt_out.size(1), \
140
+ "Target input and output sequence sizes do not match."
141
+ return DataBatch(src, src_mask, tgt_in, tgt_out, orig_text, orig_idx)
142
+
143
+ def __iter__(self):
144
+ for i in range(self.__len__()):
145
+ yield self.__getitem__(i)
146
+
147
+ def to_loader(self):
148
+ """Converts self to a DataLoader """
149
+
150
+ batch_size = self.batch_size
151
+ shuffle = not self.evaluation
152
+ return DL(self,
153
+ collate_fn=self.__collate_fn,
154
+ batch_size=batch_size,
155
+ shuffle=shuffle)
156
+
157
+ def load_doc(self, doc, evaluation=False):
158
+ data = doc.get_mwt_expansions(evaluation)
159
+ if evaluation: data = [[e] for e in data]
160
+ return data
161
+
162
+ class BinaryDataLoader(DataLoader):
163
+ """
164
+ This version of the DataLoader performs the same tasks as the regular DataLoader,
165
+ except the targets are arrays of 0/1 indicating if the character is the location
166
+ of an MWT split
167
+ """
168
+ def prepare_target(self, vocab, datum):
169
+ src = datum[0] if self.evaluation else datum[1]
170
+ binary = [0]
171
+ has_space = False
172
+ for char in src:
173
+ if char == ' ':
174
+ has_space = True
175
+ elif has_space:
176
+ has_space = False
177
+ binary.append(1)
178
+ else:
179
+ binary.append(0)
180
+ binary.append(0)
181
+ return binary, binary
182
+
stanza/stanza/models/mwt/scorer.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Utils and wrappers for scoring MWT
3
+ """
4
+ from stanza.models.common.utils import ud_scores
5
+
6
+ def score(system_conllu_file, gold_conllu_file):
7
+ """ Wrapper for word segmenter scorer. """
8
+ evaluation = ud_scores(gold_conllu_file, system_conllu_file)
9
+ el = evaluation["Words"]
10
+ p, r, f = el.precision, el.recall, el.f1
11
+ return p, r, f
12
+
stanza/stanza/models/mwt/utils.py ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import stanza
2
+
3
+ from stanza.models.common import doc
4
+ from stanza.models.tokenization.data import TokenizationDataset
5
+ from stanza.models.tokenization.utils import predict, decode_predictions
6
+
7
+ def mwts_composed_of_words(doc):
8
+ """
9
+ Return True/False if the MWTs in the doc are all exactly composed of the text in their words
10
+ """
11
+ for sent_idx, sentence in enumerate(doc.sentences):
12
+ for token_idx, token in enumerate(sentence.tokens):
13
+ if len(token.words) > 1:
14
+ expected = "".join(x.text for x in token.words)
15
+ if token.text != expected:
16
+ return False
17
+ return True
18
+
19
+
20
+ def resplit_mwt(tokens, pipeline, keep_tokens=True):
21
+ """
22
+ Uses the tokenize processor and the mwt processor in the pipeline to resplit tokens into MWT
23
+
24
+ tokens: a list of list of string
25
+ pipeline: a Stanza pipeline which contains, at a minimum, tokenize and mwt
26
+
27
+ keep_tokens: if True, enforce the old token boundaries by modify
28
+ the results of the tokenize inference.
29
+ Otherwise, use whatever new boundaries the model comes up with.
30
+
31
+ between running the tokenize model and breaking the text into tokens,
32
+ we can update all_preds to use the original token boundaries
33
+ (if and only if keep_tokens == True)
34
+
35
+ This method returns a Document with just the tokens and words annotated.
36
+ """
37
+ if "tokenize" not in pipeline.processors:
38
+ raise ValueError("Need a Pipeline with a valid tokenize processor")
39
+ if "mwt" not in pipeline.processors:
40
+ raise ValueError("Need a Pipeline with a valid mwt processor")
41
+ tokenize_processor = pipeline.processors["tokenize"]
42
+ mwt_processor = pipeline.processors["mwt"]
43
+ fake_text = "\n\n".join(" ".join(sentence) for sentence in tokens)
44
+
45
+ # set up batches
46
+ batches = TokenizationDataset(tokenize_processor.config,
47
+ input_text=fake_text,
48
+ vocab=tokenize_processor.vocab,
49
+ evaluation=True,
50
+ dictionary=tokenize_processor.trainer.dictionary)
51
+
52
+ all_preds, all_raw = predict(trainer=tokenize_processor.trainer,
53
+ data_generator=batches,
54
+ batch_size=tokenize_processor.trainer.args['batch_size'],
55
+ max_seqlen=tokenize_processor.config.get('max_seqlen', tokenize_processor.MAX_SEQ_LENGTH_DEFAULT),
56
+ use_regex_tokens=True,
57
+ num_workers=tokenize_processor.config.get('num_workers', 0))
58
+
59
+ if keep_tokens:
60
+ for sentence, pred in zip(tokens, all_preds):
61
+ char_idx = 0
62
+ for word in sentence:
63
+ if len(word) > 0:
64
+ pred[char_idx:char_idx+len(word)-1] = 0
65
+ if pred[char_idx+len(word)-1] == 0:
66
+ pred[char_idx+len(word)-1] = 1
67
+ char_idx += len(word) + 1
68
+
69
+ _, _, document = decode_predictions(vocab=tokenize_processor.vocab,
70
+ mwt_dict=None,
71
+ orig_text=fake_text,
72
+ all_raw=all_raw,
73
+ all_preds=all_preds,
74
+ no_ssplit=True,
75
+ skip_newline=tokenize_processor.trainer.args['skip_newline'],
76
+ use_la_ittb_shorthand=tokenize_processor.trainer.args['shorthand'] == 'la_ittb')
77
+
78
+ document = doc.Document(document, fake_text)
79
+ mwt_processor.process(document)
80
+ return document
81
+
82
+ def main():
83
+ pipe = stanza.Pipeline("en", processors="tokenize,mwt", package="gum")
84
+ tokens = [["I", "can't", "believe", "it"], ["I can't", "sleep"]]
85
+ doc = resplit_mwt(tokens, pipe)
86
+ print(doc)
87
+
88
+ doc = resplit_mwt(tokens, pipe, keep_tokens=False)
89
+ print(doc)
90
+
91
+ if __name__ == '__main__':
92
+ main()
stanza/stanza/models/ner/__init__.py ADDED
File without changes
stanza/stanza/models/ner/trainer.py ADDED
@@ -0,0 +1,268 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ A trainer class to handle training and testing of models.
3
+ """
4
+
5
+ import sys
6
+ import logging
7
+ import torch
8
+ from torch import nn
9
+
10
+ from stanza.models.common.foundation_cache import NoTransformerFoundationCache, load_bert, load_bert_with_peft
11
+ from stanza.models.common.peft_config import build_peft_wrapper, load_peft_wrapper
12
+ from stanza.models.common.trainer import Trainer as BaseTrainer
13
+ from stanza.models.common.vocab import VOCAB_PREFIX, VOCAB_PREFIX_SIZE
14
+ from stanza.models.common import utils, loss
15
+ from stanza.models.ner.model import NERTagger
16
+ from stanza.models.ner.vocab import MultiVocab
17
+ from stanza.models.common.crf import viterbi_decode
18
+
19
+
20
+ logger = logging.getLogger('stanza')
21
+
22
+ def unpack_batch(batch, device):
23
+ """ Unpack a batch from the data loader. """
24
+ inputs = [batch[0]]
25
+ inputs += [b.to(device) if b is not None else None for b in batch[1:5]]
26
+ orig_idx = batch[5]
27
+ word_orig_idx = batch[6]
28
+ char_orig_idx = batch[7]
29
+ sentlens = batch[8]
30
+ wordlens = batch[9]
31
+ charlens = batch[10]
32
+ charoffsets = batch[11]
33
+ return inputs, orig_idx, word_orig_idx, char_orig_idx, sentlens, wordlens, charlens, charoffsets
34
+
35
+ def fix_singleton_tags(tags):
36
+ """
37
+ If there are any singleton B- or E- tags, convert them to S-
38
+ """
39
+ new_tags = list(tags)
40
+ # first update all I- tags at the start or end of sequence to B- or E- as appropriate
41
+ for idx, tag in enumerate(new_tags):
42
+ if (tag.startswith("I-") and
43
+ (idx == len(new_tags) - 1 or
44
+ (new_tags[idx+1] != "I-" + tag[2:] and new_tags[idx+1] != "E-" + tag[2:]))):
45
+ new_tags[idx] = "E-" + tag[2:]
46
+ if (tag.startswith("I-") and
47
+ (idx == 0 or
48
+ (new_tags[idx-1] != "B-" + tag[2:] and new_tags[idx-1] != "I-" + tag[2:]))):
49
+ new_tags[idx] = "B-" + tag[2:]
50
+ # now make another pass through the data to update any singleton tags,
51
+ # including ones which were turned into singletons by the previous operation
52
+ for idx, tag in enumerate(new_tags):
53
+ if (tag.startswith("B-") and
54
+ (idx == len(new_tags) - 1 or
55
+ (new_tags[idx+1] != "I-" + tag[2:] and new_tags[idx+1] != "E-" + tag[2:]))):
56
+ new_tags[idx] = "S-" + tag[2:]
57
+ if (tag.startswith("E-") and
58
+ (idx == 0 or
59
+ (new_tags[idx-1] != "B-" + tag[2:] and new_tags[idx-1] != "I-" + tag[2:]))):
60
+ new_tags[idx] = "S-" + tag[2:]
61
+ return new_tags
62
+
63
+ class Trainer(BaseTrainer):
64
+ """ A trainer for training models. """
65
+ def __init__(self, args=None, vocab=None, pretrain=None, model_file=None, device=None,
66
+ train_classifier_only=False, foundation_cache=None, second_optim=False):
67
+ if model_file is not None:
68
+ # load everything from file
69
+ self.load(model_file, pretrain, args, foundation_cache)
70
+ else:
71
+ assert all(var is not None for var in [args, vocab, pretrain])
72
+ # build model from scratch
73
+ self.args = args
74
+ self.vocab = vocab
75
+ bert_model, bert_tokenizer = load_bert(self.args['bert_model'])
76
+ peft_name = None
77
+ if self.args['use_peft']:
78
+ # fine tune the bert if we're using peft
79
+ self.args['bert_finetune'] = True
80
+ peft_name = "ner"
81
+ # peft the lovely model
82
+ bert_model = build_peft_wrapper(bert_model, self.args, logger, adapter_name=peft_name)
83
+
84
+ self.model = NERTagger(args, vocab, emb_matrix=pretrain.emb, foundation_cache=foundation_cache, bert_model=bert_model, bert_tokenizer=bert_tokenizer, force_bert_saved=self.args['bert_finetune'], peft_name=peft_name)
85
+
86
+ # IMPORTANT: gradient checkpointing BREAKS peft if applied before
87
+ # 1. Apply PEFT FIRST (looksie! it's above this line)
88
+ # 2. Run gradient checkpointing
89
+ # https://github.com/huggingface/peft/issues/742
90
+ if self.args.get("gradient_checkpointing", False) and self.args.get("bert_finetune", False):
91
+ self.model.bert_model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": False})
92
+
93
+
94
+ # if this wasn't set anywhere, we use a default of the 0th tagset
95
+ # we don't set this as a default in the options so that
96
+ # we can distinguish "intentionally set to 0" and "not set at all"
97
+ if self.args.get('predict_tagset', None) is None:
98
+ self.args['predict_tagset'] = 0
99
+
100
+ if train_classifier_only:
101
+ logger.info('Disabling gradient for non-classifier layers')
102
+ exclude = ['tag_clf', 'crit']
103
+ for pname, p in self.model.named_parameters():
104
+ if pname.split('.')[0] not in exclude:
105
+ p.requires_grad = False
106
+ self.model = self.model.to(device)
107
+ if not second_optim:
108
+ self.optimizer = utils.get_optimizer(self.args['optim'], self.model, self.args['lr'], momentum=self.args['momentum'], bert_learning_rate=self.args.get('bert_learning_rate', 0.0), is_peft=self.args.get("use_peft"))
109
+ else:
110
+ self.optimizer = utils.get_optimizer(self.args['second_optim'], self.model, self.args['second_lr'], momentum=self.args['momentum'], bert_learning_rate=self.args.get('second_bert_learning_rate', 0.0), is_peft=self.args.get("use_peft"))
111
+
112
+ def update(self, batch, eval=False):
113
+ device = next(self.model.parameters()).device
114
+ inputs, orig_idx, word_orig_idx, char_orig_idx, sentlens, wordlens, charlens, charoffsets = unpack_batch(batch, device)
115
+ word, wordchars, wordchars_mask, chars, tags = inputs
116
+
117
+ if eval:
118
+ self.model.eval()
119
+ else:
120
+ self.model.train()
121
+ self.optimizer.zero_grad()
122
+ loss, _, _ = self.model(word, wordchars, wordchars_mask, tags, word_orig_idx, sentlens, wordlens, chars, charoffsets, charlens, char_orig_idx)
123
+ loss_val = loss.data.item()
124
+ if eval:
125
+ return loss_val
126
+
127
+ loss.backward()
128
+ torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.args['max_grad_norm'])
129
+ self.optimizer.step()
130
+ return loss_val
131
+
132
+ def predict(self, batch, unsort=True):
133
+ device = next(self.model.parameters()).device
134
+ inputs, orig_idx, word_orig_idx, char_orig_idx, sentlens, wordlens, charlens, charoffsets = unpack_batch(batch, device)
135
+ word, wordchars, wordchars_mask, chars, tags = inputs
136
+
137
+ self.model.eval()
138
+ #batch_size = word.size(0)
139
+ _, logits, trans = self.model(word, wordchars, wordchars_mask, tags, word_orig_idx, sentlens, wordlens, chars, charoffsets, charlens, char_orig_idx)
140
+
141
+ # decode
142
+ # TODO: might need to decode multiple columns of output for
143
+ # models with multiple layers
144
+ trans = [x.data.cpu().numpy() for x in trans]
145
+ logits = [x.data.cpu().numpy() for x in logits]
146
+ batch_size = logits[0].shape[0]
147
+ if any(x.shape[0] != batch_size for x in logits):
148
+ raise AssertionError("Expected all of the logits to have the same size")
149
+ tag_seqs = []
150
+ predict_tagset = self.args['predict_tagset']
151
+ for i in range(batch_size):
152
+ # for each tag column in the output, decode the tag assignments
153
+ tags = [viterbi_decode(x[i, :sentlens[i]], y)[0] for x, y in zip(logits, trans)]
154
+ # TODO: this is to patch that the model can sometimes predict < "O"
155
+ tags = [[x if x >= VOCAB_PREFIX_SIZE else VOCAB_PREFIX_SIZE for x in y] for y in tags]
156
+ # that gives us N lists of |sent| tags, whereas we want |sent| lists of N tags
157
+ tags = list(zip(*tags))
158
+ # now unmap that to the tags in the vocab
159
+ tags = self.vocab['tag'].unmap(tags)
160
+ # for now, allow either TagVocab or CompositeVocab
161
+ # TODO: we might want to return all of the predictions
162
+ # rather than a single column
163
+ tags = [x[predict_tagset] if isinstance(x, list) else x for x in tags]
164
+ tags = fix_singleton_tags(tags)
165
+ tag_seqs += [tags]
166
+
167
+ if unsort:
168
+ tag_seqs = utils.unsort(tag_seqs, orig_idx)
169
+ return tag_seqs
170
+
171
+ def save(self, filename, skip_modules=True):
172
+ model_state = self.model.state_dict()
173
+ # skip saving modules like pretrained embeddings, because they are large and will be saved in a separate file
174
+ if skip_modules:
175
+ skipped = [k for k in model_state.keys() if k.split('.')[0] in self.model.unsaved_modules]
176
+ for k in skipped:
177
+ del model_state[k]
178
+ params = {
179
+ 'model': model_state,
180
+ 'vocab': self.vocab.state_dict(),
181
+ 'config': self.args
182
+ }
183
+
184
+ if self.args["use_peft"]:
185
+ from peft import get_peft_model_state_dict
186
+ params["bert_lora"] = get_peft_model_state_dict(self.model.bert_model, adapter_name=self.model.peft_name)
187
+ try:
188
+ torch.save(params, filename, _use_new_zipfile_serialization=False)
189
+ logger.info("Model saved to {}".format(filename))
190
+ except (KeyboardInterrupt, SystemExit):
191
+ raise
192
+ except:
193
+ logger.warning("Saving failed... continuing anyway.")
194
+
195
+ def load(self, filename, pretrain=None, args=None, foundation_cache=None):
196
+ try:
197
+ checkpoint = torch.load(filename, lambda storage, loc: storage, weights_only=True)
198
+ except BaseException:
199
+ logger.error("Cannot load model from {}".format(filename))
200
+ raise
201
+ self.args = checkpoint['config']
202
+ if args: self.args.update(args)
203
+ # if predict_tagset was not explicitly set in the args,
204
+ # we use the value the model was trained with
205
+ for keep_arg in ('predict_tagset', 'train_scheme', 'scheme'):
206
+ if self.args.get(keep_arg, None) is None:
207
+ self.args[keep_arg] = checkpoint['config'].get(keep_arg, None)
208
+
209
+ lora_weights = checkpoint.get('bert_lora')
210
+ if lora_weights:
211
+ logger.debug("Found peft weights for NER; loading a peft adapter")
212
+ self.args["use_peft"] = True
213
+
214
+ self.vocab = MultiVocab.load_state_dict(checkpoint['vocab'])
215
+
216
+ emb_matrix=None
217
+ if pretrain is not None:
218
+ emb_matrix = pretrain.emb
219
+
220
+ force_bert_saved = False
221
+ peft_name = None
222
+ if self.args.get('use_peft', False):
223
+ force_bert_saved = True
224
+ bert_model, bert_tokenizer, peft_name = load_bert_with_peft(self.args['bert_model'], "ner", foundation_cache)
225
+ bert_model = load_peft_wrapper(bert_model, lora_weights, self.args, logger, peft_name)
226
+ logger.debug("Loaded peft with name %s", peft_name)
227
+ else:
228
+ if any(x.startswith("bert_model.") for x in checkpoint['model'].keys()):
229
+ logger.debug("Model %s has a finetuned transformer. Not using transformer cache to make sure the finetuned version of the transformer isn't accidentally used elsewhere", filename)
230
+ foundation_cache = NoTransformerFoundationCache(foundation_cache)
231
+ force_bert_saved = True
232
+ bert_model, bert_tokenizer = load_bert(self.args.get('bert_model'), foundation_cache)
233
+
234
+ if any(x.startswith("crit.") for x in checkpoint['model'].keys()):
235
+ logger.debug("Old model format detected. Updating to the new format with one column of tags")
236
+ checkpoint['model']['crits.0._transitions'] = checkpoint['model'].pop('crit._transitions')
237
+ checkpoint['model']['tag_clfs.0.weight'] = checkpoint['model'].pop('tag_clf.weight')
238
+ checkpoint['model']['tag_clfs.0.bias'] = checkpoint['model'].pop('tag_clf.bias')
239
+ self.model = NERTagger(self.args, self.vocab, emb_matrix=emb_matrix, foundation_cache=foundation_cache, bert_model=bert_model, bert_tokenizer=bert_tokenizer, force_bert_saved=force_bert_saved, peft_name=peft_name)
240
+ self.model.load_state_dict(checkpoint['model'], strict=False)
241
+
242
+ # there is a possible issue with the delta embeddings.
243
+ # specifically, with older models trained without the delta
244
+ # embedding matrix
245
+ # if those models have been trained with the embedding
246
+ # modifications saved as part of the base embedding,
247
+ # we need to resave the model with the updated embedding
248
+ # otherwise the resulting model will be broken
249
+ if 'delta' not in self.model.vocab and 'word_emb.weight' in checkpoint['model'].keys() and 'word_emb' in self.model.unsaved_modules:
250
+ logger.debug("Removing word_emb from unsaved_modules so that resaving %s will keep the saved embedding", filename)
251
+ self.model.unsaved_modules.remove('word_emb')
252
+
253
+ def get_known_tags(self):
254
+ """
255
+ Return the tags known by this model
256
+
257
+ Removes the S-, B-, etc, and does not include O
258
+ """
259
+ tags = set()
260
+ for tag in self.vocab['tag'].items(0):
261
+ if tag in VOCAB_PREFIX:
262
+ continue
263
+ if tag == 'O':
264
+ continue
265
+ if len(tag) > 2 and tag[:2] in ('S-', 'B-', 'I-', 'E-'):
266
+ tag = tag[2:]
267
+ tags.add(tag)
268
+ return sorted(tags)
stanza/stanza/models/tokenization/vocab.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections import Counter
2
+ import re
3
+
4
+ from stanza.models.common.vocab import BaseVocab
5
+ from stanza.models.common.vocab import UNK, PAD
6
+
7
+ SPACE_RE = re.compile(r'\s')
8
+
9
+ class Vocab(BaseVocab):
10
+ def __init__(self, *args, **kwargs):
11
+ super().__init__(*args, **kwargs)
12
+ self.lang_replaces_spaces = any([self.lang.startswith(x) for x in ['zh', 'ja', 'ko']])
13
+
14
+ def build_vocab(self):
15
+ paras = self.data
16
+ counter = Counter()
17
+ for para in paras:
18
+ for unit in para:
19
+ normalized = self.normalize_unit(unit[0])
20
+ counter[normalized] += 1
21
+
22
+ self._id2unit = [PAD, UNK] + list(sorted(list(counter.keys()), key=lambda k: counter[k], reverse=True))
23
+ self._unit2id = {w:i for i, w in enumerate(self._id2unit)}
24
+
25
+ def normalize_unit(self, unit):
26
+ # Normalize minimal units used by the tokenizer
27
+ return unit
28
+
29
+ def normalize_token(self, token):
30
+ token = SPACE_RE.sub(' ', token.lstrip())
31
+
32
+ if self.lang_replaces_spaces:
33
+ token = token.replace(' ', '')
34
+
35
+ return token
stanza/stanza/pipeline/demo/README.md ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ## Interactive Demo for Stanza
2
+
3
+ ### Requirements
4
+
5
+ stanza, flask
6
+
7
+ ### Run the demo locally
8
+
9
+ 1. Make sure you know how to disable your browser's CORS rule. For Chrome, [this extension](https://mybrowseraddon.com/access-control-allow-origin.html) works pretty well.
10
+ 2. From this directory, start the Stanza demo server
11
+
12
+ ```bash
13
+ export FLASK_APP=demo_server.py
14
+ flask run
15
+ ```
16
+
17
+ 3. In `stanza-brat.js`, uncomment the line at the top that declares `serverAddress` and point it to where your flask is serving the demo server (usually `http://localhost:5000`)
18
+
19
+ 4. Open `stanza-brat.html` in your browser (with CORS disabled) and enjoy!
20
+
21
+ ### Common issues
22
+
23
+ Make sure you have the models corresponding to the language you want to test out locally before submitting requests to the server! (Models can be obtained by `import stanza; stanza.download(<language_code>)`.
stanza/stanza/utils/datasets/constituency/__init__.py ADDED
File without changes
stanza/stanza/utils/datasets/constituency/common_trees.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Look through 2 files, only output the common trees
3
+
4
+ pretty basic - could use some more options
5
+ """
6
+
7
+ import sys
8
+
9
+ def main():
10
+ in1 = sys.argv[1]
11
+ with open(in1, encoding="utf-8") as fin:
12
+ lines1 = fin.readlines()
13
+ in2 = sys.argv[2]
14
+ with open(in2, encoding="utf-8") as fin:
15
+ lines2 = fin.readlines()
16
+
17
+ common = [l1 for l1, l2 in zip(lines1, lines2) if l1 == l2]
18
+ for l in common:
19
+ print(l.strip())
20
+
21
+ if __name__ == '__main__':
22
+ main()
23
+
stanza/stanza/utils/datasets/constituency/convert_alt.py ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Read files of parses and the files which define the train/dev/test splits
3
+
4
+ Write out the files after splitting them
5
+
6
+ Sequence of operations:
7
+ - read the raw lines from the input files
8
+ - read the recommended splits, as per the ALT description page
9
+ - separate the trees using the recommended split files
10
+ - write back the trees
11
+ """
12
+
13
+ def read_split_file(split_file):
14
+ """
15
+ Read a split file for ALT
16
+
17
+ The format of the file is expected to be a list of lines such as
18
+ URL.1234 <url>
19
+ Here, we only care about the id
20
+
21
+ return: a set of the ids
22
+ """
23
+ with open(split_file, encoding="utf-8") as fin:
24
+ lines = fin.readlines()
25
+ lines = [x.strip() for x in lines]
26
+ lines = [x.split()[0] for x in lines if x]
27
+ if any(not x.startswith("URL.") for x in lines):
28
+ raise ValueError("Unexpected line in %s: %s" % (split_file, x))
29
+ split = set(int(x.split(".", 1)[1]) for x in lines)
30
+ return split
31
+
32
+ def split_trees(all_lines, splits):
33
+ """
34
+ Splits lines of the form
35
+ SNT.17873.4049 (S ...
36
+ then assigns them to a list based on the file id in
37
+ SNT.<file>.<sent>
38
+ """
39
+ trees = [list() for _ in splits]
40
+ for line in all_lines:
41
+ tree_id, tree_text = line.split(maxsplit=1)
42
+ tree_id = int(tree_id.split(".", 2)[1])
43
+ for split_idx, split in enumerate(splits):
44
+ if tree_id in split:
45
+ trees[split_idx].append(tree_text)
46
+ break
47
+ else:
48
+ # couldn't figure out which split to put this in
49
+ raise ValueError("Couldn't find which split this line goes in:\n%s" % line)
50
+ return trees
51
+
52
+ def read_alt_lines(input_files):
53
+ """
54
+ Read the trees from the given file(s)
55
+
56
+ Any trees with wide spaces are eliminated. The parse tree
57
+ handling doesn't handle it well and the tokenizer won't produce
58
+ tokens which are entirely wide spaces anyway
59
+
60
+ The tree lines are not processed into trees, though
61
+ """
62
+ all_lines = []
63
+ for input_file in input_files:
64
+ with open(input_file, encoding="utf-8") as fin:
65
+ all_lines.extend(fin.readlines())
66
+ all_lines = [x.strip() for x in all_lines]
67
+ all_lines = [x for x in all_lines if x]
68
+ original_count = len(all_lines)
69
+ # there is 1 tree with wide space as an entire token, and 4 with wide spaces at the end of a token
70
+ all_lines = [x for x in all_lines if not " " in x]
71
+ new_count = len(all_lines)
72
+ if new_count < original_count:
73
+ print("Eliminated %d trees for having wide spaces in it" % ((original_count - new_count)))
74
+ original_count = new_count
75
+ all_lines = [x for x in all_lines if not "\\x" in x]
76
+ new_count = len(all_lines)
77
+ if new_count < original_count:
78
+ print("Eliminated %d trees for not being correctly encoded" % ((original_count - new_count)))
79
+ original_count = new_count
80
+ return all_lines
81
+
82
+ def convert_alt(input_files, split_files, output_files):
83
+ """
84
+ Convert the ALT treebank into train/dev/test splits
85
+
86
+ input_files: paths to read trees
87
+ split_files: recommended splits from the ALT page
88
+ output_files: where to write train/dev/test
89
+ """
90
+ all_lines = read_alt_lines(input_files)
91
+
92
+ splits = [read_split_file(split_file) for split_file in split_files]
93
+ trees = split_trees(all_lines, splits)
94
+
95
+ for chunk, output_file in zip(trees, output_files):
96
+ print("Writing %d trees to %s" % (len(chunk), output_file))
97
+ with open(output_file, "w", encoding="utf-8") as fout:
98
+ for tree in chunk:
99
+ # the extra ROOT is because the ALT doesn't have this at the top of its trees
100
+ fout.write("(ROOT {})\n".format(tree))
stanza/stanza/utils/datasets/constituency/convert_arboretum.py ADDED
@@ -0,0 +1,443 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Parses a Tiger dataset to PTB
3
+
4
+ Also handles problems specific for the Arboretum treebank.
5
+
6
+ - validation errors in the XML:
7
+ -- there is a "&" instead of an "&amp;" early on
8
+ -- there are tags "<{note}>" and "<{parentes-udeladt}>" which may or may not be relevant,
9
+ but are definitely not properly xml encoded
10
+ - trees with stranded nodes. 5 trees have links to words in a different tree.
11
+ those trees are skipped
12
+ - trees with empty nodes. 58 trees have phrase nodes with no leaves.
13
+ those trees are skipped
14
+ - trees with missing words. 134 trees have words in the text which aren't in the tree
15
+ those trees are also skipped
16
+ - trees with categories not in the category directory
17
+ for example, intj... replaced with fcl?
18
+ most of these are replaced with what might be a sensible replacement
19
+ - trees with labels that don't have an obvious replacement
20
+ these trees are eliminated, 4 total
21
+ - underscores in words. those words are split into multiple words
22
+ the tagging is not going to be ideal, but the first step of training
23
+ a parser is usually to retag the words anyway, so this should be okay
24
+ - tree 14729 is really weirdly annotated. skipped
25
+ - 5373 trees total have non-projective constituents. These don't work
26
+ with the stanza parser... in order to work around this, we rearrange
27
+ them when possible.
28
+ ((X Z) Y1 Y2 ...) -> (X Y1 Y2 Z) this rearranges 3021 trees
29
+ ((X Z1 ...) Y1 Y2 ...) -> (X Y1 Y2 Z) this rearranges 403 trees
30
+ ((X Z1 ...) (tag Y1) ...) -> (X (Y1) Z) this rearranges 1258 trees
31
+
32
+ A couple examples of things which get rearranged
33
+ (limited in scope and without the words to avoid breaking our license):
34
+
35
+ (vp (v-fin s4_6) (conj-c s4_8) (v-fin s4_9)) (pron-pers s4_7)
36
+ -->
37
+ (vp (v-fin s4_6) (pron-pers s4_7) (conj-c s4_8) (v-fin s4_9))
38
+
39
+ (vp (v-fin s1_2) (v-pcp2 s1_4)) (adv s1_3)
40
+ -->
41
+ (vp (v-fin s1_2) (adv s1_3) (v-pcp2 s1_4))
42
+
43
+ This process leaves behind 691 trees. In some cases, the
44
+ non-projective structure is at a higher level than the attachment.
45
+ In others, there are nested non-projectivities that are not
46
+ rearranged by the above pattern. A couple examples:
47
+
48
+ here, the 3-7 nonprojectivity has the 7 in a nested structure
49
+ (s
50
+ (par
51
+ (n s206_1)
52
+ (pu s206_2)
53
+ (fcl
54
+ (fcl
55
+ (pron-pers s206_3)
56
+ (fcl (pron-pers s206_7) (adv s206_8) (v-fin s206_9)))
57
+ (vp (v-fin s206_4) (v-inf s206_6))
58
+ (pron-pers s206_5))
59
+ (pu s206_10)))
60
+
61
+ here, 11 is attached at a higher level than 12 & 13
62
+ (s
63
+ (fcl
64
+ (icl
65
+ (np
66
+ (adv s223_1)
67
+ (np
68
+ (n s223_2)
69
+ (pp
70
+ (prp s223_3)
71
+ (par
72
+ (adv s223_4)
73
+ (prop s223_5)
74
+ (pu s223_6)
75
+ (prop s223_7)
76
+ (conj-c s223_8)
77
+ (np (adv s223_9) (prop s223_10))))))
78
+ (vp (infm s223_12) (v-inf s223_13)))
79
+ (v-fin s223_11)
80
+ (pu s223_14)))
81
+
82
+ even if we moved _6 between 2 and 7, we'd then have a completely flat
83
+ structure when moving 3..5 inside
84
+ (s
85
+ (fcl
86
+ (xx s499_1)
87
+ (np
88
+ (pp (pron-pers s499_2) (prp s499_7))
89
+ (n s499_6))
90
+ (v-fin s499_3) (adv s499_4) (adv s499_5) (pu s499_8)))
91
+
92
+ """
93
+
94
+
95
+ from collections import namedtuple
96
+ import io
97
+ import xml.etree.ElementTree as ET
98
+
99
+ from tqdm import tqdm
100
+
101
+ from stanza.models.constituency.parse_tree import Tree
102
+ from stanza.server import tsurgeon
103
+
104
+ def read_xml_file(input_filename):
105
+ """
106
+ Convert an XML file into a list of trees - each <s> becomes its own object
107
+ """
108
+ print("Reading {}".format(input_filename))
109
+ with open(input_filename, encoding="utf-8") as fin:
110
+ lines = fin.readlines()
111
+
112
+ sentences = []
113
+ current_sentence = []
114
+ in_sentence = False
115
+ for line_idx, line in enumerate(lines):
116
+ if line.startswith("<s "):
117
+ if len(current_sentence) > 0:
118
+ raise ValueError("Found the start of a sentence inside an existing sentence, line {}".format(line_idx))
119
+ in_sentence = True
120
+
121
+ if in_sentence:
122
+ current_sentence.append(line)
123
+
124
+ if line.startswith("</s>"):
125
+ assert in_sentence
126
+ current_sentence = [x.replace("<{parentes-udeladt}>", "") for x in current_sentence]
127
+ current_sentence = [x.replace("<{note}>", "") for x in current_sentence]
128
+ sentences.append("".join(current_sentence))
129
+ current_sentence = []
130
+ in_sentence = False
131
+
132
+ assert len(current_sentence) == 0
133
+
134
+ xml_sentences = []
135
+ for sent_idx, text in enumerate(sentences):
136
+ sentence = io.StringIO(text)
137
+ try:
138
+ tree = ET.parse(sentence)
139
+ xml_sentences.append(tree)
140
+ except ET.ParseError as e:
141
+ raise ValueError("Failed to parse sentence {}".format(sent_idx))
142
+
143
+ return xml_sentences
144
+
145
+ Word = namedtuple('Word', ['word', 'tag'])
146
+ Node = namedtuple('Node', ['label', 'children'])
147
+
148
+ class BrokenLinkError(ValueError):
149
+ def __init__(self, error):
150
+ super(BrokenLinkError, self).__init__(error)
151
+
152
+ def process_nodes(root_id, words, nodes, visited):
153
+ """
154
+ Given a root_id, a map of words, and a map of nodes, construct a Tree
155
+
156
+ visited is a set of string ids and mutates over the course of the recursive call
157
+ """
158
+ if root_id in visited:
159
+ raise ValueError("Loop in the tree!")
160
+ visited.add(root_id)
161
+
162
+ if root_id in words:
163
+ word = words[root_id]
164
+ # big brain move: put the root_id here so we can use that to
165
+ # check the sorted order when we are done
166
+ word_node = Tree(label=root_id)
167
+ tag_node = Tree(label=word.tag, children=word_node)
168
+ return tag_node
169
+ elif root_id in nodes:
170
+ node = nodes[root_id]
171
+ children = [process_nodes(child, words, nodes, visited) for child in node.children]
172
+ return Tree(label=node.label, children=children)
173
+ else:
174
+ raise BrokenLinkError("Unknown id! {}".format(root_id))
175
+
176
+ def check_words(tree, tsurgeon_processor):
177
+ """
178
+ Check that the words of a sentence are in order
179
+
180
+ If they are not, this applies a tsurgeon to rearrange simple cases
181
+ The tsurgeon looks at the gap between words, eg _3 to _7, and looks
182
+ for the words between, such as _4 _5 _6. if those words are under
183
+ a node at the same level as the 3-7 node and does not include any
184
+ other nodes (such as _8), that subtree is moved to between _3 and _7
185
+
186
+ Example:
187
+
188
+ (vp (v-fin s4_6) (conj-c s4_8) (v-fin s4_9)) (pron-pers s4_7)
189
+ -->
190
+ (vp (v-fin s4_6) (pron-pers s4_7) (conj-c s4_8) (v-fin s4_9))
191
+ """
192
+ while True:
193
+ words = tree.leaf_labels()
194
+ indices = [int(w.split("_", 1)[1]) for w in words]
195
+ for word_idx, word_label in enumerate(indices):
196
+ if word_idx != word_label - 1:
197
+ break
198
+ else:
199
+ # if there are no weird indices, keep the tree
200
+ return tree
201
+
202
+ sorted_indices = sorted(indices)
203
+ if indices == sorted_indices:
204
+ raise ValueError("Skipped index! This should already be accounted for {}".format(tree))
205
+
206
+ if word_idx == 0:
207
+ return None
208
+
209
+ prefix = words[0].split("_", 1)[0]
210
+ prev_idx = word_idx - 1
211
+ prev_label = indices[prev_idx]
212
+ missing_words = ["%s_%d" % (prefix, x) for x in range(prev_label + 1, word_label)]
213
+ missing_words = "|".join(missing_words)
214
+ #move_tregex = "%s > (__=home > (__=parent > __=grandparent)) . (%s > (__=move > =grandparent))" % (words[word_idx], "|".join(missing_words))
215
+ move_tregex = "%s > (__=home > (__=parent << %s $+ (__=move <<, %s <<- %s)))" % (words[word_idx], words[prev_idx], missing_words, missing_words)
216
+ move_tsurgeon = "move move $+ home"
217
+ modified = tsurgeon_processor.process(tree, move_tregex, move_tsurgeon)[0]
218
+ if modified == tree:
219
+ # this only happens if the desired fix didn't happen
220
+ #print("Failed to process:\n {}\n {} {}".format(tree, prev_label, word_label))
221
+ return None
222
+
223
+ tree = modified
224
+
225
+ def replace_words(tree, words):
226
+ """
227
+ Remap the leaf words given a map of the labels we expect in the leaves
228
+ """
229
+ leaves = tree.leaf_labels()
230
+ new_words = [words[w].word for w in leaves]
231
+ new_tree = tree.replace_words(new_words)
232
+ return new_tree
233
+
234
+ def process_tree(sentence):
235
+ """
236
+ Convert a single ET element representing a Tiger tree to a parse tree
237
+ """
238
+ sentence = sentence.getroot()
239
+ sent_id = sentence.get("id")
240
+ if sent_id is None:
241
+ raise ValueError("Tree {} does not have an id".format(sent_idx))
242
+ if len(sentence) > 1:
243
+ raise ValueError("Longer than expected number of items in {}".format(sent_id))
244
+ graph = sentence.find("graph")
245
+ if not graph:
246
+ raise ValueError("Unexpected tree structure in {} : top tag is not 'graph'".format(sent_id))
247
+
248
+ root_id = graph.get("root")
249
+ if not root_id:
250
+ raise ValueError("Tree has no root id in {}".format(sent_id))
251
+
252
+ terminals = graph.find("terminals")
253
+ if not terminals:
254
+ raise ValueError("No terminals in tree {}".format(sent_id))
255
+ # some Arboretum graphs have two sets of nonterminals,
256
+ # apparently intentionally, so we ignore that possible error
257
+ nonterminals = graph.find("nonterminals")
258
+ if not nonterminals:
259
+ raise ValueError("No nonterminals in tree {}".format(sent_id))
260
+
261
+ # read the words. the words have ids, text, and tags which we care about
262
+ words = {}
263
+ for word in terminals:
264
+ if word.tag == 'parentes-udeladt' or word.tag == 'note':
265
+ continue
266
+ if word.tag != "t":
267
+ raise ValueError("Unexpected tree structure in {} : word with tag other than t".format(sent_id))
268
+ word_id = word.get("id")
269
+ if not word_id:
270
+ raise ValueError("Word had no id in {}".format(sent_id))
271
+ word_text = word.get("word")
272
+ if not word_text:
273
+ raise ValueError("Word had no text in {}".format(sent_id))
274
+ word_pos = word.get("pos")
275
+ if not word_pos:
276
+ raise ValueError("Word had no pos in {}".format(sent_id))
277
+ words[word_id] = Word(word_text, word_pos)
278
+
279
+ # read the nodes. the nodes have ids, labels, and children
280
+ # some of the edges are labeled "secedge". we ignore those
281
+ nodes = {}
282
+ for nt in nonterminals:
283
+ if nt.tag != "nt":
284
+ raise ValueError("Unexpected tree structure in {} : node with tag other than nt".format(sent_id))
285
+ nt_id = nt.get("id")
286
+ if not nt_id:
287
+ raise ValueError("NT has no id in {}".format(sent_id))
288
+ nt_label = nt.get("cat")
289
+ if not nt_label:
290
+ raise ValueError("NT has no label in {}".format(sent_id))
291
+
292
+ children = []
293
+ for child in nt:
294
+ if child.tag != "edge" and child.tag != "secedge":
295
+ raise ValueError("NT has unexpected child in {} : {}".format(sent_id, child.tag))
296
+ if child.tag == "edge":
297
+ child_id = child.get("idref")
298
+ if not child_id:
299
+ raise ValueError("Child is missing an id in {}".format(sent_id))
300
+ children.append(child_id)
301
+ nodes[nt_id] = Node(nt_label, children)
302
+
303
+ if root_id not in nodes:
304
+ raise ValueError("Could not find root in nodes in {}".format(sent_id))
305
+
306
+ tree = process_nodes(root_id, words, nodes, set())
307
+ return tree, words
308
+
309
+ def word_sequence_missing_words(tree):
310
+ """
311
+ Check if the word sequence is missing words
312
+
313
+ Some trees skip labels, such as
314
+ (s (fcl (pron-pers s16817_1) (v-fin s16817_2) (prp s16817_3) (pp (prp s16817_5) (par (n s16817_6) (conj-c s16817_7) (n s16817_8))) (pu s16817_9)))
315
+ but in these cases, the word is present in the original text and simply not attached to the tree
316
+ """
317
+ words = tree.leaf_labels()
318
+ indices = [int(w.split("_")[1]) for w in words]
319
+ indices = sorted(indices)
320
+ for idx, label in enumerate(indices):
321
+ if label != idx + 1:
322
+ return True
323
+ return False
324
+
325
+ WORD_TO_PHRASE = {
326
+ "art": "advp", # "en smule" is the one time this happens. it is used as an advp elsewhere
327
+ "adj": "adjp",
328
+ "adv": "advp",
329
+ "conj": "cp",
330
+ "intj": "fcl", # not sure? seems to match "hold kæft" when it shows up
331
+ "n": "np",
332
+ "num": "np", # would prefer something like QP from PTB
333
+ "pron": "np", # ??
334
+ "prop": "np",
335
+ "prp": "pp",
336
+ "v": "vp",
337
+ }
338
+
339
+ def split_underscores(tree):
340
+ assert not tree.is_leaf(), "Should never reach a leaf in this code path"
341
+
342
+ if tree.is_preterminal():
343
+ return tree
344
+
345
+ children = tree.children
346
+ new_children = []
347
+ for child in children:
348
+ if child.is_preterminal():
349
+ if '_' not in child.children[0].label:
350
+ new_children.append(child)
351
+ continue
352
+
353
+ if child.label.split("-")[0] not in WORD_TO_PHRASE:
354
+ raise ValueError("SPLITTING {}".format(child))
355
+ pieces = []
356
+ for piece in child.children[0].label.split("_"):
357
+ # This may not be accurate, but we already retag the treebank anyway
358
+ if len(piece) == 0:
359
+ raise ValueError("A word started or ended with _")
360
+ pieces.append(Tree(child.label, Tree(piece)))
361
+ new_children.append(Tree(WORD_TO_PHRASE[child.label.split("-")[0]], pieces))
362
+ else:
363
+ new_children.append(split_underscores(child))
364
+
365
+ return Tree(tree.label, new_children)
366
+
367
+ REMAP_LABELS = {
368
+ "adj": "adjp",
369
+ "adv": "advp",
370
+ "intj": "fcl",
371
+ "n": "np",
372
+ "num": "np", # again, a dedicated number node would be better, but there are only a few "num" labeled
373
+ "prp": "pp",
374
+ }
375
+
376
+
377
+ def has_weird_constituents(tree):
378
+ """
379
+ Eliminate a few trees with weird labels
380
+
381
+ Eliminate p? there are only 3 and they have varying structure underneath
382
+ Also cl, since I have no idea how to label it and it only excludes 1 tree
383
+ """
384
+ labels = Tree.get_unique_constituent_labels(tree)
385
+ if "p" in labels or "cl" in labels:
386
+ return True
387
+ return False
388
+
389
+ def convert_tiger_treebank(input_filename):
390
+ sentences = read_xml_file(input_filename)
391
+
392
+ unfixable = 0
393
+ dangling = 0
394
+ broken_links = 0
395
+ missing_words = 0
396
+ weird_constituents = 0
397
+ trees = []
398
+
399
+ with tsurgeon.Tsurgeon() as tsurgeon_processor:
400
+ for sent_idx, sentence in enumerate(tqdm(sentences)):
401
+ try:
402
+ tree, words = process_tree(sentence)
403
+
404
+ if not tree.all_leaves_are_preterminals():
405
+ dangling += 1
406
+ continue
407
+
408
+ if word_sequence_missing_words(tree):
409
+ missing_words += 1
410
+ continue
411
+
412
+ tree = check_words(tree, tsurgeon_processor)
413
+ if tree is None:
414
+ unfixable += 1
415
+ continue
416
+
417
+ if has_weird_constituents(tree):
418
+ weird_constituents += 1
419
+ continue
420
+
421
+ tree = replace_words(tree, words)
422
+ tree = split_underscores(tree)
423
+ tree = tree.remap_constituent_labels(REMAP_LABELS)
424
+ trees.append(tree)
425
+ except BrokenLinkError as e:
426
+ # the get("id") would have failed as a different error type if missing,
427
+ # so we can safely use it directly like this
428
+ broken_links += 1
429
+ # print("Unable to process {} because of broken links: {}".format(sentence.getroot().get("id"), e))
430
+
431
+ print("Found {} trees with empty nodes".format(dangling))
432
+ print("Found {} trees with unattached words".format(missing_words))
433
+ print("Found {} trees with confusing constituent labels".format(weird_constituents))
434
+ print("Not able to rearrange {} nodes".format(unfixable))
435
+ print("Unable to handle {} trees because of broken links, eg names in another tree".format(broken_links))
436
+ print("Parsed {} trees from {}".format(len(trees), input_filename))
437
+ return trees
438
+
439
+ def main():
440
+ treebank = convert_tiger_treebank("extern_data/constituency/danish/W0084/arboretum.tiger/arboretum.tiger")
441
+
442
+ if __name__ == '__main__':
443
+ main()
stanza/stanza/utils/datasets/constituency/convert_icepahc.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Currently this doesn't function
3
+
4
+ The goal is simply to demonstrate how to use tsurgeon
5
+ """
6
+
7
+ from stanza.models.constituency.tree_reader import read_trees, read_treebank
8
+ from stanza.server import tsurgeon
9
+
10
+ TREEBANK = """
11
+ ( (IP-MAT (NP-SBJ (PRO-N Það-það))
12
+ (BEPI er-vera)
13
+ (ADVP (ADV eiginlega-eiginlega))
14
+ (ADJP (NEG ekki-ekki) (ADJ-N hægt-hægur))
15
+ (IP-INF (TO að-að) (VB lýsa-lýsa))
16
+ (NP-OB1 (N-D tilfinningu$-tilfinning) (D-D $nni-hinn))
17
+ (IP-INF (TO að-að) (VB fá-fá))
18
+ (IP-INF (TO að-að) (VB taka-taka))
19
+ (NP-OB1 (N-A þátt-þáttur))
20
+ (PP (P í-í)
21
+ (NP (D-D þessu-þessi)))
22
+ (, ,-,)
23
+ (VBPI segir-segja)
24
+ (NP-SBJ (NPR-N Sverrir-sverrir) (NPR-N Ingi-ingi))
25
+ (. .-.)))
26
+ """
27
+
28
+ # Output of the first tsurgeon:
29
+ #(ROOT
30
+ # (IP-MAT
31
+ # (NP-SBJ (PRO-N Það))
32
+ # (BEPI er)
33
+ # (ADVP (ADV eiginlega))
34
+ # (ADJP (NEG ekki) (ADJ-N hægt))
35
+ # (IP-INF (TO að) (VB lýsa))
36
+ # (NP-OB1 (N-D tilfinningu$) (D-D $nni))
37
+ # (IP-INF (TO að) (VB fá))
38
+ # (IP-INF (TO að) (VB taka))
39
+ # (NP-OB1 (N-A þátt))
40
+ # (PP
41
+ # (P í)
42
+ # (NP (D-D þessu)))
43
+ # (, ,)
44
+ # (VBPI segir)
45
+ # (NP-SBJ (NPR-N Sverrir) (NPR-N Ingi))
46
+ # (. .)))
47
+
48
+ # Output of the second operation
49
+ #(ROOT
50
+ # (IP-MAT
51
+ # (NP-SBJ (PRO-N Það))
52
+ # (BEPI er)
53
+ # (ADVP (ADV eiginlega))
54
+ # (ADJP (NEG ekki) (ADJ-N hægt))
55
+ # (IP-INF (TO að) (VB lýsa))
56
+ # (NP-OB1 (N-D tilfinningunni))
57
+ # (IP-INF (TO að) (VB fá))
58
+ # (IP-INF (TO að) (VB taka))
59
+ # (NP-OB1 (N-A þátt))
60
+ # (PP
61
+ # (P í)
62
+ # (NP (D-D þessu)))
63
+ # (, ,)
64
+ # (VBPI segir)
65
+ # (NP-SBJ (NPR-N Sverrir) (NPR-N Ingi))
66
+ # (. .)))
67
+
68
+
69
+ treebank = read_trees(TREEBANK)
70
+
71
+ with tsurgeon.Tsurgeon(classpath="$CLASSPATH") as tsurgeon_processor:
72
+ form_tregex = "/^(.+)-.+$/#1%form=word !< __"
73
+ form_tsurgeon = "relabel word /^.+$/%{form}/"
74
+
75
+ noun_det_tregex = "/^N-/ < /^([^$]+)[$]$/#1%noun=noun $+ (/^D-/ < /^[$]([^$]+)$/#1%det=det)"
76
+ noun_det_relabel = "relabel noun /^.+$/%{noun}%{det}/"
77
+ noun_det_prune = "prune det"
78
+
79
+ for tree in treebank:
80
+ updated_tree = tsurgeon_processor.process(tree, (form_tregex, form_tsurgeon))[0]
81
+ print("{:P}".format(updated_tree))
82
+ updated_tree = tsurgeon_processor.process(updated_tree, (noun_det_tregex, noun_det_relabel, noun_det_prune))[0]
83
+ print("{:P}".format(updated_tree))
stanza/stanza/utils/datasets/constituency/convert_it_turin.py ADDED
@@ -0,0 +1,339 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Converts Turin's constituency dataset
3
+
4
+ Turin University put out a freely available constituency dataset in 2011.
5
+ It is not as large as VIT or ISST, but it is free, which is nice.
6
+
7
+ The 2011 parsing task combines trees from several sources:
8
+ http://www.di.unito.it/~tutreeb/evalita-parsingtask-11.html
9
+
10
+ There is another site for Turin treebanks:
11
+ http://www.di.unito.it/~tutreeb/treebanks.html
12
+
13
+ Weirdly, the most recent versions of the Evalita trees are not there.
14
+ The most relevant parts are the ParTUT downloads. As of Sep. 2021:
15
+
16
+ http://www.di.unito.it/~tutreeb/corpora/Par-TUT/tutINpenn/italian/JRCAcquis_It.pen
17
+ http://www.di.unito.it/~tutreeb/corpora/Par-TUT/tutINpenn/italian/UDHR_It.pen
18
+ http://www.di.unito.it/~tutreeb/corpora/Par-TUT/tutINpenn/italian/CC_It.pen
19
+ http://www.di.unito.it/~tutreeb/corpora/Par-TUT/tutINpenn/italian/FB_It.pen
20
+ http://www.di.unito.it/~tutreeb/corpora/Par-TUT/tutINpenn/italian/WIT3_It.pen
21
+
22
+ We can't simply cat all these files together as there are a bunch of
23
+ asterisks as comments and the files may have some duplicates. For
24
+ example, the JRCAcquis piece has many duplicates. Also, some don't
25
+ pass validation for one reason or another.
26
+
27
+ One oddity of these data files is that the MWT are denoted by doubling
28
+ the token. The token is not split as would be expected, though. We try
29
+ to use stanza's MWT tokenizer for IT to split the tokens, with some
30
+ rules added by hand in BIWORD_SPLITS. Two are still unsplit, though...
31
+ """
32
+
33
+ import glob
34
+ import os
35
+ import re
36
+ import sys
37
+
38
+ import stanza
39
+ from stanza.models.constituency import parse_tree
40
+ from stanza.models.constituency import tree_reader
41
+
42
+ def load_without_asterisks(in_file, encoding='utf-8'):
43
+ with open(in_file, encoding=encoding) as fin:
44
+ new_lines = [x if x.find("********") < 0 else "\n" for x in fin.readlines()]
45
+ if len(new_lines) > 0 and not new_lines[-1].endswith("\n"):
46
+ new_lines[-1] = new_lines[-1] + "\n"
47
+ return new_lines
48
+
49
+ CONSTITUENT_SPLIT = re.compile("[-=#+0-9]")
50
+
51
+ # JRCA is almost entirely duplicates
52
+ # WIT3 follows a different annotation scheme
53
+ FILES_TO_ELIMINATE = ["JRCAcquis_It.pen", "WIT3_It.pen"]
54
+
55
+ # assuming this is a typo
56
+ REMAP_NODES = { "Sbar" : "SBAR" }
57
+
58
+ REMAP_WORDS = { "-LSB-": "[", "-RSB-": "]" }
59
+
60
+ # these mostly seem to be mistakes
61
+ # maybe Vbar and ADVbar should be converted to something else?
62
+ NODES_TO_ELIMINATE = ["C", "PHRASP", "PRDT", "Vbar", "parte", "ADVbar"]
63
+
64
+ UNKNOWN_SPLITS = set()
65
+
66
+ # a map of splits that the tokenizer or MWT doesn't handle well
67
+ BIWORD_SPLITS = { "offertogli": ("offerto", "gli"),
68
+ "offertegli": ("offerte", "gli"),
69
+ "formatasi": ("formata", "si"),
70
+ "formatosi": ("formato", "si"),
71
+ "multiplexarlo": ("multiplexar", "lo"),
72
+ "esibirsi": ("esibir", "si"),
73
+ "pagarne": ("pagar", "ne"),
74
+ "recarsi": ("recar", "si"),
75
+ "trarne": ("trar", "ne"),
76
+ "esserci": ("esser", "ci"),
77
+ "aprirne": ("aprir", "ne"),
78
+ "farle": ("far", "le"),
79
+ "disporne": ("dispor", "ne"),
80
+ "andargli": ("andar", "gli"),
81
+ "CONSIDERARSI": ("CONSIDERAR", "SI"),
82
+ "conferitegli": ("conferite", "gli"),
83
+ "formatasi": ("formata", "si"),
84
+ "formatosi": ("formato", "si"),
85
+ "Formatisi": ("Formati", "si"),
86
+ "multiplexarlo": ("multiplexar", "lo"),
87
+ "esibirsi": ("esibir", "si"),
88
+ "pagarne": ("pagar", "ne"),
89
+ "recarsi": ("recar", "si"),
90
+ "trarne": ("trar", "ne"),
91
+ "temerne": ("temer", "ne"),
92
+ "esserci": ("esser", "ci"),
93
+ "esservi": ("esser", "vi"),
94
+ "restituirne": ("restituir", "ne"),
95
+ "col": ("con", "il"),
96
+ "cogli": ("con", "gli"),
97
+ "dirgli": ("dir", "gli"),
98
+ "opporgli": ("oppor", "gli"),
99
+ "eccolo": ("ecco", "lo"),
100
+ "Eccolo": ("Ecco", "lo"),
101
+ "Eccole": ("Ecco", "le"),
102
+ "farci": ("far", "ci"),
103
+ "farli": ("far", "li"),
104
+ "farne": ("far", "ne"),
105
+ "farsi": ("far", "si"),
106
+ "farvi": ("far", "vi"),
107
+ "Connettiti": ("Connetti", "ti"),
108
+ "APPLICARSI": ("APPLICAR", "SI"),
109
+ # This is not always two words, but if it IS two words,
110
+ # it gets split like this
111
+ "assicurati": ("assicura", "ti"),
112
+ "Fatti": ("Fai", "te"),
113
+ "ai": ("a", "i"),
114
+ "Ai": ("A", "i"),
115
+ "AI": ("A", "I"),
116
+ "al": ("a", "il"),
117
+ "Al": ("A", "il"),
118
+ "AL": ("A", "IL"),
119
+ "coi": ("con", "i"),
120
+ "colla": ("con", "la"),
121
+ "colle": ("con", "le"),
122
+ "dal": ("da", "il"),
123
+ "Dal": ("Da", "il"),
124
+ "DAL": ("DA", "IL"),
125
+ "dei": ("di", "i"),
126
+ "Dei": ("Di", "i"),
127
+ "DEI": ("DI", "I"),
128
+ "del": ("di", "il"),
129
+ "Del": ("Di", "il"),
130
+ "DEL": ("DI", "IL"),
131
+ "nei": ("in", "i"),
132
+ "NEI": ("IN", "I"),
133
+ "nel": ("in", "il"),
134
+ "Nel": ("In", "il"),
135
+ "NEL": ("IN", "IL"),
136
+ "pel": ("per", "il"),
137
+ "sui": ("su", "i"),
138
+ "Sui": ("Su", "i"),
139
+ "sul": ("su", "il"),
140
+ "Sul": ("Su", "il"),
141
+ ",": (",", ","),
142
+ ".": (".", "."),
143
+ '"': ('"', '"'),
144
+ '-': ('-', '-'),
145
+ '-LRB-': ('-LRB-', '-LRB-'),
146
+ "garantirne": ("garantir", "ne"),
147
+ "aprirvi": ("aprir", "vi"),
148
+ "esimersi": ("esimer", "si"),
149
+ "opporsi": ("oppor", "si"),
150
+ }
151
+
152
+ CAP_BIWORD = re.compile("[A-Z]+_[A-Z]+")
153
+
154
+ def split_mwe(tree, pipeline):
155
+ words = list(tree.leaf_labels())
156
+ found = False
157
+ for idx, word in enumerate(words[:-3]):
158
+ if word == words[idx+1] and word == words[idx+2] and word == words[idx+3]:
159
+ raise ValueError("Oh no, 4 consecutive words")
160
+
161
+ for idx, word in enumerate(words[:-2]):
162
+ if word == words[idx+1] and word == words[idx+2]:
163
+ doc = pipeline(word)
164
+ assert len(doc.sentences) == 1
165
+ if len(doc.sentences[0].words) != 3:
166
+ raise RuntimeError("Word {} not tokenized into 3 parts... thought all 3 part words were handled!".format(word))
167
+ words[idx] = doc.sentences[0].words[0].text
168
+ words[idx+1] = doc.sentences[0].words[1].text
169
+ words[idx+2] = doc.sentences[0].words[2].text
170
+ found = True
171
+
172
+ for idx, word in enumerate(words[:-1]):
173
+ if word == words[idx+1]:
174
+ if word in BIWORD_SPLITS:
175
+ first_word = BIWORD_SPLITS[word][0]
176
+ second_word = BIWORD_SPLITS[word][1]
177
+ elif CAP_BIWORD.match(word):
178
+ first_word, second_word = word.split("_")
179
+ else:
180
+ doc = pipeline(word)
181
+ assert len(doc.sentences) == 1
182
+ if len(doc.sentences[0].words) == 2:
183
+ first_word = doc.sentences[0].words[0].text
184
+ second_word = doc.sentences[0].words[1].text
185
+ else:
186
+ if word not in UNKNOWN_SPLITS:
187
+ UNKNOWN_SPLITS.add(word)
188
+ print("Could not figure out how to split {}\n {}\n {}".format(word, " ".join(words), tree))
189
+ continue
190
+
191
+ words[idx] = first_word
192
+ words[idx+1] = second_word
193
+ found = True
194
+
195
+ if found:
196
+ tree = tree.replace_words(words)
197
+ return tree
198
+
199
+
200
+ def load_trees(filename, pipeline):
201
+ # some of the files are in latin-1 encoding rather than utf-8
202
+ try:
203
+ raw_text = load_without_asterisks(filename, "utf-8")
204
+ except UnicodeDecodeError:
205
+ raw_text = load_without_asterisks(filename, "latin-1")
206
+
207
+ # also, some have messed up validation (it will be logged)
208
+ # hence the broken_ok=True argument
209
+ trees = tree_reader.read_trees("".join(raw_text), broken_ok=True)
210
+
211
+ filtered_trees = []
212
+ for tree in trees:
213
+ if tree.children[0].label is None:
214
+ print("Skipping a broken tree (missing label) in {}: {}".format(filename, tree))
215
+ continue
216
+
217
+ try:
218
+ words = tuple(tree.leaf_labels())
219
+ except ValueError:
220
+ print("Skipping a broken tree (missing preterminal) in {}: {}".format(filename, tree))
221
+ continue
222
+
223
+ if any('www.facebook' in pt.label for pt in tree.preterminals()):
224
+ print("Skipping a tree with a weird preterminal label in {}: {}".format(filename, tree))
225
+ continue
226
+
227
+ tree = tree.prune_none().simplify_labels(CONSTITUENT_SPLIT)
228
+
229
+ if len(tree.children) > 1:
230
+ print("Found a tree with a non-unary root! {}: {}".format(filename, tree))
231
+ continue
232
+ if tree.children[0].is_preterminal():
233
+ print("Found a tree with a single preterminal node! {}: {}".format(filename, tree))
234
+ continue
235
+
236
+ # The expectation is that the retagging will handle this anyway
237
+ for pt in tree.preterminals():
238
+ if not pt.label:
239
+ pt.label = "UNK"
240
+ print("Found a tree with a blank preterminal label. Setting it to UNK. {}: {}".format(filename, tree))
241
+
242
+ tree = tree.remap_constituent_labels(REMAP_NODES)
243
+ tree = tree.remap_words(REMAP_WORDS)
244
+
245
+ tree = split_mwe(tree, pipeline)
246
+ if tree is None:
247
+ continue
248
+
249
+ constituents = set(parse_tree.Tree.get_unique_constituent_labels(tree))
250
+ for weird_label in NODES_TO_ELIMINATE:
251
+ if weird_label in constituents:
252
+ break
253
+ else:
254
+ weird_label = None
255
+ if weird_label is not None:
256
+ print("Skipping a tree with a weird label {} in {}: {}".format(weird_label, filename, tree))
257
+ continue
258
+
259
+ filtered_trees.append(tree)
260
+
261
+ return filtered_trees
262
+
263
+ def save_trees(out_file, trees):
264
+ print("Saving {} trees to {}".format(len(trees), out_file))
265
+ with open(out_file, "w", encoding="utf-8") as fout:
266
+ for tree in trees:
267
+ fout.write(str(tree))
268
+ fout.write("\n")
269
+
270
+ def convert_it_turin(input_path, output_path):
271
+ pipeline = stanza.Pipeline("it", processors="tokenize, mwt", tokenize_no_ssplit=True)
272
+
273
+ os.makedirs(output_path, exist_ok=True)
274
+
275
+ evalita_dir = os.path.join(input_path, "evalita")
276
+
277
+ evalita_test = os.path.join(evalita_dir, "evalita11_TESTgold_CONPARSE.penn")
278
+ it_test = os.path.join(output_path, "it_turin_test.mrg")
279
+ test_trees = load_trees(evalita_test, pipeline)
280
+ save_trees(it_test, test_trees)
281
+
282
+ known_text = set()
283
+ for tree in test_trees:
284
+ words = tuple(tree.leaf_labels())
285
+ assert words not in known_text
286
+ known_text.add(words)
287
+
288
+ evalita_train = os.path.join(output_path, "it_turin_train.mrg")
289
+ evalita_files = glob.glob(os.path.join(evalita_dir, "*2011*penn"))
290
+ turin_files = glob.glob(os.path.join(input_path, "turin", "*pen"))
291
+ filenames = evalita_files + turin_files
292
+ filtered_trees = []
293
+ for filename in filenames:
294
+ if os.path.split(filename)[1] in FILES_TO_ELIMINATE:
295
+ continue
296
+
297
+ trees = load_trees(filename, pipeline)
298
+ file_trees = []
299
+
300
+ for tree in trees:
301
+ words = tuple(tree.leaf_labels())
302
+ if words in known_text:
303
+ print("Skipping a duplicate in {}: {}".format(filename, tree))
304
+ continue
305
+
306
+ known_text.add(words)
307
+
308
+ file_trees.append(tree)
309
+
310
+ filtered_trees.append((filename, file_trees))
311
+
312
+ print("{} contains {} usable trees".format(evalita_test, len(test_trees)))
313
+ print(" Unique constituents in {}: {}".format(evalita_test, parse_tree.Tree.get_unique_constituent_labels(test_trees)))
314
+
315
+ train_trees = []
316
+ dev_trees = []
317
+ for filename, file_trees in filtered_trees:
318
+ print("{} contains {} usable trees".format(filename, len(file_trees)))
319
+ print(" Unique constituents in {}: {}".format(filename, parse_tree.Tree.get_unique_constituent_labels(file_trees)))
320
+ for tree in file_trees:
321
+ if len(train_trees) <= len(dev_trees) * 9:
322
+ train_trees.append(tree)
323
+ else:
324
+ dev_trees.append(tree)
325
+
326
+ it_train = os.path.join(output_path, "it_turin_train.mrg")
327
+ save_trees(it_train, train_trees)
328
+
329
+ it_dev = os.path.join(output_path, "it_turin_dev.mrg")
330
+ save_trees(it_dev, dev_trees)
331
+
332
+ def main():
333
+ input_path = sys.argv[1]
334
+ output_path = sys.argv[2]
335
+
336
+ convert_it_turin(input_path, output_path)
337
+
338
+ if __name__ == '__main__':
339
+ main()
stanza/stanza/utils/datasets/constituency/convert_it_vit.py ADDED
@@ -0,0 +1,700 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Converts the proprietary VIT dataset to a format suitable for stanza
2
+
3
+ There are multiple corrections in the UD version of VIT, along with
4
+ recommended splits for the MWT, along with recommended splits of
5
+ the sentences into train/dev/test
6
+
7
+ Accordingly, it is necessary to use the UD dataset as a reference
8
+
9
+ Here is a sample line of the text file we use:
10
+
11
+ #ID=sent_00002 cp-[sp-[part-negli, sn-[sa-[ag-ultimi], nt-anni]], f-[sn-[art-la, n-dinamica, spd-[partd-dei, sn-[n-polo_di_attrazione]]], ibar-[ause-è, ausep-stata, savv-[savv-[avv-sempre], avv-più], vppt-caratterizzata], compin-[spda-[partda-dall, sn-[n-emergere, spd-[pd-di, sn-[art-una, sa-[ag-crescente], n-concorrenza, f2-[rel-che, f-[ibar-[clit-si, ause-è, avv-progressivamente, vppin-spostata], compin-[spda-[partda-dalle, sn-[sa-[ag-singole], n-imprese]], sp-[part-ai, sn-[n-sistemi, sa-[coord-[ag-economici, cong-e, ag-territoriali]]]], fp-[punt-',', sv5-[vgt-determinando, compt-[sn-[art-l_, nf-esigenza, spd-[pd-di, sn-[art-una, n-riconsiderazione, spd-[partd-dei, sn-[n-rapporti, sv3-[ppre-esistenti, compin-[sp-[p-tra, sn-[n-soggetti, sa-[ag-produttivi]]], cong-e, sn-[n-ambiente, f2-[sp-[p-in, sn-[relob-cui]], f-[sn-[deit-questi], ibar-[vin-operano, punto-.]]]]]]]]]]]]]]]]]]]]]]]]
12
+
13
+ Here you can already see multiple issues when parsing:
14
+ - the first word is "negli", which is split into In_ADP gli_DET in the UD version
15
+ - also the first word is capitalized in the UD version
16
+ - comma looks like a tempting split target, but there is a ',' in this sentence
17
+ punt-','
18
+ - not shown here is '-' which is different from the - used for denoting POS
19
+ par-'-'
20
+
21
+ Fortunately, -[ is always an open and ] is always a close
22
+
23
+ As of April 2022, the UD version of the dataset has some minor edits
24
+ which are necessary for the proper functioning of this script.
25
+ Otherwise, the MWT won't align correctly, some typos won't be
26
+ corrected, etc. These edits are released in UD 2.10
27
+
28
+ The data itself is available from ELRA:
29
+
30
+ http://catalog.elra.info/en-us/repository/browse/ELRA-W0040/
31
+
32
+ Internally at Stanford you can contact Chris Manning or John Bauer.
33
+
34
+ The processing goes as follows:
35
+ - read in UD and con trees
36
+ some of the con trees have broken brackets and are discarded
37
+ in other cases, abbreviations were turned into single tokens in UD
38
+ - extract the MWT expansions of Italian contractions,
39
+ such as Negli -> In gli
40
+ - attempt to align the trees from the two datasets using ngrams
41
+ some trees had the sentence splitting updated
42
+ sentences which can't be matched are discarded
43
+ - use CoreNLP tsurgeon to update tokens in the con trees
44
+ based on the information in the UD dataset
45
+ - split contractions
46
+ - rearrange clitics which are occasionally non-projective
47
+ - replace the words in the con tree with the dep tree's words
48
+ this takes advantage of spelling & capitalization fixes
49
+
50
+ In 2022, there was an update to the dataset from Prof. Delmonte.
51
+ This update is hopefully in current ELRA distributions now.
52
+ If not, please contact ELRA to specifically ask for the updated version.
53
+ Internally to Stanford, feel free to ask Chris or John for the updates.
54
+ Look for the line below "original version with more errors"
55
+
56
+ In August 2022, Prof. Delmonte made a slight update in a zip file
57
+ `john.zip`. If/when that gets updated to ELRA, we will update it
58
+ here. Contact Chris or John for a copy if not updated yet, or go
59
+ back in git history to get the older version of the code which
60
+ works with the 2022 ELRA update.
61
+
62
+ Later, in September 2022, there is yet another update,
63
+ New version of VIT.zip
64
+ Unzip the contents into a folder
65
+ $CONSTITUENCY_BASE/italian/it_vit
66
+ so there should be a file
67
+ $CONSTITUENCY_BASE/italian/it_vit/VITwritten/VITconstsyntNumb
68
+
69
+ There are a few other updates needed to improve the annotations,
70
+ but all the nagging seemed to give Prof. Delmonte a headache,
71
+ so at this point we include those fixes in this script instead.
72
+ See the first few tsurgeon operations in update_mwts_and_special_cases
73
+ """
74
+
75
+ from collections import defaultdict, deque, namedtuple
76
+ import itertools
77
+ import os
78
+ import re
79
+ import sys
80
+
81
+ from tqdm import tqdm
82
+
83
+ from stanza.models.constituency.tree_reader import read_trees, UnclosedTreeError, ExtraCloseTreeError
84
+ from stanza.server import tsurgeon
85
+ from stanza.utils.conll import CoNLL
86
+ from stanza.utils.datasets.constituency.utils import SHARDS, write_dataset
87
+ import stanza.utils.default_paths as default_paths
88
+
89
+ def read_constituency_sentences(fin):
90
+ """
91
+ Reads the lines from the constituency treebank and splits into ID, text
92
+
93
+ No further processing is done on the trees yet
94
+ """
95
+ sentences = []
96
+ for line in fin:
97
+ line = line.strip()
98
+ # WTF why doesn't strip() remove this
99
+ line = line.replace(u'\ufeff', '')
100
+ if not line:
101
+ continue
102
+ sent_id, sent_text = line.split(maxsplit=1)
103
+ # we have seen a couple different versions of this sentence header
104
+ # although one file is always consistent with itself, at least
105
+ if not sent_id.startswith("#ID=sent") and not sent_id.startswith("ID#sent"):
106
+ raise ValueError("Unexpected start of sentence: |{}|".format(sent_id))
107
+ if not sent_text:
108
+ raise ValueError("Empty text for |{}|".format(sent_id))
109
+ sentences.append((sent_id, sent_text))
110
+ return sentences
111
+
112
+ def read_constituency_file(filename):
113
+ print("Reading raw constituencies from %s" % filename)
114
+ with open(filename, encoding='utf-8') as fin:
115
+ return read_constituency_sentences(fin)
116
+
117
+ OPEN = "-["
118
+ CLOSE = "]"
119
+
120
+ DATE_RE = re.compile("^([0-9]{1,2})[_]([0-9]{2})$")
121
+ INTEGER_PERCENT_RE = re.compile(r"^((?:min|plus)?[0-9]{1,3})[%]$")
122
+ DECIMAL_PERCENT_RE = re.compile(r"^((?:min|plus)?[0-9]{1,3})[/_]([0-9]{1,3})[%]$")
123
+ RANGE_PERCENT_RE = re.compile(r"^([0-9]{1,2}[/_][0-9]{1,2})[/]([0-9]{1,2}[/_][0-9]{1,2})[%]$")
124
+ DECIMAL_RE = re.compile(r"^([0-9])[_]([0-9])$")
125
+
126
+ ProcessedTree = namedtuple('ProcessedTree', ['con_id', 'dep_id', 'tree'])
127
+
128
+ def raw_tree(text):
129
+ """
130
+ A sentence will look like this:
131
+ #ID=sent_00001 fc-[f3-[sn-[art-le, n-infrastrutture, sc-[ccom-come, sn-[n-fattore, spd-[pd-di,
132
+ sn-[n-competitività]]]]]], f3-[spd-[pd-di, sn-[mw-Angela, nh-Airoldi]]], punto-.]
133
+ Non-preterminal nodes have tags, followed by the stuff under the node, -[
134
+ The node is closed by the ]
135
+ """
136
+ pieces = []
137
+ open_pieces = text.split(OPEN)
138
+ for open_idx, open_piece in enumerate(open_pieces):
139
+ if open_idx > 0:
140
+ pieces[-1] = pieces[-1] + OPEN
141
+ open_piece = open_piece.strip()
142
+ if not open_piece:
143
+ raise ValueError("Unexpected empty node!")
144
+ close_pieces = open_piece.split(CLOSE)
145
+ for close_idx, close_piece in enumerate(close_pieces):
146
+ if close_idx > 0:
147
+ pieces.append(CLOSE)
148
+ close_piece = close_piece.strip()
149
+ if not close_piece:
150
+ # this is okay - multiple closes at the end of a deep bracket
151
+ continue
152
+ word_pieces = close_piece.split(", ")
153
+ pieces.extend([x.strip() for x in word_pieces if x.strip()])
154
+
155
+ # at this point, pieces is a list with:
156
+ # tag-[ for opens
157
+ # tag-word for words
158
+ # ] for closes
159
+ # this structure converts pretty well to reading using the tree reader
160
+
161
+ PIECE_MAPPING = {
162
+ "agn-/ter'": "(agn ter)",
163
+ "cong-'&'": "(cong &)",
164
+ "da_riempire-'...'": "(da_riempire ...)",
165
+ "date-1992_1993": "(date 1992/1993)",
166
+ "date-'31-12-95'": "(date 31-12-95)",
167
+ "date-'novantaquattro-95'":"(date novantaquattro-95)",
168
+ "date-'novantaquattro-95": "(date novantaquattro-95)",
169
+ "date-'novantaquattro-novantacinque'": "(date novantaquattro-novantacinque)",
170
+ "dirs-':'": "(dirs :)",
171
+ "dirs-'\"'": "(dirs \")",
172
+ "mw-'&'": "(mw &)",
173
+ "mw-'Presunto'": "(mw Presunto)",
174
+ "nh-'Alain-Gauze'": "(nh Alain-Gauze)",
175
+ "np-'porto_Marghera'": "(np Porto) (np Marghera)",
176
+ "np-'roma-l_aquila'": "(np Roma-L'Aquila)",
177
+ "np-'L_Aquila-Villa_Vomano'": "(np L'Aquila) (np -) (np Villa) (np Vomano)",
178
+ "npro-'Avanti_!'": "(npro Avanti_!)",
179
+ "npro-'Viacom-Paramount'": "(npro Viacom-Paramount)",
180
+ "npro-'Rhone-Poulenc'": "(npro Rhone-Poulenc)",
181
+ "npro-'Itar-Tass'": "(npro Itar-Tass)",
182
+ "par-(-)": "(par -)",
183
+ "par-','": "(par ,)",
184
+ "par-'<'": "(par <)",
185
+ "par-'>'": "(par >)",
186
+ "par-'-'": "(par -)",
187
+ "par-'\"'": "(par \")",
188
+ "par-'('": "(par -LRB-)",
189
+ "par-')'": "(par -RRB-)",
190
+ "par-'&&'": "(par &&)",
191
+ "punt-','": "(punt ,)",
192
+ "punt-'-'": "(punt -)",
193
+ "punt-';'": "(punt ;)",
194
+ "punto-':'": "(punto :)",
195
+ "punto-';'": "(punto ;)",
196
+ "puntint-'!'": "(puntint !)",
197
+ "puntint-'?'": "(puntint !)",
198
+ "num-'2plus2'": "(num 2+2)",
199
+ "num-/bis'": "(num bis)",
200
+ "num-/ter'": "(num ter)",
201
+ "num-18_00/1_00": "(num 18:00/1:00)",
202
+ "num-1/500_2/000": "(num 1.500-2.000)",
203
+ "num-16_1": "(num 16,1)",
204
+ "num-0_1": "(num 0,1)",
205
+ "num-0_3": "(num 0,3)",
206
+ "num-2_7": "(num 2,7)",
207
+ "num-455_68": "(num 455/68)",
208
+ "num-437_5": "(num 437,5)",
209
+ "num-4708_82": "(num 4708,82)",
210
+ "num-16EQ517_7": "(num 16EQ517/7)",
211
+ "num-2=184_90": "(num 2=184/90)",
212
+ "num-3EQ429_20": "(num 3eq429/20)",
213
+ "num-'1990-EQU-100'": "(num 1990-EQU-100)",
214
+ "num-'500-EQU-250'": "(num 500-EQU-250)",
215
+ "num-0_39%minus": "(num 0,39) (num %%) (num -)",
216
+ "num-1_88/76": "(num 1-88/76)",
217
+ "num-'70/80'": "(num 70,80)",
218
+ "num-'18/20'": "(num 18:20)",
219
+ "num-295/mila'": "(num 295mila)",
220
+ "num-'295/mila'": "(num 295mila)",
221
+ "num-0/07%plus": "(num 0,07) (num %%) (num plus)",
222
+ "num-0/69%minus": "(num 0,69) (num %%) (num minus)",
223
+ "num-0_39%minus": "(num 0,39) (num %%) (num minus)",
224
+ "num-9_11/16": "(num 9-11,16)",
225
+ "num-2/184_90": "(num 2=184/90)",
226
+ "num-3/429_20": "(num 3eq429/20)",
227
+ # TODO: remove the following num conversions if possible
228
+ # this would require editing either constituency or UD
229
+ "num-1:28_124": "(num 1=8/1242)",
230
+ "num-1:28_397": "(num 1=8/3972)",
231
+ "num-1:28_947": "(num 1=8/9472)",
232
+ "num-1:29_657": "(num 1=9/6572)",
233
+ "num-1:29_867": "(num 1=9/8672)",
234
+ "num-1:29_874": "(num 1=9/8742)",
235
+ "num-1:30_083": "(num 1=0/0833)",
236
+ "num-1:30_140": "(num 1=0/1403)",
237
+ "num-1:30_354": "(num 1=0/3543)",
238
+ "num-1:30_453": "(num 1=0/4533)",
239
+ "num-1:30_946": "(num 1=0/9463)",
240
+ "num-1:31_602": "(num 1=1/6023)",
241
+ "num-1:31_842": "(num 1=1/8423)",
242
+ "num-1:32_087": "(num 1=2/0873)",
243
+ "num-1:32_259": "(num 1=2/2593)",
244
+ "num-1:33_166": "(num 1=3/1663)",
245
+ "num-1:34_154": "(num 1=4/1543)",
246
+ "num-1:34_556": "(num 1=4/5563)",
247
+ "num-1:35_323": "(num 1=5/3233)",
248
+ "num-1:36_023": "(num 1=6/0233)",
249
+ "num-1:36_076": "(num 1=6/0763)",
250
+ "num-1:36_651": "(num 1=6/6513)",
251
+ "n-giga_flop/s": "(n giga_flop/s)",
252
+ "sect-'g-1'": "(sect g-1)",
253
+ "sect-'h-1'": "(sect h-1)",
254
+ "sect-'h-2'": "(sect h-2)",
255
+ "sect-'h-3'": "(sect h-3)",
256
+ "abbr-'a-b-c'": "(abbr a-b-c)",
257
+ "abbr-d_o_a_": "(abbr DOA)",
258
+ "abbr-d_l_": "(abbr DL)",
259
+ "abbr-i_s_e_f_": "(abbr ISEF)",
260
+ "abbr-d_p_r_": "(abbr DPR)",
261
+ "abbr-D_P_R_": "(abbr DPR)",
262
+ "abbr-d_m_": "(abbr dm)",
263
+ "abbr-T_U_": "(abbr TU)",
264
+ "abbr-F_A_M_E_": "(abbr Fame)",
265
+ "dots-'...'": "(dots ...)",
266
+ }
267
+ new_pieces = ["(ROOT "]
268
+ for piece in pieces:
269
+ if piece.endswith(OPEN):
270
+ new_pieces.append("(" + piece[:-2])
271
+ elif piece == CLOSE:
272
+ new_pieces.append(")")
273
+ elif piece in PIECE_MAPPING:
274
+ new_pieces.append(PIECE_MAPPING[piece])
275
+ else:
276
+ # maxsplit=1 because of words like 1990-EQU-100
277
+ tag, word = piece.split("-", maxsplit=1)
278
+ if word.find("'") >= 0 or word.find("(") >= 0 or word.find(")") >= 0:
279
+ raise ValueError("Unhandled weird node: {}".format(piece))
280
+ if word.endswith("_"):
281
+ word = word[:-1] + "'"
282
+ date_match = DATE_RE.match(word)
283
+ if date_match:
284
+ # 10_30 special case sent_07072
285
+ # 16_30 special case sent_07098
286
+ # 21_15 special case sent_07099 and others
287
+ word = date_match.group(1) + ":" + date_match.group(2)
288
+ integer_percent = INTEGER_PERCENT_RE.match(word)
289
+ if integer_percent:
290
+ word = integer_percent.group(1) + "_%%"
291
+ range_percent = RANGE_PERCENT_RE.match(word)
292
+ if range_percent:
293
+ word = range_percent.group(1) + "," + range_percent.group(2) + "_%%"
294
+ percent = DECIMAL_PERCENT_RE.match(word)
295
+ if percent:
296
+ word = percent.group(1) + "," + percent.group(2) + "_%%"
297
+ decimal = DECIMAL_RE.match(word)
298
+ if decimal:
299
+ word = decimal.group(1) + "," + decimal.group(2)
300
+ # there are words which are multiple words mashed together
301
+ # with _ for some reason
302
+ # also, words which end in ' are replaced with _
303
+ # fortunately, no words seem to have both
304
+ # splitting like this means the tags are likely wrong,
305
+ # but the conparser needs to retag anyway, so it shouldn't matter
306
+ word_pieces = word.split("_")
307
+ for word_piece in word_pieces:
308
+ new_pieces.append("(%s %s)" % (tag, word_piece))
309
+ new_pieces.append(")")
310
+
311
+ text = " ".join(new_pieces)
312
+ trees = read_trees(text)
313
+ if len(trees) > 1:
314
+ raise ValueError("Unexpected number of trees!")
315
+ return trees[0]
316
+
317
+ def extract_ngrams(sentence, process_func, ngram_len=4):
318
+ leaf_words = [x for x in process_func(sentence)]
319
+ leaf_words = ["l'" if x == "l" else x for x in leaf_words]
320
+ if len(leaf_words) <= ngram_len:
321
+ return [tuple(leaf_words)]
322
+ its = [leaf_words[i:i+len(leaf_words)-ngram_len+1] for i in range(ngram_len)]
323
+ return [words for words in itertools.zip_longest(*its)]
324
+
325
+ def build_ngrams(sentences, process_func, id_func, ngram_len=4):
326
+ """
327
+ Turn the list of processed trees into a bunch of ngrams
328
+
329
+ The returned map is from tuple to set of ids
330
+
331
+ The idea being that this map can be used to search for trees to
332
+ match datasets
333
+ """
334
+ ngram_map = defaultdict(set)
335
+ for sentence in tqdm(sentences, postfix="Extracting ngrams"):
336
+ sentence_id = id_func(sentence)
337
+ ngrams = extract_ngrams(sentence, process_func, ngram_len)
338
+ for ngram in ngrams:
339
+ ngram_map[ngram].add(sentence_id)
340
+ return ngram_map
341
+
342
+ # just the tokens (maybe use words? depends on MWT in the con dataset)
343
+ DEP_PROCESS_FUNC = lambda x: [t.text.lower() for t in x.tokens]
344
+ # find the comment with "sent_id" in it, take just the id itself
345
+ DEP_ID_FUNC = lambda x: [c for c in x.comments if c.startswith("# sent_id")][0].split()[-1]
346
+
347
+ CON_PROCESS_FUNC = lambda x: [y.lower() for y in x.leaf_labels()]
348
+
349
+ def match_ngrams(sentence_ngrams, ngram_map, debug=False):
350
+ """
351
+ Check if there is a SINGLE matching sentence in the ngram_map for these ngrams
352
+
353
+ If an ngram shows up in multiple sentences, that is okay, but we ignore that info
354
+ If an ngram shows up in just one sentence, that is considered the match
355
+ If a different ngram then shows up in a different sentence, that is a problem
356
+ TODO: taking the intersection of all non-empty matches might be better
357
+ """
358
+ if debug:
359
+ print("NGRAMS FOR DEBUG SENTENCE:")
360
+ potential_match = None
361
+ unknown_ngram = 0
362
+ for ngram in sentence_ngrams:
363
+ con_matches = ngram_map[ngram]
364
+ if debug:
365
+ print("{} matched {}".format(ngram, len(con_matches)))
366
+ if len(con_matches) == 0:
367
+ unknown_ngram += 1
368
+ continue
369
+ if len(con_matches) > 1:
370
+ continue
371
+ # get the one & only element from the set
372
+ con_match = next(iter(con_matches))
373
+ if debug:
374
+ print(" {}".format(con_match))
375
+ if potential_match is None:
376
+ potential_match = con_match
377
+ elif potential_match != con_match:
378
+ return None
379
+ if unknown_ngram > len(sentence_ngrams) / 2:
380
+ return None
381
+ return potential_match
382
+
383
+ def match_sentences(con_tree_map, con_vit_ngrams, dep_sentences, split_name, debug_sentence=None):
384
+ """
385
+ Match ngrams in the dependency sentences to the constituency sentences
386
+
387
+ Then, to make sure the constituency sentence wasn't split into two
388
+ in the UD dataset, this checks the ngrams in the reverse direction
389
+
390
+ Some examples of things which don't match:
391
+ VIT-4769 Insegnanti non vedenti, insegnanti non autosufficienti con protesi agli arti inferiori.
392
+ this is duplicated in the original dataset, so the matching algorithm can't possibly work
393
+
394
+ VIT-4796 I posti istituiti con attività di sostegno dei docenti che ottengono il trasferimento su classi di concorso;
395
+ the correct con match should be sent_04829 but the brackets on that tree are broken
396
+ """
397
+ con_to_dep_matches = {}
398
+ dep_ngram_map = build_ngrams(dep_sentences, DEP_PROCESS_FUNC, DEP_ID_FUNC)
399
+ unmatched = 0
400
+ bad_match = 0
401
+ for sentence in dep_sentences:
402
+ sentence_ngrams = extract_ngrams(sentence, DEP_PROCESS_FUNC)
403
+ potential_match = match_ngrams(sentence_ngrams, con_vit_ngrams, debug_sentence is not None and DEP_ID_FUNC(sentence) == debug_sentence)
404
+ if potential_match is None:
405
+ if unmatched < 5:
406
+ print("Could not match the following sentence: {} {}".format(DEP_ID_FUNC(sentence), sentence.text))
407
+ unmatched += 1
408
+ continue
409
+ if potential_match not in con_tree_map:
410
+ raise ValueError("wtf")
411
+ con_ngrams = extract_ngrams(con_tree_map[potential_match], CON_PROCESS_FUNC)
412
+ reverse_match = match_ngrams(con_ngrams, dep_ngram_map)
413
+ if reverse_match is None:
414
+ #print("Matched sentence {} to sentence {} but the reverse match failed".format(sentence.text, " ".join(con_tree_map[potential_match].leaf_labels())))
415
+ bad_match += 1
416
+ continue
417
+ con_to_dep_matches[potential_match] = reverse_match
418
+ print("Failed to match %d sentences and found %d spurious matches in the %s section" % (unmatched, bad_match, split_name))
419
+ return con_to_dep_matches
420
+
421
+ EXCEPTIONS = ["gliene", "glielo", "gliela", "eccoci"]
422
+
423
+ def get_mwt(*dep_datasets):
424
+ """
425
+ Get the ADP/DET MWTs from the UD dataset
426
+
427
+ This class of MWT are expanded in the UD but not the constituencies
428
+ """
429
+ mwt_map = {}
430
+ for dataset in dep_datasets:
431
+ for sentence in dataset.sentences:
432
+ for token in sentence.tokens:
433
+ if len(token.words) == 1:
434
+ continue
435
+ # words such as "accorgermene" we just skip over
436
+ # those are already expanded in the constituency dataset
437
+ # TODO: the clitics are actually expanded weirdly, maybe need to compensate for that
438
+ if token.words[0].upos in ('VERB', 'AUX') and all(word.upos == 'PRON' for word in token.words[1:]):
439
+ continue
440
+ if token.text.lower() in EXCEPTIONS:
441
+ continue
442
+ if len(token.words) != 2 or token.words[0].upos != 'ADP' or token.words[1].upos != 'DET':
443
+ raise ValueError("Not sure how to handle this: {}".format(token))
444
+ expansion = (token.words[0].text, token.words[1].text)
445
+ if token.text in mwt_map:
446
+ if mwt_map[token.text] != expansion:
447
+ raise ValueError("Inconsistent MWT: {} -> {} or {}".format(token.text, expansion, mwt_map[token.text]))
448
+ continue
449
+ #print("Expanding {} to {}".format(token.text, expansion))
450
+ mwt_map[token.text] = expansion
451
+ return mwt_map
452
+
453
+ def update_mwts_and_special_cases(original_tree, dep_sentence, mwt_map, tsurgeon_processor):
454
+ """
455
+ Replace MWT structures with their UD equivalents, along with some other minor tsurgeon based edits
456
+
457
+ original_tree: the tree as read from VIT
458
+ dep_sentence: the UD dependency dataset version of this sentence
459
+ """
460
+ updated_tree = original_tree
461
+
462
+ operations = []
463
+
464
+ # first, remove titles or testo from the start of a sentence
465
+ con_words = updated_tree.leaf_labels()
466
+ if con_words[0] == "Tit'":
467
+ operations.append(["/^Tit'$/=prune !, __", "prune prune"])
468
+ elif con_words[0] == "TESTO":
469
+ operations.append(["/^TESTO$/=prune !, __", "prune prune"])
470
+ elif con_words[0] == "testo":
471
+ operations.append(["/^testo$/ !, __ . /^:$/=prune", "prune prune"])
472
+ operations.append(["/^testo$/=prune !, __", "prune prune"])
473
+
474
+ if len(con_words) >= 2 and con_words[-2] == '...' and con_words[-1] == '.':
475
+ # the most recent VIT constituency has some sentence final . after a ...
476
+ # the UD dataset has a more typical ... ending instead
477
+ # these lines used to say "riempire" which was rather odd
478
+ operations.append(["/^[.][.][.]$/ . /^[.]$/=prune", "prune prune"])
479
+
480
+ # a few constituent tags are simply errors which need to be fixed
481
+ if original_tree.children[0].label == 'p':
482
+ # 'p' shouldn't be at root
483
+ operations.append(["_ROOT_ < p=p", "relabel p cp"])
484
+ # fix one specific tree if it has an s_top in it
485
+ operations.append(["s_top=stop < (in=in < più=piu)", "replace piu (q più)", "relabel in sq", "relabel stop sa"])
486
+ # sect doesn't exist as a constituent. replace it with sa
487
+ operations.append(["sect=sect < num", "relabel sect sa"])
488
+ # ppas as an internal node gets removed
489
+ operations.append(["ppas=ppas < (__ < __)", "excise ppas ppas"])
490
+
491
+ # now assemble a bunch of regex to split and otherwise manipulate
492
+ # the MWT in the trees
493
+ for token in dep_sentence.tokens:
494
+ if len(token.words) == 1:
495
+ continue
496
+ if token.text in mwt_map:
497
+ mwt_pieces = mwt_map[token.text]
498
+ if len(mwt_pieces) != 2:
499
+ raise NotImplementedError("Expected exactly 2 pieces of mwt for %s" % token.text)
500
+ # the MWT words in the UD version will have ' when needed,
501
+ # but the corresponding ' is skipped in the con version of VIT,
502
+ # hence the replace("'", "")
503
+ # however, all' has the ' included, because this is a
504
+ # constituent treebank, not a consistent treebank
505
+ search_regex = "/^(?i:%s(?:')?)$/" % token.text.replace("'", "")
506
+ # tags which seem to be relevant:
507
+ # avvl|ccom|php|part|partd|partda
508
+ tregex = "__ !> __ <<<%d (%s=child > (__=parent $+ sn=sn))" % (token.id[0], search_regex)
509
+ tsurgeons = ["insert (art %s) >0 sn" % mwt_pieces[1], "relabel child %s" % mwt_pieces[0]]
510
+ operations.append([tregex] + tsurgeons)
511
+
512
+ tregex = "__ !> __ <<<%d (%s=child > (__=parent !$+ sn !$+ (art < %s)))" % (token.id[0], search_regex, mwt_pieces[1])
513
+ tsurgeons = ["insert (art %s) $- parent" % mwt_pieces[1], "relabel child %s" % mwt_pieces[0]]
514
+ operations.append([tregex] + tsurgeons)
515
+ elif len(token.words) == 2:
516
+ #print("{} not in mwt_map".format(token.text))
517
+ # apparently some trees like sent_00381 and sent_05070
518
+ # have the clitic in a non-projective manner
519
+ # [vcl-essersi, vppin-sparato, compt-[clitdat-si
520
+ # intj-figurarsi, fs-[cosu-quando, f-[ibar-[clit-si
521
+ # and before you ask, there are also clitics which are
522
+ # simply not there at all, rather than always attached
523
+ # in a non-projective manner
524
+ tregex = "__=parent < (/^(?i:%s)$/=child . (__=np !< __ . (/^clit/=clit < %s)))" % (token.text, token.words[1].text)
525
+ tsurgeon = "moveprune clit $- parent"
526
+ operations.append([tregex, tsurgeon])
527
+
528
+ # there are also some trees which don't have clitics
529
+ # for example, trees should look like this:
530
+ # [ibar-[vsup-poteva, vcl-rivelarsi], compc-[clit-si, sn-[...]]]
531
+ # however, at least one such example for rivelarsi instead
532
+ # looks like this, with no corresponding clit
533
+ # [... vcl-rivelarsi], compc-[sn-[in-ancora]]
534
+ # note that is the actual tag, not just me being pissed off
535
+ # breaking down the tregex:
536
+ # the child is the original MWT, not split
537
+ # !. clit verifies that it is not split (and stops the tsurgeon once fixed)
538
+ # !$+ checks that the parent of the MWT is the last element under parent
539
+ # note that !. can leave the immediate parent to touch the clit
540
+ # neighbor will be the place the new clit will be sticking out
541
+ tregex = "__=parent < (/^(?i:%s)$/=child !. /^clit/) !$+ __ > (__=gp $+ __=neighbor)" % token.text
542
+ tsurgeon = "insert (clit %s) >0 neighbor" % token.words[1].text
543
+ operations.append([tregex, tsurgeon])
544
+
545
+ # secondary option: while most trees are like the above,
546
+ # with an outer bracket around the MWT and another verb,
547
+ # some go straight into the next phrase
548
+ # sent_05076
549
+ # sv5-[vcl-adeguandosi, compin-[sp-[part-alle, ...
550
+ tregex = "__=parent < (/^(?i:%s)$/=child !. /^clit/) $+ __" % token.text
551
+ tsurgeon = "insert (clit %s) $- parent" % token.words[1].text
552
+ operations.append([tregex, tsurgeon])
553
+ else:
554
+ pass
555
+ if len(operations) > 0:
556
+ updated_tree = tsurgeon_processor.process(updated_tree, *operations)[0]
557
+ return updated_tree, operations
558
+
559
+ def update_tree(original_tree, dep_sentence, con_id, dep_id, mwt_map, tsurgeon_processor):
560
+ """
561
+ Update a tree using the mwt_map and tsurgeon to expand some MWTs
562
+
563
+ Then replace the words in the con tree with the words in the dep tree
564
+ """
565
+ ud_words = [x.text for x in dep_sentence.words]
566
+
567
+ updated_tree, operations = update_mwts_and_special_cases(original_tree, dep_sentence, mwt_map, tsurgeon_processor)
568
+
569
+ # this checks number of words
570
+ try:
571
+ updated_tree = updated_tree.replace_words(ud_words)
572
+ except ValueError as e:
573
+ raise ValueError("Failed to process {} {}:\nORIGINAL TREE\n{}\nUPDATED TREE\n{}\nUPDATED LEAVES\n{}\nUD TEXT\n{}\nTsurgeons applied:\n{}\n".format(con_id, dep_id, original_tree, updated_tree, updated_tree.leaf_labels(), ud_words, "\n".join("{}".format(op) for op in operations))) from e
574
+ return updated_tree
575
+
576
+ # train set:
577
+ # 858: missing close parens in the UD conversion
578
+ # 1169: 'che', 'poi', 'tutti', 'i', 'Paesi', 'ue', '.' -> 'per', 'tutti', 'i', 'paesi', 'Ue', '.'
579
+ # 2375: the problem is inconsistent treatment of s_p_a_
580
+ # 05052: the heuristic to fill in a missing "si" doesn't work because there's
581
+ # already another "si" immediately after
582
+ #
583
+ # test set:
584
+ # 09764: weird punct at end
585
+ # 10058: weird punct at end
586
+ IGNORE_IDS = ["sent_00867", "sent_01169", "sent_02375", "sent_05052", "sent_09764", "sent_10058"]
587
+
588
+ def extract_updated_dataset(con_tree_map, dep_sentence_map, split_ids, mwt_map, tsurgeon_processor):
589
+ """
590
+ Update constituency trees using the information in the dependency treebank
591
+ """
592
+ trees = []
593
+ for con_id, dep_id in tqdm(split_ids.items()):
594
+ # skip a few trees which have non-MWT word modifications
595
+ if con_id in IGNORE_IDS:
596
+ continue
597
+ original_tree = con_tree_map[con_id]
598
+ dep_sentence = dep_sentence_map[dep_id]
599
+ updated_tree = update_tree(original_tree, dep_sentence, con_id, dep_id, mwt_map, tsurgeon_processor)
600
+
601
+ trees.append(ProcessedTree(con_id, dep_id, updated_tree))
602
+ return trees
603
+
604
+ def read_updated_trees(paths, debug_sentence=None):
605
+ # original version with more errors
606
+ #con_filename = os.path.join(con_directory, "2011-12-20", "Archive", "VIT_newconstsynt.txt")
607
+ # this is the April 2022 version
608
+ #con_filename = os.path.join(con_directory, "VIT_newconstsynt.txt")
609
+ # the most recent update from ELRA may look like this?
610
+ # it's what we got, at least
611
+ # con_filename = os.path.join(con_directory, "italian", "VITwritten", "VITconstsyntNumb")
612
+
613
+ # needs at least UD 2.11 or this will not work
614
+ con_directory = paths["CONSTITUENCY_BASE"]
615
+ ud_directory = os.path.join(paths["UDBASE"], "UD_Italian-VIT")
616
+
617
+ con_filename = os.path.join(con_directory, "italian", "it_vit", "VITwritten", "VITconstsyntNumb")
618
+ ud_vit_train = os.path.join(ud_directory, "it_vit-ud-train.conllu")
619
+ ud_vit_dev = os.path.join(ud_directory, "it_vit-ud-dev.conllu")
620
+ ud_vit_test = os.path.join(ud_directory, "it_vit-ud-test.conllu")
621
+
622
+ print("Reading UD train/dev/test from %s" % ud_directory)
623
+ ud_train_data = CoNLL.conll2doc(input_file=ud_vit_train)
624
+ ud_dev_data = CoNLL.conll2doc(input_file=ud_vit_dev)
625
+ ud_test_data = CoNLL.conll2doc(input_file=ud_vit_test)
626
+
627
+ ud_vit_train_map = { DEP_ID_FUNC(x) : x for x in ud_train_data.sentences }
628
+ ud_vit_dev_map = { DEP_ID_FUNC(x) : x for x in ud_dev_data.sentences }
629
+ ud_vit_test_map = { DEP_ID_FUNC(x) : x for x in ud_test_data.sentences }
630
+
631
+ print("Getting ADP/DET expansions from UD data")
632
+ mwt_map = get_mwt(ud_train_data, ud_dev_data, ud_test_data)
633
+
634
+ con_sentences = read_constituency_file(con_filename)
635
+ num_discarded = 0
636
+ con_tree_map = {}
637
+ for idx, sentence in enumerate(tqdm(con_sentences, postfix="Processing")):
638
+ try:
639
+ tree = raw_tree(sentence[1])
640
+ if sentence[0].startswith("#ID="):
641
+ tree_id = sentence[0].split("=")[-1]
642
+ else:
643
+ tree_id = sentence[0].split("#")[-1]
644
+ # don't care about the raw text?
645
+ con_tree_map[tree_id] = tree
646
+ except UnclosedTreeError as e:
647
+ num_discarded = num_discarded + 1
648
+ print("Discarding {} because of reading error:\n {}: {}\n {}".format(sentence[0], type(e), e, sentence[1]))
649
+ except ExtraCloseTreeError as e:
650
+ num_discarded = num_discarded + 1
651
+ print("Discarding {} because of reading error:\n {}: {}\n {}".format(sentence[0], type(e), e, sentence[1]))
652
+ except ValueError as e:
653
+ print("Discarding {} because of reading error:\n {}: {}\n {}".format(sentence[0], type(e), e, sentence[1]))
654
+ num_discarded = num_discarded + 1
655
+ #raise ValueError("Could not process line %d" % idx) from e
656
+
657
+ print("Discarded %d trees. Have %d trees left" % (num_discarded, len(con_tree_map)))
658
+ if num_discarded > 0:
659
+ raise ValueError("Oops! We thought all of the VIT trees were properly bracketed now")
660
+ con_vit_ngrams = build_ngrams(con_tree_map.items(), lambda x: CON_PROCESS_FUNC(x[1]), lambda x: x[0])
661
+
662
+ # TODO: match more sentences. some are probably missing because of MWT
663
+ train_ids = match_sentences(con_tree_map, con_vit_ngrams, ud_train_data.sentences, "train", debug_sentence)
664
+ dev_ids = match_sentences(con_tree_map, con_vit_ngrams, ud_dev_data.sentences, "dev", debug_sentence)
665
+ test_ids = match_sentences(con_tree_map, con_vit_ngrams, ud_test_data.sentences, "test", debug_sentence)
666
+ print("Remaining total trees: %d" % (len(train_ids) + len(dev_ids) + len(test_ids)))
667
+ print(" {} train {} dev {} test".format(len(train_ids), len(dev_ids), len(test_ids)))
668
+ print("Updating trees with MWT and newer tokens from UD...")
669
+
670
+ # the moveprune feature requires a new corenlp release after 4.4.0
671
+ with tsurgeon.Tsurgeon(classpath="$CLASSPATH") as tsurgeon_processor:
672
+ train_trees = extract_updated_dataset(con_tree_map, ud_vit_train_map, train_ids, mwt_map, tsurgeon_processor)
673
+ dev_trees = extract_updated_dataset(con_tree_map, ud_vit_dev_map, dev_ids, mwt_map, tsurgeon_processor)
674
+ test_trees = extract_updated_dataset(con_tree_map, ud_vit_test_map, test_ids, mwt_map, tsurgeon_processor)
675
+
676
+ return train_trees, dev_trees, test_trees
677
+
678
+ def convert_it_vit(paths, dataset_name, debug_sentence=None):
679
+ """
680
+ Read the trees, then write them out to the expected output_directory
681
+ """
682
+ train_trees, dev_trees, test_trees = read_updated_trees(paths, debug_sentence)
683
+
684
+ train_trees = [x.tree for x in train_trees]
685
+ dev_trees = [x.tree for x in dev_trees]
686
+ test_trees = [x.tree for x in test_trees]
687
+
688
+ output_directory = paths["CONSTITUENCY_DATA_DIR"]
689
+ write_dataset([train_trees, dev_trees, test_trees], output_directory, dataset_name)
690
+
691
+ def main():
692
+ paths = default_paths.get_default_paths()
693
+ dataset_name = "it_vit"
694
+
695
+ debug_sentence = sys.argv[1] if len(sys.argv) > 1 else None
696
+
697
+ convert_it_vit(paths, dataset_name, debug_sentence)
698
+
699
+ if __name__ == '__main__':
700
+ main()
stanza/stanza/utils/datasets/constituency/convert_spmrl.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ from stanza.models.constituency.parse_tree import Tree
4
+ from stanza.models.constituency.tree_reader import read_treebank
5
+ from stanza.utils.default_paths import get_default_paths
6
+
7
+ SHARDS = ("train", "dev", "test")
8
+
9
+ def add_root(tree):
10
+ if tree.label.startswith("NN"):
11
+ tree = Tree("NP", tree)
12
+ if tree.label.startswith("NE"):
13
+ tree = Tree("PN", tree)
14
+ elif tree.label.startswith("XY"):
15
+ tree = Tree("VROOT", tree)
16
+ return Tree("ROOT", tree)
17
+
18
+ def convert_spmrl(input_directory, output_directory, short_name):
19
+ for shard in SHARDS:
20
+ tree_filename = os.path.join(input_directory, shard, shard + ".German.gold.ptb")
21
+ trees = read_treebank(tree_filename, tree_callback=add_root)
22
+ output_filename = os.path.join(output_directory, "%s_%s.mrg" % (short_name, shard))
23
+ with open(output_filename, "w", encoding="utf-8") as fout:
24
+ for tree in trees:
25
+ fout.write(str(tree))
26
+ fout.write("\n")
27
+ print("Wrote %d trees to %s" % (len(trees), output_filename))
28
+
29
+ if __name__ == '__main__':
30
+ paths = get_default_paths()
31
+ output_directory = paths["CONSTITUENCY_DATA_DIR"]
32
+ input_directory = "extern_data/constituency/spmrl/SPMRL_SHARED_2014/GERMAN_SPMRL/gold/ptb"
33
+ convert_spmrl(input_directory, output_directory, "de_spmrl")
34
+
35
+
stanza/stanza/utils/datasets/constituency/convert_starlang.py ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import os
3
+ import re
4
+
5
+ from tqdm import tqdm
6
+
7
+ from stanza.models.constituency import parse_tree
8
+ from stanza.models.constituency import tree_reader
9
+
10
+ TURKISH_RE = re.compile(r"[{]turkish=([^}]+)[}]")
11
+
12
+ DISALLOWED_LABELS = ('DT', 'DET', 's', 'vp', 'AFVP', 'CONJ', 'INTJ', '-XXX-')
13
+
14
+ def read_tree(text):
15
+ """
16
+ Reads in a tree, then extracts specifically the word from the specific format used
17
+
18
+ Also converts LCB/RCB as needed
19
+ """
20
+ trees = tree_reader.read_trees(text)
21
+ if len(trees) > 1:
22
+ raise ValueError("Tree file had two trees!")
23
+ tree = trees[0]
24
+ labels = tree.leaf_labels()
25
+ new_labels = []
26
+ for label in labels:
27
+ match = TURKISH_RE.search(label)
28
+ if match is None:
29
+ raise ValueError("Could not find word in |{}|".format(label))
30
+ word = match.group(1)
31
+ word = word.replace("-LCB-", "{").replace("-RCB-", "}")
32
+ new_labels.append(word)
33
+
34
+ tree = tree.replace_words(new_labels)
35
+ #tree = tree.remap_constituent_labels(LABEL_MAP)
36
+ con_labels = tree.get_unique_constituent_labels([tree])
37
+ if any(label in DISALLOWED_LABELS for label in con_labels):
38
+ raise ValueError("found an unexpected phrasal node {}".format(label))
39
+ return tree
40
+
41
+ def read_files(filenames, conversion, log):
42
+ trees = []
43
+ for filename in filenames:
44
+ with open(filename, encoding="utf-8") as fin:
45
+ text = fin.read()
46
+ try:
47
+ tree = conversion(text)
48
+ if tree is not None:
49
+ trees.append(tree)
50
+ except ValueError as e:
51
+ if log:
52
+ print("-----------------\nFound an error in {}: {} Original text: {}".format(filename, e, text))
53
+ return trees
54
+
55
+ def read_starlang(paths, conversion=read_tree, log=True):
56
+ """
57
+ Read the starlang trees, converting them using the given method.
58
+
59
+ read_tree or any other conversion turns one file at a time to a sentence.
60
+ log is whether or not to log a ValueError - the NER division has many missing labels
61
+ """
62
+ if isinstance(paths, str):
63
+ paths = (paths,)
64
+
65
+ train_files = []
66
+ dev_files = []
67
+ test_files = []
68
+
69
+ for path in paths:
70
+ tree_files = [os.path.join(path, x) for x in os.listdir(path)]
71
+ train_files.extend([x for x in tree_files if x.endswith(".train")])
72
+ dev_files.extend([x for x in tree_files if x.endswith(".dev")])
73
+ test_files.extend([x for x in tree_files if x.endswith(".test")])
74
+
75
+ print("Reading %d total files" % (len(train_files) + len(dev_files) + len(test_files)))
76
+ train_treebank = read_files(tqdm(train_files), conversion=conversion, log=log)
77
+ dev_treebank = read_files(tqdm(dev_files), conversion=conversion, log=log)
78
+ test_treebank = read_files(tqdm(test_files), conversion=conversion, log=log)
79
+
80
+ return train_treebank, dev_treebank, test_treebank
81
+
82
+ def main(conversion=read_tree, log=True):
83
+ paths = ["extern_data/constituency/turkish/TurkishAnnotatedTreeBank-15",
84
+ "extern_data/constituency/turkish/TurkishAnnotatedTreeBank2-15",
85
+ "extern_data/constituency/turkish/TurkishAnnotatedTreeBank2-20"]
86
+ train_treebank, dev_treebank, test_treebank = read_starlang(paths, conversion=conversion, log=log)
87
+
88
+ print("Train: %d" % len(train_treebank))
89
+ print("Dev: %d" % len(dev_treebank))
90
+ print("Test: %d" % len(test_treebank))
91
+
92
+ print(train_treebank[0])
93
+ return train_treebank, dev_treebank, test_treebank
94
+
95
+ if __name__ == '__main__':
96
+ main()
stanza/stanza/utils/datasets/constituency/extract_all_silver_dataset.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ After running build_silver_dataset.py, this extracts the trees of all match levels at once
3
+
4
+ For example
5
+
6
+ python stanza/utils/datasets/constituency/extract_all_silver_dataset.py --output_prefix /u/nlp/data/constituency-parser/chinese/2024_zh_wiki/zh_silver_ --parsed_trees /u/nlp/data/constituency-parser/chinese/2024_zh_wiki/zh_wiki_a*trees
7
+
8
+ cat /u/nlp/data/constituency-parser/chinese/2024_zh_wiki/zh_silver_[012345678].mrg | sort | uniq | shuf > /u/nlp/data/constituency-parser/chinese/2024_zh_wiki/zh_silver_sort.mrg
9
+
10
+ shuf /u/nlp/data/constituency-parser/chinese/2024_zh_wiki/zh_silver_sort.mrg | head -n 200000 > /u/nlp/data/constituency-parser/chinese/2024_zh_wiki/zh_silver_200K.mrg
11
+ """
12
+
13
+ import argparse
14
+ from collections import defaultdict
15
+ import json
16
+
17
+ def parse_args():
18
+ parser = argparse.ArgumentParser(description="After finding common trees using build_silver_dataset, this extracts them all or just the ones from a particular level of accuracy")
19
+ parser.add_argument('--parsed_trees', type=str, nargs='+', help='Input file(s) of trees parsed into the build_silver_dataset json format.')
20
+ parser.add_argument('--output_prefix', type=str, default=None, help='Prefix to use for outputting trees')
21
+ parser.add_argument('--output_suffix', type=str, default=".mrg", help='Suffix to use for outputting trees')
22
+ args = parser.parse_args()
23
+
24
+ return args
25
+
26
+ def main():
27
+ args = parse_args()
28
+
29
+ trees = defaultdict(list)
30
+ for filename in args.parsed_trees:
31
+ with open(filename, encoding='utf-8') as fin:
32
+ for line in fin.readlines():
33
+ tree = json.loads(line)
34
+ trees[tree['count']].append(tree['tree'])
35
+
36
+ for score, tree_list in trees.items():
37
+ filename = "%s%s%s" % (args.output_prefix, score, args.output_suffix)
38
+ with open(filename, 'w', encoding='utf-8') as fout:
39
+ for tree in tree_list:
40
+ fout.write(tree)
41
+ fout.write('\n')
42
+
43
+ if __name__ == '__main__':
44
+ main()
45
+
46
+
stanza/stanza/utils/datasets/constituency/relabel_tags.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Retag an S-expression tree with a new set of POS tags
3
+
4
+ Also includes an option to write the new trees as bracket_labels
5
+ (essentially, skipping the treebank_to_labeled_brackets step)
6
+ """
7
+
8
+ import argparse
9
+ import logging
10
+
11
+ from stanza import Pipeline
12
+ from stanza.models.constituency import retagging
13
+ from stanza.models.constituency import tree_reader
14
+ from stanza.models.constituency.utils import retag_trees
15
+
16
+ logger = logging.getLogger('stanza')
17
+
18
+ def parse_args():
19
+ parser = argparse.ArgumentParser(description="Script that retags a tree file")
20
+ parser.add_argument('--lang', default='vi', type=str, help='Language')
21
+ parser.add_argument('--input_file', default='data/constituency/vi_vlsp21_train.mrg', help='File to retag')
22
+ parser.add_argument('--output_file', default='vi_vlsp21_train_retagged.mrg', help='Where to write the retagged trees')
23
+ retagging.add_retag_args(parser)
24
+
25
+ parser.add_argument('--bracket_labels', action='store_true', help='Write the trees as bracket labels instead of S-expressions')
26
+
27
+ args = parser.parse_args()
28
+ args = vars(args)
29
+ retagging.postprocess_args(args)
30
+
31
+ return args
32
+
33
+ def main():
34
+ args = parse_args()
35
+
36
+ retag_pipeline = retagging.build_retag_pipeline(args)
37
+
38
+ train_trees = tree_reader.read_treebank(args['input_file'])
39
+ logger.info("Retagging %d trees using %s", len(train_trees), args['retag_package'])
40
+ train_trees = retag_trees(train_trees, retag_pipeline, args['retag_xpos'])
41
+ tree_format = "{:L}" if args['bracket_labels'] else "{}"
42
+ with open(args['output_file'], "w") as fout:
43
+ for tree in train_trees:
44
+ fout.write(tree_format.format(tree))
45
+ fout.write("\n")
46
+
47
+ if __name__ == '__main__':
48
+ main()
stanza/stanza/utils/datasets/constituency/selftrain.py ADDED
@@ -0,0 +1,268 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Common methods for the various self-training data collection scripts
3
+ """
4
+
5
+ import logging
6
+ import os
7
+ import random
8
+ import re
9
+
10
+ import stanza
11
+ from stanza.models.common import utils
12
+ from stanza.models.common.bert_embedding import TextTooLongError
13
+ from stanza.utils.get_tqdm import get_tqdm
14
+
15
+ logger = logging.getLogger('stanza')
16
+ tqdm = get_tqdm()
17
+
18
+ def common_args(parser):
19
+ parser.add_argument(
20
+ '--output_file',
21
+ default='data/constituency/vi_silver.mrg',
22
+ help='Where to write the silver trees'
23
+ )
24
+ parser.add_argument(
25
+ '--lang',
26
+ default='vi',
27
+ help='Which language tools to use for tokenization and POS'
28
+ )
29
+ parser.add_argument(
30
+ '--num_sentences',
31
+ type=int,
32
+ default=-1,
33
+ help='How many sentences to get per file (max)'
34
+ )
35
+ parser.add_argument(
36
+ '--models',
37
+ default='saved_models/constituency/vi_vlsp21_inorder.pt',
38
+ help='What models to use for parsing. comma-separated'
39
+ )
40
+ parser.add_argument(
41
+ '--package',
42
+ default='default',
43
+ help='Which package to load pretrain & charlm from for the parsers'
44
+ )
45
+ parser.add_argument(
46
+ '--output_ptb',
47
+ default=False,
48
+ action='store_true',
49
+ help='Output trees in PTB brackets (default is a bracket language format)'
50
+ )
51
+
52
+ def add_length_args(parser):
53
+ parser.add_argument(
54
+ '--min_len',
55
+ default=5,
56
+ type=int,
57
+ help='Minimum length sentence to keep. None = unlimited'
58
+ )
59
+ parser.add_argument(
60
+ '--no_min_len',
61
+ dest='min_len',
62
+ action='store_const',
63
+ const=None,
64
+ help='No minimum length'
65
+ )
66
+ parser.add_argument(
67
+ '--max_len',
68
+ default=100,
69
+ type=int,
70
+ help='Maximum length sentence to keep. None = unlimited'
71
+ )
72
+ parser.add_argument(
73
+ '--no_max_len',
74
+ dest='max_len',
75
+ action='store_const',
76
+ const=None,
77
+ help='No maximum length'
78
+ )
79
+
80
+ def build_ssplit_pipe(ssplit, lang):
81
+ if ssplit:
82
+ return stanza.Pipeline(lang, processors="tokenize")
83
+ else:
84
+ return stanza.Pipeline(lang, processors="tokenize", tokenize_no_ssplit=True)
85
+
86
+ def build_tag_pipe(ssplit, lang, foundation_cache=None):
87
+ if ssplit:
88
+ return stanza.Pipeline(lang, processors="tokenize,pos", foundation_cache=foundation_cache)
89
+ else:
90
+ return stanza.Pipeline(lang, processors="tokenize,pos", tokenize_no_ssplit=True, foundation_cache=foundation_cache)
91
+
92
+ def build_parser_pipes(lang, models, package="default", foundation_cache=None):
93
+ """
94
+ Build separate pipelines for each parser model we want to use
95
+
96
+ It is highly recommended to pass in a FoundationCache to reuse bottom layers
97
+ """
98
+ parser_pipes = []
99
+ for model_name in models.split(","):
100
+ if os.path.exists(model_name):
101
+ # if the model name exists as a file, treat it as the path to the model
102
+ pipe = stanza.Pipeline(lang, processors="constituency", package=package, constituency_model_path=model_name, constituency_pretagged=True, foundation_cache=foundation_cache)
103
+ else:
104
+ # otherwise, assume it is a package name?
105
+ pipe = stanza.Pipeline(lang, processors={"constituency": model_name}, constituency_pretagged=True, package=None, foundation_cache=foundation_cache)
106
+ parser_pipes.append(pipe)
107
+ return parser_pipes
108
+
109
+ def split_docs(docs, ssplit_pipe, max_len=140, max_word_len=50, chunk_size=2000):
110
+ """
111
+ Using the ssplit pipeline, break up the documents into sentences
112
+
113
+ Filters out sentences which are too long or have words too long.
114
+
115
+ This step is necessary because some web text has unstructured
116
+ sentences which overwhelm the tagger, or even text with no
117
+ whitespace which breaks the charlm in the tokenizer or tagger
118
+ """
119
+ raw_sentences = 0
120
+ filtered_sentences = 0
121
+ new_docs = []
122
+
123
+ logger.info("Splitting raw docs into sentences: %d", len(docs))
124
+ for chunk_start in tqdm(range(0, len(docs), chunk_size)):
125
+ chunk = docs[chunk_start:chunk_start+chunk_size]
126
+ chunk = [stanza.Document([], text=t) for t in chunk]
127
+ chunk = ssplit_pipe(chunk)
128
+ sentences = [s for d in chunk for s in d.sentences]
129
+ raw_sentences += len(sentences)
130
+ sentences = [s for s in sentences if len(s.words) < max_len]
131
+ sentences = [s for s in sentences if max(len(w.text) for w in s.words) < max_word_len]
132
+ filtered_sentences += len(sentences)
133
+ new_docs.extend([s.text for s in sentences])
134
+
135
+ logger.info("Split sentences: %d", raw_sentences)
136
+ logger.info("Sentences filtered for length: %d", filtered_sentences)
137
+ return new_docs
138
+
139
+ # from https://stackoverflow.com/questions/2718196/find-all-chinese-text-in-a-string-using-python-and-regex
140
+ ZH_RE = re.compile(u'[⺀-⺙⺛-⻳⼀-⿕々〇〡-〩〸-〺〻㐀-䶵一-鿃豈-鶴侮-頻並-龎]', re.UNICODE)
141
+ # https://stackoverflow.com/questions/6787716/regular-expression-for-japanese-characters
142
+ JA_RE = re.compile(u'[一-龠ぁ-ゔァ-ヴー々〆〤ヶ]', re.UNICODE)
143
+ DEV_RE = re.compile(u'[\u0900-\u097f]', re.UNICODE)
144
+
145
+ def tokenize_docs(docs, pipe, min_len, max_len):
146
+ """
147
+ Turn the text in docs into a list of whitespace separated sentences
148
+
149
+ docs: a list of strings
150
+ pipe: a Stanza pipeline for tokenizing
151
+ min_len, max_len: can be None to not filter by this attribute
152
+ """
153
+ results = []
154
+ docs = [stanza.Document([], text=t) for t in docs]
155
+ if len(docs) == 0:
156
+ return results
157
+ pipe(docs)
158
+ is_zh = pipe.lang and pipe.lang.startswith("zh")
159
+ is_ja = pipe.lang and pipe.lang.startswith("ja")
160
+ is_vi = pipe.lang and pipe.lang.startswith("vi")
161
+ for doc in docs:
162
+ for sentence in doc.sentences:
163
+ if min_len and len(sentence.words) < min_len:
164
+ continue
165
+ if max_len and len(sentence.words) > max_len:
166
+ continue
167
+ text = sentence.text
168
+ if (text.find("|") >= 0 or text.find("_") >= 0 or
169
+ text.find("<") >= 0 or text.find(">") >= 0 or
170
+ text.find("[") >= 0 or text.find("]") >= 0 or
171
+ text.find('—') >= 0): # an em dash, seems to be part of lists
172
+ continue
173
+ # the VI tokenizer in particular doesn't split these well
174
+ if any(any(w.text.find(c) >= 0 and len(w.text) > 1 for w in sentence.words)
175
+ for c in '"()'):
176
+ continue
177
+ text = [w.text.replace(" ", "_") for w in sentence.words]
178
+ text = " ".join(text)
179
+ if any(len(w.text) >= 50 for w in sentence.words):
180
+ # skip sentences where some of the words are unreasonably long
181
+ # could make this an argument
182
+ continue
183
+ if not is_zh and len(ZH_RE.findall(text)) > 250:
184
+ # some Chinese sentences show up in VI Wikipedia
185
+ # we want to eliminate ones which will choke the bert models
186
+ continue
187
+ if not is_ja and len(JA_RE.findall(text)) > 150:
188
+ # some Japanese sentences also show up in VI Wikipedia
189
+ # we want to eliminate ones which will choke the bert models
190
+ continue
191
+ if is_vi and len(DEV_RE.findall(text)) > 100:
192
+ # would need some list of languages that use
193
+ # Devanagari to eliminate sentences from all datasets.
194
+ # Otherwise we might accidentally throw away all the
195
+ # text from a language we need (although that would be obvious)
196
+ continue
197
+ results.append(text)
198
+ return results
199
+
200
+ def find_matching_trees(docs, num_sentences, accepted_trees, tag_pipe, parser_pipes, shuffle=True, chunk_size=10, max_len=140, min_len=10, output_ptb=False):
201
+ """
202
+ Find trees where all the parsers in parser_pipes agree
203
+
204
+ docs should be a list of strings.
205
+ one sentence per string or a whole block of text as long as the tag_pipe can break it into sentences
206
+
207
+ num_sentences > 0 gives an upper limit on how many sentences to extract.
208
+ If < 0, all possible sentences are extracted
209
+
210
+ accepted_trees is a running tally of all the trees already built,
211
+ so that we don't reuse the same sentence if we see it again
212
+ """
213
+ if num_sentences < 0:
214
+ tqdm_total = len(docs)
215
+ else:
216
+ tqdm_total = num_sentences
217
+
218
+ if output_ptb:
219
+ output_format = "{}"
220
+ else:
221
+ output_format = "{:L}"
222
+
223
+ with tqdm(total=tqdm_total, leave=False) as pbar:
224
+ if shuffle:
225
+ random.shuffle(docs)
226
+ new_trees = set()
227
+ for chunk_start in range(0, len(docs), chunk_size):
228
+ chunk = docs[chunk_start:chunk_start+chunk_size]
229
+ chunk = [stanza.Document([], text=t) for t in chunk]
230
+
231
+ if num_sentences < 0:
232
+ pbar.update(len(chunk))
233
+
234
+ # first, retag the sentences
235
+ tag_pipe(chunk)
236
+
237
+ chunk = [d for d in chunk if len(d.sentences) > 0]
238
+ if max_len is not None:
239
+ # for now, we don't have a good way to deal with sentences longer than the bert maxlen
240
+ chunk = [d for d in chunk if max(len(s.words) for s in d.sentences) < max_len]
241
+ if len(chunk) == 0:
242
+ continue
243
+
244
+ parses = []
245
+ try:
246
+ for pipe in parser_pipes:
247
+ pipe(chunk)
248
+ trees = [output_format.format(sent.constituency) for doc in chunk for sent in doc.sentences if len(sent.words) >= min_len]
249
+ parses.append(trees)
250
+ except TextTooLongError as e:
251
+ # easiest is to skip this chunk - could theoretically save the other sentences
252
+ continue
253
+
254
+ for tree in zip(*parses):
255
+ if len(set(tree)) != 1:
256
+ continue
257
+ tree = tree[0]
258
+ if tree in accepted_trees:
259
+ continue
260
+ if tree not in new_trees:
261
+ new_trees.add(tree)
262
+ if num_sentences >= 0:
263
+ pbar.update(1)
264
+ if num_sentences >= 0 and len(new_trees) >= num_sentences:
265
+ return new_trees
266
+
267
+ return new_trees
268
+
stanza/stanza/utils/datasets/constituency/selftrain_it.py ADDED
@@ -0,0 +1,120 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Builds a self-training dataset from an Italian data source and two models
2
+
3
+ The idea is that the top down and the inorder parsers should make
4
+ somewhat different errors, so hopefully the sum of an 86 f1 parser and
5
+ an 85.5 f1 parser will produce some half-decent silver trees which can
6
+ be used as self-training so that a new model can do better than either.
7
+
8
+ One dataset used is PaCCSS, which has 63000 pairs of sentences:
9
+
10
+ http://www.italianlp.it/resources/paccss-it-parallel-corpus-of-complex-simple-sentences-for-italian/
11
+
12
+ PaCCSS-IT: A Parallel Corpus of Complex-Simple Sentences for Automatic Text Simplification
13
+ Brunato, Dominique et al, 2016
14
+ https://aclanthology.org/D16-1034
15
+
16
+ Even larger is the IT section of Europarl, which has 1900000 lines
17
+
18
+ https://www.statmt.org/europarl/
19
+
20
+ Europarl: A Parallel Corpus for Statistical Machine Translation
21
+ Philipp Koehn
22
+ https://homepages.inf.ed.ac.uk/pkoehn/publications/europarl-mtsummit05.pdf
23
+ """
24
+
25
+ import argparse
26
+ import logging
27
+ import os
28
+ import random
29
+
30
+ import stanza
31
+ from stanza.models.common.foundation_cache import FoundationCache
32
+ from stanza.utils.datasets.constituency import selftrain
33
+ from stanza.utils.get_tqdm import get_tqdm
34
+
35
+ tqdm = get_tqdm()
36
+ logger = logging.getLogger('stanza')
37
+
38
+ def parse_args():
39
+ parser = argparse.ArgumentParser(
40
+ description="Script that converts a public IT dataset to silver standard trees"
41
+ )
42
+ selftrain.common_args(parser)
43
+ parser.add_argument(
44
+ '--input_dir',
45
+ default='extern_data/italian',
46
+ help='Path to the PaCCSS corpus and europarl corpus'
47
+ )
48
+
49
+ parser.add_argument(
50
+ '--no_europarl',
51
+ default=True,
52
+ action='store_false',
53
+ dest='europarl',
54
+ help='Use the europarl dataset. Turning this off makes the script a lot faster'
55
+ )
56
+
57
+ parser.set_defaults(lang="it")
58
+ parser.set_defaults(package="vit")
59
+ parser.set_defaults(models="saved_models/constituency/it_best/it_vit_inorder_best.pt,saved_models/constituency/it_best/it_vit_topdown.pt")
60
+ parser.set_defaults(output_file="data/constituency/it_silver.mrg")
61
+
62
+ args = parser.parse_args()
63
+ return args
64
+
65
+ def get_paccss(input_dir):
66
+ """
67
+ Read the paccss dataset, which is two sentences per line
68
+ """
69
+ input_file = os.path.join(input_dir, "PaCCSS/data-set/PACCSS-IT.txt")
70
+ with open(input_file) as fin:
71
+ # the first line is a header line
72
+ lines = fin.readlines()[1:]
73
+ lines = [x.strip() for x in lines]
74
+ lines = [x.split("\t")[:2] for x in lines if x]
75
+ text = [y for x in lines for y in x]
76
+ logger.info("Read %d sentences from %s", len(text), input_file)
77
+ return text
78
+
79
+ def get_europarl(input_dir, ssplit_pipe):
80
+ """
81
+ Read the Europarl dataset
82
+
83
+ This dataset needs to be tokenized and split into lines
84
+ """
85
+ input_file = os.path.join(input_dir, "europarl/europarl-v7.it-en.it")
86
+ with open(input_file) as fin:
87
+ # the first line is a header line
88
+ lines = fin.readlines()[1:]
89
+ lines = [x.strip() for x in lines]
90
+ lines = [x for x in lines if x]
91
+ logger.info("Read %d docs from %s", len(lines), input_file)
92
+ lines = selftrain.split_docs(lines, ssplit_pipe)
93
+ return lines
94
+
95
+ def main():
96
+ """
97
+ Combine the two datasets, parse them, and write out the results
98
+ """
99
+ args = parse_args()
100
+
101
+ foundation_cache = FoundationCache()
102
+ ssplit_pipe = selftrain.build_ssplit_pipe(ssplit=True, lang=args.lang)
103
+ tag_pipe = selftrain.build_tag_pipe(ssplit=False, lang=args.lang, foundation_cache=foundation_cache)
104
+ parser_pipes = selftrain.build_parser_pipes(args.lang, args.models, package=args.package, foundation_cache=foundation_cache)
105
+
106
+ docs = get_paccss(args.input_dir)
107
+ if args.europarl:
108
+ docs.extend(get_europarl(args.input_dir, ssplit_pipe))
109
+
110
+ logger.info("Processing %d docs", len(docs))
111
+ new_trees = selftrain.find_matching_trees(docs, args.num_sentences, set(), tag_pipe, parser_pipes, shuffle=False, chunk_size=100, output_ptb=args.output_ptb)
112
+ logger.info("Found %d unique trees which are the same between models" % len(new_trees))
113
+ with open(args.output_file, "w") as fout:
114
+ for tree in sorted(new_trees):
115
+ fout.write(tree)
116
+ fout.write("\n")
117
+
118
+
119
+ if __name__ == '__main__':
120
+ main()
stanza/stanza/utils/datasets/constituency/selftrain_single_file.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Builds a self-training dataset from a single file.
3
+
4
+ Default is to assume one document of text per line. If a line has
5
+ multiple sentences, they will be split using the stanza tokenizer.
6
+ """
7
+
8
+ import argparse
9
+ import io
10
+ import logging
11
+ import os
12
+
13
+ import numpy as np
14
+
15
+ import stanza
16
+ from stanza.utils.datasets.constituency import selftrain
17
+ from stanza.utils.get_tqdm import get_tqdm
18
+
19
+ logger = logging.getLogger('stanza')
20
+ tqdm = get_tqdm()
21
+
22
+ def parse_args():
23
+ """
24
+ Only specific argument for this script is the file to process
25
+ """
26
+ parser = argparse.ArgumentParser(
27
+ description="Script that converts a single file of text to silver standard trees"
28
+ )
29
+ selftrain.common_args(parser)
30
+ parser.add_argument(
31
+ '--input_file',
32
+ default="vi_part_1.aa",
33
+ help='Path to the file to read'
34
+ )
35
+
36
+ args = parser.parse_args()
37
+ return args
38
+
39
+
40
+ def read_file(input_file):
41
+ """
42
+ Read lines from an input file
43
+
44
+ Takes care to avoid encoding errors at the end of Oscar files.
45
+ The Oscar splits sometimes break a utf-8 character in half.
46
+ """
47
+ with open(input_file, "rb") as fin:
48
+ text = fin.read()
49
+ text = text.decode("utf-8", errors="replace")
50
+ with io.StringIO(text) as fin:
51
+ lines = fin.readlines()
52
+ return lines
53
+
54
+
55
+ def main():
56
+ args = parse_args()
57
+
58
+ # TODO: make ssplit an argument
59
+ ssplit_pipe = selftrain.build_ssplit_pipe(ssplit=True, lang=args.lang)
60
+ tag_pipe = selftrain.build_tag_pipe(ssplit=False, lang=args.lang)
61
+ parser_pipes = selftrain.build_parser_pipes(args.lang, args.models)
62
+
63
+ # create a blank file. we will append to this file so that partial results can be used
64
+ with open(args.output_file, "w") as fout:
65
+ pass
66
+
67
+ docs = read_file(args.input_file)
68
+ logger.info("Read %d lines from %s", len(docs), args.input_file)
69
+ docs = selftrain.split_docs(docs, ssplit_pipe)
70
+
71
+ # breaking into chunks lets us output partial results and see the
72
+ # progress in log files
73
+ accepted_trees = set()
74
+ if len(docs) > 10000:
75
+ chunks = tqdm(np.array_split(docs, 100), disable=False)
76
+ else:
77
+ chunks = [docs]
78
+ for chunk in chunks:
79
+ new_trees = selftrain.find_matching_trees(chunk, args.num_sentences, accepted_trees, tag_pipe, parser_pipes, shuffle=False, chunk_size=100)
80
+ accepted_trees.update(new_trees)
81
+
82
+ with open(args.output_file, "a") as fout:
83
+ for tree in sorted(new_trees):
84
+ fout.write(tree)
85
+ fout.write("\n")
86
+
87
+ if __name__ == '__main__':
88
+ main()
stanza/stanza/utils/datasets/constituency/selftrain_vi_quad.py ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Processes the train section of VI QuAD into trees suitable for use in the conparser lm
3
+ """
4
+
5
+ import argparse
6
+ import json
7
+ import logging
8
+
9
+ import stanza
10
+ from stanza.utils.datasets.constituency import selftrain
11
+
12
+ logger = logging.getLogger('stanza')
13
+
14
+ def parse_args():
15
+ parser = argparse.ArgumentParser(
16
+ description="Script that converts vi quad to silver standard trees"
17
+ )
18
+ selftrain.common_args(parser)
19
+ selftrain.add_length_args(parser)
20
+ parser.add_argument(
21
+ '--input_file',
22
+ default="extern_data/vietnamese/ViQuAD/train_ViQuAD.json",
23
+ help='Path to the ViQuAD train file'
24
+ )
25
+ parser.add_argument(
26
+ '--tokenize_only',
27
+ default=False,
28
+ action='store_true',
29
+ help='Tokenize instead of writing trees'
30
+ )
31
+
32
+ args = parser.parse_args()
33
+ return args
34
+
35
+ def parse_quad(text):
36
+ """
37
+ Read in a file from the VI quad dataset
38
+
39
+ The train file has a specific format:
40
+ the doc has a 'data' section
41
+ each block in the data is a separate document (138 in the train file, for example)
42
+ each block has a 'paragraphs' section
43
+ each paragrah has 'qas' and 'context'. we care about the qas
44
+ each piece of qas has 'question', which is what we actually want
45
+ """
46
+ doc = json.loads(text)
47
+
48
+ questions = []
49
+
50
+ for block in doc['data']:
51
+ paragraphs = block['paragraphs']
52
+ for paragraph in paragraphs:
53
+ qas = paragraph['qas']
54
+ for question in qas:
55
+ questions.append(question['question'])
56
+
57
+ return questions
58
+
59
+
60
+ def read_quad(train_file):
61
+ with open(train_file) as fin:
62
+ text = fin.read()
63
+
64
+ return parse_quad(text)
65
+
66
+ def main():
67
+ """
68
+ Turn the train section of VI quad into a list of trees
69
+ """
70
+ args = parse_args()
71
+
72
+ docs = read_quad(args.input_file)
73
+ logger.info("Read %d lines from %s", len(docs), args.input_file)
74
+ if args.tokenize_only:
75
+ pipe = stanza.Pipeline(args.lang, processors="tokenize")
76
+ text = selftrain.tokenize_docs(docs, pipe, args.min_len, args.max_len)
77
+ with open(args.output_file, "w", encoding="utf-8") as fout:
78
+ for line in text:
79
+ fout.write(line)
80
+ fout.write("\n")
81
+ else:
82
+ tag_pipe = selftrain.build_tag_pipe(ssplit=False, lang=args.lang)
83
+ parser_pipes = selftrain.build_parser_pipes(args.lang, args.models)
84
+
85
+ # create a blank file. we will append to this file so that partial results can be used
86
+ with open(args.output_file, "w") as fout:
87
+ pass
88
+
89
+ accepted_trees = set()
90
+ new_trees = selftrain.find_matching_trees(docs, args.num_sentences, accepted_trees, tag_pipe, parser_pipes, shuffle=False, chunk_size=100)
91
+ new_trees = [tree for tree in new_trees if tree.find("(_SQ") >= 0]
92
+ with open(args.output_file, "a") as fout:
93
+ for tree in sorted(new_trees):
94
+ fout.write(tree)
95
+ fout.write("\n")
96
+
97
+ if __name__ == '__main__':
98
+ main()
stanza/stanza/utils/datasets/constituency/selftrain_wiki.py ADDED
@@ -0,0 +1,140 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Builds a self-training dataset from an Italian data source and two models
2
+
3
+ The idea is that the top down and the inorder parsers should make
4
+ somewhat different errors, so hopefully the sum of an 86 f1 parser and
5
+ an 85.5 f1 parser will produce some half-decent silver trees which can
6
+ be used as self-training so that a new model can do better than either.
7
+
8
+ The dataset used is PaCCSS, which has 63000 pairs of sentences:
9
+
10
+ http://www.italianlp.it/resources/paccss-it-parallel-corpus-of-complex-simple-sentences-for-italian/
11
+ """
12
+
13
+ import argparse
14
+ from collections import deque
15
+ import glob
16
+ import os
17
+ import random
18
+
19
+ from stanza.models.common.foundation_cache import FoundationCache
20
+ from stanza.utils.datasets.constituency import selftrain
21
+ from stanza.utils.get_tqdm import get_tqdm
22
+
23
+ tqdm = get_tqdm()
24
+
25
+ def parse_args():
26
+ parser = argparse.ArgumentParser(
27
+ description="Script that converts part of a wikipedia dump to silver standard trees"
28
+ )
29
+ selftrain.common_args(parser)
30
+ parser.add_argument(
31
+ '--input_dir',
32
+ default='extern_data/vietnamese/wikipedia/text',
33
+ help='Path to the wikipedia dump after processing by wikiextractor'
34
+ )
35
+ parser.add_argument(
36
+ '--no_shuffle',
37
+ dest='shuffle',
38
+ action='store_false',
39
+ help="Don't shuffle files when processing the directory"
40
+ )
41
+
42
+ parser.set_defaults(num_sentences=10000)
43
+
44
+ args = parser.parse_args()
45
+ return args
46
+
47
+ def list_wikipedia_files(input_dir):
48
+ """
49
+ Get a list of wiki files under the input_dir
50
+
51
+ Recursively traverse the directory, then sort
52
+ """
53
+ if not os.path.isdir(input_dir) and os.path.split(input_dir)[1].startswith("wiki_"):
54
+ return [input_dir]
55
+
56
+ wiki_files = []
57
+
58
+ recursive_files = deque()
59
+ recursive_files.extend(glob.glob(os.path.join(input_dir, "*")))
60
+ while len(recursive_files) > 0:
61
+ next_file = recursive_files.pop()
62
+ if os.path.isdir(next_file):
63
+ recursive_files.extend(glob.glob(os.path.join(next_file, "*")))
64
+ elif os.path.split(next_file)[1].startswith("wiki_"):
65
+ wiki_files.append(next_file)
66
+
67
+ wiki_files.sort()
68
+ return wiki_files
69
+
70
+ def read_wiki_file(filename):
71
+ """
72
+ Read the text from a wiki file as a list of paragraphs.
73
+
74
+ Each <doc> </doc> is its own item in the list.
75
+ Lines are separated by \n\n to give hints to the stanza tokenizer.
76
+ The first line after <doc> is skipped as it is usually the document title.
77
+ """
78
+ with open(filename) as fin:
79
+ lines = fin.readlines()
80
+ docs = []
81
+ current_doc = []
82
+ line_iterator = iter(lines)
83
+ line = next(line_iterator, None)
84
+ while line is not None:
85
+ if line.startswith("<doc"):
86
+ # skip the next line, as it is usually the title
87
+ line = next(line_iterator, None)
88
+ elif line.startswith("</doc"):
89
+ if current_doc:
90
+ if len(current_doc) > 2:
91
+ # a lot of very short documents are links to related documents
92
+ # a single wikipedia can have tens of thousands of useless almost-duplicates
93
+ docs.append("\n\n".join(current_doc))
94
+ current_doc = []
95
+ else:
96
+ # not the start or end of a doc
97
+ # hopefully this is valid text
98
+ line = line.replace("()", " ")
99
+ line = line.replace("( )", " ")
100
+ line = line.strip()
101
+ if line.find("&lt;") >= 0 or line.find("&gt;") >= 0:
102
+ line = ""
103
+ if line:
104
+ current_doc.append(line)
105
+ line = next(line_iterator, None)
106
+
107
+ if current_doc:
108
+ docs.append("\n\n".join(current_doc))
109
+ return docs
110
+
111
+ def main():
112
+ args = parse_args()
113
+
114
+ random.seed(1234)
115
+
116
+ wiki_files = list_wikipedia_files(args.input_dir)
117
+ if args.shuffle:
118
+ random.shuffle(wiki_files)
119
+
120
+ foundation_cache = FoundationCache()
121
+ tag_pipe = selftrain.build_tag_pipe(ssplit=True, lang=args.lang, foundation_cache=foundation_cache)
122
+ parser_pipes = selftrain.build_parser_pipes(args.lang, args.models, foundation_cache=foundation_cache)
123
+
124
+ # create a blank file. we will append to this file so that partial results can be used
125
+ with open(args.output_file, "w") as fout:
126
+ pass
127
+
128
+ accepted_trees = set()
129
+ for filename in tqdm(wiki_files, disable=False):
130
+ docs = read_wiki_file(filename)
131
+ new_trees = selftrain.find_matching_trees(docs, args.num_sentences, accepted_trees, tag_pipe, parser_pipes, shuffle=args.shuffle)
132
+ accepted_trees.update(new_trees)
133
+
134
+ with open(args.output_file, "a") as fout:
135
+ for tree in sorted(new_trees):
136
+ fout.write(tree)
137
+ fout.write("\n")
138
+
139
+ if __name__ == '__main__':
140
+ main()
stanza/stanza/utils/datasets/constituency/split_holdout.py ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Split a constituency dataset randomly into 90/10 splits
3
+
4
+ TODO: add a function to rotate the pieces of the split so that each
5
+ training instance gets seen once
6
+ """
7
+
8
+ import argparse
9
+ import os
10
+ import random
11
+
12
+ from stanza.models.constituency import tree_reader
13
+ from stanza.utils.datasets.constituency.utils import copy_dev_test
14
+ from stanza.utils.default_paths import get_default_paths
15
+
16
+ def write_trees(base_path, dataset_name, trees):
17
+ output_path = os.path.join(base_path, "%s_train.mrg" % dataset_name)
18
+ with open(output_path, "w", encoding="utf-8") as fout:
19
+ for tree in trees:
20
+ fout.write("%s\n" % tree)
21
+
22
+
23
+ def main():
24
+ parser = argparse.ArgumentParser(description="Split a standard dataset into 90/10 proportions of train so there is held out training data")
25
+ parser.add_argument('--dataset', type=str, default="id_icon", help='dataset to split')
26
+ parser.add_argument('--base_dataset', type=str, default=None, help='output name for base dataset')
27
+ parser.add_argument('--holdout_dataset', type=str, default=None, help='output name for holdout dataset')
28
+ parser.add_argument('--ratio', type=float, default=0.1, help='Number of trees to hold out')
29
+ parser.add_argument('--seed', type=int, default=1234, help='Random seed')
30
+ args = parser.parse_args()
31
+
32
+ if args.base_dataset is None:
33
+ args.base_dataset = args.dataset + "-base"
34
+ print("--base_dataset not set, using %s" % args.base_dataset)
35
+
36
+ if args.holdout_dataset is None:
37
+ args.holdout_dataset = args.dataset + "-holdout"
38
+ print("--holdout_dataset not set, using %s" % args.holdout_dataset)
39
+
40
+ base_path = get_default_paths()["CONSTITUENCY_DATA_DIR"]
41
+ copy_dev_test(base_path, args.dataset, args.base_dataset)
42
+ copy_dev_test(base_path, args.dataset, args.holdout_dataset)
43
+
44
+ train_file = os.path.join(base_path, "%s_train.mrg" % args.dataset)
45
+ print("Reading %s" % train_file)
46
+ trees = tree_reader.read_tree_file(train_file)
47
+
48
+ base_train = []
49
+ holdout_train = []
50
+
51
+ random.seed(args.seed)
52
+
53
+ for tree in trees:
54
+ if random.random() < args.ratio:
55
+ holdout_train.append(tree)
56
+ else:
57
+ base_train.append(tree)
58
+
59
+ write_trees(base_path, args.base_dataset, base_train)
60
+ write_trees(base_path, args.holdout_dataset, holdout_train)
61
+
62
+ if __name__ == '__main__':
63
+ main()
64
+
stanza/stanza/utils/datasets/constituency/split_weighted_ensemble.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Read in a dataset and split the train portion into pieces
3
+
4
+ One chunk of the train will be the original dataset.
5
+
6
+ Others will be a sampling from the original dataset of the same size,
7
+ but sampled with replacement, with the goal being to get a random
8
+ distribution of trees with some reweighting of the original trees.
9
+ """
10
+
11
+ import argparse
12
+ import os
13
+ import random
14
+
15
+ from stanza.models.constituency import tree_reader
16
+ from stanza.models.constituency.parse_tree import Tree
17
+ from stanza.utils.datasets.constituency.utils import copy_dev_test
18
+ from stanza.utils.default_paths import get_default_paths
19
+
20
+ def main():
21
+ parser = argparse.ArgumentParser(description="Split a standard dataset into 1 base section and N-1 random redraws of training data")
22
+ parser.add_argument('--dataset', type=str, default="id_icon", help='dataset to split')
23
+ parser.add_argument('--seed', type=int, default=1234, help='Random seed')
24
+ parser.add_argument('--num_splits', type=int, default=5, help='Number of splits')
25
+ args = parser.parse_args()
26
+
27
+ random.seed(args.seed)
28
+
29
+ base_path = get_default_paths()["CONSTITUENCY_DATA_DIR"]
30
+ train_file = os.path.join(base_path, "%s_train.mrg" % args.dataset)
31
+ print("Reading %s" % train_file)
32
+ train_trees = tree_reader.read_tree_file(train_file)
33
+
34
+ # For datasets with low numbers of certain constituents in the train set,
35
+ # we could easily find ourselves in a situation where all of the trees
36
+ # with a specific constituent have been randomly shuffled away from
37
+ # a random shuffle
38
+ # An example of this is there are 3 total trees with SQ in id_icon
39
+ # Therefore, we have to take a little care to guarantee at least one tree
40
+ # for each constituent type is in a random slice
41
+ # TODO: this doesn't compensate for transition schemes with compound transitions,
42
+ # such as in_order_compound. could do a similar boosting with one per transition type
43
+ constituents = sorted(Tree.get_unique_constituent_labels(train_trees))
44
+ con_to_trees = {con: list() for con in constituents}
45
+ for tree in train_trees:
46
+ tree_cons = Tree.get_unique_constituent_labels(tree)
47
+ for con in tree_cons:
48
+ con_to_trees[con].append(tree)
49
+ for con in constituents:
50
+ print("%d trees with %s" % (len(con_to_trees[con]), con))
51
+
52
+ for i in range(args.num_splits):
53
+ dataset_name = "%s-random-%d" % (args.dataset, i)
54
+
55
+ copy_dev_test(base_path, args.dataset, dataset_name)
56
+ if i == 0:
57
+ train_dataset = train_trees
58
+ else:
59
+ train_dataset = []
60
+ for con in constituents:
61
+ train_dataset.extend(random.choices(con_to_trees[con], k=2))
62
+ needed_trees = len(train_trees) - len(train_dataset)
63
+ if needed_trees > 0:
64
+ print("%d trees already chosen. Adding %d more" % (len(train_dataset), needed_trees))
65
+ train_dataset.extend(random.choices(train_trees, k=needed_trees))
66
+ output_filename = os.path.join(base_path, "%s_train.mrg" % dataset_name)
67
+ print("Writing {} trees to {}".format(len(train_dataset), output_filename))
68
+ Tree.write_treebank(train_dataset, output_filename)
69
+
70
+
71
+ if __name__ == '__main__':
72
+ main()
73
+
stanza/stanza/utils/datasets/constituency/tokenize_wiki.py ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ A short script to use a Stanza tokenizer to extract tokenized sentences from Wikipedia
3
+
4
+ The first step is to convert a Wikipedia dataset using Prof. Attardi's wikiextractor:
5
+ https://github.com/attardi/wikiextractor
6
+
7
+ This script then writes out sentences, one per line, whitespace separated
8
+ Some common issues with the tokenizer are accounted for by discarding those lines.
9
+
10
+ Also, to account for languages such as VI where whitespace occurs within words,
11
+ spaces are replaced with _ This should not cause any confusion, as any line with
12
+ a natural _ in has already been discarded.
13
+
14
+ for i in `echo A B C D E F G H I J K`; do nlprun "python3 stanza/utils/datasets/constituency/tokenize_wiki.py --output_file /u/nlp/data/constituency-parser/italian/2024_wiki_tokenization/it_wiki_tokenized_B$i.txt --lang it --max_len 120 --input_dir /u/nlp/data/Wikipedia/itwiki/B$i --tokenizer_model saved_models/tokenize/it_combined_tokenizer.pt --download_method None" -o /u/nlp/data/constituency-parser/italian/2024_wiki_tokenization/it_wiki_tokenized_B$i.out; done
15
+ """
16
+
17
+ import argparse
18
+ import logging
19
+
20
+ import stanza
21
+ from stanza.models.common.bert_embedding import load_tokenizer, filter_data
22
+ from stanza.utils.datasets.constituency import selftrain_wiki
23
+ from stanza.utils.datasets.constituency.selftrain import add_length_args, tokenize_docs
24
+ from stanza.utils.get_tqdm import get_tqdm
25
+
26
+ tqdm = get_tqdm()
27
+
28
+ def parse_args():
29
+ parser = argparse.ArgumentParser(
30
+ description="Script that converts part of a wikipedia dump to silver standard trees"
31
+ )
32
+ parser.add_argument(
33
+ '--output_file',
34
+ default='vi_wiki_tokenized.txt',
35
+ help='Where to write the tokenized lines'
36
+ )
37
+ parser.add_argument(
38
+ '--lang',
39
+ default='vi',
40
+ help='Which language tools to use for tokenization and POS'
41
+ )
42
+
43
+ input_group = parser.add_mutually_exclusive_group(required=True)
44
+ input_group.add_argument(
45
+ '--input_dir',
46
+ default=None,
47
+ help='Path to the wikipedia dump after processing by wikiextractor'
48
+ )
49
+ input_group.add_argument(
50
+ '--input_file',
51
+ default=None,
52
+ help='Path to a single file of the wikipedia dump after processing by wikiextractor'
53
+ )
54
+ parser.add_argument(
55
+ '--bert_tokenizer',
56
+ default=None,
57
+ help='Which bert tokenizer (if any) to use to filter long sentences'
58
+ )
59
+ parser.add_argument(
60
+ '--tokenizer_model',
61
+ default=None,
62
+ help='Use this model instead of the current Stanza tokenizer for this language'
63
+ )
64
+ parser.add_argument(
65
+ '--download_method',
66
+ default=None,
67
+ help='Download pipeline models using this method (defaults to downloading updates from HF)'
68
+ )
69
+ add_length_args(parser)
70
+ args = parser.parse_args()
71
+ return args
72
+
73
+ def main():
74
+ args = parse_args()
75
+ if args.input_dir is not None:
76
+ files = selftrain_wiki.list_wikipedia_files(args.input_dir)
77
+ elif args.input_file is not None:
78
+ files = [args.input_file]
79
+ else:
80
+ raise ValueError("Need to specify at least one file or directory!")
81
+
82
+ if args.bert_tokenizer:
83
+ tokenizer = load_tokenizer(args.bert_tokenizer)
84
+ print("Max model length: %d" % tokenizer.model_max_length)
85
+ pipeline_args = {}
86
+ if args.tokenizer_model:
87
+ pipeline_args["tokenize_model_path"] = args.tokenizer_model
88
+ if args.download_method:
89
+ pipeline_args["download_method"] = args.download_method
90
+ pipe = stanza.Pipeline(args.lang, processors="tokenize", **pipeline_args)
91
+
92
+ with open(args.output_file, "w", encoding="utf-8") as fout:
93
+ for filename in tqdm(files):
94
+ docs = selftrain_wiki.read_wiki_file(filename)
95
+ text = tokenize_docs(docs, pipe, args.min_len, args.max_len)
96
+ if args.bert_tokenizer:
97
+ filtered = filter_data(args.bert_tokenizer, [x.split() for x in text], tokenizer, logging.DEBUG)
98
+ text = [" ".join(x) for x in filtered]
99
+ for line in text:
100
+ fout.write(line)
101
+ fout.write("\n")
102
+
103
+ if __name__ == '__main__':
104
+ main()