ybornachot commited on
Commit
2b05bdb
·
1 Parent(s): 6db01cc

feat: made compat with HF dataset + refactor

Browse files
Files changed (1) hide show
  1. notebooks/03_fine_tuning.ipynb +212 -339
notebooks/03_fine_tuning.ipynb CHANGED
@@ -52,7 +52,7 @@
52
  },
53
  {
54
  "cell_type": "code",
55
- "execution_count": 20,
56
  "metadata": {},
57
  "outputs": [],
58
  "source": [
@@ -60,8 +60,9 @@
60
  "import functools\n",
61
  "from typing import List, Dict, Callable\n",
62
  "import os\n",
63
- "import subprocess\n",
64
- "from concurrent.futures import ThreadPoolExecutor, as_completed\n",
 
65
  "\n",
66
  "import torch\n",
67
  "import torch.nn as nn\n",
@@ -69,6 +70,8 @@
69
  "from torch.utils.data import Dataset, DataLoader\n",
70
  "from torch.optim import AdamW\n",
71
  "from transformers import AutoConfig, AutoModelForMaskedLM, AutoTokenizer\n",
 
 
72
  "import numpy as np\n",
73
  "import pyBigWig\n",
74
  "from pyfaidx import Fasta\n",
@@ -90,10 +93,9 @@
90
  "- **`model_name`**: HuggingFace model name/identifier for the pretrained backbone model\n",
91
  "\n",
92
  "### Data\n",
 
 
93
  "- **`data_cache_dir`**: Directory where downloaded data files (FASTA, bigWig) will be stored\n",
94
- "- **`fasta_url`**: URL to download reference genome FASTA file\n",
95
- "- **`bigwig_url_list`**: List of URLs for bigWig track files to download\n",
96
- "- **`bigwig_file_ids`**: List of identifiers/names for bigWig tracks (set after downloading, used for model head and metrics)\n",
97
  "- **`sequence_length`**: Length of input sequences in base pairs (bp)\n",
98
  "- **`keep_target_center_fraction`**: Fraction of center sequence to keep for target prediction (crops edges to focus on center)\n",
99
  "\n",
@@ -108,6 +110,9 @@
108
  "- **`validate_every_n_steps`**: Run validation every N steps\n",
109
  "- **`num_validation_samples`**: Number of samples to use for validation set\n",
110
  "\n",
 
 
 
111
  "### General\n",
112
  "- **`seed`**: Random seed for reproducibility\n",
113
  "- **`device`**: Device to run training on (\"cuda\" or \"cpu\")\n",
@@ -116,31 +121,18 @@
116
  },
117
  {
118
  "cell_type": "code",
119
- "execution_count": 21,
120
  "metadata": {},
121
- "outputs": [
122
- {
123
- "name": "stdout",
124
- "output_type": "stream",
125
- "text": [
126
- "Using device: cpu\n"
127
- ]
128
- }
129
- ],
130
  "source": [
131
  "config = {\n",
132
  " # Model\n",
133
  " \"model_name\": \"InstaDeepAI/NTv3_8M_pre\",\n",
134
  " \n",
135
  " # Data\n",
 
 
136
  " \"data_cache_dir\": \"./data\",\n",
137
- " \"fasta_url\": \"https://hgdownload.gi.ucsc.edu/goldenPath/hg38/bigZips/hg38.fa.gz\",\n",
138
- " \"bigwig_url_list\": [\n",
139
- " \"https://www.encodeproject.org/files/ENCFF055QKS/@@download/ENCFF055QKS.bigWig\",\n",
140
- " \"https://www.encodeproject.org/files/ENCFF214GOQ/@@download/ENCFF214GOQ.bigWig\",\n",
141
- " \"https://www.encodeproject.org/files/ENCFF592NIB/@@download/ENCFF592NIB.bigWig\",\n",
142
- " \"https://www.encodeproject.org/files/ENCFF921PHQ/@@download/ENCFF921PHQ.bigWig\",\n",
143
- " ],\n",
144
  " \"sequence_length\": 32_768,\n",
145
  " \"keep_target_center_fraction\": 0.375,\n",
146
  " \n",
@@ -159,40 +151,11 @@
159
  " \"num_test_samples\": 10000,\n",
160
  " \n",
161
  " # General\n",
162
- " \"seed\": 17,\n",
163
  " \"device\": \"cuda\" if torch.cuda.is_available() else \"cpu\",\n",
164
  " \"num_workers\": 16,\n",
165
  "}\n",
166
  "\n",
167
- "os.makedirs(config[\"data_cache_dir\"], exist_ok=True)\n",
168
- "\n",
169
- "# Extract filenames from URLs\n",
170
- "def extract_filename_from_url(url: str) -> str:\n",
171
- " \"\"\"Extract filename from URL, handling query parameters.\"\"\"\n",
172
- " # Remove query parameters if present\n",
173
- " url_clean = url.split('?')[0]\n",
174
- " # Get the last part of the URL path\n",
175
- " return url_clean.split('/')[-1]\n",
176
- "\n",
177
- "# Create paths for downloaded files\n",
178
- "fasta_path = os.path.join(config[\"data_cache_dir\"], extract_filename_from_url(config[\"fasta_url\"]).replace('.gz', ''))\n",
179
- "bigwig_path_list = [\n",
180
- " os.path.join(config[\"data_cache_dir\"], extract_filename_from_url(url))\n",
181
- " for url in config[\"bigwig_url_list\"]\n",
182
- "]\n",
183
- "\n",
184
- "\n",
185
- "# TODO: find a way to link the experiment accession to bigwig file ids\n",
186
- "# Create bigwig_file_ids from filenames (without extension)\n",
187
- "config[\"bigwig_file_ids\"] = [\n",
188
- " # os.path.splitext(extract_filename_from_url(url))[0]\n",
189
- " # for url in config[\"bigwig_url_list\"]\n",
190
- " \"ENCSR325NFE\",\n",
191
- " \"ENCSR962OTG\",\n",
192
- " \"ENCSR619DQO_P\",\n",
193
- " \"ENCSR619DQO_M\",\n",
194
- "]\n",
195
- "\n",
196
  "# Set random seed\n",
197
  "torch.manual_seed(config[\"seed\"])\n",
198
  "np.random.seed(config[\"seed\"])\n",
@@ -217,56 +180,99 @@
217
  "metadata": {},
218
  "outputs": [],
219
  "source": [
220
- "def _download_file(url: str, output_path: str) -> None:\n",
221
- " \"\"\"Download a file from URL to output_path using wget.\"\"\"\n",
222
- " subprocess.run([\"wget\", \"-c\", url, \"-O\", output_path], check=True)\n",
223
- "\n",
224
- "# Prepare download tasks: (url, output_path)\n",
225
- "download_tasks = []\n",
226
- "\n",
227
- "# FASTA file\n",
228
- "fasta_filename = extract_filename_from_url(config[\"fasta_url\"])\n",
229
- "fasta_gz_path = os.path.join(config[\"data_cache_dir\"], fasta_filename)\n",
230
- "download_tasks.append((config[\"fasta_url\"], fasta_gz_path))\n",
231
- "\n",
232
- "# BigWig files\n",
233
- "for bigwig_url in config[\"bigwig_url_list\"]:\n",
234
- " filename = extract_filename_from_url(bigwig_url)\n",
235
- " filepath = os.path.join(config[\"data_cache_dir\"], filename)\n",
236
- " download_tasks.append((bigwig_url, filepath))\n",
237
- "\n",
238
- "# Download files in parallel\n",
239
- "max_workers = min(len(download_tasks), 8)\n",
240
- "\n",
241
- "print(f\"Downloading {len(download_tasks)} files using {max_workers} workers...\")\n",
242
- "with ThreadPoolExecutor(max_workers=max_workers) as executor:\n",
243
- " # Submit all download tasks\n",
244
- " future_to_path = {\n",
245
- " executor.submit(_download_file, url, path): path\n",
246
- " for url, path in download_tasks\n",
247
- " }\n",
248
  " \n",
249
- " # Wait for all downloads to complete\n",
250
- " for future in as_completed(future_to_path):\n",
251
- " try:\n",
252
- " future.result() # Raises exception if download failed\n",
253
- " path = future_to_path[future]\n",
254
- " print(f\"✓ Downloaded: {os.path.basename(path)}\")\n",
255
- " except Exception as e:\n",
256
- " path = future_to_path[future]\n",
257
- " raise RuntimeError(f\"Failed to download {path}: {e}\") from e\n",
258
- "\n",
259
- "# Extract FASTA file after download\n",
260
- "print(f\"\\nExtracting {fasta_filename}...\")\n",
261
- "subprocess.run([\"gunzip\", \"-f\", fasta_gz_path], check=True)\n",
262
- "print(\" Extraction complete\")"
263
- ]
264
- },
265
- {
266
- "cell_type": "markdown",
267
- "metadata": {},
268
- "source": [
269
- "### Data Splits Definition"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
270
  ]
271
  },
272
  {
@@ -275,11 +281,20 @@
275
  "metadata": {},
276
  "outputs": [],
277
  "source": [
278
- "chrom_splits = {\n",
279
- " \"train\": [f\"chr{i}\" for i in range(1, 21)] + ['chrX', 'chrY'],\n",
280
- " \"val\": ['chr22'],\n",
281
- " \"test\": ['chr21']\n",
282
- "}"
 
 
 
 
 
 
 
 
 
283
  ]
284
  },
285
  {
@@ -335,7 +350,7 @@
335
  " self.backbone = AutoModelForMaskedLM.from_pretrained(\n",
336
  " model_name, \n",
337
  " trust_remote_code=True,\n",
338
- " config=self.config\n",
339
  " )\n",
340
  " \n",
341
  " self.keep_target_center_fraction = keep_target_center_fraction\n",
@@ -351,7 +366,7 @@
351
  " \n",
352
  " def forward(self, tokens: torch.Tensor, **kwargs) -> Dict[str, torch.Tensor]:\n",
353
  " # Forward through backbone\n",
354
- " outputs = self.backbone(input_ids=tokens)\n",
355
  " embedding = outputs.hidden_states[-1] # Last hidden state\n",
356
  " \n",
357
  " # Crop to center fraction\n",
@@ -379,14 +394,14 @@
379
  "# Create model\n",
380
  "model = HFModelWithHead(\n",
381
  " model_name=config[\"model_name\"],\n",
382
- " bigwig_track_names=config[\"bigwig_file_ids\"],\n",
383
  " keep_target_center_fraction=config[\"keep_target_center_fraction\"],\n",
384
  ")\n",
385
  "model = model.to(device)\n",
386
  "model.train()\n",
387
  "\n",
388
  "print(f\"Model loaded: {config['model_name']}\")\n",
389
- "print(f\"Number of bigwig tracks: {len(config['bigwig_file_ids'])}\")\n",
390
  "print(f\"Model parameters: {sum(p.numel() for p in model.parameters()):,}\")"
391
  ]
392
  },
@@ -426,8 +441,8 @@
426
  " Random genomic windows from a reference genome + bigWig signal.\n",
427
  "\n",
428
  " Each sample:\n",
429
- " - picks a chromosome/region (from `chroms` or `regions`),\n",
430
- " - picks a random window of length `sequence_length`,\n",
431
  " - returns (sequence, signal, chrom, start, end).\n",
432
  "\n",
433
  " This dataset is compatible with multi-worker DataLoaders. BigWig files\n",
@@ -438,11 +453,13 @@
438
  " ----\n",
439
  " fasta_path : str\n",
440
  " Path to the reference genome FASTA (e.g. hg38.fna).\n",
441
- " bigwig_path_list : str\n",
442
- " Path to the bigWig file (e.g. ENCFF884LDL.bigWig).\n",
443
- " chroms : List[str]\n",
444
- " Chromosome names as they appear in the bigWig (e.g. [\"chr1\", \"chr2\", ...]).\n",
445
- " Used for backward compatibility or when regions=None.\n",
 
 
446
  " sequence_length : int\n",
447
  " Length of each random window (in bp).\n",
448
  " num_samples : int\n",
@@ -453,18 +470,14 @@
453
  " Function to transform/scaling bigwig targets.\n",
454
  " keep_target_center_fraction : float\n",
455
  " Fraction of center sequence to keep for target prediction (crops edges to focus on center).\n",
456
- " regions : List[tuple[str, int, int]] | None\n",
457
- " Optional list of regions as (chromosome, start, end) tuples.\n",
458
- " If provided, samples are drawn randomly from within these regions only.\n",
459
- " This matches the JAX pipeline approach using BED file splits.\n",
460
- " If None, samples from entire chromosomes in `chroms`.\n",
461
  " \"\"\"\n",
462
  "\n",
463
  " def __init__(\n",
464
  " self,\n",
465
  " fasta_path: str,\n",
466
  " bigwig_path_list: list[str],\n",
467
- " chroms: List[str],\n",
 
468
  " sequence_length: int,\n",
469
  " num_samples: int,\n",
470
  " tokenizer: AutoTokenizer,\n",
@@ -479,43 +492,37 @@
479
  " self.sequence_length = sequence_length\n",
480
  " self.num_samples = num_samples\n",
481
  " self.tokenizer = tokenizer\n",
482
- " self.transform_fn = transform_fn # Use pre-computed transform function\n",
483
  " self.keep_target_center_fraction = keep_target_center_fraction\n",
484
- " self.chroms = chroms\n",
485
  "\n",
486
- " # Get chromosome lengths from first BigWig file (lazy, cached per process)\n",
487
- " # We need this for validation, so open temporarily\n",
488
- " bw_handle = _get_bigwig_handle(bigwig_path_list[0])\n",
489
- " bw_chrom_lengths = bw_handle.chroms() # dict: chrom -> length\n",
490
  "\n",
491
- " self.valid_chroms = []\n",
492
- " self.chrom_lengths = {}\n",
 
493
  "\n",
494
- " for c in chroms:\n",
495
- " if c not in bw_chrom_lengths or c not in self.fasta:\n",
496
  " continue\n",
 
 
 
497
  "\n",
498
- " fa_len = len(self.fasta[c])\n",
499
- " bw_len = bw_chrom_lengths[c]\n",
500
- " L = min(fa_len, bw_len)\n",
501
- "\n",
502
- " if L > self.sequence_length:\n",
503
- " self.valid_chroms.append(c)\n",
504
- " self.chrom_lengths[c] = L\n",
505
- "\n",
506
- " if not self.valid_chroms:\n",
507
- " raise ValueError(\"No valid chromosomes after intersecting FASTA and bigWig.\")\n",
508
  "\n",
509
  " def __len__(self):\n",
510
  " return self.num_samples\n",
511
  "\n",
512
  " def __getitem__(self, idx):\n",
513
- "\n",
514
- " # Sample from entire chromosomes\n",
515
- " chrom = random.choice(self.valid_chroms)\n",
516
- " chrom_len = self.chrom_lengths[chrom]\n",
517
- " max_start = chrom_len - self.sequence_length\n",
518
- " start = random.randint(0, max_start)\n",
519
  " end = start + self.sequence_length\n",
520
  "\n",
521
  " # Sequence\n",
@@ -575,133 +582,26 @@
575
  "metadata": {},
576
  "outputs": [],
577
  "source": [
578
- "# Scaling functions for targets\n",
579
- "def compute_chromosome_stats(track_data: np.ndarray) -> dict:\n",
580
- " \"\"\"\n",
581
- " Compute minimal statistics needed for weighted mean computation.\n",
582
- " \n",
583
- " Args:\n",
584
- " track_data: numpy array of track values for a chromosome\n",
585
- " \n",
586
- " Returns:\n",
587
- " Dictionary with statistics: sum, mean, total_count\n",
588
  " \"\"\"\n",
589
- " track_data = track_data.astype(np.float32)\n",
590
- " \n",
591
- " # Compute statistics\n",
592
- " sum_all = np.sum(track_data)\n",
593
- " total_count = track_data.size\n",
594
- " mean_all = sum_all / total_count if total_count > 0 else 0.0\n",
595
- " \n",
596
- " return {\n",
597
- " \"sum\": sum_all,\n",
598
- " \"mean\": mean_all,\n",
599
- " \"total_count\": total_count,\n",
600
- " }\n",
601
  "\n",
602
- "\n",
603
- "def aggregate_file_statistics(chr_stats_list: List[dict]) -> dict:\n",
604
- " \"\"\"\n",
605
- " Aggregate chromosome-level statistics into file-level statistics.\n",
606
- " \n",
607
  " Args:\n",
608
- " chr_stats_list: List of dictionaries, each containing chromosome-level statistics\n",
609
- " \n",
610
- " Returns:\n",
611
- " Dictionary with aggregated file-level statistics (only mean)\n",
612
- " \"\"\"\n",
613
- " # Convert to arrays for easier computation\n",
614
- " total_counts = np.array([s[\"total_count\"] for s in chr_stats_list], dtype=np.int64)\n",
615
- " means = np.array([s[\"mean\"] for s in chr_stats_list], dtype=np.float32)\n",
616
- " sums = np.array([s[\"sum\"] for s in chr_stats_list], dtype=np.float32)\n",
617
- " \n",
618
- " # Aggregate total count\n",
619
- " total_count = np.sum(total_counts)\n",
620
- " \n",
621
- " # Weighted mean: mean = sum(mean_chr * total_count_chr) / sum(total_count_chr)\n",
622
- " mean = np.sum(means * total_counts) / total_count if total_count > 0 else 0.0\n",
623
- " \n",
624
- " return {\n",
625
- " \"total_count\": total_count,\n",
626
- " \"sum\": np.sum(sums),\n",
627
- " \"mean\": mean,\n",
628
- " }\n",
629
- "\n",
630
  "\n",
631
- "def get_track_means(bigwig_tracks_list: List[pyBigWig.pyBigWig]) -> np.ndarray:\n",
632
- " \"\"\"\n",
633
- " Get track means for normalization.\n",
634
- " Computes statistics per chromosome and aggregates using weighted averaging,\n",
635
- " \n",
636
- " Args:\n",
637
- " bigwig_tracks_list: List of pyBigWig file objects\n",
638
- " \n",
639
- " Returns:\n",
640
- " Array of track means, one per bigwig file\n",
641
- " \"\"\"\n",
642
- " track_means = []\n",
643
- " \n",
644
- " for bigwig_track in bigwig_tracks_list:\n",
645
- " chrom_lengths = bigwig_track.chroms()\n",
646
- " all_chr_stats = []\n",
647
- " \n",
648
- " # Compute statistics for each chromosome\n",
649
- " for chrom_name, chrom_length in chrom_lengths.items():\n",
650
- " try:\n",
651
- " # Get chromosome data as numpy array\n",
652
- " bw_array = np.array(\n",
653
- " bigwig_track.values(chrom_name, 0, chrom_length, numpy=True),\n",
654
- " dtype=np.float32\n",
655
- " )\n",
656
- " # Replace NaN with 0\n",
657
- " bw_array = np.nan_to_num(bw_array, nan=0.0)\n",
658
- " \n",
659
- " # Compute chromosome-level statistics\n",
660
- " chr_stats = compute_chromosome_stats(bw_array)\n",
661
- " all_chr_stats.append(chr_stats)\n",
662
- " except Exception as e:\n",
663
- " # Skip chromosomes that fail to load\n",
664
- " print(f\"Warning: Failed to load chromosome {chrom_name}: {e}\")\n",
665
- " continue\n",
666
- " \n",
667
- " if not all_chr_stats:\n",
668
- " raise ValueError(f\"No valid chromosomes found for bigwig track\")\n",
669
- " \n",
670
- " # Aggregate chromosome-level stats into file-level stats\n",
671
- " file_stats = aggregate_file_statistics(all_chr_stats)\n",
672
- " \n",
673
- " # Use the weighted mean for normalization\n",
674
- " track_means.append(file_stats[\"mean\"])\n",
675
- " \n",
676
- " return np.array(track_means, dtype=np.float32)\n",
677
- "\n",
678
- "\n",
679
- "def create_targets_scaling_fn(bigwig_path_list: List[str]) -> Callable[[torch.Tensor], torch.Tensor]:\n",
680
- " \"\"\"\n",
681
- " Build a scaling function based on track means computed from bigwig files.\n",
682
- " \n",
683
- " Opens bigwig files, computes track statistics, and creates a transform function.\n",
684
- " The statistics are computed once and reused for all calls to the returned transform function.\n",
685
- " \n",
686
- " Args:\n",
687
- " bigwig_path_list: List of paths to bigwig files\n",
688
- " \n",
689
  " Returns:\n",
690
  " Transform function that scales input tensors\n",
691
  " \"\"\"\n",
692
  " # Open bigwig files and compute track statistics\n",
693
- " print(\"Computing track statistics (this may take a while)...\")\n",
694
- " bw_list = [\n",
695
- " pyBigWig.open(bigwig_path)\n",
696
- " for bigwig_path in bigwig_path_list\n",
697
- " ]\n",
698
- " track_means = get_track_means(bw_list)\n",
699
- " print(f\"Computed track means: {track_means}\")\n",
700
- " print(f\"Track means shape: {track_means.shape}\")\n",
701
- " \n",
702
  " # Create tensor from computed means\n",
703
  " track_means_tensor = torch.tensor(track_means, dtype=torch.float32)\n",
704
- " \n",
705
  " def transform_fn(x: torch.Tensor) -> torch.Tensor:\n",
706
  " \"\"\"\n",
707
  " x: torch.Tensor, shape (seq_len, num_tracks) or (batch, seq_len, num_tracks)\n",
@@ -717,20 +617,10 @@
717
  " scaled,\n",
718
  " )\n",
719
  " return clipped\n",
720
- " \n",
721
  " return transform_fn"
722
  ]
723
  },
724
- {
725
- "cell_type": "code",
726
- "execution_count": null,
727
- "metadata": {},
728
- "outputs": [],
729
- "source": [
730
- "# Create scaling function\n",
731
- "targets_transform_fn = create_targets_scaling_fn(bigwig_path_list)"
732
- ]
733
- },
734
  {
735
  "cell_type": "code",
736
  "execution_count": null,
@@ -741,25 +631,26 @@
741
  "create_dataset_fn = functools.partial(\n",
742
  " GenomeBigWigDataset,\n",
743
  " fasta_path=fasta_path,\n",
744
- " bigwig_path_list=bigwig_path_list,\n",
 
745
  " sequence_length=config[\"sequence_length\"],\n",
746
  " tokenizer=tokenizer,\n",
747
- " transform_fn=targets_transform_fn,\n",
748
  " keep_target_center_fraction=config[\"keep_target_center_fraction\"],\n",
749
  ")\n",
750
  "\n",
751
  "train_dataset = create_dataset_fn(\n",
752
- " chroms=chrom_splits[\"train\"],\n",
753
  " num_samples=config[\"num_steps_training\"] * config[\"batch_size\"],\n",
754
  ")\n",
755
  "\n",
756
  "val_dataset = create_dataset_fn(\n",
757
- " chroms=chrom_splits[\"val\"],\n",
758
  " num_samples=config[\"num_validation_samples\"],\n",
759
  ")\n",
760
  "\n",
761
  "test_dataset = create_dataset_fn(\n",
762
- " chroms=chrom_splits[\"test\"],\n",
763
  " num_samples=config[\"num_test_samples\"],\n",
764
  ")\n",
765
  "\n",
@@ -785,7 +676,7 @@
785
  " num_workers=config[\"num_workers\"],\n",
786
  ")\n",
787
  "\n",
788
- "print(f\"Train samples: {len(train_dataset)}\")\n",
789
  "print(f\"Val samples: {len(val_dataset)}\")\n",
790
  "print(f\"Test samples: {len(test_dataset)}\")"
791
  ]
@@ -912,9 +803,9 @@
912
  "metadata": {},
913
  "outputs": [],
914
  "source": [
915
- "train_metrics = TracksMetrics(config[\"bigwig_file_ids\"])\n",
916
- "val_metrics = TracksMetrics(config[\"bigwig_file_ids\"])\n",
917
- "test_metrics = TracksMetrics(config[\"bigwig_file_ids\"])"
918
  ]
919
  },
920
  {
@@ -1098,47 +989,6 @@
1098
  "val_losses = []\n",
1099
  "val_pearson_scores = []\n",
1100
  "\n",
1101
- "# Initialize interactive plots using FigureWidget for real-time updates\n",
1102
- "from plotly.graph_objects import FigureWidget\n",
1103
- "from plotly.subplots import make_subplots\n",
1104
- "\n",
1105
- "# Create base figure with subplots\n",
1106
- "fig_base = make_subplots(\n",
1107
- " rows=1, cols=2,\n",
1108
- " subplot_titles=('Loss', 'Mean Pearson Correlation'),\n",
1109
- " horizontal_spacing=0.15,\n",
1110
- ")\n",
1111
- "\n",
1112
- "# Add empty traces for train and val metrics\n",
1113
- "fig_base.add_trace(\n",
1114
- " go.Scatter(x=[], y=[], mode='lines+markers', name='Train Loss', line=dict(color='blue')),\n",
1115
- " row=1, col=1\n",
1116
- ")\n",
1117
- "fig_base.add_trace(\n",
1118
- " go.Scatter(x=[], y=[], mode='lines+markers', name='Val Loss', line=dict(color='red')),\n",
1119
- " row=1, col=1\n",
1120
- ")\n",
1121
- "fig_base.add_trace(\n",
1122
- " go.Scatter(x=[], y=[], mode='lines+markers', name='Train Pearson', line=dict(color='green')),\n",
1123
- " row=1, col=2\n",
1124
- ")\n",
1125
- "fig_base.add_trace(\n",
1126
- " go.Scatter(x=[], y=[], mode='lines+markers', name='Val Pearson', line=dict(color='orange')),\n",
1127
- " row=1, col=2\n",
1128
- ")\n",
1129
- "\n",
1130
- "fig_base.update_xaxes(title_text=\"Step\", row=1, col=1)\n",
1131
- "fig_base.update_xaxes(title_text=\"Step\", row=1, col=2)\n",
1132
- "fig_base.update_yaxes(title_text=\"Loss\", row=1, col=1)\n",
1133
- "fig_base.update_yaxes(title_text=\"Pearson Correlation\", row=1, col=2)\n",
1134
- "fig_base.update_layout(height=800, width=1600, showlegend=True)\n",
1135
- "\n",
1136
- "# Convert to FigureWidget for interactive updates\n",
1137
- "fig = FigureWidget(fig_base)\n",
1138
- "\n",
1139
- "# Display initial plot (will update in place during training)\n",
1140
- "display(fig)\n",
1141
- "\n",
1142
  "# Create iterator for training data (will cycle if needed)\n",
1143
  "train_iter = iter(train_loader)\n",
1144
  "\n",
@@ -1183,11 +1033,6 @@
1183
  " train_losses.append(mean_loss)\n",
1184
  " train_pearson_scores.append(train_metrics_dict['mean/pearson'])\n",
1185
  " \n",
1186
- " # Update plots - direct assignment to FigureWidget data updates the plot automatically\n",
1187
- " fig.data[0].x = train_steps\n",
1188
- " fig.data[0].y = train_losses\n",
1189
- " fig.data[2].x = train_steps\n",
1190
- " fig.data[2].y = train_pearson_scores\n",
1191
  " \n",
1192
  " print(\n",
1193
  " f\"Step {step_idx + 1}/{config['num_steps_training']} | \"\n",
@@ -1215,11 +1060,6 @@
1215
  " val_losses.append(val_metrics_dict['loss'])\n",
1216
  " val_pearson_scores.append(val_pearson_mean)\n",
1217
  " \n",
1218
- " # Update plots with validation data - direct assignment updates the plot automatically\n",
1219
- " fig.data[1].x = val_steps\n",
1220
- " fig.data[1].y = val_losses\n",
1221
- " fig.data[3].x = val_steps\n",
1222
- " fig.data[3].y = val_pearson_scores\n",
1223
  " \n",
1224
  " print(f\" Validation Loss: {val_metrics_dict['loss']:.4f}\")\n",
1225
  " print(f\" Validation Mean Pearson: {val_pearson_mean:.4f}\")\n",
@@ -1228,7 +1068,40 @@
1228
  " \n",
1229
  " model.train() # Back to training mode\n",
1230
  "\n",
1231
- "print(f\"\\nTraining completed after {config['num_steps_training']} steps.\")"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1232
  ]
1233
  },
1234
  {
 
52
  },
53
  {
54
  "cell_type": "code",
55
+ "execution_count": null,
56
  "metadata": {},
57
  "outputs": [],
58
  "source": [
 
60
  "import functools\n",
61
  "from typing import List, Dict, Callable\n",
62
  "import os\n",
63
+ "import fnmatch\n",
64
+ "from pathlib import Path\n",
65
+ "from huggingface_hub import HfApi, snapshot_download\n",
66
  "\n",
67
  "import torch\n",
68
  "import torch.nn as nn\n",
 
70
  "from torch.utils.data import Dataset, DataLoader\n",
71
  "from torch.optim import AdamW\n",
72
  "from transformers import AutoConfig, AutoModelForMaskedLM, AutoTokenizer\n",
73
+ "import pandas as pd\n",
74
+ "import matplotlib.pyplot as plt\n",
75
  "import numpy as np\n",
76
  "import pyBigWig\n",
77
  "from pyfaidx import Fasta\n",
 
93
  "- **`model_name`**: HuggingFace model name/identifier for the pretrained backbone model\n",
94
  "\n",
95
  "### Data\n",
96
+ "- **`hf_repo_id`**: HuggingFace dataset repository ID containing the benchmark data\n",
97
+ "- **`species`**: Species name (e.g., \"human\") to select data from the benchmark dataset\n",
98
  "- **`data_cache_dir`**: Directory where downloaded data files (FASTA, bigWig) will be stored\n",
 
 
 
99
  "- **`sequence_length`**: Length of input sequences in base pairs (bp)\n",
100
  "- **`keep_target_center_fraction`**: Fraction of center sequence to keep for target prediction (crops edges to focus on center)\n",
101
  "\n",
 
110
  "- **`validate_every_n_steps`**: Run validation every N steps\n",
111
  "- **`num_validation_samples`**: Number of samples to use for validation set\n",
112
  "\n",
113
+ "### Test\n",
114
+ "- **`num_test_samples`**: Number of samples to use for test set evaluation\n",
115
+ "\n",
116
  "### General\n",
117
  "- **`seed`**: Random seed for reproducibility\n",
118
  "- **`device`**: Device to run training on (\"cuda\" or \"cpu\")\n",
 
121
  },
122
  {
123
  "cell_type": "code",
124
+ "execution_count": null,
125
  "metadata": {},
126
+ "outputs": [],
 
 
 
 
 
 
 
 
127
  "source": [
128
  "config = {\n",
129
  " # Model\n",
130
  " \"model_name\": \"InstaDeepAI/NTv3_8M_pre\",\n",
131
  " \n",
132
  " # Data\n",
133
+ " \"hf_repo_id\": \"InstaDeepAI/NTv3_benchmark_dataset\",\n",
134
+ " \"species\": \"arabidopsis\",\n",
135
  " \"data_cache_dir\": \"./data\",\n",
 
 
 
 
 
 
 
136
  " \"sequence_length\": 32_768,\n",
137
  " \"keep_target_center_fraction\": 0.375,\n",
138
  " \n",
 
151
  " \"num_test_samples\": 10000,\n",
152
  " \n",
153
  " # General\n",
154
+ " \"seed\": 0,\n",
155
  " \"device\": \"cuda\" if torch.cuda.is_available() else \"cpu\",\n",
156
  " \"num_workers\": 16,\n",
157
  "}\n",
158
  "\n",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
159
  "# Set random seed\n",
160
  "torch.manual_seed(config[\"seed\"])\n",
161
  "np.random.seed(config[\"seed\"])\n",
 
180
  "metadata": {},
181
  "outputs": [],
182
  "source": [
183
+ "def prepare_genomics_inputs(\n",
184
+ " species: str,\n",
185
+ " data_cache_dir: str | Path = \"data\",\n",
186
+ " hf_repo_id: str = \"InstaDeepAI/NTv3_benchmark_dataset\",\n",
187
+ ") -> tuple[str, list[str], list[str]]:\n",
188
+ " \"\"\"\n",
189
+ " Downloads:\n",
190
+ " 1) FASTA from HF dataset under: <species>/genome.fasta\n",
191
+ " 2) BigWigs from HF dataset under: <species>/functional_tracks/**\n",
192
+ " 3) Splits from HF dataset under: <species>/splits.bed\n",
193
+ " 4) Metadata from HF dataset under: benchmark_metadata.tsv\n",
194
+ " Returns:\n",
195
+ " (fasta_path, bigwig_path_list, bigwig_file_ids)\n",
196
+ " \"\"\"\n",
197
+ " cache = Path(data_cache_dir).expanduser().resolve()\n",
198
+ " cache.mkdir(parents=True, exist_ok=True)\n",
 
 
 
 
 
 
 
 
 
 
 
 
199
  " \n",
200
+ " # --- Download metadata + <species> files (FASTA, BigWigs, Splits) ---\n",
201
+ " api = HfApi()\n",
202
+ " files = api.list_repo_files(repo_id=hf_repo_id, repo_type=\"dataset\")\n",
203
+ " \n",
204
+ " # Find all files to download: species directory + metadata at root\n",
205
+ " species_pattern = f\"{species}/**\"\n",
206
+ " metadata_file = \"benchmark_metadata.tsv\"\n",
207
+ " \n",
208
+ " species_files = [p for p in files if fnmatch.fnmatch(p, species_pattern)]\n",
209
+ " if not species_files:\n",
210
+ " raise ValueError(f\"No files found matching '{species_pattern}' in '{hf_repo_id}'\")\n",
211
+ " \n",
212
+ " if metadata_file not in files:\n",
213
+ " raise ValueError(f\"No metadata file found at '{metadata_file}' in '{hf_repo_id}'\")\n",
214
+ " \n",
215
+ " # Download all needed files\n",
216
+ " download_patterns = [species_pattern, metadata_file]\n",
217
+ " local_dir = Path(\n",
218
+ " snapshot_download(\n",
219
+ " repo_id=hf_repo_id,\n",
220
+ " repo_type=\"dataset\",\n",
221
+ " allow_patterns=download_patterns,\n",
222
+ " local_dir=str(cache),\n",
223
+ " )\n",
224
+ " )\n",
225
+ " \n",
226
+ " # --- Organize outputs ---\n",
227
+ " # FASTA file\n",
228
+ " fasta_path_repo = f\"{species}/genome.fasta\"\n",
229
+ " fasta_path = str(local_dir / fasta_path_repo)\n",
230
+ " if not Path(fasta_path).is_file():\n",
231
+ " raise ValueError(f\"FASTA file not found at '{fasta_path}'\")\n",
232
+ " \n",
233
+ " # BigWig files\n",
234
+ " bigwig_paths, bigwig_ids = [], []\n",
235
+ " for repo_path in species_files:\n",
236
+ " lp = local_dir / repo_path\n",
237
+ " if lp.is_file() and lp.suffix == \".bigwig\":\n",
238
+ " bigwig_paths.append(str(lp))\n",
239
+ " bigwig_ids.append(lp.stem)\n",
240
+ " if not bigwig_paths:\n",
241
+ " raise ValueError(f\"Found no BigWig files in '{species_pattern}'\")\n",
242
+ " \n",
243
+ " # Splits file\n",
244
+ " splits_path_repo = f\"{species}/splits.bed\"\n",
245
+ " splits_path = local_dir / splits_path_repo\n",
246
+ " if not splits_path.is_file():\n",
247
+ " raise ValueError(f\"Splits file not found at '{splits_path}'\")\n",
248
+ " splits_df = pd.read_csv(\n",
249
+ " splits_path, \n",
250
+ " sep=\"\\t\", \n",
251
+ " header=None, \n",
252
+ " names=[\"chr_name\", \"start\", \"end\", \"split\"],\n",
253
+ " dtype={\"chr_name\": str, \"start\": int, \"end\": int, \"split\": str},\n",
254
+ " )\n",
255
+ " \n",
256
+ " # Metadata file\n",
257
+ " metadata_path = local_dir / metadata_file\n",
258
+ " if not metadata_path.is_file():\n",
259
+ " raise ValueError(f\"Metadata file not found at '{metadata_path}'\")\n",
260
+ " metadata_df = pd.read_csv(metadata_path, sep=\"\\t\")\n",
261
+ "\n",
262
+ " if \"species\" not in metadata_df.columns:\n",
263
+ " raise ValueError(\"benchmark_metadata.tsv has no 'species' column\")\n",
264
+ "\n",
265
+ " # Filter metadata according to species\n",
266
+ " metadata_df = metadata_df[metadata_df[\"species\"] == species].reset_index(drop=True)\n",
267
+ "\n",
268
+ " # Order metadata according to bigwig file ids\n",
269
+ " metadata_df = (\n",
270
+ " metadata_df.set_index(\"file_id\")\n",
271
+ " .loc[bigwig_ids]\n",
272
+ " .reset_index()\n",
273
+ " )\n",
274
+ "\n",
275
+ " return fasta_path, bigwig_paths, bigwig_ids, splits_df, metadata_df"
276
  ]
277
  },
278
  {
 
281
  "metadata": {},
282
  "outputs": [],
283
  "source": [
284
+ "os.makedirs(config[\"data_cache_dir\"], exist_ok=True)\n",
285
+ "\n",
286
+ "# Download all species files + load the splits, and metadata\n",
287
+ "(\n",
288
+ " fasta_path, \n",
289
+ " bigwig_paths, \n",
290
+ " bigwig_ids, \n",
291
+ " species_splits_df,\n",
292
+ " metadata_df \n",
293
+ ") = prepare_genomics_inputs(\n",
294
+ " config[\"species\"], \n",
295
+ " config[\"data_cache_dir\"], \n",
296
+ " config[\"hf_repo_id\"]\n",
297
+ ")"
298
  ]
299
  },
300
  {
 
350
  " self.backbone = AutoModelForMaskedLM.from_pretrained(\n",
351
  " model_name, \n",
352
  " trust_remote_code=True,\n",
353
+ " config=self.config,\n",
354
  " )\n",
355
  " \n",
356
  " self.keep_target_center_fraction = keep_target_center_fraction\n",
 
366
  " \n",
367
  " def forward(self, tokens: torch.Tensor, **kwargs) -> Dict[str, torch.Tensor]:\n",
368
  " # Forward through backbone\n",
369
+ " outputs = self.backbone(input_ids=tokens, output_hidden_states=True)\n",
370
  " embedding = outputs.hidden_states[-1] # Last hidden state\n",
371
  " \n",
372
  " # Crop to center fraction\n",
 
394
  "# Create model\n",
395
  "model = HFModelWithHead(\n",
396
  " model_name=config[\"model_name\"],\n",
397
+ " bigwig_track_names=bigwig_ids,\n",
398
  " keep_target_center_fraction=config[\"keep_target_center_fraction\"],\n",
399
  ")\n",
400
  "model = model.to(device)\n",
401
  "model.train()\n",
402
  "\n",
403
  "print(f\"Model loaded: {config['model_name']}\")\n",
404
+ "print(f\"Number of bigwig tracks: {len(bigwig_ids)}\")\n",
405
  "print(f\"Model parameters: {sum(p.numel() for p in model.parameters()):,}\")"
406
  ]
407
  },
 
441
  " Random genomic windows from a reference genome + bigWig signal.\n",
442
  "\n",
443
  " Each sample:\n",
444
+ " - picks a random region from the specified split,\n",
445
+ " - picks a random window of length `sequence_length` within that region,\n",
446
  " - returns (sequence, signal, chrom, start, end).\n",
447
  "\n",
448
  " This dataset is compatible with multi-worker DataLoaders. BigWig files\n",
 
453
  " ----\n",
454
  " fasta_path : str\n",
455
  " Path to the reference genome FASTA (e.g. hg38.fna).\n",
456
+ " bigwig_path_list : list[str]\n",
457
+ " List of paths to bigWig files.\n",
458
+ " chrom_regions : pd.DataFrame\n",
459
+ " DataFrame with columns: chr_name, start, end, split.\n",
460
+ " Contains all genomic regions with their split assignments.\n",
461
+ " split : str\n",
462
+ " Split name to filter regions (e.g., \"train\", \"val\", \"test\").\n",
463
  " sequence_length : int\n",
464
  " Length of each random window (in bp).\n",
465
  " num_samples : int\n",
 
470
  " Function to transform/scaling bigwig targets.\n",
471
  " keep_target_center_fraction : float\n",
472
  " Fraction of center sequence to keep for target prediction (crops edges to focus on center).\n",
 
 
 
 
 
473
  " \"\"\"\n",
474
  "\n",
475
  " def __init__(\n",
476
  " self,\n",
477
  " fasta_path: str,\n",
478
  " bigwig_path_list: list[str],\n",
479
+ " chrom_regions: pd.DataFrame,\n",
480
+ " split: str,\n",
481
  " sequence_length: int,\n",
482
  " num_samples: int,\n",
483
  " tokenizer: AutoTokenizer,\n",
 
492
  " self.sequence_length = sequence_length\n",
493
  " self.num_samples = num_samples\n",
494
  " self.tokenizer = tokenizer\n",
495
+ " self.transform_fn = transform_fn\n",
496
  " self.keep_target_center_fraction = keep_target_center_fraction\n",
497
+ " self.chrom_regions = chrom_regions\n",
498
  "\n",
499
+ " # Filter regions by split\n",
500
+ " split_regions = self.chrom_regions[self.chrom_regions[\"split\"] == split].copy()\n",
 
 
501
  "\n",
502
+ " # Filter valid regions (must be large enough for sequence_length)\n",
503
+ " self.valid_regions = []\n",
504
+ " for _, row in split_regions.iterrows():\n",
505
  "\n",
506
+ " region_length = row.end - row.start\n",
507
+ " if region_length < self.sequence_length:\n",
508
  " continue\n",
509
+ " \n",
510
+ " # Store valid region\n",
511
+ " self.valid_regions.append((row.chr_name, row.start, row.end))\n",
512
  "\n",
513
+ " if not self.valid_regions:\n",
514
+ " raise ValueError(f\"No valid regions found for split '{split}'\")\n",
 
 
 
 
 
 
 
 
515
  "\n",
516
  " def __len__(self):\n",
517
  " return self.num_samples\n",
518
  "\n",
519
  " def __getitem__(self, idx):\n",
520
+ " # Sample a random region from the valid regions\n",
521
+ " chrom, region_start, region_end = random.choice(self.valid_regions)\n",
522
+ " \n",
523
+ " # Sample a random window within this region\n",
524
+ " max_start = region_end - self.sequence_length\n",
525
+ " start = random.randint(region_start, max_start)\n",
526
  " end = start + self.sequence_length\n",
527
  "\n",
528
  " # Sequence\n",
 
582
  "metadata": {},
583
  "outputs": [],
584
  "source": [
585
+ "def create_targets_scaling_fn(\n",
586
+ " metadata_df: pd.DataFrame\n",
587
+ ") -> Callable[[torch.Tensor], torch.Tensor]:\n",
 
 
 
 
 
 
 
588
  " \"\"\"\n",
589
+ " Build a scaling function based on track means contained in the metadata.\n",
 
 
 
 
 
 
 
 
 
 
 
590
  "\n",
 
 
 
 
 
591
  " Args:\n",
592
+ " metadata_df: pandas.DataFrame with track means\n",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
593
  "\n",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
594
  " Returns:\n",
595
  " Transform function that scales input tensors\n",
596
  " \"\"\"\n",
597
  " # Open bigwig files and compute track statistics\n",
598
+ " track_means = metadata_df[\"mean\"].to_numpy()\n",
599
+ " print(f\"Track means: {track_means}\")\n",
600
+ " print(f\"Number of tracks: {track_means.shape}\")\n",
601
+ "\n",
 
 
 
 
 
602
  " # Create tensor from computed means\n",
603
  " track_means_tensor = torch.tensor(track_means, dtype=torch.float32)\n",
604
+ "\n",
605
  " def transform_fn(x: torch.Tensor) -> torch.Tensor:\n",
606
  " \"\"\"\n",
607
  " x: torch.Tensor, shape (seq_len, num_tracks) or (batch, seq_len, num_tracks)\n",
 
617
  " scaled,\n",
618
  " )\n",
619
  " return clipped\n",
620
+ "\n",
621
  " return transform_fn"
622
  ]
623
  },
 
 
 
 
 
 
 
 
 
 
624
  {
625
  "cell_type": "code",
626
  "execution_count": null,
 
631
  "create_dataset_fn = functools.partial(\n",
632
  " GenomeBigWigDataset,\n",
633
  " fasta_path=fasta_path,\n",
634
+ " bigwig_path_list=bigwig_paths,\n",
635
+ " chrom_regions=species_splits_df,\n",
636
  " sequence_length=config[\"sequence_length\"],\n",
637
  " tokenizer=tokenizer,\n",
638
+ " transform_fn=create_targets_scaling_fn(metadata_df),\n",
639
  " keep_target_center_fraction=config[\"keep_target_center_fraction\"],\n",
640
  ")\n",
641
  "\n",
642
  "train_dataset = create_dataset_fn(\n",
643
+ " split=\"train\",\n",
644
  " num_samples=config[\"num_steps_training\"] * config[\"batch_size\"],\n",
645
  ")\n",
646
  "\n",
647
  "val_dataset = create_dataset_fn(\n",
648
+ " split=\"val\",\n",
649
  " num_samples=config[\"num_validation_samples\"],\n",
650
  ")\n",
651
  "\n",
652
  "test_dataset = create_dataset_fn(\n",
653
+ " split=\"test\",\n",
654
  " num_samples=config[\"num_test_samples\"],\n",
655
  ")\n",
656
  "\n",
 
676
  " num_workers=config[\"num_workers\"],\n",
677
  ")\n",
678
  "\n",
679
+ "print(f\"\\nTrain samples: {len(train_dataset)}\")\n",
680
  "print(f\"Val samples: {len(val_dataset)}\")\n",
681
  "print(f\"Test samples: {len(test_dataset)}\")"
682
  ]
 
803
  "metadata": {},
804
  "outputs": [],
805
  "source": [
806
+ "train_metrics = TracksMetrics(bigwig_ids)\n",
807
+ "val_metrics = TracksMetrics(bigwig_ids)\n",
808
+ "test_metrics = TracksMetrics(bigwig_ids)"
809
  ]
810
  },
811
  {
 
989
  "val_losses = []\n",
990
  "val_pearson_scores = []\n",
991
  "\n",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
992
  "# Create iterator for training data (will cycle if needed)\n",
993
  "train_iter = iter(train_loader)\n",
994
  "\n",
 
1033
  " train_losses.append(mean_loss)\n",
1034
  " train_pearson_scores.append(train_metrics_dict['mean/pearson'])\n",
1035
  " \n",
 
 
 
 
 
1036
  " \n",
1037
  " print(\n",
1038
  " f\"Step {step_idx + 1}/{config['num_steps_training']} | \"\n",
 
1060
  " val_losses.append(val_metrics_dict['loss'])\n",
1061
  " val_pearson_scores.append(val_pearson_mean)\n",
1062
  " \n",
 
 
 
 
 
1063
  " \n",
1064
  " print(f\" Validation Loss: {val_metrics_dict['loss']:.4f}\")\n",
1065
  " print(f\" Validation Mean Pearson: {val_pearson_mean:.4f}\")\n",
 
1068
  " \n",
1069
  " model.train() # Back to training mode\n",
1070
  "\n",
1071
+ "print(f\"\\nTraining completed after {config['num_steps_training']} steps.\")\n"
1072
+ ]
1073
+ },
1074
+ {
1075
+ "cell_type": "code",
1076
+ "execution_count": null,
1077
+ "metadata": {},
1078
+ "outputs": [],
1079
+ "source": [
1080
+ "# Plot training results\n",
1081
+ "fig, axes = plt.subplots(1, 2, figsize=(16, 6))\n",
1082
+ "\n",
1083
+ "# Plot Loss\n",
1084
+ "axes[0].plot(train_steps, train_losses, 'b-o', label='Train Loss', markersize=4, linewidth=1.5)\n",
1085
+ "if val_steps:\n",
1086
+ " axes[0].plot(val_steps, val_losses, 'r-s', label='Val Loss', markersize=4, linewidth=1.5)\n",
1087
+ "axes[0].set_xlabel('Step')\n",
1088
+ "axes[0].set_ylabel('Loss')\n",
1089
+ "axes[0].set_title('Loss')\n",
1090
+ "axes[0].legend()\n",
1091
+ "axes[0].grid(True, alpha=0.3)\n",
1092
+ "\n",
1093
+ "# Plot Pearson Correlation\n",
1094
+ "axes[1].plot(train_steps, train_pearson_scores, 'g-o', label='Train Pearson', markersize=4, linewidth=1.5)\n",
1095
+ "if val_steps:\n",
1096
+ " axes[1].plot(val_steps, val_pearson_scores, 'orange', marker='s', label='Val Pearson', markersize=4, linewidth=1.5)\n",
1097
+ "axes[1].set_xlabel('Step')\n",
1098
+ "axes[1].set_ylabel('Pearson Correlation')\n",
1099
+ "axes[1].set_title('Mean Pearson Correlation')\n",
1100
+ "axes[1].legend()\n",
1101
+ "axes[1].grid(True, alpha=0.3)\n",
1102
+ "\n",
1103
+ "plt.tight_layout()\n",
1104
+ "plt.show()\n"
1105
  ]
1106
  },
1107
  {