lekhnathrijal commited on
Commit
692d9ea
·
1 Parent(s): fc3e987

Upload MultiTaskClassifierPipeline

Browse files
classifier.py ADDED
@@ -0,0 +1,182 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch.nn.functional as F
3
+ from collections import OrderedDict
4
+
5
+ class Classifier(object):
6
+ MULTI_CLASS = 'multi_class'
7
+ MULTI_LABEL = 'multi_label'
8
+ MODEL_CONFIG = 'classifier_config'
9
+
10
+ id2label = None
11
+ label2id = None
12
+ num_labels = 0
13
+ indices = {}
14
+
15
+ def __init__(self, config):
16
+ self.config = config
17
+ self.setup()
18
+
19
+ # @property
20
+ # def tokenizer_config(self):
21
+ # config = {}
22
+
23
+ # for cls, cls_items in self._config.items():
24
+ # config[cls] = [
25
+ # {"name": item["name"], "labels": item["labels"]} for item in cls_items
26
+ # ]
27
+
28
+ # return config
29
+
30
+
31
+ def setup(self):
32
+ all_items = [item for items in self.config.values() for item in items]
33
+ labels_dict = OrderedDict([(k, v) for item in all_items for (k,v) in item['labels']])
34
+
35
+ self.id2label = {idx : _l for (idx, _l) in enumerate(labels_dict)}
36
+ self.label2id = {_l : idx for (idx, _l) in enumerate(labels_dict)}
37
+ self.num_labels = len(self.id2label)
38
+
39
+
40
+ self._compute_indices()
41
+
42
+ def items(self, cls):
43
+ return self.config[cls]
44
+
45
+ def _compute_indices(self):
46
+ all_items = [item for items in self.config.values() for item in items]
47
+
48
+ self.indices = {}
49
+
50
+ range_offset = 0
51
+ for item in all_items:
52
+ cls_labels = OrderedDict(item['labels'])
53
+ name = item['name']
54
+
55
+ range_start = range_offset
56
+ range_end = range_start + len(cls_labels)
57
+
58
+ self.indices[name] = range(range_start, range_end)
59
+ range_offset = range_end
60
+
61
+ def encode_labels(self, row):
62
+ label_encodings = np.zeros(self.num_labels)
63
+
64
+ for item in self.items(self.MULTI_CLASS):
65
+ labels = OrderedDict(item['labels'])
66
+ cls_indices = self.indices[item['name']]
67
+ column_name = item['column']
68
+ offset = next(i for i in cls_indices)
69
+ cls_label2id = {_l: i for (i, _l) in enumerate(labels.keys())}
70
+ column_value = row[column_name].strip()
71
+
72
+ label_encodings[offset + cls_label2id[column_value]] = 1
73
+
74
+ for item in self.items(self.MULTI_LABEL):
75
+ cls_indices = self.indices[item['name']]
76
+ offset = next(i for i in cls_indices)
77
+ columns = item['columns']
78
+
79
+ for (cidx, column_name) in enumerate(columns):
80
+ cls_label2id = columns[column_name]
81
+ column_value = row[column_name].strip()
82
+
83
+ label_encodings[offset + cidx] = cls_label2id[column_value]
84
+
85
+
86
+ return label_encodings
87
+
88
+
89
+ def preds_from_logits(self, logits):
90
+ preds = np.zeros_like(logits)
91
+ (rows, _) = preds.shape
92
+
93
+ # print(logits)
94
+
95
+ for item in self.items(self.MULTI_CLASS):
96
+ cls_indices = self.indices[item['name']]
97
+ index_offset = next(i for i in cls_indices)
98
+ best_classes = np.argmax(logits[:,cls_indices], axis=-1)
99
+ preds[np.arange(rows), [i + index_offset for i in best_classes]] = 1
100
+
101
+ for item in self.items(self.MULTI_LABEL):
102
+ cls_indices = self.indices[item['name']]
103
+ threshold = item['threshold']
104
+
105
+ preds[:, cls_indices] = (logits[:, cls_indices] >= threshold).astype(float)
106
+
107
+
108
+ return preds
109
+
110
+ def compute_losses(self, logits, labels):
111
+ multi_class_losses = []
112
+ multi_label_losses = []
113
+
114
+ losses = {}
115
+
116
+ for item in self.items(self.MULTI_CLASS):
117
+ cls_indices = self.indices[item['name']]
118
+ cls_loss_weight = item.get('loss_weight', 1)
119
+ cls_loss = F.cross_entropy(logits[:,cls_indices], labels[:,cls_indices]).unsqueeze(dim=0)
120
+
121
+ multi_class_losses.append(cls_loss_weight * cls_loss)
122
+
123
+ for item in self.items(self.MULTI_LABEL):
124
+ cls_indices = self.indices[item['name']]
125
+ cls_loss_weight = item.get('loss_weight', 1)
126
+ cls_loss = F.binary_cross_entropy_with_logits(logits[:,cls_indices], labels[:,cls_indices]).unsqueeze(dim=0)
127
+ multi_label_losses.append(cls_loss_weight * cls_loss)
128
+
129
+
130
+ # return {
131
+ # self.MULTI_CLASS: sum(*multi_class_losses),
132
+ # self.MULTI_LABEL: sum(*multi_label_losses),
133
+ # }
134
+
135
+ losses.update({self.MULTI_CLASS: sum(*multi_class_losses)})
136
+ losses.update({self.MULTI_LABEL: sum(*multi_label_losses)})
137
+
138
+ return losses
139
+
140
+ def get_results(self, logits):
141
+ predictions = self.preds_from_logits(logits)
142
+ decoded_predictions = [
143
+ [self.id2label[i] for (i, _l) in enumerate(row) if _l == 1] \
144
+ for row in predictions
145
+ ]
146
+
147
+ results = []
148
+
149
+ for decoded in decoded_predictions:
150
+
151
+ result = {}
152
+
153
+ for item in self.items(self.MULTI_CLASS):
154
+ cls_labels = OrderedDict(item['labels'])
155
+ name = item['name']
156
+
157
+ key = next((_l for _l in decoded if _l in cls_labels), None)
158
+
159
+ if key is None:
160
+ value = None
161
+ else:
162
+ value = cls_labels[key]
163
+
164
+ result[name] = {
165
+ 'key': key,
166
+ 'value': value,
167
+ }
168
+
169
+
170
+ for item in self.items(self.MULTI_LABEL):
171
+ cls_labels = OrderedDict(item['labels'])
172
+ name = item['name']
173
+
174
+ result[name] = [cls_labels[_l] for _l in decoded if _l in cls_labels]
175
+
176
+
177
+ results.append(result)
178
+
179
+ return results
180
+
181
+ def random_logits(self, num_rows=1):
182
+ return np.random.uniform(-2, 2, (num_rows, self.num_labels))
classifier_pipeline.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import Pipeline
2
+ from .classifier import Classifier
3
+
4
+ class MultiTaskClassifierPipeline(Pipeline):
5
+
6
+ def _sanitize_parameters(self, **kwargs):
7
+ preprocess_kwargs = {}
8
+ postprocess_kwargs = {}
9
+
10
+ return preprocess_kwargs, {}, postprocess_kwargs
11
+
12
+ def preprocess(self, inputs):
13
+ return self.tokenizer(inputs, padding="max_length", truncation=True, return_tensors=self.framework).to(self.device)
14
+
15
+ def _forward(self, model_inputs):
16
+ return self.model(**model_inputs)
17
+
18
+ def postprocess(self, model_outputs):
19
+ model_config = self.model.config
20
+ classifier = Classifier(model_config.task_specific_params[Classifier.MODEL_CONFIG])
21
+ logits = model_outputs.logits.numpy()
22
+
23
+ return classifier.get_results(logits)[0]
config.json CHANGED
@@ -1,10 +1,28 @@
1
  {
2
- "_name_or_path": "google-bert/bert-large-uncased",
3
  "architectures": [
4
  "BertForSequenceClassification"
5
  ],
6
  "attention_probs_dropout_prob": 0.1,
7
  "classifier_dropout": null,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
  "gradient_checkpointing": false,
9
  "hidden_act": "gelu",
10
  "hidden_dropout_prob": 0.1,
 
1
  {
2
+ "_name_or_path": "ai-research-lab/bert-question-classifier",
3
  "architectures": [
4
  "BertForSequenceClassification"
5
  ],
6
  "attention_probs_dropout_prob": 0.1,
7
  "classifier_dropout": null,
8
+ "custom_pipelines": {
9
+ "question-classifier": {
10
+ "default": {
11
+ "model": {
12
+ "pt": [
13
+ "ai-research-lab/bert-question-classifier",
14
+ "main"
15
+ ]
16
+ }
17
+ },
18
+ "impl": "classifier_pipeline.MultiTaskClassifierPipeline",
19
+ "pt": [
20
+ "AutoModelForSequenceClassification"
21
+ ],
22
+ "tf": [],
23
+ "type": "text"
24
+ }
25
+ },
26
  "gradient_checkpointing": false,
27
  "hidden_act": "gelu",
28
  "hidden_dropout_prob": 0.1,
special_tokens_map.json CHANGED
@@ -1,7 +1,37 @@
1
  {
2
- "cls_token": "[CLS]",
3
- "mask_token": "[MASK]",
4
- "pad_token": "[PAD]",
5
- "sep_token": "[SEP]",
6
- "unk_token": "[UNK]"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
  }
 
1
  {
2
+ "cls_token": {
3
+ "content": "[CLS]",
4
+ "lstrip": false,
5
+ "normalized": false,
6
+ "rstrip": false,
7
+ "single_word": false
8
+ },
9
+ "mask_token": {
10
+ "content": "[MASK]",
11
+ "lstrip": false,
12
+ "normalized": false,
13
+ "rstrip": false,
14
+ "single_word": false
15
+ },
16
+ "pad_token": {
17
+ "content": "[PAD]",
18
+ "lstrip": false,
19
+ "normalized": false,
20
+ "rstrip": false,
21
+ "single_word": false
22
+ },
23
+ "sep_token": {
24
+ "content": "[SEP]",
25
+ "lstrip": false,
26
+ "normalized": false,
27
+ "rstrip": false,
28
+ "single_word": false
29
+ },
30
+ "unk_token": {
31
+ "content": "[UNK]",
32
+ "lstrip": false,
33
+ "normalized": false,
34
+ "rstrip": false,
35
+ "single_word": false
36
+ }
37
  }
tokenizer.json CHANGED
@@ -6,7 +6,16 @@
6
  "strategy": "LongestFirst",
7
  "stride": 0
8
  },
9
- "padding": null,
 
 
 
 
 
 
 
 
 
10
  "added_tokens": [
11
  {
12
  "id": 0,
 
6
  "strategy": "LongestFirst",
7
  "stride": 0
8
  },
9
+ "padding": {
10
+ "strategy": {
11
+ "Fixed": 512
12
+ },
13
+ "direction": "Right",
14
+ "pad_to_multiple_of": null,
15
+ "pad_id": 0,
16
+ "pad_type_id": 0,
17
+ "pad_token": "[PAD]"
18
+ },
19
  "added_tokens": [
20
  {
21
  "id": 0,
tokenizer_config.json CHANGED
@@ -50,8 +50,11 @@
50
  "model_max_length": 512,
51
  "pad_token": "[PAD]",
52
  "sep_token": "[SEP]",
 
53
  "strip_accents": null,
54
  "tokenize_chinese_chars": true,
55
  "tokenizer_class": "BertTokenizer",
 
 
56
  "unk_token": "[UNK]"
57
  }
 
50
  "model_max_length": 512,
51
  "pad_token": "[PAD]",
52
  "sep_token": "[SEP]",
53
+ "stride": 0,
54
  "strip_accents": null,
55
  "tokenize_chinese_chars": true,
56
  "tokenizer_class": "BertTokenizer",
57
+ "truncation_side": "right",
58
+ "truncation_strategy": "longest_first",
59
  "unk_token": "[UNK]"
60
  }