Spaces:
Running
Running
Commit
·
89ffd35
1
Parent(s):
6e05130
fix: with cell outputs
Browse files- notebooks/03_fine_tuning.ipynb +156 -21
notebooks/03_fine_tuning.ipynb
CHANGED
|
@@ -561,7 +561,7 @@
|
|
| 561 |
},
|
| 562 |
{
|
| 563 |
"cell_type": "code",
|
| 564 |
-
"execution_count":
|
| 565 |
"metadata": {},
|
| 566 |
"outputs": [],
|
| 567 |
"source": [
|
|
@@ -589,9 +589,29 @@
|
|
| 589 |
},
|
| 590 |
{
|
| 591 |
"cell_type": "code",
|
| 592 |
-
"execution_count":
|
| 593 |
"metadata": {},
|
| 594 |
-
"outputs": [
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 595 |
"source": [
|
| 596 |
"# Calculate gradient accumulation steps and effective batch size\n",
|
| 597 |
"num_devices = 1 # Single device for now\n",
|
|
@@ -663,7 +683,7 @@
|
|
| 663 |
},
|
| 664 |
{
|
| 665 |
"cell_type": "code",
|
| 666 |
-
"execution_count":
|
| 667 |
"metadata": {},
|
| 668 |
"outputs": [],
|
| 669 |
"source": [
|
|
@@ -753,7 +773,7 @@
|
|
| 753 |
},
|
| 754 |
{
|
| 755 |
"cell_type": "code",
|
| 756 |
-
"execution_count":
|
| 757 |
"metadata": {},
|
| 758 |
"outputs": [],
|
| 759 |
"source": [
|
|
@@ -771,9 +791,17 @@
|
|
| 771 |
},
|
| 772 |
{
|
| 773 |
"cell_type": "code",
|
| 774 |
-
"execution_count":
|
| 775 |
"metadata": {},
|
| 776 |
-
"outputs": [
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 777 |
"source": [
|
| 778 |
"def get_track_means(bigwig_file_ids: List[str]) -> np.ndarray:\n",
|
| 779 |
" \"\"\"\n",
|
|
@@ -899,7 +927,7 @@
|
|
| 899 |
},
|
| 900 |
{
|
| 901 |
"cell_type": "code",
|
| 902 |
-
"execution_count":
|
| 903 |
"metadata": {},
|
| 904 |
"outputs": [],
|
| 905 |
"source": [
|
|
@@ -975,7 +1003,7 @@
|
|
| 975 |
},
|
| 976 |
{
|
| 977 |
"cell_type": "code",
|
| 978 |
-
"execution_count":
|
| 979 |
"metadata": {},
|
| 980 |
"outputs": [],
|
| 981 |
"source": [
|
|
@@ -1061,9 +1089,98 @@
|
|
| 1061 |
},
|
| 1062 |
{
|
| 1063 |
"cell_type": "code",
|
| 1064 |
-
"execution_count":
|
| 1065 |
"metadata": {},
|
| 1066 |
-
"outputs": [
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1067 |
"source": [
|
| 1068 |
"# Training loop (step-based with gradient accumulation)\n",
|
| 1069 |
"print(\"Starting training...\")\n",
|
|
@@ -1174,7 +1291,7 @@
|
|
| 1174 |
},
|
| 1175 |
{
|
| 1176 |
"cell_type": "code",
|
| 1177 |
-
"execution_count":
|
| 1178 |
"metadata": {},
|
| 1179 |
"outputs": [],
|
| 1180 |
"source": [
|
|
@@ -1216,9 +1333,34 @@
|
|
| 1216 |
},
|
| 1217 |
{
|
| 1218 |
"cell_type": "code",
|
| 1219 |
-
"execution_count":
|
| 1220 |
"metadata": {},
|
| 1221 |
-
"outputs": [
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1222 |
"source": [
|
| 1223 |
"print(\"\\n\" + \"=\"*50)\n",
|
| 1224 |
"print(\"Test Set Evaluation\")\n",
|
|
@@ -1265,13 +1407,6 @@
|
|
| 1265 |
" print(f\" {track_name}/pearson: {test_metrics_dict[f'metrics_raw/{track_name}/pearson']:.4f}\")\n",
|
| 1266 |
"print(\"=\"*50)"
|
| 1267 |
]
|
| 1268 |
-
},
|
| 1269 |
-
{
|
| 1270 |
-
"cell_type": "code",
|
| 1271 |
-
"execution_count": null,
|
| 1272 |
-
"metadata": {},
|
| 1273 |
-
"outputs": [],
|
| 1274 |
-
"source": []
|
| 1275 |
}
|
| 1276 |
],
|
| 1277 |
"metadata": {
|
|
|
|
| 561 |
},
|
| 562 |
{
|
| 563 |
"cell_type": "code",
|
| 564 |
+
"execution_count": 19,
|
| 565 |
"metadata": {},
|
| 566 |
"outputs": [],
|
| 567 |
"source": [
|
|
|
|
| 589 |
},
|
| 590 |
{
|
| 591 |
"cell_type": "code",
|
| 592 |
+
"execution_count": 20,
|
| 593 |
"metadata": {},
|
| 594 |
+
"outputs": [
|
| 595 |
+
{
|
| 596 |
+
"name": "stdout",
|
| 597 |
+
"output_type": "stream",
|
| 598 |
+
"text": [
|
| 599 |
+
"Gradient accumulation steps: 2\n",
|
| 600 |
+
"Effective batch size: 4\n",
|
| 601 |
+
"Effective tokens per update: 4096\n",
|
| 602 |
+
"\n",
|
| 603 |
+
"Training constants:\n",
|
| 604 |
+
" Total training steps: 32\n",
|
| 605 |
+
" Log training metrics every: 2 steps\n",
|
| 606 |
+
" Run validation every: 4 steps\n",
|
| 607 |
+
" Warmup steps: 3\n",
|
| 608 |
+
"\n",
|
| 609 |
+
"Optimizer setup:\n",
|
| 610 |
+
" Initial LR: 1e-05\n",
|
| 611 |
+
" Peak LR: 5e-05\n"
|
| 612 |
+
]
|
| 613 |
+
}
|
| 614 |
+
],
|
| 615 |
"source": [
|
| 616 |
"# Calculate gradient accumulation steps and effective batch size\n",
|
| 617 |
"num_devices = 1 # Single device for now\n",
|
|
|
|
| 683 |
},
|
| 684 |
{
|
| 685 |
"cell_type": "code",
|
| 686 |
+
"execution_count": 21,
|
| 687 |
"metadata": {},
|
| 688 |
"outputs": [],
|
| 689 |
"source": [
|
|
|
|
| 773 |
},
|
| 774 |
{
|
| 775 |
"cell_type": "code",
|
| 776 |
+
"execution_count": 22,
|
| 777 |
"metadata": {},
|
| 778 |
"outputs": [],
|
| 779 |
"source": [
|
|
|
|
| 791 |
},
|
| 792 |
{
|
| 793 |
"cell_type": "code",
|
| 794 |
+
"execution_count": 23,
|
| 795 |
"metadata": {},
|
| 796 |
+
"outputs": [
|
| 797 |
+
{
|
| 798 |
+
"name": "stdout",
|
| 799 |
+
"output_type": "stream",
|
| 800 |
+
"text": [
|
| 801 |
+
"Scaling functions created\n"
|
| 802 |
+
]
|
| 803 |
+
}
|
| 804 |
+
],
|
| 805 |
"source": [
|
| 806 |
"def get_track_means(bigwig_file_ids: List[str]) -> np.ndarray:\n",
|
| 807 |
" \"\"\"\n",
|
|
|
|
| 927 |
},
|
| 928 |
{
|
| 929 |
"cell_type": "code",
|
| 930 |
+
"execution_count": 24,
|
| 931 |
"metadata": {},
|
| 932 |
"outputs": [],
|
| 933 |
"source": [
|
|
|
|
| 1003 |
},
|
| 1004 |
{
|
| 1005 |
"cell_type": "code",
|
| 1006 |
+
"execution_count": 25,
|
| 1007 |
"metadata": {},
|
| 1008 |
"outputs": [],
|
| 1009 |
"source": [
|
|
|
|
| 1089 |
},
|
| 1090 |
{
|
| 1091 |
"cell_type": "code",
|
| 1092 |
+
"execution_count": 26,
|
| 1093 |
"metadata": {},
|
| 1094 |
+
"outputs": [
|
| 1095 |
+
{
|
| 1096 |
+
"name": "stdout",
|
| 1097 |
+
"output_type": "stream",
|
| 1098 |
+
"text": [
|
| 1099 |
+
"Starting training...\n",
|
| 1100 |
+
"Training for 32 steps with 2 gradient accumulation steps\n",
|
| 1101 |
+
"\n"
|
| 1102 |
+
]
|
| 1103 |
+
},
|
| 1104 |
+
{
|
| 1105 |
+
"name": "stderr",
|
| 1106 |
+
"output_type": "stream",
|
| 1107 |
+
"text": [
|
| 1108 |
+
"/home/y-bornachot/venvs/ntv3-env/lib/python3.12/site-packages/torch/amp/autocast_mode.py:287: UserWarning: In CPU autocast, but the target dtype is not supported. Disabling autocast.\n",
|
| 1109 |
+
"CPU Autocast only supports dtype of torch.bfloat16, torch.float16 currently.\n",
|
| 1110 |
+
" warnings.warn(error_message)\n",
|
| 1111 |
+
"/home/y-bornachot/venvs/ntv3-env/lib/python3.12/site-packages/torchmetrics/utilities/prints.py:43: UserWarning: The variance of predictions or target is close to zero. This can cause instability in Pearson correlationcoefficient, leading to wrong results. Consider re-scaling the input if possible or computing using alarger dtype (currently using torch.float32). Setting the correlation coefficient to nan.\n",
|
| 1112 |
+
" warnings.warn(*args, **kwargs)\n",
|
| 1113 |
+
"/tmp/ipykernel_1758159/1960846655.py:68: RuntimeWarning: Mean of empty slice\n",
|
| 1114 |
+
" metrics_dict[\"metrics_scaled/mean/pearson\"] = np.nanmean(correlations_scaled)\n",
|
| 1115 |
+
"/tmp/ipykernel_1758159/1960846655.py:77: RuntimeWarning: Mean of empty slice\n",
|
| 1116 |
+
" metrics_dict[\"metrics_raw/mean/pearson\"] = np.nanmean(correlations_raw)\n"
|
| 1117 |
+
]
|
| 1118 |
+
},
|
| 1119 |
+
{
|
| 1120 |
+
"name": "stdout",
|
| 1121 |
+
"output_type": "stream",
|
| 1122 |
+
"text": [
|
| 1123 |
+
"Step 1/32 | Loss: 0.8378 | Mean Pearson: nan | LR: 1.17e-09 | Tokens: 4,096\n",
|
| 1124 |
+
"\n",
|
| 1125 |
+
"Running validation at step 0...\n",
|
| 1126 |
+
" Validation Loss: 0.5279\n",
|
| 1127 |
+
" Validation Mean Pearson: -0.0192\n",
|
| 1128 |
+
" ENCFF884LDL/pearson: -0.0192\n",
|
| 1129 |
+
"Step 3/32 | Loss: 0.4650 | Mean Pearson: -0.0149 | LR: 2.50e-09 | Tokens: 12,288\n",
|
| 1130 |
+
"Step 5/32 | Loss: 0.3369 | Mean Pearson: -0.1350 | LR: 2.41e-09 | Tokens: 20,480\n",
|
| 1131 |
+
"\n",
|
| 1132 |
+
"Running validation at step 4...\n",
|
| 1133 |
+
" Validation Loss: 0.3878\n",
|
| 1134 |
+
" Validation Mean Pearson: -0.1298\n",
|
| 1135 |
+
" ENCFF884LDL/pearson: -0.1298\n",
|
| 1136 |
+
"Step 7/32 | Loss: 0.3609 | Mean Pearson: -0.0102 | LR: 2.32e-09 | Tokens: 28,672\n",
|
| 1137 |
+
"Step 9/32 | Loss: 0.3301 | Mean Pearson: -0.0902 | LR: 2.23e-09 | Tokens: 36,864\n",
|
| 1138 |
+
"\n",
|
| 1139 |
+
"Running validation at step 8...\n",
|
| 1140 |
+
" Validation Loss: 0.4743\n",
|
| 1141 |
+
" Validation Mean Pearson: -0.0739\n",
|
| 1142 |
+
" ENCFF884LDL/pearson: -0.0739\n",
|
| 1143 |
+
"Step 11/32 | Loss: 0.3905 | Mean Pearson: -0.0113 | LR: 2.13e-09 | Tokens: 45,056\n",
|
| 1144 |
+
"Step 13/32 | Loss: 0.3181 | Mean Pearson: -0.1564 | LR: 2.02e-09 | Tokens: 53,248\n",
|
| 1145 |
+
"\n",
|
| 1146 |
+
"Running validation at step 12...\n",
|
| 1147 |
+
" Validation Loss: 0.3337\n",
|
| 1148 |
+
" Validation Mean Pearson: -0.0650\n",
|
| 1149 |
+
" ENCFF884LDL/pearson: -0.0650\n",
|
| 1150 |
+
"Step 15/32 | Loss: 0.3638 | Mean Pearson: 0.0295 | LR: 1.91e-09 | Tokens: 61,440\n",
|
| 1151 |
+
"Step 17/32 | Loss: 0.4170 | Mean Pearson: -0.0442 | LR: 1.80e-09 | Tokens: 69,632\n",
|
| 1152 |
+
"\n",
|
| 1153 |
+
"Running validation at step 16...\n",
|
| 1154 |
+
" Validation Loss: 0.7969\n",
|
| 1155 |
+
" Validation Mean Pearson: -0.0304\n",
|
| 1156 |
+
" ENCFF884LDL/pearson: -0.0304\n",
|
| 1157 |
+
"Step 19/32 | Loss: 0.5033 | Mean Pearson: -0.0173 | LR: 1.67e-09 | Tokens: 77,824\n",
|
| 1158 |
+
"Step 21/32 | Loss: 0.4084 | Mean Pearson: -0.0516 | LR: 1.54e-09 | Tokens: 86,016\n",
|
| 1159 |
+
"\n",
|
| 1160 |
+
"Running validation at step 20...\n",
|
| 1161 |
+
" Validation Loss: 0.3475\n",
|
| 1162 |
+
" Validation Mean Pearson: -0.3040\n",
|
| 1163 |
+
" ENCFF884LDL/pearson: -0.3040\n",
|
| 1164 |
+
"Step 23/32 | Loss: 0.4915 | Mean Pearson: -0.1727 | LR: 1.39e-09 | Tokens: 94,208\n",
|
| 1165 |
+
"Step 25/32 | Loss: 0.3654 | Mean Pearson: -0.3257 | LR: 1.23e-09 | Tokens: 102,400\n",
|
| 1166 |
+
"\n",
|
| 1167 |
+
"Running validation at step 24...\n",
|
| 1168 |
+
" Validation Loss: 0.4069\n",
|
| 1169 |
+
" Validation Mean Pearson: -0.0551\n",
|
| 1170 |
+
" ENCFF884LDL/pearson: -0.0551\n",
|
| 1171 |
+
"Step 27/32 | Loss: 0.5344 | Mean Pearson: -0.0604 | LR: 1.04e-09 | Tokens: 110,592\n",
|
| 1172 |
+
"Step 29/32 | Loss: 0.3671 | Mean Pearson: -0.0290 | LR: 8.04e-10 | Tokens: 118,784\n",
|
| 1173 |
+
"\n",
|
| 1174 |
+
"Running validation at step 28...\n",
|
| 1175 |
+
" Validation Loss: 0.3162\n",
|
| 1176 |
+
" Validation Mean Pearson: -0.1008\n",
|
| 1177 |
+
" ENCFF884LDL/pearson: -0.1008\n",
|
| 1178 |
+
"Step 31/32 | Loss: 0.5994 | Mean Pearson: -0.0107 | LR: 4.64e-10 | Tokens: 126,976\n",
|
| 1179 |
+
"\n",
|
| 1180 |
+
"Training completed after 32 steps!\n"
|
| 1181 |
+
]
|
| 1182 |
+
}
|
| 1183 |
+
],
|
| 1184 |
"source": [
|
| 1185 |
"# Training loop (step-based with gradient accumulation)\n",
|
| 1186 |
"print(\"Starting training...\")\n",
|
|
|
|
| 1291 |
},
|
| 1292 |
{
|
| 1293 |
"cell_type": "code",
|
| 1294 |
+
"execution_count": 27,
|
| 1295 |
"metadata": {},
|
| 1296 |
"outputs": [],
|
| 1297 |
"source": [
|
|
|
|
| 1333 |
},
|
| 1334 |
{
|
| 1335 |
"cell_type": "code",
|
| 1336 |
+
"execution_count": 28,
|
| 1337 |
"metadata": {},
|
| 1338 |
+
"outputs": [
|
| 1339 |
+
{
|
| 1340 |
+
"name": "stdout",
|
| 1341 |
+
"output_type": "stream",
|
| 1342 |
+
"text": [
|
| 1343 |
+
"\n",
|
| 1344 |
+
"==================================================\n",
|
| 1345 |
+
"Test Set Evaluation\n",
|
| 1346 |
+
"==================================================\n",
|
| 1347 |
+
"Running test evaluation with 5 steps (10 samples)\n",
|
| 1348 |
+
"\n",
|
| 1349 |
+
"==================================================\n",
|
| 1350 |
+
"Test Set Results\n",
|
| 1351 |
+
"==================================================\n",
|
| 1352 |
+
"\n",
|
| 1353 |
+
"Scaled Metrics (scaled predictions vs scaled targets):\n",
|
| 1354 |
+
" Mean Pearson (scaled): -0.0020\n",
|
| 1355 |
+
" ENCFF884LDL/pearson: -0.0020\n",
|
| 1356 |
+
"\n",
|
| 1357 |
+
"Raw Metrics (raw predictions vs raw targets):\n",
|
| 1358 |
+
" Mean Pearson (raw): -0.0020\n",
|
| 1359 |
+
" ENCFF884LDL/pearson: -0.0020\n",
|
| 1360 |
+
"==================================================\n"
|
| 1361 |
+
]
|
| 1362 |
+
}
|
| 1363 |
+
],
|
| 1364 |
"source": [
|
| 1365 |
"print(\"\\n\" + \"=\"*50)\n",
|
| 1366 |
"print(\"Test Set Evaluation\")\n",
|
|
|
|
| 1407 |
" print(f\" {track_name}/pearson: {test_metrics_dict[f'metrics_raw/{track_name}/pearson']:.4f}\")\n",
|
| 1408 |
"print(\"=\"*50)"
|
| 1409 |
]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1410 |
}
|
| 1411 |
],
|
| 1412 |
"metadata": {
|