bernardo-de-almeida commited on
Commit
161de31
·
1 Parent(s): bf78c8f

feat: add pipeline api

Browse files
Files changed (4) hide show
  1. README.md +19 -7
  2. app.py +153 -0
  3. ntv3_tracks_pipeline.py +567 -0
  4. requirements.txt +6 -0
README.md CHANGED
@@ -1,13 +1,25 @@
1
  ---
2
- title: Ntv3 Tracks
3
- emoji: 👁
4
- colorFrom: purple
5
- colorTo: indigo
6
  sdk: gradio
7
- sdk_version: 6.1.0
8
  app_file: app.py
9
  pinned: false
10
- short_description: NTv3 Post-Trained Functional Track Prediction
11
  ---
12
 
13
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
+ title: NTv3 Tracks Demo
3
+ emoji: 🧬
4
+ colorFrom: blue
5
+ colorTo: green
6
  sdk: gradio
7
+ sdk_version: 4.0.0
8
  app_file: app.py
9
  pinned: false
 
10
  ---
11
 
12
+ # NTv3 Tracks Demo
13
+
14
+ This Space deploys the custom Hugging Face `Pipeline` in `ntv3_tracks_pipeline.py` and provides both:
15
+ - a UI
16
+ - a REST API (`/api/predict`, auto-generated by Gradio)
17
+
18
+ ## Environment variables (optional)
19
+
20
+ - `MODEL_ID` (default: `InstaDeepAI/NTv3_100M`)
21
+ - `DEFAULT_SPECIES` (default: `human`)
22
+
23
+ ## Notes
24
+
25
+ Genome-coordinate mode may download and decompress large FASTA files. For a lightweight demo, send a DNA sequence directly via `seq`.
app.py ADDED
@@ -0,0 +1,153 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import numpy as np
3
+ import gradio as gr
4
+
5
+ # local file in the Space repo
6
+ from ntv3_tracks_pipeline import load_ntv3_tracks_pipeline
7
+
8
+ MODEL_ID = os.environ.get("MODEL_ID", "InstaDeepAI/NTv3_650M_pos")
9
+ DEFAULT_SPECIES = os.environ.get("DEFAULT_SPECIES", "human")
10
+
11
+ # Load once at startup (Space container)
12
+ pipe = load_ntv3_tracks_pipeline(
13
+ model=MODEL_ID,
14
+ device="auto",
15
+ default_species=DEFAULT_SPECIES,
16
+ verbose=False,
17
+ )
18
+
19
+ def _downsample_1d(arr: np.ndarray, max_points: int):
20
+ if max_points is None or max_points <= 0 or arr.shape[0] <= max_points:
21
+ return arr, 1
22
+ stride = int(np.ceil(arr.shape[0] / max_points))
23
+ return arr[::stride], stride
24
+
25
+ def predict(
26
+ seq: str,
27
+ species: str,
28
+ chrom: str,
29
+ start: int,
30
+ end: int,
31
+ use_coords: bool,
32
+ tracks: str,
33
+ elements: str,
34
+ max_points: int,
35
+ ):
36
+ """
37
+ Returns JSON-serializable dict (Gradio also exposes this at /api/predict by default).
38
+ """
39
+ if use_coords:
40
+ if not chrom:
41
+ raise gr.Error("chrom is required when use_coords=True")
42
+ if start is None or end is None or end <= start:
43
+ raise gr.Error("start/end must be set and end > start when use_coords=True")
44
+ inputs = {"chrom": chrom, "start": int(start), "end": int(end), "species": species}
45
+ else:
46
+ if not seq or len(seq.strip()) == 0:
47
+ raise gr.Error("seq is required when use_coords=False")
48
+ inputs = {"seq": seq.strip(), "species": species}
49
+
50
+ out = pipe(inputs)
51
+
52
+ # Parse selection lists
53
+ track_ids = [t.strip() for t in tracks.split(",") if t.strip()] if tracks else []
54
+ element_names = [e.strip() for e in elements.split(",") if e.strip()] if elements else []
55
+
56
+ # Bigwig tracks
57
+ bigwig_names = out.bigwig_track_names or []
58
+ bw = out.bigwig_tracks_logits # (L, T)
59
+ bw_selected = {}
60
+ for tid in track_ids:
61
+ if tid not in bigwig_names:
62
+ continue
63
+ idx = bigwig_names.index(tid)
64
+ y, stride = _downsample_1d(bw[:, idx], max_points)
65
+ bw_selected[tid] = {"values": y.astype(float).tolist(), "stride": int(stride)}
66
+
67
+ # BED elements (positive class probability)
68
+ bed_selected = {}
69
+ if out.bed_element_names is not None and element_names:
70
+ logits = out.bed_tracks_logits # (L, E, C)
71
+ # softmax over last axis
72
+ logits = logits - logits.max(axis=-1, keepdims=True)
73
+ probs = np.exp(logits) / np.exp(logits).sum(axis=-1, keepdims=True)
74
+ for ename in element_names:
75
+ if ename not in out.bed_element_names:
76
+ continue
77
+ eidx = out.bed_element_names.index(ename)
78
+ y, stride = _downsample_1d(probs[:, eidx, 1], max_points)
79
+ bed_selected[ename] = {"values": y.astype(float).tolist(), "stride": int(stride)}
80
+
81
+ meta = {
82
+ "model_id": MODEL_ID,
83
+ "species": out.species,
84
+ "assembly": out.assembly,
85
+ "chrom": out.chrom,
86
+ "start": out.start,
87
+ "end": out.end,
88
+ "window_len": out.window_len,
89
+ "pred_start": out.pred_start,
90
+ "pred_end": out.pred_end,
91
+ }
92
+
93
+ return {
94
+ "meta": meta,
95
+ "bigwig_track_names_count": len(bigwig_names),
96
+ "bigwig_selected": bw_selected,
97
+ "bed_selected": bed_selected,
98
+ }
99
+
100
+ with gr.Blocks(title="NTv3 Tracks Demo") as demo:
101
+ gr.Markdown(
102
+ """# NTv3 tracks demo (Space)
103
+
104
+ This Space runs your `NTv3TracksPipeline` and exposes:
105
+ - an interactive UI
106
+ - a REST API (Gradio auto-generated endpoint)
107
+
108
+ **Tip:** For reliable, fast demos, pass a DNA **sequence** directly. Genome-coordinate mode may download a whole genome FASTA.
109
+ """
110
+ )
111
+
112
+ with gr.Row():
113
+ use_coords = gr.Checkbox(value=False, label="Use genome coords instead of seq")
114
+ species = gr.Dropdown(choices=["human","mouse","drosophila_melanogaster"], value=DEFAULT_SPECIES, label="species")
115
+
116
+ seq = gr.Textbox(lines=4, label="DNA sequence (A/C/G/T/N)")
117
+ with gr.Row():
118
+ chrom = gr.Textbox(label="chrom (e.g. chr1)")
119
+ start = gr.Number(label="start", value=0, precision=0)
120
+ end = gr.Number(label="end", value=1024, precision=0)
121
+
122
+ tracks = gr.Textbox(label="BigWig track IDs to return (comma-separated)", placeholder="ENCSR... , ENCSR...")
123
+ elements = gr.Textbox(label="BED element names to return (comma-separated)", placeholder="e.g. CTCF, H3K27ac")
124
+ max_points = gr.Slider(100, 5000, value=1000, step=100, label="Max points per returned series (downsample)")
125
+
126
+ btn = gr.Button("Predict")
127
+ out = gr.JSON(label="Output JSON")
128
+
129
+ btn.click(
130
+ fn=predict,
131
+ inputs=[seq, species, chrom, start, end, use_coords, tracks, elements, max_points],
132
+ outputs=[out],
133
+ )
134
+
135
+ gr.Markdown(
136
+ """## API usage
137
+
138
+ After you deploy, Gradio exposes an endpoint like:
139
+
140
+ - `POST https://<your-space>.hf.space/api/predict`
141
+
142
+ with JSON body:
143
+
144
+ ```json
145
+ {"data": ["ACGT...", "human", "", 0, 0, false, "ENCSR...", "CTCF", 1000]}
146
+ ```
147
+
148
+ The response is a JSON dict with `meta`, plus any requested tracks/elements.
149
+ """
150
+ )
151
+
152
+ if __name__ == "__main__":
153
+ demo.launch()
ntv3_tracks_pipeline.py ADDED
@@ -0,0 +1,567 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ from dataclasses import dataclass
4
+ from pathlib import Path
5
+ from typing import Any, Dict, List, Optional, Union
6
+
7
+ import numpy as np
8
+ import torch
9
+ from transformers import AutoConfig, AutoModel, AutoTokenizer
10
+ from transformers.pipelines import Pipeline
11
+
12
+ try:
13
+ from pyfaidx import Fasta
14
+ except Exception:
15
+ Fasta = None
16
+
17
+ try:
18
+ import requests
19
+ except Exception:
20
+ requests = None
21
+
22
+ try:
23
+ import matplotlib.pyplot as plt
24
+ except Exception:
25
+ plt = None
26
+
27
+ try:
28
+ import seaborn as sns
29
+ except Exception:
30
+ sns = None
31
+
32
+
33
+ # ---------------------------------------------------------------------
34
+ # Assembly <-> species mapping
35
+ # ---------------------------------------------------------------------
36
+ ASSEMBLY_TO_SPECIES = {
37
+ "hg38": "human",
38
+ "mm10": "mouse",
39
+ "dm6": "drosophila_melanogaster",
40
+ "TAIR10": "arabidopsis_thaliana",
41
+ "Zm-B73-REFERENCE-NAM-5.0": "zea_mays",
42
+ "IRGSP-1.0": "oryza_sativa",
43
+ "Glycine_max_v2.1": "glycine_max",
44
+ "IWGSC": "triticum_aestivum",
45
+ "Gossypium_hirsutum_v2.1": "gossypium_hirsutum",
46
+ "ASM228892v3": "delphinapterus_leucas",
47
+ "ASM334442v1": "ursus_americanus",
48
+ "AmpOce1": "amphiprion_ocellaris",
49
+ "Bison_UMD1": "bison_bison_bison",
50
+ "ChiLan1": "chinchilla_lanigera",
51
+ "Felis_catus_9": "felis_catus",
52
+ "GRCz11": "danio_rerio",
53
+ "KH": "ciona_intestinalis",
54
+ "Mnem_1": "macaca_nemestrina",
55
+ "R64": "saccharomyces_cerevisiae",
56
+ "ROS_Cfam_1": "canis_lupus_familiaris",
57
+ "SCA1": "serinus_canaria",
58
+ "TETRAODON8": "tetraodon_nigroviridis",
59
+ "WBcel235": "caenorhabditis_elegans",
60
+ "bGalGal1": "gallus_gallus",
61
+ "fSalTru1": "salmo_trutta",
62
+ "gorGor4": "gorilla_gorilla",
63
+ "mRatBN7": "rattus_norvegicus",
64
+ "SL3": "solanum_lycopersicum",
65
+ "ARS-UCD2.0": "bos_taurus",
66
+ }
67
+ SPECIES_TO_ASSEMBLY = {v: k for k, v in ASSEMBLY_TO_SPECIES.items()}
68
+
69
+ # Minimal UCSC FASTA sources (extend as needed)
70
+ ASSEMBLY_TO_UCSC_FA_GZ = {
71
+ "hg38": "https://hgdownload.soe.ucsc.edu/goldenPath/hg38/bigZips/hg38.fa.gz",
72
+ "mm10": "https://hgdownload.soe.ucsc.edu/goldenPath/mm10/bigZips/mm10.fa.gz",
73
+ "dm6": "https://hgdownload.soe.ucsc.edu/goldenPath/dm6/bigZips/dm6.fa.gz",
74
+ }
75
+
76
+
77
+ def _sanitize_dna(seq: str) -> str:
78
+ seq = seq.upper()
79
+ return "".join(ch if ch in ("A", "C", "G", "T", "N") else "N" for ch in seq)
80
+
81
+
82
+ def _download_file(url: str, dst: Path) -> None:
83
+ if requests is None:
84
+ raise ImportError("requests is required for genome download. Install with: pip install requests")
85
+ dst.parent.mkdir(parents=True, exist_ok=True)
86
+ with requests.get(url, stream=True, timeout=60) as r:
87
+ r.raise_for_status()
88
+ with open(dst, "wb") as f:
89
+ for chunk in r.iter_content(chunk_size=1024 * 1024):
90
+ if chunk:
91
+ f.write(chunk)
92
+
93
+
94
+ def _ensure_fasta_for_assembly(assembly: str, cache_dir: Union[str, Path]) -> Path:
95
+ """
96
+ Download <assembly>.fa.gz, decompress to <assembly>.fa, return the .fa path.
97
+ pyfaidx works reliably on uncompressed FASTA.
98
+ """
99
+ cache_dir = Path(cache_dir).expanduser().resolve()
100
+ cache_dir.mkdir(parents=True, exist_ok=True)
101
+
102
+ fa_path = cache_dir / f"{assembly}.fa"
103
+ gz_path = cache_dir / f"{assembly}.fa.gz"
104
+
105
+ if fa_path.exists():
106
+ return fa_path
107
+
108
+ if assembly not in ASSEMBLY_TO_UCSC_FA_GZ:
109
+ raise ValueError(
110
+ f"No download URL configured for assembly='{assembly}'. "
111
+ f"Supported for auto-download: {sorted(ASSEMBLY_TO_UCSC_FA_GZ.keys())}. "
112
+ f"Either pass fasta_path explicitly, or extend ASSEMBLY_TO_UCSC_FA_GZ."
113
+ )
114
+
115
+ url = ASSEMBLY_TO_UCSC_FA_GZ[assembly]
116
+ if not gz_path.exists():
117
+ print(f"Downloading {url} -> {gz_path}")
118
+ _download_file(url, gz_path)
119
+
120
+ import gzip
121
+ print(f"Decompressing {gz_path} -> {fa_path}")
122
+ with gzip.open(gz_path, "rb") as fin, open(fa_path, "wb") as fout:
123
+ while True:
124
+ chunk = fin.read(1024 * 1024)
125
+ if not chunk:
126
+ break
127
+ fout.write(chunk)
128
+
129
+ return fa_path
130
+
131
+
132
+ def _fetch_from_fasta(fasta_path: Union[str, Path], chrom: str, start: int, end: int) -> str:
133
+ if Fasta is None:
134
+ raise ImportError("pyfaidx is required for fasta windows. Install with: pip install pyfaidx")
135
+
136
+ fasta_path = Path(fasta_path)
137
+ if fasta_path.suffix == ".gz":
138
+ raise ValueError(f"Got '{fasta_path}' (gz). Please pass an uncompressed .fa (auto-download returns .fa).")
139
+
140
+ fasta = Fasta(str(fasta_path), rebuild=True)
141
+ return _sanitize_dna(fasta[chrom][start:end].seq)
142
+
143
+
144
+ def _pick_device(device: Union[str, int, torch.device]) -> torch.device:
145
+ # Handle torch.device objects
146
+ if isinstance(device, torch.device):
147
+ return device
148
+
149
+ # Handle integer device IDs (transformers pipeline convention)
150
+ if isinstance(device, int):
151
+ if device == -1:
152
+ return torch.device("cpu")
153
+ elif device >= 0:
154
+ if torch.cuda.is_available():
155
+ return torch.device(f"cuda:{device}")
156
+ else:
157
+ return torch.device("cpu")
158
+ else:
159
+ raise ValueError(f"Invalid device integer: {device}")
160
+
161
+ # Handle string device names
162
+ if isinstance(device, str):
163
+ d = device.lower()
164
+ if d == "auto":
165
+ if torch.cuda.is_available():
166
+ return torch.device("cuda")
167
+ if torch.backends.mps.is_available():
168
+ return torch.device("mps")
169
+ return torch.device("cpu")
170
+ if d in ("cuda", "cpu", "mps"):
171
+ return torch.device(d)
172
+ raise ValueError("device must be one of: 'auto', 'cpu', 'cuda', 'mps', or an integer")
173
+
174
+ raise ValueError(f"device must be a string, integer, or torch.device, got {type(device)}")
175
+
176
+
177
+ def _softmax_last(x: np.ndarray) -> np.ndarray:
178
+ x = x - x.max(axis=-1, keepdims=True)
179
+ ex = np.exp(x)
180
+ return ex / ex.sum(axis=-1, keepdims=True)
181
+
182
+
183
+ def _plot_tracks_fillbetween(
184
+ tracks: Dict[str, np.ndarray],
185
+ chrom: Optional[str],
186
+ start: int,
187
+ end: int,
188
+ assembly: Optional[str],
189
+ height: float = 1.0,
190
+ figsize_x: float = 20.0,
191
+ ):
192
+ if plt is None:
193
+ raise ImportError("matplotlib is required for plotting. Install with: pip install matplotlib")
194
+ if sns is None:
195
+ raise ImportError("seaborn is required for notebook-style plots. Install with: pip install seaborn")
196
+
197
+ n = len(tracks)
198
+ if n == 0:
199
+ raise ValueError("No tracks to plot.")
200
+
201
+ fig, axes = plt.subplots(n, 1, figsize=(figsize_x, height * n), sharex=True)
202
+ if n == 1:
203
+ axes = [axes]
204
+
205
+ any_track = next(iter(tracks.values()))
206
+ x = np.linspace(start, end, num=len(any_track), endpoint=False)
207
+
208
+ for ax, (title, y) in zip(axes, tracks.items()):
209
+ ax.fill_between(x, y)
210
+ ax.set_title(title)
211
+ sns.despine(top=True, right=True, bottom=True)
212
+
213
+ label = f"{chrom}:{start}-{end}" if chrom is not None else f"{start}-{end}"
214
+ if assembly is not None:
215
+ label += f" ({assembly})"
216
+ axes[-1].set_xlabel(label)
217
+
218
+ plt.tight_layout()
219
+ return fig, axes
220
+
221
+
222
+ @dataclass
223
+ class NTv3TracksOutput:
224
+ bigwig_tracks_logits: np.ndarray # (L_pred, T)
225
+ bed_tracks_logits: np.ndarray # (L_pred, E, C)
226
+ mlm_logits: np.ndarray
227
+ chrom: Optional[str] = None
228
+ start: Optional[int] = None
229
+ end: Optional[int] = None
230
+ species: Optional[str] = None
231
+ assembly: Optional[str] = None
232
+ bigwig_track_names: Optional[List[str]] = None # from cfg.bigwigs_per_file_assembly[assembly]
233
+ bed_element_names: Optional[List[str]] = None
234
+ window_len: Optional[int] = None
235
+ pred_start: Optional[int] = None
236
+ pred_end: Optional[int] = None
237
+
238
+
239
+ class NTv3TracksPipeline(Pipeline):
240
+ def __init__(
241
+ self,
242
+ model: Union[str, torch.nn.Module],
243
+ tokenizer: Optional[Union[str, Any]] = None,
244
+ trust_remote_code: bool = True,
245
+ token: Optional[str] = None,
246
+ default_species: str = "human",
247
+ genome_cache_dir: Union[str, Path] = "~/.cache/ntv3/genomes",
248
+ device: str = "auto",
249
+ mps_force_cpu: bool = True,
250
+ mps_force_cpu_length: int = 16384,
251
+ verbose: bool = True,
252
+ # Your notebook uses these constants for "middle 37.5%" prediction span
253
+ pred_center_fraction: float = 0.375,
254
+ pred_center_offset_fraction: float = 0.3125,
255
+ **kwargs: Any,
256
+ ):
257
+ self.model_id = model if isinstance(model, str) else None
258
+ self.default_species = default_species
259
+ self.genome_cache_dir = Path(genome_cache_dir)
260
+ self.mps_force_cpu = bool(mps_force_cpu)
261
+ self.mps_force_cpu_length = int(mps_force_cpu_length)
262
+ self.verbose = bool(verbose)
263
+ self.pred_center_fraction = float(pred_center_fraction)
264
+ self.pred_center_offset_fraction = float(pred_center_offset_fraction)
265
+
266
+ if self.default_species not in SPECIES_TO_ASSEMBLY:
267
+ raise ValueError(
268
+ f"default_species='{self.default_species}' is not supported. "
269
+ f"Supported species: {sorted(SPECIES_TO_ASSEMBLY.keys())}"
270
+ )
271
+
272
+ if isinstance(model, str):
273
+ self.config = AutoConfig.from_pretrained(model, trust_remote_code=trust_remote_code, token=token)
274
+ self.model = AutoModel.from_pretrained(model, trust_remote_code=trust_remote_code, token=token)
275
+ else:
276
+ self.model = model
277
+ self.config = getattr(model, "config", None)
278
+
279
+ if tokenizer is None:
280
+ if not self.model_id:
281
+ raise ValueError("If passing a model module, pass tokenizer explicitly.")
282
+ self.tokenizer = AutoTokenizer.from_pretrained(self.model_id, trust_remote_code=trust_remote_code, token=token)
283
+ elif isinstance(tokenizer, str):
284
+ self.tokenizer = AutoTokenizer.from_pretrained(tokenizer, trust_remote_code=trust_remote_code, token=token)
285
+ else:
286
+ self.tokenizer = tokenizer
287
+
288
+ # Extract model_id from config if not already set (following ntv3_gff_pipeline.py pattern)
289
+ if self.model_id is None and self.config is not None:
290
+ self.model_id = getattr(self.config, "_name_or_path", None) or getattr(self.config, "name_or_path", None)
291
+
292
+ # Load species_tokenizer (following ntv3_gff_pipeline.py pattern)
293
+ if self.model_id:
294
+ self.species_tokenizer = AutoTokenizer.from_pretrained(
295
+ self.model_id,
296
+ subfolder="species_tokenizer",
297
+ trust_remote_code=trust_remote_code,
298
+ token=token,
299
+ )
300
+ else:
301
+ self.species_tokenizer = kwargs.get("species_tokenizer", None)
302
+ if self.species_tokenizer is None:
303
+ raise ValueError("Pass species_tokenizer=... when constructing with a model module.")
304
+
305
+ # bed names (your notebooks refer to bed_element_names)
306
+ self.bed_element_names = (
307
+ getattr(self.config, "bed_elements_names", None)
308
+ or getattr(self.config, "bed_element_names", None)
309
+ )
310
+
311
+ self._target_device = _pick_device(device)
312
+ self.model.to(self._target_device)
313
+ self.model.eval()
314
+
315
+ super().__init__(model=self.model, tokenizer=self.tokenizer, device=-1, **kwargs)
316
+
317
+ def _sanitize_parameters(self, **kwargs):
318
+ return {}, {}, {}
319
+
320
+ def _get_model_device(self) -> torch.device:
321
+ return next(self.model.parameters()).device
322
+
323
+ def _resolve_species_and_assembly(self, inputs: Dict[str, Any]) -> tuple[str, str]:
324
+ species = inputs.get("species", self.default_species)
325
+ if species not in SPECIES_TO_ASSEMBLY:
326
+ raise ValueError(f"Unsupported species='{species}'. Supported species: {sorted(SPECIES_TO_ASSEMBLY.keys())}")
327
+ assembly = SPECIES_TO_ASSEMBLY[species]
328
+
329
+ cfg_assemblies = list(self.config.bigwigs_per_file_assembly.keys())
330
+ if assembly not in cfg_assemblies:
331
+ raise ValueError(
332
+ f"Species '{species}' maps to assembly '{assembly}', but that assembly is not available in this checkpoint. "
333
+ f"Available assemblies: {cfg_assemblies}"
334
+ )
335
+ return species, assembly
336
+
337
+
338
+ def _maybe_force_cpu_for_mps_long(self, input_ids_cpu: torch.Tensor) -> torch.device:
339
+ dev = self._get_model_device()
340
+ if self.mps_force_cpu and dev.type == "mps":
341
+ seq_len = int(input_ids_cpu.shape[-1])
342
+ if seq_len >= self.mps_force_cpu_length:
343
+ if self.verbose:
344
+ print(
345
+ f"[NTv3TracksPipeline] MPS detected and input is long (tokens={seq_len}). "
346
+ "Switching model + inputs to CPU for this run."
347
+ )
348
+ self.model.to("cpu")
349
+ self.model.eval()
350
+ return torch.device("cpu")
351
+ return dev
352
+
353
+ def preprocess(self, inputs: Dict[str, Any], **kwargs: Any) -> Dict[str, Any]:
354
+ species, assembly = self._resolve_species_and_assembly(inputs)
355
+
356
+ # Resolve sequence
357
+ if "seq" in inputs and inputs["seq"] is not None:
358
+ seq = _sanitize_dna(inputs["seq"])
359
+ chrom = None
360
+ start = 0
361
+ end = len(seq)
362
+ window_len = len(seq)
363
+ else:
364
+ chrom = inputs["chrom"]
365
+ start = int(inputs["start"])
366
+ end = int(inputs["end"])
367
+ window_len = end - start
368
+ fasta_path = inputs.get("fasta_path")
369
+ if fasta_path is None:
370
+ fasta_path = _ensure_fasta_for_assembly(assembly, self.genome_cache_dir)
371
+ seq = _fetch_from_fasta(fasta_path, chrom, start, end)
372
+
373
+ # Tokenize with padding
374
+ batch = self.tokenizer([seq], add_special_tokens=False, padding=True, pad_to_multiple_of=128, return_tensors="pt")
375
+ input_ids_cpu = batch["input_ids"]
376
+
377
+ # MPS-long fallback decision
378
+ device = self._maybe_force_cpu_for_mps_long(input_ids_cpu)
379
+
380
+ # Move inputs
381
+ input_ids = input_ids_cpu.to(device)
382
+ # Species tokenization - match batch size
383
+ batch_size = input_ids.shape[0]
384
+ species_ids = self.species_tokenizer([species] * batch_size, add_special_tokens=False, return_tensors="pt")
385
+ species_ids_tensor = species_ids["input_ids"].to(device)
386
+
387
+ # Prediction interval (not used for slicing logits, just x-axis)
388
+ pred_start = start + int(window_len * self.pred_center_offset_fraction)
389
+ pred_end = pred_start + int(window_len * self.pred_center_fraction)
390
+
391
+ # ✅ The source of truth for track IDs/names (your note)
392
+ bigwig_track_names = list(self.config.bigwigs_per_file_assembly[assembly])
393
+
394
+ return {
395
+ "input_ids": input_ids,
396
+ "species_ids": species_ids_tensor,
397
+ "meta": {
398
+ "chrom": chrom,
399
+ "start": start,
400
+ "end": end,
401
+ "species": species,
402
+ "assembly": assembly,
403
+ "window_len": window_len,
404
+ "pred_start": pred_start,
405
+ "pred_end": pred_end,
406
+ "bigwig_track_names": bigwig_track_names,
407
+ },
408
+ }
409
+
410
+ # prevent Pipeline from moving tensors to its own device
411
+ def forward(self, model_inputs, **forward_params):
412
+ return self._forward(model_inputs, **forward_params)
413
+
414
+ def _forward(self, model_inputs: Dict[str, Any], **kwargs: Any) -> Dict[str, Any]:
415
+ meta = model_inputs.pop("meta")
416
+ if self.verbose:
417
+ print(f"Running on device: {self._get_model_device()}")
418
+ with torch.no_grad():
419
+ out = self.model(
420
+ input_ids=model_inputs["input_ids"],
421
+ species_ids=model_inputs["species_ids"],
422
+ return_dict=True,
423
+ )
424
+ out["meta"] = meta
425
+ return out
426
+
427
+ def postprocess(self, model_outputs: Dict[str, Any], **kwargs: Any) -> NTv3TracksOutput:
428
+ meta = model_outputs.pop("meta", {})
429
+
430
+ def to_np(x):
431
+ return x.detach().float().cpu().numpy()
432
+
433
+ bigwig_np = to_np(model_outputs["bigwig_tracks_logits"])
434
+ bed_np = to_np(model_outputs["bed_tracks_logits"])
435
+ mlm_np = to_np(model_outputs["logits"])
436
+
437
+ # Normalize shapes to remove batch/(optional assembly) dims
438
+ if bigwig_np.ndim == 3:
439
+ bigwig_np = bigwig_np[0] # (L, T)
440
+ elif bigwig_np.ndim == 4:
441
+ bigwig_np = bigwig_np[0, 0] # (L, T) if (B, A, L, T)
442
+ else:
443
+ raise ValueError(f"Unexpected bigwig_tracks_logits ndim: {bigwig_np.ndim}")
444
+
445
+ if bed_np.ndim == 4:
446
+ bed_np = bed_np[0] # (L, E, C)
447
+ elif bed_np.ndim == 5:
448
+ bed_np = bed_np[0, 0] # (L, E, C) if (B, A, L, E, C)
449
+ else:
450
+ raise ValueError(f"Unexpected bed_tracks_logits ndim: {bed_np.ndim}")
451
+
452
+ if mlm_np.ndim == 3:
453
+ mlm_np = mlm_np[0]
454
+
455
+ return NTv3TracksOutput(
456
+ bigwig_tracks_logits=bigwig_np,
457
+ bed_tracks_logits=bed_np,
458
+ mlm_logits=mlm_np,
459
+ chrom=meta.get("chrom"),
460
+ start=meta.get("start"),
461
+ end=meta.get("end"),
462
+ species=meta.get("species"),
463
+ assembly=meta.get("assembly"),
464
+ bigwig_track_names=meta.get("bigwig_track_names"),
465
+ bed_element_names=self.bed_element_names,
466
+ window_len=meta.get("window_len"),
467
+ pred_start=meta.get("pred_start"),
468
+ pred_end=meta.get("pred_end"),
469
+ )
470
+
471
+ def __call__(
472
+ self,
473
+ inputs,
474
+ *args,
475
+ plot: bool = False,
476
+ tracks_to_plot: Optional[Dict[str, str]] = None, # title -> track_id (ENCSR...)
477
+ elements_to_plot: Optional[List[str]] = None, # element names
478
+ plot_height: float = 1.0,
479
+ plot_figsize_x: float = 20.0,
480
+ **kwargs,
481
+ ):
482
+ """
483
+ One-step call that can optionally plot and always returns NTv3TracksOutput.
484
+ """
485
+ out: NTv3TracksOutput = super().__call__(inputs, *args, **kwargs)
486
+
487
+ if plot:
488
+ if out.bigwig_track_names is None:
489
+ raise ValueError("bigwig_track_names missing; expected cfg.bigwigs_per_file_assembly[assembly].")
490
+ if out.bed_element_names is None:
491
+ raise ValueError("bed element names missing from config.")
492
+ tracks_to_plot = tracks_to_plot or {}
493
+ elements_to_plot = elements_to_plot or []
494
+
495
+ bigwig_names = out.bigwig_track_names
496
+ bed_element_names = out.bed_element_names
497
+
498
+ # Validate
499
+ missing_tracks = [tid for tid in tracks_to_plot.values() if tid not in bigwig_names]
500
+ if missing_tracks:
501
+ raise ValueError(
502
+ f"The following tracks are not available in bigwig_names: {missing_tracks}\n"
503
+ f"First 50 available: {bigwig_names[:50]}{'...' if len(bigwig_names) > 50 else ''}"
504
+ )
505
+
506
+ missing_elements = [e for e in elements_to_plot if e not in bed_element_names]
507
+ if missing_elements:
508
+ raise ValueError(
509
+ f"The following elements are not available in bed_element_names: {missing_elements}\n"
510
+ f"First 50 available: {bed_element_names[:50]}{'...' if len(bed_element_names) > 50 else ''}"
511
+ )
512
+
513
+ # Build bigwig tracks dict (title -> y)
514
+ bigwig_tracks: Dict[str, np.ndarray] = {}
515
+ bigwig = out.bigwig_tracks_logits # (L_pred, T)
516
+ for title, track_id in tracks_to_plot.items():
517
+ track_idx = bigwig_names.index(track_id)
518
+ bigwig_tracks[title] = bigwig[:, track_idx]
519
+
520
+ # Bed positive class probabilities (title -> y)
521
+ bed_probs: Dict[str, np.ndarray] = {}
522
+ probs = _softmax_last(out.bed_tracks_logits) # (L_pred, E, C)
523
+ for element_name in elements_to_plot:
524
+ element_idx = bed_element_names.index(element_name)
525
+ bed_probs[element_name] = probs[:, element_idx, 1]
526
+
527
+ all_tracks = {**bigwig_tracks, **bed_probs}
528
+
529
+ plot_start = int(out.pred_start or 0)
530
+ plot_end = int(out.pred_end or (plot_start + len(next(iter(all_tracks.values())))))
531
+
532
+ _plot_tracks_fillbetween(
533
+ all_tracks,
534
+ chrom=out.chrom,
535
+ start=plot_start,
536
+ end=plot_end,
537
+ assembly=out.assembly,
538
+ height=plot_height,
539
+ figsize_x=plot_figsize_x,
540
+ )
541
+
542
+ return out
543
+
544
+ def load_ntv3_tracks_pipeline(
545
+ model: str,
546
+ device: str = "auto",
547
+ **pipeline_kwargs: Any,
548
+ ):
549
+ """
550
+ Convenience helper to build an NTv3TracksPipeline for any NTv3 checkpoint.
551
+
552
+ Parameters
553
+ ----------
554
+ model:
555
+ Checkpoint id, e.g. "InstaDeepAI/NTv3_100M", "InstaDeepAI/NTv3_650M", ...
556
+ device:
557
+ "auto", "cpu", "cuda", "mps"
558
+ pipeline_kwargs:
559
+ Extra kwargs passed to NTv3TracksPipeline (default_species, genome_cache_dir, etc.).
560
+ """
561
+ pipe = NTv3TracksPipeline(
562
+ model=model,
563
+ trust_remote_code=True,
564
+ device=device,
565
+ **pipeline_kwargs,
566
+ )
567
+ return pipe
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ transformers>=4.41.0
2
+ torch
3
+ numpy
4
+ gradio>=4.0.0
5
+ pyfaidx
6
+ requests