haukurpj commited on
Commit
aaca62a
·
1 Parent(s): d7ae5c8

Fix inconsistencies with the old model - now works equally

Browse files
Files changed (4) hide show
  1. config.json +0 -2
  2. configuration.py +67 -18
  3. modeling.py +178 -237
  4. old_label_utils.py +223 -0
config.json CHANGED
@@ -469,8 +469,6 @@
469
  "act",
470
  "mid"
471
  ],
472
- "null": null,
473
- "null_leaf": null,
474
  "separator": "<SEP>"
475
  },
476
  "layer_norm_eps": 1e-05,
 
469
  "act",
470
  "mid"
471
  ],
 
 
472
  "separator": "<SEP>"
473
  },
474
  "layer_norm_eps": 1e-05,
configuration.py CHANGED
@@ -2,11 +2,36 @@
2
  # This file is part of IceBERT POS model conversion.
3
 
4
  import json
5
- from typing import Any, Dict, Optional
 
6
 
7
  from transformers import AutoConfig, RobertaConfig
8
 
9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
  class IceBertPosConfig(RobertaConfig):
11
  """
12
  Configuration class for IceBERT POS (Part-of-Speech) tagging model.
@@ -18,7 +43,7 @@ class IceBertPosConfig(RobertaConfig):
18
  model_type = "icebert-pos"
19
 
20
  def __init__(
21
- self, label_schema: Optional[Dict[str, Any]] = None, classifier_dropout: Optional[float] = None, **kwargs
22
  ):
23
  super().__init__(**kwargs)
24
 
@@ -26,12 +51,16 @@ class IceBertPosConfig(RobertaConfig):
26
  if label_schema is None:
27
  label_schema = self._get_default_label_schema()
28
 
 
 
 
 
29
  self.label_schema = label_schema
30
 
31
  # Derive parameters from label schema
32
- self.num_categories = len(label_schema["label_categories"])
33
- self.num_labels = len(label_schema["labels"])
34
- self.num_groups = len(label_schema["group_names"])
35
 
36
  # Classification head parameters
37
  self.classifier_dropout = classifier_dropout if classifier_dropout is not None else 0.1
@@ -41,10 +70,10 @@ class IceBertPosConfig(RobertaConfig):
41
  self.attr_proj_input_size = self.num_categories + self.hidden_size
42
 
43
  @staticmethod
44
- def _get_default_label_schema() -> Dict[str, Any]:
45
  """Default label schema corresponding to terms2.json"""
46
- return {
47
- "label_categories": [
48
  "n",
49
  "g",
50
  "x",
@@ -89,7 +118,7 @@ class IceBertPosConfig(RobertaConfig):
89
  "ns",
90
  "m",
91
  ],
92
- "category_to_group_names": {
93
  "n": ["gender", "number", "case", "def", "proper"],
94
  "g": ["gender", "number", "case"],
95
  "l": ["gender", "number", "case", "adj_c", "deg"],
@@ -116,7 +145,7 @@ class IceBertPosConfig(RobertaConfig):
116
  "ae": ["deg"],
117
  "as": ["deg"],
118
  },
119
- "group_names": [
120
  "gender",
121
  "gender_or_person",
122
  "number",
@@ -129,7 +158,7 @@ class IceBertPosConfig(RobertaConfig):
129
  "person",
130
  "tense",
131
  ],
132
- "group_name_to_labels": {
133
  "gender": ["masc", "fem", "neut", "gender_x"],
134
  "number": ["sing", "plur"],
135
  "person": ["1", "2", "3"],
@@ -142,7 +171,7 @@ class IceBertPosConfig(RobertaConfig):
142
  "proper": ["proper"],
143
  "adj_c": ["strong", "weak", "equiinflected"],
144
  },
145
- "labels": [
146
  "<SEP>",
147
  "n",
148
  "g",
@@ -214,17 +243,37 @@ class IceBertPosConfig(RobertaConfig):
214
  "act",
215
  "mid",
216
  ],
217
- "null": None,
218
- "null_leaf": None,
219
- "separator": "<SEP>",
220
- "ignore_categories": ["x", "e"],
221
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
222
 
223
  @classmethod
224
  def from_label_schema_file(cls, schema_path: str, **kwargs) -> "IceBertPosConfig":
225
  """Create config from a label schema JSON file"""
226
  with open(schema_path, "r", encoding="utf-8") as f:
227
- label_schema = json.load(f)
 
228
  return cls(label_schema=label_schema, **kwargs)
229
 
230
 
 
2
  # This file is part of IceBERT POS model conversion.
3
 
4
  import json
5
+ from dataclasses import dataclass
6
+ from typing import Dict, List, Optional
7
 
8
  from transformers import AutoConfig, RobertaConfig
9
 
10
 
11
+ @dataclass
12
+ class LabelSchema:
13
+ """
14
+ Dataclass representing the structure of a POS tagging label schema.
15
+
16
+ The schema defines a hierarchical structure where:
17
+ - Categories (e.g., 'n', 'v', 'l') are the main POS types
18
+ - Groups (e.g., 'gender', 'number', 'case') are grammatical attribute types
19
+ - Labels are the specific values for each group (e.g., 'masc', 'fem', 'sing', 'plur')
20
+
21
+ Each category maps to applicable groups, and each group maps to its possible labels.
22
+ This enables multilabel classification where tokens get both a category and
23
+ relevant grammatical attributes.
24
+ """
25
+
26
+ label_categories: List[str]
27
+ category_to_group_names: Dict[str, List[str]]
28
+ group_names: List[str]
29
+ group_name_to_labels: Dict[str, List[str]]
30
+ labels: List[str]
31
+ separator: str
32
+ ignore_categories: List[str]
33
+
34
+
35
  class IceBertPosConfig(RobertaConfig):
36
  """
37
  Configuration class for IceBERT POS (Part-of-Speech) tagging model.
 
43
  model_type = "icebert-pos"
44
 
45
  def __init__(
46
+ self, label_schema: Optional[LabelSchema] = None, classifier_dropout: Optional[float] = None, **kwargs
47
  ):
48
  super().__init__(**kwargs)
49
 
 
51
  if label_schema is None:
52
  label_schema = self._get_default_label_schema()
53
 
54
+ # Convert dict to LabelSchema if needed (when loaded from JSON)
55
+ if isinstance(label_schema, dict):
56
+ label_schema = LabelSchema(**label_schema)
57
+
58
  self.label_schema = label_schema
59
 
60
  # Derive parameters from label schema
61
+ self.num_categories = len(label_schema.label_categories)
62
+ self.num_labels = len(label_schema.labels)
63
+ self.num_groups = len(label_schema.group_names)
64
 
65
  # Classification head parameters
66
  self.classifier_dropout = classifier_dropout if classifier_dropout is not None else 0.1
 
70
  self.attr_proj_input_size = self.num_categories + self.hidden_size
71
 
72
  @staticmethod
73
+ def _get_default_label_schema() -> LabelSchema:
74
  """Default label schema corresponding to terms2.json"""
75
+ return LabelSchema(
76
+ label_categories=[
77
  "n",
78
  "g",
79
  "x",
 
118
  "ns",
119
  "m",
120
  ],
121
+ category_to_group_names={
122
  "n": ["gender", "number", "case", "def", "proper"],
123
  "g": ["gender", "number", "case"],
124
  "l": ["gender", "number", "case", "adj_c", "deg"],
 
145
  "ae": ["deg"],
146
  "as": ["deg"],
147
  },
148
+ group_names=[
149
  "gender",
150
  "gender_or_person",
151
  "number",
 
158
  "person",
159
  "tense",
160
  ],
161
+ group_name_to_labels={
162
  "gender": ["masc", "fem", "neut", "gender_x"],
163
  "number": ["sing", "plur"],
164
  "person": ["1", "2", "3"],
 
171
  "proper": ["proper"],
172
  "adj_c": ["strong", "weak", "equiinflected"],
173
  },
174
+ labels=[
175
  "<SEP>",
176
  "n",
177
  "g",
 
243
  "act",
244
  "mid",
245
  ],
246
+ separator="<SEP>",
247
+ ignore_categories=["x", "e"],
248
+ )
249
+
250
+ def to_dict(self):
251
+ """Convert config to dictionary, handling LabelSchema serialization."""
252
+ output = super().to_dict()
253
+
254
+ # Convert LabelSchema to dict for JSON serialization
255
+ if hasattr(self, 'label_schema') and self.label_schema is not None:
256
+ if isinstance(self.label_schema, LabelSchema):
257
+ output['label_schema'] = {
258
+ 'label_categories': self.label_schema.label_categories,
259
+ 'category_to_group_names': self.label_schema.category_to_group_names,
260
+ 'group_names': self.label_schema.group_names,
261
+ 'group_name_to_labels': self.label_schema.group_name_to_labels,
262
+ 'labels': self.label_schema.labels,
263
+ 'separator': self.label_schema.separator,
264
+ 'ignore_categories': self.label_schema.ignore_categories,
265
+ }
266
+ else:
267
+ output['label_schema'] = self.label_schema
268
+
269
+ return output
270
 
271
  @classmethod
272
  def from_label_schema_file(cls, schema_path: str, **kwargs) -> "IceBertPosConfig":
273
  """Create config from a label schema JSON file"""
274
  with open(schema_path, "r", encoding="utf-8") as f:
275
+ schema_dict = json.load(f)
276
+ label_schema = LabelSchema(**schema_dict)
277
  return cls(label_schema=label_schema, **kwargs)
278
 
279
 
modeling.py CHANGED
@@ -11,6 +11,15 @@ from torch.nn.utils.rnn import pad_sequence
11
  from transformers import AutoConfig, AutoModel, PreTrainedModel, RobertaModel
12
 
13
  from .configuration import IceBertPosConfig
 
 
 
 
 
 
 
 
 
14
 
15
  logger = logging.getLogger(__name__)
16
 
@@ -38,11 +47,11 @@ class MultiLabelTokenClassificationHead(nn.Module):
38
  def forward(self, features: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
39
  """
40
  Args:
41
- features: Word-level features of shape (total_words, hidden_size)
42
 
43
  Returns:
44
- cat_logits: Category logits of shape (total_words, num_categories)
45
- attr_logits: Attribute logits of shape (total_words, num_labels)
46
  """
47
  x = self.dropout(features)
48
  x = self.dense(x)
@@ -81,9 +90,22 @@ class IceBertPosForTokenClassification(PreTrainedModel):
81
  self.roberta = RobertaModel(config, add_pooling_layer=False)
82
  self.classifier = MultiLabelTokenClassificationHead(config)
83
 
 
 
 
 
84
  # Initialize weights and apply final processing
85
  self.post_init()
86
 
 
 
 
 
 
 
 
 
 
87
  def forward(
88
  self,
89
  input_ids: torch.Tensor,
@@ -101,7 +123,7 @@ class IceBertPosForTokenClassification(PreTrainedModel):
101
  Args:
102
  input_ids: Token indices of shape (batch_size, sequence_length)
103
  attention_mask: Attention mask of shape (batch_size, sequence_length)
104
- word_mask: Binary mask indicating word boundaries (1 = word start)
105
 
106
  Returns:
107
  cat_logits: Category logits of shape (batch_size, max_words, num_categories)
@@ -118,22 +140,37 @@ class IceBertPosForTokenClassification(PreTrainedModel):
118
  head_mask=head_mask,
119
  inputs_embeds=inputs_embeds,
120
  output_attentions=output_attentions,
121
- output_hidden_states=output_hidden_states,
122
  return_dict=return_dict,
123
  )
124
 
125
- sequence_output = outputs[0] # (batch_size, seq_len, hidden_size)
126
-
127
- # Aggregate subword tokens to word-level representations using word_mask
128
- word_features, nwords = self._aggregate_subword_tokens(sequence_output, word_mask)
129
-
130
- # Apply classification head
131
- cat_logits, attr_logits = self.classifier(word_features)
132
 
133
- # Reshape back to batch format using word counts
134
- cat_logits_batch, attr_logits_batch = self._reshape_to_batch_format(cat_logits, attr_logits, nwords)
135
 
136
- return cat_logits_batch, attr_logits_batch
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
137
 
138
  def _aggregate_subword_tokens(
139
  self, sequence_output: torch.Tensor, word_mask: torch.Tensor
@@ -147,7 +184,7 @@ class IceBertPosForTokenClassification(PreTrainedModel):
147
  word_mask: Binary mask where 1 indicates start of word (batch_size, seq_len)
148
 
149
  Returns:
150
- word_features: Word-level features (total_words, hidden_size)
151
  nwords: Number of words per sequence (batch_size,)
152
  """
153
  # TODO: Verify that BOS and EOS are handled correctly - I'm worried that this does not correctly handle padding
@@ -234,7 +271,7 @@ class IceBertPosForTokenClassification(PreTrainedModel):
234
 
235
  cat_logits, attr_logits = self.forward(input_ids=input_ids, attention_mask=attention_mask, word_mask=word_mask)
236
 
237
- return self._logits_to_labels(cat_logits, attr_logits, word_ids)
238
 
239
  def _word_ids_to_word_mask(self, word_ids: List[List[int]], input_shape: torch.Size) -> torch.Tensor:
240
  """
@@ -245,18 +282,23 @@ class IceBertPosForTokenClassification(PreTrainedModel):
245
  input_shape: Shape of input_ids tensor (batch_size, seq_len)
246
 
247
  Returns:
248
- word_mask: Binary tensor where 1 indicates start of word
249
  """
250
  batch_size, seq_len = input_shape
251
  word_mask = torch.zeros(batch_size, seq_len, dtype=torch.long)
252
 
253
  for batch_idx, seq_word_ids in enumerate(word_ids):
 
 
254
  prev_word_id = None
255
- for token_idx, word_id in enumerate(seq_word_ids):
256
  if word_id != prev_word_id:
257
- word_mask[batch_idx, token_idx] = 1
258
  prev_word_id = word_id
259
 
 
 
 
260
  return word_mask
261
 
262
  def predict_labels_from_text(self, sentences: List[str], tokenizer) -> List[List[Tuple[str, List[str]]]]:
@@ -270,231 +312,130 @@ class IceBertPosForTokenClassification(PreTrainedModel):
270
  Returns:
271
  List of sequences, each containing (category, [attributes]) per word
272
  """
273
- # Tokenize with fairseq-style preprocessing
274
- encodings = [tokenizer(sent, return_tensors="pt") for sent in sentences]
275
- word_ids_list = [encoding.word_ids() for encoding in encodings]
276
-
277
- # Batch the inputs
278
- max_len = max(encoding["input_ids"].shape[1] for encoding in encodings)
279
- batch_input_ids = []
280
- batch_attention_mask = []
281
-
282
- for encoding in encodings:
283
- input_ids = encoding["input_ids"][0]
284
- attention_mask = encoding["attention_mask"][0]
285
-
286
- # Pad to max length
287
- pad_len = max_len - len(input_ids)
288
- if pad_len > 0:
289
- input_ids = torch.cat([input_ids, torch.ones(pad_len, dtype=torch.long)]) # pad_token_id = 1
290
- attention_mask = torch.cat([attention_mask, torch.zeros(pad_len, dtype=torch.long)])
291
-
292
- batch_input_ids.append(input_ids)
293
- batch_attention_mask.append(attention_mask)
294
-
295
- batch_input_ids = torch.stack(batch_input_ids)
296
- batch_attention_mask = torch.stack(batch_attention_mask)
297
 
298
  return self.predict_labels(batch_input_ids, batch_attention_mask, word_ids_list)
299
 
300
- def _make_group_name_to_group_attr_vec_idxs(self):
301
- """Create mapping from group names to their attribute vector indices"""
302
- group_name_to_group_attr_vec_idxs = {}
303
- labels = self.config.label_schema["labels"]
304
- nspecial = 0 # Number of special tokens in label dictionary (like <SEP>)
305
-
306
- for group_name, group_labels in self.config.label_schema["group_name_to_labels"].items():
307
- vec_idxs = []
308
- for label in group_labels:
309
- if label in labels:
310
- # Find index in labels list, but subtract nspecial to get vector index
311
- label_dict_idx = labels.index(label)
312
- if label_dict_idx >= nspecial: # Skip special tokens
313
- vec_idxs.append(label_dict_idx - nspecial)
314
- group_name_to_group_attr_vec_idxs[group_name] = torch.tensor(vec_idxs)
315
-
316
- return group_name_to_group_attr_vec_idxs
317
-
318
- def _make_group_masks(self):
319
- """Create group masks for each category"""
320
- label_categories = self.config.label_schema["label_categories"]
321
- group_names = self.config.label_schema["group_names"]
322
- category_to_group_names = self.config.label_schema["category_to_group_names"]
323
-
324
- num_cats = len(label_categories)
325
- num_groups = len(group_names)
326
-
327
- group_mask = torch.zeros(num_cats, num_groups, dtype=torch.bool)
328
-
329
- for cat_idx, category in enumerate(label_categories):
330
- if category in category_to_group_names:
331
- for group_name in category_to_group_names[category]:
332
- if group_name in group_names:
333
- group_idx = group_names.index(group_name)
334
- group_mask[cat_idx, group_idx] = True
335
-
336
- return group_mask
337
-
338
- def _make_category_mappings(self):
339
- """Create mappings between category vector indices and dictionary indices"""
340
- labels = self.config.label_schema["labels"]
341
- label_categories = self.config.label_schema["label_categories"]
342
-
343
- # Create mapping from category names to vector indices (0-based)
344
- cat_dict_idx_to_vec_idx = torch.zeros(len(labels), dtype=torch.long)
345
- cat_vec_idx_to_dict_idx = torch.zeros(len(label_categories), dtype=torch.long)
346
-
347
- for vec_idx, category in enumerate(label_categories):
348
- if category in labels:
349
- dict_idx = labels.index(category)
350
- cat_dict_idx_to_vec_idx[dict_idx] = vec_idx
351
- cat_vec_idx_to_dict_idx[vec_idx] = dict_idx
352
-
353
- return cat_dict_idx_to_vec_idx, cat_vec_idx_to_dict_idx
354
-
355
- def _count_words_per_sequence(self, word_ids: List[List[int]]) -> List[int]:
356
- """Count the number of unique words in each sequence."""
357
- words_per_seq = []
358
- for seq_word_ids in word_ids:
359
- unique_word_ids = set(word_id for word_id in seq_word_ids if word_id is not None)
360
- words_per_seq.append(len(unique_word_ids))
361
- return words_per_seq
362
-
363
- def _predict_categories_for_sequence(
364
- self, cat_logits: torch.Tensor, seq_idx: int, seq_nwords: int, cat_vec_idx_to_dict_idx: torch.Tensor
365
- ) -> Tuple[torch.Tensor, torch.Tensor]:
366
- """Predict categories for a single sequence and return both vector and dictionary indices."""
367
- pred_cat_vec_idxs = cat_logits[seq_idx, :seq_nwords].max(dim=-1).indices
368
- pred_cats = cat_vec_idx_to_dict_idx[pred_cat_vec_idxs]
369
- return pred_cat_vec_idxs, pred_cats
370
-
371
- def _predict_attributes_for_group(
372
- self,
373
- attr_logits: torch.Tensor,
374
- seq_idx: int,
375
- seq_nwords: int,
376
- group_vec_idxs: torch.Tensor,
377
- seq_group_mask: torch.Tensor,
378
- group_idx: int,
379
- ) -> torch.Tensor:
380
- """Predict attributes for a single group."""
381
- if len(group_vec_idxs) == 0:
382
- return torch.zeros(seq_nwords, dtype=torch.long)
383
-
384
- # Get logits for this group
385
- group_logits = attr_logits[seq_idx, :seq_nwords, group_vec_idxs]
386
-
387
- if len(group_vec_idxs) == 1:
388
- # Single element group: use sigmoid > 0.5
389
- group_pred = group_logits.sigmoid().ge(0.5).long()
390
- group_pred_dict_idxs = (group_pred.squeeze() * group_vec_idxs.item()) * seq_group_mask[:, group_idx]
391
- else:
392
- # Multi element group: use argmax
393
- group_pred_vec_idxs = group_logits.max(dim=-1).indices
394
- group_pred_dict_idxs = group_vec_idxs[group_pred_vec_idxs] * seq_group_mask[:, group_idx]
395
-
396
- return group_pred_dict_idxs
397
-
398
- def _predict_all_attributes_for_sequence(
399
- self,
400
- attr_logits: torch.Tensor,
401
- seq_idx: int,
402
- seq_nwords: int,
403
- pred_cat_vec_idxs: torch.Tensor,
404
- group_name_to_group_attr_vec_idxs: dict,
405
- group_mask: torch.Tensor,
406
- group_names: List[str],
407
- ) -> torch.Tensor:
408
- """Predict all attributes for a single sequence."""
409
- seq_group_mask = group_mask[pred_cat_vec_idxs]
410
- pred_attrs = []
411
-
412
- for group_idx, group_name in enumerate(group_names):
413
- if group_name not in group_name_to_group_attr_vec_idxs:
414
- pred_attrs.append(torch.zeros(seq_nwords, dtype=torch.long))
415
- continue
416
-
417
- group_vec_idxs = group_name_to_group_attr_vec_idxs[group_name]
418
- group_pred_dict_idxs = self._predict_attributes_for_group(
419
- attr_logits, seq_idx, seq_nwords, group_vec_idxs, seq_group_mask, group_idx
420
- )
421
- pred_attrs.append(group_pred_dict_idxs)
422
-
423
- # Stack predictions
424
- if pred_attrs:
425
- return torch.stack([p.squeeze() if p.dim() > 1 else p for p in pred_attrs]).t()
426
- else:
427
- return torch.zeros(seq_nwords, len(group_names), dtype=torch.long)
428
-
429
- def _convert_predictions_to_labels(
430
- self, pred_cats: torch.Tensor, pred_attrs_tensor: torch.Tensor, labels: List[str], group_names: List[str]
431
- ) -> List[Tuple[str, List[str]]]:
432
- """Convert prediction tensors to human-readable labels."""
433
- seq_nwords = pred_cats.size(0)
434
- seq_predictions = []
435
-
436
- for word_idx in range(seq_nwords):
437
- # Category (convert from dictionary index to string)
438
- cat_dict_idx = pred_cats[word_idx].item()
439
- if cat_dict_idx < len(labels):
440
- category = labels[cat_dict_idx]
441
- else:
442
- category = "UNK"
443
-
444
- # Attributes (convert from dictionary indices to strings)
445
- attributes = []
446
- for group_idx in range(len(group_names)):
447
- attr_dict_idx = pred_attrs_tensor[word_idx, group_idx].item()
448
- if attr_dict_idx > 0 and attr_dict_idx < len(labels): # Skip 0 (empty) and out of bounds
449
- attributes.append(labels[attr_dict_idx])
450
-
451
- seq_predictions.append((category, attributes))
452
-
453
- return seq_predictions
454
-
455
  def _logits_to_labels(
456
- self, cat_logits: torch.Tensor, attr_logits: torch.Tensor, word_ids: List[List[int]]
457
  ) -> List[List[Tuple[str, List[str]]]]:
458
  """
459
  Convert logits to human-readable labels using fairseq's group-based logic.
 
460
  """
461
- # Create necessary mappings
462
- group_name_to_group_attr_vec_idxs = self._make_group_name_to_group_attr_vec_idxs()
463
- group_mask = self._make_group_masks()
464
- cat_dict_idx_to_vec_idx, cat_vec_idx_to_dict_idx = self._make_category_mappings()
465
-
466
- label_schema = self.config.label_schema
467
- labels = label_schema["labels"]
468
- group_names = label_schema["group_names"]
469
-
470
- batch_size = cat_logits.size(0)
471
- words_per_seq = self._count_words_per_sequence(word_ids)
472
- batch_predictions = []
473
-
474
- for seq_idx in range(batch_size):
475
- seq_nwords = words_per_seq[seq_idx]
476
-
477
- # Predict categories
478
- pred_cat_vec_idxs, pred_cats = self._predict_categories_for_sequence(
479
- cat_logits, seq_idx, seq_nwords, cat_vec_idx_to_dict_idx
480
- )
481
-
482
- # Predict attributes
483
- pred_attrs_tensor = self._predict_all_attributes_for_sequence(
484
- attr_logits,
485
- seq_idx,
486
- seq_nwords,
487
- pred_cat_vec_idxs,
488
- group_name_to_group_attr_vec_idxs,
489
- group_mask,
490
- group_names,
491
- )
492
-
493
- # Convert to labels
494
- seq_predictions = self._convert_predictions_to_labels(pred_cats, pred_attrs_tensor, labels, group_names)
495
- batch_predictions.append(seq_predictions)
496
-
497
- return batch_predictions
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
498
 
499
 
500
  AutoConfig.register("icebert-pos", IceBertPosConfig)
 
11
  from transformers import AutoConfig, AutoModel, PreTrainedModel, RobertaModel
12
 
13
  from .configuration import IceBertPosConfig
14
+ from .old_label_utils import (
15
+ SimpleLabelDictionary,
16
+ clean_cats_attrs,
17
+ create_label_dictionary_from_schema,
18
+ make_dict_idx_to_vec_idx,
19
+ make_group_masks,
20
+ make_group_name_to_group_attr_vec_idxs,
21
+ make_vec_idx_to_dict_idx,
22
+ )
23
 
24
  logger = logging.getLogger(__name__)
25
 
 
47
  def forward(self, features: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
48
  """
49
  Args:
50
+ features: Word-level features of shape (batch_size, max_words, hidden_size)
51
 
52
  Returns:
53
+ cat_logits: Category logits of shape (batch_size, max_words, num_categories)
54
+ attr_logits: Attribute logits of shape (batch_size, max_words, num_labels)
55
  """
56
  x = self.dropout(features)
57
  x = self.dense(x)
 
90
  self.roberta = RobertaModel(config, add_pooling_layer=False)
91
  self.classifier = MultiLabelTokenClassificationHead(config)
92
 
93
+ # Create label dictionary and mappings (mimicking old fairseq model)
94
+ self.label_dictionary = create_label_dictionary_from_schema(config.label_schema)
95
+ self._setup_label_mappings()
96
+
97
  # Initialize weights and apply final processing
98
  self.post_init()
99
 
100
+ def _setup_label_mappings(self):
101
+ """Setup label mappings similar to the old fairseq model."""
102
+ schema = self.config.label_schema
103
+
104
+ self.group_name_to_group_attr_vec_idxs = make_group_name_to_group_attr_vec_idxs(self.label_dictionary, schema)
105
+ self.cat_dict_idx_to_vec_idx = make_dict_idx_to_vec_idx(self.label_dictionary, schema.label_categories)
106
+ self.cat_vec_idx_to_dict_idx = make_vec_idx_to_dict_idx(self.label_dictionary, schema.label_categories)
107
+ self.group_mask = make_group_masks(self.label_dictionary, schema)
108
+
109
  def forward(
110
  self,
111
  input_ids: torch.Tensor,
 
123
  Args:
124
  input_ids: Token indices of shape (batch_size, sequence_length)
125
  attention_mask: Attention mask of shape (batch_size, sequence_length)
126
+ word_mask: Binary mask indicating word boundaries (1 = word start) of shape (batch_size, sequence_length)
127
 
128
  Returns:
129
  cat_logits: Category logits of shape (batch_size, max_words, num_categories)
 
140
  head_mask=head_mask,
141
  inputs_embeds=inputs_embeds,
142
  output_attentions=output_attentions,
143
+ output_hidden_states=True,
144
  return_dict=return_dict,
145
  )
146
 
147
+ x = outputs[0] # (batch_size, seq_len, hidden)
 
 
 
 
 
 
148
 
149
+ # Copy exact logic from old model
150
+ _, _, inner_dim = x.shape
151
 
152
+ # use first bpe token of word as representation
153
+ x = x[:, 1:-1, :]
154
+ starts = word_mask[:, 1:-1] # remove bos, eos
155
+ ends = starts.roll(-1, dims=[-1]).nonzero()[:, -1] + 1
156
+ starts = starts.nonzero().tolist()
157
+ mean_words = []
158
+ for (seq_idx, token_idx), end in zip(starts, ends):
159
+ mean_words.append(x[seq_idx, token_idx:end, :].mean(dim=0))
160
+ mean_words = torch.stack(mean_words)
161
+ words = mean_words
162
+ # Innermost dimension is mask for tokens at head of word.
163
+ nwords = word_mask.sum(dim=-1)
164
+ (cat_logits, attr_logits) = self.classifier(words)
165
+
166
+ # (Batch * Time) x Depth -> Batch x Time x Depth
167
+ cat_logits = pad_sequence(cat_logits.split((nwords).tolist()), padding_value=0, batch_first=True)
168
+ attr_logits = pad_sequence(
169
+ attr_logits.split((nwords).tolist()),
170
+ padding_value=0,
171
+ batch_first=True,
172
+ )
173
+ return cat_logits, attr_logits
174
 
175
  def _aggregate_subword_tokens(
176
  self, sequence_output: torch.Tensor, word_mask: torch.Tensor
 
184
  word_mask: Binary mask where 1 indicates start of word (batch_size, seq_len)
185
 
186
  Returns:
187
+ word_features: Word-level features (batch_size, max_words, hidden_size)
188
  nwords: Number of words per sequence (batch_size,)
189
  """
190
  # TODO: Verify that BOS and EOS are handled correctly - I'm worried that this does not correctly handle padding
 
271
 
272
  cat_logits, attr_logits = self.forward(input_ids=input_ids, attention_mask=attention_mask, word_mask=word_mask)
273
 
274
+ return self._logits_to_labels(cat_logits, attr_logits, word_mask)
275
 
276
  def _word_ids_to_word_mask(self, word_ids: List[List[int]], input_shape: torch.Size) -> torch.Tensor:
277
  """
 
282
  input_shape: Shape of input_ids tensor (batch_size, seq_len)
283
 
284
  Returns:
285
+ word_mask: Binary tensor where 1 indicates start of word (batch_size, seq_len)
286
  """
287
  batch_size, seq_len = input_shape
288
  word_mask = torch.zeros(batch_size, seq_len, dtype=torch.long)
289
 
290
  for batch_idx, seq_word_ids in enumerate(word_ids):
291
+ # Truncate to exclude BOS and EOS tokens (first and last)
292
+ truncated_word_ids = seq_word_ids[1:-1]
293
  prev_word_id = None
294
+ for token_idx, word_id in enumerate(truncated_word_ids):
295
  if word_id != prev_word_id:
296
+ word_mask[batch_idx, token_idx + 1] = 1 # +1 to account for BOS
297
  prev_word_id = word_id
298
 
299
+ # Debug logging to match fairseq model
300
+ logger.debug(f"Word mask: {word_mask[batch_idx].tolist()}")
301
+
302
  return word_mask
303
 
304
  def predict_labels_from_text(self, sentences: List[str], tokenizer) -> List[List[Tuple[str, List[str]]]]:
 
312
  Returns:
313
  List of sequences, each containing (category, [attributes]) per word
314
  """
315
+ # Split sentences by spaces to get proper word boundaries
316
+ # This fixes the issue where tokens like "Kl." get split incorrectly
317
+ sentences_split = [sentence.split() for sentence in sentences]
318
+
319
+ # Use batch_encode_plus with is_split_into_words=True to preserve word boundaries
320
+ encoding = tokenizer.batch_encode_plus(
321
+ sentences_split,
322
+ return_tensors="pt",
323
+ padding=True,
324
+ is_split_into_words=True,
325
+ add_special_tokens=True
326
+ )
327
+
328
+ batch_input_ids = encoding["input_ids"]
329
+ batch_attention_mask = encoding["attention_mask"]
330
+ word_ids_list = [encoding.word_ids(i) for i in range(len(sentences))]
331
+
332
+ # Debug logging to match fairseq model
333
+ for i in range(len(sentences)):
334
+ logger.debug(f"Encoded tokens: {batch_input_ids[i]}")
335
+ logger.debug(f"Decoded tokens: {tokenizer.convert_ids_to_tokens(batch_input_ids[i].tolist())}")
336
+ logger.debug(f"Word IDs: {word_ids_list[i]}")
 
 
337
 
338
  return self.predict_labels(batch_input_ids, batch_attention_mask, word_ids_list)
339
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
340
  def _logits_to_labels(
341
+ self, cat_logits: torch.Tensor, attr_logits: torch.Tensor, word_mask: torch.Tensor
342
  ) -> List[List[Tuple[str, List[str]]]]:
343
  """
344
  Convert logits to human-readable labels using fairseq's group-based logic.
345
+ Copied from the old model's logits_to_labels method.
346
  """
347
+ # logits: Batch x Time x Labels
348
+ bsz, _, num_cats = cat_logits.shape
349
+ _, _, num_attrs = attr_logits.shape
350
+ nwords = word_mask.sum(-1)
351
+
352
+ assert num_attrs == len(self.config.label_schema.labels)
353
+ assert num_cats == len(self.config.label_schema.label_categories)
354
+
355
+ batch_cats = []
356
+ batch_attrs = []
357
+ for seq_idx in range(bsz):
358
+ seq_nwords = nwords[seq_idx]
359
+ pred_cat_vec_idxs = cat_logits[seq_idx, :seq_nwords].max(dim=-1).indices
360
+ pred_cats = self.cat_vec_idx_to_dict_idx[pred_cat_vec_idxs]
361
+
362
+ group_mask = self.group_mask[pred_cat_vec_idxs]
363
+ offset = self.label_dictionary.nspecial
364
+ pred_attrs = []
365
+ for group_idx, group_name in enumerate(self.config.label_schema.group_names):
366
+ group_vec_idxs = self.group_name_to_group_attr_vec_idxs[group_name]
367
+ # logits: (bsz * nwords) x labels
368
+ group_logits = attr_logits[seq_idx, :seq_nwords, group_vec_idxs]
369
+ if len(group_vec_idxs) == 1:
370
+ group_pred = group_logits.sigmoid().ge(0.5).long()
371
+ group_pred_dict_idxs = (group_pred.squeeze() * (group_vec_idxs.item() + offset)).T.to(
372
+ "cpu"
373
+ ) * group_mask[:, group_idx]
374
+ else:
375
+ group_pred_vec_idxs = group_logits.max(dim=-1).indices
376
+ group_pred_dict_idxs = (group_vec_idxs[group_pred_vec_idxs] + offset) * group_mask[:, group_idx]
377
+ pred_attrs.append(group_pred_dict_idxs)
378
+
379
+ pred_attrs = torch.stack([p.squeeze() for p in pred_attrs]).t()
380
+
381
+ batch_cats.append(pred_cats)
382
+ batch_attrs.append(pred_attrs)
383
+
384
+ predictions = list(
385
+ [
386
+ clean_cats_attrs(
387
+ self.label_dictionary,
388
+ self.config.label_schema,
389
+ seq_cats,
390
+ seq_attrs,
391
+ )
392
+ for seq_cats, seq_attrs in zip(batch_cats, batch_attrs)
393
+ ]
394
+ )
395
+
396
+ return predictions
397
+
398
+
399
+ def make_vec_idx_to_dict_idx(dictionary, labels, device="cpu", fill_value=-100):
400
+ vec_idx_to_dict_idx = torch.full((len(dictionary),), device=device, fill_value=fill_value, dtype=torch.long)
401
+ for vec_idx, label in enumerate(labels):
402
+ vec_idx_to_dict_idx[vec_idx] = dictionary.index(label)
403
+ return vec_idx_to_dict_idx
404
+
405
+
406
+ def make_group_masks(dictionary, schema, device="cpu"):
407
+ num_groups = len(schema.group_names)
408
+ offset = dictionary.nspecial
409
+ num_labels = len(dictionary) - offset
410
+ ret_mask = torch.zeros(num_labels, num_groups, dtype=torch.int64, device=device)
411
+ for cat, cat_group_names in schema.category_to_group_names.items():
412
+ cat_label_idx = dictionary.index(cat)
413
+ cat_vec_idx = schema.label_categories.index(cat)
414
+ for group_name in cat_group_names:
415
+ ret_mask[cat_vec_idx, schema.group_names.index(group_name)] = 1
416
+ assert cat_label_idx != dictionary.unk()
417
+ for cat in schema.label_categories:
418
+ cat_label_idx = dictionary.index(cat)
419
+ assert cat_label_idx != dictionary.unk()
420
+ return ret_mask
421
+
422
+
423
+ def make_group_name_to_group_attr_vec_idxs(dict_, schema):
424
+ offset = dict_.nspecial
425
+ group_names = schema.group_name_to_labels.keys()
426
+ name_to_labels = schema.group_name_to_labels
427
+ group_name_to_group_attr_vec_idxs = {
428
+ name: torch.tensor([dict_.index(item) - offset for item in name_to_labels[name]]) for name in group_names
429
+ }
430
+ return group_name_to_group_attr_vec_idxs
431
+
432
+
433
+ def make_dict_idx_to_vec_idx(dictionary, cats, device="cpu", fill_value=-100):
434
+ # NOTE: when target is not in label_categories, the error is silent
435
+ map_tgt = torch.full((len(dictionary),), device=device, fill_value=fill_value, dtype=torch.long)
436
+ for vec_idx, label in enumerate(cats):
437
+ map_tgt[dictionary.index(label)] = vec_idx
438
+ return map_tgt
439
 
440
 
441
  AutoConfig.register("icebert-pos", IceBertPosConfig)
old_label_utils.py ADDED
@@ -0,0 +1,223 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) Miðeind ehf.
2
+ # This file is part of IceBERT POS model conversion.
3
+
4
+ """
5
+ Utility functions copied from the old fairseq-based model for label handling.
6
+ These functions handle the conversion between vector indices and dictionary indices,
7
+ accounting for the offset caused by special tokens in the label dictionary.
8
+ """
9
+
10
+ from typing import Dict, List, Tuple
11
+ import torch
12
+
13
+
14
+ class SimpleLabelDictionary:
15
+ """
16
+ Simplified version of fairseq Dictionary to handle label mappings.
17
+ This replaces the fairseq Dictionary dependency while maintaining the same interface.
18
+ """
19
+
20
+ def __init__(self, labels: List[str], nspecial: int = 5):
21
+ """
22
+ Args:
23
+ labels: List of labels including special tokens at the beginning
24
+ nspecial: Number of special tokens (typically 5: <pad>, <s>, </s>, <unk>, <SEP>)
25
+ """
26
+ self.symbols = labels
27
+ self.nspecial = nspecial
28
+ self._indices = {label: idx for idx, label in enumerate(labels)}
29
+
30
+ def index(self, label: str) -> int:
31
+ """Get index of label in dictionary."""
32
+ return self._indices.get(label, self.unk())
33
+
34
+ def unk(self) -> int:
35
+ """Return index of unknown token (typically 3)."""
36
+ return 3
37
+
38
+ def string(self, indices: torch.Tensor) -> str:
39
+ """Convert tensor of indices to space-separated string of labels."""
40
+ if indices.dim() == 0:
41
+ indices = indices.unsqueeze(0)
42
+
43
+ # Filter out special tokens like fairseq Dictionary does
44
+ special_indices_to_ignore = {0, 1, 2, 3} # BOS, PAD, EOS, UNK
45
+
46
+ labels = [
47
+ self.symbols[idx] for idx in indices.tolist()
48
+ if 0 <= idx < len(self.symbols) and idx not in special_indices_to_ignore
49
+ ]
50
+ return " ".join(labels)
51
+
52
+ def __len__(self) -> int:
53
+ return len(self.symbols)
54
+
55
+
56
+ def make_vec_idx_to_dict_idx(dictionary: SimpleLabelDictionary, labels: List[str], device="cpu", fill_value=-100) -> torch.Tensor:
57
+ """
58
+ Create mapping from vector indices to dictionary indices.
59
+
60
+ Args:
61
+ dictionary: Label dictionary
62
+ labels: List of labels
63
+ device: Device for tensor
64
+ fill_value: Fill value for missing entries
65
+
66
+ Returns:
67
+ Tensor mapping vector indices to dictionary indices
68
+ """
69
+ vec_idx_to_dict_idx = torch.full((len(dictionary),), device=device, fill_value=fill_value, dtype=torch.long)
70
+ for vec_idx, label in enumerate(labels):
71
+ vec_idx_to_dict_idx[vec_idx] = dictionary.index(label)
72
+ return vec_idx_to_dict_idx
73
+
74
+
75
+ def make_group_masks(dictionary: SimpleLabelDictionary, schema, device="cpu") -> torch.Tensor:
76
+ """
77
+ Create group masks indicating which groups are valid for each category.
78
+
79
+ Args:
80
+ dictionary: Label dictionary
81
+ schema: Label schema object
82
+ device: Device for tensor
83
+
84
+ Returns:
85
+ Tensor of shape (num_categories, num_groups) with 1 for valid combinations
86
+ """
87
+ num_groups = len(schema.group_names)
88
+ offset = dictionary.nspecial
89
+ num_labels = len(dictionary) - offset
90
+ ret_mask = torch.zeros(num_labels, num_groups, dtype=torch.int64, device=device)
91
+
92
+ for cat, cat_group_names in schema.category_to_group_names.items():
93
+ cat_label_idx = dictionary.index(cat)
94
+ cat_vec_idx = schema.label_categories.index(cat)
95
+ for group_name in cat_group_names:
96
+ ret_mask[cat_vec_idx, schema.group_names.index(group_name)] = 1
97
+ assert cat_label_idx != dictionary.unk()
98
+
99
+ return ret_mask
100
+
101
+
102
+ def make_group_name_to_group_attr_vec_idxs(dictionary: SimpleLabelDictionary, schema) -> Dict[str, torch.Tensor]:
103
+ """
104
+ Create mapping from group names to their attribute vector indices.
105
+
106
+ Args:
107
+ dictionary: Label dictionary
108
+ schema: Label schema object
109
+
110
+ Returns:
111
+ Dictionary mapping group names to tensor of vector indices
112
+ """
113
+ offset = dictionary.nspecial
114
+ group_names = schema.group_name_to_labels.keys()
115
+ name_to_labels = schema.group_name_to_labels
116
+ group_name_to_group_attr_vec_idxs = {
117
+ name: torch.tensor([dictionary.index(item) - offset for item in name_to_labels[name]])
118
+ for name in group_names
119
+ }
120
+ return group_name_to_group_attr_vec_idxs
121
+
122
+
123
+ def make_dict_idx_to_vec_idx(dictionary: SimpleLabelDictionary, cats: List[str], device="cpu", fill_value=-100) -> torch.Tensor:
124
+ """
125
+ Create mapping from dictionary indices to vector indices.
126
+
127
+ Args:
128
+ dictionary: Label dictionary
129
+ cats: List of categories
130
+ device: Device for tensor
131
+ fill_value: Fill value for missing entries
132
+
133
+ Returns:
134
+ Tensor mapping dictionary indices to vector indices
135
+ """
136
+ # NOTE: when target is not in label_categories, the error is silent
137
+ map_tgt = torch.full((len(dictionary),), device=device, fill_value=fill_value, dtype=torch.long)
138
+ for vec_idx, label in enumerate(cats):
139
+ map_tgt[dictionary.index(label)] = vec_idx
140
+ return map_tgt
141
+
142
+
143
+ def clean_cats_attrs(ldict: SimpleLabelDictionary, schema, pred_cats: torch.Tensor, pred_attrs: torch.Tensor) -> List[Tuple[str, List[str]]]:
144
+ """
145
+ Convert predicted category and attribute indices to human-readable labels.
146
+
147
+ Args:
148
+ ldict: Label dictionary
149
+ schema: Label schema object
150
+ pred_cats: Predicted category indices
151
+ pred_attrs: Predicted attribute indices
152
+
153
+ Returns:
154
+ List of (category, [attributes]) tuples
155
+ """
156
+ cats = ldict.string(pred_cats).split(" ")
157
+ attrs = []
158
+
159
+ if len(pred_attrs.shape) == 1:
160
+ split_pred_attrs = [pred_attrs]
161
+ else:
162
+ split_pred_attrs = pred_attrs.split(1, dim=0)
163
+
164
+ for (_cat_idx, attr_idxs) in zip(pred_cats.tolist(), split_pred_attrs):
165
+ seq_attrs = [lbl for lbl in ldict.string((attr_idxs.squeeze())).split(" ")]
166
+ if not any(it for it in seq_attrs):
167
+ seq_attrs = []
168
+ attrs.append(seq_attrs)
169
+
170
+ return list(zip(cats, attrs))
171
+
172
+
173
+ def create_label_dictionary_from_schema(schema) -> SimpleLabelDictionary:
174
+ """
175
+ Create a SimpleLabelDictionary from a label schema, mimicking the old fairseq setup.
176
+ Load the exact symbols from the original fairseq dictionary to ensure perfect compatibility.
177
+
178
+ Args:
179
+ schema: Label schema object (unused, kept for compatibility)
180
+
181
+ Returns:
182
+ SimpleLabelDictionary with exact same symbols as original fairseq dict
183
+ """
184
+ try:
185
+ # Load original fairseq dictionary to get exact symbol order and content
186
+ from fairseq.data import Dictionary
187
+ import os
188
+
189
+ # Try to find the original dict_term.txt file
190
+ possible_paths = [
191
+ 'scripts/dict_term.txt',
192
+ 'icebert-pos/scripts/dict_term.txt',
193
+ '../scripts/dict_term.txt'
194
+ ]
195
+
196
+ original_dict = None
197
+ for path in possible_paths:
198
+ if os.path.exists(path):
199
+ original_dict = Dictionary.load(path)
200
+ break
201
+
202
+ if original_dict is not None:
203
+ # Use exact symbols from original dictionary
204
+ return SimpleLabelDictionary(original_dict.symbols, nspecial=original_dict.nspecial)
205
+
206
+ except ImportError:
207
+ # Fallback if fairseq is not available
208
+ pass
209
+ except Exception:
210
+ # Fallback if file loading fails
211
+ pass
212
+
213
+ # Fallback: reconstruct from schema (original logic)
214
+ # Use the correct special token order from original dictionary
215
+ special_symbols = ["<s>", "<pad>", "</s>", "<unk>", "<SEP>"]
216
+
217
+ # The schema labels start with <SEP>, so we need to skip it
218
+ schema_labels_without_sep = [label for label in schema.labels if label != "<SEP>"]
219
+
220
+ # Combine: special tokens + schema labels (without duplicate <SEP>)
221
+ all_symbols = special_symbols + schema_labels_without_sep
222
+
223
+ return SimpleLabelDictionary(all_symbols, nspecial=4) # 4 special tokens before <SEP>