|
|
|
|
|
|
|
|
|
|
|
import json |
|
|
from dataclasses import dataclass |
|
|
from typing import Dict, List, Optional |
|
|
|
|
|
import torch |
|
|
from transformers import AutoConfig, RobertaConfig |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class LabelSchema: |
|
|
""" |
|
|
Dataclass representing the structure of a POS tagging label schema. |
|
|
|
|
|
The schema defines a hierarchical structure where: |
|
|
- Categories (e.g., 'n', 'v', 'l') are the main POS types |
|
|
- Groups (e.g., 'gender', 'number', 'case') are grammatical attribute types |
|
|
- Labels are the specific values for each group (e.g., 'masc', 'fem', 'sing', 'plur') |
|
|
|
|
|
Each category maps to applicable groups, and each group maps to its possible labels. |
|
|
This enables multilabel classification where tokens get both a category and |
|
|
relevant grammatical attributes. |
|
|
""" |
|
|
|
|
|
label_categories: List[str] |
|
|
category_to_group_names: Dict[str, List[str]] |
|
|
group_names: List[str] |
|
|
group_name_to_labels: Dict[str, List[str]] |
|
|
labels: List[str] |
|
|
separator: str |
|
|
ignore_categories: List[str] |
|
|
|
|
|
def get_group_name_to_group_attr_indices(self, device="cpu") -> Dict[str, torch.Tensor]: |
|
|
""" |
|
|
Create mapping from group names to their attribute indices in the labels list. |
|
|
|
|
|
Returns: |
|
|
Dictionary mapping group names to tensor of label indices |
|
|
""" |
|
|
group_name_to_group_attr_indices = {} |
|
|
for group_name, group_labels in self.group_name_to_labels.items(): |
|
|
indices = [] |
|
|
for label in group_labels: |
|
|
if label in self.labels: |
|
|
indices.append(self.labels.index(label)) |
|
|
group_name_to_group_attr_indices[group_name] = torch.tensor(indices, device=device) |
|
|
return group_name_to_group_attr_indices |
|
|
|
|
|
def get_group_masks(self, device="cpu") -> torch.Tensor: |
|
|
""" |
|
|
Create group masks indicating which groups are valid for each category. |
|
|
|
|
|
Returns: |
|
|
Tensor of shape (num_categories, num_groups) with 1 for valid combinations |
|
|
""" |
|
|
num_categories = len(self.label_categories) |
|
|
num_groups = len(self.group_names) |
|
|
group_mask = torch.zeros(num_categories, num_groups, dtype=torch.int64, device=device) |
|
|
|
|
|
for cat, cat_group_names in self.category_to_group_names.items(): |
|
|
if cat in self.label_categories: |
|
|
cat_idx = self.label_categories.index(cat) |
|
|
for group_name in cat_group_names: |
|
|
if group_name in self.group_names: |
|
|
group_idx = self.group_names.index(group_name) |
|
|
group_mask[cat_idx, group_idx] = 1 |
|
|
|
|
|
return group_mask |
|
|
|
|
|
def get_category_name_to_index(self) -> Dict[str, int]: |
|
|
""" |
|
|
Create mapping from category names to their indices. |
|
|
|
|
|
Returns: |
|
|
Dictionary mapping category names to their indices |
|
|
""" |
|
|
return {cat: idx for idx, cat in enumerate(self.label_categories)} |
|
|
|
|
|
def get_label_name_to_index(self) -> Dict[str, int]: |
|
|
""" |
|
|
Create mapping from label names to their indices. |
|
|
|
|
|
Returns: |
|
|
Dictionary mapping label names to their indices |
|
|
""" |
|
|
return {label: idx for idx, label in enumerate(self.labels)} |
|
|
|
|
|
|
|
|
class IceBertPosConfig(RobertaConfig): |
|
|
""" |
|
|
Configuration class for IceBERT POS (Part-of-Speech) tagging model. |
|
|
|
|
|
This configuration inherits from RobertaConfig and adds POS-specific parameters |
|
|
derived from the label schema used for multilabel token classification. |
|
|
""" |
|
|
|
|
|
model_type = "icebert-pos" |
|
|
|
|
|
def __init__( |
|
|
self, label_schema: Optional[LabelSchema] = None, classifier_dropout: Optional[float] = None, **kwargs |
|
|
): |
|
|
super().__init__(**kwargs) |
|
|
|
|
|
|
|
|
if label_schema is None: |
|
|
label_schema = self._get_default_label_schema() |
|
|
|
|
|
|
|
|
if isinstance(label_schema, dict): |
|
|
label_schema = LabelSchema(**label_schema) |
|
|
|
|
|
self.label_schema = label_schema |
|
|
|
|
|
|
|
|
self.num_categories = len(label_schema.label_categories) |
|
|
self.num_labels = len(label_schema.labels) |
|
|
self.num_groups = len(label_schema.group_names) |
|
|
|
|
|
|
|
|
self.classifier_dropout = classifier_dropout if classifier_dropout is not None else 0.1 |
|
|
|
|
|
|
|
|
|
|
|
self.attr_proj_input_size = self.num_categories + self.hidden_size |
|
|
|
|
|
@staticmethod |
|
|
def _get_default_label_schema() -> LabelSchema: |
|
|
"""Default label schema corresponding to terms2.json""" |
|
|
return LabelSchema( |
|
|
label_categories=[ |
|
|
"n", |
|
|
"g", |
|
|
"x", |
|
|
"e", |
|
|
"v", |
|
|
"l", |
|
|
"fa", |
|
|
"fb", |
|
|
"fe", |
|
|
"fo", |
|
|
"fp", |
|
|
"fs", |
|
|
"ft", |
|
|
"tf", |
|
|
"ta", |
|
|
"tp", |
|
|
"to", |
|
|
"sn", |
|
|
"sb", |
|
|
"sf", |
|
|
"sv", |
|
|
"ss", |
|
|
"sl", |
|
|
"sþ", |
|
|
"cn", |
|
|
"ct", |
|
|
"c", |
|
|
"aa", |
|
|
"af", |
|
|
"au", |
|
|
"ao", |
|
|
"aþ", |
|
|
"ae", |
|
|
"as", |
|
|
"ks", |
|
|
"kt", |
|
|
"p", |
|
|
"pl", |
|
|
"pk", |
|
|
"pg", |
|
|
"pa", |
|
|
"ns", |
|
|
"m", |
|
|
], |
|
|
category_to_group_names={ |
|
|
"n": ["gender", "number", "case", "def", "proper"], |
|
|
"g": ["gender", "number", "case"], |
|
|
"l": ["gender", "number", "case", "adj_c", "deg"], |
|
|
"fa": ["gender", "number", "case"], |
|
|
"fb": ["gender", "number", "case"], |
|
|
"fe": ["gender", "number", "case"], |
|
|
"fs": ["gender", "number", "case"], |
|
|
"ft": ["gender", "number", "case"], |
|
|
"fo": ["gender_or_person", "number", "case"], |
|
|
"fp": ["gender_or_person", "number", "case"], |
|
|
"tf": ["gender", "number", "case"], |
|
|
"sn": ["voice"], |
|
|
"sb": ["voice", "person", "number", "tense"], |
|
|
"sf": ["voice", "person", "number", "tense"], |
|
|
"sv": ["voice", "person", "number", "tense"], |
|
|
"ss": ["voice"], |
|
|
"sl": ["voice", "person", "number", "tense"], |
|
|
"sþ": ["voice", "gender", "number", "case"], |
|
|
"aa": ["deg"], |
|
|
"af": ["deg"], |
|
|
"au": ["deg"], |
|
|
"ao": ["deg"], |
|
|
"aþ": ["deg"], |
|
|
"ae": ["deg"], |
|
|
"as": ["deg"], |
|
|
}, |
|
|
group_names=[ |
|
|
"gender", |
|
|
"gender_or_person", |
|
|
"number", |
|
|
"case", |
|
|
"def", |
|
|
"proper", |
|
|
"adj_c", |
|
|
"deg", |
|
|
"voice", |
|
|
"person", |
|
|
"tense", |
|
|
], |
|
|
group_name_to_labels={ |
|
|
"gender": ["masc", "fem", "neut", "gender_x"], |
|
|
"number": ["sing", "plur"], |
|
|
"person": ["1", "2", "3"], |
|
|
"gender_or_person": ["masc", "fem", "neut", "gender_x", "1", "2", "3"], |
|
|
"case": ["nom", "acc", "dat", "gen"], |
|
|
"deg": ["pos", "cmp", "superl"], |
|
|
"voice": ["act", "mid"], |
|
|
"tense": ["pres", "past"], |
|
|
"def": ["definite"], |
|
|
"proper": ["proper"], |
|
|
"adj_c": ["strong", "weak", "equiinflected"], |
|
|
}, |
|
|
labels=[ |
|
|
"<SEP>", |
|
|
"n", |
|
|
"g", |
|
|
"x", |
|
|
"e", |
|
|
"v", |
|
|
"l", |
|
|
"fa", |
|
|
"fb", |
|
|
"fe", |
|
|
"fo", |
|
|
"fp", |
|
|
"fs", |
|
|
"ft", |
|
|
"tf", |
|
|
"ta", |
|
|
"tp", |
|
|
"to", |
|
|
"sn", |
|
|
"sb", |
|
|
"sf", |
|
|
"sv", |
|
|
"ss", |
|
|
"sl", |
|
|
"sþ", |
|
|
"cn", |
|
|
"ct", |
|
|
"c", |
|
|
"aa", |
|
|
"af", |
|
|
"au", |
|
|
"ao", |
|
|
"aþ", |
|
|
"ae", |
|
|
"as", |
|
|
"ks", |
|
|
"kt", |
|
|
"p", |
|
|
"pl", |
|
|
"pk", |
|
|
"pg", |
|
|
"pa", |
|
|
"ns", |
|
|
"m", |
|
|
"masc", |
|
|
"fem", |
|
|
"neut", |
|
|
"gender_x", |
|
|
"1", |
|
|
"2", |
|
|
"3", |
|
|
"sing", |
|
|
"plur", |
|
|
"nom", |
|
|
"acc", |
|
|
"dat", |
|
|
"gen", |
|
|
"definite", |
|
|
"proper", |
|
|
"strong", |
|
|
"weak", |
|
|
"equiinflected", |
|
|
"pos", |
|
|
"cmp", |
|
|
"superl", |
|
|
"past", |
|
|
"pres", |
|
|
"pass", |
|
|
"act", |
|
|
"mid", |
|
|
], |
|
|
separator="<SEP>", |
|
|
ignore_categories=["x", "e"], |
|
|
) |
|
|
|
|
|
def to_dict(self): |
|
|
"""Convert config to dictionary, handling LabelSchema serialization.""" |
|
|
output = super().to_dict() |
|
|
|
|
|
|
|
|
if hasattr(self, 'label_schema') and self.label_schema is not None: |
|
|
if isinstance(self.label_schema, LabelSchema): |
|
|
output['label_schema'] = { |
|
|
'label_categories': self.label_schema.label_categories, |
|
|
'category_to_group_names': self.label_schema.category_to_group_names, |
|
|
'group_names': self.label_schema.group_names, |
|
|
'group_name_to_labels': self.label_schema.group_name_to_labels, |
|
|
'labels': self.label_schema.labels, |
|
|
'separator': self.label_schema.separator, |
|
|
'ignore_categories': self.label_schema.ignore_categories, |
|
|
} |
|
|
else: |
|
|
output['label_schema'] = self.label_schema |
|
|
|
|
|
return output |
|
|
|
|
|
@classmethod |
|
|
def from_label_schema_file(cls, schema_path: str, **kwargs) -> "IceBertPosConfig": |
|
|
"""Create config from a label schema JSON file""" |
|
|
with open(schema_path, "r", encoding="utf-8") as f: |
|
|
schema_dict = json.load(f) |
|
|
label_schema = LabelSchema(**schema_dict) |
|
|
return cls(label_schema=label_schema, **kwargs) |
|
|
|
|
|
|
|
|
AutoConfig.register("icebert-pos", IceBertPosConfig) |
|
|
|