ybornachot commited on
Commit
2f2efdf
·
1 Parent(s): 35ab8fa

fix: corrected paths for restrained list of bigwigs

Browse files
Files changed (1) hide show
  1. notebooks/03_fine_tuning.ipynb +98 -31
notebooks/03_fine_tuning.ipynb CHANGED
@@ -139,10 +139,13 @@
139
  " \"data_cache_dir\": \"./data\",\n",
140
  " \"sequence_length\": 32_768,\n",
141
  " \"keep_target_center_fraction\": 0.375,\n",
 
 
 
142
  " \n",
143
  " # Training\n",
144
  " \"batch_size\": 12,\n",
145
- " \"num_steps_training\": 5315, # reproduce 10% of benchmark training length\n",
146
  " \"log_every_n_steps\": 20,\n",
147
  " \"learning_rate\": 1e-5,\n",
148
  " \"weight_decay\": 0.01,\n",
@@ -196,13 +199,23 @@
196
  " species: str,\n",
197
  " data_cache_dir: str | Path = \"data\",\n",
198
  " hf_repo_id: str = \"InstaDeepAI/NTv3_benchmark_dataset\",\n",
 
199
  ") -> tuple[str, list[str], list[str]]:\n",
200
  " \"\"\"\n",
201
  " Downloads:\n",
202
  " 1) FASTA from HF dataset under: <species>/genome.fasta\n",
203
  " 2) BigWigs from HF dataset under: <species>/functional_tracks/**\n",
 
204
  " 3) Splits from HF dataset under: <species>/splits.bed\n",
205
  " 4) Metadata from HF dataset under: benchmark_metadata.tsv\n",
 
 
 
 
 
 
 
 
206
  " Returns:\n",
207
  " (fasta_path, bigwig_path_list, bigwig_file_ids)\n",
208
  " \"\"\"\n",
@@ -210,16 +223,36 @@
210
  " cache.mkdir(parents=True, exist_ok=True)\n",
211
  " \n",
212
  " # --- Download metadata + <species> files (FASTA, BigWigs, Splits) ---\n",
213
- " api = HfApi()\n",
214
- " files = api.list_repo_files(repo_id=hf_repo_id, repo_type=\"dataset\")\n",
215
- " \n",
216
- " # Find all files to download: species directory + metadata at root\n",
217
- " species_pattern = f\"{species}/**\"\n",
218
  " metadata_file = \"benchmark_metadata.tsv\"\n",
219
- " species_files = [p for p in files if fnmatch.fnmatch(p, species_pattern)]\n",
220
  " \n",
221
- " # Download all needed files\n",
222
- " download_patterns = [species_pattern, metadata_file]\n",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
223
  " local_dir = Path(\n",
224
  " snapshot_download(\n",
225
  " repo_id=hf_repo_id,\n",
@@ -236,15 +269,16 @@
236
  " if not Path(fasta_path).is_file():\n",
237
  " raise ValueError(f\"FASTA file not found at '{fasta_path}'\")\n",
238
  " \n",
239
- " # BigWig files\n",
240
- " bigwig_paths, bigwig_ids = [], []\n",
241
- " for repo_path in species_files:\n",
242
- " lp = local_dir / repo_path\n",
243
- " if lp.is_file() and lp.suffix == \".bigwig\":\n",
244
- " bigwig_paths.append(str(lp))\n",
245
- " bigwig_ids.append(lp.stem)\n",
246
- " if not bigwig_paths:\n",
247
- " raise ValueError(f\"Found no BigWig files in '{species_pattern}'\")\n",
 
248
  " \n",
249
  " # Splits file\n",
250
  " splits_path_repo = f\"{species}/splits.bed\"\n",
@@ -284,7 +318,7 @@
284
  "source": [
285
  "os.makedirs(config[\"data_cache_dir\"], exist_ok=True)\n",
286
  "\n",
287
- "# Download all species files + load the splits, and metadata\n",
288
  "(\n",
289
  " fasta_path, \n",
290
  " bigwig_paths, \n",
@@ -294,7 +328,8 @@
294
  ") = prepare_genomics_inputs(\n",
295
  " config[\"species\"], \n",
296
  " config[\"data_cache_dir\"], \n",
297
- " config[\"hf_repo_id\"]\n",
 
298
  ")"
299
  ]
300
  },
@@ -348,11 +383,12 @@
348
  " \n",
349
  " # Load config and model\n",
350
  " self.config = AutoConfig.from_pretrained(model_name, trust_remote_code=True)\n",
351
- " self.backbone = AutoModelForMaskedLM.from_pretrained(\n",
352
  " model_name, \n",
353
  " trust_remote_code=True,\n",
354
  " config=self.config,\n",
355
  " )\n",
 
356
  " \n",
357
  " self.keep_target_center_fraction = keep_target_center_fraction\n",
358
  "\n",
@@ -428,18 +464,48 @@
428
  "metadata": {},
429
  "outputs": [],
430
  "source": [
431
- "# Process-local cache for BigWig file handles (one per worker process)\n",
432
  "# This allows safe multi-worker DataLoader usage\n",
 
433
  "_bigwig_cache = {} # Maps (process_id, file_path) -> pyBigWig handle\n",
434
  "\n",
435
  "\n",
 
 
 
 
 
 
 
 
 
 
 
 
436
  "def _get_bigwig_handle(bigwig_path: str) -> pyBigWig.pyBigWig:\n",
437
  " \"\"\"Get or create a BigWig file handle for the current process.\"\"\"\n",
438
  " process_id = os.getpid()\n",
439
- " cache_key = (process_id, bigwig_path)\n",
 
440
  " \n",
441
  " if cache_key not in _bigwig_cache:\n",
442
- " _bigwig_cache[cache_key] = pyBigWig.open(bigwig_path)\n",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
443
  " \n",
444
  " return _bigwig_cache[cache_key]\n",
445
  "\n",
@@ -494,8 +560,8 @@
494
  " ):\n",
495
  " super().__init__()\n",
496
  "\n",
497
- " self.fasta = Fasta(fasta_path, as_raw=True, sequence_always_upper=True)\n",
498
  " # Store paths instead of opening files immediately (for multi-worker compatibility)\n",
 
499
  " self.bigwig_path_list = bigwig_path_list\n",
500
  " self.sequence_length = sequence_length\n",
501
  " self.num_samples = num_samples\n",
@@ -533,8 +599,9 @@
533
  " start = random.randint(region_start, max_start)\n",
534
  " end = start + self.sequence_length\n",
535
  "\n",
536
- " # Sequence\n",
537
- " seq = self.fasta[chrom][start:end] # string slice\n",
 
538
  " # Tokenize with padding and truncation to ensure consistent lengths for batching\n",
539
  " tokenized = self.tokenizer(\n",
540
  " seq,\n",
@@ -788,7 +855,8 @@
788
  " metrics_dict = {}\n",
789
  " \n",
790
  " # Compute Pearson correlation per track\n",
791
- " correlations = self.pearson_metric.compute().numpy()\n",
 
792
  " for i, track_name in enumerate(self.track_names):\n",
793
  " metrics_dict[f\"{track_name}/pearson\"] = correlations[i]\n",
794
  " \n",
@@ -931,7 +999,6 @@
931
  " loss.backward()\n",
932
  " return loss.item()\n",
933
  "\n",
934
- "\n",
935
  "def validation_step(\n",
936
  " model: nn.Module,\n",
937
  " batch: Dict[str, torch.Tensor],\n",
@@ -1149,10 +1216,10 @@
1149
  "source": [
1150
  " ## Test set results\n",
1151
  "\n",
1152
- "Performances reached at ~1.5B tokens (~1500 steps in current 32kb sequences setup with batch_size=32)\n",
1153
- "\n",
1154
  "**Hardware configuration**: These results were obtained on an **H100 GPU with 16 workers** for data loading in approximately **~10 minutes** of training.\n",
1155
  "\n",
 
 
1156
  "Mean Pearson: 0.5835\n",
1157
  "- ENCSR325NFE/pearson: 0.6081\n",
1158
  "- ENCSR962OTG/pearson: 0.7286\n",
 
139
  " \"data_cache_dir\": \"./data\",\n",
140
  " \"sequence_length\": 32_768,\n",
141
  " \"keep_target_center_fraction\": 0.375,\n",
142
+ " \"bigwig_file_ids\": [\n",
143
+ " \"ENCSR325NFE\", \"ENCSR962OTG\", \"ENCSR619DQO_P\", \"ENCSR619DQO_M\"\n",
144
+ " ], # If None, will use all available tracks for selected species\n",
145
  " \n",
146
  " # Training\n",
147
  " \"batch_size\": 12,\n",
148
+ " \"num_steps_training\": 2000, # Consider increasing for improving training performance\n",
149
  " \"log_every_n_steps\": 20,\n",
150
  " \"learning_rate\": 1e-5,\n",
151
  " \"weight_decay\": 0.01,\n",
 
199
  " species: str,\n",
200
  " data_cache_dir: str | Path = \"data\",\n",
201
  " hf_repo_id: str = \"InstaDeepAI/NTv3_benchmark_dataset\",\n",
202
+ " bigwig_file_ids: list[str] | None = None,\n",
203
  ") -> tuple[str, list[str], list[str]]:\n",
204
  " \"\"\"\n",
205
  " Downloads:\n",
206
  " 1) FASTA from HF dataset under: <species>/genome.fasta\n",
207
  " 2) BigWigs from HF dataset under: <species>/functional_tracks/**\n",
208
+ " (filtered by bigwig_file_ids if provided)\n",
209
  " 3) Splits from HF dataset under: <species>/splits.bed\n",
210
  " 4) Metadata from HF dataset under: benchmark_metadata.tsv\n",
211
+ " \n",
212
+ " Args:\n",
213
+ " species: Species name (e.g., \"human\", \"arabidopsis\")\n",
214
+ " data_cache_dir: Directory where downloaded data files will be stored\n",
215
+ " hf_repo_id: HuggingFace dataset repository ID\n",
216
+ " bigwig_file_ids: Optional list of BigWig file IDs to download. If None,\n",
217
+ " downloads all available BigWig files for the species.\n",
218
+ " \n",
219
  " Returns:\n",
220
  " (fasta_path, bigwig_path_list, bigwig_file_ids)\n",
221
  " \"\"\"\n",
 
223
  " cache.mkdir(parents=True, exist_ok=True)\n",
224
  " \n",
225
  " # --- Download metadata + <species> files (FASTA, BigWigs, Splits) ---\n",
 
 
 
 
 
226
  " metadata_file = \"benchmark_metadata.tsv\"\n",
227
+ " download_patterns = [metadata_file, f\"{species}/genome.fasta\", f\"{species}/splits.bed\"]\n",
228
  " \n",
229
+ " if bigwig_file_ids is not None:\n",
230
+ " # List files to validate requested BigWig files exist\n",
231
+ " api = HfApi()\n",
232
+ " files = api.list_repo_files(repo_id=hf_repo_id, repo_type=\"dataset\")\n",
233
+ " species_pattern = f\"{species}/**\"\n",
234
+ " species_files = [p for p in files if fnmatch.fnmatch(p, species_pattern)]\n",
235
+ " \n",
236
+ " # Get all available BigWig file IDs and their paths\n",
237
+ " available_bigwig_files = {\n",
238
+ " Path(p).stem: p for p in species_files \n",
239
+ " if Path(p).suffix == \".bigwig\"\n",
240
+ " }\n",
241
+ " \n",
242
+ " # Check that all requested files exist\n",
243
+ " missing_files = set(bigwig_file_ids) - set(available_bigwig_files.keys())\n",
244
+ " if missing_files:\n",
245
+ " raise ValueError(\n",
246
+ " f\"Requested BigWig files not found: {missing_files}. \"\n",
247
+ " f\"Available files: {list(available_bigwig_files.keys())}\"\n",
248
+ " )\n",
249
+ " \n",
250
+ " # Add specific patterns for requested BigWig files only\n",
251
+ " for file_id in bigwig_file_ids:\n",
252
+ " download_patterns.append(available_bigwig_files[file_id])\n",
253
+ " else:\n",
254
+ " # Download all BigWig files\n",
255
+ " download_patterns.append(f\"{species}/functional_tracks/*.bigwig\")\n",
256
  " local_dir = Path(\n",
257
  " snapshot_download(\n",
258
  " repo_id=hf_repo_id,\n",
 
269
  " if not Path(fasta_path).is_file():\n",
270
  " raise ValueError(f\"FASTA file not found at '{fasta_path}'\")\n",
271
  " \n",
272
+ " # BigWig files - use downloaded files directly\n",
273
+ " bigwig_dir = local_dir / species / \"functional_tracks\"\n",
274
+ " \n",
275
+ " if bigwig_file_ids is not None:\n",
276
+ " bigwig_paths = [bigwig_dir / f\"{file_id}.bigwig\" for file_id in bigwig_file_ids]\n",
277
+ " bigwig_ids = bigwig_file_ids\n",
278
+ " else:\n",
279
+ " # Find all downloaded BigWig files\n",
280
+ " bigwig_paths = [bigwig_file for bigwig_file in bigwig_dir.glob(\"*.bigwig\")]\n",
281
+ " bigwig_ids = [bigwig_file.stem for bigwig_file in bigwig_dir.glob(\"*.bigwig\")] \n",
282
  " \n",
283
  " # Splits file\n",
284
  " splits_path_repo = f\"{species}/splits.bed\"\n",
 
318
  "source": [
319
  "os.makedirs(config[\"data_cache_dir\"], exist_ok=True)\n",
320
  "\n",
321
+ "# Download all requested species-related files + load the splits, and metadata\n",
322
  "(\n",
323
  " fasta_path, \n",
324
  " bigwig_paths, \n",
 
328
  ") = prepare_genomics_inputs(\n",
329
  " config[\"species\"], \n",
330
  " config[\"data_cache_dir\"], \n",
331
+ " config[\"hf_repo_id\"],\n",
332
+ " bigwig_file_ids=config[\"bigwig_file_ids\"]\n",
333
  ")"
334
  ]
335
  },
 
383
  " \n",
384
  " # Load config and model\n",
385
  " self.config = AutoConfig.from_pretrained(model_name, trust_remote_code=True)\n",
386
+ " backbone = AutoModelForMaskedLM.from_pretrained(\n",
387
  " model_name, \n",
388
  " trust_remote_code=True,\n",
389
  " config=self.config,\n",
390
  " )\n",
391
+ " self.backbone = torch.compile(backbone)\n",
392
  " \n",
393
  " self.keep_target_center_fraction = keep_target_center_fraction\n",
394
  "\n",
 
464
  "metadata": {},
465
  "outputs": [],
466
  "source": [
467
+ "# Process-local cache for file handles (one per worker process)\n",
468
  "# This allows safe multi-worker DataLoader usage\n",
469
+ "_fasta_cache = {} # Maps (process_id, file_path) -> Fasta handle\n",
470
  "_bigwig_cache = {} # Maps (process_id, file_path) -> pyBigWig handle\n",
471
  "\n",
472
  "\n",
473
+ "def _get_fasta_handle(fasta_path: str) -> Fasta:\n",
474
+ " \"\"\"Get or create a FASTA file handle for the current process.\"\"\"\n",
475
+ " process_id = os.getpid()\n",
476
+ " abs_path = str(Path(fasta_path).resolve())\n",
477
+ " cache_key = (process_id, abs_path)\n",
478
+ " \n",
479
+ " if cache_key not in _fasta_cache:\n",
480
+ " _fasta_cache[cache_key] = Fasta(abs_path, as_raw=True, sequence_always_upper=True)\n",
481
+ " \n",
482
+ " return _fasta_cache[cache_key]\n",
483
+ "\n",
484
+ "\n",
485
  "def _get_bigwig_handle(bigwig_path: str) -> pyBigWig.pyBigWig:\n",
486
  " \"\"\"Get or create a BigWig file handle for the current process.\"\"\"\n",
487
  " process_id = os.getpid()\n",
488
+ " abs_path = str(Path(bigwig_path).resolve())\n",
489
+ " cache_key = (process_id, abs_path)\n",
490
  " \n",
491
  " if cache_key not in _bigwig_cache:\n",
492
+ " # Check if file exists before trying to open\n",
493
+ " if not Path(abs_path).exists():\n",
494
+ " raise FileNotFoundError(\n",
495
+ " f\"BigWig file not found: {abs_path}\\n\"\n",
496
+ " f\"Original path: {bigwig_path}\\n\"\n",
497
+ " f\"Current working directory: {os.getcwd()}\"\n",
498
+ " )\n",
499
+ " \n",
500
+ " try:\n",
501
+ " _bigwig_cache[cache_key] = pyBigWig.open(abs_path)\n",
502
+ " except Exception as e:\n",
503
+ " raise RuntimeError(\n",
504
+ " f\"Failed to open BigWig file: {abs_path}\\n\"\n",
505
+ " f\"Error: {str(e)}\\n\"\n",
506
+ " f\"File exists: {Path(abs_path).exists()}\\n\"\n",
507
+ " f\"File size: {Path(abs_path).stat().st_size if Path(abs_path).exists() else 'N/A'} bytes\"\n",
508
+ " ) from e\n",
509
  " \n",
510
  " return _bigwig_cache[cache_key]\n",
511
  "\n",
 
560
  " ):\n",
561
  " super().__init__()\n",
562
  "\n",
 
563
  " # Store paths instead of opening files immediately (for multi-worker compatibility)\n",
564
+ " self.fasta_path = fasta_path\n",
565
  " self.bigwig_path_list = bigwig_path_list\n",
566
  " self.sequence_length = sequence_length\n",
567
  " self.num_samples = num_samples\n",
 
599
  " start = random.randint(region_start, max_start)\n",
600
  " end = start + self.sequence_length\n",
601
  "\n",
602
+ " # Sequence - get FASTA handle lazily (cached per worker process)\n",
603
+ " fasta = _get_fasta_handle(self.fasta_path)\n",
604
+ " seq = fasta[chrom][start:end] # string slice\n",
605
  " # Tokenize with padding and truncation to ensure consistent lengths for batching\n",
606
  " tokenized = self.tokenizer(\n",
607
  " seq,\n",
 
855
  " metrics_dict = {}\n",
856
  " \n",
857
  " # Compute Pearson correlation per track\n",
858
+ " # Move to CPU before converting to numpy\n",
859
+ " correlations = self.pearson_metric.compute().cpu().numpy()\n",
860
  " for i, track_name in enumerate(self.track_names):\n",
861
  " metrics_dict[f\"{track_name}/pearson\"] = correlations[i]\n",
862
  " \n",
 
999
  " loss.backward()\n",
1000
  " return loss.item()\n",
1001
  "\n",
 
1002
  "def validation_step(\n",
1003
  " model: nn.Module,\n",
1004
  " batch: Dict[str, torch.Tensor],\n",
 
1216
  "source": [
1217
  " ## Test set results\n",
1218
  "\n",
 
 
1219
  "**Hardware configuration**: These results were obtained on an **H100 GPU with 16 workers** for data loading in approximately **~10 minutes** of training.\n",
1220
  "\n",
1221
+ "Performances reached at ~1.5B tokens (~1500 steps in current 32kb sequences setup with batch_size=32)\n",
1222
+ "\n",
1223
  "Mean Pearson: 0.5835\n",
1224
  "- ENCSR325NFE/pearson: 0.6081\n",
1225
  "- ENCSR962OTG/pearson: 0.7286\n",