ybornachot commited on
Commit
b04b4fa
·
1 Parent(s): 8849cef

fix: metrics correction

Browse files
Files changed (1) hide show
  1. notebooks/03_fine_tuning.ipynb +110 -158
notebooks/03_fine_tuning.ipynb CHANGED
@@ -26,7 +26,7 @@
26
  },
27
  {
28
  "cell_type": "code",
29
- "execution_count": null,
30
  "metadata": {},
31
  "outputs": [
32
  {
@@ -66,7 +66,7 @@
66
  },
67
  {
68
  "cell_type": "code",
69
- "execution_count": null,
70
  "metadata": {},
71
  "outputs": [
72
  {
@@ -112,7 +112,7 @@
112
  " # General\n",
113
  " \"seed\": 42,\n",
114
  " \"device\": \"cuda\" if torch.cuda.is_available() else \"cpu\",\n",
115
- " \"num_workers\": 4, # Number of worker processes for DataLoader\n",
116
  "}\n",
117
  "\n",
118
  "# Set random seed\n",
@@ -132,27 +132,9 @@
132
  },
133
  {
134
  "cell_type": "code",
135
- "execution_count": 2,
136
  "metadata": {},
137
- "outputs": [
138
- {
139
- "name": "stdout",
140
- "output_type": "stream",
141
- "text": [
142
- "--2025-12-09 18:33:50-- https://ftp.ncbi.nlm.nih.gov/genomes/refseq/vertebrate_mammalian/Homo_sapiens/latest_assembly_versions/GCF_000001405.40_GRCh38.p14/GCF_000001405.40_GRCh38.p14_genomic.fna.gz\n",
143
- "Resolving ftp.ncbi.nlm.nih.gov (ftp.ncbi.nlm.nih.gov)... 2607:f220:41e:250::7, 2607:f220:41e:250::11, 2607:f220:41e:250::12, ...\n",
144
- "Connecting to ftp.ncbi.nlm.nih.gov (ftp.ncbi.nlm.nih.gov)|2607:f220:41e:250::7|:443... connected.\n",
145
- "HTTP request sent, awaiting response... 200 OK\n",
146
- "Length: 972898531 (928M) [application/x-gzip]\n",
147
- "Saving to: 'GCF_000001405.40_GRCh38.p14_genomic.fna.gz'\n",
148
- "\n",
149
- "GCF_000001405.40_GR 100%[===================>] 927.83M 18.4MB/s in 51s \n",
150
- "\n",
151
- "2025-12-09 18:34:42 (18.0 MB/s) - 'GCF_000001405.40_GRCh38.p14_genomic.fna.gz' saved [972898531/972898531]\n",
152
- "\n"
153
- ]
154
- }
155
- ],
156
  "source": [
157
  "!wget -c https://ftp.ncbi.nlm.nih.gov/genomes/refseq/vertebrate_mammalian/Homo_sapiens/latest_assembly_versions/GCF_000001405.40_GRCh38.p14/GCF_000001405.40_GRCh38.p14_genomic.fna.gz \\\n",
158
  "&& gunzip -f GCF_000001405.40_GRCh38.p14_genomic.fna.gz"
@@ -160,22 +142,9 @@
160
  },
161
  {
162
  "cell_type": "code",
163
- "execution_count": 16,
164
  "metadata": {},
165
- "outputs": [
166
- {
167
- "name": "stdout",
168
- "output_type": "stream",
169
- "text": [
170
- "--2025-12-09 22:13:59-- https://www.encodeproject.org/files/ENCFF884LDL/@@download/ENCFF884LDL\n",
171
- "Resolving www.encodeproject.org (www.encodeproject.org)... 34.211.244.144\n",
172
- "Connecting to www.encodeproject.org (www.encodeproject.org)|34.211.244.144|:443... connected.\n",
173
- "HTTP request sent, awaiting response... 404 Not Found\n",
174
- "2025-12-09 22:14:00 ERROR 404: Not Found.\n",
175
- "\n"
176
- ]
177
- }
178
- ],
179
  "source": [
180
  "!wget -O ENCFF884LDL \"$(curl -s https://www.encodeproject.org/files/ENCFF884LDL/@@download/ENCFF884LDL | sed -n 's/.*href=\\\"\\([^\\\"]*ENCFF884LDL[^\\\"]*\\)\\\".*/\\1/p')\" \\\n",
181
  "&& echo \"Downloaded ENCFF884LDL\""
@@ -183,39 +152,16 @@
183
  },
184
  {
185
  "cell_type": "code",
186
- "execution_count": 4,
187
  "metadata": {},
188
- "outputs": [
189
- {
190
- "name": "stdout",
191
- "output_type": "stream",
192
- "text": [
193
- "--2025-12-09 18:41:24-- https://www.encodeproject.org/files/ENCFF884LDL/@@download/ENCFF884LDL.bigWig\n",
194
- "Resolving www.encodeproject.org (www.encodeproject.org)... 34.211.244.144\n",
195
- "Connecting to www.encodeproject.org (www.encodeproject.org)|34.211.244.144|:443... connected.\n",
196
- "HTTP request sent, awaiting response... 307 Temporary Redirect\n",
197
- "Location: https://encode-public.s3.amazonaws.com/2020/09/19/425880b6-b323-4ee2-95ce-56bdd088d126/ENCFF884LDL.bigWig?response-content-disposition=attachment%3B%20filename%3DENCFF884LDL.bigWig&AWSAccessKeyId=ASIATGZNGCNX3AXUNFS3&Signature=Ca%2Bz1PL7zdbGzyRggtvN686q4oE%3D&x-amz-security-token=IQoJb3JpZ2luX2VjEPr%2F%2F%2F%2F%2F%2F%2F%2F%2F%2FwEaCXVzLXdlc3QtMiJGMEQCIAggXesBwHBuGSivVx0RvF5f2vZbk09TPBdf%2FYJUt%2BLWAiAKrh58c%2Bm%2F%2ByrujtQxgltFGzGo5qXSWv%2B0zPaa3gKUTCq8BQjC%2F%2F%2F%2F%2F%2F%2F%2F%2F%2F8BEAAaDDIyMDc0ODcxNDg2MyIMa%2FegIMq%2By2ql10quKpAFATT6r6oWCSXqrBd2gfR8S1QNvY%2BKjvbr%2BvS2ifnF5NqfByJgZxdXVC65WI8fUYqgspTQB5Az%2BE5O4jR8EnFBv%2FjO6DqrWkQQOUsHUFFGXJjarvCPdYjqJmV9SyeTuzNeV0xwFX%2Fleq1%2F4f3eAV81Nv5J%2B8UeHYn5GxtwS%2BjhzVsCJ8tqAo6yRi0wPteU8nb8yLJb%2F%2FWvQLZce7Yc9%2BZkuxKKGoEKQstRSGLCh%2FjtnNfvGp0x20mj5C7wsk61LHBJlNV3KVD7qZHZ57N1CBx5XNuJ%2BkJp6eBU8htM%2FY73tBkp4w5xHNyI5F%2B7JxjDDjo4YOikyLKk7tnTmWfC2lEGXXx33D8xyBxi4oNnK76R0N296GRSHS22esmo12YGK5QNvVbU4SuZUUWjVcrGFqtN%2F7ff1K%2FdqiRyh6TDvXbOUf%2Bk691iqwRY34LbXoJsOzcux5wwQGbHfcSdGrp2Y3KtpDGEdHiiTVHJeHi9pxBvlwvmjM5lXjJjtjOFqXIF%2F%2FygXdl4wUIMMsuinPWpA5xVIk4kg1Bv5XVNuqcPJl7Dl2ZdRzQvwc0Xl5dBL39ZAz9MvCffPV2Fb3hiL5vIQJ2ySdDnqXDhTuUsWGy81MltoznoOVbvuu64FAEp4GdwnwRH1ILlVOKQ1bHR5FSHqb8OFVqAQezRljaJY2ds1J2HMAJ2AJtg3k8XNQScR%2FutxWkI3pYDnAQQQkHHw3aFWNNYbQMfyAAptJohtNGClRoTiepBUckqxpgvMXwEOTJzpUEi0sMIxMkXMWa3ncKFHQAP6P3eKxBOjW8s%2F3BXwRlbgsNdQvqDUdf2dD5KLeHfpyKbdPnG0C6yZAxBF%2Fk4jO1F2F4o533RZGF8Ww7qMc5Ij2ww%2BbPhyQY6sgG2uZfWDKxd1yRNOufiZW%2FAtmcEQg%2BtzoWnq6TxyhU0OCY%2BN7xR8HO4UaT0Od0C06PHugNQCUS6eJusR0IfSRJ7ozZJUomphTeCPXw1G%2B6RVsni%2B9lGE8SlRLTMzNvzQJv8oJNZsoi6DVWlK%2FGt7TgwxSKH8%2BVQmal7nXUqR9f8Dh7CF1KppbVtNiGDaxTIN%2F7j%2BwIFrKHIMOYhC1dt5gPFnIQwnj1%2BuyEw5FWF3hKIkD%2Bc&Expires=1765431685 [following]\n",
198
- "--2025-12-09 18:41:25-- https://encode-public.s3.amazonaws.com/2020/09/19/425880b6-b323-4ee2-95ce-56bdd088d126/ENCFF884LDL.bigWig?response-content-disposition=attachment%3B%20filename%3DENCFF884LDL.bigWig&AWSAccessKeyId=ASIATGZNGCNX3AXUNFS3&Signature=Ca%2Bz1PL7zdbGzyRggtvN686q4oE%3D&x-amz-security-token=IQoJb3JpZ2luX2VjEPr%2F%2F%2F%2F%2F%2F%2F%2F%2F%2FwEaCXVzLXdlc3QtMiJGMEQCIAggXesBwHBuGSivVx0RvF5f2vZbk09TPBdf%2FYJUt%2BLWAiAKrh58c%2Bm%2F%2ByrujtQxgltFGzGo5qXSWv%2B0zPaa3gKUTCq8BQjC%2F%2F%2F%2F%2F%2F%2F%2F%2F%2F8BEAAaDDIyMDc0ODcxNDg2MyIMa%2FegIMq%2By2ql10quKpAFATT6r6oWCSXqrBd2gfR8S1QNvY%2BKjvbr%2BvS2ifnF5NqfByJgZxdXVC65WI8fUYqgspTQB5Az%2BE5O4jR8EnFBv%2FjO6DqrWkQQOUsHUFFGXJjarvCPdYjqJmV9SyeTuzNeV0xwFX%2Fleq1%2F4f3eAV81Nv5J%2B8UeHYn5GxtwS%2BjhzVsCJ8tqAo6yRi0wPteU8nb8yLJb%2F%2FWvQLZce7Yc9%2BZkuxKKGoEKQstRSGLCh%2FjtnNfvGp0x20mj5C7wsk61LHBJlNV3KVD7qZHZ57N1CBx5XNuJ%2BkJp6eBU8htM%2FY73tBkp4w5xHNyI5F%2B7JxjDDjo4YOikyLKk7tnTmWfC2lEGXXx33D8xyBxi4oNnK76R0N296GRSHS22esmo12YGK5QNvVbU4SuZUUWjVcrGFqtN%2F7ff1K%2FdqiRyh6TDvXbOUf%2Bk691iqwRY34LbXoJsOzcux5wwQGbHfcSdGrp2Y3KtpDGEdHiiTVHJeHi9pxBvlwvmjM5lXjJjtjOFqXIF%2F%2FygXdl4wUIMMsuinPWpA5xVIk4kg1Bv5XVNuqcPJl7Dl2ZdRzQvwc0Xl5dBL39ZAz9MvCffPV2Fb3hiL5vIQJ2ySdDnqXDhTuUsWGy81MltoznoOVbvuu64FAEp4GdwnwRH1ILlVOKQ1bHR5FSHqb8OFVqAQezRljaJY2ds1J2HMAJ2AJtg3k8XNQScR%2FutxWkI3pYDnAQQQkHHw3aFWNNYbQMfyAAptJohtNGClRoTiepBUckqxpgvMXwEOTJzpUEi0sMIxMkXMWa3ncKFHQAP6P3eKxBOjW8s%2F3BXwRlbgsNdQvqDUdf2dD5KLeHfpyKbdPnG0C6yZAxBF%2Fk4jO1F2F4o533RZGF8Ww7qMc5Ij2ww%2BbPhyQY6sgG2uZfWDKxd1yRNOufiZW%2FAtmcEQg%2BtzoWnq6TxyhU0OCY%2BN7xR8HO4UaT0Od0C06PHugNQCUS6eJusR0IfSRJ7ozZJUomphTeCPXw1G%2B6RVsni%2B9lGE8SlRLTMzNvzQJv8oJNZsoi6DVWlK%2FGt7TgwxSKH8%2BVQmal7nXUqR9f8Dh7CF1KppbVtNiGDaxTIN%2F7j%2BwIFrKHIMOYhC1dt5gPFnIQwnj1%2BuyEw5FWF3hKIkD%2Bc&Expires=1765431685\n",
199
- "Resolving encode-public.s3.amazonaws.com (encode-public.s3.amazonaws.com)... 3.5.81.13, 52.92.211.217, 52.92.197.57, ...\n",
200
- "Connecting to encode-public.s3.amazonaws.com (encode-public.s3.amazonaws.com)|3.5.81.13|:443... connected.\n",
201
- "HTTP request sent, awaiting response... 200 OK\n",
202
- "Length: 568139478 (542M) [binary/octet-stream]\n",
203
- "Saving to: 'ENCFF884LDL.bigWig'\n",
204
- "\n",
205
- "ENCFF884LDL.bigWig 100%[===================>] 541.82M 9.64MB/s in 79s \n",
206
- "\n",
207
- "2025-12-09 18:42:45 (6.88 MB/s) - 'ENCFF884LDL.bigWig' saved [568139478/568139478]\n",
208
- "\n"
209
- ]
210
- }
211
- ],
212
  "source": [
213
  "!wget -c https://www.encodeproject.org/files/ENCFF884LDL/@@download/ENCFF884LDL.bigWig"
214
  ]
215
  },
216
  {
217
  "cell_type": "code",
218
- "execution_count": 5,
219
  "metadata": {},
220
  "outputs": [],
221
  "source": [
@@ -265,7 +211,7 @@
265
  },
266
  {
267
  "cell_type": "code",
268
- "execution_count": 71,
269
  "metadata": {},
270
  "outputs": [],
271
  "source": [
@@ -341,7 +287,7 @@
341
  },
342
  {
343
  "cell_type": "code",
344
- "execution_count": 72,
345
  "metadata": {},
346
  "outputs": [
347
  {
@@ -387,7 +333,7 @@
387
  },
388
  {
389
  "cell_type": "code",
390
- "execution_count": null,
391
  "metadata": {},
392
  "outputs": [],
393
  "source": [
@@ -500,20 +446,22 @@
500
  " # For a single input string, its shape is typically (1, len(seq))\n",
501
  "\n",
502
  " # Signal from bigWig tracks (numpy array) -> torch tensor\n",
503
- " bigwig_targets = [\n",
504
  " self.bw_list[i].values(chrom, start, end, numpy=True)\n",
505
  " for i in range(len(self.bw_list))\n",
506
- " ]\n",
 
 
507
  " # pyBigWig returns NaN where no data; turn NaN into 0\n",
508
  " bigwig_targets = torch.tensor(bigwig_targets, dtype=torch.float32)\n",
509
  " bigwig_targets = torch.nan_to_num(bigwig_targets, nan=0.0)\n",
510
  " \n",
511
  " # Crop targets to center fraction\n",
512
  " if self.keep_target_center_fraction < 1.0:\n",
513
- " seq_len = bigwig_targets.shape[0]\n",
514
  " target_offset = int(seq_len * (1 - self.keep_target_center_fraction) // 2)\n",
515
  " target_length = seq_len - 2 * target_offset\n",
516
- " bigwig_targets = bigwig_targets[target_offset:target_offset + target_length]\n",
517
  "\n",
518
  " sample = {\n",
519
  " \"tokens\": tokens,\n",
@@ -527,7 +475,7 @@
527
  },
528
  {
529
  "cell_type": "code",
530
- "execution_count": null,
531
  "metadata": {},
532
  "outputs": [
533
  {
@@ -535,7 +483,8 @@
535
  "output_type": "stream",
536
  "text": [
537
  "Train samples: 100\n",
538
- "Val samples: 10\n"
 
539
  ]
540
  }
541
  ],
@@ -605,7 +554,7 @@
605
  },
606
  {
607
  "cell_type": "code",
608
- "execution_count": 59,
609
  "metadata": {},
610
  "outputs": [],
611
  "source": [
@@ -633,7 +582,7 @@
633
  },
634
  {
635
  "cell_type": "code",
636
- "execution_count": 60,
637
  "metadata": {},
638
  "outputs": [
639
  {
@@ -727,7 +676,7 @@
727
  },
728
  {
729
  "cell_type": "code",
730
- "execution_count": null,
731
  "metadata": {},
732
  "outputs": [],
733
  "source": [
@@ -794,20 +743,20 @@
794
  " # Scaled metrics: per-track Pearson correlations\n",
795
  " for i, (track_name, metric) in enumerate(zip(self.track_names, self.pearson_metrics_scaled)):\n",
796
  " corr = metric.compute().item()\n",
797
- " metrics_dict[f\"{track_name}/pearson_scaled\"] = corr\n",
798
  " \n",
799
  " # Scaled metrics: mean Pearson correlation\n",
800
  " correlations_scaled = [metric.compute().item() for metric in self.pearson_metrics_scaled]\n",
801
- " metrics_dict[\"mean/pearson_scaled\"] = np.nanmean(correlations_scaled)\n",
802
  " \n",
803
  " # Raw metrics: per-track Pearson correlations\n",
804
  " for i, (track_name, metric) in enumerate(zip(self.track_names, self.pearson_metrics_raw)):\n",
805
  " corr = metric.compute().item()\n",
806
- " metrics_dict[f\"{track_name}/pearson_raw\"] = corr\n",
807
  " \n",
808
  " # Raw metrics: mean Pearson correlation\n",
809
  " correlations_raw = [metric.compute().item() for metric in self.pearson_metrics_raw]\n",
810
- " metrics_dict[\"mean/pearson_raw\"] = np.nanmean(correlations_raw)\n",
811
  " \n",
812
  " # Mean loss\n",
813
  " metrics_dict[\"loss\"] = np.mean(self.losses) if self.losses else 0.0\n",
@@ -817,7 +766,7 @@
817
  },
818
  {
819
  "cell_type": "code",
820
- "execution_count": null,
821
  "metadata": {},
822
  "outputs": [],
823
  "source": [
@@ -835,7 +784,7 @@
835
  },
836
  {
837
  "cell_type": "code",
838
- "execution_count": 63,
839
  "metadata": {},
840
  "outputs": [
841
  {
@@ -971,7 +920,7 @@
971
  },
972
  {
973
  "cell_type": "code",
974
- "execution_count": 64,
975
  "metadata": {},
976
  "outputs": [],
977
  "source": [
@@ -1047,7 +996,7 @@
1047
  },
1048
  {
1049
  "cell_type": "code",
1050
- "execution_count": null,
1051
  "metadata": {},
1052
  "outputs": [],
1053
  "source": [
@@ -1133,7 +1082,7 @@
1133
  },
1134
  {
1135
  "cell_type": "code",
1136
- "execution_count": null,
1137
  "metadata": {},
1138
  "outputs": [
1139
  {
@@ -1142,78 +1091,63 @@
1142
  "text": [
1143
  "Starting training...\n",
1144
  "Training for 32 steps with 2 gradient accumulation steps\n",
1145
- "\n"
1146
- ]
1147
- },
1148
- {
1149
- "name": "stderr",
1150
- "output_type": "stream",
1151
- "text": [
1152
- "/home/y-bornachot/venvs/ntv3-env/lib/python3.12/site-packages/torch/amp/autocast_mode.py:287: UserWarning: In CPU autocast, but the target dtype is not supported. Disabling autocast.\n",
1153
- "CPU Autocast only supports dtype of torch.bfloat16, torch.float16 currently.\n",
1154
- " warnings.warn(error_message)\n"
1155
- ]
1156
- },
1157
- {
1158
- "name": "stdout",
1159
- "output_type": "stream",
1160
- "text": [
1161
- "Step 0/32 | Loss: 1.5993 | Mean Pearson: -0.0848 | LR: 1.17e-09 | Tokens: 4,096\n",
1162
  "\n",
1163
  "Running validation at step 0...\n",
1164
- " Validation Loss: 0.6607\n",
1165
- " Validation Mean Pearson: -0.0054\n",
1166
- " ENCFF884LDL/pearson: -0.0054\n",
1167
- "Step 2/32 | Loss: 0.3453 | Mean Pearson: -0.2111 | LR: 2.50e-09 | Tokens: 12,288\n",
1168
- "Step 4/32 | Loss: 1.0248 | Mean Pearson: -0.0197 | LR: 2.41e-09 | Tokens: 20,480\n",
1169
  "\n",
1170
  "Running validation at step 4...\n",
1171
- " Validation Loss: 0.5158\n",
1172
- " Validation Mean Pearson: 0.0160\n",
1173
- " ENCFF884LDL/pearson: 0.0160\n",
1174
- "Step 6/32 | Loss: 0.3720 | Mean Pearson: 0.0140 | LR: 2.32e-09 | Tokens: 28,672\n",
1175
- "Step 8/32 | Loss: 0.4894 | Mean Pearson: -0.0300 | LR: 2.23e-09 | Tokens: 36,864\n",
1176
  "\n",
1177
  "Running validation at step 8...\n",
1178
- " Validation Loss: 0.5024\n",
1179
- " Validation Mean Pearson: -0.0443\n",
1180
- " ENCFF884LDL/pearson: -0.0443\n",
1181
- "Step 10/32 | Loss: 0.4039 | Mean Pearson: -0.0783 | LR: 2.13e-09 | Tokens: 45,056\n",
1182
- "Step 12/32 | Loss: 0.4974 | Mean Pearson: 0.0227 | LR: 2.02e-09 | Tokens: 53,248\n",
1183
  "\n",
1184
  "Running validation at step 12...\n",
1185
- " Validation Loss: 0.5107\n",
1186
- " Validation Mean Pearson: -0.0596\n",
1187
- " ENCFF884LDL/pearson: -0.0596\n",
1188
- "Step 14/32 | Loss: 0.2984 | Mean Pearson: -0.0820 | LR: 1.91e-09 | Tokens: 61,440\n",
1189
- "Step 16/32 | Loss: 0.5219 | Mean Pearson: -0.0668 | LR: 1.80e-09 | Tokens: 69,632\n",
1190
  "\n",
1191
  "Running validation at step 16...\n",
1192
- " Validation Loss: 0.8410\n",
1193
- " Validation Mean Pearson: 0.0041\n",
1194
- " ENCFF884LDL/pearson: 0.0041\n",
1195
- "Step 18/32 | Loss: 0.3663 | Mean Pearson: 0.0888 | LR: 1.67e-09 | Tokens: 77,824\n",
1196
- "Step 20/32 | Loss: 0.4024 | Mean Pearson: -0.0628 | LR: 1.54e-09 | Tokens: 86,016\n",
1197
  "\n",
1198
  "Running validation at step 20...\n",
1199
- " Validation Loss: 0.4043\n",
1200
- " Validation Mean Pearson: -0.1108\n",
1201
- " ENCFF884LDL/pearson: -0.1108\n",
1202
- "Step 22/32 | Loss: 0.4096 | Mean Pearson: -0.0249 | LR: 1.39e-09 | Tokens: 94,208\n",
1203
- "Step 24/32 | Loss: 0.3930 | Mean Pearson: -0.0779 | LR: 1.23e-09 | Tokens: 102,400\n",
1204
  "\n",
1205
  "Running validation at step 24...\n",
1206
- " Validation Loss: 0.3426\n",
1207
- " Validation Mean Pearson: 0.0236\n",
1208
- " ENCFF884LDL/pearson: 0.0236\n",
1209
- "Step 26/32 | Loss: 0.4457 | Mean Pearson: -0.0219 | LR: 1.04e-09 | Tokens: 110,592\n",
1210
- "Step 28/32 | Loss: 0.4520 | Mean Pearson: 0.0580 | LR: 8.04e-10 | Tokens: 118,784\n",
1211
  "\n",
1212
  "Running validation at step 28...\n",
1213
- " Validation Loss: 0.3757\n",
1214
- " Validation Mean Pearson: 0.0050\n",
1215
- " ENCFF884LDL/pearson: 0.0050\n",
1216
- "Step 30/32 | Loss: 0.9341 | Mean Pearson: -0.0122 | LR: 4.64e-10 | Tokens: 126,976\n",
1217
  "\n",
1218
  "Training completed after 32 steps!\n"
1219
  ]
@@ -1290,7 +1224,7 @@
1290
  " current_lr = scheduler.get_last_lr()[0] if scheduler else config[\"learning_rate\"]\n",
1291
  " print(f\"Step {optimizer_step_idx + 1}/{num_steps_training} | \"\n",
1292
  " f\"Loss: {avg_loss:.4f} | \"\n",
1293
- " f\"Mean Pearson: {train_metrics_dict['mean/pearson']:.4f} | \"\n",
1294
  " f\"LR: {current_lr:.2e} | \"\n",
1295
  " f\"Tokens: {num_tokens_seen:,}\")\n",
1296
  " train_metrics.reset()\n",
@@ -1311,9 +1245,9 @@
1311
  " # Print validation metrics\n",
1312
  " val_metrics_dict = val_metrics.compute()\n",
1313
  " print(f\" Validation Loss: {np.mean(val_losses):.4f}\")\n",
1314
- " print(f\" Validation Mean Pearson: {val_metrics_dict['mean/pearson']:.4f}\")\n",
1315
  " for track_name in config[\"bigwig_file_ids\"]:\n",
1316
- " print(f\" {track_name}/pearson: {val_metrics_dict[f'{track_name}/pearson']:.4f}\")\n",
1317
  " \n",
1318
  " model.train() # Back to training mode\n",
1319
  "\n",
@@ -1329,7 +1263,7 @@
1329
  },
1330
  {
1331
  "cell_type": "code",
1332
- "execution_count": null,
1333
  "metadata": {},
1334
  "outputs": [],
1335
  "source": [
@@ -1381,18 +1315,36 @@
1381
  "\n",
1382
  "==================================================\n",
1383
  "Test Set Evaluation\n",
1384
- "==================================================\n"
 
1385
  ]
1386
  },
1387
  {
1388
- "ename": "NameError",
1389
- "evalue": "name 'test_dataset' is not defined",
1390
- "output_type": "error",
1391
- "traceback": [
1392
- "\u001b[31m---------------------------------------------------------------------------\u001b[39m",
1393
- "\u001b[31mNameError\u001b[39m Traceback (most recent call last)",
1394
- "\u001b[36mCell\u001b[39m\u001b[36m \u001b[39m\u001b[32mIn[68]\u001b[39m\u001b[32m, line 10\u001b[39m\n\u001b[32m 8\u001b[39m \u001b[38;5;66;03m# Calculate number of test steps (based on deepspeed pipeline)\u001b[39;00m\n\u001b[32m 9\u001b[39m test_batch_size = config[\u001b[33m\"\u001b[39m\u001b[33mbatch_size\u001b[39m\u001b[33m\"\u001b[39m]\n\u001b[32m---> \u001b[39m\u001b[32m10\u001b[39m num_test_samples = \u001b[38;5;28mlen\u001b[39m(\u001b[43mtest_dataset\u001b[49m)\n\u001b[32m 11\u001b[39m num_test_steps = num_test_samples // test_batch_size\n\u001b[32m 13\u001b[39m \u001b[38;5;28mprint\u001b[39m(\u001b[33mf\u001b[39m\u001b[33m\"\u001b[39m\u001b[33mRunning test evaluation with \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mnum_test_steps\u001b[38;5;132;01m}\u001b[39;00m\u001b[33m steps (\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mnum_test_samples\u001b[38;5;132;01m}\u001b[39;00m\u001b[33m samples)\u001b[39m\u001b[33m\"\u001b[39m)\n",
1395
- "\u001b[31mNameError\u001b[39m: name 'test_dataset' is not defined"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1396
  ]
1397
  }
1398
  ],
@@ -1432,14 +1384,14 @@
1432
  "print(\"Test Set Results\")\n",
1433
  "print(\"=\"*50)\n",
1434
  "print(f\"\\nScaled Metrics (scaled predictions vs scaled targets):\")\n",
1435
- "print(f\" Mean Pearson (scaled): {test_metrics_dict['mean/pearson_scaled']:.4f}\")\n",
1436
  "for track_name in config[\"bigwig_file_ids\"]:\n",
1437
- " print(f\" {track_name}/pearson_scaled: {test_metrics_dict[f'{track_name}/pearson_scaled']:.4f}\")\n",
1438
  "\n",
1439
  "print(f\"\\nRaw Metrics (raw predictions vs raw targets):\")\n",
1440
- "print(f\" Mean Pearson (raw): {test_metrics_dict['mean/pearson_raw']:.4f}\")\n",
1441
  "for track_name in config[\"bigwig_file_ids\"]:\n",
1442
- " print(f\" {track_name}/pearson_raw: {test_metrics_dict[f'{track_name}/pearson_raw']:.4f}\")\n",
1443
  "print(\"=\"*50)"
1444
  ]
1445
  },
 
26
  },
27
  {
28
  "cell_type": "code",
29
+ "execution_count": 1,
30
  "metadata": {},
31
  "outputs": [
32
  {
 
66
  },
67
  {
68
  "cell_type": "code",
69
+ "execution_count": 2,
70
  "metadata": {},
71
  "outputs": [
72
  {
 
112
  " # General\n",
113
  " \"seed\": 42,\n",
114
  " \"device\": \"cuda\" if torch.cuda.is_available() else \"cpu\",\n",
115
+ " \"num_workers\": 0, # Number of worker processes for DataLoader\n",
116
  "}\n",
117
  "\n",
118
  "# Set random seed\n",
 
132
  },
133
  {
134
  "cell_type": "code",
135
+ "execution_count": null,
136
  "metadata": {},
137
+ "outputs": [],
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
138
  "source": [
139
  "!wget -c https://ftp.ncbi.nlm.nih.gov/genomes/refseq/vertebrate_mammalian/Homo_sapiens/latest_assembly_versions/GCF_000001405.40_GRCh38.p14/GCF_000001405.40_GRCh38.p14_genomic.fna.gz \\\n",
140
  "&& gunzip -f GCF_000001405.40_GRCh38.p14_genomic.fna.gz"
 
142
  },
143
  {
144
  "cell_type": "code",
145
+ "execution_count": null,
146
  "metadata": {},
147
+ "outputs": [],
 
 
 
 
 
 
 
 
 
 
 
 
 
148
  "source": [
149
  "!wget -O ENCFF884LDL \"$(curl -s https://www.encodeproject.org/files/ENCFF884LDL/@@download/ENCFF884LDL | sed -n 's/.*href=\\\"\\([^\\\"]*ENCFF884LDL[^\\\"]*\\)\\\".*/\\1/p')\" \\\n",
150
  "&& echo \"Downloaded ENCFF884LDL\""
 
152
  },
153
  {
154
  "cell_type": "code",
155
+ "execution_count": null,
156
  "metadata": {},
157
+ "outputs": [],
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
158
  "source": [
159
  "!wget -c https://www.encodeproject.org/files/ENCFF884LDL/@@download/ENCFF884LDL.bigWig"
160
  ]
161
  },
162
  {
163
  "cell_type": "code",
164
+ "execution_count": 3,
165
  "metadata": {},
166
  "outputs": [],
167
  "source": [
 
211
  },
212
  {
213
  "cell_type": "code",
214
+ "execution_count": 4,
215
  "metadata": {},
216
  "outputs": [],
217
  "source": [
 
287
  },
288
  {
289
  "cell_type": "code",
290
+ "execution_count": 5,
291
  "metadata": {},
292
  "outputs": [
293
  {
 
333
  },
334
  {
335
  "cell_type": "code",
336
+ "execution_count": 34,
337
  "metadata": {},
338
  "outputs": [],
339
  "source": [
 
446
  " # For a single input string, its shape is typically (1, len(seq))\n",
447
  "\n",
448
  " # Signal from bigWig tracks (numpy array) -> torch tensor\n",
449
+ " bigwig_targets = np.array([\n",
450
  " self.bw_list[i].values(chrom, start, end, numpy=True)\n",
451
  " for i in range(len(self.bw_list))\n",
452
+ " ]) # shape (num_tracks, seq_len)\n",
453
+ " # Transpose to (seq_len, num_tracks)\n",
454
+ " bigwig_targets = bigwig_targets.T\n",
455
  " # pyBigWig returns NaN where no data; turn NaN into 0\n",
456
  " bigwig_targets = torch.tensor(bigwig_targets, dtype=torch.float32)\n",
457
  " bigwig_targets = torch.nan_to_num(bigwig_targets, nan=0.0)\n",
458
  " \n",
459
  " # Crop targets to center fraction\n",
460
  " if self.keep_target_center_fraction < 1.0:\n",
461
+ " seq_len = bigwig_targets.shape[0] # First dimension is sequence length\n",
462
  " target_offset = int(seq_len * (1 - self.keep_target_center_fraction) // 2)\n",
463
  " target_length = seq_len - 2 * target_offset\n",
464
+ " bigwig_targets = bigwig_targets[target_offset:target_offset + target_length, :]\n",
465
  "\n",
466
  " sample = {\n",
467
  " \"tokens\": tokens,\n",
 
475
  },
476
  {
477
  "cell_type": "code",
478
+ "execution_count": 35,
479
  "metadata": {},
480
  "outputs": [
481
  {
 
483
  "output_type": "stream",
484
  "text": [
485
  "Train samples: 100\n",
486
+ "Val samples: 10\n",
487
+ "Test samples: 10\n"
488
  ]
489
  }
490
  ],
 
554
  },
555
  {
556
  "cell_type": "code",
557
+ "execution_count": 36,
558
  "metadata": {},
559
  "outputs": [],
560
  "source": [
 
582
  },
583
  {
584
  "cell_type": "code",
585
+ "execution_count": 37,
586
  "metadata": {},
587
  "outputs": [
588
  {
 
676
  },
677
  {
678
  "cell_type": "code",
679
+ "execution_count": 38,
680
  "metadata": {},
681
  "outputs": [],
682
  "source": [
 
743
  " # Scaled metrics: per-track Pearson correlations\n",
744
  " for i, (track_name, metric) in enumerate(zip(self.track_names, self.pearson_metrics_scaled)):\n",
745
  " corr = metric.compute().item()\n",
746
+ " metrics_dict[f\"metrics_scaled/{track_name}/pearson\"] = corr\n",
747
  " \n",
748
  " # Scaled metrics: mean Pearson correlation\n",
749
  " correlations_scaled = [metric.compute().item() for metric in self.pearson_metrics_scaled]\n",
750
+ " metrics_dict[\"metrics_scaled/mean/pearson\"] = np.nanmean(correlations_scaled)\n",
751
  " \n",
752
  " # Raw metrics: per-track Pearson correlations\n",
753
  " for i, (track_name, metric) in enumerate(zip(self.track_names, self.pearson_metrics_raw)):\n",
754
  " corr = metric.compute().item()\n",
755
+ " metrics_dict[f\"metrics_raw/{track_name}/pearson\"] = corr\n",
756
  " \n",
757
  " # Raw metrics: mean Pearson correlation\n",
758
  " correlations_raw = [metric.compute().item() for metric in self.pearson_metrics_raw]\n",
759
+ " metrics_dict[\"metrics_raw/mean/pearson\"] = np.nanmean(correlations_raw)\n",
760
  " \n",
761
  " # Mean loss\n",
762
  " metrics_dict[\"loss\"] = np.mean(self.losses) if self.losses else 0.0\n",
 
766
  },
767
  {
768
  "cell_type": "code",
769
+ "execution_count": 39,
770
  "metadata": {},
771
  "outputs": [],
772
  "source": [
 
784
  },
785
  {
786
  "cell_type": "code",
787
+ "execution_count": 40,
788
  "metadata": {},
789
  "outputs": [
790
  {
 
920
  },
921
  {
922
  "cell_type": "code",
923
+ "execution_count": 41,
924
  "metadata": {},
925
  "outputs": [],
926
  "source": [
 
996
  },
997
  {
998
  "cell_type": "code",
999
+ "execution_count": 42,
1000
  "metadata": {},
1001
  "outputs": [],
1002
  "source": [
 
1082
  },
1083
  {
1084
  "cell_type": "code",
1085
+ "execution_count": 43,
1086
  "metadata": {},
1087
  "outputs": [
1088
  {
 
1091
  "text": [
1092
  "Starting training...\n",
1093
  "Training for 32 steps with 2 gradient accumulation steps\n",
1094
+ "\n",
1095
+ "Step 1/32 | Loss: 0.7569 | Mean Pearson: -0.1473 | LR: 1.17e-09 | Tokens: 4,096\n",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1096
  "\n",
1097
  "Running validation at step 0...\n",
1098
+ " Validation Loss: 1.0152\n",
1099
+ " Validation Mean Pearson: -0.0414\n",
1100
+ " ENCFF884LDL/pearson: -0.0414\n",
1101
+ "Step 3/32 | Loss: 0.3793 | Mean Pearson: -0.0229 | LR: 2.50e-09 | Tokens: 12,288\n",
1102
+ "Step 5/32 | Loss: 0.4111 | Mean Pearson: -0.1739 | LR: 2.41e-09 | Tokens: 20,480\n",
1103
  "\n",
1104
  "Running validation at step 4...\n",
1105
+ " Validation Loss: 0.4801\n",
1106
+ " Validation Mean Pearson: 0.0120\n",
1107
+ " ENCFF884LDL/pearson: 0.0120\n",
1108
+ "Step 7/32 | Loss: 0.3404 | Mean Pearson: -0.0191 | LR: 2.32e-09 | Tokens: 28,672\n",
1109
+ "Step 9/32 | Loss: 0.3950 | Mean Pearson: 0.0090 | LR: 2.23e-09 | Tokens: 36,864\n",
1110
  "\n",
1111
  "Running validation at step 8...\n",
1112
+ " Validation Loss: 0.5865\n",
1113
+ " Validation Mean Pearson: -0.0260\n",
1114
+ " ENCFF884LDL/pearson: -0.0260\n",
1115
+ "Step 11/32 | Loss: 0.3750 | Mean Pearson: 0.0121 | LR: 2.13e-09 | Tokens: 45,056\n",
1116
+ "Step 13/32 | Loss: 0.4380 | Mean Pearson: -0.0126 | LR: 2.02e-09 | Tokens: 53,248\n",
1117
  "\n",
1118
  "Running validation at step 12...\n",
1119
+ " Validation Loss: 0.3997\n",
1120
+ " Validation Mean Pearson: 0.0093\n",
1121
+ " ENCFF884LDL/pearson: 0.0093\n",
1122
+ "Step 15/32 | Loss: 0.3469 | Mean Pearson: -0.0279 | LR: 1.91e-09 | Tokens: 61,440\n",
1123
+ "Step 17/32 | Loss: 0.5098 | Mean Pearson: -0.2044 | LR: 1.80e-09 | Tokens: 69,632\n",
1124
  "\n",
1125
  "Running validation at step 16...\n",
1126
+ " Validation Loss: 0.3752\n",
1127
+ " Validation Mean Pearson: -0.0178\n",
1128
+ " ENCFF884LDL/pearson: -0.0178\n",
1129
+ "Step 19/32 | Loss: 0.4899 | Mean Pearson: -0.0424 | LR: 1.67e-09 | Tokens: 77,824\n",
1130
+ "Step 21/32 | Loss: 0.3889 | Mean Pearson: -0.0332 | LR: 1.54e-09 | Tokens: 86,016\n",
1131
  "\n",
1132
  "Running validation at step 20...\n",
1133
+ " Validation Loss: 0.4217\n",
1134
+ " Validation Mean Pearson: -0.0205\n",
1135
+ " ENCFF884LDL/pearson: -0.0205\n",
1136
+ "Step 23/32 | Loss: 0.3392 | Mean Pearson: 0.0235 | LR: 1.39e-09 | Tokens: 94,208\n",
1137
+ "Step 25/32 | Loss: 0.4165 | Mean Pearson: 0.0033 | LR: 1.23e-09 | Tokens: 102,400\n",
1138
  "\n",
1139
  "Running validation at step 24...\n",
1140
+ " Validation Loss: 0.4363\n",
1141
+ " Validation Mean Pearson: -0.0379\n",
1142
+ " ENCFF884LDL/pearson: -0.0379\n",
1143
+ "Step 27/32 | Loss: 0.7630 | Mean Pearson: 0.0683 | LR: 1.04e-09 | Tokens: 110,592\n",
1144
+ "Step 29/32 | Loss: 0.7357 | Mean Pearson: 0.0050 | LR: 8.04e-10 | Tokens: 118,784\n",
1145
  "\n",
1146
  "Running validation at step 28...\n",
1147
+ " Validation Loss: 0.6629\n",
1148
+ " Validation Mean Pearson: -0.0370\n",
1149
+ " ENCFF884LDL/pearson: -0.0370\n",
1150
+ "Step 31/32 | Loss: 0.3690 | Mean Pearson: -0.0808 | LR: 4.64e-10 | Tokens: 126,976\n",
1151
  "\n",
1152
  "Training completed after 32 steps!\n"
1153
  ]
 
1224
  " current_lr = scheduler.get_last_lr()[0] if scheduler else config[\"learning_rate\"]\n",
1225
  " print(f\"Step {optimizer_step_idx + 1}/{num_steps_training} | \"\n",
1226
  " f\"Loss: {avg_loss:.4f} | \"\n",
1227
+ " f\"Mean Pearson: {train_metrics_dict['metrics_scaled/mean/pearson']:.4f} | \"\n",
1228
  " f\"LR: {current_lr:.2e} | \"\n",
1229
  " f\"Tokens: {num_tokens_seen:,}\")\n",
1230
  " train_metrics.reset()\n",
 
1245
  " # Print validation metrics\n",
1246
  " val_metrics_dict = val_metrics.compute()\n",
1247
  " print(f\" Validation Loss: {np.mean(val_losses):.4f}\")\n",
1248
+ " print(f\" Validation Mean Pearson: {val_metrics_dict['metrics_scaled/mean/pearson']:.4f}\")\n",
1249
  " for track_name in config[\"bigwig_file_ids\"]:\n",
1250
+ " print(f\" {track_name}/pearson: {val_metrics_dict[f'metrics_scaled/{track_name}/pearson']:.4f}\")\n",
1251
  " \n",
1252
  " model.train() # Back to training mode\n",
1253
  "\n",
 
1263
  },
1264
  {
1265
  "cell_type": "code",
1266
+ "execution_count": 44,
1267
  "metadata": {},
1268
  "outputs": [],
1269
  "source": [
 
1315
  "\n",
1316
  "==================================================\n",
1317
  "Test Set Evaluation\n",
1318
+ "==================================================\n",
1319
+ "Running test evaluation with 5 steps (10 samples)\n"
1320
  ]
1321
  },
1322
  {
1323
+ "name": "stderr",
1324
+ "output_type": "stream",
1325
+ "text": [
1326
+ "/home/y-bornachot/venvs/ntv3-env/lib/python3.12/site-packages/torch/amp/autocast_mode.py:287: UserWarning: In CPU autocast, but the target dtype is not supported. Disabling autocast.\n",
1327
+ "CPU Autocast only supports dtype of torch.bfloat16, torch.float16 currently.\n",
1328
+ " warnings.warn(error_message)\n"
1329
+ ]
1330
+ },
1331
+ {
1332
+ "name": "stdout",
1333
+ "output_type": "stream",
1334
+ "text": [
1335
+ "\n",
1336
+ "==================================================\n",
1337
+ "Test Set Results\n",
1338
+ "==================================================\n",
1339
+ "\n",
1340
+ "Scaled Metrics (scaled predictions vs scaled targets):\n",
1341
+ " Mean Pearson (scaled): -0.0362\n",
1342
+ " metrics_scaled/ENCFF884LDL/pearson: -0.0362\n",
1343
+ "\n",
1344
+ "Raw Metrics (raw predictions vs raw targets):\n",
1345
+ " Mean Pearson (raw): -0.0362\n",
1346
+ " metrics_raw/ENCFF884LDL/pearson: -0.0362\n",
1347
+ "==================================================\n"
1348
  ]
1349
  }
1350
  ],
 
1384
  "print(\"Test Set Results\")\n",
1385
  "print(\"=\"*50)\n",
1386
  "print(f\"\\nScaled Metrics (scaled predictions vs scaled targets):\")\n",
1387
+ "print(f\" Mean Pearson (scaled): {test_metrics_dict['metrics_scaled/mean/pearson']:.4f}\")\n",
1388
  "for track_name in config[\"bigwig_file_ids\"]:\n",
1389
+ " print(f\" {track_name}/pearson: {test_metrics_dict[f'metrics_scaled/{track_name}/pearson']:.4f}\")\n",
1390
  "\n",
1391
  "print(f\"\\nRaw Metrics (raw predictions vs raw targets):\")\n",
1392
+ "print(f\" Mean Pearson (raw): {test_metrics_dict['metrics_raw/mean/pearson']:.4f}\")\n",
1393
  "for track_name in config[\"bigwig_file_ids\"]:\n",
1394
+ " print(f\" {track_name}/pearson: {test_metrics_dict[f'metrics_raw/{track_name}/pearson']:.4f}\")\n",
1395
  "print(\"=\"*50)"
1396
  ]
1397
  },