E-katrin commited on
Commit
1281049
·
verified ·
1 Parent(s): a9eb0cf

Upload ConlluTokenClassificationPipeline

Browse files
Files changed (2) hide show
  1. config.json +10 -0
  2. pipeline.py +236 -0
config.json CHANGED
@@ -8,6 +8,16 @@
8
  "AutoModel": "modeling_parser.CobaldParser"
9
  },
10
  "consecutive_null_limit": 3,
 
 
 
 
 
 
 
 
 
 
11
  "deepslot_classifier_hidden_size": 256,
12
  "dependency_classifier_hidden_size": 128,
13
  "dropout": 0.1,
 
8
  "AutoModel": "modeling_parser.CobaldParser"
9
  },
10
  "consecutive_null_limit": 3,
11
+ "custom_pipelines": {
12
+ "conllu-parsing": {
13
+ "impl": "pipeline.ConlluTokenClassificationPipeline",
14
+ "pt": [
15
+ "AutoModel"
16
+ ],
17
+ "tf": [],
18
+ "type": "text"
19
+ }
20
+ },
21
  "deepslot_classifier_hidden_size": 256,
22
  "dependency_classifier_hidden_size": 128,
23
  "dropout": 0.1,
pipeline.py ADDED
@@ -0,0 +1,236 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+
3
+ from transformers import Pipeline
4
+
5
+ from src.lemmatize_helper import reconstruct_lemma
6
+
7
+
8
+ class ConlluTokenClassificationPipeline(Pipeline):
9
+ def __init__(
10
+ self,
11
+ model,
12
+ tokenizer: callable = None,
13
+ sentenizer: callable = None,
14
+ **kwargs
15
+ ):
16
+ super().__init__(model=model, **kwargs)
17
+ self.tokenizer = tokenizer
18
+ self.sentenizer = sentenizer
19
+
20
+
21
+ def _sanitize_parameters(self, output_format: str = 'list', **kwargs):
22
+ if output_format not in ['list', 'str']:
23
+ raise ValueError(
24
+ f"output_format must be 'str' or 'list', not {output_format}"
25
+ )
26
+ # capture output_format for postprocessing
27
+ return {}, {}, {'output_format': output_format}
28
+
29
+
30
+ def preprocess(self, inputs: str) -> dict:
31
+ if not isinstance(inputs, str):
32
+ raise ValueError("pipeline input must be string (text)")
33
+
34
+ sentences = [sentence for sentence in self.sentenizer(inputs)]
35
+ words = [
36
+ [word for word in self.tokenizer(sentence)]
37
+ for sentence in sentences
38
+ ]
39
+ # stash for later post‐processing
40
+ self._texts = sentences
41
+ return {"words": words}
42
+
43
+ def _forward(self, model_inputs: dict) -> dict:
44
+ return self.model(**model_inputs, inference_mode=True)
45
+
46
+
47
+ def postprocess(self, model_outputs: dict, output_format: str) -> list[dict] | str:
48
+ sentences = self._decode_model_output(model_outputs)
49
+ # Format sentences into CoNLL-U string if requested.
50
+ if output_format == 'str':
51
+ sentences = self._format_as_conllu(sentences)
52
+ return sentences
53
+
54
+ def _decode_model_output(self, model_outputs: dict) -> list[dict]:
55
+ n_sentences = len(model_outputs["words"])
56
+
57
+ sentences_decoded = []
58
+ for i in range(n_sentences):
59
+
60
+ def select_arcs(arcs, batch_idx):
61
+ # Select arcs where batch index == batch_idx
62
+ # Return tensor of shape [n_selected_arcs, 3]
63
+ return arcs[arcs[:, 0] == batch_idx][:, 1:]
64
+
65
+ # Model outputs are padded tensors, so only leave first `n_words` labels.
66
+ n_words = len(model_outputs["words"][i])
67
+
68
+ optional_tags = {}
69
+ if "lemma_rules" in model_outputs:
70
+ optional_tags["lemma_rule_ids"] = model_outputs["lemma_rules"][i, :n_words].tolist()
71
+ if "joint_feats" in model_outputs:
72
+ optional_tags["joint_feats_ids"] = model_outputs["joint_feats"][i, :n_words].tolist()
73
+ if "deps_ud" in model_outputs:
74
+ optional_tags["deps_ud"] = select_arcs(model_outputs["deps_ud"], i).tolist()
75
+ if "deps_eud" in model_outputs:
76
+ optional_tags["deps_eud"] = select_arcs(model_outputs["deps_eud"], i).tolist()
77
+ if "miscs" in model_outputs:
78
+ optional_tags["misc_ids"] = model_outputs["miscs"][i, :n_words].tolist()
79
+ if "deepslots" in model_outputs:
80
+ optional_tags["deepslot_ids"] = model_outputs["deepslots"][i, :n_words].tolist()
81
+ if "semclasses" in model_outputs:
82
+ optional_tags["semclass_ids"] = model_outputs["semclasses"][i, :n_words].tolist()
83
+
84
+ sentence_decoded = self._decode_sentence(
85
+ text=self._texts[i],
86
+ words=model_outputs["words"][i],
87
+ **optional_tags,
88
+ )
89
+ sentences_decoded.append(sentence_decoded)
90
+ return sentences_decoded
91
+
92
+ def _decode_sentence(
93
+ self,
94
+ text: str,
95
+ words: list[str],
96
+ lemma_rule_ids: list[int] = None,
97
+ joint_feats_ids: list[int] = None,
98
+ deps_ud: list[list[int]] = None,
99
+ deps_eud: list[list[int]] = None,
100
+ misc_ids: list[int] = None,
101
+ deepslot_ids: list[int] = None,
102
+ semclass_ids: list[int] = None
103
+ ) -> dict:
104
+
105
+ # Enumerate words in the sentence, starting from 1.
106
+ ids = self._enumerate_words(words)
107
+
108
+ result = {
109
+ "text": text,
110
+ "words": words,
111
+ "ids": ids
112
+ }
113
+
114
+ # Decode lemmas.
115
+ if lemma_rule_ids:
116
+ result["lemmas"] = [
117
+ reconstruct_lemma(
118
+ word,
119
+ self.model.config.vocabulary["lemma_rule"][lemma_rule_id]
120
+ )
121
+ for word, lemma_rule_id in zip(words, lemma_rule_ids, strict=True)
122
+ ]
123
+ # Decode POS and features.
124
+ if joint_feats_ids:
125
+ upos, xpos, feats = zip(
126
+ *[
127
+ self.model.config.vocabulary["joint_feats"][joint_feats_id].split('#')
128
+ for joint_feats_id in joint_feats_ids
129
+ ],
130
+ strict=True
131
+ )
132
+ result["upos"] = list(upos)
133
+ result["xpos"] = list(xpos)
134
+ result["feats"] = list(feats)
135
+ # Decode syntax.
136
+ renumerate_and_decode_arcs = lambda arcs, id2rel: [
137
+ (
138
+ # ids stores inverse mapping from internal numeration to the standard
139
+ # conllu numeration, so simply use ids[internal_idx] to retrieve token id
140
+ # from internal index.
141
+ ids[arc_from] if arc_from != arc_to else '0',
142
+ ids[arc_to],
143
+ id2rel[deprel_id]
144
+ )
145
+ for arc_from, arc_to, deprel_id in arcs
146
+ ]
147
+ if deps_ud:
148
+ result["deps_ud"] = renumerate_and_decode_arcs(
149
+ deps_ud,
150
+ self.model.config.vocabulary["ud_deprel"]
151
+ )
152
+ if deps_eud:
153
+ result["deps_eud"] = renumerate_and_decode_arcs(
154
+ deps_eud,
155
+ self.model.config.vocabulary["eud_deprel"]
156
+ )
157
+ # Decode misc.
158
+ if misc_ids:
159
+ result["miscs"] = [
160
+ self.model.config.vocabulary["misc"][misc_id]
161
+ for misc_id in misc_ids
162
+ ]
163
+ # Decode semantics.
164
+ if deepslot_ids:
165
+ result["deepslots"] = [
166
+ self.model.config.vocabulary["deepslot"][deepslot_id]
167
+ for deepslot_id in deepslot_ids
168
+ ]
169
+ if semclass_ids:
170
+ result["semclasses"] = [
171
+ self.model.config.vocabulary["semclass"][semclass_id]
172
+ for semclass_id in semclass_ids
173
+ ]
174
+ return result
175
+
176
+ @staticmethod
177
+ def _enumerate_words(words: list[str]) -> list[str]:
178
+ ids = []
179
+ current_id = 0
180
+ current_null_count = 0
181
+ for word in words:
182
+ if word == "#NULL":
183
+ current_null_count += 1
184
+ ids.append(f"{current_id}.{current_null_count}")
185
+ else:
186
+ current_id += 1
187
+ current_null_count = 0
188
+ ids.append(f"{current_id}")
189
+ return ids
190
+
191
+ @staticmethod
192
+ def _format_as_conllu(sentences: list[dict]) -> str:
193
+ """
194
+ Format a list of sentence dicts into a CoNLL-U formatted string.
195
+ """
196
+ formatted = []
197
+ for sentence in sentences:
198
+ # The first line is a text matadata.
199
+ lines = [f"# text = {sentence['text']}"]
200
+
201
+ id2idx = {token_id: idx for idx, token_id in enumerate(sentence['ids'])}
202
+
203
+ # Basic syntax.
204
+ heads = [''] * len(id2idx)
205
+ deprels = [''] * len(id2idx)
206
+ if "deps_ud" in sentence:
207
+ for arc_from, arc_to, deprel in sentence['deps_ud']:
208
+ token_idx = id2idx[arc_to]
209
+ heads[token_idx] = arc_from
210
+ deprels[token_idx] = deprel
211
+
212
+ # Enhanced syntax.
213
+ deps_dicts = [{} for _ in range(len(id2idx))]
214
+ if "deps_eud" in sentence:
215
+ for arc_from, arc_to, deprel in sentence['deps_eud']:
216
+ token_idx = id2idx[arc_to]
217
+ deps_dicts[token_idx][arc_from] = deprel
218
+
219
+ for idx, token_id in enumerate(sentence['ids']):
220
+ word = sentence['words'][idx]
221
+ lemma = sentence['lemmas'][idx] if "lemmas" in sentence else ''
222
+ upos = sentence['upos'][idx] if "upos" in sentence else ''
223
+ xpos = sentence['xpos'][idx] if "xpos" in sentence else ''
224
+ feats = sentence['feats'][idx] if "feats" in sentence else ''
225
+ deps = '|'.join(f"{head}:{rel}" for head, rel in deps_dicts[idx].items()) or '_'
226
+ misc = sentence['miscs'][idx] if "miscs" in sentence else ''
227
+ deepslot = sentence['deepslots'][idx] if "deepslots" in sentence else ''
228
+ semclass = sentence['semclasses'][idx] if "semclasses" in sentence else ''
229
+ # CoNLL-U columns
230
+ line = '\t'.join([
231
+ token_id, word, lemma, upos, xpos, feats, heads[idx],
232
+ deprels[idx], deps, misc, deepslot, semclass
233
+ ])
234
+ lines.append(line)
235
+ formatted.append('\n'.join(lines))
236
+ return '\n\n'.join(formatted)