fortvivlan commited on
Commit
fa7e289
·
verified ·
1 Parent(s): d6ac0f0

Model save

Browse files
Files changed (7) hide show
  1. README.md +8 -8
  2. config.json +197 -197
  3. dependency_classifier.py +46 -42
  4. model.safetensors +1 -1
  5. modeling_parser.py +25 -44
  6. training_args.bin +2 -2
  7. utils.py +1 -1
README.md CHANGED
@@ -21,28 +21,28 @@ model-index:
21
  split: validation
22
  metrics:
23
  - type: f1
24
- value: 0.9352911058579663
25
  name: Null F1
26
  - type: f1
27
- value: 0.8223238635196441
28
  name: Lemma F1
29
  - type: f1
30
- value: 0.7874293202680182
31
  name: Morphology F1
32
  - type: accuracy
33
- value: 0.7509689490800553
34
  name: Ud Jaccard
35
  - type: accuracy
36
- value: 0.7934583515045791
37
  name: Eud Jaccard
38
  - type: f1
39
- value: 0.5310531282679114
40
  name: Miscs F1
41
  - type: f1
42
- value: 0.6223423025329784
43
  name: Deepslot F1
44
  - type: f1
45
- value: 0.6145897578961568
46
  name: Semclass F1
47
  ---
48
 
 
21
  split: validation
22
  metrics:
23
  - type: f1
24
+ value: 0.9270548177755096
25
  name: Null F1
26
  - type: f1
27
+ value: 0.8339235583777782
28
  name: Lemma F1
29
  - type: f1
30
+ value: 0.7885002678867238
31
  name: Morphology F1
32
  - type: accuracy
33
+ value: 0.7653227685854114
34
  name: Ud Jaccard
35
  - type: accuracy
36
+ value: 0.7962406996475656
37
  name: Eud Jaccard
38
  - type: f1
39
+ value: 0.6438483915854029
40
  name: Miscs F1
41
  - type: f1
42
+ value: 0.6179291073868571
43
  name: Deepslot F1
44
  - type: f1
45
+ value: 0.6220501826358034
46
  name: Semclass F1
47
  ---
48
 
config.json CHANGED
@@ -25,7 +25,7 @@
25
  "null_classifier_hidden_size": 512,
26
  "semclass_classifier_hidden_size": 512,
27
  "torch_dtype": "float32",
28
- "transformers_version": "4.51.3",
29
  "vocabulary": {
30
  "deepslot": {
31
  "0": "$Dislocation",
@@ -370,213 +370,213 @@
370
  "1": "ADJ#Adjective#Degree=Cmp",
371
  "2": "ADJ#Adjective#Degree=Pos",
372
  "3": "ADJ#Adjective#Degree=Sup",
373
- "4": "ADJ#None#Degree=Cmp",
374
- "5": "ADJ#None#Degree=Pos",
375
- "6": "ADJ#None#Degree=Pos|NumType=Ord",
376
- "7": "ADJ#None#Degree=Sup",
377
- "8": "ADJ#None#None",
378
- "9": "ADJ#Numeral#Degree=Pos|NumForm=Digit|NumType=Ord",
379
- "10": "ADJ#Numeral#Degree=Pos|NumForm=Word|NumType=Ord",
380
- "11": "ADJ#Prefixoid#None",
381
- "12": "ADP#Adverb#None",
382
- "13": "ADP#None#None",
383
- "14": "ADP#Preposition#None",
384
  "15": "ADV#Adjective#Degree=Pos",
385
  "16": "ADV#Adverb#Degree=Cmp",
386
  "17": "ADV#Adverb#Degree=Pos",
387
  "18": "ADV#Adverb#Degree=Pos|NumType=Mult",
388
  "19": "ADV#Adverb#Degree=Sup",
389
- "20": "ADV#Adverb#None",
390
- "21": "ADV#Adverb#NumType=Mult",
391
- "22": "ADV#Adverb#Polarity=Neg",
392
- "23": "ADV#Adverb#PronType=Dem",
393
  "24": "ADV#Invariable#Degree=Cmp",
394
- "25": "ADV#Invariable#None",
395
- "26": "ADV#None#Degree=Cmp",
396
- "27": "ADV#None#Degree=Pos",
397
- "28": "ADV#None#Degree=Sup",
398
- "29": "ADV#None#None",
399
- "30": "ADV#None#NumType=Mult",
400
- "31": "ADV#None#PronType=Dem",
401
- "32": "ADV#None#PronType=Int",
402
- "33": "ADV#Prefixoid#None",
403
- "34": "AUX#None#Mood=Ind|Number=Sing|Person=3|Tense=Pres|VerbForm=Fin",
404
- "35": "AUX#Verb#Mood=Ind|Number=Plur|Person=1|Tense=Past|VerbForm=Fin",
405
- "36": "AUX#Verb#Mood=Ind|Number=Plur|Person=1|Tense=Pres|VerbForm=Fin",
406
- "37": "AUX#Verb#Mood=Ind|Number=Plur|Person=2|Tense=Pres|VerbForm=Fin",
407
- "38": "AUX#Verb#Mood=Ind|Number=Plur|Person=3|Tense=Past|VerbForm=Fin",
408
- "39": "AUX#Verb#Mood=Ind|Number=Plur|Person=3|Tense=Pres|VerbForm=Fin",
409
- "40": "AUX#Verb#Mood=Ind|Number=Sing|Person=1|Tense=Past|VerbForm=Fin",
410
- "41": "AUX#Verb#Mood=Ind|Number=Sing|Person=1|Tense=Pres|VerbForm=Fin",
411
- "42": "AUX#Verb#Mood=Ind|Number=Sing|Person=2|Tense=Past|VerbForm=Fin",
412
- "43": "AUX#Verb#Mood=Ind|Number=Sing|Person=2|Tense=Pres|VerbForm=Fin",
413
- "44": "AUX#Verb#Mood=Ind|Number=Sing|Person=3|Tense=Past|VerbForm=Fin",
414
- "45": "AUX#Verb#Mood=Ind|Number=Sing|Person=3|Tense=Pres|VerbForm=Fin",
415
- "46": "AUX#Verb#Mood=Sub|Number=Plur|Person=1|Tense=Past|VerbForm=Fin",
416
- "47": "AUX#Verb#Mood=Sub|Number=Plur|Tense=Past|VerbForm=Part",
417
- "48": "AUX#Verb#Number=Plur|Tense=Past|VerbForm=Part",
418
- "49": "AUX#Verb#Number=Plur|Tense=Pres|VerbForm=Part",
419
- "50": "AUX#Verb#VerbForm=Fin",
420
- "51": "AUX#Verb#VerbForm=Ger",
421
- "52": "AUX#Verb#VerbForm=Inf",
422
- "53": "CCONJ#Conjunction#None",
423
- "54": "CCONJ#None#None",
424
  "55": "DET#Adjective#PronType=Tot",
425
  "56": "DET#Article#Definite=Def|PronType=Art",
426
  "57": "DET#Article#Definite=Ind|PronType=Art",
427
  "58": "DET#Conjunction#Definite=Def|PronType=Art",
428
- "59": "DET#None#Definite=Def|PronType=Art",
429
- "60": "DET#None#Definite=EMPTY",
430
- "61": "DET#None#Definite=Ind|PronType=Art",
431
- "62": "DET#None#None",
432
- "63": "DET#None#Number=Sing|PronType=Dem",
433
- "64": "DET#None#PronType=Int",
434
- "65": "DET#None#PronType=Neg",
435
- "66": "DET#None#PronType=Rcp",
436
- "67": "DET#None#PronType=Tot",
437
- "68": "DET#Prefixoid#None",
438
- "69": "DET#Pronoun#None",
439
- "70": "DET#Pronoun#Number=Plur|PronType=Dem",
440
- "71": "DET#Pronoun#Number=Sing|PronType=Dem",
441
- "72": "DET#Pronoun#Polarity=Neg",
442
- "73": "DET#Pronoun#PronType=Ind",
443
- "74": "DET#Pronoun#PronType=Int",
444
- "75": "DET#Pronoun#PronType=Rel",
445
- "76": "DET#Pronoun#PronType=Tot",
446
- "77": "INTJ#Interjection#None",
447
  "78": "NOUN#Adverb#Number=Sing",
448
- "79": "NOUN#None#Number=Plur",
449
- "80": "NOUN#None#Number=Sing",
450
- "81": "NOUN#Noun#Abbr=Yes|Number=Plur",
451
- "82": "NOUN#Noun#Abbr=Yes|Number=Sing",
452
- "83": "NOUN#Noun#NumType=Frac|Number=Sing",
453
- "84": "NOUN#Noun#Number=Plur",
454
- "85": "NOUN#Noun#Number=Sing",
455
- "86": "NOUN#Noun#Number=Sing|Polarity=Neg",
456
- "87": "NOUN#Noun#VerbForm=Fin",
457
- "88": "NOUN#Prefixoid#None",
458
- "89": "NOUN#Prefixoid#Number=Sing",
459
- "90": "NUM#None#Degree=Pos|NumType=Ord",
460
- "91": "NUM#None#NumType=Card",
461
- "92": "NUM#Noun#NumForm=Word|NumType=Card",
462
- "93": "NUM#Numeral#None",
463
- "94": "NUM#Numeral#NumForm=Digit|NumType=Card",
464
- "95": "NUM#Numeral#NumForm=Digit|NumType=Frac",
465
- "96": "NUM#Numeral#NumForm=Roman|NumType=Card",
466
- "97": "NUM#Numeral#NumForm=Word|NumType=Card",
467
- "98": "NUM#Numeral#NumType=Card",
468
- "99": "PART#None#None",
469
- "100": "PART#None#Polarity=Neg",
470
- "101": "PART#Particle#None",
471
- "102": "PART#Particle#Polarity=Neg",
472
- "103": "PPROPN#None#Number=Plur",
473
- "104": "PRON#None#Gender=Neut|Number=Sing|Person=3|Poss=Yes|PronType=Prs",
474
- "105": "PRON#None#Number=Sing",
475
- "106": "PRON#None#Number=Sing|PronType=Dem",
476
- "107": "PRON#None#Number=Sing|PronType=Ind",
477
- "108": "PRON#None#PronType=Int",
478
- "109": "PRON#None#PronType=Rel",
479
- "110": "PRON#Pronoun#Case=Acc|Gender=Fem|Number=Sing|Person=3|PronType=Prs",
480
- "111": "PRON#Pronoun#Case=Acc|Gender=Fem|Number=Sing|Person=3|PronType=Prs|Reflex=Yes",
481
- "112": "PRON#Pronoun#Case=Acc|Gender=Masc|Number=Sing|Person=3|PronType=Prs",
482
- "113": "PRON#Pronoun#Case=Acc|Gender=Masc|Number=Sing|Person=3|PronType=Prs|Reflex=Yes",
483
- "114": "PRON#Pronoun#Case=Acc|Gender=Neut|Number=Sing|Person=3|PronType=Prs",
484
- "115": "PRON#Pronoun#Case=Acc|Gender=Neut|Number=Sing|Person=3|PronType=Prs|Reflex=Yes",
485
- "116": "PRON#Pronoun#Case=Acc|Number=Plur|Person=1|PronType=Prs",
486
- "117": "PRON#Pronoun#Case=Acc|Number=Plur|Person=1|PronType=Prs|Reflex=Yes",
487
- "118": "PRON#Pronoun#Case=Acc|Number=Plur|Person=2|PronType=Prs",
488
- "119": "PRON#Pronoun#Case=Acc|Number=Plur|Person=3|PronType=Prs",
489
- "120": "PRON#Pronoun#Case=Acc|Number=Plur|Person=3|PronType=Prs|Reflex=Yes",
490
- "121": "PRON#Pronoun#Case=Acc|Number=Sing|Person=1|PronType=Prs",
491
- "122": "PRON#Pronoun#Case=Acc|Number=Sing|Person=2|PronType=Prs",
492
- "123": "PRON#Pronoun#Case=Acc|Number=Sing|Person=2|PronType=Prs|Reflex=Yes",
493
- "124": "PRON#Pronoun#Case=Gen|Gender=Fem|Number=Sing|Person=3|Poss=Yes|PronType=Prs",
494
- "125": "PRON#Pronoun#Case=Gen|Gender=Masc|Number=Sing|Person=3|Poss=Yes|PronType=Prs",
495
- "126": "PRON#Pronoun#Case=Gen|Gender=Neut|Number=Sing|Person=3|Poss=Yes|PronType=Prs",
496
- "127": "PRON#Pronoun#Case=Gen|Number=Plur|Person=1|Poss=Yes|PronType=Prs",
497
- "128": "PRON#Pronoun#Case=Gen|Number=Plur|Person=3|Poss=Yes|PronType=Prs",
498
- "129": "PRON#Pronoun#Case=Gen|Number=Sing|Person=1|Poss=Yes|PronType=Prs",
499
- "130": "PRON#Pronoun#Case=Gen|Number=Sing|Person=2|Poss=Yes|PronType=Prs",
500
- "131": "PRON#Pronoun#Case=Nom|Gender=Fem|Number=Sing|Person=3|PronType=Prs",
501
- "132": "PRON#Pronoun#Case=Nom|Gender=Masc|Number=Sing|Person=3|PronType=Prs",
502
- "133": "PRON#Pronoun#Case=Nom|Gender=Masc|Number=Sing|Person=3|PronType=Prs|Reflex=Yes",
503
- "134": "PRON#Pronoun#Case=Nom|Gender=Neut|Number=Sing|Person=3|PronType=Prs",
504
- "135": "PRON#Pronoun#Case=Nom|Gender=Neut|Number=Sing|Person=3|PronType=Prs|Reflex=Yes",
505
- "136": "PRON#Pronoun#Case=Nom|Number=Plur|Person=1|PronType=Prs",
506
- "137": "PRON#Pronoun#Case=Nom|Number=Plur|Person=2|PronType=Prs",
507
- "138": "PRON#Pronoun#Case=Nom|Number=Plur|Person=3|PronType=Prs",
508
- "139": "PRON#Pronoun#Case=Nom|Number=Plur|Person=3|PronType=Prs|Reflex=Yes",
509
- "140": "PRON#Pronoun#Case=Nom|Number=Sing|Person=1|PronType=Prs",
510
- "141": "PRON#Pronoun#Case=Nom|Number=Sing|Person=2|PronType=Prs",
511
- "142": "PRON#Pronoun#None",
512
- "143": "PRON#Pronoun#Number=Plur",
513
- "144": "PRON#Pronoun#Number=Plur|PronType=Dem",
514
- "145": "PRON#Pronoun#Number=Plur|PronType=Tot",
515
- "146": "PRON#Pronoun#Number=Sing",
516
- "147": "PRON#Pronoun#Number=Sing|Polarity=Neg|PronType=Neg",
517
- "148": "PRON#Pronoun#Number=Sing|PronType=Dem",
518
- "149": "PRON#Pronoun#Number=Sing|PronType=Ind",
519
- "150": "PRON#Pronoun#Number=Sing|PronType=Neg",
520
- "151": "PRON#Pronoun#Number=Sing|Reflex=Yes",
521
- "152": "PRON#Pronoun#PronType=Ind",
522
- "153": "PRON#Pronoun#PronType=Int",
523
- "154": "PRON#Pronoun#PronType=Rel",
524
- "155": "PROPN#None#Abbr=Yes",
525
- "156": "PROPN#None#Number=Plur",
526
- "157": "PROPN#None#Number=Sing",
527
- "158": "PROPN#Noun#Abbr=Yes|Number=Plur",
528
- "159": "PROPN#Noun#Abbr=Yes|Number=Sing",
529
- "160": "PROPN#Noun#Number=Plur",
530
- "161": "PROPN#Noun#Number=Sing",
531
- "162": "PROPN#Noun#Number=Sing|Polarity=Neg",
532
- "163": "PROPN#Noun#PronType=Dem",
533
- "164": "PROPN#Noun#VerbForm=Fin",
534
- "165": "PROPN#Prefixoid#Number=Sing",
535
- "166": "PUNCT#None#None",
536
- "167": "PUNCT#PUNCT#None",
537
- "168": "Prefixoid#Prefixoid#None",
538
- "169": "SCONJ#Conjunction#None",
539
- "170": "SCONJ#None#None",
540
- "171": "SYM#Conjunction#None",
541
- "172": "SYM#Noun#None",
542
- "173": "SYM#Noun#Number=Sing",
543
- "174": "VERB#None#Mood=Ind|Tense=Past|VerbForm=Fin",
544
- "175": "VERB#None#Tense=Past|VerbForm=Part",
545
- "176": "VERB#None#VerbForm=Ger",
546
- "177": "VERB#None#VerbForm=Inf",
547
- "178": "VERB#Verb#Mood=Imp|VerbForm=Inf",
548
- "179": "VERB#Verb#Mood=Ind|Number=Plur|Person=1|Tense=Past|VerbForm=Fin",
549
- "180": "VERB#Verb#Mood=Ind|Number=Plur|Person=1|Tense=Pres|VerbForm=Fin",
550
- "181": "VERB#Verb#Mood=Ind|Number=Plur|Person=2|Tense=Pres|VerbForm=Fin",
551
- "182": "VERB#Verb#Mood=Ind|Number=Plur|Person=3|Tense=Past|VerbForm=Fin",
552
- "183": "VERB#Verb#Mood=Ind|Number=Plur|Person=3|Tense=Pres|VerbForm=Fin",
553
- "184": "VERB#Verb#Mood=Ind|Number=Sing|Person=1|Tense=Past|VerbForm=Fin",
554
- "185": "VERB#Verb#Mood=Ind|Number=Sing|Person=1|Tense=Pres|VerbForm=Fin",
555
- "186": "VERB#Verb#Mood=Ind|Number=Sing|Person=2|Tense=Past|VerbForm=Fin",
556
- "187": "VERB#Verb#Mood=Ind|Number=Sing|Person=2|Tense=Pres|VerbForm=Fin",
557
- "188": "VERB#Verb#Mood=Ind|Number=Sing|Person=3|Polarity=Neg|Tense=Pres|VerbForm=Fin",
558
- "189": "VERB#Verb#Mood=Ind|Number=Sing|Person=3|Tense=Past|VerbForm=Fin",
559
- "190": "VERB#Verb#Mood=Ind|Number=Sing|Person=3|Tense=Pres|VerbForm=Fin",
560
- "191": "VERB#Verb#Mood=Sub|Number=Plur|Person=1|Tense=Past|VerbForm=Fin",
561
- "192": "VERB#Verb#Mood=Sub|Tense=Past|VerbForm=Part",
562
- "193": "VERB#Verb#Mood=Sub|Tense=Past|VerbForm=Part|Voice=Pass",
563
- "194": "VERB#Verb#Mood=Sub|VerbForm=Inf",
564
- "195": "VERB#Verb#Person=1|Tense=Past|VerbForm=Part",
565
- "196": "VERB#Verb#Person=1|Tense=Past|VerbForm=Part|Voice=Pass",
566
- "197": "VERB#Verb#Person=1|Tense=Pres|VerbForm=Ger",
567
- "198": "VERB#Verb#Person=1|Tense=Pres|VerbForm=Inf",
568
- "199": "VERB#Verb#Person=1|Tense=Pres|VerbForm=Part",
569
- "200": "VERB#Verb#Person=2|Tense=Pres|VerbForm=Inf",
570
- "201": "VERB#Verb#Tense=Past|VerbForm=Part",
571
- "202": "VERB#Verb#Tense=Past|VerbForm=Part|Voice=Pass",
572
- "203": "VERB#Verb#Tense=Pres|VerbForm=Part",
573
- "204": "VERB#Verb#VerbForm=Fin",
574
- "205": "VERB#Verb#VerbForm=Ger",
575
- "206": "VERB#Verb#VerbForm=Inf",
576
- "207": "X#None#Foreign=Yes",
577
- "208": "X#None#None",
578
- "209": "X#None#Typo=Yes",
579
- "210": "X#None#foreign=Yes"
580
  },
581
  "lemma_rule": {
582
  "0": "cut_prefix=0|cut_suffix=0|append_suffix=",
 
25
  "null_classifier_hidden_size": 512,
26
  "semclass_classifier_hidden_size": 512,
27
  "torch_dtype": "float32",
28
+ "transformers_version": "4.52.3",
29
  "vocabulary": {
30
  "deepslot": {
31
  "0": "$Dislocation",
 
370
  "1": "ADJ#Adjective#Degree=Cmp",
371
  "2": "ADJ#Adjective#Degree=Pos",
372
  "3": "ADJ#Adjective#Degree=Sup",
373
+ "4": "ADJ#Numeral#Degree=Pos|NumForm=Digit|NumType=Ord",
374
+ "5": "ADJ#Numeral#Degree=Pos|NumForm=Word|NumType=Ord",
375
+ "6": "ADJ#Prefixoid#_",
376
+ "7": "ADJ#_#Degree=Cmp",
377
+ "8": "ADJ#_#Degree=Pos",
378
+ "9": "ADJ#_#Degree=Pos|NumType=Ord",
379
+ "10": "ADJ#_#Degree=Sup",
380
+ "11": "ADJ#_#_",
381
+ "12": "ADP#Adverb#_",
382
+ "13": "ADP#Preposition#_",
383
+ "14": "ADP#_#_",
384
  "15": "ADV#Adjective#Degree=Pos",
385
  "16": "ADV#Adverb#Degree=Cmp",
386
  "17": "ADV#Adverb#Degree=Pos",
387
  "18": "ADV#Adverb#Degree=Pos|NumType=Mult",
388
  "19": "ADV#Adverb#Degree=Sup",
389
+ "20": "ADV#Adverb#NumType=Mult",
390
+ "21": "ADV#Adverb#Polarity=Neg",
391
+ "22": "ADV#Adverb#PronType=Dem",
392
+ "23": "ADV#Adverb#_",
393
  "24": "ADV#Invariable#Degree=Cmp",
394
+ "25": "ADV#Invariable#_",
395
+ "26": "ADV#Prefixoid#_",
396
+ "27": "ADV#_#Degree=Cmp",
397
+ "28": "ADV#_#Degree=Pos",
398
+ "29": "ADV#_#Degree=Sup",
399
+ "30": "ADV#_#NumType=Mult",
400
+ "31": "ADV#_#PronType=Dem",
401
+ "32": "ADV#_#PronType=Int",
402
+ "33": "ADV#_#_",
403
+ "34": "AUX#Verb#Mood=Ind|Number=Plur|Person=1|Tense=Past|VerbForm=Fin",
404
+ "35": "AUX#Verb#Mood=Ind|Number=Plur|Person=1|Tense=Pres|VerbForm=Fin",
405
+ "36": "AUX#Verb#Mood=Ind|Number=Plur|Person=2|Tense=Pres|VerbForm=Fin",
406
+ "37": "AUX#Verb#Mood=Ind|Number=Plur|Person=3|Tense=Past|VerbForm=Fin",
407
+ "38": "AUX#Verb#Mood=Ind|Number=Plur|Person=3|Tense=Pres|VerbForm=Fin",
408
+ "39": "AUX#Verb#Mood=Ind|Number=Sing|Person=1|Tense=Past|VerbForm=Fin",
409
+ "40": "AUX#Verb#Mood=Ind|Number=Sing|Person=1|Tense=Pres|VerbForm=Fin",
410
+ "41": "AUX#Verb#Mood=Ind|Number=Sing|Person=2|Tense=Past|VerbForm=Fin",
411
+ "42": "AUX#Verb#Mood=Ind|Number=Sing|Person=2|Tense=Pres|VerbForm=Fin",
412
+ "43": "AUX#Verb#Mood=Ind|Number=Sing|Person=3|Tense=Past|VerbForm=Fin",
413
+ "44": "AUX#Verb#Mood=Ind|Number=Sing|Person=3|Tense=Pres|VerbForm=Fin",
414
+ "45": "AUX#Verb#Mood=Sub|Number=Plur|Person=1|Tense=Past|VerbForm=Fin",
415
+ "46": "AUX#Verb#Mood=Sub|Number=Plur|Tense=Past|VerbForm=Part",
416
+ "47": "AUX#Verb#Number=Plur|Tense=Past|VerbForm=Part",
417
+ "48": "AUX#Verb#Number=Plur|Tense=Pres|VerbForm=Part",
418
+ "49": "AUX#Verb#VerbForm=Fin",
419
+ "50": "AUX#Verb#VerbForm=Ger",
420
+ "51": "AUX#Verb#VerbForm=Inf",
421
+ "52": "AUX#_#Mood=Ind|Number=Sing|Person=3|Tense=Pres|VerbForm=Fin",
422
+ "53": "CCONJ#Conjunction#_",
423
+ "54": "CCONJ#_#_",
424
  "55": "DET#Adjective#PronType=Tot",
425
  "56": "DET#Article#Definite=Def|PronType=Art",
426
  "57": "DET#Article#Definite=Ind|PronType=Art",
427
  "58": "DET#Conjunction#Definite=Def|PronType=Art",
428
+ "59": "DET#Prefixoid#_",
429
+ "60": "DET#Pronoun#Number=Plur|PronType=Dem",
430
+ "61": "DET#Pronoun#Number=Sing|PronType=Dem",
431
+ "62": "DET#Pronoun#Polarity=Neg",
432
+ "63": "DET#Pronoun#PronType=Ind",
433
+ "64": "DET#Pronoun#PronType=Int",
434
+ "65": "DET#Pronoun#PronType=Rel",
435
+ "66": "DET#Pronoun#PronType=Tot",
436
+ "67": "DET#Pronoun#_",
437
+ "68": "DET#_#Definite=Def|PronType=Art",
438
+ "69": "DET#_#Definite=EMPTY",
439
+ "70": "DET#_#Definite=Ind|PronType=Art",
440
+ "71": "DET#_#Number=Sing|PronType=Dem",
441
+ "72": "DET#_#PronType=Int",
442
+ "73": "DET#_#PronType=Neg",
443
+ "74": "DET#_#PronType=Rcp",
444
+ "75": "DET#_#PronType=Tot",
445
+ "76": "DET#_#_",
446
+ "77": "INTJ#Interjection#_",
447
  "78": "NOUN#Adverb#Number=Sing",
448
+ "79": "NOUN#Noun#Abbr=Yes|Number=Plur",
449
+ "80": "NOUN#Noun#Abbr=Yes|Number=Sing",
450
+ "81": "NOUN#Noun#NumType=Frac|Number=Sing",
451
+ "82": "NOUN#Noun#Number=Plur",
452
+ "83": "NOUN#Noun#Number=Sing",
453
+ "84": "NOUN#Noun#Number=Sing|Polarity=Neg",
454
+ "85": "NOUN#Noun#VerbForm=Fin",
455
+ "86": "NOUN#Prefixoid#Number=Sing",
456
+ "87": "NOUN#Prefixoid#_",
457
+ "88": "NOUN#_#Number=Plur",
458
+ "89": "NOUN#_#Number=Sing",
459
+ "90": "NUM#Noun#NumForm=Word|NumType=Card",
460
+ "91": "NUM#Numeral#NumForm=Digit|NumType=Card",
461
+ "92": "NUM#Numeral#NumForm=Digit|NumType=Frac",
462
+ "93": "NUM#Numeral#NumForm=Roman|NumType=Card",
463
+ "94": "NUM#Numeral#NumForm=Word|NumType=Card",
464
+ "95": "NUM#Numeral#NumType=Card",
465
+ "96": "NUM#Numeral#_",
466
+ "97": "NUM#_#Degree=Pos|NumType=Ord",
467
+ "98": "NUM#_#NumType=Card",
468
+ "99": "PART#Particle#Polarity=Neg",
469
+ "100": "PART#Particle#_",
470
+ "101": "PART#_#Polarity=Neg",
471
+ "102": "PART#_#_",
472
+ "103": "PPROPN#_#Number=Plur",
473
+ "104": "PRON#Pronoun#Case=Acc|Gender=Fem|Number=Sing|Person=3|PronType=Prs",
474
+ "105": "PRON#Pronoun#Case=Acc|Gender=Fem|Number=Sing|Person=3|PronType=Prs|Reflex=Yes",
475
+ "106": "PRON#Pronoun#Case=Acc|Gender=Masc|Number=Sing|Person=3|PronType=Prs",
476
+ "107": "PRON#Pronoun#Case=Acc|Gender=Masc|Number=Sing|Person=3|PronType=Prs|Reflex=Yes",
477
+ "108": "PRON#Pronoun#Case=Acc|Gender=Neut|Number=Sing|Person=3|PronType=Prs",
478
+ "109": "PRON#Pronoun#Case=Acc|Gender=Neut|Number=Sing|Person=3|PronType=Prs|Reflex=Yes",
479
+ "110": "PRON#Pronoun#Case=Acc|Number=Plur|Person=1|PronType=Prs",
480
+ "111": "PRON#Pronoun#Case=Acc|Number=Plur|Person=1|PronType=Prs|Reflex=Yes",
481
+ "112": "PRON#Pronoun#Case=Acc|Number=Plur|Person=2|PronType=Prs",
482
+ "113": "PRON#Pronoun#Case=Acc|Number=Plur|Person=3|PronType=Prs",
483
+ "114": "PRON#Pronoun#Case=Acc|Number=Plur|Person=3|PronType=Prs|Reflex=Yes",
484
+ "115": "PRON#Pronoun#Case=Acc|Number=Sing|Person=1|PronType=Prs",
485
+ "116": "PRON#Pronoun#Case=Acc|Number=Sing|Person=2|PronType=Prs",
486
+ "117": "PRON#Pronoun#Case=Acc|Number=Sing|Person=2|PronType=Prs|Reflex=Yes",
487
+ "118": "PRON#Pronoun#Case=Gen|Gender=Fem|Number=Sing|Person=3|Poss=Yes|PronType=Prs",
488
+ "119": "PRON#Pronoun#Case=Gen|Gender=Masc|Number=Sing|Person=3|Poss=Yes|PronType=Prs",
489
+ "120": "PRON#Pronoun#Case=Gen|Gender=Neut|Number=Sing|Person=3|Poss=Yes|PronType=Prs",
490
+ "121": "PRON#Pronoun#Case=Gen|Number=Plur|Person=1|Poss=Yes|PronType=Prs",
491
+ "122": "PRON#Pronoun#Case=Gen|Number=Plur|Person=3|Poss=Yes|PronType=Prs",
492
+ "123": "PRON#Pronoun#Case=Gen|Number=Sing|Person=1|Poss=Yes|PronType=Prs",
493
+ "124": "PRON#Pronoun#Case=Gen|Number=Sing|Person=2|Poss=Yes|PronType=Prs",
494
+ "125": "PRON#Pronoun#Case=Nom|Gender=Fem|Number=Sing|Person=3|PronType=Prs",
495
+ "126": "PRON#Pronoun#Case=Nom|Gender=Masc|Number=Sing|Person=3|PronType=Prs",
496
+ "127": "PRON#Pronoun#Case=Nom|Gender=Masc|Number=Sing|Person=3|PronType=Prs|Reflex=Yes",
497
+ "128": "PRON#Pronoun#Case=Nom|Gender=Neut|Number=Sing|Person=3|PronType=Prs",
498
+ "129": "PRON#Pronoun#Case=Nom|Gender=Neut|Number=Sing|Person=3|PronType=Prs|Reflex=Yes",
499
+ "130": "PRON#Pronoun#Case=Nom|Number=Plur|Person=1|PronType=Prs",
500
+ "131": "PRON#Pronoun#Case=Nom|Number=Plur|Person=2|PronType=Prs",
501
+ "132": "PRON#Pronoun#Case=Nom|Number=Plur|Person=3|PronType=Prs",
502
+ "133": "PRON#Pronoun#Case=Nom|Number=Plur|Person=3|PronType=Prs|Reflex=Yes",
503
+ "134": "PRON#Pronoun#Case=Nom|Number=Sing|Person=1|PronType=Prs",
504
+ "135": "PRON#Pronoun#Case=Nom|Number=Sing|Person=2|PronType=Prs",
505
+ "136": "PRON#Pronoun#Number=Plur",
506
+ "137": "PRON#Pronoun#Number=Plur|PronType=Dem",
507
+ "138": "PRON#Pronoun#Number=Plur|PronType=Tot",
508
+ "139": "PRON#Pronoun#Number=Sing",
509
+ "140": "PRON#Pronoun#Number=Sing|Polarity=Neg|PronType=Neg",
510
+ "141": "PRON#Pronoun#Number=Sing|PronType=Dem",
511
+ "142": "PRON#Pronoun#Number=Sing|PronType=Ind",
512
+ "143": "PRON#Pronoun#Number=Sing|PronType=Neg",
513
+ "144": "PRON#Pronoun#Number=Sing|Reflex=Yes",
514
+ "145": "PRON#Pronoun#PronType=Ind",
515
+ "146": "PRON#Pronoun#PronType=Int",
516
+ "147": "PRON#Pronoun#PronType=Rel",
517
+ "148": "PRON#Pronoun#_",
518
+ "149": "PRON#_#Gender=Neut|Number=Sing|Person=3|Poss=Yes|PronType=Prs",
519
+ "150": "PRON#_#Number=Sing",
520
+ "151": "PRON#_#Number=Sing|PronType=Dem",
521
+ "152": "PRON#_#Number=Sing|PronType=Ind",
522
+ "153": "PRON#_#PronType=Int",
523
+ "154": "PRON#_#PronType=Rel",
524
+ "155": "PROPN#Noun#Abbr=Yes|Number=Plur",
525
+ "156": "PROPN#Noun#Abbr=Yes|Number=Sing",
526
+ "157": "PROPN#Noun#Number=Plur",
527
+ "158": "PROPN#Noun#Number=Sing",
528
+ "159": "PROPN#Noun#Number=Sing|Polarity=Neg",
529
+ "160": "PROPN#Noun#PronType=Dem",
530
+ "161": "PROPN#Noun#VerbForm=Fin",
531
+ "162": "PROPN#Prefixoid#Number=Sing",
532
+ "163": "PROPN#_#Abbr=Yes",
533
+ "164": "PROPN#_#Number=Plur",
534
+ "165": "PROPN#_#Number=Sing",
535
+ "166": "PUNCT#PUNCT#_",
536
+ "167": "PUNCT#_#_",
537
+ "168": "Prefixoid#Prefixoid#_",
538
+ "169": "SCONJ#Conjunction#_",
539
+ "170": "SCONJ#_#_",
540
+ "171": "SYM#Conjunction#_",
541
+ "172": "SYM#Noun#Number=Sing",
542
+ "173": "SYM#Noun#_",
543
+ "174": "VERB#Verb#Mood=Imp|VerbForm=Inf",
544
+ "175": "VERB#Verb#Mood=Ind|Number=Plur|Person=1|Tense=Past|VerbForm=Fin",
545
+ "176": "VERB#Verb#Mood=Ind|Number=Plur|Person=1|Tense=Pres|VerbForm=Fin",
546
+ "177": "VERB#Verb#Mood=Ind|Number=Plur|Person=2|Tense=Pres|VerbForm=Fin",
547
+ "178": "VERB#Verb#Mood=Ind|Number=Plur|Person=3|Tense=Past|VerbForm=Fin",
548
+ "179": "VERB#Verb#Mood=Ind|Number=Plur|Person=3|Tense=Pres|VerbForm=Fin",
549
+ "180": "VERB#Verb#Mood=Ind|Number=Sing|Person=1|Tense=Past|VerbForm=Fin",
550
+ "181": "VERB#Verb#Mood=Ind|Number=Sing|Person=1|Tense=Pres|VerbForm=Fin",
551
+ "182": "VERB#Verb#Mood=Ind|Number=Sing|Person=2|Tense=Past|VerbForm=Fin",
552
+ "183": "VERB#Verb#Mood=Ind|Number=Sing|Person=2|Tense=Pres|VerbForm=Fin",
553
+ "184": "VERB#Verb#Mood=Ind|Number=Sing|Person=3|Polarity=Neg|Tense=Pres|VerbForm=Fin",
554
+ "185": "VERB#Verb#Mood=Ind|Number=Sing|Person=3|Tense=Past|VerbForm=Fin",
555
+ "186": "VERB#Verb#Mood=Ind|Number=Sing|Person=3|Tense=Pres|VerbForm=Fin",
556
+ "187": "VERB#Verb#Mood=Sub|Number=Plur|Person=1|Tense=Past|VerbForm=Fin",
557
+ "188": "VERB#Verb#Mood=Sub|Tense=Past|VerbForm=Part",
558
+ "189": "VERB#Verb#Mood=Sub|Tense=Past|VerbForm=Part|Voice=Pass",
559
+ "190": "VERB#Verb#Mood=Sub|VerbForm=Inf",
560
+ "191": "VERB#Verb#Person=1|Tense=Past|VerbForm=Part",
561
+ "192": "VERB#Verb#Person=1|Tense=Past|VerbForm=Part|Voice=Pass",
562
+ "193": "VERB#Verb#Person=1|Tense=Pres|VerbForm=Ger",
563
+ "194": "VERB#Verb#Person=1|Tense=Pres|VerbForm=Inf",
564
+ "195": "VERB#Verb#Person=1|Tense=Pres|VerbForm=Part",
565
+ "196": "VERB#Verb#Person=2|Tense=Pres|VerbForm=Inf",
566
+ "197": "VERB#Verb#Tense=Past|VerbForm=Part",
567
+ "198": "VERB#Verb#Tense=Past|VerbForm=Part|Voice=Pass",
568
+ "199": "VERB#Verb#Tense=Pres|VerbForm=Part",
569
+ "200": "VERB#Verb#VerbForm=Fin",
570
+ "201": "VERB#Verb#VerbForm=Ger",
571
+ "202": "VERB#Verb#VerbForm=Inf",
572
+ "203": "VERB#_#Mood=Ind|Tense=Past|VerbForm=Fin",
573
+ "204": "VERB#_#Tense=Past|VerbForm=Part",
574
+ "205": "VERB#_#VerbForm=Ger",
575
+ "206": "VERB#_#VerbForm=Inf",
576
+ "207": "X#_#Foreign=Yes",
577
+ "208": "X#_#Typo=Yes",
578
+ "209": "X#_#_",
579
+ "210": "X#_#foreign=Yes"
580
  },
581
  "lemma_rule": {
582
  "0": "cut_prefix=0|cut_suffix=0|append_suffix=",
dependency_classifier.py CHANGED
@@ -38,19 +38,21 @@ class DependencyHeadBase(nn.Module):
38
 
39
  def forward(
40
  self,
41
- h_arc_head: Tensor, # [batch_size, seq_len, hidden_size]
42
- h_arc_dep: Tensor, # ...
43
- h_rel_head: Tensor, # ...
44
- h_rel_dep: Tensor, # ...
45
- gold_arcs: LongTensor, # [batch_size, seq_len, seq_len]
46
- mask: BoolTensor # [batch_size, seq_len]
 
47
  ) -> dict[str, Tensor]:
48
-
49
  # Score arcs.
50
- # s_arc[:, i, j] = score of edge j -> i.
51
  s_arc = self.arc_attention(h_arc_head, h_arc_dep)
52
  # Mask undesirable values (padding, nulls, etc.) with -inf.
53
- replace_masked_values(s_arc, pairwise_mask(mask), replace_with=-1e8)
 
54
  # Score arcs' relations.
55
  # [batch_size, seq_len, seq_len, num_labels]
56
  s_rel = self.rel_attention(h_rel_head, h_rel_dep).permute(0, 2, 3, 1)
@@ -63,11 +65,11 @@ class DependencyHeadBase(nn.Module):
63
 
64
  # Predict arcs based on the scores.
65
  # [batch_size, seq_len, seq_len]
66
- pred_arcs_3d = self.predict_arcs(s_arc, mask)
67
  # [batch_size, seq_len, seq_len]
68
- pred_rels_3d = self.predict_rels(s_rel)
69
  # [n_pred_arcs, 4]
70
- preds_combined = self.combine_arcs_rels(pred_arcs_3d, pred_rels_3d)
71
  return {
72
  'preds': preds_combined,
73
  'loss': loss
@@ -91,8 +93,9 @@ class DependencyHeadBase(nn.Module):
91
 
92
  def predict_arcs(
93
  self,
94
- s_arc: Tensor, # [batch_size, seq_len, seq_len]
95
- mask: BoolTensor # [batch_size, seq_len]
 
96
  ) -> LongTensor:
97
  """Predict arcs from scores."""
98
  raise NotImplementedError
@@ -127,42 +130,40 @@ class DependencyHead(DependencyHeadBase):
127
  @override
128
  def predict_arcs(
129
  self,
130
- s_arc: Tensor, # [batch_size, seq_len, seq_len]
131
- mask: BoolTensor # [batch_size, seq_len]
 
132
  ) -> Tensor:
133
 
134
  if self.training:
135
  # During training, use fast greedy decoding.
136
  # - [batch_size, seq_len]
137
- pred_arcs_seq = s_arc.argmax(dim=-1)
138
  else:
139
- # During inference, diligently decode Maximum Spanning Tree.
140
- pred_arcs_seq = self._mst_decode(s_arc, mask)
141
- # FIXME
142
- # pred_arcs_seq = s_arc.argmax(dim=-1)
143
 
144
  # Upscale arcs sequence of shape [batch_size, seq_len]
145
  # to matrix of shape [batch_size, seq_len, seq_len].
146
- pred_arcs = F.one_hot(pred_arcs_seq, num_classes=pred_arcs_seq.size(1)).long()
 
 
 
 
147
  return pred_arcs
148
 
149
  def _mst_decode(
150
  self,
151
- s_arc: Tensor, # [batch_size, seq_len, seq_len]
152
- mask: Tensor # [batch_size, seq_len]
153
  ) -> tuple[Tensor, Tensor]:
154
-
155
  batch_size = s_arc.size(0)
156
  device = s_arc.device
157
  s_arc = s_arc.cpu()
158
 
159
  # Convert scores to probabilities, as `decode_mst` expects non-negative values.
160
- arc_probs = nn.functional.softmax(s_arc, dim=-1)
161
- # Transpose arcs, because decode_mst defines 'energy' matrix as
162
- # energy[i,j] = "Score that `i` is the head of `j`",
163
- # whereas
164
- # arc_probs[i,j] = "Probability that `j` is the head of `i`".
165
- arc_probs = arc_probs.transpose(1, 2)
166
 
167
  # `decode_mst` knows nothing about UD and ROOT, so we have to manually
168
  # zero probabilities of arcs leading to ROOT to make sure ROOT is a source node
@@ -177,11 +178,10 @@ class DependencyHead(DependencyHeadBase):
177
  pred_arcs = []
178
  for sample_idx in range(batch_size):
179
  energy = arc_probs[sample_idx]
180
- # has_labels=False because we will decode them manually later.
181
- lengths = mask[sample_idx].sum()
182
- heads, _ = decode_mst(energy, lengths, has_labels=False)
183
  # Some nodes may be isolated. Pick heads greedily in this case.
184
- heads[heads <= 0] = s_arc[sample_idx].argmax(dim=-1)[heads <= 0]
185
  pred_arcs.append(heads)
186
 
187
  # shape: [batch_size, seq_len]
@@ -195,7 +195,7 @@ class DependencyHead(DependencyHeadBase):
195
  gold_arcs: LongTensor # [n_arcs, 4]
196
  ) -> tuple[Tensor, Tensor]:
197
  batch_idxs, from_idxs, to_idxs, _ = gold_arcs.T
198
- return F.cross_entropy(s_arc[batch_idxs, from_idxs], to_idxs)
199
 
200
 
201
  class MultiDependencyHead(DependencyHeadBase):
@@ -206,8 +206,9 @@ class MultiDependencyHead(DependencyHeadBase):
206
  @override
207
  def predict_arcs(
208
  self,
209
- s_arc: Tensor, # [batch_size, seq_len, seq_len]
210
- mask: BoolTensor # [batch_size, seq_len]
 
211
  ) -> Tensor:
212
  # Convert scores to probabilities.
213
  arc_probs = torch.sigmoid(s_arc)
@@ -263,8 +264,8 @@ class DependencyClassifier(nn.Module):
263
  embeddings: Tensor, # [batch_size, seq_len, embedding_size]
264
  gold_ud: Tensor, # [n_ud_arcs, 4]
265
  gold_eud: Tensor, # [n_eud_arcs, 4]
266
- mask_ud: Tensor, # [batch_size, seq_len]
267
- mask_eud: Tensor # [batch_size, seq_len]
268
  ) -> dict[str, Tensor]:
269
 
270
  # - [batch_size, seq_len, hidden_size]
@@ -280,7 +281,8 @@ class DependencyClassifier(nn.Module):
280
  h_rel_head,
281
  h_rel_dep,
282
  gold_arcs=gold_ud,
283
- mask=mask_ud
 
284
  )
285
  output_eud = self.dependency_head_eud(
286
  h_arc_head,
@@ -288,7 +290,9 @@ class DependencyClassifier(nn.Module):
288
  h_rel_head,
289
  h_rel_dep,
290
  gold_arcs=gold_eud,
291
- mask=mask_eud
 
 
292
  )
293
 
294
  return {
 
38
 
39
  def forward(
40
  self,
41
+ h_arc_head: Tensor, # [batch_size, seq_len, hidden_size]
42
+ h_arc_dep: Tensor, # ...
43
+ h_rel_head: Tensor, # ...
44
+ h_rel_dep: Tensor, # ...
45
+ gold_arcs: LongTensor, # [batch_size, seq_len, seq_len]
46
+ null_mask: BoolTensor, # [batch_size, seq_len]
47
+ padding_mask: BoolTensor # [batch_size, seq_len]
48
  ) -> dict[str, Tensor]:
49
+
50
  # Score arcs.
51
+ # s_arc[:, i, j] = score of edge i -> j.
52
  s_arc = self.arc_attention(h_arc_head, h_arc_dep)
53
  # Mask undesirable values (padding, nulls, etc.) with -inf.
54
+ mask2d = pairwise_mask(null_mask & padding_mask)
55
+ replace_masked_values(s_arc, mask2d, replace_with=-1e8)
56
  # Score arcs' relations.
57
  # [batch_size, seq_len, seq_len, num_labels]
58
  s_rel = self.rel_attention(h_rel_head, h_rel_dep).permute(0, 2, 3, 1)
 
65
 
66
  # Predict arcs based on the scores.
67
  # [batch_size, seq_len, seq_len]
68
+ pred_arcs_matrix = self.predict_arcs(s_arc, null_mask, padding_mask)
69
  # [batch_size, seq_len, seq_len]
70
+ pred_rels_matrix = self.predict_rels(s_rel)
71
  # [n_pred_arcs, 4]
72
+ preds_combined = self.combine_arcs_rels(pred_arcs_matrix, pred_rels_matrix)
73
  return {
74
  'preds': preds_combined,
75
  'loss': loss
 
93
 
94
  def predict_arcs(
95
  self,
96
+ s_arc: Tensor, # [batch_size, seq_len, seq_len]
97
+ null_mask: BoolTensor, # [batch_size, seq_len]
98
+ padding_mask: BoolTensor # [batch_size, seq_len]
99
  ) -> LongTensor:
100
  """Predict arcs from scores."""
101
  raise NotImplementedError
 
130
  @override
131
  def predict_arcs(
132
  self,
133
+ s_arc: Tensor, # [batch_size, seq_len, seq_len]
134
+ null_mask: BoolTensor, # [batch_size, seq_len]
135
+ padding_mask: BoolTensor # [batch_size, seq_len, seq_len]
136
  ) -> Tensor:
137
 
138
  if self.training:
139
  # During training, use fast greedy decoding.
140
  # - [batch_size, seq_len]
141
+ pred_arcs_seq = s_arc.argmax(dim=1)
142
  else:
143
+ # During inference, decode Maximum Spanning Tree.
144
+ pred_arcs_seq = self._mst_decode(s_arc, padding_mask)
 
 
145
 
146
  # Upscale arcs sequence of shape [batch_size, seq_len]
147
  # to matrix of shape [batch_size, seq_len, seq_len].
148
+ pred_arcs = F.one_hot(pred_arcs_seq, num_classes=pred_arcs_seq.size(1)).long().transpose(1, 2)
149
+ # Apply mask one more time (even though s_arc is already masked),
150
+ # because argmax erases information about masked values.
151
+ mask2d = pairwise_mask(null_mask & padding_mask)
152
+ replace_masked_values(pred_arcs, mask2d, replace_with=0)
153
  return pred_arcs
154
 
155
  def _mst_decode(
156
  self,
157
+ s_arc: Tensor, # [batch_size, seq_len, seq_len]
158
+ padding_mask: Tensor
159
  ) -> tuple[Tensor, Tensor]:
160
+
161
  batch_size = s_arc.size(0)
162
  device = s_arc.device
163
  s_arc = s_arc.cpu()
164
 
165
  # Convert scores to probabilities, as `decode_mst` expects non-negative values.
166
+ arc_probs = nn.functional.softmax(s_arc, dim=1)
 
 
 
 
 
167
 
168
  # `decode_mst` knows nothing about UD and ROOT, so we have to manually
169
  # zero probabilities of arcs leading to ROOT to make sure ROOT is a source node
 
178
  pred_arcs = []
179
  for sample_idx in range(batch_size):
180
  energy = arc_probs[sample_idx]
181
+ length = padding_mask[sample_idx].sum()
182
+ heads = decode_mst(energy, length)
 
183
  # Some nodes may be isolated. Pick heads greedily in this case.
184
+ heads[heads <= 0] = s_arc[sample_idx].argmax(dim=1)[heads <= 0]
185
  pred_arcs.append(heads)
186
 
187
  # shape: [batch_size, seq_len]
 
195
  gold_arcs: LongTensor # [n_arcs, 4]
196
  ) -> tuple[Tensor, Tensor]:
197
  batch_idxs, from_idxs, to_idxs, _ = gold_arcs.T
198
+ return F.cross_entropy(s_arc[batch_idxs, :, to_idxs], from_idxs)
199
 
200
 
201
  class MultiDependencyHead(DependencyHeadBase):
 
206
  @override
207
  def predict_arcs(
208
  self,
209
+ s_arc: Tensor, # [batch_size, seq_len, seq_len]
210
+ null_mask: BoolTensor, # [batch_size, seq_len]
211
+ padding_mask: BoolTensor # [batch_size, seq_len]
212
  ) -> Tensor:
213
  # Convert scores to probabilities.
214
  arc_probs = torch.sigmoid(s_arc)
 
264
  embeddings: Tensor, # [batch_size, seq_len, embedding_size]
265
  gold_ud: Tensor, # [n_ud_arcs, 4]
266
  gold_eud: Tensor, # [n_eud_arcs, 4]
267
+ null_mask: Tensor, # [batch_size, seq_len]
268
+ padding_mask: Tensor # [batch_size, seq_len]
269
  ) -> dict[str, Tensor]:
270
 
271
  # - [batch_size, seq_len, hidden_size]
 
281
  h_rel_head,
282
  h_rel_dep,
283
  gold_arcs=gold_ud,
284
+ null_mask=null_mask,
285
+ padding_mask=padding_mask
286
  )
287
  output_eud = self.dependency_head_eud(
288
  h_arc_head,
 
290
  h_rel_head,
291
  h_rel_dep,
292
  gold_arcs=gold_eud,
293
+ # Ignore null mask in E-UD
294
+ null_mask=torch.ones_like(padding_mask),
295
+ padding_mask=padding_mask
296
  )
297
 
298
  return {
model.safetensors CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:3641289089a079abc37a0858a0a412bc6f031f0755e5b24d229c4bf92ce83976
3
  size 1141314800
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1618ac5132f5aa1c8525829b0a8ac2e7a0e38ae184cfd1bdcbe5ded4e90a63ee
3
  size 1141314800
modeling_parser.py CHANGED
@@ -1,8 +1,6 @@
1
  from torch import nn
2
  from torch import LongTensor
3
  from transformers import PreTrainedModel
4
- from transformers.modeling_outputs import ModelOutput
5
- from dataclasses import dataclass
6
 
7
  from .configuration import CobaldParserConfig
8
  from .encoder import WordTransformerEncoder
@@ -17,23 +15,6 @@ from .utils import (
17
  )
18
 
19
 
20
- @dataclass
21
- class CobaldParserOutput(ModelOutput):
22
- """
23
- Output type for CobaldParser.
24
- """
25
- loss: float = None
26
- words: list = None
27
- counting_mask: LongTensor = None
28
- lemma_rules: LongTensor = None
29
- joint_feats: LongTensor = None
30
- deps_ud: LongTensor = None
31
- deps_eud: LongTensor = None
32
- miscs: LongTensor = None
33
- deepslots: LongTensor = None
34
- semclasses: LongTensor = None
35
-
36
-
37
  class CobaldParser(PreTrainedModel):
38
  """Morpho-Syntax-Semantic Parser."""
39
 
@@ -119,8 +100,8 @@ class CobaldParser(PreTrainedModel):
119
  sent_ids: list[str] = None,
120
  texts: list[str] = None,
121
  inference_mode: bool = False
122
- ) -> CobaldParserOutput:
123
- result = {}
124
 
125
  # Extra [CLS] token accounts for the case when #NULL is the first token in a sentence.
126
  words_with_cls = prepend_cls(words)
@@ -129,62 +110,62 @@ class CobaldParser(PreTrainedModel):
129
  embeddings_without_nulls = self.encoder(words_without_nulls)
130
  # Predict nulls.
131
  null_output = self.classifiers["null"](embeddings_without_nulls, counting_masks)
132
- result["counting_mask"] = null_output['preds']
133
- result["loss"] = null_output["loss"]
134
 
135
  # "Teacher forcing": during training, pass the original words (with gold nulls)
136
  # to the classification heads, so that they are trained upon correct sentences.
137
  if inference_mode:
138
  # Restore predicted nulls in the original sentences.
139
- result["words"] = add_nulls(words, null_output["preds"])
140
  else:
141
- result["words"] = words
142
 
143
  # Encode words with nulls.
144
  # [batch_size, seq_len, embedding_size]
145
- embeddings = self.encoder(result["words"])
146
 
147
  # Predict lemmas and morphological features.
148
  if "lemma_rule" in self.classifiers:
149
  lemma_output = self.classifiers["lemma_rule"](embeddings, lemma_rules)
150
- result["lemma_rules"] = lemma_output['preds']
151
- result["loss"] += lemma_output['loss']
152
 
153
  if "joint_feats" in self.classifiers:
154
  joint_feats_output = self.classifiers["joint_feats"](embeddings, joint_feats)
155
- result["joint_feats"] = joint_feats_output['preds']
156
- result["loss"] += joint_feats_output['loss']
157
 
158
  # Predict syntax.
159
  if "syntax" in self.classifiers:
160
- padding_mask = build_padding_mask(result["words"], self.device)
161
- null_mask = build_null_mask(result["words"], self.device)
162
  deps_output = self.classifiers["syntax"](
163
  embeddings,
164
  deps_ud,
165
  deps_eud,
166
- mask_ud=(padding_mask & ~null_mask),
167
- mask_eud=padding_mask
168
  )
169
- result["deps_ud"] = deps_output['preds_ud']
170
- result["deps_eud"] = deps_output['preds_eud']
171
- result["loss"] += deps_output['loss_ud'] + deps_output['loss_eud']
172
 
173
  # Predict miscellaneous features.
174
  if "misc" in self.classifiers:
175
  misc_output = self.classifiers["misc"](embeddings, miscs)
176
- result["miscs"] = misc_output['preds']
177
- result["loss"] += misc_output['loss']
178
 
179
  # Predict semantics.
180
  if "deepslot" in self.classifiers:
181
  deepslot_output = self.classifiers["deepslot"](embeddings, deepslots)
182
- result["deepslots"] = deepslot_output['preds']
183
- result["loss"] += deepslot_output['loss']
184
 
185
  if "semclass" in self.classifiers:
186
  semclass_output = self.classifiers["semclass"](embeddings, semclasses)
187
- result["semclasses"] = semclass_output['preds']
188
- result["loss"] += semclass_output['loss']
189
 
190
- return CobaldParserOutput(**result)
 
1
  from torch import nn
2
  from torch import LongTensor
3
  from transformers import PreTrainedModel
 
 
4
 
5
  from .configuration import CobaldParserConfig
6
  from .encoder import WordTransformerEncoder
 
15
  )
16
 
17
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
  class CobaldParser(PreTrainedModel):
19
  """Morpho-Syntax-Semantic Parser."""
20
 
 
100
  sent_ids: list[str] = None,
101
  texts: list[str] = None,
102
  inference_mode: bool = False
103
+ ) -> dict:
104
+ output = {}
105
 
106
  # Extra [CLS] token accounts for the case when #NULL is the first token in a sentence.
107
  words_with_cls = prepend_cls(words)
 
110
  embeddings_without_nulls = self.encoder(words_without_nulls)
111
  # Predict nulls.
112
  null_output = self.classifiers["null"](embeddings_without_nulls, counting_masks)
113
+ output["counting_mask"] = null_output['preds']
114
+ output["loss"] = null_output["loss"]
115
 
116
  # "Teacher forcing": during training, pass the original words (with gold nulls)
117
  # to the classification heads, so that they are trained upon correct sentences.
118
  if inference_mode:
119
  # Restore predicted nulls in the original sentences.
120
+ output["words"] = add_nulls(words, null_output["preds"])
121
  else:
122
+ output["words"] = words
123
 
124
  # Encode words with nulls.
125
  # [batch_size, seq_len, embedding_size]
126
+ embeddings = self.encoder(output["words"])
127
 
128
  # Predict lemmas and morphological features.
129
  if "lemma_rule" in self.classifiers:
130
  lemma_output = self.classifiers["lemma_rule"](embeddings, lemma_rules)
131
+ output["lemma_rules"] = lemma_output['preds']
132
+ output["loss"] += lemma_output['loss']
133
 
134
  if "joint_feats" in self.classifiers:
135
  joint_feats_output = self.classifiers["joint_feats"](embeddings, joint_feats)
136
+ output["joint_feats"] = joint_feats_output['preds']
137
+ output["loss"] += joint_feats_output['loss']
138
 
139
  # Predict syntax.
140
  if "syntax" in self.classifiers:
141
+ padding_mask = build_padding_mask(output["words"], self.device)
142
+ null_mask = build_null_mask(output["words"], self.device)
143
  deps_output = self.classifiers["syntax"](
144
  embeddings,
145
  deps_ud,
146
  deps_eud,
147
+ null_mask,
148
+ padding_mask
149
  )
150
+ output["deps_ud"] = deps_output['preds_ud']
151
+ output["deps_eud"] = deps_output['preds_eud']
152
+ output["loss"] += deps_output['loss_ud'] + deps_output['loss_eud']
153
 
154
  # Predict miscellaneous features.
155
  if "misc" in self.classifiers:
156
  misc_output = self.classifiers["misc"](embeddings, miscs)
157
+ output["miscs"] = misc_output['preds']
158
+ output["loss"] += misc_output['loss']
159
 
160
  # Predict semantics.
161
  if "deepslot" in self.classifiers:
162
  deepslot_output = self.classifiers["deepslot"](embeddings, deepslots)
163
+ output["deepslots"] = deepslot_output['preds']
164
+ output["loss"] += deepslot_output['loss']
165
 
166
  if "semclass" in self.classifiers:
167
  semclass_output = self.classifiers["semclass"](embeddings, semclasses)
168
+ output["semclasses"] = semclass_output['preds']
169
+ output["loss"] += semclass_output['loss']
170
 
171
+ return output
training_args.bin CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:09fd75fcff1724f060e15c6d1fd2cd167eab8304208d648ba480d453b4974510
3
- size 5496
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2d8dec73b5638e57d0bca9bd4ee05cd11ce5aba98bc59f80b4e16231e6e7403f
3
+ size 5905
utils.py CHANGED
@@ -21,7 +21,7 @@ def build_padding_mask(sentences: list[list[str]], device) -> Tensor:
21
  return _build_condition_mask(sentences, condition_fn=lambda word: True, device=device)
22
 
23
  def build_null_mask(sentences: list[list[str]], device) -> Tensor:
24
- return _build_condition_mask(sentences, condition_fn=lambda word: word == "#NULL", device=device)
25
 
26
 
27
  def pairwise_mask(masks1d: Tensor) -> Tensor:
 
21
  return _build_condition_mask(sentences, condition_fn=lambda word: True, device=device)
22
 
23
  def build_null_mask(sentences: list[list[str]], device) -> Tensor:
24
+ return _build_condition_mask(sentences, condition_fn=lambda word: word != "#NULL", device=device)
25
 
26
 
27
  def pairwise_mask(masks1d: Tensor) -> Tensor: