ybornachot commited on
Commit
65f032b
·
1 Parent(s): b6b1c80

feat: enhanced dataset with multiprocessing compatibility + added documentation

Browse files
Files changed (2) hide show
  1. .gitattributes +1 -0
  2. notebooks/03_fine_tuning.ipynb +3 -1425
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ *.ipynb filter=lfs diff=lfs merge=lfs -text
notebooks/03_fine_tuning.ipynb CHANGED
@@ -1,1425 +1,3 @@
1
- {
2
- "cells": [
3
- {
4
- "cell_type": "markdown",
5
- "metadata": {},
6
- "source": [
7
- "# Simple PyTorch Tracks Fine-Tuning Pipeline\n",
8
- "\n",
9
- "This notebook implements a simple PyTorch-based deep learning pipeline for tracks prediction fine-tuning.\n",
10
- "\n",
11
- "## Overview\n",
12
- "- Loads a HuggingFace model (NTv3) as backbone\n",
13
- "- Adds a prediction head for bigwig tracks\n",
14
- "- Fine-tunes on tracks prediction with a simple training loop\n"
15
- ]
16
- },
17
- {
18
- "cell_type": "code",
19
- "execution_count": 1,
20
- "metadata": {},
21
- "outputs": [],
22
- "source": [
23
- "# Install useful dependencies\n",
24
- "# !pip install pyBigWig\n",
25
- "# !pip install pyfaidx\n",
26
- "# !pip install torchmetrics"
27
- ]
28
- },
29
- {
30
- "cell_type": "code",
31
- "execution_count": 1,
32
- "metadata": {},
33
- "outputs": [],
34
- "source": [
35
- "# 0. Imports\n",
36
- "import random\n",
37
- "import functools\n",
38
- "from typing import List, Dict, Callable\n",
39
- "import os\n",
40
- "import subprocess\n",
41
- "\n",
42
- "import torch\n",
43
- "import torch.nn as nn\n",
44
- "import torch.nn.functional as F\n",
45
- "from torch.utils.data import Dataset, DataLoader\n",
46
- "from torch.optim import AdamW\n",
47
- "from transformers import AutoConfig, AutoModelForMaskedLM, AutoTokenizer\n",
48
- "import numpy as np\n",
49
- "import pyBigWig\n",
50
- "from pyfaidx import Fasta\n",
51
- "from torchmetrics import PearsonCorrCoef\n",
52
- "import plotly.graph_objects as go\n",
53
- "from plotly.subplots import make_subplots\n",
54
- "from IPython.display import display"
55
- ]
56
- },
57
- {
58
- "cell_type": "markdown",
59
- "metadata": {},
60
- "source": [
61
- "# 1. Configuration setup\n",
62
- "\n",
63
- "## Configuration Parameters\n",
64
- "\n",
65
- "### Model\n",
66
- "- **`model_name`**: HuggingFace model name/identifier for the pretrained backbone model\n",
67
- "\n",
68
- "### Data\n",
69
- "- **`data_cache_dir`**: Directory where downloaded data files (FASTA, bigWig) will be stored\n",
70
- "- **`fasta_url`**: URL to download reference genome FASTA file\n",
71
- "- **`bigwig_url_list`**: List of URLs for bigWig track files to download\n",
72
- "- **`sequence_length`**: Length of input sequences in base pairs (bp)\n",
73
- "- **`keep_target_center_fraction`**: Fraction of center sequence to keep for target prediction (crops edges to focus on center)\n",
74
- "\n",
75
- "### Training\n",
76
- "- **`batch_size`**: Number of samples per batch\n",
77
- "- **`learning_rate`**: Constant learning rate for optimizer\n",
78
- "- **`weight_decay`**: L2 regularization coefficient for optimizer\n",
79
- "- **`num_steps_training`**: Total number of training steps\n",
80
- "- **`log_every_n_steps`**: Log training metrics every N steps\n",
81
- "- **`validate_every_n_steps`**: Run validation every N steps\n",
82
- "\n",
83
- "### Validation\n",
84
- "- **`num_validation_samples`**: Number of samples to use for validation set\n",
85
- "\n",
86
- "### General\n",
87
- "- **`seed`**: Random seed for reproducibility\n",
88
- "- **`device`**: Device to run training on (\"cuda\" or \"cpu\")\n",
89
- "- **`num_workers`**: Number of worker processes for DataLoader (0 = single-threaded)"
90
- ]
91
- },
92
- {
93
- "cell_type": "code",
94
- "execution_count": 15,
95
- "metadata": {},
96
- "outputs": [
97
- {
98
- "name": "stdout",
99
- "output_type": "stream",
100
- "text": [
101
- "Using device: cpu\n"
102
- ]
103
- }
104
- ],
105
- "source": [
106
- "config = {\n",
107
- " # Model\n",
108
- " \"model_name\": \"InstaDeepAI/ntv3_8M_7downsample_pretrained_le_1mb\",\n",
109
- " \n",
110
- " # Data\n",
111
- " \"data_cache_dir\": \"./data\",\n",
112
- " \"fasta_url\": \"https://hgdownload.gi.ucsc.edu/goldenPath/hg38/bigZips/hg38.fa.gz\",\n",
113
- " \"bigwig_url_list\": [\n",
114
- " \"https://www.encodeproject.org/files/ENCFF884LDL/@@download/ENCFF884LDL.bigWig\"\n",
115
- " ],\n",
116
- " \"sequence_length\": 1_024,\n",
117
- " \"keep_target_center_fraction\": 0.375,\n",
118
- " \n",
119
- " # Training\n",
120
- " \"batch_size\": 8,\n",
121
- " \"num_steps_training\": 1000,\n",
122
- " \"log_every_n_steps\": 10,\n",
123
- " \"learning_rate\": 1e-5,\n",
124
- " \"weight_decay\": 0.01,\n",
125
- " \n",
126
- " # Validation\n",
127
- " \"validate_every_n_steps\": 50,\n",
128
- " \"num_validation_samples\": 100,\n",
129
- " \n",
130
- " # General\n",
131
- " \"seed\": 42,\n",
132
- " \"device\": \"cuda\" if torch.cuda.is_available() else \"cpu\",\n",
133
- " \"num_workers\": 0,\n",
134
- "}\n",
135
- "\n",
136
- "os.makedirs(config[\"data_cache_dir\"], exist_ok=True)\n",
137
- "\n",
138
- "# Extract filenames from URLs\n",
139
- "def extract_filename_from_url(url: str) -> str:\n",
140
- " \"\"\"Extract filename from URL, handling query parameters.\"\"\"\n",
141
- " # Remove query parameters if present\n",
142
- " url_clean = url.split('?')[0]\n",
143
- " # Get the last part of the URL path\n",
144
- " return url_clean.split('/')[-1]\n",
145
- "\n",
146
- "# Create paths for downloaded files\n",
147
- "fasta_path = os.path.join(config[\"data_cache_dir\"], extract_filename_from_url(config[\"fasta_url\"]).replace('.gz', ''))\n",
148
- "bigwig_path_list = [\n",
149
- " os.path.join(config[\"data_cache_dir\"], extract_filename_from_url(url))\n",
150
- " for url in config[\"bigwig_url_list\"]\n",
151
- "]\n",
152
- "\n",
153
- "# Create bigwig_file_ids from filenames (without extension)\n",
154
- "config[\"bigwig_file_ids\"] = [\n",
155
- " os.path.splitext(extract_filename_from_url(url))[0]\n",
156
- " for url in config[\"bigwig_url_list\"]\n",
157
- "]\n",
158
- "\n",
159
- "# Set random seed\n",
160
- "torch.manual_seed(config[\"seed\"])\n",
161
- "np.random.seed(config[\"seed\"])\n",
162
- "\n",
163
- "# Set device\n",
164
- "device = torch.device(config[\"device\"])\n",
165
- "print(f\"Using device: {device}\")"
166
- ]
167
- },
168
- {
169
- "cell_type": "markdown",
170
- "metadata": {},
171
- "source": [
172
- "# 2. Data download"
173
- ]
174
- },
175
- {
176
- "cell_type": "code",
177
- "execution_count": 3,
178
- "metadata": {},
179
- "outputs": [
180
- {
181
- "name": "stdout",
182
- "output_type": "stream",
183
- "text": [
184
- "--2025-12-10 14:47:06-- https://hgdownload.gi.ucsc.edu/goldenPath/hg38/bigZips/hg38.fa.gz\n",
185
- "Resolving hgdownload.gi.ucsc.edu (hgdownload.gi.ucsc.edu)... 128.114.119.163\n",
186
- "Connecting to hgdownload.gi.ucsc.edu (hgdownload.gi.ucsc.edu)|128.114.119.163|:443... connected.\n",
187
- "HTTP request sent, awaiting response... 200 OK\n",
188
- "Length: 983659424 (938M) [application/x-gzip]\n",
189
- "Saving to: './data/hg38.fa.gz'\n",
190
- "\n",
191
- "hg38.fa.gz 100%[===================>] 938.09M 10.4MB/s in 1m 43s \n",
192
- "\n",
193
- "2025-12-10 14:48:50 (9.09 MB/s) - './data/hg38.fa.gz' saved [983659424/983659424]\n",
194
- "\n"
195
- ]
196
- }
197
- ],
198
- "source": [
199
- "# Download fasta file\n",
200
- "!wget -c {config[\"fasta_url\"]} -P {config[\"data_cache_dir\"]}/ && gunzip -f {config[\"data_cache_dir\"]}/{config[\"fasta_url\"].split(os.path.sep)[-1]}"
201
- ]
202
- },
203
- {
204
- "cell_type": "code",
205
- "execution_count": 7,
206
- "metadata": {},
207
- "outputs": [
208
- {
209
- "name": "stdout",
210
- "output_type": "stream",
211
- "text": [
212
- "Downloading ENCFF884LDL.bigWig...\n"
213
- ]
214
- },
215
- {
216
- "name": "stderr",
217
- "output_type": "stream",
218
- "text": [
219
- "--2025-12-10 14:54:41-- https://www.encodeproject.org/files/ENCFF884LDL/@@download/ENCFF884LDL.bigWig\n",
220
- "Resolving www.encodeproject.org (www.encodeproject.org)... 34.211.244.144\n",
221
- "Connecting to www.encodeproject.org (www.encodeproject.org)|34.211.244.144|:443... connected.\n",
222
- "HTTP request sent, awaiting response... 307 Temporary Redirect\n",
223
- "Location: https://encode-public.s3.amazonaws.com/2020/09/19/425880b6-b323-4ee2-95ce-56bdd088d126/ENCFF884LDL.bigWig?response-content-disposition=attachment%3B%20filename%3DENCFF884LDL.bigWig&AWSAccessKeyId=ASIATGZNGCNXU6SGJVOL&Signature=4o0Pp2RvJtnZc9z7HOuCU1k9wwI%3D&x-amz-security-token=IQoJb3JpZ2luX2VjEA0aCXVzLXdlc3QtMiJGMEQCIEdyOOxtHk6rJT06xIjzZR3nVyqbPB1twIFxCDtIQfNXAiAph1lc69CfHzPPglodVnVh9QCjlsXHFyUEU3K0%2Bx%2F%2Bziq8BQjW%2F%2F%2F%2F%2F%2F%2F%2F%2F%2F8BEAAaDDIyMDc0ODcxNDg2MyIMYwkeEaXuk%2BE48EDAKpAFkm4uzCSB40oRz3YT4m%2FZfBSH7XIuSCuzS7nrL5tXb9Q2rfPQSD4PHOyTR0LOOfcr98%2FyF8cJw4NE%2Fwsw8BRs4xPFEEyN6yGqwHmAyxBuwdca4GLSMGRDaSPoleMJw1FcSv96ofbZFYTTSol4b6%2FZj4jJjCa887%2F6S5x9kNIjTAtgX%2Fr3Ci4wi4FXGKTijTU%2FnbuuLZ3Cz2UobD6p732apsayl7avmUdWbUvROl3sHFOWOGCKsmDv0mavyEu2EsHxniBPfECy00BNvf%2Bj2FDaz1BImMIDavVBSwcWk8uCPjbsccsgiuKAfwr3dOXQ7R6y4NwmuFluBqn1GOXw1K13T4LrF%2BrhmqdOWeIVKB%2Bo9vnfQm1Dws6EoyS%2BG0bWDnyuUnLtWGf4cZPA6kjcM14fspFxoMnLjHBfdpYKZ3VmikbgwE8mDaiHODH1WQ36lUPigKbbIeHqOnHTIEw5h6F8D0MfIdVBSV2HCXweIlxCr6%2FV8hy2RzDouzT%2FIH%2FIobhHjGPM%2FlmkLAcfEzS2fioCJwkqQ3F%2BC77alAhtDQ4Oy5OIxRnRHVLpO%2BMA9Ml0SrEegCGPIzLucuCtbj2UTEOnBRQXyMolyySopJZb4p4BpJ6MiitLyCt1C66lvJpX5oMri%2BVD7FcTgdPYxcqM%2FMLD%2B4XqTYh5wdK7EYe3CpsVjpviZSVbn7yVHAb8WqdmFO%2BXRGhjQdN6rMrwGPiMCmQq12tTQftfmEwPGN1CVHG%2BbL1KUpEF4BRE61xDwEu7ZXyycPqTJMKHVn%2BXZ%2BxFsaxpUsp25U6JIVVPiNgt1OyhfjU6oqzwzeXH7KMRIcqz2d%2B3p%2BIbjRvoHcLc8AzgY4RvgWMGlb5gIpv15HQTDvdiLLwwjd3lyQY6sgE9t%2Bhi2Jv1DPgJN0YUGblcTV3Ey95h%2BBIXo6zWGwqhyZhkH%2ByxJKXouv2S1mKS3BM0dp2maJGDp69Mze8UkGjFYvdzxHT1zrCZ4dMRRkRObY3%2F4ZP33ogelhzchd7S76et35vYwYHd9DYycWZnJ%2FIcfpSZURGMJu3gLM3YhIscykGwQKqB21Tmyjufi0AaYyLk4w2OKc31kgjFvs6lNaHhqTuFButuHEiBUMzieixOI%2BX6&Expires=1765504482 [following]\n",
224
- "--2025-12-10 14:54:42-- https://encode-public.s3.amazonaws.com/2020/09/19/425880b6-b323-4ee2-95ce-56bdd088d126/ENCFF884LDL.bigWig?response-content-disposition=attachment%3B%20filename%3DENCFF884LDL.bigWig&AWSAccessKeyId=ASIATGZNGCNXU6SGJVOL&Signature=4o0Pp2RvJtnZc9z7HOuCU1k9wwI%3D&x-amz-security-token=IQoJb3JpZ2luX2VjEA0aCXVzLXdlc3QtMiJGMEQCIEdyOOxtHk6rJT06xIjzZR3nVyqbPB1twIFxCDtIQfNXAiAph1lc69CfHzPPglodVnVh9QCjlsXHFyUEU3K0%2Bx%2F%2Bziq8BQjW%2F%2F%2F%2F%2F%2F%2F%2F%2F%2F8BEAAaDDIyMDc0ODcxNDg2MyIMYwkeEaXuk%2BE48EDAKpAFkm4uzCSB40oRz3YT4m%2FZfBSH7XIuSCuzS7nrL5tXb9Q2rfPQSD4PHOyTR0LOOfcr98%2FyF8cJw4NE%2Fwsw8BRs4xPFEEyN6yGqwHmAyxBuwdca4GLSMGRDaSPoleMJw1FcSv96ofbZFYTTSol4b6%2FZj4jJjCa887%2F6S5x9kNIjTAtgX%2Fr3Ci4wi4FXGKTijTU%2FnbuuLZ3Cz2UobD6p732apsayl7avmUdWbUvROl3sHFOWOGCKsmDv0mavyEu2EsHxniBPfECy00BNvf%2Bj2FDaz1BImMIDavVBSwcWk8uCPjbsccsgiuKAfwr3dOXQ7R6y4NwmuFluBqn1GOXw1K13T4LrF%2BrhmqdOWeIVKB%2Bo9vnfQm1Dws6EoyS%2BG0bWDnyuUnLtWGf4cZPA6kjcM14fspFxoMnLjHBfdpYKZ3VmikbgwE8mDaiHODH1WQ36lUPigKbbIeHqOnHTIEw5h6F8D0MfIdVBSV2HCXweIlxCr6%2FV8hy2RzDouzT%2FIH%2FIobhHjGPM%2FlmkLAcfEzS2fioCJwkqQ3F%2BC77alAhtDQ4Oy5OIxRnRHVLpO%2BMA9Ml0SrEegCGPIzLucuCtbj2UTEOnBRQXyMolyySopJZb4p4BpJ6MiitLyCt1C66lvJpX5oMri%2BVD7FcTgdPYxcqM%2FMLD%2B4XqTYh5wdK7EYe3CpsVjpviZSVbn7yVHAb8WqdmFO%2BXRGhjQdN6rMrwGPiMCmQq12tTQftfmEwPGN1CVHG%2BbL1KUpEF4BRE61xDwEu7ZXyycPqTJMKHVn%2BXZ%2BxFsaxpUsp25U6JIVVPiNgt1OyhfjU6oqzwzeXH7KMRIcqz2d%2B3p%2BIbjRvoHcLc8AzgY4RvgWMGlb5gIpv15HQTDvdiLLwwjd3lyQY6sgE9t%2Bhi2Jv1DPgJN0YUGblcTV3Ey95h%2BBIXo6zWGwqhyZhkH%2ByxJKXouv2S1mKS3BM0dp2maJGDp69Mze8UkGjFYvdzxHT1zrCZ4dMRRkRObY3%2F4ZP33ogelhzchd7S76et35vYwYHd9DYycWZnJ%2FIcfpSZURGMJu3gLM3YhIscykGwQKqB21Tmyjufi0AaYyLk4w2OKc31kgjFvs6lNaHhqTuFButuHEiBUMzieixOI%2BX6&Expires=1765504482\n",
225
- "Resolving encode-public.s3.amazonaws.com (encode-public.s3.amazonaws.com)... 52.92.248.169, 52.92.211.49, 3.5.80.18, ...\n",
226
- "Connecting to encode-public.s3.amazonaws.com (encode-public.s3.amazonaws.com)|52.92.248.169|:443... connected.\n",
227
- "HTTP request sent, awaiting response... 416 Requested Range Not Satisfiable\n",
228
- "\n",
229
- " The file is already fully retrieved; nothing to do.\n",
230
- "\n"
231
- ]
232
- }
233
- ],
234
- "source": [
235
- "# Download bigwig files\n",
236
- "for bigwig_url in config[\"bigwig_url_list\"]:\n",
237
- " filename = extract_filename_from_url(bigwig_url)\n",
238
- " filepath = os.path.join(config[\"data_cache_dir\"], filename)\n",
239
- " print(f\"Downloading {filename}...\")\n",
240
- " subprocess.run([\"wget\", \"-c\", bigwig_url, \"-O\", filepath], check=True)"
241
- ]
242
- },
243
- {
244
- "cell_type": "code",
245
- "execution_count": 3,
246
- "metadata": {},
247
- "outputs": [],
248
- "source": [
249
- "chrom_splits = {\n",
250
- " \"train\": [f\"chr{i}\" for i in range(1, 21)] + ['chrX', 'chrY'],\n",
251
- " \"val\": ['chr22'],\n",
252
- " \"test\": ['chr21']\n",
253
- "}"
254
- ]
255
- },
256
- {
257
- "cell_type": "markdown",
258
- "metadata": {},
259
- "source": [
260
- "# 3. Model and tokenizer setup"
261
- ]
262
- },
263
- {
264
- "cell_type": "code",
265
- "execution_count": 4,
266
- "metadata": {},
267
- "outputs": [],
268
- "source": [
269
- "class LinearHead(nn.Module):\n",
270
- " \"\"\"A linear head that predicts one scalar value per track.\"\"\"\n",
271
- " def __init__(self, embed_dim: int, num_labels: int):\n",
272
- " super().__init__()\n",
273
- " self.layer_norm = nn.LayerNorm(embed_dim)\n",
274
- " self.head = nn.Linear(embed_dim, num_labels)\n",
275
- " \n",
276
- " def forward(self, x: torch.Tensor) -> torch.Tensor:\n",
277
- " x = self.layer_norm(x)\n",
278
- " x = self.head(x)\n",
279
- " x = F.softplus(x) # Ensure positive values\n",
280
- " return x\n",
281
- "\n",
282
- "\n",
283
- "class HFModelWithHead(nn.Module):\n",
284
- " \"\"\"Simple model wrapper: HF backbone + bigwig head.\"\"\"\n",
285
- " \n",
286
- " def __init__(\n",
287
- " self,\n",
288
- " model_name: str,\n",
289
- " bigwig_track_names: List[str],\n",
290
- " keep_target_center_fraction: float = 0.375,\n",
291
- " ):\n",
292
- " super().__init__()\n",
293
- " \n",
294
- " # Load config and model\n",
295
- " self.config = AutoConfig.from_pretrained(model_name, trust_remote_code=True)\n",
296
- " self.backbone = AutoModelForMaskedLM.from_pretrained(\n",
297
- " model_name, \n",
298
- " trust_remote_code=True,\n",
299
- " config=self.config\n",
300
- " )\n",
301
- " \n",
302
- " self.keep_target_center_fraction = keep_target_center_fraction\n",
303
- "\n",
304
- " if hasattr(self.config, \"embed_dim\"):\n",
305
- " embed_dim = self.config.embed_dim\n",
306
- " else:\n",
307
- " raise ValueError(f\"Could not determine embed_dim for {model_name}\")\n",
308
- " \n",
309
- " # Bigwig head (NTv3 outputs at single-nucleotide resolution)\n",
310
- " self.bigwig_head = LinearHead(embed_dim, len(bigwig_track_names))\n",
311
- " self.model_name = model_name\n",
312
- " \n",
313
- " def forward(self, tokens: torch.Tensor, **kwargs) -> Dict[str, torch.Tensor]:\n",
314
- " # Forward through backbone\n",
315
- " outputs = self.backbone(input_ids=tokens)\n",
316
- " embedding = outputs.hidden_states[-1] # Last hidden state\n",
317
- " \n",
318
- " # Crop to center fraction\n",
319
- " if self.keep_target_center_fraction < 1.0:\n",
320
- " seq_len = embedding.shape[1]\n",
321
- " target_offset = int(seq_len * (1 - self.keep_target_center_fraction) // 2)\n",
322
- " target_length = seq_len - 2 * target_offset\n",
323
- " embedding = embedding[:, target_offset:target_offset + target_length, :]\n",
324
- " \n",
325
- " # Predict bigwig tracks\n",
326
- " bigwig_logits = self.bigwig_head(embedding)\n",
327
- " \n",
328
- " return {\"bigwig_tracks_logits\": bigwig_logits}"
329
- ]
330
- },
331
- {
332
- "cell_type": "code",
333
- "execution_count": 5,
334
- "metadata": {},
335
- "outputs": [
336
- {
337
- "name": "stdout",
338
- "output_type": "stream",
339
- "text": [
340
- "Model loaded: InstaDeepAI/ntv3_8M_7downsample_pretrained_le_1mb\n",
341
- "Number of bigwig tracks: 1\n",
342
- "Model parameters: 7,693,244\n"
343
- ]
344
- }
345
- ],
346
- "source": [
347
- "# Load tokenizer\n",
348
- "tokenizer = AutoTokenizer.from_pretrained(config[\"model_name\"], trust_remote_code=True)\n",
349
- "\n",
350
- "# Create model\n",
351
- "model = HFModelWithHead(\n",
352
- " model_name=config[\"model_name\"],\n",
353
- " bigwig_track_names=config[\"bigwig_file_ids\"],\n",
354
- " keep_target_center_fraction=config[\"keep_target_center_fraction\"],\n",
355
- ")\n",
356
- "model = model.to(device)\n",
357
- "model.train()\n",
358
- "\n",
359
- "print(f\"Model loaded: {config['model_name']}\")\n",
360
- "print(f\"Number of bigwig tracks: {len(config['bigwig_file_ids'])}\")\n",
361
- "print(f\"Model parameters: {sum(p.numel() for p in model.parameters()):,}\")"
362
- ]
363
- },
364
- {
365
- "cell_type": "code",
366
- "execution_count": 6,
367
- "metadata": {},
368
- "outputs": [],
369
- "source": [
370
- "# Scaling functions for targets\n",
371
- "def get_track_means(bigwig_file_ids: List[str]) -> np.ndarray:\n",
372
- " \"\"\"\n",
373
- " Get track means for normalization.\n",
374
- " For now, return dummy values. In real pipeline, this loads from metadata.\n",
375
- " \"\"\"\n",
376
- " # Dummy values - in real pipeline, this would load from actual metadata\n",
377
- " return np.ones(len(bigwig_file_ids), dtype=np.float32) * 1.0\n",
378
- "\n",
379
- "\n",
380
- "def create_targets_scaling_fn(bigwig_file_ids: List[str]) -> Callable[[torch.Tensor], torch.Tensor]:\n",
381
- " \"\"\"\n",
382
- " Build a scaling function based on track means.\n",
383
- " \"\"\"\n",
384
- " # Load track means\n",
385
- " track_means_np = get_track_means(bigwig_file_ids)\n",
386
- " track_means = torch.tensor(track_means_np, dtype=torch.float32)\n",
387
- " \n",
388
- " def transform_fn(x: torch.Tensor) -> torch.Tensor:\n",
389
- " \"\"\"\n",
390
- " x: torch.Tensor, shape (seq_len, num_tracks) or (batch, seq_len, num_tracks)\n",
391
- " \"\"\"\n",
392
- " # Move constants to correct device then normalize\n",
393
- " means = track_means.to(x.device)\n",
394
- " scaled = x / means\n",
395
- "\n",
396
- " # Smooth clipping: if > 10, apply formula\n",
397
- " clipped = torch.where(\n",
398
- " scaled > 10.0,\n",
399
- " 2.0 * torch.sqrt(scaled * 10.0) - 10.0,\n",
400
- " scaled,\n",
401
- " )\n",
402
- " return clipped\n",
403
- " \n",
404
- " return transform_fn"
405
- ]
406
- },
407
- {
408
- "cell_type": "markdown",
409
- "metadata": {},
410
- "source": [
411
- "# 4. Data loading"
412
- ]
413
- },
414
- {
415
- "cell_type": "code",
416
- "execution_count": 7,
417
- "metadata": {},
418
- "outputs": [],
419
- "source": [
420
- "class GenomeBigWigDataset(Dataset):\n",
421
- " \"\"\"\n",
422
- " Random genomic windows from a reference genome + bigWig signal.\n",
423
- "\n",
424
- " Each sample:\n",
425
- " - picks a chromosome from `chroms`,\n",
426
- " - picks a random window of length `window_size`,\n",
427
- " - returns (sequence, signal, chrom, start, end).\n",
428
- "\n",
429
- " Args\n",
430
- " ----\n",
431
- " fasta_path : str\n",
432
- " Path to the reference genome FASTA (e.g. hg38.fna).\n",
433
- " bigwig_path : str\n",
434
- " Path to the bigWig file (e.g. ENCFF884LDL.bigWig).\n",
435
- " chroms : List[str]\n",
436
- " Chromosome names as they appear in the bigWig (e.g. [\"chr1\", \"chr2\", ...]).\n",
437
- " window_size : int\n",
438
- " Length of each random window (in bp).\n",
439
- " num_samples : int\n",
440
- " Number of samples the dataset will provide (len(dataset)).\n",
441
- " chrom_mapping : Optional[Dict[str, str]]\n",
442
- " Optional mapping from bigWig chrom name -> FASTA chrom name.\n",
443
- " If None, assumes the same names in both.\n",
444
- " Example for hg38 RefSeq FASTA:\n",
445
- " {\n",
446
- " \"chr1\": \"NC_000001.11\",\n",
447
- " \"chr2\": \"NC_000002.12\",\n",
448
- " ...\n",
449
- " }\n",
450
- " \"\"\"\n",
451
- "\n",
452
- " def __init__(\n",
453
- " self,\n",
454
- " fasta_path: str,\n",
455
- " bigwig_path_list: list[str],\n",
456
- " chroms: List[str],\n",
457
- " sequence_length: int,\n",
458
- " num_samples: int,\n",
459
- " tokenizer: AutoTokenizer,\n",
460
- " transform_fn: Callable[[torch.Tensor], torch.Tensor],\n",
461
- " keep_target_center_fraction: float = 1.0,\n",
462
- " num_tracks: int = 1,\n",
463
- " ):\n",
464
- " super().__init__()\n",
465
- "\n",
466
- " self.fasta = Fasta(fasta_path, as_raw=True, sequence_always_upper=True)\n",
467
- " self.bw_list = [\n",
468
- " pyBigWig.open(bigwig_path)\n",
469
- " for bigwig_path in bigwig_path_list\n",
470
- " ]\n",
471
- " self.sequence_length = sequence_length\n",
472
- " self.num_samples = num_samples\n",
473
- " self.tokenizer = tokenizer\n",
474
- " self.transform_fn = transform_fn\n",
475
- " self.keep_target_center_fraction = keep_target_center_fraction\n",
476
- " self.num_tracks = num_tracks\n",
477
- " self.chroms = chroms\n",
478
- "\n",
479
- " # Intersect lengths between FASTA and bigWig for safety\n",
480
- " bw_chrom_lengths = self.bw_list[0].chroms() # dict: chrom -> length\n",
481
- "\n",
482
- " self.valid_chroms = []\n",
483
- " self.chrom_lengths = {}\n",
484
- "\n",
485
- " for c in chroms:\n",
486
- " if c not in bw_chrom_lengths or c not in self.fasta:\n",
487
- " continue\n",
488
- "\n",
489
- " fa_len = len(self.fasta[c])\n",
490
- " bw_len = bw_chrom_lengths[c]\n",
491
- " L = min(fa_len, bw_len)\n",
492
- "\n",
493
- " if L > self.sequence_length:\n",
494
- " self.valid_chroms.append(c)\n",
495
- " self.chrom_lengths[c] = L\n",
496
- "\n",
497
- " if not self.valid_chroms:\n",
498
- " raise ValueError(\"No valid chromosomes after intersecting FASTA and bigWig.\")\n",
499
- "\n",
500
- " def __len__(self):\n",
501
- " return self.num_samples\n",
502
- "\n",
503
- " def __getitem__(self, idx):\n",
504
- " # Ignore idx, sample randomly\n",
505
- " chrom = random.choice(self.valid_chroms)\n",
506
- " chrom_len = self.chrom_lengths[chrom]\n",
507
- "\n",
508
- " max_start = chrom_len - self.sequence_length\n",
509
- " start = random.randint(0, max_start)\n",
510
- " end = start + self.sequence_length\n",
511
- "\n",
512
- " # Sequence\n",
513
- " seq = self.fasta[chrom][start:end] # string slice\n",
514
- " tokens = self.tokenizer(\n",
515
- " seq,\n",
516
- " return_tensors=\"pt\", # Returns a dict of PyTorch tensors\n",
517
- " )[\"input_ids\"][0]\n",
518
- " # The 'input_ids' field contains the tokenized sequence.\n",
519
- " # For a single input string, its shape is typically (1, len(seq))\n",
520
- "\n",
521
- " # Signal from bigWig tracks (numpy array) -> torch tensor\n",
522
- " bigwig_targets = np.array([\n",
523
- " self.bw_list[i].values(chrom, start, end, numpy=True)\n",
524
- " for i in range(len(self.bw_list))\n",
525
- " ]) # shape (num_tracks, seq_len)\n",
526
- " # Transpose to (seq_len, num_tracks)\n",
527
- " bigwig_targets = bigwig_targets.T\n",
528
- " # pyBigWig returns NaN where no data; turn NaN into 0\n",
529
- " bigwig_targets = torch.tensor(bigwig_targets, dtype=torch.float32)\n",
530
- " bigwig_targets = torch.nan_to_num(bigwig_targets, nan=0.0)\n",
531
- " \n",
532
- " # Crop targets to center fraction\n",
533
- " if self.keep_target_center_fraction < 1.0:\n",
534
- " seq_len = bigwig_targets.shape[0] # First dimension is sequence length\n",
535
- " target_offset = int(seq_len * (1 - self.keep_target_center_fraction) // 2)\n",
536
- " target_length = seq_len - 2 * target_offset\n",
537
- " bigwig_targets = bigwig_targets[target_offset:target_offset + target_length, :]\n",
538
- "\n",
539
- " # Apply scaling to targets\n",
540
- " bigwig_targets = self.transform_fn(bigwig_targets)\n",
541
- "\n",
542
- " sample = {\n",
543
- " \"tokens\": tokens,\n",
544
- " \"bigwig_targets\": bigwig_targets,\n",
545
- " \"chrom\": chrom,\n",
546
- " \"start\": start,\n",
547
- " \"end\": end,\n",
548
- " }\n",
549
- " return sample"
550
- ]
551
- },
552
- {
553
- "cell_type": "code",
554
- "execution_count": 16,
555
- "metadata": {},
556
- "outputs": [
557
- {
558
- "name": "stdout",
559
- "output_type": "stream",
560
- "text": [
561
- "Train samples: 100\n",
562
- "Val samples: 100\n",
563
- "Test samples: 100\n"
564
- ]
565
- }
566
- ],
567
- "source": [
568
- "# Create scaling function\n",
569
- "transform_fn = create_targets_scaling_fn(config[\"bigwig_file_ids\"])\n",
570
- "\n",
571
- "create_dataset_fn = functools.partial(\n",
572
- " GenomeBigWigDataset,\n",
573
- " fasta_path=fasta_path,\n",
574
- " bigwig_path_list=bigwig_path_list,\n",
575
- " sequence_length=config[\"sequence_length\"],\n",
576
- " tokenizer=tokenizer,\n",
577
- " transform_fn=transform_fn,\n",
578
- " keep_target_center_fraction=config[\"keep_target_center_fraction\"],\n",
579
- " num_tracks=len(config[\"bigwig_file_ids\"]),\n",
580
- ")\n",
581
- "\n",
582
- "train_dataset = create_dataset_fn(\n",
583
- " chroms=chrom_splits[\"train\"],\n",
584
- " num_samples=100,\n",
585
- ")\n",
586
- "\n",
587
- "val_dataset = create_dataset_fn(\n",
588
- " chroms=chrom_splits[\"val\"],\n",
589
- " num_samples=config[\"num_validation_samples\"],\n",
590
- ")\n",
591
- "\n",
592
- "test_dataset = create_dataset_fn(\n",
593
- " chroms=chrom_splits[\"test\"],\n",
594
- " num_samples=config[\"num_validation_samples\"],\n",
595
- ")\n",
596
- "\n",
597
- "# Create dataloaders\n",
598
- "train_loader = DataLoader(\n",
599
- " train_dataset,\n",
600
- " batch_size=config[\"batch_size\"],\n",
601
- " shuffle=True,\n",
602
- " num_workers=config[\"num_workers\"],\n",
603
- ")\n",
604
- "\n",
605
- "val_loader = DataLoader(\n",
606
- " val_dataset,\n",
607
- " batch_size=config[\"batch_size\"],\n",
608
- " shuffle=False,\n",
609
- " num_workers=config[\"num_workers\"],\n",
610
- ")\n",
611
- "\n",
612
- "test_loader = DataLoader(\n",
613
- " test_dataset,\n",
614
- " batch_size=config[\"batch_size\"],\n",
615
- " shuffle=False,\n",
616
- " num_workers=config[\"num_workers\"],\n",
617
- ")\n",
618
- "\n",
619
- "print(f\"Train samples: {len(train_dataset)}\")\n",
620
- "print(f\"Val samples: {len(val_dataset)}\")\n",
621
- "print(f\"Test samples: {len(test_dataset)}\")"
622
- ]
623
- },
624
- {
625
- "cell_type": "markdown",
626
- "metadata": {},
627
- "source": [
628
- "# 5. Optimizer setup\n"
629
- ]
630
- },
631
- {
632
- "cell_type": "code",
633
- "execution_count": 17,
634
- "metadata": {},
635
- "outputs": [
636
- {
637
- "name": "stdout",
638
- "output_type": "stream",
639
- "text": [
640
- "Training configuration:\n",
641
- " Batch size: 8\n",
642
- " Total training steps: 1000\n",
643
- " Log metrics every: 10 steps\n",
644
- " Validate every: 50 steps\n",
645
- "\n",
646
- "Optimizer setup:\n",
647
- " Learning rate: 1e-05\n"
648
- ]
649
- }
650
- ],
651
- "source": [
652
- "# Training setup\n",
653
- "print(f\"Training configuration:\")\n",
654
- "print(f\" Batch size: {config[\"batch_size\"]}\")\n",
655
- "print(f\" Total training steps: {config[\"num_steps_training\"]}\")\n",
656
- "print(f\" Log metrics every: {config[\"log_every_n_steps\"]} steps\")\n",
657
- "print(f\" Validate every: {config[\"validate_every_n_steps\"]} steps\")\n",
658
- "\n",
659
- "# Setup optimizer\n",
660
- "optimizer = AdamW(\n",
661
- " model.parameters(),\n",
662
- " lr=config[\"learning_rate\"],\n",
663
- " weight_decay=config[\"weight_decay\"],\n",
664
- ")\n",
665
- "\n",
666
- "print(f\"\\nOptimizer setup:\")\n",
667
- "print(f\" Learning rate: {config['learning_rate']}\")"
668
- ]
669
- },
670
- {
671
- "cell_type": "markdown",
672
- "metadata": {},
673
- "source": [
674
- "# 6. Metrics setup (using TorchMetrics)"
675
- ]
676
- },
677
- {
678
- "cell_type": "code",
679
- "execution_count": 18,
680
- "metadata": {},
681
- "outputs": [],
682
- "source": [
683
- "class TracksMetrics:\n",
684
- " \"\"\"Simple metrics tracker for tracks prediction.\"\"\"\n",
685
- " \n",
686
- " def __init__(self, track_names: List[str]):\n",
687
- " self.track_names = track_names\n",
688
- " self.num_tracks = len(track_names)\n",
689
- " # Metrics: comparing scaled targets with scaled predictions\n",
690
- " self.pearson_metrics = [\n",
691
- " PearsonCorrCoef().to(device) for _ in range(self.num_tracks)\n",
692
- " ]\n",
693
- " self.losses = []\n",
694
- " \n",
695
- " def reset(self):\n",
696
- " for metric in self.pearson_metrics:\n",
697
- " metric.reset()\n",
698
- " self.losses = []\n",
699
- " \n",
700
- " def update(\n",
701
- " self, \n",
702
- " predictions: torch.Tensor, \n",
703
- " targets: torch.Tensor,\n",
704
- " loss: float\n",
705
- " ):\n",
706
- " \"\"\"\n",
707
- " Update metrics.\n",
708
- " Args:\n",
709
- " predictions: (batch, seq_len, num_tracks)\n",
710
- " targets: (batch, seq_len, num_tracks)\n",
711
- " loss: scalar loss value\n",
712
- " \"\"\"\n",
713
- " # Flatten batch and sequence dimensions\n",
714
- " pred_flat = predictions.detach().reshape(-1, self.num_tracks) # (N, num_tracks)\n",
715
- " target_flat = targets.detach().reshape(-1, self.num_tracks) # (N, num_tracks)\n",
716
- " \n",
717
- " # Update metrics\n",
718
- " for i, metric in enumerate(self.pearson_metrics):\n",
719
- " metric.update(pred_flat[:, i], target_flat[:, i])\n",
720
- " \n",
721
- " self.losses.append(loss)\n",
722
- " \n",
723
- " def compute(self) -> Dict[str, float]:\n",
724
- " \"\"\"Compute and return all metrics.\"\"\"\n",
725
- " metrics_dict = {}\n",
726
- " \n",
727
- " # Per-track Pearson correlations\n",
728
- " for i, (track_name, metric) in enumerate(zip(self.track_names, self.pearson_metrics)):\n",
729
- " corr = metric.compute().item()\n",
730
- " metrics_dict[f\"{track_name}/pearson\"] = corr\n",
731
- " \n",
732
- " # Mean Pearson correlation\n",
733
- " correlations = [metric.compute().item() for metric in self.pearson_metrics]\n",
734
- " metrics_dict[\"mean/pearson\"] = np.nanmean(correlations)\n",
735
- " \n",
736
- " # Mean loss\n",
737
- " metrics_dict[\"loss\"] = np.mean(self.losses) if self.losses else 0.0\n",
738
- " \n",
739
- " return metrics_dict"
740
- ]
741
- },
742
- {
743
- "cell_type": "code",
744
- "execution_count": 19,
745
- "metadata": {},
746
- "outputs": [],
747
- "source": [
748
- "train_metrics = TracksMetrics(config[\"bigwig_file_ids\"])\n",
749
- "val_metrics = TracksMetrics(config[\"bigwig_file_ids\"])\n",
750
- "test_metrics = TracksMetrics(config[\"bigwig_file_ids\"])"
751
- ]
752
- },
753
- {
754
- "cell_type": "markdown",
755
- "metadata": {},
756
- "source": [
757
- "# 7. Loss functions"
758
- ]
759
- },
760
- {
761
- "cell_type": "code",
762
- "execution_count": 20,
763
- "metadata": {},
764
- "outputs": [],
765
- "source": [
766
- "def poisson_loss(ytrue: torch.Tensor, ypred: torch.Tensor, epsilon: float = 1e-7) -> torch.Tensor:\n",
767
- " \"\"\"Poisson loss per element: ypred - ytrue * log(ypred).\"\"\"\n",
768
- " return ypred - ytrue * torch.log(ypred + epsilon)\n",
769
- "\n",
770
- "\n",
771
- "def safe_for_grad_log_torch(x: torch.Tensor) -> torch.Tensor:\n",
772
- " \"\"\"Guarantees that the log is defined for all x > 0 in a differentiable way.\"\"\"\n",
773
- " return torch.log(torch.where(x > 0.0, x, torch.ones_like(x)))\n",
774
- "\n",
775
- "\n",
776
- "def poisson_multinomial_loss(\n",
777
- " logits: torch.Tensor,\n",
778
- " targets: torch.Tensor,\n",
779
- " shape_loss_coefficient: float = 5.0,\n",
780
- " epsilon: float = 1e-7,\n",
781
- ") -> tuple[torch.Tensor, torch.Tensor | None, torch.Tensor | None]:\n",
782
- " \"\"\"\n",
783
- " Regression loss for bigwig tracks (MSE, Poisson, or Poisson-Multinomial).\n",
784
- " \"\"\"\n",
785
- "\n",
786
- " # Scale loss\n",
787
- " sum_pred = logits.sum(dim=1) # (batch, num_tracks)\n",
788
- " sum_true = targets.sum(dim=1) # (batch, num_tracks)\n",
789
- " scale_loss = poisson_loss(sum_true, sum_pred, epsilon=epsilon)\n",
790
- " scale_loss = scale_loss.mean()\n",
791
- " \n",
792
- " # Shape loss\n",
793
- " denom = logits.sum(dim=1, keepdim=True) + epsilon\n",
794
- " p_pred = logits / denom\n",
795
- " pl_pred = safe_for_grad_log_torch(p_pred)\n",
796
- " shape_loss = -(targets * pl_pred).mean()\n",
797
- " \n",
798
- " # Combine\n",
799
- " loss = shape_loss + scale_loss / shape_loss_coefficient\n",
800
- "\n",
801
- " return loss, scale_loss, shape_loss\n"
802
- ]
803
- },
804
- {
805
- "cell_type": "markdown",
806
- "metadata": {},
807
- "source": [
808
- "# 8. Training loop"
809
- ]
810
- },
811
- {
812
- "cell_type": "code",
813
- "execution_count": 21,
814
- "metadata": {},
815
- "outputs": [],
816
- "source": [
817
- "def train_step(\n",
818
- " model: nn.Module,\n",
819
- " batch: Dict[str, torch.Tensor],\n",
820
- ") -> float:\n",
821
- " \"\"\"Single training step.\"\"\"\n",
822
- " tokens = batch[\"tokens\"].to(device)\n",
823
- " bigwig_targets = batch[\"bigwig_targets\"].to(device)\n",
824
- " \n",
825
- " # Forward pass\n",
826
- " outputs = model(tokens=tokens)\n",
827
- " bigwig_logits = outputs[\"bigwig_tracks_logits\"]\n",
828
- " \n",
829
- " # Compute loss\n",
830
- " loss, _, _ = poisson_multinomial_loss(\n",
831
- " logits=bigwig_logits,\n",
832
- " targets=bigwig_targets,\n",
833
- " )\n",
834
- " \n",
835
- " # Backward pass\n",
836
- " loss.backward()\n",
837
- " return loss.item()\n",
838
- "\n",
839
- "\n",
840
- "def validation_step(\n",
841
- " model: nn.Module,\n",
842
- " batch: Dict[str, torch.Tensor],\n",
843
- " metrics: TracksMetrics,\n",
844
- ") -> float:\n",
845
- " \"\"\"Single validation step.\"\"\"\n",
846
- " model.eval()\n",
847
- " \n",
848
- " tokens = batch[\"tokens\"].to(device)\n",
849
- " bigwig_targets = batch[\"bigwig_targets\"].to(device)\n",
850
- " \n",
851
- " with torch.no_grad():\n",
852
- " # Forward pass\n",
853
- " outputs = model(tokens=tokens)\n",
854
- " bigwig_logits = outputs[\"bigwig_tracks_logits\"]\n",
855
- " \n",
856
- " # Compute loss\n",
857
- " loss, _, _ = poisson_multinomial_loss(\n",
858
- " logits=bigwig_logits,\n",
859
- " targets=bigwig_targets,\n",
860
- " )\n",
861
- " \n",
862
- " # Update metrics\n",
863
- " metrics.update(\n",
864
- " predictions=bigwig_logits,\n",
865
- " targets=bigwig_targets,\n",
866
- " loss=loss.item()\n",
867
- " )\n",
868
- " \n",
869
- " return loss.item()"
870
- ]
871
- },
872
- {
873
- "cell_type": "markdown",
874
- "metadata": {},
875
- "source": [
876
- "### Interactive plotting is temporary for debug"
877
- ]
878
- },
879
- {
880
- "cell_type": "code",
881
- "execution_count": 22,
882
- "metadata": {},
883
- "outputs": [
884
- {
885
- "name": "stdout",
886
- "output_type": "stream",
887
- "text": [
888
- "Starting training...\n",
889
- "Training for 1000 steps\n",
890
- "\n"
891
- ]
892
- },
893
- {
894
- "data": {
895
- "application/vnd.jupyter.widget-view+json": {
896
- "model_id": "5935c992adb7428bac8de1aa6873dd7e",
897
- "version_major": 2,
898
- "version_minor": 0
899
- },
900
- "text/plain": [
901
- "FigureWidget({\n",
902
- " 'data': [{'line': {'color': 'blue'},\n",
903
- " 'mode': 'lines+markers',\n",
904
- " 'name': 'Train Loss',\n",
905
- " 'type': 'scatter',\n",
906
- " 'uid': '5424e4af-13b6-48c8-a367-8aa145c3a9db',\n",
907
- " 'x': [],\n",
908
- " 'xaxis': 'x',\n",
909
- " 'y': [],\n",
910
- " 'yaxis': 'y'},\n",
911
- " {'line': {'color': 'red'},\n",
912
- " 'mode': 'lines+markers',\n",
913
- " 'name': 'Val Loss',\n",
914
- " 'type': 'scatter',\n",
915
- " 'uid': 'fe995660-5f01-4c12-9d7d-9ed19ddee785',\n",
916
- " 'x': [],\n",
917
- " 'xaxis': 'x',\n",
918
- " 'y': [],\n",
919
- " 'yaxis': 'y'},\n",
920
- " {'line': {'color': 'green'},\n",
921
- " 'mode': 'lines+markers',\n",
922
- " 'name': 'Train Pearson',\n",
923
- " 'type': 'scatter',\n",
924
- " 'uid': '8453b45b-4613-41bc-a46b-ac59ba9e6f97',\n",
925
- " 'x': [],\n",
926
- " 'xaxis': 'x2',\n",
927
- " 'y': [],\n",
928
- " 'yaxis': 'y2'},\n",
929
- " {'line': {'color': 'orange'},\n",
930
- " 'mode': 'lines+markers',\n",
931
- " 'name': 'Val Pearson',\n",
932
- " 'type': 'scatter',\n",
933
- " 'uid': '0887ea97-abf9-4fcf-8ea8-c638dc153a4d',\n",
934
- " 'x': [],\n",
935
- " 'xaxis': 'x2',\n",
936
- " 'y': [],\n",
937
- " 'yaxis': 'y2'}],\n",
938
- " 'layout': {'annotations': [{'font': {'size': 16},\n",
939
- " 'showarrow': False,\n",
940
- " 'text': 'Loss',\n",
941
- " 'x': 0.2125,\n",
942
- " 'xanchor': 'center',\n",
943
- " 'xref': 'paper',\n",
944
- " 'y': 1.0,\n",
945
- " 'yanchor': 'bottom',\n",
946
- " 'yref': 'paper'},\n",
947
- " {'font': {'size': 16},\n",
948
- " 'showarrow': False,\n",
949
- " 'text': 'Mean Pearson Correlation',\n",
950
- " 'x': 0.7875,\n",
951
- " 'xanchor': 'center',\n",
952
- " 'xref': 'paper',\n",
953
- " 'y': 1.0,\n",
954
- " 'yanchor': 'bottom',\n",
955
- " 'yref': 'paper'}],\n",
956
- " 'height': 800,\n",
957
- " 'showlegend': True,\n",
958
- " 'template': '...',\n",
959
- " 'title': {'text': 'Training'},\n",
960
- " 'width': 1600,\n",
961
- " 'xaxis': {'anchor': 'y', 'domain': [0.0, 0.425], 'title': {'text': 'Step'}},\n",
962
- " 'xaxis2': {'anchor': 'y2', 'domain': [0.575, 1.0], 'title': {'text': 'Step'}},\n",
963
- " 'yaxis': {'anchor': 'x', 'domain': [0.0, 1.0], 'title': {'text': 'Loss'}},\n",
964
- " 'yaxis2': {'anchor': 'x2', 'domain': [0.0, 1.0], 'title': {'text': 'Pearson Correlation'}}}\n",
965
- "})"
966
- ]
967
- },
968
- "metadata": {},
969
- "output_type": "display_data"
970
- },
971
- {
972
- "name": "stderr",
973
- "output_type": "stream",
974
- "text": [
975
- "/home/y-bornachot/venvs/ntv3-env/lib/python3.12/site-packages/torch/amp/autocast_mode.py:287: UserWarning:\n",
976
- "\n",
977
- "In CPU autocast, but the target dtype is not supported. Disabling autocast.\n",
978
- "CPU Autocast only supports dtype of torch.bfloat16, torch.float16 currently.\n",
979
- "\n"
980
- ]
981
- },
982
- {
983
- "name": "stdout",
984
- "output_type": "stream",
985
- "text": [
986
- "Step 10/1000 | Loss: 0.2374 | Mean Pearson: 0.0382 | LR: 1.00e-05\n",
987
- "Step 20/1000 | Loss: 2.2259 | Mean Pearson: -0.0884 | LR: 1.00e-05\n",
988
- "Step 30/1000 | Loss: 20.0122 | Mean Pearson: 0.1379 | LR: 1.00e-05\n",
989
- "Step 40/1000 | Loss: 9.6938 | Mean Pearson: -0.1497 | LR: 1.00e-05\n",
990
- "Step 50/1000 | Loss: -1.8435 | Mean Pearson: -0.1875 | LR: 1.00e-05\n",
991
- "\n",
992
- "Running validation at step 50...\n",
993
- " Validation Loss: 11.5599\n",
994
- " Validation Mean Pearson: -0.1576\n",
995
- " ENCFF884LDL/pearson: -0.1576\n",
996
- "Step 60/1000 | Loss: 1.4427 | Mean Pearson: 0.2841 | LR: 1.00e-05\n",
997
- "Step 70/1000 | Loss: -3.4037 | Mean Pearson: -0.1362 | LR: 1.00e-05\n",
998
- "Step 80/1000 | Loss: 9.0958 | Mean Pearson: -0.1319 | LR: 1.00e-05\n",
999
- "Step 90/1000 | Loss: -7.8433 | Mean Pearson: -0.0576 | LR: 1.00e-05\n",
1000
- "Step 100/1000 | Loss: 7.3503 | Mean Pearson: -0.2150 | LR: 1.00e-05\n",
1001
- "\n",
1002
- "Running validation at step 100...\n",
1003
- " Validation Loss: 22.3383\n",
1004
- " Validation Mean Pearson: -0.2867\n",
1005
- " ENCFF884LDL/pearson: -0.2867\n",
1006
- "Step 110/1000 | Loss: -8.1600 | Mean Pearson: -0.1616 | LR: 1.00e-05\n",
1007
- "Step 120/1000 | Loss: -0.8743 | Mean Pearson: -0.1318 | LR: 1.00e-05\n",
1008
- "Step 130/1000 | Loss: -2.9825 | Mean Pearson: -0.0480 | LR: 1.00e-05\n",
1009
- "Step 140/1000 | Loss: -2.4524 | Mean Pearson: -0.0879 | LR: 1.00e-05\n",
1010
- "Step 150/1000 | Loss: 3.8818 | Mean Pearson: -0.0907 | LR: 1.00e-05\n",
1011
- "\n",
1012
- "Running validation at step 150...\n",
1013
- " Validation Loss: 19.6866\n",
1014
- " Validation Mean Pearson: -0.2207\n",
1015
- " ENCFF884LDL/pearson: -0.2207\n",
1016
- "Step 160/1000 | Loss: -1.0933 | Mean Pearson: -0.1243 | LR: 1.00e-05\n",
1017
- "Step 170/1000 | Loss: -2.2577 | Mean Pearson: -0.0212 | LR: 1.00e-05\n",
1018
- "Step 180/1000 | Loss: 0.0738 | Mean Pearson: 0.5643 | LR: 1.00e-05\n",
1019
- "Step 190/1000 | Loss: -0.1097 | Mean Pearson: 0.0309 | LR: 1.00e-05\n",
1020
- "Step 200/1000 | Loss: -8.7972 | Mean Pearson: 0.4804 | LR: 1.00e-05\n",
1021
- "\n",
1022
- "Running validation at step 200...\n",
1023
- " Validation Loss: -8.8160\n",
1024
- " Validation Mean Pearson: 0.0912\n",
1025
- " ENCFF884LDL/pearson: 0.0912\n",
1026
- "Step 210/1000 | Loss: -2.5429 | Mean Pearson: 0.3908 | LR: 1.00e-05\n",
1027
- "Step 220/1000 | Loss: -6.8421 | Mean Pearson: 0.4080 | LR: 1.00e-05\n",
1028
- "Step 230/1000 | Loss: -4.4312 | Mean Pearson: -0.0400 | LR: 1.00e-05\n",
1029
- "Step 240/1000 | Loss: -11.4732 | Mean Pearson: 0.6653 | LR: 1.00e-05\n",
1030
- "Step 250/1000 | Loss: -9.2648 | Mean Pearson: 0.0539 | LR: 1.00e-05\n",
1031
- "\n",
1032
- "Running validation at step 250...\n",
1033
- " Validation Loss: -6.8987\n",
1034
- " Validation Mean Pearson: 0.0654\n",
1035
- " ENCFF884LDL/pearson: 0.0654\n",
1036
- "Step 260/1000 | Loss: -0.6699 | Mean Pearson: 0.0913 | LR: 1.00e-05\n",
1037
- "Step 270/1000 | Loss: -8.6625 | Mean Pearson: 0.3179 | LR: 1.00e-05\n",
1038
- "Step 280/1000 | Loss: -11.7691 | Mean Pearson: 0.0004 | LR: 1.00e-05\n",
1039
- "Step 290/1000 | Loss: -14.1622 | Mean Pearson: 0.0492 | LR: 1.00e-05\n",
1040
- "Step 300/1000 | Loss: 0.9208 | Mean Pearson: 0.0607 | LR: 1.00e-05\n",
1041
- "\n",
1042
- "Running validation at step 300...\n",
1043
- " Validation Loss: -5.0427\n",
1044
- " Validation Mean Pearson: 0.3464\n",
1045
- " ENCFF884LDL/pearson: 0.3464\n",
1046
- "Step 310/1000 | Loss: -1.2881 | Mean Pearson: 0.1696 | LR: 1.00e-05\n",
1047
- "Step 320/1000 | Loss: -18.6637 | Mean Pearson: 0.0892 | LR: 1.00e-05\n",
1048
- "Step 330/1000 | Loss: -36.6038 | Mean Pearson: 0.3356 | LR: 1.00e-05\n",
1049
- "Step 340/1000 | Loss: -2.4984 | Mean Pearson: 0.2305 | LR: 1.00e-05\n",
1050
- "Step 350/1000 | Loss: -4.7985 | Mean Pearson: 0.0968 | LR: 1.00e-05\n",
1051
- "\n",
1052
- "Running validation at step 350...\n",
1053
- " Validation Loss: -13.6500\n",
1054
- " Validation Mean Pearson: 0.2737\n",
1055
- " ENCFF884LDL/pearson: 0.2737\n",
1056
- "Step 360/1000 | Loss: -9.4795 | Mean Pearson: 0.0579 | LR: 1.00e-05\n",
1057
- "Step 370/1000 | Loss: 0.3531 | Mean Pearson: 0.0240 | LR: 1.00e-05\n",
1058
- "Step 380/1000 | Loss: -5.7921 | Mean Pearson: 0.4119 | LR: 1.00e-05\n",
1059
- "Step 390/1000 | Loss: -2.7049 | Mean Pearson: 0.1343 | LR: 1.00e-05\n",
1060
- "Step 400/1000 | Loss: -32.8422 | Mean Pearson: 0.1545 | LR: 1.00e-05\n",
1061
- "\n",
1062
- "Running validation at step 400...\n",
1063
- " Validation Loss: -4.3502\n",
1064
- " Validation Mean Pearson: 0.3124\n",
1065
- " ENCFF884LDL/pearson: 0.3124\n",
1066
- "Step 410/1000 | Loss: -18.9574 | Mean Pearson: 0.0594 | LR: 1.00e-05\n",
1067
- "Step 420/1000 | Loss: -5.4032 | Mean Pearson: 0.2804 | LR: 1.00e-05\n",
1068
- "Step 430/1000 | Loss: -0.5171 | Mean Pearson: 0.1835 | LR: 1.00e-05\n",
1069
- "Step 440/1000 | Loss: -3.4071 | Mean Pearson: 0.0680 | LR: 1.00e-05\n",
1070
- "Step 450/1000 | Loss: -3.5580 | Mean Pearson: 0.0850 | LR: 1.00e-05\n",
1071
- "\n",
1072
- "Running validation at step 450...\n",
1073
- " Validation Loss: -7.3308\n",
1074
- " Validation Mean Pearson: 0.1128\n",
1075
- " ENCFF884LDL/pearson: 0.1128\n",
1076
- "Step 460/1000 | Loss: -0.9750 | Mean Pearson: 0.1717 | LR: 1.00e-05\n",
1077
- "Step 470/1000 | Loss: -5.5775 | Mean Pearson: 0.1321 | LR: 1.00e-05\n",
1078
- "Step 480/1000 | Loss: -1.1170 | Mean Pearson: 0.1484 | LR: 1.00e-05\n",
1079
- "Step 490/1000 | Loss: -3.8053 | Mean Pearson: 0.1959 | LR: 1.00e-05\n",
1080
- "Step 500/1000 | Loss: -4.5933 | Mean Pearson: 0.1860 | LR: 1.00e-05\n",
1081
- "\n",
1082
- "Running validation at step 500...\n",
1083
- " Validation Loss: -5.7617\n",
1084
- " Validation Mean Pearson: 0.3155\n",
1085
- " ENCFF884LDL/pearson: 0.3155\n",
1086
- "Step 510/1000 | Loss: -3.3306 | Mean Pearson: 0.2815 | LR: 1.00e-05\n",
1087
- "Step 520/1000 | Loss: -2.1962 | Mean Pearson: 0.1151 | LR: 1.00e-05\n",
1088
- "Step 530/1000 | Loss: -1.5388 | Mean Pearson: 0.3783 | LR: 1.00e-05\n",
1089
- "Step 540/1000 | Loss: -2.2349 | Mean Pearson: 0.0734 | LR: 1.00e-05\n",
1090
- "Step 550/1000 | Loss: -1.5502 | Mean Pearson: 0.2171 | LR: 1.00e-05\n",
1091
- "\n",
1092
- "Running validation at step 550...\n",
1093
- " Validation Loss: -3.0059\n",
1094
- " Validation Mean Pearson: 0.2325\n",
1095
- " ENCFF884LDL/pearson: 0.2325\n",
1096
- "Step 560/1000 | Loss: -2.0764 | Mean Pearson: -0.0049 | LR: 1.00e-05\n",
1097
- "Step 570/1000 | Loss: -1.7384 | Mean Pearson: 0.2989 | LR: 1.00e-05\n",
1098
- "Step 580/1000 | Loss: -6.7306 | Mean Pearson: 0.2522 | LR: 1.00e-05\n",
1099
- "Step 590/1000 | Loss: -3.2473 | Mean Pearson: 0.1042 | LR: 1.00e-05\n",
1100
- "Step 600/1000 | Loss: -4.2841 | Mean Pearson: 0.1936 | LR: 1.00e-05\n",
1101
- "\n",
1102
- "Running validation at step 600...\n",
1103
- " Validation Loss: -4.5611\n",
1104
- " Validation Mean Pearson: 0.2744\n",
1105
- " ENCFF884LDL/pearson: 0.2744\n",
1106
- "Step 610/1000 | Loss: -3.5691 | Mean Pearson: 0.1803 | LR: 1.00e-05\n",
1107
- "Step 620/1000 | Loss: -7.2129 | Mean Pearson: 0.0901 | LR: 1.00e-05\n",
1108
- "Step 630/1000 | Loss: -6.0598 | Mean Pearson: 0.1795 | LR: 1.00e-05\n",
1109
- "Step 640/1000 | Loss: -2.8917 | Mean Pearson: 0.1111 | LR: 1.00e-05\n",
1110
- "Step 650/1000 | Loss: -2.7210 | Mean Pearson: 0.3566 | LR: 1.00e-05\n",
1111
- "\n",
1112
- "Running validation at step 650...\n",
1113
- " Validation Loss: -4.3997\n",
1114
- " Validation Mean Pearson: 0.3327\n",
1115
- " ENCFF884LDL/pearson: 0.3327\n",
1116
- "Step 660/1000 | Loss: -3.4793 | Mean Pearson: 0.0441 | LR: 1.00e-05\n",
1117
- "Step 670/1000 | Loss: -1.9743 | Mean Pearson: 0.1364 | LR: 1.00e-05\n",
1118
- "Step 680/1000 | Loss: -5.7498 | Mean Pearson: 0.2330 | LR: 1.00e-05\n",
1119
- "Step 690/1000 | Loss: -12.8701 | Mean Pearson: 0.3182 | LR: 1.00e-05\n",
1120
- "Step 700/1000 | Loss: -1.5847 | Mean Pearson: 0.1971 | LR: 1.00e-05\n",
1121
- "\n",
1122
- "Running validation at step 700...\n",
1123
- " Validation Loss: -2.0630\n",
1124
- " Validation Mean Pearson: 0.1267\n",
1125
- " ENCFF884LDL/pearson: 0.1267\n",
1126
- "Step 710/1000 | Loss: -6.0704 | Mean Pearson: 0.3715 | LR: 1.00e-05\n",
1127
- "Step 720/1000 | Loss: -2.6020 | Mean Pearson: 0.1244 | LR: 1.00e-05\n",
1128
- "Step 730/1000 | Loss: -58.8965 | Mean Pearson: 0.5625 | LR: 1.00e-05\n",
1129
- "Step 740/1000 | Loss: -1.2855 | Mean Pearson: 0.2658 | LR: 1.00e-05\n",
1130
- "Step 750/1000 | Loss: -4.4599 | Mean Pearson: 0.0137 | LR: 1.00e-05\n",
1131
- "\n",
1132
- "Running validation at step 750...\n",
1133
- " Validation Loss: -11.1562\n",
1134
- " Validation Mean Pearson: 0.0844\n",
1135
- " ENCFF884LDL/pearson: 0.0844\n",
1136
- "Step 760/1000 | Loss: -11.6905 | Mean Pearson: 0.1914 | LR: 1.00e-05\n",
1137
- "Step 770/1000 | Loss: -4.0964 | Mean Pearson: 0.2022 | LR: 1.00e-05\n",
1138
- "Step 780/1000 | Loss: -1.5512 | Mean Pearson: 0.3568 | LR: 1.00e-05\n",
1139
- "Step 790/1000 | Loss: -5.5843 | Mean Pearson: 0.2058 | LR: 1.00e-05\n",
1140
- "Step 800/1000 | Loss: -3.9190 | Mean Pearson: 0.4362 | LR: 1.00e-05\n",
1141
- "\n",
1142
- "Running validation at step 800...\n",
1143
- " Validation Loss: -4.7017\n",
1144
- " Validation Mean Pearson: 0.3817\n",
1145
- " ENCFF884LDL/pearson: 0.3817\n",
1146
- "Step 810/1000 | Loss: -7.6856 | Mean Pearson: 0.0672 | LR: 1.00e-05\n",
1147
- "Step 820/1000 | Loss: -5.3603 | Mean Pearson: 0.2325 | LR: 1.00e-05\n",
1148
- "Step 830/1000 | Loss: -3.8539 | Mean Pearson: 0.2808 | LR: 1.00e-05\n",
1149
- "Step 840/1000 | Loss: -8.1141 | Mean Pearson: 0.2529 | LR: 1.00e-05\n",
1150
- "Step 850/1000 | Loss: -10.5886 | Mean Pearson: 0.3454 | LR: 1.00e-05\n",
1151
- "\n",
1152
- "Running validation at step 850...\n",
1153
- " Validation Loss: -4.9108\n",
1154
- " Validation Mean Pearson: 0.2195\n",
1155
- " ENCFF884LDL/pearson: 0.2195\n",
1156
- "Step 860/1000 | Loss: -4.1028 | Mean Pearson: 0.3304 | LR: 1.00e-05\n",
1157
- "Step 870/1000 | Loss: -7.1834 | Mean Pearson: 0.1206 | LR: 1.00e-05\n",
1158
- "Step 880/1000 | Loss: -8.9869 | Mean Pearson: 0.3584 | LR: 1.00e-05\n",
1159
- "Step 890/1000 | Loss: -2.2697 | Mean Pearson: 0.0943 | LR: 1.00e-05\n",
1160
- "Step 900/1000 | Loss: -14.0142 | Mean Pearson: 0.4761 | LR: 1.00e-05\n",
1161
- "\n",
1162
- "Running validation at step 900...\n",
1163
- " Validation Loss: -3.2329\n",
1164
- " Validation Mean Pearson: 0.3635\n",
1165
- " ENCFF884LDL/pearson: 0.3635\n",
1166
- "Step 910/1000 | Loss: -9.0941 | Mean Pearson: 0.2754 | LR: 1.00e-05\n",
1167
- "Step 920/1000 | Loss: -4.6371 | Mean Pearson: 0.0167 | LR: 1.00e-05\n",
1168
- "Step 930/1000 | Loss: -7.9853 | Mean Pearson: 0.0941 | LR: 1.00e-05\n",
1169
- "Step 940/1000 | Loss: -22.9349 | Mean Pearson: 0.5140 | LR: 1.00e-05\n",
1170
- "Step 950/1000 | Loss: -2.0866 | Mean Pearson: 0.1746 | LR: 1.00e-05\n",
1171
- "\n",
1172
- "Running validation at step 950...\n",
1173
- " Validation Loss: -8.8318\n",
1174
- " Validation Mean Pearson: 0.1597\n",
1175
- " ENCFF884LDL/pearson: 0.1597\n",
1176
- "Step 960/1000 | Loss: -4.8540 | Mean Pearson: 0.6318 | LR: 1.00e-05\n",
1177
- "Step 970/1000 | Loss: -4.1091 | Mean Pearson: 0.0985 | LR: 1.00e-05\n",
1178
- "Step 980/1000 | Loss: -5.1141 | Mean Pearson: 0.2031 | LR: 1.00e-05\n",
1179
- "Step 990/1000 | Loss: -4.1959 | Mean Pearson: 0.2404 | LR: 1.00e-05\n",
1180
- "Step 1000/1000 | Loss: -0.9942 | Mean Pearson: 0.2742 | LR: 1.00e-05\n",
1181
- "\n",
1182
- "Running validation at step 1000...\n",
1183
- " Validation Loss: -4.2796\n",
1184
- " Validation Mean Pearson: 0.1425\n",
1185
- " ENCFF884LDL/pearson: 0.1425\n",
1186
- "\n",
1187
- "Training completed after 1000 steps.\n"
1188
- ]
1189
- }
1190
- ],
1191
- "source": [
1192
- "# Training loop\n",
1193
- "print(\"Starting training...\")\n",
1194
- "print(f\"Training for {config[\"num_steps_training\"]} steps\\n\")\n",
1195
- "\n",
1196
- "model.train()\n",
1197
- "train_metrics.reset()\n",
1198
- "optimizer.zero_grad() # Initialize gradients\n",
1199
- "\n",
1200
- "# Track metrics for plotting\n",
1201
- "train_steps = []\n",
1202
- "train_losses = []\n",
1203
- "train_pearson_scores = []\n",
1204
- "val_steps = []\n",
1205
- "val_losses = []\n",
1206
- "val_pearson_scores = []\n",
1207
- "\n",
1208
- "# Initialize interactive plots using FigureWidget for real-time updates\n",
1209
- "from plotly.graph_objects import FigureWidget\n",
1210
- "from plotly.subplots import make_subplots\n",
1211
- "\n",
1212
- "# Create base figure with subplots\n",
1213
- "fig_base = make_subplots(\n",
1214
- " rows=1, cols=2,\n",
1215
- " subplot_titles=('Loss', 'Mean Pearson Correlation'),\n",
1216
- " horizontal_spacing=0.15,\n",
1217
- ")\n",
1218
- "\n",
1219
- "# Add empty traces for train and val metrics\n",
1220
- "fig_base.add_trace(\n",
1221
- " go.Scatter(x=[], y=[], mode='lines+markers', name='Train Loss', line=dict(color='blue')),\n",
1222
- " row=1, col=1\n",
1223
- ")\n",
1224
- "fig_base.add_trace(\n",
1225
- " go.Scatter(x=[], y=[], mode='lines+markers', name='Val Loss', line=dict(color='red')),\n",
1226
- " row=1, col=1\n",
1227
- ")\n",
1228
- "fig_base.add_trace(\n",
1229
- " go.Scatter(x=[], y=[], mode='lines+markers', name='Train Pearson', line=dict(color='green')),\n",
1230
- " row=1, col=2\n",
1231
- ")\n",
1232
- "fig_base.add_trace(\n",
1233
- " go.Scatter(x=[], y=[], mode='lines+markers', name='Val Pearson', line=dict(color='orange')),\n",
1234
- " row=1, col=2\n",
1235
- ")\n",
1236
- "\n",
1237
- "fig_base.update_xaxes(title_text=\"Step\", row=1, col=1)\n",
1238
- "fig_base.update_xaxes(title_text=\"Step\", row=1, col=2)\n",
1239
- "fig_base.update_yaxes(title_text=\"Loss\", row=1, col=1)\n",
1240
- "fig_base.update_yaxes(title_text=\"Pearson Correlation\", row=1, col=2)\n",
1241
- "fig_base.update_layout(height=800, width=1600, showlegend=True, title_text=\"Training\")\n",
1242
- "\n",
1243
- "# Convert to FigureWidget for interactive updates\n",
1244
- "fig = FigureWidget(fig_base)\n",
1245
- "\n",
1246
- "# Display initial plot (will update in place during training)\n",
1247
- "display(fig)\n",
1248
- "\n",
1249
- "# Create iterator for training data (will cycle if needed)\n",
1250
- "train_iter = iter(train_loader)\n",
1251
- "\n",
1252
- "# Main training loop\n",
1253
- "for step_idx in range(config[\"num_steps_training\"]):\n",
1254
- " try:\n",
1255
- " batch = next(train_iter)\n",
1256
- " except StopIteration:\n",
1257
- " # Restart iterator if we run out of data\n",
1258
- " train_iter = iter(train_loader)\n",
1259
- " batch = next(train_iter)\n",
1260
- " \n",
1261
- " # Forward pass and backward pass\n",
1262
- " loss = train_step(model, batch)\n",
1263
- " \n",
1264
- " # Update optimizer\n",
1265
- " optimizer.step()\n",
1266
- " optimizer.zero_grad()\n",
1267
- " \n",
1268
- " # Update metrics\n",
1269
- " tokens = batch[\"tokens\"].to(device)\n",
1270
- " bigwig_targets = batch[\"bigwig_targets\"].to(device)\n",
1271
- " with torch.no_grad():\n",
1272
- " outputs = model(tokens=tokens)\n",
1273
- " bigwig_logits = outputs[\"bigwig_tracks_logits\"]\n",
1274
- " \n",
1275
- " train_metrics.update(\n",
1276
- " predictions=bigwig_logits,\n",
1277
- " targets=bigwig_targets,\n",
1278
- " loss=loss\n",
1279
- " )\n",
1280
- " \n",
1281
- " # Logging\n",
1282
- " if (step_idx + 1) % config[\"log_every_n_steps\"] == 0:\n",
1283
- " train_metrics_dict = train_metrics.compute()\n",
1284
- " current_lr = optimizer.param_groups[0]['lr']\n",
1285
- " \n",
1286
- " # Track metrics for plotting\n",
1287
- " train_steps.append(step_idx + 1)\n",
1288
- " train_losses.append(loss)\n",
1289
- " train_pearson_scores.append(train_metrics_dict['mean/pearson'])\n",
1290
- " \n",
1291
- " # Update plots - direct assignment to FigureWidget data updates the plot automatically\n",
1292
- " fig.data[0].x = train_steps\n",
1293
- " fig.data[0].y = train_losses\n",
1294
- " fig.data[2].x = train_steps\n",
1295
- " fig.data[2].y = train_pearson_scores\n",
1296
- " \n",
1297
- " print(f\"Step {step_idx + 1}/{config[\"num_steps_training\"]} | \"\n",
1298
- " f\"Loss: {loss:.4f} | \"\n",
1299
- " f\"Mean Pearson: {train_metrics_dict['mean/pearson']:.4f} | \"\n",
1300
- " f\"LR: {current_lr:.2e}\")\n",
1301
- " train_metrics.reset()\n",
1302
- " \n",
1303
- " # Validation\n",
1304
- " if (step_idx + 1) % config[\"validate_every_n_steps\"] == 0:\n",
1305
- " print(f\"\\nRunning validation at step {step_idx + 1}...\")\n",
1306
- " val_metrics.reset()\n",
1307
- " model.eval()\n",
1308
- " \n",
1309
- " val_batch_losses = []\n",
1310
- " for val_batch in val_loader:\n",
1311
- " val_loss = validation_step(model, val_batch, val_metrics)\n",
1312
- " val_batch_losses.append(val_loss)\n",
1313
- " \n",
1314
- " # Print validation metrics\n",
1315
- " val_metrics_dict = val_metrics.compute()\n",
1316
- " val_loss_mean = np.mean(val_batch_losses)\n",
1317
- " val_pearson_mean = val_metrics_dict['mean/pearson']\n",
1318
- " \n",
1319
- " # Track validation metrics\n",
1320
- " val_steps.append(step_idx + 1)\n",
1321
- " val_losses.append(val_loss_mean)\n",
1322
- " val_pearson_scores.append(val_pearson_mean)\n",
1323
- " \n",
1324
- " # Update plots with validation data - direct assignment updates the plot automatically\n",
1325
- " fig.data[1].x = val_steps\n",
1326
- " fig.data[1].y = val_losses\n",
1327
- " fig.data[3].x = val_steps\n",
1328
- " fig.data[3].y = val_pearson_scores\n",
1329
- " \n",
1330
- " print(f\" Validation Loss: {val_loss_mean:.4f}\")\n",
1331
- " print(f\" Validation Mean Pearson: {val_pearson_mean:.4f}\")\n",
1332
- " for track_name in config[\"bigwig_file_ids\"]:\n",
1333
- " print(f\" {track_name}/pearson: {val_metrics_dict[f'{track_name}/pearson']:.4f}\")\n",
1334
- " \n",
1335
- " model.train() # Back to training mode\n",
1336
- "\n",
1337
- "print(f\"\\nTraining completed after {config[\"num_steps_training\"]} steps.\")"
1338
- ]
1339
- },
1340
- {
1341
- "cell_type": "markdown",
1342
- "metadata": {},
1343
- "source": [
1344
- "# 10. Test evaluation"
1345
- ]
1346
- },
1347
- {
1348
- "cell_type": "code",
1349
- "execution_count": 24,
1350
- "metadata": {},
1351
- "outputs": [
1352
- {
1353
- "name": "stdout",
1354
- "output_type": "stream",
1355
- "text": [
1356
- "Running test evaluation with 12 steps (100 samples)\n",
1357
- "\n",
1358
- "==================================================\n",
1359
- "Test Set Results\n",
1360
- "==================================================\n",
1361
- "\n",
1362
- "Metrics:\n",
1363
- " Mean Pearson: 0.1787\n",
1364
- " ENCFF884LDL/pearson: 0.1787\n"
1365
- ]
1366
- }
1367
- ],
1368
- "source": [
1369
- "# Calculate number of test steps (based on deepspeed pipeline)\n",
1370
- "num_test_samples = len(test_dataset)\n",
1371
- "num_test_steps = num_test_samples // config[\"batch_size\"]\n",
1372
- "print(f\"Running test evaluation with {num_test_steps} steps ({num_test_samples} samples)\")\n",
1373
- "\n",
1374
- "# Set model to eval mode\n",
1375
- "model.eval()\n",
1376
- "\n",
1377
- "for test_batch in test_loader: \n",
1378
- "\n",
1379
- " _ = validation_step( \n",
1380
- " model, \n",
1381
- " test_batch, \n",
1382
- " test_metrics,\n",
1383
- " )\n",
1384
- " \n",
1385
- "# Compute final test metrics\n",
1386
- "test_metrics_dict = test_metrics.compute()\n",
1387
- "print(\"\\n\" + \"=\"*50)\n",
1388
- "print(\"Test Set Results\")\n",
1389
- "print(\"=\"*50)\n",
1390
- "print(f\"\\nMetrics:\")\n",
1391
- "print(f\" Mean Pearson: {test_metrics_dict['mean/pearson']:.4f}\")\n",
1392
- "for track_name in config[\"bigwig_file_ids\"]: \n",
1393
- " print(f\" {track_name}/pearson: {test_metrics_dict[f'{track_name}/pearson']:.4f}\")"
1394
- ]
1395
- },
1396
- {
1397
- "cell_type": "code",
1398
- "execution_count": null,
1399
- "metadata": {},
1400
- "outputs": [],
1401
- "source": []
1402
- }
1403
- ],
1404
- "metadata": {
1405
- "kernelspec": {
1406
- "display_name": "Python 3.12 (ntv3-env)",
1407
- "language": "python",
1408
- "name": "ntv3-env"
1409
- },
1410
- "language_info": {
1411
- "codemirror_mode": {
1412
- "name": "ipython",
1413
- "version": 3
1414
- },
1415
- "file_extension": ".py",
1416
- "mimetype": "text/x-python",
1417
- "name": "python",
1418
- "nbconvert_exporter": "python",
1419
- "pygments_lexer": "ipython3",
1420
- "version": "3.12.3"
1421
- }
1422
- },
1423
- "nbformat": 4,
1424
- "nbformat_minor": 2
1425
- }
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:fd2b425dc0d358a64ac0e27c1c8b32eef79069b995edcdf2b81549988ac97026
3
+ size 14418415