Spaces:
Sleeping
Sleeping
Restore colorful clusters in state-dynamic plots
Browse files- Keep colorful Set1 palette for cluster visualization
- Viridis gradient for position coloring
- Green/red start/end markers
- Monochrome styling only for UI elements, not data viz
🤖 Generated with [Claude Code](https://claude.com/claude-code)
Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
app.py
CHANGED
|
@@ -642,9 +642,8 @@ def create_interactive_state_plot(embeddings, n_clusters=8, stride=100, use_3d=F
|
|
| 642 |
hover_text = [f"Window {i}<br>Position: {pos}-{pos+1000} bp<br>Cluster: {c}"
|
| 643 |
for i, (pos, c) in enumerate(zip(positions, cluster_labels))]
|
| 644 |
|
| 645 |
-
#
|
| 646 |
-
|
| 647 |
-
for i in range(n_clusters)]
|
| 648 |
|
| 649 |
if use_3d:
|
| 650 |
fig = go.Figure()
|
|
@@ -655,12 +654,12 @@ def create_interactive_state_plot(embeddings, n_clusters=8, stride=100, use_3d=F
|
|
| 655 |
y=embedding_reduced[:, 1],
|
| 656 |
z=embedding_reduced[:, 2],
|
| 657 |
mode='lines',
|
| 658 |
-
line=dict(color='rgba(
|
| 659 |
name='Trajectory',
|
| 660 |
hoverinfo='skip'
|
| 661 |
))
|
| 662 |
|
| 663 |
-
# Points -
|
| 664 |
fig.add_trace(go.Scatter3d(
|
| 665 |
x=embedding_reduced[:, 0],
|
| 666 |
y=embedding_reduced[:, 1],
|
|
@@ -669,7 +668,7 @@ def create_interactive_state_plot(embeddings, n_clusters=8, stride=100, use_3d=F
|
|
| 669 |
marker=dict(
|
| 670 |
size=5,
|
| 671 |
color=cluster_labels,
|
| 672 |
-
colorscale='
|
| 673 |
opacity=0.85,
|
| 674 |
line=dict(width=0.5, color='white')
|
| 675 |
),
|
|
@@ -678,22 +677,22 @@ def create_interactive_state_plot(embeddings, n_clusters=8, stride=100, use_3d=F
|
|
| 678 |
name='Windows'
|
| 679 |
))
|
| 680 |
|
| 681 |
-
# Start marker -
|
| 682 |
fig.add_trace(go.Scatter3d(
|
| 683 |
x=[embedding_reduced[0, 0]],
|
| 684 |
y=[embedding_reduced[0, 1]],
|
| 685 |
z=[embedding_reduced[0, 2]],
|
| 686 |
mode='markers',
|
| 687 |
-
marker=dict(size=10, color='
|
| 688 |
name="5' start"
|
| 689 |
))
|
| 690 |
-
# End marker -
|
| 691 |
fig.add_trace(go.Scatter3d(
|
| 692 |
x=[embedding_reduced[-1, 0]],
|
| 693 |
y=[embedding_reduced[-1, 1]],
|
| 694 |
z=[embedding_reduced[-1, 2]],
|
| 695 |
mode='markers',
|
| 696 |
-
marker=dict(size=10, color='
|
| 697 |
name="3' end"
|
| 698 |
))
|
| 699 |
|
|
@@ -739,7 +738,7 @@ def create_interactive_state_plot(embeddings, n_clusters=8, stride=100, use_3d=F
|
|
| 739 |
x=embedding_reduced[mask, 0],
|
| 740 |
y=embedding_reduced[mask, 1],
|
| 741 |
mode='markers',
|
| 742 |
-
marker=dict(size=7, color=
|
| 743 |
line=dict(width=0.5, color='white')),
|
| 744 |
text=[hover_text[i] for i in np.where(mask)[0]],
|
| 745 |
hovertemplate='%{text}<extra></extra>',
|
|
@@ -750,24 +749,24 @@ def create_interactive_state_plot(embeddings, n_clusters=8, stride=100, use_3d=F
|
|
| 750 |
# Start/End markers
|
| 751 |
fig.add_trace(go.Scatter(
|
| 752 |
x=[embedding_reduced[0, 0]], y=[embedding_reduced[0, 1]],
|
| 753 |
-
mode='markers', marker=dict(size=12, color='
|
| 754 |
-
line=dict(width=1, color='
|
| 755 |
name="5'", showlegend=True
|
| 756 |
), row=1, col=1)
|
| 757 |
fig.add_trace(go.Scatter(
|
| 758 |
x=[embedding_reduced[-1, 0]], y=[embedding_reduced[-1, 1]],
|
| 759 |
-
mode='markers', marker=dict(size=12, color='
|
| 760 |
-
line=dict(width=1, color='
|
| 761 |
name="3'", showlegend=True
|
| 762 |
), row=1, col=1)
|
| 763 |
|
| 764 |
-
# Right plot: by position -
|
| 765 |
fig.add_trace(go.Scatter(
|
| 766 |
x=embedding_reduced[:, 0],
|
| 767 |
y=embedding_reduced[:, 1],
|
| 768 |
mode='lines+markers',
|
| 769 |
-
line=dict(color='rgba(
|
| 770 |
-
marker=dict(size=7, color=np.arange(n_windows), colorscale='
|
| 771 |
showscale=True, colorbar=dict(title=dict(text='window', font=dict(size=10)),
|
| 772 |
x=1.02, tickfont=dict(size=9))),
|
| 773 |
text=hover_text,
|
|
@@ -777,25 +776,25 @@ def create_interactive_state_plot(embeddings, n_clusters=8, stride=100, use_3d=F
|
|
| 777 |
|
| 778 |
fig.add_trace(go.Scatter(
|
| 779 |
x=[embedding_reduced[0, 0]], y=[embedding_reduced[0, 1]],
|
| 780 |
-
mode='markers', marker=dict(size=12, color='
|
| 781 |
-
line=dict(width=1, color='
|
| 782 |
showlegend=False
|
| 783 |
), row=1, col=2)
|
| 784 |
fig.add_trace(go.Scatter(
|
| 785 |
x=[embedding_reduced[-1, 0]], y=[embedding_reduced[-1, 1]],
|
| 786 |
-
mode='markers', marker=dict(size=12, color='
|
| 787 |
-
line=dict(width=1, color='
|
| 788 |
showlegend=False
|
| 789 |
), row=1, col=2)
|
| 790 |
|
| 791 |
-
# Bottom: sequence map -
|
| 792 |
window_size = 1000
|
| 793 |
for i, (cluster, pos) in enumerate(zip(cluster_labels, positions)):
|
| 794 |
fig.add_trace(go.Scatter(
|
| 795 |
x=[pos, pos + window_size, pos + window_size, pos, pos],
|
| 796 |
y=[0, 0, 1, 1, 0],
|
| 797 |
fill='toself',
|
| 798 |
-
fillcolor=
|
| 799 |
line=dict(width=0),
|
| 800 |
hoverinfo='text',
|
| 801 |
text=f'Position {pos}-{pos+window_size} bp<br>Cluster {cluster}',
|
|
@@ -918,50 +917,64 @@ def create_sequence_viewer_html(sequence, positions, probabilities, threshold=0.
|
|
| 918 |
|
| 919 |
def predict(sequence: str, stride: int = 100, threshold: float = 0.3):
|
| 920 |
"""Predict CRISPR array probability for each position."""
|
| 921 |
-
import tempfile
|
| 922 |
import csv
|
| 923 |
import time
|
| 924 |
|
| 925 |
start_time = time.time()
|
| 926 |
|
| 927 |
-
sequence =
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 928 |
|
| 929 |
-
is_valid, error =
|
| 930 |
if not is_valid:
|
| 931 |
-
return
|
| 932 |
|
| 933 |
result = predict_sequence(sequence, stride=stride, aggregation="mean")
|
| 934 |
|
| 935 |
-
#
|
| 936 |
-
regions = detect_crispr_regions(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 937 |
|
| 938 |
# Create interactive Plotly plot
|
| 939 |
-
fig = create_interactive_prediction_plot(
|
| 940 |
|
| 941 |
# Create static matplotlib plot for PNG/PDF export
|
| 942 |
-
|
| 943 |
-
|
|
|
|
| 944 |
plt.close(static_fig)
|
| 945 |
|
| 946 |
# Create CSV with prediction data
|
| 947 |
-
|
| 948 |
-
csv_path = os.path.join(temp_dir, "crispr_predictions.csv")
|
| 949 |
with open(csv_path, 'w', newline='') as f:
|
| 950 |
writer = csv.writer(f)
|
| 951 |
-
writer.writerow(['
|
| 952 |
for pos, prob in zip(result.positions, result.probabilities):
|
| 953 |
-
writer.writerow([pos, f"{prob:.4f}", prob >= threshold])
|
| 954 |
|
| 955 |
# Create GFF3 export
|
| 956 |
-
gff_path = create_gff3_export(regions, result.sequence_length) if regions else None
|
| 957 |
|
| 958 |
# Create sequence viewer HTML
|
| 959 |
-
seq_viewer_html = create_sequence_viewer_html(sequence,
|
| 960 |
|
| 961 |
elapsed_time = time.time() - start_time
|
| 962 |
|
| 963 |
# Create summary text file
|
| 964 |
-
summary_path = os.path.join(
|
| 965 |
summary_text = f"""CRISPR Array Detection Summary
|
| 966 |
==============================
|
| 967 |
|
|
@@ -1008,9 +1021,15 @@ Detected CRISPR Regions: {len(regions)}
|
|
| 1008 |
|
| 1009 |
def detect(sequence: str, threshold: float = 0.3, min_length: int = 160):
|
| 1010 |
"""Detect CRISPR array regions."""
|
| 1011 |
-
sequence =
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1012 |
|
| 1013 |
-
is_valid, error =
|
| 1014 |
if not is_valid:
|
| 1015 |
return [], f"**Error**: {error}"
|
| 1016 |
|
|
@@ -1031,20 +1050,16 @@ def detect(sequence: str, threshold: float = 0.3, min_length: int = 160):
|
|
| 1031 |
return regions, summary
|
| 1032 |
|
| 1033 |
|
| 1034 |
-
def save_figure_to_file(fig, prefix="plot"):
|
| 1035 |
"""Save matplotlib figure to temporary files for download."""
|
| 1036 |
-
|
| 1037 |
-
import os
|
| 1038 |
-
|
| 1039 |
-
# Create temp directory if needed
|
| 1040 |
-
temp_dir = tempfile.gettempdir()
|
| 1041 |
|
| 1042 |
# Save PNG
|
| 1043 |
-
png_path = os.path.join(
|
| 1044 |
fig.savefig(png_path, dpi=150, bbox_inches='tight', facecolor='white')
|
| 1045 |
|
| 1046 |
# Save PDF
|
| 1047 |
-
pdf_path = os.path.join(
|
| 1048 |
fig.savefig(pdf_path, bbox_inches='tight', facecolor='white')
|
| 1049 |
|
| 1050 |
return png_path, pdf_path
|
|
@@ -1052,14 +1067,19 @@ def save_figure_to_file(fig, prefix="plot"):
|
|
| 1052 |
|
| 1053 |
def get_embedding(sequence: str, mode: str = "mean", use_3d: bool = False):
|
| 1054 |
"""Extract hidden state embedding and visualize as heatmap."""
|
| 1055 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1056 |
|
| 1057 |
-
is_valid, error =
|
| 1058 |
if not is_valid:
|
| 1059 |
-
return
|
| 1060 |
|
| 1061 |
result = embed_sequence(sequence, mode="trajectory" if mode == "state-dynamics" else mode)
|
| 1062 |
png_path, pdf_path = None, None
|
|
|
|
| 1063 |
|
| 1064 |
if mode == "trajectory":
|
| 1065 |
# Create trajectory heatmap (windows x dimensions)
|
|
@@ -1067,7 +1087,7 @@ def get_embedding(sequence: str, mode: str = "mean", use_3d: bool = False):
|
|
| 1067 |
result.embeddings,
|
| 1068 |
title="Embedding Trajectory Across Sequence"
|
| 1069 |
)
|
| 1070 |
-
png_path, pdf_path = save_figure_to_file(fig, "trajectory_embedding")
|
| 1071 |
summary = f"""## Trajectory Embedding
|
| 1072 |
|
| 1073 |
| Property | Value |
|
|
@@ -1090,7 +1110,7 @@ Blue = negative activation, Red = positive activation.
|
|
| 1090 |
|
| 1091 |
# For downloads, create a static matplotlib version
|
| 1092 |
static_fig = create_state_dynamic_plot(embeddings, n_clusters=n_clusters, stride=100)
|
| 1093 |
-
png_path, pdf_path = save_figure_to_file(static_fig, "state_dynamic_plot")
|
| 1094 |
plt.close(static_fig)
|
| 1095 |
|
| 1096 |
dim_text = "3D" if use_3d else "2D"
|
|
@@ -1121,7 +1141,7 @@ Blue = negative activation, Red = positive activation.
|
|
| 1121 |
result.embedding,
|
| 1122 |
title=f"Sequence Embedding ({result.method})"
|
| 1123 |
)
|
| 1124 |
-
png_path, pdf_path = save_figure_to_file(fig, f"embedding_{mode}")
|
| 1125 |
summary = f"""## Embedding Extracted
|
| 1126 |
|
| 1127 |
| Property | Value |
|
|
@@ -1138,7 +1158,18 @@ Blue = negative activation, Red = positive activation.
|
|
| 1138 |
|
| 1139 |
|
| 1140 |
# Build interface
|
| 1141 |
-
with gr.Blocks(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1142 |
gr.Markdown("""
|
| 1143 |
# crispr-detect
|
| 1144 |
|
|
@@ -1224,14 +1255,17 @@ Sliding window analysis with per-position probability scores. Export to GFF3/CSV
|
|
| 1224 |
results = predict(*args)
|
| 1225 |
# results = (fig, summary, regions, png, pdf, csv, summary_txt, gff, seq_html)
|
| 1226 |
# Return results plus visibility updates for accordions
|
| 1227 |
-
|
|
|
|
| 1228 |
|
| 1229 |
predict_btn.click(
|
| 1230 |
predict_and_show_downloads,
|
| 1231 |
inputs=[seq_input, stride_input, threshold_input],
|
| 1232 |
outputs=[plot_output, result_summary, regions_output, pred_download_png, pred_download_pdf,
|
| 1233 |
pred_download_csv, pred_download_summary, pred_download_gff, seq_viewer_html,
|
| 1234 |
-
download_accordion, seq_viewer_accordion]
|
|
|
|
|
|
|
| 1235 |
)
|
| 1236 |
|
| 1237 |
with gr.Tab("Embeddings"):
|
|
@@ -1289,12 +1323,15 @@ Repeats cluster together, spacers form distinct groups.
|
|
| 1289 |
|
| 1290 |
def embed_and_show_downloads(*args):
|
| 1291 |
results = get_embedding(*args)
|
| 1292 |
-
|
|
|
|
| 1293 |
|
| 1294 |
embed_btn.click(
|
| 1295 |
embed_and_show_downloads,
|
| 1296 |
inputs=[embed_seq, embed_mode, use_3d],
|
| 1297 |
-
outputs=[embed_plot, embed_summary, download_png, download_pdf, embed_download_accordion]
|
|
|
|
|
|
|
| 1298 |
)
|
| 1299 |
|
| 1300 |
with gr.Tab("API"):
|
|
@@ -1365,15 +1402,10 @@ if __name__ == "__main__":
|
|
| 1365 |
model = get_model()
|
| 1366 |
warmup_model(model)
|
| 1367 |
print(f"Model ready! GPU: {get_gpu_status()}")
|
|
|
|
| 1368 |
demo.launch(
|
| 1369 |
server_name="0.0.0.0",
|
| 1370 |
server_port=7860,
|
| 1371 |
-
|
| 1372 |
-
|
| 1373 |
-
secondary_hue=gr.themes.colors.zinc,
|
| 1374 |
-
neutral_hue=gr.themes.colors.zinc,
|
| 1375 |
-
font=gr.themes.GoogleFont("Inter"),
|
| 1376 |
-
font_mono=gr.themes.GoogleFont("Geist Mono"),
|
| 1377 |
-
),
|
| 1378 |
-
css=CUSTOM_CSS
|
| 1379 |
)
|
|
|
|
| 642 |
hover_text = [f"Window {i}<br>Position: {pos}-{pos+1000} bp<br>Cluster: {c}"
|
| 643 |
for i, (pos, c) in enumerate(zip(positions, cluster_labels))]
|
| 644 |
|
| 645 |
+
# Colorful palette for clusters
|
| 646 |
+
colors = px.colors.qualitative.Set1[:n_clusters]
|
|
|
|
| 647 |
|
| 648 |
if use_3d:
|
| 649 |
fig = go.Figure()
|
|
|
|
| 654 |
y=embedding_reduced[:, 1],
|
| 655 |
z=embedding_reduced[:, 2],
|
| 656 |
mode='lines',
|
| 657 |
+
line=dict(color='rgba(100,100,100,0.3)', width=2),
|
| 658 |
name='Trajectory',
|
| 659 |
hoverinfo='skip'
|
| 660 |
))
|
| 661 |
|
| 662 |
+
# Points - colorful by cluster
|
| 663 |
fig.add_trace(go.Scatter3d(
|
| 664 |
x=embedding_reduced[:, 0],
|
| 665 |
y=embedding_reduced[:, 1],
|
|
|
|
| 668 |
marker=dict(
|
| 669 |
size=5,
|
| 670 |
color=cluster_labels,
|
| 671 |
+
colorscale='Set1',
|
| 672 |
opacity=0.85,
|
| 673 |
line=dict(width=0.5, color='white')
|
| 674 |
),
|
|
|
|
| 677 |
name='Windows'
|
| 678 |
))
|
| 679 |
|
| 680 |
+
# Start marker - green
|
| 681 |
fig.add_trace(go.Scatter3d(
|
| 682 |
x=[embedding_reduced[0, 0]],
|
| 683 |
y=[embedding_reduced[0, 1]],
|
| 684 |
z=[embedding_reduced[0, 2]],
|
| 685 |
mode='markers',
|
| 686 |
+
marker=dict(size=10, color='green', symbol='diamond'),
|
| 687 |
name="5' start"
|
| 688 |
))
|
| 689 |
+
# End marker - red
|
| 690 |
fig.add_trace(go.Scatter3d(
|
| 691 |
x=[embedding_reduced[-1, 0]],
|
| 692 |
y=[embedding_reduced[-1, 1]],
|
| 693 |
z=[embedding_reduced[-1, 2]],
|
| 694 |
mode='markers',
|
| 695 |
+
marker=dict(size=10, color='red', symbol='square'),
|
| 696 |
name="3' end"
|
| 697 |
))
|
| 698 |
|
|
|
|
| 738 |
x=embedding_reduced[mask, 0],
|
| 739 |
y=embedding_reduced[mask, 1],
|
| 740 |
mode='markers',
|
| 741 |
+
marker=dict(size=7, color=colors[c], opacity=0.8,
|
| 742 |
line=dict(width=0.5, color='white')),
|
| 743 |
text=[hover_text[i] for i in np.where(mask)[0]],
|
| 744 |
hovertemplate='%{text}<extra></extra>',
|
|
|
|
| 749 |
# Start/End markers
|
| 750 |
fig.add_trace(go.Scatter(
|
| 751 |
x=[embedding_reduced[0, 0]], y=[embedding_reduced[0, 1]],
|
| 752 |
+
mode='markers', marker=dict(size=12, color='green', symbol='triangle-up',
|
| 753 |
+
line=dict(width=1, color='black')),
|
| 754 |
name="5'", showlegend=True
|
| 755 |
), row=1, col=1)
|
| 756 |
fig.add_trace(go.Scatter(
|
| 757 |
x=[embedding_reduced[-1, 0]], y=[embedding_reduced[-1, 1]],
|
| 758 |
+
mode='markers', marker=dict(size=12, color='red', symbol='square',
|
| 759 |
+
line=dict(width=1, color='black')),
|
| 760 |
name="3'", showlegend=True
|
| 761 |
), row=1, col=1)
|
| 762 |
|
| 763 |
+
# Right plot: by position - viridis gradient
|
| 764 |
fig.add_trace(go.Scatter(
|
| 765 |
x=embedding_reduced[:, 0],
|
| 766 |
y=embedding_reduced[:, 1],
|
| 767 |
mode='lines+markers',
|
| 768 |
+
line=dict(color='rgba(100,100,100,0.3)', width=1),
|
| 769 |
+
marker=dict(size=7, color=np.arange(n_windows), colorscale='Viridis',
|
| 770 |
showscale=True, colorbar=dict(title=dict(text='window', font=dict(size=10)),
|
| 771 |
x=1.02, tickfont=dict(size=9))),
|
| 772 |
text=hover_text,
|
|
|
|
| 776 |
|
| 777 |
fig.add_trace(go.Scatter(
|
| 778 |
x=[embedding_reduced[0, 0]], y=[embedding_reduced[0, 1]],
|
| 779 |
+
mode='markers', marker=dict(size=12, color='green', symbol='triangle-up',
|
| 780 |
+
line=dict(width=1, color='black')),
|
| 781 |
showlegend=False
|
| 782 |
), row=1, col=2)
|
| 783 |
fig.add_trace(go.Scatter(
|
| 784 |
x=[embedding_reduced[-1, 0]], y=[embedding_reduced[-1, 1]],
|
| 785 |
+
mode='markers', marker=dict(size=12, color='red', symbol='square',
|
| 786 |
+
line=dict(width=1, color='black')),
|
| 787 |
showlegend=False
|
| 788 |
), row=1, col=2)
|
| 789 |
|
| 790 |
+
# Bottom: sequence map - colorful blocks
|
| 791 |
window_size = 1000
|
| 792 |
for i, (cluster, pos) in enumerate(zip(cluster_labels, positions)):
|
| 793 |
fig.add_trace(go.Scatter(
|
| 794 |
x=[pos, pos + window_size, pos + window_size, pos, pos],
|
| 795 |
y=[0, 0, 1, 1, 0],
|
| 796 |
fill='toself',
|
| 797 |
+
fillcolor=colors[cluster],
|
| 798 |
line=dict(width=0),
|
| 799 |
hoverinfo='text',
|
| 800 |
text=f'Position {pos}-{pos+window_size} bp<br>Cluster {cluster}',
|
|
|
|
| 917 |
|
| 918 |
def predict(sequence: str, stride: int = 100, threshold: float = 0.3):
|
| 919 |
"""Predict CRISPR array probability for each position."""
|
|
|
|
| 920 |
import csv
|
| 921 |
import time
|
| 922 |
|
| 923 |
start_time = time.time()
|
| 924 |
|
| 925 |
+
is_valid, sequence, error = normalize_sequence_input(sequence)
|
| 926 |
+
if not is_valid:
|
| 927 |
+
return prediction_error_outputs(error)
|
| 928 |
+
|
| 929 |
+
is_valid, stride, error = validate_stride(stride)
|
| 930 |
+
if not is_valid:
|
| 931 |
+
return prediction_error_outputs(error)
|
| 932 |
|
| 933 |
+
is_valid, threshold, error = validate_threshold(threshold)
|
| 934 |
if not is_valid:
|
| 935 |
+
return prediction_error_outputs(error)
|
| 936 |
|
| 937 |
result = predict_sequence(sequence, stride=stride, aggregation="mean")
|
| 938 |
|
| 939 |
+
# Reuse the prediction result so the model only runs once per analysis.
|
| 940 |
+
regions = detect_crispr_regions(
|
| 941 |
+
sequence,
|
| 942 |
+
threshold=threshold,
|
| 943 |
+
min_length=100,
|
| 944 |
+
stride=stride,
|
| 945 |
+
prediction_result=result,
|
| 946 |
+
)
|
| 947 |
+
|
| 948 |
+
# User-facing coordinates are 1-based. Core inference stays 0-based.
|
| 949 |
+
display_positions = [pos + 1 for pos in result.positions]
|
| 950 |
|
| 951 |
# Create interactive Plotly plot
|
| 952 |
+
fig = create_interactive_prediction_plot(display_positions, result.probabilities, threshold, regions)
|
| 953 |
|
| 954 |
# Create static matplotlib plot for PNG/PDF export
|
| 955 |
+
output_dir = make_output_dir("crispr_prediction")
|
| 956 |
+
static_fig = create_prediction_plot(display_positions, result.probabilities, threshold, regions)
|
| 957 |
+
png_path, pdf_path = save_figure_to_file(static_fig, "crispr_prediction", output_dir)
|
| 958 |
plt.close(static_fig)
|
| 959 |
|
| 960 |
# Create CSV with prediction data
|
| 961 |
+
csv_path = os.path.join(output_dir, "crispr_predictions.csv")
|
|
|
|
| 962 |
with open(csv_path, 'w', newline='') as f:
|
| 963 |
writer = csv.writer(f)
|
| 964 |
+
writer.writerow(['position_1based', 'probability', 'above_threshold'])
|
| 965 |
for pos, prob in zip(result.positions, result.probabilities):
|
| 966 |
+
writer.writerow([pos + 1, f"{prob:.4f}", prob >= threshold])
|
| 967 |
|
| 968 |
# Create GFF3 export
|
| 969 |
+
gff_path = create_gff3_export(regions, result.sequence_length, output_dir=output_dir) if regions else None
|
| 970 |
|
| 971 |
# Create sequence viewer HTML
|
| 972 |
+
seq_viewer_html = create_sequence_viewer_html(sequence, display_positions, result.probabilities, threshold)
|
| 973 |
|
| 974 |
elapsed_time = time.time() - start_time
|
| 975 |
|
| 976 |
# Create summary text file
|
| 977 |
+
summary_path = os.path.join(output_dir, "crispr_summary.txt")
|
| 978 |
summary_text = f"""CRISPR Array Detection Summary
|
| 979 |
==============================
|
| 980 |
|
|
|
|
| 1021 |
|
| 1022 |
def detect(sequence: str, threshold: float = 0.3, min_length: int = 160):
|
| 1023 |
"""Detect CRISPR array regions."""
|
| 1024 |
+
is_valid, sequence, error = normalize_sequence_input(sequence)
|
| 1025 |
+
if not is_valid:
|
| 1026 |
+
return [], f"**Error**: {error}"
|
| 1027 |
+
|
| 1028 |
+
is_valid, threshold, error = validate_threshold(threshold)
|
| 1029 |
+
if not is_valid:
|
| 1030 |
+
return [], f"**Error**: {error}"
|
| 1031 |
|
| 1032 |
+
is_valid, min_length, error = validate_min_length(min_length)
|
| 1033 |
if not is_valid:
|
| 1034 |
return [], f"**Error**: {error}"
|
| 1035 |
|
|
|
|
| 1050 |
return regions, summary
|
| 1051 |
|
| 1052 |
|
| 1053 |
+
def save_figure_to_file(fig, prefix="plot", output_dir=None):
|
| 1054 |
"""Save matplotlib figure to temporary files for download."""
|
| 1055 |
+
output_dir = output_dir or make_output_dir(prefix)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1056 |
|
| 1057 |
# Save PNG
|
| 1058 |
+
png_path = os.path.join(output_dir, f"{prefix}.png")
|
| 1059 |
fig.savefig(png_path, dpi=150, bbox_inches='tight', facecolor='white')
|
| 1060 |
|
| 1061 |
# Save PDF
|
| 1062 |
+
pdf_path = os.path.join(output_dir, f"{prefix}.pdf")
|
| 1063 |
fig.savefig(pdf_path, bbox_inches='tight', facecolor='white')
|
| 1064 |
|
| 1065 |
return png_path, pdf_path
|
|
|
|
| 1067 |
|
| 1068 |
def get_embedding(sequence: str, mode: str = "mean", use_3d: bool = False):
|
| 1069 |
"""Extract hidden state embedding and visualize as heatmap."""
|
| 1070 |
+
allowed_modes = {"state-dynamics", "mean", "max", "trajectory", "cls"}
|
| 1071 |
+
if mode not in allowed_modes:
|
| 1072 |
+
return embedding_error_outputs(
|
| 1073 |
+
"Mode must be one of: state-dynamics, mean, max, trajectory, cls"
|
| 1074 |
+
)
|
| 1075 |
|
| 1076 |
+
is_valid, sequence, error = normalize_sequence_input(sequence)
|
| 1077 |
if not is_valid:
|
| 1078 |
+
return embedding_error_outputs(error)
|
| 1079 |
|
| 1080 |
result = embed_sequence(sequence, mode="trajectory" if mode == "state-dynamics" else mode)
|
| 1081 |
png_path, pdf_path = None, None
|
| 1082 |
+
output_dir = make_output_dir("crispr_embedding")
|
| 1083 |
|
| 1084 |
if mode == "trajectory":
|
| 1085 |
# Create trajectory heatmap (windows x dimensions)
|
|
|
|
| 1087 |
result.embeddings,
|
| 1088 |
title="Embedding Trajectory Across Sequence"
|
| 1089 |
)
|
| 1090 |
+
png_path, pdf_path = save_figure_to_file(fig, "trajectory_embedding", output_dir)
|
| 1091 |
summary = f"""## Trajectory Embedding
|
| 1092 |
|
| 1093 |
| Property | Value |
|
|
|
|
| 1110 |
|
| 1111 |
# For downloads, create a static matplotlib version
|
| 1112 |
static_fig = create_state_dynamic_plot(embeddings, n_clusters=n_clusters, stride=100)
|
| 1113 |
+
png_path, pdf_path = save_figure_to_file(static_fig, "state_dynamic_plot", output_dir)
|
| 1114 |
plt.close(static_fig)
|
| 1115 |
|
| 1116 |
dim_text = "3D" if use_3d else "2D"
|
|
|
|
| 1141 |
result.embedding,
|
| 1142 |
title=f"Sequence Embedding ({result.method})"
|
| 1143 |
)
|
| 1144 |
+
png_path, pdf_path = save_figure_to_file(fig, f"embedding_{mode}", output_dir)
|
| 1145 |
summary = f"""## Embedding Extracted
|
| 1146 |
|
| 1147 |
| Property | Value |
|
|
|
|
| 1158 |
|
| 1159 |
|
| 1160 |
# Build interface
|
| 1161 |
+
with gr.Blocks(
|
| 1162 |
+
title="CRISPR Array Detection",
|
| 1163 |
+
theme=gr.themes.Base(
|
| 1164 |
+
primary_hue=gr.themes.colors.zinc,
|
| 1165 |
+
secondary_hue=gr.themes.colors.zinc,
|
| 1166 |
+
neutral_hue=gr.themes.colors.zinc,
|
| 1167 |
+
font=gr.themes.GoogleFont("Inter"),
|
| 1168 |
+
font_mono=gr.themes.GoogleFont("Geist Mono"),
|
| 1169 |
+
),
|
| 1170 |
+
css=CUSTOM_CSS,
|
| 1171 |
+
delete_cache=(3600, 86400),
|
| 1172 |
+
) as demo:
|
| 1173 |
gr.Markdown("""
|
| 1174 |
# crispr-detect
|
| 1175 |
|
|
|
|
| 1255 |
results = predict(*args)
|
| 1256 |
# results = (fig, summary, regions, png, pdf, csv, summary_txt, gff, seq_html)
|
| 1257 |
# Return results plus visibility updates for accordions
|
| 1258 |
+
success = results[0] is not None
|
| 1259 |
+
return results + (gr.update(visible=success), gr.update(visible=success))
|
| 1260 |
|
| 1261 |
predict_btn.click(
|
| 1262 |
predict_and_show_downloads,
|
| 1263 |
inputs=[seq_input, stride_input, threshold_input],
|
| 1264 |
outputs=[plot_output, result_summary, regions_output, pred_download_png, pred_download_pdf,
|
| 1265 |
pred_download_csv, pred_download_summary, pred_download_gff, seq_viewer_html,
|
| 1266 |
+
download_accordion, seq_viewer_accordion],
|
| 1267 |
+
api_name="predict",
|
| 1268 |
+
concurrency_limit=1,
|
| 1269 |
)
|
| 1270 |
|
| 1271 |
with gr.Tab("Embeddings"):
|
|
|
|
| 1323 |
|
| 1324 |
def embed_and_show_downloads(*args):
|
| 1325 |
results = get_embedding(*args)
|
| 1326 |
+
success = results[0] is not None
|
| 1327 |
+
return results + (gr.update(visible=success),)
|
| 1328 |
|
| 1329 |
embed_btn.click(
|
| 1330 |
embed_and_show_downloads,
|
| 1331 |
inputs=[embed_seq, embed_mode, use_3d],
|
| 1332 |
+
outputs=[embed_plot, embed_summary, download_png, download_pdf, embed_download_accordion],
|
| 1333 |
+
api_name="get_embedding",
|
| 1334 |
+
concurrency_limit=1,
|
| 1335 |
)
|
| 1336 |
|
| 1337 |
with gr.Tab("API"):
|
|
|
|
| 1402 |
model = get_model()
|
| 1403 |
warmup_model(model)
|
| 1404 |
print(f"Model ready! GPU: {get_gpu_status()}")
|
| 1405 |
+
demo.queue(max_size=QUEUE_MAX_SIZE, default_concurrency_limit=1)
|
| 1406 |
demo.launch(
|
| 1407 |
server_name="0.0.0.0",
|
| 1408 |
server_port=7860,
|
| 1409 |
+
max_threads=4,
|
| 1410 |
+
show_error=True,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1411 |
)
|