# Copyright (C) Miðeind ehf. # This file is part of IceBERT POS model conversion. import json from typing import Any, Dict, Optional from transformers import AutoConfig, RobertaConfig class IceBertPosConfig(RobertaConfig): """ Configuration class for IceBERT POS (Part-of-Speech) tagging model. This configuration inherits from RobertaConfig and adds POS-specific parameters derived from the label schema used for multilabel token classification. """ model_type = "icebert-pos" def __init__( self, label_schema: Optional[Dict[str, Any]] = None, classifier_dropout: Optional[float] = None, **kwargs ): super().__init__(**kwargs) # Default label schema (terms2.json content) if label_schema is None: label_schema = self._get_default_label_schema() self.label_schema = label_schema # Derive parameters from label schema self.num_categories = len(label_schema["label_categories"]) self.num_labels = len(label_schema["labels"]) self.num_groups = len(label_schema["group_names"]) # Classification head parameters self.classifier_dropout = classifier_dropout if classifier_dropout is not None else 0.1 # Computed input size for attribute projection # (category_probs + hidden_size) -> num_labels self.attr_proj_input_size = self.num_categories + self.hidden_size @staticmethod def _get_default_label_schema() -> Dict[str, Any]: """Default label schema corresponding to terms2.json""" return { "label_categories": [ "n", "g", "x", "e", "v", "l", "fa", "fb", "fe", "fo", "fp", "fs", "ft", "tf", "ta", "tp", "to", "sn", "sb", "sf", "sv", "ss", "sl", "sþ", "cn", "ct", "c", "aa", "af", "au", "ao", "aþ", "ae", "as", "ks", "kt", "p", "pl", "pk", "pg", "pa", "ns", "m", ], "category_to_group_names": { "n": ["gender", "number", "case", "def", "proper"], "g": ["gender", "number", "case"], "l": ["gender", "number", "case", "adj_c", "deg"], "fa": ["gender", "number", "case"], "fb": ["gender", "number", "case"], "fe": ["gender", "number", "case"], "fs": ["gender", "number", "case"], "ft": ["gender", "number", "case"], "fo": ["gender_or_person", "number", "case"], "fp": ["gender_or_person", "number", "case"], "tf": ["gender", "number", "case"], "sn": ["voice"], "sb": ["voice", "person", "number", "tense"], "sf": ["voice", "person", "number", "tense"], "sv": ["voice", "person", "number", "tense"], "ss": ["voice"], "sl": ["voice", "person", "number", "tense"], "sþ": ["voice", "gender", "number", "case"], "aa": ["deg"], "af": ["deg"], "au": ["deg"], "ao": ["deg"], "aþ": ["deg"], "ae": ["deg"], "as": ["deg"], }, "group_names": [ "gender", "gender_or_person", "number", "case", "def", "proper", "adj_c", "deg", "voice", "person", "tense", ], "group_name_to_labels": { "gender": ["masc", "fem", "neut", "gender_x"], "number": ["sing", "plur"], "person": ["1", "2", "3"], "gender_or_person": ["masc", "fem", "neut", "gender_x", "1", "2", "3"], "case": ["nom", "acc", "dat", "gen"], "deg": ["pos", "cmp", "superl"], "voice": ["act", "mid"], "tense": ["pres", "past"], "def": ["definite"], "proper": ["proper"], "adj_c": ["strong", "weak", "equiinflected"], }, "labels": [ "", "n", "g", "x", "e", "v", "l", "fa", "fb", "fe", "fo", "fp", "fs", "ft", "tf", "ta", "tp", "to", "sn", "sb", "sf", "sv", "ss", "sl", "sþ", "cn", "ct", "c", "aa", "af", "au", "ao", "aþ", "ae", "as", "ks", "kt", "p", "pl", "pk", "pg", "pa", "ns", "m", "masc", "fem", "neut", "gender_x", "1", "2", "3", "sing", "plur", "nom", "acc", "dat", "gen", "definite", "proper", "strong", "weak", "equiinflected", "pos", "cmp", "superl", "past", "pres", "pass", "act", "mid", ], "null": None, "null_leaf": None, "separator": "", "ignore_categories": ["x", "e"], } @classmethod def from_label_schema_file(cls, schema_path: str, **kwargs) -> "IceBertPosConfig": """Create config from a label schema JSON file""" with open(schema_path, "r", encoding="utf-8") as f: label_schema = json.load(f) return cls(label_schema=label_schema, **kwargs) AutoConfig.register("icebert-pos", IceBertPosConfig)