ybornachot commited on
Commit
eadf098
·
1 Parent(s): b0d8db2

fix: come back to older version

Browse files
Files changed (1) hide show
  1. notebooks/03_fine_tuning.ipynb +499 -170
notebooks/03_fine_tuning.ipynb CHANGED
@@ -4,15 +4,13 @@
4
  "cell_type": "markdown",
5
  "metadata": {},
6
  "source": [
7
- "# 🧬 Fine-Tuning a Model on BigWig Tracks Prediction\n",
8
  "\n",
9
  "This notebook demonstrates a **simplified fine-tuning setup** that enables training of a pre-trained Nucleotide Transformer v3 (NTv3) model to predict BigWig signal tracks directly from DNA sequences. The streamlined approach leverages a pre-trained NTv3 backbone as a feature extractor and adds a custom prediction head that outputs single-nucleotide resolution signal values for various genomic tracks (e.g., ChIP-seq, ATAC-seq, RNA-seq).\n",
10
  "\n",
11
- "**⚡ Key Advantage**: This simplified pipeline achieves **close performance to more complex training approaches** while enabling **fast fine-tuning**. The training speed benefits from the efficient NTv3 model architecture and depends on your hardware capabilities (GPU acceleration and multi-worker data loading significantly reduce training time). With NTv3 models, meaningful Pearson correlations can typically be reached within ~10minutes of training on a 32kb functional tracks prediction task. \n",
12
  "\n",
13
- "While this notebook currently focuses on NTv3 models, the pipeline structure can be extended to work with other foundation models. The setup is designed for rapid experimentation and iteration, making it ideal for adapting pre-trained models to your specific genomic tracks or experimental conditions without the overhead of complex distributed training infrastructure.\n",
14
- "\n",
15
- "**🔧 Main Simplifications**: Compared to the full supervised tracks prediction pipeline, this notebook simplifies several aspects to enable faster iteration:\n",
16
  "\n",
17
  "- **Data splits**: Uses simple chromosome-based train/val/test splits (e.g., assigning entire chromosomes to each split) instead of more complex region-based splits\n",
18
  "- **Random sequence sampling**: The dataset randomly samples sequences from chromosomes/regions on-the-fly, rather than using pre-computed sliding windows\n",
@@ -32,14 +30,7 @@
32
  "\n",
33
  "If you're interested in using pre-trained models for inference without fine-tuning, or exploring different model architectures, please refer to other notebooks in this collection. This notebook focuses specifically on the simplified fine-tuning process, which is useful when you want to quickly adapt a pre-trained model to your specific genomic tracks or improve performance on particular cell types or experimental conditions.\n",
34
  "\n",
35
- "📝 Note for Google Colab users: This notebook is compatible with Colab! For faster training, make sure to enable GPU: Runtime Change runtime type GPU (T4 or better recommended).\n"
36
- ]
37
- },
38
- {
39
- "cell_type": "markdown",
40
- "metadata": {},
41
- "source": [
42
- "# 0. 📦 Imports"
43
  ]
44
  },
45
  {
@@ -49,38 +40,14 @@
49
  "outputs": [],
50
  "source": [
51
  "# Install dependencies\n",
52
- "!pip install datasets transformers torchmetrics plotly "
53
- ]
54
- },
55
- {
56
- "cell_type": "code",
57
- "execution_count": 7,
58
- "metadata": {},
59
- "outputs": [],
60
- "source": [
61
- "# Imports\n",
62
- "from typing import List, Dict\n",
63
- "import os\n",
64
- "\n",
65
- "import torch\n",
66
- "import torch.nn as nn\n",
67
- "import torch.nn.functional as F\n",
68
- "from torch.utils.data import DataLoader\n",
69
- "from torch.optim import AdamW\n",
70
- "from transformers import AutoConfig, AutoModelForMaskedLM, AutoTokenizer\n",
71
- "from datasets import load_dataset\n",
72
- "import numpy as np\n",
73
- "from torchmetrics import PearsonCorrCoef\n",
74
- "import plotly.graph_objects as go\n",
75
- "from IPython.display import display\n",
76
- "from tqdm import tqdm"
77
  ]
78
  },
79
  {
80
  "cell_type": "markdown",
81
  "metadata": {},
82
  "source": [
83
- "# 1. ⚙️ Configuration\n",
84
  "\n",
85
  "## Configuration Parameters\n",
86
  "\n",
@@ -116,25 +83,45 @@
116
  "cell_type": "code",
117
  "execution_count": null,
118
  "metadata": {},
119
- "outputs": [
120
- {
121
- "name": "stdout",
122
- "output_type": "stream",
123
- "text": [
124
- "Using device: cpu\n"
125
- ]
126
- }
127
- ],
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
128
  "source": [
129
  "config = {\n",
130
  " # Model\n",
131
  " \"model_name\": \"InstaDeepAI/NTv3_8M_pre\",\n",
132
  " \n",
133
- " # Data - Hugging Face Dataset Configuration\n",
134
- " \"dataset_name\": \"InstaDeepAI/bigwig_tracks\", # Hugging Face dataset name or path to script\n",
135
  " \"data_cache_dir\": \"./data\",\n",
136
  " \"fasta_url\": \"https://hgdownload.gi.ucsc.edu/goldenPath/hg38/bigZips/hg38.fa.gz\",\n",
137
  " \"bigwig_url_list\": [\n",
 
138
  " \"https://www.encodeproject.org/files/ENCFF055QKS/@@download/ENCFF055QKS.bigWig\",\n",
139
  " \"https://www.encodeproject.org/files/ENCFF214GOQ/@@download/ENCFF214GOQ.bigWig\",\n",
140
  " \"https://www.encodeproject.org/files/ENCFF592NIB/@@download/ENCFF592NIB.bigWig\",\n",
@@ -165,8 +152,25 @@
165
  "\n",
166
  "os.makedirs(config[\"data_cache_dir\"], exist_ok=True)\n",
167
  "\n",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
168
  "# Create bigwig_file_ids from filenames (without extension)\n",
169
  "config[\"bigwig_file_ids\"] = [\n",
 
 
170
  " \"ENCSR325NFE\",\n",
171
  " \"ENCSR962OTG\",\n",
172
  " \"ENCSR619DQO_P\",\n",
@@ -186,7 +190,67 @@
186
  "cell_type": "markdown",
187
  "metadata": {},
188
  "source": [
189
- "# 2. 🧠 Model and tokenizer setup\n",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
190
  " \n",
191
  "In this section, we set up the model and tokenizer. \n",
192
  " \n",
@@ -196,12 +260,12 @@
196
  "This linear head is trained for regression on a set of genomic tracks, \n",
197
  "allowing the model to make predictions for each track at single nucleotide resolution.\n",
198
  " \n",
199
- "The following code wraps the HuggingFace model together with this regression head for the end-to-end task."
200
  ]
201
  },
202
  {
203
  "cell_type": "code",
204
- "execution_count": 9,
205
  "metadata": {},
206
  "outputs": [],
207
  "source": [
@@ -269,19 +333,9 @@
269
  },
270
  {
271
  "cell_type": "code",
272
- "execution_count": 10,
273
- "metadata": {},
274
- "outputs": [
275
- {
276
- "name": "stdout",
277
- "output_type": "stream",
278
- "text": [
279
- "Model loaded: InstaDeepAI/NTv3_8M_pre\n",
280
- "Number of bigwig tracks: 4\n",
281
- "Model parameters: 7,694,015\n"
282
- ]
283
- }
284
- ],
285
  "source": [
286
  "# Load tokenizer\n",
287
  "tokenizer = AutoTokenizer.from_pretrained(config[\"model_name\"], trust_remote_code=True)\n",
@@ -297,16 +351,168 @@
297
  "\n",
298
  "print(f\"Model loaded: {config['model_name']}\")\n",
299
  "print(f\"Number of bigwig tracks: {len(config['bigwig_file_ids'])}\")\n",
300
- "print(f\"Model parameters: {sum(p.numel() for p in model.parameters()):,}\")\n"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
301
  ]
302
  },
303
  {
304
  "cell_type": "markdown",
305
  "metadata": {},
306
  "source": [
307
- "# 3. 📥 Dataset setup\n",
308
  "\n",
309
- "Load the Hugging Face dataset and set up the data pipeline. The dataset automatically handles downloading FASTA and BigWig files, normalizing tracks, and sampling random genomic windows."
310
  ]
311
  },
312
  {
@@ -315,30 +521,161 @@
315
  "metadata": {},
316
  "outputs": [],
317
  "source": [
318
- "# Chromosomes split definition\n",
319
- "chrom_splits = {\n",
320
- " \"train\": [f\"chr{i}\" for i in range(1, 21)] + ['chrX', 'chrY'],\n",
321
- " \"val\": ['chr22'],\n",
322
- " \"test\": ['chr21']\n",
323
- "}\n",
324
  "\n",
325
- "# Number of desired samples per split\n",
326
- "num_samples = {\n",
327
- " \"train\": config[\"num_steps_training\"] * config[\"batch_size\"],\n",
328
- " \"val\": config[\"num_validation_samples\"],\n",
329
- " \"test\": config[\"num_test_samples\"],\n",
330
- "}\n",
331
  "\n",
332
- "print(f\"Loading dataset from {config['dataset_name']}...\")\n",
333
- "dataset = load_dataset(\n",
334
- " config[\"dataset_name\"],\n",
335
- " data_files=chrom_splits,\n",
336
- " num_samples=num_samples,\n",
337
- " fasta_url=config[\"fasta_url\"],\n",
338
- " bigwig_urls=config[\"bigwig_url_list\"],\n",
339
- " sequence_length=config[\"sequence_length\"],\n",
340
- " data_dir=config[\"data_cache_dir\"],\n",
341
- ")"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
342
  ]
343
  },
344
  {
@@ -347,92 +684,84 @@
347
  "metadata": {},
348
  "outputs": [],
349
  "source": [
350
- "# Tokenization function\n",
351
- "def tokenize_examples(examples):\n",
352
- " \"\"\"Tokenize sequences and prepare targets.\"\"\"\n",
353
- " sequences = examples[\"sequence\"]\n",
354
- " \n",
355
- " # Tokenize sequences\n",
356
- " tokenized = tokenizer(\n",
357
- " sequences,\n",
358
- " max_length=config[\"sequence_length\"],\n",
359
- " padding=\"max_length\",\n",
360
- " truncation=True,\n",
361
- " return_tensors=None,\n",
362
- " )\n",
363
- " \n",
364
- " # Crop targets to center fraction if needed\n",
365
- " if config[\"keep_target_center_fraction\"] < 1.0:\n",
366
- " seq_len = examples[\"bigwig_targets\"].shape[0]\n",
367
- " target_offset = int(seq_len * (1 - config[\"keep_target_center_fraction\"]) // 2)\n",
368
- " target_length = seq_len - 2 * target_offset\n",
369
- " examples[\"bigwig_targets\"] = examples[\"bigwig_targets\"][target_offset:target_offset + target_length, :]\n",
370
- " \n",
371
- " return {\n",
372
- " \"tokens\": tokenized[\"input_ids\"],\n",
373
- " \"bigwig_targets\": examples[\"bigwig_targets\"],\n",
374
- " }\n",
375
  "\n",
376
- "# Apply tokenization\n",
377
- "print(\"Tokenizing sequences...\")\n",
378
- "dataset = dataset.map(\n",
379
- " tokenize_examples,\n",
380
- " batched=True,\n",
381
- " remove_columns=[\"sequence\"], # Remove original sequence after tokenization\n",
382
  ")\n",
383
  "\n",
384
- "# Format for PyTorch\n",
385
- "dataset = dataset.with_format(\"torch\")\n",
 
 
386
  "\n",
387
- "dataloaders = {}\n",
388
- "for split_name in chrom_splits.keys():\n",
389
- " dataloaders[split_name] = DataLoader(\n",
390
- " dataset[split_name],\n",
391
- " batch_size=config[\"batch_size\"],\n",
392
- " shuffle=(split_name == \"train\"),\n",
393
- " num_workers=config[\"num_workers\"],\n",
394
- " )\n",
395
  "\n",
396
- "# Extract DataLoaders\n",
397
- "train_loader = dataloaders[\"train\"]\n",
398
- "val_loader = dataloaders[\"val\"]\n",
399
- "test_loader = dataloaders[\"test\"]\n",
 
 
 
 
 
 
 
 
 
 
400
  "\n",
401
- "print(f\"\\nData pipeline created successfully!\")\n",
402
- "print(f\"Train batches: {len(train_loader)}\")\n",
403
- "print(f\"Val batches: {len(val_loader)}\")\n",
404
- "print(f\"Test batches: {len(test_loader)}\")"
 
 
 
 
 
 
405
  ]
406
  },
407
  {
408
  "cell_type": "markdown",
409
  "metadata": {},
410
  "source": [
411
- "# 4. ⚙️ Optimizer setup\n",
412
  "\n",
413
- "Configure the AdamW optimizer with learning rate and weight decay hyperparameters. This optimizer will update the model parameters during training to minimize the loss function."
 
414
  ]
415
  },
416
  {
417
  "cell_type": "code",
418
- "execution_count": 15,
419
- "metadata": {},
420
- "outputs": [
421
- {
422
- "name": "stdout",
423
- "output_type": "stream",
424
- "text": [
425
- "Training configuration:\n",
426
- " Batch size: 32\n",
427
- " Total training steps: 19932\n",
428
- " Log metrics every: 40 steps\n",
429
- " Validate every: 400 steps\n",
430
- "\n",
431
- "Optimizer setup:\n",
432
- " Learning rate: 1e-05\n"
433
- ]
434
- }
435
- ],
436
  "source": [
437
  "# Training setup\n",
438
  "print(f\"Training configuration:\")\n",
@@ -456,14 +785,14 @@
456
  "cell_type": "markdown",
457
  "metadata": {},
458
  "source": [
459
- "# 5. 📊 Metrics setup\n",
460
  "\n",
461
  "Set up evaluation metrics to track model performance during training and validation. We use Pearson correlation coefficients to measure how well the predicted BigWig signals match the ground truth signals."
462
  ]
463
  },
464
  {
465
  "cell_type": "code",
466
- "execution_count": 16,
467
  "metadata": {},
468
  "outputs": [],
469
  "source": [
@@ -536,7 +865,7 @@
536
  },
537
  {
538
  "cell_type": "code",
539
- "execution_count": 17,
540
  "metadata": {},
541
  "outputs": [],
542
  "source": [
@@ -549,14 +878,14 @@
549
  "cell_type": "markdown",
550
  "metadata": {},
551
  "source": [
552
- "# 6. 📉 Loss function\n",
553
  "\n",
554
  "Define the Poisson-Multinomial loss function that captures both the scale (total signal) and shape (distribution) of BigWig tracks. This loss is specifically designed for count-based genomic signal data."
555
  ]
556
  },
557
  {
558
  "cell_type": "code",
559
- "execution_count": 18,
560
  "metadata": {},
561
  "outputs": [],
562
  "source": [
@@ -631,14 +960,14 @@
631
  "cell_type": "markdown",
632
  "metadata": {},
633
  "source": [
634
- "# 7. 🏃 Training loop\n",
635
  "\n",
636
  "Run the main training loop that iterates through batches, computes gradients, and updates model parameters. The loop includes periodic validation checks and real-time metric visualization to monitor training progress."
637
  ]
638
  },
639
  {
640
  "cell_type": "code",
641
- "execution_count": 19,
642
  "metadata": {},
643
  "outputs": [],
644
  "source": [
@@ -706,7 +1035,7 @@
706
  },
707
  {
708
  "cell_type": "code",
709
- "execution_count": 18,
710
  "metadata": {},
711
  "outputs": [],
712
  "source": [
@@ -863,7 +1192,7 @@
863
  "cell_type": "markdown",
864
  "metadata": {},
865
  "source": [
866
- "# 9. 🧪 Test evaluation\n",
867
  "\n",
868
  "Evaluate the fine-tuned model on the held-out test set to assess final performance. This provides an unbiased estimate of how well the model generalizes to unseen genomic regions."
869
  ]
@@ -941,4 +1270,4 @@
941
  },
942
  "nbformat": 4,
943
  "nbformat_minor": 2
944
- }
 
4
  "cell_type": "markdown",
5
  "metadata": {},
6
  "source": [
7
+ "# \ud83e\uddec Fine-Tuning a Model on BigWig Tracks Prediction\n",
8
  "\n",
9
  "This notebook demonstrates a **simplified fine-tuning setup** that enables training of a pre-trained Nucleotide Transformer v3 (NTv3) model to predict BigWig signal tracks directly from DNA sequences. The streamlined approach leverages a pre-trained NTv3 backbone as a feature extractor and adds a custom prediction head that outputs single-nucleotide resolution signal values for various genomic tracks (e.g., ChIP-seq, ATAC-seq, RNA-seq).\n",
10
  "\n",
11
+ "**\u26a1 Key Advantage**: This simplified pipeline achieves **close performance to more complex training approaches** while enabling **relatively fast fine-tuning in approximately one hour**. The setup is designed for rapid experimentation and iteration, making it ideal for adapting pre-trained models to your specific genomic tracks or experimental conditions without the overhead of complex distributed training infrastructure.\n",
12
  "\n",
13
+ "**\ud83d\udd27 Main Simplifications**: Compared to the full supervised tracks pipeline, this notebook simplifies several aspects to enable faster iteration:\n",
 
 
14
  "\n",
15
  "- **Data splits**: Uses simple chromosome-based train/val/test splits (e.g., assigning entire chromosomes to each split) instead of more complex region-based splits\n",
16
  "- **Random sequence sampling**: The dataset randomly samples sequences from chromosomes/regions on-the-fly, rather than using pre-computed sliding windows\n",
 
30
  "\n",
31
  "If you're interested in using pre-trained models for inference without fine-tuning, or exploring different model architectures, please refer to other notebooks in this collection. This notebook focuses specifically on the simplified fine-tuning process, which is useful when you want to quickly adapt a pre-trained model to your specific genomic tracks or improve performance on particular cell types or experimental conditions.\n",
32
  "\n",
33
+ "\ud83d\udcdd Note for Google Colab users: This notebook is compatible with Colab! For faster training, make sure to enable GPU: Runtime \u2192 Change runtime type \u2192 GPU (T4 or better recommended).\n"
 
 
 
 
 
 
 
34
  ]
35
  },
36
  {
 
40
  "outputs": [],
41
  "source": [
42
  "# Install dependencies\n",
43
+ "!pip install pyfaidx pyBigWig torchmetrics transformers plotly"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
44
  ]
45
  },
46
  {
47
  "cell_type": "markdown",
48
  "metadata": {},
49
  "source": [
50
+ "# 1. \ud83d\udce6 Imports + Configuration\n",
51
  "\n",
52
  "## Configuration Parameters\n",
53
  "\n",
 
83
  "cell_type": "code",
84
  "execution_count": null,
85
  "metadata": {},
86
+ "outputs": [],
87
+ "source": [
88
+ "# 0. Imports\n",
89
+ "import random\n",
90
+ "import functools\n",
91
+ "from typing import List, Dict, Callable\n",
92
+ "import os\n",
93
+ "import subprocess\n",
94
+ "\n",
95
+ "import torch\n",
96
+ "import torch.nn as nn\n",
97
+ "import torch.nn.functional as F\n",
98
+ "from torch.utils.data import Dataset, DataLoader\n",
99
+ "from torch.optim import AdamW\n",
100
+ "from transformers import AutoConfig, AutoModelForMaskedLM, AutoTokenizer\n",
101
+ "import numpy as np\n",
102
+ "import pyBigWig\n",
103
+ "from pyfaidx import Fasta\n",
104
+ "from torchmetrics import PearsonCorrCoef\n",
105
+ "import plotly.graph_objects as go\n",
106
+ "from IPython.display import display\n",
107
+ "from tqdm import tqdm"
108
+ ]
109
+ },
110
+ {
111
+ "cell_type": "code",
112
+ "execution_count": null,
113
+ "metadata": {},
114
+ "outputs": [],
115
  "source": [
116
  "config = {\n",
117
  " # Model\n",
118
  " \"model_name\": \"InstaDeepAI/NTv3_8M_pre\",\n",
119
  " \n",
120
+ " # Data\n",
 
121
  " \"data_cache_dir\": \"./data\",\n",
122
  " \"fasta_url\": \"https://hgdownload.gi.ucsc.edu/goldenPath/hg38/bigZips/hg38.fa.gz\",\n",
123
  " \"bigwig_url_list\": [\n",
124
+ " # \"https://www.encodeproject.org/files/ENCFF884LDL/@@download/ENCFF884LDL.bigWig\",\n",
125
  " \"https://www.encodeproject.org/files/ENCFF055QKS/@@download/ENCFF055QKS.bigWig\",\n",
126
  " \"https://www.encodeproject.org/files/ENCFF214GOQ/@@download/ENCFF214GOQ.bigWig\",\n",
127
  " \"https://www.encodeproject.org/files/ENCFF592NIB/@@download/ENCFF592NIB.bigWig\",\n",
 
152
  "\n",
153
  "os.makedirs(config[\"data_cache_dir\"], exist_ok=True)\n",
154
  "\n",
155
+ "# Extract filenames from URLs\n",
156
+ "def extract_filename_from_url(url: str) -> str:\n",
157
+ " \"\"\"Extract filename from URL, handling query parameters.\"\"\"\n",
158
+ " # Remove query parameters if present\n",
159
+ " url_clean = url.split('?')[0]\n",
160
+ " # Get the last part of the URL path\n",
161
+ " return url_clean.split('/')[-1]\n",
162
+ "\n",
163
+ "# Create paths for downloaded files\n",
164
+ "fasta_path = os.path.join(config[\"data_cache_dir\"], extract_filename_from_url(config[\"fasta_url\"]).replace('.gz', ''))\n",
165
+ "bigwig_path_list = [\n",
166
+ " os.path.join(config[\"data_cache_dir\"], extract_filename_from_url(url))\n",
167
+ " for url in config[\"bigwig_url_list\"]\n",
168
+ "]\n",
169
+ "\n",
170
  "# Create bigwig_file_ids from filenames (without extension)\n",
171
  "config[\"bigwig_file_ids\"] = [\n",
172
+ " # os.path.splitext(extract_filename_from_url(url))[0]\n",
173
+ " # for url in config[\"bigwig_url_list\"]\n",
174
  " \"ENCSR325NFE\",\n",
175
  " \"ENCSR962OTG\",\n",
176
  " \"ENCSR619DQO_P\",\n",
 
190
  "cell_type": "markdown",
191
  "metadata": {},
192
  "source": [
193
+ "# 2. \ud83d\udce5 Genome & Tracks Data Download\n",
194
+ "\n",
195
+ "Download the reference genome FASTA file and BigWig signal tracks from public repositories. These files contain the genomic sequences and experimental signal data (e.g., ChIP-seq, ATAC-seq) that we'll use for training."
196
+ ]
197
+ },
198
+ {
199
+ "cell_type": "code",
200
+ "execution_count": null,
201
+ "metadata": {},
202
+ "outputs": [],
203
+ "source": [
204
+ "# Download fasta file\n",
205
+ "fasta_filename = extract_filename_from_url(config[\"fasta_url\"])\n",
206
+ "fasta_gz_path = os.path.join(config[\"data_cache_dir\"], fasta_filename)\n",
207
+ "\n",
208
+ "print(f\"Downloading {fasta_filename}...\")\n",
209
+ "subprocess.run([\"wget\", \"-c\", config[\"fasta_url\"], \"-O\", fasta_gz_path], check=True)\n",
210
+ "\n",
211
+ "print(f\"Extracting {fasta_filename}...\")\n",
212
+ "subprocess.run([\"gunzip\", \"-f\", fasta_gz_path], check=True)"
213
+ ]
214
+ },
215
+ {
216
+ "cell_type": "code",
217
+ "execution_count": null,
218
+ "metadata": {},
219
+ "outputs": [],
220
+ "source": [
221
+ "# Download bigwig files\n",
222
+ "for bigwig_url in config[\"bigwig_url_list\"]:\n",
223
+ " filename = extract_filename_from_url(bigwig_url)\n",
224
+ " filepath = os.path.join(config[\"data_cache_dir\"], filename)\n",
225
+ " print(f\"Downloading {filename}...\")\n",
226
+ " subprocess.run([\"wget\", \"-c\", bigwig_url, \"-O\", filepath], check=True)"
227
+ ]
228
+ },
229
+ {
230
+ "cell_type": "markdown",
231
+ "metadata": {},
232
+ "source": [
233
+ "## Data Splits Definition"
234
+ ]
235
+ },
236
+ {
237
+ "cell_type": "code",
238
+ "execution_count": null,
239
+ "metadata": {},
240
+ "outputs": [],
241
+ "source": [
242
+ "chrom_splits = {\n",
243
+ " \"train\": [f\"chr{i}\" for i in range(1, 21)] + ['chrX', 'chrY'],\n",
244
+ " \"val\": ['chr22'],\n",
245
+ " \"test\": ['chr21']\n",
246
+ "}"
247
+ ]
248
+ },
249
+ {
250
+ "cell_type": "markdown",
251
+ "metadata": {},
252
+ "source": [
253
+ "# 3. \ud83e\udde0 Model and tokenizer setup\n",
254
  " \n",
255
  "In this section, we set up the model and tokenizer. \n",
256
  " \n",
 
260
  "This linear head is trained for regression on a set of genomic tracks, \n",
261
  "allowing the model to make predictions for each track at single nucleotide resolution.\n",
262
  " \n",
263
+ "The following code wraps the HuggingFace model together with this regression head for the end-to-end task.\n"
264
  ]
265
  },
266
  {
267
  "cell_type": "code",
268
+ "execution_count": null,
269
  "metadata": {},
270
  "outputs": [],
271
  "source": [
 
333
  },
334
  {
335
  "cell_type": "code",
336
+ "execution_count": null,
337
+ "metadata": {},
338
+ "outputs": [],
 
 
 
 
 
 
 
 
 
 
339
  "source": [
340
  "# Load tokenizer\n",
341
  "tokenizer = AutoTokenizer.from_pretrained(config[\"model_name\"], trust_remote_code=True)\n",
 
351
  "\n",
352
  "print(f\"Model loaded: {config['model_name']}\")\n",
353
  "print(f\"Number of bigwig tracks: {len(config['bigwig_file_ids'])}\")\n",
354
+ "print(f\"Model parameters: {sum(p.numel() for p in model.parameters()):,}\")"
355
+ ]
356
+ },
357
+ {
358
+ "cell_type": "code",
359
+ "execution_count": null,
360
+ "metadata": {},
361
+ "outputs": [],
362
+ "source": [
363
+ "# Scaling functions for targets\n",
364
+ "def compute_chromosome_stats(track_data: np.ndarray) -> dict:\n",
365
+ " \"\"\"\n",
366
+ " Compute minimal statistics needed for weighted mean computation.\n",
367
+ " \n",
368
+ " Args:\n",
369
+ " track_data: numpy array of track values for a chromosome\n",
370
+ " \n",
371
+ " Returns:\n",
372
+ " Dictionary with statistics: sum, mean, total_count\n",
373
+ " \"\"\"\n",
374
+ " track_data = track_data.astype(np.float32)\n",
375
+ " \n",
376
+ " # Compute statistics\n",
377
+ " sum_all = np.sum(track_data)\n",
378
+ " total_count = track_data.size\n",
379
+ " mean_all = sum_all / total_count if total_count > 0 else 0.0\n",
380
+ " \n",
381
+ " return {\n",
382
+ " \"sum\": sum_all,\n",
383
+ " \"mean\": mean_all,\n",
384
+ " \"total_count\": total_count,\n",
385
+ " }\n",
386
+ "\n",
387
+ "\n",
388
+ "def aggregate_file_statistics(chr_stats_list: List[dict]) -> dict:\n",
389
+ " \"\"\"\n",
390
+ " Aggregate chromosome-level statistics into file-level statistics.\n",
391
+ " \n",
392
+ " Args:\n",
393
+ " chr_stats_list: List of dictionaries, each containing chromosome-level statistics\n",
394
+ " \n",
395
+ " Returns:\n",
396
+ " Dictionary with aggregated file-level statistics (only mean)\n",
397
+ " \"\"\"\n",
398
+ " # Convert to arrays for easier computation\n",
399
+ " total_counts = np.array([s[\"total_count\"] for s in chr_stats_list], dtype=np.int64)\n",
400
+ " means = np.array([s[\"mean\"] for s in chr_stats_list], dtype=np.float32)\n",
401
+ " sums = np.array([s[\"sum\"] for s in chr_stats_list], dtype=np.float32)\n",
402
+ " \n",
403
+ " # Aggregate total count\n",
404
+ " total_count = np.sum(total_counts)\n",
405
+ " \n",
406
+ " # Weighted mean: mean = sum(mean_chr * total_count_chr) / sum(total_count_chr)\n",
407
+ " mean = np.sum(means * total_counts) / total_count if total_count > 0 else 0.0\n",
408
+ " \n",
409
+ " return {\n",
410
+ " \"total_count\": total_count,\n",
411
+ " \"sum\": np.sum(sums),\n",
412
+ " \"mean\": mean,\n",
413
+ " }\n",
414
+ "\n",
415
+ "\n",
416
+ "def get_track_means(bigwig_tracks_list: List[pyBigWig.pyBigWig]) -> np.ndarray:\n",
417
+ " \"\"\"\n",
418
+ " Get track means for normalization.\n",
419
+ " Computes statistics per chromosome and aggregates using weighted averaging,\n",
420
+ " \n",
421
+ " Args:\n",
422
+ " bigwig_tracks_list: List of pyBigWig file objects\n",
423
+ " \n",
424
+ " Returns:\n",
425
+ " Array of track means, one per bigwig file\n",
426
+ " \"\"\"\n",
427
+ " track_means = []\n",
428
+ " \n",
429
+ " for bigwig_track in bigwig_tracks_list:\n",
430
+ " chrom_lengths = bigwig_track.chroms()\n",
431
+ " all_chr_stats = []\n",
432
+ " \n",
433
+ " # Compute statistics for each chromosome\n",
434
+ " for chrom_name, chrom_length in chrom_lengths.items():\n",
435
+ " try:\n",
436
+ " # Get chromosome data as numpy array\n",
437
+ " bw_array = np.array(\n",
438
+ " bigwig_track.values(chrom_name, 0, chrom_length, numpy=True),\n",
439
+ " dtype=np.float32\n",
440
+ " )\n",
441
+ " # Replace NaN with 0\n",
442
+ " bw_array = np.nan_to_num(bw_array, nan=0.0)\n",
443
+ " \n",
444
+ " # Compute chromosome-level statistics\n",
445
+ " chr_stats = compute_chromosome_stats(bw_array)\n",
446
+ " all_chr_stats.append(chr_stats)\n",
447
+ " except Exception as e:\n",
448
+ " # Skip chromosomes that fail to load\n",
449
+ " print(f\"Warning: Failed to load chromosome {chrom_name}: {e}\")\n",
450
+ " continue\n",
451
+ " \n",
452
+ " if not all_chr_stats:\n",
453
+ " raise ValueError(f\"No valid chromosomes found for bigwig track\")\n",
454
+ " \n",
455
+ " # Aggregate chromosome-level stats into file-level stats\n",
456
+ " file_stats = aggregate_file_statistics(all_chr_stats)\n",
457
+ " \n",
458
+ " # Use the weighted mean for normalization\n",
459
+ " track_means.append(file_stats[\"mean\"])\n",
460
+ " \n",
461
+ " return np.array(track_means, dtype=np.float32)\n",
462
+ "\n",
463
+ "\n",
464
+ "def create_targets_scaling_fn(bigwig_path_list: List[str]) -> Callable[[torch.Tensor], torch.Tensor]:\n",
465
+ " \"\"\"\n",
466
+ " Build a scaling function based on track means computed from bigwig files.\n",
467
+ " \n",
468
+ " Opens bigwig files, computes track statistics, and creates a transform function.\n",
469
+ " The statistics are computed once and reused for all calls to the returned transform function.\n",
470
+ " \n",
471
+ " Args:\n",
472
+ " bigwig_path_list: List of paths to bigwig files\n",
473
+ " \n",
474
+ " Returns:\n",
475
+ " Transform function that scales input tensors\n",
476
+ " \"\"\"\n",
477
+ " # Open bigwig files and compute track statistics\n",
478
+ " print(\"Computing track statistics (this may take a while)...\")\n",
479
+ " bw_list = [\n",
480
+ " pyBigWig.open(bigwig_path)\n",
481
+ " for bigwig_path in bigwig_path_list\n",
482
+ " ]\n",
483
+ " track_means = get_track_means(bw_list)\n",
484
+ " print(f\"Computed track means: {track_means}\")\n",
485
+ " print(f\"Track means shape: {track_means.shape}\")\n",
486
+ " \n",
487
+ " # Create tensor from computed means\n",
488
+ " track_means_tensor = torch.tensor(track_means, dtype=torch.float32)\n",
489
+ " \n",
490
+ " def transform_fn(x: torch.Tensor) -> torch.Tensor:\n",
491
+ " \"\"\"\n",
492
+ " x: torch.Tensor, shape (seq_len, num_tracks) or (batch, seq_len, num_tracks)\n",
493
+ " \"\"\"\n",
494
+ " # Move constants to correct device then normalize\n",
495
+ " means = track_means_tensor.to(x.device)\n",
496
+ " scaled = x / means\n",
497
+ "\n",
498
+ " # Smooth clipping: if > 10, apply formula\n",
499
+ " clipped = torch.where(\n",
500
+ " scaled > 10.0,\n",
501
+ " 2.0 * torch.sqrt(scaled * 10.0) - 10.0,\n",
502
+ " scaled,\n",
503
+ " )\n",
504
+ " return clipped\n",
505
+ " \n",
506
+ " return transform_fn"
507
  ]
508
  },
509
  {
510
  "cell_type": "markdown",
511
  "metadata": {},
512
  "source": [
513
+ "# 4. \ud83d\udd04 Data loading\n",
514
  "\n",
515
+ "Create PyTorch datasets and data loaders that efficiently sample random genomic windows from the reference genome and extract corresponding BigWig signal values. The dataset handles sequence tokenization, target scaling, and chromosome-based train/val/test splits."
516
  ]
517
  },
518
  {
 
521
  "metadata": {},
522
  "outputs": [],
523
  "source": [
524
+ "# Process-local cache for BigWig file handles (one per worker process)\n",
525
+ "# This allows safe multi-worker DataLoader usage\n",
526
+ "_bigwig_cache = {} # Maps (process_id, file_path) -> pyBigWig handle\n",
 
 
 
527
  "\n",
 
 
 
 
 
 
528
  "\n",
529
+ "def _get_bigwig_handle(bigwig_path: str) -> pyBigWig.pyBigWig:\n",
530
+ " \"\"\"Get or create a BigWig file handle for the current process.\"\"\"\n",
531
+ " process_id = os.getpid()\n",
532
+ " cache_key = (process_id, bigwig_path)\n",
533
+ " \n",
534
+ " if cache_key not in _bigwig_cache:\n",
535
+ " _bigwig_cache[cache_key] = pyBigWig.open(bigwig_path)\n",
536
+ " \n",
537
+ " return _bigwig_cache[cache_key]\n",
538
+ "\n",
539
+ "\n",
540
+ "class GenomeBigWigDataset(Dataset):\n",
541
+ " \"\"\"\n",
542
+ " Random genomic windows from a reference genome + bigWig signal.\n",
543
+ "\n",
544
+ " Each sample:\n",
545
+ " - picks a chromosome/region (from `chroms` or `regions`),\n",
546
+ " - picks a random window of length `sequence_length`,\n",
547
+ " - returns (sequence, signal, chrom, start, end).\n",
548
+ "\n",
549
+ " This dataset is compatible with multi-worker DataLoaders. BigWig files\n",
550
+ " are opened lazily using a process-local cache, ensuring each worker process\n",
551
+ " has its own file handles and avoiding concurrent access issues.\n",
552
+ "\n",
553
+ " Args\n",
554
+ " ----\n",
555
+ " fasta_path : str\n",
556
+ " Path to the reference genome FASTA (e.g. hg38.fna).\n",
557
+ " bigwig_path_list : str\n",
558
+ " Path to the bigWig file (e.g. ENCFF884LDL.bigWig).\n",
559
+ " chroms : List[str]\n",
560
+ " Chromosome names as they appear in the bigWig (e.g. [\"chr1\", \"chr2\", ...]).\n",
561
+ " Used for backward compatibility or when regions=None.\n",
562
+ " sequence_length : int\n",
563
+ " Length of each random window (in bp).\n",
564
+ " num_samples : int\n",
565
+ " Number of samples the dataset will provide (len(dataset)).\n",
566
+ " tokenizer : AutoTokenizer\n",
567
+ " Tokenizer to use for tokenization.\n",
568
+ " transform_fn : Callable\n",
569
+ " Function to transform/scaling bigwig targets.\n",
570
+ " keep_target_center_fraction : float\n",
571
+ " Fraction of center sequence to keep for target prediction (crops edges to focus on center).\n",
572
+ " regions : List[tuple[str, int, int]] | None\n",
573
+ " Optional list of regions as (chromosome, start, end) tuples.\n",
574
+ " If provided, samples are drawn randomly from within these regions only.\n",
575
+ " This matches the JAX pipeline approach using BED file splits.\n",
576
+ " If None, samples from entire chromosomes in `chroms`.\n",
577
+ " \"\"\"\n",
578
+ "\n",
579
+ " def __init__(\n",
580
+ " self,\n",
581
+ " fasta_path: str,\n",
582
+ " bigwig_path_list: list[str],\n",
583
+ " chroms: List[str],\n",
584
+ " sequence_length: int,\n",
585
+ " num_samples: int,\n",
586
+ " tokenizer: AutoTokenizer,\n",
587
+ " transform_fn: Callable[[torch.Tensor], torch.Tensor],\n",
588
+ " keep_target_center_fraction: float = 1.0,\n",
589
+ " ):\n",
590
+ " super().__init__()\n",
591
+ "\n",
592
+ " self.fasta = Fasta(fasta_path, as_raw=True, sequence_always_upper=True)\n",
593
+ " # Store paths instead of opening files immediately (for multi-worker compatibility)\n",
594
+ " self.bigwig_path_list = bigwig_path_list\n",
595
+ " self.sequence_length = sequence_length\n",
596
+ " self.num_samples = num_samples\n",
597
+ " self.tokenizer = tokenizer\n",
598
+ " self.transform_fn = transform_fn # Use pre-computed transform function\n",
599
+ " self.keep_target_center_fraction = keep_target_center_fraction\n",
600
+ " self.chroms = chroms\n",
601
+ "\n",
602
+ " # Get chromosome lengths from first BigWig file (lazy, cached per process)\n",
603
+ " # We need this for validation, so open temporarily\n",
604
+ " bw_handle = _get_bigwig_handle(bigwig_path_list[0])\n",
605
+ " bw_chrom_lengths = bw_handle.chroms() # dict: chrom -> length\n",
606
+ "\n",
607
+ " self.valid_chroms = []\n",
608
+ " self.chrom_lengths = {}\n",
609
+ "\n",
610
+ " for c in chroms:\n",
611
+ " if c not in bw_chrom_lengths or c not in self.fasta:\n",
612
+ " continue\n",
613
+ "\n",
614
+ " fa_len = len(self.fasta[c])\n",
615
+ " bw_len = bw_chrom_lengths[c]\n",
616
+ " L = min(fa_len, bw_len)\n",
617
+ "\n",
618
+ " if L > self.sequence_length:\n",
619
+ " self.valid_chroms.append(c)\n",
620
+ " self.chrom_lengths[c] = L\n",
621
+ "\n",
622
+ " if not self.valid_chroms:\n",
623
+ " raise ValueError(\"No valid chromosomes after intersecting FASTA and bigWig.\")\n",
624
+ "\n",
625
+ " def __len__(self):\n",
626
+ " return self.num_samples\n",
627
+ "\n",
628
+ " def __getitem__(self, idx):\n",
629
+ "\n",
630
+ " # Sample from entire chromosomes\n",
631
+ " chrom = random.choice(self.valid_chroms)\n",
632
+ " chrom_len = self.chrom_lengths[chrom]\n",
633
+ " max_start = chrom_len - self.sequence_length\n",
634
+ " start = random.randint(0, max_start)\n",
635
+ " end = start + self.sequence_length\n",
636
+ "\n",
637
+ " # Sequence\n",
638
+ " seq = self.fasta[chrom][start:end] # string slice\n",
639
+ " # Tokenize with padding and truncation to ensure consistent lengths for batching\n",
640
+ " tokenized = self.tokenizer(\n",
641
+ " seq,\n",
642
+ " padding=\"max_length\",\n",
643
+ " truncation=True,\n",
644
+ " max_length=self.sequence_length,\n",
645
+ " return_tensors=\"pt\",\n",
646
+ " )\n",
647
+ " tokens = tokenized[\"input_ids\"][0] # Shape: (max_length,)\n",
648
+ "\n",
649
+ " # Signal from bigWig tracks (numpy array) -> torch tensor\n",
650
+ " # Get BigWig handles lazily (cached per worker process)\n",
651
+ " bigwig_targets = np.array([\n",
652
+ " _get_bigwig_handle(bw_path).values(chrom, start, end, numpy=True)\n",
653
+ " for bw_path in self.bigwig_path_list\n",
654
+ " ]) # shape (num_tracks, seq_len)\n",
655
+ " # Transpose to (seq_len, num_tracks)\n",
656
+ " bigwig_targets = bigwig_targets.T\n",
657
+ " # pyBigWig returns NaN where no data; turn NaN into 0\n",
658
+ " bigwig_targets = torch.tensor(bigwig_targets, dtype=torch.float32)\n",
659
+ " bigwig_targets = torch.nan_to_num(bigwig_targets, nan=0.0)\n",
660
+ " \n",
661
+ " # Crop targets to center fraction\n",
662
+ " if self.keep_target_center_fraction < 1.0:\n",
663
+ " seq_len = bigwig_targets.shape[0] # First dimension is sequence length\n",
664
+ " target_offset = int(seq_len * (1 - self.keep_target_center_fraction) // 2)\n",
665
+ " target_length = seq_len - 2 * target_offset\n",
666
+ " bigwig_targets = bigwig_targets[target_offset:target_offset + target_length, :]\n",
667
+ "\n",
668
+ " # Apply scaling to targets\n",
669
+ " bigwig_targets = self.transform_fn(bigwig_targets)\n",
670
+ "\n",
671
+ " sample = {\n",
672
+ " \"tokens\": tokens,\n",
673
+ " \"bigwig_targets\": bigwig_targets,\n",
674
+ " \"chrom\": chrom,\n",
675
+ " \"start\": start,\n",
676
+ " \"end\": end,\n",
677
+ " }\n",
678
+ " return sample"
679
  ]
680
  },
681
  {
 
684
  "metadata": {},
685
  "outputs": [],
686
  "source": [
687
+ "# Create scaling function\n",
688
+ "targets_transform_fn = create_targets_scaling_fn(bigwig_path_list)"
689
+ ]
690
+ },
691
+ {
692
+ "cell_type": "code",
693
+ "execution_count": null,
694
+ "metadata": {},
695
+ "outputs": [],
696
+ "source": [
697
+ "# Create datasets & dataloaders\n",
698
+ "create_dataset_fn = functools.partial(\n",
699
+ " GenomeBigWigDataset,\n",
700
+ " fasta_path=fasta_path,\n",
701
+ " bigwig_path_list=bigwig_path_list,\n",
702
+ " sequence_length=config[\"sequence_length\"],\n",
703
+ " tokenizer=tokenizer,\n",
704
+ " transform_fn=targets_transform_fn,\n",
705
+ " keep_target_center_fraction=config[\"keep_target_center_fraction\"],\n",
706
+ ")\n",
 
 
 
 
 
707
  "\n",
708
+ "train_dataset = create_dataset_fn(\n",
709
+ " chroms=chrom_splits[\"train\"],\n",
710
+ " num_samples=config[\"num_steps_training\"] * config[\"batch_size\"],\n",
 
 
 
711
  ")\n",
712
  "\n",
713
+ "val_dataset = create_dataset_fn(\n",
714
+ " chroms=chrom_splits[\"val\"],\n",
715
+ " num_samples=config[\"num_validation_samples\"],\n",
716
+ ")\n",
717
  "\n",
718
+ "test_dataset = create_dataset_fn(\n",
719
+ " chroms=chrom_splits[\"test\"],\n",
720
+ " num_samples=config[\"num_test_samples\"],\n",
721
+ ")\n",
 
 
 
 
722
  "\n",
723
+ "# Create dataloaders\n",
724
+ "train_loader = DataLoader(\n",
725
+ " train_dataset,\n",
726
+ " batch_size=config[\"batch_size\"],\n",
727
+ " shuffle=True,\n",
728
+ " num_workers=config[\"num_workers\"],\n",
729
+ ")\n",
730
+ "\n",
731
+ "val_loader = DataLoader(\n",
732
+ " val_dataset,\n",
733
+ " batch_size=config[\"batch_size\"],\n",
734
+ " shuffle=False,\n",
735
+ " num_workers=config[\"num_workers\"],\n",
736
+ ")\n",
737
  "\n",
738
+ "test_loader = DataLoader(\n",
739
+ " test_dataset,\n",
740
+ " batch_size=config[\"batch_size\"],\n",
741
+ " shuffle=False,\n",
742
+ " num_workers=config[\"num_workers\"],\n",
743
+ ")\n",
744
+ "\n",
745
+ "print(f\"Train samples: {len(train_dataset)}\")\n",
746
+ "print(f\"Val samples: {len(val_dataset)}\")\n",
747
+ "print(f\"Test samples: {len(test_dataset)}\")"
748
  ]
749
  },
750
  {
751
  "cell_type": "markdown",
752
  "metadata": {},
753
  "source": [
754
+ "# 5. \u2699\ufe0f Optimizer setup\n",
755
  "\n",
756
+ "Configure the AdamW optimizer with learning rate and weight decay hyperparameters. This optimizer will update the model parameters during training to minimize the loss function.\n",
757
+ "\n"
758
  ]
759
  },
760
  {
761
  "cell_type": "code",
762
+ "execution_count": null,
763
+ "metadata": {},
764
+ "outputs": [],
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
765
  "source": [
766
  "# Training setup\n",
767
  "print(f\"Training configuration:\")\n",
 
785
  "cell_type": "markdown",
786
  "metadata": {},
787
  "source": [
788
+ "# 6. \ud83d\udcca Metrics setup\n",
789
  "\n",
790
  "Set up evaluation metrics to track model performance during training and validation. We use Pearson correlation coefficients to measure how well the predicted BigWig signals match the ground truth signals."
791
  ]
792
  },
793
  {
794
  "cell_type": "code",
795
+ "execution_count": null,
796
  "metadata": {},
797
  "outputs": [],
798
  "source": [
 
865
  },
866
  {
867
  "cell_type": "code",
868
+ "execution_count": null,
869
  "metadata": {},
870
  "outputs": [],
871
  "source": [
 
878
  "cell_type": "markdown",
879
  "metadata": {},
880
  "source": [
881
+ "# 7. \ud83d\udcc9 Loss functions\n",
882
  "\n",
883
  "Define the Poisson-Multinomial loss function that captures both the scale (total signal) and shape (distribution) of BigWig tracks. This loss is specifically designed for count-based genomic signal data."
884
  ]
885
  },
886
  {
887
  "cell_type": "code",
888
+ "execution_count": null,
889
  "metadata": {},
890
  "outputs": [],
891
  "source": [
 
960
  "cell_type": "markdown",
961
  "metadata": {},
962
  "source": [
963
+ "# 8. \ud83c\udfc3 Training loop\n",
964
  "\n",
965
  "Run the main training loop that iterates through batches, computes gradients, and updates model parameters. The loop includes periodic validation checks and real-time metric visualization to monitor training progress."
966
  ]
967
  },
968
  {
969
  "cell_type": "code",
970
+ "execution_count": null,
971
  "metadata": {},
972
  "outputs": [],
973
  "source": [
 
1035
  },
1036
  {
1037
  "cell_type": "code",
1038
+ "execution_count": null,
1039
  "metadata": {},
1040
  "outputs": [],
1041
  "source": [
 
1192
  "cell_type": "markdown",
1193
  "metadata": {},
1194
  "source": [
1195
+ "# 9. \ud83e\uddea Test evaluation\n",
1196
  "\n",
1197
  "Evaluate the fine-tuned model on the held-out test set to assess final performance. This provides an unbiased estimate of how well the model generalizes to unseen genomic regions."
1198
  ]
 
1270
  },
1271
  "nbformat": 4,
1272
  "nbformat_minor": 2
1273
+ }