vikranth1111 commited on
Commit
a4b94b2
·
1 Parent(s): fff24c6

Upload 16 files

Browse files
01_kpy_first_model_errors.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
02_error_analysis_first_model.ipynb ADDED
@@ -0,0 +1,612 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 1,
6
+ "metadata": {},
7
+ "outputs": [
8
+ {
9
+ "name": "stderr",
10
+ "output_type": "stream",
11
+ "text": [
12
+ "[nltk_data] Downloading package stopwords to\n",
13
+ "[nltk_data] C:\\Users\\kurti\\AppData\\Roaming\\nltk_data...\n",
14
+ "[nltk_data] Package stopwords is already up-to-date!\n"
15
+ ]
16
+ },
17
+ {
18
+ "data": {
19
+ "text/plain": [
20
+ "True"
21
+ ]
22
+ },
23
+ "execution_count": 1,
24
+ "metadata": {},
25
+ "output_type": "execute_result"
26
+ }
27
+ ],
28
+ "source": [
29
+ "import re\n",
30
+ "import nltk\n",
31
+ "import string\n",
32
+ "import numpy as np \n",
33
+ "import pandas as pd\n",
34
+ "from nltk.corpus import stopwords\n",
35
+ "from nltk.stem import PorterStemmer\n",
36
+ "from nltk.tokenize import TweetTokenizer\n",
37
+ "from sklearn.naive_bayes import MultinomialNB\n",
38
+ "from sklearn.model_selection import StratifiedKFold\n",
39
+ "from sklearn.feature_extraction.text import CountVectorizer\n",
40
+ "\n",
41
+ "nltk.download(\"stopwords\")"
42
+ ]
43
+ },
44
+ {
45
+ "cell_type": "code",
46
+ "execution_count": 2,
47
+ "metadata": {},
48
+ "outputs": [],
49
+ "source": [
50
+ "def process_tweet(tweet):\n",
51
+ " \"\"\"\n",
52
+ " Process tweet function.\n",
53
+ " Input:\n",
54
+ " tweet: a string containing a tweet\n",
55
+ " Returns:\n",
56
+ " tweets_clean: a list of words containing the processed tweet\n",
57
+ "\n",
58
+ " *Taken from Coursera NLP Specialization Course 1, week 1 programming\n",
59
+ " assignment*\n",
60
+ " \"\"\"\n",
61
+ " stemmer = PorterStemmer()\n",
62
+ " stopwords_english = stopwords.words('english')\n",
63
+ " # remove stock market tickers like $GE\n",
64
+ " tweet = re.sub(r'\\$\\w*', '', str(tweet))\n",
65
+ " # remove old style retweet text \"RT\"\n",
66
+ " tweet = re.sub(r'^RT[\\s]+', '', str(tweet))\n",
67
+ " # remove hyperlinks\n",
68
+ " tweet = re.sub(r'https?:\\/\\/.*[\\r\\n]*', '', str(tweet))\n",
69
+ " # remove hashtags\n",
70
+ " # only removing the hash # sign from the word\n",
71
+ " tweet = re.sub(r'#', '', str(tweet))\n",
72
+ " # tokenize tweets\n",
73
+ " tokenizer = TweetTokenizer(preserve_case=False, strip_handles=True,\n",
74
+ " reduce_len=True)\n",
75
+ " tweet_tokens = tokenizer.tokenize(tweet)\n",
76
+ "\n",
77
+ " tweets_clean = []\n",
78
+ " for word in tweet_tokens:\n",
79
+ " if (word not in stopwords_english and # remove stopwords\n",
80
+ " word not in string.punctuation): # remove punctuation\n",
81
+ " # tweets_clean.append(word)\n",
82
+ " stem_word = stemmer.stem(word) # stemming word\n",
83
+ " tweets_clean.append(stem_word)\n",
84
+ "\n",
85
+ " return \" \".join(tweets_clean)"
86
+ ]
87
+ },
88
+ {
89
+ "cell_type": "code",
90
+ "execution_count": 3,
91
+ "metadata": {},
92
+ "outputs": [
93
+ {
94
+ "data": {
95
+ "text/html": [
96
+ "<div>\n",
97
+ "<style scoped>\n",
98
+ " .dataframe tbody tr th:only-of-type {\n",
99
+ " vertical-align: middle;\n",
100
+ " }\n",
101
+ "\n",
102
+ " .dataframe tbody tr th {\n",
103
+ " vertical-align: top;\n",
104
+ " }\n",
105
+ "\n",
106
+ " .dataframe thead th {\n",
107
+ " text-align: right;\n",
108
+ " }\n",
109
+ "</style>\n",
110
+ "<table border=\"1\" class=\"dataframe\">\n",
111
+ " <thead>\n",
112
+ " <tr style=\"text-align: right;\">\n",
113
+ " <th></th>\n",
114
+ " <th>id</th>\n",
115
+ " <th>all_text</th>\n",
116
+ " </tr>\n",
117
+ " </thead>\n",
118
+ " <tbody>\n",
119
+ " <tr>\n",
120
+ " <th>0</th>\n",
121
+ " <td>3796</td>\n",
122
+ " <td>new weapon caus un-imagin destruct destructionnon</td>\n",
123
+ " </tr>\n",
124
+ " <tr>\n",
125
+ " <th>1</th>\n",
126
+ " <td>3185</td>\n",
127
+ " <td>f @ing thing gishwh got soak delug go pad tamp...</td>\n",
128
+ " </tr>\n",
129
+ " <tr>\n",
130
+ " <th>2</th>\n",
131
+ " <td>7769</td>\n",
132
+ " <td>dt rt ‰ ûïthe col polic catch pickpocket liver...</td>\n",
133
+ " </tr>\n",
134
+ " <tr>\n",
135
+ " <th>3</th>\n",
136
+ " <td>191</td>\n",
137
+ " <td>aftershock back school kick great want thank e...</td>\n",
138
+ " </tr>\n",
139
+ " <tr>\n",
140
+ " <th>4</th>\n",
141
+ " <td>9810</td>\n",
142
+ " <td>respons trauma children addict develop defens ...</td>\n",
143
+ " </tr>\n",
144
+ " <tr>\n",
145
+ " <th>5</th>\n",
146
+ " <td>7934</td>\n",
147
+ " <td>look like got caught rainstorm amaz disgust ti...</td>\n",
148
+ " </tr>\n",
149
+ " <tr>\n",
150
+ " <th>6</th>\n",
151
+ " <td>2538</td>\n",
152
+ " <td>favorit ladi came volunt meet hope join youth ...</td>\n",
153
+ " </tr>\n",
154
+ " <tr>\n",
155
+ " <th>7</th>\n",
156
+ " <td>2611</td>\n",
157
+ " <td>ux fail emv peopl want insert remov quickli li...</td>\n",
158
+ " </tr>\n",
159
+ " <tr>\n",
160
+ " <th>8</th>\n",
161
+ " <td>9756</td>\n",
162
+ " <td>can't find ariana grand shirt fuck tragedytrag...</td>\n",
163
+ " </tr>\n",
164
+ " <tr>\n",
165
+ " <th>9</th>\n",
166
+ " <td>6254</td>\n",
167
+ " <td>murder stori america ‰ ûª first hijack</td>\n",
168
+ " </tr>\n",
169
+ " </tbody>\n",
170
+ "</table>\n",
171
+ "</div>"
172
+ ],
173
+ "text/plain": [
174
+ " id all_text\n",
175
+ "0 3796 new weapon caus un-imagin destruct destructionnon\n",
176
+ "1 3185 f @ing thing gishwh got soak delug go pad tamp...\n",
177
+ "2 7769 dt rt ‰ ûïthe col polic catch pickpocket liver...\n",
178
+ "3 191 aftershock back school kick great want thank e...\n",
179
+ "4 9810 respons trauma children addict develop defens ...\n",
180
+ "5 7934 look like got caught rainstorm amaz disgust ti...\n",
181
+ "6 2538 favorit ladi came volunt meet hope join youth ...\n",
182
+ "7 2611 ux fail emv peopl want insert remov quickli li...\n",
183
+ "8 9756 can't find ariana grand shirt fuck tragedytrag...\n",
184
+ "9 6254 murder stori america ‰ ûª first hijack"
185
+ ]
186
+ },
187
+ "execution_count": 3,
188
+ "metadata": {},
189
+ "output_type": "execute_result"
190
+ }
191
+ ],
192
+ "source": [
193
+ "# read train data\n",
194
+ "df = pd.read_csv(\"../inputs/train.csv\")\n",
195
+ "# shuffle data\n",
196
+ "df = df.sample(frac=1, random_state=42).reset_index(drop=True)\n",
197
+ "# create new column \"all_text\"\n",
198
+ "df[\"all_text\"] = df[\"text\"] + df[\"keyword\"].fillna(\"none\") + df[\"location\"].fillna(\"none\")\n",
199
+ "# split into features and labels\n",
200
+ "X = df.drop([\"text\", \"keyword\", \"location\", \"target\"], axis=1)\n",
201
+ "y = df[\"target\"]\n",
202
+ "\n",
203
+ "# process tweets\n",
204
+ "X[\"all_text\"] = X[\"all_text\"].apply(process_tweet)\n",
205
+ "X.head(10)"
206
+ ]
207
+ },
208
+ {
209
+ "cell_type": "code",
210
+ "execution_count": 4,
211
+ "metadata": {},
212
+ "outputs": [],
213
+ "source": [
214
+ "# create a dictionary mapping predictions to the tweet idx\n",
215
+ "pred_idx_dict = {}\n",
216
+ "# initialize kfold\n",
217
+ "skf = StratifiedKFold(n_splits=5, shuffle=False)\n",
218
+ "for fold, (train_idx, val_idx) in enumerate(skf.split(X=X, y=y)):\n",
219
+ " X_train, X_val = X.loc[train_idx, :], X.loc[val_idx, :]\n",
220
+ " y_train, y_val = y[train_idx], y[val_idx]\n",
221
+ "\n",
222
+ " # vectorize text and store model\n",
223
+ " count_vect = CountVectorizer()\n",
224
+ " X_train_vect = count_vect.fit_transform(X_train[\"all_text\"].values)\n",
225
+ " X_val_vect = count_vect.transform(X_val[\"all_text\"].values)\n",
226
+ " \n",
227
+ " # classify predictions\n",
228
+ " clf = MultinomialNB()\n",
229
+ " clf.fit(X_train_vect, y_train)\n",
230
+ " y_preds = clf.predict(X_val_vect)\n",
231
+ " \n",
232
+ " # idx of tweet mapping to prediction of model\n",
233
+ " for idx, key in enumerate(val_idx):\n",
234
+ " pred_idx_dict[key] = y_preds[idx]"
235
+ ]
236
+ },
237
+ {
238
+ "cell_type": "code",
239
+ "execution_count": 20,
240
+ "metadata": {},
241
+ "outputs": [],
242
+ "source": [
243
+ "# create df with actual and prediction\n",
244
+ "error_df = X.copy()\n",
245
+ "error_df.rename(columns={\"all_text\":\"processed_all_text\"}, inplace=True)\n",
246
+ "error_df[\"all_text\"] = df[df[\"id\"] == error_df[\"id\"].values][\"all_text\"]\n",
247
+ "error_df[\"actual\"] = y.copy()\n",
248
+ "error_df[\"predictions\"] = pred_idx_dict.values()"
249
+ ]
250
+ },
251
+ {
252
+ "cell_type": "code",
253
+ "execution_count": 21,
254
+ "metadata": {},
255
+ "outputs": [
256
+ {
257
+ "data": {
258
+ "text/html": [
259
+ "<div>\n",
260
+ "<style scoped>\n",
261
+ " .dataframe tbody tr th:only-of-type {\n",
262
+ " vertical-align: middle;\n",
263
+ " }\n",
264
+ "\n",
265
+ " .dataframe tbody tr th {\n",
266
+ " vertical-align: top;\n",
267
+ " }\n",
268
+ "\n",
269
+ " .dataframe thead th {\n",
270
+ " text-align: right;\n",
271
+ " }\n",
272
+ "</style>\n",
273
+ "<table border=\"1\" class=\"dataframe\">\n",
274
+ " <thead>\n",
275
+ " <tr style=\"text-align: right;\">\n",
276
+ " <th></th>\n",
277
+ " <th>id</th>\n",
278
+ " <th>processed_all_text</th>\n",
279
+ " <th>all_text</th>\n",
280
+ " <th>actual</th>\n",
281
+ " <th>predictions</th>\n",
282
+ " </tr>\n",
283
+ " </thead>\n",
284
+ " <tbody>\n",
285
+ " <tr>\n",
286
+ " <th>0</th>\n",
287
+ " <td>3796</td>\n",
288
+ " <td>new weapon caus un-imagin destruct destructionnon</td>\n",
289
+ " <td>So you have a new weapon that can cause un-ima...</td>\n",
290
+ " <td>1</td>\n",
291
+ " <td>0</td>\n",
292
+ " </tr>\n",
293
+ " <tr>\n",
294
+ " <th>1</th>\n",
295
+ " <td>3185</td>\n",
296
+ " <td>f @ing thing gishwh got soak delug go pad tamp...</td>\n",
297
+ " <td>The f$&amp;amp;@ing things I do for #GISHWHES Just...</td>\n",
298
+ " <td>0</td>\n",
299
+ " <td>0</td>\n",
300
+ " </tr>\n",
301
+ " <tr>\n",
302
+ " <th>2</th>\n",
303
+ " <td>7769</td>\n",
304
+ " <td>dt rt ‰ ûïthe col polic catch pickpocket liver...</td>\n",
305
+ " <td>DT @georgegalloway: RT @Galloway4Mayor: ‰ÛÏThe...</td>\n",
306
+ " <td>1</td>\n",
307
+ " <td>0</td>\n",
308
+ " </tr>\n",
309
+ " <tr>\n",
310
+ " <th>3</th>\n",
311
+ " <td>191</td>\n",
312
+ " <td>aftershock back school kick great want thank e...</td>\n",
313
+ " <td>Aftershock back to school kick off was great. ...</td>\n",
314
+ " <td>0</td>\n",
315
+ " <td>0</td>\n",
316
+ " </tr>\n",
317
+ " <tr>\n",
318
+ " <th>4</th>\n",
319
+ " <td>9810</td>\n",
320
+ " <td>respons trauma children addict develop defens ...</td>\n",
321
+ " <td>in response to trauma Children of Addicts deve...</td>\n",
322
+ " <td>0</td>\n",
323
+ " <td>1</td>\n",
324
+ " </tr>\n",
325
+ " <tr>\n",
326
+ " <th>...</th>\n",
327
+ " <td>...</td>\n",
328
+ " <td>...</td>\n",
329
+ " <td>...</td>\n",
330
+ " <td>...</td>\n",
331
+ " <td>...</td>\n",
332
+ " </tr>\n",
333
+ " <tr>\n",
334
+ " <th>7608</th>\n",
335
+ " <td>7470</td>\n",
336
+ " <td>mani obliter server alway like play :D obliter...</td>\n",
337
+ " <td>@Eganator2000 There aren't many Obliteration s...</td>\n",
338
+ " <td>0</td>\n",
339
+ " <td>0</td>\n",
340
+ " </tr>\n",
341
+ " <tr>\n",
342
+ " <th>7609</th>\n",
343
+ " <td>7691</td>\n",
344
+ " <td>panic attack bc enough money drug alcohol want...</td>\n",
345
+ " <td>just had a panic attack bc I don't have enough...</td>\n",
346
+ " <td>0</td>\n",
347
+ " <td>0</td>\n",
348
+ " </tr>\n",
349
+ " <tr>\n",
350
+ " <th>7610</th>\n",
351
+ " <td>1242</td>\n",
352
+ " <td>omron hem 712c automat blood pressur monitor s...</td>\n",
353
+ " <td>Omron HEM-712C Automatic Blood Pressure Monito...</td>\n",
354
+ " <td>0</td>\n",
355
+ " <td>1</td>\n",
356
+ " </tr>\n",
357
+ " <tr>\n",
358
+ " <th>7611</th>\n",
359
+ " <td>10862</td>\n",
360
+ " <td>offici say quarantin place alabama home possib...</td>\n",
361
+ " <td>Officials say a quarantine is in place at an A...</td>\n",
362
+ " <td>1</td>\n",
363
+ " <td>1</td>\n",
364
+ " </tr>\n",
365
+ " <tr>\n",
366
+ " <th>7612</th>\n",
367
+ " <td>10409</td>\n",
368
+ " <td>move england five year ago today whirlwind time</td>\n",
369
+ " <td>I moved to England five years ago today. What ...</td>\n",
370
+ " <td>1</td>\n",
371
+ " <td>1</td>\n",
372
+ " </tr>\n",
373
+ " </tbody>\n",
374
+ "</table>\n",
375
+ "<p>7613 rows × 5 columns</p>\n",
376
+ "</div>"
377
+ ],
378
+ "text/plain": [
379
+ " id processed_all_text \\\n",
380
+ "0 3796 new weapon caus un-imagin destruct destructionnon \n",
381
+ "1 3185 f @ing thing gishwh got soak delug go pad tamp... \n",
382
+ "2 7769 dt rt ‰ ûïthe col polic catch pickpocket liver... \n",
383
+ "3 191 aftershock back school kick great want thank e... \n",
384
+ "4 9810 respons trauma children addict develop defens ... \n",
385
+ "... ... ... \n",
386
+ "7608 7470 mani obliter server alway like play :D obliter... \n",
387
+ "7609 7691 panic attack bc enough money drug alcohol want... \n",
388
+ "7610 1242 omron hem 712c automat blood pressur monitor s... \n",
389
+ "7611 10862 offici say quarantin place alabama home possib... \n",
390
+ "7612 10409 move england five year ago today whirlwind time \n",
391
+ "\n",
392
+ " all_text actual predictions \n",
393
+ "0 So you have a new weapon that can cause un-ima... 1 0 \n",
394
+ "1 The f$&amp;@ing things I do for #GISHWHES Just... 0 0 \n",
395
+ "2 DT @georgegalloway: RT @Galloway4Mayor: ‰ÛÏThe... 1 0 \n",
396
+ "3 Aftershock back to school kick off was great. ... 0 0 \n",
397
+ "4 in response to trauma Children of Addicts deve... 0 1 \n",
398
+ "... ... ... ... \n",
399
+ "7608 @Eganator2000 There aren't many Obliteration s... 0 0 \n",
400
+ "7609 just had a panic attack bc I don't have enough... 0 0 \n",
401
+ "7610 Omron HEM-712C Automatic Blood Pressure Monito... 0 1 \n",
402
+ "7611 Officials say a quarantine is in place at an A... 1 1 \n",
403
+ "7612 I moved to England five years ago today. What ... 1 1 \n",
404
+ "\n",
405
+ "[7613 rows x 5 columns]"
406
+ ]
407
+ },
408
+ "execution_count": 21,
409
+ "metadata": {},
410
+ "output_type": "execute_result"
411
+ }
412
+ ],
413
+ "source": [
414
+ "error_df"
415
+ ]
416
+ },
417
+ {
418
+ "cell_type": "code",
419
+ "execution_count": 24,
420
+ "metadata": {},
421
+ "outputs": [
422
+ {
423
+ "data": {
424
+ "text/html": [
425
+ "<div>\n",
426
+ "<style scoped>\n",
427
+ " .dataframe tbody tr th:only-of-type {\n",
428
+ " vertical-align: middle;\n",
429
+ " }\n",
430
+ "\n",
431
+ " .dataframe tbody tr th {\n",
432
+ " vertical-align: top;\n",
433
+ " }\n",
434
+ "\n",
435
+ " .dataframe thead th {\n",
436
+ " text-align: right;\n",
437
+ " }\n",
438
+ "</style>\n",
439
+ "<table border=\"1\" class=\"dataframe\">\n",
440
+ " <thead>\n",
441
+ " <tr style=\"text-align: right;\">\n",
442
+ " <th></th>\n",
443
+ " <th>id</th>\n",
444
+ " <th>processed_all_text</th>\n",
445
+ " <th>all_text</th>\n",
446
+ " <th>actual</th>\n",
447
+ " <th>predictions</th>\n",
448
+ " </tr>\n",
449
+ " </thead>\n",
450
+ " <tbody>\n",
451
+ " <tr>\n",
452
+ " <th>149</th>\n",
453
+ " <td>1061</td>\n",
454
+ " <td>ye i'm bleed heart liberal.bleedingl oak tx</td>\n",
455
+ " <td>@KatRamsland Yes I'm a bleeding heart liberal....</td>\n",
456
+ " <td>1</td>\n",
457
+ " <td>0</td>\n",
458
+ " </tr>\n",
459
+ " <tr>\n",
460
+ " <th>518</th>\n",
461
+ " <td>8946</td>\n",
462
+ " <td>storm came . . fuck coolstormnon</td>\n",
463
+ " <td>So this storm just came out of no where. .fuck...</td>\n",
464
+ " <td>1</td>\n",
465
+ " <td>0</td>\n",
466
+ " </tr>\n",
467
+ " <tr>\n",
468
+ " <th>3161</th>\n",
469
+ " <td>143</td>\n",
470
+ " <td>car even week got fuck car accid .. mf can't f...</td>\n",
471
+ " <td>only had a car for not even a week and got in ...</td>\n",
472
+ " <td>1</td>\n",
473
+ " <td>0</td>\n",
474
+ " </tr>\n",
475
+ " <tr>\n",
476
+ " <th>6624</th>\n",
477
+ " <td>9044</td>\n",
478
+ " <td>spacex founder musk structur failur took falcon 9</td>\n",
479
+ " <td>SpaceX Founder Musk: Structural Failure Took D...</td>\n",
480
+ " <td>1</td>\n",
481
+ " <td>0</td>\n",
482
+ " </tr>\n",
483
+ " <tr>\n",
484
+ " <th>881</th>\n",
485
+ " <td>1458</td>\n",
486
+ " <td>anoth one anoth one still ain't done shit one ...</td>\n",
487
+ " <td>'I did another one I did another one. You stil...</td>\n",
488
+ " <td>1</td>\n",
489
+ " <td>0</td>\n",
490
+ " </tr>\n",
491
+ " <tr>\n",
492
+ " <th>4314</th>\n",
493
+ " <td>10364</td>\n",
494
+ " <td>router one latest ddo attack weapon</td>\n",
495
+ " <td>Your Router is One of the Latest DDoS Attack W...</td>\n",
496
+ " <td>0</td>\n",
497
+ " <td>1</td>\n",
498
+ " </tr>\n",
499
+ " <tr>\n",
500
+ " <th>5399</th>\n",
501
+ " <td>6188</td>\n",
502
+ " <td>gov brown allow parol 1976 chowchilla school b...</td>\n",
503
+ " <td>Gov. Brown allows parole for 1976 Chowchilla s...</td>\n",
504
+ " <td>0</td>\n",
505
+ " <td>1</td>\n",
506
+ " </tr>\n",
507
+ " <tr>\n",
508
+ " <th>4266</th>\n",
509
+ " <td>4911</td>\n",
510
+ " <td>chick masturb guy get explod face</td>\n",
511
+ " <td>Chick masturbates a guy until she gets explode...</td>\n",
512
+ " <td>1</td>\n",
513
+ " <td>0</td>\n",
514
+ " </tr>\n",
515
+ " <tr>\n",
516
+ " <th>3959</th>\n",
517
+ " <td>2112</td>\n",
518
+ " <td>borrow concern possibl interest rate rise coul...</td>\n",
519
+ " <td>#Borrowers concerned at possible #interest rat...</td>\n",
520
+ " <td>0</td>\n",
521
+ " <td>1</td>\n",
522
+ " </tr>\n",
523
+ " <tr>\n",
524
+ " <th>6445</th>\n",
525
+ " <td>7926</td>\n",
526
+ " <td>stuck rainstorm stay toward middl road street ...</td>\n",
527
+ " <td>Stuck in a rainstorm? Stay toward the middle o...</td>\n",
528
+ " <td>0</td>\n",
529
+ " <td>1</td>\n",
530
+ " </tr>\n",
531
+ " </tbody>\n",
532
+ "</table>\n",
533
+ "</div>"
534
+ ],
535
+ "text/plain": [
536
+ " id processed_all_text \\\n",
537
+ "149 1061 ye i'm bleed heart liberal.bleedingl oak tx \n",
538
+ "518 8946 storm came . . fuck coolstormnon \n",
539
+ "3161 143 car even week got fuck car accid .. mf can't f... \n",
540
+ "6624 9044 spacex founder musk structur failur took falcon 9 \n",
541
+ "881 1458 anoth one anoth one still ain't done shit one ... \n",
542
+ "4314 10364 router one latest ddo attack weapon \n",
543
+ "5399 6188 gov brown allow parol 1976 chowchilla school b... \n",
544
+ "4266 4911 chick masturb guy get explod face \n",
545
+ "3959 2112 borrow concern possibl interest rate rise coul... \n",
546
+ "6445 7926 stuck rainstorm stay toward middl road street ... \n",
547
+ "\n",
548
+ " all_text actual predictions \n",
549
+ "149 @KatRamsland Yes I'm a bleeding heart liberal.... 1 0 \n",
550
+ "518 So this storm just came out of no where. .fuck... 1 0 \n",
551
+ "3161 only had a car for not even a week and got in ... 1 0 \n",
552
+ "6624 SpaceX Founder Musk: Structural Failure Took D... 1 0 \n",
553
+ "881 'I did another one I did another one. You stil... 1 0 \n",
554
+ "4314 Your Router is One of the Latest DDoS Attack W... 0 1 \n",
555
+ "5399 Gov. Brown allows parole for 1976 Chowchilla s... 0 1 \n",
556
+ "4266 Chick masturbates a guy until she gets explode... 1 0 \n",
557
+ "3959 #Borrowers concerned at possible #interest rat... 0 1 \n",
558
+ "6445 Stuck in a rainstorm? Stay toward the middle o... 0 1 "
559
+ ]
560
+ },
561
+ "execution_count": 24,
562
+ "metadata": {},
563
+ "output_type": "execute_result"
564
+ }
565
+ ],
566
+ "source": [
567
+ "# store only the misclassified instances\n",
568
+ "misclassified_df = error_df[error_df[\"actual\"].values != error_df[\"predictions\"]]\n",
569
+ "# keep only 100 of the misclassfied instances\n",
570
+ "misclassified_100 = misclassified_df.sample(n=100, random_state=42)\n",
571
+ "misclassified_100.head(10)"
572
+ ]
573
+ },
574
+ {
575
+ "cell_type": "code",
576
+ "execution_count": 23,
577
+ "metadata": {},
578
+ "outputs": [],
579
+ "source": [
580
+ "misclassified_100.to_csv(\"misclassified_data.csv\", index=False)"
581
+ ]
582
+ },
583
+ {
584
+ "cell_type": "code",
585
+ "execution_count": null,
586
+ "metadata": {},
587
+ "outputs": [],
588
+ "source": []
589
+ }
590
+ ],
591
+ "metadata": {
592
+ "kernelspec": {
593
+ "display_name": "Python 3",
594
+ "language": "python",
595
+ "name": "python3"
596
+ },
597
+ "language_info": {
598
+ "codemirror_mode": {
599
+ "name": "ipython",
600
+ "version": 3
601
+ },
602
+ "file_extension": ".py",
603
+ "mimetype": "text/x-python",
604
+ "name": "python",
605
+ "nbconvert_exporter": "python",
606
+ "pygments_lexer": "ipython3",
607
+ "version": "3.8.6"
608
+ }
609
+ },
610
+ "nbformat": 4,
611
+ "nbformat_minor": 4
612
+ }
03_kpy_data_exploration.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
VERSION ADDED
@@ -0,0 +1 @@
 
 
1
+ 0.1.0
config.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # data
2
+ DATA_DIR = "../inputs/"
3
+ ORIGINAL_TRAIN = DATA_DIR + "train.csv"
4
+ MODIFIED_TRAIN = DATA_DIR + "modified_train.csv"
5
+ TEST_DATA = DATA_DIR + "test.csv"
6
+ MODIFIED_TEST = DATA_DIR + "modified_test.csv"
7
+ SUBMISSION = DATA_DIR + "sample_submission.csv"
8
+ MODEL_DIR = "../models/"
9
+ IMAGES = "../images/"
10
+
11
+ # features
12
+ ID = "id"
13
+ TEXT = "text"
14
+ KEYWORD = "keyword"
15
+ LOCATION = "location"
16
+ FOLD = "kfold"
17
+ TOKENS = "tokens"
18
+
19
+ # created features
20
+ ALL_TEXT = "all_text"
21
+ CLEANED_TEXT = "cleaned_text"
22
+
23
+ # target
24
+ TARGET = "target"
25
+ RELABELED_TARGET = "relabeled_target"
26
+
27
+ # Pretrained Word2Vec
28
+ PRETRAINED_WORD2VEC = "word2vec-google-news-300"
29
+ EMBED_SIZE = 300
30
+
31
+ # TRAINING
32
+ HIDDEN_DIM = 256
33
+ TARGET_DIM = 1
34
+ BATCH_SIZE = 32
35
+ N_EPOCHS = 8
36
+ N_SPLITS = 5
37
+ LEARNING_RATE = 1e-3
38
+ MAXLEN = 202
39
+ VOCAB_SIZE = 172901
40
+
data_cleaning.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pandas as pd
2
+
3
+ import config
4
+
5
+ def relabel_target(df:pd.DataFrame) -> pd.DataFrame:
6
+ """
7
+ Relabel duplicate tweets that are mislabelled in the training dataset
8
+ :param df: A pandas dataframe with a "target" column
9
+ :return: df
10
+ """
11
+ # copy old target label
12
+ df[config.RELABELED_TARGET] = df[config.TARGET].copy()
13
+ # relabel samples with different labels to their duplicates
14
+ df.loc[df[config.TEXT] == 'like for the music video I want some real action shit like burning buildings and police chases not some weak ben winston shit',
15
+ config.RELABELED_TARGET] = 0
16
+ df.loc[df[config.TEXT] == 'Hellfire is surrounded by desires so be careful and don‰Ûªt let your desires control you! #Afterlife',
17
+ config.RELABELED_TARGET] = 0
18
+ df.loc[df[config.TEXT] == 'To fight bioterrorism sir.',
19
+ config.RELABELED_TARGET] = 0
20
+ df.loc[df[config.TEXT] == '.POTUS #StrategicPatience is a strategy for #Genocide; refugees; IDP Internally displaced people; horror; etc. https://t.co/rqWuoy1fm4',
21
+ config.RELABELED_TARGET] = 1
22
+ df.loc[df[config.TEXT] == 'CLEARED:incident with injury:I-495 inner loop Exit 31 - MD 97/Georgia Ave Silver Spring',
23
+ config.RELABELED_TARGET] = 1
24
+ df.loc[df[config.TEXT] == '#foodscare #offers2go #NestleIndia slips into loss after #Magginoodle #ban unsafe and hazardous for #humanconsumption',
25
+ config.RELABELED_TARGET] = 0
26
+ df.loc[df[config.TEXT] == 'In #islam saving a person is equal in reward to saving all humans! Islam is the opposite of terrorism!',
27
+ config.RELABELED_TARGET] = 0
28
+ df.loc[df[config.TEXT] == 'Who is bringing the tornadoes and floods. Who is bringing the climate change. God is after America He is plaguing her\n \n#FARRAKHAN #QUOTE',
29
+ config.RELABELED_TARGET] = 1
30
+ df.loc[df[config.TEXT] == 'RT NotExplained: The only known image of infamous hijacker D.B. Cooper. http://t.co/JlzK2HdeTG',
31
+ config.RELABELED_TARGET] = 1
32
+ df.loc[df[config.TEXT] == "Mmmmmm I'm burning.... I'm burning buildings I'm building.... Oooooohhhh oooh ooh...",
33
+ config.RELABELED_TARGET] = 0
34
+ df.loc[df[config.TEXT] == "wowo--=== 12000 Nigerian refugees repatriated from Cameroon",
35
+ config.RELABELED_TARGET] = 0
36
+ df.loc[df[config.TEXT] == "He came to a land which was engulfed in tribal war and turned it into a land of peace i.e. Madinah. #ProphetMuhammad #islam",
37
+ config.RELABELED_TARGET] = 0
38
+ df.loc[df[config.TEXT] == "Hellfire! We don‰Ûªt even want to think about it or mention it so let‰Ûªs not do anything that leads to it #islam!",
39
+ config.RELABELED_TARGET] = 0
40
+ df.loc[df[config.TEXT] == "The Prophet (peace be upon him) said 'Save yourself from Hellfire even if it is by giving half a date in charity.'",
41
+ config.RELABELED_TARGET] = 0
42
+ df.loc[df[config.TEXT] == "Caution: breathing may be hazardous to your health.",
43
+ config.RELABELED_TARGET] = 1
44
+ df.loc[df[config.TEXT] == "I Pledge Allegiance To The P.O.P.E. And The Burning Buildings of Epic City. ??????",
45
+ config.RELABELED_TARGET] = 0
46
+ df.loc[df[config.TEXT] == "#Allah describes piling up #wealth thinking it would last #forever as the description of the people of #Hellfire in Surah Humaza. #Reflect",
47
+ config.RELABELED_TARGET] = 0
48
+ df.loc[df[config.TEXT] == "that horrible sinking feeling when you‰Ûªve been at home on your phone for a while and you realise its been on 3G this whole time",
49
+ config.RELABELED_TARGET] = 0
50
+ return df
features.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import gensim.downloader as api
3
+
4
+ import config
5
+
6
+ def get_word2vec_enc(corpus: list, gensim_pretrained_emb:str) -> list:
7
+ """
8
+ Get the W2V value for each word withing
9
+ :param text: The text we want to get embeddings for
10
+ :param embed_size: Dimension output for pretrained embeddings
11
+ :param pretrained_emb: The pretrained embedding to use
12
+ :return: words encoded as vectors
13
+ """
14
+ word_vecs = api.load(gensim_pretrained_emb)
15
+ embedding_weights = np.zeros((config.VOCAB_SIZE, config.EMBED_SIZE))
16
+ for word, i in corpus:
17
+ if word in word_vecs:
18
+ embedding_weights[i] = word_vecs[word]
19
+ return embedding_weights
fixed_df_naive_bayes.png ADDED
incorrect_naive_bayes.png ADDED
inference.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pickle
2
+
3
+ import numpy as np
4
+ import pandas as pd
5
+ from tensorflow.keras.models import load_model
6
+ from tensorflow.keras.preprocessing.sequence import pad_sequences
7
+
8
+ import config
9
+ import preprocessing as pp
10
+
11
+ def predict_test(model:str, test_data:pd.DataFrame= config.MODIFIED_TEST):
12
+
13
+ # path to model
14
+ model_path = f"{config.MODEL_DIR}/PRETRAIN_WORD2VEC_{model}/"
15
+
16
+ # read data
17
+ df_test = pd.read_csv(test_data)
18
+
19
+ # do cleaning to text
20
+ df_test[config.CLEANED_TEXT] = df_test[config.TEXT].apply(pp.clean_tweet)
21
+
22
+ # loading tokenizer
23
+ with open(f'{model_path}tokenizer.pkl', 'rb') as handle:
24
+ tokenizer = pickle.load(handle)
25
+
26
+ # convert tokens to sequences and pad them
27
+ data_values = tokenizer.texts_to_sequences(df_test[config.CLEANED_TEXT].values)
28
+ X_padded = pad_sequences(data_values, maxlen=config.MAXLEN)
29
+
30
+ # load the classifier
31
+ clf = load_model(f"{model_path}{model}_Word2Vec .h5")
32
+ predictions = clf.predict_classes(X_padded, verbose=-1)
33
+
34
+ return predictions
35
+
36
+ if __name__ == "__main__":
37
+ submission = predict_test(model="LSTM")
38
+ sample_sub = pd.read_csv(config.SUBMISSION)
39
+ sample_sub.loc[:, config.TARGET] = submission
40
+ sample_sub.to_csv(f"{config.MODEL_DIR}PRETRAIN_WORD2VEC_LSTM/LSTM.csv", index=False)
lstm_model.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from tensorflow.keras.layers import Dense, Dropout, LSTM, Bidirectional
2
+ from tensorflow.keras import Sequential
3
+
4
+ def my_LSTM(embedding_layer):
5
+ print('Creating model...')
6
+ model = Sequential()
7
+ model.add(embedding_layer)
8
+ model.add(Dropout(0.2))
9
+ model.add(Bidirectional(LSTM(units=64, dropout=0.1, recurrent_dropout=0.1)))
10
+ model.add(Dense(50, activation="relu"))
11
+ model.add(Dropout(0.1))
12
+ model.add(Dense(1, activation = "sigmoid"))
13
+
14
+ print('Compiling...')
15
+ model.compile(loss='binary_crossentropy',
16
+ optimizer='adam',
17
+ metrics=["accuracy"])
18
+ return model
model_dispatcher.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ from sklearn import linear_model, naive_bayes, ensemble, svm
2
+
3
+ MODELS = {
4
+ "logistic_regression": linear_model.LogisticRegression(max_iter=1000, random_state=42),
5
+ "naive_bayes": naive_bayes.MultinomialNB(),
6
+ "random_forest": ensemble.RandomForestClassifier(n_estimators=500, random_state=42, n_jobs=-1),
7
+ "svm": svm.SVC(C=10)
8
+ }
preprocessing.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ import string
3
+
4
+ import nltk
5
+ from nltk.corpus import stopwords
6
+ from nltk.tokenize import TweetTokenizer
7
+
8
+ nltk.download("stopwords")
9
+
10
+ def clean_tweet(tweet:str) -> str:
11
+ """
12
+ Convert all text to lowercase, remove stock market tickers, RT symbol, hyperlinks and the hastag symbol
13
+ :param tweet: tweet by a unique user
14
+ :return: cleaned string without hashtags, emojis, and punctuation
15
+ """
16
+ # make text lower case
17
+ tweet = tweet.lower()
18
+ # remove stock market tickers like $GE
19
+ tweet = re.sub(r'\$\w*', '', str(tweet))
20
+ # remove old style retweet text "RT"
21
+ tweet = re.sub(r'^RT[\s]+', '', str(tweet))
22
+ # remove hyperlinks
23
+ tweet = re.sub(r'https?:\/\/.*[\r\n]*', '', str(tweet))
24
+ # remove hashtags
25
+ # only removing the hash # sign from the word
26
+ tweet = re.sub(r'#', '', str(tweet))
27
+
28
+ # remove punctuation
29
+ punct = set(string.punctuation)
30
+ tweet = "".join(ch for ch in tweet if ch not in punct)
31
+
32
+ # remove stopwords
33
+ stop_words = set(stopwords.words("english"))
34
+ tweet = " ".join(word for word in tweet.split() if word not in stop_words)
35
+
36
+ return tweet
train.py ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import pickle
3
+
4
+ import numpy as np
5
+ import pandas as pd
6
+ from sklearn.metrics import f1_score
7
+ from tensorflow.keras.layers import Embedding
8
+ from sklearn.model_selection import StratifiedKFold
9
+ from tensorflow.keras.preprocessing.text import Tokenizer
10
+ from tensorflow.keras.preprocessing.sequence import pad_sequences
11
+
12
+ import config
13
+ import preprocessing as pp
14
+ import features as f
15
+ import data_cleaning as data_clean
16
+ from lstm_model import my_LSTM
17
+
18
+ # GPU Use
19
+ os.environ["KERAS_BACKEND"] = "plaidml.keras.backend"
20
+
21
+ def run_training(model:str) -> None:
22
+ """
23
+ Training our Machine Learning model and serializing to disc
24
+ """
25
+ # read train and test data
26
+ df_train = pd.read_csv(config.ORIGINAL_TRAIN)
27
+ df_test = pd.read_csv(config.TEST_DATA)
28
+
29
+ # relabel mislabeled samples
30
+ df_train = data_clean.relabel_target(df_train)
31
+
32
+ # shuffle data
33
+ df_train = df_train.sample(frac=1, random_state=42).reset_index(drop=True)
34
+
35
+ # clean the text
36
+ df_train[config.CLEANED_TEXT] = df_train[config.TEXT].apply(pp.clean_tweet)
37
+ df_test[config.CLEANED_TEXT] = df_test[config.TEXT].apply(pp.clean_tweet)
38
+
39
+ # save the modified train and test data
40
+ df_train.to_csv(config.MODIFIED_TRAIN, index=False)
41
+ df_test.to_csv(config.MODIFIED_TEST, index=False)
42
+ del df_test
43
+
44
+ # convert text to numerical representation
45
+ tokenizer = Tokenizer(oov_token="<unk>")
46
+ tokenizer.fit_on_texts(df_train[config.CLEANED_TEXT])
47
+
48
+ # path to save model
49
+ model_path = f"{config.MODEL_DIR}/PRETRAIN_WORD2VEC_{model}/"
50
+
51
+ # checking the folder exist
52
+ if not os.path.exists(model_path):
53
+ os.makedirs(model_path)
54
+
55
+ # saving tokenizer
56
+ with open(f'{model_path}tokenizer.pkl', 'wb') as handle:
57
+ pickle.dump(tokenizer, handle, protocol=pickle.HIGHEST_PROTOCOL)
58
+
59
+ # pad the sequences
60
+ X_padded = pad_sequences(tokenizer.texts_to_sequences(df_train[config.CLEANED_TEXT].values), maxlen=config.MAXLEN)
61
+
62
+ # get the pretrained word embeddings and prepare embedding layer
63
+ embedding_matrix = f.get_word2vec_enc(tokenizer.word_index.items(), config.PRETRAINED_WORD2VEC)
64
+ embedding_layer = Embedding(input_dim=config.VOCAB_SIZE,
65
+ output_dim=config.EMBED_SIZE,
66
+ weights=[embedding_matrix],
67
+ input_length=config.MAXLEN,
68
+ trainable=False)
69
+
70
+ # target values
71
+ y = df_train[config.RELABELED_TARGET].values
72
+
73
+ # train a single model
74
+ clf = my_LSTM(embedding_layer)
75
+ clf.fit(X_padded, y,
76
+ epochs=config.N_EPOCHS,
77
+ verbose=1)
78
+
79
+ # persist the model
80
+ clf.save(f"{model_path}/{model}_Word2Vec.h5")
81
+
82
+ if __name__ == "__main__":
83
+ run_training("LSTM")
84
+
85
+
user_interface.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pickle
2
+
3
+ import numpy as np
4
+ import gradio as gr
5
+ from tensorflow.keras.models import load_model
6
+ from tensorflow.keras.preprocessing.sequence import pad_sequences
7
+
8
+ from src import config
9
+ from src import preprocessing as pp
10
+
11
+ def predict(text:str):
12
+ """
13
+ Predict the class of an instance
14
+ :param text: The tweet text we want to classify
15
+ :return: The Model Output
16
+ """
17
+ outcome_dict = {0: "Non-Disaster", 1: "Disaster"}
18
+
19
+ # path to model
20
+ model_path = f"models/PRETRAIN_WORD2VEC_LSTM/"
21
+
22
+ # do cleaning to text
23
+ clean_text = pp.clean_tweet(text)
24
+ clean_text = np.array([clean_text])
25
+
26
+ # loading tokenizer
27
+ with open(f'{model_path}tokenizer.pkl', 'rb') as handle:
28
+ tokenizer = pickle.load(handle)
29
+
30
+ # convert tokens to sequences and pad them
31
+ data_values = tokenizer.texts_to_sequences(clean_text)
32
+ X_padded = pad_sequences(data_values, maxlen=config.MAXLEN)
33
+
34
+ # load the classifier
35
+ clf = load_model(f"{model_path}LSTM_Word2Vec.h5")
36
+ prediction = clf.predict_classes(X_padded, verbose=-1)
37
+
38
+ prediction = prediction.sum()
39
+ return outcome_dict[prediction]
40
+
41
+ if __name__ == "__main__":
42
+ iface = gr.Interface(
43
+ fn=predict,
44
+ inputs= gr.inputs.Textbox(lines=3, placeholder="Insert Tweet..."),
45
+ outputs="text"
46
+ )
47
+ iface.launch()
utils.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Code Source
2
+ # https://datascience.stackexchange.com/questions/45165/how-to-get-accuracy-f1-precision-and-recall-for-a-keras-model
3
+
4
+ from tensorflow.keras import backend as K
5
+
6
+ def f1_metric(y_true, y_pred): #taken from old keras source code
7
+ true_positives = K.sum(K.round(K.clip(y_true * y_pred, 0, 1)))
8
+ possible_positives = K.sum(K.round(K.clip(y_true, 0, 1)))
9
+ predicted_positives = K.sum(K.round(K.clip(y_pred, 0, 1)))
10
+ precision = true_positives / (predicted_positives + K.epsilon())
11
+ recall = true_positives / (possible_positives + K.epsilon())
12
+ f1_val = 2*(precision*recall)/(precision+recall+K.epsilon())
13
+ return f1_val
14
+
15
+