{ "cells": [ { "cell_type": "markdown", "id": "11ab9cd5-a6e4-416a-b44f-201e8bf8ee84", "metadata": {}, "source": [ "## Test inference" ] }, { "cell_type": "code", "execution_count": 5, "id": "40523be3-6ec7-4cac-aa90-6b5177c0f07d", "metadata": { "tags": [] }, "outputs": [], "source": [ "from pdfminer.high_level import extract_text" ] }, { "cell_type": "code", "execution_count": 26, "id": "c0e5cc3f-5a9d-4b0f-8f7c-d46c0f79b5df", "metadata": { "tags": [] }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Cannot set gray non-stroke color because /'P3954' is an invalid float value\n" ] } ], "source": [ "text = extract_text(\"./example_docs_for_inference/publication_climate.pdf\")" ] }, { "cell_type": "code", "execution_count": 27, "id": "120528e3-26b9-40ce-ac8c-3c30c3092d28", "metadata": { "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "ISSN 1831-9424 \n", "\n", "How to plan mitigation, adaptatio\n" ] } ], "source": [ "print(text[0:50])" ] }, { "cell_type": "code", "execution_count": 9, "id": "d191928f-381e-4da3-8342-1300909b52c5", "metadata": { "tags": [] }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/home/mbarhdadi/projects/training/eurovoc_training_env/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", " from .autonotebook import tqdm as notebook_tqdm\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Model loaded. Ready to predict 6958 eurovoc labels.\n" ] } ], "source": [ "import pickle\n", "from transformers import AutoTokenizer, AutoModel\n", "from eurovoc import EurovocTagger\n", "\n", "# Load MLBinarizer\n", "with open('./models_finetuned/latest/mlb.pickle', 'rb') as f:\n", " mlb = pickle.load(f)\n", "\n", "# Load tokenizer\n", "BERT_MODEL_NAME = \"nlpaueb/legal-bert-base-uncased\"\n", "tokenizer = AutoTokenizer.from_pretrained(BERT_MODEL_NAME)\n", "\n", "# Load trained model\n", "checkpoint_path = \"./models_finetuned/latest/EurovocTaggerFP32-epoch=04-val_loss=0.00.ckpt\" \n", "model = EurovocTagger.load_from_checkpoint(\n", " checkpoint_path,\n", " bert_model_name=BERT_MODEL_NAME,\n", " n_classes=len(mlb.classes_)\n", ")\n", "\n", "\n", "print(f\"Model loaded. Ready to predict {len(mlb.classes_)} eurovoc labels.\")" ] }, { "cell_type": "code", "execution_count": 15, "id": "7a1fd7e6-e14d-4c24-97ae-abcd5a30ab71", "metadata": { "tags": [] }, "outputs": [], "source": [ "def get_eurovoc_id_to_term_mapping():\n", " \"\"\"\n", " Create a mapping from eurovoc IDs to their human-readable terms.\n", " \n", " Returns:\n", " Dict mapping eurovoc_id -> term_name\n", " \"\"\"\n", " import requests\n", " import xmltodict\n", " \n", " eurovoc_id_to_term = {}\n", " \n", " response = requests.get(\n", " 'http://publications.europa.eu/resource/dataset/eurovoc',\n", " headers={\n", " 'Accept': 'application/xml',\n", " 'Accept-Language': 'en',\n", " 'User-Agent': 'Mozilla/5.0'\n", " }\n", " )\n", " \n", " data = xmltodict.parse(response.content)\n", " \n", " for term in data['xs:schema']['xs:simpleType']['xs:restriction']['xs:enumeration']:\n", " try:\n", " name = term['xs:annotation']['xs:documentation'].split('/')[0].strip()\n", " eurovoc_id = term['@value'].split(':')[1]\n", " \n", " # Map ID -> term \n", " eurovoc_id_to_term[eurovoc_id] = {\n", " 'original': name,\n", " 'lowercase': name.lower()\n", " }\n", " except (KeyError, IndexError) as e:\n", " print(f\"⚠️ Could not parse term: {term}\")\n", " \n", " print(f\"✓ Loaded {len(eurovoc_id_to_term)} eurovoc terms\")\n", " return eurovoc_id_to_term" ] }, { "cell_type": "code", "execution_count": 23, "id": "d2b703ea-ca41-4353-8776-1a226f02c56b", "metadata": { "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Loading Eurovoc terms...\n", "✓ Loaded 7488 eurovoc terms\n" ] } ], "source": [ "print(\"Loading Eurovoc terms...\")\n", "eurovoc_id_to_term = get_eurovoc_id_to_term_mapping()\n" ] }, { "cell_type": "code", "execution_count": 24, "id": "7a5fed81-64e8-4454-a56b-73eb50676b75", "metadata": { "tags": [] }, "outputs": [], "source": [ "import torch\n", "import numpy as np\n", "from transformers import AutoTokenizer\n", "\n", "def predict_eurovoc_labels(text, model, mlb, tokenizer, \n", " eurovoc_id_to_term=None,\n", " max_token_len=512, \n", " threshold=0.5, \n", " top_k=10,\n", " device='cuda'):\n", " model.eval()\n", " model.to(device)\n", " \n", " # Tokenize\n", " encoding = tokenizer.encode_plus(\n", " text,\n", " add_special_tokens=True,\n", " max_length=max_token_len,\n", " return_token_type_ids=False,\n", " padding=\"max_length\",\n", " truncation=True,\n", " return_attention_mask=True,\n", " return_tensors='pt',\n", " )\n", " \n", " input_ids = encoding[\"input_ids\"].to(device)\n", " attention_mask = encoding[\"attention_mask\"].to(device)\n", " \n", " # Predict\n", " with torch.no_grad():\n", " _, outputs = model(input_ids, attention_mask)\n", " \n", "\n", " probabilities = outputs\n", " \n", " probabilities = probabilities.cpu().numpy()[0]\n", " \n", " # Helper function to enrich labels with terms\n", " def enrich_labels(label_ids, probs):\n", " \"\"\"Add human-readable terms to eurovoc IDs\"\"\"\n", " enriched = []\n", " for label_id, prob in zip(label_ids, probs):\n", " entry = {\n", " 'eurovoc_id': label_id,\n", " 'probability': float(prob)\n", " }\n", " \n", " # Add term if mapping available\n", " if eurovoc_id_to_term and label_id in eurovoc_id_to_term:\n", " entry['term'] = eurovoc_id_to_term[label_id]['original']\n", " entry['term_lower'] = eurovoc_id_to_term[label_id]['lowercase']\n", " else:\n", " entry['term'] = None\n", " entry['term_lower'] = None\n", " \n", " enriched.append(entry)\n", " \n", " return enriched\n", " \n", " # Get predictions above threshold\n", " predicted_indices = np.where(probabilities >= threshold)[0]\n", " predicted_labels = mlb.classes_[predicted_indices]\n", " predicted_probs = probabilities[predicted_indices]\n", " \n", " # Get top-k predictions\n", " top_k_indices = np.argsort(probabilities)[-top_k:][::-1]\n", " top_k_labels = mlb.classes_[top_k_indices]\n", " top_k_probs = probabilities[top_k_indices]\n", " \n", " return {\n", " 'above_threshold': {\n", " 'predictions': enrich_labels(predicted_labels, predicted_probs),\n", " 'count': len(predicted_labels)\n", " },\n", " 'top_k': {\n", " 'predictions': enrich_labels(top_k_labels, top_k_probs)\n", " }\n", " }" ] }, { "cell_type": "code", "execution_count": 28, "id": "030b99aa-edc7-472a-8c7f-636a47a9cdce", "metadata": { "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Document length: 696483 characters\n", "Truncated to: 2048 tokens (~2048 chars)\n", "\n", "Running inference...\n", "\n", "================================================================================\n", "TOP 15 PREDICTED EUROVOC LABELS (with terms)\n", "================================================================================\n", "642 | energy saving | 0.8567\n", "6700 | energy efficiency | 0.7060\n", "2281 | poverty | 0.4645\n", "5311 | user guide | 0.4198\n", "2498 | energy policy | 0.3545\n", "5482 | climate change | 0.1736\n", "754 | renewable energy | 0.1338\n", "6400 | reduction of gas emissions | 0.1321\n", "2517 | social policy | 0.1260\n", "475 | energy distribution | 0.1253\n", "5188 | information technology | 0.1087\n", "2715 | energy production | 0.1087\n", "2451 | EU policy | 0.0812\n", "4139 | serial publication | 0.0808\n", "83 | living conditions | 0.0793\n", "\n", "5 labels above threshold (0.3)\n", "\n", "================================================================================\n", "PREDICTIONS ABOVE THRESHOLD (with readable terms)\n", "================================================================================\n", "2281 | poverty | 0.4645\n", "2498 | energy policy | 0.3545\n", "5311 | user guide | 0.4198\n", "642 | energy saving | 0.8567\n", "6700 | energy efficiency | 0.7060\n" ] } ], "source": [ "print(f\"Document length: {len(text)} characters\")\n", "print(f\"Truncated to: {512 * 4} tokens (~2048 chars)\\n\") \n", "\n", "print(\"Running inference...\\n\")\n", "results = predict_eurovoc_labels(\n", " text=text,\n", " model=model,\n", " mlb=mlb,\n", " tokenizer=tokenizer,\n", " eurovoc_id_to_term=eurovoc_id_to_term, # ← Pass the mapping\n", " threshold=0.3,\n", " top_k=15\n", ")\n", "print(\"=\" * 80)\n", "print(\"TOP 15 PREDICTED EUROVOC LABELS\")\n", "print(\"=\" * 80)\n", "\n", "for pred in results['top_k']['predictions']:\n", " term = pred['term'] if pred['term'] else \"(term not found)\"\n", " print(f\"{pred['eurovoc_id']:15s} | {term:45s} | {pred['probability']:.4f}\")\n", "\n", "print(f\"\\n{results['above_threshold']['count']} labels above threshold (0.3)\")\n", "\n", "print(\"\\n\" + \"=\" * 80)\n", "print(\"PREDICTIONS ABOVE THRESHOLD\")\n", "print(\"=\" * 80)\n", "\n", "for pred in results['above_threshold']['predictions']:\n", " if pred['term']: # Only show if term was found\n", " print(f\"{pred['eurovoc_id']:15s} | {pred['term']:45s} | {pred['probability']:.4f}\")" ] }, { "cell_type": "code", "execution_count": null, "id": "27ebc73c-5832-4702-bc1e-dd026ebeed02", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "eurovoc_training_env", "language": "python", "name": "eurovoc_training_env" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.12" } }, "nbformat": 4, "nbformat_minor": 5 }