Spaces:
Running
Running
Commit
·
cdc7a28
1
Parent(s):
13a5a11
feat: update notebooks
Browse files
notebooks_tutorials/{02_fine_tuning.ipynb → 02_fine_tuning_pretrained_model.ipynb}
RENAMED
|
@@ -4,9 +4,9 @@
|
|
| 4 |
"cell_type": "markdown",
|
| 5 |
"metadata": {},
|
| 6 |
"source": [
|
| 7 |
-
"# 🧬 Fine-Tuning a Model on BigWig Tracks Prediction\n",
|
| 8 |
"\n",
|
| 9 |
-
"This notebook demonstrates a **simplified fine-tuning setup** that enables training of a pre-trained Nucleotide Transformer v3 (NTv3) model to predict BigWig signal tracks directly from DNA sequences. The streamlined approach leverages a pre-trained NTv3 backbone as a feature extractor and adds a custom prediction head that outputs single-nucleotide resolution signal values for various genomic tracks (e.g., ChIP-seq, ATAC-seq, RNA-seq).\n",
|
| 10 |
"\n",
|
| 11 |
"📊 We provide access to the NTv3-benchmark data that we released on our Hugging Face dataset: `InstaDeepAI/NTv3_benchmark_dataset`. In this repository, you will find ready-to-use genome FASTA files, Bigwig tracks, metadata, but also the splits that were used for the benchmark.\n",
|
| 12 |
"\n",
|
|
@@ -15,7 +15,7 @@
|
|
| 15 |
"- **Constant learning rate**: Uses a fixed learning rate throughout training without learning rate scheduling\n",
|
| 16 |
"- **No gradient accumulation**: Implements simple step-based training without gradient accumulation, making the training loop more straightforward\n",
|
| 17 |
"\n",
|
| 18 |
-
"**⚡ Key Advantage**: This simplified pipeline achieves close performance to more complex training approaches while enabling fast fine-tuning: on a H100 GPU and using 16 workers for data loading, it takes ~15min to reach acceptable performances for a 32kb functional tracks prediction task on **NTv3_8M_pre** model. The training speed benefits from the efficient NTv3 model architecture, but of course depends on your hardware capabilities (GPU acceleration and multi-worker data loading significantly reduce training time)
|
| 19 |
]
|
| 20 |
},
|
| 21 |
{
|
|
|
|
| 4 |
"cell_type": "markdown",
|
| 5 |
"metadata": {},
|
| 6 |
"source": [
|
| 7 |
+
"# 🧬 Fine-Tuning a Pre-Trained Model on BigWig Tracks Prediction\n",
|
| 8 |
"\n",
|
| 9 |
+
"This notebook demonstrates a **simplified fine-tuning setup** that enables training of a **pre-trained Nucleotide Transformer v3 (NTv3) model** to predict BigWig signal tracks directly from DNA sequences. The streamlined approach leverages a pre-trained NTv3 backbone as a feature extractor and adds a custom prediction head that outputs single-nucleotide resolution signal values for various genomic tracks (e.g., ChIP-seq, ATAC-seq, RNA-seq).\n",
|
| 10 |
"\n",
|
| 11 |
"📊 We provide access to the NTv3-benchmark data that we released on our Hugging Face dataset: `InstaDeepAI/NTv3_benchmark_dataset`. In this repository, you will find ready-to-use genome FASTA files, Bigwig tracks, metadata, but also the splits that were used for the benchmark.\n",
|
| 12 |
"\n",
|
|
|
|
| 15 |
"- **Constant learning rate**: Uses a fixed learning rate throughout training without learning rate scheduling\n",
|
| 16 |
"- **No gradient accumulation**: Implements simple step-based training without gradient accumulation, making the training loop more straightforward\n",
|
| 17 |
"\n",
|
| 18 |
+
"**⚡ Key Advantage**: This simplified pipeline achieves close performance to more complex training approaches while enabling fast fine-tuning: on a H100 GPU and using 16 workers for data loading, it takes ~15min to reach acceptable performances for a 32kb functional tracks prediction task on **NTv3_8M_pre** model. The training speed benefits from the efficient NTv3 model architecture, but of course depends on your hardware capabilities (GPU acceleration and multi-worker data loading significantly reduce training time)."
|
| 19 |
]
|
| 20 |
},
|
| 21 |
{
|
notebooks_tutorials/03_fine_tuning_posttrained_model.ipynb
ADDED
|
@@ -0,0 +1,1141 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"cells": [
|
| 3 |
+
{
|
| 4 |
+
"cell_type": "markdown",
|
| 5 |
+
"metadata": {},
|
| 6 |
+
"source": [
|
| 7 |
+
"# 🧬 Fine-Tuning a Post-trained Model on BigWig Tracks Prediction\n",
|
| 8 |
+
"\n",
|
| 9 |
+
"This notebook demonstrates a **simplified fine-tuning setup** that enables training of a **post-trained Nucleotide Transformer v3 (NTv3) model** to predict BigWig signal tracks directly from DNA sequences. The streamlined approach leverages a pre-trained NTv3 backbone as a feature extractor and adds a custom prediction head that outputs single-nucleotide resolution signal values for various genomic tracks (e.g., ChIP-seq, ATAC-seq, RNA-seq).\n",
|
| 10 |
+
"\n",
|
| 11 |
+
"**🎯 Notebook purpose:**\n",
|
| 12 |
+
"This notebook is configured to train the `NTv3_650M_post` model on the `human` species from the NTv3 benchmark dataset. To run this training, you will need a large GPU (either A100 or H100).\n",
|
| 13 |
+
"For a simplified version of this notebook that uses the `NTv3_8M_pre` model and runs on a CPU, please see the [02_fine_tuning.ipynb](https://huggingface.co/spaces/InstaDeepAI/ntv3/blob/main/notebooks_tutorials/02_fine_tuning_pretrained_model.ipynb) notebook.\n",
|
| 14 |
+
"The notebook uses the same \"simplified setup\" as described there. \n",
|
| 15 |
+
"\n",
|
| 16 |
+
"📝 Note for Google Colab users: This notebook is compatible with Colab! This notebook is designed to be run on a high-performance GPU. The default parameters can be used with a H100 with 80GB of HBM."
|
| 17 |
+
]
|
| 18 |
+
},
|
| 19 |
+
{
|
| 20 |
+
"cell_type": "markdown",
|
| 21 |
+
"metadata": {},
|
| 22 |
+
"source": [
|
| 23 |
+
"# 0. 📦 Imports dependencies"
|
| 24 |
+
]
|
| 25 |
+
},
|
| 26 |
+
{
|
| 27 |
+
"cell_type": "code",
|
| 28 |
+
"execution_count": null,
|
| 29 |
+
"metadata": {},
|
| 30 |
+
"outputs": [],
|
| 31 |
+
"source": [
|
| 32 |
+
"# Install dependencies\n",
|
| 33 |
+
"!pip install pyfaidx pyBigWig torchmetrics transformers"
|
| 34 |
+
]
|
| 35 |
+
},
|
| 36 |
+
{
|
| 37 |
+
"cell_type": "code",
|
| 38 |
+
"execution_count": null,
|
| 39 |
+
"metadata": {},
|
| 40 |
+
"outputs": [],
|
| 41 |
+
"source": [
|
| 42 |
+
"import random\n",
|
| 43 |
+
"import functools\n",
|
| 44 |
+
"from typing import List, Dict, Callable\n",
|
| 45 |
+
"import os\n",
|
| 46 |
+
"import fnmatch\n",
|
| 47 |
+
"from pathlib import Path\n",
|
| 48 |
+
"from huggingface_hub import HfApi, snapshot_download\n",
|
| 49 |
+
"\n",
|
| 50 |
+
"import torch\n",
|
| 51 |
+
"import torch.nn as nn\n",
|
| 52 |
+
"import torch.nn.functional as F\n",
|
| 53 |
+
"from torch.utils.data import Dataset, DataLoader\n",
|
| 54 |
+
"from torch.optim import AdamW\n",
|
| 55 |
+
"from transformers import AutoConfig, AutoModel, AutoTokenizer\n",
|
| 56 |
+
"import pandas as pd\n",
|
| 57 |
+
"import matplotlib.pyplot as plt\n",
|
| 58 |
+
"import numpy as np\n",
|
| 59 |
+
"import pyBigWig\n",
|
| 60 |
+
"from pyfaidx import Fasta\n",
|
| 61 |
+
"from torchmetrics import PearsonCorrCoef\n",
|
| 62 |
+
"from tqdm import tqdm"
|
| 63 |
+
]
|
| 64 |
+
},
|
| 65 |
+
{
|
| 66 |
+
"cell_type": "markdown",
|
| 67 |
+
"metadata": {},
|
| 68 |
+
"source": [
|
| 69 |
+
"# 1. ⚙️ Configuration\n",
|
| 70 |
+
"\n",
|
| 71 |
+
"## Configuration Parameters\n",
|
| 72 |
+
"\n",
|
| 73 |
+
"### Model\n",
|
| 74 |
+
"- **`model_name`**: HuggingFace model name/identifier for the pretrained backbone model\n",
|
| 75 |
+
"\n",
|
| 76 |
+
"### Data\n",
|
| 77 |
+
"- **`hf_repo_id`**: HuggingFace dataset repository ID containing the benchmark data\n",
|
| 78 |
+
"- **`species`**: Species name (e.g., \"human\") to select data from the benchmark dataset\n",
|
| 79 |
+
"- **`data_cache_dir`**: Directory where downloaded data files (FASTA, bigWig) will be stored\n",
|
| 80 |
+
"- **`sequence_length`**: Length of input sequences in base pairs (bp)\n",
|
| 81 |
+
"- **`keep_target_center_fraction`**: Fraction of center sequence to keep for target prediction (crops edges to focus on center)\n",
|
| 82 |
+
"\n",
|
| 83 |
+
"### Training\n",
|
| 84 |
+
"- **`batch_size`**: Number of samples per batch\n",
|
| 85 |
+
"- **`learning_rate`**: Constant learning rate for optimizer\n",
|
| 86 |
+
"- **`weight_decay`**: L2 regularization coefficient for optimizer\n",
|
| 87 |
+
"- **`num_steps_training`**: Total number of training steps\n",
|
| 88 |
+
"- **`log_every_n_steps`**: Log training metrics every N steps\n",
|
| 89 |
+
"\n",
|
| 90 |
+
"### Validation\n",
|
| 91 |
+
"- **`validate_every_n_steps`**: Run validation every N steps\n",
|
| 92 |
+
"- **`num_validation_samples`**: Number of samples to use for validation set\n",
|
| 93 |
+
"\n",
|
| 94 |
+
"### Test\n",
|
| 95 |
+
"- **`num_test_samples`**: Number of samples to use for test set evaluation\n",
|
| 96 |
+
"\n",
|
| 97 |
+
"### General\n",
|
| 98 |
+
"- **`seed`**: Random seed for reproducibility\n",
|
| 99 |
+
"- **`device`**: Device to run training on (\"cuda\" or \"cpu\")\n",
|
| 100 |
+
"- **`num_workers`**: Number of worker processes for DataLoader (0 = single-threaded)"
|
| 101 |
+
]
|
| 102 |
+
},
|
| 103 |
+
{
|
| 104 |
+
"cell_type": "code",
|
| 105 |
+
"execution_count": null,
|
| 106 |
+
"metadata": {},
|
| 107 |
+
"outputs": [],
|
| 108 |
+
"source": [
|
| 109 |
+
"config = {\n",
|
| 110 |
+
" # Model\n",
|
| 111 |
+
" \"model_name\": \"InstaDeepAI/NTv3_650M_post\",\n",
|
| 112 |
+
" \n",
|
| 113 |
+
" # Data\n",
|
| 114 |
+
" \"hf_repo_id\": \"InstaDeepAI/NTv3_benchmark_dataset\",\n",
|
| 115 |
+
" \"species_name\": \"human\",\n",
|
| 116 |
+
" \"data_cache_dir\": \"./data\",\n",
|
| 117 |
+
" \"sequence_length\": 32_768,\n",
|
| 118 |
+
" \"keep_target_center_fraction\": 0.375,\n",
|
| 119 |
+
" \n",
|
| 120 |
+
" # Training\n",
|
| 121 |
+
" \"batch_size\": 4,\n",
|
| 122 |
+
" \"num_steps_training\": 15000, #~2B tokens\n",
|
| 123 |
+
" \"log_every_n_steps\": 40,\n",
|
| 124 |
+
" \"learning_rate\": 1e-5,\n",
|
| 125 |
+
" \"weight_decay\": 0.01,\n",
|
| 126 |
+
" \n",
|
| 127 |
+
" # Validation\n",
|
| 128 |
+
" \"validate_every_n_steps\": 400, \n",
|
| 129 |
+
" \"num_validation_samples\": 1000,\n",
|
| 130 |
+
"\n",
|
| 131 |
+
" # Test\n",
|
| 132 |
+
" \"num_test_samples\": 10000,\n",
|
| 133 |
+
" \n",
|
| 134 |
+
" # General\n",
|
| 135 |
+
" \"seed\": 0,\n",
|
| 136 |
+
" \"device\": \"cuda\" if torch.cuda.is_available() else \"cpu\",\n",
|
| 137 |
+
" \"num_workers\": 16,\n",
|
| 138 |
+
"}\n",
|
| 139 |
+
"\n",
|
| 140 |
+
"# Set random seed\n",
|
| 141 |
+
"torch.manual_seed(config[\"seed\"])\n",
|
| 142 |
+
"np.random.seed(config[\"seed\"])\n",
|
| 143 |
+
"\n",
|
| 144 |
+
"# Set device\n",
|
| 145 |
+
"device = torch.device(config[\"device\"])\n",
|
| 146 |
+
"print(f\"Using device: {device}\")"
|
| 147 |
+
]
|
| 148 |
+
},
|
| 149 |
+
{
|
| 150 |
+
"cell_type": "markdown",
|
| 151 |
+
"metadata": {},
|
| 152 |
+
"source": [
|
| 153 |
+
"# 2. 📥 Genome & Tracks Data Download\n",
|
| 154 |
+
"\n",
|
| 155 |
+
"Download the reference genome FASTA file and BigWig signal tracks from public repositories. These files contain the genomic sequences and experimental signal data (e.g., ChIP-seq, ATAC-seq) that we'll use for training."
|
| 156 |
+
]
|
| 157 |
+
},
|
| 158 |
+
{
|
| 159 |
+
"cell_type": "code",
|
| 160 |
+
"execution_count": null,
|
| 161 |
+
"metadata": {},
|
| 162 |
+
"outputs": [],
|
| 163 |
+
"source": [
|
| 164 |
+
"def prepare_genomics_inputs(\n",
|
| 165 |
+
" species: str,\n",
|
| 166 |
+
" data_cache_dir: str | Path = \"data\",\n",
|
| 167 |
+
" hf_repo_id: str = \"InstaDeepAI/NTv3_benchmark_dataset\",\n",
|
| 168 |
+
" bigwig_file_ids: list[str] | None = None,\n",
|
| 169 |
+
") -> tuple[str, list[str], list[str], pd.DataFrame, pd.DataFrame]:\n",
|
| 170 |
+
" \"\"\"\n",
|
| 171 |
+
" Downloads:\n",
|
| 172 |
+
" 1) FASTA from HF dataset under: <species>/genome.fasta\n",
|
| 173 |
+
" 2) BigWigs from HF dataset under: <species>/functional_tracks/**\n",
|
| 174 |
+
" (filtered by bigwig_file_ids if provided)\n",
|
| 175 |
+
" 3) Splits from HF dataset under: <species>/splits.bed\n",
|
| 176 |
+
" 4) Metadata from HF dataset under: benchmark_metadata.tsv\n",
|
| 177 |
+
" \n",
|
| 178 |
+
" Args:\n",
|
| 179 |
+
" species: Species name (e.g., \"human\", \"arabidopsis\")\n",
|
| 180 |
+
" data_cache_dir: Directory where downloaded data files will be stored\n",
|
| 181 |
+
" hf_repo_id: HuggingFace dataset repository ID\n",
|
| 182 |
+
" bigwig_file_ids: Optional list of BigWig file IDs to download. If None,\n",
|
| 183 |
+
" downloads all available BigWig files for the species.\n",
|
| 184 |
+
" \n",
|
| 185 |
+
" Returns:\n",
|
| 186 |
+
" (fasta_path, bigwig_path_list, bigwig_file_ids)\n",
|
| 187 |
+
" \"\"\"\n",
|
| 188 |
+
" cache = Path(data_cache_dir).expanduser().resolve()\n",
|
| 189 |
+
" cache.mkdir(parents=True, exist_ok=True)\n",
|
| 190 |
+
" \n",
|
| 191 |
+
" # --- Download metadata + <species> files (FASTA, BigWigs, Splits) ---\n",
|
| 192 |
+
" metadata_file = \"benchmark_metadata.tsv\"\n",
|
| 193 |
+
" download_patterns = [metadata_file, f\"{species}/genome.fasta\", f\"{species}/splits.bed\"]\n",
|
| 194 |
+
" \n",
|
| 195 |
+
" if bigwig_file_ids is not None:\n",
|
| 196 |
+
" # List files to validate requested BigWig files exist\n",
|
| 197 |
+
" api = HfApi()\n",
|
| 198 |
+
" files = api.list_repo_files(repo_id=hf_repo_id, repo_type=\"dataset\")\n",
|
| 199 |
+
" species_pattern = f\"{species}/**\"\n",
|
| 200 |
+
" species_files = [p for p in files if fnmatch.fnmatch(p, species_pattern)]\n",
|
| 201 |
+
" \n",
|
| 202 |
+
" # Get all available BigWig file IDs and their paths\n",
|
| 203 |
+
" available_bigwig_files = {\n",
|
| 204 |
+
" Path(p).stem: p for p in species_files \n",
|
| 205 |
+
" if Path(p).suffix == \".bigwig\"\n",
|
| 206 |
+
" }\n",
|
| 207 |
+
" \n",
|
| 208 |
+
" # Check that all requested files exist\n",
|
| 209 |
+
" missing_files = set(bigwig_file_ids) - set(available_bigwig_files.keys())\n",
|
| 210 |
+
" if missing_files:\n",
|
| 211 |
+
" raise ValueError(\n",
|
| 212 |
+
" f\"Requested BigWig files not found: {missing_files}. \"\n",
|
| 213 |
+
" f\"Available files: {list(available_bigwig_files.keys())}\"\n",
|
| 214 |
+
" )\n",
|
| 215 |
+
" \n",
|
| 216 |
+
" # Add specific patterns for requested BigWig files only\n",
|
| 217 |
+
" for file_id in bigwig_file_ids:\n",
|
| 218 |
+
" download_patterns.append(available_bigwig_files[file_id])\n",
|
| 219 |
+
" else:\n",
|
| 220 |
+
" # Download all BigWig files\n",
|
| 221 |
+
" download_patterns.append(f\"{species}/functional_tracks/*.bigwig\")\n",
|
| 222 |
+
" local_dir = Path(\n",
|
| 223 |
+
" snapshot_download(\n",
|
| 224 |
+
" repo_id=hf_repo_id,\n",
|
| 225 |
+
" repo_type=\"dataset\",\n",
|
| 226 |
+
" allow_patterns=download_patterns,\n",
|
| 227 |
+
" local_dir=str(cache),\n",
|
| 228 |
+
" )\n",
|
| 229 |
+
" )\n",
|
| 230 |
+
" \n",
|
| 231 |
+
" # --- Organize outputs ---\n",
|
| 232 |
+
" # FASTA file\n",
|
| 233 |
+
" fasta_path_repo = f\"{species}/genome.fasta\"\n",
|
| 234 |
+
" fasta_path = str(local_dir / fasta_path_repo)\n",
|
| 235 |
+
" \n",
|
| 236 |
+
" # BigWig files - use downloaded files directly\n",
|
| 237 |
+
" bigwig_dir = local_dir / species / \"functional_tracks\"\n",
|
| 238 |
+
" \n",
|
| 239 |
+
" if bigwig_file_ids is not None:\n",
|
| 240 |
+
" bigwig_paths = [str(bigwig_dir / f\"{file_id}.bigwig\") for file_id in bigwig_file_ids]\n",
|
| 241 |
+
" bigwig_ids = bigwig_file_ids\n",
|
| 242 |
+
" else:\n",
|
| 243 |
+
" # Find all downloaded BigWig files\n",
|
| 244 |
+
" bigwig_paths = [str(bigwig_file) for bigwig_file in bigwig_dir.glob(\"*.bigwig\")]\n",
|
| 245 |
+
" bigwig_ids = [bigwig_file.stem for bigwig_file in bigwig_dir.glob(\"*.bigwig\")] \n",
|
| 246 |
+
" \n",
|
| 247 |
+
" # Splits file\n",
|
| 248 |
+
" splits_path_repo = f\"{species}/splits.bed\"\n",
|
| 249 |
+
" splits_path = local_dir / splits_path_repo\n",
|
| 250 |
+
"\n",
|
| 251 |
+
" splits_df = pd.read_csv(\n",
|
| 252 |
+
" splits_path, \n",
|
| 253 |
+
" sep=\"\\t\", \n",
|
| 254 |
+
" header=None, \n",
|
| 255 |
+
" names=[\"chr_name\", \"start\", \"end\", \"split\"],\n",
|
| 256 |
+
" dtype={\"chr_name\": str, \"start\": int, \"end\": int, \"split\": str},\n",
|
| 257 |
+
" )\n",
|
| 258 |
+
" \n",
|
| 259 |
+
" # Metadata file\n",
|
| 260 |
+
" metadata_path = local_dir / metadata_file\n",
|
| 261 |
+
" metadata_df = pd.read_csv(metadata_path, sep=\"\\t\")\n",
|
| 262 |
+
"\n",
|
| 263 |
+
" # Filter metadata according to species\n",
|
| 264 |
+
" metadata_df = metadata_df[metadata_df[\"species_name\"] == species].reset_index(drop=True)\n",
|
| 265 |
+
"\n",
|
| 266 |
+
" # Order metadata according to bigwig file ids\n",
|
| 267 |
+
" metadata_df = (\n",
|
| 268 |
+
" metadata_df.set_index(\"file_id\")\n",
|
| 269 |
+
" .loc[bigwig_ids]\n",
|
| 270 |
+
" .reset_index()\n",
|
| 271 |
+
" )\n",
|
| 272 |
+
"\n",
|
| 273 |
+
" return fasta_path, bigwig_paths, bigwig_ids, splits_df, metadata_df"
|
| 274 |
+
]
|
| 275 |
+
},
|
| 276 |
+
{
|
| 277 |
+
"cell_type": "code",
|
| 278 |
+
"execution_count": null,
|
| 279 |
+
"metadata": {},
|
| 280 |
+
"outputs": [],
|
| 281 |
+
"source": [
|
| 282 |
+
"os.makedirs(config[\"data_cache_dir\"], exist_ok=True)\n",
|
| 283 |
+
"\n",
|
| 284 |
+
"# Download all species files + load the splits, and metadata\n",
|
| 285 |
+
"(\n",
|
| 286 |
+
" fasta_path, \n",
|
| 287 |
+
" bigwig_paths, \n",
|
| 288 |
+
" bigwig_ids, \n",
|
| 289 |
+
" species_splits_df,\n",
|
| 290 |
+
" metadata_df \n",
|
| 291 |
+
") = prepare_genomics_inputs(\n",
|
| 292 |
+
" config[\"species_name\"], \n",
|
| 293 |
+
" config[\"data_cache_dir\"], \n",
|
| 294 |
+
" config[\"hf_repo_id\"]\n",
|
| 295 |
+
")"
|
| 296 |
+
]
|
| 297 |
+
},
|
| 298 |
+
{
|
| 299 |
+
"cell_type": "markdown",
|
| 300 |
+
"metadata": {},
|
| 301 |
+
"source": [
|
| 302 |
+
"# 3. 🧠 Model and tokenizer setup\n",
|
| 303 |
+
" \n",
|
| 304 |
+
"In this section, we set up the model and tokenizer. \n",
|
| 305 |
+
" \n",
|
| 306 |
+
"Our approach uses any suitable pretrained backbone from HuggingFace Transformers (for example, `InstaDeepAI/ntv3_650M_pre`),\n",
|
| 307 |
+
"which is then extended with an additional linear head. \n",
|
| 308 |
+
" \n",
|
| 309 |
+
"This linear head is trained for regression on a set of genomic tracks, \n",
|
| 310 |
+
"allowing the model to make predictions for each track at single nucleotide resolution.\n",
|
| 311 |
+
" \n",
|
| 312 |
+
"The following code wraps the HuggingFace model together with this regression head for the end-to-end task.\n"
|
| 313 |
+
]
|
| 314 |
+
},
|
| 315 |
+
{
|
| 316 |
+
"cell_type": "code",
|
| 317 |
+
"execution_count": null,
|
| 318 |
+
"metadata": {},
|
| 319 |
+
"outputs": [],
|
| 320 |
+
"source": [
|
| 321 |
+
"def crop_center(x: np.ndarray, keep_target_center_fraction: float = 0.375) -> np.ndarray:\n",
|
| 322 |
+
" \"\"\"Crop the central sequence-length fraction for arrays of size (..., seq_len, num_tracks)\"\"\"\n",
|
| 323 |
+
" seq_len = x.shape[-2]\n",
|
| 324 |
+
" target_offset = int(seq_len * (1 - keep_target_center_fraction) // 2)\n",
|
| 325 |
+
" target_length = seq_len - 2 * target_offset\n",
|
| 326 |
+
" return x[..., target_offset:target_offset + target_length, :]\n",
|
| 327 |
+
"\n",
|
| 328 |
+
"\n",
|
| 329 |
+
"class LinearHead(nn.Module):\n",
|
| 330 |
+
" \"\"\"A linear head that predicts one scalar value per track.\"\"\"\n",
|
| 331 |
+
" def __init__(self, embed_dim: int, num_labels: int):\n",
|
| 332 |
+
" super().__init__()\n",
|
| 333 |
+
" self.layer_norm = nn.LayerNorm(embed_dim)\n",
|
| 334 |
+
" self.head = nn.Linear(embed_dim, num_labels)\n",
|
| 335 |
+
" \n",
|
| 336 |
+
" def forward(self, x: torch.Tensor) -> torch.Tensor:\n",
|
| 337 |
+
" x = self.layer_norm(x)\n",
|
| 338 |
+
" x = self.head(x)\n",
|
| 339 |
+
" x = F.softplus(x) # Ensure positive values\n",
|
| 340 |
+
" return x\n",
|
| 341 |
+
"\n",
|
| 342 |
+
"\n",
|
| 343 |
+
"class HFModelWithHead(nn.Module):\n",
|
| 344 |
+
" \"\"\"Simple model wrapper: HF backbone + bigwig head.\"\"\"\n",
|
| 345 |
+
" \n",
|
| 346 |
+
" def __init__(\n",
|
| 347 |
+
" self,\n",
|
| 348 |
+
" model_name: str,\n",
|
| 349 |
+
" bigwig_track_names: List[str],\n",
|
| 350 |
+
" species_str: str,\n",
|
| 351 |
+
" keep_target_center_fraction: float = 0.375,\n",
|
| 352 |
+
" ):\n",
|
| 353 |
+
" super().__init__()\n",
|
| 354 |
+
" \n",
|
| 355 |
+
" # Load config and model\n",
|
| 356 |
+
" self.config = AutoConfig.from_pretrained(model_name, trust_remote_code=True)\n",
|
| 357 |
+
" self.backbone = AutoModel.from_pretrained(\n",
|
| 358 |
+
" model_name, \n",
|
| 359 |
+
" trust_remote_code=True,\n",
|
| 360 |
+
" config=self.config,\n",
|
| 361 |
+
" )\n",
|
| 362 |
+
" if species_str in self.backbone.supported_species:\n",
|
| 363 |
+
" self.species_ids = self.backbone.encode_species(species_str)\n",
|
| 364 |
+
" print(f\"Using species: {species_str} with ids: {self.species_ids}\")\n",
|
| 365 |
+
" else:\n",
|
| 366 |
+
" # Mask token id\n",
|
| 367 |
+
" print(f\"{species_str} not in supported species, using mask token id\")\n",
|
| 368 |
+
" self.species_ids = torch.LongTensor([2])\n",
|
| 369 |
+
"\n",
|
| 370 |
+
" self.keep_target_center_fraction = keep_target_center_fraction\n",
|
| 371 |
+
"\n",
|
| 372 |
+
" # Bigwig head (NTv3 outputs at single-nucleotide resolution)\n",
|
| 373 |
+
" self.bigwig_head = LinearHead(self.config.embed_dim, len(bigwig_track_names))\n",
|
| 374 |
+
" self.model_name = model_name\n",
|
| 375 |
+
" \n",
|
| 376 |
+
" def forward(self, tokens: torch.Tensor, **kwargs) -> Dict[str, torch.Tensor]:\n",
|
| 377 |
+
" # Forward through backbone\n",
|
| 378 |
+
" species_tokens = torch.repeat_interleave(self.species_ids, tokens.shape[0])\n",
|
| 379 |
+
" species_tokens = species_tokens.to(tokens.device)\n",
|
| 380 |
+
" outputs = self.backbone(input_ids=tokens, species_ids=species_tokens, output_hidden_states=True)\n",
|
| 381 |
+
" embedding = outputs.hidden_states[-1] # Last hidden state\n",
|
| 382 |
+
" \n",
|
| 383 |
+
" # Crop to center fraction\n",
|
| 384 |
+
" if self.keep_target_center_fraction < 1.0:\n",
|
| 385 |
+
" embedding = crop_center(embedding, self.keep_target_center_fraction)\n",
|
| 386 |
+
" \n",
|
| 387 |
+
" # Predict bigwig tracks\n",
|
| 388 |
+
" bigwig_logits = self.bigwig_head(embedding)\n",
|
| 389 |
+
" \n",
|
| 390 |
+
" return {\"bigwig_tracks_logits\": bigwig_logits}"
|
| 391 |
+
]
|
| 392 |
+
},
|
| 393 |
+
{
|
| 394 |
+
"cell_type": "code",
|
| 395 |
+
"execution_count": null,
|
| 396 |
+
"metadata": {},
|
| 397 |
+
"outputs": [],
|
| 398 |
+
"source": [
|
| 399 |
+
"# Load tokenizer\n",
|
| 400 |
+
"tokenizer = AutoTokenizer.from_pretrained(config[\"model_name\"], trust_remote_code=True)\n",
|
| 401 |
+
"\n",
|
| 402 |
+
"# Create model\n",
|
| 403 |
+
"model = HFModelWithHead(\n",
|
| 404 |
+
" model_name=config[\"model_name\"],\n",
|
| 405 |
+
" bigwig_track_names=bigwig_ids,\n",
|
| 406 |
+
" species_str=config[\"full_species_name\"],\n",
|
| 407 |
+
" keep_target_center_fraction=config[\"keep_target_center_fraction\"],\n",
|
| 408 |
+
")\n",
|
| 409 |
+
"model = model.to(device)\n",
|
| 410 |
+
"model.train()\n",
|
| 411 |
+
"\n",
|
| 412 |
+
"print(f\"Model loaded: {config['model_name']}\")\n",
|
| 413 |
+
"print(f\"Number of bigwig tracks: {len(bigwig_ids)}\")\n",
|
| 414 |
+
"print(f\"Model parameters: {sum(p.numel() for p in model.parameters()):,}\")"
|
| 415 |
+
]
|
| 416 |
+
},
|
| 417 |
+
{
|
| 418 |
+
"cell_type": "markdown",
|
| 419 |
+
"metadata": {},
|
| 420 |
+
"source": [
|
| 421 |
+
"# 4. 🔄 Data loading\n",
|
| 422 |
+
"\n",
|
| 423 |
+
"Create PyTorch datasets and data loaders that efficiently sample random genomic windows from the reference genome and extract corresponding BigWig signal values. The dataset handles sequence tokenization, target scaling, and chromosome-based train/val/test splits."
|
| 424 |
+
]
|
| 425 |
+
},
|
| 426 |
+
{
|
| 427 |
+
"cell_type": "code",
|
| 428 |
+
"execution_count": null,
|
| 429 |
+
"metadata": {},
|
| 430 |
+
"outputs": [],
|
| 431 |
+
"source": [
|
| 432 |
+
"# Process-local cache for file handles (one per worker process)\n",
|
| 433 |
+
"# This allows safe multi-worker DataLoader usage\n",
|
| 434 |
+
"_fasta_cache = {} # Maps (process_id, file_path) -> Fasta handle\n",
|
| 435 |
+
"_bigwig_cache = {} # Maps (process_id, file_path) -> pyBigWig handle\n",
|
| 436 |
+
"\n",
|
| 437 |
+
"\n",
|
| 438 |
+
"def _get_fasta_handle(fasta_path: str) -> Fasta:\n",
|
| 439 |
+
" \"\"\"Get or create a FASTA file handle for the current process.\"\"\"\n",
|
| 440 |
+
" process_id = os.getpid()\n",
|
| 441 |
+
" abs_path = str(Path(fasta_path).resolve())\n",
|
| 442 |
+
" cache_key = (process_id, abs_path)\n",
|
| 443 |
+
" \n",
|
| 444 |
+
" if cache_key not in _fasta_cache:\n",
|
| 445 |
+
" _fasta_cache[cache_key] = Fasta(abs_path, as_raw=True, sequence_always_upper=True)\n",
|
| 446 |
+
" \n",
|
| 447 |
+
" return _fasta_cache[cache_key]\n",
|
| 448 |
+
"\n",
|
| 449 |
+
"\n",
|
| 450 |
+
"def _get_bigwig_handle(bigwig_path: str) -> pyBigWig.pyBigWig:\n",
|
| 451 |
+
" \"\"\"Get or create a BigWig file handle for the current process.\"\"\"\n",
|
| 452 |
+
" process_id = os.getpid()\n",
|
| 453 |
+
" abs_path = str(Path(bigwig_path).resolve())\n",
|
| 454 |
+
" cache_key = (process_id, abs_path)\n",
|
| 455 |
+
" \n",
|
| 456 |
+
" if cache_key not in _bigwig_cache:\n",
|
| 457 |
+
" # Check if file exists before trying to open\n",
|
| 458 |
+
" if not Path(abs_path).exists():\n",
|
| 459 |
+
" raise FileNotFoundError(\n",
|
| 460 |
+
" f\"BigWig file not found: {abs_path}\\n\"\n",
|
| 461 |
+
" f\"Original path: {bigwig_path}\\n\"\n",
|
| 462 |
+
" f\"Current working directory: {os.getcwd()}\"\n",
|
| 463 |
+
" )\n",
|
| 464 |
+
" \n",
|
| 465 |
+
" try:\n",
|
| 466 |
+
" _bigwig_cache[cache_key] = pyBigWig.open(abs_path)\n",
|
| 467 |
+
" except Exception as e:\n",
|
| 468 |
+
" raise RuntimeError(\n",
|
| 469 |
+
" f\"Failed to open BigWig file: {abs_path} with error: {str(e)}\\n\"\n",
|
| 470 |
+
" f\"File exists: {Path(abs_path).exists()}\\n\"\n",
|
| 471 |
+
" f\"File size: {Path(abs_path).stat().st_size if Path(abs_path).exists() else 'N/A'} bytes\"\n",
|
| 472 |
+
" ) from e\n",
|
| 473 |
+
" \n",
|
| 474 |
+
" return _bigwig_cache[cache_key]\n",
|
| 475 |
+
"\n",
|
| 476 |
+
"\n",
|
| 477 |
+
"class GenomeBigWigDataset(Dataset):\n",
|
| 478 |
+
" \"\"\"\n",
|
| 479 |
+
" A PyTorch dataset to access a reference genome and bigwig tracks. The dataset is \n",
|
| 480 |
+
" compatible with multi-worker DataLoaders (using process-local file handles and lazy \n",
|
| 481 |
+
" loading). For each sample, a random genomic region is picked from the specified split,\n",
|
| 482 |
+
" and a random window of length `sequence_length` within that region is returned.\n",
|
| 483 |
+
" \"\"\"\n",
|
| 484 |
+
"\n",
|
| 485 |
+
" def __init__(\n",
|
| 486 |
+
" self,\n",
|
| 487 |
+
" fasta_path: str,\n",
|
| 488 |
+
" bigwig_path_list: list[str],\n",
|
| 489 |
+
" chrom_regions: pd.DataFrame,\n",
|
| 490 |
+
" split: str,\n",
|
| 491 |
+
" sequence_length: int,\n",
|
| 492 |
+
" num_samples: int,\n",
|
| 493 |
+
" tokenizer: AutoTokenizer,\n",
|
| 494 |
+
" transform_fn: Callable[[torch.Tensor], torch.Tensor],\n",
|
| 495 |
+
" keep_target_center_fraction: float = 1.0,\n",
|
| 496 |
+
" ):\n",
|
| 497 |
+
" super().__init__()\n",
|
| 498 |
+
"\n",
|
| 499 |
+
" # Store paths instead of opening files immediately (for multi-worker compatibility)\n",
|
| 500 |
+
" self.fasta_path = fasta_path\n",
|
| 501 |
+
" self.bigwig_path_list = bigwig_path_list\n",
|
| 502 |
+
" self.sequence_length = sequence_length\n",
|
| 503 |
+
" self.num_samples = num_samples\n",
|
| 504 |
+
" self.tokenizer = tokenizer\n",
|
| 505 |
+
" self.transform_fn = transform_fn\n",
|
| 506 |
+
" self.keep_target_center_fraction = keep_target_center_fraction\n",
|
| 507 |
+
" self.chrom_regions = chrom_regions\n",
|
| 508 |
+
"\n",
|
| 509 |
+
" # Filter regions by split\n",
|
| 510 |
+
" split_regions = self.chrom_regions[self.chrom_regions[\"split\"] == split].copy()\n",
|
| 511 |
+
"\n",
|
| 512 |
+
" # Filter valid regions (must be large enough for sequence_length)\n",
|
| 513 |
+
" self.valid_regions = []\n",
|
| 514 |
+
" for _, row in split_regions.iterrows():\n",
|
| 515 |
+
"\n",
|
| 516 |
+
" region_length = row.end - row.start\n",
|
| 517 |
+
" if region_length < self.sequence_length:\n",
|
| 518 |
+
" continue\n",
|
| 519 |
+
" \n",
|
| 520 |
+
" # Store valid region\n",
|
| 521 |
+
" self.valid_regions.append((row.chr_name, row.start, row.end))\n",
|
| 522 |
+
"\n",
|
| 523 |
+
" def __len__(self):\n",
|
| 524 |
+
" return self.num_samples\n",
|
| 525 |
+
"\n",
|
| 526 |
+
" def __getitem__(self, idx):\n",
|
| 527 |
+
" # Sample a random region from the valid regions\n",
|
| 528 |
+
" chrom, region_start, region_end = random.choice(self.valid_regions)\n",
|
| 529 |
+
" \n",
|
| 530 |
+
" # Sample a random window within this region\n",
|
| 531 |
+
" max_start = region_end - self.sequence_length\n",
|
| 532 |
+
" start = random.randint(region_start, max_start)\n",
|
| 533 |
+
" end = start + self.sequence_length\n",
|
| 534 |
+
"\n",
|
| 535 |
+
" # Sequence - get FASTA handle lazily (cached per worker process)\n",
|
| 536 |
+
" fasta = _get_fasta_handle(self.fasta_path)\n",
|
| 537 |
+
" seq = fasta[chrom][start:end] # string slice\n",
|
| 538 |
+
" # Tokenize with padding and truncation to ensure consistent lengths for batching\n",
|
| 539 |
+
" tokenized = self.tokenizer(\n",
|
| 540 |
+
" seq,\n",
|
| 541 |
+
" padding=\"max_length\",\n",
|
| 542 |
+
" truncation=True,\n",
|
| 543 |
+
" max_length=self.sequence_length,\n",
|
| 544 |
+
" return_tensors=\"pt\",\n",
|
| 545 |
+
" )\n",
|
| 546 |
+
" tokens = tokenized[\"input_ids\"][0] # Shape: (max_length,)\n",
|
| 547 |
+
"\n",
|
| 548 |
+
" # Signal from bigWig tracks (numpy array) -> torch tensor\n",
|
| 549 |
+
" # Get BigWig handles lazily (cached per worker process)\n",
|
| 550 |
+
" bigwig_targets = np.array([\n",
|
| 551 |
+
" _get_bigwig_handle(bw_path).values(chrom, start, end, numpy=True)\n",
|
| 552 |
+
" for bw_path in self.bigwig_path_list\n",
|
| 553 |
+
" ]) # shape (num_tracks, seq_len)\n",
|
| 554 |
+
" # Transpose to (seq_len, num_tracks)\n",
|
| 555 |
+
" bigwig_targets = bigwig_targets.T\n",
|
| 556 |
+
" # pyBigWig returns NaN where no data; turn NaN into 0\n",
|
| 557 |
+
" bigwig_targets = torch.tensor(bigwig_targets, dtype=torch.float32)\n",
|
| 558 |
+
" bigwig_targets = torch.nan_to_num(bigwig_targets, nan=0.0)\n",
|
| 559 |
+
" \n",
|
| 560 |
+
" # Crop targets to center fraction\n",
|
| 561 |
+
" if self.keep_target_center_fraction < 1.0:\n",
|
| 562 |
+
" bigwig_targets = crop_center(bigwig_targets, self.keep_target_center_fraction)\n",
|
| 563 |
+
"\n",
|
| 564 |
+
" # Apply scaling to targets\n",
|
| 565 |
+
" bigwig_targets = self.transform_fn(bigwig_targets)\n",
|
| 566 |
+
"\n",
|
| 567 |
+
" sample = {\n",
|
| 568 |
+
" \"tokens\": tokens,\n",
|
| 569 |
+
" \"bigwig_targets\": bigwig_targets,\n",
|
| 570 |
+
" \"chrom\": chrom,\n",
|
| 571 |
+
" \"start\": start,\n",
|
| 572 |
+
" \"end\": end,\n",
|
| 573 |
+
" }\n",
|
| 574 |
+
" return sample"
|
| 575 |
+
]
|
| 576 |
+
},
|
| 577 |
+
{
|
| 578 |
+
"cell_type": "markdown",
|
| 579 |
+
"metadata": {},
|
| 580 |
+
"source": [
|
| 581 |
+
"### Data preprocessing utilities"
|
| 582 |
+
]
|
| 583 |
+
},
|
| 584 |
+
{
|
| 585 |
+
"cell_type": "code",
|
| 586 |
+
"execution_count": null,
|
| 587 |
+
"metadata": {},
|
| 588 |
+
"outputs": [],
|
| 589 |
+
"source": [
|
| 590 |
+
"def create_targets_scaling_fn(\n",
|
| 591 |
+
" metadata_df: pd.DataFrame\n",
|
| 592 |
+
") -> Callable[[torch.Tensor], torch.Tensor]:\n",
|
| 593 |
+
" \"\"\"\n",
|
| 594 |
+
" Build a scaling function that uses the track means to normalise and softclip the targets.\n",
|
| 595 |
+
" \"\"\"\n",
|
| 596 |
+
" # Open bigwig files and compute track statistics\n",
|
| 597 |
+
" track_means = metadata_df[\"mean\"].to_numpy()\n",
|
| 598 |
+
" print(f\"Track means: {track_means}\")\n",
|
| 599 |
+
" print(f\"Number of tracks: {track_means.shape}\")\n",
|
| 600 |
+
"\n",
|
| 601 |
+
" # Create tensor from computed means\n",
|
| 602 |
+
" track_means_tensor = torch.tensor(track_means, dtype=torch.float32)\n",
|
| 603 |
+
"\n",
|
| 604 |
+
" def transform_fn(x: torch.Tensor) -> torch.Tensor:\n",
|
| 605 |
+
" # Move constants to correct device then normalize\n",
|
| 606 |
+
" means = track_means_tensor.to(x.device)\n",
|
| 607 |
+
" scaled = x / means\n",
|
| 608 |
+
"\n",
|
| 609 |
+
" # Smooth clipping: if > 10, apply formula\n",
|
| 610 |
+
" clipped = torch.where(\n",
|
| 611 |
+
" scaled > 10.0,\n",
|
| 612 |
+
" 2.0 * torch.sqrt(scaled * 10.0) - 10.0,\n",
|
| 613 |
+
" scaled,\n",
|
| 614 |
+
" )\n",
|
| 615 |
+
" return clipped\n",
|
| 616 |
+
"\n",
|
| 617 |
+
" return transform_fn"
|
| 618 |
+
]
|
| 619 |
+
},
|
| 620 |
+
{
|
| 621 |
+
"cell_type": "code",
|
| 622 |
+
"execution_count": null,
|
| 623 |
+
"metadata": {},
|
| 624 |
+
"outputs": [],
|
| 625 |
+
"source": [
|
| 626 |
+
"# Create datasets & dataloaders\n",
|
| 627 |
+
"create_dataset_fn = functools.partial(\n",
|
| 628 |
+
" GenomeBigWigDataset,\n",
|
| 629 |
+
" fasta_path=fasta_path,\n",
|
| 630 |
+
" bigwig_path_list=bigwig_paths,\n",
|
| 631 |
+
" chrom_regions=species_splits_df,\n",
|
| 632 |
+
" sequence_length=config[\"sequence_length\"],\n",
|
| 633 |
+
" tokenizer=tokenizer,\n",
|
| 634 |
+
" transform_fn=create_targets_scaling_fn(metadata_df),\n",
|
| 635 |
+
" keep_target_center_fraction=config[\"keep_target_center_fraction\"],\n",
|
| 636 |
+
")\n",
|
| 637 |
+
"\n",
|
| 638 |
+
"train_dataset = create_dataset_fn(\n",
|
| 639 |
+
" split=\"train\",\n",
|
| 640 |
+
" num_samples=config[\"num_steps_training\"] * config[\"batch_size\"],\n",
|
| 641 |
+
")\n",
|
| 642 |
+
"\n",
|
| 643 |
+
"val_dataset = create_dataset_fn(\n",
|
| 644 |
+
" split=\"val\",\n",
|
| 645 |
+
" num_samples=config[\"num_validation_samples\"],\n",
|
| 646 |
+
")\n",
|
| 647 |
+
"\n",
|
| 648 |
+
"test_dataset = create_dataset_fn(\n",
|
| 649 |
+
" split=\"test\",\n",
|
| 650 |
+
" num_samples=config[\"num_test_samples\"],\n",
|
| 651 |
+
")\n",
|
| 652 |
+
"\n",
|
| 653 |
+
"# Create dataloaders\n",
|
| 654 |
+
"train_loader = DataLoader(\n",
|
| 655 |
+
" train_dataset,\n",
|
| 656 |
+
" batch_size=config[\"batch_size\"],\n",
|
| 657 |
+
" shuffle=True,\n",
|
| 658 |
+
" num_workers=config[\"num_workers\"],\n",
|
| 659 |
+
")\n",
|
| 660 |
+
"\n",
|
| 661 |
+
"val_loader = DataLoader(\n",
|
| 662 |
+
" val_dataset,\n",
|
| 663 |
+
" batch_size=config[\"batch_size\"],\n",
|
| 664 |
+
" shuffle=False,\n",
|
| 665 |
+
" num_workers=config[\"num_workers\"],\n",
|
| 666 |
+
")\n",
|
| 667 |
+
"\n",
|
| 668 |
+
"test_loader = DataLoader(\n",
|
| 669 |
+
" test_dataset,\n",
|
| 670 |
+
" batch_size=config[\"batch_size\"],\n",
|
| 671 |
+
" shuffle=False,\n",
|
| 672 |
+
" num_workers=config[\"num_workers\"],\n",
|
| 673 |
+
")\n",
|
| 674 |
+
"\n",
|
| 675 |
+
"print(f\"\\nTrain samples: {len(train_dataset)}\")\n",
|
| 676 |
+
"print(f\"Val samples: {len(val_dataset)}\")\n",
|
| 677 |
+
"print(f\"Test samples: {len(test_dataset)}\")"
|
| 678 |
+
]
|
| 679 |
+
},
|
| 680 |
+
{
|
| 681 |
+
"cell_type": "markdown",
|
| 682 |
+
"metadata": {},
|
| 683 |
+
"source": [
|
| 684 |
+
"# 5. ⚙️ Optimizer setup\n",
|
| 685 |
+
"\n",
|
| 686 |
+
"Configure the AdamW optimizer with learning rate and weight decay hyperparameters. This optimizer will update the model parameters during training to minimize the loss function.\n",
|
| 687 |
+
"\n"
|
| 688 |
+
]
|
| 689 |
+
},
|
| 690 |
+
{
|
| 691 |
+
"cell_type": "code",
|
| 692 |
+
"execution_count": null,
|
| 693 |
+
"metadata": {},
|
| 694 |
+
"outputs": [],
|
| 695 |
+
"source": [
|
| 696 |
+
"# Training setup\n",
|
| 697 |
+
"print(f\"Training configuration:\")\n",
|
| 698 |
+
"print(f\" Batch size: {config['batch_size']}\")\n",
|
| 699 |
+
"print(f\" Total training steps: {config['num_steps_training']}\")\n",
|
| 700 |
+
"print(f\" Log metrics every: {config['log_every_n_steps']} steps\")\n",
|
| 701 |
+
"print(f\" Validate every: {config['validate_every_n_steps']} steps\")\n",
|
| 702 |
+
"\n",
|
| 703 |
+
"# Setup optimizer\n",
|
| 704 |
+
"optimizer = AdamW(\n",
|
| 705 |
+
" model.parameters(),\n",
|
| 706 |
+
" lr=config[\"learning_rate\"],\n",
|
| 707 |
+
" weight_decay=config[\"weight_decay\"],\n",
|
| 708 |
+
")\n",
|
| 709 |
+
"\n",
|
| 710 |
+
"print(f\"\\nOptimizer setup:\")\n",
|
| 711 |
+
"print(f\" Learning rate: {config['learning_rate']}\")"
|
| 712 |
+
]
|
| 713 |
+
},
|
| 714 |
+
{
|
| 715 |
+
"cell_type": "markdown",
|
| 716 |
+
"metadata": {},
|
| 717 |
+
"source": [
|
| 718 |
+
"# 6. 📊 Metrics setup\n",
|
| 719 |
+
"\n",
|
| 720 |
+
"Set up evaluation metrics to track model performance during training and validation. We use Pearson correlation coefficients to measure how well the predicted BigWig signals match the ground truth signals."
|
| 721 |
+
]
|
| 722 |
+
},
|
| 723 |
+
{
|
| 724 |
+
"cell_type": "code",
|
| 725 |
+
"execution_count": null,
|
| 726 |
+
"metadata": {},
|
| 727 |
+
"outputs": [],
|
| 728 |
+
"source": [
|
| 729 |
+
"class TracksMetrics:\n",
|
| 730 |
+
" \"\"\"Metrics to handle multi-track pearson correlations and losses\"\"\"\n",
|
| 731 |
+
" \n",
|
| 732 |
+
" def __init__(self, track_names: List[str], split: str):\n",
|
| 733 |
+
" self.track_names = track_names\n",
|
| 734 |
+
" self.num_tracks = len(track_names)\n",
|
| 735 |
+
" self.split = split\n",
|
| 736 |
+
"\n",
|
| 737 |
+
" # Initialise metrics \n",
|
| 738 |
+
" self.pearson = PearsonCorrCoef(num_outputs=self.num_tracks).to(device)\n",
|
| 739 |
+
" self.pearson.set_dtype(torch.float64) # Use float64 for improved numerical stability\n",
|
| 740 |
+
" self.losses = []\n",
|
| 741 |
+
"\n",
|
| 742 |
+
" # Record mean metrics per logging interval\n",
|
| 743 |
+
" self.step_idxs = []\n",
|
| 744 |
+
" self.mean_pearsons = []\n",
|
| 745 |
+
" self.mean_losses = []\n",
|
| 746 |
+
" \n",
|
| 747 |
+
" def reset(self):\n",
|
| 748 |
+
" self.pearson.reset()\n",
|
| 749 |
+
" self.losses = []\n",
|
| 750 |
+
" \n",
|
| 751 |
+
" def update(\n",
|
| 752 |
+
" self, \n",
|
| 753 |
+
" predictions: torch.Tensor, \n",
|
| 754 |
+
" targets: torch.Tensor,\n",
|
| 755 |
+
" loss: float\n",
|
| 756 |
+
" ):\n",
|
| 757 |
+
" \"\"\"\n",
|
| 758 |
+
" Update the metrics with predictions and targets of shape (..., num_tracks) and a scalar loss.\n",
|
| 759 |
+
" \"\"\"\n",
|
| 760 |
+
" # Flatten batch and sequence dimensions\n",
|
| 761 |
+
" pred_flat = predictions.detach().reshape(-1, self.num_tracks).to(torch.float64) # (N, num_tracks)\n",
|
| 762 |
+
" target_flat = targets.detach().reshape(-1, self.num_tracks).to(torch.float64) # (N, num_tracks)\n",
|
| 763 |
+
" \n",
|
| 764 |
+
" # Update metrics\n",
|
| 765 |
+
" self.pearson.update(pred_flat, target_flat)\n",
|
| 766 |
+
" self.losses.append(loss)\n",
|
| 767 |
+
" \n",
|
| 768 |
+
" def compute(self) -> Dict[str, float]:\n",
|
| 769 |
+
" \"\"\"Compute the pearson correlations and loss and return a dictionary of metrics.\"\"\"\n",
|
| 770 |
+
" # Per-track Pearson correlations\n",
|
| 771 |
+
" correlations = self.pearson.compute().cpu().numpy()\n",
|
| 772 |
+
" metrics_dict = {\n",
|
| 773 |
+
" f\"{track_name}/pearson\": correlations[i] for i, track_name in enumerate(self.track_names)\n",
|
| 774 |
+
" }\n",
|
| 775 |
+
" metrics_dict[\"mean/pearson\"] = correlations.mean()\n",
|
| 776 |
+
" \n",
|
| 777 |
+
" # Mean loss\n",
|
| 778 |
+
" metrics_dict[\"loss\"] = np.mean(self.losses)\n",
|
| 779 |
+
" \n",
|
| 780 |
+
" return metrics_dict\n",
|
| 781 |
+
"\n",
|
| 782 |
+
" def update_mean_metrics(self, step_idx: int):\n",
|
| 783 |
+
" \"\"\"Update the mean metrics over the logging interval and save to a csv file.\"\"\"\n",
|
| 784 |
+
" # Update mean metrics with the mean pearson & average loss\n",
|
| 785 |
+
" metrics_dict = self.compute()\n",
|
| 786 |
+
" self.step_idxs.append(step_idx)\n",
|
| 787 |
+
" self.mean_pearsons.append(metrics_dict[\"mean/pearson\"])\n",
|
| 788 |
+
" self.mean_losses.append(metrics_dict[\"loss\"])\n",
|
| 789 |
+
"\n",
|
| 790 |
+
" # Save metrics to a csv for plotting\n",
|
| 791 |
+
" data = {\n",
|
| 792 |
+
" \"step\": self.step_idxs,\n",
|
| 793 |
+
" \"mean_loss\": self.mean_losses,\n",
|
| 794 |
+
" \"mean_pearson\": self.mean_pearsons,\n",
|
| 795 |
+
" }\n",
|
| 796 |
+
" df = pd.DataFrame(data)\n",
|
| 797 |
+
" df.to_csv(f\"metrics_{self.split}.csv\", index=False)\n",
|
| 798 |
+
" \n",
|
| 799 |
+
" def print_metrics(self, print_per_track: bool = False):\n",
|
| 800 |
+
" \"\"\"Print a summary of the metrics.\"\"\"\n",
|
| 801 |
+
" print(\n",
|
| 802 |
+
" f\"Step {self.step_idxs[-1]}/{config['num_steps_training']} | \"\n",
|
| 803 |
+
" f\"Loss: {self.mean_losses[-1]:.4f} | \"\n",
|
| 804 |
+
" f\"Mean Pearson: {self.mean_pearsons[-1]:.4f}\"\n",
|
| 805 |
+
" )\n",
|
| 806 |
+
" metrics_dict = self.compute()\n",
|
| 807 |
+
" if print_per_track:\n",
|
| 808 |
+
" for metric_key, metric_value in metrics_dict.items():\n",
|
| 809 |
+
" print(f\" {metric_key}: {metric_value:.4f}\")\n",
|
| 810 |
+
" "
|
| 811 |
+
]
|
| 812 |
+
},
|
| 813 |
+
{
|
| 814 |
+
"cell_type": "code",
|
| 815 |
+
"execution_count": null,
|
| 816 |
+
"metadata": {},
|
| 817 |
+
"outputs": [],
|
| 818 |
+
"source": [
|
| 819 |
+
"train_metrics = TracksMetrics(bigwig_ids, \"train\")\n",
|
| 820 |
+
"val_metrics = TracksMetrics(bigwig_ids, \"val\")\n",
|
| 821 |
+
"test_metrics = TracksMetrics(bigwig_ids, \"test\")"
|
| 822 |
+
]
|
| 823 |
+
},
|
| 824 |
+
{
|
| 825 |
+
"cell_type": "markdown",
|
| 826 |
+
"metadata": {},
|
| 827 |
+
"source": [
|
| 828 |
+
"# 7. 📉 Loss functions\n",
|
| 829 |
+
"\n",
|
| 830 |
+
"Define the Poisson-Multinomial loss function that captures both the scale (total signal) and shape (distribution) of BigWig tracks. This loss is specifically designed for count-based genomic signal data."
|
| 831 |
+
]
|
| 832 |
+
},
|
| 833 |
+
{
|
| 834 |
+
"cell_type": "code",
|
| 835 |
+
"execution_count": null,
|
| 836 |
+
"metadata": {},
|
| 837 |
+
"outputs": [],
|
| 838 |
+
"source": [
|
| 839 |
+
"def poisson_loss(ytrue: torch.Tensor, ypred: torch.Tensor, epsilon: float = 1e-7) -> torch.Tensor:\n",
|
| 840 |
+
" \"\"\"Poisson loss per element: ypred - ytrue * log(ypred).\"\"\"\n",
|
| 841 |
+
" return ypred - ytrue * torch.log(ypred + epsilon)\n",
|
| 842 |
+
"\n",
|
| 843 |
+
"\n",
|
| 844 |
+
"def safe_for_grad_log_torch(x: torch.Tensor) -> torch.Tensor:\n",
|
| 845 |
+
" \"\"\"Guarantees that the log is defined for all x > 0 in a differentiable way.\"\"\"\n",
|
| 846 |
+
" return torch.log(torch.where(x > 0.0, x, torch.ones_like(x)))\n",
|
| 847 |
+
"\n",
|
| 848 |
+
"\n",
|
| 849 |
+
"def poisson_multinomial_loss(\n",
|
| 850 |
+
" logits: torch.Tensor,\n",
|
| 851 |
+
" targets: torch.Tensor,\n",
|
| 852 |
+
" shape_loss_coefficient: float = 5.0,\n",
|
| 853 |
+
" epsilon: float = 1e-7,\n",
|
| 854 |
+
") -> torch.Tensor: \n",
|
| 855 |
+
" \"\"\"\n",
|
| 856 |
+
" Regression loss for bigwig tracks (Poisson-Multinomial). The logits and targets are\n",
|
| 857 |
+
" expected to be of shape (batch, seq_length, num_tracks).\n",
|
| 858 |
+
" \"\"\"\n",
|
| 859 |
+
" batch_size, seq_length, num_tracks = logits.shape\n",
|
| 860 |
+
" \n",
|
| 861 |
+
" # Scale loss: Poisson loss on total counts per sequence per track\n",
|
| 862 |
+
" # Sum over sequence dimension (axis=1)\n",
|
| 863 |
+
" sum_pred = logits.sum(dim=1) # (batch, num_tracks)\n",
|
| 864 |
+
" sum_true = targets.sum(dim=1) # (batch, num_tracks)\n",
|
| 865 |
+
" \n",
|
| 866 |
+
" # Compute poisson loss per (batch, track)\n",
|
| 867 |
+
" scale_loss = poisson_loss(sum_true, sum_pred, epsilon=epsilon) # (batch, num_tracks)\n",
|
| 868 |
+
" \n",
|
| 869 |
+
" # Normalize by sequence length\n",
|
| 870 |
+
" scale_loss = scale_loss / (seq_length + epsilon)\n",
|
| 871 |
+
" \n",
|
| 872 |
+
" # Average over batch and tracks\n",
|
| 873 |
+
" scale_loss = scale_loss.mean()\n",
|
| 874 |
+
" \n",
|
| 875 |
+
" # Shape loss: Multinomial loss\n",
|
| 876 |
+
" # Add epsilon to all positions\n",
|
| 877 |
+
" predicted_counts = logits + epsilon\n",
|
| 878 |
+
" targets_with_epsilon = targets + epsilon\n",
|
| 879 |
+
" \n",
|
| 880 |
+
" # Normalize predictions to get probabilities\n",
|
| 881 |
+
" denom = predicted_counts.sum(dim=1, keepdim=True) + epsilon # (batch, 1, num_tracks)\n",
|
| 882 |
+
" p_pred = predicted_counts / denom\n",
|
| 883 |
+
" \n",
|
| 884 |
+
" # Compute shape loss: -sum(targets * log(p_pred))\n",
|
| 885 |
+
" pl_pred = safe_for_grad_log_torch(p_pred)\n",
|
| 886 |
+
" shape_loss = -(targets_with_epsilon * pl_pred)\n",
|
| 887 |
+
" \n",
|
| 888 |
+
" # Sum over all dimensions and normalize by total number of positions\n",
|
| 889 |
+
" shape_denom = batch_size * seq_length * num_tracks + epsilon\n",
|
| 890 |
+
" shape_loss = shape_loss.sum() / shape_denom\n",
|
| 891 |
+
" \n",
|
| 892 |
+
" # Combine losses\n",
|
| 893 |
+
" loss = shape_loss + scale_loss / shape_loss_coefficient\n",
|
| 894 |
+
"\n",
|
| 895 |
+
" return loss\n"
|
| 896 |
+
]
|
| 897 |
+
},
|
| 898 |
+
{
|
| 899 |
+
"cell_type": "markdown",
|
| 900 |
+
"metadata": {},
|
| 901 |
+
"source": [
|
| 902 |
+
"# 8. 🏃 Training loop\n",
|
| 903 |
+
"\n",
|
| 904 |
+
"Run the main training loop that iterates through batches, computes gradients, and updates model parameters. The loop includes periodic validation checks and real-time metric visualization to monitor training progress."
|
| 905 |
+
]
|
| 906 |
+
},
|
| 907 |
+
{
|
| 908 |
+
"cell_type": "code",
|
| 909 |
+
"execution_count": null,
|
| 910 |
+
"metadata": {},
|
| 911 |
+
"outputs": [],
|
| 912 |
+
"source": [
|
| 913 |
+
"def train_step(\n",
|
| 914 |
+
" model: nn.Module,\n",
|
| 915 |
+
" optimizer: torch.optim.Optimizer,\n",
|
| 916 |
+
" batch: Dict[str, torch.Tensor],\n",
|
| 917 |
+
" train_metrics: TracksMetrics,\n",
|
| 918 |
+
") -> None:\n",
|
| 919 |
+
" \"\"\"Single training step.\"\"\"\n",
|
| 920 |
+
" tokens = batch[\"tokens\"].to(device)\n",
|
| 921 |
+
" bigwig_targets = batch[\"bigwig_targets\"].to(device)\n",
|
| 922 |
+
" \n",
|
| 923 |
+
" # Forward pass\n",
|
| 924 |
+
" outputs = model(tokens=tokens)\n",
|
| 925 |
+
" bigwig_logits = outputs[\"bigwig_tracks_logits\"]\n",
|
| 926 |
+
" \n",
|
| 927 |
+
" # Compute loss\n",
|
| 928 |
+
" loss = poisson_multinomial_loss(\n",
|
| 929 |
+
" logits=bigwig_logits,\n",
|
| 930 |
+
" targets=bigwig_targets,\n",
|
| 931 |
+
" )\n",
|
| 932 |
+
"\n",
|
| 933 |
+
" # Backward pass\n",
|
| 934 |
+
" optimizer.zero_grad()\n",
|
| 935 |
+
" loss.backward()\n",
|
| 936 |
+
" optimizer.step()\n",
|
| 937 |
+
"\n",
|
| 938 |
+
" # Update metrics\n",
|
| 939 |
+
" train_metrics.update(\n",
|
| 940 |
+
" predictions=bigwig_logits,\n",
|
| 941 |
+
" targets=bigwig_targets,\n",
|
| 942 |
+
" loss=loss.item()\n",
|
| 943 |
+
" )\n",
|
| 944 |
+
" \n",
|
| 945 |
+
"\n",
|
| 946 |
+
"\n",
|
| 947 |
+
"def validation_step(\n",
|
| 948 |
+
" model: nn.Module,\n",
|
| 949 |
+
" batch: Dict[str, torch.Tensor],\n",
|
| 950 |
+
" metrics: TracksMetrics,\n",
|
| 951 |
+
") -> None:\n",
|
| 952 |
+
" \"\"\"Single validation step.\"\"\"\n",
|
| 953 |
+
" tokens = batch[\"tokens\"].to(device)\n",
|
| 954 |
+
" bigwig_targets = batch[\"bigwig_targets\"].to(device)\n",
|
| 955 |
+
" \n",
|
| 956 |
+
" with torch.no_grad():\n",
|
| 957 |
+
" # Forward pass\n",
|
| 958 |
+
" outputs = model(tokens=tokens)\n",
|
| 959 |
+
" bigwig_logits = outputs[\"bigwig_tracks_logits\"]\n",
|
| 960 |
+
" \n",
|
| 961 |
+
" # Compute loss\n",
|
| 962 |
+
" loss = poisson_multinomial_loss(\n",
|
| 963 |
+
" logits=bigwig_logits,\n",
|
| 964 |
+
" targets=bigwig_targets,\n",
|
| 965 |
+
" )\n",
|
| 966 |
+
" \n",
|
| 967 |
+
" # Update metrics\n",
|
| 968 |
+
" metrics.update(\n",
|
| 969 |
+
" predictions=bigwig_logits,\n",
|
| 970 |
+
" targets=bigwig_targets,\n",
|
| 971 |
+
" loss=loss.item()\n",
|
| 972 |
+
" )"
|
| 973 |
+
]
|
| 974 |
+
},
|
| 975 |
+
{
|
| 976 |
+
"cell_type": "markdown",
|
| 977 |
+
"metadata": {},
|
| 978 |
+
"source": [
|
| 979 |
+
"## Run Training Loop"
|
| 980 |
+
]
|
| 981 |
+
},
|
| 982 |
+
{
|
| 983 |
+
"cell_type": "code",
|
| 984 |
+
"execution_count": null,
|
| 985 |
+
"metadata": {},
|
| 986 |
+
"outputs": [],
|
| 987 |
+
"source": [
|
| 988 |
+
"# Training loop\n",
|
| 989 |
+
"print(f\"Starting training for {config['num_steps_training']} steps\\n\")\n",
|
| 990 |
+
"\n",
|
| 991 |
+
"# Create iterator for training data (will cycle if needed)\n",
|
| 992 |
+
"train_iter = iter(train_loader)\n",
|
| 993 |
+
"model.train()\n",
|
| 994 |
+
"\n",
|
| 995 |
+
"# Main training loop\n",
|
| 996 |
+
"for step_idx in range(config[\"num_steps_training\"]):\n",
|
| 997 |
+
" try:\n",
|
| 998 |
+
" batch = next(train_iter)\n",
|
| 999 |
+
" except StopIteration:\n",
|
| 1000 |
+
" # Restart iterator if we run out of data\n",
|
| 1001 |
+
" train_iter = iter(train_loader)\n",
|
| 1002 |
+
" batch = next(train_iter)\n",
|
| 1003 |
+
" \n",
|
| 1004 |
+
" # Take a training step\n",
|
| 1005 |
+
" train_step(model, optimizer, batch, train_metrics)\n",
|
| 1006 |
+
"\n",
|
| 1007 |
+
" # Logging\n",
|
| 1008 |
+
" if (step_idx + 1) % config[\"log_every_n_steps\"] == 0:\n",
|
| 1009 |
+
" train_metrics.update_mean_metrics(step_idx + 1)\n",
|
| 1010 |
+
" train_metrics.print_metrics()\n",
|
| 1011 |
+
" train_metrics.reset()\n",
|
| 1012 |
+
" \n",
|
| 1013 |
+
" # Validation\n",
|
| 1014 |
+
" if (step_idx + 1) % config[\"validate_every_n_steps\"] == 0:\n",
|
| 1015 |
+
" print(f\"\\nRunning validation at step {step_idx + 1}...\")\n",
|
| 1016 |
+
" model.eval()\n",
|
| 1017 |
+
" \n",
|
| 1018 |
+
" for val_batch in val_loader:\n",
|
| 1019 |
+
" validation_step(model, val_batch, val_metrics)\n",
|
| 1020 |
+
" \n",
|
| 1021 |
+
" val_metrics.update_mean_metrics(step_idx + 1)\n",
|
| 1022 |
+
" val_metrics.print_metrics(print_per_track=True)\n",
|
| 1023 |
+
" val_metrics.reset()\n",
|
| 1024 |
+
"\n",
|
| 1025 |
+
" # Back to training mode\n",
|
| 1026 |
+
" print(\"\\n\" + \"-\"*100 + \"\\nTraining metrics:\")\n",
|
| 1027 |
+
" model.train() \n",
|
| 1028 |
+
"\n",
|
| 1029 |
+
"print(f\"\\nTraining completed after {config['num_steps_training']} steps.\")\n"
|
| 1030 |
+
]
|
| 1031 |
+
},
|
| 1032 |
+
{
|
| 1033 |
+
"cell_type": "code",
|
| 1034 |
+
"execution_count": null,
|
| 1035 |
+
"metadata": {},
|
| 1036 |
+
"outputs": [],
|
| 1037 |
+
"source": [
|
| 1038 |
+
"# Plot training results\n",
|
| 1039 |
+
"fig, axes = plt.subplots(1, 2, figsize=(12, 5))\n",
|
| 1040 |
+
"\n",
|
| 1041 |
+
"df_train = pd.read_csv(\"metrics_train.csv\")\n",
|
| 1042 |
+
"df_val = pd.read_csv(\"metrics_val.csv\")\n",
|
| 1043 |
+
"\n",
|
| 1044 |
+
"# Plot Loss\n",
|
| 1045 |
+
"axes[0].plot(df_train[\"step\"], df_train[\"mean_loss\"], 'b-o', label='Train Loss', markersize=4, linewidth=1.5)\n",
|
| 1046 |
+
"axes[0].plot(df_val[\"step\"], df_val[\"mean_loss\"], 'r-s', label='Val Loss', markersize=4, linewidth=1.5)\n",
|
| 1047 |
+
"axes[0].set_xlabel('Step')\n",
|
| 1048 |
+
"axes[0].set_ylabel('Loss')\n",
|
| 1049 |
+
"axes[0].set_title('Loss')\n",
|
| 1050 |
+
"axes[0].legend()\n",
|
| 1051 |
+
"axes[0].grid(True, alpha=0.3)\n",
|
| 1052 |
+
"\n",
|
| 1053 |
+
"# Plot Pearson Correlation\n",
|
| 1054 |
+
"axes[1].plot(df_train[\"step\"], df_train[\"mean_pearson\"], 'g-o', label='Train Pearson', markersize=4, linewidth=1.5)\n",
|
| 1055 |
+
"axes[1].plot(df_val[\"step\"], df_val[\"mean_pearson\"], 'orange', marker='s', label='Val Pearson', markersize=4, linewidth=1.5)\n",
|
| 1056 |
+
"axes[1].set_xlabel('Step')\n",
|
| 1057 |
+
"axes[1].set_ylabel('Pearson Correlation')\n",
|
| 1058 |
+
"axes[1].set_title('Mean Pearson Correlation')\n",
|
| 1059 |
+
"axes[1].legend()\n",
|
| 1060 |
+
"axes[1].grid(True, alpha=0.3)"
|
| 1061 |
+
]
|
| 1062 |
+
},
|
| 1063 |
+
{
|
| 1064 |
+
"cell_type": "markdown",
|
| 1065 |
+
"metadata": {},
|
| 1066 |
+
"source": [
|
| 1067 |
+
"# 9. 🧪 Test evaluation\n",
|
| 1068 |
+
"\n",
|
| 1069 |
+
"Evaluate the fine-tuned model on the held-out test set to assess final performance. This provides an unbiased estimate of how well the model generalizes to unseen genomic regions."
|
| 1070 |
+
]
|
| 1071 |
+
},
|
| 1072 |
+
{
|
| 1073 |
+
"cell_type": "code",
|
| 1074 |
+
"execution_count": null,
|
| 1075 |
+
"metadata": {},
|
| 1076 |
+
"outputs": [],
|
| 1077 |
+
"source": [
|
| 1078 |
+
"# Calculate number of test steps (based on deepspeed pipeline)\n",
|
| 1079 |
+
"num_test_samples = len(test_dataset)\n",
|
| 1080 |
+
"num_test_steps = num_test_samples // config[\"batch_size\"]\n",
|
| 1081 |
+
"print(f\"Running test evaluation with {num_test_steps} steps ({num_test_samples} samples)\")\n",
|
| 1082 |
+
"\n",
|
| 1083 |
+
"# Set model to eval mode\n",
|
| 1084 |
+
"model.eval()\n",
|
| 1085 |
+
"\n",
|
| 1086 |
+
"# Run test evaluation with progress bar\n",
|
| 1087 |
+
"for test_batch in tqdm(test_loader, desc=\"Test evaluation\", total=num_test_steps): \n",
|
| 1088 |
+
" validation_step( \n",
|
| 1089 |
+
" model, \n",
|
| 1090 |
+
" test_batch, \n",
|
| 1091 |
+
" test_metrics,\n",
|
| 1092 |
+
" )\n",
|
| 1093 |
+
" \n",
|
| 1094 |
+
"# Compute final test metrics\n",
|
| 1095 |
+
"test_metrics_dict = test_metrics.compute()\n",
|
| 1096 |
+
"print(\"\\n\" + \"=\"*50)\n",
|
| 1097 |
+
"print(\"Test Set Results\")\n",
|
| 1098 |
+
"print(\"=\"*50)\n",
|
| 1099 |
+
"print(f\"\\nMetrics:\")\n",
|
| 1100 |
+
"print(f\" Mean Pearson: {test_metrics_dict['mean/pearson']:.4f}\")\n",
|
| 1101 |
+
"for track_name in bigwig_ids: \n",
|
| 1102 |
+
" print(f\" {track_name}/pearson: {test_metrics_dict[f'{track_name}/pearson']:.4f}\")"
|
| 1103 |
+
]
|
| 1104 |
+
},
|
| 1105 |
+
{
|
| 1106 |
+
"cell_type": "markdown",
|
| 1107 |
+
"metadata": {},
|
| 1108 |
+
"source": [
|
| 1109 |
+
" ## Test set results\n",
|
| 1110 |
+
"\n",
|
| 1111 |
+
"#TODO: Add test set results after run!"
|
| 1112 |
+
]
|
| 1113 |
+
},
|
| 1114 |
+
{
|
| 1115 |
+
"cell_type": "markdown",
|
| 1116 |
+
"metadata": {},
|
| 1117 |
+
"source": []
|
| 1118 |
+
}
|
| 1119 |
+
],
|
| 1120 |
+
"metadata": {
|
| 1121 |
+
"kernelspec": {
|
| 1122 |
+
"display_name": ".venv",
|
| 1123 |
+
"language": "python",
|
| 1124 |
+
"name": "python3"
|
| 1125 |
+
},
|
| 1126 |
+
"language_info": {
|
| 1127 |
+
"codemirror_mode": {
|
| 1128 |
+
"name": "ipython",
|
| 1129 |
+
"version": 3
|
| 1130 |
+
},
|
| 1131 |
+
"file_extension": ".py",
|
| 1132 |
+
"mimetype": "text/x-python",
|
| 1133 |
+
"name": "python",
|
| 1134 |
+
"nbconvert_exporter": "python",
|
| 1135 |
+
"pygments_lexer": "ipython3",
|
| 1136 |
+
"version": "3.11.14"
|
| 1137 |
+
}
|
| 1138 |
+
},
|
| 1139 |
+
"nbformat": 4,
|
| 1140 |
+
"nbformat_minor": 2
|
| 1141 |
+
}
|
notebooks_tutorials/{03_model_interpretation.ipynb → 04_model_interpretation.ipynb}
RENAMED
|
File without changes
|
tabs/home.html
CHANGED
|
@@ -84,10 +84,11 @@
|
|
| 84 |
<ul>
|
| 85 |
<li><a href="https://huggingface.co/spaces/InstaDeepAI/ntv3/blob/main/notebooks_tutorials/00_quickstart_inference.ipynb" target="_blank" rel="noopener noreferrer">🚀 00 — Quickstart inference</a></li>
|
| 86 |
<li><a href="https://huggingface.co/spaces/InstaDeepAI/ntv3/blob/main/notebooks_tutorials/01_tracks_prediction.ipynb" target="_blank" rel="noopener noreferrer">📊 01 — Tracks prediction</a></li>
|
| 87 |
-
<li><a href="https://huggingface.co/spaces/InstaDeepAI/ntv3/blob/main/notebooks_tutorials/
|
| 88 |
-
<li><a href="https://huggingface.co/spaces/InstaDeepAI/ntv3/blob/main/notebooks_tutorials/
|
| 89 |
-
<li
|
| 90 |
-
<li
|
|
|
|
| 91 |
</ul>
|
| 92 |
</div>
|
| 93 |
<div class="card">
|
|
|
|
| 84 |
<ul>
|
| 85 |
<li><a href="https://huggingface.co/spaces/InstaDeepAI/ntv3/blob/main/notebooks_tutorials/00_quickstart_inference.ipynb" target="_blank" rel="noopener noreferrer">🚀 00 — Quickstart inference</a></li>
|
| 86 |
<li><a href="https://huggingface.co/spaces/InstaDeepAI/ntv3/blob/main/notebooks_tutorials/01_tracks_prediction.ipynb" target="_blank" rel="noopener noreferrer">📊 01 — Tracks prediction</a></li>
|
| 87 |
+
<li><a href="https://huggingface.co/spaces/InstaDeepAI/ntv3/blob/main/notebooks_tutorials/02_fine_tuning_pretrained_model.ipynb" target="_blank" rel="noopener noreferrer">🎯 02 — Fine-tune a pre-trained model on bigwig tracks</a></li>
|
| 88 |
+
<li><a href="https://huggingface.co/spaces/InstaDeepAI/ntv3/blob/main/notebooks_tutorials/03_fine_tuning_posttrained_model.ipynb" target="_blank" rel="noopener noreferrer">🎯 03 — Fine-tune a post-trained model on bigwig tracks</a></li>
|
| 89 |
+
<li><a href="https://huggingface.co/spaces/InstaDeepAI/ntv3/blob/main/notebooks_tutorials/04_model_interpretation.ipynb" target="_blank" rel="noopener noreferrer">🔍 04 — Model interpretation</a></li>
|
| 90 |
+
<li>🧪 05 — Training NTv3-generative <em>(coming soon)</em></li>
|
| 91 |
+
<li>🪰 06 — Generating enhancer sequences <em>(coming soon)</em></li>
|
| 92 |
</ul>
|
| 93 |
</div>
|
| 94 |
<div class="card">
|