import numpy as np import torch.nn.functional as F from collections import OrderedDict class Classifier(object): MULTI_CLASS = 'multi_class' MULTI_LABEL = 'multi_label' MODEL_CONFIG = 'classifier_config' id2label = None label2id = None num_labels = 0 indices = {} def __init__(self, config): self.config = config self.setup() # @property # def tokenizer_config(self): # config = {} # for cls, cls_items in self._config.items(): # config[cls] = [ # {"name": item["name"], "labels": item["labels"]} for item in cls_items # ] # return config def setup(self): all_items = [item for items in self.config.values() for item in items] labels_dict = OrderedDict([(k, v) for item in all_items for (k,v) in item['labels']]) self.id2label = {idx : _l for (idx, _l) in enumerate(labels_dict)} self.label2id = {_l : idx for (idx, _l) in enumerate(labels_dict)} self.num_labels = len(self.id2label) self._compute_indices() def items(self, cls): return self.config[cls] def _compute_indices(self): all_items = [item for items in self.config.values() for item in items] self.indices = {} range_offset = 0 for item in all_items: cls_labels = OrderedDict(item['labels']) name = item['name'] range_start = range_offset range_end = range_start + len(cls_labels) self.indices[name] = range(range_start, range_end) range_offset = range_end def encode_labels(self, row): label_encodings = np.zeros(self.num_labels) for item in self.items(self.MULTI_CLASS): labels = OrderedDict(item['labels']) cls_indices = self.indices[item['name']] column_name = item['column'] offset = next(i for i in cls_indices) cls_label2id = {_l: i for (i, _l) in enumerate(labels.keys())} column_value = row[column_name].strip() label_encodings[offset + cls_label2id[column_value]] = 1 for item in self.items(self.MULTI_LABEL): cls_indices = self.indices[item['name']] offset = next(i for i in cls_indices) columns = item['columns'] for (cidx, column_name) in enumerate(columns): cls_label2id = columns[column_name] column_value = row[column_name].strip() label_encodings[offset + cidx] = cls_label2id[column_value] return label_encodings def preds_from_logits(self, logits): preds = np.zeros_like(logits) (rows, _) = preds.shape # print(logits) for item in self.items(self.MULTI_CLASS): cls_indices = self.indices[item['name']] index_offset = next(i for i in cls_indices) best_classes = np.argmax(logits[:,cls_indices], axis=-1) preds[np.arange(rows), [i + index_offset for i in best_classes]] = 1 for item in self.items(self.MULTI_LABEL): cls_indices = self.indices[item['name']] threshold = item['threshold'] preds[:, cls_indices] = (logits[:, cls_indices] >= threshold).astype(float) return preds def compute_losses(self, logits, labels): multi_class_losses = [] multi_label_losses = [] losses = {} for item in self.items(self.MULTI_CLASS): cls_indices = self.indices[item['name']] cls_loss_weight = item.get('loss_weight', 1) cls_loss = F.cross_entropy(logits[:,cls_indices], labels[:,cls_indices]).unsqueeze(dim=0) multi_class_losses.append(cls_loss_weight * cls_loss) for item in self.items(self.MULTI_LABEL): cls_indices = self.indices[item['name']] cls_loss_weight = item.get('loss_weight', 1) cls_loss = F.binary_cross_entropy_with_logits(logits[:,cls_indices], labels[:,cls_indices]).unsqueeze(dim=0) multi_label_losses.append(cls_loss_weight * cls_loss) # return { # self.MULTI_CLASS: sum(*multi_class_losses), # self.MULTI_LABEL: sum(*multi_label_losses), # } losses.update({self.MULTI_CLASS: sum(*multi_class_losses)}) losses.update({self.MULTI_LABEL: sum(*multi_label_losses)}) return losses def get_results(self, logits): predictions = self.preds_from_logits(logits) decoded_predictions = [ [self.id2label[i] for (i, _l) in enumerate(row) if _l == 1] \ for row in predictions ] results = [] for decoded in decoded_predictions: result = {} for item in self.items(self.MULTI_CLASS): cls_labels = OrderedDict(item['labels']) name = item['name'] key = next((_l for _l in decoded if _l in cls_labels), None) if key is None: value = None else: value = cls_labels[key] result[name] = { 'key': key, 'value': value, } for item in self.items(self.MULTI_LABEL): cls_labels = OrderedDict(item['labels']) name = item['name'] result[name] = [cls_labels[_l] for _l in decoded if _l in cls_labels] results.append(result) return results def random_logits(self, num_rows=1): return np.random.uniform(-2, 2, (num_rows, self.num_labels))