bernardo-de-almeida commited on
Commit
b65f002
·
1 Parent(s): beb6a82

refactor: clean code

Browse files
Files changed (4) hide show
  1. app.py +81 -59
  2. bigwig_export.py +11 -7
  3. ntv3_tracks_pipeline.py +72 -61
  4. requirements.txt +1 -1
app.py CHANGED
@@ -8,11 +8,10 @@ from pathlib import Path
8
  import gradio as gr
9
  import matplotlib
10
  import matplotlib.colors as mcolors
11
- import matplotlib.pyplot as plt
12
  import numpy as np
13
  import plotly.graph_objects as go
14
- from plotly.subplots import make_subplots
15
  import torch
 
16
 
17
  from bigwig_export import _softmax_last, create_bigwig_zip
18
  from ntv3_tracks_pipeline import (
@@ -57,7 +56,7 @@ def load_pipeline(model_id: str, species: str = DEFAULT_SPECIES):
57
  pipe = load_ntv3_tracks_pipeline(
58
  model=model_id,
59
  token=HF_TOKEN,
60
- device="cpu", # This prevents the pipeline constructor from doing model.to(cuda) during import.
61
  default_species=species,
62
  verbose=False,
63
  )
@@ -100,25 +99,29 @@ try:
100
  except Exception:
101
 
102
  def gpu(*args, **kwargs):
 
 
103
  def wrap(fn):
104
  return fn
105
 
106
  return wrap
107
 
108
 
109
- def _global_stride(L: int, target: int) -> int:
110
- if target <= 0 or L <= target:
111
  return 1
112
- return int(np.ceil(L / target))
113
 
114
 
115
- def _make_tracks_figure(x: np.ndarray, series: list[tuple[str, np.ndarray]], region: str = ""):
 
 
116
  """Create an interactive plotly figure with multiple tracks."""
117
  if not series:
118
  raise gr.Error("Nothing to plot (no tracks/elements selected).")
119
 
120
  n = len(series)
121
-
122
  # Create subplots with shared x-axis
123
  fig = make_subplots(
124
  rows=n,
@@ -140,8 +143,10 @@ def _make_tracks_figure(x: np.ndarray, series: list[tuple[str, np.ndarray]], reg
140
 
141
  # Convert color to rgba for fill
142
  rgba = mcolors.to_rgba(color)
143
- rgba_str = f"rgba({int(rgba[0]*255)}, {int(rgba[1]*255)}, {int(rgba[2]*255)}, 0.3)"
144
-
 
 
145
  # Add filled area (fill_between equivalent)
146
  fig.add_trace(
147
  go.Scatter(
@@ -149,12 +154,12 @@ def _make_tracks_figure(x: np.ndarray, series: list[tuple[str, np.ndarray]], reg
149
  y=y,
150
  mode="lines",
151
  name=title,
152
- line=dict(color=color, width=1.5),
153
  fill="tozeroy",
154
  fillcolor=rgba_str,
155
- hovertemplate=f"<b>{title}</b><br>" +
156
- "Position: %{x}<br>" +
157
- "Value: %{y:.4f}<extra></extra>",
158
  showlegend=False,
159
  ),
160
  row=i,
@@ -165,7 +170,7 @@ def _make_tracks_figure(x: np.ndarray, series: list[tuple[str, np.ndarray]], reg
165
  fig.update_layout(
166
  height=150 * n, # Adjust height based on number of tracks
167
  width=1200,
168
- margin=dict(l=80, r=20, t=40, b=60),
169
  hovermode="x unified", # Show all values at same x position
170
  template="plotly_white",
171
  )
@@ -278,7 +283,7 @@ def _format_track_for_display(track_id: str) -> str:
278
 
279
 
280
  def _extract_track_id(display_value: str) -> str:
281
- """Extract track ID from display format 'display_name (track_id)' or return as-is."""
282
  if " (" in display_value and display_value.endswith(")"):
283
  # Extract track_id from format "display_name (track_id)"
284
  return display_value.rsplit(" (", 1)[1][:-1]
@@ -455,6 +460,7 @@ def update_coords_on_species_change(species: str):
455
 
456
 
457
  def reset_on_species_change(species: str):
 
458
  # Clear results + selected when species changes (avoids mismatched IDs)
459
  try:
460
  track_ids = _get_bigwig_names(species) # warms cache if available
@@ -500,6 +506,7 @@ def predict(
500
  bigwig_selected: list[str],
501
  bed_elements: list[str],
502
  ):
 
503
  tprint("start")
504
 
505
  # Debug: verify species is being passed
@@ -515,10 +522,11 @@ def predict(
515
  if use_coords:
516
  # Check if this species supports coordinate-based fetching
517
  if species not in SPECIES_WITH_COORDINATE_SUPPORT:
 
518
  raise gr.Error(
519
- f"Species '{species}' does not support coordinate-based sequence fetching. "
520
- f"Please provide a DNA sequence directly or use one of the supported species: "
521
- f"{', '.join(sorted(SPECIES_WITH_COORDINATE_SUPPORT))}"
522
  )
523
  if not chrom:
524
  raise gr.Error("chrom is required when use_coords=True")
@@ -537,8 +545,10 @@ def predict(
537
 
538
  # Verify species is in inputs before calling pipeline
539
  if "species" not in inputs:
 
540
  raise gr.Error(
541
- f"Internal error: species not found in inputs dict. Inputs: {list(inputs.keys())}"
 
542
  )
543
 
544
  tprint("inputs prepared")
@@ -576,12 +586,15 @@ def predict(
576
 
577
  if not has_bigwigs and not has_bed:
578
  raise gr.Error(
579
- "No BigWig tracks or BED elements available for this species in the current model."
 
580
  )
581
 
582
  if not has_bigwigs and bigwig_selected:
583
  raise gr.Error(
584
- "No BigWig tracks available for this species, but BigWig tracks were selected. Please deselect BigWig tracks or choose a different species."
 
 
585
  )
586
 
587
  # Defaults if user picked none
@@ -617,17 +630,17 @@ def predict(
617
 
618
  # Determine sequence length from available data
619
  if has_bigwigs:
620
- L = bw.shape[0]
621
  elif has_bed:
622
- L = bed_logits.shape[0]
623
  else:
624
  raise gr.Error("No data available for plotting.")
625
 
626
- stride = _global_stride(L, PLOT_TARGET_POINTS)
627
 
628
  x0 = int(out.pred_start or 0)
629
- x1 = int(out.pred_end or (x0 + L))
630
- x = np.linspace(x0, x1, num=L, endpoint=False)[::stride]
631
 
632
  series: list[tuple[str, np.ndarray]] = []
633
 
@@ -645,14 +658,14 @@ def predict(
645
  series.append((ename, probs[:, eidx, 1][::stride].astype(float)))
646
 
647
  tprint("figure data processed created")
648
-
649
  # Build region string for x-axis label
650
  region = (
651
  f"{out.chrom}:{out.pred_start}-{out.pred_end}" if out.chrom else f"{x0}-{x1}"
652
  )
653
  if out.assembly:
654
  region += f" ({out.assembly})"
655
-
656
  fig = _make_tracks_figure(x, series, region=region)
657
  tprint("figure created")
658
 
@@ -680,7 +693,10 @@ def predict(
680
  # -----------------------------
681
  CSS = """
682
  #tracks_plot { position: relative; width: 100% !important; max-width: 100% !important; }
683
- #tracks_plot .wrap, #tracks_plot .plot-container { width: 100% !important; max-width: 100% !important; }
 
 
 
684
 
685
  #tracks_plot_download {
686
  position: absolute;
@@ -916,7 +932,8 @@ function addDownloadIcon() {
916
  btn.title = "Download PNG";
917
  btn.innerHTML = `
918
  <svg viewBox="0 0 24 24" aria-hidden="true">
919
- <path d="M5 20h14v-2H5v2zm7-18v10.17l3.59-3.58L17 10l-5 5-5-5 1.41-1.41L11 12.17V2h1z"/>
 
920
  </svg>
921
  `;
922
  btn.onclick = () => {
@@ -1024,8 +1041,10 @@ with gr.Blocks(title="NTv3 Tracks Demo") as demo:
1024
  <div class="intro-card">
1025
  <h3>2) Choose signals</h3>
1026
  <ul>
1027
- <li>Search & select <strong>BigWig functional tracks</strong> (RNA-seq, ChIP-seq, DNase…)</li>
1028
- <li>Select <strong>BED genome annotation elements</strong> (exons, introns, promoters…)</li>
 
 
1029
  </ul>
1030
  </div>
1031
 
@@ -1041,10 +1060,12 @@ with gr.Blocks(title="NTv3 Tracks Demo") as demo:
1041
 
1042
  <div class="intro-tip">
1043
  <span class="intro-tip-icon">💡</span>
1044
- <span><strong>Tip:</strong> The demo includes default settings that you can use to get started, taking ~ 15 seconds to run for the example on human.</span>
 
1045
  </div>
1046
 
1047
- <div style="margin-top: 16px; padding: 12px; background: rgba(0,0,0,0.03); border-radius: 12px; font-size: 0.95rem;">
 
1048
  <strong>Available species:</strong> {_all_species_list}<br>
1049
  <br>
1050
  <strong>Species with functional tracks:</strong> {_bigwig_species_list}
@@ -1059,8 +1080,8 @@ with gr.Blocks(title="NTv3 Tracks Demo") as demo:
1059
 
1060
  # Model display names (without InstaDeepAI/ prefix) and their full IDs
1061
  MODEL_OPTIONS = {
1062
- "NTv3 650M (pos)": "InstaDeepAI/NTv3_650M_pos",
1063
- "NTv3 100M (pos)": "InstaDeepAI/NTv3_100M_pos",
1064
  }
1065
 
1066
  # Reverse mapping: full ID -> display name
@@ -1112,11 +1133,9 @@ with gr.Blocks(title="NTv3 Tracks Demo") as demo:
1112
  + ")"
1113
  )
1114
  with gr.Row():
1115
- chrom = gr.Textbox(label="Chromosome", value=_default_coords["chrom"])
1116
- start = gr.Number(
1117
- label="Start", value=_default_coords["start"], precision=0
1118
- )
1119
- end = gr.Number(label="End", value=_default_coords["end"], precision=0)
1120
 
1121
  # DNA sequence section - visible only when "Enter DNA sequence" is selected
1122
  # Using Textbox directly (not wrapped in Group) to avoid visual border/line
@@ -1189,7 +1208,8 @@ with gr.Blocks(title="NTv3 Tracks Demo") as demo:
1189
  )
1190
 
1191
  bigwig_no_tracks_msg = gr.Markdown(
1192
- "⚠️ No functional genomic tracks available for this species in the current model.",
 
1193
  visible=False,
1194
  )
1195
 
@@ -1318,19 +1338,18 @@ with gr.Blocks(title="NTv3 Tracks Demo") as demo:
1318
  current_species, current_query, upd["value"]
1319
  )
1320
 
1321
- # Create a completely fresh update with explicit empty value to prevent any checked state
1322
- # Force Gradio to clear checked state by explicitly setting value to empty list
1323
- # Use a workaround: set choices to empty first, then to new_choices to force a complete refresh
1324
- # But since we can only return one update, we'll ensure value is explicitly empty
1325
- # and that we're not preserving any state from the previous update
1326
-
1327
- # Ensure no items from results_checked are in new_choices (they should already be filtered, but double-check)
1328
  checked_track_ids = {_extract_track_id(x) for x in results_checked}
1329
  new_choices_filtered = [
1330
  c for c in new_choices if _extract_track_id(c) not in checked_track_ids
1331
  ]
1332
 
1333
- # Create update with explicit empty value - this should force Gradio to clear all checked items
 
1334
  fresh_update = gr.update(
1335
  choices=new_choices_filtered,
1336
  value=[], # CRITICAL: Explicitly empty list to clear all checked state
@@ -1366,7 +1385,7 @@ with gr.Blocks(title="NTv3 Tracks Demo") as demo:
1366
  results_choices = (
1367
  results_update.choices if hasattr(results_update, "choices") else []
1368
  )
1369
- except:
1370
  # Fallback: get choices from the search function directly
1371
  results_choices = _get_search_results_choices(
1372
  current_species,
@@ -1395,10 +1414,12 @@ with gr.Blocks(title="NTv3 Tracks Demo") as demo:
1395
  ):
1396
  """Update selected tracks when user checks/unchecks items directly."""
1397
  # selected_value contains only the currently checked items
1398
- # Update choices to match the current selections (so unchecked items are removed)
 
1399
  show_selected = bool(selected_value)
1400
 
1401
- # Also update search results to reflect the new selection (tracks that were unchecked can now appear in results)
 
1402
  search_updates = search_bigwigs(current_species, current_query, selected_value)
1403
 
1404
  return (
@@ -1462,7 +1483,7 @@ with gr.Blocks(title="NTv3 Tracks Demo") as demo:
1462
  ]
1463
  # Show selected tracks section if there are default tracks
1464
  show_selected_tracks = bool(default_formatted)
1465
- except:
1466
  formatted_tracks = []
1467
  default_formatted = []
1468
  show_selected_tracks = False
@@ -1594,12 +1615,13 @@ with gr.Blocks(title="NTv3 Tracks Demo") as demo:
1594
  try:
1595
  zip_path = create_bigwig_zip(out, bw_selected, bed_selected)
1596
  return gr.update(value=zip_path, visible=True)
1597
- except ImportError as e:
1598
  raise gr.Error(
1599
- "pyBigWig is required for BigWig export. Install with: pip install pyBigWig"
 
1600
  )
1601
- except Exception as e:
1602
- raise gr.Error(f"Error creating BigWig files: {str(e)}")
1603
 
1604
  download_bigwig_btn.click(
1605
  fn=download_bigwig_zip,
 
8
  import gradio as gr
9
  import matplotlib
10
  import matplotlib.colors as mcolors
 
11
  import numpy as np
12
  import plotly.graph_objects as go
 
13
  import torch
14
+ from plotly.subplots import make_subplots
15
 
16
  from bigwig_export import _softmax_last, create_bigwig_zip
17
  from ntv3_tracks_pipeline import (
 
56
  pipe = load_ntv3_tracks_pipeline(
57
  model=model_id,
58
  token=HF_TOKEN,
59
+ device="cpu", # Prevents model.to(cuda) during import
60
  default_species=species,
61
  verbose=False,
62
  )
 
99
  except Exception:
100
 
101
  def gpu(*args, **kwargs):
102
+ """GPU decorator placeholder when spaces module is not available."""
103
+
104
  def wrap(fn):
105
  return fn
106
 
107
  return wrap
108
 
109
 
110
+ def _global_stride(length: int, target: int) -> int:
111
+ if target <= 0 or length <= target:
112
  return 1
113
+ return int(np.ceil(length / target))
114
 
115
 
116
+ def _make_tracks_figure(
117
+ x: np.ndarray, series: list[tuple[str, np.ndarray]], region: str = ""
118
+ ):
119
  """Create an interactive plotly figure with multiple tracks."""
120
  if not series:
121
  raise gr.Error("Nothing to plot (no tracks/elements selected).")
122
 
123
  n = len(series)
124
+
125
  # Create subplots with shared x-axis
126
  fig = make_subplots(
127
  rows=n,
 
143
 
144
  # Convert color to rgba for fill
145
  rgba = mcolors.to_rgba(color)
146
+ rgba_str = (
147
+ f"rgba({int(rgba[0]*255)}, {int(rgba[1]*255)}, {int(rgba[2]*255)}, 0.3)"
148
+ )
149
+
150
  # Add filled area (fill_between equivalent)
151
  fig.add_trace(
152
  go.Scatter(
 
154
  y=y,
155
  mode="lines",
156
  name=title,
157
+ line={"color": color, "width": 1.5},
158
  fill="tozeroy",
159
  fillcolor=rgba_str,
160
+ hovertemplate=f"<b>{title}</b><br>"
161
+ + "Position: %{x}<br>"
162
+ + "Value: %{y:.4f}<extra></extra>",
163
  showlegend=False,
164
  ),
165
  row=i,
 
170
  fig.update_layout(
171
  height=150 * n, # Adjust height based on number of tracks
172
  width=1200,
173
+ margin={"l": 80, "r": 20, "t": 40, "b": 60},
174
  hovermode="x unified", # Show all values at same x position
175
  template="plotly_white",
176
  )
 
283
 
284
 
285
  def _extract_track_id(display_value: str) -> str:
286
+ """Extract track ID from display format or return as-is."""
287
  if " (" in display_value and display_value.endswith(")"):
288
  # Extract track_id from format "display_name (track_id)"
289
  return display_value.rsplit(" (", 1)[1][:-1]
 
460
 
461
 
462
  def reset_on_species_change(species: str):
463
+ """Reset search and selected tracks when species changes."""
464
  # Clear results + selected when species changes (avoids mismatched IDs)
465
  try:
466
  track_ids = _get_bigwig_names(species) # warms cache if available
 
506
  bigwig_selected: list[str],
507
  bed_elements: list[str],
508
  ):
509
+ """Run prediction and return figure with tracks."""
510
  tprint("start")
511
 
512
  # Debug: verify species is being passed
 
522
  if use_coords:
523
  # Check if this species supports coordinate-based fetching
524
  if species not in SPECIES_WITH_COORDINATE_SUPPORT:
525
+ supported = ", ".join(sorted(SPECIES_WITH_COORDINATE_SUPPORT))
526
  raise gr.Error(
527
+ f"Species '{species}' does not support coordinate-based sequence "
528
+ f"fetching. Please provide a DNA sequence directly or use one of "
529
+ f"the supported species: {supported}"
530
  )
531
  if not chrom:
532
  raise gr.Error("chrom is required when use_coords=True")
 
545
 
546
  # Verify species is in inputs before calling pipeline
547
  if "species" not in inputs:
548
+ input_keys = list(inputs.keys())
549
  raise gr.Error(
550
+ f"Internal error: species not found in inputs dict. "
551
+ f"Inputs: {input_keys}"
552
  )
553
 
554
  tprint("inputs prepared")
 
586
 
587
  if not has_bigwigs and not has_bed:
588
  raise gr.Error(
589
+ "No BigWig tracks or BED elements available for this species "
590
+ "in the current model."
591
  )
592
 
593
  if not has_bigwigs and bigwig_selected:
594
  raise gr.Error(
595
+ "No BigWig tracks available for this species, but BigWig tracks "
596
+ "were selected. Please deselect BigWig tracks or choose a "
597
+ "different species."
598
  )
599
 
600
  # Defaults if user picked none
 
630
 
631
  # Determine sequence length from available data
632
  if has_bigwigs:
633
+ seq_length = bw.shape[0]
634
  elif has_bed:
635
+ seq_length = bed_logits.shape[0]
636
  else:
637
  raise gr.Error("No data available for plotting.")
638
 
639
+ stride = _global_stride(seq_length, PLOT_TARGET_POINTS)
640
 
641
  x0 = int(out.pred_start or 0)
642
+ x1 = int(out.pred_end or (x0 + seq_length))
643
+ x = np.linspace(x0, x1, num=seq_length, endpoint=False)[::stride]
644
 
645
  series: list[tuple[str, np.ndarray]] = []
646
 
 
658
  series.append((ename, probs[:, eidx, 1][::stride].astype(float)))
659
 
660
  tprint("figure data processed created")
661
+
662
  # Build region string for x-axis label
663
  region = (
664
  f"{out.chrom}:{out.pred_start}-{out.pred_end}" if out.chrom else f"{x0}-{x1}"
665
  )
666
  if out.assembly:
667
  region += f" ({out.assembly})"
668
+
669
  fig = _make_tracks_figure(x, series, region=region)
670
  tprint("figure created")
671
 
 
693
  # -----------------------------
694
  CSS = """
695
  #tracks_plot { position: relative; width: 100% !important; max-width: 100% !important; }
696
+ #tracks_plot .wrap, #tracks_plot .plot-container {
697
+ width: 100% !important;
698
+ max-width: 100% !important;
699
+ }
700
 
701
  #tracks_plot_download {
702
  position: absolute;
 
932
  btn.title = "Download PNG";
933
  btn.innerHTML = `
934
  <svg viewBox="0 0 24 24" aria-hidden="true">
935
+ <path d="M5 20h14v-2H5v2zm7-18v10.17l3.59-3.58L17 10l-5 5-5-5
936
+ 1.41-1.41L11 12.17V2h1z"/>
937
  </svg>
938
  `;
939
  btn.onclick = () => {
 
1041
  <div class="intro-card">
1042
  <h3>2) Choose signals</h3>
1043
  <ul>
1044
+ <li>Search & select <strong>BigWig functional tracks</strong>
1045
+ (RNA-seq, ChIP-seq, DNase…)</li>
1046
+ <li>Select <strong>BED genome annotation elements</strong>
1047
+ (exons, introns, promoters…)</li>
1048
  </ul>
1049
  </div>
1050
 
 
1060
 
1061
  <div class="intro-tip">
1062
  <span class="intro-tip-icon">💡</span>
1063
+ <span><strong>Tip:</strong> The demo includes default settings that you can use
1064
+ to get started, taking ~ 15 seconds to run for the example on human.</span>
1065
  </div>
1066
 
1067
+ <div style="margin-top: 16px; padding: 12px; background: rgba(0,0,0,0.03);
1068
+ border-radius: 12px; font-size: 0.95rem;">
1069
  <strong>Available species:</strong> {_all_species_list}<br>
1070
  <br>
1071
  <strong>Species with functional tracks:</strong> {_bigwig_species_list}
 
1080
 
1081
  # Model display names (without InstaDeepAI/ prefix) and their full IDs
1082
  MODEL_OPTIONS = {
1083
+ "NTv3 650M (post)": "InstaDeepAI/NTv3_650M_pos",
1084
+ "NTv3 100M (post)": "InstaDeepAI/NTv3_100M_pos",
1085
  }
1086
 
1087
  # Reverse mapping: full ID -> display name
 
1133
  + ")"
1134
  )
1135
  with gr.Row():
1136
+ chrom = gr.Textbox(label="Chromosome", value=_default_coords["chrom"])
1137
+ start = gr.Number(label="Start", value=_default_coords["start"], precision=0)
1138
+ end = gr.Number(label="End", value=_default_coords["end"], precision=0)
 
 
1139
 
1140
  # DNA sequence section - visible only when "Enter DNA sequence" is selected
1141
  # Using Textbox directly (not wrapped in Group) to avoid visual border/line
 
1208
  )
1209
 
1210
  bigwig_no_tracks_msg = gr.Markdown(
1211
+ "⚠️ No functional genomic tracks available for this species "
1212
+ "in the current model.",
1213
  visible=False,
1214
  )
1215
 
 
1338
  current_species, current_query, upd["value"]
1339
  )
1340
 
1341
+ # Create a completely fresh update with explicit empty value
1342
+ # to prevent any checked state. Force Gradio to clear checked state
1343
+ # by explicitly setting value to empty list.
1344
+ # Ensure no items from results_checked are in new_choices
1345
+ # (they should already be filtered, but double-check)
 
 
1346
  checked_track_ids = {_extract_track_id(x) for x in results_checked}
1347
  new_choices_filtered = [
1348
  c for c in new_choices if _extract_track_id(c) not in checked_track_ids
1349
  ]
1350
 
1351
+ # Create update with explicit empty value
1352
+ # This should force Gradio to clear all checked items
1353
  fresh_update = gr.update(
1354
  choices=new_choices_filtered,
1355
  value=[], # CRITICAL: Explicitly empty list to clear all checked state
 
1385
  results_choices = (
1386
  results_update.choices if hasattr(results_update, "choices") else []
1387
  )
1388
+ except Exception:
1389
  # Fallback: get choices from the search function directly
1390
  results_choices = _get_search_results_choices(
1391
  current_species,
 
1414
  ):
1415
  """Update selected tracks when user checks/unchecks items directly."""
1416
  # selected_value contains only the currently checked items
1417
+ # Update choices to match current selections
1418
+ # (unchecked items are removed)
1419
  show_selected = bool(selected_value)
1420
 
1421
+ # Also update search results to reflect new selection
1422
+ # (unchecked tracks can now appear in results)
1423
  search_updates = search_bigwigs(current_species, current_query, selected_value)
1424
 
1425
  return (
 
1483
  ]
1484
  # Show selected tracks section if there are default tracks
1485
  show_selected_tracks = bool(default_formatted)
1486
+ except Exception:
1487
  formatted_tracks = []
1488
  default_formatted = []
1489
  show_selected_tracks = False
 
1615
  try:
1616
  zip_path = create_bigwig_zip(out, bw_selected, bed_selected)
1617
  return gr.update(value=zip_path, visible=True)
1618
+ except ImportError:
1619
  raise gr.Error(
1620
+ "pyBigWig is required for BigWig export. "
1621
+ "Install with: pip install pyBigWig"
1622
  )
1623
+ except Exception as exc:
1624
+ raise gr.Error(f"Error creating BigWig files: {str(exc)}")
1625
 
1626
  download_bigwig_btn.click(
1627
  fn=download_bigwig_zip,
bigwig_export.py CHANGED
@@ -11,9 +11,9 @@ from typing import TYPE_CHECKING
11
  import numpy as np
12
 
13
  try:
14
- import pyBigWig
15
  except ImportError:
16
- pyBigWig = None
17
 
18
  if TYPE_CHECKING:
19
  from ntv3_tracks_pipeline import NTv3TracksOutput
@@ -75,16 +75,20 @@ def create_bigwig_zip(
75
  chrom = out.chrom
76
  if chrom is None:
77
  raise ValueError(
78
- "Chromosome information not available. Use genomic coordinates for BigWig export."
79
  )
80
 
81
  start = out.start
82
  end = out.end
 
 
83
  window_len = out.window_len or (end - start)
84
 
85
  # Calculate prediction region (center 37.5%)
86
- pred_start = out.pred_start or (start + int(window_len * 0.3125))
87
- pred_end = out.pred_end or (pred_start + int(window_len * 0.375))
 
 
88
 
89
  # Create temporary directory for BigWig files
90
  tmpdir = tempfile.gettempdir()
@@ -160,11 +164,11 @@ def create_bigwig_zip(
160
  for bw_file in created_files:
161
  try:
162
  os.remove(bw_file)
163
- except:
164
  pass
165
  try:
166
  os.rmdir(output_dir)
167
- except:
168
  pass
169
 
170
  return zip_path
 
11
  import numpy as np
12
 
13
  try:
14
+ import pyBigWig # noqa: N816
15
  except ImportError:
16
+ pyBigWig = None # noqa: N816
17
 
18
  if TYPE_CHECKING:
19
  from ntv3_tracks_pipeline import NTv3TracksOutput
 
75
  chrom = out.chrom
76
  if chrom is None:
77
  raise ValueError(
78
+ "Chromosome information not available. Use genomic coordinates."
79
  )
80
 
81
  start = out.start
82
  end = out.end
83
+ if start is None or end is None:
84
+ raise ValueError("Start and end coordinates are required for BigWig export.")
85
  window_len = out.window_len or (end - start)
86
 
87
  # Calculate prediction region (center 37.5%)
88
+ if out.pred_start is not None:
89
+ pred_start = out.pred_start
90
+ else:
91
+ pred_start = start + int(window_len * 0.3125)
92
 
93
  # Create temporary directory for BigWig files
94
  tmpdir = tempfile.gettempdir()
 
164
  for bw_file in created_files:
165
  try:
166
  os.remove(bw_file)
167
+ except Exception:
168
  pass
169
  try:
170
  os.rmdir(output_dir)
171
+ except Exception:
172
  pass
173
 
174
  return zip_path
ntv3_tracks_pipeline.py CHANGED
@@ -74,13 +74,13 @@ SPECIES_WITH_COORDINATE_SUPPORT = {
74
  # Assembly -> API URL template mapping
75
  # ---------------------------------------------------------------------
76
  # Default API URL template (UCSC format) that works for most species
77
- DEFAULT_API_URL_TEMPLATE = "https://api.genome.ucsc.edu/getData/sequence?genome={assembly};chrom={chrom};start={start};end={end}"
78
 
79
  # for species with different format, add the assembly name to the mapping
80
  # The template should use {chrom}, {start}, and {end} as placeholders.
81
  ASSEMBLY_TO_API_URL_TEMPLATE = {
82
  # Arabidopsis thaliana (TAIR10) - uses hub URL format
83
- "TAIR10": "https://api.genome.ucsc.edu/getData/sequence?hubUrl=http://genome.ucsc.edu/goldenPath/help/examples/hubExamples/hubAssembly/plantAraTha1/hub.txt;genome=araTha1;chrom={chrom};start={start};end={end}",
84
  }
85
 
86
 
@@ -124,7 +124,8 @@ def _get_dna_sequence(assembly: str, chrom: str, start: int, end: int) -> str:
124
  """
125
  if requests is None:
126
  raise ImportError(
127
- "requests is required for genome download. Install with: pip install requests"
 
128
  )
129
 
130
  # Get API URL template for this assembly, or use default
@@ -151,12 +152,11 @@ def _ensure_fasta_for_assembly(assembly: str, cache_dir: str | Path) -> Path:
151
  if fa_path.exists():
152
  return fa_path
153
 
154
- if assembly not in ASSEMBLY_TO_UCSC_FA_GZ:
155
- raise ValueError(
156
- f"No download URL configured for assembly='{assembly}'. "
157
- f"Supported for auto-download: {sorted(ASSEMBLY_TO_UCSC_FA_GZ.keys())}. "
158
- f"Either pass fasta_path explicitly, or extend ASSEMBLY_TO_UCSC_FA_GZ."
159
- )
160
 
161
  import gzip
162
 
@@ -340,7 +340,8 @@ class NTv3TracksPipeline(Pipeline):
340
  else:
341
  self.tokenizer = tokenizer
342
 
343
- # Extract model_id from config if not already set (following ntv3_gff_pipeline.py pattern)
 
344
  if self.model_id is None and self.config is not None:
345
  self.model_id = getattr(self.config, "_name_or_path", None) or getattr(
346
  self.config, "name_or_path", None
@@ -374,29 +375,57 @@ class NTv3TracksPipeline(Pipeline):
374
  model=self.model, tokenizer=self.tokenizer, device=-1, **kwargs
375
  )
376
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
377
  def _sanitize_parameters(self, **kwargs):
378
  return {}, {}, {}
379
 
380
- def _get_model_device(self) -> torch.device:
381
  return next(self.model.parameters()).device
382
 
383
  def _resolve_species_and_assembly(self, inputs: dict[str, Any]) -> tuple[str, str]:
384
  species = inputs.get("species", self.default_species)
385
  if species not in SPECIES_TO_ASSEMBLY:
 
386
  raise ValueError(
387
- f"Unsupported species='{species}'. Supported species: {sorted(SPECIES_TO_ASSEMBLY.keys())}"
388
  )
389
  assembly = SPECIES_TO_ASSEMBLY[species]
390
 
391
  cfg_assemblies = list(self.config.bigwigs_per_file_assembly.keys())
392
  if assembly not in cfg_assemblies:
393
  raise ValueError(
394
- f"Species '{species}' maps to assembly '{assembly}', but that assembly is not available in this checkpoint. "
 
395
  f"Available assemblies: {cfg_assemblies}"
396
  )
397
  return species, assembly
398
 
399
- def _maybe_force_cpu_for_mps_long(
400
  self, input_ids_cpu: torch.Tensor
401
  ) -> torch.device:
402
  dev = self._get_model_device()
@@ -405,40 +434,15 @@ class NTv3TracksPipeline(Pipeline):
405
  if seq_len >= self.mps_force_cpu_length:
406
  if self.verbose:
407
  print(
408
- f"[NTv3TracksPipeline] MPS detected and input is long (tokens={seq_len}). "
409
- "Switching model + inputs to CPU for this run."
 
410
  )
411
  self.model.to("cpu")
412
  self.model.eval()
413
  return torch.device("cpu")
414
  return dev
415
 
416
- def available_bigwig_track_names(self, species: str | None = None) -> list[str]:
417
- """
418
- Return BigWig track IDs for the assembly corresponding to `species`.
419
- No model forward pass.
420
- """
421
- sp = species or self.default_species
422
- assembly = SPECIES_TO_ASSEMBLY.get(sp)
423
- if assembly is None:
424
- raise ValueError(
425
- f"Unknown species={sp}. Supported: {sorted(SPECIES_TO_ASSEMBLY.keys())}"
426
- )
427
-
428
- if assembly not in self.config.bigwigs_per_file_assembly:
429
- raise ValueError(
430
- f"Assembly {assembly} not found in checkpoint config. "
431
- f"Available: {list(self.config.bigwigs_per_file_assembly.keys())}"
432
- )
433
-
434
- return list(self.config.bigwigs_per_file_assembly[assembly])
435
-
436
- def available_bed_element_names(self) -> list[str]:
437
- """
438
- Return BED element names available in this checkpoint (no forward pass).
439
- """
440
- return list(self.bed_element_names or [])
441
-
442
  def preprocess(self, inputs: dict[str, Any], **kwargs: Any) -> dict[str, Any]:
443
  species, assembly = self._resolve_species_and_assembly(inputs)
444
 
@@ -506,19 +510,6 @@ class NTv3TracksPipeline(Pipeline):
506
  def forward(self, model_inputs, **forward_params):
507
  return self._forward(model_inputs, **forward_params)
508
 
509
- def _forward(self, model_inputs: dict[str, Any], **kwargs: Any) -> dict[str, Any]:
510
- meta = model_inputs.pop("meta")
511
- if self.verbose:
512
- print(f"Running on device: {self._get_model_device()}")
513
- with torch.no_grad():
514
- out = self.model(
515
- input_ids=model_inputs["input_ids"],
516
- species_ids=model_inputs["species_ids"],
517
- return_dict=True,
518
- )
519
- out["meta"] = meta
520
- return out
521
-
522
  def postprocess(
523
  self, model_outputs: dict[str, Any], **kwargs: Any
524
  ) -> NTv3TracksOutput:
@@ -565,6 +556,19 @@ class NTv3TracksPipeline(Pipeline):
565
  pred_end=meta.get("pred_end"),
566
  )
567
 
 
 
 
 
 
 
 
 
 
 
 
 
 
568
  def __call__(
569
  self,
570
  inputs,
@@ -584,7 +588,8 @@ class NTv3TracksPipeline(Pipeline):
584
  if plot:
585
  if out.bigwig_track_names is None:
586
  raise ValueError(
587
- "bigwig_track_names missing; expected cfg.bigwigs_per_file_assembly[assembly]."
 
588
  )
589
  if out.bed_element_names is None:
590
  raise ValueError("bed element names missing from config.")
@@ -600,17 +605,22 @@ class NTv3TracksPipeline(Pipeline):
600
  ]
601
  if missing_tracks:
602
  raise ValueError(
603
- f"The following tracks are not available in bigwig_names: {missing_tracks}\n"
604
- f"First 50 available: {bigwig_names[:50]}{'...' if len(bigwig_names) > 50 else ''}"
 
 
605
  )
606
 
607
  missing_elements = [
608
  e for e in elements_to_plot if e not in bed_element_names
609
  ]
610
  if missing_elements:
 
 
611
  raise ValueError(
612
- f"The following elements are not available in bed_element_names: {missing_elements}\n"
613
- f"First 50 available: {bed_element_names[:50]}{'...' if len(bed_element_names) > 50 else ''}"
 
614
  )
615
 
616
  # Build bigwig tracks dict (title -> y)
@@ -662,7 +672,8 @@ def load_ntv3_tracks_pipeline(
662
  device:
663
  "auto", "cpu", "cuda", "mps"
664
  pipeline_kwargs:
665
- Extra kwargs passed to NTv3TracksPipeline (default_species, genome_cache_dir, etc.).
 
666
  """
667
  pipe = NTv3TracksPipeline(
668
  model=model,
 
74
  # Assembly -> API URL template mapping
75
  # ---------------------------------------------------------------------
76
  # Default API URL template (UCSC format) that works for most species
77
+ DEFAULT_API_URL_TEMPLATE = "https://api.genome.ucsc.edu/getData/sequence?genome={assembly};chrom={chrom};start={start};end={end}" # noqa: E501
78
 
79
  # for species with different format, add the assembly name to the mapping
80
  # The template should use {chrom}, {start}, and {end} as placeholders.
81
  ASSEMBLY_TO_API_URL_TEMPLATE = {
82
  # Arabidopsis thaliana (TAIR10) - uses hub URL format
83
+ "TAIR10": "https://api.genome.ucsc.edu/getData/sequence?hubUrl=http://genome.ucsc.edu/goldenPath/help/examples/hubExamples/hubAssembly/plantAraTha1/hub.txt;genome=araTha1;chrom={chrom};start={start};end={end}", # noqa: E501
84
  }
85
 
86
 
 
124
  """
125
  if requests is None:
126
  raise ImportError(
127
+ "requests is required for genome download. "
128
+ "Install with: pip install requests"
129
  )
130
 
131
  # Get API URL template for this assembly, or use default
 
152
  if fa_path.exists():
153
  return fa_path
154
 
155
+ # This function is deprecated - use _get_dna_sequence with API instead
156
+ raise ValueError(
157
+ f"FASTA file download is no longer supported for assembly='{assembly}'. "
158
+ f"Please use _get_dna_sequence() with API-based sequence fetching instead."
159
+ )
 
160
 
161
  import gzip
162
 
 
340
  else:
341
  self.tokenizer = tokenizer
342
 
343
+ # Extract model_id from config if not already set
344
+ # (following ntv3_gff_pipeline.py pattern)
345
  if self.model_id is None and self.config is not None:
346
  self.model_id = getattr(self.config, "_name_or_path", None) or getattr(
347
  self.config, "name_or_path", None
 
375
  model=self.model, tokenizer=self.tokenizer, device=-1, **kwargs
376
  )
377
 
378
+ def available_bigwig_track_names(self, species: str | None = None) -> list[str]:
379
+ """
380
+ Return BigWig track IDs for the assembly corresponding to `species`.
381
+ No model forward pass.
382
+ """
383
+ sp = species or self.default_species
384
+ assembly = SPECIES_TO_ASSEMBLY.get(sp)
385
+ if assembly is None:
386
+ raise ValueError(
387
+ f"Unknown species={sp}. Supported: {sorted(SPECIES_TO_ASSEMBLY.keys())}"
388
+ )
389
+
390
+ if assembly not in self.config.bigwigs_per_file_assembly:
391
+ raise ValueError(
392
+ f"Assembly {assembly} not found in checkpoint config. "
393
+ f"Available: {list(self.config.bigwigs_per_file_assembly.keys())}"
394
+ )
395
+
396
+ return list(self.config.bigwigs_per_file_assembly[assembly])
397
+
398
+ def available_bed_element_names(self) -> list[str]:
399
+ """
400
+ Return BED element names available in this checkpoint (no forward pass).
401
+ """
402
+ return list(self.bed_element_names or [])
403
+
404
  def _sanitize_parameters(self, **kwargs):
405
  return {}, {}, {}
406
 
407
+ def _get_model_device(self) -> torch.device: # noqa: CCE001
408
  return next(self.model.parameters()).device
409
 
410
  def _resolve_species_and_assembly(self, inputs: dict[str, Any]) -> tuple[str, str]:
411
  species = inputs.get("species", self.default_species)
412
  if species not in SPECIES_TO_ASSEMBLY:
413
+ supported = sorted(SPECIES_TO_ASSEMBLY.keys())
414
  raise ValueError(
415
+ f"Unsupported species='{species}'. " f"Supported species: {supported}"
416
  )
417
  assembly = SPECIES_TO_ASSEMBLY[species]
418
 
419
  cfg_assemblies = list(self.config.bigwigs_per_file_assembly.keys())
420
  if assembly not in cfg_assemblies:
421
  raise ValueError(
422
+ f"Species '{species}' maps to assembly '{assembly}', "
423
+ f"but that assembly is not available in this checkpoint. "
424
  f"Available assemblies: {cfg_assemblies}"
425
  )
426
  return species, assembly
427
 
428
+ def _maybe_force_cpu_for_mps_long( # noqa: CCE001
429
  self, input_ids_cpu: torch.Tensor
430
  ) -> torch.device:
431
  dev = self._get_model_device()
 
434
  if seq_len >= self.mps_force_cpu_length:
435
  if self.verbose:
436
  print(
437
+ f"[NTv3TracksPipeline] MPS detected and input is long "
438
+ f"(tokens={seq_len}). Switching model + inputs to CPU "
439
+ "for this run."
440
  )
441
  self.model.to("cpu")
442
  self.model.eval()
443
  return torch.device("cpu")
444
  return dev
445
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
446
  def preprocess(self, inputs: dict[str, Any], **kwargs: Any) -> dict[str, Any]:
447
  species, assembly = self._resolve_species_and_assembly(inputs)
448
 
 
510
  def forward(self, model_inputs, **forward_params):
511
  return self._forward(model_inputs, **forward_params)
512
 
 
 
 
 
 
 
 
 
 
 
 
 
 
513
  def postprocess(
514
  self, model_outputs: dict[str, Any], **kwargs: Any
515
  ) -> NTv3TracksOutput:
 
556
  pred_end=meta.get("pred_end"),
557
  )
558
 
559
+ def _forward(self, model_inputs: dict[str, Any], **kwargs: Any) -> dict[str, Any]:
560
+ meta = model_inputs.pop("meta")
561
+ if self.verbose:
562
+ print(f"Running on device: {self._get_model_device()}")
563
+ with torch.no_grad():
564
+ out = self.model(
565
+ input_ids=model_inputs["input_ids"],
566
+ species_ids=model_inputs["species_ids"],
567
+ return_dict=True,
568
+ )
569
+ out["meta"] = meta
570
+ return out
571
+
572
  def __call__(
573
  self,
574
  inputs,
 
588
  if plot:
589
  if out.bigwig_track_names is None:
590
  raise ValueError(
591
+ "bigwig_track_names missing; expected "
592
+ "cfg.bigwigs_per_file_assembly[assembly]."
593
  )
594
  if out.bed_element_names is None:
595
  raise ValueError("bed element names missing from config.")
 
605
  ]
606
  if missing_tracks:
607
  raise ValueError(
608
+ f"The following tracks are not available in "
609
+ f"bigwig_names: {missing_tracks}\n"
610
+ f"First 50 available: {bigwig_names[:50]}"
611
+ f"{'...' if len(bigwig_names) > 50 else ''}"
612
  )
613
 
614
  missing_elements = [
615
  e for e in elements_to_plot if e not in bed_element_names
616
  ]
617
  if missing_elements:
618
+ first_50 = bed_element_names[:50]
619
+ ellipsis = "..." if len(bed_element_names) > 50 else ""
620
  raise ValueError(
621
+ f"The following elements are not available in "
622
+ f"bed_element_names: {missing_elements}\n"
623
+ f"First 50 available: {first_50}{ellipsis}"
624
  )
625
 
626
  # Build bigwig tracks dict (title -> y)
 
672
  device:
673
  "auto", "cpu", "cuda", "mps"
674
  pipeline_kwargs:
675
+ Extra kwargs passed to NTv3TracksPipeline
676
+ (default_species, genome_cache_dir, etc.).
677
  """
678
  pipe = NTv3TracksPipeline(
679
  model=model,
requirements.txt CHANGED
@@ -1,8 +1,8 @@
1
  gradio>=4.0.0
 
2
  matplotlib
3
  numpy
4
  plotly
5
- kaleido
6
  pyBigWig
7
  pyfaidx
8
  requests
 
1
  gradio>=4.0.0
2
+ kaleido
3
  matplotlib
4
  numpy
5
  plotly
 
6
  pyBigWig
7
  pyfaidx
8
  requests