Tbruand commited on
Commit ·
e5bc494
1
Parent(s): ad6dade
feat(notebook): ajout du notebook d'entraînement CamemBERT (fine-tuning)
Browse files
notebooks/02_train_camenbert.ipynb
ADDED
|
@@ -0,0 +1,586 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"cells": [
|
| 3 |
+
{
|
| 4 |
+
"cell_type": "code",
|
| 5 |
+
"execution_count": 1,
|
| 6 |
+
"id": "e63f42f0-e971-4017-8550-21fdcfc2de11",
|
| 7 |
+
"metadata": {},
|
| 8 |
+
"outputs": [
|
| 9 |
+
{
|
| 10 |
+
"name": "stdout",
|
| 11 |
+
"output_type": "stream",
|
| 12 |
+
"text": [
|
| 13 |
+
"Requirement already satisfied: scikit-learn in /usr/local/lib/python3.10/dist-packages (1.7.0)\n",
|
| 14 |
+
"Requirement already satisfied: numpy>=1.22.0 in /usr/local/lib/python3.10/dist-packages (from scikit-learn) (1.24.1)\n",
|
| 15 |
+
"Requirement already satisfied: scipy>=1.8.0 in /usr/local/lib/python3.10/dist-packages (from scikit-learn) (1.15.3)\n",
|
| 16 |
+
"Requirement already satisfied: joblib>=1.2.0 in /usr/local/lib/python3.10/dist-packages (from scikit-learn) (1.5.1)\n",
|
| 17 |
+
"Requirement already satisfied: threadpoolctl>=3.1.0 in /usr/local/lib/python3.10/dist-packages (from scikit-learn) (3.6.0)\n",
|
| 18 |
+
"\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv\u001b[0m\u001b[33m\n",
|
| 19 |
+
"\u001b[0m\n",
|
| 20 |
+
"\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m A new release of pip is available: \u001b[0m\u001b[31;49m23.3.1\u001b[0m\u001b[39;49m -> \u001b[0m\u001b[32;49m25.1.1\u001b[0m\n",
|
| 21 |
+
"\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m To update, run: \u001b[0m\u001b[32;49mpython -m pip install --upgrade pip\u001b[0m\n",
|
| 22 |
+
"Requirement already satisfied: pandas in /usr/local/lib/python3.10/dist-packages (2.3.0)\n",
|
| 23 |
+
"Requirement already satisfied: numpy>=1.22.4 in /usr/local/lib/python3.10/dist-packages (from pandas) (1.24.1)\n",
|
| 24 |
+
"Requirement already satisfied: python-dateutil>=2.8.2 in /usr/local/lib/python3.10/dist-packages (from pandas) (2.8.2)\n",
|
| 25 |
+
"Requirement already satisfied: pytz>=2020.1 in /usr/local/lib/python3.10/dist-packages (from pandas) (2025.2)\n",
|
| 26 |
+
"Requirement already satisfied: tzdata>=2022.7 in /usr/local/lib/python3.10/dist-packages (from pandas) (2025.2)\n",
|
| 27 |
+
"Requirement already satisfied: six>=1.5 in /usr/lib/python3/dist-packages (from python-dateutil>=2.8.2->pandas) (1.16.0)\n",
|
| 28 |
+
"\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv\u001b[0m\u001b[33m\n",
|
| 29 |
+
"\u001b[0m\n",
|
| 30 |
+
"\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m A new release of pip is available: \u001b[0m\u001b[31;49m23.3.1\u001b[0m\u001b[39;49m -> \u001b[0m\u001b[32;49m25.1.1\u001b[0m\n",
|
| 31 |
+
"\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m To update, run: \u001b[0m\u001b[32;49mpython -m pip install --upgrade pip\u001b[0m\n",
|
| 32 |
+
"Requirement already satisfied: tqdm in /usr/local/lib/python3.10/dist-packages (4.67.1)\n",
|
| 33 |
+
"\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv\u001b[0m\u001b[33m\n",
|
| 34 |
+
"\u001b[0m\n",
|
| 35 |
+
"\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m A new release of pip is available: \u001b[0m\u001b[31;49m23.3.1\u001b[0m\u001b[39;49m -> \u001b[0m\u001b[32;49m25.1.1\u001b[0m\n",
|
| 36 |
+
"\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m To update, run: \u001b[0m\u001b[32;49mpython -m pip install --upgrade pip\u001b[0m\n",
|
| 37 |
+
"Requirement already satisfied: sentencepiece in /usr/local/lib/python3.10/dist-packages (0.2.0)\n",
|
| 38 |
+
"\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv\u001b[0m\u001b[33m\n",
|
| 39 |
+
"\u001b[0m\n",
|
| 40 |
+
"\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m A new release of pip is available: \u001b[0m\u001b[31;49m23.3.1\u001b[0m\u001b[39;49m -> \u001b[0m\u001b[32;49m25.1.1\u001b[0m\n",
|
| 41 |
+
"\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m To update, run: \u001b[0m\u001b[32;49mpython -m pip install --upgrade pip\u001b[0m\n",
|
| 42 |
+
"Looking in indexes: https://download.pytorch.org/whl/cu118\n",
|
| 43 |
+
"Collecting torch==2.0.1\n",
|
| 44 |
+
" Using cached https://download.pytorch.org/whl/cu118/torch-2.0.1%2Bcu118-cp310-cp310-linux_x86_64.whl (2267.3 MB)\n",
|
| 45 |
+
"Requirement already satisfied: torchvision in /usr/local/lib/python3.10/dist-packages (0.16.0+cu118)\n",
|
| 46 |
+
"Requirement already satisfied: torchaudio in /usr/local/lib/python3.10/dist-packages (2.1.0+cu118)\n",
|
| 47 |
+
"Requirement already satisfied: filelock in /usr/local/lib/python3.10/dist-packages (from torch==2.0.1) (3.9.0)\n",
|
| 48 |
+
"Requirement already satisfied: typing-extensions in /usr/local/lib/python3.10/dist-packages (from torch==2.0.1) (4.4.0)\n",
|
| 49 |
+
"Requirement already satisfied: sympy in /usr/local/lib/python3.10/dist-packages (from torch==2.0.1) (1.12)\n",
|
| 50 |
+
"Requirement already satisfied: networkx in /usr/local/lib/python3.10/dist-packages (from torch==2.0.1) (3.0)\n",
|
| 51 |
+
"Requirement already satisfied: jinja2 in /usr/local/lib/python3.10/dist-packages (from torch==2.0.1) (3.1.2)\n",
|
| 52 |
+
"Requirement already satisfied: triton==2.0.0 in /usr/local/lib/python3.10/dist-packages (from torch==2.0.1) (2.0.0)\n",
|
| 53 |
+
"Requirement already satisfied: cmake in /usr/local/lib/python3.10/dist-packages (from triton==2.0.0->torch==2.0.1) (4.0.3)\n",
|
| 54 |
+
"Requirement already satisfied: lit in /usr/local/lib/python3.10/dist-packages (from triton==2.0.0->torch==2.0.1) (18.1.8)\n",
|
| 55 |
+
"Requirement already satisfied: numpy in /usr/local/lib/python3.10/dist-packages (from torchvision) (1.24.1)\n",
|
| 56 |
+
"Requirement already satisfied: requests in /usr/local/lib/python3.10/dist-packages (from torchvision) (2.31.0)\n",
|
| 57 |
+
"INFO: pip is looking at multiple versions of torchvision to determine which version is compatible with other requirements. This could take a while.\n",
|
| 58 |
+
"Collecting torchvision\n",
|
| 59 |
+
" Using cached https://download.pytorch.org/whl/cu118/torchvision-0.22.1%2Bcu118-cp310-cp310-manylinux_2_28_x86_64.whl.metadata (6.1 kB)\n",
|
| 60 |
+
" Using cached https://download.pytorch.org/whl/cu118/torchvision-0.22.0%2Bcu118-cp310-cp310-manylinux_2_28_x86_64.whl.metadata (6.1 kB)\n",
|
| 61 |
+
" Using cached https://download.pytorch.org/whl/cu118/torchvision-0.21.0%2Bcu118-cp310-cp310-linux_x86_64.whl.metadata (6.1 kB)\n",
|
| 62 |
+
" Using cached https://download.pytorch.org/whl/cu118/torchvision-0.20.1%2Bcu118-cp310-cp310-linux_x86_64.whl (6.5 MB)\n",
|
| 63 |
+
" Using cached https://download.pytorch.org/whl/cu118/torchvision-0.20.0%2Bcu118-cp310-cp310-linux_x86_64.whl (6.5 MB)\n",
|
| 64 |
+
" Using cached https://download.pytorch.org/whl/cu118/torchvision-0.19.1%2Bcu118-cp310-cp310-linux_x86_64.whl (6.3 MB)\n",
|
| 65 |
+
" Using cached https://download.pytorch.org/whl/cu118/torchvision-0.19.0%2Bcu118-cp310-cp310-linux_x86_64.whl (6.3 MB)\n",
|
| 66 |
+
"INFO: pip is still looking at multiple versions of torchvision to determine which version is compatible with other requirements. This could take a while.\n",
|
| 67 |
+
" Using cached https://download.pytorch.org/whl/cu118/torchvision-0.18.1%2Bcu118-cp310-cp310-linux_x86_64.whl (6.3 MB)\n",
|
| 68 |
+
" Using cached https://download.pytorch.org/whl/cu118/torchvision-0.18.0%2Bcu118-cp310-cp310-linux_x86_64.whl (6.3 MB)\n",
|
| 69 |
+
" Using cached https://download.pytorch.org/whl/cu118/torchvision-0.17.2%2Bcu118-cp310-cp310-linux_x86_64.whl (6.2 MB)\n",
|
| 70 |
+
" Using cached https://download.pytorch.org/whl/cu118/torchvision-0.17.1%2Bcu118-cp310-cp310-linux_x86_64.whl (6.2 MB)\n",
|
| 71 |
+
" Using cached https://download.pytorch.org/whl/cu118/torchvision-0.17.0%2Bcu118-cp310-cp310-linux_x86_64.whl (6.2 MB)\n",
|
| 72 |
+
"INFO: This is taking longer than usual. You might need to provide the dependency resolver with stricter constraints to reduce runtime. See https://pip.pypa.io/warnings/backtracking for guidance. If you want to abort this run, press Ctrl + C.\n",
|
| 73 |
+
" Using cached https://download.pytorch.org/whl/cu118/torchvision-0.16.2%2Bcu118-cp310-cp310-linux_x86_64.whl (6.1 MB)\n",
|
| 74 |
+
" Using cached https://download.pytorch.org/whl/cu118/torchvision-0.16.1%2Bcu118-cp310-cp310-linux_x86_64.whl (6.1 MB)\n",
|
| 75 |
+
" Using cached https://download.pytorch.org/whl/cu118/torchvision-0.15.2%2Bcu118-cp310-cp310-linux_x86_64.whl (6.1 MB)\n",
|
| 76 |
+
"Requirement already satisfied: pillow!=8.3.*,>=5.3.0 in /usr/local/lib/python3.10/dist-packages (from torchvision) (9.3.0)\n",
|
| 77 |
+
"INFO: pip is looking at multiple versions of torchaudio to determine which version is compatible with other requirements. This could take a while.\n",
|
| 78 |
+
"Collecting torchaudio\n",
|
| 79 |
+
" Using cached https://download.pytorch.org/whl/cu118/torchaudio-2.7.1%2Bcu118-cp310-cp310-manylinux_2_28_x86_64.whl.metadata (6.6 kB)\n",
|
| 80 |
+
" Using cached https://download.pytorch.org/whl/cu118/torchaudio-2.7.0%2Bcu118-cp310-cp310-manylinux_2_28_x86_64.whl.metadata (6.6 kB)\n",
|
| 81 |
+
" Using cached https://download.pytorch.org/whl/cu118/torchaudio-2.6.0%2Bcu118-cp310-cp310-linux_x86_64.whl.metadata (6.6 kB)\n",
|
| 82 |
+
" Using cached https://download.pytorch.org/whl/cu118/torchaudio-2.5.1%2Bcu118-cp310-cp310-linux_x86_64.whl (3.3 MB)\n",
|
| 83 |
+
" Using cached https://download.pytorch.org/whl/cu118/torchaudio-2.5.0%2Bcu118-cp310-cp310-linux_x86_64.whl (3.3 MB)\n",
|
| 84 |
+
" Using cached https://download.pytorch.org/whl/cu118/torchaudio-2.4.1%2Bcu118-cp310-cp310-linux_x86_64.whl (3.3 MB)\n",
|
| 85 |
+
" Using cached https://download.pytorch.org/whl/cu118/torchaudio-2.4.0%2Bcu118-cp310-cp310-linux_x86_64.whl (3.3 MB)\n",
|
| 86 |
+
"INFO: pip is still looking at multiple versions of torchaudio to determine which version is compatible with other requirements. This could take a while.\n",
|
| 87 |
+
" Using cached https://download.pytorch.org/whl/cu118/torchaudio-2.3.1%2Bcu118-cp310-cp310-linux_x86_64.whl (3.3 MB)\n",
|
| 88 |
+
" Using cached https://download.pytorch.org/whl/cu118/torchaudio-2.3.0%2Bcu118-cp310-cp310-linux_x86_64.whl (3.3 MB)\n",
|
| 89 |
+
" Using cached https://download.pytorch.org/whl/cu118/torchaudio-2.2.2%2Bcu118-cp310-cp310-linux_x86_64.whl (3.3 MB)\n",
|
| 90 |
+
" Using cached https://download.pytorch.org/whl/cu118/torchaudio-2.2.1%2Bcu118-cp310-cp310-linux_x86_64.whl (3.3 MB)\n",
|
| 91 |
+
" Using cached https://download.pytorch.org/whl/cu118/torchaudio-2.2.0%2Bcu118-cp310-cp310-linux_x86_64.whl (3.3 MB)\n",
|
| 92 |
+
"INFO: This is taking longer than usual. You might need to provide the dependency resolver with stricter constraints to reduce runtime. See https://pip.pypa.io/warnings/backtracking for guidance. If you want to abort this run, press Ctrl + C.\n",
|
| 93 |
+
" Using cached https://download.pytorch.org/whl/cu118/torchaudio-2.1.2%2Bcu118-cp310-cp310-linux_x86_64.whl (3.2 MB)\n",
|
| 94 |
+
" Using cached https://download.pytorch.org/whl/cu118/torchaudio-2.1.1%2Bcu118-cp310-cp310-linux_x86_64.whl (3.2 MB)\n",
|
| 95 |
+
" Using cached https://download.pytorch.org/whl/cu118/torchaudio-2.0.2%2Bcu118-cp310-cp310-linux_x86_64.whl (4.4 MB)\n",
|
| 96 |
+
"Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.10/dist-packages (from jinja2->torch==2.0.1) (2.1.2)\n",
|
| 97 |
+
"Requirement already satisfied: charset-normalizer<4,>=2 in /usr/local/lib/python3.10/dist-packages (from requests->torchvision) (2.1.1)\n",
|
| 98 |
+
"Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.10/dist-packages (from requests->torchvision) (3.4)\n",
|
| 99 |
+
"Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.10/dist-packages (from requests->torchvision) (1.26.13)\n",
|
| 100 |
+
"Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.10/dist-packages (from requests->torchvision) (2022.12.7)\n",
|
| 101 |
+
"Requirement already satisfied: mpmath>=0.19 in /usr/local/lib/python3.10/dist-packages (from sympy->torch==2.0.1) (1.3.0)\n",
|
| 102 |
+
"Installing collected packages: torch, torchvision, torchaudio\n",
|
| 103 |
+
" Attempting uninstall: torch\n",
|
| 104 |
+
" Found existing installation: torch 2.1.0+cu118\n",
|
| 105 |
+
" Uninstalling torch-2.1.0+cu118:\n",
|
| 106 |
+
" Successfully uninstalled torch-2.1.0+cu118\n",
|
| 107 |
+
" Rolling back uninstall of torch\n",
|
| 108 |
+
" Moving to /usr/local/bin/convert-caffe2-to-onnx\n",
|
| 109 |
+
" from /tmp/pip-uninstall-ior9qvf0/convert-caffe2-to-onnx\n",
|
| 110 |
+
" Moving to /usr/local/bin/convert-onnx-to-caffe2\n",
|
| 111 |
+
" from /tmp/pip-uninstall-ior9qvf0/convert-onnx-to-caffe2\n",
|
| 112 |
+
" Moving to /usr/local/bin/torchrun\n",
|
| 113 |
+
" from /tmp/pip-uninstall-ior9qvf0/torchrun\n",
|
| 114 |
+
" Moving to /usr/local/lib/python3.10/dist-packages/functorch/\n",
|
| 115 |
+
" from /usr/local/lib/python3.10/dist-packages/~unctorch\n",
|
| 116 |
+
" Moving to /usr/local/lib/python3.10/dist-packages/nvfuser/\n",
|
| 117 |
+
" from /usr/local/lib/python3.10/dist-packages/~vfuser\n",
|
| 118 |
+
" Moving to /usr/local/lib/python3.10/dist-packages/torch-2.1.0+cu118.dist-info/\n",
|
| 119 |
+
" from /usr/local/lib/python3.10/dist-packages/~orch-2.1.0+cu118.dist-info\n",
|
| 120 |
+
" Moving to /usr/local/lib/python3.10/dist-packages/torch/\n",
|
| 121 |
+
" from /usr/local/lib/python3.10/dist-packages/~orch\n",
|
| 122 |
+
" Moving to /usr/local/lib/python3.10/dist-packages/torchgen/\n",
|
| 123 |
+
" from /usr/local/lib/python3.10/dist-packages/~orchgen\n",
|
| 124 |
+
"\u001b[31mERROR: Could not install packages due to an OSError: [Errno 28] No space left on device\n",
|
| 125 |
+
"\u001b[0m\u001b[31m\n",
|
| 126 |
+
"\u001b[0m\n",
|
| 127 |
+
"\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m A new release of pip is available: \u001b[0m\u001b[31;49m23.3.1\u001b[0m\u001b[39;49m -> \u001b[0m\u001b[32;49m25.1.1\u001b[0m\n",
|
| 128 |
+
"\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m To update, run: \u001b[0m\u001b[32;49mpython -m pip install --upgrade pip\u001b[0m\n",
|
| 129 |
+
"Requirement already satisfied: transformers in /usr/local/lib/python3.10/dist-packages (4.52.4)\n",
|
| 130 |
+
"Requirement already satisfied: filelock in /usr/local/lib/python3.10/dist-packages (from transformers) (3.9.0)\n",
|
| 131 |
+
"Requirement already satisfied: huggingface-hub<1.0,>=0.30.0 in /usr/local/lib/python3.10/dist-packages (from transformers) (0.33.0)\n",
|
| 132 |
+
"Requirement already satisfied: numpy>=1.17 in /usr/local/lib/python3.10/dist-packages (from transformers) (1.24.1)\n",
|
| 133 |
+
"Requirement already satisfied: packaging>=20.0 in /usr/local/lib/python3.10/dist-packages (from transformers) (23.2)\n",
|
| 134 |
+
"Requirement already satisfied: pyyaml>=5.1 in /usr/local/lib/python3.10/dist-packages (from transformers) (6.0.1)\n",
|
| 135 |
+
"Requirement already satisfied: regex!=2019.12.17 in /usr/local/lib/python3.10/dist-packages (from transformers) (2024.11.6)\n",
|
| 136 |
+
"Requirement already satisfied: requests in /usr/local/lib/python3.10/dist-packages (from transformers) (2.31.0)\n",
|
| 137 |
+
"Requirement already satisfied: tokenizers<0.22,>=0.21 in /usr/local/lib/python3.10/dist-packages (from transformers) (0.21.1)\n",
|
| 138 |
+
"Requirement already satisfied: safetensors>=0.4.3 in /usr/local/lib/python3.10/dist-packages (from transformers) (0.5.3)\n",
|
| 139 |
+
"Requirement already satisfied: tqdm>=4.27 in /usr/local/lib/python3.10/dist-packages (from transformers) (4.67.1)\n",
|
| 140 |
+
"Requirement already satisfied: fsspec>=2023.5.0 in /usr/local/lib/python3.10/dist-packages (from huggingface-hub<1.0,>=0.30.0->transformers) (2025.5.1)\n",
|
| 141 |
+
"Requirement already satisfied: typing-extensions>=3.7.4.3 in /usr/local/lib/python3.10/dist-packages (from huggingface-hub<1.0,>=0.30.0->transformers) (4.4.0)\n",
|
| 142 |
+
"Requirement already satisfied: hf-xet<2.0.0,>=1.1.2 in /usr/local/lib/python3.10/dist-packages (from huggingface-hub<1.0,>=0.30.0->transformers) (1.1.3)\n",
|
| 143 |
+
"Requirement already satisfied: charset-normalizer<4,>=2 in /usr/local/lib/python3.10/dist-packages (from requests->transformers) (2.1.1)\n",
|
| 144 |
+
"Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.10/dist-packages (from requests->transformers) (3.4)\n",
|
| 145 |
+
"Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.10/dist-packages (from requests->transformers) (1.26.13)\n",
|
| 146 |
+
"Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.10/dist-packages (from requests->transformers) (2022.12.7)\n",
|
| 147 |
+
"\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv\u001b[0m\u001b[33m\n",
|
| 148 |
+
"\u001b[0m\n",
|
| 149 |
+
"\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m A new release of pip is available: \u001b[0m\u001b[31;49m23.3.1\u001b[0m\u001b[39;49m -> \u001b[0m\u001b[32;49m25.1.1\u001b[0m\n",
|
| 150 |
+
"\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m To update, run: \u001b[0m\u001b[32;49mpython -m pip install --upgrade pip\u001b[0m\n"
|
| 151 |
+
]
|
| 152 |
+
}
|
| 153 |
+
],
|
| 154 |
+
"source": [
|
| 155 |
+
"!pip install scikit-learn\n",
|
| 156 |
+
"!pip install pandas\n",
|
| 157 |
+
"!pip install tqdm\n",
|
| 158 |
+
"!pip install sentencepiece\n",
|
| 159 |
+
"!pip install torch==2.0.1 torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118\n",
|
| 160 |
+
"!pip install --upgrade transformers"
|
| 161 |
+
]
|
| 162 |
+
},
|
| 163 |
+
{
|
| 164 |
+
"cell_type": "code",
|
| 165 |
+
"execution_count": 2,
|
| 166 |
+
"id": "7db96f7b-0cd3-4710-93d8-391622e60c25",
|
| 167 |
+
"metadata": {},
|
| 168 |
+
"outputs": [],
|
| 169 |
+
"source": [
|
| 170 |
+
"import os\n",
|
| 171 |
+
"import json\n",
|
| 172 |
+
"import pandas as pd\n",
|
| 173 |
+
"import torch\n",
|
| 174 |
+
"from torch.optim import AdamW\n",
|
| 175 |
+
"from torch.utils.data import Dataset, DataLoader\n",
|
| 176 |
+
"from torch.nn import CrossEntropyLoss\n",
|
| 177 |
+
"from transformers import CamembertTokenizer, CamembertForSequenceClassification, get_scheduler\n",
|
| 178 |
+
"from sklearn.model_selection import train_test_split\n",
|
| 179 |
+
"from sklearn.metrics import classification_report\n",
|
| 180 |
+
"from tqdm import tqdm"
|
| 181 |
+
]
|
| 182 |
+
},
|
| 183 |
+
{
|
| 184 |
+
"cell_type": "code",
|
| 185 |
+
"execution_count": 3,
|
| 186 |
+
"id": "338ab5df-bc6d-4f2c-8a5e-bc9bb7b658f1",
|
| 187 |
+
"metadata": {},
|
| 188 |
+
"outputs": [],
|
| 189 |
+
"source": [
|
| 190 |
+
"# ─────────────────────────────────────────────\n",
|
| 191 |
+
"# ⚙️ Config\n",
|
| 192 |
+
"# ─────────────────────────────────────────────\n",
|
| 193 |
+
"DEBUG = False\n",
|
| 194 |
+
"BATCH_SIZE = 64\n",
|
| 195 |
+
"EPOCHS = 3 if not DEBUG else 1\n",
|
| 196 |
+
"MAX_LEN = 128\n",
|
| 197 |
+
"LR = 2e-5\n",
|
| 198 |
+
"PATIENCE = 2 # pour l'early stopping"
|
| 199 |
+
]
|
| 200 |
+
},
|
| 201 |
+
{
|
| 202 |
+
"cell_type": "code",
|
| 203 |
+
"execution_count": 4,
|
| 204 |
+
"id": "3ccb7df2-77fc-461b-ac1e-05f1d8be7ed0",
|
| 205 |
+
"metadata": {},
|
| 206 |
+
"outputs": [
|
| 207 |
+
{
|
| 208 |
+
"name": "stdout",
|
| 209 |
+
"output_type": "stream",
|
| 210 |
+
"text": [
|
| 211 |
+
"Classes : df_labels\n",
|
| 212 |
+
"0 189412\n",
|
| 213 |
+
"1 33982\n",
|
| 214 |
+
"Name: count, dtype: int64\n"
|
| 215 |
+
]
|
| 216 |
+
}
|
| 217 |
+
],
|
| 218 |
+
"source": [
|
| 219 |
+
"# ─────────────────────────────────────────────\n",
|
| 220 |
+
"# 📁 Chargement du dataset\n",
|
| 221 |
+
"# ─────────────────────────────────────────────\n",
|
| 222 |
+
"df = pd.read_csv(\"jigsaw-toxic-comment-train-google-fr-cleaned.csv\")\n",
|
| 223 |
+
"df['comment_text'] = df['comment_text'].astype(str)\n",
|
| 224 |
+
"df.rename(columns={'comment_text': 'texts'}, inplace=True)\n",
|
| 225 |
+
"\n",
|
| 226 |
+
"label_cols = ['toxic', 'severe_toxic', 'obscene', 'threat', 'insult', 'identity_hate']\n",
|
| 227 |
+
"other_cols_to_drop = ['Unnamed: 0.1', 'Unnamed: 0', 'id']\n",
|
| 228 |
+
"cols_to_drop = label_cols + other_cols_to_drop\n",
|
| 229 |
+
"\n",
|
| 230 |
+
"df['df_labels'] = df[label_cols].max(axis=1)\n",
|
| 231 |
+
"df = df.drop(columns=cols_to_drop)\n",
|
| 232 |
+
"\n",
|
| 233 |
+
"# Debug : sous-échantillonnage équilibré\n",
|
| 234 |
+
"if DEBUG:\n",
|
| 235 |
+
" df_0 = df[df[\"df_labels\"] == 0].sample(500, random_state=42)\n",
|
| 236 |
+
" df_1 = df[df[\"df_labels\"] == 1].sample(500, random_state=42)\n",
|
| 237 |
+
" df = pd.concat([df_0, df_1]).sample(frac=1, random_state=42)\n",
|
| 238 |
+
"\n",
|
| 239 |
+
"print(\"Classes :\", df['df_labels'].value_counts())"
|
| 240 |
+
]
|
| 241 |
+
},
|
| 242 |
+
{
|
| 243 |
+
"cell_type": "code",
|
| 244 |
+
"execution_count": 5,
|
| 245 |
+
"id": "83c4cf79-57d9-4d54-8d8f-0e4249b8a930",
|
| 246 |
+
"metadata": {},
|
| 247 |
+
"outputs": [],
|
| 248 |
+
"source": [
|
| 249 |
+
"# ─────────────────────────────────────────────\n",
|
| 250 |
+
"# 🔢 Dataset\n",
|
| 251 |
+
"# ─────────────────────────────────────────────\n",
|
| 252 |
+
"tokenizer = CamembertTokenizer.from_pretrained(\"camembert-base\")\n",
|
| 253 |
+
"\n",
|
| 254 |
+
"class CommentDataset(Dataset):\n",
|
| 255 |
+
" def __init__(self, texts, labels, tokenizer, max_len):\n",
|
| 256 |
+
" self.texts = texts\n",
|
| 257 |
+
" self.labels = labels\n",
|
| 258 |
+
" self.tokenizer = tokenizer\n",
|
| 259 |
+
" self.max_len = max_len\n",
|
| 260 |
+
"\n",
|
| 261 |
+
" def __len__(self):\n",
|
| 262 |
+
" return len(self.texts)\n",
|
| 263 |
+
"\n",
|
| 264 |
+
" def __getitem__(self, idx):\n",
|
| 265 |
+
" encoding = self.tokenizer(\n",
|
| 266 |
+
" self.texts[idx],\n",
|
| 267 |
+
" padding=\"max_length\",\n",
|
| 268 |
+
" truncation=True,\n",
|
| 269 |
+
" max_length=self.max_len,\n",
|
| 270 |
+
" return_tensors=\"pt\"\n",
|
| 271 |
+
" )\n",
|
| 272 |
+
" item = {key: val.squeeze() for key, val in encoding.items()}\n",
|
| 273 |
+
" item['labels'] = torch.tensor(self.labels[idx], dtype=torch.long)\n",
|
| 274 |
+
" return item\n",
|
| 275 |
+
"\n",
|
| 276 |
+
"# Split\n",
|
| 277 |
+
"X_train, X_val, y_train, y_val = train_test_split(df[\"texts\"].tolist(), df[\"df_labels\"].tolist(), test_size=0.2, random_state=42)\n",
|
| 278 |
+
"\n",
|
| 279 |
+
"train_dataset = CommentDataset(X_train, y_train, tokenizer, MAX_LEN)\n",
|
| 280 |
+
"val_dataset = CommentDataset(X_val, y_val, tokenizer, MAX_LEN)\n",
|
| 281 |
+
"\n",
|
| 282 |
+
"train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)\n",
|
| 283 |
+
"val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE)"
|
| 284 |
+
]
|
| 285 |
+
},
|
| 286 |
+
{
|
| 287 |
+
"cell_type": "code",
|
| 288 |
+
"execution_count": 6,
|
| 289 |
+
"id": "05e8419e-dbeb-42ba-b9ed-ea099e96244a",
|
| 290 |
+
"metadata": {},
|
| 291 |
+
"outputs": [
|
| 292 |
+
{
|
| 293 |
+
"name": "stderr",
|
| 294 |
+
"output_type": "stream",
|
| 295 |
+
"text": [
|
| 296 |
+
"Some weights of CamembertForSequenceClassification were not initialized from the model checkpoint at camembert-base and are newly initialized: ['classifier.dense.bias', 'classifier.dense.weight', 'classifier.out_proj.bias', 'classifier.out_proj.weight']\n",
|
| 297 |
+
"You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
|
| 298 |
+
]
|
| 299 |
+
},
|
| 300 |
+
{
|
| 301 |
+
"name": "stdout",
|
| 302 |
+
"output_type": "stream",
|
| 303 |
+
"text": [
|
| 304 |
+
"Poids pour la loss : tensor([1.0000, 5.5739])\n"
|
| 305 |
+
]
|
| 306 |
+
}
|
| 307 |
+
],
|
| 308 |
+
"source": [
|
| 309 |
+
"# ─────────────────────────────────────────────\n",
|
| 310 |
+
"# 🧠 Modèle + loss pondérée\n",
|
| 311 |
+
"# ─────────────────────────────────────────────\n",
|
| 312 |
+
"device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
|
| 313 |
+
"model = CamembertForSequenceClassification.from_pretrained(\"camembert-base\", num_labels=2).to(device)\n",
|
| 314 |
+
"\n",
|
| 315 |
+
"# pondération dynamique\n",
|
| 316 |
+
"if DEBUG:\n",
|
| 317 |
+
" class_weights = torch.tensor([1.0, 1.0], dtype=torch.float)\n",
|
| 318 |
+
"else:\n",
|
| 319 |
+
" count_0 = df[df[\"df_labels\"] == 0].shape[0]\n",
|
| 320 |
+
" count_1 = df[df[\"df_labels\"] == 1].shape[0]\n",
|
| 321 |
+
" class_weights = torch.tensor([1.0, count_0 / count_1], dtype=torch.float)\n",
|
| 322 |
+
"\n",
|
| 323 |
+
"print(f\"Poids pour la loss : {class_weights}\")\n",
|
| 324 |
+
"loss_fn = CrossEntropyLoss(weight=class_weights.to(device))\n",
|
| 325 |
+
"\n",
|
| 326 |
+
"# Optimiseur et scheduler\n",
|
| 327 |
+
"optimizer = AdamW(model.parameters(), lr=LR)\n",
|
| 328 |
+
"scheduler = get_scheduler(\"linear\", optimizer=optimizer, num_warmup_steps=0, num_training_steps=len(train_loader) * EPOCHS)\n"
|
| 329 |
+
]
|
| 330 |
+
},
|
| 331 |
+
{
|
| 332 |
+
"cell_type": "code",
|
| 333 |
+
"execution_count": 7,
|
| 334 |
+
"id": "0b181b9f-64a6-479f-b585-221a598cded6",
|
| 335 |
+
"metadata": {},
|
| 336 |
+
"outputs": [
|
| 337 |
+
{
|
| 338 |
+
"name": "stdout",
|
| 339 |
+
"output_type": "stream",
|
| 340 |
+
"text": [
|
| 341 |
+
"\n",
|
| 342 |
+
"🌟 Epoch 1/3\n"
|
| 343 |
+
]
|
| 344 |
+
},
|
| 345 |
+
{
|
| 346 |
+
"name": "stderr",
|
| 347 |
+
"output_type": "stream",
|
| 348 |
+
"text": [
|
| 349 |
+
"Entraînement: 100%|██████████| 2793/2793 [18:23<00:00, 2.53it/s]\n"
|
| 350 |
+
]
|
| 351 |
+
},
|
| 352 |
+
{
|
| 353 |
+
"name": "stdout",
|
| 354 |
+
"output_type": "stream",
|
| 355 |
+
"text": [
|
| 356 |
+
"📉 Loss moyenne : 0.5043\n"
|
| 357 |
+
]
|
| 358 |
+
},
|
| 359 |
+
{
|
| 360 |
+
"name": "stderr",
|
| 361 |
+
"output_type": "stream",
|
| 362 |
+
"text": [
|
| 363 |
+
"Évaluation: 100%|██████████| 699/699 [01:50<00:00, 6.32it/s]\n"
|
| 364 |
+
]
|
| 365 |
+
},
|
| 366 |
+
{
|
| 367 |
+
"name": "stdout",
|
| 368 |
+
"output_type": "stream",
|
| 369 |
+
"text": [
|
| 370 |
+
"🎯 F1-score (weighted) : 0.8826\n",
|
| 371 |
+
"✅ Nouveau meilleur modèle — sauvegarde manuelle...\n",
|
| 372 |
+
"\n",
|
| 373 |
+
"🌟 Epoch 2/3\n"
|
| 374 |
+
]
|
| 375 |
+
},
|
| 376 |
+
{
|
| 377 |
+
"name": "stderr",
|
| 378 |
+
"output_type": "stream",
|
| 379 |
+
"text": [
|
| 380 |
+
"Entraînement: 100%|██████████| 2793/2793 [18:26<00:00, 2.53it/s]\n"
|
| 381 |
+
]
|
| 382 |
+
},
|
| 383 |
+
{
|
| 384 |
+
"name": "stdout",
|
| 385 |
+
"output_type": "stream",
|
| 386 |
+
"text": [
|
| 387 |
+
"📉 Loss moyenne : 0.4711\n"
|
| 388 |
+
]
|
| 389 |
+
},
|
| 390 |
+
{
|
| 391 |
+
"name": "stderr",
|
| 392 |
+
"output_type": "stream",
|
| 393 |
+
"text": [
|
| 394 |
+
"Évaluation: 100%|██████████| 699/699 [01:49<00:00, 6.39it/s]\n"
|
| 395 |
+
]
|
| 396 |
+
},
|
| 397 |
+
{
|
| 398 |
+
"name": "stdout",
|
| 399 |
+
"output_type": "stream",
|
| 400 |
+
"text": [
|
| 401 |
+
"🎯 F1-score (weighted) : 0.8735\n",
|
| 402 |
+
"⏳ EarlyStopping patience : 1/2\n",
|
| 403 |
+
"\n",
|
| 404 |
+
"🌟 Epoch 3/3\n"
|
| 405 |
+
]
|
| 406 |
+
},
|
| 407 |
+
{
|
| 408 |
+
"name": "stderr",
|
| 409 |
+
"output_type": "stream",
|
| 410 |
+
"text": [
|
| 411 |
+
"Entraînement: 100%|██████████| 2793/2793 [18:26<00:00, 2.52it/s]\n"
|
| 412 |
+
]
|
| 413 |
+
},
|
| 414 |
+
{
|
| 415 |
+
"name": "stdout",
|
| 416 |
+
"output_type": "stream",
|
| 417 |
+
"text": [
|
| 418 |
+
"📉 Loss moyenne : 0.4485\n"
|
| 419 |
+
]
|
| 420 |
+
},
|
| 421 |
+
{
|
| 422 |
+
"name": "stderr",
|
| 423 |
+
"output_type": "stream",
|
| 424 |
+
"text": [
|
| 425 |
+
"Évaluation: 100%|██████████| 699/699 [01:50<00:00, 6.35it/s]\n"
|
| 426 |
+
]
|
| 427 |
+
},
|
| 428 |
+
{
|
| 429 |
+
"name": "stdout",
|
| 430 |
+
"output_type": "stream",
|
| 431 |
+
"text": [
|
| 432 |
+
"🎯 F1-score (weighted) : 0.8816\n",
|
| 433 |
+
"⏳ EarlyStopping patience : 2/2\n",
|
| 434 |
+
"🛑 Arrêt anticipé — pas d'amélioration\n"
|
| 435 |
+
]
|
| 436 |
+
}
|
| 437 |
+
],
|
| 438 |
+
"source": [
|
| 439 |
+
"best_f1 = 0\n",
|
| 440 |
+
"patience_counter = 0\n",
|
| 441 |
+
"os.makedirs(\"outputs/model\", exist_ok=True)\n",
|
| 442 |
+
"\n",
|
| 443 |
+
"for epoch in range(EPOCHS):\n",
|
| 444 |
+
" print(f\"\\n🌟 Epoch {epoch + 1}/{EPOCHS}\")\n",
|
| 445 |
+
" model.train()\n",
|
| 446 |
+
" total_loss = 0\n",
|
| 447 |
+
"\n",
|
| 448 |
+
" for batch in tqdm(train_loader, desc=\"Entraînement\"):\n",
|
| 449 |
+
" batch = {k: v.to(device) for k, v in batch.items()}\n",
|
| 450 |
+
" logits = model(**batch).logits\n",
|
| 451 |
+
" loss = loss_fn(logits, batch[\"labels\"])\n",
|
| 452 |
+
" loss.backward()\n",
|
| 453 |
+
" optimizer.step()\n",
|
| 454 |
+
" scheduler.step()\n",
|
| 455 |
+
" optimizer.zero_grad()\n",
|
| 456 |
+
" total_loss += loss.item()\n",
|
| 457 |
+
"\n",
|
| 458 |
+
" avg_loss = total_loss / len(train_loader)\n",
|
| 459 |
+
" print(f\"📉 Loss moyenne : {avg_loss:.4f}\")\n",
|
| 460 |
+
"\n",
|
| 461 |
+
" # 🔍 Évaluation\n",
|
| 462 |
+
" model.eval()\n",
|
| 463 |
+
" y_true, y_pred = [], []\n",
|
| 464 |
+
" with torch.no_grad():\n",
|
| 465 |
+
" for batch in tqdm(val_loader, desc=\"Évaluation\"):\n",
|
| 466 |
+
" batch = {k: v.to(device) for k, v in batch.items()}\n",
|
| 467 |
+
" logits = model(**batch).logits\n",
|
| 468 |
+
" preds = torch.argmax(logits, dim=1)\n",
|
| 469 |
+
" y_true.extend(batch[\"labels\"].cpu().tolist())\n",
|
| 470 |
+
" y_pred.extend(preds.cpu().tolist())\n",
|
| 471 |
+
"\n",
|
| 472 |
+
" report = classification_report(y_true, y_pred, target_names=[\"Non toxique\", \"Toxique\"], output_dict=True)\n",
|
| 473 |
+
" f1 = report[\"weighted avg\"][\"f1-score\"]\n",
|
| 474 |
+
" print(f\"🎯 F1-score (weighted) : {f1:.4f}\")\n",
|
| 475 |
+
"\n",
|
| 476 |
+
" if f1 > best_f1:\n",
|
| 477 |
+
" best_f1 = f1\n",
|
| 478 |
+
" patience_counter = 0\n",
|
| 479 |
+
" print(\"✅ Nouveau meilleur modèle — sauvegarde manuelle...\")\n",
|
| 480 |
+
"\n",
|
| 481 |
+
" import os\n",
|
| 482 |
+
"\n",
|
| 483 |
+
" # 📂 Dossier de sauvegarde\n",
|
| 484 |
+
" save_dir = \"outputs/model\"\n",
|
| 485 |
+
" os.makedirs(save_dir, exist_ok=True)\n",
|
| 486 |
+
"\n",
|
| 487 |
+
" # 💾 Sauvegarde manuelle des poids\n",
|
| 488 |
+
" torch.save(model.state_dict(), os.path.join(save_dir, \"pytorch_model.bin\"))\n",
|
| 489 |
+
"\n",
|
| 490 |
+
" # 💾 Sauvegarde de la configuration du modèle\n",
|
| 491 |
+
" model.config.to_json_file(os.path.join(save_dir, \"config.json\"))\n",
|
| 492 |
+
"\n",
|
| 493 |
+
" # 💾 Sauvegarde du tokenizer\n",
|
| 494 |
+
" tokenizer.save_pretrained(save_dir)\n",
|
| 495 |
+
"\n",
|
| 496 |
+
" # 💾 Sauvegarde des métriques\n",
|
| 497 |
+
" with open(\"outputs/metrics.json\", \"w\") as f:\n",
|
| 498 |
+
" json.dump(report, f, indent=4)\n",
|
| 499 |
+
"\n",
|
| 500 |
+
" else:\n",
|
| 501 |
+
" patience_counter += 1\n",
|
| 502 |
+
" print(f\"⏳ EarlyStopping patience : {patience_counter}/{PATIENCE}\")\n",
|
| 503 |
+
" if patience_counter >= PATIENCE:\n",
|
| 504 |
+
" print(\"🛑 Arrêt anticipé — pas d'amélioration\")\n",
|
| 505 |
+
" break"
|
| 506 |
+
]
|
| 507 |
+
},
|
| 508 |
+
{
|
| 509 |
+
"cell_type": "code",
|
| 510 |
+
"execution_count": 8,
|
| 511 |
+
"id": "ba6f2d7c-0daf-48db-96a2-33935dca1d9e",
|
| 512 |
+
"metadata": {},
|
| 513 |
+
"outputs": [
|
| 514 |
+
{
|
| 515 |
+
"name": "stdout",
|
| 516 |
+
"output_type": "stream",
|
| 517 |
+
"text": [
|
| 518 |
+
"📊 Métriques sauvegardées :\n",
|
| 519 |
+
"\n",
|
| 520 |
+
"🗂 Classe : Non toxique\n",
|
| 521 |
+
" 🔸 Précision : 0.9294\n",
|
| 522 |
+
" 🔸 Rappel : 0.9329\n",
|
| 523 |
+
" 🔸 F1-score : 0.9312\n",
|
| 524 |
+
"\n",
|
| 525 |
+
"🗂 Classe : Toxique\n",
|
| 526 |
+
" 🔸 Précision : 0.6193\n",
|
| 527 |
+
" 🔸 Rappel : 0.6065\n",
|
| 528 |
+
" 🔸 F1-score : 0.6129\n",
|
| 529 |
+
"\n",
|
| 530 |
+
"🔄 Moyennes pondérées (weighted avg) :\n",
|
| 531 |
+
" ✅ Précision : 0.8821\n",
|
| 532 |
+
" ✅ Rappel : 0.8831\n",
|
| 533 |
+
" ✅ F1-score : 0.8826\n"
|
| 534 |
+
]
|
| 535 |
+
}
|
| 536 |
+
],
|
| 537 |
+
"source": [
|
| 538 |
+
"import json\n",
|
| 539 |
+
"import os\n",
|
| 540 |
+
"\n",
|
| 541 |
+
"# 📁 Chemin du fichier de métriques\n",
|
| 542 |
+
"metrics_path = \"outputs/metrics.json\"\n",
|
| 543 |
+
"\n",
|
| 544 |
+
"# ✅ Vérifie l'existence du fichier\n",
|
| 545 |
+
"if os.path.exists(metrics_path):\n",
|
| 546 |
+
" with open(metrics_path, \"r\") as f:\n",
|
| 547 |
+
" metrics = json.load(f)\n",
|
| 548 |
+
"\n",
|
| 549 |
+
" print(\"📊 Métriques sauvegardées :\\n\")\n",
|
| 550 |
+
" for label in [\"Non toxique\", \"Toxique\"]:\n",
|
| 551 |
+
" print(f\"🗂 Classe : {label}\")\n",
|
| 552 |
+
" print(f\" 🔸 Précision : {metrics[label]['precision']:.4f}\")\n",
|
| 553 |
+
" print(f\" 🔸 Rappel : {metrics[label]['recall']:.4f}\")\n",
|
| 554 |
+
" print(f\" 🔸 F1-score : {metrics[label]['f1-score']:.4f}\\n\")\n",
|
| 555 |
+
"\n",
|
| 556 |
+
" print(\"🔄 Moyennes pondérées (weighted avg) :\")\n",
|
| 557 |
+
" print(f\" ✅ Précision : {metrics['weighted avg']['precision']:.4f}\")\n",
|
| 558 |
+
" print(f\" ✅ Rappel : {metrics['weighted avg']['recall']:.4f}\")\n",
|
| 559 |
+
" print(f\" ✅ F1-score : {metrics['weighted avg']['f1-score']:.4f}\")\n",
|
| 560 |
+
"else:\n",
|
| 561 |
+
" print(\"❌ Aucune métrique trouvée dans outputs/metrics.json\")"
|
| 562 |
+
]
|
| 563 |
+
}
|
| 564 |
+
],
|
| 565 |
+
"metadata": {
|
| 566 |
+
"kernelspec": {
|
| 567 |
+
"display_name": "Python 3 (ipykernel)",
|
| 568 |
+
"language": "python",
|
| 569 |
+
"name": "python3"
|
| 570 |
+
},
|
| 571 |
+
"language_info": {
|
| 572 |
+
"codemirror_mode": {
|
| 573 |
+
"name": "ipython",
|
| 574 |
+
"version": 3
|
| 575 |
+
},
|
| 576 |
+
"file_extension": ".py",
|
| 577 |
+
"mimetype": "text/x-python",
|
| 578 |
+
"name": "python",
|
| 579 |
+
"nbconvert_exporter": "python",
|
| 580 |
+
"pygments_lexer": "ipython3",
|
| 581 |
+
"version": "3.10.12"
|
| 582 |
+
}
|
| 583 |
+
},
|
| 584 |
+
"nbformat": 4,
|
| 585 |
+
"nbformat_minor": 5
|
| 586 |
+
}
|