bernardo-de-almeida commited on
Commit
0af9d02
·
1 Parent(s): 4c9b328

fix: restrict bed elements to trained ones

Browse files
Files changed (2) hide show
  1. app.py +42 -4
  2. ntv3_tracks_pipeline.py +119 -4
app.py CHANGED
@@ -308,6 +308,21 @@ def _get_bigwig_names(species: str) -> list[str]:
308
  return _BIGWIG_CACHE[species]
309
 
310
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
311
  def _has_bigwigs(species: str) -> bool:
312
  """Check if a species has BigWig tracks available in the current model."""
313
  try:
@@ -652,7 +667,7 @@ def predict(
652
  if bed_logits is not None and bed_elements:
653
  probs = _softmax_last(bed_logits)
654
  for ename in bed_elements:
655
- display_name = ename.replace("_", " ")
656
  eidx = bed_names.index(ename)
657
  series.append((display_name, probs[:, eidx, 1][::stride].astype(float)))
658
 
@@ -751,7 +766,8 @@ observer.observe(document.body, {
751
  """
752
 
753
  # BED list is small enough to keep as dropdown
754
- _init_bed = pipe.available_bed_element_names()
 
755
 
756
  # Default BigWig tracks
757
  DEFAULT_BIGWIG_TRACKS = [
@@ -781,6 +797,12 @@ _init_bigwig_selected = [
781
  # Filter default BED elements to only those available
782
  _init_bed_selected = [elem for elem in DEFAULT_BED_ELEMENTS if elem in _init_bed]
783
 
 
 
 
 
 
 
784
  # Default coordinates per species
785
  DEFAULT_COORDS = {
786
  "human": {"chrom": "chr19", "start": 6_700_000, "end": 6_831_072},
@@ -1063,8 +1085,8 @@ with gr.Blocks(title="NTv3 Tracks Demo") as demo:
1063
  gr.Markdown("# Select genome annotation elements")
1064
 
1065
  bed_elements = gr.Dropdown(
1066
- choices=_init_bed,
1067
- value=_init_bed_selected if _init_bed_selected else [],
1068
  multiselect=True,
1069
  label="Genome annotation elements (search + select)",
1070
  elem_id="bed_elements_dropdown",
@@ -1268,6 +1290,17 @@ with gr.Blocks(title="NTv3 Tracks Demo") as demo:
1268
  except Exception:
1269
  pass
1270
 
 
 
 
 
 
 
 
 
 
 
 
1271
  return (
1272
  gr.update(visible=show_coords, value=coords["chrom"]),
1273
  gr.update(visible=show_coords, value=coords["start"]),
@@ -1291,6 +1324,10 @@ with gr.Blocks(title="NTv3 Tracks Demo") as demo:
1291
  gr.update(visible=has_bigwigs), # Show bigwig query if available
1292
  gr.update(visible=has_bigwigs), # Show bigwig results if available
1293
  gr.update(visible=has_bigwigs), # Show bigwig buttons if available
 
 
 
 
1294
  )
1295
 
1296
  # Update input type radio visibility and value when species changes
@@ -1339,6 +1376,7 @@ with gr.Blocks(title="NTv3 Tracks Demo") as demo:
1339
  bigwig_query,
1340
  bigwig_results,
1341
  bigwig_buttons_row,
 
1342
  ],
1343
  )
1344
 
 
308
  return _BIGWIG_CACHE[species]
309
 
310
 
311
+ def _get_bed_element_names(species: str) -> list[str]:
312
+ """Get BED element names available for a given species (filtered by training data)."""
313
+ if pipe is None:
314
+ return []
315
+ try:
316
+ return pipe.available_bed_element_names(species)
317
+ except (ValueError, AttributeError):
318
+ return []
319
+
320
+
321
+ def _format_bed_element_for_display(element_name: str) -> str:
322
+ """Format BED element name for display: replace underscores with spaces and capitalize."""
323
+ return element_name.replace("_", " ").title()
324
+
325
+
326
  def _has_bigwigs(species: str) -> bool:
327
  """Check if a species has BigWig tracks available in the current model."""
328
  try:
 
667
  if bed_logits is not None and bed_elements:
668
  probs = _softmax_last(bed_logits)
669
  for ename in bed_elements:
670
+ display_name = ename.replace("_", " ").lower()
671
  eidx = bed_names.index(ename)
672
  series.append((display_name, probs[:, eidx, 1][::stride].astype(float)))
673
 
 
766
  """
767
 
768
  # BED list is small enough to keep as dropdown
769
+ # Filter by default species to show only elements available for training
770
+ _init_bed = pipe.available_bed_element_names(DEFAULT_SPECIES)
771
 
772
  # Default BigWig tracks
773
  DEFAULT_BIGWIG_TRACKS = [
 
797
  # Filter default BED elements to only those available
798
  _init_bed_selected = [elem for elem in DEFAULT_BED_ELEMENTS if elem in _init_bed]
799
 
800
+ # Format BED elements for display: use tuples (display_name, value) for dropdown
801
+ _init_bed_choices = [
802
+ (_format_bed_element_for_display(elem), elem) for elem in _init_bed
803
+ ]
804
+ _init_bed_selected_values = _init_bed_selected # Keep original values for selection
805
+
806
  # Default coordinates per species
807
  DEFAULT_COORDS = {
808
  "human": {"chrom": "chr19", "start": 6_700_000, "end": 6_831_072},
 
1085
  gr.Markdown("# Select genome annotation elements")
1086
 
1087
  bed_elements = gr.Dropdown(
1088
+ choices=_init_bed_choices,
1089
+ value=_init_bed_selected_values if _init_bed_selected_values else [],
1090
  multiselect=True,
1091
  label="Genome annotation elements (search + select)",
1092
  elem_id="bed_elements_dropdown",
 
1290
  except Exception:
1291
  pass
1292
 
1293
+ # Get BED elements available for this species
1294
+ bed_element_names = _get_bed_element_names(species)
1295
+ # Filter default BED elements to only those available for this species
1296
+ default_bed_selected = [
1297
+ elem for elem in DEFAULT_BED_ELEMENTS if elem in bed_element_names
1298
+ ]
1299
+ # Format BED elements for display: use tuples (display_name, value)
1300
+ bed_element_choices = [
1301
+ (_format_bed_element_for_display(elem), elem) for elem in bed_element_names
1302
+ ]
1303
+
1304
  return (
1305
  gr.update(visible=show_coords, value=coords["chrom"]),
1306
  gr.update(visible=show_coords, value=coords["start"]),
 
1324
  gr.update(visible=has_bigwigs), # Show bigwig query if available
1325
  gr.update(visible=has_bigwigs), # Show bigwig results if available
1326
  gr.update(visible=has_bigwigs), # Show bigwig buttons if available
1327
+ gr.update(
1328
+ choices=bed_element_choices,
1329
+ value=default_bed_selected,
1330
+ ), # Update BED elements dropdown with species-specific elements
1331
  )
1332
 
1333
  # Update input type radio visibility and value when species changes
 
1376
  bigwig_query,
1377
  bigwig_results,
1378
  bigwig_buttons_row,
1379
+ bed_elements,
1380
  ],
1381
  )
1382
 
ntv3_tracks_pipeline.py CHANGED
@@ -110,6 +110,79 @@ BED_ELEMENT_COLORS = {
110
  }
111
 
112
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
113
  def _sanitize_dna(seq: str) -> str:
114
  seq = seq.upper()
115
  return "".join(ch if ch in ("A", "C", "G", "T", "N") else "N" for ch in seq)
@@ -373,11 +446,26 @@ class NTv3TracksPipeline(Pipeline):
373
 
374
  return list(self.config.bigwigs_per_species[species])
375
 
376
- def available_bed_element_names(self) -> list[str]:
377
  """
378
- Return BED element names available in this checkpoint (no forward pass).
 
 
 
 
 
 
 
 
 
 
 
 
379
  """
380
- return list(self.bed_element_names or [])
 
 
 
381
 
382
  def _sanitize_parameters(self, **kwargs):
383
  return {}, {}, {}
@@ -523,6 +611,33 @@ class NTv3TracksPipeline(Pipeline):
523
  if mlm_np.ndim == 3:
524
  mlm_np = mlm_np[0]
525
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
526
  return NTv3TracksOutput(
527
  bigwig_tracks_logits=bigwig_np,
528
  bed_tracks_logits=bed_np,
@@ -533,7 +648,7 @@ class NTv3TracksPipeline(Pipeline):
533
  species=meta.get("species"),
534
  assembly=meta.get("assembly"),
535
  bigwig_track_names=meta.get("bigwig_track_names"),
536
- bed_element_names=self.bed_element_names,
537
  window_len=meta.get("window_len"),
538
  pred_start=meta.get("pred_start"),
539
  pred_end=meta.get("pred_end"),
 
110
  }
111
 
112
 
113
+ def _filter_bed_elements_by_species(
114
+ bed_element_names: list[str], species: str
115
+ ) -> list[str]:
116
+ """
117
+ Filter BED element names based on species-specific training data availability.
118
+
119
+ Rules:
120
+ - Human: all tracks
121
+ - Mouse: only polyA_signal
122
+ - Other species: everything except promoter, enhancer, ctcf, lncrna
123
+
124
+ Parameters
125
+ ----------
126
+ bed_element_names : list[str]
127
+ Full list of BED element names from the model config
128
+ species : str
129
+ Species name (e.g., "human", "mouse", "drosophila_melanogaster")
130
+
131
+ Returns
132
+ -------
133
+ list[str]
134
+ Filtered list of BED element names available for this species
135
+ """
136
+ if not bed_element_names:
137
+ return []
138
+
139
+ # Elements to exclude for "other species" (everything except human and mouse)
140
+ excluded_for_other_species = {
141
+ "promoter Tissue specific",
142
+ "promoter Tissue invariant",
143
+ "enhancer Tissue specific",
144
+ "enhancer Tissue invariant",
145
+ "CTCF-bound",
146
+ "lncRNA",
147
+ }
148
+
149
+ # Normalize element names (handle both with/without underscores/spaces)
150
+ normalized_excluded = set()
151
+ for elem in excluded_for_other_species:
152
+ normalized_excluded.add(elem)
153
+ normalized_excluded.add(elem.replace(" ", "_"))
154
+
155
+ if species == "human":
156
+ # Human: all tracks
157
+ return list(bed_element_names)
158
+ else:
159
+ # Other species: everything except promoter, enhancer, ctcf, lncrna
160
+ # Normalize element names for comparison (handle spaces, underscores, case)
161
+ normalized_bed_names = {
162
+ elem.lower().replace("_", " "): elem
163
+ for elem in bed_element_names
164
+ }
165
+ normalized_excluded_lower = {
166
+ elem.lower().replace("_", " ")
167
+ for elem in excluded_for_other_species
168
+ }
169
+
170
+ # Also check for keywords in element names
171
+ excluded_keywords = ["promoter", "enhancer", "ctcf", "lnc"]
172
+
173
+ filtered_normalized = [
174
+ norm_name
175
+ for norm_name, orig_elem in normalized_bed_names.items()
176
+ if norm_name not in normalized_excluded_lower
177
+ and not any(keyword in norm_name for keyword in excluded_keywords)
178
+ ]
179
+
180
+ # Return original element names (preserving original format)
181
+ return [
182
+ normalized_bed_names[norm_name] for norm_name in filtered_normalized
183
+ ]
184
+
185
+
186
  def _sanitize_dna(seq: str) -> str:
187
  seq = seq.upper()
188
  return "".join(ch if ch in ("A", "C", "G", "T", "N") else "N" for ch in seq)
 
446
 
447
  return list(self.config.bigwigs_per_species[species])
448
 
449
+ def available_bed_element_names(self, species: str | None = None) -> list[str]:
450
  """
451
+ Return BED element names available in this checkpoint for the given species.
452
+ Filters elements based on species-specific training data availability.
453
+
454
+ Parameters
455
+ ----------
456
+ species : str | None
457
+ Species name (e.g., "human", "mouse"). If None, returns all elements
458
+ without filtering (for backward compatibility).
459
+
460
+ Returns
461
+ -------
462
+ list[str]
463
+ Filtered list of BED element names available for this species
464
  """
465
+ all_elements = list(self.bed_element_names or [])
466
+ if species is None:
467
+ return all_elements
468
+ return _filter_bed_elements_by_species(all_elements, species)
469
 
470
  def _sanitize_parameters(self, **kwargs):
471
  return {}, {}, {}
 
611
  if mlm_np.ndim == 3:
612
  mlm_np = mlm_np[0]
613
 
614
+ # Filter BED elements based on species
615
+ species = meta.get("species")
616
+ all_bed_element_names = self.bed_element_names or []
617
+ if species and all_bed_element_names:
618
+ filtered_bed_element_names = _filter_bed_elements_by_species(
619
+ all_bed_element_names, species
620
+ )
621
+ # Filter bed_tracks_logits to only include elements available for this species
622
+ if filtered_bed_element_names != all_bed_element_names:
623
+ # Create mapping from filtered element names to original indices
624
+ element_indices = [
625
+ all_bed_element_names.index(elem)
626
+ for elem in filtered_bed_element_names
627
+ if elem in all_bed_element_names
628
+ ]
629
+ if element_indices:
630
+ # bed_np shape is (L, E, C) where E is number of elements
631
+ bed_np = bed_np[:, element_indices, :]
632
+ # Update filtered list to only include elements that were found
633
+ filtered_bed_element_names = [
634
+ elem
635
+ for elem in filtered_bed_element_names
636
+ if elem in all_bed_element_names
637
+ ]
638
+ else:
639
+ filtered_bed_element_names = all_bed_element_names
640
+
641
  return NTv3TracksOutput(
642
  bigwig_tracks_logits=bigwig_np,
643
  bed_tracks_logits=bed_np,
 
648
  species=meta.get("species"),
649
  assembly=meta.get("assembly"),
650
  bigwig_track_names=meta.get("bigwig_track_names"),
651
+ bed_element_names=filtered_bed_element_names,
652
  window_len=meta.get("window_len"),
653
  pred_start=meta.get("pred_start"),
654
  pred_end=meta.get("pred_end"),