Spaces:
Running
Running
Commit
·
b6b1c80
1
Parent(s):
e712656
fix: notebook simplification
Browse files- notebooks/03_fine_tuning.ipynb +598 -523
notebooks/03_fine_tuning.ipynb
CHANGED
|
@@ -16,7 +16,7 @@
|
|
| 16 |
},
|
| 17 |
{
|
| 18 |
"cell_type": "code",
|
| 19 |
-
"execution_count":
|
| 20 |
"metadata": {},
|
| 21 |
"outputs": [],
|
| 22 |
"source": [
|
|
@@ -28,14 +28,14 @@
|
|
| 28 |
},
|
| 29 |
{
|
| 30 |
"cell_type": "code",
|
| 31 |
-
"execution_count":
|
| 32 |
"metadata": {},
|
| 33 |
"outputs": [],
|
| 34 |
"source": [
|
| 35 |
"# 0. Imports\n",
|
| 36 |
"import random\n",
|
| 37 |
"import functools\n",
|
| 38 |
-
"from typing import List, Dict,
|
| 39 |
"import os\n",
|
| 40 |
"import subprocess\n",
|
| 41 |
"\n",
|
|
@@ -48,19 +48,50 @@
|
|
| 48 |
"import numpy as np\n",
|
| 49 |
"import pyBigWig\n",
|
| 50 |
"from pyfaidx import Fasta\n",
|
| 51 |
-
"from torchmetrics import PearsonCorrCoef"
|
|
|
|
|
|
|
|
|
|
| 52 |
]
|
| 53 |
},
|
| 54 |
{
|
| 55 |
"cell_type": "markdown",
|
| 56 |
"metadata": {},
|
| 57 |
"source": [
|
| 58 |
-
"# 1. Configuration setup"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 59 |
]
|
| 60 |
},
|
| 61 |
{
|
| 62 |
"cell_type": "code",
|
| 63 |
-
"execution_count":
|
| 64 |
"metadata": {},
|
| 65 |
"outputs": [
|
| 66 |
{
|
|
@@ -74,37 +105,32 @@
|
|
| 74 |
"source": [
|
| 75 |
"config = {\n",
|
| 76 |
" # Model\n",
|
| 77 |
-
" \"model_name\": \"InstaDeepAI/ntv3_8M_7downsample_pretrained_le_1mb\"
|
| 78 |
" \n",
|
| 79 |
" # Data\n",
|
| 80 |
-
" \"data_cache_dir\": \"./data\"
|
| 81 |
-
" \"fasta_url\": \"https://hgdownload.gi.ucsc.edu/goldenPath/hg38/bigZips/hg38.fa.gz\"
|
| 82 |
-
" \"bigwig_url_list\": [\
|
| 83 |
-
"
|
| 84 |
-
"
|
|
|
|
|
|
|
| 85 |
" \n",
|
| 86 |
" # Training\n",
|
| 87 |
-
" \"batch_size\":
|
| 88 |
-
" \"
|
| 89 |
-
" \"
|
| 90 |
-
" \n",
|
| 91 |
-
" \"
|
| 92 |
-
" \"num_tokens_per_update\": 4_096, # Target tokens per optimizer update (batch_size * seq_len * grad_accum)\n",
|
| 93 |
-
" \"num_tokens_per_log\": 8_192, # Tokens between training logs (how often to print metrics)\n",
|
| 94 |
-
" \"num_tokens_per_validation\": 16_384, # Tokens between validation runs (how often to evaluate on validation set)\n",
|
| 95 |
" \n",
|
| 96 |
" # Validation\n",
|
| 97 |
-
" \"
|
| 98 |
-
" \n",
|
| 99 |
-
" # Loss\n",
|
| 100 |
-
" \"bigwig_loss_weight\": 1.0, # Weight multiplier for bigwig prediction loss\n",
|
| 101 |
-
" \"bigwig_scalar_loss_function\": \"poisson-multinomial\", # Loss function type for bigwig tracks\n",
|
| 102 |
-
" \"bigwig_shape_loss_coefficient\": 5.0, # Coefficient balancing shape loss vs scale loss in poisson-multinomial loss\n",
|
| 103 |
" \n",
|
| 104 |
" # General\n",
|
| 105 |
-
" \"seed\": 42
|
| 106 |
-
" \"device\": \"cuda\" if torch.cuda.is_available() else \"cpu\"
|
| 107 |
-
" \"num_workers\": 0
|
| 108 |
"}\n",
|
| 109 |
"\n",
|
| 110 |
"os.makedirs(config[\"data_cache_dir\"], exist_ok=True)\n",
|
|
@@ -216,7 +242,7 @@
|
|
| 216 |
},
|
| 217 |
{
|
| 218 |
"cell_type": "code",
|
| 219 |
-
"execution_count":
|
| 220 |
"metadata": {},
|
| 221 |
"outputs": [],
|
| 222 |
"source": [
|
|
@@ -236,7 +262,7 @@
|
|
| 236 |
},
|
| 237 |
{
|
| 238 |
"cell_type": "code",
|
| 239 |
-
"execution_count":
|
| 240 |
"metadata": {},
|
| 241 |
"outputs": [],
|
| 242 |
"source": [
|
|
@@ -304,7 +330,7 @@
|
|
| 304 |
},
|
| 305 |
{
|
| 306 |
"cell_type": "code",
|
| 307 |
-
"execution_count":
|
| 308 |
"metadata": {},
|
| 309 |
"outputs": [
|
| 310 |
{
|
|
@@ -335,6 +361,49 @@
|
|
| 335 |
"print(f\"Model parameters: {sum(p.numel() for p in model.parameters()):,}\")"
|
| 336 |
]
|
| 337 |
},
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 338 |
{
|
| 339 |
"cell_type": "markdown",
|
| 340 |
"metadata": {},
|
|
@@ -344,7 +413,7 @@
|
|
| 344 |
},
|
| 345 |
{
|
| 346 |
"cell_type": "code",
|
| 347 |
-
"execution_count":
|
| 348 |
"metadata": {},
|
| 349 |
"outputs": [],
|
| 350 |
"source": [
|
|
@@ -388,6 +457,7 @@
|
|
| 388 |
" sequence_length: int,\n",
|
| 389 |
" num_samples: int,\n",
|
| 390 |
" tokenizer: AutoTokenizer,\n",
|
|
|
|
| 391 |
" keep_target_center_fraction: float = 1.0,\n",
|
| 392 |
" num_tracks: int = 1,\n",
|
| 393 |
" ):\n",
|
|
@@ -401,6 +471,7 @@
|
|
| 401 |
" self.sequence_length = sequence_length\n",
|
| 402 |
" self.num_samples = num_samples\n",
|
| 403 |
" self.tokenizer = tokenizer\n",
|
|
|
|
| 404 |
" self.keep_target_center_fraction = keep_target_center_fraction\n",
|
| 405 |
" self.num_tracks = num_tracks\n",
|
| 406 |
" self.chroms = chroms\n",
|
|
@@ -465,6 +536,9 @@
|
|
| 465 |
" target_length = seq_len - 2 * target_offset\n",
|
| 466 |
" bigwig_targets = bigwig_targets[target_offset:target_offset + target_length, :]\n",
|
| 467 |
"\n",
|
|
|
|
|
|
|
|
|
|
| 468 |
" sample = {\n",
|
| 469 |
" \"tokens\": tokens,\n",
|
| 470 |
" \"bigwig_targets\": bigwig_targets,\n",
|
|
@@ -477,7 +551,7 @@
|
|
| 477 |
},
|
| 478 |
{
|
| 479 |
"cell_type": "code",
|
| 480 |
-
"execution_count":
|
| 481 |
"metadata": {},
|
| 482 |
"outputs": [
|
| 483 |
{
|
|
@@ -485,18 +559,22 @@
|
|
| 485 |
"output_type": "stream",
|
| 486 |
"text": [
|
| 487 |
"Train samples: 100\n",
|
| 488 |
-
"Val samples:
|
| 489 |
-
"Test samples:
|
| 490 |
]
|
| 491 |
}
|
| 492 |
],
|
| 493 |
"source": [
|
|
|
|
|
|
|
|
|
|
| 494 |
"create_dataset_fn = functools.partial(\n",
|
| 495 |
" GenomeBigWigDataset,\n",
|
| 496 |
" fasta_path=fasta_path,\n",
|
| 497 |
" bigwig_path_list=bigwig_path_list,\n",
|
| 498 |
" sequence_length=config[\"sequence_length\"],\n",
|
| 499 |
" tokenizer=tokenizer,\n",
|
|
|
|
| 500 |
" keep_target_center_fraction=config[\"keep_target_center_fraction\"],\n",
|
| 501 |
" num_tracks=len(config[\"bigwig_file_ids\"]),\n",
|
| 502 |
")\n",
|
|
@@ -552,21 +630,18 @@
|
|
| 552 |
},
|
| 553 |
{
|
| 554 |
"cell_type": "code",
|
| 555 |
-
"execution_count":
|
| 556 |
"metadata": {},
|
| 557 |
"outputs": [
|
| 558 |
{
|
| 559 |
"name": "stdout",
|
| 560 |
"output_type": "stream",
|
| 561 |
"text": [
|
| 562 |
-
"
|
| 563 |
-
"
|
| 564 |
-
"
|
| 565 |
-
"\n",
|
| 566 |
-
"
|
| 567 |
-
" Total training steps: 32\n",
|
| 568 |
-
" Log training metrics every: 2 steps\n",
|
| 569 |
-
" Run validation every: 4 steps\n",
|
| 570 |
"\n",
|
| 571 |
"Optimizer setup:\n",
|
| 572 |
" Learning rate: 1e-05\n"
|
|
@@ -574,37 +649,12 @@
|
|
| 574 |
}
|
| 575 |
],
|
| 576 |
"source": [
|
| 577 |
-
"#
|
| 578 |
-
"
|
| 579 |
-
"
|
| 580 |
-
"
|
| 581 |
-
"\n",
|
| 582 |
-
"
|
| 583 |
-
"num_accumulation_gradient = max(1, int(config[\"num_tokens_per_update\"] // (batch_size * num_devices * sequence_length)))\n",
|
| 584 |
-
"\n",
|
| 585 |
-
"# Calculate effective batch size and tokens per update\n",
|
| 586 |
-
"effective_batch_size = batch_size * num_devices * num_accumulation_gradient\n",
|
| 587 |
-
"effective_num_tokens_per_update = effective_batch_size * sequence_length\n",
|
| 588 |
-
"\n",
|
| 589 |
-
"print(f\"Gradient accumulation steps: {num_accumulation_gradient}\")\n",
|
| 590 |
-
"print(f\"Effective batch size: {effective_batch_size}\")\n",
|
| 591 |
-
"print(f\"Effective tokens per update: {effective_num_tokens_per_update}\")\n",
|
| 592 |
-
"\n",
|
| 593 |
-
"# Compute logging constants (based on deepspeed pipeline: compute_logging_constants)\n",
|
| 594 |
-
"num_train_samples = len(train_dataset)\n",
|
| 595 |
-
"num_tokens_per_update = effective_num_tokens_per_update # Same as effective_num_tokens_per_update\n",
|
| 596 |
-
"\n",
|
| 597 |
-
"# Total training steps based on token budget\n",
|
| 598 |
-
"num_steps_training = config[\"num_tokens_training\"] // num_tokens_per_update\n",
|
| 599 |
-
"\n",
|
| 600 |
-
"# Steps for logging and validation\n",
|
| 601 |
-
"log_train_step = int(np.ceil(config[\"num_tokens_per_log\"] / num_tokens_per_update))\n",
|
| 602 |
-
"log_validation_step = int(np.ceil(config[\"num_tokens_per_validation\"] / num_tokens_per_update))\n",
|
| 603 |
-
"\n",
|
| 604 |
-
"print(f\"\\nTraining constants:\")\n",
|
| 605 |
-
"print(f\" Total training steps: {num_steps_training}\")\n",
|
| 606 |
-
"print(f\" Log training metrics every: {log_train_step} steps\")\n",
|
| 607 |
-
"print(f\" Run validation every: {log_validation_step} steps\")\n",
|
| 608 |
"\n",
|
| 609 |
"# Setup optimizer\n",
|
| 610 |
"optimizer = AdamW(\n",
|
|
@@ -626,87 +676,62 @@
|
|
| 626 |
},
|
| 627 |
{
|
| 628 |
"cell_type": "code",
|
| 629 |
-
"execution_count":
|
| 630 |
"metadata": {},
|
| 631 |
"outputs": [],
|
| 632 |
"source": [
|
| 633 |
"class TracksMetrics:\n",
|
| 634 |
-
" \"\"\"Simple metrics tracker for tracks prediction
|
| 635 |
" \n",
|
| 636 |
" def __init__(self, track_names: List[str]):\n",
|
| 637 |
" self.track_names = track_names\n",
|
| 638 |
" self.num_tracks = len(track_names)\n",
|
| 639 |
-
" #
|
| 640 |
-
" self.
|
| 641 |
-
" PearsonCorrCoef().to(device) for _ in range(self.num_tracks)\n",
|
| 642 |
-
" ]\n",
|
| 643 |
-
" # Raw metrics: comparing raw targets with unscaled predictions\n",
|
| 644 |
-
" self.pearson_metrics_raw = [\n",
|
| 645 |
" PearsonCorrCoef().to(device) for _ in range(self.num_tracks)\n",
|
| 646 |
" ]\n",
|
| 647 |
" self.losses = []\n",
|
| 648 |
" \n",
|
| 649 |
" def reset(self):\n",
|
| 650 |
-
" for metric in self.
|
| 651 |
-
" metric.reset()\n",
|
| 652 |
-
" for metric in self.pearson_metrics_raw:\n",
|
| 653 |
" metric.reset()\n",
|
| 654 |
" self.losses = []\n",
|
| 655 |
" \n",
|
| 656 |
" def update(\n",
|
| 657 |
" self, \n",
|
| 658 |
-
"
|
| 659 |
-
"
|
| 660 |
-
" predictions_raw: torch.Tensor,\n",
|
| 661 |
-
" targets_raw: torch.Tensor,\n",
|
| 662 |
" loss: float\n",
|
| 663 |
" ):\n",
|
| 664 |
" \"\"\"\n",
|
| 665 |
-
" Update
|
| 666 |
" Args:\n",
|
| 667 |
-
"
|
| 668 |
-
"
|
| 669 |
-
" predictions_raw: (batch, seq_len, num_tracks) - raw/unscaled predictions\n",
|
| 670 |
-
" targets_raw: (batch, seq_len, num_tracks) - raw targets\n",
|
| 671 |
" loss: scalar loss value\n",
|
| 672 |
" \"\"\"\n",
|
| 673 |
" # Flatten batch and sequence dimensions\n",
|
| 674 |
-
"
|
| 675 |
-
"
|
| 676 |
-
" pred_raw_flat = predictions_raw.detach().reshape(-1, self.num_tracks) # (N, num_tracks)\n",
|
| 677 |
-
" target_raw_flat = targets_raw.detach().reshape(-1, self.num_tracks) # (N, num_tracks)\n",
|
| 678 |
-
" \n",
|
| 679 |
-
" # Update scaled metrics\n",
|
| 680 |
-
" for i, metric in enumerate(self.pearson_metrics_scaled):\n",
|
| 681 |
-
" metric.update(pred_scaled_flat[:, i], target_scaled_flat[:, i])\n",
|
| 682 |
" \n",
|
| 683 |
-
" # Update
|
| 684 |
-
" for i, metric in enumerate(self.
|
| 685 |
-
" metric.update(
|
| 686 |
" \n",
|
| 687 |
" self.losses.append(loss)\n",
|
| 688 |
" \n",
|
| 689 |
" def compute(self) -> Dict[str, float]:\n",
|
| 690 |
-
" \"\"\"Compute and return all metrics
|
| 691 |
" metrics_dict = {}\n",
|
| 692 |
" \n",
|
| 693 |
-
" #
|
| 694 |
-
" for i, (track_name, metric) in enumerate(zip(self.track_names, self.
|
| 695 |
" corr = metric.compute().item()\n",
|
| 696 |
-
" metrics_dict[f\"
|
| 697 |
" \n",
|
| 698 |
-
" #
|
| 699 |
-
"
|
| 700 |
-
" metrics_dict[\"
|
| 701 |
-
" \n",
|
| 702 |
-
" # Raw metrics: per-track Pearson correlations\n",
|
| 703 |
-
" for i, (track_name, metric) in enumerate(zip(self.track_names, self.pearson_metrics_raw)):\n",
|
| 704 |
-
" corr = metric.compute().item()\n",
|
| 705 |
-
" metrics_dict[f\"metrics_raw/{track_name}/pearson\"] = corr\n",
|
| 706 |
-
" \n",
|
| 707 |
-
" # Raw metrics: mean Pearson correlation\n",
|
| 708 |
-
" correlations_raw = [metric.compute().item() for metric in self.pearson_metrics_raw]\n",
|
| 709 |
-
" metrics_dict[\"metrics_raw/mean/pearson\"] = np.nanmean(correlations_raw)\n",
|
| 710 |
" \n",
|
| 711 |
" # Mean loss\n",
|
| 712 |
" metrics_dict[\"loss\"] = np.mean(self.losses) if self.losses else 0.0\n",
|
|
@@ -716,7 +741,7 @@
|
|
| 716 |
},
|
| 717 |
{
|
| 718 |
"cell_type": "code",
|
| 719 |
-
"execution_count":
|
| 720 |
"metadata": {},
|
| 721 |
"outputs": [],
|
| 722 |
"source": [
|
|
@@ -729,148 +754,12 @@
|
|
| 729 |
"cell_type": "markdown",
|
| 730 |
"metadata": {},
|
| 731 |
"source": [
|
| 732 |
-
"# 7.
|
| 733 |
]
|
| 734 |
},
|
| 735 |
{
|
| 736 |
"cell_type": "code",
|
| 737 |
-
"execution_count":
|
| 738 |
-
"metadata": {},
|
| 739 |
-
"outputs": [
|
| 740 |
-
{
|
| 741 |
-
"name": "stdout",
|
| 742 |
-
"output_type": "stream",
|
| 743 |
-
"text": [
|
| 744 |
-
"Scaling functions created\n"
|
| 745 |
-
]
|
| 746 |
-
}
|
| 747 |
-
],
|
| 748 |
-
"source": [
|
| 749 |
-
"def get_track_means(bigwig_file_ids: List[str]) -> np.ndarray:\n",
|
| 750 |
-
" \"\"\"\n",
|
| 751 |
-
" Get track means for normalization.\n",
|
| 752 |
-
" For now, return dummy values. In real pipeline, this loads from metadata.\n",
|
| 753 |
-
" \"\"\"\n",
|
| 754 |
-
" # Dummy values - in real pipeline, this would load from actual metadata\n",
|
| 755 |
-
" return np.ones(len(bigwig_file_ids), dtype=np.float32) * 1.0\n",
|
| 756 |
-
"\n",
|
| 757 |
-
"\n",
|
| 758 |
-
"def get_rna_seq_track_ids(bigwig_file_ids: List[str]) -> List[int]:\n",
|
| 759 |
-
" \"\"\"\n",
|
| 760 |
-
" Get RNA-seq track indices.\n",
|
| 761 |
-
" For now, return empty list. In real pipeline, this identifies RNA-seq tracks.\n",
|
| 762 |
-
" \"\"\"\n",
|
| 763 |
-
" # Dummy - in real pipeline, this would identify RNA-seq tracks\n",
|
| 764 |
-
" return []\n",
|
| 765 |
-
"\n",
|
| 766 |
-
"\n",
|
| 767 |
-
"def create_targets_scaling_fn(bigwig_file_ids: List[str]) -> Callable[[torch.Tensor], torch.Tensor]:\n",
|
| 768 |
-
" \"\"\"\n",
|
| 769 |
-
" Build a scaling function based on track means and RNA-seq squashing.\n",
|
| 770 |
-
" Copied from the supervised tracks pipeline.\n",
|
| 771 |
-
" \"\"\"\n",
|
| 772 |
-
" # Load track means\n",
|
| 773 |
-
" track_means_np = get_track_means(bigwig_file_ids)\n",
|
| 774 |
-
" track_means = torch.tensor(track_means_np, dtype=torch.float32)\n",
|
| 775 |
-
" \n",
|
| 776 |
-
" # Get which tracks use squashing\n",
|
| 777 |
-
" rna_ids = get_rna_seq_track_ids(bigwig_file_ids)\n",
|
| 778 |
-
" apply_squashing = torch.zeros((len(bigwig_file_ids),), dtype=torch.bool)\n",
|
| 779 |
-
" if len(rna_ids) > 0:\n",
|
| 780 |
-
" apply_squashing[rna_ids] = True\n",
|
| 781 |
-
" \n",
|
| 782 |
-
" def transform_fn(x: torch.Tensor) -> torch.Tensor:\n",
|
| 783 |
-
" \"\"\"\n",
|
| 784 |
-
" x: torch.Tensor, shape (batch, seq_len, num_tracks)\n",
|
| 785 |
-
" \"\"\"\n",
|
| 786 |
-
" device = x.device\n",
|
| 787 |
-
" \n",
|
| 788 |
-
" # Move constants to correct device\n",
|
| 789 |
-
" means = track_means.to(device)\n",
|
| 790 |
-
" squash_mask = apply_squashing.to(device)\n",
|
| 791 |
-
" \n",
|
| 792 |
-
" # Normalize\n",
|
| 793 |
-
" scaled = x / means\n",
|
| 794 |
-
" \n",
|
| 795 |
-
" # Power squashing where needed\n",
|
| 796 |
-
" squashed = torch.where(\n",
|
| 797 |
-
" squash_mask.view(1, 1, -1),\n",
|
| 798 |
-
" scaled.pow(0.75),\n",
|
| 799 |
-
" scaled,\n",
|
| 800 |
-
" )\n",
|
| 801 |
-
" \n",
|
| 802 |
-
" # Smooth clipping: if > 10, apply formula\n",
|
| 803 |
-
" clipped = torch.where(\n",
|
| 804 |
-
" squashed > 10.0,\n",
|
| 805 |
-
" 2.0 * torch.sqrt(squashed * 10.0) - 10.0,\n",
|
| 806 |
-
" squashed,\n",
|
| 807 |
-
" )\n",
|
| 808 |
-
" \n",
|
| 809 |
-
" return clipped\n",
|
| 810 |
-
" \n",
|
| 811 |
-
" return transform_fn\n",
|
| 812 |
-
"\n",
|
| 813 |
-
"\n",
|
| 814 |
-
"def create_predictions_scaling_fn(bigwig_file_ids: List[str]) -> Callable[[torch.Tensor], torch.Tensor]:\n",
|
| 815 |
-
" \"\"\"\n",
|
| 816 |
-
" Inverse scaling function to apply on predictions before computing metrics.\n",
|
| 817 |
-
" Copied from the supervised tracks pipeline.\n",
|
| 818 |
-
" \"\"\"\n",
|
| 819 |
-
" # Load means\n",
|
| 820 |
-
" track_means_np = get_track_means(bigwig_file_ids)\n",
|
| 821 |
-
" track_means = torch.tensor(track_means_np, dtype=torch.float32)\n",
|
| 822 |
-
" \n",
|
| 823 |
-
" # RNA-seq mask\n",
|
| 824 |
-
" rna_ids = get_rna_seq_track_ids(bigwig_file_ids)\n",
|
| 825 |
-
" apply_squashing = torch.zeros((len(bigwig_file_ids),), dtype=torch.bool)\n",
|
| 826 |
-
" if len(rna_ids) > 0:\n",
|
| 827 |
-
" apply_squashing[rna_ids] = True\n",
|
| 828 |
-
" \n",
|
| 829 |
-
" def inverse_transform_fn(x: torch.Tensor) -> torch.Tensor:\n",
|
| 830 |
-
" \"\"\"\n",
|
| 831 |
-
" x: torch.Tensor, shape (batch, seq_len, num_tracks)\n",
|
| 832 |
-
" \"\"\"\n",
|
| 833 |
-
" device = x.device\n",
|
| 834 |
-
" means = track_means.to(device)\n",
|
| 835 |
-
" squash_mask = apply_squashing.to(device)\n",
|
| 836 |
-
" \n",
|
| 837 |
-
" # Undo clipping\n",
|
| 838 |
-
" unclipped = torch.where(\n",
|
| 839 |
-
" x > 10.0,\n",
|
| 840 |
-
" (x + 10.0).pow(2) / (4 * 10.0),\n",
|
| 841 |
-
" x,\n",
|
| 842 |
-
" )\n",
|
| 843 |
-
" \n",
|
| 844 |
-
" # Undo squashing\n",
|
| 845 |
-
" unsquashed = torch.where(\n",
|
| 846 |
-
" squash_mask.view(1, 1, -1),\n",
|
| 847 |
-
" unclipped.pow(1.0 / 0.75),\n",
|
| 848 |
-
" unclipped,\n",
|
| 849 |
-
" )\n",
|
| 850 |
-
" \n",
|
| 851 |
-
" # Undo normalization\n",
|
| 852 |
-
" return unsquashed * means\n",
|
| 853 |
-
" \n",
|
| 854 |
-
" return inverse_transform_fn\n",
|
| 855 |
-
"\n",
|
| 856 |
-
"\n",
|
| 857 |
-
"# Create scaling functions\n",
|
| 858 |
-
"scale_targets_fn = create_targets_scaling_fn(config[\"bigwig_file_ids\"])\n",
|
| 859 |
-
"scale_predictions_fn = create_predictions_scaling_fn(config[\"bigwig_file_ids\"])\n",
|
| 860 |
-
"\n",
|
| 861 |
-
"print(\"Scaling functions created\")"
|
| 862 |
-
]
|
| 863 |
-
},
|
| 864 |
-
{
|
| 865 |
-
"cell_type": "markdown",
|
| 866 |
-
"metadata": {},
|
| 867 |
-
"source": [
|
| 868 |
-
"# 8. Loss functions"
|
| 869 |
-
]
|
| 870 |
-
},
|
| 871 |
-
{
|
| 872 |
-
"cell_type": "code",
|
| 873 |
-
"execution_count": 40,
|
| 874 |
"metadata": {},
|
| 875 |
"outputs": [],
|
| 876 |
"source": [
|
|
@@ -887,49 +776,24 @@
|
|
| 887 |
"def poisson_multinomial_loss(\n",
|
| 888 |
" logits: torch.Tensor,\n",
|
| 889 |
" targets: torch.Tensor,\n",
|
| 890 |
-
" mask: torch.Tensor | None = None,\n",
|
| 891 |
" shape_loss_coefficient: float = 5.0,\n",
|
| 892 |
" epsilon: float = 1e-7,\n",
|
| 893 |
") -> tuple[torch.Tensor, torch.Tensor | None, torch.Tensor | None]:\n",
|
| 894 |
" \"\"\"\n",
|
| 895 |
" Regression loss for bigwig tracks (MSE, Poisson, or Poisson-Multinomial).\n",
|
| 896 |
" \"\"\"\n",
|
| 897 |
-
" scale_loss, shape_loss = None, None\n",
|
| 898 |
-
" \n",
|
| 899 |
-
" if mask is None:\n",
|
| 900 |
-
" mask = torch.ones_like(targets, dtype=torch.float32, device=targets.device)\n",
|
| 901 |
-
" else:\n",
|
| 902 |
-
" mask = mask.float()\n",
|
| 903 |
-
" \n",
|
| 904 |
-
" mask_sum = mask.sum() + epsilon\n",
|
| 905 |
-
" masked_logits = logits * mask\n",
|
| 906 |
-
" masked_targets = targets * mask\n",
|
| 907 |
"\n",
|
| 908 |
" # Scale loss\n",
|
| 909 |
-
"
|
| 910 |
-
"
|
| 911 |
-
" \n",
|
| 912 |
-
" sum_pred = masked_logits.sum(dim=1) # (batch, num_tracks)\n",
|
| 913 |
-
" sum_true = masked_targets.sum(dim=1) # (batch, num_tracks)\n",
|
| 914 |
-
" \n",
|
| 915 |
" scale_loss = poisson_loss(sum_true, sum_pred, epsilon=epsilon)\n",
|
| 916 |
-
" scale_loss = scale_loss
|
| 917 |
-
" \n",
|
| 918 |
-
" if mask_per_sequence.any():\n",
|
| 919 |
-
" scale_loss_filtered = scale_loss[mask_per_sequence]\n",
|
| 920 |
-
" scale_loss = scale_loss_filtered.mean()\n",
|
| 921 |
-
" else:\n",
|
| 922 |
-
" scale_loss = torch.tensor(0.0, device=targets.device, dtype=targets.dtype)\n",
|
| 923 |
" \n",
|
| 924 |
" # Shape loss\n",
|
| 925 |
-
"
|
| 926 |
-
"
|
| 927 |
-
" \n",
|
| 928 |
-
" denom = predicted_counts.sum(dim=1, keepdim=True) + epsilon\n",
|
| 929 |
-
" p_pred = predicted_counts / denom\n",
|
| 930 |
-
" \n",
|
| 931 |
" pl_pred = safe_for_grad_log_torch(p_pred)\n",
|
| 932 |
-
" shape_loss = -(
|
| 933 |
" \n",
|
| 934 |
" # Combine\n",
|
| 935 |
" loss = shape_loss + scale_loss / shape_loss_coefficient\n",
|
|
@@ -941,57 +805,42 @@
|
|
| 941 |
"cell_type": "markdown",
|
| 942 |
"metadata": {},
|
| 943 |
"source": [
|
| 944 |
-
"#
|
| 945 |
]
|
| 946 |
},
|
| 947 |
{
|
| 948 |
"cell_type": "code",
|
| 949 |
-
"execution_count":
|
| 950 |
"metadata": {},
|
| 951 |
"outputs": [],
|
| 952 |
"source": [
|
| 953 |
"def train_step(\n",
|
| 954 |
" model: nn.Module,\n",
|
| 955 |
" batch: Dict[str, torch.Tensor],\n",
|
| 956 |
-
" optimizer: torch.optim.Optimizer,\n",
|
| 957 |
-
" scale_targets_fn: Callable,\n",
|
| 958 |
-
" config: Dict,\n",
|
| 959 |
-
" num_accumulation_steps: int = 1,\n",
|
| 960 |
") -> float:\n",
|
| 961 |
-
" \"\"\"Single training step
|
| 962 |
" tokens = batch[\"tokens\"].to(device)\n",
|
| 963 |
-
" bigwig_targets = batch[\"bigwig_targets\"].to(device)
|
| 964 |
" \n",
|
| 965 |
" # Forward pass\n",
|
| 966 |
" outputs = model(tokens=tokens)\n",
|
| 967 |
-
" bigwig_logits = outputs[\"bigwig_tracks_logits\"]
|
| 968 |
-
" \n",
|
| 969 |
-
" # Scale targets\n",
|
| 970 |
-
" scaled_targets = scale_targets_fn(bigwig_targets)\n",
|
| 971 |
" \n",
|
| 972 |
" # Compute loss\n",
|
| 973 |
" loss, _, _ = poisson_multinomial_loss(\n",
|
| 974 |
" logits=bigwig_logits,\n",
|
| 975 |
-
" targets=
|
| 976 |
-
" shape_loss_coefficient=config[\"bigwig_shape_loss_coefficient\"],\n",
|
| 977 |
" )\n",
|
| 978 |
" \n",
|
| 979 |
-
" #
|
| 980 |
-
" loss = loss / num_accumulation_steps\n",
|
| 981 |
-
" \n",
|
| 982 |
-
" # Backward pass (accumulate gradients)\n",
|
| 983 |
" loss.backward()\n",
|
| 984 |
-
" \n",
|
| 985 |
-
" return loss.item() * num_accumulation_steps # Return unscaled loss for logging\n",
|
| 986 |
"\n",
|
| 987 |
"\n",
|
| 988 |
"def validation_step(\n",
|
| 989 |
" model: nn.Module,\n",
|
| 990 |
" batch: Dict[str, torch.Tensor],\n",
|
| 991 |
-
" scale_targets_fn: Callable,\n",
|
| 992 |
-
" scale_predictions_fn: Callable,\n",
|
| 993 |
" metrics: TracksMetrics,\n",
|
| 994 |
-
" config: Dict,\n",
|
| 995 |
") -> float:\n",
|
| 996 |
" \"\"\"Single validation step.\"\"\"\n",
|
| 997 |
" model.eval()\n",
|
|
@@ -1004,35 +853,32 @@
|
|
| 1004 |
" outputs = model(tokens=tokens)\n",
|
| 1005 |
" bigwig_logits = outputs[\"bigwig_tracks_logits\"]\n",
|
| 1006 |
" \n",
|
| 1007 |
-
" #
|
| 1008 |
-
" scaled_targets = scale_targets_fn(bigwig_targets)\n",
|
| 1009 |
-
" \n",
|
| 1010 |
-
" # Compute loss (using scaled targets)\n",
|
| 1011 |
" loss, _, _ = poisson_multinomial_loss(\n",
|
| 1012 |
" logits=bigwig_logits,\n",
|
| 1013 |
-
" targets=
|
| 1014 |
-
" shape_loss_coefficient=config[\"bigwig_shape_loss_coefficient\"],\n",
|
| 1015 |
" )\n",
|
| 1016 |
" \n",
|
| 1017 |
-
" #
|
| 1018 |
-
" # (predictions are in scaled space, need to inverse transform)\n",
|
| 1019 |
-
" unscaled_predictions = scale_predictions_fn(bigwig_logits)\n",
|
| 1020 |
-
" \n",
|
| 1021 |
-
" # Update metrics (using original space targets and predictions)\n",
|
| 1022 |
" metrics.update(\n",
|
| 1023 |
-
"
|
| 1024 |
-
"
|
| 1025 |
-
" predictions_raw=unscaled_predictions,\n",
|
| 1026 |
-
" targets_raw=bigwig_targets,\n",
|
| 1027 |
" loss=loss.item()\n",
|
| 1028 |
" )\n",
|
| 1029 |
" \n",
|
| 1030 |
" return loss.item()"
|
| 1031 |
]
|
| 1032 |
},
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1033 |
{
|
| 1034 |
"cell_type": "code",
|
| 1035 |
-
"execution_count":
|
| 1036 |
"metadata": {},
|
| 1037 |
"outputs": [
|
| 1038 |
{
|
|
@@ -1040,163 +886,455 @@
|
|
| 1040 |
"output_type": "stream",
|
| 1041 |
"text": [
|
| 1042 |
"Starting training...\n",
|
| 1043 |
-
"Training for
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1044 |
"\n",
|
| 1045 |
-
"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1046 |
"\n",
|
| 1047 |
-
"Running validation at step
|
| 1048 |
-
" Validation Loss:
|
| 1049 |
-
" Validation Mean Pearson: -0.
|
| 1050 |
-
" ENCFF884LDL/pearson: -0.
|
| 1051 |
-
"Step
|
| 1052 |
-
"Step
|
|
|
|
|
|
|
|
|
|
| 1053 |
"\n",
|
| 1054 |
-
"Running validation at step
|
| 1055 |
-
" Validation Loss:
|
| 1056 |
-
" Validation Mean Pearson:
|
| 1057 |
-
" ENCFF884LDL/pearson:
|
| 1058 |
-
"Step
|
| 1059 |
-
"Step
|
|
|
|
|
|
|
|
|
|
| 1060 |
"\n",
|
| 1061 |
-
"Running validation at step
|
| 1062 |
-
" Validation Loss:
|
| 1063 |
-
" Validation Mean Pearson:
|
| 1064 |
-
" ENCFF884LDL/pearson:
|
| 1065 |
-
"Step
|
| 1066 |
-
"Step
|
|
|
|
|
|
|
|
|
|
| 1067 |
"\n",
|
| 1068 |
-
"Running validation at step
|
| 1069 |
-
" Validation Loss:
|
| 1070 |
-
" Validation Mean Pearson:
|
| 1071 |
-
" ENCFF884LDL/pearson:
|
| 1072 |
-
"Step
|
| 1073 |
-
"Step
|
|
|
|
|
|
|
|
|
|
| 1074 |
"\n",
|
| 1075 |
-
"Running validation at step
|
| 1076 |
-
" Validation Loss:
|
| 1077 |
-
" Validation Mean Pearson:
|
| 1078 |
-
" ENCFF884LDL/pearson:
|
| 1079 |
-
"Step
|
| 1080 |
-
"Step
|
|
|
|
|
|
|
|
|
|
| 1081 |
"\n",
|
| 1082 |
-
"Running validation at step
|
| 1083 |
-
" Validation Loss:
|
| 1084 |
-
" Validation Mean Pearson:
|
| 1085 |
-
" ENCFF884LDL/pearson:
|
| 1086 |
-
"Step
|
| 1087 |
-
"Step
|
|
|
|
|
|
|
|
|
|
| 1088 |
"\n",
|
| 1089 |
-
"Running validation at step
|
| 1090 |
-
" Validation Loss:
|
| 1091 |
-
" Validation Mean Pearson:
|
| 1092 |
-
" ENCFF884LDL/pearson:
|
| 1093 |
-
"Step
|
| 1094 |
-
"Step
|
|
|
|
|
|
|
|
|
|
| 1095 |
"\n",
|
| 1096 |
-
"Running validation at step
|
| 1097 |
-
" Validation Loss:
|
| 1098 |
-
" Validation Mean Pearson:
|
| 1099 |
-
" ENCFF884LDL/pearson:
|
| 1100 |
-
"Step
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1101 |
"\n",
|
| 1102 |
-
"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1103 |
]
|
| 1104 |
}
|
| 1105 |
],
|
| 1106 |
"source": [
|
| 1107 |
-
"# Training loop
|
| 1108 |
"print(\"Starting training...\")\n",
|
| 1109 |
-
"print(f\"Training for {num_steps_training} steps
|
| 1110 |
"\n",
|
| 1111 |
"model.train()\n",
|
| 1112 |
"train_metrics.reset()\n",
|
| 1113 |
"optimizer.zero_grad() # Initialize gradients\n",
|
| 1114 |
"\n",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1115 |
"# Create iterator for training data (will cycle if needed)\n",
|
| 1116 |
"train_iter = iter(train_loader)\n",
|
| 1117 |
-
"
|
| 1118 |
-
"\n",
|
| 1119 |
-
"
|
| 1120 |
-
"
|
| 1121 |
-
"
|
| 1122 |
-
"
|
| 1123 |
-
"
|
| 1124 |
-
"
|
| 1125 |
-
"
|
| 1126 |
-
" except StopIteration:\n",
|
| 1127 |
-
" # Restart iterator if we run out of data\n",
|
| 1128 |
-
" train_iter = iter(train_loader)\n",
|
| 1129 |
-
" batch = next(train_iter)\n",
|
| 1130 |
-
" \n",
|
| 1131 |
-
" # Forward pass and accumulate gradients\n",
|
| 1132 |
-
" loss = train_step(\n",
|
| 1133 |
-
" model, batch, optimizer, scale_targets_fn, config, \n",
|
| 1134 |
-
" num_accumulation_steps=num_accumulation_gradient\n",
|
| 1135 |
-
" )\n",
|
| 1136 |
-
" accumulated_loss += loss\n",
|
| 1137 |
" \n",
|
| 1138 |
-
" #
|
|
|
|
|
|
|
|
|
|
| 1139 |
" optimizer.step()\n",
|
| 1140 |
" optimizer.zero_grad()\n",
|
| 1141 |
" \n",
|
| 1142 |
-
"
|
| 1143 |
-
" num_tokens_seen += effective_num_tokens_per_update\n",
|
| 1144 |
-
" \n",
|
| 1145 |
-
" # Update metrics (on last batch of accumulation)\n",
|
| 1146 |
" tokens = batch[\"tokens\"].to(device)\n",
|
| 1147 |
" bigwig_targets = batch[\"bigwig_targets\"].to(device)\n",
|
| 1148 |
" with torch.no_grad():\n",
|
| 1149 |
" outputs = model(tokens=tokens)\n",
|
| 1150 |
" bigwig_logits = outputs[\"bigwig_tracks_logits\"]\n",
|
| 1151 |
" \n",
|
| 1152 |
-
" # Scale targets for scaled metrics\n",
|
| 1153 |
-
" scaled_targets = scale_targets_fn(bigwig_targets)\n",
|
| 1154 |
-
" \n",
|
| 1155 |
-
" # Unscale predictions for raw metrics\n",
|
| 1156 |
-
" unscaled_predictions = scale_predictions_fn(bigwig_logits)\n",
|
| 1157 |
-
" \n",
|
| 1158 |
-
" avg_loss = accumulated_loss / num_accumulation_gradient\n",
|
| 1159 |
" train_metrics.update(\n",
|
| 1160 |
-
"
|
| 1161 |
-
"
|
| 1162 |
-
"
|
| 1163 |
-
" targets_raw=bigwig_targets,\n",
|
| 1164 |
-
" loss=avg_loss\n",
|
| 1165 |
" )\n",
|
| 1166 |
" \n",
|
| 1167 |
" # Logging\n",
|
| 1168 |
-
" if
|
| 1169 |
" train_metrics_dict = train_metrics.compute()\n",
|
| 1170 |
-
" current_lr =
|
| 1171 |
-
"
|
| 1172 |
-
"
|
| 1173 |
-
"
|
| 1174 |
-
"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1175 |
" train_metrics.reset()\n",
|
| 1176 |
" \n",
|
| 1177 |
" # Validation\n",
|
| 1178 |
-
" if
|
| 1179 |
-
" print(f\"\\nRunning validation at step {
|
| 1180 |
" val_metrics.reset()\n",
|
| 1181 |
" model.eval()\n",
|
| 1182 |
" \n",
|
| 1183 |
-
"
|
| 1184 |
" for val_batch in val_loader:\n",
|
| 1185 |
-
" val_loss = validation_step(\n",
|
| 1186 |
-
"
|
| 1187 |
-
" )\n",
|
| 1188 |
-
" val_losses.append(val_loss)\n",
|
| 1189 |
" \n",
|
| 1190 |
" # Print validation metrics\n",
|
| 1191 |
" val_metrics_dict = val_metrics.compute()\n",
|
| 1192 |
-
"
|
| 1193 |
-
"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1194 |
" for track_name in config[\"bigwig_file_ids\"]:\n",
|
| 1195 |
-
" print(f\" {track_name}/pearson: {val_metrics_dict[f'
|
| 1196 |
" \n",
|
| 1197 |
" model.train() # Back to training mode\n",
|
| 1198 |
"\n",
|
| 1199 |
-
"print(f\"\\nTraining completed after {num_steps_training} steps
|
| 1200 |
]
|
| 1201 |
},
|
| 1202 |
{
|
|
@@ -1208,122 +1346,59 @@
|
|
| 1208 |
},
|
| 1209 |
{
|
| 1210 |
"cell_type": "code",
|
| 1211 |
-
"execution_count":
|
| 1212 |
-
"metadata": {},
|
| 1213 |
-
"outputs": [],
|
| 1214 |
-
"source": [
|
| 1215 |
-
"def test_step(\n",
|
| 1216 |
-
" model: nn.Module,\n",
|
| 1217 |
-
" batch: Dict[str, torch.Tensor],\n",
|
| 1218 |
-
" scale_targets_fn: Callable,\n",
|
| 1219 |
-
" scale_predictions_fn: Callable,\n",
|
| 1220 |
-
" metrics: TracksMetrics,\n",
|
| 1221 |
-
") -> None:\n",
|
| 1222 |
-
" \"\"\"\n",
|
| 1223 |
-
" Pure evaluation step for test set (no loss computation).\n",
|
| 1224 |
-
" Based on tracks_evaluation_step_torch from deepspeed pipeline.\n",
|
| 1225 |
-
" \"\"\"\n",
|
| 1226 |
-
" tokens = batch[\"tokens\"].to(device)\n",
|
| 1227 |
-
" bigwig_targets = batch[\"bigwig_targets\"].to(device) # Shape: (batch, seq_len_cropped, num_tracks)\n",
|
| 1228 |
-
" \n",
|
| 1229 |
-
" with torch.no_grad():\n",
|
| 1230 |
-
" # Forward pass\n",
|
| 1231 |
-
" outputs = model(tokens=tokens)\n",
|
| 1232 |
-
" bigwig_logits = outputs[\"bigwig_tracks_logits\"] # Shape: (batch, cropped_seq_len, num_tracks)\n",
|
| 1233 |
-
" \n",
|
| 1234 |
-
" # Scale targets for scaled metrics\n",
|
| 1235 |
-
" scaled_targets = scale_targets_fn(bigwig_targets)\n",
|
| 1236 |
-
" \n",
|
| 1237 |
-
" # Unscale predictions for raw metrics\n",
|
| 1238 |
-
" unscaled_predictions = scale_predictions_fn(bigwig_logits)\n",
|
| 1239 |
-
" \n",
|
| 1240 |
-
" # Update metrics with both scaled and raw values\n",
|
| 1241 |
-
" # Pass 0.0 as loss since we don't compute loss in test evaluation\n",
|
| 1242 |
-
" metrics.update(\n",
|
| 1243 |
-
" predictions_scaled=bigwig_logits,\n",
|
| 1244 |
-
" targets_scaled=scaled_targets,\n",
|
| 1245 |
-
" predictions_raw=unscaled_predictions,\n",
|
| 1246 |
-
" targets_raw=bigwig_targets,\n",
|
| 1247 |
-
" loss=0.0\n",
|
| 1248 |
-
" )"
|
| 1249 |
-
]
|
| 1250 |
-
},
|
| 1251 |
-
{
|
| 1252 |
-
"cell_type": "code",
|
| 1253 |
-
"execution_count": 28,
|
| 1254 |
"metadata": {},
|
| 1255 |
"outputs": [
|
| 1256 |
{
|
| 1257 |
"name": "stdout",
|
| 1258 |
"output_type": "stream",
|
| 1259 |
"text": [
|
| 1260 |
-
"\n",
|
| 1261 |
-
"==================================================\n",
|
| 1262 |
-
"Test Set Evaluation\n",
|
| 1263 |
-
"==================================================\n",
|
| 1264 |
-
"Running test evaluation with 5 steps (10 samples)\n",
|
| 1265 |
"\n",
|
| 1266 |
"==================================================\n",
|
| 1267 |
"Test Set Results\n",
|
| 1268 |
"==================================================\n",
|
| 1269 |
"\n",
|
| 1270 |
-
"
|
| 1271 |
-
" Mean Pearson
|
| 1272 |
-
" ENCFF884LDL/pearson:
|
| 1273 |
-
"\n",
|
| 1274 |
-
"Raw Metrics (raw predictions vs raw targets):\n",
|
| 1275 |
-
" Mean Pearson (raw): -0.0020\n",
|
| 1276 |
-
" ENCFF884LDL/pearson: -0.0020\n",
|
| 1277 |
-
"==================================================\n"
|
| 1278 |
]
|
| 1279 |
}
|
| 1280 |
],
|
| 1281 |
"source": [
|
| 1282 |
-
"print(\"\\n\" + \"=\"*50)\n",
|
| 1283 |
-
"print(\"Test Set Evaluation\")\n",
|
| 1284 |
-
"print(\"=\"*50)\n",
|
| 1285 |
-
"\n",
|
| 1286 |
"# Calculate number of test steps (based on deepspeed pipeline)\n",
|
| 1287 |
"num_test_samples = len(test_dataset)\n",
|
| 1288 |
"num_test_steps = num_test_samples // config[\"batch_size\"]\n",
|
| 1289 |
-
"\n",
|
| 1290 |
"print(f\"Running test evaluation with {num_test_steps} steps ({num_test_samples} samples)\")\n",
|
| 1291 |
"\n",
|
| 1292 |
"# Set model to eval mode\n",
|
| 1293 |
"model.eval()\n",
|
| 1294 |
"\n",
|
| 1295 |
-
"
|
| 1296 |
-
"test_iter = iter(test_loader)\n",
|
| 1297 |
"\n",
|
| 1298 |
-
"
|
| 1299 |
-
"
|
| 1300 |
-
"
|
| 1301 |
-
"
|
| 1302 |
-
" except StopIteration:\n",
|
| 1303 |
-
" break\n",
|
| 1304 |
-
" \n",
|
| 1305 |
-
" # Perform test evaluation (pure evaluation, no loss computation)\n",
|
| 1306 |
-
" test_step(\n",
|
| 1307 |
-
" model, test_batch, scale_targets_fn, scale_predictions_fn, test_metrics\n",
|
| 1308 |
" )\n",
|
| 1309 |
-
"\n",
|
| 1310 |
"# Compute final test metrics\n",
|
| 1311 |
"test_metrics_dict = test_metrics.compute()\n",
|
| 1312 |
-
"\n",
|
| 1313 |
"print(\"\\n\" + \"=\"*50)\n",
|
| 1314 |
"print(\"Test Set Results\")\n",
|
| 1315 |
"print(\"=\"*50)\n",
|
| 1316 |
-
"print(f\"\\
|
| 1317 |
-
"print(f\" Mean Pearson
|
| 1318 |
-
"for track_name in config[\"bigwig_file_ids\"]
|
| 1319 |
-
" print(f\" {track_name}/pearson: {test_metrics_dict[f'
|
| 1320 |
-
"\n",
|
| 1321 |
-
"print(f\"\\nRaw Metrics (raw predictions vs raw targets):\")\n",
|
| 1322 |
-
"print(f\" Mean Pearson (raw): {test_metrics_dict['metrics_raw/mean/pearson']:.4f}\")\n",
|
| 1323 |
-
"for track_name in config[\"bigwig_file_ids\"]:\n",
|
| 1324 |
-
" print(f\" {track_name}/pearson: {test_metrics_dict[f'metrics_raw/{track_name}/pearson']:.4f}\")\n",
|
| 1325 |
-
"print(\"=\"*50)"
|
| 1326 |
]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1327 |
}
|
| 1328 |
],
|
| 1329 |
"metadata": {
|
|
|
|
| 16 |
},
|
| 17 |
{
|
| 18 |
"cell_type": "code",
|
| 19 |
+
"execution_count": 1,
|
| 20 |
"metadata": {},
|
| 21 |
"outputs": [],
|
| 22 |
"source": [
|
|
|
|
| 28 |
},
|
| 29 |
{
|
| 30 |
"cell_type": "code",
|
| 31 |
+
"execution_count": 1,
|
| 32 |
"metadata": {},
|
| 33 |
"outputs": [],
|
| 34 |
"source": [
|
| 35 |
"# 0. Imports\n",
|
| 36 |
"import random\n",
|
| 37 |
"import functools\n",
|
| 38 |
+
"from typing import List, Dict, Callable\n",
|
| 39 |
"import os\n",
|
| 40 |
"import subprocess\n",
|
| 41 |
"\n",
|
|
|
|
| 48 |
"import numpy as np\n",
|
| 49 |
"import pyBigWig\n",
|
| 50 |
"from pyfaidx import Fasta\n",
|
| 51 |
+
"from torchmetrics import PearsonCorrCoef\n",
|
| 52 |
+
"import plotly.graph_objects as go\n",
|
| 53 |
+
"from plotly.subplots import make_subplots\n",
|
| 54 |
+
"from IPython.display import display"
|
| 55 |
]
|
| 56 |
},
|
| 57 |
{
|
| 58 |
"cell_type": "markdown",
|
| 59 |
"metadata": {},
|
| 60 |
"source": [
|
| 61 |
+
"# 1. Configuration setup\n",
|
| 62 |
+
"\n",
|
| 63 |
+
"## Configuration Parameters\n",
|
| 64 |
+
"\n",
|
| 65 |
+
"### Model\n",
|
| 66 |
+
"- **`model_name`**: HuggingFace model name/identifier for the pretrained backbone model\n",
|
| 67 |
+
"\n",
|
| 68 |
+
"### Data\n",
|
| 69 |
+
"- **`data_cache_dir`**: Directory where downloaded data files (FASTA, bigWig) will be stored\n",
|
| 70 |
+
"- **`fasta_url`**: URL to download reference genome FASTA file\n",
|
| 71 |
+
"- **`bigwig_url_list`**: List of URLs for bigWig track files to download\n",
|
| 72 |
+
"- **`sequence_length`**: Length of input sequences in base pairs (bp)\n",
|
| 73 |
+
"- **`keep_target_center_fraction`**: Fraction of center sequence to keep for target prediction (crops edges to focus on center)\n",
|
| 74 |
+
"\n",
|
| 75 |
+
"### Training\n",
|
| 76 |
+
"- **`batch_size`**: Number of samples per batch\n",
|
| 77 |
+
"- **`learning_rate`**: Constant learning rate for optimizer\n",
|
| 78 |
+
"- **`weight_decay`**: L2 regularization coefficient for optimizer\n",
|
| 79 |
+
"- **`num_steps_training`**: Total number of training steps\n",
|
| 80 |
+
"- **`log_every_n_steps`**: Log training metrics every N steps\n",
|
| 81 |
+
"- **`validate_every_n_steps`**: Run validation every N steps\n",
|
| 82 |
+
"\n",
|
| 83 |
+
"### Validation\n",
|
| 84 |
+
"- **`num_validation_samples`**: Number of samples to use for validation set\n",
|
| 85 |
+
"\n",
|
| 86 |
+
"### General\n",
|
| 87 |
+
"- **`seed`**: Random seed for reproducibility\n",
|
| 88 |
+
"- **`device`**: Device to run training on (\"cuda\" or \"cpu\")\n",
|
| 89 |
+
"- **`num_workers`**: Number of worker processes for DataLoader (0 = single-threaded)"
|
| 90 |
]
|
| 91 |
},
|
| 92 |
{
|
| 93 |
"cell_type": "code",
|
| 94 |
+
"execution_count": 15,
|
| 95 |
"metadata": {},
|
| 96 |
"outputs": [
|
| 97 |
{
|
|
|
|
| 105 |
"source": [
|
| 106 |
"config = {\n",
|
| 107 |
" # Model\n",
|
| 108 |
+
" \"model_name\": \"InstaDeepAI/ntv3_8M_7downsample_pretrained_le_1mb\",\n",
|
| 109 |
" \n",
|
| 110 |
" # Data\n",
|
| 111 |
+
" \"data_cache_dir\": \"./data\",\n",
|
| 112 |
+
" \"fasta_url\": \"https://hgdownload.gi.ucsc.edu/goldenPath/hg38/bigZips/hg38.fa.gz\",\n",
|
| 113 |
+
" \"bigwig_url_list\": [\n",
|
| 114 |
+
" \"https://www.encodeproject.org/files/ENCFF884LDL/@@download/ENCFF884LDL.bigWig\"\n",
|
| 115 |
+
" ],\n",
|
| 116 |
+
" \"sequence_length\": 1_024,\n",
|
| 117 |
+
" \"keep_target_center_fraction\": 0.375,\n",
|
| 118 |
" \n",
|
| 119 |
" # Training\n",
|
| 120 |
+
" \"batch_size\": 8,\n",
|
| 121 |
+
" \"num_steps_training\": 1000,\n",
|
| 122 |
+
" \"log_every_n_steps\": 10,\n",
|
| 123 |
+
" \"learning_rate\": 1e-5,\n",
|
| 124 |
+
" \"weight_decay\": 0.01,\n",
|
|
|
|
|
|
|
|
|
|
| 125 |
" \n",
|
| 126 |
" # Validation\n",
|
| 127 |
+
" \"validate_every_n_steps\": 50,\n",
|
| 128 |
+
" \"num_validation_samples\": 100,\n",
|
|
|
|
|
|
|
|
|
|
|
|
|
| 129 |
" \n",
|
| 130 |
" # General\n",
|
| 131 |
+
" \"seed\": 42,\n",
|
| 132 |
+
" \"device\": \"cuda\" if torch.cuda.is_available() else \"cpu\",\n",
|
| 133 |
+
" \"num_workers\": 0,\n",
|
| 134 |
"}\n",
|
| 135 |
"\n",
|
| 136 |
"os.makedirs(config[\"data_cache_dir\"], exist_ok=True)\n",
|
|
|
|
| 242 |
},
|
| 243 |
{
|
| 244 |
"cell_type": "code",
|
| 245 |
+
"execution_count": 3,
|
| 246 |
"metadata": {},
|
| 247 |
"outputs": [],
|
| 248 |
"source": [
|
|
|
|
| 262 |
},
|
| 263 |
{
|
| 264 |
"cell_type": "code",
|
| 265 |
+
"execution_count": 4,
|
| 266 |
"metadata": {},
|
| 267 |
"outputs": [],
|
| 268 |
"source": [
|
|
|
|
| 330 |
},
|
| 331 |
{
|
| 332 |
"cell_type": "code",
|
| 333 |
+
"execution_count": 5,
|
| 334 |
"metadata": {},
|
| 335 |
"outputs": [
|
| 336 |
{
|
|
|
|
| 361 |
"print(f\"Model parameters: {sum(p.numel() for p in model.parameters()):,}\")"
|
| 362 |
]
|
| 363 |
},
|
| 364 |
+
{
|
| 365 |
+
"cell_type": "code",
|
| 366 |
+
"execution_count": 6,
|
| 367 |
+
"metadata": {},
|
| 368 |
+
"outputs": [],
|
| 369 |
+
"source": [
|
| 370 |
+
"# Scaling functions for targets\n",
|
| 371 |
+
"def get_track_means(bigwig_file_ids: List[str]) -> np.ndarray:\n",
|
| 372 |
+
" \"\"\"\n",
|
| 373 |
+
" Get track means for normalization.\n",
|
| 374 |
+
" For now, return dummy values. In real pipeline, this loads from metadata.\n",
|
| 375 |
+
" \"\"\"\n",
|
| 376 |
+
" # Dummy values - in real pipeline, this would load from actual metadata\n",
|
| 377 |
+
" return np.ones(len(bigwig_file_ids), dtype=np.float32) * 1.0\n",
|
| 378 |
+
"\n",
|
| 379 |
+
"\n",
|
| 380 |
+
"def create_targets_scaling_fn(bigwig_file_ids: List[str]) -> Callable[[torch.Tensor], torch.Tensor]:\n",
|
| 381 |
+
" \"\"\"\n",
|
| 382 |
+
" Build a scaling function based on track means.\n",
|
| 383 |
+
" \"\"\"\n",
|
| 384 |
+
" # Load track means\n",
|
| 385 |
+
" track_means_np = get_track_means(bigwig_file_ids)\n",
|
| 386 |
+
" track_means = torch.tensor(track_means_np, dtype=torch.float32)\n",
|
| 387 |
+
" \n",
|
| 388 |
+
" def transform_fn(x: torch.Tensor) -> torch.Tensor:\n",
|
| 389 |
+
" \"\"\"\n",
|
| 390 |
+
" x: torch.Tensor, shape (seq_len, num_tracks) or (batch, seq_len, num_tracks)\n",
|
| 391 |
+
" \"\"\"\n",
|
| 392 |
+
" # Move constants to correct device then normalize\n",
|
| 393 |
+
" means = track_means.to(x.device)\n",
|
| 394 |
+
" scaled = x / means\n",
|
| 395 |
+
"\n",
|
| 396 |
+
" # Smooth clipping: if > 10, apply formula\n",
|
| 397 |
+
" clipped = torch.where(\n",
|
| 398 |
+
" scaled > 10.0,\n",
|
| 399 |
+
" 2.0 * torch.sqrt(scaled * 10.0) - 10.0,\n",
|
| 400 |
+
" scaled,\n",
|
| 401 |
+
" )\n",
|
| 402 |
+
" return clipped\n",
|
| 403 |
+
" \n",
|
| 404 |
+
" return transform_fn"
|
| 405 |
+
]
|
| 406 |
+
},
|
| 407 |
{
|
| 408 |
"cell_type": "markdown",
|
| 409 |
"metadata": {},
|
|
|
|
| 413 |
},
|
| 414 |
{
|
| 415 |
"cell_type": "code",
|
| 416 |
+
"execution_count": 7,
|
| 417 |
"metadata": {},
|
| 418 |
"outputs": [],
|
| 419 |
"source": [
|
|
|
|
| 457 |
" sequence_length: int,\n",
|
| 458 |
" num_samples: int,\n",
|
| 459 |
" tokenizer: AutoTokenizer,\n",
|
| 460 |
+
" transform_fn: Callable[[torch.Tensor], torch.Tensor],\n",
|
| 461 |
" keep_target_center_fraction: float = 1.0,\n",
|
| 462 |
" num_tracks: int = 1,\n",
|
| 463 |
" ):\n",
|
|
|
|
| 471 |
" self.sequence_length = sequence_length\n",
|
| 472 |
" self.num_samples = num_samples\n",
|
| 473 |
" self.tokenizer = tokenizer\n",
|
| 474 |
+
" self.transform_fn = transform_fn\n",
|
| 475 |
" self.keep_target_center_fraction = keep_target_center_fraction\n",
|
| 476 |
" self.num_tracks = num_tracks\n",
|
| 477 |
" self.chroms = chroms\n",
|
|
|
|
| 536 |
" target_length = seq_len - 2 * target_offset\n",
|
| 537 |
" bigwig_targets = bigwig_targets[target_offset:target_offset + target_length, :]\n",
|
| 538 |
"\n",
|
| 539 |
+
" # Apply scaling to targets\n",
|
| 540 |
+
" bigwig_targets = self.transform_fn(bigwig_targets)\n",
|
| 541 |
+
"\n",
|
| 542 |
" sample = {\n",
|
| 543 |
" \"tokens\": tokens,\n",
|
| 544 |
" \"bigwig_targets\": bigwig_targets,\n",
|
|
|
|
| 551 |
},
|
| 552 |
{
|
| 553 |
"cell_type": "code",
|
| 554 |
+
"execution_count": 16,
|
| 555 |
"metadata": {},
|
| 556 |
"outputs": [
|
| 557 |
{
|
|
|
|
| 559 |
"output_type": "stream",
|
| 560 |
"text": [
|
| 561 |
"Train samples: 100\n",
|
| 562 |
+
"Val samples: 100\n",
|
| 563 |
+
"Test samples: 100\n"
|
| 564 |
]
|
| 565 |
}
|
| 566 |
],
|
| 567 |
"source": [
|
| 568 |
+
"# Create scaling function\n",
|
| 569 |
+
"transform_fn = create_targets_scaling_fn(config[\"bigwig_file_ids\"])\n",
|
| 570 |
+
"\n",
|
| 571 |
"create_dataset_fn = functools.partial(\n",
|
| 572 |
" GenomeBigWigDataset,\n",
|
| 573 |
" fasta_path=fasta_path,\n",
|
| 574 |
" bigwig_path_list=bigwig_path_list,\n",
|
| 575 |
" sequence_length=config[\"sequence_length\"],\n",
|
| 576 |
" tokenizer=tokenizer,\n",
|
| 577 |
+
" transform_fn=transform_fn,\n",
|
| 578 |
" keep_target_center_fraction=config[\"keep_target_center_fraction\"],\n",
|
| 579 |
" num_tracks=len(config[\"bigwig_file_ids\"]),\n",
|
| 580 |
")\n",
|
|
|
|
| 630 |
},
|
| 631 |
{
|
| 632 |
"cell_type": "code",
|
| 633 |
+
"execution_count": 17,
|
| 634 |
"metadata": {},
|
| 635 |
"outputs": [
|
| 636 |
{
|
| 637 |
"name": "stdout",
|
| 638 |
"output_type": "stream",
|
| 639 |
"text": [
|
| 640 |
+
"Training configuration:\n",
|
| 641 |
+
" Batch size: 8\n",
|
| 642 |
+
" Total training steps: 1000\n",
|
| 643 |
+
" Log metrics every: 10 steps\n",
|
| 644 |
+
" Validate every: 50 steps\n",
|
|
|
|
|
|
|
|
|
|
| 645 |
"\n",
|
| 646 |
"Optimizer setup:\n",
|
| 647 |
" Learning rate: 1e-05\n"
|
|
|
|
| 649 |
}
|
| 650 |
],
|
| 651 |
"source": [
|
| 652 |
+
"# Training setup\n",
|
| 653 |
+
"print(f\"Training configuration:\")\n",
|
| 654 |
+
"print(f\" Batch size: {config[\"batch_size\"]}\")\n",
|
| 655 |
+
"print(f\" Total training steps: {config[\"num_steps_training\"]}\")\n",
|
| 656 |
+
"print(f\" Log metrics every: {config[\"log_every_n_steps\"]} steps\")\n",
|
| 657 |
+
"print(f\" Validate every: {config[\"validate_every_n_steps\"]} steps\")\n",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 658 |
"\n",
|
| 659 |
"# Setup optimizer\n",
|
| 660 |
"optimizer = AdamW(\n",
|
|
|
|
| 676 |
},
|
| 677 |
{
|
| 678 |
"cell_type": "code",
|
| 679 |
+
"execution_count": 18,
|
| 680 |
"metadata": {},
|
| 681 |
"outputs": [],
|
| 682 |
"source": [
|
| 683 |
"class TracksMetrics:\n",
|
| 684 |
+
" \"\"\"Simple metrics tracker for tracks prediction.\"\"\"\n",
|
| 685 |
" \n",
|
| 686 |
" def __init__(self, track_names: List[str]):\n",
|
| 687 |
" self.track_names = track_names\n",
|
| 688 |
" self.num_tracks = len(track_names)\n",
|
| 689 |
+
" # Metrics: comparing scaled targets with scaled predictions\n",
|
| 690 |
+
" self.pearson_metrics = [\n",
|
|
|
|
|
|
|
|
|
|
|
|
|
| 691 |
" PearsonCorrCoef().to(device) for _ in range(self.num_tracks)\n",
|
| 692 |
" ]\n",
|
| 693 |
" self.losses = []\n",
|
| 694 |
" \n",
|
| 695 |
" def reset(self):\n",
|
| 696 |
+
" for metric in self.pearson_metrics:\n",
|
|
|
|
|
|
|
| 697 |
" metric.reset()\n",
|
| 698 |
" self.losses = []\n",
|
| 699 |
" \n",
|
| 700 |
" def update(\n",
|
| 701 |
" self, \n",
|
| 702 |
+
" predictions: torch.Tensor, \n",
|
| 703 |
+
" targets: torch.Tensor,\n",
|
|
|
|
|
|
|
| 704 |
" loss: float\n",
|
| 705 |
" ):\n",
|
| 706 |
" \"\"\"\n",
|
| 707 |
+
" Update metrics.\n",
|
| 708 |
" Args:\n",
|
| 709 |
+
" predictions: (batch, seq_len, num_tracks)\n",
|
| 710 |
+
" targets: (batch, seq_len, num_tracks)\n",
|
|
|
|
|
|
|
| 711 |
" loss: scalar loss value\n",
|
| 712 |
" \"\"\"\n",
|
| 713 |
" # Flatten batch and sequence dimensions\n",
|
| 714 |
+
" pred_flat = predictions.detach().reshape(-1, self.num_tracks) # (N, num_tracks)\n",
|
| 715 |
+
" target_flat = targets.detach().reshape(-1, self.num_tracks) # (N, num_tracks)\n",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 716 |
" \n",
|
| 717 |
+
" # Update metrics\n",
|
| 718 |
+
" for i, metric in enumerate(self.pearson_metrics):\n",
|
| 719 |
+
" metric.update(pred_flat[:, i], target_flat[:, i])\n",
|
| 720 |
" \n",
|
| 721 |
" self.losses.append(loss)\n",
|
| 722 |
" \n",
|
| 723 |
" def compute(self) -> Dict[str, float]:\n",
|
| 724 |
+
" \"\"\"Compute and return all metrics.\"\"\"\n",
|
| 725 |
" metrics_dict = {}\n",
|
| 726 |
" \n",
|
| 727 |
+
" # Per-track Pearson correlations\n",
|
| 728 |
+
" for i, (track_name, metric) in enumerate(zip(self.track_names, self.pearson_metrics)):\n",
|
| 729 |
" corr = metric.compute().item()\n",
|
| 730 |
+
" metrics_dict[f\"{track_name}/pearson\"] = corr\n",
|
| 731 |
" \n",
|
| 732 |
+
" # Mean Pearson correlation\n",
|
| 733 |
+
" correlations = [metric.compute().item() for metric in self.pearson_metrics]\n",
|
| 734 |
+
" metrics_dict[\"mean/pearson\"] = np.nanmean(correlations)\n",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 735 |
" \n",
|
| 736 |
" # Mean loss\n",
|
| 737 |
" metrics_dict[\"loss\"] = np.mean(self.losses) if self.losses else 0.0\n",
|
|
|
|
| 741 |
},
|
| 742 |
{
|
| 743 |
"cell_type": "code",
|
| 744 |
+
"execution_count": 19,
|
| 745 |
"metadata": {},
|
| 746 |
"outputs": [],
|
| 747 |
"source": [
|
|
|
|
| 754 |
"cell_type": "markdown",
|
| 755 |
"metadata": {},
|
| 756 |
"source": [
|
| 757 |
+
"# 7. Loss functions"
|
| 758 |
]
|
| 759 |
},
|
| 760 |
{
|
| 761 |
"cell_type": "code",
|
| 762 |
+
"execution_count": 20,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 763 |
"metadata": {},
|
| 764 |
"outputs": [],
|
| 765 |
"source": [
|
|
|
|
| 776 |
"def poisson_multinomial_loss(\n",
|
| 777 |
" logits: torch.Tensor,\n",
|
| 778 |
" targets: torch.Tensor,\n",
|
|
|
|
| 779 |
" shape_loss_coefficient: float = 5.0,\n",
|
| 780 |
" epsilon: float = 1e-7,\n",
|
| 781 |
") -> tuple[torch.Tensor, torch.Tensor | None, torch.Tensor | None]:\n",
|
| 782 |
" \"\"\"\n",
|
| 783 |
" Regression loss for bigwig tracks (MSE, Poisson, or Poisson-Multinomial).\n",
|
| 784 |
" \"\"\"\n",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 785 |
"\n",
|
| 786 |
" # Scale loss\n",
|
| 787 |
+
" sum_pred = logits.sum(dim=1) # (batch, num_tracks)\n",
|
| 788 |
+
" sum_true = targets.sum(dim=1) # (batch, num_tracks)\n",
|
|
|
|
|
|
|
|
|
|
|
|
|
| 789 |
" scale_loss = poisson_loss(sum_true, sum_pred, epsilon=epsilon)\n",
|
| 790 |
+
" scale_loss = scale_loss.mean()\n",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 791 |
" \n",
|
| 792 |
" # Shape loss\n",
|
| 793 |
+
" denom = logits.sum(dim=1, keepdim=True) + epsilon\n",
|
| 794 |
+
" p_pred = logits / denom\n",
|
|
|
|
|
|
|
|
|
|
|
|
|
| 795 |
" pl_pred = safe_for_grad_log_torch(p_pred)\n",
|
| 796 |
+
" shape_loss = -(targets * pl_pred).mean()\n",
|
| 797 |
" \n",
|
| 798 |
" # Combine\n",
|
| 799 |
" loss = shape_loss + scale_loss / shape_loss_coefficient\n",
|
|
|
|
| 805 |
"cell_type": "markdown",
|
| 806 |
"metadata": {},
|
| 807 |
"source": [
|
| 808 |
+
"# 8. Training loop"
|
| 809 |
]
|
| 810 |
},
|
| 811 |
{
|
| 812 |
"cell_type": "code",
|
| 813 |
+
"execution_count": 21,
|
| 814 |
"metadata": {},
|
| 815 |
"outputs": [],
|
| 816 |
"source": [
|
| 817 |
"def train_step(\n",
|
| 818 |
" model: nn.Module,\n",
|
| 819 |
" batch: Dict[str, torch.Tensor],\n",
|
|
|
|
|
|
|
|
|
|
|
|
|
| 820 |
") -> float:\n",
|
| 821 |
+
" \"\"\"Single training step.\"\"\"\n",
|
| 822 |
" tokens = batch[\"tokens\"].to(device)\n",
|
| 823 |
+
" bigwig_targets = batch[\"bigwig_targets\"].to(device)\n",
|
| 824 |
" \n",
|
| 825 |
" # Forward pass\n",
|
| 826 |
" outputs = model(tokens=tokens)\n",
|
| 827 |
+
" bigwig_logits = outputs[\"bigwig_tracks_logits\"]\n",
|
|
|
|
|
|
|
|
|
|
| 828 |
" \n",
|
| 829 |
" # Compute loss\n",
|
| 830 |
" loss, _, _ = poisson_multinomial_loss(\n",
|
| 831 |
" logits=bigwig_logits,\n",
|
| 832 |
+
" targets=bigwig_targets,\n",
|
|
|
|
| 833 |
" )\n",
|
| 834 |
" \n",
|
| 835 |
+
" # Backward pass\n",
|
|
|
|
|
|
|
|
|
|
| 836 |
" loss.backward()\n",
|
| 837 |
+
" return loss.item()\n",
|
|
|
|
| 838 |
"\n",
|
| 839 |
"\n",
|
| 840 |
"def validation_step(\n",
|
| 841 |
" model: nn.Module,\n",
|
| 842 |
" batch: Dict[str, torch.Tensor],\n",
|
|
|
|
|
|
|
| 843 |
" metrics: TracksMetrics,\n",
|
|
|
|
| 844 |
") -> float:\n",
|
| 845 |
" \"\"\"Single validation step.\"\"\"\n",
|
| 846 |
" model.eval()\n",
|
|
|
|
| 853 |
" outputs = model(tokens=tokens)\n",
|
| 854 |
" bigwig_logits = outputs[\"bigwig_tracks_logits\"]\n",
|
| 855 |
" \n",
|
| 856 |
+
" # Compute loss\n",
|
|
|
|
|
|
|
|
|
|
| 857 |
" loss, _, _ = poisson_multinomial_loss(\n",
|
| 858 |
" logits=bigwig_logits,\n",
|
| 859 |
+
" targets=bigwig_targets,\n",
|
|
|
|
| 860 |
" )\n",
|
| 861 |
" \n",
|
| 862 |
+
" # Update metrics\n",
|
|
|
|
|
|
|
|
|
|
|
|
|
| 863 |
" metrics.update(\n",
|
| 864 |
+
" predictions=bigwig_logits,\n",
|
| 865 |
+
" targets=bigwig_targets,\n",
|
|
|
|
|
|
|
| 866 |
" loss=loss.item()\n",
|
| 867 |
" )\n",
|
| 868 |
" \n",
|
| 869 |
" return loss.item()"
|
| 870 |
]
|
| 871 |
},
|
| 872 |
+
{
|
| 873 |
+
"cell_type": "markdown",
|
| 874 |
+
"metadata": {},
|
| 875 |
+
"source": [
|
| 876 |
+
"### Interactive plotting is temporary for debug"
|
| 877 |
+
]
|
| 878 |
+
},
|
| 879 |
{
|
| 880 |
"cell_type": "code",
|
| 881 |
+
"execution_count": 22,
|
| 882 |
"metadata": {},
|
| 883 |
"outputs": [
|
| 884 |
{
|
|
|
|
| 886 |
"output_type": "stream",
|
| 887 |
"text": [
|
| 888 |
"Starting training...\n",
|
| 889 |
+
"Training for 1000 steps\n",
|
| 890 |
+
"\n"
|
| 891 |
+
]
|
| 892 |
+
},
|
| 893 |
+
{
|
| 894 |
+
"data": {
|
| 895 |
+
"application/vnd.jupyter.widget-view+json": {
|
| 896 |
+
"model_id": "5935c992adb7428bac8de1aa6873dd7e",
|
| 897 |
+
"version_major": 2,
|
| 898 |
+
"version_minor": 0
|
| 899 |
+
},
|
| 900 |
+
"text/plain": [
|
| 901 |
+
"FigureWidget({\n",
|
| 902 |
+
" 'data': [{'line': {'color': 'blue'},\n",
|
| 903 |
+
" 'mode': 'lines+markers',\n",
|
| 904 |
+
" 'name': 'Train Loss',\n",
|
| 905 |
+
" 'type': 'scatter',\n",
|
| 906 |
+
" 'uid': '5424e4af-13b6-48c8-a367-8aa145c3a9db',\n",
|
| 907 |
+
" 'x': [],\n",
|
| 908 |
+
" 'xaxis': 'x',\n",
|
| 909 |
+
" 'y': [],\n",
|
| 910 |
+
" 'yaxis': 'y'},\n",
|
| 911 |
+
" {'line': {'color': 'red'},\n",
|
| 912 |
+
" 'mode': 'lines+markers',\n",
|
| 913 |
+
" 'name': 'Val Loss',\n",
|
| 914 |
+
" 'type': 'scatter',\n",
|
| 915 |
+
" 'uid': 'fe995660-5f01-4c12-9d7d-9ed19ddee785',\n",
|
| 916 |
+
" 'x': [],\n",
|
| 917 |
+
" 'xaxis': 'x',\n",
|
| 918 |
+
" 'y': [],\n",
|
| 919 |
+
" 'yaxis': 'y'},\n",
|
| 920 |
+
" {'line': {'color': 'green'},\n",
|
| 921 |
+
" 'mode': 'lines+markers',\n",
|
| 922 |
+
" 'name': 'Train Pearson',\n",
|
| 923 |
+
" 'type': 'scatter',\n",
|
| 924 |
+
" 'uid': '8453b45b-4613-41bc-a46b-ac59ba9e6f97',\n",
|
| 925 |
+
" 'x': [],\n",
|
| 926 |
+
" 'xaxis': 'x2',\n",
|
| 927 |
+
" 'y': [],\n",
|
| 928 |
+
" 'yaxis': 'y2'},\n",
|
| 929 |
+
" {'line': {'color': 'orange'},\n",
|
| 930 |
+
" 'mode': 'lines+markers',\n",
|
| 931 |
+
" 'name': 'Val Pearson',\n",
|
| 932 |
+
" 'type': 'scatter',\n",
|
| 933 |
+
" 'uid': '0887ea97-abf9-4fcf-8ea8-c638dc153a4d',\n",
|
| 934 |
+
" 'x': [],\n",
|
| 935 |
+
" 'xaxis': 'x2',\n",
|
| 936 |
+
" 'y': [],\n",
|
| 937 |
+
" 'yaxis': 'y2'}],\n",
|
| 938 |
+
" 'layout': {'annotations': [{'font': {'size': 16},\n",
|
| 939 |
+
" 'showarrow': False,\n",
|
| 940 |
+
" 'text': 'Loss',\n",
|
| 941 |
+
" 'x': 0.2125,\n",
|
| 942 |
+
" 'xanchor': 'center',\n",
|
| 943 |
+
" 'xref': 'paper',\n",
|
| 944 |
+
" 'y': 1.0,\n",
|
| 945 |
+
" 'yanchor': 'bottom',\n",
|
| 946 |
+
" 'yref': 'paper'},\n",
|
| 947 |
+
" {'font': {'size': 16},\n",
|
| 948 |
+
" 'showarrow': False,\n",
|
| 949 |
+
" 'text': 'Mean Pearson Correlation',\n",
|
| 950 |
+
" 'x': 0.7875,\n",
|
| 951 |
+
" 'xanchor': 'center',\n",
|
| 952 |
+
" 'xref': 'paper',\n",
|
| 953 |
+
" 'y': 1.0,\n",
|
| 954 |
+
" 'yanchor': 'bottom',\n",
|
| 955 |
+
" 'yref': 'paper'}],\n",
|
| 956 |
+
" 'height': 800,\n",
|
| 957 |
+
" 'showlegend': True,\n",
|
| 958 |
+
" 'template': '...',\n",
|
| 959 |
+
" 'title': {'text': 'Training'},\n",
|
| 960 |
+
" 'width': 1600,\n",
|
| 961 |
+
" 'xaxis': {'anchor': 'y', 'domain': [0.0, 0.425], 'title': {'text': 'Step'}},\n",
|
| 962 |
+
" 'xaxis2': {'anchor': 'y2', 'domain': [0.575, 1.0], 'title': {'text': 'Step'}},\n",
|
| 963 |
+
" 'yaxis': {'anchor': 'x', 'domain': [0.0, 1.0], 'title': {'text': 'Loss'}},\n",
|
| 964 |
+
" 'yaxis2': {'anchor': 'x2', 'domain': [0.0, 1.0], 'title': {'text': 'Pearson Correlation'}}}\n",
|
| 965 |
+
"})"
|
| 966 |
+
]
|
| 967 |
+
},
|
| 968 |
+
"metadata": {},
|
| 969 |
+
"output_type": "display_data"
|
| 970 |
+
},
|
| 971 |
+
{
|
| 972 |
+
"name": "stderr",
|
| 973 |
+
"output_type": "stream",
|
| 974 |
+
"text": [
|
| 975 |
+
"/home/y-bornachot/venvs/ntv3-env/lib/python3.12/site-packages/torch/amp/autocast_mode.py:287: UserWarning:\n",
|
| 976 |
+
"\n",
|
| 977 |
+
"In CPU autocast, but the target dtype is not supported. Disabling autocast.\n",
|
| 978 |
+
"CPU Autocast only supports dtype of torch.bfloat16, torch.float16 currently.\n",
|
| 979 |
+
"\n"
|
| 980 |
+
]
|
| 981 |
+
},
|
| 982 |
+
{
|
| 983 |
+
"name": "stdout",
|
| 984 |
+
"output_type": "stream",
|
| 985 |
+
"text": [
|
| 986 |
+
"Step 10/1000 | Loss: 0.2374 | Mean Pearson: 0.0382 | LR: 1.00e-05\n",
|
| 987 |
+
"Step 20/1000 | Loss: 2.2259 | Mean Pearson: -0.0884 | LR: 1.00e-05\n",
|
| 988 |
+
"Step 30/1000 | Loss: 20.0122 | Mean Pearson: 0.1379 | LR: 1.00e-05\n",
|
| 989 |
+
"Step 40/1000 | Loss: 9.6938 | Mean Pearson: -0.1497 | LR: 1.00e-05\n",
|
| 990 |
+
"Step 50/1000 | Loss: -1.8435 | Mean Pearson: -0.1875 | LR: 1.00e-05\n",
|
| 991 |
+
"\n",
|
| 992 |
+
"Running validation at step 50...\n",
|
| 993 |
+
" Validation Loss: 11.5599\n",
|
| 994 |
+
" Validation Mean Pearson: -0.1576\n",
|
| 995 |
+
" ENCFF884LDL/pearson: -0.1576\n",
|
| 996 |
+
"Step 60/1000 | Loss: 1.4427 | Mean Pearson: 0.2841 | LR: 1.00e-05\n",
|
| 997 |
+
"Step 70/1000 | Loss: -3.4037 | Mean Pearson: -0.1362 | LR: 1.00e-05\n",
|
| 998 |
+
"Step 80/1000 | Loss: 9.0958 | Mean Pearson: -0.1319 | LR: 1.00e-05\n",
|
| 999 |
+
"Step 90/1000 | Loss: -7.8433 | Mean Pearson: -0.0576 | LR: 1.00e-05\n",
|
| 1000 |
+
"Step 100/1000 | Loss: 7.3503 | Mean Pearson: -0.2150 | LR: 1.00e-05\n",
|
| 1001 |
"\n",
|
| 1002 |
+
"Running validation at step 100...\n",
|
| 1003 |
+
" Validation Loss: 22.3383\n",
|
| 1004 |
+
" Validation Mean Pearson: -0.2867\n",
|
| 1005 |
+
" ENCFF884LDL/pearson: -0.2867\n",
|
| 1006 |
+
"Step 110/1000 | Loss: -8.1600 | Mean Pearson: -0.1616 | LR: 1.00e-05\n",
|
| 1007 |
+
"Step 120/1000 | Loss: -0.8743 | Mean Pearson: -0.1318 | LR: 1.00e-05\n",
|
| 1008 |
+
"Step 130/1000 | Loss: -2.9825 | Mean Pearson: -0.0480 | LR: 1.00e-05\n",
|
| 1009 |
+
"Step 140/1000 | Loss: -2.4524 | Mean Pearson: -0.0879 | LR: 1.00e-05\n",
|
| 1010 |
+
"Step 150/1000 | Loss: 3.8818 | Mean Pearson: -0.0907 | LR: 1.00e-05\n",
|
| 1011 |
"\n",
|
| 1012 |
+
"Running validation at step 150...\n",
|
| 1013 |
+
" Validation Loss: 19.6866\n",
|
| 1014 |
+
" Validation Mean Pearson: -0.2207\n",
|
| 1015 |
+
" ENCFF884LDL/pearson: -0.2207\n",
|
| 1016 |
+
"Step 160/1000 | Loss: -1.0933 | Mean Pearson: -0.1243 | LR: 1.00e-05\n",
|
| 1017 |
+
"Step 170/1000 | Loss: -2.2577 | Mean Pearson: -0.0212 | LR: 1.00e-05\n",
|
| 1018 |
+
"Step 180/1000 | Loss: 0.0738 | Mean Pearson: 0.5643 | LR: 1.00e-05\n",
|
| 1019 |
+
"Step 190/1000 | Loss: -0.1097 | Mean Pearson: 0.0309 | LR: 1.00e-05\n",
|
| 1020 |
+
"Step 200/1000 | Loss: -8.7972 | Mean Pearson: 0.4804 | LR: 1.00e-05\n",
|
| 1021 |
"\n",
|
| 1022 |
+
"Running validation at step 200...\n",
|
| 1023 |
+
" Validation Loss: -8.8160\n",
|
| 1024 |
+
" Validation Mean Pearson: 0.0912\n",
|
| 1025 |
+
" ENCFF884LDL/pearson: 0.0912\n",
|
| 1026 |
+
"Step 210/1000 | Loss: -2.5429 | Mean Pearson: 0.3908 | LR: 1.00e-05\n",
|
| 1027 |
+
"Step 220/1000 | Loss: -6.8421 | Mean Pearson: 0.4080 | LR: 1.00e-05\n",
|
| 1028 |
+
"Step 230/1000 | Loss: -4.4312 | Mean Pearson: -0.0400 | LR: 1.00e-05\n",
|
| 1029 |
+
"Step 240/1000 | Loss: -11.4732 | Mean Pearson: 0.6653 | LR: 1.00e-05\n",
|
| 1030 |
+
"Step 250/1000 | Loss: -9.2648 | Mean Pearson: 0.0539 | LR: 1.00e-05\n",
|
| 1031 |
"\n",
|
| 1032 |
+
"Running validation at step 250...\n",
|
| 1033 |
+
" Validation Loss: -6.8987\n",
|
| 1034 |
+
" Validation Mean Pearson: 0.0654\n",
|
| 1035 |
+
" ENCFF884LDL/pearson: 0.0654\n",
|
| 1036 |
+
"Step 260/1000 | Loss: -0.6699 | Mean Pearson: 0.0913 | LR: 1.00e-05\n",
|
| 1037 |
+
"Step 270/1000 | Loss: -8.6625 | Mean Pearson: 0.3179 | LR: 1.00e-05\n",
|
| 1038 |
+
"Step 280/1000 | Loss: -11.7691 | Mean Pearson: 0.0004 | LR: 1.00e-05\n",
|
| 1039 |
+
"Step 290/1000 | Loss: -14.1622 | Mean Pearson: 0.0492 | LR: 1.00e-05\n",
|
| 1040 |
+
"Step 300/1000 | Loss: 0.9208 | Mean Pearson: 0.0607 | LR: 1.00e-05\n",
|
| 1041 |
"\n",
|
| 1042 |
+
"Running validation at step 300...\n",
|
| 1043 |
+
" Validation Loss: -5.0427\n",
|
| 1044 |
+
" Validation Mean Pearson: 0.3464\n",
|
| 1045 |
+
" ENCFF884LDL/pearson: 0.3464\n",
|
| 1046 |
+
"Step 310/1000 | Loss: -1.2881 | Mean Pearson: 0.1696 | LR: 1.00e-05\n",
|
| 1047 |
+
"Step 320/1000 | Loss: -18.6637 | Mean Pearson: 0.0892 | LR: 1.00e-05\n",
|
| 1048 |
+
"Step 330/1000 | Loss: -36.6038 | Mean Pearson: 0.3356 | LR: 1.00e-05\n",
|
| 1049 |
+
"Step 340/1000 | Loss: -2.4984 | Mean Pearson: 0.2305 | LR: 1.00e-05\n",
|
| 1050 |
+
"Step 350/1000 | Loss: -4.7985 | Mean Pearson: 0.0968 | LR: 1.00e-05\n",
|
| 1051 |
"\n",
|
| 1052 |
+
"Running validation at step 350...\n",
|
| 1053 |
+
" Validation Loss: -13.6500\n",
|
| 1054 |
+
" Validation Mean Pearson: 0.2737\n",
|
| 1055 |
+
" ENCFF884LDL/pearson: 0.2737\n",
|
| 1056 |
+
"Step 360/1000 | Loss: -9.4795 | Mean Pearson: 0.0579 | LR: 1.00e-05\n",
|
| 1057 |
+
"Step 370/1000 | Loss: 0.3531 | Mean Pearson: 0.0240 | LR: 1.00e-05\n",
|
| 1058 |
+
"Step 380/1000 | Loss: -5.7921 | Mean Pearson: 0.4119 | LR: 1.00e-05\n",
|
| 1059 |
+
"Step 390/1000 | Loss: -2.7049 | Mean Pearson: 0.1343 | LR: 1.00e-05\n",
|
| 1060 |
+
"Step 400/1000 | Loss: -32.8422 | Mean Pearson: 0.1545 | LR: 1.00e-05\n",
|
| 1061 |
"\n",
|
| 1062 |
+
"Running validation at step 400...\n",
|
| 1063 |
+
" Validation Loss: -4.3502\n",
|
| 1064 |
+
" Validation Mean Pearson: 0.3124\n",
|
| 1065 |
+
" ENCFF884LDL/pearson: 0.3124\n",
|
| 1066 |
+
"Step 410/1000 | Loss: -18.9574 | Mean Pearson: 0.0594 | LR: 1.00e-05\n",
|
| 1067 |
+
"Step 420/1000 | Loss: -5.4032 | Mean Pearson: 0.2804 | LR: 1.00e-05\n",
|
| 1068 |
+
"Step 430/1000 | Loss: -0.5171 | Mean Pearson: 0.1835 | LR: 1.00e-05\n",
|
| 1069 |
+
"Step 440/1000 | Loss: -3.4071 | Mean Pearson: 0.0680 | LR: 1.00e-05\n",
|
| 1070 |
+
"Step 450/1000 | Loss: -3.5580 | Mean Pearson: 0.0850 | LR: 1.00e-05\n",
|
| 1071 |
"\n",
|
| 1072 |
+
"Running validation at step 450...\n",
|
| 1073 |
+
" Validation Loss: -7.3308\n",
|
| 1074 |
+
" Validation Mean Pearson: 0.1128\n",
|
| 1075 |
+
" ENCFF884LDL/pearson: 0.1128\n",
|
| 1076 |
+
"Step 460/1000 | Loss: -0.9750 | Mean Pearson: 0.1717 | LR: 1.00e-05\n",
|
| 1077 |
+
"Step 470/1000 | Loss: -5.5775 | Mean Pearson: 0.1321 | LR: 1.00e-05\n",
|
| 1078 |
+
"Step 480/1000 | Loss: -1.1170 | Mean Pearson: 0.1484 | LR: 1.00e-05\n",
|
| 1079 |
+
"Step 490/1000 | Loss: -3.8053 | Mean Pearson: 0.1959 | LR: 1.00e-05\n",
|
| 1080 |
+
"Step 500/1000 | Loss: -4.5933 | Mean Pearson: 0.1860 | LR: 1.00e-05\n",
|
| 1081 |
"\n",
|
| 1082 |
+
"Running validation at step 500...\n",
|
| 1083 |
+
" Validation Loss: -5.7617\n",
|
| 1084 |
+
" Validation Mean Pearson: 0.3155\n",
|
| 1085 |
+
" ENCFF884LDL/pearson: 0.3155\n",
|
| 1086 |
+
"Step 510/1000 | Loss: -3.3306 | Mean Pearson: 0.2815 | LR: 1.00e-05\n",
|
| 1087 |
+
"Step 520/1000 | Loss: -2.1962 | Mean Pearson: 0.1151 | LR: 1.00e-05\n",
|
| 1088 |
+
"Step 530/1000 | Loss: -1.5388 | Mean Pearson: 0.3783 | LR: 1.00e-05\n",
|
| 1089 |
+
"Step 540/1000 | Loss: -2.2349 | Mean Pearson: 0.0734 | LR: 1.00e-05\n",
|
| 1090 |
+
"Step 550/1000 | Loss: -1.5502 | Mean Pearson: 0.2171 | LR: 1.00e-05\n",
|
| 1091 |
"\n",
|
| 1092 |
+
"Running validation at step 550...\n",
|
| 1093 |
+
" Validation Loss: -3.0059\n",
|
| 1094 |
+
" Validation Mean Pearson: 0.2325\n",
|
| 1095 |
+
" ENCFF884LDL/pearson: 0.2325\n",
|
| 1096 |
+
"Step 560/1000 | Loss: -2.0764 | Mean Pearson: -0.0049 | LR: 1.00e-05\n",
|
| 1097 |
+
"Step 570/1000 | Loss: -1.7384 | Mean Pearson: 0.2989 | LR: 1.00e-05\n",
|
| 1098 |
+
"Step 580/1000 | Loss: -6.7306 | Mean Pearson: 0.2522 | LR: 1.00e-05\n",
|
| 1099 |
+
"Step 590/1000 | Loss: -3.2473 | Mean Pearson: 0.1042 | LR: 1.00e-05\n",
|
| 1100 |
+
"Step 600/1000 | Loss: -4.2841 | Mean Pearson: 0.1936 | LR: 1.00e-05\n",
|
| 1101 |
+
"\n",
|
| 1102 |
+
"Running validation at step 600...\n",
|
| 1103 |
+
" Validation Loss: -4.5611\n",
|
| 1104 |
+
" Validation Mean Pearson: 0.2744\n",
|
| 1105 |
+
" ENCFF884LDL/pearson: 0.2744\n",
|
| 1106 |
+
"Step 610/1000 | Loss: -3.5691 | Mean Pearson: 0.1803 | LR: 1.00e-05\n",
|
| 1107 |
+
"Step 620/1000 | Loss: -7.2129 | Mean Pearson: 0.0901 | LR: 1.00e-05\n",
|
| 1108 |
+
"Step 630/1000 | Loss: -6.0598 | Mean Pearson: 0.1795 | LR: 1.00e-05\n",
|
| 1109 |
+
"Step 640/1000 | Loss: -2.8917 | Mean Pearson: 0.1111 | LR: 1.00e-05\n",
|
| 1110 |
+
"Step 650/1000 | Loss: -2.7210 | Mean Pearson: 0.3566 | LR: 1.00e-05\n",
|
| 1111 |
+
"\n",
|
| 1112 |
+
"Running validation at step 650...\n",
|
| 1113 |
+
" Validation Loss: -4.3997\n",
|
| 1114 |
+
" Validation Mean Pearson: 0.3327\n",
|
| 1115 |
+
" ENCFF884LDL/pearson: 0.3327\n",
|
| 1116 |
+
"Step 660/1000 | Loss: -3.4793 | Mean Pearson: 0.0441 | LR: 1.00e-05\n",
|
| 1117 |
+
"Step 670/1000 | Loss: -1.9743 | Mean Pearson: 0.1364 | LR: 1.00e-05\n",
|
| 1118 |
+
"Step 680/1000 | Loss: -5.7498 | Mean Pearson: 0.2330 | LR: 1.00e-05\n",
|
| 1119 |
+
"Step 690/1000 | Loss: -12.8701 | Mean Pearson: 0.3182 | LR: 1.00e-05\n",
|
| 1120 |
+
"Step 700/1000 | Loss: -1.5847 | Mean Pearson: 0.1971 | LR: 1.00e-05\n",
|
| 1121 |
+
"\n",
|
| 1122 |
+
"Running validation at step 700...\n",
|
| 1123 |
+
" Validation Loss: -2.0630\n",
|
| 1124 |
+
" Validation Mean Pearson: 0.1267\n",
|
| 1125 |
+
" ENCFF884LDL/pearson: 0.1267\n",
|
| 1126 |
+
"Step 710/1000 | Loss: -6.0704 | Mean Pearson: 0.3715 | LR: 1.00e-05\n",
|
| 1127 |
+
"Step 720/1000 | Loss: -2.6020 | Mean Pearson: 0.1244 | LR: 1.00e-05\n",
|
| 1128 |
+
"Step 730/1000 | Loss: -58.8965 | Mean Pearson: 0.5625 | LR: 1.00e-05\n",
|
| 1129 |
+
"Step 740/1000 | Loss: -1.2855 | Mean Pearson: 0.2658 | LR: 1.00e-05\n",
|
| 1130 |
+
"Step 750/1000 | Loss: -4.4599 | Mean Pearson: 0.0137 | LR: 1.00e-05\n",
|
| 1131 |
+
"\n",
|
| 1132 |
+
"Running validation at step 750...\n",
|
| 1133 |
+
" Validation Loss: -11.1562\n",
|
| 1134 |
+
" Validation Mean Pearson: 0.0844\n",
|
| 1135 |
+
" ENCFF884LDL/pearson: 0.0844\n",
|
| 1136 |
+
"Step 760/1000 | Loss: -11.6905 | Mean Pearson: 0.1914 | LR: 1.00e-05\n",
|
| 1137 |
+
"Step 770/1000 | Loss: -4.0964 | Mean Pearson: 0.2022 | LR: 1.00e-05\n",
|
| 1138 |
+
"Step 780/1000 | Loss: -1.5512 | Mean Pearson: 0.3568 | LR: 1.00e-05\n",
|
| 1139 |
+
"Step 790/1000 | Loss: -5.5843 | Mean Pearson: 0.2058 | LR: 1.00e-05\n",
|
| 1140 |
+
"Step 800/1000 | Loss: -3.9190 | Mean Pearson: 0.4362 | LR: 1.00e-05\n",
|
| 1141 |
+
"\n",
|
| 1142 |
+
"Running validation at step 800...\n",
|
| 1143 |
+
" Validation Loss: -4.7017\n",
|
| 1144 |
+
" Validation Mean Pearson: 0.3817\n",
|
| 1145 |
+
" ENCFF884LDL/pearson: 0.3817\n",
|
| 1146 |
+
"Step 810/1000 | Loss: -7.6856 | Mean Pearson: 0.0672 | LR: 1.00e-05\n",
|
| 1147 |
+
"Step 820/1000 | Loss: -5.3603 | Mean Pearson: 0.2325 | LR: 1.00e-05\n",
|
| 1148 |
+
"Step 830/1000 | Loss: -3.8539 | Mean Pearson: 0.2808 | LR: 1.00e-05\n",
|
| 1149 |
+
"Step 840/1000 | Loss: -8.1141 | Mean Pearson: 0.2529 | LR: 1.00e-05\n",
|
| 1150 |
+
"Step 850/1000 | Loss: -10.5886 | Mean Pearson: 0.3454 | LR: 1.00e-05\n",
|
| 1151 |
+
"\n",
|
| 1152 |
+
"Running validation at step 850...\n",
|
| 1153 |
+
" Validation Loss: -4.9108\n",
|
| 1154 |
+
" Validation Mean Pearson: 0.2195\n",
|
| 1155 |
+
" ENCFF884LDL/pearson: 0.2195\n",
|
| 1156 |
+
"Step 860/1000 | Loss: -4.1028 | Mean Pearson: 0.3304 | LR: 1.00e-05\n",
|
| 1157 |
+
"Step 870/1000 | Loss: -7.1834 | Mean Pearson: 0.1206 | LR: 1.00e-05\n",
|
| 1158 |
+
"Step 880/1000 | Loss: -8.9869 | Mean Pearson: 0.3584 | LR: 1.00e-05\n",
|
| 1159 |
+
"Step 890/1000 | Loss: -2.2697 | Mean Pearson: 0.0943 | LR: 1.00e-05\n",
|
| 1160 |
+
"Step 900/1000 | Loss: -14.0142 | Mean Pearson: 0.4761 | LR: 1.00e-05\n",
|
| 1161 |
+
"\n",
|
| 1162 |
+
"Running validation at step 900...\n",
|
| 1163 |
+
" Validation Loss: -3.2329\n",
|
| 1164 |
+
" Validation Mean Pearson: 0.3635\n",
|
| 1165 |
+
" ENCFF884LDL/pearson: 0.3635\n",
|
| 1166 |
+
"Step 910/1000 | Loss: -9.0941 | Mean Pearson: 0.2754 | LR: 1.00e-05\n",
|
| 1167 |
+
"Step 920/1000 | Loss: -4.6371 | Mean Pearson: 0.0167 | LR: 1.00e-05\n",
|
| 1168 |
+
"Step 930/1000 | Loss: -7.9853 | Mean Pearson: 0.0941 | LR: 1.00e-05\n",
|
| 1169 |
+
"Step 940/1000 | Loss: -22.9349 | Mean Pearson: 0.5140 | LR: 1.00e-05\n",
|
| 1170 |
+
"Step 950/1000 | Loss: -2.0866 | Mean Pearson: 0.1746 | LR: 1.00e-05\n",
|
| 1171 |
+
"\n",
|
| 1172 |
+
"Running validation at step 950...\n",
|
| 1173 |
+
" Validation Loss: -8.8318\n",
|
| 1174 |
+
" Validation Mean Pearson: 0.1597\n",
|
| 1175 |
+
" ENCFF884LDL/pearson: 0.1597\n",
|
| 1176 |
+
"Step 960/1000 | Loss: -4.8540 | Mean Pearson: 0.6318 | LR: 1.00e-05\n",
|
| 1177 |
+
"Step 970/1000 | Loss: -4.1091 | Mean Pearson: 0.0985 | LR: 1.00e-05\n",
|
| 1178 |
+
"Step 980/1000 | Loss: -5.1141 | Mean Pearson: 0.2031 | LR: 1.00e-05\n",
|
| 1179 |
+
"Step 990/1000 | Loss: -4.1959 | Mean Pearson: 0.2404 | LR: 1.00e-05\n",
|
| 1180 |
+
"Step 1000/1000 | Loss: -0.9942 | Mean Pearson: 0.2742 | LR: 1.00e-05\n",
|
| 1181 |
+
"\n",
|
| 1182 |
+
"Running validation at step 1000...\n",
|
| 1183 |
+
" Validation Loss: -4.2796\n",
|
| 1184 |
+
" Validation Mean Pearson: 0.1425\n",
|
| 1185 |
+
" ENCFF884LDL/pearson: 0.1425\n",
|
| 1186 |
+
"\n",
|
| 1187 |
+
"Training completed after 1000 steps.\n"
|
| 1188 |
]
|
| 1189 |
}
|
| 1190 |
],
|
| 1191 |
"source": [
|
| 1192 |
+
"# Training loop\n",
|
| 1193 |
"print(\"Starting training...\")\n",
|
| 1194 |
+
"print(f\"Training for {config[\"num_steps_training\"]} steps\\n\")\n",
|
| 1195 |
"\n",
|
| 1196 |
"model.train()\n",
|
| 1197 |
"train_metrics.reset()\n",
|
| 1198 |
"optimizer.zero_grad() # Initialize gradients\n",
|
| 1199 |
"\n",
|
| 1200 |
+
"# Track metrics for plotting\n",
|
| 1201 |
+
"train_steps = []\n",
|
| 1202 |
+
"train_losses = []\n",
|
| 1203 |
+
"train_pearson_scores = []\n",
|
| 1204 |
+
"val_steps = []\n",
|
| 1205 |
+
"val_losses = []\n",
|
| 1206 |
+
"val_pearson_scores = []\n",
|
| 1207 |
+
"\n",
|
| 1208 |
+
"# Initialize interactive plots using FigureWidget for real-time updates\n",
|
| 1209 |
+
"from plotly.graph_objects import FigureWidget\n",
|
| 1210 |
+
"from plotly.subplots import make_subplots\n",
|
| 1211 |
+
"\n",
|
| 1212 |
+
"# Create base figure with subplots\n",
|
| 1213 |
+
"fig_base = make_subplots(\n",
|
| 1214 |
+
" rows=1, cols=2,\n",
|
| 1215 |
+
" subplot_titles=('Loss', 'Mean Pearson Correlation'),\n",
|
| 1216 |
+
" horizontal_spacing=0.15,\n",
|
| 1217 |
+
")\n",
|
| 1218 |
+
"\n",
|
| 1219 |
+
"# Add empty traces for train and val metrics\n",
|
| 1220 |
+
"fig_base.add_trace(\n",
|
| 1221 |
+
" go.Scatter(x=[], y=[], mode='lines+markers', name='Train Loss', line=dict(color='blue')),\n",
|
| 1222 |
+
" row=1, col=1\n",
|
| 1223 |
+
")\n",
|
| 1224 |
+
"fig_base.add_trace(\n",
|
| 1225 |
+
" go.Scatter(x=[], y=[], mode='lines+markers', name='Val Loss', line=dict(color='red')),\n",
|
| 1226 |
+
" row=1, col=1\n",
|
| 1227 |
+
")\n",
|
| 1228 |
+
"fig_base.add_trace(\n",
|
| 1229 |
+
" go.Scatter(x=[], y=[], mode='lines+markers', name='Train Pearson', line=dict(color='green')),\n",
|
| 1230 |
+
" row=1, col=2\n",
|
| 1231 |
+
")\n",
|
| 1232 |
+
"fig_base.add_trace(\n",
|
| 1233 |
+
" go.Scatter(x=[], y=[], mode='lines+markers', name='Val Pearson', line=dict(color='orange')),\n",
|
| 1234 |
+
" row=1, col=2\n",
|
| 1235 |
+
")\n",
|
| 1236 |
+
"\n",
|
| 1237 |
+
"fig_base.update_xaxes(title_text=\"Step\", row=1, col=1)\n",
|
| 1238 |
+
"fig_base.update_xaxes(title_text=\"Step\", row=1, col=2)\n",
|
| 1239 |
+
"fig_base.update_yaxes(title_text=\"Loss\", row=1, col=1)\n",
|
| 1240 |
+
"fig_base.update_yaxes(title_text=\"Pearson Correlation\", row=1, col=2)\n",
|
| 1241 |
+
"fig_base.update_layout(height=800, width=1600, showlegend=True, title_text=\"Training\")\n",
|
| 1242 |
+
"\n",
|
| 1243 |
+
"# Convert to FigureWidget for interactive updates\n",
|
| 1244 |
+
"fig = FigureWidget(fig_base)\n",
|
| 1245 |
+
"\n",
|
| 1246 |
+
"# Display initial plot (will update in place during training)\n",
|
| 1247 |
+
"display(fig)\n",
|
| 1248 |
+
"\n",
|
| 1249 |
"# Create iterator for training data (will cycle if needed)\n",
|
| 1250 |
"train_iter = iter(train_loader)\n",
|
| 1251 |
+
"\n",
|
| 1252 |
+
"# Main training loop\n",
|
| 1253 |
+
"for step_idx in range(config[\"num_steps_training\"]):\n",
|
| 1254 |
+
" try:\n",
|
| 1255 |
+
" batch = next(train_iter)\n",
|
| 1256 |
+
" except StopIteration:\n",
|
| 1257 |
+
" # Restart iterator if we run out of data\n",
|
| 1258 |
+
" train_iter = iter(train_loader)\n",
|
| 1259 |
+
" batch = next(train_iter)\n",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1260 |
" \n",
|
| 1261 |
+
" # Forward pass and backward pass\n",
|
| 1262 |
+
" loss = train_step(model, batch)\n",
|
| 1263 |
+
" \n",
|
| 1264 |
+
" # Update optimizer\n",
|
| 1265 |
" optimizer.step()\n",
|
| 1266 |
" optimizer.zero_grad()\n",
|
| 1267 |
" \n",
|
| 1268 |
+
" # Update metrics\n",
|
|
|
|
|
|
|
|
|
|
| 1269 |
" tokens = batch[\"tokens\"].to(device)\n",
|
| 1270 |
" bigwig_targets = batch[\"bigwig_targets\"].to(device)\n",
|
| 1271 |
" with torch.no_grad():\n",
|
| 1272 |
" outputs = model(tokens=tokens)\n",
|
| 1273 |
" bigwig_logits = outputs[\"bigwig_tracks_logits\"]\n",
|
| 1274 |
" \n",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1275 |
" train_metrics.update(\n",
|
| 1276 |
+
" predictions=bigwig_logits,\n",
|
| 1277 |
+
" targets=bigwig_targets,\n",
|
| 1278 |
+
" loss=loss\n",
|
|
|
|
|
|
|
| 1279 |
" )\n",
|
| 1280 |
" \n",
|
| 1281 |
" # Logging\n",
|
| 1282 |
+
" if (step_idx + 1) % config[\"log_every_n_steps\"] == 0:\n",
|
| 1283 |
" train_metrics_dict = train_metrics.compute()\n",
|
| 1284 |
+
" current_lr = optimizer.param_groups[0]['lr']\n",
|
| 1285 |
+
" \n",
|
| 1286 |
+
" # Track metrics for plotting\n",
|
| 1287 |
+
" train_steps.append(step_idx + 1)\n",
|
| 1288 |
+
" train_losses.append(loss)\n",
|
| 1289 |
+
" train_pearson_scores.append(train_metrics_dict['mean/pearson'])\n",
|
| 1290 |
+
" \n",
|
| 1291 |
+
" # Update plots - direct assignment to FigureWidget data updates the plot automatically\n",
|
| 1292 |
+
" fig.data[0].x = train_steps\n",
|
| 1293 |
+
" fig.data[0].y = train_losses\n",
|
| 1294 |
+
" fig.data[2].x = train_steps\n",
|
| 1295 |
+
" fig.data[2].y = train_pearson_scores\n",
|
| 1296 |
+
" \n",
|
| 1297 |
+
" print(f\"Step {step_idx + 1}/{config[\"num_steps_training\"]} | \"\n",
|
| 1298 |
+
" f\"Loss: {loss:.4f} | \"\n",
|
| 1299 |
+
" f\"Mean Pearson: {train_metrics_dict['mean/pearson']:.4f} | \"\n",
|
| 1300 |
+
" f\"LR: {current_lr:.2e}\")\n",
|
| 1301 |
" train_metrics.reset()\n",
|
| 1302 |
" \n",
|
| 1303 |
" # Validation\n",
|
| 1304 |
+
" if (step_idx + 1) % config[\"validate_every_n_steps\"] == 0:\n",
|
| 1305 |
+
" print(f\"\\nRunning validation at step {step_idx + 1}...\")\n",
|
| 1306 |
" val_metrics.reset()\n",
|
| 1307 |
" model.eval()\n",
|
| 1308 |
" \n",
|
| 1309 |
+
" val_batch_losses = []\n",
|
| 1310 |
" for val_batch in val_loader:\n",
|
| 1311 |
+
" val_loss = validation_step(model, val_batch, val_metrics)\n",
|
| 1312 |
+
" val_batch_losses.append(val_loss)\n",
|
|
|
|
|
|
|
| 1313 |
" \n",
|
| 1314 |
" # Print validation metrics\n",
|
| 1315 |
" val_metrics_dict = val_metrics.compute()\n",
|
| 1316 |
+
" val_loss_mean = np.mean(val_batch_losses)\n",
|
| 1317 |
+
" val_pearson_mean = val_metrics_dict['mean/pearson']\n",
|
| 1318 |
+
" \n",
|
| 1319 |
+
" # Track validation metrics\n",
|
| 1320 |
+
" val_steps.append(step_idx + 1)\n",
|
| 1321 |
+
" val_losses.append(val_loss_mean)\n",
|
| 1322 |
+
" val_pearson_scores.append(val_pearson_mean)\n",
|
| 1323 |
+
" \n",
|
| 1324 |
+
" # Update plots with validation data - direct assignment updates the plot automatically\n",
|
| 1325 |
+
" fig.data[1].x = val_steps\n",
|
| 1326 |
+
" fig.data[1].y = val_losses\n",
|
| 1327 |
+
" fig.data[3].x = val_steps\n",
|
| 1328 |
+
" fig.data[3].y = val_pearson_scores\n",
|
| 1329 |
+
" \n",
|
| 1330 |
+
" print(f\" Validation Loss: {val_loss_mean:.4f}\")\n",
|
| 1331 |
+
" print(f\" Validation Mean Pearson: {val_pearson_mean:.4f}\")\n",
|
| 1332 |
" for track_name in config[\"bigwig_file_ids\"]:\n",
|
| 1333 |
+
" print(f\" {track_name}/pearson: {val_metrics_dict[f'{track_name}/pearson']:.4f}\")\n",
|
| 1334 |
" \n",
|
| 1335 |
" model.train() # Back to training mode\n",
|
| 1336 |
"\n",
|
| 1337 |
+
"print(f\"\\nTraining completed after {config[\"num_steps_training\"]} steps.\")"
|
| 1338 |
]
|
| 1339 |
},
|
| 1340 |
{
|
|
|
|
| 1346 |
},
|
| 1347 |
{
|
| 1348 |
"cell_type": "code",
|
| 1349 |
+
"execution_count": 24,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1350 |
"metadata": {},
|
| 1351 |
"outputs": [
|
| 1352 |
{
|
| 1353 |
"name": "stdout",
|
| 1354 |
"output_type": "stream",
|
| 1355 |
"text": [
|
| 1356 |
+
"Running test evaluation with 12 steps (100 samples)\n",
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1357 |
"\n",
|
| 1358 |
"==================================================\n",
|
| 1359 |
"Test Set Results\n",
|
| 1360 |
"==================================================\n",
|
| 1361 |
"\n",
|
| 1362 |
+
"Metrics:\n",
|
| 1363 |
+
" Mean Pearson: 0.1787\n",
|
| 1364 |
+
" ENCFF884LDL/pearson: 0.1787\n"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1365 |
]
|
| 1366 |
}
|
| 1367 |
],
|
| 1368 |
"source": [
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1369 |
"# Calculate number of test steps (based on deepspeed pipeline)\n",
|
| 1370 |
"num_test_samples = len(test_dataset)\n",
|
| 1371 |
"num_test_steps = num_test_samples // config[\"batch_size\"]\n",
|
|
|
|
| 1372 |
"print(f\"Running test evaluation with {num_test_steps} steps ({num_test_samples} samples)\")\n",
|
| 1373 |
"\n",
|
| 1374 |
"# Set model to eval mode\n",
|
| 1375 |
"model.eval()\n",
|
| 1376 |
"\n",
|
| 1377 |
+
"for test_batch in test_loader: \n",
|
|
|
|
| 1378 |
"\n",
|
| 1379 |
+
" _ = validation_step( \n",
|
| 1380 |
+
" model, \n",
|
| 1381 |
+
" test_batch, \n",
|
| 1382 |
+
" test_metrics,\n",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1383 |
" )\n",
|
| 1384 |
+
" \n",
|
| 1385 |
"# Compute final test metrics\n",
|
| 1386 |
"test_metrics_dict = test_metrics.compute()\n",
|
|
|
|
| 1387 |
"print(\"\\n\" + \"=\"*50)\n",
|
| 1388 |
"print(\"Test Set Results\")\n",
|
| 1389 |
"print(\"=\"*50)\n",
|
| 1390 |
+
"print(f\"\\nMetrics:\")\n",
|
| 1391 |
+
"print(f\" Mean Pearson: {test_metrics_dict['mean/pearson']:.4f}\")\n",
|
| 1392 |
+
"for track_name in config[\"bigwig_file_ids\"]: \n",
|
| 1393 |
+
" print(f\" {track_name}/pearson: {test_metrics_dict[f'{track_name}/pearson']:.4f}\")"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1394 |
]
|
| 1395 |
+
},
|
| 1396 |
+
{
|
| 1397 |
+
"cell_type": "code",
|
| 1398 |
+
"execution_count": null,
|
| 1399 |
+
"metadata": {},
|
| 1400 |
+
"outputs": [],
|
| 1401 |
+
"source": []
|
| 1402 |
}
|
| 1403 |
],
|
| 1404 |
"metadata": {
|