fortvivlan commited on
Commit
83f244f
·
verified ·
1 Parent(s): a60093b

Model save

Browse files
Files changed (10) hide show
  1. README.md +57 -0
  2. config.json +512 -0
  3. configuration.py +48 -0
  4. dependency_classifier.py +299 -0
  5. encoder.py +109 -0
  6. mlp_classifier.py +46 -0
  7. model.safetensors +3 -0
  8. modeling_parser.py +190 -0
  9. training_args.bin +3 -0
  10. utils.py +66 -0
README.md ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ base_model: xlm-roberta-base
3
+ datasets: CoBaLD/enhanced-ud-syntax
4
+ language: en
5
+ library_name: transformers
6
+ license: gpl-3.0
7
+ metrics:
8
+ - accuracy
9
+ - f1
10
+ pipeline_tag: cobald-parsing
11
+ tags:
12
+ - pytorch
13
+ model-index:
14
+ - name: CoBaLD/cobald-parser-pretrain-en
15
+ results:
16
+ - task:
17
+ type: token-classification
18
+ dataset:
19
+ name: enhanced-ud-syntax
20
+ type: CoBaLD/enhanced-ud-syntax
21
+ split: validation
22
+ metrics:
23
+ - type: f1
24
+ value: 0.2499754084433992
25
+ name: Null F1
26
+ - type: accuracy
27
+ value: 0.6778357854769815
28
+ name: Ud Jaccard
29
+ - type: accuracy
30
+ value: 0.74552842927079
31
+ name: Eud Jaccard
32
+ ---
33
+
34
+ # Model Card for cobald-parser-pretrain-en
35
+
36
+ A transformer-based multihead parser for CoBaLD annotation.
37
+
38
+ This model parses a pre-tokenized CoNLL-U text and jointly labels each token with three tiers of tags:
39
+ * Grammatical tags (lemma, UPOS, XPOS, morphological features),
40
+ * Syntactic tags (basic and enhanced Universal Dependencies),
41
+ * Semantic tags (deep slot and semantic class).
42
+
43
+ ## Model Sources
44
+
45
+ - **Repository:** https://github.com/CobaldAnnotation/CobaldParser
46
+ - **Paper:** https://dialogue-conf.org/wp-content/uploads/2025/04/BaiukIBaiukAPetrovaM.009.pdf
47
+ - **Demo:** [coming soon]
48
+
49
+ ## Citation
50
+
51
+ @inproceedings{baiuk2025cobald,
52
+ title={CoBaLD Parser: Joint Morphosyntactic and Semantic Annotation},
53
+ author={Baiuk, Ilia and Baiuk, Alexandra and Petrova, Maria},
54
+ booktitle={Proceedings of the International Conference "Dialogue"},
55
+ volume={I},
56
+ year={2025}
57
+ }
config.json ADDED
@@ -0,0 +1,512 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "activation": "relu",
3
+ "architectures": [
4
+ "CobaldParser"
5
+ ],
6
+ "auto_map": {
7
+ "AutoConfig": "configuration.CobaldParserConfig",
8
+ "AutoModel": "modeling_parser.CobaldParser"
9
+ },
10
+ "consecutive_null_limit": 3,
11
+ "custom_pipelines": {
12
+ "cobald-parsing": {
13
+ "impl": "pipeline.ConlluTokenClassificationPipeline",
14
+ "pt": "CobaldParser"
15
+ }
16
+ },
17
+ "deepslot_classifier_hidden_size": 256,
18
+ "dependency_classifier_hidden_size": 128,
19
+ "dropout": 0.1,
20
+ "encoder_model_name": "xlm-roberta-base",
21
+ "lemma_classifier_hidden_size": 512,
22
+ "misc_classifier_hidden_size": 512,
23
+ "model_type": "cobald_parser",
24
+ "morphology_classifier_hidden_size": 512,
25
+ "null_classifier_hidden_size": 512,
26
+ "semclass_classifier_hidden_size": 512,
27
+ "torch_dtype": "float32",
28
+ "transformers_version": "4.51.3",
29
+ "vocabulary": {
30
+ "eud_deprel": {
31
+ "0": "acl",
32
+ "1": "acl:about",
33
+ "2": "acl:about_whether",
34
+ "3": "acl:after",
35
+ "4": "acl:against",
36
+ "5": "acl:as",
37
+ "6": "acl:as_if",
38
+ "7": "acl:as_to",
39
+ "8": "acl:at",
40
+ "9": "acl:before",
41
+ "10": "acl:behind",
42
+ "11": "acl:between",
43
+ "12": "acl:beyond",
44
+ "13": "acl:but",
45
+ "14": "acl:but_to",
46
+ "15": "acl:concerning",
47
+ "16": "acl:except_that",
48
+ "17": "acl:for",
49
+ "18": "acl:for_to",
50
+ "19": "acl:from",
51
+ "20": "acl:if",
52
+ "21": "acl:in",
53
+ "22": "acl:including",
54
+ "23": "acl:including_whether",
55
+ "24": "acl:inside",
56
+ "25": "acl:instead_of",
57
+ "26": "acl:into",
58
+ "27": "acl:like",
59
+ "28": "acl:of",
60
+ "29": "acl:of_if",
61
+ "30": "acl:of_why",
62
+ "31": "acl:on",
63
+ "32": "acl:once",
64
+ "33": "acl:over",
65
+ "34": "acl:prior_to",
66
+ "35": "acl:regarding",
67
+ "36": "acl:relcl",
68
+ "37": "acl:relcl:to",
69
+ "38": "acl:since",
70
+ "39": "acl:such_as",
71
+ "40": "acl:than",
72
+ "41": "acl:that",
73
+ "42": "acl:though",
74
+ "43": "acl:to",
75
+ "44": "acl:toward",
76
+ "45": "acl:towards",
77
+ "46": "acl:under",
78
+ "47": "acl:until",
79
+ "48": "acl:upon",
80
+ "49": "acl:when",
81
+ "50": "acl:where",
82
+ "51": "acl:whether",
83
+ "52": "acl:why",
84
+ "53": "acl:with",
85
+ "54": "advcl",
86
+ "55": "advcl:about",
87
+ "56": "advcl:about_whether",
88
+ "57": "advcl:after",
89
+ "58": "advcl:against",
90
+ "59": "advcl:albeit",
91
+ "60": "advcl:along_with",
92
+ "61": "advcl:although",
93
+ "62": "advcl:as",
94
+ "63": "advcl:as_if",
95
+ "64": "advcl:as_in",
96
+ "65": "advcl:as_though",
97
+ "66": "advcl:as_to",
98
+ "67": "advcl:as_well_as",
99
+ "68": "advcl:as_with",
100
+ "69": "advcl:at",
101
+ "70": "advcl:because",
102
+ "71": "advcl:before",
103
+ "72": "advcl:behind",
104
+ "73": "advcl:besides",
105
+ "74": "advcl:between",
106
+ "75": "advcl:beyond",
107
+ "76": "advcl:but",
108
+ "77": "advcl:by",
109
+ "78": "advcl:cause",
110
+ "79": "advcl:despite",
111
+ "80": "advcl:due_to",
112
+ "81": "advcl:except",
113
+ "82": "advcl:except_for",
114
+ "83": "advcl:except_that",
115
+ "84": "advcl:for",
116
+ "85": "advcl:for_if",
117
+ "86": "advcl:for_to",
118
+ "87": "advcl:from",
119
+ "88": "advcl:given",
120
+ "89": "advcl:if",
121
+ "90": "advcl:if_to",
122
+ "91": "advcl:in",
123
+ "92": "advcl:in_between",
124
+ "93": "advcl:in_case",
125
+ "94": "advcl:in_order",
126
+ "95": "advcl:in_order_for",
127
+ "96": "advcl:in_order_to",
128
+ "97": "advcl:in_that",
129
+ "98": "advcl:including_by",
130
+ "99": "advcl:inside",
131
+ "100": "advcl:insofar_as",
132
+ "101": "advcl:instead_of",
133
+ "102": "advcl:into",
134
+ "103": "advcl:lest",
135
+ "104": "advcl:like",
136
+ "105": "advcl:of",
137
+ "106": "advcl:of_whether",
138
+ "107": "advcl:on",
139
+ "108": "advcl:on_whether",
140
+ "109": "advcl:once",
141
+ "110": "advcl:out",
142
+ "111": "advcl:over",
143
+ "112": "advcl:past",
144
+ "113": "advcl:prior_to",
145
+ "114": "advcl:provided",
146
+ "115": "advcl:rather_than",
147
+ "116": "advcl:relcl",
148
+ "117": "advcl:relcl:because",
149
+ "118": "advcl:since",
150
+ "119": "advcl:so",
151
+ "120": "advcl:so_as_to",
152
+ "121": "advcl:so_that",
153
+ "122": "advcl:such_as",
154
+ "123": "advcl:than",
155
+ "124": "advcl:than_if",
156
+ "125": "advcl:that",
157
+ "126": "advcl:the",
158
+ "127": "advcl:though",
159
+ "128": "advcl:through",
160
+ "129": "advcl:till",
161
+ "130": "advcl:to",
162
+ "131": "advcl:toward",
163
+ "132": "advcl:towards",
164
+ "133": "advcl:under",
165
+ "134": "advcl:unless",
166
+ "135": "advcl:until",
167
+ "136": "advcl:upon",
168
+ "137": "advcl:when",
169
+ "138": "advcl:where",
170
+ "139": "advcl:whereas",
171
+ "140": "advcl:whether",
172
+ "141": "advcl:while",
173
+ "142": "advcl:whilst",
174
+ "143": "advcl:whither",
175
+ "144": "advcl:with",
176
+ "145": "advcl:without",
177
+ "146": "advmod",
178
+ "147": "amod",
179
+ "148": "appos",
180
+ "149": "aux",
181
+ "150": "aux:pass",
182
+ "151": "case",
183
+ "152": "case:of",
184
+ "153": "cc",
185
+ "154": "cc:preconj",
186
+ "155": "ccomp",
187
+ "156": "compound",
188
+ "157": "compound:prt",
189
+ "158": "conj",
190
+ "159": "conj:and",
191
+ "160": "conj:and_or",
192
+ "161": "conj:and_yet",
193
+ "162": "conj:as_well_as",
194
+ "163": "conj:but",
195
+ "164": "conj:et",
196
+ "165": "conj:for",
197
+ "166": "conj:let_alone",
198
+ "167": "conj:minus",
199
+ "168": "conj:nor",
200
+ "169": "conj:not",
201
+ "170": "conj:not_to_mention",
202
+ "171": "conj:or",
203
+ "172": "conj:plus",
204
+ "173": "conj:plus_minus",
205
+ "174": "conj:rather_than",
206
+ "175": "conj:slash",
207
+ "176": "conj:though",
208
+ "177": "conj:yet",
209
+ "178": "cop",
210
+ "179": "csubj",
211
+ "180": "csubj:outer",
212
+ "181": "csubj:pass",
213
+ "182": "csubj:xsubj",
214
+ "183": "dep",
215
+ "184": "det",
216
+ "185": "det:predet",
217
+ "186": "discourse",
218
+ "187": "dislocated",
219
+ "188": "expl",
220
+ "189": "fixed",
221
+ "190": "flat",
222
+ "191": "goeswith",
223
+ "192": "iobj",
224
+ "193": "list",
225
+ "194": "mark",
226
+ "195": "nmod",
227
+ "196": "nmod:a_la",
228
+ "197": "nmod:aboard",
229
+ "198": "nmod:about",
230
+ "199": "nmod:above",
231
+ "200": "nmod:according_to",
232
+ "201": "nmod:across",
233
+ "202": "nmod:after",
234
+ "203": "nmod:against",
235
+ "204": "nmod:along",
236
+ "205": "nmod:alongside",
237
+ "206": "nmod:amidst",
238
+ "207": "nmod:among",
239
+ "208": "nmod:amongst",
240
+ "209": "nmod:around",
241
+ "210": "nmod:as",
242
+ "211": "nmod:as_for",
243
+ "212": "nmod:as_in",
244
+ "213": "nmod:as_opposed_to",
245
+ "214": "nmod:as_to",
246
+ "215": "nmod:astride",
247
+ "216": "nmod:at",
248
+ "217": "nmod:atop",
249
+ "218": "nmod:barring",
250
+ "219": "nmod:because_of",
251
+ "220": "nmod:before",
252
+ "221": "nmod:behind",
253
+ "222": "nmod:below",
254
+ "223": "nmod:besides",
255
+ "224": "nmod:between",
256
+ "225": "nmod:beyond",
257
+ "226": "nmod:but",
258
+ "227": "nmod:by",
259
+ "228": "nmod:circa",
260
+ "229": "nmod:colon",
261
+ "230": "nmod:concerning",
262
+ "231": "nmod:desc",
263
+ "232": "nmod:despite",
264
+ "233": "nmod:down",
265
+ "234": "nmod:due_to",
266
+ "235": "nmod:during",
267
+ "236": "nmod:except",
268
+ "237": "nmod:except_for",
269
+ "238": "nmod:excluding",
270
+ "239": "nmod:following",
271
+ "240": "nmod:for",
272
+ "241": "nmod:from",
273
+ "242": "nmod:from_across",
274
+ "243": "nmod:from_below",
275
+ "244": "nmod:from_outside",
276
+ "245": "nmod:from_over",
277
+ "246": "nmod:in",
278
+ "247": "nmod:including",
279
+ "248": "nmod:inside",
280
+ "249": "nmod:instead_of",
281
+ "250": "nmod:into",
282
+ "251": "nmod:like",
283
+ "252": "nmod:minus",
284
+ "253": "nmod:near",
285
+ "254": "nmod:next_to",
286
+ "255": "nmod:of",
287
+ "256": "nmod:off",
288
+ "257": "nmod:on",
289
+ "258": "nmod:onto",
290
+ "259": "nmod:opposite",
291
+ "260": "nmod:out",
292
+ "261": "nmod:out_of",
293
+ "262": "nmod:outside",
294
+ "263": "nmod:over",
295
+ "264": "nmod:past",
296
+ "265": "nmod:per",
297
+ "266": "nmod:poss",
298
+ "267": "nmod:post",
299
+ "268": "nmod:prior_to",
300
+ "269": "nmod:pro",
301
+ "270": "nmod:re",
302
+ "271": "nmod:regarding",
303
+ "272": "nmod:round",
304
+ "273": "nmod:save",
305
+ "274": "nmod:since",
306
+ "275": "nmod:slash",
307
+ "276": "nmod:such_as",
308
+ "277": "nmod:than",
309
+ "278": "nmod:through",
310
+ "279": "nmod:throughout",
311
+ "280": "nmod:thru",
312
+ "281": "nmod:times",
313
+ "282": "nmod:to",
314
+ "283": "nmod:toward",
315
+ "284": "nmod:towards",
316
+ "285": "nmod:under",
317
+ "286": "nmod:unlike",
318
+ "287": "nmod:unmarked",
319
+ "288": "nmod:until",
320
+ "289": "nmod:up",
321
+ "290": "nmod:upon",
322
+ "291": "nmod:versus",
323
+ "292": "nmod:via",
324
+ "293": "nmod:with",
325
+ "294": "nmod:within",
326
+ "295": "nmod:without",
327
+ "296": "nmod:x",
328
+ "297": "nsubj",
329
+ "298": "nsubj:outer",
330
+ "299": "nsubj:pass",
331
+ "300": "nsubj:pass:xsubj",
332
+ "301": "nsubj:xsubj",
333
+ "302": "nummod",
334
+ "303": "obj",
335
+ "304": "obl",
336
+ "305": "obl:aboard",
337
+ "306": "obl:about",
338
+ "307": "obl:above",
339
+ "308": "obl:according_to",
340
+ "309": "obl:across",
341
+ "310": "obl:after",
342
+ "311": "obl:against",
343
+ "312": "obl:agent",
344
+ "313": "obl:along",
345
+ "314": "obl:along_with",
346
+ "315": "obl:alongside",
347
+ "316": "obl:amid",
348
+ "317": "obl:amidst",
349
+ "318": "obl:among",
350
+ "319": "obl:amongst",
351
+ "320": "obl:around",
352
+ "321": "obl:as",
353
+ "322": "obl:as_for",
354
+ "323": "obl:as_in",
355
+ "324": "obl:as_of",
356
+ "325": "obl:as_opposed_to",
357
+ "326": "obl:as_to",
358
+ "327": "obl:aside",
359
+ "328": "obl:at",
360
+ "329": "obl:atop",
361
+ "330": "obl:because_of",
362
+ "331": "obl:before",
363
+ "332": "obl:behind",
364
+ "333": "obl:below",
365
+ "334": "obl:beneath",
366
+ "335": "obl:beside",
367
+ "336": "obl:besides",
368
+ "337": "obl:between",
369
+ "338": "obl:beyond",
370
+ "339": "obl:but",
371
+ "340": "obl:by",
372
+ "341": "obl:circa",
373
+ "342": "obl:concerning",
374
+ "343": "obl:depending",
375
+ "344": "obl:depending_on",
376
+ "345": "obl:depending_upon",
377
+ "346": "obl:despite",
378
+ "347": "obl:down",
379
+ "348": "obl:due_to",
380
+ "349": "obl:during",
381
+ "350": "obl:except",
382
+ "351": "obl:except_for",
383
+ "352": "obl:following",
384
+ "353": "obl:for",
385
+ "354": "obl:for_post",
386
+ "355": "obl:from",
387
+ "356": "obl:from_across",
388
+ "357": "obl:from_among",
389
+ "358": "obl:from_behind",
390
+ "359": "obl:from_over",
391
+ "360": "obl:given",
392
+ "361": "obl:in",
393
+ "362": "obl:in_between",
394
+ "363": "obl:in_case_of",
395
+ "364": "obl:including",
396
+ "365": "obl:including_before",
397
+ "366": "obl:including_for",
398
+ "367": "obl:including_in",
399
+ "368": "obl:inside",
400
+ "369": "obl:instead_of",
401
+ "370": "obl:into",
402
+ "371": "obl:like",
403
+ "372": "obl:minus",
404
+ "373": "obl:near",
405
+ "374": "obl:nearby",
406
+ "375": "obl:nigh",
407
+ "376": "obl:notwithstanding",
408
+ "377": "obl:of",
409
+ "378": "obl:off",
410
+ "379": "obl:off_of",
411
+ "380": "obl:on",
412
+ "381": "obl:on_board",
413
+ "382": "obl:onto",
414
+ "383": "obl:opposite",
415
+ "384": "obl:out",
416
+ "385": "obl:out_of",
417
+ "386": "obl:outside",
418
+ "387": "obl:over",
419
+ "388": "obl:past",
420
+ "389": "obl:per",
421
+ "390": "obl:post",
422
+ "391": "obl:prior_to",
423
+ "392": "obl:re",
424
+ "393": "obl:regarding",
425
+ "394": "obl:round",
426
+ "395": "obl:since",
427
+ "396": "obl:than",
428
+ "397": "obl:through",
429
+ "398": "obl:throughout",
430
+ "399": "obl:thru",
431
+ "400": "obl:till",
432
+ "401": "obl:to",
433
+ "402": "obl:to_before",
434
+ "403": "obl:toward",
435
+ "404": "obl:towards",
436
+ "405": "obl:under",
437
+ "406": "obl:underneath",
438
+ "407": "obl:unlike",
439
+ "408": "obl:unmarked",
440
+ "409": "obl:until",
441
+ "410": "obl:unto",
442
+ "411": "obl:up",
443
+ "412": "obl:up_to",
444
+ "413": "obl:upon",
445
+ "414": "obl:versus",
446
+ "415": "obl:via",
447
+ "416": "obl:with",
448
+ "417": "obl:within",
449
+ "418": "obl:without",
450
+ "419": "parataxis",
451
+ "420": "punct",
452
+ "421": "ref",
453
+ "422": "reparandum",
454
+ "423": "root",
455
+ "424": "vocative",
456
+ "425": "xcomp"
457
+ },
458
+ "ud_deprel": {
459
+ "0": "acl",
460
+ "1": "acl:relcl",
461
+ "2": "advcl",
462
+ "3": "advcl:relcl",
463
+ "4": "advmod",
464
+ "5": "amod",
465
+ "6": "appos",
466
+ "7": "aux",
467
+ "8": "aux:pass",
468
+ "9": "case",
469
+ "10": "cc",
470
+ "11": "cc:preconj",
471
+ "12": "ccomp",
472
+ "13": "compound",
473
+ "14": "compound:prt",
474
+ "15": "conj",
475
+ "16": "cop",
476
+ "17": "csubj",
477
+ "18": "csubj:outer",
478
+ "19": "csubj:pass",
479
+ "20": "dep",
480
+ "21": "det",
481
+ "22": "det:predet",
482
+ "23": "discourse",
483
+ "24": "dislocated",
484
+ "25": "expl",
485
+ "26": "fixed",
486
+ "27": "flat",
487
+ "28": "goeswith",
488
+ "29": "iobj",
489
+ "30": "list",
490
+ "31": "mark",
491
+ "32": "nmod",
492
+ "33": "nmod:desc",
493
+ "34": "nmod:poss",
494
+ "35": "nmod:unmarked",
495
+ "36": "nsubj",
496
+ "37": "nsubj:outer",
497
+ "38": "nsubj:pass",
498
+ "39": "nummod",
499
+ "40": "obj",
500
+ "41": "obl",
501
+ "42": "obl:agent",
502
+ "43": "obl:unmarked",
503
+ "44": "orphan",
504
+ "45": "parataxis",
505
+ "46": "punct",
506
+ "47": "reparandum",
507
+ "48": "root",
508
+ "49": "vocative",
509
+ "50": "xcomp"
510
+ }
511
+ }
512
+ }
configuration.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import PretrainedConfig
2
+
3
+
4
+ class CobaldParserConfig(PretrainedConfig):
5
+ model_type = "cobald_parser"
6
+
7
+ def __init__(
8
+ self,
9
+ encoder_model_name: str = None,
10
+ null_classifier_hidden_size: int = 0,
11
+ lemma_classifier_hidden_size: int = 0,
12
+ morphology_classifier_hidden_size: int = 0,
13
+ dependency_classifier_hidden_size: int = 0,
14
+ misc_classifier_hidden_size: int = 0,
15
+ deepslot_classifier_hidden_size: int = 0,
16
+ semclass_classifier_hidden_size: int = 0,
17
+ activation: str = 'relu',
18
+ dropout: float = 0.1,
19
+ consecutive_null_limit: int = 0,
20
+ vocabulary: dict[dict[int, str]] = {},
21
+ **kwargs
22
+ ):
23
+ self.encoder_model_name = encoder_model_name
24
+ self.null_classifier_hidden_size = null_classifier_hidden_size
25
+ self.consecutive_null_limit = consecutive_null_limit
26
+ self.lemma_classifier_hidden_size = lemma_classifier_hidden_size
27
+ self.morphology_classifier_hidden_size = morphology_classifier_hidden_size
28
+ self.dependency_classifier_hidden_size = dependency_classifier_hidden_size
29
+ self.misc_classifier_hidden_size = misc_classifier_hidden_size
30
+ self.deepslot_classifier_hidden_size = deepslot_classifier_hidden_size
31
+ self.semclass_classifier_hidden_size = semclass_classifier_hidden_size
32
+ self.activation = activation
33
+ self.dropout = dropout
34
+ # The serialized config stores mappings as strings,
35
+ # e.g. {"0": "acl", "1": "conj"}, so we have to convert them to int.
36
+ self.vocabulary = {
37
+ column: {int(k): v for k, v in labels.items()}
38
+ for column, labels in vocabulary.items()
39
+ }
40
+ # HACK: Tell HF hub about custom pipeline.
41
+ # It should not be hardcoded like this but other workaround are worse imo.
42
+ self.custom_pipelines = {
43
+ "cobald-parsing": {
44
+ "impl": "pipeline.ConlluTokenClassificationPipeline",
45
+ "pt": "CobaldParser",
46
+ }
47
+ }
48
+ super().__init__(**kwargs)
dependency_classifier.py ADDED
@@ -0,0 +1,299 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import override
2
+ from copy import deepcopy
3
+
4
+ import numpy as np
5
+
6
+ import torch
7
+ from torch import nn
8
+ from torch import Tensor, FloatTensor, BoolTensor, LongTensor
9
+ import torch.nn.functional as F
10
+
11
+ from transformers.activations import ACT2FN
12
+
13
+ from cobald_parser.bilinear_matrix_attention import BilinearMatrixAttention
14
+ from cobald_parser.chu_liu_edmonds import decode_mst
15
+ from cobald_parser.utils import pairwise_mask, replace_masked_values
16
+
17
+
18
+ class DependencyHeadBase(nn.Module):
19
+ """
20
+ Base class for scoring arcs and relations between tokens in a dependency tree/graph.
21
+ """
22
+
23
+ def __init__(self, hidden_size: int, n_rels: int):
24
+ super().__init__()
25
+
26
+ self.arc_attention = BilinearMatrixAttention(
27
+ hidden_size,
28
+ hidden_size,
29
+ use_input_biases=True,
30
+ n_labels=1
31
+ )
32
+ self.rel_attention = BilinearMatrixAttention(
33
+ hidden_size,
34
+ hidden_size,
35
+ use_input_biases=True,
36
+ n_labels=n_rels
37
+ )
38
+
39
+ def forward(
40
+ self,
41
+ h_arc_head: Tensor, # [batch_size, seq_len, hidden_size]
42
+ h_arc_dep: Tensor, # ...
43
+ h_rel_head: Tensor, # ...
44
+ h_rel_dep: Tensor, # ...
45
+ gold_arcs: LongTensor, # [batch_size, seq_len, seq_len]
46
+ mask: BoolTensor # [batch_size, seq_len]
47
+ ) -> dict[str, Tensor]:
48
+
49
+ # Score arcs.
50
+ # s_arc[:, i, j] = score of edge j -> i.
51
+ s_arc = self.arc_attention(h_arc_head, h_arc_dep)
52
+ # Mask undesirable values (padding, nulls, etc.) with -inf.
53
+ replace_masked_values(s_arc, pairwise_mask(mask), replace_with=-1e8)
54
+ # Score arcs' relations.
55
+ # [batch_size, seq_len, seq_len, num_labels]
56
+ s_rel = self.rel_attention(h_rel_head, h_rel_dep).permute(0, 2, 3, 1)
57
+
58
+ # Calculate loss.
59
+ loss = 0.0
60
+ if gold_arcs is not None:
61
+ loss += self.calc_arc_loss(s_arc, gold_arcs)
62
+ loss += self.calc_rel_loss(s_rel, gold_arcs)
63
+
64
+ # Predict arcs based on the scores.
65
+ # [batch_size, seq_len, seq_len]
66
+ pred_arcs_3d = self.predict_arcs(s_arc, mask)
67
+ # [batch_size, seq_len, seq_len]
68
+ pred_rels_3d = self.predict_rels(s_rel)
69
+ # [n_pred_arcs, 4]
70
+ preds_combined = self.combine_arcs_rels(pred_arcs_3d, pred_rels_3d)
71
+ return {
72
+ 'preds': preds_combined,
73
+ 'loss': loss
74
+ }
75
+
76
+ @staticmethod
77
+ def calc_arc_loss(
78
+ s_arc: Tensor, # [batch_size, seq_len, seq_len]
79
+ gold_arcs: LongTensor # [n_arcs, 4]
80
+ ) -> Tensor:
81
+ """Calculate arc loss."""
82
+ raise NotImplementedError
83
+
84
+ @staticmethod
85
+ def calc_rel_loss(
86
+ s_rel: Tensor, # [batch_size, seq_len, seq_len, num_labels]
87
+ gold_arcs: LongTensor # [n_arcs, 4]
88
+ ) -> Tensor:
89
+ batch_idxs, arcs_from, arcs_to, rels = gold_arcs.T
90
+ return F.cross_entropy(s_rel[batch_idxs, arcs_from, arcs_to], rels)
91
+
92
+ def predict_arcs(
93
+ self,
94
+ s_arc: Tensor, # [batch_size, seq_len, seq_len]
95
+ mask: BoolTensor # [batch_size, seq_len]
96
+ ) -> LongTensor:
97
+ """Predict arcs from scores."""
98
+ raise NotImplementedError
99
+
100
+ def predict_rels(
101
+ self,
102
+ s_rel: FloatTensor
103
+ ) -> LongTensor:
104
+ return s_rel.argmax(dim=-1).long()
105
+
106
+ @staticmethod
107
+ def combine_arcs_rels(
108
+ pred_arcs: LongTensor,
109
+ pred_rels: LongTensor
110
+ ) -> LongTensor:
111
+ """Select relations towards predicted arcs."""
112
+ assert pred_arcs.shape == pred_rels.shape
113
+ # Get indices where arcs exist
114
+ indices = pred_arcs.nonzero(as_tuple=True)
115
+ batch_idxs, from_idxs, to_idxs = indices
116
+ # Get corresponding relation types
117
+ rel_types = pred_rels[batch_idxs, from_idxs, to_idxs]
118
+ # Stack as [batch_idx, from_idx, to_idx, rel_type]
119
+ return torch.stack([batch_idxs, from_idxs, to_idxs, rel_types], dim=1)
120
+
121
+
122
+ class DependencyHead(DependencyHeadBase):
123
+ """
124
+ Basic UD syntax specialization that predicts single edge for each token.
125
+ """
126
+
127
+ @override
128
+ def predict_arcs(
129
+ self,
130
+ s_arc: Tensor, # [batch_size, seq_len, seq_len]
131
+ mask: BoolTensor # [batch_size, seq_len]
132
+ ) -> Tensor:
133
+
134
+ if self.training:
135
+ # During training, use fast greedy decoding.
136
+ # - [batch_size, seq_len]
137
+ pred_arcs_seq = s_arc.argmax(dim=-1)
138
+ else:
139
+ # During inference, diligently decode Maximum Spanning Tree.
140
+ pred_arcs_seq = self._mst_decode(s_arc, mask)
141
+ # FIXME
142
+ # pred_arcs_seq = s_arc.argmax(dim=-1)
143
+
144
+ # Upscale arcs sequence of shape [batch_size, seq_len]
145
+ # to matrix of shape [batch_size, seq_len, seq_len].
146
+ pred_arcs = F.one_hot(pred_arcs_seq, num_classes=pred_arcs_seq.size(1)).long()
147
+ return pred_arcs
148
+
149
+ def _mst_decode(
150
+ self,
151
+ s_arc: Tensor, # [batch_size, seq_len, seq_len]
152
+ mask: Tensor # [batch_size, seq_len]
153
+ ) -> tuple[Tensor, Tensor]:
154
+
155
+ batch_size = s_arc.size(0)
156
+ device = s_arc.device
157
+ s_arc = s_arc.cpu()
158
+
159
+ # Convert scores to probabilities, as `decode_mst` expects non-negative values.
160
+ arc_probs = nn.functional.softmax(s_arc, dim=-1)
161
+ # Transpose arcs, because decode_mst defines 'energy' matrix as
162
+ # energy[i,j] = "Score that `i` is the head of `j`",
163
+ # whereas
164
+ # arc_probs[i,j] = "Probability that `j` is the head of `i`".
165
+ arc_probs = arc_probs.transpose(1, 2)
166
+
167
+ # `decode_mst` knows nothing about UD and ROOT, so we have to manually
168
+ # zero probabilities of arcs leading to ROOT to make sure ROOT is a source node
169
+ # of a graph.
170
+
171
+ # Decode ROOT positions from diagonals.
172
+ # shape: [batch_size]
173
+ root_idxs = arc_probs.diagonal(dim1=1, dim2=2).argmax(dim=-1)
174
+ # Zero out arcs leading to ROOTs.
175
+ arc_probs[torch.arange(batch_size), :, root_idxs] = 0.0
176
+
177
+ pred_arcs = []
178
+ for sample_idx in range(batch_size):
179
+ energy = arc_probs[sample_idx]
180
+ # has_labels=False because we will decode them manually later.
181
+ lengths = mask[sample_idx].sum()
182
+ heads, _ = decode_mst(energy, lengths, has_labels=False)
183
+ # Some nodes may be isolated. Pick heads greedily in this case.
184
+ heads[heads <= 0] = s_arc[sample_idx].argmax(dim=-1)[heads <= 0]
185
+ pred_arcs.append(heads)
186
+
187
+ # shape: [batch_size, seq_len]
188
+ pred_arcs = torch.from_numpy(np.stack(pred_arcs)).long().to(device)
189
+ return pred_arcs
190
+
191
+ @staticmethod
192
+ @override
193
+ def calc_arc_loss(
194
+ s_arc: Tensor, # [batch_size, seq_len, seq_len]
195
+ gold_arcs: LongTensor # [n_arcs, 4]
196
+ ) -> tuple[Tensor, Tensor]:
197
+ batch_idxs, from_idxs, to_idxs, _ = gold_arcs.T
198
+ return F.cross_entropy(s_arc[batch_idxs, from_idxs], to_idxs)
199
+
200
+
201
+ class MultiDependencyHead(DependencyHeadBase):
202
+ """
203
+ Enhanced UD syntax specialization that predicts multiple edges for each token.
204
+ """
205
+
206
+ @override
207
+ def predict_arcs(
208
+ self,
209
+ s_arc: Tensor, # [batch_size, seq_len, seq_len]
210
+ mask: BoolTensor # [batch_size, seq_len]
211
+ ) -> Tensor:
212
+ # Convert scores to probabilities.
213
+ arc_probs = torch.sigmoid(s_arc)
214
+ # Find confident arcs (with prob > 0.5).
215
+ return arc_probs.round().long()
216
+
217
+ @staticmethod
218
+ @override
219
+ def calc_arc_loss(
220
+ s_arc: Tensor, # [batch_size, seq_len, seq_len]
221
+ gold_arcs: LongTensor # [n_arcs, 4]
222
+ ) -> Tensor:
223
+ batch_idxs, from_idxs, to_idxs, _ = gold_arcs.T
224
+ # Gold arcs but as a matrix, where matrix[i, arcs_from, arc_to] = 1.0 if arcs is present.
225
+ gold_arcs_matrix = torch.zeros_like(s_arc)
226
+ gold_arcs_matrix[batch_idxs, from_idxs, to_idxs] = 1.0
227
+ # Padded arcs's logits are huge negative values that doesn't contribute to the loss.
228
+ return F.binary_cross_entropy_with_logits(s_arc, gold_arcs_matrix)
229
+
230
+
231
+ class DependencyClassifier(nn.Module):
232
+ """
233
+ Dozat and Manning's biaffine dependency classifier.
234
+ """
235
+
236
+ def __init__(
237
+ self,
238
+ input_size: int,
239
+ hidden_size: int,
240
+ n_rels_ud: int,
241
+ n_rels_eud: int,
242
+ activation: str,
243
+ dropout: float,
244
+ ):
245
+ super().__init__()
246
+
247
+ self.arc_dep_mlp = nn.Sequential(
248
+ nn.Dropout(dropout),
249
+ nn.Linear(input_size, hidden_size),
250
+ ACT2FN[activation],
251
+ nn.Dropout(dropout)
252
+ )
253
+ # All mlps are equal.
254
+ self.arc_head_mlp = deepcopy(self.arc_dep_mlp)
255
+ self.rel_dep_mlp = deepcopy(self.arc_dep_mlp)
256
+ self.rel_head_mlp = deepcopy(self.arc_dep_mlp)
257
+
258
+ self.dependency_head_ud = DependencyHead(hidden_size, n_rels_ud)
259
+ self.dependency_head_eud = MultiDependencyHead(hidden_size, n_rels_eud)
260
+
261
+ def forward(
262
+ self,
263
+ embeddings: Tensor, # [batch_size, seq_len, embedding_size]
264
+ gold_ud: Tensor, # [n_ud_arcs, 4]
265
+ gold_eud: Tensor, # [n_eud_arcs, 4]
266
+ mask_ud: Tensor, # [batch_size, seq_len]
267
+ mask_eud: Tensor # [batch_size, seq_len]
268
+ ) -> dict[str, Tensor]:
269
+
270
+ # - [batch_size, seq_len, hidden_size]
271
+ h_arc_head = self.arc_head_mlp(embeddings)
272
+ h_arc_dep = self.arc_dep_mlp(embeddings)
273
+ h_rel_head = self.rel_head_mlp(embeddings)
274
+ h_rel_dep = self.rel_dep_mlp(embeddings)
275
+
276
+ # Share the h vectors between dependency and multi-dependency heads.
277
+ output_ud = self.dependency_head_ud(
278
+ h_arc_head,
279
+ h_arc_dep,
280
+ h_rel_head,
281
+ h_rel_dep,
282
+ gold_arcs=gold_ud,
283
+ mask=mask_ud
284
+ )
285
+ output_eud = self.dependency_head_eud(
286
+ h_arc_head,
287
+ h_arc_dep,
288
+ h_rel_head,
289
+ h_rel_dep,
290
+ gold_arcs=gold_eud,
291
+ mask=mask_eud
292
+ )
293
+
294
+ return {
295
+ 'preds_ud': output_ud["preds"],
296
+ 'preds_eud': output_eud["preds"],
297
+ 'loss_ud': output_ud["loss"],
298
+ 'loss_eud': output_eud["loss"]
299
+ }
encoder.py ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+ from torch import Tensor, LongTensor
4
+
5
+ from transformers import AutoTokenizer, AutoModel
6
+
7
+
8
+ class WordTransformerEncoder(nn.Module):
9
+ """
10
+ Encodes sentences into word-level embeddings using a pretrained MLM transformer.
11
+ """
12
+ def __init__(self, model_name: str):
13
+ super().__init__()
14
+ self.tokenizer = AutoTokenizer.from_pretrained(model_name)
15
+ # Model like BERT, RoBERTa, etc.
16
+ self.model = AutoModel.from_pretrained(model_name)
17
+
18
+ def forward(self, words: list[list[str]]) -> Tensor:
19
+ """
20
+ Build words embeddings.
21
+
22
+ - Tokenizes input sentences into subtokens.
23
+ - Passes the subtokens through the pre-trained transformer model.
24
+ - Aggregates subtoken embeddings into word embeddings using mean pooling.
25
+ """
26
+ batch_size = len(words)
27
+
28
+ # BPE tokenization: split words into subtokens, e.g. ['kidding'] -> ['▁ki', 'dding'].
29
+ subtokens = self.tokenizer(
30
+ words,
31
+ padding=True,
32
+ truncation=True,
33
+ is_split_into_words=True,
34
+ return_tensors='pt'
35
+ )
36
+ subtokens = subtokens.to(self.model.device)
37
+ # Index words from 1 and reserve 0 for special subtokens (e.g. <s>, </s>, padding, etc.).
38
+ # Such numeration makes a following aggregation easier.
39
+ words_ids = torch.stack([
40
+ torch.tensor(
41
+ [word_id + 1 if word_id is not None else 0 for word_id in subtokens.word_ids(batch_idx)],
42
+ dtype=torch.long,
43
+ device=self.model.device
44
+ )
45
+ for batch_idx in range(batch_size)
46
+ ])
47
+
48
+ # Run model and extract subtokens embeddings from the last layer.
49
+ subtokens_embeddings = self.model(**subtokens).last_hidden_state
50
+
51
+ # Aggreate subtokens embeddings into words embeddings.
52
+ # [batch_size, n_words, embedding_size]
53
+ words_emeddings = self._aggregate_subtokens_embeddings(subtokens_embeddings, words_ids)
54
+ return words_emeddings
55
+
56
+ def _aggregate_subtokens_embeddings(
57
+ self,
58
+ subtokens_embeddings: Tensor, # [batch_size, n_subtokens, embedding_size]
59
+ words_ids: LongTensor # [batch_size, n_subtokens]
60
+ ) -> Tensor:
61
+ """
62
+ Aggregate subtoken embeddings into word embeddings by averaging.
63
+
64
+ This method ensures that multiple subtokens corresponding to a single word are combined
65
+ into a single embedding.
66
+ """
67
+ batch_size, n_subtokens, embedding_size = subtokens_embeddings.shape
68
+ # The number of words in a sentence plus an "auxiliary" word in the beginnig.
69
+ n_words = torch.max(words_ids) + 1
70
+
71
+ words_embeddings = torch.zeros(
72
+ size=(batch_size, n_words, embedding_size),
73
+ dtype=subtokens_embeddings.dtype,
74
+ device=self.model.device
75
+ )
76
+ words_ids_expanded = words_ids.unsqueeze(-1).expand(batch_size, n_subtokens, embedding_size)
77
+
78
+ # Use scatter_reduce_ to average embeddings of subtokens corresponding to the same word.
79
+ # All the padding and special subtokens will be aggregated into an "auxiliary" first embedding,
80
+ # namely into words_embeddings[:, 0, :].
81
+ words_embeddings.scatter_reduce_(
82
+ dim=1,
83
+ index=words_ids_expanded,
84
+ src=subtokens_embeddings,
85
+ reduce="mean",
86
+ include_self=False
87
+ )
88
+ # Now remove the auxiliary word in the beginning.
89
+ words_embeddings = words_embeddings[:, 1:, :]
90
+ return words_embeddings
91
+
92
+ def get_embedding_size(self) -> int:
93
+ """Returns the embedding size of the transformer model, e.g. 768 for BERT."""
94
+ return self.model.config.hidden_size
95
+
96
+ def get_embeddings_layer(self):
97
+ """Returns the embeddings model."""
98
+ return self.model.embeddings
99
+
100
+ def get_transformer_layers(self) -> list[nn.Module]:
101
+ """
102
+ Return a flat list of all transformer-*block* layers, excluding embeddings/poolers, etc.
103
+ """
104
+ layers = []
105
+ for sub in self.model.modules():
106
+ # find all ModuleLists (these always hold the actual block layers)
107
+ if isinstance(sub, nn.ModuleList):
108
+ layers.extend(list(sub))
109
+ return layers
mlp_classifier.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+ from torch import Tensor, LongTensor
4
+
5
+ from transformers.activations import ACT2FN
6
+
7
+
8
+ class MlpClassifier(nn.Module):
9
+ """ Simple feed-forward multilayer perceptron classifier. """
10
+
11
+ def __init__(
12
+ self,
13
+ input_size: int,
14
+ hidden_size: int,
15
+ n_classes: int,
16
+ activation: str,
17
+ dropout: float,
18
+ class_weights: list[float] = None,
19
+ ):
20
+ super().__init__()
21
+
22
+ self.n_classes = n_classes
23
+ self.classifier = nn.Sequential(
24
+ nn.Dropout(dropout),
25
+ nn.Linear(input_size, hidden_size),
26
+ ACT2FN[activation],
27
+ nn.Dropout(dropout),
28
+ nn.Linear(hidden_size, n_classes)
29
+ )
30
+ if class_weights is not None:
31
+ class_weights = torch.tensor(class_weights, dtype=torch.long)
32
+ self.cross_entropy = nn.CrossEntropyLoss(weight=class_weights)
33
+
34
+ def forward(self, embeddings: Tensor, labels: LongTensor = None) -> dict:
35
+ logits = self.classifier(embeddings)
36
+ # Calculate loss.
37
+ loss = 0.0
38
+ if labels is not None:
39
+ # Reshape tensors to match expected dimensions
40
+ loss = self.cross_entropy(
41
+ logits.view(-1, self.n_classes),
42
+ labels.view(-1)
43
+ )
44
+ # Predictions.
45
+ preds = logits.argmax(dim=-1)
46
+ return {'preds': preds, 'loss': loss}
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:dfb31d300bc938f0353cb18e733a5c611ef8a6bbbbea8766e75878af3c304b06
3
+ size 1147244460
modeling_parser.py ADDED
@@ -0,0 +1,190 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch import nn
2
+ from torch import LongTensor
3
+ from transformers import PreTrainedModel
4
+ from transformers.modeling_outputs import ModelOutput
5
+ from dataclasses import dataclass
6
+
7
+ from .configuration import CobaldParserConfig
8
+ from .encoder import WordTransformerEncoder
9
+ from .mlp_classifier import MlpClassifier
10
+ from .dependency_classifier import DependencyClassifier
11
+ from .utils import (
12
+ build_padding_mask,
13
+ build_null_mask,
14
+ prepend_cls,
15
+ remove_nulls,
16
+ add_nulls
17
+ )
18
+
19
+
20
+ @dataclass
21
+ class CobaldParserOutput(ModelOutput):
22
+ """
23
+ Output type for CobaldParser.
24
+ """
25
+ loss: float = None
26
+ words: list = None
27
+ counting_mask: LongTensor = None
28
+ lemma_rules: LongTensor = None
29
+ joint_feats: LongTensor = None
30
+ deps_ud: LongTensor = None
31
+ deps_eud: LongTensor = None
32
+ miscs: LongTensor = None
33
+ deepslots: LongTensor = None
34
+ semclasses: LongTensor = None
35
+
36
+
37
+ class CobaldParser(PreTrainedModel):
38
+ """Morpho-Syntax-Semantic Parser."""
39
+
40
+ config_class = CobaldParserConfig
41
+
42
+ def __init__(self, config: CobaldParserConfig):
43
+ super().__init__(config)
44
+
45
+ self.encoder = WordTransformerEncoder(
46
+ model_name=config.encoder_model_name
47
+ )
48
+ embedding_size = self.encoder.get_embedding_size()
49
+
50
+ self.classifiers = nn.ModuleDict()
51
+ self.classifiers["null"] = MlpClassifier(
52
+ input_size=self.encoder.get_embedding_size(),
53
+ hidden_size=config.null_classifier_hidden_size,
54
+ n_classes=config.consecutive_null_limit + 1,
55
+ activation=config.activation,
56
+ dropout=config.dropout
57
+ )
58
+ if "lemma_rule" in config.vocabulary:
59
+ self.classifiers["lemma_rule"] = MlpClassifier(
60
+ input_size=embedding_size,
61
+ hidden_size=config.lemma_classifier_hidden_size,
62
+ n_classes=len(config.vocabulary["lemma_rule"]),
63
+ activation=config.activation,
64
+ dropout=config.dropout
65
+ )
66
+ if "joint_feats" in config.vocabulary:
67
+ self.classifiers["joint_feats"] = MlpClassifier(
68
+ input_size=embedding_size,
69
+ hidden_size=config.morphology_classifier_hidden_size,
70
+ n_classes=len(config.vocabulary["joint_feats"]),
71
+ activation=config.activation,
72
+ dropout=config.dropout
73
+ )
74
+ if "ud_deprel" in config.vocabulary or "eud_deprel" in config.vocabulary:
75
+ self.classifiers["syntax"] = DependencyClassifier(
76
+ input_size=embedding_size,
77
+ hidden_size=config.dependency_classifier_hidden_size,
78
+ n_rels_ud=len(config.vocabulary["ud_deprel"]),
79
+ n_rels_eud=len(config.vocabulary["eud_deprel"]),
80
+ activation=config.activation,
81
+ dropout=config.dropout
82
+ )
83
+ if "misc" in config.vocabulary:
84
+ self.classifiers["misc"] = MlpClassifier(
85
+ input_size=embedding_size,
86
+ hidden_size=config.misc_classifier_hidden_size,
87
+ n_classes=len(config.vocabulary["misc"]),
88
+ activation=config.activation,
89
+ dropout=config.dropout
90
+ )
91
+ if "deepslot" in config.vocabulary:
92
+ self.classifiers["deepslot"] = MlpClassifier(
93
+ input_size=embedding_size,
94
+ hidden_size=config.deepslot_classifier_hidden_size,
95
+ n_classes=len(config.vocabulary["deepslot"]),
96
+ activation=config.activation,
97
+ dropout=config.dropout
98
+ )
99
+ if "semclass" in config.vocabulary:
100
+ self.classifiers["semclass"] = MlpClassifier(
101
+ input_size=embedding_size,
102
+ hidden_size=config.semclass_classifier_hidden_size,
103
+ n_classes=len(config.vocabulary["semclass"]),
104
+ activation=config.activation,
105
+ dropout=config.dropout
106
+ )
107
+
108
+ def forward(
109
+ self,
110
+ words: list[list[str]],
111
+ counting_masks: LongTensor = None,
112
+ lemma_rules: LongTensor = None,
113
+ joint_feats: LongTensor = None,
114
+ deps_ud: LongTensor = None,
115
+ deps_eud: LongTensor = None,
116
+ miscs: LongTensor = None,
117
+ deepslots: LongTensor = None,
118
+ semclasses: LongTensor = None,
119
+ sent_ids: list[str] = None,
120
+ texts: list[str] = None,
121
+ inference_mode: bool = False
122
+ ) -> CobaldParserOutput:
123
+ result = {}
124
+
125
+ # Extra [CLS] token accounts for the case when #NULL is the first token in a sentence.
126
+ words_with_cls = prepend_cls(words)
127
+ words_without_nulls = remove_nulls(words_with_cls)
128
+ # Embeddings of words without nulls.
129
+ embeddings_without_nulls = self.encoder(words_without_nulls)
130
+ # Predict nulls.
131
+ null_output = self.classifiers["null"](embeddings_without_nulls, counting_masks)
132
+ result["counting_mask"] = null_output['preds']
133
+ result["loss"] = null_output["loss"]
134
+
135
+ # "Teacher forcing": during training, pass the original words (with gold nulls)
136
+ # to the classification heads, so that they are trained upon correct sentences.
137
+ if inference_mode:
138
+ # Restore predicted nulls in the original sentences.
139
+ result["words"] = add_nulls(words, null_output["preds"])
140
+ else:
141
+ result["words"] = words
142
+
143
+ # Encode words with nulls.
144
+ # [batch_size, seq_len, embedding_size]
145
+ embeddings = self.encoder(result["words"])
146
+
147
+ # Predict lemmas and morphological features.
148
+ if "lemma_rule" in self.classifiers:
149
+ lemma_output = self.classifiers["lemma_rule"](embeddings, lemma_rules)
150
+ result["lemma_rules"] = lemma_output['preds']
151
+ result["loss"] += lemma_output['loss']
152
+
153
+ if "joint_feats" in self.classifiers:
154
+ joint_feats_output = self.classifiers["joint_feats"](embeddings, joint_feats)
155
+ result["joint_feats"] = joint_feats_output['preds']
156
+ result["loss"] += joint_feats_output['loss']
157
+
158
+ # Predict syntax.
159
+ if "syntax" in self.classifiers:
160
+ padding_mask = build_padding_mask(result["words"], self.device)
161
+ null_mask = build_null_mask(result["words"], self.device)
162
+ deps_output = self.classifiers["syntax"](
163
+ embeddings,
164
+ deps_ud,
165
+ deps_eud,
166
+ mask_ud=(padding_mask & ~null_mask),
167
+ mask_eud=padding_mask
168
+ )
169
+ result["deps_ud"] = deps_output['preds_ud']
170
+ result["deps_eud"] = deps_output['preds_eud']
171
+ result["loss"] += deps_output['loss_ud'] + deps_output['loss_eud']
172
+
173
+ # Predict miscellaneous features.
174
+ if "misc" in self.classifiers:
175
+ misc_output = self.classifiers["misc"](embeddings, miscs)
176
+ result["miscs"] = misc_output['preds']
177
+ result["loss"] += misc_output['loss']
178
+
179
+ # Predict semantics.
180
+ if "deepslot" in self.classifiers:
181
+ deepslot_output = self.classifiers["deepslot"](embeddings, deepslots)
182
+ result["deepslots"] = deepslot_output['preds']
183
+ result["loss"] += deepslot_output['loss']
184
+
185
+ if "semclass" in self.classifiers:
186
+ semclass_output = self.classifiers["semclass"](embeddings, semclasses)
187
+ result["semclasses"] = semclass_output['preds']
188
+ result["loss"] += semclass_output['loss']
189
+
190
+ return CobaldParserOutput(**result)
training_args.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:fb86bfc1859cb8830ba649eb0e08e3be821f421c3c9c251c0e1fe160e95afe74
3
+ size 5432
utils.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import Tensor
3
+
4
+
5
+ def pad_sequences(sequences: list[Tensor], padding_value: int) -> Tensor:
6
+ """
7
+ Stack 1d tensors (sequences) into a single 2d tensor so that each sequence is padded on the
8
+ right.
9
+ """
10
+ return torch.nn.utils.rnn.pad_sequence(sequences, padding_value=padding_value, batch_first=True)
11
+
12
+
13
+ def _build_condition_mask(sentences: list[list[str]], condition_fn: callable, device) -> Tensor:
14
+ masks = [
15
+ torch.tensor([condition_fn(word) for word in sentence], dtype=bool, device=device)
16
+ for sentence in sentences
17
+ ]
18
+ return pad_sequences(masks, padding_value=False)
19
+
20
+ def build_padding_mask(sentences: list[list[str]], device) -> Tensor:
21
+ return _build_condition_mask(sentences, condition_fn=lambda word: True, device=device)
22
+
23
+ def build_null_mask(sentences: list[list[str]], device) -> Tensor:
24
+ return _build_condition_mask(sentences, condition_fn=lambda word: word == "#NULL", device=device)
25
+
26
+
27
+ def pairwise_mask(masks1d: Tensor) -> Tensor:
28
+ """
29
+ Calculate an outer product of a mask, i.e. masks2d[:, i, j] = masks1d[:, i] & masks1d[:, j].
30
+ """
31
+ return masks1d[:, None, :] & masks1d[:, :, None]
32
+
33
+
34
+ # Credits: https://docs.allennlp.org/main/api/nn/util/#replace_masked_values
35
+ def replace_masked_values(tensor: Tensor, mask: Tensor, replace_with: float):
36
+ """
37
+ Replace all masked values in tensor with `replace_with`.
38
+ """
39
+ assert tensor.dim() == mask.dim(), "tensor.dim() of {tensor.dim()} != mask.dim() of {mask.dim()}"
40
+ tensor.masked_fill_(~mask, replace_with)
41
+
42
+
43
+ def prepend_cls(sentences: list[list[str]]) -> list[list[str]]:
44
+ """
45
+ Return a copy of sentences with [CLS] token prepended.
46
+ """
47
+ return [["[CLS]", *sentence] for sentence in sentences]
48
+
49
+ def remove_nulls(sentences: list[list[str]]) -> list[list[str]]:
50
+ """
51
+ Return a copy of sentences with nulls removed.
52
+ """
53
+ return [[word for word in sentence if word != "#NULL"] for sentence in sentences]
54
+
55
+ def add_nulls(sentences: list[list[str]], counting_mask) -> list[list[str]]:
56
+ """
57
+ Return a copy of sentences with nulls restored according to counting masks.
58
+ """
59
+ sentences_with_nulls = []
60
+ for sentence, counting_mask in zip(sentences, counting_mask):
61
+ sentence_with_nulls = []
62
+ for word, n_nulls_to_insert in zip(sentence, counting_mask):
63
+ sentence_with_nulls.append(word)
64
+ sentence_with_nulls.extend(["#NULL"] * n_nulls_to_insert)
65
+ sentences_with_nulls.append(sentence_with_nulls)
66
+ return sentences_with_nulls