Upload folder using huggingface_hub
Browse files- __init__.py +13 -0
- __pycache__/__init__.cpython-38.pyc +0 -0
- __pycache__/configuration.cpython-38.pyc +0 -0
- __pycache__/modeling.cpython-38.pyc +0 -0
- config.json +490 -0
- configuration.py +231 -0
- merges.txt +0 -0
- model.safetensors +3 -0
- modeling.py +503 -0
- special_tokens_map.json +15 -0
- tokenizer.json +0 -0
- tokenizer_config.json +57 -0
- vocab.json +0 -0
__init__.py
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (C) Miðeind ehf.
|
| 2 |
+
# This file is part of IceBERT POS model conversion.
|
| 3 |
+
|
| 4 |
+
from .configuration import IceBertPosConfig
|
| 5 |
+
from .modeling import IceBertPosForTokenClassification, MultiLabelTokenClassificationHead
|
| 6 |
+
|
| 7 |
+
__version__ = "0.1.0"
|
| 8 |
+
|
| 9 |
+
__all__ = [
|
| 10 |
+
"IceBertPosConfig",
|
| 11 |
+
"IceBertPosForTokenClassification",
|
| 12 |
+
"MultiLabelTokenClassificationHead",
|
| 13 |
+
]
|
__pycache__/__init__.cpython-38.pyc
ADDED
|
Binary file (157 Bytes). View file
|
|
|
__pycache__/configuration.cpython-38.pyc
ADDED
|
Binary file (3.52 kB). View file
|
|
|
__pycache__/modeling.cpython-38.pyc
ADDED
|
Binary file (15.2 kB). View file
|
|
|
config.json
ADDED
|
@@ -0,0 +1,490 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"architectures": [
|
| 3 |
+
"IceBertPosForTokenClassification"
|
| 4 |
+
],
|
| 5 |
+
"attention_probs_dropout_prob": 0.1,
|
| 6 |
+
"attr_proj_input_size": 811,
|
| 7 |
+
"auto_map": {
|
| 8 |
+
"AutoConfig": "configuration.IceBertPosConfig",
|
| 9 |
+
"AutoModel": "modeling.IceBertPosForTokenClassification"
|
| 10 |
+
},
|
| 11 |
+
"bos_token_id": 0,
|
| 12 |
+
"classifier_dropout": 0.0,
|
| 13 |
+
"eos_token_id": 2,
|
| 14 |
+
"hidden_act": "gelu",
|
| 15 |
+
"hidden_dropout_prob": 0.1,
|
| 16 |
+
"hidden_size": 768,
|
| 17 |
+
"id2label": {
|
| 18 |
+
"0": "LABEL_0",
|
| 19 |
+
"1": "LABEL_1",
|
| 20 |
+
"2": "LABEL_2",
|
| 21 |
+
"3": "LABEL_3",
|
| 22 |
+
"4": "LABEL_4",
|
| 23 |
+
"5": "LABEL_5",
|
| 24 |
+
"6": "LABEL_6",
|
| 25 |
+
"7": "LABEL_7",
|
| 26 |
+
"8": "LABEL_8",
|
| 27 |
+
"9": "LABEL_9",
|
| 28 |
+
"10": "LABEL_10",
|
| 29 |
+
"11": "LABEL_11",
|
| 30 |
+
"12": "LABEL_12",
|
| 31 |
+
"13": "LABEL_13",
|
| 32 |
+
"14": "LABEL_14",
|
| 33 |
+
"15": "LABEL_15",
|
| 34 |
+
"16": "LABEL_16",
|
| 35 |
+
"17": "LABEL_17",
|
| 36 |
+
"18": "LABEL_18",
|
| 37 |
+
"19": "LABEL_19",
|
| 38 |
+
"20": "LABEL_20",
|
| 39 |
+
"21": "LABEL_21",
|
| 40 |
+
"22": "LABEL_22",
|
| 41 |
+
"23": "LABEL_23",
|
| 42 |
+
"24": "LABEL_24",
|
| 43 |
+
"25": "LABEL_25",
|
| 44 |
+
"26": "LABEL_26",
|
| 45 |
+
"27": "LABEL_27",
|
| 46 |
+
"28": "LABEL_28",
|
| 47 |
+
"29": "LABEL_29",
|
| 48 |
+
"30": "LABEL_30",
|
| 49 |
+
"31": "LABEL_31",
|
| 50 |
+
"32": "LABEL_32",
|
| 51 |
+
"33": "LABEL_33",
|
| 52 |
+
"34": "LABEL_34",
|
| 53 |
+
"35": "LABEL_35",
|
| 54 |
+
"36": "LABEL_36",
|
| 55 |
+
"37": "LABEL_37",
|
| 56 |
+
"38": "LABEL_38",
|
| 57 |
+
"39": "LABEL_39",
|
| 58 |
+
"40": "LABEL_40",
|
| 59 |
+
"41": "LABEL_41",
|
| 60 |
+
"42": "LABEL_42",
|
| 61 |
+
"43": "LABEL_43",
|
| 62 |
+
"44": "LABEL_44",
|
| 63 |
+
"45": "LABEL_45",
|
| 64 |
+
"46": "LABEL_46",
|
| 65 |
+
"47": "LABEL_47",
|
| 66 |
+
"48": "LABEL_48",
|
| 67 |
+
"49": "LABEL_49",
|
| 68 |
+
"50": "LABEL_50",
|
| 69 |
+
"51": "LABEL_51",
|
| 70 |
+
"52": "LABEL_52",
|
| 71 |
+
"53": "LABEL_53",
|
| 72 |
+
"54": "LABEL_54",
|
| 73 |
+
"55": "LABEL_55",
|
| 74 |
+
"56": "LABEL_56",
|
| 75 |
+
"57": "LABEL_57",
|
| 76 |
+
"58": "LABEL_58",
|
| 77 |
+
"59": "LABEL_59",
|
| 78 |
+
"60": "LABEL_60",
|
| 79 |
+
"61": "LABEL_61",
|
| 80 |
+
"62": "LABEL_62",
|
| 81 |
+
"63": "LABEL_63",
|
| 82 |
+
"64": "LABEL_64",
|
| 83 |
+
"65": "LABEL_65",
|
| 84 |
+
"66": "LABEL_66",
|
| 85 |
+
"67": "LABEL_67",
|
| 86 |
+
"68": "LABEL_68",
|
| 87 |
+
"69": "LABEL_69"
|
| 88 |
+
},
|
| 89 |
+
"initializer_range": 0.02,
|
| 90 |
+
"intermediate_size": 3072,
|
| 91 |
+
"label2id": {
|
| 92 |
+
"LABEL_0": 0,
|
| 93 |
+
"LABEL_1": 1,
|
| 94 |
+
"LABEL_10": 10,
|
| 95 |
+
"LABEL_11": 11,
|
| 96 |
+
"LABEL_12": 12,
|
| 97 |
+
"LABEL_13": 13,
|
| 98 |
+
"LABEL_14": 14,
|
| 99 |
+
"LABEL_15": 15,
|
| 100 |
+
"LABEL_16": 16,
|
| 101 |
+
"LABEL_17": 17,
|
| 102 |
+
"LABEL_18": 18,
|
| 103 |
+
"LABEL_19": 19,
|
| 104 |
+
"LABEL_2": 2,
|
| 105 |
+
"LABEL_20": 20,
|
| 106 |
+
"LABEL_21": 21,
|
| 107 |
+
"LABEL_22": 22,
|
| 108 |
+
"LABEL_23": 23,
|
| 109 |
+
"LABEL_24": 24,
|
| 110 |
+
"LABEL_25": 25,
|
| 111 |
+
"LABEL_26": 26,
|
| 112 |
+
"LABEL_27": 27,
|
| 113 |
+
"LABEL_28": 28,
|
| 114 |
+
"LABEL_29": 29,
|
| 115 |
+
"LABEL_3": 3,
|
| 116 |
+
"LABEL_30": 30,
|
| 117 |
+
"LABEL_31": 31,
|
| 118 |
+
"LABEL_32": 32,
|
| 119 |
+
"LABEL_33": 33,
|
| 120 |
+
"LABEL_34": 34,
|
| 121 |
+
"LABEL_35": 35,
|
| 122 |
+
"LABEL_36": 36,
|
| 123 |
+
"LABEL_37": 37,
|
| 124 |
+
"LABEL_38": 38,
|
| 125 |
+
"LABEL_39": 39,
|
| 126 |
+
"LABEL_4": 4,
|
| 127 |
+
"LABEL_40": 40,
|
| 128 |
+
"LABEL_41": 41,
|
| 129 |
+
"LABEL_42": 42,
|
| 130 |
+
"LABEL_43": 43,
|
| 131 |
+
"LABEL_44": 44,
|
| 132 |
+
"LABEL_45": 45,
|
| 133 |
+
"LABEL_46": 46,
|
| 134 |
+
"LABEL_47": 47,
|
| 135 |
+
"LABEL_48": 48,
|
| 136 |
+
"LABEL_49": 49,
|
| 137 |
+
"LABEL_5": 5,
|
| 138 |
+
"LABEL_50": 50,
|
| 139 |
+
"LABEL_51": 51,
|
| 140 |
+
"LABEL_52": 52,
|
| 141 |
+
"LABEL_53": 53,
|
| 142 |
+
"LABEL_54": 54,
|
| 143 |
+
"LABEL_55": 55,
|
| 144 |
+
"LABEL_56": 56,
|
| 145 |
+
"LABEL_57": 57,
|
| 146 |
+
"LABEL_58": 58,
|
| 147 |
+
"LABEL_59": 59,
|
| 148 |
+
"LABEL_6": 6,
|
| 149 |
+
"LABEL_60": 60,
|
| 150 |
+
"LABEL_61": 61,
|
| 151 |
+
"LABEL_62": 62,
|
| 152 |
+
"LABEL_63": 63,
|
| 153 |
+
"LABEL_64": 64,
|
| 154 |
+
"LABEL_65": 65,
|
| 155 |
+
"LABEL_66": 66,
|
| 156 |
+
"LABEL_67": 67,
|
| 157 |
+
"LABEL_68": 68,
|
| 158 |
+
"LABEL_69": 69,
|
| 159 |
+
"LABEL_7": 7,
|
| 160 |
+
"LABEL_8": 8,
|
| 161 |
+
"LABEL_9": 9
|
| 162 |
+
},
|
| 163 |
+
"label_schema": {
|
| 164 |
+
"category_to_group_names": {
|
| 165 |
+
"aa": [
|
| 166 |
+
"deg"
|
| 167 |
+
],
|
| 168 |
+
"ae": [
|
| 169 |
+
"deg"
|
| 170 |
+
],
|
| 171 |
+
"af": [
|
| 172 |
+
"deg"
|
| 173 |
+
],
|
| 174 |
+
"ao": [
|
| 175 |
+
"deg"
|
| 176 |
+
],
|
| 177 |
+
"as": [
|
| 178 |
+
"deg"
|
| 179 |
+
],
|
| 180 |
+
"au": [
|
| 181 |
+
"deg"
|
| 182 |
+
],
|
| 183 |
+
"a\u00fe": [
|
| 184 |
+
"deg"
|
| 185 |
+
],
|
| 186 |
+
"fa": [
|
| 187 |
+
"gender",
|
| 188 |
+
"number",
|
| 189 |
+
"case"
|
| 190 |
+
],
|
| 191 |
+
"fb": [
|
| 192 |
+
"gender",
|
| 193 |
+
"number",
|
| 194 |
+
"case"
|
| 195 |
+
],
|
| 196 |
+
"fe": [
|
| 197 |
+
"gender",
|
| 198 |
+
"number",
|
| 199 |
+
"case"
|
| 200 |
+
],
|
| 201 |
+
"fo": [
|
| 202 |
+
"gender_or_person",
|
| 203 |
+
"number",
|
| 204 |
+
"case"
|
| 205 |
+
],
|
| 206 |
+
"fp": [
|
| 207 |
+
"gender_or_person",
|
| 208 |
+
"number",
|
| 209 |
+
"case"
|
| 210 |
+
],
|
| 211 |
+
"fs": [
|
| 212 |
+
"gender",
|
| 213 |
+
"number",
|
| 214 |
+
"case"
|
| 215 |
+
],
|
| 216 |
+
"ft": [
|
| 217 |
+
"gender",
|
| 218 |
+
"number",
|
| 219 |
+
"case"
|
| 220 |
+
],
|
| 221 |
+
"g": [
|
| 222 |
+
"gender",
|
| 223 |
+
"number",
|
| 224 |
+
"case"
|
| 225 |
+
],
|
| 226 |
+
"l": [
|
| 227 |
+
"gender",
|
| 228 |
+
"number",
|
| 229 |
+
"case",
|
| 230 |
+
"adj_c",
|
| 231 |
+
"deg"
|
| 232 |
+
],
|
| 233 |
+
"n": [
|
| 234 |
+
"gender",
|
| 235 |
+
"number",
|
| 236 |
+
"case",
|
| 237 |
+
"def",
|
| 238 |
+
"proper"
|
| 239 |
+
],
|
| 240 |
+
"sb": [
|
| 241 |
+
"voice",
|
| 242 |
+
"person",
|
| 243 |
+
"number",
|
| 244 |
+
"tense"
|
| 245 |
+
],
|
| 246 |
+
"sf": [
|
| 247 |
+
"voice",
|
| 248 |
+
"person",
|
| 249 |
+
"number",
|
| 250 |
+
"tense"
|
| 251 |
+
],
|
| 252 |
+
"sl": [
|
| 253 |
+
"voice",
|
| 254 |
+
"person",
|
| 255 |
+
"number",
|
| 256 |
+
"tense"
|
| 257 |
+
],
|
| 258 |
+
"sn": [
|
| 259 |
+
"voice"
|
| 260 |
+
],
|
| 261 |
+
"ss": [
|
| 262 |
+
"voice"
|
| 263 |
+
],
|
| 264 |
+
"sv": [
|
| 265 |
+
"voice",
|
| 266 |
+
"person",
|
| 267 |
+
"number",
|
| 268 |
+
"tense"
|
| 269 |
+
],
|
| 270 |
+
"s\u00fe": [
|
| 271 |
+
"voice",
|
| 272 |
+
"gender",
|
| 273 |
+
"number",
|
| 274 |
+
"case"
|
| 275 |
+
],
|
| 276 |
+
"tf": [
|
| 277 |
+
"gender",
|
| 278 |
+
"number",
|
| 279 |
+
"case"
|
| 280 |
+
]
|
| 281 |
+
},
|
| 282 |
+
"group_name_to_labels": {
|
| 283 |
+
"adj_c": [
|
| 284 |
+
"strong",
|
| 285 |
+
"weak",
|
| 286 |
+
"equiinflected"
|
| 287 |
+
],
|
| 288 |
+
"case": [
|
| 289 |
+
"nom",
|
| 290 |
+
"acc",
|
| 291 |
+
"dat",
|
| 292 |
+
"gen"
|
| 293 |
+
],
|
| 294 |
+
"def": [
|
| 295 |
+
"definite"
|
| 296 |
+
],
|
| 297 |
+
"deg": [
|
| 298 |
+
"pos",
|
| 299 |
+
"cmp",
|
| 300 |
+
"superl"
|
| 301 |
+
],
|
| 302 |
+
"gender": [
|
| 303 |
+
"masc",
|
| 304 |
+
"fem",
|
| 305 |
+
"neut",
|
| 306 |
+
"gender_x"
|
| 307 |
+
],
|
| 308 |
+
"gender_or_person": [
|
| 309 |
+
"masc",
|
| 310 |
+
"fem",
|
| 311 |
+
"neut",
|
| 312 |
+
"gender_x",
|
| 313 |
+
"1",
|
| 314 |
+
"2",
|
| 315 |
+
"3"
|
| 316 |
+
],
|
| 317 |
+
"number": [
|
| 318 |
+
"sing",
|
| 319 |
+
"plur"
|
| 320 |
+
],
|
| 321 |
+
"person": [
|
| 322 |
+
"1",
|
| 323 |
+
"2",
|
| 324 |
+
"3"
|
| 325 |
+
],
|
| 326 |
+
"proper": [
|
| 327 |
+
"proper"
|
| 328 |
+
],
|
| 329 |
+
"tense": [
|
| 330 |
+
"pres",
|
| 331 |
+
"past"
|
| 332 |
+
],
|
| 333 |
+
"voice": [
|
| 334 |
+
"act",
|
| 335 |
+
"mid"
|
| 336 |
+
]
|
| 337 |
+
},
|
| 338 |
+
"group_names": [
|
| 339 |
+
"gender",
|
| 340 |
+
"gender_or_person",
|
| 341 |
+
"number",
|
| 342 |
+
"case",
|
| 343 |
+
"def",
|
| 344 |
+
"proper",
|
| 345 |
+
"adj_c",
|
| 346 |
+
"deg",
|
| 347 |
+
"voice",
|
| 348 |
+
"person",
|
| 349 |
+
"tense"
|
| 350 |
+
],
|
| 351 |
+
"ignore_categories": [
|
| 352 |
+
"x",
|
| 353 |
+
"e"
|
| 354 |
+
],
|
| 355 |
+
"label_categories": [
|
| 356 |
+
"n",
|
| 357 |
+
"g",
|
| 358 |
+
"x",
|
| 359 |
+
"e",
|
| 360 |
+
"v",
|
| 361 |
+
"l",
|
| 362 |
+
"fa",
|
| 363 |
+
"fb",
|
| 364 |
+
"fe",
|
| 365 |
+
"fo",
|
| 366 |
+
"fp",
|
| 367 |
+
"fs",
|
| 368 |
+
"ft",
|
| 369 |
+
"tf",
|
| 370 |
+
"ta",
|
| 371 |
+
"tp",
|
| 372 |
+
"to",
|
| 373 |
+
"sn",
|
| 374 |
+
"sb",
|
| 375 |
+
"sf",
|
| 376 |
+
"sv",
|
| 377 |
+
"ss",
|
| 378 |
+
"sl",
|
| 379 |
+
"s\u00fe",
|
| 380 |
+
"cn",
|
| 381 |
+
"ct",
|
| 382 |
+
"c",
|
| 383 |
+
"aa",
|
| 384 |
+
"af",
|
| 385 |
+
"au",
|
| 386 |
+
"ao",
|
| 387 |
+
"a\u00fe",
|
| 388 |
+
"ae",
|
| 389 |
+
"as",
|
| 390 |
+
"ks",
|
| 391 |
+
"kt",
|
| 392 |
+
"p",
|
| 393 |
+
"pl",
|
| 394 |
+
"pk",
|
| 395 |
+
"pg",
|
| 396 |
+
"pa",
|
| 397 |
+
"ns",
|
| 398 |
+
"m"
|
| 399 |
+
],
|
| 400 |
+
"labels": [
|
| 401 |
+
"<SEP>",
|
| 402 |
+
"n",
|
| 403 |
+
"g",
|
| 404 |
+
"x",
|
| 405 |
+
"e",
|
| 406 |
+
"v",
|
| 407 |
+
"l",
|
| 408 |
+
"fa",
|
| 409 |
+
"fb",
|
| 410 |
+
"fe",
|
| 411 |
+
"fo",
|
| 412 |
+
"fp",
|
| 413 |
+
"fs",
|
| 414 |
+
"ft",
|
| 415 |
+
"tf",
|
| 416 |
+
"ta",
|
| 417 |
+
"tp",
|
| 418 |
+
"to",
|
| 419 |
+
"sn",
|
| 420 |
+
"sb",
|
| 421 |
+
"sf",
|
| 422 |
+
"sv",
|
| 423 |
+
"ss",
|
| 424 |
+
"sl",
|
| 425 |
+
"s\u00fe",
|
| 426 |
+
"cn",
|
| 427 |
+
"ct",
|
| 428 |
+
"c",
|
| 429 |
+
"aa",
|
| 430 |
+
"af",
|
| 431 |
+
"au",
|
| 432 |
+
"ao",
|
| 433 |
+
"a\u00fe",
|
| 434 |
+
"ae",
|
| 435 |
+
"as",
|
| 436 |
+
"ks",
|
| 437 |
+
"kt",
|
| 438 |
+
"p",
|
| 439 |
+
"pl",
|
| 440 |
+
"pk",
|
| 441 |
+
"pg",
|
| 442 |
+
"pa",
|
| 443 |
+
"ns",
|
| 444 |
+
"m",
|
| 445 |
+
"masc",
|
| 446 |
+
"fem",
|
| 447 |
+
"neut",
|
| 448 |
+
"gender_x",
|
| 449 |
+
"1",
|
| 450 |
+
"2",
|
| 451 |
+
"3",
|
| 452 |
+
"sing",
|
| 453 |
+
"plur",
|
| 454 |
+
"nom",
|
| 455 |
+
"acc",
|
| 456 |
+
"dat",
|
| 457 |
+
"gen",
|
| 458 |
+
"definite",
|
| 459 |
+
"proper",
|
| 460 |
+
"strong",
|
| 461 |
+
"weak",
|
| 462 |
+
"equiinflected",
|
| 463 |
+
"pos",
|
| 464 |
+
"cmp",
|
| 465 |
+
"superl",
|
| 466 |
+
"past",
|
| 467 |
+
"pres",
|
| 468 |
+
"pass",
|
| 469 |
+
"act",
|
| 470 |
+
"mid"
|
| 471 |
+
],
|
| 472 |
+
"null": null,
|
| 473 |
+
"null_leaf": null,
|
| 474 |
+
"separator": "<SEP>"
|
| 475 |
+
},
|
| 476 |
+
"layer_norm_eps": 1e-05,
|
| 477 |
+
"max_position_embeddings": 514,
|
| 478 |
+
"model_type": "icebert-pos",
|
| 479 |
+
"num_attention_heads": 12,
|
| 480 |
+
"num_categories": 43,
|
| 481 |
+
"num_groups": 12,
|
| 482 |
+
"num_hidden_layers": 12,
|
| 483 |
+
"pad_token_id": 1,
|
| 484 |
+
"position_embedding_type": "absolute",
|
| 485 |
+
"torch_dtype": "float32",
|
| 486 |
+
"transformers_version": "4.46.3",
|
| 487 |
+
"type_vocab_size": 1,
|
| 488 |
+
"use_cache": true,
|
| 489 |
+
"vocab_size": 49937
|
| 490 |
+
}
|
configuration.py
ADDED
|
@@ -0,0 +1,231 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (C) Miðeind ehf.
|
| 2 |
+
# This file is part of IceBERT POS model conversion.
|
| 3 |
+
|
| 4 |
+
import json
|
| 5 |
+
from typing import Any, Dict, Optional
|
| 6 |
+
|
| 7 |
+
from transformers import AutoConfig, RobertaConfig
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class IceBertPosConfig(RobertaConfig):
|
| 11 |
+
"""
|
| 12 |
+
Configuration class for IceBERT POS (Part-of-Speech) tagging model.
|
| 13 |
+
|
| 14 |
+
This configuration inherits from RobertaConfig and adds POS-specific parameters
|
| 15 |
+
derived from the label schema used for multilabel token classification.
|
| 16 |
+
"""
|
| 17 |
+
|
| 18 |
+
model_type = "icebert-pos"
|
| 19 |
+
|
| 20 |
+
def __init__(
|
| 21 |
+
self, label_schema: Optional[Dict[str, Any]] = None, classifier_dropout: Optional[float] = None, **kwargs
|
| 22 |
+
):
|
| 23 |
+
super().__init__(**kwargs)
|
| 24 |
+
|
| 25 |
+
# Default label schema (terms2.json content)
|
| 26 |
+
if label_schema is None:
|
| 27 |
+
label_schema = self._get_default_label_schema()
|
| 28 |
+
|
| 29 |
+
self.label_schema = label_schema
|
| 30 |
+
|
| 31 |
+
# Derive parameters from label schema
|
| 32 |
+
self.num_categories = len(label_schema["label_categories"])
|
| 33 |
+
self.num_labels = len(label_schema["labels"])
|
| 34 |
+
self.num_groups = len(label_schema["group_names"])
|
| 35 |
+
|
| 36 |
+
# Classification head parameters
|
| 37 |
+
self.classifier_dropout = classifier_dropout if classifier_dropout is not None else 0.1
|
| 38 |
+
|
| 39 |
+
# Computed input size for attribute projection
|
| 40 |
+
# (category_probs + hidden_size) -> num_labels
|
| 41 |
+
self.attr_proj_input_size = self.num_categories + self.hidden_size
|
| 42 |
+
|
| 43 |
+
@staticmethod
|
| 44 |
+
def _get_default_label_schema() -> Dict[str, Any]:
|
| 45 |
+
"""Default label schema corresponding to terms2.json"""
|
| 46 |
+
return {
|
| 47 |
+
"label_categories": [
|
| 48 |
+
"n",
|
| 49 |
+
"g",
|
| 50 |
+
"x",
|
| 51 |
+
"e",
|
| 52 |
+
"v",
|
| 53 |
+
"l",
|
| 54 |
+
"fa",
|
| 55 |
+
"fb",
|
| 56 |
+
"fe",
|
| 57 |
+
"fo",
|
| 58 |
+
"fp",
|
| 59 |
+
"fs",
|
| 60 |
+
"ft",
|
| 61 |
+
"tf",
|
| 62 |
+
"ta",
|
| 63 |
+
"tp",
|
| 64 |
+
"to",
|
| 65 |
+
"sn",
|
| 66 |
+
"sb",
|
| 67 |
+
"sf",
|
| 68 |
+
"sv",
|
| 69 |
+
"ss",
|
| 70 |
+
"sl",
|
| 71 |
+
"sþ",
|
| 72 |
+
"cn",
|
| 73 |
+
"ct",
|
| 74 |
+
"c",
|
| 75 |
+
"aa",
|
| 76 |
+
"af",
|
| 77 |
+
"au",
|
| 78 |
+
"ao",
|
| 79 |
+
"aþ",
|
| 80 |
+
"ae",
|
| 81 |
+
"as",
|
| 82 |
+
"ks",
|
| 83 |
+
"kt",
|
| 84 |
+
"p",
|
| 85 |
+
"pl",
|
| 86 |
+
"pk",
|
| 87 |
+
"pg",
|
| 88 |
+
"pa",
|
| 89 |
+
"ns",
|
| 90 |
+
"m",
|
| 91 |
+
],
|
| 92 |
+
"category_to_group_names": {
|
| 93 |
+
"n": ["gender", "number", "case", "def", "proper"],
|
| 94 |
+
"g": ["gender", "number", "case"],
|
| 95 |
+
"l": ["gender", "number", "case", "adj_c", "deg"],
|
| 96 |
+
"fa": ["gender", "number", "case"],
|
| 97 |
+
"fb": ["gender", "number", "case"],
|
| 98 |
+
"fe": ["gender", "number", "case"],
|
| 99 |
+
"fs": ["gender", "number", "case"],
|
| 100 |
+
"ft": ["gender", "number", "case"],
|
| 101 |
+
"fo": ["gender_or_person", "number", "case"],
|
| 102 |
+
"fp": ["gender_or_person", "number", "case"],
|
| 103 |
+
"tf": ["gender", "number", "case"],
|
| 104 |
+
"sn": ["voice"],
|
| 105 |
+
"sb": ["voice", "person", "number", "tense"],
|
| 106 |
+
"sf": ["voice", "person", "number", "tense"],
|
| 107 |
+
"sv": ["voice", "person", "number", "tense"],
|
| 108 |
+
"ss": ["voice"],
|
| 109 |
+
"sl": ["voice", "person", "number", "tense"],
|
| 110 |
+
"sþ": ["voice", "gender", "number", "case"],
|
| 111 |
+
"aa": ["deg"],
|
| 112 |
+
"af": ["deg"],
|
| 113 |
+
"au": ["deg"],
|
| 114 |
+
"ao": ["deg"],
|
| 115 |
+
"aþ": ["deg"],
|
| 116 |
+
"ae": ["deg"],
|
| 117 |
+
"as": ["deg"],
|
| 118 |
+
},
|
| 119 |
+
"group_names": [
|
| 120 |
+
"gender",
|
| 121 |
+
"gender_or_person",
|
| 122 |
+
"number",
|
| 123 |
+
"case",
|
| 124 |
+
"def",
|
| 125 |
+
"proper",
|
| 126 |
+
"adj_c",
|
| 127 |
+
"deg",
|
| 128 |
+
"voice",
|
| 129 |
+
"person",
|
| 130 |
+
"tense",
|
| 131 |
+
],
|
| 132 |
+
"group_name_to_labels": {
|
| 133 |
+
"gender": ["masc", "fem", "neut", "gender_x"],
|
| 134 |
+
"number": ["sing", "plur"],
|
| 135 |
+
"person": ["1", "2", "3"],
|
| 136 |
+
"gender_or_person": ["masc", "fem", "neut", "gender_x", "1", "2", "3"],
|
| 137 |
+
"case": ["nom", "acc", "dat", "gen"],
|
| 138 |
+
"deg": ["pos", "cmp", "superl"],
|
| 139 |
+
"voice": ["act", "mid"],
|
| 140 |
+
"tense": ["pres", "past"],
|
| 141 |
+
"def": ["definite"],
|
| 142 |
+
"proper": ["proper"],
|
| 143 |
+
"adj_c": ["strong", "weak", "equiinflected"],
|
| 144 |
+
},
|
| 145 |
+
"labels": [
|
| 146 |
+
"<SEP>",
|
| 147 |
+
"n",
|
| 148 |
+
"g",
|
| 149 |
+
"x",
|
| 150 |
+
"e",
|
| 151 |
+
"v",
|
| 152 |
+
"l",
|
| 153 |
+
"fa",
|
| 154 |
+
"fb",
|
| 155 |
+
"fe",
|
| 156 |
+
"fo",
|
| 157 |
+
"fp",
|
| 158 |
+
"fs",
|
| 159 |
+
"ft",
|
| 160 |
+
"tf",
|
| 161 |
+
"ta",
|
| 162 |
+
"tp",
|
| 163 |
+
"to",
|
| 164 |
+
"sn",
|
| 165 |
+
"sb",
|
| 166 |
+
"sf",
|
| 167 |
+
"sv",
|
| 168 |
+
"ss",
|
| 169 |
+
"sl",
|
| 170 |
+
"sþ",
|
| 171 |
+
"cn",
|
| 172 |
+
"ct",
|
| 173 |
+
"c",
|
| 174 |
+
"aa",
|
| 175 |
+
"af",
|
| 176 |
+
"au",
|
| 177 |
+
"ao",
|
| 178 |
+
"aþ",
|
| 179 |
+
"ae",
|
| 180 |
+
"as",
|
| 181 |
+
"ks",
|
| 182 |
+
"kt",
|
| 183 |
+
"p",
|
| 184 |
+
"pl",
|
| 185 |
+
"pk",
|
| 186 |
+
"pg",
|
| 187 |
+
"pa",
|
| 188 |
+
"ns",
|
| 189 |
+
"m",
|
| 190 |
+
"masc",
|
| 191 |
+
"fem",
|
| 192 |
+
"neut",
|
| 193 |
+
"gender_x",
|
| 194 |
+
"1",
|
| 195 |
+
"2",
|
| 196 |
+
"3",
|
| 197 |
+
"sing",
|
| 198 |
+
"plur",
|
| 199 |
+
"nom",
|
| 200 |
+
"acc",
|
| 201 |
+
"dat",
|
| 202 |
+
"gen",
|
| 203 |
+
"definite",
|
| 204 |
+
"proper",
|
| 205 |
+
"strong",
|
| 206 |
+
"weak",
|
| 207 |
+
"equiinflected",
|
| 208 |
+
"pos",
|
| 209 |
+
"cmp",
|
| 210 |
+
"superl",
|
| 211 |
+
"past",
|
| 212 |
+
"pres",
|
| 213 |
+
"pass",
|
| 214 |
+
"act",
|
| 215 |
+
"mid",
|
| 216 |
+
],
|
| 217 |
+
"null": None,
|
| 218 |
+
"null_leaf": None,
|
| 219 |
+
"separator": "<SEP>",
|
| 220 |
+
"ignore_categories": ["x", "e"],
|
| 221 |
+
}
|
| 222 |
+
|
| 223 |
+
@classmethod
|
| 224 |
+
def from_label_schema_file(cls, schema_path: str, **kwargs) -> "IceBertPosConfig":
|
| 225 |
+
"""Create config from a label schema JSON file"""
|
| 226 |
+
with open(schema_path, "r", encoding="utf-8") as f:
|
| 227 |
+
label_schema = json.load(f)
|
| 228 |
+
return cls(label_schema=label_schema, **kwargs)
|
| 229 |
+
|
| 230 |
+
|
| 231 |
+
AutoConfig.register("icebert-pos", IceBertPosConfig)
|
merges.txt
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
model.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:cc24ca46b3b1024c92be719a8964c1336185c3d188674f2f4b96c1064fdaab7f
|
| 3 |
+
size 497965196
|
modeling.py
ADDED
|
@@ -0,0 +1,503 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (C) Miðeind ehf.
|
| 2 |
+
# This file is part of IceBERT POS model conversion.
|
| 3 |
+
|
| 4 |
+
import logging
|
| 5 |
+
from typing import List, Optional, Tuple
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
import torch.nn as nn
|
| 9 |
+
import torch.nn.functional as F
|
| 10 |
+
from torch.nn.utils.rnn import pad_sequence
|
| 11 |
+
from transformers import AutoConfig, AutoModel, PreTrainedModel, RobertaModel
|
| 12 |
+
|
| 13 |
+
from .configuration import IceBertPosConfig
|
| 14 |
+
|
| 15 |
+
logger = logging.getLogger(__name__)
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class MultiLabelTokenClassificationHead(nn.Module):
|
| 19 |
+
"""Head for multilabel word-level classification tasks."""
|
| 20 |
+
|
| 21 |
+
def __init__(self, config: IceBertPosConfig):
|
| 22 |
+
super().__init__()
|
| 23 |
+
self.num_categories = config.num_categories
|
| 24 |
+
self.num_labels = config.num_labels
|
| 25 |
+
self.hidden_size = config.hidden_size
|
| 26 |
+
|
| 27 |
+
self.dense = nn.Linear(self.hidden_size, self.hidden_size)
|
| 28 |
+
self.activation_fn = F.relu
|
| 29 |
+
self.dropout = nn.Dropout(p=config.classifier_dropout)
|
| 30 |
+
self.layer_norm = nn.LayerNorm(self.hidden_size)
|
| 31 |
+
|
| 32 |
+
# Category projection: hidden_size -> num_categories
|
| 33 |
+
self.cat_proj = nn.Linear(self.hidden_size, self.num_categories)
|
| 34 |
+
|
| 35 |
+
# Attribute projection: (hidden_size + num_categories) -> num_labels
|
| 36 |
+
self.out_proj = nn.Linear(self.hidden_size + self.num_categories, self.num_labels)
|
| 37 |
+
|
| 38 |
+
def forward(self, features: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 39 |
+
"""
|
| 40 |
+
Args:
|
| 41 |
+
features: Word-level features of shape (total_words, hidden_size)
|
| 42 |
+
|
| 43 |
+
Returns:
|
| 44 |
+
cat_logits: Category logits of shape (total_words, num_categories)
|
| 45 |
+
attr_logits: Attribute logits of shape (total_words, num_labels)
|
| 46 |
+
"""
|
| 47 |
+
x = self.dropout(features)
|
| 48 |
+
x = self.dense(x)
|
| 49 |
+
x = self.layer_norm(x)
|
| 50 |
+
x = self.activation_fn(x)
|
| 51 |
+
|
| 52 |
+
# Predict categories
|
| 53 |
+
cat_logits = self.cat_proj(x)
|
| 54 |
+
cat_probs = torch.softmax(cat_logits, dim=-1)
|
| 55 |
+
|
| 56 |
+
# Predict attributes using concatenated features
|
| 57 |
+
attr_input = torch.cat((cat_probs, x), dim=-1)
|
| 58 |
+
attr_logits = self.out_proj(attr_input)
|
| 59 |
+
|
| 60 |
+
return cat_logits, attr_logits
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
class IceBertPosForTokenClassification(PreTrainedModel):
|
| 64 |
+
"""
|
| 65 |
+
IceBERT model for multilabel token classification (POS tagging).
|
| 66 |
+
|
| 67 |
+
This model performs word-level POS tagging by:
|
| 68 |
+
1. Encoding input with RoBERTa
|
| 69 |
+
2. Aggregating subword tokens to word-level representations
|
| 70 |
+
3. Predicting both categories and attributes for each word
|
| 71 |
+
"""
|
| 72 |
+
|
| 73 |
+
config_class = IceBertPosConfig
|
| 74 |
+
|
| 75 |
+
def __init__(self, config: IceBertPosConfig):
|
| 76 |
+
super().__init__(config)
|
| 77 |
+
self.config = config
|
| 78 |
+
self.num_categories = config.num_categories
|
| 79 |
+
self.num_labels = config.num_labels
|
| 80 |
+
|
| 81 |
+
self.roberta = RobertaModel(config, add_pooling_layer=False)
|
| 82 |
+
self.classifier = MultiLabelTokenClassificationHead(config)
|
| 83 |
+
|
| 84 |
+
# Initialize weights and apply final processing
|
| 85 |
+
self.post_init()
|
| 86 |
+
|
| 87 |
+
def forward(
|
| 88 |
+
self,
|
| 89 |
+
input_ids: torch.Tensor,
|
| 90 |
+
attention_mask: torch.Tensor,
|
| 91 |
+
word_mask: torch.Tensor,
|
| 92 |
+
token_type_ids: Optional[torch.Tensor] = None,
|
| 93 |
+
position_ids: Optional[torch.Tensor] = None,
|
| 94 |
+
head_mask: Optional[torch.Tensor] = None,
|
| 95 |
+
inputs_embeds: Optional[torch.Tensor] = None,
|
| 96 |
+
output_attentions: Optional[bool] = None,
|
| 97 |
+
output_hidden_states: Optional[bool] = None,
|
| 98 |
+
return_dict: Optional[bool] = None,
|
| 99 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 100 |
+
"""
|
| 101 |
+
Args:
|
| 102 |
+
input_ids: Token indices of shape (batch_size, sequence_length)
|
| 103 |
+
attention_mask: Attention mask of shape (batch_size, sequence_length)
|
| 104 |
+
word_mask: Binary mask indicating word boundaries (1 = word start)
|
| 105 |
+
|
| 106 |
+
Returns:
|
| 107 |
+
cat_logits: Category logits of shape (batch_size, max_words, num_categories)
|
| 108 |
+
attr_logits: Attribute logits of shape (batch_size, max_words, num_labels)
|
| 109 |
+
"""
|
| 110 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 111 |
+
|
| 112 |
+
# Get RoBERTa outputs
|
| 113 |
+
outputs = self.roberta(
|
| 114 |
+
input_ids,
|
| 115 |
+
attention_mask=attention_mask,
|
| 116 |
+
token_type_ids=token_type_ids,
|
| 117 |
+
position_ids=position_ids,
|
| 118 |
+
head_mask=head_mask,
|
| 119 |
+
inputs_embeds=inputs_embeds,
|
| 120 |
+
output_attentions=output_attentions,
|
| 121 |
+
output_hidden_states=output_hidden_states,
|
| 122 |
+
return_dict=return_dict,
|
| 123 |
+
)
|
| 124 |
+
|
| 125 |
+
sequence_output = outputs[0] # (batch_size, seq_len, hidden_size)
|
| 126 |
+
|
| 127 |
+
# Aggregate subword tokens to word-level representations using word_mask
|
| 128 |
+
word_features, nwords = self._aggregate_subword_tokens(sequence_output, word_mask)
|
| 129 |
+
|
| 130 |
+
# Apply classification head
|
| 131 |
+
cat_logits, attr_logits = self.classifier(word_features)
|
| 132 |
+
|
| 133 |
+
# Reshape back to batch format using word counts
|
| 134 |
+
cat_logits_batch, attr_logits_batch = self._reshape_to_batch_format(cat_logits, attr_logits, nwords)
|
| 135 |
+
|
| 136 |
+
return cat_logits_batch, attr_logits_batch
|
| 137 |
+
|
| 138 |
+
def _aggregate_subword_tokens(
|
| 139 |
+
self, sequence_output: torch.Tensor, word_mask: torch.Tensor
|
| 140 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 141 |
+
"""
|
| 142 |
+
Aggregate subword token representations to word-level representations.
|
| 143 |
+
Following the original fairseq approach by averaging subword tokens within each word.
|
| 144 |
+
|
| 145 |
+
Args:
|
| 146 |
+
sequence_output: subword token representations (batch_size, seq_len, hidden_size)
|
| 147 |
+
word_mask: Binary mask where 1 indicates start of word (batch_size, seq_len)
|
| 148 |
+
|
| 149 |
+
Returns:
|
| 150 |
+
word_features: Word-level features (total_words, hidden_size)
|
| 151 |
+
nwords: Number of words per sequence (batch_size,)
|
| 152 |
+
"""
|
| 153 |
+
# TODO: Verify that BOS and EOS are handled correctly - I'm worried that this does not correctly handle padding
|
| 154 |
+
# Remove BOS and EOS tokens (first and last positions)
|
| 155 |
+
x = sequence_output[:, 1:-1, :] # (batch_size, seq_len-2, hidden_size)
|
| 156 |
+
starts = word_mask[:, 1:-1] # (batch_size, seq_len-2)
|
| 157 |
+
|
| 158 |
+
# Count words per sequence
|
| 159 |
+
nwords = starts.sum(dim=-1) # (batch_size,)
|
| 160 |
+
|
| 161 |
+
# Find word boundaries and average tokens within each word
|
| 162 |
+
mean_words = []
|
| 163 |
+
batch_size, seq_len, hidden_size = x.shape
|
| 164 |
+
|
| 165 |
+
for batch_idx in range(batch_size):
|
| 166 |
+
seq_starts = starts[batch_idx] # (seq_len-2,)
|
| 167 |
+
seq_x = x[batch_idx] # (seq_len-2, hidden_size)
|
| 168 |
+
|
| 169 |
+
# Find start positions of words
|
| 170 |
+
start_positions = seq_starts.nonzero(as_tuple=True)[0] # positions where words start
|
| 171 |
+
|
| 172 |
+
if len(start_positions) == 0:
|
| 173 |
+
continue
|
| 174 |
+
|
| 175 |
+
# Calculate end positions (start of next word or end of sequence)
|
| 176 |
+
end_positions = torch.cat([start_positions[1:], torch.tensor([seq_len], device=start_positions.device)])
|
| 177 |
+
|
| 178 |
+
# Average tokens within each word
|
| 179 |
+
for start_pos, end_pos in zip(start_positions, end_positions):
|
| 180 |
+
word_tokens = seq_x[start_pos:end_pos] # tokens in this word
|
| 181 |
+
word_repr = word_tokens.mean(dim=0) # average representation
|
| 182 |
+
mean_words.append(word_repr)
|
| 183 |
+
|
| 184 |
+
if len(mean_words) == 0:
|
| 185 |
+
return torch.empty(0, sequence_output.size(-1), device=sequence_output.device), nwords
|
| 186 |
+
|
| 187 |
+
return torch.stack(mean_words), nwords
|
| 188 |
+
|
| 189 |
+
def _reshape_to_batch_format(
|
| 190 |
+
self, cat_logits: torch.Tensor, attr_logits: torch.Tensor, nwords: torch.Tensor
|
| 191 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 192 |
+
"""
|
| 193 |
+
Reshape word-level predictions back to batch format.
|
| 194 |
+
Following the original fairseq approach with pad_sequence.
|
| 195 |
+
|
| 196 |
+
Args:
|
| 197 |
+
cat_logits: Category logits (total_words, num_categories)
|
| 198 |
+
attr_logits: Attribute logits (total_words, num_labels)
|
| 199 |
+
nwords: Number of words per sequence (batch_size,)
|
| 200 |
+
|
| 201 |
+
Returns:
|
| 202 |
+
cat_logits_batch: (batch_size, max_words, num_categories)
|
| 203 |
+
attr_logits_batch: (batch_size, max_words, num_labels)
|
| 204 |
+
"""
|
| 205 |
+
|
| 206 |
+
# Split logits by sequence using word counts
|
| 207 |
+
words_per_seq = nwords.tolist()
|
| 208 |
+
cat_logits_split = cat_logits.split(words_per_seq)
|
| 209 |
+
attr_logits_split = attr_logits.split(words_per_seq)
|
| 210 |
+
|
| 211 |
+
# Pad to same length (matching original fairseq approach)
|
| 212 |
+
cat_logits_batch = pad_sequence(cat_logits_split, batch_first=True, padding_value=0)
|
| 213 |
+
attr_logits_batch = pad_sequence(attr_logits_split, batch_first=True, padding_value=0)
|
| 214 |
+
|
| 215 |
+
return cat_logits_batch, attr_logits_batch
|
| 216 |
+
|
| 217 |
+
@torch.no_grad()
|
| 218 |
+
def predict_labels(
|
| 219 |
+
self, input_ids: torch.Tensor, attention_mask: torch.Tensor, word_ids: List[List[int]]
|
| 220 |
+
) -> List[List[Tuple[str, List[str]]]]:
|
| 221 |
+
"""
|
| 222 |
+
Predict POS labels for input sequences.
|
| 223 |
+
|
| 224 |
+
Args:
|
| 225 |
+
input_ids: Token indices
|
| 226 |
+
attention_mask: Attention mask
|
| 227 |
+
word_ids: Word boundaries
|
| 228 |
+
|
| 229 |
+
Returns:
|
| 230 |
+
List of sequences, each containing (category, [attributes]) per word
|
| 231 |
+
"""
|
| 232 |
+
# Convert word_ids to word_mask
|
| 233 |
+
word_mask = self._word_ids_to_word_mask(word_ids, input_ids.shape)
|
| 234 |
+
|
| 235 |
+
cat_logits, attr_logits = self.forward(input_ids=input_ids, attention_mask=attention_mask, word_mask=word_mask)
|
| 236 |
+
|
| 237 |
+
return self._logits_to_labels(cat_logits, attr_logits, word_ids)
|
| 238 |
+
|
| 239 |
+
def _word_ids_to_word_mask(self, word_ids: List[List[int]], input_shape: torch.Size) -> torch.Tensor:
|
| 240 |
+
"""
|
| 241 |
+
Convert word_ids to word_mask (binary mask indicating word boundaries).
|
| 242 |
+
|
| 243 |
+
Args:
|
| 244 |
+
word_ids: List of word id sequences
|
| 245 |
+
input_shape: Shape of input_ids tensor (batch_size, seq_len)
|
| 246 |
+
|
| 247 |
+
Returns:
|
| 248 |
+
word_mask: Binary tensor where 1 indicates start of word
|
| 249 |
+
"""
|
| 250 |
+
batch_size, seq_len = input_shape
|
| 251 |
+
word_mask = torch.zeros(batch_size, seq_len, dtype=torch.long)
|
| 252 |
+
|
| 253 |
+
for batch_idx, seq_word_ids in enumerate(word_ids):
|
| 254 |
+
prev_word_id = None
|
| 255 |
+
for token_idx, word_id in enumerate(seq_word_ids):
|
| 256 |
+
if word_id != prev_word_id:
|
| 257 |
+
word_mask[batch_idx, token_idx] = 1
|
| 258 |
+
prev_word_id = word_id
|
| 259 |
+
|
| 260 |
+
return word_mask
|
| 261 |
+
|
| 262 |
+
def predict_labels_from_text(self, sentences: List[str], tokenizer) -> List[List[Tuple[str, List[str]]]]:
|
| 263 |
+
"""
|
| 264 |
+
Predict POS labels from raw text using fairseq-style preprocessing.
|
| 265 |
+
|
| 266 |
+
Args:
|
| 267 |
+
sentences: List of input sentences
|
| 268 |
+
tokenizer: HuggingFace tokenizer
|
| 269 |
+
|
| 270 |
+
Returns:
|
| 271 |
+
List of sequences, each containing (category, [attributes]) per word
|
| 272 |
+
"""
|
| 273 |
+
# Tokenize with fairseq-style preprocessing
|
| 274 |
+
encodings = [tokenizer(sent, return_tensors="pt") for sent in sentences]
|
| 275 |
+
word_ids_list = [encoding.word_ids() for encoding in encodings]
|
| 276 |
+
|
| 277 |
+
# Batch the inputs
|
| 278 |
+
max_len = max(encoding["input_ids"].shape[1] for encoding in encodings)
|
| 279 |
+
batch_input_ids = []
|
| 280 |
+
batch_attention_mask = []
|
| 281 |
+
|
| 282 |
+
for encoding in encodings:
|
| 283 |
+
input_ids = encoding["input_ids"][0]
|
| 284 |
+
attention_mask = encoding["attention_mask"][0]
|
| 285 |
+
|
| 286 |
+
# Pad to max length
|
| 287 |
+
pad_len = max_len - len(input_ids)
|
| 288 |
+
if pad_len > 0:
|
| 289 |
+
input_ids = torch.cat([input_ids, torch.ones(pad_len, dtype=torch.long)]) # pad_token_id = 1
|
| 290 |
+
attention_mask = torch.cat([attention_mask, torch.zeros(pad_len, dtype=torch.long)])
|
| 291 |
+
|
| 292 |
+
batch_input_ids.append(input_ids)
|
| 293 |
+
batch_attention_mask.append(attention_mask)
|
| 294 |
+
|
| 295 |
+
batch_input_ids = torch.stack(batch_input_ids)
|
| 296 |
+
batch_attention_mask = torch.stack(batch_attention_mask)
|
| 297 |
+
|
| 298 |
+
return self.predict_labels(batch_input_ids, batch_attention_mask, word_ids_list)
|
| 299 |
+
|
| 300 |
+
def _make_group_name_to_group_attr_vec_idxs(self):
|
| 301 |
+
"""Create mapping from group names to their attribute vector indices"""
|
| 302 |
+
group_name_to_group_attr_vec_idxs = {}
|
| 303 |
+
labels = self.config.label_schema["labels"]
|
| 304 |
+
nspecial = 0 # Number of special tokens in label dictionary (like <SEP>)
|
| 305 |
+
|
| 306 |
+
for group_name, group_labels in self.config.label_schema["group_name_to_labels"].items():
|
| 307 |
+
vec_idxs = []
|
| 308 |
+
for label in group_labels:
|
| 309 |
+
if label in labels:
|
| 310 |
+
# Find index in labels list, but subtract nspecial to get vector index
|
| 311 |
+
label_dict_idx = labels.index(label)
|
| 312 |
+
if label_dict_idx >= nspecial: # Skip special tokens
|
| 313 |
+
vec_idxs.append(label_dict_idx - nspecial)
|
| 314 |
+
group_name_to_group_attr_vec_idxs[group_name] = torch.tensor(vec_idxs)
|
| 315 |
+
|
| 316 |
+
return group_name_to_group_attr_vec_idxs
|
| 317 |
+
|
| 318 |
+
def _make_group_masks(self):
|
| 319 |
+
"""Create group masks for each category"""
|
| 320 |
+
label_categories = self.config.label_schema["label_categories"]
|
| 321 |
+
group_names = self.config.label_schema["group_names"]
|
| 322 |
+
category_to_group_names = self.config.label_schema["category_to_group_names"]
|
| 323 |
+
|
| 324 |
+
num_cats = len(label_categories)
|
| 325 |
+
num_groups = len(group_names)
|
| 326 |
+
|
| 327 |
+
group_mask = torch.zeros(num_cats, num_groups, dtype=torch.bool)
|
| 328 |
+
|
| 329 |
+
for cat_idx, category in enumerate(label_categories):
|
| 330 |
+
if category in category_to_group_names:
|
| 331 |
+
for group_name in category_to_group_names[category]:
|
| 332 |
+
if group_name in group_names:
|
| 333 |
+
group_idx = group_names.index(group_name)
|
| 334 |
+
group_mask[cat_idx, group_idx] = True
|
| 335 |
+
|
| 336 |
+
return group_mask
|
| 337 |
+
|
| 338 |
+
def _make_category_mappings(self):
|
| 339 |
+
"""Create mappings between category vector indices and dictionary indices"""
|
| 340 |
+
labels = self.config.label_schema["labels"]
|
| 341 |
+
label_categories = self.config.label_schema["label_categories"]
|
| 342 |
+
|
| 343 |
+
# Create mapping from category names to vector indices (0-based)
|
| 344 |
+
cat_dict_idx_to_vec_idx = torch.zeros(len(labels), dtype=torch.long)
|
| 345 |
+
cat_vec_idx_to_dict_idx = torch.zeros(len(label_categories), dtype=torch.long)
|
| 346 |
+
|
| 347 |
+
for vec_idx, category in enumerate(label_categories):
|
| 348 |
+
if category in labels:
|
| 349 |
+
dict_idx = labels.index(category)
|
| 350 |
+
cat_dict_idx_to_vec_idx[dict_idx] = vec_idx
|
| 351 |
+
cat_vec_idx_to_dict_idx[vec_idx] = dict_idx
|
| 352 |
+
|
| 353 |
+
return cat_dict_idx_to_vec_idx, cat_vec_idx_to_dict_idx
|
| 354 |
+
|
| 355 |
+
def _count_words_per_sequence(self, word_ids: List[List[int]]) -> List[int]:
|
| 356 |
+
"""Count the number of unique words in each sequence."""
|
| 357 |
+
words_per_seq = []
|
| 358 |
+
for seq_word_ids in word_ids:
|
| 359 |
+
unique_word_ids = set(word_id for word_id in seq_word_ids if word_id is not None)
|
| 360 |
+
words_per_seq.append(len(unique_word_ids))
|
| 361 |
+
return words_per_seq
|
| 362 |
+
|
| 363 |
+
def _predict_categories_for_sequence(
|
| 364 |
+
self, cat_logits: torch.Tensor, seq_idx: int, seq_nwords: int, cat_vec_idx_to_dict_idx: torch.Tensor
|
| 365 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 366 |
+
"""Predict categories for a single sequence and return both vector and dictionary indices."""
|
| 367 |
+
pred_cat_vec_idxs = cat_logits[seq_idx, :seq_nwords].max(dim=-1).indices
|
| 368 |
+
pred_cats = cat_vec_idx_to_dict_idx[pred_cat_vec_idxs]
|
| 369 |
+
return pred_cat_vec_idxs, pred_cats
|
| 370 |
+
|
| 371 |
+
def _predict_attributes_for_group(
|
| 372 |
+
self,
|
| 373 |
+
attr_logits: torch.Tensor,
|
| 374 |
+
seq_idx: int,
|
| 375 |
+
seq_nwords: int,
|
| 376 |
+
group_vec_idxs: torch.Tensor,
|
| 377 |
+
seq_group_mask: torch.Tensor,
|
| 378 |
+
group_idx: int,
|
| 379 |
+
) -> torch.Tensor:
|
| 380 |
+
"""Predict attributes for a single group."""
|
| 381 |
+
if len(group_vec_idxs) == 0:
|
| 382 |
+
return torch.zeros(seq_nwords, dtype=torch.long)
|
| 383 |
+
|
| 384 |
+
# Get logits for this group
|
| 385 |
+
group_logits = attr_logits[seq_idx, :seq_nwords, group_vec_idxs]
|
| 386 |
+
|
| 387 |
+
if len(group_vec_idxs) == 1:
|
| 388 |
+
# Single element group: use sigmoid > 0.5
|
| 389 |
+
group_pred = group_logits.sigmoid().ge(0.5).long()
|
| 390 |
+
group_pred_dict_idxs = (group_pred.squeeze() * group_vec_idxs.item()) * seq_group_mask[:, group_idx]
|
| 391 |
+
else:
|
| 392 |
+
# Multi element group: use argmax
|
| 393 |
+
group_pred_vec_idxs = group_logits.max(dim=-1).indices
|
| 394 |
+
group_pred_dict_idxs = group_vec_idxs[group_pred_vec_idxs] * seq_group_mask[:, group_idx]
|
| 395 |
+
|
| 396 |
+
return group_pred_dict_idxs
|
| 397 |
+
|
| 398 |
+
def _predict_all_attributes_for_sequence(
|
| 399 |
+
self,
|
| 400 |
+
attr_logits: torch.Tensor,
|
| 401 |
+
seq_idx: int,
|
| 402 |
+
seq_nwords: int,
|
| 403 |
+
pred_cat_vec_idxs: torch.Tensor,
|
| 404 |
+
group_name_to_group_attr_vec_idxs: dict,
|
| 405 |
+
group_mask: torch.Tensor,
|
| 406 |
+
group_names: List[str],
|
| 407 |
+
) -> torch.Tensor:
|
| 408 |
+
"""Predict all attributes for a single sequence."""
|
| 409 |
+
seq_group_mask = group_mask[pred_cat_vec_idxs]
|
| 410 |
+
pred_attrs = []
|
| 411 |
+
|
| 412 |
+
for group_idx, group_name in enumerate(group_names):
|
| 413 |
+
if group_name not in group_name_to_group_attr_vec_idxs:
|
| 414 |
+
pred_attrs.append(torch.zeros(seq_nwords, dtype=torch.long))
|
| 415 |
+
continue
|
| 416 |
+
|
| 417 |
+
group_vec_idxs = group_name_to_group_attr_vec_idxs[group_name]
|
| 418 |
+
group_pred_dict_idxs = self._predict_attributes_for_group(
|
| 419 |
+
attr_logits, seq_idx, seq_nwords, group_vec_idxs, seq_group_mask, group_idx
|
| 420 |
+
)
|
| 421 |
+
pred_attrs.append(group_pred_dict_idxs)
|
| 422 |
+
|
| 423 |
+
# Stack predictions
|
| 424 |
+
if pred_attrs:
|
| 425 |
+
return torch.stack([p.squeeze() if p.dim() > 1 else p for p in pred_attrs]).t()
|
| 426 |
+
else:
|
| 427 |
+
return torch.zeros(seq_nwords, len(group_names), dtype=torch.long)
|
| 428 |
+
|
| 429 |
+
def _convert_predictions_to_labels(
|
| 430 |
+
self, pred_cats: torch.Tensor, pred_attrs_tensor: torch.Tensor, labels: List[str], group_names: List[str]
|
| 431 |
+
) -> List[Tuple[str, List[str]]]:
|
| 432 |
+
"""Convert prediction tensors to human-readable labels."""
|
| 433 |
+
seq_nwords = pred_cats.size(0)
|
| 434 |
+
seq_predictions = []
|
| 435 |
+
|
| 436 |
+
for word_idx in range(seq_nwords):
|
| 437 |
+
# Category (convert from dictionary index to string)
|
| 438 |
+
cat_dict_idx = pred_cats[word_idx].item()
|
| 439 |
+
if cat_dict_idx < len(labels):
|
| 440 |
+
category = labels[cat_dict_idx]
|
| 441 |
+
else:
|
| 442 |
+
category = "UNK"
|
| 443 |
+
|
| 444 |
+
# Attributes (convert from dictionary indices to strings)
|
| 445 |
+
attributes = []
|
| 446 |
+
for group_idx in range(len(group_names)):
|
| 447 |
+
attr_dict_idx = pred_attrs_tensor[word_idx, group_idx].item()
|
| 448 |
+
if attr_dict_idx > 0 and attr_dict_idx < len(labels): # Skip 0 (empty) and out of bounds
|
| 449 |
+
attributes.append(labels[attr_dict_idx])
|
| 450 |
+
|
| 451 |
+
seq_predictions.append((category, attributes))
|
| 452 |
+
|
| 453 |
+
return seq_predictions
|
| 454 |
+
|
| 455 |
+
def _logits_to_labels(
|
| 456 |
+
self, cat_logits: torch.Tensor, attr_logits: torch.Tensor, word_ids: List[List[int]]
|
| 457 |
+
) -> List[List[Tuple[str, List[str]]]]:
|
| 458 |
+
"""
|
| 459 |
+
Convert logits to human-readable labels using fairseq's group-based logic.
|
| 460 |
+
"""
|
| 461 |
+
# Create necessary mappings
|
| 462 |
+
group_name_to_group_attr_vec_idxs = self._make_group_name_to_group_attr_vec_idxs()
|
| 463 |
+
group_mask = self._make_group_masks()
|
| 464 |
+
cat_dict_idx_to_vec_idx, cat_vec_idx_to_dict_idx = self._make_category_mappings()
|
| 465 |
+
|
| 466 |
+
label_schema = self.config.label_schema
|
| 467 |
+
labels = label_schema["labels"]
|
| 468 |
+
group_names = label_schema["group_names"]
|
| 469 |
+
|
| 470 |
+
batch_size = cat_logits.size(0)
|
| 471 |
+
words_per_seq = self._count_words_per_sequence(word_ids)
|
| 472 |
+
batch_predictions = []
|
| 473 |
+
|
| 474 |
+
for seq_idx in range(batch_size):
|
| 475 |
+
seq_nwords = words_per_seq[seq_idx]
|
| 476 |
+
|
| 477 |
+
# Predict categories
|
| 478 |
+
pred_cat_vec_idxs, pred_cats = self._predict_categories_for_sequence(
|
| 479 |
+
cat_logits, seq_idx, seq_nwords, cat_vec_idx_to_dict_idx
|
| 480 |
+
)
|
| 481 |
+
|
| 482 |
+
# Predict attributes
|
| 483 |
+
pred_attrs_tensor = self._predict_all_attributes_for_sequence(
|
| 484 |
+
attr_logits,
|
| 485 |
+
seq_idx,
|
| 486 |
+
seq_nwords,
|
| 487 |
+
pred_cat_vec_idxs,
|
| 488 |
+
group_name_to_group_attr_vec_idxs,
|
| 489 |
+
group_mask,
|
| 490 |
+
group_names,
|
| 491 |
+
)
|
| 492 |
+
|
| 493 |
+
# Convert to labels
|
| 494 |
+
seq_predictions = self._convert_predictions_to_labels(pred_cats, pred_attrs_tensor, labels, group_names)
|
| 495 |
+
batch_predictions.append(seq_predictions)
|
| 496 |
+
|
| 497 |
+
return batch_predictions
|
| 498 |
+
|
| 499 |
+
|
| 500 |
+
AutoConfig.register("icebert-pos", IceBertPosConfig)
|
| 501 |
+
AutoModel.register(IceBertPosConfig, IceBertPosForTokenClassification)
|
| 502 |
+
IceBertPosConfig.register_for_auto_class()
|
| 503 |
+
IceBertPosForTokenClassification.register_for_auto_class("AutoModel")
|
special_tokens_map.json
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"bos_token": "<s>",
|
| 3 |
+
"cls_token": "<s>",
|
| 4 |
+
"eos_token": "</s>",
|
| 5 |
+
"mask_token": {
|
| 6 |
+
"content": "<mask>",
|
| 7 |
+
"lstrip": true,
|
| 8 |
+
"normalized": false,
|
| 9 |
+
"rstrip": false,
|
| 10 |
+
"single_word": false
|
| 11 |
+
},
|
| 12 |
+
"pad_token": "<pad>",
|
| 13 |
+
"sep_token": "</s>",
|
| 14 |
+
"unk_token": "<unk>"
|
| 15 |
+
}
|
tokenizer.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
tokenizer_config.json
ADDED
|
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"add_prefix_space": true,
|
| 3 |
+
"added_tokens_decoder": {
|
| 4 |
+
"0": {
|
| 5 |
+
"content": "<s>",
|
| 6 |
+
"lstrip": false,
|
| 7 |
+
"normalized": false,
|
| 8 |
+
"rstrip": false,
|
| 9 |
+
"single_word": false,
|
| 10 |
+
"special": true
|
| 11 |
+
},
|
| 12 |
+
"1": {
|
| 13 |
+
"content": "<pad>",
|
| 14 |
+
"lstrip": false,
|
| 15 |
+
"normalized": false,
|
| 16 |
+
"rstrip": false,
|
| 17 |
+
"single_word": false,
|
| 18 |
+
"special": true
|
| 19 |
+
},
|
| 20 |
+
"2": {
|
| 21 |
+
"content": "</s>",
|
| 22 |
+
"lstrip": false,
|
| 23 |
+
"normalized": false,
|
| 24 |
+
"rstrip": false,
|
| 25 |
+
"single_word": false,
|
| 26 |
+
"special": true
|
| 27 |
+
},
|
| 28 |
+
"3": {
|
| 29 |
+
"content": "<unk>",
|
| 30 |
+
"lstrip": false,
|
| 31 |
+
"normalized": false,
|
| 32 |
+
"rstrip": false,
|
| 33 |
+
"single_word": false,
|
| 34 |
+
"special": true
|
| 35 |
+
},
|
| 36 |
+
"49936": {
|
| 37 |
+
"content": "<mask>",
|
| 38 |
+
"lstrip": true,
|
| 39 |
+
"normalized": false,
|
| 40 |
+
"rstrip": false,
|
| 41 |
+
"single_word": false,
|
| 42 |
+
"special": true
|
| 43 |
+
}
|
| 44 |
+
},
|
| 45 |
+
"bos_token": "<s>",
|
| 46 |
+
"clean_up_tokenization_spaces": false,
|
| 47 |
+
"cls_token": "<s>",
|
| 48 |
+
"eos_token": "</s>",
|
| 49 |
+
"errors": "replace",
|
| 50 |
+
"mask_token": "<mask>",
|
| 51 |
+
"model_max_length": 1000000000000000019884624838656,
|
| 52 |
+
"pad_token": "<pad>",
|
| 53 |
+
"sep_token": "</s>",
|
| 54 |
+
"tokenizer_class": "RobertaTokenizer",
|
| 55 |
+
"trim_offsets": true,
|
| 56 |
+
"unk_token": "<unk>"
|
| 57 |
+
}
|
vocab.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|