karan99300 commited on
Commit
a521b82
·
verified ·
1 Parent(s): f89f17b

Upload 6 files

Browse files
code.ipynb ADDED
@@ -0,0 +1,1400 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 1,
6
+ "id": "c2ed359a",
7
+ "metadata": {},
8
+ "outputs": [],
9
+ "source": [
10
+ "import pandas as pd"
11
+ ]
12
+ },
13
+ {
14
+ "cell_type": "code",
15
+ "execution_count": 2,
16
+ "id": "2d441603",
17
+ "metadata": {},
18
+ "outputs": [
19
+ {
20
+ "data": {
21
+ "text/html": [
22
+ "<div>\n",
23
+ "<style scoped>\n",
24
+ " .dataframe tbody tr th:only-of-type {\n",
25
+ " vertical-align: middle;\n",
26
+ " }\n",
27
+ "\n",
28
+ " .dataframe tbody tr th {\n",
29
+ " vertical-align: top;\n",
30
+ " }\n",
31
+ "\n",
32
+ " .dataframe thead th {\n",
33
+ " text-align: right;\n",
34
+ " }\n",
35
+ "</style>\n",
36
+ "<table border=\"1\" class=\"dataframe\">\n",
37
+ " <thead>\n",
38
+ " <tr style=\"text-align: right;\">\n",
39
+ " <th></th>\n",
40
+ " <th>textID</th>\n",
41
+ " <th>text</th>\n",
42
+ " <th>selected_text</th>\n",
43
+ " <th>sentiment</th>\n",
44
+ " <th>Time of Tweet</th>\n",
45
+ " <th>Age of User</th>\n",
46
+ " <th>Country</th>\n",
47
+ " <th>Population -2020</th>\n",
48
+ " <th>Land Area (Km²)</th>\n",
49
+ " <th>Density (P/Km²)</th>\n",
50
+ " </tr>\n",
51
+ " </thead>\n",
52
+ " <tbody>\n",
53
+ " <tr>\n",
54
+ " <th>0</th>\n",
55
+ " <td>cb774db0d1</td>\n",
56
+ " <td>I`d have responded, if I were going</td>\n",
57
+ " <td>I`d have responded, if I were going</td>\n",
58
+ " <td>neutral</td>\n",
59
+ " <td>morning</td>\n",
60
+ " <td>0-20</td>\n",
61
+ " <td>Afghanistan</td>\n",
62
+ " <td>38928346</td>\n",
63
+ " <td>652860.0</td>\n",
64
+ " <td>60</td>\n",
65
+ " </tr>\n",
66
+ " <tr>\n",
67
+ " <th>1</th>\n",
68
+ " <td>549e992a42</td>\n",
69
+ " <td>Sooo SAD I will miss you here in San Diego!!!</td>\n",
70
+ " <td>Sooo SAD</td>\n",
71
+ " <td>negative</td>\n",
72
+ " <td>noon</td>\n",
73
+ " <td>21-30</td>\n",
74
+ " <td>Albania</td>\n",
75
+ " <td>2877797</td>\n",
76
+ " <td>27400.0</td>\n",
77
+ " <td>105</td>\n",
78
+ " </tr>\n",
79
+ " <tr>\n",
80
+ " <th>2</th>\n",
81
+ " <td>088c60f138</td>\n",
82
+ " <td>my boss is bullying me...</td>\n",
83
+ " <td>bullying me</td>\n",
84
+ " <td>negative</td>\n",
85
+ " <td>night</td>\n",
86
+ " <td>31-45</td>\n",
87
+ " <td>Algeria</td>\n",
88
+ " <td>43851044</td>\n",
89
+ " <td>2381740.0</td>\n",
90
+ " <td>18</td>\n",
91
+ " </tr>\n",
92
+ " <tr>\n",
93
+ " <th>3</th>\n",
94
+ " <td>9642c003ef</td>\n",
95
+ " <td>what interview! leave me alone</td>\n",
96
+ " <td>leave me alone</td>\n",
97
+ " <td>negative</td>\n",
98
+ " <td>morning</td>\n",
99
+ " <td>46-60</td>\n",
100
+ " <td>Andorra</td>\n",
101
+ " <td>77265</td>\n",
102
+ " <td>470.0</td>\n",
103
+ " <td>164</td>\n",
104
+ " </tr>\n",
105
+ " <tr>\n",
106
+ " <th>4</th>\n",
107
+ " <td>358bd9e861</td>\n",
108
+ " <td>Sons of ****, why couldn`t they put them on t...</td>\n",
109
+ " <td>Sons of ****,</td>\n",
110
+ " <td>negative</td>\n",
111
+ " <td>noon</td>\n",
112
+ " <td>60-70</td>\n",
113
+ " <td>Angola</td>\n",
114
+ " <td>32866272</td>\n",
115
+ " <td>1246700.0</td>\n",
116
+ " <td>26</td>\n",
117
+ " </tr>\n",
118
+ " <tr>\n",
119
+ " <th>...</th>\n",
120
+ " <td>...</td>\n",
121
+ " <td>...</td>\n",
122
+ " <td>...</td>\n",
123
+ " <td>...</td>\n",
124
+ " <td>...</td>\n",
125
+ " <td>...</td>\n",
126
+ " <td>...</td>\n",
127
+ " <td>...</td>\n",
128
+ " <td>...</td>\n",
129
+ " <td>...</td>\n",
130
+ " </tr>\n",
131
+ " <tr>\n",
132
+ " <th>27476</th>\n",
133
+ " <td>4eac33d1c0</td>\n",
134
+ " <td>wish we could come see u on Denver husband l...</td>\n",
135
+ " <td>d lost</td>\n",
136
+ " <td>negative</td>\n",
137
+ " <td>night</td>\n",
138
+ " <td>31-45</td>\n",
139
+ " <td>Ghana</td>\n",
140
+ " <td>31072940</td>\n",
141
+ " <td>227540.0</td>\n",
142
+ " <td>137</td>\n",
143
+ " </tr>\n",
144
+ " <tr>\n",
145
+ " <th>27477</th>\n",
146
+ " <td>4f4c4fc327</td>\n",
147
+ " <td>I`ve wondered about rake to. The client has ...</td>\n",
148
+ " <td>, don`t force</td>\n",
149
+ " <td>negative</td>\n",
150
+ " <td>morning</td>\n",
151
+ " <td>46-60</td>\n",
152
+ " <td>Greece</td>\n",
153
+ " <td>10423054</td>\n",
154
+ " <td>128900.0</td>\n",
155
+ " <td>81</td>\n",
156
+ " </tr>\n",
157
+ " <tr>\n",
158
+ " <th>27478</th>\n",
159
+ " <td>f67aae2310</td>\n",
160
+ " <td>Yay good for both of you. Enjoy the break - y...</td>\n",
161
+ " <td>Yay good for both of you.</td>\n",
162
+ " <td>positive</td>\n",
163
+ " <td>noon</td>\n",
164
+ " <td>60-70</td>\n",
165
+ " <td>Grenada</td>\n",
166
+ " <td>112523</td>\n",
167
+ " <td>340.0</td>\n",
168
+ " <td>331</td>\n",
169
+ " </tr>\n",
170
+ " <tr>\n",
171
+ " <th>27479</th>\n",
172
+ " <td>ed167662a5</td>\n",
173
+ " <td>But it was worth it ****.</td>\n",
174
+ " <td>But it was worth it ****.</td>\n",
175
+ " <td>positive</td>\n",
176
+ " <td>night</td>\n",
177
+ " <td>70-100</td>\n",
178
+ " <td>Guatemala</td>\n",
179
+ " <td>17915568</td>\n",
180
+ " <td>107160.0</td>\n",
181
+ " <td>167</td>\n",
182
+ " </tr>\n",
183
+ " <tr>\n",
184
+ " <th>27480</th>\n",
185
+ " <td>6f7127d9d7</td>\n",
186
+ " <td>All this flirting going on - The ATG smiles...</td>\n",
187
+ " <td>All this flirting going on - The ATG smiles. Y...</td>\n",
188
+ " <td>neutral</td>\n",
189
+ " <td>morning</td>\n",
190
+ " <td>0-20</td>\n",
191
+ " <td>Guinea</td>\n",
192
+ " <td>13132795</td>\n",
193
+ " <td>246000.0</td>\n",
194
+ " <td>53</td>\n",
195
+ " </tr>\n",
196
+ " </tbody>\n",
197
+ "</table>\n",
198
+ "<p>27481 rows × 10 columns</p>\n",
199
+ "</div>"
200
+ ],
201
+ "text/plain": [
202
+ " textID text \\\n",
203
+ "0 cb774db0d1 I`d have responded, if I were going \n",
204
+ "1 549e992a42 Sooo SAD I will miss you here in San Diego!!! \n",
205
+ "2 088c60f138 my boss is bullying me... \n",
206
+ "3 9642c003ef what interview! leave me alone \n",
207
+ "4 358bd9e861 Sons of ****, why couldn`t they put them on t... \n",
208
+ "... ... ... \n",
209
+ "27476 4eac33d1c0 wish we could come see u on Denver husband l... \n",
210
+ "27477 4f4c4fc327 I`ve wondered about rake to. The client has ... \n",
211
+ "27478 f67aae2310 Yay good for both of you. Enjoy the break - y... \n",
212
+ "27479 ed167662a5 But it was worth it ****. \n",
213
+ "27480 6f7127d9d7 All this flirting going on - The ATG smiles... \n",
214
+ "\n",
215
+ " selected_text sentiment \\\n",
216
+ "0 I`d have responded, if I were going neutral \n",
217
+ "1 Sooo SAD negative \n",
218
+ "2 bullying me negative \n",
219
+ "3 leave me alone negative \n",
220
+ "4 Sons of ****, negative \n",
221
+ "... ... ... \n",
222
+ "27476 d lost negative \n",
223
+ "27477 , don`t force negative \n",
224
+ "27478 Yay good for both of you. positive \n",
225
+ "27479 But it was worth it ****. positive \n",
226
+ "27480 All this flirting going on - The ATG smiles. Y... neutral \n",
227
+ "\n",
228
+ " Time of Tweet Age of User Country Population -2020 \\\n",
229
+ "0 morning 0-20 Afghanistan 38928346 \n",
230
+ "1 noon 21-30 Albania 2877797 \n",
231
+ "2 night 31-45 Algeria 43851044 \n",
232
+ "3 morning 46-60 Andorra 77265 \n",
233
+ "4 noon 60-70 Angola 32866272 \n",
234
+ "... ... ... ... ... \n",
235
+ "27476 night 31-45 Ghana 31072940 \n",
236
+ "27477 morning 46-60 Greece 10423054 \n",
237
+ "27478 noon 60-70 Grenada 112523 \n",
238
+ "27479 night 70-100 Guatemala 17915568 \n",
239
+ "27480 morning 0-20 Guinea 13132795 \n",
240
+ "\n",
241
+ " Land Area (Km²) Density (P/Km²) \n",
242
+ "0 652860.0 60 \n",
243
+ "1 27400.0 105 \n",
244
+ "2 2381740.0 18 \n",
245
+ "3 470.0 164 \n",
246
+ "4 1246700.0 26 \n",
247
+ "... ... ... \n",
248
+ "27476 227540.0 137 \n",
249
+ "27477 128900.0 81 \n",
250
+ "27478 340.0 331 \n",
251
+ "27479 107160.0 167 \n",
252
+ "27480 246000.0 53 \n",
253
+ "\n",
254
+ "[27481 rows x 10 columns]"
255
+ ]
256
+ },
257
+ "execution_count": 2,
258
+ "metadata": {},
259
+ "output_type": "execute_result"
260
+ }
261
+ ],
262
+ "source": [
263
+ "df=pd.read_csv('train.csv',encoding='unicode_escape')\n",
264
+ "df"
265
+ ]
266
+ },
267
+ {
268
+ "cell_type": "code",
269
+ "execution_count": 3,
270
+ "id": "60b7c4de",
271
+ "metadata": {},
272
+ "outputs": [
273
+ {
274
+ "data": {
275
+ "text/html": [
276
+ "<div>\n",
277
+ "<style scoped>\n",
278
+ " .dataframe tbody tr th:only-of-type {\n",
279
+ " vertical-align: middle;\n",
280
+ " }\n",
281
+ "\n",
282
+ " .dataframe tbody tr th {\n",
283
+ " vertical-align: top;\n",
284
+ " }\n",
285
+ "\n",
286
+ " .dataframe thead th {\n",
287
+ " text-align: right;\n",
288
+ " }\n",
289
+ "</style>\n",
290
+ "<table border=\"1\" class=\"dataframe\">\n",
291
+ " <thead>\n",
292
+ " <tr style=\"text-align: right;\">\n",
293
+ " <th></th>\n",
294
+ " <th>text</th>\n",
295
+ " <th>sentiment</th>\n",
296
+ " </tr>\n",
297
+ " </thead>\n",
298
+ " <tbody>\n",
299
+ " <tr>\n",
300
+ " <th>0</th>\n",
301
+ " <td>I`d have responded, if I were going</td>\n",
302
+ " <td>neutral</td>\n",
303
+ " </tr>\n",
304
+ " <tr>\n",
305
+ " <th>1</th>\n",
306
+ " <td>Sooo SAD I will miss you here in San Diego!!!</td>\n",
307
+ " <td>negative</td>\n",
308
+ " </tr>\n",
309
+ " <tr>\n",
310
+ " <th>2</th>\n",
311
+ " <td>my boss is bullying me...</td>\n",
312
+ " <td>negative</td>\n",
313
+ " </tr>\n",
314
+ " <tr>\n",
315
+ " <th>3</th>\n",
316
+ " <td>what interview! leave me alone</td>\n",
317
+ " <td>negative</td>\n",
318
+ " </tr>\n",
319
+ " <tr>\n",
320
+ " <th>4</th>\n",
321
+ " <td>Sons of ****, why couldn`t they put them on t...</td>\n",
322
+ " <td>negative</td>\n",
323
+ " </tr>\n",
324
+ " <tr>\n",
325
+ " <th>...</th>\n",
326
+ " <td>...</td>\n",
327
+ " <td>...</td>\n",
328
+ " </tr>\n",
329
+ " <tr>\n",
330
+ " <th>27476</th>\n",
331
+ " <td>wish we could come see u on Denver husband l...</td>\n",
332
+ " <td>negative</td>\n",
333
+ " </tr>\n",
334
+ " <tr>\n",
335
+ " <th>27477</th>\n",
336
+ " <td>I`ve wondered about rake to. The client has ...</td>\n",
337
+ " <td>negative</td>\n",
338
+ " </tr>\n",
339
+ " <tr>\n",
340
+ " <th>27478</th>\n",
341
+ " <td>Yay good for both of you. Enjoy the break - y...</td>\n",
342
+ " <td>positive</td>\n",
343
+ " </tr>\n",
344
+ " <tr>\n",
345
+ " <th>27479</th>\n",
346
+ " <td>But it was worth it ****.</td>\n",
347
+ " <td>positive</td>\n",
348
+ " </tr>\n",
349
+ " <tr>\n",
350
+ " <th>27480</th>\n",
351
+ " <td>All this flirting going on - The ATG smiles...</td>\n",
352
+ " <td>neutral</td>\n",
353
+ " </tr>\n",
354
+ " </tbody>\n",
355
+ "</table>\n",
356
+ "<p>27481 rows × 2 columns</p>\n",
357
+ "</div>"
358
+ ],
359
+ "text/plain": [
360
+ " text sentiment\n",
361
+ "0 I`d have responded, if I were going neutral\n",
362
+ "1 Sooo SAD I will miss you here in San Diego!!! negative\n",
363
+ "2 my boss is bullying me... negative\n",
364
+ "3 what interview! leave me alone negative\n",
365
+ "4 Sons of ****, why couldn`t they put them on t... negative\n",
366
+ "... ... ...\n",
367
+ "27476 wish we could come see u on Denver husband l... negative\n",
368
+ "27477 I`ve wondered about rake to. The client has ... negative\n",
369
+ "27478 Yay good for both of you. Enjoy the break - y... positive\n",
370
+ "27479 But it was worth it ****. positive\n",
371
+ "27480 All this flirting going on - The ATG smiles... neutral\n",
372
+ "\n",
373
+ "[27481 rows x 2 columns]"
374
+ ]
375
+ },
376
+ "execution_count": 3,
377
+ "metadata": {},
378
+ "output_type": "execute_result"
379
+ }
380
+ ],
381
+ "source": [
382
+ "df.drop(df.columns[[0,2,4,5,6,7,8,9]], axis=1, inplace=True)\n",
383
+ "df"
384
+ ]
385
+ },
386
+ {
387
+ "cell_type": "code",
388
+ "execution_count": 4,
389
+ "id": "296d66e2",
390
+ "metadata": {},
391
+ "outputs": [],
392
+ "source": [
393
+ "labels=df.sentiment.unique()"
394
+ ]
395
+ },
396
+ {
397
+ "cell_type": "code",
398
+ "execution_count": 5,
399
+ "id": "99085d3e",
400
+ "metadata": {},
401
+ "outputs": [],
402
+ "source": [
403
+ "label_dict={}\n",
404
+ "for id,label in enumerate(labels):\n",
405
+ " label_dict[label]=id"
406
+ ]
407
+ },
408
+ {
409
+ "cell_type": "code",
410
+ "execution_count": 6,
411
+ "id": "fa1e4160",
412
+ "metadata": {},
413
+ "outputs": [
414
+ {
415
+ "data": {
416
+ "text/plain": [
417
+ "{'neutral': 0, 'negative': 1, 'positive': 2}"
418
+ ]
419
+ },
420
+ "execution_count": 6,
421
+ "metadata": {},
422
+ "output_type": "execute_result"
423
+ }
424
+ ],
425
+ "source": [
426
+ "label_dict"
427
+ ]
428
+ },
429
+ {
430
+ "cell_type": "code",
431
+ "execution_count": 7,
432
+ "id": "eaa2c872",
433
+ "metadata": {},
434
+ "outputs": [],
435
+ "source": [
436
+ "df['label']=df.sentiment.replace(label_dict)"
437
+ ]
438
+ },
439
+ {
440
+ "cell_type": "code",
441
+ "execution_count": 8,
442
+ "id": "63ed05e3",
443
+ "metadata": {},
444
+ "outputs": [
445
+ {
446
+ "data": {
447
+ "text/html": [
448
+ "<div>\n",
449
+ "<style scoped>\n",
450
+ " .dataframe tbody tr th:only-of-type {\n",
451
+ " vertical-align: middle;\n",
452
+ " }\n",
453
+ "\n",
454
+ " .dataframe tbody tr th {\n",
455
+ " vertical-align: top;\n",
456
+ " }\n",
457
+ "\n",
458
+ " .dataframe thead th {\n",
459
+ " text-align: right;\n",
460
+ " }\n",
461
+ "</style>\n",
462
+ "<table border=\"1\" class=\"dataframe\">\n",
463
+ " <thead>\n",
464
+ " <tr style=\"text-align: right;\">\n",
465
+ " <th></th>\n",
466
+ " <th>text</th>\n",
467
+ " <th>sentiment</th>\n",
468
+ " <th>label</th>\n",
469
+ " </tr>\n",
470
+ " </thead>\n",
471
+ " <tbody>\n",
472
+ " <tr>\n",
473
+ " <th>0</th>\n",
474
+ " <td>I`d have responded, if I were going</td>\n",
475
+ " <td>neutral</td>\n",
476
+ " <td>0</td>\n",
477
+ " </tr>\n",
478
+ " <tr>\n",
479
+ " <th>1</th>\n",
480
+ " <td>Sooo SAD I will miss you here in San Diego!!!</td>\n",
481
+ " <td>negative</td>\n",
482
+ " <td>1</td>\n",
483
+ " </tr>\n",
484
+ " <tr>\n",
485
+ " <th>2</th>\n",
486
+ " <td>my boss is bullying me...</td>\n",
487
+ " <td>negative</td>\n",
488
+ " <td>1</td>\n",
489
+ " </tr>\n",
490
+ " <tr>\n",
491
+ " <th>3</th>\n",
492
+ " <td>what interview! leave me alone</td>\n",
493
+ " <td>negative</td>\n",
494
+ " <td>1</td>\n",
495
+ " </tr>\n",
496
+ " <tr>\n",
497
+ " <th>4</th>\n",
498
+ " <td>Sons of ****, why couldn`t they put them on t...</td>\n",
499
+ " <td>negative</td>\n",
500
+ " <td>1</td>\n",
501
+ " </tr>\n",
502
+ " <tr>\n",
503
+ " <th>...</th>\n",
504
+ " <td>...</td>\n",
505
+ " <td>...</td>\n",
506
+ " <td>...</td>\n",
507
+ " </tr>\n",
508
+ " <tr>\n",
509
+ " <th>27476</th>\n",
510
+ " <td>wish we could come see u on Denver husband l...</td>\n",
511
+ " <td>negative</td>\n",
512
+ " <td>1</td>\n",
513
+ " </tr>\n",
514
+ " <tr>\n",
515
+ " <th>27477</th>\n",
516
+ " <td>I`ve wondered about rake to. The client has ...</td>\n",
517
+ " <td>negative</td>\n",
518
+ " <td>1</td>\n",
519
+ " </tr>\n",
520
+ " <tr>\n",
521
+ " <th>27478</th>\n",
522
+ " <td>Yay good for both of you. Enjoy the break - y...</td>\n",
523
+ " <td>positive</td>\n",
524
+ " <td>2</td>\n",
525
+ " </tr>\n",
526
+ " <tr>\n",
527
+ " <th>27479</th>\n",
528
+ " <td>But it was worth it ****.</td>\n",
529
+ " <td>positive</td>\n",
530
+ " <td>2</td>\n",
531
+ " </tr>\n",
532
+ " <tr>\n",
533
+ " <th>27480</th>\n",
534
+ " <td>All this flirting going on - The ATG smiles...</td>\n",
535
+ " <td>neutral</td>\n",
536
+ " <td>0</td>\n",
537
+ " </tr>\n",
538
+ " </tbody>\n",
539
+ "</table>\n",
540
+ "<p>27481 rows × 3 columns</p>\n",
541
+ "</div>"
542
+ ],
543
+ "text/plain": [
544
+ " text sentiment label\n",
545
+ "0 I`d have responded, if I were going neutral 0\n",
546
+ "1 Sooo SAD I will miss you here in San Diego!!! negative 1\n",
547
+ "2 my boss is bullying me... negative 1\n",
548
+ "3 what interview! leave me alone negative 1\n",
549
+ "4 Sons of ****, why couldn`t they put them on t... negative 1\n",
550
+ "... ... ... ...\n",
551
+ "27476 wish we could come see u on Denver husband l... negative 1\n",
552
+ "27477 I`ve wondered about rake to. The client has ... negative 1\n",
553
+ "27478 Yay good for both of you. Enjoy the break - y... positive 2\n",
554
+ "27479 But it was worth it ****. positive 2\n",
555
+ "27480 All this flirting going on - The ATG smiles... neutral 0\n",
556
+ "\n",
557
+ "[27481 rows x 3 columns]"
558
+ ]
559
+ },
560
+ "execution_count": 8,
561
+ "metadata": {},
562
+ "output_type": "execute_result"
563
+ }
564
+ ],
565
+ "source": [
566
+ "df"
567
+ ]
568
+ },
569
+ {
570
+ "cell_type": "code",
571
+ "execution_count": 9,
572
+ "id": "b1dc846d",
573
+ "metadata": {},
574
+ "outputs": [],
575
+ "source": [
576
+ "from sklearn.model_selection import train_test_split\n",
577
+ "X_train,X_val,y_train,y_val=train_test_split(df.index.values,df['label'].values,train_size=0.8,random_state=0)"
578
+ ]
579
+ },
580
+ {
581
+ "cell_type": "code",
582
+ "execution_count": 10,
583
+ "id": "418fb78e",
584
+ "metadata": {},
585
+ "outputs": [],
586
+ "source": [
587
+ "df['data_type'] = 'not_set'\n",
588
+ "df.loc[X_train, 'data_type'] = 'train'\n",
589
+ "df.loc[X_val, 'data_type'] = 'val'"
590
+ ]
591
+ },
592
+ {
593
+ "cell_type": "code",
594
+ "execution_count": 11,
595
+ "id": "ceed0315",
596
+ "metadata": {},
597
+ "outputs": [],
598
+ "source": [
599
+ "df['text']=df['text'].astype(str)"
600
+ ]
601
+ },
602
+ {
603
+ "cell_type": "code",
604
+ "execution_count": 12,
605
+ "id": "d3c74651",
606
+ "metadata": {},
607
+ "outputs": [
608
+ {
609
+ "data": {
610
+ "text/html": [
611
+ "<div>\n",
612
+ "<style scoped>\n",
613
+ " .dataframe tbody tr th:only-of-type {\n",
614
+ " vertical-align: middle;\n",
615
+ " }\n",
616
+ "\n",
617
+ " .dataframe tbody tr th {\n",
618
+ " vertical-align: top;\n",
619
+ " }\n",
620
+ "\n",
621
+ " .dataframe thead th {\n",
622
+ " text-align: right;\n",
623
+ " }\n",
624
+ "</style>\n",
625
+ "<table border=\"1\" class=\"dataframe\">\n",
626
+ " <thead>\n",
627
+ " <tr style=\"text-align: right;\">\n",
628
+ " <th></th>\n",
629
+ " <th>text</th>\n",
630
+ " <th>sentiment</th>\n",
631
+ " <th>label</th>\n",
632
+ " <th>data_type</th>\n",
633
+ " </tr>\n",
634
+ " </thead>\n",
635
+ " <tbody>\n",
636
+ " <tr>\n",
637
+ " <th>0</th>\n",
638
+ " <td>I`d have responded, if I were going</td>\n",
639
+ " <td>neutral</td>\n",
640
+ " <td>0</td>\n",
641
+ " <td>train</td>\n",
642
+ " </tr>\n",
643
+ " <tr>\n",
644
+ " <th>1</th>\n",
645
+ " <td>Sooo SAD I will miss you here in San Diego!!!</td>\n",
646
+ " <td>negative</td>\n",
647
+ " <td>1</td>\n",
648
+ " <td>train</td>\n",
649
+ " </tr>\n",
650
+ " <tr>\n",
651
+ " <th>2</th>\n",
652
+ " <td>my boss is bullying me...</td>\n",
653
+ " <td>negative</td>\n",
654
+ " <td>1</td>\n",
655
+ " <td>train</td>\n",
656
+ " </tr>\n",
657
+ " <tr>\n",
658
+ " <th>3</th>\n",
659
+ " <td>what interview! leave me alone</td>\n",
660
+ " <td>negative</td>\n",
661
+ " <td>1</td>\n",
662
+ " <td>train</td>\n",
663
+ " </tr>\n",
664
+ " <tr>\n",
665
+ " <th>4</th>\n",
666
+ " <td>Sons of ****, why couldn`t they put them on t...</td>\n",
667
+ " <td>negative</td>\n",
668
+ " <td>1</td>\n",
669
+ " <td>val</td>\n",
670
+ " </tr>\n",
671
+ " <tr>\n",
672
+ " <th>...</th>\n",
673
+ " <td>...</td>\n",
674
+ " <td>...</td>\n",
675
+ " <td>...</td>\n",
676
+ " <td>...</td>\n",
677
+ " </tr>\n",
678
+ " <tr>\n",
679
+ " <th>27476</th>\n",
680
+ " <td>wish we could come see u on Denver husband l...</td>\n",
681
+ " <td>negative</td>\n",
682
+ " <td>1</td>\n",
683
+ " <td>train</td>\n",
684
+ " </tr>\n",
685
+ " <tr>\n",
686
+ " <th>27477</th>\n",
687
+ " <td>I`ve wondered about rake to. The client has ...</td>\n",
688
+ " <td>negative</td>\n",
689
+ " <td>1</td>\n",
690
+ " <td>train</td>\n",
691
+ " </tr>\n",
692
+ " <tr>\n",
693
+ " <th>27478</th>\n",
694
+ " <td>Yay good for both of you. Enjoy the break - y...</td>\n",
695
+ " <td>positive</td>\n",
696
+ " <td>2</td>\n",
697
+ " <td>val</td>\n",
698
+ " </tr>\n",
699
+ " <tr>\n",
700
+ " <th>27479</th>\n",
701
+ " <td>But it was worth it ****.</td>\n",
702
+ " <td>positive</td>\n",
703
+ " <td>2</td>\n",
704
+ " <td>val</td>\n",
705
+ " </tr>\n",
706
+ " <tr>\n",
707
+ " <th>27480</th>\n",
708
+ " <td>All this flirting going on - The ATG smiles...</td>\n",
709
+ " <td>neutral</td>\n",
710
+ " <td>0</td>\n",
711
+ " <td>val</td>\n",
712
+ " </tr>\n",
713
+ " </tbody>\n",
714
+ "</table>\n",
715
+ "<p>27481 rows × 4 columns</p>\n",
716
+ "</div>"
717
+ ],
718
+ "text/plain": [
719
+ " text sentiment label \\\n",
720
+ "0 I`d have responded, if I were going neutral 0 \n",
721
+ "1 Sooo SAD I will miss you here in San Diego!!! negative 1 \n",
722
+ "2 my boss is bullying me... negative 1 \n",
723
+ "3 what interview! leave me alone negative 1 \n",
724
+ "4 Sons of ****, why couldn`t they put them on t... negative 1 \n",
725
+ "... ... ... ... \n",
726
+ "27476 wish we could come see u on Denver husband l... negative 1 \n",
727
+ "27477 I`ve wondered about rake to. The client has ... negative 1 \n",
728
+ "27478 Yay good for both of you. Enjoy the break - y... positive 2 \n",
729
+ "27479 But it was worth it ****. positive 2 \n",
730
+ "27480 All this flirting going on - The ATG smiles... neutral 0 \n",
731
+ "\n",
732
+ " data_type \n",
733
+ "0 train \n",
734
+ "1 train \n",
735
+ "2 train \n",
736
+ "3 train \n",
737
+ "4 val \n",
738
+ "... ... \n",
739
+ "27476 train \n",
740
+ "27477 train \n",
741
+ "27478 val \n",
742
+ "27479 val \n",
743
+ "27480 val \n",
744
+ "\n",
745
+ "[27481 rows x 4 columns]"
746
+ ]
747
+ },
748
+ "execution_count": 12,
749
+ "metadata": {},
750
+ "output_type": "execute_result"
751
+ }
752
+ ],
753
+ "source": [
754
+ "df"
755
+ ]
756
+ },
757
+ {
758
+ "cell_type": "code",
759
+ "execution_count": 13,
760
+ "id": "b018cca8",
761
+ "metadata": {},
762
+ "outputs": [
763
+ {
764
+ "data": {
765
+ "text/plain": [
766
+ "array([' I`d have responded, if I were going',\n",
767
+ " ' Sooo SAD I will miss you here in San Diego!!!',\n",
768
+ " 'my boss is bullying me...', ...,\n",
769
+ " 'So I get up early and I feel good about the day. I walk to work and I`m feeling alright. But guess what... I don`t work today.',\n",
770
+ " ' wish we could come see u on Denver husband lost his job and can`t afford it',\n",
771
+ " ' I`ve wondered about rake to. The client has made it clear .NET only, don`t force devs to learn a new lang #agile #ccnet'],\n",
772
+ " dtype=object)"
773
+ ]
774
+ },
775
+ "execution_count": 13,
776
+ "metadata": {},
777
+ "output_type": "execute_result"
778
+ }
779
+ ],
780
+ "source": [
781
+ "df[df.data_type=='train'].text.values"
782
+ ]
783
+ },
784
+ {
785
+ "cell_type": "code",
786
+ "execution_count": 14,
787
+ "id": "0d03c58e",
788
+ "metadata": {},
789
+ "outputs": [
790
+ {
791
+ "name": "stderr",
792
+ "output_type": "stream",
793
+ "text": [
794
+ "c:\\Users\\KARAN\\AppData\\Local\\Programs\\Python\\Python311\\Lib\\site-packages\\tqdm\\auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
795
+ " from .autonotebook import tqdm as notebook_tqdm\n"
796
+ ]
797
+ }
798
+ ],
799
+ "source": [
800
+ "from transformers import BertTokenizer\n",
801
+ "from torch.utils.data import TensorDataset\n",
802
+ "import torch"
803
+ ]
804
+ },
805
+ {
806
+ "cell_type": "code",
807
+ "execution_count": 15,
808
+ "id": "1fc7bfd6",
809
+ "metadata": {},
810
+ "outputs": [],
811
+ "source": [
812
+ "tokenizer = BertTokenizer.from_pretrained('bert-base-uncased', \n",
813
+ " do_lower_case=True)"
814
+ ]
815
+ },
816
+ {
817
+ "cell_type": "code",
818
+ "execution_count": 16,
819
+ "id": "ea3521d9",
820
+ "metadata": {},
821
+ "outputs": [],
822
+ "source": [
823
+ "encoded_data_train = tokenizer.batch_encode_plus(\n",
824
+ " df[df.data_type=='train'].text.values.tolist(), \n",
825
+ " add_special_tokens=True, \n",
826
+ " return_attention_mask=True, \n",
827
+ " max_length=256,\n",
828
+ " padding='max_length',\n",
829
+ " truncation=True,\n",
830
+ " return_tensors='pt',\n",
831
+ ")\n",
832
+ "\n",
833
+ "encoded_data_val = tokenizer.batch_encode_plus(\n",
834
+ " df[df.data_type=='val'].text.values.tolist(), \n",
835
+ " add_special_tokens=True, \n",
836
+ " return_attention_mask=True, \n",
837
+ " max_length=256,\n",
838
+ " truncation=True,\n",
839
+ " padding='max_length', \n",
840
+ " return_tensors='pt'\n",
841
+ ")\n",
842
+ "\n",
843
+ "\n",
844
+ "input_ids_train = encoded_data_train['input_ids']\n",
845
+ "attention_masks_train = encoded_data_train['attention_mask']\n",
846
+ "labels_train = torch.tensor(df[df.data_type=='train'].label.values)\n",
847
+ "\n",
848
+ "input_ids_val = encoded_data_val['input_ids']\n",
849
+ "attention_masks_val = encoded_data_val['attention_mask']\n",
850
+ "labels_val = torch.tensor(df[df.data_type=='val'].label.values)"
851
+ ]
852
+ },
853
+ {
854
+ "cell_type": "code",
855
+ "execution_count": 17,
856
+ "id": "d56c3636",
857
+ "metadata": {},
858
+ "outputs": [],
859
+ "source": [
860
+ "train_data=TensorDataset(input_ids_train,attention_masks_train,labels_train)\n",
861
+ "val_data=TensorDataset(input_ids_val,attention_masks_val,labels_val)"
862
+ ]
863
+ },
864
+ {
865
+ "cell_type": "code",
866
+ "execution_count": 18,
867
+ "id": "c1e6192b",
868
+ "metadata": {},
869
+ "outputs": [
870
+ {
871
+ "name": "stderr",
872
+ "output_type": "stream",
873
+ "text": [
874
+ "Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.weight', 'classifier.bias']\n",
875
+ "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
876
+ ]
877
+ }
878
+ ],
879
+ "source": [
880
+ "from transformers import BertForSequenceClassification\n",
881
+ "model = BertForSequenceClassification.from_pretrained(\"bert-base-uncased\",\n",
882
+ " num_labels=len(label_dict),\n",
883
+ " output_attentions=False,\n",
884
+ " output_hidden_states=False)"
885
+ ]
886
+ },
887
+ {
888
+ "cell_type": "code",
889
+ "execution_count": 19,
890
+ "id": "18b4fca0",
891
+ "metadata": {},
892
+ "outputs": [],
893
+ "source": [
894
+ "from torch.utils.data import DataLoader,RandomSampler,SequentialSampler\n",
895
+ "train_loader=DataLoader(\n",
896
+ " train_data,\n",
897
+ " sampler=RandomSampler(train_data),\n",
898
+ " batch_size=4\n",
899
+ ")\n",
900
+ "\n",
901
+ "val_loader=DataLoader(\n",
902
+ " val_data,\n",
903
+ " sampler=SequentialSampler(val_data),\n",
904
+ " batch_size=4\n",
905
+ ")"
906
+ ]
907
+ },
908
+ {
909
+ "cell_type": "code",
910
+ "execution_count": 20,
911
+ "id": "b3f37358",
912
+ "metadata": {},
913
+ "outputs": [],
914
+ "source": [
915
+ "from transformers import get_linear_schedule_with_warmup\n",
916
+ "from torch.optim import AdamW\n",
917
+ "\n",
918
+ "optimizer=AdamW(\n",
919
+ " model.parameters(),\n",
920
+ " lr=1e-5,\n",
921
+ " eps=1e-8\n",
922
+ ")\n",
923
+ "\n",
924
+ "epochs=5\n",
925
+ "\n",
926
+ "scheduler=get_linear_schedule_with_warmup(\n",
927
+ " optimizer,\n",
928
+ " num_warmup_steps=0,\n",
929
+ " num_training_steps=len(train_data)*epochs\n",
930
+ ")"
931
+ ]
932
+ },
933
+ {
934
+ "cell_type": "code",
935
+ "execution_count": 21,
936
+ "id": "a5ccc6d8",
937
+ "metadata": {},
938
+ "outputs": [
939
+ {
940
+ "name": "stdout",
941
+ "output_type": "stream",
942
+ "text": [
943
+ "cuda\n"
944
+ ]
945
+ }
946
+ ],
947
+ "source": [
948
+ "device = torch.device('cuda')\n",
949
+ "model.to(device)\n",
950
+ "print(device)"
951
+ ]
952
+ },
953
+ {
954
+ "cell_type": "code",
955
+ "execution_count": 22,
956
+ "id": "1ad2f635",
957
+ "metadata": {},
958
+ "outputs": [],
959
+ "source": [
960
+ "import numpy as np\n",
961
+ "def eval(val_loader,model):\n",
962
+ " model.eval()\n",
963
+ " loss_val_total=0\n",
964
+ " preds,true=[],[]\n",
965
+ " \n",
966
+ " for batch in val_loader:\n",
967
+ " batch=tuple(b.to(device) for b in batch)\n",
968
+ " inputs = {'input_ids': batch[0],\n",
969
+ " 'attention_mask': batch[1],\n",
970
+ " 'labels': batch[2],\n",
971
+ " }\n",
972
+ " \n",
973
+ " with torch.no_grad():\n",
974
+ " outputs=model(**inputs)\n",
975
+ " \n",
976
+ " loss=outputs[0]\n",
977
+ " logits=outputs[1]\n",
978
+ " loss_val_total+=loss.item()\n",
979
+ " logits=logits.detach().cpu().numpy()\n",
980
+ " labels=inputs['labels'].cpu().numpy()\n",
981
+ " preds.append(logits)\n",
982
+ " true.append(labels)\n",
983
+ " \n",
984
+ " loss_val_avg=loss_val_total/len(val_loader)\n",
985
+ " predictions=np.concatenate(preds,axis=0)\n",
986
+ " true_vals=np.concatenate(true,axis=0)\n",
987
+ " \n",
988
+ " return loss_val_avg,predictions,true_vals"
989
+ ]
990
+ },
991
+ {
992
+ "cell_type": "code",
993
+ "execution_count": 138,
994
+ "id": "05f1146d",
995
+ "metadata": {},
996
+ "outputs": [
997
+ {
998
+ "name": "stderr",
999
+ "output_type": "stream",
1000
+ "text": [
1001
+ " \r"
1002
+ ]
1003
+ },
1004
+ {
1005
+ "name": "stdout",
1006
+ "output_type": "stream",
1007
+ "text": [
1008
+ "\n",
1009
+ "Epoch 1\n",
1010
+ "Training loss: 1.1087568206265244\n",
1011
+ "Validation loss: 1.1073771828738126\n"
1012
+ ]
1013
+ },
1014
+ {
1015
+ "name": "stderr",
1016
+ "output_type": "stream",
1017
+ "text": [
1018
+ " \r"
1019
+ ]
1020
+ },
1021
+ {
1022
+ "name": "stdout",
1023
+ "output_type": "stream",
1024
+ "text": [
1025
+ "\n",
1026
+ "Epoch 2\n",
1027
+ "Training loss: 1.1035335373561803\n",
1028
+ "Validation loss: 1.0943231875246222\n"
1029
+ ]
1030
+ },
1031
+ {
1032
+ "name": "stderr",
1033
+ "output_type": "stream",
1034
+ "text": [
1035
+ " \r"
1036
+ ]
1037
+ },
1038
+ {
1039
+ "name": "stdout",
1040
+ "output_type": "stream",
1041
+ "text": [
1042
+ "\n",
1043
+ "Epoch 3\n",
1044
+ "Training loss: 1.0946122174852106\n",
1045
+ "Validation loss: 1.0898548677617854\n"
1046
+ ]
1047
+ },
1048
+ {
1049
+ "name": "stderr",
1050
+ "output_type": "stream",
1051
+ "text": [
1052
+ " \r"
1053
+ ]
1054
+ },
1055
+ {
1056
+ "name": "stdout",
1057
+ "output_type": "stream",
1058
+ "text": [
1059
+ "\n",
1060
+ "Epoch 4\n",
1061
+ "Training loss: 1.0907055499164993\n",
1062
+ "Validation loss: 1.0901057242480192\n"
1063
+ ]
1064
+ },
1065
+ {
1066
+ "name": "stderr",
1067
+ "output_type": "stream",
1068
+ "text": [
1069
+ " \r"
1070
+ ]
1071
+ },
1072
+ {
1073
+ "name": "stdout",
1074
+ "output_type": "stream",
1075
+ "text": [
1076
+ "\n",
1077
+ "Epoch 5\n",
1078
+ "Training loss: 1.0898382831825786\n",
1079
+ "Validation loss: 1.0943078843030063\n"
1080
+ ]
1081
+ }
1082
+ ],
1083
+ "source": [
1084
+ "from tqdm import tqdm\n",
1085
+ "for epoch in range(1,epochs+1):\n",
1086
+ " model.train()\n",
1087
+ " loss_train_total=0\n",
1088
+ " progress_bar = tqdm(train_loader,desc='Epoch {:1d}'.format(epoch), leave=False, disable=False)\n",
1089
+ " for batch in progress_bar:\n",
1090
+ "\n",
1091
+ " model.zero_grad()\n",
1092
+ " \n",
1093
+ " batch = tuple(b.to(device) for b in batch)\n",
1094
+ " \n",
1095
+ " inputs = {'input_ids': batch[0],\n",
1096
+ " 'attention_mask': batch[1],\n",
1097
+ " 'labels': batch[2],\n",
1098
+ " } \n",
1099
+ "\n",
1100
+ " outputs = model(**inputs)\n",
1101
+ " \n",
1102
+ " loss = outputs[0]\n",
1103
+ " loss_train_total += loss.item()\n",
1104
+ " loss.backward()\n",
1105
+ "\n",
1106
+ " torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)\n",
1107
+ "\n",
1108
+ " optimizer.step()\n",
1109
+ " scheduler.step()\n",
1110
+ " \n",
1111
+ " progress_bar.set_postfix({'training_loss': '{:.3f}'.format(loss.item()/len(batch))})\n",
1112
+ " torch.save(model.state_dict(), f'finetuned_BERT_epoch_{epoch}.model')\n",
1113
+ " \n",
1114
+ " tqdm.write(f'\\nEpoch {epoch}')\n",
1115
+ " \n",
1116
+ " loss_train_avg = loss_train_total/len(train_loader) \n",
1117
+ " tqdm.write(f'Training loss: {loss_train_avg}')\n",
1118
+ " \n",
1119
+ " val_loss, predictions, true_vals = eval(val_loader)\n",
1120
+ " tqdm.write(f'Validation loss: {val_loss}')"
1121
+ ]
1122
+ },
1123
+ {
1124
+ "cell_type": "code",
1125
+ "execution_count": 23,
1126
+ "id": "e9f7735a",
1127
+ "metadata": {},
1128
+ "outputs": [
1129
+ {
1130
+ "name": "stderr",
1131
+ "output_type": "stream",
1132
+ "text": [
1133
+ "Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.weight', 'classifier.bias']\n",
1134
+ "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
1135
+ ]
1136
+ },
1137
+ {
1138
+ "data": {
1139
+ "text/plain": [
1140
+ "BertForSequenceClassification(\n",
1141
+ " (bert): BertModel(\n",
1142
+ " (embeddings): BertEmbeddings(\n",
1143
+ " (word_embeddings): Embedding(30522, 768, padding_idx=0)\n",
1144
+ " (position_embeddings): Embedding(512, 768)\n",
1145
+ " (token_type_embeddings): Embedding(2, 768)\n",
1146
+ " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
1147
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
1148
+ " )\n",
1149
+ " (encoder): BertEncoder(\n",
1150
+ " (layer): ModuleList(\n",
1151
+ " (0-11): 12 x BertLayer(\n",
1152
+ " (attention): BertAttention(\n",
1153
+ " (self): BertSelfAttention(\n",
1154
+ " (query): Linear(in_features=768, out_features=768, bias=True)\n",
1155
+ " (key): Linear(in_features=768, out_features=768, bias=True)\n",
1156
+ " (value): Linear(in_features=768, out_features=768, bias=True)\n",
1157
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
1158
+ " )\n",
1159
+ " (output): BertSelfOutput(\n",
1160
+ " (dense): Linear(in_features=768, out_features=768, bias=True)\n",
1161
+ " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
1162
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
1163
+ " )\n",
1164
+ " )\n",
1165
+ " (intermediate): BertIntermediate(\n",
1166
+ " (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
1167
+ " (intermediate_act_fn): GELUActivation()\n",
1168
+ " )\n",
1169
+ " (output): BertOutput(\n",
1170
+ " (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
1171
+ " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
1172
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
1173
+ " )\n",
1174
+ " )\n",
1175
+ " )\n",
1176
+ " )\n",
1177
+ " (pooler): BertPooler(\n",
1178
+ " (dense): Linear(in_features=768, out_features=768, bias=True)\n",
1179
+ " (activation): Tanh()\n",
1180
+ " )\n",
1181
+ " )\n",
1182
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
1183
+ " (classifier): Linear(in_features=768, out_features=3, bias=True)\n",
1184
+ ")"
1185
+ ]
1186
+ },
1187
+ "execution_count": 23,
1188
+ "metadata": {},
1189
+ "output_type": "execute_result"
1190
+ }
1191
+ ],
1192
+ "source": [
1193
+ "model=BertForSequenceClassification.from_pretrained(\n",
1194
+ " 'bert-base-uncased',\n",
1195
+ " num_labels=len(label_dict),\n",
1196
+ " output_attentions=False,\n",
1197
+ " output_hidden_states=False\n",
1198
+ ")\n",
1199
+ " \n",
1200
+ "model.to(device)"
1201
+ ]
1202
+ },
1203
+ {
1204
+ "cell_type": "code",
1205
+ "execution_count": 40,
1206
+ "id": "2cb2cb43",
1207
+ "metadata": {},
1208
+ "outputs": [
1209
+ {
1210
+ "data": {
1211
+ "text/plain": [
1212
+ "<All keys matched successfully>"
1213
+ ]
1214
+ },
1215
+ "execution_count": 40,
1216
+ "metadata": {},
1217
+ "output_type": "execute_result"
1218
+ }
1219
+ ],
1220
+ "source": [
1221
+ "model.load_state_dict(torch.load('finetuned_BERT_epoch_2.model',map_location=torch.device('cpu')))"
1222
+ ]
1223
+ },
1224
+ {
1225
+ "cell_type": "code",
1226
+ "execution_count": 41,
1227
+ "id": "86053301",
1228
+ "metadata": {},
1229
+ "outputs": [],
1230
+ "source": [
1231
+ "loss,predictions,true_vals=eval(val_loader,model)"
1232
+ ]
1233
+ },
1234
+ {
1235
+ "cell_type": "code",
1236
+ "execution_count": null,
1237
+ "id": "26089e26",
1238
+ "metadata": {},
1239
+ "outputs": [
1240
+ {
1241
+ "data": {
1242
+ "text/plain": [
1243
+ "array([[-0.94487095, 2.4501007 , -2.4328873 ],\n",
1244
+ " [ 3.0208707 , -1.7925887 , 0.409608 ],\n",
1245
+ " [-1.245395 , 2.8607914 , -2.7080884 ],\n",
1246
+ " ...,\n",
1247
+ " [-0.13207848, -2.0695374 , 3.5249124 ],\n",
1248
+ " [-1.0361273 , -2.475614 , 3.9253955 ],\n",
1249
+ " [-0.3563956 , -2.541143 , 3.703467 ]], dtype=float32)"
1250
+ ]
1251
+ },
1252
+ "execution_count": 34,
1253
+ "metadata": {},
1254
+ "output_type": "execute_result"
1255
+ }
1256
+ ],
1257
+ "source": [
1258
+ "predictions"
1259
+ ]
1260
+ },
1261
+ {
1262
+ "cell_type": "code",
1263
+ "execution_count": null,
1264
+ "id": "bf20a5ca",
1265
+ "metadata": {},
1266
+ "outputs": [
1267
+ {
1268
+ "data": {
1269
+ "text/plain": [
1270
+ "array([1, 0, 1, ..., 2, 2, 2], dtype=int64)"
1271
+ ]
1272
+ },
1273
+ "execution_count": 35,
1274
+ "metadata": {},
1275
+ "output_type": "execute_result"
1276
+ }
1277
+ ],
1278
+ "source": [
1279
+ "preds_flat = np.argmax(predictions, axis=1).flatten()\n",
1280
+ "preds_flat"
1281
+ ]
1282
+ },
1283
+ {
1284
+ "cell_type": "code",
1285
+ "execution_count": null,
1286
+ "id": "70d73cf6",
1287
+ "metadata": {},
1288
+ "outputs": [
1289
+ {
1290
+ "data": {
1291
+ "text/plain": [
1292
+ "array([1, 0, 1, ..., 2, 2, 0], dtype=int64)"
1293
+ ]
1294
+ },
1295
+ "execution_count": 36,
1296
+ "metadata": {},
1297
+ "output_type": "execute_result"
1298
+ }
1299
+ ],
1300
+ "source": [
1301
+ "true_vals"
1302
+ ]
1303
+ },
1304
+ {
1305
+ "cell_type": "code",
1306
+ "execution_count": null,
1307
+ "id": "f4d78070",
1308
+ "metadata": {},
1309
+ "outputs": [],
1310
+ "source": [
1311
+ "def accuracy_per_class(preds, labels):\n",
1312
+ " label_dict_inverse = {v: k for k, v in label_dict.items()}\n",
1313
+ " \n",
1314
+ " preds_flat = np.argmax(preds, axis=1).flatten()\n",
1315
+ " labels_flat = labels.flatten()\n",
1316
+ "\n",
1317
+ " for label in np.unique(labels_flat):\n",
1318
+ " y_preds = preds_flat[labels_flat==label]\n",
1319
+ " y_true = labels_flat[labels_flat==label]\n",
1320
+ " print(f'Class: {label_dict_inverse[label]}')\n",
1321
+ " print(f'Accuracy: {len(y_preds[y_preds==label])}/{len(y_true)}\\n')"
1322
+ ]
1323
+ },
1324
+ {
1325
+ "cell_type": "code",
1326
+ "execution_count": null,
1327
+ "id": "46eb06a4",
1328
+ "metadata": {},
1329
+ "outputs": [
1330
+ {
1331
+ "name": "stdout",
1332
+ "output_type": "stream",
1333
+ "text": [
1334
+ "Class: neutral\n",
1335
+ "Accuracy: 1571/2195\n",
1336
+ "\n",
1337
+ "Class: negative\n",
1338
+ "Accuracy: 1230/1563\n",
1339
+ "\n",
1340
+ "Class: positive\n",
1341
+ "Accuracy: 1501/1739\n",
1342
+ "\n"
1343
+ ]
1344
+ }
1345
+ ],
1346
+ "source": [
1347
+ "accuracy_per_class(predictions, true_vals)"
1348
+ ]
1349
+ },
1350
+ {
1351
+ "cell_type": "code",
1352
+ "execution_count": null,
1353
+ "id": "284950e9",
1354
+ "metadata": {},
1355
+ "outputs": [
1356
+ {
1357
+ "name": "stdout",
1358
+ "output_type": "stream",
1359
+ "text": [
1360
+ " precision recall f1-score support\n",
1361
+ "\n",
1362
+ " 0 0.76 0.72 0.74 2195\n",
1363
+ " 1 0.80 0.79 0.79 1563\n",
1364
+ " 2 0.80 0.86 0.83 1739\n",
1365
+ "\n",
1366
+ " accuracy 0.78 5497\n",
1367
+ " macro avg 0.78 0.79 0.79 5497\n",
1368
+ "weighted avg 0.78 0.78 0.78 5497\n",
1369
+ "\n"
1370
+ ]
1371
+ }
1372
+ ],
1373
+ "source": [
1374
+ "from sklearn.metrics import classification_report\n",
1375
+ "print(classification_report(true_vals,preds_flat))"
1376
+ ]
1377
+ }
1378
+ ],
1379
+ "metadata": {
1380
+ "kernelspec": {
1381
+ "display_name": "Python 3 (ipykernel)",
1382
+ "language": "python",
1383
+ "name": "python3"
1384
+ },
1385
+ "language_info": {
1386
+ "codemirror_mode": {
1387
+ "name": "ipython",
1388
+ "version": 3
1389
+ },
1390
+ "file_extension": ".py",
1391
+ "mimetype": "text/x-python",
1392
+ "name": "python",
1393
+ "nbconvert_exporter": "python",
1394
+ "pygments_lexer": "ipython3",
1395
+ "version": "3.11.6"
1396
+ }
1397
+ },
1398
+ "nbformat": 4,
1399
+ "nbformat_minor": 5
1400
+ }
finetuned_BERT_epoch_1.model ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2b613d1cef3a1c6b4954e0ace43e7038a307eac4c5c2bbfa277b62433f743e5e
3
+ size 438022947
finetuned_BERT_epoch_2.model ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:085ac801240230dea513a1d851552f8a76a6f96a27eb965f458703d7ce626129
3
+ size 438022947
finetuned_BERT_epoch_3.model ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a2dd1ccc341b74f8a933befbb518c950b660e248560e5a220e2a4a7a0a48f6c6
3
+ size 438022947
finetuned_BERT_epoch_4.model ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:abb5978e32e26fb66bbe6d94a96351af8cb1f9a48df0d018f0c1157f5176fef1
3
+ size 438022947
finetuned_BERT_epoch_5.model ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:54cdde522a80be579c662dd01bd4199fc79278e6782a0e4560d539f568cf819b
3
+ size 438022947