Upload folder using huggingface_hub
Browse files- 0-data-prepare.ipynb +509 -0
- 1-training.ipynb +0 -0
- 2-prepare-data-2.ipynb +1933 -0
- 3-training.ipynb +0 -0
- mlp_model.pth +3 -0
- pre-train.json.zip +3 -0
- scaler.joblib +3 -0
- test.json.zip +3 -0
- train.json.zip +3 -0
- val.json.zip +3 -0
0-data-prepare.ipynb
ADDED
|
@@ -0,0 +1,509 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"cells": [
|
| 3 |
+
{
|
| 4 |
+
"cell_type": "code",
|
| 5 |
+
"execution_count": 1,
|
| 6 |
+
"id": "7384109c-4507-4895-ac34-8400e2978021",
|
| 7 |
+
"metadata": {},
|
| 8 |
+
"outputs": [],
|
| 9 |
+
"source": [
|
| 10 |
+
"import ujson as json"
|
| 11 |
+
]
|
| 12 |
+
},
|
| 13 |
+
{
|
| 14 |
+
"cell_type": "code",
|
| 15 |
+
"execution_count": 2,
|
| 16 |
+
"id": "de47a236-bb5d-4cb9-89e4-a8eb3f66abaa",
|
| 17 |
+
"metadata": {},
|
| 18 |
+
"outputs": [
|
| 19 |
+
{
|
| 20 |
+
"name": "stdout",
|
| 21 |
+
"output_type": "stream",
|
| 22 |
+
"text": [
|
| 23 |
+
"art_en_crime.json 22116 dict_keys(['id', 'embedding', 'themes', 'language', 'title', 'content', 'cluster_id'])\n",
|
| 24 |
+
"fincrime_train.json 27548 dict_keys(['content', 'themes', 'embedding', 'llm_themes', 'cl_themes'])\n",
|
| 25 |
+
"train_multilang_2.json 119583 dict_keys(['language', 'content', 'id', 'lang', 'embedding', 'cluster_id', 'themes', 'llm_themes'])\n",
|
| 26 |
+
"other_lang_train.json 149000 dict_keys(['embedding', 'themes', 'id', 'content'])\n"
|
| 27 |
+
]
|
| 28 |
+
},
|
| 29 |
+
{
|
| 30 |
+
"data": {
|
| 31 |
+
"text/plain": [
|
| 32 |
+
"318247"
|
| 33 |
+
]
|
| 34 |
+
},
|
| 35 |
+
"execution_count": 2,
|
| 36 |
+
"metadata": {},
|
| 37 |
+
"output_type": "execute_result"
|
| 38 |
+
}
|
| 39 |
+
],
|
| 40 |
+
"source": [
|
| 41 |
+
"train_files = ['art_en_crime.json', 'fincrime_train.json', 'train_multilang_2.json', 'other_lang_train.json']\n",
|
| 42 |
+
"# 'en_train.json'\n",
|
| 43 |
+
"train_data = []\n",
|
| 44 |
+
"for train_file in train_files:\n",
|
| 45 |
+
" tm = json.load(open('topic_data_new_sorted/'+train_file))\n",
|
| 46 |
+
" print(train_file, len(tm), tm[0].keys())\n",
|
| 47 |
+
" train_data.extend(tm)\n",
|
| 48 |
+
"len(train_data)"
|
| 49 |
+
]
|
| 50 |
+
},
|
| 51 |
+
{
|
| 52 |
+
"cell_type": "code",
|
| 53 |
+
"execution_count": 3,
|
| 54 |
+
"id": "97e48648-4a9f-40b9-8b95-2740a93b4762",
|
| 55 |
+
"metadata": {},
|
| 56 |
+
"outputs": [],
|
| 57 |
+
"source": [
|
| 58 |
+
"for i in train_data:\n",
|
| 59 |
+
" del i['embedding']"
|
| 60 |
+
]
|
| 61 |
+
},
|
| 62 |
+
{
|
| 63 |
+
"cell_type": "code",
|
| 64 |
+
"execution_count": 4,
|
| 65 |
+
"id": "b3b5f700-5467-4d8b-963d-060a1962ab00",
|
| 66 |
+
"metadata": {},
|
| 67 |
+
"outputs": [],
|
| 68 |
+
"source": [
|
| 69 |
+
"final_data = []\n",
|
| 70 |
+
"contents = set()\n",
|
| 71 |
+
"for i in train_data:\n",
|
| 72 |
+
" if not i.get('content'):\n",
|
| 73 |
+
" continue\n",
|
| 74 |
+
" th = hash(i['content'])\n",
|
| 75 |
+
" if th not in contents:\n",
|
| 76 |
+
" final_data.append(i)\n",
|
| 77 |
+
" contents.add(th)"
|
| 78 |
+
]
|
| 79 |
+
},
|
| 80 |
+
{
|
| 81 |
+
"cell_type": "code",
|
| 82 |
+
"execution_count": 5,
|
| 83 |
+
"id": "3527a7a2-b58c-406b-890e-99002ef050df",
|
| 84 |
+
"metadata": {},
|
| 85 |
+
"outputs": [
|
| 86 |
+
{
|
| 87 |
+
"data": {
|
| 88 |
+
"text/plain": [
|
| 89 |
+
"(22116, 245203, 318247)"
|
| 90 |
+
]
|
| 91 |
+
},
|
| 92 |
+
"execution_count": 5,
|
| 93 |
+
"metadata": {},
|
| 94 |
+
"output_type": "execute_result"
|
| 95 |
+
}
|
| 96 |
+
],
|
| 97 |
+
"source": [
|
| 98 |
+
"tc, cc = 0, 0\n",
|
| 99 |
+
"for i in train_data:\n",
|
| 100 |
+
" if i.get('title'):\n",
|
| 101 |
+
" tc += 1\n",
|
| 102 |
+
" if i.get('content'):\n",
|
| 103 |
+
" cc += 1\n",
|
| 104 |
+
"tc, cc, len(train_data)"
|
| 105 |
+
]
|
| 106 |
+
},
|
| 107 |
+
{
|
| 108 |
+
"cell_type": "code",
|
| 109 |
+
"execution_count": 6,
|
| 110 |
+
"id": "460312d0-236c-418d-a53e-1357fdc586bc",
|
| 111 |
+
"metadata": {},
|
| 112 |
+
"outputs": [
|
| 113 |
+
{
|
| 114 |
+
"data": {
|
| 115 |
+
"text/plain": [
|
| 116 |
+
"144496"
|
| 117 |
+
]
|
| 118 |
+
},
|
| 119 |
+
"execution_count": 6,
|
| 120 |
+
"metadata": {},
|
| 121 |
+
"output_type": "execute_result"
|
| 122 |
+
}
|
| 123 |
+
],
|
| 124 |
+
"source": [
|
| 125 |
+
"len(final_data)"
|
| 126 |
+
]
|
| 127 |
+
},
|
| 128 |
+
{
|
| 129 |
+
"cell_type": "code",
|
| 130 |
+
"execution_count": 7,
|
| 131 |
+
"id": "631689b1-c18d-427c-baea-8d2d3d4b9781",
|
| 132 |
+
"metadata": {},
|
| 133 |
+
"outputs": [
|
| 134 |
+
{
|
| 135 |
+
"data": {
|
| 136 |
+
"text/plain": [
|
| 137 |
+
"{'Politics': 44642,\n",
|
| 138 |
+
" 'Crime': 58051,\n",
|
| 139 |
+
" 'Financial Crime': 24289,\n",
|
| 140 |
+
" 'Business': 25247,\n",
|
| 141 |
+
" 'Entertainment': 18221,\n",
|
| 142 |
+
" 'Finance': 6166,\n",
|
| 143 |
+
" 'Economics': 1772,\n",
|
| 144 |
+
" 'Sports': 15184,\n",
|
| 145 |
+
" 'Tech': 4107,\n",
|
| 146 |
+
" 'Automotive': 2116,\n",
|
| 147 |
+
" 'Health': 7829,\n",
|
| 148 |
+
" 'Lifestyle': 368,\n",
|
| 149 |
+
" 'Science': 3481,\n",
|
| 150 |
+
" 'Travel': 914,\n",
|
| 151 |
+
" 'Weather': 1070,\n",
|
| 152 |
+
" 'General': 8972,\n",
|
| 153 |
+
" 'Types of life insurance fraud': 1,\n",
|
| 154 |
+
" 'Consequences of insurance fraud': 1,\n",
|
| 155 |
+
" 'How to prevent life insurance fraud': 1,\n",
|
| 156 |
+
" 'Front Running': 1,\n",
|
| 157 |
+
" 'fraud': 1}"
|
| 158 |
+
]
|
| 159 |
+
},
|
| 160 |
+
"execution_count": 7,
|
| 161 |
+
"metadata": {},
|
| 162 |
+
"output_type": "execute_result"
|
| 163 |
+
}
|
| 164 |
+
],
|
| 165 |
+
"source": [
|
| 166 |
+
"themes = {}\n",
|
| 167 |
+
"for i in final_data:\n",
|
| 168 |
+
" for t in i['themes']:\n",
|
| 169 |
+
" themes[t] = themes.get(t, 0) + 1\n",
|
| 170 |
+
"themes"
|
| 171 |
+
]
|
| 172 |
+
},
|
| 173 |
+
{
|
| 174 |
+
"cell_type": "code",
|
| 175 |
+
"execution_count": 8,
|
| 176 |
+
"id": "a7bccdc9-de94-4e7b-b686-0fa062f5e781",
|
| 177 |
+
"metadata": {},
|
| 178 |
+
"outputs": [
|
| 179 |
+
{
|
| 180 |
+
"name": "stderr",
|
| 181 |
+
"output_type": "stream",
|
| 182 |
+
"text": [
|
| 183 |
+
"/usr/local/lib/python3.10/dist-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
|
| 184 |
+
" from .autonotebook import tqdm as notebook_tqdm\n"
|
| 185 |
+
]
|
| 186 |
+
},
|
| 187 |
+
{
|
| 188 |
+
"data": {
|
| 189 |
+
"text/plain": [
|
| 190 |
+
"'tr'"
|
| 191 |
+
]
|
| 192 |
+
},
|
| 193 |
+
"execution_count": 8,
|
| 194 |
+
"metadata": {},
|
| 195 |
+
"output_type": "execute_result"
|
| 196 |
+
}
|
| 197 |
+
],
|
| 198 |
+
"source": [
|
| 199 |
+
"from fast_langdetect import detect\n",
|
| 200 |
+
"result = detect(text=\"Bugün hava çok güzel\", model='full', k=1)\n",
|
| 201 |
+
"result[0]['lang']"
|
| 202 |
+
]
|
| 203 |
+
},
|
| 204 |
+
{
|
| 205 |
+
"cell_type": "code",
|
| 206 |
+
"execution_count": 9,
|
| 207 |
+
"id": "ed34643a-974c-4995-b349-6ce194397985",
|
| 208 |
+
"metadata": {},
|
| 209 |
+
"outputs": [
|
| 210 |
+
{
|
| 211 |
+
"data": {
|
| 212 |
+
"text/plain": [
|
| 213 |
+
"{'en': 49682,\n",
|
| 214 |
+
" 'hi': 541,\n",
|
| 215 |
+
" 'mt': 1,\n",
|
| 216 |
+
" 'fr': 11193,\n",
|
| 217 |
+
" 'de': 12987,\n",
|
| 218 |
+
" 'pt': 1586,\n",
|
| 219 |
+
" 'it': 1682,\n",
|
| 220 |
+
" 'ko': 455,\n",
|
| 221 |
+
" 'es': 4034,\n",
|
| 222 |
+
" 'zh': 5,\n",
|
| 223 |
+
" 'ar': 35732,\n",
|
| 224 |
+
" 'bn': 420,\n",
|
| 225 |
+
" 'pl': 1387,\n",
|
| 226 |
+
" 'nl': 824,\n",
|
| 227 |
+
" 'lv': 379,\n",
|
| 228 |
+
" 'hk': 345,\n",
|
| 229 |
+
" 'hr': 511,\n",
|
| 230 |
+
" 'ja': 591,\n",
|
| 231 |
+
" 'cy': 182,\n",
|
| 232 |
+
" 'sv': 852,\n",
|
| 233 |
+
" 'da': 2120,\n",
|
| 234 |
+
" 'el': 462,\n",
|
| 235 |
+
" 'tr': 507,\n",
|
| 236 |
+
" 'ro': 546,\n",
|
| 237 |
+
" 'ur': 723,\n",
|
| 238 |
+
" 'mr': 344,\n",
|
| 239 |
+
" 'so': 338,\n",
|
| 240 |
+
" 'fa': 729,\n",
|
| 241 |
+
" 'mk': 451,\n",
|
| 242 |
+
" 'gu': 385,\n",
|
| 243 |
+
" 'th': 617,\n",
|
| 244 |
+
" 'lt': 444,\n",
|
| 245 |
+
" 'tw': 246,\n",
|
| 246 |
+
" 'sl': 568,\n",
|
| 247 |
+
" 'ml': 414,\n",
|
| 248 |
+
" 'te': 344,\n",
|
| 249 |
+
" 'he': 526,\n",
|
| 250 |
+
" 'cs': 671,\n",
|
| 251 |
+
" 'et': 660,\n",
|
| 252 |
+
" 'ta': 463,\n",
|
| 253 |
+
" 'gl': 9,\n",
|
| 254 |
+
" 'id': 538,\n",
|
| 255 |
+
" 'ca': 506,\n",
|
| 256 |
+
" 'ast': 1,\n",
|
| 257 |
+
" 'eu': 1,\n",
|
| 258 |
+
" 'sk': 455,\n",
|
| 259 |
+
" 'sq': 424,\n",
|
| 260 |
+
" 'ne': 429,\n",
|
| 261 |
+
" 'fi': 652,\n",
|
| 262 |
+
" 'sw': 394,\n",
|
| 263 |
+
" 'bg': 507,\n",
|
| 264 |
+
" 'ru': 653,\n",
|
| 265 |
+
" 'hu': 564,\n",
|
| 266 |
+
" 'cn': 261,\n",
|
| 267 |
+
" 'vi': 535,\n",
|
| 268 |
+
" 'pa': 343,\n",
|
| 269 |
+
" 'no': 1156,\n",
|
| 270 |
+
" 'tl': 559,\n",
|
| 271 |
+
" 'uk': 500,\n",
|
| 272 |
+
" 'kn': 374,\n",
|
| 273 |
+
" 'af': 611,\n",
|
| 274 |
+
" 'arz': 58,\n",
|
| 275 |
+
" 'nn': 6,\n",
|
| 276 |
+
" 'dv': 10,\n",
|
| 277 |
+
" 'azb': 1,\n",
|
| 278 |
+
" 'sd': 1,\n",
|
| 279 |
+
" 'ckb': 1}"
|
| 280 |
+
]
|
| 281 |
+
},
|
| 282 |
+
"execution_count": 9,
|
| 283 |
+
"metadata": {},
|
| 284 |
+
"output_type": "execute_result"
|
| 285 |
+
}
|
| 286 |
+
],
|
| 287 |
+
"source": [
|
| 288 |
+
"langs = {}\n",
|
| 289 |
+
"for i in final_data:\n",
|
| 290 |
+
" l = i.get('language')\n",
|
| 291 |
+
" if not l:\n",
|
| 292 |
+
" l = i.get('language')\n",
|
| 293 |
+
" if not l:\n",
|
| 294 |
+
" l = detect(text=i['content'], model='full', k=1)[0]['lang']\n",
|
| 295 |
+
" langs[l] = langs.get(l, 0) + 1\n",
|
| 296 |
+
"langs"
|
| 297 |
+
]
|
| 298 |
+
},
|
| 299 |
+
{
|
| 300 |
+
"cell_type": "code",
|
| 301 |
+
"execution_count": 12,
|
| 302 |
+
"id": "dcdc4560-2bcb-4d90-a64d-436b85f23bc6",
|
| 303 |
+
"metadata": {},
|
| 304 |
+
"outputs": [],
|
| 305 |
+
"source": [
|
| 306 |
+
"del train_data"
|
| 307 |
+
]
|
| 308 |
+
},
|
| 309 |
+
{
|
| 310 |
+
"cell_type": "code",
|
| 311 |
+
"execution_count": 2,
|
| 312 |
+
"id": "770489b9-78c2-4231-9c8b-ba000342fe6d",
|
| 313 |
+
"metadata": {},
|
| 314 |
+
"outputs": [],
|
| 315 |
+
"source": [
|
| 316 |
+
"final_data = json.load(open('train1.json'))"
|
| 317 |
+
]
|
| 318 |
+
},
|
| 319 |
+
{
|
| 320 |
+
"cell_type": "code",
|
| 321 |
+
"execution_count": 7,
|
| 322 |
+
"id": "a69be4e4-08b3-4830-9b4d-b8ea67f15656",
|
| 323 |
+
"metadata": {},
|
| 324 |
+
"outputs": [],
|
| 325 |
+
"source": [
|
| 326 |
+
"import requests\n",
|
| 327 |
+
"from requests.adapters import HTTPAdapter\n",
|
| 328 |
+
"from concurrent.futures import ThreadPoolExecutor, as_completed\n",
|
| 329 |
+
"from typing import List, Dict, Any, Optional, Tuple\n",
|
| 330 |
+
"\n",
|
| 331 |
+
"def _make_session(pool_size: int) -> requests.Session:\n",
|
| 332 |
+
" s = requests.Session()\n",
|
| 333 |
+
" adapter = HTTPAdapter(pool_connections=pool_size, pool_maxsize=pool_size, max_retries=0)\n",
|
| 334 |
+
" s.mount(\"http://\", adapter)\n",
|
| 335 |
+
" s.mount(\"https://\", adapter)\n",
|
| 336 |
+
" return s\n",
|
| 337 |
+
"\n",
|
| 338 |
+
"\n",
|
| 339 |
+
"def _embed_batch(\n",
|
| 340 |
+
" session: requests.Session,\n",
|
| 341 |
+
" base_url: str,\n",
|
| 342 |
+
" text: str,\n",
|
| 343 |
+
" timeout_s: float,\n",
|
| 344 |
+
" extra_headers: Optional[Dict[str, str]] = None,\n",
|
| 345 |
+
") -> Any:\n",
|
| 346 |
+
" \"\"\"\n",
|
| 347 |
+
" Native TEI endpoint: POST {base_url}/embed\n",
|
| 348 |
+
" Payload: {\"inputs\": [..texts..]}\n",
|
| 349 |
+
" \"\"\"\n",
|
| 350 |
+
" url = base_url.rstrip(\"/\") + \"/embed\"\n",
|
| 351 |
+
" headers = {\"Content-Type\": \"application/json\"}\n",
|
| 352 |
+
" if extra_headers:\n",
|
| 353 |
+
" headers.update(extra_headers)\n",
|
| 354 |
+
" r = session.post(url, json={\"inputs\": text}, headers=headers, timeout=timeout_s)\n",
|
| 355 |
+
" r.raise_for_status()\n",
|
| 356 |
+
" return r.json()\n",
|
| 357 |
+
"\n",
|
| 358 |
+
"def one_call(art):\n",
|
| 359 |
+
" if \"embedding\" in art:\n",
|
| 360 |
+
" return\n",
|
| 361 |
+
" text = []\n",
|
| 362 |
+
" if art.get(\"title\"):\n",
|
| 363 |
+
" text.append(art[\"title\"])\n",
|
| 364 |
+
" if art.get(\"content\"):\n",
|
| 365 |
+
" text.append(art[\"content\"])\n",
|
| 366 |
+
" text = \"\\n\\n\".join(text)\n",
|
| 367 |
+
" res = _embed_batch(session, base_url, text, timeout_s=30)\n",
|
| 368 |
+
" art[\"embedding\"] = res[0]\n",
|
| 369 |
+
"\n",
|
| 370 |
+
"WORKERS = 64\n",
|
| 371 |
+
"session = _make_session(pool_size=WORKERS)\n",
|
| 372 |
+
"base_url = \"http://172.83.12.123:8080\""
|
| 373 |
+
]
|
| 374 |
+
},
|
| 375 |
+
{
|
| 376 |
+
"cell_type": "code",
|
| 377 |
+
"execution_count": 6,
|
| 378 |
+
"id": "e8d7127f-af28-4029-aa16-35b52269317a",
|
| 379 |
+
"metadata": {},
|
| 380 |
+
"outputs": [],
|
| 381 |
+
"source": [
|
| 382 |
+
"with ThreadPoolExecutor(max_workers=WORKERS) as ex:\n",
|
| 383 |
+
" _ = [ex.submit(one_call, art) for art in final_data]"
|
| 384 |
+
]
|
| 385 |
+
},
|
| 386 |
+
{
|
| 387 |
+
"cell_type": "code",
|
| 388 |
+
"execution_count": 10,
|
| 389 |
+
"id": "1ba25f4b-7bec-40a6-9b78-d014a64674c4",
|
| 390 |
+
"metadata": {},
|
| 391 |
+
"outputs": [
|
| 392 |
+
{
|
| 393 |
+
"data": {
|
| 394 |
+
"text/plain": [
|
| 395 |
+
"(144496, 144496)"
|
| 396 |
+
]
|
| 397 |
+
},
|
| 398 |
+
"execution_count": 10,
|
| 399 |
+
"metadata": {},
|
| 400 |
+
"output_type": "execute_result"
|
| 401 |
+
}
|
| 402 |
+
],
|
| 403 |
+
"source": [
|
| 404 |
+
"te = 0\n",
|
| 405 |
+
"for i in final_data:\n",
|
| 406 |
+
" if \"embedding\" in i and len(i[\"embedding\"]) == 1024:\n",
|
| 407 |
+
" te += 1\n",
|
| 408 |
+
"te, len(final_data)"
|
| 409 |
+
]
|
| 410 |
+
},
|
| 411 |
+
{
|
| 412 |
+
"cell_type": "code",
|
| 413 |
+
"execution_count": 9,
|
| 414 |
+
"id": "1e821b09-cfce-4ae6-82c6-1d786efcc676",
|
| 415 |
+
"metadata": {},
|
| 416 |
+
"outputs": [],
|
| 417 |
+
"source": [
|
| 418 |
+
"json.dump(final_data, open('train1.json', 'w'))"
|
| 419 |
+
]
|
| 420 |
+
},
|
| 421 |
+
{
|
| 422 |
+
"cell_type": "code",
|
| 423 |
+
"execution_count": null,
|
| 424 |
+
"id": "19a5069a-6a0b-4ed2-ae4c-173899b4d632",
|
| 425 |
+
"metadata": {},
|
| 426 |
+
"outputs": [],
|
| 427 |
+
"source": []
|
| 428 |
+
},
|
| 429 |
+
{
|
| 430 |
+
"cell_type": "code",
|
| 431 |
+
"execution_count": null,
|
| 432 |
+
"id": "4b02dcf6-2d64-4dab-a417-da3d2c411254",
|
| 433 |
+
"metadata": {},
|
| 434 |
+
"outputs": [],
|
| 435 |
+
"source": []
|
| 436 |
+
},
|
| 437 |
+
{
|
| 438 |
+
"cell_type": "code",
|
| 439 |
+
"execution_count": 9,
|
| 440 |
+
"id": "f427d890-ce22-4a04-9157-04528db80369",
|
| 441 |
+
"metadata": {},
|
| 442 |
+
"outputs": [
|
| 443 |
+
{
|
| 444 |
+
"name": "stdout",
|
| 445 |
+
"output_type": "stream",
|
| 446 |
+
"text": [
|
| 447 |
+
"en_test.json 750 750 750\n",
|
| 448 |
+
"other_lang_test.json 1000 981 981\n",
|
| 449 |
+
"spanish_tagged.json 1199 1197 1197\n",
|
| 450 |
+
"fincrime_test.json 623 623 623\n"
|
| 451 |
+
]
|
| 452 |
+
}
|
| 453 |
+
],
|
| 454 |
+
"source": [
|
| 455 |
+
"files = [\"en_test.json\", \"other_lang_test.json\", \"spanish_tagged.json\", \"fincrime_test.json\"]\n",
|
| 456 |
+
"for file in files:\n",
|
| 457 |
+
" tm = json.load(open('topic_data_new_sorted/'+file))\n",
|
| 458 |
+
" tf = []\n",
|
| 459 |
+
" cc, em = 0, 0\n",
|
| 460 |
+
" contents = set()\n",
|
| 461 |
+
" for i in tm:\n",
|
| 462 |
+
" if i['content'].lower() in contents:\n",
|
| 463 |
+
" continue\n",
|
| 464 |
+
" if i.get('content'):\n",
|
| 465 |
+
" cc += 1\n",
|
| 466 |
+
" contents.add(i['content'].lower())\n",
|
| 467 |
+
" if \"embedding\" in i:\n",
|
| 468 |
+
" del i[\"embedding\"]\n",
|
| 469 |
+
" tf.append(i)\n",
|
| 470 |
+
" with ThreadPoolExecutor(max_workers=WORKERS) as ex:\n",
|
| 471 |
+
" _ = [ex.submit(one_call, art) for art in tf]\n",
|
| 472 |
+
" for i in tf:\n",
|
| 473 |
+
" if \"embedding\" in i and len(i[\"embedding\"]) == 1024:\n",
|
| 474 |
+
" em += 1\n",
|
| 475 |
+
" json.dump(tf, open('test/' + file, 'w'))\n",
|
| 476 |
+
" print(file, len(tm), cc, em)"
|
| 477 |
+
]
|
| 478 |
+
},
|
| 479 |
+
{
|
| 480 |
+
"cell_type": "code",
|
| 481 |
+
"execution_count": null,
|
| 482 |
+
"id": "7e42ec21-99df-4b62-b678-86bf24bb259e",
|
| 483 |
+
"metadata": {},
|
| 484 |
+
"outputs": [],
|
| 485 |
+
"source": []
|
| 486 |
+
}
|
| 487 |
+
],
|
| 488 |
+
"metadata": {
|
| 489 |
+
"kernelspec": {
|
| 490 |
+
"display_name": "Python 3 (ipykernel)",
|
| 491 |
+
"language": "python",
|
| 492 |
+
"name": "python3"
|
| 493 |
+
},
|
| 494 |
+
"language_info": {
|
| 495 |
+
"codemirror_mode": {
|
| 496 |
+
"name": "ipython",
|
| 497 |
+
"version": 3
|
| 498 |
+
},
|
| 499 |
+
"file_extension": ".py",
|
| 500 |
+
"mimetype": "text/x-python",
|
| 501 |
+
"name": "python",
|
| 502 |
+
"nbconvert_exporter": "python",
|
| 503 |
+
"pygments_lexer": "ipython3",
|
| 504 |
+
"version": "3.10.12"
|
| 505 |
+
}
|
| 506 |
+
},
|
| 507 |
+
"nbformat": 4,
|
| 508 |
+
"nbformat_minor": 5
|
| 509 |
+
}
|
1-training.ipynb
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
2-prepare-data-2.ipynb
ADDED
|
@@ -0,0 +1,1933 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"cells": [
|
| 3 |
+
{
|
| 4 |
+
"cell_type": "code",
|
| 5 |
+
"execution_count": 1,
|
| 6 |
+
"id": "a184875c-343b-4a39-ac51-2a80f8f661a5",
|
| 7 |
+
"metadata": {},
|
| 8 |
+
"outputs": [],
|
| 9 |
+
"source": [
|
| 10 |
+
"import ujson as json\n",
|
| 11 |
+
"import random\n",
|
| 12 |
+
"import os"
|
| 13 |
+
]
|
| 14 |
+
},
|
| 15 |
+
{
|
| 16 |
+
"cell_type": "code",
|
| 17 |
+
"execution_count": null,
|
| 18 |
+
"id": "a2f82047-027a-4950-8a17-6c8437d2f2c9",
|
| 19 |
+
"metadata": {},
|
| 20 |
+
"outputs": [],
|
| 21 |
+
"source": [
|
| 22 |
+
"1. prepare dataset by lang + theme\n",
|
| 23 |
+
"2. choose MMR candidates (take atleast 15K each theme and 100 per each lang, en=7K, other=8K)\n",
|
| 24 |
+
"3. calculate embedding of qwen\n",
|
| 25 |
+
"4. again choose MMR based on qwen embedding also make list of llm verify candidate when qwen emb not matching with given themes (5K each)\n",
|
| 26 |
+
"5. confirm with LLM for selected candidates\n",
|
| 27 |
+
"6. split train & test also pre-train"
|
| 28 |
+
]
|
| 29 |
+
},
|
| 30 |
+
{
|
| 31 |
+
"cell_type": "markdown",
|
| 32 |
+
"id": "e1679138-5178-4262-a2f1-5589dd80c7e9",
|
| 33 |
+
"metadata": {},
|
| 34 |
+
"source": [
|
| 35 |
+
"# 1. prepare dataset by lang + theme"
|
| 36 |
+
]
|
| 37 |
+
},
|
| 38 |
+
{
|
| 39 |
+
"cell_type": "code",
|
| 40 |
+
"execution_count": null,
|
| 41 |
+
"id": "c70cc644-b406-4105-8bef-5dad0552d354",
|
| 42 |
+
"metadata": {},
|
| 43 |
+
"outputs": [],
|
| 44 |
+
"source": [
|
| 45 |
+
"files = ['art_en.json', 'art_en2.json', 'art_non_en.json', 'art_non_en2.json']\n",
|
| 46 |
+
"for file in files:\n",
|
| 47 |
+
" fo = open('raw_articles/' + file)\n",
|
| 48 |
+
" tt = {}\n",
|
| 49 |
+
" for line in fo:\n",
|
| 50 |
+
" d = json.loads(line)\n",
|
| 51 |
+
" if not d['_source'].get('nlp'):\n",
|
| 52 |
+
" continue\n",
|
| 53 |
+
" for t in d['_source']['nlp']['theme']:\n",
|
| 54 |
+
" tt[t] = tt.get(t, 0) + 1\n",
|
| 55 |
+
" tt = dict(sorted(tt.items(), key=lambda item: item[1]))\n",
|
| 56 |
+
" print(file, tt)"
|
| 57 |
+
]
|
| 58 |
+
},
|
| 59 |
+
{
|
| 60 |
+
"cell_type": "code",
|
| 61 |
+
"execution_count": null,
|
| 62 |
+
"id": "94a9c39a-e060-4b87-8edd-e65a3c480c86",
|
| 63 |
+
"metadata": {},
|
| 64 |
+
"outputs": [],
|
| 65 |
+
"source": [
|
| 66 |
+
"files = ['art_en2_specific.json', 'art_non_en_specific.json']\n",
|
| 67 |
+
"for file in files:\n",
|
| 68 |
+
" fo = open('raw_articles/' + file)\n",
|
| 69 |
+
" tt = {}\n",
|
| 70 |
+
" for line in fo:\n",
|
| 71 |
+
" d = json.loads(line)\n",
|
| 72 |
+
" if not d['_source'].get('nlp'):\n",
|
| 73 |
+
" continue\n",
|
| 74 |
+
" for t in d['_source']['nlp']['theme']:\n",
|
| 75 |
+
" tt[t] = tt.get(t, 0) + 1\n",
|
| 76 |
+
" tt = dict(sorted(tt.items(), key=lambda item: item[1]))\n",
|
| 77 |
+
" print(file, tt)"
|
| 78 |
+
]
|
| 79 |
+
},
|
| 80 |
+
{
|
| 81 |
+
"cell_type": "code",
|
| 82 |
+
"execution_count": null,
|
| 83 |
+
"id": "b2301a8d-aef2-4776-9d02-c35aae042267",
|
| 84 |
+
"metadata": {},
|
| 85 |
+
"outputs": [],
|
| 86 |
+
"source": []
|
| 87 |
+
},
|
| 88 |
+
{
|
| 89 |
+
"cell_type": "code",
|
| 90 |
+
"execution_count": null,
|
| 91 |
+
"id": "e28c0287-b649-4dea-a6f3-1e83a2669b69",
|
| 92 |
+
"metadata": {},
|
| 93 |
+
"outputs": [],
|
| 94 |
+
"source": [
|
| 95 |
+
"files = ['art_en.json', 'art_en2.json', 'art_non_en.json', 'art_non_en2.json', 'art_en2_specific.json', 'art_non_en_specific.json']\n",
|
| 96 |
+
"tt = {}\n",
|
| 97 |
+
"for file in files:\n",
|
| 98 |
+
" fo = open('raw_articles/' + file)\n",
|
| 99 |
+
" for line in fo:\n",
|
| 100 |
+
" d = json.loads(line)\n",
|
| 101 |
+
" if not d['_source'].get('nlp'):\n",
|
| 102 |
+
" continue\n",
|
| 103 |
+
" for t in d['_source']['nlp']['theme']:\n",
|
| 104 |
+
" tt[t] = tt.get(t, 0) + 1\n",
|
| 105 |
+
" tt = dict(sorted(tt.items(), key=lambda item: item[1]))\n",
|
| 106 |
+
"print(tt)"
|
| 107 |
+
]
|
| 108 |
+
},
|
| 109 |
+
{
|
| 110 |
+
"cell_type": "code",
|
| 111 |
+
"execution_count": null,
|
| 112 |
+
"id": "9b735036-3a34-4d20-bae9-58f8b5bc7100",
|
| 113 |
+
"metadata": {},
|
| 114 |
+
"outputs": [],
|
| 115 |
+
"source": [
|
| 116 |
+
"priority_order = [\n",
|
| 117 |
+
" \"Economics\",\n",
|
| 118 |
+
" \"Financial Crime\",\n",
|
| 119 |
+
" \"Finance\",\n",
|
| 120 |
+
" \"Lifestyle\",\n",
|
| 121 |
+
" \"Automotive\",\n",
|
| 122 |
+
" \"Science\",\n",
|
| 123 |
+
" \"Tech\",\n",
|
| 124 |
+
" \"Travel\",\n",
|
| 125 |
+
" \"Weather\",\n",
|
| 126 |
+
" \"Health\",\n",
|
| 127 |
+
" \"Crime\",\n",
|
| 128 |
+
" \"Sports\",\n",
|
| 129 |
+
" \"General\",\n",
|
| 130 |
+
" \"Business\",\n",
|
| 131 |
+
" \"Politics\",\n",
|
| 132 |
+
" \"Entertainment\",\n",
|
| 133 |
+
" ]\n",
|
| 134 |
+
"\n",
|
| 135 |
+
"def sort_themes_by_priority(themes):\n",
|
| 136 |
+
" return sorted(\n",
|
| 137 |
+
" themes,\n",
|
| 138 |
+
" key=lambda theme: priority_order.index(theme)\n",
|
| 139 |
+
" if theme in priority_order else float(\"inf\")\n",
|
| 140 |
+
" )"
|
| 141 |
+
]
|
| 142 |
+
},
|
| 143 |
+
{
|
| 144 |
+
"cell_type": "code",
|
| 145 |
+
"execution_count": null,
|
| 146 |
+
"id": "e71c50e1-c159-4767-b366-f23153199e96",
|
| 147 |
+
"metadata": {},
|
| 148 |
+
"outputs": [],
|
| 149 |
+
"source": [
|
| 150 |
+
"en_ds = {}\n",
|
| 151 |
+
"\n",
|
| 152 |
+
"files = ['art_en.json', 'art_en2.json', 'art_en2_specific.json']\n",
|
| 153 |
+
"ts = set()\n",
|
| 154 |
+
"for file in files:\n",
|
| 155 |
+
" fo = open('raw_articles/' + file)\n",
|
| 156 |
+
" for line in fo:\n",
|
| 157 |
+
" d = json.loads(line)\n",
|
| 158 |
+
" art = {\"id\": d[\"_id\"], **d['_source']}\n",
|
| 159 |
+
" if not art.get('nlp'):\n",
|
| 160 |
+
" continue\n",
|
| 161 |
+
" if not art['nlp'].get('new_embedding') or not art['nlp'].get('theme'):\n",
|
| 162 |
+
" continue\n",
|
| 163 |
+
" th = hash(art['title'].lower())\n",
|
| 164 |
+
" if th in ts:\n",
|
| 165 |
+
" continue\n",
|
| 166 |
+
" ts.add(th)\n",
|
| 167 |
+
" sts = art['nlp']['theme']\n",
|
| 168 |
+
" sts = sort_themes_by_priority(sts)\n",
|
| 169 |
+
" for st in sts:\n",
|
| 170 |
+
" if st not in en_ds:\n",
|
| 171 |
+
" en_ds[st] = []\n",
|
| 172 |
+
" if len(en_ds[st]) >= 80000:\n",
|
| 173 |
+
" continue\n",
|
| 174 |
+
" en_ds[st].append(art)\n",
|
| 175 |
+
" break"
|
| 176 |
+
]
|
| 177 |
+
},
|
| 178 |
+
{
|
| 179 |
+
"cell_type": "code",
|
| 180 |
+
"execution_count": null,
|
| 181 |
+
"id": "71b6a2b2-20c1-4fb7-a10d-2c4a5eb1d1e4",
|
| 182 |
+
"metadata": {},
|
| 183 |
+
"outputs": [],
|
| 184 |
+
"source": [
|
| 185 |
+
"for st in en_ds:\n",
|
| 186 |
+
" print(st, '==>', len(en_ds[st]))"
|
| 187 |
+
]
|
| 188 |
+
},
|
| 189 |
+
{
|
| 190 |
+
"cell_type": "code",
|
| 191 |
+
"execution_count": null,
|
| 192 |
+
"id": "e8ffff0a-94ac-4d50-8ca1-179d163c458f",
|
| 193 |
+
"metadata": {},
|
| 194 |
+
"outputs": [],
|
| 195 |
+
"source": [
|
| 196 |
+
"json.dump(en_ds, open('filtered/en_ds.json', 'w'))"
|
| 197 |
+
]
|
| 198 |
+
},
|
| 199 |
+
{
|
| 200 |
+
"cell_type": "code",
|
| 201 |
+
"execution_count": null,
|
| 202 |
+
"id": "0cdd6962-d7f8-4c1b-aea5-979b0f622308",
|
| 203 |
+
"metadata": {},
|
| 204 |
+
"outputs": [],
|
| 205 |
+
"source": [
|
| 206 |
+
"from fast_langdetect import detect"
|
| 207 |
+
]
|
| 208 |
+
},
|
| 209 |
+
{
|
| 210 |
+
"cell_type": "code",
|
| 211 |
+
"execution_count": null,
|
| 212 |
+
"id": "cdb31010-0d34-47dd-8843-2fd417cea31f",
|
| 213 |
+
"metadata": {},
|
| 214 |
+
"outputs": [],
|
| 215 |
+
"source": [
|
| 216 |
+
"ml_ds = {}\n",
|
| 217 |
+
"\n",
|
| 218 |
+
"files = ['art_non_en.json', 'art_non_en2.json', 'art_non_en_specific.json']\n",
|
| 219 |
+
"ts = set()\n",
|
| 220 |
+
"for file in files:\n",
|
| 221 |
+
" fo = open('raw_articles/' + file)\n",
|
| 222 |
+
" for line in fo:\n",
|
| 223 |
+
" d = json.loads(line)\n",
|
| 224 |
+
" art = {\"id\": d[\"_id\"], **d['_source']}\n",
|
| 225 |
+
" if not art.get('nlp'):\n",
|
| 226 |
+
" continue\n",
|
| 227 |
+
" if not art['nlp'].get('new_embedding') or not art['nlp'].get('theme'):\n",
|
| 228 |
+
" continue\n",
|
| 229 |
+
" th = hash(art['title'].lower())\n",
|
| 230 |
+
" if th in ts:\n",
|
| 231 |
+
" continue\n",
|
| 232 |
+
" ts.add(th)\n",
|
| 233 |
+
"\n",
|
| 234 |
+
" lang = detect(text=art['content'], model='full', k=1)[0]['lang']\n",
|
| 235 |
+
" if lang not in ml_ds:\n",
|
| 236 |
+
" ml_ds[lang] = {}\n",
|
| 237 |
+
"\n",
|
| 238 |
+
" sts = art['nlp']['theme']\n",
|
| 239 |
+
" sts = sort_themes_by_priority(sts)\n",
|
| 240 |
+
" for st in sts:\n",
|
| 241 |
+
" if st not in ml_ds[lang]:\n",
|
| 242 |
+
" ml_ds[lang][st] = []\n",
|
| 243 |
+
" if len(ml_ds[lang][st]) >= 5000:\n",
|
| 244 |
+
" continue\n",
|
| 245 |
+
" ml_ds[lang][st].append(art)\n",
|
| 246 |
+
" break"
|
| 247 |
+
]
|
| 248 |
+
},
|
| 249 |
+
{
|
| 250 |
+
"cell_type": "code",
|
| 251 |
+
"execution_count": null,
|
| 252 |
+
"id": "0dbe9913-f48f-4f08-9c25-89e9b4e9dc71",
|
| 253 |
+
"metadata": {},
|
| 254 |
+
"outputs": [],
|
| 255 |
+
"source": []
|
| 256 |
+
},
|
| 257 |
+
{
|
| 258 |
+
"cell_type": "code",
|
| 259 |
+
"execution_count": null,
|
| 260 |
+
"id": "222f0b88-0808-426b-8ace-f57727265f79",
|
| 261 |
+
"metadata": {},
|
| 262 |
+
"outputs": [],
|
| 263 |
+
"source": [
|
| 264 |
+
"for lang in ml_ds:\n",
|
| 265 |
+
" xx = {t: len(ml_ds[lang][t]) for t in ml_ds[lang]}\n",
|
| 266 |
+
" print(lang, '==>', xx)"
|
| 267 |
+
]
|
| 268 |
+
},
|
| 269 |
+
{
|
| 270 |
+
"cell_type": "code",
|
| 271 |
+
"execution_count": null,
|
| 272 |
+
"id": "127ce293-80b4-4f9d-973f-8293df8865ad",
|
| 273 |
+
"metadata": {},
|
| 274 |
+
"outputs": [],
|
| 275 |
+
"source": [
|
| 276 |
+
"for lang in ml_ds:\n",
|
| 277 |
+
" tot = sum(len(ml_ds[lang][t]) for t in ml_ds[lang])\n",
|
| 278 |
+
" if tot < 1000:\n",
|
| 279 |
+
" continue\n",
|
| 280 |
+
" json.dump(ml_ds[lang], open('filtered/' + lang + '_ds.json', 'w'))"
|
| 281 |
+
]
|
| 282 |
+
},
|
| 283 |
+
{
|
| 284 |
+
"cell_type": "markdown",
|
| 285 |
+
"id": "44328527-b38f-4a69-90cd-92fbaa1e1fd3",
|
| 286 |
+
"metadata": {},
|
| 287 |
+
"source": [
|
| 288 |
+
"# 2. choose MMR candidates (take atleast 15K each theme and 100 per each lang, en=7K, other=8K)"
|
| 289 |
+
]
|
| 290 |
+
},
|
| 291 |
+
{
|
| 292 |
+
"cell_type": "code",
|
| 293 |
+
"execution_count": null,
|
| 294 |
+
"id": "9ac418a0-0efd-4c77-ae86-4d8c8eb107c5",
|
| 295 |
+
"metadata": {},
|
| 296 |
+
"outputs": [],
|
| 297 |
+
"source": [
|
| 298 |
+
"#!/usr/bin/env python3\n",
|
| 299 |
+
"from __future__ import annotations\n",
|
| 300 |
+
"\n",
|
| 301 |
+
"from typing import Any, Dict, List, Optional\n",
|
| 302 |
+
"import numpy as np\n",
|
| 303 |
+
"import torch\n",
|
| 304 |
+
"import faiss\n",
|
| 305 |
+
"import faiss.contrib.torch_utils # enables passing torch tensors to faiss on GPU\n",
|
| 306 |
+
"\n",
|
| 307 |
+
"\n",
|
| 308 |
+
"def torch_l2_normalize_(x: torch.Tensor, eps: float = 1e-12) -> torch.Tensor:\n",
|
| 309 |
+
" # in-place row-wise normalize\n",
|
| 310 |
+
" norms = torch.linalg.norm(x, dim=1, keepdim=True).clamp_min(eps)\n",
|
| 311 |
+
" x.div_(norms)\n",
|
| 312 |
+
" return x\n",
|
| 313 |
+
"\n",
|
| 314 |
+
"\n",
|
| 315 |
+
"class FaissGpuDiverseSelectorTorch:\n",
|
| 316 |
+
" \"\"\"\n",
|
| 317 |
+
" GPU-only pipeline:\n",
|
| 318 |
+
" - build embeddings tensor on GPU (torch)\n",
|
| 319 |
+
" - normalize on GPU\n",
|
| 320 |
+
" - build GpuIndexFlatIP (cosine via normalization)\n",
|
| 321 |
+
" - diversity selection via GPU spherical clustering (on a SAMPLE) + GPU NN mapping\n",
|
| 322 |
+
" \"\"\"\n",
|
| 323 |
+
"\n",
|
| 324 |
+
" def __init__(\n",
|
| 325 |
+
" self,\n",
|
| 326 |
+
" embedding_key: str = \"new_embedding\",\n",
|
| 327 |
+
" gpu_id: int = 0,\n",
|
| 328 |
+
" seed: int = 123,\n",
|
| 329 |
+
" temp_mem_gb: float = 6.0,\n",
|
| 330 |
+
" use_float16_storage: bool = True,\n",
|
| 331 |
+
" build_batch_size: int = 16384, # batching reduces peak CPU RAM + helps conversion\n",
|
| 332 |
+
" ):\n",
|
| 333 |
+
" self.embedding_key = embedding_key\n",
|
| 334 |
+
" self.gpu_id = int(gpu_id)\n",
|
| 335 |
+
" self.seed = int(seed)\n",
|
| 336 |
+
" self.temp_mem_bytes = int(temp_mem_gb * 1024**3)\n",
|
| 337 |
+
" self.use_float16_storage = bool(use_float16_storage)\n",
|
| 338 |
+
" self.build_batch_size = int(build_batch_size)\n",
|
| 339 |
+
"\n",
|
| 340 |
+
" self.items: Optional[List[Dict[str, Any]]] = None\n",
|
| 341 |
+
" self.xb: Optional[torch.Tensor] = None # (N,d) on GPU, float32 normalized\n",
|
| 342 |
+
" self.res: Optional[faiss.StandardGpuResources] = None\n",
|
| 343 |
+
" self.index: Optional[faiss.GpuIndexFlatIP] = None\n",
|
| 344 |
+
" self.d: Optional[int] = None\n",
|
| 345 |
+
"\n",
|
| 346 |
+
" def fit(self, items: List[Dict[str, Any]]) -> \"FaissGpuDiverseSelectorTorch\":\n",
|
| 347 |
+
" ngpu = faiss.get_num_gpus()\n",
|
| 348 |
+
" if ngpu <= 0:\n",
|
| 349 |
+
" raise RuntimeError(\"faiss.get_num_gpus()==0. You are not using faiss-gpu / CUDA not visible.\")\n",
|
| 350 |
+
" if self.gpu_id >= ngpu:\n",
|
| 351 |
+
" raise RuntimeError(f\"gpu_id={self.gpu_id} out of range; FAISS sees {ngpu} GPUs.\")\n",
|
| 352 |
+
"\n",
|
| 353 |
+
" self.items = items\n",
|
| 354 |
+
" n = len(items)\n",
|
| 355 |
+
" if n == 0:\n",
|
| 356 |
+
" raise ValueError(\"Empty items\")\n",
|
| 357 |
+
"\n",
|
| 358 |
+
" # Infer dim from first embedding\n",
|
| 359 |
+
" first = items[0][self.embedding_key]\n",
|
| 360 |
+
" d = len(first)\n",
|
| 361 |
+
" self.d = d\n",
|
| 362 |
+
"\n",
|
| 363 |
+
" # Build GPU tensor in batches (still CPU->GPU copy, but avoids extra NumPy overhead)\n",
|
| 364 |
+
" dev = torch.device(f\"cuda:{self.gpu_id}\")\n",
|
| 365 |
+
" chunks: List[torch.Tensor] = []\n",
|
| 366 |
+
"\n",
|
| 367 |
+
" bs = self.build_batch_size\n",
|
| 368 |
+
" for start in range(0, n, bs):\n",
|
| 369 |
+
" end = min(start + bs, n)\n",
|
| 370 |
+
" # NOTE: this Python loop is unavoidable with list-of-dicts.\n",
|
| 371 |
+
" # If you can store embeddings as a single array/tensor upstream, do that instead.\n",
|
| 372 |
+
" batch = np.asarray([items[i][self.embedding_key] for i in range(start, end)], dtype=np.float32)\n",
|
| 373 |
+
" t = torch.from_numpy(batch).to(device=dev, non_blocking=False)\n",
|
| 374 |
+
" chunks.append(t)\n",
|
| 375 |
+
"\n",
|
| 376 |
+
" xb = torch.cat(chunks, dim=0) # (N,d) float32 on GPU\n",
|
| 377 |
+
"\n",
|
| 378 |
+
" # Normalize on GPU\n",
|
| 379 |
+
" torch_l2_normalize_(xb)\n",
|
| 380 |
+
"\n",
|
| 381 |
+
" # Create and keep GPU resources + big temp memory\n",
|
| 382 |
+
" res = faiss.StandardGpuResources()\n",
|
| 383 |
+
" res.setTempMemory(self.temp_mem_bytes)\n",
|
| 384 |
+
" self.res = res\n",
|
| 385 |
+
"\n",
|
| 386 |
+
" # GPU IP index (cosine via normalization)\n",
|
| 387 |
+
" cfg = faiss.GpuIndexFlatConfig()\n",
|
| 388 |
+
" cfg.device = self.gpu_id\n",
|
| 389 |
+
" cfg.useFloat16 = self.use_float16_storage\n",
|
| 390 |
+
"\n",
|
| 391 |
+
" index = faiss.GpuIndexFlatIP(res, d, cfg)\n",
|
| 392 |
+
"\n",
|
| 393 |
+
" # Add directly from GPU torch tensor (faiss.contrib.torch_utils)\n",
|
| 394 |
+
" index.add(xb)\n",
|
| 395 |
+
"\n",
|
| 396 |
+
" self.xb = xb\n",
|
| 397 |
+
" self.index = index\n",
|
| 398 |
+
" return self\n",
|
| 399 |
+
"\n",
|
| 400 |
+
" def select(\n",
|
| 401 |
+
" self,\n",
|
| 402 |
+
" num_select: int,\n",
|
| 403 |
+
" train_per_centroid: int = 500, # IMPORTANT: keep this modest to avoid “forever”\n",
|
| 404 |
+
" niter: int = 6,\n",
|
| 405 |
+
" nredo: int = 1,\n",
|
| 406 |
+
" centroid_search_k: int = 128,\n",
|
| 407 |
+
" ) -> List[Dict[str, Any]]:\n",
|
| 408 |
+
" if self.items is None or self.xb is None or self.index is None or self.res is None or self.d is None:\n",
|
| 409 |
+
" raise RuntimeError(\"Call fit() first.\")\n",
|
| 410 |
+
"\n",
|
| 411 |
+
" items = self.items\n",
|
| 412 |
+
" xb = self.xb\n",
|
| 413 |
+
" index = self.index\n",
|
| 414 |
+
" res = self.res\n",
|
| 415 |
+
" d = self.d\n",
|
| 416 |
+
"\n",
|
| 417 |
+
" N = xb.shape[0]\n",
|
| 418 |
+
" k = min(int(num_select), int(N))\n",
|
| 419 |
+
" if k <= 0:\n",
|
| 420 |
+
" return []\n",
|
| 421 |
+
"\n",
|
| 422 |
+
" # Sample training data ON GPU\n",
|
| 423 |
+
" train_sz = min(int(N), k * int(train_per_centroid))\n",
|
| 424 |
+
" g = torch.Generator(device=xb.device)\n",
|
| 425 |
+
" g.manual_seed(self.seed)\n",
|
| 426 |
+
"\n",
|
| 427 |
+
" if train_sz < N:\n",
|
| 428 |
+
" perm = torch.randperm(int(N), generator=g, device=xb.device)\n",
|
| 429 |
+
" train_idx = perm[:train_sz]\n",
|
| 430 |
+
" xtrain = xb.index_select(0, train_idx)\n",
|
| 431 |
+
" else:\n",
|
| 432 |
+
" xtrain = xb\n",
|
| 433 |
+
"\n",
|
| 434 |
+
" # FAISS GPU spherical clustering\n",
|
| 435 |
+
" clus = faiss.Clustering(d, k)\n",
|
| 436 |
+
" clus.seed = self.seed\n",
|
| 437 |
+
" clus.niter = int(niter)\n",
|
| 438 |
+
" clus.nredo = int(nredo)\n",
|
| 439 |
+
" clus.spherical = True\n",
|
| 440 |
+
" clus.verbose = False\n",
|
| 441 |
+
"\n",
|
| 442 |
+
" cfg = faiss.GpuIndexFlatConfig()\n",
|
| 443 |
+
" cfg.device = self.gpu_id\n",
|
| 444 |
+
" cfg.useFloat16 = self.use_float16_storage\n",
|
| 445 |
+
" assign_index = faiss.GpuIndexFlatIP(res, d, cfg)\n",
|
| 446 |
+
"\n",
|
| 447 |
+
" clus.train(xtrain.cpu().numpy(), assign_index)\n",
|
| 448 |
+
"\n",
|
| 449 |
+
" centroids = faiss.vector_to_array(clus.centroids).reshape(k, d).astype(np.float32, copy=False)\n",
|
| 450 |
+
" # Move centroids to GPU torch, normalize, then search on GPU\n",
|
| 451 |
+
" C = torch.from_numpy(centroids).to(device=xb.device)\n",
|
| 452 |
+
" torch_l2_normalize_(C)\n",
|
| 453 |
+
"\n",
|
| 454 |
+
" centroid_search_k = max(int(centroid_search_k), 1)\n",
|
| 455 |
+
" _, I = index.search(C, centroid_search_k) # I is a torch tensor (thanks to torch_utils)\n",
|
| 456 |
+
"\n",
|
| 457 |
+
" chosen: List[int] = []\n",
|
| 458 |
+
" used = set()\n",
|
| 459 |
+
"\n",
|
| 460 |
+
" I_cpu = I.to(\"cpu\").numpy() # small: (k, centroid_search_k)\n",
|
| 461 |
+
" for r in range(k):\n",
|
| 462 |
+
" pick = None\n",
|
| 463 |
+
" for j in range(centroid_search_k):\n",
|
| 464 |
+
" idx = int(I_cpu[r, j])\n",
|
| 465 |
+
" if idx >= 0 and idx not in used:\n",
|
| 466 |
+
" pick = idx\n",
|
| 467 |
+
" break\n",
|
| 468 |
+
" if pick is not None:\n",
|
| 469 |
+
" used.add(pick)\n",
|
| 470 |
+
" chosen.append(pick)\n",
|
| 471 |
+
"\n",
|
| 472 |
+
" # Fill if duplicates reduced count\n",
|
| 473 |
+
" if len(chosen) < k:\n",
|
| 474 |
+
" for r in range(k):\n",
|
| 475 |
+
" if len(chosen) >= k:\n",
|
| 476 |
+
" break\n",
|
| 477 |
+
" for j in range(centroid_search_k):\n",
|
| 478 |
+
" idx = int(I_cpu[r, j])\n",
|
| 479 |
+
" if idx >= 0 and idx not in used:\n",
|
| 480 |
+
" used.add(idx)\n",
|
| 481 |
+
" chosen.append(idx)\n",
|
| 482 |
+
" if len(chosen) >= k:\n",
|
| 483 |
+
" break\n",
|
| 484 |
+
"\n",
|
| 485 |
+
" return [items[i] for i in chosen[:k]]"
|
| 486 |
+
]
|
| 487 |
+
},
|
| 488 |
+
{
|
| 489 |
+
"cell_type": "code",
|
| 490 |
+
"execution_count": null,
|
| 491 |
+
"id": "2ab2368e-7ec0-4870-b092-243b6c45bd70",
|
| 492 |
+
"metadata": {},
|
| 493 |
+
"outputs": [],
|
| 494 |
+
"source": []
|
| 495 |
+
},
|
| 496 |
+
{
|
| 497 |
+
"cell_type": "code",
|
| 498 |
+
"execution_count": null,
|
| 499 |
+
"id": "d5fd1433-1ffd-489f-8bc7-bca2d76a22d4",
|
| 500 |
+
"metadata": {},
|
| 501 |
+
"outputs": [],
|
| 502 |
+
"source": []
|
| 503 |
+
},
|
| 504 |
+
{
|
| 505 |
+
"cell_type": "code",
|
| 506 |
+
"execution_count": null,
|
| 507 |
+
"id": "c4a937d7-e6bf-4f98-bbcd-ed9fb746b731",
|
| 508 |
+
"metadata": {},
|
| 509 |
+
"outputs": [],
|
| 510 |
+
"source": [
|
| 511 |
+
"files = os.listdir('filtered')\n",
|
| 512 |
+
"files = [i for i in files if i != 'en_ds.json' and i.endswith('.json')]"
|
| 513 |
+
]
|
| 514 |
+
},
|
| 515 |
+
{
|
| 516 |
+
"cell_type": "code",
|
| 517 |
+
"execution_count": null,
|
| 518 |
+
"id": "c3c93e4b-5f2e-4ad2-9b39-59942539bce5",
|
| 519 |
+
"metadata": {
|
| 520 |
+
"scrolled": true
|
| 521 |
+
},
|
| 522 |
+
"outputs": [],
|
| 523 |
+
"source": [
|
| 524 |
+
"non_en = {}\n",
|
| 525 |
+
"for file in files:\n",
|
| 526 |
+
" data = json.load(open('filtered/'+file))\n",
|
| 527 |
+
" for theme in data:\n",
|
| 528 |
+
" data[theme] = [{\"id\": i[\"id\"], \"title\": i[\"title\"], \"content\": i[\"content\"], \"language\": i[\"language\"], **i[\"nlp\"]} for i in data[theme]]\n",
|
| 529 |
+
" mmr_cands = len(data[theme])//5\n",
|
| 530 |
+
" if mmr_cands <= 2:\n",
|
| 531 |
+
" continue\n",
|
| 532 |
+
" selector = FaissGpuDiverseSelectorTorch(\n",
|
| 533 |
+
" embedding_key=\"new_embedding\",\n",
|
| 534 |
+
" gpu_id=0,\n",
|
| 535 |
+
" temp_mem_gb=6.0,\n",
|
| 536 |
+
" use_float16_storage=True,\n",
|
| 537 |
+
" build_batch_size=8192,\n",
|
| 538 |
+
" ).fit(data[theme])\n",
|
| 539 |
+
" picked = selector.select(\n",
|
| 540 |
+
" num_select=mmr_cands,\n",
|
| 541 |
+
" train_per_centroid=500,\n",
|
| 542 |
+
" niter=100,\n",
|
| 543 |
+
" centroid_search_k=512,\n",
|
| 544 |
+
" )\n",
|
| 545 |
+
" if theme not in non_en:\n",
|
| 546 |
+
" non_en[theme] = []\n",
|
| 547 |
+
" print(file, theme, len(data[theme]), len(picked), mmr_cands)\n",
|
| 548 |
+
" non_en[theme].extend(picked)\n"
|
| 549 |
+
]
|
| 550 |
+
},
|
| 551 |
+
{
|
| 552 |
+
"cell_type": "code",
|
| 553 |
+
"execution_count": null,
|
| 554 |
+
"id": "499a2017-1bdc-44ea-b886-6a55bf7f8a36",
|
| 555 |
+
"metadata": {
|
| 556 |
+
"scrolled": true
|
| 557 |
+
},
|
| 558 |
+
"outputs": [],
|
| 559 |
+
"source": [
|
| 560 |
+
"final_ds = {}\n",
|
| 561 |
+
"for theme in non_en:\n",
|
| 562 |
+
" selector = FaissGpuDiverseSelectorTorch(\n",
|
| 563 |
+
" embedding_key=\"new_embedding\",\n",
|
| 564 |
+
" gpu_id=0,\n",
|
| 565 |
+
" temp_mem_gb=6.0,\n",
|
| 566 |
+
" use_float16_storage=True,\n",
|
| 567 |
+
" build_batch_size=8192,\n",
|
| 568 |
+
" ).fit(non_en[theme])\n",
|
| 569 |
+
" picked = selector.select(\n",
|
| 570 |
+
" num_select=7000,\n",
|
| 571 |
+
" train_per_centroid=500,\n",
|
| 572 |
+
" niter=100,\n",
|
| 573 |
+
" centroid_search_k=512,\n",
|
| 574 |
+
" )\n",
|
| 575 |
+
" if theme not in final_ds:\n",
|
| 576 |
+
" final_ds[theme] = []\n",
|
| 577 |
+
" print(theme, len(non_en[theme]), len(picked))\n",
|
| 578 |
+
" for i in picked:\n",
|
| 579 |
+
" del i['new_embedding']\n",
|
| 580 |
+
" final_ds[theme].extend(picked)\n",
|
| 581 |
+
"\n",
|
| 582 |
+
"json.dump(final_ds, open('filtered2/ml_step1.json', 'w'))"
|
| 583 |
+
]
|
| 584 |
+
},
|
| 585 |
+
{
|
| 586 |
+
"cell_type": "code",
|
| 587 |
+
"execution_count": null,
|
| 588 |
+
"id": "7161b895-b229-47a6-9082-b493fbd7428a",
|
| 589 |
+
"metadata": {},
|
| 590 |
+
"outputs": [],
|
| 591 |
+
"source": [
|
| 592 |
+
"en = json.load(open('filtered/en_ds.json'))"
|
| 593 |
+
]
|
| 594 |
+
},
|
| 595 |
+
{
|
| 596 |
+
"cell_type": "code",
|
| 597 |
+
"execution_count": null,
|
| 598 |
+
"id": "ce6213a5-40bb-454e-950b-b4b23e045e13",
|
| 599 |
+
"metadata": {
|
| 600 |
+
"scrolled": true
|
| 601 |
+
},
|
| 602 |
+
"outputs": [],
|
| 603 |
+
"source": [
|
| 604 |
+
"final_ds = {}\n",
|
| 605 |
+
"for theme in en:\n",
|
| 606 |
+
" en[theme] = [{\"id\": i[\"id\"], \"title\": i[\"title\"], \"content\": i[\"content\"], \"language\": i[\"language\"], **i[\"nlp\"]} for i in en[theme]]\n",
|
| 607 |
+
" selector = FaissGpuDiverseSelectorTorch(\n",
|
| 608 |
+
" embedding_key=\"new_embedding\",\n",
|
| 609 |
+
" gpu_id=0,\n",
|
| 610 |
+
" temp_mem_gb=6.0,\n",
|
| 611 |
+
" use_float16_storage=True,\n",
|
| 612 |
+
" build_batch_size=8192,\n",
|
| 613 |
+
" ).fit(en[theme])\n",
|
| 614 |
+
" picked = selector.select(\n",
|
| 615 |
+
" num_select=8000,\n",
|
| 616 |
+
" train_per_centroid=500,\n",
|
| 617 |
+
" niter=100,\n",
|
| 618 |
+
" centroid_search_k=512,\n",
|
| 619 |
+
" )\n",
|
| 620 |
+
" if theme not in final_ds:\n",
|
| 621 |
+
" final_ds[theme] = []\n",
|
| 622 |
+
" print(theme, len(en[theme]), len(picked))\n",
|
| 623 |
+
" for i in picked:\n",
|
| 624 |
+
" del i['new_embedding']\n",
|
| 625 |
+
" final_ds[theme].extend(picked)\n",
|
| 626 |
+
"\n",
|
| 627 |
+
"json.dump(final_ds, open('filtered2/en_step1.json', 'w'))"
|
| 628 |
+
]
|
| 629 |
+
},
|
| 630 |
+
{
|
| 631 |
+
"cell_type": "markdown",
|
| 632 |
+
"id": "842369c9-6624-404d-808b-1e39a4a6a1b0",
|
| 633 |
+
"metadata": {},
|
| 634 |
+
"source": [
|
| 635 |
+
"# 3.calculate embedding of qwen"
|
| 636 |
+
]
|
| 637 |
+
},
|
| 638 |
+
{
|
| 639 |
+
"cell_type": "code",
|
| 640 |
+
"execution_count": null,
|
| 641 |
+
"id": "bfce123b-6d47-4353-b36c-3db708274742",
|
| 642 |
+
"metadata": {},
|
| 643 |
+
"outputs": [],
|
| 644 |
+
"source": [
|
| 645 |
+
"import requests\n",
|
| 646 |
+
"from requests.adapters import HTTPAdapter\n",
|
| 647 |
+
"from concurrent.futures import ThreadPoolExecutor, as_completed\n",
|
| 648 |
+
"from typing import List, Dict, Any, Optional, Tuple\n",
|
| 649 |
+
"\n",
|
| 650 |
+
"def _make_session(pool_size: int) -> requests.Session:\n",
|
| 651 |
+
" s = requests.Session()\n",
|
| 652 |
+
" adapter = HTTPAdapter(pool_connections=pool_size, pool_maxsize=pool_size, max_retries=0)\n",
|
| 653 |
+
" s.mount(\"http://\", adapter)\n",
|
| 654 |
+
" s.mount(\"https://\", adapter)\n",
|
| 655 |
+
" return s\n",
|
| 656 |
+
"\n",
|
| 657 |
+
"\n",
|
| 658 |
+
"def _embed_batch(\n",
|
| 659 |
+
" session: requests.Session,\n",
|
| 660 |
+
" base_url: str,\n",
|
| 661 |
+
" text: str,\n",
|
| 662 |
+
" timeout_s: float,\n",
|
| 663 |
+
" extra_headers = {'x-api-token': '***'},\n",
|
| 664 |
+
") -> Any:\n",
|
| 665 |
+
" \"\"\"\n",
|
| 666 |
+
" Native TEI endpoint: POST {base_url}/embed\n",
|
| 667 |
+
" Payload: {\"inputs\": [..texts..]}\n",
|
| 668 |
+
" \"\"\"\n",
|
| 669 |
+
" url = base_url.rstrip(\"/\") + \"/qwen\"\n",
|
| 670 |
+
" headers = {\"Content-Type\": \"application/json\"}\n",
|
| 671 |
+
" if not text.startswith('query: '):\n",
|
| 672 |
+
" text = 'query: ' + text\n",
|
| 673 |
+
" if extra_headers:\n",
|
| 674 |
+
" headers.update(extra_headers)\n",
|
| 675 |
+
" r = session.post(url, json={\"inputs\": text}, headers=headers, timeout=timeout_s)\n",
|
| 676 |
+
" r.raise_for_status()\n",
|
| 677 |
+
" return r.json()\n",
|
| 678 |
+
"\n",
|
| 679 |
+
"def one_call(art):\n",
|
| 680 |
+
" if \"embedding\" in art:\n",
|
| 681 |
+
" return\n",
|
| 682 |
+
" text = []\n",
|
| 683 |
+
" if art.get(\"title\"):\n",
|
| 684 |
+
" text.append(art[\"title\"])\n",
|
| 685 |
+
" if art.get(\"content\"):\n",
|
| 686 |
+
" text.append(art[\"content\"])\n",
|
| 687 |
+
" text = \"\\n\\n\".join(text)\n",
|
| 688 |
+
" try:\n",
|
| 689 |
+
" res = _embed_batch(session, base_url, text, timeout_s=90)\n",
|
| 690 |
+
" except Exception as e:\n",
|
| 691 |
+
" print(e)\n",
|
| 692 |
+
" art[\"embedding\"] = res[0]\n",
|
| 693 |
+
"\n",
|
| 694 |
+
"WORKERS = 512\n",
|
| 695 |
+
"session = _make_session(pool_size=WORKERS)\n",
|
| 696 |
+
"base_url = \"http://65.19.132.154:9001\""
|
| 697 |
+
]
|
| 698 |
+
},
|
| 699 |
+
{
|
| 700 |
+
"cell_type": "code",
|
| 701 |
+
"execution_count": null,
|
| 702 |
+
"id": "e5e64eac-dc40-41ec-a497-45ce9cb2f637",
|
| 703 |
+
"metadata": {},
|
| 704 |
+
"outputs": [],
|
| 705 |
+
"source": []
|
| 706 |
+
},
|
| 707 |
+
{
|
| 708 |
+
"cell_type": "code",
|
| 709 |
+
"execution_count": null,
|
| 710 |
+
"id": "7304ec1b-2a7a-42f7-a0fb-f6e5e6c4eb94",
|
| 711 |
+
"metadata": {},
|
| 712 |
+
"outputs": [],
|
| 713 |
+
"source": [
|
| 714 |
+
"for file in ['en_step1.json', 'ml_step1.json']:\n",
|
| 715 |
+
" data = json.load(open('filtered2/' + file))\n",
|
| 716 |
+
" for theme in data:\n",
|
| 717 |
+
" with ThreadPoolExecutor(max_workers=WORKERS) as ex:\n",
|
| 718 |
+
" _ = [ex.submit(one_call, art) for art in data[theme]]\n",
|
| 719 |
+
" json.dump(data, open('filtered2/' + file, 'w'))"
|
| 720 |
+
]
|
| 721 |
+
},
|
| 722 |
+
{
|
| 723 |
+
"cell_type": "markdown",
|
| 724 |
+
"id": "59b381bb-d4ca-4414-b299-3da7531faed3",
|
| 725 |
+
"metadata": {
|
| 726 |
+
"jp-MarkdownHeadingCollapsed": true
|
| 727 |
+
},
|
| 728 |
+
"source": [
|
| 729 |
+
"# 4. again choose MMR based on qwen embedding also make list of llm verify candidate when qwen emb not matching with given themes (5K each)"
|
| 730 |
+
]
|
| 731 |
+
},
|
| 732 |
+
{
|
| 733 |
+
"cell_type": "code",
|
| 734 |
+
"execution_count": null,
|
| 735 |
+
"id": "80e76f11-f5de-4c3e-b130-5646b4bf5273",
|
| 736 |
+
"metadata": {},
|
| 737 |
+
"outputs": [],
|
| 738 |
+
"source": [
|
| 739 |
+
"from fast_langdetect import detect\n",
|
| 740 |
+
"\n",
|
| 741 |
+
"en_data, ml_data = [], []\n",
|
| 742 |
+
"data = json.load(open('filtered2/en_step1.json'))\n",
|
| 743 |
+
"for theme in data:\n",
|
| 744 |
+
" en_data.extend(data[theme])\n",
|
| 745 |
+
"\n",
|
| 746 |
+
"data = json.load(open('filtered2/ml_step1.json'))\n",
|
| 747 |
+
"for theme in data:\n",
|
| 748 |
+
" ml_data.extend(data[theme])\n",
|
| 749 |
+
"\n",
|
| 750 |
+
"for file in ['train1.json', 'spanish_tagged.json', 'train_multilang_2.json']:\n",
|
| 751 |
+
" data = json.load(open('filtered2/' + file))\n",
|
| 752 |
+
" for i in data:\n",
|
| 753 |
+
" if not i.get('content'):\n",
|
| 754 |
+
" continue\n",
|
| 755 |
+
" l = i.get('language')\n",
|
| 756 |
+
" if not l:\n",
|
| 757 |
+
" l = i.get('language')\n",
|
| 758 |
+
" if not l:\n",
|
| 759 |
+
" l = detect(text=i['content'], model='full', k=1)[0]['lang']\n",
|
| 760 |
+
" t = {**i, \"language\": l, \"file\": file}\n",
|
| 761 |
+
" if l == 'en':\n",
|
| 762 |
+
" en_data.append(t)\n",
|
| 763 |
+
" else:\n",
|
| 764 |
+
" ml_data.append(t)"
|
| 765 |
+
]
|
| 766 |
+
},
|
| 767 |
+
{
|
| 768 |
+
"cell_type": "code",
|
| 769 |
+
"execution_count": null,
|
| 770 |
+
"id": "1a435d53-1cf9-46e3-809a-208bb0a76122",
|
| 771 |
+
"metadata": {},
|
| 772 |
+
"outputs": [],
|
| 773 |
+
"source": [
|
| 774 |
+
"# for i in ml_data:\n",
|
| 775 |
+
"# if \"embedding\" in i:\n",
|
| 776 |
+
"# del i[\"embedding\"]"
|
| 777 |
+
]
|
| 778 |
+
},
|
| 779 |
+
{
|
| 780 |
+
"cell_type": "code",
|
| 781 |
+
"execution_count": null,
|
| 782 |
+
"id": "b5979e1a-4e94-45ee-937d-102550918ef7",
|
| 783 |
+
"metadata": {},
|
| 784 |
+
"outputs": [],
|
| 785 |
+
"source": [
|
| 786 |
+
"with ThreadPoolExecutor(max_workers=WORKERS) as ex:\n",
|
| 787 |
+
" _ = [ex.submit(one_call, art) for art in en_data]"
|
| 788 |
+
]
|
| 789 |
+
},
|
| 790 |
+
{
|
| 791 |
+
"cell_type": "code",
|
| 792 |
+
"execution_count": null,
|
| 793 |
+
"id": "63e62230-bc41-442c-a612-08b22abdd0b3",
|
| 794 |
+
"metadata": {},
|
| 795 |
+
"outputs": [],
|
| 796 |
+
"source": [
|
| 797 |
+
"with ThreadPoolExecutor(max_workers=WORKERS) as ex:\n",
|
| 798 |
+
" _ = [ex.submit(one_call, art) for art in ml_data]"
|
| 799 |
+
]
|
| 800 |
+
},
|
| 801 |
+
{
|
| 802 |
+
"cell_type": "code",
|
| 803 |
+
"execution_count": null,
|
| 804 |
+
"id": "59d5c0cf-37ab-406c-ad53-042c7af451cd",
|
| 805 |
+
"metadata": {},
|
| 806 |
+
"outputs": [],
|
| 807 |
+
"source": [
|
| 808 |
+
"len(en_data), len(ml_data)"
|
| 809 |
+
]
|
| 810 |
+
},
|
| 811 |
+
{
|
| 812 |
+
"cell_type": "code",
|
| 813 |
+
"execution_count": null,
|
| 814 |
+
"id": "3e5543ef-53d6-4e97-923d-203a0382e878",
|
| 815 |
+
"metadata": {},
|
| 816 |
+
"outputs": [],
|
| 817 |
+
"source": [
|
| 818 |
+
"json.dump(en_data, open('filtered/en_step2.json', 'w'))"
|
| 819 |
+
]
|
| 820 |
+
},
|
| 821 |
+
{
|
| 822 |
+
"cell_type": "code",
|
| 823 |
+
"execution_count": null,
|
| 824 |
+
"id": "b8cedffd-36f8-4684-8126-2b6b78d67aec",
|
| 825 |
+
"metadata": {},
|
| 826 |
+
"outputs": [],
|
| 827 |
+
"source": [
|
| 828 |
+
"json.dump(ml_data, open('filtered/ml_step2.json', 'w'))"
|
| 829 |
+
]
|
| 830 |
+
},
|
| 831 |
+
{
|
| 832 |
+
"cell_type": "code",
|
| 833 |
+
"execution_count": null,
|
| 834 |
+
"id": "b5c6d123-12a2-4275-a4b0-1d43951ed922",
|
| 835 |
+
"metadata": {},
|
| 836 |
+
"outputs": [],
|
| 837 |
+
"source": [
|
| 838 |
+
"# easy to confirm with theme embedding first and find mismatches\n",
|
| 839 |
+
"themes = [\n",
|
| 840 |
+
" \"Economics\",\n",
|
| 841 |
+
" \"Financial Crime\",\n",
|
| 842 |
+
" \"Finance\",\n",
|
| 843 |
+
" \"Lifestyle\",\n",
|
| 844 |
+
" \"Automotive\",\n",
|
| 845 |
+
" \"Science\",\n",
|
| 846 |
+
" \"Tech\",\n",
|
| 847 |
+
" \"Travel\",\n",
|
| 848 |
+
" \"Weather\",\n",
|
| 849 |
+
" \"Health\",\n",
|
| 850 |
+
" \"Crime\",\n",
|
| 851 |
+
" \"Sports\",\n",
|
| 852 |
+
" \"General\",\n",
|
| 853 |
+
" \"Business\",\n",
|
| 854 |
+
" \"Politics\",\n",
|
| 855 |
+
" \"Entertainment\",\n",
|
| 856 |
+
" ]\n",
|
| 857 |
+
"theme_emb = {i: _embed_batch(session, base_url, i, timeout_s=90)[0] for i in themes}"
|
| 858 |
+
]
|
| 859 |
+
},
|
| 860 |
+
{
|
| 861 |
+
"cell_type": "code",
|
| 862 |
+
"execution_count": null,
|
| 863 |
+
"id": "d041e14a-cb28-4714-afbe-582ccb64fcbb",
|
| 864 |
+
"metadata": {},
|
| 865 |
+
"outputs": [],
|
| 866 |
+
"source": [
|
| 867 |
+
"import numpy as np\n",
|
| 868 |
+
"\n",
|
| 869 |
+
"def l2_normalize(x, axis=-1, eps=1e-12):\n",
|
| 870 |
+
" x = np.asarray(x, dtype=np.float32)\n",
|
| 871 |
+
" return x / (np.linalg.norm(x, axis=axis, keepdims=True) + eps)\n",
|
| 872 |
+
"\n",
|
| 873 |
+
"def prepare_theme_matrix(theme_to_emb: dict):\n",
|
| 874 |
+
" theme_names = list(theme_to_emb.keys())\n",
|
| 875 |
+
" theme_mat = np.vstack([theme_to_emb[t] for t in theme_names]).astype(np.float32)\n",
|
| 876 |
+
" theme_mat = l2_normalize(theme_mat, axis=1) # (T, D)\n",
|
| 877 |
+
" return theme_names, theme_mat\n",
|
| 878 |
+
"\n",
|
| 879 |
+
"def similarities(article_mat, theme_mat):\n",
|
| 880 |
+
" # article_mat: (N, D) theme_mat: (T, D)\n",
|
| 881 |
+
" article_mat = l2_normalize(article_mat, axis=1)\n",
|
| 882 |
+
" return article_mat @ theme_mat.T # (N, T) cosine similarities\n",
|
| 883 |
+
"\n",
|
| 884 |
+
"def pick_by_gap_row(scores, gap=0.10, min_score=0.25, max_k=3):\n",
|
| 885 |
+
" # scores: (T,) raw similarity\n",
|
| 886 |
+
" idx = np.argsort(scores)[::-1]\n",
|
| 887 |
+
" s = scores[idx]\n",
|
| 888 |
+
"\n",
|
| 889 |
+
" # keep while above floor\n",
|
| 890 |
+
" keep = np.where(s >= min_score)[0]\n",
|
| 891 |
+
" if keep.size == 0:\n",
|
| 892 |
+
" return idx[:0] # none\n",
|
| 893 |
+
" last = keep[-1]\n",
|
| 894 |
+
"\n",
|
| 895 |
+
" # stop at first big drop\n",
|
| 896 |
+
" drops = s[:-1] - s[1:]\n",
|
| 897 |
+
" cut_points = np.where(drops >= gap)[0]\n",
|
| 898 |
+
" if cut_points.size > 0:\n",
|
| 899 |
+
" last = min(last, cut_points[0])\n",
|
| 900 |
+
"\n",
|
| 901 |
+
" last = min(last, max_k - 1) # cap number of themes\n",
|
| 902 |
+
" return idx[: last + 1]\n",
|
| 903 |
+
"\n",
|
| 904 |
+
"def assign_themes(article_mat, theme_to_emb, gap=0.08, min_score=0.2, max_k=5):\n",
|
| 905 |
+
" theme_names, theme_mat = prepare_theme_matrix(theme_to_emb)\n",
|
| 906 |
+
" sim = similarities(article_mat, theme_mat) # (N, T)\n",
|
| 907 |
+
"\n",
|
| 908 |
+
" results = []\n",
|
| 909 |
+
" for i in range(sim.shape[0]):\n",
|
| 910 |
+
" chosen_idx = pick_by_gap_row(sim[i], gap=gap, min_score=min_score, max_k=max_k)\n",
|
| 911 |
+
" results.append([theme_names[j] for j in chosen_idx])\n",
|
| 912 |
+
" return results\n",
|
| 913 |
+
"\n",
|
| 914 |
+
"def is_prediction_on_top(my_prediction, emb_prediction):\n",
|
| 915 |
+
" \"\"\"\n",
|
| 916 |
+
" Returns True if:\n",
|
| 917 |
+
" - len(my_prediction)==1: the single predicted theme is the top-1 embedding theme\n",
|
| 918 |
+
" - len(my_prediction)>1: the top-K embedding themes (K=len(my_prediction)) match my_prediction as a set\n",
|
| 919 |
+
" (order-insensitive), meaning all predicted themes are \"on top\" together.\n",
|
| 920 |
+
" \"\"\"\n",
|
| 921 |
+
" if not my_prediction or not emb_prediction:\n",
|
| 922 |
+
" return False\n",
|
| 923 |
+
"\n",
|
| 924 |
+
" k = len(my_prediction)\n",
|
| 925 |
+
" topk = emb_prediction[:k]\n",
|
| 926 |
+
"\n",
|
| 927 |
+
" if k == 1:\n",
|
| 928 |
+
" return my_prediction[0] == emb_prediction[0]\n",
|
| 929 |
+
"\n",
|
| 930 |
+
" return set(my_prediction) == set(topk)"
|
| 931 |
+
]
|
| 932 |
+
},
|
| 933 |
+
{
|
| 934 |
+
"cell_type": "code",
|
| 935 |
+
"execution_count": null,
|
| 936 |
+
"id": "7c72387e-4343-47f0-88fd-ef964a03fbcd",
|
| 937 |
+
"metadata": {},
|
| 938 |
+
"outputs": [],
|
| 939 |
+
"source": [
|
| 940 |
+
"is_prediction_on_top(en_data[500]['themes'], assign_themes([en_data[500]['embedding']], theme_emb)[0])"
|
| 941 |
+
]
|
| 942 |
+
},
|
| 943 |
+
{
|
| 944 |
+
"cell_type": "code",
|
| 945 |
+
"execution_count": null,
|
| 946 |
+
"id": "61121189-f649-4660-a4d0-973cae54de6e",
|
| 947 |
+
"metadata": {},
|
| 948 |
+
"outputs": [],
|
| 949 |
+
"source": [
|
| 950 |
+
"results = assign_themes([i['embedding'] for i in en_data], theme_emb)\n",
|
| 951 |
+
"nb = 0\n",
|
| 952 |
+
"for i, j in zip(en_data, results):\n",
|
| 953 |
+
" if is_prediction_on_top(i['themes'], j):\n",
|
| 954 |
+
" i['need_to_validate'] = False\n",
|
| 955 |
+
" nb += 1\n",
|
| 956 |
+
" else:\n",
|
| 957 |
+
" i['need_to_validate'] = True\n",
|
| 958 |
+
"nb, len(en_data)"
|
| 959 |
+
]
|
| 960 |
+
},
|
| 961 |
+
{
|
| 962 |
+
"cell_type": "code",
|
| 963 |
+
"execution_count": null,
|
| 964 |
+
"id": "2895b09b-130c-4c57-b0e3-4b8bcbc41f0c",
|
| 965 |
+
"metadata": {},
|
| 966 |
+
"outputs": [],
|
| 967 |
+
"source": [
|
| 968 |
+
"results = assign_themes([i['embedding'] for i in ml_data], theme_emb)\n",
|
| 969 |
+
"nb = 0\n",
|
| 970 |
+
"for i, j in zip(ml_data, results):\n",
|
| 971 |
+
" if is_prediction_on_top(i['themes'], j):\n",
|
| 972 |
+
" i['need_to_validate'] = False\n",
|
| 973 |
+
" nb += 1\n",
|
| 974 |
+
" else:\n",
|
| 975 |
+
" i['need_to_validate'] = True\n",
|
| 976 |
+
"nb, len(ml_data)"
|
| 977 |
+
]
|
| 978 |
+
},
|
| 979 |
+
{
|
| 980 |
+
"cell_type": "code",
|
| 981 |
+
"execution_count": null,
|
| 982 |
+
"id": "446c19a3-afab-4100-964a-093969034e91",
|
| 983 |
+
"metadata": {},
|
| 984 |
+
"outputs": [],
|
| 985 |
+
"source": [
|
| 986 |
+
"import uuid\n",
|
| 987 |
+
"str(uuid.uuid4())"
|
| 988 |
+
]
|
| 989 |
+
},
|
| 990 |
+
{
|
| 991 |
+
"cell_type": "code",
|
| 992 |
+
"execution_count": null,
|
| 993 |
+
"id": "93dc427c-cebd-4376-a565-ff9e8283077f",
|
| 994 |
+
"metadata": {},
|
| 995 |
+
"outputs": [],
|
| 996 |
+
"source": [
|
| 997 |
+
"nid, nt = 0, 0\n",
|
| 998 |
+
"for i in en_data + ml_data:\n",
|
| 999 |
+
" if not i.get('id'):\n",
|
| 1000 |
+
" i['id'] = str(uuid.uuid4())\n",
|
| 1001 |
+
" nid += 1\n",
|
| 1002 |
+
" if not i.get('title'):\n",
|
| 1003 |
+
" nt += 1\n",
|
| 1004 |
+
"len(en_data + ml_data), nid, nt"
|
| 1005 |
+
]
|
| 1006 |
+
},
|
| 1007 |
+
{
|
| 1008 |
+
"cell_type": "code",
|
| 1009 |
+
"execution_count": null,
|
| 1010 |
+
"id": "edfe0095-a239-4df6-9c1c-8d8f6033892b",
|
| 1011 |
+
"metadata": {},
|
| 1012 |
+
"outputs": [],
|
| 1013 |
+
"source": [
|
| 1014 |
+
"json.dump(en_data, open('filtered2/en_step2.json', 'w'))"
|
| 1015 |
+
]
|
| 1016 |
+
},
|
| 1017 |
+
{
|
| 1018 |
+
"cell_type": "code",
|
| 1019 |
+
"execution_count": null,
|
| 1020 |
+
"id": "6ed59f14-c070-41a5-bf81-2f21fb89e797",
|
| 1021 |
+
"metadata": {},
|
| 1022 |
+
"outputs": [],
|
| 1023 |
+
"source": [
|
| 1024 |
+
"json.dump(ml_data, open('filtered2/ml_step2.json', 'w'))"
|
| 1025 |
+
]
|
| 1026 |
+
},
|
| 1027 |
+
{
|
| 1028 |
+
"cell_type": "code",
|
| 1029 |
+
"execution_count": null,
|
| 1030 |
+
"id": "67d594c0-6eab-4191-855e-b98e1d490649",
|
| 1031 |
+
"metadata": {},
|
| 1032 |
+
"outputs": [],
|
| 1033 |
+
"source": []
|
| 1034 |
+
},
|
| 1035 |
+
{
|
| 1036 |
+
"cell_type": "code",
|
| 1037 |
+
"execution_count": null,
|
| 1038 |
+
"id": "7d612263-349d-4ae6-81f3-baf643452c1c",
|
| 1039 |
+
"metadata": {
|
| 1040 |
+
"scrolled": true
|
| 1041 |
+
},
|
| 1042 |
+
"outputs": [],
|
| 1043 |
+
"source": [
|
| 1044 |
+
"en_themewise = {}\n",
|
| 1045 |
+
"for i in en_data:\n",
|
| 1046 |
+
" sts = sort_themes_by_priority(i['themes'])\n",
|
| 1047 |
+
" for st in sts:\n",
|
| 1048 |
+
" if st not in en_themewise:\n",
|
| 1049 |
+
" en_themewise[st] = []\n",
|
| 1050 |
+
" en_themewise[st].append(i)\n",
|
| 1051 |
+
" break\n",
|
| 1052 |
+
"for i in en_themewise:\n",
|
| 1053 |
+
" tcc = len(en_themewise[i])\n",
|
| 1054 |
+
" selector = FaissGpuDiverseSelectorTorch(\n",
|
| 1055 |
+
" embedding_key=\"embedding\",\n",
|
| 1056 |
+
" gpu_id=0,\n",
|
| 1057 |
+
" temp_mem_gb=6.0,\n",
|
| 1058 |
+
" use_float16_storage=True,\n",
|
| 1059 |
+
" build_batch_size=8192,\n",
|
| 1060 |
+
" ).fit(en_themewise[i])\n",
|
| 1061 |
+
" picked = selector.select(\n",
|
| 1062 |
+
" num_select=5000,\n",
|
| 1063 |
+
" train_per_centroid=500,\n",
|
| 1064 |
+
" niter=100,\n",
|
| 1065 |
+
" centroid_search_k=512,\n",
|
| 1066 |
+
" )\n",
|
| 1067 |
+
" en_themewise[i] = picked\n",
|
| 1068 |
+
" print(i, '==>', tcc, len(picked))\n",
|
| 1069 |
+
"en_data = []\n",
|
| 1070 |
+
"for theme in en_themewise:\n",
|
| 1071 |
+
" en_data.extend(en_themewise[theme])\n",
|
| 1072 |
+
"json.dump(en_data, open('filtered2/en_step3.json', 'w'))"
|
| 1073 |
+
]
|
| 1074 |
+
},
|
| 1075 |
+
{
|
| 1076 |
+
"cell_type": "code",
|
| 1077 |
+
"execution_count": null,
|
| 1078 |
+
"id": "f5d4d3c4-9248-46b2-ba52-51dd4e22838d",
|
| 1079 |
+
"metadata": {
|
| 1080 |
+
"scrolled": true
|
| 1081 |
+
},
|
| 1082 |
+
"outputs": [],
|
| 1083 |
+
"source": [
|
| 1084 |
+
"ml_themewise = {}\n",
|
| 1085 |
+
"for i in ml_data:\n",
|
| 1086 |
+
" sts = sort_themes_by_priority(i['themes'])\n",
|
| 1087 |
+
" for st in sts:\n",
|
| 1088 |
+
" if st not in ml_themewise:\n",
|
| 1089 |
+
" ml_themewise[st] = []\n",
|
| 1090 |
+
" ml_themewise[st].append(i)\n",
|
| 1091 |
+
" break\n",
|
| 1092 |
+
"for i in ml_themewise:\n",
|
| 1093 |
+
" print(i, '==>', len(ml_themewise[i]))\n",
|
| 1094 |
+
" tcc = len(ml_themewise[i])\n",
|
| 1095 |
+
" selector = FaissGpuDiverseSelectorTorch(\n",
|
| 1096 |
+
" embedding_key=\"embedding\",\n",
|
| 1097 |
+
" gpu_id=0,\n",
|
| 1098 |
+
" temp_mem_gb=6.0,\n",
|
| 1099 |
+
" use_float16_storage=True,\n",
|
| 1100 |
+
" build_batch_size=8192,\n",
|
| 1101 |
+
" ).fit(ml_themewise[i])\n",
|
| 1102 |
+
" picked = selector.select(\n",
|
| 1103 |
+
" num_select=5000,\n",
|
| 1104 |
+
" train_per_centroid=500,\n",
|
| 1105 |
+
" niter=100,\n",
|
| 1106 |
+
" centroid_search_k=512,\n",
|
| 1107 |
+
" )\n",
|
| 1108 |
+
" ml_themewise[i] = picked\n",
|
| 1109 |
+
" print(i, '==>', tcc, len(picked))\n",
|
| 1110 |
+
"ml_data = []\n",
|
| 1111 |
+
"for theme in ml_themewise:\n",
|
| 1112 |
+
" ml_data.extend(ml_themewise[theme])\n",
|
| 1113 |
+
"json.dump(ml_data, open('filtered2/ml_step3.json', 'w'))\n",
|
| 1114 |
+
"len(ml_data)"
|
| 1115 |
+
]
|
| 1116 |
+
},
|
| 1117 |
+
{
|
| 1118 |
+
"cell_type": "markdown",
|
| 1119 |
+
"id": "17b1492a-0b40-43b3-9a46-3fea7fd4a7f6",
|
| 1120 |
+
"metadata": {},
|
| 1121 |
+
"source": [
|
| 1122 |
+
"# 5. confirm with LLM for selected candidates"
|
| 1123 |
+
]
|
| 1124 |
+
},
|
| 1125 |
+
{
|
| 1126 |
+
"cell_type": "code",
|
| 1127 |
+
"execution_count": null,
|
| 1128 |
+
"id": "033d00e6-9068-4705-a8a5-7a0aa4828d52",
|
| 1129 |
+
"metadata": {},
|
| 1130 |
+
"outputs": [],
|
| 1131 |
+
"source": [
|
| 1132 |
+
"nb = 0\n",
|
| 1133 |
+
"for i in en_data + ml_data:\n",
|
| 1134 |
+
" if i.get('need_to_validate'):\n",
|
| 1135 |
+
" nb += 1\n",
|
| 1136 |
+
"nb, len(en_data + ml_data)"
|
| 1137 |
+
]
|
| 1138 |
+
},
|
| 1139 |
+
{
|
| 1140 |
+
"cell_type": "code",
|
| 1141 |
+
"execution_count": 2,
|
| 1142 |
+
"id": "6701aa6e-2797-4083-9496-1308d71ffbde",
|
| 1143 |
+
"metadata": {},
|
| 1144 |
+
"outputs": [],
|
| 1145 |
+
"source": [
|
| 1146 |
+
"import json\n",
|
| 1147 |
+
"import logging\n",
|
| 1148 |
+
"import time\n",
|
| 1149 |
+
"import random\n",
|
| 1150 |
+
"\n",
|
| 1151 |
+
"import openai\n",
|
| 1152 |
+
"\n",
|
| 1153 |
+
"OPENAI_KEY = '***'\n",
|
| 1154 |
+
"MAX_OPENAI_RETRIES = 5\n",
|
| 1155 |
+
"\n",
|
| 1156 |
+
"openai.api_key = OPENAI_KEY\n",
|
| 1157 |
+
"\n",
|
| 1158 |
+
"\n",
|
| 1159 |
+
"def extract_function_data(\n",
|
| 1160 |
+
" function: dict,\n",
|
| 1161 |
+
" content: str,\n",
|
| 1162 |
+
" target_model: str = 'gpt-4o-mini',\n",
|
| 1163 |
+
" role: str = \"You are a zero-shot classification model.\",\n",
|
| 1164 |
+
" retries: int = 0,\n",
|
| 1165 |
+
" extra_tags: dict = None\n",
|
| 1166 |
+
"):\n",
|
| 1167 |
+
" retries += 1\n",
|
| 1168 |
+
" if extra_tags:\n",
|
| 1169 |
+
" openpipe_tags.update(extra_tags)\n",
|
| 1170 |
+
" if retries > MAX_OPENAI_RETRIES:\n",
|
| 1171 |
+
" logging.error(\"Failed to Extract Event Data\", extra={\n",
|
| 1172 |
+
" \"content_length\": len(content.split(\" \")),\n",
|
| 1173 |
+
" \"error_message\": f\"Failed to extract function data after {MAX_OPENAI_RETRIES} retries\"\n",
|
| 1174 |
+
" })\n",
|
| 1175 |
+
" return None\n",
|
| 1176 |
+
" try:\n",
|
| 1177 |
+
" response = openai.chat.completions.create(\n",
|
| 1178 |
+
" model=target_model,\n",
|
| 1179 |
+
" messages=[{\"role\": \"system\", \"content\": role},\n",
|
| 1180 |
+
" {\"role\": \"user\", \"content\": content}],\n",
|
| 1181 |
+
" functions=[function],\n",
|
| 1182 |
+
" function_call={\"name\": function[\"name\"]},\n",
|
| 1183 |
+
" temperature=0\n",
|
| 1184 |
+
" )\n",
|
| 1185 |
+
"\n",
|
| 1186 |
+
" # Except rate limit error\n",
|
| 1187 |
+
" except openai.RateLimitError as e:\n",
|
| 1188 |
+
" # Ask the question again with 3/4 of the content\n",
|
| 1189 |
+
" logging.error(\"Retrying OpenAI Request\", extra={\n",
|
| 1190 |
+
" \"error\": str(e),\n",
|
| 1191 |
+
" \"error_message\": \"API limit reached, try again after 5 seconds\",\n",
|
| 1192 |
+
" })\n",
|
| 1193 |
+
" # sleep for random time between 5 to 15 seconds\n",
|
| 1194 |
+
" time.sleep(random.randint(5, 15))\n",
|
| 1195 |
+
" return extract_function_data(function, target_model, content, retries)\n",
|
| 1196 |
+
"\n",
|
| 1197 |
+
" # Except all other errors\n",
|
| 1198 |
+
" except Exception as e:\n",
|
| 1199 |
+
" logging.error(\"Retrying OpenAI Request\", extra={\n",
|
| 1200 |
+
" \"error\": str(e),\n",
|
| 1201 |
+
" \"error_message\": \"Unkown Error, Retrying with 3/4 of the content\",\n",
|
| 1202 |
+
" })\n",
|
| 1203 |
+
" logging.error(e)\n",
|
| 1204 |
+
" content_length = len(content.split(\" \"))\n",
|
| 1205 |
+
" updated_content = \" \".join(content.split(\" \")[:int(content_length * 0.75)])\n",
|
| 1206 |
+
"\n",
|
| 1207 |
+
" return extract_function_data(function, target_model, updated_content,\n",
|
| 1208 |
+
" retries)\n",
|
| 1209 |
+
"\n",
|
| 1210 |
+
" try:\n",
|
| 1211 |
+
" result = json.loads(response.choices[0].message.function_call.arguments)\n",
|
| 1212 |
+
" ## Remove Keys with empty string value because OpenAI is returning empty strings as keys sometimes\n",
|
| 1213 |
+
" result = {k: v for k, v in result.items() if v not in [\"\", [], None]}\n",
|
| 1214 |
+
"\n",
|
| 1215 |
+
" logging.info(\"Successfully extracted Event Data\", extra={\n",
|
| 1216 |
+
" \"error_message\": \"Content Classified Successfully\",\n",
|
| 1217 |
+
" \"content\": str(result)\n",
|
| 1218 |
+
" })\n",
|
| 1219 |
+
" return result\n",
|
| 1220 |
+
" except Exception as e:\n",
|
| 1221 |
+
" logging.error(\"Failed to Parse JSON from OpenAI in Extraction\", extra={\n",
|
| 1222 |
+
" \"error\": str(e),\n",
|
| 1223 |
+
" \"openai_response\": str(response),\n",
|
| 1224 |
+
" })\n",
|
| 1225 |
+
" return None"
|
| 1226 |
+
]
|
| 1227 |
+
},
|
| 1228 |
+
{
|
| 1229 |
+
"cell_type": "code",
|
| 1230 |
+
"execution_count": null,
|
| 1231 |
+
"id": "d2992602-ab1b-43ad-a2c2-ede79671affc",
|
| 1232 |
+
"metadata": {},
|
| 1233 |
+
"outputs": [],
|
| 1234 |
+
"source": [
|
| 1235 |
+
"int_theme_classification_function = {\n",
|
| 1236 |
+
" \"name\": \"theme_classification\",\n",
|
| 1237 |
+
" \"description\": \"please classify below news article into any of these categories, ['Health', 'Entertainment', 'Business', 'Science', 'Politics', 'Finance', 'Economics', 'Tech', 'Crime', 'Sports', 'Lifestyle', 'Automotive', 'Travel', 'Weather', 'General', 'Financial Crime'], You should return \\\"General\\\" only if article doesnt fall into any of these, Please classigy it carefully. Don't tell me bad result before confirming. also don't return all categories, return general if it not fits in any of above\",\n",
|
| 1238 |
+
" \"parameters\": {\n",
|
| 1239 |
+
" \"type\": \"object\",\n",
|
| 1240 |
+
" \"properties\": {'Health': {'type': 'boolean'},\n",
|
| 1241 |
+
" 'Entertainment': {'type': 'boolean'},\n",
|
| 1242 |
+
" 'Business': {'type': 'boolean'}, 'Science': {'type': 'boolean'},\n",
|
| 1243 |
+
" 'Politics': {'type': 'boolean'}, 'Finance': {'type': 'boolean'},\n",
|
| 1244 |
+
" 'Economics': {'type': 'boolean'}, 'Tech': {'type': 'boolean'},\n",
|
| 1245 |
+
" 'Crime': {'type': 'boolean'}, 'Sports': {'type': 'boolean'},\n",
|
| 1246 |
+
" 'Lifestyle': {'type': 'boolean'},\n",
|
| 1247 |
+
" 'Automotive': {'type': 'boolean'}, 'Travel': {'type': 'boolean'},\n",
|
| 1248 |
+
" 'Weather': {'type': 'boolean'}, 'General': {'type': 'boolean'}, 'Financial Crime': {'type': 'boolean'}},\n",
|
| 1249 |
+
" \"required\": []\n",
|
| 1250 |
+
" }\n",
|
| 1251 |
+
"}"
|
| 1252 |
+
]
|
| 1253 |
+
},
|
| 1254 |
+
{
|
| 1255 |
+
"cell_type": "code",
|
| 1256 |
+
"execution_count": null,
|
| 1257 |
+
"id": "245b1afc-7926-47fb-9f4b-fdae08edb4f0",
|
| 1258 |
+
"metadata": {},
|
| 1259 |
+
"outputs": [],
|
| 1260 |
+
"source": [
|
| 1261 |
+
"for i in en_data[::-1]:\n",
|
| 1262 |
+
" if i['need_to_validate']:\n",
|
| 1263 |
+
" break"
|
| 1264 |
+
]
|
| 1265 |
+
},
|
| 1266 |
+
{
|
| 1267 |
+
"cell_type": "code",
|
| 1268 |
+
"execution_count": null,
|
| 1269 |
+
"id": "05314580-6e86-4fb6-a554-2d357deb21b1",
|
| 1270 |
+
"metadata": {},
|
| 1271 |
+
"outputs": [],
|
| 1272 |
+
"source": [
|
| 1273 |
+
"themes = extract_function_data(int_theme_classification_function, \"Title\" + i['title'] + '\\n\\nContent:' + i['content'], target_model='gpt-4.1-mini')\n",
|
| 1274 |
+
"themes = [i for i in themes if themes[i]]\n",
|
| 1275 |
+
"themes"
|
| 1276 |
+
]
|
| 1277 |
+
},
|
| 1278 |
+
{
|
| 1279 |
+
"cell_type": "code",
|
| 1280 |
+
"execution_count": null,
|
| 1281 |
+
"id": "fc20aeb3-8552-4873-bda2-37e498c04d49",
|
| 1282 |
+
"metadata": {},
|
| 1283 |
+
"outputs": [],
|
| 1284 |
+
"source": [
|
| 1285 |
+
"def validate(art):\n",
|
| 1286 |
+
" if not art.get('need_to_validate'):\n",
|
| 1287 |
+
" return\n",
|
| 1288 |
+
" if \"prev_themes\" in art:\n",
|
| 1289 |
+
" return\n",
|
| 1290 |
+
" content = \"\"\n",
|
| 1291 |
+
" if art.get(\"title\"):\n",
|
| 1292 |
+
" content = \"Title:\" + art['title'] + '\\n\\nContent:' + art['content']\n",
|
| 1293 |
+
" else:\n",
|
| 1294 |
+
" content = art['content']\n",
|
| 1295 |
+
" themes = extract_function_data(int_theme_classification_function, content, target_model='gpt-4.1-mini')\n",
|
| 1296 |
+
" themes = [i for i in themes if themes[i]]\n",
|
| 1297 |
+
" if themes:\n",
|
| 1298 |
+
" art['prev_themes'] = art.pop('themes')\n",
|
| 1299 |
+
" art['themes'] = themes"
|
| 1300 |
+
]
|
| 1301 |
+
},
|
| 1302 |
+
{
|
| 1303 |
+
"cell_type": "code",
|
| 1304 |
+
"execution_count": null,
|
| 1305 |
+
"id": "11ca3cca-df82-4c92-aac8-58c6722d5ed7",
|
| 1306 |
+
"metadata": {},
|
| 1307 |
+
"outputs": [],
|
| 1308 |
+
"source": [
|
| 1309 |
+
"with ThreadPoolExecutor(max_workers=128) as ex:\n",
|
| 1310 |
+
" _ = [ex.submit(validate, art) for art in en_data]"
|
| 1311 |
+
]
|
| 1312 |
+
},
|
| 1313 |
+
{
|
| 1314 |
+
"cell_type": "code",
|
| 1315 |
+
"execution_count": null,
|
| 1316 |
+
"id": "b2541763-4a46-4195-b2e9-48df2e992add",
|
| 1317 |
+
"metadata": {},
|
| 1318 |
+
"outputs": [],
|
| 1319 |
+
"source": [
|
| 1320 |
+
"with ThreadPoolExecutor(max_workers=128) as ex:\n",
|
| 1321 |
+
" _ = [ex.submit(validate, art) for art in ml_data]"
|
| 1322 |
+
]
|
| 1323 |
+
},
|
| 1324 |
+
{
|
| 1325 |
+
"cell_type": "code",
|
| 1326 |
+
"execution_count": null,
|
| 1327 |
+
"id": "07227c4d-6cdb-41dc-87d3-b67c73f15b6f",
|
| 1328 |
+
"metadata": {},
|
| 1329 |
+
"outputs": [],
|
| 1330 |
+
"source": [
|
| 1331 |
+
"nb = 0\n",
|
| 1332 |
+
"for i in en_data + ml_data:\n",
|
| 1333 |
+
" if i['need_to_validate'] and not i.get('prev_themes'):\n",
|
| 1334 |
+
" nb += 1\n",
|
| 1335 |
+
" print(i.get('title'), i.get('content'))\n",
|
| 1336 |
+
"nb, len(en_data + ml_data)"
|
| 1337 |
+
]
|
| 1338 |
+
},
|
| 1339 |
+
{
|
| 1340 |
+
"cell_type": "code",
|
| 1341 |
+
"execution_count": null,
|
| 1342 |
+
"id": "24004f74-9ca5-4c70-88c1-f1ec3fbb13c0",
|
| 1343 |
+
"metadata": {},
|
| 1344 |
+
"outputs": [],
|
| 1345 |
+
"source": [
|
| 1346 |
+
"json.dump(en_data, open('filtered2/en_step4.json', 'w'))"
|
| 1347 |
+
]
|
| 1348 |
+
},
|
| 1349 |
+
{
|
| 1350 |
+
"cell_type": "code",
|
| 1351 |
+
"execution_count": null,
|
| 1352 |
+
"id": "67a33aa0-e8e7-44c3-89d6-2ca097679988",
|
| 1353 |
+
"metadata": {},
|
| 1354 |
+
"outputs": [],
|
| 1355 |
+
"source": [
|
| 1356 |
+
"json.dump(ml_data, open('filtered2/ml_step4.json', 'w'))"
|
| 1357 |
+
]
|
| 1358 |
+
},
|
| 1359 |
+
{
|
| 1360 |
+
"cell_type": "markdown",
|
| 1361 |
+
"id": "1ac2b2bd-7419-4f67-a023-59bac45d57b4",
|
| 1362 |
+
"metadata": {},
|
| 1363 |
+
"source": [
|
| 1364 |
+
"# 6. split train & test also pre-train"
|
| 1365 |
+
]
|
| 1366 |
+
},
|
| 1367 |
+
{
|
| 1368 |
+
"cell_type": "code",
|
| 1369 |
+
"execution_count": null,
|
| 1370 |
+
"id": "afab9235-4e4c-4a28-b6d1-e6071efd056a",
|
| 1371 |
+
"metadata": {},
|
| 1372 |
+
"outputs": [],
|
| 1373 |
+
"source": [
|
| 1374 |
+
"random.shuffle(en_data)\n",
|
| 1375 |
+
"random.shuffle(ml_data)"
|
| 1376 |
+
]
|
| 1377 |
+
},
|
| 1378 |
+
{
|
| 1379 |
+
"cell_type": "code",
|
| 1380 |
+
"execution_count": null,
|
| 1381 |
+
"id": "d57dc5c4-1756-497c-9be3-9e3c4086f667",
|
| 1382 |
+
"metadata": {},
|
| 1383 |
+
"outputs": [],
|
| 1384 |
+
"source": [
|
| 1385 |
+
"train = {}\n",
|
| 1386 |
+
"test = {}\n",
|
| 1387 |
+
"\n",
|
| 1388 |
+
"for i in en_data:\n",
|
| 1389 |
+
" sts = sort_themes_by_priority(i['themes'])\n",
|
| 1390 |
+
" if 'Education' in sts:\n",
|
| 1391 |
+
" sts.remove('Education')\n",
|
| 1392 |
+
" for st in sts:\n",
|
| 1393 |
+
" if st not in train:\n",
|
| 1394 |
+
" train[st] = []\n",
|
| 1395 |
+
" test[st] = []\n",
|
| 1396 |
+
" if len(test[st]) < 500 and i['need_to_validate']:\n",
|
| 1397 |
+
" test[st].append(i)\n",
|
| 1398 |
+
" break\n",
|
| 1399 |
+
" else:\n",
|
| 1400 |
+
" train[st].append(i)\n",
|
| 1401 |
+
" break"
|
| 1402 |
+
]
|
| 1403 |
+
},
|
| 1404 |
+
{
|
| 1405 |
+
"cell_type": "code",
|
| 1406 |
+
"execution_count": null,
|
| 1407 |
+
"id": "aa27ad3e-71a5-4aed-bc2f-32771feda321",
|
| 1408 |
+
"metadata": {},
|
| 1409 |
+
"outputs": [],
|
| 1410 |
+
"source": [
|
| 1411 |
+
"for i in ml_data:\n",
|
| 1412 |
+
" sts = sort_themes_by_priority(i['themes'])\n",
|
| 1413 |
+
" if 'Education' in sts:\n",
|
| 1414 |
+
" sts.remove('Education')\n",
|
| 1415 |
+
" for st in sts:\n",
|
| 1416 |
+
" if st not in train:\n",
|
| 1417 |
+
" train[st] = []\n",
|
| 1418 |
+
" test[st] = []\n",
|
| 1419 |
+
" if len(test[st]) < 1000 and i['need_to_validate']:\n",
|
| 1420 |
+
" test[st].append(i)\n",
|
| 1421 |
+
" break\n",
|
| 1422 |
+
" else:\n",
|
| 1423 |
+
" train[st].append(i)\n",
|
| 1424 |
+
" break"
|
| 1425 |
+
]
|
| 1426 |
+
},
|
| 1427 |
+
{
|
| 1428 |
+
"cell_type": "code",
|
| 1429 |
+
"execution_count": null,
|
| 1430 |
+
"id": "b9cd3ece-301f-4fee-b8fd-a1c514287ce5",
|
| 1431 |
+
"metadata": {},
|
| 1432 |
+
"outputs": [],
|
| 1433 |
+
"source": [
|
| 1434 |
+
"for i in test:\n",
|
| 1435 |
+
" print(i, len(test[i]), len(train[i]))"
|
| 1436 |
+
]
|
| 1437 |
+
},
|
| 1438 |
+
{
|
| 1439 |
+
"cell_type": "code",
|
| 1440 |
+
"execution_count": null,
|
| 1441 |
+
"id": "f6ac7932-ca64-4d06-8a12-e1235a0ee501",
|
| 1442 |
+
"metadata": {},
|
| 1443 |
+
"outputs": [],
|
| 1444 |
+
"source": [
|
| 1445 |
+
"json.dump(sum([train[i] for i in train], []), open('filtered2/train.json', 'w'))"
|
| 1446 |
+
]
|
| 1447 |
+
},
|
| 1448 |
+
{
|
| 1449 |
+
"cell_type": "code",
|
| 1450 |
+
"execution_count": null,
|
| 1451 |
+
"id": "ae639b65-1b28-4079-899c-8ae5dd86eb4d",
|
| 1452 |
+
"metadata": {},
|
| 1453 |
+
"outputs": [],
|
| 1454 |
+
"source": [
|
| 1455 |
+
"json.dump(sum([test[i] for i in test], []), open('filtered2/test.json', 'w'))"
|
| 1456 |
+
]
|
| 1457 |
+
},
|
| 1458 |
+
{
|
| 1459 |
+
"cell_type": "code",
|
| 1460 |
+
"execution_count": null,
|
| 1461 |
+
"id": "fc0fc992-1c32-42ed-95c8-b1b77d878b93",
|
| 1462 |
+
"metadata": {},
|
| 1463 |
+
"outputs": [],
|
| 1464 |
+
"source": [
|
| 1465 |
+
"tests_cont = set([i['content'].lower().strip() for i in sum([test[i] for i in test], [])])\n",
|
| 1466 |
+
"len(tests_cont)"
|
| 1467 |
+
]
|
| 1468 |
+
},
|
| 1469 |
+
{
|
| 1470 |
+
"cell_type": "code",
|
| 1471 |
+
"execution_count": null,
|
| 1472 |
+
"id": "2d7ce8f4-00a4-419a-8b8d-6c06cf058aed",
|
| 1473 |
+
"metadata": {},
|
| 1474 |
+
"outputs": [],
|
| 1475 |
+
"source": [
|
| 1476 |
+
"pre_train = json.load(open('filtered2/en_step2.json')) + json.load(open('filtered2/ml_step2.json'))\n",
|
| 1477 |
+
"len(pre_train)"
|
| 1478 |
+
]
|
| 1479 |
+
},
|
| 1480 |
+
{
|
| 1481 |
+
"cell_type": "code",
|
| 1482 |
+
"execution_count": null,
|
| 1483 |
+
"id": "ba3972ee-ac8c-408b-b314-cb89424ff46d",
|
| 1484 |
+
"metadata": {},
|
| 1485 |
+
"outputs": [],
|
| 1486 |
+
"source": [
|
| 1487 |
+
"pre_train = [i for i in pre_train if i['content'].lower().strip() not in tests_cont]\n",
|
| 1488 |
+
"len(pre_train)"
|
| 1489 |
+
]
|
| 1490 |
+
},
|
| 1491 |
+
{
|
| 1492 |
+
"cell_type": "code",
|
| 1493 |
+
"execution_count": null,
|
| 1494 |
+
"id": "25594c3a-f5f1-4d77-a297-a33229630ac5",
|
| 1495 |
+
"metadata": {},
|
| 1496 |
+
"outputs": [],
|
| 1497 |
+
"source": [
|
| 1498 |
+
"json.dump(pre_train, open('filtered2/pre-train.json', 'w'))"
|
| 1499 |
+
]
|
| 1500 |
+
},
|
| 1501 |
+
{
|
| 1502 |
+
"cell_type": "markdown",
|
| 1503 |
+
"id": "751f2c18-80b4-4f18-aa51-8d4da2e26e01",
|
| 1504 |
+
"metadata": {},
|
| 1505 |
+
"source": [
|
| 1506 |
+
"# split again val, train and test"
|
| 1507 |
+
]
|
| 1508 |
+
},
|
| 1509 |
+
{
|
| 1510 |
+
"cell_type": "code",
|
| 1511 |
+
"execution_count": null,
|
| 1512 |
+
"id": "0991b233-ec51-481e-90b3-208f5c2303ec",
|
| 1513 |
+
"metadata": {},
|
| 1514 |
+
"outputs": [],
|
| 1515 |
+
"source": [
|
| 1516 |
+
"train = json.load(open('filtered2/train.json'))\n",
|
| 1517 |
+
"pre_train = json.load(open('filtered2/pre-train.json'))\n",
|
| 1518 |
+
"test = json.load(open('filtered2/test.json'))"
|
| 1519 |
+
]
|
| 1520 |
+
},
|
| 1521 |
+
{
|
| 1522 |
+
"cell_type": "code",
|
| 1523 |
+
"execution_count": null,
|
| 1524 |
+
"id": "ac1fb5ff-9702-4605-899c-0b63f22fd478",
|
| 1525 |
+
"metadata": {},
|
| 1526 |
+
"outputs": [],
|
| 1527 |
+
"source": [
|
| 1528 |
+
"def validate_train(art):\n",
|
| 1529 |
+
" if \"prev_themes\" in art:\n",
|
| 1530 |
+
" return\n",
|
| 1531 |
+
" content = \"\"\n",
|
| 1532 |
+
" if art.get(\"title\"):\n",
|
| 1533 |
+
" content = \"Title:\" + art['title'] + '\\n\\nContent:' + art['content']\n",
|
| 1534 |
+
" else:\n",
|
| 1535 |
+
" content = art['content']\n",
|
| 1536 |
+
" themes = extract_function_data(int_theme_classification_function, content, target_model='gpt-4.1-mini')\n",
|
| 1537 |
+
" themes = [i for i in themes if themes[i]]\n",
|
| 1538 |
+
" if themes:\n",
|
| 1539 |
+
" art['prev_themes'] = art.pop('themes')\n",
|
| 1540 |
+
" art['themes'] = themes"
|
| 1541 |
+
]
|
| 1542 |
+
},
|
| 1543 |
+
{
|
| 1544 |
+
"cell_type": "code",
|
| 1545 |
+
"execution_count": null,
|
| 1546 |
+
"id": "9ab0726e-a993-4400-b444-014537253fed",
|
| 1547 |
+
"metadata": {},
|
| 1548 |
+
"outputs": [],
|
| 1549 |
+
"source": [
|
| 1550 |
+
"from requests.adapters import HTTPAdapter\n",
|
| 1551 |
+
"from concurrent.futures import ThreadPoolExecutor, as_completed\n",
|
| 1552 |
+
"with ThreadPoolExecutor(max_workers=512) as ex:\n",
|
| 1553 |
+
" _ = [ex.submit(validate_train, art) for art in train]"
|
| 1554 |
+
]
|
| 1555 |
+
},
|
| 1556 |
+
{
|
| 1557 |
+
"cell_type": "code",
|
| 1558 |
+
"execution_count": null,
|
| 1559 |
+
"id": "549b4492-97e6-4a26-b4da-fc8ad4c52073",
|
| 1560 |
+
"metadata": {},
|
| 1561 |
+
"outputs": [],
|
| 1562 |
+
"source": [
|
| 1563 |
+
"nb = 0\n",
|
| 1564 |
+
"for i in train:\n",
|
| 1565 |
+
" if \"prev_themes\" not in i:\n",
|
| 1566 |
+
" nb += 1\n",
|
| 1567 |
+
"nb, len(train)"
|
| 1568 |
+
]
|
| 1569 |
+
},
|
| 1570 |
+
{
|
| 1571 |
+
"cell_type": "code",
|
| 1572 |
+
"execution_count": null,
|
| 1573 |
+
"id": "d011264f-f4af-4d82-bbfc-7d621d6f2215",
|
| 1574 |
+
"metadata": {},
|
| 1575 |
+
"outputs": [],
|
| 1576 |
+
"source": [
|
| 1577 |
+
"json.dump(train, open('filtered2/train.json', 'w'))"
|
| 1578 |
+
]
|
| 1579 |
+
},
|
| 1580 |
+
{
|
| 1581 |
+
"cell_type": "code",
|
| 1582 |
+
"execution_count": null,
|
| 1583 |
+
"id": "054fd8df-9ea0-4d39-ac99-7e3e43c6c4d2",
|
| 1584 |
+
"metadata": {},
|
| 1585 |
+
"outputs": [],
|
| 1586 |
+
"source": [
|
| 1587 |
+
"ntrain = {}\n",
|
| 1588 |
+
"nval = {}\n",
|
| 1589 |
+
"\n",
|
| 1590 |
+
"for i in train:\n",
|
| 1591 |
+
" sts = sort_themes_by_priority(i['themes'])\n",
|
| 1592 |
+
" sts = [i for i in sts if i in priority_order]\n",
|
| 1593 |
+
" for st in sts:\n",
|
| 1594 |
+
" if st not in ntrain:\n",
|
| 1595 |
+
" ntrain[st] = []\n",
|
| 1596 |
+
" nval[st] = []\n",
|
| 1597 |
+
" if len(nval[st]) < 500 and i['need_to_validate']:\n",
|
| 1598 |
+
" nval[st].append(i)\n",
|
| 1599 |
+
" break\n",
|
| 1600 |
+
" else:\n",
|
| 1601 |
+
" ntrain[st].append(i)\n",
|
| 1602 |
+
" break\n",
|
| 1603 |
+
"\n",
|
| 1604 |
+
"ntrain = sum([ntrain[i] for i in ntrain], [])\n",
|
| 1605 |
+
"nval = sum([nval[i] for i in nval], [])"
|
| 1606 |
+
]
|
| 1607 |
+
},
|
| 1608 |
+
{
|
| 1609 |
+
"cell_type": "code",
|
| 1610 |
+
"execution_count": null,
|
| 1611 |
+
"id": "7df359b3-b815-402a-a0b8-fd5d09cc7ee2",
|
| 1612 |
+
"metadata": {},
|
| 1613 |
+
"outputs": [],
|
| 1614 |
+
"source": [
|
| 1615 |
+
"len(ntrain), len(nval)"
|
| 1616 |
+
]
|
| 1617 |
+
},
|
| 1618 |
+
{
|
| 1619 |
+
"cell_type": "code",
|
| 1620 |
+
"execution_count": null,
|
| 1621 |
+
"id": "b37878b1-4f2b-4a87-85ac-7a685e9fe00a",
|
| 1622 |
+
"metadata": {},
|
| 1623 |
+
"outputs": [],
|
| 1624 |
+
"source": []
|
| 1625 |
+
},
|
| 1626 |
+
{
|
| 1627 |
+
"cell_type": "code",
|
| 1628 |
+
"execution_count": null,
|
| 1629 |
+
"id": "36c43a30-4b0b-443f-a29a-af78848b271c",
|
| 1630 |
+
"metadata": {},
|
| 1631 |
+
"outputs": [],
|
| 1632 |
+
"source": [
|
| 1633 |
+
"npretrain = {}\n",
|
| 1634 |
+
"contset = {}\n",
|
| 1635 |
+
"\n",
|
| 1636 |
+
"for i in nval+test:\n",
|
| 1637 |
+
" contset[i['content'].lower().strip()] = i['themes']\n",
|
| 1638 |
+
"len(contset)"
|
| 1639 |
+
]
|
| 1640 |
+
},
|
| 1641 |
+
{
|
| 1642 |
+
"cell_type": "code",
|
| 1643 |
+
"execution_count": null,
|
| 1644 |
+
"id": "7bc1d022-41c4-45f8-be2d-4246cd052802",
|
| 1645 |
+
"metadata": {},
|
| 1646 |
+
"outputs": [],
|
| 1647 |
+
"source": [
|
| 1648 |
+
"fin_rows = []\n",
|
| 1649 |
+
"for i in pre_train:\n",
|
| 1650 |
+
" if 'Financial Crime' not in i['themes']:\n",
|
| 1651 |
+
" continue\n",
|
| 1652 |
+
" if i['content'].lower().strip() not in contset:\n",
|
| 1653 |
+
" fin_rows.append(i)\n",
|
| 1654 |
+
"len(fin_rows)"
|
| 1655 |
+
]
|
| 1656 |
+
},
|
| 1657 |
+
{
|
| 1658 |
+
"cell_type": "code",
|
| 1659 |
+
"execution_count": null,
|
| 1660 |
+
"id": "d65d514c-c7cb-4a81-be63-da1420eb89c0",
|
| 1661 |
+
"metadata": {},
|
| 1662 |
+
"outputs": [],
|
| 1663 |
+
"source": [
|
| 1664 |
+
"with ThreadPoolExecutor(max_workers=512) as ex:\n",
|
| 1665 |
+
" _ = [ex.submit(validate_train, art) for art in fin_rows]"
|
| 1666 |
+
]
|
| 1667 |
+
},
|
| 1668 |
+
{
|
| 1669 |
+
"cell_type": "code",
|
| 1670 |
+
"execution_count": null,
|
| 1671 |
+
"id": "8b881fd9-3894-40ac-a032-2a81452f3d7b",
|
| 1672 |
+
"metadata": {},
|
| 1673 |
+
"outputs": [],
|
| 1674 |
+
"source": [
|
| 1675 |
+
"tt = {}\n",
|
| 1676 |
+
"for i in fin_rows[:10000]:\n",
|
| 1677 |
+
" for t in i['themes']:\n",
|
| 1678 |
+
" tt[t] = tt.get(t, 0) + 1\n",
|
| 1679 |
+
"tt"
|
| 1680 |
+
]
|
| 1681 |
+
},
|
| 1682 |
+
{
|
| 1683 |
+
"cell_type": "code",
|
| 1684 |
+
"execution_count": null,
|
| 1685 |
+
"id": "9e9c7f98-befe-4eb6-9bad-28b839ce385b",
|
| 1686 |
+
"metadata": {},
|
| 1687 |
+
"outputs": [],
|
| 1688 |
+
"source": [
|
| 1689 |
+
"for i in fin_rows:\n",
|
| 1690 |
+
" if i['content'].lower().strip() not in contset:\n",
|
| 1691 |
+
" contset[i['content'].lower().strip()] = i['themes']"
|
| 1692 |
+
]
|
| 1693 |
+
},
|
| 1694 |
+
{
|
| 1695 |
+
"cell_type": "code",
|
| 1696 |
+
"execution_count": null,
|
| 1697 |
+
"id": "fc3dfbef-9361-4d38-b0b3-c8cbc63d04ce",
|
| 1698 |
+
"metadata": {},
|
| 1699 |
+
"outputs": [],
|
| 1700 |
+
"source": [
|
| 1701 |
+
"ntrain = ntrain + fin_rows\n",
|
| 1702 |
+
"len(ntrain)"
|
| 1703 |
+
]
|
| 1704 |
+
},
|
| 1705 |
+
{
|
| 1706 |
+
"cell_type": "code",
|
| 1707 |
+
"execution_count": null,
|
| 1708 |
+
"id": "cb8f2cf0-49cf-4306-ac1f-73d19b139ab4",
|
| 1709 |
+
"metadata": {},
|
| 1710 |
+
"outputs": [],
|
| 1711 |
+
"source": [
|
| 1712 |
+
"final_train = {}\n",
|
| 1713 |
+
"for i in ntrain:\n",
|
| 1714 |
+
" sts = sort_themes_by_priority(i['themes'])\n",
|
| 1715 |
+
" sts = [i for i in sts if i in priority_order]\n",
|
| 1716 |
+
" for st in sts:\n",
|
| 1717 |
+
" if st not in final_train:\n",
|
| 1718 |
+
" final_train[st] = []\n",
|
| 1719 |
+
" if len(final_train[st]) < 6000:\n",
|
| 1720 |
+
" final_train[st].append(i)\n",
|
| 1721 |
+
" break\n",
|
| 1722 |
+
"for i in final_train:\n",
|
| 1723 |
+
" print(i, len(final_train[i]))"
|
| 1724 |
+
]
|
| 1725 |
+
},
|
| 1726 |
+
{
|
| 1727 |
+
"cell_type": "code",
|
| 1728 |
+
"execution_count": null,
|
| 1729 |
+
"id": "2d062629-4f0c-4e15-a36d-ff92ae463399",
|
| 1730 |
+
"metadata": {},
|
| 1731 |
+
"outputs": [],
|
| 1732 |
+
"source": [
|
| 1733 |
+
"final_train = sum([final_train[t] for t in final_train], [])\n",
|
| 1734 |
+
"len(final_train)"
|
| 1735 |
+
]
|
| 1736 |
+
},
|
| 1737 |
+
{
|
| 1738 |
+
"cell_type": "code",
|
| 1739 |
+
"execution_count": null,
|
| 1740 |
+
"id": "b7d10a78-26c0-461b-800f-cae91b4b7e4a",
|
| 1741 |
+
"metadata": {},
|
| 1742 |
+
"outputs": [],
|
| 1743 |
+
"source": []
|
| 1744 |
+
},
|
| 1745 |
+
{
|
| 1746 |
+
"cell_type": "code",
|
| 1747 |
+
"execution_count": null,
|
| 1748 |
+
"id": "32b75f1c-0b21-46a8-a2cc-9dbe5cb17088",
|
| 1749 |
+
"metadata": {},
|
| 1750 |
+
"outputs": [],
|
| 1751 |
+
"source": [
|
| 1752 |
+
"for i in pre_train:\n",
|
| 1753 |
+
" if i['content'].lower().strip() in contset:\n",
|
| 1754 |
+
" i['themes'] = contset[i['content'].lower().strip()]"
|
| 1755 |
+
]
|
| 1756 |
+
},
|
| 1757 |
+
{
|
| 1758 |
+
"cell_type": "code",
|
| 1759 |
+
"execution_count": 67,
|
| 1760 |
+
"id": "606aac46-b3ce-48a8-96d5-b179845dfae9",
|
| 1761 |
+
"metadata": {},
|
| 1762 |
+
"outputs": [],
|
| 1763 |
+
"source": [
|
| 1764 |
+
"valid_themes = set(priority_order)\n",
|
| 1765 |
+
"\n",
|
| 1766 |
+
"def dumpjson(rows, file_name):\n",
|
| 1767 |
+
" t = {}\n",
|
| 1768 |
+
" for row in rows:\n",
|
| 1769 |
+
" themes = row['themes']\n",
|
| 1770 |
+
" themes = [i for i in themes if i in valid_themes]\n",
|
| 1771 |
+
" row['themes'] = themes\n",
|
| 1772 |
+
" for th in themes:\n",
|
| 1773 |
+
" t[th] = t.get(th, 0) + 1\n",
|
| 1774 |
+
" print(len(rows), t)\n",
|
| 1775 |
+
" json.dump(rows, open(file_name, 'w'))"
|
| 1776 |
+
]
|
| 1777 |
+
},
|
| 1778 |
+
{
|
| 1779 |
+
"cell_type": "code",
|
| 1780 |
+
"execution_count": 68,
|
| 1781 |
+
"id": "09f7e9ac-e6fd-47ce-823b-dabe74b27ba0",
|
| 1782 |
+
"metadata": {},
|
| 1783 |
+
"outputs": [
|
| 1784 |
+
{
|
| 1785 |
+
"name": "stdout",
|
| 1786 |
+
"output_type": "stream",
|
| 1787 |
+
"text": [
|
| 1788 |
+
"92245 {'Entertainment': 6854, 'General': 6302, 'Politics': 14207, 'Business': 14722, 'Lifestyle': 7827, 'Tech': 8823, 'Health': 7985, 'Crime': 9548, 'Economics': 7703, 'Sports': 6607, 'Travel': 6812, 'Science': 4773, 'Automotive': 4698, 'Weather': 6351, 'Finance': 8080, 'Financial Crime': 5218}\n"
|
| 1789 |
+
]
|
| 1790 |
+
}
|
| 1791 |
+
],
|
| 1792 |
+
"source": [
|
| 1793 |
+
"dumpjson(final_train, 'filtered2/train.json')"
|
| 1794 |
+
]
|
| 1795 |
+
},
|
| 1796 |
+
{
|
| 1797 |
+
"cell_type": "code",
|
| 1798 |
+
"execution_count": 69,
|
| 1799 |
+
"id": "902ece36-0efa-4225-97c1-7a6b0f6082a6",
|
| 1800 |
+
"metadata": {},
|
| 1801 |
+
"outputs": [
|
| 1802 |
+
{
|
| 1803 |
+
"name": "stdout",
|
| 1804 |
+
"output_type": "stream",
|
| 1805 |
+
"text": [
|
| 1806 |
+
"8000 {'Entertainment': 588, 'General': 534, 'Politics': 1285, 'Lifestyle': 505, 'Health': 687, 'Sports': 570, 'Weather': 529, 'Business': 1465, 'Tech': 763, 'Crime': 676, 'Travel': 543, 'Science': 527, 'Automotive': 509, 'Economics': 500, 'Finance': 703, 'Financial Crime': 507}\n"
|
| 1807 |
+
]
|
| 1808 |
+
}
|
| 1809 |
+
],
|
| 1810 |
+
"source": [
|
| 1811 |
+
"dumpjson(nval, 'filtered2/val.json')"
|
| 1812 |
+
]
|
| 1813 |
+
},
|
| 1814 |
+
{
|
| 1815 |
+
"cell_type": "code",
|
| 1816 |
+
"execution_count": 70,
|
| 1817 |
+
"id": "03a5343e-2fee-4051-be11-d37bab9a4c24",
|
| 1818 |
+
"metadata": {},
|
| 1819 |
+
"outputs": [
|
| 1820 |
+
{
|
| 1821 |
+
"name": "stdout",
|
| 1822 |
+
"output_type": "stream",
|
| 1823 |
+
"text": [
|
| 1824 |
+
"16000 {'Entertainment': 1161, 'Politics': 2414, 'Crime': 1337, 'Business': 2698, 'General': 1065, 'Sports': 1118, 'Financial Crime': 1009, 'Finance': 1371, 'Travel': 1143, 'Tech': 1476, 'Health': 1295, 'Automotive': 1014, 'Weather': 1060, 'Lifestyle': 1023, 'Economics': 1000, 'Science': 1056}\n"
|
| 1825 |
+
]
|
| 1826 |
+
}
|
| 1827 |
+
],
|
| 1828 |
+
"source": [
|
| 1829 |
+
"dumpjson(test, 'filtered2/test.json')"
|
| 1830 |
+
]
|
| 1831 |
+
},
|
| 1832 |
+
{
|
| 1833 |
+
"cell_type": "code",
|
| 1834 |
+
"execution_count": null,
|
| 1835 |
+
"id": "6892ae88-b0ae-4318-82fe-868f0f260c1c",
|
| 1836 |
+
"metadata": {},
|
| 1837 |
+
"outputs": [],
|
| 1838 |
+
"source": [
|
| 1839 |
+
"for i in train:\n",
|
| 1840 |
+
" if i['content'].lower().strip() not in contset:\n",
|
| 1841 |
+
" contset[i['content'].lower().strip()] = i['themes']\n",
|
| 1842 |
+
"\n",
|
| 1843 |
+
"for i in pre_train:\n",
|
| 1844 |
+
" if i['content'].lower().strip() in contset:\n",
|
| 1845 |
+
" i['themes'] = contset[i['content'].lower().strip()]"
|
| 1846 |
+
]
|
| 1847 |
+
},
|
| 1848 |
+
{
|
| 1849 |
+
"cell_type": "code",
|
| 1850 |
+
"execution_count": null,
|
| 1851 |
+
"id": "9cd446ac-0209-4937-a8da-045e8d17edb8",
|
| 1852 |
+
"metadata": {},
|
| 1853 |
+
"outputs": [],
|
| 1854 |
+
"source": [
|
| 1855 |
+
"len(pre_train)"
|
| 1856 |
+
]
|
| 1857 |
+
},
|
| 1858 |
+
{
|
| 1859 |
+
"cell_type": "code",
|
| 1860 |
+
"execution_count": null,
|
| 1861 |
+
"id": "2ae5ec6d-ed70-4ba8-a188-a1c3dd561dbf",
|
| 1862 |
+
"metadata": {},
|
| 1863 |
+
"outputs": [],
|
| 1864 |
+
"source": [
|
| 1865 |
+
"to_ignore = set()\n",
|
| 1866 |
+
"for i in test + nval:\n",
|
| 1867 |
+
" to_ignore.add(i['content'].lower().strip())\n",
|
| 1868 |
+
"\n",
|
| 1869 |
+
"final_pre_train = []\n",
|
| 1870 |
+
"for i in pre_train:\n",
|
| 1871 |
+
" if i['content'].lower().strip() in to_ignore:\n",
|
| 1872 |
+
" continue\n",
|
| 1873 |
+
" final_pre_train.append(i)\n",
|
| 1874 |
+
"len(final_pre_train), len(to_ignore)"
|
| 1875 |
+
]
|
| 1876 |
+
},
|
| 1877 |
+
{
|
| 1878 |
+
"cell_type": "code",
|
| 1879 |
+
"execution_count": 71,
|
| 1880 |
+
"id": "6f560a9a-cf70-47f3-97f3-ae45f520590f",
|
| 1881 |
+
"metadata": {},
|
| 1882 |
+
"outputs": [
|
| 1883 |
+
{
|
| 1884 |
+
"name": "stdout",
|
| 1885 |
+
"output_type": "stream",
|
| 1886 |
+
"text": [
|
| 1887 |
+
"397784 {'Sports': 37315, 'Lifestyle': 21900, 'Crime': 88859, 'Business': 60980, 'Entertainment': 40906, 'Politics': 78084, 'General': 25417, 'Tech': 25008, 'Health': 32249, 'Travel': 15170, 'Finance': 21892, 'Economics': 13128, 'Weather': 13586, 'Science': 14338, 'Automotive': 12176, 'Financial Crime': 5035}\n"
|
| 1888 |
+
]
|
| 1889 |
+
}
|
| 1890 |
+
],
|
| 1891 |
+
"source": [
|
| 1892 |
+
"dumpjson(final_pre_train, 'filtered2/pre-train.json')"
|
| 1893 |
+
]
|
| 1894 |
+
},
|
| 1895 |
+
{
|
| 1896 |
+
"cell_type": "code",
|
| 1897 |
+
"execution_count": null,
|
| 1898 |
+
"id": "dd98be36-8ec4-444e-a7fa-d7196334a630",
|
| 1899 |
+
"metadata": {},
|
| 1900 |
+
"outputs": [],
|
| 1901 |
+
"source": []
|
| 1902 |
+
},
|
| 1903 |
+
{
|
| 1904 |
+
"cell_type": "code",
|
| 1905 |
+
"execution_count": null,
|
| 1906 |
+
"id": "1fbf2b53-45a6-4727-b862-bb7819e4562e",
|
| 1907 |
+
"metadata": {},
|
| 1908 |
+
"outputs": [],
|
| 1909 |
+
"source": []
|
| 1910 |
+
}
|
| 1911 |
+
],
|
| 1912 |
+
"metadata": {
|
| 1913 |
+
"kernelspec": {
|
| 1914 |
+
"display_name": "Python 3 (ipykernel)",
|
| 1915 |
+
"language": "python",
|
| 1916 |
+
"name": "python3"
|
| 1917 |
+
},
|
| 1918 |
+
"language_info": {
|
| 1919 |
+
"codemirror_mode": {
|
| 1920 |
+
"name": "ipython",
|
| 1921 |
+
"version": 3
|
| 1922 |
+
},
|
| 1923 |
+
"file_extension": ".py",
|
| 1924 |
+
"mimetype": "text/x-python",
|
| 1925 |
+
"name": "python",
|
| 1926 |
+
"nbconvert_exporter": "python",
|
| 1927 |
+
"pygments_lexer": "ipython3",
|
| 1928 |
+
"version": "3.10.12"
|
| 1929 |
+
}
|
| 1930 |
+
},
|
| 1931 |
+
"nbformat": 4,
|
| 1932 |
+
"nbformat_minor": 5
|
| 1933 |
+
}
|
3-training.ipynb
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
mlp_model.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:32b352f602bdf62dc246c0d6509495469520d1c6e4caa856619beeae90d89c5f
|
| 3 |
+
size 9100341
|
pre-train.json.zip
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:025bb8dd1a317149eb9ce207b992142cbc99a0e50efa0704b5222701bf69d2da
|
| 3 |
+
size 509109586
|
scaler.joblib
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:a50abf4348f226b4f3dbd0fda27ae284de7882851bcdcfda3e25a3bcaed0bbab
|
| 3 |
+
size 25191
|
test.json.zip
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:dbd8a5492b857ef5f1c0b8eb58a0f9c0bcc85532bae026bde30e3c47269ca423
|
| 3 |
+
size 22253927
|
train.json.zip
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:b5192ac529ad7c2c8083981325dd14d774a92f04ae768aac38f0b9d6669d75d6
|
| 3 |
+
size 129843701
|
val.json.zip
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:3de1a164f7512828b632939d5cc75b002f7f8564c0ba142b3329cb5e2e3a5199
|
| 3 |
+
size 11034512
|