bernardo-de-almeida commited on
Commit
beb6a82
·
1 Parent(s): 82970d1

feat: make plots interactive

Browse files
Files changed (2) hide show
  1. app.py +76 -21
  2. requirements.txt +2 -0
app.py CHANGED
@@ -7,8 +7,11 @@ from pathlib import Path
7
 
8
  import gradio as gr
9
  import matplotlib
 
10
  import matplotlib.pyplot as plt
11
  import numpy as np
 
 
12
  import torch
13
 
14
  from bigwig_export import _softmax_last, create_bigwig_zip
@@ -109,42 +112,93 @@ def _global_stride(L: int, target: int) -> int:
109
  return int(np.ceil(L / target))
110
 
111
 
112
- def _make_tracks_figure(x: np.ndarray, series: list[tuple[str, np.ndarray]]):
 
113
  if not series:
114
  raise gr.Error("Nothing to plot (no tracks/elements selected).")
115
 
116
  n = len(series)
117
- fig, axes = plt.subplots(n, 1, figsize=(18, 1.35 * n), sharex=True)
118
- if n == 1:
119
- axes = [axes]
 
 
 
 
 
 
120
 
121
  # Define color schemes
122
  bigwig_color = "#4A90E2" # Blue
123
 
124
- for ax, (title, y) in zip(axes, series):
125
  # Determine color based on track type
126
  if title in BED_ELEMENT_COLORS:
127
  color = BED_ELEMENT_COLORS[title]
128
  else:
129
  color = bigwig_color
130
 
131
- ax.fill_between(x, y, color=color, alpha=0.3, linewidth=0)
132
- ax.plot(x, y, color=color, linewidth=0.8)
133
- ax.set_title(title, fontsize=10, loc="left")
134
- ax.grid(alpha=0.2)
135
- ax.set_yticks([])
136
- ax.spines["top"].set_visible(False)
137
- ax.spines["right"].set_visible(False)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
138
 
139
- axes[-1].set_xlabel("Genomic position / index")
140
- fig.tight_layout()
141
  return fig
142
 
143
 
144
  def _save_fig_png(fig) -> str:
 
145
  tmpdir = tempfile.gettempdir()
146
  out_path = os.path.join(tmpdir, f"ntv3_tracks_{uuid.uuid4().hex}.png")
147
- fig.savefig(out_path, dpi=200, bbox_inches="tight")
 
148
  return out_path
149
 
150
 
@@ -499,7 +553,7 @@ def predict(
499
  tprint(f"model moved to {device}")
500
 
501
  pipe.model.eval()
502
- tprint(f"Running on {next(pipe.model.parameters()).device}")
503
  tprint("model ready to run inference")
504
 
505
  # run inference
@@ -591,15 +645,16 @@ def predict(
591
  series.append((ename, probs[:, eidx, 1][::stride].astype(float)))
592
 
593
  tprint("figure data processed created")
594
- fig = _make_tracks_figure(x, series)
595
- tprint("figure created")
596
-
597
  region = (
598
  f"{out.chrom}:{out.pred_start}-{out.pred_end}" if out.chrom else f"{x0}-{x1}"
599
  )
600
  if out.assembly:
601
  region += f" ({out.assembly})"
602
- fig.axes[-1].set_xlabel(region)
 
 
603
 
604
  png_path = _save_fig_png(fig)
605
  tprint("figure png saved")
@@ -1056,7 +1111,7 @@ with gr.Blocks(title="NTv3 Tracks Demo") as demo:
1056
  + ", ".join(sorted(SPECIES_WITH_COORDINATE_SUPPORT))
1057
  + ")"
1058
  )
1059
- with gr.Row():
1060
  chrom = gr.Textbox(label="Chromosome", value=_default_coords["chrom"])
1061
  start = gr.Number(
1062
  label="Start", value=_default_coords["start"], precision=0
 
7
 
8
  import gradio as gr
9
  import matplotlib
10
+ import matplotlib.colors as mcolors
11
  import matplotlib.pyplot as plt
12
  import numpy as np
13
+ import plotly.graph_objects as go
14
+ from plotly.subplots import make_subplots
15
  import torch
16
 
17
  from bigwig_export import _softmax_last, create_bigwig_zip
 
112
  return int(np.ceil(L / target))
113
 
114
 
115
+ def _make_tracks_figure(x: np.ndarray, series: list[tuple[str, np.ndarray]], region: str = ""):
116
+ """Create an interactive plotly figure with multiple tracks."""
117
  if not series:
118
  raise gr.Error("Nothing to plot (no tracks/elements selected).")
119
 
120
  n = len(series)
121
+
122
+ # Create subplots with shared x-axis
123
+ fig = make_subplots(
124
+ rows=n,
125
+ cols=1,
126
+ shared_xaxes=True,
127
+ vertical_spacing=0.02,
128
+ subplot_titles=[title for title, _ in series],
129
+ )
130
 
131
  # Define color schemes
132
  bigwig_color = "#4A90E2" # Blue
133
 
134
+ for i, (title, y) in enumerate(series, 1):
135
  # Determine color based on track type
136
  if title in BED_ELEMENT_COLORS:
137
  color = BED_ELEMENT_COLORS[title]
138
  else:
139
  color = bigwig_color
140
 
141
+ # Convert color to rgba for fill
142
+ rgba = mcolors.to_rgba(color)
143
+ rgba_str = f"rgba({int(rgba[0]*255)}, {int(rgba[1]*255)}, {int(rgba[2]*255)}, 0.3)"
144
+
145
+ # Add filled area (fill_between equivalent)
146
+ fig.add_trace(
147
+ go.Scatter(
148
+ x=x,
149
+ y=y,
150
+ mode="lines",
151
+ name=title,
152
+ line=dict(color=color, width=1.5),
153
+ fill="tozeroy",
154
+ fillcolor=rgba_str,
155
+ hovertemplate=f"<b>{title}</b><br>" +
156
+ "Position: %{x}<br>" +
157
+ "Value: %{y:.4f}<extra></extra>",
158
+ showlegend=False,
159
+ ),
160
+ row=i,
161
+ col=1,
162
+ )
163
+
164
+ # Update layout for better appearance
165
+ fig.update_layout(
166
+ height=150 * n, # Adjust height based on number of tracks
167
+ width=1200,
168
+ margin=dict(l=80, r=20, t=40, b=60),
169
+ hovermode="x unified", # Show all values at same x position
170
+ template="plotly_white",
171
+ )
172
+
173
+ # Update y-axes to remove ticks and improve appearance
174
+ for i in range(1, n + 1):
175
+ fig.update_yaxes(
176
+ showticklabels=False,
177
+ showgrid=True,
178
+ gridcolor="rgba(0,0,0,0.1)",
179
+ row=i,
180
+ col=1,
181
+ )
182
+
183
+ # Update x-axis on the last subplot with region label
184
+ xaxis_title = region if region else "Genomic position / index"
185
+ fig.update_xaxes(
186
+ title_text=xaxis_title,
187
+ showgrid=True,
188
+ gridcolor="rgba(0,0,0,0.1)",
189
+ row=n,
190
+ col=1,
191
+ )
192
 
 
 
193
  return fig
194
 
195
 
196
  def _save_fig_png(fig) -> str:
197
+ """Save plotly figure as PNG."""
198
  tmpdir = tempfile.gettempdir()
199
  out_path = os.path.join(tmpdir, f"ntv3_tracks_{uuid.uuid4().hex}.png")
200
+ # Plotly figures can be saved directly as PNG
201
+ fig.write_image(out_path, width=1200, height=fig.layout.height, scale=2)
202
  return out_path
203
 
204
 
 
553
  tprint(f"model moved to {device}")
554
 
555
  pipe.model.eval()
556
+ print(f"Running on {next(pipe.model.parameters()).device}")
557
  tprint("model ready to run inference")
558
 
559
  # run inference
 
645
  series.append((ename, probs[:, eidx, 1][::stride].astype(float)))
646
 
647
  tprint("figure data processed created")
648
+
649
+ # Build region string for x-axis label
 
650
  region = (
651
  f"{out.chrom}:{out.pred_start}-{out.pred_end}" if out.chrom else f"{x0}-{x1}"
652
  )
653
  if out.assembly:
654
  region += f" ({out.assembly})"
655
+
656
+ fig = _make_tracks_figure(x, series, region=region)
657
+ tprint("figure created")
658
 
659
  png_path = _save_fig_png(fig)
660
  tprint("figure png saved")
 
1111
  + ", ".join(sorted(SPECIES_WITH_COORDINATE_SUPPORT))
1112
  + ")"
1113
  )
1114
+ with gr.Row():
1115
  chrom = gr.Textbox(label="Chromosome", value=_default_coords["chrom"])
1116
  start = gr.Number(
1117
  label="Start", value=_default_coords["start"], precision=0
requirements.txt CHANGED
@@ -1,6 +1,8 @@
1
  gradio>=4.0.0
2
  matplotlib
3
  numpy
 
 
4
  pyBigWig
5
  pyfaidx
6
  requests
 
1
  gradio>=4.0.0
2
  matplotlib
3
  numpy
4
+ plotly
5
+ kaleido
6
  pyBigWig
7
  pyfaidx
8
  requests