bernardo-de-almeida commited on
Commit
25532d4
·
1 Parent(s): 9daf043

feat: add features to handle better different species

Browse files
Files changed (2) hide show
  1. app.py +171 -31
  2. ntv3_tracks_pipeline.py +14 -5
app.py CHANGED
@@ -12,7 +12,12 @@ matplotlib.use('Agg')
12
 
13
  import matplotlib.pyplot as plt
14
 
15
- from ntv3_tracks_pipeline import load_ntv3_tracks_pipeline, BED_ELEMENT_COLORS
 
 
 
 
 
16
  from bigwig_export import create_bigwig_zip, _softmax_last
17
 
18
 
@@ -116,6 +121,28 @@ def _get_bigwig_names(species: str) -> list[str]:
116
  return _BIGWIG_CACHE[species]
117
 
118
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
119
  def _rank_search(query: str, names: list[str], limit: int) -> list[str]:
120
  """
121
  Return up to `limit` candidate track IDs matching `query` using a fast,
@@ -180,7 +207,11 @@ def update_coords_on_species_change(species: str):
180
 
181
  def reset_on_species_change(species: str):
182
  # Clear results + selected when species changes (avoids mismatched IDs)
183
- _get_bigwig_names(species) # warms cache
 
 
 
 
184
  return (
185
  gr.update(value=""), # query textbox
186
  gr.update(choices=[], value=[]), # results list
@@ -201,7 +232,18 @@ def predict(
201
  bigwig_selected: list[str],
202
  bed_elements: list[str],
203
  ):
 
 
 
 
204
  if use_coords:
 
 
 
 
 
 
 
205
  if not chrom:
206
  raise gr.Error("chrom is required when use_coords=True")
207
  if start is None or end is None or int(end) <= int(start):
@@ -211,7 +253,11 @@ def predict(
211
  if not seq or not seq.strip():
212
  raise gr.Error("seq is required when use_coords=False")
213
  inputs = {"seq": seq.strip(), "species": species}
214
-
 
 
 
 
215
  out = pipe(inputs)
216
 
217
  bw_names = out.bigwig_track_names or []
@@ -219,11 +265,18 @@ def predict(
219
  bed_names = out.bed_element_names or []
220
  bed_logits = out.bed_tracks_logits
221
 
222
- if bw is None or not bw_names:
223
- raise gr.Error("No BigWig tracks available in model output.")
 
 
 
 
 
 
 
224
 
225
  # Defaults if user picked none
226
- if not bigwig_selected:
227
  default_bigwig_tracks = [
228
  "ENCSR056HPM", # K562 RNA-seq
229
  "ENCSR921NMD", # K562 DNAse
@@ -236,21 +289,31 @@ def predict(
236
  ]
237
  # Filter to only include tracks that are available for this species/assembly
238
  bigwig_selected = [tid for tid in default_bigwig_tracks if tid in bw_names]
 
239
  if (not bed_elements) and bed_names:
240
  default_bed_elements = ["protein_coding_gene", "exon", "intron"]
241
  # Filter to only include elements that are available
242
  bed_elements = [elem for elem in default_bed_elements if elem in bed_names]
243
 
244
  # Validate (important for API usage)
245
- missing_tracks = [t for t in bigwig_selected if t not in bw_names]
246
- if missing_tracks:
247
- raise gr.Error(f"Unknown BigWig track id(s): {missing_tracks}")
248
-
249
- missing_elems = [e for e in bed_elements if e not in bed_names]
250
- if missing_elems:
251
- raise gr.Error(f"Unknown BED element(s): {missing_elems}")
252
-
253
- L = bw.shape[0]
 
 
 
 
 
 
 
 
 
254
  stride = _global_stride(L, PLOT_TARGET_POINTS)
255
 
256
  x0 = int(out.pred_start or 0)
@@ -258,10 +321,14 @@ def predict(
258
  x = np.linspace(x0, x1, num=L, endpoint=False)[::stride]
259
 
260
  series: list[tuple[str, np.ndarray]] = []
261
- for tid in bigwig_selected:
262
- idx = bw_names.index(tid)
263
- series.append((tid, bw[:, idx][::stride].astype(float)))
 
 
 
264
 
 
265
  if bed_logits is not None and bed_elements:
266
  probs = _softmax_last(bed_logits)
267
  for ename in bed_elements:
@@ -532,9 +599,24 @@ DEFAULT_COORDS = {
532
  # Get default coordinates for default species
533
  _default_coords = DEFAULT_COORDS.get(DEFAULT_SPECIES, DEFAULT_COORDS["human"])
534
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
535
  with gr.Blocks(title="NTv3 Tracks Demo") as demo:
536
  gr.Markdown(
537
- """
538
  <div class="intro-hero">
539
 
540
  <div class="intro-title">
@@ -543,9 +625,6 @@ with gr.Blocks(title="NTv3 Tracks Demo") as demo:
543
  Predict and visualize functional genomics signals directly from DNA using
544
  <strong>Nucleotide Transformer v3</strong>.
545
  </p>
546
- <p style="margin-top: 8px; font-size: 0.95rem; opacity: 0.85;">
547
- <strong>Currently available species:</strong> Human, Mouse, Drosophila melanogaster, Arabidopsis thaliana, Gorilla
548
- </p>
549
  </div>
550
 
551
  <div class="intro-grid">
@@ -581,6 +660,12 @@ with gr.Blocks(title="NTv3 Tracks Demo") as demo:
581
  <span><strong>Tip:</strong> The demo includes default settings that you can use to get started, taking ~ 1 minute to run.</span>
582
  </div>
583
 
 
 
 
 
 
 
584
  </div>
585
  """,
586
  elem_id="intro_markdown",
@@ -609,17 +694,26 @@ with gr.Blocks(title="NTv3 Tracks Demo") as demo:
609
 
610
  model_status = gr.Markdown("", visible=False)
611
 
612
- gr.Markdown("## Input sequence (Genomic coordinate or DNA sequence)")
 
 
 
 
 
613
 
614
  with gr.Row():
615
  species = gr.Dropdown(
616
- ["human", "mouse", "drosophila_melanogaster", "arabidopsis_thaliana", "gorilla_gorilla"],
617
  value=DEFAULT_SPECIES,
618
  label="Species",
619
  )
620
- use_coords = gr.Checkbox(True, label="Use genome coordinates")
 
 
 
 
621
 
622
- with gr.Row():
623
  chrom = gr.Textbox(label="Chromosome", value=_default_coords["chrom"])
624
  start = gr.Number(label="Start", value=_default_coords["start"], precision=0)
625
  end = gr.Number(label="End", value=_default_coords["end"], precision=0)
@@ -650,6 +744,11 @@ with gr.Blocks(title="NTv3 Tracks Demo") as demo:
650
  )
651
 
652
  gr.Markdown("## Select functional tracks")
 
 
 
 
 
653
 
654
  bigwig_selected = gr.CheckboxGroup(
655
  choices=_init_bigwig_selected,
@@ -667,7 +766,7 @@ with gr.Blocks(title="NTv3 Tracks Demo") as demo:
667
  label="Results (click to add to Selected)",
668
  )
669
 
670
- with gr.Row():
671
  bigwig_clear_btn = gr.Button("Clear results")
672
  bigwig_remove_btn = gr.Button("Remove checked from Selected")
673
 
@@ -743,11 +842,52 @@ with gr.Blocks(title="NTv3 Tracks Demo") as demo:
743
  outputs=[bigwig_query, bigwig_results, bigwig_selected],
744
  )
745
 
746
- # Update coordinates when species changes
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
747
  species.change(
748
- fn=update_coords_on_species_change,
749
- inputs=[species],
750
- outputs=[chrom, start, end],
 
 
 
 
 
 
 
 
 
751
  )
752
 
753
  btn.click(
 
12
 
13
  import matplotlib.pyplot as plt
14
 
15
+ from ntv3_tracks_pipeline import (
16
+ load_ntv3_tracks_pipeline,
17
+ BED_ELEMENT_COLORS,
18
+ ASSEMBLY_TO_SPECIES,
19
+ SPECIES_WITH_COORDINATE_SUPPORT,
20
+ )
21
  from bigwig_export import create_bigwig_zip, _softmax_last
22
 
23
 
 
121
  return _BIGWIG_CACHE[species]
122
 
123
 
124
+ def _has_bigwigs(species: str) -> bool:
125
+ """Check if a species has BigWig tracks available in the current model."""
126
+ try:
127
+ tracks = _get_bigwig_names(species)
128
+ return len(tracks) > 0
129
+ except (ValueError, AttributeError):
130
+ # Species not in config or pipeline not loaded
131
+ return False
132
+
133
+
134
+ def _get_species_with_bigwigs() -> set[str]:
135
+ """Get set of species that have BigWig tracks available in the current model."""
136
+ if pipe is None:
137
+ return set()
138
+
139
+ species_with_bigwigs = set()
140
+ for species in ASSEMBLY_TO_SPECIES.values():
141
+ if _has_bigwigs(species):
142
+ species_with_bigwigs.add(species)
143
+ return species_with_bigwigs
144
+
145
+
146
  def _rank_search(query: str, names: list[str], limit: int) -> list[str]:
147
  """
148
  Return up to `limit` candidate track IDs matching `query` using a fast,
 
207
 
208
  def reset_on_species_change(species: str):
209
  # Clear results + selected when species changes (avoids mismatched IDs)
210
+ try:
211
+ _get_bigwig_names(species) # warms cache if available
212
+ except (ValueError, AttributeError):
213
+ # Species doesn't have bigwigs, that's okay
214
+ pass
215
  return (
216
  gr.update(value=""), # query textbox
217
  gr.update(choices=[], value=[]), # results list
 
232
  bigwig_selected: list[str],
233
  bed_elements: list[str],
234
  ):
235
+ # Debug: verify species is being passed
236
+ if not species:
237
+ raise gr.Error("Species parameter is missing. Please select a species.")
238
+
239
  if use_coords:
240
+ # Check if this species supports coordinate-based fetching
241
+ if species not in SPECIES_WITH_COORDINATE_SUPPORT:
242
+ raise gr.Error(
243
+ f"Species '{species}' does not support coordinate-based sequence fetching. "
244
+ f"Please provide a DNA sequence directly or use one of the supported species: "
245
+ f"{', '.join(sorted(SPECIES_WITH_COORDINATE_SUPPORT))}"
246
+ )
247
  if not chrom:
248
  raise gr.Error("chrom is required when use_coords=True")
249
  if start is None or end is None or int(end) <= int(start):
 
253
  if not seq or not seq.strip():
254
  raise gr.Error("seq is required when use_coords=False")
255
  inputs = {"seq": seq.strip(), "species": species}
256
+
257
+ # Verify species is in inputs before calling pipeline
258
+ if "species" not in inputs:
259
+ raise gr.Error(f"Internal error: species not found in inputs dict. Inputs: {list(inputs.keys())}")
260
+
261
  out = pipe(inputs)
262
 
263
  bw_names = out.bigwig_track_names or []
 
265
  bed_names = out.bed_element_names or []
266
  bed_logits = out.bed_tracks_logits
267
 
268
+ # Check if we have any tracks/elements to plot
269
+ has_bigwigs = bw is not None and len(bw_names) > 0
270
+ has_bed = bed_logits is not None and len(bed_names) > 0
271
+
272
+ if not has_bigwigs and not has_bed:
273
+ raise gr.Error("No BigWig tracks or BED elements available for this species in the current model.")
274
+
275
+ if not has_bigwigs and bigwig_selected:
276
+ raise gr.Error("No BigWig tracks available for this species, but BigWig tracks were selected. Please deselect BigWig tracks or choose a different species.")
277
 
278
  # Defaults if user picked none
279
+ if has_bigwigs and not bigwig_selected:
280
  default_bigwig_tracks = [
281
  "ENCSR056HPM", # K562 RNA-seq
282
  "ENCSR921NMD", # K562 DNAse
 
289
  ]
290
  # Filter to only include tracks that are available for this species/assembly
291
  bigwig_selected = [tid for tid in default_bigwig_tracks if tid in bw_names]
292
+
293
  if (not bed_elements) and bed_names:
294
  default_bed_elements = ["protein_coding_gene", "exon", "intron"]
295
  # Filter to only include elements that are available
296
  bed_elements = [elem for elem in default_bed_elements if elem in bed_names]
297
 
298
  # Validate (important for API usage)
299
+ if has_bigwigs and bigwig_selected:
300
+ missing_tracks = [t for t in bigwig_selected if t not in bw_names]
301
+ if missing_tracks:
302
+ raise gr.Error(f"Unknown BigWig track id(s): {missing_tracks}")
303
+
304
+ if bed_elements:
305
+ missing_elems = [e for e in bed_elements if e not in bed_names]
306
+ if missing_elems:
307
+ raise gr.Error(f"Unknown BED element(s): {missing_elems}")
308
+
309
+ # Determine sequence length from available data
310
+ if has_bigwigs:
311
+ L = bw.shape[0]
312
+ elif has_bed:
313
+ L = bed_logits.shape[0]
314
+ else:
315
+ raise gr.Error("No data available for plotting.")
316
+
317
  stride = _global_stride(L, PLOT_TARGET_POINTS)
318
 
319
  x0 = int(out.pred_start or 0)
 
321
  x = np.linspace(x0, x1, num=L, endpoint=False)[::stride]
322
 
323
  series: list[tuple[str, np.ndarray]] = []
324
+
325
+ # Add BigWig tracks if available and selected
326
+ if has_bigwigs and bigwig_selected:
327
+ for tid in bigwig_selected:
328
+ idx = bw_names.index(tid)
329
+ series.append((tid, bw[:, idx][::stride].astype(float)))
330
 
331
+ # Add BED elements if available and selected
332
  if bed_logits is not None and bed_elements:
333
  probs = _softmax_last(bed_logits)
334
  for ename in bed_elements:
 
599
  # Get default coordinates for default species
600
  _default_coords = DEFAULT_COORDS.get(DEFAULT_SPECIES, DEFAULT_COORDS["human"])
601
 
602
+ # Format species names for display (replace underscores with spaces, capitalize)
603
+ def _format_species_name(species: str) -> str:
604
+ """Format species name for display."""
605
+ return species.replace("_", " ").title()
606
+
607
+ # Get all available species and format them
608
+ _all_species = sorted(ASSEMBLY_TO_SPECIES.values())
609
+ _all_species_formatted = [_format_species_name(s) for s in _all_species]
610
+ _all_species_list = ", ".join(_all_species_formatted)
611
+
612
+ # Get species with BigWig tracks
613
+ _species_with_bigwigs = _get_species_with_bigwigs()
614
+ _bigwig_species_formatted = sorted([_format_species_name(s) for s in _species_with_bigwigs])
615
+ _bigwig_species_list = ", ".join(_bigwig_species_formatted) if _bigwig_species_formatted else "None (BED elements only)"
616
+
617
  with gr.Blocks(title="NTv3 Tracks Demo") as demo:
618
  gr.Markdown(
619
+ f"""
620
  <div class="intro-hero">
621
 
622
  <div class="intro-title">
 
625
  Predict and visualize functional genomics signals directly from DNA using
626
  <strong>Nucleotide Transformer v3</strong>.
627
  </p>
 
 
 
628
  </div>
629
 
630
  <div class="intro-grid">
 
660
  <span><strong>Tip:</strong> The demo includes default settings that you can use to get started, taking ~ 1 minute to run.</span>
661
  </div>
662
 
663
+ <div style="margin-top: 16px; padding: 12px; background: rgba(0,0,0,0.03); border-radius: 12px; font-size: 0.95rem;">
664
+ <strong>Available species:</strong> {_all_species_list}<br>
665
+ <br>
666
+ <strong>Species with functional tracks:</strong> {_bigwig_species_list}
667
+ </div>
668
+
669
  </div>
670
  """,
671
  elem_id="intro_markdown",
 
694
 
695
  model_status = gr.Markdown("", visible=False)
696
 
697
+ gr.Markdown("## Input sequence (Genomic coordinate or DNA sequence)\n"
698
+ "Supported species for coordinate-based sequence fetching: " + ", ".join(sorted(SPECIES_WITH_COORDINATE_SUPPORT)) + "\n"
699
+ )
700
+
701
+ # Get all available species from the pipeline
702
+ all_species = sorted(ASSEMBLY_TO_SPECIES.values())
703
 
704
  with gr.Row():
705
  species = gr.Dropdown(
706
+ choices=all_species,
707
  value=DEFAULT_SPECIES,
708
  label="Species",
709
  )
710
+ use_coords = gr.Checkbox(
711
+ True,
712
+ label="Use genome coordinate (only for supported species)",
713
+ visible=DEFAULT_SPECIES in SPECIES_WITH_COORDINATE_SUPPORT
714
+ )
715
 
716
+ with gr.Row(visible=True) as coords_row:
717
  chrom = gr.Textbox(label="Chromosome", value=_default_coords["chrom"])
718
  start = gr.Number(label="Start", value=_default_coords["start"], precision=0)
719
  end = gr.Number(label="End", value=_default_coords["end"], precision=0)
 
744
  )
745
 
746
  gr.Markdown("## Select functional tracks")
747
+
748
+ bigwig_no_tracks_msg = gr.Markdown(
749
+ "⚠️ No functional genomic tracks available for this species in the current model.",
750
+ visible=False,
751
+ )
752
 
753
  bigwig_selected = gr.CheckboxGroup(
754
  choices=_init_bigwig_selected,
 
766
  label="Results (click to add to Selected)",
767
  )
768
 
769
+ with gr.Row(visible=True) as bigwig_buttons_row:
770
  bigwig_clear_btn = gr.Button("Clear results")
771
  bigwig_remove_btn = gr.Button("Remove checked from Selected")
772
 
 
842
  outputs=[bigwig_query, bigwig_results, bigwig_selected],
843
  )
844
 
845
+ # Update coordinates visibility and values when species changes
846
+ def update_on_species_change(species: str, use_coords_val: bool):
847
+ """Update coordinates visibility and values when species changes."""
848
+ is_supported = species in SPECIES_WITH_COORDINATE_SUPPORT
849
+ has_bigwigs = _has_bigwigs(species)
850
+ coords = DEFAULT_COORDS.get(species, DEFAULT_COORDS["human"])
851
+ # Show coordinates only if species is supported AND use_coords is True
852
+ show_coords = is_supported and use_coords_val
853
+ return (
854
+ gr.update(visible=show_coords, value=coords["chrom"]),
855
+ gr.update(visible=show_coords, value=coords["start"]),
856
+ gr.update(visible=show_coords, value=coords["end"]),
857
+ gr.update(value=is_supported, visible=is_supported), # Show/hide and enable use_coords only if supported
858
+ gr.update(visible=show_coords), # Show/hide the row
859
+ gr.update(visible=not has_bigwigs), # Show "no tracks" message if no bigwigs
860
+ gr.update(visible=has_bigwigs), # Show bigwig selection if available
861
+ gr.update(visible=has_bigwigs), # Show bigwig query if available
862
+ gr.update(visible=has_bigwigs), # Show bigwig results if available
863
+ gr.update(visible=has_bigwigs), # Show bigwig buttons if available
864
+ )
865
+
866
+ # Update coordinates visibility when checkbox changes
867
+ def update_coords_visibility(use_coords_val: bool, species: str):
868
+ """Update coordinates visibility when checkbox changes."""
869
+ is_supported = species in SPECIES_WITH_COORDINATE_SUPPORT
870
+ show_coords = is_supported and use_coords_val
871
+ return (
872
+ gr.update(visible=show_coords),
873
+ gr.update(visible=show_coords),
874
+ gr.update(visible=show_coords),
875
+ gr.update(visible=show_coords), # Show/hide the row
876
+ )
877
+
878
  species.change(
879
+ fn=update_on_species_change,
880
+ inputs=[species, use_coords],
881
+ outputs=[
882
+ chrom, start, end, use_coords, coords_row,
883
+ bigwig_no_tracks_msg, bigwig_selected, bigwig_query, bigwig_results, bigwig_buttons_row
884
+ ],
885
+ )
886
+
887
+ use_coords.change(
888
+ fn=update_coords_visibility,
889
+ inputs=[use_coords, species],
890
+ outputs=[chrom, start, end, coords_row],
891
  )
892
 
893
  btn.click(
ntv3_tracks_pipeline.py CHANGED
@@ -38,8 +38,6 @@ ASSEMBLY_TO_SPECIES = {
38
  "Glycine_max_v2.1": "glycine_max",
39
  "IWGSC": "triticum_aestivum",
40
  "Gossypium_hirsutum_v2.1": "gossypium_hirsutum",
41
- "ASM228892v3": "delphinapterus_leucas",
42
- "ASM334442v1": "ursus_americanus",
43
  "AmpOce1": "amphiprion_ocellaris",
44
  "Bison_UMD1": "bison_bison_bison",
45
  "ChiLan1": "chinchilla_lanigera",
@@ -47,7 +45,6 @@ ASSEMBLY_TO_SPECIES = {
47
  "GRCz11": "danio_rerio",
48
  "KH": "ciona_intestinalis",
49
  "Mnem_1": "macaca_nemestrina",
50
- "R64": "saccharomyces_cerevisiae",
51
  "ROS_Cfam_1": "canis_lupus_familiaris",
52
  "SCA1": "serinus_canaria",
53
  "TETRAODON8": "tetraodon_nigroviridis",
@@ -56,11 +53,23 @@ ASSEMBLY_TO_SPECIES = {
56
  "fSalTru1": "salmo_trutta",
57
  "gorGor4": "gorilla_gorilla",
58
  "mRatBN7": "rattus_norvegicus",
59
- "SL3": "solanum_lycopersicum",
60
- "ARS-UCD2.0": "bos_taurus",
61
  }
62
  SPECIES_TO_ASSEMBLY = {v: k for k, v in ASSEMBLY_TO_SPECIES.items()}
63
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
64
  # ---------------------------------------------------------------------
65
  # Assembly -> API URL template mapping
66
  # ---------------------------------------------------------------------
 
38
  "Glycine_max_v2.1": "glycine_max",
39
  "IWGSC": "triticum_aestivum",
40
  "Gossypium_hirsutum_v2.1": "gossypium_hirsutum",
 
 
41
  "AmpOce1": "amphiprion_ocellaris",
42
  "Bison_UMD1": "bison_bison_bison",
43
  "ChiLan1": "chinchilla_lanigera",
 
45
  "GRCz11": "danio_rerio",
46
  "KH": "ciona_intestinalis",
47
  "Mnem_1": "macaca_nemestrina",
 
48
  "ROS_Cfam_1": "canis_lupus_familiaris",
49
  "SCA1": "serinus_canaria",
50
  "TETRAODON8": "tetraodon_nigroviridis",
 
53
  "fSalTru1": "salmo_trutta",
54
  "gorGor4": "gorilla_gorilla",
55
  "mRatBN7": "rattus_norvegicus",
 
 
56
  }
57
  SPECIES_TO_ASSEMBLY = {v: k for k, v in ASSEMBLY_TO_SPECIES.items()}
58
 
59
+ # ---------------------------------------------------------------------
60
+ # Species that support coordinate-based sequence fetching
61
+ # ---------------------------------------------------------------------
62
+ # List of species that can fetch DNA sequences from genomic coordinates via API.
63
+ # Species not in this list can still be used but require direct DNA sequence input.
64
+ SPECIES_WITH_COORDINATE_SUPPORT = {
65
+ "human", # hg38 - UCSC API
66
+ "mouse", # mm10 - UCSC API
67
+ "drosophila_melanogaster", # dm6 - UCSC API
68
+ "arabidopsis_thaliana", # TAIR10 - UCSC hub API
69
+ "gorilla_gorilla", # gorGor4 - UCSC API
70
+ # Add more species as API URLs are configured
71
+ }
72
+
73
  # ---------------------------------------------------------------------
74
  # Assembly -> API URL template mapping
75
  # ---------------------------------------------------------------------