ybornachot commited on
Commit
6e05130
·
1 Parent(s): b04b4fa

fix: simplified data download + loading

Browse files
Files changed (1) hide show
  1. notebooks/03_fine_tuning.ipynb +122 -251
notebooks/03_fine_tuning.ipynb CHANGED
@@ -21,30 +21,23 @@
21
  "outputs": [],
22
  "source": [
23
  "# Install useful dependencies\n",
24
- "# !pip install -r requirements.txt"
 
 
25
  ]
26
  },
27
  {
28
  "cell_type": "code",
29
- "execution_count": 1,
30
  "metadata": {},
31
- "outputs": [
32
- {
33
- "name": "stderr",
34
- "output_type": "stream",
35
- "text": [
36
- "/home/y-bornachot/venvs/ntv3-env/lib/python3.12/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
37
- " from .autonotebook import tqdm as notebook_tqdm\n"
38
- ]
39
- }
40
- ],
41
  "source": [
42
  "# 0. Imports\n",
43
  "import random\n",
44
  "import functools\n",
45
  "from typing import List, Dict, Optional, Callable\n",
46
- "import pyBigWig\n",
47
- "from pyfaidx import Fasta\n",
48
  "\n",
49
  "import torch\n",
50
  "import torch.nn as nn\n",
@@ -54,6 +47,8 @@
54
  "from torch.optim.lr_scheduler import LambdaLR\n",
55
  "from transformers import AutoConfig, AutoModelForMaskedLM, AutoTokenizer\n",
56
  "import numpy as np\n",
 
 
57
  "from torchmetrics import PearsonCorrCoef"
58
  ]
59
  },
@@ -66,7 +61,7 @@
66
  },
67
  {
68
  "cell_type": "code",
69
- "execution_count": 2,
70
  "metadata": {},
71
  "outputs": [
72
  {
@@ -81,11 +76,12 @@
81
  "config = {\n",
82
  " # Model\n",
83
  " \"model_name\": \"InstaDeepAI/ntv3_8M_7downsample_pretrained_le_1mb\", # NTv3 model\n",
84
- " \"pretrained\": True,\n",
85
  " \n",
86
  " # Data\n",
 
 
 
87
  " \"sequence_length\": 1_024,\n",
88
- " \"bigwig_file_ids\": [\"ENCFF884LDL\"], # Example track names\n",
89
  " \"keep_target_center_fraction\": 0.375,\n",
90
  " \n",
91
  " # Training\n",
@@ -115,10 +111,34 @@
115
  " \"num_workers\": 0, # Number of worker processes for DataLoader\n",
116
  "}\n",
117
  "\n",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
118
  "# Set random seed\n",
119
  "torch.manual_seed(config[\"seed\"])\n",
120
  "np.random.seed(config[\"seed\"])\n",
121
  "\n",
 
122
  "device = torch.device(config[\"device\"])\n",
123
  "print(f\"Using device: {device}\")"
124
  ]
@@ -132,73 +152,82 @@
132
  },
133
  {
134
  "cell_type": "code",
135
- "execution_count": null,
136
- "metadata": {},
137
- "outputs": [],
138
- "source": [
139
- "!wget -c https://ftp.ncbi.nlm.nih.gov/genomes/refseq/vertebrate_mammalian/Homo_sapiens/latest_assembly_versions/GCF_000001405.40_GRCh38.p14/GCF_000001405.40_GRCh38.p14_genomic.fna.gz \\\n",
140
- "&& gunzip -f GCF_000001405.40_GRCh38.p14_genomic.fna.gz"
141
- ]
142
- },
143
- {
144
- "cell_type": "code",
145
- "execution_count": null,
146
  "metadata": {},
147
- "outputs": [],
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
148
  "source": [
149
- "!wget -O ENCFF884LDL \"$(curl -s https://www.encodeproject.org/files/ENCFF884LDL/@@download/ENCFF884LDL | sed -n 's/.*href=\\\"\\([^\\\"]*ENCFF884LDL[^\\\"]*\\)\\\".*/\\1/p')\" \\\n",
150
- "&& echo \"Downloaded ENCFF884LDL\""
151
  ]
152
  },
153
  {
154
  "cell_type": "code",
155
- "execution_count": null,
156
  "metadata": {},
157
- "outputs": [],
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
158
  "source": [
159
- "!wget -c https://www.encodeproject.org/files/ENCFF884LDL/@@download/ENCFF884LDL.bigWig"
 
 
 
 
 
160
  ]
161
  },
162
  {
163
  "cell_type": "code",
164
- "execution_count": 3,
165
  "metadata": {},
166
  "outputs": [],
167
  "source": [
168
- "chrom_mapping = {\n",
169
- " \"chr1\": \"NC_000001.11\",\n",
170
- " \"chr2\": \"NC_000002.12\",\n",
171
- " \"chr3\": \"NC_000003.12\",\n",
172
- " \"chr4\": \"NC_000004.12\",\n",
173
- " \"chr5\": \"NC_000005.10\",\n",
174
- " \"chr6\": \"NC_000006.12\",\n",
175
- " \"chr7\": \"NC_000007.14\",\n",
176
- " \"chr8\": \"NC_000008.11\",\n",
177
- " \"chr9\": \"NC_000009.12\",\n",
178
- " \"chr10\": \"NC_000010.11\",\n",
179
- " \"chr11\": \"NC_000011.10\",\n",
180
- " \"chr12\": \"NC_000012.12\",\n",
181
- " \"chr13\": \"NC_000013.11\",\n",
182
- " \"chr14\": \"NC_000014.9\",\n",
183
- " \"chr15\": \"NC_000015.10\",\n",
184
- " \"chr16\": \"NC_000016.10\",\n",
185
- " \"chr17\": \"NC_000017.11\",\n",
186
- " \"chr18\": \"NC_000018.10\",\n",
187
- " \"chr19\": \"NC_000019.10\",\n",
188
- " \"chr20\": \"NC_000020.11\",\n",
189
- " \"chr21\": \"NC_000021.9\",\n",
190
- " \"chr22\": \"NC_000022.11\",\n",
191
- " \"chrX\": \"NC_000023.11\",\n",
192
- " \"chrY\": \"NC_000024.10\",\n",
193
- " # mitochondrial\n",
194
- " \"chrM\": \"NC_012920.1\",\n",
195
- " \"chrMT\": \"NC_012920.1\",\n",
196
- "}\n",
197
- "\n",
198
  "chrom_splits = {\n",
199
- " \"train\": [f\"chr{i}\" for i in range(1, 19)],\n",
200
- " \"val\": [f\"chr{i}\" for i in range(19, 21)],\n",
201
- " \"test\": [f\"chr{i}\" for i in range(21, 23)],\n",
202
  "}"
203
  ]
204
  },
@@ -211,7 +240,7 @@
211
  },
212
  {
213
  "cell_type": "code",
214
- "execution_count": 4,
215
  "metadata": {},
216
  "outputs": [],
217
  "source": [
@@ -237,24 +266,16 @@
237
  " model_name: str,\n",
238
  " bigwig_track_names: List[str],\n",
239
  " keep_target_center_fraction: float = 0.375,\n",
240
- " pretrained: bool = True,\n",
241
  " ):\n",
242
  " super().__init__()\n",
243
  " \n",
244
  " # Load config and model\n",
245
  " self.config = AutoConfig.from_pretrained(model_name, trust_remote_code=True)\n",
246
- "\n",
247
- " if pretrained:\n",
248
- " self.backbone = AutoModelForMaskedLM.from_pretrained(\n",
249
- " model_name, \n",
250
- " trust_remote_code=True,\n",
251
- " config=self.config\n",
252
- " )\n",
253
- " else:\n",
254
- " self.backbone = AutoModelForMaskedLM.from_config(\n",
255
- " self.config, \n",
256
- " trust_remote_code=True\n",
257
- " )\n",
258
  " \n",
259
  " self.keep_target_center_fraction = keep_target_center_fraction\n",
260
  "\n",
@@ -287,7 +308,7 @@
287
  },
288
  {
289
  "cell_type": "code",
290
- "execution_count": 5,
291
  "metadata": {},
292
  "outputs": [
293
  {
@@ -314,7 +335,6 @@
314
  " model_name=config[\"model_name\"],\n",
315
  " bigwig_track_names=config[\"bigwig_file_ids\"],\n",
316
  " keep_target_center_fraction=config[\"keep_target_center_fraction\"],\n",
317
- " pretrained=config[\"pretrained\"],\n",
318
  ")\n",
319
  "model = model.to(device)\n",
320
  "model.train()\n",
@@ -333,7 +353,7 @@
333
  },
334
  {
335
  "cell_type": "code",
336
- "execution_count": 34,
337
  "metadata": {},
338
  "outputs": [],
339
  "source": [
@@ -377,7 +397,6 @@
377
  " sequence_length: int,\n",
378
  " num_samples: int,\n",
379
  " tokenizer: AutoTokenizer,\n",
380
- " chrom_mapping: Optional[Dict[str, str]] = None,\n",
381
  " keep_target_center_fraction: float = 1.0,\n",
382
  " num_tracks: int = 1,\n",
383
  " ):\n",
@@ -393,9 +412,7 @@
393
  " self.tokenizer = tokenizer\n",
394
  " self.keep_target_center_fraction = keep_target_center_fraction\n",
395
  " self.num_tracks = num_tracks\n",
396
- "\n",
397
  " self.chroms = chroms\n",
398
- " self.chrom_mapping = chrom_mapping or {c: c for c in chroms}\n",
399
  "\n",
400
  " # Intersect lengths between FASTA and bigWig for safety\n",
401
  " bw_chrom_lengths = self.bw_list[0].chroms() # dict: chrom -> length\n",
@@ -404,13 +421,10 @@
404
  " self.chrom_lengths = {}\n",
405
  "\n",
406
  " for c in chroms:\n",
407
- " if c not in bw_chrom_lengths:\n",
408
- " continue\n",
409
- " fa_name = self.chrom_mapping.get(c, c)\n",
410
- " if fa_name not in self.fasta:\n",
411
  " continue\n",
412
  "\n",
413
- " fa_len = len(self.fasta[fa_name])\n",
414
  " bw_len = bw_chrom_lengths[c]\n",
415
  " L = min(fa_len, bw_len)\n",
416
  "\n",
@@ -433,11 +447,8 @@
433
  " start = random.randint(0, max_start)\n",
434
  " end = start + self.sequence_length\n",
435
  "\n",
436
- " # FASTA chromosome name may differ\n",
437
- " fa_chrom = self.chrom_mapping.get(chrom, chrom)\n",
438
- "\n",
439
  " # Sequence\n",
440
- " seq = self.fasta[fa_chrom][start:end] # string slice\n",
441
  " tokens = self.tokenizer(\n",
442
  " seq,\n",
443
  " return_tensors=\"pt\", # Returns a dict of PyTorch tensors\n",
@@ -475,7 +486,7 @@
475
  },
476
  {
477
  "cell_type": "code",
478
- "execution_count": 35,
479
  "metadata": {},
480
  "outputs": [
481
  {
@@ -489,16 +500,12 @@
489
  }
490
  ],
491
  "source": [
492
- "fasta_path = \"./GCF_000001405.40_GRCh38.p14_genomic.fna\"\n",
493
- "bigwig_path_list = [\"./ENCFF884LDL.bigWig\"]\n",
494
- "\n",
495
  "create_dataset_fn = functools.partial(\n",
496
  " GenomeBigWigDataset,\n",
497
  " fasta_path=fasta_path,\n",
498
  " bigwig_path_list=bigwig_path_list,\n",
499
  " sequence_length=config[\"sequence_length\"],\n",
500
  " tokenizer=tokenizer,\n",
501
- " chrom_mapping=chrom_mapping,\n",
502
  " keep_target_center_fraction=config[\"keep_target_center_fraction\"],\n",
503
  " num_tracks=len(config[\"bigwig_file_ids\"]),\n",
504
  ")\n",
@@ -554,7 +561,7 @@
554
  },
555
  {
556
  "cell_type": "code",
557
- "execution_count": 36,
558
  "metadata": {},
559
  "outputs": [],
560
  "source": [
@@ -582,29 +589,9 @@
582
  },
583
  {
584
  "cell_type": "code",
585
- "execution_count": 37,
586
  "metadata": {},
587
- "outputs": [
588
- {
589
- "name": "stdout",
590
- "output_type": "stream",
591
- "text": [
592
- "Gradient accumulation steps: 2\n",
593
- "Effective batch size: 4\n",
594
- "Effective tokens per update: 4096\n",
595
- "\n",
596
- "Training constants:\n",
597
- " Total training steps: 32\n",
598
- " Log training metrics every: 2 steps\n",
599
- " Run validation every: 4 steps\n",
600
- " Warmup steps: 3\n",
601
- "\n",
602
- "Optimizer setup:\n",
603
- " Initial LR: 1e-05\n",
604
- " Peak LR: 5e-05\n"
605
- ]
606
- }
607
- ],
608
  "source": [
609
  "# Calculate gradient accumulation steps and effective batch size\n",
610
  "num_devices = 1 # Single device for now\n",
@@ -676,7 +663,7 @@
676
  },
677
  {
678
  "cell_type": "code",
679
- "execution_count": 38,
680
  "metadata": {},
681
  "outputs": [],
682
  "source": [
@@ -766,7 +753,7 @@
766
  },
767
  {
768
  "cell_type": "code",
769
- "execution_count": 39,
770
  "metadata": {},
771
  "outputs": [],
772
  "source": [
@@ -784,17 +771,9 @@
784
  },
785
  {
786
  "cell_type": "code",
787
- "execution_count": 40,
788
  "metadata": {},
789
- "outputs": [
790
- {
791
- "name": "stdout",
792
- "output_type": "stream",
793
- "text": [
794
- "Scaling functions created\n"
795
- ]
796
- }
797
- ],
798
  "source": [
799
  "def get_track_means(bigwig_file_ids: List[str]) -> np.ndarray:\n",
800
  " \"\"\"\n",
@@ -920,7 +899,7 @@
920
  },
921
  {
922
  "cell_type": "code",
923
- "execution_count": 41,
924
  "metadata": {},
925
  "outputs": [],
926
  "source": [
@@ -996,7 +975,7 @@
996
  },
997
  {
998
  "cell_type": "code",
999
- "execution_count": 42,
1000
  "metadata": {},
1001
  "outputs": [],
1002
  "source": [
@@ -1082,77 +1061,9 @@
1082
  },
1083
  {
1084
  "cell_type": "code",
1085
- "execution_count": 43,
1086
  "metadata": {},
1087
- "outputs": [
1088
- {
1089
- "name": "stdout",
1090
- "output_type": "stream",
1091
- "text": [
1092
- "Starting training...\n",
1093
- "Training for 32 steps with 2 gradient accumulation steps\n",
1094
- "\n",
1095
- "Step 1/32 | Loss: 0.7569 | Mean Pearson: -0.1473 | LR: 1.17e-09 | Tokens: 4,096\n",
1096
- "\n",
1097
- "Running validation at step 0...\n",
1098
- " Validation Loss: 1.0152\n",
1099
- " Validation Mean Pearson: -0.0414\n",
1100
- " ENCFF884LDL/pearson: -0.0414\n",
1101
- "Step 3/32 | Loss: 0.3793 | Mean Pearson: -0.0229 | LR: 2.50e-09 | Tokens: 12,288\n",
1102
- "Step 5/32 | Loss: 0.4111 | Mean Pearson: -0.1739 | LR: 2.41e-09 | Tokens: 20,480\n",
1103
- "\n",
1104
- "Running validation at step 4...\n",
1105
- " Validation Loss: 0.4801\n",
1106
- " Validation Mean Pearson: 0.0120\n",
1107
- " ENCFF884LDL/pearson: 0.0120\n",
1108
- "Step 7/32 | Loss: 0.3404 | Mean Pearson: -0.0191 | LR: 2.32e-09 | Tokens: 28,672\n",
1109
- "Step 9/32 | Loss: 0.3950 | Mean Pearson: 0.0090 | LR: 2.23e-09 | Tokens: 36,864\n",
1110
- "\n",
1111
- "Running validation at step 8...\n",
1112
- " Validation Loss: 0.5865\n",
1113
- " Validation Mean Pearson: -0.0260\n",
1114
- " ENCFF884LDL/pearson: -0.0260\n",
1115
- "Step 11/32 | Loss: 0.3750 | Mean Pearson: 0.0121 | LR: 2.13e-09 | Tokens: 45,056\n",
1116
- "Step 13/32 | Loss: 0.4380 | Mean Pearson: -0.0126 | LR: 2.02e-09 | Tokens: 53,248\n",
1117
- "\n",
1118
- "Running validation at step 12...\n",
1119
- " Validation Loss: 0.3997\n",
1120
- " Validation Mean Pearson: 0.0093\n",
1121
- " ENCFF884LDL/pearson: 0.0093\n",
1122
- "Step 15/32 | Loss: 0.3469 | Mean Pearson: -0.0279 | LR: 1.91e-09 | Tokens: 61,440\n",
1123
- "Step 17/32 | Loss: 0.5098 | Mean Pearson: -0.2044 | LR: 1.80e-09 | Tokens: 69,632\n",
1124
- "\n",
1125
- "Running validation at step 16...\n",
1126
- " Validation Loss: 0.3752\n",
1127
- " Validation Mean Pearson: -0.0178\n",
1128
- " ENCFF884LDL/pearson: -0.0178\n",
1129
- "Step 19/32 | Loss: 0.4899 | Mean Pearson: -0.0424 | LR: 1.67e-09 | Tokens: 77,824\n",
1130
- "Step 21/32 | Loss: 0.3889 | Mean Pearson: -0.0332 | LR: 1.54e-09 | Tokens: 86,016\n",
1131
- "\n",
1132
- "Running validation at step 20...\n",
1133
- " Validation Loss: 0.4217\n",
1134
- " Validation Mean Pearson: -0.0205\n",
1135
- " ENCFF884LDL/pearson: -0.0205\n",
1136
- "Step 23/32 | Loss: 0.3392 | Mean Pearson: 0.0235 | LR: 1.39e-09 | Tokens: 94,208\n",
1137
- "Step 25/32 | Loss: 0.4165 | Mean Pearson: 0.0033 | LR: 1.23e-09 | Tokens: 102,400\n",
1138
- "\n",
1139
- "Running validation at step 24...\n",
1140
- " Validation Loss: 0.4363\n",
1141
- " Validation Mean Pearson: -0.0379\n",
1142
- " ENCFF884LDL/pearson: -0.0379\n",
1143
- "Step 27/32 | Loss: 0.7630 | Mean Pearson: 0.0683 | LR: 1.04e-09 | Tokens: 110,592\n",
1144
- "Step 29/32 | Loss: 0.7357 | Mean Pearson: 0.0050 | LR: 8.04e-10 | Tokens: 118,784\n",
1145
- "\n",
1146
- "Running validation at step 28...\n",
1147
- " Validation Loss: 0.6629\n",
1148
- " Validation Mean Pearson: -0.0370\n",
1149
- " ENCFF884LDL/pearson: -0.0370\n",
1150
- "Step 31/32 | Loss: 0.3690 | Mean Pearson: -0.0808 | LR: 4.64e-10 | Tokens: 126,976\n",
1151
- "\n",
1152
- "Training completed after 32 steps!\n"
1153
- ]
1154
- }
1155
- ],
1156
  "source": [
1157
  "# Training loop (step-based with gradient accumulation)\n",
1158
  "print(\"Starting training...\")\n",
@@ -1263,7 +1174,7 @@
1263
  },
1264
  {
1265
  "cell_type": "code",
1266
- "execution_count": 44,
1267
  "metadata": {},
1268
  "outputs": [],
1269
  "source": [
@@ -1307,47 +1218,7 @@
1307
  "cell_type": "code",
1308
  "execution_count": null,
1309
  "metadata": {},
1310
- "outputs": [
1311
- {
1312
- "name": "stdout",
1313
- "output_type": "stream",
1314
- "text": [
1315
- "\n",
1316
- "==================================================\n",
1317
- "Test Set Evaluation\n",
1318
- "==================================================\n",
1319
- "Running test evaluation with 5 steps (10 samples)\n"
1320
- ]
1321
- },
1322
- {
1323
- "name": "stderr",
1324
- "output_type": "stream",
1325
- "text": [
1326
- "/home/y-bornachot/venvs/ntv3-env/lib/python3.12/site-packages/torch/amp/autocast_mode.py:287: UserWarning: In CPU autocast, but the target dtype is not supported. Disabling autocast.\n",
1327
- "CPU Autocast only supports dtype of torch.bfloat16, torch.float16 currently.\n",
1328
- " warnings.warn(error_message)\n"
1329
- ]
1330
- },
1331
- {
1332
- "name": "stdout",
1333
- "output_type": "stream",
1334
- "text": [
1335
- "\n",
1336
- "==================================================\n",
1337
- "Test Set Results\n",
1338
- "==================================================\n",
1339
- "\n",
1340
- "Scaled Metrics (scaled predictions vs scaled targets):\n",
1341
- " Mean Pearson (scaled): -0.0362\n",
1342
- " metrics_scaled/ENCFF884LDL/pearson: -0.0362\n",
1343
- "\n",
1344
- "Raw Metrics (raw predictions vs raw targets):\n",
1345
- " Mean Pearson (raw): -0.0362\n",
1346
- " metrics_raw/ENCFF884LDL/pearson: -0.0362\n",
1347
- "==================================================\n"
1348
- ]
1349
- }
1350
- ],
1351
  "source": [
1352
  "print(\"\\n\" + \"=\"*50)\n",
1353
  "print(\"Test Set Evaluation\")\n",
 
21
  "outputs": [],
22
  "source": [
23
  "# Install useful dependencies\n",
24
+ "# !pip install pyBigWig\n",
25
+ "# !pip install pyfaidx\n",
26
+ "# !pip install torchmetrics"
27
  ]
28
  },
29
  {
30
  "cell_type": "code",
31
+ "execution_count": 5,
32
  "metadata": {},
33
+ "outputs": [],
 
 
 
 
 
 
 
 
 
34
  "source": [
35
  "# 0. Imports\n",
36
  "import random\n",
37
  "import functools\n",
38
  "from typing import List, Dict, Optional, Callable\n",
39
+ "import os\n",
40
+ "import subprocess\n",
41
  "\n",
42
  "import torch\n",
43
  "import torch.nn as nn\n",
 
47
  "from torch.optim.lr_scheduler import LambdaLR\n",
48
  "from transformers import AutoConfig, AutoModelForMaskedLM, AutoTokenizer\n",
49
  "import numpy as np\n",
50
+ "import pyBigWig\n",
51
+ "from pyfaidx import Fasta\n",
52
  "from torchmetrics import PearsonCorrCoef"
53
  ]
54
  },
 
61
  },
62
  {
63
  "cell_type": "code",
64
+ "execution_count": 6,
65
  "metadata": {},
66
  "outputs": [
67
  {
 
76
  "config = {\n",
77
  " # Model\n",
78
  " \"model_name\": \"InstaDeepAI/ntv3_8M_7downsample_pretrained_le_1mb\", # NTv3 model\n",
 
79
  " \n",
80
  " # Data\n",
81
+ " \"data_cache_dir\": \"./data\",\n",
82
+ " \"fasta_url\": \"https://hgdownload.gi.ucsc.edu/goldenPath/hg38/bigZips/hg38.fa.gz\",\n",
83
+ " \"bigwig_url_list\": [\"https://www.encodeproject.org/files/ENCFF884LDL/@@download/ENCFF884LDL.bigWig\"],\n",
84
  " \"sequence_length\": 1_024,\n",
 
85
  " \"keep_target_center_fraction\": 0.375,\n",
86
  " \n",
87
  " # Training\n",
 
111
  " \"num_workers\": 0, # Number of worker processes for DataLoader\n",
112
  "}\n",
113
  "\n",
114
+ "os.makedirs(config[\"data_cache_dir\"], exist_ok=True)\n",
115
+ "\n",
116
+ "# Extract filenames from URLs\n",
117
+ "def extract_filename_from_url(url: str) -> str:\n",
118
+ " \"\"\"Extract filename from URL, handling query parameters.\"\"\"\n",
119
+ " # Remove query parameters if present\n",
120
+ " url_clean = url.split('?')[0]\n",
121
+ " # Get the last part of the URL path\n",
122
+ " return url_clean.split('/')[-1]\n",
123
+ "\n",
124
+ "# Create paths for downloaded files\n",
125
+ "fasta_path = os.path.join(config[\"data_cache_dir\"], extract_filename_from_url(config[\"fasta_url\"]).replace('.gz', ''))\n",
126
+ "bigwig_path_list = [\n",
127
+ " os.path.join(config[\"data_cache_dir\"], extract_filename_from_url(url))\n",
128
+ " for url in config[\"bigwig_url_list\"]\n",
129
+ "]\n",
130
+ "\n",
131
+ "# Create bigwig_file_ids from filenames (without extension)\n",
132
+ "config[\"bigwig_file_ids\"] = [\n",
133
+ " os.path.splitext(extract_filename_from_url(url))[0]\n",
134
+ " for url in config[\"bigwig_url_list\"]\n",
135
+ "]\n",
136
+ "\n",
137
  "# Set random seed\n",
138
  "torch.manual_seed(config[\"seed\"])\n",
139
  "np.random.seed(config[\"seed\"])\n",
140
  "\n",
141
+ "# Set device\n",
142
  "device = torch.device(config[\"device\"])\n",
143
  "print(f\"Using device: {device}\")"
144
  ]
 
152
  },
153
  {
154
  "cell_type": "code",
155
+ "execution_count": 3,
 
 
 
 
 
 
 
 
 
 
156
  "metadata": {},
157
+ "outputs": [
158
+ {
159
+ "name": "stdout",
160
+ "output_type": "stream",
161
+ "text": [
162
+ "--2025-12-10 14:47:06-- https://hgdownload.gi.ucsc.edu/goldenPath/hg38/bigZips/hg38.fa.gz\n",
163
+ "Resolving hgdownload.gi.ucsc.edu (hgdownload.gi.ucsc.edu)... 128.114.119.163\n",
164
+ "Connecting to hgdownload.gi.ucsc.edu (hgdownload.gi.ucsc.edu)|128.114.119.163|:443... connected.\n",
165
+ "HTTP request sent, awaiting response... 200 OK\n",
166
+ "Length: 983659424 (938M) [application/x-gzip]\n",
167
+ "Saving to: './data/hg38.fa.gz'\n",
168
+ "\n",
169
+ "hg38.fa.gz 100%[===================>] 938.09M 10.4MB/s in 1m 43s \n",
170
+ "\n",
171
+ "2025-12-10 14:48:50 (9.09 MB/s) - './data/hg38.fa.gz' saved [983659424/983659424]\n",
172
+ "\n"
173
+ ]
174
+ }
175
+ ],
176
  "source": [
177
+ "# Download fasta file\n",
178
+ "!wget -c {config[\"fasta_url\"]} -P {config[\"data_cache_dir\"]}/ && gunzip -f {config[\"data_cache_dir\"]}/{config[\"fasta_url\"].split(os.path.sep)[-1]}"
179
  ]
180
  },
181
  {
182
  "cell_type": "code",
183
+ "execution_count": 7,
184
  "metadata": {},
185
+ "outputs": [
186
+ {
187
+ "name": "stdout",
188
+ "output_type": "stream",
189
+ "text": [
190
+ "Downloading ENCFF884LDL.bigWig...\n"
191
+ ]
192
+ },
193
+ {
194
+ "name": "stderr",
195
+ "output_type": "stream",
196
+ "text": [
197
+ "--2025-12-10 14:54:41-- https://www.encodeproject.org/files/ENCFF884LDL/@@download/ENCFF884LDL.bigWig\n",
198
+ "Resolving www.encodeproject.org (www.encodeproject.org)... 34.211.244.144\n",
199
+ "Connecting to www.encodeproject.org (www.encodeproject.org)|34.211.244.144|:443... connected.\n",
200
+ "HTTP request sent, awaiting response... 307 Temporary Redirect\n",
201
+ "Location: https://encode-public.s3.amazonaws.com/2020/09/19/425880b6-b323-4ee2-95ce-56bdd088d126/ENCFF884LDL.bigWig?response-content-disposition=attachment%3B%20filename%3DENCFF884LDL.bigWig&AWSAccessKeyId=ASIATGZNGCNXU6SGJVOL&Signature=4o0Pp2RvJtnZc9z7HOuCU1k9wwI%3D&x-amz-security-token=IQoJb3JpZ2luX2VjEA0aCXVzLXdlc3QtMiJGMEQCIEdyOOxtHk6rJT06xIjzZR3nVyqbPB1twIFxCDtIQfNXAiAph1lc69CfHzPPglodVnVh9QCjlsXHFyUEU3K0%2Bx%2F%2Bziq8BQjW%2F%2F%2F%2F%2F%2F%2F%2F%2F%2F8BEAAaDDIyMDc0ODcxNDg2MyIMYwkeEaXuk%2BE48EDAKpAFkm4uzCSB40oRz3YT4m%2FZfBSH7XIuSCuzS7nrL5tXb9Q2rfPQSD4PHOyTR0LOOfcr98%2FyF8cJw4NE%2Fwsw8BRs4xPFEEyN6yGqwHmAyxBuwdca4GLSMGRDaSPoleMJw1FcSv96ofbZFYTTSol4b6%2FZj4jJjCa887%2F6S5x9kNIjTAtgX%2Fr3Ci4wi4FXGKTijTU%2FnbuuLZ3Cz2UobD6p732apsayl7avmUdWbUvROl3sHFOWOGCKsmDv0mavyEu2EsHxniBPfECy00BNvf%2Bj2FDaz1BImMIDavVBSwcWk8uCPjbsccsgiuKAfwr3dOXQ7R6y4NwmuFluBqn1GOXw1K13T4LrF%2BrhmqdOWeIVKB%2Bo9vnfQm1Dws6EoyS%2BG0bWDnyuUnLtWGf4cZPA6kjcM14fspFxoMnLjHBfdpYKZ3VmikbgwE8mDaiHODH1WQ36lUPigKbbIeHqOnHTIEw5h6F8D0MfIdVBSV2HCXweIlxCr6%2FV8hy2RzDouzT%2FIH%2FIobhHjGPM%2FlmkLAcfEzS2fioCJwkqQ3F%2BC77alAhtDQ4Oy5OIxRnRHVLpO%2BMA9Ml0SrEegCGPIzLucuCtbj2UTEOnBRQXyMolyySopJZb4p4BpJ6MiitLyCt1C66lvJpX5oMri%2BVD7FcTgdPYxcqM%2FMLD%2B4XqTYh5wdK7EYe3CpsVjpviZSVbn7yVHAb8WqdmFO%2BXRGhjQdN6rMrwGPiMCmQq12tTQftfmEwPGN1CVHG%2BbL1KUpEF4BRE61xDwEu7ZXyycPqTJMKHVn%2BXZ%2BxFsaxpUsp25U6JIVVPiNgt1OyhfjU6oqzwzeXH7KMRIcqz2d%2B3p%2BIbjRvoHcLc8AzgY4RvgWMGlb5gIpv15HQTDvdiLLwwjd3lyQY6sgE9t%2Bhi2Jv1DPgJN0YUGblcTV3Ey95h%2BBIXo6zWGwqhyZhkH%2ByxJKXouv2S1mKS3BM0dp2maJGDp69Mze8UkGjFYvdzxHT1zrCZ4dMRRkRObY3%2F4ZP33ogelhzchd7S76et35vYwYHd9DYycWZnJ%2FIcfpSZURGMJu3gLM3YhIscykGwQKqB21Tmyjufi0AaYyLk4w2OKc31kgjFvs6lNaHhqTuFButuHEiBUMzieixOI%2BX6&Expires=1765504482 [following]\n",
202
+ "--2025-12-10 14:54:42-- https://encode-public.s3.amazonaws.com/2020/09/19/425880b6-b323-4ee2-95ce-56bdd088d126/ENCFF884LDL.bigWig?response-content-disposition=attachment%3B%20filename%3DENCFF884LDL.bigWig&AWSAccessKeyId=ASIATGZNGCNXU6SGJVOL&Signature=4o0Pp2RvJtnZc9z7HOuCU1k9wwI%3D&x-amz-security-token=IQoJb3JpZ2luX2VjEA0aCXVzLXdlc3QtMiJGMEQCIEdyOOxtHk6rJT06xIjzZR3nVyqbPB1twIFxCDtIQfNXAiAph1lc69CfHzPPglodVnVh9QCjlsXHFyUEU3K0%2Bx%2F%2Bziq8BQjW%2F%2F%2F%2F%2F%2F%2F%2F%2F%2F8BEAAaDDIyMDc0ODcxNDg2MyIMYwkeEaXuk%2BE48EDAKpAFkm4uzCSB40oRz3YT4m%2FZfBSH7XIuSCuzS7nrL5tXb9Q2rfPQSD4PHOyTR0LOOfcr98%2FyF8cJw4NE%2Fwsw8BRs4xPFEEyN6yGqwHmAyxBuwdca4GLSMGRDaSPoleMJw1FcSv96ofbZFYTTSol4b6%2FZj4jJjCa887%2F6S5x9kNIjTAtgX%2Fr3Ci4wi4FXGKTijTU%2FnbuuLZ3Cz2UobD6p732apsayl7avmUdWbUvROl3sHFOWOGCKsmDv0mavyEu2EsHxniBPfECy00BNvf%2Bj2FDaz1BImMIDavVBSwcWk8uCPjbsccsgiuKAfwr3dOXQ7R6y4NwmuFluBqn1GOXw1K13T4LrF%2BrhmqdOWeIVKB%2Bo9vnfQm1Dws6EoyS%2BG0bWDnyuUnLtWGf4cZPA6kjcM14fspFxoMnLjHBfdpYKZ3VmikbgwE8mDaiHODH1WQ36lUPigKbbIeHqOnHTIEw5h6F8D0MfIdVBSV2HCXweIlxCr6%2FV8hy2RzDouzT%2FIH%2FIobhHjGPM%2FlmkLAcfEzS2fioCJwkqQ3F%2BC77alAhtDQ4Oy5OIxRnRHVLpO%2BMA9Ml0SrEegCGPIzLucuCtbj2UTEOnBRQXyMolyySopJZb4p4BpJ6MiitLyCt1C66lvJpX5oMri%2BVD7FcTgdPYxcqM%2FMLD%2B4XqTYh5wdK7EYe3CpsVjpviZSVbn7yVHAb8WqdmFO%2BXRGhjQdN6rMrwGPiMCmQq12tTQftfmEwPGN1CVHG%2BbL1KUpEF4BRE61xDwEu7ZXyycPqTJMKHVn%2BXZ%2BxFsaxpUsp25U6JIVVPiNgt1OyhfjU6oqzwzeXH7KMRIcqz2d%2B3p%2BIbjRvoHcLc8AzgY4RvgWMGlb5gIpv15HQTDvdiLLwwjd3lyQY6sgE9t%2Bhi2Jv1DPgJN0YUGblcTV3Ey95h%2BBIXo6zWGwqhyZhkH%2ByxJKXouv2S1mKS3BM0dp2maJGDp69Mze8UkGjFYvdzxHT1zrCZ4dMRRkRObY3%2F4ZP33ogelhzchd7S76et35vYwYHd9DYycWZnJ%2FIcfpSZURGMJu3gLM3YhIscykGwQKqB21Tmyjufi0AaYyLk4w2OKc31kgjFvs6lNaHhqTuFButuHEiBUMzieixOI%2BX6&Expires=1765504482\n",
203
+ "Resolving encode-public.s3.amazonaws.com (encode-public.s3.amazonaws.com)... 52.92.248.169, 52.92.211.49, 3.5.80.18, ...\n",
204
+ "Connecting to encode-public.s3.amazonaws.com (encode-public.s3.amazonaws.com)|52.92.248.169|:443... connected.\n",
205
+ "HTTP request sent, awaiting response... 416 Requested Range Not Satisfiable\n",
206
+ "\n",
207
+ " The file is already fully retrieved; nothing to do.\n",
208
+ "\n"
209
+ ]
210
+ }
211
+ ],
212
  "source": [
213
+ "# Download bigwig files\n",
214
+ "for bigwig_url in config[\"bigwig_url_list\"]:\n",
215
+ " filename = extract_filename_from_url(bigwig_url)\n",
216
+ " filepath = os.path.join(config[\"data_cache_dir\"], filename)\n",
217
+ " print(f\"Downloading {filename}...\")\n",
218
+ " subprocess.run([\"wget\", \"-c\", bigwig_url, \"-O\", filepath], check=True)"
219
  ]
220
  },
221
  {
222
  "cell_type": "code",
223
+ "execution_count": 8,
224
  "metadata": {},
225
  "outputs": [],
226
  "source": [
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
227
  "chrom_splits = {\n",
228
+ " \"train\": [f\"chr{i}\" for i in range(1, 21)] + ['chrX', 'chrY'],\n",
229
+ " \"val\": ['chr22'],\n",
230
+ " \"test\": ['chr21']\n",
231
  "}"
232
  ]
233
  },
 
240
  },
241
  {
242
  "cell_type": "code",
243
+ "execution_count": 11,
244
  "metadata": {},
245
  "outputs": [],
246
  "source": [
 
266
  " model_name: str,\n",
267
  " bigwig_track_names: List[str],\n",
268
  " keep_target_center_fraction: float = 0.375,\n",
 
269
  " ):\n",
270
  " super().__init__()\n",
271
  " \n",
272
  " # Load config and model\n",
273
  " self.config = AutoConfig.from_pretrained(model_name, trust_remote_code=True)\n",
274
+ " self.backbone = AutoModelForMaskedLM.from_pretrained(\n",
275
+ " model_name, \n",
276
+ " trust_remote_code=True,\n",
277
+ " config=self.config\n",
278
+ " )\n",
 
 
 
 
 
 
 
279
  " \n",
280
  " self.keep_target_center_fraction = keep_target_center_fraction\n",
281
  "\n",
 
308
  },
309
  {
310
  "cell_type": "code",
311
+ "execution_count": 12,
312
  "metadata": {},
313
  "outputs": [
314
  {
 
335
  " model_name=config[\"model_name\"],\n",
336
  " bigwig_track_names=config[\"bigwig_file_ids\"],\n",
337
  " keep_target_center_fraction=config[\"keep_target_center_fraction\"],\n",
 
338
  ")\n",
339
  "model = model.to(device)\n",
340
  "model.train()\n",
 
353
  },
354
  {
355
  "cell_type": "code",
356
+ "execution_count": 17,
357
  "metadata": {},
358
  "outputs": [],
359
  "source": [
 
397
  " sequence_length: int,\n",
398
  " num_samples: int,\n",
399
  " tokenizer: AutoTokenizer,\n",
 
400
  " keep_target_center_fraction: float = 1.0,\n",
401
  " num_tracks: int = 1,\n",
402
  " ):\n",
 
412
  " self.tokenizer = tokenizer\n",
413
  " self.keep_target_center_fraction = keep_target_center_fraction\n",
414
  " self.num_tracks = num_tracks\n",
 
415
  " self.chroms = chroms\n",
 
416
  "\n",
417
  " # Intersect lengths between FASTA and bigWig for safety\n",
418
  " bw_chrom_lengths = self.bw_list[0].chroms() # dict: chrom -> length\n",
 
421
  " self.chrom_lengths = {}\n",
422
  "\n",
423
  " for c in chroms:\n",
424
+ " if c not in bw_chrom_lengths or c not in self.fasta:\n",
 
 
 
425
  " continue\n",
426
  "\n",
427
+ " fa_len = len(self.fasta[c])\n",
428
  " bw_len = bw_chrom_lengths[c]\n",
429
  " L = min(fa_len, bw_len)\n",
430
  "\n",
 
447
  " start = random.randint(0, max_start)\n",
448
  " end = start + self.sequence_length\n",
449
  "\n",
 
 
 
450
  " # Sequence\n",
451
+ " seq = self.fasta[chrom][start:end] # string slice\n",
452
  " tokens = self.tokenizer(\n",
453
  " seq,\n",
454
  " return_tensors=\"pt\", # Returns a dict of PyTorch tensors\n",
 
486
  },
487
  {
488
  "cell_type": "code",
489
+ "execution_count": 18,
490
  "metadata": {},
491
  "outputs": [
492
  {
 
500
  }
501
  ],
502
  "source": [
 
 
 
503
  "create_dataset_fn = functools.partial(\n",
504
  " GenomeBigWigDataset,\n",
505
  " fasta_path=fasta_path,\n",
506
  " bigwig_path_list=bigwig_path_list,\n",
507
  " sequence_length=config[\"sequence_length\"],\n",
508
  " tokenizer=tokenizer,\n",
 
509
  " keep_target_center_fraction=config[\"keep_target_center_fraction\"],\n",
510
  " num_tracks=len(config[\"bigwig_file_ids\"]),\n",
511
  ")\n",
 
561
  },
562
  {
563
  "cell_type": "code",
564
+ "execution_count": null,
565
  "metadata": {},
566
  "outputs": [],
567
  "source": [
 
589
  },
590
  {
591
  "cell_type": "code",
592
+ "execution_count": null,
593
  "metadata": {},
594
+ "outputs": [],
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
595
  "source": [
596
  "# Calculate gradient accumulation steps and effective batch size\n",
597
  "num_devices = 1 # Single device for now\n",
 
663
  },
664
  {
665
  "cell_type": "code",
666
+ "execution_count": null,
667
  "metadata": {},
668
  "outputs": [],
669
  "source": [
 
753
  },
754
  {
755
  "cell_type": "code",
756
+ "execution_count": null,
757
  "metadata": {},
758
  "outputs": [],
759
  "source": [
 
771
  },
772
  {
773
  "cell_type": "code",
774
+ "execution_count": null,
775
  "metadata": {},
776
+ "outputs": [],
 
 
 
 
 
 
 
 
777
  "source": [
778
  "def get_track_means(bigwig_file_ids: List[str]) -> np.ndarray:\n",
779
  " \"\"\"\n",
 
899
  },
900
  {
901
  "cell_type": "code",
902
+ "execution_count": null,
903
  "metadata": {},
904
  "outputs": [],
905
  "source": [
 
975
  },
976
  {
977
  "cell_type": "code",
978
+ "execution_count": null,
979
  "metadata": {},
980
  "outputs": [],
981
  "source": [
 
1061
  },
1062
  {
1063
  "cell_type": "code",
1064
+ "execution_count": null,
1065
  "metadata": {},
1066
+ "outputs": [],
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1067
  "source": [
1068
  "# Training loop (step-based with gradient accumulation)\n",
1069
  "print(\"Starting training...\")\n",
 
1174
  },
1175
  {
1176
  "cell_type": "code",
1177
+ "execution_count": null,
1178
  "metadata": {},
1179
  "outputs": [],
1180
  "source": [
 
1218
  "cell_type": "code",
1219
  "execution_count": null,
1220
  "metadata": {},
1221
+ "outputs": [],
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1222
  "source": [
1223
  "print(\"\\n\" + \"=\"*50)\n",
1224
  "print(\"Test Set Evaluation\")\n",