prz4587 commited on
Commit
e925eab
·
verified ·
1 Parent(s): 338be64

Upload folder using huggingface_hub

Browse files
0-data-prepare.ipynb ADDED
@@ -0,0 +1,509 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 1,
6
+ "id": "7384109c-4507-4895-ac34-8400e2978021",
7
+ "metadata": {},
8
+ "outputs": [],
9
+ "source": [
10
+ "import ujson as json"
11
+ ]
12
+ },
13
+ {
14
+ "cell_type": "code",
15
+ "execution_count": 2,
16
+ "id": "de47a236-bb5d-4cb9-89e4-a8eb3f66abaa",
17
+ "metadata": {},
18
+ "outputs": [
19
+ {
20
+ "name": "stdout",
21
+ "output_type": "stream",
22
+ "text": [
23
+ "art_en_crime.json 22116 dict_keys(['id', 'embedding', 'themes', 'language', 'title', 'content', 'cluster_id'])\n",
24
+ "fincrime_train.json 27548 dict_keys(['content', 'themes', 'embedding', 'llm_themes', 'cl_themes'])\n",
25
+ "train_multilang_2.json 119583 dict_keys(['language', 'content', 'id', 'lang', 'embedding', 'cluster_id', 'themes', 'llm_themes'])\n",
26
+ "other_lang_train.json 149000 dict_keys(['embedding', 'themes', 'id', 'content'])\n"
27
+ ]
28
+ },
29
+ {
30
+ "data": {
31
+ "text/plain": [
32
+ "318247"
33
+ ]
34
+ },
35
+ "execution_count": 2,
36
+ "metadata": {},
37
+ "output_type": "execute_result"
38
+ }
39
+ ],
40
+ "source": [
41
+ "train_files = ['art_en_crime.json', 'fincrime_train.json', 'train_multilang_2.json', 'other_lang_train.json']\n",
42
+ "# 'en_train.json'\n",
43
+ "train_data = []\n",
44
+ "for train_file in train_files:\n",
45
+ " tm = json.load(open('topic_data_new_sorted/'+train_file))\n",
46
+ " print(train_file, len(tm), tm[0].keys())\n",
47
+ " train_data.extend(tm)\n",
48
+ "len(train_data)"
49
+ ]
50
+ },
51
+ {
52
+ "cell_type": "code",
53
+ "execution_count": 3,
54
+ "id": "97e48648-4a9f-40b9-8b95-2740a93b4762",
55
+ "metadata": {},
56
+ "outputs": [],
57
+ "source": [
58
+ "for i in train_data:\n",
59
+ " del i['embedding']"
60
+ ]
61
+ },
62
+ {
63
+ "cell_type": "code",
64
+ "execution_count": 4,
65
+ "id": "b3b5f700-5467-4d8b-963d-060a1962ab00",
66
+ "metadata": {},
67
+ "outputs": [],
68
+ "source": [
69
+ "final_data = []\n",
70
+ "contents = set()\n",
71
+ "for i in train_data:\n",
72
+ " if not i.get('content'):\n",
73
+ " continue\n",
74
+ " th = hash(i['content'])\n",
75
+ " if th not in contents:\n",
76
+ " final_data.append(i)\n",
77
+ " contents.add(th)"
78
+ ]
79
+ },
80
+ {
81
+ "cell_type": "code",
82
+ "execution_count": 5,
83
+ "id": "3527a7a2-b58c-406b-890e-99002ef050df",
84
+ "metadata": {},
85
+ "outputs": [
86
+ {
87
+ "data": {
88
+ "text/plain": [
89
+ "(22116, 245203, 318247)"
90
+ ]
91
+ },
92
+ "execution_count": 5,
93
+ "metadata": {},
94
+ "output_type": "execute_result"
95
+ }
96
+ ],
97
+ "source": [
98
+ "tc, cc = 0, 0\n",
99
+ "for i in train_data:\n",
100
+ " if i.get('title'):\n",
101
+ " tc += 1\n",
102
+ " if i.get('content'):\n",
103
+ " cc += 1\n",
104
+ "tc, cc, len(train_data)"
105
+ ]
106
+ },
107
+ {
108
+ "cell_type": "code",
109
+ "execution_count": 6,
110
+ "id": "460312d0-236c-418d-a53e-1357fdc586bc",
111
+ "metadata": {},
112
+ "outputs": [
113
+ {
114
+ "data": {
115
+ "text/plain": [
116
+ "144496"
117
+ ]
118
+ },
119
+ "execution_count": 6,
120
+ "metadata": {},
121
+ "output_type": "execute_result"
122
+ }
123
+ ],
124
+ "source": [
125
+ "len(final_data)"
126
+ ]
127
+ },
128
+ {
129
+ "cell_type": "code",
130
+ "execution_count": 7,
131
+ "id": "631689b1-c18d-427c-baea-8d2d3d4b9781",
132
+ "metadata": {},
133
+ "outputs": [
134
+ {
135
+ "data": {
136
+ "text/plain": [
137
+ "{'Politics': 44642,\n",
138
+ " 'Crime': 58051,\n",
139
+ " 'Financial Crime': 24289,\n",
140
+ " 'Business': 25247,\n",
141
+ " 'Entertainment': 18221,\n",
142
+ " 'Finance': 6166,\n",
143
+ " 'Economics': 1772,\n",
144
+ " 'Sports': 15184,\n",
145
+ " 'Tech': 4107,\n",
146
+ " 'Automotive': 2116,\n",
147
+ " 'Health': 7829,\n",
148
+ " 'Lifestyle': 368,\n",
149
+ " 'Science': 3481,\n",
150
+ " 'Travel': 914,\n",
151
+ " 'Weather': 1070,\n",
152
+ " 'General': 8972,\n",
153
+ " 'Types of life insurance fraud': 1,\n",
154
+ " 'Consequences of insurance fraud': 1,\n",
155
+ " 'How to prevent life insurance fraud': 1,\n",
156
+ " 'Front Running': 1,\n",
157
+ " 'fraud': 1}"
158
+ ]
159
+ },
160
+ "execution_count": 7,
161
+ "metadata": {},
162
+ "output_type": "execute_result"
163
+ }
164
+ ],
165
+ "source": [
166
+ "themes = {}\n",
167
+ "for i in final_data:\n",
168
+ " for t in i['themes']:\n",
169
+ " themes[t] = themes.get(t, 0) + 1\n",
170
+ "themes"
171
+ ]
172
+ },
173
+ {
174
+ "cell_type": "code",
175
+ "execution_count": 8,
176
+ "id": "a7bccdc9-de94-4e7b-b686-0fa062f5e781",
177
+ "metadata": {},
178
+ "outputs": [
179
+ {
180
+ "name": "stderr",
181
+ "output_type": "stream",
182
+ "text": [
183
+ "/usr/local/lib/python3.10/dist-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
184
+ " from .autonotebook import tqdm as notebook_tqdm\n"
185
+ ]
186
+ },
187
+ {
188
+ "data": {
189
+ "text/plain": [
190
+ "'tr'"
191
+ ]
192
+ },
193
+ "execution_count": 8,
194
+ "metadata": {},
195
+ "output_type": "execute_result"
196
+ }
197
+ ],
198
+ "source": [
199
+ "from fast_langdetect import detect\n",
200
+ "result = detect(text=\"Bugün hava çok güzel\", model='full', k=1)\n",
201
+ "result[0]['lang']"
202
+ ]
203
+ },
204
+ {
205
+ "cell_type": "code",
206
+ "execution_count": 9,
207
+ "id": "ed34643a-974c-4995-b349-6ce194397985",
208
+ "metadata": {},
209
+ "outputs": [
210
+ {
211
+ "data": {
212
+ "text/plain": [
213
+ "{'en': 49682,\n",
214
+ " 'hi': 541,\n",
215
+ " 'mt': 1,\n",
216
+ " 'fr': 11193,\n",
217
+ " 'de': 12987,\n",
218
+ " 'pt': 1586,\n",
219
+ " 'it': 1682,\n",
220
+ " 'ko': 455,\n",
221
+ " 'es': 4034,\n",
222
+ " 'zh': 5,\n",
223
+ " 'ar': 35732,\n",
224
+ " 'bn': 420,\n",
225
+ " 'pl': 1387,\n",
226
+ " 'nl': 824,\n",
227
+ " 'lv': 379,\n",
228
+ " 'hk': 345,\n",
229
+ " 'hr': 511,\n",
230
+ " 'ja': 591,\n",
231
+ " 'cy': 182,\n",
232
+ " 'sv': 852,\n",
233
+ " 'da': 2120,\n",
234
+ " 'el': 462,\n",
235
+ " 'tr': 507,\n",
236
+ " 'ro': 546,\n",
237
+ " 'ur': 723,\n",
238
+ " 'mr': 344,\n",
239
+ " 'so': 338,\n",
240
+ " 'fa': 729,\n",
241
+ " 'mk': 451,\n",
242
+ " 'gu': 385,\n",
243
+ " 'th': 617,\n",
244
+ " 'lt': 444,\n",
245
+ " 'tw': 246,\n",
246
+ " 'sl': 568,\n",
247
+ " 'ml': 414,\n",
248
+ " 'te': 344,\n",
249
+ " 'he': 526,\n",
250
+ " 'cs': 671,\n",
251
+ " 'et': 660,\n",
252
+ " 'ta': 463,\n",
253
+ " 'gl': 9,\n",
254
+ " 'id': 538,\n",
255
+ " 'ca': 506,\n",
256
+ " 'ast': 1,\n",
257
+ " 'eu': 1,\n",
258
+ " 'sk': 455,\n",
259
+ " 'sq': 424,\n",
260
+ " 'ne': 429,\n",
261
+ " 'fi': 652,\n",
262
+ " 'sw': 394,\n",
263
+ " 'bg': 507,\n",
264
+ " 'ru': 653,\n",
265
+ " 'hu': 564,\n",
266
+ " 'cn': 261,\n",
267
+ " 'vi': 535,\n",
268
+ " 'pa': 343,\n",
269
+ " 'no': 1156,\n",
270
+ " 'tl': 559,\n",
271
+ " 'uk': 500,\n",
272
+ " 'kn': 374,\n",
273
+ " 'af': 611,\n",
274
+ " 'arz': 58,\n",
275
+ " 'nn': 6,\n",
276
+ " 'dv': 10,\n",
277
+ " 'azb': 1,\n",
278
+ " 'sd': 1,\n",
279
+ " 'ckb': 1}"
280
+ ]
281
+ },
282
+ "execution_count": 9,
283
+ "metadata": {},
284
+ "output_type": "execute_result"
285
+ }
286
+ ],
287
+ "source": [
288
+ "langs = {}\n",
289
+ "for i in final_data:\n",
290
+ " l = i.get('language')\n",
291
+ " if not l:\n",
292
+ " l = i.get('language')\n",
293
+ " if not l:\n",
294
+ " l = detect(text=i['content'], model='full', k=1)[0]['lang']\n",
295
+ " langs[l] = langs.get(l, 0) + 1\n",
296
+ "langs"
297
+ ]
298
+ },
299
+ {
300
+ "cell_type": "code",
301
+ "execution_count": 12,
302
+ "id": "dcdc4560-2bcb-4d90-a64d-436b85f23bc6",
303
+ "metadata": {},
304
+ "outputs": [],
305
+ "source": [
306
+ "del train_data"
307
+ ]
308
+ },
309
+ {
310
+ "cell_type": "code",
311
+ "execution_count": 2,
312
+ "id": "770489b9-78c2-4231-9c8b-ba000342fe6d",
313
+ "metadata": {},
314
+ "outputs": [],
315
+ "source": [
316
+ "final_data = json.load(open('train1.json'))"
317
+ ]
318
+ },
319
+ {
320
+ "cell_type": "code",
321
+ "execution_count": 7,
322
+ "id": "a69be4e4-08b3-4830-9b4d-b8ea67f15656",
323
+ "metadata": {},
324
+ "outputs": [],
325
+ "source": [
326
+ "import requests\n",
327
+ "from requests.adapters import HTTPAdapter\n",
328
+ "from concurrent.futures import ThreadPoolExecutor, as_completed\n",
329
+ "from typing import List, Dict, Any, Optional, Tuple\n",
330
+ "\n",
331
+ "def _make_session(pool_size: int) -> requests.Session:\n",
332
+ " s = requests.Session()\n",
333
+ " adapter = HTTPAdapter(pool_connections=pool_size, pool_maxsize=pool_size, max_retries=0)\n",
334
+ " s.mount(\"http://\", adapter)\n",
335
+ " s.mount(\"https://\", adapter)\n",
336
+ " return s\n",
337
+ "\n",
338
+ "\n",
339
+ "def _embed_batch(\n",
340
+ " session: requests.Session,\n",
341
+ " base_url: str,\n",
342
+ " text: str,\n",
343
+ " timeout_s: float,\n",
344
+ " extra_headers: Optional[Dict[str, str]] = None,\n",
345
+ ") -> Any:\n",
346
+ " \"\"\"\n",
347
+ " Native TEI endpoint: POST {base_url}/embed\n",
348
+ " Payload: {\"inputs\": [..texts..]}\n",
349
+ " \"\"\"\n",
350
+ " url = base_url.rstrip(\"/\") + \"/embed\"\n",
351
+ " headers = {\"Content-Type\": \"application/json\"}\n",
352
+ " if extra_headers:\n",
353
+ " headers.update(extra_headers)\n",
354
+ " r = session.post(url, json={\"inputs\": text}, headers=headers, timeout=timeout_s)\n",
355
+ " r.raise_for_status()\n",
356
+ " return r.json()\n",
357
+ "\n",
358
+ "def one_call(art):\n",
359
+ " if \"embedding\" in art:\n",
360
+ " return\n",
361
+ " text = []\n",
362
+ " if art.get(\"title\"):\n",
363
+ " text.append(art[\"title\"])\n",
364
+ " if art.get(\"content\"):\n",
365
+ " text.append(art[\"content\"])\n",
366
+ " text = \"\\n\\n\".join(text)\n",
367
+ " res = _embed_batch(session, base_url, text, timeout_s=30)\n",
368
+ " art[\"embedding\"] = res[0]\n",
369
+ "\n",
370
+ "WORKERS = 64\n",
371
+ "session = _make_session(pool_size=WORKERS)\n",
372
+ "base_url = \"http://172.83.12.123:8080\""
373
+ ]
374
+ },
375
+ {
376
+ "cell_type": "code",
377
+ "execution_count": 6,
378
+ "id": "e8d7127f-af28-4029-aa16-35b52269317a",
379
+ "metadata": {},
380
+ "outputs": [],
381
+ "source": [
382
+ "with ThreadPoolExecutor(max_workers=WORKERS) as ex:\n",
383
+ " _ = [ex.submit(one_call, art) for art in final_data]"
384
+ ]
385
+ },
386
+ {
387
+ "cell_type": "code",
388
+ "execution_count": 10,
389
+ "id": "1ba25f4b-7bec-40a6-9b78-d014a64674c4",
390
+ "metadata": {},
391
+ "outputs": [
392
+ {
393
+ "data": {
394
+ "text/plain": [
395
+ "(144496, 144496)"
396
+ ]
397
+ },
398
+ "execution_count": 10,
399
+ "metadata": {},
400
+ "output_type": "execute_result"
401
+ }
402
+ ],
403
+ "source": [
404
+ "te = 0\n",
405
+ "for i in final_data:\n",
406
+ " if \"embedding\" in i and len(i[\"embedding\"]) == 1024:\n",
407
+ " te += 1\n",
408
+ "te, len(final_data)"
409
+ ]
410
+ },
411
+ {
412
+ "cell_type": "code",
413
+ "execution_count": 9,
414
+ "id": "1e821b09-cfce-4ae6-82c6-1d786efcc676",
415
+ "metadata": {},
416
+ "outputs": [],
417
+ "source": [
418
+ "json.dump(final_data, open('train1.json', 'w'))"
419
+ ]
420
+ },
421
+ {
422
+ "cell_type": "code",
423
+ "execution_count": null,
424
+ "id": "19a5069a-6a0b-4ed2-ae4c-173899b4d632",
425
+ "metadata": {},
426
+ "outputs": [],
427
+ "source": []
428
+ },
429
+ {
430
+ "cell_type": "code",
431
+ "execution_count": null,
432
+ "id": "4b02dcf6-2d64-4dab-a417-da3d2c411254",
433
+ "metadata": {},
434
+ "outputs": [],
435
+ "source": []
436
+ },
437
+ {
438
+ "cell_type": "code",
439
+ "execution_count": 9,
440
+ "id": "f427d890-ce22-4a04-9157-04528db80369",
441
+ "metadata": {},
442
+ "outputs": [
443
+ {
444
+ "name": "stdout",
445
+ "output_type": "stream",
446
+ "text": [
447
+ "en_test.json 750 750 750\n",
448
+ "other_lang_test.json 1000 981 981\n",
449
+ "spanish_tagged.json 1199 1197 1197\n",
450
+ "fincrime_test.json 623 623 623\n"
451
+ ]
452
+ }
453
+ ],
454
+ "source": [
455
+ "files = [\"en_test.json\", \"other_lang_test.json\", \"spanish_tagged.json\", \"fincrime_test.json\"]\n",
456
+ "for file in files:\n",
457
+ " tm = json.load(open('topic_data_new_sorted/'+file))\n",
458
+ " tf = []\n",
459
+ " cc, em = 0, 0\n",
460
+ " contents = set()\n",
461
+ " for i in tm:\n",
462
+ " if i['content'].lower() in contents:\n",
463
+ " continue\n",
464
+ " if i.get('content'):\n",
465
+ " cc += 1\n",
466
+ " contents.add(i['content'].lower())\n",
467
+ " if \"embedding\" in i:\n",
468
+ " del i[\"embedding\"]\n",
469
+ " tf.append(i)\n",
470
+ " with ThreadPoolExecutor(max_workers=WORKERS) as ex:\n",
471
+ " _ = [ex.submit(one_call, art) for art in tf]\n",
472
+ " for i in tf:\n",
473
+ " if \"embedding\" in i and len(i[\"embedding\"]) == 1024:\n",
474
+ " em += 1\n",
475
+ " json.dump(tf, open('test/' + file, 'w'))\n",
476
+ " print(file, len(tm), cc, em)"
477
+ ]
478
+ },
479
+ {
480
+ "cell_type": "code",
481
+ "execution_count": null,
482
+ "id": "7e42ec21-99df-4b62-b678-86bf24bb259e",
483
+ "metadata": {},
484
+ "outputs": [],
485
+ "source": []
486
+ }
487
+ ],
488
+ "metadata": {
489
+ "kernelspec": {
490
+ "display_name": "Python 3 (ipykernel)",
491
+ "language": "python",
492
+ "name": "python3"
493
+ },
494
+ "language_info": {
495
+ "codemirror_mode": {
496
+ "name": "ipython",
497
+ "version": 3
498
+ },
499
+ "file_extension": ".py",
500
+ "mimetype": "text/x-python",
501
+ "name": "python",
502
+ "nbconvert_exporter": "python",
503
+ "pygments_lexer": "ipython3",
504
+ "version": "3.10.12"
505
+ }
506
+ },
507
+ "nbformat": 4,
508
+ "nbformat_minor": 5
509
+ }
1-training.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
2-prepare-data-2.ipynb ADDED
@@ -0,0 +1,1933 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 1,
6
+ "id": "a184875c-343b-4a39-ac51-2a80f8f661a5",
7
+ "metadata": {},
8
+ "outputs": [],
9
+ "source": [
10
+ "import ujson as json\n",
11
+ "import random\n",
12
+ "import os"
13
+ ]
14
+ },
15
+ {
16
+ "cell_type": "code",
17
+ "execution_count": null,
18
+ "id": "a2f82047-027a-4950-8a17-6c8437d2f2c9",
19
+ "metadata": {},
20
+ "outputs": [],
21
+ "source": [
22
+ "1. prepare dataset by lang + theme\n",
23
+ "2. choose MMR candidates (take atleast 15K each theme and 100 per each lang, en=7K, other=8K)\n",
24
+ "3. calculate embedding of qwen\n",
25
+ "4. again choose MMR based on qwen embedding also make list of llm verify candidate when qwen emb not matching with given themes (5K each)\n",
26
+ "5. confirm with LLM for selected candidates\n",
27
+ "6. split train & test also pre-train"
28
+ ]
29
+ },
30
+ {
31
+ "cell_type": "markdown",
32
+ "id": "e1679138-5178-4262-a2f1-5589dd80c7e9",
33
+ "metadata": {},
34
+ "source": [
35
+ "# 1. prepare dataset by lang + theme"
36
+ ]
37
+ },
38
+ {
39
+ "cell_type": "code",
40
+ "execution_count": null,
41
+ "id": "c70cc644-b406-4105-8bef-5dad0552d354",
42
+ "metadata": {},
43
+ "outputs": [],
44
+ "source": [
45
+ "files = ['art_en.json', 'art_en2.json', 'art_non_en.json', 'art_non_en2.json']\n",
46
+ "for file in files:\n",
47
+ " fo = open('raw_articles/' + file)\n",
48
+ " tt = {}\n",
49
+ " for line in fo:\n",
50
+ " d = json.loads(line)\n",
51
+ " if not d['_source'].get('nlp'):\n",
52
+ " continue\n",
53
+ " for t in d['_source']['nlp']['theme']:\n",
54
+ " tt[t] = tt.get(t, 0) + 1\n",
55
+ " tt = dict(sorted(tt.items(), key=lambda item: item[1]))\n",
56
+ " print(file, tt)"
57
+ ]
58
+ },
59
+ {
60
+ "cell_type": "code",
61
+ "execution_count": null,
62
+ "id": "94a9c39a-e060-4b87-8edd-e65a3c480c86",
63
+ "metadata": {},
64
+ "outputs": [],
65
+ "source": [
66
+ "files = ['art_en2_specific.json', 'art_non_en_specific.json']\n",
67
+ "for file in files:\n",
68
+ " fo = open('raw_articles/' + file)\n",
69
+ " tt = {}\n",
70
+ " for line in fo:\n",
71
+ " d = json.loads(line)\n",
72
+ " if not d['_source'].get('nlp'):\n",
73
+ " continue\n",
74
+ " for t in d['_source']['nlp']['theme']:\n",
75
+ " tt[t] = tt.get(t, 0) + 1\n",
76
+ " tt = dict(sorted(tt.items(), key=lambda item: item[1]))\n",
77
+ " print(file, tt)"
78
+ ]
79
+ },
80
+ {
81
+ "cell_type": "code",
82
+ "execution_count": null,
83
+ "id": "b2301a8d-aef2-4776-9d02-c35aae042267",
84
+ "metadata": {},
85
+ "outputs": [],
86
+ "source": []
87
+ },
88
+ {
89
+ "cell_type": "code",
90
+ "execution_count": null,
91
+ "id": "e28c0287-b649-4dea-a6f3-1e83a2669b69",
92
+ "metadata": {},
93
+ "outputs": [],
94
+ "source": [
95
+ "files = ['art_en.json', 'art_en2.json', 'art_non_en.json', 'art_non_en2.json', 'art_en2_specific.json', 'art_non_en_specific.json']\n",
96
+ "tt = {}\n",
97
+ "for file in files:\n",
98
+ " fo = open('raw_articles/' + file)\n",
99
+ " for line in fo:\n",
100
+ " d = json.loads(line)\n",
101
+ " if not d['_source'].get('nlp'):\n",
102
+ " continue\n",
103
+ " for t in d['_source']['nlp']['theme']:\n",
104
+ " tt[t] = tt.get(t, 0) + 1\n",
105
+ " tt = dict(sorted(tt.items(), key=lambda item: item[1]))\n",
106
+ "print(tt)"
107
+ ]
108
+ },
109
+ {
110
+ "cell_type": "code",
111
+ "execution_count": null,
112
+ "id": "9b735036-3a34-4d20-bae9-58f8b5bc7100",
113
+ "metadata": {},
114
+ "outputs": [],
115
+ "source": [
116
+ "priority_order = [\n",
117
+ " \"Economics\",\n",
118
+ " \"Financial Crime\",\n",
119
+ " \"Finance\",\n",
120
+ " \"Lifestyle\",\n",
121
+ " \"Automotive\",\n",
122
+ " \"Science\",\n",
123
+ " \"Tech\",\n",
124
+ " \"Travel\",\n",
125
+ " \"Weather\",\n",
126
+ " \"Health\",\n",
127
+ " \"Crime\",\n",
128
+ " \"Sports\",\n",
129
+ " \"General\",\n",
130
+ " \"Business\",\n",
131
+ " \"Politics\",\n",
132
+ " \"Entertainment\",\n",
133
+ " ]\n",
134
+ "\n",
135
+ "def sort_themes_by_priority(themes):\n",
136
+ " return sorted(\n",
137
+ " themes,\n",
138
+ " key=lambda theme: priority_order.index(theme)\n",
139
+ " if theme in priority_order else float(\"inf\")\n",
140
+ " )"
141
+ ]
142
+ },
143
+ {
144
+ "cell_type": "code",
145
+ "execution_count": null,
146
+ "id": "e71c50e1-c159-4767-b366-f23153199e96",
147
+ "metadata": {},
148
+ "outputs": [],
149
+ "source": [
150
+ "en_ds = {}\n",
151
+ "\n",
152
+ "files = ['art_en.json', 'art_en2.json', 'art_en2_specific.json']\n",
153
+ "ts = set()\n",
154
+ "for file in files:\n",
155
+ " fo = open('raw_articles/' + file)\n",
156
+ " for line in fo:\n",
157
+ " d = json.loads(line)\n",
158
+ " art = {\"id\": d[\"_id\"], **d['_source']}\n",
159
+ " if not art.get('nlp'):\n",
160
+ " continue\n",
161
+ " if not art['nlp'].get('new_embedding') or not art['nlp'].get('theme'):\n",
162
+ " continue\n",
163
+ " th = hash(art['title'].lower())\n",
164
+ " if th in ts:\n",
165
+ " continue\n",
166
+ " ts.add(th)\n",
167
+ " sts = art['nlp']['theme']\n",
168
+ " sts = sort_themes_by_priority(sts)\n",
169
+ " for st in sts:\n",
170
+ " if st not in en_ds:\n",
171
+ " en_ds[st] = []\n",
172
+ " if len(en_ds[st]) >= 80000:\n",
173
+ " continue\n",
174
+ " en_ds[st].append(art)\n",
175
+ " break"
176
+ ]
177
+ },
178
+ {
179
+ "cell_type": "code",
180
+ "execution_count": null,
181
+ "id": "71b6a2b2-20c1-4fb7-a10d-2c4a5eb1d1e4",
182
+ "metadata": {},
183
+ "outputs": [],
184
+ "source": [
185
+ "for st in en_ds:\n",
186
+ " print(st, '==>', len(en_ds[st]))"
187
+ ]
188
+ },
189
+ {
190
+ "cell_type": "code",
191
+ "execution_count": null,
192
+ "id": "e8ffff0a-94ac-4d50-8ca1-179d163c458f",
193
+ "metadata": {},
194
+ "outputs": [],
195
+ "source": [
196
+ "json.dump(en_ds, open('filtered/en_ds.json', 'w'))"
197
+ ]
198
+ },
199
+ {
200
+ "cell_type": "code",
201
+ "execution_count": null,
202
+ "id": "0cdd6962-d7f8-4c1b-aea5-979b0f622308",
203
+ "metadata": {},
204
+ "outputs": [],
205
+ "source": [
206
+ "from fast_langdetect import detect"
207
+ ]
208
+ },
209
+ {
210
+ "cell_type": "code",
211
+ "execution_count": null,
212
+ "id": "cdb31010-0d34-47dd-8843-2fd417cea31f",
213
+ "metadata": {},
214
+ "outputs": [],
215
+ "source": [
216
+ "ml_ds = {}\n",
217
+ "\n",
218
+ "files = ['art_non_en.json', 'art_non_en2.json', 'art_non_en_specific.json']\n",
219
+ "ts = set()\n",
220
+ "for file in files:\n",
221
+ " fo = open('raw_articles/' + file)\n",
222
+ " for line in fo:\n",
223
+ " d = json.loads(line)\n",
224
+ " art = {\"id\": d[\"_id\"], **d['_source']}\n",
225
+ " if not art.get('nlp'):\n",
226
+ " continue\n",
227
+ " if not art['nlp'].get('new_embedding') or not art['nlp'].get('theme'):\n",
228
+ " continue\n",
229
+ " th = hash(art['title'].lower())\n",
230
+ " if th in ts:\n",
231
+ " continue\n",
232
+ " ts.add(th)\n",
233
+ "\n",
234
+ " lang = detect(text=art['content'], model='full', k=1)[0]['lang']\n",
235
+ " if lang not in ml_ds:\n",
236
+ " ml_ds[lang] = {}\n",
237
+ "\n",
238
+ " sts = art['nlp']['theme']\n",
239
+ " sts = sort_themes_by_priority(sts)\n",
240
+ " for st in sts:\n",
241
+ " if st not in ml_ds[lang]:\n",
242
+ " ml_ds[lang][st] = []\n",
243
+ " if len(ml_ds[lang][st]) >= 5000:\n",
244
+ " continue\n",
245
+ " ml_ds[lang][st].append(art)\n",
246
+ " break"
247
+ ]
248
+ },
249
+ {
250
+ "cell_type": "code",
251
+ "execution_count": null,
252
+ "id": "0dbe9913-f48f-4f08-9c25-89e9b4e9dc71",
253
+ "metadata": {},
254
+ "outputs": [],
255
+ "source": []
256
+ },
257
+ {
258
+ "cell_type": "code",
259
+ "execution_count": null,
260
+ "id": "222f0b88-0808-426b-8ace-f57727265f79",
261
+ "metadata": {},
262
+ "outputs": [],
263
+ "source": [
264
+ "for lang in ml_ds:\n",
265
+ " xx = {t: len(ml_ds[lang][t]) for t in ml_ds[lang]}\n",
266
+ " print(lang, '==>', xx)"
267
+ ]
268
+ },
269
+ {
270
+ "cell_type": "code",
271
+ "execution_count": null,
272
+ "id": "127ce293-80b4-4f9d-973f-8293df8865ad",
273
+ "metadata": {},
274
+ "outputs": [],
275
+ "source": [
276
+ "for lang in ml_ds:\n",
277
+ " tot = sum(len(ml_ds[lang][t]) for t in ml_ds[lang])\n",
278
+ " if tot < 1000:\n",
279
+ " continue\n",
280
+ " json.dump(ml_ds[lang], open('filtered/' + lang + '_ds.json', 'w'))"
281
+ ]
282
+ },
283
+ {
284
+ "cell_type": "markdown",
285
+ "id": "44328527-b38f-4a69-90cd-92fbaa1e1fd3",
286
+ "metadata": {},
287
+ "source": [
288
+ "# 2. choose MMR candidates (take atleast 15K each theme and 100 per each lang, en=7K, other=8K)"
289
+ ]
290
+ },
291
+ {
292
+ "cell_type": "code",
293
+ "execution_count": null,
294
+ "id": "9ac418a0-0efd-4c77-ae86-4d8c8eb107c5",
295
+ "metadata": {},
296
+ "outputs": [],
297
+ "source": [
298
+ "#!/usr/bin/env python3\n",
299
+ "from __future__ import annotations\n",
300
+ "\n",
301
+ "from typing import Any, Dict, List, Optional\n",
302
+ "import numpy as np\n",
303
+ "import torch\n",
304
+ "import faiss\n",
305
+ "import faiss.contrib.torch_utils # enables passing torch tensors to faiss on GPU\n",
306
+ "\n",
307
+ "\n",
308
+ "def torch_l2_normalize_(x: torch.Tensor, eps: float = 1e-12) -> torch.Tensor:\n",
309
+ " # in-place row-wise normalize\n",
310
+ " norms = torch.linalg.norm(x, dim=1, keepdim=True).clamp_min(eps)\n",
311
+ " x.div_(norms)\n",
312
+ " return x\n",
313
+ "\n",
314
+ "\n",
315
+ "class FaissGpuDiverseSelectorTorch:\n",
316
+ " \"\"\"\n",
317
+ " GPU-only pipeline:\n",
318
+ " - build embeddings tensor on GPU (torch)\n",
319
+ " - normalize on GPU\n",
320
+ " - build GpuIndexFlatIP (cosine via normalization)\n",
321
+ " - diversity selection via GPU spherical clustering (on a SAMPLE) + GPU NN mapping\n",
322
+ " \"\"\"\n",
323
+ "\n",
324
+ " def __init__(\n",
325
+ " self,\n",
326
+ " embedding_key: str = \"new_embedding\",\n",
327
+ " gpu_id: int = 0,\n",
328
+ " seed: int = 123,\n",
329
+ " temp_mem_gb: float = 6.0,\n",
330
+ " use_float16_storage: bool = True,\n",
331
+ " build_batch_size: int = 16384, # batching reduces peak CPU RAM + helps conversion\n",
332
+ " ):\n",
333
+ " self.embedding_key = embedding_key\n",
334
+ " self.gpu_id = int(gpu_id)\n",
335
+ " self.seed = int(seed)\n",
336
+ " self.temp_mem_bytes = int(temp_mem_gb * 1024**3)\n",
337
+ " self.use_float16_storage = bool(use_float16_storage)\n",
338
+ " self.build_batch_size = int(build_batch_size)\n",
339
+ "\n",
340
+ " self.items: Optional[List[Dict[str, Any]]] = None\n",
341
+ " self.xb: Optional[torch.Tensor] = None # (N,d) on GPU, float32 normalized\n",
342
+ " self.res: Optional[faiss.StandardGpuResources] = None\n",
343
+ " self.index: Optional[faiss.GpuIndexFlatIP] = None\n",
344
+ " self.d: Optional[int] = None\n",
345
+ "\n",
346
+ " def fit(self, items: List[Dict[str, Any]]) -> \"FaissGpuDiverseSelectorTorch\":\n",
347
+ " ngpu = faiss.get_num_gpus()\n",
348
+ " if ngpu <= 0:\n",
349
+ " raise RuntimeError(\"faiss.get_num_gpus()==0. You are not using faiss-gpu / CUDA not visible.\")\n",
350
+ " if self.gpu_id >= ngpu:\n",
351
+ " raise RuntimeError(f\"gpu_id={self.gpu_id} out of range; FAISS sees {ngpu} GPUs.\")\n",
352
+ "\n",
353
+ " self.items = items\n",
354
+ " n = len(items)\n",
355
+ " if n == 0:\n",
356
+ " raise ValueError(\"Empty items\")\n",
357
+ "\n",
358
+ " # Infer dim from first embedding\n",
359
+ " first = items[0][self.embedding_key]\n",
360
+ " d = len(first)\n",
361
+ " self.d = d\n",
362
+ "\n",
363
+ " # Build GPU tensor in batches (still CPU->GPU copy, but avoids extra NumPy overhead)\n",
364
+ " dev = torch.device(f\"cuda:{self.gpu_id}\")\n",
365
+ " chunks: List[torch.Tensor] = []\n",
366
+ "\n",
367
+ " bs = self.build_batch_size\n",
368
+ " for start in range(0, n, bs):\n",
369
+ " end = min(start + bs, n)\n",
370
+ " # NOTE: this Python loop is unavoidable with list-of-dicts.\n",
371
+ " # If you can store embeddings as a single array/tensor upstream, do that instead.\n",
372
+ " batch = np.asarray([items[i][self.embedding_key] for i in range(start, end)], dtype=np.float32)\n",
373
+ " t = torch.from_numpy(batch).to(device=dev, non_blocking=False)\n",
374
+ " chunks.append(t)\n",
375
+ "\n",
376
+ " xb = torch.cat(chunks, dim=0) # (N,d) float32 on GPU\n",
377
+ "\n",
378
+ " # Normalize on GPU\n",
379
+ " torch_l2_normalize_(xb)\n",
380
+ "\n",
381
+ " # Create and keep GPU resources + big temp memory\n",
382
+ " res = faiss.StandardGpuResources()\n",
383
+ " res.setTempMemory(self.temp_mem_bytes)\n",
384
+ " self.res = res\n",
385
+ "\n",
386
+ " # GPU IP index (cosine via normalization)\n",
387
+ " cfg = faiss.GpuIndexFlatConfig()\n",
388
+ " cfg.device = self.gpu_id\n",
389
+ " cfg.useFloat16 = self.use_float16_storage\n",
390
+ "\n",
391
+ " index = faiss.GpuIndexFlatIP(res, d, cfg)\n",
392
+ "\n",
393
+ " # Add directly from GPU torch tensor (faiss.contrib.torch_utils)\n",
394
+ " index.add(xb)\n",
395
+ "\n",
396
+ " self.xb = xb\n",
397
+ " self.index = index\n",
398
+ " return self\n",
399
+ "\n",
400
+ " def select(\n",
401
+ " self,\n",
402
+ " num_select: int,\n",
403
+ " train_per_centroid: int = 500, # IMPORTANT: keep this modest to avoid “forever”\n",
404
+ " niter: int = 6,\n",
405
+ " nredo: int = 1,\n",
406
+ " centroid_search_k: int = 128,\n",
407
+ " ) -> List[Dict[str, Any]]:\n",
408
+ " if self.items is None or self.xb is None or self.index is None or self.res is None or self.d is None:\n",
409
+ " raise RuntimeError(\"Call fit() first.\")\n",
410
+ "\n",
411
+ " items = self.items\n",
412
+ " xb = self.xb\n",
413
+ " index = self.index\n",
414
+ " res = self.res\n",
415
+ " d = self.d\n",
416
+ "\n",
417
+ " N = xb.shape[0]\n",
418
+ " k = min(int(num_select), int(N))\n",
419
+ " if k <= 0:\n",
420
+ " return []\n",
421
+ "\n",
422
+ " # Sample training data ON GPU\n",
423
+ " train_sz = min(int(N), k * int(train_per_centroid))\n",
424
+ " g = torch.Generator(device=xb.device)\n",
425
+ " g.manual_seed(self.seed)\n",
426
+ "\n",
427
+ " if train_sz < N:\n",
428
+ " perm = torch.randperm(int(N), generator=g, device=xb.device)\n",
429
+ " train_idx = perm[:train_sz]\n",
430
+ " xtrain = xb.index_select(0, train_idx)\n",
431
+ " else:\n",
432
+ " xtrain = xb\n",
433
+ "\n",
434
+ " # FAISS GPU spherical clustering\n",
435
+ " clus = faiss.Clustering(d, k)\n",
436
+ " clus.seed = self.seed\n",
437
+ " clus.niter = int(niter)\n",
438
+ " clus.nredo = int(nredo)\n",
439
+ " clus.spherical = True\n",
440
+ " clus.verbose = False\n",
441
+ "\n",
442
+ " cfg = faiss.GpuIndexFlatConfig()\n",
443
+ " cfg.device = self.gpu_id\n",
444
+ " cfg.useFloat16 = self.use_float16_storage\n",
445
+ " assign_index = faiss.GpuIndexFlatIP(res, d, cfg)\n",
446
+ "\n",
447
+ " clus.train(xtrain.cpu().numpy(), assign_index)\n",
448
+ "\n",
449
+ " centroids = faiss.vector_to_array(clus.centroids).reshape(k, d).astype(np.float32, copy=False)\n",
450
+ " # Move centroids to GPU torch, normalize, then search on GPU\n",
451
+ " C = torch.from_numpy(centroids).to(device=xb.device)\n",
452
+ " torch_l2_normalize_(C)\n",
453
+ "\n",
454
+ " centroid_search_k = max(int(centroid_search_k), 1)\n",
455
+ " _, I = index.search(C, centroid_search_k) # I is a torch tensor (thanks to torch_utils)\n",
456
+ "\n",
457
+ " chosen: List[int] = []\n",
458
+ " used = set()\n",
459
+ "\n",
460
+ " I_cpu = I.to(\"cpu\").numpy() # small: (k, centroid_search_k)\n",
461
+ " for r in range(k):\n",
462
+ " pick = None\n",
463
+ " for j in range(centroid_search_k):\n",
464
+ " idx = int(I_cpu[r, j])\n",
465
+ " if idx >= 0 and idx not in used:\n",
466
+ " pick = idx\n",
467
+ " break\n",
468
+ " if pick is not None:\n",
469
+ " used.add(pick)\n",
470
+ " chosen.append(pick)\n",
471
+ "\n",
472
+ " # Fill if duplicates reduced count\n",
473
+ " if len(chosen) < k:\n",
474
+ " for r in range(k):\n",
475
+ " if len(chosen) >= k:\n",
476
+ " break\n",
477
+ " for j in range(centroid_search_k):\n",
478
+ " idx = int(I_cpu[r, j])\n",
479
+ " if idx >= 0 and idx not in used:\n",
480
+ " used.add(idx)\n",
481
+ " chosen.append(idx)\n",
482
+ " if len(chosen) >= k:\n",
483
+ " break\n",
484
+ "\n",
485
+ " return [items[i] for i in chosen[:k]]"
486
+ ]
487
+ },
488
+ {
489
+ "cell_type": "code",
490
+ "execution_count": null,
491
+ "id": "2ab2368e-7ec0-4870-b092-243b6c45bd70",
492
+ "metadata": {},
493
+ "outputs": [],
494
+ "source": []
495
+ },
496
+ {
497
+ "cell_type": "code",
498
+ "execution_count": null,
499
+ "id": "d5fd1433-1ffd-489f-8bc7-bca2d76a22d4",
500
+ "metadata": {},
501
+ "outputs": [],
502
+ "source": []
503
+ },
504
+ {
505
+ "cell_type": "code",
506
+ "execution_count": null,
507
+ "id": "c4a937d7-e6bf-4f98-bbcd-ed9fb746b731",
508
+ "metadata": {},
509
+ "outputs": [],
510
+ "source": [
511
+ "files = os.listdir('filtered')\n",
512
+ "files = [i for i in files if i != 'en_ds.json' and i.endswith('.json')]"
513
+ ]
514
+ },
515
+ {
516
+ "cell_type": "code",
517
+ "execution_count": null,
518
+ "id": "c3c93e4b-5f2e-4ad2-9b39-59942539bce5",
519
+ "metadata": {
520
+ "scrolled": true
521
+ },
522
+ "outputs": [],
523
+ "source": [
524
+ "non_en = {}\n",
525
+ "for file in files:\n",
526
+ " data = json.load(open('filtered/'+file))\n",
527
+ " for theme in data:\n",
528
+ " data[theme] = [{\"id\": i[\"id\"], \"title\": i[\"title\"], \"content\": i[\"content\"], \"language\": i[\"language\"], **i[\"nlp\"]} for i in data[theme]]\n",
529
+ " mmr_cands = len(data[theme])//5\n",
530
+ " if mmr_cands <= 2:\n",
531
+ " continue\n",
532
+ " selector = FaissGpuDiverseSelectorTorch(\n",
533
+ " embedding_key=\"new_embedding\",\n",
534
+ " gpu_id=0,\n",
535
+ " temp_mem_gb=6.0,\n",
536
+ " use_float16_storage=True,\n",
537
+ " build_batch_size=8192,\n",
538
+ " ).fit(data[theme])\n",
539
+ " picked = selector.select(\n",
540
+ " num_select=mmr_cands,\n",
541
+ " train_per_centroid=500,\n",
542
+ " niter=100,\n",
543
+ " centroid_search_k=512,\n",
544
+ " )\n",
545
+ " if theme not in non_en:\n",
546
+ " non_en[theme] = []\n",
547
+ " print(file, theme, len(data[theme]), len(picked), mmr_cands)\n",
548
+ " non_en[theme].extend(picked)\n"
549
+ ]
550
+ },
551
+ {
552
+ "cell_type": "code",
553
+ "execution_count": null,
554
+ "id": "499a2017-1bdc-44ea-b886-6a55bf7f8a36",
555
+ "metadata": {
556
+ "scrolled": true
557
+ },
558
+ "outputs": [],
559
+ "source": [
560
+ "final_ds = {}\n",
561
+ "for theme in non_en:\n",
562
+ " selector = FaissGpuDiverseSelectorTorch(\n",
563
+ " embedding_key=\"new_embedding\",\n",
564
+ " gpu_id=0,\n",
565
+ " temp_mem_gb=6.0,\n",
566
+ " use_float16_storage=True,\n",
567
+ " build_batch_size=8192,\n",
568
+ " ).fit(non_en[theme])\n",
569
+ " picked = selector.select(\n",
570
+ " num_select=7000,\n",
571
+ " train_per_centroid=500,\n",
572
+ " niter=100,\n",
573
+ " centroid_search_k=512,\n",
574
+ " )\n",
575
+ " if theme not in final_ds:\n",
576
+ " final_ds[theme] = []\n",
577
+ " print(theme, len(non_en[theme]), len(picked))\n",
578
+ " for i in picked:\n",
579
+ " del i['new_embedding']\n",
580
+ " final_ds[theme].extend(picked)\n",
581
+ "\n",
582
+ "json.dump(final_ds, open('filtered2/ml_step1.json', 'w'))"
583
+ ]
584
+ },
585
+ {
586
+ "cell_type": "code",
587
+ "execution_count": null,
588
+ "id": "7161b895-b229-47a6-9082-b493fbd7428a",
589
+ "metadata": {},
590
+ "outputs": [],
591
+ "source": [
592
+ "en = json.load(open('filtered/en_ds.json'))"
593
+ ]
594
+ },
595
+ {
596
+ "cell_type": "code",
597
+ "execution_count": null,
598
+ "id": "ce6213a5-40bb-454e-950b-b4b23e045e13",
599
+ "metadata": {
600
+ "scrolled": true
601
+ },
602
+ "outputs": [],
603
+ "source": [
604
+ "final_ds = {}\n",
605
+ "for theme in en:\n",
606
+ " en[theme] = [{\"id\": i[\"id\"], \"title\": i[\"title\"], \"content\": i[\"content\"], \"language\": i[\"language\"], **i[\"nlp\"]} for i in en[theme]]\n",
607
+ " selector = FaissGpuDiverseSelectorTorch(\n",
608
+ " embedding_key=\"new_embedding\",\n",
609
+ " gpu_id=0,\n",
610
+ " temp_mem_gb=6.0,\n",
611
+ " use_float16_storage=True,\n",
612
+ " build_batch_size=8192,\n",
613
+ " ).fit(en[theme])\n",
614
+ " picked = selector.select(\n",
615
+ " num_select=8000,\n",
616
+ " train_per_centroid=500,\n",
617
+ " niter=100,\n",
618
+ " centroid_search_k=512,\n",
619
+ " )\n",
620
+ " if theme not in final_ds:\n",
621
+ " final_ds[theme] = []\n",
622
+ " print(theme, len(en[theme]), len(picked))\n",
623
+ " for i in picked:\n",
624
+ " del i['new_embedding']\n",
625
+ " final_ds[theme].extend(picked)\n",
626
+ "\n",
627
+ "json.dump(final_ds, open('filtered2/en_step1.json', 'w'))"
628
+ ]
629
+ },
630
+ {
631
+ "cell_type": "markdown",
632
+ "id": "842369c9-6624-404d-808b-1e39a4a6a1b0",
633
+ "metadata": {},
634
+ "source": [
635
+ "# 3.calculate embedding of qwen"
636
+ ]
637
+ },
638
+ {
639
+ "cell_type": "code",
640
+ "execution_count": null,
641
+ "id": "bfce123b-6d47-4353-b36c-3db708274742",
642
+ "metadata": {},
643
+ "outputs": [],
644
+ "source": [
645
+ "import requests\n",
646
+ "from requests.adapters import HTTPAdapter\n",
647
+ "from concurrent.futures import ThreadPoolExecutor, as_completed\n",
648
+ "from typing import List, Dict, Any, Optional, Tuple\n",
649
+ "\n",
650
+ "def _make_session(pool_size: int) -> requests.Session:\n",
651
+ " s = requests.Session()\n",
652
+ " adapter = HTTPAdapter(pool_connections=pool_size, pool_maxsize=pool_size, max_retries=0)\n",
653
+ " s.mount(\"http://\", adapter)\n",
654
+ " s.mount(\"https://\", adapter)\n",
655
+ " return s\n",
656
+ "\n",
657
+ "\n",
658
+ "def _embed_batch(\n",
659
+ " session: requests.Session,\n",
660
+ " base_url: str,\n",
661
+ " text: str,\n",
662
+ " timeout_s: float,\n",
663
+ " extra_headers = {'x-api-token': '***'},\n",
664
+ ") -> Any:\n",
665
+ " \"\"\"\n",
666
+ " Native TEI endpoint: POST {base_url}/embed\n",
667
+ " Payload: {\"inputs\": [..texts..]}\n",
668
+ " \"\"\"\n",
669
+ " url = base_url.rstrip(\"/\") + \"/qwen\"\n",
670
+ " headers = {\"Content-Type\": \"application/json\"}\n",
671
+ " if not text.startswith('query: '):\n",
672
+ " text = 'query: ' + text\n",
673
+ " if extra_headers:\n",
674
+ " headers.update(extra_headers)\n",
675
+ " r = session.post(url, json={\"inputs\": text}, headers=headers, timeout=timeout_s)\n",
676
+ " r.raise_for_status()\n",
677
+ " return r.json()\n",
678
+ "\n",
679
+ "def one_call(art):\n",
680
+ " if \"embedding\" in art:\n",
681
+ " return\n",
682
+ " text = []\n",
683
+ " if art.get(\"title\"):\n",
684
+ " text.append(art[\"title\"])\n",
685
+ " if art.get(\"content\"):\n",
686
+ " text.append(art[\"content\"])\n",
687
+ " text = \"\\n\\n\".join(text)\n",
688
+ " try:\n",
689
+ " res = _embed_batch(session, base_url, text, timeout_s=90)\n",
690
+ " except Exception as e:\n",
691
+ " print(e)\n",
692
+ " art[\"embedding\"] = res[0]\n",
693
+ "\n",
694
+ "WORKERS = 512\n",
695
+ "session = _make_session(pool_size=WORKERS)\n",
696
+ "base_url = \"http://65.19.132.154:9001\""
697
+ ]
698
+ },
699
+ {
700
+ "cell_type": "code",
701
+ "execution_count": null,
702
+ "id": "e5e64eac-dc40-41ec-a497-45ce9cb2f637",
703
+ "metadata": {},
704
+ "outputs": [],
705
+ "source": []
706
+ },
707
+ {
708
+ "cell_type": "code",
709
+ "execution_count": null,
710
+ "id": "7304ec1b-2a7a-42f7-a0fb-f6e5e6c4eb94",
711
+ "metadata": {},
712
+ "outputs": [],
713
+ "source": [
714
+ "for file in ['en_step1.json', 'ml_step1.json']:\n",
715
+ " data = json.load(open('filtered2/' + file))\n",
716
+ " for theme in data:\n",
717
+ " with ThreadPoolExecutor(max_workers=WORKERS) as ex:\n",
718
+ " _ = [ex.submit(one_call, art) for art in data[theme]]\n",
719
+ " json.dump(data, open('filtered2/' + file, 'w'))"
720
+ ]
721
+ },
722
+ {
723
+ "cell_type": "markdown",
724
+ "id": "59b381bb-d4ca-4414-b299-3da7531faed3",
725
+ "metadata": {
726
+ "jp-MarkdownHeadingCollapsed": true
727
+ },
728
+ "source": [
729
+ "# 4. again choose MMR based on qwen embedding also make list of llm verify candidate when qwen emb not matching with given themes (5K each)"
730
+ ]
731
+ },
732
+ {
733
+ "cell_type": "code",
734
+ "execution_count": null,
735
+ "id": "80e76f11-f5de-4c3e-b130-5646b4bf5273",
736
+ "metadata": {},
737
+ "outputs": [],
738
+ "source": [
739
+ "from fast_langdetect import detect\n",
740
+ "\n",
741
+ "en_data, ml_data = [], []\n",
742
+ "data = json.load(open('filtered2/en_step1.json'))\n",
743
+ "for theme in data:\n",
744
+ " en_data.extend(data[theme])\n",
745
+ "\n",
746
+ "data = json.load(open('filtered2/ml_step1.json'))\n",
747
+ "for theme in data:\n",
748
+ " ml_data.extend(data[theme])\n",
749
+ "\n",
750
+ "for file in ['train1.json', 'spanish_tagged.json', 'train_multilang_2.json']:\n",
751
+ " data = json.load(open('filtered2/' + file))\n",
752
+ " for i in data:\n",
753
+ " if not i.get('content'):\n",
754
+ " continue\n",
755
+ " l = i.get('language')\n",
756
+ " if not l:\n",
757
+ " l = i.get('language')\n",
758
+ " if not l:\n",
759
+ " l = detect(text=i['content'], model='full', k=1)[0]['lang']\n",
760
+ " t = {**i, \"language\": l, \"file\": file}\n",
761
+ " if l == 'en':\n",
762
+ " en_data.append(t)\n",
763
+ " else:\n",
764
+ " ml_data.append(t)"
765
+ ]
766
+ },
767
+ {
768
+ "cell_type": "code",
769
+ "execution_count": null,
770
+ "id": "1a435d53-1cf9-46e3-809a-208bb0a76122",
771
+ "metadata": {},
772
+ "outputs": [],
773
+ "source": [
774
+ "# for i in ml_data:\n",
775
+ "# if \"embedding\" in i:\n",
776
+ "# del i[\"embedding\"]"
777
+ ]
778
+ },
779
+ {
780
+ "cell_type": "code",
781
+ "execution_count": null,
782
+ "id": "b5979e1a-4e94-45ee-937d-102550918ef7",
783
+ "metadata": {},
784
+ "outputs": [],
785
+ "source": [
786
+ "with ThreadPoolExecutor(max_workers=WORKERS) as ex:\n",
787
+ " _ = [ex.submit(one_call, art) for art in en_data]"
788
+ ]
789
+ },
790
+ {
791
+ "cell_type": "code",
792
+ "execution_count": null,
793
+ "id": "63e62230-bc41-442c-a612-08b22abdd0b3",
794
+ "metadata": {},
795
+ "outputs": [],
796
+ "source": [
797
+ "with ThreadPoolExecutor(max_workers=WORKERS) as ex:\n",
798
+ " _ = [ex.submit(one_call, art) for art in ml_data]"
799
+ ]
800
+ },
801
+ {
802
+ "cell_type": "code",
803
+ "execution_count": null,
804
+ "id": "59d5c0cf-37ab-406c-ad53-042c7af451cd",
805
+ "metadata": {},
806
+ "outputs": [],
807
+ "source": [
808
+ "len(en_data), len(ml_data)"
809
+ ]
810
+ },
811
+ {
812
+ "cell_type": "code",
813
+ "execution_count": null,
814
+ "id": "3e5543ef-53d6-4e97-923d-203a0382e878",
815
+ "metadata": {},
816
+ "outputs": [],
817
+ "source": [
818
+ "json.dump(en_data, open('filtered/en_step2.json', 'w'))"
819
+ ]
820
+ },
821
+ {
822
+ "cell_type": "code",
823
+ "execution_count": null,
824
+ "id": "b8cedffd-36f8-4684-8126-2b6b78d67aec",
825
+ "metadata": {},
826
+ "outputs": [],
827
+ "source": [
828
+ "json.dump(ml_data, open('filtered/ml_step2.json', 'w'))"
829
+ ]
830
+ },
831
+ {
832
+ "cell_type": "code",
833
+ "execution_count": null,
834
+ "id": "b5c6d123-12a2-4275-a4b0-1d43951ed922",
835
+ "metadata": {},
836
+ "outputs": [],
837
+ "source": [
838
+ "# easy to confirm with theme embedding first and find mismatches\n",
839
+ "themes = [\n",
840
+ " \"Economics\",\n",
841
+ " \"Financial Crime\",\n",
842
+ " \"Finance\",\n",
843
+ " \"Lifestyle\",\n",
844
+ " \"Automotive\",\n",
845
+ " \"Science\",\n",
846
+ " \"Tech\",\n",
847
+ " \"Travel\",\n",
848
+ " \"Weather\",\n",
849
+ " \"Health\",\n",
850
+ " \"Crime\",\n",
851
+ " \"Sports\",\n",
852
+ " \"General\",\n",
853
+ " \"Business\",\n",
854
+ " \"Politics\",\n",
855
+ " \"Entertainment\",\n",
856
+ " ]\n",
857
+ "theme_emb = {i: _embed_batch(session, base_url, i, timeout_s=90)[0] for i in themes}"
858
+ ]
859
+ },
860
+ {
861
+ "cell_type": "code",
862
+ "execution_count": null,
863
+ "id": "d041e14a-cb28-4714-afbe-582ccb64fcbb",
864
+ "metadata": {},
865
+ "outputs": [],
866
+ "source": [
867
+ "import numpy as np\n",
868
+ "\n",
869
+ "def l2_normalize(x, axis=-1, eps=1e-12):\n",
870
+ " x = np.asarray(x, dtype=np.float32)\n",
871
+ " return x / (np.linalg.norm(x, axis=axis, keepdims=True) + eps)\n",
872
+ "\n",
873
+ "def prepare_theme_matrix(theme_to_emb: dict):\n",
874
+ " theme_names = list(theme_to_emb.keys())\n",
875
+ " theme_mat = np.vstack([theme_to_emb[t] for t in theme_names]).astype(np.float32)\n",
876
+ " theme_mat = l2_normalize(theme_mat, axis=1) # (T, D)\n",
877
+ " return theme_names, theme_mat\n",
878
+ "\n",
879
+ "def similarities(article_mat, theme_mat):\n",
880
+ " # article_mat: (N, D) theme_mat: (T, D)\n",
881
+ " article_mat = l2_normalize(article_mat, axis=1)\n",
882
+ " return article_mat @ theme_mat.T # (N, T) cosine similarities\n",
883
+ "\n",
884
+ "def pick_by_gap_row(scores, gap=0.10, min_score=0.25, max_k=3):\n",
885
+ " # scores: (T,) raw similarity\n",
886
+ " idx = np.argsort(scores)[::-1]\n",
887
+ " s = scores[idx]\n",
888
+ "\n",
889
+ " # keep while above floor\n",
890
+ " keep = np.where(s >= min_score)[0]\n",
891
+ " if keep.size == 0:\n",
892
+ " return idx[:0] # none\n",
893
+ " last = keep[-1]\n",
894
+ "\n",
895
+ " # stop at first big drop\n",
896
+ " drops = s[:-1] - s[1:]\n",
897
+ " cut_points = np.where(drops >= gap)[0]\n",
898
+ " if cut_points.size > 0:\n",
899
+ " last = min(last, cut_points[0])\n",
900
+ "\n",
901
+ " last = min(last, max_k - 1) # cap number of themes\n",
902
+ " return idx[: last + 1]\n",
903
+ "\n",
904
+ "def assign_themes(article_mat, theme_to_emb, gap=0.08, min_score=0.2, max_k=5):\n",
905
+ " theme_names, theme_mat = prepare_theme_matrix(theme_to_emb)\n",
906
+ " sim = similarities(article_mat, theme_mat) # (N, T)\n",
907
+ "\n",
908
+ " results = []\n",
909
+ " for i in range(sim.shape[0]):\n",
910
+ " chosen_idx = pick_by_gap_row(sim[i], gap=gap, min_score=min_score, max_k=max_k)\n",
911
+ " results.append([theme_names[j] for j in chosen_idx])\n",
912
+ " return results\n",
913
+ "\n",
914
+ "def is_prediction_on_top(my_prediction, emb_prediction):\n",
915
+ " \"\"\"\n",
916
+ " Returns True if:\n",
917
+ " - len(my_prediction)==1: the single predicted theme is the top-1 embedding theme\n",
918
+ " - len(my_prediction)>1: the top-K embedding themes (K=len(my_prediction)) match my_prediction as a set\n",
919
+ " (order-insensitive), meaning all predicted themes are \"on top\" together.\n",
920
+ " \"\"\"\n",
921
+ " if not my_prediction or not emb_prediction:\n",
922
+ " return False\n",
923
+ "\n",
924
+ " k = len(my_prediction)\n",
925
+ " topk = emb_prediction[:k]\n",
926
+ "\n",
927
+ " if k == 1:\n",
928
+ " return my_prediction[0] == emb_prediction[0]\n",
929
+ "\n",
930
+ " return set(my_prediction) == set(topk)"
931
+ ]
932
+ },
933
+ {
934
+ "cell_type": "code",
935
+ "execution_count": null,
936
+ "id": "7c72387e-4343-47f0-88fd-ef964a03fbcd",
937
+ "metadata": {},
938
+ "outputs": [],
939
+ "source": [
940
+ "is_prediction_on_top(en_data[500]['themes'], assign_themes([en_data[500]['embedding']], theme_emb)[0])"
941
+ ]
942
+ },
943
+ {
944
+ "cell_type": "code",
945
+ "execution_count": null,
946
+ "id": "61121189-f649-4660-a4d0-973cae54de6e",
947
+ "metadata": {},
948
+ "outputs": [],
949
+ "source": [
950
+ "results = assign_themes([i['embedding'] for i in en_data], theme_emb)\n",
951
+ "nb = 0\n",
952
+ "for i, j in zip(en_data, results):\n",
953
+ " if is_prediction_on_top(i['themes'], j):\n",
954
+ " i['need_to_validate'] = False\n",
955
+ " nb += 1\n",
956
+ " else:\n",
957
+ " i['need_to_validate'] = True\n",
958
+ "nb, len(en_data)"
959
+ ]
960
+ },
961
+ {
962
+ "cell_type": "code",
963
+ "execution_count": null,
964
+ "id": "2895b09b-130c-4c57-b0e3-4b8bcbc41f0c",
965
+ "metadata": {},
966
+ "outputs": [],
967
+ "source": [
968
+ "results = assign_themes([i['embedding'] for i in ml_data], theme_emb)\n",
969
+ "nb = 0\n",
970
+ "for i, j in zip(ml_data, results):\n",
971
+ " if is_prediction_on_top(i['themes'], j):\n",
972
+ " i['need_to_validate'] = False\n",
973
+ " nb += 1\n",
974
+ " else:\n",
975
+ " i['need_to_validate'] = True\n",
976
+ "nb, len(ml_data)"
977
+ ]
978
+ },
979
+ {
980
+ "cell_type": "code",
981
+ "execution_count": null,
982
+ "id": "446c19a3-afab-4100-964a-093969034e91",
983
+ "metadata": {},
984
+ "outputs": [],
985
+ "source": [
986
+ "import uuid\n",
987
+ "str(uuid.uuid4())"
988
+ ]
989
+ },
990
+ {
991
+ "cell_type": "code",
992
+ "execution_count": null,
993
+ "id": "93dc427c-cebd-4376-a565-ff9e8283077f",
994
+ "metadata": {},
995
+ "outputs": [],
996
+ "source": [
997
+ "nid, nt = 0, 0\n",
998
+ "for i in en_data + ml_data:\n",
999
+ " if not i.get('id'):\n",
1000
+ " i['id'] = str(uuid.uuid4())\n",
1001
+ " nid += 1\n",
1002
+ " if not i.get('title'):\n",
1003
+ " nt += 1\n",
1004
+ "len(en_data + ml_data), nid, nt"
1005
+ ]
1006
+ },
1007
+ {
1008
+ "cell_type": "code",
1009
+ "execution_count": null,
1010
+ "id": "edfe0095-a239-4df6-9c1c-8d8f6033892b",
1011
+ "metadata": {},
1012
+ "outputs": [],
1013
+ "source": [
1014
+ "json.dump(en_data, open('filtered2/en_step2.json', 'w'))"
1015
+ ]
1016
+ },
1017
+ {
1018
+ "cell_type": "code",
1019
+ "execution_count": null,
1020
+ "id": "6ed59f14-c070-41a5-bf81-2f21fb89e797",
1021
+ "metadata": {},
1022
+ "outputs": [],
1023
+ "source": [
1024
+ "json.dump(ml_data, open('filtered2/ml_step2.json', 'w'))"
1025
+ ]
1026
+ },
1027
+ {
1028
+ "cell_type": "code",
1029
+ "execution_count": null,
1030
+ "id": "67d594c0-6eab-4191-855e-b98e1d490649",
1031
+ "metadata": {},
1032
+ "outputs": [],
1033
+ "source": []
1034
+ },
1035
+ {
1036
+ "cell_type": "code",
1037
+ "execution_count": null,
1038
+ "id": "7d612263-349d-4ae6-81f3-baf643452c1c",
1039
+ "metadata": {
1040
+ "scrolled": true
1041
+ },
1042
+ "outputs": [],
1043
+ "source": [
1044
+ "en_themewise = {}\n",
1045
+ "for i in en_data:\n",
1046
+ " sts = sort_themes_by_priority(i['themes'])\n",
1047
+ " for st in sts:\n",
1048
+ " if st not in en_themewise:\n",
1049
+ " en_themewise[st] = []\n",
1050
+ " en_themewise[st].append(i)\n",
1051
+ " break\n",
1052
+ "for i in en_themewise:\n",
1053
+ " tcc = len(en_themewise[i])\n",
1054
+ " selector = FaissGpuDiverseSelectorTorch(\n",
1055
+ " embedding_key=\"embedding\",\n",
1056
+ " gpu_id=0,\n",
1057
+ " temp_mem_gb=6.0,\n",
1058
+ " use_float16_storage=True,\n",
1059
+ " build_batch_size=8192,\n",
1060
+ " ).fit(en_themewise[i])\n",
1061
+ " picked = selector.select(\n",
1062
+ " num_select=5000,\n",
1063
+ " train_per_centroid=500,\n",
1064
+ " niter=100,\n",
1065
+ " centroid_search_k=512,\n",
1066
+ " )\n",
1067
+ " en_themewise[i] = picked\n",
1068
+ " print(i, '==>', tcc, len(picked))\n",
1069
+ "en_data = []\n",
1070
+ "for theme in en_themewise:\n",
1071
+ " en_data.extend(en_themewise[theme])\n",
1072
+ "json.dump(en_data, open('filtered2/en_step3.json', 'w'))"
1073
+ ]
1074
+ },
1075
+ {
1076
+ "cell_type": "code",
1077
+ "execution_count": null,
1078
+ "id": "f5d4d3c4-9248-46b2-ba52-51dd4e22838d",
1079
+ "metadata": {
1080
+ "scrolled": true
1081
+ },
1082
+ "outputs": [],
1083
+ "source": [
1084
+ "ml_themewise = {}\n",
1085
+ "for i in ml_data:\n",
1086
+ " sts = sort_themes_by_priority(i['themes'])\n",
1087
+ " for st in sts:\n",
1088
+ " if st not in ml_themewise:\n",
1089
+ " ml_themewise[st] = []\n",
1090
+ " ml_themewise[st].append(i)\n",
1091
+ " break\n",
1092
+ "for i in ml_themewise:\n",
1093
+ " print(i, '==>', len(ml_themewise[i]))\n",
1094
+ " tcc = len(ml_themewise[i])\n",
1095
+ " selector = FaissGpuDiverseSelectorTorch(\n",
1096
+ " embedding_key=\"embedding\",\n",
1097
+ " gpu_id=0,\n",
1098
+ " temp_mem_gb=6.0,\n",
1099
+ " use_float16_storage=True,\n",
1100
+ " build_batch_size=8192,\n",
1101
+ " ).fit(ml_themewise[i])\n",
1102
+ " picked = selector.select(\n",
1103
+ " num_select=5000,\n",
1104
+ " train_per_centroid=500,\n",
1105
+ " niter=100,\n",
1106
+ " centroid_search_k=512,\n",
1107
+ " )\n",
1108
+ " ml_themewise[i] = picked\n",
1109
+ " print(i, '==>', tcc, len(picked))\n",
1110
+ "ml_data = []\n",
1111
+ "for theme in ml_themewise:\n",
1112
+ " ml_data.extend(ml_themewise[theme])\n",
1113
+ "json.dump(ml_data, open('filtered2/ml_step3.json', 'w'))\n",
1114
+ "len(ml_data)"
1115
+ ]
1116
+ },
1117
+ {
1118
+ "cell_type": "markdown",
1119
+ "id": "17b1492a-0b40-43b3-9a46-3fea7fd4a7f6",
1120
+ "metadata": {},
1121
+ "source": [
1122
+ "# 5. confirm with LLM for selected candidates"
1123
+ ]
1124
+ },
1125
+ {
1126
+ "cell_type": "code",
1127
+ "execution_count": null,
1128
+ "id": "033d00e6-9068-4705-a8a5-7a0aa4828d52",
1129
+ "metadata": {},
1130
+ "outputs": [],
1131
+ "source": [
1132
+ "nb = 0\n",
1133
+ "for i in en_data + ml_data:\n",
1134
+ " if i.get('need_to_validate'):\n",
1135
+ " nb += 1\n",
1136
+ "nb, len(en_data + ml_data)"
1137
+ ]
1138
+ },
1139
+ {
1140
+ "cell_type": "code",
1141
+ "execution_count": 2,
1142
+ "id": "6701aa6e-2797-4083-9496-1308d71ffbde",
1143
+ "metadata": {},
1144
+ "outputs": [],
1145
+ "source": [
1146
+ "import json\n",
1147
+ "import logging\n",
1148
+ "import time\n",
1149
+ "import random\n",
1150
+ "\n",
1151
+ "import openai\n",
1152
+ "\n",
1153
+ "OPENAI_KEY = '***'\n",
1154
+ "MAX_OPENAI_RETRIES = 5\n",
1155
+ "\n",
1156
+ "openai.api_key = OPENAI_KEY\n",
1157
+ "\n",
1158
+ "\n",
1159
+ "def extract_function_data(\n",
1160
+ " function: dict,\n",
1161
+ " content: str,\n",
1162
+ " target_model: str = 'gpt-4o-mini',\n",
1163
+ " role: str = \"You are a zero-shot classification model.\",\n",
1164
+ " retries: int = 0,\n",
1165
+ " extra_tags: dict = None\n",
1166
+ "):\n",
1167
+ " retries += 1\n",
1168
+ " if extra_tags:\n",
1169
+ " openpipe_tags.update(extra_tags)\n",
1170
+ " if retries > MAX_OPENAI_RETRIES:\n",
1171
+ " logging.error(\"Failed to Extract Event Data\", extra={\n",
1172
+ " \"content_length\": len(content.split(\" \")),\n",
1173
+ " \"error_message\": f\"Failed to extract function data after {MAX_OPENAI_RETRIES} retries\"\n",
1174
+ " })\n",
1175
+ " return None\n",
1176
+ " try:\n",
1177
+ " response = openai.chat.completions.create(\n",
1178
+ " model=target_model,\n",
1179
+ " messages=[{\"role\": \"system\", \"content\": role},\n",
1180
+ " {\"role\": \"user\", \"content\": content}],\n",
1181
+ " functions=[function],\n",
1182
+ " function_call={\"name\": function[\"name\"]},\n",
1183
+ " temperature=0\n",
1184
+ " )\n",
1185
+ "\n",
1186
+ " # Except rate limit error\n",
1187
+ " except openai.RateLimitError as e:\n",
1188
+ " # Ask the question again with 3/4 of the content\n",
1189
+ " logging.error(\"Retrying OpenAI Request\", extra={\n",
1190
+ " \"error\": str(e),\n",
1191
+ " \"error_message\": \"API limit reached, try again after 5 seconds\",\n",
1192
+ " })\n",
1193
+ " # sleep for random time between 5 to 15 seconds\n",
1194
+ " time.sleep(random.randint(5, 15))\n",
1195
+ " return extract_function_data(function, target_model, content, retries)\n",
1196
+ "\n",
1197
+ " # Except all other errors\n",
1198
+ " except Exception as e:\n",
1199
+ " logging.error(\"Retrying OpenAI Request\", extra={\n",
1200
+ " \"error\": str(e),\n",
1201
+ " \"error_message\": \"Unkown Error, Retrying with 3/4 of the content\",\n",
1202
+ " })\n",
1203
+ " logging.error(e)\n",
1204
+ " content_length = len(content.split(\" \"))\n",
1205
+ " updated_content = \" \".join(content.split(\" \")[:int(content_length * 0.75)])\n",
1206
+ "\n",
1207
+ " return extract_function_data(function, target_model, updated_content,\n",
1208
+ " retries)\n",
1209
+ "\n",
1210
+ " try:\n",
1211
+ " result = json.loads(response.choices[0].message.function_call.arguments)\n",
1212
+ " ## Remove Keys with empty string value because OpenAI is returning empty strings as keys sometimes\n",
1213
+ " result = {k: v for k, v in result.items() if v not in [\"\", [], None]}\n",
1214
+ "\n",
1215
+ " logging.info(\"Successfully extracted Event Data\", extra={\n",
1216
+ " \"error_message\": \"Content Classified Successfully\",\n",
1217
+ " \"content\": str(result)\n",
1218
+ " })\n",
1219
+ " return result\n",
1220
+ " except Exception as e:\n",
1221
+ " logging.error(\"Failed to Parse JSON from OpenAI in Extraction\", extra={\n",
1222
+ " \"error\": str(e),\n",
1223
+ " \"openai_response\": str(response),\n",
1224
+ " })\n",
1225
+ " return None"
1226
+ ]
1227
+ },
1228
+ {
1229
+ "cell_type": "code",
1230
+ "execution_count": null,
1231
+ "id": "d2992602-ab1b-43ad-a2c2-ede79671affc",
1232
+ "metadata": {},
1233
+ "outputs": [],
1234
+ "source": [
1235
+ "int_theme_classification_function = {\n",
1236
+ " \"name\": \"theme_classification\",\n",
1237
+ " \"description\": \"please classify below news article into any of these categories, ['Health', 'Entertainment', 'Business', 'Science', 'Politics', 'Finance', 'Economics', 'Tech', 'Crime', 'Sports', 'Lifestyle', 'Automotive', 'Travel', 'Weather', 'General', 'Financial Crime'], You should return \\\"General\\\" only if article doesnt fall into any of these, Please classigy it carefully. Don't tell me bad result before confirming. also don't return all categories, return general if it not fits in any of above\",\n",
1238
+ " \"parameters\": {\n",
1239
+ " \"type\": \"object\",\n",
1240
+ " \"properties\": {'Health': {'type': 'boolean'},\n",
1241
+ " 'Entertainment': {'type': 'boolean'},\n",
1242
+ " 'Business': {'type': 'boolean'}, 'Science': {'type': 'boolean'},\n",
1243
+ " 'Politics': {'type': 'boolean'}, 'Finance': {'type': 'boolean'},\n",
1244
+ " 'Economics': {'type': 'boolean'}, 'Tech': {'type': 'boolean'},\n",
1245
+ " 'Crime': {'type': 'boolean'}, 'Sports': {'type': 'boolean'},\n",
1246
+ " 'Lifestyle': {'type': 'boolean'},\n",
1247
+ " 'Automotive': {'type': 'boolean'}, 'Travel': {'type': 'boolean'},\n",
1248
+ " 'Weather': {'type': 'boolean'}, 'General': {'type': 'boolean'}, 'Financial Crime': {'type': 'boolean'}},\n",
1249
+ " \"required\": []\n",
1250
+ " }\n",
1251
+ "}"
1252
+ ]
1253
+ },
1254
+ {
1255
+ "cell_type": "code",
1256
+ "execution_count": null,
1257
+ "id": "245b1afc-7926-47fb-9f4b-fdae08edb4f0",
1258
+ "metadata": {},
1259
+ "outputs": [],
1260
+ "source": [
1261
+ "for i in en_data[::-1]:\n",
1262
+ " if i['need_to_validate']:\n",
1263
+ " break"
1264
+ ]
1265
+ },
1266
+ {
1267
+ "cell_type": "code",
1268
+ "execution_count": null,
1269
+ "id": "05314580-6e86-4fb6-a554-2d357deb21b1",
1270
+ "metadata": {},
1271
+ "outputs": [],
1272
+ "source": [
1273
+ "themes = extract_function_data(int_theme_classification_function, \"Title\" + i['title'] + '\\n\\nContent:' + i['content'], target_model='gpt-4.1-mini')\n",
1274
+ "themes = [i for i in themes if themes[i]]\n",
1275
+ "themes"
1276
+ ]
1277
+ },
1278
+ {
1279
+ "cell_type": "code",
1280
+ "execution_count": null,
1281
+ "id": "fc20aeb3-8552-4873-bda2-37e498c04d49",
1282
+ "metadata": {},
1283
+ "outputs": [],
1284
+ "source": [
1285
+ "def validate(art):\n",
1286
+ " if not art.get('need_to_validate'):\n",
1287
+ " return\n",
1288
+ " if \"prev_themes\" in art:\n",
1289
+ " return\n",
1290
+ " content = \"\"\n",
1291
+ " if art.get(\"title\"):\n",
1292
+ " content = \"Title:\" + art['title'] + '\\n\\nContent:' + art['content']\n",
1293
+ " else:\n",
1294
+ " content = art['content']\n",
1295
+ " themes = extract_function_data(int_theme_classification_function, content, target_model='gpt-4.1-mini')\n",
1296
+ " themes = [i for i in themes if themes[i]]\n",
1297
+ " if themes:\n",
1298
+ " art['prev_themes'] = art.pop('themes')\n",
1299
+ " art['themes'] = themes"
1300
+ ]
1301
+ },
1302
+ {
1303
+ "cell_type": "code",
1304
+ "execution_count": null,
1305
+ "id": "11ca3cca-df82-4c92-aac8-58c6722d5ed7",
1306
+ "metadata": {},
1307
+ "outputs": [],
1308
+ "source": [
1309
+ "with ThreadPoolExecutor(max_workers=128) as ex:\n",
1310
+ " _ = [ex.submit(validate, art) for art in en_data]"
1311
+ ]
1312
+ },
1313
+ {
1314
+ "cell_type": "code",
1315
+ "execution_count": null,
1316
+ "id": "b2541763-4a46-4195-b2e9-48df2e992add",
1317
+ "metadata": {},
1318
+ "outputs": [],
1319
+ "source": [
1320
+ "with ThreadPoolExecutor(max_workers=128) as ex:\n",
1321
+ " _ = [ex.submit(validate, art) for art in ml_data]"
1322
+ ]
1323
+ },
1324
+ {
1325
+ "cell_type": "code",
1326
+ "execution_count": null,
1327
+ "id": "07227c4d-6cdb-41dc-87d3-b67c73f15b6f",
1328
+ "metadata": {},
1329
+ "outputs": [],
1330
+ "source": [
1331
+ "nb = 0\n",
1332
+ "for i in en_data + ml_data:\n",
1333
+ " if i['need_to_validate'] and not i.get('prev_themes'):\n",
1334
+ " nb += 1\n",
1335
+ " print(i.get('title'), i.get('content'))\n",
1336
+ "nb, len(en_data + ml_data)"
1337
+ ]
1338
+ },
1339
+ {
1340
+ "cell_type": "code",
1341
+ "execution_count": null,
1342
+ "id": "24004f74-9ca5-4c70-88c1-f1ec3fbb13c0",
1343
+ "metadata": {},
1344
+ "outputs": [],
1345
+ "source": [
1346
+ "json.dump(en_data, open('filtered2/en_step4.json', 'w'))"
1347
+ ]
1348
+ },
1349
+ {
1350
+ "cell_type": "code",
1351
+ "execution_count": null,
1352
+ "id": "67a33aa0-e8e7-44c3-89d6-2ca097679988",
1353
+ "metadata": {},
1354
+ "outputs": [],
1355
+ "source": [
1356
+ "json.dump(ml_data, open('filtered2/ml_step4.json', 'w'))"
1357
+ ]
1358
+ },
1359
+ {
1360
+ "cell_type": "markdown",
1361
+ "id": "1ac2b2bd-7419-4f67-a023-59bac45d57b4",
1362
+ "metadata": {},
1363
+ "source": [
1364
+ "# 6. split train & test also pre-train"
1365
+ ]
1366
+ },
1367
+ {
1368
+ "cell_type": "code",
1369
+ "execution_count": null,
1370
+ "id": "afab9235-4e4c-4a28-b6d1-e6071efd056a",
1371
+ "metadata": {},
1372
+ "outputs": [],
1373
+ "source": [
1374
+ "random.shuffle(en_data)\n",
1375
+ "random.shuffle(ml_data)"
1376
+ ]
1377
+ },
1378
+ {
1379
+ "cell_type": "code",
1380
+ "execution_count": null,
1381
+ "id": "d57dc5c4-1756-497c-9be3-9e3c4086f667",
1382
+ "metadata": {},
1383
+ "outputs": [],
1384
+ "source": [
1385
+ "train = {}\n",
1386
+ "test = {}\n",
1387
+ "\n",
1388
+ "for i in en_data:\n",
1389
+ " sts = sort_themes_by_priority(i['themes'])\n",
1390
+ " if 'Education' in sts:\n",
1391
+ " sts.remove('Education')\n",
1392
+ " for st in sts:\n",
1393
+ " if st not in train:\n",
1394
+ " train[st] = []\n",
1395
+ " test[st] = []\n",
1396
+ " if len(test[st]) < 500 and i['need_to_validate']:\n",
1397
+ " test[st].append(i)\n",
1398
+ " break\n",
1399
+ " else:\n",
1400
+ " train[st].append(i)\n",
1401
+ " break"
1402
+ ]
1403
+ },
1404
+ {
1405
+ "cell_type": "code",
1406
+ "execution_count": null,
1407
+ "id": "aa27ad3e-71a5-4aed-bc2f-32771feda321",
1408
+ "metadata": {},
1409
+ "outputs": [],
1410
+ "source": [
1411
+ "for i in ml_data:\n",
1412
+ " sts = sort_themes_by_priority(i['themes'])\n",
1413
+ " if 'Education' in sts:\n",
1414
+ " sts.remove('Education')\n",
1415
+ " for st in sts:\n",
1416
+ " if st not in train:\n",
1417
+ " train[st] = []\n",
1418
+ " test[st] = []\n",
1419
+ " if len(test[st]) < 1000 and i['need_to_validate']:\n",
1420
+ " test[st].append(i)\n",
1421
+ " break\n",
1422
+ " else:\n",
1423
+ " train[st].append(i)\n",
1424
+ " break"
1425
+ ]
1426
+ },
1427
+ {
1428
+ "cell_type": "code",
1429
+ "execution_count": null,
1430
+ "id": "b9cd3ece-301f-4fee-b8fd-a1c514287ce5",
1431
+ "metadata": {},
1432
+ "outputs": [],
1433
+ "source": [
1434
+ "for i in test:\n",
1435
+ " print(i, len(test[i]), len(train[i]))"
1436
+ ]
1437
+ },
1438
+ {
1439
+ "cell_type": "code",
1440
+ "execution_count": null,
1441
+ "id": "f6ac7932-ca64-4d06-8a12-e1235a0ee501",
1442
+ "metadata": {},
1443
+ "outputs": [],
1444
+ "source": [
1445
+ "json.dump(sum([train[i] for i in train], []), open('filtered2/train.json', 'w'))"
1446
+ ]
1447
+ },
1448
+ {
1449
+ "cell_type": "code",
1450
+ "execution_count": null,
1451
+ "id": "ae639b65-1b28-4079-899c-8ae5dd86eb4d",
1452
+ "metadata": {},
1453
+ "outputs": [],
1454
+ "source": [
1455
+ "json.dump(sum([test[i] for i in test], []), open('filtered2/test.json', 'w'))"
1456
+ ]
1457
+ },
1458
+ {
1459
+ "cell_type": "code",
1460
+ "execution_count": null,
1461
+ "id": "fc0fc992-1c32-42ed-95c8-b1b77d878b93",
1462
+ "metadata": {},
1463
+ "outputs": [],
1464
+ "source": [
1465
+ "tests_cont = set([i['content'].lower().strip() for i in sum([test[i] for i in test], [])])\n",
1466
+ "len(tests_cont)"
1467
+ ]
1468
+ },
1469
+ {
1470
+ "cell_type": "code",
1471
+ "execution_count": null,
1472
+ "id": "2d7ce8f4-00a4-419a-8b8d-6c06cf058aed",
1473
+ "metadata": {},
1474
+ "outputs": [],
1475
+ "source": [
1476
+ "pre_train = json.load(open('filtered2/en_step2.json')) + json.load(open('filtered2/ml_step2.json'))\n",
1477
+ "len(pre_train)"
1478
+ ]
1479
+ },
1480
+ {
1481
+ "cell_type": "code",
1482
+ "execution_count": null,
1483
+ "id": "ba3972ee-ac8c-408b-b314-cb89424ff46d",
1484
+ "metadata": {},
1485
+ "outputs": [],
1486
+ "source": [
1487
+ "pre_train = [i for i in pre_train if i['content'].lower().strip() not in tests_cont]\n",
1488
+ "len(pre_train)"
1489
+ ]
1490
+ },
1491
+ {
1492
+ "cell_type": "code",
1493
+ "execution_count": null,
1494
+ "id": "25594c3a-f5f1-4d77-a297-a33229630ac5",
1495
+ "metadata": {},
1496
+ "outputs": [],
1497
+ "source": [
1498
+ "json.dump(pre_train, open('filtered2/pre-train.json', 'w'))"
1499
+ ]
1500
+ },
1501
+ {
1502
+ "cell_type": "markdown",
1503
+ "id": "751f2c18-80b4-4f18-aa51-8d4da2e26e01",
1504
+ "metadata": {},
1505
+ "source": [
1506
+ "# split again val, train and test"
1507
+ ]
1508
+ },
1509
+ {
1510
+ "cell_type": "code",
1511
+ "execution_count": null,
1512
+ "id": "0991b233-ec51-481e-90b3-208f5c2303ec",
1513
+ "metadata": {},
1514
+ "outputs": [],
1515
+ "source": [
1516
+ "train = json.load(open('filtered2/train.json'))\n",
1517
+ "pre_train = json.load(open('filtered2/pre-train.json'))\n",
1518
+ "test = json.load(open('filtered2/test.json'))"
1519
+ ]
1520
+ },
1521
+ {
1522
+ "cell_type": "code",
1523
+ "execution_count": null,
1524
+ "id": "ac1fb5ff-9702-4605-899c-0b63f22fd478",
1525
+ "metadata": {},
1526
+ "outputs": [],
1527
+ "source": [
1528
+ "def validate_train(art):\n",
1529
+ " if \"prev_themes\" in art:\n",
1530
+ " return\n",
1531
+ " content = \"\"\n",
1532
+ " if art.get(\"title\"):\n",
1533
+ " content = \"Title:\" + art['title'] + '\\n\\nContent:' + art['content']\n",
1534
+ " else:\n",
1535
+ " content = art['content']\n",
1536
+ " themes = extract_function_data(int_theme_classification_function, content, target_model='gpt-4.1-mini')\n",
1537
+ " themes = [i for i in themes if themes[i]]\n",
1538
+ " if themes:\n",
1539
+ " art['prev_themes'] = art.pop('themes')\n",
1540
+ " art['themes'] = themes"
1541
+ ]
1542
+ },
1543
+ {
1544
+ "cell_type": "code",
1545
+ "execution_count": null,
1546
+ "id": "9ab0726e-a993-4400-b444-014537253fed",
1547
+ "metadata": {},
1548
+ "outputs": [],
1549
+ "source": [
1550
+ "from requests.adapters import HTTPAdapter\n",
1551
+ "from concurrent.futures import ThreadPoolExecutor, as_completed\n",
1552
+ "with ThreadPoolExecutor(max_workers=512) as ex:\n",
1553
+ " _ = [ex.submit(validate_train, art) for art in train]"
1554
+ ]
1555
+ },
1556
+ {
1557
+ "cell_type": "code",
1558
+ "execution_count": null,
1559
+ "id": "549b4492-97e6-4a26-b4da-fc8ad4c52073",
1560
+ "metadata": {},
1561
+ "outputs": [],
1562
+ "source": [
1563
+ "nb = 0\n",
1564
+ "for i in train:\n",
1565
+ " if \"prev_themes\" not in i:\n",
1566
+ " nb += 1\n",
1567
+ "nb, len(train)"
1568
+ ]
1569
+ },
1570
+ {
1571
+ "cell_type": "code",
1572
+ "execution_count": null,
1573
+ "id": "d011264f-f4af-4d82-bbfc-7d621d6f2215",
1574
+ "metadata": {},
1575
+ "outputs": [],
1576
+ "source": [
1577
+ "json.dump(train, open('filtered2/train.json', 'w'))"
1578
+ ]
1579
+ },
1580
+ {
1581
+ "cell_type": "code",
1582
+ "execution_count": null,
1583
+ "id": "054fd8df-9ea0-4d39-ac99-7e3e43c6c4d2",
1584
+ "metadata": {},
1585
+ "outputs": [],
1586
+ "source": [
1587
+ "ntrain = {}\n",
1588
+ "nval = {}\n",
1589
+ "\n",
1590
+ "for i in train:\n",
1591
+ " sts = sort_themes_by_priority(i['themes'])\n",
1592
+ " sts = [i for i in sts if i in priority_order]\n",
1593
+ " for st in sts:\n",
1594
+ " if st not in ntrain:\n",
1595
+ " ntrain[st] = []\n",
1596
+ " nval[st] = []\n",
1597
+ " if len(nval[st]) < 500 and i['need_to_validate']:\n",
1598
+ " nval[st].append(i)\n",
1599
+ " break\n",
1600
+ " else:\n",
1601
+ " ntrain[st].append(i)\n",
1602
+ " break\n",
1603
+ "\n",
1604
+ "ntrain = sum([ntrain[i] for i in ntrain], [])\n",
1605
+ "nval = sum([nval[i] for i in nval], [])"
1606
+ ]
1607
+ },
1608
+ {
1609
+ "cell_type": "code",
1610
+ "execution_count": null,
1611
+ "id": "7df359b3-b815-402a-a0b8-fd5d09cc7ee2",
1612
+ "metadata": {},
1613
+ "outputs": [],
1614
+ "source": [
1615
+ "len(ntrain), len(nval)"
1616
+ ]
1617
+ },
1618
+ {
1619
+ "cell_type": "code",
1620
+ "execution_count": null,
1621
+ "id": "b37878b1-4f2b-4a87-85ac-7a685e9fe00a",
1622
+ "metadata": {},
1623
+ "outputs": [],
1624
+ "source": []
1625
+ },
1626
+ {
1627
+ "cell_type": "code",
1628
+ "execution_count": null,
1629
+ "id": "36c43a30-4b0b-443f-a29a-af78848b271c",
1630
+ "metadata": {},
1631
+ "outputs": [],
1632
+ "source": [
1633
+ "npretrain = {}\n",
1634
+ "contset = {}\n",
1635
+ "\n",
1636
+ "for i in nval+test:\n",
1637
+ " contset[i['content'].lower().strip()] = i['themes']\n",
1638
+ "len(contset)"
1639
+ ]
1640
+ },
1641
+ {
1642
+ "cell_type": "code",
1643
+ "execution_count": null,
1644
+ "id": "7bc1d022-41c4-45f8-be2d-4246cd052802",
1645
+ "metadata": {},
1646
+ "outputs": [],
1647
+ "source": [
1648
+ "fin_rows = []\n",
1649
+ "for i in pre_train:\n",
1650
+ " if 'Financial Crime' not in i['themes']:\n",
1651
+ " continue\n",
1652
+ " if i['content'].lower().strip() not in contset:\n",
1653
+ " fin_rows.append(i)\n",
1654
+ "len(fin_rows)"
1655
+ ]
1656
+ },
1657
+ {
1658
+ "cell_type": "code",
1659
+ "execution_count": null,
1660
+ "id": "d65d514c-c7cb-4a81-be63-da1420eb89c0",
1661
+ "metadata": {},
1662
+ "outputs": [],
1663
+ "source": [
1664
+ "with ThreadPoolExecutor(max_workers=512) as ex:\n",
1665
+ " _ = [ex.submit(validate_train, art) for art in fin_rows]"
1666
+ ]
1667
+ },
1668
+ {
1669
+ "cell_type": "code",
1670
+ "execution_count": null,
1671
+ "id": "8b881fd9-3894-40ac-a032-2a81452f3d7b",
1672
+ "metadata": {},
1673
+ "outputs": [],
1674
+ "source": [
1675
+ "tt = {}\n",
1676
+ "for i in fin_rows[:10000]:\n",
1677
+ " for t in i['themes']:\n",
1678
+ " tt[t] = tt.get(t, 0) + 1\n",
1679
+ "tt"
1680
+ ]
1681
+ },
1682
+ {
1683
+ "cell_type": "code",
1684
+ "execution_count": null,
1685
+ "id": "9e9c7f98-befe-4eb6-9bad-28b839ce385b",
1686
+ "metadata": {},
1687
+ "outputs": [],
1688
+ "source": [
1689
+ "for i in fin_rows:\n",
1690
+ " if i['content'].lower().strip() not in contset:\n",
1691
+ " contset[i['content'].lower().strip()] = i['themes']"
1692
+ ]
1693
+ },
1694
+ {
1695
+ "cell_type": "code",
1696
+ "execution_count": null,
1697
+ "id": "fc3dfbef-9361-4d38-b0b3-c8cbc63d04ce",
1698
+ "metadata": {},
1699
+ "outputs": [],
1700
+ "source": [
1701
+ "ntrain = ntrain + fin_rows\n",
1702
+ "len(ntrain)"
1703
+ ]
1704
+ },
1705
+ {
1706
+ "cell_type": "code",
1707
+ "execution_count": null,
1708
+ "id": "cb8f2cf0-49cf-4306-ac1f-73d19b139ab4",
1709
+ "metadata": {},
1710
+ "outputs": [],
1711
+ "source": [
1712
+ "final_train = {}\n",
1713
+ "for i in ntrain:\n",
1714
+ " sts = sort_themes_by_priority(i['themes'])\n",
1715
+ " sts = [i for i in sts if i in priority_order]\n",
1716
+ " for st in sts:\n",
1717
+ " if st not in final_train:\n",
1718
+ " final_train[st] = []\n",
1719
+ " if len(final_train[st]) < 6000:\n",
1720
+ " final_train[st].append(i)\n",
1721
+ " break\n",
1722
+ "for i in final_train:\n",
1723
+ " print(i, len(final_train[i]))"
1724
+ ]
1725
+ },
1726
+ {
1727
+ "cell_type": "code",
1728
+ "execution_count": null,
1729
+ "id": "2d062629-4f0c-4e15-a36d-ff92ae463399",
1730
+ "metadata": {},
1731
+ "outputs": [],
1732
+ "source": [
1733
+ "final_train = sum([final_train[t] for t in final_train], [])\n",
1734
+ "len(final_train)"
1735
+ ]
1736
+ },
1737
+ {
1738
+ "cell_type": "code",
1739
+ "execution_count": null,
1740
+ "id": "b7d10a78-26c0-461b-800f-cae91b4b7e4a",
1741
+ "metadata": {},
1742
+ "outputs": [],
1743
+ "source": []
1744
+ },
1745
+ {
1746
+ "cell_type": "code",
1747
+ "execution_count": null,
1748
+ "id": "32b75f1c-0b21-46a8-a2cc-9dbe5cb17088",
1749
+ "metadata": {},
1750
+ "outputs": [],
1751
+ "source": [
1752
+ "for i in pre_train:\n",
1753
+ " if i['content'].lower().strip() in contset:\n",
1754
+ " i['themes'] = contset[i['content'].lower().strip()]"
1755
+ ]
1756
+ },
1757
+ {
1758
+ "cell_type": "code",
1759
+ "execution_count": 67,
1760
+ "id": "606aac46-b3ce-48a8-96d5-b179845dfae9",
1761
+ "metadata": {},
1762
+ "outputs": [],
1763
+ "source": [
1764
+ "valid_themes = set(priority_order)\n",
1765
+ "\n",
1766
+ "def dumpjson(rows, file_name):\n",
1767
+ " t = {}\n",
1768
+ " for row in rows:\n",
1769
+ " themes = row['themes']\n",
1770
+ " themes = [i for i in themes if i in valid_themes]\n",
1771
+ " row['themes'] = themes\n",
1772
+ " for th in themes:\n",
1773
+ " t[th] = t.get(th, 0) + 1\n",
1774
+ " print(len(rows), t)\n",
1775
+ " json.dump(rows, open(file_name, 'w'))"
1776
+ ]
1777
+ },
1778
+ {
1779
+ "cell_type": "code",
1780
+ "execution_count": 68,
1781
+ "id": "09f7e9ac-e6fd-47ce-823b-dabe74b27ba0",
1782
+ "metadata": {},
1783
+ "outputs": [
1784
+ {
1785
+ "name": "stdout",
1786
+ "output_type": "stream",
1787
+ "text": [
1788
+ "92245 {'Entertainment': 6854, 'General': 6302, 'Politics': 14207, 'Business': 14722, 'Lifestyle': 7827, 'Tech': 8823, 'Health': 7985, 'Crime': 9548, 'Economics': 7703, 'Sports': 6607, 'Travel': 6812, 'Science': 4773, 'Automotive': 4698, 'Weather': 6351, 'Finance': 8080, 'Financial Crime': 5218}\n"
1789
+ ]
1790
+ }
1791
+ ],
1792
+ "source": [
1793
+ "dumpjson(final_train, 'filtered2/train.json')"
1794
+ ]
1795
+ },
1796
+ {
1797
+ "cell_type": "code",
1798
+ "execution_count": 69,
1799
+ "id": "902ece36-0efa-4225-97c1-7a6b0f6082a6",
1800
+ "metadata": {},
1801
+ "outputs": [
1802
+ {
1803
+ "name": "stdout",
1804
+ "output_type": "stream",
1805
+ "text": [
1806
+ "8000 {'Entertainment': 588, 'General': 534, 'Politics': 1285, 'Lifestyle': 505, 'Health': 687, 'Sports': 570, 'Weather': 529, 'Business': 1465, 'Tech': 763, 'Crime': 676, 'Travel': 543, 'Science': 527, 'Automotive': 509, 'Economics': 500, 'Finance': 703, 'Financial Crime': 507}\n"
1807
+ ]
1808
+ }
1809
+ ],
1810
+ "source": [
1811
+ "dumpjson(nval, 'filtered2/val.json')"
1812
+ ]
1813
+ },
1814
+ {
1815
+ "cell_type": "code",
1816
+ "execution_count": 70,
1817
+ "id": "03a5343e-2fee-4051-be11-d37bab9a4c24",
1818
+ "metadata": {},
1819
+ "outputs": [
1820
+ {
1821
+ "name": "stdout",
1822
+ "output_type": "stream",
1823
+ "text": [
1824
+ "16000 {'Entertainment': 1161, 'Politics': 2414, 'Crime': 1337, 'Business': 2698, 'General': 1065, 'Sports': 1118, 'Financial Crime': 1009, 'Finance': 1371, 'Travel': 1143, 'Tech': 1476, 'Health': 1295, 'Automotive': 1014, 'Weather': 1060, 'Lifestyle': 1023, 'Economics': 1000, 'Science': 1056}\n"
1825
+ ]
1826
+ }
1827
+ ],
1828
+ "source": [
1829
+ "dumpjson(test, 'filtered2/test.json')"
1830
+ ]
1831
+ },
1832
+ {
1833
+ "cell_type": "code",
1834
+ "execution_count": null,
1835
+ "id": "6892ae88-b0ae-4318-82fe-868f0f260c1c",
1836
+ "metadata": {},
1837
+ "outputs": [],
1838
+ "source": [
1839
+ "for i in train:\n",
1840
+ " if i['content'].lower().strip() not in contset:\n",
1841
+ " contset[i['content'].lower().strip()] = i['themes']\n",
1842
+ "\n",
1843
+ "for i in pre_train:\n",
1844
+ " if i['content'].lower().strip() in contset:\n",
1845
+ " i['themes'] = contset[i['content'].lower().strip()]"
1846
+ ]
1847
+ },
1848
+ {
1849
+ "cell_type": "code",
1850
+ "execution_count": null,
1851
+ "id": "9cd446ac-0209-4937-a8da-045e8d17edb8",
1852
+ "metadata": {},
1853
+ "outputs": [],
1854
+ "source": [
1855
+ "len(pre_train)"
1856
+ ]
1857
+ },
1858
+ {
1859
+ "cell_type": "code",
1860
+ "execution_count": null,
1861
+ "id": "2ae5ec6d-ed70-4ba8-a188-a1c3dd561dbf",
1862
+ "metadata": {},
1863
+ "outputs": [],
1864
+ "source": [
1865
+ "to_ignore = set()\n",
1866
+ "for i in test + nval:\n",
1867
+ " to_ignore.add(i['content'].lower().strip())\n",
1868
+ "\n",
1869
+ "final_pre_train = []\n",
1870
+ "for i in pre_train:\n",
1871
+ " if i['content'].lower().strip() in to_ignore:\n",
1872
+ " continue\n",
1873
+ " final_pre_train.append(i)\n",
1874
+ "len(final_pre_train), len(to_ignore)"
1875
+ ]
1876
+ },
1877
+ {
1878
+ "cell_type": "code",
1879
+ "execution_count": 71,
1880
+ "id": "6f560a9a-cf70-47f3-97f3-ae45f520590f",
1881
+ "metadata": {},
1882
+ "outputs": [
1883
+ {
1884
+ "name": "stdout",
1885
+ "output_type": "stream",
1886
+ "text": [
1887
+ "397784 {'Sports': 37315, 'Lifestyle': 21900, 'Crime': 88859, 'Business': 60980, 'Entertainment': 40906, 'Politics': 78084, 'General': 25417, 'Tech': 25008, 'Health': 32249, 'Travel': 15170, 'Finance': 21892, 'Economics': 13128, 'Weather': 13586, 'Science': 14338, 'Automotive': 12176, 'Financial Crime': 5035}\n"
1888
+ ]
1889
+ }
1890
+ ],
1891
+ "source": [
1892
+ "dumpjson(final_pre_train, 'filtered2/pre-train.json')"
1893
+ ]
1894
+ },
1895
+ {
1896
+ "cell_type": "code",
1897
+ "execution_count": null,
1898
+ "id": "dd98be36-8ec4-444e-a7fa-d7196334a630",
1899
+ "metadata": {},
1900
+ "outputs": [],
1901
+ "source": []
1902
+ },
1903
+ {
1904
+ "cell_type": "code",
1905
+ "execution_count": null,
1906
+ "id": "1fbf2b53-45a6-4727-b862-bb7819e4562e",
1907
+ "metadata": {},
1908
+ "outputs": [],
1909
+ "source": []
1910
+ }
1911
+ ],
1912
+ "metadata": {
1913
+ "kernelspec": {
1914
+ "display_name": "Python 3 (ipykernel)",
1915
+ "language": "python",
1916
+ "name": "python3"
1917
+ },
1918
+ "language_info": {
1919
+ "codemirror_mode": {
1920
+ "name": "ipython",
1921
+ "version": 3
1922
+ },
1923
+ "file_extension": ".py",
1924
+ "mimetype": "text/x-python",
1925
+ "name": "python",
1926
+ "nbconvert_exporter": "python",
1927
+ "pygments_lexer": "ipython3",
1928
+ "version": "3.10.12"
1929
+ }
1930
+ },
1931
+ "nbformat": 4,
1932
+ "nbformat_minor": 5
1933
+ }
3-training.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
mlp_model.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:32b352f602bdf62dc246c0d6509495469520d1c6e4caa856619beeae90d89c5f
3
+ size 9100341
pre-train.json.zip ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:025bb8dd1a317149eb9ce207b992142cbc99a0e50efa0704b5222701bf69d2da
3
+ size 509109586
scaler.joblib ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a50abf4348f226b4f3dbd0fda27ae284de7882851bcdcfda3e25a3bcaed0bbab
3
+ size 25191
test.json.zip ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:dbd8a5492b857ef5f1c0b8eb58a0f9c0bcc85532bae026bde30e3c47269ca423
3
+ size 22253927
train.json.zip ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b5192ac529ad7c2c8083981325dd14d774a92f04ae768aac38f0b9d6669d75d6
3
+ size 129843701
val.json.zip ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3de1a164f7512828b632939d5cc75b002f7f8564c0ba142b3329cb5e2e3a5199
3
+ size 11034512