prz4587 commited on
Commit
f88be8d
·
verified ·
1 Parent(s): e925eab

Upload folder using huggingface_hub

Browse files
Files changed (3) hide show
  1. 11-train.ipynb +0 -0
  2. model.pth +3 -0
  3. threshold-detect.ipynb +600 -0
11-train.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
model.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f46834fd6aba4676d4232a308276e2913eeaf4f155d2cc739a6218f7255d2f05
3
+ size 7415509
threshold-detect.ipynb ADDED
@@ -0,0 +1,600 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 4,
6
+ "id": "caa2786c-27ba-484a-9ff4-da4e378ef019",
7
+ "metadata": {},
8
+ "outputs": [],
9
+ "source": [
10
+ "import ujson as json"
11
+ ]
12
+ },
13
+ {
14
+ "cell_type": "code",
15
+ "execution_count": 5,
16
+ "id": "b1f04aee-ba66-4e67-9183-36cf2534c08f",
17
+ "metadata": {},
18
+ "outputs": [],
19
+ "source": [
20
+ "testd = json.load(open('filtered2/test.json'))"
21
+ ]
22
+ },
23
+ {
24
+ "cell_type": "code",
25
+ "execution_count": 6,
26
+ "id": "a663864b-2795-4aac-9817-8c6d8c31c4e5",
27
+ "metadata": {},
28
+ "outputs": [
29
+ {
30
+ "data": {
31
+ "text/plain": [
32
+ "16000"
33
+ ]
34
+ },
35
+ "execution_count": 6,
36
+ "metadata": {},
37
+ "output_type": "execute_result"
38
+ }
39
+ ],
40
+ "source": [
41
+ "len(testd)"
42
+ ]
43
+ },
44
+ {
45
+ "cell_type": "code",
46
+ "execution_count": 22,
47
+ "id": "65be569c-13bb-4b61-bfd1-5be33428409b",
48
+ "metadata": {},
49
+ "outputs": [],
50
+ "source": [
51
+ "from joblib import load\n",
52
+ "import torch\n",
53
+ "from torch import nn\n",
54
+ "from transformers import BertModel, BertConfig\n",
55
+ "from sklearn.preprocessing import MultiLabelBinarizer\n",
56
+ "\n",
57
+ "classes = [\n",
58
+ " \"Automotive\",\n",
59
+ " \"Business\",\n",
60
+ " \"Crime\",\n",
61
+ " \"Economics\",\n",
62
+ " \"Entertainment\",\n",
63
+ " \"Finance\",\n",
64
+ " \"Financial Crime\",\n",
65
+ " \"General\",\n",
66
+ " \"Health\",\n",
67
+ " \"Lifestyle\",\n",
68
+ " \"Politics\",\n",
69
+ " \"Science\",\n",
70
+ " \"Sports\",\n",
71
+ " \"Tech\",\n",
72
+ " \"Travel\",\n",
73
+ " \"Weather\",\n",
74
+ "]\n",
75
+ "mlb = MultiLabelBinarizer(classes=classes)\n",
76
+ "mlb.fit([[]])\n",
77
+ "\n",
78
+ "NUM_LABELS = len(classes)\n",
79
+ "\n",
80
+ "\n",
81
+ "# class SimpleMLP(nn.Module):\n",
82
+ "# def __init__(self, input_dim=1024, num_labels=NUM_LABELS):\n",
83
+ "# super().__init__()\n",
84
+ "# self.net = nn.Sequential(\n",
85
+ "# nn.Linear(input_dim, 1024),\n",
86
+ "# nn.ReLU(),\n",
87
+ "# nn.Dropout(0.1),\n",
88
+ "# nn.Linear(1024, 512),\n",
89
+ "# nn.ReLU(),\n",
90
+ "# nn.Linear(512, 512),\n",
91
+ "# nn.ReLU(),\n",
92
+ "# nn.Linear(512, 512),\n",
93
+ "# nn.ReLU(),\n",
94
+ "# nn.Linear(512, 256),\n",
95
+ "# nn.ReLU(),\n",
96
+ "# nn.Linear(256, 128),\n",
97
+ "# nn.ReLU(),\n",
98
+ "# nn.Linear(128, 64),\n",
99
+ "# nn.ReLU(),\n",
100
+ "# nn.LayerNorm(64),\n",
101
+ "# nn.Linear(64, num_labels),\n",
102
+ "# )\n",
103
+ "\n",
104
+ "# def forward(self, x):\n",
105
+ "# return self.net(x) # logits\n",
106
+ "\n",
107
+ "\n",
108
+ "# THEME_MODEL = SimpleMLP(num_labels=len(mlb.classes_))\n",
109
+ "# device = 0 if torch.cuda.is_available() else -1\n",
110
+ "# if device == 0:\n",
111
+ "# THEME_MODEL.load_state_dict(torch.load('qwen_embedding_theme/mlp_model.pth'))\n",
112
+ "# else:\n",
113
+ "# THEME_MODEL.load_state_dict(torch.load(\n",
114
+ "# 'qwen_embedding_theme/mlp_model.pth', map_location=torch.device('cpu')))\n",
115
+ "# THEME_MODEL.eval()\n",
116
+ "# if torch.cuda.is_available():\n",
117
+ "# THEME_MODEL.to(device)\n",
118
+ "\n",
119
+ "# SCALER = load('qwen_embedding_theme/scaler.joblib')\n"
120
+ ]
121
+ },
122
+ {
123
+ "cell_type": "code",
124
+ "execution_count": 23,
125
+ "id": "c895eac1-6b42-4b8c-ab26-3a79b507e8a7",
126
+ "metadata": {},
127
+ "outputs": [
128
+ {
129
+ "name": "stdout",
130
+ "output_type": "stream",
131
+ "text": [
132
+ "classes: ['Automotive', 'Business', 'Crime', 'Economics', 'Entertainment', 'Finance', 'Financial Crime', 'General', 'Health', 'Lifestyle', 'Politics', 'Science', 'Sports', 'Tech', 'Travel', 'Weather']\n",
133
+ "mlb.classes_: ['Automotive', 'Business', 'Crime', 'Economics', 'Entertainment', 'Finance', 'Financial Crime', 'General', 'Health', 'Lifestyle', 'Politics', 'Science', 'Sports', 'Tech', 'Travel', 'Weather']\n"
134
+ ]
135
+ }
136
+ ],
137
+ "source": [
138
+ "print(\"classes:\", classes)\n",
139
+ "print(\"mlb.classes_:\", list(mlb.classes_))"
140
+ ]
141
+ },
142
+ {
143
+ "cell_type": "code",
144
+ "execution_count": 24,
145
+ "id": "857c360a-07c4-41ef-bcd3-a90b862eaea4",
146
+ "metadata": {},
147
+ "outputs": [],
148
+ "source": [
149
+ "def batch(iterable, n=10):\n",
150
+ " l = len(iterable)\n",
151
+ " for ndx in range(0, l, n):\n",
152
+ " yield iterable[ndx:min(ndx + n, l)]\n",
153
+ "\n",
154
+ "def get_multilabel_themes(embeddings, batch_size=16):\n",
155
+ " results = []\n",
156
+ "\n",
157
+ " for batch_embeddings in batch(embeddings, n=batch_size):\n",
158
+ " # Transform embeddings with the pre-fitted scaler\n",
159
+ " batch_embeddings = SCALER.transform(batch_embeddings)\n",
160
+ " batch_embeddings = torch.tensor(batch_embeddings, dtype=torch.float)\n",
161
+ " if device == 0:\n",
162
+ " batch_embeddings = batch_embeddings.to(\"cuda\")\n",
163
+ "\n",
164
+ " with torch.no_grad():\n",
165
+ " predictions = THEME_MODEL(batch_embeddings)\n",
166
+ " probabilities = torch.sigmoid(predictions)\n",
167
+ "\n",
168
+ " # Convert probabilities to CPU and numpy format for easier processing with sklearn\n",
169
+ " probabilities = probabilities.cpu().numpy()\n",
170
+ "\n",
171
+ " # Prepare the list of dictionaries with theme names and scores\n",
172
+ " for probability in probabilities:\n",
173
+ " result = [{'name': label, 'score': round(float(score), 2)} for label, score in\n",
174
+ " zip(mlb.classes_, probability)]\n",
175
+ " if not result:\n",
176
+ " result = [{'name': label, 'score': round(float(score), 2)} for label, score in\n",
177
+ " zip(mlb.classes_, probability)]\n",
178
+ " result = [max(result, key=lambda x: x['score'])]\n",
179
+ " result = sorted(result, key=lambda x: x['score'], reverse=True)\n",
180
+ " results.append(result)\n",
181
+ "\n",
182
+ " return results"
183
+ ]
184
+ },
185
+ {
186
+ "cell_type": "code",
187
+ "execution_count": 25,
188
+ "id": "3f23a7c4-12df-45b7-96f8-148f7cbff251",
189
+ "metadata": {},
190
+ "outputs": [],
191
+ "source": [
192
+ "class WideShallowMLP(nn.Module):\n",
193
+ " def __init__(\n",
194
+ " self,\n",
195
+ " input_dim=1024,\n",
196
+ " num_labels=NUM_LABELS,\n",
197
+ " dropout=0.35,\n",
198
+ " activation=\"gelu\", # \"gelu\" or \"silu\"\n",
199
+ " hidden2=768, # 2nd layer width\n",
200
+ " temperature=0.6, # logit temperature scaling\n",
201
+ " ):\n",
202
+ " super().__init__()\n",
203
+ "\n",
204
+ " if activation == \"gelu\":\n",
205
+ " act = nn.GELU()\n",
206
+ " elif activation == \"silu\":\n",
207
+ " act = nn.SiLU()\n",
208
+ " else:\n",
209
+ " raise ValueError(f\"Unknown activation: {activation}\")\n",
210
+ "\n",
211
+ " self.temperature = float(temperature)\n",
212
+ "\n",
213
+ " self.net = nn.Sequential(\n",
214
+ " nn.Linear(input_dim, 1024),\n",
215
+ " act,\n",
216
+ " nn.LayerNorm(1024),\n",
217
+ " nn.Dropout(dropout),\n",
218
+ "\n",
219
+ " nn.Linear(1024, hidden2),\n",
220
+ " act,\n",
221
+ " nn.LayerNorm(hidden2),\n",
222
+ " nn.Dropout(dropout),\n",
223
+ "\n",
224
+ " nn.Linear(hidden2, num_labels),\n",
225
+ " )\n",
226
+ "\n",
227
+ " def forward(self, x):\n",
228
+ " logits = self.net(x)\n",
229
+ " if self.temperature != 1.0:\n",
230
+ " logits = logits / self.temperature\n",
231
+ " return logits\n",
232
+ "\n",
233
+ "THEME_MODEL = WideShallowMLP()\n",
234
+ "device = 0 if torch.cuda.is_available() else -1\n",
235
+ "if device == 0:\n",
236
+ " THEME_MODEL.load_state_dict(torch.load('exp_artifacts_grid/ws_gelu_d0.35_wd0.003_h2768_t0.6.pth'))\n",
237
+ "else:\n",
238
+ " THEME_MODEL.load_state_dict(torch.load(\n",
239
+ " 'exp_artifacts_grid/ws_gelu_d0.35_wd0.003_h2768_t0.6.pth', map_location=torch.device('cpu')))\n",
240
+ "THEME_MODEL.eval()\n",
241
+ "if torch.cuda.is_available():\n",
242
+ " THEME_MODEL.to(device)\n",
243
+ "\n",
244
+ "SCALER = load('exp_artifacts_grid/scaler.joblib')"
245
+ ]
246
+ },
247
+ {
248
+ "cell_type": "code",
249
+ "execution_count": 26,
250
+ "id": "22c1a4c1-d0e1-44f9-8518-6d81acec062e",
251
+ "metadata": {},
252
+ "outputs": [],
253
+ "source": [
254
+ "import numpy as np\n",
255
+ "thresholds = np.load('exp_artifacts_grid/ws_gelu_d0.35_wd0.003_h2768_t0.6_thresholds.npy')"
256
+ ]
257
+ },
258
+ {
259
+ "cell_type": "code",
260
+ "execution_count": 27,
261
+ "id": "f8a98800-9a1d-417e-904f-06f814be81fb",
262
+ "metadata": {},
263
+ "outputs": [
264
+ {
265
+ "data": {
266
+ "text/plain": [
267
+ "array([0.5 , 0.6 , 0.65, 0.55, 0.5 , 0.5 , 0.45, 0.5 , 0.6 , 0.75, 0.45,\n",
268
+ " 0.5 , 0.5 , 0.6 , 0.55, 0.2 ], dtype=float32)"
269
+ ]
270
+ },
271
+ "execution_count": 27,
272
+ "metadata": {},
273
+ "output_type": "execute_result"
274
+ }
275
+ ],
276
+ "source": [
277
+ "thresholds"
278
+ ]
279
+ },
280
+ {
281
+ "cell_type": "code",
282
+ "execution_count": 28,
283
+ "id": "c4ac491c-96bb-48d9-9037-ffd758d81ad8",
284
+ "metadata": {},
285
+ "outputs": [],
286
+ "source": [
287
+ "qwen_themes = get_multilabel_themes([i['embedding'] for i in testd])"
288
+ ]
289
+ },
290
+ {
291
+ "cell_type": "code",
292
+ "execution_count": 29,
293
+ "id": "9a131376-cfbc-4ebd-9f32-859e61e989e0",
294
+ "metadata": {},
295
+ "outputs": [
296
+ {
297
+ "data": {
298
+ "text/plain": [
299
+ "[{'name': 'Business', 'score': 0.86},\n",
300
+ " {'name': 'Automotive', 'score': 0.47},\n",
301
+ " {'name': 'Travel', 'score': 0.37},\n",
302
+ " {'name': 'Tech', 'score': 0.19},\n",
303
+ " {'name': 'Lifestyle', 'score': 0.07},\n",
304
+ " {'name': 'General', 'score': 0.04},\n",
305
+ " {'name': 'Health', 'score': 0.03},\n",
306
+ " {'name': 'Politics', 'score': 0.03},\n",
307
+ " {'name': 'Economics', 'score': 0.02},\n",
308
+ " {'name': 'Finance', 'score': 0.02},\n",
309
+ " {'name': 'Science', 'score': 0.02},\n",
310
+ " {'name': 'Sports', 'score': 0.02},\n",
311
+ " {'name': 'Entertainment', 'score': 0.01},\n",
312
+ " {'name': 'Crime', 'score': 0.0},\n",
313
+ " {'name': 'Financial Crime', 'score': 0.0},\n",
314
+ " {'name': 'Weather', 'score': 0.0}]"
315
+ ]
316
+ },
317
+ "execution_count": 29,
318
+ "metadata": {},
319
+ "output_type": "execute_result"
320
+ }
321
+ ],
322
+ "source": [
323
+ "qwen_themes[12469]"
324
+ ]
325
+ },
326
+ {
327
+ "cell_type": "code",
328
+ "execution_count": 30,
329
+ "id": "17e1e5db-9fe6-431b-bc58-255e1288bd44",
330
+ "metadata": {},
331
+ "outputs": [
332
+ {
333
+ "data": {
334
+ "text/plain": [
335
+ "['Business', 'Travel']"
336
+ ]
337
+ },
338
+ "execution_count": 30,
339
+ "metadata": {},
340
+ "output_type": "execute_result"
341
+ }
342
+ ],
343
+ "source": [
344
+ "testd[12469]['themes']"
345
+ ]
346
+ },
347
+ {
348
+ "cell_type": "code",
349
+ "execution_count": 31,
350
+ "id": "1a4fa192-a344-4d7d-bd7b-ff13f723dfd0",
351
+ "metadata": {},
352
+ "outputs": [],
353
+ "source": [
354
+ "from sklearn.preprocessing import MultiLabelBinarizer\n",
355
+ "from sklearn.metrics import (\n",
356
+ " accuracy_score,\n",
357
+ " hamming_loss,\n",
358
+ " precision_score,\n",
359
+ " recall_score,\n",
360
+ " f1_score,\n",
361
+ " jaccard_score,\n",
362
+ " classification_report\n",
363
+ ")\n",
364
+ "\n",
365
+ "def print_multilabel_metrics(y_true, y_pred):\n",
366
+ " \"\"\"\n",
367
+ " y_true: List[List[str]] -> ground truth labels\n",
368
+ " y_pred: List[List[str]] -> predicted labels\n",
369
+ " \"\"\"\n",
370
+ "\n",
371
+ " mlb = MultiLabelBinarizer()\n",
372
+ " y_true_bin = mlb.fit_transform(y_true)\n",
373
+ " y_pred_bin = mlb.transform(y_pred)\n",
374
+ "\n",
375
+ " acc = accuracy_score(y_true_bin, y_pred_bin)\n",
376
+ " h_loss = hamming_loss(y_true_bin, y_pred_bin)\n",
377
+ " prec_micro = precision_score(y_true_bin, y_pred_bin, average=\"micro\", zero_division=0)\n",
378
+ " rec_micro = recall_score(y_true_bin, y_pred_bin, average=\"micro\", zero_division=0)\n",
379
+ " f1_micro = f1_score(y_true_bin, y_pred_bin, average=\"micro\", zero_division=0)\n",
380
+ " jacc = jaccard_score(y_true_bin, y_pred_bin, average=\"samples\", zero_division=0)\n",
381
+ "\n",
382
+ " print(\"Accuracy:\", acc)\n",
383
+ " print(\"Hamming Loss:\", h_loss)\n",
384
+ " print(\"Precision (micro):\", prec_micro)\n",
385
+ " print(\"Recall (micro):\", rec_micro)\n",
386
+ " print(\"F1-Score (micro):\", f1_micro)\n",
387
+ " print(\"Jaccard Similarity (samples avg):\", jacc)\n",
388
+ " print(\"\\nClassification Report:\")\n",
389
+ " print(classification_report(y_true_bin, y_pred_bin, target_names=mlb.classes_, zero_division=0))\n"
390
+ ]
391
+ },
392
+ {
393
+ "cell_type": "markdown",
394
+ "id": "f517dcbf-8a64-49e6-a9b1-754353548ca8",
395
+ "metadata": {},
396
+ "source": [
397
+ "# score >= 0.1"
398
+ ]
399
+ },
400
+ {
401
+ "cell_type": "code",
402
+ "execution_count": 32,
403
+ "id": "2e071073-bd2b-4ad9-b0dd-00c9ee4dc756",
404
+ "metadata": {},
405
+ "outputs": [
406
+ {
407
+ "name": "stdout",
408
+ "output_type": "stream",
409
+ "text": [
410
+ "Accuracy: 0.230875\n",
411
+ "Hamming Loss: 0.113875\n",
412
+ "Precision (micro): 0.4182678401718937\n",
413
+ "Recall (micro): 0.9531544256120528\n",
414
+ "F1-Score (micro): 0.5814020275121334\n",
415
+ "Jaccard Similarity (samples avg): 0.524701389322483\n",
416
+ "\n",
417
+ "Classification Report:\n",
418
+ " precision recall f1-score support\n",
419
+ "\n",
420
+ " Automotive 0.52 0.95 0.67 1014\n",
421
+ " Business 0.39 0.97 0.56 2698\n",
422
+ " Crime 0.37 0.96 0.53 1337\n",
423
+ " Economics 0.31 0.97 0.47 1000\n",
424
+ " Entertainment 0.45 0.97 0.61 1161\n",
425
+ " Finance 0.37 0.97 0.54 1371\n",
426
+ "Financial Crime 0.49 0.94 0.64 1009\n",
427
+ " General 0.26 0.89 0.41 1065\n",
428
+ " Health 0.44 0.95 0.61 1295\n",
429
+ " Lifestyle 0.29 0.91 0.44 1023\n",
430
+ " Politics 0.48 0.96 0.64 2414\n",
431
+ " Science 0.45 0.96 0.61 1056\n",
432
+ " Sports 0.64 0.96 0.77 1118\n",
433
+ " Tech 0.46 0.97 0.62 1476\n",
434
+ " Travel 0.47 0.95 0.63 1143\n",
435
+ " Weather 0.63 0.94 0.75 1060\n",
436
+ "\n",
437
+ " micro avg 0.42 0.95 0.58 21240\n",
438
+ " macro avg 0.44 0.95 0.59 21240\n",
439
+ " weighted avg 0.44 0.95 0.59 21240\n",
440
+ " samples avg 0.53 0.95 0.64 21240\n",
441
+ "\n"
442
+ ]
443
+ }
444
+ ],
445
+ "source": [
446
+ "all_labels = [sorted(i['themes']) for i in testd]\n",
447
+ "y_pred = [sorted(list(set(j['name'] for j in i if j['score'] >= 0.1))) for i in qwen_themes]\n",
448
+ "print_multilabel_metrics(all_labels, y_pred)"
449
+ ]
450
+ },
451
+ {
452
+ "cell_type": "markdown",
453
+ "id": "8746d0e7-8897-4d1f-8a46-1fccd618d198",
454
+ "metadata": {},
455
+ "source": [
456
+ "# score with ML detected"
457
+ ]
458
+ },
459
+ {
460
+ "cell_type": "code",
461
+ "execution_count": 35,
462
+ "id": "63d3005f-df66-4205-ba00-31d60d109930",
463
+ "metadata": {},
464
+ "outputs": [
465
+ {
466
+ "name": "stdout",
467
+ "output_type": "stream",
468
+ "text": [
469
+ "Accuracy: 0.5604375\n",
470
+ "Hamming Loss: 0.04196484375\n",
471
+ "Precision (micro): 0.727671018956318\n",
472
+ "Recall (micro): 0.789783427495292\n",
473
+ "F1-Score (micro): 0.7574560314270878\n",
474
+ "Jaccard Similarity (samples avg): 0.712795386904762\n",
475
+ "\n",
476
+ "Classification Report:\n",
477
+ " precision recall f1-score support\n",
478
+ "\n",
479
+ " Automotive 0.82 0.81 0.82 1014\n",
480
+ " Business 0.70 0.74 0.72 2698\n",
481
+ " Crime 0.72 0.73 0.72 1337\n",
482
+ " Economics 0.62 0.79 0.70 1000\n",
483
+ " Entertainment 0.79 0.83 0.81 1161\n",
484
+ " Finance 0.70 0.81 0.75 1371\n",
485
+ "Financial Crime 0.74 0.78 0.76 1009\n",
486
+ " General 0.47 0.67 0.55 1065\n",
487
+ " Health 0.80 0.78 0.79 1295\n",
488
+ " Lifestyle 0.70 0.63 0.66 1023\n",
489
+ " Politics 0.74 0.86 0.80 2414\n",
490
+ " Science 0.72 0.79 0.76 1056\n",
491
+ " Sports 0.89 0.88 0.89 1118\n",
492
+ " Tech 0.78 0.80 0.79 1476\n",
493
+ " Travel 0.80 0.80 0.80 1143\n",
494
+ " Weather 0.75 0.89 0.81 1060\n",
495
+ "\n",
496
+ " micro avg 0.73 0.79 0.76 21240\n",
497
+ " macro avg 0.73 0.79 0.76 21240\n",
498
+ " weighted avg 0.73 0.79 0.76 21240\n",
499
+ " samples avg 0.76 0.82 0.76 21240\n",
500
+ "\n"
501
+ ]
502
+ }
503
+ ],
504
+ "source": [
505
+ "threshold = {'Automotive': 0.5, 'Business': 0.6, 'Crime': 0.65, 'Economics': 0.55, 'Entertainment': 0.5, 'Finance': 0.5, 'Financial Crime': 0.45, 'General': 0.5, 'Health': 0.6, 'Lifestyle': 0.75, 'Politics': 0.45, 'Science': 0.5, 'Sports': 0.5, 'Tech': 0.6, 'Travel': 0.55, 'Weather': 0.2}\n",
506
+ "all_labels = [sorted(i['themes']) for i in testd]\n",
507
+ "y_pred = [sorted(list(set(j['name'] for j in i if j['score'] >= threshold[j['name']]))) for i in qwen_themes]\n",
508
+ "print_multilabel_metrics(all_labels, y_pred)"
509
+ ]
510
+ },
511
+ {
512
+ "cell_type": "markdown",
513
+ "id": "57e004e7-a815-465f-9991-efde869338c4",
514
+ "metadata": {},
515
+ "source": [
516
+ "# score >= 0.5"
517
+ ]
518
+ },
519
+ {
520
+ "cell_type": "code",
521
+ "execution_count": 34,
522
+ "id": "ddb9dfdc-5cba-4d9d-bdf6-fa80ea157bf6",
523
+ "metadata": {},
524
+ "outputs": [
525
+ {
526
+ "name": "stdout",
527
+ "output_type": "stream",
528
+ "text": [
529
+ "Accuracy: 0.5435\n",
530
+ "Hamming Loss: 0.04390625\n",
531
+ "Precision (micro): 0.7069707757264674\n",
532
+ "Recall (micro): 0.8040960451977401\n",
533
+ "F1-Score (micro): 0.7524120005286576\n",
534
+ "Jaccard Similarity (samples avg): 0.7082564100829726\n",
535
+ "\n",
536
+ "Classification Report:\n",
537
+ " precision recall f1-score support\n",
538
+ "\n",
539
+ " Automotive 0.82 0.81 0.82 1014\n",
540
+ " Business 0.66 0.79 0.72 2698\n",
541
+ " Crime 0.64 0.80 0.71 1337\n",
542
+ " Economics 0.60 0.81 0.69 1000\n",
543
+ " Entertainment 0.79 0.83 0.81 1161\n",
544
+ " Finance 0.70 0.81 0.75 1371\n",
545
+ "Financial Crime 0.76 0.76 0.76 1009\n",
546
+ " General 0.47 0.67 0.55 1065\n",
547
+ " Health 0.76 0.82 0.79 1295\n",
548
+ " Lifestyle 0.56 0.77 0.65 1023\n",
549
+ " Politics 0.77 0.84 0.80 2414\n",
550
+ " Science 0.72 0.79 0.76 1056\n",
551
+ " Sports 0.89 0.88 0.89 1118\n",
552
+ " Tech 0.74 0.84 0.78 1476\n",
553
+ " Travel 0.78 0.81 0.80 1143\n",
554
+ " Weather 0.86 0.76 0.81 1060\n",
555
+ "\n",
556
+ " micro avg 0.71 0.80 0.75 21240\n",
557
+ " macro avg 0.72 0.80 0.75 21240\n",
558
+ " weighted avg 0.72 0.80 0.76 21240\n",
559
+ " samples avg 0.75 0.83 0.76 21240\n",
560
+ "\n"
561
+ ]
562
+ }
563
+ ],
564
+ "source": [
565
+ "all_labels = [sorted(i['themes']) for i in testd]\n",
566
+ "y_pred = [sorted(list(set(j['name'] for j in i if j['score'] >= 0.5))) for i in qwen_themes]\n",
567
+ "print_multilabel_metrics(all_labels, y_pred)"
568
+ ]
569
+ },
570
+ {
571
+ "cell_type": "code",
572
+ "execution_count": null,
573
+ "id": "1f50c82e-d524-450c-aef5-0170af19123d",
574
+ "metadata": {},
575
+ "outputs": [],
576
+ "source": []
577
+ }
578
+ ],
579
+ "metadata": {
580
+ "kernelspec": {
581
+ "display_name": "Python 3 (ipykernel)",
582
+ "language": "python",
583
+ "name": "python3"
584
+ },
585
+ "language_info": {
586
+ "codemirror_mode": {
587
+ "name": "ipython",
588
+ "version": 3
589
+ },
590
+ "file_extension": ".py",
591
+ "mimetype": "text/x-python",
592
+ "name": "python",
593
+ "nbconvert_exporter": "python",
594
+ "pygments_lexer": "ipython3",
595
+ "version": "3.10.12"
596
+ }
597
+ },
598
+ "nbformat": 4,
599
+ "nbformat_minor": 5
600
+ }