sochasticbackup commited on
Commit
2997d61
·
1 Parent(s): e994268

initialised app

Browse files
.gitignore ADDED
@@ -0,0 +1,216 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ *.py[codz]
4
+ *$py.class
5
+
6
+ # C extensions
7
+ *.so
8
+
9
+ # Distribution / packaging
10
+ .Python
11
+ build/
12
+ develop-eggs/
13
+ dist/
14
+ downloads/
15
+ eggs/
16
+ .eggs/
17
+ lib/
18
+ lib64/
19
+ parts/
20
+ sdist/
21
+ var/
22
+ wheels/
23
+ share/python-wheels/
24
+ *.egg-info/
25
+ .installed.cfg
26
+ *.egg
27
+ MANIFEST
28
+
29
+ # PyInstaller
30
+ # Usually these files are written by a python script from a template
31
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
32
+ *.manifest
33
+ *.spec
34
+
35
+ # Installer logs
36
+ pip-log.txt
37
+ pip-delete-this-directory.txt
38
+
39
+ # Unit test / coverage reports
40
+ htmlcov/
41
+ .tox/
42
+ .nox/
43
+ .coverage
44
+ .coverage.*
45
+ .cache
46
+ nosetests.xml
47
+ coverage.xml
48
+ *.cover
49
+ *.py.cover
50
+ .hypothesis/
51
+ .pytest_cache/
52
+ cover/
53
+
54
+ # Translations
55
+ *.mo
56
+ *.pot
57
+
58
+ # Django stuff:
59
+ *.log
60
+ local_settings.py
61
+ db.sqlite3
62
+ db.sqlite3-journal
63
+
64
+ # Flask stuff:
65
+ instance/
66
+ .webassets-cache
67
+
68
+ # Scrapy stuff:
69
+ .scrapy
70
+
71
+ # Sphinx documentation
72
+ docs/_build/
73
+
74
+ # PyBuilder
75
+ .pybuilder/
76
+ target/
77
+
78
+ # Jupyter Notebook
79
+ .ipynb_checkpoints
80
+
81
+ # IPython
82
+ profile_default/
83
+ ipython_config.py
84
+
85
+ # pyenv
86
+ # For a library or package, you might want to ignore these files since the code is
87
+ # intended to run in multiple environments; otherwise, check them in:
88
+ # .python-version
89
+
90
+ # pipenv
91
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
93
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
94
+ # install all needed dependencies.
95
+ # Pipfile.lock
96
+
97
+ # UV
98
+ # Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control.
99
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
100
+ # commonly ignored for libraries.
101
+ # uv.lock
102
+
103
+ # poetry
104
+ # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
105
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
106
+ # commonly ignored for libraries.
107
+ # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
108
+ # poetry.lock
109
+ # poetry.toml
110
+
111
+ # pdm
112
+ # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
113
+ # pdm recommends including project-wide configuration in pdm.toml, but excluding .pdm-python.
114
+ # https://pdm-project.org/en/latest/usage/project/#working-with-version-control
115
+ # pdm.lock
116
+ # pdm.toml
117
+ .pdm-python
118
+ .pdm-build/
119
+
120
+ # pixi
121
+ # Similar to Pipfile.lock, it is generally recommended to include pixi.lock in version control.
122
+ # pixi.lock
123
+ # Pixi creates a virtual environment in the .pixi directory, just like venv module creates one
124
+ # in the .venv directory. It is recommended not to include this directory in version control.
125
+ .pixi
126
+
127
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
128
+ __pypackages__/
129
+
130
+ # Celery stuff
131
+ celerybeat-schedule
132
+ celerybeat.pid
133
+
134
+ # Redis
135
+ *.rdb
136
+ *.aof
137
+ *.pid
138
+
139
+ # RabbitMQ
140
+ mnesia/
141
+ rabbitmq/
142
+ rabbitmq-data/
143
+
144
+ # ActiveMQ
145
+ activemq-data/
146
+
147
+ # SageMath parsed files
148
+ *.sage.py
149
+
150
+ # Environments
151
+ .env
152
+ .envrc
153
+ .venv
154
+ env/
155
+ venv/
156
+ ENV/
157
+ env.bak/
158
+ venv.bak/
159
+
160
+ # Spyder project settings
161
+ .spyderproject
162
+ .spyproject
163
+
164
+ # Rope project settings
165
+ .ropeproject
166
+
167
+ # mkdocs documentation
168
+ /site
169
+
170
+ # mypy
171
+ .mypy_cache/
172
+ .dmypy.json
173
+ dmypy.json
174
+
175
+ # Pyre type checker
176
+ .pyre/
177
+
178
+ # pytype static type analyzer
179
+ .pytype/
180
+
181
+ # Cython debug symbols
182
+ cython_debug/
183
+
184
+ # PyCharm
185
+ # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
186
+ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
187
+ # and can be added to the global gitignore or merged into this file. For a more nuclear
188
+ # option (not recommended) you can uncomment the following to ignore the entire idea folder.
189
+ # .idea/
190
+
191
+ # Abstra
192
+ # Abstra is an AI-powered process automation framework.
193
+ # Ignore directories containing user credentials, local state, and settings.
194
+ # Learn more at https://abstra.io/docs
195
+ .abstra/
196
+
197
+ # Visual Studio Code
198
+ # Visual Studio Code specific template is maintained in a separate VisualStudioCode.gitignore
199
+ # that can be found at https://github.com/github/gitignore/blob/main/Global/VisualStudioCode.gitignore
200
+ # and can be added to the global gitignore or merged into this file. However, if you prefer,
201
+ # you could uncomment the following to ignore the entire vscode folder
202
+ # .vscode/
203
+
204
+ # Ruff stuff:
205
+ .ruff_cache/
206
+
207
+ # PyPI configuration file
208
+ .pypirc
209
+
210
+ # Marimo
211
+ marimo/_static/
212
+ marimo/_lsp/
213
+ __marimo__/
214
+
215
+ # Streamlit
216
+ .streamlit/secrets.toml
README.md CHANGED
@@ -1,14 +1,15 @@
1
  ---
2
- title: Evo App
3
- emoji: 😻
4
- colorFrom: indigo
5
- colorTo: red
6
  sdk: gradio
7
- sdk_version: 5.49.1
8
  app_file: app.py
9
  pinned: false
10
- license: mit
11
- short_description: 4 Tasks of Evo
12
  ---
13
 
14
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
1
  ---
2
+ title: Evo Model Interface
3
+ emoji: 🧬
4
+ colorFrom: blue
5
+ colorTo: green
6
  sdk: gradio
7
+ sdk_version: 4.44.0
8
  app_file: app.py
9
  pinned: false
10
+ license: apache-2.0
11
+ python_version: 3.11
12
  ---
13
 
14
+ Check configuration
15
+ We'll verify that the model and space are configured correctly from a few properties in the README's YAML metadata.
app.py ADDED
@@ -0,0 +1,468 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Evo Model Web Interface
3
+ A simple Gradio app for testing Evo's predictive and generative capabilities.
4
+ """
5
+ import gradio as gr
6
+ import torch
7
+ import numpy as np
8
+ from evo import Evo
9
+ from evo.scoring import score_sequences
10
+ from evo.generation import generate
11
+ from typing import List, Tuple, Dict
12
+ import io
13
+
14
+
15
+ # Global model variables
16
+ model = None
17
+ tokenizer = None
18
+ device = "cuda:0" if torch.cuda.is_available() else "cpu"
19
+
20
+
21
+ def load_model():
22
+ """Load Evo model once at startup."""
23
+ global model, tokenizer
24
+ if model is None:
25
+ print("Loading Evo model...")
26
+ evo_model = Evo('evo-1-8k-base')
27
+ model, tokenizer = evo_model.model, evo_model.tokenizer
28
+ model.to(device)
29
+ model.eval()
30
+ print("✓ Model loaded successfully")
31
+
32
+
33
+ # ============================================================================
34
+ # TASK 1: Function Prediction
35
+ # ============================================================================
36
+
37
+ def detect_sequence_type(seq: str) -> str:
38
+ """Detect if sequence is DNA, RNA, or protein."""
39
+ seq_upper = seq.upper()
40
+ if any(c in set('EFILPQZ') for c in seq_upper):
41
+ return 'protein'
42
+ if 'U' in seq_upper:
43
+ return 'RNA'
44
+ if all(c in set('ACGTN') for c in seq_upper):
45
+ return 'DNA'
46
+ return 'unknown'
47
+
48
+
49
+ def parse_fasta_text(text: str) -> List[Tuple[str, str]]:
50
+ """Parse FASTA format text into (id, sequence) tuples."""
51
+ sequences = []
52
+ current_id = None
53
+ current_seq = []
54
+
55
+ for line in text.strip().split('\n'):
56
+ line = line.strip()
57
+ if line.startswith('>'):
58
+ if current_id is not None:
59
+ sequences.append((current_id, ''.join(current_seq)))
60
+ current_id = line[1:].split('|')[0].strip()
61
+ current_seq = []
62
+ else:
63
+ current_seq.append(line)
64
+
65
+ if current_id is not None:
66
+ sequences.append((current_id, ''.join(current_seq)))
67
+
68
+ return sequences
69
+
70
+
71
+ def predict_function(sequences_text: str, threshold: float) -> str:
72
+ """Predict sequence functionality."""
73
+ load_model()
74
+
75
+ if not sequences_text.strip():
76
+ return "⚠️ Please enter sequences in FASTA format or paste sequences directly."
77
+
78
+ # Parse input
79
+ if sequences_text.startswith('>'):
80
+ # FASTA format
81
+ seq_data = parse_fasta_text(sequences_text)
82
+ else:
83
+ # Single sequence
84
+ seq_data = [("sequence_1", sequences_text.strip().replace('\n', ''))]
85
+
86
+ if not seq_data:
87
+ return "⚠️ No valid sequences found."
88
+
89
+ # Score sequences
90
+ sequences = [seq for _, seq in seq_data]
91
+ scores = score_sequences(sequences, model, tokenizer, reduce_method='mean', device=device)
92
+
93
+ # Format results
94
+ results = ["# Function Prediction Results\n"]
95
+ results.append(f"{'Sequence ID':<20} {'Type':<10} {'Score':<12} {'Prediction':<15} {'Length':<10}")
96
+ results.append("-" * 70)
97
+
98
+ for (seq_id, seq), score in zip(seq_data, scores):
99
+ seq_type = detect_sequence_type(seq)
100
+ prediction = "✓ Functional" if score > threshold else "✗ Non-functional"
101
+ results.append(f"{seq_id:<20} {seq_type:<10} {score:<12.4f} {prediction:<15} {len(seq):<10}")
102
+
103
+ results.append("\n" + "=" * 70)
104
+ results.append(f"Total sequences: {len(seq_data)}")
105
+ results.append(f"Functional: {sum(1 for s in scores if s > threshold)}")
106
+ results.append(f"Non-functional: {sum(1 for s in scores if s <= threshold)}")
107
+ results.append(f"Average score: {np.mean(scores):.4f}")
108
+
109
+ return "\n".join(results)
110
+
111
+
112
+ # ============================================================================
113
+ # TASK 2: Gene Essentiality
114
+ # ============================================================================
115
+
116
+ def predict_essentiality(genes_text: str) -> str:
117
+ """Predict gene essentiality."""
118
+ load_model()
119
+
120
+ if not genes_text.strip():
121
+ return "⚠️ Please enter gene sequences in FASTA format."
122
+
123
+ # Parse FASTA
124
+ if not genes_text.startswith('>'):
125
+ return "⚠️ Please use FASTA format: >gene_id|organism|function\\nATGC..."
126
+
127
+ gene_data = parse_fasta_text(genes_text)
128
+ if not gene_data:
129
+ return "⚠️ No valid genes found."
130
+
131
+ # Score genes
132
+ sequences = [seq for _, seq in gene_data]
133
+ scores = score_sequences(sequences, model, tokenizer, reduce_method='mean', device=device)
134
+
135
+ # Calculate statistics
136
+ scores_mean = np.mean(scores)
137
+ scores_std = np.std(scores)
138
+
139
+ # Format results
140
+ results = ["# Gene Essentiality Prediction\n"]
141
+ results.append(f"{'Gene ID':<20} {'Z-Score':<10} {'Score':<12} {'Essentiality':<15} {'Confidence':<12}")
142
+ results.append("-" * 70)
143
+
144
+ essential_count = 0
145
+ for (gene_id, seq), score in zip(gene_data, scores):
146
+ z_score = (score - scores_mean) / scores_std if scores_std > 0 else 0
147
+
148
+ if z_score > 0.5:
149
+ essentiality = "✓ Essential"
150
+ confidence = "High" if z_score > 1.0 else "Medium"
151
+ essential_count += 1
152
+ elif z_score < -0.5:
153
+ essentiality = "✗ Non-essential"
154
+ confidence = "High" if z_score < -1.0 else "Medium"
155
+ else:
156
+ essentiality = "? Uncertain"
157
+ confidence = "Low"
158
+
159
+ results.append(f"{gene_id:<20} {z_score:<10.2f} {score:<12.4f} {essentiality:<15} {confidence:<12}")
160
+
161
+ results.append("\n" + "=" * 70)
162
+ results.append(f"Total genes: {len(gene_data)}")
163
+ results.append(f"Essential: {essential_count}")
164
+ results.append(f"Mean score: {scores_mean:.4f} (std: {scores_std:.4f})")
165
+
166
+ return "\n".join(results)
167
+
168
+
169
+ # ============================================================================
170
+ # TASK 3: CRISPR Generation
171
+ # ============================================================================
172
+
173
+ def generate_crispr(n_systems: int, cas_type: str, target_seq: str, cas_length: int) -> str:
174
+ """Generate CRISPR-Cas systems."""
175
+ load_model()
176
+
177
+ # Templates
178
+ cas9_start = 'ATGAACAAGAAC'
179
+ cas12_start = 'ATGAGCAAGCTG'
180
+
181
+ results = ["# CRISPR-Cas System Generation\n"]
182
+
183
+ cas_types = ['cas9', 'cas12'] if cas_type == 'Both' else [cas_type.lower()]
184
+
185
+ for i in range(n_systems):
186
+ current_cas = cas_types[i % len(cas_types)]
187
+ prompt = cas9_start if current_cas == 'cas9' else cas12_start
188
+
189
+ results.append(f"\n{'='*70}")
190
+ results.append(f"System {i+1}: {current_cas.upper()}")
191
+ results.append('='*70)
192
+
193
+ # Generate Cas protein
194
+ output_seqs, _ = generate(
195
+ [prompt],
196
+ model,
197
+ tokenizer,
198
+ n_tokens=cas_length,
199
+ temperature=0.8,
200
+ top_k=4,
201
+ device=device,
202
+ verbose=0
203
+ )
204
+ cas_protein = output_seqs[0]
205
+
206
+ # Generate gRNA spacer
207
+ if target_seq:
208
+ complement = {'A': 'U', 'T': 'A', 'G': 'C', 'C': 'G'}
209
+ spacer = ''.join(complement.get(b, 'N') for b in reversed(target_seq[:20]))
210
+ else:
211
+ spacer_seqs, _ = generate(['G'], model, tokenizer, n_tokens=19, temperature=0.7,
212
+ top_k=4, device=device, verbose=0)
213
+ spacer = spacer_seqs[0][:20].replace('T', 'U')
214
+
215
+ # PAM sequence
216
+ pam = 'NGG' if current_cas == 'cas9' else 'TTTN'
217
+
218
+ results.append(f"\n{current_cas.upper()} Protein ({len(cas_protein)} nt):")
219
+ results.append(f"{cas_protein[:80]}..." if len(cas_protein) > 80 else cas_protein)
220
+ results.append(f"\ngRNA Spacer: {spacer}")
221
+ results.append(f"PAM Sequence: {pam}")
222
+ if current_cas == 'cas9':
223
+ results.append(f"tracrRNA: AGCAUAGCAAGUUAAAAUAAGGCUAGUCCGU")
224
+
225
+ return "\n".join(results)
226
+
227
+
228
+ # ============================================================================
229
+ # TASK 4: Regulatory Design
230
+ # ============================================================================
231
+
232
+ def generate_spacer_simple(length: int) -> str:
233
+ """Generate a simple random spacer."""
234
+ bases = ['A', 'T', 'G', 'C']
235
+ return ''.join(np.random.choice(bases) for _ in range(length))
236
+
237
+
238
+ def design_regulatory(n_designs: int, expression_level: str) -> str:
239
+ """Design regulatory sequences."""
240
+ load_model()
241
+
242
+ # Templates
243
+ promoter_templates = {
244
+ 'High': ('TTGACA', 'TATAAT'),
245
+ 'Medium': ('TTGACT', 'TATACT'),
246
+ 'Low': ('TTGCCA', 'TATGAT')
247
+ }
248
+
249
+ rbs_templates = {
250
+ 'High': 'AGGAGGU',
251
+ 'Medium': 'AGGAGG',
252
+ 'Low': 'AGGA'
253
+ }
254
+
255
+ results = ["# Regulatory Sequences Design\n"]
256
+
257
+ levels = ['High', 'Medium', 'Low']
258
+
259
+ for i in range(n_designs):
260
+ if expression_level == 'Mixed':
261
+ level = levels[i % 3]
262
+ else:
263
+ level = expression_level
264
+
265
+ results.append(f"\n{'='*70}")
266
+ results.append(f"Design {i+1}: {level} Expression")
267
+ results.append('='*70)
268
+
269
+ # Get promoter boxes
270
+ box_35, box_10 = promoter_templates[level]
271
+
272
+ # Generate spacers
273
+ spacer_35_10 = generate_spacer_simple(17)
274
+ spacer_10_rbs = generate_spacer_simple(7)
275
+
276
+ # Get RBS
277
+ rbs = rbs_templates[level]
278
+
279
+ # Generate RBS-ATG spacer
280
+ spacer_rbs_atg = generate_spacer_simple(7)
281
+
282
+ # Assemble
283
+ promoter = box_35 + spacer_35_10 + box_10
284
+ full_region = promoter + spacer_10_rbs + rbs + spacer_rbs_atg + 'ATG'
285
+
286
+ gc_content = 100 * (full_region.count('G') + full_region.count('C')) / len(full_region)
287
+
288
+ results.append(f"\nComponents:")
289
+ results.append(f" -35 box: {box_35}")
290
+ results.append(f" -10 box: {box_10}")
291
+ results.append(f" RBS (Shine-Dalgarno): {rbs}")
292
+ results.append(f" Start codon: ATG")
293
+ results.append(f"\nFull Regulatory Region ({len(full_region)} bp, GC={gc_content:.1f}%):")
294
+ results.append(full_region)
295
+ results.append(f"\nPromoter only:")
296
+ results.append(promoter)
297
+
298
+ return "\n".join(results)
299
+
300
+
301
+ # ============================================================================
302
+ # Gradio Interface
303
+ # ============================================================================
304
+
305
+ def create_interface():
306
+ """Create the Gradio interface."""
307
+
308
+ with gr.Blocks(title="Evo Model Interface", theme=gr.themes.Soft()) as demo:
309
+ gr.Markdown("# 🧬 Evo Model Interface")
310
+ gr.Markdown("### Test Evo's predictive and generative capabilities")
311
+
312
+ with gr.Tabs():
313
+ # Task 1: Function Prediction
314
+ with gr.Tab("🔍 Function Prediction"):
315
+ gr.Markdown("### Predict if sequences are functional")
316
+ gr.Markdown("*Enter sequences in FASTA format or paste a single sequence*")
317
+
318
+ with gr.Row():
319
+ with gr.Column():
320
+ func_input = gr.Textbox(
321
+ label="Input Sequences",
322
+ placeholder=">seq1|description\nATCGATCGATCG...\n\nOr paste a single sequence directly",
323
+ lines=8
324
+ )
325
+ func_threshold = gr.Slider(
326
+ minimum=-3.0,
327
+ maximum=0.0,
328
+ value=-1.5,
329
+ step=0.1,
330
+ label="Functionality Threshold"
331
+ )
332
+ func_btn = gr.Button("Predict Function", variant="primary")
333
+
334
+ with gr.Column():
335
+ func_output = gr.Textbox(
336
+ label="Results",
337
+ lines=15,
338
+ show_copy_button=True
339
+ )
340
+
341
+ func_btn.click(
342
+ fn=predict_function,
343
+ inputs=[func_input, func_threshold],
344
+ outputs=func_output
345
+ )
346
+
347
+ gr.Examples(
348
+ examples=[
349
+ [">functional_gene\nATGGCACAACCCGCGCCGAACTGGTTGACCTGAAAACCACCGCCGCACTGCGTCAGGCCAGCCAGGCGGAACAA", -1.5],
350
+ [">noncoding\nGCTAGCTAGCTAGCTAGCTAGCTAGCTAGCTAGCTAGCTAGCTAGCTAGCTAGC", -1.5],
351
+ ],
352
+ inputs=[func_input, func_threshold]
353
+ )
354
+
355
+ # Task 2: Gene Essentiality
356
+ with gr.Tab("🧬 Gene Essentiality"):
357
+ gr.Markdown("### Predict essential genes in bacteria/phages")
358
+ gr.Markdown("*Input format: >gene_id|organism|function*")
359
+
360
+ with gr.Row():
361
+ with gr.Column():
362
+ ess_input = gr.Textbox(
363
+ label="Gene Sequences (FASTA)",
364
+ placeholder=">dnaA|E.coli|Replication initiator\nATGTCGAAAGCCGCAT...",
365
+ lines=8
366
+ )
367
+ ess_btn = gr.Button("Predict Essentiality", variant="primary")
368
+
369
+ with gr.Column():
370
+ ess_output = gr.Textbox(
371
+ label="Results",
372
+ lines=15,
373
+ show_copy_button=True
374
+ )
375
+
376
+ ess_btn.click(
377
+ fn=predict_essentiality,
378
+ inputs=ess_input,
379
+ outputs=ess_output
380
+ )
381
+
382
+ # Task 3: CRISPR Generation
383
+ with gr.Tab("✂️ CRISPR Generation"):
384
+ gr.Markdown("### Generate synthetic CRISPR-Cas systems")
385
+
386
+ with gr.Row():
387
+ with gr.Column():
388
+ crispr_n = gr.Slider(
389
+ minimum=1,
390
+ maximum=5,
391
+ value=2,
392
+ step=1,
393
+ label="Number of Systems"
394
+ )
395
+ crispr_type = gr.Radio(
396
+ choices=["Cas9", "Cas12", "Both"],
397
+ value="Both",
398
+ label="Cas Type"
399
+ )
400
+ crispr_target = gr.Textbox(
401
+ label="Target Sequence (optional)",
402
+ placeholder="ATCGATCGATCGATCG",
403
+ lines=2
404
+ )
405
+ crispr_length = gr.Slider(
406
+ minimum=500,
407
+ maximum=2000,
408
+ value=1000,
409
+ step=100,
410
+ label="Cas Protein Length"
411
+ )
412
+ crispr_btn = gr.Button("Generate CRISPR Systems", variant="primary")
413
+
414
+ with gr.Column():
415
+ crispr_output = gr.Textbox(
416
+ label="Generated Systems",
417
+ lines=15,
418
+ show_copy_button=True
419
+ )
420
+
421
+ crispr_btn.click(
422
+ fn=generate_crispr,
423
+ inputs=[crispr_n, crispr_type, crispr_target, crispr_length],
424
+ outputs=crispr_output
425
+ )
426
+
427
+ # Task 4: Regulatory Design
428
+ with gr.Tab("🎛️ Regulatory Design"):
429
+ gr.Markdown("### Design promoter-RBS pairs for gene expression")
430
+
431
+ with gr.Row():
432
+ with gr.Column():
433
+ reg_n = gr.Slider(
434
+ minimum=1,
435
+ maximum=10,
436
+ value=3,
437
+ step=1,
438
+ label="Number of Designs"
439
+ )
440
+ reg_level = gr.Radio(
441
+ choices=["High", "Medium", "Low", "Mixed"],
442
+ value="Mixed",
443
+ label="Expression Level"
444
+ )
445
+ reg_btn = gr.Button("Design Regulatory Sequences", variant="primary")
446
+
447
+ with gr.Column():
448
+ reg_output = gr.Textbox(
449
+ label="Designed Sequences",
450
+ lines=15,
451
+ show_copy_button=True
452
+ )
453
+
454
+ reg_btn.click(
455
+ fn=design_regulatory,
456
+ inputs=[reg_n, reg_level],
457
+ outputs=reg_output
458
+ )
459
+
460
+ gr.Markdown("---")
461
+ gr.Markdown("💡 **Tips:** Higher scores = more functional/essential | All outputs can be copied | Model: evo-1-8k-base")
462
+
463
+ return demo
464
+
465
+
466
+ if __name__ == "__main__":
467
+ demo = create_interface()
468
+ demo.launch()
evo/__init__.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ from .version import version as __version__
2
+
3
+ from .models import Evo
4
+
5
+ from .generation import generate
6
+ from .scoring import score_sequences, positional_entropies
evo/configs/evo-1-131k-base_inference.yml ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ vocab_size: 512
2
+ hidden_size: 4096
3
+ num_filters: 4096
4
+ max_sequence_len: 8192
5
+ attn_layer_idxs: [8, 16, 24]
6
+ hyena_layer_idxs: [0, 1, 2, 3, 4, 5, 6, 7, 9, 10, 11, 12, 13, 14, 15, 17, 18, 19, 20, 21, 22, 23, 25, 26, 27, 28, 29, 30, 31]
7
+ num_layers: 32
8
+ short_filter_length: 3
9
+ num_attention_heads: 32
10
+ short_filter_bias: True
11
+ mlp_init_method: torch.nn.init.zeros_
12
+ mlp_output_init_method: torch.nn.init.zeros_
13
+ eps: 1.0e-6
14
+ state_size: 8
15
+ inner_size_multiple_of: 16 # force GLU inner_size to be a multiple of
16
+ smeared_gqa: False
17
+ make_vocab_size_divisible_by: 8
18
+ log_intermediate_values: False
19
+ proj_groups: 1 # GQA
20
+ hyena_filter_groups: 1
21
+ split_k0: True
22
+ model_parallel_size: 1
23
+ pile_parallel_size: 1
24
+ tie_embeddings: True
25
+ inner_mlp_size: null # set to None, so it auto-fills
26
+ mha_out_proj_bias: True
27
+ qkv_proj_bias: True
28
+ final_norm: True
29
+ rng_fork: False
30
+ use_flash_attn: False
31
+ use_flash_rmsnorm: False
32
+ use_flash_depthwise: False
33
+ use_flashfft: False
34
+ column_split: True # only affects outputs when proj_groups > 1
35
+ inference_mode: True
36
+ tokenizer_type: CharLevelTokenizer
37
+ prefill_style: fft
38
+ mlp_activation: gelu
39
+ use_interpolated_rotary_pos_emb: true # turn this one for linear interpolated context extension
40
+ rotary_emb_scaling_factor: 16 # scaling factor for time indices in rotary embeddings
evo/configs/evo-1-8k-base_inference.yml ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ vocab_size: 512
2
+ hidden_size: 4096
3
+ num_filters: 4096
4
+ max_sequence_len: 8192
5
+ attn_layer_idxs: [8, 16, 24]
6
+ hyena_layer_idxs: [0, 1, 2, 3, 4, 5, 6, 7, 9, 10, 11, 12, 13, 14, 15, 17, 18, 19, 20, 21, 22, 23, 25, 26, 27, 28, 29, 30, 31]
7
+ num_layers: 32
8
+ short_filter_length: 3
9
+ num_attention_heads: 32
10
+ short_filter_bias: True
11
+ mlp_init_method: torch.nn.init.zeros_
12
+ mlp_output_init_method: torch.nn.init.zeros_
13
+ eps: 1.0e-6
14
+ state_size: 8
15
+ inner_size_multiple_of: 16 # force GLU inner_size to be a multiple of
16
+ smeared_gqa: False
17
+ make_vocab_size_divisible_by: 8
18
+ log_intermediate_values: False
19
+ proj_groups: 1 # GQA
20
+ hyena_filter_groups: 1
21
+ split_k0: True
22
+ model_parallel_size: 1
23
+ pile_parallel_size: 1
24
+ tie_embeddings: True
25
+ inner_mlp_size: null # set to None, so it auto-fills
26
+ mha_out_proj_bias: True
27
+ qkv_proj_bias: True
28
+ final_norm: True
29
+ rng_fork: False
30
+ use_flash_attn: False
31
+ use_flash_rmsnorm: False
32
+ use_flash_depthwise: False
33
+ use_flashfft: False
34
+ column_split: True # only affects outputs when proj_groups > 1
35
+ inference_mode: True
36
+ tokenizer_type: CharLevelTokenizer
37
+ prefill_style: fft
38
+ mlp_activation: gelu
evo/generation.py ADDED
@@ -0,0 +1,297 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import sys
3
+ import torch
4
+ from typing import List, Tuple, Union
5
+
6
+ from stripedhyena.model import StripedHyena
7
+ from stripedhyena.sample import sample
8
+ from stripedhyena.tokenizer import CharLevelTokenizer
9
+
10
+ from .scoring import logits_to_logprobs, prepare_batch
11
+
12
+
13
+ class Generator:
14
+ '''
15
+ Adapted from https://github.com/togethercomputer/stripedhyena.
16
+
17
+ Modifications include:
18
+ - `generate()` accepts and returns the recurrent cache state, letting the user
19
+ keep track of it across sampling runs.
20
+ - Able to sample with long token prompts in which the cache is initialized with
21
+ recurrent teacher forcing.
22
+ '''
23
+ def __init__(
24
+ self,
25
+ model: StripedHyena,
26
+ tokenizer: CharLevelTokenizer,
27
+ top_k: int = 50,
28
+ top_p: float = 0.7,
29
+ temperature: float = 1.,
30
+ ):
31
+ self.model = model
32
+ self.tokenizer = tokenizer
33
+ self.top_k = top_k
34
+ self.top_p = top_p
35
+ self.temperature = temperature
36
+ self.untils = ['\n\n']
37
+
38
+ def generate(
39
+ self,
40
+ device: str,
41
+ input_string: str = None,
42
+ input_ids: torch.tensor = None,
43
+ num_tokens: int = 32,
44
+ cached_generation: bool = True,
45
+ force_prompt_threshold: int = 128,
46
+ print_generation: bool = True,
47
+ verbose: bool = False,
48
+ skip_special_tokens: bool = False,
49
+ stop_at_eos: bool = True,
50
+ max_seqlen: int = None,
51
+ inference_params_dict: dict = None,
52
+ ) -> Tuple[torch.tensor, torch.tensor, dict]:
53
+ """
54
+ A version of the generate() method that enables passing in and that returns the
55
+ `inference_params_dict` for replaying cached sampling from a given state.
56
+ """
57
+ if isinstance(self.tokenizer.eos, int):
58
+ eos_token_ids = torch.LongTensor([self.tokenizer.eos]).to(device)
59
+ else:
60
+ # is a tensor
61
+ eos_token_ids = self.tokenizer.tokenize(self.tokenizer.eos).to(device)
62
+
63
+ if input_ids is None:
64
+ input = self.tokenizer.tokenize(input_string)
65
+ if isinstance(input, list):
66
+ input = torch.LongTensor(input).unsqueeze(0).to(device)
67
+ # is a tensor
68
+ else:
69
+ input = input.unsqueeze(0).to(device)
70
+
71
+ else:
72
+ input = input_ids
73
+ x = input
74
+
75
+ if max_seqlen is not None:
76
+ x = x[:, -max_seqlen :]
77
+
78
+ num_tokens = int(num_tokens)
79
+ batch_size = x.shape[0]
80
+
81
+ prompt_length = x.shape[1]
82
+ prompt_forcing = prompt_length > force_prompt_threshold
83
+ if prompt_forcing:
84
+ forced_prompt_length = prompt_length - force_prompt_threshold
85
+ x_force = x[:, force_prompt_threshold:]
86
+ x = x[:, :force_prompt_threshold]
87
+ else:
88
+ forced_prompt_length = 0
89
+
90
+ generation = torch.empty(
91
+ x.shape[0],
92
+ num_tokens,
93
+ dtype=torch.long,
94
+ device=x.device,
95
+ )
96
+
97
+ scores = torch.empty(
98
+ x.shape[0],
99
+ num_tokens,
100
+ self.tokenizer.vocab_size,
101
+ dtype=torch.float,
102
+ device=x.device,
103
+ )
104
+
105
+ if inference_params_dict is not None:
106
+ cached_generation = True
107
+ prefilled = True
108
+ # Ensure that the cached data is loaded on the correct device.
109
+ for key, data in inference_params_dict['mha'].key_value_memory_dict.items():
110
+ inference_params_dict['mha'].key_value_memory_dict[key] = data.to(x.device)
111
+ for key, data in inference_params_dict['hyena'].fir_state_dict.items():
112
+ inference_params_dict['hyena'].fir_state_dict[key] = data.to(x.device)
113
+ for key, data in inference_params_dict['hyena'].state_dict.items():
114
+ inference_params_dict['hyena'].state_dict[key] = data.to(x.device)
115
+
116
+ elif cached_generation:
117
+ inference_params_dict = self.model.initialize_inference_params()
118
+ inference_params_dict['mha'].max_batch_size = batch_size
119
+ inference_params_dict['hyena'].max_batch_size = batch_size
120
+ prefilled = False
121
+
122
+ if verbose:
123
+ mem_after_tok = torch.cuda.memory_allocated(device=x.device) / 1e9
124
+ print(f'Memory after tokenization: {mem_after_tok} GB')
125
+ print('Starting generation...')
126
+ if input_string is not None:
127
+ print('Prompt: ' + input_string)
128
+ else:
129
+ print(f'Prompt ids: {input_ids} {input_ids.shape}')
130
+
131
+ for i in range(forced_prompt_length + num_tokens):
132
+ if prefilled:
133
+ post_prefill = True
134
+ else:
135
+ post_prefill = cached_generation and i > 0
136
+
137
+ # prefill then process only the last token
138
+ if post_prefill:
139
+ x = x[:, -1:]
140
+ seqlen_offset = inference_params_dict['mha'].seqlen_offset
141
+
142
+ if seqlen_offset == 0:
143
+ seqlen_offset = input.shape[-1]
144
+ inference_params_dict['hyena'].seqlen_offset = seqlen_offset
145
+ inference_params_dict['mha'].seqlen_offset = seqlen_offset
146
+ else:
147
+ inference_params_dict['mha'].seqlen_offset += 1
148
+ inference_params_dict['hyena'].seqlen_offset += 1
149
+
150
+ # do forward pass with no gradient
151
+ with torch.inference_mode():
152
+ logits, inference_params_dict = self.model(
153
+ x,
154
+ inference_params_dict=inference_params_dict,
155
+ )
156
+
157
+ last_logits = logits[:, -1]
158
+
159
+ if prompt_forcing and i < forced_prompt_length:
160
+ new_idx = x_force[:, i]
161
+ else:
162
+ new_idx = sample(
163
+ last_logits,
164
+ top_k=self.top_k,
165
+ top_p=self.top_p,
166
+ temperature=self.temperature,
167
+ )
168
+
169
+ if stop_at_eos and (generation[0, -2:] == eos_token_ids).all():
170
+ print('Stopping generation at EOS')
171
+
172
+ if print_generation and verbose and batch_size == 1:
173
+ print(
174
+ f'{self.tokenizer.detokenize([new_idx.item()])}',
175
+ end=' ',
176
+ )
177
+
178
+ if prompt_forcing:
179
+ if i >= forced_prompt_length:
180
+ scores[:, i - forced_prompt_length] = last_logits
181
+ generation[:, i - forced_prompt_length] = new_idx
182
+ else:
183
+ scores[:, i] = last_logits
184
+ generation[:, i] = new_idx
185
+
186
+ if post_prefill:
187
+ x = new_idx[:, None]
188
+ else:
189
+ x = torch.cat([x, new_idx[:, None]], dim=-1)
190
+
191
+ if verbose:
192
+ y = self.tokenizer.detokenize_batch(generation[:, : i + 1])
193
+
194
+ for until in self.untils:
195
+ if until in y:
196
+ y = y.split(until)[0]
197
+ break
198
+
199
+ print(f'\nInput: {input_string}, Output: {y}')
200
+
201
+ mem_end = torch.cuda.memory_allocated(device=x.device) / 1e9
202
+ print(f'Memory after generation: {mem_end} GB')
203
+
204
+ return generation[:, : i + 1], scores[:, : i + 1], inference_params_dict
205
+
206
+
207
+ def generate(
208
+ prompt_seqs: List[str],
209
+ model: StripedHyena,
210
+ tokenizer: CharLevelTokenizer,
211
+ n_tokens: int = 100,
212
+ temperature: float = 0.,
213
+ top_k: int = 1,
214
+ top_p: float = 1.,
215
+ batched: bool = True,
216
+ prepend_bos: bool = False,
217
+ cached_generation: bool = False,
218
+ force_prompt_threshold: int = 128,
219
+ verbose: int = 1,
220
+ device: str = 'cuda:0',
221
+ **kwargs,
222
+ ) -> Tuple[List[str], List[float]]:
223
+ """
224
+ Performs generation from a list of prompts.
225
+ If all prompts are the same length, this can do batched generation.
226
+ Also supports cached generation for efficient sampling.
227
+ """
228
+ model.eval()
229
+
230
+ g = Generator(
231
+ model,
232
+ tokenizer,
233
+ top_k=top_k,
234
+ top_p=top_p,
235
+ temperature=temperature,
236
+ )
237
+
238
+ uniform_lengths = all(len(s) == len(prompt_seqs[0]) for s in prompt_seqs)
239
+
240
+ if batched and uniform_lengths:
241
+ input_ids_list = [
242
+ prepare_batch(
243
+ prompt_seqs,
244
+ tokenizer,
245
+ prepend_bos=prepend_bos,
246
+ device=device,
247
+ )[0]
248
+ ]
249
+ else:
250
+ if verbose:
251
+ if not uniform_lengths:
252
+ sys.stderr.write('Note: Prompts are of different lengths.\n')
253
+ sys.stderr.write('Note: Will not do batched generation.\n')
254
+ input_ids_list = [
255
+ prepare_batch(
256
+ [ prompt_seq ],
257
+ tokenizer,
258
+ prepend_bos=prepend_bos,
259
+ device=device,
260
+ )[0]
261
+ for prompt_seq in prompt_seqs
262
+ ]
263
+
264
+ generated_seqs, generated_scores = [], []
265
+ for input_ids in input_ids_list:
266
+ batch_size = input_ids.shape[0]
267
+
268
+ output_ids, logits, _ = g.generate(
269
+ input_ids=input_ids,
270
+ num_tokens=n_tokens,
271
+ cached_generation=cached_generation,
272
+ force_prompt_threshold=force_prompt_threshold,
273
+ device=device,
274
+ print_generation=(verbose > 1),
275
+ verbose=(verbose > 1),
276
+ stop_at_eos=False,
277
+ )
278
+ if verbose > 1:
279
+ print('input_ids.shape', input_ids.shape)
280
+ print('output_ids.shape', output_ids.shape)
281
+ print('logits.shape', logits.shape)
282
+
283
+ generated_seqs_batch = list(tokenizer.detokenize_batch(output_ids))
284
+ assert len(generated_seqs_batch) == batch_size
285
+ generated_seqs += generated_seqs_batch
286
+
287
+ logprobs = logits_to_logprobs(logits, output_ids)
288
+ logprobs = logprobs.float().cpu().numpy()
289
+
290
+ generated_scores += [ np.mean(logprobs[idx]) for idx in range(batch_size) ]
291
+
292
+ assert len(generated_seqs) == len(generated_scores) == len(prompt_seqs)
293
+ if verbose:
294
+ for seq, score, prompt in zip(generated_seqs, generated_scores, prompt_seqs):
295
+ print(f'Prompt: "{prompt}",\tOutput: "{seq}",\tScore: {score}')
296
+
297
+ return generated_seqs, generated_scores
evo/models.py ADDED
@@ -0,0 +1,122 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pkgutil
2
+ import re
3
+ from transformers import AutoConfig, AutoModelForCausalLM
4
+ import yaml
5
+
6
+ from stripedhyena.utils import dotdict
7
+ from stripedhyena.model import StripedHyena
8
+ from stripedhyena.tokenizer import CharLevelTokenizer
9
+
10
+
11
+ MODEL_NAMES = [
12
+ 'evo-1.5-8k-base',
13
+ 'evo-1-8k-base',
14
+ 'evo-1-131k-base',
15
+ 'evo-1-8k-crispr',
16
+ 'evo-1-8k-transposon',
17
+ ]
18
+
19
+ class Evo:
20
+ def __init__(self, model_name: str = MODEL_NAMES[1], device: str = None):
21
+ """
22
+ Loads an Evo model checkpoint given a model name.
23
+ If the checkpoint does not exist, we automatically download it from HuggingFace.
24
+ """
25
+ self.device = device
26
+
27
+ # Check model name.
28
+
29
+ if model_name not in MODEL_NAMES:
30
+ raise ValueError(
31
+ f'Invalid model name {model_name}. Should be one of: '
32
+ f'{", ".join(MODEL_NAMES)}.'
33
+ )
34
+
35
+ # Assign config path.
36
+
37
+ if model_name == 'evo-1-8k-base' or \
38
+ model_name == 'evo-1-8k-crispr' or \
39
+ model_name == 'evo-1-8k-transposon' or \
40
+ model_name == 'evo-1.5-8k-base':
41
+ config_path = 'configs/evo-1-8k-base_inference.yml'
42
+ elif model_name == 'evo-1-131k-base':
43
+ config_path = 'configs/evo-1-131k-base_inference.yml'
44
+ else:
45
+ raise ValueError(
46
+ f'Invalid model name {model_name}. Should be one of: '
47
+ f'{", ".join(MODEL_NAMES)}.'
48
+ )
49
+
50
+ # Load model.
51
+
52
+ self.model = load_checkpoint(
53
+ model_name=model_name,
54
+ config_path=config_path,
55
+ device=self.device
56
+ )
57
+
58
+ # Load tokenizer.
59
+
60
+ self.tokenizer = CharLevelTokenizer(512)
61
+
62
+
63
+ HF_MODEL_NAME_MAP = {
64
+ 'evo-1.5-8k-base': 'evo-design/evo-1.5-8k-base',
65
+ 'evo-1-8k-base': 'togethercomputer/evo-1-8k-base',
66
+ 'evo-1-131k-base': 'togethercomputer/evo-1-131k-base',
67
+ 'evo-1-8k-crispr': 'LongSafari/evo-1-8k-crispr',
68
+ 'evo-1-8k-transposon': 'LongSafari/evo-1-8k-transposon',
69
+ }
70
+
71
+ def load_checkpoint(
72
+ model_name: str = MODEL_NAMES[1],
73
+ config_path: str = 'evo/configs/evo-1-131k-base_inference.yml',
74
+ device: str = None,
75
+ *args, **kwargs
76
+ ):
77
+ """
78
+ Load checkpoint from HuggingFace and place it into SH model.
79
+ """
80
+
81
+ # Map model name to HuggingFace model name.
82
+
83
+ hf_model_name = HF_MODEL_NAME_MAP[model_name]
84
+
85
+ # Load model config.
86
+
87
+ model_config = AutoConfig.from_pretrained(
88
+ hf_model_name,
89
+ trust_remote_code=True,
90
+ revision='1.1_fix' if re.match(r'evo-1-.*-base', model_name) else 'main',
91
+ )
92
+ model_config.use_cache = True
93
+
94
+ # Load model.
95
+
96
+ model = AutoModelForCausalLM.from_pretrained(
97
+ hf_model_name,
98
+ config=model_config,
99
+ trust_remote_code=True,
100
+ revision='1.1_fix' if re.match(r'evo-1-.*-base', model_name) else 'main',
101
+ )
102
+
103
+ # Load model state dict & cleanup.
104
+
105
+ state_dict = model.backbone.state_dict()
106
+ del model
107
+ del model_config
108
+
109
+ # Load SH config.
110
+
111
+ config = yaml.safe_load(pkgutil.get_data(__name__, config_path))
112
+ global_config = dotdict(config, Loader=yaml.FullLoader)
113
+
114
+ # Load SH Model.
115
+
116
+ model = StripedHyena(global_config)
117
+ model.load_state_dict(state_dict, strict=True)
118
+ model.to_bfloat16_except_poles_residues()
119
+ if device is not None:
120
+ model = model.to(device)
121
+
122
+ return model
evo/scoring.py ADDED
@@ -0,0 +1,131 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ from typing import List, Tuple
4
+
5
+ from stripedhyena.model import StripedHyena
6
+ from stripedhyena.tokenizer import CharLevelTokenizer
7
+
8
+
9
+ def prepare_batch(
10
+ seqs: List[str],
11
+ tokenizer: CharLevelTokenizer,
12
+ prepend_bos: bool = True,
13
+ device: str = 'cuda:0'
14
+ ) -> Tuple[torch.Tensor, List[int]]:
15
+ """
16
+ Takes in a list of sequences, tokenizes them, and puts them in a tensor batch.
17
+ If the sequences have differing lengths, then pad up to the maximum sequence length.
18
+ """
19
+ seq_lengths = [ len(seq) for seq in seqs ]
20
+ max_seq_length = max(seq_lengths)
21
+
22
+ input_ids = []
23
+ for seq in seqs:
24
+ padding = [tokenizer.pad_id] * (max_seq_length - len(seq))
25
+ input_ids.append(
26
+ torch.tensor(
27
+ ([tokenizer.eod_id] * int(prepend_bos)) + tokenizer.tokenize(seq) + padding,
28
+ dtype=torch.long,
29
+ ).to(device).unsqueeze(0)
30
+ )
31
+ input_ids = torch.cat(input_ids, dim=0)
32
+
33
+ return input_ids, seq_lengths
34
+
35
+
36
+ def logits_to_logprobs(
37
+ logits: torch.Tensor,
38
+ input_ids: torch.Tensor,
39
+ trim_bos: bool = True,
40
+ ) -> torch.Tensor:
41
+ """
42
+ Takes in a tensor of logits of dimension (batch, length, vocab).
43
+ Computes the log-likelihoods using a softmax along the vocab dimension.
44
+ Uses the `input_ids` to index into the log-likelihoods and returns the likelihood
45
+ of the provided sequence at each position with dimension (batch, length).
46
+ """
47
+ softmax_logprobs = torch.log_softmax(logits, dim=-1)
48
+ if trim_bos:
49
+ softmax_logprobs = softmax_logprobs[:, :-1] # Remove last prediction.
50
+ input_ids = input_ids[:, 1:] # Trim BOS added by tokenizer.
51
+ assert(softmax_logprobs.shape[1] == input_ids.shape[1])
52
+
53
+ logprobs = torch.gather(
54
+ softmax_logprobs, # Gather likelihoods...
55
+ 2, # along the vocab dimension...
56
+ input_ids.unsqueeze(-1) # using the token ids to index.
57
+ ).squeeze(-1)
58
+
59
+ return logprobs
60
+
61
+
62
+ def score_sequences(
63
+ seqs: List[str],
64
+ model: StripedHyena,
65
+ tokenizer: CharLevelTokenizer,
66
+ reduce_method: str = 'mean',
67
+ device: str = 'cuda:0',
68
+ ) -> List[float]:
69
+ """
70
+ Computes the model log-likelihood scores for sequences in `seqs`.
71
+ Uses `reduce_method` to take the mean or sum across the likelihoods at each
72
+ position (default: `'mean'`).
73
+
74
+ Returns a list of scalar scores corresponding to the reduced log-likelihoods for
75
+ each sequence.
76
+ """
77
+ input_ids, seq_lengths = prepare_batch(seqs, tokenizer, device=device, prepend_bos=True)
78
+ assert(len(seq_lengths) == input_ids.shape[0])
79
+
80
+ with torch.inference_mode():
81
+ logits, _ = model(input_ids) # (batch, length, vocab)
82
+
83
+ logprobs = logits_to_logprobs(logits, input_ids, trim_bos=True)
84
+ logprobs = logprobs.float().cpu().numpy()
85
+
86
+ if reduce_method == 'mean':
87
+ reduce_func = np.mean
88
+ elif reduce_method == 'sum':
89
+ reduce_func = np.sum
90
+ else:
91
+ raise ValueError(f'Invalid reduce_method {reduce_method}')
92
+
93
+ return [
94
+ reduce_func(logprobs[idx][:seq_lengths[idx]])
95
+ for idx in range(len(seq_lengths))
96
+ ]
97
+
98
+
99
+ def positional_entropies(
100
+ seqs: List[str],
101
+ model: StripedHyena,
102
+ tokenizer: CharLevelTokenizer,
103
+ device: str = 'cuda:0',
104
+ ) -> List[np.array]:
105
+ """
106
+ Computes the positional entropies for sequences in `seqs`.
107
+
108
+ Returns a list of arrays, where each array is the same length as the
109
+ corresponding sequence length. Each array contains the per-position entropy
110
+ across the vocab dimension.
111
+ """
112
+ input_ids, seq_lengths = prepare_batch(seqs, tokenizer, device=device, prepend_bos=True)
113
+ assert(len(seq_lengths) == input_ids.shape[0])
114
+
115
+ with torch.inference_mode():
116
+ logits, _ = model(input_ids) # (batch, length, vocab)
117
+
118
+ # Tokenizer prepends BOS, remember to remove last prediction.
119
+ softmax_logprobs = torch.log_softmax(logits, dim=-1)[:, :-1]
120
+
121
+ entropies = -torch.sum(torch.exp(softmax_logprobs) * softmax_logprobs, dim=-1)
122
+ entropies = entropies.float().cpu().numpy()
123
+
124
+ sequence_entropies = [
125
+ entropies[idx][:seq_lengths[idx]] for idx in range(len(seq_lengths))
126
+ ]
127
+ assert all(
128
+ len(seq) == len(entropy) for seq, entropy in zip(seqs, sequence_entropies)
129
+ )
130
+
131
+ return sequence_entropies
evo/utils.py ADDED
@@ -0,0 +1,183 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import pandas as pd
3
+ from typing import Callable
4
+
5
+
6
+ NTs = 'ACGT'
7
+
8
+ AAs = 'ACDEFGHIKLMNPQRSTVWY'
9
+
10
+ AA_TO_CODON = {
11
+ '*': ['TAA', 'TAG', 'TGA'], # Stop.
12
+ 'A': ['GCT', 'GCC', 'GCA', 'GCG'], # Ala.
13
+ 'C': ['TGT', 'TGC'], # Cys.
14
+ 'D': ['GAT', 'GAC'], # Asp.
15
+ 'E': ['GAA', 'GAG'], # Glu.
16
+ 'F': ['TTT', 'TTC'], # Phe.
17
+ 'G': ['GGU', 'GGC', 'GGA', 'GGG'], # Gly.
18
+ 'H': ['CAT', 'CAC'], # His.
19
+ 'I': ['ATT', 'ATC', 'ATA'], # Ile.
20
+ 'K': ['AAA', 'AAG'], # Lys.
21
+ 'L': ['TTA', 'TTG', 'CTT', 'CTC', 'CTA', 'CTG'], # Leu.
22
+ 'M': ['ATG'], # Met.
23
+ 'N': ['AAT', 'AAC'], # Asn.
24
+ 'P': ['CCT', 'CCC', 'CCA', 'CCG'], # Pro.
25
+ 'Q': ['CAA', 'CAG'], # Gln.
26
+ 'R': ['CGT', 'CGC', 'CGA', 'CGG', 'AGA', 'AGG'], # Arg.
27
+ 'S': ['TCT', 'TCC', 'TCA', 'TCG', 'AGT', 'AGC'], # Ser.
28
+ 'T': ['ACT', 'ACC', 'ACA', 'ACG'], # Thr.
29
+ 'V': ['GTT', 'GTC', 'GTA', 'GTG'], # Val.
30
+ 'W': ['TGG'], # Trp.
31
+ 'Y': ['TAT', 'TAC'], # Tyr.
32
+ }
33
+
34
+ CODON_TO_AA = {
35
+ codon: aa
36
+ for aa, codon_list in AA_TO_CODON.items()
37
+ for codon in codon_list
38
+ }
39
+
40
+ AA_3_TO_1 = {
41
+ "Ala": "A", # Alanine
42
+ "Arg": "R", # Arginine
43
+ "Asn": "N", # Asparagine
44
+ "Asp": "D", # Aspartic acid
45
+ "Cys": "C", # Cysteine
46
+ "Gln": "Q", # Glutamine
47
+ "Glu": "E", # Glutamic acid
48
+ "Gly": "G", # Glycine
49
+ "His": "H", # Histidine
50
+ "Ile": "I", # Isoleucine
51
+ "Leu": "L", # Leucine
52
+ "Lys": "K", # Lysine
53
+ "Met": "M", # Methionine
54
+ "Phe": "F", # Phenylalanine
55
+ "Pro": "P", # Proline
56
+ "Ser": "S", # Serine
57
+ "Thr": "T", # Threonine
58
+ "Trp": "W", # Tryptophan
59
+ "Tyr": "Y", # Tyrosine
60
+ "Val": "V" # Valine
61
+ }
62
+
63
+
64
+ def nucleotide_deep_mutational_scan(sequence: str, ignore_wt: bool = True):
65
+ for idx, wt in enumerate(sequence):
66
+ for mt in NTs:
67
+ if ignore_wt and wt == mt:
68
+ continue
69
+ yield (wt, mt, idx)
70
+
71
+
72
+ def parse_blast_output(output_path: str) -> pd.DataFrame:
73
+ """
74
+ Parses standard blast output with `-outfmt 6`.
75
+ """
76
+ # blast default format output fields.
77
+ blast_table_header = [
78
+ 'qacc', 'sacc', 'pident', 'length', 'mismatch', 'gapopen', 'qstart',
79
+ 'qend', 'sstart', 'send', 'evalue',
80
+ ]
81
+
82
+ data = []
83
+ with open(output_path, 'r') as f:
84
+ for line in f:
85
+ if line.startswith("#"):
86
+ continue
87
+ if line.strip() == '':
88
+ continue
89
+ line = line.strip().split()
90
+ data.append(dict(zip(blast_table_header, line)))
91
+
92
+ df = pd.DataFrame(data)
93
+ if len(df) == 0:
94
+ return df
95
+ df['evalue'] = df['evalue'].astype(float)
96
+
97
+ return df
98
+
99
+
100
+ def parse_erpin_output(output_path: str, name: str) -> pd.DataFrame:
101
+ """
102
+ Parses ERPIN output. For an example, see `eval/data/example_rho_output.txt`.
103
+ """
104
+ # ERPIN format output fields.
105
+ output_fields = [ 'strand', 'index', 'interval', 'score', 'evalue' ]
106
+
107
+ data = []
108
+ with open(output_path, 'r') as f:
109
+ for line in f:
110
+ if line.startswith(f'>{name}'):
111
+ meta = dict(zip(output_fields, f.readline().rstrip().split()))
112
+ sequence = f.readline().rstrip()
113
+ start, end = meta['interval'].split('..')
114
+ data.append([
115
+ f"{name}_{meta['index']}_{meta['strand']}",
116
+ sequence,
117
+ int(start),
118
+ int(end),
119
+ '+' if meta['strand'] == 'FW' else '-',
120
+ meta['score'],
121
+ float(meta['evalue']),
122
+ ])
123
+
124
+ return pd.DataFrame(
125
+ data,
126
+ columns=[
127
+ 'id',
128
+ 'seq',
129
+ 'start',
130
+ 'end',
131
+ 'strand',
132
+ 'score',
133
+ 'evalue',
134
+ ],
135
+ )
136
+
137
+
138
+ def parse_hmmsearch_output(output_path: str) -> pd.DataFrame:
139
+ """
140
+ Parses standard hmmsearch output.
141
+ """
142
+ # hmmsearch format output fields.
143
+ hmmsearch_table_header = [
144
+ 'target', 'target_acc', 'tlen', 'query', 'query_acc', 'qlen',
145
+ 'evalue', 'score', 'bias', 'num', 'of', 'cevalue', 'ievalue',
146
+ 'dscore', 'dbias', 'hmm_from', 'hmm_to', 'ali_from', 'ali_to',
147
+ 'env_from', 'env_to', 'acc', 'desc',
148
+ ]
149
+
150
+ data = []
151
+ with open(output_path, 'r') as f:
152
+ for line in f:
153
+ if line.startswith("#"):
154
+ continue
155
+ line = line.strip().split()
156
+ data.append(dict(zip(hmmsearch_table_header, line)))
157
+
158
+ return pd.DataFrame(data)
159
+
160
+
161
+ def permutation_test(
162
+ score_func: Callable[[np.array, np.array], float],
163
+ x1: np.array,
164
+ x2: np.array,
165
+ n_permutations: int = 100_000,
166
+ ) -> float:
167
+ """
168
+ Returns a permutation-based P value. Computes the null distribution by
169
+ shuffling the provided data and recomputing the `score_func`.
170
+ """
171
+ if n_permutations < 1:
172
+ raise ValueError('Number of permutations must be positive.')
173
+
174
+ x1, x2 = np.array(x1), np.array(x2)
175
+
176
+ observed_score = score_func(x1, x2)
177
+
178
+ null_distribution = np.array([
179
+ score_func(x1, np.random.permutation(x2))
180
+ for _ in range(n_permutations)
181
+ ])
182
+
183
+ return np.mean(null_distribution >= observed_score)
evo/version.py ADDED
@@ -0,0 +1 @@
 
 
1
+ version = '0.4'
requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ gradio==4.44.0
2
+ torch==2.1.0
3
+ numpy==1.24.3
4
+ transformers==4.36.0
5
+ einops==0.7.0
6
+ pyyaml==6.0.1
7
+ git+https://github.com/togethercomputer/stripedhyena.git