IceBERT-PoS / configuration.py
haukurpj's picture
Upload folder using huggingface_hub
2d923bf verified
raw
history blame
6.94 kB
# Copyright (C) Miðeind ehf.
# This file is part of IceBERT POS model conversion.
import json
from typing import Any, Dict, Optional
from transformers import AutoConfig, RobertaConfig
class IceBertPosConfig(RobertaConfig):
"""
Configuration class for IceBERT POS (Part-of-Speech) tagging model.
This configuration inherits from RobertaConfig and adds POS-specific parameters
derived from the label schema used for multilabel token classification.
"""
model_type = "icebert-pos"
def __init__(
self, label_schema: Optional[Dict[str, Any]] = None, classifier_dropout: Optional[float] = None, **kwargs
):
super().__init__(**kwargs)
# Default label schema (terms2.json content)
if label_schema is None:
label_schema = self._get_default_label_schema()
self.label_schema = label_schema
# Derive parameters from label schema
self.num_categories = len(label_schema["label_categories"])
self.num_labels = len(label_schema["labels"])
self.num_groups = len(label_schema["group_names"])
# Classification head parameters
self.classifier_dropout = classifier_dropout if classifier_dropout is not None else 0.1
# Computed input size for attribute projection
# (category_probs + hidden_size) -> num_labels
self.attr_proj_input_size = self.num_categories + self.hidden_size
@staticmethod
def _get_default_label_schema() -> Dict[str, Any]:
"""Default label schema corresponding to terms2.json"""
return {
"label_categories": [
"n",
"g",
"x",
"e",
"v",
"l",
"fa",
"fb",
"fe",
"fo",
"fp",
"fs",
"ft",
"tf",
"ta",
"tp",
"to",
"sn",
"sb",
"sf",
"sv",
"ss",
"sl",
"sþ",
"cn",
"ct",
"c",
"aa",
"af",
"au",
"ao",
"aþ",
"ae",
"as",
"ks",
"kt",
"p",
"pl",
"pk",
"pg",
"pa",
"ns",
"m",
],
"category_to_group_names": {
"n": ["gender", "number", "case", "def", "proper"],
"g": ["gender", "number", "case"],
"l": ["gender", "number", "case", "adj_c", "deg"],
"fa": ["gender", "number", "case"],
"fb": ["gender", "number", "case"],
"fe": ["gender", "number", "case"],
"fs": ["gender", "number", "case"],
"ft": ["gender", "number", "case"],
"fo": ["gender_or_person", "number", "case"],
"fp": ["gender_or_person", "number", "case"],
"tf": ["gender", "number", "case"],
"sn": ["voice"],
"sb": ["voice", "person", "number", "tense"],
"sf": ["voice", "person", "number", "tense"],
"sv": ["voice", "person", "number", "tense"],
"ss": ["voice"],
"sl": ["voice", "person", "number", "tense"],
"sþ": ["voice", "gender", "number", "case"],
"aa": ["deg"],
"af": ["deg"],
"au": ["deg"],
"ao": ["deg"],
"aþ": ["deg"],
"ae": ["deg"],
"as": ["deg"],
},
"group_names": [
"gender",
"gender_or_person",
"number",
"case",
"def",
"proper",
"adj_c",
"deg",
"voice",
"person",
"tense",
],
"group_name_to_labels": {
"gender": ["masc", "fem", "neut", "gender_x"],
"number": ["sing", "plur"],
"person": ["1", "2", "3"],
"gender_or_person": ["masc", "fem", "neut", "gender_x", "1", "2", "3"],
"case": ["nom", "acc", "dat", "gen"],
"deg": ["pos", "cmp", "superl"],
"voice": ["act", "mid"],
"tense": ["pres", "past"],
"def": ["definite"],
"proper": ["proper"],
"adj_c": ["strong", "weak", "equiinflected"],
},
"labels": [
"<SEP>",
"n",
"g",
"x",
"e",
"v",
"l",
"fa",
"fb",
"fe",
"fo",
"fp",
"fs",
"ft",
"tf",
"ta",
"tp",
"to",
"sn",
"sb",
"sf",
"sv",
"ss",
"sl",
"sþ",
"cn",
"ct",
"c",
"aa",
"af",
"au",
"ao",
"aþ",
"ae",
"as",
"ks",
"kt",
"p",
"pl",
"pk",
"pg",
"pa",
"ns",
"m",
"masc",
"fem",
"neut",
"gender_x",
"1",
"2",
"3",
"sing",
"plur",
"nom",
"acc",
"dat",
"gen",
"definite",
"proper",
"strong",
"weak",
"equiinflected",
"pos",
"cmp",
"superl",
"past",
"pres",
"pass",
"act",
"mid",
],
"null": None,
"null_leaf": None,
"separator": "<SEP>",
"ignore_categories": ["x", "e"],
}
@classmethod
def from_label_schema_file(cls, schema_path: str, **kwargs) -> "IceBertPosConfig":
"""Create config from a label schema JSON file"""
with open(schema_path, "r", encoding="utf-8") as f:
label_schema = json.load(f)
return cls(label_schema=label_schema, **kwargs)
AutoConfig.register("icebert-pos", IceBertPosConfig)