bernardo-de-almeida commited on
Commit
1094a5f
·
1 Parent(s): 2101d19

fix: notebooks with new post-trained model formats

Browse files
index.html CHANGED
@@ -262,8 +262,8 @@
262
  </li>
263
  <li>🎯 Post-trained checkpoints:
264
  <div style="margin-top: 8px; margin-left: 0;">
265
- <div><a href="https://huggingface.co/InstaDeepAI/NTv3_100M"><code>InstaDeepAI/NTv3_100M</code></a></div>
266
- <div><a href="https://huggingface.co/InstaDeepAI/NTv3_650M"><code>InstaDeepAI/NTv3_650M</code></a></div>
267
  </div>
268
  </li>
269
  </ul>
@@ -309,7 +309,7 @@
309
  <ul>
310
  <li><a href="https://huggingface.co/spaces/InstaDeepAI/ntv3/blob/main/notebooks/00_quickstart_inference.ipynb" target="_blank" rel="noopener">🚀 00 — Quickstart inference</a></li>
311
  <li><a href="https://huggingface.co/spaces/InstaDeepAI/ntv3/blob/main/notebooks/01_tracks_prediction.ipynb" target="_blank" rel="noopener">📊 01 — Tracks prediction</a></li>
312
- <li>🏷️ 02 — Genome annotation / segmentation</li>
313
  <li>🎯 03 — Fine-tune on bigwig tracks</li>
314
  <li>🔍 04 — Model interpretation</li>
315
  <li>🧪 05 — Sequence generation</li>
@@ -380,9 +380,12 @@ out = pipe(
380
  )
381
 
382
  # Print output shapes
383
- print(out.bigwig_tracks_logits.shape) # functional track predictions
384
- print(out.bed_tracks_logits.shape) # genome annotation predictions
385
- print(out.mlm_logits.shape) # MLM logits: (B, L, V = 11)</code></pre></div>
 
 
 
386
  <p>Predictions can also be plotted for a subset of functional tracks and genomic elements:</p>
387
  <div class="code"><pre><code class="language-python">tracks_to_plot = {
388
  "K562 RNA-seq": "ENCSR056HPM",
 
262
  </li>
263
  <li>🎯 Post-trained checkpoints:
264
  <div style="margin-top: 8px; margin-left: 0;">
265
+ <div><a href="https://huggingface.co/InstaDeepAI/NTv3_100M_pos"><code>InstaDeepAI/NTv3_100M_pos</code></a></div>
266
+ <div><a href="https://huggingface.co/InstaDeepAI/NTv3_650M_pos"><code>InstaDeepAI/NTv3_650M_pos</code></a></div>
267
  </div>
268
  </li>
269
  </ul>
 
309
  <ul>
310
  <li><a href="https://huggingface.co/spaces/InstaDeepAI/ntv3/blob/main/notebooks/00_quickstart_inference.ipynb" target="_blank" rel="noopener">🚀 00 — Quickstart inference</a></li>
311
  <li><a href="https://huggingface.co/spaces/InstaDeepAI/ntv3/blob/main/notebooks/01_tracks_prediction.ipynb" target="_blank" rel="noopener">📊 01 — Tracks prediction</a></li>
312
+ <li><a href="https://huggingface.co/spaces/InstaDeepAI/ntv3/blob/main/notebooks/02_genome_annotation.ipynb" target="_blank" rel="noopener">🏷️ 02 — Genome annotation / segmentation</a></li>
313
  <li>🎯 03 — Fine-tune on bigwig tracks</li>
314
  <li>🔍 04 — Model interpretation</li>
315
  <li>🧪 05 — Sequence generation</li>
 
380
  )
381
 
382
  # Print output shapes
383
+ # 7k human tracks over 37.5 % center region of the input sequence
384
+ print("bigwig_tracks_logits:", tuple(out.bigwig_tracks_logits.shape))
385
+ # Location of 21 genomic elements over 37.5 % center region of the input sequence
386
+ print("bed_tracks_logits:", tuple(out.bed_tracks_logits.shape))
387
+ # Language model logits for whole sequence over vocabulary
388
+ print("language model logits:", tuple(out.mlm_logits.shape))</code></pre></div>
389
  <p>Predictions can also be plotted for a subset of functional tracks and genomic elements:</p>
390
  <div class="code"><pre><code class="language-python">tracks_to_plot = {
391
  "K562 RNA-seq": "ENCSR056HPM",
notebooks/00_quickstart_inference.ipynb CHANGED
@@ -10,7 +10,7 @@
10
  "This notebook demonstrates how to run **quick inference** with both the pre- and post-trained NTv3 checkpoints:\n",
11
  "\n",
12
  "- **Pre-trained (MLM-focused):** `InstaDeepAI/NTv3_8M_pre`, `InstaDeepAI/NTv3_100M_pre`, `InstaDeepAI/NTv3_650M_pre`\n",
13
- "- **Post-trained (task heads):** `InstaDeepAI/NTv3_100M`, `InstaDeepAI/NTv3_650M`\n",
14
  "\n",
15
  "We show how to:\n",
16
  "\n",
@@ -51,7 +51,7 @@
51
  },
52
  {
53
  "cell_type": "code",
54
- "execution_count": 7,
55
  "id": "d56c105b",
56
  "metadata": {},
57
  "outputs": [
@@ -287,18 +287,40 @@
287
  "source": [
288
  "## 3) 🧠 Post-trained checkpoint (task heads: BigWig + BED)\n",
289
  "\n",
290
- "Post-trained checkpoints add task-specific heads.\n",
291
  "\n",
292
  "In particular:\n",
293
- "- `condition_tokenizer` is used to tokenize a species condition like `\"human\"`\n",
294
- "- `file_assembly_idx` selects the assembly (e.g., `hg38`)\n",
295
  "\n",
296
  "Expected outputs:\n",
297
- "- `bigwig_tracks_logits`\n",
298
- "- `bed_tracks_logits`\n",
299
- "- `logits` (MLM)\n",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
300
  "\n",
301
- "> 💡 If your post-trained checkpoint supports multiple assemblies, the config typically exposes a mapping like `cfg.bigwigs_per_file_assembly`."
 
302
  ]
303
  },
304
  {
@@ -318,39 +340,30 @@
318
  }
319
  ],
320
  "source": [
321
- "posttrained_model_name = \"InstaDeepAI/NTv3_100M\"\n",
322
- "\n",
323
- "# Load config/tokenizers/model\n",
324
- "cfg_pos = AutoConfig.from_pretrained(posttrained_model_name, trust_remote_code=True)\n",
325
- "tok_pos = AutoTokenizer.from_pretrained(posttrained_model_name, trust_remote_code=True)\n",
326
- "model_pos = AutoModel.from_pretrained(posttrained_model_name, trust_remote_code=True)\n",
327
- "condition_tokenizer = AutoTokenizer.from_pretrained(\n",
328
- " posttrained_model_name, subfolder=\"condition_tokenizer\", trust_remote_code=True\n",
329
- ")\n",
330
  "\n",
331
- "# Example: human sequence (sequence needs to be multiple of 128 due to 7 downsampling in model)\n",
332
- "seq = \"ATCG\" * 512\n",
333
- "batch = tok_pos([seq], add_special_tokens=False, return_tensors=\"pt\")\n",
334
- "condition = condition_tokenizer([\"human\"], return_tensors=\"pt\")\n",
335
  "\n",
336
- "# Get assembly index for human (hg38)\n",
337
- "assemblies = list(cfg_pos.bigwigs_per_file_assembly.keys())\n",
338
- "assembly_idx = torch.tensor([assemblies.index(\"hg38\")])\n",
339
  "\n",
340
- "out = model_pos(\n",
 
341
  " input_ids=batch[\"input_ids\"],\n",
342
- " condition_ids=[condition[\"input_ids\"][0]],\n",
343
- " file_assembly_idx=assembly_idx,\n",
344
- " output_hidden_states=True,\n",
345
- " output_attentions=True,\n",
346
  ")\n",
347
  "\n",
348
  "# 7k human tracks over 37.5 % center region of the input sequence\n",
349
- "print(\"bigwig_tracks_logits:\", out[\"bigwig_tracks_logits\"].shape)\n",
350
  "# Location of 21 genomic elements over 37.5 % center region of the input sequence\n",
351
- "print(\"bed_tracks_logits:\", out[\"bed_tracks_logits\"].shape)\n",
352
  "# Language model logits for whole sequence over vocabulary\n",
353
- "print(\"language model logits:\", out[\"logits\"].shape)"
354
  ]
355
  }
356
  ],
 
10
  "This notebook demonstrates how to run **quick inference** with both the pre- and post-trained NTv3 checkpoints:\n",
11
  "\n",
12
  "- **Pre-trained (MLM-focused):** `InstaDeepAI/NTv3_8M_pre`, `InstaDeepAI/NTv3_100M_pre`, `InstaDeepAI/NTv3_650M_pre`\n",
13
+ "- **Post-trained (functional tracks and genome annotation):** `InstaDeepAI/NTv3_100M_pos`, `InstaDeepAI/NTv3_650M_pos`\n",
14
  "\n",
15
  "We show how to:\n",
16
  "\n",
 
51
  },
52
  {
53
  "cell_type": "code",
54
+ "execution_count": 3,
55
  "id": "d56c105b",
56
  "metadata": {},
57
  "outputs": [
 
287
  "source": [
288
  "## 3) 🧠 Post-trained checkpoint (task heads: BigWig + BED)\n",
289
  "\n",
290
+ "Post-trained checkpoints add task-specific heads for functional track prediction and genome annotation.\n",
291
  "\n",
292
  "In particular:\n",
293
+ "- `species_tokenizer` is used to tokenize a species condition like `\"human\"`\n",
294
+ "- `species_ids` passes the species tokens to the model\n",
295
  "\n",
296
  "Expected outputs:\n",
297
+ "- `bigwig_tracks_logits`: functional track predictions\n",
298
+ "- `bed_tracks_logits`: genome annotation predictions\n",
299
+ "- `logits`: masked language modeling logits"
300
+ ]
301
+ },
302
+ {
303
+ "cell_type": "code",
304
+ "execution_count": 9,
305
+ "id": "bdb8c4d1",
306
+ "metadata": {},
307
+ "outputs": [
308
+ {
309
+ "name": "stdout",
310
+ "output_type": "stream",
311
+ "text": [
312
+ "Model supported species: TO BE DONE\n"
313
+ ]
314
+ }
315
+ ],
316
+ "source": [
317
+ "# Inspect config and supported species\n",
318
+ "post_trained_model_name = \"InstaDeepAI/NTv3_100M_pos\"\n",
319
+ "\n",
320
+ "cfg_post = AutoConfig.from_pretrained(post_trained_model_name, trust_remote_code=True)\n",
321
  "\n",
322
+ "species = \"TO BE DONE\"\n",
323
+ "print(\"Model supported species:\", species)"
324
  ]
325
  },
326
  {
 
340
  }
341
  ],
342
  "source": [
343
+ "tok_post = AutoTokenizer.from_pretrained(post_trained_model_name, trust_remote_code=True)\n",
344
+ "cond_tok_post = AutoTokenizer.from_pretrained(post_trained_model_name, subfolder='species_tokenizer', trust_remote_code=True)\n",
345
+ "model_post = AutoModel.from_pretrained(post_trained_model_name, trust_remote_code=True)\n",
 
 
 
 
 
 
346
  "\n",
347
+ "# Prepare inputs\n",
348
+ "batch = tok_post([\"ATCGNATCG\", \"ACGT\"], add_special_tokens=False, padding=True, pad_to_multiple_of=128, return_tensors=\"pt\")\n",
 
 
349
  "\n",
350
+ "# Condition tokens (e.g., species)\n",
351
+ "species = 'human'\n",
352
+ "species_ids = cond_tok_post([species] * len(batch['input_ids']), add_special_tokens=False, return_tensors='pt')\n",
353
  "\n",
354
+ "# Forward pass\n",
355
+ "out = model_post(\n",
356
  " input_ids=batch[\"input_ids\"],\n",
357
+ " species_ids=species_ids['input_ids'],\n",
358
+ " return_dict=True\n",
 
 
359
  ")\n",
360
  "\n",
361
  "# 7k human tracks over 37.5 % center region of the input sequence\n",
362
+ "print(\"bigwig_tracks_logits:\", tuple(out[\"bigwig_tracks_logits\"].shape))\n",
363
  "# Location of 21 genomic elements over 37.5 % center region of the input sequence\n",
364
+ "print(\"bed_tracks_logits:\", tuple(out[\"bed_tracks_logits\"].shape))\n",
365
  "# Language model logits for whole sequence over vocabulary\n",
366
+ "print(\"language model logits:\", tuple(out[\"logits\"].shape))\n"
367
  ]
368
  }
369
  ],
notebooks/01_tracks_prediction.ipynb CHANGED
@@ -19,6 +19,8 @@
19
  "- **Genomic element annotations** (`bed_tracks_logits`): Classification predictions for genomic elements such as genes, exons, introns, splice sites, promoters, enhancers, and more\n",
20
  "- **Masked Language Model logits** (`logits`): Standard transformer language model outputs\n",
21
  "\n",
 
 
22
  "## 📚 Notebook Structure\n",
23
  "\n",
24
  "1. **Setup**: Install dependencies and define the genomic window of interest\n",
@@ -33,16 +35,6 @@
33
  "- Supports the 24 species that NTv3 was post-trained on"
34
  ]
35
  },
36
- {
37
- "cell_type": "markdown",
38
- "id": "4997c547",
39
- "metadata": {},
40
- "source": [
41
- "## 0) Colab Setup (if running on Google Colab)\n",
42
- "\n",
43
- "This cell detects if you're running on Google Colab and sets up the environment accordingly."
44
- ]
45
- },
46
  {
47
  "cell_type": "code",
48
  "execution_count": null,
@@ -65,7 +57,7 @@
65
  },
66
  {
67
  "cell_type": "code",
68
- "execution_count": null,
69
  "id": "608d67e1",
70
  "metadata": {},
71
  "outputs": [],
@@ -96,7 +88,7 @@
96
  },
97
  {
98
  "cell_type": "code",
99
- "execution_count": null,
100
  "id": "795a576f",
101
  "metadata": {},
102
  "outputs": [
@@ -112,15 +104,16 @@
112
  "# -----------------------------\n",
113
  "# User inputs\n",
114
  "# -----------------------------\n",
115
- "model_name = \"InstaDeepAI/NTv3_100M\" # options: \"InstaDeepAI/ntv3_106M_7downsample_post_trained_1mb\" or \"InstaDeepAI/ntv3_650M_7downsample_post_trained_1mb_v2\"\n",
116
  "\n",
117
  "# Example window from a given species (edit these) - needs to be multiple of 128 due to the model downsampling\n",
118
- "assembly = \"hg38\"\n",
 
119
  "chrom = \"chr19\"\n",
120
  "start = 6_700_000\n",
121
  "end = 6_831_072\n",
122
  "\n",
123
- "# Optional: if the model is gated/private, set HF_TOKEN to a PERSONAL token (hf_...)\n",
124
  "HF_TOKEN = os.getenv(\"HF_TOKEN\", None)\n",
125
  "\n",
126
  "assert end > start, \"end must be > start\"\n",
@@ -138,7 +131,7 @@
138
  },
139
  {
140
  "cell_type": "code",
141
- "execution_count": 3,
142
  "id": "2354e2aa",
143
  "metadata": {},
144
  "outputs": [
@@ -175,7 +168,8 @@
175
  "name": "stdout",
176
  "output_type": "stream",
177
  "text": [
178
- "Using downloaded chromosome FASTA: ./genomes/hg38/chr19.fa\n",
 
179
  "Sequence preview: GTCAACAATAACAAATGACATATTAGTAGTAAATTATAATTATACATTACAACAAAATTA...\n",
180
  "Valid DNA: True\n"
181
  ]
@@ -234,47 +228,198 @@
234
  },
235
  {
236
  "cell_type": "code",
237
- "execution_count": 5,
238
  "id": "e09f0469",
239
  "metadata": {},
240
  "outputs": [
241
  {
242
- "name": "stdout",
243
- "output_type": "stream",
244
- "text": [
245
- "Model supported assemblies: ['AmpOce1', 'Bison_UMD1', 'ChiLan1', 'Felis_catus_9', 'GRCz11', 'Glycine_max_v2.1', 'Gossypium_hirsutum_v2.1', 'IRGSP-1.0', 'IWGSC', 'KH', 'Mnem_1', 'ROS_Cfam_1', 'SCA1', 'TAIR10', 'TETRAODON8', 'WBcel235', 'Zm-B73-REFERENCE-NAM-5.0', 'bGalGal1', 'dm6', 'fSalTru1', 'gorGor4', 'hg38', 'mRatBN7', 'mm10']\n"
246
- ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
247
  }
248
  ],
249
  "source": [
250
  "# Load model\n",
251
- "cfg = AutoConfig.from_pretrained(model_name, trust_remote_code=True, token=HF_TOKEN)\n",
252
- "model = AutoModel.from_pretrained(model_name, trust_remote_code=True, token=HF_TOKEN).to(device)\n",
253
  "\n",
254
  "# Load tokenizer\n",
255
- "tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True, token=HF_TOKEN)\n",
256
  "\n",
257
  "# Load condition tokenizer\n",
258
- "condition_tokenizer = AutoTokenizer.from_pretrained(\n",
259
- " model_name, subfolder=\"condition_tokenizer\", trust_remote_code=True, token=HF_TOKEN\n",
260
  ")\n",
261
  "\n",
262
  "# Set model to evaluation mode\n",
263
- "model.eval()\n",
264
- "\n",
265
- "# Get assembly index\n",
266
- "assemblies = list(cfg.bigwigs_per_file_assembly.keys())\n",
267
- "print(\"Model supported assemblies:\", assemblies)\n",
268
- "assembly_idx = torch.tensor([assemblies.index(assembly)])\n",
269
- "\n",
270
- "# Condition token (species)\n",
271
- "condition = condition_tokenizer([\"human\"], return_tensors=\"pt\")\n",
272
- "condition_ids = [condition[\"input_ids\"][0].to(device)]"
273
  ]
274
  },
275
  {
276
  "cell_type": "code",
277
- "execution_count": 6,
278
  "id": "43154959",
279
  "metadata": {},
280
  "outputs": [
@@ -307,8 +452,7 @@
307
  "We pass:\n",
308
  "\n",
309
  "- `input_ids`: tokenized DNA window\n",
310
- "- `condition_ids`: species tokens (`human`)\n",
311
- "- `file_assembly_idx`: select the assembly (`hg38`)\n",
312
  "\n",
313
  "Outputs include:\n",
314
  "\n",
@@ -319,7 +463,7 @@
319
  },
320
  {
321
  "cell_type": "code",
322
- "execution_count": null,
323
  "id": "6765a9b9",
324
  "metadata": {},
325
  "outputs": [
@@ -338,13 +482,15 @@
338
  "batch = tokenizer([seq], add_special_tokens=False, return_tensors=\"pt\")\n",
339
  "input_ids = batch[\"input_ids\"].to(device)\n",
340
  "\n",
 
 
 
 
341
  "# Run inference\n",
342
  "out = model(\n",
343
  " input_ids=input_ids,\n",
344
- " condition_ids=condition_ids,\n",
345
- " file_assembly_idx=assembly_idx,\n",
346
- " output_hidden_states=False,\n",
347
- " output_attentions=False,\n",
348
  ")\n",
349
  "\n",
350
  "# 7k human tracks over 37.5 % center region of the input sequence\n",
@@ -391,7 +537,7 @@
391
  },
392
  {
393
  "cell_type": "code",
394
- "execution_count": null,
395
  "id": "717539e2",
396
  "metadata": {},
397
  "outputs": [],
 
19
  "- **Genomic element annotations** (`bed_tracks_logits`): Classification predictions for genomic elements such as genes, exons, introns, splice sites, promoters, enhancers, and more\n",
20
  "- **Masked Language Model logits** (`logits`): Standard transformer language model outputs\n",
21
  "\n",
22
+ "> 💡 **Note:** Functional tracks and genomic element annotations are predicted only for the center 37.5% of the input sequence, where the model is more confident due to having full context on both sides.\n",
23
+ "\n",
24
  "## 📚 Notebook Structure\n",
25
  "\n",
26
  "1. **Setup**: Install dependencies and define the genomic window of interest\n",
 
35
  "- Supports the 24 species that NTv3 was post-trained on"
36
  ]
37
  },
 
 
 
 
 
 
 
 
 
 
38
  {
39
  "cell_type": "code",
40
  "execution_count": null,
 
57
  },
58
  {
59
  "cell_type": "code",
60
+ "execution_count": 7,
61
  "id": "608d67e1",
62
  "metadata": {},
63
  "outputs": [],
 
88
  },
89
  {
90
  "cell_type": "code",
91
+ "execution_count": 8,
92
  "id": "795a576f",
93
  "metadata": {},
94
  "outputs": [
 
104
  "# -----------------------------\n",
105
  "# User inputs\n",
106
  "# -----------------------------\n",
107
+ "model_name = \"InstaDeepAI/NTv3_100M_pos\" # options: \"InstaDeepAI/ntv3_106M_7downsample_post_trained_1mb\" or \"InstaDeepAI/ntv3_650M_7downsample_post_trained_1mb_v2\"\n",
108
  "\n",
109
  "# Example window from a given species (edit these) - needs to be multiple of 128 due to the model downsampling\n",
110
+ "species = \"human\" # will use for condition the model on species\n",
111
+ "assembly = \"hg38\" # will use for fetching the chromosome sequence\n",
112
  "chrom = \"chr19\"\n",
113
  "start = 6_700_000\n",
114
  "end = 6_831_072\n",
115
  "\n",
116
+ "# Optional\n",
117
  "HF_TOKEN = os.getenv(\"HF_TOKEN\", None)\n",
118
  "\n",
119
  "assert end > start, \"end must be > start\"\n",
 
131
  },
132
  {
133
  "cell_type": "code",
134
+ "execution_count": 4,
135
  "id": "2354e2aa",
136
  "metadata": {},
137
  "outputs": [
 
168
  "name": "stdout",
169
  "output_type": "stream",
170
  "text": [
171
+ "Downloading: https://hgdownload.soe.ucsc.edu/goldenPath/hg38/chromosomes/chr19.fa.gz\n",
172
+ "Using downloaded chromosome FASTA: ./hg38/chr19.fa\n",
173
  "Sequence preview: GTCAACAATAACAAATGACATATTAGTAGTAAATTATAATTATACATTACAACAAAATTA...\n",
174
  "Valid DNA: True\n"
175
  ]
 
228
  },
229
  {
230
  "cell_type": "code",
231
+ "execution_count": 11,
232
  "id": "e09f0469",
233
  "metadata": {},
234
  "outputs": [
235
  {
236
+ "data": {
237
+ "text/plain": [
238
+ "NTv3Model(\n",
239
+ " (core): Core(\n",
240
+ " (embed_layer): Embedding(11, 16, padding_idx=1)\n",
241
+ " (stem): Stem(\n",
242
+ " (conv): Conv1d(16, 768, kernel_size=(15,), stride=(1,), padding=same)\n",
243
+ " )\n",
244
+ " (cond_tables): ModuleList(\n",
245
+ " (0): Embedding(30, 16)\n",
246
+ " )\n",
247
+ " (conv_tower_blocks): ModuleList(\n",
248
+ " (0-6): 7 x ConditionedConvTowerBlock(\n",
249
+ " (conv): AdaptiveConvBlock(\n",
250
+ " (conv): Conv1d(768, 768, kernel_size=(5,), stride=(1,), padding=same)\n",
251
+ " (layer_norm): AdaptiveLayerNorm(\n",
252
+ " (np.int64(768),), eps=1e-05, elementwise_affine=True\n",
253
+ " (modulation_layers): ModuleList(\n",
254
+ " (0): Linear(in_features=16, out_features=1536, bias=True)\n",
255
+ " )\n",
256
+ " )\n",
257
+ " )\n",
258
+ " (res_conv): AdaptiveResidualConvBlock(\n",
259
+ " (conv_block): AdaptiveConvBlock(\n",
260
+ " (conv): Conv1d(768, 768, kernel_size=(1,), stride=(1,), padding=same)\n",
261
+ " (layer_norm): AdaptiveLayerNorm(\n",
262
+ " (np.int64(768),), eps=1e-05, elementwise_affine=True\n",
263
+ " (modulation_layers): ModuleList(\n",
264
+ " (0): Linear(in_features=16, out_features=1536, bias=True)\n",
265
+ " )\n",
266
+ " )\n",
267
+ " )\n",
268
+ " (modulation_layers): ModuleList(\n",
269
+ " (0): Linear(in_features=16, out_features=768, bias=True)\n",
270
+ " )\n",
271
+ " )\n",
272
+ " (avg_pool): AvgPool1d(kernel_size=(2,), stride=(2,), padding=(0,))\n",
273
+ " )\n",
274
+ " )\n",
275
+ " (transformer_blocks): ModuleList(\n",
276
+ " (0-5): 6 x AdaptiveSelfAttentionBlock(\n",
277
+ " (self_attention_layer_norm): AdaptiveLayerNorm(\n",
278
+ " (768,), eps=1e-05, elementwise_affine=True\n",
279
+ " (modulation_layers): ModuleList(\n",
280
+ " (0): Linear(in_features=16, out_features=1536, bias=True)\n",
281
+ " )\n",
282
+ " )\n",
283
+ " (final_layer_norm): AdaptiveLayerNorm(\n",
284
+ " (768,), eps=1e-05, elementwise_affine=True\n",
285
+ " (modulation_layers): ModuleList(\n",
286
+ " (0): Linear(in_features=16, out_features=1536, bias=True)\n",
287
+ " )\n",
288
+ " )\n",
289
+ " (sa_layer): MultiHeadAttention(\n",
290
+ " (query_head): LinearProjectionHeInit(\n",
291
+ " (linear): Linear(in_features=768, out_features=768, bias=True)\n",
292
+ " )\n",
293
+ " (key_head): LinearProjectionHeInit(\n",
294
+ " (linear): Linear(in_features=768, out_features=768, bias=True)\n",
295
+ " )\n",
296
+ " (value_head): LinearProjectionHeInit(\n",
297
+ " (linear): Linear(in_features=768, out_features=768, bias=True)\n",
298
+ " )\n",
299
+ " (mha_output): Linear(in_features=768, out_features=768, bias=True)\n",
300
+ " (rotary_embedding): RotaryEmbedding()\n",
301
+ " )\n",
302
+ " (fc1): Linear(in_features=768, out_features=6144, bias=False)\n",
303
+ " (fc2): Linear(in_features=3072, out_features=768, bias=False)\n",
304
+ " (_ffn_activation_fn): SiLU()\n",
305
+ " (modulation_layers): ModuleList(\n",
306
+ " (0): Linear(in_features=16, out_features=768, bias=True)\n",
307
+ " )\n",
308
+ " )\n",
309
+ " )\n",
310
+ " (deconv_tower_blocks): ModuleList(\n",
311
+ " (0-6): 7 x ConditionedDeConvTowerBlock(\n",
312
+ " (conv): AdaptiveDeConvBlock(\n",
313
+ " (conv): Conv1d(768, 768, kernel_size=(5,), stride=(1,), padding=same)\n",
314
+ " (layer_norm): AdaptiveLayerNorm(\n",
315
+ " (np.int64(768),), eps=1e-05, elementwise_affine=True\n",
316
+ " (modulation_layers): ModuleList(\n",
317
+ " (0): Linear(in_features=16, out_features=1536, bias=True)\n",
318
+ " )\n",
319
+ " )\n",
320
+ " )\n",
321
+ " (res_conv): AdaptiveResidualDeConvBlock(\n",
322
+ " (conv_block): AdaptiveDeConvBlock(\n",
323
+ " (conv): ConvTranspose1d(768, 768, kernel_size=(1,), stride=(1,))\n",
324
+ " (layer_norm): AdaptiveLayerNorm(\n",
325
+ " (np.int64(768),), eps=1e-05, elementwise_affine=True\n",
326
+ " (modulation_layers): ModuleList(\n",
327
+ " (0): Linear(in_features=16, out_features=1536, bias=True)\n",
328
+ " )\n",
329
+ " )\n",
330
+ " )\n",
331
+ " (modulation_layers): ModuleList(\n",
332
+ " (0): Linear(in_features=16, out_features=768, bias=True)\n",
333
+ " )\n",
334
+ " )\n",
335
+ " )\n",
336
+ " )\n",
337
+ " (bigwig_head): MultiSpeciesHead(\n",
338
+ " (species_heads): ModuleList(\n",
339
+ " (0-4): 5 x ZeroHead()\n",
340
+ " (5): LinearHead(\n",
341
+ " (layer_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n",
342
+ " (head): Linear(in_features=768, out_features=590, bias=True)\n",
343
+ " )\n",
344
+ " (6): LinearHead(\n",
345
+ " (layer_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n",
346
+ " (head): Linear(in_features=768, out_features=319, bias=True)\n",
347
+ " )\n",
348
+ " (7): LinearHead(\n",
349
+ " (layer_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n",
350
+ " (head): Linear(in_features=768, out_features=1392, bias=True)\n",
351
+ " )\n",
352
+ " (8): LinearHead(\n",
353
+ " (layer_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n",
354
+ " (head): Linear(in_features=768, out_features=776, bias=True)\n",
355
+ " )\n",
356
+ " (9-12): 4 x ZeroHead()\n",
357
+ " (13): LinearHead(\n",
358
+ " (layer_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n",
359
+ " (head): Linear(in_features=768, out_features=1899, bias=True)\n",
360
+ " )\n",
361
+ " (14-15): 2 x ZeroHead()\n",
362
+ " (16): LinearHead(\n",
363
+ " (layer_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n",
364
+ " (head): Linear(in_features=768, out_features=921, bias=True)\n",
365
+ " )\n",
366
+ " (17): ZeroHead()\n",
367
+ " (18): LinearHead(\n",
368
+ " (layer_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n",
369
+ " (head): Linear(in_features=768, out_features=180, bias=True)\n",
370
+ " )\n",
371
+ " (19-20): 2 x ZeroHead()\n",
372
+ " (21): LinearHead(\n",
373
+ " (layer_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n",
374
+ " (head): Linear(in_features=768, out_features=7362, bias=True)\n",
375
+ " )\n",
376
+ " (22): ZeroHead()\n",
377
+ " (23): LinearHead(\n",
378
+ " (layer_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n",
379
+ " (head): Linear(in_features=768, out_features=2450, bias=True)\n",
380
+ " )\n",
381
+ " )\n",
382
+ " )\n",
383
+ " (bed_head): ClassificationHead(\n",
384
+ " (layer_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n",
385
+ " (head): Linear(in_features=768, out_features=42, bias=True)\n",
386
+ " )\n",
387
+ " (conditions_heads): ModuleList(\n",
388
+ " (0): Linear(in_features=768, out_features=30, bias=True)\n",
389
+ " )\n",
390
+ " (lm_head): ModuleDict(\n",
391
+ " (hidden_layers): ModuleList()\n",
392
+ " (head): Linear(in_features=768, out_features=11, bias=True)\n",
393
+ " )\n",
394
+ " )\n",
395
+ ")"
396
+ ]
397
+ },
398
+ "execution_count": 11,
399
+ "metadata": {},
400
+ "output_type": "execute_result"
401
  }
402
  ],
403
  "source": [
404
  "# Load model\n",
405
+ "cfg = AutoConfig.from_pretrained(model_name, trust_remote_code=True)\n",
406
+ "model = AutoModel.from_pretrained(model_name, trust_remote_code=True).to(device)\n",
407
  "\n",
408
  "# Load tokenizer\n",
409
+ "tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)\n",
410
  "\n",
411
  "# Load condition tokenizer\n",
412
+ "species_tokenizer = AutoTokenizer.from_pretrained(\n",
413
+ " model_name, subfolder=\"species_tokenizer\", trust_remote_code=True,\n",
414
  ")\n",
415
  "\n",
416
  "# Set model to evaluation mode\n",
417
+ "model.eval()"
 
 
 
 
 
 
 
 
 
418
  ]
419
  },
420
  {
421
  "cell_type": "code",
422
+ "execution_count": 12,
423
  "id": "43154959",
424
  "metadata": {},
425
  "outputs": [
 
452
  "We pass:\n",
453
  "\n",
454
  "- `input_ids`: tokenized DNA window\n",
455
+ "- `species_ids`: species tokens (`human`)\n",
 
456
  "\n",
457
  "Outputs include:\n",
458
  "\n",
 
463
  },
464
  {
465
  "cell_type": "code",
466
+ "execution_count": 13,
467
  "id": "6765a9b9",
468
  "metadata": {},
469
  "outputs": [
 
482
  "batch = tokenizer([seq], add_special_tokens=False, return_tensors=\"pt\")\n",
483
  "input_ids = batch[\"input_ids\"].to(device)\n",
484
  "\n",
485
+ "# Condition tokens (e.g., species)\n",
486
+ "species = 'human'\n",
487
+ "species_ids = species_tokenizer([species] * len(batch['input_ids']), add_special_tokens=False, return_tensors='pt')\n",
488
+ "\n",
489
  "# Run inference\n",
490
  "out = model(\n",
491
  " input_ids=input_ids,\n",
492
+ " species_ids=species_ids['input_ids'],\n",
493
+ " return_dict=True\n",
 
 
494
  ")\n",
495
  "\n",
496
  "# 7k human tracks over 37.5 % center region of the input sequence\n",
 
537
  },
538
  {
539
  "cell_type": "code",
540
+ "execution_count": 15,
541
  "id": "717539e2",
542
  "metadata": {},
543
  "outputs": [],