haukurpj commited on
Commit
d50a6a0
·
1 Parent(s): 72d99d6

add support for IFD tags and some refactoring

Browse files
Files changed (5) hide show
  1. README.md +21 -5
  2. configuration.py +56 -0
  3. ifd_utils.py +136 -0
  4. modeling.py +86 -105
  5. old_label_utils.py +0 -223
README.md CHANGED
@@ -13,24 +13,40 @@ from transformers import AutoModel, AutoTokenizer
13
  model = AutoModel.from_pretrained("mideind/IceBERT-PoS", trust_remote_code=True)
14
  tokenizer = AutoTokenizer.from_pretrained("mideind/IceBERT-PoS")
15
 
16
- # Pre tokenized
 
 
 
 
 
 
 
17
  sentence = "Ég veit að þú kemur í kvöld til mín ."
18
 
 
19
  result = model.predict_labels_from_text([sentence], tokenizer)
20
  expected = [
21
  [
22
  ("fp", ["1", "sing", "nom"]),
23
- ("sf", ["sing", "act", "1", "pres"]),
24
  ("c", []),
25
  ("fp", ["2", "sing", "nom"]),
26
- ("sf", ["sing", "act", "2", "pres"]),
27
- ("af", ["pos"]),
28
  ("n", ["neut", "sing", "acc"]),
29
- ("af", ["pos"]),
30
  ("fp", ["1", "sing", "gen"]),
31
  ("pl", []),
32
  ]
33
  ]
34
  assert result == expected, f"Expected {expected}, but got {result}"
35
  print("Test passed successfully!")
 
 
 
 
 
 
 
 
36
  ```
 
13
  model = AutoModel.from_pretrained("mideind/IceBERT-PoS", trust_remote_code=True)
14
  tokenizer = AutoTokenizer.from_pretrained("mideind/IceBERT-PoS")
15
 
16
+ ## Prediction Methods
17
+
18
+ The model provides two prediction methods:
19
+
20
+ - **`predict_labels_from_text()`**: Returns structured predictions as (category, [attributes]) tuples. Use this for downstream NLP tasks or when you need the semantic meaning of each prediction.
21
+ - **`predict_ifd_labels_from_text()`**: Returns predictions in IFD (Icelandic Frequency Dictionary) format. Use this for evaluation against MIM-GOLD datasets or when you need compatibility with traditional Icelandic POS taggers.
22
+
23
+ # Example sentence which is already tokenized (with a classic tokenizer)
24
  sentence = "Ég veit að þú kemur í kvöld til mín ."
25
 
26
+ # Get predictions in (category, [attributes]) format
27
  result = model.predict_labels_from_text([sentence], tokenizer)
28
  expected = [
29
  [
30
  ("fp", ["1", "sing", "nom"]),
31
+ ("sf", ["act", "1", "sing", "pres"]),
32
  ("c", []),
33
  ("fp", ["2", "sing", "nom"]),
34
+ ("sf", ["act", "2", "sing", "pres"]),
35
+ ("af", []),
36
  ("n", ["neut", "sing", "acc"]),
37
+ ("af", []),
38
  ("fp", ["1", "sing", "gen"]),
39
  ("pl", []),
40
  ]
41
  ]
42
  assert result == expected, f"Expected {expected}, but got {result}"
43
  print("Test passed successfully!")
44
+
45
+ # Get predictions in IFD format (for MIM-GOLD evaluation)
46
+ ifd_result = model.predict_ifd_labels_from_text([sentence], tokenizer)
47
+ ifd_expected = [
48
+ ["fp1en", "sfg1en", "c", "fp2en", "sfg2en", "af", "nheo", "af", "fp1ee", "pl"]
49
+ ]
50
+ assert ifd_result == ifd_expected, f"Expected {ifd_expected}, but got {ifd_result}"
51
+ print("IFD conversion test passed successfully!")
52
  ```
configuration.py CHANGED
@@ -5,6 +5,7 @@ import json
5
  from dataclasses import dataclass
6
  from typing import Dict, List, Optional
7
 
 
8
  from transformers import AutoConfig, RobertaConfig
9
 
10
 
@@ -31,6 +32,61 @@ class LabelSchema:
31
  separator: str
32
  ignore_categories: List[str]
33
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34
 
35
  class IceBertPosConfig(RobertaConfig):
36
  """
 
5
  from dataclasses import dataclass
6
  from typing import Dict, List, Optional
7
 
8
+ import torch
9
  from transformers import AutoConfig, RobertaConfig
10
 
11
 
 
32
  separator: str
33
  ignore_categories: List[str]
34
 
35
+ def get_group_name_to_group_attr_indices(self, device="cpu") -> Dict[str, torch.Tensor]:
36
+ """
37
+ Create mapping from group names to their attribute indices in the labels list.
38
+
39
+ Returns:
40
+ Dictionary mapping group names to tensor of label indices
41
+ """
42
+ group_name_to_group_attr_indices = {}
43
+ for group_name, group_labels in self.group_name_to_labels.items():
44
+ indices = []
45
+ for label in group_labels:
46
+ if label in self.labels:
47
+ indices.append(self.labels.index(label))
48
+ group_name_to_group_attr_indices[group_name] = torch.tensor(indices, device=device)
49
+ return group_name_to_group_attr_indices
50
+
51
+ def get_group_masks(self, device="cpu") -> torch.Tensor:
52
+ """
53
+ Create group masks indicating which groups are valid for each category.
54
+
55
+ Returns:
56
+ Tensor of shape (num_categories, num_groups) with 1 for valid combinations
57
+ """
58
+ num_categories = len(self.label_categories)
59
+ num_groups = len(self.group_names)
60
+ group_mask = torch.zeros(num_categories, num_groups, dtype=torch.int64, device=device)
61
+
62
+ for cat, cat_group_names in self.category_to_group_names.items():
63
+ if cat in self.label_categories:
64
+ cat_idx = self.label_categories.index(cat)
65
+ for group_name in cat_group_names:
66
+ if group_name in self.group_names:
67
+ group_idx = self.group_names.index(group_name)
68
+ group_mask[cat_idx, group_idx] = 1
69
+
70
+ return group_mask
71
+
72
+ def get_category_name_to_index(self) -> Dict[str, int]:
73
+ """
74
+ Create mapping from category names to their indices.
75
+
76
+ Returns:
77
+ Dictionary mapping category names to their indices
78
+ """
79
+ return {cat: idx for idx, cat in enumerate(self.label_categories)}
80
+
81
+ def get_label_name_to_index(self) -> Dict[str, int]:
82
+ """
83
+ Create mapping from label names to their indices.
84
+
85
+ Returns:
86
+ Dictionary mapping label names to their indices
87
+ """
88
+ return {label: idx for idx, label in enumerate(self.labels)}
89
+
90
 
91
  class IceBertPosConfig(RobertaConfig):
92
  """
ifd_utils.py ADDED
@@ -0,0 +1,136 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) Miðeind ehf.
2
+ # This file is part of IceBERT POS model conversion.
3
+
4
+ """
5
+ IFD (Icelandic Frequency Dictionary) utilities for converting model predictions
6
+ to IFD format labels used in MIM-GOLD evaluation.
7
+ """
8
+
9
+ import numpy as np
10
+ from typing import List, Tuple
11
+
12
+ # Category and feature definitions
13
+ CATS = [
14
+ "n", "g", "x", "e", "v", "l", "fa", "fb", "fe", "fo", "fp", "fs", "ft", "tf",
15
+ "ta", "tp", "to", "sn", "sb", "sf", "sv", "ss", "sl", "sþ", "cn", "ct", "c",
16
+ "aa", "af", "au", "ao", "aþ", "ae", "as", "ks", "kt", "p", "pl", "pk", "pg",
17
+ "pa", "ns", "m"
18
+ ]
19
+
20
+ FEATS = [
21
+ "masc", "fem", "neut", "gender_x", "1", "2", "3", "sing", "plur", "nom",
22
+ "acc", "dat", "gen", "definite", "proper", "strong", "weak", "equiinflected",
23
+ "pos", "cmp", "superl", "past", "pres", "pass", "act", "mid"
24
+ ]
25
+
26
+ LABELS = CATS + FEATS
27
+ LABEL_TO_IDX = {label: idx for (idx, label) in enumerate(LABELS)}
28
+
29
+ # IFD conversion mappings
30
+ GENDER = {"k": "masc", "v": "fem", "h": "neut", "-": "gender_x"}
31
+ NUMBER = {"e": "sing", "f": "plur"}
32
+ PERSON = {"1": "1", "2": "2", "3": "3"}
33
+ CASE = {"n": "nom", "o": "acc", "þ": "dat", "e": "gen"}
34
+ DEGREE = {"f": "pos", "m": "cmp", "e": "superl"}
35
+ VOICE = {"g": "act", "m": "mid"}
36
+ TENSE = {"n": "pres", "þ": "past"}
37
+ ADJ_CLASS = {"s": "strong", "v": "weak", "o": "equiinflected"}
38
+ DEFINITE = {"g": "definite", " ": "indefinite"}
39
+
40
+ TAGSET = {
41
+ "n": [
42
+ GENDER,
43
+ NUMBER,
44
+ CASE,
45
+ {"g": "definite", "-": "", " ": ""},
46
+ {"": "", "s": "proper"},
47
+ ],
48
+ "l": [GENDER, NUMBER, CASE, ADJ_CLASS, DEGREE],
49
+ "f": [{**GENDER, **PERSON}, NUMBER, CASE],
50
+ "g": [GENDER, NUMBER, CASE],
51
+ "t": [GENDER, NUMBER, CASE],
52
+ "sþ": [VOICE, GENDER, NUMBER, CASE],
53
+ "s": [VOICE, PERSON, NUMBER, TENSE],
54
+ "a": [DEGREE],
55
+ }
56
+
57
+
58
+ def vec2ifd(vec):
59
+ """Convert one-hot vector to IFD format tag."""
60
+ cat_idx = np.argmax(vec[:len(CATS)])
61
+ cat = CATS[cat_idx]
62
+ idxs = list(np.where(vec == 1)[0])
63
+ features = [LABELS[int(idx)] for idx in idxs if int(idx) >= len(CATS)]
64
+
65
+ if not features:
66
+ return cat
67
+
68
+ ret = []
69
+ codes = []
70
+ tagset_key = cat[0]
71
+ tagset_key = "sþ" if cat.startswith("sþ") else tagset_key
72
+
73
+ if tagset_key not in TAGSET:
74
+ return cat
75
+
76
+ for feature in TAGSET[tagset_key]:
77
+ for code, val in feature.items():
78
+ if val in features:
79
+ ret.append(val)
80
+ codes.append(code)
81
+
82
+ tag = "".join([cat] + codes)
83
+
84
+ if cat == "n" and "proper" in features and "definite" not in features:
85
+ tag = tag[:-1] + "-" + tag[-1]
86
+
87
+ return tag
88
+
89
+
90
+ def convert_predictions_to_ifd(predictions: List[Tuple[str, List[str]]]) -> List[str]:
91
+ """
92
+ Convert model predictions to IFD format using logic from the original model.
93
+
94
+ Args:
95
+ predictions: List of (category, [attributes]) tuples from model
96
+
97
+ Returns:
98
+ List of IFD format labels
99
+ """
100
+ ifd_labels = []
101
+
102
+ for labelset in predictions:
103
+ cat, feats = labelset
104
+ labels_to_map = [cat]
105
+
106
+ # Apply the same logic as the original predict_ifd_labels method
107
+ if len(feats) == 1 and feats[0] == "pos":
108
+ # This label is used as a default for training but implied in mim format
109
+ feats = []
110
+ elif cat == "sl" and "act" in feats:
111
+ # Number and tense are not shown for sl act in mim format
112
+ feats = [f for f in feats if f not in ["1", "sing", "pres"]]
113
+
114
+ labels_to_map += feats
115
+
116
+ # Create one-hot vector from labels
117
+ vec = np.zeros(len(LABELS))
118
+ for label in labels_to_map:
119
+ if label in LABEL_TO_IDX:
120
+ vec[LABEL_TO_IDX[label]] = 1
121
+
122
+ # Convert to IFD format
123
+ try:
124
+ ifd_label = vec2ifd(vec)
125
+ if ifd_label == "ns":
126
+ # This is to comply with the format
127
+ ifd_label = "n----s"
128
+ ifd_labels.append(ifd_label)
129
+ except Exception:
130
+ # Fallback to naive concatenation if conversion fails
131
+ if feats:
132
+ ifd_labels.append(cat + "".join(feats))
133
+ else:
134
+ ifd_labels.append(cat)
135
+
136
+ return ifd_labels
modeling.py CHANGED
@@ -11,15 +11,7 @@ from torch.nn.utils.rnn import pad_sequence
11
  from transformers import AutoConfig, AutoModel, PreTrainedModel, RobertaModel
12
 
13
  from .configuration import IceBertPosConfig
14
- from .old_label_utils import (
15
- SimpleLabelDictionary,
16
- clean_cats_attrs,
17
- create_label_dictionary_from_schema,
18
- make_dict_idx_to_vec_idx,
19
- make_group_masks,
20
- make_group_name_to_group_attr_vec_idxs,
21
- make_vec_idx_to_dict_idx,
22
- )
23
 
24
  logger = logging.getLogger(__name__)
25
 
@@ -90,21 +82,31 @@ class IceBertPosForTokenClassification(PreTrainedModel):
90
  self.roberta = RobertaModel(config, add_pooling_layer=False)
91
  self.classifier = MultiLabelTokenClassificationHead(config)
92
 
93
- # Create label dictionary and mappings (mimicking old fairseq model)
94
- self.label_dictionary = create_label_dictionary_from_schema(config.label_schema)
95
  self._setup_label_mappings()
96
 
97
  # Initialize weights and apply final processing
98
  self.post_init()
99
 
100
  def _setup_label_mappings(self):
101
- """Setup label mappings similar to the old fairseq model."""
102
  schema = self.config.label_schema
103
 
104
- self.group_name_to_group_attr_vec_idxs = make_group_name_to_group_attr_vec_idxs(self.label_dictionary, schema)
105
- self.cat_dict_idx_to_vec_idx = make_dict_idx_to_vec_idx(self.label_dictionary, schema.label_categories)
106
- self.cat_vec_idx_to_dict_idx = make_vec_idx_to_dict_idx(self.label_dictionary, schema.label_categories)
107
- self.group_mask = make_group_masks(self.label_dictionary, schema)
 
 
 
 
 
 
 
 
 
 
 
 
108
 
109
  def forward(
110
  self,
@@ -316,20 +318,16 @@ class IceBertPosForTokenClassification(PreTrainedModel):
316
  # Split sentences by spaces to get proper word boundaries
317
  # This fixes the issue where tokens like "Kl." get split incorrectly
318
  sentences_split = [sentence.split() for sentence in sentences]
319
-
320
  # Use batch_encode_plus with is_split_into_words=True to preserve word boundaries
321
  encoding = tokenizer.batch_encode_plus(
322
- sentences_split,
323
- return_tensors="pt",
324
- padding=True,
325
- is_split_into_words=True,
326
- add_special_tokens=True
327
  )
328
-
329
  batch_input_ids = encoding["input_ids"]
330
  batch_attention_mask = encoding["attention_mask"]
331
  word_ids_list = [encoding.word_ids(i) for i in range(len(sentences))]
332
-
333
  # Debug logging to match fairseq model
334
  for i in range(len(sentences)):
335
  logger.debug(f"Encoded tokens: {batch_input_ids[i]}")
@@ -342,8 +340,7 @@ class IceBertPosForTokenClassification(PreTrainedModel):
342
  self, cat_logits: torch.Tensor, attr_logits: torch.Tensor, word_mask: torch.Tensor
343
  ) -> List[List[Tuple[str, List[str]]]]:
344
  """
345
- Convert logits to human-readable labels using fairseq's group-based logic.
346
- Copied from the old model's logits_to_labels method.
347
  """
348
  # logits: Batch x Time x Labels
349
  bsz, _, num_cats = cat_logits.shape
@@ -353,90 +350,74 @@ class IceBertPosForTokenClassification(PreTrainedModel):
353
  assert num_attrs == len(self.config.label_schema.labels)
354
  assert num_cats == len(self.config.label_schema.label_categories)
355
 
356
- batch_cats = []
357
- batch_attrs = []
 
358
  for seq_idx in range(bsz):
359
  seq_nwords = nwords[seq_idx]
360
- pred_cat_vec_idxs = cat_logits[seq_idx, :seq_nwords].max(dim=-1).indices
361
- pred_cats = self.cat_vec_idx_to_dict_idx[pred_cat_vec_idxs]
362
-
363
- group_mask = self.group_mask[pred_cat_vec_idxs]
364
- offset = self.label_dictionary.nspecial
365
- pred_attrs = []
366
- for group_idx, group_name in enumerate(self.config.label_schema.group_names):
367
- group_vec_idxs = self.group_name_to_group_attr_vec_idxs[group_name]
368
- # logits: (bsz * nwords) x labels
369
- group_logits = attr_logits[seq_idx, :seq_nwords, group_vec_idxs]
370
- if len(group_vec_idxs) == 1:
371
- group_pred = group_logits.sigmoid().ge(0.5).long()
372
- group_pred_dict_idxs = (group_pred.squeeze() * (group_vec_idxs.item() + offset)).T.to(
373
- "cpu"
374
- ) * group_mask[:, group_idx]
375
- else:
376
- group_pred_vec_idxs = group_logits.max(dim=-1).indices
377
- group_pred_dict_idxs = (group_vec_idxs[group_pred_vec_idxs] + offset) * group_mask[:, group_idx]
378
- pred_attrs.append(group_pred_dict_idxs)
379
-
380
- pred_attrs = torch.stack([p.squeeze() for p in pred_attrs]).t()
381
-
382
- batch_cats.append(pred_cats)
383
- batch_attrs.append(pred_attrs)
384
-
385
- predictions = list(
386
- [
387
- clean_cats_attrs(
388
- self.label_dictionary,
389
- self.config.label_schema,
390
- seq_cats,
391
- seq_attrs,
392
- )
393
- for seq_cats, seq_attrs in zip(batch_cats, batch_attrs)
394
- ]
395
- )
 
 
 
396
 
397
  return predictions
398
 
399
-
400
- def make_vec_idx_to_dict_idx(dictionary, labels, device="cpu", fill_value=-100):
401
- vec_idx_to_dict_idx = torch.full((len(dictionary),), device=device, fill_value=fill_value, dtype=torch.long)
402
- for vec_idx, label in enumerate(labels):
403
- vec_idx_to_dict_idx[vec_idx] = dictionary.index(label)
404
- return vec_idx_to_dict_idx
405
-
406
-
407
- def make_group_masks(dictionary, schema, device="cpu"):
408
- num_groups = len(schema.group_names)
409
- offset = dictionary.nspecial
410
- num_labels = len(dictionary) - offset
411
- ret_mask = torch.zeros(num_labels, num_groups, dtype=torch.int64, device=device)
412
- for cat, cat_group_names in schema.category_to_group_names.items():
413
- cat_label_idx = dictionary.index(cat)
414
- cat_vec_idx = schema.label_categories.index(cat)
415
- for group_name in cat_group_names:
416
- ret_mask[cat_vec_idx, schema.group_names.index(group_name)] = 1
417
- assert cat_label_idx != dictionary.unk()
418
- for cat in schema.label_categories:
419
- cat_label_idx = dictionary.index(cat)
420
- assert cat_label_idx != dictionary.unk()
421
- return ret_mask
422
-
423
-
424
- def make_group_name_to_group_attr_vec_idxs(dict_, schema):
425
- offset = dict_.nspecial
426
- group_names = schema.group_name_to_labels.keys()
427
- name_to_labels = schema.group_name_to_labels
428
- group_name_to_group_attr_vec_idxs = {
429
- name: torch.tensor([dict_.index(item) - offset for item in name_to_labels[name]]) for name in group_names
430
- }
431
- return group_name_to_group_attr_vec_idxs
432
-
433
-
434
- def make_dict_idx_to_vec_idx(dictionary, cats, device="cpu", fill_value=-100):
435
- # NOTE: when target is not in label_categories, the error is silent
436
- map_tgt = torch.full((len(dictionary),), device=device, fill_value=fill_value, dtype=torch.long)
437
- for vec_idx, label in enumerate(cats):
438
- map_tgt[dictionary.index(label)] = vec_idx
439
- return map_tgt
440
 
441
 
442
  AutoConfig.register("icebert-pos", IceBertPosConfig)
 
11
  from transformers import AutoConfig, AutoModel, PreTrainedModel, RobertaModel
12
 
13
  from .configuration import IceBertPosConfig
14
+ from .ifd_utils import convert_predictions_to_ifd
 
 
 
 
 
 
 
 
15
 
16
  logger = logging.getLogger(__name__)
17
 
 
82
  self.roberta = RobertaModel(config, add_pooling_layer=False)
83
  self.classifier = MultiLabelTokenClassificationHead(config)
84
 
 
 
85
  self._setup_label_mappings()
86
 
87
  # Initialize weights and apply final processing
88
  self.post_init()
89
 
90
  def _setup_label_mappings(self):
91
+ """Setup label mappings using schema methods."""
92
  schema = self.config.label_schema
93
 
94
+ # Get model device for tensor creation
95
+ device = next(self.parameters()).device if len(list(self.parameters())) > 0 else torch.device("cpu")
96
+
97
+ # Register group mask as buffer so it moves with model.to(device)
98
+ self.register_buffer("group_mask", schema.get_group_masks(device=device))
99
+
100
+ # Register group attribute indices as buffers
101
+ group_attr_indices = schema.get_group_name_to_group_attr_indices(device=device)
102
+ self.group_name_to_group_attr_indices = {}
103
+ for group_name, indices in group_attr_indices.items():
104
+ buffer_name = f"group_attr_indices_{group_name}"
105
+ self.register_buffer(buffer_name, indices)
106
+ self.group_name_to_group_attr_indices[group_name] = getattr(self, buffer_name)
107
+
108
+ # Category name to index mapping (regular dict, no device movement needed)
109
+ self.category_name_to_index = schema.get_category_name_to_index()
110
 
111
  def forward(
112
  self,
 
318
  # Split sentences by spaces to get proper word boundaries
319
  # This fixes the issue where tokens like "Kl." get split incorrectly
320
  sentences_split = [sentence.split() for sentence in sentences]
321
+
322
  # Use batch_encode_plus with is_split_into_words=True to preserve word boundaries
323
  encoding = tokenizer.batch_encode_plus(
324
+ sentences_split, return_tensors="pt", padding=True, is_split_into_words=True, add_special_tokens=True
 
 
 
 
325
  )
326
+
327
  batch_input_ids = encoding["input_ids"]
328
  batch_attention_mask = encoding["attention_mask"]
329
  word_ids_list = [encoding.word_ids(i) for i in range(len(sentences))]
330
+
331
  # Debug logging to match fairseq model
332
  for i in range(len(sentences)):
333
  logger.debug(f"Encoded tokens: {batch_input_ids[i]}")
 
340
  self, cat_logits: torch.Tensor, attr_logits: torch.Tensor, word_mask: torch.Tensor
341
  ) -> List[List[Tuple[str, List[str]]]]:
342
  """
343
+ Convert logits to human-readable labels using schema-based logic.
 
344
  """
345
  # logits: Batch x Time x Labels
346
  bsz, _, num_cats = cat_logits.shape
 
350
  assert num_attrs == len(self.config.label_schema.labels)
351
  assert num_cats == len(self.config.label_schema.label_categories)
352
 
353
+ predictions = []
354
+ schema = self.config.label_schema
355
+
356
  for seq_idx in range(bsz):
357
  seq_nwords = nwords[seq_idx]
358
+ pred_cat_indices = cat_logits[seq_idx, :seq_nwords].max(dim=-1).indices
359
+
360
+ seq_predictions = []
361
+ for word_idx in range(seq_nwords):
362
+ cat_idx = int(pred_cat_indices[word_idx].item())
363
+ cat_name = schema.label_categories[cat_idx]
364
+
365
+ # Get valid groups for this category
366
+ valid_groups = schema.category_to_group_names.get(cat_name, [])
367
+
368
+ # Collect attributes for this word
369
+ attributes = []
370
+ for group_name in valid_groups:
371
+ if group_name in self.group_name_to_group_attr_indices:
372
+ group_indices = self.group_name_to_group_attr_indices[group_name]
373
+ if len(group_indices) > 0:
374
+ group_logits = attr_logits[seq_idx, word_idx, group_indices]
375
+ if len(group_indices) == 1:
376
+ # Binary decision
377
+ if group_logits.sigmoid().item() > 0.5:
378
+ attr_idx = int(group_indices[0].item())
379
+ attributes.append(schema.labels[attr_idx])
380
+ else:
381
+ # Multi-class decision
382
+ best_idx = int(group_logits.max(dim=-1).indices.item())
383
+ attr_idx = int(group_indices[best_idx].item())
384
+ attributes.append(schema.labels[attr_idx])
385
+
386
+ # Apply specific rules from original model
387
+ if len(attributes) == 1 and attributes[0] == "pos":
388
+ # This label is used as a default for training but implied in mim format
389
+ attributes = []
390
+ elif cat_name == "sl" and "act" in attributes:
391
+ # Number and tense are not shown for sl act in mim format
392
+ attributes = [attr for attr in attributes if attr not in ["1", "sing", "pres"]]
393
+
394
+ seq_predictions.append((cat_name, attributes))
395
+
396
+ predictions.append(seq_predictions)
397
 
398
  return predictions
399
 
400
+ def predict_ifd_labels_from_text(self, sentences: List[str], tokenizer) -> List[List[str]]:
401
+ """
402
+ Predict IFD format labels from raw text.
403
+
404
+ Args:
405
+ sentences: List of input sentences
406
+ tokenizer: HuggingFace tokenizer
407
+
408
+ Returns:
409
+ List of sequences, each containing IFD format labels per word
410
+ """
411
+ # Get model predictions in (category, [attributes]) format
412
+ predictions = self.predict_labels_from_text(sentences, tokenizer)
413
+
414
+ # Convert each sentence's predictions to IFD format
415
+ ifd_predictions = []
416
+ for sentence_predictions in predictions:
417
+ ifd_labels = convert_predictions_to_ifd(sentence_predictions)
418
+ ifd_predictions.append(ifd_labels)
419
+
420
+ return ifd_predictions
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
421
 
422
 
423
  AutoConfig.register("icebert-pos", IceBertPosConfig)
old_label_utils.py DELETED
@@ -1,223 +0,0 @@
1
- # Copyright (C) Miðeind ehf.
2
- # This file is part of IceBERT POS model conversion.
3
-
4
- """
5
- Utility functions copied from the old fairseq-based model for label handling.
6
- These functions handle the conversion between vector indices and dictionary indices,
7
- accounting for the offset caused by special tokens in the label dictionary.
8
- """
9
-
10
- from typing import Dict, List, Tuple
11
- import torch
12
-
13
-
14
- class SimpleLabelDictionary:
15
- """
16
- Simplified version of fairseq Dictionary to handle label mappings.
17
- This replaces the fairseq Dictionary dependency while maintaining the same interface.
18
- """
19
-
20
- def __init__(self, labels: List[str], nspecial: int = 5):
21
- """
22
- Args:
23
- labels: List of labels including special tokens at the beginning
24
- nspecial: Number of special tokens (typically 5: <pad>, <s>, </s>, <unk>, <SEP>)
25
- """
26
- self.symbols = labels
27
- self.nspecial = nspecial
28
- self._indices = {label: idx for idx, label in enumerate(labels)}
29
-
30
- def index(self, label: str) -> int:
31
- """Get index of label in dictionary."""
32
- return self._indices.get(label, self.unk())
33
-
34
- def unk(self) -> int:
35
- """Return index of unknown token (typically 3)."""
36
- return 3
37
-
38
- def string(self, indices: torch.Tensor) -> str:
39
- """Convert tensor of indices to space-separated string of labels."""
40
- if indices.dim() == 0:
41
- indices = indices.unsqueeze(0)
42
-
43
- # Filter out special tokens like fairseq Dictionary does
44
- special_indices_to_ignore = {0, 1, 2, 3} # BOS, PAD, EOS, UNK
45
-
46
- labels = [
47
- self.symbols[idx] for idx in indices.tolist()
48
- if 0 <= idx < len(self.symbols) and idx not in special_indices_to_ignore
49
- ]
50
- return " ".join(labels)
51
-
52
- def __len__(self) -> int:
53
- return len(self.symbols)
54
-
55
-
56
- def make_vec_idx_to_dict_idx(dictionary: SimpleLabelDictionary, labels: List[str], device="cpu", fill_value=-100) -> torch.Tensor:
57
- """
58
- Create mapping from vector indices to dictionary indices.
59
-
60
- Args:
61
- dictionary: Label dictionary
62
- labels: List of labels
63
- device: Device for tensor
64
- fill_value: Fill value for missing entries
65
-
66
- Returns:
67
- Tensor mapping vector indices to dictionary indices
68
- """
69
- vec_idx_to_dict_idx = torch.full((len(dictionary),), device=device, fill_value=fill_value, dtype=torch.long)
70
- for vec_idx, label in enumerate(labels):
71
- vec_idx_to_dict_idx[vec_idx] = dictionary.index(label)
72
- return vec_idx_to_dict_idx
73
-
74
-
75
- def make_group_masks(dictionary: SimpleLabelDictionary, schema, device="cpu") -> torch.Tensor:
76
- """
77
- Create group masks indicating which groups are valid for each category.
78
-
79
- Args:
80
- dictionary: Label dictionary
81
- schema: Label schema object
82
- device: Device for tensor
83
-
84
- Returns:
85
- Tensor of shape (num_categories, num_groups) with 1 for valid combinations
86
- """
87
- num_groups = len(schema.group_names)
88
- offset = dictionary.nspecial
89
- num_labels = len(dictionary) - offset
90
- ret_mask = torch.zeros(num_labels, num_groups, dtype=torch.int64, device=device)
91
-
92
- for cat, cat_group_names in schema.category_to_group_names.items():
93
- cat_label_idx = dictionary.index(cat)
94
- cat_vec_idx = schema.label_categories.index(cat)
95
- for group_name in cat_group_names:
96
- ret_mask[cat_vec_idx, schema.group_names.index(group_name)] = 1
97
- assert cat_label_idx != dictionary.unk()
98
-
99
- return ret_mask
100
-
101
-
102
- def make_group_name_to_group_attr_vec_idxs(dictionary: SimpleLabelDictionary, schema) -> Dict[str, torch.Tensor]:
103
- """
104
- Create mapping from group names to their attribute vector indices.
105
-
106
- Args:
107
- dictionary: Label dictionary
108
- schema: Label schema object
109
-
110
- Returns:
111
- Dictionary mapping group names to tensor of vector indices
112
- """
113
- offset = dictionary.nspecial
114
- group_names = schema.group_name_to_labels.keys()
115
- name_to_labels = schema.group_name_to_labels
116
- group_name_to_group_attr_vec_idxs = {
117
- name: torch.tensor([dictionary.index(item) - offset for item in name_to_labels[name]])
118
- for name in group_names
119
- }
120
- return group_name_to_group_attr_vec_idxs
121
-
122
-
123
- def make_dict_idx_to_vec_idx(dictionary: SimpleLabelDictionary, cats: List[str], device="cpu", fill_value=-100) -> torch.Tensor:
124
- """
125
- Create mapping from dictionary indices to vector indices.
126
-
127
- Args:
128
- dictionary: Label dictionary
129
- cats: List of categories
130
- device: Device for tensor
131
- fill_value: Fill value for missing entries
132
-
133
- Returns:
134
- Tensor mapping dictionary indices to vector indices
135
- """
136
- # NOTE: when target is not in label_categories, the error is silent
137
- map_tgt = torch.full((len(dictionary),), device=device, fill_value=fill_value, dtype=torch.long)
138
- for vec_idx, label in enumerate(cats):
139
- map_tgt[dictionary.index(label)] = vec_idx
140
- return map_tgt
141
-
142
-
143
- def clean_cats_attrs(ldict: SimpleLabelDictionary, schema, pred_cats: torch.Tensor, pred_attrs: torch.Tensor) -> List[Tuple[str, List[str]]]:
144
- """
145
- Convert predicted category and attribute indices to human-readable labels.
146
-
147
- Args:
148
- ldict: Label dictionary
149
- schema: Label schema object
150
- pred_cats: Predicted category indices
151
- pred_attrs: Predicted attribute indices
152
-
153
- Returns:
154
- List of (category, [attributes]) tuples
155
- """
156
- cats = ldict.string(pred_cats).split(" ")
157
- attrs = []
158
-
159
- if len(pred_attrs.shape) == 1:
160
- split_pred_attrs = [pred_attrs]
161
- else:
162
- split_pred_attrs = pred_attrs.split(1, dim=0)
163
-
164
- for (_cat_idx, attr_idxs) in zip(pred_cats.tolist(), split_pred_attrs):
165
- seq_attrs = [lbl for lbl in ldict.string((attr_idxs.squeeze())).split(" ")]
166
- if not any(it for it in seq_attrs):
167
- seq_attrs = []
168
- attrs.append(seq_attrs)
169
-
170
- return list(zip(cats, attrs))
171
-
172
-
173
- def create_label_dictionary_from_schema(schema) -> SimpleLabelDictionary:
174
- """
175
- Create a SimpleLabelDictionary from a label schema, mimicking the old fairseq setup.
176
- Load the exact symbols from the original fairseq dictionary to ensure perfect compatibility.
177
-
178
- Args:
179
- schema: Label schema object (unused, kept for compatibility)
180
-
181
- Returns:
182
- SimpleLabelDictionary with exact same symbols as original fairseq dict
183
- """
184
- try:
185
- # Load original fairseq dictionary to get exact symbol order and content
186
- from fairseq.data import Dictionary
187
- import os
188
-
189
- # Try to find the original dict_term.txt file
190
- possible_paths = [
191
- 'scripts/dict_term.txt',
192
- 'icebert-pos/scripts/dict_term.txt',
193
- '../scripts/dict_term.txt'
194
- ]
195
-
196
- original_dict = None
197
- for path in possible_paths:
198
- if os.path.exists(path):
199
- original_dict = Dictionary.load(path)
200
- break
201
-
202
- if original_dict is not None:
203
- # Use exact symbols from original dictionary
204
- return SimpleLabelDictionary(original_dict.symbols, nspecial=original_dict.nspecial)
205
-
206
- except ImportError:
207
- # Fallback if fairseq is not available
208
- pass
209
- except Exception:
210
- # Fallback if file loading fails
211
- pass
212
-
213
- # Fallback: reconstruct from schema (original logic)
214
- # Use the correct special token order from original dictionary
215
- special_symbols = ["<s>", "<pad>", "</s>", "<unk>", "<SEP>"]
216
-
217
- # The schema labels start with <SEP>, so we need to skip it
218
- schema_labels_without_sep = [label for label in schema.labels if label != "<SEP>"]
219
-
220
- # Combine: special tokens + schema labels (without duplicate <SEP>)
221
- all_symbols = special_symbols + schema_labels_without_sep
222
-
223
- return SimpleLabelDictionary(all_symbols, nspecial=4) # 4 special tokens before <SEP>