genomenet Claude Opus 4.5 commited on
Commit
3cc5297
·
1 Parent(s): 6b4e599

Minimalist monochrome redesign with Geist Mono font

Browse files

- Monochrome/grayscale color scheme throughout
- Geist Mono font for code and sequence display
- Simplified UI text: lowercase labels, minimal descriptions
- Grayscale Plotly charts with subtle styling
- Minimal header: "crispr-detect" with brief description
- Compact API and About tabs
- Zinc-based Gradio theme

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>

Files changed (3) hide show
  1. app.py +420 -292
  2. inference/inference.py +35 -7
  3. inference/tokenizer.py +31 -5
app.py CHANGED
@@ -3,7 +3,10 @@ CRISPR Array Detection - HuggingFace Spaces App
3
  """
4
 
5
  import os
 
 
6
  os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
 
7
 
8
  import gradio as gr
9
  import numpy as np
@@ -23,24 +26,100 @@ from inference.model_loader import get_model, warmup_model, get_gpu_status
23
  from inference.tokenizer import validate_sequence, strip_fasta_header
24
  from inference.inference import detect_crispr_regions
25
 
26
- # Custom CSS for better fonts
 
 
 
 
 
27
  CUSTOM_CSS = """
28
- @import url('https://fonts.googleapis.com/css2?family=Inter:wght@400;500;600&family=JetBrains+Mono&display=swap');
 
 
 
 
 
 
 
 
 
 
29
 
30
  * {
31
- font-family: 'Inter', -apple-system, BlinkMacSystemFont, 'Segoe UI', sans-serif !important;
32
  }
33
 
34
- code, pre, .code, textarea {
35
- font-family: 'JetBrains Mono', 'Fira Code', monospace !important;
36
  }
37
 
38
- h1, h2, h3 {
39
- font-weight: 600 !important;
 
 
 
 
 
 
40
  }
41
 
42
  .gradio-container {
43
  max-width: 1200px !important;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
44
  }
45
  """
46
 
@@ -68,6 +147,100 @@ EMBEDDING_CRISPR_EXAMPLE = """GACAGGTACAAGAAGGAGTATGCATCAATGTGGTCGTGTGGAACAAACGC
68
  EMBEDDING_RANDOM_EXAMPLE = """ATGCGATCGATCGATCGATCGATCGTAGCTAGCTAGCTAGCTAGCTGATCGATCGATCGTAGCTAGCTAGCTGATCGATCGATCGATCGATCGTAGCTAGCTAGCTAGCTAGCTGATCGATCGATCGTAGCTAGCTAGCTGATCGATCGATCGATCGATCGTAGCTAGCTAGCTAGCTAGCTGATCGATCGATCGTAGCTAGCTAGCTGATCGATCGATCGATCGATCGTAGCTAGCTAGCTAGCTAGCTGATCGATCGATCGTAGCTAGCTAGCTGATCGATCGATCGATCGATCGTAGCTAGCTAGCTAGCTAGCTGATCGATCGATCGTAGCTAGCTAGCTGATCGATCGATCGATCGATCGTAGCTAGCTAGCTAGCTAGCTGATCGATCGATCGTAGCTAGCTAGCTGATCGATCGATCGATCGATCGTAGCTAGCTAGCTAGCTAGCTGATCGATCGATCGTAGCTAGCTAGCTGATCGATCGATCGATCGATCGTAGCTAGCTAGCTAGCTAGCTGATCGATCGATCGTAGCTAGCTAGCTGATCGATCGATCGATCGATCGTAGCTAGCTAGCTAGCTAGCTGATCGATCGATCGTAGCTAGCTAGCTGATCGATCGATCGATCGATCGTAGCTAGCTAGCTAGCTAGCTGATCGATCGATCGTAGCTAGCTAGCTGATCGATCGATCGATCGATCGTAGCTAGCTAGCTAGCTAGCTGATCGATCGATCGTAGCTAGCTAGCTGATCGATCGATCGATCGATCGTAGCTAGCTAGCTAGCTAGCTGATCGATCGATCGTAGCTAGCTAGCTGATCGATCGATCGATCGATCGTAGCTAGCTAGCTAGCTAGCTGATCGATCGATCGTAGCTAGCTAGCTGATCGATCGATCGATCGATCGTAGCTAGCTAGCTAGCTAGCTGATCGATCGATCGTAGCTAGCTAGCTGATCGATCGATCGATCGATCGTAGCTAGCTAGCTAGCTAGCTGATCGATCGATCGTAGCTAGCTAGCTGATCGATCGATCGATCGATCGTAGCTAGCTAGCTAGCTAGCTGATCGATCGATCGTAGCTAGCTAGCTGATCGATCGATCGATCGATCGTAGCTAGCTAGCTAGCTAGCTGATCGATCGATCGTAGCTAGCTAGCTGATCGATCGATCGATCGATCGTAGCTAGCTAGCTAGCTAGCTGATCGATCGATCGTAGCTAGCTAGCTGATCGATCGATCGATCGATCGTAGCTAGCTAGCTAGCTAGCTGATCGATCGATCGTAGCTAGCTAGCTGATCGATCGATCGATCGATCGTAGCTAGCTAGCTAGCTAGCTGATCGATCGATCGTAGCTAGCTAGCTGATCGATCGATCGATCGATCGTAGCTAGCTAGCTAGCTAGCTGATCGATCGATCGTAGCTAGCTAGCTGATCGATCGATCGATCGATCGTAGCTAGCTAGCTAGCTAGCTGATCGATCGATCGTAGCTAGCTAGCTGATCGATCGATCGATCGATCGTAGCTAGCTAGCTAGCTAGCTGATCGATCGATCGTAGCTAGCTAGCTGATCGATCGATCGATCGATCGTAGCTAGCTAGCTAGCTAGCTGATCGATCGATCGTAGCTAGCTAGCTGATCGATCGATCGATCGATCGTAGCTAGCTAGCTAGCTAGCTGATCGATCGATCGTAGCTAGCTAGCTGATCGATCGATCGATCGATCGTAGCTAGCTAGCTAGCTAGCTGATCGATCGATCGTAGCTAGCTAGCTGATCGATCGATCGATCGATCGTAGCTAGCTAGCTAGCTAGCTGATCGATCGATCGTAGCTAGCTAGCTGATCGATCGATCGATCGATCGTAGCTAGCTAGCTAGCTAGCTGATCGATCGATCGTAGCTAGCTAGCTGATCGATCGATCGATCGATCGTAGCTAGCTAGCTAGCTAGCTGATCGATCGATCGTAGCTAGCTAGCTGATCGATCGATCGATCGATCGTAGCTAGCTAGCTAGCTAGCTGATCGATCGATCGTAGCTAGCTAGCTGATCGATCGATCGATCGATCGTAGCTAGCTAGCTAGCTAGCTGATCGATCGATCGTAGCTAGCTAGCTGATCGATCGATCGATCGATCGTAGCTAGCTAGCTAGCTAGCTGATCGATCGATCGTAGCTAGCTAGCTGATCGATCGATCGATCGATCGTAGCTAGCTAGCTAGCTAGCTGATCGATCGATCGTAGCTAGCTAGCTGATCGATCGATCGATCGATCGTAGCTAGCTAGCTAGCTAGCTGATCGATCGATCGTAGCTAGCTAGCTGATCGATCGATCGATCGATCGTAGCTAGCTAGCTAGCTAGCTGATCGATCGATCGTAGCTAGCTAGCTGATCGATCGATCGATCGATCGTAGCTAGCTAGCTAGCTAGCTGATCGATCGATCGTAGCTAGCTAGCTGATCGATCGATCGATCGATCGTAGCTAGCTAGCTAGCTAGCTGATCGATCGATCGTAGCTAGCTAGCT"""
69
 
70
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
71
  def create_prediction_plot(positions, probabilities, threshold=0.3, regions=None):
72
  """Create a matplotlib figure showing the prediction curve (for PNG/PDF export)."""
73
  fig, ax = plt.subplots(figsize=(12, 4))
@@ -89,7 +262,7 @@ def create_prediction_plot(positions, probabilities, threshold=0.3, regions=None
89
  ax.set_ylabel('CRISPR Probability')
90
  ax.set_title('CRISPR Array Detection Score')
91
  ax.set_ylim(0, 1)
92
- ax.set_xlim(0, max(positions) if positions else 1000)
93
  ax.legend(loc='upper right')
94
  ax.grid(True, alpha=0.3)
95
 
@@ -101,102 +274,96 @@ def create_interactive_prediction_plot(positions, probabilities, threshold=0.3,
101
  """Create an interactive Plotly figure showing the prediction curve with minimap."""
102
  fig = go.Figure()
103
 
 
104
  max_pos = max(positions) if positions else 1000
105
 
106
- # Main probability curve with fill
107
  fig.add_trace(go.Scatter(
108
  x=positions,
109
  y=probabilities,
110
  mode='lines',
111
- name='Prediction Score',
112
- line=dict(color='#2563eb', width=1.5),
113
  fill='tozeroy',
114
- fillcolor='rgba(37, 99, 235, 0.15)',
115
  hovertemplate='Position: %{x:,} bp<br>Score: %{y:.3f}<extra></extra>'
116
  ))
117
 
118
- # Add threshold line
119
  fig.add_hline(
120
  y=threshold,
121
  line_dash="dash",
122
- line_color="#dc2626",
123
- annotation_text=f"Threshold ({threshold})",
124
  annotation_position="top right",
125
- annotation_font_size=11
 
126
  )
127
 
128
- # Highlight detected CRISPR regions
129
  if regions:
130
  for r in regions:
131
  fig.add_vrect(
132
  x0=r['start'], x1=r['end'],
133
- fillcolor="rgba(220, 38, 38, 0.12)",
134
  layer="below",
135
  line_width=1,
136
- line_color="rgba(220, 38, 38, 0.3)",
137
- annotation_text=f"CRISPR {r['region_id']}",
138
  annotation_position="top left",
139
- annotation_font_size=10,
140
- annotation_font_color="#dc2626"
141
  )
142
 
143
  fig.update_layout(
144
- title=dict(
145
- text='CRISPR Array Detection',
146
- font=dict(size=14, color='#1f2937'),
147
- x=0.5,
148
- xanchor='center'
149
- ),
150
  xaxis=dict(
151
- title='Position (bp)',
152
- range=[0, max_pos],
153
- gridcolor='#e5e7eb',
154
  showgrid=True,
155
  zeroline=False,
156
- # Rangeslider for minimap navigation
 
157
  rangeslider=dict(
158
  visible=True,
159
- thickness=0.08,
160
- bgcolor='#f3f4f6',
161
- bordercolor='#d1d5db',
162
  borderwidth=1
163
  ),
164
- # Range selector buttons for quick zoom
165
  rangeselector=dict(
166
  buttons=list([
167
  dict(count=500, label="500bp", step="all", stepmode="backward"),
168
  dict(count=1000, label="1kb", step="all", stepmode="backward"),
169
  dict(count=5000, label="5kb", step="all", stepmode="backward"),
170
- dict(step="all", label="Full")
171
  ]),
172
- bgcolor='#f9fafb',
173
- bordercolor='#d1d5db',
174
- font=dict(size=10),
 
175
  x=0,
176
- y=1.15
177
  )
178
  ),
179
  yaxis=dict(
180
- title='CRISPR Probability',
181
  range=[0, 1.05],
182
- gridcolor='#e5e7eb',
183
  showgrid=True,
184
  zeroline=False,
 
 
185
  tickformat='.1f'
186
  ),
187
  hovermode='x unified',
188
- showlegend=True,
189
- legend=dict(
190
- yanchor="top", y=0.99,
191
- xanchor="right", x=0.99,
192
- bgcolor='rgba(255,255,255,0.8)',
193
- bordercolor='#e5e7eb',
194
- borderwidth=1
195
- ),
196
- height=480,
197
- plot_bgcolor='white',
198
- paper_bgcolor='white',
199
- margin=dict(t=80, b=60)
200
  )
201
 
202
  return fig
@@ -221,9 +388,8 @@ def create_embedding_heatmap(embedding, title="Sequence Embedding", cols=30):
221
  # Create figure
222
  fig, ax = plt.subplots(figsize=(14, max(3, rows * 0.25)))
223
 
224
- # Use diverging colormap centered at 0
225
- vmax = max(abs(np.nanmin(embedding)), abs(np.nanmax(embedding)))
226
- norm = TwoSlopeNorm(vmin=-vmax, vcenter=0, vmax=vmax)
227
 
228
  im = ax.imshow(grid, cmap='RdBu_r', norm=norm, aspect='auto')
229
 
@@ -262,9 +428,8 @@ def create_trajectory_heatmap(embeddings, title="Embedding Trajectory"):
262
 
263
  fig, ax = plt.subplots(figsize=(14, max(4, n_windows * 0.3)))
264
 
265
- # Use diverging colormap
266
- vmax = max(abs(embeddings.min()), abs(embeddings.max()))
267
- norm = TwoSlopeNorm(vmin=-vmax, vcenter=0, vmax=vmax)
268
 
269
  im = ax.imshow(embeddings, cmap='RdBu_r', norm=norm, aspect='auto')
270
 
@@ -442,17 +607,17 @@ def create_sequence_cluster_map(cluster_labels, stride=100, window_size=1000):
442
 
443
  def create_interactive_state_plot(embeddings, n_clusters=8, stride=100, use_3d=False):
444
  """
445
- Create interactive Plotly State-Dynamic Plot with 2D or 3D UMAP.
446
  """
447
  embeddings = np.array(embeddings)
448
  n_windows, n_dims = embeddings.shape
449
 
450
  if n_windows < 5:
451
- # Not enough data
452
  fig = go.Figure()
453
  fig.add_annotation(text="Need longer sequence (minimum ~1500 bp)",
454
  xref="paper", yref="paper", x=0.5, y=0.5,
455
- showarrow=False, font=dict(size=16))
 
456
  return fig
457
 
458
  # UMAP reduction
@@ -477,70 +642,74 @@ def create_interactive_state_plot(embeddings, n_clusters=8, stride=100, use_3d=F
477
  hover_text = [f"Window {i}<br>Position: {pos}-{pos+1000} bp<br>Cluster: {c}"
478
  for i, (pos, c) in enumerate(zip(positions, cluster_labels))]
479
 
480
- # Color palette
481
- colors = px.colors.qualitative.Set1[:n_clusters]
482
- color_map = [colors[c] for c in cluster_labels]
483
 
484
  if use_3d:
485
- # 3D Plot
486
  fig = go.Figure()
487
 
488
- # Add trajectory line
489
  fig.add_trace(go.Scatter3d(
490
  x=embedding_reduced[:, 0],
491
  y=embedding_reduced[:, 1],
492
  z=embedding_reduced[:, 2],
493
  mode='lines',
494
- line=dict(color='rgba(100,100,100,0.3)', width=2),
495
  name='Trajectory',
496
  hoverinfo='skip'
497
  ))
498
 
499
- # Add points colored by cluster
500
  fig.add_trace(go.Scatter3d(
501
  x=embedding_reduced[:, 0],
502
  y=embedding_reduced[:, 1],
503
  z=embedding_reduced[:, 2],
504
  mode='markers',
505
  marker=dict(
506
- size=6,
507
  color=cluster_labels,
508
- colorscale='Set1',
509
- opacity=0.8,
510
- line=dict(width=1, color='white')
511
  ),
512
  text=hover_text,
513
  hovertemplate='%{text}<extra></extra>',
514
  name='Windows'
515
  ))
516
 
517
- # Mark start and end
518
  fig.add_trace(go.Scatter3d(
519
  x=[embedding_reduced[0, 0]],
520
  y=[embedding_reduced[0, 1]],
521
  z=[embedding_reduced[0, 2]],
522
  mode='markers',
523
- marker=dict(size=12, color='green', symbol='diamond'),
524
- name="Start (5')"
525
  ))
 
526
  fig.add_trace(go.Scatter3d(
527
  x=[embedding_reduced[-1, 0]],
528
  y=[embedding_reduced[-1, 1]],
529
  z=[embedding_reduced[-1, 2]],
530
  mode='markers',
531
- marker=dict(size=12, color='red', symbol='square'),
532
- name="End (3')"
533
  ))
534
 
535
  fig.update_layout(
536
- title=dict(text='3D State-Dynamic Plot (drag to rotate)', font=dict(size=16)),
537
  scene=dict(
538
- xaxis_title='UMAP 1',
539
- yaxis_title='UMAP 2',
540
- zaxis_title='UMAP 3'
541
  ),
542
- height=600,
543
- showlegend=True
 
 
 
 
544
  )
545
 
546
  else:
@@ -549,7 +718,7 @@ def create_interactive_state_plot(embeddings, n_clusters=8, stride=100, use_3d=F
549
  rows=2, cols=2,
550
  specs=[[{"type": "scatter"}, {"type": "scatter"}],
551
  [{"type": "scatter", "colspan": 2}, None]],
552
- subplot_titles=('By Cluster', 'By Position', 'Sequence Map'),
553
  row_heights=[0.6, 0.4],
554
  vertical_spacing=0.12
555
  )
@@ -559,7 +728,7 @@ def create_interactive_state_plot(embeddings, n_clusters=8, stride=100, use_3d=F
559
  x=embedding_reduced[:, 0],
560
  y=embedding_reduced[:, 1],
561
  mode='lines',
562
- line=dict(color='rgba(100,100,100,0.2)', width=1),
563
  hoverinfo='skip',
564
  showlegend=False
565
  ), row=1, col=1)
@@ -570,36 +739,37 @@ def create_interactive_state_plot(embeddings, n_clusters=8, stride=100, use_3d=F
570
  x=embedding_reduced[mask, 0],
571
  y=embedding_reduced[mask, 1],
572
  mode='markers',
573
- marker=dict(size=8, color=colors[c], opacity=0.8,
574
- line=dict(width=1, color='white')),
575
  text=[hover_text[i] for i in np.where(mask)[0]],
576
  hovertemplate='%{text}<extra></extra>',
577
- name=f'Cluster {c}',
578
  legendgroup=f'c{c}'
579
  ), row=1, col=1)
580
 
581
  # Start/End markers
582
  fig.add_trace(go.Scatter(
583
  x=[embedding_reduced[0, 0]], y=[embedding_reduced[0, 1]],
584
- mode='markers', marker=dict(size=15, color='green', symbol='triangle-up',
585
- line=dict(width=2, color='black')),
586
- name="Start (5')", showlegend=True
587
  ), row=1, col=1)
588
  fig.add_trace(go.Scatter(
589
  x=[embedding_reduced[-1, 0]], y=[embedding_reduced[-1, 1]],
590
- mode='markers', marker=dict(size=15, color='red', symbol='square',
591
- line=dict(width=2, color='black')),
592
- name="End (3')", showlegend=True
593
  ), row=1, col=1)
594
 
595
- # Right plot: by position
596
  fig.add_trace(go.Scatter(
597
  x=embedding_reduced[:, 0],
598
  y=embedding_reduced[:, 1],
599
  mode='lines+markers',
600
- line=dict(color='rgba(100,100,100,0.3)', width=1),
601
- marker=dict(size=8, color=np.arange(n_windows), colorscale='Viridis',
602
- showscale=True, colorbar=dict(title='Window', x=1.02)),
 
603
  text=hover_text,
604
  hovertemplate='%{text}<extra></extra>',
605
  showlegend=False
@@ -607,47 +777,59 @@ def create_interactive_state_plot(embeddings, n_clusters=8, stride=100, use_3d=F
607
 
608
  fig.add_trace(go.Scatter(
609
  x=[embedding_reduced[0, 0]], y=[embedding_reduced[0, 1]],
610
- mode='markers', marker=dict(size=15, color='green', symbol='triangle-up',
611
- line=dict(width=2, color='black')),
612
  showlegend=False
613
  ), row=1, col=2)
614
  fig.add_trace(go.Scatter(
615
  x=[embedding_reduced[-1, 0]], y=[embedding_reduced[-1, 1]],
616
- mode='markers', marker=dict(size=15, color='red', symbol='square',
617
- line=dict(width=2, color='black')),
618
  showlegend=False
619
  ), row=1, col=2)
620
 
621
- # Bottom: sequence map (horizontal bar)
622
  window_size = 1000
623
  for i, (cluster, pos) in enumerate(zip(cluster_labels, positions)):
624
  fig.add_trace(go.Scatter(
625
  x=[pos, pos + window_size, pos + window_size, pos, pos],
626
  y=[0, 0, 1, 1, 0],
627
  fill='toself',
628
- fillcolor=colors[cluster],
629
  line=dict(width=0),
630
- opacity=0.7,
631
  hoverinfo='text',
632
  text=f'Position {pos}-{pos+window_size} bp<br>Cluster {cluster}',
633
  showlegend=False
634
  ), row=2, col=1)
635
 
636
- fig.update_xaxes(title_text='UMAP 1', row=1, col=1)
637
- fig.update_yaxes(title_text='UMAP 2', row=1, col=1)
638
- fig.update_xaxes(title_text='UMAP 1', row=1, col=2)
639
- fig.update_yaxes(title_text='UMAP 2', row=1, col=2)
640
- fig.update_xaxes(title_text='Position (bp)', row=2, col=1)
 
 
 
 
 
641
  fig.update_yaxes(visible=False, row=2, col=1)
642
 
643
  fig.update_layout(
644
- title=dict(text='Interactive State-Dynamic Plot (hover for details, zoom/pan available)',
645
- font=dict(size=14)),
646
- height=700,
647
  showlegend=True,
648
- legend=dict(orientation='h', yanchor='bottom', y=1.02, xanchor='right', x=1)
 
 
 
 
 
649
  )
650
 
 
 
 
 
651
  return fig
652
 
653
 
@@ -655,15 +837,23 @@ def parse_fasta_file(file_path):
655
  """Parse a FASTA file and return the sequence."""
656
  if file_path is None:
657
  return None
658
- with open(file_path, 'r') as f:
 
 
 
 
659
  content = f.read()
660
- return strip_fasta_header(content.strip())
 
 
 
 
661
 
662
 
663
- def create_gff3_export(regions, sequence_length, sequence_id="input_sequence"):
664
  """Create GFF3 format annotation file for detected CRISPR regions."""
665
- import tempfile
666
- gff_path = os.path.join(tempfile.gettempdir(), "crispr_regions.gff3")
667
 
668
  with open(gff_path, 'w') as f:
669
  # GFF3 header
@@ -673,59 +863,52 @@ def create_gff3_export(regions, sequence_length, sequence_id="input_sequence"):
673
  for r in regions:
674
  # GFF3 format: seqid source type start end score strand phase attributes
675
  attributes = f"ID=CRISPR_{r['region_id']};Name=CRISPR_array_{r['region_id']};score={r['mean_score']:.3f}"
676
- f.write(f"{sequence_id}\tCRISPR-BERT\tCRISPR_array\t{r['start']+1}\t{r['end']}\t{r['mean_score']:.3f}\t.\t.\t{attributes}\n")
677
 
678
  return gff_path
679
 
680
 
681
  def create_sequence_viewer_html(sequence, positions, probabilities, threshold=0.3, chunk_size=100):
682
- """Create an HTML visualization of the sequence with color-coded scores."""
683
- # Interpolate scores to per-nucleotide level
684
- import numpy as np
685
-
686
  seq_len = len(sequence)
687
- per_base_scores = np.zeros(seq_len)
688
-
689
- # Map window scores to positions
690
- for i, (pos, prob) in enumerate(zip(positions, probabilities)):
691
- start = pos
692
- end = min(pos + 1000, seq_len) # window size
693
- # Average with existing scores for overlapping windows
694
- for j in range(start, end):
695
- if per_base_scores[j] == 0:
696
- per_base_scores[j] = prob
697
- else:
698
- per_base_scores[j] = (per_base_scores[j] + prob) / 2
699
-
700
- # Generate HTML
701
- html_parts = ['<div style="font-family: monospace; font-size: 12px; line-height: 1.8; background: #f8f9fa; padding: 15px; border-radius: 8px; max-height: 400px; overflow-y: auto;">']
702
- html_parts.append('<div style="margin-bottom: 10px; font-family: sans-serif; font-size: 13px;">')
703
- html_parts.append('<span style="background: linear-gradient(to right, #3b82f6, #fbbf24, #ef4444); padding: 2px 20px; border-radius: 3px; color: white;">Low → Medium → High CRISPR Score</span>')
704
- html_parts.append(f'<span style="margin-left: 15px;">Threshold: {threshold}</span>')
705
  html_parts.append('</div>')
706
 
707
- # Process sequence in chunks with position markers
708
  for chunk_start in range(0, seq_len, chunk_size):
709
  chunk_end = min(chunk_start + chunk_size, seq_len)
710
  chunk_seq = sequence[chunk_start:chunk_end]
711
  chunk_scores = per_base_scores[chunk_start:chunk_end]
712
 
713
  # Position marker
714
- html_parts.append(f'<div><span style="color: #666; width: 60px; display: inline-block; font-size: 11px;">{chunk_start+1:,}</span>')
715
 
716
  for i, (base, score) in enumerate(zip(chunk_seq, chunk_scores)):
717
- # Color based on score: blue (low) -> yellow (medium) -> red (high)
718
- if score < threshold * 0.5:
719
- color = "#3b82f6" # blue
720
- elif score < threshold:
721
- color = "#fbbf24" # yellow
722
- elif score < threshold * 1.5:
723
- color = "#f97316" # orange
724
- else:
725
- color = "#ef4444" # red
726
-
727
- bg_opacity = min(0.3 + score * 0.7, 1.0)
728
- html_parts.append(f'<span style="color: {color}; background-color: rgba(0,0,0,{bg_opacity * 0.1}); font-weight: {"bold" if score >= threshold else "normal"};" title="Pos {chunk_start + i + 1}: {score:.3f}">{base}</span>')
729
 
730
  html_parts.append('</div>')
731
 
@@ -957,75 +1140,72 @@ Blue = negative activation, Red = positive activation.
957
  # Build interface
958
  with gr.Blocks(title="CRISPR Array Detection") as demo:
959
  gr.Markdown("""
960
- # CRISPR Array Detection
961
-
962
- A deep learning approach for identifying CRISPR arrays in prokaryotic genome sequences. This tool employs a 24-layer BERT transformer architecture (~430M parameters) that was pre-trained on metagenomic contigs and complete microbial genomes, then fine-tuned on annotated CRISPR array sequences.
963
 
964
- **Method**: Input sequences are processed using a sliding window approach (1000 bp window, configurable stride). For each window, the model outputs a probability score ∈ [0,1] indicating the likelihood that the central region contains part of a CRISPR array. Overlapping predictions are aggregated to produce per-position scores across the full sequence length.
965
-
966
- **Output**: Detected CRISPR regions are reported with genomic coordinates, mean prediction scores, and can be exported in standard formats (GFF3, CSV) for downstream analysis.
967
  """)
968
 
969
  with gr.Tab("Prediction"):
970
  with gr.Row():
971
  with gr.Column(scale=1):
972
  seq_input = gr.Textbox(
973
- label="Input Sequence",
974
  placeholder="Paste DNA sequence (FASTA format accepted)...",
975
  lines=6,
976
  value=FLANKED_CRISPR_EXAMPLE,
977
- info="Minimum length: 1000 bp. Accepts raw sequence or FASTA format."
978
  )
979
  file_upload = gr.File(
980
- label="Upload FASTA File",
981
  file_types=[".fasta", ".fa", ".fna", ".txt"],
982
  type="filepath"
983
  )
984
  with gr.Row():
985
  stride_input = gr.Slider(
986
  minimum=50, maximum=500, value=100, step=50,
987
- label="Stride (bp)",
988
- info="Step size between consecutive windows. Lower values increase resolution but require more computation."
989
  )
990
  threshold_input = gr.Slider(
991
  minimum=0.1, maximum=0.9, value=0.3, step=0.05,
992
- label="Detection Threshold",
993
- info="Minimum score to classify a region as CRISPR. Lower = more sensitive, higher = more specific."
994
  )
995
  with gr.Row():
996
- predict_btn = gr.Button("Run Analysis", variant="primary", size="lg")
997
- gr.Markdown("**Example sequences:**")
998
  with gr.Row():
999
- gr.Button("Flanked CRISPR").click(
1000
  lambda: FLANKED_CRISPR_EXAMPLE, outputs=seq_input
1001
  )
1002
- gr.Button("E. coli K-12 CRISPR I-E").click(
1003
  lambda: ECOLI_CRISPR_EXAMPLE, outputs=seq_input
1004
  )
1005
  with gr.Row():
1006
- gr.Button("CRISPR Array").click(
1007
  lambda: CRISPR_EXAMPLE, outputs=seq_input
1008
  )
1009
- gr.Button("Negative Control").click(
1010
  lambda: NON_CRISPR_EXAMPLE, outputs=seq_input
1011
  )
1012
  result_summary = gr.Markdown()
1013
- with gr.Accordion("Export Results", open=False, visible=False) as download_accordion:
1014
- gr.Markdown("**Figures:**")
1015
  with gr.Row():
1016
- pred_download_png = gr.File(label="PNG", interactive=False)
1017
- pred_download_pdf = gr.File(label="PDF", interactive=False)
1018
- gr.Markdown("**Data:**")
1019
  with gr.Row():
1020
- pred_download_csv = gr.File(label="CSV", interactive=False)
1021
- pred_download_gff = gr.File(label="GFF3", interactive=False)
1022
  with gr.Row():
1023
- pred_download_summary = gr.File(label="Summary", interactive=False)
1024
  with gr.Column(scale=2):
1025
- plot_output = gr.Plot(label="Prediction Score Profile")
1026
- with gr.Accordion("Sequence Viewer", open=False, visible=False) as seq_viewer_accordion:
1027
- gr.Markdown("*Color scale: blue (low score) → yellow (medium) → red (high score). Hover over nucleotides for exact values.*")
1028
- seq_viewer_html = gr.HTML(label="Color-coded sequence")
1029
  regions_output = gr.JSON(label="Detected Regions", visible=False)
1030
 
1031
  # Handle file upload - load content into textbox
@@ -1056,52 +1236,49 @@ A deep learning approach for identifying CRISPR arrays in prokaryotic genome seq
1056
 
1057
  with gr.Tab("Embeddings"):
1058
  gr.Markdown("""
1059
- ### Hidden State Analysis
1060
-
1061
- Extract and visualize the model's internal representations (embeddings) from the transformer layers. The **State-Dynamics** mode applies UMAP dimensionality reduction to project the 768-dimensional embeddings into 2D/3D space, then performs agglomerative clustering to identify regions with similar activation patterns.
1062
 
1063
- **Biological interpretation**: In CRISPR arrays, repeat sequences share conserved motifs and should cluster together, while unique spacer sequences form distinct clusters. This creates a characteristic alternating pattern in the sequence map visualization.
 
1064
  """)
1065
  with gr.Row():
1066
  with gr.Column(scale=1):
1067
  embed_seq = gr.Textbox(
1068
- label="Input Sequence",
1069
  placeholder="Paste DNA sequence...",
1070
  lines=6,
1071
  value=EMBEDDING_CRISPR_EXAMPLE,
1072
- info="Longer sequences (>2000 bp) provide better clustering resolution."
1073
  )
1074
  embed_mode = gr.Radio(
1075
  choices=["state-dynamics", "mean", "max", "trajectory"],
1076
  value="state-dynamics",
1077
- label="Visualization Mode",
1078
- info="state-dynamics: UMAP clustering | mean/max: pooled embedding | trajectory: per-window heatmap"
1079
  )
1080
  use_3d = gr.Checkbox(
1081
- label="3D UMAP Projection",
1082
  value=False,
1083
- info="Project embeddings to 3D space (interactive rotation)",
1084
  visible=True
1085
  )
1086
  with gr.Row():
1087
- embed_btn = gr.Button("Extract Embeddings", variant="primary")
1088
  with gr.Row():
1089
- gr.Button("CRISPR Example (3kb)").click(
1090
  lambda: EMBEDDING_CRISPR_EXAMPLE, outputs=embed_seq
1091
  )
1092
- gr.Button("Control Sequence (3kb)").click(
1093
  lambda: EMBEDDING_RANDOM_EXAMPLE, outputs=embed_seq
1094
  )
1095
- gr.Markdown("""
1096
- **Example structure:** 600 bp upstream | CRISPR array (25 repeats + 24 spacers) | 600 bp downstream
1097
- """)
1098
  embed_summary = gr.Markdown()
1099
- with gr.Accordion("Export Results", open=False, visible=False) as embed_download_accordion:
1100
  with gr.Row():
1101
- download_png = gr.File(label="PNG", interactive=False)
1102
- download_pdf = gr.File(label="PDF", interactive=False)
1103
  with gr.Column(scale=2):
1104
- embed_plot = gr.Plot(label="Embedding Visualization")
1105
 
1106
  # Show/hide 3D checkbox based on mode
1107
  embed_mode.change(
@@ -1122,119 +1299,64 @@ Extract and visualize the model's internal representations (embeddings) from the
1122
 
1123
  with gr.Tab("API"):
1124
  gr.Markdown("""
1125
- ### Programmatic Access
1126
-
1127
- This tool can be accessed programmatically using the Gradio Python client or via HTTP requests.
1128
-
1129
- #### Python Client
1130
 
1131
  ```python
1132
  from gradio_client import Client
1133
 
1134
- # Connect to the API
1135
  client = Client("genomenet/crispr-array-detection")
1136
 
1137
- # Run prediction
1138
  result = client.predict(
1139
- sequence="ATGC...", # DNA sequence (min 1000 bp)
1140
- stride=100, # Window stride in bp
1141
- threshold=0.3, # Detection threshold
1142
  api_name="/predict"
1143
  )
1144
 
1145
- # result contains: (plot, summary, regions, png_path, pdf_path, csv_path, summary_path, gff_path, seq_viewer_html)
1146
- ```
1147
-
1148
- #### Extract Embeddings
1149
-
1150
- ```python
1151
  result = client.predict(
1152
  sequence="ATGC...",
1153
- mode="state-dynamics", # or "mean", "max", "trajectory"
1154
  use_3d=False,
1155
  api_name="/get_embedding"
1156
  )
1157
  ```
1158
 
1159
- #### cURL Example
1160
-
1161
- ```bash
1162
- curl -X POST "https://genomenet-crispr-array-detection.hf.space/api/predict" \\
1163
- -H "Content-Type: application/json" \\
1164
- -d '{"data": ["ATGCATGC...", 100, 0.3]}'
1165
- ```
1166
-
1167
- #### Output Formats
1168
-
1169
- | Format | Description |
1170
- |--------|-------------|
1171
- | CSV | Per-position scores: `position, probability, above_threshold` |
1172
- | GFF3 | Standard genome annotation format for detected regions |
1173
- | TXT | Human-readable summary with statistics |
1174
- | PNG/PDF | Publication-ready figures |
1175
-
1176
- #### Rate Limits
1177
-
1178
- - Free tier: Standard HuggingFace rate limits apply
1179
- - For high-throughput analysis, consider running the model locally
1180
-
1181
- #### Local Installation
1182
 
 
1183
  ```bash
1184
  git clone https://huggingface.co/spaces/genomenet/crispr-array-detection
1185
- cd crispr-array-detection
1186
- pip install -r requirements.txt
1187
- python app.py
1188
  ```
1189
  """)
1190
 
1191
  with gr.Tab("About"):
1192
  gr.Markdown("""
1193
- ### Model Architecture
1194
-
1195
- | Component | Specification |
1196
- |-----------|--------------|
1197
- | Base model | BERT (Bidirectional Encoder Representations from Transformers) |
1198
- | Layers | 24 transformer blocks |
1199
- | Hidden size | 768 dimensions |
1200
- | Attention heads | 12 |
1201
- | Parameters | ~430 million |
1202
- | Classification head | Bottleneck architecture |
1203
-
1204
- ### Training
1205
-
1206
- **Pre-training corpus**: Metagenomic contigs and complete microbial genomes from public databases.
1207
 
1208
- **Fine-tuning data**: Annotated CRISPR arrays from bacterial and archaeal genomes, including positive examples from CRISPRCasdb and negative examples from non-CRISPR genomic regions.
 
 
 
 
 
1209
 
1210
- **Embedding extraction**: Hidden states are extracted from transformer layer 21 (768 dimensions per position).
1211
 
1212
- ### Parameters
 
 
 
1213
 
1214
- | Parameter | Range | Default | Description |
1215
- |-----------|-------|---------|-------------|
1216
- | Stride | 50-500 bp | 100 bp | Step size between windows. Lower = higher resolution, more computation |
1217
- | Threshold | 0.1-0.9 | 0.3 | Detection cutoff. Lower = more sensitive, higher = more specific |
1218
- | Window size | Fixed | 1000 bp | Input window for the transformer model |
1219
 
1220
- ### Performance Considerations
1221
 
1222
- - **GPU recommended**: T4 or better for interactive use
1223
- - **CPU inference**: Functional but slower (~10-30s per analysis)
1224
- - **Memory**: ~2GB GPU memory required
1225
 
1226
- ### Citation
1227
-
1228
- If you use this tool in your research, please cite:
1229
-
1230
- > Mu, Z. (2024). Deep Learning-Based CRISPR Array Detection. Master's Thesis, Helmholtz Centre for Infection Research.
1231
-
1232
- ### Acknowledgements
1233
-
1234
- - Ziyu Mu - Model development (Master's Thesis, HZI BIFO)
1235
- - DFG SPP 2141 "Much more than Defence" (Project MC 172)
1236
- - BMBF de.NBI / GenomeNet
1237
- - Helmholtz Centre for Infection Research (HZI)
1238
  """)
1239
 
1240
 
@@ -1246,6 +1368,12 @@ if __name__ == "__main__":
1246
  demo.launch(
1247
  server_name="0.0.0.0",
1248
  server_port=7860,
1249
- theme=gr.themes.Soft(),
 
 
 
 
 
 
1250
  css=CUSTOM_CSS
1251
  )
 
3
  """
4
 
5
  import os
6
+ import html
7
+ import tempfile
8
  os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
9
+ os.environ.setdefault("MPLCONFIGDIR", os.path.join(tempfile.gettempdir(), "matplotlib"))
10
 
11
  import gradio as gr
12
  import numpy as np
 
26
  from inference.tokenizer import validate_sequence, strip_fasta_header
27
  from inference.inference import detect_crispr_regions
28
 
29
+ MAX_SEQUENCE_LENGTH = int(os.environ.get("MAX_SEQUENCE_LENGTH", "50000"))
30
+ MAX_UPLOAD_BYTES = int(os.environ.get("MAX_UPLOAD_BYTES", str(2 * 1024 * 1024)))
31
+ MAX_SEQUENCE_VIEWER_LENGTH = int(os.environ.get("MAX_SEQUENCE_VIEWER_LENGTH", "20000"))
32
+ QUEUE_MAX_SIZE = int(os.environ.get("GRADIO_QUEUE_MAX_SIZE", "8"))
33
+
34
+ # Custom CSS - Minimal monochrome design with Geist fonts
35
  CUSTOM_CSS = """
36
+ @import url('https://fonts.googleapis.com/css2?family=Inter:wght@300;400;500;600&display=swap');
37
+ @font-face {
38
+ font-family: 'Geist Mono';
39
+ src: url('https://cdn.jsdelivr.net/npm/geist@1.2.0/dist/fonts/geist-mono/GeistMono-Regular.woff2') format('woff2');
40
+ font-weight: 400;
41
+ }
42
+ @font-face {
43
+ font-family: 'Geist Mono';
44
+ src: url('https://cdn.jsdelivr.net/npm/geist@1.2.0/dist/fonts/geist-mono/GeistMono-Medium.woff2') format('woff2');
45
+ font-weight: 500;
46
+ }
47
 
48
  * {
49
+ font-family: 'Inter', -apple-system, BlinkMacSystemFont, system-ui, sans-serif !important;
50
  }
51
 
52
+ code, pre, .code, textarea, .prose code {
53
+ font-family: 'Geist Mono', 'SF Mono', Consolas, monospace !important;
54
  }
55
 
56
+ h1 {
57
+ font-weight: 500 !important;
58
+ letter-spacing: -0.02em !important;
59
+ }
60
+
61
+ h2, h3, h4 {
62
+ font-weight: 500 !important;
63
+ color: #18181b !important;
64
  }
65
 
66
  .gradio-container {
67
  max-width: 1200px !important;
68
+ background: #fafafa !important;
69
+ }
70
+
71
+ .gr-button-primary {
72
+ background: #18181b !important;
73
+ border: none !important;
74
+ }
75
+
76
+ .gr-button-primary:hover {
77
+ background: #27272a !important;
78
+ }
79
+
80
+ .gr-button-secondary {
81
+ background: #fff !important;
82
+ border: 1px solid #e4e4e7 !important;
83
+ color: #18181b !important;
84
+ }
85
+
86
+ .gr-panel {
87
+ border: 1px solid #e4e4e7 !important;
88
+ background: #fff !important;
89
+ }
90
+
91
+ /* Minimal table styling */
92
+ table {
93
+ border-collapse: collapse !important;
94
+ }
95
+
96
+ th, td {
97
+ border-bottom: 1px solid #e4e4e7 !important;
98
+ padding: 8px 12px !important;
99
+ }
100
+
101
+ th {
102
+ font-weight: 500 !important;
103
+ text-transform: uppercase !important;
104
+ font-size: 11px !important;
105
+ letter-spacing: 0.05em !important;
106
+ color: #71717a !important;
107
+ }
108
+
109
+ /* Slider styling */
110
+ input[type="range"] {
111
+ accent-color: #18181b !important;
112
+ }
113
+
114
+ /* Tab styling */
115
+ .tab-nav button {
116
+ font-weight: 400 !important;
117
+ color: #52525b !important;
118
+ }
119
+
120
+ .tab-nav button.selected {
121
+ color: #18181b !important;
122
+ border-bottom: 2px solid #18181b !important;
123
  }
124
  """
125
 
 
147
  EMBEDDING_RANDOM_EXAMPLE = """ATGCGATCGATCGATCGATCGATCGTAGCTAGCTAGCTAGCTAGCTGATCGATCGATCGTAGCTAGCTAGCTGATCGATCGATCGATCGATCGTAGCTAGCTAGCTAGCTAGCTGATCGATCGATCGTAGCTAGCTAGCTGATCGATCGATCGATCGATCGTAGCTAGCTAGCTAGCTAGCTGATCGATCGATCGTAGCTAGCTAGCTGATCGATCGATCGATCGATCGTAGCTAGCTAGCTAGCTAGCTGATCGATCGATCGTAGCTAGCTAGCTGATCGATCGATCGATCGATCGTAGCTAGCTAGCTAGCTAGCTGATCGATCGATCGTAGCTAGCTAGCTGATCGATCGATCGATCGATCGTAGCTAGCTAGCTAGCTAGCTGATCGATCGATCGTAGCTAGCTAGCTGATCGATCGATCGATCGATCGTAGCTAGCTAGCTAGCTAGCTGATCGATCGATCGTAGCTAGCTAGCTGATCGATCGATCGATCGATCGTAGCTAGCTAGCTAGCTAGCTGATCGATCGATCGTAGCTAGCTAGCTGATCGATCGATCGATCGATCGTAGCTAGCTAGCTAGCTAGCTGATCGATCGATCGTAGCTAGCTAGCTGATCGATCGATCGATCGATCGTAGCTAGCTAGCTAGCTAGCTGATCGATCGATCGTAGCTAGCTAGCTGATCGATCGATCGATCGATCGTAGCTAGCTAGCTAGCTAGCTGATCGATCGATCGTAGCTAGCTAGCTGATCGATCGATCGATCGATCGTAGCTAGCTAGCTAGCTAGCTGATCGATCGATCGTAGCTAGCTAGCTGATCGATCGATCGATCGATCGTAGCTAGCTAGCTAGCTAGCTGATCGATCGATCGTAGCTAGCTAGCTGATCGATCGATCGATCGATCGTAGCTAGCTAGCTAGCTAGCTGATCGATCGATCGTAGCTAGCTAGCTGATCGATCGATCGATCGATCGTAGCTAGCTAGCTAGCTAGCTGATCGATCGATCGTAGCTAGCTAGCTGATCGATCGATCGATCGATCGTAGCTAGCTAGCTAGCTAGCTGATCGATCGATCGTAGCTAGCTAGCTGATCGATCGATCGATCGATCGTAGCTAGCTAGCTAGCTAGCTGATCGATCGATCGTAGCTAGCTAGCTGATCGATCGATCGATCGATCGTAGCTAGCTAGCTAGCTAGCTGATCGATCGATCGTAGCTAGCTAGCTGATCGATCGATCGATCGATCGTAGCTAGCTAGCTAGCTAGCTGATCGATCGATCGTAGCTAGCTAGCTGATCGATCGATCGATCGATCGTAGCTAGCTAGCTAGCTAGCTGATCGATCGATCGTAGCTAGCTAGCTGATCGATCGATCGATCGATCGTAGCTAGCTAGCTAGCTAGCTGATCGATCGATCGTAGCTAGCTAGCTGATCGATCGATCGATCGATCGTAGCTAGCTAGCTAGCTAGCTGATCGATCGATCGTAGCTAGCTAGCTGATCGATCGATCGATCGATCGTAGCTAGCTAGCTAGCTAGCTGATCGATCGATCGTAGCTAGCTAGCTGATCGATCGATCGATCGATCGTAGCTAGCTAGCTAGCTAGCTGATCGATCGATCGTAGCTAGCTAGCTGATCGATCGATCGATCGATCGTAGCTAGCTAGCTAGCTAGCTGATCGATCGATCGTAGCTAGCTAGCTGATCGATCGATCGATCGATCGTAGCTAGCTAGCTAGCTAGCTGATCGATCGATCGTAGCTAGCTAGCTGATCGATCGATCGATCGATCGTAGCTAGCTAGCTAGCTAGCTGATCGATCGATCGTAGCTAGCTAGCTGATCGATCGATCGATCGATCGTAGCTAGCTAGCTAGCTAGCTGATCGATCGATCGTAGCTAGCTAGCTGATCGATCGATCGATCGATCGTAGCTAGCTAGCTAGCTAGCTGATCGATCGATCGTAGCTAGCTAGCTGATCGATCGATCGATCGATCGTAGCTAGCTAGCTAGCTAGCTGATCGATCGATCGTAGCTAGCTAGCTGATCGATCGATCGATCGATCGTAGCTAGCTAGCTAGCTAGCTGATCGATCGATCGTAGCTAGCTAGCTGATCGATCGATCGATCGATCGTAGCTAGCTAGCTAGCTAGCTGATCGATCGATCGTAGCTAGCTAGCTGATCGATCGATCGATCGATCGTAGCTAGCTAGCTAGCTAGCTGATCGATCGATCGTAGCTAGCTAGCTGATCGATCGATCGATCGATCGTAGCTAGCTAGCTAGCTAGCTGATCGATCGATCGTAGCTAGCTAGCTGATCGATCGATCGATCGATCGTAGCTAGCTAGCTAGCTAGCTGATCGATCGATCGTAGCTAGCTAGCTGATCGATCGATCGATCGATCGTAGCTAGCTAGCTAGCTAGCTGATCGATCGATCGTAGCTAGCTAGCTGATCGATCGATCGATCGATCGTAGCTAGCTAGCTAGCTAGCTGATCGATCGATCGTAGCTAGCTAGCT"""
148
 
149
 
150
+ def _count_fasta_records(text: str) -> int:
151
+ return sum(1 for line in text.splitlines() if line.strip().startswith(">"))
152
+
153
+
154
+ def normalize_sequence_input(sequence: str) -> tuple[bool, str, str]:
155
+ """Clean and validate a single-sequence FASTA/raw DNA input."""
156
+ if sequence is None:
157
+ return False, "", "Sequence is empty"
158
+
159
+ text = str(sequence).strip()
160
+ if not text:
161
+ return False, "", "Sequence is empty"
162
+
163
+ if _count_fasta_records(text) > 1:
164
+ return False, "", "Multi-FASTA input is not supported. Please submit one sequence at a time."
165
+
166
+ cleaned = strip_fasta_header(text)
167
+ is_valid, error = validate_sequence(cleaned)
168
+ if not is_valid:
169
+ return False, cleaned, error
170
+
171
+ if len(cleaned) > MAX_SEQUENCE_LENGTH:
172
+ return (
173
+ False,
174
+ cleaned,
175
+ f"Sequence too long: {len(cleaned):,} bp > {MAX_SEQUENCE_LENGTH:,} bp limit",
176
+ )
177
+
178
+ return True, cleaned, ""
179
+
180
+
181
+ def validate_stride(stride) -> tuple[bool, int, str]:
182
+ if isinstance(stride, bool):
183
+ return False, 0, "Stride must be an integer between 50 and 500 bp"
184
+ try:
185
+ if isinstance(stride, float) and not stride.is_integer():
186
+ raise ValueError
187
+ stride = int(stride)
188
+ except (TypeError, ValueError):
189
+ return False, 0, "Stride must be an integer between 50 and 500 bp"
190
+
191
+ if not 50 <= stride <= 500:
192
+ return False, stride, "Stride must be between 50 and 500 bp"
193
+ return True, stride, ""
194
+
195
+
196
+ def validate_threshold(threshold) -> tuple[bool, float, str]:
197
+ try:
198
+ threshold = float(threshold)
199
+ except (TypeError, ValueError):
200
+ return False, 0.0, "Threshold must be a number between 0 and 1"
201
+
202
+ if not 0.0 <= threshold <= 1.0:
203
+ return False, threshold, "Threshold must be between 0 and 1"
204
+ return True, threshold, ""
205
+
206
+
207
+ def validate_min_length(min_length) -> tuple[bool, int, str]:
208
+ try:
209
+ if isinstance(min_length, float) and not min_length.is_integer():
210
+ raise ValueError
211
+ min_length = int(min_length)
212
+ except (TypeError, ValueError):
213
+ return False, 0, "Minimum region length must be an integer"
214
+
215
+ if min_length < 1:
216
+ return False, min_length, "Minimum region length must be at least 1 bp"
217
+ return True, min_length, ""
218
+
219
+
220
+ def prediction_error_outputs(message: str):
221
+ return None, f"**Error**: {message}", [], None, None, None, None, None, ""
222
+
223
+
224
+ def embedding_error_outputs(message: str):
225
+ return None, f"**Error**: {message}", None, None
226
+
227
+
228
+ def make_output_dir(prefix: str) -> str:
229
+ return tempfile.mkdtemp(prefix=f"{prefix}_")
230
+
231
+
232
+ def symmetric_activation_norm(values) -> TwoSlopeNorm:
233
+ values = np.asarray(values, dtype=float)
234
+ finite = values[np.isfinite(values)]
235
+ if finite.size == 0:
236
+ vmax = 1.0
237
+ else:
238
+ vmax = max(abs(float(np.nanmin(finite))), abs(float(np.nanmax(finite))))
239
+ if vmax <= 0:
240
+ vmax = 1.0
241
+ return TwoSlopeNorm(vmin=-vmax, vcenter=0, vmax=vmax)
242
+
243
+
244
  def create_prediction_plot(positions, probabilities, threshold=0.3, regions=None):
245
  """Create a matplotlib figure showing the prediction curve (for PNG/PDF export)."""
246
  fig, ax = plt.subplots(figsize=(12, 4))
 
262
  ax.set_ylabel('CRISPR Probability')
263
  ax.set_title('CRISPR Array Detection Score')
264
  ax.set_ylim(0, 1)
265
+ ax.set_xlim(min(positions) if positions else 1, max(positions) if positions else 1000)
266
  ax.legend(loc='upper right')
267
  ax.grid(True, alpha=0.3)
268
 
 
274
  """Create an interactive Plotly figure showing the prediction curve with minimap."""
275
  fig = go.Figure()
276
 
277
+ min_pos = min(positions) if positions else 1
278
  max_pos = max(positions) if positions else 1000
279
 
280
+ # Main probability curve with fill - monochrome
281
  fig.add_trace(go.Scatter(
282
  x=positions,
283
  y=probabilities,
284
  mode='lines',
285
+ name='Score',
286
+ line=dict(color='#18181b', width=1.5),
287
  fill='tozeroy',
288
+ fillcolor='rgba(24, 24, 27, 0.08)',
289
  hovertemplate='Position: %{x:,} bp<br>Score: %{y:.3f}<extra></extra>'
290
  ))
291
 
292
+ # Add threshold line - dashed gray
293
  fig.add_hline(
294
  y=threshold,
295
  line_dash="dash",
296
+ line_color="#71717a",
297
+ annotation_text=f"threshold={threshold}",
298
  annotation_position="top right",
299
+ annotation_font_size=10,
300
+ annotation_font_color="#71717a"
301
  )
302
 
303
+ # Highlight detected CRISPR regions - subtle gray
304
  if regions:
305
  for r in regions:
306
  fig.add_vrect(
307
  x0=r['start'], x1=r['end'],
308
+ fillcolor="rgba(24, 24, 27, 0.06)",
309
  layer="below",
310
  line_width=1,
311
+ line_color="rgba(24, 24, 27, 0.2)",
312
+ annotation_text=f"#{r['region_id']}",
313
  annotation_position="top left",
314
+ annotation_font_size=9,
315
+ annotation_font_color="#52525b"
316
  )
317
 
318
  fig.update_layout(
319
+ title=None,
 
 
 
 
 
320
  xaxis=dict(
321
+ title=dict(text='Position (bp)', font=dict(size=11, color='#52525b')),
322
+ range=[min_pos, max_pos],
323
+ gridcolor='#f4f4f5',
324
  showgrid=True,
325
  zeroline=False,
326
+ linecolor='#e4e4e7',
327
+ tickfont=dict(size=10, color='#71717a'),
328
  rangeslider=dict(
329
  visible=True,
330
+ thickness=0.06,
331
+ bgcolor='#fafafa',
332
+ bordercolor='#e4e4e7',
333
  borderwidth=1
334
  ),
 
335
  rangeselector=dict(
336
  buttons=list([
337
  dict(count=500, label="500bp", step="all", stepmode="backward"),
338
  dict(count=1000, label="1kb", step="all", stepmode="backward"),
339
  dict(count=5000, label="5kb", step="all", stepmode="backward"),
340
+ dict(step="all", label="all")
341
  ]),
342
+ bgcolor='#fafafa',
343
+ bordercolor='#e4e4e7',
344
+ activecolor='#e4e4e7',
345
+ font=dict(size=9, color='#52525b'),
346
  x=0,
347
+ y=1.12
348
  )
349
  ),
350
  yaxis=dict(
351
+ title=dict(text='Score', font=dict(size=11, color='#52525b')),
352
  range=[0, 1.05],
353
+ gridcolor='#f4f4f5',
354
  showgrid=True,
355
  zeroline=False,
356
+ linecolor='#e4e4e7',
357
+ tickfont=dict(size=10, color='#71717a'),
358
  tickformat='.1f'
359
  ),
360
  hovermode='x unified',
361
+ showlegend=False,
362
+ height=420,
363
+ plot_bgcolor='#fafafa',
364
+ paper_bgcolor='#fafafa',
365
+ margin=dict(t=50, b=60, l=50, r=20),
366
+ font=dict(family='Inter, system-ui, sans-serif')
 
 
 
 
 
 
367
  )
368
 
369
  return fig
 
388
  # Create figure
389
  fig, ax = plt.subplots(figsize=(14, max(3, rows * 0.25)))
390
 
391
+ # Use diverging colormap centered at 0; constant embeddings need a non-zero span.
392
+ norm = symmetric_activation_norm(embedding)
 
393
 
394
  im = ax.imshow(grid, cmap='RdBu_r', norm=norm, aspect='auto')
395
 
 
428
 
429
  fig, ax = plt.subplots(figsize=(14, max(4, n_windows * 0.3)))
430
 
431
+ # Use diverging colormap; constant embeddings need a non-zero span.
432
+ norm = symmetric_activation_norm(embeddings)
 
433
 
434
  im = ax.imshow(embeddings, cmap='RdBu_r', norm=norm, aspect='auto')
435
 
 
607
 
608
  def create_interactive_state_plot(embeddings, n_clusters=8, stride=100, use_3d=False):
609
  """
610
+ Create interactive Plotly State-Dynamic Plot with 2D or 3D UMAP - monochrome style.
611
  """
612
  embeddings = np.array(embeddings)
613
  n_windows, n_dims = embeddings.shape
614
 
615
  if n_windows < 5:
 
616
  fig = go.Figure()
617
  fig.add_annotation(text="Need longer sequence (minimum ~1500 bp)",
618
  xref="paper", yref="paper", x=0.5, y=0.5,
619
+ showarrow=False, font=dict(size=14, color='#71717a'))
620
+ fig.update_layout(plot_bgcolor='#fafafa', paper_bgcolor='#fafafa')
621
  return fig
622
 
623
  # UMAP reduction
 
642
  hover_text = [f"Window {i}<br>Position: {pos}-{pos+1000} bp<br>Cluster: {c}"
643
  for i, (pos, c) in enumerate(zip(positions, cluster_labels))]
644
 
645
+ # Monochrome grayscale palette for clusters
646
+ grays = [f'rgba({int(40 + i * 180 / n_clusters)}, {int(40 + i * 180 / n_clusters)}, {int(40 + i * 180 / n_clusters)}, 0.8)'
647
+ for i in range(n_clusters)]
648
 
649
  if use_3d:
 
650
  fig = go.Figure()
651
 
652
+ # Trajectory line
653
  fig.add_trace(go.Scatter3d(
654
  x=embedding_reduced[:, 0],
655
  y=embedding_reduced[:, 1],
656
  z=embedding_reduced[:, 2],
657
  mode='lines',
658
+ line=dict(color='rgba(113,113,122,0.3)', width=2),
659
  name='Trajectory',
660
  hoverinfo='skip'
661
  ))
662
 
663
+ # Points - grayscale colorscale
664
  fig.add_trace(go.Scatter3d(
665
  x=embedding_reduced[:, 0],
666
  y=embedding_reduced[:, 1],
667
  z=embedding_reduced[:, 2],
668
  mode='markers',
669
  marker=dict(
670
+ size=5,
671
  color=cluster_labels,
672
+ colorscale='Greys',
673
+ opacity=0.85,
674
+ line=dict(width=0.5, color='white')
675
  ),
676
  text=hover_text,
677
  hovertemplate='%{text}<extra></extra>',
678
  name='Windows'
679
  ))
680
 
681
+ # Start marker - dark
682
  fig.add_trace(go.Scatter3d(
683
  x=[embedding_reduced[0, 0]],
684
  y=[embedding_reduced[0, 1]],
685
  z=[embedding_reduced[0, 2]],
686
  mode='markers',
687
+ marker=dict(size=10, color='#18181b', symbol='diamond'),
688
+ name="5' start"
689
  ))
690
+ # End marker - medium gray
691
  fig.add_trace(go.Scatter3d(
692
  x=[embedding_reduced[-1, 0]],
693
  y=[embedding_reduced[-1, 1]],
694
  z=[embedding_reduced[-1, 2]],
695
  mode='markers',
696
+ marker=dict(size=10, color='#71717a', symbol='square'),
697
+ name="3' end"
698
  ))
699
 
700
  fig.update_layout(
701
+ title=None,
702
  scene=dict(
703
+ xaxis=dict(title='UMAP 1', gridcolor='#e4e4e7', backgroundcolor='#fafafa'),
704
+ yaxis=dict(title='UMAP 2', gridcolor='#e4e4e7', backgroundcolor='#fafafa'),
705
+ zaxis=dict(title='UMAP 3', gridcolor='#e4e4e7', backgroundcolor='#fafafa'),
706
  ),
707
+ height=550,
708
+ showlegend=True,
709
+ legend=dict(font=dict(size=10), bgcolor='rgba(250,250,250,0.9)'),
710
+ plot_bgcolor='#fafafa',
711
+ paper_bgcolor='#fafafa',
712
+ font=dict(family='Inter, system-ui, sans-serif', color='#52525b')
713
  )
714
 
715
  else:
 
718
  rows=2, cols=2,
719
  specs=[[{"type": "scatter"}, {"type": "scatter"}],
720
  [{"type": "scatter", "colspan": 2}, None]],
721
+ subplot_titles=('by cluster', 'by position', 'sequence map'),
722
  row_heights=[0.6, 0.4],
723
  vertical_spacing=0.12
724
  )
 
728
  x=embedding_reduced[:, 0],
729
  y=embedding_reduced[:, 1],
730
  mode='lines',
731
+ line=dict(color='rgba(113,113,122,0.15)', width=1),
732
  hoverinfo='skip',
733
  showlegend=False
734
  ), row=1, col=1)
 
739
  x=embedding_reduced[mask, 0],
740
  y=embedding_reduced[mask, 1],
741
  mode='markers',
742
+ marker=dict(size=7, color=grays[c],
743
+ line=dict(width=0.5, color='white')),
744
  text=[hover_text[i] for i in np.where(mask)[0]],
745
  hovertemplate='%{text}<extra></extra>',
746
+ name=f'{c}',
747
  legendgroup=f'c{c}'
748
  ), row=1, col=1)
749
 
750
  # Start/End markers
751
  fig.add_trace(go.Scatter(
752
  x=[embedding_reduced[0, 0]], y=[embedding_reduced[0, 1]],
753
+ mode='markers', marker=dict(size=12, color='#18181b', symbol='triangle-up',
754
+ line=dict(width=1, color='white')),
755
+ name="5'", showlegend=True
756
  ), row=1, col=1)
757
  fig.add_trace(go.Scatter(
758
  x=[embedding_reduced[-1, 0]], y=[embedding_reduced[-1, 1]],
759
+ mode='markers', marker=dict(size=12, color='#71717a', symbol='square',
760
+ line=dict(width=1, color='white')),
761
+ name="3'", showlegend=True
762
  ), row=1, col=1)
763
 
764
+ # Right plot: by position - grayscale gradient
765
  fig.add_trace(go.Scatter(
766
  x=embedding_reduced[:, 0],
767
  y=embedding_reduced[:, 1],
768
  mode='lines+markers',
769
+ line=dict(color='rgba(113,113,122,0.2)', width=1),
770
+ marker=dict(size=7, color=np.arange(n_windows), colorscale='Greys',
771
+ showscale=True, colorbar=dict(title=dict(text='window', font=dict(size=10)),
772
+ x=1.02, tickfont=dict(size=9))),
773
  text=hover_text,
774
  hovertemplate='%{text}<extra></extra>',
775
  showlegend=False
 
777
 
778
  fig.add_trace(go.Scatter(
779
  x=[embedding_reduced[0, 0]], y=[embedding_reduced[0, 1]],
780
+ mode='markers', marker=dict(size=12, color='#18181b', symbol='triangle-up',
781
+ line=dict(width=1, color='white')),
782
  showlegend=False
783
  ), row=1, col=2)
784
  fig.add_trace(go.Scatter(
785
  x=[embedding_reduced[-1, 0]], y=[embedding_reduced[-1, 1]],
786
+ mode='markers', marker=dict(size=12, color='#71717a', symbol='square',
787
+ line=dict(width=1, color='white')),
788
  showlegend=False
789
  ), row=1, col=2)
790
 
791
+ # Bottom: sequence map - grayscale blocks
792
  window_size = 1000
793
  for i, (cluster, pos) in enumerate(zip(cluster_labels, positions)):
794
  fig.add_trace(go.Scatter(
795
  x=[pos, pos + window_size, pos + window_size, pos, pos],
796
  y=[0, 0, 1, 1, 0],
797
  fill='toself',
798
+ fillcolor=grays[cluster],
799
  line=dict(width=0),
 
800
  hoverinfo='text',
801
  text=f'Position {pos}-{pos+window_size} bp<br>Cluster {cluster}',
802
  showlegend=False
803
  ), row=2, col=1)
804
 
805
+ fig.update_xaxes(title_text='UMAP 1', row=1, col=1, gridcolor='#f4f4f5',
806
+ tickfont=dict(size=9, color='#71717a'))
807
+ fig.update_yaxes(title_text='UMAP 2', row=1, col=1, gridcolor='#f4f4f5',
808
+ tickfont=dict(size=9, color='#71717a'))
809
+ fig.update_xaxes(title_text='UMAP 1', row=1, col=2, gridcolor='#f4f4f5',
810
+ tickfont=dict(size=9, color='#71717a'))
811
+ fig.update_yaxes(title_text='UMAP 2', row=1, col=2, gridcolor='#f4f4f5',
812
+ tickfont=dict(size=9, color='#71717a'))
813
+ fig.update_xaxes(title_text='position (bp)', row=2, col=1, gridcolor='#f4f4f5',
814
+ tickfont=dict(size=9, color='#71717a'))
815
  fig.update_yaxes(visible=False, row=2, col=1)
816
 
817
  fig.update_layout(
818
+ title=None,
819
+ height=650,
 
820
  showlegend=True,
821
+ legend=dict(orientation='h', yanchor='bottom', y=1.02, xanchor='right', x=1,
822
+ font=dict(size=9), bgcolor='rgba(250,250,250,0.9)'),
823
+ plot_bgcolor='#fafafa',
824
+ paper_bgcolor='#fafafa',
825
+ font=dict(family='Inter, system-ui, sans-serif', color='#52525b', size=11),
826
+ margin=dict(t=40, b=40)
827
  )
828
 
829
+ # Style subplot titles
830
+ for annotation in fig['layout']['annotations']:
831
+ annotation['font'] = dict(size=11, color='#52525b')
832
+
833
  return fig
834
 
835
 
 
837
  """Parse a FASTA file and return the sequence."""
838
  if file_path is None:
839
  return None
840
+ size = os.path.getsize(file_path)
841
+ if size > MAX_UPLOAD_BYTES:
842
+ raise gr.Error(f"Uploaded file is too large ({size:,} bytes > {MAX_UPLOAD_BYTES:,} byte limit).")
843
+
844
+ with open(file_path, 'r', encoding='utf-8', errors='replace') as f:
845
  content = f.read()
846
+
847
+ is_valid, cleaned, error = normalize_sequence_input(content)
848
+ if not is_valid:
849
+ raise gr.Error(error)
850
+ return cleaned
851
 
852
 
853
+ def create_gff3_export(regions, sequence_length, sequence_id="input_sequence", output_dir=None):
854
  """Create GFF3 format annotation file for detected CRISPR regions."""
855
+ output_dir = output_dir or make_output_dir("crispr_export")
856
+ gff_path = os.path.join(output_dir, "crispr_regions.gff3")
857
 
858
  with open(gff_path, 'w') as f:
859
  # GFF3 header
 
863
  for r in regions:
864
  # GFF3 format: seqid source type start end score strand phase attributes
865
  attributes = f"ID=CRISPR_{r['region_id']};Name=CRISPR_array_{r['region_id']};score={r['mean_score']:.3f}"
866
+ f.write(f"{sequence_id}\tCRISPR-BERT\tCRISPR_array\t{r['start']}\t{r['end']}\t{r['mean_score']:.3f}\t.\t.\t{attributes}\n")
867
 
868
  return gff_path
869
 
870
 
871
  def create_sequence_viewer_html(sequence, positions, probabilities, threshold=0.3, chunk_size=100):
872
+ """Create an HTML visualization of the sequence with grayscale intensity scores."""
 
 
 
873
  seq_len = len(sequence)
874
+ if seq_len > MAX_SEQUENCE_VIEWER_LENGTH:
875
+ return (
876
+ '<div style="background: #fafafa; padding: 16px; border: 1px solid #e4e4e7;">'
877
+ f'Sequence viewer disabled for sequences longer than {MAX_SEQUENCE_VIEWER_LENGTH:,} bp '
878
+ f'(current sequence: {seq_len:,} bp). Use the plot and downloads for full results.'
879
+ '</div>'
880
+ )
881
+
882
+ per_base_scores = np.asarray(probabilities, dtype=float)
883
+ if len(per_base_scores) != seq_len:
884
+ per_base_scores = np.resize(per_base_scores, seq_len)
885
+
886
+ # Generate HTML - monochrome style
887
+ html_parts = ['<div style="font-family: \'Geist Mono\', \'SF Mono\', Consolas, monospace; font-size: 11px; line-height: 1.9; background: #fafafa; padding: 16px; border: 1px solid #e4e4e7; max-height: 400px; overflow-y: auto;">']
888
+ html_parts.append('<div style="margin-bottom: 12px; font-family: Inter, system-ui, sans-serif; font-size: 11px; color: #71717a;">')
889
+ html_parts.append('<span style="background: linear-gradient(to right, #fafafa, #18181b); padding: 3px 24px; border: 1px solid #e4e4e7; display: inline-block;">low → high</span>')
890
+ html_parts.append(f'<span style="margin-left: 12px;">threshold: {threshold}</span>')
 
891
  html_parts.append('</div>')
892
 
893
+ # Process sequence in chunks
894
  for chunk_start in range(0, seq_len, chunk_size):
895
  chunk_end = min(chunk_start + chunk_size, seq_len)
896
  chunk_seq = sequence[chunk_start:chunk_end]
897
  chunk_scores = per_base_scores[chunk_start:chunk_end]
898
 
899
  # Position marker
900
+ html_parts.append(f'<div><span style="color: #a1a1aa; width: 55px; display: inline-block; font-size: 10px;">{chunk_start+1:,}</span>')
901
 
902
  for i, (base, score) in enumerate(zip(chunk_seq, chunk_scores)):
903
+ # Grayscale intensity based on score
904
+ intensity = int(255 - score * 200) # Higher score = darker
905
+ color = f'rgb({intensity},{intensity},{intensity})'
906
+ bg_intensity = int(250 - score * 40)
907
+ bg_color = f'rgb({bg_intensity},{bg_intensity},{bg_intensity})'
908
+ font_weight = '600' if score >= threshold else '400'
909
+
910
+ safe_base = html.escape(base)
911
+ html_parts.append(f'<span style="color: {color}; background-color: {bg_color}; font-weight: {font_weight};" title="pos {chunk_start + i + 1}: {score:.3f}">{safe_base}</span>')
 
 
 
912
 
913
  html_parts.append('</div>')
914
 
 
1140
  # Build interface
1141
  with gr.Blocks(title="CRISPR Array Detection") as demo:
1142
  gr.Markdown("""
1143
+ # crispr-detect
 
 
1144
 
1145
+ BERT-based CRISPR array detection. 24-layer transformer (430M params) trained on metagenomic sequences.
1146
+ Sliding window analysis with per-position probability scores. Export to GFF3/CSV.
 
1147
  """)
1148
 
1149
  with gr.Tab("Prediction"):
1150
  with gr.Row():
1151
  with gr.Column(scale=1):
1152
  seq_input = gr.Textbox(
1153
+ label="sequence",
1154
  placeholder="Paste DNA sequence (FASTA format accepted)...",
1155
  lines=6,
1156
  value=FLANKED_CRISPR_EXAMPLE,
1157
+ info="min 1000 bp"
1158
  )
1159
  file_upload = gr.File(
1160
+ label="upload fasta",
1161
  file_types=[".fasta", ".fa", ".fna", ".txt"],
1162
  type="filepath"
1163
  )
1164
  with gr.Row():
1165
  stride_input = gr.Slider(
1166
  minimum=50, maximum=500, value=100, step=50,
1167
+ label="stride",
1168
+ info="lower = higher resolution"
1169
  )
1170
  threshold_input = gr.Slider(
1171
  minimum=0.1, maximum=0.9, value=0.3, step=0.05,
1172
+ label="threshold",
1173
+ info="lower = sensitive, higher = specific"
1174
  )
1175
  with gr.Row():
1176
+ predict_btn = gr.Button("run", variant="primary", size="lg")
1177
+ gr.Markdown("*examples:*")
1178
  with gr.Row():
1179
+ gr.Button("flanked", size="sm").click(
1180
  lambda: FLANKED_CRISPR_EXAMPLE, outputs=seq_input
1181
  )
1182
+ gr.Button("e.coli", size="sm").click(
1183
  lambda: ECOLI_CRISPR_EXAMPLE, outputs=seq_input
1184
  )
1185
  with gr.Row():
1186
+ gr.Button("crispr", size="sm").click(
1187
  lambda: CRISPR_EXAMPLE, outputs=seq_input
1188
  )
1189
+ gr.Button("control", size="sm").click(
1190
  lambda: NON_CRISPR_EXAMPLE, outputs=seq_input
1191
  )
1192
  result_summary = gr.Markdown()
1193
+ with gr.Accordion("export", open=False, visible=False) as download_accordion:
1194
+
1195
  with gr.Row():
1196
+ pred_download_png = gr.File(label="png", interactive=False)
1197
+ pred_download_pdf = gr.File(label="pdf", interactive=False)
1198
+
1199
  with gr.Row():
1200
+ pred_download_csv = gr.File(label="csv", interactive=False)
1201
+ pred_download_gff = gr.File(label="gff3", interactive=False)
1202
  with gr.Row():
1203
+ pred_download_summary = gr.File(label="summary", interactive=False)
1204
  with gr.Column(scale=2):
1205
+ plot_output = gr.Plot(label="prediction")
1206
+ with gr.Accordion("sequence", open=False, visible=False) as seq_viewer_accordion:
1207
+ gr.Markdown("*grayscale intensity = score. hover for values.*")
1208
+ seq_viewer_html = gr.HTML(label="sequence")
1209
  regions_output = gr.JSON(label="Detected Regions", visible=False)
1210
 
1211
  # Handle file upload - load content into textbox
 
1236
 
1237
  with gr.Tab("Embeddings"):
1238
  gr.Markdown("""
1239
+ ### embeddings
 
 
1240
 
1241
+ 768-dim hidden states from transformer layer 21. UMAP projection + agglomerative clustering.
1242
+ Repeats cluster together, spacers form distinct groups.
1243
  """)
1244
  with gr.Row():
1245
  with gr.Column(scale=1):
1246
  embed_seq = gr.Textbox(
1247
+ label="sequence",
1248
  placeholder="Paste DNA sequence...",
1249
  lines=6,
1250
  value=EMBEDDING_CRISPR_EXAMPLE,
1251
+ info="min ~2000 bp for clustering"
1252
  )
1253
  embed_mode = gr.Radio(
1254
  choices=["state-dynamics", "mean", "max", "trajectory"],
1255
  value="state-dynamics",
1256
+ label="mode",
1257
+ info=""
1258
  )
1259
  use_3d = gr.Checkbox(
1260
+ label="3D",
1261
  value=False,
1262
+ info="",
1263
  visible=True
1264
  )
1265
  with gr.Row():
1266
+ embed_btn = gr.Button("extract", variant="primary")
1267
  with gr.Row():
1268
+ gr.Button("crispr 3kb", size="sm").click(
1269
  lambda: EMBEDDING_CRISPR_EXAMPLE, outputs=embed_seq
1270
  )
1271
+ gr.Button("control 3kb", size="sm").click(
1272
  lambda: EMBEDDING_RANDOM_EXAMPLE, outputs=embed_seq
1273
  )
1274
+ gr.Markdown("*example: 600bp upstream | 25 repeats + 24 spacers | 600bp downstream*")
 
 
1275
  embed_summary = gr.Markdown()
1276
+ with gr.Accordion("export", open=False, visible=False) as embed_download_accordion:
1277
  with gr.Row():
1278
+ download_png = gr.File(label="png", interactive=False)
1279
+ download_pdf = gr.File(label="pdf", interactive=False)
1280
  with gr.Column(scale=2):
1281
+ embed_plot = gr.Plot(label="embedding")
1282
 
1283
  # Show/hide 3D checkbox based on mode
1284
  embed_mode.change(
 
1299
 
1300
  with gr.Tab("API"):
1301
  gr.Markdown("""
1302
+ ### api
 
 
 
 
1303
 
1304
  ```python
1305
  from gradio_client import Client
1306
 
 
1307
  client = Client("genomenet/crispr-array-detection")
1308
 
1309
+ # predict
1310
  result = client.predict(
1311
+ sequence="ATGC...",
1312
+ stride=100,
1313
+ threshold=0.3,
1314
  api_name="/predict"
1315
  )
1316
 
1317
+ # embeddings
 
 
 
 
 
1318
  result = client.predict(
1319
  sequence="ATGC...",
1320
+ mode="state-dynamics",
1321
  use_3d=False,
1322
  api_name="/get_embedding"
1323
  )
1324
  ```
1325
 
1326
+ **output formats**: CSV (scores), GFF3 (annotations), PNG/PDF (figures)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1327
 
1328
+ **local**:
1329
  ```bash
1330
  git clone https://huggingface.co/spaces/genomenet/crispr-array-detection
1331
+ pip install -r requirements.txt && python app.py
 
 
1332
  ```
1333
  """)
1334
 
1335
  with gr.Tab("About"):
1336
  gr.Markdown("""
1337
+ ### about
 
 
 
 
 
 
 
 
 
 
 
 
 
1338
 
1339
+ | | |
1340
+ |---|---|
1341
+ | architecture | BERT, 24 layers, 768 hidden, 12 heads, 430M params |
1342
+ | training | metagenomic contigs, microbial genomes, CRISPRCasdb |
1343
+ | window | 1000 bp |
1344
+ | embedding | layer 21 (768-dim) |
1345
 
1346
+ **parameters**
1347
 
1348
+ | param | default | range |
1349
+ |-------|---------|-------|
1350
+ | stride | 100 bp | 50-500 |
1351
+ | threshold | 0.3 | 0.1-0.9 |
1352
 
1353
+ **citation**
 
 
 
 
1354
 
1355
+ Mu, Z. (2024). Deep Learning-Based CRISPR Array Detection. Master's Thesis, HZI.
1356
 
1357
+ **acknowledgements**
 
 
1358
 
1359
+ DFG SPP 2141 (MC 172) / BMBF de.NBI GenomeNet / HZI BIFO
 
 
 
 
 
 
 
 
 
 
 
1360
  """)
1361
 
1362
 
 
1368
  demo.launch(
1369
  server_name="0.0.0.0",
1370
  server_port=7860,
1371
+ theme=gr.themes.Base(
1372
+ primary_hue=gr.themes.colors.zinc,
1373
+ secondary_hue=gr.themes.colors.zinc,
1374
+ neutral_hue=gr.themes.colors.zinc,
1375
+ font=gr.themes.GoogleFont("Inter"),
1376
+ font_mono=gr.themes.GoogleFont("Geist Mono"),
1377
+ ),
1378
  css=CUSTOM_CSS
1379
  )
inference/inference.py CHANGED
@@ -74,6 +74,9 @@ def predict_batch(
74
  Returns:
75
  Predictions of shape (N, window_size) with probabilities
76
  """
 
 
 
77
  expected_dtype = model.inputs[0].dtype
78
  windows = cast_for_model(windows, expected_dtype)
79
 
@@ -121,6 +124,9 @@ def aggregate_predictions(
121
  Returns:
122
  Per-position probability array of shape (seq_length,)
123
  """
 
 
 
124
  scores = np.zeros(seq_length, dtype=np.float32)
125
  counts = np.zeros(seq_length, dtype=np.int32)
126
 
@@ -162,8 +168,10 @@ def predict_sequence(
162
  Returns:
163
  PredictionResult with per-position probabilities
164
  """
165
- if model is None:
166
- model = get_model()
 
 
167
 
168
  # Tokenize sequence
169
  tokens = encode_sequence(sequence)
@@ -172,6 +180,9 @@ def predict_sequence(
172
  # Create sliding windows
173
  windows, starts = create_windows(tokens, window_size=WINDOW_SIZE, stride=stride)
174
 
 
 
 
175
  logger.info(f"Processing sequence: {seq_length} bp, {len(windows)} windows (stride={stride})")
176
 
177
  # Run batched prediction
@@ -211,6 +222,9 @@ def embed_batch(
211
  Returns:
212
  Embeddings of shape (N, window_size, embed_dim) or (N, embed_dim)
213
  """
 
 
 
214
  expected_dtype = model.inputs[0].dtype
215
  windows = cast_for_model(windows, expected_dtype)
216
 
@@ -249,8 +263,10 @@ def embed_sequence(
249
  Returns:
250
  EmbeddingResult (for mean/cls/max) or TrajectoryResult (for trajectory)
251
  """
252
- if model is None:
253
- model = get_embedding_model()
 
 
254
 
255
  # Tokenize sequence
256
  tokens = encode_sequence(sequence)
@@ -259,6 +275,9 @@ def embed_sequence(
259
  # Create windows
260
  windows, starts = create_windows(tokens, window_size=WINDOW_SIZE, stride=stride)
261
 
 
 
 
262
  logger.info(f"Extracting embeddings: {seq_length} bp, {len(windows)} windows")
263
 
264
  # Get embeddings (shape: N, window_size, embed_dim)
@@ -306,7 +325,8 @@ def detect_crispr_regions(
306
  min_length: int = 160,
307
  merge_gap: int = 80,
308
  stride: int = 100,
309
- model: Optional[tf.keras.Model] = None
 
310
  ) -> list[dict]:
311
  """
312
  Detect CRISPR array regions in a sequence.
@@ -322,8 +342,16 @@ def detect_crispr_regions(
322
  Returns:
323
  List of detected regions with coordinates and scores
324
  """
325
- # Get per-position predictions
326
- result = predict_sequence(sequence, stride=stride, model=model)
 
 
 
 
 
 
 
 
327
  scores = np.array(result.probabilities)
328
 
329
  # Threshold to binary mask
 
74
  Returns:
75
  Predictions of shape (N, window_size) with probabilities
76
  """
77
+ if batch_size <= 0:
78
+ raise ValueError("batch_size must be a positive integer")
79
+
80
  expected_dtype = model.inputs[0].dtype
81
  windows = cast_for_model(windows, expected_dtype)
82
 
 
124
  Returns:
125
  Per-position probability array of shape (seq_length,)
126
  """
127
+ if aggregation not in {"mean", "max"}:
128
+ raise ValueError("aggregation must be 'mean' or 'max'")
129
+
130
  scores = np.zeros(seq_length, dtype=np.float32)
131
  counts = np.zeros(seq_length, dtype=np.int32)
132
 
 
168
  Returns:
169
  PredictionResult with per-position probabilities
170
  """
171
+ if aggregation not in {"mean", "max"}:
172
+ raise ValueError("aggregation must be 'mean' or 'max'")
173
+ if batch_size <= 0:
174
+ raise ValueError("batch_size must be a positive integer")
175
 
176
  # Tokenize sequence
177
  tokens = encode_sequence(sequence)
 
180
  # Create sliding windows
181
  windows, starts = create_windows(tokens, window_size=WINDOW_SIZE, stride=stride)
182
 
183
+ if model is None:
184
+ model = get_model()
185
+
186
  logger.info(f"Processing sequence: {seq_length} bp, {len(windows)} windows (stride={stride})")
187
 
188
  # Run batched prediction
 
222
  Returns:
223
  Embeddings of shape (N, window_size, embed_dim) or (N, embed_dim)
224
  """
225
+ if batch_size <= 0:
226
+ raise ValueError("batch_size must be a positive integer")
227
+
228
  expected_dtype = model.inputs[0].dtype
229
  windows = cast_for_model(windows, expected_dtype)
230
 
 
263
  Returns:
264
  EmbeddingResult (for mean/cls/max) or TrajectoryResult (for trajectory)
265
  """
266
+ if mode not in {"mean", "cls", "max", "trajectory"}:
267
+ raise ValueError("mode must be one of: mean, cls, max, trajectory")
268
+ if batch_size <= 0:
269
+ raise ValueError("batch_size must be a positive integer")
270
 
271
  # Tokenize sequence
272
  tokens = encode_sequence(sequence)
 
275
  # Create windows
276
  windows, starts = create_windows(tokens, window_size=WINDOW_SIZE, stride=stride)
277
 
278
+ if model is None:
279
+ model = get_embedding_model()
280
+
281
  logger.info(f"Extracting embeddings: {seq_length} bp, {len(windows)} windows")
282
 
283
  # Get embeddings (shape: N, window_size, embed_dim)
 
325
  min_length: int = 160,
326
  merge_gap: int = 80,
327
  stride: int = 100,
328
+ model: Optional[tf.keras.Model] = None,
329
+ prediction_result: Optional[PredictionResult] = None
330
  ) -> list[dict]:
331
  """
332
  Detect CRISPR array regions in a sequence.
 
342
  Returns:
343
  List of detected regions with coordinates and scores
344
  """
345
+ if not 0.0 <= threshold <= 1.0:
346
+ raise ValueError("threshold must be between 0 and 1")
347
+ if min_length < 1:
348
+ raise ValueError("min_length must be at least 1")
349
+ if merge_gap < 0:
350
+ raise ValueError("merge_gap must be non-negative")
351
+
352
+ # Get per-position predictions, or reuse a caller-provided result to avoid
353
+ # running the model twice in UI flows that need both scores and regions.
354
+ result = prediction_result or predict_sequence(sequence, stride=stride, model=model)
355
  scores = np.array(result.probabilities)
356
 
357
  # Threshold to binary mask
inference/tokenizer.py CHANGED
@@ -11,7 +11,6 @@ Token mapping:
11
  """
12
 
13
  import numpy as np
14
- from typing import Union
15
 
16
  VOCAB_SIZE = 6
17
  WINDOW_SIZE = 1000
@@ -30,6 +29,22 @@ _LUT[ord("g")] = 3
30
  _LUT[ord("t")] = 4
31
 
32
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
  def encode_sequence(sequence: str) -> np.ndarray:
34
  """
35
  Convert DNA sequence string to integer token array.
@@ -43,7 +58,10 @@ def encode_sequence(sequence: str) -> np.ndarray:
43
  # Convert to uppercase for consistency
44
  seq_upper = sequence.upper()
45
  # Convert to bytes and apply lookup
46
- seq_bytes = np.frombuffer(seq_upper.encode("ascii"), dtype=np.uint8)
 
 
 
47
  return _LUT[seq_bytes]
48
 
49
 
@@ -69,7 +87,8 @@ def validate_sequence(sequence: str) -> tuple[bool, str]:
69
  invalid_chars = seq_chars - valid_chars
70
 
71
  if invalid_chars:
72
- return False, f"Invalid characters in sequence: {invalid_chars}"
 
73
 
74
  return True, ""
75
 
@@ -84,8 +103,13 @@ def strip_fasta_header(text: str) -> str:
84
  Returns:
85
  Sequence string with headers removed
86
  """
87
- lines = text.strip().split("\n")
88
- sequence_lines = [line.strip() for line in lines if not line.startswith(">")]
 
 
 
 
 
89
  return "".join(sequence_lines)
90
 
91
 
@@ -105,6 +129,8 @@ def create_windows(
105
  Returns:
106
  Tuple of (windows array, start positions array)
107
  """
 
 
108
  seq_len = len(tokens)
109
 
110
  if seq_len < window_size:
 
11
  """
12
 
13
  import numpy as np
 
14
 
15
  VOCAB_SIZE = 6
16
  WINDOW_SIZE = 1000
 
29
  _LUT[ord("t")] = 4
30
 
31
 
32
+ def _coerce_positive_int(name: str, value) -> int:
33
+ """Accept int-like values from UI/API inputs and reject unsafe strides."""
34
+ if isinstance(value, bool):
35
+ raise ValueError(f"{name} must be a positive integer")
36
+ if isinstance(value, (int, np.integer)):
37
+ parsed = int(value)
38
+ elif isinstance(value, float) and value.is_integer():
39
+ parsed = int(value)
40
+ else:
41
+ raise ValueError(f"{name} must be a positive integer")
42
+
43
+ if parsed <= 0:
44
+ raise ValueError(f"{name} must be a positive integer")
45
+ return parsed
46
+
47
+
48
  def encode_sequence(sequence: str) -> np.ndarray:
49
  """
50
  Convert DNA sequence string to integer token array.
 
58
  # Convert to uppercase for consistency
59
  seq_upper = sequence.upper()
60
  # Convert to bytes and apply lookup
61
+ try:
62
+ seq_bytes = np.frombuffer(seq_upper.encode("ascii"), dtype=np.uint8)
63
+ except UnicodeEncodeError as exc:
64
+ raise ValueError("Sequence contains non-ASCII characters") from exc
65
  return _LUT[seq_bytes]
66
 
67
 
 
87
  invalid_chars = seq_chars - valid_chars
88
 
89
  if invalid_chars:
90
+ invalid = ", ".join(repr(c) for c in sorted(invalid_chars))
91
+ return False, f"Invalid characters in sequence: {invalid}"
92
 
93
  return True, ""
94
 
 
103
  Returns:
104
  Sequence string with headers removed
105
  """
106
+ lines = text.strip().splitlines()
107
+ sequence_lines = []
108
+ for line in lines:
109
+ line = line.strip()
110
+ if not line or line.startswith(">"):
111
+ continue
112
+ sequence_lines.append(line)
113
  return "".join(sequence_lines)
114
 
115
 
 
129
  Returns:
130
  Tuple of (windows array, start positions array)
131
  """
132
+ window_size = _coerce_positive_int("window_size", window_size)
133
+ stride = _coerce_positive_int("stride", stride)
134
  seq_len = len(tokens)
135
 
136
  if seq_len < window_size: