diff --git "a/notebooks_tutorials/03_fine_tuning_posttrained_model.ipynb" "b/notebooks_tutorials/03_fine_tuning_posttrained_model.ipynb" --- "a/notebooks_tutorials/03_fine_tuning_posttrained_model.ipynb" +++ "b/notebooks_tutorials/03_fine_tuning_posttrained_model.ipynb" @@ -6,11 +6,11 @@ "source": [ "# 🧬 Fine-Tuning a Post-trained Model on BigWig Tracks Prediction\n", "\n", - "This notebook demonstrates a **simplified fine-tuning setup** that enables training of a **post-trained Nucleotide Transformer v3 (NTv3) model** to predict BigWig signal tracks directly from DNA sequences. The streamlined approach leverages a pre-trained NTv3 backbone as a feature extractor and adds a custom prediction head that outputs single-nucleotide resolution signal values for various genomic tracks (e.g., ChIP-seq, ATAC-seq, RNA-seq).\n", + "This notebook demonstrates a **simplified fine-tuning setup** that enables training of a post-trained Nucleotide Transformer v3 (NTv3) model to predict BigWig signal tracks directly from DNA sequences. The streamlined approach leverages a pre-trained NTv3 backbone as a feature extractor and adds a custom prediction head that outputs single-nucleotide resolution signal values for various genomic tracks (e.g., ChIP-seq, ATAC-seq, RNA-seq).\n", "\n", "**🎯 Notebook purpose:**\n", "This notebook is configured to train the `NTv3_650M_post` model on the `human` species from the NTv3 benchmark dataset. To run this training, you will need a large GPU (either A100 or H100).\n", - "For a simplified version of this notebook that uses the `NTv3_8M_pre` model and runs on a CPU, please see the [02_fine_tuning.ipynb](https://huggingface.co/spaces/InstaDeepAI/ntv3/blob/main/notebooks_tutorials/02_fine_tuning_pretrained_model.ipynb) notebook.\n", + "For a simplified version of this notebook that uses the `NTv3_8M_pre` model and runs on a CPU, please see the [02_fine_tuning.ipynb](https://huggingface.co/spaces/InstaDeepAI/ntv3/blob/main/notebooks_tutorials/02_fine_tuning.ipynb) notebook.\n", "The notebook uses the same \"simplified setup\" as described there. \n", "\n", "📝 Note for Google Colab users: This notebook is compatible with Colab! This notebook is designed to be run on a high-performance GPU. The default parameters can be used with a H100 with 80GB of HBM." @@ -25,9 +25,58 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 1, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Requirement already satisfied: pyfaidx in ./.venv/lib/python3.11/site-packages (0.9.0.3)\n", + "Requirement already satisfied: pyBigWig in ./.venv/lib/python3.11/site-packages (0.3.24)\n", + "Requirement already satisfied: torchmetrics in ./.venv/lib/python3.11/site-packages (1.8.2)\n", + "Requirement already satisfied: transformers in ./.venv/lib/python3.11/site-packages (4.57.1)\n", + "Requirement already satisfied: packaging in ./.venv/lib/python3.11/site-packages (from pyfaidx) (24.2)\n", + "Requirement already satisfied: numpy>1.20.0 in ./.venv/lib/python3.11/site-packages (from torchmetrics) (2.1.3)\n", + "Requirement already satisfied: torch>=2.0.0 in ./.venv/lib/python3.11/site-packages (from torchmetrics) (2.5.1+cu121)\n", + "Requirement already satisfied: lightning-utilities>=0.8.0 in ./.venv/lib/python3.11/site-packages (from torchmetrics) (0.15.2)\n", + "Requirement already satisfied: filelock in ./.venv/lib/python3.11/site-packages (from transformers) (3.17.0)\n", + "Requirement already satisfied: huggingface-hub<1.0,>=0.34.0 in ./.venv/lib/python3.11/site-packages (from transformers) (0.36.0)\n", + "Requirement already satisfied: pyyaml>=5.1 in ./.venv/lib/python3.11/site-packages (from transformers) (6.0.2)\n", + "Requirement already satisfied: regex!=2019.12.17 in ./.venv/lib/python3.11/site-packages (from transformers) (2024.11.6)\n", + "Requirement already satisfied: requests in ./.venv/lib/python3.11/site-packages (from transformers) (2.32.3)\n", + "Requirement already satisfied: tokenizers<=0.23.0,>=0.22.0 in ./.venv/lib/python3.11/site-packages (from transformers) (0.22.1)\n", + "Requirement already satisfied: safetensors>=0.4.3 in ./.venv/lib/python3.11/site-packages (from transformers) (0.7.0)\n", + "Requirement already satisfied: tqdm>=4.27 in ./.venv/lib/python3.11/site-packages (from transformers) (4.67.1)\n", + "Requirement already satisfied: fsspec>=2023.5.0 in ./.venv/lib/python3.11/site-packages (from huggingface-hub<1.0,>=0.34.0->transformers) (2025.3.0)\n", + "Requirement already satisfied: typing-extensions>=3.7.4.3 in ./.venv/lib/python3.11/site-packages (from huggingface-hub<1.0,>=0.34.0->transformers) (4.12.2)\n", + "Requirement already satisfied: hf-xet<2.0.0,>=1.1.3 in ./.venv/lib/python3.11/site-packages (from huggingface-hub<1.0,>=0.34.0->transformers) (1.2.0)\n", + "Requirement already satisfied: setuptools in ./.venv/lib/python3.11/site-packages (from lightning-utilities>=0.8.0->torchmetrics) (77.0.1)\n", + "Requirement already satisfied: networkx in ./.venv/lib/python3.11/site-packages (from torch>=2.0.0->torchmetrics) (3.6.1)\n", + "Requirement already satisfied: jinja2 in ./.venv/lib/python3.11/site-packages (from torch>=2.0.0->torchmetrics) (3.1.6)\n", + "Requirement already satisfied: nvidia-cuda-nvrtc-cu12==12.1.105 in ./.venv/lib/python3.11/site-packages (from torch>=2.0.0->torchmetrics) (12.1.105)\n", + "Requirement already satisfied: nvidia-cuda-runtime-cu12==12.1.105 in ./.venv/lib/python3.11/site-packages (from torch>=2.0.0->torchmetrics) (12.1.105)\n", + "Requirement already satisfied: nvidia-cuda-cupti-cu12==12.1.105 in ./.venv/lib/python3.11/site-packages (from torch>=2.0.0->torchmetrics) (12.1.105)\n", + "Requirement already satisfied: nvidia-cudnn-cu12==9.1.0.70 in ./.venv/lib/python3.11/site-packages (from torch>=2.0.0->torchmetrics) (9.1.0.70)\n", + "Requirement already satisfied: nvidia-cublas-cu12==12.1.3.1 in ./.venv/lib/python3.11/site-packages (from torch>=2.0.0->torchmetrics) (12.1.3.1)\n", + "Requirement already satisfied: nvidia-cufft-cu12==11.0.2.54 in ./.venv/lib/python3.11/site-packages (from torch>=2.0.0->torchmetrics) (11.0.2.54)\n", + "Requirement already satisfied: nvidia-curand-cu12==10.3.2.106 in ./.venv/lib/python3.11/site-packages (from torch>=2.0.0->torchmetrics) (10.3.2.106)\n", + "Requirement already satisfied: nvidia-cusolver-cu12==11.4.5.107 in ./.venv/lib/python3.11/site-packages (from torch>=2.0.0->torchmetrics) (11.4.5.107)\n", + "Requirement already satisfied: nvidia-cusparse-cu12==12.1.0.106 in ./.venv/lib/python3.11/site-packages (from torch>=2.0.0->torchmetrics) (12.1.0.106)\n", + "Requirement already satisfied: nvidia-nccl-cu12==2.21.5 in ./.venv/lib/python3.11/site-packages (from torch>=2.0.0->torchmetrics) (2.21.5)\n", + "Requirement already satisfied: nvidia-nvtx-cu12==12.1.105 in ./.venv/lib/python3.11/site-packages (from torch>=2.0.0->torchmetrics) (12.1.105)\n", + "Requirement already satisfied: triton==3.1.0 in ./.venv/lib/python3.11/site-packages (from torch>=2.0.0->torchmetrics) (3.1.0)\n", + "Requirement already satisfied: sympy==1.13.1 in ./.venv/lib/python3.11/site-packages (from torch>=2.0.0->torchmetrics) (1.13.1)\n", + "Requirement already satisfied: nvidia-nvjitlink-cu12 in ./.venv/lib/python3.11/site-packages (from nvidia-cusolver-cu12==11.4.5.107->torch>=2.0.0->torchmetrics) (12.9.86)\n", + "Requirement already satisfied: mpmath<1.4,>=1.1.0 in ./.venv/lib/python3.11/site-packages (from sympy==1.13.1->torch>=2.0.0->torchmetrics) (1.3.0)\n", + "Requirement already satisfied: MarkupSafe>=2.0 in ./.venv/lib/python3.11/site-packages (from jinja2->torch>=2.0.0->torchmetrics) (3.0.2)\n", + "Requirement already satisfied: charset-normalizer<4,>=2 in ./.venv/lib/python3.11/site-packages (from requests->transformers) (3.4.1)\n", + "Requirement already satisfied: idna<4,>=2.5 in ./.venv/lib/python3.11/site-packages (from requests->transformers) (3.10)\n", + "Requirement already satisfied: urllib3<3,>=1.21.1 in ./.venv/lib/python3.11/site-packages (from requests->transformers) (2.3.0)\n", + "Requirement already satisfied: certifi>=2017.4.17 in ./.venv/lib/python3.11/site-packages (from requests->transformers) (2025.1.31)\n" + ] + } + ], "source": [ "# Install dependencies\n", "!pip install pyfaidx pyBigWig torchmetrics transformers" @@ -35,7 +84,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 2, "metadata": {}, "outputs": [], "source": [ @@ -97,14 +146,25 @@ "### General\n", "- **`seed`**: Random seed for reproducibility\n", "- **`device`**: Device to run training on (\"cuda\" or \"cpu\")\n", - "- **`num_workers`**: Number of worker processes for DataLoader (0 = single-threaded)" + "- **`num_workers`**: Number of worker processes for DataLoader (0 = single-threaded)\n", + "\n", + "### Reproducing the results reported in the paper\n", + "This notebook is configured to train for 15k steps, which is ~2B tokens. To reproduce the results reported in the paper, you should train for 120k steps (~15B tokens).\n" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 3, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Using device: cuda\n" + ] + } + ], "source": [ "config = {\n", " # Model\n", @@ -157,7 +217,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 4, "metadata": {}, "outputs": [], "source": [ @@ -275,9 +335,178 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 5, "metadata": {}, - "outputs": [], + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "9128a6c975214f6c88ec724605e91a6b", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Fetching 37 files: 0%| | 0/37 [00:00" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], "source": [ "# Plot training results\n", "fig, axes = plt.subplots(1, 2, figsize=(12, 5))\n", @@ -1071,9 +3481,78 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 18, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Running test evaluation with 2500 steps (10000 samples)\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Test evaluation: 100%|██████████| 2500/2500 [11:09<00:00, 3.74it/s]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "==================================================\n", + "Test Set Results\n", + "==================================================\n", + "\n", + "Metrics:\n", + " Mean Pearson: 0.5413\n", + " ENCSR154HRN_P/pearson: 0.4806\n", + " ENCSR701YIC/pearson: 0.5618\n", + " ENCSR935RNW_P/pearson: 0.4262\n", + " ENCSR100LIJ_P/pearson: 0.4342\n", + " ENCSR527JGN_P/pearson: 0.6413\n", + " ENCSR114HGS_P/pearson: 0.2809\n", + " ENCSR484LTQ_M/pearson: 0.5097\n", + " ENCSR935RNW_M/pearson: 0.4451\n", + " ENCSR410DWV/pearson: 0.7197\n", + " ENCSR046BCI_M/pearson: 0.4100\n", + " ENCSR814RGG/pearson: 0.7306\n", + " ENCSR799DGV_P/pearson: 0.4292\n", + " ENCSR527JGN_M/pearson: 0.5838\n", + " ENCSR154HRN_M/pearson: 0.4915\n", + " ENCSR862QCH_M/pearson: 0.5296\n", + " ENCSR100LIJ_M/pearson: 0.4482\n", + " ENCSR321PWZ_P/pearson: 0.6399\n", + " ENCSR484LTQ_P/pearson: 0.4387\n", + " ENCSR619DQO_P/pearson: 0.6396\n", + " ENCSR325NFE/pearson: 0.7549\n", + " ENCSR249ROI_P/pearson: 0.5665\n", + " ENCSR249ROI_M/pearson: 0.5497\n", + " ENCSR754DRC/pearson: 0.4786\n", + " ENCSR321PWZ_M/pearson: 0.6713\n", + " ENCSR862QCH_P/pearson: 0.4601\n", + " ENCSR046BCI_P/pearson: 0.3657\n", + " ENCSR799DGV_M/pearson: 0.4400\n", + " ENCSR962OTG/pearson: 0.8034\n", + " ENCSR682BFG/pearson: 0.6500\n", + " ENCSR628PLS/pearson: 0.6073\n", + " ENCSR619DQO_M/pearson: 0.5792\n", + " ENCSR487QSB/pearson: 0.6684\n", + " ENCSR114HGS_M/pearson: 0.2939\n", + " ENCSR863PSM/pearson: 0.6735\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n" + ] + } + ], "source": [ "# Calculate number of test steps (based on deepspeed pipeline)\n", "num_test_samples = len(test_dataset)\n", @@ -1108,7 +3587,44 @@ "source": [ " ## Test set results\n", "\n", - "#TODO: Add test set results after run!" + "Mean Pearson: 0.5413\n", + "\n", + "- ENCSR154HRN_P/pearson: 0.4806\n", + "- ENCSR701YIC/pearson: 0.5618\n", + "- ENCSR935RNW_P/pearson: 0.4262\n", + "- ENCSR100LIJ_P/pearson: 0.4342\n", + "- ENCSR527JGN_P/pearson: 0.6413\n", + "- ENCSR114HGS_P/pearson: 0.2809\n", + "- ENCSR484LTQ_M/pearson: 0.5097\n", + "- ENCSR935RNW_M/pearson: 0.4451\n", + "- ENCSR410DWV/pearson: 0.7197\n", + "- ENCSR046BCI_M/pearson: 0.4100\n", + "- ENCSR814RGG/pearson: 0.7306\n", + "- ENCSR799DGV_P/pearson: 0.4292\n", + "- ENCSR527JGN_M/pearson: 0.5838\n", + "- ENCSR154HRN_M/pearson: 0.4915\n", + "- ENCSR862QCH_M/pearson: 0.5296\n", + "- ENCSR100LIJ_M/pearson: 0.4482\n", + "- ENCSR321PWZ_P/pearson: 0.6399\n", + "- ENCSR484LTQ_P/pearson: 0.4387\n", + "- ENCSR619DQO_P/pearson: 0.6396\n", + "- ENCSR325NFE/pearson: 0.7549\n", + "- ENCSR249ROI_P/pearson: 0.5665\n", + "- ENCSR249ROI_M/pearson: 0.5497\n", + "- ENCSR754DRC/pearson: 0.4786\n", + "- ENCSR321PWZ_M/pearson: 0.6713\n", + "- ENCSR862QCH_P/pearson: 0.4601\n", + "- ENCSR046BCI_P/pearson: 0.3657\n", + "- ENCSR799DGV_M/pearson: 0.4400\n", + "- ENCSR962OTG/pearson: 0.8034\n", + "- ENCSR682BFG/pearson: 0.6500\n", + "- ENCSR628PLS/pearson: 0.6073\n", + "- ENCSR619DQO_M/pearson: 0.5792\n", + "- ENCSR487QSB/pearson: 0.6684\n", + "- ENCSR114HGS_M/pearson: 0.2939\n", + "- ENCSR863PSM/pearson: 0.6735\n", + "\n", + "NOTE: the performance reported in the paper is reached after ~15B tokens (~8x more than this notebook)" ] }, {