{ "cells": [ { "cell_type": "code", "execution_count": 1, "id": "cell-01-install", "metadata": {}, "outputs": [ { "ename": "", "evalue": "", "output_type": "error", "traceback": [ "\u001b[1;31mThe Kernel crashed while executing code in the current cell or a previous cell. \n", "\u001b[1;31mPlease review the code in the cell(s) to identify a possible cause of the failure. \n", "\u001b[1;31mClick here for more info. \n", "\u001b[1;31mView Jupyter log for further details." ] }, { "ename": "", "evalue": "", "output_type": "error", "traceback": [ "\u001b[1;31mCanceled future for execute_request message before replies were done" ] }, { "ename": "", "evalue": "", "output_type": "error", "traceback": [ "\u001b[1;31mCanceled future for execute_request message before replies were done. \n", "\u001b[1;31mView Jupyter log for further details." ] } ], "source": [ "# ╔══════════════════════════════════════════════════════════════╗\n", "# ║ CELL 1 — Install packages ║\n", "# ║ RUN THIS CELL ALONE FIRST — it will auto-restart runtime ║\n", "# ╚══════════════════════════════════════════════════════════════╝\n", "import subprocess, sys, os\n", "\n", "def pip(*pkgs):\n", " subprocess.check_call([sys.executable, \"-m\", \"pip\", \"install\", \"-q\", *pkgs])\n", "\n", "# No pinned numpy/pandas — let Colab use its pre-installed compatible versions.\n", "# Pinning old numpy/pandas causes the 'mtrand ABI mismatch' ValueError.\n", "pip(\n", " \"datasets>=2.18.0\",\n", " \"transformers>=4.40.0\",\n", " \"sentence-transformers>=2.7.0\",\n", " \"scikit-learn>=1.4.0\",\n", " \"tqdm>=4.66.0\",\n", " \"accelerate>=0.26.0\",\n", " \"evaluate\",\n", ")\n", "\n", "print(\"✅ Packages installed — restarting runtime now …\")\n", "os.kill(os.getpid(), 9) # auto-restart; Colab reconnects in ~5 s" ] }, { "cell_type": "code", "execution_count": 5, "id": "22cff354", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m527.3/527.3 kB\u001b[0m \u001b[31m16.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m177.6/177.6 kB\u001b[0m \u001b[31m20.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[?25h\u001b[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.\n", "gcsfs 2025.3.0 requires fsspec==2025.3.0, but you have fsspec 2024.6.1 which is incompatible.\u001b[0m\u001b[31m\n", "\u001b[0m" ] } ], "source": [ "!pip install -q \"datasets==2.21.0\"" ] }, { "cell_type": "code", "execution_count": 1, "id": "cell-02-gpu", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount(\"/content/drive\", force_remount=True).\n", "CUDA available: True\n", "Device name: Tesla T4\n" ] } ], "source": [ "# ╔══════════════════════════════════════════════════════════════╗\n", "# ║ CELL 2 — Mount Drive + GPU check ║\n", "# ║ Run AFTER the runtime has restarted ║\n", "# ╚══════════════════════════════════════════════════════════════╝\n", "from google.colab import drive\n", "drive.mount(\"/content/drive\")\n", "\n", "import torch\n", "print(\"CUDA available:\", torch.cuda.is_available())\n", "print(\"Device name: \", torch.cuda.get_device_name(0) if torch.cuda.is_available() else \"CPU\")\n", "assert torch.cuda.is_available(), \"❌ No GPU — set Runtime type to T4 GPU!\"\n", "DEVICE = 0" ] }, { "cell_type": "code", "execution_count": 2, "id": "cell-03-config", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "✅ Config ready | output → /content/drive/MyDrive/Athernex/nli_contract_model_final\n" ] } ], "source": [ "# ╔══════════════════════════════════════════════════════════════╗\n", "# ║ CELL 3 — Config ║\n", "# ╚══════════════════════════════════════════════════════════════╝\n", "import os\n", "\n", "CNLI_SIZE = 6820 # full ContractNLI train split\n", "MNLI_SIZE = 50000 # pool; genre filter keeps ~8-10k government rows\n", "SYNTH_SIZE = 1000 # synthetic contradiction pairs\n", "\n", "BASE_MODEL = \"typeform/distilbert-base-uncased-mnli\"\n", "OUTPUT_DIR = \"/content/drive/MyDrive/Athernex/nli_contract_model_final\"\n", "EPOCHS = 5\n", "BATCH_SIZE = 32 # T4 handles 32 at max_length=128\n", "LR = 2e-5\n", "MAX_LEN = 128\n", "\n", "LABEL2ID = {\"entailment\": 0, \"contradiction\": 1, \"neutral\": 2}\n", "ID2LABEL = {v: k for k, v in LABEL2ID.items()}\n", "\n", "os.makedirs(OUTPUT_DIR, exist_ok=True)\n", "print(f\"✅ Config ready | output → {OUTPUT_DIR}\")" ] }, { "cell_type": "code", "execution_count": 3, "id": "cell-04-helpers", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "✅ Data helpers defined.\n" ] } ], "source": [ "# ╔══════════════════════════════════════════════════════════════╗\n", "# ║ CELL 4 — Data loading helpers ║\n", "# ╚══════════════════════════════════════════════════════════════╝\n", "import re\n", "import pandas as pd\n", "from datasets import load_dataset\n", "\n", "def clean_clause(text: str) -> str:\n", " text = re.sub(r'\\s+', ' ', text).strip()\n", " text = re.sub(r'[^\\x00-\\x7F]+', '', text)\n", " return text\n", "\n", "def load_contract_nli(split: str = \"train\", size: int = CNLI_SIZE):\n", " \"\"\"Full ContractNLI — kiddothe2b/contract-nli, subset contractnli_a.\"\"\"\n", " slice_str = f\"{split}[:{size}]\" if size else split\n", " return load_dataset(\n", " \"kiddothe2b/contract-nli\", \"contractnli_a\",\n", " split=slice_str, trust_remote_code=True\n", " )\n", "\n", "def process_contract_nli(dataset) -> pd.DataFrame:\n", " \"\"\"kiddothe2b schema: 0=contradiction, 1=entailment, 2=neutral.\"\"\"\n", " label_map = {0: \"contradiction\", 1: \"entailment\", 2: \"neutral\"}\n", " records = []\n", " for s in dataset:\n", " p = clean_clause(s[\"premise\"])\n", " h = clean_clause(s[\"hypothesis\"])\n", " if len(p) < 20 or len(h) < 20:\n", " continue\n", " records.append({\"clause1\": p, \"clause2\": h,\n", " \"label\": label_map.get(s[\"label\"], \"neutral\")})\n", " return pd.DataFrame(records)\n", "\n", "def load_mnli_government(split: str = \"train\", size: int = MNLI_SIZE):\n", " \"\"\"MultiNLI filtered to government genre.\"\"\"\n", " if split == \"validation\":\n", " split = \"validation_matched\"\n", " slice_str = f\"{split}[:{size}]\" if size else split\n", " ds = load_dataset(\"nyu-mll/multi_nli\", split=slice_str, trust_remote_code=True)\n", " return ds.filter(lambda x: x[\"genre\"] == \"government\")\n", "\n", "def process_mnli_government(dataset) -> pd.DataFrame:\n", " \"\"\"MultiNLI schema: 0=entailment, 1=neutral, 2=contradiction.\"\"\"\n", " label_map = {0: \"entailment\", 1: \"neutral\", 2: \"contradiction\"}\n", " records = []\n", " for s in dataset:\n", " if not s[\"premise\"] or not s[\"hypothesis\"]:\n", " continue\n", " p = clean_clause(s[\"premise\"])\n", " h = clean_clause(s[\"hypothesis\"])\n", " if len(p) < 20 or len(h) < 20:\n", " continue\n", " records.append({\"clause1\": p, \"clause2\": h,\n", " \"label\": label_map.get(s[\"label\"], \"neutral\")})\n", " return pd.DataFrame(records)\n", "\n", "NEGATION_MAP = {\n", " \"shall\": \"shall not\", \"must\": \"must not\",\n", " \"will\": \"will not\", \"may\": \"may not\",\n", " \"is required to\": \"is not required to\",\n", " \"exclusive\": \"non-exclusive\", \"limited\": \"unlimited\",\n", " \"terminate\": \"not terminate\",\n", "}\n", "\n", "def simulate_contradiction(clause: str):\n", " for term, negated in NEGATION_MAP.items():\n", " if term in clause.lower():\n", " return re.sub(term, negated, clause, count=1, flags=re.IGNORECASE)\n", " return None\n", "\n", "def build_synthetic_pairs(clauses: list, sample_size: int = SYNTH_SIZE) -> pd.DataFrame:\n", " import random; random.seed(42)\n", " sampled = random.sample(clauses, min(sample_size, len(clauses)))\n", " records = []\n", " for clause in sampled:\n", " neg = simulate_contradiction(clause)\n", " if neg:\n", " records.append({\"clause1\": clause, \"clause2\": neg, \"label\": \"contradiction\"})\n", " return pd.DataFrame(records)\n", "\n", "print(\"✅ Data helpers defined.\")" ] }, { "cell_type": "code", "execution_count": 4, "id": "cell-05-build", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\n", "=======================================================\n", "BUILDING FULL TRAINING DATA\n", "=======================================================\n", "\n", "[1/3] ContractNLI (size=6820) ...\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "/usr/local/lib/python3.12/dist-packages/huggingface_hub/utils/_auth.py:103: UserWarning: \n", "Error while fetching `HF_TOKEN` secret value from your vault: 'Requesting secret HF_TOKEN timed out. Secrets can only be fetched when running from the Colab UI.'.\n", "You are not authenticated with the Hugging Face Hub in this notebook.\n", "If the error persists, please let us know by opening an issue on GitHub (https://github.com/huggingface/huggingface_hub/issues/new).\n", " warnings.warn(\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "3828de1843b244eab1864397b1be07ec", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Downloading data: 0%| | 0.00/796k [00:00, ?B/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "6d6569e0dd754662b7e7c7da5704ce01", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Downloading data: 0%| | 0.00/213k [00:00, ?B/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "99d706588b8f46b5bafd97aca5302e57", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Downloading data: 0%| | 0.00/114k [00:00, ?B/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "cb547e43501246ab9077fd18cfa7e2ca", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Generating train split: 0%| | 0/6819 [00:00, ? examples/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "3bcc04612f034aa6afd1f4bcfeeb0890", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Generating test split: 0%| | 0/1991 [00:00, ? examples/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "172e8bc3142d4905badd5e7c1b9dc649", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Generating validation split: 0%| | 0/978 [00:00, ? examples/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ " → 6819 pairs\n", "label\n", "entailment 3195\n", "neutral 2820\n", "contradiction 804\n", "\n", "[2/3] MultiNLI government (pool=50000) ...\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "f20dccd17a3648cdb76685916aa5603d", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Downloading readme: 0%| | 0.00/8.89k [00:00, ?B/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "a15041dfe77b4f8fbe005b02daa90f53", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Downloading data: 0%| | 0.00/214M [00:00, ?B/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "a3bd313424c04fa69cb074c62aa20e5b", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Downloading data: 0%| | 0.00/4.94M [00:00, ?B/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "c2711b551a664e6d8e2680dbf96890ac", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Downloading data: 0%| | 0.00/5.10M [00:00, ?B/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "acfa4ba975724a03af7232b72a6d07bf", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Generating train split: 0%| | 0/392702 [00:00, ? examples/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "f5d309c3902744d0921d0255f760c863", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Generating validation_matched split: 0%| | 0/9815 [00:00, ? examples/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "82e9096c59c24f2a83f580b310a5255a", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Generating validation_mismatched split: 0%| | 0/9832 [00:00, ? examples/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "97ee834b40814d619522c5852d240c66", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Filter: 0%| | 0/50000 [00:00, ? examples/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ " → 9937 pairs\n", "label\n", "contradiction 3526\n", "entailment 3318\n", "neutral 3093\n", "\n", "[3/3] Synthetic contradictions (size=1000) ...\n", " → 842 pairs\n", "\n", "=======================================================\n", "✅ Total training pairs: 17598\n", "label\n", "entailment 6513\n", "neutral 5913\n", "contradiction 5172\n", "=======================================================\n" ] } ], "source": [ "# ╔══════════════════════════════════════════════════════════════╗\n", "# ║ CELL 5 — Build full training dataset ║\n", "# ╚══════════════════════════════════════════════════════════════╝\n", "print(\"\\n\" + \"=\"*55)\n", "print(\"BUILDING FULL TRAINING DATA\")\n", "print(\"=\"*55)\n", "\n", "print(f\"\\n[1/3] ContractNLI (size={CNLI_SIZE}) ...\")\n", "cnli_raw = load_contract_nli(size=CNLI_SIZE)\n", "cnli_df = process_contract_nli(cnli_raw)\n", "print(f\" → {len(cnli_df)} pairs\")\n", "print(cnli_df[\"label\"].value_counts().to_string())\n", "\n", "print(f\"\\n[2/3] MultiNLI government (pool={MNLI_SIZE}) ...\")\n", "mnli_raw = load_mnli_government(size=MNLI_SIZE)\n", "mnli_df = process_mnli_government(mnli_raw)\n", "print(f\" → {len(mnli_df)} pairs\")\n", "print(mnli_df[\"label\"].value_counts().to_string())\n", "\n", "print(f\"\\n[3/3] Synthetic contradictions (size={SYNTH_SIZE}) ...\")\n", "synth_df = build_synthetic_pairs(cnli_df[\"clause1\"].tolist(), sample_size=SYNTH_SIZE)\n", "print(f\" → {len(synth_df)} pairs\")\n", "\n", "valid_labels = {\"entailment\", \"contradiction\", \"neutral\"}\n", "full_df = (\n", " pd.concat([cnli_df, mnli_df, synth_df], ignore_index=True)\n", " .sample(frac=1, random_state=42)\n", ")\n", "full_df = full_df[full_df[\"label\"].isin(valid_labels)].dropna(\n", " subset=[\"clause1\", \"clause2\", \"label\"])\n", "\n", "print(f\"\\n{'='*55}\")\n", "print(f\"✅ Total training pairs: {len(full_df)}\")\n", "print(full_df[\"label\"].value_counts().to_string())\n", "print(\"=\"*55)" ] }, { "cell_type": "code", "execution_count": 5, "id": "cell-06-tokenize", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Warning: You are sending unauthenticated requests to the HF Hub. Please set a HF_TOKEN to enable higher rate limits and faster downloads.\n", "WARNING:huggingface_hub.utils._http:Warning: You are sending unauthenticated requests to the HF Hub. Please set a HF_TOKEN to enable higher rate limits and faster downloads.\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "c4180a00e39f4b20a183ac5e50fcc535", "version_major": 2, "version_minor": 0 }, "text/plain": [ "config.json: 0%| | 0.00/776 [00:00, ?B/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "6acbaa79f2bc445a8dba7cbac5e0db63", "version_major": 2, "version_minor": 0 }, "text/plain": [ "tokenizer_config.json: 0%| | 0.00/258 [00:00, ?B/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "cf1904a867804acd804c1269a4cd0b36", "version_major": 2, "version_minor": 0 }, "text/plain": [ "vocab.txt: 0.00B [00:00, ?B/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "c51bd902071d4089a6e35bce0008d300", "version_major": 2, "version_minor": 0 }, "text/plain": [ "special_tokens_map.json: 0%| | 0.00/112 [00:00, ?B/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "736813af7da3439e801d423e723617c9", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Map: 0%| | 0/14958 [00:00, ? examples/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "58c1e30f69f84c45a120c54a3d64865e", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Map: 0%| | 0/2640 [00:00, ? examples/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "Train: 14958 | Eval: 2640\n" ] } ], "source": [ "# ╔══════════════════════════════════════════════════════════════╗\n", "# ║ CELL 6 — Tokenize & split ║\n", "# ╚══════════════════════════════════════════════════════════════╝\n", "from datasets import Dataset\n", "from transformers import AutoTokenizer\n", "\n", "tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL)\n", "\n", "df_train = full_df.copy()\n", "df_train[\"label\"] = df_train[\"label\"].map(LABEL2ID)\n", "\n", "hf_ds = Dataset.from_pandas(df_train[[\"clause1\", \"clause2\", \"label\"]])\n", "splits = hf_ds.train_test_split(test_size=0.15, seed=42)\n", "\n", "def tokenize_fn(batch):\n", " return tokenizer(\n", " batch[\"clause1\"], batch[\"clause2\"],\n", " truncation=True, padding=\"max_length\", max_length=MAX_LEN\n", " )\n", "\n", "tokenized = splits.map(tokenize_fn, batched=True, batch_size=256)\n", "# Drop raw text columns — keep only model inputs + label\n", "tokenized = tokenized.remove_columns([\"clause1\", \"clause2\"])\n", "tokenized.set_format(\"torch\")\n", "\n", "print(f\"Train: {len(tokenized['train'])} | Eval: {len(tokenized['test'])}\")" ] }, { "cell_type": "code", "execution_count": 6, "id": "cell-07-model", "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "4ae908a15d7f40d1acfb2dc7bb152bdf", "version_major": 2, "version_minor": 0 }, "text/plain": [ "model.safetensors: 0%| | 0.00/268M [00:00, ?B/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "882b3a80c47e48d2bd370b812ba97b85", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Loading weights: 0%| | 0/104 [00:00, ?it/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stderr", "output_type": "stream", "text": [ "warmup_ratio is deprecated and will be removed in v5.2. Use `warmup_steps` instead.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "✅ Model & Trainer ready.\n" ] } ], "source": [ "# ╔══════════════════════════════════════════════════════════════╗\n", "# ║ CELL 7 — Model & TrainingArguments ║\n", "# ╚══════════════════════════════════════════════════════════════╝\n", "from transformers import (\n", " AutoModelForSequenceClassification,\n", " TrainingArguments,\n", " Trainer,\n", " EarlyStoppingCallback,\n", ")\n", "import numpy as np\n", "from sklearn.metrics import f1_score, accuracy_score\n", "\n", "model = AutoModelForSequenceClassification.from_pretrained(\n", " BASE_MODEL,\n", " num_labels=3,\n", " id2label=ID2LABEL,\n", " label2id=LABEL2ID,\n", " ignore_mismatched_sizes=True,\n", ")\n", "\n", "def compute_metrics(eval_pred):\n", " logits, labels = eval_pred\n", " preds = np.argmax(logits, axis=-1)\n", " return {\n", " \"accuracy\": accuracy_score(labels, preds),\n", " \"f1\": f1_score(labels, preds, average=\"weighted\", zero_division=0),\n", " }\n", "\n", "training_args = TrainingArguments(\n", " output_dir=OUTPUT_DIR,\n", " num_train_epochs=EPOCHS,\n", " per_device_train_batch_size=BATCH_SIZE,\n", " per_device_eval_batch_size=BATCH_SIZE,\n", " learning_rate=LR,\n", " weight_decay=0.01,\n", " warmup_ratio=0.1,\n", " lr_scheduler_type=\"cosine\",\n", " eval_strategy=\"epoch\", # ← replaces deprecated evaluation_strategy\n", " save_strategy=\"epoch\",\n", " load_best_model_at_end=True,\n", " metric_for_best_model=\"f1\",\n", " greater_is_better=True,\n", " logging_steps=50,\n", " fp16=True, # T4 supports FP16 — ~2x speed boost\n", " dataloader_num_workers=2,\n", " report_to=\"none\",\n", ")\n", "\n", "trainer = Trainer(\n", " model=model,\n", " args=training_args,\n", " train_dataset=tokenized[\"train\"],\n", " eval_dataset=tokenized[\"test\"],\n", " compute_metrics=compute_metrics,\n", " callbacks=[EarlyStoppingCallback(early_stopping_patience=2)],\n", ")\n", "\n", "print(\"✅ Model & Trainer ready.\")" ] }, { "cell_type": "code", "execution_count": 7, "id": "cell-08-train", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\n", "🚀 Starting training ...\n" ] }, { "data": { "text/html": [ "\n", "
| Epoch | \n", "Training Loss | \n", "Validation Loss | \n", "Accuracy | \n", "F1 | \n", "
|---|---|---|---|---|
| 1 | \n", "0.331108 | \n", "0.299753 | \n", "0.882197 | \n", "0.883166 | \n", "
| 2 | \n", "0.287276 | \n", "0.249932 | \n", "0.909470 | \n", "0.909536 | \n", "
| 3 | \n", "0.193082 | \n", "0.232398 | \n", "0.920076 | \n", "0.919938 | \n", "
| 4 | \n", "0.181951 | \n", "0.229638 | \n", "0.923485 | \n", "0.923234 | \n", "
| 5 | \n", "0.148084 | \n", "0.244799 | \n", "0.921591 | \n", "0.921433 | \n", "
"
],
"text/plain": [
"