Spaces:
Running
Running
Commit
·
1094a5f
1
Parent(s):
2101d19
fix: notebooks with new post-trained model formats
Browse files- index.html +9 -6
- notebooks/00_quickstart_inference.ipynb +46 -33
- notebooks/01_tracks_prediction.ipynb +193 -47
index.html
CHANGED
|
@@ -262,8 +262,8 @@
|
|
| 262 |
</li>
|
| 263 |
<li>🎯 Post-trained checkpoints:
|
| 264 |
<div style="margin-top: 8px; margin-left: 0;">
|
| 265 |
-
<div><a href="https://huggingface.co/InstaDeepAI/
|
| 266 |
-
<div><a href="https://huggingface.co/InstaDeepAI/
|
| 267 |
</div>
|
| 268 |
</li>
|
| 269 |
</ul>
|
|
@@ -309,7 +309,7 @@
|
|
| 309 |
<ul>
|
| 310 |
<li><a href="https://huggingface.co/spaces/InstaDeepAI/ntv3/blob/main/notebooks/00_quickstart_inference.ipynb" target="_blank" rel="noopener">🚀 00 — Quickstart inference</a></li>
|
| 311 |
<li><a href="https://huggingface.co/spaces/InstaDeepAI/ntv3/blob/main/notebooks/01_tracks_prediction.ipynb" target="_blank" rel="noopener">📊 01 — Tracks prediction</a></li>
|
| 312 |
-
<li>🏷️ 02 — Genome annotation / segmentation</li>
|
| 313 |
<li>🎯 03 — Fine-tune on bigwig tracks</li>
|
| 314 |
<li>🔍 04 — Model interpretation</li>
|
| 315 |
<li>🧪 05 — Sequence generation</li>
|
|
@@ -380,9 +380,12 @@ out = pipe(
|
|
| 380 |
)
|
| 381 |
|
| 382 |
# Print output shapes
|
| 383 |
-
|
| 384 |
-
print(out.
|
| 385 |
-
|
|
|
|
|
|
|
|
|
|
| 386 |
<p>Predictions can also be plotted for a subset of functional tracks and genomic elements:</p>
|
| 387 |
<div class="code"><pre><code class="language-python">tracks_to_plot = {
|
| 388 |
"K562 RNA-seq": "ENCSR056HPM",
|
|
|
|
| 262 |
</li>
|
| 263 |
<li>🎯 Post-trained checkpoints:
|
| 264 |
<div style="margin-top: 8px; margin-left: 0;">
|
| 265 |
+
<div><a href="https://huggingface.co/InstaDeepAI/NTv3_100M_pos"><code>InstaDeepAI/NTv3_100M_pos</code></a></div>
|
| 266 |
+
<div><a href="https://huggingface.co/InstaDeepAI/NTv3_650M_pos"><code>InstaDeepAI/NTv3_650M_pos</code></a></div>
|
| 267 |
</div>
|
| 268 |
</li>
|
| 269 |
</ul>
|
|
|
|
| 309 |
<ul>
|
| 310 |
<li><a href="https://huggingface.co/spaces/InstaDeepAI/ntv3/blob/main/notebooks/00_quickstart_inference.ipynb" target="_blank" rel="noopener">🚀 00 — Quickstart inference</a></li>
|
| 311 |
<li><a href="https://huggingface.co/spaces/InstaDeepAI/ntv3/blob/main/notebooks/01_tracks_prediction.ipynb" target="_blank" rel="noopener">📊 01 — Tracks prediction</a></li>
|
| 312 |
+
<li><a href="https://huggingface.co/spaces/InstaDeepAI/ntv3/blob/main/notebooks/02_genome_annotation.ipynb" target="_blank" rel="noopener">🏷️ 02 — Genome annotation / segmentation</a></li>
|
| 313 |
<li>🎯 03 — Fine-tune on bigwig tracks</li>
|
| 314 |
<li>🔍 04 — Model interpretation</li>
|
| 315 |
<li>🧪 05 — Sequence generation</li>
|
|
|
|
| 380 |
)
|
| 381 |
|
| 382 |
# Print output shapes
|
| 383 |
+
# 7k human tracks over 37.5 % center region of the input sequence
|
| 384 |
+
print("bigwig_tracks_logits:", tuple(out.bigwig_tracks_logits.shape))
|
| 385 |
+
# Location of 21 genomic elements over 37.5 % center region of the input sequence
|
| 386 |
+
print("bed_tracks_logits:", tuple(out.bed_tracks_logits.shape))
|
| 387 |
+
# Language model logits for whole sequence over vocabulary
|
| 388 |
+
print("language model logits:", tuple(out.mlm_logits.shape))</code></pre></div>
|
| 389 |
<p>Predictions can also be plotted for a subset of functional tracks and genomic elements:</p>
|
| 390 |
<div class="code"><pre><code class="language-python">tracks_to_plot = {
|
| 391 |
"K562 RNA-seq": "ENCSR056HPM",
|
notebooks/00_quickstart_inference.ipynb
CHANGED
|
@@ -10,7 +10,7 @@
|
|
| 10 |
"This notebook demonstrates how to run **quick inference** with both the pre- and post-trained NTv3 checkpoints:\n",
|
| 11 |
"\n",
|
| 12 |
"- **Pre-trained (MLM-focused):** `InstaDeepAI/NTv3_8M_pre`, `InstaDeepAI/NTv3_100M_pre`, `InstaDeepAI/NTv3_650M_pre`\n",
|
| 13 |
-
"- **Post-trained (
|
| 14 |
"\n",
|
| 15 |
"We show how to:\n",
|
| 16 |
"\n",
|
|
@@ -51,7 +51,7 @@
|
|
| 51 |
},
|
| 52 |
{
|
| 53 |
"cell_type": "code",
|
| 54 |
-
"execution_count":
|
| 55 |
"id": "d56c105b",
|
| 56 |
"metadata": {},
|
| 57 |
"outputs": [
|
|
@@ -287,18 +287,40 @@
|
|
| 287 |
"source": [
|
| 288 |
"## 3) 🧠 Post-trained checkpoint (task heads: BigWig + BED)\n",
|
| 289 |
"\n",
|
| 290 |
-
"Post-trained checkpoints add task-specific heads.\n",
|
| 291 |
"\n",
|
| 292 |
"In particular:\n",
|
| 293 |
-
"- `
|
| 294 |
-
"- `
|
| 295 |
"\n",
|
| 296 |
"Expected outputs:\n",
|
| 297 |
-
"- `bigwig_tracks_logits
|
| 298 |
-
"- `bed_tracks_logits
|
| 299 |
-
"- `logits
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 300 |
"\n",
|
| 301 |
-
"
|
|
|
|
| 302 |
]
|
| 303 |
},
|
| 304 |
{
|
|
@@ -318,39 +340,30 @@
|
|
| 318 |
}
|
| 319 |
],
|
| 320 |
"source": [
|
| 321 |
-
"
|
| 322 |
-
"\n",
|
| 323 |
-
"
|
| 324 |
-
"cfg_pos = AutoConfig.from_pretrained(posttrained_model_name, trust_remote_code=True)\n",
|
| 325 |
-
"tok_pos = AutoTokenizer.from_pretrained(posttrained_model_name, trust_remote_code=True)\n",
|
| 326 |
-
"model_pos = AutoModel.from_pretrained(posttrained_model_name, trust_remote_code=True)\n",
|
| 327 |
-
"condition_tokenizer = AutoTokenizer.from_pretrained(\n",
|
| 328 |
-
" posttrained_model_name, subfolder=\"condition_tokenizer\", trust_remote_code=True\n",
|
| 329 |
-
")\n",
|
| 330 |
"\n",
|
| 331 |
-
"#
|
| 332 |
-
"
|
| 333 |
-
"batch = tok_pos([seq], add_special_tokens=False, return_tensors=\"pt\")\n",
|
| 334 |
-
"condition = condition_tokenizer([\"human\"], return_tensors=\"pt\")\n",
|
| 335 |
"\n",
|
| 336 |
-
"#
|
| 337 |
-
"
|
| 338 |
-
"
|
| 339 |
"\n",
|
| 340 |
-
"
|
|
|
|
| 341 |
" input_ids=batch[\"input_ids\"],\n",
|
| 342 |
-
"
|
| 343 |
-
"
|
| 344 |
-
" output_hidden_states=True,\n",
|
| 345 |
-
" output_attentions=True,\n",
|
| 346 |
")\n",
|
| 347 |
"\n",
|
| 348 |
"# 7k human tracks over 37.5 % center region of the input sequence\n",
|
| 349 |
-
"print(\"bigwig_tracks_logits:\", out[\"bigwig_tracks_logits\"].shape)\n",
|
| 350 |
"# Location of 21 genomic elements over 37.5 % center region of the input sequence\n",
|
| 351 |
-
"print(\"bed_tracks_logits:\", out[\"bed_tracks_logits\"].shape)\n",
|
| 352 |
"# Language model logits for whole sequence over vocabulary\n",
|
| 353 |
-
"print(\"language model logits:\", out[\"logits\"].shape)"
|
| 354 |
]
|
| 355 |
}
|
| 356 |
],
|
|
|
|
| 10 |
"This notebook demonstrates how to run **quick inference** with both the pre- and post-trained NTv3 checkpoints:\n",
|
| 11 |
"\n",
|
| 12 |
"- **Pre-trained (MLM-focused):** `InstaDeepAI/NTv3_8M_pre`, `InstaDeepAI/NTv3_100M_pre`, `InstaDeepAI/NTv3_650M_pre`\n",
|
| 13 |
+
"- **Post-trained (functional tracks and genome annotation):** `InstaDeepAI/NTv3_100M_pos`, `InstaDeepAI/NTv3_650M_pos`\n",
|
| 14 |
"\n",
|
| 15 |
"We show how to:\n",
|
| 16 |
"\n",
|
|
|
|
| 51 |
},
|
| 52 |
{
|
| 53 |
"cell_type": "code",
|
| 54 |
+
"execution_count": 3,
|
| 55 |
"id": "d56c105b",
|
| 56 |
"metadata": {},
|
| 57 |
"outputs": [
|
|
|
|
| 287 |
"source": [
|
| 288 |
"## 3) 🧠 Post-trained checkpoint (task heads: BigWig + BED)\n",
|
| 289 |
"\n",
|
| 290 |
+
"Post-trained checkpoints add task-specific heads for functional track prediction and genome annotation.\n",
|
| 291 |
"\n",
|
| 292 |
"In particular:\n",
|
| 293 |
+
"- `species_tokenizer` is used to tokenize a species condition like `\"human\"`\n",
|
| 294 |
+
"- `species_ids` passes the species tokens to the model\n",
|
| 295 |
"\n",
|
| 296 |
"Expected outputs:\n",
|
| 297 |
+
"- `bigwig_tracks_logits`: functional track predictions\n",
|
| 298 |
+
"- `bed_tracks_logits`: genome annotation predictions\n",
|
| 299 |
+
"- `logits`: masked language modeling logits"
|
| 300 |
+
]
|
| 301 |
+
},
|
| 302 |
+
{
|
| 303 |
+
"cell_type": "code",
|
| 304 |
+
"execution_count": 9,
|
| 305 |
+
"id": "bdb8c4d1",
|
| 306 |
+
"metadata": {},
|
| 307 |
+
"outputs": [
|
| 308 |
+
{
|
| 309 |
+
"name": "stdout",
|
| 310 |
+
"output_type": "stream",
|
| 311 |
+
"text": [
|
| 312 |
+
"Model supported species: TO BE DONE\n"
|
| 313 |
+
]
|
| 314 |
+
}
|
| 315 |
+
],
|
| 316 |
+
"source": [
|
| 317 |
+
"# Inspect config and supported species\n",
|
| 318 |
+
"post_trained_model_name = \"InstaDeepAI/NTv3_100M_pos\"\n",
|
| 319 |
+
"\n",
|
| 320 |
+
"cfg_post = AutoConfig.from_pretrained(post_trained_model_name, trust_remote_code=True)\n",
|
| 321 |
"\n",
|
| 322 |
+
"species = \"TO BE DONE\"\n",
|
| 323 |
+
"print(\"Model supported species:\", species)"
|
| 324 |
]
|
| 325 |
},
|
| 326 |
{
|
|
|
|
| 340 |
}
|
| 341 |
],
|
| 342 |
"source": [
|
| 343 |
+
"tok_post = AutoTokenizer.from_pretrained(post_trained_model_name, trust_remote_code=True)\n",
|
| 344 |
+
"cond_tok_post = AutoTokenizer.from_pretrained(post_trained_model_name, subfolder='species_tokenizer', trust_remote_code=True)\n",
|
| 345 |
+
"model_post = AutoModel.from_pretrained(post_trained_model_name, trust_remote_code=True)\n",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 346 |
"\n",
|
| 347 |
+
"# Prepare inputs\n",
|
| 348 |
+
"batch = tok_post([\"ATCGNATCG\", \"ACGT\"], add_special_tokens=False, padding=True, pad_to_multiple_of=128, return_tensors=\"pt\")\n",
|
|
|
|
|
|
|
| 349 |
"\n",
|
| 350 |
+
"# Condition tokens (e.g., species)\n",
|
| 351 |
+
"species = 'human'\n",
|
| 352 |
+
"species_ids = cond_tok_post([species] * len(batch['input_ids']), add_special_tokens=False, return_tensors='pt')\n",
|
| 353 |
"\n",
|
| 354 |
+
"# Forward pass\n",
|
| 355 |
+
"out = model_post(\n",
|
| 356 |
" input_ids=batch[\"input_ids\"],\n",
|
| 357 |
+
" species_ids=species_ids['input_ids'],\n",
|
| 358 |
+
" return_dict=True\n",
|
|
|
|
|
|
|
| 359 |
")\n",
|
| 360 |
"\n",
|
| 361 |
"# 7k human tracks over 37.5 % center region of the input sequence\n",
|
| 362 |
+
"print(\"bigwig_tracks_logits:\", tuple(out[\"bigwig_tracks_logits\"].shape))\n",
|
| 363 |
"# Location of 21 genomic elements over 37.5 % center region of the input sequence\n",
|
| 364 |
+
"print(\"bed_tracks_logits:\", tuple(out[\"bed_tracks_logits\"].shape))\n",
|
| 365 |
"# Language model logits for whole sequence over vocabulary\n",
|
| 366 |
+
"print(\"language model logits:\", tuple(out[\"logits\"].shape))\n"
|
| 367 |
]
|
| 368 |
}
|
| 369 |
],
|
notebooks/01_tracks_prediction.ipynb
CHANGED
|
@@ -19,6 +19,8 @@
|
|
| 19 |
"- **Genomic element annotations** (`bed_tracks_logits`): Classification predictions for genomic elements such as genes, exons, introns, splice sites, promoters, enhancers, and more\n",
|
| 20 |
"- **Masked Language Model logits** (`logits`): Standard transformer language model outputs\n",
|
| 21 |
"\n",
|
|
|
|
|
|
|
| 22 |
"## 📚 Notebook Structure\n",
|
| 23 |
"\n",
|
| 24 |
"1. **Setup**: Install dependencies and define the genomic window of interest\n",
|
|
@@ -33,16 +35,6 @@
|
|
| 33 |
"- Supports the 24 species that NTv3 was post-trained on"
|
| 34 |
]
|
| 35 |
},
|
| 36 |
-
{
|
| 37 |
-
"cell_type": "markdown",
|
| 38 |
-
"id": "4997c547",
|
| 39 |
-
"metadata": {},
|
| 40 |
-
"source": [
|
| 41 |
-
"## 0) Colab Setup (if running on Google Colab)\n",
|
| 42 |
-
"\n",
|
| 43 |
-
"This cell detects if you're running on Google Colab and sets up the environment accordingly."
|
| 44 |
-
]
|
| 45 |
-
},
|
| 46 |
{
|
| 47 |
"cell_type": "code",
|
| 48 |
"execution_count": null,
|
|
@@ -65,7 +57,7 @@
|
|
| 65 |
},
|
| 66 |
{
|
| 67 |
"cell_type": "code",
|
| 68 |
-
"execution_count":
|
| 69 |
"id": "608d67e1",
|
| 70 |
"metadata": {},
|
| 71 |
"outputs": [],
|
|
@@ -96,7 +88,7 @@
|
|
| 96 |
},
|
| 97 |
{
|
| 98 |
"cell_type": "code",
|
| 99 |
-
"execution_count":
|
| 100 |
"id": "795a576f",
|
| 101 |
"metadata": {},
|
| 102 |
"outputs": [
|
|
@@ -112,15 +104,16 @@
|
|
| 112 |
"# -----------------------------\n",
|
| 113 |
"# User inputs\n",
|
| 114 |
"# -----------------------------\n",
|
| 115 |
-
"model_name = \"InstaDeepAI/
|
| 116 |
"\n",
|
| 117 |
"# Example window from a given species (edit these) - needs to be multiple of 128 due to the model downsampling\n",
|
| 118 |
-
"
|
|
|
|
| 119 |
"chrom = \"chr19\"\n",
|
| 120 |
"start = 6_700_000\n",
|
| 121 |
"end = 6_831_072\n",
|
| 122 |
"\n",
|
| 123 |
-
"# Optional
|
| 124 |
"HF_TOKEN = os.getenv(\"HF_TOKEN\", None)\n",
|
| 125 |
"\n",
|
| 126 |
"assert end > start, \"end must be > start\"\n",
|
|
@@ -138,7 +131,7 @@
|
|
| 138 |
},
|
| 139 |
{
|
| 140 |
"cell_type": "code",
|
| 141 |
-
"execution_count":
|
| 142 |
"id": "2354e2aa",
|
| 143 |
"metadata": {},
|
| 144 |
"outputs": [
|
|
@@ -175,7 +168,8 @@
|
|
| 175 |
"name": "stdout",
|
| 176 |
"output_type": "stream",
|
| 177 |
"text": [
|
| 178 |
-
"
|
|
|
|
| 179 |
"Sequence preview: GTCAACAATAACAAATGACATATTAGTAGTAAATTATAATTATACATTACAACAAAATTA...\n",
|
| 180 |
"Valid DNA: True\n"
|
| 181 |
]
|
|
@@ -234,47 +228,198 @@
|
|
| 234 |
},
|
| 235 |
{
|
| 236 |
"cell_type": "code",
|
| 237 |
-
"execution_count":
|
| 238 |
"id": "e09f0469",
|
| 239 |
"metadata": {},
|
| 240 |
"outputs": [
|
| 241 |
{
|
| 242 |
-
"
|
| 243 |
-
|
| 244 |
-
|
| 245 |
-
|
| 246 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 247 |
}
|
| 248 |
],
|
| 249 |
"source": [
|
| 250 |
"# Load model\n",
|
| 251 |
-
"cfg = AutoConfig.from_pretrained(model_name, trust_remote_code=True
|
| 252 |
-
"model = AutoModel.from_pretrained(model_name, trust_remote_code=True
|
| 253 |
"\n",
|
| 254 |
"# Load tokenizer\n",
|
| 255 |
-
"tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True
|
| 256 |
"\n",
|
| 257 |
"# Load condition tokenizer\n",
|
| 258 |
-
"
|
| 259 |
-
" model_name, subfolder=\"
|
| 260 |
")\n",
|
| 261 |
"\n",
|
| 262 |
"# Set model to evaluation mode\n",
|
| 263 |
-
"model.eval()
|
| 264 |
-
"\n",
|
| 265 |
-
"# Get assembly index\n",
|
| 266 |
-
"assemblies = list(cfg.bigwigs_per_file_assembly.keys())\n",
|
| 267 |
-
"print(\"Model supported assemblies:\", assemblies)\n",
|
| 268 |
-
"assembly_idx = torch.tensor([assemblies.index(assembly)])\n",
|
| 269 |
-
"\n",
|
| 270 |
-
"# Condition token (species)\n",
|
| 271 |
-
"condition = condition_tokenizer([\"human\"], return_tensors=\"pt\")\n",
|
| 272 |
-
"condition_ids = [condition[\"input_ids\"][0].to(device)]"
|
| 273 |
]
|
| 274 |
},
|
| 275 |
{
|
| 276 |
"cell_type": "code",
|
| 277 |
-
"execution_count":
|
| 278 |
"id": "43154959",
|
| 279 |
"metadata": {},
|
| 280 |
"outputs": [
|
|
@@ -307,8 +452,7 @@
|
|
| 307 |
"We pass:\n",
|
| 308 |
"\n",
|
| 309 |
"- `input_ids`: tokenized DNA window\n",
|
| 310 |
-
"- `
|
| 311 |
-
"- `file_assembly_idx`: select the assembly (`hg38`)\n",
|
| 312 |
"\n",
|
| 313 |
"Outputs include:\n",
|
| 314 |
"\n",
|
|
@@ -319,7 +463,7 @@
|
|
| 319 |
},
|
| 320 |
{
|
| 321 |
"cell_type": "code",
|
| 322 |
-
"execution_count":
|
| 323 |
"id": "6765a9b9",
|
| 324 |
"metadata": {},
|
| 325 |
"outputs": [
|
|
@@ -338,13 +482,15 @@
|
|
| 338 |
"batch = tokenizer([seq], add_special_tokens=False, return_tensors=\"pt\")\n",
|
| 339 |
"input_ids = batch[\"input_ids\"].to(device)\n",
|
| 340 |
"\n",
|
|
|
|
|
|
|
|
|
|
|
|
|
| 341 |
"# Run inference\n",
|
| 342 |
"out = model(\n",
|
| 343 |
" input_ids=input_ids,\n",
|
| 344 |
-
"
|
| 345 |
-
"
|
| 346 |
-
" output_hidden_states=False,\n",
|
| 347 |
-
" output_attentions=False,\n",
|
| 348 |
")\n",
|
| 349 |
"\n",
|
| 350 |
"# 7k human tracks over 37.5 % center region of the input sequence\n",
|
|
@@ -391,7 +537,7 @@
|
|
| 391 |
},
|
| 392 |
{
|
| 393 |
"cell_type": "code",
|
| 394 |
-
"execution_count":
|
| 395 |
"id": "717539e2",
|
| 396 |
"metadata": {},
|
| 397 |
"outputs": [],
|
|
|
|
| 19 |
"- **Genomic element annotations** (`bed_tracks_logits`): Classification predictions for genomic elements such as genes, exons, introns, splice sites, promoters, enhancers, and more\n",
|
| 20 |
"- **Masked Language Model logits** (`logits`): Standard transformer language model outputs\n",
|
| 21 |
"\n",
|
| 22 |
+
"> 💡 **Note:** Functional tracks and genomic element annotations are predicted only for the center 37.5% of the input sequence, where the model is more confident due to having full context on both sides.\n",
|
| 23 |
+
"\n",
|
| 24 |
"## 📚 Notebook Structure\n",
|
| 25 |
"\n",
|
| 26 |
"1. **Setup**: Install dependencies and define the genomic window of interest\n",
|
|
|
|
| 35 |
"- Supports the 24 species that NTv3 was post-trained on"
|
| 36 |
]
|
| 37 |
},
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 38 |
{
|
| 39 |
"cell_type": "code",
|
| 40 |
"execution_count": null,
|
|
|
|
| 57 |
},
|
| 58 |
{
|
| 59 |
"cell_type": "code",
|
| 60 |
+
"execution_count": 7,
|
| 61 |
"id": "608d67e1",
|
| 62 |
"metadata": {},
|
| 63 |
"outputs": [],
|
|
|
|
| 88 |
},
|
| 89 |
{
|
| 90 |
"cell_type": "code",
|
| 91 |
+
"execution_count": 8,
|
| 92 |
"id": "795a576f",
|
| 93 |
"metadata": {},
|
| 94 |
"outputs": [
|
|
|
|
| 104 |
"# -----------------------------\n",
|
| 105 |
"# User inputs\n",
|
| 106 |
"# -----------------------------\n",
|
| 107 |
+
"model_name = \"InstaDeepAI/NTv3_100M_pos\" # options: \"InstaDeepAI/ntv3_106M_7downsample_post_trained_1mb\" or \"InstaDeepAI/ntv3_650M_7downsample_post_trained_1mb_v2\"\n",
|
| 108 |
"\n",
|
| 109 |
"# Example window from a given species (edit these) - needs to be multiple of 128 due to the model downsampling\n",
|
| 110 |
+
"species = \"human\" # will use for condition the model on species\n",
|
| 111 |
+
"assembly = \"hg38\" # will use for fetching the chromosome sequence\n",
|
| 112 |
"chrom = \"chr19\"\n",
|
| 113 |
"start = 6_700_000\n",
|
| 114 |
"end = 6_831_072\n",
|
| 115 |
"\n",
|
| 116 |
+
"# Optional\n",
|
| 117 |
"HF_TOKEN = os.getenv(\"HF_TOKEN\", None)\n",
|
| 118 |
"\n",
|
| 119 |
"assert end > start, \"end must be > start\"\n",
|
|
|
|
| 131 |
},
|
| 132 |
{
|
| 133 |
"cell_type": "code",
|
| 134 |
+
"execution_count": 4,
|
| 135 |
"id": "2354e2aa",
|
| 136 |
"metadata": {},
|
| 137 |
"outputs": [
|
|
|
|
| 168 |
"name": "stdout",
|
| 169 |
"output_type": "stream",
|
| 170 |
"text": [
|
| 171 |
+
"Downloading: https://hgdownload.soe.ucsc.edu/goldenPath/hg38/chromosomes/chr19.fa.gz\n",
|
| 172 |
+
"Using downloaded chromosome FASTA: ./hg38/chr19.fa\n",
|
| 173 |
"Sequence preview: GTCAACAATAACAAATGACATATTAGTAGTAAATTATAATTATACATTACAACAAAATTA...\n",
|
| 174 |
"Valid DNA: True\n"
|
| 175 |
]
|
|
|
|
| 228 |
},
|
| 229 |
{
|
| 230 |
"cell_type": "code",
|
| 231 |
+
"execution_count": 11,
|
| 232 |
"id": "e09f0469",
|
| 233 |
"metadata": {},
|
| 234 |
"outputs": [
|
| 235 |
{
|
| 236 |
+
"data": {
|
| 237 |
+
"text/plain": [
|
| 238 |
+
"NTv3Model(\n",
|
| 239 |
+
" (core): Core(\n",
|
| 240 |
+
" (embed_layer): Embedding(11, 16, padding_idx=1)\n",
|
| 241 |
+
" (stem): Stem(\n",
|
| 242 |
+
" (conv): Conv1d(16, 768, kernel_size=(15,), stride=(1,), padding=same)\n",
|
| 243 |
+
" )\n",
|
| 244 |
+
" (cond_tables): ModuleList(\n",
|
| 245 |
+
" (0): Embedding(30, 16)\n",
|
| 246 |
+
" )\n",
|
| 247 |
+
" (conv_tower_blocks): ModuleList(\n",
|
| 248 |
+
" (0-6): 7 x ConditionedConvTowerBlock(\n",
|
| 249 |
+
" (conv): AdaptiveConvBlock(\n",
|
| 250 |
+
" (conv): Conv1d(768, 768, kernel_size=(5,), stride=(1,), padding=same)\n",
|
| 251 |
+
" (layer_norm): AdaptiveLayerNorm(\n",
|
| 252 |
+
" (np.int64(768),), eps=1e-05, elementwise_affine=True\n",
|
| 253 |
+
" (modulation_layers): ModuleList(\n",
|
| 254 |
+
" (0): Linear(in_features=16, out_features=1536, bias=True)\n",
|
| 255 |
+
" )\n",
|
| 256 |
+
" )\n",
|
| 257 |
+
" )\n",
|
| 258 |
+
" (res_conv): AdaptiveResidualConvBlock(\n",
|
| 259 |
+
" (conv_block): AdaptiveConvBlock(\n",
|
| 260 |
+
" (conv): Conv1d(768, 768, kernel_size=(1,), stride=(1,), padding=same)\n",
|
| 261 |
+
" (layer_norm): AdaptiveLayerNorm(\n",
|
| 262 |
+
" (np.int64(768),), eps=1e-05, elementwise_affine=True\n",
|
| 263 |
+
" (modulation_layers): ModuleList(\n",
|
| 264 |
+
" (0): Linear(in_features=16, out_features=1536, bias=True)\n",
|
| 265 |
+
" )\n",
|
| 266 |
+
" )\n",
|
| 267 |
+
" )\n",
|
| 268 |
+
" (modulation_layers): ModuleList(\n",
|
| 269 |
+
" (0): Linear(in_features=16, out_features=768, bias=True)\n",
|
| 270 |
+
" )\n",
|
| 271 |
+
" )\n",
|
| 272 |
+
" (avg_pool): AvgPool1d(kernel_size=(2,), stride=(2,), padding=(0,))\n",
|
| 273 |
+
" )\n",
|
| 274 |
+
" )\n",
|
| 275 |
+
" (transformer_blocks): ModuleList(\n",
|
| 276 |
+
" (0-5): 6 x AdaptiveSelfAttentionBlock(\n",
|
| 277 |
+
" (self_attention_layer_norm): AdaptiveLayerNorm(\n",
|
| 278 |
+
" (768,), eps=1e-05, elementwise_affine=True\n",
|
| 279 |
+
" (modulation_layers): ModuleList(\n",
|
| 280 |
+
" (0): Linear(in_features=16, out_features=1536, bias=True)\n",
|
| 281 |
+
" )\n",
|
| 282 |
+
" )\n",
|
| 283 |
+
" (final_layer_norm): AdaptiveLayerNorm(\n",
|
| 284 |
+
" (768,), eps=1e-05, elementwise_affine=True\n",
|
| 285 |
+
" (modulation_layers): ModuleList(\n",
|
| 286 |
+
" (0): Linear(in_features=16, out_features=1536, bias=True)\n",
|
| 287 |
+
" )\n",
|
| 288 |
+
" )\n",
|
| 289 |
+
" (sa_layer): MultiHeadAttention(\n",
|
| 290 |
+
" (query_head): LinearProjectionHeInit(\n",
|
| 291 |
+
" (linear): Linear(in_features=768, out_features=768, bias=True)\n",
|
| 292 |
+
" )\n",
|
| 293 |
+
" (key_head): LinearProjectionHeInit(\n",
|
| 294 |
+
" (linear): Linear(in_features=768, out_features=768, bias=True)\n",
|
| 295 |
+
" )\n",
|
| 296 |
+
" (value_head): LinearProjectionHeInit(\n",
|
| 297 |
+
" (linear): Linear(in_features=768, out_features=768, bias=True)\n",
|
| 298 |
+
" )\n",
|
| 299 |
+
" (mha_output): Linear(in_features=768, out_features=768, bias=True)\n",
|
| 300 |
+
" (rotary_embedding): RotaryEmbedding()\n",
|
| 301 |
+
" )\n",
|
| 302 |
+
" (fc1): Linear(in_features=768, out_features=6144, bias=False)\n",
|
| 303 |
+
" (fc2): Linear(in_features=3072, out_features=768, bias=False)\n",
|
| 304 |
+
" (_ffn_activation_fn): SiLU()\n",
|
| 305 |
+
" (modulation_layers): ModuleList(\n",
|
| 306 |
+
" (0): Linear(in_features=16, out_features=768, bias=True)\n",
|
| 307 |
+
" )\n",
|
| 308 |
+
" )\n",
|
| 309 |
+
" )\n",
|
| 310 |
+
" (deconv_tower_blocks): ModuleList(\n",
|
| 311 |
+
" (0-6): 7 x ConditionedDeConvTowerBlock(\n",
|
| 312 |
+
" (conv): AdaptiveDeConvBlock(\n",
|
| 313 |
+
" (conv): Conv1d(768, 768, kernel_size=(5,), stride=(1,), padding=same)\n",
|
| 314 |
+
" (layer_norm): AdaptiveLayerNorm(\n",
|
| 315 |
+
" (np.int64(768),), eps=1e-05, elementwise_affine=True\n",
|
| 316 |
+
" (modulation_layers): ModuleList(\n",
|
| 317 |
+
" (0): Linear(in_features=16, out_features=1536, bias=True)\n",
|
| 318 |
+
" )\n",
|
| 319 |
+
" )\n",
|
| 320 |
+
" )\n",
|
| 321 |
+
" (res_conv): AdaptiveResidualDeConvBlock(\n",
|
| 322 |
+
" (conv_block): AdaptiveDeConvBlock(\n",
|
| 323 |
+
" (conv): ConvTranspose1d(768, 768, kernel_size=(1,), stride=(1,))\n",
|
| 324 |
+
" (layer_norm): AdaptiveLayerNorm(\n",
|
| 325 |
+
" (np.int64(768),), eps=1e-05, elementwise_affine=True\n",
|
| 326 |
+
" (modulation_layers): ModuleList(\n",
|
| 327 |
+
" (0): Linear(in_features=16, out_features=1536, bias=True)\n",
|
| 328 |
+
" )\n",
|
| 329 |
+
" )\n",
|
| 330 |
+
" )\n",
|
| 331 |
+
" (modulation_layers): ModuleList(\n",
|
| 332 |
+
" (0): Linear(in_features=16, out_features=768, bias=True)\n",
|
| 333 |
+
" )\n",
|
| 334 |
+
" )\n",
|
| 335 |
+
" )\n",
|
| 336 |
+
" )\n",
|
| 337 |
+
" (bigwig_head): MultiSpeciesHead(\n",
|
| 338 |
+
" (species_heads): ModuleList(\n",
|
| 339 |
+
" (0-4): 5 x ZeroHead()\n",
|
| 340 |
+
" (5): LinearHead(\n",
|
| 341 |
+
" (layer_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n",
|
| 342 |
+
" (head): Linear(in_features=768, out_features=590, bias=True)\n",
|
| 343 |
+
" )\n",
|
| 344 |
+
" (6): LinearHead(\n",
|
| 345 |
+
" (layer_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n",
|
| 346 |
+
" (head): Linear(in_features=768, out_features=319, bias=True)\n",
|
| 347 |
+
" )\n",
|
| 348 |
+
" (7): LinearHead(\n",
|
| 349 |
+
" (layer_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n",
|
| 350 |
+
" (head): Linear(in_features=768, out_features=1392, bias=True)\n",
|
| 351 |
+
" )\n",
|
| 352 |
+
" (8): LinearHead(\n",
|
| 353 |
+
" (layer_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n",
|
| 354 |
+
" (head): Linear(in_features=768, out_features=776, bias=True)\n",
|
| 355 |
+
" )\n",
|
| 356 |
+
" (9-12): 4 x ZeroHead()\n",
|
| 357 |
+
" (13): LinearHead(\n",
|
| 358 |
+
" (layer_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n",
|
| 359 |
+
" (head): Linear(in_features=768, out_features=1899, bias=True)\n",
|
| 360 |
+
" )\n",
|
| 361 |
+
" (14-15): 2 x ZeroHead()\n",
|
| 362 |
+
" (16): LinearHead(\n",
|
| 363 |
+
" (layer_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n",
|
| 364 |
+
" (head): Linear(in_features=768, out_features=921, bias=True)\n",
|
| 365 |
+
" )\n",
|
| 366 |
+
" (17): ZeroHead()\n",
|
| 367 |
+
" (18): LinearHead(\n",
|
| 368 |
+
" (layer_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n",
|
| 369 |
+
" (head): Linear(in_features=768, out_features=180, bias=True)\n",
|
| 370 |
+
" )\n",
|
| 371 |
+
" (19-20): 2 x ZeroHead()\n",
|
| 372 |
+
" (21): LinearHead(\n",
|
| 373 |
+
" (layer_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n",
|
| 374 |
+
" (head): Linear(in_features=768, out_features=7362, bias=True)\n",
|
| 375 |
+
" )\n",
|
| 376 |
+
" (22): ZeroHead()\n",
|
| 377 |
+
" (23): LinearHead(\n",
|
| 378 |
+
" (layer_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n",
|
| 379 |
+
" (head): Linear(in_features=768, out_features=2450, bias=True)\n",
|
| 380 |
+
" )\n",
|
| 381 |
+
" )\n",
|
| 382 |
+
" )\n",
|
| 383 |
+
" (bed_head): ClassificationHead(\n",
|
| 384 |
+
" (layer_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n",
|
| 385 |
+
" (head): Linear(in_features=768, out_features=42, bias=True)\n",
|
| 386 |
+
" )\n",
|
| 387 |
+
" (conditions_heads): ModuleList(\n",
|
| 388 |
+
" (0): Linear(in_features=768, out_features=30, bias=True)\n",
|
| 389 |
+
" )\n",
|
| 390 |
+
" (lm_head): ModuleDict(\n",
|
| 391 |
+
" (hidden_layers): ModuleList()\n",
|
| 392 |
+
" (head): Linear(in_features=768, out_features=11, bias=True)\n",
|
| 393 |
+
" )\n",
|
| 394 |
+
" )\n",
|
| 395 |
+
")"
|
| 396 |
+
]
|
| 397 |
+
},
|
| 398 |
+
"execution_count": 11,
|
| 399 |
+
"metadata": {},
|
| 400 |
+
"output_type": "execute_result"
|
| 401 |
}
|
| 402 |
],
|
| 403 |
"source": [
|
| 404 |
"# Load model\n",
|
| 405 |
+
"cfg = AutoConfig.from_pretrained(model_name, trust_remote_code=True)\n",
|
| 406 |
+
"model = AutoModel.from_pretrained(model_name, trust_remote_code=True).to(device)\n",
|
| 407 |
"\n",
|
| 408 |
"# Load tokenizer\n",
|
| 409 |
+
"tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)\n",
|
| 410 |
"\n",
|
| 411 |
"# Load condition tokenizer\n",
|
| 412 |
+
"species_tokenizer = AutoTokenizer.from_pretrained(\n",
|
| 413 |
+
" model_name, subfolder=\"species_tokenizer\", trust_remote_code=True,\n",
|
| 414 |
")\n",
|
| 415 |
"\n",
|
| 416 |
"# Set model to evaluation mode\n",
|
| 417 |
+
"model.eval()"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 418 |
]
|
| 419 |
},
|
| 420 |
{
|
| 421 |
"cell_type": "code",
|
| 422 |
+
"execution_count": 12,
|
| 423 |
"id": "43154959",
|
| 424 |
"metadata": {},
|
| 425 |
"outputs": [
|
|
|
|
| 452 |
"We pass:\n",
|
| 453 |
"\n",
|
| 454 |
"- `input_ids`: tokenized DNA window\n",
|
| 455 |
+
"- `species_ids`: species tokens (`human`)\n",
|
|
|
|
| 456 |
"\n",
|
| 457 |
"Outputs include:\n",
|
| 458 |
"\n",
|
|
|
|
| 463 |
},
|
| 464 |
{
|
| 465 |
"cell_type": "code",
|
| 466 |
+
"execution_count": 13,
|
| 467 |
"id": "6765a9b9",
|
| 468 |
"metadata": {},
|
| 469 |
"outputs": [
|
|
|
|
| 482 |
"batch = tokenizer([seq], add_special_tokens=False, return_tensors=\"pt\")\n",
|
| 483 |
"input_ids = batch[\"input_ids\"].to(device)\n",
|
| 484 |
"\n",
|
| 485 |
+
"# Condition tokens (e.g., species)\n",
|
| 486 |
+
"species = 'human'\n",
|
| 487 |
+
"species_ids = species_tokenizer([species] * len(batch['input_ids']), add_special_tokens=False, return_tensors='pt')\n",
|
| 488 |
+
"\n",
|
| 489 |
"# Run inference\n",
|
| 490 |
"out = model(\n",
|
| 491 |
" input_ids=input_ids,\n",
|
| 492 |
+
" species_ids=species_ids['input_ids'],\n",
|
| 493 |
+
" return_dict=True\n",
|
|
|
|
|
|
|
| 494 |
")\n",
|
| 495 |
"\n",
|
| 496 |
"# 7k human tracks over 37.5 % center region of the input sequence\n",
|
|
|
|
| 537 |
},
|
| 538 |
{
|
| 539 |
"cell_type": "code",
|
| 540 |
+
"execution_count": 15,
|
| 541 |
"id": "717539e2",
|
| 542 |
"metadata": {},
|
| 543 |
"outputs": [],
|