bernardo-de-almeida commited on
Commit
f7c9069
·
1 Parent(s): 7163437

feat: improve demo

Browse files
Files changed (4) hide show
  1. README.md +1 -12
  2. app.py +540 -97
  3. ntv3_tracks_pipeline.py +71 -52
  4. requirements.txt +1 -0
README.md CHANGED
@@ -11,15 +11,4 @@ pinned: false
11
 
12
  # NTv3 Tracks Demo
13
 
14
- This Space deploys the custom Hugging Face `Pipeline` in `ntv3_tracks_pipeline.py` and provides both:
15
- - a UI
16
- - a REST API (`/api/predict`, auto-generated by Gradio)
17
-
18
- ## Environment variables (optional)
19
-
20
- - `MODEL_ID` (default: `InstaDeepAI/NTv3_100M`)
21
- - `DEFAULT_SPECIES` (default: `human`)
22
-
23
- ## Notes
24
-
25
- Genome-coordinate mode may download and decompress large FASTA files. For a lightweight demo, send a DNA sequence directly via `seq`.
 
11
 
12
  # NTv3 Tracks Demo
13
 
14
+ This Space deploys the custom Hugging Face `Pipeline` in `ntv3_tracks_pipeline.py`.
 
 
 
 
 
 
 
 
 
 
 
app.py CHANGED
@@ -1,33 +1,195 @@
1
  import os
 
 
2
  import numpy as np
3
  import gradio as gr
 
 
4
 
5
- # local file in the Space repo
6
- from ntv3_tracks_pipeline import load_ntv3_tracks_pipeline
7
 
8
- MODEL_ID = os.environ.get("MODEL_ID", "InstaDeepAI/NTv3_650M_pos")
9
- DEFAULT_SPECIES = os.environ.get("DEFAULT_SPECIES", "human")
10
 
 
 
 
 
 
11
  HF_TOKEN = (
12
- os.environ.get("HF_TOKEN")
13
- or os.environ.get("HUGGINGFACEHUB_API_TOKEN") # also common in Spaces
 
14
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
 
16
- # Load once at startup (Space container)
17
- pipe = load_ntv3_tracks_pipeline(
18
- model=MODEL_ID,
19
- device="auto",
20
- default_species=DEFAULT_SPECIES,
21
- token=HF_TOKEN,
22
- verbose=False,
23
- )
24
 
25
- def _downsample_1d(arr: np.ndarray, max_points: int):
26
- if max_points is None or max_points <= 0 or arr.shape[0] <= max_points:
27
- return arr, 1
28
- stride = int(np.ceil(arr.shape[0] / max_points))
29
- return arr[::stride], stride
30
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
  def predict(
32
  seq: str,
33
  species: str,
@@ -35,125 +197,406 @@ def predict(
35
  start: int,
36
  end: int,
37
  use_coords: bool,
38
- tracks: str,
39
- elements: str,
40
- max_points: int,
41
  ):
42
- """
43
- Returns JSON-serializable dict (Gradio also exposes this at /api/predict by default).
44
- """
45
  if use_coords:
46
  if not chrom:
47
  raise gr.Error("chrom is required when use_coords=True")
48
- if start is None or end is None or end <= start:
49
  raise gr.Error("start/end must be set and end > start when use_coords=True")
50
  inputs = {"chrom": chrom, "start": int(start), "end": int(end), "species": species}
51
  else:
52
- if not seq or len(seq.strip()) == 0:
53
  raise gr.Error("seq is required when use_coords=False")
54
  inputs = {"seq": seq.strip(), "species": species}
55
 
56
  out = pipe(inputs)
57
 
58
- # Parse selection lists
59
- track_ids = [t.strip() for t in tracks.split(",") if t.strip()] if tracks else []
60
- element_names = [e.strip() for e in elements.split(",") if e.strip()] if elements else []
61
-
62
- # Bigwig tracks
63
- bigwig_names = out.bigwig_track_names or []
64
- bw = out.bigwig_tracks_logits # (L, T)
65
- bw_selected = {}
66
- for tid in track_ids:
67
- if tid not in bigwig_names:
68
- continue
69
- idx = bigwig_names.index(tid)
70
- y, stride = _downsample_1d(bw[:, idx], max_points)
71
- bw_selected[tid] = {"values": y.astype(float).tolist(), "stride": int(stride)}
72
-
73
- # BED elements (positive class probability)
74
- bed_selected = {}
75
- if out.bed_element_names is not None and element_names:
76
- logits = out.bed_tracks_logits # (L, E, C)
77
- # softmax over last axis
78
- logits = logits - logits.max(axis=-1, keepdims=True)
79
- probs = np.exp(logits) / np.exp(logits).sum(axis=-1, keepdims=True)
80
- for ename in element_names:
81
- if ename not in out.bed_element_names:
82
- continue
83
- eidx = out.bed_element_names.index(ename)
84
- y, stride = _downsample_1d(probs[:, eidx, 1], max_points)
85
- bed_selected[ename] = {"values": y.astype(float).tolist(), "stride": int(stride)}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
86
 
87
  meta = {
88
- "model_id": MODEL_ID,
89
  "species": out.species,
90
  "assembly": out.assembly,
91
  "chrom": out.chrom,
92
- "start": out.start,
93
- "end": out.end,
94
- "window_len": out.window_len,
95
  "pred_start": out.pred_start,
96
  "pred_end": out.pred_end,
 
 
 
 
97
  }
98
 
99
- return {
100
- "meta": meta,
101
- "bigwig_track_names_count": len(bigwig_names),
102
- "bigwig_selected": bw_selected,
103
- "bed_selected": bed_selected,
104
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
105
 
106
  with gr.Blocks(title="NTv3 Tracks Demo") as demo:
107
  gr.Markdown(
108
- """# NTv3 tracks demo (Space)
 
109
 
110
- This Space runs your `NTv3TracksPipeline` and exposes:
111
- - an interactive UI
112
- - a REST API (Gradio auto-generated endpoint)
113
 
114
- **Tip:** For reliable, fast demos, pass a DNA **sequence** directly. Genome-coordinate mode may download a whole genome FASTA.
115
- """
 
 
 
 
116
  )
117
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
118
  with gr.Row():
119
- use_coords = gr.Checkbox(value=False, label="Use genome coords instead of seq")
120
- species = gr.Dropdown(choices=["human","mouse","drosophila_melanogaster"], value=DEFAULT_SPECIES, label="species")
 
 
 
 
121
 
122
- seq = gr.Textbox(lines=4, label="DNA sequence (A/C/G/T/N)")
123
  with gr.Row():
124
- chrom = gr.Textbox(label="chrom (e.g. chr1)")
125
- start = gr.Number(label="start", value=0, precision=0)
126
- end = gr.Number(label="end", value=1024, precision=0)
127
 
128
- tracks = gr.Textbox(label="BigWig track IDs to return (comma-separated)", placeholder="ENCSR... , ENCSR...")
129
- elements = gr.Textbox(label="BED element names to return (comma-separated)", placeholder="e.g. CTCF, H3K27ac")
130
- max_points = gr.Slider(100, 5000, value=1000, step=100, label="Max points per returned series (downsample)")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
131
 
132
- btn = gr.Button("Predict")
133
- out = gr.JSON(label="Output JSON")
134
 
135
- btn.click(
136
- fn=predict,
137
- inputs=[seq, species, chrom, start, end, use_coords, tracks, elements, max_points],
138
- outputs=[out],
139
  )
140
 
141
- gr.Markdown(
142
- """## API usage
 
 
143
 
144
- After you deploy, Gradio exposes an endpoint like:
 
 
 
145
 
146
- - `POST https://<your-space>.hf.space/api/predict`
 
 
147
 
148
- with JSON body:
 
 
 
 
 
 
 
149
 
150
- ```json
151
- {"data": ["ACGT...", "human", "", 0, 0, false, "ENCSR...", "CTCF", 1000]}
152
- ```
153
 
154
- The response is a JSON dict with `meta`, plus any requested tracks/elements.
155
- """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
156
  )
157
 
158
  if __name__ == "__main__":
159
- demo.launch()
 
 
 
 
 
 
 
 
 
1
  import os
2
+ import uuid
3
+ import tempfile
4
  import numpy as np
5
  import gradio as gr
6
+ import matplotlib.pyplot as plt
7
+ import asyncio
8
 
9
+ from ntv3_tracks_pipeline import load_ntv3_tracks_pipeline, BED_ELEMENT_COLORS
 
10
 
 
 
11
 
12
+ # -----------------------------
13
+ # Env / auth
14
+ # -----------------------------
15
+ MODEL_ID = os.environ.get("MODEL_ID", "InstaDeepAI/NTv3_100M_pos")
16
+ DEFAULT_SPECIES = os.environ.get("DEFAULT_SPECIES", "human")
17
  HF_TOKEN = (
18
+ os.environ.get("NTV3_HF_TOKEN")
19
+ or os.environ.get("HF_TOKEN")
20
+ or os.environ.get("HUGGINGFACEHUB_API_TOKEN")
21
  )
22
+ if HF_TOKEN is None:
23
+ raise RuntimeError("Missing Hugging Face token. Set NTV3_HF_TOKEN as a Space Secret.")
24
+
25
+ asyncio.set_event_loop_policy(asyncio.DefaultEventLoopPolicy())
26
+
27
+ PLOT_TARGET_POINTS = int(os.environ.get("PLOT_TARGET_POINTS", "1500"))
28
+ SEARCH_MAX_RESULTS = int(os.environ.get("SEARCH_MAX_RESULTS", "50"))
29
+
30
+
31
+ # -----------------------------
32
+ # Load pipeline (reloadable)
33
+ # -----------------------------
34
+ pipe = None
35
+ current_model_id = MODEL_ID
36
+
37
+ def load_pipeline(model_id: str, species: str = DEFAULT_SPECIES):
38
+ """Load or reload the pipeline with a new model."""
39
+ global pipe, current_model_id
40
+ pipe = load_ntv3_tracks_pipeline(
41
+ model=model_id,
42
+ token=HF_TOKEN,
43
+ device="auto",
44
+ default_species=species,
45
+ verbose=False,
46
+ )
47
+ current_model_id = model_id
48
+ return pipe
49
+
50
+ # Load initial pipeline
51
+ load_pipeline(MODEL_ID, DEFAULT_SPECIES)
52
+
53
+
54
+ # -----------------------------
55
+ # Helpers
56
+ # -----------------------------
57
+ def _softmax_last(x: np.ndarray) -> np.ndarray:
58
+ x = x - x.max(axis=-1, keepdims=True)
59
+ ex = np.exp(x)
60
+ return ex / ex.sum(axis=-1, keepdims=True)
61
+
62
+
63
+ def _global_stride(L: int, target: int) -> int:
64
+ if target <= 0 or L <= target:
65
+ return 1
66
+ return int(np.ceil(L / target))
67
+
68
+
69
+ def _make_tracks_figure(x: np.ndarray, series: list[tuple[str, np.ndarray]]):
70
+ if not series:
71
+ raise gr.Error("Nothing to plot (no tracks/elements selected).")
72
+
73
+ n = len(series)
74
+ fig, axes = plt.subplots(n, 1, figsize=(18, 1.35 * n), sharex=True)
75
+ if n == 1:
76
+ axes = [axes]
77
+
78
+ # Define color schemes
79
+ bigwig_color = "#4A90E2" # Blue
80
+
81
+ for ax, (title, y) in zip(axes, series):
82
+ # Determine color based on track type
83
+ if title in BED_ELEMENT_COLORS:
84
+ color = BED_ELEMENT_COLORS[title]
85
+ else:
86
+ color = bigwig_color
87
+
88
+ ax.fill_between(x, y, color=color, alpha=0.3, linewidth=0)
89
+ ax.plot(x, y, color=color, linewidth=0.8)
90
+ ax.set_title(title, fontsize=10, loc="left")
91
+ ax.grid(alpha=0.2)
92
+ ax.set_yticks([])
93
+ ax.spines["top"].set_visible(False)
94
+ ax.spines["right"].set_visible(False)
95
+
96
+ axes[-1].set_xlabel("Genomic position / index")
97
+ fig.tight_layout()
98
+ return fig
99
+
100
+
101
+ def _save_fig_png(fig) -> str:
102
+ tmpdir = tempfile.gettempdir()
103
+ out_path = os.path.join(tmpdir, f"ntv3_tracks_{uuid.uuid4().hex}.png")
104
+ fig.savefig(out_path, dpi=200, bbox_inches="tight")
105
+ return out_path
106
+
107
+
108
+ # Cache track lists per species so search is instant after first load
109
+ _BIGWIG_CACHE: dict[str, list[str]] = {}
110
+
111
+
112
+ def _get_bigwig_names(species: str) -> list[str]:
113
+ if species not in _BIGWIG_CACHE:
114
+ _BIGWIG_CACHE[species] = pipe.available_bigwig_track_names(species)
115
+ return _BIGWIG_CACHE[species]
116
+
117
+
118
+ def _rank_search(query: str, names: list[str], limit: int) -> list[str]:
119
+ """
120
+ Return up to `limit` candidate track IDs matching `query` using a fast,
121
+ low-overhead ranking suitable for very large `names` lists.
122
+
123
+ Matching & ranking rules:
124
+ 1) Case-insensitive match.
125
+ 2) Items whose ID *starts with* the query are ranked first.
126
+ 3) Remaining items that merely *contain* the query are ranked after.
127
+ 4) Results preserve the original relative order within each group
128
+ (stable w.r.t. the input `names` order).
129
+ 5) If `query` is empty/whitespace, returns an empty list to avoid
130
+ flooding the UI with a huge default list.
131
+
132
+ Notes:
133
+ - `limit` only caps the number of returned results; it does not prevent
134
+ short queries (e.g. "E") from producing many matches—if you want that,
135
+ add a minimum query length check (e.g. `if len(q) < 2: return []`).
136
+ - Time complexity is O(len(names)) per call.
137
+ """
138
+ q = (query or "").strip().lower()
139
+ if not q:
140
+ return [] # don’t spam a giant default list
141
+
142
+ starts = []
143
+ contains = []
144
+
145
+ for n in names:
146
+ nl = n.lower()
147
+ if nl.startswith(q):
148
+ starts.append(n)
149
+ elif q in nl:
150
+ contains.append(n)
151
+
152
+ out = starts + contains
153
+ return out[:limit]
154
+
155
+
156
+ def search_bigwigs(species: str, query: str):
157
+ names = _get_bigwig_names(species)
158
+ results = _rank_search(query, names, SEARCH_MAX_RESULTS)
159
+ return gr.update(choices=results, value=[])
160
+
161
+
162
+ def add_selected(current_selected: list[str], to_add: list[str]):
163
+ cur = list(dict.fromkeys(current_selected or [])) # preserve order, unique
164
+ for x in (to_add or []):
165
+ if x not in cur:
166
+ cur.append(x)
167
+ return gr.update(choices=cur, value=cur) # show + keep all checked
168
+
169
+
170
+ def remove_selected(current_selected: list[str], to_remove: list[str]):
171
+ cur = [x for x in (current_selected or []) if x not in set(to_remove or [])]
172
+ return gr.update(choices=cur, value=cur)
173
 
 
 
 
 
 
 
 
 
174
 
175
+ def update_coords_on_species_change(species: str):
176
+ """Update coordinates when species changes."""
177
+ coords = DEFAULT_COORDS.get(species, DEFAULT_COORDS["human"])
178
+ return coords["chrom"], coords["start"], coords["end"]
 
179
 
180
+ def reset_on_species_change(species: str):
181
+ # Clear results + selected when species changes (avoids mismatched IDs)
182
+ _get_bigwig_names(species) # warms cache
183
+ return (
184
+ gr.update(value=""), # query textbox
185
+ gr.update(choices=[], value=[]), # results list
186
+ gr.update(choices=[], value=[]), # selected list
187
+ )
188
+
189
+
190
+ # -----------------------------
191
+ # Predict
192
+ # -----------------------------
193
  def predict(
194
  seq: str,
195
  species: str,
 
197
  start: int,
198
  end: int,
199
  use_coords: bool,
200
+ bigwig_selected: list[str],
201
+ bed_elements: list[str],
 
202
  ):
 
 
 
203
  if use_coords:
204
  if not chrom:
205
  raise gr.Error("chrom is required when use_coords=True")
206
+ if start is None or end is None or int(end) <= int(start):
207
  raise gr.Error("start/end must be set and end > start when use_coords=True")
208
  inputs = {"chrom": chrom, "start": int(start), "end": int(end), "species": species}
209
  else:
210
+ if not seq or not seq.strip():
211
  raise gr.Error("seq is required when use_coords=False")
212
  inputs = {"seq": seq.strip(), "species": species}
213
 
214
  out = pipe(inputs)
215
 
216
+ bw_names = out.bigwig_track_names or []
217
+ bw = out.bigwig_tracks_logits
218
+ bed_names = out.bed_element_names or []
219
+ bed_logits = out.bed_tracks_logits
220
+
221
+ if bw is None or not bw_names:
222
+ raise gr.Error("No BigWig tracks available in model output.")
223
+
224
+ # Defaults if user picked none
225
+ if not bigwig_selected:
226
+ default_bigwig_tracks = [
227
+ "ENCSR056HPM", # K562 RNA-seq
228
+ "ENCSR921NMD", # K562 DNAse
229
+ "ENCSR000DWD", # K562 H3k4me3
230
+ "ENCSR000AKO", # K562 CTCF
231
+ "ENCSR561FEE_P", # HepG2 RNA-seq
232
+ "ENCSR000EJV", # HepG2 DNAse
233
+ "ENCSR000AMP", # HepG2 H3k4me3
234
+ "ENCSR000BIE", # HepG2 CTCF
235
+ ]
236
+ # Filter to only include tracks that are available for this species/assembly
237
+ bigwig_selected = [tid for tid in default_bigwig_tracks if tid in bw_names]
238
+ if (not bed_elements) and bed_names:
239
+ default_bed_elements = ["protein_coding_gene", "exon", "intron"]
240
+ # Filter to only include elements that are available
241
+ bed_elements = [elem for elem in default_bed_elements if elem in bed_names]
242
+
243
+ # Validate (important for API usage)
244
+ missing_tracks = [t for t in bigwig_selected if t not in bw_names]
245
+ if missing_tracks:
246
+ raise gr.Error(f"Unknown BigWig track id(s): {missing_tracks}")
247
+
248
+ missing_elems = [e for e in bed_elements if e not in bed_names]
249
+ if missing_elems:
250
+ raise gr.Error(f"Unknown BED element(s): {missing_elems}")
251
+
252
+ L = bw.shape[0]
253
+ stride = _global_stride(L, PLOT_TARGET_POINTS)
254
+
255
+ x0 = int(out.pred_start or 0)
256
+ x1 = int(out.pred_end or (x0 + L))
257
+ x = np.linspace(x0, x1, num=L, endpoint=False)[::stride]
258
+
259
+ series: list[tuple[str, np.ndarray]] = []
260
+ for tid in bigwig_selected:
261
+ idx = bw_names.index(tid)
262
+ series.append((tid, bw[:, idx][::stride].astype(float)))
263
+
264
+ if bed_logits is not None and bed_elements:
265
+ probs = _softmax_last(bed_logits)
266
+ for ename in bed_elements:
267
+ eidx = bed_names.index(ename)
268
+ series.append((ename, probs[:, eidx, 1][::stride].astype(float)))
269
+
270
+ fig = _make_tracks_figure(x, series)
271
+
272
+ region = f"{out.chrom}:{out.pred_start}-{out.pred_end}" if out.chrom else f"{x0}-{x1}"
273
+ if out.assembly:
274
+ region += f" ({out.assembly})"
275
+ fig.axes[-1].set_xlabel(region)
276
+
277
+ png_path = _save_fig_png(fig)
278
 
279
  meta = {
280
+ "model_id": current_model_id,
281
  "species": out.species,
282
  "assembly": out.assembly,
283
  "chrom": out.chrom,
 
 
 
284
  "pred_start": out.pred_start,
285
  "pred_end": out.pred_end,
286
+ "bigwig_selected": bigwig_selected,
287
+ "bed_selected": bed_elements,
288
+ "plot_stride": stride,
289
+ "plot_target_points": PLOT_TARGET_POINTS,
290
  }
291
 
292
+ return fig, png_path, meta
293
+
294
+
295
+ # -----------------------------
296
+ # UI (keep your download icon setup)
297
+ # -----------------------------
298
+ CSS = """
299
+ #tracks_plot { position: relative; width: 100% !important; max-width: 100% !important; }
300
+ #tracks_plot .wrap, #tracks_plot .plot-container { width: 100% !important; max-width: 100% !important; }
301
+
302
+ #tracks_plot_download {
303
+ position: absolute;
304
+ top: 10px;
305
+ right: 12px;
306
+ z-index: 50;
307
+ background: rgba(0,0,0,0.55);
308
+ border: 1px solid rgba(255,255,255,0.15);
309
+ border-radius: 10px;
310
+ padding: 6px 8px;
311
+ cursor: pointer;
312
+ user-select: none;
313
+ }
314
+ #tracks_plot_download:hover { background: rgba(0,0,0,0.7); }
315
+ #tracks_plot_download svg { width: 18px; height: 18px; display: block; fill: white; }
316
+ #export_png_hidden { display: none !important; }
317
+
318
+ #predict_btn {
319
+ background-color: #FF6B35 !important;
320
+ color: white !important;
321
+ border: none !important;
322
+ }
323
+ #predict_btn:hover {
324
+ background-color: #E55A2B !important;
325
+ }
326
+
327
+ #intro_markdown {
328
+ font-size: 1.3em !important;
329
+ line-height: 1.7 !important;
330
+ }
331
+ #intro_markdown h1 {
332
+ font-size: 2.8em !important;
333
+ margin-bottom: 0.6em !important;
334
+ }
335
+ #intro_markdown h2, #intro_markdown h3 {
336
+ font-size: 1.8em !important;
337
+ }
338
+ #intro_markdown p, #intro_markdown li {
339
+ font-size: 1.2em !important;
340
+ }
341
+ """
342
+
343
+ JS = """
344
+ function addDownloadIcon() {
345
+ const plot = document.querySelector("#tracks_plot");
346
+ if (!plot) return;
347
+ if (document.querySelector("#tracks_plot_download")) return;
348
+
349
+ const btn = document.createElement("div");
350
+ btn.id = "tracks_plot_download";
351
+ btn.title = "Download PNG";
352
+ btn.innerHTML = `
353
+ <svg viewBox="0 0 24 24" aria-hidden="true">
354
+ <path d="M5 20h14v-2H5v2zm7-18v10.17l3.59-3.58L17 10l-5 5-5-5 1.41-1.41L11 12.17V2h1z"/>
355
+ </svg>
356
+ `;
357
+ btn.onclick = () => {
358
+ const link = document.querySelector("#export_png_hidden a");
359
+ if (link) link.click();
360
+ };
361
+ plot.appendChild(btn);
362
+ }
363
+ function setup() {
364
+ addDownloadIcon();
365
+ const obs = new MutationObserver(() => addDownloadIcon());
366
+ obs.observe(document.body, { childList: true, subtree: true });
367
+ }
368
+ setup();
369
+ """
370
+
371
+ # BED list is small enough to keep as dropdown
372
+ _init_bed = pipe.available_bed_element_names()
373
+
374
+ # Default BigWig tracks
375
+ DEFAULT_BIGWIG_TRACKS = [
376
+ "ENCSR056HPM", # K562 RNA-seq
377
+ "ENCSR921NMD", # K562 DNAse
378
+ "ENCSR000DWD", # K562 H3k4me3
379
+ "ENCSR000AKO", # K562 CTCF
380
+ "ENCSR561FEE_P", # HepG2 RNA-seq
381
+ "ENCSR000EJV", # HepG2 DNAse
382
+ "ENCSR000AMP", # HepG2 H3k4me3
383
+ "ENCSR000BIE", # HepG2 CTCF
384
+ ]
385
+
386
+ # Default BED elements
387
+ DEFAULT_BED_ELEMENTS = ["protein_coding_gene", "exon", "intron"]
388
+
389
+ # Get available BigWig tracks for default species and filter defaults
390
+ _init_bigwig = _get_bigwig_names(DEFAULT_SPECIES)
391
+ _init_bigwig_selected = [tid for tid in DEFAULT_BIGWIG_TRACKS if tid in _init_bigwig]
392
+
393
+ # Filter default BED elements to only those available
394
+ _init_bed_selected = [elem for elem in DEFAULT_BED_ELEMENTS if elem in _init_bed]
395
+
396
+ # Default coordinates per species
397
+ DEFAULT_COORDS = {
398
+ "human": {"chrom": "chr19", "start": 6_700_000, "end": 6_831_072},
399
+ "mouse": {"chrom": "chr1", "start": 100_000, "end": 200_000},
400
+ "drosophila_melanogaster": {"chrom": "chr2L", "start": 1_000_000, "end": 2_000_000},
401
+ }
402
+
403
+ # Get default coordinates for default species
404
+ _default_coords = DEFAULT_COORDS.get(DEFAULT_SPECIES, DEFAULT_COORDS["human"])
405
+
406
+ # Default coordinates per species
407
+ DEFAULT_COORDS = {
408
+ "human": {"chrom": "chr19", "start": 6_700_000, "end": 6_831_072},
409
+ "mouse": {"chrom": "chr1", "start": 0, "end": 32_768},
410
+ "drosophila_melanogaster": {"chrom": "chr2L", "start": 0, "end": 32_768},
411
+ }
412
+
413
+ # Get default coordinates for default species
414
+ _default_coords = DEFAULT_COORDS.get(DEFAULT_SPECIES, DEFAULT_COORDS["human"])
415
 
416
  with gr.Blocks(title="NTv3 Tracks Demo") as demo:
417
  gr.Markdown(
418
+ """
419
+ # 🧬 NTv3 Tracks Demo
420
 
421
+ **Predict functional genomics tracks and genome annotation elements from DNA sequences using NTv3 (Nucleotide Transformer v3).**
 
 
422
 
423
+ This demo allows you to:
424
+ - **Input**: Provide a DNA sequence directly or specify genomic coordinates (chromosome, start, end)
425
+ - **Select tracks**: Choose from hundreds of BigWig functional tracks (e.g., RNA-seq, ChIP-seq, DNase) and genome annotation elements (e.g., exons, introns, promoters)
426
+ - **Visualize**: View NTv3 predictions across the input sequence
427
+ """,
428
+ elem_id="intro_markdown",
429
  )
430
 
431
+ gr.Markdown("## Select NTv3 post-trained model")
432
+
433
+ # Model display names (without InstaDeepAI/ prefix) and their full IDs
434
+ MODEL_OPTIONS = {
435
+ "NTv3 650M (pos)": "InstaDeepAI/NTv3_650M_pos",
436
+ "NTv3 100M (pos)": "InstaDeepAI/NTv3_100M_pos",
437
+ }
438
+
439
+ # Reverse mapping: full ID -> display name
440
+ MODEL_ID_TO_DISPLAY = {v: k for k, v in MODEL_OPTIONS.items()}
441
+
442
+ # Get display name for current model
443
+ current_display_name = MODEL_ID_TO_DISPLAY.get(current_model_id, "NTv3 100M (pos)")
444
+
445
+ model_selector = gr.Dropdown(
446
+ choices=list(MODEL_OPTIONS.keys()),
447
+ value=current_display_name,
448
+ label="Model",
449
+ )
450
+
451
+ model_status = gr.Markdown("", visible=False)
452
+
453
+ gr.Markdown("## Input sequence (Genomic coordinate or DNA sequence)")
454
+
455
  with gr.Row():
456
+ species = gr.Dropdown(
457
+ ["human", "mouse", "drosophila_melanogaster"],
458
+ value=DEFAULT_SPECIES,
459
+ label="Species",
460
+ )
461
+ use_coords = gr.Checkbox(True, label="Use genome coordinates")
462
 
 
463
  with gr.Row():
464
+ chrom = gr.Textbox(label="Chromosome", value=_default_coords["chrom"])
465
+ start = gr.Number(label="Start", value=_default_coords["start"], precision=0)
466
+ end = gr.Number(label="End", value=_default_coords["end"], precision=0)
467
 
468
+ seq = gr.Textbox(lines=4, label="Input DNA sequence", placeholder="ACGT...")
469
+
470
+ def change_model(display_name: str, species: str):
471
+ """Reload pipeline with new model."""
472
+ try:
473
+ # Convert display name to full model ID
474
+ if display_name in MODEL_OPTIONS:
475
+ model_id = MODEL_OPTIONS[display_name]
476
+ else:
477
+ # Fallback: assume it's already a model ID or custom value
478
+ model_id = display_name
479
+
480
+ load_pipeline(model_id, species)
481
+ # Update available tracks/elements
482
+ _get_bigwig_names(species) # warm cache
483
+ return gr.update(value="✅ Model loaded successfully"), gr.update(visible=True)
484
+ except Exception as e:
485
+ return gr.update(value=f"❌ Error loading model: {str(e)}"), gr.update(visible=True)
486
+
487
+ model_selector.change(
488
+ fn=change_model,
489
+ inputs=[model_selector, species],
490
+ outputs=[model_status, model_status],
491
+ )
492
 
493
+ gr.Markdown("## Select functional tracks")
 
494
 
495
+ bigwig_selected = gr.CheckboxGroup(
496
+ choices=_init_bigwig_selected,
497
+ value=_init_bigwig_selected,
498
+ label="Selected functional tracks (used for prediction)",
499
  )
500
 
501
+ bigwig_query = gr.Textbox(
502
+ label="Search functional tracks (auto-search while typing)",
503
+ placeholder="Type to search… (e.g. ENCSR056HPM for K562 RNA-seq)",
504
+ )
505
 
506
+ bigwig_results = gr.CheckboxGroup(
507
+ choices=[],
508
+ label="Results (click to add to Selected)",
509
+ )
510
 
511
+ with gr.Row():
512
+ bigwig_clear_btn = gr.Button("Clear results")
513
+ bigwig_remove_btn = gr.Button("Remove checked from Selected")
514
 
515
+ gr.Markdown("## Select genome annotation elements")
516
+
517
+ bed_elements = gr.Dropdown(
518
+ choices=_init_bed,
519
+ value=_init_bed_selected if _init_bed_selected else [],
520
+ multiselect=True,
521
+ label="Genome annotation elements (search + select)",
522
+ )
523
 
524
+ btn = gr.Button("Predict", elem_id="predict_btn")
 
 
525
 
526
+ gr.Markdown("## NTv3 predictions for selected tracks and elements")
527
+
528
+ plot = gr.Plot(label="", elem_id="tracks_plot")
529
+ export_png = gr.File(elem_id="export_png_hidden", interactive=False)
530
+
531
+ with gr.Accordion("Meta (click to expand)", open=False):
532
+ meta = gr.JSON(label="Meta")
533
+
534
+ # --- wiring (live search + auto-add) ---
535
+
536
+ # Live search on every keystroke
537
+ bigwig_query.input(
538
+ fn=search_bigwigs,
539
+ inputs=[species, bigwig_query],
540
+ outputs=[bigwig_results],
541
+ )
542
+
543
+ # Auto-add: whenever user checks items in results, add them to Selected,
544
+ # then clear results selection (so it feels like "click to add")
545
+ def _auto_add(selected_now: list[str], results_checked: list[str]):
546
+ upd = add_selected(selected_now, results_checked) # reuses your function
547
+ # clear checks in results, keep choices
548
+ return upd, gr.update(value=[])
549
+
550
+ bigwig_results.change(
551
+ fn=_auto_add,
552
+ inputs=[bigwig_selected, bigwig_results],
553
+ outputs=[bigwig_selected, bigwig_results],
554
+ )
555
+
556
+ # Clear results list (handy when query is short)
557
+ def _clear_results():
558
+ return gr.update(choices=[], value=[]), gr.update(value="")
559
+
560
+ bigwig_clear_btn.click(
561
+ fn=_clear_results,
562
+ inputs=[],
563
+ outputs=[bigwig_results, bigwig_query],
564
+ )
565
+
566
+ # Remove: check items in Selected, then click Remove
567
+ bigwig_remove_btn.click(
568
+ fn=remove_selected,
569
+ inputs=[bigwig_selected, bigwig_selected],
570
+ outputs=[bigwig_selected],
571
+ )
572
+
573
+ species.change(
574
+ fn=reset_on_species_change,
575
+ inputs=[species],
576
+ outputs=[bigwig_query, bigwig_results, bigwig_selected],
577
+ )
578
+
579
+ # Update coordinates when species changes
580
+ species.change(
581
+ fn=update_coords_on_species_change,
582
+ inputs=[species],
583
+ outputs=[chrom, start, end],
584
+ )
585
+
586
+ btn.click(
587
+ fn=predict,
588
+ inputs=[seq, species, chrom, start, end, use_coords, bigwig_selected, bed_elements],
589
+ outputs=[plot, export_png, meta],
590
+ api_name="predict",
591
  )
592
 
593
  if __name__ == "__main__":
594
+ demo.launch(
595
+ server_name="0.0.0.0",
596
+ server_port=7860,
597
+ ssr_mode=False,
598
+ show_error=True,
599
+ allowed_paths=[tempfile.gettempdir()],
600
+ css=CSS,
601
+ js=JS,
602
+ )
ntv3_tracks_pipeline.py CHANGED
@@ -24,11 +24,6 @@ try:
24
  except Exception:
25
  plt = None
26
 
27
- try:
28
- import seaborn as sns
29
- except Exception:
30
- sns = None
31
-
32
 
33
  # ---------------------------------------------------------------------
34
  # Assembly <-> species mapping
@@ -66,29 +61,42 @@ ASSEMBLY_TO_SPECIES = {
66
  }
67
  SPECIES_TO_ASSEMBLY = {v: k for k, v in ASSEMBLY_TO_SPECIES.items()}
68
 
69
- # Minimal UCSC FASTA sources (extend as needed)
70
- ASSEMBLY_TO_UCSC_FA_GZ = {
71
- "hg38": "https://hgdownload.soe.ucsc.edu/goldenPath/hg38/bigZips/hg38.fa.gz",
72
- "mm10": "https://hgdownload.soe.ucsc.edu/goldenPath/mm10/bigZips/mm10.fa.gz",
73
- "dm6": "https://hgdownload.soe.ucsc.edu/goldenPath/dm6/bigZips/dm6.fa.gz",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
74
  }
75
 
76
-
77
  def _sanitize_dna(seq: str) -> str:
78
  seq = seq.upper()
79
  return "".join(ch if ch in ("A", "C", "G", "T", "N") else "N" for ch in seq)
80
 
81
 
82
- def _download_file(url: str, dst: Path) -> None:
83
  if requests is None:
84
  raise ImportError("requests is required for genome download. Install with: pip install requests")
85
- dst.parent.mkdir(parents=True, exist_ok=True)
86
- with requests.get(url, stream=True, timeout=60) as r:
87
- r.raise_for_status()
88
- with open(dst, "wb") as f:
89
- for chunk in r.iter_content(chunk_size=1024 * 1024):
90
- if chunk:
91
- f.write(chunk)
92
 
93
 
94
  def _ensure_fasta_for_assembly(assembly: str, cache_dir: Union[str, Path]) -> Path:
@@ -112,11 +120,6 @@ def _ensure_fasta_for_assembly(assembly: str, cache_dir: Union[str, Path]) -> Pa
112
  f"Either pass fasta_path explicitly, or extend ASSEMBLY_TO_UCSC_FA_GZ."
113
  )
114
 
115
- url = ASSEMBLY_TO_UCSC_FA_GZ[assembly]
116
- if not gz_path.exists():
117
- print(f"Downloading {url} -> {gz_path}")
118
- _download_file(url, gz_path)
119
-
120
  import gzip
121
  print(f"Decompressing {gz_path} -> {fa_path}")
122
  with gzip.open(gz_path, "rb") as fin, open(fa_path, "wb") as fout:
@@ -128,19 +131,6 @@ def _ensure_fasta_for_assembly(assembly: str, cache_dir: Union[str, Path]) -> Pa
128
 
129
  return fa_path
130
 
131
-
132
- def _fetch_from_fasta(fasta_path: Union[str, Path], chrom: str, start: int, end: int) -> str:
133
- if Fasta is None:
134
- raise ImportError("pyfaidx is required for fasta windows. Install with: pip install pyfaidx")
135
-
136
- fasta_path = Path(fasta_path)
137
- if fasta_path.suffix == ".gz":
138
- raise ValueError(f"Got '{fasta_path}' (gz). Please pass an uncompressed .fa (auto-download returns .fa).")
139
-
140
- fasta = Fasta(str(fasta_path), rebuild=True)
141
- return _sanitize_dna(fasta[chrom][start:end].seq)
142
-
143
-
144
  def _pick_device(device: Union[str, int, torch.device]) -> torch.device:
145
  # Handle torch.device objects
146
  if isinstance(device, torch.device):
@@ -191,8 +181,6 @@ def _plot_tracks_fillbetween(
191
  ):
192
  if plt is None:
193
  raise ImportError("matplotlib is required for plotting. Install with: pip install matplotlib")
194
- if sns is None:
195
- raise ImportError("seaborn is required for notebook-style plots. Install with: pip install seaborn")
196
 
197
  n = len(tracks)
198
  if n == 0:
@@ -205,10 +193,25 @@ def _plot_tracks_fillbetween(
205
  any_track = next(iter(tracks.values()))
206
  x = np.linspace(start, end, num=len(any_track), endpoint=False)
207
 
 
 
 
 
208
  for ax, (title, y) in zip(axes, tracks.items()):
209
- ax.fill_between(x, y)
210
- ax.set_title(title)
211
- sns.despine(top=True, right=True, bottom=True)
 
 
 
 
 
 
 
 
 
 
 
212
 
213
  label = f"{chrom}:{start}-{end}" if chrom is not None else f"{start}-{end}"
214
  if assembly is not None:
@@ -263,12 +266,6 @@ class NTv3TracksPipeline(Pipeline):
263
  self.pred_center_fraction = float(pred_center_fraction)
264
  self.pred_center_offset_fraction = float(pred_center_offset_fraction)
265
 
266
- if self.default_species not in SPECIES_TO_ASSEMBLY:
267
- raise ValueError(
268
- f"default_species='{self.default_species}' is not supported. "
269
- f"Supported species: {sorted(SPECIES_TO_ASSEMBLY.keys())}"
270
- )
271
-
272
  if isinstance(model, str):
273
  self.config = AutoConfig.from_pretrained(model, trust_remote_code=trust_remote_code, token=token)
274
  self.model = AutoModel.from_pretrained(model, trust_remote_code=trust_remote_code, token=token)
@@ -350,6 +347,30 @@ class NTv3TracksPipeline(Pipeline):
350
  return torch.device("cpu")
351
  return dev
352
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
353
  def preprocess(self, inputs: Dict[str, Any], **kwargs: Any) -> Dict[str, Any]:
354
  species, assembly = self._resolve_species_and_assembly(inputs)
355
 
@@ -365,10 +386,8 @@ class NTv3TracksPipeline(Pipeline):
365
  start = int(inputs["start"])
366
  end = int(inputs["end"])
367
  window_len = end - start
368
- fasta_path = inputs.get("fasta_path")
369
- if fasta_path is None:
370
- fasta_path = _ensure_fasta_for_assembly(assembly, self.genome_cache_dir)
371
- seq = _fetch_from_fasta(fasta_path, chrom, start, end)
372
 
373
  # Tokenize with padding
374
  batch = self.tokenizer([seq], add_special_tokens=False, padding=True, pad_to_multiple_of=128, return_tensors="pt")
 
24
  except Exception:
25
  plt = None
26
 
 
 
 
 
 
27
 
28
  # ---------------------------------------------------------------------
29
  # Assembly <-> species mapping
 
61
  }
62
  SPECIES_TO_ASSEMBLY = {v: k for k, v in ASSEMBLY_TO_SPECIES.items()}
63
 
64
+ # BED element to color mapping (shared between pipeline and app)
65
+ BED_ELEMENT_COLORS = {
66
+ "protein_coding_gene": "#E74C3C", # Red
67
+ "lncRNA": "#2ECC71", # Green
68
+ "exon": "#9B59B6", # Purple
69
+ "intron": "#F39C12", # Orange
70
+ "splice_donor": "#1ABC9C", # Teal
71
+ "splice_acceptor": "#E67E22", # Dark orange
72
+ "CTCF-bound": "#3498DB", # Light blue
73
+ "polyA_signal": "#95A5A6", # Gray
74
+ "enhancer_Tissue_specific": "#D35400", # Dark red
75
+ "enhancer_Tissue_invariant": "#16A085", # Dark teal
76
+ "promoter_Tissue_specific": "#C0392B", # Dark red 2
77
+ "promoter_Tissue_invariant": "#27AE60", # Dark green
78
+ "5UTR+": "#8E44AD", # Dark purple
79
+ "5UTR-": "#D68910", # Dark orange 2
80
+ "3UTR+": "#138D75", # Dark teal 2
81
+ "3UTR-": "#2874A6", # Dark blue
82
+ "skipped_exon": "#7D3C98", # Purple 2
83
+ "always_on_exon": "#A93226", # Red 2
84
+ "start_codon": "#196F3D", # Green 2
85
+ "stop_codon": "#B9770E", # Brown
86
+ "ORF": "#1F618D", # Blue 2
87
  }
88
 
 
89
  def _sanitize_dna(seq: str) -> str:
90
  seq = seq.upper()
91
  return "".join(ch if ch in ("A", "C", "G", "T", "N") else "N" for ch in seq)
92
 
93
 
94
+ def _get_dna_sequence(assembly: str, chrom: str, start: int, end: int) -> str:
95
  if requests is None:
96
  raise ImportError("requests is required for genome download. Install with: pip install requests")
97
+ url = f"https://api.genome.ucsc.edu/getData/sequence?genome={assembly};chrom={chrom};start={start};end={end}"
98
+ seq = requests.get(url).json()["dna"].upper()
99
+ return seq
 
 
 
 
100
 
101
 
102
  def _ensure_fasta_for_assembly(assembly: str, cache_dir: Union[str, Path]) -> Path:
 
120
  f"Either pass fasta_path explicitly, or extend ASSEMBLY_TO_UCSC_FA_GZ."
121
  )
122
 
 
 
 
 
 
123
  import gzip
124
  print(f"Decompressing {gz_path} -> {fa_path}")
125
  with gzip.open(gz_path, "rb") as fin, open(fa_path, "wb") as fout:
 
131
 
132
  return fa_path
133
 
 
 
 
 
 
 
 
 
 
 
 
 
 
134
  def _pick_device(device: Union[str, int, torch.device]) -> torch.device:
135
  # Handle torch.device objects
136
  if isinstance(device, torch.device):
 
181
  ):
182
  if plt is None:
183
  raise ImportError("matplotlib is required for plotting. Install with: pip install matplotlib")
 
 
184
 
185
  n = len(tracks)
186
  if n == 0:
 
193
  any_track = next(iter(tracks.values()))
194
  x = np.linspace(start, end, num=len(any_track), endpoint=False)
195
 
196
+ # Define color schemes
197
+ # BigWig tracks: use blue/gray tones
198
+ bigwig_color = "#4A90E2" # Blue
199
+
200
  for ax, (title, y) in zip(axes, tracks.items()):
201
+ # Determine color based on track type
202
+ if title in BED_ELEMENT_COLORS:
203
+ color = BED_ELEMENT_COLORS[title]
204
+ else:
205
+ color = bigwig_color
206
+
207
+ ax.fill_between(x, y, color=color, alpha=0.3, linewidth=0)
208
+ ax.plot(x, y, color=color, linewidth=0.8)
209
+ ax.set_title(title, fontsize=10, loc="left")
210
+ ax.grid(alpha=0.2)
211
+ ax.set_yticks([])
212
+ # minimal "despine"
213
+ ax.spines["top"].set_visible(False)
214
+ ax.spines["right"].set_visible(False)
215
 
216
  label = f"{chrom}:{start}-{end}" if chrom is not None else f"{start}-{end}"
217
  if assembly is not None:
 
266
  self.pred_center_fraction = float(pred_center_fraction)
267
  self.pred_center_offset_fraction = float(pred_center_offset_fraction)
268
 
 
 
 
 
 
 
269
  if isinstance(model, str):
270
  self.config = AutoConfig.from_pretrained(model, trust_remote_code=trust_remote_code, token=token)
271
  self.model = AutoModel.from_pretrained(model, trust_remote_code=trust_remote_code, token=token)
 
347
  return torch.device("cpu")
348
  return dev
349
 
350
+ def available_bigwig_track_names(self, species: str | None = None) -> list[str]:
351
+ """
352
+ Return BigWig track IDs for the assembly corresponding to `species`.
353
+ No model forward pass.
354
+ """
355
+ sp = species or self.default_species
356
+ assembly = SPECIES_TO_ASSEMBLY.get(sp)
357
+ if assembly is None:
358
+ raise ValueError(f"Unknown species={sp}. Supported: {sorted(SPECIES_TO_ASSEMBLY.keys())}")
359
+
360
+ if assembly not in self.config.bigwigs_per_file_assembly:
361
+ raise ValueError(
362
+ f"Assembly {assembly} not found in checkpoint config. "
363
+ f"Available: {list(self.config.bigwigs_per_file_assembly.keys())}"
364
+ )
365
+
366
+ return list(self.config.bigwigs_per_file_assembly[assembly])
367
+
368
+ def available_bed_element_names(self) -> List[str]:
369
+ """
370
+ Return BED element names available in this checkpoint (no forward pass).
371
+ """
372
+ return list(self.bed_element_names or [])
373
+
374
  def preprocess(self, inputs: Dict[str, Any], **kwargs: Any) -> Dict[str, Any]:
375
  species, assembly = self._resolve_species_and_assembly(inputs)
376
 
 
386
  start = int(inputs["start"])
387
  end = int(inputs["end"])
388
  window_len = end - start
389
+ seq = _get_dna_sequence(assembly, chrom, start, end)
390
+ seq = _sanitize_dna(seq)
 
 
391
 
392
  # Tokenize with padding
393
  batch = self.tokenizer([seq], add_special_tokens=False, padding=True, pad_to_multiple_of=128, return_tensors="pt")
requirements.txt CHANGED
@@ -4,3 +4,4 @@ numpy
4
  gradio>=4.0.0
5
  pyfaidx
6
  requests
 
 
4
  gradio>=4.0.0
5
  pyfaidx
6
  requests
7
+ matplotlib