Question Answering
sanjudebnath commited on
Commit
a966e1f
verified
1 Parent(s): 22f3dba

Delete question_answering.ipynb

Browse files
Files changed (1) hide show
  1. question_answering.ipynb +0 -2403
question_answering.ipynb DELETED
@@ -1,2403 +0,0 @@
1
- {
2
- "cells": [
3
- {
4
- "cell_type": "markdown",
5
- "id": "19817716",
6
- "metadata": {},
7
- "source": [
8
- "# Question Answering\n",
9
- "The following notebook contains different question answering models. We will start by introducing a representation for the dataset and corresponding DataLoader and then evaluate different models."
10
- ]
11
- },
12
- {
13
- "cell_type": "code",
14
- "execution_count": 50,
15
- "id": "49bf46c6",
16
- "metadata": {},
17
- "outputs": [],
18
- "source": [
19
- "from transformers import DistilBertModel, DistilBertForMaskedLM, DistilBertConfig, \\\n",
20
- " DistilBertTokenizerFast, AutoTokenizer, BertModel, BertForMaskedLM, BertTokenizerFast, BertConfig\n",
21
- "from torch import nn\n",
22
- "from pathlib import Path\n",
23
- "import torch\n",
24
- "import pandas as pd\n",
25
- "from typing import Optional \n",
26
- "from tqdm.auto import tqdm\n",
27
- "from util import eval_test_set, count_parameters\n",
28
- "from torch.optim import AdamW, RMSprop\n",
29
- "\n",
30
- "\n",
31
- "from qa_model import QuestionDistilBERT, SimpleQuestionDistilBERT, ReuseQuestionDistilBERT, Dataset, test_model"
32
- ]
33
- },
34
- {
35
- "cell_type": "markdown",
36
- "id": "3ea47820",
37
- "metadata": {},
38
- "source": [
39
- "## Data\n",
40
- "Processing the data correctly is partly based on the Huggingface Tutorial (https://huggingface.co/course/chapter7/7?fw=pt)"
41
- ]
42
- },
43
- {
44
- "cell_type": "code",
45
- "execution_count": 51,
46
- "id": "7b1b2b3e",
47
- "metadata": {},
48
- "outputs": [],
49
- "source": [
50
- "tokenizer = AutoTokenizer.from_pretrained('distilbert-base-uncased')"
51
- ]
52
- },
53
- {
54
- "cell_type": "code",
55
- "execution_count": 52,
56
- "id": "f276eba7",
57
- "metadata": {
58
- "scrolled": false
59
- },
60
- "outputs": [],
61
- "source": [
62
- " \n",
63
- "# create datasets and loaders for training and test set\n",
64
- "squad_paths = [str(x) for x in Path('data/training_squad/').glob('**/*.txt')]\n",
65
- "nat_paths = [str(x) for x in Path('data/natural_questions_train/').glob('**/*.txt')]\n",
66
- "hotpotqa_paths = [str(x) for x in Path('data/hotpotqa_training/').glob('**/*.txt')]"
67
- ]
68
- },
69
- {
70
- "cell_type": "markdown",
71
- "id": "ad8d532a",
72
- "metadata": {},
73
- "source": [
74
- "## POC Model\n",
75
- "* Works very well:\n",
76
- " * Dropout 0.1 is too small (overfitting after first epoch) - changed to 0.15\n",
77
- " * Difference between AdamW and RMSprop minimal\n",
78
- " \n",
79
- "### Results:\n",
80
- "Dropout = 0.15\n",
81
- "* Mean EM: 0.5374\n",
82
- "* Mean F-1: 0.6826317532406944\n",
83
- "\n",
84
- "Dropout = 0.2 (overfitting realtively similar to first, but seems to be too high)\n",
85
- "* Mean EM: 0.5044\n",
86
- "* Mean F-1: 0.6437359169276439"
87
- ]
88
- },
89
- {
90
- "cell_type": "code",
91
- "execution_count": 54,
92
- "id": "703e7f38",
93
- "metadata": {},
94
- "outputs": [],
95
- "source": [
96
- "dataset = Dataset(squad_paths = squad_paths, natural_question_paths=None, hotpotqa_paths=hotpotqa_paths, tokenizer=tokenizer)\n",
97
- "loader = torch.utils.data.DataLoader(dataset, batch_size=8)\n",
98
- "\n",
99
- "test_dataset = Dataset(squad_paths = [str(x) for x in Path('data/test_squad/').glob('**/*.txt')], \n",
100
- " natural_question_paths=None, \n",
101
- " hotpotqa_paths = None, tokenizer=tokenizer)\n",
102
- "test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=4)"
103
- ]
104
- },
105
- {
106
- "cell_type": "code",
107
- "execution_count": 55,
108
- "id": "6672f614",
109
- "metadata": {},
110
- "outputs": [],
111
- "source": [
112
- "model = DistilBertForMaskedLM.from_pretrained(\"distilbert-base-uncased\")\n",
113
- "config = DistilBertConfig.from_pretrained(\"distilbert-base-uncased\")\n",
114
- "mod = model.distilbert"
115
- ]
116
- },
117
- {
118
- "cell_type": "code",
119
- "execution_count": 56,
120
- "id": "dec15198",
121
- "metadata": {},
122
- "outputs": [
123
- {
124
- "data": {
125
- "text/plain": [
126
- "SimpleQuestionDistilBERT(\n",
127
- " (distilbert): DistilBertModel(\n",
128
- " (embeddings): Embeddings(\n",
129
- " (word_embeddings): Embedding(30522, 768, padding_idx=0)\n",
130
- " (position_embeddings): Embedding(512, 768)\n",
131
- " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
132
- " (dropout): Dropout(p=0.1, inplace=False)\n",
133
- " )\n",
134
- " (transformer): Transformer(\n",
135
- " (layer): ModuleList(\n",
136
- " (0): TransformerBlock(\n",
137
- " (attention): MultiHeadSelfAttention(\n",
138
- " (dropout): Dropout(p=0.1, inplace=False)\n",
139
- " (q_lin): Linear(in_features=768, out_features=768, bias=True)\n",
140
- " (k_lin): Linear(in_features=768, out_features=768, bias=True)\n",
141
- " (v_lin): Linear(in_features=768, out_features=768, bias=True)\n",
142
- " (out_lin): Linear(in_features=768, out_features=768, bias=True)\n",
143
- " )\n",
144
- " (sa_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
145
- " (ffn): FFN(\n",
146
- " (dropout): Dropout(p=0.1, inplace=False)\n",
147
- " (lin1): Linear(in_features=768, out_features=3072, bias=True)\n",
148
- " (lin2): Linear(in_features=3072, out_features=768, bias=True)\n",
149
- " (activation): GELUActivation()\n",
150
- " )\n",
151
- " (output_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
152
- " )\n",
153
- " (1): TransformerBlock(\n",
154
- " (attention): MultiHeadSelfAttention(\n",
155
- " (dropout): Dropout(p=0.1, inplace=False)\n",
156
- " (q_lin): Linear(in_features=768, out_features=768, bias=True)\n",
157
- " (k_lin): Linear(in_features=768, out_features=768, bias=True)\n",
158
- " (v_lin): Linear(in_features=768, out_features=768, bias=True)\n",
159
- " (out_lin): Linear(in_features=768, out_features=768, bias=True)\n",
160
- " )\n",
161
- " (sa_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
162
- " (ffn): FFN(\n",
163
- " (dropout): Dropout(p=0.1, inplace=False)\n",
164
- " (lin1): Linear(in_features=768, out_features=3072, bias=True)\n",
165
- " (lin2): Linear(in_features=3072, out_features=768, bias=True)\n",
166
- " (activation): GELUActivation()\n",
167
- " )\n",
168
- " (output_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
169
- " )\n",
170
- " (2): TransformerBlock(\n",
171
- " (attention): MultiHeadSelfAttention(\n",
172
- " (dropout): Dropout(p=0.1, inplace=False)\n",
173
- " (q_lin): Linear(in_features=768, out_features=768, bias=True)\n",
174
- " (k_lin): Linear(in_features=768, out_features=768, bias=True)\n",
175
- " (v_lin): Linear(in_features=768, out_features=768, bias=True)\n",
176
- " (out_lin): Linear(in_features=768, out_features=768, bias=True)\n",
177
- " )\n",
178
- " (sa_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
179
- " (ffn): FFN(\n",
180
- " (dropout): Dropout(p=0.1, inplace=False)\n",
181
- " (lin1): Linear(in_features=768, out_features=3072, bias=True)\n",
182
- " (lin2): Linear(in_features=3072, out_features=768, bias=True)\n",
183
- " (activation): GELUActivation()\n",
184
- " )\n",
185
- " (output_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
186
- " )\n",
187
- " (3): TransformerBlock(\n",
188
- " (attention): MultiHeadSelfAttention(\n",
189
- " (dropout): Dropout(p=0.1, inplace=False)\n",
190
- " (q_lin): Linear(in_features=768, out_features=768, bias=True)\n",
191
- " (k_lin): Linear(in_features=768, out_features=768, bias=True)\n",
192
- " (v_lin): Linear(in_features=768, out_features=768, bias=True)\n",
193
- " (out_lin): Linear(in_features=768, out_features=768, bias=True)\n",
194
- " )\n",
195
- " (sa_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
196
- " (ffn): FFN(\n",
197
- " (dropout): Dropout(p=0.1, inplace=False)\n",
198
- " (lin1): Linear(in_features=768, out_features=3072, bias=True)\n",
199
- " (lin2): Linear(in_features=3072, out_features=768, bias=True)\n",
200
- " (activation): GELUActivation()\n",
201
- " )\n",
202
- " (output_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
203
- " )\n",
204
- " (4): TransformerBlock(\n",
205
- " (attention): MultiHeadSelfAttention(\n",
206
- " (dropout): Dropout(p=0.1, inplace=False)\n",
207
- " (q_lin): Linear(in_features=768, out_features=768, bias=True)\n",
208
- " (k_lin): Linear(in_features=768, out_features=768, bias=True)\n",
209
- " (v_lin): Linear(in_features=768, out_features=768, bias=True)\n",
210
- " (out_lin): Linear(in_features=768, out_features=768, bias=True)\n",
211
- " )\n",
212
- " (sa_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
213
- " (ffn): FFN(\n",
214
- " (dropout): Dropout(p=0.1, inplace=False)\n",
215
- " (lin1): Linear(in_features=768, out_features=3072, bias=True)\n",
216
- " (lin2): Linear(in_features=3072, out_features=768, bias=True)\n",
217
- " (activation): GELUActivation()\n",
218
- " )\n",
219
- " (output_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
220
- " )\n",
221
- " (5): TransformerBlock(\n",
222
- " (attention): MultiHeadSelfAttention(\n",
223
- " (dropout): Dropout(p=0.1, inplace=False)\n",
224
- " (q_lin): Linear(in_features=768, out_features=768, bias=True)\n",
225
- " (k_lin): Linear(in_features=768, out_features=768, bias=True)\n",
226
- " (v_lin): Linear(in_features=768, out_features=768, bias=True)\n",
227
- " (out_lin): Linear(in_features=768, out_features=768, bias=True)\n",
228
- " )\n",
229
- " (sa_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
230
- " (ffn): FFN(\n",
231
- " (dropout): Dropout(p=0.1, inplace=False)\n",
232
- " (lin1): Linear(in_features=768, out_features=3072, bias=True)\n",
233
- " (lin2): Linear(in_features=3072, out_features=768, bias=True)\n",
234
- " (activation): GELUActivation()\n",
235
- " )\n",
236
- " (output_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
237
- " )\n",
238
- " )\n",
239
- " )\n",
240
- " )\n",
241
- " (dropout): Dropout(p=0.5, inplace=False)\n",
242
- " (classifier): Linear(in_features=768, out_features=2, bias=True)\n",
243
- ")"
244
- ]
245
- },
246
- "execution_count": 56,
247
- "metadata": {},
248
- "output_type": "execute_result"
249
- }
250
- ],
251
- "source": [
252
- "device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')\n",
253
- "model = SimpleQuestionDistilBERT(mod)\n",
254
- "model.to(device)"
255
- ]
256
- },
257
- {
258
- "cell_type": "code",
259
- "execution_count": 57,
260
- "id": "9def3c83",
261
- "metadata": {},
262
- "outputs": [
263
- {
264
- "name": "stdout",
265
- "output_type": "stream",
266
- "text": [
267
- "+---------------------------------------------------------+------------+\n",
268
- "| Modules | Parameters |\n",
269
- "+---------------------------------------------------------+------------+\n",
270
- "| distilbert.embeddings.word_embeddings.weight | 23440896 |\n",
271
- "| distilbert.embeddings.position_embeddings.weight | 393216 |\n",
272
- "| distilbert.embeddings.LayerNorm.weight | 768 |\n",
273
- "| distilbert.embeddings.LayerNorm.bias | 768 |\n",
274
- "| distilbert.transformer.layer.0.attention.q_lin.weight | 589824 |\n",
275
- "| distilbert.transformer.layer.0.attention.q_lin.bias | 768 |\n",
276
- "| distilbert.transformer.layer.0.attention.k_lin.weight | 589824 |\n",
277
- "| distilbert.transformer.layer.0.attention.k_lin.bias | 768 |\n",
278
- "| distilbert.transformer.layer.0.attention.v_lin.weight | 589824 |\n",
279
- "| distilbert.transformer.layer.0.attention.v_lin.bias | 768 |\n",
280
- "| distilbert.transformer.layer.0.attention.out_lin.weight | 589824 |\n",
281
- "| distilbert.transformer.layer.0.attention.out_lin.bias | 768 |\n",
282
- "| distilbert.transformer.layer.0.sa_layer_norm.weight | 768 |\n",
283
- "| distilbert.transformer.layer.0.sa_layer_norm.bias | 768 |\n",
284
- "| distilbert.transformer.layer.0.ffn.lin1.weight | 2359296 |\n",
285
- "| distilbert.transformer.layer.0.ffn.lin1.bias | 3072 |\n",
286
- "| distilbert.transformer.layer.0.ffn.lin2.weight | 2359296 |\n",
287
- "| distilbert.transformer.layer.0.ffn.lin2.bias | 768 |\n",
288
- "| distilbert.transformer.layer.0.output_layer_norm.weight | 768 |\n",
289
- "| distilbert.transformer.layer.0.output_layer_norm.bias | 768 |\n",
290
- "| distilbert.transformer.layer.1.attention.q_lin.weight | 589824 |\n",
291
- "| distilbert.transformer.layer.1.attention.q_lin.bias | 768 |\n",
292
- "| distilbert.transformer.layer.1.attention.k_lin.weight | 589824 |\n",
293
- "| distilbert.transformer.layer.1.attention.k_lin.bias | 768 |\n",
294
- "| distilbert.transformer.layer.1.attention.v_lin.weight | 589824 |\n",
295
- "| distilbert.transformer.layer.1.attention.v_lin.bias | 768 |\n",
296
- "| distilbert.transformer.layer.1.attention.out_lin.weight | 589824 |\n",
297
- "| distilbert.transformer.layer.1.attention.out_lin.bias | 768 |\n",
298
- "| distilbert.transformer.layer.1.sa_layer_norm.weight | 768 |\n",
299
- "| distilbert.transformer.layer.1.sa_layer_norm.bias | 768 |\n",
300
- "| distilbert.transformer.layer.1.ffn.lin1.weight | 2359296 |\n",
301
- "| distilbert.transformer.layer.1.ffn.lin1.bias | 3072 |\n",
302
- "| distilbert.transformer.layer.1.ffn.lin2.weight | 2359296 |\n",
303
- "| distilbert.transformer.layer.1.ffn.lin2.bias | 768 |\n",
304
- "| distilbert.transformer.layer.1.output_layer_norm.weight | 768 |\n",
305
- "| distilbert.transformer.layer.1.output_layer_norm.bias | 768 |\n",
306
- "| distilbert.transformer.layer.2.attention.q_lin.weight | 589824 |\n",
307
- "| distilbert.transformer.layer.2.attention.q_lin.bias | 768 |\n",
308
- "| distilbert.transformer.layer.2.attention.k_lin.weight | 589824 |\n",
309
- "| distilbert.transformer.layer.2.attention.k_lin.bias | 768 |\n",
310
- "| distilbert.transformer.layer.2.attention.v_lin.weight | 589824 |\n",
311
- "| distilbert.transformer.layer.2.attention.v_lin.bias | 768 |\n",
312
- "| distilbert.transformer.layer.2.attention.out_lin.weight | 589824 |\n",
313
- "| distilbert.transformer.layer.2.attention.out_lin.bias | 768 |\n",
314
- "| distilbert.transformer.layer.2.sa_layer_norm.weight | 768 |\n",
315
- "| distilbert.transformer.layer.2.sa_layer_norm.bias | 768 |\n",
316
- "| distilbert.transformer.layer.2.ffn.lin1.weight | 2359296 |\n",
317
- "| distilbert.transformer.layer.2.ffn.lin1.bias | 3072 |\n",
318
- "| distilbert.transformer.layer.2.ffn.lin2.weight | 2359296 |\n",
319
- "| distilbert.transformer.layer.2.ffn.lin2.bias | 768 |\n",
320
- "| distilbert.transformer.layer.2.output_layer_norm.weight | 768 |\n",
321
- "| distilbert.transformer.layer.2.output_layer_norm.bias | 768 |\n",
322
- "| distilbert.transformer.layer.3.attention.q_lin.weight | 589824 |\n",
323
- "| distilbert.transformer.layer.3.attention.q_lin.bias | 768 |\n",
324
- "| distilbert.transformer.layer.3.attention.k_lin.weight | 589824 |\n",
325
- "| distilbert.transformer.layer.3.attention.k_lin.bias | 768 |\n",
326
- "| distilbert.transformer.layer.3.attention.v_lin.weight | 589824 |\n",
327
- "| distilbert.transformer.layer.3.attention.v_lin.bias | 768 |\n",
328
- "| distilbert.transformer.layer.3.attention.out_lin.weight | 589824 |\n",
329
- "| distilbert.transformer.layer.3.attention.out_lin.bias | 768 |\n",
330
- "| distilbert.transformer.layer.3.sa_layer_norm.weight | 768 |\n",
331
- "| distilbert.transformer.layer.3.sa_layer_norm.bias | 768 |\n",
332
- "| distilbert.transformer.layer.3.ffn.lin1.weight | 2359296 |\n",
333
- "| distilbert.transformer.layer.3.ffn.lin1.bias | 3072 |\n",
334
- "| distilbert.transformer.layer.3.ffn.lin2.weight | 2359296 |\n",
335
- "| distilbert.transformer.layer.3.ffn.lin2.bias | 768 |\n",
336
- "| distilbert.transformer.layer.3.output_layer_norm.weight | 768 |\n",
337
- "| distilbert.transformer.layer.3.output_layer_norm.bias | 768 |\n",
338
- "| distilbert.transformer.layer.4.attention.q_lin.weight | 589824 |\n",
339
- "| distilbert.transformer.layer.4.attention.q_lin.bias | 768 |\n",
340
- "| distilbert.transformer.layer.4.attention.k_lin.weight | 589824 |\n",
341
- "| distilbert.transformer.layer.4.attention.k_lin.bias | 768 |\n",
342
- "| distilbert.transformer.layer.4.attention.v_lin.weight | 589824 |\n",
343
- "| distilbert.transformer.layer.4.attention.v_lin.bias | 768 |\n",
344
- "| distilbert.transformer.layer.4.attention.out_lin.weight | 589824 |\n",
345
- "| distilbert.transformer.layer.4.attention.out_lin.bias | 768 |\n",
346
- "| distilbert.transformer.layer.4.sa_layer_norm.weight | 768 |\n",
347
- "| distilbert.transformer.layer.4.sa_layer_norm.bias | 768 |\n",
348
- "| distilbert.transformer.layer.4.ffn.lin1.weight | 2359296 |\n",
349
- "| distilbert.transformer.layer.4.ffn.lin1.bias | 3072 |\n",
350
- "| distilbert.transformer.layer.4.ffn.lin2.weight | 2359296 |\n",
351
- "| distilbert.transformer.layer.4.ffn.lin2.bias | 768 |\n",
352
- "| distilbert.transformer.layer.4.output_layer_norm.weight | 768 |\n",
353
- "| distilbert.transformer.layer.4.output_layer_norm.bias | 768 |\n",
354
- "| distilbert.transformer.layer.5.attention.q_lin.weight | 589824 |\n",
355
- "| distilbert.transformer.layer.5.attention.q_lin.bias | 768 |\n",
356
- "| distilbert.transformer.layer.5.attention.k_lin.weight | 589824 |\n",
357
- "| distilbert.transformer.layer.5.attention.k_lin.bias | 768 |\n",
358
- "| distilbert.transformer.layer.5.attention.v_lin.weight | 589824 |\n",
359
- "| distilbert.transformer.layer.5.attention.v_lin.bias | 768 |\n",
360
- "| distilbert.transformer.layer.5.attention.out_lin.weight | 589824 |\n",
361
- "| distilbert.transformer.layer.5.attention.out_lin.bias | 768 |\n",
362
- "| distilbert.transformer.layer.5.sa_layer_norm.weight | 768 |\n",
363
- "| distilbert.transformer.layer.5.sa_layer_norm.bias | 768 |\n",
364
- "| distilbert.transformer.layer.5.ffn.lin1.weight | 2359296 |\n",
365
- "| distilbert.transformer.layer.5.ffn.lin1.bias | 3072 |\n",
366
- "| distilbert.transformer.layer.5.ffn.lin2.weight | 2359296 |\n",
367
- "| distilbert.transformer.layer.5.ffn.lin2.bias | 768 |\n",
368
- "| distilbert.transformer.layer.5.output_layer_norm.weight | 768 |\n",
369
- "| distilbert.transformer.layer.5.output_layer_norm.bias | 768 |\n",
370
- "| classifier.weight | 1536 |\n",
371
- "| classifier.bias | 2 |\n",
372
- "+---------------------------------------------------------+------------+\n",
373
- "Total Trainable Params: 66364418\n"
374
- ]
375
- },
376
- {
377
- "data": {
378
- "text/plain": [
379
- "66364418"
380
- ]
381
- },
382
- "execution_count": 57,
383
- "metadata": {},
384
- "output_type": "execute_result"
385
- }
386
- ],
387
- "source": [
388
- "count_parameters(model)"
389
- ]
390
- },
391
- {
392
- "cell_type": "markdown",
393
- "id": "426a6311",
394
- "metadata": {},
395
- "source": [
396
- "### Testing the model"
397
- ]
398
- },
399
- {
400
- "cell_type": "code",
401
- "execution_count": 58,
402
- "id": "6151c201",
403
- "metadata": {},
404
- "outputs": [],
405
- "source": [
406
- "# get smaller dataset\n",
407
- "batch_size = 8\n",
408
- "test_ds = Dataset(squad_paths = squad_paths[:2], natural_question_paths=None, hotpotqa_paths=None, tokenizer=tokenizer)\n",
409
- "test_ds_loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size)\n",
410
- "optim = RMSprop(model.parameters(), lr=1e-4)"
411
- ]
412
- },
413
- {
414
- "cell_type": "code",
415
- "execution_count": 59,
416
- "id": "aeae0c56",
417
- "metadata": {},
418
- "outputs": [
419
- {
420
- "name": "stdout",
421
- "output_type": "stream",
422
- "text": [
423
- "Passed\n"
424
- ]
425
- }
426
- ],
427
- "source": [
428
- "test_model(model, optim, test_ds_loader, device)"
429
- ]
430
- },
431
- {
432
- "cell_type": "markdown",
433
- "id": "59928d34",
434
- "metadata": {},
435
- "source": [
436
- "### Model Training"
437
- ]
438
- },
439
- {
440
- "cell_type": "code",
441
- "execution_count": 60,
442
- "id": "a8017b8c",
443
- "metadata": {},
444
- "outputs": [
445
- {
446
- "data": {
447
- "text/plain": [
448
- "SimpleQuestionDistilBERT(\n",
449
- " (distilbert): DistilBertModel(\n",
450
- " (embeddings): Embeddings(\n",
451
- " (word_embeddings): Embedding(30522, 768, padding_idx=0)\n",
452
- " (position_embeddings): Embedding(512, 768)\n",
453
- " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
454
- " (dropout): Dropout(p=0.1, inplace=False)\n",
455
- " )\n",
456
- " (transformer): Transformer(\n",
457
- " (layer): ModuleList(\n",
458
- " (0): TransformerBlock(\n",
459
- " (attention): MultiHeadSelfAttention(\n",
460
- " (dropout): Dropout(p=0.1, inplace=False)\n",
461
- " (q_lin): Linear(in_features=768, out_features=768, bias=True)\n",
462
- " (k_lin): Linear(in_features=768, out_features=768, bias=True)\n",
463
- " (v_lin): Linear(in_features=768, out_features=768, bias=True)\n",
464
- " (out_lin): Linear(in_features=768, out_features=768, bias=True)\n",
465
- " )\n",
466
- " (sa_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
467
- " (ffn): FFN(\n",
468
- " (dropout): Dropout(p=0.1, inplace=False)\n",
469
- " (lin1): Linear(in_features=768, out_features=3072, bias=True)\n",
470
- " (lin2): Linear(in_features=3072, out_features=768, bias=True)\n",
471
- " (activation): GELUActivation()\n",
472
- " )\n",
473
- " (output_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
474
- " )\n",
475
- " (1): TransformerBlock(\n",
476
- " (attention): MultiHeadSelfAttention(\n",
477
- " (dropout): Dropout(p=0.1, inplace=False)\n",
478
- " (q_lin): Linear(in_features=768, out_features=768, bias=True)\n",
479
- " (k_lin): Linear(in_features=768, out_features=768, bias=True)\n",
480
- " (v_lin): Linear(in_features=768, out_features=768, bias=True)\n",
481
- " (out_lin): Linear(in_features=768, out_features=768, bias=True)\n",
482
- " )\n",
483
- " (sa_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
484
- " (ffn): FFN(\n",
485
- " (dropout): Dropout(p=0.1, inplace=False)\n",
486
- " (lin1): Linear(in_features=768, out_features=3072, bias=True)\n",
487
- " (lin2): Linear(in_features=3072, out_features=768, bias=True)\n",
488
- " (activation): GELUActivation()\n",
489
- " )\n",
490
- " (output_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
491
- " )\n",
492
- " (2): TransformerBlock(\n",
493
- " (attention): MultiHeadSelfAttention(\n",
494
- " (dropout): Dropout(p=0.1, inplace=False)\n",
495
- " (q_lin): Linear(in_features=768, out_features=768, bias=True)\n",
496
- " (k_lin): Linear(in_features=768, out_features=768, bias=True)\n",
497
- " (v_lin): Linear(in_features=768, out_features=768, bias=True)\n",
498
- " (out_lin): Linear(in_features=768, out_features=768, bias=True)\n",
499
- " )\n",
500
- " (sa_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
501
- " (ffn): FFN(\n",
502
- " (dropout): Dropout(p=0.1, inplace=False)\n",
503
- " (lin1): Linear(in_features=768, out_features=3072, bias=True)\n",
504
- " (lin2): Linear(in_features=3072, out_features=768, bias=True)\n",
505
- " (activation): GELUActivation()\n",
506
- " )\n",
507
- " (output_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
508
- " )\n",
509
- " (3): TransformerBlock(\n",
510
- " (attention): MultiHeadSelfAttention(\n",
511
- " (dropout): Dropout(p=0.1, inplace=False)\n",
512
- " (q_lin): Linear(in_features=768, out_features=768, bias=True)\n",
513
- " (k_lin): Linear(in_features=768, out_features=768, bias=True)\n",
514
- " (v_lin): Linear(in_features=768, out_features=768, bias=True)\n",
515
- " (out_lin): Linear(in_features=768, out_features=768, bias=True)\n",
516
- " )\n",
517
- " (sa_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
518
- " (ffn): FFN(\n",
519
- " (dropout): Dropout(p=0.1, inplace=False)\n",
520
- " (lin1): Linear(in_features=768, out_features=3072, bias=True)\n",
521
- " (lin2): Linear(in_features=3072, out_features=768, bias=True)\n",
522
- " (activation): GELUActivation()\n",
523
- " )\n",
524
- " (output_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
525
- " )\n",
526
- " (4): TransformerBlock(\n",
527
- " (attention): MultiHeadSelfAttention(\n",
528
- " (dropout): Dropout(p=0.1, inplace=False)\n",
529
- " (q_lin): Linear(in_features=768, out_features=768, bias=True)\n",
530
- " (k_lin): Linear(in_features=768, out_features=768, bias=True)\n",
531
- " (v_lin): Linear(in_features=768, out_features=768, bias=True)\n",
532
- " (out_lin): Linear(in_features=768, out_features=768, bias=True)\n",
533
- " )\n",
534
- " (sa_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
535
- " (ffn): FFN(\n",
536
- " (dropout): Dropout(p=0.1, inplace=False)\n",
537
- " (lin1): Linear(in_features=768, out_features=3072, bias=True)\n",
538
- " (lin2): Linear(in_features=3072, out_features=768, bias=True)\n",
539
- " (activation): GELUActivation()\n",
540
- " )\n",
541
- " (output_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
542
- " )\n",
543
- " (5): TransformerBlock(\n",
544
- " (attention): MultiHeadSelfAttention(\n",
545
- " (dropout): Dropout(p=0.1, inplace=False)\n",
546
- " (q_lin): Linear(in_features=768, out_features=768, bias=True)\n",
547
- " (k_lin): Linear(in_features=768, out_features=768, bias=True)\n",
548
- " (v_lin): Linear(in_features=768, out_features=768, bias=True)\n",
549
- " (out_lin): Linear(in_features=768, out_features=768, bias=True)\n",
550
- " )\n",
551
- " (sa_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
552
- " (ffn): FFN(\n",
553
- " (dropout): Dropout(p=0.1, inplace=False)\n",
554
- " (lin1): Linear(in_features=768, out_features=3072, bias=True)\n",
555
- " (lin2): Linear(in_features=3072, out_features=768, bias=True)\n",
556
- " (activation): GELUActivation()\n",
557
- " )\n",
558
- " (output_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
559
- " )\n",
560
- " )\n",
561
- " )\n",
562
- " )\n",
563
- " (dropout): Dropout(p=0.5, inplace=False)\n",
564
- " (classifier): Linear(in_features=768, out_features=2, bias=True)\n",
565
- ")"
566
- ]
567
- },
568
- "execution_count": 60,
569
- "metadata": {},
570
- "output_type": "execute_result"
571
- }
572
- ],
573
- "source": [
574
- "device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')\n",
575
- "model = SimpleQuestionDistilBERT(mod)\n",
576
- "model.to(device)"
577
- ]
578
- },
579
- {
580
- "cell_type": "code",
581
- "execution_count": 61,
582
- "id": "f13c12dc",
583
- "metadata": {},
584
- "outputs": [],
585
- "source": [
586
- "model.train()\n",
587
- "optim = RMSprop(model.parameters(), lr=1e-4)"
588
- ]
589
- },
590
- {
591
- "cell_type": "code",
592
- "execution_count": 22,
593
- "id": "e4fa54d9",
594
- "metadata": {},
595
- "outputs": [
596
- {
597
- "data": {
598
- "application/vnd.jupyter.widget-view+json": {
599
- "model_id": "0016d9f5ba764eb98e9df8573995c86c",
600
- "version_major": 2,
601
- "version_minor": 0
602
- },
603
- "text/plain": [
604
- " 0%| | 0/10875 [00:00<?, ?it/s]"
605
- ]
606
- },
607
- "metadata": {},
608
- "output_type": "display_data"
609
- },
610
- {
611
- "name": "stdout",
612
- "output_type": "stream",
613
- "text": [
614
- "Mean Training Error 0.7555404769408292\n"
615
- ]
616
- },
617
- {
618
- "data": {
619
- "application/vnd.jupyter.widget-view+json": {
620
- "model_id": "96af0e22e2ee44fd920795b0e7317839",
621
- "version_major": 2,
622
- "version_minor": 0
623
- },
624
- "text/plain": [
625
- " 0%| | 0/2500 [00:00<?, ?it/s]"
626
- ]
627
- },
628
- "metadata": {},
629
- "output_type": "display_data"
630
- },
631
- {
632
- "name": "stdout",
633
- "output_type": "stream",
634
- "text": [
635
- "Mean Test Error 1.761920437876694\n"
636
- ]
637
- },
638
- {
639
- "data": {
640
- "application/vnd.jupyter.widget-view+json": {
641
- "model_id": "5160ffe5f60e4b72b46746a33b1d60d0",
642
- "version_major": 2,
643
- "version_minor": 0
644
- },
645
- "text/plain": [
646
- " 0%| | 0/10875 [00:00<?, ?it/s]"
647
- ]
648
- },
649
- "metadata": {},
650
- "output_type": "display_data"
651
- },
652
- {
653
- "ename": "KeyboardInterrupt",
654
- "evalue": "",
655
- "output_type": "error",
656
- "traceback": [
657
- "\u001B[0;31m---------------------------------------------------------------------------\u001B[0m",
658
- "\u001B[0;31mKeyboardInterrupt\u001B[0m Traceback (most recent call last)",
659
- "Cell \u001B[0;32mIn [22], line 18\u001B[0m\n\u001B[1;32m 16\u001B[0m \u001B[38;5;66;03m# print(torch.argmax(outputs['start_logits'],axis=1), torch.argmax(outputs['end_logits'], axis=1), start, end)\u001B[39;00m\n\u001B[1;32m 17\u001B[0m loss \u001B[38;5;241m=\u001B[39m outputs[\u001B[38;5;124m'\u001B[39m\u001B[38;5;124mloss\u001B[39m\u001B[38;5;124m'\u001B[39m]\n\u001B[0;32m---> 18\u001B[0m loss\u001B[38;5;241m.\u001B[39mbackward()\n\u001B[1;32m 19\u001B[0m \u001B[38;5;66;03m# torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1)\u001B[39;00m\n\u001B[1;32m 20\u001B[0m optim\u001B[38;5;241m.\u001B[39mstep()\n",
660
- "File \u001B[0;32m~/Documents/University/WS2022/applieddl/venv/lib64/python3.10/site-packages/torch/_tensor.py:396\u001B[0m, in \u001B[0;36mTensor.backward\u001B[0;34m(self, gradient, retain_graph, create_graph, inputs)\u001B[0m\n\u001B[1;32m 387\u001B[0m \u001B[38;5;28;01mif\u001B[39;00m has_torch_function_unary(\u001B[38;5;28mself\u001B[39m):\n\u001B[1;32m 388\u001B[0m \u001B[38;5;28;01mreturn\u001B[39;00m handle_torch_function(\n\u001B[1;32m 389\u001B[0m Tensor\u001B[38;5;241m.\u001B[39mbackward,\n\u001B[1;32m 390\u001B[0m (\u001B[38;5;28mself\u001B[39m,),\n\u001B[0;32m (...)\u001B[0m\n\u001B[1;32m 394\u001B[0m create_graph\u001B[38;5;241m=\u001B[39mcreate_graph,\n\u001B[1;32m 395\u001B[0m inputs\u001B[38;5;241m=\u001B[39minputs)\n\u001B[0;32m--> 396\u001B[0m \u001B[43mtorch\u001B[49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43mautograd\u001B[49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43mbackward\u001B[49m\u001B[43m(\u001B[49m\u001B[38;5;28;43mself\u001B[39;49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mgradient\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mretain_graph\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mcreate_graph\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43minputs\u001B[49m\u001B[38;5;241;43m=\u001B[39;49m\u001B[43minputs\u001B[49m\u001B[43m)\u001B[49m\n",
661
- "File \u001B[0;32m~/Documents/University/WS2022/applieddl/venv/lib64/python3.10/site-packages/torch/autograd/__init__.py:173\u001B[0m, in \u001B[0;36mbackward\u001B[0;34m(tensors, grad_tensors, retain_graph, create_graph, grad_variables, inputs)\u001B[0m\n\u001B[1;32m 168\u001B[0m retain_graph \u001B[38;5;241m=\u001B[39m create_graph\n\u001B[1;32m 170\u001B[0m \u001B[38;5;66;03m# The reason we repeat same the comment below is that\u001B[39;00m\n\u001B[1;32m 171\u001B[0m \u001B[38;5;66;03m# some Python versions print out the first line of a multi-line function\u001B[39;00m\n\u001B[1;32m 172\u001B[0m \u001B[38;5;66;03m# calls in the traceback and some print out the last line\u001B[39;00m\n\u001B[0;32m--> 173\u001B[0m \u001B[43mVariable\u001B[49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43m_execution_engine\u001B[49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43mrun_backward\u001B[49m\u001B[43m(\u001B[49m\u001B[43m \u001B[49m\u001B[38;5;66;43;03m# Calls into the C++ engine to run the backward pass\u001B[39;49;00m\n\u001B[1;32m 174\u001B[0m \u001B[43m \u001B[49m\u001B[43mtensors\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mgrad_tensors_\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mretain_graph\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mcreate_graph\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43minputs\u001B[49m\u001B[43m,\u001B[49m\n\u001B[1;32m 175\u001B[0m \u001B[43m \u001B[49m\u001B[43mallow_unreachable\u001B[49m\u001B[38;5;241;43m=\u001B[39;49m\u001B[38;5;28;43;01mTrue\u001B[39;49;00m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43maccumulate_grad\u001B[49m\u001B[38;5;241;43m=\u001B[39;49m\u001B[38;5;28;43;01mTrue\u001B[39;49;00m\u001B[43m)\u001B[49m\n",
662
- "\u001B[0;31mKeyboardInterrupt\u001B[0m: "
663
- ]
664
- }
665
- ],
666
- "source": [
667
- "epochs = 5\n",
668
- "\n",
669
- "for epoch in range(epochs):\n",
670
- " loop = tqdm(loader, leave=True)\n",
671
- " model.train()\n",
672
- " mean_training_error = []\n",
673
- " for batch in loop:\n",
674
- " optim.zero_grad()\n",
675
- " \n",
676
- " input_ids = batch['input_ids'].to(device)\n",
677
- " attention_mask = batch['attention_mask'].to(device)\n",
678
- " start = batch['start_positions'].to(device)\n",
679
- " end = batch['end_positions'].to(device)\n",
680
- " \n",
681
- " outputs = model(input_ids, attention_mask=attention_mask, start_positions=start, end_positions=end)\n",
682
- " # print(torch.argmax(outputs['start_logits'],axis=1), torch.argmax(outputs['end_logits'], axis=1), start, end)\n",
683
- " loss = outputs['loss']\n",
684
- " loss.backward()\n",
685
- " # torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1)\n",
686
- " optim.step()\n",
687
- " mean_training_error.append(loss.item())\n",
688
- " loop.set_description(f'Epoch {epoch}')\n",
689
- " loop.set_postfix(loss=loss.item())\n",
690
- " print(\"Mean Training Error\", np.mean(mean_training_error))\n",
691
- " \n",
692
- " \n",
693
- " loop = tqdm(test_loader, leave=True)\n",
694
- " model.eval()\n",
695
- " mean_test_error = []\n",
696
- " for batch in loop:\n",
697
- " \n",
698
- " input_ids = batch['input_ids'].to(device)\n",
699
- " attention_mask = batch['attention_mask'].to(device)\n",
700
- " start = batch['start_positions'].to(device)\n",
701
- " end = batch['end_positions'].to(device)\n",
702
- " \n",
703
- " outputs = model(input_ids, attention_mask=attention_mask, start_positions=start, end_positions=end)\n",
704
- " # print(torch.argmax(outputs['start_logits'],axis=1), torch.argmax(outputs['end_logits'], axis=1), start, end)\n",
705
- " loss = outputs['loss']\n",
706
- " \n",
707
- " mean_test_error.append(loss.item())\n",
708
- " loop.set_description(f'Epoch {epoch} Testset')\n",
709
- " loop.set_postfix(loss=loss.item())\n",
710
- " print(\"Mean Test Error\", np.mean(mean_test_error))"
711
- ]
712
- },
713
- {
714
- "cell_type": "code",
715
- "execution_count": 19,
716
- "id": "6ff26fb4",
717
- "metadata": {},
718
- "outputs": [],
719
- "source": [
720
- "torch.save(model.state_dict(), \"simple_distilbert_qa.model\")"
721
- ]
722
- },
723
- {
724
- "cell_type": "code",
725
- "execution_count": 20,
726
- "id": "a5e7abeb",
727
- "metadata": {},
728
- "outputs": [
729
- {
730
- "data": {
731
- "text/plain": [
732
- "<All keys matched successfully>"
733
- ]
734
- },
735
- "execution_count": 20,
736
- "metadata": {},
737
- "output_type": "execute_result"
738
- }
739
- ],
740
- "source": [
741
- "model = SimpleQuestionDistilBERT(mod)\n",
742
- "model.load_state_dict(torch.load(\"simple_distilbert_qa.model\"))"
743
- ]
744
- },
745
- {
746
- "cell_type": "code",
747
- "execution_count": 18,
748
- "id": "f5ad7bee",
749
- "metadata": {},
750
- "outputs": [
751
- {
752
- "name": "stderr",
753
- "output_type": "stream",
754
- "text": [
755
- "100%|鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅| 2500/2500 [02:09<00:00, 19.37it/s]"
756
- ]
757
- },
758
- {
759
- "name": "stdout",
760
- "output_type": "stream",
761
- "text": [
762
- "Mean EM: 0.5374\n",
763
- "Mean F-1: 0.6826317532406944\n"
764
- ]
765
- },
766
- {
767
- "name": "stderr",
768
- "output_type": "stream",
769
- "text": [
770
- "\n"
771
- ]
772
- }
773
- ],
774
- "source": [
775
- "eval_test_set(model, tokenizer, test_loader, device)"
776
- ]
777
- },
778
- {
779
- "cell_type": "markdown",
780
- "id": "fa6017a8",
781
- "metadata": {},
782
- "source": [
783
- "## Freeze baseline and train new head\n",
784
- "This was my initial idea, to freeze the layers and add a completely new head, which we train from scratch. I tried a lot of different configurations, but nothing really worked, I usually stayed at a CrossEntropyLoss of about 3 the whole time. Below, you can see the different heads I have tried.\n",
785
- "\n",
786
- "Furthermore, I experimented with different data, because I though it might not be enough data all in all. I would conclude that this didn't work because (1) Transformers are very data-hungry and I probably still used too little data (one epoch took about 1h though, so it wasn't possible to use even more). (2) We train the layers completely new, which means they contain absolutely no structure about the problem and task beforehand. I do not think that this way of training leads to better results / less energy used all in all, because it would be too resource intense.\n",
787
- "\n",
788
- "The following setup is partly based on the HuggingFace implementation of the question answering model (https://github.com/huggingface/transformers/blob/v4.23.1/src/transformers/models/distilbert/modeling_distilbert.py#L805)"
789
- ]
790
- },
791
- {
792
- "cell_type": "code",
793
- "execution_count": 62,
794
- "id": "92b21967",
795
- "metadata": {},
796
- "outputs": [],
797
- "source": [
798
- "model = DistilBertForMaskedLM.from_pretrained(\"distilbert-base-uncased\")"
799
- ]
800
- },
801
- {
802
- "cell_type": "code",
803
- "execution_count": 63,
804
- "id": "1d7b3a8c",
805
- "metadata": {},
806
- "outputs": [],
807
- "source": [
808
- "config = DistilBertConfig.from_pretrained(\"distilbert-base-uncased\")"
809
- ]
810
- },
811
- {
812
- "cell_type": "code",
813
- "execution_count": 64,
814
- "id": "91444894",
815
- "metadata": {},
816
- "outputs": [],
817
- "source": [
818
- "# only take base model, we do not need the classification head\n",
819
- "mod = model.distilbert"
820
- ]
821
- },
822
- {
823
- "cell_type": "code",
824
- "execution_count": 65,
825
- "id": "74ca6c07",
826
- "metadata": {},
827
- "outputs": [
828
- {
829
- "data": {
830
- "text/plain": [
831
- "QuestionDistilBERT(\n",
832
- " (distilbert): DistilBertModel(\n",
833
- " (embeddings): Embeddings(\n",
834
- " (word_embeddings): Embedding(30522, 768, padding_idx=0)\n",
835
- " (position_embeddings): Embedding(512, 768)\n",
836
- " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
837
- " (dropout): Dropout(p=0.1, inplace=False)\n",
838
- " )\n",
839
- " (transformer): Transformer(\n",
840
- " (layer): ModuleList(\n",
841
- " (0): TransformerBlock(\n",
842
- " (attention): MultiHeadSelfAttention(\n",
843
- " (dropout): Dropout(p=0.1, inplace=False)\n",
844
- " (q_lin): Linear(in_features=768, out_features=768, bias=True)\n",
845
- " (k_lin): Linear(in_features=768, out_features=768, bias=True)\n",
846
- " (v_lin): Linear(in_features=768, out_features=768, bias=True)\n",
847
- " (out_lin): Linear(in_features=768, out_features=768, bias=True)\n",
848
- " )\n",
849
- " (sa_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
850
- " (ffn): FFN(\n",
851
- " (dropout): Dropout(p=0.1, inplace=False)\n",
852
- " (lin1): Linear(in_features=768, out_features=3072, bias=True)\n",
853
- " (lin2): Linear(in_features=3072, out_features=768, bias=True)\n",
854
- " (activation): GELUActivation()\n",
855
- " )\n",
856
- " (output_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
857
- " )\n",
858
- " (1): TransformerBlock(\n",
859
- " (attention): MultiHeadSelfAttention(\n",
860
- " (dropout): Dropout(p=0.1, inplace=False)\n",
861
- " (q_lin): Linear(in_features=768, out_features=768, bias=True)\n",
862
- " (k_lin): Linear(in_features=768, out_features=768, bias=True)\n",
863
- " (v_lin): Linear(in_features=768, out_features=768, bias=True)\n",
864
- " (out_lin): Linear(in_features=768, out_features=768, bias=True)\n",
865
- " )\n",
866
- " (sa_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
867
- " (ffn): FFN(\n",
868
- " (dropout): Dropout(p=0.1, inplace=False)\n",
869
- " (lin1): Linear(in_features=768, out_features=3072, bias=True)\n",
870
- " (lin2): Linear(in_features=3072, out_features=768, bias=True)\n",
871
- " (activation): GELUActivation()\n",
872
- " )\n",
873
- " (output_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
874
- " )\n",
875
- " (2): TransformerBlock(\n",
876
- " (attention): MultiHeadSelfAttention(\n",
877
- " (dropout): Dropout(p=0.1, inplace=False)\n",
878
- " (q_lin): Linear(in_features=768, out_features=768, bias=True)\n",
879
- " (k_lin): Linear(in_features=768, out_features=768, bias=True)\n",
880
- " (v_lin): Linear(in_features=768, out_features=768, bias=True)\n",
881
- " (out_lin): Linear(in_features=768, out_features=768, bias=True)\n",
882
- " )\n",
883
- " (sa_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
884
- " (ffn): FFN(\n",
885
- " (dropout): Dropout(p=0.1, inplace=False)\n",
886
- " (lin1): Linear(in_features=768, out_features=3072, bias=True)\n",
887
- " (lin2): Linear(in_features=3072, out_features=768, bias=True)\n",
888
- " (activation): GELUActivation()\n",
889
- " )\n",
890
- " (output_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
891
- " )\n",
892
- " (3): TransformerBlock(\n",
893
- " (attention): MultiHeadSelfAttention(\n",
894
- " (dropout): Dropout(p=0.1, inplace=False)\n",
895
- " (q_lin): Linear(in_features=768, out_features=768, bias=True)\n",
896
- " (k_lin): Linear(in_features=768, out_features=768, bias=True)\n",
897
- " (v_lin): Linear(in_features=768, out_features=768, bias=True)\n",
898
- " (out_lin): Linear(in_features=768, out_features=768, bias=True)\n",
899
- " )\n",
900
- " (sa_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
901
- " (ffn): FFN(\n",
902
- " (dropout): Dropout(p=0.1, inplace=False)\n",
903
- " (lin1): Linear(in_features=768, out_features=3072, bias=True)\n",
904
- " (lin2): Linear(in_features=3072, out_features=768, bias=True)\n",
905
- " (activation): GELUActivation()\n",
906
- " )\n",
907
- " (output_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
908
- " )\n",
909
- " (4): TransformerBlock(\n",
910
- " (attention): MultiHeadSelfAttention(\n",
911
- " (dropout): Dropout(p=0.1, inplace=False)\n",
912
- " (q_lin): Linear(in_features=768, out_features=768, bias=True)\n",
913
- " (k_lin): Linear(in_features=768, out_features=768, bias=True)\n",
914
- " (v_lin): Linear(in_features=768, out_features=768, bias=True)\n",
915
- " (out_lin): Linear(in_features=768, out_features=768, bias=True)\n",
916
- " )\n",
917
- " (sa_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
918
- " (ffn): FFN(\n",
919
- " (dropout): Dropout(p=0.1, inplace=False)\n",
920
- " (lin1): Linear(in_features=768, out_features=3072, bias=True)\n",
921
- " (lin2): Linear(in_features=3072, out_features=768, bias=True)\n",
922
- " (activation): GELUActivation()\n",
923
- " )\n",
924
- " (output_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
925
- " )\n",
926
- " (5): TransformerBlock(\n",
927
- " (attention): MultiHeadSelfAttention(\n",
928
- " (dropout): Dropout(p=0.1, inplace=False)\n",
929
- " (q_lin): Linear(in_features=768, out_features=768, bias=True)\n",
930
- " (k_lin): Linear(in_features=768, out_features=768, bias=True)\n",
931
- " (v_lin): Linear(in_features=768, out_features=768, bias=True)\n",
932
- " (out_lin): Linear(in_features=768, out_features=768, bias=True)\n",
933
- " )\n",
934
- " (sa_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
935
- " (ffn): FFN(\n",
936
- " (dropout): Dropout(p=0.1, inplace=False)\n",
937
- " (lin1): Linear(in_features=768, out_features=3072, bias=True)\n",
938
- " (lin2): Linear(in_features=3072, out_features=768, bias=True)\n",
939
- " (activation): GELUActivation()\n",
940
- " )\n",
941
- " (output_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
942
- " )\n",
943
- " )\n",
944
- " )\n",
945
- " )\n",
946
- " (relu): ReLU()\n",
947
- " (dropout): Dropout(p=0.1, inplace=False)\n",
948
- " (te): TransformerEncoder(\n",
949
- " (layers): ModuleList(\n",
950
- " (0): TransformerEncoderLayer(\n",
951
- " (self_attn): MultiheadAttention(\n",
952
- " (out_proj): NonDynamicallyQuantizableLinear(in_features=768, out_features=768, bias=True)\n",
953
- " )\n",
954
- " (linear1): Linear(in_features=768, out_features=2048, bias=True)\n",
955
- " (dropout): Dropout(p=0.1, inplace=False)\n",
956
- " (linear2): Linear(in_features=2048, out_features=768, bias=True)\n",
957
- " (norm1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n",
958
- " (norm2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n",
959
- " (dropout1): Dropout(p=0.1, inplace=False)\n",
960
- " (dropout2): Dropout(p=0.1, inplace=False)\n",
961
- " )\n",
962
- " (1): TransformerEncoderLayer(\n",
963
- " (self_attn): MultiheadAttention(\n",
964
- " (out_proj): NonDynamicallyQuantizableLinear(in_features=768, out_features=768, bias=True)\n",
965
- " )\n",
966
- " (linear1): Linear(in_features=768, out_features=2048, bias=True)\n",
967
- " (dropout): Dropout(p=0.1, inplace=False)\n",
968
- " (linear2): Linear(in_features=2048, out_features=768, bias=True)\n",
969
- " (norm1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n",
970
- " (norm2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n",
971
- " (dropout1): Dropout(p=0.1, inplace=False)\n",
972
- " (dropout2): Dropout(p=0.1, inplace=False)\n",
973
- " )\n",
974
- " (2): TransformerEncoderLayer(\n",
975
- " (self_attn): MultiheadAttention(\n",
976
- " (out_proj): NonDynamicallyQuantizableLinear(in_features=768, out_features=768, bias=True)\n",
977
- " )\n",
978
- " (linear1): Linear(in_features=768, out_features=2048, bias=True)\n",
979
- " (dropout): Dropout(p=0.1, inplace=False)\n",
980
- " (linear2): Linear(in_features=2048, out_features=768, bias=True)\n",
981
- " (norm1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n",
982
- " (norm2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n",
983
- " (dropout1): Dropout(p=0.1, inplace=False)\n",
984
- " (dropout2): Dropout(p=0.1, inplace=False)\n",
985
- " )\n",
986
- " )\n",
987
- " )\n",
988
- " (classifier): Sequential(\n",
989
- " (0): Dropout(p=0.1, inplace=False)\n",
990
- " (1): ReLU()\n",
991
- " (2): Linear(in_features=768, out_features=512, bias=True)\n",
992
- " (3): Dropout(p=0.1, inplace=False)\n",
993
- " (4): ReLU()\n",
994
- " (5): Linear(in_features=512, out_features=256, bias=True)\n",
995
- " (6): Dropout(p=0.1, inplace=False)\n",
996
- " (7): ReLU()\n",
997
- " (8): Linear(in_features=256, out_features=128, bias=True)\n",
998
- " (9): Dropout(p=0.1, inplace=False)\n",
999
- " (10): ReLU()\n",
1000
- " (11): Linear(in_features=128, out_features=64, bias=True)\n",
1001
- " (12): Dropout(p=0.1, inplace=False)\n",
1002
- " (13): ReLU()\n",
1003
- " (14): Linear(in_features=64, out_features=2, bias=True)\n",
1004
- " )\n",
1005
- ")"
1006
- ]
1007
- },
1008
- "execution_count": 65,
1009
- "metadata": {},
1010
- "output_type": "execute_result"
1011
- }
1012
- ],
1013
- "source": [
1014
- "device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')\n",
1015
- "model = QuestionDistilBERT(mod)\n",
1016
- "model.to(device)"
1017
- ]
1018
- },
1019
- {
1020
- "cell_type": "code",
1021
- "execution_count": 66,
1022
- "id": "340857f9",
1023
- "metadata": {},
1024
- "outputs": [
1025
- {
1026
- "name": "stdout",
1027
- "output_type": "stream",
1028
- "text": [
1029
- "+---------------------------------------+------------+\n",
1030
- "| Modules | Parameters |\n",
1031
- "+---------------------------------------+------------+\n",
1032
- "| te.layers.0.self_attn.in_proj_weight | 1769472 |\n",
1033
- "| te.layers.0.self_attn.in_proj_bias | 2304 |\n",
1034
- "| te.layers.0.self_attn.out_proj.weight | 589824 |\n",
1035
- "| te.layers.0.self_attn.out_proj.bias | 768 |\n",
1036
- "| te.layers.0.linear1.weight | 1572864 |\n",
1037
- "| te.layers.0.linear1.bias | 2048 |\n",
1038
- "| te.layers.0.linear2.weight | 1572864 |\n",
1039
- "| te.layers.0.linear2.bias | 768 |\n",
1040
- "| te.layers.0.norm1.weight | 768 |\n",
1041
- "| te.layers.0.norm1.bias | 768 |\n",
1042
- "| te.layers.0.norm2.weight | 768 |\n",
1043
- "| te.layers.0.norm2.bias | 768 |\n",
1044
- "| te.layers.1.self_attn.in_proj_weight | 1769472 |\n",
1045
- "| te.layers.1.self_attn.in_proj_bias | 2304 |\n",
1046
- "| te.layers.1.self_attn.out_proj.weight | 589824 |\n",
1047
- "| te.layers.1.self_attn.out_proj.bias | 768 |\n",
1048
- "| te.layers.1.linear1.weight | 1572864 |\n",
1049
- "| te.layers.1.linear1.bias | 2048 |\n",
1050
- "| te.layers.1.linear2.weight | 1572864 |\n",
1051
- "| te.layers.1.linear2.bias | 768 |\n",
1052
- "| te.layers.1.norm1.weight | 768 |\n",
1053
- "| te.layers.1.norm1.bias | 768 |\n",
1054
- "| te.layers.1.norm2.weight | 768 |\n",
1055
- "| te.layers.1.norm2.bias | 768 |\n",
1056
- "| te.layers.2.self_attn.in_proj_weight | 1769472 |\n",
1057
- "| te.layers.2.self_attn.in_proj_bias | 2304 |\n",
1058
- "| te.layers.2.self_attn.out_proj.weight | 589824 |\n",
1059
- "| te.layers.2.self_attn.out_proj.bias | 768 |\n",
1060
- "| te.layers.2.linear1.weight | 1572864 |\n",
1061
- "| te.layers.2.linear1.bias | 2048 |\n",
1062
- "| te.layers.2.linear2.weight | 1572864 |\n",
1063
- "| te.layers.2.linear2.bias | 768 |\n",
1064
- "| te.layers.2.norm1.weight | 768 |\n",
1065
- "| te.layers.2.norm1.bias | 768 |\n",
1066
- "| te.layers.2.norm2.weight | 768 |\n",
1067
- "| te.layers.2.norm2.bias | 768 |\n",
1068
- "| classifier.2.weight | 393216 |\n",
1069
- "| classifier.2.bias | 512 |\n",
1070
- "| classifier.5.weight | 131072 |\n",
1071
- "| classifier.5.bias | 256 |\n",
1072
- "| classifier.8.weight | 32768 |\n",
1073
- "| classifier.8.bias | 128 |\n",
1074
- "| classifier.11.weight | 8192 |\n",
1075
- "| classifier.11.bias | 64 |\n",
1076
- "| classifier.14.weight | 128 |\n",
1077
- "| classifier.14.bias | 2 |\n",
1078
- "+---------------------------------------+------------+\n",
1079
- "Total Trainable Params: 17108290\n"
1080
- ]
1081
- },
1082
- {
1083
- "data": {
1084
- "text/plain": [
1085
- "17108290"
1086
- ]
1087
- },
1088
- "execution_count": 66,
1089
- "metadata": {},
1090
- "output_type": "execute_result"
1091
- }
1092
- ],
1093
- "source": [
1094
- "count_parameters(model)"
1095
- ]
1096
- },
1097
- {
1098
- "cell_type": "markdown",
1099
- "id": "9babd013",
1100
- "metadata": {},
1101
- "source": [
1102
- "### Testing the model\n",
1103
- "This is the same procedure as in `distilbert.ipynb`. "
1104
- ]
1105
- },
1106
- {
1107
- "cell_type": "code",
1108
- "execution_count": 67,
1109
- "id": "694c828b",
1110
- "metadata": {},
1111
- "outputs": [],
1112
- "source": [
1113
- "# get smaller dataset\n",
1114
- "batch_size = 8\n",
1115
- "test_ds = Dataset(squad_paths = squad_paths[:2], natural_question_paths=None, hotpotqa_paths=None, tokenizer=tokenizer)\n",
1116
- "test_ds_loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size)\n",
1117
- "optim=torch.optim.Adam(model.parameters())"
1118
- ]
1119
- },
1120
- {
1121
- "cell_type": "code",
1122
- "execution_count": 68,
1123
- "id": "a76587df",
1124
- "metadata": {},
1125
- "outputs": [
1126
- {
1127
- "name": "stdout",
1128
- "output_type": "stream",
1129
- "text": [
1130
- "Passed\n"
1131
- ]
1132
- }
1133
- ],
1134
- "source": [
1135
- "test_model(model, optim, test_ds_loader, device)"
1136
- ]
1137
- },
1138
- {
1139
- "cell_type": "markdown",
1140
- "id": "7c326e8e",
1141
- "metadata": {},
1142
- "source": [
1143
- "### Training the model\n",
1144
- "* Parameter Tuning:\n",
1145
- " * Learning Rate: I experimented with several values, 1e-4 seemed to work best for me. 1e-3 was very unstable and 1e-5 was too small.\n",
1146
- " * Gradient Clipping: I experimented with this, but the difference was only minimal\n",
1147
- "\n",
1148
- "Data:\n",
1149
- "* I first used only the SQuAD dataset, but generalisation is a problem\n",
1150
- " * The dataset is realtively small and we often have entries with the same context but different questions\n",
1151
- " * I believe, the diversity is not big enough to train a fully functional model\n",
1152
- "* Hence, I included the Natural Questions dataset too\n",
1153
- " * It is however a lot more messy - I elaborated a bit more on this in `load_data.ipynb`\n",
1154
- "* Also the hotpotqa data was used\n",
1155
- "\n",
1156
- "Tested with: \n",
1157
- "* 3 Linear Layers\n",
1158
- " * Training Error high - needed more layers\n",
1159
- " * Already expected - this was mostly a Proof of Concept\n",
1160
- "* 1 TransformerEncoder with 4 attention heads + 1 Linear Layer:\n",
1161
- " * Training Error was high, still too simple\n",
1162
- "* 1 TransformerEncoder with 8 heads + 1 Linear Layer:\n",
1163
- " * Training Error gets lower, however stagnates at some point\n",
1164
- " * Probably still too simple, it doesn't generalise either\n",
1165
- "* 2 TransformerEncoder with 8 and 4 heads + 1 Linear Layer:\n",
1166
- " * Loss gets down but doesn't go further after some time\n"
1167
- ]
1168
- },
1169
- {
1170
- "cell_type": "code",
1171
- "execution_count": null,
1172
- "id": "2e9f4bd3",
1173
- "metadata": {},
1174
- "outputs": [],
1175
- "source": [
1176
- "dataset = Dataset(squad_paths = squad_paths, natural_question_paths=nat_paths, hotpotqa_paths=hotpotqa_paths, tokenizer=tokenizer)\n",
1177
- "loader = torch.utils.data.DataLoader(dataset, batch_size=8)\n",
1178
- "\n",
1179
- "test_dataset = Dataset(squad_paths = [str(x) for x in Path('data/test_squad/').glob('**/*.txt')], \n",
1180
- " natural_question_paths=None, \n",
1181
- " hotpotqa_paths = None, tokenizer=tokenizer)\n",
1182
- "test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=4)"
1183
- ]
1184
- },
1185
- {
1186
- "cell_type": "code",
1187
- "execution_count": 26,
1188
- "id": "03a6de37",
1189
- "metadata": {},
1190
- "outputs": [],
1191
- "source": [
1192
- "model = QuestionDistilBERT(mod)"
1193
- ]
1194
- },
1195
- {
1196
- "cell_type": "code",
1197
- "execution_count": 41,
1198
- "id": "ed854b73",
1199
- "metadata": {},
1200
- "outputs": [],
1201
- "source": [
1202
- "from torch.optim import AdamW, RMSprop\n",
1203
- "\n",
1204
- "model.train()\n",
1205
- "optim = RMSprop(model.parameters(), lr=1e-4)"
1206
- ]
1207
- },
1208
- {
1209
- "cell_type": "code",
1210
- "execution_count": 42,
1211
- "id": "79fdfcc9",
1212
- "metadata": {},
1213
- "outputs": [],
1214
- "source": [
1215
- "from torch.utils.tensorboard import SummaryWriter\n",
1216
- "writer = SummaryWriter()"
1217
- ]
1218
- },
1219
- {
1220
- "cell_type": "code",
1221
- "execution_count": null,
1222
- "id": "f7bddb43",
1223
- "metadata": {},
1224
- "outputs": [
1225
- {
1226
- "data": {
1227
- "application/vnd.jupyter.widget-view+json": {
1228
- "model_id": "5e9e74167c4b4b22b3218f4ca3c5abf0",
1229
- "version_major": 2,
1230
- "version_minor": 0
1231
- },
1232
- "text/plain": [
1233
- " 0%| | 0/21750 [00:00<?, ?it/s]"
1234
- ]
1235
- },
1236
- "metadata": {},
1237
- "output_type": "display_data"
1238
- },
1239
- {
1240
- "name": "stdout",
1241
- "output_type": "stream",
1242
- "text": [
1243
- "Mean Training Error 3.8791405910185013\n"
1244
- ]
1245
- },
1246
- {
1247
- "data": {
1248
- "application/vnd.jupyter.widget-view+json": {
1249
- "model_id": "f3ce562fc61d4bfc83a4860eb06bc20c",
1250
- "version_major": 2,
1251
- "version_minor": 0
1252
- },
1253
- "text/plain": [
1254
- " 0%| | 0/1250 [00:00<?, ?it/s]"
1255
- ]
1256
- },
1257
- "metadata": {},
1258
- "output_type": "display_data"
1259
- },
1260
- {
1261
- "name": "stdout",
1262
- "output_type": "stream",
1263
- "text": [
1264
- "Mean Test Error 3.7705092002868654\n"
1265
- ]
1266
- },
1267
- {
1268
- "data": {
1269
- "application/vnd.jupyter.widget-view+json": {
1270
- "model_id": "2e84e21cedd446a0a5f5a40501711d1c",
1271
- "version_major": 2,
1272
- "version_minor": 0
1273
- },
1274
- "text/plain": [
1275
- " 0%| | 0/21750 [00:00<?, ?it/s]"
1276
- ]
1277
- },
1278
- "metadata": {},
1279
- "output_type": "display_data"
1280
- },
1281
- {
1282
- "name": "stdout",
1283
- "output_type": "stream",
1284
- "text": [
1285
- "Mean Training Error 3.7389922174091996\n"
1286
- ]
1287
- },
1288
- {
1289
- "data": {
1290
- "application/vnd.jupyter.widget-view+json": {
1291
- "model_id": "07135c48be0146498cd37d767c1ee6ab",
1292
- "version_major": 2,
1293
- "version_minor": 0
1294
- },
1295
- "text/plain": [
1296
- " 0%| | 0/1250 [00:00<?, ?it/s]"
1297
- ]
1298
- },
1299
- "metadata": {},
1300
- "output_type": "display_data"
1301
- },
1302
- {
1303
- "name": "stdout",
1304
- "output_type": "stream",
1305
- "text": [
1306
- "Mean Test Error 3.7443671816825868\n"
1307
- ]
1308
- },
1309
- {
1310
- "data": {
1311
- "application/vnd.jupyter.widget-view+json": {
1312
- "model_id": "e9a51fbabc7043c2819a68e247e4a3ec",
1313
- "version_major": 2,
1314
- "version_minor": 0
1315
- },
1316
- "text/plain": [
1317
- " 0%| | 0/21750 [00:00<?, ?it/s]"
1318
- ]
1319
- },
1320
- "metadata": {},
1321
- "output_type": "display_data"
1322
- },
1323
- {
1324
- "name": "stdout",
1325
- "output_type": "stream",
1326
- "text": [
1327
- "Mean Training Error 3.7031057048117977\n"
1328
- ]
1329
- },
1330
- {
1331
- "data": {
1332
- "application/vnd.jupyter.widget-view+json": {
1333
- "model_id": "bfdbcc9fe32542a19c47bc1d7704400e",
1334
- "version_major": 2,
1335
- "version_minor": 0
1336
- },
1337
- "text/plain": [
1338
- " 0%| | 0/1250 [00:00<?, ?it/s]"
1339
- ]
1340
- },
1341
- "metadata": {},
1342
- "output_type": "display_data"
1343
- },
1344
- {
1345
- "name": "stdout",
1346
- "output_type": "stream",
1347
- "text": [
1348
- "Mean Test Error 3.743248237323761\n"
1349
- ]
1350
- },
1351
- {
1352
- "data": {
1353
- "application/vnd.jupyter.widget-view+json": {
1354
- "model_id": "81fd1278b22643dc9fb3ac306533a240",
1355
- "version_major": 2,
1356
- "version_minor": 0
1357
- },
1358
- "text/plain": [
1359
- " 0%| | 0/21750 [00:00<?, ?it/s]"
1360
- ]
1361
- },
1362
- "metadata": {},
1363
- "output_type": "display_data"
1364
- },
1365
- {
1366
- "name": "stdout",
1367
- "output_type": "stream",
1368
- "text": [
1369
- "Mean Training Error 3.6711661003430685\n"
1370
- ]
1371
- },
1372
- {
1373
- "data": {
1374
- "application/vnd.jupyter.widget-view+json": {
1375
- "model_id": "8b38d6cd44e048ec8bcd6b5cb86cce16",
1376
- "version_major": 2,
1377
- "version_minor": 0
1378
- },
1379
- "text/plain": [
1380
- " 0%| | 0/1250 [00:00<?, ?it/s]"
1381
- ]
1382
- },
1383
- "metadata": {},
1384
- "output_type": "display_data"
1385
- },
1386
- {
1387
- "name": "stdout",
1388
- "output_type": "stream",
1389
- "text": [
1390
- "Mean Test Error 3.740310479736328\n"
1391
- ]
1392
- },
1393
- {
1394
- "data": {
1395
- "application/vnd.jupyter.widget-view+json": {
1396
- "model_id": "825248aa3f934f4aade9d973e6f3b43e",
1397
- "version_major": 2,
1398
- "version_minor": 0
1399
- },
1400
- "text/plain": [
1401
- " 0%| | 0/21750 [00:00<?, ?it/s]"
1402
- ]
1403
- },
1404
- "metadata": {},
1405
- "output_type": "display_data"
1406
- },
1407
- {
1408
- "name": "stdout",
1409
- "output_type": "stream",
1410
- "text": [
1411
- "Mean Training Error 3.6591619139813827\n"
1412
- ]
1413
- },
1414
- {
1415
- "data": {
1416
- "application/vnd.jupyter.widget-view+json": {
1417
- "model_id": "edceb7af0ec6450997820967638c12db",
1418
- "version_major": 2,
1419
- "version_minor": 0
1420
- },
1421
- "text/plain": [
1422
- " 0%| | 0/1250 [00:00<?, ?it/s]"
1423
- ]
1424
- },
1425
- "metadata": {},
1426
- "output_type": "display_data"
1427
- },
1428
- {
1429
- "name": "stdout",
1430
- "output_type": "stream",
1431
- "text": [
1432
- "Mean Test Error 3.8138498876571654\n"
1433
- ]
1434
- },
1435
- {
1436
- "data": {
1437
- "application/vnd.jupyter.widget-view+json": {
1438
- "model_id": "27e903eb0d0f4f949c234e4faf4277a1",
1439
- "version_major": 2,
1440
- "version_minor": 0
1441
- },
1442
- "text/plain": [
1443
- " 0%| | 0/21750 [00:00<?, ?it/s]"
1444
- ]
1445
- },
1446
- "metadata": {},
1447
- "output_type": "display_data"
1448
- }
1449
- ],
1450
- "source": [
1451
- "epochs = 20\n",
1452
- "\n",
1453
- "for epoch in range(epochs):\n",
1454
- " loop = tqdm(loader, leave=True)\n",
1455
- " model.train()\n",
1456
- " mean_training_error = []\n",
1457
- " for batch in loop:\n",
1458
- " optim.zero_grad()\n",
1459
- " \n",
1460
- " input_ids = batch['input_ids'].to(device)\n",
1461
- " attention_mask = batch['attention_mask'].to(device)\n",
1462
- " start = batch['start_positions'].to(device)\n",
1463
- " end = batch['end_positions'].to(device)\n",
1464
- " \n",
1465
- " outputs = model(input_ids, attention_mask=attention_mask, start_positions=start, end_positions=end)\n",
1466
- " \n",
1467
- " loss = outputs['loss']\n",
1468
- " loss.backward()\n",
1469
- " \n",
1470
- " optim.step()\n",
1471
- " mean_training_error.append(loss.item())\n",
1472
- " loop.set_description(f'Epoch {epoch}')\n",
1473
- " loop.set_postfix(loss=loss.item())\n",
1474
- " print(\"Mean Training Error\", np.mean(mean_training_error))\n",
1475
- " writer.add_scalar(\"Loss/train\", np.mean(mean_training_error), epoch)\n",
1476
- " \n",
1477
- " loop = tqdm(test_loader, leave=True)\n",
1478
- " model.eval()\n",
1479
- " mean_test_error = []\n",
1480
- " for batch in loop:\n",
1481
- " \n",
1482
- " input_ids = batch['input_ids'].to(device)\n",
1483
- " attention_mask = batch['attention_mask'].to(device)\n",
1484
- " start = batch['start_positions'].to(device)\n",
1485
- " end = batch['end_positions'].to(device)\n",
1486
- " \n",
1487
- " outputs = model(input_ids, attention_mask=attention_mask, start_positions=start, end_positions=end)\n",
1488
- " # print(torch.argmax(outputs['start_logits'],axis=1), torch.argmax(outputs['end_logits'], axis=1), start, end)\n",
1489
- " loss = outputs['loss']\n",
1490
- " \n",
1491
- " mean_test_error.append(loss.item())\n",
1492
- " loop.set_description(f'Epoch {epoch} Testset')\n",
1493
- " loop.set_postfix(loss=loss.item())\n",
1494
- " print(\"Mean Test Error\", np.mean(mean_test_error))\n",
1495
- " writer.add_scalar(\"Loss/test\", np.mean(mean_test_error), epoch)"
1496
- ]
1497
- },
1498
- {
1499
- "cell_type": "code",
1500
- "execution_count": 238,
1501
- "id": "a9d6af2e",
1502
- "metadata": {},
1503
- "outputs": [],
1504
- "source": [
1505
- "writer.close()"
1506
- ]
1507
- },
1508
- {
1509
- "cell_type": "code",
1510
- "execution_count": 33,
1511
- "id": "ba43447e",
1512
- "metadata": {},
1513
- "outputs": [],
1514
- "source": [
1515
- "torch.save(model.state_dict(), \"distilbert_qa.model\")"
1516
- ]
1517
- },
1518
- {
1519
- "cell_type": "code",
1520
- "execution_count": 34,
1521
- "id": "ffc49aca",
1522
- "metadata": {},
1523
- "outputs": [
1524
- {
1525
- "data": {
1526
- "text/plain": [
1527
- "<All keys matched successfully>"
1528
- ]
1529
- },
1530
- "execution_count": 34,
1531
- "metadata": {},
1532
- "output_type": "execute_result"
1533
- }
1534
- ],
1535
- "source": [
1536
- "model = QuestionDistilBERT(mod)\n",
1537
- "model.load_state_dict(torch.load(\"distilbert_qa.model\"))"
1538
- ]
1539
- },
1540
- {
1541
- "cell_type": "code",
1542
- "execution_count": 35,
1543
- "id": "730a86c1",
1544
- "metadata": {},
1545
- "outputs": [
1546
- {
1547
- "name": "stderr",
1548
- "output_type": "stream",
1549
- "text": [
1550
- "100%|鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅| 2500/2500 [02:57<00:00, 14.09it/s]"
1551
- ]
1552
- },
1553
- {
1554
- "name": "stdout",
1555
- "output_type": "stream",
1556
- "text": [
1557
- "Mean EM: 0.0479\n",
1558
- "Mean F-1: 0.08989175857485086\n"
1559
- ]
1560
- },
1561
- {
1562
- "name": "stderr",
1563
- "output_type": "stream",
1564
- "text": [
1565
- "\n"
1566
- ]
1567
- }
1568
- ],
1569
- "source": [
1570
- "eval_test_set(model, tokenizer, test_loader, device)"
1571
- ]
1572
- },
1573
- {
1574
- "cell_type": "markdown",
1575
- "id": "bd1c7076",
1576
- "metadata": {},
1577
- "source": [
1578
- "## Reuse Layer\n",
1579
- "This was inspired by how well the original model with just one classification head worked. I felt like the main problem with the previous model was the lack of structure which was already in the layers, combined with the massive amount of resources needed for a Transformer.\n",
1580
- "\n",
1581
- "Hence, I tried cloning the last (and then last two) layers of the DistilBERT model, putting a classifier on top and using this as the head. The base DistilBERT model is completely frozen. This worked extremely well, while we only fine-tune about 21% of the parameters (14 Mio as opposed to 66 Mio!) we did before. Below you can see the results.\n",
1582
- "\n",
1583
- "### Last DistilBERT layer\n",
1584
- "\n",
1585
- "Dropout 0.1 and RMSprop 1e-4:\n",
1586
- "* Mean EM: 0.3888\n",
1587
- "* Mean F-1: 0.5122932744694068\n",
1588
- "\n",
1589
- "Dropout 0.25: very early stagnating\n",
1590
- "* Mean EM: 0.3552\n",
1591
- "* Mean F-1: 0.4711235721312687\n",
1592
- "\n",
1593
- "Dropout 0.15: seems to work well - training and test error stagnate around 1.7 and 1.8 but good generalisation (need to add more layers)\n",
1594
- "* Mean EM: 0.4119\n",
1595
- "* Mean F-1: 0.5296387232893214\n",
1596
- "\n",
1597
- "### Last DitilBERT layer + more Dense layers\n",
1598
- "Dropout 0.15 + 4 dense layers((786-512)-(512-256)-(256-128)-(128-2)) & ReLU: doesn't work too well - stagnates at around 2.4\n",
1599
- "\n",
1600
- "### Last two DistilBERT layers\n",
1601
- "Dropout 0.1 but last 2 DistilBERT layers: works very well, but early overfitting - maybe use more data\n",
1602
- "* Mean EM: 0.458\n",
1603
- "* Mean F-1: 0.6003368353673634\n",
1604
- "\n",
1605
- "Dropout 0.1 - last 2 distilbert layers: all data\n",
1606
- "* Mean EM: 0.484\n",
1607
- "* Mean F-1: 0.6344960035215299\n",
1608
- "\n",
1609
- "Dropout 0.15 - **BEST**\n",
1610
- "* Mean EM: 0.5178\n",
1611
- "* Mean F-1: 0.6671140689626448\n",
1612
- "\n",
1613
- "Dropout 0.2 - doesn't work too well\n",
1614
- "* Mean EM: 0.4353\n",
1615
- "* Mean F-1: 0.5776847879304647\n"
1616
- ]
1617
- },
1618
- {
1619
- "cell_type": "code",
1620
- "execution_count": 69,
1621
- "id": "654e09e8",
1622
- "metadata": {},
1623
- "outputs": [],
1624
- "source": [
1625
- "dataset = Dataset(squad_paths = squad_paths, natural_question_paths=None, hotpotqa_paths=hotpotqa_paths, tokenizer=tokenizer)\n",
1626
- "loader = torch.utils.data.DataLoader(dataset, batch_size=8)\n",
1627
- "\n",
1628
- "test_dataset = Dataset(squad_paths = [str(x) for x in Path('data/test_squad/').glob('**/*.txt')], \n",
1629
- " natural_question_paths=None, \n",
1630
- " hotpotqa_paths = None, tokenizer=tokenizer)\n",
1631
- "test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=4)"
1632
- ]
1633
- },
1634
- {
1635
- "cell_type": "code",
1636
- "execution_count": 70,
1637
- "id": "707c0cb5",
1638
- "metadata": {},
1639
- "outputs": [
1640
- {
1641
- "data": {
1642
- "text/plain": [
1643
- "ReuseQuestionDistilBERT(\n",
1644
- " (te): ModuleList(\n",
1645
- " (0): TransformerBlock(\n",
1646
- " (attention): MultiHeadSelfAttention(\n",
1647
- " (dropout): Dropout(p=0.1, inplace=False)\n",
1648
- " (q_lin): Linear(in_features=768, out_features=768, bias=True)\n",
1649
- " (k_lin): Linear(in_features=768, out_features=768, bias=True)\n",
1650
- " (v_lin): Linear(in_features=768, out_features=768, bias=True)\n",
1651
- " (out_lin): Linear(in_features=768, out_features=768, bias=True)\n",
1652
- " )\n",
1653
- " (sa_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
1654
- " (ffn): FFN(\n",
1655
- " (dropout): Dropout(p=0.1, inplace=False)\n",
1656
- " (lin1): Linear(in_features=768, out_features=3072, bias=True)\n",
1657
- " (lin2): Linear(in_features=3072, out_features=768, bias=True)\n",
1658
- " (activation): GELUActivation()\n",
1659
- " )\n",
1660
- " (output_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
1661
- " )\n",
1662
- " (1): TransformerBlock(\n",
1663
- " (attention): MultiHeadSelfAttention(\n",
1664
- " (dropout): Dropout(p=0.1, inplace=False)\n",
1665
- " (q_lin): Linear(in_features=768, out_features=768, bias=True)\n",
1666
- " (k_lin): Linear(in_features=768, out_features=768, bias=True)\n",
1667
- " (v_lin): Linear(in_features=768, out_features=768, bias=True)\n",
1668
- " (out_lin): Linear(in_features=768, out_features=768, bias=True)\n",
1669
- " )\n",
1670
- " (sa_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
1671
- " (ffn): FFN(\n",
1672
- " (dropout): Dropout(p=0.1, inplace=False)\n",
1673
- " (lin1): Linear(in_features=768, out_features=3072, bias=True)\n",
1674
- " (lin2): Linear(in_features=3072, out_features=768, bias=True)\n",
1675
- " (activation): GELUActivation()\n",
1676
- " )\n",
1677
- " (output_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
1678
- " )\n",
1679
- " )\n",
1680
- " (distilbert): DistilBertModel(\n",
1681
- " (embeddings): Embeddings(\n",
1682
- " (word_embeddings): Embedding(30522, 768, padding_idx=0)\n",
1683
- " (position_embeddings): Embedding(512, 768)\n",
1684
- " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
1685
- " (dropout): Dropout(p=0.1, inplace=False)\n",
1686
- " )\n",
1687
- " (transformer): Transformer(\n",
1688
- " (layer): ModuleList(\n",
1689
- " (0): TransformerBlock(\n",
1690
- " (attention): MultiHeadSelfAttention(\n",
1691
- " (dropout): Dropout(p=0.1, inplace=False)\n",
1692
- " (q_lin): Linear(in_features=768, out_features=768, bias=True)\n",
1693
- " (k_lin): Linear(in_features=768, out_features=768, bias=True)\n",
1694
- " (v_lin): Linear(in_features=768, out_features=768, bias=True)\n",
1695
- " (out_lin): Linear(in_features=768, out_features=768, bias=True)\n",
1696
- " )\n",
1697
- " (sa_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
1698
- " (ffn): FFN(\n",
1699
- " (dropout): Dropout(p=0.1, inplace=False)\n",
1700
- " (lin1): Linear(in_features=768, out_features=3072, bias=True)\n",
1701
- " (lin2): Linear(in_features=3072, out_features=768, bias=True)\n",
1702
- " (activation): GELUActivation()\n",
1703
- " )\n",
1704
- " (output_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
1705
- " )\n",
1706
- " (1): TransformerBlock(\n",
1707
- " (attention): MultiHeadSelfAttention(\n",
1708
- " (dropout): Dropout(p=0.1, inplace=False)\n",
1709
- " (q_lin): Linear(in_features=768, out_features=768, bias=True)\n",
1710
- " (k_lin): Linear(in_features=768, out_features=768, bias=True)\n",
1711
- " (v_lin): Linear(in_features=768, out_features=768, bias=True)\n",
1712
- " (out_lin): Linear(in_features=768, out_features=768, bias=True)\n",
1713
- " )\n",
1714
- " (sa_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
1715
- " (ffn): FFN(\n",
1716
- " (dropout): Dropout(p=0.1, inplace=False)\n",
1717
- " (lin1): Linear(in_features=768, out_features=3072, bias=True)\n",
1718
- " (lin2): Linear(in_features=3072, out_features=768, bias=True)\n",
1719
- " (activation): GELUActivation()\n",
1720
- " )\n",
1721
- " (output_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
1722
- " )\n",
1723
- " (2): TransformerBlock(\n",
1724
- " (attention): MultiHeadSelfAttention(\n",
1725
- " (dropout): Dropout(p=0.1, inplace=False)\n",
1726
- " (q_lin): Linear(in_features=768, out_features=768, bias=True)\n",
1727
- " (k_lin): Linear(in_features=768, out_features=768, bias=True)\n",
1728
- " (v_lin): Linear(in_features=768, out_features=768, bias=True)\n",
1729
- " (out_lin): Linear(in_features=768, out_features=768, bias=True)\n",
1730
- " )\n",
1731
- " (sa_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
1732
- " (ffn): FFN(\n",
1733
- " (dropout): Dropout(p=0.1, inplace=False)\n",
1734
- " (lin1): Linear(in_features=768, out_features=3072, bias=True)\n",
1735
- " (lin2): Linear(in_features=3072, out_features=768, bias=True)\n",
1736
- " (activation): GELUActivation()\n",
1737
- " )\n",
1738
- " (output_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
1739
- " )\n",
1740
- " (3): TransformerBlock(\n",
1741
- " (attention): MultiHeadSelfAttention(\n",
1742
- " (dropout): Dropout(p=0.1, inplace=False)\n",
1743
- " (q_lin): Linear(in_features=768, out_features=768, bias=True)\n",
1744
- " (k_lin): Linear(in_features=768, out_features=768, bias=True)\n",
1745
- " (v_lin): Linear(in_features=768, out_features=768, bias=True)\n",
1746
- " (out_lin): Linear(in_features=768, out_features=768, bias=True)\n",
1747
- " )\n",
1748
- " (sa_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
1749
- " (ffn): FFN(\n",
1750
- " (dropout): Dropout(p=0.1, inplace=False)\n",
1751
- " (lin1): Linear(in_features=768, out_features=3072, bias=True)\n",
1752
- " (lin2): Linear(in_features=3072, out_features=768, bias=True)\n",
1753
- " (activation): GELUActivation()\n",
1754
- " )\n",
1755
- " (output_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
1756
- " )\n",
1757
- " (4): TransformerBlock(\n",
1758
- " (attention): MultiHeadSelfAttention(\n",
1759
- " (dropout): Dropout(p=0.1, inplace=False)\n",
1760
- " (q_lin): Linear(in_features=768, out_features=768, bias=True)\n",
1761
- " (k_lin): Linear(in_features=768, out_features=768, bias=True)\n",
1762
- " (v_lin): Linear(in_features=768, out_features=768, bias=True)\n",
1763
- " (out_lin): Linear(in_features=768, out_features=768, bias=True)\n",
1764
- " )\n",
1765
- " (sa_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
1766
- " (ffn): FFN(\n",
1767
- " (dropout): Dropout(p=0.1, inplace=False)\n",
1768
- " (lin1): Linear(in_features=768, out_features=3072, bias=True)\n",
1769
- " (lin2): Linear(in_features=3072, out_features=768, bias=True)\n",
1770
- " (activation): GELUActivation()\n",
1771
- " )\n",
1772
- " (output_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
1773
- " )\n",
1774
- " (5): TransformerBlock(\n",
1775
- " (attention): MultiHeadSelfAttention(\n",
1776
- " (dropout): Dropout(p=0.1, inplace=False)\n",
1777
- " (q_lin): Linear(in_features=768, out_features=768, bias=True)\n",
1778
- " (k_lin): Linear(in_features=768, out_features=768, bias=True)\n",
1779
- " (v_lin): Linear(in_features=768, out_features=768, bias=True)\n",
1780
- " (out_lin): Linear(in_features=768, out_features=768, bias=True)\n",
1781
- " )\n",
1782
- " (sa_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
1783
- " (ffn): FFN(\n",
1784
- " (dropout): Dropout(p=0.1, inplace=False)\n",
1785
- " (lin1): Linear(in_features=768, out_features=3072, bias=True)\n",
1786
- " (lin2): Linear(in_features=3072, out_features=768, bias=True)\n",
1787
- " (activation): GELUActivation()\n",
1788
- " )\n",
1789
- " (output_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
1790
- " )\n",
1791
- " )\n",
1792
- " )\n",
1793
- " )\n",
1794
- " (relu): ReLU()\n",
1795
- " (dropout): Dropout(p=0.15, inplace=False)\n",
1796
- " (classifier): Linear(in_features=768, out_features=2, bias=True)\n",
1797
- ")"
1798
- ]
1799
- },
1800
- "execution_count": 70,
1801
- "metadata": {},
1802
- "output_type": "execute_result"
1803
- }
1804
- ],
1805
- "source": [
1806
- "model = DistilBertForMaskedLM.from_pretrained(\"distilbert-base-uncased\")\n",
1807
- "config = DistilBertConfig.from_pretrained(\"distilbert-base-uncased\")\n",
1808
- "mod = model.distilbert\n",
1809
- "\n",
1810
- "device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')\n",
1811
- "model = ReuseQuestionDistilBERT(mod)\n",
1812
- "model.to(device)"
1813
- ]
1814
- },
1815
- {
1816
- "cell_type": "code",
1817
- "execution_count": 71,
1818
- "id": "d2c6bff5",
1819
- "metadata": {},
1820
- "outputs": [
1821
- {
1822
- "name": "stdout",
1823
- "output_type": "stream",
1824
- "text": [
1825
- "+-------------------------------+------------+\n",
1826
- "| Modules | Parameters |\n",
1827
- "+-------------------------------+------------+\n",
1828
- "| te.0.attention.q_lin.weight | 589824 |\n",
1829
- "| te.0.attention.q_lin.bias | 768 |\n",
1830
- "| te.0.attention.k_lin.weight | 589824 |\n",
1831
- "| te.0.attention.k_lin.bias | 768 |\n",
1832
- "| te.0.attention.v_lin.weight | 589824 |\n",
1833
- "| te.0.attention.v_lin.bias | 768 |\n",
1834
- "| te.0.attention.out_lin.weight | 589824 |\n",
1835
- "| te.0.attention.out_lin.bias | 768 |\n",
1836
- "| te.0.sa_layer_norm.weight | 768 |\n",
1837
- "| te.0.sa_layer_norm.bias | 768 |\n",
1838
- "| te.0.ffn.lin1.weight | 2359296 |\n",
1839
- "| te.0.ffn.lin1.bias | 3072 |\n",
1840
- "| te.0.ffn.lin2.weight | 2359296 |\n",
1841
- "| te.0.ffn.lin2.bias | 768 |\n",
1842
- "| te.0.output_layer_norm.weight | 768 |\n",
1843
- "| te.0.output_layer_norm.bias | 768 |\n",
1844
- "| te.1.attention.q_lin.weight | 589824 |\n",
1845
- "| te.1.attention.q_lin.bias | 768 |\n",
1846
- "| te.1.attention.k_lin.weight | 589824 |\n",
1847
- "| te.1.attention.k_lin.bias | 768 |\n",
1848
- "| te.1.attention.v_lin.weight | 589824 |\n",
1849
- "| te.1.attention.v_lin.bias | 768 |\n",
1850
- "| te.1.attention.out_lin.weight | 589824 |\n",
1851
- "| te.1.attention.out_lin.bias | 768 |\n",
1852
- "| te.1.sa_layer_norm.weight | 768 |\n",
1853
- "| te.1.sa_layer_norm.bias | 768 |\n",
1854
- "| te.1.ffn.lin1.weight | 2359296 |\n",
1855
- "| te.1.ffn.lin1.bias | 3072 |\n",
1856
- "| te.1.ffn.lin2.weight | 2359296 |\n",
1857
- "| te.1.ffn.lin2.bias | 768 |\n",
1858
- "| te.1.output_layer_norm.weight | 768 |\n",
1859
- "| te.1.output_layer_norm.bias | 768 |\n",
1860
- "| classifier.weight | 1536 |\n",
1861
- "| classifier.bias | 2 |\n",
1862
- "+-------------------------------+------------+\n",
1863
- "Total Trainable Params: 14177282\n"
1864
- ]
1865
- },
1866
- {
1867
- "data": {
1868
- "text/plain": [
1869
- "14177282"
1870
- ]
1871
- },
1872
- "execution_count": 71,
1873
- "metadata": {},
1874
- "output_type": "execute_result"
1875
- }
1876
- ],
1877
- "source": [
1878
- "count_parameters(model)"
1879
- ]
1880
- },
1881
- {
1882
- "cell_type": "markdown",
1883
- "id": "c386c2eb",
1884
- "metadata": {},
1885
- "source": [
1886
- "### Testing the Model"
1887
- ]
1888
- },
1889
- {
1890
- "cell_type": "code",
1891
- "execution_count": 72,
1892
- "id": "818deed3",
1893
- "metadata": {},
1894
- "outputs": [],
1895
- "source": [
1896
- "# get smaller dataset\n",
1897
- "batch_size = 8\n",
1898
- "test_ds = Dataset(squad_paths = squad_paths[:2], natural_question_paths=None, hotpotqa_paths=None, tokenizer=tokenizer)\n",
1899
- "test_ds_loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size)\n",
1900
- "optim=torch.optim.Adam(model.parameters())"
1901
- ]
1902
- },
1903
- {
1904
- "cell_type": "code",
1905
- "execution_count": 73,
1906
- "id": "9da40760",
1907
- "metadata": {},
1908
- "outputs": [
1909
- {
1910
- "name": "stdout",
1911
- "output_type": "stream",
1912
- "text": [
1913
- "Passed\n"
1914
- ]
1915
- }
1916
- ],
1917
- "source": [
1918
- "test_model(model, optim, test_ds_loader, device)"
1919
- ]
1920
- },
1921
- {
1922
- "cell_type": "markdown",
1923
- "id": "c3f80248",
1924
- "metadata": {},
1925
- "source": [
1926
- "### Model Training"
1927
- ]
1928
- },
1929
- {
1930
- "cell_type": "code",
1931
- "execution_count": 24,
1932
- "id": "e1adabe6",
1933
- "metadata": {},
1934
- "outputs": [],
1935
- "source": [
1936
- "from torch.optim import AdamW, RMSprop\n",
1937
- "\n",
1938
- "model.train()\n",
1939
- "optim = AdamW(model.parameters(), lr=1e-4)"
1940
- ]
1941
- },
1942
- {
1943
- "cell_type": "code",
1944
- "execution_count": 25,
1945
- "id": "efe1cbd5",
1946
- "metadata": {},
1947
- "outputs": [
1948
- {
1949
- "data": {
1950
- "application/vnd.jupyter.widget-view+json": {
1951
- "model_id": "8785757b04214102830ded36c1392c8d",
1952
- "version_major": 2,
1953
- "version_minor": 0
1954
- },
1955
- "text/plain": [
1956
- " 0%| | 0/35000 [00:00<?, ?it/s]"
1957
- ]
1958
- },
1959
- "metadata": {},
1960
- "output_type": "display_data"
1961
- },
1962
- {
1963
- "name": "stdout",
1964
- "output_type": "stream",
1965
- "text": [
1966
- "Mean Training Error 2.6535016193100383\n"
1967
- ]
1968
- },
1969
- {
1970
- "data": {
1971
- "application/vnd.jupyter.widget-view+json": {
1972
- "model_id": "836f5365498642fa9ae891a86dca5892",
1973
- "version_major": 2,
1974
- "version_minor": 0
1975
- },
1976
- "text/plain": [
1977
- " 0%| | 0/2500 [00:00<?, ?it/s]"
1978
- ]
1979
- },
1980
- "metadata": {},
1981
- "output_type": "display_data"
1982
- },
1983
- {
1984
- "name": "stdout",
1985
- "output_type": "stream",
1986
- "text": [
1987
- "Mean Test Error 2.384517493388057\n"
1988
- ]
1989
- },
1990
- {
1991
- "data": {
1992
- "application/vnd.jupyter.widget-view+json": {
1993
- "model_id": "981e1cef83a1477e920d1cdbffdfcde1",
1994
- "version_major": 2,
1995
- "version_minor": 0
1996
- },
1997
- "text/plain": [
1998
- " 0%| | 0/35000 [00:00<?, ?it/s]"
1999
- ]
2000
- },
2001
- "metadata": {},
2002
- "output_type": "display_data"
2003
- },
2004
- {
2005
- "name": "stdout",
2006
- "output_type": "stream",
2007
- "text": [
2008
- "Mean Training Error 2.172889394424643\n"
2009
- ]
2010
- },
2011
- {
2012
- "data": {
2013
- "application/vnd.jupyter.widget-view+json": {
2014
- "model_id": "20a785e7fefb43239f1120992d2c3416",
2015
- "version_major": 2,
2016
- "version_minor": 0
2017
- },
2018
- "text/plain": [
2019
- " 0%| | 0/2500 [00:00<?, ?it/s]"
2020
- ]
2021
- },
2022
- "metadata": {},
2023
- "output_type": "display_data"
2024
- },
2025
- {
2026
- "name": "stdout",
2027
- "output_type": "stream",
2028
- "text": [
2029
- "Mean Test Error 2.013008696398139\n"
2030
- ]
2031
- },
2032
- {
2033
- "data": {
2034
- "application/vnd.jupyter.widget-view+json": {
2035
- "model_id": "47831e65b1ed4be78e8e7cb24068b0c3",
2036
- "version_major": 2,
2037
- "version_minor": 0
2038
- },
2039
- "text/plain": [
2040
- " 0%| | 0/35000 [00:00<?, ?it/s]"
2041
- ]
2042
- },
2043
- "metadata": {},
2044
- "output_type": "display_data"
2045
- },
2046
- {
2047
- "name": "stdout",
2048
- "output_type": "stream",
2049
- "text": [
2050
- "Mean Training Error 1.9743544759827\n"
2051
- ]
2052
- },
2053
- {
2054
- "data": {
2055
- "application/vnd.jupyter.widget-view+json": {
2056
- "model_id": "15904a3f930249fb944ea87184676e14",
2057
- "version_major": 2,
2058
- "version_minor": 0
2059
- },
2060
- "text/plain": [
2061
- " 0%| | 0/2500 [00:00<?, ?it/s]"
2062
- ]
2063
- },
2064
- "metadata": {},
2065
- "output_type": "display_data"
2066
- },
2067
- {
2068
- "name": "stdout",
2069
- "output_type": "stream",
2070
- "text": [
2071
- "Mean Test Error 1.8922049684919418\n"
2072
- ]
2073
- },
2074
- {
2075
- "data": {
2076
- "application/vnd.jupyter.widget-view+json": {
2077
- "model_id": "108bdbf644d94d78910195992b9e2652",
2078
- "version_major": 2,
2079
- "version_minor": 0
2080
- },
2081
- "text/plain": [
2082
- " 0%| | 0/35000 [00:00<?, ?it/s]"
2083
- ]
2084
- },
2085
- "metadata": {},
2086
- "output_type": "display_data"
2087
- },
2088
- {
2089
- "name": "stdout",
2090
- "output_type": "stream",
2091
- "text": [
2092
- "Mean Training Error 1.857202093189742\n"
2093
- ]
2094
- },
2095
- {
2096
- "data": {
2097
- "application/vnd.jupyter.widget-view+json": {
2098
- "model_id": "d6a75a6ab40d4a2599b7511bfc60bf83",
2099
- "version_major": 2,
2100
- "version_minor": 0
2101
- },
2102
- "text/plain": [
2103
- " 0%| | 0/2500 [00:00<?, ?it/s]"
2104
- ]
2105
- },
2106
- "metadata": {},
2107
- "output_type": "display_data"
2108
- },
2109
- {
2110
- "name": "stdout",
2111
- "output_type": "stream",
2112
- "text": [
2113
- "Mean Test Error 1.793771461571753\n"
2114
- ]
2115
- },
2116
- {
2117
- "data": {
2118
- "application/vnd.jupyter.widget-view+json": {
2119
- "model_id": "d3468a6ba72a4f42b0e7cc77ee0a0011",
2120
- "version_major": 2,
2121
- "version_minor": 0
2122
- },
2123
- "text/plain": [
2124
- " 0%| | 0/35000 [00:00<?, ?it/s]"
2125
- ]
2126
- },
2127
- "metadata": {},
2128
- "output_type": "display_data"
2129
- },
2130
- {
2131
- "name": "stdout",
2132
- "output_type": "stream",
2133
- "text": [
2134
- "Mean Training Error 1.7750537034896867\n"
2135
- ]
2136
- },
2137
- {
2138
- "data": {
2139
- "application/vnd.jupyter.widget-view+json": {
2140
- "model_id": "8aca0aa529d2452e8bd29fe7ada934f2",
2141
- "version_major": 2,
2142
- "version_minor": 0
2143
- },
2144
- "text/plain": [
2145
- " 0%| | 0/2500 [00:00<?, ?it/s]"
2146
- ]
2147
- },
2148
- "metadata": {},
2149
- "output_type": "display_data"
2150
- },
2151
- {
2152
- "name": "stdout",
2153
- "output_type": "stream",
2154
- "text": [
2155
- "Mean Test Error 1.7466133671954274\n"
2156
- ]
2157
- },
2158
- {
2159
- "data": {
2160
- "application/vnd.jupyter.widget-view+json": {
2161
- "model_id": "e09abdfa63c841ce97f445ba9b3eeaa8",
2162
- "version_major": 2,
2163
- "version_minor": 0
2164
- },
2165
- "text/plain": [
2166
- " 0%| | 0/35000 [00:00<?, ?it/s]"
2167
- ]
2168
- },
2169
- "metadata": {},
2170
- "output_type": "display_data"
2171
- },
2172
- {
2173
- "name": "stdout",
2174
- "output_type": "stream",
2175
- "text": [
2176
- "Mean Training Error 1.7097622096568346\n"
2177
- ]
2178
- },
2179
- {
2180
- "data": {
2181
- "application/vnd.jupyter.widget-view+json": {
2182
- "model_id": "0f49dd32d33e4f398be0942a59d735ce",
2183
- "version_major": 2,
2184
- "version_minor": 0
2185
- },
2186
- "text/plain": [
2187
- " 0%| | 0/2500 [00:00<?, ?it/s]"
2188
- ]
2189
- },
2190
- "metadata": {},
2191
- "output_type": "display_data"
2192
- },
2193
- {
2194
- "name": "stdout",
2195
- "output_type": "stream",
2196
- "text": [
2197
- "Mean Test Error 1.7642206047609448\n"
2198
- ]
2199
- },
2200
- {
2201
- "data": {
2202
- "application/vnd.jupyter.widget-view+json": {
2203
- "model_id": "a493dd70ffb64cd19830e5dc98607979",
2204
- "version_major": 2,
2205
- "version_minor": 0
2206
- },
2207
- "text/plain": [
2208
- " 0%| | 0/35000 [00:00<?, ?it/s]"
2209
- ]
2210
- },
2211
- "metadata": {},
2212
- "output_type": "display_data"
2213
- },
2214
- {
2215
- "name": "stderr",
2216
- "output_type": "stream",
2217
- "text": [
2218
- "\n",
2219
- "KeyboardInterrupt\n",
2220
- "\n"
2221
- ]
2222
- }
2223
- ],
2224
- "source": [
2225
- "epochs = 16\n",
2226
- "\n",
2227
- "for epoch in range(epochs):\n",
2228
- " loop = tqdm(loader, leave=True)\n",
2229
- " model.train()\n",
2230
- " mean_training_error = []\n",
2231
- " for batch in loop:\n",
2232
- " optim.zero_grad()\n",
2233
- " \n",
2234
- " input_ids = batch['input_ids'].to(device)\n",
2235
- " attention_mask = batch['attention_mask'].to(device)\n",
2236
- " start = batch['start_positions'].to(device)\n",
2237
- " end = batch['end_positions'].to(device)\n",
2238
- " \n",
2239
- " outputs = model(input_ids, attention_mask=attention_mask, start_positions=start, end_positions=end)\n",
2240
- " # print(torch.argmax(outputs['start_logits'],axis=1), torch.argmax(outputs['end_logits'], axis=1), start, end)\n",
2241
- " loss = outputs['loss']\n",
2242
- " loss.backward()\n",
2243
- " # torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1)\n",
2244
- " optim.step()\n",
2245
- " mean_training_error.append(loss.item())\n",
2246
- " loop.set_description(f'Epoch {epoch}')\n",
2247
- " loop.set_postfix(loss=loss.item())\n",
2248
- " print(\"Mean Training Error\", np.mean(mean_training_error))\n",
2249
- " \n",
2250
- " loop = tqdm(test_loader, leave=True)\n",
2251
- " model.eval()\n",
2252
- " mean_test_error = []\n",
2253
- " for batch in loop:\n",
2254
- " \n",
2255
- " input_ids = batch['input_ids'].to(device)\n",
2256
- " attention_mask = batch['attention_mask'].to(device)\n",
2257
- " start = batch['start_positions'].to(device)\n",
2258
- " end = batch['end_positions'].to(device)\n",
2259
- " \n",
2260
- " outputs = model(input_ids, attention_mask=attention_mask, start_positions=start, end_positions=end)\n",
2261
- " # print(torch.argmax(outputs['start_logits'],axis=1), torch.argmax(outputs['end_logits'], axis=1), start, end)\n",
2262
- " loss = outputs['loss']\n",
2263
- " \n",
2264
- " mean_test_error.append(loss.item())\n",
2265
- " loop.set_description(f'Epoch {epoch} Testset')\n",
2266
- " loop.set_postfix(loss=loss.item())\n",
2267
- " print(\"Mean Test Error\", np.mean(mean_test_error))\n",
2268
- " torch.save(model.state_dict(), \"distilbert_reuse_{}\".format(epoch))"
2269
- ]
2270
- },
2271
- {
2272
- "cell_type": "code",
2273
- "execution_count": 48,
2274
- "id": "fdf37d18",
2275
- "metadata": {},
2276
- "outputs": [],
2277
- "source": [
2278
- "torch.save(model.state_dict(), \"distilbert_reuse.model\")"
2279
- ]
2280
- },
2281
- {
2282
- "cell_type": "code",
2283
- "execution_count": 49,
2284
- "id": "d1cfded4",
2285
- "metadata": {},
2286
- "outputs": [],
2287
- "source": [
2288
- "m = ReuseQuestionDistilBERT(mod)\n",
2289
- "m.load_state_dict(torch.load(\"distilbert_reuse.model\"))\n",
2290
- "model = m"
2291
- ]
2292
- },
2293
- {
2294
- "cell_type": "code",
2295
- "execution_count": 47,
2296
- "id": "233bdc18",
2297
- "metadata": {},
2298
- "outputs": [
2299
- {
2300
- "name": "stderr",
2301
- "output_type": "stream",
2302
- "text": [
2303
- "100%|鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅| 2500/2500 [02:51<00:00, 14.59it/s]"
2304
- ]
2305
- },
2306
- {
2307
- "name": "stdout",
2308
- "output_type": "stream",
2309
- "text": [
2310
- "Mean EM: 0.5178\n",
2311
- "Mean F-1: 0.6671140689626448\n"
2312
- ]
2313
- },
2314
- {
2315
- "name": "stderr",
2316
- "output_type": "stream",
2317
- "text": [
2318
- "\n"
2319
- ]
2320
- }
2321
- ],
2322
- "source": [
2323
- "eval_test_set(model, tokenizer, test_loader, device)"
2324
- ]
2325
- },
2326
- {
2327
- "cell_type": "code",
2328
- "execution_count": null,
2329
- "id": "0fb1ce9e",
2330
- "metadata": {},
2331
- "outputs": [],
2332
- "source": []
2333
- }
2334
- ],
2335
- "metadata": {
2336
- "kernelspec": {
2337
- "display_name": "Python 3.10.8 ('venv': venv)",
2338
- "language": "python",
2339
- "name": "python3"
2340
- },
2341
- "language_info": {
2342
- "codemirror_mode": {
2343
- "name": "ipython",
2344
- "version": 3
2345
- },
2346
- "file_extension": ".py",
2347
- "mimetype": "text/x-python",
2348
- "name": "python",
2349
- "nbconvert_exporter": "python",
2350
- "pygments_lexer": "ipython3",
2351
- "version": "3.10.8"
2352
- },
2353
- "toc": {
2354
- "base_numbering": 1,
2355
- "nav_menu": {},
2356
- "number_sections": true,
2357
- "sideBar": true,
2358
- "skip_h1_title": false,
2359
- "title_cell": "Table of Contents",
2360
- "title_sidebar": "Contents",
2361
- "toc_cell": false,
2362
- "toc_position": {},
2363
- "toc_section_display": true,
2364
- "toc_window_display": false
2365
- },
2366
- "varInspector": {
2367
- "cols": {
2368
- "lenName": 16,
2369
- "lenType": 16,
2370
- "lenVar": 40
2371
- },
2372
- "kernels_config": {
2373
- "python": {
2374
- "delete_cmd_postfix": "",
2375
- "delete_cmd_prefix": "del ",
2376
- "library": "var_list.py",
2377
- "varRefreshCmd": "print(var_dic_list())"
2378
- },
2379
- "r": {
2380
- "delete_cmd_postfix": ") ",
2381
- "delete_cmd_prefix": "rm(",
2382
- "library": "var_list.r",
2383
- "varRefreshCmd": "cat(var_dic_list()) "
2384
- }
2385
- },
2386
- "types_to_exclude": [
2387
- "module",
2388
- "function",
2389
- "builtin_function_or_method",
2390
- "instance",
2391
- "_Feature"
2392
- ],
2393
- "window_display": false
2394
- },
2395
- "vscode": {
2396
- "interpreter": {
2397
- "hash": "85bf9c14e9ba73b783ed1274d522bec79eb0b2b739090180d8ce17bb11aff4aa"
2398
- }
2399
- }
2400
- },
2401
- "nbformat": 4,
2402
- "nbformat_minor": 5
2403
- }