genomenet Claude Opus 4.5 commited on
Commit
6050d79
·
1 Parent(s): 3cc5297

Restore colorful clusters in state-dynamic plots

Browse files

- Keep colorful Set1 palette for cluster visualization
- Viridis gradient for position coloring
- Green/red start/end markers
- Monochrome styling only for UI elements, not data viz

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>

Files changed (1) hide show
  1. app.py +101 -69
app.py CHANGED
@@ -642,9 +642,8 @@ def create_interactive_state_plot(embeddings, n_clusters=8, stride=100, use_3d=F
642
  hover_text = [f"Window {i}<br>Position: {pos}-{pos+1000} bp<br>Cluster: {c}"
643
  for i, (pos, c) in enumerate(zip(positions, cluster_labels))]
644
 
645
- # Monochrome grayscale palette for clusters
646
- grays = [f'rgba({int(40 + i * 180 / n_clusters)}, {int(40 + i * 180 / n_clusters)}, {int(40 + i * 180 / n_clusters)}, 0.8)'
647
- for i in range(n_clusters)]
648
 
649
  if use_3d:
650
  fig = go.Figure()
@@ -655,12 +654,12 @@ def create_interactive_state_plot(embeddings, n_clusters=8, stride=100, use_3d=F
655
  y=embedding_reduced[:, 1],
656
  z=embedding_reduced[:, 2],
657
  mode='lines',
658
- line=dict(color='rgba(113,113,122,0.3)', width=2),
659
  name='Trajectory',
660
  hoverinfo='skip'
661
  ))
662
 
663
- # Points - grayscale colorscale
664
  fig.add_trace(go.Scatter3d(
665
  x=embedding_reduced[:, 0],
666
  y=embedding_reduced[:, 1],
@@ -669,7 +668,7 @@ def create_interactive_state_plot(embeddings, n_clusters=8, stride=100, use_3d=F
669
  marker=dict(
670
  size=5,
671
  color=cluster_labels,
672
- colorscale='Greys',
673
  opacity=0.85,
674
  line=dict(width=0.5, color='white')
675
  ),
@@ -678,22 +677,22 @@ def create_interactive_state_plot(embeddings, n_clusters=8, stride=100, use_3d=F
678
  name='Windows'
679
  ))
680
 
681
- # Start marker - dark
682
  fig.add_trace(go.Scatter3d(
683
  x=[embedding_reduced[0, 0]],
684
  y=[embedding_reduced[0, 1]],
685
  z=[embedding_reduced[0, 2]],
686
  mode='markers',
687
- marker=dict(size=10, color='#18181b', symbol='diamond'),
688
  name="5' start"
689
  ))
690
- # End marker - medium gray
691
  fig.add_trace(go.Scatter3d(
692
  x=[embedding_reduced[-1, 0]],
693
  y=[embedding_reduced[-1, 1]],
694
  z=[embedding_reduced[-1, 2]],
695
  mode='markers',
696
- marker=dict(size=10, color='#71717a', symbol='square'),
697
  name="3' end"
698
  ))
699
 
@@ -739,7 +738,7 @@ def create_interactive_state_plot(embeddings, n_clusters=8, stride=100, use_3d=F
739
  x=embedding_reduced[mask, 0],
740
  y=embedding_reduced[mask, 1],
741
  mode='markers',
742
- marker=dict(size=7, color=grays[c],
743
  line=dict(width=0.5, color='white')),
744
  text=[hover_text[i] for i in np.where(mask)[0]],
745
  hovertemplate='%{text}<extra></extra>',
@@ -750,24 +749,24 @@ def create_interactive_state_plot(embeddings, n_clusters=8, stride=100, use_3d=F
750
  # Start/End markers
751
  fig.add_trace(go.Scatter(
752
  x=[embedding_reduced[0, 0]], y=[embedding_reduced[0, 1]],
753
- mode='markers', marker=dict(size=12, color='#18181b', symbol='triangle-up',
754
- line=dict(width=1, color='white')),
755
  name="5'", showlegend=True
756
  ), row=1, col=1)
757
  fig.add_trace(go.Scatter(
758
  x=[embedding_reduced[-1, 0]], y=[embedding_reduced[-1, 1]],
759
- mode='markers', marker=dict(size=12, color='#71717a', symbol='square',
760
- line=dict(width=1, color='white')),
761
  name="3'", showlegend=True
762
  ), row=1, col=1)
763
 
764
- # Right plot: by position - grayscale gradient
765
  fig.add_trace(go.Scatter(
766
  x=embedding_reduced[:, 0],
767
  y=embedding_reduced[:, 1],
768
  mode='lines+markers',
769
- line=dict(color='rgba(113,113,122,0.2)', width=1),
770
- marker=dict(size=7, color=np.arange(n_windows), colorscale='Greys',
771
  showscale=True, colorbar=dict(title=dict(text='window', font=dict(size=10)),
772
  x=1.02, tickfont=dict(size=9))),
773
  text=hover_text,
@@ -777,25 +776,25 @@ def create_interactive_state_plot(embeddings, n_clusters=8, stride=100, use_3d=F
777
 
778
  fig.add_trace(go.Scatter(
779
  x=[embedding_reduced[0, 0]], y=[embedding_reduced[0, 1]],
780
- mode='markers', marker=dict(size=12, color='#18181b', symbol='triangle-up',
781
- line=dict(width=1, color='white')),
782
  showlegend=False
783
  ), row=1, col=2)
784
  fig.add_trace(go.Scatter(
785
  x=[embedding_reduced[-1, 0]], y=[embedding_reduced[-1, 1]],
786
- mode='markers', marker=dict(size=12, color='#71717a', symbol='square',
787
- line=dict(width=1, color='white')),
788
  showlegend=False
789
  ), row=1, col=2)
790
 
791
- # Bottom: sequence map - grayscale blocks
792
  window_size = 1000
793
  for i, (cluster, pos) in enumerate(zip(cluster_labels, positions)):
794
  fig.add_trace(go.Scatter(
795
  x=[pos, pos + window_size, pos + window_size, pos, pos],
796
  y=[0, 0, 1, 1, 0],
797
  fill='toself',
798
- fillcolor=grays[cluster],
799
  line=dict(width=0),
800
  hoverinfo='text',
801
  text=f'Position {pos}-{pos+window_size} bp<br>Cluster {cluster}',
@@ -918,50 +917,64 @@ def create_sequence_viewer_html(sequence, positions, probabilities, threshold=0.
918
 
919
  def predict(sequence: str, stride: int = 100, threshold: float = 0.3):
920
  """Predict CRISPR array probability for each position."""
921
- import tempfile
922
  import csv
923
  import time
924
 
925
  start_time = time.time()
926
 
927
- sequence = strip_fasta_header(sequence.strip())
 
 
 
 
 
 
928
 
929
- is_valid, error = validate_sequence(sequence)
930
  if not is_valid:
931
- return None, f"**Error**: {error}", None, None, None, None, None, None, None
932
 
933
  result = predict_sequence(sequence, stride=stride, aggregation="mean")
934
 
935
- # Detect regions first (needed for plot annotations)
936
- regions = detect_crispr_regions(sequence, threshold=threshold, min_length=100, stride=stride)
 
 
 
 
 
 
 
 
 
937
 
938
  # Create interactive Plotly plot
939
- fig = create_interactive_prediction_plot(result.positions, result.probabilities, threshold, regions)
940
 
941
  # Create static matplotlib plot for PNG/PDF export
942
- static_fig = create_prediction_plot(result.positions, result.probabilities, threshold, regions)
943
- png_path, pdf_path = save_figure_to_file(static_fig, "crispr_prediction")
 
944
  plt.close(static_fig)
945
 
946
  # Create CSV with prediction data
947
- temp_dir = tempfile.gettempdir()
948
- csv_path = os.path.join(temp_dir, "crispr_predictions.csv")
949
  with open(csv_path, 'w', newline='') as f:
950
  writer = csv.writer(f)
951
- writer.writerow(['position', 'probability', 'above_threshold'])
952
  for pos, prob in zip(result.positions, result.probabilities):
953
- writer.writerow([pos, f"{prob:.4f}", prob >= threshold])
954
 
955
  # Create GFF3 export
956
- gff_path = create_gff3_export(regions, result.sequence_length) if regions else None
957
 
958
  # Create sequence viewer HTML
959
- seq_viewer_html = create_sequence_viewer_html(sequence, result.positions, result.probabilities, threshold)
960
 
961
  elapsed_time = time.time() - start_time
962
 
963
  # Create summary text file
964
- summary_path = os.path.join(temp_dir, "crispr_summary.txt")
965
  summary_text = f"""CRISPR Array Detection Summary
966
  ==============================
967
 
@@ -1008,9 +1021,15 @@ Detected CRISPR Regions: {len(regions)}
1008
 
1009
  def detect(sequence: str, threshold: float = 0.3, min_length: int = 160):
1010
  """Detect CRISPR array regions."""
1011
- sequence = strip_fasta_header(sequence.strip())
 
 
 
 
 
 
1012
 
1013
- is_valid, error = validate_sequence(sequence)
1014
  if not is_valid:
1015
  return [], f"**Error**: {error}"
1016
 
@@ -1031,20 +1050,16 @@ def detect(sequence: str, threshold: float = 0.3, min_length: int = 160):
1031
  return regions, summary
1032
 
1033
 
1034
- def save_figure_to_file(fig, prefix="plot"):
1035
  """Save matplotlib figure to temporary files for download."""
1036
- import tempfile
1037
- import os
1038
-
1039
- # Create temp directory if needed
1040
- temp_dir = tempfile.gettempdir()
1041
 
1042
  # Save PNG
1043
- png_path = os.path.join(temp_dir, f"{prefix}.png")
1044
  fig.savefig(png_path, dpi=150, bbox_inches='tight', facecolor='white')
1045
 
1046
  # Save PDF
1047
- pdf_path = os.path.join(temp_dir, f"{prefix}.pdf")
1048
  fig.savefig(pdf_path, bbox_inches='tight', facecolor='white')
1049
 
1050
  return png_path, pdf_path
@@ -1052,14 +1067,19 @@ def save_figure_to_file(fig, prefix="plot"):
1052
 
1053
  def get_embedding(sequence: str, mode: str = "mean", use_3d: bool = False):
1054
  """Extract hidden state embedding and visualize as heatmap."""
1055
- sequence = strip_fasta_header(sequence.strip())
 
 
 
 
1056
 
1057
- is_valid, error = validate_sequence(sequence)
1058
  if not is_valid:
1059
- return None, f"**Error**: {error}", None, None
1060
 
1061
  result = embed_sequence(sequence, mode="trajectory" if mode == "state-dynamics" else mode)
1062
  png_path, pdf_path = None, None
 
1063
 
1064
  if mode == "trajectory":
1065
  # Create trajectory heatmap (windows x dimensions)
@@ -1067,7 +1087,7 @@ def get_embedding(sequence: str, mode: str = "mean", use_3d: bool = False):
1067
  result.embeddings,
1068
  title="Embedding Trajectory Across Sequence"
1069
  )
1070
- png_path, pdf_path = save_figure_to_file(fig, "trajectory_embedding")
1071
  summary = f"""## Trajectory Embedding
1072
 
1073
  | Property | Value |
@@ -1090,7 +1110,7 @@ Blue = negative activation, Red = positive activation.
1090
 
1091
  # For downloads, create a static matplotlib version
1092
  static_fig = create_state_dynamic_plot(embeddings, n_clusters=n_clusters, stride=100)
1093
- png_path, pdf_path = save_figure_to_file(static_fig, "state_dynamic_plot")
1094
  plt.close(static_fig)
1095
 
1096
  dim_text = "3D" if use_3d else "2D"
@@ -1121,7 +1141,7 @@ Blue = negative activation, Red = positive activation.
1121
  result.embedding,
1122
  title=f"Sequence Embedding ({result.method})"
1123
  )
1124
- png_path, pdf_path = save_figure_to_file(fig, f"embedding_{mode}")
1125
  summary = f"""## Embedding Extracted
1126
 
1127
  | Property | Value |
@@ -1138,7 +1158,18 @@ Blue = negative activation, Red = positive activation.
1138
 
1139
 
1140
  # Build interface
1141
- with gr.Blocks(title="CRISPR Array Detection") as demo:
 
 
 
 
 
 
 
 
 
 
 
1142
  gr.Markdown("""
1143
  # crispr-detect
1144
 
@@ -1224,14 +1255,17 @@ Sliding window analysis with per-position probability scores. Export to GFF3/CSV
1224
  results = predict(*args)
1225
  # results = (fig, summary, regions, png, pdf, csv, summary_txt, gff, seq_html)
1226
  # Return results plus visibility updates for accordions
1227
- return results + (gr.update(visible=True), gr.update(visible=True))
 
1228
 
1229
  predict_btn.click(
1230
  predict_and_show_downloads,
1231
  inputs=[seq_input, stride_input, threshold_input],
1232
  outputs=[plot_output, result_summary, regions_output, pred_download_png, pred_download_pdf,
1233
  pred_download_csv, pred_download_summary, pred_download_gff, seq_viewer_html,
1234
- download_accordion, seq_viewer_accordion]
 
 
1235
  )
1236
 
1237
  with gr.Tab("Embeddings"):
@@ -1289,12 +1323,15 @@ Repeats cluster together, spacers form distinct groups.
1289
 
1290
  def embed_and_show_downloads(*args):
1291
  results = get_embedding(*args)
1292
- return results + (gr.update(visible=True),)
 
1293
 
1294
  embed_btn.click(
1295
  embed_and_show_downloads,
1296
  inputs=[embed_seq, embed_mode, use_3d],
1297
- outputs=[embed_plot, embed_summary, download_png, download_pdf, embed_download_accordion]
 
 
1298
  )
1299
 
1300
  with gr.Tab("API"):
@@ -1365,15 +1402,10 @@ if __name__ == "__main__":
1365
  model = get_model()
1366
  warmup_model(model)
1367
  print(f"Model ready! GPU: {get_gpu_status()}")
 
1368
  demo.launch(
1369
  server_name="0.0.0.0",
1370
  server_port=7860,
1371
- theme=gr.themes.Base(
1372
- primary_hue=gr.themes.colors.zinc,
1373
- secondary_hue=gr.themes.colors.zinc,
1374
- neutral_hue=gr.themes.colors.zinc,
1375
- font=gr.themes.GoogleFont("Inter"),
1376
- font_mono=gr.themes.GoogleFont("Geist Mono"),
1377
- ),
1378
- css=CUSTOM_CSS
1379
  )
 
642
  hover_text = [f"Window {i}<br>Position: {pos}-{pos+1000} bp<br>Cluster: {c}"
643
  for i, (pos, c) in enumerate(zip(positions, cluster_labels))]
644
 
645
+ # Colorful palette for clusters
646
+ colors = px.colors.qualitative.Set1[:n_clusters]
 
647
 
648
  if use_3d:
649
  fig = go.Figure()
 
654
  y=embedding_reduced[:, 1],
655
  z=embedding_reduced[:, 2],
656
  mode='lines',
657
+ line=dict(color='rgba(100,100,100,0.3)', width=2),
658
  name='Trajectory',
659
  hoverinfo='skip'
660
  ))
661
 
662
+ # Points - colorful by cluster
663
  fig.add_trace(go.Scatter3d(
664
  x=embedding_reduced[:, 0],
665
  y=embedding_reduced[:, 1],
 
668
  marker=dict(
669
  size=5,
670
  color=cluster_labels,
671
+ colorscale='Set1',
672
  opacity=0.85,
673
  line=dict(width=0.5, color='white')
674
  ),
 
677
  name='Windows'
678
  ))
679
 
680
+ # Start marker - green
681
  fig.add_trace(go.Scatter3d(
682
  x=[embedding_reduced[0, 0]],
683
  y=[embedding_reduced[0, 1]],
684
  z=[embedding_reduced[0, 2]],
685
  mode='markers',
686
+ marker=dict(size=10, color='green', symbol='diamond'),
687
  name="5' start"
688
  ))
689
+ # End marker - red
690
  fig.add_trace(go.Scatter3d(
691
  x=[embedding_reduced[-1, 0]],
692
  y=[embedding_reduced[-1, 1]],
693
  z=[embedding_reduced[-1, 2]],
694
  mode='markers',
695
+ marker=dict(size=10, color='red', symbol='square'),
696
  name="3' end"
697
  ))
698
 
 
738
  x=embedding_reduced[mask, 0],
739
  y=embedding_reduced[mask, 1],
740
  mode='markers',
741
+ marker=dict(size=7, color=colors[c], opacity=0.8,
742
  line=dict(width=0.5, color='white')),
743
  text=[hover_text[i] for i in np.where(mask)[0]],
744
  hovertemplate='%{text}<extra></extra>',
 
749
  # Start/End markers
750
  fig.add_trace(go.Scatter(
751
  x=[embedding_reduced[0, 0]], y=[embedding_reduced[0, 1]],
752
+ mode='markers', marker=dict(size=12, color='green', symbol='triangle-up',
753
+ line=dict(width=1, color='black')),
754
  name="5'", showlegend=True
755
  ), row=1, col=1)
756
  fig.add_trace(go.Scatter(
757
  x=[embedding_reduced[-1, 0]], y=[embedding_reduced[-1, 1]],
758
+ mode='markers', marker=dict(size=12, color='red', symbol='square',
759
+ line=dict(width=1, color='black')),
760
  name="3'", showlegend=True
761
  ), row=1, col=1)
762
 
763
+ # Right plot: by position - viridis gradient
764
  fig.add_trace(go.Scatter(
765
  x=embedding_reduced[:, 0],
766
  y=embedding_reduced[:, 1],
767
  mode='lines+markers',
768
+ line=dict(color='rgba(100,100,100,0.3)', width=1),
769
+ marker=dict(size=7, color=np.arange(n_windows), colorscale='Viridis',
770
  showscale=True, colorbar=dict(title=dict(text='window', font=dict(size=10)),
771
  x=1.02, tickfont=dict(size=9))),
772
  text=hover_text,
 
776
 
777
  fig.add_trace(go.Scatter(
778
  x=[embedding_reduced[0, 0]], y=[embedding_reduced[0, 1]],
779
+ mode='markers', marker=dict(size=12, color='green', symbol='triangle-up',
780
+ line=dict(width=1, color='black')),
781
  showlegend=False
782
  ), row=1, col=2)
783
  fig.add_trace(go.Scatter(
784
  x=[embedding_reduced[-1, 0]], y=[embedding_reduced[-1, 1]],
785
+ mode='markers', marker=dict(size=12, color='red', symbol='square',
786
+ line=dict(width=1, color='black')),
787
  showlegend=False
788
  ), row=1, col=2)
789
 
790
+ # Bottom: sequence map - colorful blocks
791
  window_size = 1000
792
  for i, (cluster, pos) in enumerate(zip(cluster_labels, positions)):
793
  fig.add_trace(go.Scatter(
794
  x=[pos, pos + window_size, pos + window_size, pos, pos],
795
  y=[0, 0, 1, 1, 0],
796
  fill='toself',
797
+ fillcolor=colors[cluster],
798
  line=dict(width=0),
799
  hoverinfo='text',
800
  text=f'Position {pos}-{pos+window_size} bp<br>Cluster {cluster}',
 
917
 
918
  def predict(sequence: str, stride: int = 100, threshold: float = 0.3):
919
  """Predict CRISPR array probability for each position."""
 
920
  import csv
921
  import time
922
 
923
  start_time = time.time()
924
 
925
+ is_valid, sequence, error = normalize_sequence_input(sequence)
926
+ if not is_valid:
927
+ return prediction_error_outputs(error)
928
+
929
+ is_valid, stride, error = validate_stride(stride)
930
+ if not is_valid:
931
+ return prediction_error_outputs(error)
932
 
933
+ is_valid, threshold, error = validate_threshold(threshold)
934
  if not is_valid:
935
+ return prediction_error_outputs(error)
936
 
937
  result = predict_sequence(sequence, stride=stride, aggregation="mean")
938
 
939
+ # Reuse the prediction result so the model only runs once per analysis.
940
+ regions = detect_crispr_regions(
941
+ sequence,
942
+ threshold=threshold,
943
+ min_length=100,
944
+ stride=stride,
945
+ prediction_result=result,
946
+ )
947
+
948
+ # User-facing coordinates are 1-based. Core inference stays 0-based.
949
+ display_positions = [pos + 1 for pos in result.positions]
950
 
951
  # Create interactive Plotly plot
952
+ fig = create_interactive_prediction_plot(display_positions, result.probabilities, threshold, regions)
953
 
954
  # Create static matplotlib plot for PNG/PDF export
955
+ output_dir = make_output_dir("crispr_prediction")
956
+ static_fig = create_prediction_plot(display_positions, result.probabilities, threshold, regions)
957
+ png_path, pdf_path = save_figure_to_file(static_fig, "crispr_prediction", output_dir)
958
  plt.close(static_fig)
959
 
960
  # Create CSV with prediction data
961
+ csv_path = os.path.join(output_dir, "crispr_predictions.csv")
 
962
  with open(csv_path, 'w', newline='') as f:
963
  writer = csv.writer(f)
964
+ writer.writerow(['position_1based', 'probability', 'above_threshold'])
965
  for pos, prob in zip(result.positions, result.probabilities):
966
+ writer.writerow([pos + 1, f"{prob:.4f}", prob >= threshold])
967
 
968
  # Create GFF3 export
969
+ gff_path = create_gff3_export(regions, result.sequence_length, output_dir=output_dir) if regions else None
970
 
971
  # Create sequence viewer HTML
972
+ seq_viewer_html = create_sequence_viewer_html(sequence, display_positions, result.probabilities, threshold)
973
 
974
  elapsed_time = time.time() - start_time
975
 
976
  # Create summary text file
977
+ summary_path = os.path.join(output_dir, "crispr_summary.txt")
978
  summary_text = f"""CRISPR Array Detection Summary
979
  ==============================
980
 
 
1021
 
1022
  def detect(sequence: str, threshold: float = 0.3, min_length: int = 160):
1023
  """Detect CRISPR array regions."""
1024
+ is_valid, sequence, error = normalize_sequence_input(sequence)
1025
+ if not is_valid:
1026
+ return [], f"**Error**: {error}"
1027
+
1028
+ is_valid, threshold, error = validate_threshold(threshold)
1029
+ if not is_valid:
1030
+ return [], f"**Error**: {error}"
1031
 
1032
+ is_valid, min_length, error = validate_min_length(min_length)
1033
  if not is_valid:
1034
  return [], f"**Error**: {error}"
1035
 
 
1050
  return regions, summary
1051
 
1052
 
1053
+ def save_figure_to_file(fig, prefix="plot", output_dir=None):
1054
  """Save matplotlib figure to temporary files for download."""
1055
+ output_dir = output_dir or make_output_dir(prefix)
 
 
 
 
1056
 
1057
  # Save PNG
1058
+ png_path = os.path.join(output_dir, f"{prefix}.png")
1059
  fig.savefig(png_path, dpi=150, bbox_inches='tight', facecolor='white')
1060
 
1061
  # Save PDF
1062
+ pdf_path = os.path.join(output_dir, f"{prefix}.pdf")
1063
  fig.savefig(pdf_path, bbox_inches='tight', facecolor='white')
1064
 
1065
  return png_path, pdf_path
 
1067
 
1068
  def get_embedding(sequence: str, mode: str = "mean", use_3d: bool = False):
1069
  """Extract hidden state embedding and visualize as heatmap."""
1070
+ allowed_modes = {"state-dynamics", "mean", "max", "trajectory", "cls"}
1071
+ if mode not in allowed_modes:
1072
+ return embedding_error_outputs(
1073
+ "Mode must be one of: state-dynamics, mean, max, trajectory, cls"
1074
+ )
1075
 
1076
+ is_valid, sequence, error = normalize_sequence_input(sequence)
1077
  if not is_valid:
1078
+ return embedding_error_outputs(error)
1079
 
1080
  result = embed_sequence(sequence, mode="trajectory" if mode == "state-dynamics" else mode)
1081
  png_path, pdf_path = None, None
1082
+ output_dir = make_output_dir("crispr_embedding")
1083
 
1084
  if mode == "trajectory":
1085
  # Create trajectory heatmap (windows x dimensions)
 
1087
  result.embeddings,
1088
  title="Embedding Trajectory Across Sequence"
1089
  )
1090
+ png_path, pdf_path = save_figure_to_file(fig, "trajectory_embedding", output_dir)
1091
  summary = f"""## Trajectory Embedding
1092
 
1093
  | Property | Value |
 
1110
 
1111
  # For downloads, create a static matplotlib version
1112
  static_fig = create_state_dynamic_plot(embeddings, n_clusters=n_clusters, stride=100)
1113
+ png_path, pdf_path = save_figure_to_file(static_fig, "state_dynamic_plot", output_dir)
1114
  plt.close(static_fig)
1115
 
1116
  dim_text = "3D" if use_3d else "2D"
 
1141
  result.embedding,
1142
  title=f"Sequence Embedding ({result.method})"
1143
  )
1144
+ png_path, pdf_path = save_figure_to_file(fig, f"embedding_{mode}", output_dir)
1145
  summary = f"""## Embedding Extracted
1146
 
1147
  | Property | Value |
 
1158
 
1159
 
1160
  # Build interface
1161
+ with gr.Blocks(
1162
+ title="CRISPR Array Detection",
1163
+ theme=gr.themes.Base(
1164
+ primary_hue=gr.themes.colors.zinc,
1165
+ secondary_hue=gr.themes.colors.zinc,
1166
+ neutral_hue=gr.themes.colors.zinc,
1167
+ font=gr.themes.GoogleFont("Inter"),
1168
+ font_mono=gr.themes.GoogleFont("Geist Mono"),
1169
+ ),
1170
+ css=CUSTOM_CSS,
1171
+ delete_cache=(3600, 86400),
1172
+ ) as demo:
1173
  gr.Markdown("""
1174
  # crispr-detect
1175
 
 
1255
  results = predict(*args)
1256
  # results = (fig, summary, regions, png, pdf, csv, summary_txt, gff, seq_html)
1257
  # Return results plus visibility updates for accordions
1258
+ success = results[0] is not None
1259
+ return results + (gr.update(visible=success), gr.update(visible=success))
1260
 
1261
  predict_btn.click(
1262
  predict_and_show_downloads,
1263
  inputs=[seq_input, stride_input, threshold_input],
1264
  outputs=[plot_output, result_summary, regions_output, pred_download_png, pred_download_pdf,
1265
  pred_download_csv, pred_download_summary, pred_download_gff, seq_viewer_html,
1266
+ download_accordion, seq_viewer_accordion],
1267
+ api_name="predict",
1268
+ concurrency_limit=1,
1269
  )
1270
 
1271
  with gr.Tab("Embeddings"):
 
1323
 
1324
  def embed_and_show_downloads(*args):
1325
  results = get_embedding(*args)
1326
+ success = results[0] is not None
1327
+ return results + (gr.update(visible=success),)
1328
 
1329
  embed_btn.click(
1330
  embed_and_show_downloads,
1331
  inputs=[embed_seq, embed_mode, use_3d],
1332
+ outputs=[embed_plot, embed_summary, download_png, download_pdf, embed_download_accordion],
1333
+ api_name="get_embedding",
1334
+ concurrency_limit=1,
1335
  )
1336
 
1337
  with gr.Tab("API"):
 
1402
  model = get_model()
1403
  warmup_model(model)
1404
  print(f"Model ready! GPU: {get_gpu_status()}")
1405
+ demo.queue(max_size=QUEUE_MAX_SIZE, default_concurrency_limit=1)
1406
  demo.launch(
1407
  server_name="0.0.0.0",
1408
  server_port=7860,
1409
+ max_threads=4,
1410
+ show_error=True,
 
 
 
 
 
 
1411
  )