Awaitinf commited on
Commit
9ddf22a
·
verified ·
1 Parent(s): c82fe68

Upload modeling.ipynb

Browse files
Files changed (1) hide show
  1. modeling.ipynb +1220 -0
modeling.ipynb ADDED
@@ -0,0 +1,1220 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "id": "initial_id",
6
+ "metadata": {
7
+ "collapsed": true,
8
+ "ExecuteTime": {
9
+ "end_time": "2025-03-25T16:04:47.234614Z",
10
+ "start_time": "2025-03-25T16:04:47.228876Z"
11
+ }
12
+ },
13
+ "source": [
14
+ "import torch\n",
15
+ "import torch.nn as nn\n",
16
+ "from transformers import BertForSequenceClassification, BertTokenizerFast, BertModel, BertPreTrainedModel, BertConfig\n",
17
+ "from transformers.modeling_outputs import BaseModelOutput, SequenceClassifierOutput\n",
18
+ "from typing import Optional, Tuple, Union"
19
+ ],
20
+ "outputs": [],
21
+ "execution_count": 46
22
+ },
23
+ {
24
+ "metadata": {
25
+ "ExecuteTime": {
26
+ "end_time": "2025-03-25T16:04:47.392034Z",
27
+ "start_time": "2025-03-25T16:04:47.379839Z"
28
+ }
29
+ },
30
+ "cell_type": "code",
31
+ "source": [
32
+ "class BertConvModel(BertPreTrainedModel):\n",
33
+ " def __init__(self, config: BertConfig):\n",
34
+ " super().__init__(config)\n",
35
+ " self.encoder = BertModel(config)\n",
36
+ " self.conv3 = nn.Conv1d(\n",
37
+ " in_channels=config.hidden_size,\n",
38
+ " out_channels=256,\n",
39
+ " kernel_size=3,\n",
40
+ " padding=1,\n",
41
+ " )\n",
42
+ " self.conv5 = nn.Conv1d(\n",
43
+ " in_channels=config.hidden_size,\n",
44
+ " out_channels=256,\n",
45
+ " kernel_size=5,\n",
46
+ " padding=2,\n",
47
+ " )\n",
48
+ " self.conv7 = nn.Conv1d(\n",
49
+ " in_channels=config.hidden_size,\n",
50
+ " out_channels=256,\n",
51
+ " kernel_size=7,\n",
52
+ " padding=3,\n",
53
+ " )\n",
54
+ " self.conv_bn = nn.BatchNorm1d(256*3)\n",
55
+ " self.linear = nn.Linear(256*3, config.hidden_size)\n",
56
+ " self.act = nn.GELU()\n",
57
+ " self.layernorm = nn.LayerNorm(config.hidden_size)\n",
58
+ "\n",
59
+ " def forward(self, input_ids, attention_mask, token_type_ids):\n",
60
+ " encoder_outputs = self.encoder(input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids)\n",
61
+ " last_hidden_state = encoder_outputs.last_hidden_state # [B, L, H]\n",
62
+ "\n",
63
+ " hidden_conv = last_hidden_state.permute(0, 2, 1) # [B, H, L]\n",
64
+ "\n",
65
+ " combined = torch.cat([\n",
66
+ " self.conv3(hidden_conv),\n",
67
+ " self.conv5(hidden_conv),\n",
68
+ " self.conv7(hidden_conv),\n",
69
+ " ], dim=1).permute(0,2, 1) # [B, L, H]\n",
70
+ " fused = self.linear(combined)\n",
71
+ " fused = self.act(fused)\n",
72
+ "\n",
73
+ " output = last_hidden_state + fused\n",
74
+ " output = self.layernorm(output)\n",
75
+ "\n",
76
+ " return BaseModelOutput(\n",
77
+ " last_hidden_state=output\n",
78
+ " )"
79
+ ],
80
+ "id": "34d786f5b97b8bab",
81
+ "outputs": [],
82
+ "execution_count": 47
83
+ },
84
+ {
85
+ "metadata": {
86
+ "ExecuteTime": {
87
+ "end_time": "2025-03-25T16:04:47.507570Z",
88
+ "start_time": "2025-03-25T16:04:47.490208Z"
89
+ }
90
+ },
91
+ "cell_type": "code",
92
+ "source": [
93
+ "class BertConvForSequenceClassification(BertPreTrainedModel):\n",
94
+ " def __init__(self, config: BertConfig):\n",
95
+ " super().__init__(config)\n",
96
+ " self.config = config\n",
97
+ " self.num_labels = config.num_labels\n",
98
+ " self.bert_conv = BertConvModel(config)\n",
99
+ " classifier_dropout = (\n",
100
+ " config.classifier_dropout if config.classifier_dropout is not None\n",
101
+ " else config.hidden_dropout_prob\n",
102
+ " )\n",
103
+ " self.dropout = nn.Dropout(classifier_dropout)\n",
104
+ " self.classifier = nn.Linear(config.hidden_size, config.num_labels)\n",
105
+ "\n",
106
+ " self.post_init()\n",
107
+ "\n",
108
+ " def forward(\n",
109
+ " self,\n",
110
+ " input_ids: Optional[torch.Tensor] = None,\n",
111
+ " attention_mask: Optional[torch.Tensor] = None,\n",
112
+ " token_type_ids: Optional[torch.Tensor] = None,\n",
113
+ " labels: Optional[torch.Tensor] = None,\n",
114
+ " return_dict: Optional[bool] = None,\n",
115
+ " ) -> Union[Tuple[torch.Tensor], SequenceClassifierOutput]:\n",
116
+ " return_dict = return_dict if return_dict is not None else self.config.use_return_dict\n",
117
+ "\n",
118
+ " outputs = self.bert_conv(\n",
119
+ " input_ids=input_ids,\n",
120
+ " attention_mask=attention_mask,\n",
121
+ " token_type_ids=token_type_ids\n",
122
+ " )\n",
123
+ "\n",
124
+ " last_hidden_state = outputs.last_hidden_state\n",
125
+ " pooled_output = last_hidden_state[:, 0, :]\n",
126
+ " pooled_output = self.dropout(pooled_output)\n",
127
+ " logits = self.classifier(pooled_output)\n",
128
+ "\n",
129
+ " loss = None\n",
130
+ " if labels is not None:\n",
131
+ " if self.config.problem_type is None:\n",
132
+ " if self.num_labels == 1:\n",
133
+ " self.config.problem_type = \"regression\"\n",
134
+ " elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):\n",
135
+ " self.config.problem_type = \"single_label_classification\"\n",
136
+ " else:\n",
137
+ " self.config.problem_type = \"multi_label_classification\"\n",
138
+ "\n",
139
+ " if self.config.problem_type == \"regression\":\n",
140
+ " loss_fct = nn.MSELoss()\n",
141
+ " if self.num_labels == 1:\n",
142
+ " loss = loss_fct(logits.squeeze(), labels.squeeze())\n",
143
+ " else:\n",
144
+ " loss = loss_fct(logits, labels)\n",
145
+ " elif self.config.problem_type == \"single_label_classification\":\n",
146
+ " loss_fct = nn.CrossEntropyLoss()\n",
147
+ " loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))\n",
148
+ " elif self.config.problem_type == \"multi_label_classification\":\n",
149
+ " loss_fct = nn.BCEWithLogitsLoss()\n",
150
+ " loss = loss_fct(logits, labels)\n",
151
+ "\n",
152
+ " if not return_dict:\n",
153
+ " output = (logits,) + outputs[2:]\n",
154
+ " return ((loss,) + output) if loss is not None else output\n",
155
+ "\n",
156
+ " return SequenceClassifierOutput(\n",
157
+ " loss=loss,\n",
158
+ " logits=logits,\n",
159
+ " hidden_states=outputs.hidden_states,\n",
160
+ " attentions=outputs.attentions,\n",
161
+ " )"
162
+ ],
163
+ "id": "e1afead74e5d56c8",
164
+ "outputs": [],
165
+ "execution_count": 48
166
+ },
167
+ {
168
+ "metadata": {
169
+ "ExecuteTime": {
170
+ "end_time": "2025-03-25T16:04:47.593331Z",
171
+ "start_time": "2025-03-25T16:04:47.588584Z"
172
+ }
173
+ },
174
+ "cell_type": "code",
175
+ "source": "from datasets import load_dataset, concatenate_datasets, DatasetDict",
176
+ "id": "ef15760c46f3148b",
177
+ "outputs": [],
178
+ "execution_count": 49
179
+ },
180
+ {
181
+ "metadata": {
182
+ "ExecuteTime": {
183
+ "end_time": "2025-03-25T14:49:08.446489Z",
184
+ "start_time": "2025-03-25T14:49:03.966024Z"
185
+ }
186
+ },
187
+ "cell_type": "code",
188
+ "source": [
189
+ "mnli = load_dataset(\"bias-amplified-splits/mnli\", \"minority_examples\")\n",
190
+ "mnli"
191
+ ],
192
+ "id": "dec1c1ae07c4474",
193
+ "outputs": [
194
+ {
195
+ "data": {
196
+ "text/plain": [
197
+ "DatasetDict({\n",
198
+ " train.biased: Dataset({\n",
199
+ " features: ['premise', 'hypothesis', 'label', 'idx'],\n",
200
+ " num_rows: 309873\n",
201
+ " })\n",
202
+ " train.anti_biased: Dataset({\n",
203
+ " features: ['premise', 'hypothesis', 'label', 'idx'],\n",
204
+ " num_rows: 82829\n",
205
+ " })\n",
206
+ " validation_matched.biased: Dataset({\n",
207
+ " features: ['premise', 'hypothesis', 'label', 'idx'],\n",
208
+ " num_rows: 7771\n",
209
+ " })\n",
210
+ " validation_matched.anti_biased: Dataset({\n",
211
+ " features: ['premise', 'hypothesis', 'label', 'idx'],\n",
212
+ " num_rows: 2044\n",
213
+ " })\n",
214
+ " validation_mismatched.biased: Dataset({\n",
215
+ " features: ['premise', 'hypothesis', 'label', 'idx'],\n",
216
+ " num_rows: 7797\n",
217
+ " })\n",
218
+ " validation_mismatched.anti_biased: Dataset({\n",
219
+ " features: ['premise', 'hypothesis', 'label', 'idx'],\n",
220
+ " num_rows: 2035\n",
221
+ " })\n",
222
+ "})"
223
+ ]
224
+ },
225
+ "execution_count": 7,
226
+ "metadata": {},
227
+ "output_type": "execute_result"
228
+ }
229
+ ],
230
+ "execution_count": 7
231
+ },
232
+ {
233
+ "metadata": {
234
+ "ExecuteTime": {
235
+ "end_time": "2025-03-25T14:49:08.554250Z",
236
+ "start_time": "2025-03-25T14:49:08.506856Z"
237
+ }
238
+ },
239
+ "cell_type": "code",
240
+ "source": [
241
+ "train = concatenate_datasets([mnli[\"train.biased\"], mnli[\"train.anti_biased\"]])\n",
242
+ "\n",
243
+ "val_matched_biased = mnli[\"validation_matched.biased\"]\n",
244
+ "val_matched_anti_biased = mnli[\"validation_matched.anti_biased\"]\n",
245
+ "val_matched = concatenate_datasets([val_matched_biased, val_matched_anti_biased])\n",
246
+ "\n",
247
+ "val_mismatched_biased = mnli[\"validation_mismatched.biased\"]\n",
248
+ "val_mismatched_anti_biased = mnli[\"validation_mismatched.anti_biased\"]\n",
249
+ "val_mismatched = concatenate_datasets([val_mismatched_biased, val_mismatched_anti_biased])\n",
250
+ "\n",
251
+ "test = concatenate_datasets([val_matched, val_mismatched])"
252
+ ],
253
+ "id": "f7a87126395bf25d",
254
+ "outputs": [],
255
+ "execution_count": 8
256
+ },
257
+ {
258
+ "metadata": {
259
+ "ExecuteTime": {
260
+ "end_time": "2025-03-25T14:49:08.594195Z",
261
+ "start_time": "2025-03-25T14:49:08.575454Z"
262
+ }
263
+ },
264
+ "cell_type": "code",
265
+ "source": [
266
+ "data = DatasetDict({\n",
267
+ " \"train\": train,\n",
268
+ " \"test\": test,\n",
269
+ "}).remove_columns(['idx'])\n",
270
+ "data"
271
+ ],
272
+ "id": "bdeb7ea17acd9bd4",
273
+ "outputs": [
274
+ {
275
+ "data": {
276
+ "text/plain": [
277
+ "DatasetDict({\n",
278
+ " train: Dataset({\n",
279
+ " features: ['premise', 'hypothesis', 'label'],\n",
280
+ " num_rows: 392702\n",
281
+ " })\n",
282
+ " test: Dataset({\n",
283
+ " features: ['premise', 'hypothesis', 'label'],\n",
284
+ " num_rows: 19647\n",
285
+ " })\n",
286
+ "})"
287
+ ]
288
+ },
289
+ "execution_count": 9,
290
+ "metadata": {},
291
+ "output_type": "execute_result"
292
+ }
293
+ ],
294
+ "execution_count": 9
295
+ },
296
+ {
297
+ "metadata": {
298
+ "ExecuteTime": {
299
+ "end_time": "2025-03-25T14:49:08.967254Z",
300
+ "start_time": "2025-03-25T14:49:08.636482Z"
301
+ }
302
+ },
303
+ "cell_type": "code",
304
+ "source": "tokenizer = BertTokenizerFast.from_pretrained(\"bert-base-uncased\")",
305
+ "id": "cb332f816eb96ca6",
306
+ "outputs": [],
307
+ "execution_count": 10
308
+ },
309
+ {
310
+ "metadata": {
311
+ "ExecuteTime": {
312
+ "end_time": "2025-03-25T14:49:09.013451Z",
313
+ "start_time": "2025-03-25T14:49:08.996755Z"
314
+ }
315
+ },
316
+ "cell_type": "code",
317
+ "source": [
318
+ "premise = \"The cat sat on the mat.\"\n",
319
+ "hypothesis = \"The cat was sitting on the mat.\"\n",
320
+ "\n",
321
+ "tokenizer.decode(tokenizer(premise, hypothesis, padding=True)['input_ids'])"
322
+ ],
323
+ "id": "e6b12359e9a054fc",
324
+ "outputs": [
325
+ {
326
+ "data": {
327
+ "text/plain": [
328
+ "'[CLS] the cat sat on the mat. [SEP] the cat was sitting on the mat. [SEP]'"
329
+ ]
330
+ },
331
+ "execution_count": 11,
332
+ "metadata": {},
333
+ "output_type": "execute_result"
334
+ }
335
+ ],
336
+ "execution_count": 11
337
+ },
338
+ {
339
+ "metadata": {
340
+ "ExecuteTime": {
341
+ "end_time": "2025-03-25T14:49:09.066234Z",
342
+ "start_time": "2025-03-25T14:49:09.061293Z"
343
+ }
344
+ },
345
+ "cell_type": "code",
346
+ "source": [
347
+ "def preprocess(examples):\n",
348
+ " return tokenizer(examples['premise'], examples['hypothesis'], truncation=\"longest_first\", max_length=512)"
349
+ ],
350
+ "id": "403351acf45cb794",
351
+ "outputs": [],
352
+ "execution_count": 12
353
+ },
354
+ {
355
+ "metadata": {
356
+ "ExecuteTime": {
357
+ "end_time": "2025-03-25T14:49:16.810416Z",
358
+ "start_time": "2025-03-25T14:49:09.115611Z"
359
+ }
360
+ },
361
+ "cell_type": "code",
362
+ "source": "tokenized_data = data.map(preprocess, batched=True, num_proc=20, remove_columns=(\"premise\", \"hypothesis\"))",
363
+ "id": "737168d34408b655",
364
+ "outputs": [
365
+ {
366
+ "data": {
367
+ "text/plain": [
368
+ "Map (num_proc=20): 0%| | 0/392702 [00:00<?, ? examples/s]"
369
+ ],
370
+ "application/vnd.jupyter.widget-view+json": {
371
+ "version_major": 2,
372
+ "version_minor": 0,
373
+ "model_id": "a5873684701845f1ae4038a470915b95"
374
+ }
375
+ },
376
+ "metadata": {},
377
+ "output_type": "display_data"
378
+ },
379
+ {
380
+ "data": {
381
+ "text/plain": [
382
+ "Map (num_proc=20): 0%| | 0/19647 [00:00<?, ? examples/s]"
383
+ ],
384
+ "application/vnd.jupyter.widget-view+json": {
385
+ "version_major": 2,
386
+ "version_minor": 0,
387
+ "model_id": "db032cc4d3d94defb7ce0102e49fc98b"
388
+ }
389
+ },
390
+ "metadata": {},
391
+ "output_type": "display_data"
392
+ }
393
+ ],
394
+ "execution_count": 13
395
+ },
396
+ {
397
+ "metadata": {
398
+ "ExecuteTime": {
399
+ "end_time": "2025-03-25T14:49:16.840646Z",
400
+ "start_time": "2025-03-25T14:49:16.825416Z"
401
+ }
402
+ },
403
+ "cell_type": "code",
404
+ "source": [
405
+ "from transformers import DataCollatorWithPadding\n",
406
+ "from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score"
407
+ ],
408
+ "id": "ec5d3fbf714a893a",
409
+ "outputs": [],
410
+ "execution_count": 14
411
+ },
412
+ {
413
+ "metadata": {
414
+ "ExecuteTime": {
415
+ "end_time": "2025-03-25T14:49:16.895814Z",
416
+ "start_time": "2025-03-25T14:49:16.891528Z"
417
+ }
418
+ },
419
+ "cell_type": "code",
420
+ "source": "data_collator = DataCollatorWithPadding(tokenizer=tokenizer, padding='longest', max_length=512)",
421
+ "id": "ef0c6727f15456db",
422
+ "outputs": [],
423
+ "execution_count": 15
424
+ },
425
+ {
426
+ "metadata": {
427
+ "ExecuteTime": {
428
+ "end_time": "2025-03-25T14:49:16.941226Z",
429
+ "start_time": "2025-03-25T14:49:16.935457Z"
430
+ }
431
+ },
432
+ "cell_type": "code",
433
+ "source": [
434
+ "def compute_metrics(pred):\n",
435
+ " labels = pred.label_ids\n",
436
+ " preds = pred.predictions.argmax(-1)\n",
437
+ "\n",
438
+ " # Calculate accuracy\n",
439
+ " accuracy = accuracy_score(labels, preds)\n",
440
+ "\n",
441
+ " # Calculate precision, recall, and F1-score\n",
442
+ " precision = precision_score(labels, preds, average='weighted')\n",
443
+ " recall = recall_score(labels, preds, average='weighted')\n",
444
+ " f1 = f1_score(labels, preds, average='weighted')\n",
445
+ "\n",
446
+ " return {\n",
447
+ " 'accuracy': accuracy,\n",
448
+ " 'precision': precision,\n",
449
+ " 'recall': recall,\n",
450
+ " 'f1': f1\n",
451
+ " }"
452
+ ],
453
+ "id": "3cd73639be70537d",
454
+ "outputs": [],
455
+ "execution_count": 16
456
+ },
457
+ {
458
+ "metadata": {
459
+ "ExecuteTime": {
460
+ "end_time": "2025-03-25T14:49:16.996574Z",
461
+ "start_time": "2025-03-25T14:49:16.991601Z"
462
+ }
463
+ },
464
+ "cell_type": "code",
465
+ "source": [
466
+ "id2label = {0: \"entailment\", 1: \"neutral\", 2: \"contradiction\"}\n",
467
+ "label2id = {\"entailment\": 0, \"neutral\": 1, \"contradiction\": 2}"
468
+ ],
469
+ "id": "885857aacaa3eab7",
470
+ "outputs": [],
471
+ "execution_count": 17
472
+ },
473
+ {
474
+ "metadata": {
475
+ "ExecuteTime": {
476
+ "end_time": "2025-03-25T16:11:57.476093Z",
477
+ "start_time": "2025-03-25T16:11:54.107040Z"
478
+ }
479
+ },
480
+ "cell_type": "code",
481
+ "source": [
482
+ "config = BertConfig.from_pretrained(\"bert-base-uncased\", num_labels=3, id2label=id2label, label2id=label2id)\n",
483
+ "model = BertModel.from_pretrained('bert-base-uncased', config=config)\n",
484
+ "encoder = BertConvModel(config)\n",
485
+ "encoder.encoder = model\n",
486
+ "model = BertConvForSequenceClassification(config)\n",
487
+ "model.bert_conv = encoder\n",
488
+ "model"
489
+ ],
490
+ "id": "7ade0d7a2bcdb241",
491
+ "outputs": [
492
+ {
493
+ "data": {
494
+ "text/plain": [
495
+ "BertConvForSequenceClassification(\n",
496
+ " (bert_conv): BertConvModel(\n",
497
+ " (encoder): BertModel(\n",
498
+ " (embeddings): BertEmbeddings(\n",
499
+ " (word_embeddings): Embedding(30522, 768, padding_idx=0)\n",
500
+ " (position_embeddings): Embedding(512, 768)\n",
501
+ " (token_type_embeddings): Embedding(2, 768)\n",
502
+ " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
503
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
504
+ " )\n",
505
+ " (encoder): BertEncoder(\n",
506
+ " (layer): ModuleList(\n",
507
+ " (0-11): 12 x BertLayer(\n",
508
+ " (attention): BertAttention(\n",
509
+ " (self): BertSdpaSelfAttention(\n",
510
+ " (query): Linear(in_features=768, out_features=768, bias=True)\n",
511
+ " (key): Linear(in_features=768, out_features=768, bias=True)\n",
512
+ " (value): Linear(in_features=768, out_features=768, bias=True)\n",
513
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
514
+ " )\n",
515
+ " (output): BertSelfOutput(\n",
516
+ " (dense): Linear(in_features=768, out_features=768, bias=True)\n",
517
+ " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
518
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
519
+ " )\n",
520
+ " )\n",
521
+ " (intermediate): BertIntermediate(\n",
522
+ " (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
523
+ " (intermediate_act_fn): GELUActivation()\n",
524
+ " )\n",
525
+ " (output): BertOutput(\n",
526
+ " (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
527
+ " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
528
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
529
+ " )\n",
530
+ " )\n",
531
+ " )\n",
532
+ " )\n",
533
+ " (pooler): BertPooler(\n",
534
+ " (dense): Linear(in_features=768, out_features=768, bias=True)\n",
535
+ " (activation): Tanh()\n",
536
+ " )\n",
537
+ " )\n",
538
+ " (conv3): Conv1d(768, 256, kernel_size=(3,), stride=(1,), padding=(1,))\n",
539
+ " (conv5): Conv1d(768, 256, kernel_size=(5,), stride=(1,), padding=(2,))\n",
540
+ " (conv7): Conv1d(768, 256, kernel_size=(7,), stride=(1,), padding=(3,))\n",
541
+ " (linear): Linear(in_features=768, out_features=768, bias=True)\n",
542
+ " (act): GELU(approximate='none')\n",
543
+ " (layernorm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n",
544
+ " )\n",
545
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
546
+ " (classifier): Linear(in_features=768, out_features=3, bias=True)\n",
547
+ ")"
548
+ ]
549
+ },
550
+ "execution_count": 57,
551
+ "metadata": {},
552
+ "output_type": "execute_result"
553
+ }
554
+ ],
555
+ "execution_count": 57
556
+ },
557
+ {
558
+ "metadata": {
559
+ "ExecuteTime": {
560
+ "end_time": "2025-03-25T14:49:17.860824Z",
561
+ "start_time": "2025-03-25T14:49:17.771010Z"
562
+ }
563
+ },
564
+ "cell_type": "code",
565
+ "source": [
566
+ "from transformers import TrainingArguments, Trainer, get_linear_schedule_with_warmup\n",
567
+ "from torch.optim import Adam"
568
+ ],
569
+ "id": "501931d88d6ef5f4",
570
+ "outputs": [],
571
+ "execution_count": 20
572
+ },
573
+ {
574
+ "metadata": {
575
+ "ExecuteTime": {
576
+ "end_time": "2025-03-25T14:52:07.994699Z",
577
+ "start_time": "2025-03-25T14:52:07.988599Z"
578
+ }
579
+ },
580
+ "cell_type": "code",
581
+ "source": [
582
+ "optimizer = Adam(\n",
583
+ " params=model.parameters(),\n",
584
+ " lr=2.5e-5,\n",
585
+ " weight_decay=0.01,\n",
586
+ " betas=(0.9, 0.999),\n",
587
+ " eps=1e-06,\n",
588
+ ")\n",
589
+ "scheduler = get_linear_schedule_with_warmup(\n",
590
+ " optimizer,\n",
591
+ " num_warmup_steps=500,\n",
592
+ " num_training_steps=60000,\n",
593
+ ")"
594
+ ],
595
+ "id": "8b13932fcba55735",
596
+ "outputs": [],
597
+ "execution_count": 38
598
+ },
599
+ {
600
+ "metadata": {
601
+ "ExecuteTime": {
602
+ "end_time": "2025-03-25T16:12:05.683266Z",
603
+ "start_time": "2025-03-25T16:12:05.481694Z"
604
+ }
605
+ },
606
+ "cell_type": "code",
607
+ "source": [
608
+ "training_args = TrainingArguments(\n",
609
+ " output_dir=\"./output\",\n",
610
+ " overwrite_output_dir=True,\n",
611
+ " eval_strategy=\"steps\",\n",
612
+ " logging_strategy=\"steps\",\n",
613
+ " save_strategy=\"steps\",\n",
614
+ " save_steps=5000,\n",
615
+ " eval_steps=5000,\n",
616
+ " logging_steps=5000,\n",
617
+ " max_steps=20000,\n",
618
+ " learning_rate=3e-5,\n",
619
+ " weight_decay=0.001,\n",
620
+ " adam_epsilon=1e-8,\n",
621
+ " warmup_steps=1000,\n",
622
+ " report_to=\"tensorboard\",\n",
623
+ " per_device_train_batch_size=64,\n",
624
+ " #gradient_accumulation_steps=2,\n",
625
+ " per_device_eval_batch_size=256,\n",
626
+ " fp16=True,\n",
627
+ ")\n",
628
+ "\n",
629
+ "trainer = Trainer(\n",
630
+ " model=model,\n",
631
+ " args=training_args,\n",
632
+ " train_dataset=tokenized_data['train'],\n",
633
+ " eval_dataset=tokenized_data['test'],\n",
634
+ " processing_class=tokenizer,\n",
635
+ " data_collator=data_collator,\n",
636
+ " #preprocess_logits_for_metrics=preprocess_logits_for_metrics,\n",
637
+ " compute_metrics=compute_metrics,\n",
638
+ " #optimizers=(optimizer, scheduler),\n",
639
+ ")"
640
+ ],
641
+ "id": "7bf077e4e34721ef",
642
+ "outputs": [],
643
+ "execution_count": 58
644
+ },
645
+ {
646
+ "metadata": {
647
+ "ExecuteTime": {
648
+ "end_time": "2025-03-25T16:12:26.697585Z",
649
+ "start_time": "2025-03-25T16:12:06.841129Z"
650
+ }
651
+ },
652
+ "cell_type": "code",
653
+ "source": "trainer.evaluate()",
654
+ "id": "ff824e1aef72287e",
655
+ "outputs": [
656
+ {
657
+ "data": {
658
+ "text/plain": [
659
+ "<IPython.core.display.HTML object>"
660
+ ],
661
+ "text/html": [
662
+ "\n",
663
+ " <div>\n",
664
+ " \n",
665
+ " <progress value='154' max='77' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
666
+ " [77/77 15:51]\n",
667
+ " </div>\n",
668
+ " "
669
+ ]
670
+ },
671
+ "metadata": {},
672
+ "output_type": "display_data"
673
+ },
674
+ {
675
+ "data": {
676
+ "text/plain": [
677
+ "{'eval_loss': 1.1944094896316528,\n",
678
+ " 'eval_model_preparation_time': 0.007,\n",
679
+ " 'eval_accuracy': 0.3641268387031099,\n",
680
+ " 'eval_precision': 0.3050162608329799,\n",
681
+ " 'eval_recall': 0.3641268387031099,\n",
682
+ " 'eval_f1': 0.29583067778201166,\n",
683
+ " 'eval_runtime': 19.8257,\n",
684
+ " 'eval_samples_per_second': 990.988,\n",
685
+ " 'eval_steps_per_second': 3.884}"
686
+ ]
687
+ },
688
+ "execution_count": 59,
689
+ "metadata": {},
690
+ "output_type": "execute_result"
691
+ }
692
+ ],
693
+ "execution_count": 59
694
+ },
695
+ {
696
+ "metadata": {
697
+ "ExecuteTime": {
698
+ "end_time": "2025-03-25T17:16:14.590823Z",
699
+ "start_time": "2025-03-25T16:12:26.732033Z"
700
+ }
701
+ },
702
+ "cell_type": "code",
703
+ "source": "trainer.train()",
704
+ "id": "46524fd4a95af711",
705
+ "outputs": [
706
+ {
707
+ "data": {
708
+ "text/plain": [
709
+ "<IPython.core.display.HTML object>"
710
+ ],
711
+ "text/html": [
712
+ "\n",
713
+ " <div>\n",
714
+ " \n",
715
+ " <progress value='20000' max='20000' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
716
+ " [20000/20000 1:03:47, Epoch 3/4]\n",
717
+ " </div>\n",
718
+ " <table border=\"1\" class=\"dataframe\">\n",
719
+ " <thead>\n",
720
+ " <tr style=\"text-align: left;\">\n",
721
+ " <th>Step</th>\n",
722
+ " <th>Training Loss</th>\n",
723
+ " <th>Validation Loss</th>\n",
724
+ " <th>Model Preparation Time</th>\n",
725
+ " <th>Accuracy</th>\n",
726
+ " <th>Precision</th>\n",
727
+ " <th>Recall</th>\n",
728
+ " <th>F1</th>\n",
729
+ " </tr>\n",
730
+ " </thead>\n",
731
+ " <tbody>\n",
732
+ " <tr>\n",
733
+ " <td>5000</td>\n",
734
+ " <td>0.567100</td>\n",
735
+ " <td>0.434957</td>\n",
736
+ " <td>0.007000</td>\n",
737
+ " <td>0.831832</td>\n",
738
+ " <td>0.836941</td>\n",
739
+ " <td>0.831832</td>\n",
740
+ " <td>0.832825</td>\n",
741
+ " </tr>\n",
742
+ " <tr>\n",
743
+ " <td>10000</td>\n",
744
+ " <td>0.368900</td>\n",
745
+ " <td>0.424474</td>\n",
746
+ " <td>0.007000</td>\n",
747
+ " <td>0.843895</td>\n",
748
+ " <td>0.845985</td>\n",
749
+ " <td>0.843895</td>\n",
750
+ " <td>0.844391</td>\n",
751
+ " </tr>\n",
752
+ " <tr>\n",
753
+ " <td>15000</td>\n",
754
+ " <td>0.275500</td>\n",
755
+ " <td>0.501343</td>\n",
756
+ " <td>0.007000</td>\n",
757
+ " <td>0.844556</td>\n",
758
+ " <td>0.847259</td>\n",
759
+ " <td>0.844556</td>\n",
760
+ " <td>0.845071</td>\n",
761
+ " </tr>\n",
762
+ " <tr>\n",
763
+ " <td>20000</td>\n",
764
+ " <td>0.201000</td>\n",
765
+ " <td>0.551570</td>\n",
766
+ " <td>0.007000</td>\n",
767
+ " <td>0.845676</td>\n",
768
+ " <td>0.848408</td>\n",
769
+ " <td>0.845676</td>\n",
770
+ " <td>0.846358</td>\n",
771
+ " </tr>\n",
772
+ " </tbody>\n",
773
+ "</table><p>"
774
+ ]
775
+ },
776
+ "metadata": {},
777
+ "output_type": "display_data"
778
+ },
779
+ {
780
+ "data": {
781
+ "text/plain": [
782
+ "TrainOutput(global_step=20000, training_loss=0.35313146667480466, metrics={'train_runtime': 3827.3631, 'train_samples_per_second': 334.434, 'train_steps_per_second': 5.226, 'total_flos': 7.376417927681814e+16, 'train_loss': 0.35313146667480466, 'epoch': 3.259452411994785})"
783
+ ]
784
+ },
785
+ "execution_count": 60,
786
+ "metadata": {},
787
+ "output_type": "execute_result"
788
+ }
789
+ ],
790
+ "execution_count": 60
791
+ },
792
+ {
793
+ "metadata": {},
794
+ "cell_type": "markdown",
795
+ "source": "Result:",
796
+ "id": "3531b1a18e32af40"
797
+ },
798
+ {
799
+ "metadata": {},
800
+ "cell_type": "markdown",
801
+ "source": [
802
+ "<table>\n",
803
+ " <thead>\n",
804
+ " <tr>\n",
805
+ " <th>Step</th>\n",
806
+ " <th>Training Loss</th>\n",
807
+ " <th>Validation Loss</th>\n",
808
+ " <th>Model Preparation Time</th>\n",
809
+ " <th>Accuracy</th>\n",
810
+ " <th>Precision</th>\n",
811
+ " <th>Recall</th>\n",
812
+ " <th>F1</th>\n",
813
+ " </tr>\n",
814
+ " </thead>\n",
815
+ " <tbody>\n",
816
+ " <tr>\n",
817
+ " <td>5000</td>\n",
818
+ " <td>0.567100</td>\n",
819
+ " <td>0.434957</td>\n",
820
+ " <td>0.007000</td>\n",
821
+ " <td>0.831832</td>\n",
822
+ " <td>0.836941</td>\n",
823
+ " <td>0.831832</td>\n",
824
+ " <td>0.832825</td>\n",
825
+ " </tr>\n",
826
+ " <tr>\n",
827
+ " <td>10000</td>\n",
828
+ " <td>0.368900</td>\n",
829
+ " <td>0.424474</td>\n",
830
+ " <td>0.007000</td>\n",
831
+ " <td>0.843895</td>\n",
832
+ " <td>0.845985</td>\n",
833
+ " <td>0.843895</td>\n",
834
+ " <td>0.844391</td>\n",
835
+ " </tr>\n",
836
+ " <tr>\n",
837
+ " <td>15000</td>\n",
838
+ " <td>0.275500</td>\n",
839
+ " <td>0.501343</td>\n",
840
+ " <td>0.007000</td>\n",
841
+ " <td>0.844556</td>\n",
842
+ " <td>0.847259</td>\n",
843
+ " <td>0.844556</td>\n",
844
+ " <td>0.845071</td>\n",
845
+ " </tr>\n",
846
+ " <tr>\n",
847
+ " <td>20000</td>\n",
848
+ " <td>0.201000</td>\n",
849
+ " <td>0.551570</td>\n",
850
+ " <td>0.007000</td>\n",
851
+ " <td>0.845676</td>\n",
852
+ " <td>0.848408</td>\n",
853
+ " <td>0.845676</td>\n",
854
+ " <td>0.846358</td>\n",
855
+ " </tr>\n",
856
+ " </tbody>\n",
857
+ " </table>\n"
858
+ ],
859
+ "id": "db1ed297ef851eab"
860
+ },
861
+ {
862
+ "metadata": {},
863
+ "cell_type": "markdown",
864
+ "source": "Comparison",
865
+ "id": "e484fedcd827fd96"
866
+ },
867
+ {
868
+ "metadata": {
869
+ "ExecuteTime": {
870
+ "end_time": "2025-03-25T17:16:15.342756Z",
871
+ "start_time": "2025-03-25T17:16:14.703412Z"
872
+ }
873
+ },
874
+ "cell_type": "code",
875
+ "source": "model = BertForSequenceClassification.from_pretrained(\"bert-base-uncased\", num_labels=3, id2label=id2label, label2id=label2id)",
876
+ "id": "8b3a585df6e993c8",
877
+ "outputs": [
878
+ {
879
+ "name": "stderr",
880
+ "output_type": "stream",
881
+ "text": [
882
+ "Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']\n",
883
+ "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
884
+ ]
885
+ }
886
+ ],
887
+ "execution_count": 61
888
+ },
889
+ {
890
+ "metadata": {
891
+ "ExecuteTime": {
892
+ "end_time": "2025-03-25T17:16:15.555113Z",
893
+ "start_time": "2025-03-25T17:16:15.392354Z"
894
+ }
895
+ },
896
+ "cell_type": "code",
897
+ "source": [
898
+ "training_args = TrainingArguments(\n",
899
+ " output_dir=\"./compare\",\n",
900
+ " overwrite_output_dir=True,\n",
901
+ " eval_strategy=\"steps\",\n",
902
+ " logging_strategy=\"steps\",\n",
903
+ " save_strategy=\"steps\",\n",
904
+ " save_steps=5000,\n",
905
+ " eval_steps=5000,\n",
906
+ " logging_steps=5000,\n",
907
+ " max_steps=20000,\n",
908
+ " learning_rate=3e-5,\n",
909
+ " weight_decay=0.001,\n",
910
+ " adam_epsilon=1e-8,\n",
911
+ " warmup_steps=1000,\n",
912
+ " report_to=\"tensorboard\",\n",
913
+ " per_device_train_batch_size=64,\n",
914
+ " #gradient_accumulation_steps=2,\n",
915
+ " per_device_eval_batch_size=256,\n",
916
+ " fp16=True,\n",
917
+ ")\n",
918
+ "\n",
919
+ "trainer = Trainer(\n",
920
+ " model=model,\n",
921
+ " args=training_args,\n",
922
+ " train_dataset=tokenized_data['train'],\n",
923
+ " eval_dataset=tokenized_data['test'],\n",
924
+ " processing_class=tokenizer,\n",
925
+ " data_collator=data_collator,\n",
926
+ " #preprocess_logits_for_metrics=preprocess_logits_for_metrics,\n",
927
+ " compute_metrics=compute_metrics,\n",
928
+ " #optimizers=(optimizer, scheduler),\n",
929
+ ")"
930
+ ],
931
+ "id": "be0ec82ebb4c18ee",
932
+ "outputs": [],
933
+ "execution_count": 62
934
+ },
935
+ {
936
+ "metadata": {
937
+ "ExecuteTime": {
938
+ "end_time": "2025-03-25T17:16:36.635075Z",
939
+ "start_time": "2025-03-25T17:16:15.574263Z"
940
+ }
941
+ },
942
+ "cell_type": "code",
943
+ "source": "trainer.evaluate()",
944
+ "id": "157359d28e31f33f",
945
+ "outputs": [
946
+ {
947
+ "data": {
948
+ "text/plain": [
949
+ "<IPython.core.display.HTML object>"
950
+ ],
951
+ "text/html": [
952
+ "\n",
953
+ " <div>\n",
954
+ " \n",
955
+ " <progress value='154' max='77' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
956
+ " [77/77 16:10]\n",
957
+ " </div>\n",
958
+ " "
959
+ ]
960
+ },
961
+ "metadata": {},
962
+ "output_type": "display_data"
963
+ },
964
+ {
965
+ "data": {
966
+ "text/plain": [
967
+ "{'eval_loss': 1.173392415046692,\n",
968
+ " 'eval_model_preparation_time': 0.0034,\n",
969
+ " 'eval_accuracy': 0.3155189087392477,\n",
970
+ " 'eval_precision': 0.31114208439248486,\n",
971
+ " 'eval_recall': 0.3155189087392477,\n",
972
+ " 'eval_f1': 0.1570748637959829,\n",
973
+ " 'eval_runtime': 21.0427,\n",
974
+ " 'eval_samples_per_second': 933.671,\n",
975
+ " 'eval_steps_per_second': 3.659}"
976
+ ]
977
+ },
978
+ "execution_count": 63,
979
+ "metadata": {},
980
+ "output_type": "execute_result"
981
+ }
982
+ ],
983
+ "execution_count": 63
984
+ },
985
+ {
986
+ "metadata": {
987
+ "ExecuteTime": {
988
+ "end_time": "2025-03-25T18:19:14.283378Z",
989
+ "start_time": "2025-03-25T17:16:36.663666Z"
990
+ }
991
+ },
992
+ "cell_type": "code",
993
+ "source": "trainer.train()",
994
+ "id": "96899639b391ead",
995
+ "outputs": [
996
+ {
997
+ "data": {
998
+ "text/plain": [
999
+ "<IPython.core.display.HTML object>"
1000
+ ],
1001
+ "text/html": [
1002
+ "\n",
1003
+ " <div>\n",
1004
+ " \n",
1005
+ " <progress value='20000' max='20000' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
1006
+ " [20000/20000 1:02:36, Epoch 3/4]\n",
1007
+ " </div>\n",
1008
+ " <table border=\"1\" class=\"dataframe\">\n",
1009
+ " <thead>\n",
1010
+ " <tr style=\"text-align: left;\">\n",
1011
+ " <th>Step</th>\n",
1012
+ " <th>Training Loss</th>\n",
1013
+ " <th>Validation Loss</th>\n",
1014
+ " <th>Model Preparation Time</th>\n",
1015
+ " <th>Accuracy</th>\n",
1016
+ " <th>Precision</th>\n",
1017
+ " <th>Recall</th>\n",
1018
+ " <th>F1</th>\n",
1019
+ " </tr>\n",
1020
+ " </thead>\n",
1021
+ " <tbody>\n",
1022
+ " <tr>\n",
1023
+ " <td>5000</td>\n",
1024
+ " <td>0.566400</td>\n",
1025
+ " <td>0.437872</td>\n",
1026
+ " <td>0.003400</td>\n",
1027
+ " <td>0.831374</td>\n",
1028
+ " <td>0.836292</td>\n",
1029
+ " <td>0.831374</td>\n",
1030
+ " <td>0.832473</td>\n",
1031
+ " </tr>\n",
1032
+ " <tr>\n",
1033
+ " <td>10000</td>\n",
1034
+ " <td>0.369200</td>\n",
1035
+ " <td>0.426002</td>\n",
1036
+ " <td>0.003400</td>\n",
1037
+ " <td>0.843437</td>\n",
1038
+ " <td>0.846317</td>\n",
1039
+ " <td>0.843437</td>\n",
1040
+ " <td>0.844099</td>\n",
1041
+ " </tr>\n",
1042
+ " <tr>\n",
1043
+ " <td>15000</td>\n",
1044
+ " <td>0.276800</td>\n",
1045
+ " <td>0.481546</td>\n",
1046
+ " <td>0.003400</td>\n",
1047
+ " <td>0.842826</td>\n",
1048
+ " <td>0.845495</td>\n",
1049
+ " <td>0.842826</td>\n",
1050
+ " <td>0.843323</td>\n",
1051
+ " </tr>\n",
1052
+ " <tr>\n",
1053
+ " <td>20000</td>\n",
1054
+ " <td>0.203600</td>\n",
1055
+ " <td>0.529640</td>\n",
1056
+ " <td>0.003400</td>\n",
1057
+ " <td>0.844047</td>\n",
1058
+ " <td>0.846602</td>\n",
1059
+ " <td>0.844047</td>\n",
1060
+ " <td>0.844737</td>\n",
1061
+ " </tr>\n",
1062
+ " </tbody>\n",
1063
+ "</table><p>"
1064
+ ]
1065
+ },
1066
+ "metadata": {},
1067
+ "output_type": "display_data"
1068
+ },
1069
+ {
1070
+ "data": {
1071
+ "text/plain": [
1072
+ "TrainOutput(global_step=20000, training_loss=0.35399990234375, metrics={'train_runtime': 3757.0942, 'train_samples_per_second': 340.689, 'train_steps_per_second': 5.323, 'total_flos': 7.083480128775549e+16, 'train_loss': 0.35399990234375, 'epoch': 3.259452411994785})"
1073
+ ]
1074
+ },
1075
+ "execution_count": 64,
1076
+ "metadata": {},
1077
+ "output_type": "execute_result"
1078
+ }
1079
+ ],
1080
+ "execution_count": 64
1081
+ },
1082
+ {
1083
+ "metadata": {},
1084
+ "cell_type": "markdown",
1085
+ "source": "Result:",
1086
+ "id": "6d9604187a3d84c5"
1087
+ },
1088
+ {
1089
+ "metadata": {},
1090
+ "cell_type": "markdown",
1091
+ "source": [
1092
+ "<table>\n",
1093
+ " <thead>\n",
1094
+ " <tr>\n",
1095
+ " <th>Step</th>\n",
1096
+ " <th>Training Loss</th>\n",
1097
+ " <th>Validation Loss</th>\n",
1098
+ " <th>Model Preparation Time</th>\n",
1099
+ " <th>Accuracy</th>\n",
1100
+ " <th>Precision</th>\n",
1101
+ " <th>Recall</th>\n",
1102
+ " <th>F1</th>\n",
1103
+ " </tr>\n",
1104
+ " </thead>\n",
1105
+ " <tbody>\n",
1106
+ " <tr>\n",
1107
+ " <td>5000</td>\n",
1108
+ " <td>0.566400</td>\n",
1109
+ " <td>0.437872</td>\n",
1110
+ " <td>0.003400</td>\n",
1111
+ " <td>0.831374</td>\n",
1112
+ " <td>0.836292</td>\n",
1113
+ " <td>0.831374</td>\n",
1114
+ " <td>0.832473</td>\n",
1115
+ " </tr>\n",
1116
+ " <tr>\n",
1117
+ " <td>10000</td>\n",
1118
+ " <td>0.369200</td>\n",
1119
+ " <td>0.426002</td>\n",
1120
+ " <td>0.003400</td>\n",
1121
+ " <td>0.843437</td>\n",
1122
+ " <td>0.846317</td>\n",
1123
+ " <td>0.843437</td>\n",
1124
+ " <td>0.844099</td>\n",
1125
+ " </tr>\n",
1126
+ " <tr>\n",
1127
+ " <td>15000</td>\n",
1128
+ " <td>0.276800</td>\n",
1129
+ " <td>0.481546</td>\n",
1130
+ " <td>0.003400</td>\n",
1131
+ " <td>0.842826</td>\n",
1132
+ " <td>0.845495</td>\n",
1133
+ " <td>0.842826</td>\n",
1134
+ " <td>0.843323</td>\n",
1135
+ " </tr>\n",
1136
+ " <tr>\n",
1137
+ " <td>20000</td>\n",
1138
+ " <td>0.203600</td>\n",
1139
+ " <td>0.529640</td>\n",
1140
+ " <td>0.003400</td>\n",
1141
+ " <td>0.844047</td>\n",
1142
+ " <td>0.846602</td>\n",
1143
+ " <td>0.844047</td>\n",
1144
+ " <td>0.844737</td>\n",
1145
+ " </tr>\n",
1146
+ " </tbody>\n",
1147
+ "</table>"
1148
+ ],
1149
+ "id": "a32ebe5d6ce99f98"
1150
+ },
1151
+ {
1152
+ "metadata": {},
1153
+ "cell_type": "markdown",
1154
+ "source": "ChromaDB Embedding Function",
1155
+ "id": "f37bfcfe59d88a95"
1156
+ },
1157
+ {
1158
+ "metadata": {},
1159
+ "cell_type": "code",
1160
+ "outputs": [],
1161
+ "execution_count": null,
1162
+ "source": "from chromadb import Documents, EmbeddingFunction, Embeddings",
1163
+ "id": "291aa9e620dc571d"
1164
+ },
1165
+ {
1166
+ "metadata": {},
1167
+ "cell_type": "code",
1168
+ "outputs": [],
1169
+ "execution_count": null,
1170
+ "source": [
1171
+ "class BertConvEmbeddingFunction(EmbeddingFunction):\n",
1172
+ " def __init__(self, model_path, device=None):\n",
1173
+ " super().__init__()\n",
1174
+ " self.model = BertConvModel.from_pretrained(model_path)\n",
1175
+ " self.tokenizer = BertTokenizerFast.from_pretrained(\"bert-base-uncased\")\n",
1176
+ " self.device = device or torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
1177
+ " self.model.to(self.device)\n",
1178
+ " self.model.eval()\n",
1179
+ "\n",
1180
+ " def __call__(self, input: Documents) -> Embeddings:\n",
1181
+ " encoded_input = self.tokenizer(\n",
1182
+ " input,\n",
1183
+ " padding=True,\n",
1184
+ " truncation=True,\n",
1185
+ " max_length=512,\n",
1186
+ " return_tensors=\"pt\",\n",
1187
+ " )\n",
1188
+ " encoded_input = {k: v.to(self.device) for k, v in encoded_input.items()}\n",
1189
+ "\n",
1190
+ " with torch.no_grad():\n",
1191
+ " outputs = self.model(**encoded_input, return_dict=True)\n",
1192
+ "\n",
1193
+ " embeddings = outputs.last_hidden_state.cpu().tolist()\n",
1194
+ " return embeddings"
1195
+ ],
1196
+ "id": "142c0fcc0a92667a"
1197
+ }
1198
+ ],
1199
+ "metadata": {
1200
+ "kernelspec": {
1201
+ "display_name": "Python 3",
1202
+ "language": "python",
1203
+ "name": "python3"
1204
+ },
1205
+ "language_info": {
1206
+ "codemirror_mode": {
1207
+ "name": "ipython",
1208
+ "version": 2
1209
+ },
1210
+ "file_extension": ".py",
1211
+ "mimetype": "text/x-python",
1212
+ "name": "python",
1213
+ "nbconvert_exporter": "python",
1214
+ "pygments_lexer": "ipython2",
1215
+ "version": "2.7.6"
1216
+ }
1217
+ },
1218
+ "nbformat": 4,
1219
+ "nbformat_minor": 5
1220
+ }