ybornachot commited on
Commit
35ab8fa
·
1 Parent(s): c9625a0

refactor: improved comments and markdown

Browse files
Files changed (1) hide show
  1. notebooks/03_fine_tuning.ipynb +72 -73
notebooks/03_fine_tuning.ipynb CHANGED
@@ -8,29 +8,32 @@
8
  "\n",
9
  "This notebook demonstrates a **simplified fine-tuning setup** that enables training of a pre-trained Nucleotide Transformer v3 (NTv3) model to predict BigWig signal tracks directly from DNA sequences. The streamlined approach leverages a pre-trained NTv3 backbone as a feature extractor and adds a custom prediction head that outputs single-nucleotide resolution signal values for various genomic tracks (e.g., ChIP-seq, ATAC-seq, RNA-seq).\n",
10
  "\n",
11
- "**⚡ Key Advantage**: This simplified pipeline achieves close performance to more complex training approaches while enabling fast fine-tuning. The training speed benefits from the efficient NTv3 model architecture and depends on your hardware capabilities (GPU acceleration and multi-worker data loading significantly reduce training time). With NTv3 models, meaningful Pearson correlations can typically be reached within ~10minutes of training on a 32kb functional tracks prediction task.\n",
12
  "\n",
13
  "**🔧 Main Simplifications**: Compared to the full supervised tracks pipeline, this notebook simplifies several aspects to enable faster iteration:\n",
14
- "\n",
15
- "- **Data splits**: Uses simple chromosome-based train/val/test splits (e.g., assigning entire chromosomes to each split) instead of more complex region-based splits\n",
16
  "- **Random sequence sampling**: The dataset randomly samples sequences from chromosomes/regions on-the-fly, rather than using pre-computed sliding windows\n",
17
  "- **Constant learning rate**: Uses a fixed learning rate throughout training without learning rate scheduling\n",
18
  "- **No gradient accumulation**: Implements simple step-based training without gradient accumulation, making the training loop more straightforward\n",
19
  "\n",
20
- "The pipeline walks through the complete fine-tuning workflow:\n",
 
 
 
21
  "\n",
22
- "- Loading genomic sequences from FASTA files and their corresponding BigWig signal tracks\n",
 
23
  "- Setting up a PyTorch dataset with proper train/validation/test splits\n",
24
  "- Configuring the model architecture with a custom linear head\n",
25
  "- Implementing a training loop with appropriate loss functions and evaluation metrics\n",
 
26
  "\n",
27
  "This provides a clean interface for training and evaluation.\n",
28
  "\n",
29
  "The model architecture consists of a pre-trained NTv3 backbone that processes DNA sequences and a custom linear head that predicts BigWig signal values at single-nucleotide resolution. Predictions are center-cropped to focus on the central portion of the input sequence (configurable via `keep_target_center_fraction`), which helps reduce edge effects from sequence context windows. The training uses a Poisson-Multinomial loss function that captures both the scale and shape of the signal distributions, and evaluation is performed using Pearson correlation metrics on both scaled and raw predictions.\n",
30
  "\n",
31
- "If you're interested in using pre-trained models for inference without fine-tuning, or exploring different model architectures, please refer to other notebooks in this collection. This notebook focuses specifically on the simplified fine-tuning process, which is useful when you want to quickly adapt a pre-trained model to your specific genomic tracks or improve performance on particular cell types or experimental conditions.\n",
32
  "\n",
33
- "📝 Note for Google Colab users: This notebook is compatible with Colab! For faster training, make sure to enable GPU: Runtime → Change runtime type → GPU (T4 or better recommended).\n"
34
  ]
35
  },
36
  {
@@ -56,28 +59,29 @@
56
  "metadata": {},
57
  "outputs": [],
58
  "source": [
59
- "import random\n",
60
  "import functools\n",
61
- "from typing import List, Dict, Callable\n",
62
- "import os\n",
63
  "import fnmatch\n",
 
 
64
  "from pathlib import Path\n",
65
- "from huggingface_hub import HfApi, snapshot_download\n",
66
  "\n",
67
- "import torch\n",
68
- "import torch.nn as nn\n",
69
- "import torch.nn.functional as F\n",
70
- "from torch.utils.data import Dataset, DataLoader\n",
71
- "from torch.optim import AdamW\n",
72
- "from transformers import AutoConfig, AutoModelForMaskedLM, AutoTokenizer\n",
73
- "import pandas as pd\n",
74
  "import matplotlib.pyplot as plt\n",
75
  "import numpy as np\n",
 
76
  "import pyBigWig\n",
77
  "from pyfaidx import Fasta\n",
 
 
 
 
 
78
  "from torchmetrics import PearsonCorrCoef\n",
79
- "import plotly.graph_objects as go\n",
80
- "from IPython.display import display\n",
81
  "from tqdm import tqdm"
82
  ]
83
  },
@@ -131,29 +135,29 @@
131
  " \n",
132
  " # Data\n",
133
  " \"hf_repo_id\": \"InstaDeepAI/NTv3_benchmark_dataset\",\n",
134
- " \"species\": \"arabidopsis\",\n",
135
  " \"data_cache_dir\": \"./data\",\n",
136
  " \"sequence_length\": 32_768,\n",
137
  " \"keep_target_center_fraction\": 0.375,\n",
138
  " \n",
139
  " # Training\n",
140
- " \"batch_size\": 32,\n",
141
- " \"num_steps_training\": 19932,\n",
142
- " \"log_every_n_steps\": 40,\n",
143
  " \"learning_rate\": 1e-5,\n",
144
  " \"weight_decay\": 0.01,\n",
145
  " \n",
146
  " # Validation\n",
147
- " \"validate_every_n_steps\": 400,\n",
148
- " \"num_validation_samples\": 1000,\n",
149
  "\n",
150
  " # Test\n",
151
- " \"num_test_samples\": 10000,\n",
152
  " \n",
153
  " # General\n",
154
- " \"seed\": 0,\n",
155
  " \"device\": \"cuda\" if torch.cuda.is_available() else \"cpu\",\n",
156
- " \"num_workers\": 16,\n",
157
  "}\n",
158
  "\n",
159
  "# Set random seed\n",
@@ -171,7 +175,15 @@
171
  "source": [
172
  "# 2. 📥 Genome & Tracks Data Download\n",
173
  "\n",
174
- "Download the reference genome FASTA file and BigWig signal tracks from public repositories. These files contain the genomic sequences and experimental signal data (e.g., ChIP-seq, ATAC-seq) that we'll use for training."
 
 
 
 
 
 
 
 
175
  ]
176
  },
177
  {
@@ -204,13 +216,7 @@
204
  " # Find all files to download: species directory + metadata at root\n",
205
  " species_pattern = f\"{species}/**\"\n",
206
  " metadata_file = \"benchmark_metadata.tsv\"\n",
207
- " \n",
208
  " species_files = [p for p in files if fnmatch.fnmatch(p, species_pattern)]\n",
209
- " if not species_files:\n",
210
- " raise ValueError(f\"No files found matching '{species_pattern}' in '{hf_repo_id}'\")\n",
211
- " \n",
212
- " if metadata_file not in files:\n",
213
- " raise ValueError(f\"No metadata file found at '{metadata_file}' in '{hf_repo_id}'\")\n",
214
  " \n",
215
  " # Download all needed files\n",
216
  " download_patterns = [species_pattern, metadata_file]\n",
@@ -255,13 +261,8 @@
255
  " \n",
256
  " # Metadata file\n",
257
  " metadata_path = local_dir / metadata_file\n",
258
- " if not metadata_path.is_file():\n",
259
- " raise ValueError(f\"Metadata file not found at '{metadata_path}'\")\n",
260
  " metadata_df = pd.read_csv(metadata_path, sep=\"\\t\")\n",
261
  "\n",
262
- " if \"species\" not in metadata_df.columns:\n",
263
- " raise ValueError(\"benchmark_metadata.tsv has no 'species' column\")\n",
264
- "\n",
265
  " # Filter metadata according to species\n",
266
  " metadata_df = metadata_df[metadata_df[\"species\"] == species].reset_index(drop=True)\n",
267
  "\n",
@@ -414,6 +415,13 @@
414
  "Create PyTorch datasets and data loaders that efficiently sample random genomic windows from the reference genome and extract corresponding BigWig signal values. The dataset handles sequence tokenization, target scaling, and chromosome-based train/val/test splits."
415
  ]
416
  },
 
 
 
 
 
 
 
417
  {
418
  "cell_type": "code",
419
  "execution_count": null,
@@ -573,7 +581,7 @@
573
  "cell_type": "markdown",
574
  "metadata": {},
575
  "source": [
576
- "### Data preprocessing utilities"
577
  ]
578
  },
579
  {
@@ -621,6 +629,13 @@
621
  " return transform_fn"
622
  ]
623
  },
 
 
 
 
 
 
 
624
  {
625
  "cell_type": "code",
626
  "execution_count": null,
@@ -736,19 +751,12 @@
736
  " def __init__(self, track_names: List[str]):\n",
737
  " self.track_names = track_names\n",
738
  " self.num_tracks = len(track_names)\n",
739
- " # Metrics: comparing scaled targets with scaled predictions\n",
740
- " # Configure to use float64 for improved numerical stability\n",
741
- " self.pearson_metrics = [\n",
742
- " PearsonCorrCoef().to(device) for _ in range(self.num_tracks)\n",
743
- " ]\n",
744
- " # Set dtype to float64 to prevent NaN warnings when variance is close to zero\n",
745
- " for metric in self.pearson_metrics:\n",
746
- " metric.set_dtype(torch.float64)\n",
747
  " self.losses = []\n",
748
  " \n",
749
  " def reset(self):\n",
750
- " for metric in self.pearson_metrics:\n",
751
- " metric.reset()\n",
752
  " self.losses = []\n",
753
  " \n",
754
  " def update(\n",
@@ -771,10 +779,7 @@
771
  " # Convert to float64 for improved numerical stability in Pearson correlation\n",
772
  " pred_flat = pred_flat.to(torch.float64)\n",
773
  " target_flat = target_flat.to(torch.float64)\n",
774
- " \n",
775
- " # Update metrics\n",
776
- " for i, metric in enumerate(self.pearson_metrics):\n",
777
- " metric.update(pred_flat[:, i], target_flat[:, i])\n",
778
  " \n",
779
  " self.losses.append(loss)\n",
780
  " \n",
@@ -782,13 +787,12 @@
782
  " \"\"\"Compute and return all metrics.\"\"\"\n",
783
  " metrics_dict = {}\n",
784
  " \n",
785
- " # Per-track Pearson correlations\n",
786
- " for i, (track_name, metric) in enumerate(zip(self.track_names, self.pearson_metrics)):\n",
787
- " corr = metric.compute().item()\n",
788
- " metrics_dict[f\"{track_name}/pearson\"] = corr\n",
789
  " \n",
790
  " # Mean Pearson correlation\n",
791
- " correlations = [metric.compute().item() for metric in self.pearson_metrics]\n",
792
  " metrics_dict[\"mean/pearson\"] = np.nanmean(correlations)\n",
793
  " \n",
794
  " # Mean loss\n",
@@ -887,7 +891,7 @@
887
  " # Combine losses\n",
888
  " loss = shape_loss + scale_loss / shape_loss_coefficient\n",
889
  "\n",
890
- " return loss, scale_loss, shape_loss\n"
891
  ]
892
  },
893
  {
@@ -918,7 +922,7 @@
918
  " bigwig_logits = outputs[\"bigwig_tracks_logits\"]\n",
919
  " \n",
920
  " # Compute loss\n",
921
- " loss, _, _ = poisson_multinomial_loss(\n",
922
  " logits=bigwig_logits,\n",
923
  " targets=bigwig_targets,\n",
924
  " )\n",
@@ -945,7 +949,7 @@
945
  " bigwig_logits = outputs[\"bigwig_tracks_logits\"]\n",
946
  " \n",
947
  " # Compute loss\n",
948
- " loss, _, _ = poisson_multinomial_loss(\n",
949
  " logits=bigwig_logits,\n",
950
  " targets=bigwig_targets,\n",
951
  " )\n",
@@ -960,13 +964,6 @@
960
  " return loss.item()"
961
  ]
962
  },
963
- {
964
- "cell_type": "markdown",
965
- "metadata": {},
966
- "source": [
967
- "### Interactive plotting is temporary for debug"
968
- ]
969
- },
970
  {
971
  "cell_type": "code",
972
  "execution_count": null,
@@ -1038,7 +1035,7 @@
1038
  " f\"Step {step_idx + 1}/{config['num_steps_training']} | \"\n",
1039
  " f\"Loss: {mean_loss:.4f} | \"\n",
1040
  " f\"Mean Pearson: {train_metrics_dict['mean/pearson']:.4f} | \"\n",
1041
- " f\"Pearson per track: {train_metrics_dict[f'{track_name}/pearson']:.4f for track_name in config['bigwig_file_ids']}\"\n",
1042
  " )\n",
1043
  " train_metrics.reset()\n",
1044
  " \n",
@@ -1063,7 +1060,7 @@
1063
  " \n",
1064
  " print(f\" Validation Loss: {val_metrics_dict['loss']:.4f}\")\n",
1065
  " print(f\" Validation Mean Pearson: {val_pearson_mean:.4f}\")\n",
1066
- " for track_name in config[\"bigwig_file_ids\"]:\n",
1067
  " print(f\" {track_name}/pearson: {val_metrics_dict[f'{track_name}/pearson']:.4f}\")\n",
1068
  " \n",
1069
  " model.train() # Back to training mode\n",
@@ -1154,6 +1151,8 @@
1154
  "\n",
1155
  "Performances reached at ~1.5B tokens (~1500 steps in current 32kb sequences setup with batch_size=32)\n",
1156
  "\n",
 
 
1157
  "Mean Pearson: 0.5835\n",
1158
  "- ENCSR325NFE/pearson: 0.6081\n",
1159
  "- ENCSR962OTG/pearson: 0.7286\n",
 
8
  "\n",
9
  "This notebook demonstrates a **simplified fine-tuning setup** that enables training of a pre-trained Nucleotide Transformer v3 (NTv3) model to predict BigWig signal tracks directly from DNA sequences. The streamlined approach leverages a pre-trained NTv3 backbone as a feature extractor and adds a custom prediction head that outputs single-nucleotide resolution signal values for various genomic tracks (e.g., ChIP-seq, ATAC-seq, RNA-seq).\n",
10
  "\n",
11
+ "We provide access to the NTv3-benchmark data that we released on our Hugging Face dataset: `InstaDeepAI/NTv3_benchmark_dataset`. In this repository, you will find ready-to-use genome FASTA files, Bigwig tracks, metadata, but also the splits that were used for the benchmark.\n",
12
  "\n",
13
  "**🔧 Main Simplifications**: Compared to the full supervised tracks pipeline, this notebook simplifies several aspects to enable faster iteration:\n",
 
 
14
  "- **Random sequence sampling**: The dataset randomly samples sequences from chromosomes/regions on-the-fly, rather than using pre-computed sliding windows\n",
15
  "- **Constant learning rate**: Uses a fixed learning rate throughout training without learning rate scheduling\n",
16
  "- **No gradient accumulation**: Implements simple step-based training without gradient accumulation, making the training loop more straightforward\n",
17
  "\n",
18
+ "**⚡ Key Advantage**: This simplified pipeline achieves close performance to more complex training approaches while enabling fast fine-tuning: on a H100 GPU and using 16 workers for data loading, it takes ~10min to reach acceptable performances for a 32kb functional tracks prediction task. The training speed benefits from the efficient NTv3 model architecture, but of course depends on your hardware capabilities (GPU acceleration and multi-worker data loading significantly reduce training time).\n",
19
+ "\n",
20
+ "**⚠️ Important Note on Hardware Requirements**: If the pipeline is designed to run on limited resources (e.g., Google Colab with a T4 GPU and 2CPUs), the timing mentioned was obtained on an **H100 GPU with 16 CPUs**. If you want to reach similar performance levels, you should be aware that you'll need **significant hardware resources** (high-end GPUs with substantial memory and multiple data loading workers). Training times will vary significantly based on your hardware configuration.\n",
21
+ "\n",
22
  "\n",
23
+ "The pipeline walks through the complete fine-tuning workflow:\n",
24
+ "- Loading genomic FASTA files sequences and their corresponding BigWig signal tracks from Hugging Face dataset\n",
25
  "- Setting up a PyTorch dataset with proper train/validation/test splits\n",
26
  "- Configuring the model architecture with a custom linear head\n",
27
  "- Implementing a training loop with appropriate loss functions and evaluation metrics\n",
28
+ "- Evaluation of the fine-tuned model on the test set\n",
29
  "\n",
30
  "This provides a clean interface for training and evaluation.\n",
31
  "\n",
32
  "The model architecture consists of a pre-trained NTv3 backbone that processes DNA sequences and a custom linear head that predicts BigWig signal values at single-nucleotide resolution. Predictions are center-cropped to focus on the central portion of the input sequence (configurable via `keep_target_center_fraction`), which helps reduce edge effects from sequence context windows. The training uses a Poisson-Multinomial loss function that captures both the scale and shape of the signal distributions, and evaluation is performed using Pearson correlation metrics on both scaled and raw predictions.\n",
33
  "\n",
34
+ "If you're interested in using pre-trained models for inference without fine-tuning, or exploring different model architectures, please refer to other notebooks in this collection. This notebook focuses specifically on the simplified fine-tuning process, which is useful when you want to quickly adapt a pre-trained model to genomic tracks or improve performance on particular cell types or experimental conditions.\n",
35
  "\n",
36
+ "📝 Note for Google Colab users: This notebook is compatible with Colab and designed to work with limited resources! For faster training, make sure to enable GPU: Runtime → Change runtime type → GPU (T4 or better recommended). However, keep in mind that the timing benchmarks mentioned above were obtained on much more powerful hardware (H100 GPU), so your training times on Colab may be significantly longer."
37
  ]
38
  },
39
  {
 
59
  "metadata": {},
60
  "outputs": [],
61
  "source": [
62
+ "# Standard library imports\n",
63
  "import functools\n",
 
 
64
  "import fnmatch\n",
65
+ "import os\n",
66
+ "import random\n",
67
  "from pathlib import Path\n",
68
+ "from typing import Callable, Dict, List\n",
69
  "\n",
70
+ "# Third-party imports\n",
71
+ "from huggingface_hub import HfApi, snapshot_download\n",
72
+ "from IPython.display import display\n",
 
 
 
 
73
  "import matplotlib.pyplot as plt\n",
74
  "import numpy as np\n",
75
+ "import pandas as pd\n",
76
  "import pyBigWig\n",
77
  "from pyfaidx import Fasta\n",
78
+ "import torch\n",
79
+ "import torch.nn as nn\n",
80
+ "import torch.nn.functional as F\n",
81
+ "from torch.optim import AdamW\n",
82
+ "from torch.utils.data import DataLoader, Dataset\n",
83
  "from torchmetrics import PearsonCorrCoef\n",
84
+ "from transformers import AutoConfig, AutoModelForMaskedLM, AutoTokenizer\n",
 
85
  "from tqdm import tqdm"
86
  ]
87
  },
 
135
  " \n",
136
  " # Data\n",
137
  " \"hf_repo_id\": \"InstaDeepAI/NTv3_benchmark_dataset\",\n",
138
+ " \"species\": \"human\",\n",
139
  " \"data_cache_dir\": \"./data\",\n",
140
  " \"sequence_length\": 32_768,\n",
141
  " \"keep_target_center_fraction\": 0.375,\n",
142
  " \n",
143
  " # Training\n",
144
+ " \"batch_size\": 12,\n",
145
+ " \"num_steps_training\": 5315, # reproduce 10% of benchmark training length\n",
146
+ " \"log_every_n_steps\": 20,\n",
147
  " \"learning_rate\": 1e-5,\n",
148
  " \"weight_decay\": 0.01,\n",
149
  " \n",
150
  " # Validation\n",
151
+ " \"validate_every_n_steps\": 100,\n",
152
+ " \"num_validation_samples\": 1_000,\n",
153
  "\n",
154
  " # Test\n",
155
+ " \"num_test_samples\": 10_000,\n",
156
  " \n",
157
  " # General\n",
158
+ " \"seed\": 0, # for reproducibility\n",
159
  " \"device\": \"cuda\" if torch.cuda.is_available() else \"cpu\",\n",
160
+ " \"num_workers\": 2, # If resources allows it, consider increasing for faster data loading\n",
161
  "}\n",
162
  "\n",
163
  "# Set random seed\n",
 
175
  "source": [
176
  "# 2. 📥 Genome & Tracks Data Download\n",
177
  "\n",
178
+ "Download all needed data for fine-tuning from our Hugging Face NTv3 benchmark dataset for the selected species.\n",
179
+ "\n",
180
+ "These files contain:\n",
181
+ "- genomic sequences \n",
182
+ "- experimental signal data (e.g., ChIP-seq, ATAC-seq)\n",
183
+ "- ready-to-use splits (as bed files)\n",
184
+ "- metadata for data normalization\n",
185
+ "\n",
186
+ "If you want to fine-tune a model a species that is not available on our Hugging Face dataset or on other bigwig tracks, you should consider downloading the genome FASTA file and/or the BigWig files from URLs, using `wget`."
187
  ]
188
  },
189
  {
 
216
  " # Find all files to download: species directory + metadata at root\n",
217
  " species_pattern = f\"{species}/**\"\n",
218
  " metadata_file = \"benchmark_metadata.tsv\"\n",
 
219
  " species_files = [p for p in files if fnmatch.fnmatch(p, species_pattern)]\n",
 
 
 
 
 
220
  " \n",
221
  " # Download all needed files\n",
222
  " download_patterns = [species_pattern, metadata_file]\n",
 
261
  " \n",
262
  " # Metadata file\n",
263
  " metadata_path = local_dir / metadata_file\n",
 
 
264
  " metadata_df = pd.read_csv(metadata_path, sep=\"\\t\")\n",
265
  "\n",
 
 
 
266
  " # Filter metadata according to species\n",
267
  " metadata_df = metadata_df[metadata_df[\"species\"] == species].reset_index(drop=True)\n",
268
  "\n",
 
415
  "Create PyTorch datasets and data loaders that efficiently sample random genomic windows from the reference genome and extract corresponding BigWig signal values. The dataset handles sequence tokenization, target scaling, and chromosome-based train/val/test splits."
416
  ]
417
  },
418
+ {
419
+ "cell_type": "markdown",
420
+ "metadata": {},
421
+ "source": [
422
+ "### Dataset Setup"
423
+ ]
424
+ },
425
  {
426
  "cell_type": "code",
427
  "execution_count": null,
 
581
  "cell_type": "markdown",
582
  "metadata": {},
583
  "source": [
584
+ "### Scaling function Setup"
585
  ]
586
  },
587
  {
 
629
  " return transform_fn"
630
  ]
631
  },
632
+ {
633
+ "cell_type": "markdown",
634
+ "metadata": {},
635
+ "source": [
636
+ "### Instantiate datasets and dataloaders"
637
+ ]
638
+ },
639
  {
640
  "cell_type": "code",
641
  "execution_count": null,
 
751
  " def __init__(self, track_names: List[str]):\n",
752
  " self.track_names = track_names\n",
753
  " self.num_tracks = len(track_names)\n",
754
+ " self.pearson_metric = PearsonCorrCoef(num_outputs=self.num_tracks).to(device)\n",
755
+ " self.pearson_metric.set_dtype(torch.float64)\n",
 
 
 
 
 
 
756
  " self.losses = []\n",
757
  " \n",
758
  " def reset(self):\n",
759
+ " self.pearson_metric.reset()\n",
 
760
  " self.losses = []\n",
761
  " \n",
762
  " def update(\n",
 
779
  " # Convert to float64 for improved numerical stability in Pearson correlation\n",
780
  " pred_flat = pred_flat.to(torch.float64)\n",
781
  " target_flat = target_flat.to(torch.float64)\n",
782
+ " self.pearson_metric.update(pred_flat, target_flat)\n",
 
 
 
783
  " \n",
784
  " self.losses.append(loss)\n",
785
  " \n",
 
787
  " \"\"\"Compute and return all metrics.\"\"\"\n",
788
  " metrics_dict = {}\n",
789
  " \n",
790
+ " # Compute Pearson correlation per track\n",
791
+ " correlations = self.pearson_metric.compute().numpy()\n",
792
+ " for i, track_name in enumerate(self.track_names):\n",
793
+ " metrics_dict[f\"{track_name}/pearson\"] = correlations[i]\n",
794
  " \n",
795
  " # Mean Pearson correlation\n",
 
796
  " metrics_dict[\"mean/pearson\"] = np.nanmean(correlations)\n",
797
  " \n",
798
  " # Mean loss\n",
 
891
  " # Combine losses\n",
892
  " loss = shape_loss + scale_loss / shape_loss_coefficient\n",
893
  "\n",
894
+ " return loss\n"
895
  ]
896
  },
897
  {
 
922
  " bigwig_logits = outputs[\"bigwig_tracks_logits\"]\n",
923
  " \n",
924
  " # Compute loss\n",
925
+ " loss = poisson_multinomial_loss(\n",
926
  " logits=bigwig_logits,\n",
927
  " targets=bigwig_targets,\n",
928
  " )\n",
 
949
  " bigwig_logits = outputs[\"bigwig_tracks_logits\"]\n",
950
  " \n",
951
  " # Compute loss\n",
952
+ " loss = poisson_multinomial_loss(\n",
953
  " logits=bigwig_logits,\n",
954
  " targets=bigwig_targets,\n",
955
  " )\n",
 
964
  " return loss.item()"
965
  ]
966
  },
 
 
 
 
 
 
 
967
  {
968
  "cell_type": "code",
969
  "execution_count": null,
 
1035
  " f\"Step {step_idx + 1}/{config['num_steps_training']} | \"\n",
1036
  " f\"Loss: {mean_loss:.4f} | \"\n",
1037
  " f\"Mean Pearson: {train_metrics_dict['mean/pearson']:.4f} | \"\n",
1038
+ " f\"Pearson per track: {\", \".join([f\"{track_name}/pearson: {train_metrics_dict[f'{track_name}/pearson']:.4f}\" for track_name in bigwig_ids])}\"\n",
1039
  " )\n",
1040
  " train_metrics.reset()\n",
1041
  " \n",
 
1060
  " \n",
1061
  " print(f\" Validation Loss: {val_metrics_dict['loss']:.4f}\")\n",
1062
  " print(f\" Validation Mean Pearson: {val_pearson_mean:.4f}\")\n",
1063
+ " for track_name in bigwig_ids:\n",
1064
  " print(f\" {track_name}/pearson: {val_metrics_dict[f'{track_name}/pearson']:.4f}\")\n",
1065
  " \n",
1066
  " model.train() # Back to training mode\n",
 
1151
  "\n",
1152
  "Performances reached at ~1.5B tokens (~1500 steps in current 32kb sequences setup with batch_size=32)\n",
1153
  "\n",
1154
+ "**Hardware configuration**: These results were obtained on an **H100 GPU with 16 workers** for data loading in approximately **~10 minutes** of training.\n",
1155
+ "\n",
1156
  "Mean Pearson: 0.5835\n",
1157
  "- ENCSR325NFE/pearson: 0.6081\n",
1158
  "- ENCSR962OTG/pearson: 0.7286\n",