Spaces:
Running
Running
Commit
·
76bb74c
1
Parent(s):
fb89b3c
uniformize notebooks
Browse files
notebooks/00_quickstart_inference.ipynb
CHANGED
|
@@ -23,12 +23,10 @@
|
|
| 23 |
},
|
| 24 |
{
|
| 25 |
"cell_type": "markdown",
|
| 26 |
-
"id": "
|
| 27 |
"metadata": {},
|
| 28 |
"source": [
|
| 29 |
-
"## 0)
|
| 30 |
-
"\n",
|
| 31 |
-
"This cell detects if you're running on Google Colab and sets up the environment accordingly."
|
| 32 |
]
|
| 33 |
},
|
| 34 |
{
|
|
@@ -41,14 +39,6 @@
|
|
| 41 |
"!pip -q install \"transformers>=4.40\" \"huggingface_hub>=0.23\" safetensors torch numpy"
|
| 42 |
]
|
| 43 |
},
|
| 44 |
-
{
|
| 45 |
-
"cell_type": "markdown",
|
| 46 |
-
"id": "5827af7e",
|
| 47 |
-
"metadata": {},
|
| 48 |
-
"source": [
|
| 49 |
-
"## 1) 📦 Imports + setup"
|
| 50 |
-
]
|
| 51 |
-
},
|
| 52 |
{
|
| 53 |
"cell_type": "code",
|
| 54 |
"execution_count": 3,
|
|
@@ -95,7 +85,7 @@
|
|
| 95 |
"id": "82146876",
|
| 96 |
"metadata": {},
|
| 97 |
"source": [
|
| 98 |
-
"##
|
| 99 |
"\n",
|
| 100 |
"This shows the simplest usage: load model + tokenizer, then run a forward pass.\n",
|
| 101 |
"\n",
|
|
@@ -285,7 +275,7 @@
|
|
| 285 |
"id": "60a01798",
|
| 286 |
"metadata": {},
|
| 287 |
"source": [
|
| 288 |
-
"##
|
| 289 |
"\n",
|
| 290 |
"Post-trained checkpoints add task-specific heads for functional track prediction and genome annotation.\n",
|
| 291 |
"\n",
|
|
|
|
| 23 |
},
|
| 24 |
{
|
| 25 |
"cell_type": "markdown",
|
| 26 |
+
"id": "5827af7e",
|
| 27 |
"metadata": {},
|
| 28 |
"source": [
|
| 29 |
+
"## 0) 📦 Imports + setup"
|
|
|
|
|
|
|
| 30 |
]
|
| 31 |
},
|
| 32 |
{
|
|
|
|
| 39 |
"!pip -q install \"transformers>=4.40\" \"huggingface_hub>=0.23\" safetensors torch numpy"
|
| 40 |
]
|
| 41 |
},
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 42 |
{
|
| 43 |
"cell_type": "code",
|
| 44 |
"execution_count": 3,
|
|
|
|
| 85 |
"id": "82146876",
|
| 86 |
"metadata": {},
|
| 87 |
"source": [
|
| 88 |
+
"## 1) 🎯 Pre-trained checkpoint (MLM-focused)\n",
|
| 89 |
"\n",
|
| 90 |
"This shows the simplest usage: load model + tokenizer, then run a forward pass.\n",
|
| 91 |
"\n",
|
|
|
|
| 275 |
"id": "60a01798",
|
| 276 |
"metadata": {},
|
| 277 |
"source": [
|
| 278 |
+
"## 2) 🧠 Post-trained checkpoint (task heads: BigWig + BED)\n",
|
| 279 |
"\n",
|
| 280 |
"Post-trained checkpoints add task-specific heads for functional track prediction and genome annotation.\n",
|
| 281 |
"\n",
|
notebooks/01_tracks_prediction.ipynb
CHANGED
|
@@ -35,21 +35,20 @@
|
|
| 35 |
"- Supports the 24 species that NTv3 was post-trained on"
|
| 36 |
]
|
| 37 |
},
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 38 |
{
|
| 39 |
"cell_type": "code",
|
| 40 |
-
"execution_count":
|
| 41 |
"id": "0ff509fd",
|
| 42 |
"metadata": {},
|
| 43 |
-
"outputs": [
|
| 44 |
-
{
|
| 45 |
-
"name": "stdout",
|
| 46 |
-
"output_type": "stream",
|
| 47 |
-
"text": [
|
| 48 |
-
"\u001b[33mWARNING: 401 Error, Credentials not correct for https://gitlab.com/api/v4/projects/36813343/packages/pypi/simple/seaborn/\u001b[0m\u001b[33m\n",
|
| 49 |
-
"\u001b[0m"
|
| 50 |
-
]
|
| 51 |
-
}
|
| 52 |
-
],
|
| 53 |
"source": [
|
| 54 |
"# Install dependencies\n",
|
| 55 |
"!pip -q install \"transformers>=4.55\" \"huggingface_hub>=0.23\" safetensors torch pyfaidx requests seaborn matplotlib"
|
|
@@ -57,7 +56,7 @@
|
|
| 57 |
},
|
| 58 |
{
|
| 59 |
"cell_type": "code",
|
| 60 |
-
"execution_count":
|
| 61 |
"id": "608d67e1",
|
| 62 |
"metadata": {},
|
| 63 |
"outputs": [],
|
|
@@ -76,35 +75,48 @@
|
|
| 76 |
"import seaborn as sns"
|
| 77 |
]
|
| 78 |
},
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 79 |
{
|
| 80 |
"cell_type": "markdown",
|
| 81 |
"id": "19db4774",
|
| 82 |
"metadata": {},
|
| 83 |
"source": [
|
| 84 |
-
"## 1) 📦
|
| 85 |
"\n",
|
| 86 |
"Set your NTv3 model and genomic window here"
|
| 87 |
]
|
| 88 |
},
|
| 89 |
{
|
| 90 |
"cell_type": "code",
|
| 91 |
-
"execution_count":
|
| 92 |
"id": "795a576f",
|
| 93 |
"metadata": {},
|
| 94 |
-
"outputs": [
|
| 95 |
-
{
|
| 96 |
-
"name": "stdout",
|
| 97 |
-
"output_type": "stream",
|
| 98 |
-
"text": [
|
| 99 |
-
"window length: 131072\n"
|
| 100 |
-
]
|
| 101 |
-
}
|
| 102 |
-
],
|
| 103 |
"source": [
|
| 104 |
"# -----------------------------\n",
|
| 105 |
"# User inputs\n",
|
| 106 |
"# -----------------------------\n",
|
| 107 |
-
"model_name = \"InstaDeepAI/NTv3_100M_pos\" # options: \"InstaDeepAI/
|
| 108 |
"\n",
|
| 109 |
"# Example window from a given species (edit these) - needs to be multiple of 128 due to the model downsampling\n",
|
| 110 |
"species = \"human\" # will use for condition the model on species\n",
|
|
@@ -114,40 +126,7 @@
|
|
| 114 |
"end = 6_831_072\n",
|
| 115 |
"\n",
|
| 116 |
"# Optional\n",
|
| 117 |
-
"HF_TOKEN = os.getenv(\"HF_TOKEN\", None)
|
| 118 |
-
"\n",
|
| 119 |
-
"assert end > start, \"end must be > start\"\n",
|
| 120 |
-
"window_len = end - start\n",
|
| 121 |
-
"assert window_len % 128 == 0, f\"window length ({window_len}) must be a multiple of 128\"\n",
|
| 122 |
-
"print(\"window length:\", window_len)\n",
|
| 123 |
-
"\n",
|
| 124 |
-
"# Simple DNA sanitization\n",
|
| 125 |
-
"DNA_RE = re.compile(r\"^[ACGTNacgtn]+$\")\n",
|
| 126 |
-
"def sanitize_dna(seq: str) -> str:\n",
|
| 127 |
-
" seq = seq.upper()\n",
|
| 128 |
-
" seq = re.sub(r\"[^ACGTN]\", \"N\", seq)\n",
|
| 129 |
-
" return seq"
|
| 130 |
-
]
|
| 131 |
-
},
|
| 132 |
-
{
|
| 133 |
-
"cell_type": "code",
|
| 134 |
-
"execution_count": 4,
|
| 135 |
-
"id": "2354e2aa",
|
| 136 |
-
"metadata": {},
|
| 137 |
-
"outputs": [
|
| 138 |
-
{
|
| 139 |
-
"name": "stdout",
|
| 140 |
-
"output_type": "stream",
|
| 141 |
-
"text": [
|
| 142 |
-
"device: cpu dtype: torch.float16\n"
|
| 143 |
-
]
|
| 144 |
-
}
|
| 145 |
-
],
|
| 146 |
-
"source": [
|
| 147 |
-
"# Device\n",
|
| 148 |
-
"device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
|
| 149 |
-
"dtype = torch.bfloat16 if (device == \"cuda\" and torch.cuda.get_device_capability(0)[0] >= 8) else torch.float16\n",
|
| 150 |
-
"print(\"device:\", device, \"dtype:\", dtype)"
|
| 151 |
]
|
| 152 |
},
|
| 153 |
{
|
|
@@ -160,62 +139,28 @@
|
|
| 160 |
},
|
| 161 |
{
|
| 162 |
"cell_type": "code",
|
| 163 |
-
"execution_count":
|
| 164 |
-
"id": "
|
| 165 |
"metadata": {},
|
| 166 |
"outputs": [
|
| 167 |
{
|
| 168 |
"name": "stdout",
|
| 169 |
"output_type": "stream",
|
| 170 |
"text": [
|
| 171 |
-
"
|
| 172 |
-
"
|
| 173 |
-
"Sequence preview: GTCAACAATAACAAATGACATATTAGTAGTAAATTATAATTATACATTACAACAAAATTA...\n",
|
| 174 |
-
"Valid DNA: True\n"
|
| 175 |
]
|
| 176 |
}
|
| 177 |
],
|
| 178 |
"source": [
|
| 179 |
-
"
|
| 180 |
-
"
|
| 181 |
-
"
|
| 182 |
-
"
|
| 183 |
-
"
|
| 184 |
-
"\n",
|
| 185 |
-
"
|
| 186 |
-
"
|
| 187 |
-
"\n",
|
| 188 |
-
" # UCSC chrom fasta (chromFa/)\n",
|
| 189 |
-
" url = f\"https://hgdownload.soe.ucsc.edu/goldenPath/{assembly}/chromosomes/{chrom}.fa.gz\"\n",
|
| 190 |
-
" print(\"Downloading:\", url)\n",
|
| 191 |
-
" r = requests.get(url, stream=True)\n",
|
| 192 |
-
" r.raise_for_status()\n",
|
| 193 |
-
" with open(gz_path, \"wb\") as f:\n",
|
| 194 |
-
" for chunk in r.iter_content(chunk_size=1024 * 1024):\n",
|
| 195 |
-
" if chunk:\n",
|
| 196 |
-
" f.write(chunk)\n",
|
| 197 |
-
"\n",
|
| 198 |
-
" # Decompress\n",
|
| 199 |
-
" import gzip\n",
|
| 200 |
-
" with gzip.open(gz_path, \"rb\") as fin, open(fa_path, \"wb\") as fout:\n",
|
| 201 |
-
" fout.write(fin.read())\n",
|
| 202 |
-
"\n",
|
| 203 |
-
" return fa_path\n",
|
| 204 |
-
"\n",
|
| 205 |
-
"def fetch_window_sequence(chrom: str, start: int, end: int, fasta_path: str = \"\") -> str:\n",
|
| 206 |
-
" \"\"\"Fetch [start,end) sequence from fasta. If fasta_path is a whole genome file, chrom must match record name.\"\"\"\n",
|
| 207 |
-
" fasta = Fasta(fasta_path, rebuild=True)\n",
|
| 208 |
-
" seq = fasta[chrom][start:end].seq\n",
|
| 209 |
-
" return sanitize_dna(seq)\n",
|
| 210 |
-
"\n",
|
| 211 |
-
"# Download chromosome\n",
|
| 212 |
-
"fasta_path = download_ucsc_chrom_fasta(chrom, assembly)\n",
|
| 213 |
-
"print(\"Using downloaded chromosome FASTA:\", fasta_path)\n",
|
| 214 |
-
"\n",
|
| 215 |
-
"seq = fetch_window_sequence(chrom, start, end, fasta_path=fasta_path)\n",
|
| 216 |
-
"print(\"Sequence preview:\", seq[:60] + (\"...\" if len(seq) > 60 else \"\"))\n",
|
| 217 |
-
"print(\"Valid DNA:\", bool(DNA_RE.match(seq)))\n",
|
| 218 |
-
"assert len(seq) == (end - start), \"Fetched sequence length mismatch\""
|
| 219 |
]
|
| 220 |
},
|
| 221 |
{
|
|
@@ -228,7 +173,7 @@
|
|
| 228 |
},
|
| 229 |
{
|
| 230 |
"cell_type": "code",
|
| 231 |
-
"execution_count":
|
| 232 |
"id": "e09f0469",
|
| 233 |
"metadata": {},
|
| 234 |
"outputs": [
|
|
@@ -395,7 +340,7 @@
|
|
| 395 |
")"
|
| 396 |
]
|
| 397 |
},
|
| 398 |
-
"execution_count":
|
| 399 |
"metadata": {},
|
| 400 |
"output_type": "execute_result"
|
| 401 |
}
|
|
@@ -419,7 +364,7 @@
|
|
| 419 |
},
|
| 420 |
{
|
| 421 |
"cell_type": "code",
|
| 422 |
-
"execution_count":
|
| 423 |
"id": "43154959",
|
| 424 |
"metadata": {},
|
| 425 |
"outputs": [
|
|
@@ -463,7 +408,7 @@
|
|
| 463 |
},
|
| 464 |
{
|
| 465 |
"cell_type": "code",
|
| 466 |
-
"execution_count":
|
| 467 |
"id": "6765a9b9",
|
| 468 |
"metadata": {},
|
| 469 |
"outputs": [
|
|
@@ -520,7 +465,7 @@
|
|
| 520 |
},
|
| 521 |
{
|
| 522 |
"cell_type": "code",
|
| 523 |
-
"execution_count":
|
| 524 |
"id": "a26e9dcc",
|
| 525 |
"metadata": {},
|
| 526 |
"outputs": [],
|
|
@@ -537,7 +482,7 @@
|
|
| 537 |
},
|
| 538 |
{
|
| 539 |
"cell_type": "code",
|
| 540 |
-
"execution_count":
|
| 541 |
"id": "717539e2",
|
| 542 |
"metadata": {},
|
| 543 |
"outputs": [],
|
|
@@ -582,7 +527,7 @@
|
|
| 582 |
},
|
| 583 |
{
|
| 584 |
"cell_type": "code",
|
| 585 |
-
"execution_count":
|
| 586 |
"id": "7ba9a397",
|
| 587 |
"metadata": {},
|
| 588 |
"outputs": [
|
|
@@ -620,6 +565,7 @@
|
|
| 620 |
"\n",
|
| 621 |
"# Model predicts for middle 37.5% of input sequence\n",
|
| 622 |
"# So predictions start at: start + (window_len - window_len * 0.375) / 2 = start + window_len * 0.3125\n",
|
|
|
|
| 623 |
"prediction_start = start + int(window_len * 0.3125)\n",
|
| 624 |
"prediction_end = prediction_start + int(window_len * 0.375)\n",
|
| 625 |
"x = np.arange(prediction_start, prediction_end)\n",
|
|
|
|
| 35 |
"- Supports the 24 species that NTv3 was post-trained on"
|
| 36 |
]
|
| 37 |
},
|
| 38 |
+
{
|
| 39 |
+
"cell_type": "markdown",
|
| 40 |
+
"id": "77046e68",
|
| 41 |
+
"metadata": {},
|
| 42 |
+
"source": [
|
| 43 |
+
"## 0) 📦 Imports + setup"
|
| 44 |
+
]
|
| 45 |
+
},
|
| 46 |
{
|
| 47 |
"cell_type": "code",
|
| 48 |
+
"execution_count": 1,
|
| 49 |
"id": "0ff509fd",
|
| 50 |
"metadata": {},
|
| 51 |
+
"outputs": [],
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 52 |
"source": [
|
| 53 |
"# Install dependencies\n",
|
| 54 |
"!pip -q install \"transformers>=4.55\" \"huggingface_hub>=0.23\" safetensors torch pyfaidx requests seaborn matplotlib"
|
|
|
|
| 56 |
},
|
| 57 |
{
|
| 58 |
"cell_type": "code",
|
| 59 |
+
"execution_count": 2,
|
| 60 |
"id": "608d67e1",
|
| 61 |
"metadata": {},
|
| 62 |
"outputs": [],
|
|
|
|
| 75 |
"import seaborn as sns"
|
| 76 |
]
|
| 77 |
},
|
| 78 |
+
{
|
| 79 |
+
"cell_type": "code",
|
| 80 |
+
"execution_count": 3,
|
| 81 |
+
"id": "2354e2aa",
|
| 82 |
+
"metadata": {},
|
| 83 |
+
"outputs": [
|
| 84 |
+
{
|
| 85 |
+
"name": "stdout",
|
| 86 |
+
"output_type": "stream",
|
| 87 |
+
"text": [
|
| 88 |
+
"device: cpu dtype: torch.float16\n"
|
| 89 |
+
]
|
| 90 |
+
}
|
| 91 |
+
],
|
| 92 |
+
"source": [
|
| 93 |
+
"# Device\n",
|
| 94 |
+
"device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
|
| 95 |
+
"dtype = torch.bfloat16 if (device == \"cuda\" and torch.cuda.get_device_capability(0)[0] >= 8) else torch.float16\n",
|
| 96 |
+
"print(\"device:\", device, \"dtype:\", dtype)"
|
| 97 |
+
]
|
| 98 |
+
},
|
| 99 |
{
|
| 100 |
"cell_type": "markdown",
|
| 101 |
"id": "19db4774",
|
| 102 |
"metadata": {},
|
| 103 |
"source": [
|
| 104 |
+
"## 1) 📦 Configuration\n",
|
| 105 |
"\n",
|
| 106 |
"Set your NTv3 model and genomic window here"
|
| 107 |
]
|
| 108 |
},
|
| 109 |
{
|
| 110 |
"cell_type": "code",
|
| 111 |
+
"execution_count": 4,
|
| 112 |
"id": "795a576f",
|
| 113 |
"metadata": {},
|
| 114 |
+
"outputs": [],
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 115 |
"source": [
|
| 116 |
"# -----------------------------\n",
|
| 117 |
"# User inputs\n",
|
| 118 |
"# -----------------------------\n",
|
| 119 |
+
"model_name = \"InstaDeepAI/NTv3_100M_pos\" # options: \"InstaDeepAI/NTv3_100M_pos\" or \"InstaDeepAI/NTv3_650M_pos\"\n",
|
| 120 |
"\n",
|
| 121 |
"# Example window from a given species (edit these) - needs to be multiple of 128 due to the model downsampling\n",
|
| 122 |
"species = \"human\" # will use for condition the model on species\n",
|
|
|
|
| 126 |
"end = 6_831_072\n",
|
| 127 |
"\n",
|
| 128 |
"# Optional\n",
|
| 129 |
+
"HF_TOKEN = os.getenv(\"HF_TOKEN\", None)"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 130 |
]
|
| 131 |
},
|
| 132 |
{
|
|
|
|
| 139 |
},
|
| 140 |
{
|
| 141 |
"cell_type": "code",
|
| 142 |
+
"execution_count": 5,
|
| 143 |
+
"id": "2e0026e4",
|
| 144 |
"metadata": {},
|
| 145 |
"outputs": [
|
| 146 |
{
|
| 147 |
"name": "stdout",
|
| 148 |
"output_type": "stream",
|
| 149 |
"text": [
|
| 150 |
+
"Original sequence length: 131072\n",
|
| 151 |
+
"Cropped sequence length: 131072, 1024.0 transformer tokens\n"
|
|
|
|
|
|
|
| 152 |
]
|
| 153 |
}
|
| 154 |
],
|
| 155 |
"source": [
|
| 156 |
+
"# Get the sequence from the UCSC API\n",
|
| 157 |
+
"url = f\"https://api.genome.ucsc.edu/getData/sequence?genome={assembly};chrom={chrom};start={start};end={end}\"\n",
|
| 158 |
+
"seq = requests.get(url).json()[\"dna\"].upper()\n",
|
| 159 |
+
"print(f\"Original sequence length: {len(seq)}\")\n",
|
| 160 |
+
"\n",
|
| 161 |
+
"# Crop to multiple of 128 (the pipeline will crop again, but this is a no-op once divisible)\n",
|
| 162 |
+
"seq = seq[:int(len(seq) // 128) * 128]\n",
|
| 163 |
+
"print(f\"Cropped sequence length: {len(seq)}, {len(seq) / 128} transformer tokens\")"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 164 |
]
|
| 165 |
},
|
| 166 |
{
|
|
|
|
| 173 |
},
|
| 174 |
{
|
| 175 |
"cell_type": "code",
|
| 176 |
+
"execution_count": 6,
|
| 177 |
"id": "e09f0469",
|
| 178 |
"metadata": {},
|
| 179 |
"outputs": [
|
|
|
|
| 340 |
")"
|
| 341 |
]
|
| 342 |
},
|
| 343 |
+
"execution_count": 6,
|
| 344 |
"metadata": {},
|
| 345 |
"output_type": "execute_result"
|
| 346 |
}
|
|
|
|
| 364 |
},
|
| 365 |
{
|
| 366 |
"cell_type": "code",
|
| 367 |
+
"execution_count": 7,
|
| 368 |
"id": "43154959",
|
| 369 |
"metadata": {},
|
| 370 |
"outputs": [
|
|
|
|
| 408 |
},
|
| 409 |
{
|
| 410 |
"cell_type": "code",
|
| 411 |
+
"execution_count": 8,
|
| 412 |
"id": "6765a9b9",
|
| 413 |
"metadata": {},
|
| 414 |
"outputs": [
|
|
|
|
| 465 |
},
|
| 466 |
{
|
| 467 |
"cell_type": "code",
|
| 468 |
+
"execution_count": 9,
|
| 469 |
"id": "a26e9dcc",
|
| 470 |
"metadata": {},
|
| 471 |
"outputs": [],
|
|
|
|
| 482 |
},
|
| 483 |
{
|
| 484 |
"cell_type": "code",
|
| 485 |
+
"execution_count": 10,
|
| 486 |
"id": "717539e2",
|
| 487 |
"metadata": {},
|
| 488 |
"outputs": [],
|
|
|
|
| 527 |
},
|
| 528 |
{
|
| 529 |
"cell_type": "code",
|
| 530 |
+
"execution_count": 12,
|
| 531 |
"id": "7ba9a397",
|
| 532 |
"metadata": {},
|
| 533 |
"outputs": [
|
|
|
|
| 565 |
"\n",
|
| 566 |
"# Model predicts for middle 37.5% of input sequence\n",
|
| 567 |
"# So predictions start at: start + (window_len - window_len * 0.375) / 2 = start + window_len * 0.3125\n",
|
| 568 |
+
"window_len = end - start\n",
|
| 569 |
"prediction_start = start + int(window_len * 0.3125)\n",
|
| 570 |
"prediction_end = prediction_start + int(window_len * 0.375)\n",
|
| 571 |
"x = np.arange(prediction_start, prediction_end)\n",
|
notebooks/02_genome_annotation.ipynb
CHANGED
|
@@ -11,19 +11,17 @@
|
|
| 11 |
"\n",
|
| 12 |
"The pipeline abstracts away all the underlying steps: running inference with the model, retrieving and processing the predicted probabilities, and applying the HMM to generate a consistent annotation. It returns a ready-to-use GFF file that can be visualized in any genome browser for the sequence of interest.\n",
|
| 13 |
"\n",
|
| 14 |
-
"If you’re interested in exploring the intermediate probabilities, please refer to the track-prediction notebooks. These probabilities can be useful for assessing model confidence and identifying potentially interesting biological regions. This notebook focuses on the higher-level task of producing gene annotations directly from raw DNA.\n",
|
| 15 |
"\n",
|
| 16 |
"> 📝 **Note for Google Colab users:** This notebook is compatible with Colab! For faster inference, make sure to enable GPU: Runtime → Change runtime type → GPU (T4 or better recommended)."
|
| 17 |
]
|
| 18 |
},
|
| 19 |
{
|
| 20 |
"cell_type": "markdown",
|
| 21 |
-
"id": "
|
| 22 |
"metadata": {},
|
| 23 |
"source": [
|
| 24 |
-
"## 0)
|
| 25 |
-
"\n",
|
| 26 |
-
"This cell detects if you're running on Google Colab and sets up the environment accordingly."
|
| 27 |
]
|
| 28 |
},
|
| 29 |
{
|
|
@@ -37,16 +35,6 @@
|
|
| 37 |
"!pip -q install \"transformers>=4.55\" \"huggingface_hub>=0.23\" safetensors torch pyfaidx requests seaborn matplotlib igv_notebook"
|
| 38 |
]
|
| 39 |
},
|
| 40 |
-
{
|
| 41 |
-
"cell_type": "markdown",
|
| 42 |
-
"id": "36d32e97",
|
| 43 |
-
"metadata": {},
|
| 44 |
-
"source": [
|
| 45 |
-
"## 1) 📦 Imports + configuration\n",
|
| 46 |
-
"\n",
|
| 47 |
-
"Set your NTv3 model and genomic window here"
|
| 48 |
-
]
|
| 49 |
-
},
|
| 50 |
{
|
| 51 |
"cell_type": "code",
|
| 52 |
"execution_count": null,
|
|
@@ -61,6 +49,16 @@
|
|
| 61 |
"from transformers import pipeline"
|
| 62 |
]
|
| 63 |
},
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 64 |
{
|
| 65 |
"cell_type": "code",
|
| 66 |
"execution_count": null,
|
|
@@ -69,7 +67,7 @@
|
|
| 69 |
"outputs": [],
|
| 70 |
"source": [
|
| 71 |
"# Define the model and genomic window\n",
|
| 72 |
-
"model_name = \"InstaDeepAI/
|
| 73 |
"assembly = \"hg38\"\n",
|
| 74 |
"chrom = \"chr19\"\n",
|
| 75 |
"start = 6_700_000\n",
|
|
@@ -98,7 +96,7 @@
|
|
| 98 |
"\n",
|
| 99 |
"# Crop to multiple of 128 (the pipeline will crop again, but this is a no-op once divisible)\n",
|
| 100 |
"seq = seq[:int(len(seq) // 128) * 128]\n",
|
| 101 |
-
"print(f\"Cropped sequence length: {len(seq)}, {len(seq) / 128} tokens\")"
|
| 102 |
]
|
| 103 |
},
|
| 104 |
{
|
|
|
|
| 11 |
"\n",
|
| 12 |
"The pipeline abstracts away all the underlying steps: running inference with the model, retrieving and processing the predicted probabilities, and applying the HMM to generate a consistent annotation. It returns a ready-to-use GFF file that can be visualized in any genome browser for the sequence of interest.\n",
|
| 13 |
"\n",
|
| 14 |
+
"If you’re interested in exploring the intermediate probabilities, please refer to the [track-prediction notebook](https://huggingface.co/spaces/InstaDeepAI/ntv3/blob/main/notebooks/01_tracks_prediction.ipynb). These probabilities can be useful for assessing model confidence and identifying potentially interesting biological regions. This notebook focuses on the higher-level task of producing gene annotations directly from raw DNA.\n",
|
| 15 |
"\n",
|
| 16 |
"> 📝 **Note for Google Colab users:** This notebook is compatible with Colab! For faster inference, make sure to enable GPU: Runtime → Change runtime type → GPU (T4 or better recommended)."
|
| 17 |
]
|
| 18 |
},
|
| 19 |
{
|
| 20 |
"cell_type": "markdown",
|
| 21 |
+
"id": "94c46695",
|
| 22 |
"metadata": {},
|
| 23 |
"source": [
|
| 24 |
+
"## 0) 📦 Imports + setup"
|
|
|
|
|
|
|
| 25 |
]
|
| 26 |
},
|
| 27 |
{
|
|
|
|
| 35 |
"!pip -q install \"transformers>=4.55\" \"huggingface_hub>=0.23\" safetensors torch pyfaidx requests seaborn matplotlib igv_notebook"
|
| 36 |
]
|
| 37 |
},
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 38 |
{
|
| 39 |
"cell_type": "code",
|
| 40 |
"execution_count": null,
|
|
|
|
| 49 |
"from transformers import pipeline"
|
| 50 |
]
|
| 51 |
},
|
| 52 |
+
{
|
| 53 |
+
"cell_type": "markdown",
|
| 54 |
+
"id": "9d29bb77",
|
| 55 |
+
"metadata": {},
|
| 56 |
+
"source": [
|
| 57 |
+
"## 1) 📦 Configuration\n",
|
| 58 |
+
"\n",
|
| 59 |
+
"Set your NTv3 model and genomic window here"
|
| 60 |
+
]
|
| 61 |
+
},
|
| 62 |
{
|
| 63 |
"cell_type": "code",
|
| 64 |
"execution_count": null,
|
|
|
|
| 67 |
"outputs": [],
|
| 68 |
"source": [
|
| 69 |
"# Define the model and genomic window\n",
|
| 70 |
+
"model_name = \"InstaDeepAI/NTv3_650M_pos\"\n",
|
| 71 |
"assembly = \"hg38\"\n",
|
| 72 |
"chrom = \"chr19\"\n",
|
| 73 |
"start = 6_700_000\n",
|
|
|
|
| 96 |
"\n",
|
| 97 |
"# Crop to multiple of 128 (the pipeline will crop again, but this is a no-op once divisible)\n",
|
| 98 |
"seq = seq[:int(len(seq) // 128) * 128]\n",
|
| 99 |
+
"print(f\"Cropped sequence length: {len(seq)}, {len(seq) / 128} transformer tokens\")"
|
| 100 |
]
|
| 101 |
},
|
| 102 |
{
|