File size: 8,045 Bytes
aaca62a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
# Copyright (C) Miðeind ehf.
# This file is part of IceBERT POS model conversion.

"""
Utility functions copied from the old fairseq-based model for label handling.
These functions handle the conversion between vector indices and dictionary indices,
accounting for the offset caused by special tokens in the label dictionary.
"""

from typing import Dict, List, Tuple
import torch


class SimpleLabelDictionary:
    """
    Simplified version of fairseq Dictionary to handle label mappings.
    This replaces the fairseq Dictionary dependency while maintaining the same interface.
    """
    
    def __init__(self, labels: List[str], nspecial: int = 5):
        """
        Args:
            labels: List of labels including special tokens at the beginning
            nspecial: Number of special tokens (typically 5: <pad>, <s>, </s>, <unk>, <SEP>)
        """
        self.symbols = labels
        self.nspecial = nspecial
        self._indices = {label: idx for idx, label in enumerate(labels)}
    
    def index(self, label: str) -> int:
        """Get index of label in dictionary."""
        return self._indices.get(label, self.unk())
    
    def unk(self) -> int:
        """Return index of unknown token (typically 3)."""
        return 3
    
    def string(self, indices: torch.Tensor) -> str:
        """Convert tensor of indices to space-separated string of labels."""
        if indices.dim() == 0:
            indices = indices.unsqueeze(0)
        
        # Filter out special tokens like fairseq Dictionary does
        special_indices_to_ignore = {0, 1, 2, 3}  # BOS, PAD, EOS, UNK
        
        labels = [
            self.symbols[idx] for idx in indices.tolist() 
            if 0 <= idx < len(self.symbols) and idx not in special_indices_to_ignore
        ]
        return " ".join(labels)
    
    def __len__(self) -> int:
        return len(self.symbols)


def make_vec_idx_to_dict_idx(dictionary: SimpleLabelDictionary, labels: List[str], device="cpu", fill_value=-100) -> torch.Tensor:
    """
    Create mapping from vector indices to dictionary indices.
    
    Args:
        dictionary: Label dictionary
        labels: List of labels
        device: Device for tensor
        fill_value: Fill value for missing entries
        
    Returns:
        Tensor mapping vector indices to dictionary indices
    """
    vec_idx_to_dict_idx = torch.full((len(dictionary),), device=device, fill_value=fill_value, dtype=torch.long)
    for vec_idx, label in enumerate(labels):
        vec_idx_to_dict_idx[vec_idx] = dictionary.index(label)
    return vec_idx_to_dict_idx


def make_group_masks(dictionary: SimpleLabelDictionary, schema, device="cpu") -> torch.Tensor:
    """
    Create group masks indicating which groups are valid for each category.
    
    Args:
        dictionary: Label dictionary
        schema: Label schema object
        device: Device for tensor
        
    Returns:
        Tensor of shape (num_categories, num_groups) with 1 for valid combinations
    """
    num_groups = len(schema.group_names)
    offset = dictionary.nspecial
    num_labels = len(dictionary) - offset
    ret_mask = torch.zeros(num_labels, num_groups, dtype=torch.int64, device=device)
    
    for cat, cat_group_names in schema.category_to_group_names.items():
        cat_label_idx = dictionary.index(cat)
        cat_vec_idx = schema.label_categories.index(cat)
        for group_name in cat_group_names:
            ret_mask[cat_vec_idx, schema.group_names.index(group_name)] = 1
        assert cat_label_idx != dictionary.unk()
    
    return ret_mask


def make_group_name_to_group_attr_vec_idxs(dictionary: SimpleLabelDictionary, schema) -> Dict[str, torch.Tensor]:
    """
    Create mapping from group names to their attribute vector indices.
    
    Args:
        dictionary: Label dictionary
        schema: Label schema object
        
    Returns:
        Dictionary mapping group names to tensor of vector indices
    """
    offset = dictionary.nspecial
    group_names = schema.group_name_to_labels.keys()
    name_to_labels = schema.group_name_to_labels
    group_name_to_group_attr_vec_idxs = {
        name: torch.tensor([dictionary.index(item) - offset for item in name_to_labels[name]]) 
        for name in group_names
    }
    return group_name_to_group_attr_vec_idxs


def make_dict_idx_to_vec_idx(dictionary: SimpleLabelDictionary, cats: List[str], device="cpu", fill_value=-100) -> torch.Tensor:
    """
    Create mapping from dictionary indices to vector indices.
    
    Args:
        dictionary: Label dictionary
        cats: List of categories
        device: Device for tensor
        fill_value: Fill value for missing entries
        
    Returns:
        Tensor mapping dictionary indices to vector indices
    """
    # NOTE: when target is not in label_categories, the error is silent
    map_tgt = torch.full((len(dictionary),), device=device, fill_value=fill_value, dtype=torch.long)
    for vec_idx, label in enumerate(cats):
        map_tgt[dictionary.index(label)] = vec_idx
    return map_tgt


def clean_cats_attrs(ldict: SimpleLabelDictionary, schema, pred_cats: torch.Tensor, pred_attrs: torch.Tensor) -> List[Tuple[str, List[str]]]:
    """
    Convert predicted category and attribute indices to human-readable labels.
    
    Args:
        ldict: Label dictionary
        schema: Label schema object
        pred_cats: Predicted category indices
        pred_attrs: Predicted attribute indices
        
    Returns:
        List of (category, [attributes]) tuples
    """
    cats = ldict.string(pred_cats).split(" ")
    attrs = []

    if len(pred_attrs.shape) == 1:
        split_pred_attrs = [pred_attrs]
    else:
        split_pred_attrs = pred_attrs.split(1, dim=0)
    
    for (_cat_idx, attr_idxs) in zip(pred_cats.tolist(), split_pred_attrs):
        seq_attrs = [lbl for lbl in ldict.string((attr_idxs.squeeze())).split(" ")]
        if not any(it for it in seq_attrs):
            seq_attrs = []
        attrs.append(seq_attrs)
    
    return list(zip(cats, attrs))


def create_label_dictionary_from_schema(schema) -> SimpleLabelDictionary:
    """
    Create a SimpleLabelDictionary from a label schema, mimicking the old fairseq setup.
    Load the exact symbols from the original fairseq dictionary to ensure perfect compatibility.
    
    Args:
        schema: Label schema object (unused, kept for compatibility)
        
    Returns:
        SimpleLabelDictionary with exact same symbols as original fairseq dict
    """
    try:
        # Load original fairseq dictionary to get exact symbol order and content
        from fairseq.data import Dictionary
        import os
        
        # Try to find the original dict_term.txt file
        possible_paths = [
            'scripts/dict_term.txt',
            'icebert-pos/scripts/dict_term.txt',
            '../scripts/dict_term.txt'
        ]
        
        original_dict = None
        for path in possible_paths:
            if os.path.exists(path):
                original_dict = Dictionary.load(path)
                break
        
        if original_dict is not None:
            # Use exact symbols from original dictionary
            return SimpleLabelDictionary(original_dict.symbols, nspecial=original_dict.nspecial)
    
    except ImportError:
        # Fallback if fairseq is not available
        pass
    except Exception:
        # Fallback if file loading fails
        pass
    
    # Fallback: reconstruct from schema (original logic)
    # Use the correct special token order from original dictionary
    special_symbols = ["<s>", "<pad>", "</s>", "<unk>", "<SEP>"]
    
    # The schema labels start with <SEP>, so we need to skip it
    schema_labels_without_sep = [label for label in schema.labels if label != "<SEP>"]
    
    # Combine: special tokens + schema labels (without duplicate <SEP>)
    all_symbols = special_symbols + schema_labels_without_sep
    
    return SimpleLabelDictionary(all_symbols, nspecial=4)  # 4 special tokens before <SEP>