{ "cells": [ { "cell_type": "markdown", "id": "024bb8a8", "metadata": {}, "source": [ "# 🚀 NTv3 Quickstart — Pre-trained and Post-trained models\n", "\n", "This notebook demonstrates how to run **quick inference** with both the pre- and post-trained NTv3 checkpoints:\n", "\n", "- **Pre-trained (MLM-focused):** `InstaDeepAI/NTv3_8M_pre`, `InstaDeepAI/NTv3_100M_pre`, `InstaDeepAI/NTv3_650M_pre`\n", "- **Post-trained (functional tracks and genome annotation):** `InstaDeepAI/NTv3_100M_pos`, `InstaDeepAI/NTv3_650M_pos`\n", "\n", "We show how to:\n", "\n", "1. Load tokenizers + models\n", "2. Run a forward pass on a DNA sequence window\n", "3. Inspect key outputs\n", "\n", "> 📝 **Note for Google Colab users:** This notebook is compatible with Colab! For faster inference, make sure to enable GPU: Runtime → Change runtime type → GPU (T4 or better recommended)." ] }, { "cell_type": "markdown", "id": "5827af7e", "metadata": {}, "source": [ "## 0) 📦 Imports + setup" ] }, { "cell_type": "code", "execution_count": null, "id": "38cc32a9", "metadata": {}, "outputs": [], "source": [ "!pip -q install \"transformers>=4.40\" \"huggingface_hub>=0.23\" safetensors torch numpy" ] }, { "cell_type": "code", "execution_count": 3, "id": "d56c105b", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "device: cpu\n", "torch_dtype: torch.float32\n" ] } ], "source": [ "import os\n", "import torch\n", "import numpy as np\n", "\n", "from transformers import AutoConfig, AutoModel, AutoTokenizer, AutoModelForMaskedLM\n", "\n", "# Optional: if the model is gated/private, set HF_TOKEN to a PERSONAL token (hf_...)\n", "HF_TOKEN = os.getenv(\"HF_TOKEN\", None)\n", "\n", "# -----------------------------\n", "# Device\n", "# -----------------------------\n", "device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n", "print(\"device:\", device)\n", "\n", "# Choose dtype (bf16 if supported; else fp16 on GPU; else fp32)\n", "if device == \"cuda\":\n", " major, minor = torch.cuda.get_device_capability(0)\n", " torch_dtype = torch.bfloat16 if major >= 8 else torch.float16\n", "else:\n", " torch_dtype = torch.float32\n", "\n", "print(\"torch_dtype:\", torch_dtype)" ] }, { "cell_type": "markdown", "id": "82146876", "metadata": {}, "source": [ "## 1) 🎯 Pre-trained checkpoint (MLM-focused)\n", "\n", "This shows the simplest usage: load model + tokenizer, then run a forward pass.\n", "\n", "Expected output:\n", "- `logits`: masked language modeling logits" ] }, { "cell_type": "code", "execution_count": null, "id": "336bb40c", "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "411ee47e94ae467f9685c35b65e3e52d", "version_major": 2, "version_minor": 0 }, "text/plain": [ "tokenizer_config.json: 0%| | 0.00/1.48k [00:00