ybornachot commited on
Commit
b6b1c80
·
1 Parent(s): e712656

fix: notebook simplification

Browse files
Files changed (1) hide show
  1. notebooks/03_fine_tuning.ipynb +598 -523
notebooks/03_fine_tuning.ipynb CHANGED
@@ -16,7 +16,7 @@
16
  },
17
  {
18
  "cell_type": "code",
19
- "execution_count": null,
20
  "metadata": {},
21
  "outputs": [],
22
  "source": [
@@ -28,14 +28,14 @@
28
  },
29
  {
30
  "cell_type": "code",
31
- "execution_count": 29,
32
  "metadata": {},
33
  "outputs": [],
34
  "source": [
35
  "# 0. Imports\n",
36
  "import random\n",
37
  "import functools\n",
38
- "from typing import List, Dict, Optional, Callable\n",
39
  "import os\n",
40
  "import subprocess\n",
41
  "\n",
@@ -48,19 +48,50 @@
48
  "import numpy as np\n",
49
  "import pyBigWig\n",
50
  "from pyfaidx import Fasta\n",
51
- "from torchmetrics import PearsonCorrCoef"
 
 
 
52
  ]
53
  },
54
  {
55
  "cell_type": "markdown",
56
  "metadata": {},
57
  "source": [
58
- "# 1. Configuration setup"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
59
  ]
60
  },
61
  {
62
  "cell_type": "code",
63
- "execution_count": 30,
64
  "metadata": {},
65
  "outputs": [
66
  {
@@ -74,37 +105,32 @@
74
  "source": [
75
  "config = {\n",
76
  " # Model\n",
77
- " \"model_name\": \"InstaDeepAI/ntv3_8M_7downsample_pretrained_le_1mb\", # HuggingFace model name/identifier\n",
78
  " \n",
79
  " # Data\n",
80
- " \"data_cache_dir\": \"./data\", # Directory where downloaded data files (FASTA, bigWig) will be stored\n",
81
- " \"fasta_url\": \"https://hgdownload.gi.ucsc.edu/goldenPath/hg38/bigZips/hg38.fa.gz\", # URL to download reference genome FASTA file\n",
82
- " \"bigwig_url_list\": [\"https://www.encodeproject.org/files/ENCFF884LDL/@@download/ENCFF884LDL.bigWig\"], # List of URLs for bigWig track files to download\n",
83
- " \"sequence_length\": 1_024, # Length of input sequences in base pairs (bp)\n",
84
- " \"keep_target_center_fraction\": 0.375, # Fraction of center sequence to keep for target prediction (crops edges to focus on center)\n",
 
 
85
  " \n",
86
  " # Training\n",
87
- " \"batch_size\": 2, # Number of samples per batch\n",
88
- " \"learning_rate\": 1e-5, # Constant learning rate for optimizer\n",
89
- " \"weight_decay\": 0.01, # L2 regularization coefficient for optimizer\n",
90
- " \n",
91
- " \"num_tokens_training\": 131_072, # Total training tokens budget (determines total training steps)\n",
92
- " \"num_tokens_per_update\": 4_096, # Target tokens per optimizer update (batch_size * seq_len * grad_accum)\n",
93
- " \"num_tokens_per_log\": 8_192, # Tokens between training logs (how often to print metrics)\n",
94
- " \"num_tokens_per_validation\": 16_384, # Tokens between validation runs (how often to evaluate on validation set)\n",
95
  " \n",
96
  " # Validation\n",
97
- " \"num_validation_samples\": 10, # Number of samples to use for validation set\n",
98
- " \n",
99
- " # Loss\n",
100
- " \"bigwig_loss_weight\": 1.0, # Weight multiplier for bigwig prediction loss\n",
101
- " \"bigwig_scalar_loss_function\": \"poisson-multinomial\", # Loss function type for bigwig tracks\n",
102
- " \"bigwig_shape_loss_coefficient\": 5.0, # Coefficient balancing shape loss vs scale loss in poisson-multinomial loss\n",
103
  " \n",
104
  " # General\n",
105
- " \"seed\": 42, # Random seed for reproducibility\n",
106
- " \"device\": \"cuda\" if torch.cuda.is_available() else \"cpu\", # Device to run training on (\"cuda\" or \"cpu\")\n",
107
- " \"num_workers\": 0, # Number of worker processes for DataLoader (0 = single-threaded)\n",
108
  "}\n",
109
  "\n",
110
  "os.makedirs(config[\"data_cache_dir\"], exist_ok=True)\n",
@@ -216,7 +242,7 @@
216
  },
217
  {
218
  "cell_type": "code",
219
- "execution_count": 31,
220
  "metadata": {},
221
  "outputs": [],
222
  "source": [
@@ -236,7 +262,7 @@
236
  },
237
  {
238
  "cell_type": "code",
239
- "execution_count": 32,
240
  "metadata": {},
241
  "outputs": [],
242
  "source": [
@@ -304,7 +330,7 @@
304
  },
305
  {
306
  "cell_type": "code",
307
- "execution_count": 33,
308
  "metadata": {},
309
  "outputs": [
310
  {
@@ -335,6 +361,49 @@
335
  "print(f\"Model parameters: {sum(p.numel() for p in model.parameters()):,}\")"
336
  ]
337
  },
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
338
  {
339
  "cell_type": "markdown",
340
  "metadata": {},
@@ -344,7 +413,7 @@
344
  },
345
  {
346
  "cell_type": "code",
347
- "execution_count": 34,
348
  "metadata": {},
349
  "outputs": [],
350
  "source": [
@@ -388,6 +457,7 @@
388
  " sequence_length: int,\n",
389
  " num_samples: int,\n",
390
  " tokenizer: AutoTokenizer,\n",
 
391
  " keep_target_center_fraction: float = 1.0,\n",
392
  " num_tracks: int = 1,\n",
393
  " ):\n",
@@ -401,6 +471,7 @@
401
  " self.sequence_length = sequence_length\n",
402
  " self.num_samples = num_samples\n",
403
  " self.tokenizer = tokenizer\n",
 
404
  " self.keep_target_center_fraction = keep_target_center_fraction\n",
405
  " self.num_tracks = num_tracks\n",
406
  " self.chroms = chroms\n",
@@ -465,6 +536,9 @@
465
  " target_length = seq_len - 2 * target_offset\n",
466
  " bigwig_targets = bigwig_targets[target_offset:target_offset + target_length, :]\n",
467
  "\n",
 
 
 
468
  " sample = {\n",
469
  " \"tokens\": tokens,\n",
470
  " \"bigwig_targets\": bigwig_targets,\n",
@@ -477,7 +551,7 @@
477
  },
478
  {
479
  "cell_type": "code",
480
- "execution_count": 35,
481
  "metadata": {},
482
  "outputs": [
483
  {
@@ -485,18 +559,22 @@
485
  "output_type": "stream",
486
  "text": [
487
  "Train samples: 100\n",
488
- "Val samples: 10\n",
489
- "Test samples: 10\n"
490
  ]
491
  }
492
  ],
493
  "source": [
 
 
 
494
  "create_dataset_fn = functools.partial(\n",
495
  " GenomeBigWigDataset,\n",
496
  " fasta_path=fasta_path,\n",
497
  " bigwig_path_list=bigwig_path_list,\n",
498
  " sequence_length=config[\"sequence_length\"],\n",
499
  " tokenizer=tokenizer,\n",
 
500
  " keep_target_center_fraction=config[\"keep_target_center_fraction\"],\n",
501
  " num_tracks=len(config[\"bigwig_file_ids\"]),\n",
502
  ")\n",
@@ -552,21 +630,18 @@
552
  },
553
  {
554
  "cell_type": "code",
555
- "execution_count": 36,
556
  "metadata": {},
557
  "outputs": [
558
  {
559
  "name": "stdout",
560
  "output_type": "stream",
561
  "text": [
562
- "Gradient accumulation steps: 2\n",
563
- "Effective batch size: 4\n",
564
- "Effective tokens per update: 4096\n",
565
- "\n",
566
- "Training constants:\n",
567
- " Total training steps: 32\n",
568
- " Log training metrics every: 2 steps\n",
569
- " Run validation every: 4 steps\n",
570
  "\n",
571
  "Optimizer setup:\n",
572
  " Learning rate: 1e-05\n"
@@ -574,37 +649,12 @@
574
  }
575
  ],
576
  "source": [
577
- "# Calculate gradient accumulation steps and effective batch size\n",
578
- "num_devices = 1 # Single device for now\n",
579
- "sequence_length = config[\"sequence_length\"]\n",
580
- "batch_size = config[\"batch_size\"]\n",
581
- "\n",
582
- "# Calculate gradient accumulation steps\n",
583
- "num_accumulation_gradient = max(1, int(config[\"num_tokens_per_update\"] // (batch_size * num_devices * sequence_length)))\n",
584
- "\n",
585
- "# Calculate effective batch size and tokens per update\n",
586
- "effective_batch_size = batch_size * num_devices * num_accumulation_gradient\n",
587
- "effective_num_tokens_per_update = effective_batch_size * sequence_length\n",
588
- "\n",
589
- "print(f\"Gradient accumulation steps: {num_accumulation_gradient}\")\n",
590
- "print(f\"Effective batch size: {effective_batch_size}\")\n",
591
- "print(f\"Effective tokens per update: {effective_num_tokens_per_update}\")\n",
592
- "\n",
593
- "# Compute logging constants (based on deepspeed pipeline: compute_logging_constants)\n",
594
- "num_train_samples = len(train_dataset)\n",
595
- "num_tokens_per_update = effective_num_tokens_per_update # Same as effective_num_tokens_per_update\n",
596
- "\n",
597
- "# Total training steps based on token budget\n",
598
- "num_steps_training = config[\"num_tokens_training\"] // num_tokens_per_update\n",
599
- "\n",
600
- "# Steps for logging and validation\n",
601
- "log_train_step = int(np.ceil(config[\"num_tokens_per_log\"] / num_tokens_per_update))\n",
602
- "log_validation_step = int(np.ceil(config[\"num_tokens_per_validation\"] / num_tokens_per_update))\n",
603
- "\n",
604
- "print(f\"\\nTraining constants:\")\n",
605
- "print(f\" Total training steps: {num_steps_training}\")\n",
606
- "print(f\" Log training metrics every: {log_train_step} steps\")\n",
607
- "print(f\" Run validation every: {log_validation_step} steps\")\n",
608
  "\n",
609
  "# Setup optimizer\n",
610
  "optimizer = AdamW(\n",
@@ -626,87 +676,62 @@
626
  },
627
  {
628
  "cell_type": "code",
629
- "execution_count": 37,
630
  "metadata": {},
631
  "outputs": [],
632
  "source": [
633
  "class TracksMetrics:\n",
634
- " \"\"\"Simple metrics tracker for tracks prediction with both scaled and raw metrics.\"\"\"\n",
635
  " \n",
636
  " def __init__(self, track_names: List[str]):\n",
637
  " self.track_names = track_names\n",
638
  " self.num_tracks = len(track_names)\n",
639
- " # Scaled metrics: comparing scaled targets with scaled predictions\n",
640
- " self.pearson_metrics_scaled = [\n",
641
- " PearsonCorrCoef().to(device) for _ in range(self.num_tracks)\n",
642
- " ]\n",
643
- " # Raw metrics: comparing raw targets with unscaled predictions\n",
644
- " self.pearson_metrics_raw = [\n",
645
  " PearsonCorrCoef().to(device) for _ in range(self.num_tracks)\n",
646
  " ]\n",
647
  " self.losses = []\n",
648
  " \n",
649
  " def reset(self):\n",
650
- " for metric in self.pearson_metrics_scaled:\n",
651
- " metric.reset()\n",
652
- " for metric in self.pearson_metrics_raw:\n",
653
  " metric.reset()\n",
654
  " self.losses = []\n",
655
  " \n",
656
  " def update(\n",
657
  " self, \n",
658
- " predictions_scaled: torch.Tensor, \n",
659
- " targets_scaled: torch.Tensor,\n",
660
- " predictions_raw: torch.Tensor,\n",
661
- " targets_raw: torch.Tensor,\n",
662
  " loss: float\n",
663
  " ):\n",
664
  " \"\"\"\n",
665
- " Update both scaled and raw metrics.\n",
666
  " Args:\n",
667
- " predictions_scaled: (batch, seq_len, num_tracks) - scaled predictions\n",
668
- " targets_scaled: (batch, seq_len, num_tracks) - scaled targets\n",
669
- " predictions_raw: (batch, seq_len, num_tracks) - raw/unscaled predictions\n",
670
- " targets_raw: (batch, seq_len, num_tracks) - raw targets\n",
671
  " loss: scalar loss value\n",
672
  " \"\"\"\n",
673
  " # Flatten batch and sequence dimensions\n",
674
- " pred_scaled_flat = predictions_scaled.detach().reshape(-1, self.num_tracks) # (N, num_tracks)\n",
675
- " target_scaled_flat = targets_scaled.detach().reshape(-1, self.num_tracks) # (N, num_tracks)\n",
676
- " pred_raw_flat = predictions_raw.detach().reshape(-1, self.num_tracks) # (N, num_tracks)\n",
677
- " target_raw_flat = targets_raw.detach().reshape(-1, self.num_tracks) # (N, num_tracks)\n",
678
- " \n",
679
- " # Update scaled metrics\n",
680
- " for i, metric in enumerate(self.pearson_metrics_scaled):\n",
681
- " metric.update(pred_scaled_flat[:, i], target_scaled_flat[:, i])\n",
682
  " \n",
683
- " # Update raw metrics\n",
684
- " for i, metric in enumerate(self.pearson_metrics_raw):\n",
685
- " metric.update(pred_raw_flat[:, i], target_raw_flat[:, i])\n",
686
  " \n",
687
  " self.losses.append(loss)\n",
688
  " \n",
689
  " def compute(self) -> Dict[str, float]:\n",
690
- " \"\"\"Compute and return all metrics (both scaled and raw).\"\"\"\n",
691
  " metrics_dict = {}\n",
692
  " \n",
693
- " # Scaled metrics: per-track Pearson correlations\n",
694
- " for i, (track_name, metric) in enumerate(zip(self.track_names, self.pearson_metrics_scaled)):\n",
695
  " corr = metric.compute().item()\n",
696
- " metrics_dict[f\"metrics_scaled/{track_name}/pearson\"] = corr\n",
697
  " \n",
698
- " # Scaled metrics: mean Pearson correlation\n",
699
- " correlations_scaled = [metric.compute().item() for metric in self.pearson_metrics_scaled]\n",
700
- " metrics_dict[\"metrics_scaled/mean/pearson\"] = np.nanmean(correlations_scaled)\n",
701
- " \n",
702
- " # Raw metrics: per-track Pearson correlations\n",
703
- " for i, (track_name, metric) in enumerate(zip(self.track_names, self.pearson_metrics_raw)):\n",
704
- " corr = metric.compute().item()\n",
705
- " metrics_dict[f\"metrics_raw/{track_name}/pearson\"] = corr\n",
706
- " \n",
707
- " # Raw metrics: mean Pearson correlation\n",
708
- " correlations_raw = [metric.compute().item() for metric in self.pearson_metrics_raw]\n",
709
- " metrics_dict[\"metrics_raw/mean/pearson\"] = np.nanmean(correlations_raw)\n",
710
  " \n",
711
  " # Mean loss\n",
712
  " metrics_dict[\"loss\"] = np.mean(self.losses) if self.losses else 0.0\n",
@@ -716,7 +741,7 @@
716
  },
717
  {
718
  "cell_type": "code",
719
- "execution_count": 38,
720
  "metadata": {},
721
  "outputs": [],
722
  "source": [
@@ -729,148 +754,12 @@
729
  "cell_type": "markdown",
730
  "metadata": {},
731
  "source": [
732
- "# 7. Scaling functions setup (copied from pipeline)"
733
  ]
734
  },
735
  {
736
  "cell_type": "code",
737
- "execution_count": 39,
738
- "metadata": {},
739
- "outputs": [
740
- {
741
- "name": "stdout",
742
- "output_type": "stream",
743
- "text": [
744
- "Scaling functions created\n"
745
- ]
746
- }
747
- ],
748
- "source": [
749
- "def get_track_means(bigwig_file_ids: List[str]) -> np.ndarray:\n",
750
- " \"\"\"\n",
751
- " Get track means for normalization.\n",
752
- " For now, return dummy values. In real pipeline, this loads from metadata.\n",
753
- " \"\"\"\n",
754
- " # Dummy values - in real pipeline, this would load from actual metadata\n",
755
- " return np.ones(len(bigwig_file_ids), dtype=np.float32) * 1.0\n",
756
- "\n",
757
- "\n",
758
- "def get_rna_seq_track_ids(bigwig_file_ids: List[str]) -> List[int]:\n",
759
- " \"\"\"\n",
760
- " Get RNA-seq track indices.\n",
761
- " For now, return empty list. In real pipeline, this identifies RNA-seq tracks.\n",
762
- " \"\"\"\n",
763
- " # Dummy - in real pipeline, this would identify RNA-seq tracks\n",
764
- " return []\n",
765
- "\n",
766
- "\n",
767
- "def create_targets_scaling_fn(bigwig_file_ids: List[str]) -> Callable[[torch.Tensor], torch.Tensor]:\n",
768
- " \"\"\"\n",
769
- " Build a scaling function based on track means and RNA-seq squashing.\n",
770
- " Copied from the supervised tracks pipeline.\n",
771
- " \"\"\"\n",
772
- " # Load track means\n",
773
- " track_means_np = get_track_means(bigwig_file_ids)\n",
774
- " track_means = torch.tensor(track_means_np, dtype=torch.float32)\n",
775
- " \n",
776
- " # Get which tracks use squashing\n",
777
- " rna_ids = get_rna_seq_track_ids(bigwig_file_ids)\n",
778
- " apply_squashing = torch.zeros((len(bigwig_file_ids),), dtype=torch.bool)\n",
779
- " if len(rna_ids) > 0:\n",
780
- " apply_squashing[rna_ids] = True\n",
781
- " \n",
782
- " def transform_fn(x: torch.Tensor) -> torch.Tensor:\n",
783
- " \"\"\"\n",
784
- " x: torch.Tensor, shape (batch, seq_len, num_tracks)\n",
785
- " \"\"\"\n",
786
- " device = x.device\n",
787
- " \n",
788
- " # Move constants to correct device\n",
789
- " means = track_means.to(device)\n",
790
- " squash_mask = apply_squashing.to(device)\n",
791
- " \n",
792
- " # Normalize\n",
793
- " scaled = x / means\n",
794
- " \n",
795
- " # Power squashing where needed\n",
796
- " squashed = torch.where(\n",
797
- " squash_mask.view(1, 1, -1),\n",
798
- " scaled.pow(0.75),\n",
799
- " scaled,\n",
800
- " )\n",
801
- " \n",
802
- " # Smooth clipping: if > 10, apply formula\n",
803
- " clipped = torch.where(\n",
804
- " squashed > 10.0,\n",
805
- " 2.0 * torch.sqrt(squashed * 10.0) - 10.0,\n",
806
- " squashed,\n",
807
- " )\n",
808
- " \n",
809
- " return clipped\n",
810
- " \n",
811
- " return transform_fn\n",
812
- "\n",
813
- "\n",
814
- "def create_predictions_scaling_fn(bigwig_file_ids: List[str]) -> Callable[[torch.Tensor], torch.Tensor]:\n",
815
- " \"\"\"\n",
816
- " Inverse scaling function to apply on predictions before computing metrics.\n",
817
- " Copied from the supervised tracks pipeline.\n",
818
- " \"\"\"\n",
819
- " # Load means\n",
820
- " track_means_np = get_track_means(bigwig_file_ids)\n",
821
- " track_means = torch.tensor(track_means_np, dtype=torch.float32)\n",
822
- " \n",
823
- " # RNA-seq mask\n",
824
- " rna_ids = get_rna_seq_track_ids(bigwig_file_ids)\n",
825
- " apply_squashing = torch.zeros((len(bigwig_file_ids),), dtype=torch.bool)\n",
826
- " if len(rna_ids) > 0:\n",
827
- " apply_squashing[rna_ids] = True\n",
828
- " \n",
829
- " def inverse_transform_fn(x: torch.Tensor) -> torch.Tensor:\n",
830
- " \"\"\"\n",
831
- " x: torch.Tensor, shape (batch, seq_len, num_tracks)\n",
832
- " \"\"\"\n",
833
- " device = x.device\n",
834
- " means = track_means.to(device)\n",
835
- " squash_mask = apply_squashing.to(device)\n",
836
- " \n",
837
- " # Undo clipping\n",
838
- " unclipped = torch.where(\n",
839
- " x > 10.0,\n",
840
- " (x + 10.0).pow(2) / (4 * 10.0),\n",
841
- " x,\n",
842
- " )\n",
843
- " \n",
844
- " # Undo squashing\n",
845
- " unsquashed = torch.where(\n",
846
- " squash_mask.view(1, 1, -1),\n",
847
- " unclipped.pow(1.0 / 0.75),\n",
848
- " unclipped,\n",
849
- " )\n",
850
- " \n",
851
- " # Undo normalization\n",
852
- " return unsquashed * means\n",
853
- " \n",
854
- " return inverse_transform_fn\n",
855
- "\n",
856
- "\n",
857
- "# Create scaling functions\n",
858
- "scale_targets_fn = create_targets_scaling_fn(config[\"bigwig_file_ids\"])\n",
859
- "scale_predictions_fn = create_predictions_scaling_fn(config[\"bigwig_file_ids\"])\n",
860
- "\n",
861
- "print(\"Scaling functions created\")"
862
- ]
863
- },
864
- {
865
- "cell_type": "markdown",
866
- "metadata": {},
867
- "source": [
868
- "# 8. Loss functions"
869
- ]
870
- },
871
- {
872
- "cell_type": "code",
873
- "execution_count": 40,
874
  "metadata": {},
875
  "outputs": [],
876
  "source": [
@@ -887,49 +776,24 @@
887
  "def poisson_multinomial_loss(\n",
888
  " logits: torch.Tensor,\n",
889
  " targets: torch.Tensor,\n",
890
- " mask: torch.Tensor | None = None,\n",
891
  " shape_loss_coefficient: float = 5.0,\n",
892
  " epsilon: float = 1e-7,\n",
893
  ") -> tuple[torch.Tensor, torch.Tensor | None, torch.Tensor | None]:\n",
894
  " \"\"\"\n",
895
  " Regression loss for bigwig tracks (MSE, Poisson, or Poisson-Multinomial).\n",
896
  " \"\"\"\n",
897
- " scale_loss, shape_loss = None, None\n",
898
- " \n",
899
- " if mask is None:\n",
900
- " mask = torch.ones_like(targets, dtype=torch.float32, device=targets.device)\n",
901
- " else:\n",
902
- " mask = mask.float()\n",
903
- " \n",
904
- " mask_sum = mask.sum() + epsilon\n",
905
- " masked_logits = logits * mask\n",
906
- " masked_targets = targets * mask\n",
907
  "\n",
908
  " # Scale loss\n",
909
- " mask_sum_per_track_per_seq = mask.sum(dim=1) # (batch, num_tracks)\n",
910
- " mask_per_sequence = mask_sum_per_track_per_seq > 0.0 # (batch, num_tracks)\n",
911
- " \n",
912
- " sum_pred = masked_logits.sum(dim=1) # (batch, num_tracks)\n",
913
- " sum_true = masked_targets.sum(dim=1) # (batch, num_tracks)\n",
914
- " \n",
915
  " scale_loss = poisson_loss(sum_true, sum_pred, epsilon=epsilon)\n",
916
- " scale_loss = scale_loss / (mask_sum_per_track_per_seq + epsilon)\n",
917
- " \n",
918
- " if mask_per_sequence.any():\n",
919
- " scale_loss_filtered = scale_loss[mask_per_sequence]\n",
920
- " scale_loss = scale_loss_filtered.mean()\n",
921
- " else:\n",
922
- " scale_loss = torch.tensor(0.0, device=targets.device, dtype=targets.dtype)\n",
923
  " \n",
924
  " # Shape loss\n",
925
- " predicted_counts = masked_logits + (epsilon * mask)\n",
926
- " masked_targets_with_epsilon = masked_targets + (epsilon * mask)\n",
927
- " \n",
928
- " denom = predicted_counts.sum(dim=1, keepdim=True) + epsilon\n",
929
- " p_pred = predicted_counts / denom\n",
930
- " \n",
931
  " pl_pred = safe_for_grad_log_torch(p_pred)\n",
932
- " shape_loss = -(masked_targets_with_epsilon * pl_pred).sum() / mask_sum\n",
933
  " \n",
934
  " # Combine\n",
935
  " loss = shape_loss + scale_loss / shape_loss_coefficient\n",
@@ -941,57 +805,42 @@
941
  "cell_type": "markdown",
942
  "metadata": {},
943
  "source": [
944
- "# 9. Training loop"
945
  ]
946
  },
947
  {
948
  "cell_type": "code",
949
- "execution_count": 41,
950
  "metadata": {},
951
  "outputs": [],
952
  "source": [
953
  "def train_step(\n",
954
  " model: nn.Module,\n",
955
  " batch: Dict[str, torch.Tensor],\n",
956
- " optimizer: torch.optim.Optimizer,\n",
957
- " scale_targets_fn: Callable,\n",
958
- " config: Dict,\n",
959
- " num_accumulation_steps: int = 1,\n",
960
  ") -> float:\n",
961
- " \"\"\"Single training step with gradient accumulation support.\"\"\"\n",
962
  " tokens = batch[\"tokens\"].to(device)\n",
963
- " bigwig_targets = batch[\"bigwig_targets\"].to(device) # Shape: (batch, seq_len_cropped, num_tracks)\n",
964
  " \n",
965
  " # Forward pass\n",
966
  " outputs = model(tokens=tokens)\n",
967
- " bigwig_logits = outputs[\"bigwig_tracks_logits\"] # Shape: (batch, cropped_seq_len, num_tracks)\n",
968
- " \n",
969
- " # Scale targets\n",
970
- " scaled_targets = scale_targets_fn(bigwig_targets)\n",
971
  " \n",
972
  " # Compute loss\n",
973
  " loss, _, _ = poisson_multinomial_loss(\n",
974
  " logits=bigwig_logits,\n",
975
- " targets=scaled_targets,\n",
976
- " shape_loss_coefficient=config[\"bigwig_shape_loss_coefficient\"],\n",
977
  " )\n",
978
  " \n",
979
- " # Scale loss by accumulation steps (for gradient accumulation)\n",
980
- " loss = loss / num_accumulation_steps\n",
981
- " \n",
982
- " # Backward pass (accumulate gradients)\n",
983
  " loss.backward()\n",
984
- " \n",
985
- " return loss.item() * num_accumulation_steps # Return unscaled loss for logging\n",
986
  "\n",
987
  "\n",
988
  "def validation_step(\n",
989
  " model: nn.Module,\n",
990
  " batch: Dict[str, torch.Tensor],\n",
991
- " scale_targets_fn: Callable,\n",
992
- " scale_predictions_fn: Callable,\n",
993
  " metrics: TracksMetrics,\n",
994
- " config: Dict,\n",
995
  ") -> float:\n",
996
  " \"\"\"Single validation step.\"\"\"\n",
997
  " model.eval()\n",
@@ -1004,35 +853,32 @@
1004
  " outputs = model(tokens=tokens)\n",
1005
  " bigwig_logits = outputs[\"bigwig_tracks_logits\"]\n",
1006
  " \n",
1007
- " # Scale targets for loss computation\n",
1008
- " scaled_targets = scale_targets_fn(bigwig_targets)\n",
1009
- " \n",
1010
- " # Compute loss (using scaled targets)\n",
1011
  " loss, _, _ = poisson_multinomial_loss(\n",
1012
  " logits=bigwig_logits,\n",
1013
- " targets=scaled_targets,\n",
1014
- " shape_loss_coefficient=config[\"bigwig_shape_loss_coefficient\"],\n",
1015
  " )\n",
1016
  " \n",
1017
- " # Scale predictions back to original space for metrics\n",
1018
- " # (predictions are in scaled space, need to inverse transform)\n",
1019
- " unscaled_predictions = scale_predictions_fn(bigwig_logits)\n",
1020
- " \n",
1021
- " # Update metrics (using original space targets and predictions)\n",
1022
  " metrics.update(\n",
1023
- " predictions_scaled=bigwig_logits,\n",
1024
- " targets_scaled=scaled_targets,\n",
1025
- " predictions_raw=unscaled_predictions,\n",
1026
- " targets_raw=bigwig_targets,\n",
1027
  " loss=loss.item()\n",
1028
  " )\n",
1029
  " \n",
1030
  " return loss.item()"
1031
  ]
1032
  },
 
 
 
 
 
 
 
1033
  {
1034
  "cell_type": "code",
1035
- "execution_count": 42,
1036
  "metadata": {},
1037
  "outputs": [
1038
  {
@@ -1040,163 +886,455 @@
1040
  "output_type": "stream",
1041
  "text": [
1042
  "Starting training...\n",
1043
- "Training for 32 steps with 2 gradient accumulation steps\n",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1044
  "\n",
1045
- "Step 1/32 | Loss: 0.5661 | Mean Pearson: -0.0525 | Tokens: 4,096\n",
 
 
 
 
 
 
 
 
1046
  "\n",
1047
- "Running validation at step 0...\n",
1048
- " Validation Loss: 0.3987\n",
1049
- " Validation Mean Pearson: -0.0426\n",
1050
- " ENCFF884LDL/pearson: -0.0426\n",
1051
- "Step 3/32 | Loss: 0.3825 | Mean Pearson: -0.0112 | Tokens: 12,288\n",
1052
- "Step 5/32 | Loss: 1.1384 | Mean Pearson: -0.0777 | Tokens: 20,480\n",
 
 
 
1053
  "\n",
1054
- "Running validation at step 4...\n",
1055
- " Validation Loss: 0.4381\n",
1056
- " Validation Mean Pearson: -0.0017\n",
1057
- " ENCFF884LDL/pearson: -0.0017\n",
1058
- "Step 7/32 | Loss: 0.4961 | Mean Pearson: -0.0188 | Tokens: 28,672\n",
1059
- "Step 9/32 | Loss: 0.4903 | Mean Pearson: -0.1522 | Tokens: 36,864\n",
 
 
 
1060
  "\n",
1061
- "Running validation at step 8...\n",
1062
- " Validation Loss: 0.3429\n",
1063
- " Validation Mean Pearson: -0.0997\n",
1064
- " ENCFF884LDL/pearson: -0.0997\n",
1065
- "Step 11/32 | Loss: 0.4597 | Mean Pearson: -0.0199 | Tokens: 45,056\n",
1066
- "Step 13/32 | Loss: 0.6507 | Mean Pearson: -0.0256 | Tokens: 53,248\n",
 
 
 
1067
  "\n",
1068
- "Running validation at step 12...\n",
1069
- " Validation Loss: 0.3901\n",
1070
- " Validation Mean Pearson: -0.0786\n",
1071
- " ENCFF884LDL/pearson: -0.0786\n",
1072
- "Step 15/32 | Loss: 0.3911 | Mean Pearson: -0.0419 | Tokens: 61,440\n",
1073
- "Step 17/32 | Loss: 0.4202 | Mean Pearson: -0.0883 | Tokens: 69,632\n",
 
 
 
1074
  "\n",
1075
- "Running validation at step 16...\n",
1076
- " Validation Loss: 0.3626\n",
1077
- " Validation Mean Pearson: -0.0840\n",
1078
- " ENCFF884LDL/pearson: -0.0840\n",
1079
- "Step 19/32 | Loss: 0.3608 | Mean Pearson: -0.1057 | Tokens: 77,824\n",
1080
- "Step 21/32 | Loss: 0.3942 | Mean Pearson: 0.1459 | Tokens: 86,016\n",
 
 
 
1081
  "\n",
1082
- "Running validation at step 20...\n",
1083
- " Validation Loss: 0.3281\n",
1084
- " Validation Mean Pearson: -0.0667\n",
1085
- " ENCFF884LDL/pearson: -0.0667\n",
1086
- "Step 23/32 | Loss: 0.4090 | Mean Pearson: 0.0540 | Tokens: 94,208\n",
1087
- "Step 25/32 | Loss: 0.5151 | Mean Pearson: -0.0076 | Tokens: 102,400\n",
 
 
 
1088
  "\n",
1089
- "Running validation at step 24...\n",
1090
- " Validation Loss: 0.2927\n",
1091
- " Validation Mean Pearson: -0.0409\n",
1092
- " ENCFF884LDL/pearson: -0.0409\n",
1093
- "Step 27/32 | Loss: 0.4339 | Mean Pearson: -0.0887 | Tokens: 110,592\n",
1094
- "Step 29/32 | Loss: 0.4516 | Mean Pearson: -0.0763 | Tokens: 118,784\n",
 
 
 
1095
  "\n",
1096
- "Running validation at step 28...\n",
1097
- " Validation Loss: 0.3076\n",
1098
- " Validation Mean Pearson: -0.0861\n",
1099
- " ENCFF884LDL/pearson: -0.0861\n",
1100
- "Step 31/32 | Loss: 0.4121 | Mean Pearson: -0.0530 | Tokens: 126,976\n",
 
 
 
 
1101
  "\n",
1102
- "Training completed after 32 steps!\n"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1103
  ]
1104
  }
1105
  ],
1106
  "source": [
1107
- "# Training loop (step-based with gradient accumulation)\n",
1108
  "print(\"Starting training...\")\n",
1109
- "print(f\"Training for {num_steps_training} steps with {num_accumulation_gradient} gradient accumulation steps\\n\")\n",
1110
  "\n",
1111
  "model.train()\n",
1112
  "train_metrics.reset()\n",
1113
  "optimizer.zero_grad() # Initialize gradients\n",
1114
  "\n",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1115
  "# Create iterator for training data (will cycle if needed)\n",
1116
  "train_iter = iter(train_loader)\n",
1117
- "num_tokens_seen = 0\n",
1118
- "\n",
1119
- "# Main training loop: for loop over optimizer steps (like deepspeed pipeline)\n",
1120
- "for optimizer_step_idx in range(num_steps_training):\n",
1121
- " # Gradient accumulation loop\n",
1122
- " accumulated_loss = 0.0\n",
1123
- " for acc_idx in range(num_accumulation_gradient):\n",
1124
- " try:\n",
1125
- " batch = next(train_iter)\n",
1126
- " except StopIteration:\n",
1127
- " # Restart iterator if we run out of data\n",
1128
- " train_iter = iter(train_loader)\n",
1129
- " batch = next(train_iter)\n",
1130
- " \n",
1131
- " # Forward pass and accumulate gradients\n",
1132
- " loss = train_step(\n",
1133
- " model, batch, optimizer, scale_targets_fn, config, \n",
1134
- " num_accumulation_steps=num_accumulation_gradient\n",
1135
- " )\n",
1136
- " accumulated_loss += loss\n",
1137
  " \n",
1138
- " # Update optimizer (after accumulation)\n",
 
 
 
1139
  " optimizer.step()\n",
1140
  " optimizer.zero_grad()\n",
1141
  " \n",
1142
- " # Update tokens seen\n",
1143
- " num_tokens_seen += effective_num_tokens_per_update\n",
1144
- " \n",
1145
- " # Update metrics (on last batch of accumulation)\n",
1146
  " tokens = batch[\"tokens\"].to(device)\n",
1147
  " bigwig_targets = batch[\"bigwig_targets\"].to(device)\n",
1148
  " with torch.no_grad():\n",
1149
  " outputs = model(tokens=tokens)\n",
1150
  " bigwig_logits = outputs[\"bigwig_tracks_logits\"]\n",
1151
  " \n",
1152
- " # Scale targets for scaled metrics\n",
1153
- " scaled_targets = scale_targets_fn(bigwig_targets)\n",
1154
- " \n",
1155
- " # Unscale predictions for raw metrics\n",
1156
- " unscaled_predictions = scale_predictions_fn(bigwig_logits)\n",
1157
- " \n",
1158
- " avg_loss = accumulated_loss / num_accumulation_gradient\n",
1159
  " train_metrics.update(\n",
1160
- " predictions_scaled=bigwig_logits,\n",
1161
- " targets_scaled=scaled_targets,\n",
1162
- " predictions_raw=unscaled_predictions,\n",
1163
- " targets_raw=bigwig_targets,\n",
1164
- " loss=avg_loss\n",
1165
  " )\n",
1166
  " \n",
1167
  " # Logging\n",
1168
- " if optimizer_step_idx % log_train_step == 0:\n",
1169
  " train_metrics_dict = train_metrics.compute()\n",
1170
- " current_lr = config[\"learning_rate\"]\n",
1171
- " print(f\"Step {optimizer_step_idx + 1}/{num_steps_training} | \"\n",
1172
- " f\"Loss: {avg_loss:.4f} | \"\n",
1173
- " f\"Mean Pearson: {train_metrics_dict['metrics_scaled/mean/pearson']:.4f} | \"\n",
1174
- " f\"Tokens: {num_tokens_seen:,}\")\n",
 
 
 
 
 
 
 
 
 
 
 
 
1175
  " train_metrics.reset()\n",
1176
  " \n",
1177
  " # Validation\n",
1178
- " if optimizer_step_idx % log_validation_step == 0:\n",
1179
- " print(f\"\\nRunning validation at step {optimizer_step_idx}...\")\n",
1180
  " val_metrics.reset()\n",
1181
  " model.eval()\n",
1182
  " \n",
1183
- " val_losses = []\n",
1184
  " for val_batch in val_loader:\n",
1185
- " val_loss = validation_step(\n",
1186
- " model, val_batch, scale_targets_fn, scale_predictions_fn, val_metrics, config\n",
1187
- " )\n",
1188
- " val_losses.append(val_loss)\n",
1189
  " \n",
1190
  " # Print validation metrics\n",
1191
  " val_metrics_dict = val_metrics.compute()\n",
1192
- " print(f\" Validation Loss: {np.mean(val_losses):.4f}\")\n",
1193
- " print(f\" Validation Mean Pearson: {val_metrics_dict['metrics_scaled/mean/pearson']:.4f}\")\n",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1194
  " for track_name in config[\"bigwig_file_ids\"]:\n",
1195
- " print(f\" {track_name}/pearson: {val_metrics_dict[f'metrics_scaled/{track_name}/pearson']:.4f}\")\n",
1196
  " \n",
1197
  " model.train() # Back to training mode\n",
1198
  "\n",
1199
- "print(f\"\\nTraining completed after {num_steps_training} steps!\")\n"
1200
  ]
1201
  },
1202
  {
@@ -1208,122 +1346,59 @@
1208
  },
1209
  {
1210
  "cell_type": "code",
1211
- "execution_count": 43,
1212
- "metadata": {},
1213
- "outputs": [],
1214
- "source": [
1215
- "def test_step(\n",
1216
- " model: nn.Module,\n",
1217
- " batch: Dict[str, torch.Tensor],\n",
1218
- " scale_targets_fn: Callable,\n",
1219
- " scale_predictions_fn: Callable,\n",
1220
- " metrics: TracksMetrics,\n",
1221
- ") -> None:\n",
1222
- " \"\"\"\n",
1223
- " Pure evaluation step for test set (no loss computation).\n",
1224
- " Based on tracks_evaluation_step_torch from deepspeed pipeline.\n",
1225
- " \"\"\"\n",
1226
- " tokens = batch[\"tokens\"].to(device)\n",
1227
- " bigwig_targets = batch[\"bigwig_targets\"].to(device) # Shape: (batch, seq_len_cropped, num_tracks)\n",
1228
- " \n",
1229
- " with torch.no_grad():\n",
1230
- " # Forward pass\n",
1231
- " outputs = model(tokens=tokens)\n",
1232
- " bigwig_logits = outputs[\"bigwig_tracks_logits\"] # Shape: (batch, cropped_seq_len, num_tracks)\n",
1233
- " \n",
1234
- " # Scale targets for scaled metrics\n",
1235
- " scaled_targets = scale_targets_fn(bigwig_targets)\n",
1236
- " \n",
1237
- " # Unscale predictions for raw metrics\n",
1238
- " unscaled_predictions = scale_predictions_fn(bigwig_logits)\n",
1239
- " \n",
1240
- " # Update metrics with both scaled and raw values\n",
1241
- " # Pass 0.0 as loss since we don't compute loss in test evaluation\n",
1242
- " metrics.update(\n",
1243
- " predictions_scaled=bigwig_logits,\n",
1244
- " targets_scaled=scaled_targets,\n",
1245
- " predictions_raw=unscaled_predictions,\n",
1246
- " targets_raw=bigwig_targets,\n",
1247
- " loss=0.0\n",
1248
- " )"
1249
- ]
1250
- },
1251
- {
1252
- "cell_type": "code",
1253
- "execution_count": 28,
1254
  "metadata": {},
1255
  "outputs": [
1256
  {
1257
  "name": "stdout",
1258
  "output_type": "stream",
1259
  "text": [
1260
- "\n",
1261
- "==================================================\n",
1262
- "Test Set Evaluation\n",
1263
- "==================================================\n",
1264
- "Running test evaluation with 5 steps (10 samples)\n",
1265
  "\n",
1266
  "==================================================\n",
1267
  "Test Set Results\n",
1268
  "==================================================\n",
1269
  "\n",
1270
- "Scaled Metrics (scaled predictions vs scaled targets):\n",
1271
- " Mean Pearson (scaled): -0.0020\n",
1272
- " ENCFF884LDL/pearson: -0.0020\n",
1273
- "\n",
1274
- "Raw Metrics (raw predictions vs raw targets):\n",
1275
- " Mean Pearson (raw): -0.0020\n",
1276
- " ENCFF884LDL/pearson: -0.0020\n",
1277
- "==================================================\n"
1278
  ]
1279
  }
1280
  ],
1281
  "source": [
1282
- "print(\"\\n\" + \"=\"*50)\n",
1283
- "print(\"Test Set Evaluation\")\n",
1284
- "print(\"=\"*50)\n",
1285
- "\n",
1286
  "# Calculate number of test steps (based on deepspeed pipeline)\n",
1287
  "num_test_samples = len(test_dataset)\n",
1288
  "num_test_steps = num_test_samples // config[\"batch_size\"]\n",
1289
- "\n",
1290
  "print(f\"Running test evaluation with {num_test_steps} steps ({num_test_samples} samples)\")\n",
1291
  "\n",
1292
  "# Set model to eval mode\n",
1293
  "model.eval()\n",
1294
  "\n",
1295
- "# Create iterator for test data\n",
1296
- "test_iter = iter(test_loader)\n",
1297
  "\n",
1298
- "# Run test evaluation (based on deepspeed pipeline: for loop over test steps)\n",
1299
- "for _ in range(num_test_steps):\n",
1300
- " try:\n",
1301
- " test_batch = next(test_iter)\n",
1302
- " except StopIteration:\n",
1303
- " break\n",
1304
- " \n",
1305
- " # Perform test evaluation (pure evaluation, no loss computation)\n",
1306
- " test_step(\n",
1307
- " model, test_batch, scale_targets_fn, scale_predictions_fn, test_metrics\n",
1308
  " )\n",
1309
- "\n",
1310
  "# Compute final test metrics\n",
1311
  "test_metrics_dict = test_metrics.compute()\n",
1312
- "\n",
1313
  "print(\"\\n\" + \"=\"*50)\n",
1314
  "print(\"Test Set Results\")\n",
1315
  "print(\"=\"*50)\n",
1316
- "print(f\"\\nScaled Metrics (scaled predictions vs scaled targets):\")\n",
1317
- "print(f\" Mean Pearson (scaled): {test_metrics_dict['metrics_scaled/mean/pearson']:.4f}\")\n",
1318
- "for track_name in config[\"bigwig_file_ids\"]:\n",
1319
- " print(f\" {track_name}/pearson: {test_metrics_dict[f'metrics_scaled/{track_name}/pearson']:.4f}\")\n",
1320
- "\n",
1321
- "print(f\"\\nRaw Metrics (raw predictions vs raw targets):\")\n",
1322
- "print(f\" Mean Pearson (raw): {test_metrics_dict['metrics_raw/mean/pearson']:.4f}\")\n",
1323
- "for track_name in config[\"bigwig_file_ids\"]:\n",
1324
- " print(f\" {track_name}/pearson: {test_metrics_dict[f'metrics_raw/{track_name}/pearson']:.4f}\")\n",
1325
- "print(\"=\"*50)"
1326
  ]
 
 
 
 
 
 
 
1327
  }
1328
  ],
1329
  "metadata": {
 
16
  },
17
  {
18
  "cell_type": "code",
19
+ "execution_count": 1,
20
  "metadata": {},
21
  "outputs": [],
22
  "source": [
 
28
  },
29
  {
30
  "cell_type": "code",
31
+ "execution_count": 1,
32
  "metadata": {},
33
  "outputs": [],
34
  "source": [
35
  "# 0. Imports\n",
36
  "import random\n",
37
  "import functools\n",
38
+ "from typing import List, Dict, Callable\n",
39
  "import os\n",
40
  "import subprocess\n",
41
  "\n",
 
48
  "import numpy as np\n",
49
  "import pyBigWig\n",
50
  "from pyfaidx import Fasta\n",
51
+ "from torchmetrics import PearsonCorrCoef\n",
52
+ "import plotly.graph_objects as go\n",
53
+ "from plotly.subplots import make_subplots\n",
54
+ "from IPython.display import display"
55
  ]
56
  },
57
  {
58
  "cell_type": "markdown",
59
  "metadata": {},
60
  "source": [
61
+ "# 1. Configuration setup\n",
62
+ "\n",
63
+ "## Configuration Parameters\n",
64
+ "\n",
65
+ "### Model\n",
66
+ "- **`model_name`**: HuggingFace model name/identifier for the pretrained backbone model\n",
67
+ "\n",
68
+ "### Data\n",
69
+ "- **`data_cache_dir`**: Directory where downloaded data files (FASTA, bigWig) will be stored\n",
70
+ "- **`fasta_url`**: URL to download reference genome FASTA file\n",
71
+ "- **`bigwig_url_list`**: List of URLs for bigWig track files to download\n",
72
+ "- **`sequence_length`**: Length of input sequences in base pairs (bp)\n",
73
+ "- **`keep_target_center_fraction`**: Fraction of center sequence to keep for target prediction (crops edges to focus on center)\n",
74
+ "\n",
75
+ "### Training\n",
76
+ "- **`batch_size`**: Number of samples per batch\n",
77
+ "- **`learning_rate`**: Constant learning rate for optimizer\n",
78
+ "- **`weight_decay`**: L2 regularization coefficient for optimizer\n",
79
+ "- **`num_steps_training`**: Total number of training steps\n",
80
+ "- **`log_every_n_steps`**: Log training metrics every N steps\n",
81
+ "- **`validate_every_n_steps`**: Run validation every N steps\n",
82
+ "\n",
83
+ "### Validation\n",
84
+ "- **`num_validation_samples`**: Number of samples to use for validation set\n",
85
+ "\n",
86
+ "### General\n",
87
+ "- **`seed`**: Random seed for reproducibility\n",
88
+ "- **`device`**: Device to run training on (\"cuda\" or \"cpu\")\n",
89
+ "- **`num_workers`**: Number of worker processes for DataLoader (0 = single-threaded)"
90
  ]
91
  },
92
  {
93
  "cell_type": "code",
94
+ "execution_count": 15,
95
  "metadata": {},
96
  "outputs": [
97
  {
 
105
  "source": [
106
  "config = {\n",
107
  " # Model\n",
108
+ " \"model_name\": \"InstaDeepAI/ntv3_8M_7downsample_pretrained_le_1mb\",\n",
109
  " \n",
110
  " # Data\n",
111
+ " \"data_cache_dir\": \"./data\",\n",
112
+ " \"fasta_url\": \"https://hgdownload.gi.ucsc.edu/goldenPath/hg38/bigZips/hg38.fa.gz\",\n",
113
+ " \"bigwig_url_list\": [\n",
114
+ " \"https://www.encodeproject.org/files/ENCFF884LDL/@@download/ENCFF884LDL.bigWig\"\n",
115
+ " ],\n",
116
+ " \"sequence_length\": 1_024,\n",
117
+ " \"keep_target_center_fraction\": 0.375,\n",
118
  " \n",
119
  " # Training\n",
120
+ " \"batch_size\": 8,\n",
121
+ " \"num_steps_training\": 1000,\n",
122
+ " \"log_every_n_steps\": 10,\n",
123
+ " \"learning_rate\": 1e-5,\n",
124
+ " \"weight_decay\": 0.01,\n",
 
 
 
125
  " \n",
126
  " # Validation\n",
127
+ " \"validate_every_n_steps\": 50,\n",
128
+ " \"num_validation_samples\": 100,\n",
 
 
 
 
129
  " \n",
130
  " # General\n",
131
+ " \"seed\": 42,\n",
132
+ " \"device\": \"cuda\" if torch.cuda.is_available() else \"cpu\",\n",
133
+ " \"num_workers\": 0,\n",
134
  "}\n",
135
  "\n",
136
  "os.makedirs(config[\"data_cache_dir\"], exist_ok=True)\n",
 
242
  },
243
  {
244
  "cell_type": "code",
245
+ "execution_count": 3,
246
  "metadata": {},
247
  "outputs": [],
248
  "source": [
 
262
  },
263
  {
264
  "cell_type": "code",
265
+ "execution_count": 4,
266
  "metadata": {},
267
  "outputs": [],
268
  "source": [
 
330
  },
331
  {
332
  "cell_type": "code",
333
+ "execution_count": 5,
334
  "metadata": {},
335
  "outputs": [
336
  {
 
361
  "print(f\"Model parameters: {sum(p.numel() for p in model.parameters()):,}\")"
362
  ]
363
  },
364
+ {
365
+ "cell_type": "code",
366
+ "execution_count": 6,
367
+ "metadata": {},
368
+ "outputs": [],
369
+ "source": [
370
+ "# Scaling functions for targets\n",
371
+ "def get_track_means(bigwig_file_ids: List[str]) -> np.ndarray:\n",
372
+ " \"\"\"\n",
373
+ " Get track means for normalization.\n",
374
+ " For now, return dummy values. In real pipeline, this loads from metadata.\n",
375
+ " \"\"\"\n",
376
+ " # Dummy values - in real pipeline, this would load from actual metadata\n",
377
+ " return np.ones(len(bigwig_file_ids), dtype=np.float32) * 1.0\n",
378
+ "\n",
379
+ "\n",
380
+ "def create_targets_scaling_fn(bigwig_file_ids: List[str]) -> Callable[[torch.Tensor], torch.Tensor]:\n",
381
+ " \"\"\"\n",
382
+ " Build a scaling function based on track means.\n",
383
+ " \"\"\"\n",
384
+ " # Load track means\n",
385
+ " track_means_np = get_track_means(bigwig_file_ids)\n",
386
+ " track_means = torch.tensor(track_means_np, dtype=torch.float32)\n",
387
+ " \n",
388
+ " def transform_fn(x: torch.Tensor) -> torch.Tensor:\n",
389
+ " \"\"\"\n",
390
+ " x: torch.Tensor, shape (seq_len, num_tracks) or (batch, seq_len, num_tracks)\n",
391
+ " \"\"\"\n",
392
+ " # Move constants to correct device then normalize\n",
393
+ " means = track_means.to(x.device)\n",
394
+ " scaled = x / means\n",
395
+ "\n",
396
+ " # Smooth clipping: if > 10, apply formula\n",
397
+ " clipped = torch.where(\n",
398
+ " scaled > 10.0,\n",
399
+ " 2.0 * torch.sqrt(scaled * 10.0) - 10.0,\n",
400
+ " scaled,\n",
401
+ " )\n",
402
+ " return clipped\n",
403
+ " \n",
404
+ " return transform_fn"
405
+ ]
406
+ },
407
  {
408
  "cell_type": "markdown",
409
  "metadata": {},
 
413
  },
414
  {
415
  "cell_type": "code",
416
+ "execution_count": 7,
417
  "metadata": {},
418
  "outputs": [],
419
  "source": [
 
457
  " sequence_length: int,\n",
458
  " num_samples: int,\n",
459
  " tokenizer: AutoTokenizer,\n",
460
+ " transform_fn: Callable[[torch.Tensor], torch.Tensor],\n",
461
  " keep_target_center_fraction: float = 1.0,\n",
462
  " num_tracks: int = 1,\n",
463
  " ):\n",
 
471
  " self.sequence_length = sequence_length\n",
472
  " self.num_samples = num_samples\n",
473
  " self.tokenizer = tokenizer\n",
474
+ " self.transform_fn = transform_fn\n",
475
  " self.keep_target_center_fraction = keep_target_center_fraction\n",
476
  " self.num_tracks = num_tracks\n",
477
  " self.chroms = chroms\n",
 
536
  " target_length = seq_len - 2 * target_offset\n",
537
  " bigwig_targets = bigwig_targets[target_offset:target_offset + target_length, :]\n",
538
  "\n",
539
+ " # Apply scaling to targets\n",
540
+ " bigwig_targets = self.transform_fn(bigwig_targets)\n",
541
+ "\n",
542
  " sample = {\n",
543
  " \"tokens\": tokens,\n",
544
  " \"bigwig_targets\": bigwig_targets,\n",
 
551
  },
552
  {
553
  "cell_type": "code",
554
+ "execution_count": 16,
555
  "metadata": {},
556
  "outputs": [
557
  {
 
559
  "output_type": "stream",
560
  "text": [
561
  "Train samples: 100\n",
562
+ "Val samples: 100\n",
563
+ "Test samples: 100\n"
564
  ]
565
  }
566
  ],
567
  "source": [
568
+ "# Create scaling function\n",
569
+ "transform_fn = create_targets_scaling_fn(config[\"bigwig_file_ids\"])\n",
570
+ "\n",
571
  "create_dataset_fn = functools.partial(\n",
572
  " GenomeBigWigDataset,\n",
573
  " fasta_path=fasta_path,\n",
574
  " bigwig_path_list=bigwig_path_list,\n",
575
  " sequence_length=config[\"sequence_length\"],\n",
576
  " tokenizer=tokenizer,\n",
577
+ " transform_fn=transform_fn,\n",
578
  " keep_target_center_fraction=config[\"keep_target_center_fraction\"],\n",
579
  " num_tracks=len(config[\"bigwig_file_ids\"]),\n",
580
  ")\n",
 
630
  },
631
  {
632
  "cell_type": "code",
633
+ "execution_count": 17,
634
  "metadata": {},
635
  "outputs": [
636
  {
637
  "name": "stdout",
638
  "output_type": "stream",
639
  "text": [
640
+ "Training configuration:\n",
641
+ " Batch size: 8\n",
642
+ " Total training steps: 1000\n",
643
+ " Log metrics every: 10 steps\n",
644
+ " Validate every: 50 steps\n",
 
 
 
645
  "\n",
646
  "Optimizer setup:\n",
647
  " Learning rate: 1e-05\n"
 
649
  }
650
  ],
651
  "source": [
652
+ "# Training setup\n",
653
+ "print(f\"Training configuration:\")\n",
654
+ "print(f\" Batch size: {config[\"batch_size\"]}\")\n",
655
+ "print(f\" Total training steps: {config[\"num_steps_training\"]}\")\n",
656
+ "print(f\" Log metrics every: {config[\"log_every_n_steps\"]} steps\")\n",
657
+ "print(f\" Validate every: {config[\"validate_every_n_steps\"]} steps\")\n",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
658
  "\n",
659
  "# Setup optimizer\n",
660
  "optimizer = AdamW(\n",
 
676
  },
677
  {
678
  "cell_type": "code",
679
+ "execution_count": 18,
680
  "metadata": {},
681
  "outputs": [],
682
  "source": [
683
  "class TracksMetrics:\n",
684
+ " \"\"\"Simple metrics tracker for tracks prediction.\"\"\"\n",
685
  " \n",
686
  " def __init__(self, track_names: List[str]):\n",
687
  " self.track_names = track_names\n",
688
  " self.num_tracks = len(track_names)\n",
689
+ " # Metrics: comparing scaled targets with scaled predictions\n",
690
+ " self.pearson_metrics = [\n",
 
 
 
 
691
  " PearsonCorrCoef().to(device) for _ in range(self.num_tracks)\n",
692
  " ]\n",
693
  " self.losses = []\n",
694
  " \n",
695
  " def reset(self):\n",
696
+ " for metric in self.pearson_metrics:\n",
 
 
697
  " metric.reset()\n",
698
  " self.losses = []\n",
699
  " \n",
700
  " def update(\n",
701
  " self, \n",
702
+ " predictions: torch.Tensor, \n",
703
+ " targets: torch.Tensor,\n",
 
 
704
  " loss: float\n",
705
  " ):\n",
706
  " \"\"\"\n",
707
+ " Update metrics.\n",
708
  " Args:\n",
709
+ " predictions: (batch, seq_len, num_tracks)\n",
710
+ " targets: (batch, seq_len, num_tracks)\n",
 
 
711
  " loss: scalar loss value\n",
712
  " \"\"\"\n",
713
  " # Flatten batch and sequence dimensions\n",
714
+ " pred_flat = predictions.detach().reshape(-1, self.num_tracks) # (N, num_tracks)\n",
715
+ " target_flat = targets.detach().reshape(-1, self.num_tracks) # (N, num_tracks)\n",
 
 
 
 
 
 
716
  " \n",
717
+ " # Update metrics\n",
718
+ " for i, metric in enumerate(self.pearson_metrics):\n",
719
+ " metric.update(pred_flat[:, i], target_flat[:, i])\n",
720
  " \n",
721
  " self.losses.append(loss)\n",
722
  " \n",
723
  " def compute(self) -> Dict[str, float]:\n",
724
+ " \"\"\"Compute and return all metrics.\"\"\"\n",
725
  " metrics_dict = {}\n",
726
  " \n",
727
+ " # Per-track Pearson correlations\n",
728
+ " for i, (track_name, metric) in enumerate(zip(self.track_names, self.pearson_metrics)):\n",
729
  " corr = metric.compute().item()\n",
730
+ " metrics_dict[f\"{track_name}/pearson\"] = corr\n",
731
  " \n",
732
+ " # Mean Pearson correlation\n",
733
+ " correlations = [metric.compute().item() for metric in self.pearson_metrics]\n",
734
+ " metrics_dict[\"mean/pearson\"] = np.nanmean(correlations)\n",
 
 
 
 
 
 
 
 
 
735
  " \n",
736
  " # Mean loss\n",
737
  " metrics_dict[\"loss\"] = np.mean(self.losses) if self.losses else 0.0\n",
 
741
  },
742
  {
743
  "cell_type": "code",
744
+ "execution_count": 19,
745
  "metadata": {},
746
  "outputs": [],
747
  "source": [
 
754
  "cell_type": "markdown",
755
  "metadata": {},
756
  "source": [
757
+ "# 7. Loss functions"
758
  ]
759
  },
760
  {
761
  "cell_type": "code",
762
+ "execution_count": 20,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
763
  "metadata": {},
764
  "outputs": [],
765
  "source": [
 
776
  "def poisson_multinomial_loss(\n",
777
  " logits: torch.Tensor,\n",
778
  " targets: torch.Tensor,\n",
 
779
  " shape_loss_coefficient: float = 5.0,\n",
780
  " epsilon: float = 1e-7,\n",
781
  ") -> tuple[torch.Tensor, torch.Tensor | None, torch.Tensor | None]:\n",
782
  " \"\"\"\n",
783
  " Regression loss for bigwig tracks (MSE, Poisson, or Poisson-Multinomial).\n",
784
  " \"\"\"\n",
 
 
 
 
 
 
 
 
 
 
785
  "\n",
786
  " # Scale loss\n",
787
+ " sum_pred = logits.sum(dim=1) # (batch, num_tracks)\n",
788
+ " sum_true = targets.sum(dim=1) # (batch, num_tracks)\n",
 
 
 
 
789
  " scale_loss = poisson_loss(sum_true, sum_pred, epsilon=epsilon)\n",
790
+ " scale_loss = scale_loss.mean()\n",
 
 
 
 
 
 
791
  " \n",
792
  " # Shape loss\n",
793
+ " denom = logits.sum(dim=1, keepdim=True) + epsilon\n",
794
+ " p_pred = logits / denom\n",
 
 
 
 
795
  " pl_pred = safe_for_grad_log_torch(p_pred)\n",
796
+ " shape_loss = -(targets * pl_pred).mean()\n",
797
  " \n",
798
  " # Combine\n",
799
  " loss = shape_loss + scale_loss / shape_loss_coefficient\n",
 
805
  "cell_type": "markdown",
806
  "metadata": {},
807
  "source": [
808
+ "# 8. Training loop"
809
  ]
810
  },
811
  {
812
  "cell_type": "code",
813
+ "execution_count": 21,
814
  "metadata": {},
815
  "outputs": [],
816
  "source": [
817
  "def train_step(\n",
818
  " model: nn.Module,\n",
819
  " batch: Dict[str, torch.Tensor],\n",
 
 
 
 
820
  ") -> float:\n",
821
+ " \"\"\"Single training step.\"\"\"\n",
822
  " tokens = batch[\"tokens\"].to(device)\n",
823
+ " bigwig_targets = batch[\"bigwig_targets\"].to(device)\n",
824
  " \n",
825
  " # Forward pass\n",
826
  " outputs = model(tokens=tokens)\n",
827
+ " bigwig_logits = outputs[\"bigwig_tracks_logits\"]\n",
 
 
 
828
  " \n",
829
  " # Compute loss\n",
830
  " loss, _, _ = poisson_multinomial_loss(\n",
831
  " logits=bigwig_logits,\n",
832
+ " targets=bigwig_targets,\n",
 
833
  " )\n",
834
  " \n",
835
+ " # Backward pass\n",
 
 
 
836
  " loss.backward()\n",
837
+ " return loss.item()\n",
 
838
  "\n",
839
  "\n",
840
  "def validation_step(\n",
841
  " model: nn.Module,\n",
842
  " batch: Dict[str, torch.Tensor],\n",
 
 
843
  " metrics: TracksMetrics,\n",
 
844
  ") -> float:\n",
845
  " \"\"\"Single validation step.\"\"\"\n",
846
  " model.eval()\n",
 
853
  " outputs = model(tokens=tokens)\n",
854
  " bigwig_logits = outputs[\"bigwig_tracks_logits\"]\n",
855
  " \n",
856
+ " # Compute loss\n",
 
 
 
857
  " loss, _, _ = poisson_multinomial_loss(\n",
858
  " logits=bigwig_logits,\n",
859
+ " targets=bigwig_targets,\n",
 
860
  " )\n",
861
  " \n",
862
+ " # Update metrics\n",
 
 
 
 
863
  " metrics.update(\n",
864
+ " predictions=bigwig_logits,\n",
865
+ " targets=bigwig_targets,\n",
 
 
866
  " loss=loss.item()\n",
867
  " )\n",
868
  " \n",
869
  " return loss.item()"
870
  ]
871
  },
872
+ {
873
+ "cell_type": "markdown",
874
+ "metadata": {},
875
+ "source": [
876
+ "### Interactive plotting is temporary for debug"
877
+ ]
878
+ },
879
  {
880
  "cell_type": "code",
881
+ "execution_count": 22,
882
  "metadata": {},
883
  "outputs": [
884
  {
 
886
  "output_type": "stream",
887
  "text": [
888
  "Starting training...\n",
889
+ "Training for 1000 steps\n",
890
+ "\n"
891
+ ]
892
+ },
893
+ {
894
+ "data": {
895
+ "application/vnd.jupyter.widget-view+json": {
896
+ "model_id": "5935c992adb7428bac8de1aa6873dd7e",
897
+ "version_major": 2,
898
+ "version_minor": 0
899
+ },
900
+ "text/plain": [
901
+ "FigureWidget({\n",
902
+ " 'data': [{'line': {'color': 'blue'},\n",
903
+ " 'mode': 'lines+markers',\n",
904
+ " 'name': 'Train Loss',\n",
905
+ " 'type': 'scatter',\n",
906
+ " 'uid': '5424e4af-13b6-48c8-a367-8aa145c3a9db',\n",
907
+ " 'x': [],\n",
908
+ " 'xaxis': 'x',\n",
909
+ " 'y': [],\n",
910
+ " 'yaxis': 'y'},\n",
911
+ " {'line': {'color': 'red'},\n",
912
+ " 'mode': 'lines+markers',\n",
913
+ " 'name': 'Val Loss',\n",
914
+ " 'type': 'scatter',\n",
915
+ " 'uid': 'fe995660-5f01-4c12-9d7d-9ed19ddee785',\n",
916
+ " 'x': [],\n",
917
+ " 'xaxis': 'x',\n",
918
+ " 'y': [],\n",
919
+ " 'yaxis': 'y'},\n",
920
+ " {'line': {'color': 'green'},\n",
921
+ " 'mode': 'lines+markers',\n",
922
+ " 'name': 'Train Pearson',\n",
923
+ " 'type': 'scatter',\n",
924
+ " 'uid': '8453b45b-4613-41bc-a46b-ac59ba9e6f97',\n",
925
+ " 'x': [],\n",
926
+ " 'xaxis': 'x2',\n",
927
+ " 'y': [],\n",
928
+ " 'yaxis': 'y2'},\n",
929
+ " {'line': {'color': 'orange'},\n",
930
+ " 'mode': 'lines+markers',\n",
931
+ " 'name': 'Val Pearson',\n",
932
+ " 'type': 'scatter',\n",
933
+ " 'uid': '0887ea97-abf9-4fcf-8ea8-c638dc153a4d',\n",
934
+ " 'x': [],\n",
935
+ " 'xaxis': 'x2',\n",
936
+ " 'y': [],\n",
937
+ " 'yaxis': 'y2'}],\n",
938
+ " 'layout': {'annotations': [{'font': {'size': 16},\n",
939
+ " 'showarrow': False,\n",
940
+ " 'text': 'Loss',\n",
941
+ " 'x': 0.2125,\n",
942
+ " 'xanchor': 'center',\n",
943
+ " 'xref': 'paper',\n",
944
+ " 'y': 1.0,\n",
945
+ " 'yanchor': 'bottom',\n",
946
+ " 'yref': 'paper'},\n",
947
+ " {'font': {'size': 16},\n",
948
+ " 'showarrow': False,\n",
949
+ " 'text': 'Mean Pearson Correlation',\n",
950
+ " 'x': 0.7875,\n",
951
+ " 'xanchor': 'center',\n",
952
+ " 'xref': 'paper',\n",
953
+ " 'y': 1.0,\n",
954
+ " 'yanchor': 'bottom',\n",
955
+ " 'yref': 'paper'}],\n",
956
+ " 'height': 800,\n",
957
+ " 'showlegend': True,\n",
958
+ " 'template': '...',\n",
959
+ " 'title': {'text': 'Training'},\n",
960
+ " 'width': 1600,\n",
961
+ " 'xaxis': {'anchor': 'y', 'domain': [0.0, 0.425], 'title': {'text': 'Step'}},\n",
962
+ " 'xaxis2': {'anchor': 'y2', 'domain': [0.575, 1.0], 'title': {'text': 'Step'}},\n",
963
+ " 'yaxis': {'anchor': 'x', 'domain': [0.0, 1.0], 'title': {'text': 'Loss'}},\n",
964
+ " 'yaxis2': {'anchor': 'x2', 'domain': [0.0, 1.0], 'title': {'text': 'Pearson Correlation'}}}\n",
965
+ "})"
966
+ ]
967
+ },
968
+ "metadata": {},
969
+ "output_type": "display_data"
970
+ },
971
+ {
972
+ "name": "stderr",
973
+ "output_type": "stream",
974
+ "text": [
975
+ "/home/y-bornachot/venvs/ntv3-env/lib/python3.12/site-packages/torch/amp/autocast_mode.py:287: UserWarning:\n",
976
+ "\n",
977
+ "In CPU autocast, but the target dtype is not supported. Disabling autocast.\n",
978
+ "CPU Autocast only supports dtype of torch.bfloat16, torch.float16 currently.\n",
979
+ "\n"
980
+ ]
981
+ },
982
+ {
983
+ "name": "stdout",
984
+ "output_type": "stream",
985
+ "text": [
986
+ "Step 10/1000 | Loss: 0.2374 | Mean Pearson: 0.0382 | LR: 1.00e-05\n",
987
+ "Step 20/1000 | Loss: 2.2259 | Mean Pearson: -0.0884 | LR: 1.00e-05\n",
988
+ "Step 30/1000 | Loss: 20.0122 | Mean Pearson: 0.1379 | LR: 1.00e-05\n",
989
+ "Step 40/1000 | Loss: 9.6938 | Mean Pearson: -0.1497 | LR: 1.00e-05\n",
990
+ "Step 50/1000 | Loss: -1.8435 | Mean Pearson: -0.1875 | LR: 1.00e-05\n",
991
+ "\n",
992
+ "Running validation at step 50...\n",
993
+ " Validation Loss: 11.5599\n",
994
+ " Validation Mean Pearson: -0.1576\n",
995
+ " ENCFF884LDL/pearson: -0.1576\n",
996
+ "Step 60/1000 | Loss: 1.4427 | Mean Pearson: 0.2841 | LR: 1.00e-05\n",
997
+ "Step 70/1000 | Loss: -3.4037 | Mean Pearson: -0.1362 | LR: 1.00e-05\n",
998
+ "Step 80/1000 | Loss: 9.0958 | Mean Pearson: -0.1319 | LR: 1.00e-05\n",
999
+ "Step 90/1000 | Loss: -7.8433 | Mean Pearson: -0.0576 | LR: 1.00e-05\n",
1000
+ "Step 100/1000 | Loss: 7.3503 | Mean Pearson: -0.2150 | LR: 1.00e-05\n",
1001
  "\n",
1002
+ "Running validation at step 100...\n",
1003
+ " Validation Loss: 22.3383\n",
1004
+ " Validation Mean Pearson: -0.2867\n",
1005
+ " ENCFF884LDL/pearson: -0.2867\n",
1006
+ "Step 110/1000 | Loss: -8.1600 | Mean Pearson: -0.1616 | LR: 1.00e-05\n",
1007
+ "Step 120/1000 | Loss: -0.8743 | Mean Pearson: -0.1318 | LR: 1.00e-05\n",
1008
+ "Step 130/1000 | Loss: -2.9825 | Mean Pearson: -0.0480 | LR: 1.00e-05\n",
1009
+ "Step 140/1000 | Loss: -2.4524 | Mean Pearson: -0.0879 | LR: 1.00e-05\n",
1010
+ "Step 150/1000 | Loss: 3.8818 | Mean Pearson: -0.0907 | LR: 1.00e-05\n",
1011
  "\n",
1012
+ "Running validation at step 150...\n",
1013
+ " Validation Loss: 19.6866\n",
1014
+ " Validation Mean Pearson: -0.2207\n",
1015
+ " ENCFF884LDL/pearson: -0.2207\n",
1016
+ "Step 160/1000 | Loss: -1.0933 | Mean Pearson: -0.1243 | LR: 1.00e-05\n",
1017
+ "Step 170/1000 | Loss: -2.2577 | Mean Pearson: -0.0212 | LR: 1.00e-05\n",
1018
+ "Step 180/1000 | Loss: 0.0738 | Mean Pearson: 0.5643 | LR: 1.00e-05\n",
1019
+ "Step 190/1000 | Loss: -0.1097 | Mean Pearson: 0.0309 | LR: 1.00e-05\n",
1020
+ "Step 200/1000 | Loss: -8.7972 | Mean Pearson: 0.4804 | LR: 1.00e-05\n",
1021
  "\n",
1022
+ "Running validation at step 200...\n",
1023
+ " Validation Loss: -8.8160\n",
1024
+ " Validation Mean Pearson: 0.0912\n",
1025
+ " ENCFF884LDL/pearson: 0.0912\n",
1026
+ "Step 210/1000 | Loss: -2.5429 | Mean Pearson: 0.3908 | LR: 1.00e-05\n",
1027
+ "Step 220/1000 | Loss: -6.8421 | Mean Pearson: 0.4080 | LR: 1.00e-05\n",
1028
+ "Step 230/1000 | Loss: -4.4312 | Mean Pearson: -0.0400 | LR: 1.00e-05\n",
1029
+ "Step 240/1000 | Loss: -11.4732 | Mean Pearson: 0.6653 | LR: 1.00e-05\n",
1030
+ "Step 250/1000 | Loss: -9.2648 | Mean Pearson: 0.0539 | LR: 1.00e-05\n",
1031
  "\n",
1032
+ "Running validation at step 250...\n",
1033
+ " Validation Loss: -6.8987\n",
1034
+ " Validation Mean Pearson: 0.0654\n",
1035
+ " ENCFF884LDL/pearson: 0.0654\n",
1036
+ "Step 260/1000 | Loss: -0.6699 | Mean Pearson: 0.0913 | LR: 1.00e-05\n",
1037
+ "Step 270/1000 | Loss: -8.6625 | Mean Pearson: 0.3179 | LR: 1.00e-05\n",
1038
+ "Step 280/1000 | Loss: -11.7691 | Mean Pearson: 0.0004 | LR: 1.00e-05\n",
1039
+ "Step 290/1000 | Loss: -14.1622 | Mean Pearson: 0.0492 | LR: 1.00e-05\n",
1040
+ "Step 300/1000 | Loss: 0.9208 | Mean Pearson: 0.0607 | LR: 1.00e-05\n",
1041
  "\n",
1042
+ "Running validation at step 300...\n",
1043
+ " Validation Loss: -5.0427\n",
1044
+ " Validation Mean Pearson: 0.3464\n",
1045
+ " ENCFF884LDL/pearson: 0.3464\n",
1046
+ "Step 310/1000 | Loss: -1.2881 | Mean Pearson: 0.1696 | LR: 1.00e-05\n",
1047
+ "Step 320/1000 | Loss: -18.6637 | Mean Pearson: 0.0892 | LR: 1.00e-05\n",
1048
+ "Step 330/1000 | Loss: -36.6038 | Mean Pearson: 0.3356 | LR: 1.00e-05\n",
1049
+ "Step 340/1000 | Loss: -2.4984 | Mean Pearson: 0.2305 | LR: 1.00e-05\n",
1050
+ "Step 350/1000 | Loss: -4.7985 | Mean Pearson: 0.0968 | LR: 1.00e-05\n",
1051
  "\n",
1052
+ "Running validation at step 350...\n",
1053
+ " Validation Loss: -13.6500\n",
1054
+ " Validation Mean Pearson: 0.2737\n",
1055
+ " ENCFF884LDL/pearson: 0.2737\n",
1056
+ "Step 360/1000 | Loss: -9.4795 | Mean Pearson: 0.0579 | LR: 1.00e-05\n",
1057
+ "Step 370/1000 | Loss: 0.3531 | Mean Pearson: 0.0240 | LR: 1.00e-05\n",
1058
+ "Step 380/1000 | Loss: -5.7921 | Mean Pearson: 0.4119 | LR: 1.00e-05\n",
1059
+ "Step 390/1000 | Loss: -2.7049 | Mean Pearson: 0.1343 | LR: 1.00e-05\n",
1060
+ "Step 400/1000 | Loss: -32.8422 | Mean Pearson: 0.1545 | LR: 1.00e-05\n",
1061
  "\n",
1062
+ "Running validation at step 400...\n",
1063
+ " Validation Loss: -4.3502\n",
1064
+ " Validation Mean Pearson: 0.3124\n",
1065
+ " ENCFF884LDL/pearson: 0.3124\n",
1066
+ "Step 410/1000 | Loss: -18.9574 | Mean Pearson: 0.0594 | LR: 1.00e-05\n",
1067
+ "Step 420/1000 | Loss: -5.4032 | Mean Pearson: 0.2804 | LR: 1.00e-05\n",
1068
+ "Step 430/1000 | Loss: -0.5171 | Mean Pearson: 0.1835 | LR: 1.00e-05\n",
1069
+ "Step 440/1000 | Loss: -3.4071 | Mean Pearson: 0.0680 | LR: 1.00e-05\n",
1070
+ "Step 450/1000 | Loss: -3.5580 | Mean Pearson: 0.0850 | LR: 1.00e-05\n",
1071
  "\n",
1072
+ "Running validation at step 450...\n",
1073
+ " Validation Loss: -7.3308\n",
1074
+ " Validation Mean Pearson: 0.1128\n",
1075
+ " ENCFF884LDL/pearson: 0.1128\n",
1076
+ "Step 460/1000 | Loss: -0.9750 | Mean Pearson: 0.1717 | LR: 1.00e-05\n",
1077
+ "Step 470/1000 | Loss: -5.5775 | Mean Pearson: 0.1321 | LR: 1.00e-05\n",
1078
+ "Step 480/1000 | Loss: -1.1170 | Mean Pearson: 0.1484 | LR: 1.00e-05\n",
1079
+ "Step 490/1000 | Loss: -3.8053 | Mean Pearson: 0.1959 | LR: 1.00e-05\n",
1080
+ "Step 500/1000 | Loss: -4.5933 | Mean Pearson: 0.1860 | LR: 1.00e-05\n",
1081
  "\n",
1082
+ "Running validation at step 500...\n",
1083
+ " Validation Loss: -5.7617\n",
1084
+ " Validation Mean Pearson: 0.3155\n",
1085
+ " ENCFF884LDL/pearson: 0.3155\n",
1086
+ "Step 510/1000 | Loss: -3.3306 | Mean Pearson: 0.2815 | LR: 1.00e-05\n",
1087
+ "Step 520/1000 | Loss: -2.1962 | Mean Pearson: 0.1151 | LR: 1.00e-05\n",
1088
+ "Step 530/1000 | Loss: -1.5388 | Mean Pearson: 0.3783 | LR: 1.00e-05\n",
1089
+ "Step 540/1000 | Loss: -2.2349 | Mean Pearson: 0.0734 | LR: 1.00e-05\n",
1090
+ "Step 550/1000 | Loss: -1.5502 | Mean Pearson: 0.2171 | LR: 1.00e-05\n",
1091
  "\n",
1092
+ "Running validation at step 550...\n",
1093
+ " Validation Loss: -3.0059\n",
1094
+ " Validation Mean Pearson: 0.2325\n",
1095
+ " ENCFF884LDL/pearson: 0.2325\n",
1096
+ "Step 560/1000 | Loss: -2.0764 | Mean Pearson: -0.0049 | LR: 1.00e-05\n",
1097
+ "Step 570/1000 | Loss: -1.7384 | Mean Pearson: 0.2989 | LR: 1.00e-05\n",
1098
+ "Step 580/1000 | Loss: -6.7306 | Mean Pearson: 0.2522 | LR: 1.00e-05\n",
1099
+ "Step 590/1000 | Loss: -3.2473 | Mean Pearson: 0.1042 | LR: 1.00e-05\n",
1100
+ "Step 600/1000 | Loss: -4.2841 | Mean Pearson: 0.1936 | LR: 1.00e-05\n",
1101
+ "\n",
1102
+ "Running validation at step 600...\n",
1103
+ " Validation Loss: -4.5611\n",
1104
+ " Validation Mean Pearson: 0.2744\n",
1105
+ " ENCFF884LDL/pearson: 0.2744\n",
1106
+ "Step 610/1000 | Loss: -3.5691 | Mean Pearson: 0.1803 | LR: 1.00e-05\n",
1107
+ "Step 620/1000 | Loss: -7.2129 | Mean Pearson: 0.0901 | LR: 1.00e-05\n",
1108
+ "Step 630/1000 | Loss: -6.0598 | Mean Pearson: 0.1795 | LR: 1.00e-05\n",
1109
+ "Step 640/1000 | Loss: -2.8917 | Mean Pearson: 0.1111 | LR: 1.00e-05\n",
1110
+ "Step 650/1000 | Loss: -2.7210 | Mean Pearson: 0.3566 | LR: 1.00e-05\n",
1111
+ "\n",
1112
+ "Running validation at step 650...\n",
1113
+ " Validation Loss: -4.3997\n",
1114
+ " Validation Mean Pearson: 0.3327\n",
1115
+ " ENCFF884LDL/pearson: 0.3327\n",
1116
+ "Step 660/1000 | Loss: -3.4793 | Mean Pearson: 0.0441 | LR: 1.00e-05\n",
1117
+ "Step 670/1000 | Loss: -1.9743 | Mean Pearson: 0.1364 | LR: 1.00e-05\n",
1118
+ "Step 680/1000 | Loss: -5.7498 | Mean Pearson: 0.2330 | LR: 1.00e-05\n",
1119
+ "Step 690/1000 | Loss: -12.8701 | Mean Pearson: 0.3182 | LR: 1.00e-05\n",
1120
+ "Step 700/1000 | Loss: -1.5847 | Mean Pearson: 0.1971 | LR: 1.00e-05\n",
1121
+ "\n",
1122
+ "Running validation at step 700...\n",
1123
+ " Validation Loss: -2.0630\n",
1124
+ " Validation Mean Pearson: 0.1267\n",
1125
+ " ENCFF884LDL/pearson: 0.1267\n",
1126
+ "Step 710/1000 | Loss: -6.0704 | Mean Pearson: 0.3715 | LR: 1.00e-05\n",
1127
+ "Step 720/1000 | Loss: -2.6020 | Mean Pearson: 0.1244 | LR: 1.00e-05\n",
1128
+ "Step 730/1000 | Loss: -58.8965 | Mean Pearson: 0.5625 | LR: 1.00e-05\n",
1129
+ "Step 740/1000 | Loss: -1.2855 | Mean Pearson: 0.2658 | LR: 1.00e-05\n",
1130
+ "Step 750/1000 | Loss: -4.4599 | Mean Pearson: 0.0137 | LR: 1.00e-05\n",
1131
+ "\n",
1132
+ "Running validation at step 750...\n",
1133
+ " Validation Loss: -11.1562\n",
1134
+ " Validation Mean Pearson: 0.0844\n",
1135
+ " ENCFF884LDL/pearson: 0.0844\n",
1136
+ "Step 760/1000 | Loss: -11.6905 | Mean Pearson: 0.1914 | LR: 1.00e-05\n",
1137
+ "Step 770/1000 | Loss: -4.0964 | Mean Pearson: 0.2022 | LR: 1.00e-05\n",
1138
+ "Step 780/1000 | Loss: -1.5512 | Mean Pearson: 0.3568 | LR: 1.00e-05\n",
1139
+ "Step 790/1000 | Loss: -5.5843 | Mean Pearson: 0.2058 | LR: 1.00e-05\n",
1140
+ "Step 800/1000 | Loss: -3.9190 | Mean Pearson: 0.4362 | LR: 1.00e-05\n",
1141
+ "\n",
1142
+ "Running validation at step 800...\n",
1143
+ " Validation Loss: -4.7017\n",
1144
+ " Validation Mean Pearson: 0.3817\n",
1145
+ " ENCFF884LDL/pearson: 0.3817\n",
1146
+ "Step 810/1000 | Loss: -7.6856 | Mean Pearson: 0.0672 | LR: 1.00e-05\n",
1147
+ "Step 820/1000 | Loss: -5.3603 | Mean Pearson: 0.2325 | LR: 1.00e-05\n",
1148
+ "Step 830/1000 | Loss: -3.8539 | Mean Pearson: 0.2808 | LR: 1.00e-05\n",
1149
+ "Step 840/1000 | Loss: -8.1141 | Mean Pearson: 0.2529 | LR: 1.00e-05\n",
1150
+ "Step 850/1000 | Loss: -10.5886 | Mean Pearson: 0.3454 | LR: 1.00e-05\n",
1151
+ "\n",
1152
+ "Running validation at step 850...\n",
1153
+ " Validation Loss: -4.9108\n",
1154
+ " Validation Mean Pearson: 0.2195\n",
1155
+ " ENCFF884LDL/pearson: 0.2195\n",
1156
+ "Step 860/1000 | Loss: -4.1028 | Mean Pearson: 0.3304 | LR: 1.00e-05\n",
1157
+ "Step 870/1000 | Loss: -7.1834 | Mean Pearson: 0.1206 | LR: 1.00e-05\n",
1158
+ "Step 880/1000 | Loss: -8.9869 | Mean Pearson: 0.3584 | LR: 1.00e-05\n",
1159
+ "Step 890/1000 | Loss: -2.2697 | Mean Pearson: 0.0943 | LR: 1.00e-05\n",
1160
+ "Step 900/1000 | Loss: -14.0142 | Mean Pearson: 0.4761 | LR: 1.00e-05\n",
1161
+ "\n",
1162
+ "Running validation at step 900...\n",
1163
+ " Validation Loss: -3.2329\n",
1164
+ " Validation Mean Pearson: 0.3635\n",
1165
+ " ENCFF884LDL/pearson: 0.3635\n",
1166
+ "Step 910/1000 | Loss: -9.0941 | Mean Pearson: 0.2754 | LR: 1.00e-05\n",
1167
+ "Step 920/1000 | Loss: -4.6371 | Mean Pearson: 0.0167 | LR: 1.00e-05\n",
1168
+ "Step 930/1000 | Loss: -7.9853 | Mean Pearson: 0.0941 | LR: 1.00e-05\n",
1169
+ "Step 940/1000 | Loss: -22.9349 | Mean Pearson: 0.5140 | LR: 1.00e-05\n",
1170
+ "Step 950/1000 | Loss: -2.0866 | Mean Pearson: 0.1746 | LR: 1.00e-05\n",
1171
+ "\n",
1172
+ "Running validation at step 950...\n",
1173
+ " Validation Loss: -8.8318\n",
1174
+ " Validation Mean Pearson: 0.1597\n",
1175
+ " ENCFF884LDL/pearson: 0.1597\n",
1176
+ "Step 960/1000 | Loss: -4.8540 | Mean Pearson: 0.6318 | LR: 1.00e-05\n",
1177
+ "Step 970/1000 | Loss: -4.1091 | Mean Pearson: 0.0985 | LR: 1.00e-05\n",
1178
+ "Step 980/1000 | Loss: -5.1141 | Mean Pearson: 0.2031 | LR: 1.00e-05\n",
1179
+ "Step 990/1000 | Loss: -4.1959 | Mean Pearson: 0.2404 | LR: 1.00e-05\n",
1180
+ "Step 1000/1000 | Loss: -0.9942 | Mean Pearson: 0.2742 | LR: 1.00e-05\n",
1181
+ "\n",
1182
+ "Running validation at step 1000...\n",
1183
+ " Validation Loss: -4.2796\n",
1184
+ " Validation Mean Pearson: 0.1425\n",
1185
+ " ENCFF884LDL/pearson: 0.1425\n",
1186
+ "\n",
1187
+ "Training completed after 1000 steps.\n"
1188
  ]
1189
  }
1190
  ],
1191
  "source": [
1192
+ "# Training loop\n",
1193
  "print(\"Starting training...\")\n",
1194
+ "print(f\"Training for {config[\"num_steps_training\"]} steps\\n\")\n",
1195
  "\n",
1196
  "model.train()\n",
1197
  "train_metrics.reset()\n",
1198
  "optimizer.zero_grad() # Initialize gradients\n",
1199
  "\n",
1200
+ "# Track metrics for plotting\n",
1201
+ "train_steps = []\n",
1202
+ "train_losses = []\n",
1203
+ "train_pearson_scores = []\n",
1204
+ "val_steps = []\n",
1205
+ "val_losses = []\n",
1206
+ "val_pearson_scores = []\n",
1207
+ "\n",
1208
+ "# Initialize interactive plots using FigureWidget for real-time updates\n",
1209
+ "from plotly.graph_objects import FigureWidget\n",
1210
+ "from plotly.subplots import make_subplots\n",
1211
+ "\n",
1212
+ "# Create base figure with subplots\n",
1213
+ "fig_base = make_subplots(\n",
1214
+ " rows=1, cols=2,\n",
1215
+ " subplot_titles=('Loss', 'Mean Pearson Correlation'),\n",
1216
+ " horizontal_spacing=0.15,\n",
1217
+ ")\n",
1218
+ "\n",
1219
+ "# Add empty traces for train and val metrics\n",
1220
+ "fig_base.add_trace(\n",
1221
+ " go.Scatter(x=[], y=[], mode='lines+markers', name='Train Loss', line=dict(color='blue')),\n",
1222
+ " row=1, col=1\n",
1223
+ ")\n",
1224
+ "fig_base.add_trace(\n",
1225
+ " go.Scatter(x=[], y=[], mode='lines+markers', name='Val Loss', line=dict(color='red')),\n",
1226
+ " row=1, col=1\n",
1227
+ ")\n",
1228
+ "fig_base.add_trace(\n",
1229
+ " go.Scatter(x=[], y=[], mode='lines+markers', name='Train Pearson', line=dict(color='green')),\n",
1230
+ " row=1, col=2\n",
1231
+ ")\n",
1232
+ "fig_base.add_trace(\n",
1233
+ " go.Scatter(x=[], y=[], mode='lines+markers', name='Val Pearson', line=dict(color='orange')),\n",
1234
+ " row=1, col=2\n",
1235
+ ")\n",
1236
+ "\n",
1237
+ "fig_base.update_xaxes(title_text=\"Step\", row=1, col=1)\n",
1238
+ "fig_base.update_xaxes(title_text=\"Step\", row=1, col=2)\n",
1239
+ "fig_base.update_yaxes(title_text=\"Loss\", row=1, col=1)\n",
1240
+ "fig_base.update_yaxes(title_text=\"Pearson Correlation\", row=1, col=2)\n",
1241
+ "fig_base.update_layout(height=800, width=1600, showlegend=True, title_text=\"Training\")\n",
1242
+ "\n",
1243
+ "# Convert to FigureWidget for interactive updates\n",
1244
+ "fig = FigureWidget(fig_base)\n",
1245
+ "\n",
1246
+ "# Display initial plot (will update in place during training)\n",
1247
+ "display(fig)\n",
1248
+ "\n",
1249
  "# Create iterator for training data (will cycle if needed)\n",
1250
  "train_iter = iter(train_loader)\n",
1251
+ "\n",
1252
+ "# Main training loop\n",
1253
+ "for step_idx in range(config[\"num_steps_training\"]):\n",
1254
+ " try:\n",
1255
+ " batch = next(train_iter)\n",
1256
+ " except StopIteration:\n",
1257
+ " # Restart iterator if we run out of data\n",
1258
+ " train_iter = iter(train_loader)\n",
1259
+ " batch = next(train_iter)\n",
 
 
 
 
 
 
 
 
 
 
 
1260
  " \n",
1261
+ " # Forward pass and backward pass\n",
1262
+ " loss = train_step(model, batch)\n",
1263
+ " \n",
1264
+ " # Update optimizer\n",
1265
  " optimizer.step()\n",
1266
  " optimizer.zero_grad()\n",
1267
  " \n",
1268
+ " # Update metrics\n",
 
 
 
1269
  " tokens = batch[\"tokens\"].to(device)\n",
1270
  " bigwig_targets = batch[\"bigwig_targets\"].to(device)\n",
1271
  " with torch.no_grad():\n",
1272
  " outputs = model(tokens=tokens)\n",
1273
  " bigwig_logits = outputs[\"bigwig_tracks_logits\"]\n",
1274
  " \n",
 
 
 
 
 
 
 
1275
  " train_metrics.update(\n",
1276
+ " predictions=bigwig_logits,\n",
1277
+ " targets=bigwig_targets,\n",
1278
+ " loss=loss\n",
 
 
1279
  " )\n",
1280
  " \n",
1281
  " # Logging\n",
1282
+ " if (step_idx + 1) % config[\"log_every_n_steps\"] == 0:\n",
1283
  " train_metrics_dict = train_metrics.compute()\n",
1284
+ " current_lr = optimizer.param_groups[0]['lr']\n",
1285
+ " \n",
1286
+ " # Track metrics for plotting\n",
1287
+ " train_steps.append(step_idx + 1)\n",
1288
+ " train_losses.append(loss)\n",
1289
+ " train_pearson_scores.append(train_metrics_dict['mean/pearson'])\n",
1290
+ " \n",
1291
+ " # Update plots - direct assignment to FigureWidget data updates the plot automatically\n",
1292
+ " fig.data[0].x = train_steps\n",
1293
+ " fig.data[0].y = train_losses\n",
1294
+ " fig.data[2].x = train_steps\n",
1295
+ " fig.data[2].y = train_pearson_scores\n",
1296
+ " \n",
1297
+ " print(f\"Step {step_idx + 1}/{config[\"num_steps_training\"]} | \"\n",
1298
+ " f\"Loss: {loss:.4f} | \"\n",
1299
+ " f\"Mean Pearson: {train_metrics_dict['mean/pearson']:.4f} | \"\n",
1300
+ " f\"LR: {current_lr:.2e}\")\n",
1301
  " train_metrics.reset()\n",
1302
  " \n",
1303
  " # Validation\n",
1304
+ " if (step_idx + 1) % config[\"validate_every_n_steps\"] == 0:\n",
1305
+ " print(f\"\\nRunning validation at step {step_idx + 1}...\")\n",
1306
  " val_metrics.reset()\n",
1307
  " model.eval()\n",
1308
  " \n",
1309
+ " val_batch_losses = []\n",
1310
  " for val_batch in val_loader:\n",
1311
+ " val_loss = validation_step(model, val_batch, val_metrics)\n",
1312
+ " val_batch_losses.append(val_loss)\n",
 
 
1313
  " \n",
1314
  " # Print validation metrics\n",
1315
  " val_metrics_dict = val_metrics.compute()\n",
1316
+ " val_loss_mean = np.mean(val_batch_losses)\n",
1317
+ " val_pearson_mean = val_metrics_dict['mean/pearson']\n",
1318
+ " \n",
1319
+ " # Track validation metrics\n",
1320
+ " val_steps.append(step_idx + 1)\n",
1321
+ " val_losses.append(val_loss_mean)\n",
1322
+ " val_pearson_scores.append(val_pearson_mean)\n",
1323
+ " \n",
1324
+ " # Update plots with validation data - direct assignment updates the plot automatically\n",
1325
+ " fig.data[1].x = val_steps\n",
1326
+ " fig.data[1].y = val_losses\n",
1327
+ " fig.data[3].x = val_steps\n",
1328
+ " fig.data[3].y = val_pearson_scores\n",
1329
+ " \n",
1330
+ " print(f\" Validation Loss: {val_loss_mean:.4f}\")\n",
1331
+ " print(f\" Validation Mean Pearson: {val_pearson_mean:.4f}\")\n",
1332
  " for track_name in config[\"bigwig_file_ids\"]:\n",
1333
+ " print(f\" {track_name}/pearson: {val_metrics_dict[f'{track_name}/pearson']:.4f}\")\n",
1334
  " \n",
1335
  " model.train() # Back to training mode\n",
1336
  "\n",
1337
+ "print(f\"\\nTraining completed after {config[\"num_steps_training\"]} steps.\")"
1338
  ]
1339
  },
1340
  {
 
1346
  },
1347
  {
1348
  "cell_type": "code",
1349
+ "execution_count": 24,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1350
  "metadata": {},
1351
  "outputs": [
1352
  {
1353
  "name": "stdout",
1354
  "output_type": "stream",
1355
  "text": [
1356
+ "Running test evaluation with 12 steps (100 samples)\n",
 
 
 
 
1357
  "\n",
1358
  "==================================================\n",
1359
  "Test Set Results\n",
1360
  "==================================================\n",
1361
  "\n",
1362
+ "Metrics:\n",
1363
+ " Mean Pearson: 0.1787\n",
1364
+ " ENCFF884LDL/pearson: 0.1787\n"
 
 
 
 
 
1365
  ]
1366
  }
1367
  ],
1368
  "source": [
 
 
 
 
1369
  "# Calculate number of test steps (based on deepspeed pipeline)\n",
1370
  "num_test_samples = len(test_dataset)\n",
1371
  "num_test_steps = num_test_samples // config[\"batch_size\"]\n",
 
1372
  "print(f\"Running test evaluation with {num_test_steps} steps ({num_test_samples} samples)\")\n",
1373
  "\n",
1374
  "# Set model to eval mode\n",
1375
  "model.eval()\n",
1376
  "\n",
1377
+ "for test_batch in test_loader: \n",
 
1378
  "\n",
1379
+ " _ = validation_step( \n",
1380
+ " model, \n",
1381
+ " test_batch, \n",
1382
+ " test_metrics,\n",
 
 
 
 
 
 
1383
  " )\n",
1384
+ " \n",
1385
  "# Compute final test metrics\n",
1386
  "test_metrics_dict = test_metrics.compute()\n",
 
1387
  "print(\"\\n\" + \"=\"*50)\n",
1388
  "print(\"Test Set Results\")\n",
1389
  "print(\"=\"*50)\n",
1390
+ "print(f\"\\nMetrics:\")\n",
1391
+ "print(f\" Mean Pearson: {test_metrics_dict['mean/pearson']:.4f}\")\n",
1392
+ "for track_name in config[\"bigwig_file_ids\"]: \n",
1393
+ " print(f\" {track_name}/pearson: {test_metrics_dict[f'{track_name}/pearson']:.4f}\")"
 
 
 
 
 
 
1394
  ]
1395
+ },
1396
+ {
1397
+ "cell_type": "code",
1398
+ "execution_count": null,
1399
+ "metadata": {},
1400
+ "outputs": [],
1401
+ "source": []
1402
  }
1403
  ],
1404
  "metadata": {