File size: 5,912 Bytes
692d9ea
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
import numpy as np
import torch.nn.functional as F
from collections import OrderedDict

class Classifier(object):
    MULTI_CLASS = 'multi_class'
    MULTI_LABEL = 'multi_label'
    MODEL_CONFIG = 'classifier_config'

    id2label = None
    label2id = None
    num_labels = 0
    indices = {}

    def __init__(self, config):
        self.config = config
        self.setup()
    
    # @property
    # def tokenizer_config(self):
    #     config = {}

    #     for cls, cls_items in self._config.items():
    #         config[cls] = [
    #             {"name": item["name"], "labels": item["labels"]} for item in cls_items
    #         ]
        
    #     return config


    def setup(self):
        all_items = [item for items in self.config.values() for item in items]
        labels_dict = OrderedDict([(k, v) for item in all_items for (k,v) in item['labels']])
    
        self.id2label = {idx : _l for (idx, _l) in enumerate(labels_dict)}
        self.label2id = {_l : idx for (idx, _l) in enumerate(labels_dict)}
        self.num_labels = len(self.id2label)


        self._compute_indices()

    def items(self, cls):
        return self.config[cls]

    def _compute_indices(self):
        all_items = [item for items in self.config.values() for item in items]
    
        self.indices = {}
    
        range_offset = 0
        for item in all_items:
            cls_labels = OrderedDict(item['labels'])
            name = item['name']
    
            range_start = range_offset
            range_end = range_start + len(cls_labels)
    
            self.indices[name] = range(range_start, range_end)
            range_offset = range_end
    
    def encode_labels(self, row):
        label_encodings = np.zeros(self.num_labels)
    
        for item in self.items(self.MULTI_CLASS):
            labels = OrderedDict(item['labels'])
            cls_indices = self.indices[item['name']]
            column_name = item['column']
            offset = next(i for i in cls_indices)
            cls_label2id = {_l: i for (i, _l) in enumerate(labels.keys())}
            column_value = row[column_name].strip()
    
            label_encodings[offset + cls_label2id[column_value]] = 1
    
        for item in self.items(self.MULTI_LABEL):
            cls_indices = self.indices[item['name']]
            offset = next(i for i in cls_indices)
            columns = item['columns']
    
            for (cidx, column_name) in enumerate(columns):
                cls_label2id = columns[column_name]
                column_value = row[column_name].strip()
    
                label_encodings[offset + cidx] = cls_label2id[column_value]
            
    
        return label_encodings


    def preds_from_logits(self, logits):
        preds = np.zeros_like(logits)
        (rows, _) = preds.shape
    
        # print(logits)
    
        for item in self.items(self.MULTI_CLASS):
            cls_indices = self.indices[item['name']]
            index_offset = next(i for i in cls_indices)
            best_classes = np.argmax(logits[:,cls_indices], axis=-1)
            preds[np.arange(rows), [i + index_offset for i in best_classes]] = 1
    
        for item in self.items(self.MULTI_LABEL):
            cls_indices = self.indices[item['name']]
            threshold = item['threshold']
    
            preds[:, cls_indices] = (logits[:, cls_indices] >= threshold).astype(float)
            
    
        return preds

    def compute_losses(self, logits, labels):
        multi_class_losses = []
        multi_label_losses = []

        losses = {}
        
        for item in self.items(self.MULTI_CLASS):
            cls_indices = self.indices[item['name']]
            cls_loss_weight = item.get('loss_weight', 1)
            cls_loss = F.cross_entropy(logits[:,cls_indices], labels[:,cls_indices]).unsqueeze(dim=0)
            
            multi_class_losses.append(cls_loss_weight * cls_loss)
    
        for item in self.items(self.MULTI_LABEL):
            cls_indices = self.indices[item['name']]
            cls_loss_weight = item.get('loss_weight', 1)
            cls_loss = F.binary_cross_entropy_with_logits(logits[:,cls_indices], labels[:,cls_indices]).unsqueeze(dim=0)
            multi_label_losses.append(cls_loss_weight * cls_loss)
    
    
        # return {
        #     self.MULTI_CLASS: sum(*multi_class_losses),
        #     self.MULTI_LABEL: sum(*multi_label_losses),
        # }  
    
        losses.update({self.MULTI_CLASS: sum(*multi_class_losses)})
        losses.update({self.MULTI_LABEL: sum(*multi_label_losses)})
    
        return losses

    def get_results(self, logits):
        predictions = self.preds_from_logits(logits)
        decoded_predictions = [
            [self.id2label[i] for (i, _l) in enumerate(row) if _l == 1] \
                for row in predictions
        ]
    
        results = []
    
        for decoded in decoded_predictions:
    
            result = {}
    
            for item in self.items(self.MULTI_CLASS):
                cls_labels = OrderedDict(item['labels'])
                name = item['name']
    
                key = next((_l for _l in decoded if _l in cls_labels), None)
                
                if key is None:
                    value = None
                else:
                    value = cls_labels[key]
        
                result[name] = {
                    'key': key,
                    'value': value,
                }
    
    
            for item in self.items(self.MULTI_LABEL):
                cls_labels = OrderedDict(item['labels'])
                name = item['name']
                
                result[name] = [cls_labels[_l] for _l in decoded if _l in cls_labels]
                
        
            results.append(result)
    
        return results

    def random_logits(self, num_rows=1):
        return np.random.uniform(-2, 2, (num_rows, self.num_labels))