|
|
|
|
|
|
|
|
|
|
|
""" |
|
|
Utility functions copied from the old fairseq-based model for label handling. |
|
|
These functions handle the conversion between vector indices and dictionary indices, |
|
|
accounting for the offset caused by special tokens in the label dictionary. |
|
|
""" |
|
|
|
|
|
from typing import Dict, List, Tuple |
|
|
import torch |
|
|
|
|
|
|
|
|
class SimpleLabelDictionary: |
|
|
""" |
|
|
Simplified version of fairseq Dictionary to handle label mappings. |
|
|
This replaces the fairseq Dictionary dependency while maintaining the same interface. |
|
|
""" |
|
|
|
|
|
def __init__(self, labels: List[str], nspecial: int = 5): |
|
|
""" |
|
|
Args: |
|
|
labels: List of labels including special tokens at the beginning |
|
|
nspecial: Number of special tokens (typically 5: <pad>, <s>, </s>, <unk>, <SEP>) |
|
|
""" |
|
|
self.symbols = labels |
|
|
self.nspecial = nspecial |
|
|
self._indices = {label: idx for idx, label in enumerate(labels)} |
|
|
|
|
|
def index(self, label: str) -> int: |
|
|
"""Get index of label in dictionary.""" |
|
|
return self._indices.get(label, self.unk()) |
|
|
|
|
|
def unk(self) -> int: |
|
|
"""Return index of unknown token (typically 3).""" |
|
|
return 3 |
|
|
|
|
|
def string(self, indices: torch.Tensor) -> str: |
|
|
"""Convert tensor of indices to space-separated string of labels.""" |
|
|
if indices.dim() == 0: |
|
|
indices = indices.unsqueeze(0) |
|
|
|
|
|
|
|
|
special_indices_to_ignore = {0, 1, 2, 3} |
|
|
|
|
|
labels = [ |
|
|
self.symbols[idx] for idx in indices.tolist() |
|
|
if 0 <= idx < len(self.symbols) and idx not in special_indices_to_ignore |
|
|
] |
|
|
return " ".join(labels) |
|
|
|
|
|
def __len__(self) -> int: |
|
|
return len(self.symbols) |
|
|
|
|
|
|
|
|
def make_vec_idx_to_dict_idx(dictionary: SimpleLabelDictionary, labels: List[str], device="cpu", fill_value=-100) -> torch.Tensor: |
|
|
""" |
|
|
Create mapping from vector indices to dictionary indices. |
|
|
|
|
|
Args: |
|
|
dictionary: Label dictionary |
|
|
labels: List of labels |
|
|
device: Device for tensor |
|
|
fill_value: Fill value for missing entries |
|
|
|
|
|
Returns: |
|
|
Tensor mapping vector indices to dictionary indices |
|
|
""" |
|
|
vec_idx_to_dict_idx = torch.full((len(dictionary),), device=device, fill_value=fill_value, dtype=torch.long) |
|
|
for vec_idx, label in enumerate(labels): |
|
|
vec_idx_to_dict_idx[vec_idx] = dictionary.index(label) |
|
|
return vec_idx_to_dict_idx |
|
|
|
|
|
|
|
|
def make_group_masks(dictionary: SimpleLabelDictionary, schema, device="cpu") -> torch.Tensor: |
|
|
""" |
|
|
Create group masks indicating which groups are valid for each category. |
|
|
|
|
|
Args: |
|
|
dictionary: Label dictionary |
|
|
schema: Label schema object |
|
|
device: Device for tensor |
|
|
|
|
|
Returns: |
|
|
Tensor of shape (num_categories, num_groups) with 1 for valid combinations |
|
|
""" |
|
|
num_groups = len(schema.group_names) |
|
|
offset = dictionary.nspecial |
|
|
num_labels = len(dictionary) - offset |
|
|
ret_mask = torch.zeros(num_labels, num_groups, dtype=torch.int64, device=device) |
|
|
|
|
|
for cat, cat_group_names in schema.category_to_group_names.items(): |
|
|
cat_label_idx = dictionary.index(cat) |
|
|
cat_vec_idx = schema.label_categories.index(cat) |
|
|
for group_name in cat_group_names: |
|
|
ret_mask[cat_vec_idx, schema.group_names.index(group_name)] = 1 |
|
|
assert cat_label_idx != dictionary.unk() |
|
|
|
|
|
return ret_mask |
|
|
|
|
|
|
|
|
def make_group_name_to_group_attr_vec_idxs(dictionary: SimpleLabelDictionary, schema) -> Dict[str, torch.Tensor]: |
|
|
""" |
|
|
Create mapping from group names to their attribute vector indices. |
|
|
|
|
|
Args: |
|
|
dictionary: Label dictionary |
|
|
schema: Label schema object |
|
|
|
|
|
Returns: |
|
|
Dictionary mapping group names to tensor of vector indices |
|
|
""" |
|
|
offset = dictionary.nspecial |
|
|
group_names = schema.group_name_to_labels.keys() |
|
|
name_to_labels = schema.group_name_to_labels |
|
|
group_name_to_group_attr_vec_idxs = { |
|
|
name: torch.tensor([dictionary.index(item) - offset for item in name_to_labels[name]]) |
|
|
for name in group_names |
|
|
} |
|
|
return group_name_to_group_attr_vec_idxs |
|
|
|
|
|
|
|
|
def make_dict_idx_to_vec_idx(dictionary: SimpleLabelDictionary, cats: List[str], device="cpu", fill_value=-100) -> torch.Tensor: |
|
|
""" |
|
|
Create mapping from dictionary indices to vector indices. |
|
|
|
|
|
Args: |
|
|
dictionary: Label dictionary |
|
|
cats: List of categories |
|
|
device: Device for tensor |
|
|
fill_value: Fill value for missing entries |
|
|
|
|
|
Returns: |
|
|
Tensor mapping dictionary indices to vector indices |
|
|
""" |
|
|
|
|
|
map_tgt = torch.full((len(dictionary),), device=device, fill_value=fill_value, dtype=torch.long) |
|
|
for vec_idx, label in enumerate(cats): |
|
|
map_tgt[dictionary.index(label)] = vec_idx |
|
|
return map_tgt |
|
|
|
|
|
|
|
|
def clean_cats_attrs(ldict: SimpleLabelDictionary, schema, pred_cats: torch.Tensor, pred_attrs: torch.Tensor) -> List[Tuple[str, List[str]]]: |
|
|
""" |
|
|
Convert predicted category and attribute indices to human-readable labels. |
|
|
|
|
|
Args: |
|
|
ldict: Label dictionary |
|
|
schema: Label schema object |
|
|
pred_cats: Predicted category indices |
|
|
pred_attrs: Predicted attribute indices |
|
|
|
|
|
Returns: |
|
|
List of (category, [attributes]) tuples |
|
|
""" |
|
|
cats = ldict.string(pred_cats).split(" ") |
|
|
attrs = [] |
|
|
|
|
|
if len(pred_attrs.shape) == 1: |
|
|
split_pred_attrs = [pred_attrs] |
|
|
else: |
|
|
split_pred_attrs = pred_attrs.split(1, dim=0) |
|
|
|
|
|
for (_cat_idx, attr_idxs) in zip(pred_cats.tolist(), split_pred_attrs): |
|
|
seq_attrs = [lbl for lbl in ldict.string((attr_idxs.squeeze())).split(" ")] |
|
|
if not any(it for it in seq_attrs): |
|
|
seq_attrs = [] |
|
|
attrs.append(seq_attrs) |
|
|
|
|
|
return list(zip(cats, attrs)) |
|
|
|
|
|
|
|
|
def create_label_dictionary_from_schema(schema) -> SimpleLabelDictionary: |
|
|
""" |
|
|
Create a SimpleLabelDictionary from a label schema, mimicking the old fairseq setup. |
|
|
Load the exact symbols from the original fairseq dictionary to ensure perfect compatibility. |
|
|
|
|
|
Args: |
|
|
schema: Label schema object (unused, kept for compatibility) |
|
|
|
|
|
Returns: |
|
|
SimpleLabelDictionary with exact same symbols as original fairseq dict |
|
|
""" |
|
|
try: |
|
|
|
|
|
from fairseq.data import Dictionary |
|
|
import os |
|
|
|
|
|
|
|
|
possible_paths = [ |
|
|
'scripts/dict_term.txt', |
|
|
'icebert-pos/scripts/dict_term.txt', |
|
|
'../scripts/dict_term.txt' |
|
|
] |
|
|
|
|
|
original_dict = None |
|
|
for path in possible_paths: |
|
|
if os.path.exists(path): |
|
|
original_dict = Dictionary.load(path) |
|
|
break |
|
|
|
|
|
if original_dict is not None: |
|
|
|
|
|
return SimpleLabelDictionary(original_dict.symbols, nspecial=original_dict.nspecial) |
|
|
|
|
|
except ImportError: |
|
|
|
|
|
pass |
|
|
except Exception: |
|
|
|
|
|
pass |
|
|
|
|
|
|
|
|
|
|
|
special_symbols = ["<s>", "<pad>", "</s>", "<unk>", "<SEP>"] |
|
|
|
|
|
|
|
|
schema_labels_without_sep = [label for label in schema.labels if label != "<SEP>"] |
|
|
|
|
|
|
|
|
all_symbols = special_symbols + schema_labels_without_sep |
|
|
|
|
|
return SimpleLabelDictionary(all_symbols, nspecial=4) |