bernardo-de-almeida commited on
Commit
3367165
·
1 Parent(s): a10b560

feat: add inference and track prediction notebooks

Browse files
notebooks/00_quickstart_inference.ipynb CHANGED
@@ -2,51 +2,243 @@
2
  "cells": [
3
  {
4
  "cell_type": "markdown",
 
5
  "metadata": {},
6
  "source": [
7
- "# NTv3 Quickstart Inference\n",
8
  "\n",
9
- "This notebook demonstrates how to load and run inference with NTv3 models.\n"
 
 
 
 
 
 
 
 
 
10
  ]
11
  },
12
  {
13
  "cell_type": "markdown",
 
14
  "metadata": {},
15
  "source": [
16
- "## Install Dependencies"
 
 
 
 
 
 
 
 
 
 
 
 
17
  ]
18
  },
19
  {
20
  "cell_type": "markdown",
 
 
 
 
 
 
 
 
 
 
21
  "metadata": {},
 
 
 
 
 
 
 
 
 
 
22
  "source": [
23
- "## Load Model and Tokenizer"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
  ]
25
  },
26
  {
27
  "cell_type": "markdown",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
  "metadata": {},
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29
  "source": [
30
- "## Run Inference"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
  ]
32
  },
33
  {
34
  "cell_type": "markdown",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35
  "metadata": {},
 
 
 
 
 
 
 
 
 
 
 
36
  "source": [
37
- "## Next Steps\n",
38
  "\n",
39
- "- Try different sequences and models\n",
40
- "- Explore model outputs\n",
41
- "- Check out other notebooks for tracks prediction, annotation, and more\n"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
42
  ]
43
  }
44
  ],
45
  "metadata": {
 
 
 
 
 
46
  "language_info": {
47
- "name": "python"
 
 
 
 
 
 
 
 
 
48
  }
49
  },
50
  "nbformat": 4,
51
- "nbformat_minor": 2
52
  }
 
2
  "cells": [
3
  {
4
  "cell_type": "markdown",
5
+ "id": "024bb8a8",
6
  "metadata": {},
7
  "source": [
8
+ "# NTv3 Quickstart — Pre-trained and Post-trained models\n",
9
  "\n",
10
+ "This notebook demonstrates how to run **quick inference** with bothe pre- and post-trained NTv3 checkpoints:\n",
11
+ "\n",
12
+ "- **Pre-trained (MLM-focused):** `InstaDeepAI/ntv3_8M_7downsample_pretrained_le_1mb`, `InstaDeepAI/ntv3_106M_7downsample_pretrained_le_1mb`, `InstaDeepAI/ntv3_650M_ntv3_650M_7downsample_pretrained_le_1mb7downsample_pre_trained_1mb`\n",
13
+ "- **Post-trained (task heads):** `InstaDeepAI/ntv3_106M_7downsample_post_trained_1mb`, `InstaDeepAI/ntv3_650M_7downsample_post_trained_1mb`\n",
14
+ "\n",
15
+ "We show how to:\n",
16
+ "\n",
17
+ "1. Load tokenizers + models\n",
18
+ "2. Run a forward pass on a DNA sequence window\n",
19
+ "3. Inspect key outputs"
20
  ]
21
  },
22
  {
23
  "cell_type": "markdown",
24
+ "id": "5d58bf1d",
25
  "metadata": {},
26
  "source": [
27
+ "## 0) Install dependencies\n",
28
+ "\n",
29
+ "Skip if already installed."
30
+ ]
31
+ },
32
+ {
33
+ "cell_type": "code",
34
+ "execution_count": null,
35
+ "id": "38cc32a9",
36
+ "metadata": {},
37
+ "outputs": [],
38
+ "source": [
39
+ "!pip -q install \"transformers>=4.40\" \"huggingface_hub>=0.23\" safetensors torch numpy"
40
  ]
41
  },
42
  {
43
  "cell_type": "markdown",
44
+ "id": "5827af7e",
45
+ "metadata": {},
46
+ "source": [
47
+ "## 1) Imports + setup"
48
+ ]
49
+ },
50
+ {
51
+ "cell_type": "code",
52
+ "execution_count": 7,
53
+ "id": "d56c105b",
54
  "metadata": {},
55
+ "outputs": [
56
+ {
57
+ "name": "stdout",
58
+ "output_type": "stream",
59
+ "text": [
60
+ "device: cpu\n",
61
+ "torch_dtype: torch.float32\n"
62
+ ]
63
+ }
64
+ ],
65
  "source": [
66
+ "import os\n",
67
+ "import torch\n",
68
+ "import numpy as np\n",
69
+ "\n",
70
+ "from transformers import AutoConfig, AutoModel, AutoTokenizer, AutoModelForMaskedLM\n",
71
+ "\n",
72
+ "# Optional: if the model is gated/private, set HF_TOKEN to a PERSONAL token (hf_...)\n",
73
+ "HF_TOKEN = os.getenv(\"HF_TOKEN\", None)\n",
74
+ "\n",
75
+ "# -----------------------------\n",
76
+ "# Device\n",
77
+ "# -----------------------------\n",
78
+ "device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
79
+ "print(\"device:\", device)\n",
80
+ "\n",
81
+ "# Choose dtype (bf16 if supported; else fp16 on GPU; else fp32)\n",
82
+ "if device == \"cuda\":\n",
83
+ " major, minor = torch.cuda.get_device_capability(0)\n",
84
+ " torch_dtype = torch.bfloat16 if major >= 8 else torch.float16\n",
85
+ "else:\n",
86
+ " torch_dtype = torch.float32\n",
87
+ "\n",
88
+ "print(\"torch_dtype:\", torch_dtype)"
89
  ]
90
  },
91
  {
92
  "cell_type": "markdown",
93
+ "id": "82146876",
94
+ "metadata": {},
95
+ "source": [
96
+ "## 2) Pre-trained checkpoint (MLM-focused)\n",
97
+ "\n",
98
+ "This shows the simplest usage: load model + tokenizer, then run a forward pass.\n",
99
+ "\n",
100
+ "Expected output:\n",
101
+ "- `logits`: masked language modeling logits"
102
+ ]
103
+ },
104
+ {
105
+ "cell_type": "code",
106
+ "execution_count": null,
107
+ "id": "336bb40c",
108
  "metadata": {},
109
+ "outputs": [
110
+ {
111
+ "name": "stdout",
112
+ "output_type": "stream",
113
+ "text": [
114
+ "torch.Size([2, 128, 11])\n",
115
+ "16\n",
116
+ "2\n",
117
+ "MLM logits shape: (2, 128, 11)\n"
118
+ ]
119
+ },
120
+ {
121
+ "name": "stderr",
122
+ "output_type": "stream",
123
+ "text": [
124
+ "/opt/anaconda3/envs/hf-finetune/lib/python3.10/site-packages/torch/amp/autocast_mode.py:283: UserWarning: In CPU autocast, but the target dtype is not supported. Disabling autocast.\n",
125
+ "CPU Autocast only supports dtype of torch.bfloat16, torch.float16 currently.\n",
126
+ " warnings.warn(error_message)\n"
127
+ ]
128
+ }
129
+ ],
130
  "source": [
131
+ "pretrained_model_name = \"InstaDeepAI/ntv3_8M_7downsample_pretrained_le_1mb\"\n",
132
+ "\n",
133
+ "# Load tokenizer/model\n",
134
+ "tok_pre = AutoTokenizer.from_pretrained(pretrained_model_name, trust_remote_code=True)\n",
135
+ "model_pre = AutoModelForMaskedLM.from_pretrained(pretrained_model_name, trust_remote_code=True)\n",
136
+ "\n",
137
+ "# Example: human sequence\n",
138
+ "seqs = [\"ATCGNATCG\", \"ACGT\"]\n",
139
+ "batch = tok_pre(seqs, add_special_tokens=False, padding=True, pad_to_multiple_of=128, return_tensors=\"pt\")\n",
140
+ "out = model_pre(**batch, output_hidden_states=True, output_attentions=True)\n",
141
+ "\n",
142
+ "print(out.logits.shape) # (B, L, V = 11)\n",
143
+ "print(len(out.hidden_states)) # convs + transformers + deconvs\n",
144
+ "print(len(out.attentions))\n",
145
+ "\n",
146
+ "# Access MLM logits\n",
147
+ "mlm_logits = out[\"logits\"]\n",
148
+ "print(\"MLM logits shape:\", tuple(mlm_logits.shape))"
149
  ]
150
  },
151
  {
152
  "cell_type": "markdown",
153
+ "id": "60a01798",
154
+ "metadata": {},
155
+ "source": [
156
+ "## 3) Post-trained checkpoint (task heads: BigWig + BED)\n",
157
+ "\n",
158
+ "Post-trained checkpoints add task-specific heads.\n",
159
+ "\n",
160
+ "In particular:\n",
161
+ "- `condition_tokenizer` is used to tokenize a species condition like `\"human\"`\n",
162
+ "- `file_assembly_idx` selects the assembly (e.g., `hg38`)\n",
163
+ "\n",
164
+ "Expected outputs:\n",
165
+ "- `bigwig_tracks_logits`\n",
166
+ "- `bed_tracks_logits`\n",
167
+ "- `logits` (MLM)\n",
168
+ "\n",
169
+ "> If your post-trained checkpoint supports multiple assemblies, the config typically exposes a mapping like `cfg.bigwigs_per_file_assembly`."
170
+ ]
171
+ },
172
+ {
173
+ "cell_type": "code",
174
+ "execution_count": null,
175
+ "id": "6cc5f2df",
176
  "metadata": {},
177
+ "outputs": [
178
+ {
179
+ "name": "stdout",
180
+ "output_type": "stream",
181
+ "text": [
182
+ "torch.Size([1, 768, 7362])\n",
183
+ "torch.Size([1, 768, 21, 2])\n",
184
+ "torch.Size([1, 2048, 11])\n"
185
+ ]
186
+ }
187
+ ],
188
  "source": [
189
+ "posttrained_model_name = \"InstaDeepAI/ntv3_106M_7downsample_post_trained_1mb\"\n",
190
  "\n",
191
+ "# Load config/tokenizers/model\n",
192
+ "cfg_pos = AutoConfig.from_pretrained(posttrained_model_name, trust_remote_code=True)\n",
193
+ "tok_pos = AutoTokenizer.from_pretrained(posttrained_model_name, trust_remote_code=True)\n",
194
+ "model_pos = AutoModel.from_pretrained(posttrained_model_name, trust_remote_code=True)\n",
195
+ "condition_tokenizer = AutoTokenizer.from_pretrained(\n",
196
+ " posttrained_model_name, subfolder=\"condition_tokenizer\", trust_remote_code=True\n",
197
+ ")\n",
198
+ "\n",
199
+ "# Example: human sequence (sequence needs to be multiple of 128 due to 7 downsampling in model)\n",
200
+ "seq = \"ATCG\" * 512\n",
201
+ "batch = tok_pos([seq], add_special_tokens=False, return_tensors=\"pt\")\n",
202
+ "condition = condition_tokenizer([\"human\"], return_tensors=\"pt\")\n",
203
+ "\n",
204
+ "# Get assembly index for human (hg38)\n",
205
+ "assemblies = list(cfg_pos.bigwigs_per_file_assembly.keys())\n",
206
+ "assembly_idx = torch.tensor([assemblies.index(\"hg38\")])\n",
207
+ "\n",
208
+ "out = model_pos(\n",
209
+ " input_ids=batch[\"input_ids\"],\n",
210
+ " condition_ids=[condition[\"input_ids\"][0]],\n",
211
+ " file_assembly_idx=assembly_idx,\n",
212
+ " output_hidden_states=True,\n",
213
+ " output_attentions=True,\n",
214
+ ")\n",
215
+ "\n",
216
+ "# Access model outputs\n",
217
+ "print(out[\"bigwig_tracks_logits\"].shape) # per-assembly functional track predictions\n",
218
+ "print(out[\"bed_tracks_logits\"].shape) # genomic element classifications\n",
219
+ "print(out[\"logits\"].shape) # masked LM logits"
220
  ]
221
  }
222
  ],
223
  "metadata": {
224
+ "kernelspec": {
225
+ "display_name": "hf-finetune",
226
+ "language": "python",
227
+ "name": "python3"
228
+ },
229
  "language_info": {
230
+ "codemirror_mode": {
231
+ "name": "ipython",
232
+ "version": 3
233
+ },
234
+ "file_extension": ".py",
235
+ "mimetype": "text/x-python",
236
+ "name": "python",
237
+ "nbconvert_exporter": "python",
238
+ "pygments_lexer": "ipython3",
239
+ "version": "3.10.18"
240
  }
241
  },
242
  "nbformat": 4,
243
+ "nbformat_minor": 5
244
  }
notebooks/01_tracks_prediction.ipynb ADDED
The diff for this file is too large to render. See raw diff