E-katrin commited on
Commit
0457841
·
verified ·
1 Parent(s): 319c66a

Upload ConlluTokenClassificationPipeline

Browse files
Files changed (2) hide show
  1. config.json +10 -0
  2. pipeline.py +237 -0
config.json CHANGED
@@ -8,6 +8,16 @@
8
  "AutoModel": "modeling_parser.CobaldParser"
9
  },
10
  "consecutive_null_limit": 3,
 
 
 
 
 
 
 
 
 
 
11
  "deepslot_classifier_hidden_size": 256,
12
  "dependency_classifier_hidden_size": 128,
13
  "dropout": 0.1,
 
8
  "AutoModel": "modeling_parser.CobaldParser"
9
  },
10
  "consecutive_null_limit": 3,
11
+ "custom_pipelines": {
12
+ "conllu-parsing": {
13
+ "impl": "pipeline.ConlluTokenClassificationPipeline",
14
+ "pt": [
15
+ "AutoModel"
16
+ ],
17
+ "tf": [],
18
+ "type": "text"
19
+ }
20
+ },
21
  "deepslot_classifier_hidden_size": 256,
22
  "dependency_classifier_hidden_size": 128,
23
  "dropout": 0.1,
pipeline.py ADDED
@@ -0,0 +1,237 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+
3
+ from transformers import Pipeline
4
+
5
+ from src.lemmatize_helper import reconstruct_lemma
6
+
7
+
8
+ class ConlluTokenClassificationPipeline(Pipeline):
9
+ def __init__(
10
+ self,
11
+ model,
12
+ tokenizer: callable = None,
13
+ sentenizer: callable = None,
14
+ **kwargs
15
+ ):
16
+ super().__init__(model=model, **kwargs)
17
+ self.tokenizer = tokenizer
18
+ self.sentenizer = sentenizer
19
+
20
+
21
+ def _sanitize_parameters(self, output_format: str = 'list', **kwargs):
22
+ if output_format not in ['list', 'str']:
23
+ raise ValueError(
24
+ f"output_format must be 'str' or 'list', not {output_format}"
25
+ )
26
+ # capture output_format for postprocessing
27
+ return {}, {}, {'output_format': output_format}
28
+
29
+
30
+ def preprocess(self, inputs: str) -> dict:
31
+ if not isinstance(inputs, str):
32
+ raise ValueError("pipeline input must be string (text)")
33
+
34
+ sentences = [sentence for sentence in self.sentenizer(inputs)]
35
+ words = [
36
+ [word for word in self.tokenizer(sentence)]
37
+ for sentence in sentences
38
+ ]
39
+ # stash for later post‐processing
40
+ self._texts = sentences
41
+ return {"words": words}
42
+
43
+
44
+ def _forward(self, model_inputs: dict) -> dict:
45
+ return self.model(**model_inputs, inference_mode=True)
46
+
47
+
48
+ def postprocess(self, model_outputs: dict, output_format: str) -> list[dict] | str:
49
+ sentences = self._decode_model_output(model_outputs)
50
+ # Format sentences into CoNLL-U string if requested.
51
+ if output_format == 'str':
52
+ sentences = self._format_as_conllu(sentences)
53
+ return sentences
54
+
55
+ def _decode_model_output(self, model_outputs: dict) -> list[dict]:
56
+ n_sentences = len(model_outputs["words"])
57
+
58
+ sentences_decoded = []
59
+ for i in range(n_sentences):
60
+
61
+ def select_arcs(arcs, batch_idx):
62
+ # Select arcs where batch index == batch_idx
63
+ # Return tensor of shape [n_selected_arcs, 3]
64
+ return arcs[arcs[:, 0] == batch_idx][:, 1:]
65
+
66
+ # Model outputs are padded tensors, so only leave first `n_words` labels.
67
+ n_words = len(model_outputs["words"][i])
68
+
69
+ optional_tags = {}
70
+ if "lemma_rules" in model_outputs:
71
+ optional_tags["lemma_rule_ids"] = model_outputs["lemma_rules"][i, :n_words].tolist()
72
+ if "joint_feats" in model_outputs:
73
+ optional_tags["joint_feats_ids"] = model_outputs["joint_feats"][i, :n_words].tolist()
74
+ if "deps_ud" in model_outputs:
75
+ optional_tags["deps_ud"] = select_arcs(model_outputs["deps_ud"], i).tolist()
76
+ if "deps_eud" in model_outputs:
77
+ optional_tags["deps_eud"] = select_arcs(model_outputs["deps_eud"], i).tolist()
78
+ if "miscs" in model_outputs:
79
+ optional_tags["misc_ids"] = model_outputs["miscs"][i, :n_words].tolist()
80
+ if "deepslots" in model_outputs:
81
+ optional_tags["deepslot_ids"] = model_outputs["deepslots"][i, :n_words].tolist()
82
+ if "semclasses" in model_outputs:
83
+ optional_tags["semclass_ids"] = model_outputs["semclasses"][i, :n_words].tolist()
84
+
85
+ sentence_decoded = self._decode_sentence(
86
+ text=self._texts[i],
87
+ words=model_outputs["words"][i],
88
+ **optional_tags,
89
+ )
90
+ sentences_decoded.append(sentence_decoded)
91
+ return sentences_decoded
92
+
93
+ def _decode_sentence(
94
+ self,
95
+ text: str,
96
+ words: list[str],
97
+ lemma_rule_ids: list[int] = None,
98
+ joint_feats_ids: list[int] = None,
99
+ deps_ud: list[list[int]] = None,
100
+ deps_eud: list[list[int]] = None,
101
+ misc_ids: list[int] = None,
102
+ deepslot_ids: list[int] = None,
103
+ semclass_ids: list[int] = None
104
+ ) -> dict:
105
+
106
+ # Enumerate words in the sentence, starting from 1.
107
+ ids = self._enumerate_words(words)
108
+
109
+ result = {
110
+ "text": text,
111
+ "words": words,
112
+ "ids": ids
113
+ }
114
+
115
+ # Decode lemmas.
116
+ if lemma_rule_ids:
117
+ result["lemmas"] = [
118
+ reconstruct_lemma(
119
+ word,
120
+ self.model.config.vocabulary["lemma_rule"][lemma_rule_id]
121
+ )
122
+ for word, lemma_rule_id in zip(words, lemma_rule_ids, strict=True)
123
+ ]
124
+ # Decode POS and features.
125
+ if joint_feats_ids:
126
+ upos, xpos, feats = zip(
127
+ *[
128
+ self.model.config.vocabulary["joint_feats"][joint_feats_id].split('#')
129
+ for joint_feats_id in joint_feats_ids
130
+ ],
131
+ strict=True
132
+ )
133
+ result["upos"] = list(upos)
134
+ result["xpos"] = list(xpos)
135
+ result["feats"] = list(feats)
136
+ # Decode syntax.
137
+ renumerate_and_decode_arcs = lambda arcs, id2rel: [
138
+ (
139
+ # ids stores inverse mapping from internal numeration to the standard
140
+ # conllu numeration, so simply use ids[internal_idx] to retrieve token id
141
+ # from internal index.
142
+ ids[arc_from] if arc_from != arc_to else '0',
143
+ ids[arc_to],
144
+ id2rel[deprel_id]
145
+ )
146
+ for arc_from, arc_to, deprel_id in arcs
147
+ ]
148
+ if deps_ud:
149
+ result["deps_ud"] = renumerate_and_decode_arcs(
150
+ deps_ud,
151
+ self.model.config.vocabulary["ud_deprel"]
152
+ )
153
+ if deps_eud:
154
+ result["deps_eud"] = renumerate_and_decode_arcs(
155
+ deps_eud,
156
+ self.model.config.vocabulary["eud_deprel"]
157
+ )
158
+ # Decode misc.
159
+ if misc_ids:
160
+ result["miscs"] = [
161
+ self.model.config.vocabulary["misc"][misc_id]
162
+ for misc_id in misc_ids
163
+ ]
164
+ # Decode semantics.
165
+ if deepslot_ids:
166
+ result["deepslots"] = [
167
+ self.model.config.vocabulary["deepslot"][deepslot_id]
168
+ for deepslot_id in deepslot_ids
169
+ ]
170
+ if semclass_ids:
171
+ result["semclasses"] = [
172
+ self.model.config.vocabulary["semclass"][semclass_id]
173
+ for semclass_id in semclass_ids
174
+ ]
175
+ return result
176
+
177
+ @staticmethod
178
+ def _enumerate_words(words: list[str]) -> list[str]:
179
+ ids = []
180
+ current_id = 0
181
+ current_null_count = 0
182
+ for word in words:
183
+ if word == "#NULL":
184
+ current_null_count += 1
185
+ ids.append(f"{current_id}.{current_null_count}")
186
+ else:
187
+ current_id += 1
188
+ current_null_count = 0
189
+ ids.append(f"{current_id}")
190
+ return ids
191
+
192
+ @staticmethod
193
+ def _format_as_conllu(sentences: list[dict]) -> str:
194
+ """
195
+ Format a list of sentence dicts into a CoNLL-U formatted string.
196
+ """
197
+ formatted = []
198
+ for sentence in sentences:
199
+ # The first line is a text matadata.
200
+ lines = [f"# text = {sentence['text']}"]
201
+
202
+ id2idx = {token_id: idx for idx, token_id in enumerate(sentence['ids'])}
203
+
204
+ # Basic syntax.
205
+ heads = [''] * len(id2idx)
206
+ deprels = [''] * len(id2idx)
207
+ if "deps_ud" in sentence:
208
+ for arc_from, arc_to, deprel in sentence['deps_ud']:
209
+ token_idx = id2idx[arc_to]
210
+ heads[token_idx] = arc_from
211
+ deprels[token_idx] = deprel
212
+
213
+ # Enhanced syntax.
214
+ deps_dicts = [{} for _ in range(len(id2idx))]
215
+ if "deps_eud" in sentence:
216
+ for arc_from, arc_to, deprel in sentence['deps_eud']:
217
+ token_idx = id2idx[arc_to]
218
+ deps_dicts[token_idx][arc_from] = deprel
219
+
220
+ for idx, token_id in enumerate(sentence['ids']):
221
+ word = sentence['words'][idx]
222
+ lemma = sentence['lemmas'][idx] if "lemmas" in sentence else ''
223
+ upos = sentence['upos'][idx] if "upos" in sentence else ''
224
+ xpos = sentence['xpos'][idx] if "xpos" in sentence else ''
225
+ feats = sentence['feats'][idx] if "feats" in sentence else ''
226
+ deps = '|'.join(f"{head}:{rel}" for head, rel in deps_dicts[idx].items()) or '_'
227
+ misc = sentence['miscs'][idx] if "miscs" in sentence else ''
228
+ deepslot = sentence['deepslots'][idx] if "deepslots" in sentence else ''
229
+ semclass = sentence['semclasses'][idx] if "semclasses" in sentence else ''
230
+ # CoNLL-U columns
231
+ line = '\t'.join([
232
+ token_id, word, lemma, upos, xpos, feats, heads[idx],
233
+ deprels[idx], deps, misc, deepslot, semclass
234
+ ])
235
+ lines.append(line)
236
+ formatted.append('\n'.join(lines))
237
+ return '\n\n'.join(formatted)