ruba2ksa commited on
Commit
beddee5
·
verified ·
1 Parent(s): 25e904c

Upload RoBERTa_Fine_Tuning_Emotion_classification.ipynb

Browse files
RoBERTa_Fine_Tuning_Emotion_classification.ipynb ADDED
@@ -0,0 +1,1612 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "nbformat": 4,
3
+ "nbformat_minor": 0,
4
+ "metadata": {
5
+ "colab": {
6
+ "provenance": []
7
+ },
8
+ "kernelspec": {
9
+ "name": "python3",
10
+ "display_name": "Python 3"
11
+ },
12
+ "accelerator": "GPU",
13
+ "widgets": {
14
+ "application/vnd.jupyter.widget-state+json": {
15
+ "f848095d186b49e08417c293b642faed": {
16
+ "model_module": "@jupyter-widgets/controls",
17
+ "model_name": "HBoxModel",
18
+ "model_module_version": "1.5.0",
19
+ "state": {
20
+ "_dom_classes": [],
21
+ "_model_module": "@jupyter-widgets/controls",
22
+ "_model_module_version": "1.5.0",
23
+ "_model_name": "HBoxModel",
24
+ "_view_count": null,
25
+ "_view_module": "@jupyter-widgets/controls",
26
+ "_view_module_version": "1.5.0",
27
+ "_view_name": "HBoxView",
28
+ "box_style": "",
29
+ "children": [
30
+ "IPY_MODEL_4f6eac487752459b82e2a5ea7d5902c8",
31
+ "IPY_MODEL_f3a2348c535a47878bca775a1f5d50d5",
32
+ "IPY_MODEL_800b720695984617856bd1b4ec7a180c"
33
+ ],
34
+ "layout": "IPY_MODEL_a7911b3fad6a4db9b891e406745bcc19"
35
+ }
36
+ },
37
+ "4f6eac487752459b82e2a5ea7d5902c8": {
38
+ "model_module": "@jupyter-widgets/controls",
39
+ "model_name": "HTMLModel",
40
+ "model_module_version": "1.5.0",
41
+ "state": {
42
+ "_dom_classes": [],
43
+ "_model_module": "@jupyter-widgets/controls",
44
+ "_model_module_version": "1.5.0",
45
+ "_model_name": "HTMLModel",
46
+ "_view_count": null,
47
+ "_view_module": "@jupyter-widgets/controls",
48
+ "_view_module_version": "1.5.0",
49
+ "_view_name": "HTMLView",
50
+ "description": "",
51
+ "description_tooltip": null,
52
+ "layout": "IPY_MODEL_c03b91347c1d43dc81d1c277c9b0ac0a",
53
+ "placeholder": "​",
54
+ "style": "IPY_MODEL_2df14874354f4483a63532dae109082e",
55
+ "value": "model.safetensors: 100%"
56
+ }
57
+ },
58
+ "f3a2348c535a47878bca775a1f5d50d5": {
59
+ "model_module": "@jupyter-widgets/controls",
60
+ "model_name": "FloatProgressModel",
61
+ "model_module_version": "1.5.0",
62
+ "state": {
63
+ "_dom_classes": [],
64
+ "_model_module": "@jupyter-widgets/controls",
65
+ "_model_module_version": "1.5.0",
66
+ "_model_name": "FloatProgressModel",
67
+ "_view_count": null,
68
+ "_view_module": "@jupyter-widgets/controls",
69
+ "_view_module_version": "1.5.0",
70
+ "_view_name": "ProgressView",
71
+ "bar_style": "success",
72
+ "description": "",
73
+ "description_tooltip": null,
74
+ "layout": "IPY_MODEL_f2e2a2f73c724d77bfd0dd01c574d192",
75
+ "max": 331055963,
76
+ "min": 0,
77
+ "orientation": "horizontal",
78
+ "style": "IPY_MODEL_b9ac9418f0474c33a1f40e0e86a8fe74",
79
+ "value": 331055963
80
+ }
81
+ },
82
+ "800b720695984617856bd1b4ec7a180c": {
83
+ "model_module": "@jupyter-widgets/controls",
84
+ "model_name": "HTMLModel",
85
+ "model_module_version": "1.5.0",
86
+ "state": {
87
+ "_dom_classes": [],
88
+ "_model_module": "@jupyter-widgets/controls",
89
+ "_model_module_version": "1.5.0",
90
+ "_model_name": "HTMLModel",
91
+ "_view_count": null,
92
+ "_view_module": "@jupyter-widgets/controls",
93
+ "_view_module_version": "1.5.0",
94
+ "_view_name": "HTMLView",
95
+ "description": "",
96
+ "description_tooltip": null,
97
+ "layout": "IPY_MODEL_d44a2fd4cf724ef6a67662e69a626eee",
98
+ "placeholder": "​",
99
+ "style": "IPY_MODEL_131a2ae47ff14ad38cc60f7434c76bfd",
100
+ "value": " 331M/331M [00:01<00:00, 228MB/s]"
101
+ }
102
+ },
103
+ "a7911b3fad6a4db9b891e406745bcc19": {
104
+ "model_module": "@jupyter-widgets/base",
105
+ "model_name": "LayoutModel",
106
+ "model_module_version": "1.2.0",
107
+ "state": {
108
+ "_model_module": "@jupyter-widgets/base",
109
+ "_model_module_version": "1.2.0",
110
+ "_model_name": "LayoutModel",
111
+ "_view_count": null,
112
+ "_view_module": "@jupyter-widgets/base",
113
+ "_view_module_version": "1.2.0",
114
+ "_view_name": "LayoutView",
115
+ "align_content": null,
116
+ "align_items": null,
117
+ "align_self": null,
118
+ "border": null,
119
+ "bottom": null,
120
+ "display": null,
121
+ "flex": null,
122
+ "flex_flow": null,
123
+ "grid_area": null,
124
+ "grid_auto_columns": null,
125
+ "grid_auto_flow": null,
126
+ "grid_auto_rows": null,
127
+ "grid_column": null,
128
+ "grid_gap": null,
129
+ "grid_row": null,
130
+ "grid_template_areas": null,
131
+ "grid_template_columns": null,
132
+ "grid_template_rows": null,
133
+ "height": null,
134
+ "justify_content": null,
135
+ "justify_items": null,
136
+ "left": null,
137
+ "margin": null,
138
+ "max_height": null,
139
+ "max_width": null,
140
+ "min_height": null,
141
+ "min_width": null,
142
+ "object_fit": null,
143
+ "object_position": null,
144
+ "order": null,
145
+ "overflow": null,
146
+ "overflow_x": null,
147
+ "overflow_y": null,
148
+ "padding": null,
149
+ "right": null,
150
+ "top": null,
151
+ "visibility": null,
152
+ "width": null
153
+ }
154
+ },
155
+ "c03b91347c1d43dc81d1c277c9b0ac0a": {
156
+ "model_module": "@jupyter-widgets/base",
157
+ "model_name": "LayoutModel",
158
+ "model_module_version": "1.2.0",
159
+ "state": {
160
+ "_model_module": "@jupyter-widgets/base",
161
+ "_model_module_version": "1.2.0",
162
+ "_model_name": "LayoutModel",
163
+ "_view_count": null,
164
+ "_view_module": "@jupyter-widgets/base",
165
+ "_view_module_version": "1.2.0",
166
+ "_view_name": "LayoutView",
167
+ "align_content": null,
168
+ "align_items": null,
169
+ "align_self": null,
170
+ "border": null,
171
+ "bottom": null,
172
+ "display": null,
173
+ "flex": null,
174
+ "flex_flow": null,
175
+ "grid_area": null,
176
+ "grid_auto_columns": null,
177
+ "grid_auto_flow": null,
178
+ "grid_auto_rows": null,
179
+ "grid_column": null,
180
+ "grid_gap": null,
181
+ "grid_row": null,
182
+ "grid_template_areas": null,
183
+ "grid_template_columns": null,
184
+ "grid_template_rows": null,
185
+ "height": null,
186
+ "justify_content": null,
187
+ "justify_items": null,
188
+ "left": null,
189
+ "margin": null,
190
+ "max_height": null,
191
+ "max_width": null,
192
+ "min_height": null,
193
+ "min_width": null,
194
+ "object_fit": null,
195
+ "object_position": null,
196
+ "order": null,
197
+ "overflow": null,
198
+ "overflow_x": null,
199
+ "overflow_y": null,
200
+ "padding": null,
201
+ "right": null,
202
+ "top": null,
203
+ "visibility": null,
204
+ "width": null
205
+ }
206
+ },
207
+ "2df14874354f4483a63532dae109082e": {
208
+ "model_module": "@jupyter-widgets/controls",
209
+ "model_name": "DescriptionStyleModel",
210
+ "model_module_version": "1.5.0",
211
+ "state": {
212
+ "_model_module": "@jupyter-widgets/controls",
213
+ "_model_module_version": "1.5.0",
214
+ "_model_name": "DescriptionStyleModel",
215
+ "_view_count": null,
216
+ "_view_module": "@jupyter-widgets/base",
217
+ "_view_module_version": "1.2.0",
218
+ "_view_name": "StyleView",
219
+ "description_width": ""
220
+ }
221
+ },
222
+ "f2e2a2f73c724d77bfd0dd01c574d192": {
223
+ "model_module": "@jupyter-widgets/base",
224
+ "model_name": "LayoutModel",
225
+ "model_module_version": "1.2.0",
226
+ "state": {
227
+ "_model_module": "@jupyter-widgets/base",
228
+ "_model_module_version": "1.2.0",
229
+ "_model_name": "LayoutModel",
230
+ "_view_count": null,
231
+ "_view_module": "@jupyter-widgets/base",
232
+ "_view_module_version": "1.2.0",
233
+ "_view_name": "LayoutView",
234
+ "align_content": null,
235
+ "align_items": null,
236
+ "align_self": null,
237
+ "border": null,
238
+ "bottom": null,
239
+ "display": null,
240
+ "flex": null,
241
+ "flex_flow": null,
242
+ "grid_area": null,
243
+ "grid_auto_columns": null,
244
+ "grid_auto_flow": null,
245
+ "grid_auto_rows": null,
246
+ "grid_column": null,
247
+ "grid_gap": null,
248
+ "grid_row": null,
249
+ "grid_template_areas": null,
250
+ "grid_template_columns": null,
251
+ "grid_template_rows": null,
252
+ "height": null,
253
+ "justify_content": null,
254
+ "justify_items": null,
255
+ "left": null,
256
+ "margin": null,
257
+ "max_height": null,
258
+ "max_width": null,
259
+ "min_height": null,
260
+ "min_width": null,
261
+ "object_fit": null,
262
+ "object_position": null,
263
+ "order": null,
264
+ "overflow": null,
265
+ "overflow_x": null,
266
+ "overflow_y": null,
267
+ "padding": null,
268
+ "right": null,
269
+ "top": null,
270
+ "visibility": null,
271
+ "width": null
272
+ }
273
+ },
274
+ "b9ac9418f0474c33a1f40e0e86a8fe74": {
275
+ "model_module": "@jupyter-widgets/controls",
276
+ "model_name": "ProgressStyleModel",
277
+ "model_module_version": "1.5.0",
278
+ "state": {
279
+ "_model_module": "@jupyter-widgets/controls",
280
+ "_model_module_version": "1.5.0",
281
+ "_model_name": "ProgressStyleModel",
282
+ "_view_count": null,
283
+ "_view_module": "@jupyter-widgets/base",
284
+ "_view_module_version": "1.2.0",
285
+ "_view_name": "StyleView",
286
+ "bar_color": null,
287
+ "description_width": ""
288
+ }
289
+ },
290
+ "d44a2fd4cf724ef6a67662e69a626eee": {
291
+ "model_module": "@jupyter-widgets/base",
292
+ "model_name": "LayoutModel",
293
+ "model_module_version": "1.2.0",
294
+ "state": {
295
+ "_model_module": "@jupyter-widgets/base",
296
+ "_model_module_version": "1.2.0",
297
+ "_model_name": "LayoutModel",
298
+ "_view_count": null,
299
+ "_view_module": "@jupyter-widgets/base",
300
+ "_view_module_version": "1.2.0",
301
+ "_view_name": "LayoutView",
302
+ "align_content": null,
303
+ "align_items": null,
304
+ "align_self": null,
305
+ "border": null,
306
+ "bottom": null,
307
+ "display": null,
308
+ "flex": null,
309
+ "flex_flow": null,
310
+ "grid_area": null,
311
+ "grid_auto_columns": null,
312
+ "grid_auto_flow": null,
313
+ "grid_auto_rows": null,
314
+ "grid_column": null,
315
+ "grid_gap": null,
316
+ "grid_row": null,
317
+ "grid_template_areas": null,
318
+ "grid_template_columns": null,
319
+ "grid_template_rows": null,
320
+ "height": null,
321
+ "justify_content": null,
322
+ "justify_items": null,
323
+ "left": null,
324
+ "margin": null,
325
+ "max_height": null,
326
+ "max_width": null,
327
+ "min_height": null,
328
+ "min_width": null,
329
+ "object_fit": null,
330
+ "object_position": null,
331
+ "order": null,
332
+ "overflow": null,
333
+ "overflow_x": null,
334
+ "overflow_y": null,
335
+ "padding": null,
336
+ "right": null,
337
+ "top": null,
338
+ "visibility": null,
339
+ "width": null
340
+ }
341
+ },
342
+ "131a2ae47ff14ad38cc60f7434c76bfd": {
343
+ "model_module": "@jupyter-widgets/controls",
344
+ "model_name": "DescriptionStyleModel",
345
+ "model_module_version": "1.5.0",
346
+ "state": {
347
+ "_model_module": "@jupyter-widgets/controls",
348
+ "_model_module_version": "1.5.0",
349
+ "_model_name": "DescriptionStyleModel",
350
+ "_view_count": null,
351
+ "_view_module": "@jupyter-widgets/base",
352
+ "_view_module_version": "1.2.0",
353
+ "_view_name": "StyleView",
354
+ "description_width": ""
355
+ }
356
+ }
357
+ }
358
+ }
359
+ },
360
+ "cells": [
361
+ {
362
+ "cell_type": "markdown",
363
+ "metadata": {
364
+ "id": "Wj6eoKzotv5I"
365
+ },
366
+ "source": [
367
+ "## Emotion Classification using Fine-tuned BERT model\n",
368
+ "\n",
369
+ "In this tutorial, I will show to fine-tune a language model (LM) for emotion classification with code adapted from this [tutorial](https://zablo.net/blog/post/custom-classifier-on-bert-model-guide-polemo2-sentiment-analysis/) by MARCIN ZABŁOCKI. I adapted his tutorial and modified the code to suit the emotion classification task using a different BERT model. Please refer to his tutorial for more detailed explanations for each code block. I really liked his tutorial because of the attention to detail and the use of high-level libraries to take care of certain parts of the model such as training and finding a good learning rate.\n",
370
+ "\n",
371
+ "Before you get started, make sure to enable `GPU` in the runtime and be sure to\n",
372
+ "restart the runtime in this environment after installing the `pytorch-lr-finder` library.\n",
373
+ "\n",
374
+ "This tutorial is in a rough draft so if you find any issues with this tutorial or have any further questions reach out to me via [Twitter](https://twitter.com/omarsar0).\n",
375
+ "\n",
376
+ "Note that the notebook was created a little while back so if something break it's because the code is not compatible with the library changes.\n"
377
+ ]
378
+ },
379
+ {
380
+ "cell_type": "code",
381
+ "metadata": {
382
+ "id": "G2tokZqttmTA"
383
+ },
384
+ "source": [
385
+ "%%capture\n",
386
+ "!pip install transformers tokenizers pytorch-lightning"
387
+ ],
388
+ "execution_count": 10,
389
+ "outputs": []
390
+ },
391
+ {
392
+ "cell_type": "markdown",
393
+ "source": [
394
+ "Note: you need to Restart runtime after running this code segment"
395
+ ],
396
+ "metadata": {
397
+ "id": "I0jZnNegGhZj"
398
+ }
399
+ },
400
+ {
401
+ "cell_type": "code",
402
+ "metadata": {
403
+ "id": "k9ZKIIGvuW5m"
404
+ },
405
+ "source": [
406
+ "%%capture\n",
407
+ "!git clone https://github.com/davidtvs/pytorch-lr-finder.git && cd pytorch-lr-finder && python setup.py install"
408
+ ],
409
+ "execution_count": 11,
410
+ "outputs": []
411
+ },
412
+ {
413
+ "cell_type": "code",
414
+ "metadata": {
415
+ "id": "qqRRWe4UuuIh",
416
+ "outputId": "a12be031-4bc9-404e-e741-9d4710b57683",
417
+ "colab": {
418
+ "base_uri": "https://localhost:8080/",
419
+ "height": 35
420
+ }
421
+ },
422
+ "source": [
423
+ "import torch\n",
424
+ "from torch import nn\n",
425
+ "from typing import List\n",
426
+ "import torch.nn.functional as F\n",
427
+ "from transformers import DistilBertTokenizer, AutoTokenizer, AutoModelWithLMHead, DistilBertForSequenceClassification, AdamW, get_linear_schedule_with_warmup\n",
428
+ "import logging\n",
429
+ "import os\n",
430
+ "from functools import lru_cache\n",
431
+ "from tokenizers import ByteLevelBPETokenizer\n",
432
+ "from tokenizers.processors import BertProcessing\n",
433
+ "import pytorch_lightning as pl\n",
434
+ "from torch.utils.data import DataLoader, Dataset\n",
435
+ "import pandas as pd\n",
436
+ "from argparse import Namespace\n",
437
+ "from sklearn.metrics import classification_report\n",
438
+ "torch.__version__"
439
+ ],
440
+ "execution_count": 12,
441
+ "outputs": [
442
+ {
443
+ "output_type": "execute_result",
444
+ "data": {
445
+ "text/plain": [
446
+ "'2.2.1+cu121'"
447
+ ],
448
+ "application/vnd.google.colaboratory.intrinsic+json": {
449
+ "type": "string"
450
+ }
451
+ },
452
+ "metadata": {},
453
+ "execution_count": 12
454
+ }
455
+ ]
456
+ },
457
+ {
458
+ "cell_type": "markdown",
459
+ "metadata": {
460
+ "id": "_whSBDujRiga"
461
+ },
462
+ "source": [
463
+ "## Load the Pretrained Language Model\n",
464
+ "We are first going to look at pretrained language model provided by HuggingFace models. We will use a variant of BERT, called DistilRoBERTa base. The `base` model has less parameters than the `larger` model.\n",
465
+ "\n",
466
+ "[RoBERTa](https://arxiv.org/abs/1907.11692) is a variant of of BERT which \"*modifies key hyperparameters, removing the next-sentence pretraining objective and training with much larger mini-batches and learning rates*\".\n",
467
+ "\n",
468
+ "Knowledge distillation help to train smaller LMs with similar performance and potential."
469
+ ]
470
+ },
471
+ {
472
+ "cell_type": "markdown",
473
+ "metadata": {
474
+ "id": "BvHNcMckSR4M"
475
+ },
476
+ "source": [
477
+ "First, let's load the tokenizer for this model:"
478
+ ]
479
+ },
480
+ {
481
+ "cell_type": "code",
482
+ "metadata": {
483
+ "id": "BPbTd5lmuzQn"
484
+ },
485
+ "source": [
486
+ "tokenizer = AutoTokenizer.from_pretrained('distilroberta-base')"
487
+ ],
488
+ "execution_count": 13,
489
+ "outputs": []
490
+ },
491
+ {
492
+ "cell_type": "markdown",
493
+ "metadata": {
494
+ "id": "7KAbKMqJSWRo"
495
+ },
496
+ "source": [
497
+ "Now let's load the actual model with the LM head that takes care of the prediciton for the LM. When fine-tuning we don't use the head and instead use the base model. The code below shows how to do this:"
498
+ ]
499
+ },
500
+ {
501
+ "cell_type": "code",
502
+ "metadata": {
503
+ "id": "PCXYlMydzQlP",
504
+ "outputId": "2845314c-bfcb-47a5-9e83-fea79a4c4409",
505
+ "colab": {
506
+ "base_uri": "https://localhost:8080/",
507
+ "height": 158,
508
+ "referenced_widgets": [
509
+ "f848095d186b49e08417c293b642faed",
510
+ "4f6eac487752459b82e2a5ea7d5902c8",
511
+ "f3a2348c535a47878bca775a1f5d50d5",
512
+ "800b720695984617856bd1b4ec7a180c",
513
+ "a7911b3fad6a4db9b891e406745bcc19",
514
+ "c03b91347c1d43dc81d1c277c9b0ac0a",
515
+ "2df14874354f4483a63532dae109082e",
516
+ "f2e2a2f73c724d77bfd0dd01c574d192",
517
+ "b9ac9418f0474c33a1f40e0e86a8fe74",
518
+ "d44a2fd4cf724ef6a67662e69a626eee",
519
+ "131a2ae47ff14ad38cc60f7434c76bfd"
520
+ ]
521
+ }
522
+ },
523
+ "source": [
524
+ "model = AutoModelWithLMHead.from_pretrained(\"distilroberta-base\")\n",
525
+ "base_model = model.base_model"
526
+ ],
527
+ "execution_count": 14,
528
+ "outputs": [
529
+ {
530
+ "output_type": "stream",
531
+ "name": "stderr",
532
+ "text": [
533
+ "/usr/local/lib/python3.10/dist-packages/transformers/models/auto/modeling_auto.py:1595: FutureWarning: The class `AutoModelWithLMHead` is deprecated and will be removed in a future version. Please use `AutoModelForCausalLM` for causal language models, `AutoModelForMaskedLM` for masked language models and `AutoModelForSeq2SeqLM` for encoder-decoder models.\n",
534
+ " warnings.warn(\n"
535
+ ]
536
+ },
537
+ {
538
+ "output_type": "display_data",
539
+ "data": {
540
+ "text/plain": [
541
+ "model.safetensors: 0%| | 0.00/331M [00:00<?, ?B/s]"
542
+ ],
543
+ "application/vnd.jupyter.widget-view+json": {
544
+ "version_major": 2,
545
+ "version_minor": 0,
546
+ "model_id": "f848095d186b49e08417c293b642faed"
547
+ }
548
+ },
549
+ "metadata": {}
550
+ },
551
+ {
552
+ "output_type": "stream",
553
+ "name": "stderr",
554
+ "text": [
555
+ "Some weights of the model checkpoint at distilroberta-base were not used when initializing RobertaForMaskedLM: ['roberta.pooler.dense.bias', 'roberta.pooler.dense.weight']\n",
556
+ "- This IS expected if you are initializing RobertaForMaskedLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
557
+ "- This IS NOT expected if you are initializing RobertaForMaskedLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n"
558
+ ]
559
+ }
560
+ ]
561
+ },
562
+ {
563
+ "cell_type": "markdown",
564
+ "metadata": {
565
+ "id": "K2_8S8BXSpNa"
566
+ },
567
+ "source": [
568
+ "Let's now try out the tokenizer first:"
569
+ ]
570
+ },
571
+ {
572
+ "cell_type": "code",
573
+ "metadata": {
574
+ "id": "5fidSmH-zrY_",
575
+ "outputId": "b396329f-341c-40c5-9294-7e4019f7adf7",
576
+ "colab": {
577
+ "base_uri": "https://localhost:8080/"
578
+ }
579
+ },
580
+ "source": [
581
+ "text = \"Elvis is the king of rock!\"\n",
582
+ "enc = tokenizer.encode_plus(text)\n",
583
+ "enc.keys()"
584
+ ],
585
+ "execution_count": 15,
586
+ "outputs": [
587
+ {
588
+ "output_type": "execute_result",
589
+ "data": {
590
+ "text/plain": [
591
+ "dict_keys(['input_ids', 'attention_mask'])"
592
+ ]
593
+ },
594
+ "metadata": {},
595
+ "execution_count": 15
596
+ }
597
+ ]
598
+ },
599
+ {
600
+ "cell_type": "code",
601
+ "metadata": {
602
+ "id": "m8F8yQCDTDQi",
603
+ "outputId": "cc768922-4463-472d-bbfd-fda843517f48",
604
+ "colab": {
605
+ "base_uri": "https://localhost:8080/"
606
+ }
607
+ },
608
+ "source": [
609
+ "print(enc)"
610
+ ],
611
+ "execution_count": 16,
612
+ "outputs": [
613
+ {
614
+ "output_type": "stream",
615
+ "name": "stdout",
616
+ "text": [
617
+ "{'input_ids': [0, 9682, 9578, 16, 5, 8453, 9, 3152, 328, 2], 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1]}\n"
618
+ ]
619
+ }
620
+ ]
621
+ },
622
+ {
623
+ "cell_type": "markdown",
624
+ "metadata": {
625
+ "id": "P3wSCLKW0ndh"
626
+ },
627
+ "source": [
628
+ "`input_ids` are the numerical encoding of the tokens in the vocabulary. `attention_mask` is an addition option used when batching sequences together and you want to tell the model which tokens should be attented to ([read more](https://huggingface.co/transformers/glossary.html#attention-mask)). The attention mask information helps when dealing with variance in the size of sequences and we need a way to tell the model that we don't want to attend to the padded indices of the sequence.\n",
629
+ "\n",
630
+ "We are only using `input_ids` and `attention_mask`\n",
631
+ "\n",
632
+ "We need to also unsqueeze to simulate batch processing\n",
633
+ "\n",
634
+ "Using DistilBertForSequenceClassification: https://huggingface.co/transformers/model_doc/distilbert.html#distilbertforsequenceclassification"
635
+ ]
636
+ },
637
+ {
638
+ "cell_type": "code",
639
+ "metadata": {
640
+ "id": "Mxsts4uT0PgA",
641
+ "outputId": "78dcf59f-cd7b-4d4e-8bf3-e807a9f35dbe",
642
+ "colab": {
643
+ "base_uri": "https://localhost:8080/"
644
+ }
645
+ },
646
+ "source": [
647
+ "out = base_model(torch.tensor(enc[\"input_ids\"]).unsqueeze(0), torch.tensor(enc[\"attention_mask\"]).unsqueeze(0))\n",
648
+ "out[0].shape"
649
+ ],
650
+ "execution_count": 17,
651
+ "outputs": [
652
+ {
653
+ "output_type": "execute_result",
654
+ "data": {
655
+ "text/plain": [
656
+ "torch.Size([1, 10, 768])"
657
+ ]
658
+ },
659
+ "metadata": {},
660
+ "execution_count": 17
661
+ }
662
+ ]
663
+ },
664
+ {
665
+ "cell_type": "code",
666
+ "metadata": {
667
+ "id": "ZiCO-n_1AHIf",
668
+ "outputId": "b8498d89-c107-4077-f5c3-37c0a19ef89b",
669
+ "colab": {
670
+ "base_uri": "https://localhost:8080/"
671
+ }
672
+ },
673
+ "source": [
674
+ "## size of representation of one of the tokens\n",
675
+ "out[0][:,0,:].shape"
676
+ ],
677
+ "execution_count": 18,
678
+ "outputs": [
679
+ {
680
+ "output_type": "execute_result",
681
+ "data": {
682
+ "text/plain": [
683
+ "torch.Size([1, 768])"
684
+ ]
685
+ },
686
+ "metadata": {},
687
+ "execution_count": 18
688
+ }
689
+ ]
690
+ },
691
+ {
692
+ "cell_type": "markdown",
693
+ "metadata": {
694
+ "id": "srwIb9nr4g4t"
695
+ },
696
+ "source": [
697
+ "`torch.Size([1, 768])` represents batch_size, number of tokens in input text (lenght of tokenized text), model's output hidden size."
698
+ ]
699
+ },
700
+ {
701
+ "cell_type": "code",
702
+ "metadata": {
703
+ "id": "iAsg0H6g53Bf",
704
+ "outputId": "1892e9cd-fd84-4978-8dd2-037d21e3dfb8",
705
+ "colab": {
706
+ "base_uri": "https://localhost:8080/"
707
+ }
708
+ },
709
+ "source": [
710
+ "t = \"Elvis is the king of rock\"\n",
711
+ "enc = tokenizer.encode_plus(t)\n",
712
+ "token_representations = base_model(torch.tensor(enc[\"input_ids\"]).unsqueeze(0))[0][0]\n",
713
+ "print(enc[\"input_ids\"])\n",
714
+ "print(tokenizer.decode(enc[\"input_ids\"]))\n",
715
+ "print(f\"Length: {len(enc['input_ids'])}\")\n",
716
+ "print(token_representations.shape)"
717
+ ],
718
+ "execution_count": 19,
719
+ "outputs": [
720
+ {
721
+ "output_type": "stream",
722
+ "name": "stdout",
723
+ "text": [
724
+ "[0, 9682, 9578, 16, 5, 8453, 9, 3152, 2]\n",
725
+ "<s>Elvis is the king of rock</s>\n",
726
+ "Length: 9\n",
727
+ "torch.Size([9, 768])\n"
728
+ ]
729
+ }
730
+ ]
731
+ },
732
+ {
733
+ "cell_type": "markdown",
734
+ "metadata": {
735
+ "id": "9RFifOoY7Hsc"
736
+ },
737
+ "source": [
738
+ "## Building Custom Classification head on top of LM base model"
739
+ ]
740
+ },
741
+ {
742
+ "cell_type": "markdown",
743
+ "metadata": {
744
+ "id": "vSUMm4Oq7nvR"
745
+ },
746
+ "source": [
747
+ "Use Mish activiation function as in the one proposed in the original tutorial"
748
+ ]
749
+ },
750
+ {
751
+ "cell_type": "code",
752
+ "metadata": {
753
+ "id": "tCEDXLxq628O"
754
+ },
755
+ "source": [
756
+ "# from https://github.com/digantamisra98/Mish/blob/b5f006660ac0b4c46e2c6958ad0301d7f9c59651/Mish/Torch/mish.py\n",
757
+ "@torch.jit.script\n",
758
+ "def mish(input):\n",
759
+ " return input * torch.tanh(F.softplus(input))\n",
760
+ "\n",
761
+ "class Mish(nn.Module):\n",
762
+ " def forward(self, input):\n",
763
+ " return mish(input)"
764
+ ],
765
+ "execution_count": 20,
766
+ "outputs": []
767
+ },
768
+ {
769
+ "cell_type": "markdown",
770
+ "metadata": {
771
+ "id": "C6Ln6KWm74ku"
772
+ },
773
+ "source": [
774
+ "The model we will use to do the fine-tuning"
775
+ ]
776
+ },
777
+ {
778
+ "cell_type": "code",
779
+ "metadata": {
780
+ "id": "9VDRSRsc71H2"
781
+ },
782
+ "source": [
783
+ "class EmoModel(nn.Module):\n",
784
+ " def __init__(self, base_model, n_classes, base_model_output_size=768, dropout=0.05):\n",
785
+ " super().__init__()\n",
786
+ " self.base_model = base_model\n",
787
+ "\n",
788
+ " self.classifier = nn.Sequential(\n",
789
+ " nn.Dropout(dropout),\n",
790
+ " nn.Linear(base_model_output_size, base_model_output_size),\n",
791
+ " Mish(),\n",
792
+ " nn.Dropout(dropout),\n",
793
+ " nn.Linear(base_model_output_size, n_classes)\n",
794
+ " )\n",
795
+ "\n",
796
+ " for layer in self.classifier:\n",
797
+ " if isinstance(layer, nn.Linear):\n",
798
+ " layer.weight.data.normal_(mean=0.0, std=0.02)\n",
799
+ " if layer.bias is not None:\n",
800
+ " layer.bias.data.zero_()\n",
801
+ "\n",
802
+ " def forward(self, input_, *args):\n",
803
+ " X, attention_mask = input_\n",
804
+ " hidden_states = self.base_model(X, attention_mask=attention_mask)\n",
805
+ "\n",
806
+ " # maybe do some pooling / RNNs... go crazy here!\n",
807
+ "\n",
808
+ " # use the <s> representation\n",
809
+ " return self.classifier(hidden_states[0][:, 0, :])"
810
+ ],
811
+ "execution_count": 21,
812
+ "outputs": []
813
+ },
814
+ {
815
+ "cell_type": "markdown",
816
+ "metadata": {
817
+ "id": "wjgME-3O8Yfo"
818
+ },
819
+ "source": [
820
+ "### Pretest the model with dummy text\n",
821
+ "We want to ensure that the model is returing the right information back."
822
+ ]
823
+ },
824
+ {
825
+ "cell_type": "code",
826
+ "metadata": {
827
+ "id": "Y6H9eF8A8XeV",
828
+ "outputId": "4bc9b2b2-9882-4218-b780-1af26e3b3969",
829
+ "colab": {
830
+ "base_uri": "https://localhost:8080/"
831
+ }
832
+ },
833
+ "source": [
834
+ "classifier = EmoModel(AutoModelWithLMHead.from_pretrained(\"distilroberta-base\").base_model, 3)"
835
+ ],
836
+ "execution_count": 22,
837
+ "outputs": [
838
+ {
839
+ "output_type": "stream",
840
+ "name": "stderr",
841
+ "text": [
842
+ "/usr/local/lib/python3.10/dist-packages/transformers/models/auto/modeling_auto.py:1595: FutureWarning: The class `AutoModelWithLMHead` is deprecated and will be removed in a future version. Please use `AutoModelForCausalLM` for causal language models, `AutoModelForMaskedLM` for masked language models and `AutoModelForSeq2SeqLM` for encoder-decoder models.\n",
843
+ " warnings.warn(\n",
844
+ "Some weights of the model checkpoint at distilroberta-base were not used when initializing RobertaForMaskedLM: ['roberta.pooler.dense.bias', 'roberta.pooler.dense.weight']\n",
845
+ "- This IS expected if you are initializing RobertaForMaskedLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
846
+ "- This IS NOT expected if you are initializing RobertaForMaskedLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n"
847
+ ]
848
+ }
849
+ ]
850
+ },
851
+ {
852
+ "cell_type": "code",
853
+ "metadata": {
854
+ "id": "-sjfHJ_L9iNH"
855
+ },
856
+ "source": [
857
+ "X = torch.tensor(enc[\"input_ids\"]).unsqueeze(0).to('cpu')\n",
858
+ "attn = torch.tensor(enc[\"attention_mask\"]).unsqueeze(0).to('cpu')"
859
+ ],
860
+ "execution_count": 23,
861
+ "outputs": []
862
+ },
863
+ {
864
+ "cell_type": "code",
865
+ "metadata": {
866
+ "id": "o6QhCuEC-y2z",
867
+ "outputId": "eed26cf5-303f-4098-ef84-3d4ab47d6f37",
868
+ "colab": {
869
+ "base_uri": "https://localhost:8080/"
870
+ }
871
+ },
872
+ "source": [
873
+ "classifier((X, attn))"
874
+ ],
875
+ "execution_count": 24,
876
+ "outputs": [
877
+ {
878
+ "output_type": "execute_result",
879
+ "data": {
880
+ "text/plain": [
881
+ "tensor([[-0.0993, 0.0813, -0.1939]], grad_fn=<AddmmBackward0>)"
882
+ ]
883
+ },
884
+ "metadata": {},
885
+ "execution_count": 24
886
+ }
887
+ ]
888
+ },
889
+ {
890
+ "cell_type": "markdown",
891
+ "metadata": {
892
+ "id": "I-N7WSY7Cb7v"
893
+ },
894
+ "source": [
895
+ "## Prepare your dataset for fine-tuning"
896
+ ]
897
+ },
898
+ {
899
+ "cell_type": "code",
900
+ "metadata": {
901
+ "id": "jDWkjaLV-5tj"
902
+ },
903
+ "source": [
904
+ "!mkdir -p tokenizer"
905
+ ],
906
+ "execution_count": 25,
907
+ "outputs": []
908
+ },
909
+ {
910
+ "cell_type": "code",
911
+ "metadata": {
912
+ "id": "wMMm5Ye1Db-m",
913
+ "outputId": "2227ea88-5302-43eb-d876-9e4a772a391d",
914
+ "colab": {
915
+ "base_uri": "https://localhost:8080/"
916
+ }
917
+ },
918
+ "source": [
919
+ "## load pretrained tokenizer information\n",
920
+ "tokenizer.save_pretrained(\"tokenizer\")"
921
+ ],
922
+ "execution_count": 26,
923
+ "outputs": [
924
+ {
925
+ "output_type": "execute_result",
926
+ "data": {
927
+ "text/plain": [
928
+ "('tokenizer/tokenizer_config.json',\n",
929
+ " 'tokenizer/special_tokens_map.json',\n",
930
+ " 'tokenizer/vocab.json',\n",
931
+ " 'tokenizer/merges.txt',\n",
932
+ " 'tokenizer/added_tokens.json',\n",
933
+ " 'tokenizer/tokenizer.json')"
934
+ ]
935
+ },
936
+ "metadata": {},
937
+ "execution_count": 26
938
+ }
939
+ ]
940
+ },
941
+ {
942
+ "cell_type": "code",
943
+ "metadata": {
944
+ "id": "3FVtbmrzDkF8",
945
+ "outputId": "5d58c54e-5c35-4c79-e791-a1bc60d396e8",
946
+ "colab": {
947
+ "base_uri": "https://localhost:8080/"
948
+ }
949
+ },
950
+ "source": [
951
+ "!ls tokenizer"
952
+ ],
953
+ "execution_count": 27,
954
+ "outputs": [
955
+ {
956
+ "output_type": "stream",
957
+ "name": "stdout",
958
+ "text": [
959
+ "merges.txt special_tokens_map.json tokenizer_config.json tokenizer.json vocab.json\n"
960
+ ]
961
+ }
962
+ ]
963
+ },
964
+ {
965
+ "cell_type": "markdown",
966
+ "metadata": {
967
+ "id": "BhTEgIaLEDRo"
968
+ },
969
+ "source": [
970
+ "Implement CollateFN using fast tokenizers.\n",
971
+ "This function basically takes care of proper tokenization and batches of sequences. This way you don't need to create your batches manually. Find out more about Tokenizers [here](https://github.com/huggingface/tokenizers/tree/master/bindings/python)."
972
+ ]
973
+ },
974
+ {
975
+ "cell_type": "code",
976
+ "metadata": {
977
+ "id": "3SCLBZsMDn4s"
978
+ },
979
+ "source": [
980
+ "class TokenizersCollateFn:\n",
981
+ " def __init__(self, max_tokens=512):\n",
982
+ "\n",
983
+ " ## RoBERTa uses BPE tokenizer similar to GPT\n",
984
+ " t = ByteLevelBPETokenizer(\n",
985
+ " \"tokenizer/vocab.json\",\n",
986
+ " \"tokenizer/merges.txt\"\n",
987
+ " )\n",
988
+ " t._tokenizer.post_processor = BertProcessing(\n",
989
+ " (\"</s>\", t.token_to_id(\"</s>\")),\n",
990
+ " (\"<s>\", t.token_to_id(\"<s>\")),\n",
991
+ " )\n",
992
+ " t.enable_truncation(max_tokens)\n",
993
+ " t.enable_padding(length=max_tokens, pad_id=t.token_to_id(\"<pad>\"))\n",
994
+ " self.tokenizer = t\n",
995
+ "\n",
996
+ " def __call__(self, batch):\n",
997
+ " encoded = self.tokenizer.encode_batch([x[0] for x in batch])\n",
998
+ " sequences_padded = torch.tensor([enc.ids for enc in encoded])\n",
999
+ " attention_masks_padded = torch.tensor([enc.attention_mask for enc in encoded])\n",
1000
+ " labels = torch.tensor([x[1] for x in batch])\n",
1001
+ "\n",
1002
+ " return (sequences_padded, attention_masks_padded), labels"
1003
+ ],
1004
+ "execution_count": 28,
1005
+ "outputs": []
1006
+ },
1007
+ {
1008
+ "cell_type": "markdown",
1009
+ "metadata": {
1010
+ "id": "4hu70Ng0Eqls"
1011
+ },
1012
+ "source": [
1013
+ "## Getting the Data and Preview it\n",
1014
+ "Below we are going to load the data and show you how to create the splits. However, we don't need to split the data manually becuase I have already created the splits and stored those files seperately which you can quickly download below:"
1015
+ ]
1016
+ },
1017
+ {
1018
+ "cell_type": "code",
1019
+ "metadata": {
1020
+ "id": "JZ3SoJH3fUsq",
1021
+ "outputId": "45966756-4264-434d-a33d-ca6cc53aac6a",
1022
+ "colab": {
1023
+ "base_uri": "https://localhost:8080/"
1024
+ }
1025
+ },
1026
+ "source": [
1027
+ "!wget https://www.dropbox.com/s/ikkqxfdbdec3fuj/test.txt\n",
1028
+ "!wget https://www.dropbox.com/s/1pzkadrvffbqw6o/train.txt\n",
1029
+ "!wget https://www.dropbox.com/s/2mzialpsgf9k5l3/val.txt"
1030
+ ],
1031
+ "execution_count": 29,
1032
+ "outputs": [
1033
+ {
1034
+ "output_type": "stream",
1035
+ "name": "stdout",
1036
+ "text": [
1037
+ "--2024-03-15 23:58:45-- https://www.dropbox.com/s/ikkqxfdbdec3fuj/test.txt\n",
1038
+ "Resolving www.dropbox.com (www.dropbox.com)... 162.125.2.18, 2620:100:6017:18::a27d:212\n",
1039
+ "Connecting to www.dropbox.com (www.dropbox.com)|162.125.2.18|:443... connected.\n",
1040
+ "HTTP request sent, awaiting response... 302 Found\n",
1041
+ "Location: /s/raw/ikkqxfdbdec3fuj/test.txt [following]\n",
1042
+ "--2024-03-15 23:58:45-- https://www.dropbox.com/s/raw/ikkqxfdbdec3fuj/test.txt\n",
1043
+ "Reusing existing connection to www.dropbox.com:443.\n",
1044
+ "HTTP request sent, awaiting response... 404 Not Found\n",
1045
+ "2024-03-15 23:58:45 ERROR 404: Not Found.\n",
1046
+ "\n",
1047
+ "--2024-03-15 23:58:45-- https://www.dropbox.com/s/1pzkadrvffbqw6o/train.txt\n",
1048
+ "Resolving www.dropbox.com (www.dropbox.com)... 162.125.2.18, 2620:100:6017:18::a27d:212\n",
1049
+ "Connecting to www.dropbox.com (www.dropbox.com)|162.125.2.18|:443... connected.\n",
1050
+ "HTTP request sent, awaiting response... 302 Found\n",
1051
+ "Location: /s/raw/1pzkadrvffbqw6o/train.txt [following]\n",
1052
+ "--2024-03-15 23:58:45-- https://www.dropbox.com/s/raw/1pzkadrvffbqw6o/train.txt\n",
1053
+ "Reusing existing connection to www.dropbox.com:443.\n",
1054
+ "HTTP request sent, awaiting response... 404 Not Found\n",
1055
+ "2024-03-15 23:58:46 ERROR 404: Not Found.\n",
1056
+ "\n",
1057
+ "--2024-03-15 23:58:46-- https://www.dropbox.com/s/2mzialpsgf9k5l3/val.txt\n",
1058
+ "Resolving www.dropbox.com (www.dropbox.com)... 162.125.2.18, 2620:100:6017:18::a27d:212\n",
1059
+ "Connecting to www.dropbox.com (www.dropbox.com)|162.125.2.18|:443... connected.\n",
1060
+ "HTTP request sent, awaiting response... 302 Found\n",
1061
+ "Location: /s/raw/2mzialpsgf9k5l3/val.txt [following]\n",
1062
+ "--2024-03-15 23:58:46-- https://www.dropbox.com/s/raw/2mzialpsgf9k5l3/val.txt\n",
1063
+ "Reusing existing connection to www.dropbox.com:443.\n",
1064
+ "HTTP request sent, awaiting response... 404 Not Found\n",
1065
+ "2024-03-15 23:58:46 ERROR 404: Not Found.\n",
1066
+ "\n"
1067
+ ]
1068
+ }
1069
+ ]
1070
+ },
1071
+ {
1072
+ "cell_type": "code",
1073
+ "metadata": {
1074
+ "id": "r_03fxufWX_G"
1075
+ },
1076
+ "source": [
1077
+ "## export the datasets as txt files\n",
1078
+ "## EXERCISE: Change this to an address\n",
1079
+ "\n",
1080
+ "train_path = \"train.txt\"\n",
1081
+ "test_path = \"test.txt\"\n",
1082
+ "val_path = \"val.txt\"\n",
1083
+ "\n",
1084
+ "## emotion labels\n",
1085
+ "label2int = {\n",
1086
+ " \"sadness\": 0,\n",
1087
+ " \"joy\": 1,\n",
1088
+ " \"love\": 2,\n",
1089
+ " \"anger\": 3,\n",
1090
+ " \"fear\": 4,\n",
1091
+ " \"surprise\": 5\n",
1092
+ "}\n",
1093
+ "\n",
1094
+ "emotions = [ \"sadness\", \"joy\", \"love\", \"anger\", \"fear\", \"surprise\"]"
1095
+ ],
1096
+ "execution_count": 30,
1097
+ "outputs": []
1098
+ },
1099
+ {
1100
+ "cell_type": "markdown",
1101
+ "source": [
1102
+ "### A Quick Look at the dataset\n",
1103
+ "Below is a few code sniphets to get a good idea of the dataset we are using here. You can skip this whole subsection if you like."
1104
+ ],
1105
+ "metadata": {
1106
+ "id": "-FJ-wN1_zmkV"
1107
+ }
1108
+ },
1109
+ {
1110
+ "cell_type": "code",
1111
+ "metadata": {
1112
+ "id": "t23zHggkEpc-",
1113
+ "outputId": "3a9615d4-492f-4134-aaa4-43cf15234fb8",
1114
+ "colab": {
1115
+ "base_uri": "https://localhost:8080/"
1116
+ }
1117
+ },
1118
+ "source": [
1119
+ "!wget https://www.dropbox.com/s/607ptdakxuh5i4s/merged_training.pkl"
1120
+ ],
1121
+ "execution_count": 31,
1122
+ "outputs": [
1123
+ {
1124
+ "output_type": "stream",
1125
+ "name": "stdout",
1126
+ "text": [
1127
+ "--2024-03-15 23:58:46-- https://www.dropbox.com/s/607ptdakxuh5i4s/merged_training.pkl\n",
1128
+ "Resolving www.dropbox.com (www.dropbox.com)... 162.125.2.18, 2620:100:6017:18::a27d:212\n",
1129
+ "Connecting to www.dropbox.com (www.dropbox.com)|162.125.2.18|:443... connected.\n",
1130
+ "HTTP request sent, awaiting response... 302 Found\n",
1131
+ "Location: /s/raw/607ptdakxuh5i4s/merged_training.pkl [following]\n",
1132
+ "--2024-03-15 23:58:46-- https://www.dropbox.com/s/raw/607ptdakxuh5i4s/merged_training.pkl\n",
1133
+ "Reusing existing connection to www.dropbox.com:443.\n",
1134
+ "HTTP request sent, awaiting response... 404 Not Found\n",
1135
+ "2024-03-15 23:58:46 ERROR 404: Not Found.\n",
1136
+ "\n"
1137
+ ]
1138
+ }
1139
+ ]
1140
+ },
1141
+ {
1142
+ "cell_type": "code",
1143
+ "metadata": {
1144
+ "id": "PQrMSUTRF06B"
1145
+ },
1146
+ "source": [
1147
+ "import pickle\n",
1148
+ "\n",
1149
+ "## helper function\n",
1150
+ "def load_from_pickle(directory):\n",
1151
+ " return pickle.load(open(directory,\"rb\"))"
1152
+ ],
1153
+ "execution_count": 32,
1154
+ "outputs": []
1155
+ },
1156
+ {
1157
+ "cell_type": "code",
1158
+ "metadata": {
1159
+ "id": "XGz89mNSHaYM",
1160
+ "outputId": "ca0ffab9-8002-43fe-8761-c4f98f495482",
1161
+ "colab": {
1162
+ "base_uri": "https://localhost:8080/",
1163
+ "height": 305
1164
+ }
1165
+ },
1166
+ "source": [
1167
+ "data = load_from_pickle(directory=\"merged_training.pkl\")\n",
1168
+ "\n",
1169
+ "## using a sample\n",
1170
+ "data= data[data[\"emotions\"].isin(emotions)]\n",
1171
+ "\n",
1172
+ "\n",
1173
+ "data = data.sample(n=20000);\n",
1174
+ "\n",
1175
+ "data.emotions.value_counts().plot.bar()"
1176
+ ],
1177
+ "execution_count": 33,
1178
+ "outputs": [
1179
+ {
1180
+ "output_type": "error",
1181
+ "ename": "FileNotFoundError",
1182
+ "evalue": "[Errno 2] No such file or directory: 'merged_training.pkl'",
1183
+ "traceback": [
1184
+ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
1185
+ "\u001b[0;31mFileNotFoundError\u001b[0m Traceback (most recent call last)",
1186
+ "\u001b[0;32m<ipython-input-33-b230c266f99a>\u001b[0m in \u001b[0;36m<cell line: 1>\u001b[0;34m()\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0mdata\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mload_from_pickle\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdirectory\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m\"merged_training.pkl\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 2\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 3\u001b[0m \u001b[0;31m## using a sample\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 4\u001b[0m \u001b[0mdata\u001b[0m\u001b[0;34m=\u001b[0m \u001b[0mdata\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mdata\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m\"emotions\"\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0misin\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0memotions\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 5\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
1187
+ "\u001b[0;32m<ipython-input-32-01bb35124bd3>\u001b[0m in \u001b[0;36mload_from_pickle\u001b[0;34m(directory)\u001b[0m\n\u001b[1;32m 3\u001b[0m \u001b[0;31m## helper function\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 4\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mload_from_pickle\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdirectory\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 5\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mpickle\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mload\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mopen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdirectory\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\"rb\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
1188
+ "\u001b[0;31mFileNotFoundError\u001b[0m: [Errno 2] No such file or directory: 'merged_training.pkl'"
1189
+ ]
1190
+ }
1191
+ ]
1192
+ },
1193
+ {
1194
+ "cell_type": "code",
1195
+ "metadata": {
1196
+ "id": "Comaf36-Hb6X"
1197
+ },
1198
+ "source": [
1199
+ "data.count()"
1200
+ ],
1201
+ "execution_count": null,
1202
+ "outputs": []
1203
+ },
1204
+ {
1205
+ "cell_type": "markdown",
1206
+ "metadata": {
1207
+ "id": "jYxc8fx_H3ad"
1208
+ },
1209
+ "source": [
1210
+ "Data has been preprocessed already, using technique from this paper: https://www.aclweb.org/anthology/D18-1404/"
1211
+ ]
1212
+ },
1213
+ {
1214
+ "cell_type": "code",
1215
+ "metadata": {
1216
+ "id": "gYKK7ujRHfRt"
1217
+ },
1218
+ "source": [
1219
+ "data.head()"
1220
+ ],
1221
+ "execution_count": null,
1222
+ "outputs": []
1223
+ },
1224
+ {
1225
+ "cell_type": "code",
1226
+ "metadata": {
1227
+ "id": "JXovcl56NFPp"
1228
+ },
1229
+ "source": [
1230
+ "## reset index\n",
1231
+ "data.reset_index(drop=True, inplace=True)"
1232
+ ],
1233
+ "execution_count": null,
1234
+ "outputs": []
1235
+ },
1236
+ {
1237
+ "cell_type": "code",
1238
+ "metadata": {
1239
+ "id": "pSzoz9InH0Ta"
1240
+ },
1241
+ "source": [
1242
+ "## check unique emotions in the dataset\n",
1243
+ "data.emotions.unique()"
1244
+ ],
1245
+ "execution_count": null,
1246
+ "outputs": []
1247
+ },
1248
+ {
1249
+ "cell_type": "markdown",
1250
+ "metadata": {
1251
+ "id": "rJm31gKShQus"
1252
+ },
1253
+ "source": [
1254
+ "## Split the data and store into individual text files\n",
1255
+ "\n",
1256
+ "If you are using your own dataset and want to split it for training, you can uncomment the code below. Otherwise, just skip it."
1257
+ ]
1258
+ },
1259
+ {
1260
+ "cell_type": "code",
1261
+ "metadata": {
1262
+ "id": "6ooNxSnPiztL"
1263
+ },
1264
+ "source": [
1265
+ "## uncomment the code below to generate the text files for your train, val, and test datasets.\n",
1266
+ "\n",
1267
+ "'''\n",
1268
+ "from sklearn.model_selection import train_test_split\n",
1269
+ "import numpy as np\n",
1270
+ "\n",
1271
+ "# Creating training and validation sets using an 80-20 split\n",
1272
+ "input_train, input_val, target_train, target_val = train_test_split(data.text.to_numpy(),\n",
1273
+ " data.emotions.to_numpy(),\n",
1274
+ " test_size=0.2)\n",
1275
+ "\n",
1276
+ "# Split the validataion further to obtain a holdout dataset (for testing) -- split 50:50\n",
1277
+ "input_val, input_test, target_val, target_test = train_test_split(input_val, target_val, test_size=0.5)\n",
1278
+ "\n",
1279
+ "\n",
1280
+ "## create a dataframe for each dataset\n",
1281
+ "train_dataset = pd.DataFrame(data={\"text\": input_train, \"class\": target_train})\n",
1282
+ "val_dataset = pd.DataFrame(data={\"text\": input_val, \"class\": target_val})\n",
1283
+ "test_dataset = pd.DataFrame(data={\"text\": input_test, \"class\": target_test})\n",
1284
+ "final_dataset = {\"train\": train_dataset, \"val\": val_dataset , \"test\": test_dataset }\n",
1285
+ "\n",
1286
+ "train_dataset.to_csv(train_path, sep=\";\",header=False, index=False)\n",
1287
+ "val_dataset.to_csv(test_path, sep=\";\",header=False, index=False)\n",
1288
+ "test_dataset.to_csv(val_path, sep=\";\",header=False, index=False)\n",
1289
+ "'''"
1290
+ ],
1291
+ "execution_count": null,
1292
+ "outputs": []
1293
+ },
1294
+ {
1295
+ "cell_type": "markdown",
1296
+ "metadata": {
1297
+ "id": "rAD1J6c0dLp8"
1298
+ },
1299
+ "source": [
1300
+ "## Create the Dataset object"
1301
+ ]
1302
+ },
1303
+ {
1304
+ "cell_type": "markdown",
1305
+ "metadata": {
1306
+ "id": "aOOI69vwIYcN"
1307
+ },
1308
+ "source": [
1309
+ "Create the Dataset object that will be used to load the different datasets."
1310
+ ]
1311
+ },
1312
+ {
1313
+ "cell_type": "code",
1314
+ "metadata": {
1315
+ "id": "Ktr6xeMuISin"
1316
+ },
1317
+ "source": [
1318
+ "class EmoDataset(Dataset):\n",
1319
+ " def __init__(self, path):\n",
1320
+ " super().__init__()\n",
1321
+ " self.data_column = \"text\"\n",
1322
+ " self.class_column = \"class\"\n",
1323
+ " self.data = pd.read_csv(path, sep=\";\", header=None, names=[self.data_column, self.class_column],\n",
1324
+ " engine=\"python\")\n",
1325
+ "\n",
1326
+ " def __getitem__(self, idx):\n",
1327
+ " return self.data.loc[idx, self.data_column], label2int[self.data.loc[idx, self.class_column]]\n",
1328
+ "\n",
1329
+ " def __len__(self):\n",
1330
+ " return self.data.shape[0]"
1331
+ ],
1332
+ "execution_count": null,
1333
+ "outputs": []
1334
+ },
1335
+ {
1336
+ "cell_type": "markdown",
1337
+ "metadata": {
1338
+ "id": "9EYQRq3qJH7n"
1339
+ },
1340
+ "source": [
1341
+ "Sanity check"
1342
+ ]
1343
+ },
1344
+ {
1345
+ "cell_type": "code",
1346
+ "metadata": {
1347
+ "id": "uGWw4wGEJGhJ"
1348
+ },
1349
+ "source": [
1350
+ "ds = EmoDataset(train_path)\n",
1351
+ "ds[19]"
1352
+ ],
1353
+ "execution_count": null,
1354
+ "outputs": []
1355
+ },
1356
+ {
1357
+ "cell_type": "markdown",
1358
+ "metadata": {
1359
+ "id": "0h6tTn9hd6v8"
1360
+ },
1361
+ "source": [
1362
+ "## Training with PyTorchLightning\n",
1363
+ "\n",
1364
+ "[PyTorchLightning](https://www.pytorchlightning.ai/) is a library that abstracts the complexity of training neural networks with PyTorch. It is built on top of PyTorch and simplifies training.\n",
1365
+ "\n",
1366
+ "![](https://pytorch-lightning.readthedocs.io/en/latest/_images/pt_to_pl.png)"
1367
+ ]
1368
+ },
1369
+ {
1370
+ "cell_type": "code",
1371
+ "metadata": {
1372
+ "id": "RJHhNRcZK7sV"
1373
+ },
1374
+ "source": [
1375
+ "## Methods required by PyTorchLightning\n",
1376
+ "\n",
1377
+ "class TrainingModule(pl.LightningModule):\n",
1378
+ " def __init__(self, hparams):\n",
1379
+ " super().__init__()\n",
1380
+ " self.model = EmoModel(AutoModelWithLMHead.from_pretrained(\"distilroberta-base\").base_model, len(emotions))\n",
1381
+ " self.loss = nn.CrossEntropyLoss() ## combines LogSoftmax() and NLLLoss()\n",
1382
+ " #self.hparams = hparams\n",
1383
+ " self.hparams.update(vars(hparams))\n",
1384
+ "\n",
1385
+ " def step(self, batch, step_name=\"train\"):\n",
1386
+ " X, y = batch\n",
1387
+ " loss = self.loss(self.forward(X), y)\n",
1388
+ " loss_key = f\"{step_name}_loss\"\n",
1389
+ " tensorboard_logs = {loss_key: loss}\n",
1390
+ "\n",
1391
+ " return { (\"loss\" if step_name == \"train\" else loss_key): loss, 'log': tensorboard_logs,\n",
1392
+ " \"progress_bar\": {loss_key: loss}}\n",
1393
+ "\n",
1394
+ " def forward(self, X, *args):\n",
1395
+ " return self.model(X, *args)\n",
1396
+ "\n",
1397
+ " def training_step(self, batch, batch_idx):\n",
1398
+ " return self.step(batch, \"train\")\n",
1399
+ "\n",
1400
+ " def validation_step(self, batch, batch_idx):\n",
1401
+ " return self.step(batch, \"val\")\n",
1402
+ "\n",
1403
+ " def validation_end(self, outputs: List[dict]):\n",
1404
+ " loss = torch.stack([x[\"val_loss\"] for x in outputs]).mean()\n",
1405
+ " return {\"val_loss\": loss}\n",
1406
+ "\n",
1407
+ " def test_step(self, batch, batch_idx):\n",
1408
+ " return self.step(batch, \"test\")\n",
1409
+ "\n",
1410
+ " def train_dataloader(self):\n",
1411
+ " return self.create_data_loader(self.hparams.train_path, shuffle=True)\n",
1412
+ "\n",
1413
+ " def val_dataloader(self):\n",
1414
+ " return self.create_data_loader(self.hparams.val_path)\n",
1415
+ "\n",
1416
+ " def test_dataloader(self):\n",
1417
+ " return self.create_data_loader(self.hparams.test_path)\n",
1418
+ "\n",
1419
+ " def create_data_loader(self, ds_path: str, shuffle=False):\n",
1420
+ " return DataLoader(\n",
1421
+ " EmoDataset(ds_path),\n",
1422
+ " batch_size=self.hparams.batch_size,\n",
1423
+ " shuffle=shuffle,\n",
1424
+ " collate_fn=TokenizersCollateFn()\n",
1425
+ " )\n",
1426
+ "\n",
1427
+ " @lru_cache()\n",
1428
+ " def total_steps(self):\n",
1429
+ " return len(self.train_dataloader()) // self.hparams.accumulate_grad_batches * self.hparams.epochs\n",
1430
+ "\n",
1431
+ " def configure_optimizers(self):\n",
1432
+ " ## use AdamW optimizer -- faster approach to training NNs\n",
1433
+ " ## read: https://www.fast.ai/2018/07/02/adam-weight-decay/\n",
1434
+ " optimizer = AdamW(self.model.parameters(), lr=self.hparams.lr)\n",
1435
+ " lr_scheduler = get_linear_schedule_with_warmup(\n",
1436
+ " optimizer,\n",
1437
+ " num_warmup_steps=self.hparams.warmup_steps,\n",
1438
+ " num_training_steps=self.total_steps(),\n",
1439
+ " )\n",
1440
+ " return [optimizer], [{\"scheduler\": lr_scheduler, \"interval\": \"step\"}]"
1441
+ ],
1442
+ "execution_count": null,
1443
+ "outputs": []
1444
+ },
1445
+ {
1446
+ "cell_type": "markdown",
1447
+ "metadata": {
1448
+ "id": "OGc7Vw1moHxr"
1449
+ },
1450
+ "source": [
1451
+ "## Finding Learning rate for the model\n",
1452
+ "\n",
1453
+ "The code below aims to obtain valuable information about the optimal learning rate during a pretraining run. Determine boundary and increase the leanring rate linearly or exponentially.\n",
1454
+ "\n",
1455
+ "More: https://github.com/davidtvs/pytorch-lr-finder"
1456
+ ]
1457
+ },
1458
+ {
1459
+ "cell_type": "code",
1460
+ "metadata": {
1461
+ "id": "xL4lNPDFoFyU"
1462
+ },
1463
+ "source": [
1464
+ "lr=0.1 ## uper bound LR\n",
1465
+ "from torch_lr_finder import LRFinder\n",
1466
+ "hparams_tmp = Namespace(\n",
1467
+ " train_path=train_path,\n",
1468
+ " val_path=val_path,\n",
1469
+ " test_path=test_path,\n",
1470
+ " batch_size=16,\n",
1471
+ " warmup_steps=100,\n",
1472
+ " epochs=1,\n",
1473
+ " lr=lr,\n",
1474
+ " accumulate_grad_batches=1,\n",
1475
+ ")\n",
1476
+ "module = TrainingModule(hparams_tmp)\n",
1477
+ "criterion = nn.CrossEntropyLoss()\n",
1478
+ "optimizer = AdamW(module.parameters(), lr=5e-7) ## lower bound LR\n",
1479
+ "lr_finder = LRFinder(module, optimizer, criterion, device=\"cuda\")\n",
1480
+ "lr_finder.range_test(module.train_dataloader(), end_lr=100, num_iter=100, accumulation_steps=hparams_tmp.accumulate_grad_batches)\n",
1481
+ "lr_finder.plot()\n",
1482
+ "lr_finder.reset()"
1483
+ ],
1484
+ "execution_count": null,
1485
+ "outputs": []
1486
+ },
1487
+ {
1488
+ "cell_type": "code",
1489
+ "metadata": {
1490
+ "id": "YdqP56M1oXav"
1491
+ },
1492
+ "source": [
1493
+ "lr = 1e-4\n",
1494
+ "lr"
1495
+ ],
1496
+ "execution_count": null,
1497
+ "outputs": []
1498
+ },
1499
+ {
1500
+ "cell_type": "code",
1501
+ "metadata": {
1502
+ "id": "vMab6vu0Bow0"
1503
+ },
1504
+ "source": [
1505
+ "lr_finder.plot(show_lr=lr)"
1506
+ ],
1507
+ "execution_count": null,
1508
+ "outputs": []
1509
+ },
1510
+ {
1511
+ "cell_type": "markdown",
1512
+ "metadata": {
1513
+ "id": "ZhHutCseBxjJ"
1514
+ },
1515
+ "source": [
1516
+ "## Training the Emotion Classifier"
1517
+ ]
1518
+ },
1519
+ {
1520
+ "cell_type": "code",
1521
+ "metadata": {
1522
+ "id": "q3FiLr3LBrjs"
1523
+ },
1524
+ "source": [
1525
+ "hparams = Namespace(\n",
1526
+ " train_path=train_path,\n",
1527
+ " val_path=val_path,\n",
1528
+ " test_path=test_path,\n",
1529
+ " batch_size=32,\n",
1530
+ " warmup_steps=100,\n",
1531
+ " epochs=1,\n",
1532
+ " lr=lr,\n",
1533
+ " accumulate_grad_batches=1\n",
1534
+ ")\n",
1535
+ "module = TrainingModule(hparams)"
1536
+ ],
1537
+ "execution_count": null,
1538
+ "outputs": []
1539
+ },
1540
+ {
1541
+ "cell_type": "code",
1542
+ "metadata": {
1543
+ "id": "N8Jv_U25B37g"
1544
+ },
1545
+ "source": [
1546
+ "## garbage collection\n",
1547
+ "import gc; gc.collect()\n",
1548
+ "torch.cuda.empty_cache()"
1549
+ ],
1550
+ "execution_count": null,
1551
+ "outputs": []
1552
+ },
1553
+ {
1554
+ "cell_type": "code",
1555
+ "metadata": {
1556
+ "id": "oRnl4HXvB5-T"
1557
+ },
1558
+ "source": [
1559
+ "## train roughly for about 10-15 minutes with GPU enabled.\n",
1560
+ "trainer = pl.Trainer(gpus=1, max_epochs=hparams.epochs, progress_bar_refresh_rate=10,\n",
1561
+ " accumulate_grad_batches=hparams.accumulate_grad_batches)\n",
1562
+ "\n",
1563
+ "trainer.fit(module)"
1564
+ ],
1565
+ "execution_count": null,
1566
+ "outputs": []
1567
+ },
1568
+ {
1569
+ "cell_type": "code",
1570
+ "metadata": {
1571
+ "id": "Y8kzE1AeB_ij"
1572
+ },
1573
+ "source": [
1574
+ "with torch.no_grad():\n",
1575
+ " progress = [\"/\", \"-\", \"\\\\\", \"|\", \"/\", \"-\", \"\\\\\", \"|\"]\n",
1576
+ " module.eval()\n",
1577
+ " true_y, pred_y = [], []\n",
1578
+ " for i, batch_ in enumerate(module.test_dataloader()):\n",
1579
+ " (X, attn), y = batch_\n",
1580
+ " batch = (X.cuda(), attn.cuda())\n",
1581
+ " print(progress[i % len(progress)], end=\"\\r\")\n",
1582
+ " y_pred = torch.argmax(module(batch), dim=1)\n",
1583
+ " true_y.extend(y.cpu())\n",
1584
+ " pred_y.extend(y_pred.cpu())\n",
1585
+ "print(\"\\n\" + \"_\" * 80)\n",
1586
+ "print(classification_report(true_y, pred_y, target_names=label2int.keys(), digits=len(emotions)))"
1587
+ ],
1588
+ "execution_count": null,
1589
+ "outputs": []
1590
+ },
1591
+ {
1592
+ "cell_type": "code",
1593
+ "metadata": {
1594
+ "id": "U0_Z_4Pkl3fc"
1595
+ },
1596
+ "source": [
1597
+ "!nvidia-smi"
1598
+ ],
1599
+ "execution_count": null,
1600
+ "outputs": []
1601
+ },
1602
+ {
1603
+ "cell_type": "code",
1604
+ "source": [],
1605
+ "metadata": {
1606
+ "id": "ifER7sn-Htge"
1607
+ },
1608
+ "execution_count": null,
1609
+ "outputs": []
1610
+ }
1611
+ ]
1612
+ }