|
|
|
|
|
|
|
|
|
|
|
import json |
|
|
from typing import Any, Dict, Optional |
|
|
|
|
|
from transformers import AutoConfig, RobertaConfig |
|
|
|
|
|
|
|
|
class IceBertPosConfig(RobertaConfig): |
|
|
""" |
|
|
Configuration class for IceBERT POS (Part-of-Speech) tagging model. |
|
|
|
|
|
This configuration inherits from RobertaConfig and adds POS-specific parameters |
|
|
derived from the label schema used for multilabel token classification. |
|
|
""" |
|
|
|
|
|
model_type = "icebert-pos" |
|
|
|
|
|
def __init__( |
|
|
self, label_schema: Optional[Dict[str, Any]] = None, classifier_dropout: Optional[float] = None, **kwargs |
|
|
): |
|
|
super().__init__(**kwargs) |
|
|
|
|
|
|
|
|
if label_schema is None: |
|
|
label_schema = self._get_default_label_schema() |
|
|
|
|
|
self.label_schema = label_schema |
|
|
|
|
|
|
|
|
self.num_categories = len(label_schema["label_categories"]) |
|
|
self.num_labels = len(label_schema["labels"]) |
|
|
self.num_groups = len(label_schema["group_names"]) |
|
|
|
|
|
|
|
|
self.classifier_dropout = classifier_dropout if classifier_dropout is not None else 0.1 |
|
|
|
|
|
|
|
|
|
|
|
self.attr_proj_input_size = self.num_categories + self.hidden_size |
|
|
|
|
|
@staticmethod |
|
|
def _get_default_label_schema() -> Dict[str, Any]: |
|
|
"""Default label schema corresponding to terms2.json""" |
|
|
return { |
|
|
"label_categories": [ |
|
|
"n", |
|
|
"g", |
|
|
"x", |
|
|
"e", |
|
|
"v", |
|
|
"l", |
|
|
"fa", |
|
|
"fb", |
|
|
"fe", |
|
|
"fo", |
|
|
"fp", |
|
|
"fs", |
|
|
"ft", |
|
|
"tf", |
|
|
"ta", |
|
|
"tp", |
|
|
"to", |
|
|
"sn", |
|
|
"sb", |
|
|
"sf", |
|
|
"sv", |
|
|
"ss", |
|
|
"sl", |
|
|
"sþ", |
|
|
"cn", |
|
|
"ct", |
|
|
"c", |
|
|
"aa", |
|
|
"af", |
|
|
"au", |
|
|
"ao", |
|
|
"aþ", |
|
|
"ae", |
|
|
"as", |
|
|
"ks", |
|
|
"kt", |
|
|
"p", |
|
|
"pl", |
|
|
"pk", |
|
|
"pg", |
|
|
"pa", |
|
|
"ns", |
|
|
"m", |
|
|
], |
|
|
"category_to_group_names": { |
|
|
"n": ["gender", "number", "case", "def", "proper"], |
|
|
"g": ["gender", "number", "case"], |
|
|
"l": ["gender", "number", "case", "adj_c", "deg"], |
|
|
"fa": ["gender", "number", "case"], |
|
|
"fb": ["gender", "number", "case"], |
|
|
"fe": ["gender", "number", "case"], |
|
|
"fs": ["gender", "number", "case"], |
|
|
"ft": ["gender", "number", "case"], |
|
|
"fo": ["gender_or_person", "number", "case"], |
|
|
"fp": ["gender_or_person", "number", "case"], |
|
|
"tf": ["gender", "number", "case"], |
|
|
"sn": ["voice"], |
|
|
"sb": ["voice", "person", "number", "tense"], |
|
|
"sf": ["voice", "person", "number", "tense"], |
|
|
"sv": ["voice", "person", "number", "tense"], |
|
|
"ss": ["voice"], |
|
|
"sl": ["voice", "person", "number", "tense"], |
|
|
"sþ": ["voice", "gender", "number", "case"], |
|
|
"aa": ["deg"], |
|
|
"af": ["deg"], |
|
|
"au": ["deg"], |
|
|
"ao": ["deg"], |
|
|
"aþ": ["deg"], |
|
|
"ae": ["deg"], |
|
|
"as": ["deg"], |
|
|
}, |
|
|
"group_names": [ |
|
|
"gender", |
|
|
"gender_or_person", |
|
|
"number", |
|
|
"case", |
|
|
"def", |
|
|
"proper", |
|
|
"adj_c", |
|
|
"deg", |
|
|
"voice", |
|
|
"person", |
|
|
"tense", |
|
|
], |
|
|
"group_name_to_labels": { |
|
|
"gender": ["masc", "fem", "neut", "gender_x"], |
|
|
"number": ["sing", "plur"], |
|
|
"person": ["1", "2", "3"], |
|
|
"gender_or_person": ["masc", "fem", "neut", "gender_x", "1", "2", "3"], |
|
|
"case": ["nom", "acc", "dat", "gen"], |
|
|
"deg": ["pos", "cmp", "superl"], |
|
|
"voice": ["act", "mid"], |
|
|
"tense": ["pres", "past"], |
|
|
"def": ["definite"], |
|
|
"proper": ["proper"], |
|
|
"adj_c": ["strong", "weak", "equiinflected"], |
|
|
}, |
|
|
"labels": [ |
|
|
"<SEP>", |
|
|
"n", |
|
|
"g", |
|
|
"x", |
|
|
"e", |
|
|
"v", |
|
|
"l", |
|
|
"fa", |
|
|
"fb", |
|
|
"fe", |
|
|
"fo", |
|
|
"fp", |
|
|
"fs", |
|
|
"ft", |
|
|
"tf", |
|
|
"ta", |
|
|
"tp", |
|
|
"to", |
|
|
"sn", |
|
|
"sb", |
|
|
"sf", |
|
|
"sv", |
|
|
"ss", |
|
|
"sl", |
|
|
"sþ", |
|
|
"cn", |
|
|
"ct", |
|
|
"c", |
|
|
"aa", |
|
|
"af", |
|
|
"au", |
|
|
"ao", |
|
|
"aþ", |
|
|
"ae", |
|
|
"as", |
|
|
"ks", |
|
|
"kt", |
|
|
"p", |
|
|
"pl", |
|
|
"pk", |
|
|
"pg", |
|
|
"pa", |
|
|
"ns", |
|
|
"m", |
|
|
"masc", |
|
|
"fem", |
|
|
"neut", |
|
|
"gender_x", |
|
|
"1", |
|
|
"2", |
|
|
"3", |
|
|
"sing", |
|
|
"plur", |
|
|
"nom", |
|
|
"acc", |
|
|
"dat", |
|
|
"gen", |
|
|
"definite", |
|
|
"proper", |
|
|
"strong", |
|
|
"weak", |
|
|
"equiinflected", |
|
|
"pos", |
|
|
"cmp", |
|
|
"superl", |
|
|
"past", |
|
|
"pres", |
|
|
"pass", |
|
|
"act", |
|
|
"mid", |
|
|
], |
|
|
"null": None, |
|
|
"null_leaf": None, |
|
|
"separator": "<SEP>", |
|
|
"ignore_categories": ["x", "e"], |
|
|
} |
|
|
|
|
|
@classmethod |
|
|
def from_label_schema_file(cls, schema_path: str, **kwargs) -> "IceBertPosConfig": |
|
|
"""Create config from a label schema JSON file""" |
|
|
with open(schema_path, "r", encoding="utf-8") as f: |
|
|
label_schema = json.load(f) |
|
|
return cls(label_schema=label_schema, **kwargs) |
|
|
|
|
|
|
|
|
AutoConfig.register("icebert-pos", IceBertPosConfig) |
|
|
|