{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": "# Train & Evaluate DistilBERT / BERT / DeBERTa\n\nSentence-pair boundary classification on combined PubMed + Wikipedia + Gutenberg data.\n\n**Metrics:** per-class F1, macro F1, weighted F1, MCC (Matthews Correlation Coefficient)." }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import os, sys\n", "os.chdir(os.path.join(os.path.dirname(os.getcwd()), \"..\"))\n", "print(\"Working dir:\", os.getcwd())\n", "import wandb\n", "from dotenv import load_dotenv\n", "\n", "\n", "load_dotenv()\n", "wandb.login(key=os.getenv(\"WB_TOKEN\"))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": "import json\nimport logging\nfrom pathlib import Path\n\nimport numpy as np\nimport torch\nimport torch.nn as nn\nfrom sklearn.metrics import (\n classification_report,\n confusion_matrix,\n f1_score,\n matthews_corrcoef,\n)\nfrom transformers import (\n AutoModelForSequenceClassification,\n AutoTokenizer,\n EarlyStoppingCallback,\n Trainer,\n TrainingArguments,\n)\n\nfrom src.datasets.combined_pairs_dataset import (\n CombinedPairsDataset,\n CombinedPairsConfig,\n NUM_LABELS,\n ID2LABEL,\n LABEL2ID,\n)\nfrom src.models.bert import load_bert, load_bert_tokenizer\nfrom src.models.deberta import load_deberta, load_deberta_tokenizer\nfrom src.models.distilbert import load_distilbert, load_distilbert_tokenizer\nfrom src.models.train import WeightedTrainer, compute_metrics\nfrom src.schemas.training_args import BertTrainingArgs, DebertaTrainingArgs, DistilBertTrainingArgs\n\nlogging.basicConfig(level=logging.INFO, format=\"%(asctime)s %(levelname)s %(message)s\")\n\ndevice = \"cuda\" if torch.cuda.is_available() else \"cpu\"\nprint(f\"Device: {device}\")\nprint(f\"PyTorch: {torch.__version__}\")" }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 1. Configuration" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": "# ── Training args from dataclasses (edit fields to override defaults) ──\ndistilbert_args = DistilBertTrainingArgs()\nbert_args = BertTrainingArgs()\ndeberta_args = DebertaTrainingArgs()\n\nfor name, args in [(\"DistilBERT\", distilbert_args), (\"BERT\", bert_args), (\"DeBERTa\", deberta_args)]:\n print(f\"{name} config:\")\n print(f\" output_dir: {args.output_dir}\")\n print(f\" epochs: {args.epochs}\")\n print(f\" batch_size: {args.batch_size}\")\n print(f\" lr: {args.lr}\")\n print(f\" max_length: {args.max_length}\")\n print(f\" gutenberg_cap: {args.gutenberg_cap}\")\n print()" }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 2. Build dataset splits" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": "cfg = CombinedPairsConfig(\n data_root=\"data\",\n gutenberg_train_cap=distilbert_args.gutenberg_cap,\n seed=distilbert_args.seed,\n max_length=distilbert_args.max_length,\n)\nbuilder = CombinedPairsDataset(cfg)\nraw_splits = builder.build_splits()\nclass_weights = builder.compute_class_weights(raw_splits[\"train\"])\nprint(f\"\\nClass weights: {class_weights.tolist()}\")" }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 3. Helper functions" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": "import matplotlib.pyplot as plt\nfrom sklearn.metrics import ConfusionMatrixDisplay\n\n\ndef plot_loss_curves(trainer, model_name: str, output_dir: str):\n \"\"\"Plot training and validation loss curves from Trainer log history.\"\"\"\n plots_dir = Path(output_dir) / \"plots\"\n plots_dir.mkdir(parents=True, exist_ok=True)\n\n history = trainer.state.log_history\n\n train_steps, train_loss = [], []\n eval_steps, eval_loss = [], []\n eval_epochs, eval_macro_f1, eval_weighted_f1, eval_mcc = [], [], [], []\n\n for entry in history:\n if \"loss\" in entry and \"eval_loss\" not in entry:\n train_steps.append(entry[\"step\"])\n train_loss.append(entry[\"loss\"])\n if \"eval_loss\" in entry:\n eval_steps.append(entry[\"step\"])\n eval_loss.append(entry[\"eval_loss\"])\n eval_epochs.append(entry.get(\"epoch\", 0))\n eval_macro_f1.append(entry.get(\"eval_macro_f1\", 0))\n eval_weighted_f1.append(entry.get(\"eval_weighted_f1\", 0))\n eval_mcc.append(entry.get(\"eval_mcc\", 0))\n\n fig, axes = plt.subplots(1, 2, figsize=(14, 4))\n\n # Loss curves\n ax = axes[0]\n ax.plot(train_steps, train_loss, label=\"Train loss\", alpha=0.7)\n ax.plot(eval_steps, eval_loss, \"o-\", label=\"Val loss\", markersize=6)\n ax.set_xlabel(\"Step\")\n ax.set_ylabel(\"Loss\")\n ax.set_title(f\"{model_name} — Loss\")\n ax.legend()\n ax.grid(True, alpha=0.3)\n\n # Metrics per epoch\n ax = axes[1]\n ax.plot(eval_epochs, eval_macro_f1, \"s-\", label=\"Macro F1\", markersize=6)\n ax.plot(eval_epochs, eval_weighted_f1, \"^-\", label=\"Weighted F1\", markersize=6)\n ax.plot(eval_epochs, eval_mcc, \"D-\", label=\"MCC\", markersize=6)\n ax.set_xlabel(\"Epoch\")\n ax.set_ylabel(\"Score\")\n ax.set_title(f\"{model_name} — Eval Metrics\")\n ax.set_ylim(0, 1)\n ax.legend()\n ax.grid(True, alpha=0.3)\n\n plt.tight_layout()\n fig.savefig(plots_dir / \"loss_curves.png\", dpi=150)\n plt.show()\n print(f\"Saved to {plots_dir / 'loss_curves.png'}\")\n\n\ndef save_best(trainer, model, tokenizer, train_args, model_type: str):\n best_dir = Path(train_args.output_dir) / \"best\"\n best_dir.mkdir(parents=True, exist_ok=True)\n\n trainer.save_model(str(best_dir))\n tokenizer.save_pretrained(str(best_dir))\n torch.save(class_weights, best_dir / \"class_weights.pt\")\n\n train_config = {\n \"model_type\": model_type,\n \"pretrained\": model.config._name_or_path,\n \"epochs\": train_args.epochs,\n \"batch_size\": train_args.batch_size,\n \"lr\": train_args.lr,\n \"max_length\": train_args.max_length,\n \"class_weights\": class_weights.tolist(),\n \"num_labels\": NUM_LABELS,\n \"id2label\": ID2LABEL,\n \"label2id\": LABEL2ID,\n }\n with open(best_dir / \"train_config.json\", \"w\") as f:\n json.dump(train_config, f, indent=2)\n\n print(f\"Best model saved to {best_dir}\")\n\n\ndef evaluate_model(model_dir: Path, dd, model_name: str, split=\"test\", batch_size=32, seed=42):\n \"\"\"Run Trainer.predict() and print full metrics report.\"\"\"\n plots_dir = model_dir.parent / \"plots\"\n plots_dir.mkdir(parents=True, exist_ok=True)\n\n model = AutoModelForSequenceClassification.from_pretrained(str(model_dir))\n ds = dd[split]\n\n eval_args = TrainingArguments(\n output_dir=\"/tmp/eval_output\",\n per_device_eval_batch_size=batch_size,\n report_to=\"none\",\n seed=seed,\n )\n trainer = Trainer(model=model, args=eval_args)\n predictions = trainer.predict(ds)\n\n preds = np.argmax(predictions.predictions, axis=-1)\n labels = predictions.label_ids\n target_names = [ID2LABEL[i] for i in range(3)]\n\n weighted_f1 = f1_score(labels, preds, average=\"weighted\")\n macro_f1 = f1_score(labels, preds, average=\"macro\")\n mcc = matthews_corrcoef(labels, preds)\n\n print(f\"\\n{'='*60}\")\n print(f\" {model_name} — {split} split ({len(ds):,} samples)\")\n print(f\"{'='*60}\\n\")\n print(f\" Weighted F1: {weighted_f1:.4f}\")\n print(f\" Macro F1: {macro_f1:.4f}\")\n print(f\" MCC: {mcc:.4f}\\n\")\n\n print(classification_report(labels, preds, target_names=target_names, digits=4))\n\n # Confusion matrix\n cm = confusion_matrix(labels, preds, labels=[0, 1, 2])\n fig, ax = plt.subplots(figsize=(6, 5))\n ConfusionMatrixDisplay(cm, display_labels=target_names).plot(ax=ax, cmap=\"Blues\", values_format=\",\")\n ax.set_title(f\"{model_name} — {split}\")\n plt.tight_layout()\n fig.savefig(plots_dir / f\"confusion_matrix_{split}.png\", dpi=150)\n plt.show()\n print(f\"Saved to {plots_dir / f'confusion_matrix_{split}.png'}\")\n\n # save metrics json\n report_dict = classification_report(labels, preds, target_names=target_names, output_dict=True)\n report_dict[\"weighted_f1\"] = weighted_f1\n report_dict[\"macro_f1\"] = macro_f1\n report_dict[\"mcc\"] = mcc\n\n out_path = model_dir / f\"{split}_metrics.json\"\n with open(out_path, \"w\") as f:\n json.dump(report_dict, f, indent=2)\n print(f\"Metrics saved to {out_path}\")\n\n return report_dict" }, { "cell_type": "markdown", "metadata": {}, "source": [ "---\n", "## 4. Train DistilBERT" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": "distilbert_model = load_distilbert()\ndistilbert_tokenizer = load_distilbert_tokenizer()\n\ndd_distilbert = builder.build_hf_dataset_dict(distilbert_tokenizer, raw_splits=raw_splits)\n\nprint(f\"Params: {sum(p.numel() for p in distilbert_model.parameters()):,}\")\nprint(f\"Train: {len(dd_distilbert['train']):,} Val: {len(dd_distilbert['val']):,} Test: {len(dd_distilbert['test']):,}\")" }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": "distilbert_trainer = WeightedTrainer(\n class_weights=class_weights,\n model=distilbert_model,\n args=distilbert_args.to_training_arguments(),\n train_dataset=dd_distilbert[\"train\"],\n eval_dataset=dd_distilbert[\"val\"],\n compute_metrics=compute_metrics,\n callbacks=[EarlyStoppingCallback(early_stopping_patience=distilbert_args.patience)]\n if distilbert_args.patience > 0 else [],\n)\n\ndistilbert_trainer.train()" }, { "cell_type": "code", "source": "plot_loss_curves(distilbert_trainer, \"DistilBERT\", distilbert_args.output_dir)", "metadata": {}, "execution_count": null, "outputs": [] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": "save_best(distilbert_trainer, distilbert_model, distilbert_tokenizer, distilbert_args, \"distilbert\")" }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 4.1 Evaluate DistilBERT" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": "distilbert_best = Path(distilbert_args.output_dir) / \"best\"\ndistilbert_val_metrics = evaluate_model(distilbert_best, dd_distilbert, \"DistilBERT\", split=\"val\",\n batch_size=distilbert_args.batch_size * 2,\n seed=distilbert_args.seed)" }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": "distilbert_test_metrics = evaluate_model(distilbert_best, dd_distilbert, \"DistilBERT\", split=\"test\",\n batch_size=distilbert_args.batch_size * 2,\n seed=distilbert_args.seed)" }, { "cell_type": "markdown", "metadata": {}, "source": [ "---\n", "## 5. Train BERT" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": "bert_model = load_bert()\nbert_tokenizer = load_bert_tokenizer()\n\ndd_bert = builder.build_hf_dataset_dict(bert_tokenizer, raw_splits=raw_splits)\n\nprint(f\"Params: {sum(p.numel() for p in bert_model.parameters()):,}\")\nprint(f\"Train: {len(dd_bert['train']):,} Val: {len(dd_bert['val']):,} Test: {len(dd_bert['test']):,}\")" }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": "bert_trainer = WeightedTrainer(\n class_weights=class_weights,\n model=bert_model,\n args=bert_args.to_training_arguments(),\n train_dataset=dd_bert[\"train\"],\n eval_dataset=dd_bert[\"val\"],\n compute_metrics=compute_metrics,\n callbacks=[EarlyStoppingCallback(early_stopping_patience=bert_args.patience)]\n if bert_args.patience > 0 else [],\n)\n\nbert_trainer.train()" }, { "cell_type": "code", "source": "plot_loss_curves(bert_trainer, \"BERT\", bert_args.output_dir)", "metadata": {}, "execution_count": null, "outputs": [] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": "save_best(bert_trainer, bert_model, bert_tokenizer, bert_args, \"bert\")" }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 5.1 Evaluate BERT" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": "bert_best = Path(bert_args.output_dir) / \"best\"\nbert_val_metrics = evaluate_model(bert_best, dd_bert, \"BERT\", split=\"val\",\n batch_size=bert_args.batch_size * 2,\n seed=bert_args.seed)" }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": "bert_test_metrics = evaluate_model(bert_best, dd_bert, \"BERT\", split=\"test\",\n batch_size=bert_args.batch_size * 2,\n seed=bert_args.seed)" }, { "cell_type": "markdown", "source": "---\n## 6. Train DeBERTa", "metadata": {} }, { "cell_type": "code", "source": "deberta_model = load_deberta()\ndeberta_tokenizer = load_deberta_tokenizer()\n\ndd_deberta = builder.build_hf_dataset_dict(deberta_tokenizer, raw_splits=raw_splits)\n\nprint(f\"Params: {sum(p.numel() for p in deberta_model.parameters()):,}\")\nprint(f\"Train: {len(dd_deberta['train']):,} Val: {len(dd_deberta['val']):,} Test: {len(dd_deberta['test']):,}\")", "metadata": {}, "execution_count": null, "outputs": [] }, { "cell_type": "code", "source": "deberta_trainer = WeightedTrainer(\n class_weights=class_weights,\n model=deberta_model,\n args=deberta_args.to_training_arguments(),\n train_dataset=dd_deberta[\"train\"],\n eval_dataset=dd_deberta[\"val\"],\n compute_metrics=compute_metrics,\n callbacks=[EarlyStoppingCallback(early_stopping_patience=deberta_args.patience)]\n if deberta_args.patience > 0 else [],\n)\n\ndeberta_trainer.train()", "metadata": {}, "execution_count": null, "outputs": [] }, { "cell_type": "code", "source": "plot_loss_curves(deberta_trainer, \"DeBERTa\", deberta_args.output_dir)", "metadata": {}, "execution_count": null, "outputs": [] }, { "cell_type": "code", "source": "save_best(deberta_trainer, deberta_model, deberta_tokenizer, deberta_args, \"deberta\")", "metadata": {}, "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "source": "### 6.1 Evaluate DeBERTa", "metadata": {} }, { "cell_type": "code", "source": "deberta_best = Path(deberta_args.output_dir) / \"best\"\ndeberta_val_metrics = evaluate_model(deberta_best, dd_deberta, \"DeBERTa\", split=\"val\",\n batch_size=deberta_args.batch_size * 2,\n seed=deberta_args.seed)", "metadata": {}, "execution_count": null, "outputs": [] }, { "cell_type": "code", "source": "deberta_test_metrics = evaluate_model(deberta_best, dd_deberta, \"DeBERTa\", split=\"test\",\n batch_size=deberta_args.batch_size * 2,\n seed=deberta_args.seed)", "metadata": {}, "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "metadata": {}, "source": "---\n## 7. Compare models" }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": "import pandas as pd\n\ncomparison = pd.DataFrame({\n \"Metric\": [\"Weighted F1\", \"Macro F1\", \"MCC\",\n \"F1 SAME_PARA\", \"F1 NEW_PARA\", \"F1 NEWLINE\"],\n \"DistilBERT\": [\n distilbert_test_metrics[\"weighted_f1\"],\n distilbert_test_metrics[\"macro_f1\"],\n distilbert_test_metrics[\"mcc\"],\n distilbert_test_metrics[\"SAME_PARAGRAPH\"][\"f1-score\"],\n distilbert_test_metrics[\"NEW_PARAGRAPH\"][\"f1-score\"],\n distilbert_test_metrics[\"NEWLINE\"][\"f1-score\"],\n ],\n \"BERT\": [\n bert_test_metrics[\"weighted_f1\"],\n bert_test_metrics[\"macro_f1\"],\n bert_test_metrics[\"mcc\"],\n bert_test_metrics[\"SAME_PARAGRAPH\"][\"f1-score\"],\n bert_test_metrics[\"NEW_PARAGRAPH\"][\"f1-score\"],\n bert_test_metrics[\"NEWLINE\"][\"f1-score\"],\n ],\n \"DeBERTa\": [\n deberta_test_metrics[\"weighted_f1\"],\n deberta_test_metrics[\"macro_f1\"],\n deberta_test_metrics[\"mcc\"],\n deberta_test_metrics[\"SAME_PARAGRAPH\"][\"f1-score\"],\n deberta_test_metrics[\"NEW_PARAGRAPH\"][\"f1-score\"],\n deberta_test_metrics[\"NEWLINE\"][\"f1-score\"],\n ],\n})\n\ncomparison = comparison.set_index(\"Metric\")\ncomparison = comparison.round(4)\ncomparison" }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": "import matplotlib.pyplot as plt\n\ncomparison_dir = Path(\"checkpoints/plots\")\ncomparison_dir.mkdir(parents=True, exist_ok=True)\n\nmetrics_to_plot = [\"Weighted F1\", \"Macro F1\", \"MCC\"]\nplot_data = comparison.loc[metrics_to_plot]\n\nax = plot_data.plot.bar(rot=0, figsize=(10, 4))\nax.set_ylim(0, 1)\nax.set_ylabel(\"Score\")\nax.set_title(\"DistilBERT vs BERT vs DeBERTa — Test Set\")\nax.legend(loc=\"lower right\")\n\nfor container in ax.containers:\n ax.bar_label(container, fmt=\"%.3f\", fontsize=8, padding=2)\n\nplt.tight_layout()\nplt.savefig(comparison_dir / \"comparison_main_metrics.png\", dpi=150)\nplt.show()\nprint(f\"Saved to {comparison_dir / 'comparison_main_metrics.png'}\")" }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": "per_class = [\"F1 SAME_PARA\", \"F1 NEW_PARA\", \"F1 NEWLINE\"]\nplot_data = comparison.loc[per_class]\n\nax = plot_data.plot.bar(rot=0, figsize=(10, 4))\nax.set_ylim(0, 1)\nax.set_ylabel(\"F1 Score\")\nax.set_title(\"Per-class F1 — Test Set\")\nax.legend(loc=\"lower right\")\n\nfor container in ax.containers:\n ax.bar_label(container, fmt=\"%.3f\", fontsize=8, padding=2)\n\nplt.tight_layout()\nplt.savefig(comparison_dir / \"comparison_per_class_f1.png\", dpi=150)\nplt.show()\nprint(f\"Saved to {comparison_dir / 'comparison_per_class_f1.png'}\")" } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "name": "python", "version": "3.10.0" } }, "nbformat": 4, "nbformat_minor": 4 }