ybornachot commited on
Commit
11ccfa8
·
1 Parent(s): 8ade038

refactor: cleaning

Browse files
Files changed (1) hide show
  1. notebooks_tutorials/02_fine_tuning.ipynb +149 -244
notebooks_tutorials/02_fine_tuning.ipynb CHANGED
@@ -8,31 +8,25 @@
8
  "\n",
9
  "This notebook demonstrates a **simplified fine-tuning setup** that enables training of a pre-trained Nucleotide Transformer v3 (NTv3) model to predict BigWig signal tracks directly from DNA sequences. The streamlined approach leverages a pre-trained NTv3 backbone as a feature extractor and adds a custom prediction head that outputs single-nucleotide resolution signal values for various genomic tracks (e.g., ChIP-seq, ATAC-seq, RNA-seq).\n",
10
  "\n",
11
- "We provide access to the NTv3-benchmark data that we released on our Hugging Face dataset: `InstaDeepAI/NTv3_benchmark_dataset`. In this repository, you will find ready-to-use genome FASTA files, Bigwig tracks, metadata, but also the splits that were used for the benchmark.\n",
12
  "\n",
13
  "**🔧 Main Simplifications**: Compared to the full supervised tracks pipeline, this notebook simplifies several aspects to enable faster iteration:\n",
14
  "- **Random sequence sampling**: The dataset randomly samples sequences from chromosomes/regions on-the-fly, rather than using pre-computed sliding windows\n",
15
  "- **Constant learning rate**: Uses a fixed learning rate throughout training without learning rate scheduling\n",
16
  "- **No gradient accumulation**: Implements simple step-based training without gradient accumulation, making the training loop more straightforward\n",
17
  "\n",
18
- "**⚡ Key Advantage**: This simplified pipeline achieves close performance to more complex training approaches while enabling fast fine-tuning: on a H100 GPU and using 16 workers for data loading, it takes ~15min to reach acceptable performances for a 32kb functional tracks prediction task on **NTv3_8M_pre** model. The training speed benefits from the efficient NTv3 model architecture, but of course depends on your hardware capabilities (GPU acceleration and multi-worker data loading significantly reduce training time).\n",
19
- "\n",
20
- "**⚠️ Important Note on Hardware Requirements**: While this pipeline is designed to run on limited resources (e.g., Google Colab with a T4 GPU and 2CPUs), the mentioned training time or displayed performances (see **Test evaluation** section) was obtained on a more powerful setup. If you want to reach similar performance levels, you should be aware that you'll need **significant hardware resources** (high-end GPUs with substantial memory and multiple data loading workers). Training times will vary significantly based on your hardware configuration.\n",
21
- "\n",
22
- "The pipeline walks through the complete fine-tuning workflow:\n",
23
- "- Loading genomic FASTA files sequences and their corresponding BigWig signal tracks from Hugging Face dataset\n",
24
- "- Setting up a PyTorch dataset with proper train/validation/test splits\n",
25
- "- Configuring the model architecture with a custom linear head\n",
26
- "- Implementing a training loop with appropriate loss functions and evaluation metrics\n",
27
- "- Evaluation of the fine-tuned model on the test set\n",
28
- "\n",
29
- "This provides a clean interface for fine-tuning and evaluation.\n",
30
- "\n",
31
- "The model architecture consists of a pre-trained NTv3 backbone that processes DNA sequences and a custom linear head that predicts BigWig signal values at single-nucleotide resolution. Predictions are center-cropped to focus on the central portion of the input sequence (configurable via `keep_target_center_fraction`), which helps reduce edge effects from sequence context windows. The training uses a Poisson-Multinomial loss function that captures both the scale and shape of the signal distributions, and evaluation is performed using Pearson correlation metrics on both scaled and raw predictions.\n",
32
  "\n",
33
- "If you're interested in using pre-trained models for inference without fine-tuning, or exploring different model architectures, please refer to other notebooks in this collection. This notebook focuses specifically on the simplified fine-tuning process, which is useful when you want to quickly adapt a pre-trained model to genomic tracks or improve performance on particular cell types or experimental conditions.\n",
34
  "\n",
35
- "📝 Note for Google Colab users: This notebook is compatible with Colab and designed to work with limited resources! For faster training, make sure to enable GPU: Runtime → Change runtime type → GPU (T4 or better recommended). However, keep in mind that the timing benchmarks mentioned above were obtained on much more powerful hardware (H100 GPU), so your training times on Colab may be significantly longer."
36
  ]
37
  },
38
  {
@@ -58,7 +52,6 @@
58
  "metadata": {},
59
  "outputs": [],
60
  "source": [
61
- "# Standard library imports\n",
62
  "import functools\n",
63
  "import fnmatch\n",
64
  "import os\n",
@@ -66,7 +59,6 @@
66
  "from pathlib import Path\n",
67
  "from typing import Callable, Dict, List\n",
68
  "\n",
69
- "# Third-party imports\n",
70
  "from huggingface_hub import HfApi, snapshot_download\n",
71
  "import matplotlib.pyplot as plt\n",
72
  "import numpy as np\n",
@@ -88,15 +80,10 @@
88
  "metadata": {},
89
  "source": [
90
  "# 1. ⚙️ Configuration\n",
91
- " \n",
92
- "💡 **Tip:** The parameters below are pre-configured for minimal requirements and are suitable for running on a Colab GPU, but this may come at the cost of reduced model performance or slower training. \n",
93
- " \n",
94
- "Feel free to experiment with these parameters according to your available resources:\n",
95
- "- If you have a more powerful GPU, **increase** `batch_size`, `learning_rate`, and `num_steps_training` for better performance and more robust training results.\n",
96
- "- To speed up training (especially during data loading), consider increasing the `num_workers` value if memory and CPU resources allow.\n",
97
- "\n",
98
- "Current configuration allow to reach decent performances and completes training in ~1h30 on a colab environment with one T4 GPU and 2CPUs. \n",
99
  "\n",
 
 
 
100
  "\n",
101
  "## Configuration Parameters\n",
102
  "\n",
@@ -204,7 +191,7 @@
204
  },
205
  {
206
  "cell_type": "code",
207
- "execution_count": 3,
208
  "metadata": {},
209
  "outputs": [],
210
  "source": [
@@ -279,8 +266,6 @@
279
  " # FASTA file\n",
280
  " fasta_path_repo = f\"{species}/genome.fasta\"\n",
281
  " fasta_path = str(local_dir / fasta_path_repo)\n",
282
- " if not Path(fasta_path).is_file():\n",
283
- " raise ValueError(f\"FASTA file not found at '{fasta_path}'\")\n",
284
  " \n",
285
  " # BigWig files - use downloaded files directly\n",
286
  " bigwig_dir = local_dir / species / \"functional_tracks\"\n",
@@ -296,8 +281,7 @@
296
  " # Splits file\n",
297
  " splits_path_repo = f\"{species}/splits.bed\"\n",
298
  " splits_path = local_dir / splits_path_repo\n",
299
- " if not splits_path.is_file():\n",
300
- " raise ValueError(f\"Splits file not found at '{splits_path}'\")\n",
301
  " splits_df = pd.read_csv(\n",
302
  " splits_path, \n",
303
  " sep=\"\\t\", \n",
@@ -311,7 +295,7 @@
311
  " metadata_df = pd.read_csv(metadata_path, sep=\"\\t\")\n",
312
  "\n",
313
  " # Filter metadata according to species\n",
314
- " metadata_df = metadata_df[metadata_df[\"species\"] == species].reset_index(drop=True)\n",
315
  "\n",
316
  " # Order metadata according to bigwig file ids\n",
317
  " metadata_df = (\n",
@@ -367,23 +351,24 @@
367
  "source": [
368
  "# 3. 🧠 Model and tokenizer setup\n",
369
  " \n",
370
- "In this section, we set up the model and tokenizer. \n",
371
- " \n",
372
- "Our approach uses any suitable pretrained backbone from HuggingFace Transformers (for example, `InstaDeepAI/ntv3_650M_pre`),\n",
373
- "which is then extended with an additional linear head. \n",
374
- " \n",
375
- "This linear head is trained for regression on a set of genomic tracks, \n",
376
- "allowing the model to make predictions for each track at single nucleotide resolution.\n",
377
- " \n",
378
- "The following code wraps the HuggingFace model together with this regression head for the end-to-end task.\n"
379
  ]
380
  },
381
  {
382
  "cell_type": "code",
383
- "execution_count": 20,
384
  "metadata": {},
385
  "outputs": [],
386
  "source": [
 
 
 
 
 
 
 
387
  "class LinearHead(nn.Module):\n",
388
  " \"\"\"A linear head that predicts one scalar value per track.\"\"\"\n",
389
  " def __init__(self, embed_dim: int, num_labels: int):\n",
@@ -419,11 +404,7 @@
419
  " self.backbone = torch.compile(backbone)\n",
420
  " \n",
421
  " self.keep_target_center_fraction = keep_target_center_fraction\n",
422
- "\n",
423
- " if hasattr(self.config, \"embed_dim\"):\n",
424
- " embed_dim = self.config.embed_dim\n",
425
- " else:\n",
426
- " raise ValueError(f\"Could not determine embed_dim for {model_name}\")\n",
427
  " \n",
428
  " # Bigwig head (NTv3 outputs at single-nucleotide resolution)\n",
429
  " self.bigwig_head = LinearHead(embed_dim, len(bigwig_track_names))\n",
@@ -436,10 +417,7 @@
436
  " \n",
437
  " # Crop to center fraction\n",
438
  " if self.keep_target_center_fraction < 1.0:\n",
439
- " seq_len = embedding.shape[1]\n",
440
- " target_offset = int(seq_len * (1 - self.keep_target_center_fraction) // 2)\n",
441
- " target_length = seq_len - 2 * target_offset\n",
442
- " embedding = embedding[:, target_offset:target_offset + target_length, :]\n",
443
  " \n",
444
  " # Predict bigwig tracks\n",
445
  " bigwig_logits = self.bigwig_head(embedding)\n",
@@ -449,7 +427,7 @@
449
  },
450
  {
451
  "cell_type": "code",
452
- "execution_count": 21,
453
  "metadata": {},
454
  "outputs": [
455
  {
@@ -473,7 +451,6 @@
473
  " keep_target_center_fraction=config[\"keep_target_center_fraction\"],\n",
474
  ")\n",
475
  "model = model.to(device)\n",
476
- "model.train()\n",
477
  "\n",
478
  "print(f\"Model loaded: {config['model_name']}\")\n",
479
  "print(f\"Number of bigwig tracks: {len(bigwig_ids)}\")\n",
@@ -498,7 +475,7 @@
498
  },
499
  {
500
  "cell_type": "code",
501
- "execution_count": 22,
502
  "metadata": {},
503
  "outputs": [],
504
  "source": [
@@ -539,8 +516,7 @@
539
  " _bigwig_cache[cache_key] = pyBigWig.open(abs_path)\n",
540
  " except Exception as e:\n",
541
  " raise RuntimeError(\n",
542
- " f\"Failed to open BigWig file: {abs_path}\\n\"\n",
543
- " f\"Error: {str(e)}\\n\"\n",
544
  " f\"File exists: {Path(abs_path).exists()}\\n\"\n",
545
  " f\"File size: {Path(abs_path).stat().st_size if Path(abs_path).exists() else 'N/A'} bytes\"\n",
546
  " ) from e\n",
@@ -550,38 +526,10 @@
550
  "\n",
551
  "class GenomeBigWigDataset(Dataset):\n",
552
  " \"\"\"\n",
553
- " Random genomic windows from a reference genome + bigWig signal.\n",
554
- "\n",
555
- " Each sample:\n",
556
- " - picks a random region from the specified split,\n",
557
- " - picks a random window of length `sequence_length` within that region,\n",
558
- " - returns (sequence, signal, chrom, start, end).\n",
559
- "\n",
560
- " This dataset is compatible with multi-worker DataLoaders. BigWig files\n",
561
- " are opened lazily using a process-local cache, ensuring each worker process\n",
562
- " has its own file handles and avoiding concurrent access issues.\n",
563
- "\n",
564
- " Args\n",
565
- " ----\n",
566
- " fasta_path : str\n",
567
- " Path to the reference genome FASTA (e.g. hg38.fna).\n",
568
- " bigwig_path_list : list[str]\n",
569
- " List of paths to bigWig files.\n",
570
- " chrom_regions : pd.DataFrame\n",
571
- " DataFrame with columns: chr_name, start, end, split.\n",
572
- " Contains all genomic regions with their split assignments.\n",
573
- " split : str\n",
574
- " Split name to filter regions (e.g., \"train\", \"val\", \"test\").\n",
575
- " sequence_length : int\n",
576
- " Length of each random window (in bp).\n",
577
- " num_samples : int\n",
578
- " Number of samples the dataset will provide (len(dataset)).\n",
579
- " tokenizer : AutoTokenizer\n",
580
- " Tokenizer to use for tokenization.\n",
581
- " transform_fn : Callable\n",
582
- " Function to transform/scaling bigwig targets.\n",
583
- " keep_target_center_fraction : float\n",
584
- " Fraction of center sequence to keep for target prediction (crops edges to focus on center).\n",
585
  " \"\"\"\n",
586
  "\n",
587
  " def __init__(\n",
@@ -622,9 +570,6 @@
622
  " # Store valid region\n",
623
  " self.valid_regions.append((row.chr_name, row.start, row.end))\n",
624
  "\n",
625
- " if not self.valid_regions:\n",
626
- " raise ValueError(f\"No valid regions found for split '{split}'\")\n",
627
- "\n",
628
  " def __len__(self):\n",
629
  " return self.num_samples\n",
630
  "\n",
@@ -664,10 +609,7 @@
664
  " \n",
665
  " # Crop targets to center fraction\n",
666
  " if self.keep_target_center_fraction < 1.0:\n",
667
- " seq_len = bigwig_targets.shape[0] # First dimension is sequence length\n",
668
- " target_offset = int(seq_len * (1 - self.keep_target_center_fraction) // 2)\n",
669
- " target_length = seq_len - 2 * target_offset\n",
670
- " bigwig_targets = bigwig_targets[target_offset:target_offset + target_length, :]\n",
671
  "\n",
672
  " # Apply scaling to targets\n",
673
  " bigwig_targets = self.transform_fn(bigwig_targets)\n",
@@ -691,7 +633,7 @@
691
  },
692
  {
693
  "cell_type": "code",
694
- "execution_count": 8,
695
  "metadata": {},
696
  "outputs": [],
697
  "source": [
@@ -699,13 +641,7 @@
699
  " metadata_df: pd.DataFrame\n",
700
  ") -> Callable[[torch.Tensor], torch.Tensor]:\n",
701
  " \"\"\"\n",
702
- " Build a scaling function based on track means contained in the metadata.\n",
703
- "\n",
704
- " Args:\n",
705
- " metadata_df: pandas.DataFrame with track means\n",
706
- "\n",
707
- " Returns:\n",
708
- " Transform function that scales input tensors\n",
709
  " \"\"\"\n",
710
  " # Open bigwig files and compute track statistics\n",
711
  " track_means = metadata_df[\"mean\"].to_numpy()\n",
@@ -716,9 +652,6 @@
716
  " track_means_tensor = torch.tensor(track_means, dtype=torch.float32)\n",
717
  "\n",
718
  " def transform_fn(x: torch.Tensor) -> torch.Tensor:\n",
719
- " \"\"\"\n",
720
- " x: torch.Tensor, shape (seq_len, num_tracks) or (batch, seq_len, num_tracks)\n",
721
- " \"\"\"\n",
722
  " # Move constants to correct device then normalize\n",
723
  " means = track_means_tensor.to(x.device)\n",
724
  " scaled = x / means\n",
@@ -879,22 +812,30 @@
879
  },
880
  {
881
  "cell_type": "code",
882
- "execution_count": 25,
883
  "metadata": {},
884
  "outputs": [],
885
  "source": [
886
  "class TracksMetrics:\n",
887
- " \"\"\"Simple metrics tracker for tracks prediction.\"\"\"\n",
888
  " \n",
889
- " def __init__(self, track_names: List[str]):\n",
890
  " self.track_names = track_names\n",
891
  " self.num_tracks = len(track_names)\n",
892
- " self.pearson_metric = PearsonCorrCoef(num_outputs=self.num_tracks).to(device)\n",
893
- " self.pearson_metric.set_dtype(torch.float64)\n",
 
 
 
894
  " self.losses = []\n",
 
 
 
 
 
895
  " \n",
896
  " def reset(self):\n",
897
- " self.pearson_metric.reset()\n",
898
  " self.losses = []\n",
899
  " \n",
900
  " def update(\n",
@@ -904,51 +845,70 @@
904
  " loss: float\n",
905
  " ):\n",
906
  " \"\"\"\n",
907
- " Update metrics.\n",
908
- " Args:\n",
909
- " predictions: (batch, seq_len, num_tracks)\n",
910
- " targets: (batch, seq_len, num_tracks)\n",
911
- " loss: scalar loss value\n",
912
  " \"\"\"\n",
913
  " # Flatten batch and sequence dimensions\n",
914
- " pred_flat = predictions.detach().reshape(-1, self.num_tracks) # (N, num_tracks)\n",
915
- " target_flat = targets.detach().reshape(-1, self.num_tracks) # (N, num_tracks)\n",
916
- " \n",
917
- " # Convert to float64 for improved numerical stability in Pearson correlation\n",
918
- " pred_flat = pred_flat.to(torch.float64)\n",
919
- " target_flat = target_flat.to(torch.float64)\n",
920
- " self.pearson_metric.update(pred_flat, target_flat)\n",
921
  " \n",
 
 
922
  " self.losses.append(loss)\n",
923
  " \n",
924
  " def compute(self) -> Dict[str, float]:\n",
925
- " \"\"\"Compute and return all metrics.\"\"\"\n",
926
- " metrics_dict = {}\n",
927
- " \n",
928
- " # Compute Pearson correlation per track\n",
929
- " # Move to CPU before converting to numpy\n",
930
- " correlations = self.pearson_metric.compute().cpu().numpy()\n",
931
- " for i, track_name in enumerate(self.track_names):\n",
932
- " metrics_dict[f\"{track_name}/pearson\"] = correlations[i]\n",
933
- " \n",
934
- " # Mean Pearson correlation\n",
935
- " metrics_dict[\"mean/pearson\"] = np.nanmean(correlations)\n",
936
  " \n",
937
  " # Mean loss\n",
938
- " metrics_dict[\"loss\"] = np.mean(self.losses) if self.losses else 0.0\n",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
939
  " \n",
940
- " return metrics_dict"
 
 
 
 
 
 
 
 
 
 
 
941
  ]
942
  },
943
  {
944
  "cell_type": "code",
945
- "execution_count": 26,
946
  "metadata": {},
947
  "outputs": [],
948
  "source": [
949
- "train_metrics = TracksMetrics(bigwig_ids)\n",
950
- "val_metrics = TracksMetrics(bigwig_ids)\n",
951
- "test_metrics = TracksMetrics(bigwig_ids)"
952
  ]
953
  },
954
  {
@@ -962,7 +922,7 @@
962
  },
963
  {
964
  "cell_type": "code",
965
- "execution_count": 27,
966
  "metadata": {},
967
  "outputs": [],
968
  "source": [
@@ -983,16 +943,8 @@
983
  " epsilon: float = 1e-7,\n",
984
  ") -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:\n",
985
  " \"\"\"\n",
986
- " Regression loss for bigwig tracks (Poisson-Multinomial).\n",
987
- " \n",
988
- " Args:\n",
989
- " logits: (batch, seq_length, num_tracks) - predicted counts\n",
990
- " targets: (batch, seq_length, num_tracks) - target counts\n",
991
- " shape_loss_coefficient: coefficient to weight scale loss\n",
992
- " epsilon: epsilon for numerical stability\n",
993
- " \n",
994
- " Returns:\n",
995
- " loss, scale_loss, shape_loss\n",
996
  " \"\"\"\n",
997
  " batch_size, seq_length, num_tracks = logits.shape\n",
998
  " \n",
@@ -1044,14 +996,16 @@
1044
  },
1045
  {
1046
  "cell_type": "code",
1047
- "execution_count": 28,
1048
  "metadata": {},
1049
  "outputs": [],
1050
  "source": [
1051
  "def train_step(\n",
1052
  " model: nn.Module,\n",
 
1053
  " batch: Dict[str, torch.Tensor],\n",
1054
- ") -> float:\n",
 
1055
  " \"\"\"Single training step.\"\"\"\n",
1056
  " tokens = batch[\"tokens\"].to(device)\n",
1057
  " bigwig_targets = batch[\"bigwig_targets\"].to(device)\n",
@@ -1065,19 +1019,27 @@
1065
  " logits=bigwig_logits,\n",
1066
  " targets=bigwig_targets,\n",
1067
  " )\n",
1068
- " \n",
1069
  " # Backward pass\n",
 
1070
  " loss.backward()\n",
1071
- " return loss.item()\n",
 
 
 
 
 
 
 
 
 
1072
  "\n",
1073
  "def validation_step(\n",
1074
  " model: nn.Module,\n",
1075
  " batch: Dict[str, torch.Tensor],\n",
1076
  " metrics: TracksMetrics,\n",
1077
- ") -> float:\n",
1078
  " \"\"\"Single validation step.\"\"\"\n",
1079
- " model.eval()\n",
1080
- " \n",
1081
  " tokens = batch[\"tokens\"].to(device)\n",
1082
  " bigwig_targets = batch[\"bigwig_targets\"].to(device)\n",
1083
  " \n",
@@ -1097,14 +1059,12 @@
1097
  " predictions=bigwig_logits,\n",
1098
  " targets=bigwig_targets,\n",
1099
  " loss=loss.item()\n",
1100
- " )\n",
1101
- " \n",
1102
- " return loss.item()"
1103
  ]
1104
  },
1105
  {
1106
  "cell_type": "code",
1107
- "execution_count": 29,
1108
  "metadata": {},
1109
  "outputs": [
1110
  {
@@ -2327,23 +2287,11 @@
2327
  ],
2328
  "source": [
2329
  "# Training loop\n",
2330
- "print(\"Starting training...\")\n",
2331
- "print(f\"Training for {config['num_steps_training']} steps\\n\")\n",
2332
- "\n",
2333
- "model.train()\n",
2334
- "train_metrics.reset()\n",
2335
- "optimizer.zero_grad() # Initialize gradients\n",
2336
- "\n",
2337
- "# Track metrics for plotting\n",
2338
- "train_steps = []\n",
2339
- "train_losses = []\n",
2340
- "train_pearson_scores = []\n",
2341
- "val_steps = []\n",
2342
- "val_losses = []\n",
2343
- "val_pearson_scores = []\n",
2344
  "\n",
2345
  "# Create iterator for training data (will cycle if needed)\n",
2346
  "train_iter = iter(train_loader)\n",
 
2347
  "\n",
2348
  "# Main training loop\n",
2349
  "for step_idx in range(config[\"num_steps_training\"]):\n",
@@ -2354,78 +2302,37 @@
2354
  " train_iter = iter(train_loader)\n",
2355
  " batch = next(train_iter)\n",
2356
  " \n",
2357
- " # Forward pass and backward pass\n",
2358
- " loss = train_step(model, batch)\n",
2359
- " \n",
2360
- " # Update optimizer\n",
2361
- " optimizer.step()\n",
2362
- " optimizer.zero_grad()\n",
2363
- " \n",
2364
- " # Update metrics\n",
2365
- " tokens = batch[\"tokens\"].to(device)\n",
2366
- " bigwig_targets = batch[\"bigwig_targets\"].to(device)\n",
2367
- " with torch.no_grad():\n",
2368
- " outputs = model(tokens=tokens)\n",
2369
- " bigwig_logits = outputs[\"bigwig_tracks_logits\"]\n",
2370
- " \n",
2371
- " train_metrics.update(\n",
2372
- " predictions=bigwig_logits,\n",
2373
- " targets=bigwig_targets,\n",
2374
- " loss=loss\n",
2375
- " )\n",
2376
- " \n",
2377
  " # Logging\n",
2378
  " if (step_idx + 1) % config[\"log_every_n_steps\"] == 0:\n",
2379
- " train_metrics_dict = train_metrics.compute()\n",
2380
- " \n",
2381
- " # Get accumulated mean loss across all batches since last reset\n",
2382
- " mean_loss = train_metrics_dict['loss']\n",
2383
- " \n",
2384
- " # Track metrics for plotting\n",
2385
- " train_steps.append(step_idx + 1)\n",
2386
- " train_losses.append(mean_loss)\n",
2387
- " train_pearson_scores.append(train_metrics_dict['mean/pearson'])\n",
2388
- " \n",
2389
- " \n",
2390
- " print(\n",
2391
- " f\"Step {step_idx + 1}/{config['num_steps_training']} | \"\n",
2392
- " f\"Loss: {mean_loss:.4f} | \"\n",
2393
- " f\"Mean Pearson: {train_metrics_dict['mean/pearson']:.4f}\"\n",
2394
- " )\n",
2395
  " train_metrics.reset()\n",
2396
  " \n",
2397
  " # Validation\n",
2398
  " if (step_idx + 1) % config[\"validate_every_n_steps\"] == 0:\n",
2399
  " print(f\"\\nRunning validation at step {step_idx + 1}...\")\n",
2400
- " val_metrics.reset()\n",
2401
  " model.eval()\n",
2402
  " \n",
2403
  " for val_batch in val_loader:\n",
2404
- " val_loss = validation_step(model, val_batch, val_metrics)\n",
2405
- " \n",
2406
- " # Print validation metrics\n",
2407
- " val_metrics_dict = val_metrics.compute()\n",
2408
- " val_pearson_mean = val_metrics_dict['mean/pearson']\n",
2409
- " \n",
2410
- " # Track validation metrics\n",
2411
- " val_steps.append(step_idx + 1)\n",
2412
- " val_losses.append(val_metrics_dict['loss'])\n",
2413
- " val_pearson_scores.append(val_pearson_mean)\n",
2414
- " \n",
2415
  " \n",
2416
- " print(f\" Validation Loss: {val_metrics_dict['loss']:.4f}\")\n",
2417
- " print(f\" Validation Mean Pearson: {val_pearson_mean:.4f}\")\n",
2418
- " for track_name in bigwig_ids:\n",
2419
- " print(f\" {track_name}/pearson: {val_metrics_dict[f'{track_name}/pearson']:.4f}\")\n",
2420
- " \n",
2421
- " model.train() # Back to training mode\n",
 
2422
  "\n",
2423
  "print(f\"\\nTraining completed after {config['num_steps_training']} steps.\")\n"
2424
  ]
2425
  },
2426
  {
2427
  "cell_type": "code",
2428
- "execution_count": 30,
2429
  "metadata": {},
2430
  "outputs": [
2431
  {
@@ -2441,12 +2348,14 @@
2441
  ],
2442
  "source": [
2443
  "# Plot training results\n",
2444
- "fig, axes = plt.subplots(1, 2, figsize=(16, 6))\n",
 
 
 
2445
  "\n",
2446
  "# Plot Loss\n",
2447
- "axes[0].plot(train_steps, train_losses, 'b-o', label='Train Loss', markersize=4, linewidth=1.5)\n",
2448
- "if val_steps:\n",
2449
- " axes[0].plot(val_steps, val_losses, 'r-s', label='Val Loss', markersize=4, linewidth=1.5)\n",
2450
  "axes[0].set_xlabel('Step')\n",
2451
  "axes[0].set_ylabel('Loss')\n",
2452
  "axes[0].set_title('Loss')\n",
@@ -2454,17 +2363,13 @@
2454
  "axes[0].grid(True, alpha=0.3)\n",
2455
  "\n",
2456
  "# Plot Pearson Correlation\n",
2457
- "axes[1].plot(train_steps, train_pearson_scores, 'g-o', label='Train Pearson', markersize=4, linewidth=1.5)\n",
2458
- "if val_steps:\n",
2459
- " axes[1].plot(val_steps, val_pearson_scores, 'orange', marker='s', label='Val Pearson', markersize=4, linewidth=1.5)\n",
2460
  "axes[1].set_xlabel('Step')\n",
2461
  "axes[1].set_ylabel('Pearson Correlation')\n",
2462
  "axes[1].set_title('Mean Pearson Correlation')\n",
2463
  "axes[1].legend()\n",
2464
- "axes[1].grid(True, alpha=0.3)\n",
2465
- "\n",
2466
- "plt.tight_layout()\n",
2467
- "plt.show()\n"
2468
  ]
2469
  },
2470
  {
 
8
  "\n",
9
  "This notebook demonstrates a **simplified fine-tuning setup** that enables training of a pre-trained Nucleotide Transformer v3 (NTv3) model to predict BigWig signal tracks directly from DNA sequences. The streamlined approach leverages a pre-trained NTv3 backbone as a feature extractor and adds a custom prediction head that outputs single-nucleotide resolution signal values for various genomic tracks (e.g., ChIP-seq, ATAC-seq, RNA-seq).\n",
10
  "\n",
11
+ "📊 We provide access to the NTv3-benchmark data that we released on our Hugging Face dataset: `InstaDeepAI/NTv3_benchmark_dataset`. In this repository, you will find ready-to-use genome FASTA files, Bigwig tracks, metadata, but also the splits that were used for the benchmark.\n",
12
  "\n",
13
  "**🔧 Main Simplifications**: Compared to the full supervised tracks pipeline, this notebook simplifies several aspects to enable faster iteration:\n",
14
  "- **Random sequence sampling**: The dataset randomly samples sequences from chromosomes/regions on-the-fly, rather than using pre-computed sliding windows\n",
15
  "- **Constant learning rate**: Uses a fixed learning rate throughout training without learning rate scheduling\n",
16
  "- **No gradient accumulation**: Implements simple step-based training without gradient accumulation, making the training loop more straightforward\n",
17
  "\n",
18
+ "**⚡ Key Advantage**: This simplified pipeline achieves close performance to more complex training approaches while enabling fast fine-tuning: on a H100 GPU and using 16 workers for data loading, it takes ~15min to reach acceptable performances for a 32kb functional tracks prediction task on **NTv3_8M_pre** model. The training speed benefits from the efficient NTv3 model architecture, but of course depends on your hardware capabilities (GPU acceleration and multi-worker data loading significantly reduce training time).\n"
19
+ ]
20
+ },
21
+ {
22
+ "cell_type": "markdown",
23
+ "metadata": {},
24
+ "source": [
25
+ "## 💻 A note on hardware\n",
 
 
 
 
 
 
26
  "\n",
27
+ "While this pipeline is designed to run on limited resources (e.g., Google Colab with a T4 GPU and 2CPUs), the mentioned training time or displayed performances (see **Test evaluation** section) was obtained on a more powerful setup. If you want to reach similar performance levels, you should be aware that you'll need **significant hardware resources** (high-end GPUs with substantial memory and multiple data loading workers). Training times will vary significantly based on your hardware configuration.\n",
28
  "\n",
29
+ "📝 Note for Google Colab users: This notebook is compatible with Colab and designed to work with limited resources! For faster training, make sure to enable GPU: Runtime → Change runtime type → GPU (T4 or better recommended)."
30
  ]
31
  },
32
  {
 
52
  "metadata": {},
53
  "outputs": [],
54
  "source": [
 
55
  "import functools\n",
56
  "import fnmatch\n",
57
  "import os\n",
 
59
  "from pathlib import Path\n",
60
  "from typing import Callable, Dict, List\n",
61
  "\n",
 
62
  "from huggingface_hub import HfApi, snapshot_download\n",
63
  "import matplotlib.pyplot as plt\n",
64
  "import numpy as np\n",
 
80
  "metadata": {},
81
  "source": [
82
  "# 1. ⚙️ Configuration\n",
 
 
 
 
 
 
 
 
83
  "\n",
84
+ "⏳ The parameters below are pre-configured to enable training on a T4 GPU (free on Colab). For faster training, use a more powerful GPU and increase the `batch_size`, `learning_rate`, and `num_steps_training` parameters. To speed up dataloading, consider increasing the `num_workers` value if memory and CPU resources allow.\n",
85
+ " \n",
86
+ "🕰️ Current configuration allow to reach decent performances and completes training in ~1h30 on a colab environment with one T4 GPU and 2CPUs. \n",
87
  "\n",
88
  "## Configuration Parameters\n",
89
  "\n",
 
191
  },
192
  {
193
  "cell_type": "code",
194
+ "execution_count": null,
195
  "metadata": {},
196
  "outputs": [],
197
  "source": [
 
266
  " # FASTA file\n",
267
  " fasta_path_repo = f\"{species}/genome.fasta\"\n",
268
  " fasta_path = str(local_dir / fasta_path_repo)\n",
 
 
269
  " \n",
270
  " # BigWig files - use downloaded files directly\n",
271
  " bigwig_dir = local_dir / species / \"functional_tracks\"\n",
 
281
  " # Splits file\n",
282
  " splits_path_repo = f\"{species}/splits.bed\"\n",
283
  " splits_path = local_dir / splits_path_repo\n",
284
+ "\n",
 
285
  " splits_df = pd.read_csv(\n",
286
  " splits_path, \n",
287
  " sep=\"\\t\", \n",
 
295
  " metadata_df = pd.read_csv(metadata_path, sep=\"\\t\")\n",
296
  "\n",
297
  " # Filter metadata according to species\n",
298
+ " metadata_df = metadata_df[metadata_df[\"species_common_name\"] == species].reset_index(drop=True)\n",
299
  "\n",
300
  " # Order metadata according to bigwig file ids\n",
301
  " metadata_df = (\n",
 
351
  "source": [
352
  "# 3. 🧠 Model and tokenizer setup\n",
353
  " \n",
354
+ "This section sets up the model by extended any pretrained backbone from HuggingFace Transformers (for example, `InstaDeepAI/ntv3_650M_pre`) with a custom linear head.\n",
355
+ "This linear head is trained for regression on a set of genomic tracks, allowing the model to make predictions for each track at single nucleotide resolution.\n",
356
+ "Predictions are center-cropped to focus on the central portion of the input sequence (configurable via `keep_target_center_fraction`), which helps reduce edge effects from sequence context windows.\n"
 
 
 
 
 
 
357
  ]
358
  },
359
  {
360
  "cell_type": "code",
361
+ "execution_count": null,
362
  "metadata": {},
363
  "outputs": [],
364
  "source": [
365
+ "def crop_center(x: np.ndarray, keep_target_center_fraction: float = 0.375) -> np.ndarray:\n",
366
+ " \"\"\"Crop the central sequence-length fraction for arrays of size (..., seq_len, num_tracks)\"\"\"\n",
367
+ " seq_len = x.shape[-2]\n",
368
+ " target_offset = int(seq_len * (1 - keep_target_center_fraction) // 2)\n",
369
+ " target_length = seq_len - 2 * target_offset\n",
370
+ " return x[..., target_offset:target_offset + target_length, :]\n",
371
+ "\n",
372
  "class LinearHead(nn.Module):\n",
373
  " \"\"\"A linear head that predicts one scalar value per track.\"\"\"\n",
374
  " def __init__(self, embed_dim: int, num_labels: int):\n",
 
404
  " self.backbone = torch.compile(backbone)\n",
405
  " \n",
406
  " self.keep_target_center_fraction = keep_target_center_fraction\n",
407
+ " embed_dim = self.config.embed_dim\n",
 
 
 
 
408
  " \n",
409
  " # Bigwig head (NTv3 outputs at single-nucleotide resolution)\n",
410
  " self.bigwig_head = LinearHead(embed_dim, len(bigwig_track_names))\n",
 
417
  " \n",
418
  " # Crop to center fraction\n",
419
  " if self.keep_target_center_fraction < 1.0:\n",
420
+ " embedding = crop_center(embedding, self.keep_target_center_fraction)\n",
 
 
 
421
  " \n",
422
  " # Predict bigwig tracks\n",
423
  " bigwig_logits = self.bigwig_head(embedding)\n",
 
427
  },
428
  {
429
  "cell_type": "code",
430
+ "execution_count": null,
431
  "metadata": {},
432
  "outputs": [
433
  {
 
451
  " keep_target_center_fraction=config[\"keep_target_center_fraction\"],\n",
452
  ")\n",
453
  "model = model.to(device)\n",
 
454
  "\n",
455
  "print(f\"Model loaded: {config['model_name']}\")\n",
456
  "print(f\"Number of bigwig tracks: {len(bigwig_ids)}\")\n",
 
475
  },
476
  {
477
  "cell_type": "code",
478
+ "execution_count": null,
479
  "metadata": {},
480
  "outputs": [],
481
  "source": [
 
516
  " _bigwig_cache[cache_key] = pyBigWig.open(abs_path)\n",
517
  " except Exception as e:\n",
518
  " raise RuntimeError(\n",
519
+ " f\"Failed to open BigWig file: {abs_path} with error: {str(e)}\\n\"\n",
 
520
  " f\"File exists: {Path(abs_path).exists()}\\n\"\n",
521
  " f\"File size: {Path(abs_path).stat().st_size if Path(abs_path).exists() else 'N/A'} bytes\"\n",
522
  " ) from e\n",
 
526
  "\n",
527
  "class GenomeBigWigDataset(Dataset):\n",
528
  " \"\"\"\n",
529
+ " A PyTorch dataset to access a reference genome and bigwig tracks. The dataset is \n",
530
+ " compatible with multi-worker DataLoaders (using process-local file handles and lazy \n",
531
+ " loading). For each sample, a random genomic region is picked from the specified split,\n",
532
+ " and a random window of length `sequence_length` within that region is returned.\n",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
533
  " \"\"\"\n",
534
  "\n",
535
  " def __init__(\n",
 
570
  " # Store valid region\n",
571
  " self.valid_regions.append((row.chr_name, row.start, row.end))\n",
572
  "\n",
 
 
 
573
  " def __len__(self):\n",
574
  " return self.num_samples\n",
575
  "\n",
 
609
  " \n",
610
  " # Crop targets to center fraction\n",
611
  " if self.keep_target_center_fraction < 1.0:\n",
612
+ " bigwig_targets = crop_center(bigwig_targets, self.keep_target_center_fraction)\n",
 
 
 
613
  "\n",
614
  " # Apply scaling to targets\n",
615
  " bigwig_targets = self.transform_fn(bigwig_targets)\n",
 
633
  },
634
  {
635
  "cell_type": "code",
636
+ "execution_count": null,
637
  "metadata": {},
638
  "outputs": [],
639
  "source": [
 
641
  " metadata_df: pd.DataFrame\n",
642
  ") -> Callable[[torch.Tensor], torch.Tensor]:\n",
643
  " \"\"\"\n",
644
+ " Build a scaling function that uses the track means to normalise and softclip the targets.\n",
 
 
 
 
 
 
645
  " \"\"\"\n",
646
  " # Open bigwig files and compute track statistics\n",
647
  " track_means = metadata_df[\"mean\"].to_numpy()\n",
 
652
  " track_means_tensor = torch.tensor(track_means, dtype=torch.float32)\n",
653
  "\n",
654
  " def transform_fn(x: torch.Tensor) -> torch.Tensor:\n",
 
 
 
655
  " # Move constants to correct device then normalize\n",
656
  " means = track_means_tensor.to(x.device)\n",
657
  " scaled = x / means\n",
 
812
  },
813
  {
814
  "cell_type": "code",
815
+ "execution_count": null,
816
  "metadata": {},
817
  "outputs": [],
818
  "source": [
819
  "class TracksMetrics:\n",
820
+ " \"\"\"Metrics to handle multi-track pearson correlations and losses\"\"\"\n",
821
  " \n",
822
+ " def __init__(self, track_names: List[str], split: str):\n",
823
  " self.track_names = track_names\n",
824
  " self.num_tracks = len(track_names)\n",
825
+ " self.split = split\n",
826
+ "\n",
827
+ " # Initialise metrics \n",
828
+ " self.pearson = PearsonCorrCoef(num_outputs=self.num_tracks).to(device)\n",
829
+ " self.pearson.set_dtype(torch.float64) # Use float64 for improved numerical stability\n",
830
  " self.losses = []\n",
831
+ "\n",
832
+ " # Record mean metrics per logging interval\n",
833
+ " self.step_idxs = []\n",
834
+ " self.mean_pearsons = []\n",
835
+ " self.mean_losses = []\n",
836
  " \n",
837
  " def reset(self):\n",
838
+ " self.pearson.reset()\n",
839
  " self.losses = []\n",
840
  " \n",
841
  " def update(\n",
 
845
  " loss: float\n",
846
  " ):\n",
847
  " \"\"\"\n",
848
+ " Update the metrics with predictions and targets of shape (..., num_tracks) and a scalar loss.\n",
 
 
 
 
849
  " \"\"\"\n",
850
  " # Flatten batch and sequence dimensions\n",
851
+ " pred_flat = predictions.detach().reshape(-1, self.num_tracks).to(torch.float64) # (N, num_tracks)\n",
852
+ " target_flat = targets.detach().reshape(-1, self.num_tracks).to(torch.float64) # (N, num_tracks)\n",
 
 
 
 
 
853
  " \n",
854
+ " # Update metrics\n",
855
+ " self.pearson.update(pred_flat, target_flat)\n",
856
  " self.losses.append(loss)\n",
857
  " \n",
858
  " def compute(self) -> Dict[str, float]:\n",
859
+ " \"\"\"Compute the pearson correlations and loss and return a dictionary of metrics.\"\"\"\n",
860
+ " # Per-track Pearson correlations\n",
861
+ " correlations = self.pearson.compute().cpu().numpy()\n",
862
+ " metrics_dict = {\n",
863
+ " f\"{track_name}/pearson\": correlations[i] for i, track_name in enumerate(self.track_names)\n",
864
+ " }\n",
865
+ " metrics_dict[\"mean/pearson\"] = correlations.mean()\n",
 
 
 
 
866
  " \n",
867
  " # Mean loss\n",
868
+ " metrics_dict[\"loss\"] = np.mean(self.losses)\n",
869
+ " \n",
870
+ " return metrics_dict\n",
871
+ "\n",
872
+ " def update_mean_metrics(self, step_idx: int):\n",
873
+ " \"\"\"Update the mean metrics over the logging interval and save to a csv file.\"\"\"\n",
874
+ " # Update mean metrics with the mean pearson & average loss\n",
875
+ " metrics_dict = self.compute()\n",
876
+ " self.step_idxs.append(step_idx)\n",
877
+ " self.mean_pearsons.append(metrics_dict[\"mean/pearson\"])\n",
878
+ " self.mean_losses.append(metrics_dict[\"loss\"])\n",
879
+ "\n",
880
+ " # Save metrics to a csv for plotting\n",
881
+ " data = {\n",
882
+ " \"step\": self.step_idxs,\n",
883
+ " \"mean_loss\": self.mean_losses,\n",
884
+ " \"mean_pearson\": self.mean_pearsons,\n",
885
+ " }\n",
886
+ " df = pd.DataFrame(data)\n",
887
+ " df.to_csv(f\"metrics_{self.split}.csv\", index=False)\n",
888
  " \n",
889
+ " def print_metrics(self, print_per_track: bool = False):\n",
890
+ " \"\"\"Print a summary of the metrics.\"\"\"\n",
891
+ " print(\n",
892
+ " f\"Step {self.step_idxs[-1]}/{config['num_steps_training']} | \"\n",
893
+ " f\"Loss: {self.mean_losses[-1]:.4f} | \"\n",
894
+ " f\"Mean Pearson: {self.mean_pearsons[-1]:.4f}\"\n",
895
+ " )\n",
896
+ " metrics_dict = self.compute()\n",
897
+ " if print_per_track:\n",
898
+ " for metric_key, metric_value in metrics_dict.items():\n",
899
+ " print(f\" {metric_key}: {metric_value:.4f}\")\n",
900
+ " "
901
  ]
902
  },
903
  {
904
  "cell_type": "code",
905
+ "execution_count": null,
906
  "metadata": {},
907
  "outputs": [],
908
  "source": [
909
+ "train_metrics = TracksMetrics(bigwig_ids, \"train\")\n",
910
+ "val_metrics = TracksMetrics(bigwig_ids, \"val\")\n",
911
+ "test_metrics = TracksMetrics(bigwig_ids, \"test\")"
912
  ]
913
  },
914
  {
 
922
  },
923
  {
924
  "cell_type": "code",
925
+ "execution_count": null,
926
  "metadata": {},
927
  "outputs": [],
928
  "source": [
 
943
  " epsilon: float = 1e-7,\n",
944
  ") -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:\n",
945
  " \"\"\"\n",
946
+ " Regression loss for bigwig tracks (Poisson-Multinomial). The logits and targets are\n",
947
+ " expected to be of shape (batch, seq_length, num_tracks).\n",
 
 
 
 
 
 
 
 
948
  " \"\"\"\n",
949
  " batch_size, seq_length, num_tracks = logits.shape\n",
950
  " \n",
 
996
  },
997
  {
998
  "cell_type": "code",
999
+ "execution_count": null,
1000
  "metadata": {},
1001
  "outputs": [],
1002
  "source": [
1003
  "def train_step(\n",
1004
  " model: nn.Module,\n",
1005
+ " optimizer: torch.optim.Optimizer,\n",
1006
  " batch: Dict[str, torch.Tensor],\n",
1007
+ " train_metrics: TracksMetrics,\n",
1008
+ ") -> None:\n",
1009
  " \"\"\"Single training step.\"\"\"\n",
1010
  " tokens = batch[\"tokens\"].to(device)\n",
1011
  " bigwig_targets = batch[\"bigwig_targets\"].to(device)\n",
 
1019
  " logits=bigwig_logits,\n",
1020
  " targets=bigwig_targets,\n",
1021
  " )\n",
1022
+ "\n",
1023
  " # Backward pass\n",
1024
+ " optimizer.zero_grad()\n",
1025
  " loss.backward()\n",
1026
+ " optimizer.step()\n",
1027
+ "\n",
1028
+ " # Update metrics\n",
1029
+ " train_metrics.update(\n",
1030
+ " predictions=bigwig_logits,\n",
1031
+ " targets=bigwig_targets,\n",
1032
+ " loss=loss.item()\n",
1033
+ " )\n",
1034
+ " \n",
1035
+ "\n",
1036
  "\n",
1037
  "def validation_step(\n",
1038
  " model: nn.Module,\n",
1039
  " batch: Dict[str, torch.Tensor],\n",
1040
  " metrics: TracksMetrics,\n",
1041
+ ") -> None:\n",
1042
  " \"\"\"Single validation step.\"\"\"\n",
 
 
1043
  " tokens = batch[\"tokens\"].to(device)\n",
1044
  " bigwig_targets = batch[\"bigwig_targets\"].to(device)\n",
1045
  " \n",
 
1059
  " predictions=bigwig_logits,\n",
1060
  " targets=bigwig_targets,\n",
1061
  " loss=loss.item()\n",
1062
+ " )"
 
 
1063
  ]
1064
  },
1065
  {
1066
  "cell_type": "code",
1067
+ "execution_count": null,
1068
  "metadata": {},
1069
  "outputs": [
1070
  {
 
2287
  ],
2288
  "source": [
2289
  "# Training loop\n",
2290
+ "print(f\"Starting training for {config['num_steps_training']} steps\\n\")\n",
 
 
 
 
 
 
 
 
 
 
 
 
 
2291
  "\n",
2292
  "# Create iterator for training data (will cycle if needed)\n",
2293
  "train_iter = iter(train_loader)\n",
2294
+ "model.train()\n",
2295
  "\n",
2296
  "# Main training loop\n",
2297
  "for step_idx in range(config[\"num_steps_training\"]):\n",
 
2302
  " train_iter = iter(train_loader)\n",
2303
  " batch = next(train_iter)\n",
2304
  " \n",
2305
+ " # Take a training step\n",
2306
+ " train_step(model, optimizer, batch, train_metrics)\n",
2307
+ "\n",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2308
  " # Logging\n",
2309
  " if (step_idx + 1) % config[\"log_every_n_steps\"] == 0:\n",
2310
+ " train_metrics.update_mean_metrics(step_idx + 1)\n",
2311
+ " train_metrics.print_metrics()\n",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2312
  " train_metrics.reset()\n",
2313
  " \n",
2314
  " # Validation\n",
2315
  " if (step_idx + 1) % config[\"validate_every_n_steps\"] == 0:\n",
2316
  " print(f\"\\nRunning validation at step {step_idx + 1}...\")\n",
 
2317
  " model.eval()\n",
2318
  " \n",
2319
  " for val_batch in val_loader:\n",
2320
+ " validation_step(model, val_batch, val_metrics)\n",
 
 
 
 
 
 
 
 
 
 
2321
  " \n",
2322
+ " val_metrics.update_mean_metrics(step_idx + 1)\n",
2323
+ " val_metrics.print_metrics(print_per_track=True)\n",
2324
+ " val_metrics.reset()\n",
2325
+ "\n",
2326
+ " # Back to training mode\n",
2327
+ " print(\"\\n\" + \"-\"*100 + \"\\nTraining metrics:\")\n",
2328
+ " model.train() \n",
2329
  "\n",
2330
  "print(f\"\\nTraining completed after {config['num_steps_training']} steps.\")\n"
2331
  ]
2332
  },
2333
  {
2334
  "cell_type": "code",
2335
+ "execution_count": null,
2336
  "metadata": {},
2337
  "outputs": [
2338
  {
 
2348
  ],
2349
  "source": [
2350
  "# Plot training results\n",
2351
+ "fig, axes = plt.subplots(1, 2, figsize=(12, 5))\n",
2352
+ "\n",
2353
+ "df_train = pd.read_csv(\"metrics_train.csv\")\n",
2354
+ "df_val = pd.read_csv(\"metrics_val.csv\")\n",
2355
  "\n",
2356
  "# Plot Loss\n",
2357
+ "axes[0].plot(df_train[\"step\"], df_train[\"mean_loss\"], 'b-o', label='Train Loss', markersize=4, linewidth=1.5)\n",
2358
+ "axes[0].plot(df_val[\"step\"], df_val[\"mean_loss\"], 'r-s', label='Val Loss', markersize=4, linewidth=1.5)\n",
 
2359
  "axes[0].set_xlabel('Step')\n",
2360
  "axes[0].set_ylabel('Loss')\n",
2361
  "axes[0].set_title('Loss')\n",
 
2363
  "axes[0].grid(True, alpha=0.3)\n",
2364
  "\n",
2365
  "# Plot Pearson Correlation\n",
2366
+ "axes[1].plot(df_train[\"step\"], df_train[\"mean_pearson\"], 'g-o', label='Train Pearson', markersize=4, linewidth=1.5)\n",
2367
+ "axes[1].plot(df_val[\"step\"], df_val[\"mean_pearson\"], 'orange', marker='s', label='Val Pearson', markersize=4, linewidth=1.5)\n",
 
2368
  "axes[1].set_xlabel('Step')\n",
2369
  "axes[1].set_ylabel('Pearson Correlation')\n",
2370
  "axes[1].set_title('Mean Pearson Correlation')\n",
2371
  "axes[1].legend()\n",
2372
+ "axes[1].grid(True, alpha=0.3)"
 
 
 
2373
  ]
2374
  },
2375
  {