ybornachot commited on
Commit
0861370
·
1 Parent(s): eadf098

feat: multiproc data download + cleaned cells

Browse files
Files changed (1) hide show
  1. notebooks/03_fine_tuning.ipynb +260 -215
notebooks/03_fine_tuning.ipynb CHANGED
@@ -4,13 +4,13 @@
4
  "cell_type": "markdown",
5
  "metadata": {},
6
  "source": [
7
- "# \ud83e\uddec Fine-Tuning a Model on BigWig Tracks Prediction\n",
8
  "\n",
9
  "This notebook demonstrates a **simplified fine-tuning setup** that enables training of a pre-trained Nucleotide Transformer v3 (NTv3) model to predict BigWig signal tracks directly from DNA sequences. The streamlined approach leverages a pre-trained NTv3 backbone as a feature extractor and adds a custom prediction head that outputs single-nucleotide resolution signal values for various genomic tracks (e.g., ChIP-seq, ATAC-seq, RNA-seq).\n",
10
  "\n",
11
- "**\u26a1 Key Advantage**: This simplified pipeline achieves **close performance to more complex training approaches** while enabling **relatively fast fine-tuning in approximately one hour**. The setup is designed for rapid experimentation and iteration, making it ideal for adapting pre-trained models to your specific genomic tracks or experimental conditions without the overhead of complex distributed training infrastructure.\n",
12
  "\n",
13
- "**\ud83d\udd27 Main Simplifications**: Compared to the full supervised tracks pipeline, this notebook simplifies several aspects to enable faster iteration:\n",
14
  "\n",
15
  "- **Data splits**: Uses simple chromosome-based train/val/test splits (e.g., assigning entire chromosomes to each split) instead of more complex region-based splits\n",
16
  "- **Random sequence sampling**: The dataset randomly samples sequences from chromosomes/regions on-the-fly, rather than using pre-computed sliding windows\n",
@@ -30,7 +30,14 @@
30
  "\n",
31
  "If you're interested in using pre-trained models for inference without fine-tuning, or exploring different model architectures, please refer to other notebooks in this collection. This notebook focuses specifically on the simplified fine-tuning process, which is useful when you want to quickly adapt a pre-trained model to your specific genomic tracks or improve performance on particular cell types or experimental conditions.\n",
32
  "\n",
33
- "\ud83d\udcdd Note for Google Colab users: This notebook is compatible with Colab! For faster training, make sure to enable GPU: Runtime \u2192 Change runtime type \u2192 GPU (T4 or better recommended).\n"
 
 
 
 
 
 
 
34
  ]
35
  },
36
  {
@@ -43,11 +50,39 @@
43
  "!pip install pyfaidx pyBigWig torchmetrics transformers plotly"
44
  ]
45
  },
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46
  {
47
  "cell_type": "markdown",
48
  "metadata": {},
49
  "source": [
50
- "# 1. \ud83d\udce6 Imports + Configuration\n",
51
  "\n",
52
  "## Configuration Parameters\n",
53
  "\n",
@@ -81,37 +116,17 @@
81
  },
82
  {
83
  "cell_type": "code",
84
- "execution_count": null,
85
  "metadata": {},
86
- "outputs": [],
87
- "source": [
88
- "# 0. Imports\n",
89
- "import random\n",
90
- "import functools\n",
91
- "from typing import List, Dict, Callable\n",
92
- "import os\n",
93
- "import subprocess\n",
94
- "\n",
95
- "import torch\n",
96
- "import torch.nn as nn\n",
97
- "import torch.nn.functional as F\n",
98
- "from torch.utils.data import Dataset, DataLoader\n",
99
- "from torch.optim import AdamW\n",
100
- "from transformers import AutoConfig, AutoModelForMaskedLM, AutoTokenizer\n",
101
- "import numpy as np\n",
102
- "import pyBigWig\n",
103
- "from pyfaidx import Fasta\n",
104
- "from torchmetrics import PearsonCorrCoef\n",
105
- "import plotly.graph_objects as go\n",
106
- "from IPython.display import display\n",
107
- "from tqdm import tqdm"
108
- ]
109
- },
110
- {
111
- "cell_type": "code",
112
- "execution_count": null,
113
- "metadata": {},
114
- "outputs": [],
115
  "source": [
116
  "config = {\n",
117
  " # Model\n",
@@ -121,7 +136,6 @@
121
  " \"data_cache_dir\": \"./data\",\n",
122
  " \"fasta_url\": \"https://hgdownload.gi.ucsc.edu/goldenPath/hg38/bigZips/hg38.fa.gz\",\n",
123
  " \"bigwig_url_list\": [\n",
124
- " # \"https://www.encodeproject.org/files/ENCFF884LDL/@@download/ENCFF884LDL.bigWig\",\n",
125
  " \"https://www.encodeproject.org/files/ENCFF055QKS/@@download/ENCFF055QKS.bigWig\",\n",
126
  " \"https://www.encodeproject.org/files/ENCFF214GOQ/@@download/ENCFF214GOQ.bigWig\",\n",
127
  " \"https://www.encodeproject.org/files/ENCFF592NIB/@@download/ENCFF592NIB.bigWig\",\n",
@@ -167,6 +181,8 @@
167
  " for url in config[\"bigwig_url_list\"]\n",
168
  "]\n",
169
  "\n",
 
 
170
  "# Create bigwig_file_ids from filenames (without extension)\n",
171
  "config[\"bigwig_file_ids\"] = [\n",
172
  " # os.path.splitext(extract_filename_from_url(url))[0]\n",
@@ -190,7 +206,7 @@
190
  "cell_type": "markdown",
191
  "metadata": {},
192
  "source": [
193
- "# 2. \ud83d\udce5 Genome & Tracks Data Download\n",
194
  "\n",
195
  "Download the reference genome FASTA file and BigWig signal tracks from public repositories. These files contain the genomic sequences and experimental signal data (e.g., ChIP-seq, ATAC-seq) that we'll use for training."
196
  ]
@@ -201,36 +217,56 @@
201
  "metadata": {},
202
  "outputs": [],
203
  "source": [
204
- "# Download fasta file\n",
 
 
 
 
 
 
 
205
  "fasta_filename = extract_filename_from_url(config[\"fasta_url\"])\n",
206
  "fasta_gz_path = os.path.join(config[\"data_cache_dir\"], fasta_filename)\n",
 
207
  "\n",
208
- "print(f\"Downloading {fasta_filename}...\")\n",
209
- "subprocess.run([\"wget\", \"-c\", config[\"fasta_url\"], \"-O\", fasta_gz_path], check=True)\n",
210
- "\n",
211
- "print(f\"Extracting {fasta_filename}...\")\n",
212
- "subprocess.run([\"gunzip\", \"-f\", fasta_gz_path], check=True)"
213
- ]
214
- },
215
- {
216
- "cell_type": "code",
217
- "execution_count": null,
218
- "metadata": {},
219
- "outputs": [],
220
- "source": [
221
- "# Download bigwig files\n",
222
  "for bigwig_url in config[\"bigwig_url_list\"]:\n",
223
  " filename = extract_filename_from_url(bigwig_url)\n",
224
  " filepath = os.path.join(config[\"data_cache_dir\"], filename)\n",
225
- " print(f\"Downloading {filename}...\")\n",
226
- " subprocess.run([\"wget\", \"-c\", bigwig_url, \"-O\", filepath], check=True)"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
227
  ]
228
  },
229
  {
230
  "cell_type": "markdown",
231
  "metadata": {},
232
  "source": [
233
- "## Data Splits Definition"
234
  ]
235
  },
236
  {
@@ -250,7 +286,7 @@
250
  "cell_type": "markdown",
251
  "metadata": {},
252
  "source": [
253
- "# 3. \ud83e\udde0 Model and tokenizer setup\n",
254
  " \n",
255
  "In this section, we set up the model and tokenizer. \n",
256
  " \n",
@@ -354,163 +390,11 @@
354
  "print(f\"Model parameters: {sum(p.numel() for p in model.parameters()):,}\")"
355
  ]
356
  },
357
- {
358
- "cell_type": "code",
359
- "execution_count": null,
360
- "metadata": {},
361
- "outputs": [],
362
- "source": [
363
- "# Scaling functions for targets\n",
364
- "def compute_chromosome_stats(track_data: np.ndarray) -> dict:\n",
365
- " \"\"\"\n",
366
- " Compute minimal statistics needed for weighted mean computation.\n",
367
- " \n",
368
- " Args:\n",
369
- " track_data: numpy array of track values for a chromosome\n",
370
- " \n",
371
- " Returns:\n",
372
- " Dictionary with statistics: sum, mean, total_count\n",
373
- " \"\"\"\n",
374
- " track_data = track_data.astype(np.float32)\n",
375
- " \n",
376
- " # Compute statistics\n",
377
- " sum_all = np.sum(track_data)\n",
378
- " total_count = track_data.size\n",
379
- " mean_all = sum_all / total_count if total_count > 0 else 0.0\n",
380
- " \n",
381
- " return {\n",
382
- " \"sum\": sum_all,\n",
383
- " \"mean\": mean_all,\n",
384
- " \"total_count\": total_count,\n",
385
- " }\n",
386
- "\n",
387
- "\n",
388
- "def aggregate_file_statistics(chr_stats_list: List[dict]) -> dict:\n",
389
- " \"\"\"\n",
390
- " Aggregate chromosome-level statistics into file-level statistics.\n",
391
- " \n",
392
- " Args:\n",
393
- " chr_stats_list: List of dictionaries, each containing chromosome-level statistics\n",
394
- " \n",
395
- " Returns:\n",
396
- " Dictionary with aggregated file-level statistics (only mean)\n",
397
- " \"\"\"\n",
398
- " # Convert to arrays for easier computation\n",
399
- " total_counts = np.array([s[\"total_count\"] for s in chr_stats_list], dtype=np.int64)\n",
400
- " means = np.array([s[\"mean\"] for s in chr_stats_list], dtype=np.float32)\n",
401
- " sums = np.array([s[\"sum\"] for s in chr_stats_list], dtype=np.float32)\n",
402
- " \n",
403
- " # Aggregate total count\n",
404
- " total_count = np.sum(total_counts)\n",
405
- " \n",
406
- " # Weighted mean: mean = sum(mean_chr * total_count_chr) / sum(total_count_chr)\n",
407
- " mean = np.sum(means * total_counts) / total_count if total_count > 0 else 0.0\n",
408
- " \n",
409
- " return {\n",
410
- " \"total_count\": total_count,\n",
411
- " \"sum\": np.sum(sums),\n",
412
- " \"mean\": mean,\n",
413
- " }\n",
414
- "\n",
415
- "\n",
416
- "def get_track_means(bigwig_tracks_list: List[pyBigWig.pyBigWig]) -> np.ndarray:\n",
417
- " \"\"\"\n",
418
- " Get track means for normalization.\n",
419
- " Computes statistics per chromosome and aggregates using weighted averaging,\n",
420
- " \n",
421
- " Args:\n",
422
- " bigwig_tracks_list: List of pyBigWig file objects\n",
423
- " \n",
424
- " Returns:\n",
425
- " Array of track means, one per bigwig file\n",
426
- " \"\"\"\n",
427
- " track_means = []\n",
428
- " \n",
429
- " for bigwig_track in bigwig_tracks_list:\n",
430
- " chrom_lengths = bigwig_track.chroms()\n",
431
- " all_chr_stats = []\n",
432
- " \n",
433
- " # Compute statistics for each chromosome\n",
434
- " for chrom_name, chrom_length in chrom_lengths.items():\n",
435
- " try:\n",
436
- " # Get chromosome data as numpy array\n",
437
- " bw_array = np.array(\n",
438
- " bigwig_track.values(chrom_name, 0, chrom_length, numpy=True),\n",
439
- " dtype=np.float32\n",
440
- " )\n",
441
- " # Replace NaN with 0\n",
442
- " bw_array = np.nan_to_num(bw_array, nan=0.0)\n",
443
- " \n",
444
- " # Compute chromosome-level statistics\n",
445
- " chr_stats = compute_chromosome_stats(bw_array)\n",
446
- " all_chr_stats.append(chr_stats)\n",
447
- " except Exception as e:\n",
448
- " # Skip chromosomes that fail to load\n",
449
- " print(f\"Warning: Failed to load chromosome {chrom_name}: {e}\")\n",
450
- " continue\n",
451
- " \n",
452
- " if not all_chr_stats:\n",
453
- " raise ValueError(f\"No valid chromosomes found for bigwig track\")\n",
454
- " \n",
455
- " # Aggregate chromosome-level stats into file-level stats\n",
456
- " file_stats = aggregate_file_statistics(all_chr_stats)\n",
457
- " \n",
458
- " # Use the weighted mean for normalization\n",
459
- " track_means.append(file_stats[\"mean\"])\n",
460
- " \n",
461
- " return np.array(track_means, dtype=np.float32)\n",
462
- "\n",
463
- "\n",
464
- "def create_targets_scaling_fn(bigwig_path_list: List[str]) -> Callable[[torch.Tensor], torch.Tensor]:\n",
465
- " \"\"\"\n",
466
- " Build a scaling function based on track means computed from bigwig files.\n",
467
- " \n",
468
- " Opens bigwig files, computes track statistics, and creates a transform function.\n",
469
- " The statistics are computed once and reused for all calls to the returned transform function.\n",
470
- " \n",
471
- " Args:\n",
472
- " bigwig_path_list: List of paths to bigwig files\n",
473
- " \n",
474
- " Returns:\n",
475
- " Transform function that scales input tensors\n",
476
- " \"\"\"\n",
477
- " # Open bigwig files and compute track statistics\n",
478
- " print(\"Computing track statistics (this may take a while)...\")\n",
479
- " bw_list = [\n",
480
- " pyBigWig.open(bigwig_path)\n",
481
- " for bigwig_path in bigwig_path_list\n",
482
- " ]\n",
483
- " track_means = get_track_means(bw_list)\n",
484
- " print(f\"Computed track means: {track_means}\")\n",
485
- " print(f\"Track means shape: {track_means.shape}\")\n",
486
- " \n",
487
- " # Create tensor from computed means\n",
488
- " track_means_tensor = torch.tensor(track_means, dtype=torch.float32)\n",
489
- " \n",
490
- " def transform_fn(x: torch.Tensor) -> torch.Tensor:\n",
491
- " \"\"\"\n",
492
- " x: torch.Tensor, shape (seq_len, num_tracks) or (batch, seq_len, num_tracks)\n",
493
- " \"\"\"\n",
494
- " # Move constants to correct device then normalize\n",
495
- " means = track_means_tensor.to(x.device)\n",
496
- " scaled = x / means\n",
497
- "\n",
498
- " # Smooth clipping: if > 10, apply formula\n",
499
- " clipped = torch.where(\n",
500
- " scaled > 10.0,\n",
501
- " 2.0 * torch.sqrt(scaled * 10.0) - 10.0,\n",
502
- " scaled,\n",
503
- " )\n",
504
- " return clipped\n",
505
- " \n",
506
- " return transform_fn"
507
- ]
508
- },
509
  {
510
  "cell_type": "markdown",
511
  "metadata": {},
512
  "source": [
513
- "# 4. \ud83d\udd04 Data loading\n",
514
  "\n",
515
  "Create PyTorch datasets and data loaders that efficiently sample random genomic windows from the reference genome and extract corresponding BigWig signal values. The dataset handles sequence tokenization, target scaling, and chromosome-based train/val/test splits."
516
  ]
@@ -678,6 +562,165 @@
678
  " return sample"
679
  ]
680
  },
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
681
  {
682
  "cell_type": "code",
683
  "execution_count": null,
@@ -751,7 +794,7 @@
751
  "cell_type": "markdown",
752
  "metadata": {},
753
  "source": [
754
- "# 5. \u2699\ufe0f Optimizer setup\n",
755
  "\n",
756
  "Configure the AdamW optimizer with learning rate and weight decay hyperparameters. This optimizer will update the model parameters during training to minimize the loss function.\n",
757
  "\n"
@@ -785,7 +828,7 @@
785
  "cell_type": "markdown",
786
  "metadata": {},
787
  "source": [
788
- "# 6. \ud83d\udcca Metrics setup\n",
789
  "\n",
790
  "Set up evaluation metrics to track model performance during training and validation. We use Pearson correlation coefficients to measure how well the predicted BigWig signals match the ground truth signals."
791
  ]
@@ -878,7 +921,7 @@
878
  "cell_type": "markdown",
879
  "metadata": {},
880
  "source": [
881
- "# 7. \ud83d\udcc9 Loss functions\n",
882
  "\n",
883
  "Define the Poisson-Multinomial loss function that captures both the scale (total signal) and shape (distribution) of BigWig tracks. This loss is specifically designed for count-based genomic signal data."
884
  ]
@@ -960,7 +1003,7 @@
960
  "cell_type": "markdown",
961
  "metadata": {},
962
  "source": [
963
- "# 8. \ud83c\udfc3 Training loop\n",
964
  "\n",
965
  "Run the main training loop that iterates through batches, computes gradients, and updates model parameters. The loop includes periodic validation checks and real-time metric visualization to monitor training progress."
966
  ]
@@ -1192,7 +1235,7 @@
1192
  "cell_type": "markdown",
1193
  "metadata": {},
1194
  "source": [
1195
- "# 9. \ud83e\uddea Test evaluation\n",
1196
  "\n",
1197
  "Evaluate the fine-tuned model on the held-out test set to assess final performance. This provides an unbiased estimate of how well the model generalizes to unseen genomic regions."
1198
  ]
@@ -1236,6 +1279,8 @@
1236
  "source": [
1237
  " ## Test set results\n",
1238
  "\n",
 
 
1239
  "Mean Pearson: 0.5835\n",
1240
  "- ENCSR325NFE/pearson: 0.6081\n",
1241
  "- ENCSR962OTG/pearson: 0.7286\n",
@@ -1270,4 +1315,4 @@
1270
  },
1271
  "nbformat": 4,
1272
  "nbformat_minor": 2
1273
- }
 
4
  "cell_type": "markdown",
5
  "metadata": {},
6
  "source": [
7
+ "# 🧬 Fine-Tuning a Model on BigWig Tracks Prediction\n",
8
  "\n",
9
  "This notebook demonstrates a **simplified fine-tuning setup** that enables training of a pre-trained Nucleotide Transformer v3 (NTv3) model to predict BigWig signal tracks directly from DNA sequences. The streamlined approach leverages a pre-trained NTv3 backbone as a feature extractor and adds a custom prediction head that outputs single-nucleotide resolution signal values for various genomic tracks (e.g., ChIP-seq, ATAC-seq, RNA-seq).\n",
10
  "\n",
11
+ "**⚡ Key Advantage**: This simplified pipeline achieves close performance to more complex training approaches while enabling fast fine-tuning. The training speed benefits from the efficient NTv3 model architecture and depends on your hardware capabilities (GPU acceleration and multi-worker data loading significantly reduce training time). With NTv3 models, meaningful Pearson correlations can typically be reached within ~10minutes of training on a 32kb functional tracks prediction task.\n",
12
  "\n",
13
+ "**🔧 Main Simplifications**: Compared to the full supervised tracks pipeline, this notebook simplifies several aspects to enable faster iteration:\n",
14
  "\n",
15
  "- **Data splits**: Uses simple chromosome-based train/val/test splits (e.g., assigning entire chromosomes to each split) instead of more complex region-based splits\n",
16
  "- **Random sequence sampling**: The dataset randomly samples sequences from chromosomes/regions on-the-fly, rather than using pre-computed sliding windows\n",
 
30
  "\n",
31
  "If you're interested in using pre-trained models for inference without fine-tuning, or exploring different model architectures, please refer to other notebooks in this collection. This notebook focuses specifically on the simplified fine-tuning process, which is useful when you want to quickly adapt a pre-trained model to your specific genomic tracks or improve performance on particular cell types or experimental conditions.\n",
32
  "\n",
33
+ "📝 Note for Google Colab users: This notebook is compatible with Colab! For faster training, make sure to enable GPU: Runtime Change runtime type GPU (T4 or better recommended).\n"
34
+ ]
35
+ },
36
+ {
37
+ "cell_type": "markdown",
38
+ "metadata": {},
39
+ "source": [
40
+ "# 0. 📦 Imports dependencies"
41
  ]
42
  },
43
  {
 
50
  "!pip install pyfaidx pyBigWig torchmetrics transformers plotly"
51
  ]
52
  },
53
+ {
54
+ "cell_type": "code",
55
+ "execution_count": 20,
56
+ "metadata": {},
57
+ "outputs": [],
58
+ "source": [
59
+ "import random\n",
60
+ "import functools\n",
61
+ "from typing import List, Dict, Callable\n",
62
+ "import os\n",
63
+ "import subprocess\n",
64
+ "from concurrent.futures import ThreadPoolExecutor, as_completed\n",
65
+ "\n",
66
+ "import torch\n",
67
+ "import torch.nn as nn\n",
68
+ "import torch.nn.functional as F\n",
69
+ "from torch.utils.data import Dataset, DataLoader\n",
70
+ "from torch.optim import AdamW\n",
71
+ "from transformers import AutoConfig, AutoModelForMaskedLM, AutoTokenizer\n",
72
+ "import numpy as np\n",
73
+ "import pyBigWig\n",
74
+ "from pyfaidx import Fasta\n",
75
+ "from torchmetrics import PearsonCorrCoef\n",
76
+ "import plotly.graph_objects as go\n",
77
+ "from IPython.display import display\n",
78
+ "from tqdm import tqdm"
79
+ ]
80
+ },
81
  {
82
  "cell_type": "markdown",
83
  "metadata": {},
84
  "source": [
85
+ "# 1. ⚙️ Configuration\n",
86
  "\n",
87
  "## Configuration Parameters\n",
88
  "\n",
 
116
  },
117
  {
118
  "cell_type": "code",
119
+ "execution_count": 21,
120
  "metadata": {},
121
+ "outputs": [
122
+ {
123
+ "name": "stdout",
124
+ "output_type": "stream",
125
+ "text": [
126
+ "Using device: cpu\n"
127
+ ]
128
+ }
129
+ ],
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
130
  "source": [
131
  "config = {\n",
132
  " # Model\n",
 
136
  " \"data_cache_dir\": \"./data\",\n",
137
  " \"fasta_url\": \"https://hgdownload.gi.ucsc.edu/goldenPath/hg38/bigZips/hg38.fa.gz\",\n",
138
  " \"bigwig_url_list\": [\n",
 
139
  " \"https://www.encodeproject.org/files/ENCFF055QKS/@@download/ENCFF055QKS.bigWig\",\n",
140
  " \"https://www.encodeproject.org/files/ENCFF214GOQ/@@download/ENCFF214GOQ.bigWig\",\n",
141
  " \"https://www.encodeproject.org/files/ENCFF592NIB/@@download/ENCFF592NIB.bigWig\",\n",
 
181
  " for url in config[\"bigwig_url_list\"]\n",
182
  "]\n",
183
  "\n",
184
+ "\n",
185
+ "# TODO: find a way to link the experiment accession to bigwig file ids\n",
186
  "# Create bigwig_file_ids from filenames (without extension)\n",
187
  "config[\"bigwig_file_ids\"] = [\n",
188
  " # os.path.splitext(extract_filename_from_url(url))[0]\n",
 
206
  "cell_type": "markdown",
207
  "metadata": {},
208
  "source": [
209
+ "# 2. 📥 Genome & Tracks Data Download\n",
210
  "\n",
211
  "Download the reference genome FASTA file and BigWig signal tracks from public repositories. These files contain the genomic sequences and experimental signal data (e.g., ChIP-seq, ATAC-seq) that we'll use for training."
212
  ]
 
217
  "metadata": {},
218
  "outputs": [],
219
  "source": [
220
+ "def _download_file(url: str, output_path: str) -> None:\n",
221
+ " \"\"\"Download a file from URL to output_path using wget.\"\"\"\n",
222
+ " subprocess.run([\"wget\", \"-c\", url, \"-O\", output_path], check=True)\n",
223
+ "\n",
224
+ "# Prepare download tasks: (url, output_path)\n",
225
+ "download_tasks = []\n",
226
+ "\n",
227
+ "# FASTA file\n",
228
  "fasta_filename = extract_filename_from_url(config[\"fasta_url\"])\n",
229
  "fasta_gz_path = os.path.join(config[\"data_cache_dir\"], fasta_filename)\n",
230
+ "download_tasks.append((config[\"fasta_url\"], fasta_gz_path))\n",
231
  "\n",
232
+ "# BigWig files\n",
 
 
 
 
 
 
 
 
 
 
 
 
 
233
  "for bigwig_url in config[\"bigwig_url_list\"]:\n",
234
  " filename = extract_filename_from_url(bigwig_url)\n",
235
  " filepath = os.path.join(config[\"data_cache_dir\"], filename)\n",
236
+ " download_tasks.append((bigwig_url, filepath))\n",
237
+ "\n",
238
+ "# Download files in parallel\n",
239
+ "max_workers = min(len(download_tasks), 8)\n",
240
+ "\n",
241
+ "print(f\"Downloading {len(download_tasks)} files using {max_workers} workers...\")\n",
242
+ "with ThreadPoolExecutor(max_workers=max_workers) as executor:\n",
243
+ " # Submit all download tasks\n",
244
+ " future_to_path = {\n",
245
+ " executor.submit(_download_file, url, path): path\n",
246
+ " for url, path in download_tasks\n",
247
+ " }\n",
248
+ " \n",
249
+ " # Wait for all downloads to complete\n",
250
+ " for future in as_completed(future_to_path):\n",
251
+ " try:\n",
252
+ " future.result() # Raises exception if download failed\n",
253
+ " path = future_to_path[future]\n",
254
+ " print(f\"✓ Downloaded: {os.path.basename(path)}\")\n",
255
+ " except Exception as e:\n",
256
+ " path = future_to_path[future]\n",
257
+ " raise RuntimeError(f\"Failed to download {path}: {e}\") from e\n",
258
+ "\n",
259
+ "# Extract FASTA file after download\n",
260
+ "print(f\"\\nExtracting {fasta_filename}...\")\n",
261
+ "subprocess.run([\"gunzip\", \"-f\", fasta_gz_path], check=True)\n",
262
+ "print(\"✓ Extraction complete\")"
263
  ]
264
  },
265
  {
266
  "cell_type": "markdown",
267
  "metadata": {},
268
  "source": [
269
+ "### Data Splits Definition"
270
  ]
271
  },
272
  {
 
286
  "cell_type": "markdown",
287
  "metadata": {},
288
  "source": [
289
+ "# 3. 🧠 Model and tokenizer setup\n",
290
  " \n",
291
  "In this section, we set up the model and tokenizer. \n",
292
  " \n",
 
390
  "print(f\"Model parameters: {sum(p.numel() for p in model.parameters()):,}\")"
391
  ]
392
  },
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
393
  {
394
  "cell_type": "markdown",
395
  "metadata": {},
396
  "source": [
397
+ "# 4. 🔄 Data loading\n",
398
  "\n",
399
  "Create PyTorch datasets and data loaders that efficiently sample random genomic windows from the reference genome and extract corresponding BigWig signal values. The dataset handles sequence tokenization, target scaling, and chromosome-based train/val/test splits."
400
  ]
 
562
  " return sample"
563
  ]
564
  },
565
+ {
566
+ "cell_type": "markdown",
567
+ "metadata": {},
568
+ "source": [
569
+ "### Data preprocessing utilities"
570
+ ]
571
+ },
572
+ {
573
+ "cell_type": "code",
574
+ "execution_count": null,
575
+ "metadata": {},
576
+ "outputs": [],
577
+ "source": [
578
+ "# Scaling functions for targets\n",
579
+ "def compute_chromosome_stats(track_data: np.ndarray) -> dict:\n",
580
+ " \"\"\"\n",
581
+ " Compute minimal statistics needed for weighted mean computation.\n",
582
+ " \n",
583
+ " Args:\n",
584
+ " track_data: numpy array of track values for a chromosome\n",
585
+ " \n",
586
+ " Returns:\n",
587
+ " Dictionary with statistics: sum, mean, total_count\n",
588
+ " \"\"\"\n",
589
+ " track_data = track_data.astype(np.float32)\n",
590
+ " \n",
591
+ " # Compute statistics\n",
592
+ " sum_all = np.sum(track_data)\n",
593
+ " total_count = track_data.size\n",
594
+ " mean_all = sum_all / total_count if total_count > 0 else 0.0\n",
595
+ " \n",
596
+ " return {\n",
597
+ " \"sum\": sum_all,\n",
598
+ " \"mean\": mean_all,\n",
599
+ " \"total_count\": total_count,\n",
600
+ " }\n",
601
+ "\n",
602
+ "\n",
603
+ "def aggregate_file_statistics(chr_stats_list: List[dict]) -> dict:\n",
604
+ " \"\"\"\n",
605
+ " Aggregate chromosome-level statistics into file-level statistics.\n",
606
+ " \n",
607
+ " Args:\n",
608
+ " chr_stats_list: List of dictionaries, each containing chromosome-level statistics\n",
609
+ " \n",
610
+ " Returns:\n",
611
+ " Dictionary with aggregated file-level statistics (only mean)\n",
612
+ " \"\"\"\n",
613
+ " # Convert to arrays for easier computation\n",
614
+ " total_counts = np.array([s[\"total_count\"] for s in chr_stats_list], dtype=np.int64)\n",
615
+ " means = np.array([s[\"mean\"] for s in chr_stats_list], dtype=np.float32)\n",
616
+ " sums = np.array([s[\"sum\"] for s in chr_stats_list], dtype=np.float32)\n",
617
+ " \n",
618
+ " # Aggregate total count\n",
619
+ " total_count = np.sum(total_counts)\n",
620
+ " \n",
621
+ " # Weighted mean: mean = sum(mean_chr * total_count_chr) / sum(total_count_chr)\n",
622
+ " mean = np.sum(means * total_counts) / total_count if total_count > 0 else 0.0\n",
623
+ " \n",
624
+ " return {\n",
625
+ " \"total_count\": total_count,\n",
626
+ " \"sum\": np.sum(sums),\n",
627
+ " \"mean\": mean,\n",
628
+ " }\n",
629
+ "\n",
630
+ "\n",
631
+ "def get_track_means(bigwig_tracks_list: List[pyBigWig.pyBigWig]) -> np.ndarray:\n",
632
+ " \"\"\"\n",
633
+ " Get track means for normalization.\n",
634
+ " Computes statistics per chromosome and aggregates using weighted averaging,\n",
635
+ " \n",
636
+ " Args:\n",
637
+ " bigwig_tracks_list: List of pyBigWig file objects\n",
638
+ " \n",
639
+ " Returns:\n",
640
+ " Array of track means, one per bigwig file\n",
641
+ " \"\"\"\n",
642
+ " track_means = []\n",
643
+ " \n",
644
+ " for bigwig_track in bigwig_tracks_list:\n",
645
+ " chrom_lengths = bigwig_track.chroms()\n",
646
+ " all_chr_stats = []\n",
647
+ " \n",
648
+ " # Compute statistics for each chromosome\n",
649
+ " for chrom_name, chrom_length in chrom_lengths.items():\n",
650
+ " try:\n",
651
+ " # Get chromosome data as numpy array\n",
652
+ " bw_array = np.array(\n",
653
+ " bigwig_track.values(chrom_name, 0, chrom_length, numpy=True),\n",
654
+ " dtype=np.float32\n",
655
+ " )\n",
656
+ " # Replace NaN with 0\n",
657
+ " bw_array = np.nan_to_num(bw_array, nan=0.0)\n",
658
+ " \n",
659
+ " # Compute chromosome-level statistics\n",
660
+ " chr_stats = compute_chromosome_stats(bw_array)\n",
661
+ " all_chr_stats.append(chr_stats)\n",
662
+ " except Exception as e:\n",
663
+ " # Skip chromosomes that fail to load\n",
664
+ " print(f\"Warning: Failed to load chromosome {chrom_name}: {e}\")\n",
665
+ " continue\n",
666
+ " \n",
667
+ " if not all_chr_stats:\n",
668
+ " raise ValueError(f\"No valid chromosomes found for bigwig track\")\n",
669
+ " \n",
670
+ " # Aggregate chromosome-level stats into file-level stats\n",
671
+ " file_stats = aggregate_file_statistics(all_chr_stats)\n",
672
+ " \n",
673
+ " # Use the weighted mean for normalization\n",
674
+ " track_means.append(file_stats[\"mean\"])\n",
675
+ " \n",
676
+ " return np.array(track_means, dtype=np.float32)\n",
677
+ "\n",
678
+ "\n",
679
+ "def create_targets_scaling_fn(bigwig_path_list: List[str]) -> Callable[[torch.Tensor], torch.Tensor]:\n",
680
+ " \"\"\"\n",
681
+ " Build a scaling function based on track means computed from bigwig files.\n",
682
+ " \n",
683
+ " Opens bigwig files, computes track statistics, and creates a transform function.\n",
684
+ " The statistics are computed once and reused for all calls to the returned transform function.\n",
685
+ " \n",
686
+ " Args:\n",
687
+ " bigwig_path_list: List of paths to bigwig files\n",
688
+ " \n",
689
+ " Returns:\n",
690
+ " Transform function that scales input tensors\n",
691
+ " \"\"\"\n",
692
+ " # Open bigwig files and compute track statistics\n",
693
+ " print(\"Computing track statistics (this may take a while)...\")\n",
694
+ " bw_list = [\n",
695
+ " pyBigWig.open(bigwig_path)\n",
696
+ " for bigwig_path in bigwig_path_list\n",
697
+ " ]\n",
698
+ " track_means = get_track_means(bw_list)\n",
699
+ " print(f\"Computed track means: {track_means}\")\n",
700
+ " print(f\"Track means shape: {track_means.shape}\")\n",
701
+ " \n",
702
+ " # Create tensor from computed means\n",
703
+ " track_means_tensor = torch.tensor(track_means, dtype=torch.float32)\n",
704
+ " \n",
705
+ " def transform_fn(x: torch.Tensor) -> torch.Tensor:\n",
706
+ " \"\"\"\n",
707
+ " x: torch.Tensor, shape (seq_len, num_tracks) or (batch, seq_len, num_tracks)\n",
708
+ " \"\"\"\n",
709
+ " # Move constants to correct device then normalize\n",
710
+ " means = track_means_tensor.to(x.device)\n",
711
+ " scaled = x / means\n",
712
+ "\n",
713
+ " # Smooth clipping: if > 10, apply formula\n",
714
+ " clipped = torch.where(\n",
715
+ " scaled > 10.0,\n",
716
+ " 2.0 * torch.sqrt(scaled * 10.0) - 10.0,\n",
717
+ " scaled,\n",
718
+ " )\n",
719
+ " return clipped\n",
720
+ " \n",
721
+ " return transform_fn"
722
+ ]
723
+ },
724
  {
725
  "cell_type": "code",
726
  "execution_count": null,
 
794
  "cell_type": "markdown",
795
  "metadata": {},
796
  "source": [
797
+ "# 5. ⚙️ Optimizer setup\n",
798
  "\n",
799
  "Configure the AdamW optimizer with learning rate and weight decay hyperparameters. This optimizer will update the model parameters during training to minimize the loss function.\n",
800
  "\n"
 
828
  "cell_type": "markdown",
829
  "metadata": {},
830
  "source": [
831
+ "# 6. 📊 Metrics setup\n",
832
  "\n",
833
  "Set up evaluation metrics to track model performance during training and validation. We use Pearson correlation coefficients to measure how well the predicted BigWig signals match the ground truth signals."
834
  ]
 
921
  "cell_type": "markdown",
922
  "metadata": {},
923
  "source": [
924
+ "# 7. 📉 Loss functions\n",
925
  "\n",
926
  "Define the Poisson-Multinomial loss function that captures both the scale (total signal) and shape (distribution) of BigWig tracks. This loss is specifically designed for count-based genomic signal data."
927
  ]
 
1003
  "cell_type": "markdown",
1004
  "metadata": {},
1005
  "source": [
1006
+ "# 8. 🏃 Training loop\n",
1007
  "\n",
1008
  "Run the main training loop that iterates through batches, computes gradients, and updates model parameters. The loop includes periodic validation checks and real-time metric visualization to monitor training progress."
1009
  ]
 
1235
  "cell_type": "markdown",
1236
  "metadata": {},
1237
  "source": [
1238
+ "# 9. 🧪 Test evaluation\n",
1239
  "\n",
1240
  "Evaluate the fine-tuned model on the held-out test set to assess final performance. This provides an unbiased estimate of how well the model generalizes to unseen genomic regions."
1241
  ]
 
1279
  "source": [
1280
  " ## Test set results\n",
1281
  "\n",
1282
+ "Performances reached at ~1.5B tokens (~1500 steps in current 32kb sequences setup with batch_size=32)\n",
1283
+ "\n",
1284
  "Mean Pearson: 0.5835\n",
1285
  "- ENCSR325NFE/pearson: 0.6081\n",
1286
  "- ENCSR962OTG/pearson: 0.7286\n",
 
1315
  },
1316
  "nbformat": 4,
1317
  "nbformat_minor": 2
1318
+ }