IceBERT-PoS / old_label_utils.py
haukurpj's picture
Fix inconsistencies with the old model - now works equally
aaca62a
raw
history blame
8.05 kB
# Copyright (C) Miðeind ehf.
# This file is part of IceBERT POS model conversion.
"""
Utility functions copied from the old fairseq-based model for label handling.
These functions handle the conversion between vector indices and dictionary indices,
accounting for the offset caused by special tokens in the label dictionary.
"""
from typing import Dict, List, Tuple
import torch
class SimpleLabelDictionary:
"""
Simplified version of fairseq Dictionary to handle label mappings.
This replaces the fairseq Dictionary dependency while maintaining the same interface.
"""
def __init__(self, labels: List[str], nspecial: int = 5):
"""
Args:
labels: List of labels including special tokens at the beginning
nspecial: Number of special tokens (typically 5: <pad>, <s>, </s>, <unk>, <SEP>)
"""
self.symbols = labels
self.nspecial = nspecial
self._indices = {label: idx for idx, label in enumerate(labels)}
def index(self, label: str) -> int:
"""Get index of label in dictionary."""
return self._indices.get(label, self.unk())
def unk(self) -> int:
"""Return index of unknown token (typically 3)."""
return 3
def string(self, indices: torch.Tensor) -> str:
"""Convert tensor of indices to space-separated string of labels."""
if indices.dim() == 0:
indices = indices.unsqueeze(0)
# Filter out special tokens like fairseq Dictionary does
special_indices_to_ignore = {0, 1, 2, 3} # BOS, PAD, EOS, UNK
labels = [
self.symbols[idx] for idx in indices.tolist()
if 0 <= idx < len(self.symbols) and idx not in special_indices_to_ignore
]
return " ".join(labels)
def __len__(self) -> int:
return len(self.symbols)
def make_vec_idx_to_dict_idx(dictionary: SimpleLabelDictionary, labels: List[str], device="cpu", fill_value=-100) -> torch.Tensor:
"""
Create mapping from vector indices to dictionary indices.
Args:
dictionary: Label dictionary
labels: List of labels
device: Device for tensor
fill_value: Fill value for missing entries
Returns:
Tensor mapping vector indices to dictionary indices
"""
vec_idx_to_dict_idx = torch.full((len(dictionary),), device=device, fill_value=fill_value, dtype=torch.long)
for vec_idx, label in enumerate(labels):
vec_idx_to_dict_idx[vec_idx] = dictionary.index(label)
return vec_idx_to_dict_idx
def make_group_masks(dictionary: SimpleLabelDictionary, schema, device="cpu") -> torch.Tensor:
"""
Create group masks indicating which groups are valid for each category.
Args:
dictionary: Label dictionary
schema: Label schema object
device: Device for tensor
Returns:
Tensor of shape (num_categories, num_groups) with 1 for valid combinations
"""
num_groups = len(schema.group_names)
offset = dictionary.nspecial
num_labels = len(dictionary) - offset
ret_mask = torch.zeros(num_labels, num_groups, dtype=torch.int64, device=device)
for cat, cat_group_names in schema.category_to_group_names.items():
cat_label_idx = dictionary.index(cat)
cat_vec_idx = schema.label_categories.index(cat)
for group_name in cat_group_names:
ret_mask[cat_vec_idx, schema.group_names.index(group_name)] = 1
assert cat_label_idx != dictionary.unk()
return ret_mask
def make_group_name_to_group_attr_vec_idxs(dictionary: SimpleLabelDictionary, schema) -> Dict[str, torch.Tensor]:
"""
Create mapping from group names to their attribute vector indices.
Args:
dictionary: Label dictionary
schema: Label schema object
Returns:
Dictionary mapping group names to tensor of vector indices
"""
offset = dictionary.nspecial
group_names = schema.group_name_to_labels.keys()
name_to_labels = schema.group_name_to_labels
group_name_to_group_attr_vec_idxs = {
name: torch.tensor([dictionary.index(item) - offset for item in name_to_labels[name]])
for name in group_names
}
return group_name_to_group_attr_vec_idxs
def make_dict_idx_to_vec_idx(dictionary: SimpleLabelDictionary, cats: List[str], device="cpu", fill_value=-100) -> torch.Tensor:
"""
Create mapping from dictionary indices to vector indices.
Args:
dictionary: Label dictionary
cats: List of categories
device: Device for tensor
fill_value: Fill value for missing entries
Returns:
Tensor mapping dictionary indices to vector indices
"""
# NOTE: when target is not in label_categories, the error is silent
map_tgt = torch.full((len(dictionary),), device=device, fill_value=fill_value, dtype=torch.long)
for vec_idx, label in enumerate(cats):
map_tgt[dictionary.index(label)] = vec_idx
return map_tgt
def clean_cats_attrs(ldict: SimpleLabelDictionary, schema, pred_cats: torch.Tensor, pred_attrs: torch.Tensor) -> List[Tuple[str, List[str]]]:
"""
Convert predicted category and attribute indices to human-readable labels.
Args:
ldict: Label dictionary
schema: Label schema object
pred_cats: Predicted category indices
pred_attrs: Predicted attribute indices
Returns:
List of (category, [attributes]) tuples
"""
cats = ldict.string(pred_cats).split(" ")
attrs = []
if len(pred_attrs.shape) == 1:
split_pred_attrs = [pred_attrs]
else:
split_pred_attrs = pred_attrs.split(1, dim=0)
for (_cat_idx, attr_idxs) in zip(pred_cats.tolist(), split_pred_attrs):
seq_attrs = [lbl for lbl in ldict.string((attr_idxs.squeeze())).split(" ")]
if not any(it for it in seq_attrs):
seq_attrs = []
attrs.append(seq_attrs)
return list(zip(cats, attrs))
def create_label_dictionary_from_schema(schema) -> SimpleLabelDictionary:
"""
Create a SimpleLabelDictionary from a label schema, mimicking the old fairseq setup.
Load the exact symbols from the original fairseq dictionary to ensure perfect compatibility.
Args:
schema: Label schema object (unused, kept for compatibility)
Returns:
SimpleLabelDictionary with exact same symbols as original fairseq dict
"""
try:
# Load original fairseq dictionary to get exact symbol order and content
from fairseq.data import Dictionary
import os
# Try to find the original dict_term.txt file
possible_paths = [
'scripts/dict_term.txt',
'icebert-pos/scripts/dict_term.txt',
'../scripts/dict_term.txt'
]
original_dict = None
for path in possible_paths:
if os.path.exists(path):
original_dict = Dictionary.load(path)
break
if original_dict is not None:
# Use exact symbols from original dictionary
return SimpleLabelDictionary(original_dict.symbols, nspecial=original_dict.nspecial)
except ImportError:
# Fallback if fairseq is not available
pass
except Exception:
# Fallback if file loading fails
pass
# Fallback: reconstruct from schema (original logic)
# Use the correct special token order from original dictionary
special_symbols = ["<s>", "<pad>", "</s>", "<unk>", "<SEP>"]
# The schema labels start with <SEP>, so we need to skip it
schema_labels_without_sep = [label for label in schema.labels if label != "<SEP>"]
# Combine: special tokens + schema labels (without duplicate <SEP>)
all_symbols = special_symbols + schema_labels_without_sep
return SimpleLabelDictionary(all_symbols, nspecial=4) # 4 special tokens before <SEP>