genomenet Claude Opus 4.5 commited on
Commit
cde1aef
·
1 Parent(s): 5d93f52

Add interactive 3D State-Dynamic plots and download buttons

Browse files

- State-Dynamic plots now use interactive Plotly instead of static matplotlib
- Added 3D checkbox for 3D UMAP visualization (drag to rotate)
- Added PNG/PDF download buttons for all plots
- Download buttons appear after analysis is complete
- 3D checkbox only visible when state-dynamics mode is selected

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>

Files changed (1) hide show
  1. app.py +83 -104
app.py CHANGED
@@ -543,13 +543,16 @@ def predict(sequence: str, stride: int = 100, threshold: float = 0.3):
543
 
544
  is_valid, error = validate_sequence(sequence)
545
  if not is_valid:
546
- return None, f"**Error**: {error}", None
547
 
548
  result = predict_sequence(sequence, stride=stride, aggregation="mean")
549
 
550
  # Create plot
551
  fig = create_prediction_plot(result.positions, result.probabilities, threshold)
552
 
 
 
 
553
  # Detect regions
554
  regions = detect_crispr_regions(sequence, threshold=threshold, min_length=100, stride=stride)
555
 
@@ -569,7 +572,7 @@ def predict(sequence: str, stride: int = 100, threshold: float = 0.3):
569
  for r in regions:
570
  summary += f"- **Region {r['region_id']}**: positions {r['start']:,}-{r['end']:,} ({r['length']} bp), score: {r['mean_score']:.3f}\n"
571
 
572
- return fig, summary, regions
573
 
574
 
575
  def detect(sequence: str, threshold: float = 0.3, min_length: int = 160):
@@ -597,15 +600,35 @@ def detect(sequence: str, threshold: float = 0.3, min_length: int = 160):
597
  return regions, summary
598
 
599
 
600
- def get_embedding(sequence: str, mode: str = "mean"):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
601
  """Extract hidden state embedding and visualize as heatmap."""
602
  sequence = strip_fasta_header(sequence.strip())
603
 
604
  is_valid, error = validate_sequence(sequence)
605
  if not is_valid:
606
- return None, f"**Error**: {error}"
607
 
608
  result = embed_sequence(sequence, mode="trajectory" if mode == "state-dynamics" else mode)
 
609
 
610
  if mode == "trajectory":
611
  # Create trajectory heatmap (windows x dimensions)
@@ -613,6 +636,7 @@ def get_embedding(sequence: str, mode: str = "mean"):
613
  result.embeddings,
614
  title="Embedding Trajectory Across Sequence"
615
  )
 
616
  summary = f"""## Trajectory Embedding
617
 
618
  | Property | Value |
@@ -625,111 +649,40 @@ Each row shows the embedding for one sliding window position.
625
  Blue = negative activation, Red = positive activation.
626
  """
627
  elif mode == "state-dynamics":
628
- # Create State-Dynamic Plot (UMAP + clustering)
629
  embeddings = np.array(result.embeddings)
630
  n_windows = embeddings.shape[0]
631
  n_clusters = min(8, max(3, n_windows // 3))
632
 
633
- # Perform clustering for both plots
634
- if n_windows >= 5:
635
- clustering = AgglomerativeClustering(n_clusters=n_clusters)
636
- cluster_labels = clustering.fit_predict(embeddings)
637
- else:
638
- cluster_labels = np.zeros(n_windows, dtype=int)
639
-
640
- # Create combined figure with state-dynamic plot and sequence map
641
- fig = plt.figure(figsize=(16, 10))
642
-
643
- # Top: State-dynamic plots
644
- ax1 = fig.add_subplot(2, 2, 1)
645
- ax2 = fig.add_subplot(2, 2, 2)
646
-
647
- if n_windows >= 5:
648
- # UMAP reduction
649
- n_neighbors = min(15, n_windows - 1)
650
- reducer = umap.UMAP(n_components=2, n_neighbors=n_neighbors,
651
- min_dist=0.1, random_state=42)
652
- embedding_2d = reducer.fit_transform(embeddings)
653
-
654
- # Colors for clusters
655
- colors = plt.cm.tab10(np.linspace(0, 1, n_clusters))
656
- cluster_cmap = ListedColormap(colors)
657
-
658
- # Left plot: by cluster
659
- points = embedding_2d.reshape(-1, 1, 2)
660
- segments = np.concatenate([points[:-1], points[1:]], axis=1)
661
- segment_colors = [colors[cluster_labels[i]] for i in range(len(segments))]
662
- lc = LineCollection(segments, colors=segment_colors, alpha=0.3, linewidths=1)
663
- ax1.add_collection(lc)
664
- scatter1 = ax1.scatter(embedding_2d[:, 0], embedding_2d[:, 1],
665
- c=cluster_labels, cmap=cluster_cmap,
666
- s=60, alpha=0.8, edgecolors='white', linewidths=0.5)
667
- ax1.scatter(embedding_2d[0, 0], embedding_2d[0, 1], c='green', s=200,
668
- marker='^', edgecolors='black', linewidths=2, label='Start', zorder=10)
669
- ax1.scatter(embedding_2d[-1, 0], embedding_2d[-1, 1], c='red', s=200,
670
- marker='s', edgecolors='black', linewidths=2, label='End', zorder=10)
671
- ax1.set_xlabel('UMAP 1')
672
- ax1.set_ylabel('UMAP 2')
673
- ax1.set_title('By Cluster', fontweight='bold')
674
- ax1.legend(loc='upper right', fontsize=8)
675
- plt.colorbar(scatter1, ax=ax1, shrink=0.8, label='Cluster')
676
-
677
- # Right plot: by position
678
- positions = np.arange(n_windows)
679
- segment_colors_pos = plt.cm.viridis(plt.Normalize(0, n_windows-1)(positions[:-1]))
680
- lc2 = LineCollection(segments, colors=segment_colors_pos, alpha=0.4, linewidths=1.5)
681
- ax2.add_collection(lc2)
682
- scatter2 = ax2.scatter(embedding_2d[:, 0], embedding_2d[:, 1],
683
- c=positions, cmap='viridis',
684
- s=60, alpha=0.8, edgecolors='white', linewidths=0.5)
685
- ax2.scatter(embedding_2d[0, 0], embedding_2d[0, 1], c='green', s=200,
686
- marker='^', edgecolors='black', linewidths=2, label="Start (5')", zorder=10)
687
- ax2.scatter(embedding_2d[-1, 0], embedding_2d[-1, 1], c='red', s=200,
688
- marker='s', edgecolors='black', linewidths=2, label="End (3')", zorder=10)
689
- ax2.set_xlabel('UMAP 1')
690
- ax2.set_ylabel('UMAP 2')
691
- ax2.set_title('By Position', fontweight='bold')
692
- ax2.legend(loc='upper right', fontsize=8)
693
- plt.colorbar(scatter2, ax=ax2, shrink=0.8, label='Window')
694
- else:
695
- ax1.text(0.5, 0.5, "Need longer sequence", ha='center', va='center')
696
- ax2.text(0.5, 0.5, "Need longer sequence", ha='center', va='center')
697
-
698
- # Bottom: Sequence cluster map (spans both columns)
699
- ax3 = fig.add_subplot(2, 1, 2)
700
- colors = plt.cm.tab10(np.linspace(0, 1, max(n_clusters, 10)))
701
- stride = 100
702
- window_size = 1000
703
- for i, cluster in enumerate(cluster_labels):
704
- start_pos = i * stride
705
- end_pos = start_pos + window_size
706
- ax3.axvspan(start_pos, end_pos, alpha=0.7, color=colors[cluster], linewidth=0)
707
 
708
- handles = [plt.Rectangle((0,0), 1, 1, color=colors[i], alpha=0.7) for i in range(n_clusters)]
709
- ax3.legend(handles, [f'Cluster {i}' for i in range(n_clusters)],
710
- loc='upper right', ncol=min(n_clusters, 5), fontsize=8)
711
- ax3.set_xlim(0, (n_windows - 1) * stride + window_size)
712
- ax3.set_ylim(0, 1)
713
- ax3.set_xlabel('Position (bp)', fontsize=11)
714
- ax3.set_yticks([])
715
- ax3.set_title('Sequence colored by embedding cluster (repeating patterns = structural elements)',
716
- fontsize=11, fontweight='bold')
717
 
718
- plt.tight_layout()
719
-
720
- summary = f"""## State-Dynamic Plot
721
 
722
  | Property | Value |
723
  |----------|-------|
724
  | Sequence length | {result.sequence_length:,} bp |
725
  | Windows analyzed | {result.num_windows} |
726
  | Clusters identified | {n_clusters} |
727
-
728
- **Top-Left**: UMAP projection colored by cluster. Similar activation patterns group together.
729
- **Top-Right**: Same projection colored by sequence position. Shows trajectory through embedding space.
730
- **Bottom**: Linear sequence map colored by cluster. Repeating color patterns indicate structural elements (e.g., repeats).
731
-
732
- If you see alternating colors (like in CRISPR arrays), this indicates the model detects repeating structural elements!
 
 
 
 
 
 
 
733
  """
734
  else:
735
  # Create single embedding heatmap
@@ -737,6 +690,7 @@ If you see alternating colors (like in CRISPR arrays), this indicates the model
737
  result.embedding,
738
  title=f"Sequence Embedding ({result.method})"
739
  )
 
740
  summary = f"""## Embedding Extracted
741
 
742
  | Property | Value |
@@ -749,7 +703,7 @@ Each cell represents one dimension of the {result.embedding_dim}-dimensional emb
749
  Blue = negative activation, Red = positive activation.
750
  """
751
 
752
- return fig, summary
753
 
754
 
755
  # Build interface
@@ -789,21 +743,25 @@ Detect CRISPR arrays in DNA sequences using a BERT-based deep learning model (43
789
  gr.Button("Load Non-CRISPR Example").click(
790
  lambda: NON_CRISPR_EXAMPLE, outputs=seq_input
791
  )
 
 
 
 
 
792
  with gr.Column(scale=2):
793
  plot_output = gr.Plot(label="CRISPR Score Profile")
794
- result_summary = gr.Markdown()
795
  regions_output = gr.JSON(label="Detected Regions", visible=False)
796
 
797
  predict_btn.click(
798
  predict,
799
  inputs=[seq_input, stride_input, threshold_input],
800
- outputs=[plot_output, result_summary, regions_output]
801
  )
802
 
803
  with gr.Tab("Embeddings"):
804
  gr.Markdown("""## State-Dynamic Plots
805
 
806
- Visualize how the model's internal representation changes across the sequence. The **State-Dynamics** mode projects embeddings to 2D using UMAP and clusters similar regions together.
807
 
808
  **For CRISPR arrays**: Expect to see alternating colors in the bottom bar where repeats and spacers alternate. Repeats should cluster together (conserved pattern), while spacers cluster separately.
809
  """)
@@ -819,7 +777,13 @@ Visualize how the model's internal representation changes across the sequence. T
819
  choices=["state-dynamics", "mean", "max", "trajectory"],
820
  value="state-dynamics",
821
  label="Visualization Mode",
822
- info="state-dynamics: UMAP clustering | mean/max: pooled heatmap | trajectory: per-window heatmap"
 
 
 
 
 
 
823
  )
824
  with gr.Row():
825
  embed_btn = gr.Button("Analyze Embeddings", variant="primary")
@@ -837,10 +801,25 @@ Visualize how the model's internal representation changes across the sequence. T
837
  - Downstream: 2364-2964 bp (random)
838
  """)
839
  embed_summary = gr.Markdown()
 
 
 
 
840
  with gr.Column(scale=2):
841
- embed_plot = gr.Plot(label="Embedding Visualization")
 
 
 
 
 
 
 
842
 
843
- embed_btn.click(get_embedding, inputs=[embed_seq, embed_mode], outputs=[embed_plot, embed_summary])
 
 
 
 
844
 
845
  with gr.Tab("About"):
846
  gr.Markdown("""
 
543
 
544
  is_valid, error = validate_sequence(sequence)
545
  if not is_valid:
546
+ return None, f"**Error**: {error}", None, None, None
547
 
548
  result = predict_sequence(sequence, stride=stride, aggregation="mean")
549
 
550
  # Create plot
551
  fig = create_prediction_plot(result.positions, result.probabilities, threshold)
552
 
553
+ # Save for download
554
+ png_path, pdf_path = save_figure_to_file(fig, "crispr_prediction")
555
+
556
  # Detect regions
557
  regions = detect_crispr_regions(sequence, threshold=threshold, min_length=100, stride=stride)
558
 
 
572
  for r in regions:
573
  summary += f"- **Region {r['region_id']}**: positions {r['start']:,}-{r['end']:,} ({r['length']} bp), score: {r['mean_score']:.3f}\n"
574
 
575
+ return fig, summary, regions, png_path, pdf_path
576
 
577
 
578
  def detect(sequence: str, threshold: float = 0.3, min_length: int = 160):
 
600
  return regions, summary
601
 
602
 
603
+ def save_figure_to_file(fig, prefix="plot"):
604
+ """Save matplotlib figure to temporary files for download."""
605
+ import tempfile
606
+ import os
607
+
608
+ # Create temp directory if needed
609
+ temp_dir = tempfile.gettempdir()
610
+
611
+ # Save PNG
612
+ png_path = os.path.join(temp_dir, f"{prefix}.png")
613
+ fig.savefig(png_path, dpi=150, bbox_inches='tight', facecolor='white')
614
+
615
+ # Save PDF
616
+ pdf_path = os.path.join(temp_dir, f"{prefix}.pdf")
617
+ fig.savefig(pdf_path, bbox_inches='tight', facecolor='white')
618
+
619
+ return png_path, pdf_path
620
+
621
+
622
+ def get_embedding(sequence: str, mode: str = "mean", use_3d: bool = False):
623
  """Extract hidden state embedding and visualize as heatmap."""
624
  sequence = strip_fasta_header(sequence.strip())
625
 
626
  is_valid, error = validate_sequence(sequence)
627
  if not is_valid:
628
+ return None, f"**Error**: {error}", None, None
629
 
630
  result = embed_sequence(sequence, mode="trajectory" if mode == "state-dynamics" else mode)
631
+ png_path, pdf_path = None, None
632
 
633
  if mode == "trajectory":
634
  # Create trajectory heatmap (windows x dimensions)
 
636
  result.embeddings,
637
  title="Embedding Trajectory Across Sequence"
638
  )
639
+ png_path, pdf_path = save_figure_to_file(fig, "trajectory_embedding")
640
  summary = f"""## Trajectory Embedding
641
 
642
  | Property | Value |
 
649
  Blue = negative activation, Red = positive activation.
650
  """
651
  elif mode == "state-dynamics":
652
+ # Create interactive State-Dynamic Plot using Plotly
653
  embeddings = np.array(result.embeddings)
654
  n_windows = embeddings.shape[0]
655
  n_clusters = min(8, max(3, n_windows // 3))
656
 
657
+ # Use the interactive Plotly version
658
+ fig = create_interactive_state_plot(embeddings, n_clusters=n_clusters, stride=100, use_3d=use_3d)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
659
 
660
+ # For downloads, create a static matplotlib version
661
+ static_fig = create_state_dynamic_plot(embeddings, n_clusters=n_clusters, stride=100)
662
+ png_path, pdf_path = save_figure_to_file(static_fig, "state_dynamic_plot")
663
+ plt.close(static_fig)
 
 
 
 
 
664
 
665
+ dim_text = "3D" if use_3d else "2D"
666
+ summary = f"""## Interactive State-Dynamic Plot ({dim_text})
 
667
 
668
  | Property | Value |
669
  |----------|-------|
670
  | Sequence length | {result.sequence_length:,} bp |
671
  | Windows analyzed | {result.num_windows} |
672
  | Clusters identified | {n_clusters} |
673
+ | Visualization | {dim_text} UMAP |
674
+
675
+ **Interactive controls:**
676
+ - **Hover** over points to see window position and cluster
677
+ - **Zoom** by scrolling or selecting region
678
+ - **Pan** by dragging
679
+ - **{"Rotate" if use_3d else "Double-click"}** to {"rotate 3D view" if use_3d else "reset zoom"}
680
+ - **Download**: Use buttons below for PNG/PDF, or camera icon in plot toolbar
681
+
682
+ **Interpretation:**
683
+ - Points colored by cluster - similar activation patterns group together
684
+ - Trajectory shows path through embedding space along the sequence
685
+ - Alternating colors in CRISPR arrays indicate repeating structural elements (repeats vs spacers)
686
  """
687
  else:
688
  # Create single embedding heatmap
 
690
  result.embedding,
691
  title=f"Sequence Embedding ({result.method})"
692
  )
693
+ png_path, pdf_path = save_figure_to_file(fig, f"embedding_{mode}")
694
  summary = f"""## Embedding Extracted
695
 
696
  | Property | Value |
 
703
  Blue = negative activation, Red = positive activation.
704
  """
705
 
706
+ return fig, summary, png_path, pdf_path
707
 
708
 
709
  # Build interface
 
743
  gr.Button("Load Non-CRISPR Example").click(
744
  lambda: NON_CRISPR_EXAMPLE, outputs=seq_input
745
  )
746
+ result_summary = gr.Markdown()
747
+ gr.Markdown("### Download Plot")
748
+ with gr.Row():
749
+ pred_download_png = gr.File(label="PNG", interactive=False)
750
+ pred_download_pdf = gr.File(label="PDF", interactive=False)
751
  with gr.Column(scale=2):
752
  plot_output = gr.Plot(label="CRISPR Score Profile")
 
753
  regions_output = gr.JSON(label="Detected Regions", visible=False)
754
 
755
  predict_btn.click(
756
  predict,
757
  inputs=[seq_input, stride_input, threshold_input],
758
+ outputs=[plot_output, result_summary, regions_output, pred_download_png, pred_download_pdf]
759
  )
760
 
761
  with gr.Tab("Embeddings"):
762
  gr.Markdown("""## State-Dynamic Plots
763
 
764
+ Visualize how the model's internal representation changes across the sequence. The **State-Dynamics** mode projects embeddings to 2D/3D using UMAP and clusters similar regions together.
765
 
766
  **For CRISPR arrays**: Expect to see alternating colors in the bottom bar where repeats and spacers alternate. Repeats should cluster together (conserved pattern), while spacers cluster separately.
767
  """)
 
777
  choices=["state-dynamics", "mean", "max", "trajectory"],
778
  value="state-dynamics",
779
  label="Visualization Mode",
780
+ info="state-dynamics: Interactive UMAP | mean/max: pooled heatmap | trajectory: per-window heatmap"
781
+ )
782
+ use_3d = gr.Checkbox(
783
+ label="3D Visualization",
784
+ value=False,
785
+ info="Enable 3D UMAP projection (drag to rotate)",
786
+ visible=True
787
  )
788
  with gr.Row():
789
  embed_btn = gr.Button("Analyze Embeddings", variant="primary")
 
801
  - Downstream: 2364-2964 bp (random)
802
  """)
803
  embed_summary = gr.Markdown()
804
+ gr.Markdown("### Download Plot")
805
+ with gr.Row():
806
+ download_png = gr.File(label="PNG", interactive=False)
807
+ download_pdf = gr.File(label="PDF", interactive=False)
808
  with gr.Column(scale=2):
809
+ embed_plot = gr.Plot(label="Embedding Visualization (Interactive)")
810
+
811
+ # Show/hide 3D checkbox based on mode
812
+ embed_mode.change(
813
+ lambda m: gr.update(visible=(m == "state-dynamics")),
814
+ inputs=[embed_mode],
815
+ outputs=[use_3d]
816
+ )
817
 
818
+ embed_btn.click(
819
+ get_embedding,
820
+ inputs=[embed_seq, embed_mode, use_3d],
821
+ outputs=[embed_plot, embed_summary, download_png, download_pdf]
822
+ )
823
 
824
  with gr.Tab("About"):
825
  gr.Markdown("""