bernardo-de-almeida commited on
Commit
76bb74c
·
1 Parent(s): fb89b3c

uniformize notebooks

Browse files
notebooks/00_quickstart_inference.ipynb CHANGED
@@ -23,12 +23,10 @@
23
  },
24
  {
25
  "cell_type": "markdown",
26
- "id": "5d58bf1d",
27
  "metadata": {},
28
  "source": [
29
- "## 0) ⚙️ Colab Setup (if running on Google Colab)\n",
30
- "\n",
31
- "This cell detects if you're running on Google Colab and sets up the environment accordingly."
32
  ]
33
  },
34
  {
@@ -41,14 +39,6 @@
41
  "!pip -q install \"transformers>=4.40\" \"huggingface_hub>=0.23\" safetensors torch numpy"
42
  ]
43
  },
44
- {
45
- "cell_type": "markdown",
46
- "id": "5827af7e",
47
- "metadata": {},
48
- "source": [
49
- "## 1) 📦 Imports + setup"
50
- ]
51
- },
52
  {
53
  "cell_type": "code",
54
  "execution_count": 3,
@@ -95,7 +85,7 @@
95
  "id": "82146876",
96
  "metadata": {},
97
  "source": [
98
- "## 2) 🎯 Pre-trained checkpoint (MLM-focused)\n",
99
  "\n",
100
  "This shows the simplest usage: load model + tokenizer, then run a forward pass.\n",
101
  "\n",
@@ -285,7 +275,7 @@
285
  "id": "60a01798",
286
  "metadata": {},
287
  "source": [
288
- "## 3) 🧠 Post-trained checkpoint (task heads: BigWig + BED)\n",
289
  "\n",
290
  "Post-trained checkpoints add task-specific heads for functional track prediction and genome annotation.\n",
291
  "\n",
 
23
  },
24
  {
25
  "cell_type": "markdown",
26
+ "id": "5827af7e",
27
  "metadata": {},
28
  "source": [
29
+ "## 0) 📦 Imports + setup"
 
 
30
  ]
31
  },
32
  {
 
39
  "!pip -q install \"transformers>=4.40\" \"huggingface_hub>=0.23\" safetensors torch numpy"
40
  ]
41
  },
 
 
 
 
 
 
 
 
42
  {
43
  "cell_type": "code",
44
  "execution_count": 3,
 
85
  "id": "82146876",
86
  "metadata": {},
87
  "source": [
88
+ "## 1) 🎯 Pre-trained checkpoint (MLM-focused)\n",
89
  "\n",
90
  "This shows the simplest usage: load model + tokenizer, then run a forward pass.\n",
91
  "\n",
 
275
  "id": "60a01798",
276
  "metadata": {},
277
  "source": [
278
+ "## 2) 🧠 Post-trained checkpoint (task heads: BigWig + BED)\n",
279
  "\n",
280
  "Post-trained checkpoints add task-specific heads for functional track prediction and genome annotation.\n",
281
  "\n",
notebooks/01_tracks_prediction.ipynb CHANGED
@@ -35,21 +35,20 @@
35
  "- Supports the 24 species that NTv3 was post-trained on"
36
  ]
37
  },
 
 
 
 
 
 
 
 
38
  {
39
  "cell_type": "code",
40
- "execution_count": null,
41
  "id": "0ff509fd",
42
  "metadata": {},
43
- "outputs": [
44
- {
45
- "name": "stdout",
46
- "output_type": "stream",
47
- "text": [
48
- "\u001b[33mWARNING: 401 Error, Credentials not correct for https://gitlab.com/api/v4/projects/36813343/packages/pypi/simple/seaborn/\u001b[0m\u001b[33m\n",
49
- "\u001b[0m"
50
- ]
51
- }
52
- ],
53
  "source": [
54
  "# Install dependencies\n",
55
  "!pip -q install \"transformers>=4.55\" \"huggingface_hub>=0.23\" safetensors torch pyfaidx requests seaborn matplotlib"
@@ -57,7 +56,7 @@
57
  },
58
  {
59
  "cell_type": "code",
60
- "execution_count": 7,
61
  "id": "608d67e1",
62
  "metadata": {},
63
  "outputs": [],
@@ -76,35 +75,48 @@
76
  "import seaborn as sns"
77
  ]
78
  },
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
79
  {
80
  "cell_type": "markdown",
81
  "id": "19db4774",
82
  "metadata": {},
83
  "source": [
84
- "## 1) 📦 Imports + configuration\n",
85
  "\n",
86
  "Set your NTv3 model and genomic window here"
87
  ]
88
  },
89
  {
90
  "cell_type": "code",
91
- "execution_count": 8,
92
  "id": "795a576f",
93
  "metadata": {},
94
- "outputs": [
95
- {
96
- "name": "stdout",
97
- "output_type": "stream",
98
- "text": [
99
- "window length: 131072\n"
100
- ]
101
- }
102
- ],
103
  "source": [
104
  "# -----------------------------\n",
105
  "# User inputs\n",
106
  "# -----------------------------\n",
107
- "model_name = \"InstaDeepAI/NTv3_100M_pos\" # options: \"InstaDeepAI/ntv3_106M_7downsample_post_trained_1mb\" or \"InstaDeepAI/ntv3_650M_7downsample_post_trained_1mb_v2\"\n",
108
  "\n",
109
  "# Example window from a given species (edit these) - needs to be multiple of 128 due to the model downsampling\n",
110
  "species = \"human\" # will use for condition the model on species\n",
@@ -114,40 +126,7 @@
114
  "end = 6_831_072\n",
115
  "\n",
116
  "# Optional\n",
117
- "HF_TOKEN = os.getenv(\"HF_TOKEN\", None)\n",
118
- "\n",
119
- "assert end > start, \"end must be > start\"\n",
120
- "window_len = end - start\n",
121
- "assert window_len % 128 == 0, f\"window length ({window_len}) must be a multiple of 128\"\n",
122
- "print(\"window length:\", window_len)\n",
123
- "\n",
124
- "# Simple DNA sanitization\n",
125
- "DNA_RE = re.compile(r\"^[ACGTNacgtn]+$\")\n",
126
- "def sanitize_dna(seq: str) -> str:\n",
127
- " seq = seq.upper()\n",
128
- " seq = re.sub(r\"[^ACGTN]\", \"N\", seq)\n",
129
- " return seq"
130
- ]
131
- },
132
- {
133
- "cell_type": "code",
134
- "execution_count": 4,
135
- "id": "2354e2aa",
136
- "metadata": {},
137
- "outputs": [
138
- {
139
- "name": "stdout",
140
- "output_type": "stream",
141
- "text": [
142
- "device: cpu dtype: torch.float16\n"
143
- ]
144
- }
145
- ],
146
- "source": [
147
- "# Device\n",
148
- "device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
149
- "dtype = torch.bfloat16 if (device == \"cuda\" and torch.cuda.get_device_capability(0)[0] >= 8) else torch.float16\n",
150
- "print(\"device:\", device, \"dtype:\", dtype)"
151
  ]
152
  },
153
  {
@@ -160,62 +139,28 @@
160
  },
161
  {
162
  "cell_type": "code",
163
- "execution_count": null,
164
- "id": "8c20066a",
165
  "metadata": {},
166
  "outputs": [
167
  {
168
  "name": "stdout",
169
  "output_type": "stream",
170
  "text": [
171
- "Downloading: https://hgdownload.soe.ucsc.edu/goldenPath/hg38/chromosomes/chr19.fa.gz\n",
172
- "Using downloaded chromosome FASTA: ./hg38/chr19.fa\n",
173
- "Sequence preview: GTCAACAATAACAAATGACATATTAGTAGTAAATTATAATTATACATTACAACAAAATTA...\n",
174
- "Valid DNA: True\n"
175
  ]
176
  }
177
  ],
178
  "source": [
179
- "def download_ucsc_chrom_fasta(chrom: str, assembly: str, out_dir: str = f\"./{assembly}\") -> str:\n",
180
- " \"\"\"Download a single chromosome FASTA from UCSC and return local path.\"\"\"\n",
181
- " os.makedirs(out_dir, exist_ok=True)\n",
182
- " gz_path = os.path.join(out_dir, f\"{chrom}.fa.gz\")\n",
183
- " fa_path = os.path.join(out_dir, f\"{chrom}.fa\")\n",
184
- "\n",
185
- " if os.path.exists(fa_path):\n",
186
- " return fa_path\n",
187
- "\n",
188
- " # UCSC chrom fasta (chromFa/)\n",
189
- " url = f\"https://hgdownload.soe.ucsc.edu/goldenPath/{assembly}/chromosomes/{chrom}.fa.gz\"\n",
190
- " print(\"Downloading:\", url)\n",
191
- " r = requests.get(url, stream=True)\n",
192
- " r.raise_for_status()\n",
193
- " with open(gz_path, \"wb\") as f:\n",
194
- " for chunk in r.iter_content(chunk_size=1024 * 1024):\n",
195
- " if chunk:\n",
196
- " f.write(chunk)\n",
197
- "\n",
198
- " # Decompress\n",
199
- " import gzip\n",
200
- " with gzip.open(gz_path, \"rb\") as fin, open(fa_path, \"wb\") as fout:\n",
201
- " fout.write(fin.read())\n",
202
- "\n",
203
- " return fa_path\n",
204
- "\n",
205
- "def fetch_window_sequence(chrom: str, start: int, end: int, fasta_path: str = \"\") -> str:\n",
206
- " \"\"\"Fetch [start,end) sequence from fasta. If fasta_path is a whole genome file, chrom must match record name.\"\"\"\n",
207
- " fasta = Fasta(fasta_path, rebuild=True)\n",
208
- " seq = fasta[chrom][start:end].seq\n",
209
- " return sanitize_dna(seq)\n",
210
- "\n",
211
- "# Download chromosome\n",
212
- "fasta_path = download_ucsc_chrom_fasta(chrom, assembly)\n",
213
- "print(\"Using downloaded chromosome FASTA:\", fasta_path)\n",
214
- "\n",
215
- "seq = fetch_window_sequence(chrom, start, end, fasta_path=fasta_path)\n",
216
- "print(\"Sequence preview:\", seq[:60] + (\"...\" if len(seq) > 60 else \"\"))\n",
217
- "print(\"Valid DNA:\", bool(DNA_RE.match(seq)))\n",
218
- "assert len(seq) == (end - start), \"Fetched sequence length mismatch\""
219
  ]
220
  },
221
  {
@@ -228,7 +173,7 @@
228
  },
229
  {
230
  "cell_type": "code",
231
- "execution_count": 11,
232
  "id": "e09f0469",
233
  "metadata": {},
234
  "outputs": [
@@ -395,7 +340,7 @@
395
  ")"
396
  ]
397
  },
398
- "execution_count": 11,
399
  "metadata": {},
400
  "output_type": "execute_result"
401
  }
@@ -419,7 +364,7 @@
419
  },
420
  {
421
  "cell_type": "code",
422
- "execution_count": 12,
423
  "id": "43154959",
424
  "metadata": {},
425
  "outputs": [
@@ -463,7 +408,7 @@
463
  },
464
  {
465
  "cell_type": "code",
466
- "execution_count": 13,
467
  "id": "6765a9b9",
468
  "metadata": {},
469
  "outputs": [
@@ -520,7 +465,7 @@
520
  },
521
  {
522
  "cell_type": "code",
523
- "execution_count": 14,
524
  "id": "a26e9dcc",
525
  "metadata": {},
526
  "outputs": [],
@@ -537,7 +482,7 @@
537
  },
538
  {
539
  "cell_type": "code",
540
- "execution_count": 15,
541
  "id": "717539e2",
542
  "metadata": {},
543
  "outputs": [],
@@ -582,7 +527,7 @@
582
  },
583
  {
584
  "cell_type": "code",
585
- "execution_count": 16,
586
  "id": "7ba9a397",
587
  "metadata": {},
588
  "outputs": [
@@ -620,6 +565,7 @@
620
  "\n",
621
  "# Model predicts for middle 37.5% of input sequence\n",
622
  "# So predictions start at: start + (window_len - window_len * 0.375) / 2 = start + window_len * 0.3125\n",
 
623
  "prediction_start = start + int(window_len * 0.3125)\n",
624
  "prediction_end = prediction_start + int(window_len * 0.375)\n",
625
  "x = np.arange(prediction_start, prediction_end)\n",
 
35
  "- Supports the 24 species that NTv3 was post-trained on"
36
  ]
37
  },
38
+ {
39
+ "cell_type": "markdown",
40
+ "id": "77046e68",
41
+ "metadata": {},
42
+ "source": [
43
+ "## 0) 📦 Imports + setup"
44
+ ]
45
+ },
46
  {
47
  "cell_type": "code",
48
+ "execution_count": 1,
49
  "id": "0ff509fd",
50
  "metadata": {},
51
+ "outputs": [],
 
 
 
 
 
 
 
 
 
52
  "source": [
53
  "# Install dependencies\n",
54
  "!pip -q install \"transformers>=4.55\" \"huggingface_hub>=0.23\" safetensors torch pyfaidx requests seaborn matplotlib"
 
56
  },
57
  {
58
  "cell_type": "code",
59
+ "execution_count": 2,
60
  "id": "608d67e1",
61
  "metadata": {},
62
  "outputs": [],
 
75
  "import seaborn as sns"
76
  ]
77
  },
78
+ {
79
+ "cell_type": "code",
80
+ "execution_count": 3,
81
+ "id": "2354e2aa",
82
+ "metadata": {},
83
+ "outputs": [
84
+ {
85
+ "name": "stdout",
86
+ "output_type": "stream",
87
+ "text": [
88
+ "device: cpu dtype: torch.float16\n"
89
+ ]
90
+ }
91
+ ],
92
+ "source": [
93
+ "# Device\n",
94
+ "device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
95
+ "dtype = torch.bfloat16 if (device == \"cuda\" and torch.cuda.get_device_capability(0)[0] >= 8) else torch.float16\n",
96
+ "print(\"device:\", device, \"dtype:\", dtype)"
97
+ ]
98
+ },
99
  {
100
  "cell_type": "markdown",
101
  "id": "19db4774",
102
  "metadata": {},
103
  "source": [
104
+ "## 1) 📦 Configuration\n",
105
  "\n",
106
  "Set your NTv3 model and genomic window here"
107
  ]
108
  },
109
  {
110
  "cell_type": "code",
111
+ "execution_count": 4,
112
  "id": "795a576f",
113
  "metadata": {},
114
+ "outputs": [],
 
 
 
 
 
 
 
 
115
  "source": [
116
  "# -----------------------------\n",
117
  "# User inputs\n",
118
  "# -----------------------------\n",
119
+ "model_name = \"InstaDeepAI/NTv3_100M_pos\" # options: \"InstaDeepAI/NTv3_100M_pos\" or \"InstaDeepAI/NTv3_650M_pos\"\n",
120
  "\n",
121
  "# Example window from a given species (edit these) - needs to be multiple of 128 due to the model downsampling\n",
122
  "species = \"human\" # will use for condition the model on species\n",
 
126
  "end = 6_831_072\n",
127
  "\n",
128
  "# Optional\n",
129
+ "HF_TOKEN = os.getenv(\"HF_TOKEN\", None)"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
130
  ]
131
  },
132
  {
 
139
  },
140
  {
141
  "cell_type": "code",
142
+ "execution_count": 5,
143
+ "id": "2e0026e4",
144
  "metadata": {},
145
  "outputs": [
146
  {
147
  "name": "stdout",
148
  "output_type": "stream",
149
  "text": [
150
+ "Original sequence length: 131072\n",
151
+ "Cropped sequence length: 131072, 1024.0 transformer tokens\n"
 
 
152
  ]
153
  }
154
  ],
155
  "source": [
156
+ "# Get the sequence from the UCSC API\n",
157
+ "url = f\"https://api.genome.ucsc.edu/getData/sequence?genome={assembly};chrom={chrom};start={start};end={end}\"\n",
158
+ "seq = requests.get(url).json()[\"dna\"].upper()\n",
159
+ "print(f\"Original sequence length: {len(seq)}\")\n",
160
+ "\n",
161
+ "# Crop to multiple of 128 (the pipeline will crop again, but this is a no-op once divisible)\n",
162
+ "seq = seq[:int(len(seq) // 128) * 128]\n",
163
+ "print(f\"Cropped sequence length: {len(seq)}, {len(seq) / 128} transformer tokens\")"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
164
  ]
165
  },
166
  {
 
173
  },
174
  {
175
  "cell_type": "code",
176
+ "execution_count": 6,
177
  "id": "e09f0469",
178
  "metadata": {},
179
  "outputs": [
 
340
  ")"
341
  ]
342
  },
343
+ "execution_count": 6,
344
  "metadata": {},
345
  "output_type": "execute_result"
346
  }
 
364
  },
365
  {
366
  "cell_type": "code",
367
+ "execution_count": 7,
368
  "id": "43154959",
369
  "metadata": {},
370
  "outputs": [
 
408
  },
409
  {
410
  "cell_type": "code",
411
+ "execution_count": 8,
412
  "id": "6765a9b9",
413
  "metadata": {},
414
  "outputs": [
 
465
  },
466
  {
467
  "cell_type": "code",
468
+ "execution_count": 9,
469
  "id": "a26e9dcc",
470
  "metadata": {},
471
  "outputs": [],
 
482
  },
483
  {
484
  "cell_type": "code",
485
+ "execution_count": 10,
486
  "id": "717539e2",
487
  "metadata": {},
488
  "outputs": [],
 
527
  },
528
  {
529
  "cell_type": "code",
530
+ "execution_count": 12,
531
  "id": "7ba9a397",
532
  "metadata": {},
533
  "outputs": [
 
565
  "\n",
566
  "# Model predicts for middle 37.5% of input sequence\n",
567
  "# So predictions start at: start + (window_len - window_len * 0.375) / 2 = start + window_len * 0.3125\n",
568
+ "window_len = end - start\n",
569
  "prediction_start = start + int(window_len * 0.3125)\n",
570
  "prediction_end = prediction_start + int(window_len * 0.375)\n",
571
  "x = np.arange(prediction_start, prediction_end)\n",
notebooks/02_genome_annotation.ipynb CHANGED
@@ -11,19 +11,17 @@
11
  "\n",
12
  "The pipeline abstracts away all the underlying steps: running inference with the model, retrieving and processing the predicted probabilities, and applying the HMM to generate a consistent annotation. It returns a ready-to-use GFF file that can be visualized in any genome browser for the sequence of interest.\n",
13
  "\n",
14
- "If you’re interested in exploring the intermediate probabilities, please refer to the track-prediction notebooks. These probabilities can be useful for assessing model confidence and identifying potentially interesting biological regions. This notebook focuses on the higher-level task of producing gene annotations directly from raw DNA.\n",
15
  "\n",
16
  "> 📝 **Note for Google Colab users:** This notebook is compatible with Colab! For faster inference, make sure to enable GPU: Runtime → Change runtime type → GPU (T4 or better recommended)."
17
  ]
18
  },
19
  {
20
  "cell_type": "markdown",
21
- "id": "71fac239",
22
  "metadata": {},
23
  "source": [
24
- "## 0) Colab Setup (if running on Google Colab)\n",
25
- "\n",
26
- "This cell detects if you're running on Google Colab and sets up the environment accordingly."
27
  ]
28
  },
29
  {
@@ -37,16 +35,6 @@
37
  "!pip -q install \"transformers>=4.55\" \"huggingface_hub>=0.23\" safetensors torch pyfaidx requests seaborn matplotlib igv_notebook"
38
  ]
39
  },
40
- {
41
- "cell_type": "markdown",
42
- "id": "36d32e97",
43
- "metadata": {},
44
- "source": [
45
- "## 1) 📦 Imports + configuration\n",
46
- "\n",
47
- "Set your NTv3 model and genomic window here"
48
- ]
49
- },
50
  {
51
  "cell_type": "code",
52
  "execution_count": null,
@@ -61,6 +49,16 @@
61
  "from transformers import pipeline"
62
  ]
63
  },
 
 
 
 
 
 
 
 
 
 
64
  {
65
  "cell_type": "code",
66
  "execution_count": null,
@@ -69,7 +67,7 @@
69
  "outputs": [],
70
  "source": [
71
  "# Define the model and genomic window\n",
72
- "model_name = \"InstaDeepAI/NTv3_650M\"\n",
73
  "assembly = \"hg38\"\n",
74
  "chrom = \"chr19\"\n",
75
  "start = 6_700_000\n",
@@ -98,7 +96,7 @@
98
  "\n",
99
  "# Crop to multiple of 128 (the pipeline will crop again, but this is a no-op once divisible)\n",
100
  "seq = seq[:int(len(seq) // 128) * 128]\n",
101
- "print(f\"Cropped sequence length: {len(seq)}, {len(seq) / 128} tokens\")"
102
  ]
103
  },
104
  {
 
11
  "\n",
12
  "The pipeline abstracts away all the underlying steps: running inference with the model, retrieving and processing the predicted probabilities, and applying the HMM to generate a consistent annotation. It returns a ready-to-use GFF file that can be visualized in any genome browser for the sequence of interest.\n",
13
  "\n",
14
+ "If you’re interested in exploring the intermediate probabilities, please refer to the [track-prediction notebook](https://huggingface.co/spaces/InstaDeepAI/ntv3/blob/main/notebooks/01_tracks_prediction.ipynb). These probabilities can be useful for assessing model confidence and identifying potentially interesting biological regions. This notebook focuses on the higher-level task of producing gene annotations directly from raw DNA.\n",
15
  "\n",
16
  "> 📝 **Note for Google Colab users:** This notebook is compatible with Colab! For faster inference, make sure to enable GPU: Runtime → Change runtime type → GPU (T4 or better recommended)."
17
  ]
18
  },
19
  {
20
  "cell_type": "markdown",
21
+ "id": "94c46695",
22
  "metadata": {},
23
  "source": [
24
+ "## 0) 📦 Imports + setup"
 
 
25
  ]
26
  },
27
  {
 
35
  "!pip -q install \"transformers>=4.55\" \"huggingface_hub>=0.23\" safetensors torch pyfaidx requests seaborn matplotlib igv_notebook"
36
  ]
37
  },
 
 
 
 
 
 
 
 
 
 
38
  {
39
  "cell_type": "code",
40
  "execution_count": null,
 
49
  "from transformers import pipeline"
50
  ]
51
  },
52
+ {
53
+ "cell_type": "markdown",
54
+ "id": "9d29bb77",
55
+ "metadata": {},
56
+ "source": [
57
+ "## 1) 📦 Configuration\n",
58
+ "\n",
59
+ "Set your NTv3 model and genomic window here"
60
+ ]
61
+ },
62
  {
63
  "cell_type": "code",
64
  "execution_count": null,
 
67
  "outputs": [],
68
  "source": [
69
  "# Define the model and genomic window\n",
70
+ "model_name = \"InstaDeepAI/NTv3_650M_pos\"\n",
71
  "assembly = \"hg38\"\n",
72
  "chrom = \"chr19\"\n",
73
  "start = 6_700_000\n",
 
96
  "\n",
97
  "# Crop to multiple of 128 (the pipeline will crop again, but this is a no-op once divisible)\n",
98
  "seq = seq[:int(len(seq) // 128) * 128]\n",
99
+ "print(f\"Cropped sequence length: {len(seq)}, {len(seq) / 128} transformer tokens\")"
100
  ]
101
  },
102
  {