Commit
·
2997d61
1
Parent(s):
e994268
initialised app
Browse files- .gitignore +216 -0
- README.md +9 -8
- app.py +468 -0
- evo/__init__.py +6 -0
- evo/configs/evo-1-131k-base_inference.yml +40 -0
- evo/configs/evo-1-8k-base_inference.yml +38 -0
- evo/generation.py +297 -0
- evo/models.py +122 -0
- evo/scoring.py +131 -0
- evo/utils.py +183 -0
- evo/version.py +1 -0
- requirements.txt +7 -0
.gitignore
ADDED
|
@@ -0,0 +1,216 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Byte-compiled / optimized / DLL files
|
| 2 |
+
__pycache__/
|
| 3 |
+
*.py[codz]
|
| 4 |
+
*$py.class
|
| 5 |
+
|
| 6 |
+
# C extensions
|
| 7 |
+
*.so
|
| 8 |
+
|
| 9 |
+
# Distribution / packaging
|
| 10 |
+
.Python
|
| 11 |
+
build/
|
| 12 |
+
develop-eggs/
|
| 13 |
+
dist/
|
| 14 |
+
downloads/
|
| 15 |
+
eggs/
|
| 16 |
+
.eggs/
|
| 17 |
+
lib/
|
| 18 |
+
lib64/
|
| 19 |
+
parts/
|
| 20 |
+
sdist/
|
| 21 |
+
var/
|
| 22 |
+
wheels/
|
| 23 |
+
share/python-wheels/
|
| 24 |
+
*.egg-info/
|
| 25 |
+
.installed.cfg
|
| 26 |
+
*.egg
|
| 27 |
+
MANIFEST
|
| 28 |
+
|
| 29 |
+
# PyInstaller
|
| 30 |
+
# Usually these files are written by a python script from a template
|
| 31 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
| 32 |
+
*.manifest
|
| 33 |
+
*.spec
|
| 34 |
+
|
| 35 |
+
# Installer logs
|
| 36 |
+
pip-log.txt
|
| 37 |
+
pip-delete-this-directory.txt
|
| 38 |
+
|
| 39 |
+
# Unit test / coverage reports
|
| 40 |
+
htmlcov/
|
| 41 |
+
.tox/
|
| 42 |
+
.nox/
|
| 43 |
+
.coverage
|
| 44 |
+
.coverage.*
|
| 45 |
+
.cache
|
| 46 |
+
nosetests.xml
|
| 47 |
+
coverage.xml
|
| 48 |
+
*.cover
|
| 49 |
+
*.py.cover
|
| 50 |
+
.hypothesis/
|
| 51 |
+
.pytest_cache/
|
| 52 |
+
cover/
|
| 53 |
+
|
| 54 |
+
# Translations
|
| 55 |
+
*.mo
|
| 56 |
+
*.pot
|
| 57 |
+
|
| 58 |
+
# Django stuff:
|
| 59 |
+
*.log
|
| 60 |
+
local_settings.py
|
| 61 |
+
db.sqlite3
|
| 62 |
+
db.sqlite3-journal
|
| 63 |
+
|
| 64 |
+
# Flask stuff:
|
| 65 |
+
instance/
|
| 66 |
+
.webassets-cache
|
| 67 |
+
|
| 68 |
+
# Scrapy stuff:
|
| 69 |
+
.scrapy
|
| 70 |
+
|
| 71 |
+
# Sphinx documentation
|
| 72 |
+
docs/_build/
|
| 73 |
+
|
| 74 |
+
# PyBuilder
|
| 75 |
+
.pybuilder/
|
| 76 |
+
target/
|
| 77 |
+
|
| 78 |
+
# Jupyter Notebook
|
| 79 |
+
.ipynb_checkpoints
|
| 80 |
+
|
| 81 |
+
# IPython
|
| 82 |
+
profile_default/
|
| 83 |
+
ipython_config.py
|
| 84 |
+
|
| 85 |
+
# pyenv
|
| 86 |
+
# For a library or package, you might want to ignore these files since the code is
|
| 87 |
+
# intended to run in multiple environments; otherwise, check them in:
|
| 88 |
+
# .python-version
|
| 89 |
+
|
| 90 |
+
# pipenv
|
| 91 |
+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
| 92 |
+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
| 93 |
+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
| 94 |
+
# install all needed dependencies.
|
| 95 |
+
# Pipfile.lock
|
| 96 |
+
|
| 97 |
+
# UV
|
| 98 |
+
# Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control.
|
| 99 |
+
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
| 100 |
+
# commonly ignored for libraries.
|
| 101 |
+
# uv.lock
|
| 102 |
+
|
| 103 |
+
# poetry
|
| 104 |
+
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
|
| 105 |
+
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
| 106 |
+
# commonly ignored for libraries.
|
| 107 |
+
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
|
| 108 |
+
# poetry.lock
|
| 109 |
+
# poetry.toml
|
| 110 |
+
|
| 111 |
+
# pdm
|
| 112 |
+
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
|
| 113 |
+
# pdm recommends including project-wide configuration in pdm.toml, but excluding .pdm-python.
|
| 114 |
+
# https://pdm-project.org/en/latest/usage/project/#working-with-version-control
|
| 115 |
+
# pdm.lock
|
| 116 |
+
# pdm.toml
|
| 117 |
+
.pdm-python
|
| 118 |
+
.pdm-build/
|
| 119 |
+
|
| 120 |
+
# pixi
|
| 121 |
+
# Similar to Pipfile.lock, it is generally recommended to include pixi.lock in version control.
|
| 122 |
+
# pixi.lock
|
| 123 |
+
# Pixi creates a virtual environment in the .pixi directory, just like venv module creates one
|
| 124 |
+
# in the .venv directory. It is recommended not to include this directory in version control.
|
| 125 |
+
.pixi
|
| 126 |
+
|
| 127 |
+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
|
| 128 |
+
__pypackages__/
|
| 129 |
+
|
| 130 |
+
# Celery stuff
|
| 131 |
+
celerybeat-schedule
|
| 132 |
+
celerybeat.pid
|
| 133 |
+
|
| 134 |
+
# Redis
|
| 135 |
+
*.rdb
|
| 136 |
+
*.aof
|
| 137 |
+
*.pid
|
| 138 |
+
|
| 139 |
+
# RabbitMQ
|
| 140 |
+
mnesia/
|
| 141 |
+
rabbitmq/
|
| 142 |
+
rabbitmq-data/
|
| 143 |
+
|
| 144 |
+
# ActiveMQ
|
| 145 |
+
activemq-data/
|
| 146 |
+
|
| 147 |
+
# SageMath parsed files
|
| 148 |
+
*.sage.py
|
| 149 |
+
|
| 150 |
+
# Environments
|
| 151 |
+
.env
|
| 152 |
+
.envrc
|
| 153 |
+
.venv
|
| 154 |
+
env/
|
| 155 |
+
venv/
|
| 156 |
+
ENV/
|
| 157 |
+
env.bak/
|
| 158 |
+
venv.bak/
|
| 159 |
+
|
| 160 |
+
# Spyder project settings
|
| 161 |
+
.spyderproject
|
| 162 |
+
.spyproject
|
| 163 |
+
|
| 164 |
+
# Rope project settings
|
| 165 |
+
.ropeproject
|
| 166 |
+
|
| 167 |
+
# mkdocs documentation
|
| 168 |
+
/site
|
| 169 |
+
|
| 170 |
+
# mypy
|
| 171 |
+
.mypy_cache/
|
| 172 |
+
.dmypy.json
|
| 173 |
+
dmypy.json
|
| 174 |
+
|
| 175 |
+
# Pyre type checker
|
| 176 |
+
.pyre/
|
| 177 |
+
|
| 178 |
+
# pytype static type analyzer
|
| 179 |
+
.pytype/
|
| 180 |
+
|
| 181 |
+
# Cython debug symbols
|
| 182 |
+
cython_debug/
|
| 183 |
+
|
| 184 |
+
# PyCharm
|
| 185 |
+
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
|
| 186 |
+
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
| 187 |
+
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
| 188 |
+
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
| 189 |
+
# .idea/
|
| 190 |
+
|
| 191 |
+
# Abstra
|
| 192 |
+
# Abstra is an AI-powered process automation framework.
|
| 193 |
+
# Ignore directories containing user credentials, local state, and settings.
|
| 194 |
+
# Learn more at https://abstra.io/docs
|
| 195 |
+
.abstra/
|
| 196 |
+
|
| 197 |
+
# Visual Studio Code
|
| 198 |
+
# Visual Studio Code specific template is maintained in a separate VisualStudioCode.gitignore
|
| 199 |
+
# that can be found at https://github.com/github/gitignore/blob/main/Global/VisualStudioCode.gitignore
|
| 200 |
+
# and can be added to the global gitignore or merged into this file. However, if you prefer,
|
| 201 |
+
# you could uncomment the following to ignore the entire vscode folder
|
| 202 |
+
# .vscode/
|
| 203 |
+
|
| 204 |
+
# Ruff stuff:
|
| 205 |
+
.ruff_cache/
|
| 206 |
+
|
| 207 |
+
# PyPI configuration file
|
| 208 |
+
.pypirc
|
| 209 |
+
|
| 210 |
+
# Marimo
|
| 211 |
+
marimo/_static/
|
| 212 |
+
marimo/_lsp/
|
| 213 |
+
__marimo__/
|
| 214 |
+
|
| 215 |
+
# Streamlit
|
| 216 |
+
.streamlit/secrets.toml
|
README.md
CHANGED
|
@@ -1,14 +1,15 @@
|
|
| 1 |
---
|
| 2 |
-
title: Evo
|
| 3 |
-
emoji:
|
| 4 |
-
colorFrom:
|
| 5 |
-
colorTo:
|
| 6 |
sdk: gradio
|
| 7 |
-
sdk_version:
|
| 8 |
app_file: app.py
|
| 9 |
pinned: false
|
| 10 |
-
license:
|
| 11 |
-
|
| 12 |
---
|
| 13 |
|
| 14 |
-
Check
|
|
|
|
|
|
| 1 |
---
|
| 2 |
+
title: Evo Model Interface
|
| 3 |
+
emoji: 🧬
|
| 4 |
+
colorFrom: blue
|
| 5 |
+
colorTo: green
|
| 6 |
sdk: gradio
|
| 7 |
+
sdk_version: 4.44.0
|
| 8 |
app_file: app.py
|
| 9 |
pinned: false
|
| 10 |
+
license: apache-2.0
|
| 11 |
+
python_version: 3.11
|
| 12 |
---
|
| 13 |
|
| 14 |
+
Check configuration
|
| 15 |
+
We'll verify that the model and space are configured correctly from a few properties in the README's YAML metadata.
|
app.py
ADDED
|
@@ -0,0 +1,468 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Evo Model Web Interface
|
| 3 |
+
A simple Gradio app for testing Evo's predictive and generative capabilities.
|
| 4 |
+
"""
|
| 5 |
+
import gradio as gr
|
| 6 |
+
import torch
|
| 7 |
+
import numpy as np
|
| 8 |
+
from evo import Evo
|
| 9 |
+
from evo.scoring import score_sequences
|
| 10 |
+
from evo.generation import generate
|
| 11 |
+
from typing import List, Tuple, Dict
|
| 12 |
+
import io
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
# Global model variables
|
| 16 |
+
model = None
|
| 17 |
+
tokenizer = None
|
| 18 |
+
device = "cuda:0" if torch.cuda.is_available() else "cpu"
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def load_model():
|
| 22 |
+
"""Load Evo model once at startup."""
|
| 23 |
+
global model, tokenizer
|
| 24 |
+
if model is None:
|
| 25 |
+
print("Loading Evo model...")
|
| 26 |
+
evo_model = Evo('evo-1-8k-base')
|
| 27 |
+
model, tokenizer = evo_model.model, evo_model.tokenizer
|
| 28 |
+
model.to(device)
|
| 29 |
+
model.eval()
|
| 30 |
+
print("✓ Model loaded successfully")
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
# ============================================================================
|
| 34 |
+
# TASK 1: Function Prediction
|
| 35 |
+
# ============================================================================
|
| 36 |
+
|
| 37 |
+
def detect_sequence_type(seq: str) -> str:
|
| 38 |
+
"""Detect if sequence is DNA, RNA, or protein."""
|
| 39 |
+
seq_upper = seq.upper()
|
| 40 |
+
if any(c in set('EFILPQZ') for c in seq_upper):
|
| 41 |
+
return 'protein'
|
| 42 |
+
if 'U' in seq_upper:
|
| 43 |
+
return 'RNA'
|
| 44 |
+
if all(c in set('ACGTN') for c in seq_upper):
|
| 45 |
+
return 'DNA'
|
| 46 |
+
return 'unknown'
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
def parse_fasta_text(text: str) -> List[Tuple[str, str]]:
|
| 50 |
+
"""Parse FASTA format text into (id, sequence) tuples."""
|
| 51 |
+
sequences = []
|
| 52 |
+
current_id = None
|
| 53 |
+
current_seq = []
|
| 54 |
+
|
| 55 |
+
for line in text.strip().split('\n'):
|
| 56 |
+
line = line.strip()
|
| 57 |
+
if line.startswith('>'):
|
| 58 |
+
if current_id is not None:
|
| 59 |
+
sequences.append((current_id, ''.join(current_seq)))
|
| 60 |
+
current_id = line[1:].split('|')[0].strip()
|
| 61 |
+
current_seq = []
|
| 62 |
+
else:
|
| 63 |
+
current_seq.append(line)
|
| 64 |
+
|
| 65 |
+
if current_id is not None:
|
| 66 |
+
sequences.append((current_id, ''.join(current_seq)))
|
| 67 |
+
|
| 68 |
+
return sequences
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
def predict_function(sequences_text: str, threshold: float) -> str:
|
| 72 |
+
"""Predict sequence functionality."""
|
| 73 |
+
load_model()
|
| 74 |
+
|
| 75 |
+
if not sequences_text.strip():
|
| 76 |
+
return "⚠️ Please enter sequences in FASTA format or paste sequences directly."
|
| 77 |
+
|
| 78 |
+
# Parse input
|
| 79 |
+
if sequences_text.startswith('>'):
|
| 80 |
+
# FASTA format
|
| 81 |
+
seq_data = parse_fasta_text(sequences_text)
|
| 82 |
+
else:
|
| 83 |
+
# Single sequence
|
| 84 |
+
seq_data = [("sequence_1", sequences_text.strip().replace('\n', ''))]
|
| 85 |
+
|
| 86 |
+
if not seq_data:
|
| 87 |
+
return "⚠️ No valid sequences found."
|
| 88 |
+
|
| 89 |
+
# Score sequences
|
| 90 |
+
sequences = [seq for _, seq in seq_data]
|
| 91 |
+
scores = score_sequences(sequences, model, tokenizer, reduce_method='mean', device=device)
|
| 92 |
+
|
| 93 |
+
# Format results
|
| 94 |
+
results = ["# Function Prediction Results\n"]
|
| 95 |
+
results.append(f"{'Sequence ID':<20} {'Type':<10} {'Score':<12} {'Prediction':<15} {'Length':<10}")
|
| 96 |
+
results.append("-" * 70)
|
| 97 |
+
|
| 98 |
+
for (seq_id, seq), score in zip(seq_data, scores):
|
| 99 |
+
seq_type = detect_sequence_type(seq)
|
| 100 |
+
prediction = "✓ Functional" if score > threshold else "✗ Non-functional"
|
| 101 |
+
results.append(f"{seq_id:<20} {seq_type:<10} {score:<12.4f} {prediction:<15} {len(seq):<10}")
|
| 102 |
+
|
| 103 |
+
results.append("\n" + "=" * 70)
|
| 104 |
+
results.append(f"Total sequences: {len(seq_data)}")
|
| 105 |
+
results.append(f"Functional: {sum(1 for s in scores if s > threshold)}")
|
| 106 |
+
results.append(f"Non-functional: {sum(1 for s in scores if s <= threshold)}")
|
| 107 |
+
results.append(f"Average score: {np.mean(scores):.4f}")
|
| 108 |
+
|
| 109 |
+
return "\n".join(results)
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
# ============================================================================
|
| 113 |
+
# TASK 2: Gene Essentiality
|
| 114 |
+
# ============================================================================
|
| 115 |
+
|
| 116 |
+
def predict_essentiality(genes_text: str) -> str:
|
| 117 |
+
"""Predict gene essentiality."""
|
| 118 |
+
load_model()
|
| 119 |
+
|
| 120 |
+
if not genes_text.strip():
|
| 121 |
+
return "⚠️ Please enter gene sequences in FASTA format."
|
| 122 |
+
|
| 123 |
+
# Parse FASTA
|
| 124 |
+
if not genes_text.startswith('>'):
|
| 125 |
+
return "⚠️ Please use FASTA format: >gene_id|organism|function\\nATGC..."
|
| 126 |
+
|
| 127 |
+
gene_data = parse_fasta_text(genes_text)
|
| 128 |
+
if not gene_data:
|
| 129 |
+
return "⚠️ No valid genes found."
|
| 130 |
+
|
| 131 |
+
# Score genes
|
| 132 |
+
sequences = [seq for _, seq in gene_data]
|
| 133 |
+
scores = score_sequences(sequences, model, tokenizer, reduce_method='mean', device=device)
|
| 134 |
+
|
| 135 |
+
# Calculate statistics
|
| 136 |
+
scores_mean = np.mean(scores)
|
| 137 |
+
scores_std = np.std(scores)
|
| 138 |
+
|
| 139 |
+
# Format results
|
| 140 |
+
results = ["# Gene Essentiality Prediction\n"]
|
| 141 |
+
results.append(f"{'Gene ID':<20} {'Z-Score':<10} {'Score':<12} {'Essentiality':<15} {'Confidence':<12}")
|
| 142 |
+
results.append("-" * 70)
|
| 143 |
+
|
| 144 |
+
essential_count = 0
|
| 145 |
+
for (gene_id, seq), score in zip(gene_data, scores):
|
| 146 |
+
z_score = (score - scores_mean) / scores_std if scores_std > 0 else 0
|
| 147 |
+
|
| 148 |
+
if z_score > 0.5:
|
| 149 |
+
essentiality = "✓ Essential"
|
| 150 |
+
confidence = "High" if z_score > 1.0 else "Medium"
|
| 151 |
+
essential_count += 1
|
| 152 |
+
elif z_score < -0.5:
|
| 153 |
+
essentiality = "✗ Non-essential"
|
| 154 |
+
confidence = "High" if z_score < -1.0 else "Medium"
|
| 155 |
+
else:
|
| 156 |
+
essentiality = "? Uncertain"
|
| 157 |
+
confidence = "Low"
|
| 158 |
+
|
| 159 |
+
results.append(f"{gene_id:<20} {z_score:<10.2f} {score:<12.4f} {essentiality:<15} {confidence:<12}")
|
| 160 |
+
|
| 161 |
+
results.append("\n" + "=" * 70)
|
| 162 |
+
results.append(f"Total genes: {len(gene_data)}")
|
| 163 |
+
results.append(f"Essential: {essential_count}")
|
| 164 |
+
results.append(f"Mean score: {scores_mean:.4f} (std: {scores_std:.4f})")
|
| 165 |
+
|
| 166 |
+
return "\n".join(results)
|
| 167 |
+
|
| 168 |
+
|
| 169 |
+
# ============================================================================
|
| 170 |
+
# TASK 3: CRISPR Generation
|
| 171 |
+
# ============================================================================
|
| 172 |
+
|
| 173 |
+
def generate_crispr(n_systems: int, cas_type: str, target_seq: str, cas_length: int) -> str:
|
| 174 |
+
"""Generate CRISPR-Cas systems."""
|
| 175 |
+
load_model()
|
| 176 |
+
|
| 177 |
+
# Templates
|
| 178 |
+
cas9_start = 'ATGAACAAGAAC'
|
| 179 |
+
cas12_start = 'ATGAGCAAGCTG'
|
| 180 |
+
|
| 181 |
+
results = ["# CRISPR-Cas System Generation\n"]
|
| 182 |
+
|
| 183 |
+
cas_types = ['cas9', 'cas12'] if cas_type == 'Both' else [cas_type.lower()]
|
| 184 |
+
|
| 185 |
+
for i in range(n_systems):
|
| 186 |
+
current_cas = cas_types[i % len(cas_types)]
|
| 187 |
+
prompt = cas9_start if current_cas == 'cas9' else cas12_start
|
| 188 |
+
|
| 189 |
+
results.append(f"\n{'='*70}")
|
| 190 |
+
results.append(f"System {i+1}: {current_cas.upper()}")
|
| 191 |
+
results.append('='*70)
|
| 192 |
+
|
| 193 |
+
# Generate Cas protein
|
| 194 |
+
output_seqs, _ = generate(
|
| 195 |
+
[prompt],
|
| 196 |
+
model,
|
| 197 |
+
tokenizer,
|
| 198 |
+
n_tokens=cas_length,
|
| 199 |
+
temperature=0.8,
|
| 200 |
+
top_k=4,
|
| 201 |
+
device=device,
|
| 202 |
+
verbose=0
|
| 203 |
+
)
|
| 204 |
+
cas_protein = output_seqs[0]
|
| 205 |
+
|
| 206 |
+
# Generate gRNA spacer
|
| 207 |
+
if target_seq:
|
| 208 |
+
complement = {'A': 'U', 'T': 'A', 'G': 'C', 'C': 'G'}
|
| 209 |
+
spacer = ''.join(complement.get(b, 'N') for b in reversed(target_seq[:20]))
|
| 210 |
+
else:
|
| 211 |
+
spacer_seqs, _ = generate(['G'], model, tokenizer, n_tokens=19, temperature=0.7,
|
| 212 |
+
top_k=4, device=device, verbose=0)
|
| 213 |
+
spacer = spacer_seqs[0][:20].replace('T', 'U')
|
| 214 |
+
|
| 215 |
+
# PAM sequence
|
| 216 |
+
pam = 'NGG' if current_cas == 'cas9' else 'TTTN'
|
| 217 |
+
|
| 218 |
+
results.append(f"\n{current_cas.upper()} Protein ({len(cas_protein)} nt):")
|
| 219 |
+
results.append(f"{cas_protein[:80]}..." if len(cas_protein) > 80 else cas_protein)
|
| 220 |
+
results.append(f"\ngRNA Spacer: {spacer}")
|
| 221 |
+
results.append(f"PAM Sequence: {pam}")
|
| 222 |
+
if current_cas == 'cas9':
|
| 223 |
+
results.append(f"tracrRNA: AGCAUAGCAAGUUAAAAUAAGGCUAGUCCGU")
|
| 224 |
+
|
| 225 |
+
return "\n".join(results)
|
| 226 |
+
|
| 227 |
+
|
| 228 |
+
# ============================================================================
|
| 229 |
+
# TASK 4: Regulatory Design
|
| 230 |
+
# ============================================================================
|
| 231 |
+
|
| 232 |
+
def generate_spacer_simple(length: int) -> str:
|
| 233 |
+
"""Generate a simple random spacer."""
|
| 234 |
+
bases = ['A', 'T', 'G', 'C']
|
| 235 |
+
return ''.join(np.random.choice(bases) for _ in range(length))
|
| 236 |
+
|
| 237 |
+
|
| 238 |
+
def design_regulatory(n_designs: int, expression_level: str) -> str:
|
| 239 |
+
"""Design regulatory sequences."""
|
| 240 |
+
load_model()
|
| 241 |
+
|
| 242 |
+
# Templates
|
| 243 |
+
promoter_templates = {
|
| 244 |
+
'High': ('TTGACA', 'TATAAT'),
|
| 245 |
+
'Medium': ('TTGACT', 'TATACT'),
|
| 246 |
+
'Low': ('TTGCCA', 'TATGAT')
|
| 247 |
+
}
|
| 248 |
+
|
| 249 |
+
rbs_templates = {
|
| 250 |
+
'High': 'AGGAGGU',
|
| 251 |
+
'Medium': 'AGGAGG',
|
| 252 |
+
'Low': 'AGGA'
|
| 253 |
+
}
|
| 254 |
+
|
| 255 |
+
results = ["# Regulatory Sequences Design\n"]
|
| 256 |
+
|
| 257 |
+
levels = ['High', 'Medium', 'Low']
|
| 258 |
+
|
| 259 |
+
for i in range(n_designs):
|
| 260 |
+
if expression_level == 'Mixed':
|
| 261 |
+
level = levels[i % 3]
|
| 262 |
+
else:
|
| 263 |
+
level = expression_level
|
| 264 |
+
|
| 265 |
+
results.append(f"\n{'='*70}")
|
| 266 |
+
results.append(f"Design {i+1}: {level} Expression")
|
| 267 |
+
results.append('='*70)
|
| 268 |
+
|
| 269 |
+
# Get promoter boxes
|
| 270 |
+
box_35, box_10 = promoter_templates[level]
|
| 271 |
+
|
| 272 |
+
# Generate spacers
|
| 273 |
+
spacer_35_10 = generate_spacer_simple(17)
|
| 274 |
+
spacer_10_rbs = generate_spacer_simple(7)
|
| 275 |
+
|
| 276 |
+
# Get RBS
|
| 277 |
+
rbs = rbs_templates[level]
|
| 278 |
+
|
| 279 |
+
# Generate RBS-ATG spacer
|
| 280 |
+
spacer_rbs_atg = generate_spacer_simple(7)
|
| 281 |
+
|
| 282 |
+
# Assemble
|
| 283 |
+
promoter = box_35 + spacer_35_10 + box_10
|
| 284 |
+
full_region = promoter + spacer_10_rbs + rbs + spacer_rbs_atg + 'ATG'
|
| 285 |
+
|
| 286 |
+
gc_content = 100 * (full_region.count('G') + full_region.count('C')) / len(full_region)
|
| 287 |
+
|
| 288 |
+
results.append(f"\nComponents:")
|
| 289 |
+
results.append(f" -35 box: {box_35}")
|
| 290 |
+
results.append(f" -10 box: {box_10}")
|
| 291 |
+
results.append(f" RBS (Shine-Dalgarno): {rbs}")
|
| 292 |
+
results.append(f" Start codon: ATG")
|
| 293 |
+
results.append(f"\nFull Regulatory Region ({len(full_region)} bp, GC={gc_content:.1f}%):")
|
| 294 |
+
results.append(full_region)
|
| 295 |
+
results.append(f"\nPromoter only:")
|
| 296 |
+
results.append(promoter)
|
| 297 |
+
|
| 298 |
+
return "\n".join(results)
|
| 299 |
+
|
| 300 |
+
|
| 301 |
+
# ============================================================================
|
| 302 |
+
# Gradio Interface
|
| 303 |
+
# ============================================================================
|
| 304 |
+
|
| 305 |
+
def create_interface():
|
| 306 |
+
"""Create the Gradio interface."""
|
| 307 |
+
|
| 308 |
+
with gr.Blocks(title="Evo Model Interface", theme=gr.themes.Soft()) as demo:
|
| 309 |
+
gr.Markdown("# 🧬 Evo Model Interface")
|
| 310 |
+
gr.Markdown("### Test Evo's predictive and generative capabilities")
|
| 311 |
+
|
| 312 |
+
with gr.Tabs():
|
| 313 |
+
# Task 1: Function Prediction
|
| 314 |
+
with gr.Tab("🔍 Function Prediction"):
|
| 315 |
+
gr.Markdown("### Predict if sequences are functional")
|
| 316 |
+
gr.Markdown("*Enter sequences in FASTA format or paste a single sequence*")
|
| 317 |
+
|
| 318 |
+
with gr.Row():
|
| 319 |
+
with gr.Column():
|
| 320 |
+
func_input = gr.Textbox(
|
| 321 |
+
label="Input Sequences",
|
| 322 |
+
placeholder=">seq1|description\nATCGATCGATCG...\n\nOr paste a single sequence directly",
|
| 323 |
+
lines=8
|
| 324 |
+
)
|
| 325 |
+
func_threshold = gr.Slider(
|
| 326 |
+
minimum=-3.0,
|
| 327 |
+
maximum=0.0,
|
| 328 |
+
value=-1.5,
|
| 329 |
+
step=0.1,
|
| 330 |
+
label="Functionality Threshold"
|
| 331 |
+
)
|
| 332 |
+
func_btn = gr.Button("Predict Function", variant="primary")
|
| 333 |
+
|
| 334 |
+
with gr.Column():
|
| 335 |
+
func_output = gr.Textbox(
|
| 336 |
+
label="Results",
|
| 337 |
+
lines=15,
|
| 338 |
+
show_copy_button=True
|
| 339 |
+
)
|
| 340 |
+
|
| 341 |
+
func_btn.click(
|
| 342 |
+
fn=predict_function,
|
| 343 |
+
inputs=[func_input, func_threshold],
|
| 344 |
+
outputs=func_output
|
| 345 |
+
)
|
| 346 |
+
|
| 347 |
+
gr.Examples(
|
| 348 |
+
examples=[
|
| 349 |
+
[">functional_gene\nATGGCACAACCCGCGCCGAACTGGTTGACCTGAAAACCACCGCCGCACTGCGTCAGGCCAGCCAGGCGGAACAA", -1.5],
|
| 350 |
+
[">noncoding\nGCTAGCTAGCTAGCTAGCTAGCTAGCTAGCTAGCTAGCTAGCTAGCTAGCTAGC", -1.5],
|
| 351 |
+
],
|
| 352 |
+
inputs=[func_input, func_threshold]
|
| 353 |
+
)
|
| 354 |
+
|
| 355 |
+
# Task 2: Gene Essentiality
|
| 356 |
+
with gr.Tab("🧬 Gene Essentiality"):
|
| 357 |
+
gr.Markdown("### Predict essential genes in bacteria/phages")
|
| 358 |
+
gr.Markdown("*Input format: >gene_id|organism|function*")
|
| 359 |
+
|
| 360 |
+
with gr.Row():
|
| 361 |
+
with gr.Column():
|
| 362 |
+
ess_input = gr.Textbox(
|
| 363 |
+
label="Gene Sequences (FASTA)",
|
| 364 |
+
placeholder=">dnaA|E.coli|Replication initiator\nATGTCGAAAGCCGCAT...",
|
| 365 |
+
lines=8
|
| 366 |
+
)
|
| 367 |
+
ess_btn = gr.Button("Predict Essentiality", variant="primary")
|
| 368 |
+
|
| 369 |
+
with gr.Column():
|
| 370 |
+
ess_output = gr.Textbox(
|
| 371 |
+
label="Results",
|
| 372 |
+
lines=15,
|
| 373 |
+
show_copy_button=True
|
| 374 |
+
)
|
| 375 |
+
|
| 376 |
+
ess_btn.click(
|
| 377 |
+
fn=predict_essentiality,
|
| 378 |
+
inputs=ess_input,
|
| 379 |
+
outputs=ess_output
|
| 380 |
+
)
|
| 381 |
+
|
| 382 |
+
# Task 3: CRISPR Generation
|
| 383 |
+
with gr.Tab("✂️ CRISPR Generation"):
|
| 384 |
+
gr.Markdown("### Generate synthetic CRISPR-Cas systems")
|
| 385 |
+
|
| 386 |
+
with gr.Row():
|
| 387 |
+
with gr.Column():
|
| 388 |
+
crispr_n = gr.Slider(
|
| 389 |
+
minimum=1,
|
| 390 |
+
maximum=5,
|
| 391 |
+
value=2,
|
| 392 |
+
step=1,
|
| 393 |
+
label="Number of Systems"
|
| 394 |
+
)
|
| 395 |
+
crispr_type = gr.Radio(
|
| 396 |
+
choices=["Cas9", "Cas12", "Both"],
|
| 397 |
+
value="Both",
|
| 398 |
+
label="Cas Type"
|
| 399 |
+
)
|
| 400 |
+
crispr_target = gr.Textbox(
|
| 401 |
+
label="Target Sequence (optional)",
|
| 402 |
+
placeholder="ATCGATCGATCGATCG",
|
| 403 |
+
lines=2
|
| 404 |
+
)
|
| 405 |
+
crispr_length = gr.Slider(
|
| 406 |
+
minimum=500,
|
| 407 |
+
maximum=2000,
|
| 408 |
+
value=1000,
|
| 409 |
+
step=100,
|
| 410 |
+
label="Cas Protein Length"
|
| 411 |
+
)
|
| 412 |
+
crispr_btn = gr.Button("Generate CRISPR Systems", variant="primary")
|
| 413 |
+
|
| 414 |
+
with gr.Column():
|
| 415 |
+
crispr_output = gr.Textbox(
|
| 416 |
+
label="Generated Systems",
|
| 417 |
+
lines=15,
|
| 418 |
+
show_copy_button=True
|
| 419 |
+
)
|
| 420 |
+
|
| 421 |
+
crispr_btn.click(
|
| 422 |
+
fn=generate_crispr,
|
| 423 |
+
inputs=[crispr_n, crispr_type, crispr_target, crispr_length],
|
| 424 |
+
outputs=crispr_output
|
| 425 |
+
)
|
| 426 |
+
|
| 427 |
+
# Task 4: Regulatory Design
|
| 428 |
+
with gr.Tab("🎛️ Regulatory Design"):
|
| 429 |
+
gr.Markdown("### Design promoter-RBS pairs for gene expression")
|
| 430 |
+
|
| 431 |
+
with gr.Row():
|
| 432 |
+
with gr.Column():
|
| 433 |
+
reg_n = gr.Slider(
|
| 434 |
+
minimum=1,
|
| 435 |
+
maximum=10,
|
| 436 |
+
value=3,
|
| 437 |
+
step=1,
|
| 438 |
+
label="Number of Designs"
|
| 439 |
+
)
|
| 440 |
+
reg_level = gr.Radio(
|
| 441 |
+
choices=["High", "Medium", "Low", "Mixed"],
|
| 442 |
+
value="Mixed",
|
| 443 |
+
label="Expression Level"
|
| 444 |
+
)
|
| 445 |
+
reg_btn = gr.Button("Design Regulatory Sequences", variant="primary")
|
| 446 |
+
|
| 447 |
+
with gr.Column():
|
| 448 |
+
reg_output = gr.Textbox(
|
| 449 |
+
label="Designed Sequences",
|
| 450 |
+
lines=15,
|
| 451 |
+
show_copy_button=True
|
| 452 |
+
)
|
| 453 |
+
|
| 454 |
+
reg_btn.click(
|
| 455 |
+
fn=design_regulatory,
|
| 456 |
+
inputs=[reg_n, reg_level],
|
| 457 |
+
outputs=reg_output
|
| 458 |
+
)
|
| 459 |
+
|
| 460 |
+
gr.Markdown("---")
|
| 461 |
+
gr.Markdown("💡 **Tips:** Higher scores = more functional/essential | All outputs can be copied | Model: evo-1-8k-base")
|
| 462 |
+
|
| 463 |
+
return demo
|
| 464 |
+
|
| 465 |
+
|
| 466 |
+
if __name__ == "__main__":
|
| 467 |
+
demo = create_interface()
|
| 468 |
+
demo.launch()
|
evo/__init__.py
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .version import version as __version__
|
| 2 |
+
|
| 3 |
+
from .models import Evo
|
| 4 |
+
|
| 5 |
+
from .generation import generate
|
| 6 |
+
from .scoring import score_sequences, positional_entropies
|
evo/configs/evo-1-131k-base_inference.yml
ADDED
|
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
vocab_size: 512
|
| 2 |
+
hidden_size: 4096
|
| 3 |
+
num_filters: 4096
|
| 4 |
+
max_sequence_len: 8192
|
| 5 |
+
attn_layer_idxs: [8, 16, 24]
|
| 6 |
+
hyena_layer_idxs: [0, 1, 2, 3, 4, 5, 6, 7, 9, 10, 11, 12, 13, 14, 15, 17, 18, 19, 20, 21, 22, 23, 25, 26, 27, 28, 29, 30, 31]
|
| 7 |
+
num_layers: 32
|
| 8 |
+
short_filter_length: 3
|
| 9 |
+
num_attention_heads: 32
|
| 10 |
+
short_filter_bias: True
|
| 11 |
+
mlp_init_method: torch.nn.init.zeros_
|
| 12 |
+
mlp_output_init_method: torch.nn.init.zeros_
|
| 13 |
+
eps: 1.0e-6
|
| 14 |
+
state_size: 8
|
| 15 |
+
inner_size_multiple_of: 16 # force GLU inner_size to be a multiple of
|
| 16 |
+
smeared_gqa: False
|
| 17 |
+
make_vocab_size_divisible_by: 8
|
| 18 |
+
log_intermediate_values: False
|
| 19 |
+
proj_groups: 1 # GQA
|
| 20 |
+
hyena_filter_groups: 1
|
| 21 |
+
split_k0: True
|
| 22 |
+
model_parallel_size: 1
|
| 23 |
+
pile_parallel_size: 1
|
| 24 |
+
tie_embeddings: True
|
| 25 |
+
inner_mlp_size: null # set to None, so it auto-fills
|
| 26 |
+
mha_out_proj_bias: True
|
| 27 |
+
qkv_proj_bias: True
|
| 28 |
+
final_norm: True
|
| 29 |
+
rng_fork: False
|
| 30 |
+
use_flash_attn: False
|
| 31 |
+
use_flash_rmsnorm: False
|
| 32 |
+
use_flash_depthwise: False
|
| 33 |
+
use_flashfft: False
|
| 34 |
+
column_split: True # only affects outputs when proj_groups > 1
|
| 35 |
+
inference_mode: True
|
| 36 |
+
tokenizer_type: CharLevelTokenizer
|
| 37 |
+
prefill_style: fft
|
| 38 |
+
mlp_activation: gelu
|
| 39 |
+
use_interpolated_rotary_pos_emb: true # turn this one for linear interpolated context extension
|
| 40 |
+
rotary_emb_scaling_factor: 16 # scaling factor for time indices in rotary embeddings
|
evo/configs/evo-1-8k-base_inference.yml
ADDED
|
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
vocab_size: 512
|
| 2 |
+
hidden_size: 4096
|
| 3 |
+
num_filters: 4096
|
| 4 |
+
max_sequence_len: 8192
|
| 5 |
+
attn_layer_idxs: [8, 16, 24]
|
| 6 |
+
hyena_layer_idxs: [0, 1, 2, 3, 4, 5, 6, 7, 9, 10, 11, 12, 13, 14, 15, 17, 18, 19, 20, 21, 22, 23, 25, 26, 27, 28, 29, 30, 31]
|
| 7 |
+
num_layers: 32
|
| 8 |
+
short_filter_length: 3
|
| 9 |
+
num_attention_heads: 32
|
| 10 |
+
short_filter_bias: True
|
| 11 |
+
mlp_init_method: torch.nn.init.zeros_
|
| 12 |
+
mlp_output_init_method: torch.nn.init.zeros_
|
| 13 |
+
eps: 1.0e-6
|
| 14 |
+
state_size: 8
|
| 15 |
+
inner_size_multiple_of: 16 # force GLU inner_size to be a multiple of
|
| 16 |
+
smeared_gqa: False
|
| 17 |
+
make_vocab_size_divisible_by: 8
|
| 18 |
+
log_intermediate_values: False
|
| 19 |
+
proj_groups: 1 # GQA
|
| 20 |
+
hyena_filter_groups: 1
|
| 21 |
+
split_k0: True
|
| 22 |
+
model_parallel_size: 1
|
| 23 |
+
pile_parallel_size: 1
|
| 24 |
+
tie_embeddings: True
|
| 25 |
+
inner_mlp_size: null # set to None, so it auto-fills
|
| 26 |
+
mha_out_proj_bias: True
|
| 27 |
+
qkv_proj_bias: True
|
| 28 |
+
final_norm: True
|
| 29 |
+
rng_fork: False
|
| 30 |
+
use_flash_attn: False
|
| 31 |
+
use_flash_rmsnorm: False
|
| 32 |
+
use_flash_depthwise: False
|
| 33 |
+
use_flashfft: False
|
| 34 |
+
column_split: True # only affects outputs when proj_groups > 1
|
| 35 |
+
inference_mode: True
|
| 36 |
+
tokenizer_type: CharLevelTokenizer
|
| 37 |
+
prefill_style: fft
|
| 38 |
+
mlp_activation: gelu
|
evo/generation.py
ADDED
|
@@ -0,0 +1,297 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
import sys
|
| 3 |
+
import torch
|
| 4 |
+
from typing import List, Tuple, Union
|
| 5 |
+
|
| 6 |
+
from stripedhyena.model import StripedHyena
|
| 7 |
+
from stripedhyena.sample import sample
|
| 8 |
+
from stripedhyena.tokenizer import CharLevelTokenizer
|
| 9 |
+
|
| 10 |
+
from .scoring import logits_to_logprobs, prepare_batch
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class Generator:
|
| 14 |
+
'''
|
| 15 |
+
Adapted from https://github.com/togethercomputer/stripedhyena.
|
| 16 |
+
|
| 17 |
+
Modifications include:
|
| 18 |
+
- `generate()` accepts and returns the recurrent cache state, letting the user
|
| 19 |
+
keep track of it across sampling runs.
|
| 20 |
+
- Able to sample with long token prompts in which the cache is initialized with
|
| 21 |
+
recurrent teacher forcing.
|
| 22 |
+
'''
|
| 23 |
+
def __init__(
|
| 24 |
+
self,
|
| 25 |
+
model: StripedHyena,
|
| 26 |
+
tokenizer: CharLevelTokenizer,
|
| 27 |
+
top_k: int = 50,
|
| 28 |
+
top_p: float = 0.7,
|
| 29 |
+
temperature: float = 1.,
|
| 30 |
+
):
|
| 31 |
+
self.model = model
|
| 32 |
+
self.tokenizer = tokenizer
|
| 33 |
+
self.top_k = top_k
|
| 34 |
+
self.top_p = top_p
|
| 35 |
+
self.temperature = temperature
|
| 36 |
+
self.untils = ['\n\n']
|
| 37 |
+
|
| 38 |
+
def generate(
|
| 39 |
+
self,
|
| 40 |
+
device: str,
|
| 41 |
+
input_string: str = None,
|
| 42 |
+
input_ids: torch.tensor = None,
|
| 43 |
+
num_tokens: int = 32,
|
| 44 |
+
cached_generation: bool = True,
|
| 45 |
+
force_prompt_threshold: int = 128,
|
| 46 |
+
print_generation: bool = True,
|
| 47 |
+
verbose: bool = False,
|
| 48 |
+
skip_special_tokens: bool = False,
|
| 49 |
+
stop_at_eos: bool = True,
|
| 50 |
+
max_seqlen: int = None,
|
| 51 |
+
inference_params_dict: dict = None,
|
| 52 |
+
) -> Tuple[torch.tensor, torch.tensor, dict]:
|
| 53 |
+
"""
|
| 54 |
+
A version of the generate() method that enables passing in and that returns the
|
| 55 |
+
`inference_params_dict` for replaying cached sampling from a given state.
|
| 56 |
+
"""
|
| 57 |
+
if isinstance(self.tokenizer.eos, int):
|
| 58 |
+
eos_token_ids = torch.LongTensor([self.tokenizer.eos]).to(device)
|
| 59 |
+
else:
|
| 60 |
+
# is a tensor
|
| 61 |
+
eos_token_ids = self.tokenizer.tokenize(self.tokenizer.eos).to(device)
|
| 62 |
+
|
| 63 |
+
if input_ids is None:
|
| 64 |
+
input = self.tokenizer.tokenize(input_string)
|
| 65 |
+
if isinstance(input, list):
|
| 66 |
+
input = torch.LongTensor(input).unsqueeze(0).to(device)
|
| 67 |
+
# is a tensor
|
| 68 |
+
else:
|
| 69 |
+
input = input.unsqueeze(0).to(device)
|
| 70 |
+
|
| 71 |
+
else:
|
| 72 |
+
input = input_ids
|
| 73 |
+
x = input
|
| 74 |
+
|
| 75 |
+
if max_seqlen is not None:
|
| 76 |
+
x = x[:, -max_seqlen :]
|
| 77 |
+
|
| 78 |
+
num_tokens = int(num_tokens)
|
| 79 |
+
batch_size = x.shape[0]
|
| 80 |
+
|
| 81 |
+
prompt_length = x.shape[1]
|
| 82 |
+
prompt_forcing = prompt_length > force_prompt_threshold
|
| 83 |
+
if prompt_forcing:
|
| 84 |
+
forced_prompt_length = prompt_length - force_prompt_threshold
|
| 85 |
+
x_force = x[:, force_prompt_threshold:]
|
| 86 |
+
x = x[:, :force_prompt_threshold]
|
| 87 |
+
else:
|
| 88 |
+
forced_prompt_length = 0
|
| 89 |
+
|
| 90 |
+
generation = torch.empty(
|
| 91 |
+
x.shape[0],
|
| 92 |
+
num_tokens,
|
| 93 |
+
dtype=torch.long,
|
| 94 |
+
device=x.device,
|
| 95 |
+
)
|
| 96 |
+
|
| 97 |
+
scores = torch.empty(
|
| 98 |
+
x.shape[0],
|
| 99 |
+
num_tokens,
|
| 100 |
+
self.tokenizer.vocab_size,
|
| 101 |
+
dtype=torch.float,
|
| 102 |
+
device=x.device,
|
| 103 |
+
)
|
| 104 |
+
|
| 105 |
+
if inference_params_dict is not None:
|
| 106 |
+
cached_generation = True
|
| 107 |
+
prefilled = True
|
| 108 |
+
# Ensure that the cached data is loaded on the correct device.
|
| 109 |
+
for key, data in inference_params_dict['mha'].key_value_memory_dict.items():
|
| 110 |
+
inference_params_dict['mha'].key_value_memory_dict[key] = data.to(x.device)
|
| 111 |
+
for key, data in inference_params_dict['hyena'].fir_state_dict.items():
|
| 112 |
+
inference_params_dict['hyena'].fir_state_dict[key] = data.to(x.device)
|
| 113 |
+
for key, data in inference_params_dict['hyena'].state_dict.items():
|
| 114 |
+
inference_params_dict['hyena'].state_dict[key] = data.to(x.device)
|
| 115 |
+
|
| 116 |
+
elif cached_generation:
|
| 117 |
+
inference_params_dict = self.model.initialize_inference_params()
|
| 118 |
+
inference_params_dict['mha'].max_batch_size = batch_size
|
| 119 |
+
inference_params_dict['hyena'].max_batch_size = batch_size
|
| 120 |
+
prefilled = False
|
| 121 |
+
|
| 122 |
+
if verbose:
|
| 123 |
+
mem_after_tok = torch.cuda.memory_allocated(device=x.device) / 1e9
|
| 124 |
+
print(f'Memory after tokenization: {mem_after_tok} GB')
|
| 125 |
+
print('Starting generation...')
|
| 126 |
+
if input_string is not None:
|
| 127 |
+
print('Prompt: ' + input_string)
|
| 128 |
+
else:
|
| 129 |
+
print(f'Prompt ids: {input_ids} {input_ids.shape}')
|
| 130 |
+
|
| 131 |
+
for i in range(forced_prompt_length + num_tokens):
|
| 132 |
+
if prefilled:
|
| 133 |
+
post_prefill = True
|
| 134 |
+
else:
|
| 135 |
+
post_prefill = cached_generation and i > 0
|
| 136 |
+
|
| 137 |
+
# prefill then process only the last token
|
| 138 |
+
if post_prefill:
|
| 139 |
+
x = x[:, -1:]
|
| 140 |
+
seqlen_offset = inference_params_dict['mha'].seqlen_offset
|
| 141 |
+
|
| 142 |
+
if seqlen_offset == 0:
|
| 143 |
+
seqlen_offset = input.shape[-1]
|
| 144 |
+
inference_params_dict['hyena'].seqlen_offset = seqlen_offset
|
| 145 |
+
inference_params_dict['mha'].seqlen_offset = seqlen_offset
|
| 146 |
+
else:
|
| 147 |
+
inference_params_dict['mha'].seqlen_offset += 1
|
| 148 |
+
inference_params_dict['hyena'].seqlen_offset += 1
|
| 149 |
+
|
| 150 |
+
# do forward pass with no gradient
|
| 151 |
+
with torch.inference_mode():
|
| 152 |
+
logits, inference_params_dict = self.model(
|
| 153 |
+
x,
|
| 154 |
+
inference_params_dict=inference_params_dict,
|
| 155 |
+
)
|
| 156 |
+
|
| 157 |
+
last_logits = logits[:, -1]
|
| 158 |
+
|
| 159 |
+
if prompt_forcing and i < forced_prompt_length:
|
| 160 |
+
new_idx = x_force[:, i]
|
| 161 |
+
else:
|
| 162 |
+
new_idx = sample(
|
| 163 |
+
last_logits,
|
| 164 |
+
top_k=self.top_k,
|
| 165 |
+
top_p=self.top_p,
|
| 166 |
+
temperature=self.temperature,
|
| 167 |
+
)
|
| 168 |
+
|
| 169 |
+
if stop_at_eos and (generation[0, -2:] == eos_token_ids).all():
|
| 170 |
+
print('Stopping generation at EOS')
|
| 171 |
+
|
| 172 |
+
if print_generation and verbose and batch_size == 1:
|
| 173 |
+
print(
|
| 174 |
+
f'{self.tokenizer.detokenize([new_idx.item()])}',
|
| 175 |
+
end=' ',
|
| 176 |
+
)
|
| 177 |
+
|
| 178 |
+
if prompt_forcing:
|
| 179 |
+
if i >= forced_prompt_length:
|
| 180 |
+
scores[:, i - forced_prompt_length] = last_logits
|
| 181 |
+
generation[:, i - forced_prompt_length] = new_idx
|
| 182 |
+
else:
|
| 183 |
+
scores[:, i] = last_logits
|
| 184 |
+
generation[:, i] = new_idx
|
| 185 |
+
|
| 186 |
+
if post_prefill:
|
| 187 |
+
x = new_idx[:, None]
|
| 188 |
+
else:
|
| 189 |
+
x = torch.cat([x, new_idx[:, None]], dim=-1)
|
| 190 |
+
|
| 191 |
+
if verbose:
|
| 192 |
+
y = self.tokenizer.detokenize_batch(generation[:, : i + 1])
|
| 193 |
+
|
| 194 |
+
for until in self.untils:
|
| 195 |
+
if until in y:
|
| 196 |
+
y = y.split(until)[0]
|
| 197 |
+
break
|
| 198 |
+
|
| 199 |
+
print(f'\nInput: {input_string}, Output: {y}')
|
| 200 |
+
|
| 201 |
+
mem_end = torch.cuda.memory_allocated(device=x.device) / 1e9
|
| 202 |
+
print(f'Memory after generation: {mem_end} GB')
|
| 203 |
+
|
| 204 |
+
return generation[:, : i + 1], scores[:, : i + 1], inference_params_dict
|
| 205 |
+
|
| 206 |
+
|
| 207 |
+
def generate(
|
| 208 |
+
prompt_seqs: List[str],
|
| 209 |
+
model: StripedHyena,
|
| 210 |
+
tokenizer: CharLevelTokenizer,
|
| 211 |
+
n_tokens: int = 100,
|
| 212 |
+
temperature: float = 0.,
|
| 213 |
+
top_k: int = 1,
|
| 214 |
+
top_p: float = 1.,
|
| 215 |
+
batched: bool = True,
|
| 216 |
+
prepend_bos: bool = False,
|
| 217 |
+
cached_generation: bool = False,
|
| 218 |
+
force_prompt_threshold: int = 128,
|
| 219 |
+
verbose: int = 1,
|
| 220 |
+
device: str = 'cuda:0',
|
| 221 |
+
**kwargs,
|
| 222 |
+
) -> Tuple[List[str], List[float]]:
|
| 223 |
+
"""
|
| 224 |
+
Performs generation from a list of prompts.
|
| 225 |
+
If all prompts are the same length, this can do batched generation.
|
| 226 |
+
Also supports cached generation for efficient sampling.
|
| 227 |
+
"""
|
| 228 |
+
model.eval()
|
| 229 |
+
|
| 230 |
+
g = Generator(
|
| 231 |
+
model,
|
| 232 |
+
tokenizer,
|
| 233 |
+
top_k=top_k,
|
| 234 |
+
top_p=top_p,
|
| 235 |
+
temperature=temperature,
|
| 236 |
+
)
|
| 237 |
+
|
| 238 |
+
uniform_lengths = all(len(s) == len(prompt_seqs[0]) for s in prompt_seqs)
|
| 239 |
+
|
| 240 |
+
if batched and uniform_lengths:
|
| 241 |
+
input_ids_list = [
|
| 242 |
+
prepare_batch(
|
| 243 |
+
prompt_seqs,
|
| 244 |
+
tokenizer,
|
| 245 |
+
prepend_bos=prepend_bos,
|
| 246 |
+
device=device,
|
| 247 |
+
)[0]
|
| 248 |
+
]
|
| 249 |
+
else:
|
| 250 |
+
if verbose:
|
| 251 |
+
if not uniform_lengths:
|
| 252 |
+
sys.stderr.write('Note: Prompts are of different lengths.\n')
|
| 253 |
+
sys.stderr.write('Note: Will not do batched generation.\n')
|
| 254 |
+
input_ids_list = [
|
| 255 |
+
prepare_batch(
|
| 256 |
+
[ prompt_seq ],
|
| 257 |
+
tokenizer,
|
| 258 |
+
prepend_bos=prepend_bos,
|
| 259 |
+
device=device,
|
| 260 |
+
)[0]
|
| 261 |
+
for prompt_seq in prompt_seqs
|
| 262 |
+
]
|
| 263 |
+
|
| 264 |
+
generated_seqs, generated_scores = [], []
|
| 265 |
+
for input_ids in input_ids_list:
|
| 266 |
+
batch_size = input_ids.shape[0]
|
| 267 |
+
|
| 268 |
+
output_ids, logits, _ = g.generate(
|
| 269 |
+
input_ids=input_ids,
|
| 270 |
+
num_tokens=n_tokens,
|
| 271 |
+
cached_generation=cached_generation,
|
| 272 |
+
force_prompt_threshold=force_prompt_threshold,
|
| 273 |
+
device=device,
|
| 274 |
+
print_generation=(verbose > 1),
|
| 275 |
+
verbose=(verbose > 1),
|
| 276 |
+
stop_at_eos=False,
|
| 277 |
+
)
|
| 278 |
+
if verbose > 1:
|
| 279 |
+
print('input_ids.shape', input_ids.shape)
|
| 280 |
+
print('output_ids.shape', output_ids.shape)
|
| 281 |
+
print('logits.shape', logits.shape)
|
| 282 |
+
|
| 283 |
+
generated_seqs_batch = list(tokenizer.detokenize_batch(output_ids))
|
| 284 |
+
assert len(generated_seqs_batch) == batch_size
|
| 285 |
+
generated_seqs += generated_seqs_batch
|
| 286 |
+
|
| 287 |
+
logprobs = logits_to_logprobs(logits, output_ids)
|
| 288 |
+
logprobs = logprobs.float().cpu().numpy()
|
| 289 |
+
|
| 290 |
+
generated_scores += [ np.mean(logprobs[idx]) for idx in range(batch_size) ]
|
| 291 |
+
|
| 292 |
+
assert len(generated_seqs) == len(generated_scores) == len(prompt_seqs)
|
| 293 |
+
if verbose:
|
| 294 |
+
for seq, score, prompt in zip(generated_seqs, generated_scores, prompt_seqs):
|
| 295 |
+
print(f'Prompt: "{prompt}",\tOutput: "{seq}",\tScore: {score}')
|
| 296 |
+
|
| 297 |
+
return generated_seqs, generated_scores
|
evo/models.py
ADDED
|
@@ -0,0 +1,122 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import pkgutil
|
| 2 |
+
import re
|
| 3 |
+
from transformers import AutoConfig, AutoModelForCausalLM
|
| 4 |
+
import yaml
|
| 5 |
+
|
| 6 |
+
from stripedhyena.utils import dotdict
|
| 7 |
+
from stripedhyena.model import StripedHyena
|
| 8 |
+
from stripedhyena.tokenizer import CharLevelTokenizer
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
MODEL_NAMES = [
|
| 12 |
+
'evo-1.5-8k-base',
|
| 13 |
+
'evo-1-8k-base',
|
| 14 |
+
'evo-1-131k-base',
|
| 15 |
+
'evo-1-8k-crispr',
|
| 16 |
+
'evo-1-8k-transposon',
|
| 17 |
+
]
|
| 18 |
+
|
| 19 |
+
class Evo:
|
| 20 |
+
def __init__(self, model_name: str = MODEL_NAMES[1], device: str = None):
|
| 21 |
+
"""
|
| 22 |
+
Loads an Evo model checkpoint given a model name.
|
| 23 |
+
If the checkpoint does not exist, we automatically download it from HuggingFace.
|
| 24 |
+
"""
|
| 25 |
+
self.device = device
|
| 26 |
+
|
| 27 |
+
# Check model name.
|
| 28 |
+
|
| 29 |
+
if model_name not in MODEL_NAMES:
|
| 30 |
+
raise ValueError(
|
| 31 |
+
f'Invalid model name {model_name}. Should be one of: '
|
| 32 |
+
f'{", ".join(MODEL_NAMES)}.'
|
| 33 |
+
)
|
| 34 |
+
|
| 35 |
+
# Assign config path.
|
| 36 |
+
|
| 37 |
+
if model_name == 'evo-1-8k-base' or \
|
| 38 |
+
model_name == 'evo-1-8k-crispr' or \
|
| 39 |
+
model_name == 'evo-1-8k-transposon' or \
|
| 40 |
+
model_name == 'evo-1.5-8k-base':
|
| 41 |
+
config_path = 'configs/evo-1-8k-base_inference.yml'
|
| 42 |
+
elif model_name == 'evo-1-131k-base':
|
| 43 |
+
config_path = 'configs/evo-1-131k-base_inference.yml'
|
| 44 |
+
else:
|
| 45 |
+
raise ValueError(
|
| 46 |
+
f'Invalid model name {model_name}. Should be one of: '
|
| 47 |
+
f'{", ".join(MODEL_NAMES)}.'
|
| 48 |
+
)
|
| 49 |
+
|
| 50 |
+
# Load model.
|
| 51 |
+
|
| 52 |
+
self.model = load_checkpoint(
|
| 53 |
+
model_name=model_name,
|
| 54 |
+
config_path=config_path,
|
| 55 |
+
device=self.device
|
| 56 |
+
)
|
| 57 |
+
|
| 58 |
+
# Load tokenizer.
|
| 59 |
+
|
| 60 |
+
self.tokenizer = CharLevelTokenizer(512)
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
HF_MODEL_NAME_MAP = {
|
| 64 |
+
'evo-1.5-8k-base': 'evo-design/evo-1.5-8k-base',
|
| 65 |
+
'evo-1-8k-base': 'togethercomputer/evo-1-8k-base',
|
| 66 |
+
'evo-1-131k-base': 'togethercomputer/evo-1-131k-base',
|
| 67 |
+
'evo-1-8k-crispr': 'LongSafari/evo-1-8k-crispr',
|
| 68 |
+
'evo-1-8k-transposon': 'LongSafari/evo-1-8k-transposon',
|
| 69 |
+
}
|
| 70 |
+
|
| 71 |
+
def load_checkpoint(
|
| 72 |
+
model_name: str = MODEL_NAMES[1],
|
| 73 |
+
config_path: str = 'evo/configs/evo-1-131k-base_inference.yml',
|
| 74 |
+
device: str = None,
|
| 75 |
+
*args, **kwargs
|
| 76 |
+
):
|
| 77 |
+
"""
|
| 78 |
+
Load checkpoint from HuggingFace and place it into SH model.
|
| 79 |
+
"""
|
| 80 |
+
|
| 81 |
+
# Map model name to HuggingFace model name.
|
| 82 |
+
|
| 83 |
+
hf_model_name = HF_MODEL_NAME_MAP[model_name]
|
| 84 |
+
|
| 85 |
+
# Load model config.
|
| 86 |
+
|
| 87 |
+
model_config = AutoConfig.from_pretrained(
|
| 88 |
+
hf_model_name,
|
| 89 |
+
trust_remote_code=True,
|
| 90 |
+
revision='1.1_fix' if re.match(r'evo-1-.*-base', model_name) else 'main',
|
| 91 |
+
)
|
| 92 |
+
model_config.use_cache = True
|
| 93 |
+
|
| 94 |
+
# Load model.
|
| 95 |
+
|
| 96 |
+
model = AutoModelForCausalLM.from_pretrained(
|
| 97 |
+
hf_model_name,
|
| 98 |
+
config=model_config,
|
| 99 |
+
trust_remote_code=True,
|
| 100 |
+
revision='1.1_fix' if re.match(r'evo-1-.*-base', model_name) else 'main',
|
| 101 |
+
)
|
| 102 |
+
|
| 103 |
+
# Load model state dict & cleanup.
|
| 104 |
+
|
| 105 |
+
state_dict = model.backbone.state_dict()
|
| 106 |
+
del model
|
| 107 |
+
del model_config
|
| 108 |
+
|
| 109 |
+
# Load SH config.
|
| 110 |
+
|
| 111 |
+
config = yaml.safe_load(pkgutil.get_data(__name__, config_path))
|
| 112 |
+
global_config = dotdict(config, Loader=yaml.FullLoader)
|
| 113 |
+
|
| 114 |
+
# Load SH Model.
|
| 115 |
+
|
| 116 |
+
model = StripedHyena(global_config)
|
| 117 |
+
model.load_state_dict(state_dict, strict=True)
|
| 118 |
+
model.to_bfloat16_except_poles_residues()
|
| 119 |
+
if device is not None:
|
| 120 |
+
model = model.to(device)
|
| 121 |
+
|
| 122 |
+
return model
|
evo/scoring.py
ADDED
|
@@ -0,0 +1,131 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
import torch
|
| 3 |
+
from typing import List, Tuple
|
| 4 |
+
|
| 5 |
+
from stripedhyena.model import StripedHyena
|
| 6 |
+
from stripedhyena.tokenizer import CharLevelTokenizer
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
def prepare_batch(
|
| 10 |
+
seqs: List[str],
|
| 11 |
+
tokenizer: CharLevelTokenizer,
|
| 12 |
+
prepend_bos: bool = True,
|
| 13 |
+
device: str = 'cuda:0'
|
| 14 |
+
) -> Tuple[torch.Tensor, List[int]]:
|
| 15 |
+
"""
|
| 16 |
+
Takes in a list of sequences, tokenizes them, and puts them in a tensor batch.
|
| 17 |
+
If the sequences have differing lengths, then pad up to the maximum sequence length.
|
| 18 |
+
"""
|
| 19 |
+
seq_lengths = [ len(seq) for seq in seqs ]
|
| 20 |
+
max_seq_length = max(seq_lengths)
|
| 21 |
+
|
| 22 |
+
input_ids = []
|
| 23 |
+
for seq in seqs:
|
| 24 |
+
padding = [tokenizer.pad_id] * (max_seq_length - len(seq))
|
| 25 |
+
input_ids.append(
|
| 26 |
+
torch.tensor(
|
| 27 |
+
([tokenizer.eod_id] * int(prepend_bos)) + tokenizer.tokenize(seq) + padding,
|
| 28 |
+
dtype=torch.long,
|
| 29 |
+
).to(device).unsqueeze(0)
|
| 30 |
+
)
|
| 31 |
+
input_ids = torch.cat(input_ids, dim=0)
|
| 32 |
+
|
| 33 |
+
return input_ids, seq_lengths
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def logits_to_logprobs(
|
| 37 |
+
logits: torch.Tensor,
|
| 38 |
+
input_ids: torch.Tensor,
|
| 39 |
+
trim_bos: bool = True,
|
| 40 |
+
) -> torch.Tensor:
|
| 41 |
+
"""
|
| 42 |
+
Takes in a tensor of logits of dimension (batch, length, vocab).
|
| 43 |
+
Computes the log-likelihoods using a softmax along the vocab dimension.
|
| 44 |
+
Uses the `input_ids` to index into the log-likelihoods and returns the likelihood
|
| 45 |
+
of the provided sequence at each position with dimension (batch, length).
|
| 46 |
+
"""
|
| 47 |
+
softmax_logprobs = torch.log_softmax(logits, dim=-1)
|
| 48 |
+
if trim_bos:
|
| 49 |
+
softmax_logprobs = softmax_logprobs[:, :-1] # Remove last prediction.
|
| 50 |
+
input_ids = input_ids[:, 1:] # Trim BOS added by tokenizer.
|
| 51 |
+
assert(softmax_logprobs.shape[1] == input_ids.shape[1])
|
| 52 |
+
|
| 53 |
+
logprobs = torch.gather(
|
| 54 |
+
softmax_logprobs, # Gather likelihoods...
|
| 55 |
+
2, # along the vocab dimension...
|
| 56 |
+
input_ids.unsqueeze(-1) # using the token ids to index.
|
| 57 |
+
).squeeze(-1)
|
| 58 |
+
|
| 59 |
+
return logprobs
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
def score_sequences(
|
| 63 |
+
seqs: List[str],
|
| 64 |
+
model: StripedHyena,
|
| 65 |
+
tokenizer: CharLevelTokenizer,
|
| 66 |
+
reduce_method: str = 'mean',
|
| 67 |
+
device: str = 'cuda:0',
|
| 68 |
+
) -> List[float]:
|
| 69 |
+
"""
|
| 70 |
+
Computes the model log-likelihood scores for sequences in `seqs`.
|
| 71 |
+
Uses `reduce_method` to take the mean or sum across the likelihoods at each
|
| 72 |
+
position (default: `'mean'`).
|
| 73 |
+
|
| 74 |
+
Returns a list of scalar scores corresponding to the reduced log-likelihoods for
|
| 75 |
+
each sequence.
|
| 76 |
+
"""
|
| 77 |
+
input_ids, seq_lengths = prepare_batch(seqs, tokenizer, device=device, prepend_bos=True)
|
| 78 |
+
assert(len(seq_lengths) == input_ids.shape[0])
|
| 79 |
+
|
| 80 |
+
with torch.inference_mode():
|
| 81 |
+
logits, _ = model(input_ids) # (batch, length, vocab)
|
| 82 |
+
|
| 83 |
+
logprobs = logits_to_logprobs(logits, input_ids, trim_bos=True)
|
| 84 |
+
logprobs = logprobs.float().cpu().numpy()
|
| 85 |
+
|
| 86 |
+
if reduce_method == 'mean':
|
| 87 |
+
reduce_func = np.mean
|
| 88 |
+
elif reduce_method == 'sum':
|
| 89 |
+
reduce_func = np.sum
|
| 90 |
+
else:
|
| 91 |
+
raise ValueError(f'Invalid reduce_method {reduce_method}')
|
| 92 |
+
|
| 93 |
+
return [
|
| 94 |
+
reduce_func(logprobs[idx][:seq_lengths[idx]])
|
| 95 |
+
for idx in range(len(seq_lengths))
|
| 96 |
+
]
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
def positional_entropies(
|
| 100 |
+
seqs: List[str],
|
| 101 |
+
model: StripedHyena,
|
| 102 |
+
tokenizer: CharLevelTokenizer,
|
| 103 |
+
device: str = 'cuda:0',
|
| 104 |
+
) -> List[np.array]:
|
| 105 |
+
"""
|
| 106 |
+
Computes the positional entropies for sequences in `seqs`.
|
| 107 |
+
|
| 108 |
+
Returns a list of arrays, where each array is the same length as the
|
| 109 |
+
corresponding sequence length. Each array contains the per-position entropy
|
| 110 |
+
across the vocab dimension.
|
| 111 |
+
"""
|
| 112 |
+
input_ids, seq_lengths = prepare_batch(seqs, tokenizer, device=device, prepend_bos=True)
|
| 113 |
+
assert(len(seq_lengths) == input_ids.shape[0])
|
| 114 |
+
|
| 115 |
+
with torch.inference_mode():
|
| 116 |
+
logits, _ = model(input_ids) # (batch, length, vocab)
|
| 117 |
+
|
| 118 |
+
# Tokenizer prepends BOS, remember to remove last prediction.
|
| 119 |
+
softmax_logprobs = torch.log_softmax(logits, dim=-1)[:, :-1]
|
| 120 |
+
|
| 121 |
+
entropies = -torch.sum(torch.exp(softmax_logprobs) * softmax_logprobs, dim=-1)
|
| 122 |
+
entropies = entropies.float().cpu().numpy()
|
| 123 |
+
|
| 124 |
+
sequence_entropies = [
|
| 125 |
+
entropies[idx][:seq_lengths[idx]] for idx in range(len(seq_lengths))
|
| 126 |
+
]
|
| 127 |
+
assert all(
|
| 128 |
+
len(seq) == len(entropy) for seq, entropy in zip(seqs, sequence_entropies)
|
| 129 |
+
)
|
| 130 |
+
|
| 131 |
+
return sequence_entropies
|
evo/utils.py
ADDED
|
@@ -0,0 +1,183 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
import pandas as pd
|
| 3 |
+
from typing import Callable
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
NTs = 'ACGT'
|
| 7 |
+
|
| 8 |
+
AAs = 'ACDEFGHIKLMNPQRSTVWY'
|
| 9 |
+
|
| 10 |
+
AA_TO_CODON = {
|
| 11 |
+
'*': ['TAA', 'TAG', 'TGA'], # Stop.
|
| 12 |
+
'A': ['GCT', 'GCC', 'GCA', 'GCG'], # Ala.
|
| 13 |
+
'C': ['TGT', 'TGC'], # Cys.
|
| 14 |
+
'D': ['GAT', 'GAC'], # Asp.
|
| 15 |
+
'E': ['GAA', 'GAG'], # Glu.
|
| 16 |
+
'F': ['TTT', 'TTC'], # Phe.
|
| 17 |
+
'G': ['GGU', 'GGC', 'GGA', 'GGG'], # Gly.
|
| 18 |
+
'H': ['CAT', 'CAC'], # His.
|
| 19 |
+
'I': ['ATT', 'ATC', 'ATA'], # Ile.
|
| 20 |
+
'K': ['AAA', 'AAG'], # Lys.
|
| 21 |
+
'L': ['TTA', 'TTG', 'CTT', 'CTC', 'CTA', 'CTG'], # Leu.
|
| 22 |
+
'M': ['ATG'], # Met.
|
| 23 |
+
'N': ['AAT', 'AAC'], # Asn.
|
| 24 |
+
'P': ['CCT', 'CCC', 'CCA', 'CCG'], # Pro.
|
| 25 |
+
'Q': ['CAA', 'CAG'], # Gln.
|
| 26 |
+
'R': ['CGT', 'CGC', 'CGA', 'CGG', 'AGA', 'AGG'], # Arg.
|
| 27 |
+
'S': ['TCT', 'TCC', 'TCA', 'TCG', 'AGT', 'AGC'], # Ser.
|
| 28 |
+
'T': ['ACT', 'ACC', 'ACA', 'ACG'], # Thr.
|
| 29 |
+
'V': ['GTT', 'GTC', 'GTA', 'GTG'], # Val.
|
| 30 |
+
'W': ['TGG'], # Trp.
|
| 31 |
+
'Y': ['TAT', 'TAC'], # Tyr.
|
| 32 |
+
}
|
| 33 |
+
|
| 34 |
+
CODON_TO_AA = {
|
| 35 |
+
codon: aa
|
| 36 |
+
for aa, codon_list in AA_TO_CODON.items()
|
| 37 |
+
for codon in codon_list
|
| 38 |
+
}
|
| 39 |
+
|
| 40 |
+
AA_3_TO_1 = {
|
| 41 |
+
"Ala": "A", # Alanine
|
| 42 |
+
"Arg": "R", # Arginine
|
| 43 |
+
"Asn": "N", # Asparagine
|
| 44 |
+
"Asp": "D", # Aspartic acid
|
| 45 |
+
"Cys": "C", # Cysteine
|
| 46 |
+
"Gln": "Q", # Glutamine
|
| 47 |
+
"Glu": "E", # Glutamic acid
|
| 48 |
+
"Gly": "G", # Glycine
|
| 49 |
+
"His": "H", # Histidine
|
| 50 |
+
"Ile": "I", # Isoleucine
|
| 51 |
+
"Leu": "L", # Leucine
|
| 52 |
+
"Lys": "K", # Lysine
|
| 53 |
+
"Met": "M", # Methionine
|
| 54 |
+
"Phe": "F", # Phenylalanine
|
| 55 |
+
"Pro": "P", # Proline
|
| 56 |
+
"Ser": "S", # Serine
|
| 57 |
+
"Thr": "T", # Threonine
|
| 58 |
+
"Trp": "W", # Tryptophan
|
| 59 |
+
"Tyr": "Y", # Tyrosine
|
| 60 |
+
"Val": "V" # Valine
|
| 61 |
+
}
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
def nucleotide_deep_mutational_scan(sequence: str, ignore_wt: bool = True):
|
| 65 |
+
for idx, wt in enumerate(sequence):
|
| 66 |
+
for mt in NTs:
|
| 67 |
+
if ignore_wt and wt == mt:
|
| 68 |
+
continue
|
| 69 |
+
yield (wt, mt, idx)
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
def parse_blast_output(output_path: str) -> pd.DataFrame:
|
| 73 |
+
"""
|
| 74 |
+
Parses standard blast output with `-outfmt 6`.
|
| 75 |
+
"""
|
| 76 |
+
# blast default format output fields.
|
| 77 |
+
blast_table_header = [
|
| 78 |
+
'qacc', 'sacc', 'pident', 'length', 'mismatch', 'gapopen', 'qstart',
|
| 79 |
+
'qend', 'sstart', 'send', 'evalue',
|
| 80 |
+
]
|
| 81 |
+
|
| 82 |
+
data = []
|
| 83 |
+
with open(output_path, 'r') as f:
|
| 84 |
+
for line in f:
|
| 85 |
+
if line.startswith("#"):
|
| 86 |
+
continue
|
| 87 |
+
if line.strip() == '':
|
| 88 |
+
continue
|
| 89 |
+
line = line.strip().split()
|
| 90 |
+
data.append(dict(zip(blast_table_header, line)))
|
| 91 |
+
|
| 92 |
+
df = pd.DataFrame(data)
|
| 93 |
+
if len(df) == 0:
|
| 94 |
+
return df
|
| 95 |
+
df['evalue'] = df['evalue'].astype(float)
|
| 96 |
+
|
| 97 |
+
return df
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
def parse_erpin_output(output_path: str, name: str) -> pd.DataFrame:
|
| 101 |
+
"""
|
| 102 |
+
Parses ERPIN output. For an example, see `eval/data/example_rho_output.txt`.
|
| 103 |
+
"""
|
| 104 |
+
# ERPIN format output fields.
|
| 105 |
+
output_fields = [ 'strand', 'index', 'interval', 'score', 'evalue' ]
|
| 106 |
+
|
| 107 |
+
data = []
|
| 108 |
+
with open(output_path, 'r') as f:
|
| 109 |
+
for line in f:
|
| 110 |
+
if line.startswith(f'>{name}'):
|
| 111 |
+
meta = dict(zip(output_fields, f.readline().rstrip().split()))
|
| 112 |
+
sequence = f.readline().rstrip()
|
| 113 |
+
start, end = meta['interval'].split('..')
|
| 114 |
+
data.append([
|
| 115 |
+
f"{name}_{meta['index']}_{meta['strand']}",
|
| 116 |
+
sequence,
|
| 117 |
+
int(start),
|
| 118 |
+
int(end),
|
| 119 |
+
'+' if meta['strand'] == 'FW' else '-',
|
| 120 |
+
meta['score'],
|
| 121 |
+
float(meta['evalue']),
|
| 122 |
+
])
|
| 123 |
+
|
| 124 |
+
return pd.DataFrame(
|
| 125 |
+
data,
|
| 126 |
+
columns=[
|
| 127 |
+
'id',
|
| 128 |
+
'seq',
|
| 129 |
+
'start',
|
| 130 |
+
'end',
|
| 131 |
+
'strand',
|
| 132 |
+
'score',
|
| 133 |
+
'evalue',
|
| 134 |
+
],
|
| 135 |
+
)
|
| 136 |
+
|
| 137 |
+
|
| 138 |
+
def parse_hmmsearch_output(output_path: str) -> pd.DataFrame:
|
| 139 |
+
"""
|
| 140 |
+
Parses standard hmmsearch output.
|
| 141 |
+
"""
|
| 142 |
+
# hmmsearch format output fields.
|
| 143 |
+
hmmsearch_table_header = [
|
| 144 |
+
'target', 'target_acc', 'tlen', 'query', 'query_acc', 'qlen',
|
| 145 |
+
'evalue', 'score', 'bias', 'num', 'of', 'cevalue', 'ievalue',
|
| 146 |
+
'dscore', 'dbias', 'hmm_from', 'hmm_to', 'ali_from', 'ali_to',
|
| 147 |
+
'env_from', 'env_to', 'acc', 'desc',
|
| 148 |
+
]
|
| 149 |
+
|
| 150 |
+
data = []
|
| 151 |
+
with open(output_path, 'r') as f:
|
| 152 |
+
for line in f:
|
| 153 |
+
if line.startswith("#"):
|
| 154 |
+
continue
|
| 155 |
+
line = line.strip().split()
|
| 156 |
+
data.append(dict(zip(hmmsearch_table_header, line)))
|
| 157 |
+
|
| 158 |
+
return pd.DataFrame(data)
|
| 159 |
+
|
| 160 |
+
|
| 161 |
+
def permutation_test(
|
| 162 |
+
score_func: Callable[[np.array, np.array], float],
|
| 163 |
+
x1: np.array,
|
| 164 |
+
x2: np.array,
|
| 165 |
+
n_permutations: int = 100_000,
|
| 166 |
+
) -> float:
|
| 167 |
+
"""
|
| 168 |
+
Returns a permutation-based P value. Computes the null distribution by
|
| 169 |
+
shuffling the provided data and recomputing the `score_func`.
|
| 170 |
+
"""
|
| 171 |
+
if n_permutations < 1:
|
| 172 |
+
raise ValueError('Number of permutations must be positive.')
|
| 173 |
+
|
| 174 |
+
x1, x2 = np.array(x1), np.array(x2)
|
| 175 |
+
|
| 176 |
+
observed_score = score_func(x1, x2)
|
| 177 |
+
|
| 178 |
+
null_distribution = np.array([
|
| 179 |
+
score_func(x1, np.random.permutation(x2))
|
| 180 |
+
for _ in range(n_permutations)
|
| 181 |
+
])
|
| 182 |
+
|
| 183 |
+
return np.mean(null_distribution >= observed_score)
|
evo/version.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
version = '0.4'
|
requirements.txt
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
gradio==4.44.0
|
| 2 |
+
torch==2.1.0
|
| 3 |
+
numpy==1.24.3
|
| 4 |
+
transformers==4.36.0
|
| 5 |
+
einops==0.7.0
|
| 6 |
+
pyyaml==6.0.1
|
| 7 |
+
git+https://github.com/togethercomputer/stripedhyena.git
|