the-puzzler commited on
Commit
08a4eed
·
1 Parent(s): 03fe355

Fix plots and simplify live community search

Browse files
Files changed (1) hide show
  1. app.py +50 -28
app.py CHANGED
@@ -8,6 +8,7 @@ from typing import Dict, List, Tuple
8
  import gradio as gr
9
  import numpy as np
10
  import plotly.express as px
 
11
  import torch
12
  from gradio_client import utils as gradio_client_utils
13
  from transformers import AutoModel, AutoTokenizer
@@ -93,6 +94,14 @@ CSS = """
93
  color: var(--muted);
94
  font-size: 0.95rem;
95
  }
 
 
 
 
 
 
 
 
96
  """
97
 
98
 
@@ -520,18 +529,35 @@ def _plot_umap(vectors: np.ndarray, labels: List[str], title: str):
520
  )
521
  coords = reducer.fit_transform(vectors)
522
  norms = np.linalg.norm(vectors, axis=1)
523
-
524
- fig = px.scatter(
525
- x=coords[:, 0],
526
- y=coords[:, 1],
527
- color=norms,
528
- hover_name=labels,
529
- labels={"x": "UMAP 1", "y": "UMAP 2", "color": "vector norm"},
530
- title=title,
531
- color_continuous_scale="Viridis",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
532
  )
533
- fig.update_traces(marker={"size": 10, "line": {"width": 0.6, "color": "#1d2a1f"}, "opacity": 0.9})
534
  fig.update_layout(
 
 
 
535
  paper_bgcolor="rgba(255,255,255,0)",
536
  plot_bgcolor="rgba(255,255,255,0.75)",
537
  margin={"l": 10, "r": 10, "t": 60, "b": 10},
@@ -540,13 +566,18 @@ def _plot_umap(vectors: np.ndarray, labels: List[str], title: str):
540
 
541
 
542
  def _plot_logits(logits: np.ndarray):
543
- fig = px.histogram(
544
- x=logits,
545
- nbins=min(50, max(12, len(logits) // 4)),
546
- title="Logit Distribution Over Input DNA Embeddings",
547
- color_discrete_sequence=["#d8832f"],
 
 
 
 
548
  )
549
  fig.update_layout(
 
550
  xaxis_title="Logit",
551
  yaxis_title="Count",
552
  paper_bgcolor="rgba(255,255,255,0)",
@@ -759,25 +790,21 @@ with gr.Blocks(title="Microbiome Explorer", css=CSS, theme=gr.themes.Soft()) as
759
  with gr.Tab("Build A Community"):
760
  with gr.Column(elem_classes=["soft-card"]):
761
  gr.Markdown(
762
- "Search `otus.97.allinfo` by OTU ID, taxon label, or taxonomy string. Add matching OTUs to a custom community, then score the assembled set."
763
  )
764
  with gr.Row():
765
  taxa_query = gr.Textbox(
766
  label="Search taxa",
767
  placeholder="Try Nitrospira, Lysobacter, Gammaproteobacteria, 97_8697 ...",
768
- scale=5,
769
  )
770
- taxa_search_btn = gr.Button("Refresh", variant="secondary", scale=1)
771
 
772
  community_search_status = gr.Markdown(elem_classes=["section-note"])
773
- taxa_matches = gr.Dropdown(
774
  label="Matching OTUs",
775
  choices=[],
776
  value=[],
777
- multiselect=True,
778
- allow_custom_value=False,
779
- interactive=True,
780
- max_choices=MAX_GENES,
781
  )
782
 
783
  with gr.Row():
@@ -822,11 +849,6 @@ with gr.Blocks(title="Microbiome Explorer", css=CSS, theme=gr.themes.Soft()) as
822
  inputs=[microbeatlas_in],
823
  outputs=[run_summary, input_umap_plot, final_umap_plot, logits_plot, top_table, member_table],
824
  )
825
- taxa_search_btn.click(
826
- fn=search_taxa,
827
- inputs=[taxa_query],
828
- outputs=[taxa_matches, community_search_status],
829
- )
830
  taxa_query.change(
831
  fn=search_taxa,
832
  inputs=[taxa_query],
 
8
  import gradio as gr
9
  import numpy as np
10
  import plotly.express as px
11
+ import plotly.graph_objects as go
12
  import torch
13
  from gradio_client import utils as gradio_client_utils
14
  from transformers import AutoModel, AutoTokenizer
 
94
  color: var(--muted);
95
  font-size: 0.95rem;
96
  }
97
+ .search-results {
98
+ max-height: 320px;
99
+ overflow-y: auto;
100
+ border: 1px solid var(--line);
101
+ border-radius: 16px;
102
+ background: rgba(255, 255, 255, 0.72);
103
+ padding: 10px 12px 2px 12px;
104
+ }
105
  """
106
 
107
 
 
529
  )
530
  coords = reducer.fit_transform(vectors)
531
  norms = np.linalg.norm(vectors, axis=1)
532
+ x_values = [float(value) for value in coords[:, 0]]
533
+ y_values = [float(value) for value in coords[:, 1]]
534
+ color_values = [float(value) for value in norms]
535
+
536
+ fig = go.Figure(
537
+ data=[
538
+ go.Scatter(
539
+ x=x_values,
540
+ y=y_values,
541
+ mode="markers",
542
+ text=labels,
543
+ customdata=np.array(color_values).reshape(-1, 1),
544
+ hovertemplate="<b>%{text}</b><br>UMAP 1=%{x:.3f}<br>UMAP 2=%{y:.3f}<br>norm=%{customdata[0]:.3f}<extra></extra>",
545
+ marker={
546
+ "size": 10,
547
+ "color": color_values,
548
+ "colorscale": "Viridis",
549
+ "line": {"width": 0.6, "color": "#1d2a1f"},
550
+ "opacity": 0.92,
551
+ "showscale": True,
552
+ "colorbar": {"title": "vector norm"},
553
+ },
554
+ )
555
+ ]
556
  )
 
557
  fig.update_layout(
558
+ title=title,
559
+ xaxis_title="UMAP 1",
560
+ yaxis_title="UMAP 2",
561
  paper_bgcolor="rgba(255,255,255,0)",
562
  plot_bgcolor="rgba(255,255,255,0.75)",
563
  margin={"l": 10, "r": 10, "t": 60, "b": 10},
 
566
 
567
 
568
  def _plot_logits(logits: np.ndarray):
569
+ hist_values = [float(value) for value in logits]
570
+ fig = go.Figure(
571
+ data=[
572
+ go.Histogram(
573
+ x=hist_values,
574
+ nbinsx=min(50, max(12, len(hist_values) // 4)),
575
+ marker={"color": "#d8832f"},
576
+ )
577
+ ]
578
  )
579
  fig.update_layout(
580
+ title="Logit Distribution Over Input DNA Embeddings",
581
  xaxis_title="Logit",
582
  yaxis_title="Count",
583
  paper_bgcolor="rgba(255,255,255,0)",
 
790
  with gr.Tab("Build A Community"):
791
  with gr.Column(elem_classes=["soft-card"]):
792
  gr.Markdown(
793
+ "Search the OTU index by OTU ID, taxon label, or taxonomy string. Matching OTUs appear directly below as you type, so you can add them without opening another widget."
794
  )
795
  with gr.Row():
796
  taxa_query = gr.Textbox(
797
  label="Search taxa",
798
  placeholder="Try Nitrospira, Lysobacter, Gammaproteobacteria, 97_8697 ...",
799
+ scale=6,
800
  )
 
801
 
802
  community_search_status = gr.Markdown(elem_classes=["section-note"])
803
+ taxa_matches = gr.CheckboxGroup(
804
  label="Matching OTUs",
805
  choices=[],
806
  value=[],
807
+ elem_classes=["search-results"],
 
 
 
808
  )
809
 
810
  with gr.Row():
 
849
  inputs=[microbeatlas_in],
850
  outputs=[run_summary, input_umap_plot, final_umap_plot, logits_plot, top_table, member_table],
851
  )
 
 
 
 
 
852
  taxa_query.change(
853
  fn=search_taxa,
854
  inputs=[taxa_query],