haukurpj commited on
Commit
2d923bf
·
verified ·
1 Parent(s): f17941f

Upload folder using huggingface_hub

Browse files
__init__.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) Miðeind ehf.
2
+ # This file is part of IceBERT POS model conversion.
3
+
4
+ from .configuration import IceBertPosConfig
5
+ from .modeling import IceBertPosForTokenClassification, MultiLabelTokenClassificationHead
6
+
7
+ __version__ = "0.1.0"
8
+
9
+ __all__ = [
10
+ "IceBertPosConfig",
11
+ "IceBertPosForTokenClassification",
12
+ "MultiLabelTokenClassificationHead",
13
+ ]
__pycache__/__init__.cpython-38.pyc ADDED
Binary file (157 Bytes). View file
 
__pycache__/configuration.cpython-38.pyc ADDED
Binary file (3.52 kB). View file
 
__pycache__/modeling.cpython-38.pyc ADDED
Binary file (15.2 kB). View file
 
config.json ADDED
@@ -0,0 +1,490 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "IceBertPosForTokenClassification"
4
+ ],
5
+ "attention_probs_dropout_prob": 0.1,
6
+ "attr_proj_input_size": 811,
7
+ "auto_map": {
8
+ "AutoConfig": "configuration.IceBertPosConfig",
9
+ "AutoModel": "modeling.IceBertPosForTokenClassification"
10
+ },
11
+ "bos_token_id": 0,
12
+ "classifier_dropout": 0.0,
13
+ "eos_token_id": 2,
14
+ "hidden_act": "gelu",
15
+ "hidden_dropout_prob": 0.1,
16
+ "hidden_size": 768,
17
+ "id2label": {
18
+ "0": "LABEL_0",
19
+ "1": "LABEL_1",
20
+ "2": "LABEL_2",
21
+ "3": "LABEL_3",
22
+ "4": "LABEL_4",
23
+ "5": "LABEL_5",
24
+ "6": "LABEL_6",
25
+ "7": "LABEL_7",
26
+ "8": "LABEL_8",
27
+ "9": "LABEL_9",
28
+ "10": "LABEL_10",
29
+ "11": "LABEL_11",
30
+ "12": "LABEL_12",
31
+ "13": "LABEL_13",
32
+ "14": "LABEL_14",
33
+ "15": "LABEL_15",
34
+ "16": "LABEL_16",
35
+ "17": "LABEL_17",
36
+ "18": "LABEL_18",
37
+ "19": "LABEL_19",
38
+ "20": "LABEL_20",
39
+ "21": "LABEL_21",
40
+ "22": "LABEL_22",
41
+ "23": "LABEL_23",
42
+ "24": "LABEL_24",
43
+ "25": "LABEL_25",
44
+ "26": "LABEL_26",
45
+ "27": "LABEL_27",
46
+ "28": "LABEL_28",
47
+ "29": "LABEL_29",
48
+ "30": "LABEL_30",
49
+ "31": "LABEL_31",
50
+ "32": "LABEL_32",
51
+ "33": "LABEL_33",
52
+ "34": "LABEL_34",
53
+ "35": "LABEL_35",
54
+ "36": "LABEL_36",
55
+ "37": "LABEL_37",
56
+ "38": "LABEL_38",
57
+ "39": "LABEL_39",
58
+ "40": "LABEL_40",
59
+ "41": "LABEL_41",
60
+ "42": "LABEL_42",
61
+ "43": "LABEL_43",
62
+ "44": "LABEL_44",
63
+ "45": "LABEL_45",
64
+ "46": "LABEL_46",
65
+ "47": "LABEL_47",
66
+ "48": "LABEL_48",
67
+ "49": "LABEL_49",
68
+ "50": "LABEL_50",
69
+ "51": "LABEL_51",
70
+ "52": "LABEL_52",
71
+ "53": "LABEL_53",
72
+ "54": "LABEL_54",
73
+ "55": "LABEL_55",
74
+ "56": "LABEL_56",
75
+ "57": "LABEL_57",
76
+ "58": "LABEL_58",
77
+ "59": "LABEL_59",
78
+ "60": "LABEL_60",
79
+ "61": "LABEL_61",
80
+ "62": "LABEL_62",
81
+ "63": "LABEL_63",
82
+ "64": "LABEL_64",
83
+ "65": "LABEL_65",
84
+ "66": "LABEL_66",
85
+ "67": "LABEL_67",
86
+ "68": "LABEL_68",
87
+ "69": "LABEL_69"
88
+ },
89
+ "initializer_range": 0.02,
90
+ "intermediate_size": 3072,
91
+ "label2id": {
92
+ "LABEL_0": 0,
93
+ "LABEL_1": 1,
94
+ "LABEL_10": 10,
95
+ "LABEL_11": 11,
96
+ "LABEL_12": 12,
97
+ "LABEL_13": 13,
98
+ "LABEL_14": 14,
99
+ "LABEL_15": 15,
100
+ "LABEL_16": 16,
101
+ "LABEL_17": 17,
102
+ "LABEL_18": 18,
103
+ "LABEL_19": 19,
104
+ "LABEL_2": 2,
105
+ "LABEL_20": 20,
106
+ "LABEL_21": 21,
107
+ "LABEL_22": 22,
108
+ "LABEL_23": 23,
109
+ "LABEL_24": 24,
110
+ "LABEL_25": 25,
111
+ "LABEL_26": 26,
112
+ "LABEL_27": 27,
113
+ "LABEL_28": 28,
114
+ "LABEL_29": 29,
115
+ "LABEL_3": 3,
116
+ "LABEL_30": 30,
117
+ "LABEL_31": 31,
118
+ "LABEL_32": 32,
119
+ "LABEL_33": 33,
120
+ "LABEL_34": 34,
121
+ "LABEL_35": 35,
122
+ "LABEL_36": 36,
123
+ "LABEL_37": 37,
124
+ "LABEL_38": 38,
125
+ "LABEL_39": 39,
126
+ "LABEL_4": 4,
127
+ "LABEL_40": 40,
128
+ "LABEL_41": 41,
129
+ "LABEL_42": 42,
130
+ "LABEL_43": 43,
131
+ "LABEL_44": 44,
132
+ "LABEL_45": 45,
133
+ "LABEL_46": 46,
134
+ "LABEL_47": 47,
135
+ "LABEL_48": 48,
136
+ "LABEL_49": 49,
137
+ "LABEL_5": 5,
138
+ "LABEL_50": 50,
139
+ "LABEL_51": 51,
140
+ "LABEL_52": 52,
141
+ "LABEL_53": 53,
142
+ "LABEL_54": 54,
143
+ "LABEL_55": 55,
144
+ "LABEL_56": 56,
145
+ "LABEL_57": 57,
146
+ "LABEL_58": 58,
147
+ "LABEL_59": 59,
148
+ "LABEL_6": 6,
149
+ "LABEL_60": 60,
150
+ "LABEL_61": 61,
151
+ "LABEL_62": 62,
152
+ "LABEL_63": 63,
153
+ "LABEL_64": 64,
154
+ "LABEL_65": 65,
155
+ "LABEL_66": 66,
156
+ "LABEL_67": 67,
157
+ "LABEL_68": 68,
158
+ "LABEL_69": 69,
159
+ "LABEL_7": 7,
160
+ "LABEL_8": 8,
161
+ "LABEL_9": 9
162
+ },
163
+ "label_schema": {
164
+ "category_to_group_names": {
165
+ "aa": [
166
+ "deg"
167
+ ],
168
+ "ae": [
169
+ "deg"
170
+ ],
171
+ "af": [
172
+ "deg"
173
+ ],
174
+ "ao": [
175
+ "deg"
176
+ ],
177
+ "as": [
178
+ "deg"
179
+ ],
180
+ "au": [
181
+ "deg"
182
+ ],
183
+ "a\u00fe": [
184
+ "deg"
185
+ ],
186
+ "fa": [
187
+ "gender",
188
+ "number",
189
+ "case"
190
+ ],
191
+ "fb": [
192
+ "gender",
193
+ "number",
194
+ "case"
195
+ ],
196
+ "fe": [
197
+ "gender",
198
+ "number",
199
+ "case"
200
+ ],
201
+ "fo": [
202
+ "gender_or_person",
203
+ "number",
204
+ "case"
205
+ ],
206
+ "fp": [
207
+ "gender_or_person",
208
+ "number",
209
+ "case"
210
+ ],
211
+ "fs": [
212
+ "gender",
213
+ "number",
214
+ "case"
215
+ ],
216
+ "ft": [
217
+ "gender",
218
+ "number",
219
+ "case"
220
+ ],
221
+ "g": [
222
+ "gender",
223
+ "number",
224
+ "case"
225
+ ],
226
+ "l": [
227
+ "gender",
228
+ "number",
229
+ "case",
230
+ "adj_c",
231
+ "deg"
232
+ ],
233
+ "n": [
234
+ "gender",
235
+ "number",
236
+ "case",
237
+ "def",
238
+ "proper"
239
+ ],
240
+ "sb": [
241
+ "voice",
242
+ "person",
243
+ "number",
244
+ "tense"
245
+ ],
246
+ "sf": [
247
+ "voice",
248
+ "person",
249
+ "number",
250
+ "tense"
251
+ ],
252
+ "sl": [
253
+ "voice",
254
+ "person",
255
+ "number",
256
+ "tense"
257
+ ],
258
+ "sn": [
259
+ "voice"
260
+ ],
261
+ "ss": [
262
+ "voice"
263
+ ],
264
+ "sv": [
265
+ "voice",
266
+ "person",
267
+ "number",
268
+ "tense"
269
+ ],
270
+ "s\u00fe": [
271
+ "voice",
272
+ "gender",
273
+ "number",
274
+ "case"
275
+ ],
276
+ "tf": [
277
+ "gender",
278
+ "number",
279
+ "case"
280
+ ]
281
+ },
282
+ "group_name_to_labels": {
283
+ "adj_c": [
284
+ "strong",
285
+ "weak",
286
+ "equiinflected"
287
+ ],
288
+ "case": [
289
+ "nom",
290
+ "acc",
291
+ "dat",
292
+ "gen"
293
+ ],
294
+ "def": [
295
+ "definite"
296
+ ],
297
+ "deg": [
298
+ "pos",
299
+ "cmp",
300
+ "superl"
301
+ ],
302
+ "gender": [
303
+ "masc",
304
+ "fem",
305
+ "neut",
306
+ "gender_x"
307
+ ],
308
+ "gender_or_person": [
309
+ "masc",
310
+ "fem",
311
+ "neut",
312
+ "gender_x",
313
+ "1",
314
+ "2",
315
+ "3"
316
+ ],
317
+ "number": [
318
+ "sing",
319
+ "plur"
320
+ ],
321
+ "person": [
322
+ "1",
323
+ "2",
324
+ "3"
325
+ ],
326
+ "proper": [
327
+ "proper"
328
+ ],
329
+ "tense": [
330
+ "pres",
331
+ "past"
332
+ ],
333
+ "voice": [
334
+ "act",
335
+ "mid"
336
+ ]
337
+ },
338
+ "group_names": [
339
+ "gender",
340
+ "gender_or_person",
341
+ "number",
342
+ "case",
343
+ "def",
344
+ "proper",
345
+ "adj_c",
346
+ "deg",
347
+ "voice",
348
+ "person",
349
+ "tense"
350
+ ],
351
+ "ignore_categories": [
352
+ "x",
353
+ "e"
354
+ ],
355
+ "label_categories": [
356
+ "n",
357
+ "g",
358
+ "x",
359
+ "e",
360
+ "v",
361
+ "l",
362
+ "fa",
363
+ "fb",
364
+ "fe",
365
+ "fo",
366
+ "fp",
367
+ "fs",
368
+ "ft",
369
+ "tf",
370
+ "ta",
371
+ "tp",
372
+ "to",
373
+ "sn",
374
+ "sb",
375
+ "sf",
376
+ "sv",
377
+ "ss",
378
+ "sl",
379
+ "s\u00fe",
380
+ "cn",
381
+ "ct",
382
+ "c",
383
+ "aa",
384
+ "af",
385
+ "au",
386
+ "ao",
387
+ "a\u00fe",
388
+ "ae",
389
+ "as",
390
+ "ks",
391
+ "kt",
392
+ "p",
393
+ "pl",
394
+ "pk",
395
+ "pg",
396
+ "pa",
397
+ "ns",
398
+ "m"
399
+ ],
400
+ "labels": [
401
+ "<SEP>",
402
+ "n",
403
+ "g",
404
+ "x",
405
+ "e",
406
+ "v",
407
+ "l",
408
+ "fa",
409
+ "fb",
410
+ "fe",
411
+ "fo",
412
+ "fp",
413
+ "fs",
414
+ "ft",
415
+ "tf",
416
+ "ta",
417
+ "tp",
418
+ "to",
419
+ "sn",
420
+ "sb",
421
+ "sf",
422
+ "sv",
423
+ "ss",
424
+ "sl",
425
+ "s\u00fe",
426
+ "cn",
427
+ "ct",
428
+ "c",
429
+ "aa",
430
+ "af",
431
+ "au",
432
+ "ao",
433
+ "a\u00fe",
434
+ "ae",
435
+ "as",
436
+ "ks",
437
+ "kt",
438
+ "p",
439
+ "pl",
440
+ "pk",
441
+ "pg",
442
+ "pa",
443
+ "ns",
444
+ "m",
445
+ "masc",
446
+ "fem",
447
+ "neut",
448
+ "gender_x",
449
+ "1",
450
+ "2",
451
+ "3",
452
+ "sing",
453
+ "plur",
454
+ "nom",
455
+ "acc",
456
+ "dat",
457
+ "gen",
458
+ "definite",
459
+ "proper",
460
+ "strong",
461
+ "weak",
462
+ "equiinflected",
463
+ "pos",
464
+ "cmp",
465
+ "superl",
466
+ "past",
467
+ "pres",
468
+ "pass",
469
+ "act",
470
+ "mid"
471
+ ],
472
+ "null": null,
473
+ "null_leaf": null,
474
+ "separator": "<SEP>"
475
+ },
476
+ "layer_norm_eps": 1e-05,
477
+ "max_position_embeddings": 514,
478
+ "model_type": "icebert-pos",
479
+ "num_attention_heads": 12,
480
+ "num_categories": 43,
481
+ "num_groups": 12,
482
+ "num_hidden_layers": 12,
483
+ "pad_token_id": 1,
484
+ "position_embedding_type": "absolute",
485
+ "torch_dtype": "float32",
486
+ "transformers_version": "4.46.3",
487
+ "type_vocab_size": 1,
488
+ "use_cache": true,
489
+ "vocab_size": 49937
490
+ }
configuration.py ADDED
@@ -0,0 +1,231 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) Miðeind ehf.
2
+ # This file is part of IceBERT POS model conversion.
3
+
4
+ import json
5
+ from typing import Any, Dict, Optional
6
+
7
+ from transformers import AutoConfig, RobertaConfig
8
+
9
+
10
+ class IceBertPosConfig(RobertaConfig):
11
+ """
12
+ Configuration class for IceBERT POS (Part-of-Speech) tagging model.
13
+
14
+ This configuration inherits from RobertaConfig and adds POS-specific parameters
15
+ derived from the label schema used for multilabel token classification.
16
+ """
17
+
18
+ model_type = "icebert-pos"
19
+
20
+ def __init__(
21
+ self, label_schema: Optional[Dict[str, Any]] = None, classifier_dropout: Optional[float] = None, **kwargs
22
+ ):
23
+ super().__init__(**kwargs)
24
+
25
+ # Default label schema (terms2.json content)
26
+ if label_schema is None:
27
+ label_schema = self._get_default_label_schema()
28
+
29
+ self.label_schema = label_schema
30
+
31
+ # Derive parameters from label schema
32
+ self.num_categories = len(label_schema["label_categories"])
33
+ self.num_labels = len(label_schema["labels"])
34
+ self.num_groups = len(label_schema["group_names"])
35
+
36
+ # Classification head parameters
37
+ self.classifier_dropout = classifier_dropout if classifier_dropout is not None else 0.1
38
+
39
+ # Computed input size for attribute projection
40
+ # (category_probs + hidden_size) -> num_labels
41
+ self.attr_proj_input_size = self.num_categories + self.hidden_size
42
+
43
+ @staticmethod
44
+ def _get_default_label_schema() -> Dict[str, Any]:
45
+ """Default label schema corresponding to terms2.json"""
46
+ return {
47
+ "label_categories": [
48
+ "n",
49
+ "g",
50
+ "x",
51
+ "e",
52
+ "v",
53
+ "l",
54
+ "fa",
55
+ "fb",
56
+ "fe",
57
+ "fo",
58
+ "fp",
59
+ "fs",
60
+ "ft",
61
+ "tf",
62
+ "ta",
63
+ "tp",
64
+ "to",
65
+ "sn",
66
+ "sb",
67
+ "sf",
68
+ "sv",
69
+ "ss",
70
+ "sl",
71
+ "sþ",
72
+ "cn",
73
+ "ct",
74
+ "c",
75
+ "aa",
76
+ "af",
77
+ "au",
78
+ "ao",
79
+ "aþ",
80
+ "ae",
81
+ "as",
82
+ "ks",
83
+ "kt",
84
+ "p",
85
+ "pl",
86
+ "pk",
87
+ "pg",
88
+ "pa",
89
+ "ns",
90
+ "m",
91
+ ],
92
+ "category_to_group_names": {
93
+ "n": ["gender", "number", "case", "def", "proper"],
94
+ "g": ["gender", "number", "case"],
95
+ "l": ["gender", "number", "case", "adj_c", "deg"],
96
+ "fa": ["gender", "number", "case"],
97
+ "fb": ["gender", "number", "case"],
98
+ "fe": ["gender", "number", "case"],
99
+ "fs": ["gender", "number", "case"],
100
+ "ft": ["gender", "number", "case"],
101
+ "fo": ["gender_or_person", "number", "case"],
102
+ "fp": ["gender_or_person", "number", "case"],
103
+ "tf": ["gender", "number", "case"],
104
+ "sn": ["voice"],
105
+ "sb": ["voice", "person", "number", "tense"],
106
+ "sf": ["voice", "person", "number", "tense"],
107
+ "sv": ["voice", "person", "number", "tense"],
108
+ "ss": ["voice"],
109
+ "sl": ["voice", "person", "number", "tense"],
110
+ "sþ": ["voice", "gender", "number", "case"],
111
+ "aa": ["deg"],
112
+ "af": ["deg"],
113
+ "au": ["deg"],
114
+ "ao": ["deg"],
115
+ "aþ": ["deg"],
116
+ "ae": ["deg"],
117
+ "as": ["deg"],
118
+ },
119
+ "group_names": [
120
+ "gender",
121
+ "gender_or_person",
122
+ "number",
123
+ "case",
124
+ "def",
125
+ "proper",
126
+ "adj_c",
127
+ "deg",
128
+ "voice",
129
+ "person",
130
+ "tense",
131
+ ],
132
+ "group_name_to_labels": {
133
+ "gender": ["masc", "fem", "neut", "gender_x"],
134
+ "number": ["sing", "plur"],
135
+ "person": ["1", "2", "3"],
136
+ "gender_or_person": ["masc", "fem", "neut", "gender_x", "1", "2", "3"],
137
+ "case": ["nom", "acc", "dat", "gen"],
138
+ "deg": ["pos", "cmp", "superl"],
139
+ "voice": ["act", "mid"],
140
+ "tense": ["pres", "past"],
141
+ "def": ["definite"],
142
+ "proper": ["proper"],
143
+ "adj_c": ["strong", "weak", "equiinflected"],
144
+ },
145
+ "labels": [
146
+ "<SEP>",
147
+ "n",
148
+ "g",
149
+ "x",
150
+ "e",
151
+ "v",
152
+ "l",
153
+ "fa",
154
+ "fb",
155
+ "fe",
156
+ "fo",
157
+ "fp",
158
+ "fs",
159
+ "ft",
160
+ "tf",
161
+ "ta",
162
+ "tp",
163
+ "to",
164
+ "sn",
165
+ "sb",
166
+ "sf",
167
+ "sv",
168
+ "ss",
169
+ "sl",
170
+ "sþ",
171
+ "cn",
172
+ "ct",
173
+ "c",
174
+ "aa",
175
+ "af",
176
+ "au",
177
+ "ao",
178
+ "aþ",
179
+ "ae",
180
+ "as",
181
+ "ks",
182
+ "kt",
183
+ "p",
184
+ "pl",
185
+ "pk",
186
+ "pg",
187
+ "pa",
188
+ "ns",
189
+ "m",
190
+ "masc",
191
+ "fem",
192
+ "neut",
193
+ "gender_x",
194
+ "1",
195
+ "2",
196
+ "3",
197
+ "sing",
198
+ "plur",
199
+ "nom",
200
+ "acc",
201
+ "dat",
202
+ "gen",
203
+ "definite",
204
+ "proper",
205
+ "strong",
206
+ "weak",
207
+ "equiinflected",
208
+ "pos",
209
+ "cmp",
210
+ "superl",
211
+ "past",
212
+ "pres",
213
+ "pass",
214
+ "act",
215
+ "mid",
216
+ ],
217
+ "null": None,
218
+ "null_leaf": None,
219
+ "separator": "<SEP>",
220
+ "ignore_categories": ["x", "e"],
221
+ }
222
+
223
+ @classmethod
224
+ def from_label_schema_file(cls, schema_path: str, **kwargs) -> "IceBertPosConfig":
225
+ """Create config from a label schema JSON file"""
226
+ with open(schema_path, "r", encoding="utf-8") as f:
227
+ label_schema = json.load(f)
228
+ return cls(label_schema=label_schema, **kwargs)
229
+
230
+
231
+ AutoConfig.register("icebert-pos", IceBertPosConfig)
merges.txt ADDED
The diff for this file is too large to render. See raw diff
 
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:cc24ca46b3b1024c92be719a8964c1336185c3d188674f2f4b96c1064fdaab7f
3
+ size 497965196
modeling.py ADDED
@@ -0,0 +1,503 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) Miðeind ehf.
2
+ # This file is part of IceBERT POS model conversion.
3
+
4
+ import logging
5
+ from typing import List, Optional, Tuple
6
+
7
+ import torch
8
+ import torch.nn as nn
9
+ import torch.nn.functional as F
10
+ from torch.nn.utils.rnn import pad_sequence
11
+ from transformers import AutoConfig, AutoModel, PreTrainedModel, RobertaModel
12
+
13
+ from .configuration import IceBertPosConfig
14
+
15
+ logger = logging.getLogger(__name__)
16
+
17
+
18
+ class MultiLabelTokenClassificationHead(nn.Module):
19
+ """Head for multilabel word-level classification tasks."""
20
+
21
+ def __init__(self, config: IceBertPosConfig):
22
+ super().__init__()
23
+ self.num_categories = config.num_categories
24
+ self.num_labels = config.num_labels
25
+ self.hidden_size = config.hidden_size
26
+
27
+ self.dense = nn.Linear(self.hidden_size, self.hidden_size)
28
+ self.activation_fn = F.relu
29
+ self.dropout = nn.Dropout(p=config.classifier_dropout)
30
+ self.layer_norm = nn.LayerNorm(self.hidden_size)
31
+
32
+ # Category projection: hidden_size -> num_categories
33
+ self.cat_proj = nn.Linear(self.hidden_size, self.num_categories)
34
+
35
+ # Attribute projection: (hidden_size + num_categories) -> num_labels
36
+ self.out_proj = nn.Linear(self.hidden_size + self.num_categories, self.num_labels)
37
+
38
+ def forward(self, features: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
39
+ """
40
+ Args:
41
+ features: Word-level features of shape (total_words, hidden_size)
42
+
43
+ Returns:
44
+ cat_logits: Category logits of shape (total_words, num_categories)
45
+ attr_logits: Attribute logits of shape (total_words, num_labels)
46
+ """
47
+ x = self.dropout(features)
48
+ x = self.dense(x)
49
+ x = self.layer_norm(x)
50
+ x = self.activation_fn(x)
51
+
52
+ # Predict categories
53
+ cat_logits = self.cat_proj(x)
54
+ cat_probs = torch.softmax(cat_logits, dim=-1)
55
+
56
+ # Predict attributes using concatenated features
57
+ attr_input = torch.cat((cat_probs, x), dim=-1)
58
+ attr_logits = self.out_proj(attr_input)
59
+
60
+ return cat_logits, attr_logits
61
+
62
+
63
+ class IceBertPosForTokenClassification(PreTrainedModel):
64
+ """
65
+ IceBERT model for multilabel token classification (POS tagging).
66
+
67
+ This model performs word-level POS tagging by:
68
+ 1. Encoding input with RoBERTa
69
+ 2. Aggregating subword tokens to word-level representations
70
+ 3. Predicting both categories and attributes for each word
71
+ """
72
+
73
+ config_class = IceBertPosConfig
74
+
75
+ def __init__(self, config: IceBertPosConfig):
76
+ super().__init__(config)
77
+ self.config = config
78
+ self.num_categories = config.num_categories
79
+ self.num_labels = config.num_labels
80
+
81
+ self.roberta = RobertaModel(config, add_pooling_layer=False)
82
+ self.classifier = MultiLabelTokenClassificationHead(config)
83
+
84
+ # Initialize weights and apply final processing
85
+ self.post_init()
86
+
87
+ def forward(
88
+ self,
89
+ input_ids: torch.Tensor,
90
+ attention_mask: torch.Tensor,
91
+ word_mask: torch.Tensor,
92
+ token_type_ids: Optional[torch.Tensor] = None,
93
+ position_ids: Optional[torch.Tensor] = None,
94
+ head_mask: Optional[torch.Tensor] = None,
95
+ inputs_embeds: Optional[torch.Tensor] = None,
96
+ output_attentions: Optional[bool] = None,
97
+ output_hidden_states: Optional[bool] = None,
98
+ return_dict: Optional[bool] = None,
99
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
100
+ """
101
+ Args:
102
+ input_ids: Token indices of shape (batch_size, sequence_length)
103
+ attention_mask: Attention mask of shape (batch_size, sequence_length)
104
+ word_mask: Binary mask indicating word boundaries (1 = word start)
105
+
106
+ Returns:
107
+ cat_logits: Category logits of shape (batch_size, max_words, num_categories)
108
+ attr_logits: Attribute logits of shape (batch_size, max_words, num_labels)
109
+ """
110
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
111
+
112
+ # Get RoBERTa outputs
113
+ outputs = self.roberta(
114
+ input_ids,
115
+ attention_mask=attention_mask,
116
+ token_type_ids=token_type_ids,
117
+ position_ids=position_ids,
118
+ head_mask=head_mask,
119
+ inputs_embeds=inputs_embeds,
120
+ output_attentions=output_attentions,
121
+ output_hidden_states=output_hidden_states,
122
+ return_dict=return_dict,
123
+ )
124
+
125
+ sequence_output = outputs[0] # (batch_size, seq_len, hidden_size)
126
+
127
+ # Aggregate subword tokens to word-level representations using word_mask
128
+ word_features, nwords = self._aggregate_subword_tokens(sequence_output, word_mask)
129
+
130
+ # Apply classification head
131
+ cat_logits, attr_logits = self.classifier(word_features)
132
+
133
+ # Reshape back to batch format using word counts
134
+ cat_logits_batch, attr_logits_batch = self._reshape_to_batch_format(cat_logits, attr_logits, nwords)
135
+
136
+ return cat_logits_batch, attr_logits_batch
137
+
138
+ def _aggregate_subword_tokens(
139
+ self, sequence_output: torch.Tensor, word_mask: torch.Tensor
140
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
141
+ """
142
+ Aggregate subword token representations to word-level representations.
143
+ Following the original fairseq approach by averaging subword tokens within each word.
144
+
145
+ Args:
146
+ sequence_output: subword token representations (batch_size, seq_len, hidden_size)
147
+ word_mask: Binary mask where 1 indicates start of word (batch_size, seq_len)
148
+
149
+ Returns:
150
+ word_features: Word-level features (total_words, hidden_size)
151
+ nwords: Number of words per sequence (batch_size,)
152
+ """
153
+ # TODO: Verify that BOS and EOS are handled correctly - I'm worried that this does not correctly handle padding
154
+ # Remove BOS and EOS tokens (first and last positions)
155
+ x = sequence_output[:, 1:-1, :] # (batch_size, seq_len-2, hidden_size)
156
+ starts = word_mask[:, 1:-1] # (batch_size, seq_len-2)
157
+
158
+ # Count words per sequence
159
+ nwords = starts.sum(dim=-1) # (batch_size,)
160
+
161
+ # Find word boundaries and average tokens within each word
162
+ mean_words = []
163
+ batch_size, seq_len, hidden_size = x.shape
164
+
165
+ for batch_idx in range(batch_size):
166
+ seq_starts = starts[batch_idx] # (seq_len-2,)
167
+ seq_x = x[batch_idx] # (seq_len-2, hidden_size)
168
+
169
+ # Find start positions of words
170
+ start_positions = seq_starts.nonzero(as_tuple=True)[0] # positions where words start
171
+
172
+ if len(start_positions) == 0:
173
+ continue
174
+
175
+ # Calculate end positions (start of next word or end of sequence)
176
+ end_positions = torch.cat([start_positions[1:], torch.tensor([seq_len], device=start_positions.device)])
177
+
178
+ # Average tokens within each word
179
+ for start_pos, end_pos in zip(start_positions, end_positions):
180
+ word_tokens = seq_x[start_pos:end_pos] # tokens in this word
181
+ word_repr = word_tokens.mean(dim=0) # average representation
182
+ mean_words.append(word_repr)
183
+
184
+ if len(mean_words) == 0:
185
+ return torch.empty(0, sequence_output.size(-1), device=sequence_output.device), nwords
186
+
187
+ return torch.stack(mean_words), nwords
188
+
189
+ def _reshape_to_batch_format(
190
+ self, cat_logits: torch.Tensor, attr_logits: torch.Tensor, nwords: torch.Tensor
191
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
192
+ """
193
+ Reshape word-level predictions back to batch format.
194
+ Following the original fairseq approach with pad_sequence.
195
+
196
+ Args:
197
+ cat_logits: Category logits (total_words, num_categories)
198
+ attr_logits: Attribute logits (total_words, num_labels)
199
+ nwords: Number of words per sequence (batch_size,)
200
+
201
+ Returns:
202
+ cat_logits_batch: (batch_size, max_words, num_categories)
203
+ attr_logits_batch: (batch_size, max_words, num_labels)
204
+ """
205
+
206
+ # Split logits by sequence using word counts
207
+ words_per_seq = nwords.tolist()
208
+ cat_logits_split = cat_logits.split(words_per_seq)
209
+ attr_logits_split = attr_logits.split(words_per_seq)
210
+
211
+ # Pad to same length (matching original fairseq approach)
212
+ cat_logits_batch = pad_sequence(cat_logits_split, batch_first=True, padding_value=0)
213
+ attr_logits_batch = pad_sequence(attr_logits_split, batch_first=True, padding_value=0)
214
+
215
+ return cat_logits_batch, attr_logits_batch
216
+
217
+ @torch.no_grad()
218
+ def predict_labels(
219
+ self, input_ids: torch.Tensor, attention_mask: torch.Tensor, word_ids: List[List[int]]
220
+ ) -> List[List[Tuple[str, List[str]]]]:
221
+ """
222
+ Predict POS labels for input sequences.
223
+
224
+ Args:
225
+ input_ids: Token indices
226
+ attention_mask: Attention mask
227
+ word_ids: Word boundaries
228
+
229
+ Returns:
230
+ List of sequences, each containing (category, [attributes]) per word
231
+ """
232
+ # Convert word_ids to word_mask
233
+ word_mask = self._word_ids_to_word_mask(word_ids, input_ids.shape)
234
+
235
+ cat_logits, attr_logits = self.forward(input_ids=input_ids, attention_mask=attention_mask, word_mask=word_mask)
236
+
237
+ return self._logits_to_labels(cat_logits, attr_logits, word_ids)
238
+
239
+ def _word_ids_to_word_mask(self, word_ids: List[List[int]], input_shape: torch.Size) -> torch.Tensor:
240
+ """
241
+ Convert word_ids to word_mask (binary mask indicating word boundaries).
242
+
243
+ Args:
244
+ word_ids: List of word id sequences
245
+ input_shape: Shape of input_ids tensor (batch_size, seq_len)
246
+
247
+ Returns:
248
+ word_mask: Binary tensor where 1 indicates start of word
249
+ """
250
+ batch_size, seq_len = input_shape
251
+ word_mask = torch.zeros(batch_size, seq_len, dtype=torch.long)
252
+
253
+ for batch_idx, seq_word_ids in enumerate(word_ids):
254
+ prev_word_id = None
255
+ for token_idx, word_id in enumerate(seq_word_ids):
256
+ if word_id != prev_word_id:
257
+ word_mask[batch_idx, token_idx] = 1
258
+ prev_word_id = word_id
259
+
260
+ return word_mask
261
+
262
+ def predict_labels_from_text(self, sentences: List[str], tokenizer) -> List[List[Tuple[str, List[str]]]]:
263
+ """
264
+ Predict POS labels from raw text using fairseq-style preprocessing.
265
+
266
+ Args:
267
+ sentences: List of input sentences
268
+ tokenizer: HuggingFace tokenizer
269
+
270
+ Returns:
271
+ List of sequences, each containing (category, [attributes]) per word
272
+ """
273
+ # Tokenize with fairseq-style preprocessing
274
+ encodings = [tokenizer(sent, return_tensors="pt") for sent in sentences]
275
+ word_ids_list = [encoding.word_ids() for encoding in encodings]
276
+
277
+ # Batch the inputs
278
+ max_len = max(encoding["input_ids"].shape[1] for encoding in encodings)
279
+ batch_input_ids = []
280
+ batch_attention_mask = []
281
+
282
+ for encoding in encodings:
283
+ input_ids = encoding["input_ids"][0]
284
+ attention_mask = encoding["attention_mask"][0]
285
+
286
+ # Pad to max length
287
+ pad_len = max_len - len(input_ids)
288
+ if pad_len > 0:
289
+ input_ids = torch.cat([input_ids, torch.ones(pad_len, dtype=torch.long)]) # pad_token_id = 1
290
+ attention_mask = torch.cat([attention_mask, torch.zeros(pad_len, dtype=torch.long)])
291
+
292
+ batch_input_ids.append(input_ids)
293
+ batch_attention_mask.append(attention_mask)
294
+
295
+ batch_input_ids = torch.stack(batch_input_ids)
296
+ batch_attention_mask = torch.stack(batch_attention_mask)
297
+
298
+ return self.predict_labels(batch_input_ids, batch_attention_mask, word_ids_list)
299
+
300
+ def _make_group_name_to_group_attr_vec_idxs(self):
301
+ """Create mapping from group names to their attribute vector indices"""
302
+ group_name_to_group_attr_vec_idxs = {}
303
+ labels = self.config.label_schema["labels"]
304
+ nspecial = 0 # Number of special tokens in label dictionary (like <SEP>)
305
+
306
+ for group_name, group_labels in self.config.label_schema["group_name_to_labels"].items():
307
+ vec_idxs = []
308
+ for label in group_labels:
309
+ if label in labels:
310
+ # Find index in labels list, but subtract nspecial to get vector index
311
+ label_dict_idx = labels.index(label)
312
+ if label_dict_idx >= nspecial: # Skip special tokens
313
+ vec_idxs.append(label_dict_idx - nspecial)
314
+ group_name_to_group_attr_vec_idxs[group_name] = torch.tensor(vec_idxs)
315
+
316
+ return group_name_to_group_attr_vec_idxs
317
+
318
+ def _make_group_masks(self):
319
+ """Create group masks for each category"""
320
+ label_categories = self.config.label_schema["label_categories"]
321
+ group_names = self.config.label_schema["group_names"]
322
+ category_to_group_names = self.config.label_schema["category_to_group_names"]
323
+
324
+ num_cats = len(label_categories)
325
+ num_groups = len(group_names)
326
+
327
+ group_mask = torch.zeros(num_cats, num_groups, dtype=torch.bool)
328
+
329
+ for cat_idx, category in enumerate(label_categories):
330
+ if category in category_to_group_names:
331
+ for group_name in category_to_group_names[category]:
332
+ if group_name in group_names:
333
+ group_idx = group_names.index(group_name)
334
+ group_mask[cat_idx, group_idx] = True
335
+
336
+ return group_mask
337
+
338
+ def _make_category_mappings(self):
339
+ """Create mappings between category vector indices and dictionary indices"""
340
+ labels = self.config.label_schema["labels"]
341
+ label_categories = self.config.label_schema["label_categories"]
342
+
343
+ # Create mapping from category names to vector indices (0-based)
344
+ cat_dict_idx_to_vec_idx = torch.zeros(len(labels), dtype=torch.long)
345
+ cat_vec_idx_to_dict_idx = torch.zeros(len(label_categories), dtype=torch.long)
346
+
347
+ for vec_idx, category in enumerate(label_categories):
348
+ if category in labels:
349
+ dict_idx = labels.index(category)
350
+ cat_dict_idx_to_vec_idx[dict_idx] = vec_idx
351
+ cat_vec_idx_to_dict_idx[vec_idx] = dict_idx
352
+
353
+ return cat_dict_idx_to_vec_idx, cat_vec_idx_to_dict_idx
354
+
355
+ def _count_words_per_sequence(self, word_ids: List[List[int]]) -> List[int]:
356
+ """Count the number of unique words in each sequence."""
357
+ words_per_seq = []
358
+ for seq_word_ids in word_ids:
359
+ unique_word_ids = set(word_id for word_id in seq_word_ids if word_id is not None)
360
+ words_per_seq.append(len(unique_word_ids))
361
+ return words_per_seq
362
+
363
+ def _predict_categories_for_sequence(
364
+ self, cat_logits: torch.Tensor, seq_idx: int, seq_nwords: int, cat_vec_idx_to_dict_idx: torch.Tensor
365
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
366
+ """Predict categories for a single sequence and return both vector and dictionary indices."""
367
+ pred_cat_vec_idxs = cat_logits[seq_idx, :seq_nwords].max(dim=-1).indices
368
+ pred_cats = cat_vec_idx_to_dict_idx[pred_cat_vec_idxs]
369
+ return pred_cat_vec_idxs, pred_cats
370
+
371
+ def _predict_attributes_for_group(
372
+ self,
373
+ attr_logits: torch.Tensor,
374
+ seq_idx: int,
375
+ seq_nwords: int,
376
+ group_vec_idxs: torch.Tensor,
377
+ seq_group_mask: torch.Tensor,
378
+ group_idx: int,
379
+ ) -> torch.Tensor:
380
+ """Predict attributes for a single group."""
381
+ if len(group_vec_idxs) == 0:
382
+ return torch.zeros(seq_nwords, dtype=torch.long)
383
+
384
+ # Get logits for this group
385
+ group_logits = attr_logits[seq_idx, :seq_nwords, group_vec_idxs]
386
+
387
+ if len(group_vec_idxs) == 1:
388
+ # Single element group: use sigmoid > 0.5
389
+ group_pred = group_logits.sigmoid().ge(0.5).long()
390
+ group_pred_dict_idxs = (group_pred.squeeze() * group_vec_idxs.item()) * seq_group_mask[:, group_idx]
391
+ else:
392
+ # Multi element group: use argmax
393
+ group_pred_vec_idxs = group_logits.max(dim=-1).indices
394
+ group_pred_dict_idxs = group_vec_idxs[group_pred_vec_idxs] * seq_group_mask[:, group_idx]
395
+
396
+ return group_pred_dict_idxs
397
+
398
+ def _predict_all_attributes_for_sequence(
399
+ self,
400
+ attr_logits: torch.Tensor,
401
+ seq_idx: int,
402
+ seq_nwords: int,
403
+ pred_cat_vec_idxs: torch.Tensor,
404
+ group_name_to_group_attr_vec_idxs: dict,
405
+ group_mask: torch.Tensor,
406
+ group_names: List[str],
407
+ ) -> torch.Tensor:
408
+ """Predict all attributes for a single sequence."""
409
+ seq_group_mask = group_mask[pred_cat_vec_idxs]
410
+ pred_attrs = []
411
+
412
+ for group_idx, group_name in enumerate(group_names):
413
+ if group_name not in group_name_to_group_attr_vec_idxs:
414
+ pred_attrs.append(torch.zeros(seq_nwords, dtype=torch.long))
415
+ continue
416
+
417
+ group_vec_idxs = group_name_to_group_attr_vec_idxs[group_name]
418
+ group_pred_dict_idxs = self._predict_attributes_for_group(
419
+ attr_logits, seq_idx, seq_nwords, group_vec_idxs, seq_group_mask, group_idx
420
+ )
421
+ pred_attrs.append(group_pred_dict_idxs)
422
+
423
+ # Stack predictions
424
+ if pred_attrs:
425
+ return torch.stack([p.squeeze() if p.dim() > 1 else p for p in pred_attrs]).t()
426
+ else:
427
+ return torch.zeros(seq_nwords, len(group_names), dtype=torch.long)
428
+
429
+ def _convert_predictions_to_labels(
430
+ self, pred_cats: torch.Tensor, pred_attrs_tensor: torch.Tensor, labels: List[str], group_names: List[str]
431
+ ) -> List[Tuple[str, List[str]]]:
432
+ """Convert prediction tensors to human-readable labels."""
433
+ seq_nwords = pred_cats.size(0)
434
+ seq_predictions = []
435
+
436
+ for word_idx in range(seq_nwords):
437
+ # Category (convert from dictionary index to string)
438
+ cat_dict_idx = pred_cats[word_idx].item()
439
+ if cat_dict_idx < len(labels):
440
+ category = labels[cat_dict_idx]
441
+ else:
442
+ category = "UNK"
443
+
444
+ # Attributes (convert from dictionary indices to strings)
445
+ attributes = []
446
+ for group_idx in range(len(group_names)):
447
+ attr_dict_idx = pred_attrs_tensor[word_idx, group_idx].item()
448
+ if attr_dict_idx > 0 and attr_dict_idx < len(labels): # Skip 0 (empty) and out of bounds
449
+ attributes.append(labels[attr_dict_idx])
450
+
451
+ seq_predictions.append((category, attributes))
452
+
453
+ return seq_predictions
454
+
455
+ def _logits_to_labels(
456
+ self, cat_logits: torch.Tensor, attr_logits: torch.Tensor, word_ids: List[List[int]]
457
+ ) -> List[List[Tuple[str, List[str]]]]:
458
+ """
459
+ Convert logits to human-readable labels using fairseq's group-based logic.
460
+ """
461
+ # Create necessary mappings
462
+ group_name_to_group_attr_vec_idxs = self._make_group_name_to_group_attr_vec_idxs()
463
+ group_mask = self._make_group_masks()
464
+ cat_dict_idx_to_vec_idx, cat_vec_idx_to_dict_idx = self._make_category_mappings()
465
+
466
+ label_schema = self.config.label_schema
467
+ labels = label_schema["labels"]
468
+ group_names = label_schema["group_names"]
469
+
470
+ batch_size = cat_logits.size(0)
471
+ words_per_seq = self._count_words_per_sequence(word_ids)
472
+ batch_predictions = []
473
+
474
+ for seq_idx in range(batch_size):
475
+ seq_nwords = words_per_seq[seq_idx]
476
+
477
+ # Predict categories
478
+ pred_cat_vec_idxs, pred_cats = self._predict_categories_for_sequence(
479
+ cat_logits, seq_idx, seq_nwords, cat_vec_idx_to_dict_idx
480
+ )
481
+
482
+ # Predict attributes
483
+ pred_attrs_tensor = self._predict_all_attributes_for_sequence(
484
+ attr_logits,
485
+ seq_idx,
486
+ seq_nwords,
487
+ pred_cat_vec_idxs,
488
+ group_name_to_group_attr_vec_idxs,
489
+ group_mask,
490
+ group_names,
491
+ )
492
+
493
+ # Convert to labels
494
+ seq_predictions = self._convert_predictions_to_labels(pred_cats, pred_attrs_tensor, labels, group_names)
495
+ batch_predictions.append(seq_predictions)
496
+
497
+ return batch_predictions
498
+
499
+
500
+ AutoConfig.register("icebert-pos", IceBertPosConfig)
501
+ AutoModel.register(IceBertPosConfig, IceBertPosForTokenClassification)
502
+ IceBertPosConfig.register_for_auto_class()
503
+ IceBertPosForTokenClassification.register_for_auto_class("AutoModel")
special_tokens_map.json ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "bos_token": "<s>",
3
+ "cls_token": "<s>",
4
+ "eos_token": "</s>",
5
+ "mask_token": {
6
+ "content": "<mask>",
7
+ "lstrip": true,
8
+ "normalized": false,
9
+ "rstrip": false,
10
+ "single_word": false
11
+ },
12
+ "pad_token": "<pad>",
13
+ "sep_token": "</s>",
14
+ "unk_token": "<unk>"
15
+ }
tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
tokenizer_config.json ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "add_prefix_space": true,
3
+ "added_tokens_decoder": {
4
+ "0": {
5
+ "content": "<s>",
6
+ "lstrip": false,
7
+ "normalized": false,
8
+ "rstrip": false,
9
+ "single_word": false,
10
+ "special": true
11
+ },
12
+ "1": {
13
+ "content": "<pad>",
14
+ "lstrip": false,
15
+ "normalized": false,
16
+ "rstrip": false,
17
+ "single_word": false,
18
+ "special": true
19
+ },
20
+ "2": {
21
+ "content": "</s>",
22
+ "lstrip": false,
23
+ "normalized": false,
24
+ "rstrip": false,
25
+ "single_word": false,
26
+ "special": true
27
+ },
28
+ "3": {
29
+ "content": "<unk>",
30
+ "lstrip": false,
31
+ "normalized": false,
32
+ "rstrip": false,
33
+ "single_word": false,
34
+ "special": true
35
+ },
36
+ "49936": {
37
+ "content": "<mask>",
38
+ "lstrip": true,
39
+ "normalized": false,
40
+ "rstrip": false,
41
+ "single_word": false,
42
+ "special": true
43
+ }
44
+ },
45
+ "bos_token": "<s>",
46
+ "clean_up_tokenization_spaces": false,
47
+ "cls_token": "<s>",
48
+ "eos_token": "</s>",
49
+ "errors": "replace",
50
+ "mask_token": "<mask>",
51
+ "model_max_length": 1000000000000000019884624838656,
52
+ "pad_token": "<pad>",
53
+ "sep_token": "</s>",
54
+ "tokenizer_class": "RobertaTokenizer",
55
+ "trim_offsets": true,
56
+ "unk_token": "<unk>"
57
+ }
vocab.json ADDED
The diff for this file is too large to render. See raw diff