{ "cells": [ { "cell_type": "markdown", "id": "024bb8a8", "metadata": {}, "source": [ "# 🚀 NTv3 Quickstart — Pre-trained and Post-trained models\n", "\n", "This notebook demonstrates how to run **quick inference** with both the pre- and post-trained NTv3 checkpoints:\n", "\n", "- **Pre-trained (MLM-focused):** `InstaDeepAI/NTv3_8M_pre`, `InstaDeepAI/NTv3_100M_pre`, `InstaDeepAI/NTv3_650M_pre`\n", "- **Post-trained (functional tracks and genome annotation):** `InstaDeepAI/NTv3_100M_post`, `InstaDeepAI/NTv3_650M_post`\n", "\n", "We show how to:\n", "\n", "1. Load tokenizers + models\n", "2. Run a forward pass on a DNA sequence window\n", "3. Inspect key outputs\n", "\n", "> 📝 **Note for Google Colab users:** This notebook is compatible with Colab! For faster inference, make sure to enable GPU: Runtime → Change runtime type → GPU (T4 or better recommended)." ] }, { "cell_type": "markdown", "id": "5827af7e", "metadata": {}, "source": [ "## 0) 📦 Imports + setup" ] }, { "cell_type": "code", "execution_count": null, "id": "0b354087", "metadata": {}, "outputs": [], "source": [ "# Login to HuggingFace (required for gated models)\n", "from huggingface_hub import login\n", "login()" ] }, { "cell_type": "code", "execution_count": 1, "id": "38cc32a9", "metadata": {}, "outputs": [], "source": [ "!pip -q install \"transformers>=4.40\" \"huggingface_hub>=0.23\" safetensors torch numpy" ] }, { "cell_type": "code", "execution_count": 2, "id": "d56c105b", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "device: cpu\n", "torch_dtype: torch.float32\n" ] } ], "source": [ "import os\n", "import torch\n", "import numpy as np\n", "\n", "from transformers import AutoConfig, AutoModel, AutoTokenizer, AutoModelForMaskedLM\n", "\n", "# Optional: if the model is gated/private, set HF_TOKEN to a PERSONAL token (hf_...)\n", "HF_TOKEN = os.getenv(\"HF_TOKEN\", None)\n", "\n", "# -----------------------------\n", "# Device\n", "# -----------------------------\n", "device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n", "print(\"device:\", device)\n", "\n", "# Choose dtype (bf16 if supported; else fp16 on GPU; else fp32)\n", "if device == \"cuda\":\n", " major, minor = torch.cuda.get_device_capability(0)\n", " torch_dtype = torch.bfloat16 if major >= 8 else torch.float16\n", "else:\n", " torch_dtype = torch.float32\n", "\n", "print(\"torch_dtype:\", torch_dtype)" ] }, { "cell_type": "code", "execution_count": 3, "id": "ef0e6d69", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ " Sequence lengths: [128, 512]\n" ] } ], "source": [ "# Dummy DNA sequences\n", "seqs = [\n", " \"ACGT\" * 32,\n", " \"ACGT\" * 128\n", "]\n", "\n", "print(\" Sequence lengths:\", [len(s) for s in seqs])" ] }, { "cell_type": "markdown", "id": "82146876", "metadata": {}, "source": [ "## 1) 🎯 Pre-trained checkpoint (MLM-focused)\n", "\n", "This shows the simplest usage: load model + tokenizer, then run a forward pass.\n", "\n", "Expected output:\n", "- `logits`: masked language modeling logits" ] }, { "cell_type": "code", "execution_count": 4, "id": "336bb40c", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "MLM logits shape: (2, 512, 11)\n" ] } ], "source": [ "pretrained_model_name = \"InstaDeepAI/NTv3_8M_pre\"\n", "\n", "# Load tokenizer/model\n", "tok_pre = AutoTokenizer.from_pretrained(pretrained_model_name, trust_remote_code=True)\n", "model_pre = AutoModelForMaskedLM.from_pretrained(pretrained_model_name, trust_remote_code=True)\n", "\n", "# Example inference\n", "# Tokenization will pad all sequences to multiple of 128\n", "batch = tok_pre(seqs, add_special_tokens=False, padding=True, pad_to_multiple_of=128, return_tensors=\"pt\")\n", "out = model_pre(**batch)\n", "\n", "# Access MLM logits\n", "mlm_logits = out[\"logits\"]\n", "print(\"MLM logits shape:\", tuple(mlm_logits.shape))" ] }, { "cell_type": "markdown", "id": "60a01798", "metadata": {}, "source": [ "## 2) 🧠 Post-trained checkpoint (task heads: BigWig + BED)\n", "\n", "Post-trained checkpoints add task-specific heads for functional track prediction and genome annotation.\n", "\n", "Expected outputs:\n", "- `bigwig_tracks_logits`: functional track predictions\n", "- `bed_tracks_logits`: genome annotation predictions\n", "- `logits`: masked language modeling logits" ] }, { "cell_type": "code", "execution_count": 5, "id": "6cc5f2df", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Supported species: dict_keys(['', '', '', '', '', '', 'amphiprion_ocellaris', 'arabidopsis_thaliana', 'bison_bison_bison', 'caenorhabditis_elegans', 'canis_lupus_familiaris', 'chinchilla_lanigera', 'ciona_intestinalis', 'danio_rerio', 'drosophila_melanogaster', 'felis_catus', 'gallus_gallus', 'glycine_max', 'gorilla_gorilla', 'gossypium_hirsutum', 'human', 'macaca_nemestrina', 'mouse', 'oryza_sativa', 'rattus_norvegicus', 'salmo_trutta', 'serinus_canaria', 'tetraodon_nigroviridis', 'triticum_aestivum', 'zea_mays'])\n", "bigwig_tracks_logits: (2, 192, 7362)\n", "bed_tracks_logits: (2, 192, 21, 2)\n", "language model logits: (2, 512, 11)\n" ] } ], "source": [ "# Load model\n", "post_trained_model_name = \"InstaDeepAI/NTv3_100M_post\"\n", "\n", "tok_post = AutoTokenizer.from_pretrained(post_trained_model_name, trust_remote_code=True)\n", "model_post = AutoModel.from_pretrained(post_trained_model_name, trust_remote_code=True)\n", "\n", "# Prepare inputs - tokenization will pad all sequences to multiple of 128\n", "batch = tok_post(seqs, add_special_tokens=False, padding=True, pad_to_multiple_of=128, return_tensors=\"pt\")\n", "\n", "# To show all supported species: \n", "print(\"Supported species:\", model_post.config.species_to_token_id.keys())\n", "# Species tokens (one per sequence)\n", "species = ['human', 'mouse']\n", "species_ids = model_post.encode_species(species)\n", "\n", "# Forward pass\n", "out = model_post(\n", " input_ids=batch[\"input_ids\"],\n", " species_ids=species_ids,\n", ")\n", "\n", "# 7k human tracks over 37.5 % center region of the input sequence\n", "print(\"bigwig_tracks_logits:\", tuple(out[\"bigwig_tracks_logits\"].shape))\n", "# Location of 21 genomic elements over 37.5 % center region of the input sequence\n", "print(\"bed_tracks_logits:\", tuple(out[\"bed_tracks_logits\"].shape))\n", "# Language model logits for whole sequence over vocabulary\n", "print(\"language model logits:\", tuple(out[\"logits\"].shape))\n" ] } ], "metadata": { "kernelspec": { "display_name": "hf-finetune", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.18" } }, "nbformat": 4, "nbformat_minor": 5 }