{ "cells": [ { "cell_type": "code", "execution_count": 1, "id": "caa2786c-27ba-484a-9ff4-da4e378ef019", "metadata": {}, "outputs": [], "source": [ "import ujson as json" ] }, { "cell_type": "code", "execution_count": 2, "id": "b1f04aee-ba66-4e67-9183-36cf2534c08f", "metadata": {}, "outputs": [], "source": [ "testd = json.load(open('../filtered2/test.json'))" ] }, { "cell_type": "code", "execution_count": 3, "id": "a663864b-2795-4aac-9817-8c6d8c31c4e5", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "16000" ] }, "execution_count": 3, "metadata": {}, "output_type": "execute_result" } ], "source": [ "len(testd)" ] }, { "cell_type": "code", "execution_count": 4, "id": "65be569c-13bb-4b61-bfd1-5be33428409b", "metadata": {}, "outputs": [], "source": [ "from joblib import load\n", "import torch\n", "from torch import nn\n", "from transformers import BertModel, BertConfig\n", "from sklearn.preprocessing import MultiLabelBinarizer\n", "\n", "classes = [\n", " \"Automotive\",\n", " \"Business\",\n", " \"Crime\",\n", " \"Economics\",\n", " \"Entertainment\",\n", " \"Finance\",\n", " \"Financial Crime\",\n", " \"General\",\n", " \"Health\",\n", " \"Lifestyle\",\n", " \"Politics\",\n", " \"Science\",\n", " \"Sports\",\n", " \"Tech\",\n", " \"Travel\",\n", " \"Weather\",\n", "]\n", "mlb = MultiLabelBinarizer(classes=classes)\n", "mlb.fit([[]])\n", "\n", "NUM_LABELS = len(classes)\n", "\n", "\n", "# class SimpleMLP(nn.Module):\n", "# def __init__(self, input_dim=1024, num_labels=NUM_LABELS):\n", "# super().__init__()\n", "# self.net = nn.Sequential(\n", "# nn.Linear(input_dim, 1024),\n", "# nn.ReLU(),\n", "# nn.Dropout(0.1),\n", "# nn.Linear(1024, 512),\n", "# nn.ReLU(),\n", "# nn.Linear(512, 512),\n", "# nn.ReLU(),\n", "# nn.Linear(512, 512),\n", "# nn.ReLU(),\n", "# nn.Linear(512, 256),\n", "# nn.ReLU(),\n", "# nn.Linear(256, 128),\n", "# nn.ReLU(),\n", "# nn.Linear(128, 64),\n", "# nn.ReLU(),\n", "# nn.LayerNorm(64),\n", "# nn.Linear(64, num_labels),\n", "# )\n", "\n", "# def forward(self, x):\n", "# return self.net(x) # logits\n", "\n", "\n", "# THEME_MODEL = SimpleMLP(num_labels=len(mlb.classes_))\n", "# device = 0 if torch.cuda.is_available() else -1\n", "# if device == 0:\n", "# THEME_MODEL.load_state_dict(torch.load('qwen_embedding_theme/mlp_model.pth'))\n", "# else:\n", "# THEME_MODEL.load_state_dict(torch.load(\n", "# 'qwen_embedding_theme/mlp_model.pth', map_location=torch.device('cpu')))\n", "# THEME_MODEL.eval()\n", "# if torch.cuda.is_available():\n", "# THEME_MODEL.to(device)\n", "\n", "# SCALER = load('qwen_embedding_theme/scaler.joblib')\n" ] }, { "cell_type": "code", "execution_count": 5, "id": "c895eac1-6b42-4b8c-ab26-3a79b507e8a7", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "classes: ['Automotive', 'Business', 'Crime', 'Economics', 'Entertainment', 'Finance', 'Financial Crime', 'General', 'Health', 'Lifestyle', 'Politics', 'Science', 'Sports', 'Tech', 'Travel', 'Weather']\n", "mlb.classes_: ['Automotive', 'Business', 'Crime', 'Economics', 'Entertainment', 'Finance', 'Financial Crime', 'General', 'Health', 'Lifestyle', 'Politics', 'Science', 'Sports', 'Tech', 'Travel', 'Weather']\n" ] } ], "source": [ "print(\"classes:\", classes)\n", "print(\"mlb.classes_:\", list(mlb.classes_))" ] }, { "cell_type": "code", "execution_count": 6, "id": "857c360a-07c4-41ef-bcd3-a90b862eaea4", "metadata": {}, "outputs": [], "source": [ "def batch(iterable, n=10):\n", " l = len(iterable)\n", " for ndx in range(0, l, n):\n", " yield iterable[ndx:min(ndx + n, l)]\n", "\n", "def get_multilabel_themes(embeddings, batch_size=16):\n", " results = []\n", "\n", " for batch_embeddings in batch(embeddings, n=batch_size):\n", " # Transform embeddings with the pre-fitted scaler\n", " batch_embeddings = SCALER.transform(batch_embeddings)\n", " batch_embeddings = torch.tensor(batch_embeddings, dtype=torch.float)\n", " if device == 0:\n", " batch_embeddings = batch_embeddings.to(\"cuda\")\n", "\n", " with torch.no_grad():\n", " predictions = THEME_MODEL(batch_embeddings)\n", " probabilities = torch.sigmoid(predictions)\n", "\n", " # Convert probabilities to CPU and numpy format for easier processing with sklearn\n", " probabilities = probabilities.cpu().numpy()\n", "\n", " # Prepare the list of dictionaries with theme names and scores\n", " for probability in probabilities:\n", " result = [{'name': label, 'score': round(float(score), 2)} for label, score in\n", " zip(mlb.classes_, probability)]\n", " if not result:\n", " result = [{'name': label, 'score': round(float(score), 2)} for label, score in\n", " zip(mlb.classes_, probability)]\n", " result = [max(result, key=lambda x: x['score'])]\n", " result = sorted(result, key=lambda x: x['score'], reverse=True)\n", " results.append(result)\n", "\n", " return results" ] }, { "cell_type": "code", "execution_count": 7, "id": "3f23a7c4-12df-45b7-96f8-148f7cbff251", "metadata": {}, "outputs": [], "source": [ "class WideShallowMLP(nn.Module):\n", " def __init__(\n", " self,\n", " input_dim=1024,\n", " num_labels=NUM_LABELS,\n", " dropout=0.35,\n", " activation=\"gelu\", # \"gelu\" or \"silu\"\n", " hidden2=768, # 2nd layer width\n", " temperature=0.6, # logit temperature scaling\n", " ):\n", " super().__init__()\n", "\n", " if activation == \"gelu\":\n", " act = nn.GELU()\n", " elif activation == \"silu\":\n", " act = nn.SiLU()\n", " else:\n", " raise ValueError(f\"Unknown activation: {activation}\")\n", "\n", " self.temperature = float(temperature)\n", "\n", " self.net = nn.Sequential(\n", " nn.Linear(input_dim, 1024),\n", " act,\n", " nn.LayerNorm(1024),\n", " nn.Dropout(dropout),\n", "\n", " nn.Linear(1024, hidden2),\n", " act,\n", " nn.LayerNorm(hidden2),\n", " nn.Dropout(dropout),\n", "\n", " nn.Linear(hidden2, num_labels),\n", " )\n", "\n", " def forward(self, x):\n", " logits = self.net(x)\n", " if self.temperature != 1.0:\n", " logits = logits / self.temperature\n", " return logits\n", "\n", "THEME_MODEL = WideShallowMLP()\n", "device = 0 if torch.cuda.is_available() else -1\n", "if device == 0:\n", " THEME_MODEL.load_state_dict(torch.load('model.pth'))\n", "else:\n", " THEME_MODEL.load_state_dict(torch.load(\n", " 'model.pth', map_location=torch.device('cpu')))\n", "THEME_MODEL.eval()\n", "if torch.cuda.is_available():\n", " THEME_MODEL.to(device)\n", "\n", "SCALER = load('scaler.joblib')" ] }, { "cell_type": "code", "execution_count": 8, "id": "c4ac491c-96bb-48d9-9037-ffd758d81ad8", "metadata": {}, "outputs": [], "source": [ "qwen_themes = get_multilabel_themes([i['embedding'] for i in testd])" ] }, { "cell_type": "code", "execution_count": 9, "id": "9a131376-cfbc-4ebd-9f32-859e61e989e0", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "[{'name': 'Business', 'score': 0.86},\n", " {'name': 'Automotive', 'score': 0.47},\n", " {'name': 'Travel', 'score': 0.37},\n", " {'name': 'Tech', 'score': 0.19},\n", " {'name': 'Lifestyle', 'score': 0.07},\n", " {'name': 'General', 'score': 0.04},\n", " {'name': 'Health', 'score': 0.03},\n", " {'name': 'Politics', 'score': 0.03},\n", " {'name': 'Economics', 'score': 0.02},\n", " {'name': 'Finance', 'score': 0.02},\n", " {'name': 'Science', 'score': 0.02},\n", " {'name': 'Sports', 'score': 0.02},\n", " {'name': 'Entertainment', 'score': 0.01},\n", " {'name': 'Crime', 'score': 0.0},\n", " {'name': 'Financial Crime', 'score': 0.0},\n", " {'name': 'Weather', 'score': 0.0}]" ] }, "execution_count": 9, "metadata": {}, "output_type": "execute_result" } ], "source": [ "qwen_themes[12469]" ] }, { "cell_type": "code", "execution_count": 10, "id": "17e1e5db-9fe6-431b-bc58-255e1288bd44", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "['Business', 'Travel']" ] }, "execution_count": 10, "metadata": {}, "output_type": "execute_result" } ], "source": [ "testd[12469]['themes']" ] }, { "cell_type": "code", "execution_count": 11, "id": "1a4fa192-a344-4d7d-bd7b-ff13f723dfd0", "metadata": {}, "outputs": [], "source": [ "from sklearn.preprocessing import MultiLabelBinarizer\n", "from sklearn.metrics import (\n", " accuracy_score,\n", " hamming_loss,\n", " precision_score,\n", " recall_score,\n", " f1_score,\n", " jaccard_score,\n", " classification_report\n", ")\n", "\n", "def print_multilabel_metrics(y_true, y_pred):\n", " \"\"\"\n", " y_true: List[List[str]] -> ground truth labels\n", " y_pred: List[List[str]] -> predicted labels\n", " \"\"\"\n", "\n", " mlb = MultiLabelBinarizer()\n", " y_true_bin = mlb.fit_transform(y_true)\n", " y_pred_bin = mlb.transform(y_pred)\n", "\n", " acc = accuracy_score(y_true_bin, y_pred_bin)\n", " h_loss = hamming_loss(y_true_bin, y_pred_bin)\n", " prec_micro = precision_score(y_true_bin, y_pred_bin, average=\"micro\", zero_division=0)\n", " rec_micro = recall_score(y_true_bin, y_pred_bin, average=\"micro\", zero_division=0)\n", " f1_micro = f1_score(y_true_bin, y_pred_bin, average=\"micro\", zero_division=0)\n", " jacc = jaccard_score(y_true_bin, y_pred_bin, average=\"samples\", zero_division=0)\n", "\n", " print(\"Accuracy:\", acc)\n", " print(\"Hamming Loss:\", h_loss)\n", " print(\"Precision (micro):\", prec_micro)\n", " print(\"Recall (micro):\", rec_micro)\n", " print(\"F1-Score (micro):\", f1_micro)\n", " print(\"Jaccard Similarity (samples avg):\", jacc)\n", " print(\"\\nClassification Report:\")\n", " print(classification_report(y_true_bin, y_pred_bin, target_names=mlb.classes_, zero_division=0))\n" ] }, { "cell_type": "markdown", "id": "f517dcbf-8a64-49e6-a9b1-754353548ca8", "metadata": {}, "source": [ "# score >= 0.1" ] }, { "cell_type": "code", "execution_count": 12, "id": "2e071073-bd2b-4ad9-b0dd-00c9ee4dc756", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Accuracy: 0.230875\n", "Hamming Loss: 0.113875\n", "Precision (micro): 0.4182678401718937\n", "Recall (micro): 0.9531544256120528\n", "F1-Score (micro): 0.5814020275121334\n", "Jaccard Similarity (samples avg): 0.524701389322483\n", "\n", "Classification Report:\n", " precision recall f1-score support\n", "\n", " Automotive 0.52 0.95 0.67 1014\n", " Business 0.39 0.97 0.56 2698\n", " Crime 0.37 0.96 0.53 1337\n", " Economics 0.31 0.97 0.47 1000\n", " Entertainment 0.45 0.97 0.61 1161\n", " Finance 0.37 0.97 0.54 1371\n", "Financial Crime 0.49 0.94 0.64 1009\n", " General 0.26 0.89 0.41 1065\n", " Health 0.44 0.95 0.61 1295\n", " Lifestyle 0.29 0.91 0.44 1023\n", " Politics 0.48 0.96 0.64 2414\n", " Science 0.45 0.96 0.61 1056\n", " Sports 0.64 0.96 0.77 1118\n", " Tech 0.46 0.97 0.62 1476\n", " Travel 0.47 0.95 0.63 1143\n", " Weather 0.63 0.94 0.75 1060\n", "\n", " micro avg 0.42 0.95 0.58 21240\n", " macro avg 0.44 0.95 0.59 21240\n", " weighted avg 0.44 0.95 0.59 21240\n", " samples avg 0.53 0.95 0.64 21240\n", "\n" ] } ], "source": [ "all_labels = [sorted(i['themes']) for i in testd]\n", "y_pred = [sorted(list(set(j['name'] for j in i if j['score'] >= 0.1))) for i in qwen_themes]\n", "print_multilabel_metrics(all_labels, y_pred)" ] }, { "cell_type": "markdown", "id": "8746d0e7-8897-4d1f-8a46-1fccd618d198", "metadata": {}, "source": [ "# score with ML detected" ] }, { "cell_type": "code", "execution_count": 13, "id": "63d3005f-df66-4205-ba00-31d60d109930", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Accuracy: 0.5604375\n", "Hamming Loss: 0.04196484375\n", "Precision (micro): 0.727671018956318\n", "Recall (micro): 0.789783427495292\n", "F1-Score (micro): 0.7574560314270878\n", "Jaccard Similarity (samples avg): 0.712795386904762\n", "\n", "Classification Report:\n", " precision recall f1-score support\n", "\n", " Automotive 0.82 0.81 0.82 1014\n", " Business 0.70 0.74 0.72 2698\n", " Crime 0.72 0.73 0.72 1337\n", " Economics 0.62 0.79 0.70 1000\n", " Entertainment 0.79 0.83 0.81 1161\n", " Finance 0.70 0.81 0.75 1371\n", "Financial Crime 0.74 0.78 0.76 1009\n", " General 0.47 0.67 0.55 1065\n", " Health 0.80 0.78 0.79 1295\n", " Lifestyle 0.70 0.63 0.66 1023\n", " Politics 0.74 0.86 0.80 2414\n", " Science 0.72 0.79 0.76 1056\n", " Sports 0.89 0.88 0.89 1118\n", " Tech 0.78 0.80 0.79 1476\n", " Travel 0.80 0.80 0.80 1143\n", " Weather 0.75 0.89 0.81 1060\n", "\n", " micro avg 0.73 0.79 0.76 21240\n", " macro avg 0.73 0.79 0.76 21240\n", " weighted avg 0.73 0.79 0.76 21240\n", " samples avg 0.76 0.82 0.76 21240\n", "\n" ] } ], "source": [ "threshold = {'Automotive': 0.5, 'Business': 0.6, 'Crime': 0.65, 'Economics': 0.55, 'Entertainment': 0.5, 'Finance': 0.5, 'Financial Crime': 0.45, 'General': 0.5, 'Health': 0.6, 'Lifestyle': 0.75, 'Politics': 0.45, 'Science': 0.5, 'Sports': 0.5, 'Tech': 0.6, 'Travel': 0.55, 'Weather': 0.2}\n", "all_labels = [sorted(i['themes']) for i in testd]\n", "y_pred = [sorted(list(set(j['name'] for j in i if j['score'] >= threshold[j['name']]))) for i in qwen_themes]\n", "print_multilabel_metrics(all_labels, y_pred)" ] }, { "cell_type": "markdown", "id": "57e004e7-a815-465f-9991-efde869338c4", "metadata": {}, "source": [ "# score >= 0.5" ] }, { "cell_type": "code", "execution_count": 14, "id": "ddb9dfdc-5cba-4d9d-bdf6-fa80ea157bf6", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Accuracy: 0.5435\n", "Hamming Loss: 0.04390625\n", "Precision (micro): 0.7069707757264674\n", "Recall (micro): 0.8040960451977401\n", "F1-Score (micro): 0.7524120005286576\n", "Jaccard Similarity (samples avg): 0.7082564100829726\n", "\n", "Classification Report:\n", " precision recall f1-score support\n", "\n", " Automotive 0.82 0.81 0.82 1014\n", " Business 0.66 0.79 0.72 2698\n", " Crime 0.64 0.80 0.71 1337\n", " Economics 0.60 0.81 0.69 1000\n", " Entertainment 0.79 0.83 0.81 1161\n", " Finance 0.70 0.81 0.75 1371\n", "Financial Crime 0.76 0.76 0.76 1009\n", " General 0.47 0.67 0.55 1065\n", " Health 0.76 0.82 0.79 1295\n", " Lifestyle 0.56 0.77 0.65 1023\n", " Politics 0.77 0.84 0.80 2414\n", " Science 0.72 0.79 0.76 1056\n", " Sports 0.89 0.88 0.89 1118\n", " Tech 0.74 0.84 0.78 1476\n", " Travel 0.78 0.81 0.80 1143\n", " Weather 0.86 0.76 0.81 1060\n", "\n", " micro avg 0.71 0.80 0.75 21240\n", " macro avg 0.72 0.80 0.75 21240\n", " weighted avg 0.72 0.80 0.76 21240\n", " samples avg 0.75 0.83 0.76 21240\n", "\n" ] } ], "source": [ "all_labels = [sorted(i['themes']) for i in testd]\n", "y_pred = [sorted(list(set(j['name'] for j in i if j['score'] >= 0.5))) for i in qwen_themes]\n", "print_multilabel_metrics(all_labels, y_pred)" ] }, { "cell_type": "markdown", "id": "2ddb7e38-537c-40ae-bd61-10bbaadf82f9", "metadata": {}, "source": [ "# en" ] }, { "cell_type": "code", "execution_count": 15, "id": "9e17f95b-537e-4841-896c-b4472af52b8f", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "8000\n", "Accuracy: 0.5805\n", "Hamming Loss: 0.03965625\n", "Precision (micro): 0.7507596145498742\n", "Recall (micro): 0.7968303694830923\n", "F1-Score (micro): 0.773109243697479\n", "Jaccard Similarity (samples avg): 0.7310900793650794\n", "\n", "Classification Report:\n", " precision recall f1-score support\n", "\n", " Automotive 0.83 0.77 0.80 508\n", " Business 0.73 0.78 0.75 1492\n", " Crime 0.74 0.73 0.74 667\n", " Economics 0.69 0.76 0.72 500\n", " Entertainment 0.80 0.83 0.82 586\n", " Finance 0.70 0.79 0.75 679\n", "Financial Crime 0.73 0.88 0.80 506\n", " General 0.53 0.64 0.58 540\n", " Health 0.83 0.82 0.83 670\n", " Lifestyle 0.72 0.68 0.70 509\n", " Politics 0.78 0.86 0.82 1218\n", " Science 0.76 0.77 0.76 538\n", " Sports 0.90 0.89 0.90 570\n", " Tech 0.78 0.80 0.79 771\n", " Travel 0.80 0.79 0.79 570\n", " Weather 0.73 0.90 0.80 529\n", "\n", " micro avg 0.75 0.80 0.77 10853\n", " macro avg 0.75 0.79 0.77 10853\n", " weighted avg 0.75 0.80 0.77 10853\n", " samples avg 0.78 0.83 0.78 10853\n", "\n" ] } ], "source": [ "ends = [i for i in testd if i['language'] == 'en']\n", "qwen_themes = get_multilabel_themes([i['embedding'] for i in ends])\n", "all_labels = [sorted(i['themes']) for i in ends]\n", "y_pred = [sorted(list(set(j['name'] for j in i if j['score'] >= threshold[j['name']]))) for i in qwen_themes]\n", "print(len(ends))\n", "print_multilabel_metrics(all_labels, y_pred)" ] }, { "cell_type": "markdown", "id": "a2a5a8f8-399c-4191-95ed-0f594fb82e1d", "metadata": {}, "source": [ "# non-en" ] }, { "cell_type": "code", "execution_count": 16, "id": "995c5397-9b2e-45d5-871a-0ff4d27e295c", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "8000\n", "Accuracy: 0.540375\n", "Hamming Loss: 0.0442734375\n", "Precision (micro): 0.7046124501473904\n", "Recall (micro): 0.7824203331086935\n", "F1-Score (micro): 0.7414807718625975\n", "Jaccard Similarity (samples avg): 0.6945006944444444\n", "\n", "Classification Report:\n", " precision recall f1-score support\n", "\n", " Automotive 0.81 0.86 0.84 506\n", " Business 0.67 0.69 0.68 1206\n", " Crime 0.69 0.73 0.71 670\n", " Economics 0.57 0.82 0.67 500\n", " Entertainment 0.78 0.83 0.81 575\n", " Finance 0.69 0.83 0.75 692\n", "Financial Crime 0.74 0.68 0.71 503\n", " General 0.43 0.70 0.53 525\n", " Health 0.77 0.73 0.75 625\n", " Lifestyle 0.69 0.57 0.62 514\n", " Politics 0.71 0.86 0.78 1196\n", " Science 0.69 0.82 0.75 518\n", " Sports 0.88 0.87 0.88 548\n", " Tech 0.78 0.81 0.79 705\n", " Travel 0.80 0.81 0.80 573\n", " Weather 0.76 0.89 0.82 531\n", "\n", " micro avg 0.70 0.78 0.74 10387\n", " macro avg 0.72 0.78 0.74 10387\n", " weighted avg 0.72 0.78 0.74 10387\n", " samples avg 0.74 0.80 0.75 10387\n", "\n" ] } ], "source": [ "ends = [i for i in testd if i['language'] != 'en']\n", "qwen_themes = get_multilabel_themes([i['embedding'] for i in ends])\n", "all_labels = [sorted(i['themes']) for i in ends]\n", "y_pred = [sorted(list(set(j['name'] for j in i if j['score'] >= threshold[j['name']]))) for i in qwen_themes]\n", "print(len(ends))\n", "print_multilabel_metrics(all_labels, y_pred)" ] }, { "cell_type": "code", "execution_count": null, "id": "796f9191-7771-409d-a0c3-d7a88d83485f", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.12" } }, "nbformat": 4, "nbformat_minor": 5 }