alxvlsv commited on
Commit
82c5f87
·
1 Parent(s): 5210d19

training notebook

Browse files
Files changed (1) hide show
  1. notebooks/emotions_training.ipynb +976 -0
notebooks/emotions_training.ipynb ADDED
@@ -0,0 +1,976 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 1,
6
+ "id": "8d5c6c94-3c83-4252-a1d5-690104ac69d9",
7
+ "metadata": {},
8
+ "outputs": [],
9
+ "source": [
10
+ "import os\n",
11
+ "os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"2\""
12
+ ]
13
+ },
14
+ {
15
+ "cell_type": "markdown",
16
+ "id": "9a42461a-8ea6-4760-822d-48b7f055182e",
17
+ "metadata": {},
18
+ "source": [
19
+ "## Imports"
20
+ ]
21
+ },
22
+ {
23
+ "cell_type": "code",
24
+ "execution_count": 2,
25
+ "id": "0374c250-ffae-4ca4-81de-fc1bdce0c98d",
26
+ "metadata": {},
27
+ "outputs": [],
28
+ "source": [
29
+ "from datasets import load_dataset\n",
30
+ "import datasets\n",
31
+ "from transformers import pipeline\n",
32
+ "import torch\n",
33
+ "from transformers import AutoTokenizer, AutoModel, AutoModelForSequenceClassification\n",
34
+ "from torch.utils.data import DataLoader\n",
35
+ "from transformers import Trainer, TrainingArguments\n",
36
+ "\n",
37
+ "import numpy as np\n",
38
+ "from sklearn.metrics import f1_score, accuracy_score, precision_score, recall_score"
39
+ ]
40
+ },
41
+ {
42
+ "cell_type": "markdown",
43
+ "id": "a9414e7d-1b89-4182-b6e6-e483a475f5e2",
44
+ "metadata": {},
45
+ "source": [
46
+ "## Dataset"
47
+ ]
48
+ },
49
+ {
50
+ "cell_type": "code",
51
+ "execution_count": 3,
52
+ "id": "6c542e4a-61e1-4598-ac08-8a36024e07fd",
53
+ "metadata": {},
54
+ "outputs": [],
55
+ "source": [
56
+ "ds = load_dataset(\"seara/ru_go_emotions\", \"simplified\")"
57
+ ]
58
+ },
59
+ {
60
+ "cell_type": "code",
61
+ "execution_count": 4,
62
+ "id": "2bc3197f-513a-4267-8db2-d42d71185314",
63
+ "metadata": {},
64
+ "outputs": [
65
+ {
66
+ "data": {
67
+ "text/plain": [
68
+ "DatasetDict({\n",
69
+ " train: Dataset({\n",
70
+ " features: ['ru_text', 'text', 'labels', 'id'],\n",
71
+ " num_rows: 43410\n",
72
+ " })\n",
73
+ " validation: Dataset({\n",
74
+ " features: ['ru_text', 'text', 'labels', 'id'],\n",
75
+ " num_rows: 5426\n",
76
+ " })\n",
77
+ " test: Dataset({\n",
78
+ " features: ['ru_text', 'text', 'labels', 'id'],\n",
79
+ " num_rows: 5427\n",
80
+ " })\n",
81
+ "})"
82
+ ]
83
+ },
84
+ "execution_count": 4,
85
+ "metadata": {},
86
+ "output_type": "execute_result"
87
+ }
88
+ ],
89
+ "source": [
90
+ "ds"
91
+ ]
92
+ },
93
+ {
94
+ "cell_type": "code",
95
+ "execution_count": 5,
96
+ "id": "b0bc35d4-492d-4fa9-9e5d-54495df25429",
97
+ "metadata": {},
98
+ "outputs": [
99
+ {
100
+ "data": {
101
+ "text/plain": [
102
+ "{'ru_text': Value(dtype='string', id=None),\n",
103
+ " 'text': Value(dtype='string', id=None),\n",
104
+ " 'labels': Sequence(feature=ClassLabel(names=['admiration', 'amusement', 'anger', 'annoyance', 'approval', 'caring', 'confusion', 'curiosity', 'desire', 'disappointment', 'disapproval', 'disgust', 'embarrassment', 'excitement', 'fear', 'gratitude', 'grief', 'joy', 'love', 'nervousness', 'optimism', 'pride', 'realization', 'relief', 'remorse', 'sadness', 'surprise', 'neutral'], id=None), length=-1, id=None),\n",
105
+ " 'id': Value(dtype='string', id=None)}"
106
+ ]
107
+ },
108
+ "execution_count": 5,
109
+ "metadata": {},
110
+ "output_type": "execute_result"
111
+ }
112
+ ],
113
+ "source": [
114
+ "ds['train'].features"
115
+ ]
116
+ },
117
+ {
118
+ "cell_type": "code",
119
+ "execution_count": 6,
120
+ "id": "000a51e0-df27-4a55-94d3-a31e0b0749f7",
121
+ "metadata": {},
122
+ "outputs": [
123
+ {
124
+ "data": {
125
+ "text/plain": [
126
+ "{'ru_text': 'Моя любимая еда — это все, что мне не приходилось готовить самому.',\n",
127
+ " 'text': \"My favourite food is anything I didn't have to cook myself.\",\n",
128
+ " 'labels': [27],\n",
129
+ " 'id': 'eebbqej'}"
130
+ ]
131
+ },
132
+ "execution_count": 6,
133
+ "metadata": {},
134
+ "output_type": "execute_result"
135
+ }
136
+ ],
137
+ "source": [
138
+ "ds['train'][0]"
139
+ ]
140
+ },
141
+ {
142
+ "cell_type": "code",
143
+ "execution_count": 7,
144
+ "id": "47601982-246d-4815-b8fd-f4dd9ea3b736",
145
+ "metadata": {},
146
+ "outputs": [
147
+ {
148
+ "data": {
149
+ "text/plain": [
150
+ "['admiration',\n",
151
+ " 'amusement',\n",
152
+ " 'anger',\n",
153
+ " 'annoyance',\n",
154
+ " 'approval',\n",
155
+ " 'caring',\n",
156
+ " 'confusion',\n",
157
+ " 'curiosity',\n",
158
+ " 'desire',\n",
159
+ " 'disappointment',\n",
160
+ " 'disapproval',\n",
161
+ " 'disgust',\n",
162
+ " 'embarrassment',\n",
163
+ " 'excitement',\n",
164
+ " 'fear',\n",
165
+ " 'gratitude',\n",
166
+ " 'grief',\n",
167
+ " 'joy',\n",
168
+ " 'love',\n",
169
+ " 'nervousness',\n",
170
+ " 'optimism',\n",
171
+ " 'pride',\n",
172
+ " 'realization',\n",
173
+ " 'relief',\n",
174
+ " 'remorse',\n",
175
+ " 'sadness',\n",
176
+ " 'surprise',\n",
177
+ " 'neutral']"
178
+ ]
179
+ },
180
+ "execution_count": 7,
181
+ "metadata": {},
182
+ "output_type": "execute_result"
183
+ }
184
+ ],
185
+ "source": [
186
+ "ds['train'].features['labels'].feature.names"
187
+ ]
188
+ },
189
+ {
190
+ "cell_type": "code",
191
+ "execution_count": 8,
192
+ "id": "2ab5e33e-12b2-4739-a7cb-17b494bd1c1c",
193
+ "metadata": {},
194
+ "outputs": [],
195
+ "source": [
196
+ "num_classes = len(ds['train'].features['labels'].feature.names)"
197
+ ]
198
+ },
199
+ {
200
+ "cell_type": "code",
201
+ "execution_count": 9,
202
+ "id": "938b4f50-7f39-43dc-9c15-9153838dd575",
203
+ "metadata": {},
204
+ "outputs": [
205
+ {
206
+ "data": {
207
+ "text/plain": [
208
+ "28"
209
+ ]
210
+ },
211
+ "execution_count": 9,
212
+ "metadata": {},
213
+ "output_type": "execute_result"
214
+ }
215
+ ],
216
+ "source": [
217
+ "num_classes"
218
+ ]
219
+ },
220
+ {
221
+ "cell_type": "markdown",
222
+ "id": "32b7f3a5-265d-4674-8a12-63c4c88220b3",
223
+ "metadata": {},
224
+ "source": [
225
+ "## Model"
226
+ ]
227
+ },
228
+ {
229
+ "cell_type": "code",
230
+ "execution_count": 10,
231
+ "id": "a01edf3f-5af5-495a-b1df-9f0fff586a48",
232
+ "metadata": {},
233
+ "outputs": [
234
+ {
235
+ "name": "stderr",
236
+ "output_type": "stream",
237
+ "text": [
238
+ "Some weights of BertForSequenceClassification were not initialized from the model checkpoint at DeepPavlov/rubert-base-cased and are newly initialized: ['classifier.bias', 'classifier.weight']\n",
239
+ "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
240
+ ]
241
+ }
242
+ ],
243
+ "source": [
244
+ "# model_name = 'cointegrated/rubert-tiny2'\n",
245
+ "model_name = 'DeepPavlov/rubert-base-cased'\n",
246
+ "# model_name = 'distilbert-base-cased'\n",
247
+ "tokenizer = AutoTokenizer.from_pretrained(model_name)\n",
248
+ "# model = AutoModel.from_pretrained(model_name)\n",
249
+ "model = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=num_classes, problem_type=\"multi_label_classification\")"
250
+ ]
251
+ },
252
+ {
253
+ "cell_type": "code",
254
+ "execution_count": 11,
255
+ "id": "939cdacf-8c9a-418c-882e-d494f40bc5c6",
256
+ "metadata": {},
257
+ "outputs": [
258
+ {
259
+ "data": {
260
+ "text/plain": [
261
+ "BertForSequenceClassification(\n",
262
+ " (bert): BertModel(\n",
263
+ " (embeddings): BertEmbeddings(\n",
264
+ " (word_embeddings): Embedding(119547, 768, padding_idx=0)\n",
265
+ " (position_embeddings): Embedding(512, 768)\n",
266
+ " (token_type_embeddings): Embedding(2, 768)\n",
267
+ " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
268
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
269
+ " )\n",
270
+ " (encoder): BertEncoder(\n",
271
+ " (layer): ModuleList(\n",
272
+ " (0-11): 12 x BertLayer(\n",
273
+ " (attention): BertAttention(\n",
274
+ " (self): BertSdpaSelfAttention(\n",
275
+ " (query): Linear(in_features=768, out_features=768, bias=True)\n",
276
+ " (key): Linear(in_features=768, out_features=768, bias=True)\n",
277
+ " (value): Linear(in_features=768, out_features=768, bias=True)\n",
278
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
279
+ " )\n",
280
+ " (output): BertSelfOutput(\n",
281
+ " (dense): Linear(in_features=768, out_features=768, bias=True)\n",
282
+ " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
283
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
284
+ " )\n",
285
+ " )\n",
286
+ " (intermediate): BertIntermediate(\n",
287
+ " (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
288
+ " (intermediate_act_fn): GELUActivation()\n",
289
+ " )\n",
290
+ " (output): BertOutput(\n",
291
+ " (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
292
+ " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
293
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
294
+ " )\n",
295
+ " )\n",
296
+ " )\n",
297
+ " )\n",
298
+ " (pooler): BertPooler(\n",
299
+ " (dense): Linear(in_features=768, out_features=768, bias=True)\n",
300
+ " (activation): Tanh()\n",
301
+ " )\n",
302
+ " )\n",
303
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
304
+ " (classifier): Linear(in_features=768, out_features=28, bias=True)\n",
305
+ ")"
306
+ ]
307
+ },
308
+ "execution_count": 11,
309
+ "metadata": {},
310
+ "output_type": "execute_result"
311
+ }
312
+ ],
313
+ "source": [
314
+ "model"
315
+ ]
316
+ },
317
+ {
318
+ "cell_type": "code",
319
+ "execution_count": 12,
320
+ "id": "5828f4c4-f3cc-4bb7-99f1-3f9c001adfeb",
321
+ "metadata": {},
322
+ "outputs": [
323
+ {
324
+ "name": "stderr",
325
+ "output_type": "stream",
326
+ "text": [
327
+ "Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.\n"
328
+ ]
329
+ },
330
+ {
331
+ "name": "stdout",
332
+ "output_type": "stream",
333
+ "text": [
334
+ "SequenceClassifierOutput(loss=None, logits=tensor([[ 0.1746, 0.0823, -0.0107, 0.0438, 0.1315, -0.0874, 0.0370, 0.0327,\n",
335
+ " 0.3731, -0.0010, 0.0453, 0.0532, -0.0753, -0.1153, -0.2895, 0.0379,\n",
336
+ " -0.1960, 0.0733, -0.0482, 0.0208, -0.1297, 0.0133, -0.0212, -0.0974,\n",
337
+ " 0.1149, 0.0732, 0.0702, -0.2103],\n",
338
+ " [ 0.1693, -0.0349, 0.0288, -0.1285, -0.0371, -0.0007, 0.1751, 0.0494,\n",
339
+ " 0.2685, -0.1137, 0.0994, 0.0226, 0.0758, -0.0487, -0.0107, -0.0709,\n",
340
+ " 0.0073, -0.0396, 0.0166, 0.0358, 0.0964, -0.1060, 0.0394, 0.0961,\n",
341
+ " 0.0808, -0.0306, 0.2214, -0.0157]]), hidden_states=None, attentions=None)\n"
342
+ ]
343
+ }
344
+ ],
345
+ "source": [
346
+ "lines = [\n",
347
+ " \"Крутая тачка.\",\n",
348
+ " \"Моя любимая еда — это все, что мне не приходилось готовить самому.\",\n",
349
+ "]\n",
350
+ "\n",
351
+ "tokens_info = tokenizer(lines, padding=True, truncation=True, return_tensors=\"pt\")\n",
352
+ "\n",
353
+ "# прямой проход через модель\n",
354
+ "with torch.no_grad():\n",
355
+ " outputs = model(**tokens_info)\n",
356
+ "\n",
357
+ "print(outputs)"
358
+ ]
359
+ },
360
+ {
361
+ "cell_type": "markdown",
362
+ "id": "8a9f349e-847e-47d5-a6c6-8ea4800f86be",
363
+ "metadata": {},
364
+ "source": [
365
+ "## Tokenize"
366
+ ]
367
+ },
368
+ {
369
+ "cell_type": "code",
370
+ "execution_count": 27,
371
+ "id": "96aa6f5f-5c11-4db4-8288-f2a7bcc571b8",
372
+ "metadata": {},
373
+ "outputs": [],
374
+ "source": [
375
+ "def tokenize_function(examples):\n",
376
+ " return tokenizer(examples[\"ru_text\"], padding='longest', truncation=True)"
377
+ ]
378
+ },
379
+ {
380
+ "cell_type": "code",
381
+ "execution_count": 28,
382
+ "id": "2ea8b2bf-1a8a-4167-a1a0-1f8ce7b20e94",
383
+ "metadata": {},
384
+ "outputs": [],
385
+ "source": [
386
+ "def one_hot_labels(example):\n",
387
+ " one_hot = [0.0] * num_classes\n",
388
+ " for label in example[\"labels\"]:\n",
389
+ " one_hot[label] = 1.0\n",
390
+ " example[\"labels\"] = one_hot\n",
391
+ " return example"
392
+ ]
393
+ },
394
+ {
395
+ "cell_type": "code",
396
+ "execution_count": 41,
397
+ "id": "4b408bb0-97f0-42f4-b13b-35b0b3b01506",
398
+ "metadata": {},
399
+ "outputs": [
400
+ {
401
+ "data": {
402
+ "text/plain": [
403
+ "21"
404
+ ]
405
+ },
406
+ "execution_count": 41,
407
+ "metadata": {},
408
+ "output_type": "execute_result"
409
+ }
410
+ ],
411
+ "source": [
412
+ "len(tokenize_function(ds[\"train\"][2])['input_ids'])"
413
+ ]
414
+ },
415
+ {
416
+ "cell_type": "code",
417
+ "execution_count": 30,
418
+ "id": "0bfdfddd-6d73-45a1-89f5-221a75f28745",
419
+ "metadata": {},
420
+ "outputs": [
421
+ {
422
+ "data": {
423
+ "application/vnd.jupyter.widget-view+json": {
424
+ "model_id": "e8f6a19a7e9449eabd1d62a382d0db96",
425
+ "version_major": 2,
426
+ "version_minor": 0
427
+ },
428
+ "text/plain": [
429
+ "Map: 0%| | 0/43410 [00:00<?, ? examples/s]"
430
+ ]
431
+ },
432
+ "metadata": {},
433
+ "output_type": "display_data"
434
+ },
435
+ {
436
+ "data": {
437
+ "application/vnd.jupyter.widget-view+json": {
438
+ "model_id": "a3d27e8f742642c99e27787850ba4bf9",
439
+ "version_major": 2,
440
+ "version_minor": 0
441
+ },
442
+ "text/plain": [
443
+ "Map: 0%| | 0/5426 [00:00<?, ? examples/s]"
444
+ ]
445
+ },
446
+ "metadata": {},
447
+ "output_type": "display_data"
448
+ },
449
+ {
450
+ "data": {
451
+ "application/vnd.jupyter.widget-view+json": {
452
+ "model_id": "95ac2872c042459c992a935a82549f8d",
453
+ "version_major": 2,
454
+ "version_minor": 0
455
+ },
456
+ "text/plain": [
457
+ "Map: 0%| | 0/5427 [00:00<?, ? examples/s]"
458
+ ]
459
+ },
460
+ "metadata": {},
461
+ "output_type": "display_data"
462
+ },
463
+ {
464
+ "data": {
465
+ "application/vnd.jupyter.widget-view+json": {
466
+ "model_id": "f9f45cce90b140d09ac6f580d1e4e00e",
467
+ "version_major": 2,
468
+ "version_minor": 0
469
+ },
470
+ "text/plain": [
471
+ "Map: 0%| | 0/43410 [00:00<?, ? examples/s]"
472
+ ]
473
+ },
474
+ "metadata": {},
475
+ "output_type": "display_data"
476
+ },
477
+ {
478
+ "data": {
479
+ "application/vnd.jupyter.widget-view+json": {
480
+ "model_id": "740df2a579bb471493de1582891e6e51",
481
+ "version_major": 2,
482
+ "version_minor": 0
483
+ },
484
+ "text/plain": [
485
+ "Map: 0%| | 0/5426 [00:00<?, ? examples/s]"
486
+ ]
487
+ },
488
+ "metadata": {},
489
+ "output_type": "display_data"
490
+ },
491
+ {
492
+ "data": {
493
+ "application/vnd.jupyter.widget-view+json": {
494
+ "model_id": "4f1b63546df94e2e9635fbdcde980b5a",
495
+ "version_major": 2,
496
+ "version_minor": 0
497
+ },
498
+ "text/plain": [
499
+ "Map: 0%| | 0/5427 [00:00<?, ? examples/s]"
500
+ ]
501
+ },
502
+ "metadata": {},
503
+ "output_type": "display_data"
504
+ }
505
+ ],
506
+ "source": [
507
+ "tokenized_datasets = ds.map(tokenize_function)\n",
508
+ "converted_datasets = tokenized_datasets.map(one_hot_labels)"
509
+ ]
510
+ },
511
+ {
512
+ "cell_type": "code",
513
+ "execution_count": 31,
514
+ "id": "84d53da8-5364-4079-b9ee-110e7facef96",
515
+ "metadata": {},
516
+ "outputs": [
517
+ {
518
+ "data": {
519
+ "application/vnd.jupyter.widget-view+json": {
520
+ "model_id": "a3cbea591a984389b0d147a53d20b65f",
521
+ "version_major": 2,
522
+ "version_minor": 0
523
+ },
524
+ "text/plain": [
525
+ "Casting the dataset: 0%| | 0/43410 [00:00<?, ? examples/s]"
526
+ ]
527
+ },
528
+ "metadata": {},
529
+ "output_type": "display_data"
530
+ },
531
+ {
532
+ "data": {
533
+ "application/vnd.jupyter.widget-view+json": {
534
+ "model_id": "d3df1c1449ff4e08a56f18e664c20195",
535
+ "version_major": 2,
536
+ "version_minor": 0
537
+ },
538
+ "text/plain": [
539
+ "Casting the dataset: 0%| | 0/5426 [00:00<?, ? examples/s]"
540
+ ]
541
+ },
542
+ "metadata": {},
543
+ "output_type": "display_data"
544
+ },
545
+ {
546
+ "data": {
547
+ "application/vnd.jupyter.widget-view+json": {
548
+ "model_id": "43fa8619d62443da92be37fb8c8aada1",
549
+ "version_major": 2,
550
+ "version_minor": 0
551
+ },
552
+ "text/plain": [
553
+ "Casting the dataset: 0%| | 0/5427 [00:00<?, ? examples/s]"
554
+ ]
555
+ },
556
+ "metadata": {},
557
+ "output_type": "display_data"
558
+ }
559
+ ],
560
+ "source": [
561
+ "converted_datasets.set_format(type=\"torch\", columns=[\"input_ids\", \"attention_mask\", \"labels\"])\n",
562
+ "converted_datasets = converted_datasets.cast_column(\"labels\", datasets.features.Sequence(datasets.Value(\"float32\")))"
563
+ ]
564
+ },
565
+ {
566
+ "cell_type": "code",
567
+ "execution_count": 32,
568
+ "id": "64326646-2b8b-40f4-99be-9422b53018e4",
569
+ "metadata": {},
570
+ "outputs": [],
571
+ "source": [
572
+ "train_dataset = converted_datasets[\"train\"].shuffle(seed=42)\n",
573
+ "val_dataset = converted_datasets[\"validation\"].shuffle(seed=42)\n",
574
+ "test_dataset = converted_datasets[\"test\"].shuffle(seed=42)"
575
+ ]
576
+ },
577
+ {
578
+ "cell_type": "code",
579
+ "execution_count": 33,
580
+ "id": "eeb7f1cd-6586-40a1-a629-1f2b760fe7d4",
581
+ "metadata": {},
582
+ "outputs": [
583
+ {
584
+ "data": {
585
+ "text/plain": [
586
+ "torch.float32"
587
+ ]
588
+ },
589
+ "execution_count": 33,
590
+ "metadata": {},
591
+ "output_type": "execute_result"
592
+ }
593
+ ],
594
+ "source": [
595
+ "train_dataset['labels'][0].dtype"
596
+ ]
597
+ },
598
+ {
599
+ "cell_type": "code",
600
+ "execution_count": 40,
601
+ "id": "1e28ccc5-3fe6-41b3-b515-f92969660249",
602
+ "metadata": {},
603
+ "outputs": [
604
+ {
605
+ "data": {
606
+ "text/plain": [
607
+ "{'labels': tensor([0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
608
+ " 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]),\n",
609
+ " 'input_ids': tensor([ 101, 11601, 7363, 128, 1761, 18934, 842, 15991, 47993,\n",
610
+ " 860, 1703, 38969, 70261, 128, 63935, 128, 8542, 4725,\n",
611
+ " 106183, 40831, 28231, 845, 10843, 100820, 4346, 89470, 132,\n",
612
+ " 102]),\n",
613
+ " 'attention_mask': tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
614
+ " 1, 1, 1, 1])}"
615
+ ]
616
+ },
617
+ "execution_count": 40,
618
+ "metadata": {},
619
+ "output_type": "execute_result"
620
+ }
621
+ ],
622
+ "source": [
623
+ "train_dataset[4]"
624
+ ]
625
+ },
626
+ {
627
+ "cell_type": "markdown",
628
+ "id": "36f67b80-61bc-4147-b817-ce81296d2ecf",
629
+ "metadata": {},
630
+ "source": [
631
+ "## Training"
632
+ ]
633
+ },
634
+ {
635
+ "cell_type": "code",
636
+ "execution_count": 34,
637
+ "id": "b3e4c392-b725-4647-81bf-bc9f14f5c814",
638
+ "metadata": {},
639
+ "outputs": [],
640
+ "source": [
641
+ "def compute_metrics(eval_pred):\n",
642
+ " logits, labels = eval_pred\n",
643
+ " probs = 1 / (1 + np.exp(-logits)) # Sigmoid\n",
644
+ " preds = (probs > 0.5).astype(int) # Превращаем в 0/1\n",
645
+ "\n",
646
+ " return {\n",
647
+ " \"f1_micro\": f1_score(labels, preds, average=\"micro\"),\n",
648
+ " \"f1_macro\": f1_score(labels, preds, average=\"macro\"),\n",
649
+ " \"precision\": precision_score(labels, preds, average=\"micro\"),\n",
650
+ " \"recall\": recall_score(labels, preds, average=\"micro\"),\n",
651
+ " \"accuracy\": accuracy_score(labels, preds) # Кол-во совпавших полных наборов меток\n",
652
+ " }"
653
+ ]
654
+ },
655
+ {
656
+ "cell_type": "code",
657
+ "execution_count": 35,
658
+ "id": "9ff1a4e7-a3df-4c4f-bba8-8961cfa1a144",
659
+ "metadata": {},
660
+ "outputs": [],
661
+ "source": [
662
+ "training_args = TrainingArguments(\n",
663
+ " output_dir=f\"./rubert\",\n",
664
+ " overwrite_output_dir=True,\n",
665
+ " num_train_epochs=10,\n",
666
+ " learning_rate=1e-5,\n",
667
+ " lr_scheduler_type=\"cosine\",\n",
668
+ " # lr_scheduler_kwargs={},\n",
669
+ " warmup_ratio=0.05,\n",
670
+ " # warmup_steps=10,\n",
671
+ " per_device_train_batch_size=16,\n",
672
+ " gradient_accumulation_steps=1,\n",
673
+ " log_level=\"error\",\n",
674
+ " # logging_dir=\"output_dir/runs/CURRENT_DATETIME_HOSTNAME\" # логи для tensorboard (default)\n",
675
+ " logging_strategy=\"steps\",\n",
676
+ " logging_steps=1,\n",
677
+ " save_strategy=\"epoch\",\n",
678
+ " # save_steps=1,\n",
679
+ " save_total_limit=2,\n",
680
+ " save_safetensors=True, # safetensors вместо torch.save / torch.load\n",
681
+ " save_only_model=False, # сохраняем optimizer, shceduler, rng, ...\n",
682
+ " use_cpu=False,\n",
683
+ " seed=42,\n",
684
+ " # bf16=True, # использовать bf16 вместо fp32\n",
685
+ " eval_strategy=\"epoch\",\n",
686
+ " # eval_steps=32,\n",
687
+ " disable_tqdm=False,\n",
688
+ " load_best_model_at_end=False,\n",
689
+ " # label_smoothing_factor=0.,\n",
690
+ " optim=\"adamw_torch\",\n",
691
+ " # optim_args=...,\n",
692
+ " # resume_from_checkpoint=...,\n",
693
+ " # auto_find_batch_size=...,\n",
694
+ ")"
695
+ ]
696
+ },
697
+ {
698
+ "cell_type": "code",
699
+ "execution_count": 36,
700
+ "id": "7c33f1b2-cdfa-4295-814f-261f066633c8",
701
+ "metadata": {},
702
+ "outputs": [],
703
+ "source": [
704
+ "import gc\n",
705
+ "\n",
706
+ "gc.collect()\n",
707
+ "torch.cuda.empty_cache()"
708
+ ]
709
+ },
710
+ {
711
+ "cell_type": "code",
712
+ "execution_count": 42,
713
+ "id": "a2ac0381-be84-4dd4-94e7-53cfd3bddc63",
714
+ "metadata": {},
715
+ "outputs": [
716
+ {
717
+ "data": {
718
+ "text/html": [
719
+ "\n",
720
+ " <div>\n",
721
+ " \n",
722
+ " <progress value='27140' max='27140' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
723
+ " [27140/27140 25:37, Epoch 10/10]\n",
724
+ " </div>\n",
725
+ " <table border=\"1\" class=\"dataframe\">\n",
726
+ " <thead>\n",
727
+ " <tr style=\"text-align: left;\">\n",
728
+ " <th>Epoch</th>\n",
729
+ " <th>Training Loss</th>\n",
730
+ " <th>Validation Loss</th>\n",
731
+ " <th>F1 Micro</th>\n",
732
+ " <th>F1 Macro</th>\n",
733
+ " <th>Precision</th>\n",
734
+ " <th>Recall</th>\n",
735
+ " <th>Accuracy</th>\n",
736
+ " </tr>\n",
737
+ " </thead>\n",
738
+ " <tbody>\n",
739
+ " <tr>\n",
740
+ " <td>1</td>\n",
741
+ " <td>0.167400</td>\n",
742
+ " <td>0.117092</td>\n",
743
+ " <td>0.365053</td>\n",
744
+ " <td>0.112839</td>\n",
745
+ " <td>0.775458</td>\n",
746
+ " <td>0.238715</td>\n",
747
+ " <td>0.239219</td>\n",
748
+ " </tr>\n",
749
+ " <tr>\n",
750
+ " <td>2</td>\n",
751
+ " <td>0.124500</td>\n",
752
+ " <td>0.098455</td>\n",
753
+ " <td>0.487169</td>\n",
754
+ " <td>0.199681</td>\n",
755
+ " <td>0.705830</td>\n",
756
+ " <td>0.371944</td>\n",
757
+ " <td>0.365647</td>\n",
758
+ " </tr>\n",
759
+ " <tr>\n",
760
+ " <td>3</td>\n",
761
+ " <td>0.069300</td>\n",
762
+ " <td>0.094669</td>\n",
763
+ " <td>0.524119</td>\n",
764
+ " <td>0.308011</td>\n",
765
+ " <td>0.688249</td>\n",
766
+ " <td>0.423197</td>\n",
767
+ " <td>0.404165</td>\n",
768
+ " </tr>\n",
769
+ " <tr>\n",
770
+ " <td>4</td>\n",
771
+ " <td>0.101400</td>\n",
772
+ " <td>0.094210</td>\n",
773
+ " <td>0.524894</td>\n",
774
+ " <td>0.329945</td>\n",
775
+ " <td>0.682731</td>\n",
776
+ " <td>0.426332</td>\n",
777
+ " <td>0.405824</td>\n",
778
+ " </tr>\n",
779
+ " <tr>\n",
780
+ " <td>5</td>\n",
781
+ " <td>0.115100</td>\n",
782
+ " <td>0.097984</td>\n",
783
+ " <td>0.534122</td>\n",
784
+ " <td>0.351584</td>\n",
785
+ " <td>0.636659</td>\n",
786
+ " <td>0.460031</td>\n",
787
+ " <td>0.429414</td>\n",
788
+ " </tr>\n",
789
+ " <tr>\n",
790
+ " <td>6</td>\n",
791
+ " <td>0.030200</td>\n",
792
+ " <td>0.101337</td>\n",
793
+ " <td>0.527109</td>\n",
794
+ " <td>0.364458</td>\n",
795
+ " <td>0.626647</td>\n",
796
+ " <td>0.454859</td>\n",
797
+ " <td>0.423701</td>\n",
798
+ " </tr>\n",
799
+ " <tr>\n",
800
+ " <td>7</td>\n",
801
+ " <td>0.052100</td>\n",
802
+ " <td>0.103811</td>\n",
803
+ " <td>0.527860</td>\n",
804
+ " <td>0.365408</td>\n",
805
+ " <td>0.614664</td>\n",
806
+ " <td>0.462539</td>\n",
807
+ " <td>0.427571</td>\n",
808
+ " </tr>\n",
809
+ " <tr>\n",
810
+ " <td>8</td>\n",
811
+ " <td>0.009300</td>\n",
812
+ " <td>0.105722</td>\n",
813
+ " <td>0.530681</td>\n",
814
+ " <td>0.371352</td>\n",
815
+ " <td>0.608722</td>\n",
816
+ " <td>0.470376</td>\n",
817
+ " <td>0.431810</td>\n",
818
+ " </tr>\n",
819
+ " <tr>\n",
820
+ " <td>9</td>\n",
821
+ " <td>0.008400</td>\n",
822
+ " <td>0.107027</td>\n",
823
+ " <td>0.531044</td>\n",
824
+ " <td>0.374502</td>\n",
825
+ " <td>0.606030</td>\n",
826
+ " <td>0.472571</td>\n",
827
+ " <td>0.432731</td>\n",
828
+ " </tr>\n",
829
+ " <tr>\n",
830
+ " <td>10</td>\n",
831
+ " <td>0.040200</td>\n",
832
+ " <td>0.107173</td>\n",
833
+ " <td>0.530246</td>\n",
834
+ " <td>0.375274</td>\n",
835
+ " <td>0.604983</td>\n",
836
+ " <td>0.471944</td>\n",
837
+ " <td>0.432916</td>\n",
838
+ " </tr>\n",
839
+ " </tbody>\n",
840
+ "</table><p>"
841
+ ],
842
+ "text/plain": [
843
+ "<IPython.core.display.HTML object>"
844
+ ]
845
+ },
846
+ "metadata": {},
847
+ "output_type": "display_data"
848
+ },
849
+ {
850
+ "data": {
851
+ "text/plain": [
852
+ "TrainOutput(global_step=27140, training_loss=0.08366055827060757, metrics={'train_runtime': 1538.1289, 'train_samples_per_second': 282.226, 'train_steps_per_second': 17.645, 'total_flos': 8421320854320816.0, 'train_loss': 0.08366055827060757, 'epoch': 10.0})"
853
+ ]
854
+ },
855
+ "execution_count": 42,
856
+ "metadata": {},
857
+ "output_type": "execute_result"
858
+ }
859
+ ],
860
+ "source": [
861
+ "from transformers import DataCollatorWithPadding\n",
862
+ "\n",
863
+ "data_collator = DataCollatorWithPadding(tokenizer=tokenizer)\n",
864
+ "\n",
865
+ "trainer = Trainer(\n",
866
+ " model=model,\n",
867
+ " args=training_args,\n",
868
+ " train_dataset=train_dataset,\n",
869
+ " eval_dataset=val_dataset,\n",
870
+ " data_collator=data_collator,\n",
871
+ " compute_metrics=compute_metrics,\n",
872
+ ")\n",
873
+ "\n",
874
+ "trainer.train()"
875
+ ]
876
+ },
877
+ {
878
+ "cell_type": "code",
879
+ "execution_count": null,
880
+ "id": "ff46bdbd-e945-4abb-a39d-dca292b9856b",
881
+ "metadata": {},
882
+ "outputs": [],
883
+ "source": []
884
+ },
885
+ {
886
+ "cell_type": "code",
887
+ "execution_count": 48,
888
+ "id": "eaceeaef-7286-48f3-9f79-b35d7d41da23",
889
+ "metadata": {},
890
+ "outputs": [],
891
+ "source": [
892
+ "tokenizer = AutoTokenizer.from_pretrained('emotions/my_model')\n",
893
+ "model = AutoModelForSequenceClassification.from_pretrained('emotions/my_model', num_labels=num_classes, problem_type=\"multi_label_classification\")"
894
+ ]
895
+ },
896
+ {
897
+ "cell_type": "code",
898
+ "execution_count": 50,
899
+ "id": "c3c93f39-b481-4e4a-b1da-dd50f1e94742",
900
+ "metadata": {},
901
+ "outputs": [
902
+ {
903
+ "data": {
904
+ "application/vnd.jupyter.widget-view+json": {
905
+ "model_id": "0a8fc03ea3144dc5b2093dcbc5953a57",
906
+ "version_major": 2,
907
+ "version_minor": 0
908
+ },
909
+ "text/plain": [
910
+ "model.safetensors: 0%| | 0.00/712M [00:00<?, ?B/s]"
911
+ ]
912
+ },
913
+ "metadata": {},
914
+ "output_type": "display_data"
915
+ },
916
+ {
917
+ "data": {
918
+ "application/vnd.jupyter.widget-view+json": {
919
+ "model_id": "b4f26250408b4b4ea332bee596aa14de",
920
+ "version_major": 2,
921
+ "version_minor": 0
922
+ },
923
+ "text/plain": [
924
+ "README.md: 0%| | 0.00/5.17k [00:00<?, ?B/s]"
925
+ ]
926
+ },
927
+ "metadata": {},
928
+ "output_type": "display_data"
929
+ },
930
+ {
931
+ "data": {
932
+ "text/plain": [
933
+ "CommitInfo(commit_url='https://huggingface.co/alxvlsv/rubert-emotions/commit/d958f2338ac01e6fe177f0186124322d6d18114a', commit_message='Upload tokenizer', commit_description='', oid='d958f2338ac01e6fe177f0186124322d6d18114a', pr_url=None, repo_url=RepoUrl('https://huggingface.co/alxvlsv/rubert-emotions', endpoint='https://huggingface.co', repo_type='model', repo_id='alxvlsv/rubert-emotions'), pr_revision=None, pr_num=None)"
934
+ ]
935
+ },
936
+ "execution_count": 50,
937
+ "metadata": {},
938
+ "output_type": "execute_result"
939
+ }
940
+ ],
941
+ "source": [
942
+ "model.push_to_hub(\"alxvlsv/rubert-emotions\")\n",
943
+ "tokenizer.push_to_hub(\"alxvlsv/rubert-emotions\")"
944
+ ]
945
+ },
946
+ {
947
+ "cell_type": "code",
948
+ "execution_count": null,
949
+ "id": "ebb82e42-50c5-4338-ac57-bbffa85c25b1",
950
+ "metadata": {},
951
+ "outputs": [],
952
+ "source": []
953
+ }
954
+ ],
955
+ "metadata": {
956
+ "kernelspec": {
957
+ "display_name": "Python (shad)",
958
+ "language": "python",
959
+ "name": "shad"
960
+ },
961
+ "language_info": {
962
+ "codemirror_mode": {
963
+ "name": "ipython",
964
+ "version": 3
965
+ },
966
+ "file_extension": ".py",
967
+ "mimetype": "text/x-python",
968
+ "name": "python",
969
+ "nbconvert_exporter": "python",
970
+ "pygments_lexer": "ipython3",
971
+ "version": "3.11.11"
972
+ }
973
+ },
974
+ "nbformat": 4,
975
+ "nbformat_minor": 5
976
+ }