bert-question-classifier / classifier.py
lekhnathrijal's picture
Upload MultiTaskClassifierPipeline
692d9ea
import numpy as np
import torch.nn.functional as F
from collections import OrderedDict
class Classifier(object):
MULTI_CLASS = 'multi_class'
MULTI_LABEL = 'multi_label'
MODEL_CONFIG = 'classifier_config'
id2label = None
label2id = None
num_labels = 0
indices = {}
def __init__(self, config):
self.config = config
self.setup()
# @property
# def tokenizer_config(self):
# config = {}
# for cls, cls_items in self._config.items():
# config[cls] = [
# {"name": item["name"], "labels": item["labels"]} for item in cls_items
# ]
# return config
def setup(self):
all_items = [item for items in self.config.values() for item in items]
labels_dict = OrderedDict([(k, v) for item in all_items for (k,v) in item['labels']])
self.id2label = {idx : _l for (idx, _l) in enumerate(labels_dict)}
self.label2id = {_l : idx for (idx, _l) in enumerate(labels_dict)}
self.num_labels = len(self.id2label)
self._compute_indices()
def items(self, cls):
return self.config[cls]
def _compute_indices(self):
all_items = [item for items in self.config.values() for item in items]
self.indices = {}
range_offset = 0
for item in all_items:
cls_labels = OrderedDict(item['labels'])
name = item['name']
range_start = range_offset
range_end = range_start + len(cls_labels)
self.indices[name] = range(range_start, range_end)
range_offset = range_end
def encode_labels(self, row):
label_encodings = np.zeros(self.num_labels)
for item in self.items(self.MULTI_CLASS):
labels = OrderedDict(item['labels'])
cls_indices = self.indices[item['name']]
column_name = item['column']
offset = next(i for i in cls_indices)
cls_label2id = {_l: i for (i, _l) in enumerate(labels.keys())}
column_value = row[column_name].strip()
label_encodings[offset + cls_label2id[column_value]] = 1
for item in self.items(self.MULTI_LABEL):
cls_indices = self.indices[item['name']]
offset = next(i for i in cls_indices)
columns = item['columns']
for (cidx, column_name) in enumerate(columns):
cls_label2id = columns[column_name]
column_value = row[column_name].strip()
label_encodings[offset + cidx] = cls_label2id[column_value]
return label_encodings
def preds_from_logits(self, logits):
preds = np.zeros_like(logits)
(rows, _) = preds.shape
# print(logits)
for item in self.items(self.MULTI_CLASS):
cls_indices = self.indices[item['name']]
index_offset = next(i for i in cls_indices)
best_classes = np.argmax(logits[:,cls_indices], axis=-1)
preds[np.arange(rows), [i + index_offset for i in best_classes]] = 1
for item in self.items(self.MULTI_LABEL):
cls_indices = self.indices[item['name']]
threshold = item['threshold']
preds[:, cls_indices] = (logits[:, cls_indices] >= threshold).astype(float)
return preds
def compute_losses(self, logits, labels):
multi_class_losses = []
multi_label_losses = []
losses = {}
for item in self.items(self.MULTI_CLASS):
cls_indices = self.indices[item['name']]
cls_loss_weight = item.get('loss_weight', 1)
cls_loss = F.cross_entropy(logits[:,cls_indices], labels[:,cls_indices]).unsqueeze(dim=0)
multi_class_losses.append(cls_loss_weight * cls_loss)
for item in self.items(self.MULTI_LABEL):
cls_indices = self.indices[item['name']]
cls_loss_weight = item.get('loss_weight', 1)
cls_loss = F.binary_cross_entropy_with_logits(logits[:,cls_indices], labels[:,cls_indices]).unsqueeze(dim=0)
multi_label_losses.append(cls_loss_weight * cls_loss)
# return {
# self.MULTI_CLASS: sum(*multi_class_losses),
# self.MULTI_LABEL: sum(*multi_label_losses),
# }
losses.update({self.MULTI_CLASS: sum(*multi_class_losses)})
losses.update({self.MULTI_LABEL: sum(*multi_label_losses)})
return losses
def get_results(self, logits):
predictions = self.preds_from_logits(logits)
decoded_predictions = [
[self.id2label[i] for (i, _l) in enumerate(row) if _l == 1] \
for row in predictions
]
results = []
for decoded in decoded_predictions:
result = {}
for item in self.items(self.MULTI_CLASS):
cls_labels = OrderedDict(item['labels'])
name = item['name']
key = next((_l for _l in decoded if _l in cls_labels), None)
if key is None:
value = None
else:
value = cls_labels[key]
result[name] = {
'key': key,
'value': value,
}
for item in self.items(self.MULTI_LABEL):
cls_labels = OrderedDict(item['labels'])
name = item['name']
result[name] = [cls_labels[_l] for _l in decoded if _l in cls_labels]
results.append(result)
return results
def random_logits(self, num_rows=1):
return np.random.uniform(-2, 2, (num_rows, self.num_labels))