bernardo-de-almeida commited on
Commit
d71c881
·
1 Parent(s): 11da9d3

feat: add bigwig export option

Browse files
Files changed (2) hide show
  1. app.py +34 -9
  2. requirements.txt +1 -0
app.py CHANGED
@@ -3,10 +3,17 @@ import uuid
3
  import tempfile
4
  import numpy as np
5
  import gradio as gr
6
- import matplotlib.pyplot as plt
7
  import asyncio
8
 
 
 
 
 
 
 
 
9
  from ntv3_tracks_pipeline import load_ntv3_tracks_pipeline, BED_ELEMENT_COLORS
 
10
 
11
 
12
  # -----------------------------
@@ -54,12 +61,6 @@ load_pipeline(MODEL_ID, DEFAULT_SPECIES)
54
  # -----------------------------
55
  # Helpers
56
  # -----------------------------
57
- def _softmax_last(x: np.ndarray) -> np.ndarray:
58
- x = x - x.max(axis=-1, keepdims=True)
59
- ex = np.exp(x)
60
- return ex / ex.sum(axis=-1, keepdims=True)
61
-
62
-
63
  def _global_stride(L: int, target: int) -> int:
64
  if target <= 0 or L <= target:
65
  return 1
@@ -289,7 +290,7 @@ def predict(
289
  "plot_target_points": PLOT_TARGET_POINTS,
290
  }
291
 
292
- return fig, png_path, meta
293
 
294
 
295
  # -----------------------------
@@ -691,6 +692,14 @@ with gr.Blocks(title="NTv3 Tracks Demo") as demo:
691
 
692
  plot = gr.Plot(label="", elem_id="tracks_plot")
693
  export_png = gr.File(elem_id="export_png_hidden", interactive=False)
 
 
 
 
 
 
 
 
694
 
695
  with gr.Accordion("Meta (click to expand)", open=False):
696
  meta = gr.JSON(label="Meta")
@@ -750,9 +759,25 @@ with gr.Blocks(title="NTv3 Tracks Demo") as demo:
750
  btn.click(
751
  fn=predict,
752
  inputs=[seq, species, chrom, start, end, use_coords, bigwig_selected, bed_elements],
753
- outputs=[plot, export_png, meta],
754
  api_name="predict",
755
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
756
 
757
  if __name__ == "__main__":
758
  demo.launch(
 
3
  import tempfile
4
  import numpy as np
5
  import gradio as gr
 
6
  import asyncio
7
 
8
+ # Set matplotlib to use non-interactive backend before importing pyplot
9
+ # This is required for Gradio which runs on worker threads
10
+ import matplotlib
11
+ matplotlib.use('Agg')
12
+
13
+ import matplotlib.pyplot as plt
14
+
15
  from ntv3_tracks_pipeline import load_ntv3_tracks_pipeline, BED_ELEMENT_COLORS
16
+ from bigwig_export import create_bigwig_zip, _softmax_last
17
 
18
 
19
  # -----------------------------
 
61
  # -----------------------------
62
  # Helpers
63
  # -----------------------------
 
 
 
 
 
 
64
  def _global_stride(L: int, target: int) -> int:
65
  if target <= 0 or L <= target:
66
  return 1
 
290
  "plot_target_points": PLOT_TARGET_POINTS,
291
  }
292
 
293
+ return fig, png_path, meta, out, bigwig_selected, bed_elements
294
 
295
 
296
  # -----------------------------
 
692
 
693
  plot = gr.Plot(label="", elem_id="tracks_plot")
694
  export_png = gr.File(elem_id="export_png_hidden", interactive=False)
695
+
696
+ # State to store prediction output and selections for BigWig export
697
+ prediction_state = gr.State(value=None)
698
+ bigwig_selected_state = gr.State(value=[])
699
+ bed_elements_state = gr.State(value=[])
700
+
701
+ download_bigwig_btn = gr.Button("📥 Download tracks as BigWig files (ZIP)", variant="secondary")
702
+ export_bigwig = gr.File(label="Download BigWig files", visible=False)
703
 
704
  with gr.Accordion("Meta (click to expand)", open=False):
705
  meta = gr.JSON(label="Meta")
 
759
  btn.click(
760
  fn=predict,
761
  inputs=[seq, species, chrom, start, end, use_coords, bigwig_selected, bed_elements],
762
+ outputs=[plot, export_png, meta, prediction_state, bigwig_selected_state, bed_elements_state],
763
  api_name="predict",
764
  )
765
+
766
+ def download_bigwig_zip(out, bw_selected, bed_selected):
767
+ """Create and return BigWig zip file."""
768
+ try:
769
+ zip_path = create_bigwig_zip(out, bw_selected, bed_selected)
770
+ return gr.update(value=zip_path, visible=True)
771
+ except ImportError as e:
772
+ raise gr.Error("pyBigWig is required for BigWig export. Install with: pip install pyBigWig")
773
+ except Exception as e:
774
+ raise gr.Error(f"Error creating BigWig files: {str(e)}")
775
+
776
+ download_bigwig_btn.click(
777
+ fn=download_bigwig_zip,
778
+ inputs=[prediction_state, bigwig_selected_state, bed_elements_state],
779
+ outputs=[export_bigwig],
780
+ )
781
 
782
  if __name__ == "__main__":
783
  demo.launch(
requirements.txt CHANGED
@@ -5,3 +5,4 @@ gradio>=4.0.0
5
  pyfaidx
6
  requests
7
  matplotlib
 
 
5
  pyfaidx
6
  requests
7
  matplotlib
8
+ pyBigWig