{ "cells": [ { "cell_type": "markdown", "id": "m0", "metadata": {}, "source": [ "# Train & Evaluate BERT\n", "\n", "Sentence-pair boundary classification on combined PubMed + Wikipedia + Gutenberg data." ] }, { "cell_type": "code", "id": "c0", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import os, sys\n", "os.chdir(os.path.join(os.path.dirname(os.getcwd()), \"..\"))\n", "print(\"Working dir:\", os.getcwd())\n", "import wandb\n", "from dotenv import load_dotenv\n", "\n", "\n", "load_dotenv(\"env.txt\")\n", "wandb.login(key=os.getenv(\"WB_TOKEN\"))" ] }, { "cell_type": "code", "id": "c1", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import json\n", "import logging\n", "from pathlib import Path\n", "\n", "import numpy as np\n", "import torch\n", "import torch.nn as nn\n", "from sklearn.metrics import classification_report, confusion_matrix, f1_score, matthews_corrcoef\n", "from transformers import AutoModelForSequenceClassification, AutoTokenizer, EarlyStoppingCallback, Trainer, TrainingArguments\n", "\n", "from src.datasets.combined_pairs_dataset import CombinedPairsDataset, CombinedPairsConfig, NUM_LABELS, ID2LABEL, LABEL2ID\n", "from src.models.bert import load_bert, load_bert_tokenizer\n", "from src.models.train import WeightedTrainer, compute_metrics\n", "from src.schemas.training_args import BertTrainingArgs\n", "\n", "logging.basicConfig(level=logging.INFO, format=\"%(asctime)s %(levelname)s %(message)s\")\n", "\n", "device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n", "print(f\"Device: {device}\")" ] }, { "cell_type": "markdown", "id": "m1", "metadata": {}, "source": [ "## 1. Configuration & Data" ] }, { "cell_type": "code", "id": "c2", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "args = BertTrainingArgs()\n", "print(f\"epochs: {args.epochs}, lr: {args.lr}, batch_size: {args.batch_size}, patience: {args.patience}\")\n", "\n", "os.environ[\"WANDB_PROJECT\"] = \"bottlecap\"\n", "os.environ[\"WANDB_RUN_NAME\"] = \"bert\"" ] }, { "cell_type": "code", "id": "c3", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "cfg = CombinedPairsConfig(data_root=\"data\", gutenberg_train_cap=args.gutenberg_cap, seed=args.seed, max_length=args.max_length)\n", "builder = CombinedPairsDataset(cfg)\n", "raw_splits = builder.build_splits()\n", "class_weights = builder.compute_class_weights(raw_splits[\"train\"])\n", "print(f\"Class weights: {class_weights.tolist()}\")" ] }, { "cell_type": "markdown", "id": "m2", "metadata": {}, "source": [ "## 2. Train" ] }, { "cell_type": "code", "id": "c4", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "model = load_bert()\n", "tokenizer = load_bert_tokenizer()\n", "dd = builder.build_hf_dataset_dict(tokenizer, raw_splits=raw_splits)\n", "\n", "print(f\"Params: {sum(p.numel() for p in model.parameters()):,}\"\n", " f\" Train: {len(dd['train']):,} Val: {len(dd['val']):,} Test: {len(dd['test']):,}\")" ] }, { "cell_type": "code", "id": "c5", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "trainer = WeightedTrainer(\n", " class_weights=class_weights,\n", " model=model,\n", " args=args.to_training_arguments(),\n", " train_dataset=dd[\"train\"],\n", " eval_dataset=dd[\"val\"],\n", " compute_metrics=compute_metrics,\n", " callbacks=[EarlyStoppingCallback(early_stopping_patience=args.patience)] if args.patience > 0 else [],\n", ")\n", "\n", "trainer.train()" ] }, { "cell_type": "markdown", "id": "m3", "metadata": {}, "source": [ "## 3. Loss Curves" ] }, { "cell_type": "code", "id": "c6", "execution_count": null, "metadata": {}, "outputs": [], "source": "import matplotlib.pyplot as plt\n\nplots_dir = Path(args.output_dir) / \"plots\"\nplots_dir.mkdir(parents=True, exist_ok=True)\n\nhistory = trainer.state.log_history\ntrain_steps = [e[\"step\"] for e in history if \"loss\" in e and \"eval_loss\" not in e]\ntrain_loss = [e[\"loss\"] for e in history if \"loss\" in e and \"eval_loss\" not in e]\neval_steps = [e[\"step\"] for e in history if \"eval_loss\" in e]\neval_loss = [e[\"eval_loss\"] for e in history if \"eval_loss\" in e]\neval_epochs = [e.get(\"epoch\", 0) for e in history if \"eval_loss\" in e]\neval_f1 = [e.get(\"eval_macro_f1\", 0) for e in history if \"eval_loss\" in e]\n\nfig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 4))\nax1.plot(train_steps, train_loss, label=\"Train\", alpha=0.7)\nax1.plot(eval_steps, eval_loss, \"o-\", label=\"Val\")\nax1.set_xlabel(\"Step\"); ax1.set_ylabel(\"Loss\"); ax1.set_title(\"BERT — Loss\"); ax1.legend(); ax1.grid(True, alpha=0.3)\nax2.plot(eval_epochs, eval_f1, \"s-\", label=\"Macro F1\")\nax2.set_xlabel(\"Epoch\"); ax2.set_ylabel(\"F1\"); ax2.set_title(\"BERT — Macro F1\"); ax2.legend(); ax2.grid(True, alpha=0.3)\nplt.tight_layout()\nfig.savefig(plots_dir / \"loss_curves.png\", dpi=150)\nplt.show()\nprint(f\"Saved to {plots_dir / 'loss_curves.png'}\")" }, { "cell_type": "markdown", "id": "m4", "metadata": {}, "source": [ "## 4. Save Best Model" ] }, { "cell_type": "code", "id": "c7", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "best_dir = Path(args.output_dir) / \"best\"\n", "best_dir.mkdir(parents=True, exist_ok=True)\n", "trainer.save_model(str(best_dir))\n", "tokenizer.save_pretrained(str(best_dir))\n", "torch.save(class_weights, best_dir / \"class_weights.pt\")\n", "\n", "train_config = {\"model_type\": \"bert\", \"pretrained\": model.config._name_or_path, \"epochs\": args.epochs, \"batch_size\": args.batch_size, \"lr\": args.lr, \"max_length\": args.max_length, \"class_weights\": class_weights.tolist(), \"num_labels\": NUM_LABELS, \"id2label\": ID2LABEL, \"label2id\": LABEL2ID}\n", "with open(best_dir / \"train_config.json\", \"w\") as f:\n", " json.dump(train_config, f, indent=2)\n", "print(f\"Saved to {best_dir}\")" ] }, { "cell_type": "markdown", "id": "m5", "metadata": {}, "source": [ "## 5. Evaluate" ] }, { "cell_type": "code", "id": "c8", "execution_count": null, "metadata": {}, "outputs": [], "source": "import matplotlib.pyplot as plt\nfrom sklearn.metrics import ConfusionMatrixDisplay\n\ntarget_names = [ID2LABEL[i] for i in range(3)]\n\nfor split in [\"val\", \"test\"]:\n eval_model = AutoModelForSequenceClassification.from_pretrained(str(best_dir))\n eval_trainer = Trainer(model=eval_model, args=TrainingArguments(output_dir=\"/tmp/eval\", per_device_eval_batch_size=args.batch_size*2, report_to=\"none\"))\n preds_out = eval_trainer.predict(dd[split])\n preds = np.argmax(preds_out.predictions, axis=-1)\n labels = preds_out.label_ids\n\n print(f\"\\n{'='*60}\")\n print(f\" BERT — {split} ({len(dd[split]):,} samples)\")\n print(f\"{'='*60}\")\n print(f\" Macro F1: {f1_score(labels, preds, average='macro'):.4f}\")\n print(f\" MCC: {matthews_corrcoef(labels, preds):.4f}\\n\")\n print(classification_report(labels, preds, target_names=target_names, digits=4))\n\n cm = confusion_matrix(labels, preds, labels=[0, 1, 2])\n fig, ax = plt.subplots(figsize=(6, 5))\n ConfusionMatrixDisplay(cm, display_labels=target_names).plot(ax=ax, cmap=\"Blues\", values_format=\",\")\n ax.set_title(f\"BERT — {split}\")\n plt.tight_layout()\n fig.savefig(plots_dir / f\"confusion_matrix_{split}.png\", dpi=150)\n plt.show()\n print(f\"Saved to {plots_dir / f'confusion_matrix_{split}.png'}\")" } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "name": "python", "version": "3.11.0" } }, "nbformat": 4, "nbformat_minor": 5 }