|
|
import numpy as np |
|
|
import torch.nn.functional as F |
|
|
from collections import OrderedDict |
|
|
|
|
|
class Classifier(object): |
|
|
MULTI_CLASS = 'multi_class' |
|
|
MULTI_LABEL = 'multi_label' |
|
|
MODEL_CONFIG = 'classifier_config' |
|
|
|
|
|
id2label = None |
|
|
label2id = None |
|
|
num_labels = 0 |
|
|
indices = {} |
|
|
|
|
|
def __init__(self, config): |
|
|
self.config = config |
|
|
self.setup() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def setup(self): |
|
|
all_items = [item for items in self.config.values() for item in items] |
|
|
labels_dict = OrderedDict([(k, v) for item in all_items for (k,v) in item['labels']]) |
|
|
|
|
|
self.id2label = {idx : _l for (idx, _l) in enumerate(labels_dict)} |
|
|
self.label2id = {_l : idx for (idx, _l) in enumerate(labels_dict)} |
|
|
self.num_labels = len(self.id2label) |
|
|
|
|
|
|
|
|
self._compute_indices() |
|
|
|
|
|
def items(self, cls): |
|
|
return self.config[cls] |
|
|
|
|
|
def _compute_indices(self): |
|
|
all_items = [item for items in self.config.values() for item in items] |
|
|
|
|
|
self.indices = {} |
|
|
|
|
|
range_offset = 0 |
|
|
for item in all_items: |
|
|
cls_labels = OrderedDict(item['labels']) |
|
|
name = item['name'] |
|
|
|
|
|
range_start = range_offset |
|
|
range_end = range_start + len(cls_labels) |
|
|
|
|
|
self.indices[name] = range(range_start, range_end) |
|
|
range_offset = range_end |
|
|
|
|
|
def encode_labels(self, row): |
|
|
label_encodings = np.zeros(self.num_labels) |
|
|
|
|
|
for item in self.items(self.MULTI_CLASS): |
|
|
labels = OrderedDict(item['labels']) |
|
|
cls_indices = self.indices[item['name']] |
|
|
column_name = item['column'] |
|
|
offset = next(i for i in cls_indices) |
|
|
cls_label2id = {_l: i for (i, _l) in enumerate(labels.keys())} |
|
|
column_value = row[column_name].strip() |
|
|
|
|
|
label_encodings[offset + cls_label2id[column_value]] = 1 |
|
|
|
|
|
for item in self.items(self.MULTI_LABEL): |
|
|
cls_indices = self.indices[item['name']] |
|
|
offset = next(i for i in cls_indices) |
|
|
columns = item['columns'] |
|
|
|
|
|
for (cidx, column_name) in enumerate(columns): |
|
|
cls_label2id = columns[column_name] |
|
|
column_value = row[column_name].strip() |
|
|
|
|
|
label_encodings[offset + cidx] = cls_label2id[column_value] |
|
|
|
|
|
|
|
|
return label_encodings |
|
|
|
|
|
|
|
|
def preds_from_logits(self, logits): |
|
|
preds = np.zeros_like(logits) |
|
|
(rows, _) = preds.shape |
|
|
|
|
|
|
|
|
|
|
|
for item in self.items(self.MULTI_CLASS): |
|
|
cls_indices = self.indices[item['name']] |
|
|
index_offset = next(i for i in cls_indices) |
|
|
best_classes = np.argmax(logits[:,cls_indices], axis=-1) |
|
|
preds[np.arange(rows), [i + index_offset for i in best_classes]] = 1 |
|
|
|
|
|
for item in self.items(self.MULTI_LABEL): |
|
|
cls_indices = self.indices[item['name']] |
|
|
threshold = item['threshold'] |
|
|
|
|
|
preds[:, cls_indices] = (logits[:, cls_indices] >= threshold).astype(float) |
|
|
|
|
|
|
|
|
return preds |
|
|
|
|
|
def compute_losses(self, logits, labels): |
|
|
multi_class_losses = [] |
|
|
multi_label_losses = [] |
|
|
|
|
|
losses = {} |
|
|
|
|
|
for item in self.items(self.MULTI_CLASS): |
|
|
cls_indices = self.indices[item['name']] |
|
|
cls_loss_weight = item.get('loss_weight', 1) |
|
|
cls_loss = F.cross_entropy(logits[:,cls_indices], labels[:,cls_indices]).unsqueeze(dim=0) |
|
|
|
|
|
multi_class_losses.append(cls_loss_weight * cls_loss) |
|
|
|
|
|
for item in self.items(self.MULTI_LABEL): |
|
|
cls_indices = self.indices[item['name']] |
|
|
cls_loss_weight = item.get('loss_weight', 1) |
|
|
cls_loss = F.binary_cross_entropy_with_logits(logits[:,cls_indices], labels[:,cls_indices]).unsqueeze(dim=0) |
|
|
multi_label_losses.append(cls_loss_weight * cls_loss) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
losses.update({self.MULTI_CLASS: sum(*multi_class_losses)}) |
|
|
losses.update({self.MULTI_LABEL: sum(*multi_label_losses)}) |
|
|
|
|
|
return losses |
|
|
|
|
|
def get_results(self, logits): |
|
|
predictions = self.preds_from_logits(logits) |
|
|
decoded_predictions = [ |
|
|
[self.id2label[i] for (i, _l) in enumerate(row) if _l == 1] \ |
|
|
for row in predictions |
|
|
] |
|
|
|
|
|
results = [] |
|
|
|
|
|
for decoded in decoded_predictions: |
|
|
|
|
|
result = {} |
|
|
|
|
|
for item in self.items(self.MULTI_CLASS): |
|
|
cls_labels = OrderedDict(item['labels']) |
|
|
name = item['name'] |
|
|
|
|
|
key = next((_l for _l in decoded if _l in cls_labels), None) |
|
|
|
|
|
if key is None: |
|
|
value = None |
|
|
else: |
|
|
value = cls_labels[key] |
|
|
|
|
|
result[name] = { |
|
|
'key': key, |
|
|
'value': value, |
|
|
} |
|
|
|
|
|
|
|
|
for item in self.items(self.MULTI_LABEL): |
|
|
cls_labels = OrderedDict(item['labels']) |
|
|
name = item['name'] |
|
|
|
|
|
result[name] = [cls_labels[_l] for _l in decoded if _l in cls_labels] |
|
|
|
|
|
|
|
|
results.append(result) |
|
|
|
|
|
return results |
|
|
|
|
|
def random_logits(self, num_rows=1): |
|
|
return np.random.uniform(-2, 2, (num_rows, self.num_labels)) |