bernardo-de-almeida commited on
Commit
31547d2
·
1 Parent(s): 9759882

fine-tuning notebook

Browse files
notebooks_pipelines/01_functional_track_prediction.ipynb CHANGED
The diff for this file is too large to render. See raw diff
 
notebooks_tutorials/02_fine_tuning.ipynb ADDED
@@ -0,0 +1,1318 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "metadata": {},
6
+ "source": [
7
+ "# 🧬 Fine-Tuning a Model on BigWig Tracks Prediction\n",
8
+ "\n",
9
+ "This notebook demonstrates a **simplified fine-tuning setup** that enables training of a pre-trained Nucleotide Transformer v3 (NTv3) model to predict BigWig signal tracks directly from DNA sequences. The streamlined approach leverages a pre-trained NTv3 backbone as a feature extractor and adds a custom prediction head that outputs single-nucleotide resolution signal values for various genomic tracks (e.g., ChIP-seq, ATAC-seq, RNA-seq).\n",
10
+ "\n",
11
+ "**⚡ Key Advantage**: This simplified pipeline achieves close performance to more complex training approaches while enabling fast fine-tuning. The training speed benefits from the efficient NTv3 model architecture and depends on your hardware capabilities (GPU acceleration and multi-worker data loading significantly reduce training time). With NTv3 models, meaningful Pearson correlations can typically be reached within ~10minutes of training on a 32kb functional tracks prediction task.\n",
12
+ "\n",
13
+ "**🔧 Main Simplifications**: Compared to the full supervised tracks pipeline, this notebook simplifies several aspects to enable faster iteration:\n",
14
+ "\n",
15
+ "- **Data splits**: Uses simple chromosome-based train/val/test splits (e.g., assigning entire chromosomes to each split) instead of more complex region-based splits\n",
16
+ "- **Random sequence sampling**: The dataset randomly samples sequences from chromosomes/regions on-the-fly, rather than using pre-computed sliding windows\n",
17
+ "- **Constant learning rate**: Uses a fixed learning rate throughout training without learning rate scheduling\n",
18
+ "- **No gradient accumulation**: Implements simple step-based training without gradient accumulation, making the training loop more straightforward\n",
19
+ "\n",
20
+ "The pipeline walks through the complete fine-tuning workflow:\n",
21
+ "\n",
22
+ "- Loading genomic sequences from FASTA files and their corresponding BigWig signal tracks\n",
23
+ "- Setting up a PyTorch dataset with proper train/validation/test splits\n",
24
+ "- Configuring the model architecture with a custom linear head\n",
25
+ "- Implementing a training loop with appropriate loss functions and evaluation metrics\n",
26
+ "\n",
27
+ "This provides a clean interface for training and evaluation.\n",
28
+ "\n",
29
+ "The model architecture consists of a pre-trained NTv3 backbone that processes DNA sequences and a custom linear head that predicts BigWig signal values at single-nucleotide resolution. Predictions are center-cropped to focus on the central portion of the input sequence (configurable via `keep_target_center_fraction`), which helps reduce edge effects from sequence context windows. The training uses a Poisson-Multinomial loss function that captures both the scale and shape of the signal distributions, and evaluation is performed using Pearson correlation metrics on both scaled and raw predictions.\n",
30
+ "\n",
31
+ "If you're interested in using pre-trained models for inference without fine-tuning, or exploring different model architectures, please refer to other notebooks in this collection. This notebook focuses specifically on the simplified fine-tuning process, which is useful when you want to quickly adapt a pre-trained model to your specific genomic tracks or improve performance on particular cell types or experimental conditions.\n",
32
+ "\n",
33
+ "📝 Note for Google Colab users: This notebook is compatible with Colab! For faster training, make sure to enable GPU: Runtime → Change runtime type → GPU (T4 or better recommended).\n"
34
+ ]
35
+ },
36
+ {
37
+ "cell_type": "markdown",
38
+ "metadata": {},
39
+ "source": [
40
+ "# 0. 📦 Imports dependencies"
41
+ ]
42
+ },
43
+ {
44
+ "cell_type": "code",
45
+ "execution_count": null,
46
+ "metadata": {},
47
+ "outputs": [],
48
+ "source": [
49
+ "# Install dependencies\n",
50
+ "!pip install pyfaidx pyBigWig torchmetrics transformers plotly"
51
+ ]
52
+ },
53
+ {
54
+ "cell_type": "code",
55
+ "execution_count": 20,
56
+ "metadata": {},
57
+ "outputs": [],
58
+ "source": [
59
+ "import random\n",
60
+ "import functools\n",
61
+ "from typing import List, Dict, Callable\n",
62
+ "import os\n",
63
+ "import subprocess\n",
64
+ "from concurrent.futures import ThreadPoolExecutor, as_completed\n",
65
+ "\n",
66
+ "import torch\n",
67
+ "import torch.nn as nn\n",
68
+ "import torch.nn.functional as F\n",
69
+ "from torch.utils.data import Dataset, DataLoader\n",
70
+ "from torch.optim import AdamW\n",
71
+ "from transformers import AutoConfig, AutoModelForMaskedLM, AutoTokenizer\n",
72
+ "import numpy as np\n",
73
+ "import pyBigWig\n",
74
+ "from pyfaidx import Fasta\n",
75
+ "from torchmetrics import PearsonCorrCoef\n",
76
+ "import plotly.graph_objects as go\n",
77
+ "from IPython.display import display\n",
78
+ "from tqdm import tqdm"
79
+ ]
80
+ },
81
+ {
82
+ "cell_type": "markdown",
83
+ "metadata": {},
84
+ "source": [
85
+ "# 1. ⚙️ Configuration\n",
86
+ "\n",
87
+ "## Configuration Parameters\n",
88
+ "\n",
89
+ "### Model\n",
90
+ "- **`model_name`**: HuggingFace model name/identifier for the pretrained backbone model\n",
91
+ "\n",
92
+ "### Data\n",
93
+ "- **`data_cache_dir`**: Directory where downloaded data files (FASTA, bigWig) will be stored\n",
94
+ "- **`fasta_url`**: URL to download reference genome FASTA file\n",
95
+ "- **`bigwig_url_list`**: List of URLs for bigWig track files to download\n",
96
+ "- **`bigwig_file_ids`**: List of identifiers/names for bigWig tracks (set after downloading, used for model head and metrics)\n",
97
+ "- **`sequence_length`**: Length of input sequences in base pairs (bp)\n",
98
+ "- **`keep_target_center_fraction`**: Fraction of center sequence to keep for target prediction (crops edges to focus on center)\n",
99
+ "\n",
100
+ "### Training\n",
101
+ "- **`batch_size`**: Number of samples per batch\n",
102
+ "- **`learning_rate`**: Constant learning rate for optimizer\n",
103
+ "- **`weight_decay`**: L2 regularization coefficient for optimizer\n",
104
+ "- **`num_steps_training`**: Total number of training steps\n",
105
+ "- **`log_every_n_steps`**: Log training metrics every N steps\n",
106
+ "\n",
107
+ "### Validation\n",
108
+ "- **`validate_every_n_steps`**: Run validation every N steps\n",
109
+ "- **`num_validation_samples`**: Number of samples to use for validation set\n",
110
+ "\n",
111
+ "### General\n",
112
+ "- **`seed`**: Random seed for reproducibility\n",
113
+ "- **`device`**: Device to run training on (\"cuda\" or \"cpu\")\n",
114
+ "- **`num_workers`**: Number of worker processes for DataLoader (0 = single-threaded)"
115
+ ]
116
+ },
117
+ {
118
+ "cell_type": "code",
119
+ "execution_count": 21,
120
+ "metadata": {},
121
+ "outputs": [
122
+ {
123
+ "name": "stdout",
124
+ "output_type": "stream",
125
+ "text": [
126
+ "Using device: cpu\n"
127
+ ]
128
+ }
129
+ ],
130
+ "source": [
131
+ "config = {\n",
132
+ " # Model\n",
133
+ " \"model_name\": \"InstaDeepAI/NTv3_8M_pre\",\n",
134
+ " \n",
135
+ " # Data\n",
136
+ " \"data_cache_dir\": \"./data\",\n",
137
+ " \"fasta_url\": \"https://hgdownload.gi.ucsc.edu/goldenPath/hg38/bigZips/hg38.fa.gz\",\n",
138
+ " \"bigwig_url_list\": [\n",
139
+ " \"https://www.encodeproject.org/files/ENCFF055QKS/@@download/ENCFF055QKS.bigWig\",\n",
140
+ " \"https://www.encodeproject.org/files/ENCFF214GOQ/@@download/ENCFF214GOQ.bigWig\",\n",
141
+ " \"https://www.encodeproject.org/files/ENCFF592NIB/@@download/ENCFF592NIB.bigWig\",\n",
142
+ " \"https://www.encodeproject.org/files/ENCFF921PHQ/@@download/ENCFF921PHQ.bigWig\",\n",
143
+ " ],\n",
144
+ " \"sequence_length\": 32_768,\n",
145
+ " \"keep_target_center_fraction\": 0.375,\n",
146
+ " \n",
147
+ " # Training\n",
148
+ " \"batch_size\": 32,\n",
149
+ " \"num_steps_training\": 19932,\n",
150
+ " \"log_every_n_steps\": 40,\n",
151
+ " \"learning_rate\": 1e-5,\n",
152
+ " \"weight_decay\": 0.01,\n",
153
+ " \n",
154
+ " # Validation\n",
155
+ " \"validate_every_n_steps\": 400,\n",
156
+ " \"num_validation_samples\": 1000,\n",
157
+ "\n",
158
+ " # Test\n",
159
+ " \"num_test_samples\": 10000,\n",
160
+ " \n",
161
+ " # General\n",
162
+ " \"seed\": 17,\n",
163
+ " \"device\": \"cuda\" if torch.cuda.is_available() else \"cpu\",\n",
164
+ " \"num_workers\": 16,\n",
165
+ "}\n",
166
+ "\n",
167
+ "os.makedirs(config[\"data_cache_dir\"], exist_ok=True)\n",
168
+ "\n",
169
+ "# Extract filenames from URLs\n",
170
+ "def extract_filename_from_url(url: str) -> str:\n",
171
+ " \"\"\"Extract filename from URL, handling query parameters.\"\"\"\n",
172
+ " # Remove query parameters if present\n",
173
+ " url_clean = url.split('?')[0]\n",
174
+ " # Get the last part of the URL path\n",
175
+ " return url_clean.split('/')[-1]\n",
176
+ "\n",
177
+ "# Create paths for downloaded files\n",
178
+ "fasta_path = os.path.join(config[\"data_cache_dir\"], extract_filename_from_url(config[\"fasta_url\"]).replace('.gz', ''))\n",
179
+ "bigwig_path_list = [\n",
180
+ " os.path.join(config[\"data_cache_dir\"], extract_filename_from_url(url))\n",
181
+ " for url in config[\"bigwig_url_list\"]\n",
182
+ "]\n",
183
+ "\n",
184
+ "\n",
185
+ "# TODO: find a way to link the experiment accession to bigwig file ids\n",
186
+ "# Create bigwig_file_ids from filenames (without extension)\n",
187
+ "config[\"bigwig_file_ids\"] = [\n",
188
+ " # os.path.splitext(extract_filename_from_url(url))[0]\n",
189
+ " # for url in config[\"bigwig_url_list\"]\n",
190
+ " \"ENCSR325NFE\",\n",
191
+ " \"ENCSR962OTG\",\n",
192
+ " \"ENCSR619DQO_P\",\n",
193
+ " \"ENCSR619DQO_M\",\n",
194
+ "]\n",
195
+ "\n",
196
+ "# Set random seed\n",
197
+ "torch.manual_seed(config[\"seed\"])\n",
198
+ "np.random.seed(config[\"seed\"])\n",
199
+ "\n",
200
+ "# Set device\n",
201
+ "device = torch.device(config[\"device\"])\n",
202
+ "print(f\"Using device: {device}\")"
203
+ ]
204
+ },
205
+ {
206
+ "cell_type": "markdown",
207
+ "metadata": {},
208
+ "source": [
209
+ "# 2. 📥 Genome & Tracks Data Download\n",
210
+ "\n",
211
+ "Download the reference genome FASTA file and BigWig signal tracks from public repositories. These files contain the genomic sequences and experimental signal data (e.g., ChIP-seq, ATAC-seq) that we'll use for training."
212
+ ]
213
+ },
214
+ {
215
+ "cell_type": "code",
216
+ "execution_count": null,
217
+ "metadata": {},
218
+ "outputs": [],
219
+ "source": [
220
+ "def _download_file(url: str, output_path: str) -> None:\n",
221
+ " \"\"\"Download a file from URL to output_path using wget.\"\"\"\n",
222
+ " subprocess.run([\"wget\", \"-c\", url, \"-O\", output_path], check=True)\n",
223
+ "\n",
224
+ "# Prepare download tasks: (url, output_path)\n",
225
+ "download_tasks = []\n",
226
+ "\n",
227
+ "# FASTA file\n",
228
+ "fasta_filename = extract_filename_from_url(config[\"fasta_url\"])\n",
229
+ "fasta_gz_path = os.path.join(config[\"data_cache_dir\"], fasta_filename)\n",
230
+ "download_tasks.append((config[\"fasta_url\"], fasta_gz_path))\n",
231
+ "\n",
232
+ "# BigWig files\n",
233
+ "for bigwig_url in config[\"bigwig_url_list\"]:\n",
234
+ " filename = extract_filename_from_url(bigwig_url)\n",
235
+ " filepath = os.path.join(config[\"data_cache_dir\"], filename)\n",
236
+ " download_tasks.append((bigwig_url, filepath))\n",
237
+ "\n",
238
+ "# Download files in parallel\n",
239
+ "max_workers = min(len(download_tasks), 8)\n",
240
+ "\n",
241
+ "print(f\"Downloading {len(download_tasks)} files using {max_workers} workers...\")\n",
242
+ "with ThreadPoolExecutor(max_workers=max_workers) as executor:\n",
243
+ " # Submit all download tasks\n",
244
+ " future_to_path = {\n",
245
+ " executor.submit(_download_file, url, path): path\n",
246
+ " for url, path in download_tasks\n",
247
+ " }\n",
248
+ " \n",
249
+ " # Wait for all downloads to complete\n",
250
+ " for future in as_completed(future_to_path):\n",
251
+ " try:\n",
252
+ " future.result() # Raises exception if download failed\n",
253
+ " path = future_to_path[future]\n",
254
+ " print(f\"✓ Downloaded: {os.path.basename(path)}\")\n",
255
+ " except Exception as e:\n",
256
+ " path = future_to_path[future]\n",
257
+ " raise RuntimeError(f\"Failed to download {path}: {e}\") from e\n",
258
+ "\n",
259
+ "# Extract FASTA file after download\n",
260
+ "print(f\"\\nExtracting {fasta_filename}...\")\n",
261
+ "subprocess.run([\"gunzip\", \"-f\", fasta_gz_path], check=True)\n",
262
+ "print(\"✓ Extraction complete\")"
263
+ ]
264
+ },
265
+ {
266
+ "cell_type": "markdown",
267
+ "metadata": {},
268
+ "source": [
269
+ "### Data Splits Definition"
270
+ ]
271
+ },
272
+ {
273
+ "cell_type": "code",
274
+ "execution_count": null,
275
+ "metadata": {},
276
+ "outputs": [],
277
+ "source": [
278
+ "chrom_splits = {\n",
279
+ " \"train\": [f\"chr{i}\" for i in range(1, 21)] + ['chrX', 'chrY'],\n",
280
+ " \"val\": ['chr22'],\n",
281
+ " \"test\": ['chr21']\n",
282
+ "}"
283
+ ]
284
+ },
285
+ {
286
+ "cell_type": "markdown",
287
+ "metadata": {},
288
+ "source": [
289
+ "# 3. 🧠 Model and tokenizer setup\n",
290
+ " \n",
291
+ "In this section, we set up the model and tokenizer. \n",
292
+ " \n",
293
+ "Our approach uses any suitable pretrained backbone from HuggingFace Transformers (for example, `InstaDeepAI/ntv3_650M_pre`),\n",
294
+ "which is then extended with an additional linear head. \n",
295
+ " \n",
296
+ "This linear head is trained for regression on a set of genomic tracks, \n",
297
+ "allowing the model to make predictions for each track at single nucleotide resolution.\n",
298
+ " \n",
299
+ "The following code wraps the HuggingFace model together with this regression head for the end-to-end task.\n"
300
+ ]
301
+ },
302
+ {
303
+ "cell_type": "code",
304
+ "execution_count": null,
305
+ "metadata": {},
306
+ "outputs": [],
307
+ "source": [
308
+ "class LinearHead(nn.Module):\n",
309
+ " \"\"\"A linear head that predicts one scalar value per track.\"\"\"\n",
310
+ " def __init__(self, embed_dim: int, num_labels: int):\n",
311
+ " super().__init__()\n",
312
+ " self.layer_norm = nn.LayerNorm(embed_dim)\n",
313
+ " self.head = nn.Linear(embed_dim, num_labels)\n",
314
+ " \n",
315
+ " def forward(self, x: torch.Tensor) -> torch.Tensor:\n",
316
+ " x = self.layer_norm(x)\n",
317
+ " x = self.head(x)\n",
318
+ " x = F.softplus(x) # Ensure positive values\n",
319
+ " return x\n",
320
+ "\n",
321
+ "\n",
322
+ "class HFModelWithHead(nn.Module):\n",
323
+ " \"\"\"Simple model wrapper: HF backbone + bigwig head.\"\"\"\n",
324
+ " \n",
325
+ " def __init__(\n",
326
+ " self,\n",
327
+ " model_name: str,\n",
328
+ " bigwig_track_names: List[str],\n",
329
+ " keep_target_center_fraction: float = 0.375,\n",
330
+ " ):\n",
331
+ " super().__init__()\n",
332
+ " \n",
333
+ " # Load config and model\n",
334
+ " self.config = AutoConfig.from_pretrained(model_name, trust_remote_code=True)\n",
335
+ " self.backbone = AutoModelForMaskedLM.from_pretrained(\n",
336
+ " model_name, \n",
337
+ " trust_remote_code=True,\n",
338
+ " config=self.config\n",
339
+ " )\n",
340
+ " \n",
341
+ " self.keep_target_center_fraction = keep_target_center_fraction\n",
342
+ "\n",
343
+ " if hasattr(self.config, \"embed_dim\"):\n",
344
+ " embed_dim = self.config.embed_dim\n",
345
+ " else:\n",
346
+ " raise ValueError(f\"Could not determine embed_dim for {model_name}\")\n",
347
+ " \n",
348
+ " # Bigwig head (NTv3 outputs at single-nucleotide resolution)\n",
349
+ " self.bigwig_head = LinearHead(embed_dim, len(bigwig_track_names))\n",
350
+ " self.model_name = model_name\n",
351
+ " \n",
352
+ " def forward(self, tokens: torch.Tensor, **kwargs) -> Dict[str, torch.Tensor]:\n",
353
+ " # Forward through backbone\n",
354
+ " outputs = self.backbone(input_ids=tokens)\n",
355
+ " embedding = outputs.hidden_states[-1] # Last hidden state\n",
356
+ " \n",
357
+ " # Crop to center fraction\n",
358
+ " if self.keep_target_center_fraction < 1.0:\n",
359
+ " seq_len = embedding.shape[1]\n",
360
+ " target_offset = int(seq_len * (1 - self.keep_target_center_fraction) // 2)\n",
361
+ " target_length = seq_len - 2 * target_offset\n",
362
+ " embedding = embedding[:, target_offset:target_offset + target_length, :]\n",
363
+ " \n",
364
+ " # Predict bigwig tracks\n",
365
+ " bigwig_logits = self.bigwig_head(embedding)\n",
366
+ " \n",
367
+ " return {\"bigwig_tracks_logits\": bigwig_logits}"
368
+ ]
369
+ },
370
+ {
371
+ "cell_type": "code",
372
+ "execution_count": null,
373
+ "metadata": {},
374
+ "outputs": [],
375
+ "source": [
376
+ "# Load tokenizer\n",
377
+ "tokenizer = AutoTokenizer.from_pretrained(config[\"model_name\"], trust_remote_code=True)\n",
378
+ "\n",
379
+ "# Create model\n",
380
+ "model = HFModelWithHead(\n",
381
+ " model_name=config[\"model_name\"],\n",
382
+ " bigwig_track_names=config[\"bigwig_file_ids\"],\n",
383
+ " keep_target_center_fraction=config[\"keep_target_center_fraction\"],\n",
384
+ ")\n",
385
+ "model = model.to(device)\n",
386
+ "model.train()\n",
387
+ "\n",
388
+ "print(f\"Model loaded: {config['model_name']}\")\n",
389
+ "print(f\"Number of bigwig tracks: {len(config['bigwig_file_ids'])}\")\n",
390
+ "print(f\"Model parameters: {sum(p.numel() for p in model.parameters()):,}\")"
391
+ ]
392
+ },
393
+ {
394
+ "cell_type": "markdown",
395
+ "metadata": {},
396
+ "source": [
397
+ "# 4. 🔄 Data loading\n",
398
+ "\n",
399
+ "Create PyTorch datasets and data loaders that efficiently sample random genomic windows from the reference genome and extract corresponding BigWig signal values. The dataset handles sequence tokenization, target scaling, and chromosome-based train/val/test splits."
400
+ ]
401
+ },
402
+ {
403
+ "cell_type": "code",
404
+ "execution_count": null,
405
+ "metadata": {},
406
+ "outputs": [],
407
+ "source": [
408
+ "# Process-local cache for BigWig file handles (one per worker process)\n",
409
+ "# This allows safe multi-worker DataLoader usage\n",
410
+ "_bigwig_cache = {} # Maps (process_id, file_path) -> pyBigWig handle\n",
411
+ "\n",
412
+ "\n",
413
+ "def _get_bigwig_handle(bigwig_path: str) -> pyBigWig.pyBigWig:\n",
414
+ " \"\"\"Get or create a BigWig file handle for the current process.\"\"\"\n",
415
+ " process_id = os.getpid()\n",
416
+ " cache_key = (process_id, bigwig_path)\n",
417
+ " \n",
418
+ " if cache_key not in _bigwig_cache:\n",
419
+ " _bigwig_cache[cache_key] = pyBigWig.open(bigwig_path)\n",
420
+ " \n",
421
+ " return _bigwig_cache[cache_key]\n",
422
+ "\n",
423
+ "\n",
424
+ "class GenomeBigWigDataset(Dataset):\n",
425
+ " \"\"\"\n",
426
+ " Random genomic windows from a reference genome + bigWig signal.\n",
427
+ "\n",
428
+ " Each sample:\n",
429
+ " - picks a chromosome/region (from `chroms` or `regions`),\n",
430
+ " - picks a random window of length `sequence_length`,\n",
431
+ " - returns (sequence, signal, chrom, start, end).\n",
432
+ "\n",
433
+ " This dataset is compatible with multi-worker DataLoaders. BigWig files\n",
434
+ " are opened lazily using a process-local cache, ensuring each worker process\n",
435
+ " has its own file handles and avoiding concurrent access issues.\n",
436
+ "\n",
437
+ " Args\n",
438
+ " ----\n",
439
+ " fasta_path : str\n",
440
+ " Path to the reference genome FASTA (e.g. hg38.fna).\n",
441
+ " bigwig_path_list : str\n",
442
+ " Path to the bigWig file (e.g. ENCFF884LDL.bigWig).\n",
443
+ " chroms : List[str]\n",
444
+ " Chromosome names as they appear in the bigWig (e.g. [\"chr1\", \"chr2\", ...]).\n",
445
+ " Used for backward compatibility or when regions=None.\n",
446
+ " sequence_length : int\n",
447
+ " Length of each random window (in bp).\n",
448
+ " num_samples : int\n",
449
+ " Number of samples the dataset will provide (len(dataset)).\n",
450
+ " tokenizer : AutoTokenizer\n",
451
+ " Tokenizer to use for tokenization.\n",
452
+ " transform_fn : Callable\n",
453
+ " Function to transform/scaling bigwig targets.\n",
454
+ " keep_target_center_fraction : float\n",
455
+ " Fraction of center sequence to keep for target prediction (crops edges to focus on center).\n",
456
+ " regions : List[tuple[str, int, int]] | None\n",
457
+ " Optional list of regions as (chromosome, start, end) tuples.\n",
458
+ " If provided, samples are drawn randomly from within these regions only.\n",
459
+ " This matches the JAX pipeline approach using BED file splits.\n",
460
+ " If None, samples from entire chromosomes in `chroms`.\n",
461
+ " \"\"\"\n",
462
+ "\n",
463
+ " def __init__(\n",
464
+ " self,\n",
465
+ " fasta_path: str,\n",
466
+ " bigwig_path_list: list[str],\n",
467
+ " chroms: List[str],\n",
468
+ " sequence_length: int,\n",
469
+ " num_samples: int,\n",
470
+ " tokenizer: AutoTokenizer,\n",
471
+ " transform_fn: Callable[[torch.Tensor], torch.Tensor],\n",
472
+ " keep_target_center_fraction: float = 1.0,\n",
473
+ " ):\n",
474
+ " super().__init__()\n",
475
+ "\n",
476
+ " self.fasta = Fasta(fasta_path, as_raw=True, sequence_always_upper=True)\n",
477
+ " # Store paths instead of opening files immediately (for multi-worker compatibility)\n",
478
+ " self.bigwig_path_list = bigwig_path_list\n",
479
+ " self.sequence_length = sequence_length\n",
480
+ " self.num_samples = num_samples\n",
481
+ " self.tokenizer = tokenizer\n",
482
+ " self.transform_fn = transform_fn # Use pre-computed transform function\n",
483
+ " self.keep_target_center_fraction = keep_target_center_fraction\n",
484
+ " self.chroms = chroms\n",
485
+ "\n",
486
+ " # Get chromosome lengths from first BigWig file (lazy, cached per process)\n",
487
+ " # We need this for validation, so open temporarily\n",
488
+ " bw_handle = _get_bigwig_handle(bigwig_path_list[0])\n",
489
+ " bw_chrom_lengths = bw_handle.chroms() # dict: chrom -> length\n",
490
+ "\n",
491
+ " self.valid_chroms = []\n",
492
+ " self.chrom_lengths = {}\n",
493
+ "\n",
494
+ " for c in chroms:\n",
495
+ " if c not in bw_chrom_lengths or c not in self.fasta:\n",
496
+ " continue\n",
497
+ "\n",
498
+ " fa_len = len(self.fasta[c])\n",
499
+ " bw_len = bw_chrom_lengths[c]\n",
500
+ " L = min(fa_len, bw_len)\n",
501
+ "\n",
502
+ " if L > self.sequence_length:\n",
503
+ " self.valid_chroms.append(c)\n",
504
+ " self.chrom_lengths[c] = L\n",
505
+ "\n",
506
+ " if not self.valid_chroms:\n",
507
+ " raise ValueError(\"No valid chromosomes after intersecting FASTA and bigWig.\")\n",
508
+ "\n",
509
+ " def __len__(self):\n",
510
+ " return self.num_samples\n",
511
+ "\n",
512
+ " def __getitem__(self, idx):\n",
513
+ "\n",
514
+ " # Sample from entire chromosomes\n",
515
+ " chrom = random.choice(self.valid_chroms)\n",
516
+ " chrom_len = self.chrom_lengths[chrom]\n",
517
+ " max_start = chrom_len - self.sequence_length\n",
518
+ " start = random.randint(0, max_start)\n",
519
+ " end = start + self.sequence_length\n",
520
+ "\n",
521
+ " # Sequence\n",
522
+ " seq = self.fasta[chrom][start:end] # string slice\n",
523
+ " # Tokenize with padding and truncation to ensure consistent lengths for batching\n",
524
+ " tokenized = self.tokenizer(\n",
525
+ " seq,\n",
526
+ " padding=\"max_length\",\n",
527
+ " truncation=True,\n",
528
+ " max_length=self.sequence_length,\n",
529
+ " return_tensors=\"pt\",\n",
530
+ " )\n",
531
+ " tokens = tokenized[\"input_ids\"][0] # Shape: (max_length,)\n",
532
+ "\n",
533
+ " # Signal from bigWig tracks (numpy array) -> torch tensor\n",
534
+ " # Get BigWig handles lazily (cached per worker process)\n",
535
+ " bigwig_targets = np.array([\n",
536
+ " _get_bigwig_handle(bw_path).values(chrom, start, end, numpy=True)\n",
537
+ " for bw_path in self.bigwig_path_list\n",
538
+ " ]) # shape (num_tracks, seq_len)\n",
539
+ " # Transpose to (seq_len, num_tracks)\n",
540
+ " bigwig_targets = bigwig_targets.T\n",
541
+ " # pyBigWig returns NaN where no data; turn NaN into 0\n",
542
+ " bigwig_targets = torch.tensor(bigwig_targets, dtype=torch.float32)\n",
543
+ " bigwig_targets = torch.nan_to_num(bigwig_targets, nan=0.0)\n",
544
+ " \n",
545
+ " # Crop targets to center fraction\n",
546
+ " if self.keep_target_center_fraction < 1.0:\n",
547
+ " seq_len = bigwig_targets.shape[0] # First dimension is sequence length\n",
548
+ " target_offset = int(seq_len * (1 - self.keep_target_center_fraction) // 2)\n",
549
+ " target_length = seq_len - 2 * target_offset\n",
550
+ " bigwig_targets = bigwig_targets[target_offset:target_offset + target_length, :]\n",
551
+ "\n",
552
+ " # Apply scaling to targets\n",
553
+ " bigwig_targets = self.transform_fn(bigwig_targets)\n",
554
+ "\n",
555
+ " sample = {\n",
556
+ " \"tokens\": tokens,\n",
557
+ " \"bigwig_targets\": bigwig_targets,\n",
558
+ " \"chrom\": chrom,\n",
559
+ " \"start\": start,\n",
560
+ " \"end\": end,\n",
561
+ " }\n",
562
+ " return sample"
563
+ ]
564
+ },
565
+ {
566
+ "cell_type": "markdown",
567
+ "metadata": {},
568
+ "source": [
569
+ "### Data preprocessing utilities"
570
+ ]
571
+ },
572
+ {
573
+ "cell_type": "code",
574
+ "execution_count": null,
575
+ "metadata": {},
576
+ "outputs": [],
577
+ "source": [
578
+ "# Scaling functions for targets\n",
579
+ "def compute_chromosome_stats(track_data: np.ndarray) -> dict:\n",
580
+ " \"\"\"\n",
581
+ " Compute minimal statistics needed for weighted mean computation.\n",
582
+ " \n",
583
+ " Args:\n",
584
+ " track_data: numpy array of track values for a chromosome\n",
585
+ " \n",
586
+ " Returns:\n",
587
+ " Dictionary with statistics: sum, mean, total_count\n",
588
+ " \"\"\"\n",
589
+ " track_data = track_data.astype(np.float32)\n",
590
+ " \n",
591
+ " # Compute statistics\n",
592
+ " sum_all = np.sum(track_data)\n",
593
+ " total_count = track_data.size\n",
594
+ " mean_all = sum_all / total_count if total_count > 0 else 0.0\n",
595
+ " \n",
596
+ " return {\n",
597
+ " \"sum\": sum_all,\n",
598
+ " \"mean\": mean_all,\n",
599
+ " \"total_count\": total_count,\n",
600
+ " }\n",
601
+ "\n",
602
+ "\n",
603
+ "def aggregate_file_statistics(chr_stats_list: List[dict]) -> dict:\n",
604
+ " \"\"\"\n",
605
+ " Aggregate chromosome-level statistics into file-level statistics.\n",
606
+ " \n",
607
+ " Args:\n",
608
+ " chr_stats_list: List of dictionaries, each containing chromosome-level statistics\n",
609
+ " \n",
610
+ " Returns:\n",
611
+ " Dictionary with aggregated file-level statistics (only mean)\n",
612
+ " \"\"\"\n",
613
+ " # Convert to arrays for easier computation\n",
614
+ " total_counts = np.array([s[\"total_count\"] for s in chr_stats_list], dtype=np.int64)\n",
615
+ " means = np.array([s[\"mean\"] for s in chr_stats_list], dtype=np.float32)\n",
616
+ " sums = np.array([s[\"sum\"] for s in chr_stats_list], dtype=np.float32)\n",
617
+ " \n",
618
+ " # Aggregate total count\n",
619
+ " total_count = np.sum(total_counts)\n",
620
+ " \n",
621
+ " # Weighted mean: mean = sum(mean_chr * total_count_chr) / sum(total_count_chr)\n",
622
+ " mean = np.sum(means * total_counts) / total_count if total_count > 0 else 0.0\n",
623
+ " \n",
624
+ " return {\n",
625
+ " \"total_count\": total_count,\n",
626
+ " \"sum\": np.sum(sums),\n",
627
+ " \"mean\": mean,\n",
628
+ " }\n",
629
+ "\n",
630
+ "\n",
631
+ "def get_track_means(bigwig_tracks_list: List[pyBigWig.pyBigWig]) -> np.ndarray:\n",
632
+ " \"\"\"\n",
633
+ " Get track means for normalization.\n",
634
+ " Computes statistics per chromosome and aggregates using weighted averaging,\n",
635
+ " \n",
636
+ " Args:\n",
637
+ " bigwig_tracks_list: List of pyBigWig file objects\n",
638
+ " \n",
639
+ " Returns:\n",
640
+ " Array of track means, one per bigwig file\n",
641
+ " \"\"\"\n",
642
+ " track_means = []\n",
643
+ " \n",
644
+ " for bigwig_track in bigwig_tracks_list:\n",
645
+ " chrom_lengths = bigwig_track.chroms()\n",
646
+ " all_chr_stats = []\n",
647
+ " \n",
648
+ " # Compute statistics for each chromosome\n",
649
+ " for chrom_name, chrom_length in chrom_lengths.items():\n",
650
+ " try:\n",
651
+ " # Get chromosome data as numpy array\n",
652
+ " bw_array = np.array(\n",
653
+ " bigwig_track.values(chrom_name, 0, chrom_length, numpy=True),\n",
654
+ " dtype=np.float32\n",
655
+ " )\n",
656
+ " # Replace NaN with 0\n",
657
+ " bw_array = np.nan_to_num(bw_array, nan=0.0)\n",
658
+ " \n",
659
+ " # Compute chromosome-level statistics\n",
660
+ " chr_stats = compute_chromosome_stats(bw_array)\n",
661
+ " all_chr_stats.append(chr_stats)\n",
662
+ " except Exception as e:\n",
663
+ " # Skip chromosomes that fail to load\n",
664
+ " print(f\"Warning: Failed to load chromosome {chrom_name}: {e}\")\n",
665
+ " continue\n",
666
+ " \n",
667
+ " if not all_chr_stats:\n",
668
+ " raise ValueError(f\"No valid chromosomes found for bigwig track\")\n",
669
+ " \n",
670
+ " # Aggregate chromosome-level stats into file-level stats\n",
671
+ " file_stats = aggregate_file_statistics(all_chr_stats)\n",
672
+ " \n",
673
+ " # Use the weighted mean for normalization\n",
674
+ " track_means.append(file_stats[\"mean\"])\n",
675
+ " \n",
676
+ " return np.array(track_means, dtype=np.float32)\n",
677
+ "\n",
678
+ "\n",
679
+ "def create_targets_scaling_fn(bigwig_path_list: List[str]) -> Callable[[torch.Tensor], torch.Tensor]:\n",
680
+ " \"\"\"\n",
681
+ " Build a scaling function based on track means computed from bigwig files.\n",
682
+ " \n",
683
+ " Opens bigwig files, computes track statistics, and creates a transform function.\n",
684
+ " The statistics are computed once and reused for all calls to the returned transform function.\n",
685
+ " \n",
686
+ " Args:\n",
687
+ " bigwig_path_list: List of paths to bigwig files\n",
688
+ " \n",
689
+ " Returns:\n",
690
+ " Transform function that scales input tensors\n",
691
+ " \"\"\"\n",
692
+ " # Open bigwig files and compute track statistics\n",
693
+ " print(\"Computing track statistics (this may take a while)...\")\n",
694
+ " bw_list = [\n",
695
+ " pyBigWig.open(bigwig_path)\n",
696
+ " for bigwig_path in bigwig_path_list\n",
697
+ " ]\n",
698
+ " track_means = get_track_means(bw_list)\n",
699
+ " print(f\"Computed track means: {track_means}\")\n",
700
+ " print(f\"Track means shape: {track_means.shape}\")\n",
701
+ " \n",
702
+ " # Create tensor from computed means\n",
703
+ " track_means_tensor = torch.tensor(track_means, dtype=torch.float32)\n",
704
+ " \n",
705
+ " def transform_fn(x: torch.Tensor) -> torch.Tensor:\n",
706
+ " \"\"\"\n",
707
+ " x: torch.Tensor, shape (seq_len, num_tracks) or (batch, seq_len, num_tracks)\n",
708
+ " \"\"\"\n",
709
+ " # Move constants to correct device then normalize\n",
710
+ " means = track_means_tensor.to(x.device)\n",
711
+ " scaled = x / means\n",
712
+ "\n",
713
+ " # Smooth clipping: if > 10, apply formula\n",
714
+ " clipped = torch.where(\n",
715
+ " scaled > 10.0,\n",
716
+ " 2.0 * torch.sqrt(scaled * 10.0) - 10.0,\n",
717
+ " scaled,\n",
718
+ " )\n",
719
+ " return clipped\n",
720
+ " \n",
721
+ " return transform_fn"
722
+ ]
723
+ },
724
+ {
725
+ "cell_type": "code",
726
+ "execution_count": null,
727
+ "metadata": {},
728
+ "outputs": [],
729
+ "source": [
730
+ "# Create scaling function\n",
731
+ "targets_transform_fn = create_targets_scaling_fn(bigwig_path_list)"
732
+ ]
733
+ },
734
+ {
735
+ "cell_type": "code",
736
+ "execution_count": null,
737
+ "metadata": {},
738
+ "outputs": [],
739
+ "source": [
740
+ "# Create datasets & dataloaders\n",
741
+ "create_dataset_fn = functools.partial(\n",
742
+ " GenomeBigWigDataset,\n",
743
+ " fasta_path=fasta_path,\n",
744
+ " bigwig_path_list=bigwig_path_list,\n",
745
+ " sequence_length=config[\"sequence_length\"],\n",
746
+ " tokenizer=tokenizer,\n",
747
+ " transform_fn=targets_transform_fn,\n",
748
+ " keep_target_center_fraction=config[\"keep_target_center_fraction\"],\n",
749
+ ")\n",
750
+ "\n",
751
+ "train_dataset = create_dataset_fn(\n",
752
+ " chroms=chrom_splits[\"train\"],\n",
753
+ " num_samples=config[\"num_steps_training\"] * config[\"batch_size\"],\n",
754
+ ")\n",
755
+ "\n",
756
+ "val_dataset = create_dataset_fn(\n",
757
+ " chroms=chrom_splits[\"val\"],\n",
758
+ " num_samples=config[\"num_validation_samples\"],\n",
759
+ ")\n",
760
+ "\n",
761
+ "test_dataset = create_dataset_fn(\n",
762
+ " chroms=chrom_splits[\"test\"],\n",
763
+ " num_samples=config[\"num_test_samples\"],\n",
764
+ ")\n",
765
+ "\n",
766
+ "# Create dataloaders\n",
767
+ "train_loader = DataLoader(\n",
768
+ " train_dataset,\n",
769
+ " batch_size=config[\"batch_size\"],\n",
770
+ " shuffle=True,\n",
771
+ " num_workers=config[\"num_workers\"],\n",
772
+ ")\n",
773
+ "\n",
774
+ "val_loader = DataLoader(\n",
775
+ " val_dataset,\n",
776
+ " batch_size=config[\"batch_size\"],\n",
777
+ " shuffle=False,\n",
778
+ " num_workers=config[\"num_workers\"],\n",
779
+ ")\n",
780
+ "\n",
781
+ "test_loader = DataLoader(\n",
782
+ " test_dataset,\n",
783
+ " batch_size=config[\"batch_size\"],\n",
784
+ " shuffle=False,\n",
785
+ " num_workers=config[\"num_workers\"],\n",
786
+ ")\n",
787
+ "\n",
788
+ "print(f\"Train samples: {len(train_dataset)}\")\n",
789
+ "print(f\"Val samples: {len(val_dataset)}\")\n",
790
+ "print(f\"Test samples: {len(test_dataset)}\")"
791
+ ]
792
+ },
793
+ {
794
+ "cell_type": "markdown",
795
+ "metadata": {},
796
+ "source": [
797
+ "# 5. ⚙️ Optimizer setup\n",
798
+ "\n",
799
+ "Configure the AdamW optimizer with learning rate and weight decay hyperparameters. This optimizer will update the model parameters during training to minimize the loss function.\n",
800
+ "\n"
801
+ ]
802
+ },
803
+ {
804
+ "cell_type": "code",
805
+ "execution_count": null,
806
+ "metadata": {},
807
+ "outputs": [],
808
+ "source": [
809
+ "# Training setup\n",
810
+ "print(f\"Training configuration:\")\n",
811
+ "print(f\" Batch size: {config[\"batch_size\"]}\")\n",
812
+ "print(f\" Total training steps: {config[\"num_steps_training\"]}\")\n",
813
+ "print(f\" Log metrics every: {config[\"log_every_n_steps\"]} steps\")\n",
814
+ "print(f\" Validate every: {config[\"validate_every_n_steps\"]} steps\")\n",
815
+ "\n",
816
+ "# Setup optimizer\n",
817
+ "optimizer = AdamW(\n",
818
+ " model.parameters(),\n",
819
+ " lr=config[\"learning_rate\"],\n",
820
+ " weight_decay=config[\"weight_decay\"],\n",
821
+ ")\n",
822
+ "\n",
823
+ "print(f\"\\nOptimizer setup:\")\n",
824
+ "print(f\" Learning rate: {config['learning_rate']}\")"
825
+ ]
826
+ },
827
+ {
828
+ "cell_type": "markdown",
829
+ "metadata": {},
830
+ "source": [
831
+ "# 6. 📊 Metrics setup\n",
832
+ "\n",
833
+ "Set up evaluation metrics to track model performance during training and validation. We use Pearson correlation coefficients to measure how well the predicted BigWig signals match the ground truth signals."
834
+ ]
835
+ },
836
+ {
837
+ "cell_type": "code",
838
+ "execution_count": null,
839
+ "metadata": {},
840
+ "outputs": [],
841
+ "source": [
842
+ "class TracksMetrics:\n",
843
+ " \"\"\"Simple metrics tracker for tracks prediction.\"\"\"\n",
844
+ " \n",
845
+ " def __init__(self, track_names: List[str]):\n",
846
+ " self.track_names = track_names\n",
847
+ " self.num_tracks = len(track_names)\n",
848
+ " # Metrics: comparing scaled targets with scaled predictions\n",
849
+ " # Configure to use float64 for improved numerical stability\n",
850
+ " self.pearson_metrics = [\n",
851
+ " PearsonCorrCoef().to(device) for _ in range(self.num_tracks)\n",
852
+ " ]\n",
853
+ " # Set dtype to float64 to prevent NaN warnings when variance is close to zero\n",
854
+ " for metric in self.pearson_metrics:\n",
855
+ " metric.set_dtype(torch.float64)\n",
856
+ " self.losses = []\n",
857
+ " \n",
858
+ " def reset(self):\n",
859
+ " for metric in self.pearson_metrics:\n",
860
+ " metric.reset()\n",
861
+ " self.losses = []\n",
862
+ " \n",
863
+ " def update(\n",
864
+ " self, \n",
865
+ " predictions: torch.Tensor, \n",
866
+ " targets: torch.Tensor,\n",
867
+ " loss: float\n",
868
+ " ):\n",
869
+ " \"\"\"\n",
870
+ " Update metrics.\n",
871
+ " Args:\n",
872
+ " predictions: (batch, seq_len, num_tracks)\n",
873
+ " targets: (batch, seq_len, num_tracks)\n",
874
+ " loss: scalar loss value\n",
875
+ " \"\"\"\n",
876
+ " # Flatten batch and sequence dimensions\n",
877
+ " pred_flat = predictions.detach().reshape(-1, self.num_tracks) # (N, num_tracks)\n",
878
+ " target_flat = targets.detach().reshape(-1, self.num_tracks) # (N, num_tracks)\n",
879
+ " \n",
880
+ " # Convert to float64 for improved numerical stability in Pearson correlation\n",
881
+ " pred_flat = pred_flat.to(torch.float64)\n",
882
+ " target_flat = target_flat.to(torch.float64)\n",
883
+ " \n",
884
+ " # Update metrics\n",
885
+ " for i, metric in enumerate(self.pearson_metrics):\n",
886
+ " metric.update(pred_flat[:, i], target_flat[:, i])\n",
887
+ " \n",
888
+ " self.losses.append(loss)\n",
889
+ " \n",
890
+ " def compute(self) -> Dict[str, float]:\n",
891
+ " \"\"\"Compute and return all metrics.\"\"\"\n",
892
+ " metrics_dict = {}\n",
893
+ " \n",
894
+ " # Per-track Pearson correlations\n",
895
+ " for i, (track_name, metric) in enumerate(zip(self.track_names, self.pearson_metrics)):\n",
896
+ " corr = metric.compute().item()\n",
897
+ " metrics_dict[f\"{track_name}/pearson\"] = corr\n",
898
+ " \n",
899
+ " # Mean Pearson correlation\n",
900
+ " correlations = [metric.compute().item() for metric in self.pearson_metrics]\n",
901
+ " metrics_dict[\"mean/pearson\"] = np.nanmean(correlations)\n",
902
+ " \n",
903
+ " # Mean loss\n",
904
+ " metrics_dict[\"loss\"] = np.mean(self.losses) if self.losses else 0.0\n",
905
+ " \n",
906
+ " return metrics_dict"
907
+ ]
908
+ },
909
+ {
910
+ "cell_type": "code",
911
+ "execution_count": null,
912
+ "metadata": {},
913
+ "outputs": [],
914
+ "source": [
915
+ "train_metrics = TracksMetrics(config[\"bigwig_file_ids\"])\n",
916
+ "val_metrics = TracksMetrics(config[\"bigwig_file_ids\"])\n",
917
+ "test_metrics = TracksMetrics(config[\"bigwig_file_ids\"])"
918
+ ]
919
+ },
920
+ {
921
+ "cell_type": "markdown",
922
+ "metadata": {},
923
+ "source": [
924
+ "# 7. 📉 Loss functions\n",
925
+ "\n",
926
+ "Define the Poisson-Multinomial loss function that captures both the scale (total signal) and shape (distribution) of BigWig tracks. This loss is specifically designed for count-based genomic signal data."
927
+ ]
928
+ },
929
+ {
930
+ "cell_type": "code",
931
+ "execution_count": null,
932
+ "metadata": {},
933
+ "outputs": [],
934
+ "source": [
935
+ "def poisson_loss(ytrue: torch.Tensor, ypred: torch.Tensor, epsilon: float = 1e-7) -> torch.Tensor:\n",
936
+ " \"\"\"Poisson loss per element: ypred - ytrue * log(ypred).\"\"\"\n",
937
+ " return ypred - ytrue * torch.log(ypred + epsilon)\n",
938
+ "\n",
939
+ "\n",
940
+ "def safe_for_grad_log_torch(x: torch.Tensor) -> torch.Tensor:\n",
941
+ " \"\"\"Guarantees that the log is defined for all x > 0 in a differentiable way.\"\"\"\n",
942
+ " return torch.log(torch.where(x > 0.0, x, torch.ones_like(x)))\n",
943
+ "\n",
944
+ "\n",
945
+ "def poisson_multinomial_loss(\n",
946
+ " logits: torch.Tensor,\n",
947
+ " targets: torch.Tensor,\n",
948
+ " shape_loss_coefficient: float = 5.0,\n",
949
+ " epsilon: float = 1e-7,\n",
950
+ ") -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:\n",
951
+ " \"\"\"\n",
952
+ " Regression loss for bigwig tracks (Poisson-Multinomial).\n",
953
+ " \n",
954
+ " Args:\n",
955
+ " logits: (batch, seq_length, num_tracks) - predicted counts\n",
956
+ " targets: (batch, seq_length, num_tracks) - target counts\n",
957
+ " shape_loss_coefficient: coefficient to weight scale loss\n",
958
+ " epsilon: epsilon for numerical stability\n",
959
+ " \n",
960
+ " Returns:\n",
961
+ " loss, scale_loss, shape_loss\n",
962
+ " \"\"\"\n",
963
+ " batch_size, seq_length, num_tracks = logits.shape\n",
964
+ " \n",
965
+ " # Scale loss: Poisson loss on total counts per sequence per track\n",
966
+ " # Sum over sequence dimension (axis=1)\n",
967
+ " sum_pred = logits.sum(dim=1) # (batch, num_tracks)\n",
968
+ " sum_true = targets.sum(dim=1) # (batch, num_tracks)\n",
969
+ " \n",
970
+ " # Compute poisson loss per (batch, track)\n",
971
+ " scale_loss = poisson_loss(sum_true, sum_pred, epsilon=epsilon) # (batch, num_tracks)\n",
972
+ " \n",
973
+ " # Normalize by sequence length\n",
974
+ " scale_loss = scale_loss / (seq_length + epsilon)\n",
975
+ " \n",
976
+ " # Average over batch and tracks\n",
977
+ " scale_loss = scale_loss.mean()\n",
978
+ " \n",
979
+ " # Shape loss: Multinomial loss\n",
980
+ " # Add epsilon to all positions\n",
981
+ " predicted_counts = logits + epsilon\n",
982
+ " targets_with_epsilon = targets + epsilon\n",
983
+ " \n",
984
+ " # Normalize predictions to get probabilities\n",
985
+ " denom = predicted_counts.sum(dim=1, keepdim=True) + epsilon # (batch, 1, num_tracks)\n",
986
+ " p_pred = predicted_counts / denom\n",
987
+ " \n",
988
+ " # Compute shape loss: -sum(targets * log(p_pred))\n",
989
+ " pl_pred = safe_for_grad_log_torch(p_pred)\n",
990
+ " shape_loss = -(targets_with_epsilon * pl_pred)\n",
991
+ " \n",
992
+ " # Sum over all dimensions and normalize by total number of positions\n",
993
+ " shape_denom = batch_size * seq_length * num_tracks + epsilon\n",
994
+ " shape_loss = shape_loss.sum() / shape_denom\n",
995
+ " \n",
996
+ " # Combine losses\n",
997
+ " loss = shape_loss + scale_loss / shape_loss_coefficient\n",
998
+ "\n",
999
+ " return loss, scale_loss, shape_loss\n"
1000
+ ]
1001
+ },
1002
+ {
1003
+ "cell_type": "markdown",
1004
+ "metadata": {},
1005
+ "source": [
1006
+ "# 8. 🏃 Training loop\n",
1007
+ "\n",
1008
+ "Run the main training loop that iterates through batches, computes gradients, and updates model parameters. The loop includes periodic validation checks and real-time metric visualization to monitor training progress."
1009
+ ]
1010
+ },
1011
+ {
1012
+ "cell_type": "code",
1013
+ "execution_count": null,
1014
+ "metadata": {},
1015
+ "outputs": [],
1016
+ "source": [
1017
+ "def train_step(\n",
1018
+ " model: nn.Module,\n",
1019
+ " batch: Dict[str, torch.Tensor],\n",
1020
+ ") -> float:\n",
1021
+ " \"\"\"Single training step.\"\"\"\n",
1022
+ " tokens = batch[\"tokens\"].to(device)\n",
1023
+ " bigwig_targets = batch[\"bigwig_targets\"].to(device)\n",
1024
+ " \n",
1025
+ " # Forward pass\n",
1026
+ " outputs = model(tokens=tokens)\n",
1027
+ " bigwig_logits = outputs[\"bigwig_tracks_logits\"]\n",
1028
+ " \n",
1029
+ " # Compute loss\n",
1030
+ " loss, _, _ = poisson_multinomial_loss(\n",
1031
+ " logits=bigwig_logits,\n",
1032
+ " targets=bigwig_targets,\n",
1033
+ " )\n",
1034
+ " \n",
1035
+ " # Backward pass\n",
1036
+ " loss.backward()\n",
1037
+ " return loss.item()\n",
1038
+ "\n",
1039
+ "\n",
1040
+ "def validation_step(\n",
1041
+ " model: nn.Module,\n",
1042
+ " batch: Dict[str, torch.Tensor],\n",
1043
+ " metrics: TracksMetrics,\n",
1044
+ ") -> float:\n",
1045
+ " \"\"\"Single validation step.\"\"\"\n",
1046
+ " model.eval()\n",
1047
+ " \n",
1048
+ " tokens = batch[\"tokens\"].to(device)\n",
1049
+ " bigwig_targets = batch[\"bigwig_targets\"].to(device)\n",
1050
+ " \n",
1051
+ " with torch.no_grad():\n",
1052
+ " # Forward pass\n",
1053
+ " outputs = model(tokens=tokens)\n",
1054
+ " bigwig_logits = outputs[\"bigwig_tracks_logits\"]\n",
1055
+ " \n",
1056
+ " # Compute loss\n",
1057
+ " loss, _, _ = poisson_multinomial_loss(\n",
1058
+ " logits=bigwig_logits,\n",
1059
+ " targets=bigwig_targets,\n",
1060
+ " )\n",
1061
+ " \n",
1062
+ " # Update metrics\n",
1063
+ " metrics.update(\n",
1064
+ " predictions=bigwig_logits,\n",
1065
+ " targets=bigwig_targets,\n",
1066
+ " loss=loss.item()\n",
1067
+ " )\n",
1068
+ " \n",
1069
+ " return loss.item()"
1070
+ ]
1071
+ },
1072
+ {
1073
+ "cell_type": "markdown",
1074
+ "metadata": {},
1075
+ "source": [
1076
+ "### Interactive plotting is temporary for debug"
1077
+ ]
1078
+ },
1079
+ {
1080
+ "cell_type": "code",
1081
+ "execution_count": null,
1082
+ "metadata": {},
1083
+ "outputs": [],
1084
+ "source": [
1085
+ "# Training loop\n",
1086
+ "print(\"Starting training...\")\n",
1087
+ "print(f\"Training for {config['num_steps_training']} steps\\n\")\n",
1088
+ "\n",
1089
+ "model.train()\n",
1090
+ "train_metrics.reset()\n",
1091
+ "optimizer.zero_grad() # Initialize gradients\n",
1092
+ "\n",
1093
+ "# Track metrics for plotting\n",
1094
+ "train_steps = []\n",
1095
+ "train_losses = []\n",
1096
+ "train_pearson_scores = []\n",
1097
+ "val_steps = []\n",
1098
+ "val_losses = []\n",
1099
+ "val_pearson_scores = []\n",
1100
+ "\n",
1101
+ "# Initialize interactive plots using FigureWidget for real-time updates\n",
1102
+ "from plotly.graph_objects import FigureWidget\n",
1103
+ "from plotly.subplots import make_subplots\n",
1104
+ "\n",
1105
+ "# Create base figure with subplots\n",
1106
+ "fig_base = make_subplots(\n",
1107
+ " rows=1, cols=2,\n",
1108
+ " subplot_titles=('Loss', 'Mean Pearson Correlation'),\n",
1109
+ " horizontal_spacing=0.15,\n",
1110
+ ")\n",
1111
+ "\n",
1112
+ "# Add empty traces for train and val metrics\n",
1113
+ "fig_base.add_trace(\n",
1114
+ " go.Scatter(x=[], y=[], mode='lines+markers', name='Train Loss', line=dict(color='blue')),\n",
1115
+ " row=1, col=1\n",
1116
+ ")\n",
1117
+ "fig_base.add_trace(\n",
1118
+ " go.Scatter(x=[], y=[], mode='lines+markers', name='Val Loss', line=dict(color='red')),\n",
1119
+ " row=1, col=1\n",
1120
+ ")\n",
1121
+ "fig_base.add_trace(\n",
1122
+ " go.Scatter(x=[], y=[], mode='lines+markers', name='Train Pearson', line=dict(color='green')),\n",
1123
+ " row=1, col=2\n",
1124
+ ")\n",
1125
+ "fig_base.add_trace(\n",
1126
+ " go.Scatter(x=[], y=[], mode='lines+markers', name='Val Pearson', line=dict(color='orange')),\n",
1127
+ " row=1, col=2\n",
1128
+ ")\n",
1129
+ "\n",
1130
+ "fig_base.update_xaxes(title_text=\"Step\", row=1, col=1)\n",
1131
+ "fig_base.update_xaxes(title_text=\"Step\", row=1, col=2)\n",
1132
+ "fig_base.update_yaxes(title_text=\"Loss\", row=1, col=1)\n",
1133
+ "fig_base.update_yaxes(title_text=\"Pearson Correlation\", row=1, col=2)\n",
1134
+ "fig_base.update_layout(height=800, width=1600, showlegend=True)\n",
1135
+ "\n",
1136
+ "# Convert to FigureWidget for interactive updates\n",
1137
+ "fig = FigureWidget(fig_base)\n",
1138
+ "\n",
1139
+ "# Display initial plot (will update in place during training)\n",
1140
+ "display(fig)\n",
1141
+ "\n",
1142
+ "# Create iterator for training data (will cycle if needed)\n",
1143
+ "train_iter = iter(train_loader)\n",
1144
+ "\n",
1145
+ "# Main training loop\n",
1146
+ "for step_idx in range(config[\"num_steps_training\"]):\n",
1147
+ " try:\n",
1148
+ " batch = next(train_iter)\n",
1149
+ " except StopIteration:\n",
1150
+ " # Restart iterator if we run out of data\n",
1151
+ " train_iter = iter(train_loader)\n",
1152
+ " batch = next(train_iter)\n",
1153
+ " \n",
1154
+ " # Forward pass and backward pass\n",
1155
+ " loss = train_step(model, batch)\n",
1156
+ " \n",
1157
+ " # Update optimizer\n",
1158
+ " optimizer.step()\n",
1159
+ " optimizer.zero_grad()\n",
1160
+ " \n",
1161
+ " # Update metrics\n",
1162
+ " tokens = batch[\"tokens\"].to(device)\n",
1163
+ " bigwig_targets = batch[\"bigwig_targets\"].to(device)\n",
1164
+ " with torch.no_grad():\n",
1165
+ " outputs = model(tokens=tokens)\n",
1166
+ " bigwig_logits = outputs[\"bigwig_tracks_logits\"]\n",
1167
+ " \n",
1168
+ " train_metrics.update(\n",
1169
+ " predictions=bigwig_logits,\n",
1170
+ " targets=bigwig_targets,\n",
1171
+ " loss=loss\n",
1172
+ " )\n",
1173
+ " \n",
1174
+ " # Logging\n",
1175
+ " if (step_idx + 1) % config[\"log_every_n_steps\"] == 0:\n",
1176
+ " train_metrics_dict = train_metrics.compute()\n",
1177
+ " \n",
1178
+ " # Get accumulated mean loss across all batches since last reset\n",
1179
+ " mean_loss = train_metrics_dict['loss']\n",
1180
+ " \n",
1181
+ " # Track metrics for plotting\n",
1182
+ " train_steps.append(step_idx + 1)\n",
1183
+ " train_losses.append(mean_loss)\n",
1184
+ " train_pearson_scores.append(train_metrics_dict['mean/pearson'])\n",
1185
+ " \n",
1186
+ " # Update plots - direct assignment to FigureWidget data updates the plot automatically\n",
1187
+ " fig.data[0].x = train_steps\n",
1188
+ " fig.data[0].y = train_losses\n",
1189
+ " fig.data[2].x = train_steps\n",
1190
+ " fig.data[2].y = train_pearson_scores\n",
1191
+ " \n",
1192
+ " print(\n",
1193
+ " f\"Step {step_idx + 1}/{config['num_steps_training']} | \"\n",
1194
+ " f\"Loss: {mean_loss:.4f} | \"\n",
1195
+ " f\"Mean Pearson: {train_metrics_dict['mean/pearson']:.4f} | \"\n",
1196
+ " f\"Pearson per track: {train_metrics_dict[f'{track_name}/pearson']:.4f for track_name in config['bigwig_file_ids']}\"\n",
1197
+ " )\n",
1198
+ " train_metrics.reset()\n",
1199
+ " \n",
1200
+ " # Validation\n",
1201
+ " if (step_idx + 1) % config[\"validate_every_n_steps\"] == 0:\n",
1202
+ " print(f\"\\nRunning validation at step {step_idx + 1}...\")\n",
1203
+ " val_metrics.reset()\n",
1204
+ " model.eval()\n",
1205
+ " \n",
1206
+ " for val_batch in val_loader:\n",
1207
+ " val_loss = validation_step(model, val_batch, val_metrics)\n",
1208
+ " \n",
1209
+ " # Print validation metrics\n",
1210
+ " val_metrics_dict = val_metrics.compute()\n",
1211
+ " val_pearson_mean = val_metrics_dict['mean/pearson']\n",
1212
+ " \n",
1213
+ " # Track validation metrics\n",
1214
+ " val_steps.append(step_idx + 1)\n",
1215
+ " val_losses.append(val_metrics_dict['loss'])\n",
1216
+ " val_pearson_scores.append(val_pearson_mean)\n",
1217
+ " \n",
1218
+ " # Update plots with validation data - direct assignment updates the plot automatically\n",
1219
+ " fig.data[1].x = val_steps\n",
1220
+ " fig.data[1].y = val_losses\n",
1221
+ " fig.data[3].x = val_steps\n",
1222
+ " fig.data[3].y = val_pearson_scores\n",
1223
+ " \n",
1224
+ " print(f\" Validation Loss: {val_metrics_dict['loss']:.4f}\")\n",
1225
+ " print(f\" Validation Mean Pearson: {val_pearson_mean:.4f}\")\n",
1226
+ " for track_name in config[\"bigwig_file_ids\"]:\n",
1227
+ " print(f\" {track_name}/pearson: {val_metrics_dict[f'{track_name}/pearson']:.4f}\")\n",
1228
+ " \n",
1229
+ " model.train() # Back to training mode\n",
1230
+ "\n",
1231
+ "print(f\"\\nTraining completed after {config['num_steps_training']} steps.\")"
1232
+ ]
1233
+ },
1234
+ {
1235
+ "cell_type": "markdown",
1236
+ "metadata": {},
1237
+ "source": [
1238
+ "# 9. 🧪 Test evaluation\n",
1239
+ "\n",
1240
+ "Evaluate the fine-tuned model on the held-out test set to assess final performance. This provides an unbiased estimate of how well the model generalizes to unseen genomic regions."
1241
+ ]
1242
+ },
1243
+ {
1244
+ "cell_type": "code",
1245
+ "execution_count": null,
1246
+ "metadata": {},
1247
+ "outputs": [],
1248
+ "source": [
1249
+ "# Calculate number of test steps (based on deepspeed pipeline)\n",
1250
+ "num_test_samples = len(test_dataset)\n",
1251
+ "num_test_steps = num_test_samples // config[\"batch_size\"]\n",
1252
+ "print(f\"Running test evaluation with {num_test_steps} steps ({num_test_samples} samples)\")\n",
1253
+ "\n",
1254
+ "# Set model to eval mode\n",
1255
+ "model.eval()\n",
1256
+ "\n",
1257
+ "# Run test evaluation with progress bar\n",
1258
+ "for test_batch in tqdm(test_loader, desc=\"Test evaluation\", total=num_test_steps): \n",
1259
+ " _ = validation_step( \n",
1260
+ " model, \n",
1261
+ " test_batch, \n",
1262
+ " test_metrics,\n",
1263
+ " )\n",
1264
+ " \n",
1265
+ "# Compute final test metrics\n",
1266
+ "test_metrics_dict = test_metrics.compute()\n",
1267
+ "print(\"\\n\" + \"=\"*50)\n",
1268
+ "print(\"Test Set Results\")\n",
1269
+ "print(\"=\"*50)\n",
1270
+ "print(f\"\\nMetrics:\")\n",
1271
+ "print(f\" Mean Pearson: {test_metrics_dict['mean/pearson']:.4f}\")\n",
1272
+ "for track_name in config[\"bigwig_file_ids\"]: \n",
1273
+ " print(f\" {track_name}/pearson: {test_metrics_dict[f'{track_name}/pearson']:.4f}\")"
1274
+ ]
1275
+ },
1276
+ {
1277
+ "cell_type": "markdown",
1278
+ "metadata": {},
1279
+ "source": [
1280
+ " ## Test set results\n",
1281
+ "\n",
1282
+ "Performances reached at ~1.5B tokens (~1500 steps in current 32kb sequences setup with batch_size=32)\n",
1283
+ "\n",
1284
+ "Mean Pearson: 0.5835\n",
1285
+ "- ENCSR325NFE/pearson: 0.6081\n",
1286
+ "- ENCSR962OTG/pearson: 0.7286\n",
1287
+ "- ENCSR619DQO_P/pearson: 0.4976\n",
1288
+ "- ENCSR619DQO_M/pearson: 0.4999"
1289
+ ]
1290
+ },
1291
+ {
1292
+ "cell_type": "markdown",
1293
+ "metadata": {},
1294
+ "source": []
1295
+ }
1296
+ ],
1297
+ "metadata": {
1298
+ "kernelspec": {
1299
+ "display_name": "Python 3.12 (ntv3-env)",
1300
+ "language": "python",
1301
+ "name": "ntv3-env"
1302
+ },
1303
+ "language_info": {
1304
+ "codemirror_mode": {
1305
+ "name": "ipython",
1306
+ "version": 3
1307
+ },
1308
+ "file_extension": ".py",
1309
+ "mimetype": "text/x-python",
1310
+ "name": "python",
1311
+ "nbconvert_exporter": "python",
1312
+ "pygments_lexer": "ipython3",
1313
+ "version": "3.12.3"
1314
+ }
1315
+ },
1316
+ "nbformat": 4,
1317
+ "nbformat_minor": 2
1318
+ }
tabs/home.html CHANGED
@@ -84,7 +84,7 @@
84
  <ul>
85
  <li><a href="https://huggingface.co/spaces/InstaDeepAI/ntv3/blob/main/notebooks_tutorials/00_quickstart_inference.ipynb" target="_blank" rel="noopener noreferrer">🚀 00 — Quickstart inference</a></li>
86
  <li><a href="https://huggingface.co/spaces/InstaDeepAI/ntv3/blob/main/notebooks_tutorials/01_tracks_prediction.ipynb" target="_blank" rel="noopener noreferrer">📊 01 — Tracks prediction</a></li>
87
- <li>🎯 02 — Fine-tune on bigwig tracks</li>
88
  <li>🔍 03 — Model interpretation</li>
89
  <li>🧪 04 — Training NTv3 generative </li>
90
  </ul>
 
84
  <ul>
85
  <li><a href="https://huggingface.co/spaces/InstaDeepAI/ntv3/blob/main/notebooks_tutorials/00_quickstart_inference.ipynb" target="_blank" rel="noopener noreferrer">🚀 00 — Quickstart inference</a></li>
86
  <li><a href="https://huggingface.co/spaces/InstaDeepAI/ntv3/blob/main/notebooks_tutorials/01_tracks_prediction.ipynb" target="_blank" rel="noopener noreferrer">📊 01 — Tracks prediction</a></li>
87
+ <li><a href="https://huggingface.co/spaces/InstaDeepAI/ntv3/blob/main/notebooks_tutorials/02_fine_tuning.ipynb" target="_blank" rel="noopener noreferrer">🎯 02 — Fine-tune on bigwig tracks</a></li>
88
  <li>🔍 03 — Model interpretation</li>
89
  <li>🧪 04 — Training NTv3 generative </li>
90
  </ul>