Model save
Browse files- README.md +8 -8
- config.json +197 -197
- dependency_classifier.py +46 -42
- model.safetensors +1 -1
- modeling_parser.py +25 -44
- training_args.bin +2 -2
- utils.py +1 -1
README.md
CHANGED
|
@@ -21,28 +21,28 @@ model-index:
|
|
| 21 |
split: validation
|
| 22 |
metrics:
|
| 23 |
- type: f1
|
| 24 |
-
value: 0.
|
| 25 |
name: Null F1
|
| 26 |
- type: f1
|
| 27 |
-
value: 0.
|
| 28 |
name: Lemma F1
|
| 29 |
- type: f1
|
| 30 |
-
value: 0.
|
| 31 |
name: Morphology F1
|
| 32 |
- type: accuracy
|
| 33 |
-
value: 0.
|
| 34 |
name: Ud Jaccard
|
| 35 |
- type: accuracy
|
| 36 |
-
value: 0.
|
| 37 |
name: Eud Jaccard
|
| 38 |
- type: f1
|
| 39 |
-
value: 0.
|
| 40 |
name: Miscs F1
|
| 41 |
- type: f1
|
| 42 |
-
value: 0.
|
| 43 |
name: Deepslot F1
|
| 44 |
- type: f1
|
| 45 |
-
value: 0.
|
| 46 |
name: Semclass F1
|
| 47 |
---
|
| 48 |
|
|
|
|
| 21 |
split: validation
|
| 22 |
metrics:
|
| 23 |
- type: f1
|
| 24 |
+
value: 0.9270548177755096
|
| 25 |
name: Null F1
|
| 26 |
- type: f1
|
| 27 |
+
value: 0.8339235583777782
|
| 28 |
name: Lemma F1
|
| 29 |
- type: f1
|
| 30 |
+
value: 0.7885002678867238
|
| 31 |
name: Morphology F1
|
| 32 |
- type: accuracy
|
| 33 |
+
value: 0.7653227685854114
|
| 34 |
name: Ud Jaccard
|
| 35 |
- type: accuracy
|
| 36 |
+
value: 0.7962406996475656
|
| 37 |
name: Eud Jaccard
|
| 38 |
- type: f1
|
| 39 |
+
value: 0.6438483915854029
|
| 40 |
name: Miscs F1
|
| 41 |
- type: f1
|
| 42 |
+
value: 0.6179291073868571
|
| 43 |
name: Deepslot F1
|
| 44 |
- type: f1
|
| 45 |
+
value: 0.6220501826358034
|
| 46 |
name: Semclass F1
|
| 47 |
---
|
| 48 |
|
config.json
CHANGED
|
@@ -25,7 +25,7 @@
|
|
| 25 |
"null_classifier_hidden_size": 512,
|
| 26 |
"semclass_classifier_hidden_size": 512,
|
| 27 |
"torch_dtype": "float32",
|
| 28 |
-
"transformers_version": "4.
|
| 29 |
"vocabulary": {
|
| 30 |
"deepslot": {
|
| 31 |
"0": "$Dislocation",
|
|
@@ -370,213 +370,213 @@
|
|
| 370 |
"1": "ADJ#Adjective#Degree=Cmp",
|
| 371 |
"2": "ADJ#Adjective#Degree=Pos",
|
| 372 |
"3": "ADJ#Adjective#Degree=Sup",
|
| 373 |
-
"4": "ADJ#
|
| 374 |
-
"5": "ADJ#
|
| 375 |
-
"6": "ADJ#
|
| 376 |
-
"7": "ADJ#
|
| 377 |
-
"8": "ADJ#
|
| 378 |
-
"9": "ADJ#
|
| 379 |
-
"10": "ADJ#
|
| 380 |
-
"11": "ADJ#
|
| 381 |
-
"12": "ADP#Adverb#
|
| 382 |
-
"13": "ADP#
|
| 383 |
-
"14": "ADP#
|
| 384 |
"15": "ADV#Adjective#Degree=Pos",
|
| 385 |
"16": "ADV#Adverb#Degree=Cmp",
|
| 386 |
"17": "ADV#Adverb#Degree=Pos",
|
| 387 |
"18": "ADV#Adverb#Degree=Pos|NumType=Mult",
|
| 388 |
"19": "ADV#Adverb#Degree=Sup",
|
| 389 |
-
"20": "ADV#Adverb#
|
| 390 |
-
"21": "ADV#Adverb#
|
| 391 |
-
"22": "ADV#Adverb#
|
| 392 |
-
"23": "ADV#Adverb#
|
| 393 |
"24": "ADV#Invariable#Degree=Cmp",
|
| 394 |
-
"25": "ADV#Invariable#
|
| 395 |
-
"26": "ADV#
|
| 396 |
-
"27": "ADV#
|
| 397 |
-
"28": "ADV#
|
| 398 |
-
"29": "ADV#
|
| 399 |
-
"30": "ADV#
|
| 400 |
-
"31": "ADV#
|
| 401 |
-
"32": "ADV#
|
| 402 |
-
"33": "ADV#
|
| 403 |
-
"34": "AUX#
|
| 404 |
-
"35": "AUX#Verb#Mood=Ind|Number=Plur|Person=1|Tense=
|
| 405 |
-
"36": "AUX#Verb#Mood=Ind|Number=Plur|Person=
|
| 406 |
-
"37": "AUX#Verb#Mood=Ind|Number=Plur|Person=
|
| 407 |
-
"38": "AUX#Verb#Mood=Ind|Number=Plur|Person=3|Tense=
|
| 408 |
-
"39": "AUX#Verb#Mood=Ind|Number=
|
| 409 |
-
"40": "AUX#Verb#Mood=Ind|Number=Sing|Person=1|Tense=
|
| 410 |
-
"41": "AUX#Verb#Mood=Ind|Number=Sing|Person=
|
| 411 |
-
"42": "AUX#Verb#Mood=Ind|Number=Sing|Person=2|Tense=
|
| 412 |
-
"43": "AUX#Verb#Mood=Ind|Number=Sing|Person=
|
| 413 |
-
"44": "AUX#Verb#Mood=Ind|Number=Sing|Person=3|Tense=
|
| 414 |
-
"45": "AUX#Verb#Mood=
|
| 415 |
-
"46": "AUX#Verb#Mood=Sub|Number=Plur|
|
| 416 |
-
"47": "AUX#Verb#
|
| 417 |
-
"48": "AUX#Verb#Number=Plur|Tense=
|
| 418 |
-
"49": "AUX#Verb#
|
| 419 |
-
"50": "AUX#Verb#VerbForm=
|
| 420 |
-
"51": "AUX#Verb#VerbForm=
|
| 421 |
-
"52": "AUX#
|
| 422 |
-
"53": "CCONJ#Conjunction#
|
| 423 |
-
"54": "CCONJ#
|
| 424 |
"55": "DET#Adjective#PronType=Tot",
|
| 425 |
"56": "DET#Article#Definite=Def|PronType=Art",
|
| 426 |
"57": "DET#Article#Definite=Ind|PronType=Art",
|
| 427 |
"58": "DET#Conjunction#Definite=Def|PronType=Art",
|
| 428 |
-
"59": "DET#
|
| 429 |
-
"60": "DET#
|
| 430 |
-
"61": "DET#
|
| 431 |
-
"62": "DET#
|
| 432 |
-
"63": "DET#
|
| 433 |
-
"64": "DET#
|
| 434 |
-
"65": "DET#
|
| 435 |
-
"66": "DET#
|
| 436 |
-
"67": "DET#
|
| 437 |
-
"68": "DET#
|
| 438 |
-
"69": "DET#
|
| 439 |
-
"70": "DET#
|
| 440 |
-
"71": "DET#
|
| 441 |
-
"72": "DET#
|
| 442 |
-
"73": "DET#
|
| 443 |
-
"74": "DET#
|
| 444 |
-
"75": "DET#
|
| 445 |
-
"76": "DET#
|
| 446 |
-
"77": "INTJ#Interjection#
|
| 447 |
"78": "NOUN#Adverb#Number=Sing",
|
| 448 |
-
"79": "NOUN#
|
| 449 |
-
"80": "NOUN#
|
| 450 |
-
"81": "NOUN#Noun#
|
| 451 |
-
"82": "NOUN#Noun#
|
| 452 |
-
"83": "NOUN#Noun#
|
| 453 |
-
"84": "NOUN#Noun#Number=
|
| 454 |
-
"85": "NOUN#Noun#
|
| 455 |
-
"86": "NOUN#
|
| 456 |
-
"87": "NOUN#
|
| 457 |
-
"88": "NOUN#
|
| 458 |
-
"89": "NOUN#
|
| 459 |
-
"90": "NUM#
|
| 460 |
-
"91": "NUM#
|
| 461 |
-
"92": "NUM#
|
| 462 |
-
"93": "NUM#Numeral#
|
| 463 |
-
"94": "NUM#Numeral#NumForm=
|
| 464 |
-
"95": "NUM#Numeral#
|
| 465 |
-
"96": "NUM#Numeral#
|
| 466 |
-
"97": "NUM#
|
| 467 |
-
"98": "NUM#
|
| 468 |
-
"99": "PART#
|
| 469 |
-
"100": "PART#
|
| 470 |
-
"101": "PART#
|
| 471 |
-
"102": "PART#
|
| 472 |
-
"103": "PPROPN#
|
| 473 |
-
"104": "PRON#
|
| 474 |
-
"105": "PRON#
|
| 475 |
-
"106": "PRON#
|
| 476 |
-
"107": "PRON#
|
| 477 |
-
"108": "PRON#
|
| 478 |
-
"109": "PRON#
|
| 479 |
-
"110": "PRON#Pronoun#Case=Acc|
|
| 480 |
-
"111": "PRON#Pronoun#Case=Acc|
|
| 481 |
-
"112": "PRON#Pronoun#Case=Acc|
|
| 482 |
-
"113": "PRON#Pronoun#Case=Acc|
|
| 483 |
-
"114": "PRON#Pronoun#Case=Acc|
|
| 484 |
-
"115": "PRON#Pronoun#Case=Acc|
|
| 485 |
-
"116": "PRON#Pronoun#Case=Acc|Number=
|
| 486 |
-
"117": "PRON#Pronoun#Case=Acc|Number=
|
| 487 |
-
"118": "PRON#Pronoun#Case=
|
| 488 |
-
"119": "PRON#Pronoun#Case=
|
| 489 |
-
"120": "PRON#Pronoun#Case=
|
| 490 |
-
"121": "PRON#Pronoun#Case=
|
| 491 |
-
"122": "PRON#Pronoun#Case=
|
| 492 |
-
"123": "PRON#Pronoun#Case=
|
| 493 |
-
"124": "PRON#Pronoun#Case=Gen|
|
| 494 |
-
"125": "PRON#Pronoun#Case=
|
| 495 |
-
"126": "PRON#Pronoun#Case=
|
| 496 |
-
"127": "PRON#Pronoun#Case=
|
| 497 |
-
"128": "PRON#Pronoun#Case=
|
| 498 |
-
"129": "PRON#Pronoun#Case=
|
| 499 |
-
"130": "PRON#Pronoun#Case=
|
| 500 |
-
"131": "PRON#Pronoun#Case=Nom|
|
| 501 |
-
"132": "PRON#Pronoun#Case=Nom|
|
| 502 |
-
"133": "PRON#Pronoun#Case=Nom|
|
| 503 |
-
"134": "PRON#Pronoun#Case=Nom|
|
| 504 |
-
"135": "PRON#Pronoun#Case=Nom|
|
| 505 |
-
"136": "PRON#Pronoun#
|
| 506 |
-
"137": "PRON#Pronoun#
|
| 507 |
-
"138": "PRON#Pronoun#
|
| 508 |
-
"139": "PRON#Pronoun#
|
| 509 |
-
"140": "PRON#Pronoun#
|
| 510 |
-
"141": "PRON#Pronoun#
|
| 511 |
-
"142": "PRON#Pronoun#
|
| 512 |
-
"143": "PRON#Pronoun#Number=
|
| 513 |
-
"144": "PRON#Pronoun#Number=
|
| 514 |
-
"145": "PRON#Pronoun#
|
| 515 |
-
"146": "PRON#Pronoun#
|
| 516 |
-
"147": "PRON#Pronoun#
|
| 517 |
-
"148": "PRON#Pronoun#
|
| 518 |
-
"149": "PRON#
|
| 519 |
-
"150": "PRON#
|
| 520 |
-
"151": "PRON#
|
| 521 |
-
"152": "PRON#
|
| 522 |
-
"153": "PRON#
|
| 523 |
-
"154": "PRON#
|
| 524 |
-
"155": "PROPN#
|
| 525 |
-
"156": "PROPN#
|
| 526 |
-
"157": "PROPN#
|
| 527 |
-
"158": "PROPN#Noun#
|
| 528 |
-
"159": "PROPN#Noun#
|
| 529 |
-
"160": "PROPN#Noun#
|
| 530 |
-
"161": "PROPN#Noun#
|
| 531 |
-
"162": "PROPN#
|
| 532 |
-
"163": "PROPN#
|
| 533 |
-
"164": "PROPN#
|
| 534 |
-
"165": "PROPN#
|
| 535 |
-
"166": "PUNCT#
|
| 536 |
-
"167": "PUNCT#
|
| 537 |
-
"168": "Prefixoid#Prefixoid#
|
| 538 |
-
"169": "SCONJ#Conjunction#
|
| 539 |
-
"170": "SCONJ#
|
| 540 |
-
"171": "SYM#Conjunction#
|
| 541 |
-
"172": "SYM#Noun#
|
| 542 |
-
"173": "SYM#Noun#
|
| 543 |
-
"174": "VERB#
|
| 544 |
-
"175": "VERB#
|
| 545 |
-
"176": "VERB#
|
| 546 |
-
"177": "VERB#
|
| 547 |
-
"178": "VERB#Verb#Mood=
|
| 548 |
-
"179": "VERB#Verb#Mood=Ind|Number=Plur|Person=
|
| 549 |
-
"180": "VERB#Verb#Mood=Ind|Number=
|
| 550 |
-
"181": "VERB#Verb#Mood=Ind|Number=
|
| 551 |
-
"182": "VERB#Verb#Mood=Ind|Number=
|
| 552 |
-
"183": "VERB#Verb#Mood=Ind|Number=
|
| 553 |
-
"184": "VERB#Verb#Mood=Ind|Number=Sing|Person=
|
| 554 |
-
"185": "VERB#Verb#Mood=Ind|Number=Sing|Person=
|
| 555 |
-
"186": "VERB#Verb#Mood=Ind|Number=Sing|Person=
|
| 556 |
-
"187": "VERB#Verb#Mood=
|
| 557 |
-
"188": "VERB#Verb#Mood=
|
| 558 |
-
"189": "VERB#Verb#Mood=
|
| 559 |
-
"190": "VERB#Verb#Mood=
|
| 560 |
-
"191": "VERB#Verb#
|
| 561 |
-
"192": "VERB#Verb#
|
| 562 |
-
"193": "VERB#Verb#
|
| 563 |
-
"194": "VERB#Verb#
|
| 564 |
-
"195": "VERB#Verb#Person=1|Tense=
|
| 565 |
-
"196": "VERB#Verb#Person=
|
| 566 |
-
"197": "VERB#Verb#
|
| 567 |
-
"198": "VERB#Verb#
|
| 568 |
-
"199": "VERB#Verb#
|
| 569 |
-
"200": "VERB#Verb#
|
| 570 |
-
"201": "VERB#Verb#
|
| 571 |
-
"202": "VERB#Verb#
|
| 572 |
-
"203": "VERB#
|
| 573 |
-
"204": "VERB#
|
| 574 |
-
"205": "VERB#
|
| 575 |
-
"206": "VERB#
|
| 576 |
-
"207": "X#
|
| 577 |
-
"208": "X#
|
| 578 |
-
"209": "X#
|
| 579 |
-
"210": "X#
|
| 580 |
},
|
| 581 |
"lemma_rule": {
|
| 582 |
"0": "cut_prefix=0|cut_suffix=0|append_suffix=",
|
|
|
|
| 25 |
"null_classifier_hidden_size": 512,
|
| 26 |
"semclass_classifier_hidden_size": 512,
|
| 27 |
"torch_dtype": "float32",
|
| 28 |
+
"transformers_version": "4.52.3",
|
| 29 |
"vocabulary": {
|
| 30 |
"deepslot": {
|
| 31 |
"0": "$Dislocation",
|
|
|
|
| 370 |
"1": "ADJ#Adjective#Degree=Cmp",
|
| 371 |
"2": "ADJ#Adjective#Degree=Pos",
|
| 372 |
"3": "ADJ#Adjective#Degree=Sup",
|
| 373 |
+
"4": "ADJ#Numeral#Degree=Pos|NumForm=Digit|NumType=Ord",
|
| 374 |
+
"5": "ADJ#Numeral#Degree=Pos|NumForm=Word|NumType=Ord",
|
| 375 |
+
"6": "ADJ#Prefixoid#_",
|
| 376 |
+
"7": "ADJ#_#Degree=Cmp",
|
| 377 |
+
"8": "ADJ#_#Degree=Pos",
|
| 378 |
+
"9": "ADJ#_#Degree=Pos|NumType=Ord",
|
| 379 |
+
"10": "ADJ#_#Degree=Sup",
|
| 380 |
+
"11": "ADJ#_#_",
|
| 381 |
+
"12": "ADP#Adverb#_",
|
| 382 |
+
"13": "ADP#Preposition#_",
|
| 383 |
+
"14": "ADP#_#_",
|
| 384 |
"15": "ADV#Adjective#Degree=Pos",
|
| 385 |
"16": "ADV#Adverb#Degree=Cmp",
|
| 386 |
"17": "ADV#Adverb#Degree=Pos",
|
| 387 |
"18": "ADV#Adverb#Degree=Pos|NumType=Mult",
|
| 388 |
"19": "ADV#Adverb#Degree=Sup",
|
| 389 |
+
"20": "ADV#Adverb#NumType=Mult",
|
| 390 |
+
"21": "ADV#Adverb#Polarity=Neg",
|
| 391 |
+
"22": "ADV#Adverb#PronType=Dem",
|
| 392 |
+
"23": "ADV#Adverb#_",
|
| 393 |
"24": "ADV#Invariable#Degree=Cmp",
|
| 394 |
+
"25": "ADV#Invariable#_",
|
| 395 |
+
"26": "ADV#Prefixoid#_",
|
| 396 |
+
"27": "ADV#_#Degree=Cmp",
|
| 397 |
+
"28": "ADV#_#Degree=Pos",
|
| 398 |
+
"29": "ADV#_#Degree=Sup",
|
| 399 |
+
"30": "ADV#_#NumType=Mult",
|
| 400 |
+
"31": "ADV#_#PronType=Dem",
|
| 401 |
+
"32": "ADV#_#PronType=Int",
|
| 402 |
+
"33": "ADV#_#_",
|
| 403 |
+
"34": "AUX#Verb#Mood=Ind|Number=Plur|Person=1|Tense=Past|VerbForm=Fin",
|
| 404 |
+
"35": "AUX#Verb#Mood=Ind|Number=Plur|Person=1|Tense=Pres|VerbForm=Fin",
|
| 405 |
+
"36": "AUX#Verb#Mood=Ind|Number=Plur|Person=2|Tense=Pres|VerbForm=Fin",
|
| 406 |
+
"37": "AUX#Verb#Mood=Ind|Number=Plur|Person=3|Tense=Past|VerbForm=Fin",
|
| 407 |
+
"38": "AUX#Verb#Mood=Ind|Number=Plur|Person=3|Tense=Pres|VerbForm=Fin",
|
| 408 |
+
"39": "AUX#Verb#Mood=Ind|Number=Sing|Person=1|Tense=Past|VerbForm=Fin",
|
| 409 |
+
"40": "AUX#Verb#Mood=Ind|Number=Sing|Person=1|Tense=Pres|VerbForm=Fin",
|
| 410 |
+
"41": "AUX#Verb#Mood=Ind|Number=Sing|Person=2|Tense=Past|VerbForm=Fin",
|
| 411 |
+
"42": "AUX#Verb#Mood=Ind|Number=Sing|Person=2|Tense=Pres|VerbForm=Fin",
|
| 412 |
+
"43": "AUX#Verb#Mood=Ind|Number=Sing|Person=3|Tense=Past|VerbForm=Fin",
|
| 413 |
+
"44": "AUX#Verb#Mood=Ind|Number=Sing|Person=3|Tense=Pres|VerbForm=Fin",
|
| 414 |
+
"45": "AUX#Verb#Mood=Sub|Number=Plur|Person=1|Tense=Past|VerbForm=Fin",
|
| 415 |
+
"46": "AUX#Verb#Mood=Sub|Number=Plur|Tense=Past|VerbForm=Part",
|
| 416 |
+
"47": "AUX#Verb#Number=Plur|Tense=Past|VerbForm=Part",
|
| 417 |
+
"48": "AUX#Verb#Number=Plur|Tense=Pres|VerbForm=Part",
|
| 418 |
+
"49": "AUX#Verb#VerbForm=Fin",
|
| 419 |
+
"50": "AUX#Verb#VerbForm=Ger",
|
| 420 |
+
"51": "AUX#Verb#VerbForm=Inf",
|
| 421 |
+
"52": "AUX#_#Mood=Ind|Number=Sing|Person=3|Tense=Pres|VerbForm=Fin",
|
| 422 |
+
"53": "CCONJ#Conjunction#_",
|
| 423 |
+
"54": "CCONJ#_#_",
|
| 424 |
"55": "DET#Adjective#PronType=Tot",
|
| 425 |
"56": "DET#Article#Definite=Def|PronType=Art",
|
| 426 |
"57": "DET#Article#Definite=Ind|PronType=Art",
|
| 427 |
"58": "DET#Conjunction#Definite=Def|PronType=Art",
|
| 428 |
+
"59": "DET#Prefixoid#_",
|
| 429 |
+
"60": "DET#Pronoun#Number=Plur|PronType=Dem",
|
| 430 |
+
"61": "DET#Pronoun#Number=Sing|PronType=Dem",
|
| 431 |
+
"62": "DET#Pronoun#Polarity=Neg",
|
| 432 |
+
"63": "DET#Pronoun#PronType=Ind",
|
| 433 |
+
"64": "DET#Pronoun#PronType=Int",
|
| 434 |
+
"65": "DET#Pronoun#PronType=Rel",
|
| 435 |
+
"66": "DET#Pronoun#PronType=Tot",
|
| 436 |
+
"67": "DET#Pronoun#_",
|
| 437 |
+
"68": "DET#_#Definite=Def|PronType=Art",
|
| 438 |
+
"69": "DET#_#Definite=EMPTY",
|
| 439 |
+
"70": "DET#_#Definite=Ind|PronType=Art",
|
| 440 |
+
"71": "DET#_#Number=Sing|PronType=Dem",
|
| 441 |
+
"72": "DET#_#PronType=Int",
|
| 442 |
+
"73": "DET#_#PronType=Neg",
|
| 443 |
+
"74": "DET#_#PronType=Rcp",
|
| 444 |
+
"75": "DET#_#PronType=Tot",
|
| 445 |
+
"76": "DET#_#_",
|
| 446 |
+
"77": "INTJ#Interjection#_",
|
| 447 |
"78": "NOUN#Adverb#Number=Sing",
|
| 448 |
+
"79": "NOUN#Noun#Abbr=Yes|Number=Plur",
|
| 449 |
+
"80": "NOUN#Noun#Abbr=Yes|Number=Sing",
|
| 450 |
+
"81": "NOUN#Noun#NumType=Frac|Number=Sing",
|
| 451 |
+
"82": "NOUN#Noun#Number=Plur",
|
| 452 |
+
"83": "NOUN#Noun#Number=Sing",
|
| 453 |
+
"84": "NOUN#Noun#Number=Sing|Polarity=Neg",
|
| 454 |
+
"85": "NOUN#Noun#VerbForm=Fin",
|
| 455 |
+
"86": "NOUN#Prefixoid#Number=Sing",
|
| 456 |
+
"87": "NOUN#Prefixoid#_",
|
| 457 |
+
"88": "NOUN#_#Number=Plur",
|
| 458 |
+
"89": "NOUN#_#Number=Sing",
|
| 459 |
+
"90": "NUM#Noun#NumForm=Word|NumType=Card",
|
| 460 |
+
"91": "NUM#Numeral#NumForm=Digit|NumType=Card",
|
| 461 |
+
"92": "NUM#Numeral#NumForm=Digit|NumType=Frac",
|
| 462 |
+
"93": "NUM#Numeral#NumForm=Roman|NumType=Card",
|
| 463 |
+
"94": "NUM#Numeral#NumForm=Word|NumType=Card",
|
| 464 |
+
"95": "NUM#Numeral#NumType=Card",
|
| 465 |
+
"96": "NUM#Numeral#_",
|
| 466 |
+
"97": "NUM#_#Degree=Pos|NumType=Ord",
|
| 467 |
+
"98": "NUM#_#NumType=Card",
|
| 468 |
+
"99": "PART#Particle#Polarity=Neg",
|
| 469 |
+
"100": "PART#Particle#_",
|
| 470 |
+
"101": "PART#_#Polarity=Neg",
|
| 471 |
+
"102": "PART#_#_",
|
| 472 |
+
"103": "PPROPN#_#Number=Plur",
|
| 473 |
+
"104": "PRON#Pronoun#Case=Acc|Gender=Fem|Number=Sing|Person=3|PronType=Prs",
|
| 474 |
+
"105": "PRON#Pronoun#Case=Acc|Gender=Fem|Number=Sing|Person=3|PronType=Prs|Reflex=Yes",
|
| 475 |
+
"106": "PRON#Pronoun#Case=Acc|Gender=Masc|Number=Sing|Person=3|PronType=Prs",
|
| 476 |
+
"107": "PRON#Pronoun#Case=Acc|Gender=Masc|Number=Sing|Person=3|PronType=Prs|Reflex=Yes",
|
| 477 |
+
"108": "PRON#Pronoun#Case=Acc|Gender=Neut|Number=Sing|Person=3|PronType=Prs",
|
| 478 |
+
"109": "PRON#Pronoun#Case=Acc|Gender=Neut|Number=Sing|Person=3|PronType=Prs|Reflex=Yes",
|
| 479 |
+
"110": "PRON#Pronoun#Case=Acc|Number=Plur|Person=1|PronType=Prs",
|
| 480 |
+
"111": "PRON#Pronoun#Case=Acc|Number=Plur|Person=1|PronType=Prs|Reflex=Yes",
|
| 481 |
+
"112": "PRON#Pronoun#Case=Acc|Number=Plur|Person=2|PronType=Prs",
|
| 482 |
+
"113": "PRON#Pronoun#Case=Acc|Number=Plur|Person=3|PronType=Prs",
|
| 483 |
+
"114": "PRON#Pronoun#Case=Acc|Number=Plur|Person=3|PronType=Prs|Reflex=Yes",
|
| 484 |
+
"115": "PRON#Pronoun#Case=Acc|Number=Sing|Person=1|PronType=Prs",
|
| 485 |
+
"116": "PRON#Pronoun#Case=Acc|Number=Sing|Person=2|PronType=Prs",
|
| 486 |
+
"117": "PRON#Pronoun#Case=Acc|Number=Sing|Person=2|PronType=Prs|Reflex=Yes",
|
| 487 |
+
"118": "PRON#Pronoun#Case=Gen|Gender=Fem|Number=Sing|Person=3|Poss=Yes|PronType=Prs",
|
| 488 |
+
"119": "PRON#Pronoun#Case=Gen|Gender=Masc|Number=Sing|Person=3|Poss=Yes|PronType=Prs",
|
| 489 |
+
"120": "PRON#Pronoun#Case=Gen|Gender=Neut|Number=Sing|Person=3|Poss=Yes|PronType=Prs",
|
| 490 |
+
"121": "PRON#Pronoun#Case=Gen|Number=Plur|Person=1|Poss=Yes|PronType=Prs",
|
| 491 |
+
"122": "PRON#Pronoun#Case=Gen|Number=Plur|Person=3|Poss=Yes|PronType=Prs",
|
| 492 |
+
"123": "PRON#Pronoun#Case=Gen|Number=Sing|Person=1|Poss=Yes|PronType=Prs",
|
| 493 |
+
"124": "PRON#Pronoun#Case=Gen|Number=Sing|Person=2|Poss=Yes|PronType=Prs",
|
| 494 |
+
"125": "PRON#Pronoun#Case=Nom|Gender=Fem|Number=Sing|Person=3|PronType=Prs",
|
| 495 |
+
"126": "PRON#Pronoun#Case=Nom|Gender=Masc|Number=Sing|Person=3|PronType=Prs",
|
| 496 |
+
"127": "PRON#Pronoun#Case=Nom|Gender=Masc|Number=Sing|Person=3|PronType=Prs|Reflex=Yes",
|
| 497 |
+
"128": "PRON#Pronoun#Case=Nom|Gender=Neut|Number=Sing|Person=3|PronType=Prs",
|
| 498 |
+
"129": "PRON#Pronoun#Case=Nom|Gender=Neut|Number=Sing|Person=3|PronType=Prs|Reflex=Yes",
|
| 499 |
+
"130": "PRON#Pronoun#Case=Nom|Number=Plur|Person=1|PronType=Prs",
|
| 500 |
+
"131": "PRON#Pronoun#Case=Nom|Number=Plur|Person=2|PronType=Prs",
|
| 501 |
+
"132": "PRON#Pronoun#Case=Nom|Number=Plur|Person=3|PronType=Prs",
|
| 502 |
+
"133": "PRON#Pronoun#Case=Nom|Number=Plur|Person=3|PronType=Prs|Reflex=Yes",
|
| 503 |
+
"134": "PRON#Pronoun#Case=Nom|Number=Sing|Person=1|PronType=Prs",
|
| 504 |
+
"135": "PRON#Pronoun#Case=Nom|Number=Sing|Person=2|PronType=Prs",
|
| 505 |
+
"136": "PRON#Pronoun#Number=Plur",
|
| 506 |
+
"137": "PRON#Pronoun#Number=Plur|PronType=Dem",
|
| 507 |
+
"138": "PRON#Pronoun#Number=Plur|PronType=Tot",
|
| 508 |
+
"139": "PRON#Pronoun#Number=Sing",
|
| 509 |
+
"140": "PRON#Pronoun#Number=Sing|Polarity=Neg|PronType=Neg",
|
| 510 |
+
"141": "PRON#Pronoun#Number=Sing|PronType=Dem",
|
| 511 |
+
"142": "PRON#Pronoun#Number=Sing|PronType=Ind",
|
| 512 |
+
"143": "PRON#Pronoun#Number=Sing|PronType=Neg",
|
| 513 |
+
"144": "PRON#Pronoun#Number=Sing|Reflex=Yes",
|
| 514 |
+
"145": "PRON#Pronoun#PronType=Ind",
|
| 515 |
+
"146": "PRON#Pronoun#PronType=Int",
|
| 516 |
+
"147": "PRON#Pronoun#PronType=Rel",
|
| 517 |
+
"148": "PRON#Pronoun#_",
|
| 518 |
+
"149": "PRON#_#Gender=Neut|Number=Sing|Person=3|Poss=Yes|PronType=Prs",
|
| 519 |
+
"150": "PRON#_#Number=Sing",
|
| 520 |
+
"151": "PRON#_#Number=Sing|PronType=Dem",
|
| 521 |
+
"152": "PRON#_#Number=Sing|PronType=Ind",
|
| 522 |
+
"153": "PRON#_#PronType=Int",
|
| 523 |
+
"154": "PRON#_#PronType=Rel",
|
| 524 |
+
"155": "PROPN#Noun#Abbr=Yes|Number=Plur",
|
| 525 |
+
"156": "PROPN#Noun#Abbr=Yes|Number=Sing",
|
| 526 |
+
"157": "PROPN#Noun#Number=Plur",
|
| 527 |
+
"158": "PROPN#Noun#Number=Sing",
|
| 528 |
+
"159": "PROPN#Noun#Number=Sing|Polarity=Neg",
|
| 529 |
+
"160": "PROPN#Noun#PronType=Dem",
|
| 530 |
+
"161": "PROPN#Noun#VerbForm=Fin",
|
| 531 |
+
"162": "PROPN#Prefixoid#Number=Sing",
|
| 532 |
+
"163": "PROPN#_#Abbr=Yes",
|
| 533 |
+
"164": "PROPN#_#Number=Plur",
|
| 534 |
+
"165": "PROPN#_#Number=Sing",
|
| 535 |
+
"166": "PUNCT#PUNCT#_",
|
| 536 |
+
"167": "PUNCT#_#_",
|
| 537 |
+
"168": "Prefixoid#Prefixoid#_",
|
| 538 |
+
"169": "SCONJ#Conjunction#_",
|
| 539 |
+
"170": "SCONJ#_#_",
|
| 540 |
+
"171": "SYM#Conjunction#_",
|
| 541 |
+
"172": "SYM#Noun#Number=Sing",
|
| 542 |
+
"173": "SYM#Noun#_",
|
| 543 |
+
"174": "VERB#Verb#Mood=Imp|VerbForm=Inf",
|
| 544 |
+
"175": "VERB#Verb#Mood=Ind|Number=Plur|Person=1|Tense=Past|VerbForm=Fin",
|
| 545 |
+
"176": "VERB#Verb#Mood=Ind|Number=Plur|Person=1|Tense=Pres|VerbForm=Fin",
|
| 546 |
+
"177": "VERB#Verb#Mood=Ind|Number=Plur|Person=2|Tense=Pres|VerbForm=Fin",
|
| 547 |
+
"178": "VERB#Verb#Mood=Ind|Number=Plur|Person=3|Tense=Past|VerbForm=Fin",
|
| 548 |
+
"179": "VERB#Verb#Mood=Ind|Number=Plur|Person=3|Tense=Pres|VerbForm=Fin",
|
| 549 |
+
"180": "VERB#Verb#Mood=Ind|Number=Sing|Person=1|Tense=Past|VerbForm=Fin",
|
| 550 |
+
"181": "VERB#Verb#Mood=Ind|Number=Sing|Person=1|Tense=Pres|VerbForm=Fin",
|
| 551 |
+
"182": "VERB#Verb#Mood=Ind|Number=Sing|Person=2|Tense=Past|VerbForm=Fin",
|
| 552 |
+
"183": "VERB#Verb#Mood=Ind|Number=Sing|Person=2|Tense=Pres|VerbForm=Fin",
|
| 553 |
+
"184": "VERB#Verb#Mood=Ind|Number=Sing|Person=3|Polarity=Neg|Tense=Pres|VerbForm=Fin",
|
| 554 |
+
"185": "VERB#Verb#Mood=Ind|Number=Sing|Person=3|Tense=Past|VerbForm=Fin",
|
| 555 |
+
"186": "VERB#Verb#Mood=Ind|Number=Sing|Person=3|Tense=Pres|VerbForm=Fin",
|
| 556 |
+
"187": "VERB#Verb#Mood=Sub|Number=Plur|Person=1|Tense=Past|VerbForm=Fin",
|
| 557 |
+
"188": "VERB#Verb#Mood=Sub|Tense=Past|VerbForm=Part",
|
| 558 |
+
"189": "VERB#Verb#Mood=Sub|Tense=Past|VerbForm=Part|Voice=Pass",
|
| 559 |
+
"190": "VERB#Verb#Mood=Sub|VerbForm=Inf",
|
| 560 |
+
"191": "VERB#Verb#Person=1|Tense=Past|VerbForm=Part",
|
| 561 |
+
"192": "VERB#Verb#Person=1|Tense=Past|VerbForm=Part|Voice=Pass",
|
| 562 |
+
"193": "VERB#Verb#Person=1|Tense=Pres|VerbForm=Ger",
|
| 563 |
+
"194": "VERB#Verb#Person=1|Tense=Pres|VerbForm=Inf",
|
| 564 |
+
"195": "VERB#Verb#Person=1|Tense=Pres|VerbForm=Part",
|
| 565 |
+
"196": "VERB#Verb#Person=2|Tense=Pres|VerbForm=Inf",
|
| 566 |
+
"197": "VERB#Verb#Tense=Past|VerbForm=Part",
|
| 567 |
+
"198": "VERB#Verb#Tense=Past|VerbForm=Part|Voice=Pass",
|
| 568 |
+
"199": "VERB#Verb#Tense=Pres|VerbForm=Part",
|
| 569 |
+
"200": "VERB#Verb#VerbForm=Fin",
|
| 570 |
+
"201": "VERB#Verb#VerbForm=Ger",
|
| 571 |
+
"202": "VERB#Verb#VerbForm=Inf",
|
| 572 |
+
"203": "VERB#_#Mood=Ind|Tense=Past|VerbForm=Fin",
|
| 573 |
+
"204": "VERB#_#Tense=Past|VerbForm=Part",
|
| 574 |
+
"205": "VERB#_#VerbForm=Ger",
|
| 575 |
+
"206": "VERB#_#VerbForm=Inf",
|
| 576 |
+
"207": "X#_#Foreign=Yes",
|
| 577 |
+
"208": "X#_#Typo=Yes",
|
| 578 |
+
"209": "X#_#_",
|
| 579 |
+
"210": "X#_#foreign=Yes"
|
| 580 |
},
|
| 581 |
"lemma_rule": {
|
| 582 |
"0": "cut_prefix=0|cut_suffix=0|append_suffix=",
|
dependency_classifier.py
CHANGED
|
@@ -38,19 +38,21 @@ class DependencyHeadBase(nn.Module):
|
|
| 38 |
|
| 39 |
def forward(
|
| 40 |
self,
|
| 41 |
-
h_arc_head: Tensor,
|
| 42 |
-
h_arc_dep: Tensor,
|
| 43 |
-
h_rel_head: Tensor,
|
| 44 |
-
h_rel_dep: Tensor,
|
| 45 |
-
gold_arcs: LongTensor,
|
| 46 |
-
|
|
|
|
| 47 |
) -> dict[str, Tensor]:
|
| 48 |
-
|
| 49 |
# Score arcs.
|
| 50 |
-
# s_arc[:, i, j] = score of edge
|
| 51 |
s_arc = self.arc_attention(h_arc_head, h_arc_dep)
|
| 52 |
# Mask undesirable values (padding, nulls, etc.) with -inf.
|
| 53 |
-
|
|
|
|
| 54 |
# Score arcs' relations.
|
| 55 |
# [batch_size, seq_len, seq_len, num_labels]
|
| 56 |
s_rel = self.rel_attention(h_rel_head, h_rel_dep).permute(0, 2, 3, 1)
|
|
@@ -63,11 +65,11 @@ class DependencyHeadBase(nn.Module):
|
|
| 63 |
|
| 64 |
# Predict arcs based on the scores.
|
| 65 |
# [batch_size, seq_len, seq_len]
|
| 66 |
-
|
| 67 |
# [batch_size, seq_len, seq_len]
|
| 68 |
-
|
| 69 |
# [n_pred_arcs, 4]
|
| 70 |
-
preds_combined = self.combine_arcs_rels(
|
| 71 |
return {
|
| 72 |
'preds': preds_combined,
|
| 73 |
'loss': loss
|
|
@@ -91,8 +93,9 @@ class DependencyHeadBase(nn.Module):
|
|
| 91 |
|
| 92 |
def predict_arcs(
|
| 93 |
self,
|
| 94 |
-
s_arc: Tensor,
|
| 95 |
-
|
|
|
|
| 96 |
) -> LongTensor:
|
| 97 |
"""Predict arcs from scores."""
|
| 98 |
raise NotImplementedError
|
|
@@ -127,42 +130,40 @@ class DependencyHead(DependencyHeadBase):
|
|
| 127 |
@override
|
| 128 |
def predict_arcs(
|
| 129 |
self,
|
| 130 |
-
s_arc: Tensor,
|
| 131 |
-
|
|
|
|
| 132 |
) -> Tensor:
|
| 133 |
|
| 134 |
if self.training:
|
| 135 |
# During training, use fast greedy decoding.
|
| 136 |
# - [batch_size, seq_len]
|
| 137 |
-
pred_arcs_seq = s_arc.argmax(dim=
|
| 138 |
else:
|
| 139 |
-
# During inference,
|
| 140 |
-
pred_arcs_seq = self._mst_decode(s_arc,
|
| 141 |
-
# FIXME
|
| 142 |
-
# pred_arcs_seq = s_arc.argmax(dim=-1)
|
| 143 |
|
| 144 |
# Upscale arcs sequence of shape [batch_size, seq_len]
|
| 145 |
# to matrix of shape [batch_size, seq_len, seq_len].
|
| 146 |
-
pred_arcs = F.one_hot(pred_arcs_seq, num_classes=pred_arcs_seq.size(1)).long()
|
|
|
|
|
|
|
|
|
|
|
|
|
| 147 |
return pred_arcs
|
| 148 |
|
| 149 |
def _mst_decode(
|
| 150 |
self,
|
| 151 |
-
s_arc: Tensor,
|
| 152 |
-
|
| 153 |
) -> tuple[Tensor, Tensor]:
|
| 154 |
-
|
| 155 |
batch_size = s_arc.size(0)
|
| 156 |
device = s_arc.device
|
| 157 |
s_arc = s_arc.cpu()
|
| 158 |
|
| 159 |
# Convert scores to probabilities, as `decode_mst` expects non-negative values.
|
| 160 |
-
arc_probs = nn.functional.softmax(s_arc, dim=
|
| 161 |
-
# Transpose arcs, because decode_mst defines 'energy' matrix as
|
| 162 |
-
# energy[i,j] = "Score that `i` is the head of `j`",
|
| 163 |
-
# whereas
|
| 164 |
-
# arc_probs[i,j] = "Probability that `j` is the head of `i`".
|
| 165 |
-
arc_probs = arc_probs.transpose(1, 2)
|
| 166 |
|
| 167 |
# `decode_mst` knows nothing about UD and ROOT, so we have to manually
|
| 168 |
# zero probabilities of arcs leading to ROOT to make sure ROOT is a source node
|
|
@@ -177,11 +178,10 @@ class DependencyHead(DependencyHeadBase):
|
|
| 177 |
pred_arcs = []
|
| 178 |
for sample_idx in range(batch_size):
|
| 179 |
energy = arc_probs[sample_idx]
|
| 180 |
-
|
| 181 |
-
|
| 182 |
-
heads, _ = decode_mst(energy, lengths, has_labels=False)
|
| 183 |
# Some nodes may be isolated. Pick heads greedily in this case.
|
| 184 |
-
heads[heads <= 0] = s_arc[sample_idx].argmax(dim=
|
| 185 |
pred_arcs.append(heads)
|
| 186 |
|
| 187 |
# shape: [batch_size, seq_len]
|
|
@@ -195,7 +195,7 @@ class DependencyHead(DependencyHeadBase):
|
|
| 195 |
gold_arcs: LongTensor # [n_arcs, 4]
|
| 196 |
) -> tuple[Tensor, Tensor]:
|
| 197 |
batch_idxs, from_idxs, to_idxs, _ = gold_arcs.T
|
| 198 |
-
return F.cross_entropy(s_arc[batch_idxs,
|
| 199 |
|
| 200 |
|
| 201 |
class MultiDependencyHead(DependencyHeadBase):
|
|
@@ -206,8 +206,9 @@ class MultiDependencyHead(DependencyHeadBase):
|
|
| 206 |
@override
|
| 207 |
def predict_arcs(
|
| 208 |
self,
|
| 209 |
-
s_arc: Tensor,
|
| 210 |
-
|
|
|
|
| 211 |
) -> Tensor:
|
| 212 |
# Convert scores to probabilities.
|
| 213 |
arc_probs = torch.sigmoid(s_arc)
|
|
@@ -263,8 +264,8 @@ class DependencyClassifier(nn.Module):
|
|
| 263 |
embeddings: Tensor, # [batch_size, seq_len, embedding_size]
|
| 264 |
gold_ud: Tensor, # [n_ud_arcs, 4]
|
| 265 |
gold_eud: Tensor, # [n_eud_arcs, 4]
|
| 266 |
-
|
| 267 |
-
|
| 268 |
) -> dict[str, Tensor]:
|
| 269 |
|
| 270 |
# - [batch_size, seq_len, hidden_size]
|
|
@@ -280,7 +281,8 @@ class DependencyClassifier(nn.Module):
|
|
| 280 |
h_rel_head,
|
| 281 |
h_rel_dep,
|
| 282 |
gold_arcs=gold_ud,
|
| 283 |
-
|
|
|
|
| 284 |
)
|
| 285 |
output_eud = self.dependency_head_eud(
|
| 286 |
h_arc_head,
|
|
@@ -288,7 +290,9 @@ class DependencyClassifier(nn.Module):
|
|
| 288 |
h_rel_head,
|
| 289 |
h_rel_dep,
|
| 290 |
gold_arcs=gold_eud,
|
| 291 |
-
mask
|
|
|
|
|
|
|
| 292 |
)
|
| 293 |
|
| 294 |
return {
|
|
|
|
| 38 |
|
| 39 |
def forward(
|
| 40 |
self,
|
| 41 |
+
h_arc_head: Tensor, # [batch_size, seq_len, hidden_size]
|
| 42 |
+
h_arc_dep: Tensor, # ...
|
| 43 |
+
h_rel_head: Tensor, # ...
|
| 44 |
+
h_rel_dep: Tensor, # ...
|
| 45 |
+
gold_arcs: LongTensor, # [batch_size, seq_len, seq_len]
|
| 46 |
+
null_mask: BoolTensor, # [batch_size, seq_len]
|
| 47 |
+
padding_mask: BoolTensor # [batch_size, seq_len]
|
| 48 |
) -> dict[str, Tensor]:
|
| 49 |
+
|
| 50 |
# Score arcs.
|
| 51 |
+
# s_arc[:, i, j] = score of edge i -> j.
|
| 52 |
s_arc = self.arc_attention(h_arc_head, h_arc_dep)
|
| 53 |
# Mask undesirable values (padding, nulls, etc.) with -inf.
|
| 54 |
+
mask2d = pairwise_mask(null_mask & padding_mask)
|
| 55 |
+
replace_masked_values(s_arc, mask2d, replace_with=-1e8)
|
| 56 |
# Score arcs' relations.
|
| 57 |
# [batch_size, seq_len, seq_len, num_labels]
|
| 58 |
s_rel = self.rel_attention(h_rel_head, h_rel_dep).permute(0, 2, 3, 1)
|
|
|
|
| 65 |
|
| 66 |
# Predict arcs based on the scores.
|
| 67 |
# [batch_size, seq_len, seq_len]
|
| 68 |
+
pred_arcs_matrix = self.predict_arcs(s_arc, null_mask, padding_mask)
|
| 69 |
# [batch_size, seq_len, seq_len]
|
| 70 |
+
pred_rels_matrix = self.predict_rels(s_rel)
|
| 71 |
# [n_pred_arcs, 4]
|
| 72 |
+
preds_combined = self.combine_arcs_rels(pred_arcs_matrix, pred_rels_matrix)
|
| 73 |
return {
|
| 74 |
'preds': preds_combined,
|
| 75 |
'loss': loss
|
|
|
|
| 93 |
|
| 94 |
def predict_arcs(
|
| 95 |
self,
|
| 96 |
+
s_arc: Tensor, # [batch_size, seq_len, seq_len]
|
| 97 |
+
null_mask: BoolTensor, # [batch_size, seq_len]
|
| 98 |
+
padding_mask: BoolTensor # [batch_size, seq_len]
|
| 99 |
) -> LongTensor:
|
| 100 |
"""Predict arcs from scores."""
|
| 101 |
raise NotImplementedError
|
|
|
|
| 130 |
@override
|
| 131 |
def predict_arcs(
|
| 132 |
self,
|
| 133 |
+
s_arc: Tensor, # [batch_size, seq_len, seq_len]
|
| 134 |
+
null_mask: BoolTensor, # [batch_size, seq_len]
|
| 135 |
+
padding_mask: BoolTensor # [batch_size, seq_len, seq_len]
|
| 136 |
) -> Tensor:
|
| 137 |
|
| 138 |
if self.training:
|
| 139 |
# During training, use fast greedy decoding.
|
| 140 |
# - [batch_size, seq_len]
|
| 141 |
+
pred_arcs_seq = s_arc.argmax(dim=1)
|
| 142 |
else:
|
| 143 |
+
# During inference, decode Maximum Spanning Tree.
|
| 144 |
+
pred_arcs_seq = self._mst_decode(s_arc, padding_mask)
|
|
|
|
|
|
|
| 145 |
|
| 146 |
# Upscale arcs sequence of shape [batch_size, seq_len]
|
| 147 |
# to matrix of shape [batch_size, seq_len, seq_len].
|
| 148 |
+
pred_arcs = F.one_hot(pred_arcs_seq, num_classes=pred_arcs_seq.size(1)).long().transpose(1, 2)
|
| 149 |
+
# Apply mask one more time (even though s_arc is already masked),
|
| 150 |
+
# because argmax erases information about masked values.
|
| 151 |
+
mask2d = pairwise_mask(null_mask & padding_mask)
|
| 152 |
+
replace_masked_values(pred_arcs, mask2d, replace_with=0)
|
| 153 |
return pred_arcs
|
| 154 |
|
| 155 |
def _mst_decode(
|
| 156 |
self,
|
| 157 |
+
s_arc: Tensor, # [batch_size, seq_len, seq_len]
|
| 158 |
+
padding_mask: Tensor
|
| 159 |
) -> tuple[Tensor, Tensor]:
|
| 160 |
+
|
| 161 |
batch_size = s_arc.size(0)
|
| 162 |
device = s_arc.device
|
| 163 |
s_arc = s_arc.cpu()
|
| 164 |
|
| 165 |
# Convert scores to probabilities, as `decode_mst` expects non-negative values.
|
| 166 |
+
arc_probs = nn.functional.softmax(s_arc, dim=1)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 167 |
|
| 168 |
# `decode_mst` knows nothing about UD and ROOT, so we have to manually
|
| 169 |
# zero probabilities of arcs leading to ROOT to make sure ROOT is a source node
|
|
|
|
| 178 |
pred_arcs = []
|
| 179 |
for sample_idx in range(batch_size):
|
| 180 |
energy = arc_probs[sample_idx]
|
| 181 |
+
length = padding_mask[sample_idx].sum()
|
| 182 |
+
heads = decode_mst(energy, length)
|
|
|
|
| 183 |
# Some nodes may be isolated. Pick heads greedily in this case.
|
| 184 |
+
heads[heads <= 0] = s_arc[sample_idx].argmax(dim=1)[heads <= 0]
|
| 185 |
pred_arcs.append(heads)
|
| 186 |
|
| 187 |
# shape: [batch_size, seq_len]
|
|
|
|
| 195 |
gold_arcs: LongTensor # [n_arcs, 4]
|
| 196 |
) -> tuple[Tensor, Tensor]:
|
| 197 |
batch_idxs, from_idxs, to_idxs, _ = gold_arcs.T
|
| 198 |
+
return F.cross_entropy(s_arc[batch_idxs, :, to_idxs], from_idxs)
|
| 199 |
|
| 200 |
|
| 201 |
class MultiDependencyHead(DependencyHeadBase):
|
|
|
|
| 206 |
@override
|
| 207 |
def predict_arcs(
|
| 208 |
self,
|
| 209 |
+
s_arc: Tensor, # [batch_size, seq_len, seq_len]
|
| 210 |
+
null_mask: BoolTensor, # [batch_size, seq_len]
|
| 211 |
+
padding_mask: BoolTensor # [batch_size, seq_len]
|
| 212 |
) -> Tensor:
|
| 213 |
# Convert scores to probabilities.
|
| 214 |
arc_probs = torch.sigmoid(s_arc)
|
|
|
|
| 264 |
embeddings: Tensor, # [batch_size, seq_len, embedding_size]
|
| 265 |
gold_ud: Tensor, # [n_ud_arcs, 4]
|
| 266 |
gold_eud: Tensor, # [n_eud_arcs, 4]
|
| 267 |
+
null_mask: Tensor, # [batch_size, seq_len]
|
| 268 |
+
padding_mask: Tensor # [batch_size, seq_len]
|
| 269 |
) -> dict[str, Tensor]:
|
| 270 |
|
| 271 |
# - [batch_size, seq_len, hidden_size]
|
|
|
|
| 281 |
h_rel_head,
|
| 282 |
h_rel_dep,
|
| 283 |
gold_arcs=gold_ud,
|
| 284 |
+
null_mask=null_mask,
|
| 285 |
+
padding_mask=padding_mask
|
| 286 |
)
|
| 287 |
output_eud = self.dependency_head_eud(
|
| 288 |
h_arc_head,
|
|
|
|
| 290 |
h_rel_head,
|
| 291 |
h_rel_dep,
|
| 292 |
gold_arcs=gold_eud,
|
| 293 |
+
# Ignore null mask in E-UD
|
| 294 |
+
null_mask=torch.ones_like(padding_mask),
|
| 295 |
+
padding_mask=padding_mask
|
| 296 |
)
|
| 297 |
|
| 298 |
return {
|
model.safetensors
CHANGED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
size 1141314800
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:1618ac5132f5aa1c8525829b0a8ac2e7a0e38ae184cfd1bdcbe5ded4e90a63ee
|
| 3 |
size 1141314800
|
modeling_parser.py
CHANGED
|
@@ -1,8 +1,6 @@
|
|
| 1 |
from torch import nn
|
| 2 |
from torch import LongTensor
|
| 3 |
from transformers import PreTrainedModel
|
| 4 |
-
from transformers.modeling_outputs import ModelOutput
|
| 5 |
-
from dataclasses import dataclass
|
| 6 |
|
| 7 |
from .configuration import CobaldParserConfig
|
| 8 |
from .encoder import WordTransformerEncoder
|
|
@@ -17,23 +15,6 @@ from .utils import (
|
|
| 17 |
)
|
| 18 |
|
| 19 |
|
| 20 |
-
@dataclass
|
| 21 |
-
class CobaldParserOutput(ModelOutput):
|
| 22 |
-
"""
|
| 23 |
-
Output type for CobaldParser.
|
| 24 |
-
"""
|
| 25 |
-
loss: float = None
|
| 26 |
-
words: list = None
|
| 27 |
-
counting_mask: LongTensor = None
|
| 28 |
-
lemma_rules: LongTensor = None
|
| 29 |
-
joint_feats: LongTensor = None
|
| 30 |
-
deps_ud: LongTensor = None
|
| 31 |
-
deps_eud: LongTensor = None
|
| 32 |
-
miscs: LongTensor = None
|
| 33 |
-
deepslots: LongTensor = None
|
| 34 |
-
semclasses: LongTensor = None
|
| 35 |
-
|
| 36 |
-
|
| 37 |
class CobaldParser(PreTrainedModel):
|
| 38 |
"""Morpho-Syntax-Semantic Parser."""
|
| 39 |
|
|
@@ -119,8 +100,8 @@ class CobaldParser(PreTrainedModel):
|
|
| 119 |
sent_ids: list[str] = None,
|
| 120 |
texts: list[str] = None,
|
| 121 |
inference_mode: bool = False
|
| 122 |
-
) ->
|
| 123 |
-
|
| 124 |
|
| 125 |
# Extra [CLS] token accounts for the case when #NULL is the first token in a sentence.
|
| 126 |
words_with_cls = prepend_cls(words)
|
|
@@ -129,62 +110,62 @@ class CobaldParser(PreTrainedModel):
|
|
| 129 |
embeddings_without_nulls = self.encoder(words_without_nulls)
|
| 130 |
# Predict nulls.
|
| 131 |
null_output = self.classifiers["null"](embeddings_without_nulls, counting_masks)
|
| 132 |
-
|
| 133 |
-
|
| 134 |
|
| 135 |
# "Teacher forcing": during training, pass the original words (with gold nulls)
|
| 136 |
# to the classification heads, so that they are trained upon correct sentences.
|
| 137 |
if inference_mode:
|
| 138 |
# Restore predicted nulls in the original sentences.
|
| 139 |
-
|
| 140 |
else:
|
| 141 |
-
|
| 142 |
|
| 143 |
# Encode words with nulls.
|
| 144 |
# [batch_size, seq_len, embedding_size]
|
| 145 |
-
embeddings = self.encoder(
|
| 146 |
|
| 147 |
# Predict lemmas and morphological features.
|
| 148 |
if "lemma_rule" in self.classifiers:
|
| 149 |
lemma_output = self.classifiers["lemma_rule"](embeddings, lemma_rules)
|
| 150 |
-
|
| 151 |
-
|
| 152 |
|
| 153 |
if "joint_feats" in self.classifiers:
|
| 154 |
joint_feats_output = self.classifiers["joint_feats"](embeddings, joint_feats)
|
| 155 |
-
|
| 156 |
-
|
| 157 |
|
| 158 |
# Predict syntax.
|
| 159 |
if "syntax" in self.classifiers:
|
| 160 |
-
padding_mask = build_padding_mask(
|
| 161 |
-
null_mask = build_null_mask(
|
| 162 |
deps_output = self.classifiers["syntax"](
|
| 163 |
embeddings,
|
| 164 |
deps_ud,
|
| 165 |
deps_eud,
|
| 166 |
-
|
| 167 |
-
|
| 168 |
)
|
| 169 |
-
|
| 170 |
-
|
| 171 |
-
|
| 172 |
|
| 173 |
# Predict miscellaneous features.
|
| 174 |
if "misc" in self.classifiers:
|
| 175 |
misc_output = self.classifiers["misc"](embeddings, miscs)
|
| 176 |
-
|
| 177 |
-
|
| 178 |
|
| 179 |
# Predict semantics.
|
| 180 |
if "deepslot" in self.classifiers:
|
| 181 |
deepslot_output = self.classifiers["deepslot"](embeddings, deepslots)
|
| 182 |
-
|
| 183 |
-
|
| 184 |
|
| 185 |
if "semclass" in self.classifiers:
|
| 186 |
semclass_output = self.classifiers["semclass"](embeddings, semclasses)
|
| 187 |
-
|
| 188 |
-
|
| 189 |
|
| 190 |
-
return
|
|
|
|
| 1 |
from torch import nn
|
| 2 |
from torch import LongTensor
|
| 3 |
from transformers import PreTrainedModel
|
|
|
|
|
|
|
| 4 |
|
| 5 |
from .configuration import CobaldParserConfig
|
| 6 |
from .encoder import WordTransformerEncoder
|
|
|
|
| 15 |
)
|
| 16 |
|
| 17 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 18 |
class CobaldParser(PreTrainedModel):
|
| 19 |
"""Morpho-Syntax-Semantic Parser."""
|
| 20 |
|
|
|
|
| 100 |
sent_ids: list[str] = None,
|
| 101 |
texts: list[str] = None,
|
| 102 |
inference_mode: bool = False
|
| 103 |
+
) -> dict:
|
| 104 |
+
output = {}
|
| 105 |
|
| 106 |
# Extra [CLS] token accounts for the case when #NULL is the first token in a sentence.
|
| 107 |
words_with_cls = prepend_cls(words)
|
|
|
|
| 110 |
embeddings_without_nulls = self.encoder(words_without_nulls)
|
| 111 |
# Predict nulls.
|
| 112 |
null_output = self.classifiers["null"](embeddings_without_nulls, counting_masks)
|
| 113 |
+
output["counting_mask"] = null_output['preds']
|
| 114 |
+
output["loss"] = null_output["loss"]
|
| 115 |
|
| 116 |
# "Teacher forcing": during training, pass the original words (with gold nulls)
|
| 117 |
# to the classification heads, so that they are trained upon correct sentences.
|
| 118 |
if inference_mode:
|
| 119 |
# Restore predicted nulls in the original sentences.
|
| 120 |
+
output["words"] = add_nulls(words, null_output["preds"])
|
| 121 |
else:
|
| 122 |
+
output["words"] = words
|
| 123 |
|
| 124 |
# Encode words with nulls.
|
| 125 |
# [batch_size, seq_len, embedding_size]
|
| 126 |
+
embeddings = self.encoder(output["words"])
|
| 127 |
|
| 128 |
# Predict lemmas and morphological features.
|
| 129 |
if "lemma_rule" in self.classifiers:
|
| 130 |
lemma_output = self.classifiers["lemma_rule"](embeddings, lemma_rules)
|
| 131 |
+
output["lemma_rules"] = lemma_output['preds']
|
| 132 |
+
output["loss"] += lemma_output['loss']
|
| 133 |
|
| 134 |
if "joint_feats" in self.classifiers:
|
| 135 |
joint_feats_output = self.classifiers["joint_feats"](embeddings, joint_feats)
|
| 136 |
+
output["joint_feats"] = joint_feats_output['preds']
|
| 137 |
+
output["loss"] += joint_feats_output['loss']
|
| 138 |
|
| 139 |
# Predict syntax.
|
| 140 |
if "syntax" in self.classifiers:
|
| 141 |
+
padding_mask = build_padding_mask(output["words"], self.device)
|
| 142 |
+
null_mask = build_null_mask(output["words"], self.device)
|
| 143 |
deps_output = self.classifiers["syntax"](
|
| 144 |
embeddings,
|
| 145 |
deps_ud,
|
| 146 |
deps_eud,
|
| 147 |
+
null_mask,
|
| 148 |
+
padding_mask
|
| 149 |
)
|
| 150 |
+
output["deps_ud"] = deps_output['preds_ud']
|
| 151 |
+
output["deps_eud"] = deps_output['preds_eud']
|
| 152 |
+
output["loss"] += deps_output['loss_ud'] + deps_output['loss_eud']
|
| 153 |
|
| 154 |
# Predict miscellaneous features.
|
| 155 |
if "misc" in self.classifiers:
|
| 156 |
misc_output = self.classifiers["misc"](embeddings, miscs)
|
| 157 |
+
output["miscs"] = misc_output['preds']
|
| 158 |
+
output["loss"] += misc_output['loss']
|
| 159 |
|
| 160 |
# Predict semantics.
|
| 161 |
if "deepslot" in self.classifiers:
|
| 162 |
deepslot_output = self.classifiers["deepslot"](embeddings, deepslots)
|
| 163 |
+
output["deepslots"] = deepslot_output['preds']
|
| 164 |
+
output["loss"] += deepslot_output['loss']
|
| 165 |
|
| 166 |
if "semclass" in self.classifiers:
|
| 167 |
semclass_output = self.classifiers["semclass"](embeddings, semclasses)
|
| 168 |
+
output["semclasses"] = semclass_output['preds']
|
| 169 |
+
output["loss"] += semclass_output['loss']
|
| 170 |
|
| 171 |
+
return output
|
training_args.bin
CHANGED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
-
size
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:2d8dec73b5638e57d0bca9bd4ee05cd11ce5aba98bc59f80b4e16231e6e7403f
|
| 3 |
+
size 5905
|
utils.py
CHANGED
|
@@ -21,7 +21,7 @@ def build_padding_mask(sentences: list[list[str]], device) -> Tensor:
|
|
| 21 |
return _build_condition_mask(sentences, condition_fn=lambda word: True, device=device)
|
| 22 |
|
| 23 |
def build_null_mask(sentences: list[list[str]], device) -> Tensor:
|
| 24 |
-
return _build_condition_mask(sentences, condition_fn=lambda word: word =
|
| 25 |
|
| 26 |
|
| 27 |
def pairwise_mask(masks1d: Tensor) -> Tensor:
|
|
|
|
| 21 |
return _build_condition_mask(sentences, condition_fn=lambda word: True, device=device)
|
| 22 |
|
| 23 |
def build_null_mask(sentences: list[list[str]], device) -> Tensor:
|
| 24 |
+
return _build_condition_mask(sentences, condition_fn=lambda word: word != "#NULL", device=device)
|
| 25 |
|
| 26 |
|
| 27 |
def pairwise_mask(masks1d: Tensor) -> Tensor:
|