Fix inconsistencies with the old model - now works equally
Browse files- config.json +0 -2
- configuration.py +67 -18
- modeling.py +178 -237
- old_label_utils.py +223 -0
config.json
CHANGED
|
@@ -469,8 +469,6 @@
|
|
| 469 |
"act",
|
| 470 |
"mid"
|
| 471 |
],
|
| 472 |
-
"null": null,
|
| 473 |
-
"null_leaf": null,
|
| 474 |
"separator": "<SEP>"
|
| 475 |
},
|
| 476 |
"layer_norm_eps": 1e-05,
|
|
|
|
| 469 |
"act",
|
| 470 |
"mid"
|
| 471 |
],
|
|
|
|
|
|
|
| 472 |
"separator": "<SEP>"
|
| 473 |
},
|
| 474 |
"layer_norm_eps": 1e-05,
|
configuration.py
CHANGED
|
@@ -2,11 +2,36 @@
|
|
| 2 |
# This file is part of IceBERT POS model conversion.
|
| 3 |
|
| 4 |
import json
|
| 5 |
-
from
|
|
|
|
| 6 |
|
| 7 |
from transformers import AutoConfig, RobertaConfig
|
| 8 |
|
| 9 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 10 |
class IceBertPosConfig(RobertaConfig):
|
| 11 |
"""
|
| 12 |
Configuration class for IceBERT POS (Part-of-Speech) tagging model.
|
|
@@ -18,7 +43,7 @@ class IceBertPosConfig(RobertaConfig):
|
|
| 18 |
model_type = "icebert-pos"
|
| 19 |
|
| 20 |
def __init__(
|
| 21 |
-
self, label_schema: Optional[
|
| 22 |
):
|
| 23 |
super().__init__(**kwargs)
|
| 24 |
|
|
@@ -26,12 +51,16 @@ class IceBertPosConfig(RobertaConfig):
|
|
| 26 |
if label_schema is None:
|
| 27 |
label_schema = self._get_default_label_schema()
|
| 28 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 29 |
self.label_schema = label_schema
|
| 30 |
|
| 31 |
# Derive parameters from label schema
|
| 32 |
-
self.num_categories = len(label_schema
|
| 33 |
-
self.num_labels = len(label_schema
|
| 34 |
-
self.num_groups = len(label_schema
|
| 35 |
|
| 36 |
# Classification head parameters
|
| 37 |
self.classifier_dropout = classifier_dropout if classifier_dropout is not None else 0.1
|
|
@@ -41,10 +70,10 @@ class IceBertPosConfig(RobertaConfig):
|
|
| 41 |
self.attr_proj_input_size = self.num_categories + self.hidden_size
|
| 42 |
|
| 43 |
@staticmethod
|
| 44 |
-
def _get_default_label_schema() ->
|
| 45 |
"""Default label schema corresponding to terms2.json"""
|
| 46 |
-
return
|
| 47 |
-
|
| 48 |
"n",
|
| 49 |
"g",
|
| 50 |
"x",
|
|
@@ -89,7 +118,7 @@ class IceBertPosConfig(RobertaConfig):
|
|
| 89 |
"ns",
|
| 90 |
"m",
|
| 91 |
],
|
| 92 |
-
|
| 93 |
"n": ["gender", "number", "case", "def", "proper"],
|
| 94 |
"g": ["gender", "number", "case"],
|
| 95 |
"l": ["gender", "number", "case", "adj_c", "deg"],
|
|
@@ -116,7 +145,7 @@ class IceBertPosConfig(RobertaConfig):
|
|
| 116 |
"ae": ["deg"],
|
| 117 |
"as": ["deg"],
|
| 118 |
},
|
| 119 |
-
|
| 120 |
"gender",
|
| 121 |
"gender_or_person",
|
| 122 |
"number",
|
|
@@ -129,7 +158,7 @@ class IceBertPosConfig(RobertaConfig):
|
|
| 129 |
"person",
|
| 130 |
"tense",
|
| 131 |
],
|
| 132 |
-
|
| 133 |
"gender": ["masc", "fem", "neut", "gender_x"],
|
| 134 |
"number": ["sing", "plur"],
|
| 135 |
"person": ["1", "2", "3"],
|
|
@@ -142,7 +171,7 @@ class IceBertPosConfig(RobertaConfig):
|
|
| 142 |
"proper": ["proper"],
|
| 143 |
"adj_c": ["strong", "weak", "equiinflected"],
|
| 144 |
},
|
| 145 |
-
|
| 146 |
"<SEP>",
|
| 147 |
"n",
|
| 148 |
"g",
|
|
@@ -214,17 +243,37 @@ class IceBertPosConfig(RobertaConfig):
|
|
| 214 |
"act",
|
| 215 |
"mid",
|
| 216 |
],
|
| 217 |
-
"
|
| 218 |
-
"
|
| 219 |
-
|
| 220 |
-
|
| 221 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 222 |
|
| 223 |
@classmethod
|
| 224 |
def from_label_schema_file(cls, schema_path: str, **kwargs) -> "IceBertPosConfig":
|
| 225 |
"""Create config from a label schema JSON file"""
|
| 226 |
with open(schema_path, "r", encoding="utf-8") as f:
|
| 227 |
-
|
|
|
|
| 228 |
return cls(label_schema=label_schema, **kwargs)
|
| 229 |
|
| 230 |
|
|
|
|
| 2 |
# This file is part of IceBERT POS model conversion.
|
| 3 |
|
| 4 |
import json
|
| 5 |
+
from dataclasses import dataclass
|
| 6 |
+
from typing import Dict, List, Optional
|
| 7 |
|
| 8 |
from transformers import AutoConfig, RobertaConfig
|
| 9 |
|
| 10 |
|
| 11 |
+
@dataclass
|
| 12 |
+
class LabelSchema:
|
| 13 |
+
"""
|
| 14 |
+
Dataclass representing the structure of a POS tagging label schema.
|
| 15 |
+
|
| 16 |
+
The schema defines a hierarchical structure where:
|
| 17 |
+
- Categories (e.g., 'n', 'v', 'l') are the main POS types
|
| 18 |
+
- Groups (e.g., 'gender', 'number', 'case') are grammatical attribute types
|
| 19 |
+
- Labels are the specific values for each group (e.g., 'masc', 'fem', 'sing', 'plur')
|
| 20 |
+
|
| 21 |
+
Each category maps to applicable groups, and each group maps to its possible labels.
|
| 22 |
+
This enables multilabel classification where tokens get both a category and
|
| 23 |
+
relevant grammatical attributes.
|
| 24 |
+
"""
|
| 25 |
+
|
| 26 |
+
label_categories: List[str]
|
| 27 |
+
category_to_group_names: Dict[str, List[str]]
|
| 28 |
+
group_names: List[str]
|
| 29 |
+
group_name_to_labels: Dict[str, List[str]]
|
| 30 |
+
labels: List[str]
|
| 31 |
+
separator: str
|
| 32 |
+
ignore_categories: List[str]
|
| 33 |
+
|
| 34 |
+
|
| 35 |
class IceBertPosConfig(RobertaConfig):
|
| 36 |
"""
|
| 37 |
Configuration class for IceBERT POS (Part-of-Speech) tagging model.
|
|
|
|
| 43 |
model_type = "icebert-pos"
|
| 44 |
|
| 45 |
def __init__(
|
| 46 |
+
self, label_schema: Optional[LabelSchema] = None, classifier_dropout: Optional[float] = None, **kwargs
|
| 47 |
):
|
| 48 |
super().__init__(**kwargs)
|
| 49 |
|
|
|
|
| 51 |
if label_schema is None:
|
| 52 |
label_schema = self._get_default_label_schema()
|
| 53 |
|
| 54 |
+
# Convert dict to LabelSchema if needed (when loaded from JSON)
|
| 55 |
+
if isinstance(label_schema, dict):
|
| 56 |
+
label_schema = LabelSchema(**label_schema)
|
| 57 |
+
|
| 58 |
self.label_schema = label_schema
|
| 59 |
|
| 60 |
# Derive parameters from label schema
|
| 61 |
+
self.num_categories = len(label_schema.label_categories)
|
| 62 |
+
self.num_labels = len(label_schema.labels)
|
| 63 |
+
self.num_groups = len(label_schema.group_names)
|
| 64 |
|
| 65 |
# Classification head parameters
|
| 66 |
self.classifier_dropout = classifier_dropout if classifier_dropout is not None else 0.1
|
|
|
|
| 70 |
self.attr_proj_input_size = self.num_categories + self.hidden_size
|
| 71 |
|
| 72 |
@staticmethod
|
| 73 |
+
def _get_default_label_schema() -> LabelSchema:
|
| 74 |
"""Default label schema corresponding to terms2.json"""
|
| 75 |
+
return LabelSchema(
|
| 76 |
+
label_categories=[
|
| 77 |
"n",
|
| 78 |
"g",
|
| 79 |
"x",
|
|
|
|
| 118 |
"ns",
|
| 119 |
"m",
|
| 120 |
],
|
| 121 |
+
category_to_group_names={
|
| 122 |
"n": ["gender", "number", "case", "def", "proper"],
|
| 123 |
"g": ["gender", "number", "case"],
|
| 124 |
"l": ["gender", "number", "case", "adj_c", "deg"],
|
|
|
|
| 145 |
"ae": ["deg"],
|
| 146 |
"as": ["deg"],
|
| 147 |
},
|
| 148 |
+
group_names=[
|
| 149 |
"gender",
|
| 150 |
"gender_or_person",
|
| 151 |
"number",
|
|
|
|
| 158 |
"person",
|
| 159 |
"tense",
|
| 160 |
],
|
| 161 |
+
group_name_to_labels={
|
| 162 |
"gender": ["masc", "fem", "neut", "gender_x"],
|
| 163 |
"number": ["sing", "plur"],
|
| 164 |
"person": ["1", "2", "3"],
|
|
|
|
| 171 |
"proper": ["proper"],
|
| 172 |
"adj_c": ["strong", "weak", "equiinflected"],
|
| 173 |
},
|
| 174 |
+
labels=[
|
| 175 |
"<SEP>",
|
| 176 |
"n",
|
| 177 |
"g",
|
|
|
|
| 243 |
"act",
|
| 244 |
"mid",
|
| 245 |
],
|
| 246 |
+
separator="<SEP>",
|
| 247 |
+
ignore_categories=["x", "e"],
|
| 248 |
+
)
|
| 249 |
+
|
| 250 |
+
def to_dict(self):
|
| 251 |
+
"""Convert config to dictionary, handling LabelSchema serialization."""
|
| 252 |
+
output = super().to_dict()
|
| 253 |
+
|
| 254 |
+
# Convert LabelSchema to dict for JSON serialization
|
| 255 |
+
if hasattr(self, 'label_schema') and self.label_schema is not None:
|
| 256 |
+
if isinstance(self.label_schema, LabelSchema):
|
| 257 |
+
output['label_schema'] = {
|
| 258 |
+
'label_categories': self.label_schema.label_categories,
|
| 259 |
+
'category_to_group_names': self.label_schema.category_to_group_names,
|
| 260 |
+
'group_names': self.label_schema.group_names,
|
| 261 |
+
'group_name_to_labels': self.label_schema.group_name_to_labels,
|
| 262 |
+
'labels': self.label_schema.labels,
|
| 263 |
+
'separator': self.label_schema.separator,
|
| 264 |
+
'ignore_categories': self.label_schema.ignore_categories,
|
| 265 |
+
}
|
| 266 |
+
else:
|
| 267 |
+
output['label_schema'] = self.label_schema
|
| 268 |
+
|
| 269 |
+
return output
|
| 270 |
|
| 271 |
@classmethod
|
| 272 |
def from_label_schema_file(cls, schema_path: str, **kwargs) -> "IceBertPosConfig":
|
| 273 |
"""Create config from a label schema JSON file"""
|
| 274 |
with open(schema_path, "r", encoding="utf-8") as f:
|
| 275 |
+
schema_dict = json.load(f)
|
| 276 |
+
label_schema = LabelSchema(**schema_dict)
|
| 277 |
return cls(label_schema=label_schema, **kwargs)
|
| 278 |
|
| 279 |
|
modeling.py
CHANGED
|
@@ -11,6 +11,15 @@ from torch.nn.utils.rnn import pad_sequence
|
|
| 11 |
from transformers import AutoConfig, AutoModel, PreTrainedModel, RobertaModel
|
| 12 |
|
| 13 |
from .configuration import IceBertPosConfig
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 14 |
|
| 15 |
logger = logging.getLogger(__name__)
|
| 16 |
|
|
@@ -38,11 +47,11 @@ class MultiLabelTokenClassificationHead(nn.Module):
|
|
| 38 |
def forward(self, features: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 39 |
"""
|
| 40 |
Args:
|
| 41 |
-
features: Word-level features of shape (
|
| 42 |
|
| 43 |
Returns:
|
| 44 |
-
cat_logits: Category logits of shape (
|
| 45 |
-
attr_logits: Attribute logits of shape (
|
| 46 |
"""
|
| 47 |
x = self.dropout(features)
|
| 48 |
x = self.dense(x)
|
|
@@ -81,9 +90,22 @@ class IceBertPosForTokenClassification(PreTrainedModel):
|
|
| 81 |
self.roberta = RobertaModel(config, add_pooling_layer=False)
|
| 82 |
self.classifier = MultiLabelTokenClassificationHead(config)
|
| 83 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 84 |
# Initialize weights and apply final processing
|
| 85 |
self.post_init()
|
| 86 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 87 |
def forward(
|
| 88 |
self,
|
| 89 |
input_ids: torch.Tensor,
|
|
@@ -101,7 +123,7 @@ class IceBertPosForTokenClassification(PreTrainedModel):
|
|
| 101 |
Args:
|
| 102 |
input_ids: Token indices of shape (batch_size, sequence_length)
|
| 103 |
attention_mask: Attention mask of shape (batch_size, sequence_length)
|
| 104 |
-
word_mask: Binary mask indicating word boundaries (1 = word start)
|
| 105 |
|
| 106 |
Returns:
|
| 107 |
cat_logits: Category logits of shape (batch_size, max_words, num_categories)
|
|
@@ -118,22 +140,37 @@ class IceBertPosForTokenClassification(PreTrainedModel):
|
|
| 118 |
head_mask=head_mask,
|
| 119 |
inputs_embeds=inputs_embeds,
|
| 120 |
output_attentions=output_attentions,
|
| 121 |
-
output_hidden_states=
|
| 122 |
return_dict=return_dict,
|
| 123 |
)
|
| 124 |
|
| 125 |
-
|
| 126 |
-
|
| 127 |
-
# Aggregate subword tokens to word-level representations using word_mask
|
| 128 |
-
word_features, nwords = self._aggregate_subword_tokens(sequence_output, word_mask)
|
| 129 |
-
|
| 130 |
-
# Apply classification head
|
| 131 |
-
cat_logits, attr_logits = self.classifier(word_features)
|
| 132 |
|
| 133 |
-
#
|
| 134 |
-
|
| 135 |
|
| 136 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 137 |
|
| 138 |
def _aggregate_subword_tokens(
|
| 139 |
self, sequence_output: torch.Tensor, word_mask: torch.Tensor
|
|
@@ -147,7 +184,7 @@ class IceBertPosForTokenClassification(PreTrainedModel):
|
|
| 147 |
word_mask: Binary mask where 1 indicates start of word (batch_size, seq_len)
|
| 148 |
|
| 149 |
Returns:
|
| 150 |
-
word_features: Word-level features (
|
| 151 |
nwords: Number of words per sequence (batch_size,)
|
| 152 |
"""
|
| 153 |
# TODO: Verify that BOS and EOS are handled correctly - I'm worried that this does not correctly handle padding
|
|
@@ -234,7 +271,7 @@ class IceBertPosForTokenClassification(PreTrainedModel):
|
|
| 234 |
|
| 235 |
cat_logits, attr_logits = self.forward(input_ids=input_ids, attention_mask=attention_mask, word_mask=word_mask)
|
| 236 |
|
| 237 |
-
return self._logits_to_labels(cat_logits, attr_logits,
|
| 238 |
|
| 239 |
def _word_ids_to_word_mask(self, word_ids: List[List[int]], input_shape: torch.Size) -> torch.Tensor:
|
| 240 |
"""
|
|
@@ -245,18 +282,23 @@ class IceBertPosForTokenClassification(PreTrainedModel):
|
|
| 245 |
input_shape: Shape of input_ids tensor (batch_size, seq_len)
|
| 246 |
|
| 247 |
Returns:
|
| 248 |
-
word_mask: Binary tensor where 1 indicates start of word
|
| 249 |
"""
|
| 250 |
batch_size, seq_len = input_shape
|
| 251 |
word_mask = torch.zeros(batch_size, seq_len, dtype=torch.long)
|
| 252 |
|
| 253 |
for batch_idx, seq_word_ids in enumerate(word_ids):
|
|
|
|
|
|
|
| 254 |
prev_word_id = None
|
| 255 |
-
for token_idx, word_id in enumerate(
|
| 256 |
if word_id != prev_word_id:
|
| 257 |
-
word_mask[batch_idx, token_idx] = 1
|
| 258 |
prev_word_id = word_id
|
| 259 |
|
|
|
|
|
|
|
|
|
|
| 260 |
return word_mask
|
| 261 |
|
| 262 |
def predict_labels_from_text(self, sentences: List[str], tokenizer) -> List[List[Tuple[str, List[str]]]]:
|
|
@@ -270,231 +312,130 @@ class IceBertPosForTokenClassification(PreTrainedModel):
|
|
| 270 |
Returns:
|
| 271 |
List of sequences, each containing (category, [attributes]) per word
|
| 272 |
"""
|
| 273 |
-
#
|
| 274 |
-
|
| 275 |
-
|
| 276 |
-
|
| 277 |
-
#
|
| 278 |
-
|
| 279 |
-
|
| 280 |
-
|
| 281 |
-
|
| 282 |
-
|
| 283 |
-
|
| 284 |
-
|
| 285 |
-
|
| 286 |
-
|
| 287 |
-
|
| 288 |
-
|
| 289 |
-
|
| 290 |
-
|
| 291 |
-
|
| 292 |
-
|
| 293 |
-
|
| 294 |
-
|
| 295 |
-
batch_input_ids = torch.stack(batch_input_ids)
|
| 296 |
-
batch_attention_mask = torch.stack(batch_attention_mask)
|
| 297 |
|
| 298 |
return self.predict_labels(batch_input_ids, batch_attention_mask, word_ids_list)
|
| 299 |
|
| 300 |
-
def _make_group_name_to_group_attr_vec_idxs(self):
|
| 301 |
-
"""Create mapping from group names to their attribute vector indices"""
|
| 302 |
-
group_name_to_group_attr_vec_idxs = {}
|
| 303 |
-
labels = self.config.label_schema["labels"]
|
| 304 |
-
nspecial = 0 # Number of special tokens in label dictionary (like <SEP>)
|
| 305 |
-
|
| 306 |
-
for group_name, group_labels in self.config.label_schema["group_name_to_labels"].items():
|
| 307 |
-
vec_idxs = []
|
| 308 |
-
for label in group_labels:
|
| 309 |
-
if label in labels:
|
| 310 |
-
# Find index in labels list, but subtract nspecial to get vector index
|
| 311 |
-
label_dict_idx = labels.index(label)
|
| 312 |
-
if label_dict_idx >= nspecial: # Skip special tokens
|
| 313 |
-
vec_idxs.append(label_dict_idx - nspecial)
|
| 314 |
-
group_name_to_group_attr_vec_idxs[group_name] = torch.tensor(vec_idxs)
|
| 315 |
-
|
| 316 |
-
return group_name_to_group_attr_vec_idxs
|
| 317 |
-
|
| 318 |
-
def _make_group_masks(self):
|
| 319 |
-
"""Create group masks for each category"""
|
| 320 |
-
label_categories = self.config.label_schema["label_categories"]
|
| 321 |
-
group_names = self.config.label_schema["group_names"]
|
| 322 |
-
category_to_group_names = self.config.label_schema["category_to_group_names"]
|
| 323 |
-
|
| 324 |
-
num_cats = len(label_categories)
|
| 325 |
-
num_groups = len(group_names)
|
| 326 |
-
|
| 327 |
-
group_mask = torch.zeros(num_cats, num_groups, dtype=torch.bool)
|
| 328 |
-
|
| 329 |
-
for cat_idx, category in enumerate(label_categories):
|
| 330 |
-
if category in category_to_group_names:
|
| 331 |
-
for group_name in category_to_group_names[category]:
|
| 332 |
-
if group_name in group_names:
|
| 333 |
-
group_idx = group_names.index(group_name)
|
| 334 |
-
group_mask[cat_idx, group_idx] = True
|
| 335 |
-
|
| 336 |
-
return group_mask
|
| 337 |
-
|
| 338 |
-
def _make_category_mappings(self):
|
| 339 |
-
"""Create mappings between category vector indices and dictionary indices"""
|
| 340 |
-
labels = self.config.label_schema["labels"]
|
| 341 |
-
label_categories = self.config.label_schema["label_categories"]
|
| 342 |
-
|
| 343 |
-
# Create mapping from category names to vector indices (0-based)
|
| 344 |
-
cat_dict_idx_to_vec_idx = torch.zeros(len(labels), dtype=torch.long)
|
| 345 |
-
cat_vec_idx_to_dict_idx = torch.zeros(len(label_categories), dtype=torch.long)
|
| 346 |
-
|
| 347 |
-
for vec_idx, category in enumerate(label_categories):
|
| 348 |
-
if category in labels:
|
| 349 |
-
dict_idx = labels.index(category)
|
| 350 |
-
cat_dict_idx_to_vec_idx[dict_idx] = vec_idx
|
| 351 |
-
cat_vec_idx_to_dict_idx[vec_idx] = dict_idx
|
| 352 |
-
|
| 353 |
-
return cat_dict_idx_to_vec_idx, cat_vec_idx_to_dict_idx
|
| 354 |
-
|
| 355 |
-
def _count_words_per_sequence(self, word_ids: List[List[int]]) -> List[int]:
|
| 356 |
-
"""Count the number of unique words in each sequence."""
|
| 357 |
-
words_per_seq = []
|
| 358 |
-
for seq_word_ids in word_ids:
|
| 359 |
-
unique_word_ids = set(word_id for word_id in seq_word_ids if word_id is not None)
|
| 360 |
-
words_per_seq.append(len(unique_word_ids))
|
| 361 |
-
return words_per_seq
|
| 362 |
-
|
| 363 |
-
def _predict_categories_for_sequence(
|
| 364 |
-
self, cat_logits: torch.Tensor, seq_idx: int, seq_nwords: int, cat_vec_idx_to_dict_idx: torch.Tensor
|
| 365 |
-
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 366 |
-
"""Predict categories for a single sequence and return both vector and dictionary indices."""
|
| 367 |
-
pred_cat_vec_idxs = cat_logits[seq_idx, :seq_nwords].max(dim=-1).indices
|
| 368 |
-
pred_cats = cat_vec_idx_to_dict_idx[pred_cat_vec_idxs]
|
| 369 |
-
return pred_cat_vec_idxs, pred_cats
|
| 370 |
-
|
| 371 |
-
def _predict_attributes_for_group(
|
| 372 |
-
self,
|
| 373 |
-
attr_logits: torch.Tensor,
|
| 374 |
-
seq_idx: int,
|
| 375 |
-
seq_nwords: int,
|
| 376 |
-
group_vec_idxs: torch.Tensor,
|
| 377 |
-
seq_group_mask: torch.Tensor,
|
| 378 |
-
group_idx: int,
|
| 379 |
-
) -> torch.Tensor:
|
| 380 |
-
"""Predict attributes for a single group."""
|
| 381 |
-
if len(group_vec_idxs) == 0:
|
| 382 |
-
return torch.zeros(seq_nwords, dtype=torch.long)
|
| 383 |
-
|
| 384 |
-
# Get logits for this group
|
| 385 |
-
group_logits = attr_logits[seq_idx, :seq_nwords, group_vec_idxs]
|
| 386 |
-
|
| 387 |
-
if len(group_vec_idxs) == 1:
|
| 388 |
-
# Single element group: use sigmoid > 0.5
|
| 389 |
-
group_pred = group_logits.sigmoid().ge(0.5).long()
|
| 390 |
-
group_pred_dict_idxs = (group_pred.squeeze() * group_vec_idxs.item()) * seq_group_mask[:, group_idx]
|
| 391 |
-
else:
|
| 392 |
-
# Multi element group: use argmax
|
| 393 |
-
group_pred_vec_idxs = group_logits.max(dim=-1).indices
|
| 394 |
-
group_pred_dict_idxs = group_vec_idxs[group_pred_vec_idxs] * seq_group_mask[:, group_idx]
|
| 395 |
-
|
| 396 |
-
return group_pred_dict_idxs
|
| 397 |
-
|
| 398 |
-
def _predict_all_attributes_for_sequence(
|
| 399 |
-
self,
|
| 400 |
-
attr_logits: torch.Tensor,
|
| 401 |
-
seq_idx: int,
|
| 402 |
-
seq_nwords: int,
|
| 403 |
-
pred_cat_vec_idxs: torch.Tensor,
|
| 404 |
-
group_name_to_group_attr_vec_idxs: dict,
|
| 405 |
-
group_mask: torch.Tensor,
|
| 406 |
-
group_names: List[str],
|
| 407 |
-
) -> torch.Tensor:
|
| 408 |
-
"""Predict all attributes for a single sequence."""
|
| 409 |
-
seq_group_mask = group_mask[pred_cat_vec_idxs]
|
| 410 |
-
pred_attrs = []
|
| 411 |
-
|
| 412 |
-
for group_idx, group_name in enumerate(group_names):
|
| 413 |
-
if group_name not in group_name_to_group_attr_vec_idxs:
|
| 414 |
-
pred_attrs.append(torch.zeros(seq_nwords, dtype=torch.long))
|
| 415 |
-
continue
|
| 416 |
-
|
| 417 |
-
group_vec_idxs = group_name_to_group_attr_vec_idxs[group_name]
|
| 418 |
-
group_pred_dict_idxs = self._predict_attributes_for_group(
|
| 419 |
-
attr_logits, seq_idx, seq_nwords, group_vec_idxs, seq_group_mask, group_idx
|
| 420 |
-
)
|
| 421 |
-
pred_attrs.append(group_pred_dict_idxs)
|
| 422 |
-
|
| 423 |
-
# Stack predictions
|
| 424 |
-
if pred_attrs:
|
| 425 |
-
return torch.stack([p.squeeze() if p.dim() > 1 else p for p in pred_attrs]).t()
|
| 426 |
-
else:
|
| 427 |
-
return torch.zeros(seq_nwords, len(group_names), dtype=torch.long)
|
| 428 |
-
|
| 429 |
-
def _convert_predictions_to_labels(
|
| 430 |
-
self, pred_cats: torch.Tensor, pred_attrs_tensor: torch.Tensor, labels: List[str], group_names: List[str]
|
| 431 |
-
) -> List[Tuple[str, List[str]]]:
|
| 432 |
-
"""Convert prediction tensors to human-readable labels."""
|
| 433 |
-
seq_nwords = pred_cats.size(0)
|
| 434 |
-
seq_predictions = []
|
| 435 |
-
|
| 436 |
-
for word_idx in range(seq_nwords):
|
| 437 |
-
# Category (convert from dictionary index to string)
|
| 438 |
-
cat_dict_idx = pred_cats[word_idx].item()
|
| 439 |
-
if cat_dict_idx < len(labels):
|
| 440 |
-
category = labels[cat_dict_idx]
|
| 441 |
-
else:
|
| 442 |
-
category = "UNK"
|
| 443 |
-
|
| 444 |
-
# Attributes (convert from dictionary indices to strings)
|
| 445 |
-
attributes = []
|
| 446 |
-
for group_idx in range(len(group_names)):
|
| 447 |
-
attr_dict_idx = pred_attrs_tensor[word_idx, group_idx].item()
|
| 448 |
-
if attr_dict_idx > 0 and attr_dict_idx < len(labels): # Skip 0 (empty) and out of bounds
|
| 449 |
-
attributes.append(labels[attr_dict_idx])
|
| 450 |
-
|
| 451 |
-
seq_predictions.append((category, attributes))
|
| 452 |
-
|
| 453 |
-
return seq_predictions
|
| 454 |
-
|
| 455 |
def _logits_to_labels(
|
| 456 |
-
self, cat_logits: torch.Tensor, attr_logits: torch.Tensor,
|
| 457 |
) -> List[List[Tuple[str, List[str]]]]:
|
| 458 |
"""
|
| 459 |
Convert logits to human-readable labels using fairseq's group-based logic.
|
|
|
|
| 460 |
"""
|
| 461 |
-
#
|
| 462 |
-
|
| 463 |
-
|
| 464 |
-
|
| 465 |
-
|
| 466 |
-
|
| 467 |
-
|
| 468 |
-
|
| 469 |
-
|
| 470 |
-
|
| 471 |
-
|
| 472 |
-
|
| 473 |
-
|
| 474 |
-
|
| 475 |
-
|
| 476 |
-
|
| 477 |
-
|
| 478 |
-
|
| 479 |
-
|
| 480 |
-
|
| 481 |
-
|
| 482 |
-
|
| 483 |
-
|
| 484 |
-
|
| 485 |
-
|
| 486 |
-
|
| 487 |
-
|
| 488 |
-
|
| 489 |
-
|
| 490 |
-
|
| 491 |
-
|
| 492 |
-
|
| 493 |
-
|
| 494 |
-
|
| 495 |
-
|
| 496 |
-
|
| 497 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 498 |
|
| 499 |
|
| 500 |
AutoConfig.register("icebert-pos", IceBertPosConfig)
|
|
|
|
| 11 |
from transformers import AutoConfig, AutoModel, PreTrainedModel, RobertaModel
|
| 12 |
|
| 13 |
from .configuration import IceBertPosConfig
|
| 14 |
+
from .old_label_utils import (
|
| 15 |
+
SimpleLabelDictionary,
|
| 16 |
+
clean_cats_attrs,
|
| 17 |
+
create_label_dictionary_from_schema,
|
| 18 |
+
make_dict_idx_to_vec_idx,
|
| 19 |
+
make_group_masks,
|
| 20 |
+
make_group_name_to_group_attr_vec_idxs,
|
| 21 |
+
make_vec_idx_to_dict_idx,
|
| 22 |
+
)
|
| 23 |
|
| 24 |
logger = logging.getLogger(__name__)
|
| 25 |
|
|
|
|
| 47 |
def forward(self, features: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 48 |
"""
|
| 49 |
Args:
|
| 50 |
+
features: Word-level features of shape (batch_size, max_words, hidden_size)
|
| 51 |
|
| 52 |
Returns:
|
| 53 |
+
cat_logits: Category logits of shape (batch_size, max_words, num_categories)
|
| 54 |
+
attr_logits: Attribute logits of shape (batch_size, max_words, num_labels)
|
| 55 |
"""
|
| 56 |
x = self.dropout(features)
|
| 57 |
x = self.dense(x)
|
|
|
|
| 90 |
self.roberta = RobertaModel(config, add_pooling_layer=False)
|
| 91 |
self.classifier = MultiLabelTokenClassificationHead(config)
|
| 92 |
|
| 93 |
+
# Create label dictionary and mappings (mimicking old fairseq model)
|
| 94 |
+
self.label_dictionary = create_label_dictionary_from_schema(config.label_schema)
|
| 95 |
+
self._setup_label_mappings()
|
| 96 |
+
|
| 97 |
# Initialize weights and apply final processing
|
| 98 |
self.post_init()
|
| 99 |
|
| 100 |
+
def _setup_label_mappings(self):
|
| 101 |
+
"""Setup label mappings similar to the old fairseq model."""
|
| 102 |
+
schema = self.config.label_schema
|
| 103 |
+
|
| 104 |
+
self.group_name_to_group_attr_vec_idxs = make_group_name_to_group_attr_vec_idxs(self.label_dictionary, schema)
|
| 105 |
+
self.cat_dict_idx_to_vec_idx = make_dict_idx_to_vec_idx(self.label_dictionary, schema.label_categories)
|
| 106 |
+
self.cat_vec_idx_to_dict_idx = make_vec_idx_to_dict_idx(self.label_dictionary, schema.label_categories)
|
| 107 |
+
self.group_mask = make_group_masks(self.label_dictionary, schema)
|
| 108 |
+
|
| 109 |
def forward(
|
| 110 |
self,
|
| 111 |
input_ids: torch.Tensor,
|
|
|
|
| 123 |
Args:
|
| 124 |
input_ids: Token indices of shape (batch_size, sequence_length)
|
| 125 |
attention_mask: Attention mask of shape (batch_size, sequence_length)
|
| 126 |
+
word_mask: Binary mask indicating word boundaries (1 = word start) of shape (batch_size, sequence_length)
|
| 127 |
|
| 128 |
Returns:
|
| 129 |
cat_logits: Category logits of shape (batch_size, max_words, num_categories)
|
|
|
|
| 140 |
head_mask=head_mask,
|
| 141 |
inputs_embeds=inputs_embeds,
|
| 142 |
output_attentions=output_attentions,
|
| 143 |
+
output_hidden_states=True,
|
| 144 |
return_dict=return_dict,
|
| 145 |
)
|
| 146 |
|
| 147 |
+
x = outputs[0] # (batch_size, seq_len, hidden)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 148 |
|
| 149 |
+
# Copy exact logic from old model
|
| 150 |
+
_, _, inner_dim = x.shape
|
| 151 |
|
| 152 |
+
# use first bpe token of word as representation
|
| 153 |
+
x = x[:, 1:-1, :]
|
| 154 |
+
starts = word_mask[:, 1:-1] # remove bos, eos
|
| 155 |
+
ends = starts.roll(-1, dims=[-1]).nonzero()[:, -1] + 1
|
| 156 |
+
starts = starts.nonzero().tolist()
|
| 157 |
+
mean_words = []
|
| 158 |
+
for (seq_idx, token_idx), end in zip(starts, ends):
|
| 159 |
+
mean_words.append(x[seq_idx, token_idx:end, :].mean(dim=0))
|
| 160 |
+
mean_words = torch.stack(mean_words)
|
| 161 |
+
words = mean_words
|
| 162 |
+
# Innermost dimension is mask for tokens at head of word.
|
| 163 |
+
nwords = word_mask.sum(dim=-1)
|
| 164 |
+
(cat_logits, attr_logits) = self.classifier(words)
|
| 165 |
+
|
| 166 |
+
# (Batch * Time) x Depth -> Batch x Time x Depth
|
| 167 |
+
cat_logits = pad_sequence(cat_logits.split((nwords).tolist()), padding_value=0, batch_first=True)
|
| 168 |
+
attr_logits = pad_sequence(
|
| 169 |
+
attr_logits.split((nwords).tolist()),
|
| 170 |
+
padding_value=0,
|
| 171 |
+
batch_first=True,
|
| 172 |
+
)
|
| 173 |
+
return cat_logits, attr_logits
|
| 174 |
|
| 175 |
def _aggregate_subword_tokens(
|
| 176 |
self, sequence_output: torch.Tensor, word_mask: torch.Tensor
|
|
|
|
| 184 |
word_mask: Binary mask where 1 indicates start of word (batch_size, seq_len)
|
| 185 |
|
| 186 |
Returns:
|
| 187 |
+
word_features: Word-level features (batch_size, max_words, hidden_size)
|
| 188 |
nwords: Number of words per sequence (batch_size,)
|
| 189 |
"""
|
| 190 |
# TODO: Verify that BOS and EOS are handled correctly - I'm worried that this does not correctly handle padding
|
|
|
|
| 271 |
|
| 272 |
cat_logits, attr_logits = self.forward(input_ids=input_ids, attention_mask=attention_mask, word_mask=word_mask)
|
| 273 |
|
| 274 |
+
return self._logits_to_labels(cat_logits, attr_logits, word_mask)
|
| 275 |
|
| 276 |
def _word_ids_to_word_mask(self, word_ids: List[List[int]], input_shape: torch.Size) -> torch.Tensor:
|
| 277 |
"""
|
|
|
|
| 282 |
input_shape: Shape of input_ids tensor (batch_size, seq_len)
|
| 283 |
|
| 284 |
Returns:
|
| 285 |
+
word_mask: Binary tensor where 1 indicates start of word (batch_size, seq_len)
|
| 286 |
"""
|
| 287 |
batch_size, seq_len = input_shape
|
| 288 |
word_mask = torch.zeros(batch_size, seq_len, dtype=torch.long)
|
| 289 |
|
| 290 |
for batch_idx, seq_word_ids in enumerate(word_ids):
|
| 291 |
+
# Truncate to exclude BOS and EOS tokens (first and last)
|
| 292 |
+
truncated_word_ids = seq_word_ids[1:-1]
|
| 293 |
prev_word_id = None
|
| 294 |
+
for token_idx, word_id in enumerate(truncated_word_ids):
|
| 295 |
if word_id != prev_word_id:
|
| 296 |
+
word_mask[batch_idx, token_idx + 1] = 1 # +1 to account for BOS
|
| 297 |
prev_word_id = word_id
|
| 298 |
|
| 299 |
+
# Debug logging to match fairseq model
|
| 300 |
+
logger.debug(f"Word mask: {word_mask[batch_idx].tolist()}")
|
| 301 |
+
|
| 302 |
return word_mask
|
| 303 |
|
| 304 |
def predict_labels_from_text(self, sentences: List[str], tokenizer) -> List[List[Tuple[str, List[str]]]]:
|
|
|
|
| 312 |
Returns:
|
| 313 |
List of sequences, each containing (category, [attributes]) per word
|
| 314 |
"""
|
| 315 |
+
# Split sentences by spaces to get proper word boundaries
|
| 316 |
+
# This fixes the issue where tokens like "Kl." get split incorrectly
|
| 317 |
+
sentences_split = [sentence.split() for sentence in sentences]
|
| 318 |
+
|
| 319 |
+
# Use batch_encode_plus with is_split_into_words=True to preserve word boundaries
|
| 320 |
+
encoding = tokenizer.batch_encode_plus(
|
| 321 |
+
sentences_split,
|
| 322 |
+
return_tensors="pt",
|
| 323 |
+
padding=True,
|
| 324 |
+
is_split_into_words=True,
|
| 325 |
+
add_special_tokens=True
|
| 326 |
+
)
|
| 327 |
+
|
| 328 |
+
batch_input_ids = encoding["input_ids"]
|
| 329 |
+
batch_attention_mask = encoding["attention_mask"]
|
| 330 |
+
word_ids_list = [encoding.word_ids(i) for i in range(len(sentences))]
|
| 331 |
+
|
| 332 |
+
# Debug logging to match fairseq model
|
| 333 |
+
for i in range(len(sentences)):
|
| 334 |
+
logger.debug(f"Encoded tokens: {batch_input_ids[i]}")
|
| 335 |
+
logger.debug(f"Decoded tokens: {tokenizer.convert_ids_to_tokens(batch_input_ids[i].tolist())}")
|
| 336 |
+
logger.debug(f"Word IDs: {word_ids_list[i]}")
|
|
|
|
|
|
|
| 337 |
|
| 338 |
return self.predict_labels(batch_input_ids, batch_attention_mask, word_ids_list)
|
| 339 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 340 |
def _logits_to_labels(
|
| 341 |
+
self, cat_logits: torch.Tensor, attr_logits: torch.Tensor, word_mask: torch.Tensor
|
| 342 |
) -> List[List[Tuple[str, List[str]]]]:
|
| 343 |
"""
|
| 344 |
Convert logits to human-readable labels using fairseq's group-based logic.
|
| 345 |
+
Copied from the old model's logits_to_labels method.
|
| 346 |
"""
|
| 347 |
+
# logits: Batch x Time x Labels
|
| 348 |
+
bsz, _, num_cats = cat_logits.shape
|
| 349 |
+
_, _, num_attrs = attr_logits.shape
|
| 350 |
+
nwords = word_mask.sum(-1)
|
| 351 |
+
|
| 352 |
+
assert num_attrs == len(self.config.label_schema.labels)
|
| 353 |
+
assert num_cats == len(self.config.label_schema.label_categories)
|
| 354 |
+
|
| 355 |
+
batch_cats = []
|
| 356 |
+
batch_attrs = []
|
| 357 |
+
for seq_idx in range(bsz):
|
| 358 |
+
seq_nwords = nwords[seq_idx]
|
| 359 |
+
pred_cat_vec_idxs = cat_logits[seq_idx, :seq_nwords].max(dim=-1).indices
|
| 360 |
+
pred_cats = self.cat_vec_idx_to_dict_idx[pred_cat_vec_idxs]
|
| 361 |
+
|
| 362 |
+
group_mask = self.group_mask[pred_cat_vec_idxs]
|
| 363 |
+
offset = self.label_dictionary.nspecial
|
| 364 |
+
pred_attrs = []
|
| 365 |
+
for group_idx, group_name in enumerate(self.config.label_schema.group_names):
|
| 366 |
+
group_vec_idxs = self.group_name_to_group_attr_vec_idxs[group_name]
|
| 367 |
+
# logits: (bsz * nwords) x labels
|
| 368 |
+
group_logits = attr_logits[seq_idx, :seq_nwords, group_vec_idxs]
|
| 369 |
+
if len(group_vec_idxs) == 1:
|
| 370 |
+
group_pred = group_logits.sigmoid().ge(0.5).long()
|
| 371 |
+
group_pred_dict_idxs = (group_pred.squeeze() * (group_vec_idxs.item() + offset)).T.to(
|
| 372 |
+
"cpu"
|
| 373 |
+
) * group_mask[:, group_idx]
|
| 374 |
+
else:
|
| 375 |
+
group_pred_vec_idxs = group_logits.max(dim=-1).indices
|
| 376 |
+
group_pred_dict_idxs = (group_vec_idxs[group_pred_vec_idxs] + offset) * group_mask[:, group_idx]
|
| 377 |
+
pred_attrs.append(group_pred_dict_idxs)
|
| 378 |
+
|
| 379 |
+
pred_attrs = torch.stack([p.squeeze() for p in pred_attrs]).t()
|
| 380 |
+
|
| 381 |
+
batch_cats.append(pred_cats)
|
| 382 |
+
batch_attrs.append(pred_attrs)
|
| 383 |
+
|
| 384 |
+
predictions = list(
|
| 385 |
+
[
|
| 386 |
+
clean_cats_attrs(
|
| 387 |
+
self.label_dictionary,
|
| 388 |
+
self.config.label_schema,
|
| 389 |
+
seq_cats,
|
| 390 |
+
seq_attrs,
|
| 391 |
+
)
|
| 392 |
+
for seq_cats, seq_attrs in zip(batch_cats, batch_attrs)
|
| 393 |
+
]
|
| 394 |
+
)
|
| 395 |
+
|
| 396 |
+
return predictions
|
| 397 |
+
|
| 398 |
+
|
| 399 |
+
def make_vec_idx_to_dict_idx(dictionary, labels, device="cpu", fill_value=-100):
|
| 400 |
+
vec_idx_to_dict_idx = torch.full((len(dictionary),), device=device, fill_value=fill_value, dtype=torch.long)
|
| 401 |
+
for vec_idx, label in enumerate(labels):
|
| 402 |
+
vec_idx_to_dict_idx[vec_idx] = dictionary.index(label)
|
| 403 |
+
return vec_idx_to_dict_idx
|
| 404 |
+
|
| 405 |
+
|
| 406 |
+
def make_group_masks(dictionary, schema, device="cpu"):
|
| 407 |
+
num_groups = len(schema.group_names)
|
| 408 |
+
offset = dictionary.nspecial
|
| 409 |
+
num_labels = len(dictionary) - offset
|
| 410 |
+
ret_mask = torch.zeros(num_labels, num_groups, dtype=torch.int64, device=device)
|
| 411 |
+
for cat, cat_group_names in schema.category_to_group_names.items():
|
| 412 |
+
cat_label_idx = dictionary.index(cat)
|
| 413 |
+
cat_vec_idx = schema.label_categories.index(cat)
|
| 414 |
+
for group_name in cat_group_names:
|
| 415 |
+
ret_mask[cat_vec_idx, schema.group_names.index(group_name)] = 1
|
| 416 |
+
assert cat_label_idx != dictionary.unk()
|
| 417 |
+
for cat in schema.label_categories:
|
| 418 |
+
cat_label_idx = dictionary.index(cat)
|
| 419 |
+
assert cat_label_idx != dictionary.unk()
|
| 420 |
+
return ret_mask
|
| 421 |
+
|
| 422 |
+
|
| 423 |
+
def make_group_name_to_group_attr_vec_idxs(dict_, schema):
|
| 424 |
+
offset = dict_.nspecial
|
| 425 |
+
group_names = schema.group_name_to_labels.keys()
|
| 426 |
+
name_to_labels = schema.group_name_to_labels
|
| 427 |
+
group_name_to_group_attr_vec_idxs = {
|
| 428 |
+
name: torch.tensor([dict_.index(item) - offset for item in name_to_labels[name]]) for name in group_names
|
| 429 |
+
}
|
| 430 |
+
return group_name_to_group_attr_vec_idxs
|
| 431 |
+
|
| 432 |
+
|
| 433 |
+
def make_dict_idx_to_vec_idx(dictionary, cats, device="cpu", fill_value=-100):
|
| 434 |
+
# NOTE: when target is not in label_categories, the error is silent
|
| 435 |
+
map_tgt = torch.full((len(dictionary),), device=device, fill_value=fill_value, dtype=torch.long)
|
| 436 |
+
for vec_idx, label in enumerate(cats):
|
| 437 |
+
map_tgt[dictionary.index(label)] = vec_idx
|
| 438 |
+
return map_tgt
|
| 439 |
|
| 440 |
|
| 441 |
AutoConfig.register("icebert-pos", IceBertPosConfig)
|
old_label_utils.py
ADDED
|
@@ -0,0 +1,223 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (C) Miðeind ehf.
|
| 2 |
+
# This file is part of IceBERT POS model conversion.
|
| 3 |
+
|
| 4 |
+
"""
|
| 5 |
+
Utility functions copied from the old fairseq-based model for label handling.
|
| 6 |
+
These functions handle the conversion between vector indices and dictionary indices,
|
| 7 |
+
accounting for the offset caused by special tokens in the label dictionary.
|
| 8 |
+
"""
|
| 9 |
+
|
| 10 |
+
from typing import Dict, List, Tuple
|
| 11 |
+
import torch
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class SimpleLabelDictionary:
|
| 15 |
+
"""
|
| 16 |
+
Simplified version of fairseq Dictionary to handle label mappings.
|
| 17 |
+
This replaces the fairseq Dictionary dependency while maintaining the same interface.
|
| 18 |
+
"""
|
| 19 |
+
|
| 20 |
+
def __init__(self, labels: List[str], nspecial: int = 5):
|
| 21 |
+
"""
|
| 22 |
+
Args:
|
| 23 |
+
labels: List of labels including special tokens at the beginning
|
| 24 |
+
nspecial: Number of special tokens (typically 5: <pad>, <s>, </s>, <unk>, <SEP>)
|
| 25 |
+
"""
|
| 26 |
+
self.symbols = labels
|
| 27 |
+
self.nspecial = nspecial
|
| 28 |
+
self._indices = {label: idx for idx, label in enumerate(labels)}
|
| 29 |
+
|
| 30 |
+
def index(self, label: str) -> int:
|
| 31 |
+
"""Get index of label in dictionary."""
|
| 32 |
+
return self._indices.get(label, self.unk())
|
| 33 |
+
|
| 34 |
+
def unk(self) -> int:
|
| 35 |
+
"""Return index of unknown token (typically 3)."""
|
| 36 |
+
return 3
|
| 37 |
+
|
| 38 |
+
def string(self, indices: torch.Tensor) -> str:
|
| 39 |
+
"""Convert tensor of indices to space-separated string of labels."""
|
| 40 |
+
if indices.dim() == 0:
|
| 41 |
+
indices = indices.unsqueeze(0)
|
| 42 |
+
|
| 43 |
+
# Filter out special tokens like fairseq Dictionary does
|
| 44 |
+
special_indices_to_ignore = {0, 1, 2, 3} # BOS, PAD, EOS, UNK
|
| 45 |
+
|
| 46 |
+
labels = [
|
| 47 |
+
self.symbols[idx] for idx in indices.tolist()
|
| 48 |
+
if 0 <= idx < len(self.symbols) and idx not in special_indices_to_ignore
|
| 49 |
+
]
|
| 50 |
+
return " ".join(labels)
|
| 51 |
+
|
| 52 |
+
def __len__(self) -> int:
|
| 53 |
+
return len(self.symbols)
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
def make_vec_idx_to_dict_idx(dictionary: SimpleLabelDictionary, labels: List[str], device="cpu", fill_value=-100) -> torch.Tensor:
|
| 57 |
+
"""
|
| 58 |
+
Create mapping from vector indices to dictionary indices.
|
| 59 |
+
|
| 60 |
+
Args:
|
| 61 |
+
dictionary: Label dictionary
|
| 62 |
+
labels: List of labels
|
| 63 |
+
device: Device for tensor
|
| 64 |
+
fill_value: Fill value for missing entries
|
| 65 |
+
|
| 66 |
+
Returns:
|
| 67 |
+
Tensor mapping vector indices to dictionary indices
|
| 68 |
+
"""
|
| 69 |
+
vec_idx_to_dict_idx = torch.full((len(dictionary),), device=device, fill_value=fill_value, dtype=torch.long)
|
| 70 |
+
for vec_idx, label in enumerate(labels):
|
| 71 |
+
vec_idx_to_dict_idx[vec_idx] = dictionary.index(label)
|
| 72 |
+
return vec_idx_to_dict_idx
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
def make_group_masks(dictionary: SimpleLabelDictionary, schema, device="cpu") -> torch.Tensor:
|
| 76 |
+
"""
|
| 77 |
+
Create group masks indicating which groups are valid for each category.
|
| 78 |
+
|
| 79 |
+
Args:
|
| 80 |
+
dictionary: Label dictionary
|
| 81 |
+
schema: Label schema object
|
| 82 |
+
device: Device for tensor
|
| 83 |
+
|
| 84 |
+
Returns:
|
| 85 |
+
Tensor of shape (num_categories, num_groups) with 1 for valid combinations
|
| 86 |
+
"""
|
| 87 |
+
num_groups = len(schema.group_names)
|
| 88 |
+
offset = dictionary.nspecial
|
| 89 |
+
num_labels = len(dictionary) - offset
|
| 90 |
+
ret_mask = torch.zeros(num_labels, num_groups, dtype=torch.int64, device=device)
|
| 91 |
+
|
| 92 |
+
for cat, cat_group_names in schema.category_to_group_names.items():
|
| 93 |
+
cat_label_idx = dictionary.index(cat)
|
| 94 |
+
cat_vec_idx = schema.label_categories.index(cat)
|
| 95 |
+
for group_name in cat_group_names:
|
| 96 |
+
ret_mask[cat_vec_idx, schema.group_names.index(group_name)] = 1
|
| 97 |
+
assert cat_label_idx != dictionary.unk()
|
| 98 |
+
|
| 99 |
+
return ret_mask
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
def make_group_name_to_group_attr_vec_idxs(dictionary: SimpleLabelDictionary, schema) -> Dict[str, torch.Tensor]:
|
| 103 |
+
"""
|
| 104 |
+
Create mapping from group names to their attribute vector indices.
|
| 105 |
+
|
| 106 |
+
Args:
|
| 107 |
+
dictionary: Label dictionary
|
| 108 |
+
schema: Label schema object
|
| 109 |
+
|
| 110 |
+
Returns:
|
| 111 |
+
Dictionary mapping group names to tensor of vector indices
|
| 112 |
+
"""
|
| 113 |
+
offset = dictionary.nspecial
|
| 114 |
+
group_names = schema.group_name_to_labels.keys()
|
| 115 |
+
name_to_labels = schema.group_name_to_labels
|
| 116 |
+
group_name_to_group_attr_vec_idxs = {
|
| 117 |
+
name: torch.tensor([dictionary.index(item) - offset for item in name_to_labels[name]])
|
| 118 |
+
for name in group_names
|
| 119 |
+
}
|
| 120 |
+
return group_name_to_group_attr_vec_idxs
|
| 121 |
+
|
| 122 |
+
|
| 123 |
+
def make_dict_idx_to_vec_idx(dictionary: SimpleLabelDictionary, cats: List[str], device="cpu", fill_value=-100) -> torch.Tensor:
|
| 124 |
+
"""
|
| 125 |
+
Create mapping from dictionary indices to vector indices.
|
| 126 |
+
|
| 127 |
+
Args:
|
| 128 |
+
dictionary: Label dictionary
|
| 129 |
+
cats: List of categories
|
| 130 |
+
device: Device for tensor
|
| 131 |
+
fill_value: Fill value for missing entries
|
| 132 |
+
|
| 133 |
+
Returns:
|
| 134 |
+
Tensor mapping dictionary indices to vector indices
|
| 135 |
+
"""
|
| 136 |
+
# NOTE: when target is not in label_categories, the error is silent
|
| 137 |
+
map_tgt = torch.full((len(dictionary),), device=device, fill_value=fill_value, dtype=torch.long)
|
| 138 |
+
for vec_idx, label in enumerate(cats):
|
| 139 |
+
map_tgt[dictionary.index(label)] = vec_idx
|
| 140 |
+
return map_tgt
|
| 141 |
+
|
| 142 |
+
|
| 143 |
+
def clean_cats_attrs(ldict: SimpleLabelDictionary, schema, pred_cats: torch.Tensor, pred_attrs: torch.Tensor) -> List[Tuple[str, List[str]]]:
|
| 144 |
+
"""
|
| 145 |
+
Convert predicted category and attribute indices to human-readable labels.
|
| 146 |
+
|
| 147 |
+
Args:
|
| 148 |
+
ldict: Label dictionary
|
| 149 |
+
schema: Label schema object
|
| 150 |
+
pred_cats: Predicted category indices
|
| 151 |
+
pred_attrs: Predicted attribute indices
|
| 152 |
+
|
| 153 |
+
Returns:
|
| 154 |
+
List of (category, [attributes]) tuples
|
| 155 |
+
"""
|
| 156 |
+
cats = ldict.string(pred_cats).split(" ")
|
| 157 |
+
attrs = []
|
| 158 |
+
|
| 159 |
+
if len(pred_attrs.shape) == 1:
|
| 160 |
+
split_pred_attrs = [pred_attrs]
|
| 161 |
+
else:
|
| 162 |
+
split_pred_attrs = pred_attrs.split(1, dim=0)
|
| 163 |
+
|
| 164 |
+
for (_cat_idx, attr_idxs) in zip(pred_cats.tolist(), split_pred_attrs):
|
| 165 |
+
seq_attrs = [lbl for lbl in ldict.string((attr_idxs.squeeze())).split(" ")]
|
| 166 |
+
if not any(it for it in seq_attrs):
|
| 167 |
+
seq_attrs = []
|
| 168 |
+
attrs.append(seq_attrs)
|
| 169 |
+
|
| 170 |
+
return list(zip(cats, attrs))
|
| 171 |
+
|
| 172 |
+
|
| 173 |
+
def create_label_dictionary_from_schema(schema) -> SimpleLabelDictionary:
|
| 174 |
+
"""
|
| 175 |
+
Create a SimpleLabelDictionary from a label schema, mimicking the old fairseq setup.
|
| 176 |
+
Load the exact symbols from the original fairseq dictionary to ensure perfect compatibility.
|
| 177 |
+
|
| 178 |
+
Args:
|
| 179 |
+
schema: Label schema object (unused, kept for compatibility)
|
| 180 |
+
|
| 181 |
+
Returns:
|
| 182 |
+
SimpleLabelDictionary with exact same symbols as original fairseq dict
|
| 183 |
+
"""
|
| 184 |
+
try:
|
| 185 |
+
# Load original fairseq dictionary to get exact symbol order and content
|
| 186 |
+
from fairseq.data import Dictionary
|
| 187 |
+
import os
|
| 188 |
+
|
| 189 |
+
# Try to find the original dict_term.txt file
|
| 190 |
+
possible_paths = [
|
| 191 |
+
'scripts/dict_term.txt',
|
| 192 |
+
'icebert-pos/scripts/dict_term.txt',
|
| 193 |
+
'../scripts/dict_term.txt'
|
| 194 |
+
]
|
| 195 |
+
|
| 196 |
+
original_dict = None
|
| 197 |
+
for path in possible_paths:
|
| 198 |
+
if os.path.exists(path):
|
| 199 |
+
original_dict = Dictionary.load(path)
|
| 200 |
+
break
|
| 201 |
+
|
| 202 |
+
if original_dict is not None:
|
| 203 |
+
# Use exact symbols from original dictionary
|
| 204 |
+
return SimpleLabelDictionary(original_dict.symbols, nspecial=original_dict.nspecial)
|
| 205 |
+
|
| 206 |
+
except ImportError:
|
| 207 |
+
# Fallback if fairseq is not available
|
| 208 |
+
pass
|
| 209 |
+
except Exception:
|
| 210 |
+
# Fallback if file loading fails
|
| 211 |
+
pass
|
| 212 |
+
|
| 213 |
+
# Fallback: reconstruct from schema (original logic)
|
| 214 |
+
# Use the correct special token order from original dictionary
|
| 215 |
+
special_symbols = ["<s>", "<pad>", "</s>", "<unk>", "<SEP>"]
|
| 216 |
+
|
| 217 |
+
# The schema labels start with <SEP>, so we need to skip it
|
| 218 |
+
schema_labels_without_sep = [label for label in schema.labels if label != "<SEP>"]
|
| 219 |
+
|
| 220 |
+
# Combine: special tokens + schema labels (without duplicate <SEP>)
|
| 221 |
+
all_symbols = special_symbols + schema_labels_without_sep
|
| 222 |
+
|
| 223 |
+
return SimpleLabelDictionary(all_symbols, nspecial=4) # 4 special tokens before <SEP>
|