bernardo-de-almeida commited on
Commit
143dd5d
·
1 Parent(s): 7436df3

fix: finetuning post-trained model on new species

Browse files
notebooks_tutorials/03_fine_tuning_posttrained_model_biwig.ipynb CHANGED
@@ -4,10 +4,10 @@
4
  "cell_type": "markdown",
5
  "metadata": {},
6
  "source": [
7
- "# 🧬 Fine-Tuning a Post-trained Model on BigWig Tracks Prediction (reproduce paper results)\n",
8
  "\n",
9
- "This notebook is designed to enable the reproduction of the results in the paper. In contrast to the simplified fine tuning setup in [02_fine_tuning_pretrained_model_biwig.ipynb](https://huggingface.co/spaces/InstaDeepAI/ntv3/blob/main/notebooks_tutorials/02_fine_tuning_pretrained_model_biwig.ipynb), this more complex setup is designed to mirror the internal JAX pipeline used to run the evaluations.\n",
10
- "As in the benchmark, the notebook finetunes a post-trained Nucleotide Transformer v3 (`NTv3_650M_post`) model to predict BigWig signal tracks directly from DNA sequences. The streamlined approach leverages a post-trained NTv3 backbone as a feature extractor. A new prediction head is added to the model, which outputs single-nucleotide resolution signal values for each of the 34 tracks in the NTv3 benchmark for the `human` species.\n",
11
  "\n",
12
  "**🦚 Features:**\n",
13
  "In addition to the simplifed version, the following features are added:\n",
@@ -18,7 +18,7 @@
18
  "- Save the latest and best models for future use\n",
19
  "\n",
20
  "**🔦 JAX vs PyTorch:**\n",
21
- "The values achieved by this pipeline are close (within 0.01 mean Pearson) to those reported in the paper. They differ slightly due to using here a PyTorch pipeline to make it easier for users, as opposed to the JAX pipeline used for the results in the paper. For most accurate performance, it is recommended to use 3x seeds and average the results, as shown in the paper.\n",
22
  "\n",
23
  "**🚆 Training:**\n",
24
  "To run this training, you will need a large GPU (either A100 or H100). It takes around 28 hours on an H100 with the default settings. It might be possible to improve the tuning of the number of workers to improve efficiency. Our JAX pipeline is able to complete the training in around 12 hours.\n",
@@ -162,7 +162,7 @@
162
  "\n",
163
  "### Data\n",
164
  "- **`hf_repo_id`**: HuggingFace dataset repository ID containing the benchmark data\n",
165
- "- **`species`**: Species name (e.g., \"human\") to select data from the benchmark dataset\n",
166
  "- **`data_cache_dir`**: Directory where downloaded data files (FASTA, bigWig) will be stored\n",
167
  "- **`sequence_length`**: Length of input sequences in base pairs (bp)\n",
168
  "- **`keep_target_center_fraction`**: Fraction of center sequence to keep for target prediction (crops edges to focus on center)\n",
@@ -185,7 +185,9 @@
185
  "### General\n",
186
  "- **`seed`**: Random seed for reproducibility\n",
187
  "- **`device`**: Device to run training on (\"cuda\" or \"cpu\")\n",
188
- "- **`num_workers`**: Number of worker processes for DataLoader (0 = single-threaded)\n"
 
 
189
  ]
190
  },
191
  {
@@ -209,7 +211,7 @@
209
  " \n",
210
  " # Data\n",
211
  " \"hf_repo_id\": \"InstaDeepAI/NTv3_benchmark_dataset\",\n",
212
- " \"species_name\": \"human\",\n",
213
  " \"data_cache_dir\": \"./data\",\n",
214
  " \"sequence_length\": 32_768,\n",
215
  " \"keep_target_center_fraction\": 0.375,\n",
@@ -255,7 +257,7 @@
255
  },
256
  {
257
  "cell_type": "code",
258
- "execution_count": 5,
259
  "metadata": {},
260
  "outputs": [],
261
  "source": [
@@ -310,7 +312,7 @@
310
  " bigwig_paths = [str(bigwig_file) for bigwig_file in bigwig_dir.glob(\"*.bigwig\")]\n",
311
  " bigwig_ids = [bigwig_file.stem for bigwig_file in bigwig_dir.glob(\"*.bigwig\")] \n",
312
  "\n",
313
- " # Splits file\n",
314
  " splits_path_repo = f\"{species}/splits.bed\"\n",
315
  " splits_path = local_dir / splits_path_repo\n",
316
  "\n",
@@ -327,7 +329,7 @@
327
  " metadata_df = pd.read_csv(metadata_path, sep=\"\\t\")\n",
328
  "\n",
329
  " # Filter and order metadata \n",
330
- " metadata_df = metadata_df[metadata_df[\"species_name\"] == species].reset_index(drop=True)\n",
331
  " metadata_df = metadata_df.set_index(\"file_id\").loc[bigwig_ids].reset_index()\n",
332
  "\n",
333
  " return fasta_path, bigwig_paths, bigwig_ids, splits_df, metadata_df"
@@ -945,16 +947,24 @@
945
  " ):\n",
946
  " super().__init__()\n",
947
  " \n",
948
- " # Load config and model\n",
949
  " self.config = AutoConfig.from_pretrained(model_name, trust_remote_code=True)\n",
950
- " backbone = AutoModel.from_pretrained(\n",
951
  " model_name, \n",
952
  " trust_remote_code=True,\n",
953
  " config=self.config,\n",
954
  " )\n",
955
- " self.backbone = torch.compile(backbone)\n",
956
- " if species_str in self.backbone.supported_species:\n",
957
- " self.species_ids = self.backbone.encode_species(species_str)\n",
 
 
 
 
 
 
 
 
958
  " print(f\"Using species: {species_str} with ids: {self.species_ids}\")\n",
959
  " else:\n",
960
  " # Mask token id\n",
@@ -968,11 +978,13 @@
968
  " self.model_name = model_name\n",
969
  " \n",
970
  " def forward(self, tokens: torch.Tensor, **kwargs) -> Dict[str, torch.Tensor]:\n",
971
- " # Forward through backbone\n",
972
  " species_tokens = torch.repeat_interleave(self.species_ids, tokens.shape[0])\n",
973
  " species_tokens = species_tokens.to(tokens.device)\n",
974
- " outputs = self.backbone(input_ids=tokens, species_ids=species_tokens, output_hidden_states=True)\n",
975
- " embedding = outputs.hidden_states[-1] # Last hidden state\n",
 
 
976
  " \n",
977
  " # Crop to center fraction\n",
978
  " if self.keep_target_center_fraction < 1.0:\n",
@@ -1198,6 +1210,7 @@
1198
  " keep_target_center_fraction=config[\"keep_target_center_fraction\"],\n",
1199
  ")\n",
1200
  "model = model.to(device)\n",
 
1201
  "model.train()\n",
1202
  "\n",
1203
  "print(f\"Model loaded: {config['model_name']}\")\n",
@@ -1553,6 +1566,18 @@
1553
  }
1554
  ],
1555
  "source": [
 
 
 
 
 
 
 
 
 
 
 
 
1556
  "# Create datasets & dataloaders\n",
1557
  "create_dataset_fn = functools.partial(\n",
1558
  " GenomeBigWigDataset,\n",
@@ -2162,7 +2187,7 @@
2162
  "cell_type": "markdown",
2163
  "metadata": {},
2164
  "source": [
2165
- "## Test set results obtained for reference\n",
2166
  "\n",
2167
  "===== Test Set Results =====\n",
2168
  "\n",
@@ -2203,6 +2228,40 @@
2203
  "- ENCSR100LIJ_P/pearson: 0.5643"
2204
  ]
2205
  },
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2206
  {
2207
  "cell_type": "markdown",
2208
  "metadata": {},
 
4
  "cell_type": "markdown",
5
  "metadata": {},
6
  "source": [
7
+ "# 🧬 Fine-Tuning a Post-trained Model on Functional BigWig Tracks Prediction (reproduce paper results)\n",
8
  "\n",
9
+ "This notebook is designed to enable the reproduction of the fine-tuning results on functional genomics tracks in the paper. In contrast to the simplified fine tuning setup in [02_fine_tuning_pretrained_model_biwig.ipynb](https://huggingface.co/spaces/InstaDeepAI/ntv3/blob/main/notebooks_tutorials/02_fine_tuning_pretrained_model_biwig.ipynb), this more complex setup is designed to mirror the internal JAX pipeline used to run the evaluations in PyTorch and using our HuggingFace models.\n",
10
+ "As in the benchmark, the notebook finetunes the post-trained Nucleotide Transformer v3 (`NTv3_650M_post`) model to predict BigWig signal tracks directly from DNA sequences. The streamlined approach leverages a post-trained NTv3 backbone as a feature extractor. A new prediction head is added to the model, which outputs single-nucleotide resolution signal values for each of the functional bigwig tracks in the NTv3 benchmark for the selected species. The notebook uses the 34 tracks for the `human` species by default, but the user can change the config to use any species from the benchmark.\n",
11
  "\n",
12
  "**🦚 Features:**\n",
13
  "In addition to the simplifed version, the following features are added:\n",
 
18
  "- Save the latest and best models for future use\n",
19
  "\n",
20
  "**🔦 JAX vs PyTorch:**\n",
21
+ "The values achieved by this pipeline are close (within 0.01 mean Pearson for human) to those reported in the paper. They differ slightly due to using here a PyTorch pipeline to make it easier for users, as opposed to the JAX pipeline used for the results in the paper. For most accurate performance, it is recommended to use 3x seeds and average the results, as shown in the paper.\n",
22
  "\n",
23
  "**🚆 Training:**\n",
24
  "To run this training, you will need a large GPU (either A100 or H100). It takes around 28 hours on an H100 with the default settings. It might be possible to improve the tuning of the number of workers to improve efficiency. Our JAX pipeline is able to complete the training in around 12 hours.\n",
 
162
  "\n",
163
  "### Data\n",
164
  "- **`hf_repo_id`**: HuggingFace dataset repository ID containing the benchmark data\n",
165
+ "- **`species`**: Species name (e.g., \"human\", \"tomato\") to select bigwig data from the benchmark dataset\n",
166
  "- **`data_cache_dir`**: Directory where downloaded data files (FASTA, bigWig) will be stored\n",
167
  "- **`sequence_length`**: Length of input sequences in base pairs (bp)\n",
168
  "- **`keep_target_center_fraction`**: Fraction of center sequence to keep for target prediction (crops edges to focus on center)\n",
 
185
  "### General\n",
186
  "- **`seed`**: Random seed for reproducibility\n",
187
  "- **`device`**: Device to run training on (\"cuda\" or \"cpu\")\n",
188
+ "- **`num_workers`**: Number of worker processes for DataLoader (0 = single-threaded)\n",
189
+ "\n",
190
+ "NOTE: the default parameters will finetune the model on the human dataset, to finetune on the tomato dataset, set the 'species_name' to 'tomato' in the config. You can also update the config parameters regarding the number of training and warmup tokens based on the species genome size, as done in our benchmark (see paper details), although this is not neccessery to achieve top performance results."
191
  ]
192
  },
193
  {
 
211
  " \n",
212
  " # Data\n",
213
  " \"hf_repo_id\": \"InstaDeepAI/NTv3_benchmark_dataset\",\n",
214
+ " \"species_name\": \"human\", # Select the species to train on, i.e. \"tomato\"\n",
215
  " \"data_cache_dir\": \"./data\",\n",
216
  " \"sequence_length\": 32_768,\n",
217
  " \"keep_target_center_fraction\": 0.375,\n",
 
257
  },
258
  {
259
  "cell_type": "code",
260
+ "execution_count": null,
261
  "metadata": {},
262
  "outputs": [],
263
  "source": [
 
312
  " bigwig_paths = [str(bigwig_file) for bigwig_file in bigwig_dir.glob(\"*.bigwig\")]\n",
313
  " bigwig_ids = [bigwig_file.stem for bigwig_file in bigwig_dir.glob(\"*.bigwig\")] \n",
314
  "\n",
315
+ " # Data splits file\n",
316
  " splits_path_repo = f\"{species}/splits.bed\"\n",
317
  " splits_path = local_dir / splits_path_repo\n",
318
  "\n",
 
329
  " metadata_df = pd.read_csv(metadata_path, sep=\"\\t\")\n",
330
  "\n",
331
  " # Filter and order metadata \n",
332
+ " metadata_df = metadata_df[metadata_df[\"species_common_name\"] == species].reset_index(drop=True)\n",
333
  " metadata_df = metadata_df.set_index(\"file_id\").loc[bigwig_ids].reset_index()\n",
334
  "\n",
335
  " return fasta_path, bigwig_paths, bigwig_ids, splits_df, metadata_df"
 
947
  " ):\n",
948
  " super().__init__()\n",
949
  " \n",
950
+ " # Load base model config and model\n",
951
  " self.config = AutoConfig.from_pretrained(model_name, trust_remote_code=True)\n",
952
+ " ntv3_base_model = AutoModel.from_pretrained(\n",
953
  " model_name, \n",
954
  " trust_remote_code=True,\n",
955
  " config=self.config,\n",
956
  " )\n",
957
+ "\n",
958
+ " # Extract the discrete conditioned model (i.e. remove the heads) for finetuning\n",
959
+ " discrete_conditioned_model = type(ntv3_base_model.core).__bases__[0]\n",
960
+ " self.core = discrete_conditioned_model(self.config) # follows name covention\n",
961
+ " # Load pre-trained weights (strict=False because we don't load the heads)\n",
962
+ " self.load_state_dict(ntv3_base_model.state_dict(), strict=False) \n",
963
+ "\n",
964
+ " self.supported_species = self.config.bigwigs_per_species.keys()\n",
965
+ " if species_str in self.config.species_to_token_id:\n",
966
+ " species_ids = self.config.species_to_token_id[species_str]\n",
967
+ " self.species_ids = torch.LongTensor([species_ids])\n",
968
  " print(f\"Using species: {species_str} with ids: {self.species_ids}\")\n",
969
  " else:\n",
970
  " # Mask token id\n",
 
978
  " self.model_name = model_name\n",
979
  " \n",
980
  " def forward(self, tokens: torch.Tensor, **kwargs) -> Dict[str, torch.Tensor]:\n",
981
+ " # Prepare the species tokens\n",
982
  " species_tokens = torch.repeat_interleave(self.species_ids, tokens.shape[0])\n",
983
  " species_tokens = species_tokens.to(tokens.device)\n",
984
+ "\n",
985
+ " # Forward through core\n",
986
+ " outputs = self.core(tokens, [species_tokens], output_hidden_states=True)\n",
987
+ " embedding = outputs[\"hidden_states\"][-1]\n",
988
  " \n",
989
  " # Crop to center fraction\n",
990
  " if self.keep_target_center_fraction < 1.0:\n",
 
1210
  " keep_target_center_fraction=config[\"keep_target_center_fraction\"],\n",
1211
  ")\n",
1212
  "model = model.to(device)\n",
1213
+ "model = torch.compile(model)\n",
1214
  "model.train()\n",
1215
  "\n",
1216
  "print(f\"Model loaded: {config['model_name']}\")\n",
 
1566
  }
1567
  ],
1568
  "source": [
1569
+ "# Pre-build the FASTA index in the main process to avoid race conditions\n",
1570
+ "# when multiple DataLoader workers try to create it simultaneously\n",
1571
+ "print(f\"Pre-building FASTA index for {fasta_path}...\")\n",
1572
+ "fai_path = Path(fasta_path + \".fai\")\n",
1573
+ "if fai_path.exists():\n",
1574
+ " # Remove potentially corrupted index from a previous failed run\n",
1575
+ " print(f\"Removing existing FASTA index: {fai_path}\")\n",
1576
+ " fai_path.unlink()\n",
1577
+ "_prebuild_fasta = Fasta(fasta_path, as_raw=True, sequence_always_upper=True)\n",
1578
+ "del _prebuild_fasta # Close the handle; workers will reopen with existing index\n",
1579
+ "print(\"FASTA index built successfully.\")\n",
1580
+ "\n",
1581
  "# Create datasets & dataloaders\n",
1582
  "create_dataset_fn = functools.partial(\n",
1583
  " GenomeBigWigDataset,\n",
 
2187
  "cell_type": "markdown",
2188
  "metadata": {},
2189
  "source": [
2190
+ "## Test set results obtained for reference (human)\n",
2191
  "\n",
2192
  "===== Test Set Results =====\n",
2193
  "\n",
 
2228
  "- ENCSR100LIJ_P/pearson: 0.5643"
2229
  ]
2230
  },
2231
+ {
2232
+ "cell_type": "markdown",
2233
+ "metadata": {},
2234
+ "source": [
2235
+ "## Test set results obtained for a new species (tomato)\n",
2236
+ "\n",
2237
+ "===== Test Set Results =====\n",
2238
+ "\n",
2239
+ "Metrics:\n",
2240
+ "Mean Pearson: 0.7596\n",
2241
+ "- SRX29291439/pearson: 0.8581\n",
2242
+ "- SRX27799718/pearson: 0.4512\n",
2243
+ "- SRX29291446/pearson: 0.9152\n",
2244
+ "- SRX29291430/pearson: 0.9069\n",
2245
+ "- SRX27799731/pearson: 0.5254\n",
2246
+ "- SRX27799719/pearson: 0.4435\n",
2247
+ "- SRX29291442/pearson: 0.9151\n",
2248
+ "- SRX27799733/pearson: 0.4725\n",
2249
+ "- SRX27799722/pearson: 0.4795\n",
2250
+ "- SRX29291444/pearson: 0.9139\n",
2251
+ "- SRX29291440/pearson: 0.8633\n",
2252
+ "- SRX27799727/pearson: 0.6209\n",
2253
+ "- SRX29291438/pearson: 0.8770\n",
2254
+ "- SRX27799703/pearson: 0.4722\n",
2255
+ "- SRX29291448/pearson: 0.9160\n",
2256
+ "- SRX29291441/pearson: 0.9169\n",
2257
+ "- SRX29291447/pearson: 0.9171\n",
2258
+ "- SRX29291445/pearson: 0.8990\n",
2259
+ "- SRX29291431/pearson: 0.9181\n",
2260
+ "- SRX29291443/pearson: 0.9103\n",
2261
+ "\n",
2262
+ "NOTE: to achieve these results, set the 'species_name' to 'tomato' in the config."
2263
+ ]
2264
+ },
2265
  {
2266
  "cell_type": "markdown",
2267
  "metadata": {},