bernardo-de-almeida commited on
Commit
9759882
·
1 Parent(s): 1dc15bb

fix: notebooks

Browse files
notebooks_pipelines/01_functional_track_prediction.ipynb CHANGED
@@ -64,13 +64,13 @@
64
  },
65
  {
66
  "cell_type": "code",
67
- "execution_count": 3,
68
  "id": "423af70a",
69
  "metadata": {},
70
  "outputs": [],
71
  "source": [
72
  "# Define the model and genomic window\n",
73
- "model_name = \"InstaDeepAI/NTv3_650M_pos\"\n",
74
  "\n",
75
  "species = \"human\" # will use for condition the model on species\n",
76
  "assembly = \"hg38\" # will use for fetching the chromosome sequence\n",
 
64
  },
65
  {
66
  "cell_type": "code",
67
+ "execution_count": null,
68
  "id": "423af70a",
69
  "metadata": {},
70
  "outputs": [],
71
  "source": [
72
  "# Define the model and genomic window\n",
73
+ "model_name = \"InstaDeepAI/NTv3_650M_post\"\n",
74
  "\n",
75
  "species = \"human\" # will use for condition the model on species\n",
76
  "assembly = \"hg38\" # will use for fetching the chromosome sequence\n",
notebooks_tutorials/00_quickstart_inference.ipynb CHANGED
@@ -10,7 +10,7 @@
10
  "This notebook demonstrates how to run **quick inference** with both the pre- and post-trained NTv3 checkpoints:\n",
11
  "\n",
12
  "- **Pre-trained (MLM-focused):** `InstaDeepAI/NTv3_8M_pre`, `InstaDeepAI/NTv3_100M_pre`, `InstaDeepAI/NTv3_650M_pre`\n",
13
- "- **Post-trained (functional tracks and genome annotation):** `InstaDeepAI/NTv3_100M_pos`, `InstaDeepAI/NTv3_650M_pos`\n",
14
  "\n",
15
  "We show how to:\n",
16
  "\n",
@@ -31,7 +31,7 @@
31
  },
32
  {
33
  "cell_type": "code",
34
- "execution_count": null,
35
  "id": "38cc32a9",
36
  "metadata": {},
37
  "outputs": [],
@@ -41,7 +41,7 @@
41
  },
42
  {
43
  "cell_type": "code",
44
- "execution_count": 3,
45
  "id": "d56c105b",
46
  "metadata": {},
47
  "outputs": [
@@ -95,156 +95,15 @@
95
  },
96
  {
97
  "cell_type": "code",
98
- "execution_count": null,
99
  "id": "336bb40c",
100
  "metadata": {},
101
  "outputs": [
102
- {
103
- "data": {
104
- "application/vnd.jupyter.widget-view+json": {
105
- "model_id": "411ee47e94ae467f9685c35b65e3e52d",
106
- "version_major": 2,
107
- "version_minor": 0
108
- },
109
- "text/plain": [
110
- "tokenizer_config.json: 0%| | 0.00/1.48k [00:00<?, ?B/s]"
111
- ]
112
- },
113
- "metadata": {},
114
- "output_type": "display_data"
115
- },
116
- {
117
- "data": {
118
- "application/vnd.jupyter.widget-view+json": {
119
- "model_id": "30447edb44b849bd936290f3a6b1b863",
120
- "version_major": 2,
121
- "version_minor": 0
122
- },
123
- "text/plain": [
124
- "tokenization_ntv3.py: 0%| | 0.00/12.0k [00:00<?, ?B/s]"
125
- ]
126
- },
127
- "metadata": {},
128
- "output_type": "display_data"
129
- },
130
- {
131
- "name": "stderr",
132
- "output_type": "stream",
133
- "text": [
134
- "A new version of the following files was downloaded from https://huggingface.co/InstaDeepAI/ntv3_base_model:\n",
135
- "- tokenization_ntv3.py\n",
136
- ". Make sure to double-check they do not contain any added malicious code. To avoid downloading new versions of the code file, you can pin a revision.\n"
137
- ]
138
- },
139
- {
140
- "data": {
141
- "application/vnd.jupyter.widget-view+json": {
142
- "model_id": "766f183dcc84421588e5cf0241d3efe7",
143
- "version_major": 2,
144
- "version_minor": 0
145
- },
146
- "text/plain": [
147
- "vocab.json: 0%| | 0.00/138 [00:00<?, ?B/s]"
148
- ]
149
- },
150
- "metadata": {},
151
- "output_type": "display_data"
152
- },
153
- {
154
- "data": {
155
- "application/vnd.jupyter.widget-view+json": {
156
- "model_id": "b0db83f7cb824d3288a30bebf7891a63",
157
- "version_major": 2,
158
- "version_minor": 0
159
- },
160
- "text/plain": [
161
- "special_tokens_map.json: 0%| | 0.00/149 [00:00<?, ?B/s]"
162
- ]
163
- },
164
- "metadata": {},
165
- "output_type": "display_data"
166
- },
167
- {
168
- "data": {
169
- "application/vnd.jupyter.widget-view+json": {
170
- "model_id": "33cf5391dcc549f088e4e927651d1cdb",
171
- "version_major": 2,
172
- "version_minor": 0
173
- },
174
- "text/plain": [
175
- "config.json: 0%| | 0.00/1.70k [00:00<?, ?B/s]"
176
- ]
177
- },
178
- "metadata": {},
179
- "output_type": "display_data"
180
- },
181
- {
182
- "data": {
183
- "application/vnd.jupyter.widget-view+json": {
184
- "model_id": "85772d5369234ca286cfa518e1725b12",
185
- "version_major": 2,
186
- "version_minor": 0
187
- },
188
- "text/plain": [
189
- "configuration_ntv3.py: 0%| | 0.00/5.90k [00:00<?, ?B/s]"
190
- ]
191
- },
192
- "metadata": {},
193
- "output_type": "display_data"
194
- },
195
- {
196
- "name": "stderr",
197
- "output_type": "stream",
198
- "text": [
199
- "A new version of the following files was downloaded from https://huggingface.co/InstaDeepAI/ntv3_base_model:\n",
200
- "- configuration_ntv3.py\n",
201
- ". Make sure to double-check they do not contain any added malicious code. To avoid downloading new versions of the code file, you can pin a revision.\n"
202
- ]
203
- },
204
- {
205
- "data": {
206
- "application/vnd.jupyter.widget-view+json": {
207
- "model_id": "ec1153d073e444c5b255ee5adea6ba68",
208
- "version_major": 2,
209
- "version_minor": 0
210
- },
211
- "text/plain": [
212
- "modeling_ntv3_base.py: 0%| | 0.00/33.9k [00:00<?, ?B/s]"
213
- ]
214
- },
215
- "metadata": {},
216
- "output_type": "display_data"
217
- },
218
- {
219
- "name": "stderr",
220
- "output_type": "stream",
221
- "text": [
222
- "A new version of the following files was downloaded from https://huggingface.co/InstaDeepAI/ntv3_base_model:\n",
223
- "- modeling_ntv3_base.py\n",
224
- ". Make sure to double-check they do not contain any added malicious code. To avoid downloading new versions of the code file, you can pin a revision.\n"
225
- ]
226
- },
227
- {
228
- "data": {
229
- "application/vnd.jupyter.widget-view+json": {
230
- "model_id": "94b9bb7fe0da4f4994adb9127d9af7e6",
231
- "version_major": 2,
232
- "version_minor": 0
233
- },
234
- "text/plain": [
235
- "model.safetensors: 0%| | 0.00/30.8M [00:00<?, ?B/s]"
236
- ]
237
- },
238
- "metadata": {},
239
- "output_type": "display_data"
240
- },
241
  {
242
  "name": "stdout",
243
  "output_type": "stream",
244
  "text": [
245
  "torch.Size([2, 128, 11])\n",
246
- "16\n",
247
- "2\n",
248
  "MLM logits shape: (2, 128, 11)\n"
249
  ]
250
  }
@@ -259,11 +118,9 @@
259
  "# Example: human sequence\n",
260
  "seqs = [\"ATCGNATCG\", \"ACGT\"]\n",
261
  "batch = tok_pre(seqs, add_special_tokens=False, padding=True, pad_to_multiple_of=128, return_tensors=\"pt\")\n",
262
- "out = model_pre(**batch, output_hidden_states=True, output_attentions=True)\n",
263
  "\n",
264
  "print(out.logits.shape) # (B, L, V = 11)\n",
265
- "print(len(out.hidden_states)) # convs + transformers + deconvs\n",
266
- "print(len(out.attentions))\n",
267
  "\n",
268
  "# Access MLM logits\n",
269
  "mlm_logits = out[\"logits\"]\n",
@@ -279,10 +136,6 @@
279
  "\n",
280
  "Post-trained checkpoints add task-specific heads for functional track prediction and genome annotation.\n",
281
  "\n",
282
- "In particular:\n",
283
- "- `species_tokenizer` is used to tokenize a species condition like `\"human\"`\n",
284
- "- `species_ids` passes the species tokens to the model\n",
285
- "\n",
286
  "Expected outputs:\n",
287
  "- `bigwig_tracks_logits`: functional track predictions\n",
288
  "- `bed_tracks_logits`: genome annotation predictions\n",
@@ -291,31 +144,7 @@
291
  },
292
  {
293
  "cell_type": "code",
294
- "execution_count": 9,
295
- "id": "bdb8c4d1",
296
- "metadata": {},
297
- "outputs": [
298
- {
299
- "name": "stdout",
300
- "output_type": "stream",
301
- "text": [
302
- "Model supported species: TO BE DONE\n"
303
- ]
304
- }
305
- ],
306
- "source": [
307
- "# Inspect config and supported species\n",
308
- "post_trained_model_name = \"InstaDeepAI/NTv3_100M_pos\"\n",
309
- "\n",
310
- "cfg_post = AutoConfig.from_pretrained(post_trained_model_name, trust_remote_code=True)\n",
311
- "\n",
312
- "species = \"TO BE DONE\"\n",
313
- "print(\"Model supported species:\", species)"
314
- ]
315
- },
316
- {
317
- "cell_type": "code",
318
- "execution_count": null,
319
  "id": "6cc5f2df",
320
  "metadata": {},
321
  "outputs": [
@@ -323,29 +152,33 @@
323
  "name": "stdout",
324
  "output_type": "stream",
325
  "text": [
326
- "torch.Size([1, 768, 7362])\n",
327
- "torch.Size([1, 768, 21, 2])\n",
328
- "torch.Size([1, 2048, 11])\n"
 
329
  ]
330
  }
331
  ],
332
  "source": [
 
 
 
333
  "tok_post = AutoTokenizer.from_pretrained(post_trained_model_name, trust_remote_code=True)\n",
334
- "cond_tok_post = AutoTokenizer.from_pretrained(post_trained_model_name, subfolder='species_tokenizer', trust_remote_code=True)\n",
335
  "model_post = AutoModel.from_pretrained(post_trained_model_name, trust_remote_code=True)\n",
336
  "\n",
337
  "# Prepare inputs\n",
338
  "batch = tok_post([\"ATCGNATCG\", \"ACGT\"], add_special_tokens=False, padding=True, pad_to_multiple_of=128, return_tensors=\"pt\")\n",
339
  "\n",
340
- "# Condition tokens (e.g., species)\n",
341
- "species = 'human'\n",
342
- "species_ids = cond_tok_post([species] * len(batch['input_ids']), add_special_tokens=False, return_tensors='pt')\n",
 
 
343
  "\n",
344
  "# Forward pass\n",
345
  "out = model_post(\n",
346
  " input_ids=batch[\"input_ids\"],\n",
347
- " species_ids=species_ids['input_ids'],\n",
348
- " return_dict=True\n",
349
  ")\n",
350
  "\n",
351
  "# 7k human tracks over 37.5 % center region of the input sequence\n",
@@ -355,6 +188,14 @@
355
  "# Language model logits for whole sequence over vocabulary\n",
356
  "print(\"language model logits:\", tuple(out[\"logits\"].shape))\n"
357
  ]
 
 
 
 
 
 
 
 
358
  }
359
  ],
360
  "metadata": {
 
10
  "This notebook demonstrates how to run **quick inference** with both the pre- and post-trained NTv3 checkpoints:\n",
11
  "\n",
12
  "- **Pre-trained (MLM-focused):** `InstaDeepAI/NTv3_8M_pre`, `InstaDeepAI/NTv3_100M_pre`, `InstaDeepAI/NTv3_650M_pre`\n",
13
+ "- **Post-trained (functional tracks and genome annotation):** `InstaDeepAI/NTv3_100M_post`, `InstaDeepAI/NTv3_650M_post`\n",
14
  "\n",
15
  "We show how to:\n",
16
  "\n",
 
31
  },
32
  {
33
  "cell_type": "code",
34
+ "execution_count": 1,
35
  "id": "38cc32a9",
36
  "metadata": {},
37
  "outputs": [],
 
41
  },
42
  {
43
  "cell_type": "code",
44
+ "execution_count": 2,
45
  "id": "d56c105b",
46
  "metadata": {},
47
  "outputs": [
 
95
  },
96
  {
97
  "cell_type": "code",
98
+ "execution_count": 3,
99
  "id": "336bb40c",
100
  "metadata": {},
101
  "outputs": [
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
102
  {
103
  "name": "stdout",
104
  "output_type": "stream",
105
  "text": [
106
  "torch.Size([2, 128, 11])\n",
 
 
107
  "MLM logits shape: (2, 128, 11)\n"
108
  ]
109
  }
 
118
  "# Example: human sequence\n",
119
  "seqs = [\"ATCGNATCG\", \"ACGT\"]\n",
120
  "batch = tok_pre(seqs, add_special_tokens=False, padding=True, pad_to_multiple_of=128, return_tensors=\"pt\")\n",
121
+ "out = model_pre(**batch)\n",
122
  "\n",
123
  "print(out.logits.shape) # (B, L, V = 11)\n",
 
 
124
  "\n",
125
  "# Access MLM logits\n",
126
  "mlm_logits = out[\"logits\"]\n",
 
136
  "\n",
137
  "Post-trained checkpoints add task-specific heads for functional track prediction and genome annotation.\n",
138
  "\n",
 
 
 
 
139
  "Expected outputs:\n",
140
  "- `bigwig_tracks_logits`: functional track predictions\n",
141
  "- `bed_tracks_logits`: genome annotation predictions\n",
 
144
  },
145
  {
146
  "cell_type": "code",
147
+ "execution_count": 4,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
148
  "id": "6cc5f2df",
149
  "metadata": {},
150
  "outputs": [
 
152
  "name": "stdout",
153
  "output_type": "stream",
154
  "text": [
155
+ "Supported species: dict_keys(['<bos>', '<cls>', '<eos>', '<mask>', '<pad>', '<unk>', 'amphiprion_ocellaris', 'arabidopsis_thaliana', 'bison_bison_bison', 'caenorhabditis_elegans', 'canis_lupus_familiaris', 'chinchilla_lanigera', 'ciona_intestinalis', 'danio_rerio', 'drosophila_melanogaster', 'felis_catus', 'gallus_gallus', 'glycine_max', 'gorilla_gorilla', 'gossypium_hirsutum', 'human', 'macaca_nemestrina', 'mouse', 'oryza_sativa', 'rattus_norvegicus', 'salmo_trutta', 'serinus_canaria', 'tetraodon_nigroviridis', 'triticum_aestivum', 'zea_mays'])\n",
156
+ "bigwig_tracks_logits: (2, 48, 7362)\n",
157
+ "bed_tracks_logits: (2, 48, 21, 2)\n",
158
+ "language model logits: (2, 128, 11)\n"
159
  ]
160
  }
161
  ],
162
  "source": [
163
+ "# Load model\n",
164
+ "post_trained_model_name = \"InstaDeepAI/NTv3_100M_post\"\n",
165
+ "\n",
166
  "tok_post = AutoTokenizer.from_pretrained(post_trained_model_name, trust_remote_code=True)\n",
 
167
  "model_post = AutoModel.from_pretrained(post_trained_model_name, trust_remote_code=True)\n",
168
  "\n",
169
  "# Prepare inputs\n",
170
  "batch = tok_post([\"ATCGNATCG\", \"ACGT\"], add_special_tokens=False, padding=True, pad_to_multiple_of=128, return_tensors=\"pt\")\n",
171
  "\n",
172
+ "# To show all supported species: \n",
173
+ "print(\"Supported species:\", model_post.config.species_to_token_id.keys())\n",
174
+ "# Species tokens\n",
175
+ "species = ['human', 'mouse']\n",
176
+ "species_ids = model_post.encode_species(species)\n",
177
  "\n",
178
  "# Forward pass\n",
179
  "out = model_post(\n",
180
  " input_ids=batch[\"input_ids\"],\n",
181
+ " species_ids=species_ids,\n",
 
182
  ")\n",
183
  "\n",
184
  "# 7k human tracks over 37.5 % center region of the input sequence\n",
 
188
  "# Language model logits for whole sequence over vocabulary\n",
189
  "print(\"language model logits:\", tuple(out[\"logits\"].shape))\n"
190
  ]
191
+ },
192
+ {
193
+ "cell_type": "code",
194
+ "execution_count": null,
195
+ "id": "037076cd",
196
+ "metadata": {},
197
+ "outputs": [],
198
+ "source": []
199
  }
200
  ],
201
  "metadata": {
notebooks_tutorials/01_tracks_prediction.ipynb CHANGED
@@ -116,7 +116,7 @@
116
  "# -----------------------------\n",
117
  "# User inputs\n",
118
  "# -----------------------------\n",
119
- "model_name = \"InstaDeepAI/NTv3_100M_pos\" # options: \"InstaDeepAI/NTv3_100M_pos\" or \"InstaDeepAI/NTv3_650M_pos\"\n",
120
  "\n",
121
  "# Example window from a given species (edit these) - needs to be multiple of 128 due to the model downsampling\n",
122
  "species = \"human\" # will use for condition the model on species\n",
@@ -173,22 +173,19 @@
173
  },
174
  {
175
  "cell_type": "code",
176
- "execution_count": 6,
177
  "id": "e09f0469",
178
  "metadata": {},
179
  "outputs": [
180
  {
181
  "data": {
182
  "text/plain": [
183
- "NTv3Model(\n",
184
- " (core): Core(\n",
185
- " (embed_layer): Embedding(11, 16, padding_idx=1)\n",
186
  " (stem): Stem(\n",
187
  " (conv): Conv1d(16, 768, kernel_size=(15,), stride=(1,), padding=same)\n",
188
  " )\n",
189
- " (cond_tables): ModuleList(\n",
190
- " (0): Embedding(30, 16)\n",
191
- " )\n",
192
  " (conv_tower_blocks): ModuleList(\n",
193
  " (0-6): 7 x ConditionedConvTowerBlock(\n",
194
  " (conv): AdaptiveConvBlock(\n",
@@ -279,6 +276,16 @@
279
  " )\n",
280
  " )\n",
281
  " )\n",
 
 
 
 
 
 
 
 
 
 
282
  " (bigwig_head): MultiSpeciesHead(\n",
283
  " (species_heads): ModuleList(\n",
284
  " (0-4): 5 x ZeroHead()\n",
@@ -329,13 +336,6 @@
329
  " (layer_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n",
330
  " (head): Linear(in_features=768, out_features=42, bias=True)\n",
331
  " )\n",
332
- " (conditions_heads): ModuleList(\n",
333
- " (0): Linear(in_features=768, out_features=30, bias=True)\n",
334
- " )\n",
335
- " (lm_head): ModuleDict(\n",
336
- " (hidden_layers): ModuleList()\n",
337
- " (head): Linear(in_features=768, out_features=11, bias=True)\n",
338
- " )\n",
339
  " )\n",
340
  ")"
341
  ]
@@ -347,24 +347,18 @@
347
  ],
348
  "source": [
349
  "# Load model\n",
350
- "cfg = AutoConfig.from_pretrained(model_name, trust_remote_code=True)\n",
351
  "model = AutoModel.from_pretrained(model_name, trust_remote_code=True).to(device)\n",
352
  "\n",
353
  "# Load tokenizer\n",
354
  "tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)\n",
355
  "\n",
356
- "# Load condition tokenizer\n",
357
- "species_tokenizer = AutoTokenizer.from_pretrained(\n",
358
- " model_name, subfolder=\"species_tokenizer\", trust_remote_code=True,\n",
359
- ")\n",
360
- "\n",
361
  "# Set model to evaluation mode\n",
362
  "model.eval()"
363
  ]
364
  },
365
  {
366
  "cell_type": "code",
367
- "execution_count": 7,
368
  "id": "43154959",
369
  "metadata": {},
370
  "outputs": [
@@ -372,15 +366,16 @@
372
  "name": "stdout",
373
  "output_type": "stream",
374
  "text": [
375
- "7362 functional tracks for hg38. First 10: ['kai1', 'kai2', 'kai3', 'kai4', 'kai5', 'kai6', 'kai7', 'kai8', 'kai10', 'kai9']\n",
376
  "Genomic elements predicted: ['protein_coding_gene', 'lncRNA', 'exon', 'intron', 'splice_donor', 'splice_acceptor', 'CTCF-bound', 'polyA_signal', 'enhancer_Tissue_specific', 'enhancer_Tissue_invariant', 'promoter_Tissue_specific', 'promoter_Tissue_invariant', '5UTR+', '5UTR-', '3UTR+', '3UTR-', 'skipped_exon', 'always_on_exon', 'start_codon', 'stop_codon', 'ORF']\n"
377
  ]
378
  }
379
  ],
380
  "source": [
381
  "# Inspect output functional tracks\n",
382
- "bigwig_names = cfg.bigwigs_per_file_assembly[assembly]\n",
383
- "print(f\"{len(bigwig_names)} functional tracks for {assembly}. First 10:\", bigwig_names[:10])\n",
 
384
  "\n",
385
  "# Inspect output genomic elements\n",
386
  "bed_element_names = cfg.bed_elements_names\n",
@@ -408,7 +403,7 @@
408
  },
409
  {
410
  "cell_type": "code",
411
- "execution_count": 8,
412
  "id": "6765a9b9",
413
  "metadata": {},
414
  "outputs": [
@@ -429,13 +424,12 @@
429
  "\n",
430
  "# Condition tokens (e.g., species)\n",
431
  "species = 'human'\n",
432
- "species_ids = species_tokenizer([species] * len(batch['input_ids']), add_special_tokens=False, return_tensors='pt')\n",
433
  "\n",
434
  "# Run inference\n",
435
  "out = model(\n",
436
  " input_ids=input_ids,\n",
437
- " species_ids=species_ids['input_ids'],\n",
438
- " return_dict=True\n",
439
  ")\n",
440
  "\n",
441
  "# 7k human tracks over 37.5 % center region of the input sequence\n",
@@ -465,7 +459,7 @@
465
  },
466
  {
467
  "cell_type": "code",
468
- "execution_count": 9,
469
  "id": "a26e9dcc",
470
  "metadata": {},
471
  "outputs": [],
@@ -482,7 +476,7 @@
482
  },
483
  {
484
  "cell_type": "code",
485
- "execution_count": 10,
486
  "id": "717539e2",
487
  "metadata": {},
488
  "outputs": [],
@@ -527,7 +521,7 @@
527
  },
528
  {
529
  "cell_type": "code",
530
- "execution_count": 12,
531
  "id": "7ba9a397",
532
  "metadata": {},
533
  "outputs": [
@@ -577,15 +571,6 @@
577
  "plot_tracks(all_tracks, prediction_start, prediction_end)\n",
578
  "plt.show()\n"
579
  ]
580
- },
581
- {
582
- "cell_type": "markdown",
583
- "id": "1ce34dc4",
584
- "metadata": {},
585
- "source": [
586
- "# 💡 To improve\n",
587
- "- Add gene annotation at top"
588
- ]
589
  }
590
  ],
591
  "metadata": {
 
116
  "# -----------------------------\n",
117
  "# User inputs\n",
118
  "# -----------------------------\n",
119
+ "model_name = \"InstaDeepAI/NTv3_100M_post\" # options: \"InstaDeepAI/NTv3_100M_post\" or \"InstaDeepAI/NTv3_650M_post\"\n",
120
  "\n",
121
  "# Example window from a given species (edit these) - needs to be multiple of 128 due to the model downsampling\n",
122
  "species = \"human\" # will use for condition the model on species\n",
 
173
  },
174
  {
175
  "cell_type": "code",
176
+ "execution_count": null,
177
  "id": "e09f0469",
178
  "metadata": {},
179
  "outputs": [
180
  {
181
  "data": {
182
  "text/plain": [
183
+ "NTv3PostTrained(\n",
184
+ " (core): NTv3PostTrainedCore(\n",
185
+ " (embed_layer): Embedding(11, 16)\n",
186
  " (stem): Stem(\n",
187
  " (conv): Conv1d(16, 768, kernel_size=(15,), stride=(1,), padding=same)\n",
188
  " )\n",
 
 
 
189
  " (conv_tower_blocks): ModuleList(\n",
190
  " (0-6): 7 x ConditionedConvTowerBlock(\n",
191
  " (conv): AdaptiveConvBlock(\n",
 
276
  " )\n",
277
  " )\n",
278
  " )\n",
279
+ " (lm_head): ModuleDict(\n",
280
+ " (hidden_layers): ModuleList()\n",
281
+ " (head): Linear(in_features=768, out_features=11, bias=True)\n",
282
+ " )\n",
283
+ " (cond_tables): ModuleList(\n",
284
+ " (0): Embedding(30, 16)\n",
285
+ " )\n",
286
+ " (conditions_heads): ModuleList(\n",
287
+ " (0): Linear(in_features=768, out_features=30, bias=True)\n",
288
+ " )\n",
289
  " (bigwig_head): MultiSpeciesHead(\n",
290
  " (species_heads): ModuleList(\n",
291
  " (0-4): 5 x ZeroHead()\n",
 
336
  " (layer_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n",
337
  " (head): Linear(in_features=768, out_features=42, bias=True)\n",
338
  " )\n",
 
 
 
 
 
 
 
339
  " )\n",
340
  ")"
341
  ]
 
347
  ],
348
  "source": [
349
  "# Load model\n",
 
350
  "model = AutoModel.from_pretrained(model_name, trust_remote_code=True).to(device)\n",
351
  "\n",
352
  "# Load tokenizer\n",
353
  "tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)\n",
354
  "\n",
 
 
 
 
 
355
  "# Set model to evaluation mode\n",
356
  "model.eval()"
357
  ]
358
  },
359
  {
360
  "cell_type": "code",
361
+ "execution_count": 10,
362
  "id": "43154959",
363
  "metadata": {},
364
  "outputs": [
 
366
  "name": "stdout",
367
  "output_type": "stream",
368
  "text": [
369
+ "7362 functional tracks for human. First 10: ['kai1', 'kai2', 'kai3', 'kai4', 'kai5', 'kai6', 'kai7', 'kai8', 'kai10', 'kai9']\n",
370
  "Genomic elements predicted: ['protein_coding_gene', 'lncRNA', 'exon', 'intron', 'splice_donor', 'splice_acceptor', 'CTCF-bound', 'polyA_signal', 'enhancer_Tissue_specific', 'enhancer_Tissue_invariant', 'promoter_Tissue_specific', 'promoter_Tissue_invariant', '5UTR+', '5UTR-', '3UTR+', '3UTR-', 'skipped_exon', 'always_on_exon', 'start_codon', 'stop_codon', 'ORF']\n"
371
  ]
372
  }
373
  ],
374
  "source": [
375
  "# Inspect output functional tracks\n",
376
+ "cfg = AutoConfig.from_pretrained(model_name, trust_remote_code=True)\n",
377
+ "bigwig_names = cfg.bigwigs_per_species[species]\n",
378
+ "print(f\"{len(bigwig_names)} functional tracks for {species}. First 10:\", bigwig_names[:10])\n",
379
  "\n",
380
  "# Inspect output genomic elements\n",
381
  "bed_element_names = cfg.bed_elements_names\n",
 
403
  },
404
  {
405
  "cell_type": "code",
406
+ "execution_count": 11,
407
  "id": "6765a9b9",
408
  "metadata": {},
409
  "outputs": [
 
424
  "\n",
425
  "# Condition tokens (e.g., species)\n",
426
  "species = 'human'\n",
427
+ "species_ids = model.encode_species(species)\n",
428
  "\n",
429
  "# Run inference\n",
430
  "out = model(\n",
431
  " input_ids=input_ids,\n",
432
+ " species_ids=species_ids,\n",
 
433
  ")\n",
434
  "\n",
435
  "# 7k human tracks over 37.5 % center region of the input sequence\n",
 
459
  },
460
  {
461
  "cell_type": "code",
462
+ "execution_count": 12,
463
  "id": "a26e9dcc",
464
  "metadata": {},
465
  "outputs": [],
 
476
  },
477
  {
478
  "cell_type": "code",
479
+ "execution_count": 13,
480
  "id": "717539e2",
481
  "metadata": {},
482
  "outputs": [],
 
521
  },
522
  {
523
  "cell_type": "code",
524
+ "execution_count": 14,
525
  "id": "7ba9a397",
526
  "metadata": {},
527
  "outputs": [
 
571
  "plot_tracks(all_tracks, prediction_start, prediction_end)\n",
572
  "plt.show()\n"
573
  ]
 
 
 
 
 
 
 
 
 
574
  }
575
  ],
576
  "metadata": {
tabs/home.html CHANGED
@@ -125,16 +125,10 @@ tok = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
125
  batch = tok(["ATCGNATCG", "ACGT"], add_special_tokens=False, padding=True, pad_to_multiple_of=128, return_tensors="pt")
126
 
127
  # Run model
128
- out = model(
129
- **batch,
130
- output_hidden_states=True,
131
- output_attentions=True
132
- )
133
 
134
  # Print output shapes
135
  print(out.logits.shape) # (B, L, V = 11)
136
- print(len(out.hidden_states)) # convs + transformers + deconvs
137
- print(len(out.attentions)) # equals transformer layers = 12
138
  </code></pre></div>
139
  <p>Model embeddings can be used for fine-tuning on downstream tasks.</p>
140
 
 
125
  batch = tok(["ATCGNATCG", "ACGT"], add_special_tokens=False, padding=True, pad_to_multiple_of=128, return_tensors="pt")
126
 
127
  # Run model
128
+ out = model(**batch)
 
 
 
 
129
 
130
  # Print output shapes
131
  print(out.logits.shape) # (B, L, V = 11)
 
 
132
  </code></pre></div>
133
  <p>Model embeddings can be used for fine-tuning on downstream tasks.</p>
134