{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# 🧬 Fine-Tuning a Post-trained Model on Functional BigWig Tracks Prediction (reproduce paper results)\n", "\n", "This notebook is designed to enable the reproduction of the fine-tuning results on functional genomics tracks in the paper. In contrast to the simplified fine tuning setup in [02_fine_tuning_pretrained_model_biwig.ipynb](https://huggingface.co/spaces/InstaDeepAI/ntv3/blob/main/notebooks_tutorials/02_fine_tuning_pretrained_model_biwig.ipynb), this more complex setup is designed to mirror the internal JAX pipeline used to run the evaluations in PyTorch and using our HuggingFace models.\n", "As in the benchmark, the notebook finetunes the post-trained Nucleotide Transformer v3 (`NTv3_650M_post`) model to predict BigWig signal tracks directly from DNA sequences. The streamlined approach leverages a post-trained NTv3 backbone as a feature extractor. A new prediction head is added to the model, which outputs single-nucleotide resolution signal values for each of the functional bigwig tracks in the NTv3 benchmark for the selected species. The notebook uses the 34 tracks for the `human` species by default, but the user can change the config to use any species from the benchmark.\n", "\n", "**🦚 Features:**\n", "In addition to the simplifed version, the following features are added:\n", "- Learning rate scheduling\n", "- Use fixed dataset regions for training\n", "- Implement gradient accumulation for large batch sizes\n", "- Use the best model (selected via validation Pearson) for evaluation \n", "- Save the latest and best models for future use\n", "\n", "**🔦 JAX vs PyTorch:**\n", "The values achieved by this pipeline are close (within 0.01 mean Pearson for human) to those reported in the paper. They differ slightly due to using here a PyTorch pipeline to make it easier for users, as opposed to the JAX pipeline used for the results in the paper. For most accurate performance, it is recommended to use 3x seeds and average the results, as shown in the paper.\n", "\n", "**🚆 Training:**\n", "To run this training, you will need a large GPU (either A100 or H100). It takes around 28 hours on an H100 with the default settings. It might be possible to improve the tuning of the number of workers to improve efficiency. Our JAX pipeline is able to complete the training in around 12 hours.\n", "\n", "📝 Note for Google Colab users: This notebook is compatible with Colab! This notebook is designed to be run on a high-performance GPU. The default parameters can be used with a H100 with 80GB of HBM.\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# 0. 📦 Imports dependencies" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "7def2b35ebeb45bc97960837f8a7041c", "version_major": 2, "version_minor": 0 }, "text/plain": [ "VBox(children=(HTML(value='