ybornachot commited on
Commit
a30b8d8
·
1 Parent(s): 3ffcb7a

feat: link to HF dataset to abstract data pipeline

Browse files
Files changed (1) hide show
  1. notebooks/03_fine_tuning.ipynb +569 -501
notebooks/03_fine_tuning.ipynb CHANGED
@@ -4,13 +4,13 @@
4
  "cell_type": "markdown",
5
  "metadata": {},
6
  "source": [
7
- "# \ud83e\uddec Fine-Tuning a Model on BigWig Tracks Prediction\n",
8
  "\n",
9
  "This notebook demonstrates a **simplified fine-tuning setup** that enables training of a pre-trained Nucleotide Transformer v3 (NTv3) model to predict BigWig signal tracks directly from DNA sequences. The streamlined approach leverages a pre-trained NTv3 backbone as a feature extractor and adds a custom prediction head that outputs single-nucleotide resolution signal values for various genomic tracks (e.g., ChIP-seq, ATAC-seq, RNA-seq).\n",
10
  "\n",
11
- "**\u26a1 Key Advantage**: This simplified pipeline achieves **close performance to more complex training approaches** while enabling **relatively fast fine-tuning in approximately one hour**. The setup is designed for rapid experimentation and iteration, making it ideal for adapting pre-trained models to your specific genomic tracks or experimental conditions without the overhead of complex distributed training infrastructure.\n",
12
  "\n",
13
- "**\ud83d\udd27 Main Simplifications**: Compared to the full supervised tracks pipeline, this notebook simplifies several aspects to enable faster iteration:\n",
14
  "\n",
15
  "- **Data splits**: Uses simple chromosome-based train/val/test splits (e.g., assigning entire chromosomes to each split) instead of more complex region-based splits\n",
16
  "- **Random sequence sampling**: The dataset randomly samples sequences from chromosomes/regions on-the-fly, rather than using pre-computed sliding windows\n",
@@ -30,24 +30,417 @@
30
  "\n",
31
  "If you're interested in using pre-trained models for inference without fine-tuning, or exploring different model architectures, please refer to other notebooks in this collection. This notebook focuses specifically on the simplified fine-tuning process, which is useful when you want to quickly adapt a pre-trained model to your specific genomic tracks or improve performance on particular cell types or experimental conditions.\n",
32
  "\n",
33
- "\ud83d\udcdd Note for Google Colab users: This notebook is compatible with Colab! For faster training, make sure to enable GPU: Runtime \u2192 Change runtime type \u2192 GPU (T4 or better recommended).\n"
 
 
 
 
 
 
 
34
  ]
35
  },
36
  {
37
  "cell_type": "code",
38
- "execution_count": null,
39
  "metadata": {},
40
- "outputs": [],
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41
  "source": [
42
  "# Install dependencies\n",
43
- "!pip install pyfaidx pyBigWig torchmetrics transformers plotly"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
44
  ]
45
  },
46
  {
47
  "cell_type": "markdown",
48
  "metadata": {},
49
  "source": [
50
- "# 1. \ud83d\udce6 Imports + Configuration\n",
51
  "\n",
52
  "## Configuration Parameters\n",
53
  "\n",
@@ -81,43 +474,24 @@
81
  },
82
  {
83
  "cell_type": "code",
84
- "execution_count": null,
85
  "metadata": {},
86
- "outputs": [],
87
- "source": [
88
- "# 0. Imports\n",
89
- "import random\n",
90
- "import functools\n",
91
- "from typing import List, Dict, Callable\n",
92
- "import os\n",
93
- "import subprocess\n",
94
- "\n",
95
- "import torch\n",
96
- "import torch.nn as nn\n",
97
- "import torch.nn.functional as F\n",
98
- "from torch.utils.data import Dataset, DataLoader\n",
99
- "from torch.optim import AdamW\n",
100
- "from transformers import AutoConfig, AutoModelForMaskedLM, AutoTokenizer\n",
101
- "import numpy as np\n",
102
- "import pyBigWig\n",
103
- "from pyfaidx import Fasta\n",
104
- "from torchmetrics import PearsonCorrCoef\n",
105
- "import plotly.graph_objects as go\n",
106
- "from IPython.display import display\n",
107
- "from tqdm import tqdm"
108
- ]
109
- },
110
- {
111
- "cell_type": "code",
112
- "execution_count": null,
113
- "metadata": {},
114
- "outputs": [],
115
  "source": [
116
  "config = {\n",
117
  " # Model\n",
118
  " \"model_name\": \"InstaDeepAI/NTv3_8M_pre\",\n",
119
  " \n",
120
- " # Data\n",
 
121
  " \"data_cache_dir\": \"./data\",\n",
122
  " \"fasta_url\": \"https://hgdownload.gi.ucsc.edu/goldenPath/hg38/bigZips/hg38.fa.gz\",\n",
123
  " \"bigwig_url_list\": [\n",
@@ -152,25 +526,8 @@
152
  "\n",
153
  "os.makedirs(config[\"data_cache_dir\"], exist_ok=True)\n",
154
  "\n",
155
- "# Extract filenames from URLs\n",
156
- "def extract_filename_from_url(url: str) -> str:\n",
157
- " \"\"\"Extract filename from URL, handling query parameters.\"\"\"\n",
158
- " # Remove query parameters if present\n",
159
- " url_clean = url.split('?')[0]\n",
160
- " # Get the last part of the URL path\n",
161
- " return url_clean.split('/')[-1]\n",
162
- "\n",
163
- "# Create paths for downloaded files\n",
164
- "fasta_path = os.path.join(config[\"data_cache_dir\"], extract_filename_from_url(config[\"fasta_url\"]).replace('.gz', ''))\n",
165
- "bigwig_path_list = [\n",
166
- " os.path.join(config[\"data_cache_dir\"], extract_filename_from_url(url))\n",
167
- " for url in config[\"bigwig_url_list\"]\n",
168
- "]\n",
169
- "\n",
170
  "# Create bigwig_file_ids from filenames (without extension)\n",
171
  "config[\"bigwig_file_ids\"] = [\n",
172
- " # os.path.splitext(extract_filename_from_url(url))[0]\n",
173
- " # for url in config[\"bigwig_url_list\"]\n",
174
  " \"ENCSR325NFE\",\n",
175
  " \"ENCSR962OTG\",\n",
176
  " \"ENCSR619DQO_P\",\n",
@@ -190,67 +547,7 @@
190
  "cell_type": "markdown",
191
  "metadata": {},
192
  "source": [
193
- "# 2. \ud83d\udce5 Genome & Tracks Data Download\n",
194
- "\n",
195
- "Download the reference genome FASTA file and BigWig signal tracks from public repositories. These files contain the genomic sequences and experimental signal data (e.g., ChIP-seq, ATAC-seq) that we'll use for training."
196
- ]
197
- },
198
- {
199
- "cell_type": "code",
200
- "execution_count": null,
201
- "metadata": {},
202
- "outputs": [],
203
- "source": [
204
- "# Download fasta file\n",
205
- "fasta_filename = extract_filename_from_url(config[\"fasta_url\"])\n",
206
- "fasta_gz_path = os.path.join(config[\"data_cache_dir\"], fasta_filename)\n",
207
- "\n",
208
- "print(f\"Downloading {fasta_filename}...\")\n",
209
- "subprocess.run([\"wget\", \"-c\", config[\"fasta_url\"], \"-O\", fasta_gz_path], check=True)\n",
210
- "\n",
211
- "print(f\"Extracting {fasta_filename}...\")\n",
212
- "subprocess.run([\"gunzip\", \"-f\", fasta_gz_path], check=True)"
213
- ]
214
- },
215
- {
216
- "cell_type": "code",
217
- "execution_count": null,
218
- "metadata": {},
219
- "outputs": [],
220
- "source": [
221
- "# Download bigwig files\n",
222
- "for bigwig_url in config[\"bigwig_url_list\"]:\n",
223
- " filename = extract_filename_from_url(bigwig_url)\n",
224
- " filepath = os.path.join(config[\"data_cache_dir\"], filename)\n",
225
- " print(f\"Downloading {filename}...\")\n",
226
- " subprocess.run([\"wget\", \"-c\", bigwig_url, \"-O\", filepath], check=True)"
227
- ]
228
- },
229
- {
230
- "cell_type": "markdown",
231
- "metadata": {},
232
- "source": [
233
- "## Data Splits Definition"
234
- ]
235
- },
236
- {
237
- "cell_type": "code",
238
- "execution_count": null,
239
- "metadata": {},
240
- "outputs": [],
241
- "source": [
242
- "chrom_splits = {\n",
243
- " \"train\": [f\"chr{i}\" for i in range(1, 21)] + ['chrX', 'chrY'],\n",
244
- " \"val\": ['chr22'],\n",
245
- " \"test\": ['chr21']\n",
246
- "}"
247
- ]
248
- },
249
- {
250
- "cell_type": "markdown",
251
- "metadata": {},
252
- "source": [
253
- "# 3. \ud83e\udde0 Model and tokenizer setup\n",
254
  " \n",
255
  "In this section, we set up the model and tokenizer. \n",
256
  " \n",
@@ -260,12 +557,12 @@
260
  "This linear head is trained for regression on a set of genomic tracks, \n",
261
  "allowing the model to make predictions for each track at single nucleotide resolution.\n",
262
  " \n",
263
- "The following code wraps the HuggingFace model together with this regression head for the end-to-end task.\n"
264
  ]
265
  },
266
  {
267
  "cell_type": "code",
268
- "execution_count": null,
269
  "metadata": {},
270
  "outputs": [],
271
  "source": [
@@ -333,9 +630,19 @@
333
  },
334
  {
335
  "cell_type": "code",
336
- "execution_count": null,
337
  "metadata": {},
338
- "outputs": [],
 
 
 
 
 
 
 
 
 
 
339
  "source": [
340
  "# Load tokenizer\n",
341
  "tokenizer = AutoTokenizer.from_pretrained(config[\"model_name\"], trust_remote_code=True)\n",
@@ -351,341 +658,84 @@
351
  "\n",
352
  "print(f\"Model loaded: {config['model_name']}\")\n",
353
  "print(f\"Number of bigwig tracks: {len(config['bigwig_file_ids'])}\")\n",
354
- "print(f\"Model parameters: {sum(p.numel() for p in model.parameters()):,}\")"
355
- ]
356
- },
357
- {
358
- "cell_type": "code",
359
- "execution_count": null,
360
- "metadata": {},
361
- "outputs": [],
362
- "source": [
363
- "# Scaling functions for targets\n",
364
- "def compute_chromosome_stats(track_data: np.ndarray) -> dict:\n",
365
- " \"\"\"\n",
366
- " Compute minimal statistics needed for weighted mean computation.\n",
367
- " \n",
368
- " Args:\n",
369
- " track_data: numpy array of track values for a chromosome\n",
370
- " \n",
371
- " Returns:\n",
372
- " Dictionary with statistics: sum, mean, total_count\n",
373
- " \"\"\"\n",
374
- " track_data = track_data.astype(np.float32)\n",
375
- " \n",
376
- " # Compute statistics\n",
377
- " sum_all = np.sum(track_data)\n",
378
- " total_count = track_data.size\n",
379
- " mean_all = sum_all / total_count if total_count > 0 else 0.0\n",
380
- " \n",
381
- " return {\n",
382
- " \"sum\": sum_all,\n",
383
- " \"mean\": mean_all,\n",
384
- " \"total_count\": total_count,\n",
385
- " }\n",
386
- "\n",
387
- "\n",
388
- "def aggregate_file_statistics(chr_stats_list: List[dict]) -> dict:\n",
389
- " \"\"\"\n",
390
- " Aggregate chromosome-level statistics into file-level statistics.\n",
391
- " \n",
392
- " Args:\n",
393
- " chr_stats_list: List of dictionaries, each containing chromosome-level statistics\n",
394
- " \n",
395
- " Returns:\n",
396
- " Dictionary with aggregated file-level statistics (only mean)\n",
397
- " \"\"\"\n",
398
- " # Convert to arrays for easier computation\n",
399
- " total_counts = np.array([s[\"total_count\"] for s in chr_stats_list], dtype=np.int64)\n",
400
- " means = np.array([s[\"mean\"] for s in chr_stats_list], dtype=np.float32)\n",
401
- " sums = np.array([s[\"sum\"] for s in chr_stats_list], dtype=np.float32)\n",
402
- " \n",
403
- " # Aggregate total count\n",
404
- " total_count = np.sum(total_counts)\n",
405
- " \n",
406
- " # Weighted mean: mean = sum(mean_chr * total_count_chr) / sum(total_count_chr)\n",
407
- " mean = np.sum(means * total_counts) / total_count if total_count > 0 else 0.0\n",
408
- " \n",
409
- " return {\n",
410
- " \"total_count\": total_count,\n",
411
- " \"sum\": np.sum(sums),\n",
412
- " \"mean\": mean,\n",
413
- " }\n",
414
- "\n",
415
- "\n",
416
- "def get_track_means(bigwig_tracks_list: List[pyBigWig.pyBigWig]) -> np.ndarray:\n",
417
- " \"\"\"\n",
418
- " Get track means for normalization.\n",
419
- " Computes statistics per chromosome and aggregates using weighted averaging,\n",
420
- " \n",
421
- " Args:\n",
422
- " bigwig_tracks_list: List of pyBigWig file objects\n",
423
- " \n",
424
- " Returns:\n",
425
- " Array of track means, one per bigwig file\n",
426
- " \"\"\"\n",
427
- " track_means = []\n",
428
- " \n",
429
- " for bigwig_track in bigwig_tracks_list:\n",
430
- " chrom_lengths = bigwig_track.chroms()\n",
431
- " all_chr_stats = []\n",
432
- " \n",
433
- " # Compute statistics for each chromosome\n",
434
- " for chrom_name, chrom_length in chrom_lengths.items():\n",
435
- " try:\n",
436
- " # Get chromosome data as numpy array\n",
437
- " bw_array = np.array(\n",
438
- " bigwig_track.values(chrom_name, 0, chrom_length, numpy=True),\n",
439
- " dtype=np.float32\n",
440
- " )\n",
441
- " # Replace NaN with 0\n",
442
- " bw_array = np.nan_to_num(bw_array, nan=0.0)\n",
443
- " \n",
444
- " # Compute chromosome-level statistics\n",
445
- " chr_stats = compute_chromosome_stats(bw_array)\n",
446
- " all_chr_stats.append(chr_stats)\n",
447
- " except Exception as e:\n",
448
- " # Skip chromosomes that fail to load\n",
449
- " print(f\"Warning: Failed to load chromosome {chrom_name}: {e}\")\n",
450
- " continue\n",
451
- " \n",
452
- " if not all_chr_stats:\n",
453
- " raise ValueError(f\"No valid chromosomes found for bigwig track\")\n",
454
- " \n",
455
- " # Aggregate chromosome-level stats into file-level stats\n",
456
- " file_stats = aggregate_file_statistics(all_chr_stats)\n",
457
- " \n",
458
- " # Use the weighted mean for normalization\n",
459
- " track_means.append(file_stats[\"mean\"])\n",
460
- " \n",
461
- " return np.array(track_means, dtype=np.float32)\n",
462
- "\n",
463
- "\n",
464
- "def create_targets_scaling_fn(bigwig_path_list: List[str]) -> Callable[[torch.Tensor], torch.Tensor]:\n",
465
- " \"\"\"\n",
466
- " Build a scaling function based on track means computed from bigwig files.\n",
467
- " \n",
468
- " Opens bigwig files, computes track statistics, and creates a transform function.\n",
469
- " The statistics are computed once and reused for all calls to the returned transform function.\n",
470
- " \n",
471
- " Args:\n",
472
- " bigwig_path_list: List of paths to bigwig files\n",
473
- " \n",
474
- " Returns:\n",
475
- " Transform function that scales input tensors\n",
476
- " \"\"\"\n",
477
- " # Open bigwig files and compute track statistics\n",
478
- " print(\"Computing track statistics (this may take a while)...\")\n",
479
- " bw_list = [\n",
480
- " pyBigWig.open(bigwig_path)\n",
481
- " for bigwig_path in bigwig_path_list\n",
482
- " ]\n",
483
- " track_means = get_track_means(bw_list)\n",
484
- " print(f\"Computed track means: {track_means}\")\n",
485
- " print(f\"Track means shape: {track_means.shape}\")\n",
486
- " \n",
487
- " # Create tensor from computed means\n",
488
- " track_means_tensor = torch.tensor(track_means, dtype=torch.float32)\n",
489
- " \n",
490
- " def transform_fn(x: torch.Tensor) -> torch.Tensor:\n",
491
- " \"\"\"\n",
492
- " x: torch.Tensor, shape (seq_len, num_tracks) or (batch, seq_len, num_tracks)\n",
493
- " \"\"\"\n",
494
- " # Move constants to correct device then normalize\n",
495
- " means = track_means_tensor.to(x.device)\n",
496
- " scaled = x / means\n",
497
- "\n",
498
- " # Smooth clipping: if > 10, apply formula\n",
499
- " clipped = torch.where(\n",
500
- " scaled > 10.0,\n",
501
- " 2.0 * torch.sqrt(scaled * 10.0) - 10.0,\n",
502
- " scaled,\n",
503
- " )\n",
504
- " return clipped\n",
505
- " \n",
506
- " return transform_fn"
507
  ]
508
  },
509
  {
510
  "cell_type": "markdown",
511
  "metadata": {},
512
  "source": [
513
- "# 4. \ud83d\udd04 Data loading\n",
514
  "\n",
515
- "Create PyTorch datasets and data loaders that efficiently sample random genomic windows from the reference genome and extract corresponding BigWig signal values. The dataset handles sequence tokenization, target scaling, and chromosome-based train/val/test splits."
516
  ]
517
  },
518
  {
519
  "cell_type": "code",
520
- "execution_count": null,
521
  "metadata": {},
522
- "outputs": [],
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
523
  "source": [
524
- "# Process-local cache for BigWig file handles (one per worker process)\n",
525
- "# This allows safe multi-worker DataLoader usage\n",
526
- "_bigwig_cache = {} # Maps (process_id, file_path) -> pyBigWig handle\n",
527
- "\n",
528
- "\n",
529
- "def _get_bigwig_handle(bigwig_path: str) -> pyBigWig.pyBigWig:\n",
530
- " \"\"\"Get or create a BigWig file handle for the current process.\"\"\"\n",
531
- " process_id = os.getpid()\n",
532
- " cache_key = (process_id, bigwig_path)\n",
533
- " \n",
534
- " if cache_key not in _bigwig_cache:\n",
535
- " _bigwig_cache[cache_key] = pyBigWig.open(bigwig_path)\n",
536
- " \n",
537
- " return _bigwig_cache[cache_key]\n",
538
- "\n",
539
- "\n",
540
- "class GenomeBigWigDataset(Dataset):\n",
541
- " \"\"\"\n",
542
- " Random genomic windows from a reference genome + bigWig signal.\n",
543
- "\n",
544
- " Each sample:\n",
545
- " - picks a chromosome/region (from `chroms` or `regions`),\n",
546
- " - picks a random window of length `sequence_length`,\n",
547
- " - returns (sequence, signal, chrom, start, end).\n",
548
- "\n",
549
- " This dataset is compatible with multi-worker DataLoaders. BigWig files\n",
550
- " are opened lazily using a process-local cache, ensuring each worker process\n",
551
- " has its own file handles and avoiding concurrent access issues.\n",
552
- "\n",
553
- " Args\n",
554
- " ----\n",
555
- " fasta_path : str\n",
556
- " Path to the reference genome FASTA (e.g. hg38.fna).\n",
557
- " bigwig_path_list : str\n",
558
- " Path to the bigWig file (e.g. ENCFF884LDL.bigWig).\n",
559
- " chroms : List[str]\n",
560
- " Chromosome names as they appear in the bigWig (e.g. [\"chr1\", \"chr2\", ...]).\n",
561
- " Used for backward compatibility or when regions=None.\n",
562
- " sequence_length : int\n",
563
- " Length of each random window (in bp).\n",
564
- " num_samples : int\n",
565
- " Number of samples the dataset will provide (len(dataset)).\n",
566
- " tokenizer : AutoTokenizer\n",
567
- " Tokenizer to use for tokenization.\n",
568
- " transform_fn : Callable\n",
569
- " Function to transform/scaling bigwig targets.\n",
570
- " keep_target_center_fraction : float\n",
571
- " Fraction of center sequence to keep for target prediction (crops edges to focus on center).\n",
572
- " regions : List[tuple[str, int, int]] | None\n",
573
- " Optional list of regions as (chromosome, start, end) tuples.\n",
574
- " If provided, samples are drawn randomly from within these regions only.\n",
575
- " This matches the JAX pipeline approach using BED file splits.\n",
576
- " If None, samples from entire chromosomes in `chroms`.\n",
577
- " \"\"\"\n",
578
- "\n",
579
- " def __init__(\n",
580
- " self,\n",
581
- " fasta_path: str,\n",
582
- " bigwig_path_list: list[str],\n",
583
- " chroms: List[str],\n",
584
- " sequence_length: int,\n",
585
- " num_samples: int,\n",
586
- " tokenizer: AutoTokenizer,\n",
587
- " transform_fn: Callable[[torch.Tensor], torch.Tensor],\n",
588
- " keep_target_center_fraction: float = 1.0,\n",
589
- " ):\n",
590
- " super().__init__()\n",
591
- "\n",
592
- " self.fasta = Fasta(fasta_path, as_raw=True, sequence_always_upper=True)\n",
593
- " # Store paths instead of opening files immediately (for multi-worker compatibility)\n",
594
- " self.bigwig_path_list = bigwig_path_list\n",
595
- " self.sequence_length = sequence_length\n",
596
- " self.num_samples = num_samples\n",
597
- " self.tokenizer = tokenizer\n",
598
- " self.transform_fn = transform_fn # Use pre-computed transform function\n",
599
- " self.keep_target_center_fraction = keep_target_center_fraction\n",
600
- " self.chroms = chroms\n",
601
- "\n",
602
- " # Get chromosome lengths from first BigWig file (lazy, cached per process)\n",
603
- " # We need this for validation, so open temporarily\n",
604
- " bw_handle = _get_bigwig_handle(bigwig_path_list[0])\n",
605
- " bw_chrom_lengths = bw_handle.chroms() # dict: chrom -> length\n",
606
- "\n",
607
- " self.valid_chroms = []\n",
608
- " self.chrom_lengths = {}\n",
609
- "\n",
610
- " for c in chroms:\n",
611
- " if c not in bw_chrom_lengths or c not in self.fasta:\n",
612
- " continue\n",
613
- "\n",
614
- " fa_len = len(self.fasta[c])\n",
615
- " bw_len = bw_chrom_lengths[c]\n",
616
- " L = min(fa_len, bw_len)\n",
617
- "\n",
618
- " if L > self.sequence_length:\n",
619
- " self.valid_chroms.append(c)\n",
620
- " self.chrom_lengths[c] = L\n",
621
- "\n",
622
- " if not self.valid_chroms:\n",
623
- " raise ValueError(\"No valid chromosomes after intersecting FASTA and bigWig.\")\n",
624
- "\n",
625
- " def __len__(self):\n",
626
- " return self.num_samples\n",
627
- "\n",
628
- " def __getitem__(self, idx):\n",
629
- "\n",
630
- " # Sample from entire chromosomes\n",
631
- " chrom = random.choice(self.valid_chroms)\n",
632
- " chrom_len = self.chrom_lengths[chrom]\n",
633
- " max_start = chrom_len - self.sequence_length\n",
634
- " start = random.randint(0, max_start)\n",
635
- " end = start + self.sequence_length\n",
636
- "\n",
637
- " # Sequence\n",
638
- " seq = self.fasta[chrom][start:end] # string slice\n",
639
- " # Tokenize with padding and truncation to ensure consistent lengths for batching\n",
640
- " tokenized = self.tokenizer(\n",
641
- " seq,\n",
642
- " padding=\"max_length\",\n",
643
- " truncation=True,\n",
644
- " max_length=self.sequence_length,\n",
645
- " return_tensors=\"pt\",\n",
646
- " )\n",
647
- " tokens = tokenized[\"input_ids\"][0] # Shape: (max_length,)\n",
648
- "\n",
649
- " # Signal from bigWig tracks (numpy array) -> torch tensor\n",
650
- " # Get BigWig handles lazily (cached per worker process)\n",
651
- " bigwig_targets = np.array([\n",
652
- " _get_bigwig_handle(bw_path).values(chrom, start, end, numpy=True)\n",
653
- " for bw_path in self.bigwig_path_list\n",
654
- " ]) # shape (num_tracks, seq_len)\n",
655
- " # Transpose to (seq_len, num_tracks)\n",
656
- " bigwig_targets = bigwig_targets.T\n",
657
- " # pyBigWig returns NaN where no data; turn NaN into 0\n",
658
- " bigwig_targets = torch.tensor(bigwig_targets, dtype=torch.float32)\n",
659
- " bigwig_targets = torch.nan_to_num(bigwig_targets, nan=0.0)\n",
660
- " \n",
661
- " # Crop targets to center fraction\n",
662
- " if self.keep_target_center_fraction < 1.0:\n",
663
- " seq_len = bigwig_targets.shape[0] # First dimension is sequence length\n",
664
- " target_offset = int(seq_len * (1 - self.keep_target_center_fraction) // 2)\n",
665
- " target_length = seq_len - 2 * target_offset\n",
666
- " bigwig_targets = bigwig_targets[target_offset:target_offset + target_length, :]\n",
667
  "\n",
668
- " # Apply scaling to targets\n",
669
- " bigwig_targets = self.transform_fn(bigwig_targets)\n",
 
 
 
 
670
  "\n",
671
- " sample = {\n",
672
- " \"tokens\": tokens,\n",
673
- " \"bigwig_targets\": bigwig_targets,\n",
674
- " \"chrom\": chrom,\n",
675
- " \"start\": start,\n",
676
- " \"end\": end,\n",
677
- " }\n",
678
- " return sample"
679
- ]
680
- },
681
- {
682
- "cell_type": "code",
683
- "execution_count": null,
684
- "metadata": {},
685
- "outputs": [],
686
- "source": [
687
- "# Create scaling function\n",
688
- "targets_transform_fn = create_targets_scaling_fn(bigwig_path_list)"
689
  ]
690
  },
691
  {
@@ -694,74 +744,92 @@
694
  "metadata": {},
695
  "outputs": [],
696
  "source": [
697
- "# Create datasets & dataloaders\n",
698
- "create_dataset_fn = functools.partial(\n",
699
- " GenomeBigWigDataset,\n",
700
- " fasta_path=fasta_path,\n",
701
- " bigwig_path_list=bigwig_path_list,\n",
702
- " sequence_length=config[\"sequence_length\"],\n",
703
- " tokenizer=tokenizer,\n",
704
- " transform_fn=targets_transform_fn,\n",
705
- " keep_target_center_fraction=config[\"keep_target_center_fraction\"],\n",
706
- ")\n",
707
- "\n",
708
- "train_dataset = create_dataset_fn(\n",
709
- " chroms=chrom_splits[\"train\"],\n",
710
- " num_samples=config[\"num_steps_training\"] * config[\"batch_size\"],\n",
711
- ")\n",
712
- "\n",
713
- "val_dataset = create_dataset_fn(\n",
714
- " chroms=chrom_splits[\"val\"],\n",
715
- " num_samples=config[\"num_validation_samples\"],\n",
716
- ")\n",
 
 
 
 
 
717
  "\n",
718
- "test_dataset = create_dataset_fn(\n",
719
- " chroms=chrom_splits[\"test\"],\n",
720
- " num_samples=config[\"num_test_samples\"],\n",
 
 
 
721
  ")\n",
722
  "\n",
723
- "# Create dataloaders\n",
724
- "train_loader = DataLoader(\n",
725
- " train_dataset,\n",
726
- " batch_size=config[\"batch_size\"],\n",
727
- " shuffle=True,\n",
728
- " num_workers=config[\"num_workers\"],\n",
729
- ")\n",
730
  "\n",
731
- "val_loader = DataLoader(\n",
732
- " val_dataset,\n",
733
- " batch_size=config[\"batch_size\"],\n",
734
- " shuffle=False,\n",
735
- " num_workers=config[\"num_workers\"],\n",
736
- ")\n",
 
 
737
  "\n",
738
- "test_loader = DataLoader(\n",
739
- " test_dataset,\n",
740
- " batch_size=config[\"batch_size\"],\n",
741
- " shuffle=False,\n",
742
- " num_workers=config[\"num_workers\"],\n",
743
- ")\n",
744
  "\n",
745
- "print(f\"Train samples: {len(train_dataset)}\")\n",
746
- "print(f\"Val samples: {len(val_dataset)}\")\n",
747
- "print(f\"Test samples: {len(test_dataset)}\")"
 
748
  ]
749
  },
750
  {
751
  "cell_type": "markdown",
752
  "metadata": {},
753
  "source": [
754
- "# 5. \u2699\ufe0f Optimizer setup\n",
755
  "\n",
756
- "Configure the AdamW optimizer with learning rate and weight decay hyperparameters. This optimizer will update the model parameters during training to minimize the loss function.\n",
757
- "\n"
758
  ]
759
  },
760
  {
761
  "cell_type": "code",
762
- "execution_count": null,
763
  "metadata": {},
764
- "outputs": [],
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
765
  "source": [
766
  "# Training setup\n",
767
  "print(f\"Training configuration:\")\n",
@@ -785,14 +853,14 @@
785
  "cell_type": "markdown",
786
  "metadata": {},
787
  "source": [
788
- "# 6. \ud83d\udcca Metrics setup\n",
789
  "\n",
790
  "Set up evaluation metrics to track model performance during training and validation. We use Pearson correlation coefficients to measure how well the predicted BigWig signals match the ground truth signals."
791
  ]
792
  },
793
  {
794
  "cell_type": "code",
795
- "execution_count": null,
796
  "metadata": {},
797
  "outputs": [],
798
  "source": [
@@ -865,7 +933,7 @@
865
  },
866
  {
867
  "cell_type": "code",
868
- "execution_count": null,
869
  "metadata": {},
870
  "outputs": [],
871
  "source": [
@@ -878,14 +946,14 @@
878
  "cell_type": "markdown",
879
  "metadata": {},
880
  "source": [
881
- "# 7. \ud83d\udcc9 Loss functions\n",
882
  "\n",
883
  "Define the Poisson-Multinomial loss function that captures both the scale (total signal) and shape (distribution) of BigWig tracks. This loss is specifically designed for count-based genomic signal data."
884
  ]
885
  },
886
  {
887
  "cell_type": "code",
888
- "execution_count": null,
889
  "metadata": {},
890
  "outputs": [],
891
  "source": [
@@ -960,14 +1028,14 @@
960
  "cell_type": "markdown",
961
  "metadata": {},
962
  "source": [
963
- "# 8. \ud83c\udfc3 Training loop\n",
964
  "\n",
965
  "Run the main training loop that iterates through batches, computes gradients, and updates model parameters. The loop includes periodic validation checks and real-time metric visualization to monitor training progress."
966
  ]
967
  },
968
  {
969
  "cell_type": "code",
970
- "execution_count": null,
971
  "metadata": {},
972
  "outputs": [],
973
  "source": [
@@ -1035,7 +1103,7 @@
1035
  },
1036
  {
1037
  "cell_type": "code",
1038
- "execution_count": null,
1039
  "metadata": {},
1040
  "outputs": [],
1041
  "source": [
@@ -1192,7 +1260,7 @@
1192
  "cell_type": "markdown",
1193
  "metadata": {},
1194
  "source": [
1195
- "# 9. \ud83e\uddea Test evaluation\n",
1196
  "\n",
1197
  "Evaluate the fine-tuned model on the held-out test set to assess final performance. This provides an unbiased estimate of how well the model generalizes to unseen genomic regions."
1198
  ]
@@ -1270,4 +1338,4 @@
1270
  },
1271
  "nbformat": 4,
1272
  "nbformat_minor": 2
1273
- }
 
4
  "cell_type": "markdown",
5
  "metadata": {},
6
  "source": [
7
+ "# 🧬 Fine-Tuning a Model on BigWig Tracks Prediction\n",
8
  "\n",
9
  "This notebook demonstrates a **simplified fine-tuning setup** that enables training of a pre-trained Nucleotide Transformer v3 (NTv3) model to predict BigWig signal tracks directly from DNA sequences. The streamlined approach leverages a pre-trained NTv3 backbone as a feature extractor and adds a custom prediction head that outputs single-nucleotide resolution signal values for various genomic tracks (e.g., ChIP-seq, ATAC-seq, RNA-seq).\n",
10
  "\n",
11
+ "**⚡ Key Advantage**: This simplified pipeline achieves **close performance to more complex training approaches** while enabling **relatively fast fine-tuning in approximately one hour**. The setup is designed for rapid experimentation and iteration, making it ideal for adapting pre-trained models to your specific genomic tracks or experimental conditions without the overhead of complex distributed training infrastructure.\n",
12
  "\n",
13
+ "**🔧 Main Simplifications**: Compared to the full supervised tracks pipeline, this notebook simplifies several aspects to enable faster iteration:\n",
14
  "\n",
15
  "- **Data splits**: Uses simple chromosome-based train/val/test splits (e.g., assigning entire chromosomes to each split) instead of more complex region-based splits\n",
16
  "- **Random sequence sampling**: The dataset randomly samples sequences from chromosomes/regions on-the-fly, rather than using pre-computed sliding windows\n",
 
30
  "\n",
31
  "If you're interested in using pre-trained models for inference without fine-tuning, or exploring different model architectures, please refer to other notebooks in this collection. This notebook focuses specifically on the simplified fine-tuning process, which is useful when you want to quickly adapt a pre-trained model to your specific genomic tracks or improve performance on particular cell types or experimental conditions.\n",
32
  "\n",
33
+ "📝 Note for Google Colab users: This notebook is compatible with Colab! For faster training, make sure to enable GPU: Runtime Change runtime type GPU (T4 or better recommended).\n"
34
+ ]
35
+ },
36
+ {
37
+ "cell_type": "markdown",
38
+ "metadata": {},
39
+ "source": [
40
+ "# 0. 📦 Imports"
41
  ]
42
  },
43
  {
44
  "cell_type": "code",
45
+ "execution_count": 1,
46
  "metadata": {},
47
+ "outputs": [
48
+ {
49
+ "name": "stdout",
50
+ "output_type": "stream",
51
+ "text": [
52
+ "Collecting datasets\n",
53
+ " Downloading datasets-4.4.1-py3-none-any.whl.metadata (19 kB)\n",
54
+ "Requirement already satisfied: transformers in /home/y-bornachot/venvs/ntv3-env/lib/python3.12/site-packages (4.57.3)\n",
55
+ "Requirement already satisfied: torchmetrics in /home/y-bornachot/venvs/ntv3-env/lib/python3.12/site-packages (1.8.2)\n",
56
+ "Requirement already satisfied: plotly in /home/y-bornachot/venvs/ntv3-env/lib/python3.12/site-packages (6.5.0)\n",
57
+ "Requirement already satisfied: filelock in /home/y-bornachot/venvs/ntv3-env/lib/python3.12/site-packages (from datasets) (3.20.0)\n",
58
+ "Requirement already satisfied: numpy>=1.17 in /home/y-bornachot/venvs/ntv3-env/lib/python3.12/site-packages (from datasets) (2.3.5)\n",
59
+ "Collecting pyarrow>=21.0.0 (from datasets)\n",
60
+ " Downloading pyarrow-22.0.0-cp312-cp312-manylinux_2_28_x86_64.whl.metadata (3.2 kB)\n",
61
+ "Collecting dill<0.4.1,>=0.3.0 (from datasets)\n",
62
+ " Downloading dill-0.4.0-py3-none-any.whl.metadata (10 kB)\n",
63
+ "Collecting pandas (from datasets)\n",
64
+ " Downloading pandas-2.3.3-cp312-cp312-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl.metadata (91 kB)\n",
65
+ "\u001b[2K \u001b[38;2;114;156;31m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m91.2/91.2 kB\u001b[0m \u001b[31m3.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
66
+ "\u001b[?25hRequirement already satisfied: requests>=2.32.2 in /home/y-bornachot/venvs/ntv3-env/lib/python3.12/site-packages (from datasets) (2.32.5)\n",
67
+ "Collecting httpx<1.0.0 (from datasets)\n",
68
+ " Using cached httpx-0.28.1-py3-none-any.whl.metadata (7.1 kB)\n",
69
+ "Requirement already satisfied: tqdm>=4.66.3 in /home/y-bornachot/venvs/ntv3-env/lib/python3.12/site-packages (from datasets) (4.67.1)\n",
70
+ "Collecting xxhash (from datasets)\n",
71
+ " Downloading xxhash-3.6.0-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl.metadata (13 kB)\n",
72
+ "Collecting multiprocess<0.70.19 (from datasets)\n",
73
+ " Downloading multiprocess-0.70.18-py312-none-any.whl.metadata (7.5 kB)\n",
74
+ "Collecting fsspec<=2025.10.0,>=2023.1.0 (from fsspec[http]<=2025.10.0,>=2023.1.0->datasets)\n",
75
+ " Downloading fsspec-2025.10.0-py3-none-any.whl.metadata (10 kB)\n",
76
+ "Requirement already satisfied: huggingface-hub<2.0,>=0.25.0 in /home/y-bornachot/venvs/ntv3-env/lib/python3.12/site-packages (from datasets) (0.36.0)\n",
77
+ "Requirement already satisfied: packaging in /home/y-bornachot/venvs/ntv3-env/lib/python3.12/site-packages (from datasets) (25.0)\n",
78
+ "Requirement already satisfied: pyyaml>=5.1 in /home/y-bornachot/venvs/ntv3-env/lib/python3.12/site-packages (from datasets) (6.0.3)\n",
79
+ "Requirement already satisfied: regex!=2019.12.17 in /home/y-bornachot/venvs/ntv3-env/lib/python3.12/site-packages (from transformers) (2025.11.3)\n",
80
+ "Requirement already satisfied: tokenizers<=0.23.0,>=0.22.0 in /home/y-bornachot/venvs/ntv3-env/lib/python3.12/site-packages (from transformers) (0.22.1)\n",
81
+ "Requirement already satisfied: safetensors>=0.4.3 in /home/y-bornachot/venvs/ntv3-env/lib/python3.12/site-packages (from transformers) (0.7.0)\n",
82
+ "Requirement already satisfied: torch>=2.0.0 in /home/y-bornachot/venvs/ntv3-env/lib/python3.12/site-packages (from torchmetrics) (2.9.1)\n",
83
+ "Requirement already satisfied: lightning-utilities>=0.8.0 in /home/y-bornachot/venvs/ntv3-env/lib/python3.12/site-packages (from torchmetrics) (0.15.2)\n",
84
+ "Requirement already satisfied: narwhals>=1.15.1 in /home/y-bornachot/venvs/ntv3-env/lib/python3.12/site-packages (from plotly) (2.13.0)\n",
85
+ "Collecting aiohttp!=4.0.0a0,!=4.0.0a1 (from fsspec[http]<=2025.10.0,>=2023.1.0->datasets)\n",
86
+ " Downloading aiohttp-3.13.2-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl.metadata (8.1 kB)\n",
87
+ "Collecting anyio (from httpx<1.0.0->datasets)\n",
88
+ " Downloading anyio-4.12.0-py3-none-any.whl.metadata (4.3 kB)\n",
89
+ "Requirement already satisfied: certifi in /home/y-bornachot/venvs/ntv3-env/lib/python3.12/site-packages (from httpx<1.0.0->datasets) (2025.11.12)\n",
90
+ "Collecting httpcore==1.* (from httpx<1.0.0->datasets)\n",
91
+ " Using cached httpcore-1.0.9-py3-none-any.whl.metadata (21 kB)\n",
92
+ "Requirement already satisfied: idna in /home/y-bornachot/venvs/ntv3-env/lib/python3.12/site-packages (from httpx<1.0.0->datasets) (3.11)\n",
93
+ "Collecting h11>=0.16 (from httpcore==1.*->httpx<1.0.0->datasets)\n",
94
+ " Using cached h11-0.16.0-py3-none-any.whl.metadata (8.3 kB)\n",
95
+ "Requirement already satisfied: typing-extensions>=3.7.4.3 in /home/y-bornachot/venvs/ntv3-env/lib/python3.12/site-packages (from huggingface-hub<2.0,>=0.25.0->datasets) (4.15.0)\n",
96
+ "Requirement already satisfied: hf-xet<2.0.0,>=1.1.3 in /home/y-bornachot/venvs/ntv3-env/lib/python3.12/site-packages (from huggingface-hub<2.0,>=0.25.0->datasets) (1.2.0)\n",
97
+ "Requirement already satisfied: setuptools in /home/y-bornachot/venvs/ntv3-env/lib/python3.12/site-packages (from lightning-utilities>=0.8.0->torchmetrics) (80.9.0)\n",
98
+ "Requirement already satisfied: charset_normalizer<4,>=2 in /home/y-bornachot/venvs/ntv3-env/lib/python3.12/site-packages (from requests>=2.32.2->datasets) (3.4.4)\n",
99
+ "Requirement already satisfied: urllib3<3,>=1.21.1 in /home/y-bornachot/venvs/ntv3-env/lib/python3.12/site-packages (from requests>=2.32.2->datasets) (2.6.1)\n",
100
+ "Requirement already satisfied: sympy>=1.13.3 in /home/y-bornachot/venvs/ntv3-env/lib/python3.12/site-packages (from torch>=2.0.0->torchmetrics) (1.14.0)\n",
101
+ "Requirement already satisfied: networkx>=2.5.1 in /home/y-bornachot/venvs/ntv3-env/lib/python3.12/site-packages (from torch>=2.0.0->torchmetrics) (3.6.1)\n",
102
+ "Requirement already satisfied: jinja2 in /home/y-bornachot/venvs/ntv3-env/lib/python3.12/site-packages (from torch>=2.0.0->torchmetrics) (3.1.6)\n",
103
+ "Requirement already satisfied: nvidia-cuda-nvrtc-cu12==12.8.93 in /home/y-bornachot/venvs/ntv3-env/lib/python3.12/site-packages (from torch>=2.0.0->torchmetrics) (12.8.93)\n",
104
+ "Requirement already satisfied: nvidia-cuda-runtime-cu12==12.8.90 in /home/y-bornachot/venvs/ntv3-env/lib/python3.12/site-packages (from torch>=2.0.0->torchmetrics) (12.8.90)\n",
105
+ "Requirement already satisfied: nvidia-cuda-cupti-cu12==12.8.90 in /home/y-bornachot/venvs/ntv3-env/lib/python3.12/site-packages (from torch>=2.0.0->torchmetrics) (12.8.90)\n",
106
+ "Requirement already satisfied: nvidia-cudnn-cu12==9.10.2.21 in /home/y-bornachot/venvs/ntv3-env/lib/python3.12/site-packages (from torch>=2.0.0->torchmetrics) (9.10.2.21)\n",
107
+ "Requirement already satisfied: nvidia-cublas-cu12==12.8.4.1 in /home/y-bornachot/venvs/ntv3-env/lib/python3.12/site-packages (from torch>=2.0.0->torchmetrics) (12.8.4.1)\n",
108
+ "Requirement already satisfied: nvidia-cufft-cu12==11.3.3.83 in /home/y-bornachot/venvs/ntv3-env/lib/python3.12/site-packages (from torch>=2.0.0->torchmetrics) (11.3.3.83)\n",
109
+ "Requirement already satisfied: nvidia-curand-cu12==10.3.9.90 in /home/y-bornachot/venvs/ntv3-env/lib/python3.12/site-packages (from torch>=2.0.0->torchmetrics) (10.3.9.90)\n",
110
+ "Requirement already satisfied: nvidia-cusolver-cu12==11.7.3.90 in /home/y-bornachot/venvs/ntv3-env/lib/python3.12/site-packages (from torch>=2.0.0->torchmetrics) (11.7.3.90)\n",
111
+ "Requirement already satisfied: nvidia-cusparse-cu12==12.5.8.93 in /home/y-bornachot/venvs/ntv3-env/lib/python3.12/site-packages (from torch>=2.0.0->torchmetrics) (12.5.8.93)\n",
112
+ "Requirement already satisfied: nvidia-cusparselt-cu12==0.7.1 in /home/y-bornachot/venvs/ntv3-env/lib/python3.12/site-packages (from torch>=2.0.0->torchmetrics) (0.7.1)\n",
113
+ "Requirement already satisfied: nvidia-nccl-cu12==2.27.5 in /home/y-bornachot/venvs/ntv3-env/lib/python3.12/site-packages (from torch>=2.0.0->torchmetrics) (2.27.5)\n",
114
+ "Requirement already satisfied: nvidia-nvshmem-cu12==3.3.20 in /home/y-bornachot/venvs/ntv3-env/lib/python3.12/site-packages (from torch>=2.0.0->torchmetrics) (3.3.20)\n",
115
+ "Requirement already satisfied: nvidia-nvtx-cu12==12.8.90 in /home/y-bornachot/venvs/ntv3-env/lib/python3.12/site-packages (from torch>=2.0.0->torchmetrics) (12.8.90)\n",
116
+ "Requirement already satisfied: nvidia-nvjitlink-cu12==12.8.93 in /home/y-bornachot/venvs/ntv3-env/lib/python3.12/site-packages (from torch>=2.0.0->torchmetrics) (12.8.93)\n",
117
+ "Requirement already satisfied: nvidia-cufile-cu12==1.13.1.3 in /home/y-bornachot/venvs/ntv3-env/lib/python3.12/site-packages (from torch>=2.0.0->torchmetrics) (1.13.1.3)\n",
118
+ "Requirement already satisfied: triton==3.5.1 in /home/y-bornachot/venvs/ntv3-env/lib/python3.12/site-packages (from torch>=2.0.0->torchmetrics) (3.5.1)\n",
119
+ "Requirement already satisfied: python-dateutil>=2.8.2 in /home/y-bornachot/venvs/ntv3-env/lib/python3.12/site-packages (from pandas->datasets) (2.9.0.post0)\n",
120
+ "Collecting pytz>=2020.1 (from pandas->datasets)\n",
121
+ " Downloading pytz-2025.2-py2.py3-none-any.whl.metadata (22 kB)\n",
122
+ "Collecting tzdata>=2022.7 (from pandas->datasets)\n",
123
+ " Downloading tzdata-2025.3-py2.py3-none-any.whl.metadata (1.4 kB)\n",
124
+ "Collecting aiohappyeyeballs>=2.5.0 (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]<=2025.10.0,>=2023.1.0->datasets)\n",
125
+ " Downloading aiohappyeyeballs-2.6.1-py3-none-any.whl.metadata (5.9 kB)\n",
126
+ "Collecting aiosignal>=1.4.0 (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]<=2025.10.0,>=2023.1.0->datasets)\n",
127
+ " Downloading aiosignal-1.4.0-py3-none-any.whl.metadata (3.7 kB)\n",
128
+ "Collecting attrs>=17.3.0 (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]<=2025.10.0,>=2023.1.0->datasets)\n",
129
+ " Downloading attrs-25.4.0-py3-none-any.whl.metadata (10 kB)\n",
130
+ "Collecting frozenlist>=1.1.1 (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]<=2025.10.0,>=2023.1.0->datasets)\n",
131
+ " Downloading frozenlist-1.8.0-cp312-cp312-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl.metadata (20 kB)\n",
132
+ "Collecting multidict<7.0,>=4.5 (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]<=2025.10.0,>=2023.1.0->datasets)\n",
133
+ " Downloading multidict-6.7.0-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl.metadata (5.3 kB)\n",
134
+ "Collecting propcache>=0.2.0 (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]<=2025.10.0,>=2023.1.0->datasets)\n",
135
+ " Downloading propcache-0.4.1-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl.metadata (13 kB)\n",
136
+ "Collecting yarl<2.0,>=1.17.0 (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]<=2025.10.0,>=2023.1.0->datasets)\n",
137
+ " Downloading yarl-1.22.0-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl.metadata (75 kB)\n",
138
+ "\u001b[2K \u001b[38;2;114;156;31m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m75.1/75.1 kB\u001b[0m \u001b[31m23.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
139
+ "\u001b[?25hRequirement already satisfied: six>=1.5 in /home/y-bornachot/venvs/ntv3-env/lib/python3.12/site-packages (from python-dateutil>=2.8.2->pandas->datasets) (1.17.0)\n",
140
+ "Requirement already satisfied: mpmath<1.4,>=1.1.0 in /home/y-bornachot/venvs/ntv3-env/lib/python3.12/site-packages (from sympy>=1.13.3->torch>=2.0.0->torchmetrics) (1.3.0)\n",
141
+ "Requirement already satisfied: MarkupSafe>=2.0 in /home/y-bornachot/venvs/ntv3-env/lib/python3.12/site-packages (from jinja2->torch>=2.0.0->torchmetrics) (3.0.3)\n",
142
+ "Downloading datasets-4.4.1-py3-none-any.whl (511 kB)\n",
143
+ "\u001b[2K \u001b[38;2;114;156;31m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m511.6/511.6 kB\u001b[0m \u001b[31m32.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
144
+ "\u001b[?25hDownloading dill-0.4.0-py3-none-any.whl (119 kB)\n",
145
+ "\u001b[2K \u001b[38;2;114;156;31m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m119.7/119.7 kB\u001b[0m \u001b[31m19.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
146
+ "\u001b[?25hDownloading fsspec-2025.10.0-py3-none-any.whl (200 kB)\n",
147
+ "\u001b[2K \u001b[38;2;114;156;31m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m201.0/201.0 kB\u001b[0m \u001b[31m16.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
148
+ "\u001b[?25hUsing cached httpx-0.28.1-py3-none-any.whl (73 kB)\n",
149
+ "Using cached httpcore-1.0.9-py3-none-any.whl (78 kB)\n",
150
+ "Downloading multiprocess-0.70.18-py312-none-any.whl (150 kB)\n",
151
+ "\u001b[2K \u001b[38;2;114;156;31m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m150.3/150.3 kB\u001b[0m \u001b[31m18.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
152
+ "\u001b[?25hDownloading pyarrow-22.0.0-cp312-cp312-manylinux_2_28_x86_64.whl (47.7 MB)\n",
153
+ "\u001b[2K \u001b[38;2;114;156;31m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m47.7/47.7 MB\u001b[0m \u001b[31m33.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0mm eta \u001b[36m0:00:01\u001b[0m[36m0:00:01\u001b[0m\n",
154
+ "\u001b[?25hDownloading pandas-2.3.3-cp312-cp312-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl (12.4 MB)\n",
155
+ "\u001b[2K \u001b[38;2;114;156;31m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m12.4/12.4 MB\u001b[0m \u001b[31m43.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0mm eta \u001b[36m0:00:01\u001b[0m0:01\u001b[0m01\u001b[0m\n",
156
+ "\u001b[?25hDownloading xxhash-3.6.0-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl (193 kB)\n",
157
+ "\u001b[2K \u001b[38;2;114;156;31m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m193.9/193.9 kB\u001b[0m \u001b[31m19.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
158
+ "\u001b[?25hDownloading aiohttp-3.13.2-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl (1.8 MB)\n",
159
+ "\u001b[2K \u001b[38;2;114;156;31m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m1.8/1.8 MB\u001b[0m \u001b[31m38.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m31m44.4 MB/s\u001b[0m eta \u001b[36m0:00:01\u001b[0m\n",
160
+ "\u001b[?25hDownloading pytz-2025.2-py2.py3-none-any.whl (509 kB)\n",
161
+ "\u001b[2K \u001b[38;2;114;156;31m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m509.2/509.2 kB\u001b[0m \u001b[31m31.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
162
+ "\u001b[?25hDownloading tzdata-2025.3-py2.py3-none-any.whl (348 kB)\n",
163
+ "\u001b[2K \u001b[38;2;114;156;31m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m348.5/348.5 kB\u001b[0m \u001b[31m40.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
164
+ "\u001b[?25hDownloading anyio-4.12.0-py3-none-any.whl (113 kB)\n",
165
+ "\u001b[2K \u001b[38;2;114;156;31m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m113.4/113.4 kB\u001b[0m \u001b[31m29.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
166
+ "\u001b[?25hDownloading aiohappyeyeballs-2.6.1-py3-none-any.whl (15 kB)\n",
167
+ "Downloading aiosignal-1.4.0-py3-none-any.whl (7.5 kB)\n",
168
+ "Downloading attrs-25.4.0-py3-none-any.whl (67 kB)\n",
169
+ "\u001b[2K \u001b[38;2;114;156;31m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m67.6/67.6 kB\u001b[0m \u001b[31m18.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
170
+ "\u001b[?25hDownloading frozenlist-1.8.0-cp312-cp312-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl (242 kB)\n",
171
+ "\u001b[2K \u001b[38;2;114;156;31m━━━━━━━━━━━━━━━━━━━━━━━━━━��━━━━━━━━━━━━━\u001b[0m \u001b[32m242.4/242.4 kB\u001b[0m \u001b[31m25.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
172
+ "\u001b[?25hUsing cached h11-0.16.0-py3-none-any.whl (37 kB)\n",
173
+ "Downloading multidict-6.7.0-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl (256 kB)\n",
174
+ "\u001b[2K \u001b[38;2;114;156;31m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m256.1/256.1 kB\u001b[0m \u001b[31m23.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
175
+ "\u001b[?25hDownloading propcache-0.4.1-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl (221 kB)\n",
176
+ "\u001b[2K \u001b[38;2;114;156;31m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m221.6/221.6 kB\u001b[0m \u001b[31m28.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
177
+ "\u001b[?25hDownloading yarl-1.22.0-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl (377 kB)\n",
178
+ "\u001b[2K \u001b[38;2;114;156;31m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m377.3/377.3 kB\u001b[0m \u001b[31m31.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
179
+ "\u001b[?25hInstalling collected packages: pytz, xxhash, tzdata, pyarrow, propcache, multidict, h11, fsspec, frozenlist, dill, attrs, anyio, aiohappyeyeballs, yarl, pandas, multiprocess, httpcore, aiosignal, httpx, aiohttp, datasets\n",
180
+ " Attempting uninstall: fsspec\n",
181
+ " Found existing installation: fsspec 2025.12.0\n",
182
+ " Uninstalling fsspec-2025.12.0:\n",
183
+ " Successfully uninstalled fsspec-2025.12.0\n",
184
+ "\u001b[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.\n",
185
+ "genomix-research 0.1.0 requires absl-py==2.1.0, which is not installed.\n",
186
+ "genomix-research 0.1.0 requires aiobotocore==2.21.1, which is not installed.\n",
187
+ "genomix-research 0.1.0 requires aioitertools==0.12.0, which is not installed.\n",
188
+ "genomix-research 0.1.0 requires antlr4-python3-runtime==4.9.3, which is not installed.\n",
189
+ "genomix-research 0.1.0 requires argon2-cffi==23.1.0, which is not installed.\n",
190
+ "genomix-research 0.1.0 requires argon2-cffi-bindings==21.2.0, which is not installed.\n",
191
+ "genomix-research 0.1.0 requires array-record==0.8.1, which is not installed.\n",
192
+ "genomix-research 0.1.0 requires arrow==1.3.0, which is not installed.\n",
193
+ "genomix-research 0.1.0 requires astunparse==1.6.3, which is not installed.\n",
194
+ "genomix-research 0.1.0 requires async-lru==2.0.5, which is not installed.\n",
195
+ "genomix-research 0.1.0 requires babel==2.17.0, which is not installed.\n",
196
+ "genomix-research 0.1.0 requires beautifulsoup4==4.13.3, which is not installed.\n",
197
+ "genomix-research 0.1.0 requires biopython==1.85, which is not installed.\n",
198
+ "genomix-research 0.1.0 requires bleach==6.2.0, which is not installed.\n",
199
+ "genomix-research 0.1.0 requires boto3==1.37.1, which is not installed.\n",
200
+ "genomix-research 0.1.0 requires botocore==1.37.1, which is not installed.\n",
201
+ "genomix-research 0.1.0 requires bravado==11.1.0, which is not installed.\n",
202
+ "genomix-research 0.1.0 requires bravado-core==5.16.1, which is not installed.\n",
203
+ "genomix-research 0.1.0 requires bx-python==0.13.0, which is not installed.\n",
204
+ "genomix-research 0.1.0 requires cachetools==5.5.2, which is not installed.\n",
205
+ "genomix-research 0.1.0 requires cffi==1.17.1, which is not installed.\n",
206
+ "genomix-research 0.1.0 requires cfgv==3.4.0, which is not installed.\n",
207
+ "genomix-research 0.1.0 requires chex==0.1.88, which is not installed.\n",
208
+ "genomix-research 0.1.0 requires click==8.1.8, which is not installed.\n",
209
+ "genomix-research 0.1.0 requires cloudpickle==3.1.1, which is not installed.\n",
210
+ "genomix-research 0.1.0 requires defusedxml==0.7.1, which is not installed.\n",
211
+ "genomix-research 0.1.0 requires distlib==0.3.9, which is not installed.\n",
212
+ "genomix-research 0.1.0 requires distrax>=0.1.5, which is not installed.\n",
213
+ "genomix-research 0.1.0 requires dm-tree==0.1.9, which is not installed.\n",
214
+ "genomix-research 0.1.0 requires etils==1.12.1, which is not installed.\n",
215
+ "genomix-research 0.1.0 requires fastjsonschema==2.21.1, which is not installed.\n",
216
+ "genomix-research 0.1.0 requires flatbuffers==25.2.10, which is not installed.\n",
217
+ "genomix-research 0.1.0 requires flax==0.10.4, which is not installed.\n",
218
+ "genomix-research 0.1.0 requires fqdn==1.5.1, which is not installed.\n",
219
+ "genomix-research 0.1.0 requires future==1.0.0, which is not installed.\n",
220
+ "genomix-research 0.1.0 requires gast==0.6.0, which is not installed.\n",
221
+ "genomix-research 0.1.0 requires gcsfs==2025.3.0, which is not installed.\n",
222
+ "genomix-research 0.1.0 requires gitdb==4.0.12, which is not installed.\n",
223
+ "genomix-research 0.1.0 requires gitpython==3.1.44, which is not installed.\n",
224
+ "genomix-research 0.1.0 requires google-api-core==2.24.1, which is not installed.\n",
225
+ "genomix-research 0.1.0 requires google-api-python-client==2.165.0, which is not installed.\n",
226
+ "genomix-research 0.1.0 requires google-auth==2.38.0, which is not installed.\n",
227
+ "genomix-research 0.1.0 requires google-auth-httplib2==0.2.0, which is not installed.\n",
228
+ "genomix-research 0.1.0 requires google-auth-oauthlib==1.2.1, which is not installed.\n",
229
+ "genomix-research 0.1.0 requires google-cloud-core==2.4.2, which is not installed.\n",
230
+ "genomix-research 0.1.0 requires google-cloud-storage==3.1.0, which is not installed.\n",
231
+ "genomix-research 0.1.0 requires google-crc32c==1.6.0, which is not installed.\n",
232
+ "genomix-research 0.1.0 requires google-pasta==0.2.0, which is not installed.\n",
233
+ "genomix-research 0.1.0 requires google-resumable-media==2.7.2, which is not installed.\n",
234
+ "genomix-research 0.1.0 requires googleapis-common-protos==1.69.1, which is not installed.\n",
235
+ "genomix-research 0.1.0 requires grain==0.2.11, which is not installed.\n",
236
+ "genomix-research 0.1.0 requires grigri==0.0.2, which is not installed.\n",
237
+ "genomix-research 0.1.0 requires grpcio==1.71.0, which is not installed.\n",
238
+ "genomix-research 0.1.0 requires h5py==3.13.0, which is not installed.\n",
239
+ "genomix-research 0.1.0 requires httplib2==0.22.0, which is not installed.\n",
240
+ "genomix-research 0.1.0 requires humanize==4.12.1, which is not installed.\n",
241
+ "genomix-research 0.1.0 requires hydra-core==1.3.2, which is not installed.\n",
242
+ "genomix-research 0.1.0 requires identify==2.6.9, which is not installed.\n",
243
+ "genomix-research 0.1.0 requires importlib-resources==6.5.2, which is not installed.\n",
244
+ "genomix-research 0.1.0 requires iniconfig==2.0.0, which is not installed.\n",
245
+ "genomix-research 0.1.0 requires isoduration==20.11.0, which is not installed.\n",
246
+ "genomix-research 0.1.0 requires jax==0.5.3, which is not installed.\n",
247
+ "genomix-research 0.1.0 requires jaxlib==0.5.3, which is not installed.\n",
248
+ "genomix-research 0.1.0 requires jaxtyping==0.2.38, which is not installed.\n",
249
+ "genomix-research 0.1.0 requires jmespath==1.0.1, which is not installed.\n",
250
+ "genomix-research 0.1.0 requires json5==0.10.0, which is not installed.\n",
251
+ "genomix-research 0.1.0 requires jsonpointer==3.0.0, which is not installed.\n",
252
+ "genomix-research 0.1.0 requires jsonref==1.1.0, which is not installed.\n",
253
+ "genomix-research 0.1.0 requires jsonschema==4.23.0, which is not installed.\n",
254
+ "genomix-research 0.1.0 requires jsonschema-specifications==2024.10.1, which is not installed.\n",
255
+ "genomix-research 0.1.0 requires jupyter==1.1.1, which is not installed.\n",
256
+ "genomix-research 0.1.0 requires jupyter-console==6.6.3, which is not installed.\n",
257
+ "genomix-research 0.1.0 requires jupyter-events==0.12.0, which is not installed.\n",
258
+ "genomix-research 0.1.0 requires jupyter-lsp==2.2.5, which is not installed.\n",
259
+ "genomix-research 0.1.0 requires jupyter-server==2.15.0, which is not installed.\n",
260
+ "genomix-research 0.1.0 requires jupyter-server-terminals==0.5.3, which is not installed.\n",
261
+ "genomix-research 0.1.0 requires jupyterlab==4.3.6, which is not installed.\n",
262
+ "genomix-research 0.1.0 requires jupyterlab-pygments==0.3.0, which is not installed.\n",
263
+ "genomix-research 0.1.0 requires jupyterlab-server==2.27.3, which is not installed.\n",
264
+ "genomix-research 0.1.0 requires keras>=3.11.3, which is not installed.\n",
265
+ "genomix-research 0.1.0 requires libclang==18.1.1, which is not installed.\n",
266
+ "genomix-research 0.1.0 requires markdown==3.7, which is not installed.\n",
267
+ "genomix-research 0.1.0 requires markdown-it-py==3.0.0, which is not installed.\n",
268
+ "genomix-research 0.1.0 requires mdurl==0.1.2, which is not installed.\n",
269
+ "genomix-research 0.1.0 requires mistune==3.1.3, which is not installed.\n",
270
+ "genomix-research 0.1.0 requires ml-dtypes==0.5.1, which is not installed.\n",
271
+ "genomix-research 0.1.0 requires monotonic==1.6, which is not installed.\n",
272
+ "genomix-research 0.1.0 requires more-itertools==10.6.0, which is not installed.\n",
273
+ "genomix-research 0.1.0 requires msgpack==1.1.0, which is not installed.\n",
274
+ "genomix-research 0.1.0 requires namex==0.0.8, which is not installed.\n",
275
+ "genomix-research 0.1.0 requires natsort==8.4.0, which is not installed.\n",
276
+ "genomix-research 0.1.0 requires nbclient==0.10.2, which is not installed.\n",
277
+ "genomix-research 0.1.0 requires nbconvert==7.16.6, which is not installed.\n",
278
+ "genomix-research 0.1.0 requires nbformat==5.10.4, which is not installed.\n",
279
+ "genomix-research 0.1.0 requires ncls==0.0.68, which is not installed.\n",
280
+ "genomix-research 0.1.0 requires neptune==1.13.0, which is not installed.\n",
281
+ "genomix-research 0.1.0 requires nodeenv==1.9.1, which is not installed.\n",
282
+ "genomix-research 0.1.0 requires notebook==7.3.3, which is not installed.\n",
283
+ "genomix-research 0.1.0 requires notebook-shim==0.2.4, which is not installed.\n",
284
+ "genomix-research 0.1.0 requires oauthlib==3.2.2, which is not installed.\n",
285
+ "genomix-research 0.1.0 requires omegaconf==2.3.0, which is not installed.\n",
286
+ "genomix-research 0.1.0 requires opt-einsum==3.4.0, which is not installed.\n",
287
+ "genomix-research 0.1.0 requires optax==0.2.4, which is not installed.\n",
288
+ "genomix-research 0.1.0 requires optree==0.14.1, which is not installed.\n",
289
+ "genomix-research 0.1.0 requires orbax==0.1.9, which is not installed.\n",
290
+ "genomix-research 0.1.0 requires orbax-checkpoint==0.11.8, which is not installed.\n",
291
+ "genomix-research 0.1.0 requires overrides==7.7.0, which is not installed.\n",
292
+ "genomix-research 0.1.0 requires pandocfilters==1.5.1, which is not installed.\n",
293
+ "genomix-research 0.1.0 requires pluggy==1.5.0, which is not installed.\n",
294
+ "genomix-research 0.1.0 requires pre-commit==4.1.0, which is not installed.\n",
295
+ "genomix-research 0.1.0 requires prometheus-client==0.21.1, which is not installed.\n",
296
+ "genomix-research 0.1.0 requires proto-plus==1.26.0, which is not installed.\n",
297
+ "genomix-research 0.1.0 requires protobuf==4.25.7, which is not installed.\n",
298
+ "genomix-research 0.1.0 requires pyasn1==0.6.1, which is not installed.\n",
299
+ "genomix-research 0.1.0 requires pyasn1-modules==0.4.1, which is not installed.\n",
300
+ "genomix-research 0.1.0 requires pycparser==2.22, which is not installed.\n",
301
+ "genomix-research 0.1.0 requires pyjwt==2.10.1, which is not installed.\n",
302
+ "genomix-research 0.1.0 requires pyranges==0.1.4, which is not installed.\n",
303
+ "genomix-research 0.1.0 requires pysam==0.23.0, which is not installed.\n",
304
+ "genomix-research 0.1.0 requires pytest==8.3.5, which is not installed.\n",
305
+ "genomix-research 0.1.0 requires pytest-randomly>=3.16.0, which is not installed.\n",
306
+ "genomix-research 0.1.0 requires pytest-split>=0.10.0, which is not installed.\n",
307
+ "genomix-research 0.1.0 requires python-json-logger==3.3.0, which is not installed.\n",
308
+ "genomix-research 0.1.0 requires ray[default]>=2.49.0, which is not installed.\n",
309
+ "genomix-research 0.1.0 requires referencing==0.36.2, which is not installed.\n",
310
+ "genomix-research 0.1.0 requires requests-oauthlib==2.0.0, which is not installed.\n",
311
+ "genomix-research 0.1.0 requires rfc3339-validator==0.1.4, which is not installed.\n",
312
+ "genomix-research 0.1.0 requires rfc3986-validator==0.1.1, which is not installed.\n",
313
+ "genomix-research 0.1.0 requires rfc3987==1.3.8, which is not installed.\n",
314
+ "genomix-research 0.1.0 requires rich==13.9.4, which is not installed.\n",
315
+ "genomix-research 0.1.0 requires rpds-py==0.23.1, which is not installed.\n",
316
+ "genomix-research 0.1.0 requires rsa==4.7, which is not installed.\n",
317
+ "genomix-research 0.1.0 requires s3fs==2025.3.0, which is not installed.\n",
318
+ "genomix-research 0.1.0 requires s3transfer==0.11.3, which is not installed.\n",
319
+ "genomix-research 0.1.0 requires scikit-learn>=1.6.1, which is not installed.\n",
320
+ "genomix-research 0.1.0 requires scipy==1.15.2, which is not installed.\n",
321
+ "genomix-research 0.1.0 requires seaborn>=0.13.2, which is not installed.\n",
322
+ "genomix-research 0.1.0 requires send2trash==1.8.3, which is not installed.\n",
323
+ "genomix-research 0.1.0 requires simplejson==3.20.1, which is not installed.\n",
324
+ "genomix-research 0.1.0 requires smmap==5.0.2, which is not installed.\n",
325
+ "genomix-research 0.1.0 requires sniffio==1.3.1, which is not installed.\n",
326
+ "genomix-research 0.1.0 requires sorted-nearest==0.0.39, which is not installed.\n",
327
+ "genomix-research 0.1.0 requires soupsieve==2.6, which is not installed.\n",
328
+ "genomix-research 0.1.0 requires swagger-spec-validator==3.0.4, which is not installed.\n",
329
+ "genomix-research 0.1.0 requires tabulate==0.9.0, which is not installed.\n",
330
+ "genomix-research 0.1.0 requires tenacity>=9.1.2, which is not installed.\n",
331
+ "genomix-research 0.1.0 requires tensorboard==2.19.0, which is not installed.\n",
332
+ "genomix-research 0.1.0 requires tensorboard-data-server==0.7.2, which is not installed.\n",
333
+ "genomix-research 0.1.0 requires tensorboard-plugin-profile==2.20.6, which is not installed.\n",
334
+ "genomix-research 0.1.0 requires tensorflow==2.19.0, which is not installed.\n",
335
+ "genomix-research 0.1.0 requires tensorflow-io==0.37.1, which is not installed.\n",
336
+ "genomix-research 0.1.0 requires tensorflow-io-gcs-filesystem==0.37.1, which is not installed.\n",
337
+ "genomix-research 0.1.0 requires tensorstore==0.1.71, which is not installed.\n",
338
+ "genomix-research 0.1.0 requires termcolor==2.5.0, which is not installed.\n",
339
+ "genomix-research 0.1.0 requires terminado==0.18.1, which is not installed.\n",
340
+ "genomix-research 0.1.0 requires tinycss2==1.4.0, which is not installed.\n",
341
+ "genomix-research 0.1.0 requires toolz==1.0.0, which is not installed.\n",
342
+ "genomix-research 0.1.0 requires treescope==0.1.9, which is not installed.\n",
343
+ "genomix-research 0.1.0 requires types-python-dateutil==2.9.0.20241206, which is not installed.\n",
344
+ "genomix-research 0.1.0 requires umap-learn>=0.5.9.post2, which is not installed.\n",
345
+ "genomix-research 0.1.0 requires uri-template==1.3.0, which is not installed.\n",
346
+ "genomix-research 0.1.0 requires uritemplate==4.1.1, which is not installed.\n",
347
+ "genomix-research 0.1.0 requires virtualenv==20.29.3, which is not installed.\n",
348
+ "genomix-research 0.1.0 requires wadler-lindig==0.1.4, which is not installed.\n",
349
+ "genomix-research 0.1.0 requires waffle==0.4.0, which is not installed.\n",
350
+ "genomix-research 0.1.0 requires webcolors==24.11.1, which is not installed.\n",
351
+ "genomix-research 0.1.0 requires webencodings==0.5.1, which is not installed.\n",
352
+ "genomix-research 0.1.0 requires websocket-client==1.8.0, which is not installed.\n",
353
+ "genomix-research 0.1.0 requires werkzeug==3.1.3, which is not installed.\n",
354
+ "genomix-research 0.1.0 requires wheel==0.45.1, which is not installed.\n",
355
+ "genomix-research 0.1.0 requires wrapt==1.17.2, which is not installed.\n",
356
+ "genomix-research 0.1.0 requires zipp==3.21.0, which is not installed.\n",
357
+ "genomix-research 0.1.0 requires aiohappyeyeballs==2.5.0, but you have aiohappyeyeballs 2.6.1 which is incompatible.\n",
358
+ "genomix-research 0.1.0 requires aiohttp==3.11.13, but you have aiohttp 3.13.2 which is incompatible.\n",
359
+ "genomix-research 0.1.0 requires aiosignal==1.3.2, but you have aiosignal 1.4.0 which is incompatible.\n",
360
+ "genomix-research 0.1.0 requires anyio==4.9.0, but you have anyio 4.12.0 which is incompatible.\n",
361
+ "genomix-research 0.1.0 requires asttokens==3.0.0, but you have asttokens 3.0.1 which is incompatible.\n",
362
+ "genomix-research 0.1.0 requires attrs==25.1.0, but you have attrs 25.4.0 which is incompatible.\n",
363
+ "genomix-research 0.1.0 requires certifi==2025.1.31, but you have certifi 2025.11.12 which is incompatible.\n",
364
+ "genomix-research 0.1.0 requires charset-normalizer==3.4.1, but you have charset-normalizer 3.4.4 which is incompatible.\n",
365
+ "genomix-research 0.1.0 requires comm==0.2.2, but you have comm 0.2.3 which is incompatible.\n",
366
+ "genomix-research 0.1.0 requires debugpy==1.8.13, but you have debugpy 1.8.17 which is incompatible.\n",
367
+ "genomix-research 0.1.0 requires executing==2.2.0, but you have executing 2.2.1 which is incompatible.\n",
368
+ "genomix-research 0.1.0 requires filelock==3.17.0, but you have filelock 3.20.0 which is incompatible.\n",
369
+ "genomix-research 0.1.0 requires frozenlist==1.5.0, but you have frozenlist 1.8.0 which is incompatible.\n",
370
+ "genomix-research 0.1.0 requires fsspec==2025.3.0, but you have fsspec 2025.10.0 which is incompatible.\n",
371
+ "genomix-research 0.1.0 requires h11==0.14.0, but you have h11 0.16.0 which is incompatible.\n",
372
+ "genomix-research 0.1.0 requires httpcore==1.0.7, but you have httpcore 1.0.9 which is incompatible.\n",
373
+ "genomix-research 0.1.0 requires idna==3.10, but you have idna 3.11 which is incompatible.\n",
374
+ "genomix-research 0.1.0 requires ipykernel==6.29.5, but you have ipykernel 7.1.0 which is incompatible.\n",
375
+ "genomix-research 0.1.0 requires ipython==9.0.2, but you have ipython 9.8.0 which is incompatible.\n",
376
+ "genomix-research 0.1.0 requires ipywidgets==8.1.5, but you have ipywidgets 8.1.8 which is incompatible.\n",
377
+ "genomix-research 0.1.0 requires jupyter-core==5.7.2, but you have jupyter-core 5.9.1 which is incompatible.\n",
378
+ "genomix-research 0.1.0 requires jupyterlab-widgets==3.0.13, but you have jupyterlab-widgets 3.0.16 which is incompatible.\n",
379
+ "genomix-research 0.1.0 requires markupsafe==3.0.2, but you have markupsafe 3.0.3 which is incompatible.\n",
380
+ "genomix-research 0.1.0 requires matplotlib-inline==0.1.7, but you have matplotlib-inline 0.2.1 which is incompatible.\n",
381
+ "genomix-research 0.1.0 requires multidict==6.1.0, but you have multidict 6.7.0 which is incompatible.\n",
382
+ "genomix-research 0.1.0 requires numpy==2.1.3, but you have numpy 2.3.5 which is incompatible.\n",
383
+ "genomix-research 0.1.0 requires packaging==24.2, but you have packaging 25.0 which is incompatible.\n",
384
+ "genomix-research 0.1.0 requires pandas==2.2.3, but you have pandas 2.3.3 which is incompatible.\n",
385
+ "genomix-research 0.1.0 requires parso==0.8.4, but you have parso 0.8.5 which is incompatible.\n",
386
+ "genomix-research 0.1.0 requires pillow==11.1.0, but you have pillow 12.0.0 which is incompatible.\n",
387
+ "genomix-research 0.1.0 requires platformdirs==4.3.6, but you have platformdirs 4.5.1 which is incompatible.\n",
388
+ "genomix-research 0.1.0 requires prompt-toolkit==3.0.50, but you have prompt-toolkit 3.0.52 which is incompatible.\n",
389
+ "genomix-research 0.1.0 requires propcache==0.3.0, but you have propcache 0.4.1 which is incompatible.\n",
390
+ "genomix-research 0.1.0 requires psutil==7.0.0, but you have psutil 7.1.3 which is incompatible.\n",
391
+ "genomix-research 0.1.0 requires pygments==2.19.1, but you have pygments 2.19.2 which is incompatible.\n",
392
+ "genomix-research 0.1.0 requires pyparsing==3.2.1, but you have pyparsing 3.2.5 which is incompatible.\n",
393
+ "genomix-research 0.1.0 requires pytz==2025.1, but you have pytz 2025.2 which is incompatible.\n",
394
+ "genomix-research 0.1.0 requires pyyaml==6.0.2, but you have pyyaml 6.0.3 which is incompatible.\n",
395
+ "genomix-research 0.1.0 requires pyzmq==26.3.0, but you have pyzmq 27.1.0 which is incompatible.\n",
396
+ "genomix-research 0.1.0 requires regex==2024.11.6, but you have regex 2025.11.3 which is incompatible.\n",
397
+ "genomix-research 0.1.0 requires requests==2.32.3, but you have requests 2.32.5 which is incompatible.\n",
398
+ "genomix-research 0.1.0 requires setuptools==77.0.1, but you have setuptools 80.9.0 which is incompatible.\n",
399
+ "genomix-research 0.1.0 requires tornado==6.4.2, but you have tornado 6.5.2 which is incompatible.\n",
400
+ "genomix-research 0.1.0 requires typing-extensions==4.12.2, but you have typing-extensions 4.15.0 which is incompatible.\n",
401
+ "genomix-research 0.1.0 requires tzdata==2025.1, but you have tzdata 2025.3 which is incompatible.\n",
402
+ "genomix-research 0.1.0 requires urllib3==2.3.0, but you have urllib3 2.6.1 which is incompatible.\n",
403
+ "genomix-research 0.1.0 requires wcwidth==0.2.13, but you have wcwidth 0.2.14 which is incompatible.\n",
404
+ "genomix-research 0.1.0 requires widgetsnbextension==4.0.13, but you have widgetsnbextension 4.0.15 which is incompatible.\n",
405
+ "genomix-research 0.1.0 requires yarl==1.18.3, but you have yarl 1.22.0 which is incompatible.\u001b[0m\u001b[31m\n",
406
+ "\u001b[0mSuccessfully installed aiohappyeyeballs-2.6.1 aiohttp-3.13.2 aiosignal-1.4.0 anyio-4.12.0 attrs-25.4.0 datasets-4.4.1 dill-0.4.0 frozenlist-1.8.0 fsspec-2025.10.0 h11-0.16.0 httpcore-1.0.9 httpx-0.28.1 multidict-6.7.0 multiprocess-0.70.18 pandas-2.3.3 propcache-0.4.1 pyarrow-22.0.0 pytz-2025.2 tzdata-2025.3 xxhash-3.6.0 yarl-1.22.0\n"
407
+ ]
408
+ }
409
+ ],
410
  "source": [
411
  "# Install dependencies\n",
412
+ "!pip install datasets transformers torchmetrics plotly "
413
+ ]
414
+ },
415
+ {
416
+ "cell_type": "code",
417
+ "execution_count": 7,
418
+ "metadata": {},
419
+ "outputs": [],
420
+ "source": [
421
+ "# Imports\n",
422
+ "from typing import List, Dict\n",
423
+ "import os\n",
424
+ "\n",
425
+ "import torch\n",
426
+ "import torch.nn as nn\n",
427
+ "import torch.nn.functional as F\n",
428
+ "from torch.utils.data import DataLoader\n",
429
+ "from torch.optim import AdamW\n",
430
+ "from transformers import AutoConfig, AutoModelForMaskedLM, AutoTokenizer\n",
431
+ "from datasets import load_dataset\n",
432
+ "import numpy as np\n",
433
+ "from torchmetrics import PearsonCorrCoef\n",
434
+ "import plotly.graph_objects as go\n",
435
+ "from IPython.display import display\n",
436
+ "from tqdm import tqdm"
437
  ]
438
  },
439
  {
440
  "cell_type": "markdown",
441
  "metadata": {},
442
  "source": [
443
+ "# 1. ⚙️ Configuration\n",
444
  "\n",
445
  "## Configuration Parameters\n",
446
  "\n",
 
474
  },
475
  {
476
  "cell_type": "code",
477
+ "execution_count": 13,
478
  "metadata": {},
479
+ "outputs": [
480
+ {
481
+ "name": "stdout",
482
+ "output_type": "stream",
483
+ "text": [
484
+ "Using device: cpu\n"
485
+ ]
486
+ }
487
+ ],
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
488
  "source": [
489
  "config = {\n",
490
  " # Model\n",
491
  " \"model_name\": \"InstaDeepAI/NTv3_8M_pre\",\n",
492
  " \n",
493
+ " # Data - Hugging Face Dataset Configuration\n",
494
+ " \"dataset_name\": \"InstaDeepAI/bigwig_tracks\", # Hugging Face dataset name or path to script\n",
495
  " \"data_cache_dir\": \"./data\",\n",
496
  " \"fasta_url\": \"https://hgdownload.gi.ucsc.edu/goldenPath/hg38/bigZips/hg38.fa.gz\",\n",
497
  " \"bigwig_url_list\": [\n",
 
526
  "\n",
527
  "os.makedirs(config[\"data_cache_dir\"], exist_ok=True)\n",
528
  "\n",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
529
  "# Create bigwig_file_ids from filenames (without extension)\n",
530
  "config[\"bigwig_file_ids\"] = [\n",
 
 
531
  " \"ENCSR325NFE\",\n",
532
  " \"ENCSR962OTG\",\n",
533
  " \"ENCSR619DQO_P\",\n",
 
547
  "cell_type": "markdown",
548
  "metadata": {},
549
  "source": [
550
+ "# 2. 🧠 Model and tokenizer setup\n",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
551
  " \n",
552
  "In this section, we set up the model and tokenizer. \n",
553
  " \n",
 
557
  "This linear head is trained for regression on a set of genomic tracks, \n",
558
  "allowing the model to make predictions for each track at single nucleotide resolution.\n",
559
  " \n",
560
+ "The following code wraps the HuggingFace model together with this regression head for the end-to-end task."
561
  ]
562
  },
563
  {
564
  "cell_type": "code",
565
+ "execution_count": 9,
566
  "metadata": {},
567
  "outputs": [],
568
  "source": [
 
630
  },
631
  {
632
  "cell_type": "code",
633
+ "execution_count": 10,
634
  "metadata": {},
635
+ "outputs": [
636
+ {
637
+ "name": "stdout",
638
+ "output_type": "stream",
639
+ "text": [
640
+ "Model loaded: InstaDeepAI/NTv3_8M_pre\n",
641
+ "Number of bigwig tracks: 4\n",
642
+ "Model parameters: 7,694,015\n"
643
+ ]
644
+ }
645
+ ],
646
  "source": [
647
  "# Load tokenizer\n",
648
  "tokenizer = AutoTokenizer.from_pretrained(config[\"model_name\"], trust_remote_code=True)\n",
 
658
  "\n",
659
  "print(f\"Model loaded: {config['model_name']}\")\n",
660
  "print(f\"Number of bigwig tracks: {len(config['bigwig_file_ids'])}\")\n",
661
+ "print(f\"Model parameters: {sum(p.numel() for p in model.parameters()):,}\")\n"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
662
  ]
663
  },
664
  {
665
  "cell_type": "markdown",
666
  "metadata": {},
667
  "source": [
668
+ "# 3. 📥 Dataset setup\n",
669
  "\n",
670
+ "Load the Hugging Face dataset and set up the data pipeline. The dataset automatically handles downloading FASTA and BigWig files, normalizing tracks, and sampling random genomic windows."
671
  ]
672
  },
673
  {
674
  "cell_type": "code",
675
+ "execution_count": 14,
676
  "metadata": {},
677
+ "outputs": [
678
+ {
679
+ "name": "stdout",
680
+ "output_type": "stream",
681
+ "text": [
682
+ "Loading dataset from InstaDeepAI/bigwig_tracks...\n"
683
+ ]
684
+ },
685
+ {
686
+ "data": {
687
+ "application/vnd.jupyter.widget-view+json": {
688
+ "model_id": "d9e36ca0c8e544339833c04f68f485aa",
689
+ "version_major": 2,
690
+ "version_minor": 0
691
+ },
692
+ "text/plain": [
693
+ "README.md: 0%| | 0.00/4.24k [00:00<?, ?B/s]"
694
+ ]
695
+ },
696
+ "metadata": {},
697
+ "output_type": "display_data"
698
+ },
699
+ {
700
+ "ename": "FileNotFoundError",
701
+ "evalue": "Couldn't find any data file at /home/y-bornachot/ntv3/notebooks/InstaDeepAI/bigwig_tracks. Couldn't find 'InstaDeepAI/bigwig_tracks' on the Hugging Face Hub either: FileNotFoundError: Unable to find 'hf://datasets/InstaDeepAI/bigwig_tracks@7fe68eaafda66223c3fe392f5fa2ad81173047a1/./data/chr1' with any supported extension ['.csv', '.tsv', '.json', '.jsonl', '.ndjson', '.parquet', '.geoparquet', '.gpq', '.arrow', '.txt', '.tar', '.xml', '.hdf5', '.h5', '.blp', '.bmp', '.dib', '.bufr', '.cur', '.pcx', '.dcx', '.dds', '.ps', '.eps', '.fit', '.fits', '.fli', '.flc', '.ftc', '.ftu', '.gbr', '.gif', '.grib', '.png', '.apng', '.jp2', '.j2k', '.jpc', '.jpf', '.jpx', '.j2c', '.icns', '.ico', '.im', '.iim', '.tif', '.tiff', '.jfif', '.jpe', '.jpg', '.jpeg', '.mpg', '.mpeg', '.msp', '.pcd', '.pxr', '.pbm', '.pgm', '.ppm', '.pnm', '.psd', '.bw', '.rgb', '.rgba', '.sgi', '.ras', '.tga', '.icb', '.vda', '.vst', '.webp', '.wmf', '.emf', '.xbm', '.xpm', '.BLP', '.BMP', '.DIB', '.BUFR', '.CUR', '.PCX', '.DCX', '.DDS', '.PS', '.EPS', '.FIT', '.FITS', '.FLI', '.FLC', '.FTC', '.FTU', '.GBR', '.GIF', '.GRIB', '.PNG', '.APNG', '.JP2', '.J2K', '.JPC', '.JPF', '.JPX', '.J2C', '.ICNS', '.ICO', '.IM', '.IIM', '.TIF', '.TIFF', '.JFIF', '.JPE', '.JPG', '.JPEG', '.MPG', '.MPEG', '.MSP', '.PCD', '.PXR', '.PBM', '.PGM', '.PPM', '.PNM', '.PSD', '.BW', '.RGB', '.RGBA', '.SGI', '.RAS', '.TGA', '.ICB', '.VDA', '.VST', '.WEBP', '.WMF', '.EMF', '.XBM', '.XPM', '.aiff', '.au', '.avr', '.caf', '.flac', '.htk', '.svx', '.mat4', '.mat5', '.mpc2k', '.ogg', '.paf', '.pvf', '.raw', '.rf64', '.sd2', '.sds', '.ircam', '.voc', '.w64', '.wav', '.nist', '.wavex', '.wve', '.xi', '.mp3', '.opus', '.3gp', '.3g2', '.avi', '.asf', '.flv', '.mp4', '.mov', '.m4v', '.mkv', '.webm', '.f4v', '.wmv', '.wma', '.ogm', '.mxf', '.nut', '.AIFF', '.AU', '.AVR', '.CAF', '.FLAC', '.HTK', '.SVX', '.MAT4', '.MAT5', '.MPC2K', '.OGG', '.PAF', '.PVF', '.RAW', '.RF64', '.SD2', '.SDS', '.IRCAM', '.VOC', '.W64', '.WAV', '.NIST', '.WAVEX', '.WVE', '.XI', '.MP3', '.OPUS', '.3GP', '.3G2', '.AVI', '.ASF', '.FLV', '.MP4', '.MOV', '.M4V', '.MKV', '.WEBM', '.F4V', '.WMV', '.WMA', '.OGM', '.MXF', '.NUT', '.pdf', '.PDF', '.nii', '.nii.gz', '.NII', '.NII.GZ', '.zip']",
702
+ "output_type": "error",
703
+ "traceback": [
704
+ "\u001b[31m---------------------------------------------------------------------------\u001b[39m",
705
+ "\u001b[31mFileNotFoundError\u001b[39m Traceback (most recent call last)",
706
+ "\u001b[36mCell\u001b[39m\u001b[36m \u001b[39m\u001b[32mIn[14]\u001b[39m\u001b[32m, line 16\u001b[39m\n\u001b[32m 9\u001b[39m num_samples = {\n\u001b[32m 10\u001b[39m \u001b[33m\"\u001b[39m\u001b[33mtrain\u001b[39m\u001b[33m\"\u001b[39m: config[\u001b[33m\"\u001b[39m\u001b[33mnum_steps_training\u001b[39m\u001b[33m\"\u001b[39m] * config[\u001b[33m\"\u001b[39m\u001b[33mbatch_size\u001b[39m\u001b[33m\"\u001b[39m],\n\u001b[32m 11\u001b[39m \u001b[33m\"\u001b[39m\u001b[33mval\u001b[39m\u001b[33m\"\u001b[39m: config[\u001b[33m\"\u001b[39m\u001b[33mnum_validation_samples\u001b[39m\u001b[33m\"\u001b[39m],\n\u001b[32m 12\u001b[39m \u001b[33m\"\u001b[39m\u001b[33mtest\u001b[39m\u001b[33m\"\u001b[39m: config[\u001b[33m\"\u001b[39m\u001b[33mnum_test_samples\u001b[39m\u001b[33m\"\u001b[39m],\n\u001b[32m 13\u001b[39m }\n\u001b[32m 15\u001b[39m \u001b[38;5;28mprint\u001b[39m(\u001b[33mf\u001b[39m\u001b[33m\"\u001b[39m\u001b[33mLoading dataset from \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mconfig[\u001b[33m'\u001b[39m\u001b[33mdataset_name\u001b[39m\u001b[33m'\u001b[39m]\u001b[38;5;132;01m}\u001b[39;00m\u001b[33m...\u001b[39m\u001b[33m\"\u001b[39m)\n\u001b[32m---> \u001b[39m\u001b[32m16\u001b[39m dataset = \u001b[43mload_dataset\u001b[49m\u001b[43m(\u001b[49m\n\u001b[32m 17\u001b[39m \u001b[43m \u001b[49m\u001b[43mconfig\u001b[49m\u001b[43m[\u001b[49m\u001b[33;43m\"\u001b[39;49m\u001b[33;43mdataset_name\u001b[39;49m\u001b[33;43m\"\u001b[39;49m\u001b[43m]\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 18\u001b[39m \u001b[43m \u001b[49m\u001b[43mdata_files\u001b[49m\u001b[43m=\u001b[49m\u001b[43mchrom_splits\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 19\u001b[39m \u001b[43m \u001b[49m\u001b[43mnum_samples\u001b[49m\u001b[43m=\u001b[49m\u001b[43mnum_samples\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 20\u001b[39m \u001b[43m \u001b[49m\u001b[43mfasta_url\u001b[49m\u001b[43m=\u001b[49m\u001b[43mconfig\u001b[49m\u001b[43m[\u001b[49m\u001b[33;43m\"\u001b[39;49m\u001b[33;43mfasta_url\u001b[39;49m\u001b[33;43m\"\u001b[39;49m\u001b[43m]\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 21\u001b[39m \u001b[43m \u001b[49m\u001b[43mbigwig_urls\u001b[49m\u001b[43m=\u001b[49m\u001b[43mconfig\u001b[49m\u001b[43m[\u001b[49m\u001b[33;43m\"\u001b[39;49m\u001b[33;43mbigwig_url_list\u001b[39;49m\u001b[33;43m\"\u001b[39;49m\u001b[43m]\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 22\u001b[39m \u001b[43m \u001b[49m\u001b[43msequence_length\u001b[49m\u001b[43m=\u001b[49m\u001b[43mconfig\u001b[49m\u001b[43m[\u001b[49m\u001b[33;43m\"\u001b[39;49m\u001b[33;43msequence_length\u001b[39;49m\u001b[33;43m\"\u001b[39;49m\u001b[43m]\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 23\u001b[39m \u001b[43m \u001b[49m\u001b[43mdata_dir\u001b[49m\u001b[43m=\u001b[49m\u001b[43mconfig\u001b[49m\u001b[43m[\u001b[49m\u001b[33;43m\"\u001b[39;49m\u001b[33;43mdata_cache_dir\u001b[39;49m\u001b[33;43m\"\u001b[39;49m\u001b[43m]\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 24\u001b[39m \u001b[43m)\u001b[49m\n",
707
+ "\u001b[36mFile \u001b[39m\u001b[32m~/venvs/ntv3-env/lib/python3.12/site-packages/datasets/load.py:1397\u001b[39m, in \u001b[36mload_dataset\u001b[39m\u001b[34m(path, name, data_dir, data_files, split, cache_dir, features, download_config, download_mode, verification_mode, keep_in_memory, save_infos, revision, token, streaming, num_proc, storage_options, **config_kwargs)\u001b[39m\n\u001b[32m 1392\u001b[39m verification_mode = VerificationMode(\n\u001b[32m 1393\u001b[39m (verification_mode \u001b[38;5;129;01mor\u001b[39;00m VerificationMode.BASIC_CHECKS) \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m save_infos \u001b[38;5;28;01melse\u001b[39;00m VerificationMode.ALL_CHECKS\n\u001b[32m 1394\u001b[39m )\n\u001b[32m 1396\u001b[39m \u001b[38;5;66;03m# Create a dataset builder\u001b[39;00m\n\u001b[32m-> \u001b[39m\u001b[32m1397\u001b[39m builder_instance = \u001b[43mload_dataset_builder\u001b[49m\u001b[43m(\u001b[49m\n\u001b[32m 1398\u001b[39m \u001b[43m \u001b[49m\u001b[43mpath\u001b[49m\u001b[43m=\u001b[49m\u001b[43mpath\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 1399\u001b[39m \u001b[43m \u001b[49m\u001b[43mname\u001b[49m\u001b[43m=\u001b[49m\u001b[43mname\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 1400\u001b[39m \u001b[43m \u001b[49m\u001b[43mdata_dir\u001b[49m\u001b[43m=\u001b[49m\u001b[43mdata_dir\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 1401\u001b[39m \u001b[43m \u001b[49m\u001b[43mdata_files\u001b[49m\u001b[43m=\u001b[49m\u001b[43mdata_files\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 1402\u001b[39m \u001b[43m \u001b[49m\u001b[43mcache_dir\u001b[49m\u001b[43m=\u001b[49m\u001b[43mcache_dir\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 1403\u001b[39m \u001b[43m \u001b[49m\u001b[43mfeatures\u001b[49m\u001b[43m=\u001b[49m\u001b[43mfeatures\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 1404\u001b[39m \u001b[43m \u001b[49m\u001b[43mdownload_config\u001b[49m\u001b[43m=\u001b[49m\u001b[43mdownload_config\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 1405\u001b[39m \u001b[43m \u001b[49m\u001b[43mdownload_mode\u001b[49m\u001b[43m=\u001b[49m\u001b[43mdownload_mode\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 1406\u001b[39m \u001b[43m \u001b[49m\u001b[43mrevision\u001b[49m\u001b[43m=\u001b[49m\u001b[43mrevision\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 1407\u001b[39m \u001b[43m \u001b[49m\u001b[43mtoken\u001b[49m\u001b[43m=\u001b[49m\u001b[43mtoken\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 1408\u001b[39m \u001b[43m \u001b[49m\u001b[43mstorage_options\u001b[49m\u001b[43m=\u001b[49m\u001b[43mstorage_options\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 1409\u001b[39m \u001b[43m \u001b[49m\u001b[43m*\u001b[49m\u001b[43m*\u001b[49m\u001b[43mconfig_kwargs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 1410\u001b[39m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 1412\u001b[39m \u001b[38;5;66;03m# Return iterable dataset in case of streaming\u001b[39;00m\n\u001b[32m 1413\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m streaming:\n",
708
+ "\u001b[36mFile \u001b[39m\u001b[32m~/venvs/ntv3-env/lib/python3.12/site-packages/datasets/load.py:1137\u001b[39m, in \u001b[36mload_dataset_builder\u001b[39m\u001b[34m(path, name, data_dir, data_files, cache_dir, features, download_config, download_mode, revision, token, storage_options, **config_kwargs)\u001b[39m\n\u001b[32m 1135\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m features \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[32m 1136\u001b[39m features = _fix_for_backward_compatible_features(features)\n\u001b[32m-> \u001b[39m\u001b[32m1137\u001b[39m dataset_module = \u001b[43mdataset_module_factory\u001b[49m\u001b[43m(\u001b[49m\n\u001b[32m 1138\u001b[39m \u001b[43m \u001b[49m\u001b[43mpath\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 1139\u001b[39m \u001b[43m \u001b[49m\u001b[43mrevision\u001b[49m\u001b[43m=\u001b[49m\u001b[43mrevision\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 1140\u001b[39m \u001b[43m \u001b[49m\u001b[43mdownload_config\u001b[49m\u001b[43m=\u001b[49m\u001b[43mdownload_config\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 1141\u001b[39m \u001b[43m \u001b[49m\u001b[43mdownload_mode\u001b[49m\u001b[43m=\u001b[49m\u001b[43mdownload_mode\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 1142\u001b[39m \u001b[43m \u001b[49m\u001b[43mdata_dir\u001b[49m\u001b[43m=\u001b[49m\u001b[43mdata_dir\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 1143\u001b[39m \u001b[43m \u001b[49m\u001b[43mdata_files\u001b[49m\u001b[43m=\u001b[49m\u001b[43mdata_files\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 1144\u001b[39m \u001b[43m \u001b[49m\u001b[43mcache_dir\u001b[49m\u001b[43m=\u001b[49m\u001b[43mcache_dir\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 1145\u001b[39m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 1146\u001b[39m \u001b[38;5;66;03m# Get dataset builder class\u001b[39;00m\n\u001b[32m 1147\u001b[39m builder_kwargs = dataset_module.builder_kwargs\n",
709
+ "\u001b[36mFile \u001b[39m\u001b[32m~/venvs/ntv3-env/lib/python3.12/site-packages/datasets/load.py:1032\u001b[39m, in \u001b[36mdataset_module_factory\u001b[39m\u001b[34m(path, revision, download_config, download_mode, data_dir, data_files, cache_dir, **download_kwargs)\u001b[39m\n\u001b[32m 1030\u001b[39m \u001b[38;5;28;01mraise\u001b[39;00m e1 \u001b[38;5;28;01mfrom\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[32m 1031\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(e1, \u001b[38;5;167;01mFileNotFoundError\u001b[39;00m):\n\u001b[32m-> \u001b[39m\u001b[32m1032\u001b[39m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mFileNotFoundError\u001b[39;00m(\n\u001b[32m 1033\u001b[39m \u001b[33mf\u001b[39m\u001b[33m\"\u001b[39m\u001b[33mCouldn\u001b[39m\u001b[33m'\u001b[39m\u001b[33mt find any data file at \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mrelative_to_absolute_path(path)\u001b[38;5;132;01m}\u001b[39;00m\u001b[33m. \u001b[39m\u001b[33m\"\u001b[39m\n\u001b[32m 1034\u001b[39m \u001b[33mf\u001b[39m\u001b[33m\"\u001b[39m\u001b[33mCouldn\u001b[39m\u001b[33m'\u001b[39m\u001b[33mt find \u001b[39m\u001b[33m'\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mpath\u001b[38;5;132;01m}\u001b[39;00m\u001b[33m'\u001b[39m\u001b[33m on the Hugging Face Hub either: \u001b[39m\u001b[38;5;132;01m{\u001b[39;00m\u001b[38;5;28mtype\u001b[39m(e1).\u001b[34m__name__\u001b[39m\u001b[38;5;132;01m}\u001b[39;00m\u001b[33m: \u001b[39m\u001b[38;5;132;01m{\u001b[39;00me1\u001b[38;5;132;01m}\u001b[39;00m\u001b[33m\"\u001b[39m\n\u001b[32m 1035\u001b[39m ) \u001b[38;5;28;01mfrom\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[32m 1036\u001b[39m \u001b[38;5;28;01mraise\u001b[39;00m e1 \u001b[38;5;28;01mfrom\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[32m 1037\u001b[39m \u001b[38;5;28;01melse\u001b[39;00m:\n",
710
+ "\u001b[31mFileNotFoundError\u001b[39m: Couldn't find any data file at /home/y-bornachot/ntv3/notebooks/InstaDeepAI/bigwig_tracks. Couldn't find 'InstaDeepAI/bigwig_tracks' on the Hugging Face Hub either: FileNotFoundError: Unable to find 'hf://datasets/InstaDeepAI/bigwig_tracks@7fe68eaafda66223c3fe392f5fa2ad81173047a1/./data/chr1' with any supported extension ['.csv', '.tsv', '.json', '.jsonl', '.ndjson', '.parquet', '.geoparquet', '.gpq', '.arrow', '.txt', '.tar', '.xml', '.hdf5', '.h5', '.blp', '.bmp', '.dib', '.bufr', '.cur', '.pcx', '.dcx', '.dds', '.ps', '.eps', '.fit', '.fits', '.fli', '.flc', '.ftc', '.ftu', '.gbr', '.gif', '.grib', '.png', '.apng', '.jp2', '.j2k', '.jpc', '.jpf', '.jpx', '.j2c', '.icns', '.ico', '.im', '.iim', '.tif', '.tiff', '.jfif', '.jpe', '.jpg', '.jpeg', '.mpg', '.mpeg', '.msp', '.pcd', '.pxr', '.pbm', '.pgm', '.ppm', '.pnm', '.psd', '.bw', '.rgb', '.rgba', '.sgi', '.ras', '.tga', '.icb', '.vda', '.vst', '.webp', '.wmf', '.emf', '.xbm', '.xpm', '.BLP', '.BMP', '.DIB', '.BUFR', '.CUR', '.PCX', '.DCX', '.DDS', '.PS', '.EPS', '.FIT', '.FITS', '.FLI', '.FLC', '.FTC', '.FTU', '.GBR', '.GIF', '.GRIB', '.PNG', '.APNG', '.JP2', '.J2K', '.JPC', '.JPF', '.JPX', '.J2C', '.ICNS', '.ICO', '.IM', '.IIM', '.TIF', '.TIFF', '.JFIF', '.JPE', '.JPG', '.JPEG', '.MPG', '.MPEG', '.MSP', '.PCD', '.PXR', '.PBM', '.PGM', '.PPM', '.PNM', '.PSD', '.BW', '.RGB', '.RGBA', '.SGI', '.RAS', '.TGA', '.ICB', '.VDA', '.VST', '.WEBP', '.WMF', '.EMF', '.XBM', '.XPM', '.aiff', '.au', '.avr', '.caf', '.flac', '.htk', '.svx', '.mat4', '.mat5', '.mpc2k', '.ogg', '.paf', '.pvf', '.raw', '.rf64', '.sd2', '.sds', '.ircam', '.voc', '.w64', '.wav', '.nist', '.wavex', '.wve', '.xi', '.mp3', '.opus', '.3gp', '.3g2', '.avi', '.asf', '.flv', '.mp4', '.mov', '.m4v', '.mkv', '.webm', '.f4v', '.wmv', '.wma', '.ogm', '.mxf', '.nut', '.AIFF', '.AU', '.AVR', '.CAF', '.FLAC', '.HTK', '.SVX', '.MAT4', '.MAT5', '.MPC2K', '.OGG', '.PAF', '.PVF', '.RAW', '.RF64', '.SD2', '.SDS', '.IRCAM', '.VOC', '.W64', '.WAV', '.NIST', '.WAVEX', '.WVE', '.XI', '.MP3', '.OPUS', '.3GP', '.3G2', '.AVI', '.ASF', '.FLV', '.MP4', '.MOV', '.M4V', '.MKV', '.WEBM', '.F4V', '.WMV', '.WMA', '.OGM', '.MXF', '.NUT', '.pdf', '.PDF', '.nii', '.nii.gz', '.NII', '.NII.GZ', '.zip']"
711
+ ]
712
+ }
713
+ ],
714
  "source": [
715
+ "# Chromosomes split definition\n",
716
+ "chrom_splits = {\n",
717
+ " \"train\": [f\"chr{i}\" for i in range(1, 21)] + ['chrX', 'chrY'],\n",
718
+ " \"val\": ['chr22'],\n",
719
+ " \"test\": ['chr21']\n",
720
+ "}\n",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
721
  "\n",
722
+ "# Number of desired samples per split\n",
723
+ "num_samples = {\n",
724
+ " \"train\": config[\"num_steps_training\"] * config[\"batch_size\"],\n",
725
+ " \"val\": config[\"num_validation_samples\"],\n",
726
+ " \"test\": config[\"num_test_samples\"],\n",
727
+ "}\n",
728
  "\n",
729
+ "print(f\"Loading dataset from {config['dataset_name']}...\")\n",
730
+ "dataset = load_dataset(\n",
731
+ " config[\"dataset_name\"],\n",
732
+ " data_files=chrom_splits,\n",
733
+ " num_samples=num_samples,\n",
734
+ " fasta_url=config[\"fasta_url\"],\n",
735
+ " bigwig_urls=config[\"bigwig_url_list\"],\n",
736
+ " sequence_length=config[\"sequence_length\"],\n",
737
+ " data_dir=config[\"data_cache_dir\"],\n",
738
+ ")"
 
 
 
 
 
 
 
 
739
  ]
740
  },
741
  {
 
744
  "metadata": {},
745
  "outputs": [],
746
  "source": [
747
+ "# Tokenization function\n",
748
+ "def tokenize_examples(examples):\n",
749
+ " \"\"\"Tokenize sequences and prepare targets.\"\"\"\n",
750
+ " sequences = examples[\"sequence\"]\n",
751
+ " \n",
752
+ " # Tokenize sequences\n",
753
+ " tokenized = tokenizer(\n",
754
+ " sequences,\n",
755
+ " max_length=config[\"sequence_length\"],\n",
756
+ " padding=\"max_length\",\n",
757
+ " truncation=True,\n",
758
+ " return_tensors=None,\n",
759
+ " )\n",
760
+ " \n",
761
+ " # Crop targets to center fraction if needed\n",
762
+ " if config[\"keep_target_center_fraction\"] < 1.0:\n",
763
+ " seq_len = examples[\"bigwig_targets\"].shape[0]\n",
764
+ " target_offset = int(seq_len * (1 - config[\"keep_target_center_fraction\"]) // 2)\n",
765
+ " target_length = seq_len - 2 * target_offset\n",
766
+ " examples[\"bigwig_targets\"] = examples[\"bigwig_targets\"][target_offset:target_offset + target_length, :]\n",
767
+ " \n",
768
+ " return {\n",
769
+ " \"tokens\": tokenized[\"input_ids\"],\n",
770
+ " \"bigwig_targets\": examples[\"bigwig_targets\"],\n",
771
+ " }\n",
772
  "\n",
773
+ "# Apply tokenization\n",
774
+ "print(\"Tokenizing sequences...\")\n",
775
+ "dataset = dataset.map(\n",
776
+ " tokenize_examples,\n",
777
+ " batched=True,\n",
778
+ " remove_columns=[\"sequence\"], # Remove original sequence after tokenization\n",
779
  ")\n",
780
  "\n",
781
+ "# Format for PyTorch\n",
782
+ "dataset = dataset.with_format(\"torch\")\n",
 
 
 
 
 
783
  "\n",
784
+ "dataloaders = {}\n",
785
+ "for split_name in chrom_splits.keys():\n",
786
+ " dataloaders[split_name] = DataLoader(\n",
787
+ " dataset[split_name],\n",
788
+ " batch_size=config[\"batch_size\"],\n",
789
+ " shuffle=(split_name == \"train\"),\n",
790
+ " num_workers=config[\"num_workers\"],\n",
791
+ " )\n",
792
  "\n",
793
+ "# Extract DataLoaders\n",
794
+ "train_loader = dataloaders[\"train\"]\n",
795
+ "val_loader = dataloaders[\"val\"]\n",
796
+ "test_loader = dataloaders[\"test\"]\n",
 
 
797
  "\n",
798
+ "print(f\"\\nData pipeline created successfully!\")\n",
799
+ "print(f\"Train batches: {len(train_loader)}\")\n",
800
+ "print(f\"Val batches: {len(val_loader)}\")\n",
801
+ "print(f\"Test batches: {len(test_loader)}\")"
802
  ]
803
  },
804
  {
805
  "cell_type": "markdown",
806
  "metadata": {},
807
  "source": [
808
+ "# 4. ⚙️ Optimizer setup\n",
809
  "\n",
810
+ "Configure the AdamW optimizer with learning rate and weight decay hyperparameters. This optimizer will update the model parameters during training to minimize the loss function."
 
811
  ]
812
  },
813
  {
814
  "cell_type": "code",
815
+ "execution_count": 15,
816
  "metadata": {},
817
+ "outputs": [
818
+ {
819
+ "name": "stdout",
820
+ "output_type": "stream",
821
+ "text": [
822
+ "Training configuration:\n",
823
+ " Batch size: 32\n",
824
+ " Total training steps: 19932\n",
825
+ " Log metrics every: 40 steps\n",
826
+ " Validate every: 400 steps\n",
827
+ "\n",
828
+ "Optimizer setup:\n",
829
+ " Learning rate: 1e-05\n"
830
+ ]
831
+ }
832
+ ],
833
  "source": [
834
  "# Training setup\n",
835
  "print(f\"Training configuration:\")\n",
 
853
  "cell_type": "markdown",
854
  "metadata": {},
855
  "source": [
856
+ "# 5. 📊 Metrics setup\n",
857
  "\n",
858
  "Set up evaluation metrics to track model performance during training and validation. We use Pearson correlation coefficients to measure how well the predicted BigWig signals match the ground truth signals."
859
  ]
860
  },
861
  {
862
  "cell_type": "code",
863
+ "execution_count": 16,
864
  "metadata": {},
865
  "outputs": [],
866
  "source": [
 
933
  },
934
  {
935
  "cell_type": "code",
936
+ "execution_count": 17,
937
  "metadata": {},
938
  "outputs": [],
939
  "source": [
 
946
  "cell_type": "markdown",
947
  "metadata": {},
948
  "source": [
949
+ "# 6. 📉 Loss function\n",
950
  "\n",
951
  "Define the Poisson-Multinomial loss function that captures both the scale (total signal) and shape (distribution) of BigWig tracks. This loss is specifically designed for count-based genomic signal data."
952
  ]
953
  },
954
  {
955
  "cell_type": "code",
956
+ "execution_count": 18,
957
  "metadata": {},
958
  "outputs": [],
959
  "source": [
 
1028
  "cell_type": "markdown",
1029
  "metadata": {},
1030
  "source": [
1031
+ "# 7. 🏃 Training loop\n",
1032
  "\n",
1033
  "Run the main training loop that iterates through batches, computes gradients, and updates model parameters. The loop includes periodic validation checks and real-time metric visualization to monitor training progress."
1034
  ]
1035
  },
1036
  {
1037
  "cell_type": "code",
1038
+ "execution_count": 19,
1039
  "metadata": {},
1040
  "outputs": [],
1041
  "source": [
 
1103
  },
1104
  {
1105
  "cell_type": "code",
1106
+ "execution_count": 18,
1107
  "metadata": {},
1108
  "outputs": [],
1109
  "source": [
 
1260
  "cell_type": "markdown",
1261
  "metadata": {},
1262
  "source": [
1263
+ "# 9. 🧪 Test evaluation\n",
1264
  "\n",
1265
  "Evaluate the fine-tuned model on the held-out test set to assess final performance. This provides an unbiased estimate of how well the model generalizes to unseen genomic regions."
1266
  ]
 
1338
  },
1339
  "nbformat": 4,
1340
  "nbformat_minor": 2
1341
+ }