Spaces:
Running
Running
Commit
·
2f2efdf
1
Parent(s):
35ab8fa
fix: corrected paths for restrained list of bigwigs
Browse files- notebooks/03_fine_tuning.ipynb +98 -31
notebooks/03_fine_tuning.ipynb
CHANGED
|
@@ -139,10 +139,13 @@
|
|
| 139 |
" \"data_cache_dir\": \"./data\",\n",
|
| 140 |
" \"sequence_length\": 32_768,\n",
|
| 141 |
" \"keep_target_center_fraction\": 0.375,\n",
|
|
|
|
|
|
|
|
|
|
| 142 |
" \n",
|
| 143 |
" # Training\n",
|
| 144 |
" \"batch_size\": 12,\n",
|
| 145 |
-
" \"num_steps_training\":
|
| 146 |
" \"log_every_n_steps\": 20,\n",
|
| 147 |
" \"learning_rate\": 1e-5,\n",
|
| 148 |
" \"weight_decay\": 0.01,\n",
|
|
@@ -196,13 +199,23 @@
|
|
| 196 |
" species: str,\n",
|
| 197 |
" data_cache_dir: str | Path = \"data\",\n",
|
| 198 |
" hf_repo_id: str = \"InstaDeepAI/NTv3_benchmark_dataset\",\n",
|
|
|
|
| 199 |
") -> tuple[str, list[str], list[str]]:\n",
|
| 200 |
" \"\"\"\n",
|
| 201 |
" Downloads:\n",
|
| 202 |
" 1) FASTA from HF dataset under: <species>/genome.fasta\n",
|
| 203 |
" 2) BigWigs from HF dataset under: <species>/functional_tracks/**\n",
|
|
|
|
| 204 |
" 3) Splits from HF dataset under: <species>/splits.bed\n",
|
| 205 |
" 4) Metadata from HF dataset under: benchmark_metadata.tsv\n",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 206 |
" Returns:\n",
|
| 207 |
" (fasta_path, bigwig_path_list, bigwig_file_ids)\n",
|
| 208 |
" \"\"\"\n",
|
|
@@ -210,16 +223,36 @@
|
|
| 210 |
" cache.mkdir(parents=True, exist_ok=True)\n",
|
| 211 |
" \n",
|
| 212 |
" # --- Download metadata + <species> files (FASTA, BigWigs, Splits) ---\n",
|
| 213 |
-
" api = HfApi()\n",
|
| 214 |
-
" files = api.list_repo_files(repo_id=hf_repo_id, repo_type=\"dataset\")\n",
|
| 215 |
-
" \n",
|
| 216 |
-
" # Find all files to download: species directory + metadata at root\n",
|
| 217 |
-
" species_pattern = f\"{species}/**\"\n",
|
| 218 |
" metadata_file = \"benchmark_metadata.tsv\"\n",
|
| 219 |
-
"
|
| 220 |
" \n",
|
| 221 |
-
"
|
| 222 |
-
"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 223 |
" local_dir = Path(\n",
|
| 224 |
" snapshot_download(\n",
|
| 225 |
" repo_id=hf_repo_id,\n",
|
|
@@ -236,15 +269,16 @@
|
|
| 236 |
" if not Path(fasta_path).is_file():\n",
|
| 237 |
" raise ValueError(f\"FASTA file not found at '{fasta_path}'\")\n",
|
| 238 |
" \n",
|
| 239 |
-
" # BigWig files\n",
|
| 240 |
-
"
|
| 241 |
-
"
|
| 242 |
-
"
|
| 243 |
-
"
|
| 244 |
-
"
|
| 245 |
-
"
|
| 246 |
-
"
|
| 247 |
-
"
|
|
|
|
| 248 |
" \n",
|
| 249 |
" # Splits file\n",
|
| 250 |
" splits_path_repo = f\"{species}/splits.bed\"\n",
|
|
@@ -284,7 +318,7 @@
|
|
| 284 |
"source": [
|
| 285 |
"os.makedirs(config[\"data_cache_dir\"], exist_ok=True)\n",
|
| 286 |
"\n",
|
| 287 |
-
"# Download all species files + load the splits, and metadata\n",
|
| 288 |
"(\n",
|
| 289 |
" fasta_path, \n",
|
| 290 |
" bigwig_paths, \n",
|
|
@@ -294,7 +328,8 @@
|
|
| 294 |
") = prepare_genomics_inputs(\n",
|
| 295 |
" config[\"species\"], \n",
|
| 296 |
" config[\"data_cache_dir\"], \n",
|
| 297 |
-
" config[\"hf_repo_id\"]
|
|
|
|
| 298 |
")"
|
| 299 |
]
|
| 300 |
},
|
|
@@ -348,11 +383,12 @@
|
|
| 348 |
" \n",
|
| 349 |
" # Load config and model\n",
|
| 350 |
" self.config = AutoConfig.from_pretrained(model_name, trust_remote_code=True)\n",
|
| 351 |
-
"
|
| 352 |
" model_name, \n",
|
| 353 |
" trust_remote_code=True,\n",
|
| 354 |
" config=self.config,\n",
|
| 355 |
" )\n",
|
|
|
|
| 356 |
" \n",
|
| 357 |
" self.keep_target_center_fraction = keep_target_center_fraction\n",
|
| 358 |
"\n",
|
|
@@ -428,18 +464,48 @@
|
|
| 428 |
"metadata": {},
|
| 429 |
"outputs": [],
|
| 430 |
"source": [
|
| 431 |
-
"# Process-local cache for
|
| 432 |
"# This allows safe multi-worker DataLoader usage\n",
|
|
|
|
| 433 |
"_bigwig_cache = {} # Maps (process_id, file_path) -> pyBigWig handle\n",
|
| 434 |
"\n",
|
| 435 |
"\n",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 436 |
"def _get_bigwig_handle(bigwig_path: str) -> pyBigWig.pyBigWig:\n",
|
| 437 |
" \"\"\"Get or create a BigWig file handle for the current process.\"\"\"\n",
|
| 438 |
" process_id = os.getpid()\n",
|
| 439 |
-
"
|
|
|
|
| 440 |
" \n",
|
| 441 |
" if cache_key not in _bigwig_cache:\n",
|
| 442 |
-
"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 443 |
" \n",
|
| 444 |
" return _bigwig_cache[cache_key]\n",
|
| 445 |
"\n",
|
|
@@ -494,8 +560,8 @@
|
|
| 494 |
" ):\n",
|
| 495 |
" super().__init__()\n",
|
| 496 |
"\n",
|
| 497 |
-
" self.fasta = Fasta(fasta_path, as_raw=True, sequence_always_upper=True)\n",
|
| 498 |
" # Store paths instead of opening files immediately (for multi-worker compatibility)\n",
|
|
|
|
| 499 |
" self.bigwig_path_list = bigwig_path_list\n",
|
| 500 |
" self.sequence_length = sequence_length\n",
|
| 501 |
" self.num_samples = num_samples\n",
|
|
@@ -533,8 +599,9 @@
|
|
| 533 |
" start = random.randint(region_start, max_start)\n",
|
| 534 |
" end = start + self.sequence_length\n",
|
| 535 |
"\n",
|
| 536 |
-
" # Sequence\n",
|
| 537 |
-
"
|
|
|
|
| 538 |
" # Tokenize with padding and truncation to ensure consistent lengths for batching\n",
|
| 539 |
" tokenized = self.tokenizer(\n",
|
| 540 |
" seq,\n",
|
|
@@ -788,7 +855,8 @@
|
|
| 788 |
" metrics_dict = {}\n",
|
| 789 |
" \n",
|
| 790 |
" # Compute Pearson correlation per track\n",
|
| 791 |
-
"
|
|
|
|
| 792 |
" for i, track_name in enumerate(self.track_names):\n",
|
| 793 |
" metrics_dict[f\"{track_name}/pearson\"] = correlations[i]\n",
|
| 794 |
" \n",
|
|
@@ -931,7 +999,6 @@
|
|
| 931 |
" loss.backward()\n",
|
| 932 |
" return loss.item()\n",
|
| 933 |
"\n",
|
| 934 |
-
"\n",
|
| 935 |
"def validation_step(\n",
|
| 936 |
" model: nn.Module,\n",
|
| 937 |
" batch: Dict[str, torch.Tensor],\n",
|
|
@@ -1149,10 +1216,10 @@
|
|
| 1149 |
"source": [
|
| 1150 |
" ## Test set results\n",
|
| 1151 |
"\n",
|
| 1152 |
-
"Performances reached at ~1.5B tokens (~1500 steps in current 32kb sequences setup with batch_size=32)\n",
|
| 1153 |
-
"\n",
|
| 1154 |
"**Hardware configuration**: These results were obtained on an **H100 GPU with 16 workers** for data loading in approximately **~10 minutes** of training.\n",
|
| 1155 |
"\n",
|
|
|
|
|
|
|
| 1156 |
"Mean Pearson: 0.5835\n",
|
| 1157 |
"- ENCSR325NFE/pearson: 0.6081\n",
|
| 1158 |
"- ENCSR962OTG/pearson: 0.7286\n",
|
|
|
|
| 139 |
" \"data_cache_dir\": \"./data\",\n",
|
| 140 |
" \"sequence_length\": 32_768,\n",
|
| 141 |
" \"keep_target_center_fraction\": 0.375,\n",
|
| 142 |
+
" \"bigwig_file_ids\": [\n",
|
| 143 |
+
" \"ENCSR325NFE\", \"ENCSR962OTG\", \"ENCSR619DQO_P\", \"ENCSR619DQO_M\"\n",
|
| 144 |
+
" ], # If None, will use all available tracks for selected species\n",
|
| 145 |
" \n",
|
| 146 |
" # Training\n",
|
| 147 |
" \"batch_size\": 12,\n",
|
| 148 |
+
" \"num_steps_training\": 2000, # Consider increasing for improving training performance\n",
|
| 149 |
" \"log_every_n_steps\": 20,\n",
|
| 150 |
" \"learning_rate\": 1e-5,\n",
|
| 151 |
" \"weight_decay\": 0.01,\n",
|
|
|
|
| 199 |
" species: str,\n",
|
| 200 |
" data_cache_dir: str | Path = \"data\",\n",
|
| 201 |
" hf_repo_id: str = \"InstaDeepAI/NTv3_benchmark_dataset\",\n",
|
| 202 |
+
" bigwig_file_ids: list[str] | None = None,\n",
|
| 203 |
") -> tuple[str, list[str], list[str]]:\n",
|
| 204 |
" \"\"\"\n",
|
| 205 |
" Downloads:\n",
|
| 206 |
" 1) FASTA from HF dataset under: <species>/genome.fasta\n",
|
| 207 |
" 2) BigWigs from HF dataset under: <species>/functional_tracks/**\n",
|
| 208 |
+
" (filtered by bigwig_file_ids if provided)\n",
|
| 209 |
" 3) Splits from HF dataset under: <species>/splits.bed\n",
|
| 210 |
" 4) Metadata from HF dataset under: benchmark_metadata.tsv\n",
|
| 211 |
+
" \n",
|
| 212 |
+
" Args:\n",
|
| 213 |
+
" species: Species name (e.g., \"human\", \"arabidopsis\")\n",
|
| 214 |
+
" data_cache_dir: Directory where downloaded data files will be stored\n",
|
| 215 |
+
" hf_repo_id: HuggingFace dataset repository ID\n",
|
| 216 |
+
" bigwig_file_ids: Optional list of BigWig file IDs to download. If None,\n",
|
| 217 |
+
" downloads all available BigWig files for the species.\n",
|
| 218 |
+
" \n",
|
| 219 |
" Returns:\n",
|
| 220 |
" (fasta_path, bigwig_path_list, bigwig_file_ids)\n",
|
| 221 |
" \"\"\"\n",
|
|
|
|
| 223 |
" cache.mkdir(parents=True, exist_ok=True)\n",
|
| 224 |
" \n",
|
| 225 |
" # --- Download metadata + <species> files (FASTA, BigWigs, Splits) ---\n",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 226 |
" metadata_file = \"benchmark_metadata.tsv\"\n",
|
| 227 |
+
" download_patterns = [metadata_file, f\"{species}/genome.fasta\", f\"{species}/splits.bed\"]\n",
|
| 228 |
" \n",
|
| 229 |
+
" if bigwig_file_ids is not None:\n",
|
| 230 |
+
" # List files to validate requested BigWig files exist\n",
|
| 231 |
+
" api = HfApi()\n",
|
| 232 |
+
" files = api.list_repo_files(repo_id=hf_repo_id, repo_type=\"dataset\")\n",
|
| 233 |
+
" species_pattern = f\"{species}/**\"\n",
|
| 234 |
+
" species_files = [p for p in files if fnmatch.fnmatch(p, species_pattern)]\n",
|
| 235 |
+
" \n",
|
| 236 |
+
" # Get all available BigWig file IDs and their paths\n",
|
| 237 |
+
" available_bigwig_files = {\n",
|
| 238 |
+
" Path(p).stem: p for p in species_files \n",
|
| 239 |
+
" if Path(p).suffix == \".bigwig\"\n",
|
| 240 |
+
" }\n",
|
| 241 |
+
" \n",
|
| 242 |
+
" # Check that all requested files exist\n",
|
| 243 |
+
" missing_files = set(bigwig_file_ids) - set(available_bigwig_files.keys())\n",
|
| 244 |
+
" if missing_files:\n",
|
| 245 |
+
" raise ValueError(\n",
|
| 246 |
+
" f\"Requested BigWig files not found: {missing_files}. \"\n",
|
| 247 |
+
" f\"Available files: {list(available_bigwig_files.keys())}\"\n",
|
| 248 |
+
" )\n",
|
| 249 |
+
" \n",
|
| 250 |
+
" # Add specific patterns for requested BigWig files only\n",
|
| 251 |
+
" for file_id in bigwig_file_ids:\n",
|
| 252 |
+
" download_patterns.append(available_bigwig_files[file_id])\n",
|
| 253 |
+
" else:\n",
|
| 254 |
+
" # Download all BigWig files\n",
|
| 255 |
+
" download_patterns.append(f\"{species}/functional_tracks/*.bigwig\")\n",
|
| 256 |
" local_dir = Path(\n",
|
| 257 |
" snapshot_download(\n",
|
| 258 |
" repo_id=hf_repo_id,\n",
|
|
|
|
| 269 |
" if not Path(fasta_path).is_file():\n",
|
| 270 |
" raise ValueError(f\"FASTA file not found at '{fasta_path}'\")\n",
|
| 271 |
" \n",
|
| 272 |
+
" # BigWig files - use downloaded files directly\n",
|
| 273 |
+
" bigwig_dir = local_dir / species / \"functional_tracks\"\n",
|
| 274 |
+
" \n",
|
| 275 |
+
" if bigwig_file_ids is not None:\n",
|
| 276 |
+
" bigwig_paths = [bigwig_dir / f\"{file_id}.bigwig\" for file_id in bigwig_file_ids]\n",
|
| 277 |
+
" bigwig_ids = bigwig_file_ids\n",
|
| 278 |
+
" else:\n",
|
| 279 |
+
" # Find all downloaded BigWig files\n",
|
| 280 |
+
" bigwig_paths = [bigwig_file for bigwig_file in bigwig_dir.glob(\"*.bigwig\")]\n",
|
| 281 |
+
" bigwig_ids = [bigwig_file.stem for bigwig_file in bigwig_dir.glob(\"*.bigwig\")] \n",
|
| 282 |
" \n",
|
| 283 |
" # Splits file\n",
|
| 284 |
" splits_path_repo = f\"{species}/splits.bed\"\n",
|
|
|
|
| 318 |
"source": [
|
| 319 |
"os.makedirs(config[\"data_cache_dir\"], exist_ok=True)\n",
|
| 320 |
"\n",
|
| 321 |
+
"# Download all requested species-related files + load the splits, and metadata\n",
|
| 322 |
"(\n",
|
| 323 |
" fasta_path, \n",
|
| 324 |
" bigwig_paths, \n",
|
|
|
|
| 328 |
") = prepare_genomics_inputs(\n",
|
| 329 |
" config[\"species\"], \n",
|
| 330 |
" config[\"data_cache_dir\"], \n",
|
| 331 |
+
" config[\"hf_repo_id\"],\n",
|
| 332 |
+
" bigwig_file_ids=config[\"bigwig_file_ids\"]\n",
|
| 333 |
")"
|
| 334 |
]
|
| 335 |
},
|
|
|
|
| 383 |
" \n",
|
| 384 |
" # Load config and model\n",
|
| 385 |
" self.config = AutoConfig.from_pretrained(model_name, trust_remote_code=True)\n",
|
| 386 |
+
" backbone = AutoModelForMaskedLM.from_pretrained(\n",
|
| 387 |
" model_name, \n",
|
| 388 |
" trust_remote_code=True,\n",
|
| 389 |
" config=self.config,\n",
|
| 390 |
" )\n",
|
| 391 |
+
" self.backbone = torch.compile(backbone)\n",
|
| 392 |
" \n",
|
| 393 |
" self.keep_target_center_fraction = keep_target_center_fraction\n",
|
| 394 |
"\n",
|
|
|
|
| 464 |
"metadata": {},
|
| 465 |
"outputs": [],
|
| 466 |
"source": [
|
| 467 |
+
"# Process-local cache for file handles (one per worker process)\n",
|
| 468 |
"# This allows safe multi-worker DataLoader usage\n",
|
| 469 |
+
"_fasta_cache = {} # Maps (process_id, file_path) -> Fasta handle\n",
|
| 470 |
"_bigwig_cache = {} # Maps (process_id, file_path) -> pyBigWig handle\n",
|
| 471 |
"\n",
|
| 472 |
"\n",
|
| 473 |
+
"def _get_fasta_handle(fasta_path: str) -> Fasta:\n",
|
| 474 |
+
" \"\"\"Get or create a FASTA file handle for the current process.\"\"\"\n",
|
| 475 |
+
" process_id = os.getpid()\n",
|
| 476 |
+
" abs_path = str(Path(fasta_path).resolve())\n",
|
| 477 |
+
" cache_key = (process_id, abs_path)\n",
|
| 478 |
+
" \n",
|
| 479 |
+
" if cache_key not in _fasta_cache:\n",
|
| 480 |
+
" _fasta_cache[cache_key] = Fasta(abs_path, as_raw=True, sequence_always_upper=True)\n",
|
| 481 |
+
" \n",
|
| 482 |
+
" return _fasta_cache[cache_key]\n",
|
| 483 |
+
"\n",
|
| 484 |
+
"\n",
|
| 485 |
"def _get_bigwig_handle(bigwig_path: str) -> pyBigWig.pyBigWig:\n",
|
| 486 |
" \"\"\"Get or create a BigWig file handle for the current process.\"\"\"\n",
|
| 487 |
" process_id = os.getpid()\n",
|
| 488 |
+
" abs_path = str(Path(bigwig_path).resolve())\n",
|
| 489 |
+
" cache_key = (process_id, abs_path)\n",
|
| 490 |
" \n",
|
| 491 |
" if cache_key not in _bigwig_cache:\n",
|
| 492 |
+
" # Check if file exists before trying to open\n",
|
| 493 |
+
" if not Path(abs_path).exists():\n",
|
| 494 |
+
" raise FileNotFoundError(\n",
|
| 495 |
+
" f\"BigWig file not found: {abs_path}\\n\"\n",
|
| 496 |
+
" f\"Original path: {bigwig_path}\\n\"\n",
|
| 497 |
+
" f\"Current working directory: {os.getcwd()}\"\n",
|
| 498 |
+
" )\n",
|
| 499 |
+
" \n",
|
| 500 |
+
" try:\n",
|
| 501 |
+
" _bigwig_cache[cache_key] = pyBigWig.open(abs_path)\n",
|
| 502 |
+
" except Exception as e:\n",
|
| 503 |
+
" raise RuntimeError(\n",
|
| 504 |
+
" f\"Failed to open BigWig file: {abs_path}\\n\"\n",
|
| 505 |
+
" f\"Error: {str(e)}\\n\"\n",
|
| 506 |
+
" f\"File exists: {Path(abs_path).exists()}\\n\"\n",
|
| 507 |
+
" f\"File size: {Path(abs_path).stat().st_size if Path(abs_path).exists() else 'N/A'} bytes\"\n",
|
| 508 |
+
" ) from e\n",
|
| 509 |
" \n",
|
| 510 |
" return _bigwig_cache[cache_key]\n",
|
| 511 |
"\n",
|
|
|
|
| 560 |
" ):\n",
|
| 561 |
" super().__init__()\n",
|
| 562 |
"\n",
|
|
|
|
| 563 |
" # Store paths instead of opening files immediately (for multi-worker compatibility)\n",
|
| 564 |
+
" self.fasta_path = fasta_path\n",
|
| 565 |
" self.bigwig_path_list = bigwig_path_list\n",
|
| 566 |
" self.sequence_length = sequence_length\n",
|
| 567 |
" self.num_samples = num_samples\n",
|
|
|
|
| 599 |
" start = random.randint(region_start, max_start)\n",
|
| 600 |
" end = start + self.sequence_length\n",
|
| 601 |
"\n",
|
| 602 |
+
" # Sequence - get FASTA handle lazily (cached per worker process)\n",
|
| 603 |
+
" fasta = _get_fasta_handle(self.fasta_path)\n",
|
| 604 |
+
" seq = fasta[chrom][start:end] # string slice\n",
|
| 605 |
" # Tokenize with padding and truncation to ensure consistent lengths for batching\n",
|
| 606 |
" tokenized = self.tokenizer(\n",
|
| 607 |
" seq,\n",
|
|
|
|
| 855 |
" metrics_dict = {}\n",
|
| 856 |
" \n",
|
| 857 |
" # Compute Pearson correlation per track\n",
|
| 858 |
+
" # Move to CPU before converting to numpy\n",
|
| 859 |
+
" correlations = self.pearson_metric.compute().cpu().numpy()\n",
|
| 860 |
" for i, track_name in enumerate(self.track_names):\n",
|
| 861 |
" metrics_dict[f\"{track_name}/pearson\"] = correlations[i]\n",
|
| 862 |
" \n",
|
|
|
|
| 999 |
" loss.backward()\n",
|
| 1000 |
" return loss.item()\n",
|
| 1001 |
"\n",
|
|
|
|
| 1002 |
"def validation_step(\n",
|
| 1003 |
" model: nn.Module,\n",
|
| 1004 |
" batch: Dict[str, torch.Tensor],\n",
|
|
|
|
| 1216 |
"source": [
|
| 1217 |
" ## Test set results\n",
|
| 1218 |
"\n",
|
|
|
|
|
|
|
| 1219 |
"**Hardware configuration**: These results were obtained on an **H100 GPU with 16 workers** for data loading in approximately **~10 minutes** of training.\n",
|
| 1220 |
"\n",
|
| 1221 |
+
"Performances reached at ~1.5B tokens (~1500 steps in current 32kb sequences setup with batch_size=32)\n",
|
| 1222 |
+
"\n",
|
| 1223 |
"Mean Pearson: 0.5835\n",
|
| 1224 |
"- ENCSR325NFE/pearson: 0.6081\n",
|
| 1225 |
"- ENCSR962OTG/pearson: 0.7286\n",
|