bernardo-de-almeida commited on
Commit
9dd80fe
·
1 Parent(s): eec0acd

feat: make robust to no GPU

Browse files
README.md CHANGED
@@ -11,4 +11,4 @@ pinned: false
11
 
12
  # NTv3 Tracks Demo
13
 
14
- This Space deploys the custom Hugging Face `Pipeline` in `ntv3_tracks_pipeline.py`.
 
11
 
12
  # NTv3 Tracks Demo
13
 
14
+ This Space deploys the custom Hugging Face `Pipeline` in `ntv3_tracks_pipeline.py`.
app.py CHANGED
@@ -4,26 +4,22 @@ import tempfile
4
  import time
5
  import uuid
6
  from pathlib import Path
7
- import torch
8
- import numpy as np
9
- import gradio as gr
10
- import spaces
11
 
12
- # Set matplotlib to use non-interactive backend before importing pyplot
13
- # This is required for Gradio which runs on worker threads
14
  import matplotlib
15
- matplotlib.use('Agg')
16
-
17
  import matplotlib.pyplot as plt
 
 
18
 
 
19
  from ntv3_tracks_pipeline import (
20
- load_ntv3_tracks_pipeline,
21
- BED_ELEMENT_COLORS,
22
  ASSEMBLY_TO_SPECIES,
 
23
  SPECIES_WITH_COORDINATE_SUPPORT,
 
24
  )
25
- from bigwig_export import create_bigwig_zip, _softmax_last
26
 
 
27
 
28
  # -----------------------------
29
  # Env / auth
@@ -36,7 +32,9 @@ HF_TOKEN = (
36
  or os.environ.get("HUGGINGFACEHUB_API_TOKEN")
37
  )
38
  if HF_TOKEN is None:
39
- raise RuntimeError("Missing Hugging Face token. Set NTV3_HF_TOKEN as a Space Secret.")
 
 
40
 
41
  # asyncio.set_event_loop_policy(asyncio.DefaultEventLoopPolicy())
42
 
@@ -49,6 +47,7 @@ SEARCH_MAX_RESULTS = int(os.environ.get("SEARCH_MAX_RESULTS", "50"))
49
  pipe = None
50
  current_model_id = MODEL_ID
51
 
 
52
  def load_pipeline(model_id: str, species: str = DEFAULT_SPECIES):
53
  """Load or reload the pipeline with a new model."""
54
  global pipe, current_model_id
@@ -62,6 +61,7 @@ def load_pipeline(model_id: str, species: str = DEFAULT_SPECIES):
62
  current_model_id = model_id
63
  return pipe
64
 
 
65
  # Load initial pipeline
66
  load_pipeline(MODEL_ID, DEFAULT_SPECIES)
67
 
@@ -73,6 +73,7 @@ load_pipeline(MODEL_ID, DEFAULT_SPECIES)
73
  _t0 = None
74
  _tlast = None
75
 
 
76
  def tprint(msg: str):
77
  "Function to print timing information"
78
  global _t0, _tlast
@@ -87,6 +88,21 @@ def tprint(msg: str):
87
  print(f"[timing] {msg}: {now - _tlast:.3f}s (total {now - _t0:.3f}s)")
88
  _tlast = now
89
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
90
  def _global_stride(L: int, target: int) -> int:
91
  if target <= 0 or L <= target:
92
  return 1
@@ -111,7 +127,7 @@ def _make_tracks_figure(x: np.ndarray, series: list[tuple[str, np.ndarray]]):
111
  color = BED_ELEMENT_COLORS[title]
112
  else:
113
  color = bigwig_color
114
-
115
  ax.fill_between(x, y, color=color, alpha=0.3, linewidth=0)
116
  ax.plot(x, y, color=color, linewidth=0.8)
117
  ax.set_title(title, fontsize=10, loc="left")
@@ -143,52 +159,52 @@ def _load_track_metadata() -> dict[str, str]:
143
  """Load track metadata from CSV and create display name mapping."""
144
  if _TRACK_METADATA_CACHE:
145
  return _TRACK_METADATA_CACHE
146
-
147
  csv_path = Path(__file__).parent / "data" / "functional_tracks_metadata.csv"
148
  if not csv_path.exists():
149
  return {}
150
-
151
  metadata = {}
152
  try:
153
- with open(csv_path, 'r', encoding='utf-8') as f:
154
  reader = csv.DictReader(f)
155
  for row in reader:
156
- track_id = row['file_id']
157
- tissue = row.get('tissue', '').strip()
158
- assay = row.get('assay', '').strip()
159
- experiment_target = row.get('experiment_target', '').strip()
160
- biosample_type = row.get('biosample_type', '').strip()
161
- strand = row.get('strand', '').strip()
162
-
163
  # Build display name from available fields
164
  parts = []
165
- if biosample_type and biosample_type != 'tissue':
166
  parts.append(biosample_type)
167
  if tissue:
168
  parts.append(tissue)
169
  if assay:
170
  # For RNA-seq, include strand information if available
171
  if strand:
172
- if strand == 'plus':
173
- strand = '+'
174
- elif strand == 'minus':
175
- strand = '-'
176
  parts.append(f"{assay} {strand}")
177
  else:
178
  parts.append(assay)
179
- if experiment_target and experiment_target not in ('none', 'RNA-seq'):
180
  parts.append(experiment_target)
181
-
182
  if parts:
183
  display_name = " - ".join(parts)
184
  else:
185
  display_name = track_id # Fallback to ID if no metadata
186
-
187
  metadata[track_id] = display_name
188
  except Exception as e:
189
  print(f"Warning: Could not load track metadata: {e}")
190
  return {}
191
-
192
  _TRACK_METADATA_CACHE.update(metadata)
193
  return metadata
194
 
@@ -235,7 +251,7 @@ def _get_species_with_bigwigs() -> set[str]:
235
  """Get set of species that have BigWig tracks available in the current model."""
236
  if pipe is None:
237
  return set()
238
-
239
  species_with_bigwigs = set()
240
  for species in ASSEMBLY_TO_SPECIES.values():
241
  if _has_bigwigs(species):
@@ -287,32 +303,38 @@ def search_bigwigs(species: str, query: str, current_selected: list[str]):
287
  if query is None:
288
  query = ""
289
  query_stripped = query.strip()
290
-
291
  # If query is empty, return empty results immediately (don't show all tracks)
292
  if not query_stripped:
293
  displayed_selected = current_selected or []
294
  show_selected = bool(displayed_selected)
295
  return (
296
- gr.update(choices=[], value=[], interactive=True), # empty results, explicitly clear checked state
297
- gr.update(visible=show_selected, choices=displayed_selected, value=displayed_selected), # show ALL selected tracks
 
 
 
 
 
 
298
  )
299
-
300
  names = _get_bigwig_names(species)
301
  # Search in both track IDs and display names
302
  metadata = _load_track_metadata()
303
  query_lower = query_stripped.lower()
304
-
305
  # Show selected tracks section if user is typing or has selections
306
  show_selected = bool(query_stripped) or bool(current_selected)
307
-
308
  # Show ALL selected tracks (not limited to 20)
309
  displayed_selected = current_selected or []
310
-
311
  # Extract track IDs from already selected tracks (to exclude them from results)
312
  selected_track_ids = set()
313
  if current_selected:
314
  selected_track_ids = {_extract_track_id(x) for x in current_selected}
315
-
316
  # Build list of (display_format, track_id) tuples for searching
317
  track_display_pairs = []
318
  for track_id in names:
@@ -322,20 +344,26 @@ def search_bigwigs(species: str, query: str, current_selected: list[str]):
322
  display_name = metadata.get(track_id, track_id)
323
  display_format = _format_track_for_display(track_id)
324
  track_display_pairs.append((display_format, track_id, display_name))
325
-
326
  # Filter by query (search in display name, display format, and track_id)
327
  matching = []
328
  for display_format, track_id, display_name in track_display_pairs:
329
- if (query_lower in track_id.lower() or
330
- query_lower in display_name.lower() or
331
- query_lower in display_format.lower()):
 
 
332
  matching.append(display_format)
333
-
334
  # Limit search results
335
  results = matching[:SEARCH_MAX_RESULTS]
336
  return (
337
- gr.update(choices=results, value=[], interactive=True), # results - limited to SEARCH_MAX_RESULTS, explicitly clear checked state
338
- gr.update(visible=show_selected, choices=displayed_selected, value=displayed_selected), # show ALL selected tracks
 
 
 
 
339
  )
340
 
341
 
@@ -344,16 +372,16 @@ def add_selected(current_selected: list[str], to_add: list[str]):
344
  # Extract track IDs from current selection (in case they're in display format)
345
  cur_ids = [_extract_track_id(x) for x in (current_selected or [])]
346
  cur_display = [_format_track_for_display(tid) for tid in cur_ids]
347
-
348
  # Extract track IDs from items to add
349
  to_add_ids = [_extract_track_id(x) for x in (to_add or [])]
350
-
351
  # Add new track IDs
352
  for tid in to_add_ids:
353
  if tid not in cur_ids:
354
  cur_ids.append(tid)
355
  cur_display.append(_format_track_for_display(tid))
356
-
357
  # Show ALL selected tracks (no limit)
358
  return gr.update(choices=cur_display, value=cur_display) # show all selected tracks
359
 
@@ -371,29 +399,34 @@ def update_coords_on_species_change(species: str):
371
  coords = DEFAULT_COORDS.get(species, DEFAULT_COORDS["human"])
372
  return coords["chrom"], coords["start"], coords["end"]
373
 
 
374
  def reset_on_species_change(species: str):
375
  # Clear results + selected when species changes (avoids mismatched IDs)
376
  try:
377
  track_ids = _get_bigwig_names(species) # warms cache if available
378
  # Format available tracks for display
379
  formatted_tracks = [_format_track_for_display(tid) for tid in track_ids]
380
-
381
  # Get default tracks for this species (filter to what's available)
382
  default_track_ids = [tid for tid in DEFAULT_BIGWIG_TRACKS if tid in track_ids]
383
- default_formatted = [_format_track_for_display(tid) for tid in default_track_ids]
384
-
 
 
385
  # Show selected tracks section if there are default tracks
386
  show_selected = bool(default_formatted)
387
-
388
  return (
389
- gr.update(value=""), # query textbox
390
  gr.update(choices=[], value=[]), # results list
391
- gr.update(choices=formatted_tracks, value=default_formatted, visible=show_selected), # selected list with defaults
 
 
392
  )
393
  except (ValueError, AttributeError):
394
  # Species doesn't have bigwigs, that's okay
395
  return (
396
- gr.update(value=""), # query textbox
397
  gr.update(choices=[], value=[]), # results list
398
  gr.update(choices=[], value=[], visible=False), # selected list (hidden)
399
  )
@@ -402,7 +435,7 @@ def reset_on_species_change(species: str):
402
  # -----------------------------
403
  # Predict
404
  # -----------------------------
405
- @spaces.GPU
406
  def predict(
407
  seq: str,
408
  species: str,
@@ -418,13 +451,13 @@ def predict(
418
  # Debug: verify species is being passed
419
  if not species:
420
  raise gr.Error("Species parameter is missing. Please select a species.")
421
-
422
  # Extract track IDs from display format if needed
423
  bigwig_selected = [_extract_track_id(tid) for tid in bigwig_selected]
424
-
425
  # Determine if using coordinates based on input_type radio button
426
  use_coords = input_type == "Use genomic coordinates"
427
-
428
  if use_coords:
429
  # Check if this species supports coordinate-based fetching
430
  if species not in SPECIES_WITH_COORDINATE_SUPPORT:
@@ -437,7 +470,12 @@ def predict(
437
  raise gr.Error("chrom is required when use_coords=True")
438
  if start is None or end is None or int(end) <= int(start):
439
  raise gr.Error("start/end must be set and end > start when use_coords=True")
440
- inputs = {"chrom": chrom, "start": int(start), "end": int(end), "species": species}
 
 
 
 
 
441
  else:
442
  if not seq or not seq.strip():
443
  raise gr.Error("seq is required when use_coords=False")
@@ -445,7 +483,9 @@ def predict(
445
 
446
  # Verify species is in inputs before calling pipeline
447
  if "species" not in inputs:
448
- raise gr.Error(f"Internal error: species not found in inputs dict. Inputs: {list(inputs.keys())}")
 
 
449
 
450
  tprint("inputs prepared")
451
 
@@ -474,12 +514,16 @@ def predict(
474
  # Check if we have any tracks/elements to plot
475
  has_bigwigs = bw is not None and len(bw_names) > 0
476
  has_bed = bed_logits is not None and len(bed_names) > 0
477
-
478
  if not has_bigwigs and not has_bed:
479
- raise gr.Error("No BigWig tracks or BED elements available for this species in the current model.")
480
-
 
 
481
  if not has_bigwigs and bigwig_selected:
482
- raise gr.Error("No BigWig tracks available for this species, but BigWig tracks were selected. Please deselect BigWig tracks or choose a different species.")
 
 
483
 
484
  # Defaults if user picked none
485
  if has_bigwigs and not bigwig_selected:
@@ -495,7 +539,7 @@ def predict(
495
  ]
496
  # Filter to only include tracks that are available for this species/assembly
497
  bigwig_selected = [tid for tid in default_bigwig_tracks if tid in bw_names]
498
-
499
  if (not bed_elements) and bed_names:
500
  default_bed_elements = ["protein_coding_gene", "exon", "intron"]
501
  # Filter to only include elements that are available
@@ -519,7 +563,7 @@ def predict(
519
  L = bed_logits.shape[0]
520
  else:
521
  raise gr.Error("No data available for plotting.")
522
-
523
  stride = _global_stride(L, PLOT_TARGET_POINTS)
524
 
525
  x0 = int(out.pred_start or 0)
@@ -527,7 +571,7 @@ def predict(
527
  x = np.linspace(x0, x1, num=L, endpoint=False)[::stride]
528
 
529
  series: list[tuple[str, np.ndarray]] = []
530
-
531
  # Add BigWig tracks if available and selected
532
  if has_bigwigs and bigwig_selected:
533
  for tid in bigwig_selected:
@@ -545,7 +589,9 @@ def predict(
545
  fig = _make_tracks_figure(x, series)
546
  tprint("figure created")
547
 
548
- region = f"{out.chrom}:{out.pred_start}-{out.pred_end}" if out.chrom else f"{x0}-{x1}"
 
 
549
  if out.assembly:
550
  region += f" ({out.assembly})"
551
  fig.axes[-1].set_xlabel(region)
@@ -846,9 +892,13 @@ DEFAULT_BED_ELEMENTS = ["protein_coding_gene", "exon", "intron"]
846
 
847
  # Get available BigWig tracks for default species and filter defaults
848
  _init_bigwig = _get_bigwig_names(DEFAULT_SPECIES)
849
- _init_bigwig_selected_ids = [tid for tid in DEFAULT_BIGWIG_TRACKS if tid in _init_bigwig]
 
 
850
  # Format for display
851
- _init_bigwig_selected = [_format_track_for_display(tid) for tid in _init_bigwig_selected_ids]
 
 
852
 
853
  # Filter default BED elements to only those available
854
  _init_bed_selected = [elem for elem in DEFAULT_BED_ELEMENTS if elem in _init_bed]
@@ -864,11 +914,13 @@ DEFAULT_COORDS = {
864
  # Get default coordinates for default species
865
  _default_coords = DEFAULT_COORDS.get(DEFAULT_SPECIES, DEFAULT_COORDS["human"])
866
 
 
867
  # Format species names for display (replace underscores with spaces, capitalize)
868
  def _format_species_name(species: str) -> str:
869
  """Format species name for display."""
870
  return species.replace("_", " ").title()
871
 
 
872
  # Get all available species and format them
873
  _all_species = sorted(ASSEMBLY_TO_SPECIES.values())
874
  _all_species_formatted = [_format_species_name(s) for s in _all_species]
@@ -876,12 +928,18 @@ _all_species_list = ", ".join(_all_species_formatted)
876
 
877
  # Get species with BigWig tracks
878
  _species_with_bigwigs = _get_species_with_bigwigs()
879
- _bigwig_species_formatted = sorted([_format_species_name(s) for s in _species_with_bigwigs])
880
- _bigwig_species_list = ", ".join(_bigwig_species_formatted) if _bigwig_species_formatted else "None (BED elements only)"
 
 
 
 
 
 
881
 
882
  with gr.Blocks(title="NTv3 Tracks Demo") as demo:
883
  gr.Markdown(
884
- f"""
885
  <div class="intro-hero">
886
 
887
  <div class="intro-title">
@@ -933,34 +991,33 @@ with gr.Blocks(title="NTv3 Tracks Demo") as demo:
933
 
934
  </div>
935
  """,
936
- elem_id="intro_markdown",
937
- )
938
-
939
 
940
  gr.Markdown("## Select NTv3 post-trained model")
941
-
942
  # Model display names (without InstaDeepAI/ prefix) and their full IDs
943
  MODEL_OPTIONS = {
944
  "NTv3 650M (pos)": "InstaDeepAI/NTv3_650M_pos",
945
  "NTv3 100M (pos)": "InstaDeepAI/NTv3_100M_pos",
946
  }
947
-
948
  # Reverse mapping: full ID -> display name
949
  MODEL_ID_TO_DISPLAY = {v: k for k, v in MODEL_OPTIONS.items()}
950
-
951
  # Get display name for current model
952
  current_display_name = MODEL_ID_TO_DISPLAY.get(current_model_id, "NTv3 100M (pos)")
953
-
954
  model_selector = gr.Dropdown(
955
  choices=list(MODEL_OPTIONS.keys()),
956
  value=current_display_name,
957
  label="Model",
958
  )
959
-
960
  model_status = gr.Markdown("", visible=False)
961
-
962
  gr.Markdown("## Input DNA sequence")
963
-
964
  # Get all available species from the pipeline
965
  all_species = sorted(ASSEMBLY_TO_SPECIES.values())
966
 
@@ -969,35 +1026,47 @@ with gr.Blocks(title="NTv3 Tracks Demo") as demo:
969
  value=DEFAULT_SPECIES,
970
  label="Species",
971
  )
972
-
973
  # Radio buttons for input type selection
974
  is_supported_default = DEFAULT_SPECIES in SPECIES_WITH_COORDINATE_SUPPORT
975
- initial_input_type = "Use genomic coordinates" if is_supported_default else "Enter DNA sequence"
 
 
976
  input_type = gr.Radio(
977
  choices=["Use genomic coordinates", "Enter DNA sequence"],
978
  value=initial_input_type,
979
  label="Input method",
980
  visible=is_supported_default, # Only show if species supports coordinates
981
  )
982
-
983
  # Coordinates section - visible only when "Use genomic coordinates" is selected
984
- with gr.Group(visible=is_supported_default and initial_input_type == "Use genomic coordinates", elem_id="coords_group") as coords_group:
985
- gr.Markdown("**Genomic coordinates** (supported species: " + ", ".join(sorted(SPECIES_WITH_COORDINATE_SUPPORT)) + ")")
 
 
 
 
 
 
 
 
986
  with gr.Row():
987
  chrom = gr.Textbox(label="Chromosome", value=_default_coords["chrom"])
988
- start = gr.Number(label="Start", value=_default_coords["start"], precision=0)
 
 
989
  end = gr.Number(label="End", value=_default_coords["end"], precision=0)
990
-
991
  # DNA sequence section - visible only when "Enter DNA sequence" is selected
992
  # Using Textbox directly (not wrapped in Group) to avoid visual border/line
993
  seq = gr.Textbox(
994
- lines=4,
995
- label="Input DNA sequence",
996
  placeholder="ACGT...",
997
  visible=initial_input_type == "Enter DNA sequence",
998
- elem_id="dna_sequence_input"
999
  )
1000
-
1001
  def change_model(display_name: str, species: str):
1002
  """Reload pipeline with new model."""
1003
  try:
@@ -1007,14 +1076,18 @@ with gr.Blocks(title="NTv3 Tracks Demo") as demo:
1007
  else:
1008
  # Fallback: assume it's already a model ID or custom value
1009
  model_id = display_name
1010
-
1011
  load_pipeline(model_id, species)
1012
  # Update available tracks/elements
1013
  _get_bigwig_names(species) # warm cache
1014
- return gr.update(value="✅ Model loaded successfully"), gr.update(visible=True)
 
 
1015
  except Exception as e:
1016
- return gr.update(value=f"❌ Error loading model: {str(e)}"), gr.update(visible=True)
1017
-
 
 
1018
  model_selector.change(
1019
  fn=change_model,
1020
  inputs=[model_selector, species],
@@ -1022,7 +1095,7 @@ with gr.Blocks(title="NTv3 Tracks Demo") as demo:
1022
  )
1023
 
1024
  gr.Markdown("## Select functional tracks")
1025
-
1026
  # Button to download tracks metadata
1027
  def get_metadata_file_path():
1028
  """Return path to metadata CSV file for download."""
@@ -1030,7 +1103,7 @@ with gr.Blocks(title="NTv3 Tracks Demo") as demo:
1030
  if csv_path.exists():
1031
  return str(csv_path)
1032
  return None
1033
-
1034
  metadata_file_path = get_metadata_file_path()
1035
  download_metadata_btn = gr.Button(
1036
  "📋 Download metadata for all functional tracks",
@@ -1041,19 +1114,19 @@ with gr.Blocks(title="NTv3 Tracks Demo") as demo:
1041
  label="Tracks metadata",
1042
  visible=False,
1043
  )
1044
-
1045
  def download_metadata():
1046
  """Return metadata file for download."""
1047
  if metadata_file_path and Path(metadata_file_path).exists():
1048
  return gr.update(value=metadata_file_path, visible=True)
1049
  return gr.update(visible=False)
1050
-
1051
  download_metadata_btn.click(
1052
  fn=download_metadata,
1053
  inputs=[],
1054
  outputs=[metadata_download_file],
1055
  )
1056
-
1057
  bigwig_no_tracks_msg = gr.Markdown(
1058
  "⚠️ No functional genomic tracks available for this species in the current model.",
1059
  visible=False,
@@ -1063,7 +1136,9 @@ with gr.Blocks(title="NTv3 Tracks Demo") as demo:
1063
  choices=_init_bigwig_selected,
1064
  value=_init_bigwig_selected,
1065
  label="Selected functional tracks (used for prediction)",
1066
- visible=bool(_init_bigwig_selected), # Show if there are default tracks, otherwise hidden
 
 
1067
  )
1068
 
1069
  bigwig_query = gr.Textbox(
@@ -1081,7 +1156,7 @@ with gr.Blocks(title="NTv3 Tracks Demo") as demo:
1081
  bigwig_remove_btn = gr.Button("Remove all selected")
1082
 
1083
  gr.Markdown("## Select genome annotation elements")
1084
-
1085
  bed_elements = gr.Dropdown(
1086
  choices=_init_bed,
1087
  value=_init_bed_selected if _init_bed_selected else [],
@@ -1092,17 +1167,21 @@ with gr.Blocks(title="NTv3 Tracks Demo") as demo:
1092
  btn = gr.Button("Predict", elem_id="predict_btn")
1093
 
1094
  gr.Markdown("## NTv3 predictions for selected tracks and elements")
1095
- gr.Markdown("Note: NTv3 predictions are for the 37.5% center of the input sequence.")
1096
-
 
 
1097
  plot = gr.Plot(label="", elem_id="tracks_plot")
1098
  export_png = gr.File(elem_id="export_png_hidden", interactive=False)
1099
-
1100
  # State to store prediction output and selections for BigWig export
1101
  prediction_state = gr.State(value=None)
1102
  bigwig_selected_state = gr.State(value=[])
1103
  bed_elements_state = gr.State(value=[])
1104
-
1105
- download_bigwig_btn = gr.Button("📥 Download tracks as BigWig files (ZIP)", variant="secondary")
 
 
1106
  export_bigwig = gr.File(label="Download BigWig files", visible=False)
1107
 
1108
  with gr.Accordion("Meta (click to expand)", open=False):
@@ -1124,24 +1203,26 @@ with gr.Blocks(title="NTv3 Tracks Demo") as demo:
1124
  )
1125
 
1126
  # Helper function to get search results choices directly (without gr.update wrapper)
1127
- def _get_search_results_choices(species: str, query: str, current_selected: list[str]) -> list[str]:
 
 
1128
  """Get search results choices as a list, excluding selected tracks."""
1129
  if query is None:
1130
  query = ""
1131
  query_stripped = query.strip()
1132
-
1133
  if not query_stripped:
1134
  return []
1135
-
1136
  names = _get_bigwig_names(species)
1137
  metadata = _load_track_metadata()
1138
  query_lower = query_stripped.lower()
1139
-
1140
  # Extract track IDs from already selected tracks
1141
  selected_track_ids = set()
1142
  if current_selected:
1143
  selected_track_ids = {_extract_track_id(x) for x in current_selected}
1144
-
1145
  # Build and filter results
1146
  matching = []
1147
  for track_id in names:
@@ -1149,46 +1230,70 @@ with gr.Blocks(title="NTv3 Tracks Demo") as demo:
1149
  continue
1150
  display_name = metadata.get(track_id, track_id)
1151
  display_format = _format_track_for_display(track_id)
1152
- if (query_lower in track_id.lower() or
1153
- query_lower in display_name.lower() or
1154
- query_lower in display_format.lower()):
 
 
1155
  matching.append(display_format)
1156
-
1157
  return matching[:SEARCH_MAX_RESULTS]
1158
-
1159
  # Auto-add: whenever user checks items in results, add them to Selected,
1160
  # then clear results selection (so it feels like "click to add")
1161
- def _auto_add(selected_now: list[str], results_checked: list[str], current_query: str, current_results: list[str], current_species: str):
 
 
 
 
 
 
1162
  upd = add_selected(selected_now, results_checked) # reuses your function
1163
  # Show selected tracks section if there are selections
1164
  show_selected = bool(upd["value"])
1165
-
1166
  # Get the new search results choices directly (excluding all selected tracks)
1167
- new_choices = _get_search_results_choices(current_species, current_query, upd["value"])
1168
-
 
 
1169
  # Create a completely fresh update with explicit empty value to prevent any checked state
1170
  # Force Gradio to clear checked state by explicitly setting value to empty list
1171
  # Use a workaround: set choices to empty first, then to new_choices to force a complete refresh
1172
  # But since we can only return one update, we'll ensure value is explicitly empty
1173
  # and that we're not preserving any state from the previous update
1174
-
1175
  # Ensure no items from results_checked are in new_choices (they should already be filtered, but double-check)
1176
  checked_track_ids = {_extract_track_id(x) for x in results_checked}
1177
- new_choices_filtered = [c for c in new_choices if _extract_track_id(c) not in checked_track_ids]
1178
-
 
 
1179
  # Create update with explicit empty value - this should force Gradio to clear all checked items
1180
  fresh_update = gr.update(
1181
  choices=new_choices_filtered,
1182
  value=[], # CRITICAL: Explicitly empty list to clear all checked state
1183
  )
1184
-
1185
  return gr.update(**upd, visible=show_selected), fresh_update
1186
 
1187
  # Use a wrapper that ensures results are cleared before updating
1188
- def _auto_add_wrapper(selected_now: list[str], results_checked: list[str], current_query: str, current_results: list[str], current_species: str):
 
 
 
 
 
 
1189
  # First, get the updates
1190
- selected_update, results_update = _auto_add(selected_now, results_checked, current_query, current_results, current_species)
1191
-
 
 
 
 
 
 
1192
  # Force the results update to have an explicit empty value
1193
  # Extract choices from results_update if it's a dict-like object
1194
  if isinstance(results_update, dict):
@@ -1197,21 +1302,26 @@ with gr.Blocks(title="NTv3 Tracks Demo") as demo:
1197
  # If it's a gr.update object, we need to access it differently
1198
  # Try to get choices from the update
1199
  try:
1200
- results_choices = results_update.choices if hasattr(results_update, 'choices') else []
 
 
1201
  except:
1202
  # Fallback: get choices from the search function directly
1203
  results_choices = _get_search_results_choices(
1204
- current_species,
1205
- current_query,
1206
- selected_now + results_checked if isinstance(selected_now, list) and isinstance(results_checked, list) else []
 
 
 
1207
  )
1208
-
1209
  # Create a completely fresh update with explicit empty value
1210
  # This should force Gradio to clear all checked items
1211
  fresh_results_update = gr.update(choices=results_choices, value=[])
1212
-
1213
  return selected_update, fresh_results_update
1214
-
1215
  bigwig_results.change(
1216
  fn=_auto_add_wrapper,
1217
  inputs=[bigwig_selected, bigwig_results, bigwig_query, bigwig_results, species],
@@ -1219,20 +1329,24 @@ with gr.Blocks(title="NTv3 Tracks Demo") as demo:
1219
  )
1220
 
1221
  # Update selected tracks immediately when user unchecks items
1222
- def _update_selected_tracks(selected_value: list[str], current_query: str, current_species: str):
 
 
1223
  """Update selected tracks when user checks/unchecks items directly."""
1224
  # selected_value contains only the currently checked items
1225
  # Update choices to match the current selections (so unchecked items are removed)
1226
  show_selected = bool(selected_value)
1227
-
1228
  # Also update search results to reflect the new selection (tracks that were unchecked can now appear in results)
1229
  search_updates = search_bigwigs(current_species, current_query, selected_value)
1230
-
1231
  return (
1232
- gr.update(choices=selected_value, value=selected_value, visible=show_selected), # Update selected tracks
 
 
1233
  search_updates[0], # Update search results
1234
  )
1235
-
1236
  bigwig_selected.change(
1237
  fn=_update_selected_tracks,
1238
  inputs=[bigwig_selected, bigwig_query, species],
@@ -1261,7 +1375,7 @@ with gr.Blocks(title="NTv3 Tracks Demo") as demo:
1261
  inputs=[species],
1262
  outputs=[bigwig_query, bigwig_results, bigwig_selected],
1263
  )
1264
-
1265
  # Update coordinates visibility and values when species changes
1266
  def update_on_species_change(species: str, input_type_val: str):
1267
  """Update coordinates visibility and values when species changes."""
@@ -1272,15 +1386,19 @@ with gr.Blocks(title="NTv3 Tracks Demo") as demo:
1272
  use_coords = input_type_val == "Use genomic coordinates"
1273
  show_coords = is_supported and use_coords
1274
  show_seq = not show_coords
1275
-
1276
  # Format available tracks for display if species has bigwigs
1277
  if has_bigwigs:
1278
  try:
1279
  track_ids = _get_bigwig_names(species)
1280
  formatted_tracks = [_format_track_for_display(tid) for tid in track_ids]
1281
  # Get default tracks for this species (filter to what's available)
1282
- default_track_ids = [tid for tid in DEFAULT_BIGWIG_TRACKS if tid in track_ids]
1283
- default_formatted = [_format_track_for_display(tid) for tid in default_track_ids]
 
 
 
 
1284
  # Show selected tracks section if there are default tracks
1285
  show_selected_tracks = bool(default_formatted)
1286
  except:
@@ -1291,29 +1409,42 @@ with gr.Blocks(title="NTv3 Tracks Demo") as demo:
1291
  formatted_tracks = []
1292
  default_formatted = []
1293
  show_selected_tracks = False
1294
-
1295
  return (
1296
  gr.update(visible=show_coords, value=coords["chrom"]),
1297
  gr.update(visible=show_coords, value=coords["start"]),
1298
  gr.update(visible=show_coords, value=coords["end"]),
1299
- gr.update(visible=is_supported, value="Use genomic coordinates" if is_supported else "Enter DNA sequence"), # Update input_type radio
 
 
 
 
 
1300
  gr.update(visible=show_coords), # Show/hide coords_group
1301
- gr.update(visible=show_seq), # Show/hide seq
1302
- gr.update(visible=not has_bigwigs), # Show "no tracks" message if no bigwigs
1303
- gr.update(visible=show_selected_tracks, choices=formatted_tracks, value=default_formatted), # Show bigwig selection with defaults if available
 
 
 
 
 
 
1304
  gr.update(visible=has_bigwigs), # Show bigwig query if available
1305
  gr.update(visible=has_bigwigs), # Show bigwig results if available
1306
  gr.update(visible=has_bigwigs), # Show bigwig buttons if available
1307
  )
1308
-
1309
  # Update input type radio visibility and value when species changes
1310
  def update_input_type_on_species_change(species: str):
1311
  """Update input type radio when species changes."""
1312
  is_supported = species in SPECIES_WITH_COORDINATE_SUPPORT
1313
  # If species doesn't support coordinates, default to sequence input
1314
- default_value = "Use genomic coordinates" if is_supported else "Enter DNA sequence"
 
 
1315
  return gr.update(visible=is_supported, value=default_value)
1316
-
1317
  # Update input visibility when radio button changes
1318
  def update_input_visibility(input_type_val: str, species: str):
1319
  """Update input visibility when radio button changes."""
@@ -1321,15 +1452,21 @@ with gr.Blocks(title="NTv3 Tracks Demo") as demo:
1321
  if input_type_val == "Enter DNA sequence":
1322
  # Hide coordinates, show sequence
1323
  return (
1324
- gr.update(visible=False), # coords_group - always hide when sequence is selected
1325
- gr.update(visible=True), # seq - always show when sequence is selected
 
 
1326
  )
1327
  elif input_type_val == "Use genomic coordinates":
1328
  # Show coordinates only if species supports it
1329
  is_supported = species in SPECIES_WITH_COORDINATE_SUPPORT
1330
  return (
1331
- gr.update(visible=is_supported), # coords_group - show only if supported
1332
- gr.update(visible=not is_supported), # seq - hide when coordinates are shown
 
 
 
 
1333
  )
1334
  else:
1335
  # Fallback: hide both (shouldn't happen)
@@ -1337,22 +1474,31 @@ with gr.Blocks(title="NTv3 Tracks Demo") as demo:
1337
  gr.update(visible=False),
1338
  gr.update(visible=False),
1339
  )
1340
-
1341
  species.change(
1342
  fn=update_input_type_on_species_change,
1343
  inputs=[species],
1344
  outputs=[input_type],
1345
  )
1346
-
1347
  species.change(
1348
  fn=update_on_species_change,
1349
  inputs=[species, input_type],
1350
  outputs=[
1351
- chrom, start, end, input_type, coords_group, seq,
1352
- bigwig_no_tracks_msg, bigwig_selected, bigwig_query, bigwig_results, bigwig_buttons_row
 
 
 
 
 
 
 
 
 
1353
  ],
1354
  )
1355
-
1356
  input_type.change(
1357
  fn=update_input_visibility,
1358
  inputs=[input_type, species],
@@ -1361,21 +1507,39 @@ with gr.Blocks(title="NTv3 Tracks Demo") as demo:
1361
 
1362
  btn.click(
1363
  fn=predict,
1364
- inputs=[seq, species, chrom, start, end, input_type, bigwig_selected, bed_elements],
1365
- outputs=[plot, export_png, meta, prediction_state, bigwig_selected_state, bed_elements_state],
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1366
  api_name="predict",
1367
  )
1368
-
1369
  def download_bigwig_zip(out, bw_selected, bed_selected):
1370
  """Create and return BigWig zip file."""
1371
  try:
1372
  zip_path = create_bigwig_zip(out, bw_selected, bed_selected)
1373
  return gr.update(value=zip_path, visible=True)
1374
  except ImportError as e:
1375
- raise gr.Error("pyBigWig is required for BigWig export. Install with: pip install pyBigWig")
 
 
1376
  except Exception as e:
1377
  raise gr.Error(f"Error creating BigWig files: {str(e)}")
1378
-
1379
  download_bigwig_btn.click(
1380
  fn=download_bigwig_zip,
1381
  inputs=[prediction_state, bigwig_selected_state, bed_elements_state],
@@ -1392,4 +1556,3 @@ if __name__ == "__main__":
1392
  css=CSS,
1393
  js=JS,
1394
  )
1395
-
 
4
  import time
5
  import uuid
6
  from pathlib import Path
 
 
 
 
7
 
8
+ import gradio as gr
 
9
  import matplotlib
 
 
10
  import matplotlib.pyplot as plt
11
+ import numpy as np
12
+ import torch
13
 
14
+ from bigwig_export import _softmax_last, create_bigwig_zip
15
  from ntv3_tracks_pipeline import (
 
 
16
  ASSEMBLY_TO_SPECIES,
17
+ BED_ELEMENT_COLORS,
18
  SPECIES_WITH_COORDINATE_SUPPORT,
19
+ load_ntv3_tracks_pipeline,
20
  )
 
21
 
22
+ matplotlib.use("Agg")
23
 
24
  # -----------------------------
25
  # Env / auth
 
32
  or os.environ.get("HUGGINGFACEHUB_API_TOKEN")
33
  )
34
  if HF_TOKEN is None:
35
+ raise RuntimeError(
36
+ "Missing Hugging Face token. Set NTV3_HF_TOKEN as a Space Secret."
37
+ )
38
 
39
  # asyncio.set_event_loop_policy(asyncio.DefaultEventLoopPolicy())
40
 
 
47
  pipe = None
48
  current_model_id = MODEL_ID
49
 
50
+
51
  def load_pipeline(model_id: str, species: str = DEFAULT_SPECIES):
52
  """Load or reload the pipeline with a new model."""
53
  global pipe, current_model_id
 
61
  current_model_id = model_id
62
  return pipe
63
 
64
+
65
  # Load initial pipeline
66
  load_pipeline(MODEL_ID, DEFAULT_SPECIES)
67
 
 
73
  _t0 = None
74
  _tlast = None
75
 
76
+
77
  def tprint(msg: str):
78
  "Function to print timing information"
79
  global _t0, _tlast
 
88
  print(f"[timing] {msg}: {now - _tlast:.3f}s (total {now - _t0:.3f}s)")
89
  _tlast = now
90
 
91
+
92
+ # GPU decorator
93
+ try:
94
+ import spaces
95
+
96
+ gpu = spaces.GPU
97
+ except Exception:
98
+
99
+ def gpu(*args, **kwargs):
100
+ def wrap(fn):
101
+ return fn
102
+
103
+ return wrap
104
+
105
+
106
  def _global_stride(L: int, target: int) -> int:
107
  if target <= 0 or L <= target:
108
  return 1
 
127
  color = BED_ELEMENT_COLORS[title]
128
  else:
129
  color = bigwig_color
130
+
131
  ax.fill_between(x, y, color=color, alpha=0.3, linewidth=0)
132
  ax.plot(x, y, color=color, linewidth=0.8)
133
  ax.set_title(title, fontsize=10, loc="left")
 
159
  """Load track metadata from CSV and create display name mapping."""
160
  if _TRACK_METADATA_CACHE:
161
  return _TRACK_METADATA_CACHE
162
+
163
  csv_path = Path(__file__).parent / "data" / "functional_tracks_metadata.csv"
164
  if not csv_path.exists():
165
  return {}
166
+
167
  metadata = {}
168
  try:
169
+ with open(csv_path, encoding="utf-8") as f:
170
  reader = csv.DictReader(f)
171
  for row in reader:
172
+ track_id = row["file_id"]
173
+ tissue = row.get("tissue", "").strip()
174
+ assay = row.get("assay", "").strip()
175
+ experiment_target = row.get("experiment_target", "").strip()
176
+ biosample_type = row.get("biosample_type", "").strip()
177
+ strand = row.get("strand", "").strip()
178
+
179
  # Build display name from available fields
180
  parts = []
181
+ if biosample_type and biosample_type != "tissue":
182
  parts.append(biosample_type)
183
  if tissue:
184
  parts.append(tissue)
185
  if assay:
186
  # For RNA-seq, include strand information if available
187
  if strand:
188
+ if strand == "plus":
189
+ strand = "+"
190
+ elif strand == "minus":
191
+ strand = "-"
192
  parts.append(f"{assay} {strand}")
193
  else:
194
  parts.append(assay)
195
+ if experiment_target and experiment_target not in ("none", "RNA-seq"):
196
  parts.append(experiment_target)
197
+
198
  if parts:
199
  display_name = " - ".join(parts)
200
  else:
201
  display_name = track_id # Fallback to ID if no metadata
202
+
203
  metadata[track_id] = display_name
204
  except Exception as e:
205
  print(f"Warning: Could not load track metadata: {e}")
206
  return {}
207
+
208
  _TRACK_METADATA_CACHE.update(metadata)
209
  return metadata
210
 
 
251
  """Get set of species that have BigWig tracks available in the current model."""
252
  if pipe is None:
253
  return set()
254
+
255
  species_with_bigwigs = set()
256
  for species in ASSEMBLY_TO_SPECIES.values():
257
  if _has_bigwigs(species):
 
303
  if query is None:
304
  query = ""
305
  query_stripped = query.strip()
306
+
307
  # If query is empty, return empty results immediately (don't show all tracks)
308
  if not query_stripped:
309
  displayed_selected = current_selected or []
310
  show_selected = bool(displayed_selected)
311
  return (
312
+ gr.update(
313
+ choices=[], value=[], interactive=True
314
+ ), # empty results, explicitly clear checked state
315
+ gr.update(
316
+ visible=show_selected,
317
+ choices=displayed_selected,
318
+ value=displayed_selected,
319
+ ), # show ALL selected tracks
320
  )
321
+
322
  names = _get_bigwig_names(species)
323
  # Search in both track IDs and display names
324
  metadata = _load_track_metadata()
325
  query_lower = query_stripped.lower()
326
+
327
  # Show selected tracks section if user is typing or has selections
328
  show_selected = bool(query_stripped) or bool(current_selected)
329
+
330
  # Show ALL selected tracks (not limited to 20)
331
  displayed_selected = current_selected or []
332
+
333
  # Extract track IDs from already selected tracks (to exclude them from results)
334
  selected_track_ids = set()
335
  if current_selected:
336
  selected_track_ids = {_extract_track_id(x) for x in current_selected}
337
+
338
  # Build list of (display_format, track_id) tuples for searching
339
  track_display_pairs = []
340
  for track_id in names:
 
344
  display_name = metadata.get(track_id, track_id)
345
  display_format = _format_track_for_display(track_id)
346
  track_display_pairs.append((display_format, track_id, display_name))
347
+
348
  # Filter by query (search in display name, display format, and track_id)
349
  matching = []
350
  for display_format, track_id, display_name in track_display_pairs:
351
+ if (
352
+ query_lower in track_id.lower()
353
+ or query_lower in display_name.lower()
354
+ or query_lower in display_format.lower()
355
+ ):
356
  matching.append(display_format)
357
+
358
  # Limit search results
359
  results = matching[:SEARCH_MAX_RESULTS]
360
  return (
361
+ gr.update(
362
+ choices=results, value=[], interactive=True
363
+ ), # results - limited to SEARCH_MAX_RESULTS, explicitly clear checked state
364
+ gr.update(
365
+ visible=show_selected, choices=displayed_selected, value=displayed_selected
366
+ ), # show ALL selected tracks
367
  )
368
 
369
 
 
372
  # Extract track IDs from current selection (in case they're in display format)
373
  cur_ids = [_extract_track_id(x) for x in (current_selected or [])]
374
  cur_display = [_format_track_for_display(tid) for tid in cur_ids]
375
+
376
  # Extract track IDs from items to add
377
  to_add_ids = [_extract_track_id(x) for x in (to_add or [])]
378
+
379
  # Add new track IDs
380
  for tid in to_add_ids:
381
  if tid not in cur_ids:
382
  cur_ids.append(tid)
383
  cur_display.append(_format_track_for_display(tid))
384
+
385
  # Show ALL selected tracks (no limit)
386
  return gr.update(choices=cur_display, value=cur_display) # show all selected tracks
387
 
 
399
  coords = DEFAULT_COORDS.get(species, DEFAULT_COORDS["human"])
400
  return coords["chrom"], coords["start"], coords["end"]
401
 
402
+
403
  def reset_on_species_change(species: str):
404
  # Clear results + selected when species changes (avoids mismatched IDs)
405
  try:
406
  track_ids = _get_bigwig_names(species) # warms cache if available
407
  # Format available tracks for display
408
  formatted_tracks = [_format_track_for_display(tid) for tid in track_ids]
409
+
410
  # Get default tracks for this species (filter to what's available)
411
  default_track_ids = [tid for tid in DEFAULT_BIGWIG_TRACKS if tid in track_ids]
412
+ default_formatted = [
413
+ _format_track_for_display(tid) for tid in default_track_ids
414
+ ]
415
+
416
  # Show selected tracks section if there are default tracks
417
  show_selected = bool(default_formatted)
418
+
419
  return (
420
+ gr.update(value=""), # query textbox
421
  gr.update(choices=[], value=[]), # results list
422
+ gr.update(
423
+ choices=formatted_tracks, value=default_formatted, visible=show_selected
424
+ ), # selected list with defaults
425
  )
426
  except (ValueError, AttributeError):
427
  # Species doesn't have bigwigs, that's okay
428
  return (
429
+ gr.update(value=""), # query textbox
430
  gr.update(choices=[], value=[]), # results list
431
  gr.update(choices=[], value=[], visible=False), # selected list (hidden)
432
  )
 
435
  # -----------------------------
436
  # Predict
437
  # -----------------------------
438
+ @gpu
439
  def predict(
440
  seq: str,
441
  species: str,
 
451
  # Debug: verify species is being passed
452
  if not species:
453
  raise gr.Error("Species parameter is missing. Please select a species.")
454
+
455
  # Extract track IDs from display format if needed
456
  bigwig_selected = [_extract_track_id(tid) for tid in bigwig_selected]
457
+
458
  # Determine if using coordinates based on input_type radio button
459
  use_coords = input_type == "Use genomic coordinates"
460
+
461
  if use_coords:
462
  # Check if this species supports coordinate-based fetching
463
  if species not in SPECIES_WITH_COORDINATE_SUPPORT:
 
470
  raise gr.Error("chrom is required when use_coords=True")
471
  if start is None or end is None or int(end) <= int(start):
472
  raise gr.Error("start/end must be set and end > start when use_coords=True")
473
+ inputs = {
474
+ "chrom": chrom,
475
+ "start": int(start),
476
+ "end": int(end),
477
+ "species": species,
478
+ }
479
  else:
480
  if not seq or not seq.strip():
481
  raise gr.Error("seq is required when use_coords=False")
 
483
 
484
  # Verify species is in inputs before calling pipeline
485
  if "species" not in inputs:
486
+ raise gr.Error(
487
+ f"Internal error: species not found in inputs dict. Inputs: {list(inputs.keys())}"
488
+ )
489
 
490
  tprint("inputs prepared")
491
 
 
514
  # Check if we have any tracks/elements to plot
515
  has_bigwigs = bw is not None and len(bw_names) > 0
516
  has_bed = bed_logits is not None and len(bed_names) > 0
517
+
518
  if not has_bigwigs and not has_bed:
519
+ raise gr.Error(
520
+ "No BigWig tracks or BED elements available for this species in the current model."
521
+ )
522
+
523
  if not has_bigwigs and bigwig_selected:
524
+ raise gr.Error(
525
+ "No BigWig tracks available for this species, but BigWig tracks were selected. Please deselect BigWig tracks or choose a different species."
526
+ )
527
 
528
  # Defaults if user picked none
529
  if has_bigwigs and not bigwig_selected:
 
539
  ]
540
  # Filter to only include tracks that are available for this species/assembly
541
  bigwig_selected = [tid for tid in default_bigwig_tracks if tid in bw_names]
542
+
543
  if (not bed_elements) and bed_names:
544
  default_bed_elements = ["protein_coding_gene", "exon", "intron"]
545
  # Filter to only include elements that are available
 
563
  L = bed_logits.shape[0]
564
  else:
565
  raise gr.Error("No data available for plotting.")
566
+
567
  stride = _global_stride(L, PLOT_TARGET_POINTS)
568
 
569
  x0 = int(out.pred_start or 0)
 
571
  x = np.linspace(x0, x1, num=L, endpoint=False)[::stride]
572
 
573
  series: list[tuple[str, np.ndarray]] = []
574
+
575
  # Add BigWig tracks if available and selected
576
  if has_bigwigs and bigwig_selected:
577
  for tid in bigwig_selected:
 
589
  fig = _make_tracks_figure(x, series)
590
  tprint("figure created")
591
 
592
+ region = (
593
+ f"{out.chrom}:{out.pred_start}-{out.pred_end}" if out.chrom else f"{x0}-{x1}"
594
+ )
595
  if out.assembly:
596
  region += f" ({out.assembly})"
597
  fig.axes[-1].set_xlabel(region)
 
892
 
893
  # Get available BigWig tracks for default species and filter defaults
894
  _init_bigwig = _get_bigwig_names(DEFAULT_SPECIES)
895
+ _init_bigwig_selected_ids = [
896
+ tid for tid in DEFAULT_BIGWIG_TRACKS if tid in _init_bigwig
897
+ ]
898
  # Format for display
899
+ _init_bigwig_selected = [
900
+ _format_track_for_display(tid) for tid in _init_bigwig_selected_ids
901
+ ]
902
 
903
  # Filter default BED elements to only those available
904
  _init_bed_selected = [elem for elem in DEFAULT_BED_ELEMENTS if elem in _init_bed]
 
914
  # Get default coordinates for default species
915
  _default_coords = DEFAULT_COORDS.get(DEFAULT_SPECIES, DEFAULT_COORDS["human"])
916
 
917
+
918
  # Format species names for display (replace underscores with spaces, capitalize)
919
  def _format_species_name(species: str) -> str:
920
  """Format species name for display."""
921
  return species.replace("_", " ").title()
922
 
923
+
924
  # Get all available species and format them
925
  _all_species = sorted(ASSEMBLY_TO_SPECIES.values())
926
  _all_species_formatted = [_format_species_name(s) for s in _all_species]
 
928
 
929
  # Get species with BigWig tracks
930
  _species_with_bigwigs = _get_species_with_bigwigs()
931
+ _bigwig_species_formatted = sorted(
932
+ [_format_species_name(s) for s in _species_with_bigwigs]
933
+ )
934
+ _bigwig_species_list = (
935
+ ", ".join(_bigwig_species_formatted)
936
+ if _bigwig_species_formatted
937
+ else "None (BED elements only)"
938
+ )
939
 
940
  with gr.Blocks(title="NTv3 Tracks Demo") as demo:
941
  gr.Markdown(
942
+ f"""
943
  <div class="intro-hero">
944
 
945
  <div class="intro-title">
 
991
 
992
  </div>
993
  """,
994
+ elem_id="intro_markdown",
995
+ )
 
996
 
997
  gr.Markdown("## Select NTv3 post-trained model")
998
+
999
  # Model display names (without InstaDeepAI/ prefix) and their full IDs
1000
  MODEL_OPTIONS = {
1001
  "NTv3 650M (pos)": "InstaDeepAI/NTv3_650M_pos",
1002
  "NTv3 100M (pos)": "InstaDeepAI/NTv3_100M_pos",
1003
  }
1004
+
1005
  # Reverse mapping: full ID -> display name
1006
  MODEL_ID_TO_DISPLAY = {v: k for k, v in MODEL_OPTIONS.items()}
1007
+
1008
  # Get display name for current model
1009
  current_display_name = MODEL_ID_TO_DISPLAY.get(current_model_id, "NTv3 100M (pos)")
1010
+
1011
  model_selector = gr.Dropdown(
1012
  choices=list(MODEL_OPTIONS.keys()),
1013
  value=current_display_name,
1014
  label="Model",
1015
  )
1016
+
1017
  model_status = gr.Markdown("", visible=False)
1018
+
1019
  gr.Markdown("## Input DNA sequence")
1020
+
1021
  # Get all available species from the pipeline
1022
  all_species = sorted(ASSEMBLY_TO_SPECIES.values())
1023
 
 
1026
  value=DEFAULT_SPECIES,
1027
  label="Species",
1028
  )
1029
+
1030
  # Radio buttons for input type selection
1031
  is_supported_default = DEFAULT_SPECIES in SPECIES_WITH_COORDINATE_SUPPORT
1032
+ initial_input_type = (
1033
+ "Use genomic coordinates" if is_supported_default else "Enter DNA sequence"
1034
+ )
1035
  input_type = gr.Radio(
1036
  choices=["Use genomic coordinates", "Enter DNA sequence"],
1037
  value=initial_input_type,
1038
  label="Input method",
1039
  visible=is_supported_default, # Only show if species supports coordinates
1040
  )
1041
+
1042
  # Coordinates section - visible only when "Use genomic coordinates" is selected
1043
+ with gr.Group(
1044
+ visible=is_supported_default
1045
+ and initial_input_type == "Use genomic coordinates",
1046
+ elem_id="coords_group",
1047
+ ) as coords_group:
1048
+ gr.Markdown(
1049
+ "**Genomic coordinates** (supported species: "
1050
+ + ", ".join(sorted(SPECIES_WITH_COORDINATE_SUPPORT))
1051
+ + ")"
1052
+ )
1053
  with gr.Row():
1054
  chrom = gr.Textbox(label="Chromosome", value=_default_coords["chrom"])
1055
+ start = gr.Number(
1056
+ label="Start", value=_default_coords["start"], precision=0
1057
+ )
1058
  end = gr.Number(label="End", value=_default_coords["end"], precision=0)
1059
+
1060
  # DNA sequence section - visible only when "Enter DNA sequence" is selected
1061
  # Using Textbox directly (not wrapped in Group) to avoid visual border/line
1062
  seq = gr.Textbox(
1063
+ lines=4,
1064
+ label="Input DNA sequence",
1065
  placeholder="ACGT...",
1066
  visible=initial_input_type == "Enter DNA sequence",
1067
+ elem_id="dna_sequence_input",
1068
  )
1069
+
1070
  def change_model(display_name: str, species: str):
1071
  """Reload pipeline with new model."""
1072
  try:
 
1076
  else:
1077
  # Fallback: assume it's already a model ID or custom value
1078
  model_id = display_name
1079
+
1080
  load_pipeline(model_id, species)
1081
  # Update available tracks/elements
1082
  _get_bigwig_names(species) # warm cache
1083
+ return gr.update(value="✅ Model loaded successfully"), gr.update(
1084
+ visible=True
1085
+ )
1086
  except Exception as e:
1087
+ return gr.update(value=f"❌ Error loading model: {str(e)}"), gr.update(
1088
+ visible=True
1089
+ )
1090
+
1091
  model_selector.change(
1092
  fn=change_model,
1093
  inputs=[model_selector, species],
 
1095
  )
1096
 
1097
  gr.Markdown("## Select functional tracks")
1098
+
1099
  # Button to download tracks metadata
1100
  def get_metadata_file_path():
1101
  """Return path to metadata CSV file for download."""
 
1103
  if csv_path.exists():
1104
  return str(csv_path)
1105
  return None
1106
+
1107
  metadata_file_path = get_metadata_file_path()
1108
  download_metadata_btn = gr.Button(
1109
  "📋 Download metadata for all functional tracks",
 
1114
  label="Tracks metadata",
1115
  visible=False,
1116
  )
1117
+
1118
  def download_metadata():
1119
  """Return metadata file for download."""
1120
  if metadata_file_path and Path(metadata_file_path).exists():
1121
  return gr.update(value=metadata_file_path, visible=True)
1122
  return gr.update(visible=False)
1123
+
1124
  download_metadata_btn.click(
1125
  fn=download_metadata,
1126
  inputs=[],
1127
  outputs=[metadata_download_file],
1128
  )
1129
+
1130
  bigwig_no_tracks_msg = gr.Markdown(
1131
  "⚠️ No functional genomic tracks available for this species in the current model.",
1132
  visible=False,
 
1136
  choices=_init_bigwig_selected,
1137
  value=_init_bigwig_selected,
1138
  label="Selected functional tracks (used for prediction)",
1139
+ visible=bool(
1140
+ _init_bigwig_selected
1141
+ ), # Show if there are default tracks, otherwise hidden
1142
  )
1143
 
1144
  bigwig_query = gr.Textbox(
 
1156
  bigwig_remove_btn = gr.Button("Remove all selected")
1157
 
1158
  gr.Markdown("## Select genome annotation elements")
1159
+
1160
  bed_elements = gr.Dropdown(
1161
  choices=_init_bed,
1162
  value=_init_bed_selected if _init_bed_selected else [],
 
1167
  btn = gr.Button("Predict", elem_id="predict_btn")
1168
 
1169
  gr.Markdown("## NTv3 predictions for selected tracks and elements")
1170
+ gr.Markdown(
1171
+ "Note: NTv3 predictions are for the 37.5% center of the input sequence."
1172
+ )
1173
+
1174
  plot = gr.Plot(label="", elem_id="tracks_plot")
1175
  export_png = gr.File(elem_id="export_png_hidden", interactive=False)
1176
+
1177
  # State to store prediction output and selections for BigWig export
1178
  prediction_state = gr.State(value=None)
1179
  bigwig_selected_state = gr.State(value=[])
1180
  bed_elements_state = gr.State(value=[])
1181
+
1182
+ download_bigwig_btn = gr.Button(
1183
+ "📥 Download tracks as BigWig files (ZIP)", variant="secondary"
1184
+ )
1185
  export_bigwig = gr.File(label="Download BigWig files", visible=False)
1186
 
1187
  with gr.Accordion("Meta (click to expand)", open=False):
 
1203
  )
1204
 
1205
  # Helper function to get search results choices directly (without gr.update wrapper)
1206
+ def _get_search_results_choices(
1207
+ species: str, query: str, current_selected: list[str]
1208
+ ) -> list[str]:
1209
  """Get search results choices as a list, excluding selected tracks."""
1210
  if query is None:
1211
  query = ""
1212
  query_stripped = query.strip()
1213
+
1214
  if not query_stripped:
1215
  return []
1216
+
1217
  names = _get_bigwig_names(species)
1218
  metadata = _load_track_metadata()
1219
  query_lower = query_stripped.lower()
1220
+
1221
  # Extract track IDs from already selected tracks
1222
  selected_track_ids = set()
1223
  if current_selected:
1224
  selected_track_ids = {_extract_track_id(x) for x in current_selected}
1225
+
1226
  # Build and filter results
1227
  matching = []
1228
  for track_id in names:
 
1230
  continue
1231
  display_name = metadata.get(track_id, track_id)
1232
  display_format = _format_track_for_display(track_id)
1233
+ if (
1234
+ query_lower in track_id.lower()
1235
+ or query_lower in display_name.lower()
1236
+ or query_lower in display_format.lower()
1237
+ ):
1238
  matching.append(display_format)
1239
+
1240
  return matching[:SEARCH_MAX_RESULTS]
1241
+
1242
  # Auto-add: whenever user checks items in results, add them to Selected,
1243
  # then clear results selection (so it feels like "click to add")
1244
+ def _auto_add(
1245
+ selected_now: list[str],
1246
+ results_checked: list[str],
1247
+ current_query: str,
1248
+ current_results: list[str],
1249
+ current_species: str,
1250
+ ):
1251
  upd = add_selected(selected_now, results_checked) # reuses your function
1252
  # Show selected tracks section if there are selections
1253
  show_selected = bool(upd["value"])
1254
+
1255
  # Get the new search results choices directly (excluding all selected tracks)
1256
+ new_choices = _get_search_results_choices(
1257
+ current_species, current_query, upd["value"]
1258
+ )
1259
+
1260
  # Create a completely fresh update with explicit empty value to prevent any checked state
1261
  # Force Gradio to clear checked state by explicitly setting value to empty list
1262
  # Use a workaround: set choices to empty first, then to new_choices to force a complete refresh
1263
  # But since we can only return one update, we'll ensure value is explicitly empty
1264
  # and that we're not preserving any state from the previous update
1265
+
1266
  # Ensure no items from results_checked are in new_choices (they should already be filtered, but double-check)
1267
  checked_track_ids = {_extract_track_id(x) for x in results_checked}
1268
+ new_choices_filtered = [
1269
+ c for c in new_choices if _extract_track_id(c) not in checked_track_ids
1270
+ ]
1271
+
1272
  # Create update with explicit empty value - this should force Gradio to clear all checked items
1273
  fresh_update = gr.update(
1274
  choices=new_choices_filtered,
1275
  value=[], # CRITICAL: Explicitly empty list to clear all checked state
1276
  )
1277
+
1278
  return gr.update(**upd, visible=show_selected), fresh_update
1279
 
1280
  # Use a wrapper that ensures results are cleared before updating
1281
+ def _auto_add_wrapper(
1282
+ selected_now: list[str],
1283
+ results_checked: list[str],
1284
+ current_query: str,
1285
+ current_results: list[str],
1286
+ current_species: str,
1287
+ ):
1288
  # First, get the updates
1289
+ selected_update, results_update = _auto_add(
1290
+ selected_now,
1291
+ results_checked,
1292
+ current_query,
1293
+ current_results,
1294
+ current_species,
1295
+ )
1296
+
1297
  # Force the results update to have an explicit empty value
1298
  # Extract choices from results_update if it's a dict-like object
1299
  if isinstance(results_update, dict):
 
1302
  # If it's a gr.update object, we need to access it differently
1303
  # Try to get choices from the update
1304
  try:
1305
+ results_choices = (
1306
+ results_update.choices if hasattr(results_update, "choices") else []
1307
+ )
1308
  except:
1309
  # Fallback: get choices from the search function directly
1310
  results_choices = _get_search_results_choices(
1311
+ current_species,
1312
+ current_query,
1313
+ selected_now + results_checked
1314
+ if isinstance(selected_now, list)
1315
+ and isinstance(results_checked, list)
1316
+ else [],
1317
  )
1318
+
1319
  # Create a completely fresh update with explicit empty value
1320
  # This should force Gradio to clear all checked items
1321
  fresh_results_update = gr.update(choices=results_choices, value=[])
1322
+
1323
  return selected_update, fresh_results_update
1324
+
1325
  bigwig_results.change(
1326
  fn=_auto_add_wrapper,
1327
  inputs=[bigwig_selected, bigwig_results, bigwig_query, bigwig_results, species],
 
1329
  )
1330
 
1331
  # Update selected tracks immediately when user unchecks items
1332
+ def _update_selected_tracks(
1333
+ selected_value: list[str], current_query: str, current_species: str
1334
+ ):
1335
  """Update selected tracks when user checks/unchecks items directly."""
1336
  # selected_value contains only the currently checked items
1337
  # Update choices to match the current selections (so unchecked items are removed)
1338
  show_selected = bool(selected_value)
1339
+
1340
  # Also update search results to reflect the new selection (tracks that were unchecked can now appear in results)
1341
  search_updates = search_bigwigs(current_species, current_query, selected_value)
1342
+
1343
  return (
1344
+ gr.update(
1345
+ choices=selected_value, value=selected_value, visible=show_selected
1346
+ ), # Update selected tracks
1347
  search_updates[0], # Update search results
1348
  )
1349
+
1350
  bigwig_selected.change(
1351
  fn=_update_selected_tracks,
1352
  inputs=[bigwig_selected, bigwig_query, species],
 
1375
  inputs=[species],
1376
  outputs=[bigwig_query, bigwig_results, bigwig_selected],
1377
  )
1378
+
1379
  # Update coordinates visibility and values when species changes
1380
  def update_on_species_change(species: str, input_type_val: str):
1381
  """Update coordinates visibility and values when species changes."""
 
1386
  use_coords = input_type_val == "Use genomic coordinates"
1387
  show_coords = is_supported and use_coords
1388
  show_seq = not show_coords
1389
+
1390
  # Format available tracks for display if species has bigwigs
1391
  if has_bigwigs:
1392
  try:
1393
  track_ids = _get_bigwig_names(species)
1394
  formatted_tracks = [_format_track_for_display(tid) for tid in track_ids]
1395
  # Get default tracks for this species (filter to what's available)
1396
+ default_track_ids = [
1397
+ tid for tid in DEFAULT_BIGWIG_TRACKS if tid in track_ids
1398
+ ]
1399
+ default_formatted = [
1400
+ _format_track_for_display(tid) for tid in default_track_ids
1401
+ ]
1402
  # Show selected tracks section if there are default tracks
1403
  show_selected_tracks = bool(default_formatted)
1404
  except:
 
1409
  formatted_tracks = []
1410
  default_formatted = []
1411
  show_selected_tracks = False
1412
+
1413
  return (
1414
  gr.update(visible=show_coords, value=coords["chrom"]),
1415
  gr.update(visible=show_coords, value=coords["start"]),
1416
  gr.update(visible=show_coords, value=coords["end"]),
1417
+ gr.update(
1418
+ visible=is_supported,
1419
+ value="Use genomic coordinates"
1420
+ if is_supported
1421
+ else "Enter DNA sequence",
1422
+ ), # Update input_type radio
1423
  gr.update(visible=show_coords), # Show/hide coords_group
1424
+ gr.update(visible=show_seq), # Show/hide seq
1425
+ gr.update(
1426
+ visible=not has_bigwigs
1427
+ ), # Show "no tracks" message if no bigwigs
1428
+ gr.update(
1429
+ visible=show_selected_tracks,
1430
+ choices=formatted_tracks,
1431
+ value=default_formatted,
1432
+ ), # Show bigwig selection with defaults if available
1433
  gr.update(visible=has_bigwigs), # Show bigwig query if available
1434
  gr.update(visible=has_bigwigs), # Show bigwig results if available
1435
  gr.update(visible=has_bigwigs), # Show bigwig buttons if available
1436
  )
1437
+
1438
  # Update input type radio visibility and value when species changes
1439
  def update_input_type_on_species_change(species: str):
1440
  """Update input type radio when species changes."""
1441
  is_supported = species in SPECIES_WITH_COORDINATE_SUPPORT
1442
  # If species doesn't support coordinates, default to sequence input
1443
+ default_value = (
1444
+ "Use genomic coordinates" if is_supported else "Enter DNA sequence"
1445
+ )
1446
  return gr.update(visible=is_supported, value=default_value)
1447
+
1448
  # Update input visibility when radio button changes
1449
  def update_input_visibility(input_type_val: str, species: str):
1450
  """Update input visibility when radio button changes."""
 
1452
  if input_type_val == "Enter DNA sequence":
1453
  # Hide coordinates, show sequence
1454
  return (
1455
+ gr.update(
1456
+ visible=False
1457
+ ), # coords_group - always hide when sequence is selected
1458
+ gr.update(visible=True), # seq - always show when sequence is selected
1459
  )
1460
  elif input_type_val == "Use genomic coordinates":
1461
  # Show coordinates only if species supports it
1462
  is_supported = species in SPECIES_WITH_COORDINATE_SUPPORT
1463
  return (
1464
+ gr.update(
1465
+ visible=is_supported
1466
+ ), # coords_group - show only if supported
1467
+ gr.update(
1468
+ visible=not is_supported
1469
+ ), # seq - hide when coordinates are shown
1470
  )
1471
  else:
1472
  # Fallback: hide both (shouldn't happen)
 
1474
  gr.update(visible=False),
1475
  gr.update(visible=False),
1476
  )
1477
+
1478
  species.change(
1479
  fn=update_input_type_on_species_change,
1480
  inputs=[species],
1481
  outputs=[input_type],
1482
  )
1483
+
1484
  species.change(
1485
  fn=update_on_species_change,
1486
  inputs=[species, input_type],
1487
  outputs=[
1488
+ chrom,
1489
+ start,
1490
+ end,
1491
+ input_type,
1492
+ coords_group,
1493
+ seq,
1494
+ bigwig_no_tracks_msg,
1495
+ bigwig_selected,
1496
+ bigwig_query,
1497
+ bigwig_results,
1498
+ bigwig_buttons_row,
1499
  ],
1500
  )
1501
+
1502
  input_type.change(
1503
  fn=update_input_visibility,
1504
  inputs=[input_type, species],
 
1507
 
1508
  btn.click(
1509
  fn=predict,
1510
+ inputs=[
1511
+ seq,
1512
+ species,
1513
+ chrom,
1514
+ start,
1515
+ end,
1516
+ input_type,
1517
+ bigwig_selected,
1518
+ bed_elements,
1519
+ ],
1520
+ outputs=[
1521
+ plot,
1522
+ export_png,
1523
+ meta,
1524
+ prediction_state,
1525
+ bigwig_selected_state,
1526
+ bed_elements_state,
1527
+ ],
1528
  api_name="predict",
1529
  )
1530
+
1531
  def download_bigwig_zip(out, bw_selected, bed_selected):
1532
  """Create and return BigWig zip file."""
1533
  try:
1534
  zip_path = create_bigwig_zip(out, bw_selected, bed_selected)
1535
  return gr.update(value=zip_path, visible=True)
1536
  except ImportError as e:
1537
+ raise gr.Error(
1538
+ "pyBigWig is required for BigWig export. Install with: pip install pyBigWig"
1539
+ )
1540
  except Exception as e:
1541
  raise gr.Error(f"Error creating BigWig files: {str(e)}")
1542
+
1543
  download_bigwig_btn.click(
1544
  fn=download_bigwig_zip,
1545
  inputs=[prediction_state, bigwig_selected_state, bed_elements_state],
 
1556
  css=CSS,
1557
  js=JS,
1558
  )
 
bigwig_export.py CHANGED
@@ -3,8 +3,8 @@ BigWig export functionality for NTv3 tracks.
3
  """
4
 
5
  import os
6
- import uuid
7
  import tempfile
 
8
  import zipfile
9
  from typing import TYPE_CHECKING
10
 
@@ -33,7 +33,7 @@ def create_bigwig_zip(
33
  ) -> str:
34
  """
35
  Create BigWig files for selected tracks and save them in a zip file.
36
-
37
  Parameters
38
  ----------
39
  out : NTv3TracksOutput
@@ -42,12 +42,12 @@ def create_bigwig_zip(
42
  List of BigWig track IDs to export.
43
  bed_elements : list[str]
44
  List of BED element names to export.
45
-
46
  Returns
47
  -------
48
  str
49
  Path to the created zip file containing BigWig files.
50
-
51
  Raises
52
  ------
53
  ImportError
@@ -56,46 +56,50 @@ def create_bigwig_zip(
56
  If no predictions are available or no tracks are selected.
57
  """
58
  if pyBigWig is None:
59
- raise ImportError("pyBigWig is required for BigWig export. Install with: pip install pyBigWig")
60
-
 
 
61
  if out is None:
62
  raise ValueError("No predictions available. Please run a prediction first.")
63
-
64
  bw_names = out.bigwig_track_names or []
65
  bw_logits = out.bigwig_tracks_logits
66
  bed_names = out.bed_element_names or []
67
  bed_logits = out.bed_tracks_logits
68
-
69
  if bw_logits is None or not bw_names:
70
  raise ValueError("No BigWig tracks available in model output.")
71
-
72
  # Get genomic coordinates
73
  chrom = out.chrom
74
  if chrom is None:
75
- raise ValueError("Chromosome information not available. Use genomic coordinates for BigWig export.")
76
-
 
 
77
  start = out.start
78
  end = out.end
79
  window_len = out.window_len or (end - start)
80
-
81
  # Calculate prediction region (center 37.5%)
82
  pred_start = out.pred_start or (start + int(window_len * 0.3125))
83
  pred_end = out.pred_end or (pred_start + int(window_len * 0.375))
84
-
85
  # Create temporary directory for BigWig files
86
  tmpdir = tempfile.gettempdir()
87
  output_dir = os.path.join(tmpdir, f"bigwig_outputs_{uuid.uuid4().hex}")
88
  os.makedirs(output_dir, exist_ok=True)
89
-
90
  # Prepare track data list
91
  track_data_list = []
92
-
93
  # Add BigWig tracks
94
  for track_id in bigwig_selected:
95
  if track_id in bw_names:
96
  idx = bw_names.index(track_id)
97
  track_data_list.append(("bigwig", track_id, idx, None))
98
-
99
  # Add BED elements (as probabilities)
100
  if bed_logits is not None and bed_elements:
101
  probs = _softmax_last(bed_logits)
@@ -104,10 +108,10 @@ def create_bigwig_zip(
104
  eidx = bed_names.index(elem_name)
105
  # Store as bed element with probability data
106
  track_data_list.append(("bed", elem_name, eidx, probs[:, eidx, 1]))
107
-
108
  if not track_data_list:
109
  raise ValueError("No tracks selected for export.")
110
-
111
  # Create BigWig files
112
  created_files = []
113
  for track_type, track_id, track_idx, bed_probs in track_data_list:
@@ -119,39 +123,39 @@ def create_bigwig_zip(
119
  continue
120
  track_data = bed_probs.astype(np.float32)
121
  display_name = track_id
122
-
123
  # Clean filename
124
  clean_name = display_name.replace(" ", "_").replace("/", "_").replace("-", "_")
125
  bw_filename = os.path.join(output_dir, f"{clean_name}.bw")
126
-
127
  # Create BigWig file
128
  bw = pyBigWig.open(bw_filename, "w")
129
-
130
  # Add header - use end of genomic window as chromosome size
131
  bw.addHeader([(chrom, end)])
132
-
133
  # Add entries
134
  num_positions = len(track_data)
135
  starts = np.arange(pred_start, pred_start + num_positions, dtype=np.int64)
136
  ends = starts + 1
137
  values = track_data.tolist()
138
-
139
  bw.addEntries(
140
  chroms=[chrom] * len(starts),
141
  starts=starts.tolist(),
142
  ends=ends.tolist(),
143
- values=values
144
  )
145
-
146
  bw.close()
147
  created_files.append(bw_filename)
148
-
149
  # Create zip file
150
  zip_path = os.path.join(tmpdir, f"ntv3_tracks_{uuid.uuid4().hex}.zip")
151
- with zipfile.ZipFile(zip_path, 'w', zipfile.ZIP_DEFLATED) as zipf:
152
  for bw_file in created_files:
153
  zipf.write(bw_file, os.path.basename(bw_file))
154
-
155
  # Clean up individual BigWig files
156
  for bw_file in created_files:
157
  try:
@@ -162,6 +166,5 @@ def create_bigwig_zip(
162
  os.rmdir(output_dir)
163
  except:
164
  pass
165
-
166
- return zip_path
167
 
 
 
3
  """
4
 
5
  import os
 
6
  import tempfile
7
+ import uuid
8
  import zipfile
9
  from typing import TYPE_CHECKING
10
 
 
33
  ) -> str:
34
  """
35
  Create BigWig files for selected tracks and save them in a zip file.
36
+
37
  Parameters
38
  ----------
39
  out : NTv3TracksOutput
 
42
  List of BigWig track IDs to export.
43
  bed_elements : list[str]
44
  List of BED element names to export.
45
+
46
  Returns
47
  -------
48
  str
49
  Path to the created zip file containing BigWig files.
50
+
51
  Raises
52
  ------
53
  ImportError
 
56
  If no predictions are available or no tracks are selected.
57
  """
58
  if pyBigWig is None:
59
+ raise ImportError(
60
+ "pyBigWig is required for BigWig export. Install with: pip install pyBigWig"
61
+ )
62
+
63
  if out is None:
64
  raise ValueError("No predictions available. Please run a prediction first.")
65
+
66
  bw_names = out.bigwig_track_names or []
67
  bw_logits = out.bigwig_tracks_logits
68
  bed_names = out.bed_element_names or []
69
  bed_logits = out.bed_tracks_logits
70
+
71
  if bw_logits is None or not bw_names:
72
  raise ValueError("No BigWig tracks available in model output.")
73
+
74
  # Get genomic coordinates
75
  chrom = out.chrom
76
  if chrom is None:
77
+ raise ValueError(
78
+ "Chromosome information not available. Use genomic coordinates for BigWig export."
79
+ )
80
+
81
  start = out.start
82
  end = out.end
83
  window_len = out.window_len or (end - start)
84
+
85
  # Calculate prediction region (center 37.5%)
86
  pred_start = out.pred_start or (start + int(window_len * 0.3125))
87
  pred_end = out.pred_end or (pred_start + int(window_len * 0.375))
88
+
89
  # Create temporary directory for BigWig files
90
  tmpdir = tempfile.gettempdir()
91
  output_dir = os.path.join(tmpdir, f"bigwig_outputs_{uuid.uuid4().hex}")
92
  os.makedirs(output_dir, exist_ok=True)
93
+
94
  # Prepare track data list
95
  track_data_list = []
96
+
97
  # Add BigWig tracks
98
  for track_id in bigwig_selected:
99
  if track_id in bw_names:
100
  idx = bw_names.index(track_id)
101
  track_data_list.append(("bigwig", track_id, idx, None))
102
+
103
  # Add BED elements (as probabilities)
104
  if bed_logits is not None and bed_elements:
105
  probs = _softmax_last(bed_logits)
 
108
  eidx = bed_names.index(elem_name)
109
  # Store as bed element with probability data
110
  track_data_list.append(("bed", elem_name, eidx, probs[:, eidx, 1]))
111
+
112
  if not track_data_list:
113
  raise ValueError("No tracks selected for export.")
114
+
115
  # Create BigWig files
116
  created_files = []
117
  for track_type, track_id, track_idx, bed_probs in track_data_list:
 
123
  continue
124
  track_data = bed_probs.astype(np.float32)
125
  display_name = track_id
126
+
127
  # Clean filename
128
  clean_name = display_name.replace(" ", "_").replace("/", "_").replace("-", "_")
129
  bw_filename = os.path.join(output_dir, f"{clean_name}.bw")
130
+
131
  # Create BigWig file
132
  bw = pyBigWig.open(bw_filename, "w")
133
+
134
  # Add header - use end of genomic window as chromosome size
135
  bw.addHeader([(chrom, end)])
136
+
137
  # Add entries
138
  num_positions = len(track_data)
139
  starts = np.arange(pred_start, pred_start + num_positions, dtype=np.int64)
140
  ends = starts + 1
141
  values = track_data.tolist()
142
+
143
  bw.addEntries(
144
  chroms=[chrom] * len(starts),
145
  starts=starts.tolist(),
146
  ends=ends.tolist(),
147
+ values=values,
148
  )
149
+
150
  bw.close()
151
  created_files.append(bw_filename)
152
+
153
  # Create zip file
154
  zip_path = os.path.join(tmpdir, f"ntv3_tracks_{uuid.uuid4().hex}.zip")
155
+ with zipfile.ZipFile(zip_path, "w", zipfile.ZIP_DEFLATED) as zipf:
156
  for bw_file in created_files:
157
  zipf.write(bw_file, os.path.basename(bw_file))
158
+
159
  # Clean up individual BigWig files
160
  for bw_file in created_files:
161
  try:
 
166
  os.rmdir(output_dir)
167
  except:
168
  pass
 
 
169
 
170
+ return zip_path
data/functional_tracks_metadata.csv CHANGED
@@ -15887,4 +15887,4 @@ GSM874952,Unknown,,TF ChIP-seq,,RPB2,mouse,geo
15887
  GSM874953,Unknown,,TF ChIP-seq,,RPB2,mouse,geo
15888
  GSM874954,Unknown,,TF ChIP-seq,,RPB2,mouse,geo
15889
  GSM874955,Unknown,,TF ChIP-seq,,RPB2,mouse,geo
15890
- GSM874956,Unknown,,TF ChIP-seq,,RPB2,mouse,geo
 
15887
  GSM874953,Unknown,,TF ChIP-seq,,RPB2,mouse,geo
15888
  GSM874954,Unknown,,TF ChIP-seq,,RPB2,mouse,geo
15889
  GSM874955,Unknown,,TF ChIP-seq,,RPB2,mouse,geo
15890
+ GSM874956,Unknown,,TF ChIP-seq,,RPB2,mouse,geo
ntv3_tracks_pipeline.py CHANGED
@@ -2,7 +2,7 @@ from __future__ import annotations
2
 
3
  from dataclasses import dataclass
4
  from pathlib import Path
5
- from typing import Any, Dict, List, Optional, Union
6
 
7
  import numpy as np
8
  import torch
@@ -109,6 +109,7 @@ BED_ELEMENT_COLORS = {
109
  "ORF": "#1F618D", # Blue 2
110
  }
111
 
 
112
  def _sanitize_dna(seq: str) -> str:
113
  seq = seq.upper()
114
  return "".join(ch if ch in ("A", "C", "G", "T", "N") else "N" for ch in seq)
@@ -117,24 +118,26 @@ def _sanitize_dna(seq: str) -> str:
117
  def _get_dna_sequence(assembly: str, chrom: str, start: int, end: int) -> str:
118
  """
119
  Fetch DNA sequence from API based on assembly, chromosome, and coordinates.
120
-
121
  Uses ASSEMBLY_TO_API_URL_TEMPLATE to determine the API URL format for each assembly.
122
  Falls back to DEFAULT_API_URL_TEMPLATE if assembly is not in the mapping.
123
  """
124
  if requests is None:
125
- raise ImportError("requests is required for genome download. Install with: pip install requests")
126
-
 
 
127
  # Get API URL template for this assembly, or use default
128
  url_template = ASSEMBLY_TO_API_URL_TEMPLATE.get(assembly, DEFAULT_API_URL_TEMPLATE)
129
-
130
  # Format the URL with the provided parameters
131
  url = url_template.format(assembly=assembly, chrom=chrom, start=start, end=end)
132
-
133
  seq = requests.get(url).json()["dna"].upper()
134
  return seq
135
 
136
 
137
- def _ensure_fasta_for_assembly(assembly: str, cache_dir: Union[str, Path]) -> Path:
138
  """
139
  Download <assembly>.fa.gz, decompress to <assembly>.fa, return the .fa path.
140
  pyfaidx works reliably on uncompressed FASTA.
@@ -156,6 +159,7 @@ def _ensure_fasta_for_assembly(assembly: str, cache_dir: Union[str, Path]) -> Pa
156
  )
157
 
158
  import gzip
 
159
  print(f"Decompressing {gz_path} -> {fa_path}")
160
  with gzip.open(gz_path, "rb") as fin, open(fa_path, "wb") as fout:
161
  while True:
@@ -166,11 +170,12 @@ def _ensure_fasta_for_assembly(assembly: str, cache_dir: Union[str, Path]) -> Pa
166
 
167
  return fa_path
168
 
169
- def _pick_device(device: Union[str, int, torch.device]) -> torch.device:
 
170
  # Handle torch.device objects
171
  if isinstance(device, torch.device):
172
  return device
173
-
174
  # Handle integer device IDs (transformers pipeline convention)
175
  if isinstance(device, int):
176
  if device == -1:
@@ -182,7 +187,7 @@ def _pick_device(device: Union[str, int, torch.device]) -> torch.device:
182
  return torch.device("cpu")
183
  else:
184
  raise ValueError(f"Invalid device integer: {device}")
185
-
186
  # Handle string device names
187
  if isinstance(device, str):
188
  d = device.lower()
@@ -194,9 +199,13 @@ def _pick_device(device: Union[str, int, torch.device]) -> torch.device:
194
  return torch.device("cpu")
195
  if d in ("cuda", "cpu", "mps"):
196
  return torch.device(d)
197
- raise ValueError("device must be one of: 'auto', 'cpu', 'cuda', 'mps', or an integer")
198
-
199
- raise ValueError(f"device must be a string, integer, or torch.device, got {type(device)}")
 
 
 
 
200
 
201
 
202
  def _softmax_last(x: np.ndarray) -> np.ndarray:
@@ -206,16 +215,18 @@ def _softmax_last(x: np.ndarray) -> np.ndarray:
206
 
207
 
208
  def _plot_tracks_fillbetween(
209
- tracks: Dict[str, np.ndarray],
210
- chrom: Optional[str],
211
  start: int,
212
  end: int,
213
- assembly: Optional[str],
214
  height: float = 1.0,
215
  figsize_x: float = 20.0,
216
  ):
217
  if plt is None:
218
- raise ImportError("matplotlib is required for plotting. Install with: pip install matplotlib")
 
 
219
 
220
  n = len(tracks)
221
  if n == 0:
@@ -238,7 +249,7 @@ def _plot_tracks_fillbetween(
238
  color = BED_ELEMENT_COLORS[title]
239
  else:
240
  color = bigwig_color
241
-
242
  ax.fill_between(x, y, color=color, alpha=0.3, linewidth=0)
243
  ax.plot(x, y, color=color, linewidth=0.8)
244
  ax.set_title(title, fontsize=10, loc="left")
@@ -260,29 +271,31 @@ def _plot_tracks_fillbetween(
260
  @dataclass
261
  class NTv3TracksOutput:
262
  bigwig_tracks_logits: np.ndarray # (L_pred, T)
263
- bed_tracks_logits: np.ndarray # (L_pred, E, C)
264
  mlm_logits: np.ndarray
265
- chrom: Optional[str] = None
266
- start: Optional[int] = None
267
- end: Optional[int] = None
268
- species: Optional[str] = None
269
- assembly: Optional[str] = None
270
- bigwig_track_names: Optional[List[str]] = None # from cfg.bigwigs_per_file_assembly[assembly]
271
- bed_element_names: Optional[List[str]] = None
272
- window_len: Optional[int] = None
273
- pred_start: Optional[int] = None
274
- pred_end: Optional[int] = None
 
 
275
 
276
 
277
  class NTv3TracksPipeline(Pipeline):
278
  def __init__(
279
  self,
280
- model: Union[str, torch.nn.Module],
281
- tokenizer: Optional[Union[str, Any]] = None,
282
  trust_remote_code: bool = True,
283
- token: Optional[str] = None,
284
  default_species: str = "human",
285
- genome_cache_dir: Union[str, Path] = "~/.cache/ntv3/genomes",
286
  device: str = "auto",
287
  mps_force_cpu: bool = True,
288
  mps_force_cpu_length: int = 16384,
@@ -302,24 +315,36 @@ class NTv3TracksPipeline(Pipeline):
302
  self.pred_center_offset_fraction = float(pred_center_offset_fraction)
303
 
304
  if isinstance(model, str):
305
- self.config = AutoConfig.from_pretrained(model, trust_remote_code=trust_remote_code, token=token)
306
- self.model = AutoModel.from_pretrained(model, trust_remote_code=trust_remote_code, token=token)
 
 
 
 
307
  else:
308
  self.model = model
309
  self.config = getattr(model, "config", None)
310
 
311
  if tokenizer is None:
312
  if not self.model_id:
313
- raise ValueError("If passing a model module, pass tokenizer explicitly.")
314
- self.tokenizer = AutoTokenizer.from_pretrained(self.model_id, trust_remote_code=trust_remote_code, token=token)
 
 
 
 
315
  elif isinstance(tokenizer, str):
316
- self.tokenizer = AutoTokenizer.from_pretrained(tokenizer, trust_remote_code=trust_remote_code, token=token)
 
 
317
  else:
318
  self.tokenizer = tokenizer
319
 
320
  # Extract model_id from config if not already set (following ntv3_gff_pipeline.py pattern)
321
  if self.model_id is None and self.config is not None:
322
- self.model_id = getattr(self.config, "_name_or_path", None) or getattr(self.config, "name_or_path", None)
 
 
323
 
324
  # Load species_tokenizer (following ntv3_gff_pipeline.py pattern)
325
  if self.model_id:
@@ -332,19 +357,22 @@ class NTv3TracksPipeline(Pipeline):
332
  else:
333
  self.species_tokenizer = kwargs.get("species_tokenizer", None)
334
  if self.species_tokenizer is None:
335
- raise ValueError("Pass species_tokenizer=... when constructing with a model module.")
 
 
336
 
337
  # bed names (your notebooks refer to bed_element_names)
338
- self.bed_element_names = (
339
- getattr(self.config, "bed_elements_names", None)
340
- or getattr(self.config, "bed_element_names", None)
341
- )
342
 
343
  self._target_device = _pick_device(device)
344
  self.model.to(self._target_device)
345
  self.model.eval()
346
 
347
- super().__init__(model=self.model, tokenizer=self.tokenizer, device=-1, **kwargs)
 
 
348
 
349
  def _sanitize_parameters(self, **kwargs):
350
  return {}, {}, {}
@@ -352,10 +380,12 @@ class NTv3TracksPipeline(Pipeline):
352
  def _get_model_device(self) -> torch.device:
353
  return next(self.model.parameters()).device
354
 
355
- def _resolve_species_and_assembly(self, inputs: Dict[str, Any]) -> tuple[str, str]:
356
  species = inputs.get("species", self.default_species)
357
  if species not in SPECIES_TO_ASSEMBLY:
358
- raise ValueError(f"Unsupported species='{species}'. Supported species: {sorted(SPECIES_TO_ASSEMBLY.keys())}")
 
 
359
  assembly = SPECIES_TO_ASSEMBLY[species]
360
 
361
  cfg_assemblies = list(self.config.bigwigs_per_file_assembly.keys())
@@ -366,8 +396,9 @@ class NTv3TracksPipeline(Pipeline):
366
  )
367
  return species, assembly
368
 
369
-
370
- def _maybe_force_cpu_for_mps_long(self, input_ids_cpu: torch.Tensor) -> torch.device:
 
371
  dev = self._get_model_device()
372
  if self.mps_force_cpu and dev.type == "mps":
373
  seq_len = int(input_ids_cpu.shape[-1])
@@ -390,7 +421,9 @@ class NTv3TracksPipeline(Pipeline):
390
  sp = species or self.default_species
391
  assembly = SPECIES_TO_ASSEMBLY.get(sp)
392
  if assembly is None:
393
- raise ValueError(f"Unknown species={sp}. Supported: {sorted(SPECIES_TO_ASSEMBLY.keys())}")
 
 
394
 
395
  if assembly not in self.config.bigwigs_per_file_assembly:
396
  raise ValueError(
@@ -400,13 +433,13 @@ class NTv3TracksPipeline(Pipeline):
400
 
401
  return list(self.config.bigwigs_per_file_assembly[assembly])
402
 
403
- def available_bed_element_names(self) -> List[str]:
404
  """
405
  Return BED element names available in this checkpoint (no forward pass).
406
  """
407
  return list(self.bed_element_names or [])
408
-
409
- def preprocess(self, inputs: Dict[str, Any], **kwargs: Any) -> Dict[str, Any]:
410
  species, assembly = self._resolve_species_and_assembly(inputs)
411
 
412
  # Resolve sequence
@@ -425,7 +458,13 @@ class NTv3TracksPipeline(Pipeline):
425
  seq = _sanitize_dna(seq)
426
 
427
  # Tokenize with padding
428
- batch = self.tokenizer([seq], add_special_tokens=False, padding=True, pad_to_multiple_of=128, return_tensors="pt")
 
 
 
 
 
 
429
  input_ids_cpu = batch["input_ids"]
430
 
431
  # MPS-long fallback decision
@@ -435,7 +474,9 @@ class NTv3TracksPipeline(Pipeline):
435
  input_ids = input_ids_cpu.to(device)
436
  # Species tokenization - match batch size
437
  batch_size = input_ids.shape[0]
438
- species_ids = self.species_tokenizer([species] * batch_size, add_special_tokens=False, return_tensors="pt")
 
 
439
  species_ids_tensor = species_ids["input_ids"].to(device)
440
 
441
  # Prediction interval (not used for slicing logits, just x-axis)
@@ -465,7 +506,7 @@ class NTv3TracksPipeline(Pipeline):
465
  def forward(self, model_inputs, **forward_params):
466
  return self._forward(model_inputs, **forward_params)
467
 
468
- def _forward(self, model_inputs: Dict[str, Any], **kwargs: Any) -> Dict[str, Any]:
469
  meta = model_inputs.pop("meta")
470
  if self.verbose:
471
  print(f"Running on device: {self._get_model_device()}")
@@ -478,7 +519,9 @@ class NTv3TracksPipeline(Pipeline):
478
  out["meta"] = meta
479
  return out
480
 
481
- def postprocess(self, model_outputs: Dict[str, Any], **kwargs: Any) -> NTv3TracksOutput:
 
 
482
  meta = model_outputs.pop("meta", {})
483
 
484
  def to_np(x):
@@ -490,16 +533,16 @@ class NTv3TracksPipeline(Pipeline):
490
 
491
  # Normalize shapes to remove batch/(optional assembly) dims
492
  if bigwig_np.ndim == 3:
493
- bigwig_np = bigwig_np[0] # (L, T)
494
  elif bigwig_np.ndim == 4:
495
- bigwig_np = bigwig_np[0, 0] # (L, T) if (B, A, L, T)
496
  else:
497
  raise ValueError(f"Unexpected bigwig_tracks_logits ndim: {bigwig_np.ndim}")
498
 
499
  if bed_np.ndim == 4:
500
- bed_np = bed_np[0] # (L, E, C)
501
  elif bed_np.ndim == 5:
502
- bed_np = bed_np[0, 0] # (L, E, C) if (B, A, L, E, C)
503
  else:
504
  raise ValueError(f"Unexpected bed_tracks_logits ndim: {bed_np.ndim}")
505
 
@@ -527,8 +570,8 @@ class NTv3TracksPipeline(Pipeline):
527
  inputs,
528
  *args,
529
  plot: bool = False,
530
- tracks_to_plot: Optional[Dict[str, str]] = None, # title -> track_id (ENCSR...)
531
- elements_to_plot: Optional[List[str]] = None, # element names
532
  plot_height: float = 1.0,
533
  plot_figsize_x: float = 20.0,
534
  **kwargs,
@@ -540,7 +583,9 @@ class NTv3TracksPipeline(Pipeline):
540
 
541
  if plot:
542
  if out.bigwig_track_names is None:
543
- raise ValueError("bigwig_track_names missing; expected cfg.bigwigs_per_file_assembly[assembly].")
 
 
544
  if out.bed_element_names is None:
545
  raise ValueError("bed element names missing from config.")
546
  tracks_to_plot = tracks_to_plot or {}
@@ -550,14 +595,18 @@ class NTv3TracksPipeline(Pipeline):
550
  bed_element_names = out.bed_element_names
551
 
552
  # Validate
553
- missing_tracks = [tid for tid in tracks_to_plot.values() if tid not in bigwig_names]
 
 
554
  if missing_tracks:
555
  raise ValueError(
556
  f"The following tracks are not available in bigwig_names: {missing_tracks}\n"
557
  f"First 50 available: {bigwig_names[:50]}{'...' if len(bigwig_names) > 50 else ''}"
558
  )
559
 
560
- missing_elements = [e for e in elements_to_plot if e not in bed_element_names]
 
 
561
  if missing_elements:
562
  raise ValueError(
563
  f"The following elements are not available in bed_element_names: {missing_elements}\n"
@@ -565,14 +614,14 @@ class NTv3TracksPipeline(Pipeline):
565
  )
566
 
567
  # Build bigwig tracks dict (title -> y)
568
- bigwig_tracks: Dict[str, np.ndarray] = {}
569
  bigwig = out.bigwig_tracks_logits # (L_pred, T)
570
  for title, track_id in tracks_to_plot.items():
571
  track_idx = bigwig_names.index(track_id)
572
  bigwig_tracks[title] = bigwig[:, track_idx]
573
 
574
  # Bed positive class probabilities (title -> y)
575
- bed_probs: Dict[str, np.ndarray] = {}
576
  probs = _softmax_last(out.bed_tracks_logits) # (L_pred, E, C)
577
  for element_name in elements_to_plot:
578
  element_idx = bed_element_names.index(element_name)
@@ -581,8 +630,10 @@ class NTv3TracksPipeline(Pipeline):
581
  all_tracks = {**bigwig_tracks, **bed_probs}
582
 
583
  plot_start = int(out.pred_start or 0)
584
- plot_end = int(out.pred_end or (plot_start + len(next(iter(all_tracks.values())))))
585
-
 
 
586
  _plot_tracks_fillbetween(
587
  all_tracks,
588
  chrom=out.chrom,
@@ -595,6 +646,7 @@ class NTv3TracksPipeline(Pipeline):
595
 
596
  return out
597
 
 
598
  def load_ntv3_tracks_pipeline(
599
  model: str,
600
  device: str = "auto",
@@ -618,4 +670,4 @@ def load_ntv3_tracks_pipeline(
618
  device=device,
619
  **pipeline_kwargs,
620
  )
621
- return pipe
 
2
 
3
  from dataclasses import dataclass
4
  from pathlib import Path
5
+ from typing import Any
6
 
7
  import numpy as np
8
  import torch
 
109
  "ORF": "#1F618D", # Blue 2
110
  }
111
 
112
+
113
  def _sanitize_dna(seq: str) -> str:
114
  seq = seq.upper()
115
  return "".join(ch if ch in ("A", "C", "G", "T", "N") else "N" for ch in seq)
 
118
  def _get_dna_sequence(assembly: str, chrom: str, start: int, end: int) -> str:
119
  """
120
  Fetch DNA sequence from API based on assembly, chromosome, and coordinates.
121
+
122
  Uses ASSEMBLY_TO_API_URL_TEMPLATE to determine the API URL format for each assembly.
123
  Falls back to DEFAULT_API_URL_TEMPLATE if assembly is not in the mapping.
124
  """
125
  if requests is None:
126
+ raise ImportError(
127
+ "requests is required for genome download. Install with: pip install requests"
128
+ )
129
+
130
  # Get API URL template for this assembly, or use default
131
  url_template = ASSEMBLY_TO_API_URL_TEMPLATE.get(assembly, DEFAULT_API_URL_TEMPLATE)
132
+
133
  # Format the URL with the provided parameters
134
  url = url_template.format(assembly=assembly, chrom=chrom, start=start, end=end)
135
+
136
  seq = requests.get(url).json()["dna"].upper()
137
  return seq
138
 
139
 
140
+ def _ensure_fasta_for_assembly(assembly: str, cache_dir: str | Path) -> Path:
141
  """
142
  Download <assembly>.fa.gz, decompress to <assembly>.fa, return the .fa path.
143
  pyfaidx works reliably on uncompressed FASTA.
 
159
  )
160
 
161
  import gzip
162
+
163
  print(f"Decompressing {gz_path} -> {fa_path}")
164
  with gzip.open(gz_path, "rb") as fin, open(fa_path, "wb") as fout:
165
  while True:
 
170
 
171
  return fa_path
172
 
173
+
174
+ def _pick_device(device: str | int | torch.device) -> torch.device:
175
  # Handle torch.device objects
176
  if isinstance(device, torch.device):
177
  return device
178
+
179
  # Handle integer device IDs (transformers pipeline convention)
180
  if isinstance(device, int):
181
  if device == -1:
 
187
  return torch.device("cpu")
188
  else:
189
  raise ValueError(f"Invalid device integer: {device}")
190
+
191
  # Handle string device names
192
  if isinstance(device, str):
193
  d = device.lower()
 
199
  return torch.device("cpu")
200
  if d in ("cuda", "cpu", "mps"):
201
  return torch.device(d)
202
+ raise ValueError(
203
+ "device must be one of: 'auto', 'cpu', 'cuda', 'mps', or an integer"
204
+ )
205
+
206
+ raise ValueError(
207
+ f"device must be a string, integer, or torch.device, got {type(device)}"
208
+ )
209
 
210
 
211
  def _softmax_last(x: np.ndarray) -> np.ndarray:
 
215
 
216
 
217
  def _plot_tracks_fillbetween(
218
+ tracks: dict[str, np.ndarray],
219
+ chrom: str | None,
220
  start: int,
221
  end: int,
222
+ assembly: str | None,
223
  height: float = 1.0,
224
  figsize_x: float = 20.0,
225
  ):
226
  if plt is None:
227
+ raise ImportError(
228
+ "matplotlib is required for plotting. Install with: pip install matplotlib"
229
+ )
230
 
231
  n = len(tracks)
232
  if n == 0:
 
249
  color = BED_ELEMENT_COLORS[title]
250
  else:
251
  color = bigwig_color
252
+
253
  ax.fill_between(x, y, color=color, alpha=0.3, linewidth=0)
254
  ax.plot(x, y, color=color, linewidth=0.8)
255
  ax.set_title(title, fontsize=10, loc="left")
 
271
  @dataclass
272
  class NTv3TracksOutput:
273
  bigwig_tracks_logits: np.ndarray # (L_pred, T)
274
+ bed_tracks_logits: np.ndarray # (L_pred, E, C)
275
  mlm_logits: np.ndarray
276
+ chrom: str | None = None
277
+ start: int | None = None
278
+ end: int | None = None
279
+ species: str | None = None
280
+ assembly: str | None = None
281
+ bigwig_track_names: list[str] | None = (
282
+ None # from cfg.bigwigs_per_file_assembly[assembly]
283
+ )
284
+ bed_element_names: list[str] | None = None
285
+ window_len: int | None = None
286
+ pred_start: int | None = None
287
+ pred_end: int | None = None
288
 
289
 
290
  class NTv3TracksPipeline(Pipeline):
291
  def __init__(
292
  self,
293
+ model: str | torch.nn.Module,
294
+ tokenizer: str | Any | None = None,
295
  trust_remote_code: bool = True,
296
+ token: str | None = None,
297
  default_species: str = "human",
298
+ genome_cache_dir: str | Path = "~/.cache/ntv3/genomes",
299
  device: str = "auto",
300
  mps_force_cpu: bool = True,
301
  mps_force_cpu_length: int = 16384,
 
315
  self.pred_center_offset_fraction = float(pred_center_offset_fraction)
316
 
317
  if isinstance(model, str):
318
+ self.config = AutoConfig.from_pretrained(
319
+ model, trust_remote_code=trust_remote_code, token=token
320
+ )
321
+ self.model = AutoModel.from_pretrained(
322
+ model, trust_remote_code=trust_remote_code, token=token
323
+ )
324
  else:
325
  self.model = model
326
  self.config = getattr(model, "config", None)
327
 
328
  if tokenizer is None:
329
  if not self.model_id:
330
+ raise ValueError(
331
+ "If passing a model module, pass tokenizer explicitly."
332
+ )
333
+ self.tokenizer = AutoTokenizer.from_pretrained(
334
+ self.model_id, trust_remote_code=trust_remote_code, token=token
335
+ )
336
  elif isinstance(tokenizer, str):
337
+ self.tokenizer = AutoTokenizer.from_pretrained(
338
+ tokenizer, trust_remote_code=trust_remote_code, token=token
339
+ )
340
  else:
341
  self.tokenizer = tokenizer
342
 
343
  # Extract model_id from config if not already set (following ntv3_gff_pipeline.py pattern)
344
  if self.model_id is None and self.config is not None:
345
+ self.model_id = getattr(self.config, "_name_or_path", None) or getattr(
346
+ self.config, "name_or_path", None
347
+ )
348
 
349
  # Load species_tokenizer (following ntv3_gff_pipeline.py pattern)
350
  if self.model_id:
 
357
  else:
358
  self.species_tokenizer = kwargs.get("species_tokenizer", None)
359
  if self.species_tokenizer is None:
360
+ raise ValueError(
361
+ "Pass species_tokenizer=... when constructing with a model module."
362
+ )
363
 
364
  # bed names (your notebooks refer to bed_element_names)
365
+ self.bed_element_names = getattr(
366
+ self.config, "bed_elements_names", None
367
+ ) or getattr(self.config, "bed_element_names", None)
 
368
 
369
  self._target_device = _pick_device(device)
370
  self.model.to(self._target_device)
371
  self.model.eval()
372
 
373
+ super().__init__(
374
+ model=self.model, tokenizer=self.tokenizer, device=-1, **kwargs
375
+ )
376
 
377
  def _sanitize_parameters(self, **kwargs):
378
  return {}, {}, {}
 
380
  def _get_model_device(self) -> torch.device:
381
  return next(self.model.parameters()).device
382
 
383
+ def _resolve_species_and_assembly(self, inputs: dict[str, Any]) -> tuple[str, str]:
384
  species = inputs.get("species", self.default_species)
385
  if species not in SPECIES_TO_ASSEMBLY:
386
+ raise ValueError(
387
+ f"Unsupported species='{species}'. Supported species: {sorted(SPECIES_TO_ASSEMBLY.keys())}"
388
+ )
389
  assembly = SPECIES_TO_ASSEMBLY[species]
390
 
391
  cfg_assemblies = list(self.config.bigwigs_per_file_assembly.keys())
 
396
  )
397
  return species, assembly
398
 
399
+ def _maybe_force_cpu_for_mps_long(
400
+ self, input_ids_cpu: torch.Tensor
401
+ ) -> torch.device:
402
  dev = self._get_model_device()
403
  if self.mps_force_cpu and dev.type == "mps":
404
  seq_len = int(input_ids_cpu.shape[-1])
 
421
  sp = species or self.default_species
422
  assembly = SPECIES_TO_ASSEMBLY.get(sp)
423
  if assembly is None:
424
+ raise ValueError(
425
+ f"Unknown species={sp}. Supported: {sorted(SPECIES_TO_ASSEMBLY.keys())}"
426
+ )
427
 
428
  if assembly not in self.config.bigwigs_per_file_assembly:
429
  raise ValueError(
 
433
 
434
  return list(self.config.bigwigs_per_file_assembly[assembly])
435
 
436
+ def available_bed_element_names(self) -> list[str]:
437
  """
438
  Return BED element names available in this checkpoint (no forward pass).
439
  """
440
  return list(self.bed_element_names or [])
441
+
442
+ def preprocess(self, inputs: dict[str, Any], **kwargs: Any) -> dict[str, Any]:
443
  species, assembly = self._resolve_species_and_assembly(inputs)
444
 
445
  # Resolve sequence
 
458
  seq = _sanitize_dna(seq)
459
 
460
  # Tokenize with padding
461
+ batch = self.tokenizer(
462
+ [seq],
463
+ add_special_tokens=False,
464
+ padding=True,
465
+ pad_to_multiple_of=128,
466
+ return_tensors="pt",
467
+ )
468
  input_ids_cpu = batch["input_ids"]
469
 
470
  # MPS-long fallback decision
 
474
  input_ids = input_ids_cpu.to(device)
475
  # Species tokenization - match batch size
476
  batch_size = input_ids.shape[0]
477
+ species_ids = self.species_tokenizer(
478
+ [species] * batch_size, add_special_tokens=False, return_tensors="pt"
479
+ )
480
  species_ids_tensor = species_ids["input_ids"].to(device)
481
 
482
  # Prediction interval (not used for slicing logits, just x-axis)
 
506
  def forward(self, model_inputs, **forward_params):
507
  return self._forward(model_inputs, **forward_params)
508
 
509
+ def _forward(self, model_inputs: dict[str, Any], **kwargs: Any) -> dict[str, Any]:
510
  meta = model_inputs.pop("meta")
511
  if self.verbose:
512
  print(f"Running on device: {self._get_model_device()}")
 
519
  out["meta"] = meta
520
  return out
521
 
522
+ def postprocess(
523
+ self, model_outputs: dict[str, Any], **kwargs: Any
524
+ ) -> NTv3TracksOutput:
525
  meta = model_outputs.pop("meta", {})
526
 
527
  def to_np(x):
 
533
 
534
  # Normalize shapes to remove batch/(optional assembly) dims
535
  if bigwig_np.ndim == 3:
536
+ bigwig_np = bigwig_np[0] # (L, T)
537
  elif bigwig_np.ndim == 4:
538
+ bigwig_np = bigwig_np[0, 0] # (L, T) if (B, A, L, T)
539
  else:
540
  raise ValueError(f"Unexpected bigwig_tracks_logits ndim: {bigwig_np.ndim}")
541
 
542
  if bed_np.ndim == 4:
543
+ bed_np = bed_np[0] # (L, E, C)
544
  elif bed_np.ndim == 5:
545
+ bed_np = bed_np[0, 0] # (L, E, C) if (B, A, L, E, C)
546
  else:
547
  raise ValueError(f"Unexpected bed_tracks_logits ndim: {bed_np.ndim}")
548
 
 
570
  inputs,
571
  *args,
572
  plot: bool = False,
573
+ tracks_to_plot: dict[str, str] | None = None, # title -> track_id (ENCSR...)
574
+ elements_to_plot: list[str] | None = None, # element names
575
  plot_height: float = 1.0,
576
  plot_figsize_x: float = 20.0,
577
  **kwargs,
 
583
 
584
  if plot:
585
  if out.bigwig_track_names is None:
586
+ raise ValueError(
587
+ "bigwig_track_names missing; expected cfg.bigwigs_per_file_assembly[assembly]."
588
+ )
589
  if out.bed_element_names is None:
590
  raise ValueError("bed element names missing from config.")
591
  tracks_to_plot = tracks_to_plot or {}
 
595
  bed_element_names = out.bed_element_names
596
 
597
  # Validate
598
+ missing_tracks = [
599
+ tid for tid in tracks_to_plot.values() if tid not in bigwig_names
600
+ ]
601
  if missing_tracks:
602
  raise ValueError(
603
  f"The following tracks are not available in bigwig_names: {missing_tracks}\n"
604
  f"First 50 available: {bigwig_names[:50]}{'...' if len(bigwig_names) > 50 else ''}"
605
  )
606
 
607
+ missing_elements = [
608
+ e for e in elements_to_plot if e not in bed_element_names
609
+ ]
610
  if missing_elements:
611
  raise ValueError(
612
  f"The following elements are not available in bed_element_names: {missing_elements}\n"
 
614
  )
615
 
616
  # Build bigwig tracks dict (title -> y)
617
+ bigwig_tracks: dict[str, np.ndarray] = {}
618
  bigwig = out.bigwig_tracks_logits # (L_pred, T)
619
  for title, track_id in tracks_to_plot.items():
620
  track_idx = bigwig_names.index(track_id)
621
  bigwig_tracks[title] = bigwig[:, track_idx]
622
 
623
  # Bed positive class probabilities (title -> y)
624
+ bed_probs: dict[str, np.ndarray] = {}
625
  probs = _softmax_last(out.bed_tracks_logits) # (L_pred, E, C)
626
  for element_name in elements_to_plot:
627
  element_idx = bed_element_names.index(element_name)
 
630
  all_tracks = {**bigwig_tracks, **bed_probs}
631
 
632
  plot_start = int(out.pred_start or 0)
633
+ plot_end = int(
634
+ out.pred_end or (plot_start + len(next(iter(all_tracks.values()))))
635
+ )
636
+
637
  _plot_tracks_fillbetween(
638
  all_tracks,
639
  chrom=out.chrom,
 
646
 
647
  return out
648
 
649
+
650
  def load_ntv3_tracks_pipeline(
651
  model: str,
652
  device: str = "auto",
 
670
  device=device,
671
  **pipeline_kwargs,
672
  )
673
+ return pipe
requirements.txt CHANGED
@@ -1,8 +1,8 @@
1
- transformers>=4.41.0
2
- torch
3
- numpy
4
  gradio>=4.0.0
5
- pyfaidx
6
- requests
7
  matplotlib
 
8
  pyBigWig
 
 
 
 
 
 
 
 
1
  gradio>=4.0.0
 
 
2
  matplotlib
3
+ numpy
4
  pyBigWig
5
+ pyfaidx
6
+ requests
7
+ torch
8
+ transformers>=4.41.0