shivansarora commited on
Commit
d7332b4
Β·
verified Β·
1 Parent(s): d03160a

Create level_model.py

Browse files
Files changed (1) hide show
  1. level_model.py +236 -0
level_model.py ADDED
@@ -0,0 +1,236 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch, random, itertools, tqdm
2
+ import numpy as np
3
+ from torch import nn
4
+ from torch.utils.data import DataLoader
5
+ from util import mean_pooling, read_corpus, CEFRDataset, convert_numeral_to_six_levels
6
+ from model_base import LevelEstimaterBase
7
+
8
+
9
+ class LevelEstimaterClassification(LevelEstimaterBase):
10
+ def __init__(self, pretrained_model, problem_type, with_ib, with_loss_weight,
11
+ attach_wlv, num_labels,
12
+ word_num_labels, alpha,
13
+ ib_beta,
14
+ batch_size,
15
+ learning_rate,
16
+ warmup,
17
+ lm_layer, corpus_path=None, test_corpus_path=None,):
18
+ super().__init__(corpus_path, test_corpus_path, pretrained_model, with_ib, attach_wlv, num_labels,
19
+ word_num_labels, alpha,
20
+ batch_size,
21
+ learning_rate, warmup, lm_layer)
22
+ self.save_hyperparameters()
23
+
24
+ self.problem_type = problem_type
25
+ self.with_loss_weight = with_loss_weight
26
+ self.ib_beta = ib_beta
27
+ self.dropout = nn.Dropout(0.1)
28
+
29
+ if self.problem_type == "regression":
30
+ self.slv_classifier = nn.Linear(self.lm.config.hidden_size, 1)
31
+ self.loss_fct = nn.MSELoss()
32
+ else:
33
+ self.slv_classifier = nn.Linear(self.lm.config.hidden_size, self.CEFR_lvs)
34
+ if self.with_loss_weight and corpus_path is not None:
35
+ train_sentlv_weights = self.precompute_loss_weights()
36
+ self.loss_fct = nn.CrossEntropyLoss(weight=train_sentlv_weights)
37
+ else:
38
+ self.loss_fct = nn.CrossEntropyLoss()
39
+
40
+ def forward(self, inputs):
41
+ # in lightning, forward defines the prediction/inference actions
42
+ outputs, information_loss = self.encode(inputs)
43
+ outputs = mean_pooling(outputs, attention_mask=inputs['attention_mask'])
44
+ logits = self.slv_classifier(self.dropout(outputs))
45
+
46
+ if self.problem_type == "regression":
47
+ predictions = convert_numeral_to_six_levels(logits.detach().clone().cpu().numpy())
48
+ else:
49
+ predictions = torch.argmax(torch.softmax(logits.detach().clone(), dim=1), dim=1, keepdim=True)
50
+
51
+ loss = None
52
+ if 'slabels_high' in inputs:
53
+ if self.problem_type == "regression":
54
+ labels = (inputs['slabels_high'] + inputs['slabels_low']) / 2
55
+ cls_loss = self.loss_fct(logits.squeeze(), labels.squeeze())
56
+ else:
57
+ labels = self.get_gold_labels(predictions, inputs['slabels_low'].detach().clone(),
58
+ inputs['slabels_high'].detach().clone())
59
+ cls_loss = self.loss_fct(logits.view(-1, self.CEFR_lvs), labels.view(-1))
60
+
61
+ loss = cls_loss
62
+ logs = {"loss": cls_loss}
63
+
64
+ predictions = predictions.cpu().numpy()
65
+
66
+ return (loss, predictions, logs) if loss is not None else predictions
67
+
68
+ def step(self, batch):
69
+ loss, predictions, logs = self.forward(batch)
70
+ return loss, logs
71
+
72
+ def _shared_eval_step(self, batch):
73
+ loss, predictions, logs = self.forward(batch)
74
+
75
+ gold_labels_low = batch['slabels_low'].cpu().detach().clone().numpy()
76
+ gold_labels_high = batch['slabels_high'].cpu().detach().clone().numpy()
77
+ golds_predictions = {'gold_labels_low': gold_labels_low, 'gold_labels_high': gold_labels_high,
78
+ 'pred_labels': predictions}
79
+
80
+ return logs, golds_predictions
81
+
82
+ def training_step(self, batch, batch_idx):
83
+ loss, logs = self.step(batch)
84
+ self.log_dict({f"train_{k}": v for k, v in logs.items()})
85
+ return loss
86
+
87
+ def validation_step(self, batch, batch_idx):
88
+ logs, golds_predictions = self._shared_eval_step(batch)
89
+ self.log_dict({f"val_{k}": v for k, v in logs.items()})
90
+ return golds_predictions
91
+
92
+ def validation_epoch_end(self, outputs):
93
+ logs = self.evaluation(outputs)
94
+ self.log_dict({f"val_{k}": v for k, v in logs.items()})
95
+
96
+ def test_step(self, batch, batch_idx):
97
+ logs, golds_predictions = self._shared_eval_step(batch)
98
+ self.log_dict({f"test_{k}": v for k, v in logs.items()})
99
+ return golds_predictions
100
+
101
+ def test_epoch_end(self, outputs):
102
+ logs = self.evaluation(outputs, test=True)
103
+ self.log_dict({f"test_{k}": v for k, v in logs.items()})
104
+
105
+
106
+ class LevelEstimaterContrastive(LevelEstimaterBase):
107
+ def __init__(self, corpus_path, test_corpus_path, pretrained_model, problem_type, with_ib, with_loss_weight,
108
+ attach_wlv, num_labels,
109
+ word_num_labels,
110
+ num_prototypes,
111
+ alpha,
112
+ ib_beta,
113
+ batch_size,
114
+ learning_rate,
115
+ warmup,
116
+ lm_layer):
117
+ super().__init__(corpus_path, test_corpus_path, pretrained_model, with_ib, attach_wlv, num_labels,
118
+ word_num_labels, alpha,
119
+ batch_size,
120
+ learning_rate, warmup, lm_layer)
121
+ self.save_hyperparameters()
122
+
123
+ self.problem_type = problem_type
124
+ self.num_prototypes = num_prototypes
125
+ self.with_loss_weight = with_loss_weight
126
+ self.ib_beta = ib_beta
127
+
128
+ self.prototype = nn.Embedding(self.CEFR_lvs * self.num_prototypes, self.lm.config.hidden_size)
129
+ # nn.init.xavier_normal_(self.prototype.weight) # Xavier initialization
130
+ # nn.init.orthogonal_(self.prototype.weight) # Make prototype vectors orthogonal
131
+
132
+ if self.with_loss_weight:
133
+ loss_weights = self.precompute_loss_weights()
134
+ self.loss_fct = nn.CrossEntropyLoss(weight=loss_weights)
135
+ else:
136
+ self.loss_fct = nn.CrossEntropyLoss()
137
+
138
+ def forward(self, batch):
139
+ # in lightning, forward defines the prediction/inference actions
140
+ outputs, information_loss = self.encode(batch)
141
+ outputs = mean_pooling(outputs, attention_mask=batch['attention_mask'])
142
+
143
+ # positive: compute cosine similarity
144
+ outputs = torch.nn.functional.normalize(outputs)
145
+ positive_prototypes = torch.nn.functional.normalize(self.prototype.weight)
146
+ logits = torch.mm(outputs, positive_prototypes.T)
147
+ logits = logits.reshape((-1, self.num_prototypes, self.CEFR_lvs))
148
+ logits = logits.mean(dim=1)
149
+
150
+ # prediction
151
+ predictions = torch.argmax(torch.softmax(logits.detach().clone(), dim=1), dim=1, keepdim=True)
152
+
153
+ loss = None
154
+ if 'slabels_high' in batch:
155
+ labels = self.get_gold_labels(predictions, batch['slabels_low'].detach().clone(),
156
+ batch['slabels_high'].detach().clone())
157
+ # cross-entropy loss
158
+ cls_loss = self.loss_fct(logits.view(-1, self.CEFR_lvs), labels.view(-1))
159
+
160
+ loss = cls_loss
161
+ logs = {"loss": loss}
162
+
163
+ predictions = predictions.cpu().numpy()
164
+
165
+ return (loss, predictions, logs) if loss is not None else predictions
166
+
167
+ def _shared_eval_step(self, batch):
168
+ loss, predictions, logs = self.forward(batch)
169
+
170
+ gold_labels_low = batch['slabels_low'].cpu().detach().clone().numpy()
171
+ gold_labels_high = batch['slabels_high'].cpu().detach().clone().numpy()
172
+ golds_predictions = {'gold_labels_low': gold_labels_low, 'gold_labels_high': gold_labels_high,
173
+ 'pred_labels': predictions}
174
+
175
+ return logs, golds_predictions
176
+
177
+ def on_train_start(self) -> None:
178
+ # Init with BERT embeddings
179
+ epcilon = 1.0e-6
180
+ higher_labels, lower_labels = [], []
181
+ prototype_initials = torch.full((self.CEFR_lvs, self.lm.config.hidden_size), fill_value=epcilon).to(self.device)
182
+
183
+ self.lm.eval()
184
+ for batch in tqdm.tqdm(self.train_dataloader(), leave=False, desc='init prototypes'):
185
+ higher_labels += batch['slabels_high'].squeeze().detach().clone().numpy().tolist()
186
+ lower_labels += batch['slabels_low'].squeeze().detach().clone().numpy().tolist()
187
+ batch = {k: v.cuda() for k, v in batch.items()}
188
+ with torch.no_grad():
189
+ outputs = self.lm(batch['input_ids'], attention_mask=batch['attention_mask'], output_hidden_states=True)
190
+ outputs_mean = mean_pooling(outputs.hidden_states[self.lm_layer],
191
+ attention_mask=batch['attention_mask'])
192
+ for lv in range(self.CEFR_lvs):
193
+ prototype_initials[lv] += outputs_mean[
194
+ (batch['slabels_low'].squeeze() == lv) | (batch['slabels_high'].squeeze() == lv)].sum(0)
195
+ if not self.with_ib:
196
+ self.lm.train()
197
+
198
+ higher_labels = torch.tensor(higher_labels)
199
+ lower_labels = torch.tensor(lower_labels)
200
+ for lv in range(self.CEFR_lvs):
201
+ denom = torch.count_nonzero((higher_labels == lv) | (lower_labels == lv)) + epcilon
202
+ prototype_initials[lv] = prototype_initials[lv] / denom
203
+
204
+ var = torch.var(prototype_initials).item() * 0.05 # Add Gaussian noize with 5% variance of the original tensor
205
+ # prototype_initials = torch.repeat_interleave(prototype_initials, self.num_prototypes, dim=0)
206
+ prototype_initials = prototype_initials.repeat(self.num_prototypes, 1)
207
+ noise = (var ** 0.5) * torch.randn(prototype_initials.size()).to(self.device)
208
+ prototype_initials = prototype_initials + noise # Add Gaussian noize
209
+ self.prototype.weight = nn.Parameter(prototype_initials)
210
+ nn.init.orthogonal_(self.prototype.weight) # Make prototype vectors orthogonal
211
+
212
+ # # Init with Xavier
213
+ # nn.init.xavier_normal_(self.prototype.weight) # Xavier initialization
214
+
215
+ def training_step(self, batch, batch_idx):
216
+ loss, predictions, logs = self.forward(batch)
217
+ self.log_dict({f"train_{k}": v for k, v in logs.items()})
218
+ return loss
219
+
220
+ def validation_step(self, batch, batch_idx):
221
+ logs, golds_predictions = self._shared_eval_step(batch)
222
+ self.log_dict({f"val_{k}": v for k, v in logs.items()})
223
+ return golds_predictions
224
+
225
+ def validation_epoch_end(self, outputs):
226
+ logs = self.evaluation(outputs)
227
+ self.log_dict({f"val_{k}": v for k, v in logs.items()})
228
+
229
+ def test_step(self, batch, batch_idx):
230
+ logs, golds_predictions = self._shared_eval_step(batch)
231
+ self.log_dict({f"test_{k}": v for k, v in logs.items()})
232
+ return golds_predictions
233
+
234
+ def test_epoch_end(self, outputs):
235
+ logs = self.evaluation(outputs, test=True)
236
+ self.log_dict({f"test_{k}": v for k, v in logs.items()})