# 🚀 NTv3 Quickstart — Pre-trained and Post-trained models

This notebook demonstrates how to run **quick inference** with both the pre- and post-trained NTv3 checkpoints:

- **Pre-trained (MLM-focused):** `InstaDeepAI/NTv3_8M_pre`, `InstaDeepAI/NTv3_100M_pre`, `InstaDeepAI/NTv3_650M_pre`
- **Post-trained (functional tracks and genome annotation):** `InstaDeepAI/NTv3_100M_post`, `InstaDeepAI/NTv3_650M_post`

We show how to:

1. Load tokenizers + models
2. Run a forward pass on a DNA sequence window
3. Inspect key outputs

> 📝 **Note for Google Colab users:** This notebook is compatible with Colab! For faster inference, make sure to enable GPU: Runtime → Change runtime type → GPU (T4 or better recommended).

## 0) 📦 Imports + setup

In [None]:
# Login to HuggingFace (required for gated models)
from huggingface_hub import login
login()

In [1]:
!pip -q install "transformers>=4.40" "huggingface_hub>=0.23" safetensors torch numpy

In [2]:
import os
import torch
import numpy as np

from transformers import AutoConfig, AutoModel, AutoTokenizer, AutoModelForMaskedLM

# Optional: if the model is gated/private, set HF_TOKEN to a PERSONAL token (hf_...)
HF_TOKEN = os.getenv("HF_TOKEN", None)

# -----------------------------
# Device
# -----------------------------
device = "cuda" if torch.cuda.is_available() else "cpu"
print("device:", device)

# Choose dtype (bf16 if supported; else fp16 on GPU; else fp32)
if device == "cuda":
 major, minor = torch.cuda.get_device_capability(0)
 torch_dtype = torch.bfloat16 if major >= 8 else torch.float16
else:
 torch_dtype = torch.float32

print("torch_dtype:", torch_dtype)

device: cpu
torch_dtype: torch.float32


In [3]:
# Dummy DNA sequences
seqs = [
 "ACGT" * 32,
 "ACGT" * 128
]

print(" Sequence lengths:", [len(s) for s in seqs])

 Sequence lengths: [128, 512]


## 1) 🎯 Pre-trained checkpoint (MLM-focused)

This shows the simplest usage: load model + tokenizer, then run a forward pass.

Expected output:
- `logits`: masked language modeling logits

In [4]:
pretrained_model_name = "InstaDeepAI/NTv3_8M_pre"

# Load tokenizer/model
tok_pre = AutoTokenizer.from_pretrained(pretrained_model_name, trust_remote_code=True)
model_pre = AutoModelForMaskedLM.from_pretrained(pretrained_model_name, trust_remote_code=True)

# Example inference
# Tokenization will pad all sequences to multiple of 128
batch = tok_pre(seqs, add_special_tokens=False, padding=True, pad_to_multiple_of=128, return_tensors="pt")
out = model_pre(**batch)

# Access MLM logits
mlm_logits = out["logits"]
print("MLM logits shape:", tuple(mlm_logits.shape))

MLM logits shape: (2, 512, 11)


## 2) 🧠 Post-trained checkpoint (task heads: BigWig + BED)

Post-trained checkpoints add task-specific heads for functional track prediction and genome annotation.

Expected outputs:
- `bigwig_tracks_logits`: functional track predictions
- `bed_tracks_logits`: genome annotation predictions
- `logits`: masked language modeling logits

In [5]:
# Load model
post_trained_model_name = "InstaDeepAI/NTv3_100M_post"

tok_post = AutoTokenizer.from_pretrained(post_trained_model_name, trust_remote_code=True)
model_post = AutoModel.from_pretrained(post_trained_model_name, trust_remote_code=True)

# Prepare inputs - tokenization will pad all sequences to multiple of 128
batch = tok_post(seqs, add_special_tokens=False, padding=True, pad_to_multiple_of=128, return_tensors="pt")

# To show all supported species: 
print("Supported species:", model_post.config.species_to_token_id.keys())
# Species tokens (one per sequence)
species = ['human', 'mouse']
species_ids = model_post.encode_species(species)

# Forward pass
out = model_post(
 input_ids=batch["input_ids"],
 species_ids=species_ids,
)

# 7k human tracks over 37.5 % center region of the input sequence
print("bigwig_tracks_logits:", tuple(out["bigwig_tracks_logits"].shape))
# Location of 21 genomic elements over 37.5 % center region of the input sequence
print("bed_tracks_logits:", tuple(out["bed_tracks_logits"].shape))
# Language model logits for whole sequence over vocabulary
print("language model logits:", tuple(out["logits"].shape))


Supported species: dict_keys(['', '', '', '', '', '', 'amphiprion_ocellaris', 'arabidopsis_thaliana', 'bison_bison_bison', 'caenorhabditis_elegans', 'canis_lupus_familiaris', 'chinchilla_lanigera', 'ciona_intestinalis', 'danio_rerio', 'drosophila_melanogaster', 'felis_catus', 'gallus_gallus', 'glycine_max', 'gorilla_gorilla', 'gossypium_hirsutum', 'human', 'macaca_nemestrina', 'mouse', 'oryza_sativa', 'rattus_norvegicus', 'salmo_trutta', 'serinus_canaria', 'tetraodon_nigroviridis', 'triticum_aestivum', 'zea_mays'])
bigwig_tracks_logits: (2, 192, 7362)
bed_tracks_logits: (2, 192, 21, 2)
language model logits: (2, 512, 11)
