Spaces:
Running
Running
Commit
·
2b05bdb
1
Parent(s):
6db01cc
feat: made compat with HF dataset + refactor
Browse files- notebooks/03_fine_tuning.ipynb +212 -339
notebooks/03_fine_tuning.ipynb
CHANGED
|
@@ -52,7 +52,7 @@
|
|
| 52 |
},
|
| 53 |
{
|
| 54 |
"cell_type": "code",
|
| 55 |
-
"execution_count":
|
| 56 |
"metadata": {},
|
| 57 |
"outputs": [],
|
| 58 |
"source": [
|
|
@@ -60,8 +60,9 @@
|
|
| 60 |
"import functools\n",
|
| 61 |
"from typing import List, Dict, Callable\n",
|
| 62 |
"import os\n",
|
| 63 |
-
"import
|
| 64 |
-
"from
|
|
|
|
| 65 |
"\n",
|
| 66 |
"import torch\n",
|
| 67 |
"import torch.nn as nn\n",
|
|
@@ -69,6 +70,8 @@
|
|
| 69 |
"from torch.utils.data import Dataset, DataLoader\n",
|
| 70 |
"from torch.optim import AdamW\n",
|
| 71 |
"from transformers import AutoConfig, AutoModelForMaskedLM, AutoTokenizer\n",
|
|
|
|
|
|
|
| 72 |
"import numpy as np\n",
|
| 73 |
"import pyBigWig\n",
|
| 74 |
"from pyfaidx import Fasta\n",
|
|
@@ -90,10 +93,9 @@
|
|
| 90 |
"- **`model_name`**: HuggingFace model name/identifier for the pretrained backbone model\n",
|
| 91 |
"\n",
|
| 92 |
"### Data\n",
|
|
|
|
|
|
|
| 93 |
"- **`data_cache_dir`**: Directory where downloaded data files (FASTA, bigWig) will be stored\n",
|
| 94 |
-
"- **`fasta_url`**: URL to download reference genome FASTA file\n",
|
| 95 |
-
"- **`bigwig_url_list`**: List of URLs for bigWig track files to download\n",
|
| 96 |
-
"- **`bigwig_file_ids`**: List of identifiers/names for bigWig tracks (set after downloading, used for model head and metrics)\n",
|
| 97 |
"- **`sequence_length`**: Length of input sequences in base pairs (bp)\n",
|
| 98 |
"- **`keep_target_center_fraction`**: Fraction of center sequence to keep for target prediction (crops edges to focus on center)\n",
|
| 99 |
"\n",
|
|
@@ -108,6 +110,9 @@
|
|
| 108 |
"- **`validate_every_n_steps`**: Run validation every N steps\n",
|
| 109 |
"- **`num_validation_samples`**: Number of samples to use for validation set\n",
|
| 110 |
"\n",
|
|
|
|
|
|
|
|
|
|
| 111 |
"### General\n",
|
| 112 |
"- **`seed`**: Random seed for reproducibility\n",
|
| 113 |
"- **`device`**: Device to run training on (\"cuda\" or \"cpu\")\n",
|
|
@@ -116,31 +121,18 @@
|
|
| 116 |
},
|
| 117 |
{
|
| 118 |
"cell_type": "code",
|
| 119 |
-
"execution_count":
|
| 120 |
"metadata": {},
|
| 121 |
-
"outputs": [
|
| 122 |
-
{
|
| 123 |
-
"name": "stdout",
|
| 124 |
-
"output_type": "stream",
|
| 125 |
-
"text": [
|
| 126 |
-
"Using device: cpu\n"
|
| 127 |
-
]
|
| 128 |
-
}
|
| 129 |
-
],
|
| 130 |
"source": [
|
| 131 |
"config = {\n",
|
| 132 |
" # Model\n",
|
| 133 |
" \"model_name\": \"InstaDeepAI/NTv3_8M_pre\",\n",
|
| 134 |
" \n",
|
| 135 |
" # Data\n",
|
|
|
|
|
|
|
| 136 |
" \"data_cache_dir\": \"./data\",\n",
|
| 137 |
-
" \"fasta_url\": \"https://hgdownload.gi.ucsc.edu/goldenPath/hg38/bigZips/hg38.fa.gz\",\n",
|
| 138 |
-
" \"bigwig_url_list\": [\n",
|
| 139 |
-
" \"https://www.encodeproject.org/files/ENCFF055QKS/@@download/ENCFF055QKS.bigWig\",\n",
|
| 140 |
-
" \"https://www.encodeproject.org/files/ENCFF214GOQ/@@download/ENCFF214GOQ.bigWig\",\n",
|
| 141 |
-
" \"https://www.encodeproject.org/files/ENCFF592NIB/@@download/ENCFF592NIB.bigWig\",\n",
|
| 142 |
-
" \"https://www.encodeproject.org/files/ENCFF921PHQ/@@download/ENCFF921PHQ.bigWig\",\n",
|
| 143 |
-
" ],\n",
|
| 144 |
" \"sequence_length\": 32_768,\n",
|
| 145 |
" \"keep_target_center_fraction\": 0.375,\n",
|
| 146 |
" \n",
|
|
@@ -159,40 +151,11 @@
|
|
| 159 |
" \"num_test_samples\": 10000,\n",
|
| 160 |
" \n",
|
| 161 |
" # General\n",
|
| 162 |
-
" \"seed\":
|
| 163 |
" \"device\": \"cuda\" if torch.cuda.is_available() else \"cpu\",\n",
|
| 164 |
" \"num_workers\": 16,\n",
|
| 165 |
"}\n",
|
| 166 |
"\n",
|
| 167 |
-
"os.makedirs(config[\"data_cache_dir\"], exist_ok=True)\n",
|
| 168 |
-
"\n",
|
| 169 |
-
"# Extract filenames from URLs\n",
|
| 170 |
-
"def extract_filename_from_url(url: str) -> str:\n",
|
| 171 |
-
" \"\"\"Extract filename from URL, handling query parameters.\"\"\"\n",
|
| 172 |
-
" # Remove query parameters if present\n",
|
| 173 |
-
" url_clean = url.split('?')[0]\n",
|
| 174 |
-
" # Get the last part of the URL path\n",
|
| 175 |
-
" return url_clean.split('/')[-1]\n",
|
| 176 |
-
"\n",
|
| 177 |
-
"# Create paths for downloaded files\n",
|
| 178 |
-
"fasta_path = os.path.join(config[\"data_cache_dir\"], extract_filename_from_url(config[\"fasta_url\"]).replace('.gz', ''))\n",
|
| 179 |
-
"bigwig_path_list = [\n",
|
| 180 |
-
" os.path.join(config[\"data_cache_dir\"], extract_filename_from_url(url))\n",
|
| 181 |
-
" for url in config[\"bigwig_url_list\"]\n",
|
| 182 |
-
"]\n",
|
| 183 |
-
"\n",
|
| 184 |
-
"\n",
|
| 185 |
-
"# TODO: find a way to link the experiment accession to bigwig file ids\n",
|
| 186 |
-
"# Create bigwig_file_ids from filenames (without extension)\n",
|
| 187 |
-
"config[\"bigwig_file_ids\"] = [\n",
|
| 188 |
-
" # os.path.splitext(extract_filename_from_url(url))[0]\n",
|
| 189 |
-
" # for url in config[\"bigwig_url_list\"]\n",
|
| 190 |
-
" \"ENCSR325NFE\",\n",
|
| 191 |
-
" \"ENCSR962OTG\",\n",
|
| 192 |
-
" \"ENCSR619DQO_P\",\n",
|
| 193 |
-
" \"ENCSR619DQO_M\",\n",
|
| 194 |
-
"]\n",
|
| 195 |
-
"\n",
|
| 196 |
"# Set random seed\n",
|
| 197 |
"torch.manual_seed(config[\"seed\"])\n",
|
| 198 |
"np.random.seed(config[\"seed\"])\n",
|
|
@@ -217,56 +180,99 @@
|
|
| 217 |
"metadata": {},
|
| 218 |
"outputs": [],
|
| 219 |
"source": [
|
| 220 |
-
"def
|
| 221 |
-
"
|
| 222 |
-
"
|
| 223 |
-
"\n",
|
| 224 |
-
"
|
| 225 |
-
"
|
| 226 |
-
"
|
| 227 |
-
"
|
| 228 |
-
"
|
| 229 |
-
"
|
| 230 |
-
"
|
| 231 |
-
"
|
| 232 |
-
"
|
| 233 |
-
"
|
| 234 |
-
"
|
| 235 |
-
"
|
| 236 |
-
" download_tasks.append((bigwig_url, filepath))\n",
|
| 237 |
-
"\n",
|
| 238 |
-
"# Download files in parallel\n",
|
| 239 |
-
"max_workers = min(len(download_tasks), 8)\n",
|
| 240 |
-
"\n",
|
| 241 |
-
"print(f\"Downloading {len(download_tasks)} files using {max_workers} workers...\")\n",
|
| 242 |
-
"with ThreadPoolExecutor(max_workers=max_workers) as executor:\n",
|
| 243 |
-
" # Submit all download tasks\n",
|
| 244 |
-
" future_to_path = {\n",
|
| 245 |
-
" executor.submit(_download_file, url, path): path\n",
|
| 246 |
-
" for url, path in download_tasks\n",
|
| 247 |
-
" }\n",
|
| 248 |
" \n",
|
| 249 |
-
" #
|
| 250 |
-
"
|
| 251 |
-
"
|
| 252 |
-
"
|
| 253 |
-
"
|
| 254 |
-
"
|
| 255 |
-
"
|
| 256 |
-
"
|
| 257 |
-
"
|
| 258 |
-
"
|
| 259 |
-
"
|
| 260 |
-
"
|
| 261 |
-
"
|
| 262 |
-
"
|
| 263 |
-
|
| 264 |
-
|
| 265 |
-
|
| 266 |
-
|
| 267 |
-
|
| 268 |
-
|
| 269 |
-
"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 270 |
]
|
| 271 |
},
|
| 272 |
{
|
|
@@ -275,11 +281,20 @@
|
|
| 275 |
"metadata": {},
|
| 276 |
"outputs": [],
|
| 277 |
"source": [
|
| 278 |
-
"
|
| 279 |
-
"
|
| 280 |
-
"
|
| 281 |
-
"
|
| 282 |
-
"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 283 |
]
|
| 284 |
},
|
| 285 |
{
|
|
@@ -335,7 +350,7 @@
|
|
| 335 |
" self.backbone = AutoModelForMaskedLM.from_pretrained(\n",
|
| 336 |
" model_name, \n",
|
| 337 |
" trust_remote_code=True,\n",
|
| 338 |
-
" config=self.config
|
| 339 |
" )\n",
|
| 340 |
" \n",
|
| 341 |
" self.keep_target_center_fraction = keep_target_center_fraction\n",
|
|
@@ -351,7 +366,7 @@
|
|
| 351 |
" \n",
|
| 352 |
" def forward(self, tokens: torch.Tensor, **kwargs) -> Dict[str, torch.Tensor]:\n",
|
| 353 |
" # Forward through backbone\n",
|
| 354 |
-
" outputs = self.backbone(input_ids=tokens)\n",
|
| 355 |
" embedding = outputs.hidden_states[-1] # Last hidden state\n",
|
| 356 |
" \n",
|
| 357 |
" # Crop to center fraction\n",
|
|
@@ -379,14 +394,14 @@
|
|
| 379 |
"# Create model\n",
|
| 380 |
"model = HFModelWithHead(\n",
|
| 381 |
" model_name=config[\"model_name\"],\n",
|
| 382 |
-
" bigwig_track_names=
|
| 383 |
" keep_target_center_fraction=config[\"keep_target_center_fraction\"],\n",
|
| 384 |
")\n",
|
| 385 |
"model = model.to(device)\n",
|
| 386 |
"model.train()\n",
|
| 387 |
"\n",
|
| 388 |
"print(f\"Model loaded: {config['model_name']}\")\n",
|
| 389 |
-
"print(f\"Number of bigwig tracks: {len(
|
| 390 |
"print(f\"Model parameters: {sum(p.numel() for p in model.parameters()):,}\")"
|
| 391 |
]
|
| 392 |
},
|
|
@@ -426,8 +441,8 @@
|
|
| 426 |
" Random genomic windows from a reference genome + bigWig signal.\n",
|
| 427 |
"\n",
|
| 428 |
" Each sample:\n",
|
| 429 |
-
" - picks a
|
| 430 |
-
" - picks a random window of length `sequence_length
|
| 431 |
" - returns (sequence, signal, chrom, start, end).\n",
|
| 432 |
"\n",
|
| 433 |
" This dataset is compatible with multi-worker DataLoaders. BigWig files\n",
|
|
@@ -438,11 +453,13 @@
|
|
| 438 |
" ----\n",
|
| 439 |
" fasta_path : str\n",
|
| 440 |
" Path to the reference genome FASTA (e.g. hg38.fna).\n",
|
| 441 |
-
" bigwig_path_list : str\n",
|
| 442 |
-
"
|
| 443 |
-
"
|
| 444 |
-
"
|
| 445 |
-
"
|
|
|
|
|
|
|
| 446 |
" sequence_length : int\n",
|
| 447 |
" Length of each random window (in bp).\n",
|
| 448 |
" num_samples : int\n",
|
|
@@ -453,18 +470,14 @@
|
|
| 453 |
" Function to transform/scaling bigwig targets.\n",
|
| 454 |
" keep_target_center_fraction : float\n",
|
| 455 |
" Fraction of center sequence to keep for target prediction (crops edges to focus on center).\n",
|
| 456 |
-
" regions : List[tuple[str, int, int]] | None\n",
|
| 457 |
-
" Optional list of regions as (chromosome, start, end) tuples.\n",
|
| 458 |
-
" If provided, samples are drawn randomly from within these regions only.\n",
|
| 459 |
-
" This matches the JAX pipeline approach using BED file splits.\n",
|
| 460 |
-
" If None, samples from entire chromosomes in `chroms`.\n",
|
| 461 |
" \"\"\"\n",
|
| 462 |
"\n",
|
| 463 |
" def __init__(\n",
|
| 464 |
" self,\n",
|
| 465 |
" fasta_path: str,\n",
|
| 466 |
" bigwig_path_list: list[str],\n",
|
| 467 |
-
"
|
|
|
|
| 468 |
" sequence_length: int,\n",
|
| 469 |
" num_samples: int,\n",
|
| 470 |
" tokenizer: AutoTokenizer,\n",
|
|
@@ -479,43 +492,37 @@
|
|
| 479 |
" self.sequence_length = sequence_length\n",
|
| 480 |
" self.num_samples = num_samples\n",
|
| 481 |
" self.tokenizer = tokenizer\n",
|
| 482 |
-
" self.transform_fn = transform_fn
|
| 483 |
" self.keep_target_center_fraction = keep_target_center_fraction\n",
|
| 484 |
-
" self.
|
| 485 |
"\n",
|
| 486 |
-
" #
|
| 487 |
-
"
|
| 488 |
-
" bw_handle = _get_bigwig_handle(bigwig_path_list[0])\n",
|
| 489 |
-
" bw_chrom_lengths = bw_handle.chroms() # dict: chrom -> length\n",
|
| 490 |
"\n",
|
| 491 |
-
"
|
| 492 |
-
" self.
|
|
|
|
| 493 |
"\n",
|
| 494 |
-
"
|
| 495 |
-
" if
|
| 496 |
" continue\n",
|
|
|
|
|
|
|
|
|
|
| 497 |
"\n",
|
| 498 |
-
"
|
| 499 |
-
"
|
| 500 |
-
" L = min(fa_len, bw_len)\n",
|
| 501 |
-
"\n",
|
| 502 |
-
" if L > self.sequence_length:\n",
|
| 503 |
-
" self.valid_chroms.append(c)\n",
|
| 504 |
-
" self.chrom_lengths[c] = L\n",
|
| 505 |
-
"\n",
|
| 506 |
-
" if not self.valid_chroms:\n",
|
| 507 |
-
" raise ValueError(\"No valid chromosomes after intersecting FASTA and bigWig.\")\n",
|
| 508 |
"\n",
|
| 509 |
" def __len__(self):\n",
|
| 510 |
" return self.num_samples\n",
|
| 511 |
"\n",
|
| 512 |
" def __getitem__(self, idx):\n",
|
| 513 |
-
"\n",
|
| 514 |
-
"
|
| 515 |
-
"
|
| 516 |
-
"
|
| 517 |
-
" max_start =
|
| 518 |
-
" start = random.randint(
|
| 519 |
" end = start + self.sequence_length\n",
|
| 520 |
"\n",
|
| 521 |
" # Sequence\n",
|
|
@@ -575,133 +582,26 @@
|
|
| 575 |
"metadata": {},
|
| 576 |
"outputs": [],
|
| 577 |
"source": [
|
| 578 |
-
"
|
| 579 |
-
"
|
| 580 |
-
"
|
| 581 |
-
" Compute minimal statistics needed for weighted mean computation.\n",
|
| 582 |
-
" \n",
|
| 583 |
-
" Args:\n",
|
| 584 |
-
" track_data: numpy array of track values for a chromosome\n",
|
| 585 |
-
" \n",
|
| 586 |
-
" Returns:\n",
|
| 587 |
-
" Dictionary with statistics: sum, mean, total_count\n",
|
| 588 |
" \"\"\"\n",
|
| 589 |
-
"
|
| 590 |
-
" \n",
|
| 591 |
-
" # Compute statistics\n",
|
| 592 |
-
" sum_all = np.sum(track_data)\n",
|
| 593 |
-
" total_count = track_data.size\n",
|
| 594 |
-
" mean_all = sum_all / total_count if total_count > 0 else 0.0\n",
|
| 595 |
-
" \n",
|
| 596 |
-
" return {\n",
|
| 597 |
-
" \"sum\": sum_all,\n",
|
| 598 |
-
" \"mean\": mean_all,\n",
|
| 599 |
-
" \"total_count\": total_count,\n",
|
| 600 |
-
" }\n",
|
| 601 |
"\n",
|
| 602 |
-
"\n",
|
| 603 |
-
"def aggregate_file_statistics(chr_stats_list: List[dict]) -> dict:\n",
|
| 604 |
-
" \"\"\"\n",
|
| 605 |
-
" Aggregate chromosome-level statistics into file-level statistics.\n",
|
| 606 |
-
" \n",
|
| 607 |
" Args:\n",
|
| 608 |
-
"
|
| 609 |
-
" \n",
|
| 610 |
-
" Returns:\n",
|
| 611 |
-
" Dictionary with aggregated file-level statistics (only mean)\n",
|
| 612 |
-
" \"\"\"\n",
|
| 613 |
-
" # Convert to arrays for easier computation\n",
|
| 614 |
-
" total_counts = np.array([s[\"total_count\"] for s in chr_stats_list], dtype=np.int64)\n",
|
| 615 |
-
" means = np.array([s[\"mean\"] for s in chr_stats_list], dtype=np.float32)\n",
|
| 616 |
-
" sums = np.array([s[\"sum\"] for s in chr_stats_list], dtype=np.float32)\n",
|
| 617 |
-
" \n",
|
| 618 |
-
" # Aggregate total count\n",
|
| 619 |
-
" total_count = np.sum(total_counts)\n",
|
| 620 |
-
" \n",
|
| 621 |
-
" # Weighted mean: mean = sum(mean_chr * total_count_chr) / sum(total_count_chr)\n",
|
| 622 |
-
" mean = np.sum(means * total_counts) / total_count if total_count > 0 else 0.0\n",
|
| 623 |
-
" \n",
|
| 624 |
-
" return {\n",
|
| 625 |
-
" \"total_count\": total_count,\n",
|
| 626 |
-
" \"sum\": np.sum(sums),\n",
|
| 627 |
-
" \"mean\": mean,\n",
|
| 628 |
-
" }\n",
|
| 629 |
-
"\n",
|
| 630 |
"\n",
|
| 631 |
-
"def get_track_means(bigwig_tracks_list: List[pyBigWig.pyBigWig]) -> np.ndarray:\n",
|
| 632 |
-
" \"\"\"\n",
|
| 633 |
-
" Get track means for normalization.\n",
|
| 634 |
-
" Computes statistics per chromosome and aggregates using weighted averaging,\n",
|
| 635 |
-
" \n",
|
| 636 |
-
" Args:\n",
|
| 637 |
-
" bigwig_tracks_list: List of pyBigWig file objects\n",
|
| 638 |
-
" \n",
|
| 639 |
-
" Returns:\n",
|
| 640 |
-
" Array of track means, one per bigwig file\n",
|
| 641 |
-
" \"\"\"\n",
|
| 642 |
-
" track_means = []\n",
|
| 643 |
-
" \n",
|
| 644 |
-
" for bigwig_track in bigwig_tracks_list:\n",
|
| 645 |
-
" chrom_lengths = bigwig_track.chroms()\n",
|
| 646 |
-
" all_chr_stats = []\n",
|
| 647 |
-
" \n",
|
| 648 |
-
" # Compute statistics for each chromosome\n",
|
| 649 |
-
" for chrom_name, chrom_length in chrom_lengths.items():\n",
|
| 650 |
-
" try:\n",
|
| 651 |
-
" # Get chromosome data as numpy array\n",
|
| 652 |
-
" bw_array = np.array(\n",
|
| 653 |
-
" bigwig_track.values(chrom_name, 0, chrom_length, numpy=True),\n",
|
| 654 |
-
" dtype=np.float32\n",
|
| 655 |
-
" )\n",
|
| 656 |
-
" # Replace NaN with 0\n",
|
| 657 |
-
" bw_array = np.nan_to_num(bw_array, nan=0.0)\n",
|
| 658 |
-
" \n",
|
| 659 |
-
" # Compute chromosome-level statistics\n",
|
| 660 |
-
" chr_stats = compute_chromosome_stats(bw_array)\n",
|
| 661 |
-
" all_chr_stats.append(chr_stats)\n",
|
| 662 |
-
" except Exception as e:\n",
|
| 663 |
-
" # Skip chromosomes that fail to load\n",
|
| 664 |
-
" print(f\"Warning: Failed to load chromosome {chrom_name}: {e}\")\n",
|
| 665 |
-
" continue\n",
|
| 666 |
-
" \n",
|
| 667 |
-
" if not all_chr_stats:\n",
|
| 668 |
-
" raise ValueError(f\"No valid chromosomes found for bigwig track\")\n",
|
| 669 |
-
" \n",
|
| 670 |
-
" # Aggregate chromosome-level stats into file-level stats\n",
|
| 671 |
-
" file_stats = aggregate_file_statistics(all_chr_stats)\n",
|
| 672 |
-
" \n",
|
| 673 |
-
" # Use the weighted mean for normalization\n",
|
| 674 |
-
" track_means.append(file_stats[\"mean\"])\n",
|
| 675 |
-
" \n",
|
| 676 |
-
" return np.array(track_means, dtype=np.float32)\n",
|
| 677 |
-
"\n",
|
| 678 |
-
"\n",
|
| 679 |
-
"def create_targets_scaling_fn(bigwig_path_list: List[str]) -> Callable[[torch.Tensor], torch.Tensor]:\n",
|
| 680 |
-
" \"\"\"\n",
|
| 681 |
-
" Build a scaling function based on track means computed from bigwig files.\n",
|
| 682 |
-
" \n",
|
| 683 |
-
" Opens bigwig files, computes track statistics, and creates a transform function.\n",
|
| 684 |
-
" The statistics are computed once and reused for all calls to the returned transform function.\n",
|
| 685 |
-
" \n",
|
| 686 |
-
" Args:\n",
|
| 687 |
-
" bigwig_path_list: List of paths to bigwig files\n",
|
| 688 |
-
" \n",
|
| 689 |
" Returns:\n",
|
| 690 |
" Transform function that scales input tensors\n",
|
| 691 |
" \"\"\"\n",
|
| 692 |
" # Open bigwig files and compute track statistics\n",
|
| 693 |
-
"
|
| 694 |
-
"
|
| 695 |
-
"
|
| 696 |
-
"
|
| 697 |
-
" ]\n",
|
| 698 |
-
" track_means = get_track_means(bw_list)\n",
|
| 699 |
-
" print(f\"Computed track means: {track_means}\")\n",
|
| 700 |
-
" print(f\"Track means shape: {track_means.shape}\")\n",
|
| 701 |
-
" \n",
|
| 702 |
" # Create tensor from computed means\n",
|
| 703 |
" track_means_tensor = torch.tensor(track_means, dtype=torch.float32)\n",
|
| 704 |
-
"
|
| 705 |
" def transform_fn(x: torch.Tensor) -> torch.Tensor:\n",
|
| 706 |
" \"\"\"\n",
|
| 707 |
" x: torch.Tensor, shape (seq_len, num_tracks) or (batch, seq_len, num_tracks)\n",
|
|
@@ -717,20 +617,10 @@
|
|
| 717 |
" scaled,\n",
|
| 718 |
" )\n",
|
| 719 |
" return clipped\n",
|
| 720 |
-
"
|
| 721 |
" return transform_fn"
|
| 722 |
]
|
| 723 |
},
|
| 724 |
-
{
|
| 725 |
-
"cell_type": "code",
|
| 726 |
-
"execution_count": null,
|
| 727 |
-
"metadata": {},
|
| 728 |
-
"outputs": [],
|
| 729 |
-
"source": [
|
| 730 |
-
"# Create scaling function\n",
|
| 731 |
-
"targets_transform_fn = create_targets_scaling_fn(bigwig_path_list)"
|
| 732 |
-
]
|
| 733 |
-
},
|
| 734 |
{
|
| 735 |
"cell_type": "code",
|
| 736 |
"execution_count": null,
|
|
@@ -741,25 +631,26 @@
|
|
| 741 |
"create_dataset_fn = functools.partial(\n",
|
| 742 |
" GenomeBigWigDataset,\n",
|
| 743 |
" fasta_path=fasta_path,\n",
|
| 744 |
-
" bigwig_path_list=
|
|
|
|
| 745 |
" sequence_length=config[\"sequence_length\"],\n",
|
| 746 |
" tokenizer=tokenizer,\n",
|
| 747 |
-
" transform_fn=
|
| 748 |
" keep_target_center_fraction=config[\"keep_target_center_fraction\"],\n",
|
| 749 |
")\n",
|
| 750 |
"\n",
|
| 751 |
"train_dataset = create_dataset_fn(\n",
|
| 752 |
-
"
|
| 753 |
" num_samples=config[\"num_steps_training\"] * config[\"batch_size\"],\n",
|
| 754 |
")\n",
|
| 755 |
"\n",
|
| 756 |
"val_dataset = create_dataset_fn(\n",
|
| 757 |
-
"
|
| 758 |
" num_samples=config[\"num_validation_samples\"],\n",
|
| 759 |
")\n",
|
| 760 |
"\n",
|
| 761 |
"test_dataset = create_dataset_fn(\n",
|
| 762 |
-
"
|
| 763 |
" num_samples=config[\"num_test_samples\"],\n",
|
| 764 |
")\n",
|
| 765 |
"\n",
|
|
@@ -785,7 +676,7 @@
|
|
| 785 |
" num_workers=config[\"num_workers\"],\n",
|
| 786 |
")\n",
|
| 787 |
"\n",
|
| 788 |
-
"print(f\"
|
| 789 |
"print(f\"Val samples: {len(val_dataset)}\")\n",
|
| 790 |
"print(f\"Test samples: {len(test_dataset)}\")"
|
| 791 |
]
|
|
@@ -912,9 +803,9 @@
|
|
| 912 |
"metadata": {},
|
| 913 |
"outputs": [],
|
| 914 |
"source": [
|
| 915 |
-
"train_metrics = TracksMetrics(
|
| 916 |
-
"val_metrics = TracksMetrics(
|
| 917 |
-
"test_metrics = TracksMetrics(
|
| 918 |
]
|
| 919 |
},
|
| 920 |
{
|
|
@@ -1098,47 +989,6 @@
|
|
| 1098 |
"val_losses = []\n",
|
| 1099 |
"val_pearson_scores = []\n",
|
| 1100 |
"\n",
|
| 1101 |
-
"# Initialize interactive plots using FigureWidget for real-time updates\n",
|
| 1102 |
-
"from plotly.graph_objects import FigureWidget\n",
|
| 1103 |
-
"from plotly.subplots import make_subplots\n",
|
| 1104 |
-
"\n",
|
| 1105 |
-
"# Create base figure with subplots\n",
|
| 1106 |
-
"fig_base = make_subplots(\n",
|
| 1107 |
-
" rows=1, cols=2,\n",
|
| 1108 |
-
" subplot_titles=('Loss', 'Mean Pearson Correlation'),\n",
|
| 1109 |
-
" horizontal_spacing=0.15,\n",
|
| 1110 |
-
")\n",
|
| 1111 |
-
"\n",
|
| 1112 |
-
"# Add empty traces for train and val metrics\n",
|
| 1113 |
-
"fig_base.add_trace(\n",
|
| 1114 |
-
" go.Scatter(x=[], y=[], mode='lines+markers', name='Train Loss', line=dict(color='blue')),\n",
|
| 1115 |
-
" row=1, col=1\n",
|
| 1116 |
-
")\n",
|
| 1117 |
-
"fig_base.add_trace(\n",
|
| 1118 |
-
" go.Scatter(x=[], y=[], mode='lines+markers', name='Val Loss', line=dict(color='red')),\n",
|
| 1119 |
-
" row=1, col=1\n",
|
| 1120 |
-
")\n",
|
| 1121 |
-
"fig_base.add_trace(\n",
|
| 1122 |
-
" go.Scatter(x=[], y=[], mode='lines+markers', name='Train Pearson', line=dict(color='green')),\n",
|
| 1123 |
-
" row=1, col=2\n",
|
| 1124 |
-
")\n",
|
| 1125 |
-
"fig_base.add_trace(\n",
|
| 1126 |
-
" go.Scatter(x=[], y=[], mode='lines+markers', name='Val Pearson', line=dict(color='orange')),\n",
|
| 1127 |
-
" row=1, col=2\n",
|
| 1128 |
-
")\n",
|
| 1129 |
-
"\n",
|
| 1130 |
-
"fig_base.update_xaxes(title_text=\"Step\", row=1, col=1)\n",
|
| 1131 |
-
"fig_base.update_xaxes(title_text=\"Step\", row=1, col=2)\n",
|
| 1132 |
-
"fig_base.update_yaxes(title_text=\"Loss\", row=1, col=1)\n",
|
| 1133 |
-
"fig_base.update_yaxes(title_text=\"Pearson Correlation\", row=1, col=2)\n",
|
| 1134 |
-
"fig_base.update_layout(height=800, width=1600, showlegend=True)\n",
|
| 1135 |
-
"\n",
|
| 1136 |
-
"# Convert to FigureWidget for interactive updates\n",
|
| 1137 |
-
"fig = FigureWidget(fig_base)\n",
|
| 1138 |
-
"\n",
|
| 1139 |
-
"# Display initial plot (will update in place during training)\n",
|
| 1140 |
-
"display(fig)\n",
|
| 1141 |
-
"\n",
|
| 1142 |
"# Create iterator for training data (will cycle if needed)\n",
|
| 1143 |
"train_iter = iter(train_loader)\n",
|
| 1144 |
"\n",
|
|
@@ -1183,11 +1033,6 @@
|
|
| 1183 |
" train_losses.append(mean_loss)\n",
|
| 1184 |
" train_pearson_scores.append(train_metrics_dict['mean/pearson'])\n",
|
| 1185 |
" \n",
|
| 1186 |
-
" # Update plots - direct assignment to FigureWidget data updates the plot automatically\n",
|
| 1187 |
-
" fig.data[0].x = train_steps\n",
|
| 1188 |
-
" fig.data[0].y = train_losses\n",
|
| 1189 |
-
" fig.data[2].x = train_steps\n",
|
| 1190 |
-
" fig.data[2].y = train_pearson_scores\n",
|
| 1191 |
" \n",
|
| 1192 |
" print(\n",
|
| 1193 |
" f\"Step {step_idx + 1}/{config['num_steps_training']} | \"\n",
|
|
@@ -1215,11 +1060,6 @@
|
|
| 1215 |
" val_losses.append(val_metrics_dict['loss'])\n",
|
| 1216 |
" val_pearson_scores.append(val_pearson_mean)\n",
|
| 1217 |
" \n",
|
| 1218 |
-
" # Update plots with validation data - direct assignment updates the plot automatically\n",
|
| 1219 |
-
" fig.data[1].x = val_steps\n",
|
| 1220 |
-
" fig.data[1].y = val_losses\n",
|
| 1221 |
-
" fig.data[3].x = val_steps\n",
|
| 1222 |
-
" fig.data[3].y = val_pearson_scores\n",
|
| 1223 |
" \n",
|
| 1224 |
" print(f\" Validation Loss: {val_metrics_dict['loss']:.4f}\")\n",
|
| 1225 |
" print(f\" Validation Mean Pearson: {val_pearson_mean:.4f}\")\n",
|
|
@@ -1228,7 +1068,40 @@
|
|
| 1228 |
" \n",
|
| 1229 |
" model.train() # Back to training mode\n",
|
| 1230 |
"\n",
|
| 1231 |
-
"print(f\"\\nTraining completed after {config['num_steps_training']} steps.\")"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1232 |
]
|
| 1233 |
},
|
| 1234 |
{
|
|
|
|
| 52 |
},
|
| 53 |
{
|
| 54 |
"cell_type": "code",
|
| 55 |
+
"execution_count": null,
|
| 56 |
"metadata": {},
|
| 57 |
"outputs": [],
|
| 58 |
"source": [
|
|
|
|
| 60 |
"import functools\n",
|
| 61 |
"from typing import List, Dict, Callable\n",
|
| 62 |
"import os\n",
|
| 63 |
+
"import fnmatch\n",
|
| 64 |
+
"from pathlib import Path\n",
|
| 65 |
+
"from huggingface_hub import HfApi, snapshot_download\n",
|
| 66 |
"\n",
|
| 67 |
"import torch\n",
|
| 68 |
"import torch.nn as nn\n",
|
|
|
|
| 70 |
"from torch.utils.data import Dataset, DataLoader\n",
|
| 71 |
"from torch.optim import AdamW\n",
|
| 72 |
"from transformers import AutoConfig, AutoModelForMaskedLM, AutoTokenizer\n",
|
| 73 |
+
"import pandas as pd\n",
|
| 74 |
+
"import matplotlib.pyplot as plt\n",
|
| 75 |
"import numpy as np\n",
|
| 76 |
"import pyBigWig\n",
|
| 77 |
"from pyfaidx import Fasta\n",
|
|
|
|
| 93 |
"- **`model_name`**: HuggingFace model name/identifier for the pretrained backbone model\n",
|
| 94 |
"\n",
|
| 95 |
"### Data\n",
|
| 96 |
+
"- **`hf_repo_id`**: HuggingFace dataset repository ID containing the benchmark data\n",
|
| 97 |
+
"- **`species`**: Species name (e.g., \"human\") to select data from the benchmark dataset\n",
|
| 98 |
"- **`data_cache_dir`**: Directory where downloaded data files (FASTA, bigWig) will be stored\n",
|
|
|
|
|
|
|
|
|
|
| 99 |
"- **`sequence_length`**: Length of input sequences in base pairs (bp)\n",
|
| 100 |
"- **`keep_target_center_fraction`**: Fraction of center sequence to keep for target prediction (crops edges to focus on center)\n",
|
| 101 |
"\n",
|
|
|
|
| 110 |
"- **`validate_every_n_steps`**: Run validation every N steps\n",
|
| 111 |
"- **`num_validation_samples`**: Number of samples to use for validation set\n",
|
| 112 |
"\n",
|
| 113 |
+
"### Test\n",
|
| 114 |
+
"- **`num_test_samples`**: Number of samples to use for test set evaluation\n",
|
| 115 |
+
"\n",
|
| 116 |
"### General\n",
|
| 117 |
"- **`seed`**: Random seed for reproducibility\n",
|
| 118 |
"- **`device`**: Device to run training on (\"cuda\" or \"cpu\")\n",
|
|
|
|
| 121 |
},
|
| 122 |
{
|
| 123 |
"cell_type": "code",
|
| 124 |
+
"execution_count": null,
|
| 125 |
"metadata": {},
|
| 126 |
+
"outputs": [],
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 127 |
"source": [
|
| 128 |
"config = {\n",
|
| 129 |
" # Model\n",
|
| 130 |
" \"model_name\": \"InstaDeepAI/NTv3_8M_pre\",\n",
|
| 131 |
" \n",
|
| 132 |
" # Data\n",
|
| 133 |
+
" \"hf_repo_id\": \"InstaDeepAI/NTv3_benchmark_dataset\",\n",
|
| 134 |
+
" \"species\": \"arabidopsis\",\n",
|
| 135 |
" \"data_cache_dir\": \"./data\",\n",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 136 |
" \"sequence_length\": 32_768,\n",
|
| 137 |
" \"keep_target_center_fraction\": 0.375,\n",
|
| 138 |
" \n",
|
|
|
|
| 151 |
" \"num_test_samples\": 10000,\n",
|
| 152 |
" \n",
|
| 153 |
" # General\n",
|
| 154 |
+
" \"seed\": 0,\n",
|
| 155 |
" \"device\": \"cuda\" if torch.cuda.is_available() else \"cpu\",\n",
|
| 156 |
" \"num_workers\": 16,\n",
|
| 157 |
"}\n",
|
| 158 |
"\n",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 159 |
"# Set random seed\n",
|
| 160 |
"torch.manual_seed(config[\"seed\"])\n",
|
| 161 |
"np.random.seed(config[\"seed\"])\n",
|
|
|
|
| 180 |
"metadata": {},
|
| 181 |
"outputs": [],
|
| 182 |
"source": [
|
| 183 |
+
"def prepare_genomics_inputs(\n",
|
| 184 |
+
" species: str,\n",
|
| 185 |
+
" data_cache_dir: str | Path = \"data\",\n",
|
| 186 |
+
" hf_repo_id: str = \"InstaDeepAI/NTv3_benchmark_dataset\",\n",
|
| 187 |
+
") -> tuple[str, list[str], list[str]]:\n",
|
| 188 |
+
" \"\"\"\n",
|
| 189 |
+
" Downloads:\n",
|
| 190 |
+
" 1) FASTA from HF dataset under: <species>/genome.fasta\n",
|
| 191 |
+
" 2) BigWigs from HF dataset under: <species>/functional_tracks/**\n",
|
| 192 |
+
" 3) Splits from HF dataset under: <species>/splits.bed\n",
|
| 193 |
+
" 4) Metadata from HF dataset under: benchmark_metadata.tsv\n",
|
| 194 |
+
" Returns:\n",
|
| 195 |
+
" (fasta_path, bigwig_path_list, bigwig_file_ids)\n",
|
| 196 |
+
" \"\"\"\n",
|
| 197 |
+
" cache = Path(data_cache_dir).expanduser().resolve()\n",
|
| 198 |
+
" cache.mkdir(parents=True, exist_ok=True)\n",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 199 |
" \n",
|
| 200 |
+
" # --- Download metadata + <species> files (FASTA, BigWigs, Splits) ---\n",
|
| 201 |
+
" api = HfApi()\n",
|
| 202 |
+
" files = api.list_repo_files(repo_id=hf_repo_id, repo_type=\"dataset\")\n",
|
| 203 |
+
" \n",
|
| 204 |
+
" # Find all files to download: species directory + metadata at root\n",
|
| 205 |
+
" species_pattern = f\"{species}/**\"\n",
|
| 206 |
+
" metadata_file = \"benchmark_metadata.tsv\"\n",
|
| 207 |
+
" \n",
|
| 208 |
+
" species_files = [p for p in files if fnmatch.fnmatch(p, species_pattern)]\n",
|
| 209 |
+
" if not species_files:\n",
|
| 210 |
+
" raise ValueError(f\"No files found matching '{species_pattern}' in '{hf_repo_id}'\")\n",
|
| 211 |
+
" \n",
|
| 212 |
+
" if metadata_file not in files:\n",
|
| 213 |
+
" raise ValueError(f\"No metadata file found at '{metadata_file}' in '{hf_repo_id}'\")\n",
|
| 214 |
+
" \n",
|
| 215 |
+
" # Download all needed files\n",
|
| 216 |
+
" download_patterns = [species_pattern, metadata_file]\n",
|
| 217 |
+
" local_dir = Path(\n",
|
| 218 |
+
" snapshot_download(\n",
|
| 219 |
+
" repo_id=hf_repo_id,\n",
|
| 220 |
+
" repo_type=\"dataset\",\n",
|
| 221 |
+
" allow_patterns=download_patterns,\n",
|
| 222 |
+
" local_dir=str(cache),\n",
|
| 223 |
+
" )\n",
|
| 224 |
+
" )\n",
|
| 225 |
+
" \n",
|
| 226 |
+
" # --- Organize outputs ---\n",
|
| 227 |
+
" # FASTA file\n",
|
| 228 |
+
" fasta_path_repo = f\"{species}/genome.fasta\"\n",
|
| 229 |
+
" fasta_path = str(local_dir / fasta_path_repo)\n",
|
| 230 |
+
" if not Path(fasta_path).is_file():\n",
|
| 231 |
+
" raise ValueError(f\"FASTA file not found at '{fasta_path}'\")\n",
|
| 232 |
+
" \n",
|
| 233 |
+
" # BigWig files\n",
|
| 234 |
+
" bigwig_paths, bigwig_ids = [], []\n",
|
| 235 |
+
" for repo_path in species_files:\n",
|
| 236 |
+
" lp = local_dir / repo_path\n",
|
| 237 |
+
" if lp.is_file() and lp.suffix == \".bigwig\":\n",
|
| 238 |
+
" bigwig_paths.append(str(lp))\n",
|
| 239 |
+
" bigwig_ids.append(lp.stem)\n",
|
| 240 |
+
" if not bigwig_paths:\n",
|
| 241 |
+
" raise ValueError(f\"Found no BigWig files in '{species_pattern}'\")\n",
|
| 242 |
+
" \n",
|
| 243 |
+
" # Splits file\n",
|
| 244 |
+
" splits_path_repo = f\"{species}/splits.bed\"\n",
|
| 245 |
+
" splits_path = local_dir / splits_path_repo\n",
|
| 246 |
+
" if not splits_path.is_file():\n",
|
| 247 |
+
" raise ValueError(f\"Splits file not found at '{splits_path}'\")\n",
|
| 248 |
+
" splits_df = pd.read_csv(\n",
|
| 249 |
+
" splits_path, \n",
|
| 250 |
+
" sep=\"\\t\", \n",
|
| 251 |
+
" header=None, \n",
|
| 252 |
+
" names=[\"chr_name\", \"start\", \"end\", \"split\"],\n",
|
| 253 |
+
" dtype={\"chr_name\": str, \"start\": int, \"end\": int, \"split\": str},\n",
|
| 254 |
+
" )\n",
|
| 255 |
+
" \n",
|
| 256 |
+
" # Metadata file\n",
|
| 257 |
+
" metadata_path = local_dir / metadata_file\n",
|
| 258 |
+
" if not metadata_path.is_file():\n",
|
| 259 |
+
" raise ValueError(f\"Metadata file not found at '{metadata_path}'\")\n",
|
| 260 |
+
" metadata_df = pd.read_csv(metadata_path, sep=\"\\t\")\n",
|
| 261 |
+
"\n",
|
| 262 |
+
" if \"species\" not in metadata_df.columns:\n",
|
| 263 |
+
" raise ValueError(\"benchmark_metadata.tsv has no 'species' column\")\n",
|
| 264 |
+
"\n",
|
| 265 |
+
" # Filter metadata according to species\n",
|
| 266 |
+
" metadata_df = metadata_df[metadata_df[\"species\"] == species].reset_index(drop=True)\n",
|
| 267 |
+
"\n",
|
| 268 |
+
" # Order metadata according to bigwig file ids\n",
|
| 269 |
+
" metadata_df = (\n",
|
| 270 |
+
" metadata_df.set_index(\"file_id\")\n",
|
| 271 |
+
" .loc[bigwig_ids]\n",
|
| 272 |
+
" .reset_index()\n",
|
| 273 |
+
" )\n",
|
| 274 |
+
"\n",
|
| 275 |
+
" return fasta_path, bigwig_paths, bigwig_ids, splits_df, metadata_df"
|
| 276 |
]
|
| 277 |
},
|
| 278 |
{
|
|
|
|
| 281 |
"metadata": {},
|
| 282 |
"outputs": [],
|
| 283 |
"source": [
|
| 284 |
+
"os.makedirs(config[\"data_cache_dir\"], exist_ok=True)\n",
|
| 285 |
+
"\n",
|
| 286 |
+
"# Download all species files + load the splits, and metadata\n",
|
| 287 |
+
"(\n",
|
| 288 |
+
" fasta_path, \n",
|
| 289 |
+
" bigwig_paths, \n",
|
| 290 |
+
" bigwig_ids, \n",
|
| 291 |
+
" species_splits_df,\n",
|
| 292 |
+
" metadata_df \n",
|
| 293 |
+
") = prepare_genomics_inputs(\n",
|
| 294 |
+
" config[\"species\"], \n",
|
| 295 |
+
" config[\"data_cache_dir\"], \n",
|
| 296 |
+
" config[\"hf_repo_id\"]\n",
|
| 297 |
+
")"
|
| 298 |
]
|
| 299 |
},
|
| 300 |
{
|
|
|
|
| 350 |
" self.backbone = AutoModelForMaskedLM.from_pretrained(\n",
|
| 351 |
" model_name, \n",
|
| 352 |
" trust_remote_code=True,\n",
|
| 353 |
+
" config=self.config,\n",
|
| 354 |
" )\n",
|
| 355 |
" \n",
|
| 356 |
" self.keep_target_center_fraction = keep_target_center_fraction\n",
|
|
|
|
| 366 |
" \n",
|
| 367 |
" def forward(self, tokens: torch.Tensor, **kwargs) -> Dict[str, torch.Tensor]:\n",
|
| 368 |
" # Forward through backbone\n",
|
| 369 |
+
" outputs = self.backbone(input_ids=tokens, output_hidden_states=True)\n",
|
| 370 |
" embedding = outputs.hidden_states[-1] # Last hidden state\n",
|
| 371 |
" \n",
|
| 372 |
" # Crop to center fraction\n",
|
|
|
|
| 394 |
"# Create model\n",
|
| 395 |
"model = HFModelWithHead(\n",
|
| 396 |
" model_name=config[\"model_name\"],\n",
|
| 397 |
+
" bigwig_track_names=bigwig_ids,\n",
|
| 398 |
" keep_target_center_fraction=config[\"keep_target_center_fraction\"],\n",
|
| 399 |
")\n",
|
| 400 |
"model = model.to(device)\n",
|
| 401 |
"model.train()\n",
|
| 402 |
"\n",
|
| 403 |
"print(f\"Model loaded: {config['model_name']}\")\n",
|
| 404 |
+
"print(f\"Number of bigwig tracks: {len(bigwig_ids)}\")\n",
|
| 405 |
"print(f\"Model parameters: {sum(p.numel() for p in model.parameters()):,}\")"
|
| 406 |
]
|
| 407 |
},
|
|
|
|
| 441 |
" Random genomic windows from a reference genome + bigWig signal.\n",
|
| 442 |
"\n",
|
| 443 |
" Each sample:\n",
|
| 444 |
+
" - picks a random region from the specified split,\n",
|
| 445 |
+
" - picks a random window of length `sequence_length` within that region,\n",
|
| 446 |
" - returns (sequence, signal, chrom, start, end).\n",
|
| 447 |
"\n",
|
| 448 |
" This dataset is compatible with multi-worker DataLoaders. BigWig files\n",
|
|
|
|
| 453 |
" ----\n",
|
| 454 |
" fasta_path : str\n",
|
| 455 |
" Path to the reference genome FASTA (e.g. hg38.fna).\n",
|
| 456 |
+
" bigwig_path_list : list[str]\n",
|
| 457 |
+
" List of paths to bigWig files.\n",
|
| 458 |
+
" chrom_regions : pd.DataFrame\n",
|
| 459 |
+
" DataFrame with columns: chr_name, start, end, split.\n",
|
| 460 |
+
" Contains all genomic regions with their split assignments.\n",
|
| 461 |
+
" split : str\n",
|
| 462 |
+
" Split name to filter regions (e.g., \"train\", \"val\", \"test\").\n",
|
| 463 |
" sequence_length : int\n",
|
| 464 |
" Length of each random window (in bp).\n",
|
| 465 |
" num_samples : int\n",
|
|
|
|
| 470 |
" Function to transform/scaling bigwig targets.\n",
|
| 471 |
" keep_target_center_fraction : float\n",
|
| 472 |
" Fraction of center sequence to keep for target prediction (crops edges to focus on center).\n",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 473 |
" \"\"\"\n",
|
| 474 |
"\n",
|
| 475 |
" def __init__(\n",
|
| 476 |
" self,\n",
|
| 477 |
" fasta_path: str,\n",
|
| 478 |
" bigwig_path_list: list[str],\n",
|
| 479 |
+
" chrom_regions: pd.DataFrame,\n",
|
| 480 |
+
" split: str,\n",
|
| 481 |
" sequence_length: int,\n",
|
| 482 |
" num_samples: int,\n",
|
| 483 |
" tokenizer: AutoTokenizer,\n",
|
|
|
|
| 492 |
" self.sequence_length = sequence_length\n",
|
| 493 |
" self.num_samples = num_samples\n",
|
| 494 |
" self.tokenizer = tokenizer\n",
|
| 495 |
+
" self.transform_fn = transform_fn\n",
|
| 496 |
" self.keep_target_center_fraction = keep_target_center_fraction\n",
|
| 497 |
+
" self.chrom_regions = chrom_regions\n",
|
| 498 |
"\n",
|
| 499 |
+
" # Filter regions by split\n",
|
| 500 |
+
" split_regions = self.chrom_regions[self.chrom_regions[\"split\"] == split].copy()\n",
|
|
|
|
|
|
|
| 501 |
"\n",
|
| 502 |
+
" # Filter valid regions (must be large enough for sequence_length)\n",
|
| 503 |
+
" self.valid_regions = []\n",
|
| 504 |
+
" for _, row in split_regions.iterrows():\n",
|
| 505 |
"\n",
|
| 506 |
+
" region_length = row.end - row.start\n",
|
| 507 |
+
" if region_length < self.sequence_length:\n",
|
| 508 |
" continue\n",
|
| 509 |
+
" \n",
|
| 510 |
+
" # Store valid region\n",
|
| 511 |
+
" self.valid_regions.append((row.chr_name, row.start, row.end))\n",
|
| 512 |
"\n",
|
| 513 |
+
" if not self.valid_regions:\n",
|
| 514 |
+
" raise ValueError(f\"No valid regions found for split '{split}'\")\n",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 515 |
"\n",
|
| 516 |
" def __len__(self):\n",
|
| 517 |
" return self.num_samples\n",
|
| 518 |
"\n",
|
| 519 |
" def __getitem__(self, idx):\n",
|
| 520 |
+
" # Sample a random region from the valid regions\n",
|
| 521 |
+
" chrom, region_start, region_end = random.choice(self.valid_regions)\n",
|
| 522 |
+
" \n",
|
| 523 |
+
" # Sample a random window within this region\n",
|
| 524 |
+
" max_start = region_end - self.sequence_length\n",
|
| 525 |
+
" start = random.randint(region_start, max_start)\n",
|
| 526 |
" end = start + self.sequence_length\n",
|
| 527 |
"\n",
|
| 528 |
" # Sequence\n",
|
|
|
|
| 582 |
"metadata": {},
|
| 583 |
"outputs": [],
|
| 584 |
"source": [
|
| 585 |
+
"def create_targets_scaling_fn(\n",
|
| 586 |
+
" metadata_df: pd.DataFrame\n",
|
| 587 |
+
") -> Callable[[torch.Tensor], torch.Tensor]:\n",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 588 |
" \"\"\"\n",
|
| 589 |
+
" Build a scaling function based on track means contained in the metadata.\n",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 590 |
"\n",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 591 |
" Args:\n",
|
| 592 |
+
" metadata_df: pandas.DataFrame with track means\n",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 593 |
"\n",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 594 |
" Returns:\n",
|
| 595 |
" Transform function that scales input tensors\n",
|
| 596 |
" \"\"\"\n",
|
| 597 |
" # Open bigwig files and compute track statistics\n",
|
| 598 |
+
" track_means = metadata_df[\"mean\"].to_numpy()\n",
|
| 599 |
+
" print(f\"Track means: {track_means}\")\n",
|
| 600 |
+
" print(f\"Number of tracks: {track_means.shape}\")\n",
|
| 601 |
+
"\n",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 602 |
" # Create tensor from computed means\n",
|
| 603 |
" track_means_tensor = torch.tensor(track_means, dtype=torch.float32)\n",
|
| 604 |
+
"\n",
|
| 605 |
" def transform_fn(x: torch.Tensor) -> torch.Tensor:\n",
|
| 606 |
" \"\"\"\n",
|
| 607 |
" x: torch.Tensor, shape (seq_len, num_tracks) or (batch, seq_len, num_tracks)\n",
|
|
|
|
| 617 |
" scaled,\n",
|
| 618 |
" )\n",
|
| 619 |
" return clipped\n",
|
| 620 |
+
"\n",
|
| 621 |
" return transform_fn"
|
| 622 |
]
|
| 623 |
},
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 624 |
{
|
| 625 |
"cell_type": "code",
|
| 626 |
"execution_count": null,
|
|
|
|
| 631 |
"create_dataset_fn = functools.partial(\n",
|
| 632 |
" GenomeBigWigDataset,\n",
|
| 633 |
" fasta_path=fasta_path,\n",
|
| 634 |
+
" bigwig_path_list=bigwig_paths,\n",
|
| 635 |
+
" chrom_regions=species_splits_df,\n",
|
| 636 |
" sequence_length=config[\"sequence_length\"],\n",
|
| 637 |
" tokenizer=tokenizer,\n",
|
| 638 |
+
" transform_fn=create_targets_scaling_fn(metadata_df),\n",
|
| 639 |
" keep_target_center_fraction=config[\"keep_target_center_fraction\"],\n",
|
| 640 |
")\n",
|
| 641 |
"\n",
|
| 642 |
"train_dataset = create_dataset_fn(\n",
|
| 643 |
+
" split=\"train\",\n",
|
| 644 |
" num_samples=config[\"num_steps_training\"] * config[\"batch_size\"],\n",
|
| 645 |
")\n",
|
| 646 |
"\n",
|
| 647 |
"val_dataset = create_dataset_fn(\n",
|
| 648 |
+
" split=\"val\",\n",
|
| 649 |
" num_samples=config[\"num_validation_samples\"],\n",
|
| 650 |
")\n",
|
| 651 |
"\n",
|
| 652 |
"test_dataset = create_dataset_fn(\n",
|
| 653 |
+
" split=\"test\",\n",
|
| 654 |
" num_samples=config[\"num_test_samples\"],\n",
|
| 655 |
")\n",
|
| 656 |
"\n",
|
|
|
|
| 676 |
" num_workers=config[\"num_workers\"],\n",
|
| 677 |
")\n",
|
| 678 |
"\n",
|
| 679 |
+
"print(f\"\\nTrain samples: {len(train_dataset)}\")\n",
|
| 680 |
"print(f\"Val samples: {len(val_dataset)}\")\n",
|
| 681 |
"print(f\"Test samples: {len(test_dataset)}\")"
|
| 682 |
]
|
|
|
|
| 803 |
"metadata": {},
|
| 804 |
"outputs": [],
|
| 805 |
"source": [
|
| 806 |
+
"train_metrics = TracksMetrics(bigwig_ids)\n",
|
| 807 |
+
"val_metrics = TracksMetrics(bigwig_ids)\n",
|
| 808 |
+
"test_metrics = TracksMetrics(bigwig_ids)"
|
| 809 |
]
|
| 810 |
},
|
| 811 |
{
|
|
|
|
| 989 |
"val_losses = []\n",
|
| 990 |
"val_pearson_scores = []\n",
|
| 991 |
"\n",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 992 |
"# Create iterator for training data (will cycle if needed)\n",
|
| 993 |
"train_iter = iter(train_loader)\n",
|
| 994 |
"\n",
|
|
|
|
| 1033 |
" train_losses.append(mean_loss)\n",
|
| 1034 |
" train_pearson_scores.append(train_metrics_dict['mean/pearson'])\n",
|
| 1035 |
" \n",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1036 |
" \n",
|
| 1037 |
" print(\n",
|
| 1038 |
" f\"Step {step_idx + 1}/{config['num_steps_training']} | \"\n",
|
|
|
|
| 1060 |
" val_losses.append(val_metrics_dict['loss'])\n",
|
| 1061 |
" val_pearson_scores.append(val_pearson_mean)\n",
|
| 1062 |
" \n",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1063 |
" \n",
|
| 1064 |
" print(f\" Validation Loss: {val_metrics_dict['loss']:.4f}\")\n",
|
| 1065 |
" print(f\" Validation Mean Pearson: {val_pearson_mean:.4f}\")\n",
|
|
|
|
| 1068 |
" \n",
|
| 1069 |
" model.train() # Back to training mode\n",
|
| 1070 |
"\n",
|
| 1071 |
+
"print(f\"\\nTraining completed after {config['num_steps_training']} steps.\")\n"
|
| 1072 |
+
]
|
| 1073 |
+
},
|
| 1074 |
+
{
|
| 1075 |
+
"cell_type": "code",
|
| 1076 |
+
"execution_count": null,
|
| 1077 |
+
"metadata": {},
|
| 1078 |
+
"outputs": [],
|
| 1079 |
+
"source": [
|
| 1080 |
+
"# Plot training results\n",
|
| 1081 |
+
"fig, axes = plt.subplots(1, 2, figsize=(16, 6))\n",
|
| 1082 |
+
"\n",
|
| 1083 |
+
"# Plot Loss\n",
|
| 1084 |
+
"axes[0].plot(train_steps, train_losses, 'b-o', label='Train Loss', markersize=4, linewidth=1.5)\n",
|
| 1085 |
+
"if val_steps:\n",
|
| 1086 |
+
" axes[0].plot(val_steps, val_losses, 'r-s', label='Val Loss', markersize=4, linewidth=1.5)\n",
|
| 1087 |
+
"axes[0].set_xlabel('Step')\n",
|
| 1088 |
+
"axes[0].set_ylabel('Loss')\n",
|
| 1089 |
+
"axes[0].set_title('Loss')\n",
|
| 1090 |
+
"axes[0].legend()\n",
|
| 1091 |
+
"axes[0].grid(True, alpha=0.3)\n",
|
| 1092 |
+
"\n",
|
| 1093 |
+
"# Plot Pearson Correlation\n",
|
| 1094 |
+
"axes[1].plot(train_steps, train_pearson_scores, 'g-o', label='Train Pearson', markersize=4, linewidth=1.5)\n",
|
| 1095 |
+
"if val_steps:\n",
|
| 1096 |
+
" axes[1].plot(val_steps, val_pearson_scores, 'orange', marker='s', label='Val Pearson', markersize=4, linewidth=1.5)\n",
|
| 1097 |
+
"axes[1].set_xlabel('Step')\n",
|
| 1098 |
+
"axes[1].set_ylabel('Pearson Correlation')\n",
|
| 1099 |
+
"axes[1].set_title('Mean Pearson Correlation')\n",
|
| 1100 |
+
"axes[1].legend()\n",
|
| 1101 |
+
"axes[1].grid(True, alpha=0.3)\n",
|
| 1102 |
+
"\n",
|
| 1103 |
+
"plt.tight_layout()\n",
|
| 1104 |
+
"plt.show()\n"
|
| 1105 |
]
|
| 1106 |
},
|
| 1107 |
{
|