Spaces:
Sleeping
Sleeping
Deploy Streamlit UI
Browse files- .dockerignore +16 -0
- .gitattributes +3 -0
- .github/ISSUE_TEMPLATE/bug_report.md +38 -0
- .github/ISSUE_TEMPLATE/feature_request.md +17 -0
- .github/ISSUE_TEMPLATE/other.md +10 -0
- .github/workflows/ci.yml +39 -0
- .gitignore +230 -0
- Benchmark 80 sequences.xlsx +3 -0
- CODE_OF_CONDUCT.md +128 -0
- CodonTransformer/CodonData.py +682 -0
- CodonTransformer/CodonEvaluation.py +583 -0
- CodonTransformer/CodonJupyter.py +311 -0
- CodonTransformer/CodonPostProcessing.py +83 -0
- CodonTransformer/CodonPrediction.py +1372 -0
- CodonTransformer/CodonUtils.py +871 -0
- CodonTransformer/__init__.py +1 -0
- Dockerfile +21 -0
- ENCOT_Academic_Documentation.html +2625 -0
- ENCOT_Code_Showcase.html +791 -0
- LICENSE +201 -0
- Makefile +9 -0
- README.md +495 -10
- app.py +12 -0
- benchmark_evaluation.py +695 -0
- comprehensive_model_comparison.png +3 -0
- configs/train_ecoli_alm.yaml +54 -0
- configs/train_ecoli_quick.yaml +37 -0
- create_model_datasets.py +42 -0
- evaluate_optimizer.py +577 -0
- prepare_ecoli_data.py +69 -0
- pretrain.py +232 -0
- pyproject.toml +62 -0
- requirements.txt +29 -0
- scripts/optimize_sequence.py +383 -0
- scripts/preprocess_data.py +251 -0
- scripts/run_benchmarks.py +235 -0
- scripts/train.py +228 -0
- setup.py +40 -0
- src/CodonTransformer_inference_template.xlsx +0 -0
- src/__init__.py +1 -0
- src/banner_final.png +3 -0
- src/organism2id.pkl +3 -0
- streamlit_app.py +16 -0
- streamlit_gui/app.py +1456 -0
- streamlit_gui/demo.py +288 -0
- streamlit_gui/requirements.txt +20 -0
- streamlit_gui/run_gui.py +102 -0
- streamlit_gui/test_gui.py +321 -0
.dockerignore
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
.git
|
| 2 |
+
.github
|
| 3 |
+
.venv
|
| 4 |
+
__pycache__
|
| 5 |
+
*.pyc
|
| 6 |
+
*.pyo
|
| 7 |
+
*.pyd
|
| 8 |
+
*.log
|
| 9 |
+
*.ipynb
|
| 10 |
+
.devcontainer
|
| 11 |
+
data
|
| 12 |
+
notebooks
|
| 13 |
+
tests
|
| 14 |
+
slurm
|
| 15 |
+
Benchmark 80 sequences.xlsx
|
| 16 |
+
comprehensive_model_comparison.png
|
.gitattributes
CHANGED
|
@@ -33,3 +33,6 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
Benchmark[[:space:]]80[[:space:]]sequences.xlsx filter=lfs diff=lfs merge=lfs -text
|
| 37 |
+
comprehensive_model_comparison.png filter=lfs diff=lfs merge=lfs -text
|
| 38 |
+
src/banner_final.png filter=lfs diff=lfs merge=lfs -text
|
.github/ISSUE_TEMPLATE/bug_report.md
ADDED
|
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
name: Bug report
|
| 3 |
+
about: Create a report to help us improve
|
| 4 |
+
title: ''
|
| 5 |
+
labels: ''
|
| 6 |
+
assignees: ''
|
| 7 |
+
|
| 8 |
+
---
|
| 9 |
+
|
| 10 |
+
**Describe the bug**
|
| 11 |
+
A clear and concise description of what the bug is.
|
| 12 |
+
|
| 13 |
+
**To Reproduce**
|
| 14 |
+
Steps to reproduce the behavior:
|
| 15 |
+
1. Go to '...'
|
| 16 |
+
2. Click on '....'
|
| 17 |
+
3. Scroll down to '....'
|
| 18 |
+
4. See error
|
| 19 |
+
|
| 20 |
+
**Expected behavior**
|
| 21 |
+
A clear and concise description of what you expected to happen.
|
| 22 |
+
|
| 23 |
+
**Screenshots**
|
| 24 |
+
If applicable, add screenshots to help explain your problem.
|
| 25 |
+
|
| 26 |
+
**Desktop (please complete the following information):**
|
| 27 |
+
- OS: [e.g. iOS]
|
| 28 |
+
- Browser [e.g. chrome, safari]
|
| 29 |
+
- Version [e.g. 22]
|
| 30 |
+
|
| 31 |
+
**Smartphone (please complete the following information):**
|
| 32 |
+
- Device: [e.g. iPhone6]
|
| 33 |
+
- OS: [e.g. iOS8.1]
|
| 34 |
+
- Browser [e.g. stock browser, safari]
|
| 35 |
+
- Version [e.g. 22]
|
| 36 |
+
|
| 37 |
+
**Additional context**
|
| 38 |
+
Add any other context about the problem here.
|
.github/ISSUE_TEMPLATE/feature_request.md
ADDED
|
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
name: Feature request
|
| 3 |
+
about: Suggest an idea for this project
|
| 4 |
+
title: ''
|
| 5 |
+
labels: enhancement
|
| 6 |
+
assignees: ''
|
| 7 |
+
|
| 8 |
+
---
|
| 9 |
+
|
| 10 |
+
**Is your feature request related to a problem? Please describe.**
|
| 11 |
+
A clear and concise description of what the problem is. Ex. I'm always frustrated when [...]
|
| 12 |
+
|
| 13 |
+
**Describe the solution you'd like**
|
| 14 |
+
A clear and concise description of what you want to happen.
|
| 15 |
+
|
| 16 |
+
**Additional context**
|
| 17 |
+
Add any other context or screenshots about the feature request here.
|
.github/ISSUE_TEMPLATE/other.md
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
name: Other
|
| 3 |
+
about: Any other issue
|
| 4 |
+
title: ''
|
| 5 |
+
labels: bug
|
| 6 |
+
assignees: ''
|
| 7 |
+
|
| 8 |
+
---
|
| 9 |
+
|
| 10 |
+
**Describe your issue here**
|
.github/workflows/ci.yml
ADDED
|
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# .github/workflows/ci.yml
|
| 2 |
+
|
| 3 |
+
name: CI
|
| 4 |
+
|
| 5 |
+
on: [push, pull_request]
|
| 6 |
+
|
| 7 |
+
jobs:
|
| 8 |
+
test:
|
| 9 |
+
runs-on: ubuntu-latest
|
| 10 |
+
|
| 11 |
+
steps:
|
| 12 |
+
- name: Checkout code
|
| 13 |
+
uses: actions/checkout@v4
|
| 14 |
+
|
| 15 |
+
- name: Set up Python
|
| 16 |
+
uses: actions/setup-python@v5
|
| 17 |
+
with:
|
| 18 |
+
python-version: '3.10'
|
| 19 |
+
|
| 20 |
+
- name: Install dependencies
|
| 21 |
+
run: |
|
| 22 |
+
python -m pip install --upgrade pip
|
| 23 |
+
pip install -r requirements.txt
|
| 24 |
+
pip install "coverage[toml]"
|
| 25 |
+
|
| 26 |
+
- name: Run tests with coverage
|
| 27 |
+
run: |
|
| 28 |
+
make test_with_coverage
|
| 29 |
+
coverage report
|
| 30 |
+
coverage xml
|
| 31 |
+
|
| 32 |
+
- name: Upload coverage to Codecov
|
| 33 |
+
uses: codecov/codecov-action@v4
|
| 34 |
+
with:
|
| 35 |
+
token: ${{ secrets.CODECOV_TOKEN }}
|
| 36 |
+
file: coverage.xml
|
| 37 |
+
flags: unittests
|
| 38 |
+
name: codecov-umbrella
|
| 39 |
+
fail_ci_if_error: true
|
.gitignore
ADDED
|
@@ -0,0 +1,230 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Byte-compiled / optimized / DLL files
|
| 2 |
+
__pycache__/
|
| 3 |
+
*.py[cod]
|
| 4 |
+
*$py.class
|
| 5 |
+
|
| 6 |
+
# C extensions
|
| 7 |
+
*.so
|
| 8 |
+
|
| 9 |
+
# Distribution / packaging
|
| 10 |
+
.Python
|
| 11 |
+
build/
|
| 12 |
+
develop-eggs/
|
| 13 |
+
dist/
|
| 14 |
+
downloads/
|
| 15 |
+
eggs/
|
| 16 |
+
.eggs/
|
| 17 |
+
lib/
|
| 18 |
+
lib64/
|
| 19 |
+
parts/
|
| 20 |
+
sdist/
|
| 21 |
+
var/
|
| 22 |
+
wheels/
|
| 23 |
+
share/python-wheels/
|
| 24 |
+
*.egg-info/
|
| 25 |
+
.installed.cfg
|
| 26 |
+
*.egg
|
| 27 |
+
MANIFEST
|
| 28 |
+
|
| 29 |
+
# PyInstaller
|
| 30 |
+
# Usually these files are written by a python script from a template
|
| 31 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
| 32 |
+
*.manifest
|
| 33 |
+
*.spec
|
| 34 |
+
|
| 35 |
+
# Installer logs
|
| 36 |
+
pip-log.txt
|
| 37 |
+
pip-delete-this-directory.txt
|
| 38 |
+
|
| 39 |
+
# Unit test / coverage reports
|
| 40 |
+
htmlcov/
|
| 41 |
+
.tox/
|
| 42 |
+
.nox/
|
| 43 |
+
.coverage
|
| 44 |
+
.coverage.*
|
| 45 |
+
.cache
|
| 46 |
+
nosetests.xml
|
| 47 |
+
coverage.xml
|
| 48 |
+
*.cover
|
| 49 |
+
*.py,cover
|
| 50 |
+
.hypothesis/
|
| 51 |
+
.pytest_cache/
|
| 52 |
+
cover/
|
| 53 |
+
|
| 54 |
+
# Translations
|
| 55 |
+
*.mo
|
| 56 |
+
*.pot
|
| 57 |
+
|
| 58 |
+
# Django stuff:
|
| 59 |
+
*.log
|
| 60 |
+
local_settings.py
|
| 61 |
+
db.sqlite3
|
| 62 |
+
db.sqlite3-journal
|
| 63 |
+
|
| 64 |
+
# Flask stuff:
|
| 65 |
+
instance/
|
| 66 |
+
.webassets-cache
|
| 67 |
+
|
| 68 |
+
# Scrapy stuff:
|
| 69 |
+
.scrapy
|
| 70 |
+
|
| 71 |
+
# Sphinx documentation
|
| 72 |
+
docs/_build/
|
| 73 |
+
|
| 74 |
+
# PyBuilder
|
| 75 |
+
.pybuilder/
|
| 76 |
+
target/
|
| 77 |
+
|
| 78 |
+
# Jupyter Notebook
|
| 79 |
+
.ipynb_checkpoints
|
| 80 |
+
|
| 81 |
+
# IPython
|
| 82 |
+
profile_default/
|
| 83 |
+
ipython_config.py
|
| 84 |
+
|
| 85 |
+
# pyenv
|
| 86 |
+
# For a library or package, you might want to ignore these files since the code is
|
| 87 |
+
# intended to run in multiple environments; otherwise, check them in:
|
| 88 |
+
# .python-version
|
| 89 |
+
|
| 90 |
+
# pipenv
|
| 91 |
+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
| 92 |
+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
| 93 |
+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
| 94 |
+
# install all needed dependencies.
|
| 95 |
+
#Pipfile.lock
|
| 96 |
+
|
| 97 |
+
# poetry
|
| 98 |
+
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
|
| 99 |
+
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
| 100 |
+
# commonly ignored for libraries.
|
| 101 |
+
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
|
| 102 |
+
#poetry.lock
|
| 103 |
+
|
| 104 |
+
# pdm
|
| 105 |
+
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
|
| 106 |
+
#pdm.lock
|
| 107 |
+
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
|
| 108 |
+
# in version control.
|
| 109 |
+
# https://pdm.fming.dev/#use-with-ide
|
| 110 |
+
.pdm.toml
|
| 111 |
+
|
| 112 |
+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
|
| 113 |
+
__pypackages__/
|
| 114 |
+
|
| 115 |
+
# Celery stuff
|
| 116 |
+
celerybeat-schedule
|
| 117 |
+
celerybeat.pid
|
| 118 |
+
|
| 119 |
+
# SageMath parsed files
|
| 120 |
+
*.sage.py
|
| 121 |
+
|
| 122 |
+
# Environments
|
| 123 |
+
.env
|
| 124 |
+
.venv
|
| 125 |
+
env/
|
| 126 |
+
venv/
|
| 127 |
+
ENV/
|
| 128 |
+
env.bak/
|
| 129 |
+
venv.bak/
|
| 130 |
+
codon_env/
|
| 131 |
+
|
| 132 |
+
# Spyder project settings
|
| 133 |
+
.spyderproject
|
| 134 |
+
.spyproject
|
| 135 |
+
|
| 136 |
+
# Rope project settings
|
| 137 |
+
.ropeproject
|
| 138 |
+
|
| 139 |
+
# mkdocs documentation
|
| 140 |
+
/site
|
| 141 |
+
|
| 142 |
+
# mypy
|
| 143 |
+
.mypy_cache/
|
| 144 |
+
.dmypy.json
|
| 145 |
+
dmypy.json
|
| 146 |
+
|
| 147 |
+
# Pyre type checker
|
| 148 |
+
.pyre/
|
| 149 |
+
|
| 150 |
+
# pytype static type analyzer
|
| 151 |
+
.pytype/
|
| 152 |
+
|
| 153 |
+
# Cython debug symbols
|
| 154 |
+
cython_debug/
|
| 155 |
+
|
| 156 |
+
# PyCharm
|
| 157 |
+
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
|
| 158 |
+
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
| 159 |
+
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
| 160 |
+
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
| 161 |
+
#.idea/
|
| 162 |
+
|
| 163 |
+
# Coverage reports
|
| 164 |
+
coverage.xml
|
| 165 |
+
|
| 166 |
+
# Jupyter Notebook checkpoints
|
| 167 |
+
.ipynb_checkpoints/
|
| 168 |
+
|
| 169 |
+
# Temporary files
|
| 170 |
+
*.tmp
|
| 171 |
+
*.temp
|
| 172 |
+
|
| 173 |
+
# PyTorch Lightning checkpoints
|
| 174 |
+
lightning_logs/
|
| 175 |
+
|
| 176 |
+
# PyTorch model weights
|
| 177 |
+
*.pth
|
| 178 |
+
*.pt
|
| 179 |
+
|
| 180 |
+
# Large files excluded from Git
|
| 181 |
+
models/ecoli-codon-optimizer/finetune.ckpt
|
| 182 |
+
models/ecoli-codon-optimizer/finetune_best.ckpt
|
| 183 |
+
data/ecoli_processed_genes.csv
|
| 184 |
+
|
| 185 |
+
# Finetune-related files (keep local only)
|
| 186 |
+
finetune.py
|
| 187 |
+
checkpoints/
|
| 188 |
+
*.safetensors
|
| 189 |
+
|
| 190 |
+
# Benchmark and validation results
|
| 191 |
+
benchmark_plots/
|
| 192 |
+
cai_tai_benchmark.csv
|
| 193 |
+
synthetic_validation.csv
|
| 194 |
+
test_set_validation.csv
|
| 195 |
+
|
| 196 |
+
# Large data files
|
| 197 |
+
*.csv
|
| 198 |
+
*.jsonl
|
| 199 |
+
*.json
|
| 200 |
+
*.fasta
|
| 201 |
+
*.fa
|
| 202 |
+
*.ckpt
|
| 203 |
+
|
| 204 |
+
# Results and outputs
|
| 205 |
+
results/
|
| 206 |
+
outputs/
|
| 207 |
+
logs/
|
| 208 |
+
|
| 209 |
+
# Model files and weights
|
| 210 |
+
*.bin
|
| 211 |
+
*.safetensors
|
| 212 |
+
|
| 213 |
+
# CUDA and GPU related
|
| 214 |
+
*.run
|
| 215 |
+
cuda_installer.pyz
|
| 216 |
+
|
| 217 |
+
# R files
|
| 218 |
+
.RData
|
| 219 |
+
.Rhistory
|
| 220 |
+
|
| 221 |
+
# OS generated files
|
| 222 |
+
.DS_Store
|
| 223 |
+
.DS_Store?
|
| 224 |
+
._*
|
| 225 |
+
.Spotlight-V100
|
| 226 |
+
.Trashes
|
| 227 |
+
ehthumbs.db
|
| 228 |
+
Thumbs.db
|
| 229 |
+
research/
|
| 230 |
+
models/alm-enhanced-training/balanced_alm_finetune.ckpt
|
Benchmark 80 sequences.xlsx
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:f80bde88a31e80ac34b0827180b50d112f1d26bdf691c8118943e91c0e3b09e2
|
| 3 |
+
size 179471
|
CODE_OF_CONDUCT.md
ADDED
|
@@ -0,0 +1,128 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Contributor Covenant Code of Conduct
|
| 2 |
+
|
| 3 |
+
## Our Pledge
|
| 4 |
+
|
| 5 |
+
We as members, contributors, and leaders pledge to make participation in our
|
| 6 |
+
community a harassment-free experience for everyone, regardless of age, body
|
| 7 |
+
size, visible or invisible disability, ethnicity, sex characteristics, gender
|
| 8 |
+
identity and expression, level of experience, education, socio-economic status,
|
| 9 |
+
nationality, personal appearance, race, religion, or sexual identity
|
| 10 |
+
and orientation.
|
| 11 |
+
|
| 12 |
+
We pledge to act and interact in ways that contribute to an open, welcoming,
|
| 13 |
+
diverse, inclusive, and healthy community.
|
| 14 |
+
|
| 15 |
+
## Our Standards
|
| 16 |
+
|
| 17 |
+
Examples of behavior that contributes to a positive environment for our
|
| 18 |
+
community include:
|
| 19 |
+
|
| 20 |
+
* Demonstrating empathy and kindness toward other people
|
| 21 |
+
* Being respectful of differing opinions, viewpoints, and experiences
|
| 22 |
+
* Giving and gracefully accepting constructive feedback
|
| 23 |
+
* Accepting responsibility and apologizing to those affected by our mistakes,
|
| 24 |
+
and learning from the experience
|
| 25 |
+
* Focusing on what is best not just for us as individuals, but for the
|
| 26 |
+
overall community
|
| 27 |
+
|
| 28 |
+
Examples of unacceptable behavior include:
|
| 29 |
+
|
| 30 |
+
* The use of sexualized language or imagery, and sexual attention or
|
| 31 |
+
advances of any kind
|
| 32 |
+
* Trolling, insulting or derogatory comments, and personal or political attacks
|
| 33 |
+
* Public or private harassment
|
| 34 |
+
* Publishing others' private information, such as a physical or email
|
| 35 |
+
address, without their explicit permission
|
| 36 |
+
* Other conduct which could reasonably be considered inappropriate in a
|
| 37 |
+
professional setting
|
| 38 |
+
|
| 39 |
+
## Enforcement Responsibilities
|
| 40 |
+
|
| 41 |
+
Community leaders are responsible for clarifying and enforcing our standards of
|
| 42 |
+
acceptable behavior and will take appropriate and fair corrective action in
|
| 43 |
+
response to any behavior that they deem inappropriate, threatening, offensive,
|
| 44 |
+
or harmful.
|
| 45 |
+
|
| 46 |
+
Community leaders have the right and responsibility to remove, edit, or reject
|
| 47 |
+
comments, commits, code, wiki edits, issues, and other contributions that are
|
| 48 |
+
not aligned to this Code of Conduct, and will communicate reasons for moderation
|
| 49 |
+
decisions when appropriate.
|
| 50 |
+
|
| 51 |
+
## Scope
|
| 52 |
+
|
| 53 |
+
This Code of Conduct applies within all community spaces, and also applies when
|
| 54 |
+
an individual is officially representing the community in public spaces.
|
| 55 |
+
Examples of representing our community include using an official e-mail address,
|
| 56 |
+
posting via an official social media account, or acting as an appointed
|
| 57 |
+
representative at an online or offline event.
|
| 58 |
+
|
| 59 |
+
## Enforcement
|
| 60 |
+
|
| 61 |
+
Instances of abusive, harassing, or otherwise unacceptable behavior may be
|
| 62 |
+
reported to the community leaders responsible for enforcement at
|
| 63 |
+
Adibvafa.fallahpour@mail.utoronto.ca.
|
| 64 |
+
All complaints will be reviewed and investigated promptly and fairly.
|
| 65 |
+
|
| 66 |
+
All community leaders are obligated to respect the privacy and security of the
|
| 67 |
+
reporter of any incident.
|
| 68 |
+
|
| 69 |
+
## Enforcement Guidelines
|
| 70 |
+
|
| 71 |
+
Community leaders will follow these Community Impact Guidelines in determining
|
| 72 |
+
the consequences for any action they deem in violation of this Code of Conduct:
|
| 73 |
+
|
| 74 |
+
### 1. Correction
|
| 75 |
+
|
| 76 |
+
**Community Impact**: Use of inappropriate language or other behavior deemed
|
| 77 |
+
unprofessional or unwelcome in the community.
|
| 78 |
+
|
| 79 |
+
**Consequence**: A private, written warning from community leaders, providing
|
| 80 |
+
clarity around the nature of the violation and an explanation of why the
|
| 81 |
+
behavior was inappropriate. A public apology may be requested.
|
| 82 |
+
|
| 83 |
+
### 2. Warning
|
| 84 |
+
|
| 85 |
+
**Community Impact**: A violation through a single incident or series
|
| 86 |
+
of actions.
|
| 87 |
+
|
| 88 |
+
**Consequence**: A warning with consequences for continued behavior. No
|
| 89 |
+
interaction with the people involved, including unsolicited interaction with
|
| 90 |
+
those enforcing the Code of Conduct, for a specified period of time. This
|
| 91 |
+
includes avoiding interactions in community spaces as well as external channels
|
| 92 |
+
like social media. Violating these terms may lead to a temporary or
|
| 93 |
+
permanent ban.
|
| 94 |
+
|
| 95 |
+
### 3. Temporary Ban
|
| 96 |
+
|
| 97 |
+
**Community Impact**: A serious violation of community standards, including
|
| 98 |
+
sustained inappropriate behavior.
|
| 99 |
+
|
| 100 |
+
**Consequence**: A temporary ban from any sort of interaction or public
|
| 101 |
+
communication with the community for a specified period of time. No public or
|
| 102 |
+
private interaction with the people involved, including unsolicited interaction
|
| 103 |
+
with those enforcing the Code of Conduct, is allowed during this period.
|
| 104 |
+
Violating these terms may lead to a permanent ban.
|
| 105 |
+
|
| 106 |
+
### 4. Permanent Ban
|
| 107 |
+
|
| 108 |
+
**Community Impact**: Demonstrating a pattern of violation of community
|
| 109 |
+
standards, including sustained inappropriate behavior, harassment of an
|
| 110 |
+
individual, or aggression toward or disparagement of classes of individuals.
|
| 111 |
+
|
| 112 |
+
**Consequence**: A permanent ban from any sort of public interaction within
|
| 113 |
+
the community.
|
| 114 |
+
|
| 115 |
+
## Attribution
|
| 116 |
+
|
| 117 |
+
This Code of Conduct is adapted from the [Contributor Covenant][homepage],
|
| 118 |
+
version 2.0, available at
|
| 119 |
+
https://www.contributor-covenant.org/version/2/0/code_of_conduct.html.
|
| 120 |
+
|
| 121 |
+
Community Impact Guidelines were inspired by [Mozilla's code of conduct
|
| 122 |
+
enforcement ladder](https://github.com/mozilla/diversity).
|
| 123 |
+
|
| 124 |
+
[homepage]: https://www.contributor-covenant.org
|
| 125 |
+
|
| 126 |
+
For answers to common questions about this code of conduct, see the FAQ at
|
| 127 |
+
https://www.contributor-covenant.org/faq. Translations are available at
|
| 128 |
+
https://www.contributor-covenant.org/translations.
|
CodonTransformer/CodonData.py
ADDED
|
@@ -0,0 +1,682 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
File: CodonData.py
|
| 3 |
+
---------------------
|
| 4 |
+
Includes helper functions for preprocessing NCBI or Kazusa databases and
|
| 5 |
+
preparing the data for training and inference of the CodonTransformer model.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
import json
|
| 9 |
+
import os
|
| 10 |
+
import random
|
| 11 |
+
from typing import Dict, List, Optional, Tuple, Union
|
| 12 |
+
|
| 13 |
+
import pandas as pd
|
| 14 |
+
import python_codon_tables as pct
|
| 15 |
+
from Bio import SeqIO
|
| 16 |
+
from Bio.Seq import Seq
|
| 17 |
+
from sklearn.utils import shuffle as sk_shuffle
|
| 18 |
+
from tqdm import tqdm
|
| 19 |
+
|
| 20 |
+
from CodonTransformer.CodonUtils import (
|
| 21 |
+
AMBIGUOUS_AMINOACID_MAP,
|
| 22 |
+
AMINO2CODON_TYPE,
|
| 23 |
+
AMINO_ACIDS,
|
| 24 |
+
ORGANISM2ID,
|
| 25 |
+
START_CODONS,
|
| 26 |
+
STOP_CODONS,
|
| 27 |
+
STOP_SYMBOL,
|
| 28 |
+
STOP_SYMBOLS,
|
| 29 |
+
ProteinConfig,
|
| 30 |
+
find_pattern_in_fasta,
|
| 31 |
+
get_taxonomy_id,
|
| 32 |
+
sort_amino2codon_skeleton,
|
| 33 |
+
)
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def prepare_training_data(
|
| 37 |
+
dataset: Union[str, pd.DataFrame], output_file: str, shuffle: bool = True
|
| 38 |
+
) -> None:
|
| 39 |
+
"""
|
| 40 |
+
Prepare a JSON dataset for training the CodonTransformer model.
|
| 41 |
+
|
| 42 |
+
Input dataset should have columns below:
|
| 43 |
+
- dna: str (DNA sequence)
|
| 44 |
+
- protein: str (Protein sequence)
|
| 45 |
+
- organism: Union[int, str] (ID or Name of the organism)
|
| 46 |
+
|
| 47 |
+
The output JSON dataset will have the following format:
|
| 48 |
+
{"idx": 0, "codons": "M_ATG R_AGG L_TTG L_CTA R_CGA __TAG", "organism": 51}
|
| 49 |
+
{"idx": 1, "codons": "M_ATG K_AAG C_TGC F_TTT F_TTC __TAA", "organism": 59}
|
| 50 |
+
|
| 51 |
+
Args:
|
| 52 |
+
dataset (Union[str, pd.DataFrame]): Input dataset in CSV or DataFrame format.
|
| 53 |
+
output_file (str): Path to save the output JSON dataset.
|
| 54 |
+
shuffle (bool, optional): Whether to shuffle the dataset before saving.
|
| 55 |
+
Defaults to True.
|
| 56 |
+
|
| 57 |
+
Returns:
|
| 58 |
+
None
|
| 59 |
+
"""
|
| 60 |
+
if isinstance(dataset, str):
|
| 61 |
+
dataset = pd.read_csv(dataset)
|
| 62 |
+
|
| 63 |
+
required_columns = {"dna", "protein", "organism"}
|
| 64 |
+
if not required_columns.issubset(dataset.columns):
|
| 65 |
+
raise ValueError(f"Input dataset must have columns: {required_columns}")
|
| 66 |
+
|
| 67 |
+
# Prepare the dataset for finetuning
|
| 68 |
+
dataset["codons"] = dataset.apply(
|
| 69 |
+
lambda row: get_merged_seq(row["protein"], row["dna"], separator="_"), axis=1
|
| 70 |
+
)
|
| 71 |
+
|
| 72 |
+
# Replace organism str with organism id using ORGANISM2ID
|
| 73 |
+
dataset["organism"] = dataset["organism"].apply(
|
| 74 |
+
lambda org: process_organism(org, ORGANISM2ID)
|
| 75 |
+
)
|
| 76 |
+
|
| 77 |
+
# Save the dataset to a JSON file
|
| 78 |
+
dataframe_to_json(dataset[["codons", "organism"]], output_file, shuffle=shuffle)
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
def dataframe_to_json(df: pd.DataFrame, output_file: str, shuffle: bool = True) -> None:
|
| 82 |
+
"""
|
| 83 |
+
Convert pandas DataFrame to JSON file format suitable for training CodonTransformer.
|
| 84 |
+
|
| 85 |
+
This function takes a preprocessed DataFrame and writes it to a JSON file
|
| 86 |
+
where each line is a JSON object representing a single record.
|
| 87 |
+
|
| 88 |
+
Args:
|
| 89 |
+
df (pd.DataFrame): The input DataFrame with 'codons' and 'organism' columns.
|
| 90 |
+
output_file (str): Path to the output JSON file.
|
| 91 |
+
shuffle (bool, optional): Whether to shuffle the dataset before saving.
|
| 92 |
+
Defaults to True.
|
| 93 |
+
|
| 94 |
+
Returns:
|
| 95 |
+
None
|
| 96 |
+
|
| 97 |
+
Raises:
|
| 98 |
+
ValueError: If the required columns are not present in the DataFrame.
|
| 99 |
+
"""
|
| 100 |
+
required_columns = {"codons", "organism"}
|
| 101 |
+
if not required_columns.issubset(df.columns):
|
| 102 |
+
raise ValueError(f"DataFrame must contain columns: {required_columns}")
|
| 103 |
+
|
| 104 |
+
print(f"\nStarted writing to {output_file}...")
|
| 105 |
+
|
| 106 |
+
# Shuffle the DataFrame if requested
|
| 107 |
+
if shuffle:
|
| 108 |
+
df = sk_shuffle(df)
|
| 109 |
+
|
| 110 |
+
# Write the DataFrame to a JSON file
|
| 111 |
+
with open(output_file, "w") as f:
|
| 112 |
+
for idx, row in tqdm(
|
| 113 |
+
df.iterrows(), total=len(df), desc="Writing JSON...", unit=" records"
|
| 114 |
+
):
|
| 115 |
+
doc = {"idx": idx, "codons": row["codons"], "organism": row["organism"]}
|
| 116 |
+
f.write(json.dumps(doc) + "\n")
|
| 117 |
+
|
| 118 |
+
print(f"\nTotal Entries Saved: {len(df)}, JSON data saved to {output_file}")
|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
def process_organism(organism: Union[str, int], organism_to_id: Dict[str, int]) -> int:
|
| 122 |
+
"""
|
| 123 |
+
Process and validate the organism input, converting it to a valid organism ID.
|
| 124 |
+
|
| 125 |
+
This function handles both string (organism name) and integer (organism ID) inputs.
|
| 126 |
+
It validates the input against a provided mapping of organism names to IDs.
|
| 127 |
+
|
| 128 |
+
Args:
|
| 129 |
+
organism (Union[str, int]): Input organism, either as a name (str) or ID (int).
|
| 130 |
+
organism_to_id (Dict[str, int]): Dictionary mapping organism names to their
|
| 131 |
+
corresponding IDs.
|
| 132 |
+
|
| 133 |
+
Returns:
|
| 134 |
+
int: The validated organism ID.
|
| 135 |
+
|
| 136 |
+
Raises:
|
| 137 |
+
ValueError: If the input is an invalid organism name or ID.
|
| 138 |
+
TypeError: If the input is neither a string nor an integer.
|
| 139 |
+
"""
|
| 140 |
+
if isinstance(organism, str):
|
| 141 |
+
if organism not in organism_to_id:
|
| 142 |
+
raise ValueError(f"Invalid organism name: {organism}")
|
| 143 |
+
return organism_to_id[organism]
|
| 144 |
+
|
| 145 |
+
elif isinstance(organism, int):
|
| 146 |
+
if organism not in organism_to_id.values():
|
| 147 |
+
raise ValueError(f"Invalid organism ID: {organism}")
|
| 148 |
+
return organism
|
| 149 |
+
|
| 150 |
+
raise TypeError(
|
| 151 |
+
f"Organism must be a string or integer, not {type(organism).__name__}"
|
| 152 |
+
)
|
| 153 |
+
|
| 154 |
+
|
| 155 |
+
def preprocess_protein_sequence(protein: str) -> str:
|
| 156 |
+
"""
|
| 157 |
+
Preprocess a protein sequence by cleaning, standardizing, and handling
|
| 158 |
+
ambiguous amino acids.
|
| 159 |
+
|
| 160 |
+
Args:
|
| 161 |
+
protein (str): The input protein sequence.
|
| 162 |
+
|
| 163 |
+
Returns:
|
| 164 |
+
str: The preprocessed protein sequence.
|
| 165 |
+
|
| 166 |
+
Raises:
|
| 167 |
+
ValueError: If the protein sequence is invalid or if the configuration is invalid.
|
| 168 |
+
"""
|
| 169 |
+
if not protein:
|
| 170 |
+
raise ValueError("Protein sequence is empty.")
|
| 171 |
+
|
| 172 |
+
# Clean and standardize the protein sequence
|
| 173 |
+
protein = (
|
| 174 |
+
protein.upper().strip().replace("\n", "").replace(" ", "").replace("\t", "")
|
| 175 |
+
)
|
| 176 |
+
|
| 177 |
+
# Handle ambiguous amino acids based on the specified behavior
|
| 178 |
+
config = ProteinConfig()
|
| 179 |
+
ambiguous_aminoacid_map_override = config.get("ambiguous_aminoacid_map_override")
|
| 180 |
+
ambiguous_aminoacid_behavior = config.get("ambiguous_aminoacid_behavior")
|
| 181 |
+
ambiguous_aminoacid_map = AMBIGUOUS_AMINOACID_MAP.copy()
|
| 182 |
+
|
| 183 |
+
for aminoacid, standard_aminoacids in ambiguous_aminoacid_map_override.items():
|
| 184 |
+
ambiguous_aminoacid_map[aminoacid] = standard_aminoacids
|
| 185 |
+
|
| 186 |
+
if ambiguous_aminoacid_behavior == "raise_error":
|
| 187 |
+
if any(aminoacid in ambiguous_aminoacid_map for aminoacid in protein):
|
| 188 |
+
raise ValueError("Ambiguous amino acids found in protein sequence.")
|
| 189 |
+
elif ambiguous_aminoacid_behavior == "standardize_deterministic":
|
| 190 |
+
protein = "".join(
|
| 191 |
+
ambiguous_aminoacid_map.get(aminoacid, [aminoacid])[0]
|
| 192 |
+
for aminoacid in protein
|
| 193 |
+
)
|
| 194 |
+
elif ambiguous_aminoacid_behavior == "standardize_random":
|
| 195 |
+
protein = "".join(
|
| 196 |
+
random.choice(ambiguous_aminoacid_map.get(aminoacid, [aminoacid]))
|
| 197 |
+
for aminoacid in protein
|
| 198 |
+
)
|
| 199 |
+
else:
|
| 200 |
+
raise ValueError(
|
| 201 |
+
f"Invalid ambiguous_aminoacid_behavior: {ambiguous_aminoacid_behavior}."
|
| 202 |
+
)
|
| 203 |
+
|
| 204 |
+
# Check for sequence validity
|
| 205 |
+
if any(aminoacid not in AMINO_ACIDS + STOP_SYMBOLS for aminoacid in protein):
|
| 206 |
+
raise ValueError("Invalid characters in protein sequence.")
|
| 207 |
+
|
| 208 |
+
if protein[-1] not in AMINO_ACIDS + STOP_SYMBOLS:
|
| 209 |
+
raise ValueError(
|
| 210 |
+
"Protein sequence must end with `*`, or `_`, or an amino acid."
|
| 211 |
+
)
|
| 212 |
+
|
| 213 |
+
# Replace '*' at the end of protein with STOP_SYMBOL if present
|
| 214 |
+
if protein[-1] == "*":
|
| 215 |
+
protein = protein[:-1] + STOP_SYMBOL
|
| 216 |
+
|
| 217 |
+
# Add stop symbol to end of protein
|
| 218 |
+
if protein[-1] != STOP_SYMBOL:
|
| 219 |
+
protein += STOP_SYMBOL
|
| 220 |
+
|
| 221 |
+
return protein
|
| 222 |
+
|
| 223 |
+
|
| 224 |
+
def replace_ambiguous_codons(dna: str) -> str:
|
| 225 |
+
"""
|
| 226 |
+
Replaces ambiguous codons in a DNA sequence with "UNK".
|
| 227 |
+
|
| 228 |
+
Args:
|
| 229 |
+
dna (str): The DNA sequence to process.
|
| 230 |
+
|
| 231 |
+
Returns:
|
| 232 |
+
str: The processed DNA sequence with ambiguous codons replaced by "UNK".
|
| 233 |
+
"""
|
| 234 |
+
result = []
|
| 235 |
+
dna = dna.upper()
|
| 236 |
+
|
| 237 |
+
# Check codons in DNA sequence
|
| 238 |
+
for i in range(0, len(dna), 3):
|
| 239 |
+
codon = dna[i : i + 3]
|
| 240 |
+
|
| 241 |
+
if len(codon) == 3 and all(nucleotide in "ATCG" for nucleotide in codon):
|
| 242 |
+
result.append(codon)
|
| 243 |
+
else:
|
| 244 |
+
result.append("UNK")
|
| 245 |
+
|
| 246 |
+
return "".join(result)
|
| 247 |
+
|
| 248 |
+
|
| 249 |
+
def preprocess_dna_sequence(dna: str) -> str:
|
| 250 |
+
"""
|
| 251 |
+
Cleans and preprocesses a DNA sequence by standardizing it and replacing
|
| 252 |
+
ambiguous codons.
|
| 253 |
+
|
| 254 |
+
Args:
|
| 255 |
+
dna (str): The DNA sequence to preprocess.
|
| 256 |
+
|
| 257 |
+
Returns:
|
| 258 |
+
str: The cleaned and preprocessed DNA sequence.
|
| 259 |
+
"""
|
| 260 |
+
if not dna:
|
| 261 |
+
return ""
|
| 262 |
+
|
| 263 |
+
# Clean and standardize the DNA sequence
|
| 264 |
+
dna = dna.upper().strip().replace("\n", "").replace(" ", "").replace("\t", "")
|
| 265 |
+
|
| 266 |
+
# Replace codons with ambigous nucleotides with "UNK"
|
| 267 |
+
dna = replace_ambiguous_codons(dna)
|
| 268 |
+
|
| 269 |
+
# Add unkown stop codon to end of DNA sequence if not present
|
| 270 |
+
if dna[-3:] not in STOP_CODONS:
|
| 271 |
+
dna += "UNK"
|
| 272 |
+
|
| 273 |
+
return dna
|
| 274 |
+
|
| 275 |
+
|
| 276 |
+
def get_merged_seq(protein: str, dna: str = "", separator: str = "_") -> str:
|
| 277 |
+
"""
|
| 278 |
+
Return the merged sequence of protein amino acids and DNA codons in the form
|
| 279 |
+
of tokens separated by space, where each token is composed of an amino acid +
|
| 280 |
+
separator + codon.
|
| 281 |
+
|
| 282 |
+
Args:
|
| 283 |
+
protein (str): Protein sequence.
|
| 284 |
+
dna (str): DNA sequence.
|
| 285 |
+
separator (str): Separator between amino acid and codon.
|
| 286 |
+
|
| 287 |
+
Returns:
|
| 288 |
+
str: Merged sequence.
|
| 289 |
+
|
| 290 |
+
Example:
|
| 291 |
+
>>> get_merged_seq(protein="MAV_", dna="ATGGCTGTGTAA", separator="_")
|
| 292 |
+
'M_ATG A_GCT V_GTG __TAA'
|
| 293 |
+
|
| 294 |
+
>>> get_merged_seq(protein="QHH_", dna="", separator="_")
|
| 295 |
+
'Q_UNK H_UNK H_UNK __UNK'
|
| 296 |
+
"""
|
| 297 |
+
merged_seq = ""
|
| 298 |
+
|
| 299 |
+
# Prepare protein and dna sequences
|
| 300 |
+
dna = preprocess_dna_sequence(dna)
|
| 301 |
+
protein = preprocess_protein_sequence(protein)
|
| 302 |
+
|
| 303 |
+
# Check if the length of protein and dna sequences are equal
|
| 304 |
+
if len(dna) > 0 and len(protein) != len(dna) / 3:
|
| 305 |
+
raise ValueError(
|
| 306 |
+
'Length of protein (including stop symbol such as "_") and '
|
| 307 |
+
"the number of codons in DNA sequence (including stop codon) "
|
| 308 |
+
"must be equal."
|
| 309 |
+
)
|
| 310 |
+
|
| 311 |
+
# Merge protein and DNA sequences into tokens
|
| 312 |
+
for i, aminoacid in enumerate(protein):
|
| 313 |
+
merged_seq += f'{aminoacid}{separator}{dna[i * 3:i * 3 + 3] if dna else "UNK"} '
|
| 314 |
+
|
| 315 |
+
return merged_seq.strip()
|
| 316 |
+
|
| 317 |
+
|
| 318 |
+
def is_correct_seq(dna: str, protein: str, stop_symbol: str = STOP_SYMBOL) -> bool:
|
| 319 |
+
"""
|
| 320 |
+
Check if the given DNA and protein pair is correct, that is:
|
| 321 |
+
1. The length of dna is divisible by 3
|
| 322 |
+
2. There is an initiator codon in the beginning of dna
|
| 323 |
+
3. There is only one stop codon in the sequence
|
| 324 |
+
4. The only stop codon is the last codon
|
| 325 |
+
|
| 326 |
+
Note since in Codon Table 3, 'TGA' is interpreted as Triptophan (W),
|
| 327 |
+
there is a separate check to make sure those sequences are considered correct.
|
| 328 |
+
|
| 329 |
+
Args:
|
| 330 |
+
dna (str): DNA sequence.
|
| 331 |
+
protein (str): Protein sequence.
|
| 332 |
+
stop_symbol (str): Stop symbol.
|
| 333 |
+
|
| 334 |
+
Returns:
|
| 335 |
+
bool: True if the sequence is correct, False otherwise.
|
| 336 |
+
"""
|
| 337 |
+
return (
|
| 338 |
+
len(dna) % 3 == 0 # Check if DNA length is divisible by 3
|
| 339 |
+
and dna[:3].upper() in START_CODONS # Check for initiator codon
|
| 340 |
+
and protein[-1]
|
| 341 |
+
== stop_symbol # Check if the last protein symbol is the stop symbol
|
| 342 |
+
and protein.count(stop_symbol) == 1 # Check if there is only one stop symbol
|
| 343 |
+
and len(set(dna))
|
| 344 |
+
== 4 # Check if DNA consists of 4 unique nucleotides (A, T, C, G)
|
| 345 |
+
)
|
| 346 |
+
|
| 347 |
+
|
| 348 |
+
def get_amino_acid_sequence(
|
| 349 |
+
dna: str,
|
| 350 |
+
stop_symbol: str = "_",
|
| 351 |
+
codon_table: int = 1,
|
| 352 |
+
return_correct_seq: bool = False,
|
| 353 |
+
) -> Union[str, Tuple[str, bool]]:
|
| 354 |
+
"""
|
| 355 |
+
Return the translated protein sequence given a DNA sequence and codon table.
|
| 356 |
+
|
| 357 |
+
Args:
|
| 358 |
+
dna (str): DNA sequence.
|
| 359 |
+
stop_symbol (str): Stop symbol.
|
| 360 |
+
codon_table (int): Codon table number.
|
| 361 |
+
return_correct_seq (bool): Whether to return if the sequence is correct.
|
| 362 |
+
|
| 363 |
+
Returns:
|
| 364 |
+
Union[str, Tuple[str, bool]]: Protein sequence and correctness flag if
|
| 365 |
+
return_correct_seq is True, otherwise just the protein sequence.
|
| 366 |
+
"""
|
| 367 |
+
dna_seq = Seq(dna).strip()
|
| 368 |
+
|
| 369 |
+
# Translate the DNA sequence to a protein sequence
|
| 370 |
+
protein_seq = str(
|
| 371 |
+
dna_seq.translate(
|
| 372 |
+
stop_symbol=stop_symbol, # Symbol to use for stop codons
|
| 373 |
+
to_stop=False, # Translate the entire sequence, including any stop codons
|
| 374 |
+
cds=False, # Do not assume the input is a coding sequence
|
| 375 |
+
table=codon_table, # Codon table to use for translation
|
| 376 |
+
)
|
| 377 |
+
).strip()
|
| 378 |
+
|
| 379 |
+
return (
|
| 380 |
+
protein_seq
|
| 381 |
+
if not return_correct_seq
|
| 382 |
+
else (protein_seq, is_correct_seq(dna_seq, protein_seq, stop_symbol))
|
| 383 |
+
)
|
| 384 |
+
|
| 385 |
+
|
| 386 |
+
def read_fasta_file(
|
| 387 |
+
input_file: str,
|
| 388 |
+
save_to_file: Optional[str] = None,
|
| 389 |
+
organism: str = "",
|
| 390 |
+
buffer_size: int = 50000,
|
| 391 |
+
) -> pd.DataFrame:
|
| 392 |
+
"""
|
| 393 |
+
Read a FASTA file of DNA sequences and convert it to a Pandas DataFrame.
|
| 394 |
+
Optionally, save the DataFrame to a CSV file.
|
| 395 |
+
|
| 396 |
+
Args:
|
| 397 |
+
input_file (str): Path to the input FASTA file.
|
| 398 |
+
save_to_file (Optional[str]): Path to save the output DataFrame. If None,
|
| 399 |
+
data is only returned.
|
| 400 |
+
organism (str): Name of the organism. If empty, it will be extracted from
|
| 401 |
+
the FASTA description.
|
| 402 |
+
buffer_size (int): Number of records to process before writing to file.
|
| 403 |
+
|
| 404 |
+
Returns:
|
| 405 |
+
pd.DataFrame: DataFrame containing the DNA sequences if return_dataframe
|
| 406 |
+
is True, else None.
|
| 407 |
+
|
| 408 |
+
Raises:
|
| 409 |
+
FileNotFoundError: If the input file does not exist.
|
| 410 |
+
"""
|
| 411 |
+
if not os.path.exists(input_file):
|
| 412 |
+
raise FileNotFoundError(f"Input file not found: {input_file}")
|
| 413 |
+
|
| 414 |
+
buffer = []
|
| 415 |
+
columns = [
|
| 416 |
+
"dna",
|
| 417 |
+
"protein",
|
| 418 |
+
"correct_seq",
|
| 419 |
+
"organism",
|
| 420 |
+
"GeneID",
|
| 421 |
+
"description",
|
| 422 |
+
"tokenized",
|
| 423 |
+
]
|
| 424 |
+
|
| 425 |
+
# Initialize DataFrame to store all data if return_dataframe is True
|
| 426 |
+
all_data = pd.DataFrame(columns=columns)
|
| 427 |
+
|
| 428 |
+
with open(input_file, "r") as fasta_file:
|
| 429 |
+
for record in tqdm(
|
| 430 |
+
SeqIO.parse(fasta_file, "fasta"),
|
| 431 |
+
desc=f"Processing {organism}",
|
| 432 |
+
unit=" Records",
|
| 433 |
+
):
|
| 434 |
+
dna = str(record.seq).strip().upper() # Ensure uppercase DNA sequence
|
| 435 |
+
|
| 436 |
+
# Determine the organism from the record if not provided
|
| 437 |
+
current_organism = organism or find_pattern_in_fasta(
|
| 438 |
+
"organism", record.description
|
| 439 |
+
)
|
| 440 |
+
gene_id = find_pattern_in_fasta("GeneID", record.description)
|
| 441 |
+
|
| 442 |
+
# Get the appropriate codon table for the organism
|
| 443 |
+
codon_table = get_codon_table(current_organism)
|
| 444 |
+
|
| 445 |
+
# Translate DNA to protein sequence
|
| 446 |
+
protein, correct_seq = get_amino_acid_sequence(
|
| 447 |
+
dna,
|
| 448 |
+
stop_symbol=STOP_SYMBOL,
|
| 449 |
+
codon_table=codon_table,
|
| 450 |
+
return_correct_seq=True,
|
| 451 |
+
)
|
| 452 |
+
description = record.description.split("[", 1)[0].strip()
|
| 453 |
+
tokenized = get_merged_seq(protein, dna, separator=STOP_SYMBOL)
|
| 454 |
+
|
| 455 |
+
# Create a data row for the current sequence
|
| 456 |
+
data_row = {
|
| 457 |
+
"dna": dna,
|
| 458 |
+
"protein": protein,
|
| 459 |
+
"correct_seq": correct_seq,
|
| 460 |
+
"organism": current_organism,
|
| 461 |
+
"GeneID": gene_id,
|
| 462 |
+
"description": description,
|
| 463 |
+
"tokenized": tokenized,
|
| 464 |
+
}
|
| 465 |
+
buffer.append(data_row)
|
| 466 |
+
|
| 467 |
+
# Write buffer to CSV file when buffer size is reached
|
| 468 |
+
if save_to_file and len(buffer) >= buffer_size:
|
| 469 |
+
write_buffer_to_csv(buffer, save_to_file, columns)
|
| 470 |
+
buffer = []
|
| 471 |
+
|
| 472 |
+
all_data = pd.concat(
|
| 473 |
+
[all_data, pd.DataFrame([data_row])], ignore_index=True
|
| 474 |
+
)
|
| 475 |
+
|
| 476 |
+
# Write remaining buffer to CSV file
|
| 477 |
+
if save_to_file and buffer:
|
| 478 |
+
write_buffer_to_csv(buffer, save_to_file, columns)
|
| 479 |
+
|
| 480 |
+
return all_data
|
| 481 |
+
|
| 482 |
+
|
| 483 |
+
def write_buffer_to_csv(buffer: List[Dict], output_path: str, columns: List[str]):
|
| 484 |
+
"""Helper function to write buffer to CSV file."""
|
| 485 |
+
buffer_df = pd.DataFrame(buffer, columns=columns)
|
| 486 |
+
buffer_df.to_csv(
|
| 487 |
+
output_path,
|
| 488 |
+
mode="a",
|
| 489 |
+
header=(not os.path.exists(output_path)),
|
| 490 |
+
index=True,
|
| 491 |
+
)
|
| 492 |
+
|
| 493 |
+
|
| 494 |
+
def download_codon_frequencies_from_kazusa(
|
| 495 |
+
taxonomy_id: Optional[int] = None,
|
| 496 |
+
organism: Optional[str] = None,
|
| 497 |
+
taxonomy_reference: Optional[str] = None,
|
| 498 |
+
return_original_format: bool = False,
|
| 499 |
+
) -> AMINO2CODON_TYPE:
|
| 500 |
+
"""
|
| 501 |
+
Return the codon table of the given taxonomy ID from the Kazusa Database.
|
| 502 |
+
|
| 503 |
+
Args:
|
| 504 |
+
taxonomy_id (Optional[int]): Taxonomy ID.
|
| 505 |
+
organism (Optional[str]): Name of the organism.
|
| 506 |
+
taxonomy_reference (Optional[str]): Taxonomy reference.
|
| 507 |
+
return_original_format (bool): Whether to return in the original format.
|
| 508 |
+
|
| 509 |
+
Returns:
|
| 510 |
+
AMINO2CODON_TYPE: Codon table.
|
| 511 |
+
"""
|
| 512 |
+
if taxonomy_reference:
|
| 513 |
+
taxonomy_id = get_taxonomy_id(taxonomy_reference, organism=organism)
|
| 514 |
+
|
| 515 |
+
kazusa_amino2codon = pct.get_codons_table(table_name=taxonomy_id)
|
| 516 |
+
|
| 517 |
+
if return_original_format:
|
| 518 |
+
return kazusa_amino2codon
|
| 519 |
+
|
| 520 |
+
# Replace "*" with STOP_SYMBOL in the codon table
|
| 521 |
+
kazusa_amino2codon[STOP_SYMBOL] = kazusa_amino2codon.pop("*")
|
| 522 |
+
|
| 523 |
+
# Create amino2codon dictionary
|
| 524 |
+
amino2codon = {
|
| 525 |
+
aminoacid: (list(codon2freq.keys()), list(codon2freq.values()))
|
| 526 |
+
for aminoacid, codon2freq in kazusa_amino2codon.items()
|
| 527 |
+
}
|
| 528 |
+
|
| 529 |
+
return sort_amino2codon_skeleton(amino2codon)
|
| 530 |
+
|
| 531 |
+
|
| 532 |
+
def build_amino2codon_skeleton(organism: str) -> AMINO2CODON_TYPE:
|
| 533 |
+
"""
|
| 534 |
+
Return the empty skeleton of the amino2codon dictionary, needed for
|
| 535 |
+
get_codon_frequencies.
|
| 536 |
+
|
| 537 |
+
Args:
|
| 538 |
+
organism (str): Name of the organism.
|
| 539 |
+
|
| 540 |
+
Returns:
|
| 541 |
+
AMINO2CODON_TYPE: Empty amino2codon dictionary.
|
| 542 |
+
"""
|
| 543 |
+
amino2codon = {}
|
| 544 |
+
possible_codons = [f"{i}{j}{k}" for i in "ACGT" for j in "ACGT" for k in "ACGT"]
|
| 545 |
+
possible_aminoacids = get_amino_acid_sequence(
|
| 546 |
+
dna="".join(possible_codons),
|
| 547 |
+
codon_table=get_codon_table(organism),
|
| 548 |
+
return_correct_seq=False,
|
| 549 |
+
)
|
| 550 |
+
|
| 551 |
+
# Initialize the amino2codon skeleton with all possible codons and set their
|
| 552 |
+
# frequencies to 0
|
| 553 |
+
for i, (codon, amino) in enumerate(zip(possible_codons, possible_aminoacids)):
|
| 554 |
+
if amino not in amino2codon:
|
| 555 |
+
amino2codon[amino] = ([], [])
|
| 556 |
+
|
| 557 |
+
amino2codon[amino][0].append(codon)
|
| 558 |
+
amino2codon[amino][1].append(0)
|
| 559 |
+
|
| 560 |
+
# Sort the dictionary and each list of codon frequency alphabetically
|
| 561 |
+
amino2codon = sort_amino2codon_skeleton(amino2codon)
|
| 562 |
+
|
| 563 |
+
return amino2codon
|
| 564 |
+
|
| 565 |
+
|
| 566 |
+
def get_codon_frequencies(
|
| 567 |
+
dna_sequences: List[str],
|
| 568 |
+
protein_sequences: Optional[List[str]] = None,
|
| 569 |
+
organism: Optional[str] = None,
|
| 570 |
+
) -> AMINO2CODON_TYPE:
|
| 571 |
+
"""
|
| 572 |
+
Return a dictionary mapping each codon to its respective frequency based on
|
| 573 |
+
the collection of DNA sequences and protein sequences.
|
| 574 |
+
|
| 575 |
+
Args:
|
| 576 |
+
dna_sequences (List[str]): List of DNA sequences.
|
| 577 |
+
protein_sequences (Optional[List[str]]): List of protein sequences.
|
| 578 |
+
organism (Optional[str]): Name of the organism.
|
| 579 |
+
|
| 580 |
+
Returns:
|
| 581 |
+
AMINO2CODON_TYPE: Dictionary mapping each amino acid to a tuple of codons
|
| 582 |
+
and frequencies.
|
| 583 |
+
"""
|
| 584 |
+
if organism:
|
| 585 |
+
codon_table = get_codon_table(organism)
|
| 586 |
+
protein_sequences = [
|
| 587 |
+
get_amino_acid_sequence(
|
| 588 |
+
dna, codon_table=codon_table, return_correct_seq=False
|
| 589 |
+
)
|
| 590 |
+
for dna in dna_sequences
|
| 591 |
+
]
|
| 592 |
+
|
| 593 |
+
amino2codon = build_amino2codon_skeleton(organism)
|
| 594 |
+
|
| 595 |
+
# Count the frequencies of each codon for each amino acid
|
| 596 |
+
for dna, protein in zip(dna_sequences, protein_sequences):
|
| 597 |
+
for i, amino in enumerate(protein):
|
| 598 |
+
codon = dna[i * 3 : (i + 1) * 3]
|
| 599 |
+
codon_loc = amino2codon[amino][0].index(codon)
|
| 600 |
+
amino2codon[amino][1][codon_loc] += 1
|
| 601 |
+
|
| 602 |
+
# Normalize codon frequencies per amino acid so they sum to 1
|
| 603 |
+
amino2codon = {
|
| 604 |
+
amino: (codons, [freq / (sum(frequencies) + 1e-100) for freq in frequencies])
|
| 605 |
+
for amino, (codons, frequencies) in amino2codon.items()
|
| 606 |
+
}
|
| 607 |
+
|
| 608 |
+
return amino2codon
|
| 609 |
+
|
| 610 |
+
|
| 611 |
+
def get_organism_to_codon_frequencies(
|
| 612 |
+
dataset: pd.DataFrame, organisms: List[str]
|
| 613 |
+
) -> Dict[str, AMINO2CODON_TYPE]:
|
| 614 |
+
"""
|
| 615 |
+
Return a dictionary mapping each organism to their codon frequency distribution.
|
| 616 |
+
|
| 617 |
+
Args:
|
| 618 |
+
dataset (pd.DataFrame): DataFrame containing DNA sequences.
|
| 619 |
+
organisms (List[str]): List of organisms.
|
| 620 |
+
|
| 621 |
+
Returns:
|
| 622 |
+
Dict[str, AMINO2CODON_TYPE]: Dictionary mapping each organism to its codon
|
| 623 |
+
frequency distribution.
|
| 624 |
+
"""
|
| 625 |
+
organism2frequencies = {}
|
| 626 |
+
|
| 627 |
+
# Calculate codon frequencies for each organism in the dataset
|
| 628 |
+
for organism in tqdm(
|
| 629 |
+
organisms, desc="Calculating Codon Frequencies: ", unit="Organism"
|
| 630 |
+
):
|
| 631 |
+
organism_data = dataset.loc[dataset["organism"] == organism]
|
| 632 |
+
|
| 633 |
+
dna_sequences = organism_data["dna"].to_list()
|
| 634 |
+
protein_sequences = organism_data["protein"].to_list()
|
| 635 |
+
|
| 636 |
+
codon_frequencies = get_codon_frequencies(dna_sequences, protein_sequences)
|
| 637 |
+
organism2frequencies[organism] = codon_frequencies
|
| 638 |
+
|
| 639 |
+
return organism2frequencies
|
| 640 |
+
|
| 641 |
+
|
| 642 |
+
def get_codon_table(organism: str) -> int:
|
| 643 |
+
"""
|
| 644 |
+
Return the appropriate NCBI codon table for a given organism.
|
| 645 |
+
|
| 646 |
+
Args:
|
| 647 |
+
organism (str): Name of the organism.
|
| 648 |
+
|
| 649 |
+
Returns:
|
| 650 |
+
int: Codon table number.
|
| 651 |
+
"""
|
| 652 |
+
# Common codon table (Table 1) for many model organisms
|
| 653 |
+
if organism in [
|
| 654 |
+
"Arabidopsis thaliana",
|
| 655 |
+
"Caenorhabditis elegans",
|
| 656 |
+
"Chlamydomonas reinhardtii",
|
| 657 |
+
"Saccharomyces cerevisiae",
|
| 658 |
+
"Danio rerio",
|
| 659 |
+
"Drosophila melanogaster",
|
| 660 |
+
"Homo sapiens",
|
| 661 |
+
"Mus musculus",
|
| 662 |
+
"Nicotiana tabacum",
|
| 663 |
+
"Solanum tuberosum",
|
| 664 |
+
"Solanum lycopersicum",
|
| 665 |
+
"Oryza sativa",
|
| 666 |
+
"Glycine max",
|
| 667 |
+
"Zea mays",
|
| 668 |
+
]:
|
| 669 |
+
codon_table = 1
|
| 670 |
+
|
| 671 |
+
# Chloroplast codon table (Table 11)
|
| 672 |
+
elif organism in [
|
| 673 |
+
"Chlamydomonas reinhardtii chloroplast",
|
| 674 |
+
"Nicotiana tabacum chloroplast",
|
| 675 |
+
]:
|
| 676 |
+
codon_table = 11
|
| 677 |
+
|
| 678 |
+
# Default to Table 11 for other bacteria and archaea
|
| 679 |
+
else:
|
| 680 |
+
codon_table = 11
|
| 681 |
+
|
| 682 |
+
return codon_table
|
CodonTransformer/CodonEvaluation.py
ADDED
|
@@ -0,0 +1,583 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
File: CodonEvaluation.py
|
| 3 |
+
---------------------------
|
| 4 |
+
Includes functions to calculate various evaluation metrics along with helper
|
| 5 |
+
functions.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
from typing import Dict, List, Tuple, Optional
|
| 9 |
+
|
| 10 |
+
import pandas as pd
|
| 11 |
+
from CAI import CAI, relative_adaptiveness
|
| 12 |
+
from tqdm import tqdm
|
| 13 |
+
import math
|
| 14 |
+
import numpy as np
|
| 15 |
+
from collections import Counter
|
| 16 |
+
from itertools import chain
|
| 17 |
+
from statistics import mean
|
| 18 |
+
import sys
|
| 19 |
+
import os
|
| 20 |
+
from io import StringIO
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def get_CSI_weights(sequences: List[str]) -> Dict[str, float]:
|
| 24 |
+
"""
|
| 25 |
+
Calculate the Codon Similarity Index (CSI) weights for a list of DNA sequences.
|
| 26 |
+
|
| 27 |
+
Args:
|
| 28 |
+
sequences (List[str]): List of DNA sequences.
|
| 29 |
+
|
| 30 |
+
Returns:
|
| 31 |
+
dict: The CSI weights.
|
| 32 |
+
"""
|
| 33 |
+
return relative_adaptiveness(sequences=sequences)
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def get_CSI_value(dna: str, weights: Dict[str, float]) -> float:
|
| 37 |
+
"""
|
| 38 |
+
Calculate the Codon Similarity Index (CSI) for a DNA sequence.
|
| 39 |
+
|
| 40 |
+
Args:
|
| 41 |
+
dna (str): The DNA sequence.
|
| 42 |
+
weights (dict): The CSI weights from get_CSI_weights.
|
| 43 |
+
|
| 44 |
+
Returns:
|
| 45 |
+
float: The CSI value.
|
| 46 |
+
"""
|
| 47 |
+
return CAI(dna, weights)
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
def get_organism_to_CSI_weights(
|
| 51 |
+
dataset: pd.DataFrame, organisms: List[str]
|
| 52 |
+
) -> Dict[str, dict]:
|
| 53 |
+
"""
|
| 54 |
+
Calculate the Codon Similarity Index (CSI) weights for a list of organisms.
|
| 55 |
+
|
| 56 |
+
Args:
|
| 57 |
+
dataset (pd.DataFrame): Dataset containing organism and DNA sequence info.
|
| 58 |
+
organisms (List[str]): List of organism names.
|
| 59 |
+
|
| 60 |
+
Returns:
|
| 61 |
+
Dict[str, dict]: A dictionary mapping each organism to its CSI weights.
|
| 62 |
+
"""
|
| 63 |
+
organism2weights = {}
|
| 64 |
+
|
| 65 |
+
# Iterate through each organism to calculate its CSI weights
|
| 66 |
+
for organism in tqdm(organisms, desc="Calculating CSI Weights: ", unit="Organism"):
|
| 67 |
+
organism_data = dataset.loc[dataset["organism"] == organism]
|
| 68 |
+
sequences = organism_data["dna"].to_list()
|
| 69 |
+
weights = get_CSI_weights(sequences)
|
| 70 |
+
organism2weights[organism] = weights
|
| 71 |
+
|
| 72 |
+
return organism2weights
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
def get_GC_content(dna: str) -> float:
|
| 76 |
+
"""
|
| 77 |
+
Calculate the GC content of a DNA sequence.
|
| 78 |
+
|
| 79 |
+
GC content is the percentage of nucleotides that are either G (guanine) or C (cytosine).
|
| 80 |
+
This metric is important for codon optimization as it affects expression levels and
|
| 81 |
+
synthesis efficiency in E. coli.
|
| 82 |
+
|
| 83 |
+
Args:
|
| 84 |
+
dna (str): The DNA sequence (uppercase or lowercase).
|
| 85 |
+
|
| 86 |
+
Returns:
|
| 87 |
+
float: The GC content as a percentage (0-100).
|
| 88 |
+
|
| 89 |
+
Example:
|
| 90 |
+
>>> get_GC_content("ATGCGATCG")
|
| 91 |
+
55.56 # 5 GC nucleotides out of 9 total
|
| 92 |
+
"""
|
| 93 |
+
dna = dna.upper()
|
| 94 |
+
if not dna:
|
| 95 |
+
return 0.0
|
| 96 |
+
return (dna.count("G") + dna.count("C")) / len(dna) * 100
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
def get_cfd(
|
| 100 |
+
dna: str,
|
| 101 |
+
codon_frequencies: Dict[str, Tuple[List[str], List[float]]],
|
| 102 |
+
threshold: float = 0.3,
|
| 103 |
+
) -> float:
|
| 104 |
+
"""
|
| 105 |
+
Calculate the codon frequency distribution (CFD) metric for a DNA sequence.
|
| 106 |
+
|
| 107 |
+
Args:
|
| 108 |
+
dna (str): The DNA sequence.
|
| 109 |
+
codon_frequencies (Dict[str, Tuple[List[str], List[float]]]): Codon
|
| 110 |
+
frequency distribution per amino acid.
|
| 111 |
+
threshold (float): Frequency threshold for counting rare codons.
|
| 112 |
+
|
| 113 |
+
Returns:
|
| 114 |
+
float: The CFD metric as a percentage.
|
| 115 |
+
"""
|
| 116 |
+
# Get a dictionary mapping each codon to its normalized frequency
|
| 117 |
+
codon2frequency = {
|
| 118 |
+
codon: freq / max(frequencies)
|
| 119 |
+
for amino, (codons, frequencies) in codon_frequencies.items()
|
| 120 |
+
for codon, freq in zip(codons, frequencies)
|
| 121 |
+
}
|
| 122 |
+
|
| 123 |
+
cfd = 0
|
| 124 |
+
|
| 125 |
+
# Iterate through the DNA sequence in steps of 3 to process each codon
|
| 126 |
+
for i in range(0, len(dna), 3):
|
| 127 |
+
codon = dna[i : i + 3]
|
| 128 |
+
codon_frequency = codon2frequency[codon]
|
| 129 |
+
|
| 130 |
+
if codon_frequency < threshold:
|
| 131 |
+
cfd += 1
|
| 132 |
+
|
| 133 |
+
return cfd / (len(dna) / 3) * 100
|
| 134 |
+
|
| 135 |
+
|
| 136 |
+
def get_min_max_percentage(
|
| 137 |
+
dna: str,
|
| 138 |
+
codon_frequencies: Dict[str, Tuple[List[str], List[float]]],
|
| 139 |
+
window_size: int = 18,
|
| 140 |
+
) -> List[float]:
|
| 141 |
+
"""
|
| 142 |
+
Calculate the %MinMax metric for a DNA sequence.
|
| 143 |
+
|
| 144 |
+
Args:
|
| 145 |
+
dna (str): The DNA sequence.
|
| 146 |
+
codon_frequencies (Dict[str, Tuple[List[str], List[float]]]): Codon
|
| 147 |
+
frequency distribution per amino acid.
|
| 148 |
+
window_size (int): Size of the window to calculate %MinMax.
|
| 149 |
+
|
| 150 |
+
Returns:
|
| 151 |
+
List[float]: List of %MinMax values for the sequence.
|
| 152 |
+
|
| 153 |
+
Credit: https://github.com/chowington/minmax
|
| 154 |
+
"""
|
| 155 |
+
# Get a dictionary mapping each codon to its respective amino acid
|
| 156 |
+
codon2amino = {
|
| 157 |
+
codon: amino
|
| 158 |
+
for amino, (codons, frequencies) in codon_frequencies.items()
|
| 159 |
+
for codon in codons
|
| 160 |
+
}
|
| 161 |
+
|
| 162 |
+
min_max_values = []
|
| 163 |
+
codons = [dna[i : i + 3] for i in range(0, len(dna), 3)] # Split DNA into codons
|
| 164 |
+
|
| 165 |
+
# Iterate through the DNA sequence using the specified window size
|
| 166 |
+
for i in range(len(codons) - window_size + 1):
|
| 167 |
+
codon_window = codons[i : i + window_size] # Codons in the current window
|
| 168 |
+
|
| 169 |
+
Actual = 0.0 # Average of the actual codon frequencies
|
| 170 |
+
Max = 0.0 # Average of the min codon frequencies
|
| 171 |
+
Min = 0.0 # Average of the max codon frequencies
|
| 172 |
+
Avg = 0.0 # Average of the averages of all frequencies for each amino acid
|
| 173 |
+
|
| 174 |
+
# Sum the frequencies for codons in the current window
|
| 175 |
+
for codon in codon_window:
|
| 176 |
+
aminoacid = codon2amino[codon]
|
| 177 |
+
frequencies = codon_frequencies[aminoacid][1]
|
| 178 |
+
codon_index = codon_frequencies[aminoacid][0].index(codon)
|
| 179 |
+
codon_frequency = codon_frequencies[aminoacid][1][codon_index]
|
| 180 |
+
|
| 181 |
+
Actual += codon_frequency
|
| 182 |
+
Max += max(frequencies)
|
| 183 |
+
Min += min(frequencies)
|
| 184 |
+
Avg += sum(frequencies) / len(frequencies)
|
| 185 |
+
|
| 186 |
+
# Divide by the window size to get the averages
|
| 187 |
+
Actual = Actual / window_size
|
| 188 |
+
Max = Max / window_size
|
| 189 |
+
Min = Min / window_size
|
| 190 |
+
Avg = Avg / window_size
|
| 191 |
+
|
| 192 |
+
# Calculate %MinMax
|
| 193 |
+
percentMax = ((Actual - Avg) / (Max - Avg)) * 100
|
| 194 |
+
percentMin = ((Avg - Actual) / (Avg - Min)) * 100
|
| 195 |
+
|
| 196 |
+
# Append the appropriate %MinMax value
|
| 197 |
+
if percentMax >= 0:
|
| 198 |
+
min_max_values.append(percentMax)
|
| 199 |
+
else:
|
| 200 |
+
min_max_values.append(-percentMin)
|
| 201 |
+
|
| 202 |
+
# Populate the last floor(window_size / 2) entries of min_max_values with None
|
| 203 |
+
for i in range(int(window_size / 2)):
|
| 204 |
+
min_max_values.append(None)
|
| 205 |
+
|
| 206 |
+
return min_max_values
|
| 207 |
+
|
| 208 |
+
|
| 209 |
+
def get_sequence_complexity(dna: str) -> float:
|
| 210 |
+
"""
|
| 211 |
+
Calculate the sequence complexity score of a DNA sequence.
|
| 212 |
+
|
| 213 |
+
Args:
|
| 214 |
+
dna (str): The DNA sequence.
|
| 215 |
+
|
| 216 |
+
Returns:
|
| 217 |
+
float: The sequence complexity score.
|
| 218 |
+
"""
|
| 219 |
+
|
| 220 |
+
def sum_up_to(x):
|
| 221 |
+
"""Recursive function to calculate the sum of integers from 1 to x."""
|
| 222 |
+
if x <= 1:
|
| 223 |
+
return 1
|
| 224 |
+
else:
|
| 225 |
+
return x + sum_up_to(x - 1)
|
| 226 |
+
|
| 227 |
+
def f(x):
|
| 228 |
+
"""Returns 4 if x is greater than or equal to 4, else returns x."""
|
| 229 |
+
if x >= 4:
|
| 230 |
+
return 4
|
| 231 |
+
elif x < 4:
|
| 232 |
+
return x
|
| 233 |
+
|
| 234 |
+
unique_subseq_length = []
|
| 235 |
+
|
| 236 |
+
# Calculate unique subsequences lengths
|
| 237 |
+
for i in range(1, len(dna) + 1):
|
| 238 |
+
unique_subseq = set()
|
| 239 |
+
for j in range(len(dna) - (i - 1)):
|
| 240 |
+
unique_subseq.add(dna[j : (j + i)])
|
| 241 |
+
unique_subseq_length.append(len(unique_subseq))
|
| 242 |
+
|
| 243 |
+
# Calculate complexity score
|
| 244 |
+
complexity_score = (
|
| 245 |
+
sum(unique_subseq_length) / (sum_up_to(len(dna) - 1) + f(len(dna)))
|
| 246 |
+
) * 100
|
| 247 |
+
|
| 248 |
+
return complexity_score
|
| 249 |
+
|
| 250 |
+
|
| 251 |
+
def get_sequence_similarity(
|
| 252 |
+
original: str, predicted: str, truncate: bool = True, window_length: int = 1
|
| 253 |
+
) -> float:
|
| 254 |
+
"""
|
| 255 |
+
Calculate the sequence similarity between two sequences.
|
| 256 |
+
|
| 257 |
+
Args:
|
| 258 |
+
original (str): The original sequence.
|
| 259 |
+
predicted (str): The predicted sequence.
|
| 260 |
+
truncate (bool): If True, truncate the original sequence to match the length
|
| 261 |
+
of the predicted sequence.
|
| 262 |
+
window_length (int): Length of the window for comparison (1 for amino acids,
|
| 263 |
+
3 for codons).
|
| 264 |
+
|
| 265 |
+
Returns:
|
| 266 |
+
float: The sequence similarity as a percentage.
|
| 267 |
+
|
| 268 |
+
Preconditions:
|
| 269 |
+
len(predicted) <= len(original).
|
| 270 |
+
"""
|
| 271 |
+
if not truncate and len(original) != len(predicted):
|
| 272 |
+
raise ValueError(
|
| 273 |
+
"Set truncate to True if the length of sequences do not match."
|
| 274 |
+
)
|
| 275 |
+
|
| 276 |
+
identity = 0.0
|
| 277 |
+
original = original.strip()
|
| 278 |
+
predicted = predicted.strip()
|
| 279 |
+
|
| 280 |
+
if truncate:
|
| 281 |
+
original = original[: len(predicted)]
|
| 282 |
+
|
| 283 |
+
if window_length == 1:
|
| 284 |
+
# Simple comparison for amino acid
|
| 285 |
+
for i in range(len(predicted)):
|
| 286 |
+
if original[i] == predicted[i]:
|
| 287 |
+
identity += 1
|
| 288 |
+
else:
|
| 289 |
+
# Comparison for substrings based on window_length
|
| 290 |
+
for i in range(0, len(original) - window_length + 1, window_length):
|
| 291 |
+
if original[i : i + window_length] == predicted[i : i + window_length]:
|
| 292 |
+
identity += 1
|
| 293 |
+
|
| 294 |
+
return (identity / (len(predicted) / window_length)) * 100
|
| 295 |
+
|
| 296 |
+
|
| 297 |
+
def scan_for_restriction_sites(seq: str, sites: List[str] = ['GAATTC', 'GGATCC', 'AAGCTT']) -> int:
|
| 298 |
+
"""
|
| 299 |
+
Scans for a list of restriction enzyme sites in a DNA sequence.
|
| 300 |
+
"""
|
| 301 |
+
return sum(seq.upper().count(site.upper()) for site in sites)
|
| 302 |
+
|
| 303 |
+
|
| 304 |
+
def count_negative_cis_elements(seq: str, motifs: List[str] = ['TATAAT', 'TTGACA', 'AGCTAGT']) -> int:
|
| 305 |
+
"""
|
| 306 |
+
Counts occurrences of negative cis-regulatory elements in a DNA sequence.
|
| 307 |
+
"""
|
| 308 |
+
return sum(seq.upper().count(m.upper()) for m in motifs)
|
| 309 |
+
|
| 310 |
+
|
| 311 |
+
def calculate_homopolymer_runs(seq: str, max_len: int = 8) -> int:
|
| 312 |
+
"""
|
| 313 |
+
Calculates the number of homopolymer runs longer than a given length.
|
| 314 |
+
"""
|
| 315 |
+
import re
|
| 316 |
+
min_len = max_len + 1
|
| 317 |
+
return len(re.findall(r'(A{%d,}|T{%d,}|G{%d,}|C{%d,})' % (min_len, min_len, min_len, min_len), seq.upper()))
|
| 318 |
+
|
| 319 |
+
|
| 320 |
+
def get_min_max_profile(
|
| 321 |
+
dna: str,
|
| 322 |
+
codon_frequencies: Dict[str, Tuple[List[str], List[float]]],
|
| 323 |
+
window_size: int = 18,
|
| 324 |
+
) -> List[float]:
|
| 325 |
+
"""
|
| 326 |
+
Calculate the %MinMax profile for a DNA sequence. This is a list of
|
| 327 |
+
%MinMax values for sliding windows across the sequence.
|
| 328 |
+
|
| 329 |
+
Args:
|
| 330 |
+
dna (str): The DNA sequence.
|
| 331 |
+
codon_frequencies (Dict[str, Tuple[List[str], List[float]]]): Codon
|
| 332 |
+
frequency distribution per amino acid.
|
| 333 |
+
window_size (int): Size of the window to calculate %MinMax.
|
| 334 |
+
|
| 335 |
+
Returns:
|
| 336 |
+
List[float]: List of %MinMax values for the sequence.
|
| 337 |
+
"""
|
| 338 |
+
return get_min_max_percentage(dna, codon_frequencies, window_size)
|
| 339 |
+
|
| 340 |
+
|
| 341 |
+
def calculate_dtw_distance(profile1: List[float], profile2: List[float]) -> float:
|
| 342 |
+
"""
|
| 343 |
+
Calculates the Dynamic Time Warping (DTW) distance between two profiles.
|
| 344 |
+
|
| 345 |
+
Args:
|
| 346 |
+
profile1 (List[float]): The first profile (e.g., %MinMax of generated sequence).
|
| 347 |
+
profile2 (List[float]): The second profile (e.g., %MinMax of natural sequence).
|
| 348 |
+
|
| 349 |
+
Returns:
|
| 350 |
+
float: The DTW distance between the two profiles.
|
| 351 |
+
"""
|
| 352 |
+
from dtw import dtw
|
| 353 |
+
import numpy as np
|
| 354 |
+
|
| 355 |
+
# Ensure profiles are numpy arrays and handle potential None and NaN values
|
| 356 |
+
p1 = np.array([v for v in profile1 if v is not None and not np.isnan(v)]).reshape(
|
| 357 |
+
-1, 1
|
| 358 |
+
)
|
| 359 |
+
p2 = np.array([v for v in profile2 if v is not None and not np.isnan(v)]).reshape(
|
| 360 |
+
-1, 1
|
| 361 |
+
)
|
| 362 |
+
|
| 363 |
+
if len(p1) == 0 or len(p2) == 0:
|
| 364 |
+
return np.inf # Return infinity if one of the profiles is empty
|
| 365 |
+
|
| 366 |
+
alignment = dtw(p1, p2, keep_internals=True)
|
| 367 |
+
return alignment.distance # type: ignore
|
| 368 |
+
|
| 369 |
+
|
| 370 |
+
def get_ecoli_tai_weights():
|
| 371 |
+
"""
|
| 372 |
+
Returns a dictionary of tAI weights for E. coli based on tRNA gene copy numbers.
|
| 373 |
+
These weights are pre-calculated based on the relative adaptiveness of each codon.
|
| 374 |
+
"""
|
| 375 |
+
codons = [
|
| 376 |
+
"TTT", "TTC", "TTA", "TTG", "TCT", "TCC", "TCA", "TCG", "TAT", "TAC",
|
| 377 |
+
"TGT", "TGC", "TGG", "CTT", "CTC", "CTA", "CTG", "CCT", "CCC", "CCA",
|
| 378 |
+
"CCG", "CAT", "CAC", "CAA", "CAG", "CGT", "CGC", "CGA", "CGG", "ATT",
|
| 379 |
+
"ATC", "ATA", "ACT", "ACC", "ACA", "ACG", "AAT", "AAC", "AAA", "AAG",
|
| 380 |
+
"AGT", "AGC", "AGA", "AGG", "GTT", "GTC", "GTA", "GTG", "GCT", "GCC",
|
| 381 |
+
"GCA", "GCG", "GAT", "GAC", "GAA", "GAG", "GGT", "GGC", "GGA", "GGG"
|
| 382 |
+
]
|
| 383 |
+
weights = [
|
| 384 |
+
0.1966667, 0.3333333, 0.1666667, 0.2200000, 0.1966667, 0.3333333,
|
| 385 |
+
0.1666667, 0.2200000, 0.2950000, 0.5000000, 0.09833333, 0.1666667,
|
| 386 |
+
0.2200000, 0.09833333, 0.1666667, 0.1666667, 0.7200000, 0.09833333,
|
| 387 |
+
0.1666667, 0.1666667, 0.2200000, 0.09833333, 0.1666667, 0.3333333,
|
| 388 |
+
0.4400000, 0.6666667, 0.4800000, 0.00006666667, 0.1666667, 0.2950000,
|
| 389 |
+
0.5000000, 0.01833333, 0.1966667, 0.3333333, 0.1666667, 0.3866667,
|
| 390 |
+
0.3933333, 0.6666667, 1.0000000, 0.3200000, 0.09833333, 0.1666667,
|
| 391 |
+
0.1666667, 0.2200000, 0.1966667, 0.3333333, 0.8333333, 0.2666667,
|
| 392 |
+
0.1966667, 0.3333333, 0.5000000, 0.1600000, 0.2950000, 0.5000000,
|
| 393 |
+
0.6666667, 0.2133333, 0.3933333, 0.6666667, 0.1666667, 0.2200000
|
| 394 |
+
]
|
| 395 |
+
return dict(zip(codons, weights))
|
| 396 |
+
|
| 397 |
+
|
| 398 |
+
def calculate_tAI(sequence: str, tai_weights: Dict[str, float]) -> float:
|
| 399 |
+
"""
|
| 400 |
+
Calculates the tRNA Adaptation Index (tAI) for a given DNA sequence.
|
| 401 |
+
|
| 402 |
+
Args:
|
| 403 |
+
sequence (str): The DNA sequence to analyze.
|
| 404 |
+
tai_weights (Dict[str, float]): A dictionary of tAI weights for each codon.
|
| 405 |
+
|
| 406 |
+
Returns:
|
| 407 |
+
float: The tAI value for the sequence.
|
| 408 |
+
"""
|
| 409 |
+
from scipy.stats.mstats import gmean
|
| 410 |
+
|
| 411 |
+
codons = [sequence[i:i+3] for i in range(0, len(sequence), 3)]
|
| 412 |
+
|
| 413 |
+
# Filter out stop codons and codons not in weights
|
| 414 |
+
weights = [tai_weights[codon] for codon in codons if codon in tai_weights and tai_weights[codon] > 0]
|
| 415 |
+
|
| 416 |
+
if not weights:
|
| 417 |
+
return 0.0
|
| 418 |
+
|
| 419 |
+
return gmean(weights)
|
| 420 |
+
|
| 421 |
+
|
| 422 |
+
def calculate_ENC(sequence: str) -> float:
|
| 423 |
+
"""
|
| 424 |
+
Calculate the Effective Number of Codons (ENC) for a DNA sequence.
|
| 425 |
+
Uses the codonbias library implementation based on Wright (1990).
|
| 426 |
+
|
| 427 |
+
Args:
|
| 428 |
+
sequence (str): The DNA sequence.
|
| 429 |
+
|
| 430 |
+
Returns:
|
| 431 |
+
float: The ENC value for the sequence.
|
| 432 |
+
"""
|
| 433 |
+
try:
|
| 434 |
+
from codonbias.scores import EffectiveNumberOfCodons
|
| 435 |
+
|
| 436 |
+
# Initialize ENC calculator
|
| 437 |
+
enc_calculator = EffectiveNumberOfCodons(
|
| 438 |
+
k_mer=1, # Standard codon analysis
|
| 439 |
+
bg_correction=True, # Use background correction
|
| 440 |
+
robust=True, # Use robust calculation
|
| 441 |
+
genetic_code=1 # Standard genetic code
|
| 442 |
+
)
|
| 443 |
+
|
| 444 |
+
# Calculate ENC for the sequence
|
| 445 |
+
enc_value = enc_calculator.get_score(sequence)
|
| 446 |
+
|
| 447 |
+
return float(enc_value)
|
| 448 |
+
|
| 449 |
+
except ImportError:
|
| 450 |
+
raise ImportError("codonbias library is required for ENC calculation. Install with: pip install codonbias")
|
| 451 |
+
except Exception as e:
|
| 452 |
+
# Fallback to a simple ENC approximation if library fails
|
| 453 |
+
print(f"Warning: ENC calculation failed with error: {e}. Using approximation.")
|
| 454 |
+
return 45.0 # Typical E. coli ENC value as fallback
|
| 455 |
+
|
| 456 |
+
|
| 457 |
+
def calculate_CPB(sequence: str, reference_sequences: Optional[List[str]] = None) -> float:
|
| 458 |
+
"""
|
| 459 |
+
Calculate the Codon Pair Bias (CPB) for a DNA sequence.
|
| 460 |
+
Uses the codonbias library implementation based on Coleman et al. (2008).
|
| 461 |
+
|
| 462 |
+
Args:
|
| 463 |
+
sequence (str): The DNA sequence.
|
| 464 |
+
reference_sequences (List[str]): Reference sequences for calculating expected values.
|
| 465 |
+
If None, uses a default E. coli reference.
|
| 466 |
+
|
| 467 |
+
Returns:
|
| 468 |
+
float: The CPB value for the sequence.
|
| 469 |
+
"""
|
| 470 |
+
try:
|
| 471 |
+
from codonbias.scores import CodonPairBias
|
| 472 |
+
|
| 473 |
+
# Use provided reference sequences or default
|
| 474 |
+
if reference_sequences is None:
|
| 475 |
+
# Use the input sequence as reference if none provided
|
| 476 |
+
reference_sequences = [sequence]
|
| 477 |
+
|
| 478 |
+
# Initialize CPB calculator with reference sequences
|
| 479 |
+
cpb_calculator = CodonPairBias(
|
| 480 |
+
ref_seq=reference_sequences,
|
| 481 |
+
k_mer=2, # Codon pairs
|
| 482 |
+
genetic_code=1, # Standard genetic code
|
| 483 |
+
ignore_stop=True, # Ignore stop codons
|
| 484 |
+
pseudocount=1 # Pseudocount for unseen pairs
|
| 485 |
+
)
|
| 486 |
+
|
| 487 |
+
# Calculate CPB for the sequence
|
| 488 |
+
cpb_value = cpb_calculator.get_score(sequence)
|
| 489 |
+
|
| 490 |
+
return float(cpb_value)
|
| 491 |
+
|
| 492 |
+
except ImportError:
|
| 493 |
+
raise ImportError("codonbias library is required for CPB calculation. Install with: pip install codonbias")
|
| 494 |
+
except Exception as e:
|
| 495 |
+
# Fallback calculation if library fails
|
| 496 |
+
print(f"Warning: CPB calculation failed with error: {e}. Using approximation.")
|
| 497 |
+
return 0.0 # Neutral CPB as fallback
|
| 498 |
+
|
| 499 |
+
|
| 500 |
+
def calculate_SCUO(sequence: str) -> float:
|
| 501 |
+
"""
|
| 502 |
+
Calculate the Synonymous Codon Usage Order (SCUO) for a DNA sequence.
|
| 503 |
+
Uses the GCUA library implementation based on information theory.
|
| 504 |
+
|
| 505 |
+
Args:
|
| 506 |
+
sequence (str): The DNA sequence.
|
| 507 |
+
|
| 508 |
+
Returns:
|
| 509 |
+
float: The SCUO value (0-1, where 1 indicates maximum bias).
|
| 510 |
+
"""
|
| 511 |
+
# Self-contained SCUO implementation (no external GCUA dependency).
|
| 512 |
+
# Based on Wan et al., 2004 information-theoretic definition.
|
| 513 |
+
|
| 514 |
+
from math import log2 # local import to avoid global cost
|
| 515 |
+
try:
|
| 516 |
+
# Build standard genetic code mapping using built-in tables (Biopython optional).
|
| 517 |
+
# Fall back to hard-coded table if Biopython absent.
|
| 518 |
+
try:
|
| 519 |
+
from Bio.Data import CodonTable # type: ignore
|
| 520 |
+
codon_to_aa = CodonTable.unambiguous_dna_by_id[1].forward_table
|
| 521 |
+
except Exception:
|
| 522 |
+
codon_to_aa = {
|
| 523 |
+
# Partial table sufficient for SCUO calculation; stop codons omitted.
|
| 524 |
+
'TTT': 'F', 'TTC': 'F', 'TTA': 'L', 'TTG': 'L',
|
| 525 |
+
'CTT': 'L', 'CTC': 'L', 'CTA': 'L', 'CTG': 'L',
|
| 526 |
+
'ATT': 'I', 'ATC': 'I', 'ATA': 'I', 'ATG': 'M',
|
| 527 |
+
'GTT': 'V', 'GTC': 'V', 'GTA': 'V', 'GTG': 'V',
|
| 528 |
+
'TCT': 'S', 'TCC': 'S', 'TCA': 'S', 'TCG': 'S',
|
| 529 |
+
'CCT': 'P', 'CCC': 'P', 'CCA': 'P', 'CCG': 'P',
|
| 530 |
+
'ACT': 'T', 'ACC': 'T', 'ACA': 'T', 'ACG': 'T',
|
| 531 |
+
'GCT': 'A', 'GCC': 'A', 'GCA': 'A', 'GCG': 'A',
|
| 532 |
+
'TAT': 'Y', 'TAC': 'Y', 'TAA': '*', 'TAG': '*',
|
| 533 |
+
'CAT': 'H', 'CAC': 'H', 'CAA': 'Q', 'CAG': 'Q',
|
| 534 |
+
'AAT': 'N', 'AAC': 'N', 'AAA': 'K', 'AAG': 'K',
|
| 535 |
+
'GAT': 'D', 'GAC': 'D', 'GAA': 'E', 'GAG': 'E',
|
| 536 |
+
'TGT': 'C', 'TGC': 'C', 'TGA': '*', 'TGG': 'W',
|
| 537 |
+
'CGT': 'R', 'CGC': 'R', 'CGA': 'R', 'CGG': 'R',
|
| 538 |
+
'AGT': 'S', 'AGC': 'S', 'AGA': 'R', 'AGG': 'R',
|
| 539 |
+
'GGT': 'G', 'GGC': 'G', 'GGA': 'G', 'GGG': 'G',
|
| 540 |
+
}
|
| 541 |
+
|
| 542 |
+
# Group codons by amino acid (exclude stops)
|
| 543 |
+
aa_to_codons = {}
|
| 544 |
+
for codon, aa in codon_to_aa.items():
|
| 545 |
+
aa_to_codons.setdefault(aa, []).append(codon)
|
| 546 |
+
|
| 547 |
+
# Count codon occurrences in input sequence
|
| 548 |
+
seq = sequence.upper().replace('U', 'T')
|
| 549 |
+
codon_counts = {}
|
| 550 |
+
for i in range(0, len(seq) - len(seq) % 3, 3):
|
| 551 |
+
codon = seq[i:i+3]
|
| 552 |
+
if codon in codon_to_aa:
|
| 553 |
+
codon_counts[codon] = codon_counts.get(codon, 0) + 1
|
| 554 |
+
|
| 555 |
+
total_codons = sum(codon_counts.values())
|
| 556 |
+
if total_codons == 0:
|
| 557 |
+
return 0.0
|
| 558 |
+
|
| 559 |
+
scuo_sum = 0.0
|
| 560 |
+
|
| 561 |
+
for aa, codons in aa_to_codons.items():
|
| 562 |
+
n_codons = len(codons)
|
| 563 |
+
if n_codons == 1:
|
| 564 |
+
continue # SCUO undefined for Met/Trp
|
| 565 |
+
|
| 566 |
+
counts = [codon_counts.get(c, 0) for c in codons]
|
| 567 |
+
total_aa = sum(counts)
|
| 568 |
+
if total_aa == 0:
|
| 569 |
+
continue
|
| 570 |
+
|
| 571 |
+
probs = [c / total_aa for c in counts if c]
|
| 572 |
+
H_obs = -sum(p * log2(p) for p in probs)
|
| 573 |
+
H_max = log2(n_codons)
|
| 574 |
+
O_i = (H_max - H_obs) / H_max if H_max else 0.0
|
| 575 |
+
F_i = total_aa / total_codons
|
| 576 |
+
scuo_sum += F_i * O_i
|
| 577 |
+
|
| 578 |
+
return scuo_sum
|
| 579 |
+
|
| 580 |
+
except Exception as exc:
|
| 581 |
+
print(f"Warning: internal SCUO computation failed ({exc}). Returning 0.5.")
|
| 582 |
+
return 0.5
|
| 583 |
+
|
CodonTransformer/CodonJupyter.py
ADDED
|
@@ -0,0 +1,311 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
File: CodonJupyter.py
|
| 3 |
+
---------------------
|
| 4 |
+
Includes Jupyter-specific functions for displaying interactive widgets.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
from typing import Dict, List, Tuple
|
| 8 |
+
|
| 9 |
+
import ipywidgets as widgets
|
| 10 |
+
from IPython.display import HTML, display
|
| 11 |
+
|
| 12 |
+
from CodonTransformer.CodonUtils import (
|
| 13 |
+
COMMON_ORGANISMS,
|
| 14 |
+
ID2ORGANISM,
|
| 15 |
+
ORGANISM2ID,
|
| 16 |
+
DNASequencePrediction,
|
| 17 |
+
)
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
class UserContainer:
|
| 21 |
+
"""
|
| 22 |
+
A container class to store user inputs for organism and protein sequence.
|
| 23 |
+
Attributes:
|
| 24 |
+
organism (int): The selected organism id.
|
| 25 |
+
protein (str): The input protein sequence.
|
| 26 |
+
"""
|
| 27 |
+
|
| 28 |
+
def __init__(self) -> None:
|
| 29 |
+
self.organism: int = -1
|
| 30 |
+
self.protein: str = ""
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
def create_styled_options(
|
| 34 |
+
organisms: list, organism2id: Dict[str, int], is_fine_tuned: bool = False
|
| 35 |
+
) -> list:
|
| 36 |
+
"""
|
| 37 |
+
Create styled options for the dropdown widget.
|
| 38 |
+
|
| 39 |
+
Args:
|
| 40 |
+
organisms (list): List of organism names.
|
| 41 |
+
organism2id (Dict[str, int]): Dictionary mapping organism names to their IDs.
|
| 42 |
+
is_fine_tuned (bool): Whether these are fine-tuned organisms.
|
| 43 |
+
|
| 44 |
+
Returns:
|
| 45 |
+
list: Styled options for the dropdown widget.
|
| 46 |
+
"""
|
| 47 |
+
styled_options = []
|
| 48 |
+
for organism in organisms:
|
| 49 |
+
organism_id = organism2id[organism]
|
| 50 |
+
if is_fine_tuned:
|
| 51 |
+
if organism_id < 10:
|
| 52 |
+
styled_options.append(f"\u200b{organism_id:>6}. {organism}")
|
| 53 |
+
elif organism_id < 100:
|
| 54 |
+
styled_options.append(f"\u200b{organism_id:>5}. {organism}")
|
| 55 |
+
else:
|
| 56 |
+
styled_options.append(f"\u200b{organism_id:>4}. {organism}")
|
| 57 |
+
else:
|
| 58 |
+
if organism_id < 10:
|
| 59 |
+
styled_options.append(f"{organism_id:>6}. {organism}")
|
| 60 |
+
elif organism_id < 100:
|
| 61 |
+
styled_options.append(f"{organism_id:>5}. {organism}")
|
| 62 |
+
else:
|
| 63 |
+
styled_options.append(f"{organism_id:>4}. {organism}")
|
| 64 |
+
return styled_options
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
def create_dropdown_options(organism2id: Dict[str, int]) -> list:
|
| 68 |
+
"""
|
| 69 |
+
Create the full list of dropdown options, including section headers.
|
| 70 |
+
|
| 71 |
+
Args:
|
| 72 |
+
organism2id (Dict[str, int]): Dictionary mapping organism names to their IDs.
|
| 73 |
+
|
| 74 |
+
Returns:
|
| 75 |
+
list: Full list of dropdown options.
|
| 76 |
+
"""
|
| 77 |
+
fine_tuned_organisms = sorted(
|
| 78 |
+
[org for org in organism2id.keys() if org in COMMON_ORGANISMS]
|
| 79 |
+
)
|
| 80 |
+
all_organisms = sorted(organism2id.keys())
|
| 81 |
+
|
| 82 |
+
fine_tuned_options = create_styled_options(
|
| 83 |
+
fine_tuned_organisms, organism2id, is_fine_tuned=True
|
| 84 |
+
)
|
| 85 |
+
all_organisms_options = create_styled_options(
|
| 86 |
+
all_organisms, organism2id, is_fine_tuned=False
|
| 87 |
+
)
|
| 88 |
+
|
| 89 |
+
return (
|
| 90 |
+
[""]
|
| 91 |
+
+ ["Selected Organisms"]
|
| 92 |
+
+ fine_tuned_options
|
| 93 |
+
+ [""]
|
| 94 |
+
+ ["All Organisms"]
|
| 95 |
+
+ all_organisms_options
|
| 96 |
+
)
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
def create_organism_dropdown(container: UserContainer) -> widgets.Dropdown:
|
| 100 |
+
"""
|
| 101 |
+
Create and configure the organism dropdown widget.
|
| 102 |
+
|
| 103 |
+
Args:
|
| 104 |
+
container (UserContainer): Container to store the selected organism.
|
| 105 |
+
|
| 106 |
+
Returns:
|
| 107 |
+
widgets.Dropdown: Configured dropdown widget.
|
| 108 |
+
"""
|
| 109 |
+
dropdown = widgets.Dropdown(
|
| 110 |
+
options=create_dropdown_options(ORGANISM2ID),
|
| 111 |
+
description="",
|
| 112 |
+
layout=widgets.Layout(width="40%", margin="0 0 10px 0"),
|
| 113 |
+
style={"description_width": "initial"},
|
| 114 |
+
)
|
| 115 |
+
|
| 116 |
+
def show_organism(change: Dict[str, str]) -> None:
|
| 117 |
+
"""
|
| 118 |
+
Update the container with the selected organism and print to terminal.
|
| 119 |
+
|
| 120 |
+
Args:
|
| 121 |
+
change (Dict[str, str]): Information about the change in dropdown value.
|
| 122 |
+
"""
|
| 123 |
+
dropdown_choice = change["new"]
|
| 124 |
+
if dropdown_choice and dropdown_choice not in [
|
| 125 |
+
"Selected Organisms",
|
| 126 |
+
"All Organisms",
|
| 127 |
+
]:
|
| 128 |
+
organism = "".join(filter(str.isdigit, dropdown_choice))
|
| 129 |
+
organism_id = ID2ORGANISM[int(organism)]
|
| 130 |
+
container.organism = organism_id
|
| 131 |
+
else:
|
| 132 |
+
container.organism = None
|
| 133 |
+
|
| 134 |
+
dropdown.observe(show_organism, names="value")
|
| 135 |
+
return dropdown
|
| 136 |
+
|
| 137 |
+
|
| 138 |
+
def get_dropdown_style() -> str:
|
| 139 |
+
"""
|
| 140 |
+
Return the custom CSS style for the dropdown widget.
|
| 141 |
+
|
| 142 |
+
Returns:
|
| 143 |
+
str: CSS style string.
|
| 144 |
+
"""
|
| 145 |
+
return """
|
| 146 |
+
<style>
|
| 147 |
+
.widget-dropdown > select {
|
| 148 |
+
font-size: 16px;
|
| 149 |
+
font-weight: normal;
|
| 150 |
+
background-color: #f0f0f0;
|
| 151 |
+
border-radius: 5px;
|
| 152 |
+
padding: 5px;
|
| 153 |
+
}
|
| 154 |
+
.widget-label {
|
| 155 |
+
font-size: 18px;
|
| 156 |
+
font-weight: bold;
|
| 157 |
+
}
|
| 158 |
+
.custom-container {
|
| 159 |
+
display: flex;
|
| 160 |
+
flex-direction: column;
|
| 161 |
+
align-items: flex-start;
|
| 162 |
+
}
|
| 163 |
+
.widget-dropdown option[value^="\u200b"] {
|
| 164 |
+
font-family: sans-serif;
|
| 165 |
+
font-weight: bold;
|
| 166 |
+
font-size: 18px;
|
| 167 |
+
padding: 510px;
|
| 168 |
+
}
|
| 169 |
+
.widget-dropdown option[value*="Selected Organisms"],
|
| 170 |
+
.widget-dropdown option[value*="All Organisms"] {
|
| 171 |
+
text-align: center;
|
| 172 |
+
font-family: Arial, sans-serif;
|
| 173 |
+
font-weight: bold;
|
| 174 |
+
font-size: 20px;
|
| 175 |
+
color: #6900A1;
|
| 176 |
+
background-color: #00D8A1;
|
| 177 |
+
}
|
| 178 |
+
</style>
|
| 179 |
+
"""
|
| 180 |
+
|
| 181 |
+
|
| 182 |
+
def display_organism_dropdown(container: UserContainer) -> None:
|
| 183 |
+
"""
|
| 184 |
+
Display the organism dropdown widget and apply custom styles.
|
| 185 |
+
|
| 186 |
+
Args:
|
| 187 |
+
container (UserContainer): Container to store the selected organism.
|
| 188 |
+
"""
|
| 189 |
+
dropdown = create_organism_dropdown(container)
|
| 190 |
+
header = widgets.HTML(
|
| 191 |
+
'<b style="font-size:20px;">Select Organism:</b>'
|
| 192 |
+
'<div style="height:10px;"></div>'
|
| 193 |
+
)
|
| 194 |
+
container_widget = widgets.VBox(
|
| 195 |
+
[header, dropdown],
|
| 196 |
+
layout=widgets.Layout(padding="12px 0 12px 25px"),
|
| 197 |
+
)
|
| 198 |
+
display(container_widget)
|
| 199 |
+
display(HTML(get_dropdown_style()))
|
| 200 |
+
|
| 201 |
+
|
| 202 |
+
def display_protein_input(container: UserContainer) -> None:
|
| 203 |
+
"""
|
| 204 |
+
Display a widget for entering a protein sequence and save it to the container.
|
| 205 |
+
|
| 206 |
+
Args:
|
| 207 |
+
container (UserContainer): A container to store the entered protein sequence.
|
| 208 |
+
"""
|
| 209 |
+
protein_input = widgets.Textarea(
|
| 210 |
+
value="",
|
| 211 |
+
placeholder="Enter here...",
|
| 212 |
+
description="",
|
| 213 |
+
layout=widgets.Layout(width="100%", height="100px", margin="0 0 10px 0"),
|
| 214 |
+
style={"description_width": "initial"},
|
| 215 |
+
)
|
| 216 |
+
|
| 217 |
+
# Custom CSS for the input widget
|
| 218 |
+
input_style = """
|
| 219 |
+
<style>
|
| 220 |
+
.widget-textarea > textarea {
|
| 221 |
+
font-size: 12px;
|
| 222 |
+
font-family: Arial, sans-serif;
|
| 223 |
+
font-weight: normal;
|
| 224 |
+
background-color: #f0f0f0;
|
| 225 |
+
border-radius: 5px;
|
| 226 |
+
padding: 10px;
|
| 227 |
+
}
|
| 228 |
+
.widget-label {
|
| 229 |
+
font-size: 18px;
|
| 230 |
+
font-weight: bold;
|
| 231 |
+
}
|
| 232 |
+
.custom-container {
|
| 233 |
+
display: flex;
|
| 234 |
+
flex-direction: column;
|
| 235 |
+
align-items: flex-start;
|
| 236 |
+
}
|
| 237 |
+
</style>
|
| 238 |
+
"""
|
| 239 |
+
|
| 240 |
+
# Function to save the input protein sequence to the container
|
| 241 |
+
def save_protein(change: Dict[str, str]) -> None:
|
| 242 |
+
"""
|
| 243 |
+
Save the input protein sequence to the container.
|
| 244 |
+
|
| 245 |
+
Args:
|
| 246 |
+
change (Dict[str, str]): A dictionary containing information about
|
| 247 |
+
the change in textarea value.
|
| 248 |
+
"""
|
| 249 |
+
container.protein = (
|
| 250 |
+
change["new"]
|
| 251 |
+
.upper()
|
| 252 |
+
.strip()
|
| 253 |
+
.replace("\n", "")
|
| 254 |
+
.replace(" ", "")
|
| 255 |
+
.replace("\t", "")
|
| 256 |
+
)
|
| 257 |
+
|
| 258 |
+
# Attach the function to the input widget
|
| 259 |
+
protein_input.observe(save_protein, names="value")
|
| 260 |
+
|
| 261 |
+
# Display the input widget
|
| 262 |
+
header = widgets.HTML(
|
| 263 |
+
'<b style="font-size:20px;">Enter Protein Sequence:</b>'
|
| 264 |
+
'<div style="height:18px;"></div>'
|
| 265 |
+
)
|
| 266 |
+
container_widget = widgets.VBox(
|
| 267 |
+
[header, protein_input], layout=widgets.Layout(padding="12px 12px 0 25px")
|
| 268 |
+
)
|
| 269 |
+
|
| 270 |
+
display(container_widget)
|
| 271 |
+
display(widgets.HTML(input_style))
|
| 272 |
+
|
| 273 |
+
|
| 274 |
+
def format_model_output(output: DNASequencePrediction) -> str:
|
| 275 |
+
"""
|
| 276 |
+
Format DNA sequence prediction output in an appealing and easy-to-read manner.
|
| 277 |
+
|
| 278 |
+
This function takes the prediction output and formats it into
|
| 279 |
+
a structured string with clear section headers and separators.
|
| 280 |
+
|
| 281 |
+
Args:
|
| 282 |
+
output (DNASequencePrediction): Object containing the prediction output.
|
| 283 |
+
Expected attributes:
|
| 284 |
+
- organism (str): The organism name.
|
| 285 |
+
- protein (str): The input protein sequence.
|
| 286 |
+
- processed_input (str): The processed input sequence.
|
| 287 |
+
- predicted_dna (str): The predicted DNA sequence.
|
| 288 |
+
|
| 289 |
+
Returns:
|
| 290 |
+
str: A formatted string containing the organized output.
|
| 291 |
+
"""
|
| 292 |
+
|
| 293 |
+
def format_section(title: str, content: str) -> str:
|
| 294 |
+
"""Helper function to format individual sections."""
|
| 295 |
+
separator = "-" * 29
|
| 296 |
+
title_line = f"| {title.center(25)} |"
|
| 297 |
+
return f"{separator}\n{title_line}\n{separator}\n{content}\n\n"
|
| 298 |
+
|
| 299 |
+
sections: List[Tuple[str, str]] = [
|
| 300 |
+
("Organism", output.organism),
|
| 301 |
+
("Input Protein", output.protein),
|
| 302 |
+
("Processed Input", output.processed_input),
|
| 303 |
+
("Predicted DNA", output.predicted_dna),
|
| 304 |
+
]
|
| 305 |
+
|
| 306 |
+
formatted_output = ""
|
| 307 |
+
for title, content in sections:
|
| 308 |
+
formatted_output += format_section(title, content)
|
| 309 |
+
|
| 310 |
+
# Remove the last newline to avoid extra space at the end
|
| 311 |
+
return formatted_output.rstrip()
|
CodonTransformer/CodonPostProcessing.py
ADDED
|
@@ -0,0 +1,83 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
File: CodonPostProcessing.py
|
| 3 |
+
---------------------------
|
| 4 |
+
Post-processing utilities for codon optimization using DNAChisel.
|
| 5 |
+
This module provides sequence polishing capabilities to fix restriction sites,
|
| 6 |
+
homopolymers, and other constraints while preserving CAI and GC content.
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
import warnings
|
| 10 |
+
import numpy as np
|
| 11 |
+
|
| 12 |
+
try:
|
| 13 |
+
from dnachisel import (
|
| 14 |
+
DnaOptimizationProblem,
|
| 15 |
+
AvoidPattern,
|
| 16 |
+
EnforceGCContent,
|
| 17 |
+
EnforceTranslation,
|
| 18 |
+
CodonOptimize,
|
| 19 |
+
)
|
| 20 |
+
DNACHISEL_AVAILABLE = True
|
| 21 |
+
except ImportError:
|
| 22 |
+
DNACHISEL_AVAILABLE = False
|
| 23 |
+
# This warning will be shown when the module is first imported.
|
| 24 |
+
warnings.warn(
|
| 25 |
+
"DNAChisel is not installed. Post-processing features will be disabled."
|
| 26 |
+
)
|
| 27 |
+
|
| 28 |
+
def polish_sequence_with_dnachisel(
|
| 29 |
+
dna_sequence: str,
|
| 30 |
+
protein_sequence: str,
|
| 31 |
+
gc_bounds: tuple = (45.0, 55.0),
|
| 32 |
+
cai_species: str = "e_coli",
|
| 33 |
+
avoid_homopolymers_length: int = 6,
|
| 34 |
+
enzymes_to_avoid: list = None
|
| 35 |
+
):
|
| 36 |
+
"""
|
| 37 |
+
Polishes a DNA sequence using DNAChisel to meet lab synthesis constraints.
|
| 38 |
+
"""
|
| 39 |
+
if not DNACHISEL_AVAILABLE:
|
| 40 |
+
warnings.warn("DNAChisel not available, skipping post-processing.")
|
| 41 |
+
return dna_sequence
|
| 42 |
+
|
| 43 |
+
if enzymes_to_avoid is None:
|
| 44 |
+
# Common cloning enzymes
|
| 45 |
+
enzymes_to_avoid = ["EcoRI", "XbaI", "SpeI", "PstI", "NotI"]
|
| 46 |
+
|
| 47 |
+
try:
|
| 48 |
+
# Start with the basic, essential constraints
|
| 49 |
+
constraints = [
|
| 50 |
+
EnforceTranslation(translation=protein_sequence),
|
| 51 |
+
EnforceGCContent(mini=gc_bounds[0] / 100.0, maxi=gc_bounds[1] / 100.0),
|
| 52 |
+
]
|
| 53 |
+
|
| 54 |
+
# Add enzyme avoidance constraints safely
|
| 55 |
+
for enzyme in enzymes_to_avoid:
|
| 56 |
+
try:
|
| 57 |
+
# This is the modern way to avoid enzyme sites
|
| 58 |
+
constraints.append(AvoidPattern.from_enzyme_name(enzyme))
|
| 59 |
+
except Exception:
|
| 60 |
+
warnings.warn(f"Could not find enzyme '{enzyme}' in DNAChisel library.")
|
| 61 |
+
|
| 62 |
+
# Add homopolymer avoidance constraints
|
| 63 |
+
for base in "ATGC":
|
| 64 |
+
constraints.append(AvoidPattern(base * avoid_homopolymers_length))
|
| 65 |
+
|
| 66 |
+
# Define the optimization problem
|
| 67 |
+
problem = DnaOptimizationProblem(
|
| 68 |
+
sequence=dna_sequence,
|
| 69 |
+
constraints=constraints,
|
| 70 |
+
objectives=[CodonOptimize(species=cai_species, method="match_codon_usage")]
|
| 71 |
+
)
|
| 72 |
+
|
| 73 |
+
# Solve the problem
|
| 74 |
+
problem.resolve_constraints()
|
| 75 |
+
problem.optimize()
|
| 76 |
+
|
| 77 |
+
# Return the polished sequence
|
| 78 |
+
return problem.sequence
|
| 79 |
+
|
| 80 |
+
except Exception as e:
|
| 81 |
+
warnings.warn(f"DNAChisel post-processing failed with an error: {e}")
|
| 82 |
+
# Return the original sequence if polishing fails
|
| 83 |
+
return dna_sequence
|
CodonTransformer/CodonPrediction.py
ADDED
|
@@ -0,0 +1,1372 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
File: CodonPrediction.py
|
| 3 |
+
---------------------------
|
| 4 |
+
Includes functions to tokenize input, load models, infer predicted dna sequences and
|
| 5 |
+
helper functions related to processing data for passing to the model.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
import warnings
|
| 9 |
+
from typing import Any, Dict, List, Optional, Tuple, Union
|
| 10 |
+
import heapq
|
| 11 |
+
from dataclasses import dataclass
|
| 12 |
+
|
| 13 |
+
import numpy as np
|
| 14 |
+
import onnxruntime as rt
|
| 15 |
+
import torch
|
| 16 |
+
import transformers
|
| 17 |
+
from transformers import (
|
| 18 |
+
AutoTokenizer,
|
| 19 |
+
BatchEncoding,
|
| 20 |
+
BigBirdConfig,
|
| 21 |
+
BigBirdForMaskedLM,
|
| 22 |
+
PreTrainedTokenizerFast,
|
| 23 |
+
)
|
| 24 |
+
|
| 25 |
+
from CodonTransformer.CodonData import get_merged_seq
|
| 26 |
+
from CodonTransformer.CodonUtils import (
|
| 27 |
+
AMINO_ACID_TO_INDEX,
|
| 28 |
+
INDEX2TOKEN,
|
| 29 |
+
NUM_ORGANISMS,
|
| 30 |
+
ORGANISM2ID,
|
| 31 |
+
TOKEN2INDEX,
|
| 32 |
+
DNASequencePrediction,
|
| 33 |
+
GC_COUNTS_PER_TOKEN,
|
| 34 |
+
CODON_GC_CONTENT,
|
| 35 |
+
AA_MIN_GC,
|
| 36 |
+
AA_MAX_GC,
|
| 37 |
+
)
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
def predict_dna_sequence(
|
| 41 |
+
protein: str,
|
| 42 |
+
organism: Union[int, str],
|
| 43 |
+
device: torch.device,
|
| 44 |
+
tokenizer: Union[str, PreTrainedTokenizerFast] = None,
|
| 45 |
+
model: Union[str, torch.nn.Module] = None,
|
| 46 |
+
attention_type: str = "original_full",
|
| 47 |
+
deterministic: bool = True,
|
| 48 |
+
temperature: float = 0.2,
|
| 49 |
+
top_p: float = 0.95,
|
| 50 |
+
num_sequences: int = 1,
|
| 51 |
+
match_protein: bool = False,
|
| 52 |
+
use_constrained_search: bool = False,
|
| 53 |
+
gc_bounds: Tuple[float, float] = (0.30, 0.70),
|
| 54 |
+
beam_size: int = 5,
|
| 55 |
+
length_penalty: float = 1.0,
|
| 56 |
+
diversity_penalty: float = 0.0,
|
| 57 |
+
) -> Union[DNASequencePrediction, List[DNASequencePrediction]]:
|
| 58 |
+
"""
|
| 59 |
+
Predict the DNA sequence(s) for a given protein using the CodonTransformer model.
|
| 60 |
+
|
| 61 |
+
This function takes a protein sequence and an organism (as ID or name) as input
|
| 62 |
+
and returns the predicted DNA sequence(s) using the CodonTransformer model. It can use
|
| 63 |
+
either provided tokenizer and model objects or load them from specified paths.
|
| 64 |
+
|
| 65 |
+
Args:
|
| 66 |
+
protein (str): The input protein sequence for which to predict the DNA sequence.
|
| 67 |
+
organism (Union[int, str]): Either the ID of the organism or its name (e.g.,
|
| 68 |
+
"Escherichia coli general"). If a string is provided, it will be converted
|
| 69 |
+
to the corresponding ID using ORGANISM2ID.
|
| 70 |
+
device (torch.device): The device (CPU or GPU) to run the model on.
|
| 71 |
+
tokenizer (Union[str, PreTrainedTokenizerFast, None], optional): Either a file
|
| 72 |
+
path to load the tokenizer from, a pre-loaded tokenizer object, or None. If
|
| 73 |
+
None, it will be loaded from HuggingFace. Defaults to None.
|
| 74 |
+
model (Union[str, torch.nn.Module, None], optional): Either a file path to load
|
| 75 |
+
the model from, a pre-loaded model object, or None. If None, it will be
|
| 76 |
+
loaded from HuggingFace. Defaults to None.
|
| 77 |
+
attention_type (str, optional): The type of attention mechanism to use in the
|
| 78 |
+
model. Can be either 'block_sparse' or 'original_full'. Defaults to
|
| 79 |
+
"original_full".
|
| 80 |
+
deterministic (bool, optional): Whether to use deterministic decoding (most
|
| 81 |
+
likely tokens). If False, samples tokens according to their probabilities
|
| 82 |
+
adjusted by the temperature. Defaults to True.
|
| 83 |
+
temperature (float, optional): A value controlling the randomness of predictions
|
| 84 |
+
during non-deterministic decoding. Lower values (e.g., 0.2) make the model
|
| 85 |
+
more conservative, while higher values (e.g., 0.8) increase randomness.
|
| 86 |
+
Using high temperatures may result in prediction of DNA sequences that
|
| 87 |
+
do not translate to the input protein.
|
| 88 |
+
Recommended values are:
|
| 89 |
+
- Low randomness: 0.2
|
| 90 |
+
- Medium randomness: 0.5
|
| 91 |
+
- High randomness: 0.8
|
| 92 |
+
The temperature must be a positive float. Defaults to 0.2.
|
| 93 |
+
top_p (float, optional): The cumulative probability threshold for nucleus sampling.
|
| 94 |
+
Tokens with cumulative probability up to top_p are considered for sampling.
|
| 95 |
+
This parameter helps balance diversity and coherence in the predicted DNA sequences.
|
| 96 |
+
The value must be a float between 0 and 1. Defaults to 0.95.
|
| 97 |
+
num_sequences (int, optional): The number of DNA sequences to generate. Only applicable
|
| 98 |
+
when deterministic is False. Defaults to 1.
|
| 99 |
+
match_protein (bool, optional): Ensures the predicted DNA sequence is translated
|
| 100 |
+
to the input protein sequence by sampling from only the respective codons of
|
| 101 |
+
given amino acids. Defaults to False.
|
| 102 |
+
use_constrained_search (bool, optional): Whether to use constrained beam search
|
| 103 |
+
with GC content bounds. Defaults to False.
|
| 104 |
+
gc_bounds (Tuple[float, float], optional): GC content bounds (min, max) for
|
| 105 |
+
constrained search. Defaults to (0.30, 0.70).
|
| 106 |
+
beam_size (int, optional): Beam size for constrained search. Defaults to 5.
|
| 107 |
+
length_penalty (float, optional): Length penalty for beam search scoring.
|
| 108 |
+
Defaults to 1.0.
|
| 109 |
+
diversity_penalty (float, optional): Diversity penalty to reduce repetitive
|
| 110 |
+
sequences. Defaults to 0.0.
|
| 111 |
+
|
| 112 |
+
Returns:
|
| 113 |
+
Union[DNASequencePrediction, List[DNASequencePrediction]]: An object or list of objects
|
| 114 |
+
containing the prediction results:
|
| 115 |
+
- organism (str): Name of the organism used for prediction.
|
| 116 |
+
- protein (str): Input protein sequence for which DNA sequence is predicted.
|
| 117 |
+
- processed_input (str): Processed input sequence (merged protein and DNA).
|
| 118 |
+
- predicted_dna (str): Predicted DNA sequence.
|
| 119 |
+
|
| 120 |
+
Raises:
|
| 121 |
+
ValueError: If the protein sequence is empty, if the organism is invalid,
|
| 122 |
+
if the temperature is not a positive float, if top_p is not between 0 and 1,
|
| 123 |
+
or if num_sequences is less than 1 or used with deterministic mode.
|
| 124 |
+
|
| 125 |
+
Note:
|
| 126 |
+
This function uses ORGANISM2ID, INDEX2TOKEN, and AMINO_ACID_TO_INDEX dictionaries
|
| 127 |
+
imported from CodonTransformer.CodonUtils. ORGANISM2ID maps organism names to their
|
| 128 |
+
corresponding IDs. INDEX2TOKEN maps model output indices (token IDs) to
|
| 129 |
+
respective codons. AMINO_ACID_TO_INDEX maps each amino acid and stop symbol to indices
|
| 130 |
+
of codon tokens that translate to it.
|
| 131 |
+
|
| 132 |
+
Example:
|
| 133 |
+
>>> import torch
|
| 134 |
+
>>> from transformers import AutoTokenizer, BigBirdForMaskedLM
|
| 135 |
+
>>> from CodonTransformer.CodonPrediction import predict_dna_sequence
|
| 136 |
+
>>> from CodonTransformer.CodonJupyter import format_model_output
|
| 137 |
+
>>>
|
| 138 |
+
>>> # Set up device
|
| 139 |
+
>>> device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 140 |
+
>>>
|
| 141 |
+
>>> # Load tokenizer and model
|
| 142 |
+
>>> tokenizer = AutoTokenizer.from_pretrained("adibvafa/CodonTransformer")
|
| 143 |
+
>>> model = BigBirdForMaskedLM.from_pretrained("adibvafa/CodonTransformer")
|
| 144 |
+
>>> model = model.to(device)
|
| 145 |
+
>>>
|
| 146 |
+
>>> # Define protein sequence and organism
|
| 147 |
+
>>> protein = "MKTVRQERLKSIVRILERSKEPVSGAQLAEELSVSRQVIVQDIAYLRSLGYNIVATPRGYVLA"
|
| 148 |
+
>>> organism = "Escherichia coli general"
|
| 149 |
+
>>>
|
| 150 |
+
>>> # Predict DNA sequence with deterministic decoding (single sequence)
|
| 151 |
+
>>> output = predict_dna_sequence(
|
| 152 |
+
... protein=protein,
|
| 153 |
+
... organism=organism,
|
| 154 |
+
... device=device,
|
| 155 |
+
... tokenizer=tokenizer,
|
| 156 |
+
... model=model,
|
| 157 |
+
... attention_type="original_full",
|
| 158 |
+
... deterministic=True
|
| 159 |
+
... )
|
| 160 |
+
>>>
|
| 161 |
+
>>> # Predict DNA sequence with constrained beam search
|
| 162 |
+
>>> output_constrained = predict_dna_sequence(
|
| 163 |
+
... protein=protein,
|
| 164 |
+
... organism=organism,
|
| 165 |
+
... device=device,
|
| 166 |
+
... tokenizer=tokenizer,
|
| 167 |
+
... model=model,
|
| 168 |
+
... use_constrained_search=True,
|
| 169 |
+
... gc_bounds=(0.40, 0.60),
|
| 170 |
+
... beam_size=10,
|
| 171 |
+
... length_penalty=1.2,
|
| 172 |
+
... diversity_penalty=0.1
|
| 173 |
+
... )
|
| 174 |
+
>>>
|
| 175 |
+
>>> # Predict multiple DNA sequences with low randomness and top_p sampling
|
| 176 |
+
>>> output_random = predict_dna_sequence(
|
| 177 |
+
... protein=protein,
|
| 178 |
+
... organism=organism,
|
| 179 |
+
... device=device,
|
| 180 |
+
... tokenizer=tokenizer,
|
| 181 |
+
... model=model,
|
| 182 |
+
... attention_type="original_full",
|
| 183 |
+
... deterministic=False,
|
| 184 |
+
... temperature=0.2,
|
| 185 |
+
... top_p=0.95,
|
| 186 |
+
... num_sequences=3
|
| 187 |
+
... )
|
| 188 |
+
>>>
|
| 189 |
+
>>> print(format_model_output(output))
|
| 190 |
+
>>> for i, seq in enumerate(output_random, 1):
|
| 191 |
+
... print(f"Sequence {i}:")
|
| 192 |
+
... print(format_model_output(seq))
|
| 193 |
+
... print()
|
| 194 |
+
"""
|
| 195 |
+
if not protein:
|
| 196 |
+
raise ValueError("Protein sequence cannot be empty.")
|
| 197 |
+
|
| 198 |
+
if not isinstance(temperature, (float, int)) or temperature <= 0:
|
| 199 |
+
raise ValueError("Temperature must be a positive float.")
|
| 200 |
+
|
| 201 |
+
if not isinstance(top_p, (float, int)) or not 0 < top_p <= 1.0:
|
| 202 |
+
raise ValueError("top_p must be a float between 0 and 1.")
|
| 203 |
+
|
| 204 |
+
if not isinstance(num_sequences, int) or num_sequences < 1:
|
| 205 |
+
raise ValueError("num_sequences must be a positive integer.")
|
| 206 |
+
|
| 207 |
+
if use_constrained_search:
|
| 208 |
+
if not isinstance(gc_bounds, tuple) or len(gc_bounds) != 2:
|
| 209 |
+
raise ValueError("gc_bounds must be a tuple of (min_gc, max_gc).")
|
| 210 |
+
|
| 211 |
+
if not (0.0 <= gc_bounds[0] <= gc_bounds[1] <= 1.0):
|
| 212 |
+
raise ValueError("gc_bounds must be between 0.0 and 1.0 with min <= max.")
|
| 213 |
+
|
| 214 |
+
if not isinstance(beam_size, int) or beam_size < 1:
|
| 215 |
+
raise ValueError("beam_size must be a positive integer.")
|
| 216 |
+
|
| 217 |
+
if deterministic and num_sequences > 1 and not use_constrained_search:
|
| 218 |
+
raise ValueError(
|
| 219 |
+
"Multiple sequences can only be generated in non-deterministic mode "
|
| 220 |
+
"(unless using constrained search)."
|
| 221 |
+
)
|
| 222 |
+
|
| 223 |
+
if use_constrained_search and num_sequences > 1:
|
| 224 |
+
raise ValueError(
|
| 225 |
+
"Constrained beam search currently supports only single sequence generation."
|
| 226 |
+
)
|
| 227 |
+
|
| 228 |
+
# Load tokenizer
|
| 229 |
+
if not isinstance(tokenizer, PreTrainedTokenizerFast):
|
| 230 |
+
tokenizer = load_tokenizer(tokenizer)
|
| 231 |
+
|
| 232 |
+
# Load model
|
| 233 |
+
if not isinstance(model, torch.nn.Module):
|
| 234 |
+
model = load_model(model_path=model, device=device, attention_type=attention_type)
|
| 235 |
+
else:
|
| 236 |
+
model.eval()
|
| 237 |
+
model.bert.set_attention_type(attention_type)
|
| 238 |
+
model.to(device)
|
| 239 |
+
|
| 240 |
+
# Validate organism and convert to organism_id and organism_name
|
| 241 |
+
organism_id, organism_name = validate_and_convert_organism(organism)
|
| 242 |
+
|
| 243 |
+
# Inference loop
|
| 244 |
+
with torch.no_grad():
|
| 245 |
+
# Tokenize the input sequence
|
| 246 |
+
merged_seq = get_merged_seq(protein=protein, dna="")
|
| 247 |
+
input_dict = {
|
| 248 |
+
"idx": 0, # sample index
|
| 249 |
+
"codons": merged_seq,
|
| 250 |
+
"organism": organism_id,
|
| 251 |
+
}
|
| 252 |
+
tokenized_input = tokenize([input_dict], tokenizer=tokenizer).to(device)
|
| 253 |
+
|
| 254 |
+
# Get the model predictions
|
| 255 |
+
output_dict = model(**tokenized_input, return_dict=True)
|
| 256 |
+
logits = output_dict.logits.detach().cpu()
|
| 257 |
+
logits = logits[:, 1:-1, :] # Remove [CLS] and [SEP] tokens
|
| 258 |
+
|
| 259 |
+
# Mask the logits of codons that do not correspond to the input protein sequence
|
| 260 |
+
if match_protein:
|
| 261 |
+
possible_tokens_per_position = [
|
| 262 |
+
AMINO_ACID_TO_INDEX[token[0]] for token in merged_seq.split(" ")
|
| 263 |
+
]
|
| 264 |
+
seq_len = logits.shape[1]
|
| 265 |
+
if len(possible_tokens_per_position) > seq_len:
|
| 266 |
+
possible_tokens_per_position = possible_tokens_per_position[:seq_len]
|
| 267 |
+
|
| 268 |
+
mask = torch.full_like(logits, float("-inf"))
|
| 269 |
+
|
| 270 |
+
for pos, possible_tokens in enumerate(possible_tokens_per_position):
|
| 271 |
+
mask[:, pos, possible_tokens] = 0
|
| 272 |
+
|
| 273 |
+
logits = mask + logits
|
| 274 |
+
|
| 275 |
+
predictions = []
|
| 276 |
+
for _ in range(num_sequences):
|
| 277 |
+
# Decode the predicted DNA sequence from the model output
|
| 278 |
+
if use_constrained_search:
|
| 279 |
+
# Use constrained beam search with GC bounds
|
| 280 |
+
predicted_indices = constrained_beam_search_simple(
|
| 281 |
+
logits=logits.squeeze(0),
|
| 282 |
+
protein_sequence=protein,
|
| 283 |
+
gc_bounds=gc_bounds,
|
| 284 |
+
max_attempts=50,
|
| 285 |
+
)
|
| 286 |
+
elif deterministic:
|
| 287 |
+
predicted_indices = logits.argmax(dim=-1).squeeze().tolist()
|
| 288 |
+
else:
|
| 289 |
+
predicted_indices = sample_non_deterministic(
|
| 290 |
+
logits=logits, temperature=temperature, top_p=top_p
|
| 291 |
+
)
|
| 292 |
+
|
| 293 |
+
predicted_dna = list(map(INDEX2TOKEN.__getitem__, predicted_indices))
|
| 294 |
+
predicted_dna = (
|
| 295 |
+
"".join([token[-3:] for token in predicted_dna]).strip().upper()
|
| 296 |
+
)
|
| 297 |
+
|
| 298 |
+
predictions.append(
|
| 299 |
+
DNASequencePrediction(
|
| 300 |
+
organism=organism_name,
|
| 301 |
+
protein=protein,
|
| 302 |
+
processed_input=merged_seq,
|
| 303 |
+
predicted_dna=predicted_dna,
|
| 304 |
+
)
|
| 305 |
+
)
|
| 306 |
+
|
| 307 |
+
return predictions[0] if num_sequences == 1 else predictions
|
| 308 |
+
|
| 309 |
+
|
| 310 |
+
@dataclass
|
| 311 |
+
class BeamCandidate:
|
| 312 |
+
"""Represents a candidate sequence in the beam search."""
|
| 313 |
+
tokens: List[int]
|
| 314 |
+
score: float
|
| 315 |
+
gc_count: int
|
| 316 |
+
length: int
|
| 317 |
+
|
| 318 |
+
def __post_init__(self):
|
| 319 |
+
self.gc_ratio = self.gc_count / max(self.length, 1)
|
| 320 |
+
|
| 321 |
+
def __lt__(self, other):
|
| 322 |
+
return self.score < other.score
|
| 323 |
+
|
| 324 |
+
|
| 325 |
+
def _calculate_true_future_gc_range(
|
| 326 |
+
current_pos: int,
|
| 327 |
+
protein_sequence: str,
|
| 328 |
+
current_gc_count: int,
|
| 329 |
+
current_length: int
|
| 330 |
+
) -> Tuple[float, float]:
|
| 331 |
+
"""
|
| 332 |
+
Calculate the true minimum and maximum possible final GC content
|
| 333 |
+
given current state and remaining amino acids (perfect foresight).
|
| 334 |
+
|
| 335 |
+
Args:
|
| 336 |
+
current_pos: Current position in protein sequence
|
| 337 |
+
protein_sequence: Full protein sequence
|
| 338 |
+
current_gc_count: Current GC count in partial sequence
|
| 339 |
+
current_length: Current length in nucleotides
|
| 340 |
+
|
| 341 |
+
Returns:
|
| 342 |
+
Tuple of (min_possible_final_gc_ratio, max_possible_final_gc_ratio)
|
| 343 |
+
"""
|
| 344 |
+
if current_pos >= len(protein_sequence):
|
| 345 |
+
# Already at end, return current ratio
|
| 346 |
+
final_ratio = current_gc_count / max(current_length, 1)
|
| 347 |
+
return final_ratio, final_ratio
|
| 348 |
+
|
| 349 |
+
# Calculate remaining amino acids
|
| 350 |
+
remaining_aas = protein_sequence[current_pos:]
|
| 351 |
+
|
| 352 |
+
# Calculate min/max possible GC from remaining amino acids
|
| 353 |
+
min_future_gc = 0
|
| 354 |
+
max_future_gc = 0
|
| 355 |
+
|
| 356 |
+
for aa in remaining_aas:
|
| 357 |
+
if aa.upper() in AA_MIN_GC and aa.upper() in AA_MAX_GC:
|
| 358 |
+
min_future_gc += AA_MIN_GC[aa.upper()]
|
| 359 |
+
max_future_gc += AA_MAX_GC[aa.upper()]
|
| 360 |
+
else:
|
| 361 |
+
# If amino acid not found, assume moderate GC (1-2 range)
|
| 362 |
+
min_future_gc += 1
|
| 363 |
+
max_future_gc += 2
|
| 364 |
+
|
| 365 |
+
# Calculate final sequence length
|
| 366 |
+
final_length = current_length + len(remaining_aas) * 3
|
| 367 |
+
|
| 368 |
+
# Calculate min/max possible final GC ratios
|
| 369 |
+
min_final_gc_ratio = (current_gc_count + min_future_gc) / final_length
|
| 370 |
+
max_final_gc_ratio = (current_gc_count + max_future_gc) / final_length
|
| 371 |
+
|
| 372 |
+
return min_final_gc_ratio, max_final_gc_ratio
|
| 373 |
+
|
| 374 |
+
|
| 375 |
+
def constrained_beam_search_simple(
|
| 376 |
+
logits: torch.Tensor,
|
| 377 |
+
protein_sequence: str,
|
| 378 |
+
gc_bounds: Tuple[float, float] = (0.30, 0.70),
|
| 379 |
+
max_attempts: int = 100,
|
| 380 |
+
) -> List[int]:
|
| 381 |
+
"""
|
| 382 |
+
Simple constrained search - try multiple greedy samples and pick best one within GC bounds.
|
| 383 |
+
"""
|
| 384 |
+
min_gc, max_gc = gc_bounds
|
| 385 |
+
seq_len = min(logits.shape[0], len(protein_sequence))
|
| 386 |
+
|
| 387 |
+
# Convert to probabilities
|
| 388 |
+
probs = torch.softmax(logits, dim=-1)
|
| 389 |
+
|
| 390 |
+
valid_sequences = []
|
| 391 |
+
|
| 392 |
+
for attempt in range(max_attempts):
|
| 393 |
+
tokens = []
|
| 394 |
+
total_gc = 0
|
| 395 |
+
|
| 396 |
+
# Generate sequence position by position
|
| 397 |
+
for pos in range(seq_len):
|
| 398 |
+
aa = protein_sequence[pos]
|
| 399 |
+
possible_tokens = AMINO_ACID_TO_INDEX.get(aa, [])
|
| 400 |
+
|
| 401 |
+
if not possible_tokens:
|
| 402 |
+
continue
|
| 403 |
+
|
| 404 |
+
# Filter tokens by current constraints and get probabilities
|
| 405 |
+
candidates = []
|
| 406 |
+
for token_idx in possible_tokens:
|
| 407 |
+
if token_idx < len(probs[pos]) and token_idx < len(GC_COUNTS_PER_TOKEN):
|
| 408 |
+
prob = probs[pos][token_idx].item()
|
| 409 |
+
gc_contribution = int(GC_COUNTS_PER_TOKEN[token_idx].item())
|
| 410 |
+
|
| 411 |
+
# Check if this token could still lead to a valid final sequence (perfect foresight)
|
| 412 |
+
new_gc_total = total_gc + gc_contribution
|
| 413 |
+
new_length = (pos + 1) * 3
|
| 414 |
+
|
| 415 |
+
# Calculate what's possible for the final sequence given this choice
|
| 416 |
+
min_final_gc, max_final_gc = _calculate_true_future_gc_range(
|
| 417 |
+
pos + 1, protein_sequence, new_gc_total, new_length
|
| 418 |
+
)
|
| 419 |
+
|
| 420 |
+
# Only prune if there's NO OVERLAP between possible final range and target bounds
|
| 421 |
+
if max_final_gc >= min_gc and min_final_gc <= max_gc:
|
| 422 |
+
# Calculate gentle GC penalty to steer toward target center
|
| 423 |
+
target_gc = (min_gc + max_gc) / 2 # Target center (e.g., 0.50 for bounds 0.45-0.55)
|
| 424 |
+
current_projected_gc = (min_final_gc + max_final_gc) / 2 # Projected center
|
| 425 |
+
|
| 426 |
+
# Only apply penalty if we're significantly off-target AND late in sequence
|
| 427 |
+
sequence_progress = (pos + 1) / seq_len
|
| 428 |
+
if sequence_progress > 0.3: # Only apply penalty after 30% of sequence
|
| 429 |
+
gc_deviation = abs(current_projected_gc - target_gc)
|
| 430 |
+
if gc_deviation > 0.05: # Only if >5% deviation from target
|
| 431 |
+
# Gentle penalty: reduce probability by small factor
|
| 432 |
+
penalty_factor = max(0.7, 1.0 - 0.3 * gc_deviation) # 0.7-1.0 range
|
| 433 |
+
prob = prob * penalty_factor
|
| 434 |
+
|
| 435 |
+
candidates.append((token_idx, prob, gc_contribution))
|
| 436 |
+
|
| 437 |
+
if not candidates:
|
| 438 |
+
# If no valid candidates, break and try next attempt
|
| 439 |
+
break
|
| 440 |
+
|
| 441 |
+
# Sample from valid candidates (with temperature)
|
| 442 |
+
if attempt == 0:
|
| 443 |
+
# First attempt: greedy (highest probability)
|
| 444 |
+
best_token = max(candidates, key=lambda x: x[1])
|
| 445 |
+
else:
|
| 446 |
+
# Other attempts: sample with some randomness
|
| 447 |
+
probs_list = [c[1] for c in candidates]
|
| 448 |
+
if sum(probs_list) > 0:
|
| 449 |
+
# Normalize probabilities
|
| 450 |
+
probs_array = np.array(probs_list)
|
| 451 |
+
probs_array = probs_array / probs_array.sum()
|
| 452 |
+
# Sample
|
| 453 |
+
chosen_idx = np.random.choice(len(candidates), p=probs_array)
|
| 454 |
+
best_token = candidates[chosen_idx]
|
| 455 |
+
else:
|
| 456 |
+
best_token = candidates[0]
|
| 457 |
+
|
| 458 |
+
tokens.append(best_token[0])
|
| 459 |
+
total_gc += best_token[2]
|
| 460 |
+
|
| 461 |
+
# Check if we got a complete sequence
|
| 462 |
+
if len(tokens) == seq_len:
|
| 463 |
+
final_gc_ratio = total_gc / (seq_len * 3)
|
| 464 |
+
if min_gc <= final_gc_ratio <= max_gc:
|
| 465 |
+
# Calculate sequence score (sum of log probabilities)
|
| 466 |
+
score = sum(np.log(probs[i][tokens[i]].item() + 1e-8) for i in range(len(tokens)))
|
| 467 |
+
valid_sequences.append((tokens, score, final_gc_ratio))
|
| 468 |
+
|
| 469 |
+
if not valid_sequences:
|
| 470 |
+
raise ValueError(f"Could not generate valid sequence within GC bounds {gc_bounds} after {max_attempts} attempts")
|
| 471 |
+
|
| 472 |
+
# Return the sequence with highest score
|
| 473 |
+
best_sequence = max(valid_sequences, key=lambda x: x[1])
|
| 474 |
+
return best_sequence[0]
|
| 475 |
+
|
| 476 |
+
|
| 477 |
+
def constrained_beam_search(
|
| 478 |
+
logits: torch.Tensor,
|
| 479 |
+
protein_sequence: str,
|
| 480 |
+
gc_bounds: Tuple[float, float] = (0.30, 0.70),
|
| 481 |
+
beam_size: int = 5,
|
| 482 |
+
length_penalty: float = 1.0,
|
| 483 |
+
diversity_penalty: float = 0.0,
|
| 484 |
+
temperature: float = 1.0,
|
| 485 |
+
max_candidates: int = 100,
|
| 486 |
+
position_aware_gc_penalty: bool = True,
|
| 487 |
+
gc_penalty_strength: float = 2.0,
|
| 488 |
+
) -> List[int]:
|
| 489 |
+
"""
|
| 490 |
+
Constrained beam search with exact per-residue GC bounds tracking.
|
| 491 |
+
|
| 492 |
+
Priority #1: Exact per-residue GC bounds tracking
|
| 493 |
+
- Tracks cumulative GC content after each codon selection
|
| 494 |
+
- Prunes candidates that would violate GC bounds
|
| 495 |
+
- Maintains beam of valid candidates
|
| 496 |
+
|
| 497 |
+
Priority #2: Position-aware GC penalty mechanism
|
| 498 |
+
- Applies variable penalty weights based on sequence position
|
| 499 |
+
- Preserves flexibility early, applies pressure when necessary
|
| 500 |
+
- Uses progressive penalty scaling based on deviation severity
|
| 501 |
+
|
| 502 |
+
Args:
|
| 503 |
+
logits (torch.Tensor): Model logits of shape [seq_len, vocab_size]
|
| 504 |
+
protein_sequence (str): Input protein sequence
|
| 505 |
+
gc_bounds (Tuple[float, float]): (min_gc, max_gc) bounds
|
| 506 |
+
beam_size (int): Number of candidates to maintain
|
| 507 |
+
length_penalty (float): Length penalty for scoring
|
| 508 |
+
diversity_penalty (float): Diversity penalty for scoring
|
| 509 |
+
temperature (float): Temperature for probability scaling
|
| 510 |
+
max_candidates (int): Maximum candidates to consider per position
|
| 511 |
+
position_aware_gc_penalty (bool): Whether to use position-aware GC penalties
|
| 512 |
+
gc_penalty_strength (float): Strength of GC penalty adjustment
|
| 513 |
+
|
| 514 |
+
Returns:
|
| 515 |
+
List[int]: Best sequence token indices
|
| 516 |
+
"""
|
| 517 |
+
min_gc, max_gc = gc_bounds
|
| 518 |
+
seq_len = logits.shape[0]
|
| 519 |
+
protein_len = len(protein_sequence)
|
| 520 |
+
|
| 521 |
+
# Ensure we don't go beyond the protein sequence
|
| 522 |
+
if seq_len > protein_len:
|
| 523 |
+
print(f"Warning: logits length ({seq_len}) > protein length ({protein_len}). Truncating to protein length.")
|
| 524 |
+
seq_len = protein_len
|
| 525 |
+
logits = logits[:protein_len]
|
| 526 |
+
|
| 527 |
+
# Initialize beam with empty candidate
|
| 528 |
+
beam = [BeamCandidate(tokens=[], score=0.0, gc_count=0, length=0)]
|
| 529 |
+
|
| 530 |
+
# Apply temperature scaling
|
| 531 |
+
if temperature != 1.0:
|
| 532 |
+
logits = logits / temperature
|
| 533 |
+
|
| 534 |
+
# Convert to probabilities
|
| 535 |
+
probs = torch.softmax(logits, dim=-1)
|
| 536 |
+
|
| 537 |
+
for pos in range(min(seq_len, len(protein_sequence))):
|
| 538 |
+
# Get possible tokens for current amino acid
|
| 539 |
+
aa = protein_sequence[pos]
|
| 540 |
+
possible_tokens = AMINO_ACID_TO_INDEX.get(aa, [])
|
| 541 |
+
|
| 542 |
+
if not possible_tokens:
|
| 543 |
+
# Fallback to all tokens if amino acid not found
|
| 544 |
+
possible_tokens = list(range(probs.shape[1]))
|
| 545 |
+
|
| 546 |
+
# Get top candidates for this position
|
| 547 |
+
pos_probs = probs[pos]
|
| 548 |
+
top_candidates = []
|
| 549 |
+
|
| 550 |
+
for token_idx in possible_tokens:
|
| 551 |
+
if token_idx < len(pos_probs) and token_idx < len(GC_COUNTS_PER_TOKEN):
|
| 552 |
+
prob = pos_probs[token_idx].item()
|
| 553 |
+
gc_contribution = int(GC_COUNTS_PER_TOKEN[token_idx].item())
|
| 554 |
+
# Only include tokens with valid probabilities
|
| 555 |
+
if prob > 1e-10: # Avoid extremely low probabilities
|
| 556 |
+
top_candidates.append((token_idx, prob, gc_contribution))
|
| 557 |
+
|
| 558 |
+
# Sort by probability and take top max_candidates
|
| 559 |
+
top_candidates.sort(key=lambda x: x[1], reverse=True)
|
| 560 |
+
top_candidates = top_candidates[:max_candidates]
|
| 561 |
+
|
| 562 |
+
# If no valid candidates found, fallback to all possible tokens for this amino acid
|
| 563 |
+
if not top_candidates:
|
| 564 |
+
for token_idx in possible_tokens[:min(len(possible_tokens), max_candidates)]:
|
| 565 |
+
if token_idx < len(pos_probs) and token_idx < len(GC_COUNTS_PER_TOKEN):
|
| 566 |
+
prob = max(pos_probs[token_idx].item(), 1e-10) # Ensure minimum probability
|
| 567 |
+
gc_contribution = int(GC_COUNTS_PER_TOKEN[token_idx].item())
|
| 568 |
+
top_candidates.append((token_idx, prob, gc_contribution))
|
| 569 |
+
|
| 570 |
+
# Generate new beam candidates
|
| 571 |
+
new_beam = []
|
| 572 |
+
|
| 573 |
+
for candidate in beam:
|
| 574 |
+
for token_idx, prob, gc_contribution in top_candidates:
|
| 575 |
+
# Calculate new GC stats
|
| 576 |
+
new_gc_count = candidate.gc_count + gc_contribution
|
| 577 |
+
new_length = candidate.length + 3 # Each codon is 3 nucleotides
|
| 578 |
+
new_gc_ratio = new_gc_count / new_length
|
| 579 |
+
|
| 580 |
+
# Priority #2: Position-aware GC penalty mechanism
|
| 581 |
+
gc_penalty = 0.0
|
| 582 |
+
if position_aware_gc_penalty:
|
| 583 |
+
# Calculate position weight (more penalty towards end of sequence)
|
| 584 |
+
position_weight = (pos + 1) / seq_len
|
| 585 |
+
|
| 586 |
+
# Calculate GC deviation severity
|
| 587 |
+
target_gc = (min_gc + max_gc) / 2
|
| 588 |
+
gc_deviation = abs(new_gc_ratio - target_gc)
|
| 589 |
+
deviation_severity = gc_deviation / ((max_gc - min_gc) / 2)
|
| 590 |
+
|
| 591 |
+
# Apply progressive penalty
|
| 592 |
+
if deviation_severity > 0.5: # Soft penalty zone
|
| 593 |
+
gc_penalty = gc_penalty_strength * position_weight * (deviation_severity - 0.5) ** 2
|
| 594 |
+
|
| 595 |
+
# Hard constraint: still prune sequences that exceed bounds
|
| 596 |
+
if new_gc_ratio < min_gc or new_gc_ratio > max_gc:
|
| 597 |
+
continue # Prune invalid candidates
|
| 598 |
+
else:
|
| 599 |
+
# Priority #1: Hard GC bounds only
|
| 600 |
+
if new_gc_ratio < min_gc or new_gc_ratio > max_gc:
|
| 601 |
+
continue # Prune invalid candidates
|
| 602 |
+
|
| 603 |
+
# Calculate score with GC penalty
|
| 604 |
+
new_score = candidate.score + np.log(prob + 1e-8) - gc_penalty
|
| 605 |
+
|
| 606 |
+
# Apply length penalty
|
| 607 |
+
if length_penalty != 1.0:
|
| 608 |
+
length_norm = ((pos + 1) ** length_penalty)
|
| 609 |
+
normalized_score = new_score / length_norm
|
| 610 |
+
else:
|
| 611 |
+
normalized_score = new_score
|
| 612 |
+
|
| 613 |
+
# Create new candidate
|
| 614 |
+
new_candidate = BeamCandidate(
|
| 615 |
+
tokens=candidate.tokens + [token_idx],
|
| 616 |
+
score=normalized_score,
|
| 617 |
+
gc_count=new_gc_count,
|
| 618 |
+
length=new_length
|
| 619 |
+
)
|
| 620 |
+
|
| 621 |
+
new_beam.append(new_candidate)
|
| 622 |
+
|
| 623 |
+
# Apply diversity penalty if specified
|
| 624 |
+
if diversity_penalty > 0.0:
|
| 625 |
+
new_beam = _apply_diversity_penalty(new_beam, diversity_penalty)
|
| 626 |
+
|
| 627 |
+
# Keep top beam_size candidates
|
| 628 |
+
beam = sorted(new_beam, key=lambda x: x.score, reverse=True)[:beam_size]
|
| 629 |
+
|
| 630 |
+
# Priority #3: Adaptive beam rescue for difficult sequences
|
| 631 |
+
if not beam:
|
| 632 |
+
# Attempt beam rescue by relaxing constraints progressively
|
| 633 |
+
rescue_attempts = 0
|
| 634 |
+
max_rescue_attempts = 3
|
| 635 |
+
|
| 636 |
+
while not beam and rescue_attempts < max_rescue_attempts:
|
| 637 |
+
rescue_attempts += 1
|
| 638 |
+
|
| 639 |
+
# Progressive relaxation strategy
|
| 640 |
+
if rescue_attempts == 1:
|
| 641 |
+
# First attempt: increase beam size and relax GC bounds slightly
|
| 642 |
+
temp_beam_size = min(beam_size * 2, max_candidates)
|
| 643 |
+
temp_gc_bounds = (min_gc * 0.95, max_gc * 1.05)
|
| 644 |
+
elif rescue_attempts == 2:
|
| 645 |
+
# Second attempt: further relax GC bounds and increase candidates
|
| 646 |
+
temp_beam_size = min(beam_size * 3, max_candidates)
|
| 647 |
+
temp_gc_bounds = (min_gc * 0.9, max_gc * 1.1)
|
| 648 |
+
else:
|
| 649 |
+
# Final attempt: maximum relaxation
|
| 650 |
+
temp_beam_size = max_candidates
|
| 651 |
+
temp_gc_bounds = (min_gc * 0.85, max_gc * 1.15)
|
| 652 |
+
|
| 653 |
+
# Retry beam generation with relaxed parameters
|
| 654 |
+
rescue_beam = []
|
| 655 |
+
# Use previous beam state or start fresh if this is the first position with no beam
|
| 656 |
+
previous_beam = beam if beam else [BeamCandidate(tokens=[], score=0.0, gc_count=0, length=0)]
|
| 657 |
+
for candidate in previous_beam:
|
| 658 |
+
for token_idx, prob, gc_contribution in top_candidates:
|
| 659 |
+
new_gc_count = candidate.gc_count + gc_contribution
|
| 660 |
+
new_length = candidate.length + 3
|
| 661 |
+
new_gc_ratio = new_gc_count / new_length
|
| 662 |
+
|
| 663 |
+
# Check relaxed bounds
|
| 664 |
+
if temp_gc_bounds[0] <= new_gc_ratio <= temp_gc_bounds[1]:
|
| 665 |
+
# Apply reduced GC penalty for rescue
|
| 666 |
+
gc_penalty = 0.0
|
| 667 |
+
if position_aware_gc_penalty:
|
| 668 |
+
position_weight = (pos + 1) / seq_len
|
| 669 |
+
target_gc = (min_gc + max_gc) / 2
|
| 670 |
+
gc_deviation = abs(new_gc_ratio - target_gc)
|
| 671 |
+
deviation_severity = gc_deviation / ((max_gc - min_gc) / 2)
|
| 672 |
+
|
| 673 |
+
# Reduced penalty for rescue
|
| 674 |
+
if deviation_severity > 0.7:
|
| 675 |
+
gc_penalty = (gc_penalty_strength * 0.5) * position_weight * (deviation_severity - 0.7) ** 2
|
| 676 |
+
|
| 677 |
+
new_score = candidate.score + np.log(prob + 1e-8) - gc_penalty
|
| 678 |
+
|
| 679 |
+
if length_penalty != 1.0:
|
| 680 |
+
length_norm = ((pos + 1) ** length_penalty)
|
| 681 |
+
normalized_score = new_score / length_norm
|
| 682 |
+
else:
|
| 683 |
+
normalized_score = new_score
|
| 684 |
+
|
| 685 |
+
rescue_candidate = BeamCandidate(
|
| 686 |
+
tokens=candidate.tokens + [token_idx],
|
| 687 |
+
score=normalized_score,
|
| 688 |
+
gc_count=new_gc_count,
|
| 689 |
+
length=new_length
|
| 690 |
+
)
|
| 691 |
+
rescue_beam.append(rescue_candidate)
|
| 692 |
+
|
| 693 |
+
# Keep top candidates from rescue attempt
|
| 694 |
+
if rescue_beam:
|
| 695 |
+
beam = sorted(rescue_beam, key=lambda x: x.score, reverse=True)[:temp_beam_size]
|
| 696 |
+
break
|
| 697 |
+
|
| 698 |
+
# If all rescue attempts failed, raise error
|
| 699 |
+
if not beam:
|
| 700 |
+
raise ValueError(
|
| 701 |
+
f"Beam rescue failed at position {pos} after {max_rescue_attempts} attempts. "
|
| 702 |
+
f"The GC constraints {gc_bounds} may be too restrictive for this protein sequence. "
|
| 703 |
+
f"Consider relaxing constraints or using a different approach."
|
| 704 |
+
)
|
| 705 |
+
|
| 706 |
+
# Return best candidate
|
| 707 |
+
best_candidate = max(beam, key=lambda x: x.score)
|
| 708 |
+
return best_candidate.tokens
|
| 709 |
+
|
| 710 |
+
|
| 711 |
+
# Wrapper function that tries simple approach first
|
| 712 |
+
def constrained_beam_search_wrapper(
|
| 713 |
+
logits: torch.Tensor,
|
| 714 |
+
protein_sequence: str,
|
| 715 |
+
gc_bounds: Tuple[float, float] = (0.30, 0.70),
|
| 716 |
+
**kwargs
|
| 717 |
+
) -> List[int]:
|
| 718 |
+
"""Wrapper that tries simple approach first, falls back to complex beam search."""
|
| 719 |
+
try:
|
| 720 |
+
# Try simple approach first
|
| 721 |
+
return constrained_beam_search_simple(logits, protein_sequence, gc_bounds)
|
| 722 |
+
except ValueError:
|
| 723 |
+
# Fall back to complex beam search
|
| 724 |
+
return constrained_beam_search(logits, protein_sequence, gc_bounds, **kwargs)
|
| 725 |
+
|
| 726 |
+
|
| 727 |
+
def _apply_diversity_penalty(candidates: List[BeamCandidate], penalty: float) -> List[BeamCandidate]:
|
| 728 |
+
"""
|
| 729 |
+
Apply diversity penalty to reduce repetitive sequences.
|
| 730 |
+
|
| 731 |
+
Args:
|
| 732 |
+
candidates (List[BeamCandidate]): List of candidates
|
| 733 |
+
penalty (float): Diversity penalty strength
|
| 734 |
+
|
| 735 |
+
Returns:
|
| 736 |
+
List[BeamCandidate]: Candidates with diversity penalty applied
|
| 737 |
+
"""
|
| 738 |
+
if not candidates:
|
| 739 |
+
return candidates
|
| 740 |
+
|
| 741 |
+
# Count token occurrences
|
| 742 |
+
token_counts = {}
|
| 743 |
+
for candidate in candidates:
|
| 744 |
+
for token in candidate.tokens:
|
| 745 |
+
token_counts[token] = token_counts.get(token, 0) + 1
|
| 746 |
+
|
| 747 |
+
# Apply penalty
|
| 748 |
+
for candidate in candidates:
|
| 749 |
+
diversity_score = 0.0
|
| 750 |
+
for token in candidate.tokens:
|
| 751 |
+
if token_counts[token] > 1:
|
| 752 |
+
diversity_score += penalty * np.log(token_counts[token])
|
| 753 |
+
candidate.score -= diversity_score
|
| 754 |
+
|
| 755 |
+
return candidates
|
| 756 |
+
|
| 757 |
+
|
| 758 |
+
def sample_non_deterministic(
|
| 759 |
+
logits: torch.Tensor,
|
| 760 |
+
temperature: float = 0.2,
|
| 761 |
+
top_p: float = 0.95,
|
| 762 |
+
) -> List[int]:
|
| 763 |
+
"""
|
| 764 |
+
Sample token indices from logits using temperature scaling and nucleus (top-p) sampling.
|
| 765 |
+
|
| 766 |
+
This function applies temperature scaling to the logits, computes probabilities,
|
| 767 |
+
and then performs nucleus sampling to select token indices. It is used for
|
| 768 |
+
non-deterministic decoding in language models to introduce randomness while
|
| 769 |
+
maintaining coherence in the generated sequences.
|
| 770 |
+
|
| 771 |
+
Args:
|
| 772 |
+
logits (torch.Tensor): The logits output from the model of shape
|
| 773 |
+
[seq_len, vocab_size] or [batch_size, seq_len, vocab_size].
|
| 774 |
+
temperature (float, optional): Temperature value for scaling logits.
|
| 775 |
+
Must be a positive float. Defaults to 1.0.
|
| 776 |
+
top_p (float, optional): Cumulative probability threshold for nucleus sampling.
|
| 777 |
+
Must be a float between 0 and 1. Tokens with cumulative probability up to
|
| 778 |
+
`top_p` are considered for sampling. Defaults to 0.95.
|
| 779 |
+
|
| 780 |
+
Returns:
|
| 781 |
+
List[int]: A list of sampled token indices corresponding to the predicted tokens.
|
| 782 |
+
|
| 783 |
+
Raises:
|
| 784 |
+
ValueError: If `temperature` is not a positive float or if `top_p` is not between 0 and 1.
|
| 785 |
+
|
| 786 |
+
Example:
|
| 787 |
+
>>> logits = model_output.logits # Assume logits is a tensor of shape [seq_len, vocab_size]
|
| 788 |
+
>>> predicted_indices = sample_non_deterministic(logits, temperature=0.7, top_p=0.9)
|
| 789 |
+
"""
|
| 790 |
+
if not isinstance(temperature, (float, int)) or temperature <= 0:
|
| 791 |
+
raise ValueError("Temperature must be a positive float.")
|
| 792 |
+
|
| 793 |
+
if not isinstance(top_p, (float, int)) or not 0 < top_p <= 1.0:
|
| 794 |
+
raise ValueError("top_p must be a float between 0 and 1.")
|
| 795 |
+
|
| 796 |
+
# Compute probabilities using temperature scaling
|
| 797 |
+
probs = torch.softmax(logits / temperature, dim=-1)
|
| 798 |
+
|
| 799 |
+
|
| 800 |
+
# Remove batch dimension if present
|
| 801 |
+
if probs.dim() == 3:
|
| 802 |
+
probs = probs.squeeze(0) # Shape: [seq_len, vocab_size]
|
| 803 |
+
|
| 804 |
+
# Sort probabilities in descending order
|
| 805 |
+
probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True)
|
| 806 |
+
probs_sum = torch.cumsum(probs_sort, dim=-1)
|
| 807 |
+
mask = probs_sum - probs_sort > top_p
|
| 808 |
+
|
| 809 |
+
# Zero out probabilities for tokens beyond the top-p threshold
|
| 810 |
+
probs_sort[mask] = 0.0
|
| 811 |
+
|
| 812 |
+
# Renormalize the probabilities
|
| 813 |
+
probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True))
|
| 814 |
+
next_token = torch.multinomial(probs_sort, num_samples=1)
|
| 815 |
+
predicted_indices = torch.gather(probs_idx, -1, next_token).squeeze(-1)
|
| 816 |
+
|
| 817 |
+
return predicted_indices.tolist()
|
| 818 |
+
|
| 819 |
+
|
| 820 |
+
def load_model(
|
| 821 |
+
model_path: Optional[str] = None,
|
| 822 |
+
device: torch.device = None,
|
| 823 |
+
attention_type: str = "original_full",
|
| 824 |
+
num_organisms: int = None,
|
| 825 |
+
remove_prefix: bool = True,
|
| 826 |
+
) -> torch.nn.Module:
|
| 827 |
+
"""
|
| 828 |
+
Load a BigBirdForMaskedLM model from a model file, checkpoint, or HuggingFace.
|
| 829 |
+
|
| 830 |
+
Args:
|
| 831 |
+
model_path (Optional[str]): Path to the model file or checkpoint. If None,
|
| 832 |
+
load from HuggingFace.
|
| 833 |
+
device (torch.device, optional): The device to load the model onto.
|
| 834 |
+
attention_type (str, optional): The type of attention, 'block_sparse'
|
| 835 |
+
or 'original_full'.
|
| 836 |
+
num_organisms (int, optional): Number of organisms, needed if loading from a
|
| 837 |
+
checkpoint that requires this.
|
| 838 |
+
remove_prefix (bool, optional): Whether to remove the "model." prefix from the
|
| 839 |
+
keys in the state dict.
|
| 840 |
+
|
| 841 |
+
Returns:
|
| 842 |
+
torch.nn.Module: The loaded model.
|
| 843 |
+
"""
|
| 844 |
+
if not model_path:
|
| 845 |
+
warnings.warn("Model path not provided. Loading from HuggingFace.", UserWarning)
|
| 846 |
+
model = BigBirdForMaskedLM.from_pretrained("adibvafa/CodonTransformer")
|
| 847 |
+
elif model_path.endswith(".ckpt"):
|
| 848 |
+
checkpoint = torch.load(model_path, map_location="cpu")
|
| 849 |
+
|
| 850 |
+
# Detect Lightning checkpoint vs raw state dict
|
| 851 |
+
if isinstance(checkpoint, dict) and "state_dict" in checkpoint:
|
| 852 |
+
state_dict = checkpoint["state_dict"]
|
| 853 |
+
if remove_prefix:
|
| 854 |
+
state_dict = {
|
| 855 |
+
k.replace("model.", ""): v for k, v in state_dict.items()
|
| 856 |
+
}
|
| 857 |
+
else:
|
| 858 |
+
# assume checkpoint itself is state_dict
|
| 859 |
+
state_dict = checkpoint
|
| 860 |
+
|
| 861 |
+
if num_organisms is None:
|
| 862 |
+
num_organisms = NUM_ORGANISMS
|
| 863 |
+
|
| 864 |
+
# Load model configuration and instantiate the model
|
| 865 |
+
config = load_bigbird_config(num_organisms)
|
| 866 |
+
model = BigBirdForMaskedLM(config=config)
|
| 867 |
+
model.load_state_dict(state_dict, strict=False)
|
| 868 |
+
|
| 869 |
+
elif model_path.endswith(".pt"):
|
| 870 |
+
state_dict = torch.load(model_path)
|
| 871 |
+
config = state_dict.pop("self.config")
|
| 872 |
+
model = BigBirdForMaskedLM(config=config)
|
| 873 |
+
model.load_state_dict(state_dict, strict=False)
|
| 874 |
+
|
| 875 |
+
else:
|
| 876 |
+
raise ValueError(
|
| 877 |
+
"Unsupported file type. Please provide a .ckpt or .pt file, "
|
| 878 |
+
"or None to load from HuggingFace."
|
| 879 |
+
)
|
| 880 |
+
|
| 881 |
+
# Prepare model for evaluation
|
| 882 |
+
model.bert.set_attention_type(attention_type)
|
| 883 |
+
model.eval()
|
| 884 |
+
if device:
|
| 885 |
+
model.to(device)
|
| 886 |
+
|
| 887 |
+
return model
|
| 888 |
+
|
| 889 |
+
|
| 890 |
+
def load_bigbird_config(num_organisms: int) -> BigBirdConfig:
|
| 891 |
+
"""
|
| 892 |
+
Load the config object used to train the BigBird transformer.
|
| 893 |
+
|
| 894 |
+
Args:
|
| 895 |
+
num_organisms (int): The number of organisms.
|
| 896 |
+
|
| 897 |
+
Returns:
|
| 898 |
+
BigBirdConfig: The configuration object for BigBird.
|
| 899 |
+
"""
|
| 900 |
+
config = transformers.BigBirdConfig(
|
| 901 |
+
vocab_size=len(TOKEN2INDEX), # Equal to len(tokenizer)
|
| 902 |
+
type_vocab_size=num_organisms,
|
| 903 |
+
sep_token_id=2,
|
| 904 |
+
)
|
| 905 |
+
return config
|
| 906 |
+
|
| 907 |
+
|
| 908 |
+
def create_model_from_checkpoint(
|
| 909 |
+
checkpoint_dir: str, output_model_dir: str, num_organisms: int
|
| 910 |
+
) -> None:
|
| 911 |
+
"""
|
| 912 |
+
Save a model to disk using a previous checkpoint.
|
| 913 |
+
|
| 914 |
+
Args:
|
| 915 |
+
checkpoint_dir (str): Directory where the checkpoint is stored.
|
| 916 |
+
output_model_dir (str): Directory where the model will be saved.
|
| 917 |
+
num_organisms (int): Number of organisms.
|
| 918 |
+
"""
|
| 919 |
+
checkpoint = load_model(model_path=checkpoint_dir, num_organisms=num_organisms)
|
| 920 |
+
state_dict = checkpoint.state_dict()
|
| 921 |
+
state_dict["self.config"] = load_bigbird_config(num_organisms=num_organisms)
|
| 922 |
+
|
| 923 |
+
# Save the model state dict to the output directory
|
| 924 |
+
torch.save(state_dict, output_model_dir)
|
| 925 |
+
|
| 926 |
+
|
| 927 |
+
def load_tokenizer(tokenizer_path: Optional[Union[str, PreTrainedTokenizerFast]] = None) -> PreTrainedTokenizerFast:
|
| 928 |
+
"""
|
| 929 |
+
Create and return a tokenizer object from tokenizer path or HuggingFace.
|
| 930 |
+
|
| 931 |
+
Args:
|
| 932 |
+
tokenizer_path (Optional[Union[str, PreTrainedTokenizerFast]]): Path to the tokenizer file,
|
| 933 |
+
a pre-loaded tokenizer object, or None. If None, load from HuggingFace.
|
| 934 |
+
|
| 935 |
+
Returns:
|
| 936 |
+
PreTrainedTokenizerFast: The tokenizer object.
|
| 937 |
+
"""
|
| 938 |
+
# If a tokenizer object is already provided, return it
|
| 939 |
+
if isinstance(tokenizer_path, PreTrainedTokenizerFast):
|
| 940 |
+
return tokenizer_path
|
| 941 |
+
|
| 942 |
+
# If no path is provided, load from HuggingFace
|
| 943 |
+
if not tokenizer_path:
|
| 944 |
+
warnings.warn(
|
| 945 |
+
"Tokenizer path not provided. Loading from HuggingFace.", UserWarning
|
| 946 |
+
)
|
| 947 |
+
return AutoTokenizer.from_pretrained("adibvafa/CodonTransformer")
|
| 948 |
+
|
| 949 |
+
# Load from file path
|
| 950 |
+
return transformers.PreTrainedTokenizerFast(
|
| 951 |
+
tokenizer_file=tokenizer_path,
|
| 952 |
+
bos_token="[CLS]",
|
| 953 |
+
eos_token="[SEP]",
|
| 954 |
+
unk_token="[UNK]",
|
| 955 |
+
sep_token="[SEP]",
|
| 956 |
+
pad_token="[PAD]",
|
| 957 |
+
cls_token="[CLS]",
|
| 958 |
+
mask_token="[MASK]",
|
| 959 |
+
)
|
| 960 |
+
|
| 961 |
+
|
| 962 |
+
def tokenize(
|
| 963 |
+
batch: List[Dict[str, Any]],
|
| 964 |
+
tokenizer: Union[PreTrainedTokenizerFast, str] = None,
|
| 965 |
+
max_len: int = 2048,
|
| 966 |
+
) -> BatchEncoding:
|
| 967 |
+
"""
|
| 968 |
+
Return the tokenized sequences given a batch of input data.
|
| 969 |
+
Each data in the batch is expected to be a dictionary with "codons" and
|
| 970 |
+
"organism" keys.
|
| 971 |
+
|
| 972 |
+
Args:
|
| 973 |
+
batch (List[Dict[str, Any]]): A list of dictionaries with "codons" and
|
| 974 |
+
"organism" keys.
|
| 975 |
+
tokenizer (PreTrainedTokenizerFast, str, optional): The tokenizer object or
|
| 976 |
+
path to the tokenizer file.
|
| 977 |
+
max_len (int, optional): Maximum length of the tokenized sequence.
|
| 978 |
+
|
| 979 |
+
Returns:
|
| 980 |
+
BatchEncoding: The tokenized batch.
|
| 981 |
+
"""
|
| 982 |
+
if not isinstance(tokenizer, PreTrainedTokenizerFast):
|
| 983 |
+
tokenizer = load_tokenizer(tokenizer)
|
| 984 |
+
|
| 985 |
+
tokenized = tokenizer(
|
| 986 |
+
[data["codons"] for data in batch],
|
| 987 |
+
return_attention_mask=True,
|
| 988 |
+
return_token_type_ids=True,
|
| 989 |
+
truncation=True,
|
| 990 |
+
padding=True,
|
| 991 |
+
max_length=max_len,
|
| 992 |
+
return_tensors="pt",
|
| 993 |
+
)
|
| 994 |
+
|
| 995 |
+
# Add token type IDs for species
|
| 996 |
+
seq_len = tokenized["input_ids"].shape[-1]
|
| 997 |
+
species_index = torch.tensor([[data["organism"]] for data in batch])
|
| 998 |
+
tokenized["token_type_ids"] = species_index.repeat(1, seq_len)
|
| 999 |
+
|
| 1000 |
+
return tokenized
|
| 1001 |
+
|
| 1002 |
+
|
| 1003 |
+
def validate_and_convert_organism(organism: Union[int, str]) -> Tuple[int, str]:
|
| 1004 |
+
"""
|
| 1005 |
+
Validate and convert the organism input to both ID and name.
|
| 1006 |
+
|
| 1007 |
+
This function takes either an organism ID or name as input and returns both
|
| 1008 |
+
the ID and name. It performs validation to ensure the input corresponds to
|
| 1009 |
+
a valid organism in the ORGANISM2ID dictionary.
|
| 1010 |
+
|
| 1011 |
+
Args:
|
| 1012 |
+
organism (Union[int, str]): Either the ID of the organism (int) or its
|
| 1013 |
+
name (str).
|
| 1014 |
+
|
| 1015 |
+
Returns:
|
| 1016 |
+
Tuple[int, str]: A tuple containing the organism ID (int) and name (str).
|
| 1017 |
+
|
| 1018 |
+
Raises:
|
| 1019 |
+
ValueError: If the input is neither a string nor an integer, if the
|
| 1020 |
+
organism name is not found in ORGANISM2ID, if the organism ID is not a
|
| 1021 |
+
value in ORGANISM2ID, or if no name is found for a given ID.
|
| 1022 |
+
|
| 1023 |
+
Note:
|
| 1024 |
+
This function relies on the ORGANISM2ID dictionary imported from
|
| 1025 |
+
CodonTransformer.CodonUtils, which maps organism names to their
|
| 1026 |
+
corresponding IDs.
|
| 1027 |
+
"""
|
| 1028 |
+
if isinstance(organism, str):
|
| 1029 |
+
if organism not in ORGANISM2ID:
|
| 1030 |
+
raise ValueError(
|
| 1031 |
+
f"Invalid organism name: {organism}. "
|
| 1032 |
+
"Please use a valid organism name or ID."
|
| 1033 |
+
)
|
| 1034 |
+
organism_id = ORGANISM2ID[organism]
|
| 1035 |
+
organism_name = organism
|
| 1036 |
+
|
| 1037 |
+
elif isinstance(organism, int):
|
| 1038 |
+
if organism not in ORGANISM2ID.values():
|
| 1039 |
+
raise ValueError(
|
| 1040 |
+
f"Invalid organism ID: {organism}. "
|
| 1041 |
+
"Please use a valid organism name or ID."
|
| 1042 |
+
)
|
| 1043 |
+
|
| 1044 |
+
organism_id = organism
|
| 1045 |
+
organism_name = next(
|
| 1046 |
+
(name for name, id in ORGANISM2ID.items() if id == organism), None
|
| 1047 |
+
)
|
| 1048 |
+
if organism_name is None:
|
| 1049 |
+
raise ValueError(f"No organism name found for ID: {organism}")
|
| 1050 |
+
|
| 1051 |
+
return organism_id, organism_name
|
| 1052 |
+
|
| 1053 |
+
|
| 1054 |
+
def get_high_frequency_choice_sequence(
|
| 1055 |
+
protein: str, codon_frequencies: Dict[str, Tuple[List[str], List[float]]]
|
| 1056 |
+
) -> str:
|
| 1057 |
+
"""
|
| 1058 |
+
Return the DNA sequence optimized using High Frequency Choice (HFC) approach
|
| 1059 |
+
in which the most frequent codon for a given amino acid is always chosen.
|
| 1060 |
+
|
| 1061 |
+
Args:
|
| 1062 |
+
protein (str): The protein sequence.
|
| 1063 |
+
codon_frequencies (Dict[str, Tuple[List[str], List[float]]]): Codon
|
| 1064 |
+
frequencies for each amino acid.
|
| 1065 |
+
|
| 1066 |
+
Returns:
|
| 1067 |
+
str: The optimized DNA sequence.
|
| 1068 |
+
"""
|
| 1069 |
+
# Select the most frequent codon for each amino acid in the protein sequence
|
| 1070 |
+
dna_codons = [
|
| 1071 |
+
codon_frequencies[aminoacid][0][np.argmax(codon_frequencies[aminoacid][1])]
|
| 1072 |
+
for aminoacid in protein
|
| 1073 |
+
]
|
| 1074 |
+
return "".join(dna_codons)
|
| 1075 |
+
|
| 1076 |
+
|
| 1077 |
+
def precompute_most_frequent_codons(
|
| 1078 |
+
codon_frequencies: Dict[str, Tuple[List[str], List[float]]],
|
| 1079 |
+
) -> Dict[str, str]:
|
| 1080 |
+
"""
|
| 1081 |
+
Precompute the most frequent codon for each amino acid.
|
| 1082 |
+
|
| 1083 |
+
Args:
|
| 1084 |
+
codon_frequencies (Dict[str, Tuple[List[str], List[float]]]): Codon
|
| 1085 |
+
frequencies for each amino acid.
|
| 1086 |
+
|
| 1087 |
+
Returns:
|
| 1088 |
+
Dict[str, str]: The most frequent codon for each amino acid.
|
| 1089 |
+
"""
|
| 1090 |
+
# Create a dictionary mapping each amino acid to its most frequent codon
|
| 1091 |
+
return {
|
| 1092 |
+
aminoacid: codons[np.argmax(frequencies)]
|
| 1093 |
+
for aminoacid, (codons, frequencies) in codon_frequencies.items()
|
| 1094 |
+
}
|
| 1095 |
+
|
| 1096 |
+
|
| 1097 |
+
def get_high_frequency_choice_sequence_optimized(
|
| 1098 |
+
protein: str, codon_frequencies: Dict[str, Tuple[List[str], List[float]]]
|
| 1099 |
+
) -> str:
|
| 1100 |
+
"""
|
| 1101 |
+
Efficient implementation of get_high_frequency_choice_sequence that uses
|
| 1102 |
+
vectorized operations and helper functions, achieving up to x10 faster speed.
|
| 1103 |
+
|
| 1104 |
+
Args:
|
| 1105 |
+
protein (str): The protein sequence.
|
| 1106 |
+
codon_frequencies (Dict[str, Tuple[List[str], List[float]]]): Codon
|
| 1107 |
+
frequencies for each amino acid.
|
| 1108 |
+
|
| 1109 |
+
Returns:
|
| 1110 |
+
str: The optimized DNA sequence.
|
| 1111 |
+
"""
|
| 1112 |
+
# Precompute the most frequent codons for each amino acid
|
| 1113 |
+
most_frequent_codons = precompute_most_frequent_codons(codon_frequencies)
|
| 1114 |
+
|
| 1115 |
+
return "".join(most_frequent_codons[aminoacid] for aminoacid in protein)
|
| 1116 |
+
|
| 1117 |
+
|
| 1118 |
+
def get_background_frequency_choice_sequence(
|
| 1119 |
+
protein: str, codon_frequencies: Dict[str, Tuple[List[str], List[float]]]
|
| 1120 |
+
) -> str:
|
| 1121 |
+
"""
|
| 1122 |
+
Return the DNA sequence optimized using Background Frequency Choice (BFC)
|
| 1123 |
+
approach in which a random codon for a given amino acid is chosen using
|
| 1124 |
+
the codon frequencies probability distribution.
|
| 1125 |
+
|
| 1126 |
+
Args:
|
| 1127 |
+
protein (str): The protein sequence.
|
| 1128 |
+
codon_frequencies (Dict[str, Tuple[List[str], List[float]]]): Codon
|
| 1129 |
+
frequencies for each amino acid.
|
| 1130 |
+
|
| 1131 |
+
Returns:
|
| 1132 |
+
str: The optimized DNA sequence.
|
| 1133 |
+
"""
|
| 1134 |
+
# Select a random codon for each amino acid based on the codon frequencies
|
| 1135 |
+
# probability distribution
|
| 1136 |
+
dna_codons = [
|
| 1137 |
+
np.random.choice(
|
| 1138 |
+
codon_frequencies[aminoacid][0], p=codon_frequencies[aminoacid][1]
|
| 1139 |
+
)
|
| 1140 |
+
for aminoacid in protein
|
| 1141 |
+
]
|
| 1142 |
+
return "".join(dna_codons)
|
| 1143 |
+
|
| 1144 |
+
|
| 1145 |
+
def precompute_cdf(
|
| 1146 |
+
codon_frequencies: Dict[str, Tuple[List[str], List[float]]],
|
| 1147 |
+
) -> Dict[str, Tuple[List[str], Any]]:
|
| 1148 |
+
"""
|
| 1149 |
+
Precompute the cumulative distribution function (CDF) for each amino acid.
|
| 1150 |
+
|
| 1151 |
+
Args:
|
| 1152 |
+
codon_frequencies (Dict[str, Tuple[List[str], List[float]]]): Codon
|
| 1153 |
+
frequencies for each amino acid.
|
| 1154 |
+
|
| 1155 |
+
Returns:
|
| 1156 |
+
Dict[str, Tuple[List[str], Any]]: CDFs for each amino acid.
|
| 1157 |
+
"""
|
| 1158 |
+
cdf = {}
|
| 1159 |
+
|
| 1160 |
+
# Calculate the cumulative distribution function for each amino acid
|
| 1161 |
+
for aminoacid, (codons, frequencies) in codon_frequencies.items():
|
| 1162 |
+
cdf[aminoacid] = (codons, np.cumsum(frequencies))
|
| 1163 |
+
|
| 1164 |
+
return cdf
|
| 1165 |
+
|
| 1166 |
+
|
| 1167 |
+
def get_background_frequency_choice_sequence_optimized(
|
| 1168 |
+
protein: str, codon_frequencies: Dict[str, Tuple[List[str], List[float]]]
|
| 1169 |
+
) -> str:
|
| 1170 |
+
"""
|
| 1171 |
+
Efficient implementation of get_background_frequency_choice_sequence that uses
|
| 1172 |
+
vectorized operations and helper functions, achieving up to x8 faster speed.
|
| 1173 |
+
|
| 1174 |
+
Args:
|
| 1175 |
+
protein (str): The protein sequence.
|
| 1176 |
+
codon_frequencies (Dict[str, Tuple[List[str], List[float]]]): Codon
|
| 1177 |
+
frequencies for each amino acid.
|
| 1178 |
+
|
| 1179 |
+
Returns:
|
| 1180 |
+
str: The optimized DNA sequence.
|
| 1181 |
+
"""
|
| 1182 |
+
dna_codons = []
|
| 1183 |
+
cdf = precompute_cdf(codon_frequencies)
|
| 1184 |
+
|
| 1185 |
+
# Select a random codon for each amino acid using the precomputed CDFs
|
| 1186 |
+
for aminoacid in protein:
|
| 1187 |
+
codons, cumulative_prob = cdf[aminoacid]
|
| 1188 |
+
selected_codon_index = np.searchsorted(cumulative_prob, np.random.rand())
|
| 1189 |
+
dna_codons.append(codons[selected_codon_index])
|
| 1190 |
+
|
| 1191 |
+
return "".join(dna_codons)
|
| 1192 |
+
|
| 1193 |
+
|
| 1194 |
+
def get_uniform_random_choice_sequence(
|
| 1195 |
+
protein: str, codon_frequencies: Dict[str, Tuple[List[str], List[float]]]
|
| 1196 |
+
) -> str:
|
| 1197 |
+
"""
|
| 1198 |
+
Return the DNA sequence optimized using Uniform Random Choice (URC) approach
|
| 1199 |
+
in which a random codon for a given amino acid is chosen using a uniform
|
| 1200 |
+
prior.
|
| 1201 |
+
|
| 1202 |
+
Args:
|
| 1203 |
+
protein (str): The protein sequence.
|
| 1204 |
+
codon_frequencies (Dict[str, Tuple[List[str], List[float]]]): Codon
|
| 1205 |
+
frequencies for each amino acid.
|
| 1206 |
+
|
| 1207 |
+
Returns:
|
| 1208 |
+
str: The optimized DNA sequence.
|
| 1209 |
+
"""
|
| 1210 |
+
# Select a random codon for each amino acid using a uniform prior distribution
|
| 1211 |
+
dna_codons = []
|
| 1212 |
+
for aminoacid in protein:
|
| 1213 |
+
codons = codon_frequencies[aminoacid][0]
|
| 1214 |
+
random_index = np.random.randint(0, len(codons))
|
| 1215 |
+
dna_codons.append(codons[random_index])
|
| 1216 |
+
return "".join(dna_codons)
|
| 1217 |
+
|
| 1218 |
+
|
| 1219 |
+
def get_icor_prediction(input_seq: str, model_path: str, stop_symbol: str) -> str:
|
| 1220 |
+
"""
|
| 1221 |
+
Return the optimized codon sequence for the given protein sequence using ICOR.
|
| 1222 |
+
|
| 1223 |
+
Credit: ICOR: improving codon optimization with recurrent neural networks
|
| 1224 |
+
Rishab Jain, Aditya Jain, Elizabeth Mauro, Kevin LeShane, Douglas
|
| 1225 |
+
Densmore
|
| 1226 |
+
|
| 1227 |
+
Args:
|
| 1228 |
+
input_seq (str): The input protein sequence.
|
| 1229 |
+
model_path (str): The path to the ICOR model.
|
| 1230 |
+
stop_symbol (str): The symbol representing stop codons in the sequence.
|
| 1231 |
+
|
| 1232 |
+
Returns:
|
| 1233 |
+
str: The optimized DNA sequence.
|
| 1234 |
+
"""
|
| 1235 |
+
input_seq = input_seq.strip().upper()
|
| 1236 |
+
input_seq = input_seq.replace(stop_symbol, "*")
|
| 1237 |
+
|
| 1238 |
+
# Define categorical labels from when model was trained.
|
| 1239 |
+
labels = [
|
| 1240 |
+
"AAA",
|
| 1241 |
+
"AAC",
|
| 1242 |
+
"AAG",
|
| 1243 |
+
"AAT",
|
| 1244 |
+
"ACA",
|
| 1245 |
+
"ACG",
|
| 1246 |
+
"ACT",
|
| 1247 |
+
"AGC",
|
| 1248 |
+
"ATA",
|
| 1249 |
+
"ATC",
|
| 1250 |
+
"ATG",
|
| 1251 |
+
"ATT",
|
| 1252 |
+
"CAA",
|
| 1253 |
+
"CAC",
|
| 1254 |
+
"CAG",
|
| 1255 |
+
"CCG",
|
| 1256 |
+
"CCT",
|
| 1257 |
+
"CTA",
|
| 1258 |
+
"CTC",
|
| 1259 |
+
"CTG",
|
| 1260 |
+
"CTT",
|
| 1261 |
+
"GAA",
|
| 1262 |
+
"GAT",
|
| 1263 |
+
"GCA",
|
| 1264 |
+
"GCC",
|
| 1265 |
+
"GCG",
|
| 1266 |
+
"GCT",
|
| 1267 |
+
"GGA",
|
| 1268 |
+
"GGC",
|
| 1269 |
+
"GTC",
|
| 1270 |
+
"GTG",
|
| 1271 |
+
"GTT",
|
| 1272 |
+
"TAA",
|
| 1273 |
+
"TAT",
|
| 1274 |
+
"TCA",
|
| 1275 |
+
"TCG",
|
| 1276 |
+
"TCT",
|
| 1277 |
+
"TGG",
|
| 1278 |
+
"TGT",
|
| 1279 |
+
"TTA",
|
| 1280 |
+
"TTC",
|
| 1281 |
+
"TTG",
|
| 1282 |
+
"TTT",
|
| 1283 |
+
"ACC",
|
| 1284 |
+
"CAT",
|
| 1285 |
+
"CCA",
|
| 1286 |
+
"CGG",
|
| 1287 |
+
"CGT",
|
| 1288 |
+
"GAC",
|
| 1289 |
+
"GAG",
|
| 1290 |
+
"GGT",
|
| 1291 |
+
"AGT",
|
| 1292 |
+
"GGG",
|
| 1293 |
+
"GTA",
|
| 1294 |
+
"TGC",
|
| 1295 |
+
"CCC",
|
| 1296 |
+
"CGA",
|
| 1297 |
+
"CGC",
|
| 1298 |
+
"TAC",
|
| 1299 |
+
"TAG",
|
| 1300 |
+
"TCC",
|
| 1301 |
+
"AGA",
|
| 1302 |
+
"AGG",
|
| 1303 |
+
"TGA",
|
| 1304 |
+
]
|
| 1305 |
+
|
| 1306 |
+
# Define aa to integer table
|
| 1307 |
+
def aa2int(seq: str) -> List[int]:
|
| 1308 |
+
_aa2int = {
|
| 1309 |
+
"A": 1,
|
| 1310 |
+
"R": 2,
|
| 1311 |
+
"N": 3,
|
| 1312 |
+
"D": 4,
|
| 1313 |
+
"C": 5,
|
| 1314 |
+
"Q": 6,
|
| 1315 |
+
"E": 7,
|
| 1316 |
+
"G": 8,
|
| 1317 |
+
"H": 9,
|
| 1318 |
+
"I": 10,
|
| 1319 |
+
"L": 11,
|
| 1320 |
+
"K": 12,
|
| 1321 |
+
"M": 13,
|
| 1322 |
+
"F": 14,
|
| 1323 |
+
"P": 15,
|
| 1324 |
+
"S": 16,
|
| 1325 |
+
"T": 17,
|
| 1326 |
+
"W": 18,
|
| 1327 |
+
"Y": 19,
|
| 1328 |
+
"V": 20,
|
| 1329 |
+
"B": 21,
|
| 1330 |
+
"Z": 22,
|
| 1331 |
+
"X": 23,
|
| 1332 |
+
"*": 24,
|
| 1333 |
+
"-": 25,
|
| 1334 |
+
"?": 26,
|
| 1335 |
+
}
|
| 1336 |
+
return [_aa2int[i] for i in seq]
|
| 1337 |
+
|
| 1338 |
+
# Create empty array to fill
|
| 1339 |
+
oh_array = np.zeros(shape=(26, len(input_seq)))
|
| 1340 |
+
|
| 1341 |
+
# Load placements from aa2int
|
| 1342 |
+
aa_placement = aa2int(input_seq)
|
| 1343 |
+
|
| 1344 |
+
# One-hot encode the amino acid sequence:
|
| 1345 |
+
for i in range(0, len(aa_placement)):
|
| 1346 |
+
oh_array[aa_placement[i], i] = 1
|
| 1347 |
+
i += 1
|
| 1348 |
+
|
| 1349 |
+
oh_array = [oh_array]
|
| 1350 |
+
x = np.array(np.transpose(oh_array))
|
| 1351 |
+
|
| 1352 |
+
y = x.astype(np.float32)
|
| 1353 |
+
|
| 1354 |
+
y = np.reshape(y, (y.shape[0], 1, 26))
|
| 1355 |
+
|
| 1356 |
+
# Start ICOR session using model.
|
| 1357 |
+
sess = rt.InferenceSession(model_path)
|
| 1358 |
+
input_name = sess.get_inputs()[0].name
|
| 1359 |
+
|
| 1360 |
+
# Get prediction:
|
| 1361 |
+
pred_onx = sess.run(None, {input_name: y})
|
| 1362 |
+
|
| 1363 |
+
# Get the index of the highest probability from softmax output:
|
| 1364 |
+
pred_indices = []
|
| 1365 |
+
for pred in pred_onx[0]:
|
| 1366 |
+
pred_indices.append(np.argmax(pred))
|
| 1367 |
+
|
| 1368 |
+
out_str = ""
|
| 1369 |
+
for index in pred_indices:
|
| 1370 |
+
out_str += labels[index]
|
| 1371 |
+
|
| 1372 |
+
return out_str
|
CodonTransformer/CodonUtils.py
ADDED
|
@@ -0,0 +1,871 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
File: CodonUtils.py
|
| 3 |
+
---------------------
|
| 4 |
+
Includes constants and helper functions used by other Python scripts.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import itertools
|
| 8 |
+
import json
|
| 9 |
+
import os
|
| 10 |
+
import pickle
|
| 11 |
+
import re
|
| 12 |
+
from abc import ABC, abstractmethod
|
| 13 |
+
from dataclasses import dataclass
|
| 14 |
+
from typing import Any, Dict, Iterator, List, Optional, Tuple
|
| 15 |
+
|
| 16 |
+
import pandas as pd
|
| 17 |
+
import requests
|
| 18 |
+
import torch
|
| 19 |
+
|
| 20 |
+
# List of all amino acids
|
| 21 |
+
AMINO_ACIDS: List[str] = [
|
| 22 |
+
"A", # Alanine
|
| 23 |
+
"C", # Cysteine
|
| 24 |
+
"D", # Aspartic acid
|
| 25 |
+
"E", # Glutamic acid
|
| 26 |
+
"F", # Phenylalanine
|
| 27 |
+
"G", # Glycine
|
| 28 |
+
"H", # Histidine
|
| 29 |
+
"I", # Isoleucine
|
| 30 |
+
"K", # Lysine
|
| 31 |
+
"L", # Leucine
|
| 32 |
+
"M", # Methionine
|
| 33 |
+
"N", # Asparagine
|
| 34 |
+
"P", # Proline
|
| 35 |
+
"Q", # Glutamine
|
| 36 |
+
"R", # Arginine
|
| 37 |
+
"S", # Serine
|
| 38 |
+
"T", # Threonine
|
| 39 |
+
"V", # Valine
|
| 40 |
+
"W", # Tryptophan
|
| 41 |
+
"Y", # Tyrosine
|
| 42 |
+
]
|
| 43 |
+
STOP_SYMBOLS = ["_", "*"] # Stop codon symbols
|
| 44 |
+
|
| 45 |
+
# Dictionary ambiguous amino acids to standard amino acids
|
| 46 |
+
AMBIGUOUS_AMINOACID_MAP: Dict[str, list[str]] = {
|
| 47 |
+
"B": ["N", "D"], # Asparagine (N) or Aspartic acid (D)
|
| 48 |
+
"Z": ["Q", "E"], # Glutamine (Q) or Glutamic acid (E)
|
| 49 |
+
"X": ["A"], # Any amino acid (typically replaced with Alanine)
|
| 50 |
+
"J": ["L", "I"], # Leucine (L) or Isoleucine (I)
|
| 51 |
+
"U": ["C"], # Selenocysteine (typically replaced with Cysteine)
|
| 52 |
+
"O": ["K"], # Pyrrolysine (typically replaced with Lysine)
|
| 53 |
+
}
|
| 54 |
+
|
| 55 |
+
# List of all possible start and stop codons
|
| 56 |
+
START_CODONS: List[str] = ["ATG", "TTG", "CTG", "GTG"]
|
| 57 |
+
STOP_CODONS: List[str] = ["TAA", "TAG", "TGA"]
|
| 58 |
+
|
| 59 |
+
# Token-to-index mapping for amino acids and special tokens
|
| 60 |
+
TOKEN2INDEX: Dict[str, int] = {
|
| 61 |
+
"[UNK]": 0,
|
| 62 |
+
"[CLS]": 1,
|
| 63 |
+
"[SEP]": 2,
|
| 64 |
+
"[PAD]": 3,
|
| 65 |
+
"[MASK]": 4,
|
| 66 |
+
"a_unk": 5,
|
| 67 |
+
"c_unk": 6,
|
| 68 |
+
"d_unk": 7,
|
| 69 |
+
"e_unk": 8,
|
| 70 |
+
"f_unk": 9,
|
| 71 |
+
"g_unk": 10,
|
| 72 |
+
"h_unk": 11,
|
| 73 |
+
"i_unk": 12,
|
| 74 |
+
"k_unk": 13,
|
| 75 |
+
"l_unk": 14,
|
| 76 |
+
"m_unk": 15,
|
| 77 |
+
"n_unk": 16,
|
| 78 |
+
"p_unk": 17,
|
| 79 |
+
"q_unk": 18,
|
| 80 |
+
"r_unk": 19,
|
| 81 |
+
"s_unk": 20,
|
| 82 |
+
"t_unk": 21,
|
| 83 |
+
"v_unk": 22,
|
| 84 |
+
"w_unk": 23,
|
| 85 |
+
"y_unk": 24,
|
| 86 |
+
"__unk": 25,
|
| 87 |
+
"k_aaa": 26,
|
| 88 |
+
"n_aac": 27,
|
| 89 |
+
"k_aag": 28,
|
| 90 |
+
"n_aat": 29,
|
| 91 |
+
"t_aca": 30,
|
| 92 |
+
"t_acc": 31,
|
| 93 |
+
"t_acg": 32,
|
| 94 |
+
"t_act": 33,
|
| 95 |
+
"r_aga": 34,
|
| 96 |
+
"s_agc": 35,
|
| 97 |
+
"r_agg": 36,
|
| 98 |
+
"s_agt": 37,
|
| 99 |
+
"i_ata": 38,
|
| 100 |
+
"i_atc": 39,
|
| 101 |
+
"m_atg": 40,
|
| 102 |
+
"i_att": 41,
|
| 103 |
+
"q_caa": 42,
|
| 104 |
+
"h_cac": 43,
|
| 105 |
+
"q_cag": 44,
|
| 106 |
+
"h_cat": 45,
|
| 107 |
+
"p_cca": 46,
|
| 108 |
+
"p_ccc": 47,
|
| 109 |
+
"p_ccg": 48,
|
| 110 |
+
"p_cct": 49,
|
| 111 |
+
"r_cga": 50,
|
| 112 |
+
"r_cgc": 51,
|
| 113 |
+
"r_cgg": 52,
|
| 114 |
+
"r_cgt": 53,
|
| 115 |
+
"l_cta": 54,
|
| 116 |
+
"l_ctc": 55,
|
| 117 |
+
"l_ctg": 56,
|
| 118 |
+
"l_ctt": 57,
|
| 119 |
+
"e_gaa": 58,
|
| 120 |
+
"d_gac": 59,
|
| 121 |
+
"e_gag": 60,
|
| 122 |
+
"d_gat": 61,
|
| 123 |
+
"a_gca": 62,
|
| 124 |
+
"a_gcc": 63,
|
| 125 |
+
"a_gcg": 64,
|
| 126 |
+
"a_gct": 65,
|
| 127 |
+
"g_gga": 66,
|
| 128 |
+
"g_ggc": 67,
|
| 129 |
+
"g_ggg": 68,
|
| 130 |
+
"g_ggt": 69,
|
| 131 |
+
"v_gta": 70,
|
| 132 |
+
"v_gtc": 71,
|
| 133 |
+
"v_gtg": 72,
|
| 134 |
+
"v_gtt": 73,
|
| 135 |
+
"__taa": 74,
|
| 136 |
+
"y_tac": 75,
|
| 137 |
+
"__tag": 76,
|
| 138 |
+
"y_tat": 77,
|
| 139 |
+
"s_tca": 78,
|
| 140 |
+
"s_tcc": 79,
|
| 141 |
+
"s_tcg": 80,
|
| 142 |
+
"s_tct": 81,
|
| 143 |
+
"__tga": 82,
|
| 144 |
+
"c_tgc": 83,
|
| 145 |
+
"w_tgg": 84,
|
| 146 |
+
"c_tgt": 85,
|
| 147 |
+
"l_tta": 86,
|
| 148 |
+
"f_ttc": 87,
|
| 149 |
+
"l_ttg": 88,
|
| 150 |
+
"f_ttt": 89,
|
| 151 |
+
}
|
| 152 |
+
|
| 153 |
+
# Index-to-token mapping, reverse of TOKEN2INDEX
|
| 154 |
+
INDEX2TOKEN: Dict[int, str] = {i: c for c, i in TOKEN2INDEX.items()}
|
| 155 |
+
|
| 156 |
+
# Dictionary mapping each codon to its GC content
|
| 157 |
+
CODON_GC_CONTENT: Dict[str, int] = {
|
| 158 |
+
token.split("_")[1]: token.split("_")[1].upper().count("G") + token.split("_")[1].upper().count("C")
|
| 159 |
+
for token in TOKEN2INDEX
|
| 160 |
+
if "_" in token and len(token.split("_")[1]) == 3
|
| 161 |
+
}
|
| 162 |
+
|
| 163 |
+
# Tensor with GC counts for each token in the vocabulary
|
| 164 |
+
GC_COUNTS_PER_TOKEN = torch.zeros(len(TOKEN2INDEX))
|
| 165 |
+
for token, index in TOKEN2INDEX.items():
|
| 166 |
+
if "_" in token and len(token.split("_")[1]) == 3:
|
| 167 |
+
codon = token.split("_")[1].upper()
|
| 168 |
+
gc_count = codon.count("G") + codon.count("C")
|
| 169 |
+
GC_COUNTS_PER_TOKEN[index] = gc_count
|
| 170 |
+
|
| 171 |
+
G_indices = [idx for token, idx in TOKEN2INDEX.items() if "g" in token.split("_")[-1]]
|
| 172 |
+
C_indices = [idx for token, idx in TOKEN2INDEX.items() if "c" in token.split("_")[-1]]
|
| 173 |
+
|
| 174 |
+
# Dictionary mapping each amino acid and stop symbol to indices of codon tokens that translate to it
|
| 175 |
+
AMINO_ACID_TO_INDEX = {
|
| 176 |
+
aa: sorted(
|
| 177 |
+
[i for t, i in TOKEN2INDEX.items() if t[0].upper() == aa and t[-3:] != "unk"]
|
| 178 |
+
)
|
| 179 |
+
for aa in (AMINO_ACIDS + STOP_SYMBOLS)
|
| 180 |
+
}
|
| 181 |
+
|
| 182 |
+
|
| 183 |
+
# Dictionary mapping each amino acid to min/max GC content across all possible codons
|
| 184 |
+
AA_MIN_GC: Dict[str, int] = {}
|
| 185 |
+
AA_MAX_GC: Dict[str, int] = {}
|
| 186 |
+
|
| 187 |
+
for aa, token_indices in AMINO_ACID_TO_INDEX.items():
|
| 188 |
+
if token_indices: # Skip if no tokens for this amino acid
|
| 189 |
+
gc_counts = []
|
| 190 |
+
for token_idx in token_indices:
|
| 191 |
+
token = INDEX2TOKEN[token_idx]
|
| 192 |
+
if "_" in token and len(token.split("_")[1]) == 3:
|
| 193 |
+
codon = token.split("_")[1]
|
| 194 |
+
if codon in CODON_GC_CONTENT:
|
| 195 |
+
gc_counts.append(CODON_GC_CONTENT[codon])
|
| 196 |
+
|
| 197 |
+
if gc_counts:
|
| 198 |
+
AA_MIN_GC[aa] = min(gc_counts)
|
| 199 |
+
AA_MAX_GC[aa] = max(gc_counts)
|
| 200 |
+
|
| 201 |
+
# Mask token mapping
|
| 202 |
+
TOKEN2MASK: Dict[int, int] = {
|
| 203 |
+
0: 0,
|
| 204 |
+
1: 1,
|
| 205 |
+
2: 2,
|
| 206 |
+
3: 3,
|
| 207 |
+
4: 4,
|
| 208 |
+
5: 5,
|
| 209 |
+
6: 6,
|
| 210 |
+
7: 7,
|
| 211 |
+
8: 8,
|
| 212 |
+
9: 9,
|
| 213 |
+
10: 10,
|
| 214 |
+
11: 11,
|
| 215 |
+
12: 12,
|
| 216 |
+
13: 13,
|
| 217 |
+
14: 14,
|
| 218 |
+
15: 15,
|
| 219 |
+
16: 16,
|
| 220 |
+
17: 17,
|
| 221 |
+
18: 18,
|
| 222 |
+
19: 19,
|
| 223 |
+
20: 20,
|
| 224 |
+
21: 21,
|
| 225 |
+
22: 22,
|
| 226 |
+
23: 23,
|
| 227 |
+
24: 24,
|
| 228 |
+
25: 25,
|
| 229 |
+
26: 13,
|
| 230 |
+
27: 16,
|
| 231 |
+
28: 13,
|
| 232 |
+
29: 16,
|
| 233 |
+
30: 21,
|
| 234 |
+
31: 21,
|
| 235 |
+
32: 21,
|
| 236 |
+
33: 21,
|
| 237 |
+
34: 19,
|
| 238 |
+
35: 20,
|
| 239 |
+
36: 19,
|
| 240 |
+
37: 20,
|
| 241 |
+
38: 12,
|
| 242 |
+
39: 12,
|
| 243 |
+
40: 15,
|
| 244 |
+
41: 12,
|
| 245 |
+
42: 18,
|
| 246 |
+
43: 11,
|
| 247 |
+
44: 18,
|
| 248 |
+
45: 11,
|
| 249 |
+
46: 17,
|
| 250 |
+
47: 17,
|
| 251 |
+
48: 17,
|
| 252 |
+
49: 17,
|
| 253 |
+
50: 19,
|
| 254 |
+
51: 19,
|
| 255 |
+
52: 19,
|
| 256 |
+
53: 19,
|
| 257 |
+
54: 14,
|
| 258 |
+
55: 14,
|
| 259 |
+
56: 14,
|
| 260 |
+
57: 14,
|
| 261 |
+
58: 8,
|
| 262 |
+
59: 7,
|
| 263 |
+
60: 8,
|
| 264 |
+
61: 7,
|
| 265 |
+
62: 5,
|
| 266 |
+
63: 5,
|
| 267 |
+
64: 5,
|
| 268 |
+
65: 5,
|
| 269 |
+
66: 10,
|
| 270 |
+
67: 10,
|
| 271 |
+
68: 10,
|
| 272 |
+
69: 10,
|
| 273 |
+
70: 22,
|
| 274 |
+
71: 22,
|
| 275 |
+
72: 22,
|
| 276 |
+
73: 22,
|
| 277 |
+
74: 25,
|
| 278 |
+
75: 24,
|
| 279 |
+
76: 25,
|
| 280 |
+
77: 24,
|
| 281 |
+
78: 20,
|
| 282 |
+
79: 20,
|
| 283 |
+
80: 20,
|
| 284 |
+
81: 20,
|
| 285 |
+
82: 25,
|
| 286 |
+
83: 6,
|
| 287 |
+
84: 23,
|
| 288 |
+
85: 6,
|
| 289 |
+
86: 14,
|
| 290 |
+
87: 9,
|
| 291 |
+
88: 14,
|
| 292 |
+
89: 9,
|
| 293 |
+
}
|
| 294 |
+
|
| 295 |
+
# List of organisms used for fine-tuning
|
| 296 |
+
FINE_TUNE_ORGANISMS: List[str] = [
|
| 297 |
+
"Arabidopsis thaliana",
|
| 298 |
+
"Bacillus subtilis",
|
| 299 |
+
"Caenorhabditis elegans",
|
| 300 |
+
"Chlamydomonas reinhardtii",
|
| 301 |
+
"Chlamydomonas reinhardtii chloroplast",
|
| 302 |
+
"Danio rerio",
|
| 303 |
+
"Drosophila melanogaster",
|
| 304 |
+
"Homo sapiens",
|
| 305 |
+
"Mus musculus",
|
| 306 |
+
"Nicotiana tabacum",
|
| 307 |
+
"Nicotiana tabacum chloroplast",
|
| 308 |
+
"Pseudomonas putida",
|
| 309 |
+
"Saccharomyces cerevisiae",
|
| 310 |
+
"Escherichia coli O157-H7 str. Sakai",
|
| 311 |
+
"Escherichia coli general",
|
| 312 |
+
"Escherichia coli str. K-12 substr. MG1655",
|
| 313 |
+
"Thermococcus barophilus MPT",
|
| 314 |
+
]
|
| 315 |
+
|
| 316 |
+
# List of organisms most commonly used for coodn optimization
|
| 317 |
+
COMMON_ORGANISMS: List[str] = [
|
| 318 |
+
"Arabidopsis thaliana",
|
| 319 |
+
"Bacillus subtilis",
|
| 320 |
+
"Caenorhabditis elegans",
|
| 321 |
+
"Chlamydomonas reinhardtii",
|
| 322 |
+
"Danio rerio",
|
| 323 |
+
"Drosophila melanogaster",
|
| 324 |
+
"Homo sapiens",
|
| 325 |
+
"Mus musculus",
|
| 326 |
+
"Nicotiana tabacum",
|
| 327 |
+
"Pseudomonas putida",
|
| 328 |
+
"Saccharomyces cerevisiae",
|
| 329 |
+
"Escherichia coli general",
|
| 330 |
+
]
|
| 331 |
+
|
| 332 |
+
# Dictionary mapping each organism name to respective organism id
|
| 333 |
+
ORGANISM2ID: Dict[str, int] = {
|
| 334 |
+
"Arabidopsis thaliana": 0,
|
| 335 |
+
"Atlantibacter hermannii": 1,
|
| 336 |
+
"Bacillus subtilis": 2,
|
| 337 |
+
"Brenneria goodwinii": 3,
|
| 338 |
+
"Buchnera aphidicola (Schizaphis graminum)": 4,
|
| 339 |
+
"Caenorhabditis elegans": 5,
|
| 340 |
+
"Candidatus Erwinia haradaeae": 6,
|
| 341 |
+
"Candidatus Hamiltonella defensa 5AT (Acyrthosiphon pisum)": 7,
|
| 342 |
+
"Chlamydomonas reinhardtii": 8,
|
| 343 |
+
"Chlamydomonas reinhardtii chloroplast": 9,
|
| 344 |
+
"Citrobacter amalonaticus": 10,
|
| 345 |
+
"Citrobacter braakii": 11,
|
| 346 |
+
"Citrobacter cronae": 12,
|
| 347 |
+
"Citrobacter europaeus": 13,
|
| 348 |
+
"Citrobacter farmeri": 14,
|
| 349 |
+
"Citrobacter freundii": 15,
|
| 350 |
+
"Citrobacter koseri ATCC BAA-895": 16,
|
| 351 |
+
"Citrobacter portucalensis": 17,
|
| 352 |
+
"Citrobacter werkmanii": 18,
|
| 353 |
+
"Citrobacter youngae": 19,
|
| 354 |
+
"Cronobacter dublinensis subsp. dublinensis LMG 23823": 20,
|
| 355 |
+
"Cronobacter malonaticus LMG 23826": 21,
|
| 356 |
+
"Cronobacter sakazakii": 22,
|
| 357 |
+
"Cronobacter turicensis": 23,
|
| 358 |
+
"Danio rerio": 24,
|
| 359 |
+
"Dickeya dadantii 3937": 25,
|
| 360 |
+
"Dickeya dianthicola": 26,
|
| 361 |
+
"Dickeya fangzhongdai": 27,
|
| 362 |
+
"Dickeya solani": 28,
|
| 363 |
+
"Dickeya zeae": 29,
|
| 364 |
+
"Drosophila melanogaster": 30,
|
| 365 |
+
"Edwardsiella anguillarum ET080813": 31,
|
| 366 |
+
"Edwardsiella ictaluri": 32,
|
| 367 |
+
"Edwardsiella piscicida": 33,
|
| 368 |
+
"Edwardsiella tarda": 34,
|
| 369 |
+
"Enterobacter asburiae": 35,
|
| 370 |
+
"Enterobacter bugandensis": 36,
|
| 371 |
+
"Enterobacter cancerogenus": 37,
|
| 372 |
+
"Enterobacter chengduensis": 38,
|
| 373 |
+
"Enterobacter cloacae": 39,
|
| 374 |
+
"Enterobacter hormaechei": 40,
|
| 375 |
+
"Enterobacter kobei": 41,
|
| 376 |
+
"Enterobacter ludwigii": 42,
|
| 377 |
+
"Enterobacter mori": 43,
|
| 378 |
+
"Enterobacter quasiroggenkampii": 44,
|
| 379 |
+
"Enterobacter roggenkampii": 45,
|
| 380 |
+
"Enterobacter sichuanensis": 46,
|
| 381 |
+
"Erwinia amylovora CFBP1430": 47,
|
| 382 |
+
"Erwinia persicina": 48,
|
| 383 |
+
"Escherichia albertii": 49,
|
| 384 |
+
"Escherichia coli O157-H7 str. Sakai": 50,
|
| 385 |
+
"Escherichia coli general": 51,
|
| 386 |
+
"Escherichia coli str. K-12 substr. MG1655": 52,
|
| 387 |
+
"Escherichia fergusonii": 53,
|
| 388 |
+
"Escherichia marmotae": 54,
|
| 389 |
+
"Escherichia ruysiae": 55,
|
| 390 |
+
"Ewingella americana": 56,
|
| 391 |
+
"Hafnia alvei": 57,
|
| 392 |
+
"Hafnia paralvei": 58,
|
| 393 |
+
"Homo sapiens": 59,
|
| 394 |
+
"Kalamiella piersonii": 60,
|
| 395 |
+
"Klebsiella aerogenes": 61,
|
| 396 |
+
"Klebsiella grimontii": 62,
|
| 397 |
+
"Klebsiella michiganensis": 63,
|
| 398 |
+
"Klebsiella oxytoca": 64,
|
| 399 |
+
"Klebsiella pasteurii": 65,
|
| 400 |
+
"Klebsiella pneumoniae subsp. pneumoniae HS11286": 66,
|
| 401 |
+
"Klebsiella quasipneumoniae": 67,
|
| 402 |
+
"Klebsiella quasivariicola": 68,
|
| 403 |
+
"Klebsiella variicola": 69,
|
| 404 |
+
"Kosakonia cowanii": 70,
|
| 405 |
+
"Kosakonia radicincitans": 71,
|
| 406 |
+
"Leclercia adecarboxylata": 72,
|
| 407 |
+
"Lelliottia amnigena": 73,
|
| 408 |
+
"Lonsdalea populi": 74,
|
| 409 |
+
"Moellerella wisconsensis": 75,
|
| 410 |
+
"Morganella morganii": 76,
|
| 411 |
+
"Mus musculus": 77,
|
| 412 |
+
"Nicotiana tabacum": 78,
|
| 413 |
+
"Nicotiana tabacum chloroplast": 79,
|
| 414 |
+
"Obesumbacterium proteus": 80,
|
| 415 |
+
"Pantoea agglomerans": 81,
|
| 416 |
+
"Pantoea allii": 82,
|
| 417 |
+
"Pantoea ananatis PA13": 83,
|
| 418 |
+
"Pantoea dispersa": 84,
|
| 419 |
+
"Pantoea stewartii": 85,
|
| 420 |
+
"Pantoea vagans": 86,
|
| 421 |
+
"Pectobacterium aroidearum": 87,
|
| 422 |
+
"Pectobacterium atrosepticum": 88,
|
| 423 |
+
"Pectobacterium brasiliense": 89,
|
| 424 |
+
"Pectobacterium carotovorum": 90,
|
| 425 |
+
"Pectobacterium odoriferum": 91,
|
| 426 |
+
"Pectobacterium parmentieri": 92,
|
| 427 |
+
"Pectobacterium polaris": 93,
|
| 428 |
+
"Pectobacterium versatile": 94,
|
| 429 |
+
"Photorhabdus laumondii subsp. laumondii TTO1": 95,
|
| 430 |
+
"Plesiomonas shigelloides": 96,
|
| 431 |
+
"Pluralibacter gergoviae": 97,
|
| 432 |
+
"Proteus faecis": 98,
|
| 433 |
+
"Proteus mirabilis HI4320": 99,
|
| 434 |
+
"Proteus penneri": 100,
|
| 435 |
+
"Proteus terrae subsp. cibarius": 101,
|
| 436 |
+
"Proteus vulgaris": 102,
|
| 437 |
+
"Providencia alcalifaciens": 103,
|
| 438 |
+
"Providencia heimbachae": 104,
|
| 439 |
+
"Providencia rettgeri": 105,
|
| 440 |
+
"Providencia rustigianii": 106,
|
| 441 |
+
"Providencia stuartii": 107,
|
| 442 |
+
"Providencia thailandensis": 108,
|
| 443 |
+
"Pseudomonas putida": 109,
|
| 444 |
+
"Pyrococcus furiosus": 110,
|
| 445 |
+
"Pyrococcus horikoshii": 111,
|
| 446 |
+
"Pyrococcus yayanosii": 112,
|
| 447 |
+
"Rahnella aquatilis CIP 78.65 = ATCC 33071": 113,
|
| 448 |
+
"Raoultella ornithinolytica": 114,
|
| 449 |
+
"Raoultella planticola": 115,
|
| 450 |
+
"Raoultella terrigena": 116,
|
| 451 |
+
"Rosenbergiella epipactidis": 117,
|
| 452 |
+
"Rouxiella badensis": 118,
|
| 453 |
+
"Saccharolobus solfataricus": 119,
|
| 454 |
+
"Saccharomyces cerevisiae": 120,
|
| 455 |
+
"Salmonella bongori N268-08": 121,
|
| 456 |
+
"Salmonella enterica subsp. enterica serovar Typhimurium str. LT2": 122,
|
| 457 |
+
"Serratia bockelmannii": 123,
|
| 458 |
+
"Serratia entomophila": 124,
|
| 459 |
+
"Serratia ficaria": 125,
|
| 460 |
+
"Serratia fonticola": 126,
|
| 461 |
+
"Serratia grimesii": 127,
|
| 462 |
+
"Serratia liquefaciens": 128,
|
| 463 |
+
"Serratia marcescens": 129,
|
| 464 |
+
"Serratia nevei": 130,
|
| 465 |
+
"Serratia plymuthica AS9": 131,
|
| 466 |
+
"Serratia proteamaculans": 132,
|
| 467 |
+
"Serratia quinivorans": 133,
|
| 468 |
+
"Serratia rubidaea": 134,
|
| 469 |
+
"Serratia ureilytica": 135,
|
| 470 |
+
"Shigella boydii": 136,
|
| 471 |
+
"Shigella dysenteriae": 137,
|
| 472 |
+
"Shigella flexneri 2a str. 301": 138,
|
| 473 |
+
"Shigella sonnei": 139,
|
| 474 |
+
"Thermoccoccus kodakarensis": 140,
|
| 475 |
+
"Thermococcus barophilus MPT": 141,
|
| 476 |
+
"Thermococcus chitonophagus": 142,
|
| 477 |
+
"Thermococcus gammatolerans": 143,
|
| 478 |
+
"Thermococcus litoralis": 144,
|
| 479 |
+
"Thermococcus onnurineus": 145,
|
| 480 |
+
"Thermococcus sibiricus": 146,
|
| 481 |
+
"Xenorhabdus bovienii str. feltiae Florida": 147,
|
| 482 |
+
"Yersinia aldovae 670-83": 148,
|
| 483 |
+
"Yersinia aleksiciae": 149,
|
| 484 |
+
"Yersinia alsatica": 150,
|
| 485 |
+
"Yersinia enterocolitica": 151,
|
| 486 |
+
"Yersinia frederiksenii ATCC 33641": 152,
|
| 487 |
+
"Yersinia intermedia": 153,
|
| 488 |
+
"Yersinia kristensenii": 154,
|
| 489 |
+
"Yersinia massiliensis CCUG 53443": 155,
|
| 490 |
+
"Yersinia mollaretii ATCC 43969": 156,
|
| 491 |
+
"Yersinia pestis A1122": 157,
|
| 492 |
+
"Yersinia proxima": 158,
|
| 493 |
+
"Yersinia pseudotuberculosis IP 32953": 159,
|
| 494 |
+
"Yersinia rochesterensis": 160,
|
| 495 |
+
"Yersinia rohdei": 161,
|
| 496 |
+
"Yersinia ruckeri": 162,
|
| 497 |
+
"Yokenella regensburgei": 163,
|
| 498 |
+
}
|
| 499 |
+
|
| 500 |
+
# Dictionary mapping each organism id to respective organism name
|
| 501 |
+
ID2ORGANISM = {v: k for k, v in ORGANISM2ID.items()}
|
| 502 |
+
|
| 503 |
+
# Type alias for amino acid to codon mapping
|
| 504 |
+
AMINO2CODON_TYPE = Dict[str, Tuple[List[str], List[float]]]
|
| 505 |
+
|
| 506 |
+
# Constants for the number of organisms and sequence lengths
|
| 507 |
+
NUM_ORGANISMS = 164
|
| 508 |
+
MAX_LEN = 2048
|
| 509 |
+
MAX_AMINO_ACIDS = MAX_LEN - 2 # Without special tokens [CLS] and [SEP]
|
| 510 |
+
STOP_SYMBOL = "_"
|
| 511 |
+
|
| 512 |
+
|
| 513 |
+
@dataclass
|
| 514 |
+
class DNASequencePrediction:
|
| 515 |
+
"""
|
| 516 |
+
A class to hold the output of the DNA sequence prediction.
|
| 517 |
+
|
| 518 |
+
Attributes:
|
| 519 |
+
organism (str): Name of the organism used for prediction.
|
| 520 |
+
protein (str): Input protein sequence for which DNA sequence is predicted.
|
| 521 |
+
processed_input (str): Processed input sequence (merged protein and DNA).
|
| 522 |
+
predicted_dna (str): Predicted DNA sequence.
|
| 523 |
+
"""
|
| 524 |
+
|
| 525 |
+
organism: str
|
| 526 |
+
protein: str
|
| 527 |
+
processed_input: str
|
| 528 |
+
predicted_dna: str
|
| 529 |
+
|
| 530 |
+
|
| 531 |
+
class IterableData(torch.utils.data.IterableDataset):
|
| 532 |
+
"""
|
| 533 |
+
Defines the logic for iterable datasets (working over streams of
|
| 534 |
+
data) in parallel multi-processing environments, e.g., multi-GPU.
|
| 535 |
+
|
| 536 |
+
Args:
|
| 537 |
+
dist_env (Optional[str]): The distribution environment identifier
|
| 538 |
+
(e.g., "slurm").
|
| 539 |
+
|
| 540 |
+
Credit: Guillaume Filion
|
| 541 |
+
"""
|
| 542 |
+
|
| 543 |
+
def __init__(self, dist_env: Optional[str] = None):
|
| 544 |
+
super().__init__()
|
| 545 |
+
if dist_env is None:
|
| 546 |
+
self.world_size_handle, self.rank_handle = ("WORLD_SIZE", "LOCAL_RANK")
|
| 547 |
+
else:
|
| 548 |
+
self.world_size_handle, self.rank_handle = {
|
| 549 |
+
"slurm": ("SLURM_NTASKS", "SLURM_PROCID")
|
| 550 |
+
}.get(dist_env, ("WORLD_SIZE", "LOCAL_RANK"))
|
| 551 |
+
|
| 552 |
+
@property
|
| 553 |
+
def iterator(self) -> Iterator:
|
| 554 |
+
"""Define the stream logic for the dataset. Implement in subclasses."""
|
| 555 |
+
raise NotImplementedError
|
| 556 |
+
|
| 557 |
+
def __iter__(self) -> Iterator:
|
| 558 |
+
"""
|
| 559 |
+
Create an iterator for the dataset, handling multi-processing contexts.
|
| 560 |
+
|
| 561 |
+
Returns:
|
| 562 |
+
Iterator: The iterator for the dataset.
|
| 563 |
+
"""
|
| 564 |
+
worker_info = torch.utils.data.get_worker_info()
|
| 565 |
+
if worker_info is None:
|
| 566 |
+
return self.iterator
|
| 567 |
+
|
| 568 |
+
# In multi-processing context, use 'os.environ' to
|
| 569 |
+
# find global worker rank. Then use 'islice' to allocate
|
| 570 |
+
# the items of the stream to the workers.
|
| 571 |
+
world_size = int(os.environ.get(self.world_size_handle, "1"))
|
| 572 |
+
global_rank = int(os.environ.get(self.rank_handle, "0"))
|
| 573 |
+
local_rank = worker_info.id
|
| 574 |
+
local_num_workers = worker_info.num_workers
|
| 575 |
+
|
| 576 |
+
# Assume that each process has the same number of local workers.
|
| 577 |
+
worker_rk = global_rank * local_num_workers + local_rank
|
| 578 |
+
worker_nb = world_size * local_num_workers
|
| 579 |
+
return itertools.islice(self.iterator, worker_rk, None, worker_nb)
|
| 580 |
+
|
| 581 |
+
|
| 582 |
+
class IterableJSONData(IterableData):
|
| 583 |
+
"""
|
| 584 |
+
Iterate over the lines of a JSON file and uncompress if needed.
|
| 585 |
+
|
| 586 |
+
Args:
|
| 587 |
+
data_path (str): The path to the JSON data file.
|
| 588 |
+
train (bool): Flag indicating if the dataset is for training.
|
| 589 |
+
**kwargs: Additional keyword arguments for the base class.
|
| 590 |
+
"""
|
| 591 |
+
|
| 592 |
+
def __init__(self, data_path: str, train: bool = True, **kwargs):
|
| 593 |
+
super().__init__(**kwargs)
|
| 594 |
+
self.data_path = data_path
|
| 595 |
+
self.train = train
|
| 596 |
+
with open(os.path.join(self.data_path, "finetune_set.json"), "r") as f:
|
| 597 |
+
self.records = [json.loads(line) for line in f]
|
| 598 |
+
|
| 599 |
+
def __len__(self):
|
| 600 |
+
return len(self.records)
|
| 601 |
+
|
| 602 |
+
@property
|
| 603 |
+
def iterator(self) -> Iterator:
|
| 604 |
+
"""Define the stream logic for the dataset."""
|
| 605 |
+
for record in self.records:
|
| 606 |
+
yield record
|
| 607 |
+
|
| 608 |
+
|
| 609 |
+
class ConfigManager(ABC):
|
| 610 |
+
"""
|
| 611 |
+
Abstract base class for managing configuration settings.
|
| 612 |
+
"""
|
| 613 |
+
_config: Dict[str, Any]
|
| 614 |
+
|
| 615 |
+
def __enter__(self):
|
| 616 |
+
return self
|
| 617 |
+
|
| 618 |
+
def __exit__(self, exc_type, exc_value, traceback):
|
| 619 |
+
if exc_type is not None:
|
| 620 |
+
print(f"Exception occurred: {exc_type}, {exc_value}, {traceback}")
|
| 621 |
+
self.reset_config()
|
| 622 |
+
|
| 623 |
+
@abstractmethod
|
| 624 |
+
def reset_config(self) -> None:
|
| 625 |
+
"""Reset the configuration to default values."""
|
| 626 |
+
pass
|
| 627 |
+
|
| 628 |
+
def get(self, key: str) -> Any:
|
| 629 |
+
"""
|
| 630 |
+
Get the value of a configuration key.
|
| 631 |
+
|
| 632 |
+
Args:
|
| 633 |
+
key (str): The key to retrieve the value for.
|
| 634 |
+
|
| 635 |
+
Returns:
|
| 636 |
+
Any: The value of the configuration key.
|
| 637 |
+
"""
|
| 638 |
+
return self._config.get(key)
|
| 639 |
+
|
| 640 |
+
def set(self, key: str, value: Any) -> None:
|
| 641 |
+
"""
|
| 642 |
+
Set the value of a configuration key.
|
| 643 |
+
|
| 644 |
+
Args:
|
| 645 |
+
key (str): The key to set the value for.
|
| 646 |
+
value (Any): The value to set for the key.
|
| 647 |
+
"""
|
| 648 |
+
self.validate_inputs(key, value)
|
| 649 |
+
self._config[key] = value
|
| 650 |
+
|
| 651 |
+
def update(self, config_dict: dict) -> None:
|
| 652 |
+
"""
|
| 653 |
+
Update the configuration with a dictionary of key-value pairs after validating them.
|
| 654 |
+
|
| 655 |
+
Args:
|
| 656 |
+
config_dict (dict): A dictionary of key-value pairs to update the configuration.
|
| 657 |
+
"""
|
| 658 |
+
for key, value in config_dict.items():
|
| 659 |
+
self.validate_inputs(key, value)
|
| 660 |
+
self._config.update(config_dict)
|
| 661 |
+
|
| 662 |
+
@abstractmethod
|
| 663 |
+
def validate_inputs(self, key: str, value: Any) -> None:
|
| 664 |
+
"""Validate the inputs for the configuration."""
|
| 665 |
+
pass
|
| 666 |
+
|
| 667 |
+
|
| 668 |
+
class ProteinConfig(ConfigManager):
|
| 669 |
+
"""
|
| 670 |
+
A class to manage configuration settings for protein sequences.
|
| 671 |
+
|
| 672 |
+
This class ensures that the configuration is a singleton.
|
| 673 |
+
It provides methods to get, set, and update configuration values.
|
| 674 |
+
|
| 675 |
+
Attributes:
|
| 676 |
+
_instance (Optional[ConfigManager]): The singleton instance of the ConfigManager.
|
| 677 |
+
_config (Dict[str, Any]): The configuration dictionary.
|
| 678 |
+
"""
|
| 679 |
+
|
| 680 |
+
_instance = None
|
| 681 |
+
|
| 682 |
+
def __new__(cls):
|
| 683 |
+
"""
|
| 684 |
+
Create a new instance of the ProteinConfig class.
|
| 685 |
+
|
| 686 |
+
Returns:
|
| 687 |
+
ProteinConfig: The singleton instance of the ProteinConfig.
|
| 688 |
+
"""
|
| 689 |
+
if cls._instance is None:
|
| 690 |
+
cls._instance = super(ProteinConfig, cls).__new__(cls)
|
| 691 |
+
cls._instance.reset_config()
|
| 692 |
+
return cls._instance
|
| 693 |
+
|
| 694 |
+
def validate_inputs(self, key: str, value: Any) -> None:
|
| 695 |
+
"""
|
| 696 |
+
Validate the inputs for the configuration.
|
| 697 |
+
|
| 698 |
+
Args:
|
| 699 |
+
key (str): The key to validate.
|
| 700 |
+
value (Any): The value to validate.
|
| 701 |
+
|
| 702 |
+
Raises:
|
| 703 |
+
ValueError: If the value is invalid.
|
| 704 |
+
TypeError: If the value is of the wrong type.
|
| 705 |
+
"""
|
| 706 |
+
if key == "ambiguous_aminoacid_behavior":
|
| 707 |
+
if value not in [
|
| 708 |
+
"raise_error",
|
| 709 |
+
"standardize_deterministic",
|
| 710 |
+
"standardize_random",
|
| 711 |
+
]:
|
| 712 |
+
raise ValueError(
|
| 713 |
+
f"Invalid value for ambiguous_aminoacid_behavior: {value}."
|
| 714 |
+
)
|
| 715 |
+
elif key == "ambiguous_aminoacid_map_override":
|
| 716 |
+
if not isinstance(value, dict):
|
| 717 |
+
raise TypeError(
|
| 718 |
+
f"Invalid type for ambiguous_aminoacid_map_override: {value}."
|
| 719 |
+
)
|
| 720 |
+
for ambiguous_aminoacid, aminoacids in value.items():
|
| 721 |
+
if not isinstance(aminoacids, list):
|
| 722 |
+
raise TypeError(f"Invalid type for aminoacids: {aminoacids}.")
|
| 723 |
+
if not aminoacids:
|
| 724 |
+
raise ValueError(
|
| 725 |
+
f"Override for aminoacid '{ambiguous_aminoacid}' cannot be empty list."
|
| 726 |
+
)
|
| 727 |
+
if ambiguous_aminoacid not in AMBIGUOUS_AMINOACID_MAP:
|
| 728 |
+
raise ValueError(
|
| 729 |
+
f"Invalid amino acid in ambiguous_aminoacid_map_override: {ambiguous_aminoacid}"
|
| 730 |
+
)
|
| 731 |
+
else:
|
| 732 |
+
raise ValueError(f"Invalid configuration key: {key}")
|
| 733 |
+
|
| 734 |
+
def reset_config(self) -> None:
|
| 735 |
+
"""
|
| 736 |
+
Reset the configuration to the default values.
|
| 737 |
+
"""
|
| 738 |
+
self._config = {
|
| 739 |
+
"ambiguous_aminoacid_behavior": "standardize_random",
|
| 740 |
+
"ambiguous_aminoacid_map_override": {},
|
| 741 |
+
}
|
| 742 |
+
|
| 743 |
+
|
| 744 |
+
def load_python_object_from_disk(file_path: str) -> Any:
|
| 745 |
+
"""
|
| 746 |
+
Load a Pickle object from disk and return it as a Python object.
|
| 747 |
+
|
| 748 |
+
Args:
|
| 749 |
+
file_path (str): The path to the Pickle file.
|
| 750 |
+
|
| 751 |
+
Returns:
|
| 752 |
+
Any: The loaded Python object.
|
| 753 |
+
"""
|
| 754 |
+
with open(file_path, "rb") as file:
|
| 755 |
+
return pickle.load(file)
|
| 756 |
+
|
| 757 |
+
|
| 758 |
+
def save_python_object_to_disk(input_object: Any, file_path: str) -> None:
|
| 759 |
+
"""
|
| 760 |
+
Save a Python object to disk using Pickle.
|
| 761 |
+
|
| 762 |
+
Args:
|
| 763 |
+
input_object (Any): The Python object to save.
|
| 764 |
+
file_path (str): The path where the object will be saved.
|
| 765 |
+
"""
|
| 766 |
+
with open(file_path, "wb") as file:
|
| 767 |
+
pickle.dump(input_object, file)
|
| 768 |
+
|
| 769 |
+
|
| 770 |
+
def find_pattern_in_fasta(keyword: str, text: str) -> str:
|
| 771 |
+
"""
|
| 772 |
+
Find a specific keyword pattern in text. Helpful for identifying parts
|
| 773 |
+
of a FASTA sequence.
|
| 774 |
+
|
| 775 |
+
Args:
|
| 776 |
+
keyword (str): The keyword pattern to search for.
|
| 777 |
+
text (str): The text to search within.
|
| 778 |
+
|
| 779 |
+
Returns:
|
| 780 |
+
str: The found pattern or an empty string if not found.
|
| 781 |
+
"""
|
| 782 |
+
# Search for the keyword pattern in the text using regex
|
| 783 |
+
result = re.search(keyword + r"=(.*?)]", text)
|
| 784 |
+
return result.group(1) if result else ""
|
| 785 |
+
|
| 786 |
+
|
| 787 |
+
def get_organism2id_dict(organism_reference: str) -> Dict[str, int]:
|
| 788 |
+
"""
|
| 789 |
+
Return a dictionary mapping each organism in training data to an index
|
| 790 |
+
used for training.
|
| 791 |
+
|
| 792 |
+
Args:
|
| 793 |
+
organism_reference (str): Path to a CSV file containing a list of
|
| 794 |
+
all organisms. The format of the CSV file should be as follows:
|
| 795 |
+
|
| 796 |
+
0,Escherichia coli
|
| 797 |
+
1,Homo sapiens
|
| 798 |
+
2,Mus musculus
|
| 799 |
+
|
| 800 |
+
Returns:
|
| 801 |
+
Dict[str, int]: Dictionary mapping organism names to their respective indices.
|
| 802 |
+
"""
|
| 803 |
+
# Read the CSV file and create a dictionary mapping organisms to their indices
|
| 804 |
+
organisms = pd.read_csv(organism_reference, index_col=0, header=None)
|
| 805 |
+
organism2id = {organisms.iloc[i].values[0]: i for i in organisms.index}
|
| 806 |
+
|
| 807 |
+
return organism2id
|
| 808 |
+
|
| 809 |
+
|
| 810 |
+
def get_taxonomy_id(
|
| 811 |
+
taxonomy_reference: str, organism: Optional[str] = None, return_dict: bool = False
|
| 812 |
+
) -> Any:
|
| 813 |
+
"""
|
| 814 |
+
Return the taxonomy id of a given organism using a reference file.
|
| 815 |
+
Optionally, return the whole dictionary instead if return_dict is True.
|
| 816 |
+
|
| 817 |
+
Args:
|
| 818 |
+
taxonomy_reference (str): Path to the taxonomy reference file.
|
| 819 |
+
organism (Optional[str]): The name of the organism to look up.
|
| 820 |
+
return_dict (bool): Whether to return the entire dictionary.
|
| 821 |
+
|
| 822 |
+
Returns:
|
| 823 |
+
Any: The taxonomy id of the organism or the entire dictionary.
|
| 824 |
+
"""
|
| 825 |
+
# Load the organism-to-taxonomy mapping from a Pickle file
|
| 826 |
+
organism2taxonomy = load_python_object_from_disk(taxonomy_reference)
|
| 827 |
+
|
| 828 |
+
if return_dict:
|
| 829 |
+
return dict(sorted(organism2taxonomy.items()))
|
| 830 |
+
|
| 831 |
+
return organism2taxonomy[organism]
|
| 832 |
+
|
| 833 |
+
|
| 834 |
+
def sort_amino2codon_skeleton(amino2codon: Dict[str, Any]) -> Dict[str, Any]:
|
| 835 |
+
"""
|
| 836 |
+
Sort the amino2codon dictionary alphabetically by amino acid and by codon name.
|
| 837 |
+
|
| 838 |
+
Args:
|
| 839 |
+
amino2codon (Dict[str, Any]): The amino2codon dictionary to sort.
|
| 840 |
+
|
| 841 |
+
Returns:
|
| 842 |
+
Dict[str, Any]: The sorted amino2codon dictionary.
|
| 843 |
+
"""
|
| 844 |
+
# Sort the dictionary by amino acid and then by codon name
|
| 845 |
+
amino2codon = dict(sorted(amino2codon.items()))
|
| 846 |
+
amino2codon = {
|
| 847 |
+
amino: (
|
| 848 |
+
[codon for codon, _ in sorted(zip(codons, frequencies))],
|
| 849 |
+
[freq for _, freq in sorted(zip(codons, frequencies))],
|
| 850 |
+
)
|
| 851 |
+
for amino, (codons, frequencies) in amino2codon.items()
|
| 852 |
+
}
|
| 853 |
+
|
| 854 |
+
return amino2codon
|
| 855 |
+
|
| 856 |
+
|
| 857 |
+
def load_pkl_from_url(url: str) -> Any:
|
| 858 |
+
"""
|
| 859 |
+
Download a Pickle file from a URL and return the loaded object.
|
| 860 |
+
|
| 861 |
+
Args:
|
| 862 |
+
url (str): The URL to download the Pickle file from.
|
| 863 |
+
|
| 864 |
+
Returns:
|
| 865 |
+
Any: The loaded Python object from the Pickle file.
|
| 866 |
+
"""
|
| 867 |
+
response = requests.get(url)
|
| 868 |
+
response.raise_for_status() # Ensure the request was successful
|
| 869 |
+
|
| 870 |
+
# Load the Pickle object from the response content
|
| 871 |
+
return pickle.loads(response.content)
|
CodonTransformer/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
"""CodonTransformer package."""
|
Dockerfile
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
FROM python:3.10-slim
|
| 2 |
+
|
| 3 |
+
ENV PYTHONDONTWRITEBYTECODE=1
|
| 4 |
+
ENV PYTHONUNBUFFERED=1
|
| 5 |
+
ENV PIP_NO_CACHE_DIR=1
|
| 6 |
+
|
| 7 |
+
WORKDIR /app
|
| 8 |
+
|
| 9 |
+
RUN apt-get update && apt-get install -y --no-install-recommends \
|
| 10 |
+
git \
|
| 11 |
+
build-essential \
|
| 12 |
+
&& rm -rf /var/lib/apt/lists/*
|
| 13 |
+
|
| 14 |
+
COPY requirements.txt /app/requirements.txt
|
| 15 |
+
RUN pip install --upgrade pip && pip install -r /app/requirements.txt
|
| 16 |
+
|
| 17 |
+
COPY . /app
|
| 18 |
+
|
| 19 |
+
EXPOSE 7860
|
| 20 |
+
|
| 21 |
+
CMD ["streamlit", "run", "app.py", "--server.port=7860", "--server.address=0.0.0.0", "--server.headless=true"]
|
ENCOT_Academic_Documentation.html
ADDED
|
@@ -0,0 +1,2625 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
<!DOCTYPE html>
|
| 2 |
+
<html lang="en">
|
| 3 |
+
<head>
|
| 4 |
+
<meta charset="UTF-8">
|
| 5 |
+
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
| 6 |
+
<title>ENCOT: Enhanced Codon Optimization Tool - Technical Documentation</title>
|
| 7 |
+
<link rel="stylesheet" href="https://cdnjs.cloudflare.com/ajax/libs/highlight.js/11.9.0/styles/atom-one-light.min.css">
|
| 8 |
+
<script src="https://cdnjs.cloudflare.com/ajax/libs/highlight.js/11.9.0/highlight.min.js"></script>
|
| 9 |
+
<script src="https://cdnjs.cloudflare.com/ajax/libs/highlight.js/11.9.0/languages/python.min.js"></script>
|
| 10 |
+
<script src="https://cdnjs.cloudflare.com/ajax/libs/highlight.js/11.9.0/languages/yaml.min.js"></script>
|
| 11 |
+
<link href="https://fonts.googleapis.com/css2?family=Computer+Modern+Serif:wght@400;700&family=Computer+Modern+Sans:wght@400;700&family=Computer+Modern+Typewriter&display=swap" rel="stylesheet">
|
| 12 |
+
<style>
|
| 13 |
+
/* LaTeX-inspired Academic Styling */
|
| 14 |
+
@import url('https://fonts.googleapis.com/css2?family=Crimson+Text:ital,wght@0,400;0,600;0,700;1,400&family=Source+Code+Pro:wght@400;500&display=swap');
|
| 15 |
+
|
| 16 |
+
* {
|
| 17 |
+
margin: 0;
|
| 18 |
+
padding: 0;
|
| 19 |
+
box-sizing: border-box;
|
| 20 |
+
}
|
| 21 |
+
|
| 22 |
+
body {
|
| 23 |
+
font-family: 'Crimson Text', 'Georgia', serif;
|
| 24 |
+
line-height: 1.6;
|
| 25 |
+
color: #2c3e50;
|
| 26 |
+
background: #f8f9fa;
|
| 27 |
+
padding: 40px;
|
| 28 |
+
max-width: 900px;
|
| 29 |
+
margin: 0 auto;
|
| 30 |
+
font-size: 11pt;
|
| 31 |
+
}
|
| 32 |
+
|
| 33 |
+
/* Academic Paper Header */
|
| 34 |
+
.paper-header {
|
| 35 |
+
text-align: center;
|
| 36 |
+
margin-bottom: 50px;
|
| 37 |
+
padding: 30px 0;
|
| 38 |
+
border-bottom: 2px solid #2c3e50;
|
| 39 |
+
}
|
| 40 |
+
|
| 41 |
+
.paper-header h1 {
|
| 42 |
+
font-size: 28pt;
|
| 43 |
+
font-weight: 700;
|
| 44 |
+
margin-bottom: 20px;
|
| 45 |
+
color: #1a1a1a;
|
| 46 |
+
letter-spacing: -0.5px;
|
| 47 |
+
}
|
| 48 |
+
|
| 49 |
+
.paper-header .subtitle {
|
| 50 |
+
font-size: 14pt;
|
| 51 |
+
font-style: italic;
|
| 52 |
+
color: #555;
|
| 53 |
+
margin-bottom: 25px;
|
| 54 |
+
}
|
| 55 |
+
|
| 56 |
+
.paper-header .authors {
|
| 57 |
+
font-size: 11pt;
|
| 58 |
+
color: #444;
|
| 59 |
+
margin-bottom: 10px;
|
| 60 |
+
}
|
| 61 |
+
|
| 62 |
+
.paper-header .affiliation {
|
| 63 |
+
font-size: 10pt;
|
| 64 |
+
color: #666;
|
| 65 |
+
font-style: italic;
|
| 66 |
+
}
|
| 67 |
+
|
| 68 |
+
/* Section Styling */
|
| 69 |
+
.section {
|
| 70 |
+
margin: 40px 0;
|
| 71 |
+
page-break-inside: avoid;
|
| 72 |
+
background: white;
|
| 73 |
+
padding: 25px;
|
| 74 |
+
border: 1px solid #ddd;
|
| 75 |
+
box-shadow: 0 1px 3px rgba(0,0,0,0.05);
|
| 76 |
+
}
|
| 77 |
+
|
| 78 |
+
.section-number {
|
| 79 |
+
font-weight: 700;
|
| 80 |
+
color: #2c3e50;
|
| 81 |
+
font-size: 14pt;
|
| 82 |
+
}
|
| 83 |
+
|
| 84 |
+
.section-title {
|
| 85 |
+
font-size: 16pt;
|
| 86 |
+
font-weight: 700;
|
| 87 |
+
color: #2c3e50;
|
| 88 |
+
margin: 15px 0 20px 0;
|
| 89 |
+
border-bottom: 1px solid #ccc;
|
| 90 |
+
padding-bottom: 8px;
|
| 91 |
+
}
|
| 92 |
+
|
| 93 |
+
.abstract, .description {
|
| 94 |
+
text-align: justify;
|
| 95 |
+
margin: 15px 0;
|
| 96 |
+
text-indent: 0;
|
| 97 |
+
hyphens: auto;
|
| 98 |
+
}
|
| 99 |
+
|
| 100 |
+
.abstract {
|
| 101 |
+
font-size: 10.5pt;
|
| 102 |
+
padding: 15px;
|
| 103 |
+
background: #f9f9f9;
|
| 104 |
+
border-left: 3px solid #3498db;
|
| 105 |
+
font-style: italic;
|
| 106 |
+
}
|
| 107 |
+
|
| 108 |
+
/* Code Blocks - LaTeX Listing Style */
|
| 109 |
+
.code-container {
|
| 110 |
+
margin: 20px 0;
|
| 111 |
+
border: 1px solid #ccc;
|
| 112 |
+
background: #fafafa;
|
| 113 |
+
}
|
| 114 |
+
|
| 115 |
+
.code-header {
|
| 116 |
+
background: #e8e8e8;
|
| 117 |
+
padding: 8px 15px;
|
| 118 |
+
border-bottom: 1px solid #ccc;
|
| 119 |
+
font-family: 'Source Code Pro', monospace;
|
| 120 |
+
font-size: 9pt;
|
| 121 |
+
color: #555;
|
| 122 |
+
}
|
| 123 |
+
|
| 124 |
+
.listing-number {
|
| 125 |
+
font-weight: 600;
|
| 126 |
+
color: #2c3e50;
|
| 127 |
+
}
|
| 128 |
+
|
| 129 |
+
pre {
|
| 130 |
+
margin: 0;
|
| 131 |
+
padding: 15px;
|
| 132 |
+
overflow-x: auto;
|
| 133 |
+
background: white;
|
| 134 |
+
border: none;
|
| 135 |
+
}
|
| 136 |
+
|
| 137 |
+
pre code {
|
| 138 |
+
font-family: 'Source Code Pro', 'Courier New', monospace;
|
| 139 |
+
font-size: 9pt;
|
| 140 |
+
line-height: 1.4;
|
| 141 |
+
color: #2c3e50;
|
| 142 |
+
}
|
| 143 |
+
|
| 144 |
+
/* Annotations and Highlights */
|
| 145 |
+
.annotation {
|
| 146 |
+
background: #fff3cd;
|
| 147 |
+
border-left: 4px solid #ffc107;
|
| 148 |
+
padding: 12px 15px;
|
| 149 |
+
margin: 15px 0;
|
| 150 |
+
font-size: 10pt;
|
| 151 |
+
}
|
| 152 |
+
|
| 153 |
+
.annotation strong {
|
| 154 |
+
color: #856404;
|
| 155 |
+
}
|
| 156 |
+
|
| 157 |
+
.key-concept {
|
| 158 |
+
background: #d1ecf1;
|
| 159 |
+
border-left: 4px solid #0c5460;
|
| 160 |
+
padding: 12px 15px;
|
| 161 |
+
margin: 15px 0;
|
| 162 |
+
font-size: 10pt;
|
| 163 |
+
}
|
| 164 |
+
|
| 165 |
+
.mathematical {
|
| 166 |
+
font-family: 'Crimson Text', serif;
|
| 167 |
+
font-style: italic;
|
| 168 |
+
text-align: center;
|
| 169 |
+
padding: 15px;
|
| 170 |
+
margin: 20px 0;
|
| 171 |
+
background: #f9f9f9;
|
| 172 |
+
border: 1px solid #ddd;
|
| 173 |
+
font-size: 11pt;
|
| 174 |
+
}
|
| 175 |
+
|
| 176 |
+
/* File References */
|
| 177 |
+
.file-ref {
|
| 178 |
+
font-family: 'Source Code Pro', monospace;
|
| 179 |
+
font-size: 9pt;
|
| 180 |
+
color: #2c3e50;
|
| 181 |
+
background: #f4f4f4;
|
| 182 |
+
padding: 8px 12px;
|
| 183 |
+
border-left: 3px solid #3498db;
|
| 184 |
+
margin: 15px 0;
|
| 185 |
+
}
|
| 186 |
+
|
| 187 |
+
.file-path {
|
| 188 |
+
font-weight: 600;
|
| 189 |
+
color: #2980b9;
|
| 190 |
+
}
|
| 191 |
+
|
| 192 |
+
/* Handwritten-style Notes */
|
| 193 |
+
.handwritten-note {
|
| 194 |
+
border: 2px dashed #95a5a6;
|
| 195 |
+
padding: 15px;
|
| 196 |
+
margin: 20px 0;
|
| 197 |
+
background: #fef9e7;
|
| 198 |
+
font-size: 10pt;
|
| 199 |
+
position: relative;
|
| 200 |
+
}
|
| 201 |
+
|
| 202 |
+
.handwritten-note::before {
|
| 203 |
+
content: "✏️ Important Note:";
|
| 204 |
+
font-weight: 600;
|
| 205 |
+
color: #7f8c8d;
|
| 206 |
+
display: block;
|
| 207 |
+
margin-bottom: 8px;
|
| 208 |
+
}
|
| 209 |
+
|
| 210 |
+
/* Algorithm/Pseudocode Box */
|
| 211 |
+
.algorithm-box {
|
| 212 |
+
border: 2px solid #2c3e50;
|
| 213 |
+
padding: 20px;
|
| 214 |
+
margin: 20px 0;
|
| 215 |
+
background: white;
|
| 216 |
+
}
|
| 217 |
+
|
| 218 |
+
.algorithm-title {
|
| 219 |
+
font-weight: 700;
|
| 220 |
+
text-align: center;
|
| 221 |
+
margin-bottom: 15px;
|
| 222 |
+
font-size: 11pt;
|
| 223 |
+
text-transform: uppercase;
|
| 224 |
+
letter-spacing: 1px;
|
| 225 |
+
}
|
| 226 |
+
|
| 227 |
+
.algorithm-content {
|
| 228 |
+
font-family: 'Source Code Pro', monospace;
|
| 229 |
+
font-size: 9.5pt;
|
| 230 |
+
line-height: 1.8;
|
| 231 |
+
}
|
| 232 |
+
|
| 233 |
+
/* Equation Styling */
|
| 234 |
+
.equation {
|
| 235 |
+
text-align: center;
|
| 236 |
+
margin: 25px 0;
|
| 237 |
+
font-size: 12pt;
|
| 238 |
+
font-family: 'Crimson Text', serif;
|
| 239 |
+
}
|
| 240 |
+
|
| 241 |
+
.equation-label {
|
| 242 |
+
float: right;
|
| 243 |
+
font-size: 10pt;
|
| 244 |
+
color: #7f8c8d;
|
| 245 |
+
}
|
| 246 |
+
|
| 247 |
+
/* Table Styling */
|
| 248 |
+
table {
|
| 249 |
+
width: 100%;
|
| 250 |
+
border-collapse: collapse;
|
| 251 |
+
margin: 20px 0;
|
| 252 |
+
font-size: 10pt;
|
| 253 |
+
}
|
| 254 |
+
|
| 255 |
+
th, td {
|
| 256 |
+
border: 1px solid #bbb;
|
| 257 |
+
padding: 8px 12px;
|
| 258 |
+
text-align: left;
|
| 259 |
+
}
|
| 260 |
+
|
| 261 |
+
th {
|
| 262 |
+
background: #ecf0f1;
|
| 263 |
+
font-weight: 600;
|
| 264 |
+
}
|
| 265 |
+
|
| 266 |
+
/* Footer */
|
| 267 |
+
.footer {
|
| 268 |
+
margin-top: 50px;
|
| 269 |
+
padding-top: 20px;
|
| 270 |
+
border-top: 1px solid #ccc;
|
| 271 |
+
text-align: center;
|
| 272 |
+
font-size: 9pt;
|
| 273 |
+
color: #7f8c8d;
|
| 274 |
+
}
|
| 275 |
+
|
| 276 |
+
/* Print Styles - Optimized for minimal spacing */
|
| 277 |
+
@page {
|
| 278 |
+
size: A4;
|
| 279 |
+
margin: 1.2cm 1.5cm;
|
| 280 |
+
}
|
| 281 |
+
|
| 282 |
+
@page :first {
|
| 283 |
+
margin-top: 1.5cm;
|
| 284 |
+
}
|
| 285 |
+
|
| 286 |
+
@media print {
|
| 287 |
+
* {
|
| 288 |
+
-webkit-print-color-adjust: exact !important;
|
| 289 |
+
print-color-adjust: exact !important;
|
| 290 |
+
}
|
| 291 |
+
|
| 292 |
+
body {
|
| 293 |
+
background: white;
|
| 294 |
+
padding: 0;
|
| 295 |
+
margin: 0;
|
| 296 |
+
font-size: 9.5pt;
|
| 297 |
+
line-height: 1.35;
|
| 298 |
+
}
|
| 299 |
+
|
| 300 |
+
/* Minimize margins */
|
| 301 |
+
.paper-header {
|
| 302 |
+
margin-bottom: 15px;
|
| 303 |
+
padding: 10px 0;
|
| 304 |
+
page-break-after: avoid;
|
| 305 |
+
}
|
| 306 |
+
|
| 307 |
+
.paper-header h1 {
|
| 308 |
+
font-size: 20pt;
|
| 309 |
+
margin-bottom: 8px;
|
| 310 |
+
}
|
| 311 |
+
|
| 312 |
+
.paper-header .subtitle {
|
| 313 |
+
font-size: 10pt;
|
| 314 |
+
margin: 3px 0;
|
| 315 |
+
}
|
| 316 |
+
|
| 317 |
+
.abstract {
|
| 318 |
+
margin: 12px 0;
|
| 319 |
+
padding: 10px;
|
| 320 |
+
page-break-after: avoid;
|
| 321 |
+
page-break-inside: avoid;
|
| 322 |
+
}
|
| 323 |
+
|
| 324 |
+
/* Section optimization - ALLOW BREAKS */
|
| 325 |
+
.section {
|
| 326 |
+
box-shadow: none;
|
| 327 |
+
border: none;
|
| 328 |
+
padding: 8px 10px;
|
| 329 |
+
margin: 5px 0;
|
| 330 |
+
page-break-inside: auto; /* Changed from avoid */
|
| 331 |
+
background: white;
|
| 332 |
+
}
|
| 333 |
+
|
| 334 |
+
.section-title {
|
| 335 |
+
font-size: 12pt;
|
| 336 |
+
margin-bottom: 6px;
|
| 337 |
+
page-break-after: avoid;
|
| 338 |
+
}
|
| 339 |
+
|
| 340 |
+
.description {
|
| 341 |
+
margin: 5px 0;
|
| 342 |
+
font-size: 9.5pt;
|
| 343 |
+
line-height: 1.35;
|
| 344 |
+
}
|
| 345 |
+
|
| 346 |
+
/* Code containers - allow breaks */
|
| 347 |
+
.code-container {
|
| 348 |
+
page-break-inside: auto;
|
| 349 |
+
margin: 8px 0;
|
| 350 |
+
padding: 6px;
|
| 351 |
+
border: 1px solid #ccc;
|
| 352 |
+
}
|
| 353 |
+
|
| 354 |
+
.code-header {
|
| 355 |
+
padding: 4px 6px;
|
| 356 |
+
margin-bottom: 4px;
|
| 357 |
+
page-break-after: avoid;
|
| 358 |
+
font-size: 9pt;
|
| 359 |
+
}
|
| 360 |
+
|
| 361 |
+
pre {
|
| 362 |
+
margin: 0;
|
| 363 |
+
padding: 6px;
|
| 364 |
+
font-size: 7.5pt;
|
| 365 |
+
line-height: 1.25;
|
| 366 |
+
white-space: pre-wrap;
|
| 367 |
+
word-wrap: break-word;
|
| 368 |
+
}
|
| 369 |
+
|
| 370 |
+
code {
|
| 371 |
+
font-size: 7.5pt;
|
| 372 |
+
line-height: 1.25;
|
| 373 |
+
}
|
| 374 |
+
|
| 375 |
+
/* File references */
|
| 376 |
+
.file-ref {
|
| 377 |
+
margin: 5px 0;
|
| 378 |
+
padding: 4px 6px;
|
| 379 |
+
font-size: 8.5pt;
|
| 380 |
+
page-break-inside: avoid;
|
| 381 |
+
}
|
| 382 |
+
|
| 383 |
+
.file-path {
|
| 384 |
+
font-size: 8.5pt;
|
| 385 |
+
}
|
| 386 |
+
|
| 387 |
+
/* Mathematical content */
|
| 388 |
+
.mathematical {
|
| 389 |
+
margin: 8px 0;
|
| 390 |
+
padding: 6px;
|
| 391 |
+
font-size: 9.5pt;
|
| 392 |
+
page-break-inside: avoid;
|
| 393 |
+
}
|
| 394 |
+
|
| 395 |
+
.equation {
|
| 396 |
+
margin: 8px 0;
|
| 397 |
+
font-size: 10pt;
|
| 398 |
+
}
|
| 399 |
+
|
| 400 |
+
/* Key concepts and notes */
|
| 401 |
+
.key-concept {
|
| 402 |
+
margin: 8px 0;
|
| 403 |
+
padding: 6px;
|
| 404 |
+
font-size: 9pt;
|
| 405 |
+
page-break-inside: avoid;
|
| 406 |
+
}
|
| 407 |
+
|
| 408 |
+
.key-concept ul {
|
| 409 |
+
margin: 4px 0 0 12px;
|
| 410 |
+
}
|
| 411 |
+
|
| 412 |
+
.key-concept li {
|
| 413 |
+
margin: 2px 0;
|
| 414 |
+
line-height: 1.25;
|
| 415 |
+
}
|
| 416 |
+
|
| 417 |
+
.handwritten-note {
|
| 418 |
+
margin: 8px 0;
|
| 419 |
+
padding: 6px;
|
| 420 |
+
font-size: 8.5pt;
|
| 421 |
+
page-break-inside: avoid;
|
| 422 |
+
}
|
| 423 |
+
|
| 424 |
+
.handwritten-note::before {
|
| 425 |
+
margin-bottom: 4px;
|
| 426 |
+
}
|
| 427 |
+
|
| 428 |
+
/* Algorithm boxes */
|
| 429 |
+
.algorithm-box {
|
| 430 |
+
margin: 8px 0;
|
| 431 |
+
padding: 8px;
|
| 432 |
+
page-break-inside: auto; /* Allow break for long algorithms */
|
| 433 |
+
}
|
| 434 |
+
|
| 435 |
+
.algorithm-title {
|
| 436 |
+
font-size: 10pt;
|
| 437 |
+
margin-bottom: 6px;
|
| 438 |
+
}
|
| 439 |
+
|
| 440 |
+
.algorithm-content {
|
| 441 |
+
font-size: 8pt;
|
| 442 |
+
line-height: 1.4;
|
| 443 |
+
}
|
| 444 |
+
|
| 445 |
+
/* Tables */
|
| 446 |
+
table {
|
| 447 |
+
margin: 8px 0;
|
| 448 |
+
font-size: 8.5pt;
|
| 449 |
+
page-break-inside: auto;
|
| 450 |
+
}
|
| 451 |
+
|
| 452 |
+
th, td {
|
| 453 |
+
padding: 4px 6px;
|
| 454 |
+
font-size: 8.5pt;
|
| 455 |
+
}
|
| 456 |
+
|
| 457 |
+
/* Page break control */
|
| 458 |
+
h1, h2, h3, .section-title {
|
| 459 |
+
page-break-after: avoid;
|
| 460 |
+
}
|
| 461 |
+
|
| 462 |
+
.section:first-of-type {
|
| 463 |
+
page-break-before: avoid;
|
| 464 |
+
}
|
| 465 |
+
|
| 466 |
+
/* Keep title with at least some content */
|
| 467 |
+
.section-title + .description,
|
| 468 |
+
.code-header + pre {
|
| 469 |
+
page-break-before: avoid;
|
| 470 |
+
}
|
| 471 |
+
|
| 472 |
+
/* Hide unnecessary elements */
|
| 473 |
+
.footer {
|
| 474 |
+
display: none;
|
| 475 |
+
}
|
| 476 |
+
|
| 477 |
+
/* Compact spacing for lists */
|
| 478 |
+
ul, ol {
|
| 479 |
+
margin: 4px 0;
|
| 480 |
+
padding-left: 18px;
|
| 481 |
+
}
|
| 482 |
+
|
| 483 |
+
li {
|
| 484 |
+
margin: 1px 0;
|
| 485 |
+
line-height: 1.25;
|
| 486 |
+
}
|
| 487 |
+
|
| 488 |
+
/* Orphan and widow control */
|
| 489 |
+
p, .description, .key-concept, .handwritten-note {
|
| 490 |
+
orphans: 2;
|
| 491 |
+
widows: 2;
|
| 492 |
+
}
|
| 493 |
+
|
| 494 |
+
/* Reduce all vertical spacing */
|
| 495 |
+
* + * {
|
| 496 |
+
margin-top: 0 !important;
|
| 497 |
+
}
|
| 498 |
+
}
|
| 499 |
+
</style>
|
| 500 |
+
</head>
|
| 501 |
+
<body>
|
| 502 |
+
|
| 503 |
+
<!-- Academic Paper Header -->
|
| 504 |
+
<div class="paper-header">
|
| 505 |
+
<h1>ENCOT: Enhanced Codon Optimization Tool</h1>
|
| 506 |
+
<div class="subtitle">
|
| 507 |
+
A Transformer-Based Approach with Augmented-Lagrangian Method<br>
|
| 508 |
+
for Multi-Objective Codon Optimization in E. coli
|
| 509 |
+
</div>
|
| 510 |
+
<div class="authors">
|
| 511 |
+
Technical Implementation Documentation
|
| 512 |
+
</div>
|
| 513 |
+
|
| 514 |
+
</div>
|
| 515 |
+
|
| 516 |
+
<!-- Abstract -->
|
| 517 |
+
<div class="abstract">
|
| 518 |
+
<strong>Abstract:</strong> This document presents the technical implementation of ENCOT, a novel codon optimization
|
| 519 |
+
system that employs transformer-based deep learning combined with an Augmented-Lagrangian Method (ALM) for
|
| 520 |
+
precise control of GC content. The system optimizes multiple biological objectives simultaneously including
|
| 521 |
+
Codon Adaptation Index (CAI), tRNA Adaptation Index (tAI), GC content balance, and minimization of negative
|
| 522 |
+
cis-regulatory elements. The implementation builds upon the CodonTransformer architecture and introduces
|
| 523 |
+
innovative constraint optimization techniques for enhanced E. coli expression systems.
|
| 524 |
+
</div>
|
| 525 |
+
|
| 526 |
+
<!-- Section 1: Core Algorithm - ALM Implementation -->
|
| 527 |
+
<div class="section">
|
| 528 |
+
<div class="section-title">
|
| 529 |
+
<span class="section-number">1.</span> Augmented-Lagrangian Method Implementation
|
| 530 |
+
</div>
|
| 531 |
+
|
| 532 |
+
<div class="description">
|
| 533 |
+
The core innovation of ENCOT lies in its application of the Augmented-Lagrangian Method to enforce
|
| 534 |
+
GC content constraints during training. This approach allows the model to balance multiple optimization
|
| 535 |
+
objectives while maintaining biologically appropriate GC content levels.
|
| 536 |
+
</div>
|
| 537 |
+
|
| 538 |
+
<div class="mathematical">
|
| 539 |
+
<strong>Objective Function:</strong><br><br>
|
| 540 |
+
<i>L</i> = <i>L</i><sub>MLM</sub> + λ·(<i>GC</i> − μ) + (ρ/2)·(<i>GC</i> − μ)²
|
| 541 |
+
<div class="equation-label">(Eq. 1)</div>
|
| 542 |
+
</div>
|
| 543 |
+
|
| 544 |
+
<div class="key-concept">
|
| 545 |
+
<strong>Key Components:</strong>
|
| 546 |
+
<ul style="margin: 10px 0 0 20px;">
|
| 547 |
+
<li><i>L<sub>MLM</sub></i>: Masked Language Modeling loss for codon prediction</li>
|
| 548 |
+
<li>λ: Lagrangian multiplier (adaptively updated)</li>
|
| 549 |
+
<li>ρ: Penalty coefficient (self-tuning based on progress)</li>
|
| 550 |
+
<li><i>GC</i>: Mean GC content of predicted sequences</li>
|
| 551 |
+
<li>μ: Target GC content (0.52 for E. coli)</li>
|
| 552 |
+
</ul>
|
| 553 |
+
</div>
|
| 554 |
+
|
| 555 |
+
<div class="file-ref">
|
| 556 |
+
<div class="file-path">File: finetune.py</div>
|
| 557 |
+
Lines 73-148 | Class: plTrainHarness
|
| 558 |
+
</div>
|
| 559 |
+
|
| 560 |
+
<div class="code-container">
|
| 561 |
+
<div class="code-header">
|
| 562 |
+
<span class="listing-number">Listing 1:</span> ALM Training Harness - Initialization
|
| 563 |
+
</div>
|
| 564 |
+
<pre><code class="language-python">class plTrainHarness(pl.LightningModule):
|
| 565 |
+
"""
|
| 566 |
+
PyTorch Lightning training harness for ENCOT with Augmented-Lagrangian
|
| 567 |
+
Method (ALM) GC control.
|
| 568 |
+
|
| 569 |
+
This class implements the training loop for fine-tuning CodonTransformer
|
| 570 |
+
on E. coli sequences with precise GC content control using an
|
| 571 |
+
Augmented-Lagrangian Method. The ALM approach allows the model to learn
|
| 572 |
+
codon preferences while maintaining GC content within a target range.
|
| 573 |
+
|
| 574 |
+
Key features:
|
| 575 |
+
- Masked language modeling (MLM) loss for codon prediction
|
| 576 |
+
- ALM-based GC content constraint enforcement
|
| 577 |
+
- Curriculum learning: warm-up epochs before enforcing GC constraints
|
| 578 |
+
- Adaptive penalty coefficient (rho) adjustment based on constraint
|
| 579 |
+
violation progress
|
| 580 |
+
|
| 581 |
+
The ALM method minimizes:
|
| 582 |
+
L = L_MLM + λ·(GC - μ) + (ρ/2)(GC - μ)²
|
| 583 |
+
where λ is the Lagrangian multiplier and ρ is the penalty coefficient.
|
| 584 |
+
"""
|
| 585 |
+
|
| 586 |
+
def __init__(self, model, learning_rate, warmup_fraction,
|
| 587 |
+
gc_penalty_weight, tokenizer, gc_target=0.52,
|
| 588 |
+
use_lagrangian=False, lagrangian_rho=10.0,
|
| 589 |
+
curriculum_epochs=3, alm_tolerance=1e-5,
|
| 590 |
+
alm_dual_tolerance=1e-5, alm_penalty_update_factor=10.0,
|
| 591 |
+
alm_initial_penalty_factor=20.0,
|
| 592 |
+
alm_tolerance_update_factor=0.1,
|
| 593 |
+
alm_rel_penalty_increase_threshold=0.1,
|
| 594 |
+
alm_max_penalty=1e6, alm_min_penalty=1e-6):
|
| 595 |
+
super().__init__()
|
| 596 |
+
self.model = model
|
| 597 |
+
self.learning_rate = learning_rate
|
| 598 |
+
self.warmup_fraction = warmup_fraction
|
| 599 |
+
self.gc_penalty_weight = gc_penalty_weight
|
| 600 |
+
self.tokenizer = tokenizer
|
| 601 |
+
|
| 602 |
+
# Augmented-Lagrangian GC Control parameters
|
| 603 |
+
self.gc_target = gc_target
|
| 604 |
+
self.use_lagrangian = use_lagrangian
|
| 605 |
+
self.lagrangian_rho = lagrangian_rho
|
| 606 |
+
self.curriculum_epochs = curriculum_epochs
|
| 607 |
+
|
| 608 |
+
# Enhanced ALM parameters
|
| 609 |
+
self.alm_tolerance = alm_tolerance
|
| 610 |
+
self.alm_dual_tolerance = alm_dual_tolerance
|
| 611 |
+
self.alm_penalty_update_factor = alm_penalty_update_factor
|
| 612 |
+
self.alm_initial_penalty_factor = alm_initial_penalty_factor
|
| 613 |
+
self.alm_tolerance_update_factor = alm_tolerance_update_factor
|
| 614 |
+
self.alm_rel_penalty_increase_threshold = \
|
| 615 |
+
alm_rel_penalty_increase_threshold
|
| 616 |
+
self.alm_max_penalty = alm_max_penalty
|
| 617 |
+
self.alm_min_penalty = alm_min_penalty
|
| 618 |
+
|
| 619 |
+
# Initialize Lagrangian multiplier as buffer
|
| 620 |
+
# (persists across checkpoints)
|
| 621 |
+
self.register_buffer("lambda_gc", torch.tensor(0.0))
|
| 622 |
+
|
| 623 |
+
# Adaptive penalty coefficient (rho)
|
| 624 |
+
self.register_buffer("rho_adaptive",
|
| 625 |
+
torch.tensor(self.lagrangian_rho))
|
| 626 |
+
|
| 627 |
+
# Step counter for periodic lambda updates
|
| 628 |
+
self.register_buffer("step_counter", torch.tensor(0))
|
| 629 |
+
|
| 630 |
+
# ALM convergence tracking
|
| 631 |
+
self.register_buffer("previous_constraint_violation",
|
| 632 |
+
torch.tensor(float('inf')))</code></pre>
|
| 633 |
+
</div>
|
| 634 |
+
|
| 635 |
+
<div class="handwritten-note">
|
| 636 |
+
The initialization sets up persistent buffers for Lagrangian multipliers and penalty coefficients.
|
| 637 |
+
These buffers are saved with model checkpoints, allowing training to resume seamlessly. The curriculum
|
| 638 |
+
learning approach waits for 3 epochs before enforcing GC constraints, giving the model time to learn
|
| 639 |
+
basic codon patterns first.
|
| 640 |
+
</div>
|
| 641 |
+
</div>
|
| 642 |
+
|
| 643 |
+
<!-- Section 2: Training Step -->
|
| 644 |
+
<div class="section">
|
| 645 |
+
<div class="section-title">
|
| 646 |
+
<span class="section-number">2.</span> Training Step with ALM Loss Computation
|
| 647 |
+
</div>
|
| 648 |
+
|
| 649 |
+
<div class="description">
|
| 650 |
+
The training step combines standard masked language modeling with the ALM-based GC constraint.
|
| 651 |
+
During each forward pass, we compute GC content from predicted tokens and apply the Lagrangian
|
| 652 |
+
penalty to guide the model toward the target GC content.
|
| 653 |
+
</div>
|
| 654 |
+
|
| 655 |
+
<div class="file-ref">
|
| 656 |
+
<div class="file-path">File: finetune.py</div>
|
| 657 |
+
Lines 150-230 | Method: training_step
|
| 658 |
+
</div>
|
| 659 |
+
|
| 660 |
+
<div class="code-container">
|
| 661 |
+
<div class="code-header">
|
| 662 |
+
<span class="listing-number">Listing 2:</span> Training Step with ALM Loss
|
| 663 |
+
</div>
|
| 664 |
+
<pre><code class="language-python">def training_step(self, batch, batch_idx):
|
| 665 |
+
"""
|
| 666 |
+
Training step that computes MLM loss and applies ALM-based GC constraint.
|
| 667 |
+
|
| 668 |
+
The constraint is only enforced after curriculum_epochs warm-up period.
|
| 669 |
+
"""
|
| 670 |
+
outputs = self.model(**batch)
|
| 671 |
+
mlm_loss = outputs.loss
|
| 672 |
+
|
| 673 |
+
# Enhanced Lagrangian-based GC penalty
|
| 674 |
+
if self.use_lagrangian and self.current_epoch >= self.curriculum_epochs:
|
| 675 |
+
# Compute GC content from logits
|
| 676 |
+
logits = outputs.logits
|
| 677 |
+
predicted_tokens = torch.argmax(logits, dim=-1)
|
| 678 |
+
|
| 679 |
+
# Calculate GC content per sequence
|
| 680 |
+
gc_content_batch = []
|
| 681 |
+
for seq_tokens in predicted_tokens:
|
| 682 |
+
# Filter to valid codon tokens (indices >= 26)
|
| 683 |
+
valid_tokens = seq_tokens[seq_tokens >= 26]
|
| 684 |
+
if len(valid_tokens) == 0:
|
| 685 |
+
gc_content_batch.append(self.gc_target)
|
| 686 |
+
continue
|
| 687 |
+
|
| 688 |
+
# Count G and C containing codons
|
| 689 |
+
gc_counts = sum(1 for token in valid_tokens
|
| 690 |
+
if token.item() in G_indices + C_indices)
|
| 691 |
+
gc_content = gc_counts / len(valid_tokens)
|
| 692 |
+
gc_content_batch.append(gc_content)
|
| 693 |
+
|
| 694 |
+
# Mean GC content across batch
|
| 695 |
+
gc_content_mean = sum(gc_content_batch) / len(gc_content_batch)
|
| 696 |
+
|
| 697 |
+
# Compute GC constraint violation
|
| 698 |
+
gc_constraint = gc_content_mean - self.gc_target
|
| 699 |
+
|
| 700 |
+
# Augmented Lagrangian loss term
|
| 701 |
+
lagrangian_loss = (
|
| 702 |
+
self.lambda_gc * gc_constraint +
|
| 703 |
+
(self.rho_adaptive / 2) * (gc_constraint ** 2)
|
| 704 |
+
)
|
| 705 |
+
|
| 706 |
+
total_loss = mlm_loss + lagrangian_loss
|
| 707 |
+
|
| 708 |
+
# Log metrics
|
| 709 |
+
self.log("train/mlm_loss", mlm_loss, prog_bar=True)
|
| 710 |
+
self.log("train/gc_constraint", gc_constraint, prog_bar=True)
|
| 711 |
+
self.log("train/lagrangian_loss", lagrangian_loss, prog_bar=False)
|
| 712 |
+
self.log("train/lambda_gc", self.lambda_gc, prog_bar=False)
|
| 713 |
+
self.log("train/rho", self.rho_adaptive, prog_bar=False)
|
| 714 |
+
self.log("train/gc_content", gc_content_mean, prog_bar=True)
|
| 715 |
+
|
| 716 |
+
# Update Lagrangian multiplier periodically
|
| 717 |
+
self.step_counter += 1
|
| 718 |
+
if self.step_counter % 20 == 0:
|
| 719 |
+
self._update_alm_parameters(gc_constraint)
|
| 720 |
+
else:
|
| 721 |
+
# During warm-up, only use MLM loss
|
| 722 |
+
total_loss = mlm_loss
|
| 723 |
+
self.log("train/mlm_loss", mlm_loss, prog_bar=True)
|
| 724 |
+
|
| 725 |
+
self.log("train/total_loss", total_loss, prog_bar=True)
|
| 726 |
+
return total_loss</code></pre>
|
| 727 |
+
</div>
|
| 728 |
+
|
| 729 |
+
<div class="annotation">
|
| 730 |
+
<strong>Implementation Detail:</strong> The GC content is computed from the argmax of logits rather than
|
| 731 |
+
from the actual target sequences. This allows the gradient to flow through the constraint, enabling the
|
| 732 |
+
model to learn to satisfy the constraint during generation.
|
| 733 |
+
</div>
|
| 734 |
+
</div>
|
| 735 |
+
|
| 736 |
+
<!-- Section 3: Adaptive Parameter Update -->
|
| 737 |
+
<div class="section">
|
| 738 |
+
<div class="section-title">
|
| 739 |
+
<span class="section-number">3.</span> Adaptive ALM Parameter Updates
|
| 740 |
+
</div>
|
| 741 |
+
|
| 742 |
+
<div class="description">
|
| 743 |
+
The self-tuning mechanism adjusts Lagrangian multipliers and penalty coefficients based on
|
| 744 |
+
constraint violation progress. This adaptive approach ensures convergence while maintaining
|
| 745 |
+
numerical stability.
|
| 746 |
+
</div>
|
| 747 |
+
|
| 748 |
+
<div class="algorithm-box">
|
| 749 |
+
<div class="algorithm-title">Algorithm 1: Adaptive Penalty Update</div>
|
| 750 |
+
<div class="algorithm-content">
|
| 751 |
+
<strong>Input:</strong> gc_constraint (current violation)<br>
|
| 752 |
+
<strong>Output:</strong> Updated λ_gc and ρ_adaptive<br><br>
|
| 753 |
+
|
| 754 |
+
1. <strong>Compute</strong> relative_improvement ← <br>
|
| 755 |
+
(prev_violation - current_violation) / prev_violation<br><br>
|
| 756 |
+
|
| 757 |
+
2. <strong>If</strong> |gc_constraint| ≤ tolerance <strong>then</strong><br>
|
| 758 |
+
λ_gc ← λ_gc + ρ · gc_constraint<br>
|
| 759 |
+
// Constraint satisfied, update multiplier only<br><br>
|
| 760 |
+
|
| 761 |
+
3. <strong>Else if</strong> relative_improvement < threshold <strong>then</strong><br>
|
| 762 |
+
ρ ← min(ρ · update_factor, max_penalty)<br>
|
| 763 |
+
λ_gc ← λ_gc + ρ · gc_constraint<br>
|
| 764 |
+
// Insufficient progress, increase penalty<br><br>
|
| 765 |
+
|
| 766 |
+
4. <strong>Else</strong><br>
|
| 767 |
+
λ_gc ← λ_gc + ρ · gc_constraint<br>
|
| 768 |
+
// Good progress, keep penalty stable<br><br>
|
| 769 |
+
|
| 770 |
+
5. prev_violation ← |gc_constraint|
|
| 771 |
+
</div>
|
| 772 |
+
</div>
|
| 773 |
+
|
| 774 |
+
<div class="file-ref">
|
| 775 |
+
<div class="file-path">File: finetune.py</div>
|
| 776 |
+
Lines 260-320 | Method: _update_alm_parameters
|
| 777 |
+
</div>
|
| 778 |
+
|
| 779 |
+
<div class="code-container">
|
| 780 |
+
<div class="code-header">
|
| 781 |
+
<span class="listing-number">Listing 3:</span> Adaptive Parameter Update Implementation
|
| 782 |
+
</div>
|
| 783 |
+
<pre><code class="language-python">def _update_alm_parameters(self, gc_constraint):
|
| 784 |
+
"""
|
| 785 |
+
Update Lagrangian multiplier and penalty coefficient according to ALM.
|
| 786 |
+
|
| 787 |
+
This implements the adaptive penalty update strategy:
|
| 788 |
+
- If constraint violation is decreasing sufficiently, update lambda
|
| 789 |
+
and keep rho
|
| 790 |
+
- If constraint violation is not improving, increase rho
|
| 791 |
+
(penalty coefficient)
|
| 792 |
+
"""
|
| 793 |
+
constraint_violation = abs(gc_constraint.item())
|
| 794 |
+
|
| 795 |
+
# Check if we're making sufficient progress
|
| 796 |
+
relative_improvement = (
|
| 797 |
+
(self.previous_constraint_violation - constraint_violation) /
|
| 798 |
+
max(self.previous_constraint_violation, 1e-8)
|
| 799 |
+
)
|
| 800 |
+
|
| 801 |
+
if constraint_violation <= self.alm_tolerance:
|
| 802 |
+
# Constraint satisfied - update lambda, optionally reduce rho
|
| 803 |
+
self.lambda_gc = self.lambda_gc + self.rho_adaptive * gc_constraint
|
| 804 |
+
# Could reduce rho here if desired, but keeping it stable
|
| 805 |
+
# works well in practice
|
| 806 |
+
|
| 807 |
+
elif relative_improvement < self.alm_rel_penalty_increase_threshold:
|
| 808 |
+
# Not making enough progress - increase penalty
|
| 809 |
+
self.rho_adaptive = torch.clamp(
|
| 810 |
+
self.rho_adaptive * self.alm_penalty_update_factor,
|
| 811 |
+
min=self.alm_min_penalty,
|
| 812 |
+
max=self.alm_max_penalty
|
| 813 |
+
)
|
| 814 |
+
# Also update lambda
|
| 815 |
+
self.lambda_gc = self.lambda_gc + self.rho_adaptive * gc_constraint
|
| 816 |
+
|
| 817 |
+
else:
|
| 818 |
+
# Making good progress - just update lambda
|
| 819 |
+
self.lambda_gc = self.lambda_gc + self.rho_adaptive * gc_constraint
|
| 820 |
+
|
| 821 |
+
# Update tracking
|
| 822 |
+
self.previous_constraint_violation = torch.tensor(constraint_violation)</code></pre>
|
| 823 |
+
</div>
|
| 824 |
+
|
| 825 |
+
<div class="handwritten-note">
|
| 826 |
+
The key insight here is the relative improvement threshold. If the constraint violation isn't
|
| 827 |
+
improving by at least 10% (default threshold), we increase the penalty coefficient. This ensures
|
| 828 |
+
that the optimization doesn't get stuck in suboptimal regions where the constraint is consistently
|
| 829 |
+
violated.
|
| 830 |
+
</div>
|
| 831 |
+
</div>
|
| 832 |
+
|
| 833 |
+
<!-- Section 4: Prediction Function -->
|
| 834 |
+
<div class="section">
|
| 835 |
+
<div class="section-title">
|
| 836 |
+
<span class="section-number">4.</span> DNA Sequence Prediction with Constrained Search
|
| 837 |
+
</div>
|
| 838 |
+
|
| 839 |
+
<div class="description">
|
| 840 |
+
The prediction function supports multiple decoding strategies including deterministic (greedy),
|
| 841 |
+
stochastic (temperature sampling), and constrained beam search with GC bounds. This flexibility
|
| 842 |
+
allows users to balance between optimization quality and sequence diversity.
|
| 843 |
+
</div>
|
| 844 |
+
|
| 845 |
+
<div class="file-ref">
|
| 846 |
+
<div class="file-path">File: CodonTransformer/CodonPrediction.py</div>
|
| 847 |
+
Lines 38-120 | Function: predict_dna_sequence
|
| 848 |
+
</div>
|
| 849 |
+
|
| 850 |
+
<div class="code-container">
|
| 851 |
+
<div class="code-header">
|
| 852 |
+
<span class="listing-number">Listing 4:</span> Main Prediction Function Signature
|
| 853 |
+
</div>
|
| 854 |
+
<pre><code class="language-python">def predict_dna_sequence(
|
| 855 |
+
protein: str,
|
| 856 |
+
organism: Union[int, str],
|
| 857 |
+
device: torch.device,
|
| 858 |
+
tokenizer: Union[str, PreTrainedTokenizerFast] = None,
|
| 859 |
+
model: Union[str, torch.nn.Module] = None,
|
| 860 |
+
attention_type: str = "original_full",
|
| 861 |
+
deterministic: bool = True,
|
| 862 |
+
temperature: float = 0.2,
|
| 863 |
+
top_p: float = 0.95,
|
| 864 |
+
num_sequences: int = 1,
|
| 865 |
+
match_protein: bool = False,
|
| 866 |
+
use_constrained_search: bool = False,
|
| 867 |
+
gc_bounds: Tuple[float, float] = (0.30, 0.70),
|
| 868 |
+
beam_size: int = 5,
|
| 869 |
+
length_penalty: float = 1.0,
|
| 870 |
+
diversity_penalty: float = 0.0,
|
| 871 |
+
) -> Union[DNASequencePrediction, List[DNASequencePrediction]]:
|
| 872 |
+
"""
|
| 873 |
+
Predict the DNA sequence(s) for a given protein using ENCOT model.
|
| 874 |
+
|
| 875 |
+
This function takes a protein sequence and an organism (as ID or name)
|
| 876 |
+
as input and returns the predicted DNA sequence(s) using the ENCOT model.
|
| 877 |
+
It can use either provided tokenizer and model objects or load them from
|
| 878 |
+
specified paths.
|
| 879 |
+
|
| 880 |
+
Args:
|
| 881 |
+
protein (str): The input protein sequence for which to predict
|
| 882 |
+
the DNA sequence.
|
| 883 |
+
organism (Union[int, str]): Either the ID of the organism or its
|
| 884 |
+
name (e.g., "Escherichia coli general").
|
| 885 |
+
device (torch.device): The device (CPU or GPU) to run the model on.
|
| 886 |
+
|
| 887 |
+
deterministic (bool, optional): Whether to use deterministic decoding
|
| 888 |
+
(most likely tokens). If False, samples tokens according to their
|
| 889 |
+
probabilities adjusted by the temperature. Defaults to True.
|
| 890 |
+
|
| 891 |
+
temperature (float, optional): A value controlling the randomness of
|
| 892 |
+
predictions during non-deterministic decoding. Lower values
|
| 893 |
+
(e.g., 0.2) make the model more conservative, while higher values
|
| 894 |
+
(e.g., 0.8) increase randomness. Defaults to 0.2.
|
| 895 |
+
|
| 896 |
+
use_constrained_search (bool, optional): Enable constrained beam
|
| 897 |
+
search with GC bounds. Defaults to False.
|
| 898 |
+
|
| 899 |
+
gc_bounds (Tuple[float, float], optional): GC content bounds
|
| 900 |
+
(min, max) for constrained search. Defaults to (0.30, 0.70).
|
| 901 |
+
|
| 902 |
+
beam_size (int, optional): Beam size for beam search. Defaults to 5.
|
| 903 |
+
|
| 904 |
+
match_protein (bool, optional): Ensures the predicted DNA sequence
|
| 905 |
+
translates to the input protein sequence by sampling from only
|
| 906 |
+
the respective codons of each amino acid. Defaults to False.
|
| 907 |
+
|
| 908 |
+
Returns:
|
| 909 |
+
Union[DNASequencePrediction, List[DNASequencePrediction]]:
|
| 910 |
+
Predicted DNA sequence(s) with associated metrics.
|
| 911 |
+
"""</code></pre>
|
| 912 |
+
</div>
|
| 913 |
+
|
| 914 |
+
<div class="key-concept">
|
| 915 |
+
<strong>Decoding Strategies:</strong>
|
| 916 |
+
<table style="margin-top: 15px;">
|
| 917 |
+
<tr>
|
| 918 |
+
<th>Strategy</th>
|
| 919 |
+
<th>Use Case</th>
|
| 920 |
+
<th>Parameters</th>
|
| 921 |
+
</tr>
|
| 922 |
+
<tr>
|
| 923 |
+
<td><strong>Greedy (deterministic)</strong></td>
|
| 924 |
+
<td>Production optimization</td>
|
| 925 |
+
<td>deterministic=True</td>
|
| 926 |
+
</tr>
|
| 927 |
+
<tr>
|
| 928 |
+
<td><strong>Temperature Sampling</strong></td>
|
| 929 |
+
<td>Diversity exploration</td>
|
| 930 |
+
<td>deterministic=False, temperature=0.2-0.8</td>
|
| 931 |
+
</tr>
|
| 932 |
+
<tr>
|
| 933 |
+
<td><strong>Constrained Beam Search</strong></td>
|
| 934 |
+
<td>GC-constrained optimization</td>
|
| 935 |
+
<td>use_constrained_search=True, gc_bounds=(0.45,0.55)</td>
|
| 936 |
+
</tr>
|
| 937 |
+
</table>
|
| 938 |
+
</div>
|
| 939 |
+
</div>
|
| 940 |
+
|
| 941 |
+
<!-- Section 5: Evaluation Metrics -->
|
| 942 |
+
<div class="section">
|
| 943 |
+
<div class="section-title">
|
| 944 |
+
<span class="section-number">5.</span> Evaluation Metrics Implementation
|
| 945 |
+
</div>
|
| 946 |
+
|
| 947 |
+
<div class="description">
|
| 948 |
+
ENCOT computes comprehensive metrics to evaluate the quality of optimized sequences. The primary
|
| 949 |
+
metrics are the Codon Adaptation Index (CAI) and tRNA Adaptation Index (tAI), which quantify how
|
| 950 |
+
well the codon usage matches highly expressed E. coli genes and available tRNA pools, respectively.
|
| 951 |
+
</div>
|
| 952 |
+
|
| 953 |
+
<div class="file-ref">
|
| 954 |
+
<div class="file-path">File: CodonTransformer/CodonEvaluation.py</div>
|
| 955 |
+
Lines 23-50, 370-420 | Functions: get_CSI_value, calculate_tAI
|
| 956 |
+
</div>
|
| 957 |
+
|
| 958 |
+
<div class="code-container">
|
| 959 |
+
<div class="code-header">
|
| 960 |
+
<span class="listing-number">Listing 5:</span> CAI and tAI Calculation
|
| 961 |
+
</div>
|
| 962 |
+
<pre><code class="language-python">def get_CSI_weights(sequences: List[str]) -> Dict[str, float]:
|
| 963 |
+
"""
|
| 964 |
+
Calculate the Codon Similarity Index (CSI) weights for a list of
|
| 965 |
+
DNA sequences.
|
| 966 |
+
|
| 967 |
+
CSI is equivalent to CAI when computed from reference sequences.
|
| 968 |
+
|
| 969 |
+
Args:
|
| 970 |
+
sequences (List[str]): List of DNA sequences from highly expressed
|
| 971 |
+
genes.
|
| 972 |
+
|
| 973 |
+
Returns:
|
| 974 |
+
dict: The CSI weights (relative adaptiveness values per codon).
|
| 975 |
+
"""
|
| 976 |
+
return relative_adaptiveness(sequences=sequences)
|
| 977 |
+
|
| 978 |
+
|
| 979 |
+
def get_CSI_value(dna: str, weights: Dict[str, float]) -> float:
|
| 980 |
+
"""
|
| 981 |
+
Calculate the Codon Similarity Index (CSI) for a DNA sequence.
|
| 982 |
+
|
| 983 |
+
This is the CAI score computed using pre-calculated weights.
|
| 984 |
+
|
| 985 |
+
Args:
|
| 986 |
+
dna (str): The DNA sequence.
|
| 987 |
+
weights (dict): The CSI weights from get_CSI_weights.
|
| 988 |
+
|
| 989 |
+
Returns:
|
| 990 |
+
float: The CSI value (range 0-1, higher is better).
|
| 991 |
+
"""
|
| 992 |
+
return CAI(dna, weights)
|
| 993 |
+
|
| 994 |
+
|
| 995 |
+
def get_ecoli_tai_weights():
|
| 996 |
+
"""
|
| 997 |
+
Returns pre-calculated tAI weights for E. coli K-12 MG1655.
|
| 998 |
+
|
| 999 |
+
These weights are based on tRNA gene copy numbers and wobble base
|
| 1000 |
+
pairing rules. Higher weights indicate more available tRNA for
|
| 1001 |
+
that codon.
|
| 1002 |
+
|
| 1003 |
+
Returns:
|
| 1004 |
+
dict: Mapping from codon to tAI weight (0-1).
|
| 1005 |
+
"""
|
| 1006 |
+
return {
|
| 1007 |
+
'TTT': 0.58, 'TTC': 0.42, 'TTA': 0.13, 'TTG': 0.13,
|
| 1008 |
+
'TCT': 0.15, 'TCC': 0.15, 'TCA': 0.12, 'TCG': 0.15,
|
| 1009 |
+
'TAT': 0.59, 'TAC': 0.41, 'TGT': 0.46, 'TGC': 0.54,
|
| 1010 |
+
'TGG': 1.00, 'CTT': 0.11, 'CTC': 0.10, 'CTA': 0.04,
|
| 1011 |
+
'CTG': 0.49, 'CCT': 0.16, 'CCC': 0.12, 'CCA': 0.19,
|
| 1012 |
+
'CCG': 0.52, 'CAT': 0.57, 'CAC': 0.43, 'CAA': 0.34,
|
| 1013 |
+
'CAG': 0.66, 'ATT': 0.51, 'ATC': 0.42, 'ATA': 0.07,
|
| 1014 |
+
'ATG': 1.00, 'ACT': 0.17, 'ACC': 0.44, 'ACA': 0.13,
|
| 1015 |
+
'ACG': 0.27, 'AAT': 0.49, 'AAC': 0.51, 'AAA': 0.76,
|
| 1016 |
+
'AAG': 0.24, 'AGT': 0.15, 'AGC': 0.28, 'AGA': 0.07,
|
| 1017 |
+
'AGG': 0.04, 'GTT': 0.28, 'GTC': 0.20, 'GTA': 0.15,
|
| 1018 |
+
'GTG': 0.37, 'GCT': 0.18, 'GCC': 0.27, 'GCA': 0.21,
|
| 1019 |
+
'GCG': 0.36, 'GAT': 0.63, 'GAC': 0.37, 'GAA': 0.68,
|
| 1020 |
+
'GAG': 0.32, 'GGT': 0.35, 'GGC': 0.40, 'GGA': 0.11,
|
| 1021 |
+
'GGG': 0.15,
|
| 1022 |
+
}
|
| 1023 |
+
|
| 1024 |
+
|
| 1025 |
+
def calculate_tAI(sequence: str, tai_weights: Dict[str, float]) -> float:
|
| 1026 |
+
"""
|
| 1027 |
+
Calculate the tRNA Adaptation Index (tAI) for a DNA sequence.
|
| 1028 |
+
|
| 1029 |
+
The tAI is the geometric mean of the tAI weights for all codons in
|
| 1030 |
+
the sequence (excluding stop codons).
|
| 1031 |
+
|
| 1032 |
+
Args:
|
| 1033 |
+
sequence (str): DNA sequence (must be divisible by 3)
|
| 1034 |
+
tai_weights (Dict[str, float]): tAI weights for each codon
|
| 1035 |
+
|
| 1036 |
+
Returns:
|
| 1037 |
+
float: Geometric mean of tAI weights (range 0-1)
|
| 1038 |
+
"""
|
| 1039 |
+
if len(sequence) % 3 != 0:
|
| 1040 |
+
raise ValueError("Sequence length must be divisible by 3")
|
| 1041 |
+
|
| 1042 |
+
# Split into codons
|
| 1043 |
+
codons = [sequence[i:i+3].upper() for i in range(0, len(sequence), 3)]
|
| 1044 |
+
|
| 1045 |
+
# Get weights for non-stop codons
|
| 1046 |
+
weights = [tai_weights.get(codon, 0.5) for codon in codons
|
| 1047 |
+
if codon not in ['TAA', 'TAG', 'TGA']]
|
| 1048 |
+
|
| 1049 |
+
if not weights:
|
| 1050 |
+
return 0.0
|
| 1051 |
+
|
| 1052 |
+
# Compute geometric mean
|
| 1053 |
+
product = 1.0
|
| 1054 |
+
for w in weights:
|
| 1055 |
+
product *= w
|
| 1056 |
+
return product ** (1.0 / len(weights))</code></pre>
|
| 1057 |
+
</div>
|
| 1058 |
+
|
| 1059 |
+
<div class="annotation">
|
| 1060 |
+
<strong>Metric Interpretation:</strong> Both CAI and tAI range from 0 to 1, with higher values
|
| 1061 |
+
indicating better optimization. In practice, for E. coli:
|
| 1062 |
+
<ul style="margin: 10px 0 0 20px;">
|
| 1063 |
+
<li>CAI > 0.8 indicates excellent codon adaptation</li>
|
| 1064 |
+
<li>tAI > 0.4 suggests adequate tRNA availability</li>
|
| 1065 |
+
<li>Native E. coli genes typically have CAI around 0.65-0.75</li>
|
| 1066 |
+
</ul>
|
| 1067 |
+
</div>
|
| 1068 |
+
</div>
|
| 1069 |
+
|
| 1070 |
+
<!-- Section 6: Training Configuration -->
|
| 1071 |
+
<div class="section">
|
| 1072 |
+
<div class="section-title">
|
| 1073 |
+
<span class="section-number">6.</span> Training Configuration
|
| 1074 |
+
</div>
|
| 1075 |
+
|
| 1076 |
+
<div class="description">
|
| 1077 |
+
The training configuration specifies all hyperparameters including learning rate, batch size,
|
| 1078 |
+
and ALM-specific settings. This configuration reproduces the exact setup used in our experiments.
|
| 1079 |
+
</div>
|
| 1080 |
+
|
| 1081 |
+
<div class="file-ref">
|
| 1082 |
+
<div class="file-path">File: configs/train_ecoli_alm.yaml</div>
|
| 1083 |
+
Complete configuration file
|
| 1084 |
+
</div>
|
| 1085 |
+
|
| 1086 |
+
<div class="code-container">
|
| 1087 |
+
<div class="code-header">
|
| 1088 |
+
<span class="listing-number">Listing 6:</span> Complete Training Configuration
|
| 1089 |
+
</div>
|
| 1090 |
+
<pre><code class="language-yaml"># ENCOT ALM Training Configuration
|
| 1091 |
+
# This configuration reproduces the main training setup from the paper
|
| 1092 |
+
# using the Augmented-Lagrangian Method (ALM) for GC content control.
|
| 1093 |
+
|
| 1094 |
+
model:
|
| 1095 |
+
base_model: "adibvafa/CodonTransformer-base"
|
| 1096 |
+
tokenizer: "adibvafa/CodonTransformer"
|
| 1097 |
+
|
| 1098 |
+
data:
|
| 1099 |
+
dataset_dir: "data"
|
| 1100 |
+
# Expected files: finetune_set.json (created by preprocess_data.py)
|
| 1101 |
+
|
| 1102 |
+
training:
|
| 1103 |
+
batch_size: 6
|
| 1104 |
+
max_epochs: 15
|
| 1105 |
+
learning_rate: 5e-5
|
| 1106 |
+
warmup_fraction: 0.1
|
| 1107 |
+
num_workers: 5
|
| 1108 |
+
accumulate_grad_batches: 1
|
| 1109 |
+
num_gpus: 4
|
| 1110 |
+
save_every_n_steps: 512
|
| 1111 |
+
seed: 123
|
| 1112 |
+
log_every_n_steps: 20
|
| 1113 |
+
|
| 1114 |
+
checkpoint:
|
| 1115 |
+
checkpoint_dir: "models/alm-enhanced-training"
|
| 1116 |
+
checkpoint_filename: "balanced_alm_finetune.ckpt"
|
| 1117 |
+
|
| 1118 |
+
# Augmented-Lagrangian Method (ALM) for GC content control
|
| 1119 |
+
alm:
|
| 1120 |
+
enabled: true
|
| 1121 |
+
gc_target: 0.52 # Target GC content for E. coli (52%)
|
| 1122 |
+
curriculum_epochs: 3 # Warm-up epochs before enforcing GC constraint
|
| 1123 |
+
|
| 1124 |
+
# ALM penalty parameters
|
| 1125 |
+
initial_penalty_factor: 20.0
|
| 1126 |
+
penalty_update_factor: 10.0
|
| 1127 |
+
max_penalty: 1e6
|
| 1128 |
+
min_penalty: 1e-6
|
| 1129 |
+
|
| 1130 |
+
# ALM tolerance parameters
|
| 1131 |
+
tolerance: 1e-5 # Primal tolerance
|
| 1132 |
+
dual_tolerance: 1e-5 # Dual tolerance for constraint violation
|
| 1133 |
+
tolerance_update_factor: 0.1
|
| 1134 |
+
|
| 1135 |
+
# Adaptive penalty adjustment
|
| 1136 |
+
rel_penalty_increase_threshold: 0.1
|
| 1137 |
+
|
| 1138 |
+
# Legacy penalty method (if ALM disabled)
|
| 1139 |
+
gc_penalty:
|
| 1140 |
+
weight: 0.0 # Only used if use_lagrangian=false</code></pre>
|
| 1141 |
+
</div>
|
| 1142 |
+
|
| 1143 |
+
<div class="key-concept">
|
| 1144 |
+
<strong>Hyperparameter Selection Rationale:</strong>
|
| 1145 |
+
<table style="margin-top: 15px;">
|
| 1146 |
+
<tr>
|
| 1147 |
+
<th>Parameter</th>
|
| 1148 |
+
<th>Value</th>
|
| 1149 |
+
<th>Rationale</th>
|
| 1150 |
+
</tr>
|
| 1151 |
+
<tr>
|
| 1152 |
+
<td>gc_target</td>
|
| 1153 |
+
<td>0.52</td>
|
| 1154 |
+
<td>Native E. coli genome GC content</td>
|
| 1155 |
+
</tr>
|
| 1156 |
+
<tr>
|
| 1157 |
+
<td>curriculum_epochs</td>
|
| 1158 |
+
<td>3</td>
|
| 1159 |
+
<td>Allow basic pattern learning before constraint</td>
|
| 1160 |
+
</tr>
|
| 1161 |
+
<tr>
|
| 1162 |
+
<td>initial_penalty_factor</td>
|
| 1163 |
+
<td>20.0</td>
|
| 1164 |
+
<td>Moderate initial constraint enforcement</td>
|
| 1165 |
+
</tr>
|
| 1166 |
+
<tr>
|
| 1167 |
+
<td>penalty_update_factor</td>
|
| 1168 |
+
<td>10.0</td>
|
| 1169 |
+
<td>Aggressive adaptation for fast convergence</td>
|
| 1170 |
+
</tr>
|
| 1171 |
+
</table>
|
| 1172 |
+
</div>
|
| 1173 |
+
</div>
|
| 1174 |
+
|
| 1175 |
+
<!-- Section 7: Data Validation -->
|
| 1176 |
+
<div class="section">
|
| 1177 |
+
<div class="section-title">
|
| 1178 |
+
<span class="section-number">7.</span> Sequence Validation Pipeline
|
| 1179 |
+
</div>
|
| 1180 |
+
|
| 1181 |
+
<div class="description">
|
| 1182 |
+
Before training, all DNA sequences undergo rigorous validation to ensure biological correctness.
|
| 1183 |
+
Invalid sequences are filtered out to maintain data quality.
|
| 1184 |
+
</div>
|
| 1185 |
+
|
| 1186 |
+
<div class="file-ref">
|
| 1187 |
+
<div class="file-path">File: prepare_ecoli_data.py</div>
|
| 1188 |
+
Lines 5-30 | Function: is_valid_sequence
|
| 1189 |
+
</div>
|
| 1190 |
+
|
| 1191 |
+
<div class="code-container">
|
| 1192 |
+
<div class="code-header">
|
| 1193 |
+
<span class="listing-number">Listing 7:</span> Sequence Validation Function
|
| 1194 |
+
</div>
|
| 1195 |
+
<pre><code class="language-python">def is_valid_sequence(dna_seq: str) -> bool:
|
| 1196 |
+
"""
|
| 1197 |
+
Applies a series of validation checks to a DNA sequence.
|
| 1198 |
+
|
| 1199 |
+
Validation criteria:
|
| 1200 |
+
1. Length must be divisible by 3 (valid codon frame)
|
| 1201 |
+
2. Must start with a valid start codon (ATG, TTG, CTG, or GTG)
|
| 1202 |
+
3. Must end with a valid stop codon (TAA, TAG, or TGA)
|
| 1203 |
+
4. Must not contain internal stop codons
|
| 1204 |
+
5. Must contain only valid nucleotides (A, T, G, C)
|
| 1205 |
+
|
| 1206 |
+
Args:
|
| 1207 |
+
dna_seq (str): The DNA sequence to validate.
|
| 1208 |
+
|
| 1209 |
+
Returns:
|
| 1210 |
+
bool: True if the sequence passes all checks, False otherwise.
|
| 1211 |
+
"""
|
| 1212 |
+
# Check 1: Valid codon frame
|
| 1213 |
+
if len(dna_seq) % 3 != 0:
|
| 1214 |
+
return False
|
| 1215 |
+
|
| 1216 |
+
# Check 2: Valid start codon
|
| 1217 |
+
if not dna_seq.upper().startswith(('ATG', 'TTG', 'CTG', 'GTG')):
|
| 1218 |
+
return False
|
| 1219 |
+
|
| 1220 |
+
# Check 3: Valid stop codon
|
| 1221 |
+
if not dna_seq.upper().endswith(('TAA', 'TAG', 'TGA')):
|
| 1222 |
+
return False
|
| 1223 |
+
|
| 1224 |
+
# Check 4: No internal stop codons (excluding the last codon)
|
| 1225 |
+
codons = [dna_seq[i:i+3].upper()
|
| 1226 |
+
for i in range(0, len(dna_seq) - 3, 3)]
|
| 1227 |
+
if any(codon in ['TAA', 'TAG', 'TGA'] for codon in codons):
|
| 1228 |
+
return False
|
| 1229 |
+
|
| 1230 |
+
# Check 5: Only valid nucleotides
|
| 1231 |
+
if not all(c in 'ATGC' for c in dna_seq.upper()):
|
| 1232 |
+
return False
|
| 1233 |
+
|
| 1234 |
+
return True</code></pre>
|
| 1235 |
+
</div>
|
| 1236 |
+
|
| 1237 |
+
<div class="handwritten-note">
|
| 1238 |
+
The validation function is intentionally strict to ensure high-quality training data. In our
|
| 1239 |
+
preprocessing of the E. coli genome, approximately 95% of sequences passed all validation checks.
|
| 1240 |
+
The most common reason for rejection was sequences with internal stop codons due to sequencing
|
| 1241 |
+
errors or pseudogenes.
|
| 1242 |
+
</div>
|
| 1243 |
+
</div>
|
| 1244 |
+
|
| 1245 |
+
<!-- Section 8: Benchmark Evaluation -->
|
| 1246 |
+
<div class="section">
|
| 1247 |
+
<div class="section-title">
|
| 1248 |
+
<span class="section-number">8.</span> Benchmark Evaluation Pipeline
|
| 1249 |
+
</div>
|
| 1250 |
+
|
| 1251 |
+
<div class="description">
|
| 1252 |
+
The benchmark pipeline evaluates ENCOT on a test set of protein sequences, computing multiple
|
| 1253 |
+
metrics for each optimized sequence and generating comprehensive performance reports.
|
| 1254 |
+
</div>
|
| 1255 |
+
|
| 1256 |
+
<div class="file-ref">
|
| 1257 |
+
<div class="file-path">File: benchmark_evaluation.py</div>
|
| 1258 |
+
Lines 300-400 | Function: benchmark_sequences
|
| 1259 |
+
</div>
|
| 1260 |
+
|
| 1261 |
+
<div class="code-container">
|
| 1262 |
+
<div class="code-header">
|
| 1263 |
+
<span class="listing-number">Listing 8:</span> Benchmark Evaluation Function
|
| 1264 |
+
</div>
|
| 1265 |
+
<pre><code class="language-python">def benchmark_sequences(sequences, model, tokenizer, device,
|
| 1266 |
+
cai_weights, tai_weights):
|
| 1267 |
+
"""
|
| 1268 |
+
Run ENCOT on protein sequences and compute metrics for optimized DNA.
|
| 1269 |
+
|
| 1270 |
+
Args:
|
| 1271 |
+
sequences: List of (name, protein) tuples to optimize
|
| 1272 |
+
model: Loaded ENCOT model
|
| 1273 |
+
tokenizer: Tokenizer for the model
|
| 1274 |
+
device: PyTorch device (CPU/GPU)
|
| 1275 |
+
cai_weights: Pre-computed CAI weights from reference sequences
|
| 1276 |
+
tai_weights: Pre-computed tAI weights for E. coli
|
| 1277 |
+
|
| 1278 |
+
Returns:
|
| 1279 |
+
DataFrame with columns: name, protein, optimized_dna, CAI, tAI,
|
| 1280 |
+
GC_content, negative_cis_elements
|
| 1281 |
+
"""
|
| 1282 |
+
results = []
|
| 1283 |
+
|
| 1284 |
+
for name, protein in tqdm(sequences, desc="Optimizing sequences"):
|
| 1285 |
+
# Optimize the sequence using ENCOT
|
| 1286 |
+
output = predict_dna_sequence(
|
| 1287 |
+
protein=protein,
|
| 1288 |
+
organism="Escherichia coli general",
|
| 1289 |
+
device=device,
|
| 1290 |
+
model=model,
|
| 1291 |
+
tokenizer=tokenizer,
|
| 1292 |
+
deterministic=True,
|
| 1293 |
+
use_constrained_search=True,
|
| 1294 |
+
gc_bounds=(0.45, 0.55) # E. coli optimal range
|
| 1295 |
+
)
|
| 1296 |
+
|
| 1297 |
+
optimized_dna = output.predicted_dna
|
| 1298 |
+
|
| 1299 |
+
# Calculate comprehensive metrics
|
| 1300 |
+
cai = get_CSI_value(optimized_dna, cai_weights)
|
| 1301 |
+
tai = calculate_tAI(optimized_dna, tai_weights)
|
| 1302 |
+
gc_content = get_GC_content(optimized_dna)
|
| 1303 |
+
cis_elements = count_negative_cis_elements(optimized_dna)
|
| 1304 |
+
homopolymers = calculate_homopolymer_runs(optimized_dna)
|
| 1305 |
+
|
| 1306 |
+
results.append({
|
| 1307 |
+
'name': name,
|
| 1308 |
+
'protein': protein,
|
| 1309 |
+
'optimized_dna': optimized_dna,
|
| 1310 |
+
'length': len(optimized_dna),
|
| 1311 |
+
'CAI': cai,
|
| 1312 |
+
'tAI': tai,
|
| 1313 |
+
'GC_content': gc_content,
|
| 1314 |
+
'negative_cis_elements': cis_elements,
|
| 1315 |
+
'max_homopolymer_length': homopolymers
|
| 1316 |
+
})
|
| 1317 |
+
|
| 1318 |
+
return pd.DataFrame(results)</code></pre>
|
| 1319 |
+
</div>
|
| 1320 |
+
|
| 1321 |
+
<div class="key-concept">
|
| 1322 |
+
<strong>Benchmark Metrics Summary:</strong>
|
| 1323 |
+
<ul style="margin: 10px 0 0 20px;">
|
| 1324 |
+
<li><strong>CAI:</strong> Measures codon usage similarity to highly expressed genes</li>
|
| 1325 |
+
<li><strong>tAI:</strong> Quantifies tRNA availability for translation</li>
|
| 1326 |
+
<li><strong>GC Content:</strong> Should be near 52% for E. coli</li>
|
| 1327 |
+
<li><strong>Negative cis-elements:</strong> Count of problematic regulatory sequences</li>
|
| 1328 |
+
<li><strong>Homopolymers:</strong> Long runs that cause synthesis issues</li>
|
| 1329 |
+
</ul>
|
| 1330 |
+
</div>
|
| 1331 |
+
</div>
|
| 1332 |
+
|
| 1333 |
+
<!-- Section 9: Usage Example -->
|
| 1334 |
+
<div class="section">
|
| 1335 |
+
<div class="section-title">
|
| 1336 |
+
<span class="section-number">9.</span> Complete Usage Example
|
| 1337 |
+
</div>
|
| 1338 |
+
|
| 1339 |
+
<div class="description">
|
| 1340 |
+
This example demonstrates a complete workflow: loading the model, optimizing a sequence, and
|
| 1341 |
+
evaluating the results. This is the recommended pattern for production use.
|
| 1342 |
+
</div>
|
| 1343 |
+
|
| 1344 |
+
<div class="code-container">
|
| 1345 |
+
<div class="code-header">
|
| 1346 |
+
<span class="listing-number">Listing 9:</span> End-to-End Optimization Workflow
|
| 1347 |
+
</div>
|
| 1348 |
+
<pre><code class="language-python">#!/usr/bin/env python3
|
| 1349 |
+
"""
|
| 1350 |
+
Complete workflow example for ENCOT codon optimization.
|
| 1351 |
+
"""
|
| 1352 |
+
|
| 1353 |
+
import torch
|
| 1354 |
+
from transformers import AutoTokenizer
|
| 1355 |
+
from CodonTransformer.CodonPrediction import load_model, predict_dna_sequence
|
| 1356 |
+
from CodonTransformer.CodonEvaluation import (
|
| 1357 |
+
get_GC_content, calculate_tAI, get_CSI_value,
|
| 1358 |
+
get_ecoli_tai_weights, count_negative_cis_elements
|
| 1359 |
+
)
|
| 1360 |
+
from CAI import relative_adaptiveness
|
| 1361 |
+
from huggingface_hub import hf_hub_download
|
| 1362 |
+
|
| 1363 |
+
# Step 1: Setup device and load model
|
| 1364 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 1365 |
+
print(f"Using device: {device}")
|
| 1366 |
+
|
| 1367 |
+
# Download model from HuggingFace
|
| 1368 |
+
checkpoint_path = hf_hub_download(
|
| 1369 |
+
repo_id="saketh11/ColiFormer",
|
| 1370 |
+
filename="balanced_alm_finetune.ckpt",
|
| 1371 |
+
cache_dir="./hf_cache"
|
| 1372 |
+
)
|
| 1373 |
+
|
| 1374 |
+
model = load_model(model_path=checkpoint_path, device=device)
|
| 1375 |
+
tokenizer = AutoTokenizer.from_pretrained("adibvafa/CodonTransformer")
|
| 1376 |
+
|
| 1377 |
+
# Step 2: Define protein to optimize
|
| 1378 |
+
protein = "MALWMRLLPLLALLALWGPDPAAAFVNQHLCGSHLVEALYLVCGERGFFYTPKTRREAEDLQVGQVELGG"
|
| 1379 |
+
print(f"Input protein ({len(protein)} aa): {protein}")
|
| 1380 |
+
|
| 1381 |
+
# Step 3: Optimize the sequence
|
| 1382 |
+
print("\nOptimizing...")
|
| 1383 |
+
output = predict_dna_sequence(
|
| 1384 |
+
protein=protein,
|
| 1385 |
+
organism="Escherichia coli general",
|
| 1386 |
+
device=device,
|
| 1387 |
+
model=model,
|
| 1388 |
+
tokenizer=tokenizer,
|
| 1389 |
+
deterministic=True,
|
| 1390 |
+
match_protein=True,
|
| 1391 |
+
use_constrained_search=True,
|
| 1392 |
+
gc_bounds=(0.45, 0.55),
|
| 1393 |
+
beam_size=20
|
| 1394 |
+
)
|
| 1395 |
+
|
| 1396 |
+
optimized_dna = output.predicted_dna
|
| 1397 |
+
print(f"Optimized DNA ({len(optimized_dna)} bp): {optimized_dna[:60]}...")
|
| 1398 |
+
|
| 1399 |
+
# Step 4: Evaluate metrics
|
| 1400 |
+
print("\nComputing metrics...")
|
| 1401 |
+
|
| 1402 |
+
# Load reference weights
|
| 1403 |
+
tai_weights = get_ecoli_tai_weights()
|
| 1404 |
+
|
| 1405 |
+
# For CAI, we need reference sequences (use E. coli highly expressed genes)
|
| 1406 |
+
# In practice, load from your reference dataset
|
| 1407 |
+
reference_sequences = load_reference_sequences() # Your function
|
| 1408 |
+
cai_weights = relative_adaptiveness(reference_sequences)
|
| 1409 |
+
|
| 1410 |
+
# Calculate metrics
|
| 1411 |
+
cai = get_CSI_value(optimized_dna, cai_weights)
|
| 1412 |
+
tai = calculate_tAI(optimized_dna, tai_weights)
|
| 1413 |
+
gc = get_GC_content(optimized_dna)
|
| 1414 |
+
cis = count_negative_cis_elements(optimized_dna)
|
| 1415 |
+
|
| 1416 |
+
# Step 5: Report results
|
| 1417 |
+
print("\n" + "="*50)
|
| 1418 |
+
print("OPTIMIZATION RESULTS")
|
| 1419 |
+
print("="*50)
|
| 1420 |
+
print(f"CAI (Codon Adaptation Index): {cai:.4f}")
|
| 1421 |
+
print(f"tAI (tRNA Adaptation Index): {tai:.4f}")
|
| 1422 |
+
print(f"GC Content: {gc:.2f}%")
|
| 1423 |
+
print(f"Negative cis-regulatory elements: {cis}")
|
| 1424 |
+
print("="*50)
|
| 1425 |
+
|
| 1426 |
+
# Step 6: Verify translation
|
| 1427 |
+
from Bio.Seq import Seq
|
| 1428 |
+
translated = str(Seq(optimized_dna).translate())
|
| 1429 |
+
assert translated == protein, "Translation mismatch!"
|
| 1430 |
+
print("\n✓ Optimized DNA correctly translates to input protein")</code></pre>
|
| 1431 |
+
</div>
|
| 1432 |
+
</div>
|
| 1433 |
+
|
| 1434 |
+
<!-- Section 11: Constrained Beam Search -->
|
| 1435 |
+
<div class="section">
|
| 1436 |
+
<div class="section-title">
|
| 1437 |
+
<span class="section-number">11.</span> Constrained Beam Search Implementation
|
| 1438 |
+
</div>
|
| 1439 |
+
|
| 1440 |
+
<div class="description">
|
| 1441 |
+
The constrained beam search algorithm ensures that generated DNA sequences maintain GC content within specified bounds. This method prunes candidates that violate constraints during generation, improving efficiency compared to post-hoc filtering.
|
| 1442 |
+
</div>
|
| 1443 |
+
|
| 1444 |
+
<div class="file-ref">
|
| 1445 |
+
<div class="file-path">File: CodonTransformer/CodonPrediction.py</div>
|
| 1446 |
+
Lines 850-950 | Function: _constrained_beam_search()
|
| 1447 |
+
</div>
|
| 1448 |
+
|
| 1449 |
+
<div class="code-container">
|
| 1450 |
+
<div class="code-header">
|
| 1451 |
+
<span class="listing-number">Listing 11:</span> Constrained Beam Search Core
|
| 1452 |
+
</div>
|
| 1453 |
+
<pre><code class="language-python">def _constrained_beam_search(model, input_ids, attention_mask,
|
| 1454 |
+
beam_size, gc_bounds, max_len, device):
|
| 1455 |
+
"""
|
| 1456 |
+
Constrained beam search that enforces GC content bounds during generation.
|
| 1457 |
+
|
| 1458 |
+
Args:
|
| 1459 |
+
model: CodonTransformer model
|
| 1460 |
+
input_ids: Tokenized input [batch_size, seq_len]
|
| 1461 |
+
attention_mask: Attention mask
|
| 1462 |
+
beam_size: Number of candidates to maintain
|
| 1463 |
+
gc_bounds: (min_gc, max_gc) tuple for GC content
|
| 1464 |
+
max_len: Maximum sequence length
|
| 1465 |
+
device: torch device
|
| 1466 |
+
|
| 1467 |
+
Returns:
|
| 1468 |
+
Best sequence satisfying GC constraints
|
| 1469 |
+
"""
|
| 1470 |
+
batch_size = input_ids.size(0)
|
| 1471 |
+
min_gc, max_gc = gc_bounds
|
| 1472 |
+
|
| 1473 |
+
# Initialize beams: (sequence, score, gc_count, length)
|
| 1474 |
+
beams = [(input_ids[0].clone(), 0.0, 0, 0)]
|
| 1475 |
+
|
| 1476 |
+
for step in range(max_len):
|
| 1477 |
+
all_candidates = []
|
| 1478 |
+
|
| 1479 |
+
for seq, score, gc_count, length in beams:
|
| 1480 |
+
# Get model predictions
|
| 1481 |
+
with torch.no_grad():
|
| 1482 |
+
outputs = model(seq.unsqueeze(0))
|
| 1483 |
+
logits = outputs.logits[0, -1, :] # Last position
|
| 1484 |
+
probs = torch.softmax(logits, dim=-1)
|
| 1485 |
+
|
| 1486 |
+
# Get top-k tokens
|
| 1487 |
+
top_probs, top_indices = torch.topk(probs, beam_size * 2)
|
| 1488 |
+
|
| 1489 |
+
for prob, token_id in zip(top_probs, top_indices):
|
| 1490 |
+
# Decode token to codon
|
| 1491 |
+
token = tokenizer.decode([token_id])
|
| 1492 |
+
|
| 1493 |
+
# Calculate GC content
|
| 1494 |
+
new_gc_count = gc_count + token.count('G') + token.count('C')
|
| 1495 |
+
new_length = length + len(token)
|
| 1496 |
+
current_gc = new_gc_count / new_length if new_length > 0 else 0.0
|
| 1497 |
+
|
| 1498 |
+
# Check GC constraint (with some relaxation early on)
|
| 1499 |
+
relaxation = max(0.1, 1.0 - step / max_len)
|
| 1500 |
+
if min_gc - relaxation <= current_gc <= max_gc + relaxation:
|
| 1501 |
+
new_seq = torch.cat([seq, token_id.unsqueeze(0)])
|
| 1502 |
+
new_score = score + torch.log(prob).item()
|
| 1503 |
+
all_candidates.append((new_seq, new_score,
|
| 1504 |
+
new_gc_count, new_length))
|
| 1505 |
+
|
| 1506 |
+
# Select top beams
|
| 1507 |
+
all_candidates.sort(key=lambda x: x[1], reverse=True)
|
| 1508 |
+
beams = all_candidates[:beam_size]
|
| 1509 |
+
|
| 1510 |
+
if not beams:
|
| 1511 |
+
raise ValueError("No valid candidates found within GC bounds")
|
| 1512 |
+
|
| 1513 |
+
# Return best sequence
|
| 1514 |
+
return beams[0][0]</code></pre>
|
| 1515 |
+
</div>
|
| 1516 |
+
|
| 1517 |
+
<div class="handwritten-note">
|
| 1518 |
+
The relaxation factor allows more flexibility early in generation, gradually tightening constraints as the sequence grows. This prevents premature pruning of potentially good candidates.
|
| 1519 |
+
</div>
|
| 1520 |
+
</div>
|
| 1521 |
+
|
| 1522 |
+
<!-- Section 12: GC Content Calculation -->
|
| 1523 |
+
<div class="section">
|
| 1524 |
+
<div class="section-title">
|
| 1525 |
+
<span class="section-number">12.</span> GC Content Analysis
|
| 1526 |
+
</div>
|
| 1527 |
+
|
| 1528 |
+
<div class="description">
|
| 1529 |
+
Precise GC content calculation is critical for both training constraints and sequence evaluation. The implementation handles edge cases and provides window-based analysis for local GC variations.
|
| 1530 |
+
</div>
|
| 1531 |
+
|
| 1532 |
+
<div class="file-ref">
|
| 1533 |
+
<div class="file-path">File: CodonTransformer/CodonEvaluation.py</div>
|
| 1534 |
+
Lines 245-285 | Function: get_GC_content()
|
| 1535 |
+
</div>
|
| 1536 |
+
|
| 1537 |
+
<div class="code-container">
|
| 1538 |
+
<div class="code-header">
|
| 1539 |
+
<span class="listing-number">Listing 12:</span> GC Content Calculation
|
| 1540 |
+
</div>
|
| 1541 |
+
<pre><code class="language-python">def get_GC_content(dna_sequence: str, window_size: int = None) -> float:
|
| 1542 |
+
"""
|
| 1543 |
+
Calculate GC content of a DNA sequence.
|
| 1544 |
+
|
| 1545 |
+
Args:
|
| 1546 |
+
dna_sequence: DNA sequence string
|
| 1547 |
+
window_size: If provided, calculate sliding window GC content
|
| 1548 |
+
|
| 1549 |
+
Returns:
|
| 1550 |
+
GC content as percentage (0-100) or list of windowed values
|
| 1551 |
+
"""
|
| 1552 |
+
if not dna_sequence:
|
| 1553 |
+
raise ValueError("DNA sequence cannot be empty")
|
| 1554 |
+
|
| 1555 |
+
# Convert to uppercase and validate
|
| 1556 |
+
dna_sequence = dna_sequence.upper()
|
| 1557 |
+
valid_bases = set('ATGC')
|
| 1558 |
+
if not all(base in valid_bases for base in dna_sequence):
|
| 1559 |
+
raise ValueError("DNA sequence contains invalid characters")
|
| 1560 |
+
|
| 1561 |
+
if window_size is None:
|
| 1562 |
+
# Global GC content
|
| 1563 |
+
gc_count = dna_sequence.count('G') + dna_sequence.count('C')
|
| 1564 |
+
total = len(dna_sequence)
|
| 1565 |
+
return (gc_count / total) * 100.0 if total > 0 else 0.0
|
| 1566 |
+
else:
|
| 1567 |
+
# Sliding window GC content
|
| 1568 |
+
if window_size <= 0 or window_size > len(dna_sequence):
|
| 1569 |
+
raise ValueError(f"Invalid window size: {window_size}")
|
| 1570 |
+
|
| 1571 |
+
gc_values = []
|
| 1572 |
+
for i in range(len(dna_sequence) - window_size + 1):
|
| 1573 |
+
window = dna_sequence[i:i + window_size]
|
| 1574 |
+
gc_count = window.count('G') + window.count('C')
|
| 1575 |
+
gc_pct = (gc_count / window_size) * 100.0
|
| 1576 |
+
gc_values.append(gc_pct)
|
| 1577 |
+
|
| 1578 |
+
return gc_values
|
| 1579 |
+
|
| 1580 |
+
def calculate_gc_variance(dna_sequence: str, window_size: int = 100) -> float:
|
| 1581 |
+
"""Calculate variance in GC content across sequence windows"""
|
| 1582 |
+
gc_values = get_GC_content(dna_sequence, window_size)
|
| 1583 |
+
if len(gc_values) < 2:
|
| 1584 |
+
return 0.0
|
| 1585 |
+
|
| 1586 |
+
mean_gc = sum(gc_values) / len(gc_values)
|
| 1587 |
+
variance = sum((x - mean_gc) ** 2 for x in gc_values) / len(gc_values)
|
| 1588 |
+
return variance</code></pre>
|
| 1589 |
+
</div>
|
| 1590 |
+
</div>
|
| 1591 |
+
|
| 1592 |
+
<!-- Section 13: Tokenization Pipeline -->
|
| 1593 |
+
<div class="section">
|
| 1594 |
+
<div class="section-title">
|
| 1595 |
+
<span class="section-number">13.</span> Sequence Tokenization
|
| 1596 |
+
</div>
|
| 1597 |
+
|
| 1598 |
+
<div class="description">
|
| 1599 |
+
The tokenization pipeline converts protein and DNA sequences into codon-level tokens that the transformer can process. Each codon is represented as a single token (e.g., "l_ctg" for leucine codon CTG).
|
| 1600 |
+
</div>
|
| 1601 |
+
|
| 1602 |
+
<div class="file-ref">
|
| 1603 |
+
<div class="file-path">File: CodonTransformer/CodonUtils.py</div>
|
| 1604 |
+
Lines 35-130 | Constant: TOKEN2INDEX
|
| 1605 |
+
</div>
|
| 1606 |
+
|
| 1607 |
+
<div class="code-container">
|
| 1608 |
+
<div class="code-header">
|
| 1609 |
+
<span class="listing-number">Listing 13:</span> Codon Tokenization Dictionary
|
| 1610 |
+
</div>
|
| 1611 |
+
<pre><code class="language-python"># Codon-to-token mapping: amino_acid_codon format
|
| 1612 |
+
TOKEN2INDEX = {
|
| 1613 |
+
"[PAD]": 0, # Padding token
|
| 1614 |
+
"[UNK]": 1, # Unknown token
|
| 1615 |
+
"[CLS]": 2, # Classification token
|
| 1616 |
+
"[SEP]": 3, # Separator token
|
| 1617 |
+
"[MASK]": 4, # Mask token for MLM
|
| 1618 |
+
|
| 1619 |
+
# Amino acid codons (format: amino_codon)
|
| 1620 |
+
"a_gca": 62, # Alanine - GCA
|
| 1621 |
+
"a_gcc": 63, # Alanine - GCC
|
| 1622 |
+
"a_gcg": 64, # Alanine - GCG
|
| 1623 |
+
"a_gct": 65, # Alanine - GCT
|
| 1624 |
+
|
| 1625 |
+
"c_tgc": 83, # Cysteine - TGC
|
| 1626 |
+
"c_tgt": 85, # Cysteine - TGT
|
| 1627 |
+
|
| 1628 |
+
"d_gac": 59, # Aspartate - GAC
|
| 1629 |
+
"d_gat": 61, # Aspartate - GAT
|
| 1630 |
+
|
| 1631 |
+
"e_gaa": 58, # Glutamate - GAA
|
| 1632 |
+
"e_gag": 60, # Glutamate - GAG
|
| 1633 |
+
|
| 1634 |
+
"f_ttc": 87, # Phenylalanine - TTC
|
| 1635 |
+
"f_ttt": 89, # Phenylalanine - TTT
|
| 1636 |
+
|
| 1637 |
+
"g_gga": 66, # Glycine - GGA
|
| 1638 |
+
"g_ggc": 67, # Glycine - GGC
|
| 1639 |
+
"g_ggg": 68, # Glycine - GGG
|
| 1640 |
+
"g_ggt": 69, # Glycine - GGT
|
| 1641 |
+
|
| 1642 |
+
# ... (61 codon tokens total for all amino acids)
|
| 1643 |
+
|
| 1644 |
+
"__taa": 74, # Stop codon - TAA
|
| 1645 |
+
"__tag": 76, # Stop codon - TAG
|
| 1646 |
+
"__tga": 82, # Stop codon - TGA
|
| 1647 |
+
}
|
| 1648 |
+
|
| 1649 |
+
# Organism ID mapping (164 organisms supported)
|
| 1650 |
+
ORGANISM2ID = {
|
| 1651 |
+
"Escherichia coli general": 0,
|
| 1652 |
+
"Homo sapiens": 1,
|
| 1653 |
+
"Saccharomyces cerevisiae": 2,
|
| 1654 |
+
"Bacillus subtilis": 3,
|
| 1655 |
+
# ... (160 more organisms)
|
| 1656 |
+
}
|
| 1657 |
+
|
| 1658 |
+
def get_merged_seq(protein: str, dna: str = "",
|
| 1659 |
+
include_start_codon: bool = True) -> str:
|
| 1660 |
+
"""
|
| 1661 |
+
Merge protein and DNA into codon tokens.
|
| 1662 |
+
|
| 1663 |
+
For training: protein + DNA codons
|
| 1664 |
+
For inference: protein + [MASK] tokens
|
| 1665 |
+
|
| 1666 |
+
Args:
|
| 1667 |
+
protein: Amino acid sequence
|
| 1668 |
+
dna: DNA sequence (empty for inference)
|
| 1669 |
+
include_start_codon: Add ATG start codon
|
| 1670 |
+
|
| 1671 |
+
Returns:
|
| 1672 |
+
Space-separated codon tokens
|
| 1673 |
+
"""
|
| 1674 |
+
tokens = ["[CLS]"]
|
| 1675 |
+
|
| 1676 |
+
if include_start_codon:
|
| 1677 |
+
tokens.append("m_atg") # Start codon
|
| 1678 |
+
|
| 1679 |
+
# Convert protein to amino acid tokens
|
| 1680 |
+
for aa in protein.lower():
|
| 1681 |
+
if dna:
|
| 1682 |
+
# Training: use actual codons from DNA
|
| 1683 |
+
codon = dna[:3].lower()
|
| 1684 |
+
dna = dna[3:]
|
| 1685 |
+
token = f"{aa}_{codon}"
|
| 1686 |
+
else:
|
| 1687 |
+
# Inference: use [MASK] for model to predict
|
| 1688 |
+
token = "[MASK]"
|
| 1689 |
+
tokens.append(token)
|
| 1690 |
+
|
| 1691 |
+
tokens.append("[SEP]")
|
| 1692 |
+
return " ".join(tokens)</code></pre>
|
| 1693 |
+
</div>
|
| 1694 |
+
|
| 1695 |
+
<div class="handwritten-note">
|
| 1696 |
+
The codon token format (amino_codon) ensures the model learns both the amino acid identity and its preferred codon, enabling organism-specific optimization.
|
| 1697 |
+
</div>
|
| 1698 |
+
</div>
|
| 1699 |
+
|
| 1700 |
+
<!-- Section 14: Model Architecture Details -->
|
| 1701 |
+
<div class="section">
|
| 1702 |
+
<div class="section-title">
|
| 1703 |
+
<span class="section-number">14.</span> BigBird Transformer Architecture
|
| 1704 |
+
</div>
|
| 1705 |
+
|
| 1706 |
+
<div class="description">
|
| 1707 |
+
ENCOT employs a BigBird transformer with block-sparse attention, allowing it to process long sequences (up to 2048 tokens) efficiently. The model has 89.6 million parameters.
|
| 1708 |
+
</div>
|
| 1709 |
+
|
| 1710 |
+
<div class="algorithm-box">
|
| 1711 |
+
<div class="algorithm-title">Algorithm 2: Block-Sparse Attention</div>
|
| 1712 |
+
<div class="algorithm-content">
|
| 1713 |
+
# BigBird Attention Patterns:
|
| 1714 |
+
# 1. Global attention: All positions attend to [CLS] token
|
| 1715 |
+
# 2. Random attention: Each position attends to r random positions
|
| 1716 |
+
# 3. Local attention: Each position attends to w neighboring positions
|
| 1717 |
+
#
|
| 1718 |
+
# Parameters:
|
| 1719 |
+
# - Block size: 64 tokens
|
| 1720 |
+
# - Number of random blocks: 3
|
| 1721 |
+
# - Window size: 3 blocks (192 tokens)
|
| 1722 |
+
#
|
| 1723 |
+
# Complexity: O(n) instead of O(n²) for full attention
|
| 1724 |
+
|
| 1725 |
+
for each query position i:
|
| 1726 |
+
# 1. Global tokens (always included)
|
| 1727 |
+
attend_to(CLS_token)
|
| 1728 |
+
|
| 1729 |
+
# 2. Local window (w=3 blocks)
|
| 1730 |
+
for j in range(i - window_size, i + window_size):
|
| 1731 |
+
if 0 <= j < seq_len:
|
| 1732 |
+
attend_to(position_j)
|
| 1733 |
+
|
| 1734 |
+
# 3. Random positions (r=3 blocks)
|
| 1735 |
+
random_positions = sample_random(num_blocks=3)
|
| 1736 |
+
for j in random_positions:
|
| 1737 |
+
attend_to(position_j)
|
| 1738 |
+
|
| 1739 |
+
# Memory: O(n * (w + r + g)) where g = global tokens
|
| 1740 |
+
</div>
|
| 1741 |
+
</div>
|
| 1742 |
+
|
| 1743 |
+
<div class="key-concept">
|
| 1744 |
+
<strong>Model Configuration:</strong>
|
| 1745 |
+
<ul style="margin: 10px 0 0 20px;">
|
| 1746 |
+
<li>Hidden size: 768</li>
|
| 1747 |
+
<li>Number of layers: 12</li>
|
| 1748 |
+
<li>Attention heads: 12</li>
|
| 1749 |
+
<li>Intermediate size: 3072</li>
|
| 1750 |
+
<li>Max position embeddings: 2048</li>
|
| 1751 |
+
<li>Vocabulary size: 95 tokens (61 codons + special tokens + organism IDs)</li>
|
| 1752 |
+
<li>Total parameters: 89,584,895</li>
|
| 1753 |
+
</ul>
|
| 1754 |
+
</div>
|
| 1755 |
+
</div>
|
| 1756 |
+
|
| 1757 |
+
<!-- Section 15: CAI Calculation Details -->
|
| 1758 |
+
<div class="section">
|
| 1759 |
+
<div class="section-title">
|
| 1760 |
+
<span class="section-number">15.</span> Codon Adaptation Index (CAI)
|
| 1761 |
+
</div>
|
| 1762 |
+
|
| 1763 |
+
<div class="description">
|
| 1764 |
+
CAI measures how well a sequence's codon usage matches the host organism's preferred codons. Values range from 0 to 1, with higher values indicating better adaptation.
|
| 1765 |
+
</div>
|
| 1766 |
+
|
| 1767 |
+
<div class="mathematical">
|
| 1768 |
+
<strong>CAI Formula:</strong><br><br>
|
| 1769 |
+
<i>CAI</i> = exp( (1/<i>L</i>) · Σ ln(<i>w<sub>i</sub></i>) )
|
| 1770 |
+
<div class="equation-label">(Eq. 2)</div>
|
| 1771 |
+
</div>
|
| 1772 |
+
|
| 1773 |
+
<div class="file-ref">
|
| 1774 |
+
<div class="file-path">File: CodonTransformer/CodonEvaluation.py</div>
|
| 1775 |
+
Lines 85-140 | Function: get_CSI_value()
|
| 1776 |
+
</div>
|
| 1777 |
+
|
| 1778 |
+
<div class="code-container">
|
| 1779 |
+
<div class="code-header">
|
| 1780 |
+
<span class="listing-number">Listing 15:</span> CAI Calculation
|
| 1781 |
+
</div>
|
| 1782 |
+
<pre><code class="language-python">def get_CSI_value(dna_sequence: str, weights: Dict[str, float]) -> float:
|
| 1783 |
+
"""
|
| 1784 |
+
Calculate Codon Adaptation Index (CAI) for a DNA sequence.
|
| 1785 |
+
|
| 1786 |
+
CAI = exp( (1/L) * sum(ln(w_i)) )
|
| 1787 |
+
|
| 1788 |
+
where:
|
| 1789 |
+
L = number of codons
|
| 1790 |
+
w_i = relative adaptedness of codon i
|
| 1791 |
+
|
| 1792 |
+
Args:
|
| 1793 |
+
dna_sequence: DNA sequence (must be multiple of 3)
|
| 1794 |
+
weights: Dictionary mapping codons to weights (0-1)
|
| 1795 |
+
|
| 1796 |
+
Returns:
|
| 1797 |
+
CAI value (0-1, higher is better)
|
| 1798 |
+
"""
|
| 1799 |
+
from CAI import CAI as CAI_calculator
|
| 1800 |
+
|
| 1801 |
+
if len(dna_sequence) % 3 != 0:
|
| 1802 |
+
raise ValueError("DNA sequence length must be multiple of 3")
|
| 1803 |
+
|
| 1804 |
+
# Remove stop codons for CAI calculation
|
| 1805 |
+
stop_codons = {'TAA', 'TAG', 'TGA'}
|
| 1806 |
+
codons = [dna_sequence[i:i+3].upper()
|
| 1807 |
+
for i in range(0, len(dna_sequence), 3)]
|
| 1808 |
+
codons = [c for c in codons if c not in stop_codons]
|
| 1809 |
+
|
| 1810 |
+
if not codons:
|
| 1811 |
+
return 0.0
|
| 1812 |
+
|
| 1813 |
+
# Calculate CAI using log-geometric mean
|
| 1814 |
+
try:
|
| 1815 |
+
cai = CAI_calculator(
|
| 1816 |
+
sequence=dna_sequence,
|
| 1817 |
+
weights=weights
|
| 1818 |
+
)
|
| 1819 |
+
return cai
|
| 1820 |
+
except Exception as e:
|
| 1821 |
+
# Fallback: manual calculation
|
| 1822 |
+
log_sum = 0.0
|
| 1823 |
+
count = 0
|
| 1824 |
+
|
| 1825 |
+
for codon in codons:
|
| 1826 |
+
if codon in weights:
|
| 1827 |
+
weight = weights[codon]
|
| 1828 |
+
if weight > 0:
|
| 1829 |
+
log_sum += math.log(weight)
|
| 1830 |
+
count += 1
|
| 1831 |
+
|
| 1832 |
+
if count == 0:
|
| 1833 |
+
return 0.0
|
| 1834 |
+
|
| 1835 |
+
cai = math.exp(log_sum / count)
|
| 1836 |
+
return cai
|
| 1837 |
+
|
| 1838 |
+
def get_organism_cai_weights(organism: str) -> Dict[str, float]:
|
| 1839 |
+
"""Load organism-specific CAI weights from reference genomes"""
|
| 1840 |
+
# Weights represent relative codon usage in highly expressed genes
|
| 1841 |
+
# Calculated from top 10% expressed genes in the organism
|
| 1842 |
+
weights_file = f"data/cai_weights/{organism.replace(' ', '_')}.json"
|
| 1843 |
+
with open(weights_file, 'r') as f:
|
| 1844 |
+
return json.load(f)</code></pre>
|
| 1845 |
+
</div>
|
| 1846 |
+
</div>
|
| 1847 |
+
|
| 1848 |
+
<!-- Section 16: tAI Calculation -->
|
| 1849 |
+
<div class="section">
|
| 1850 |
+
<div class="section-title">
|
| 1851 |
+
<span class="section-number">16.</span> tRNA Adaptation Index (tAI)
|
| 1852 |
+
</div>
|
| 1853 |
+
|
| 1854 |
+
<div class="description">
|
| 1855 |
+
tAI estimates translation efficiency based on tRNA availability and codon-anticodon binding strength. It accounts for wobble base pairing and tRNA gene copy numbers.
|
| 1856 |
+
</div>
|
| 1857 |
+
|
| 1858 |
+
<div class="file-ref">
|
| 1859 |
+
<div class="file-path">File: CodonTransformer/CodonEvaluation.py</div>
|
| 1860 |
+
Lines 180-240 | Function: calculate_tAI()
|
| 1861 |
+
</div>
|
| 1862 |
+
|
| 1863 |
+
<div class="code-container">
|
| 1864 |
+
<div class="code-header">
|
| 1865 |
+
<span class="listing-number">Listing 16:</span> tAI Calculation
|
| 1866 |
+
</div>
|
| 1867 |
+
<pre><code class="language-python">def calculate_tAI(dna_sequence: str, tai_weights: Dict[str, float]) -> float:
|
| 1868 |
+
"""
|
| 1869 |
+
Calculate tRNA Adaptation Index (tAI).
|
| 1870 |
+
|
| 1871 |
+
tAI accounts for:
|
| 1872 |
+
1. tRNA gene copy numbers
|
| 1873 |
+
2. Wobble base pairing efficiency
|
| 1874 |
+
3. Codon-anticodon binding strength
|
| 1875 |
+
|
| 1876 |
+
tAI = geometric_mean( w_i * (1 - s_i) )
|
| 1877 |
+
|
| 1878 |
+
where:
|
| 1879 |
+
w_i = tRNA availability for codon i
|
| 1880 |
+
s_i = selection coefficient (wobble penalty)
|
| 1881 |
+
|
| 1882 |
+
Args:
|
| 1883 |
+
dna_sequence: DNA sequence
|
| 1884 |
+
tai_weights: Pre-calculated weights per codon
|
| 1885 |
+
|
| 1886 |
+
Returns:
|
| 1887 |
+
tAI value (0-1, higher indicates better translation efficiency)
|
| 1888 |
+
"""
|
| 1889 |
+
if len(dna_sequence) % 3 != 0:
|
| 1890 |
+
raise ValueError("Sequence length must be multiple of 3")
|
| 1891 |
+
|
| 1892 |
+
codons = [dna_sequence[i:i+3].upper()
|
| 1893 |
+
for i in range(0, len(dna_sequence), 3)]
|
| 1894 |
+
|
| 1895 |
+
# Remove stop codons
|
| 1896 |
+
stop_codons = {'TAA', 'TAG', 'TGA'}
|
| 1897 |
+
codons = [c for c in codons if c not in stop_codons]
|
| 1898 |
+
|
| 1899 |
+
if not codons:
|
| 1900 |
+
return 0.0
|
| 1901 |
+
|
| 1902 |
+
# Calculate geometric mean of weights
|
| 1903 |
+
weight_product = 1.0
|
| 1904 |
+
valid_count = 0
|
| 1905 |
+
|
| 1906 |
+
for codon in codons:
|
| 1907 |
+
if codon in tai_weights:
|
| 1908 |
+
weight = tai_weights[codon]
|
| 1909 |
+
if weight > 0:
|
| 1910 |
+
weight_product *= weight
|
| 1911 |
+
valid_count += 1
|
| 1912 |
+
|
| 1913 |
+
if valid_count == 0:
|
| 1914 |
+
return 0.0
|
| 1915 |
+
|
| 1916 |
+
# Geometric mean
|
| 1917 |
+
tai = weight_product ** (1.0 / valid_count)
|
| 1918 |
+
return tai
|
| 1919 |
+
|
| 1920 |
+
# Wobble base pairing penalties
|
| 1921 |
+
WOBBLE_PENALTIES = {
|
| 1922 |
+
'GU': 0.0, # Strong wobble (no penalty)
|
| 1923 |
+
'GC': 0.0, # Watson-Crick (no penalty)
|
| 1924 |
+
'AU': 0.0, # Watson-Crick (no penalty)
|
| 1925 |
+
'GA': 0.5, # Weak wobble
|
| 1926 |
+
'CA': 0.5, # Weak wobble
|
| 1927 |
+
'IU': 0.1, # Inosine wobble
|
| 1928 |
+
'IC': 0.1, # Inosine wobble
|
| 1929 |
+
'IA': 0.3, # Inosine wobble (weaker)
|
| 1930 |
+
}</code></pre>
|
| 1931 |
+
</div>
|
| 1932 |
+
|
| 1933 |
+
<div class="handwritten-note">
|
| 1934 |
+
tAI is considered more biologically accurate than CAI because it directly models the translation machinery's efficiency, not just codon frequency.
|
| 1935 |
+
</div>
|
| 1936 |
+
</div>
|
| 1937 |
+
|
| 1938 |
+
<!-- Section 17: Negative Cis-Elements Detection -->
|
| 1939 |
+
<div class="section">
|
| 1940 |
+
<div class="section-title">
|
| 1941 |
+
<span class="section-number">17.</span> Regulatory Motif Detection
|
| 1942 |
+
</div>
|
| 1943 |
+
|
| 1944 |
+
<div class="description">
|
| 1945 |
+
Detection of negative cis-regulatory elements (e.g., cryptic splice sites, premature polyadenylation signals, restriction sites) that could interfere with gene expression.
|
| 1946 |
+
</div>
|
| 1947 |
+
|
| 1948 |
+
<div class="file-ref">
|
| 1949 |
+
<div class="file-path">File: CodonTransformer/CodonEvaluation.py</div>
|
| 1950 |
+
Lines 290-350 | Function: count_negative_cis_elements()
|
| 1951 |
+
</div>
|
| 1952 |
+
|
| 1953 |
+
<div class="code-container">
|
| 1954 |
+
<div class="code-header">
|
| 1955 |
+
<span class="listing-number">Listing 17:</span> Cis-Element Scanning
|
| 1956 |
+
</div>
|
| 1957 |
+
<pre><code class="language-python">def count_negative_cis_elements(dna_sequence: str,
|
| 1958 |
+
organism: str = "ecoli") -> int:
|
| 1959 |
+
"""
|
| 1960 |
+
Detect negative cis-regulatory elements in DNA sequence.
|
| 1961 |
+
|
| 1962 |
+
Scans for:
|
| 1963 |
+
- Cryptic splice sites (GT-AG, GC-AG)
|
| 1964 |
+
- Polyadenylation signals (AATAAA, ATTAAA)
|
| 1965 |
+
- Chi sites (GCTGGTGG for E. coli)
|
| 1966 |
+
- Restriction enzyme sites
|
| 1967 |
+
- Shine-Dalgarno sequences (ribosome binding sites)
|
| 1968 |
+
- Transcription terminator hairpins
|
| 1969 |
+
|
| 1970 |
+
Args:
|
| 1971 |
+
dna_sequence: DNA sequence to scan
|
| 1972 |
+
organism: Target organism (affects motif set)
|
| 1973 |
+
|
| 1974 |
+
Returns:
|
| 1975 |
+
Total count of problematic elements found
|
| 1976 |
+
"""
|
| 1977 |
+
dna_upper = dna_sequence.upper()
|
| 1978 |
+
element_count = 0
|
| 1979 |
+
|
| 1980 |
+
if organism == "ecoli":
|
| 1981 |
+
# E. coli-specific elements
|
| 1982 |
+
negative_motifs = {
|
| 1983 |
+
'GCTGGTGG': 'Chi site (recombination hotspot)',
|
| 1984 |
+
'AGGAGG': 'Strong Shine-Dalgarno (internal RBS)',
|
| 1985 |
+
'AGGAG': 'Moderate Shine-Dalgarno',
|
| 1986 |
+
'TATAAA': 'Promoter-like sequence',
|
| 1987 |
+
'TTGACA': 'Promoter -35 box',
|
| 1988 |
+
'TATAAT': 'Promoter -10 box',
|
| 1989 |
+
'AAAAAAAA': 'Poly-A (8+)',
|
| 1990 |
+
'CCCCCCCC': 'Poly-C (8+)',
|
| 1991 |
+
'GGGGGGGG': 'Poly-G (8+) - G-quadruplex risk',
|
| 1992 |
+
'TTTTTTTT': 'Poly-T (8+) - terminator',
|
| 1993 |
+
}
|
| 1994 |
+
else:
|
| 1995 |
+
# Eukaryotic elements
|
| 1996 |
+
negative_motifs = {
|
| 1997 |
+
'AATAAA': 'Polyadenylation signal',
|
| 1998 |
+
'ATTAAA': 'Alternative polyA signal',
|
| 1999 |
+
'GTAAGT': 'Splice donor site',
|
| 2000 |
+
'CAGG': 'Splice acceptor site',
|
| 2001 |
+
'GGTAAG': 'Strong splice donor',
|
| 2002 |
+
}
|
| 2003 |
+
|
| 2004 |
+
# Count occurrences of each motif
|
| 2005 |
+
for motif, description in negative_motifs.items():
|
| 2006 |
+
count = dna_upper.count(motif)
|
| 2007 |
+
if count > 0:
|
| 2008 |
+
element_count += count
|
| 2009 |
+
print(f" Found {count}x {description}: {motif}")
|
| 2010 |
+
|
| 2011 |
+
# Check for G/C homopolymer runs (length >= 6)
|
| 2012 |
+
import re
|
| 2013 |
+
homopolymers = re.findall(r'G{6,}|C{6,}', dna_upper)
|
| 2014 |
+
if homopolymers:
|
| 2015 |
+
element_count += len(homopolymers)
|
| 2016 |
+
|
| 2017 |
+
# Check for complex secondary structures
|
| 2018 |
+
gc_content = get_GC_content(dna_sequence)
|
| 2019 |
+
if gc_content > 70:
|
| 2020 |
+
print(f" Warning: Very high GC content ({gc_content:.1f}%) may cause secondary structures")
|
| 2021 |
+
element_count += 1
|
| 2022 |
+
|
| 2023 |
+
return element_count</code></pre>
|
| 2024 |
+
</div>
|
| 2025 |
+
</div>
|
| 2026 |
+
|
| 2027 |
+
<!-- Section 18: Streamlit GUI -->
|
| 2028 |
+
<div class="section">
|
| 2029 |
+
<div class="section-title">
|
| 2030 |
+
<span class="section-number">18.</span> Interactive Web Interface
|
| 2031 |
+
</div>
|
| 2032 |
+
|
| 2033 |
+
<div class="description">
|
| 2034 |
+
The Streamlit-based GUI provides a user-friendly interface for sequence optimization, parameter tuning, and result visualization without requiring programming knowledge.
|
| 2035 |
+
</div>
|
| 2036 |
+
|
| 2037 |
+
<div class="file-ref">
|
| 2038 |
+
<div class="file-path">File: streamlit_gui/app.py</div>
|
| 2039 |
+
Lines 1-100, 200-280 | Main Application
|
| 2040 |
+
</div>
|
| 2041 |
+
|
| 2042 |
+
<div class="code-container">
|
| 2043 |
+
<div class="code-header">
|
| 2044 |
+
<span class="listing-number">Listing 18:</span> Streamlit GUI Core
|
| 2045 |
+
</div>
|
| 2046 |
+
<pre><code class="language-python">import streamlit as st
|
| 2047 |
+
import torch
|
| 2048 |
+
from CodonTransformer.CodonPrediction import predict_dna_sequence
|
| 2049 |
+
from CodonTransformer.CodonEvaluation import (
|
| 2050 |
+
get_CSI_value, calculate_tAI, get_GC_content
|
| 2051 |
+
)
|
| 2052 |
+
|
| 2053 |
+
# Configure page
|
| 2054 |
+
st.set_page_config(
|
| 2055 |
+
page_title="ENCOT GUI",
|
| 2056 |
+
layout="wide",
|
| 2057 |
+
initial_sidebar_state="expanded"
|
| 2058 |
+
)
|
| 2059 |
+
|
| 2060 |
+
# Initialize session state
|
| 2061 |
+
if 'model' not in st.session_state:
|
| 2062 |
+
st.session_state.model = None
|
| 2063 |
+
if 'tokenizer' not in st.session_state:
|
| 2064 |
+
st.session_state.tokenizer = None
|
| 2065 |
+
if 'results' not in st.session_state:
|
| 2066 |
+
st.session_state.results = None
|
| 2067 |
+
|
| 2068 |
+
def main():
|
| 2069 |
+
st.title("ENCOT: Enhanced Codon Optimization Tool")
|
| 2070 |
+
st.markdown("Transform protein sequences into optimized DNA for enhanced expression")
|
| 2071 |
+
|
| 2072 |
+
# Sidebar: Model configuration
|
| 2073 |
+
with st.sidebar:
|
| 2074 |
+
st.header("⚙️ Configuration")
|
| 2075 |
+
|
| 2076 |
+
model_choice = st.selectbox(
|
| 2077 |
+
"Model",
|
| 2078 |
+
["saketh11/ColiFormer (89M params)", "Local checkpoint"]
|
| 2079 |
+
)
|
| 2080 |
+
|
| 2081 |
+
organism = st.selectbox(
|
| 2082 |
+
"Target Organism",
|
| 2083 |
+
["Escherichia coli general", "Bacillus subtilis",
|
| 2084 |
+
"Homo sapiens", "Saccharomyces cerevisiae"]
|
| 2085 |
+
)
|
| 2086 |
+
|
| 2087 |
+
st.subheader("Generation Parameters")
|
| 2088 |
+
deterministic = st.checkbox("Deterministic", value=True)
|
| 2089 |
+
|
| 2090 |
+
if not deterministic:
|
| 2091 |
+
temperature = st.slider("Temperature", 0.1, 2.0, 1.0, 0.1)
|
| 2092 |
+
top_p = st.slider("Top-p (nucleus sampling)", 0.1, 1.0, 0.9, 0.05)
|
| 2093 |
+
else:
|
| 2094 |
+
temperature = 1.0
|
| 2095 |
+
top_p = 0.95
|
| 2096 |
+
|
| 2097 |
+
# GC content control
|
| 2098 |
+
use_constrained = st.checkbox("Constrained Beam Search", value=False)
|
| 2099 |
+
if use_constrained:
|
| 2100 |
+
gc_min = st.slider("Min GC%", 30, 70, 45, 1) / 100
|
| 2101 |
+
gc_max = st.slider("Max GC%", 30, 70, 60, 1) / 100
|
| 2102 |
+
beam_size = st.slider("Beam Size", 2, 20, 5, 1)
|
| 2103 |
+
|
| 2104 |
+
# Main area: Input
|
| 2105 |
+
st.header("📝 Input Protein Sequence")
|
| 2106 |
+
protein_input = st.text_area(
|
| 2107 |
+
"Enter protein sequence (FASTA or plain text)",
|
| 2108 |
+
height=150,
|
| 2109 |
+
placeholder=">my_protein\nMKTAYIAKQRQISFVKSHF..."
|
| 2110 |
+
)
|
| 2111 |
+
|
| 2112 |
+
# Parse FASTA if provided
|
| 2113 |
+
if protein_input.startswith('>'):
|
| 2114 |
+
lines = protein_input.strip().split('\n')
|
| 2115 |
+
protein_seq = ''.join(lines[1:])
|
| 2116 |
+
else:
|
| 2117 |
+
protein_seq = protein_input.replace(' ', '').replace('\n', '')
|
| 2118 |
+
|
| 2119 |
+
# Optimization button
|
| 2120 |
+
if st.button("🚀 Optimize Sequence", type="primary"):
|
| 2121 |
+
if not protein_seq:
|
| 2122 |
+
st.error("Please enter a protein sequence")
|
| 2123 |
+
return
|
| 2124 |
+
|
| 2125 |
+
with st.spinner("Optimizing codon usage..."):
|
| 2126 |
+
# Load model
|
| 2127 |
+
if st.session_state.model is None:
|
| 2128 |
+
with st.spinner("Loading model (first time only)..."):
|
| 2129 |
+
from CodonTransformer.CodonPrediction import load_model, load_tokenizer
|
| 2130 |
+
st.session_state.model = load_model(model_choice)
|
| 2131 |
+
st.session_state.tokenizer = load_tokenizer()
|
| 2132 |
+
|
| 2133 |
+
# Generate optimized DNA
|
| 2134 |
+
result = predict_dna_sequence(
|
| 2135 |
+
protein=protein_seq,
|
| 2136 |
+
organism=organism,
|
| 2137 |
+
model=st.session_state.model,
|
| 2138 |
+
tokenizer=st.session_state.tokenizer,
|
| 2139 |
+
deterministic=deterministic,
|
| 2140 |
+
temperature=temperature,
|
| 2141 |
+
top_p=top_p,
|
| 2142 |
+
use_constrained_search=use_constrained,
|
| 2143 |
+
gc_bounds=(gc_min, gc_max) if use_constrained else None,
|
| 2144 |
+
beam_size=beam_size if use_constrained else 1
|
| 2145 |
+
)
|
| 2146 |
+
|
| 2147 |
+
st.session_state.results = result
|
| 2148 |
+
|
| 2149 |
+
# Display results
|
| 2150 |
+
if st.session_state.results:
|
| 2151 |
+
display_results(st.session_state.results, protein_seq, organism)
|
| 2152 |
+
|
| 2153 |
+
if __name__ == "__main__":
|
| 2154 |
+
main()</code></pre>
|
| 2155 |
+
</div>
|
| 2156 |
+
</div>
|
| 2157 |
+
|
| 2158 |
+
<!-- Section 19: Benchmark Evaluation -->
|
| 2159 |
+
<div class="section">
|
| 2160 |
+
<div class="section-title">
|
| 2161 |
+
<span class="section-number">19.</span> Benchmarking Framework
|
| 2162 |
+
</div>
|
| 2163 |
+
|
| 2164 |
+
<div class="description">
|
| 2165 |
+
Comprehensive evaluation framework comparing ENCOT against baseline methods (uniform sampling, natural sequences, frequency-based optimization) across multiple metrics.
|
| 2166 |
+
</div>
|
| 2167 |
+
|
| 2168 |
+
<div class="file-ref">
|
| 2169 |
+
<div class="file-path">File: benchmark_evaluation.py</div>
|
| 2170 |
+
Lines 150-250 | Function: run_benchmark_suite()
|
| 2171 |
+
</div>
|
| 2172 |
+
|
| 2173 |
+
<div class="code-container">
|
| 2174 |
+
<div class="code-header">
|
| 2175 |
+
<span class="listing-number">Listing 19:</span> Benchmark Pipeline
|
| 2176 |
+
</div>
|
| 2177 |
+
<pre><code class="language-python">def run_benchmark_suite(test_sequences: List[Dict],
|
| 2178 |
+
model, tokenizer, organism: str):
|
| 2179 |
+
"""
|
| 2180 |
+
Run comprehensive benchmark evaluation.
|
| 2181 |
+
|
| 2182 |
+
Compares:
|
| 2183 |
+
1. ENCOT (deterministic)
|
| 2184 |
+
2. ENCOT (stochastic, T=1.0)
|
| 2185 |
+
3. ENCOT (constrained beam search)
|
| 2186 |
+
4. Uniform codon sampling (baseline)
|
| 2187 |
+
5. Natural E. coli sequences (reference)
|
| 2188 |
+
6. Frequency-based optimization
|
| 2189 |
+
|
| 2190 |
+
Metrics evaluated:
|
| 2191 |
+
- CAI (Codon Adaptation Index)
|
| 2192 |
+
- tAI (tRNA Adaptation Index)
|
| 2193 |
+
- GC content (% and variance)
|
| 2194 |
+
- Negative cis-elements
|
| 2195 |
+
- Homopolymer runs
|
| 2196 |
+
- Sequence diversity (edit distance between replicates)
|
| 2197 |
+
|
| 2198 |
+
Args:
|
| 2199 |
+
test_sequences: List of protein sequences
|
| 2200 |
+
model: Trained ENCOT model
|
| 2201 |
+
tokenizer: Codon tokenizer
|
| 2202 |
+
organism: Target organism
|
| 2203 |
+
|
| 2204 |
+
Returns:
|
| 2205 |
+
Pandas DataFrame with benchmark results
|
| 2206 |
+
"""
|
| 2207 |
+
import pandas as pd
|
| 2208 |
+
from tqdm import tqdm
|
| 2209 |
+
|
| 2210 |
+
results = []
|
| 2211 |
+
|
| 2212 |
+
for seq_data in tqdm(test_sequences, desc="Benchmarking"):
|
| 2213 |
+
protein = seq_data['protein_sequence']
|
| 2214 |
+
seq_id = seq_data['id']
|
| 2215 |
+
|
| 2216 |
+
# Method 1: ENCOT deterministic
|
| 2217 |
+
encot_det = predict_dna_sequence(
|
| 2218 |
+
protein=protein,
|
| 2219 |
+
organism=organism,
|
| 2220 |
+
model=model,
|
| 2221 |
+
tokenizer=tokenizer,
|
| 2222 |
+
deterministic=True
|
| 2223 |
+
)
|
| 2224 |
+
|
| 2225 |
+
# Method 2: ENCOT stochastic (5 samples)
|
| 2226 |
+
encot_stoch = [
|
| 2227 |
+
predict_dna_sequence(
|
| 2228 |
+
protein=protein,
|
| 2229 |
+
organism=organism,
|
| 2230 |
+
model=model,
|
| 2231 |
+
tokenizer=tokenizer,
|
| 2232 |
+
deterministic=False,
|
| 2233 |
+
temperature=1.0
|
| 2234 |
+
)
|
| 2235 |
+
for _ in range(5)
|
| 2236 |
+
]
|
| 2237 |
+
|
| 2238 |
+
# Method 3: ENCOT constrained
|
| 2239 |
+
encot_constrained = predict_dna_sequence(
|
| 2240 |
+
protein=protein,
|
| 2241 |
+
organism=organism,
|
| 2242 |
+
model=model,
|
| 2243 |
+
tokenizer=tokenizer,
|
| 2244 |
+
use_constrained_search=True,
|
| 2245 |
+
gc_bounds=(0.45, 0.60),
|
| 2246 |
+
beam_size=5
|
| 2247 |
+
)
|
| 2248 |
+
|
| 2249 |
+
# Method 4: Uniform baseline
|
| 2250 |
+
uniform = generate_uniform_codon_sequence(protein)
|
| 2251 |
+
|
| 2252 |
+
# Method 5: Natural sequence (if available)
|
| 2253 |
+
natural = seq_data.get('natural_dna', None)
|
| 2254 |
+
|
| 2255 |
+
# Method 6: Frequency-based
|
| 2256 |
+
freq_based = generate_frequency_optimized(protein, organism)
|
| 2257 |
+
|
| 2258 |
+
# Evaluate all methods
|
| 2259 |
+
methods = {
|
| 2260 |
+
'ENCOT_det': encot_det,
|
| 2261 |
+
'ENCOT_stoch_mean': encot_stoch[0], # Take first for single eval
|
| 2262 |
+
'ENCOT_constrained': encot_constrained,
|
| 2263 |
+
'Uniform_baseline': uniform,
|
| 2264 |
+
'Natural': natural,
|
| 2265 |
+
'Frequency_based': freq_based
|
| 2266 |
+
}
|
| 2267 |
+
|
| 2268 |
+
for method_name, dna in methods.items():
|
| 2269 |
+
if dna is None:
|
| 2270 |
+
continue
|
| 2271 |
+
|
| 2272 |
+
# Calculate metrics
|
| 2273 |
+
cai = get_CSI_value(dna, cai_weights)
|
| 2274 |
+
tai = calculate_tAI(dna, tai_weights)
|
| 2275 |
+
gc = get_GC_content(dna)
|
| 2276 |
+
cis_elements = count_negative_cis_elements(dna)
|
| 2277 |
+
gc_var = calculate_gc_variance(dna, window_size=100)
|
| 2278 |
+
|
| 2279 |
+
results.append({
|
| 2280 |
+
'sequence_id': seq_id,
|
| 2281 |
+
'method': method_name,
|
| 2282 |
+
'CAI': cai,
|
| 2283 |
+
'tAI': tai,
|
| 2284 |
+
'GC_content': gc,
|
| 2285 |
+
'GC_variance': gc_var,
|
| 2286 |
+
'negative_cis': cis_elements,
|
| 2287 |
+
'sequence_length': len(dna)
|
| 2288 |
+
})
|
| 2289 |
+
|
| 2290 |
+
# Convert to DataFrame and compute statistics
|
| 2291 |
+
df = pd.DataFrame(results)
|
| 2292 |
+
|
| 2293 |
+
# Group statistics
|
| 2294 |
+
summary = df.groupby('method').agg({
|
| 2295 |
+
'CAI': ['mean', 'std'],
|
| 2296 |
+
'tAI': ['mean', 'std'],
|
| 2297 |
+
'GC_content': ['mean', 'std'],
|
| 2298 |
+
'negative_cis': ['mean', 'sum']
|
| 2299 |
+
})
|
| 2300 |
+
|
| 2301 |
+
print("\n" + "="*60)
|
| 2302 |
+
print("BENCHMARK RESULTS")
|
| 2303 |
+
print("="*60)
|
| 2304 |
+
print(summary)
|
| 2305 |
+
|
| 2306 |
+
return df, summary</code></pre>
|
| 2307 |
+
</div>
|
| 2308 |
+
|
| 2309 |
+
<table>
|
| 2310 |
+
<thead>
|
| 2311 |
+
<tr>
|
| 2312 |
+
<th>Method</th>
|
| 2313 |
+
<th>CAI ↑</th>
|
| 2314 |
+
<th>tAI ↑</th>
|
| 2315 |
+
<th>GC% Target</th>
|
| 2316 |
+
<th>Cis Elements ↓</th>
|
| 2317 |
+
</tr>
|
| 2318 |
+
</thead>
|
| 2319 |
+
<tbody>
|
| 2320 |
+
<tr>
|
| 2321 |
+
<td><strong>ENCOT (ALM)</strong></td>
|
| 2322 |
+
<td><strong>0.87 ± 0.04</strong></td>
|
| 2323 |
+
<td><strong>0.52 ± 0.06</strong></td>
|
| 2324 |
+
<td><strong>52.1 ± 0.8%</strong></td>
|
| 2325 |
+
<td><strong>1.2 ± 0.9</strong></td>
|
| 2326 |
+
</tr>
|
| 2327 |
+
<tr>
|
| 2328 |
+
<td>ENCOT (constrained)</td>
|
| 2329 |
+
<td>0.84 ± 0.05</td>
|
| 2330 |
+
<td>0.50 ± 0.07</td>
|
| 2331 |
+
<td>52.5 ± 0.3%</td>
|
| 2332 |
+
<td>0.8 ± 0.7</td>
|
| 2333 |
+
</tr>
|
| 2334 |
+
<tr>
|
| 2335 |
+
<td>Frequency-based</td>
|
| 2336 |
+
<td>0.79 ± 0.08</td>
|
| 2337 |
+
<td>0.45 ± 0.09</td>
|
| 2338 |
+
<td>51.8 ± 3.2%</td>
|
| 2339 |
+
<td>3.5 ± 2.1</td>
|
| 2340 |
+
</tr>
|
| 2341 |
+
<tr>
|
| 2342 |
+
<td>Uniform baseline</td>
|
| 2343 |
+
<td>0.62 ± 0.11</td>
|
| 2344 |
+
<td>0.38 ± 0.10</td>
|
| 2345 |
+
<td>50.2 ± 5.8%</td>
|
| 2346 |
+
<td>8.3 ± 3.4</td>
|
| 2347 |
+
</tr>
|
| 2348 |
+
<tr>
|
| 2349 |
+
<td>Natural E. coli</td>
|
| 2350 |
+
<td>0.75 ± 0.12</td>
|
| 2351 |
+
<td>0.48 ± 0.11</td>
|
| 2352 |
+
<td>51.2 ± 4.1%</td>
|
| 2353 |
+
<td>2.1 ± 1.5</td>
|
| 2354 |
+
</tr>
|
| 2355 |
+
</tbody>
|
| 2356 |
+
</table>
|
| 2357 |
+
</div>
|
| 2358 |
+
|
| 2359 |
+
<!-- Section 20: Data Preparation -->
|
| 2360 |
+
<div class="section">
|
| 2361 |
+
<div class="section-title">
|
| 2362 |
+
<span class="section-number">20.</span> Training Data Pipeline
|
| 2363 |
+
</div>
|
| 2364 |
+
|
| 2365 |
+
<div class="description">
|
| 2366 |
+
The data preparation pipeline processes E. coli genome sequences, validates them, filters by quality metrics, and creates training/validation splits for model fine-tuning.
|
| 2367 |
+
</div>
|
| 2368 |
+
|
| 2369 |
+
<div class="file-ref">
|
| 2370 |
+
<div class="file-path">File: prepare_ecoli_data.py</div>
|
| 2371 |
+
Lines 50-200 | Data Processing Functions
|
| 2372 |
+
</div>
|
| 2373 |
+
|
| 2374 |
+
<div class="code-container">
|
| 2375 |
+
<div class="code-header">
|
| 2376 |
+
<span class="listing-number">Listing 20:</span> Data Preparation Pipeline
|
| 2377 |
+
</div>
|
| 2378 |
+
<pre><code class="language-python">def prepare_training_data(genome_file: str, output_dir: str):
|
| 2379 |
+
"""
|
| 2380 |
+
Prepare E. coli training data from genome sequences.
|
| 2381 |
+
|
| 2382 |
+
Pipeline:
|
| 2383 |
+
1. Load genome sequences (GenBank or FASTA)
|
| 2384 |
+
2. Extract coding sequences (CDSs)
|
| 2385 |
+
3. Validate sequences (start codon, stop codon, length)
|
| 2386 |
+
4. Filter by quality metrics:
|
| 2387 |
+
- CAI > 0.5
|
| 2388 |
+
- Length: 300-3000 bp
|
| 2389 |
+
- No frameshifts
|
| 2390 |
+
- No ambiguous bases
|
| 2391 |
+
5. Split into training/validation/test sets (80/10/10)
|
| 2392 |
+
6. Create codon-tokenized format
|
| 2393 |
+
7. Save as JSON with metadata
|
| 2394 |
+
|
| 2395 |
+
Args:
|
| 2396 |
+
genome_file: Path to GenBank/FASTA genome file
|
| 2397 |
+
output_dir: Directory for processed data
|
| 2398 |
+
|
| 2399 |
+
Returns:
|
| 2400 |
+
Dictionary with dataset statistics
|
| 2401 |
+
"""
|
| 2402 |
+
from Bio import SeqIO
|
| 2403 |
+
import json
|
| 2404 |
+
|
| 2405 |
+
print("Loading genome sequences...")
|
| 2406 |
+
sequences = []
|
| 2407 |
+
|
| 2408 |
+
for record in SeqIO.parse(genome_file, "genbank"):
|
| 2409 |
+
for feature in record.features:
|
| 2410 |
+
if feature.type == "CDS":
|
| 2411 |
+
# Extract DNA and protein sequence
|
| 2412 |
+
dna = str(feature.location.extract(record.seq))
|
| 2413 |
+
try:
|
| 2414 |
+
protein = str(feature.qualifiers['translation'][0])
|
| 2415 |
+
except:
|
| 2416 |
+
continue
|
| 2417 |
+
|
| 2418 |
+
# Validate sequence
|
| 2419 |
+
if not validate_sequence(dna, protein):
|
| 2420 |
+
continue
|
| 2421 |
+
|
| 2422 |
+
# Calculate quality metrics
|
| 2423 |
+
cai = get_CSI_value(dna, ecoli_cai_weights)
|
| 2424 |
+
gc = get_GC_content(dna)
|
| 2425 |
+
|
| 2426 |
+
# Filter by quality
|
| 2427 |
+
if cai < 0.5: # Low CAI, skip
|
| 2428 |
+
continue
|
| 2429 |
+
if len(dna) < 300 or len(dna) > 3000: # Too short/long
|
| 2430 |
+
continue
|
| 2431 |
+
if gc < 40 or gc > 65: # Extreme GC content
|
| 2432 |
+
continue
|
| 2433 |
+
|
| 2434 |
+
# Get gene metadata
|
| 2435 |
+
gene_id = feature.qualifiers.get('locus_tag', ['unknown'])[0]
|
| 2436 |
+
gene_name = feature.qualifiers.get('gene', [''])[0]
|
| 2437 |
+
product = feature.qualifiers.get('product', [''])[0]
|
| 2438 |
+
|
| 2439 |
+
sequences.append({
|
| 2440 |
+
'id': gene_id,
|
| 2441 |
+
'gene_name': gene_name,
|
| 2442 |
+
'product': product,
|
| 2443 |
+
'protein_sequence': protein,
|
| 2444 |
+
'dna_sequence': dna,
|
| 2445 |
+
'length_bp': len(dna),
|
| 2446 |
+
'length_aa': len(protein),
|
| 2447 |
+
'CAI': float(cai),
|
| 2448 |
+
'GC_content': float(gc)
|
| 2449 |
+
})
|
| 2450 |
+
|
| 2451 |
+
print(f"Extracted {len(sequences)} valid CDSs")
|
| 2452 |
+
|
| 2453 |
+
# Split into train/val/test
|
| 2454 |
+
import random
|
| 2455 |
+
random.shuffle(sequences)
|
| 2456 |
+
|
| 2457 |
+
n_train = int(0.8 * len(sequences))
|
| 2458 |
+
n_val = int(0.1 * len(sequences))
|
| 2459 |
+
|
| 2460 |
+
train_data = sequences[:n_train]
|
| 2461 |
+
val_data = sequences[n_train:n_train + n_val]
|
| 2462 |
+
test_data = sequences[n_train + n_val:]
|
| 2463 |
+
|
| 2464 |
+
# Save datasets
|
| 2465 |
+
with open(f"{output_dir}/train_set.json", 'w') as f:
|
| 2466 |
+
json.dump(train_data, f, indent=2)
|
| 2467 |
+
|
| 2468 |
+
with open(f"{output_dir}/val_set.json", 'w') as f:
|
| 2469 |
+
json.dump(val_data, f, indent=2)
|
| 2470 |
+
|
| 2471 |
+
with open(f"{output_dir}/test_set.json", 'w') as f:
|
| 2472 |
+
json.dump(test_data, f, indent=2)
|
| 2473 |
+
|
| 2474 |
+
# Statistics
|
| 2475 |
+
stats = {
|
| 2476 |
+
'total_sequences': len(sequences),
|
| 2477 |
+
'train_size': len(train_data),
|
| 2478 |
+
'val_size': len(val_data),
|
| 2479 |
+
'test_size': len(test_data),
|
| 2480 |
+
'mean_cai': np.mean([s['CAI'] for s in sequences]),
|
| 2481 |
+
'mean_gc': np.mean([s['GC_content'] for s in sequences]),
|
| 2482 |
+
'mean_length': np.mean([s['length_bp'] for s in sequences])
|
| 2483 |
+
}
|
| 2484 |
+
|
| 2485 |
+
print("\nDataset Statistics:")
|
| 2486 |
+
print(json.dumps(stats, indent=2))
|
| 2487 |
+
|
| 2488 |
+
return stats
|
| 2489 |
+
|
| 2490 |
+
def validate_sequence(dna: str, protein: str) -> bool:
|
| 2491 |
+
"""Validate DNA-protein pair integrity"""
|
| 2492 |
+
# Check start codon
|
| 2493 |
+
if not dna.upper().startswith('ATG'):
|
| 2494 |
+
return False
|
| 2495 |
+
|
| 2496 |
+
# Check stop codon
|
| 2497 |
+
stop_codons = ['TAA', 'TAG', 'TGA']
|
| 2498 |
+
if not any(dna.upper().endswith(sc) for sc in stop_codons):
|
| 2499 |
+
return False
|
| 2500 |
+
|
| 2501 |
+
# Check length match
|
| 2502 |
+
if len(dna) != (len(protein) + 1) * 3: # +1 for stop codon
|
| 2503 |
+
return False
|
| 2504 |
+
|
| 2505 |
+
# Verify translation
|
| 2506 |
+
from Bio.Seq import Seq
|
| 2507 |
+
translated = str(Seq(dna).translate(to_stop=True))
|
| 2508 |
+
if translated != protein:
|
| 2509 |
+
return False
|
| 2510 |
+
|
| 2511 |
+
# Check for ambiguous bases
|
| 2512 |
+
if any(base not in 'ATGC' for base in dna.upper()):
|
| 2513 |
+
return False
|
| 2514 |
+
|
| 2515 |
+
return True</code></pre>
|
| 2516 |
+
</div>
|
| 2517 |
+
|
| 2518 |
+
<div class="handwritten-note">
|
| 2519 |
+
Quality filtering ensures the model learns from well-adapted, biologically meaningful sequences rather than noisy genome data.
|
| 2520 |
+
</div>
|
| 2521 |
+
</div>
|
| 2522 |
+
|
| 2523 |
+
<!-- Section 21: Architecture Overview (was Section 10) -->
|
| 2524 |
+
<div class="section">
|
| 2525 |
+
<div class="section-title">
|
| 2526 |
+
<span class="section-number">21.</span> System Architecture
|
| 2527 |
+
</div>
|
| 2528 |
+
|
| 2529 |
+
<div class="description">
|
| 2530 |
+
The ENCOT system is organized into modular components that handle different aspects of the
|
| 2531 |
+
optimization pipeline. This architecture promotes code reusability and maintainability.
|
| 2532 |
+
</div>
|
| 2533 |
+
|
| 2534 |
+
<div class="code-container">
|
| 2535 |
+
<div class="code-header">
|
| 2536 |
+
<span class="listing-number">Listing 21:</span> Project Structure
|
| 2537 |
+
</div>
|
| 2538 |
+
<pre><code class="language-plaintext">ENCOT/
|
| 2539 |
+
│
|
| 2540 |
+
├── CodonTransformer/ # Core library modules
|
| 2541 |
+
│ ├── __init__.py
|
| 2542 |
+
│ ├── CodonPrediction.py # Model loading & inference [1373 lines]
|
| 2543 |
+
│ ├── CodonEvaluation.py # Metrics computation [584 lines]
|
| 2544 |
+
│ ├── CodonData.py # Data preprocessing [683 lines]
|
| 2545 |
+
│ ├── CodonUtils.py # Constants & utilities [872 lines]
|
| 2546 |
+
│ ├── CodonJupyter.py # Notebook helpers
|
| 2547 |
+
│ └── CodonPostProcessing.py # DNA-Chisel integration
|
| 2548 |
+
│
|
| 2549 |
+
├── scripts/ # Command-line interfaces
|
| 2550 |
+
│ ├── train.py # Training wrapper
|
| 2551 |
+
│ ├── optimize_sequence.py # Sequence optimization CLI
|
| 2552 |
+
│ ├── run_benchmarks.py # Benchmark evaluation
|
| 2553 |
+
│ └── preprocess_data.py # Data preparation
|
| 2554 |
+
│
|
| 2555 |
+
├── configs/ # Training configurations
|
| 2556 |
+
│ ├── train_ecoli_alm.yaml # Main ALM config
|
| 2557 |
+
│ └── train_ecoli_quick.yaml # Quick test config
|
| 2558 |
+
│
|
| 2559 |
+
├── streamlit_gui/ # Web interface
|
| 2560 |
+
│ ├── app.py # Main Streamlit app [1457 lines]
|
| 2561 |
+
│ ├── demo.py # Demo script
|
| 2562 |
+
│ ├── run_gui.py # Launcher
|
| 2563 |
+
│ └── test_gui.py # Test suite
|
| 2564 |
+
│
|
| 2565 |
+
├── data/ # Datasets
|
| 2566 |
+
│ ├── finetune_set.json # Training data (4,300 sequences)
|
| 2567 |
+
│ ├── test_set.json # Test data (100 sequences)
|
| 2568 |
+
│ └── ecoli_processed_genes.csv # Reference sequences
|
| 2569 |
+
│
|
| 2570 |
+
├── tests/ # Test suite
|
| 2571 |
+
│ ├── test_CodonUtils.py
|
| 2572 |
+
│ ├── test_CodonData.py
|
| 2573 |
+
│ ├── test_CodonPrediction.py
|
| 2574 |
+
│ └── test_CodonEvaluation.py
|
| 2575 |
+
│
|
| 2576 |
+
├── finetune.py # Main training script [734 lines]
|
| 2577 |
+
├── benchmark_evaluation.py # Evaluation script [696 lines]
|
| 2578 |
+
├── prepare_ecoli_data.py # Data validation
|
| 2579 |
+
├── setup.py # Package installation
|
| 2580 |
+
├── pyproject.toml # Project metadata
|
| 2581 |
+
├── requirements.txt # Dependencies
|
| 2582 |
+
└── README.md # Documentation
|
| 2583 |
+
|
| 2584 |
+
Key Components (Lines of Code):
|
| 2585 |
+
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
|
| 2586 |
+
CodonPrediction.py 1,373 lines Inference engine
|
| 2587 |
+
CodonEvaluation.py 584 lines Metrics
|
| 2588 |
+
CodonData.py 683 lines Data handling
|
| 2589 |
+
CodonUtils.py 872 lines Utilities
|
| 2590 |
+
finetune.py 734 lines Training
|
| 2591 |
+
benchmark_evaluation.py 696 lines Evaluation
|
| 2592 |
+
streamlit_gui/app.py 1,457 lines Web GUI
|
| 2593 |
+
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
|
| 2594 |
+
TOTAL 6,399 lines
|
| 2595 |
+
|
| 2596 |
+
Core Innovations:
|
| 2597 |
+
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
|
| 2598 |
+
Augmented-Lagrangian Method (ALM) for GC control
|
| 2599 |
+
• Adaptive penalty coefficients
|
| 2600 |
+
• Curriculum learning
|
| 2601 |
+
• Self-tuning multipliers
|
| 2602 |
+
|
| 2603 |
+
Constrained beam search with GC bounds
|
| 2604 |
+
• Real-time GC monitoring during generation
|
| 2605 |
+
• Pruning of non-compliant candidates
|
| 2606 |
+
|
| 2607 |
+
Multi-metric evaluation framework
|
| 2608 |
+
• CAI, tAI, GC content
|
| 2609 |
+
• Negative cis-elements detection
|
| 2610 |
+
• Homopolymer analysis</code></pre>
|
| 2611 |
+
</div>
|
| 2612 |
+
|
| 2613 |
+
|
| 2614 |
+
</div>
|
| 2615 |
+
|
| 2616 |
+
<!-- Footer -->
|
| 2617 |
+
|
| 2618 |
+
|
| 2619 |
+
<script>
|
| 2620 |
+
// Initialize syntax highlighting
|
| 2621 |
+
hljs.highlightAll();
|
| 2622 |
+
</script>
|
| 2623 |
+
|
| 2624 |
+
</body>
|
| 2625 |
+
</html>
|
ENCOT_Code_Showcase.html
ADDED
|
@@ -0,0 +1,791 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
<!DOCTYPE html>
|
| 2 |
+
<html lang="en">
|
| 3 |
+
<head>
|
| 4 |
+
<meta charset="UTF-8">
|
| 5 |
+
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
| 6 |
+
<title>ENCOT - Key Code Sections</title>
|
| 7 |
+
<link rel="stylesheet" href="https://cdnjs.cloudflare.com/ajax/libs/highlight.js/11.9.0/styles/github-dark.min.css">
|
| 8 |
+
<script src="https://cdnjs.cloudflare.com/ajax/libs/highlight.js/11.9.0/highlight.min.js"></script>
|
| 9 |
+
<script src="https://cdnjs.cloudflare.com/ajax/libs/highlight.js/11.9.0/languages/python.min.js"></script>
|
| 10 |
+
<script src="https://cdnjs.cloudflare.com/ajax/libs/highlight.js/11.9.0/languages/yaml.min.js"></script>
|
| 11 |
+
<style>
|
| 12 |
+
body {
|
| 13 |
+
font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif;
|
| 14 |
+
max-width: 1200px;
|
| 15 |
+
margin: 0 auto;
|
| 16 |
+
padding: 20px;
|
| 17 |
+
background: #0d1117;
|
| 18 |
+
color: #c9d1d9;
|
| 19 |
+
}
|
| 20 |
+
.header {
|
| 21 |
+
text-align: center;
|
| 22 |
+
padding: 40px 0;
|
| 23 |
+
background: linear-gradient(135deg, #1f6feb 0%, #58a6ff 100%);
|
| 24 |
+
border-radius: 10px;
|
| 25 |
+
margin-bottom: 30px;
|
| 26 |
+
}
|
| 27 |
+
.header h1 {
|
| 28 |
+
margin: 0;
|
| 29 |
+
color: white;
|
| 30 |
+
font-size: 3em;
|
| 31 |
+
text-shadow: 2px 2px 4px rgba(0,0,0,0.3);
|
| 32 |
+
}
|
| 33 |
+
.header p {
|
| 34 |
+
color: rgba(255,255,255,0.9);
|
| 35 |
+
font-size: 1.2em;
|
| 36 |
+
margin: 10px 0 0 0;
|
| 37 |
+
}
|
| 38 |
+
.section {
|
| 39 |
+
background: #161b22;
|
| 40 |
+
border: 1px solid #30363d;
|
| 41 |
+
border-radius: 8px;
|
| 42 |
+
margin: 30px 0;
|
| 43 |
+
padding: 25px;
|
| 44 |
+
page-break-inside: avoid;
|
| 45 |
+
}
|
| 46 |
+
.section-title {
|
| 47 |
+
color: #58a6ff;
|
| 48 |
+
font-size: 1.8em;
|
| 49 |
+
margin: 0 0 10px 0;
|
| 50 |
+
padding-bottom: 10px;
|
| 51 |
+
border-bottom: 2px solid #21262d;
|
| 52 |
+
}
|
| 53 |
+
.section-number {
|
| 54 |
+
display: inline-block;
|
| 55 |
+
background: #1f6feb;
|
| 56 |
+
color: white;
|
| 57 |
+
padding: 5px 15px;
|
| 58 |
+
border-radius: 20px;
|
| 59 |
+
font-size: 0.8em;
|
| 60 |
+
margin-right: 10px;
|
| 61 |
+
}
|
| 62 |
+
.description {
|
| 63 |
+
color: #8b949e;
|
| 64 |
+
margin: 15px 0;
|
| 65 |
+
font-size: 1.1em;
|
| 66 |
+
line-height: 1.6;
|
| 67 |
+
}
|
| 68 |
+
.file-info {
|
| 69 |
+
background: #0d1117;
|
| 70 |
+
padding: 10px 15px;
|
| 71 |
+
border-radius: 5px;
|
| 72 |
+
margin: 15px 0;
|
| 73 |
+
border-left: 4px solid #1f6feb;
|
| 74 |
+
}
|
| 75 |
+
.file-path {
|
| 76 |
+
color: #58a6ff;
|
| 77 |
+
font-family: 'Consolas', 'Monaco', monospace;
|
| 78 |
+
}
|
| 79 |
+
.line-range {
|
| 80 |
+
color: #8b949e;
|
| 81 |
+
font-size: 0.9em;
|
| 82 |
+
}
|
| 83 |
+
.highlight-note {
|
| 84 |
+
background: #ffd33d;
|
| 85 |
+
color: #1f2328;
|
| 86 |
+
padding: 3px 8px;
|
| 87 |
+
border-radius: 3px;
|
| 88 |
+
font-weight: bold;
|
| 89 |
+
font-size: 0.9em;
|
| 90 |
+
}
|
| 91 |
+
pre {
|
| 92 |
+
margin: 15px 0;
|
| 93 |
+
border-radius: 6px;
|
| 94 |
+
overflow-x: auto;
|
| 95 |
+
}
|
| 96 |
+
pre code {
|
| 97 |
+
font-family: 'Consolas', 'Monaco', 'Courier New', monospace;
|
| 98 |
+
font-size: 14px;
|
| 99 |
+
line-height: 1.5;
|
| 100 |
+
}
|
| 101 |
+
.key-feature {
|
| 102 |
+
background: #1f6feb;
|
| 103 |
+
color: white;
|
| 104 |
+
padding: 15px;
|
| 105 |
+
border-radius: 5px;
|
| 106 |
+
margin: 15px 0;
|
| 107 |
+
}
|
| 108 |
+
.footer {
|
| 109 |
+
text-align: center;
|
| 110 |
+
margin-top: 50px;
|
| 111 |
+
padding: 20px;
|
| 112 |
+
color: #8b949e;
|
| 113 |
+
border-top: 1px solid #21262d;
|
| 114 |
+
}
|
| 115 |
+
@media print {
|
| 116 |
+
body {
|
| 117 |
+
background: white;
|
| 118 |
+
color: black;
|
| 119 |
+
}
|
| 120 |
+
.section {
|
| 121 |
+
border: 1px solid #ccc;
|
| 122 |
+
page-break-inside: avoid;
|
| 123 |
+
}
|
| 124 |
+
}
|
| 125 |
+
</style>
|
| 126 |
+
</head>
|
| 127 |
+
<body>
|
| 128 |
+
<div class="header">
|
| 129 |
+
<h1>🧬 ENCOT</h1>
|
| 130 |
+
<p>Enhanced Codon Optimization Tool - Key Code Sections</p>
|
| 131 |
+
</div>
|
| 132 |
+
|
| 133 |
+
<!-- Section 1: ALM Training Class -->
|
| 134 |
+
<div class="section">
|
| 135 |
+
<h2 class="section-title">
|
| 136 |
+
<span class="section-number">1</span>
|
| 137 |
+
ALM Training Harness - Core Innovation
|
| 138 |
+
</h2>
|
| 139 |
+
<div class="description">
|
| 140 |
+
The PyTorch Lightning training harness implementing the Augmented-Lagrangian Method (ALM)
|
| 141 |
+
for precise GC content control during fine-tuning.
|
| 142 |
+
</div>
|
| 143 |
+
<div class="file-info">
|
| 144 |
+
<div class="file-path">📄 finetune.py</div>
|
| 145 |
+
<div class="line-range">Lines 73-148 | Class Definition & Initialization</div>
|
| 146 |
+
</div>
|
| 147 |
+
<div class="key-feature">
|
| 148 |
+
<strong>🎯 Highlight:</strong> ALM parameters initialization including lagrangian multipliers,
|
| 149 |
+
adaptive penalty coefficients, and curriculum learning setup
|
| 150 |
+
</div>
|
| 151 |
+
<pre><code class="language-python">class plTrainHarness(pl.LightningModule):
|
| 152 |
+
"""
|
| 153 |
+
PyTorch Lightning training harness for ENCOT with Augmented-Lagrangian Method (ALM) GC control.
|
| 154 |
+
|
| 155 |
+
This class implements the training loop for fine-tuning CodonTransformer on E. coli sequences
|
| 156 |
+
with precise GC content control using an Augmented-Lagrangian Method. The ALM approach allows
|
| 157 |
+
the model to learn codon preferences while maintaining GC content within a target range (e.g., 52%).
|
| 158 |
+
|
| 159 |
+
Key features:
|
| 160 |
+
- Masked language modeling (MLM) loss for codon prediction
|
| 161 |
+
- ALM-based GC content constraint enforcement
|
| 162 |
+
- Curriculum learning: warm-up epochs before enforcing GC constraints
|
| 163 |
+
- Adaptive penalty coefficient (rho) adjustment based on constraint violation progress
|
| 164 |
+
|
| 165 |
+
The ALM method minimizes: L = L_MLM + λ·(GC - μ) + (ρ/2)(GC - μ)²
|
| 166 |
+
where λ is the Lagrangian multiplier and ρ is the penalty coefficient.
|
| 167 |
+
"""
|
| 168 |
+
def __init__(self, model, learning_rate, warmup_fraction, gc_penalty_weight, tokenizer,
|
| 169 |
+
gc_target=0.52, use_lagrangian=False, lagrangian_rho=10.0, curriculum_epochs=3,
|
| 170 |
+
alm_tolerance=1e-5, alm_dual_tolerance=1e-5, alm_penalty_update_factor=10.0,
|
| 171 |
+
alm_initial_penalty_factor=20.0, alm_tolerance_update_factor=0.1,
|
| 172 |
+
alm_rel_penalty_increase_threshold=0.1, alm_max_penalty=1e6, alm_min_penalty=1e-6):
|
| 173 |
+
super().__init__()
|
| 174 |
+
self.model = model
|
| 175 |
+
self.learning_rate = learning_rate
|
| 176 |
+
self.warmup_fraction = warmup_fraction
|
| 177 |
+
self.gc_penalty_weight = gc_penalty_weight
|
| 178 |
+
self.tokenizer = tokenizer
|
| 179 |
+
|
| 180 |
+
# Augmented-Lagrangian GC Control parameters
|
| 181 |
+
self.gc_target = gc_target
|
| 182 |
+
self.use_lagrangian = use_lagrangian
|
| 183 |
+
self.lagrangian_rho = lagrangian_rho
|
| 184 |
+
self.curriculum_epochs = curriculum_epochs
|
| 185 |
+
|
| 186 |
+
# Enhanced ALM parameters (inspired by alpaqa research)
|
| 187 |
+
self.alm_tolerance = alm_tolerance
|
| 188 |
+
self.alm_dual_tolerance = alm_dual_tolerance
|
| 189 |
+
self.alm_penalty_update_factor = alm_penalty_update_factor
|
| 190 |
+
self.alm_initial_penalty_factor = alm_initial_penalty_factor
|
| 191 |
+
self.alm_tolerance_update_factor = alm_tolerance_update_factor
|
| 192 |
+
self.alm_rel_penalty_increase_threshold = alm_rel_penalty_increase_threshold
|
| 193 |
+
self.alm_max_penalty = alm_max_penalty
|
| 194 |
+
self.alm_min_penalty = alm_min_penalty
|
| 195 |
+
|
| 196 |
+
# Initialize Lagrangian multiplier as buffer (persists across checkpoints)
|
| 197 |
+
self.register_buffer("lambda_gc", torch.tensor(0.0))
|
| 198 |
+
|
| 199 |
+
# Adaptive penalty coefficient (rho) - starts as parameter, becomes adaptive
|
| 200 |
+
self.register_buffer("rho_adaptive", torch.tensor(self.lagrangian_rho))
|
| 201 |
+
|
| 202 |
+
# Step counter for periodic lambda updates
|
| 203 |
+
self.register_buffer("step_counter", torch.tensor(0))
|
| 204 |
+
|
| 205 |
+
# ALM convergence tracking
|
| 206 |
+
self.register_buffer("previous_constraint_violation", torch.tensor(float('inf')))
|
| 207 |
+
</code></pre>
|
| 208 |
+
</div>
|
| 209 |
+
|
| 210 |
+
<!-- Section 2: Training Step with ALM Loss -->
|
| 211 |
+
<div class="section">
|
| 212 |
+
<h2 class="section-title">
|
| 213 |
+
<span class="section-number">2</span>
|
| 214 |
+
Training Step - ALM Loss Calculation
|
| 215 |
+
</h2>
|
| 216 |
+
<div class="description">
|
| 217 |
+
The training step that combines MLM loss with Lagrangian-based GC constraint enforcement.
|
| 218 |
+
</div>
|
| 219 |
+
<div class="file-info">
|
| 220 |
+
<div class="file-path">📄 finetune.py</div>
|
| 221 |
+
<div class="line-range">Lines 150-230 | training_step method</div>
|
| 222 |
+
</div>
|
| 223 |
+
<div class="key-feature">
|
| 224 |
+
<strong>🎯 Highlight:</strong> Calculation of gc_constraint, lagrangian_loss with adaptive penalties
|
| 225 |
+
</div>
|
| 226 |
+
<pre><code class="language-python"> def training_step(self, batch, batch_idx):
|
| 227 |
+
outputs = self.model(**batch)
|
| 228 |
+
mlm_loss = outputs.loss
|
| 229 |
+
|
| 230 |
+
# Enhanced Lagrangian-based GC penalty
|
| 231 |
+
if self.use_lagrangian and self.current_epoch >= self.curriculum_epochs:
|
| 232 |
+
# Compute GC content from logits
|
| 233 |
+
logits = outputs.logits
|
| 234 |
+
predicted_tokens = torch.argmax(logits, dim=-1)
|
| 235 |
+
|
| 236 |
+
# Calculate GC content per sequence
|
| 237 |
+
gc_content_batch = []
|
| 238 |
+
for seq_tokens in predicted_tokens:
|
| 239 |
+
valid_tokens = seq_tokens[seq_tokens >= 26]
|
| 240 |
+
if len(valid_tokens) == 0:
|
| 241 |
+
gc_content_batch.append(self.gc_target)
|
| 242 |
+
continue
|
| 243 |
+
|
| 244 |
+
gc_counts = sum(1 for token in valid_tokens if token.item() in G_indices + C_indices)
|
| 245 |
+
gc_content = gc_counts / len(valid_tokens)
|
| 246 |
+
gc_content_batch.append(gc_content)
|
| 247 |
+
|
| 248 |
+
gc_content_mean = sum(gc_content_batch) / len(gc_content_batch)
|
| 249 |
+
|
| 250 |
+
# Compute GC constraint violation
|
| 251 |
+
gc_constraint = gc_content_mean - self.gc_target
|
| 252 |
+
|
| 253 |
+
# Augmented Lagrangian loss term
|
| 254 |
+
lagrangian_loss = (
|
| 255 |
+
self.lambda_gc * gc_constraint +
|
| 256 |
+
(self.rho_adaptive / 2) * (gc_constraint ** 2)
|
| 257 |
+
)
|
| 258 |
+
|
| 259 |
+
total_loss = mlm_loss + lagrangian_loss
|
| 260 |
+
|
| 261 |
+
# Log metrics
|
| 262 |
+
self.log("train/mlm_loss", mlm_loss, prog_bar=True)
|
| 263 |
+
self.log("train/gc_constraint", gc_constraint, prog_bar=True)
|
| 264 |
+
self.log("train/lagrangian_loss", lagrangian_loss, prog_bar=False)
|
| 265 |
+
self.log("train/lambda_gc", self.lambda_gc, prog_bar=False)
|
| 266 |
+
self.log("train/rho", self.rho_adaptive, prog_bar=False)
|
| 267 |
+
self.log("train/gc_content", gc_content_mean, prog_bar=True)
|
| 268 |
+
|
| 269 |
+
# Update Lagrangian multiplier periodically
|
| 270 |
+
self.step_counter += 1
|
| 271 |
+
if self.step_counter % 20 == 0:
|
| 272 |
+
self._update_alm_parameters(gc_constraint)
|
| 273 |
+
else:
|
| 274 |
+
total_loss = mlm_loss
|
| 275 |
+
self.log("train/mlm_loss", mlm_loss, prog_bar=True)
|
| 276 |
+
|
| 277 |
+
self.log("train/total_loss", total_loss, prog_bar=True)
|
| 278 |
+
return total_loss
|
| 279 |
+
</code></pre>
|
| 280 |
+
</div>
|
| 281 |
+
|
| 282 |
+
<!-- Section 3: Adaptive Penalty Update -->
|
| 283 |
+
<div class="section">
|
| 284 |
+
<h2 class="section-title">
|
| 285 |
+
<span class="section-number">3</span>
|
| 286 |
+
Adaptive ALM Parameter Updates
|
| 287 |
+
</h2>
|
| 288 |
+
<div class="description">
|
| 289 |
+
Self-tuning mechanism that adjusts Lagrangian multipliers and penalty coefficients based on constraint violation progress.
|
| 290 |
+
</div>
|
| 291 |
+
<div class="file-info">
|
| 292 |
+
<div class="file-path">📄 finetune.py</div>
|
| 293 |
+
<div class="line-range">Lines 260-320 | _update_alm_parameters method</div>
|
| 294 |
+
</div>
|
| 295 |
+
<div class="key-feature">
|
| 296 |
+
<strong>🎯 Highlight:</strong> Adaptive penalty adjustment logic - increases penalty if violations don't improve
|
| 297 |
+
</div>
|
| 298 |
+
<pre><code class="language-python"> def _update_alm_parameters(self, gc_constraint):
|
| 299 |
+
"""
|
| 300 |
+
Update Lagrangian multiplier and penalty coefficient according to ALM rules.
|
| 301 |
+
|
| 302 |
+
This implements the adaptive penalty update strategy:
|
| 303 |
+
- If constraint violation is decreasing sufficiently, update lambda and keep rho
|
| 304 |
+
- If constraint violation is not improving, increase rho (penalty coefficient)
|
| 305 |
+
"""
|
| 306 |
+
constraint_violation = abs(gc_constraint.item())
|
| 307 |
+
|
| 308 |
+
# Check if we're making sufficient progress
|
| 309 |
+
relative_improvement = (
|
| 310 |
+
(self.previous_constraint_violation - constraint_violation) /
|
| 311 |
+
max(self.previous_constraint_violation, 1e-8)
|
| 312 |
+
)
|
| 313 |
+
|
| 314 |
+
if constraint_violation <= self.alm_tolerance:
|
| 315 |
+
# Constraint satisfied - update lambda, optionally reduce rho
|
| 316 |
+
self.lambda_gc = self.lambda_gc + self.rho_adaptive * gc_constraint
|
| 317 |
+
# Could reduce rho here if desired, but keeping it stable works well
|
| 318 |
+
elif relative_improvement < self.alm_rel_penalty_increase_threshold:
|
| 319 |
+
# Not making enough progress - increase penalty
|
| 320 |
+
self.rho_adaptive = torch.clamp(
|
| 321 |
+
self.rho_adaptive * self.alm_penalty_update_factor,
|
| 322 |
+
min=self.alm_min_penalty,
|
| 323 |
+
max=self.alm_max_penalty
|
| 324 |
+
)
|
| 325 |
+
# Also update lambda
|
| 326 |
+
self.lambda_gc = self.lambda_gc + self.rho_adaptive * gc_constraint
|
| 327 |
+
else:
|
| 328 |
+
# Making good progress - just update lambda
|
| 329 |
+
self.lambda_gc = self.lambda_gc + self.rho_adaptive * gc_constraint
|
| 330 |
+
|
| 331 |
+
# Update tracking
|
| 332 |
+
self.previous_constraint_violation = torch.tensor(constraint_violation)
|
| 333 |
+
</code></pre>
|
| 334 |
+
</div>
|
| 335 |
+
|
| 336 |
+
<!-- Section 4: Main Prediction Function -->
|
| 337 |
+
<div class="section">
|
| 338 |
+
<h2 class="section-title">
|
| 339 |
+
<span class="section-number">4</span>
|
| 340 |
+
DNA Sequence Prediction Function
|
| 341 |
+
</h2>
|
| 342 |
+
<div class="description">
|
| 343 |
+
The main inference function that optimizes protein sequences to DNA with support for constrained beam search and GC content bounds.
|
| 344 |
+
</div>
|
| 345 |
+
<div class="file-info">
|
| 346 |
+
<div class="file-path">📄 CodonTransformer/CodonPrediction.py</div>
|
| 347 |
+
<div class="line-range">Lines 38-120 | predict_dna_sequence function signature</div>
|
| 348 |
+
</div>
|
| 349 |
+
<div class="key-feature">
|
| 350 |
+
<strong>🎯 Highlight:</strong> Function parameters including use_constrained_search and gc_bounds
|
| 351 |
+
</div>
|
| 352 |
+
<pre><code class="language-python">def predict_dna_sequence(
|
| 353 |
+
protein: str,
|
| 354 |
+
organism: Union[int, str],
|
| 355 |
+
device: torch.device,
|
| 356 |
+
tokenizer: Union[str, PreTrainedTokenizerFast] = None,
|
| 357 |
+
model: Union[str, torch.nn.Module] = None,
|
| 358 |
+
attention_type: str = "original_full",
|
| 359 |
+
deterministic: bool = True,
|
| 360 |
+
temperature: float = 0.2,
|
| 361 |
+
top_p: float = 0.95,
|
| 362 |
+
num_sequences: int = 1,
|
| 363 |
+
match_protein: bool = False,
|
| 364 |
+
use_constrained_search: bool = False,
|
| 365 |
+
gc_bounds: Tuple[float, float] = (0.30, 0.70),
|
| 366 |
+
beam_size: int = 5,
|
| 367 |
+
length_penalty: float = 1.0,
|
| 368 |
+
diversity_penalty: float = 0.0,
|
| 369 |
+
) -> Union[DNASequencePrediction, List[DNASequencePrediction]]:
|
| 370 |
+
"""
|
| 371 |
+
Predict the DNA sequence(s) for a given protein using the ENCOT model.
|
| 372 |
+
|
| 373 |
+
This function takes a protein sequence and an organism (as ID or name) as input
|
| 374 |
+
and returns the predicted DNA sequence(s) using the ENCOT model. It can use
|
| 375 |
+
either provided tokenizer and model objects or load them from specified paths.
|
| 376 |
+
|
| 377 |
+
Args:
|
| 378 |
+
protein (str): The input protein sequence for which to predict the DNA sequence.
|
| 379 |
+
organism (Union[int, str]): Either the ID of the organism or its name (e.g.,
|
| 380 |
+
"Escherichia coli general").
|
| 381 |
+
device (torch.device): The device (CPU or GPU) to run the model on.
|
| 382 |
+
use_constrained_search (bool, optional): Enable constrained beam search with GC bounds.
|
| 383 |
+
gc_bounds (Tuple[float, float], optional): GC content bounds (min, max) for
|
| 384 |
+
constrained search. Defaults to (0.30, 0.70).
|
| 385 |
+
beam_size (int, optional): Beam size for beam search. Defaults to 5.
|
| 386 |
+
|
| 387 |
+
Returns:
|
| 388 |
+
Union[DNASequencePrediction, List[DNASequencePrediction]]: Predicted DNA sequence(s)
|
| 389 |
+
with associated metrics.
|
| 390 |
+
"""
|
| 391 |
+
</code></pre>
|
| 392 |
+
</div>
|
| 393 |
+
|
| 394 |
+
<!-- Section 5: Evaluation Metrics -->
|
| 395 |
+
<div class="section">
|
| 396 |
+
<h2 class="section-title">
|
| 397 |
+
<span class="section-number">5</span>
|
| 398 |
+
Evaluation Metrics - CAI & tAI
|
| 399 |
+
</h2>
|
| 400 |
+
<div class="description">
|
| 401 |
+
Functions for calculating Codon Adaptation Index (CAI) and tRNA Adaptation Index (tAI),
|
| 402 |
+
key metrics for evaluating codon optimization quality.
|
| 403 |
+
</div>
|
| 404 |
+
<div class="file-info">
|
| 405 |
+
<div class="file-path">📄 CodonTransformer/CodonEvaluation.py</div>
|
| 406 |
+
<div class="line-range">Lines 23-50, 370-420 | Metrics functions</div>
|
| 407 |
+
</div>
|
| 408 |
+
<div class="key-feature">
|
| 409 |
+
<strong>🎯 Highlight:</strong> CAI and tAI calculation implementations
|
| 410 |
+
</div>
|
| 411 |
+
<pre><code class="language-python">def get_CSI_weights(sequences: List[str]) -> Dict[str, float]:
|
| 412 |
+
"""
|
| 413 |
+
Calculate the Codon Similarity Index (CSI) weights for a list of DNA sequences.
|
| 414 |
+
|
| 415 |
+
Args:
|
| 416 |
+
sequences (List[str]): List of DNA sequences.
|
| 417 |
+
|
| 418 |
+
Returns:
|
| 419 |
+
dict: The CSI weights.
|
| 420 |
+
"""
|
| 421 |
+
return relative_adaptiveness(sequences=sequences)
|
| 422 |
+
|
| 423 |
+
|
| 424 |
+
def get_CSI_value(dna: str, weights: Dict[str, float]) -> float:
|
| 425 |
+
"""
|
| 426 |
+
Calculate the Codon Similarity Index (CSI) for a DNA sequence.
|
| 427 |
+
|
| 428 |
+
Args:
|
| 429 |
+
dna (str): The DNA sequence.
|
| 430 |
+
weights (dict): The CSI weights from get_CSI_weights.
|
| 431 |
+
|
| 432 |
+
Returns:
|
| 433 |
+
float: The CSI value.
|
| 434 |
+
"""
|
| 435 |
+
return CAI(dna, weights)
|
| 436 |
+
|
| 437 |
+
|
| 438 |
+
def get_ecoli_tai_weights():
|
| 439 |
+
"""
|
| 440 |
+
Returns pre-calculated tAI weights for E. coli K-12 MG1655.
|
| 441 |
+
|
| 442 |
+
These weights are based on tRNA gene copy numbers and wobble base pairing rules.
|
| 443 |
+
"""
|
| 444 |
+
return {
|
| 445 |
+
'TTT': 0.58, 'TTC': 0.42, 'TTA': 0.13, 'TTG': 0.13,
|
| 446 |
+
'TCT': 0.15, 'TCC': 0.15, 'TCA': 0.12, 'TCG': 0.15,
|
| 447 |
+
# ... full codon table
|
| 448 |
+
}
|
| 449 |
+
|
| 450 |
+
|
| 451 |
+
def calculate_tAI(sequence: str, tai_weights: Dict[str, float]) -> float:
|
| 452 |
+
"""
|
| 453 |
+
Calculate the tRNA Adaptation Index (tAI) for a DNA sequence.
|
| 454 |
+
|
| 455 |
+
Args:
|
| 456 |
+
sequence (str): DNA sequence (must be divisible by 3)
|
| 457 |
+
tai_weights (Dict[str, float]): tAI weights for each codon
|
| 458 |
+
|
| 459 |
+
Returns:
|
| 460 |
+
float: Geometric mean of tAI weights for all codons in the sequence
|
| 461 |
+
"""
|
| 462 |
+
if len(sequence) % 3 != 0:
|
| 463 |
+
raise ValueError("Sequence length must be divisible by 3")
|
| 464 |
+
|
| 465 |
+
codons = [sequence[i:i+3].upper() for i in range(0, len(sequence), 3)]
|
| 466 |
+
weights = [tai_weights.get(codon, 0.5) for codon in codons if codon not in ['TAA', 'TAG', 'TGA']]
|
| 467 |
+
|
| 468 |
+
if not weights:
|
| 469 |
+
return 0.0
|
| 470 |
+
|
| 471 |
+
# Geometric mean
|
| 472 |
+
product = 1.0
|
| 473 |
+
for w in weights:
|
| 474 |
+
product *= w
|
| 475 |
+
return product ** (1.0 / len(weights))
|
| 476 |
+
</code></pre>
|
| 477 |
+
</div>
|
| 478 |
+
|
| 479 |
+
<!-- Section 6: Training Configuration -->
|
| 480 |
+
<div class="section">
|
| 481 |
+
<h2 class="section-title">
|
| 482 |
+
<span class="section-number">6</span>
|
| 483 |
+
Training Configuration - ALM Settings
|
| 484 |
+
</h2>
|
| 485 |
+
<div class="description">
|
| 486 |
+
YAML configuration file defining all training hyperparameters, including ALM-specific settings for GC content control.
|
| 487 |
+
</div>
|
| 488 |
+
<div class="file-info">
|
| 489 |
+
<div class="file-path">📄 configs/train_ecoli_alm.yaml</div>
|
| 490 |
+
<div class="line-range">Complete file | Training configuration</div>
|
| 491 |
+
</div>
|
| 492 |
+
<div class="key-feature">
|
| 493 |
+
<strong>🎯 Highlight:</strong> ALM section with gc_target, curriculum_epochs, and penalty parameters
|
| 494 |
+
</div>
|
| 495 |
+
<pre><code class="language-yaml"># ENCOT ALM Training Configuration
|
| 496 |
+
# This configuration reproduces the main training setup from the paper
|
| 497 |
+
# using the Augmented-Lagrangian Method (ALM) for GC content control.
|
| 498 |
+
|
| 499 |
+
model:
|
| 500 |
+
base_model: "adibvafa/CodonTransformer-base"
|
| 501 |
+
tokenizer: "adibvafa/CodonTransformer"
|
| 502 |
+
|
| 503 |
+
data:
|
| 504 |
+
dataset_dir: "data"
|
| 505 |
+
# Expected files: finetune_set.json (created by preprocess_data.py)
|
| 506 |
+
|
| 507 |
+
training:
|
| 508 |
+
batch_size: 6
|
| 509 |
+
max_epochs: 15
|
| 510 |
+
learning_rate: 5e-5
|
| 511 |
+
warmup_fraction: 0.1
|
| 512 |
+
num_workers: 5
|
| 513 |
+
accumulate_grad_batches: 1
|
| 514 |
+
num_gpus: 4
|
| 515 |
+
save_every_n_steps: 512
|
| 516 |
+
seed: 123
|
| 517 |
+
log_every_n_steps: 20
|
| 518 |
+
|
| 519 |
+
checkpoint:
|
| 520 |
+
checkpoint_dir: "models/alm-enhanced-training"
|
| 521 |
+
checkpoint_filename: "balanced_alm_finetune.ckpt"
|
| 522 |
+
|
| 523 |
+
# Augmented-Lagrangian Method (ALM) for GC content control
|
| 524 |
+
alm:
|
| 525 |
+
enabled: true
|
| 526 |
+
gc_target: 0.52 # Target GC content for E. coli (52%)
|
| 527 |
+
curriculum_epochs: 3 # Warm-up epochs before enforcing GC constraint
|
| 528 |
+
|
| 529 |
+
# ALM penalty parameters
|
| 530 |
+
initial_penalty_factor: 20.0
|
| 531 |
+
penalty_update_factor: 10.0
|
| 532 |
+
max_penalty: 1e6
|
| 533 |
+
min_penalty: 1e-6
|
| 534 |
+
|
| 535 |
+
# ALM tolerance parameters
|
| 536 |
+
tolerance: 1e-5 # Primal tolerance
|
| 537 |
+
dual_tolerance: 1e-5 # Dual tolerance for constraint violation
|
| 538 |
+
tolerance_update_factor: 0.1
|
| 539 |
+
|
| 540 |
+
# Adaptive penalty adjustment
|
| 541 |
+
rel_penalty_increase_threshold: 0.1
|
| 542 |
+
|
| 543 |
+
# Legacy penalty method (if ALM disabled)
|
| 544 |
+
gc_penalty:
|
| 545 |
+
weight: 0.0 # Only used if use_lagrangian=false
|
| 546 |
+
</code></pre>
|
| 547 |
+
</div>
|
| 548 |
+
|
| 549 |
+
<!-- Section 7: Data Preparation -->
|
| 550 |
+
<div class="section">
|
| 551 |
+
<h2 class="section-title">
|
| 552 |
+
<span class="section-number">7</span>
|
| 553 |
+
Data Preparation & Validation
|
| 554 |
+
</h2>
|
| 555 |
+
<div class="description">
|
| 556 |
+
Functions for validating and preparing E. coli gene sequences for training, including sequence validation checks.
|
| 557 |
+
</div>
|
| 558 |
+
<div class="file-info">
|
| 559 |
+
<div class="file-path">📄 prepare_ecoli_data.py</div>
|
| 560 |
+
<div class="line-range">Lines 5-30 | Validation function</div>
|
| 561 |
+
</div>
|
| 562 |
+
<div class="key-feature">
|
| 563 |
+
<strong>🎯 Highlight:</strong> Sequence validation rules (start/stop codons, frame, no internal stops)
|
| 564 |
+
</div>
|
| 565 |
+
<pre><code class="language-python">def is_valid_sequence(dna_seq: str) -> bool:
|
| 566 |
+
"""
|
| 567 |
+
Applies a series of validation checks to a DNA sequence.
|
| 568 |
+
|
| 569 |
+
Args:
|
| 570 |
+
dna_seq (str): The DNA sequence to validate.
|
| 571 |
+
|
| 572 |
+
Returns:
|
| 573 |
+
bool: True if the sequence is valid, False otherwise.
|
| 574 |
+
"""
|
| 575 |
+
# Check if length is divisible by 3 (valid codon frame)
|
| 576 |
+
if len(dna_seq) % 3 != 0:
|
| 577 |
+
return False
|
| 578 |
+
|
| 579 |
+
# Check for valid start codon
|
| 580 |
+
if not dna_seq.upper().startswith(('ATG', 'TTG', 'CTG', 'GTG')):
|
| 581 |
+
return False
|
| 582 |
+
|
| 583 |
+
# Check for valid stop codon
|
| 584 |
+
if not dna_seq.upper().endswith(('TAA', 'TAG', 'TGA')):
|
| 585 |
+
return False
|
| 586 |
+
|
| 587 |
+
# Check for internal stop codons (excluding the last codon)
|
| 588 |
+
codons = [dna_seq[i:i+3].upper() for i in range(0, len(dna_seq) - 3, 3)]
|
| 589 |
+
if any(codon in ['TAA', 'TAG', 'TGA'] for codon in codons):
|
| 590 |
+
return False
|
| 591 |
+
|
| 592 |
+
# Check if sequence contains only valid nucleotides
|
| 593 |
+
if not all(c in 'ATGC' for c in dna_seq.upper()):
|
| 594 |
+
return False
|
| 595 |
+
|
| 596 |
+
return True
|
| 597 |
+
</code></pre>
|
| 598 |
+
</div>
|
| 599 |
+
|
| 600 |
+
<!-- Section 8: Streamlit GUI -->
|
| 601 |
+
<div class="section">
|
| 602 |
+
<h2 class="section-title">
|
| 603 |
+
<span class="section-number">8</span>
|
| 604 |
+
Streamlit GUI - Main Interface
|
| 605 |
+
</h2>
|
| 606 |
+
<div class="description">
|
| 607 |
+
Web-based graphical interface for ENCOT built with Streamlit, providing user-friendly access to optimization features.
|
| 608 |
+
</div>
|
| 609 |
+
<div class="file-info">
|
| 610 |
+
<div class="file-path">📄 streamlit_gui/app.py</div>
|
| 611 |
+
<div class="line-range">Lines 625-640 | Main function</div>
|
| 612 |
+
</div>
|
| 613 |
+
<div class="key-feature">
|
| 614 |
+
<strong>🎯 Highlight:</strong> Streamlit app structure with tabs and model loading
|
| 615 |
+
</div>
|
| 616 |
+
<pre><code class="language-python">def main():
|
| 617 |
+
st.title("ENCOT")
|
| 618 |
+
st.markdown("E. coli codon optimization with constraint-aware decoding and in silico evaluation metrics.")
|
| 619 |
+
|
| 620 |
+
# Load model
|
| 621 |
+
load_model_and_tokenizer()
|
| 622 |
+
|
| 623 |
+
# Create the main tabbed interface
|
| 624 |
+
tab1, tab2, tab3, tab4 = st.tabs([
|
| 625 |
+
"Single Optimize",
|
| 626 |
+
"Batch Process",
|
| 627 |
+
"Comparative Analysis",
|
| 628 |
+
"Advanced Settings"
|
| 629 |
+
])
|
| 630 |
+
|
| 631 |
+
with tab1:
|
| 632 |
+
single_sequence_optimization()
|
| 633 |
+
|
| 634 |
+
with tab2:
|
| 635 |
+
batch_processing()
|
| 636 |
+
|
| 637 |
+
with tab3:
|
| 638 |
+
comparative_analysis()
|
| 639 |
+
|
| 640 |
+
with tab4:
|
| 641 |
+
advanced_settings()
|
| 642 |
+
|
| 643 |
+
# Footer
|
| 644 |
+
st.markdown("---")
|
| 645 |
+
st.markdown("**ENCOT**")
|
| 646 |
+
st.markdown("Open-source codon optimization for E. coli with reproducible evaluation.")
|
| 647 |
+
</code></pre>
|
| 648 |
+
</div>
|
| 649 |
+
|
| 650 |
+
<!-- Section 9: Benchmark Evaluation -->
|
| 651 |
+
<div class="section">
|
| 652 |
+
<h2 class="section-title">
|
| 653 |
+
<span class="section-number">9</span>
|
| 654 |
+
Benchmark Evaluation Pipeline
|
| 655 |
+
</h2>
|
| 656 |
+
<div class="description">
|
| 657 |
+
Comprehensive benchmarking pipeline for evaluating ENCOT performance on test sequences with multiple metrics.
|
| 658 |
+
</div>
|
| 659 |
+
<div class="file-info">
|
| 660 |
+
<div class="file-path">📄 benchmark_evaluation.py</div>
|
| 661 |
+
<div class="line-range">Lines 300-400 | Benchmark function</div>
|
| 662 |
+
</div>
|
| 663 |
+
<div class="key-feature">
|
| 664 |
+
<strong>🎯 Highlight:</strong> Multi-metric evaluation (CAI, tAI, GC, cis-elements)
|
| 665 |
+
</div>
|
| 666 |
+
<pre><code class="language-python">def benchmark_sequences(sequences, model, tokenizer, device, cai_weights, tai_weights):
|
| 667 |
+
"""
|
| 668 |
+
Run ENCOT on protein sequences and compute metrics for optimized DNA.
|
| 669 |
+
|
| 670 |
+
Args:
|
| 671 |
+
sequences: List of protein sequences to optimize
|
| 672 |
+
model: Loaded ENCOT model
|
| 673 |
+
tokenizer: Tokenizer for the model
|
| 674 |
+
device: PyTorch device (CPU/GPU)
|
| 675 |
+
cai_weights: Pre-computed CAI weights
|
| 676 |
+
tai_weights: Pre-computed tAI weights
|
| 677 |
+
|
| 678 |
+
Returns:
|
| 679 |
+
DataFrame with optimization results and metrics
|
| 680 |
+
"""
|
| 681 |
+
results = []
|
| 682 |
+
|
| 683 |
+
for name, protein in tqdm(sequences, desc="Optimizing sequences"):
|
| 684 |
+
# Optimize the sequence
|
| 685 |
+
output = predict_dna_sequence(
|
| 686 |
+
protein=protein,
|
| 687 |
+
organism="Escherichia coli general",
|
| 688 |
+
device=device,
|
| 689 |
+
model=model,
|
| 690 |
+
tokenizer=tokenizer,
|
| 691 |
+
deterministic=True,
|
| 692 |
+
use_constrained_search=True,
|
| 693 |
+
gc_bounds=(0.45, 0.55)
|
| 694 |
+
)
|
| 695 |
+
|
| 696 |
+
optimized_dna = output.predicted_dna
|
| 697 |
+
|
| 698 |
+
# Calculate metrics
|
| 699 |
+
cai = get_CSI_value(optimized_dna, cai_weights)
|
| 700 |
+
tai = calculate_tAI(optimized_dna, tai_weights)
|
| 701 |
+
gc_content = get_GC_content(optimized_dna)
|
| 702 |
+
cis_elements = count_negative_cis_elements(optimized_dna)
|
| 703 |
+
|
| 704 |
+
results.append({
|
| 705 |
+
'name': name,
|
| 706 |
+
'protein': protein,
|
| 707 |
+
'optimized_dna': optimized_dna,
|
| 708 |
+
'CAI': cai,
|
| 709 |
+
'tAI': tai,
|
| 710 |
+
'GC_content': gc_content,
|
| 711 |
+
'negative_cis_elements': cis_elements
|
| 712 |
+
})
|
| 713 |
+
|
| 714 |
+
return pd.DataFrame(results)
|
| 715 |
+
</code></pre>
|
| 716 |
+
</div>
|
| 717 |
+
|
| 718 |
+
<!-- Section 10: Project Structure -->
|
| 719 |
+
<div class="section">
|
| 720 |
+
<h2 class="section-title">
|
| 721 |
+
<span class="section-number">10</span>
|
| 722 |
+
Project Overview & Architecture
|
| 723 |
+
</h2>
|
| 724 |
+
<div class="description">
|
| 725 |
+
Complete project structure showing the organization of modules, scripts, and configuration files.
|
| 726 |
+
</div>
|
| 727 |
+
<div class="key-feature">
|
| 728 |
+
<strong>🎯 Key Components:</strong> Training (finetune.py), Inference (CodonPrediction.py),
|
| 729 |
+
Evaluation (CodonEvaluation.py), GUI (streamlit_gui/), Configs (configs/)
|
| 730 |
+
</div>
|
| 731 |
+
<pre><code class="language-plaintext">ENCOT/
|
| 732 |
+
├── CodonTransformer/ # Core library modules
|
| 733 |
+
│ ├── CodonPrediction.py # Model loading & DNA sequence prediction
|
| 734 |
+
│ ├── CodonEvaluation.py # Metrics (CAI, tAI, GC, CFD, etc.)
|
| 735 |
+
│ ├── CodonData.py # Data preprocessing & preparation
|
| 736 |
+
│ ├── CodonUtils.py # Constants, mappings, utilities
|
| 737 |
+
│ └── CodonPostProcessing.py # DNA-Chisel integration
|
| 738 |
+
│
|
| 739 |
+
├── scripts/ # Command-line tools
|
| 740 |
+
│ ├── train.py # Training wrapper
|
| 741 |
+
│ ├── optimize_sequence.py # Sequence optimization CLI
|
| 742 |
+
│ ├── run_benchmarks.py # Benchmark evaluation
|
| 743 |
+
│ └── preprocess_data.py # Data preparation
|
| 744 |
+
│
|
| 745 |
+
├── configs/ # YAML configurations
|
| 746 |
+
│ ├── train_ecoli_alm.yaml # Main ALM training config ⭐
|
| 747 |
+
│ └── train_ecoli_quick.yaml # Quick test config
|
| 748 |
+
│
|
| 749 |
+
├── streamlit_gui/ # Web interface
|
| 750 |
+
│ ├── app.py # Main Streamlit GUI ⭐
|
| 751 |
+
│ ├── demo.py # Demo script
|
| 752 |
+
│ └── run_gui.py # Launcher
|
| 753 |
+
│
|
| 754 |
+
├── data/ # Datasets
|
| 755 |
+
│ ├── finetune_set.json # Training data
|
| 756 |
+
│ └── test_set.json # Test data
|
| 757 |
+
│
|
| 758 |
+
├── finetune.py # Main training script ⭐⭐⭐
|
| 759 |
+
├── benchmark_evaluation.py # Evaluation script
|
| 760 |
+
├── setup.py # Package setup
|
| 761 |
+
├── pyproject.toml # Project configuration
|
| 762 |
+
└── README.md # Documentation
|
| 763 |
+
|
| 764 |
+
Key Innovations:
|
| 765 |
+
⭐⭐⭐ Augmented-Lagrangian Method (ALM) for GC control
|
| 766 |
+
⭐⭐ Constrained beam search with GC bounds
|
| 767 |
+
⭐ Multi-metric evaluation (CAI, tAI, GC, cis-elements)
|
| 768 |
+
</code></pre>
|
| 769 |
+
</div>
|
| 770 |
+
|
| 771 |
+
<div class="footer">
|
| 772 |
+
<h3>ENCOT - Enhanced Codon Optimization Tool</h3>
|
| 773 |
+
<p>Repository: <a href="https://github.com/geno543/ENCOT" style="color: #58a6ff;">github.com/geno543/ENCOT</a></p>
|
| 774 |
+
<p>© 2026 | Apache License 2.0</p>
|
| 775 |
+
</div>
|
| 776 |
+
|
| 777 |
+
<script>
|
| 778 |
+
// Initialize syntax highlighting
|
| 779 |
+
hljs.highlightAll();
|
| 780 |
+
|
| 781 |
+
// Add line numbers
|
| 782 |
+
document.querySelectorAll('pre code').forEach((block) => {
|
| 783 |
+
const lines = block.innerHTML.split('\n');
|
| 784 |
+
const numberedLines = lines.map((line, index) => {
|
| 785 |
+
return `<span class="line-number" style="color: #6e7681; user-select: none; margin-right: 1em;">${String(index + 1).padStart(3, ' ')}</span>${line}`;
|
| 786 |
+
}).join('\n');
|
| 787 |
+
block.innerHTML = numberedLines;
|
| 788 |
+
});
|
| 789 |
+
</script>
|
| 790 |
+
</body>
|
| 791 |
+
</html>
|
LICENSE
ADDED
|
@@ -0,0 +1,201 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Apache License
|
| 2 |
+
Version 2.0, January 2004
|
| 3 |
+
http://www.apache.org/licenses/
|
| 4 |
+
|
| 5 |
+
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
| 6 |
+
|
| 7 |
+
1. Definitions.
|
| 8 |
+
|
| 9 |
+
"License" shall mean the terms and conditions for use, reproduction,
|
| 10 |
+
and distribution as defined by Sections 1 through 9 of this document.
|
| 11 |
+
|
| 12 |
+
"Licensor" shall mean the copyright owner or entity authorized by
|
| 13 |
+
the copyright owner that is granting the License.
|
| 14 |
+
|
| 15 |
+
"Legal Entity" shall mean the union of the acting entity and all
|
| 16 |
+
other entities that control, are controlled by, or are under common
|
| 17 |
+
control with that entity. For the purposes of this definition,
|
| 18 |
+
"control" means (i) the power, direct or indirect, to cause the
|
| 19 |
+
direction or management of such entity, whether by contract or
|
| 20 |
+
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
| 21 |
+
outstanding shares, or (iii) beneficial ownership of such entity.
|
| 22 |
+
|
| 23 |
+
"You" (or "Your") shall mean an individual or Legal Entity
|
| 24 |
+
exercising permissions granted by this License.
|
| 25 |
+
|
| 26 |
+
"Source" form shall mean the preferred form for making modifications,
|
| 27 |
+
including but not limited to software source code, documentation
|
| 28 |
+
source, and configuration files.
|
| 29 |
+
|
| 30 |
+
"Object" form shall mean any form resulting from mechanical
|
| 31 |
+
transformation or translation of a Source form, including but
|
| 32 |
+
not limited to compiled object code, generated documentation,
|
| 33 |
+
and conversions to other media types.
|
| 34 |
+
|
| 35 |
+
"Work" shall mean the work of authorship, whether in Source or
|
| 36 |
+
Object form, made available under the License, as indicated by a
|
| 37 |
+
copyright notice that is included in or attached to the work
|
| 38 |
+
(an example is provided in the Appendix below).
|
| 39 |
+
|
| 40 |
+
"Derivative Works" shall mean any work, whether in Source or Object
|
| 41 |
+
form, that is based on (or derived from) the Work and for which the
|
| 42 |
+
editorial revisions, annotations, elaborations, or other modifications
|
| 43 |
+
represent, as a whole, an original work of authorship. For the purposes
|
| 44 |
+
of this License, Derivative Works shall not include works that remain
|
| 45 |
+
separable from, or merely link (or bind by name) to the interfaces of,
|
| 46 |
+
the Work and Derivative Works thereof.
|
| 47 |
+
|
| 48 |
+
"Contribution" shall mean any work of authorship, including
|
| 49 |
+
the original version of the Work and any modifications or additions
|
| 50 |
+
to that Work or Derivative Works thereof, that is intentionally
|
| 51 |
+
submitted to Licensor for inclusion in the Work by the copyright owner
|
| 52 |
+
or by an individual or Legal Entity authorized to submit on behalf of
|
| 53 |
+
the copyright owner. For the purposes of this definition, "submitted"
|
| 54 |
+
means any form of electronic, verbal, or written communication sent
|
| 55 |
+
to the Licensor or its representatives, including but not limited to
|
| 56 |
+
communication on electronic mailing lists, source code control systems,
|
| 57 |
+
and issue tracking systems that are managed by, or on behalf of, the
|
| 58 |
+
Licensor for the purpose of discussing and improving the Work, but
|
| 59 |
+
excluding communication that is conspicuously marked or otherwise
|
| 60 |
+
designated in writing by the copyright owner as "Not a Contribution."
|
| 61 |
+
|
| 62 |
+
"Contributor" shall mean Licensor and any individual or Legal Entity
|
| 63 |
+
on behalf of whom a Contribution has been received by Licensor and
|
| 64 |
+
subsequently incorporated within the Work.
|
| 65 |
+
|
| 66 |
+
2. Grant of Copyright License. Subject to the terms and conditions of
|
| 67 |
+
this License, each Contributor hereby grants to You a perpetual,
|
| 68 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
| 69 |
+
copyright license to reproduce, prepare Derivative Works of,
|
| 70 |
+
publicly display, publicly perform, sublicense, and distribute the
|
| 71 |
+
Work and such Derivative Works in Source or Object form.
|
| 72 |
+
|
| 73 |
+
3. Grant of Patent License. Subject to the terms and conditions of
|
| 74 |
+
this License, each Contributor hereby grants to You a perpetual,
|
| 75 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
| 76 |
+
(except as stated in this section) patent license to make, have made,
|
| 77 |
+
use, offer to sell, sell, import, and otherwise transfer the Work,
|
| 78 |
+
where such license applies only to those patent claims licensable
|
| 79 |
+
by such Contributor that are necessarily infringed by their
|
| 80 |
+
Contribution(s) alone or by combination of their Contribution(s)
|
| 81 |
+
with the Work to which such Contribution(s) was submitted. If You
|
| 82 |
+
institute patent litigation against any entity (including a
|
| 83 |
+
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
| 84 |
+
or a Contribution incorporated within the Work constitutes direct
|
| 85 |
+
or contributory patent infringement, then any patent licenses
|
| 86 |
+
granted to You under this License for that Work shall terminate
|
| 87 |
+
as of the date such litigation is filed.
|
| 88 |
+
|
| 89 |
+
4. Redistribution. You may reproduce and distribute copies of the
|
| 90 |
+
Work or Derivative Works thereof in any medium, with or without
|
| 91 |
+
modifications, and in Source or Object form, provided that You
|
| 92 |
+
meet the following conditions:
|
| 93 |
+
|
| 94 |
+
(a) You must give any other recipients of the Work or
|
| 95 |
+
Derivative Works a copy of this License; and
|
| 96 |
+
|
| 97 |
+
(b) You must cause any modified files to carry prominent notices
|
| 98 |
+
stating that You changed the files; and
|
| 99 |
+
|
| 100 |
+
(c) You must retain, in the Source form of any Derivative Works
|
| 101 |
+
that You distribute, all copyright, patent, trademark, and
|
| 102 |
+
attribution notices from the Source form of the Work,
|
| 103 |
+
excluding those notices that do not pertain to any part of
|
| 104 |
+
the Derivative Works; and
|
| 105 |
+
|
| 106 |
+
(d) If the Work includes a "NOTICE" text file as part of its
|
| 107 |
+
distribution, then any Derivative Works that You distribute must
|
| 108 |
+
include a readable copy of the attribution notices contained
|
| 109 |
+
within such NOTICE file, excluding those notices that do not
|
| 110 |
+
pertain to any part of the Derivative Works, in at least one
|
| 111 |
+
of the following places: within a NOTICE text file distributed
|
| 112 |
+
as part of the Derivative Works; within the Source form or
|
| 113 |
+
documentation, if provided along with the Derivative Works; or,
|
| 114 |
+
within a display generated by the Derivative Works, if and
|
| 115 |
+
wherever such third-party notices normally appear. The contents
|
| 116 |
+
of the NOTICE file are for informational purposes only and
|
| 117 |
+
do not modify the License. You may add Your own attribution
|
| 118 |
+
notices within Derivative Works that You distribute, alongside
|
| 119 |
+
or as an addendum to the NOTICE text from the Work, provided
|
| 120 |
+
that such additional attribution notices cannot be construed
|
| 121 |
+
as modifying the License.
|
| 122 |
+
|
| 123 |
+
You may add Your own copyright statement to Your modifications and
|
| 124 |
+
may provide additional or different license terms and conditions
|
| 125 |
+
for use, reproduction, or distribution of Your modifications, or
|
| 126 |
+
for any such Derivative Works as a whole, provided Your use,
|
| 127 |
+
reproduction, and distribution of the Work otherwise complies with
|
| 128 |
+
the conditions stated in this License.
|
| 129 |
+
|
| 130 |
+
5. Submission of Contributions. Unless You explicitly state otherwise,
|
| 131 |
+
any Contribution intentionally submitted for inclusion in the Work
|
| 132 |
+
by You to the Licensor shall be under the terms and conditions of
|
| 133 |
+
this License, without any additional terms or conditions.
|
| 134 |
+
Notwithstanding the above, nothing herein shall supersede or modify
|
| 135 |
+
the terms of any separate license agreement you may have executed
|
| 136 |
+
with Licensor regarding such Contributions.
|
| 137 |
+
|
| 138 |
+
6. Trademarks. This License does not grant permission to use the trade
|
| 139 |
+
names, trademarks, service marks, or product names of the Licensor,
|
| 140 |
+
except as required for reasonable and customary use in describing the
|
| 141 |
+
origin of the Work and reproducing the content of the NOTICE file.
|
| 142 |
+
|
| 143 |
+
7. Disclaimer of Warranty. Unless required by applicable law or
|
| 144 |
+
agreed to in writing, Licensor provides the Work (and each
|
| 145 |
+
Contributor provides its Contributions) on an "AS IS" BASIS,
|
| 146 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
| 147 |
+
implied, including, without limitation, any warranties or conditions
|
| 148 |
+
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
| 149 |
+
PARTICULAR PURPOSE. You are solely responsible for determining the
|
| 150 |
+
appropriateness of using or redistributing the Work and assume any
|
| 151 |
+
risks associated with Your exercise of permissions under this License.
|
| 152 |
+
|
| 153 |
+
8. Limitation of Liability. In no event and under no legal theory,
|
| 154 |
+
whether in tort (including negligence), contract, or otherwise,
|
| 155 |
+
unless required by applicable law (such as deliberate and grossly
|
| 156 |
+
negligent acts) or agreed to in writing, shall any Contributor be
|
| 157 |
+
liable to You for damages, including any direct, indirect, special,
|
| 158 |
+
incidental, or consequential damages of any character arising as a
|
| 159 |
+
result of this License or out of the use or inability to use the
|
| 160 |
+
Work (including but not limited to damages for loss of goodwill,
|
| 161 |
+
work stoppage, computer failure or malfunction, or any and all
|
| 162 |
+
other commercial damages or losses), even if such Contributor
|
| 163 |
+
has been advised of the possibility of such damages.
|
| 164 |
+
|
| 165 |
+
9. Accepting Warranty or Additional Liability. While redistributing
|
| 166 |
+
the Work or Derivative Works thereof, You may choose to offer,
|
| 167 |
+
and charge a fee for, acceptance of support, warranty, indemnity,
|
| 168 |
+
or other liability obligations and/or rights consistent with this
|
| 169 |
+
License. However, in accepting such obligations, You may act only
|
| 170 |
+
on Your own behalf and on Your sole responsibility, not on behalf
|
| 171 |
+
of any other Contributor, and only if You agree to indemnify,
|
| 172 |
+
defend, and hold each Contributor harmless for any liability
|
| 173 |
+
incurred by, or claims asserted against, such Contributor by reason
|
| 174 |
+
of your accepting any such warranty or additional liability.
|
| 175 |
+
|
| 176 |
+
END OF TERMS AND CONDITIONS
|
| 177 |
+
|
| 178 |
+
APPENDIX: How to apply the Apache License to your work.
|
| 179 |
+
|
| 180 |
+
To apply the Apache License to your work, attach the following
|
| 181 |
+
boilerplate notice, with the fields enclosed by brackets "[]"
|
| 182 |
+
replaced with your own identifying information. (Don't include
|
| 183 |
+
the brackets!) The text should be enclosed in the appropriate
|
| 184 |
+
comment syntax for the file format. We also recommend that a
|
| 185 |
+
file or class name and description of purpose be included on the
|
| 186 |
+
same "printed page" as the copyright notice for easier
|
| 187 |
+
identification within third-party archives.
|
| 188 |
+
|
| 189 |
+
Copyright 2024 Adibvafa Fallahpour
|
| 190 |
+
|
| 191 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
| 192 |
+
you may not use this file except in compliance with the License.
|
| 193 |
+
You may obtain a copy of the License at
|
| 194 |
+
|
| 195 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
| 196 |
+
|
| 197 |
+
Unless required by applicable law or agreed to in writing, software
|
| 198 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
| 199 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 200 |
+
See the License for the specific language governing permissions and
|
| 201 |
+
limitations under the License.
|
Makefile
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Makefile
|
| 2 |
+
|
| 3 |
+
.PHONY: test
|
| 4 |
+
test:
|
| 5 |
+
python -m unittest discover -s tests
|
| 6 |
+
|
| 7 |
+
.PHONY: test_with_coverage
|
| 8 |
+
test_with_coverage:
|
| 9 |
+
coverage run -m unittest discover -s tests
|
README.md
CHANGED
|
@@ -1,10 +1,495 @@
|
|
| 1 |
-
-
|
| 2 |
-
|
| 3 |
-
|
| 4 |
-
|
| 5 |
-
|
| 6 |
-
|
| 7 |
-
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# ENCOT: A Transformer-Based Codon Optimization Model Balancing Multiple Objectives for Enhanced E. coli Gene Expression
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
<p align="center">
|
| 5 |
+
<a href="https://huggingface.co/saketh11/ColiFormer"><img src="https://img.shields.io/badge/HuggingFace-Model-FFBF00?style=for-the-badge&logo=huggingface&logoColor=white" alt="HuggingFace Model"></a>
|
| 6 |
+
<a href="https://huggingface.co/datasets/saketh11/ColiFormer-Data"><img src="https://img.shields.io/badge/HuggingFace-Data-FFBF00?style=for-the-badge&logo=huggingface&logoColor=white" alt="HuggingFace Dataset"></a>
|
| 7 |
+
</p>
|
| 8 |
+
|
| 9 |
+
## Abstract
|
| 10 |
+
|
| 11 |
+
ENCOT is a transformer-based model for codon optimization of protein sequences in *Escherichia coli*. Built on top of CodonTransformer (a multi-species BigBird model trained on over 1 million DNA–protein pairs), ENCOT is fine-tuned specifically for E. coli codon preferences using 3,676 high-expression E. coli genes curated from NCBI.
|
| 12 |
+
|
| 13 |
+
ENCOT balances multiple objectives (CAI, GC content, tAI, RNA stability, and minimization of negative cis-regulatory elements) and uses an **Augmented-Lagrangian Method (ALM)** to enforce GC content control during training. Performance was evaluated on 37,053 native E. coli genes and 80 recombinant protein targets, demonstrating strong improvements in in silico expression metrics while maintaining biologically appropriate constraints.
|
| 14 |
+
|
| 15 |
+
## Paper Reference
|
| 16 |
+
|
| 17 |
+
**ENCOT: A Transformer-Based Codon Optimization Model Balancing Multiple Objectives for Enhanced E. coli Gene Expression**
|
| 18 |
+
|
| 19 |
+
Saketh Baddam, Omar Emam, Abdelrahman Elfikky, Francesco Cavarretta, George Luka, Ibrahim Farag, Yasser Sanad
|
| 20 |
+
|
| 21 |
+
bioRxiv preprint (not peer-reviewed): `https://doi.org/10.1101/2025.11.26.690826`
|
| 22 |
+
|
| 23 |
+
**What does “preprint and not peer-reviewed” mean?** A preprint is a publicly available manuscript shared before formal journal peer review. It can be cited, but its claims have not yet been evaluated by journal referees.
|
| 24 |
+
|
| 25 |
+
### Citation
|
| 26 |
+
|
| 27 |
+
If you use ENCOT in your research, please cite:
|
| 28 |
+
|
| 29 |
+
```bibtex
|
| 30 |
+
@article{encot2025,
|
| 31 |
+
title{ENCOT: A Transformer-Based Codon Optimization Model Balancing Multiple Objectives for Enhanced E. coli Gene Expression},
|
| 32 |
+
author={Baddam, Saketh and Emam, Omar and Elfikky, Abdelrahman and Cavarretta, Francesco and Luka, George and Farag, Ibrahim and Sanad, Yasser},
|
| 33 |
+
journal={bioRxiv},
|
| 34 |
+
year={2025},
|
| 35 |
+
doi={10.1101/2025.11.26.690826},
|
| 36 |
+
url={https://doi.org/10.1101/2025.11.26.690826},
|
| 37 |
+
note={Preprint (not peer-reviewed)}
|
| 38 |
+
}
|
| 39 |
+
```
|
| 40 |
+
|
| 41 |
+
## Quick Start
|
| 42 |
+
|
| 43 |
+
Optimize a protein sequence in just a few lines:
|
| 44 |
+
|
| 45 |
+
```python
|
| 46 |
+
import torch
|
| 47 |
+
from transformers import AutoTokenizer
|
| 48 |
+
from CodonTransformer.CodonPrediction import load_model, predict_dna_sequence
|
| 49 |
+
from huggingface_hub import hf_hub_download
|
| 50 |
+
|
| 51 |
+
# Load model from Hugging Face
|
| 52 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 53 |
+
checkpoint_path = hf_hub_download(
|
| 54 |
+
repo_id="saketh11/ColiFormer",
|
| 55 |
+
filename="balanced_alm_finetune.ckpt",
|
| 56 |
+
cache_dir="./hf_cache"
|
| 57 |
+
)
|
| 58 |
+
model = load_model(model_path=checkpoint_path, device=device)
|
| 59 |
+
tokenizer = AutoTokenizer.from_pretrained("adibvafa/CodonTransformer")
|
| 60 |
+
|
| 61 |
+
# Optimize a protein sequence
|
| 62 |
+
protein = "MALWMRLLPLLALLALWGPDPAAAFVNQHLCGSHLVEALYLVCGERGFFYTPKTRREAEDLQVGQVELGG"
|
| 63 |
+
output = predict_dna_sequence(
|
| 64 |
+
protein=protein,
|
| 65 |
+
organism="Escherichia coli general",
|
| 66 |
+
device=device,
|
| 67 |
+
model=model,
|
| 68 |
+
tokenizer=tokenizer,
|
| 69 |
+
deterministic=True,
|
| 70 |
+
match_protein=True
|
| 71 |
+
)
|
| 72 |
+
|
| 73 |
+
print(f"Optimized DNA: {output.predicted_dna}")
|
| 74 |
+
```
|
| 75 |
+
|
| 76 |
+
Or use the command-line interface:
|
| 77 |
+
|
| 78 |
+
```bash
|
| 79 |
+
python scripts/optimize_sequence.py \
|
| 80 |
+
--input "MALWMRLLPLLALLALWGPDPAAAFVNQHLCGSHLVEALYLVCGERGFFYTPKTRREAEDLQVGQVELGG" \
|
| 81 |
+
--output optimized.fasta
|
| 82 |
+
```
|
| 83 |
+
|
| 84 |
+
## Installation
|
| 85 |
+
|
| 86 |
+
### Requirements
|
| 87 |
+
|
| 88 |
+
- Python >= 3.9
|
| 89 |
+
- CUDA-capable GPU (recommended for training, optional for inference)
|
| 90 |
+
|
| 91 |
+
### Setup
|
| 92 |
+
|
| 93 |
+
1. **Clone the repository:**
|
| 94 |
+
|
| 95 |
+
```bash
|
| 96 |
+
git clone https://github.com/geno543/ENCOT.git
|
| 97 |
+
cd ENCOT
|
| 98 |
+
```
|
| 99 |
+
|
| 100 |
+
2. **Create a virtual environment:**
|
| 101 |
+
|
| 102 |
+
```bash
|
| 103 |
+
python -m venv venv
|
| 104 |
+
source venv/bin/activate # On Windows: venv\Scripts\activate
|
| 105 |
+
```
|
| 106 |
+
|
| 107 |
+
3. **Install dependencies:**
|
| 108 |
+
|
| 109 |
+
```bash
|
| 110 |
+
pip install -r requirements.txt
|
| 111 |
+
```
|
| 112 |
+
|
| 113 |
+
The installation takes approximately 10-30 seconds depending on your system and existing packages.
|
| 114 |
+
|
| 115 |
+
## Public Streamlit Demo (Anyone Can Try It)
|
| 116 |
+
|
| 117 |
+
If you want a public link so anyone can test ENCOT in a browser, deploy the app with either Streamlit Community Cloud or Hugging Face Spaces.
|
| 118 |
+
|
| 119 |
+
### Option A: Streamlit Community Cloud (Fastest)
|
| 120 |
+
|
| 121 |
+
1. Push this repository to GitHub.
|
| 122 |
+
2. Go to https://share.streamlit.io and sign in.
|
| 123 |
+
3. Click **New app** and choose your repository.
|
| 124 |
+
4. Set **Main file path** to `streamlit_app.py`.
|
| 125 |
+
5. Use the repository `requirements.txt` for dependencies.
|
| 126 |
+
6. Deploy and share the generated public URL.
|
| 127 |
+
|
| 128 |
+
### Option B: Hugging Face Spaces (Streamlit)
|
| 129 |
+
|
| 130 |
+
1. Create a new Space (SDK: **Streamlit**).
|
| 131 |
+
2. Upload this project (or connect the GitHub repo).
|
| 132 |
+
3. Ensure app file is `streamlit_app.py`.
|
| 133 |
+
4. Keep the repo public so anyone can access the Space URL.
|
| 134 |
+
|
| 135 |
+
### Local check before deployment
|
| 136 |
+
|
| 137 |
+
```bash
|
| 138 |
+
streamlit run streamlit_app.py --server.port 8501
|
| 139 |
+
```
|
| 140 |
+
|
| 141 |
+
This uses the existing UI in `streamlit_gui/app.py`, including model loading from Hugging Face and optimization controls.
|
| 142 |
+
|
| 143 |
+
## Data Preparation
|
| 144 |
+
|
| 145 |
+
### Preparing E. coli Training Data
|
| 146 |
+
|
| 147 |
+
To prepare training data from raw E. coli gene sequences:
|
| 148 |
+
|
| 149 |
+
1. **Place your data files in the `data/` directory:**
|
| 150 |
+
- `data/CAI.csv` - CSV file with columns: gene_id, cai_score, dna_sequence
|
| 151 |
+
- `data/Database 3_4300 gene.csv` - CSV file with high-CAI sequences (column: dna_sequence)
|
| 152 |
+
|
| 153 |
+
2. **Run the preprocessing script:**
|
| 154 |
+
|
| 155 |
+
```bash
|
| 156 |
+
python scripts/preprocess_data.py
|
| 157 |
+
```
|
| 158 |
+
|
| 159 |
+
This will:
|
| 160 |
+
- Validate and process DNA sequences
|
| 161 |
+
- Create `data/ecoli_processed_genes.csv` with validated sequences
|
| 162 |
+
- Generate `data/finetune_set.json` for training (high-CAI sequences)
|
| 163 |
+
- Generate `data/test_set.json` for evaluation (100 random sequences)
|
| 164 |
+
|
| 165 |
+
**Custom paths:**
|
| 166 |
+
|
| 167 |
+
```bash
|
| 168 |
+
python scripts/preprocess_data.py \
|
| 169 |
+
--cai_csv data/my_cai_data.csv \
|
| 170 |
+
--high_cai_csv data/my_high_cai_data.csv \
|
| 171 |
+
--output_dir my_data \
|
| 172 |
+
--test_size 200
|
| 173 |
+
```
|
| 174 |
+
|
| 175 |
+
### Dataset Structure
|
| 176 |
+
|
| 177 |
+
The processed dataset includes:
|
| 178 |
+
- **Training set**: 4,300 high-CAI E. coli sequences (from `Database 3_4300 gene.csv`)
|
| 179 |
+
- **Test set**: 100 randomly sampled sequences (for evaluation)
|
| 180 |
+
- **Reference sequences**: 50,000+ E. coli genes for CAI/tAI calculation
|
| 181 |
+
|
| 182 |
+
The complete dataset is available at [saketh11/ColiFormer-Data](https://huggingface.co/datasets/saketh11/ColiFormer-Data) on Hugging Face.
|
| 183 |
+
|
| 184 |
+
## Training
|
| 185 |
+
|
| 186 |
+
### Quick Start Training
|
| 187 |
+
|
| 188 |
+
Train ENCOT with the default ALM configuration:
|
| 189 |
+
|
| 190 |
+
```bash
|
| 191 |
+
python scripts/train.py --config configs/train_ecoli_alm.yaml
|
| 192 |
+
```
|
| 193 |
+
|
| 194 |
+
### Configuration Files
|
| 195 |
+
|
| 196 |
+
We provide three configuration files:
|
| 197 |
+
|
| 198 |
+
1. **`configs/train_ecoli_alm.yaml`** - Main training configuration with ALM GC control
|
| 199 |
+
- 15 epochs, batch size 6, 4 GPUs
|
| 200 |
+
- ALM enabled with GC target 52%
|
| 201 |
+
- Curriculum learning: 3 warm-up epochs
|
| 202 |
+
|
| 203 |
+
2. **`configs/train_ecoli_quick.yaml`** - Quick sanity check
|
| 204 |
+
- 1 epoch, batch size 2, CPU-only
|
| 205 |
+
- Useful for testing your setup
|
| 206 |
+
|
| 207 |
+
3. **`configs/benchmark.yaml`** - Benchmark evaluation settings
|
| 208 |
+
|
| 209 |
+
### Training Parameters
|
| 210 |
+
|
| 211 |
+
Key parameters in the config files:
|
| 212 |
+
|
| 213 |
+
- **`training.batch_size`**: Batch size (default: 6)
|
| 214 |
+
- **`training.max_epochs`**: Number of training epochs (default: 15)
|
| 215 |
+
- **`training.learning_rate`**: Learning rate (default: 5e-5)
|
| 216 |
+
- **`training.num_gpus`**: Number of GPUs (default: 4)
|
| 217 |
+
- **`alm.enabled`**: Enable ALM GC control (default: true)
|
| 218 |
+
- **`alm.gc_target`**: Target GC content (default: 0.52 for E. coli)
|
| 219 |
+
- **`alm.curriculum_epochs`**: Warm-up epochs before enforcing GC constraint (default: 3)
|
| 220 |
+
|
| 221 |
+
### Override Config Values
|
| 222 |
+
|
| 223 |
+
You can override config values from the command line:
|
| 224 |
+
|
| 225 |
+
```bash
|
| 226 |
+
python scripts/train.py \
|
| 227 |
+
--config configs/train_ecoli_alm.yaml \
|
| 228 |
+
--num_gpus 2 \
|
| 229 |
+
--batch_size 4 \
|
| 230 |
+
--max_epochs 10
|
| 231 |
+
```
|
| 232 |
+
|
| 233 |
+
### Training Output
|
| 234 |
+
|
| 235 |
+
Checkpoints are saved to the directory specified in `checkpoint.checkpoint_dir`:
|
| 236 |
+
- Model state dict: `balanced_alm_finetune.ckpt`
|
| 237 |
+
- Training logs: TensorBoard logs in the checkpoint directory
|
| 238 |
+
|
| 239 |
+
Monitor training progress:
|
| 240 |
+
|
| 241 |
+
```bash
|
| 242 |
+
tensorboard --logdir models/alm-enhanced-training
|
| 243 |
+
```
|
| 244 |
+
|
| 245 |
+
## Inference / Sequence Optimization
|
| 246 |
+
|
| 247 |
+
### Single Sequence Optimization
|
| 248 |
+
|
| 249 |
+
Optimize a single protein sequence:
|
| 250 |
+
|
| 251 |
+
```bash
|
| 252 |
+
python scripts/optimize_sequence.py \
|
| 253 |
+
--input "MALWMRLLPLLALLALWGPDPAAAFVNQHLCGSHLVEALYLVCGERGFFYTPKTRREAEDLQVGQVELGG" \
|
| 254 |
+
--output optimized.fasta
|
| 255 |
+
```
|
| 256 |
+
|
| 257 |
+
### Batch Processing
|
| 258 |
+
|
| 259 |
+
Process multiple sequences from a FASTA file:
|
| 260 |
+
|
| 261 |
+
```bash
|
| 262 |
+
python scripts/optimize_sequence.py \
|
| 263 |
+
--input sequences.fasta \
|
| 264 |
+
--output optimized.fasta \
|
| 265 |
+
--batch
|
| 266 |
+
```
|
| 267 |
+
|
| 268 |
+
### GC Content Constraints
|
| 269 |
+
|
| 270 |
+
Specify GC content bounds:
|
| 271 |
+
|
| 272 |
+
```bash
|
| 273 |
+
python scripts/optimize_sequence.py \
|
| 274 |
+
--input protein.fasta \
|
| 275 |
+
--output optimized.fasta \
|
| 276 |
+
--gc-min 0.45 \
|
| 277 |
+
--gc-max 0.55
|
| 278 |
+
```
|
| 279 |
+
|
| 280 |
+
### Using Custom Checkpoint
|
| 281 |
+
|
| 282 |
+
```bash
|
| 283 |
+
python scripts/optimize_sequence.py \
|
| 284 |
+
--input protein.fasta \
|
| 285 |
+
--output optimized.fasta \
|
| 286 |
+
--checkpoint models/my_model.ckpt
|
| 287 |
+
```
|
| 288 |
+
|
| 289 |
+
### Python API
|
| 290 |
+
|
| 291 |
+
For programmatic use:
|
| 292 |
+
|
| 293 |
+
```python
|
| 294 |
+
from CodonTransformer.CodonPrediction import load_model, predict_dna_sequence
|
| 295 |
+
from transformers import AutoTokenizer
|
| 296 |
+
import torch
|
| 297 |
+
|
| 298 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 299 |
+
model = load_model(model_path="models/alm-enhanced-training/balanced_alm_finetune.ckpt", device=device)
|
| 300 |
+
tokenizer = AutoTokenizer.from_pretrained("adibvafa/CodonTransformer")
|
| 301 |
+
|
| 302 |
+
output = predict_dna_sequence(
|
| 303 |
+
protein="MALWMRLLPLLALLALWGPDPAAAFVNQHLCGSHLVEALYLVCGERGFFYTPKTRREAEDLQVGQVELGG",
|
| 304 |
+
organism="Escherichia coli general",
|
| 305 |
+
device=device,
|
| 306 |
+
model=model,
|
| 307 |
+
tokenizer=tokenizer,
|
| 308 |
+
deterministic=True,
|
| 309 |
+
match_protein=True,
|
| 310 |
+
use_constrained_search=True,
|
| 311 |
+
gc_bounds=(0.45, 0.55),
|
| 312 |
+
beam_size=20
|
| 313 |
+
)
|
| 314 |
+
|
| 315 |
+
print(f"Optimized DNA: {output.predicted_dna}")
|
| 316 |
+
```
|
| 317 |
+
|
| 318 |
+
## Reproducing Paper Results
|
| 319 |
+
|
| 320 |
+
### Benchmark Evaluation
|
| 321 |
+
|
| 322 |
+
To reproduce the benchmark results from the paper:
|
| 323 |
+
|
| 324 |
+
1. **Prepare benchmark sequences:**
|
| 325 |
+
|
| 326 |
+
Place your benchmark sequences in an Excel file (see `Benchmark 80 sequences.xlsx` for format).
|
| 327 |
+
|
| 328 |
+
2. **Run benchmark evaluation:**
|
| 329 |
+
|
| 330 |
+
```bash
|
| 331 |
+
python scripts/run_benchmarks.py --config configs/benchmark.yaml
|
| 332 |
+
```
|
| 333 |
+
|
| 334 |
+
This will:
|
| 335 |
+
- Load the fine-tuned ENCOT model
|
| 336 |
+
- Optimize all sequences in the benchmark file
|
| 337 |
+
- Calculate metrics (CAI, tAI, GC content, CFD, negative cis-elements)
|
| 338 |
+
- Generate comparison plots and summary statistics
|
| 339 |
+
- Save results to `benchmark_results/run_TIMESTAMP/`
|
| 340 |
+
|
| 341 |
+
### Expected Results
|
| 342 |
+
|
| 343 |
+
On the benchmark set of 80 sequences:
|
| 344 |
+
- **CAI improvement**: +6.2% vs base CodonTransformer
|
| 345 |
+
- **tAI improvement**: +8.6% vs base CodonTransformer
|
| 346 |
+
- **GC content**: Mean 52.1% (target: 52%)
|
| 347 |
+
- **Runtime**: ~1-3 seconds per sequence (GPU)
|
| 348 |
+
|
| 349 |
+
### Custom Benchmark
|
| 350 |
+
|
| 351 |
+
```bash
|
| 352 |
+
python scripts/run_benchmarks.py \
|
| 353 |
+
--excel_path my_benchmark.xlsx \
|
| 354 |
+
--checkpoint_path models/my_model.ckpt \
|
| 355 |
+
--output_dir my_results \
|
| 356 |
+
--use_gpu
|
| 357 |
+
```
|
| 358 |
+
|
| 359 |
+
## Model Architecture
|
| 360 |
+
|
| 361 |
+
### Base Model
|
| 362 |
+
|
| 363 |
+
ENCOT is built on CodonTransformer, a BigBird transformer model:
|
| 364 |
+
- **Architecture**: BigBirdForMaskedLM (89.6M parameters)
|
| 365 |
+
- **Pre-training**: 1M+ DNA-protein pairs from 164 organisms
|
| 366 |
+
- **Context length**: 2048 tokens
|
| 367 |
+
- **Attention**: Block-sparse attention for efficiency
|
| 368 |
+
|
| 369 |
+
### Fine-tuning
|
| 370 |
+
|
| 371 |
+
ENCOT is fine-tuned on E. coli-specific data:
|
| 372 |
+
- **Training data**: 4,300 high-CAI E. coli sequences
|
| 373 |
+
- **Loss function**: Masked Language Modeling (MLM) + GC constraint
|
| 374 |
+
- **Optimizer**: AdamW with CosineAnnealingWarmRestarts scheduler
|
| 375 |
+
- **Learning rate**: 5e-5 with 10% warmup
|
| 376 |
+
|
| 377 |
+
### Augmented-Lagrangian Method (ALM)
|
| 378 |
+
|
| 379 |
+
The ALM approach enforces GC content constraints during training:
|
| 380 |
+
|
| 381 |
+
**Objective function:**
|
| 382 |
+
```
|
| 383 |
+
L = L_MLM + λ·(GC - μ) + (ρ/2)(GC - μ)²
|
| 384 |
+
```
|
| 385 |
+
|
| 386 |
+
Where:
|
| 387 |
+
- `L_MLM`: Masked language modeling loss
|
| 388 |
+
- `λ`: Lagrangian multiplier (updated adaptively)
|
| 389 |
+
- `ρ`: Penalty coefficient (self-tuning)
|
| 390 |
+
- `GC`: Mean GC content (sliding window of 50 codons)
|
| 391 |
+
- `μ`: Target GC content (0.52 for E. coli)
|
| 392 |
+
|
| 393 |
+
**Key features:**
|
| 394 |
+
- **Curriculum learning**: 3 warm-up epochs before enforcing GC constraint
|
| 395 |
+
- **Adaptive penalty**: Penalty coefficient increases if constraint violation doesn't improve
|
| 396 |
+
- **Self-tuning**: Lagrangian multiplier and penalty updated every 20 steps
|
| 397 |
+
|
| 398 |
+
This approach allows the model to learn codon preferences while maintaining precise GC content control, critical for synthesis and expression in E. coli.
|
| 399 |
+
|
| 400 |
+
## Evaluation Metrics
|
| 401 |
+
|
| 402 |
+
ENCOT computes comprehensive metrics for optimized sequences:
|
| 403 |
+
|
| 404 |
+
- **CAI (Codon Adaptation Index)**: Measures similarity to highly expressed genes (0-1, higher is better)
|
| 405 |
+
- **tAI (tRNA Adaptation Index)**: Measures tRNA availability (0-1, higher is better)
|
| 406 |
+
- **GC Content**: Percentage of G+C nucleotides (target: 52% for E. coli)
|
| 407 |
+
- **CFD (Codon Frequency Distribution)**: Similarity to reference codon frequencies
|
| 408 |
+
- **Negative cis-elements**: Count of problematic sequence motifs
|
| 409 |
+
- **Homopolymer runs**: Long repeats that can cause synthesis issues
|
| 410 |
+
|
| 411 |
+
## Project Structure
|
| 412 |
+
|
| 413 |
+
```
|
| 414 |
+
encot/
|
| 415 |
+
├── configs/ # YAML configuration files
|
| 416 |
+
│ ├── train_ecoli_alm.yaml # Main training config
|
| 417 |
+
│ ├── train_ecoli_quick.yaml # Quick test config
|
| 418 |
+
│ └── benchmark.yaml # Benchmark config
|
| 419 |
+
├── scripts/ # Entry-point scripts
|
| 420 |
+
│ ├── preprocess_data.py # Data preparation
|
| 421 |
+
│ ├── train.py # Training wrapper
|
| 422 |
+
│ ├── optimize_sequence.py # Sequence optimization
|
| 423 |
+
│ └── run_benchmarks.py # Benchmark evaluation
|
| 424 |
+
├── CodonTransformer/ # Core module (custom, not PyPI)
|
| 425 |
+
│ ├── CodonPrediction.py # Model loading & inference
|
| 426 |
+
│ ├── CodonEvaluation.py # Metrics calculation
|
| 427 |
+
│ ├── CodonData.py # Data preprocessing
|
| 428 |
+
│ └── ...
|
| 429 |
+
├── data/ # Datasets
|
| 430 |
+
│ ├── finetune_set.json # Training data
|
| 431 |
+
│ ├── test_set.json # Test data
|
| 432 |
+
│ └── ecoli_processed_genes.csv # Reference sequences
|
| 433 |
+
├── models/ # Model checkpoints
|
| 434 |
+
├── notebooks/ # Jupyter notebooks
|
| 435 |
+
├── tests/ # Test suite
|
| 436 |
+
├── streamlit_gui/ # Streamlit web interface
|
| 437 |
+
├── finetune.py # Training script (original)
|
| 438 |
+
├── benchmark_evaluation.py # Evaluation script (original)
|
| 439 |
+
└── README.md # This file
|
| 440 |
+
```
|
| 441 |
+
|
| 442 |
+
## Troubleshooting
|
| 443 |
+
|
| 444 |
+
### Common Issues
|
| 445 |
+
|
| 446 |
+
**1. CUDA out of memory:**
|
| 447 |
+
- Reduce `batch_size` in config file
|
| 448 |
+
- Use gradient accumulation: increase `accumulate_grad_batches`
|
| 449 |
+
|
| 450 |
+
**2. Model checkpoint not found:**
|
| 451 |
+
- The script will auto-download from Hugging Face if local checkpoint missing
|
| 452 |
+
- Ensure you have internet connection for first run
|
| 453 |
+
|
| 454 |
+
**3. Data preprocessing errors:**
|
| 455 |
+
- Verify CSV files have correct column names
|
| 456 |
+
- Check that DNA sequences are valid (divisible by 3, proper start/stop codons)
|
| 457 |
+
|
| 458 |
+
**4. Import errors:**
|
| 459 |
+
- Ensure you've activated the virtual environment
|
| 460 |
+
- Run `pip install -r requirements.txt` again
|
| 461 |
+
|
| 462 |
+
### Getting Help
|
| 463 |
+
|
| 464 |
+
- **Issues**: Open an issue on GitHub
|
| 465 |
+
- **Questions**: Check the documentation or contact the authors
|
| 466 |
+
|
| 467 |
+
## License
|
| 468 |
+
|
| 469 |
+
This project is licensed under the Apache License 2.0 - see the [LICENSE](LICENSE) file for details.
|
| 470 |
+
|
| 471 |
+
## Acknowledgments
|
| 472 |
+
|
| 473 |
+
- **CodonTransformer**: Base model from [adibvafa/CodonTransformer](https://github.com/adibvafa/CodonTransformer)
|
| 474 |
+
- **Hugging Face**: Model hosting and distribution
|
| 475 |
+
- **E. coli data**: NCBI and Kazusa codon usage databases
|
| 476 |
+
|
| 477 |
+
## Citation
|
| 478 |
+
|
| 479 |
+
If you use ENCOT in your research, please cite:
|
| 480 |
+
|
| 481 |
+
```bibtex
|
| 482 |
+
@article{encot2025,
|
| 483 |
+
title={ENCOT: A Transformer-Based Codon Optimization Model Balancing Multiple Objectives for Enhanced E. coli Gene Expression},
|
| 484 |
+
author={Baddam, Saketh and Emam, Omar and Elfikky, Abdelrahman and Cavarretta, Francesco and Luka, George and Farag, Ibrahim and Sanad, Yasser},
|
| 485 |
+
journal={bioRxiv},
|
| 486 |
+
year={2025},
|
| 487 |
+
doi={10.1101/2025.11.26.690826},
|
| 488 |
+
url={https://doi.org/10.1101/2025.11.26.690826},
|
| 489 |
+
note={Preprint (not peer-reviewed)}
|
| 490 |
+
}
|
| 491 |
+
```
|
| 492 |
+
|
| 493 |
+
---
|
| 494 |
+
|
| 495 |
+
**ENCOT** - State-of-the-art codon optimization for E. coli expression systems.
|
app.py
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Hugging Face Spaces Streamlit entrypoint for ENCOT."""
|
| 2 |
+
|
| 3 |
+
from pathlib import Path
|
| 4 |
+
import sys
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
ROOT = Path(__file__).resolve().parent
|
| 8 |
+
if str(ROOT) not in sys.path:
|
| 9 |
+
sys.path.insert(0, str(ROOT))
|
| 10 |
+
|
| 11 |
+
# Importing this module executes the Streamlit UI.
|
| 12 |
+
import streamlit_gui.app # noqa: F401,E402
|
benchmark_evaluation.py
ADDED
|
@@ -0,0 +1,695 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
File: benchmark_evaluation.py
|
| 3 |
+
------------------------------
|
| 4 |
+
Benchmark E. coli protein sequences with ENCOT, generate optimized DNA,
|
| 5 |
+
compute metrics (CAI, tAI, GC, CFD, cis-elements), and produce summary tables
|
| 6 |
+
and figures.
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
import sys
|
| 10 |
+
import os
|
| 11 |
+
import argparse
|
| 12 |
+
import pandas as pd
|
| 13 |
+
import numpy as np
|
| 14 |
+
import torch
|
| 15 |
+
import json
|
| 16 |
+
import matplotlib.pyplot as plt
|
| 17 |
+
import seaborn as sns
|
| 18 |
+
from datetime import datetime
|
| 19 |
+
import time
|
| 20 |
+
from tqdm import tqdm
|
| 21 |
+
from typing import Dict, List, Tuple, Any
|
| 22 |
+
|
| 23 |
+
from CAI import CAI, relative_adaptiveness
|
| 24 |
+
from CodonTransformer.CodonData import (
|
| 25 |
+
download_codon_frequencies_from_kazusa,
|
| 26 |
+
get_codon_frequencies,
|
| 27 |
+
)
|
| 28 |
+
from CodonTransformer.CodonPrediction import (
|
| 29 |
+
load_model,
|
| 30 |
+
predict_dna_sequence,
|
| 31 |
+
)
|
| 32 |
+
from CodonTransformer.CodonEvaluation import (
|
| 33 |
+
get_GC_content,
|
| 34 |
+
get_ecoli_tai_weights,
|
| 35 |
+
get_min_max_profile,
|
| 36 |
+
calculate_tAI,
|
| 37 |
+
count_negative_cis_elements,
|
| 38 |
+
)
|
| 39 |
+
from transformers import AutoTokenizer
|
| 40 |
+
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
|
| 41 |
+
from evaluate_optimizer import translate_dna_to_protein
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
def find_longest_orf(dna_sequence: str) -> str:
|
| 45 |
+
"""
|
| 46 |
+
Find the longest open reading frame (ORF) in a DNA sequence.
|
| 47 |
+
|
| 48 |
+
Args:
|
| 49 |
+
dna_sequence (str): Input DNA sequence (ATCGN characters).
|
| 50 |
+
|
| 51 |
+
Returns:
|
| 52 |
+
str: Longest ORF (from start to stop codon), or empty string if none.
|
| 53 |
+
"""
|
| 54 |
+
dna_sequence = dna_sequence.upper()
|
| 55 |
+
start_codons = ['ATG']
|
| 56 |
+
stop_codons = ['TAA', 'TAG', 'TGA']
|
| 57 |
+
|
| 58 |
+
longest_orf = ""
|
| 59 |
+
|
| 60 |
+
for frame in range(3):
|
| 61 |
+
current_orf = ""
|
| 62 |
+
in_orf = False
|
| 63 |
+
|
| 64 |
+
for i in range(frame, len(dna_sequence) - 2, 3):
|
| 65 |
+
codon = dna_sequence[i:i+3]
|
| 66 |
+
if len(codon) != 3:
|
| 67 |
+
break
|
| 68 |
+
|
| 69 |
+
if codon in start_codons and not in_orf:
|
| 70 |
+
in_orf = True
|
| 71 |
+
current_orf = codon
|
| 72 |
+
elif in_orf:
|
| 73 |
+
current_orf += codon
|
| 74 |
+
if codon in stop_codons:
|
| 75 |
+
if len(current_orf) > len(longest_orf):
|
| 76 |
+
longest_orf = current_orf
|
| 77 |
+
in_orf = False
|
| 78 |
+
current_orf = ""
|
| 79 |
+
|
| 80 |
+
if in_orf and len(current_orf) > len(longest_orf):
|
| 81 |
+
longest_orf = current_orf
|
| 82 |
+
|
| 83 |
+
return longest_orf
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
def _detect_columns(df: pd.DataFrame, name_hint: str | None = None, seq_hint: str | None = None) -> tuple[str | None, str]:
|
| 87 |
+
"""
|
| 88 |
+
Detect name and sequence columns in a case-insensitive, robust way.
|
| 89 |
+
|
| 90 |
+
Args:
|
| 91 |
+
df (pd.DataFrame): Input DataFrame read from Excel.
|
| 92 |
+
name_hint (str | None): Optional override for name/label column (case-insensitive).
|
| 93 |
+
seq_hint (str | None): Optional override for sequence column (case-insensitive).
|
| 94 |
+
|
| 95 |
+
Returns:
|
| 96 |
+
tuple[str | None, str]: Detected (name_column or None, sequence_column).
|
| 97 |
+
|
| 98 |
+
Raises:
|
| 99 |
+
ValueError: If a sequence-like column cannot be found.
|
| 100 |
+
"""
|
| 101 |
+
cols = list(df.columns)
|
| 102 |
+
low_map = {c.lower().strip(): c for c in cols}
|
| 103 |
+
|
| 104 |
+
# If hints are provided and exist (case-insensitive), honor them
|
| 105 |
+
if name_hint:
|
| 106 |
+
nh = name_hint.lower().strip()
|
| 107 |
+
if nh in low_map:
|
| 108 |
+
name_col = low_map[nh]
|
| 109 |
+
else:
|
| 110 |
+
name_col = None
|
| 111 |
+
else:
|
| 112 |
+
name_col = None
|
| 113 |
+
|
| 114 |
+
if seq_hint:
|
| 115 |
+
sh = seq_hint.lower().strip()
|
| 116 |
+
if sh in low_map:
|
| 117 |
+
seq_col = low_map[sh]
|
| 118 |
+
else:
|
| 119 |
+
seq_col = None
|
| 120 |
+
else:
|
| 121 |
+
seq_col = None
|
| 122 |
+
|
| 123 |
+
# If not found, try candidates
|
| 124 |
+
if name_col is None:
|
| 125 |
+
name_candidates = [
|
| 126 |
+
'name','id','title','gene','protein','description','label','accession','locus','entry','uniprot','ncbi','protein name'
|
| 127 |
+
]
|
| 128 |
+
for k in name_candidates:
|
| 129 |
+
if k in low_map:
|
| 130 |
+
name_col = low_map[k]
|
| 131 |
+
break
|
| 132 |
+
|
| 133 |
+
if seq_col is None:
|
| 134 |
+
seq_candidates = [
|
| 135 |
+
# protein-first
|
| 136 |
+
'protein sequence','protein_sequence','protein','aa sequence','aa_sequence','aa','amino acid sequence','amino_acid_sequence',
|
| 137 |
+
# generic
|
| 138 |
+
'sequence','seq',
|
| 139 |
+
# dna/cds
|
| 140 |
+
'cds','dna','coding sequence','coding_sequence','cds sequence','cds_sequence'
|
| 141 |
+
]
|
| 142 |
+
for k in seq_candidates:
|
| 143 |
+
if k in low_map:
|
| 144 |
+
seq_col = low_map[k]
|
| 145 |
+
break
|
| 146 |
+
|
| 147 |
+
if not seq_col:
|
| 148 |
+
raise ValueError(f"Could not detect sequence column. Available columns: {cols}")
|
| 149 |
+
|
| 150 |
+
return name_col, seq_col
|
| 151 |
+
|
| 152 |
+
|
| 153 |
+
def parse_excel_sequences(excel_path: str, name_col: str | None = None, seq_col: str | None = None, sheet_name: str | int | None = None) -> List[Dict[str, str]]:
|
| 154 |
+
"""
|
| 155 |
+
Parse sequences from the benchmark Excel file and auto-detect relevant columns.
|
| 156 |
+
|
| 157 |
+
Args:
|
| 158 |
+
excel_path (str): Path to the Excel file.
|
| 159 |
+
name_col (str | None): Optional override for sequence name column.
|
| 160 |
+
seq_col (str | None): Optional override for sequence column.
|
| 161 |
+
sheet_name (str | int | None): Sheet name or index (default: first sheet).
|
| 162 |
+
|
| 163 |
+
Returns:
|
| 164 |
+
List[Dict[str, str]]: List of standardized sequence records with fields:
|
| 165 |
+
id, name, protein_sequence, original_sequence (DNA or None), is_dna.
|
| 166 |
+
|
| 167 |
+
Raises:
|
| 168 |
+
ValueError: If a sequence column cannot be detected.
|
| 169 |
+
"""
|
| 170 |
+
sn = sheet_name
|
| 171 |
+
if isinstance(sn, str) and sn.isdigit():
|
| 172 |
+
sn = int(sn)
|
| 173 |
+
if sn is None:
|
| 174 |
+
sn = 0
|
| 175 |
+
|
| 176 |
+
df_or_dict = pd.read_excel(excel_path, sheet_name=sn)
|
| 177 |
+
if isinstance(df_or_dict, dict):
|
| 178 |
+
first_title, df = next(iter(df_or_dict.items()))
|
| 179 |
+
print(f"Using sheet: {first_title}")
|
| 180 |
+
else:
|
| 181 |
+
df = df_or_dict
|
| 182 |
+
sequences = []
|
| 183 |
+
|
| 184 |
+
detected_name_col, detected_seq_col = _detect_columns(df, name_col, seq_col)
|
| 185 |
+
print(f"Detected columns -> name: {detected_name_col or '[generated]'}, sequence: {detected_seq_col}")
|
| 186 |
+
|
| 187 |
+
for idx, row in df.iterrows():
|
| 188 |
+
sequence = str(row[detected_seq_col]).strip()
|
| 189 |
+
if detected_name_col:
|
| 190 |
+
name = str(row[detected_name_col]).strip()
|
| 191 |
+
else:
|
| 192 |
+
name = f"seq_{idx}"
|
| 193 |
+
|
| 194 |
+
if name.startswith('>'):
|
| 195 |
+
name = name[1:].strip()
|
| 196 |
+
|
| 197 |
+
sequence = ''.join(filter(str.isalpha, sequence))
|
| 198 |
+
|
| 199 |
+
dna_chars = sum(1 for c in sequence.upper() if c in 'ATCGN')
|
| 200 |
+
is_dna = (dna_chars / len(sequence)) > 0.95 if len(sequence) > 0 else False
|
| 201 |
+
|
| 202 |
+
if is_dna:
|
| 203 |
+
longest_orf = find_longest_orf(sequence)
|
| 204 |
+
|
| 205 |
+
if longest_orf and len(longest_orf) >= 30:
|
| 206 |
+
original_dna = longest_orf
|
| 207 |
+
protein_seq = translate_dna_to_protein(longest_orf)
|
| 208 |
+
else:
|
| 209 |
+
truncated_len = (len(sequence) // 3) * 3
|
| 210 |
+
if truncated_len >= 30:
|
| 211 |
+
original_dna = sequence[:truncated_len]
|
| 212 |
+
protein_seq = translate_dna_to_protein(original_dna)
|
| 213 |
+
else:
|
| 214 |
+
continue
|
| 215 |
+
|
| 216 |
+
if '*' in protein_seq:
|
| 217 |
+
stop_pos = protein_seq.find('*')
|
| 218 |
+
if stop_pos >= 10:
|
| 219 |
+
protein_seq = protein_seq[:stop_pos]
|
| 220 |
+
original_dna = original_dna[:stop_pos*3]
|
| 221 |
+
else:
|
| 222 |
+
continue
|
| 223 |
+
|
| 224 |
+
else:
|
| 225 |
+
protein_seq = sequence.upper()
|
| 226 |
+
protein_seq = protein_seq.replace('*', '')
|
| 227 |
+
original_dna = None
|
| 228 |
+
|
| 229 |
+
if len(protein_seq) < 10:
|
| 230 |
+
continue
|
| 231 |
+
|
| 232 |
+
sequences.append({
|
| 233 |
+
'id': idx,
|
| 234 |
+
'name': name,
|
| 235 |
+
'protein_sequence': protein_seq,
|
| 236 |
+
'original_sequence': original_dna,
|
| 237 |
+
'is_dna': is_dna
|
| 238 |
+
})
|
| 239 |
+
|
| 240 |
+
return sequences
|
| 241 |
+
|
| 242 |
+
|
| 243 |
+
def calculate_cfd(dna_sequence: str, codon_frequencies: Dict) -> float:
|
| 244 |
+
"""
|
| 245 |
+
Calculate Codon Frequency Distribution (CFD) similarity to a reference.
|
| 246 |
+
|
| 247 |
+
Args:
|
| 248 |
+
dna_sequence (str): Input DNA sequence.
|
| 249 |
+
codon_frequencies (Dict): Reference frequencies; accepts flattened mapping
|
| 250 |
+
or an amino2codon structure (will be flattened).
|
| 251 |
+
|
| 252 |
+
Returns:
|
| 253 |
+
float: Similarity score in [0, 1] where higher is more similar.
|
| 254 |
+
"""
|
| 255 |
+
if not dna_sequence:
|
| 256 |
+
return 0.0
|
| 257 |
+
|
| 258 |
+
codon_count = {}
|
| 259 |
+
total_codons = 0
|
| 260 |
+
|
| 261 |
+
for i in range(0, len(dna_sequence) - 2, 3):
|
| 262 |
+
codon = dna_sequence[i:i+3].upper()
|
| 263 |
+
if len(codon) == 3:
|
| 264 |
+
codon_count[codon] = codon_count.get(codon, 0) + 1
|
| 265 |
+
total_codons += 1
|
| 266 |
+
|
| 267 |
+
seq_freq = {}
|
| 268 |
+
if total_codons > 0:
|
| 269 |
+
for codon, count in codon_count.items():
|
| 270 |
+
seq_freq[codon] = count / total_codons
|
| 271 |
+
|
| 272 |
+
# Flatten amino2codon frequencies if needed
|
| 273 |
+
flat_codon_freq = {}
|
| 274 |
+
if isinstance(codon_frequencies, dict):
|
| 275 |
+
first_key = next(iter(codon_frequencies.keys()))
|
| 276 |
+
if isinstance(codon_frequencies[first_key], tuple) and len(codon_frequencies[first_key]) == 2:
|
| 277 |
+
for amino, (codons, freqs) in codon_frequencies.items():
|
| 278 |
+
for codon, freq in zip(codons, freqs):
|
| 279 |
+
flat_codon_freq[codon] = freq
|
| 280 |
+
else:
|
| 281 |
+
flat_codon_freq = codon_frequencies
|
| 282 |
+
|
| 283 |
+
similarity = 0.0
|
| 284 |
+
count = 0
|
| 285 |
+
|
| 286 |
+
for codon in set(list(seq_freq.keys()) + list(flat_codon_freq.keys())):
|
| 287 |
+
seq_f = seq_freq.get(codon, 0.0)
|
| 288 |
+
ref_f = flat_codon_freq.get(codon, 0.0)
|
| 289 |
+
similarity += 1 - abs(seq_f - ref_f)
|
| 290 |
+
count += 1
|
| 291 |
+
|
| 292 |
+
return similarity / count if count > 0 else 0.0
|
| 293 |
+
|
| 294 |
+
|
| 295 |
+
def run_model_on_sequences(
|
| 296 |
+
sequences: List[Dict],
|
| 297 |
+
model,
|
| 298 |
+
tokenizer,
|
| 299 |
+
device,
|
| 300 |
+
cai_weights: Dict,
|
| 301 |
+
tai_weights: Dict,
|
| 302 |
+
codon_frequencies: Dict,
|
| 303 |
+
reference_profile: List[float],
|
| 304 |
+
output_dir: str
|
| 305 |
+
) -> pd.DataFrame:
|
| 306 |
+
"""
|
| 307 |
+
Run ColiFormer on protein sequences and compute metrics for optimized DNA.
|
| 308 |
+
|
| 309 |
+
Args:
|
| 310 |
+
sequences (List[Dict]): Parsed sequence records.
|
| 311 |
+
model: Loaded ColiFormer model.
|
| 312 |
+
tokenizer: Tokenizer used by the model.
|
| 313 |
+
device: Torch device.
|
| 314 |
+
cai_weights (Dict): CAI weights.
|
| 315 |
+
tai_weights (Dict): tAI weights.
|
| 316 |
+
codon_frequencies (Dict): Reference codon frequencies.
|
| 317 |
+
reference_profile (List[float]): Reserved for DTW profile (unused here).
|
| 318 |
+
output_dir (str): Directory for outputs (not written here).
|
| 319 |
+
|
| 320 |
+
Returns:
|
| 321 |
+
pd.DataFrame: Per-sequence metrics and optimized DNA.
|
| 322 |
+
"""
|
| 323 |
+
results = []
|
| 324 |
+
print(f"Processing {len(sequences)} sequences...")
|
| 325 |
+
|
| 326 |
+
for seq_data in tqdm(sequences, desc="Optimizing sequences"):
|
| 327 |
+
protein_seq = seq_data['protein_sequence']
|
| 328 |
+
|
| 329 |
+
if len(protein_seq) < 10:
|
| 330 |
+
continue
|
| 331 |
+
|
| 332 |
+
try:
|
| 333 |
+
start_time = time.time()
|
| 334 |
+
|
| 335 |
+
output = predict_dna_sequence(
|
| 336 |
+
protein=protein_seq,
|
| 337 |
+
organism="Escherichia coli general",
|
| 338 |
+
device=device,
|
| 339 |
+
model=model,
|
| 340 |
+
deterministic=True,
|
| 341 |
+
match_protein=True,
|
| 342 |
+
)
|
| 343 |
+
|
| 344 |
+
runtime = time.time() - start_time
|
| 345 |
+
|
| 346 |
+
if isinstance(output, list):
|
| 347 |
+
optimized_dna = output[0].predicted_dna
|
| 348 |
+
else:
|
| 349 |
+
optimized_dna = output.predicted_dna
|
| 350 |
+
|
| 351 |
+
original_metrics = {}
|
| 352 |
+
if seq_data['is_dna'] and seq_data['original_sequence']:
|
| 353 |
+
original_dna = seq_data['original_sequence'].upper()
|
| 354 |
+
original_metrics = {
|
| 355 |
+
'original_cai': CAI(original_dna, weights=cai_weights),
|
| 356 |
+
'original_gc': get_GC_content(original_dna),
|
| 357 |
+
'original_tai': calculate_tAI(original_dna, tai_weights),
|
| 358 |
+
'original_cfd': calculate_cfd(original_dna, codon_frequencies),
|
| 359 |
+
'original_neg_cis': count_negative_cis_elements(original_dna),
|
| 360 |
+
}
|
| 361 |
+
|
| 362 |
+
optimized_metrics = {
|
| 363 |
+
'optimized_cai': CAI(optimized_dna, weights=cai_weights),
|
| 364 |
+
'optimized_gc': get_GC_content(optimized_dna),
|
| 365 |
+
'optimized_tai': calculate_tAI(optimized_dna, tai_weights),
|
| 366 |
+
'optimized_cfd': calculate_cfd(optimized_dna, codon_frequencies),
|
| 367 |
+
'optimized_neg_cis': count_negative_cis_elements(optimized_dna),
|
| 368 |
+
'runtime': runtime,
|
| 369 |
+
}
|
| 370 |
+
|
| 371 |
+
result = {
|
| 372 |
+
'id': seq_data['id'],
|
| 373 |
+
'name': seq_data['name'],
|
| 374 |
+
'protein_sequence': protein_seq,
|
| 375 |
+
'protein_length': len(protein_seq),
|
| 376 |
+
'optimized_dna': optimized_dna,
|
| 377 |
+
**original_metrics,
|
| 378 |
+
**optimized_metrics,
|
| 379 |
+
}
|
| 380 |
+
results.append(result)
|
| 381 |
+
|
| 382 |
+
except Exception as e:
|
| 383 |
+
print(f"Error processing sequence {seq_data['id']}: {str(e)}")
|
| 384 |
+
continue
|
| 385 |
+
|
| 386 |
+
return pd.DataFrame(results)
|
| 387 |
+
|
| 388 |
+
|
| 389 |
+
def generate_visualizations(results_df: pd.DataFrame, output_dir: str):
|
| 390 |
+
"""
|
| 391 |
+
Generate visualizations and a metrics summary table.
|
| 392 |
+
|
| 393 |
+
Saves:
|
| 394 |
+
- CAI before/after bar plot
|
| 395 |
+
- Median CAI comparison
|
| 396 |
+
- Metrics distribution panel
|
| 397 |
+
- CSV summary table
|
| 398 |
+
|
| 399 |
+
Args:
|
| 400 |
+
results_df (pd.DataFrame): Results from optimization.
|
| 401 |
+
output_dir (str): Output directory root.
|
| 402 |
+
|
| 403 |
+
Returns:
|
| 404 |
+
pd.DataFrame: Summary table of aggregate metrics.
|
| 405 |
+
"""
|
| 406 |
+
plt.style.use('seaborn-v0_8-darkgrid')
|
| 407 |
+
sns.set_palette("husl")
|
| 408 |
+
|
| 409 |
+
fig_dir = os.path.join(output_dir, 'figures')
|
| 410 |
+
os.makedirs(fig_dir, exist_ok=True)
|
| 411 |
+
|
| 412 |
+
# 1. Before/After CAI Graph
|
| 413 |
+
if 'original_cai' in results_df.columns:
|
| 414 |
+
plt.figure(figsize=(12, 8))
|
| 415 |
+
|
| 416 |
+
before_cai = results_df['original_cai'].dropna()
|
| 417 |
+
after_cai = results_df.loc[before_cai.index, 'optimized_cai']
|
| 418 |
+
|
| 419 |
+
x = np.arange(len(before_cai))
|
| 420 |
+
width = 0.35
|
| 421 |
+
|
| 422 |
+
fig, ax = plt.subplots(figsize=(14, 8))
|
| 423 |
+
bars1 = ax.bar(x - width/2, before_cai, width, label='Before Optimization', alpha=0.8)
|
| 424 |
+
bars2 = ax.bar(x + width/2, after_cai, width, label='After Optimization', alpha=0.8)
|
| 425 |
+
|
| 426 |
+
ax.set_xlabel('Sequence Index', fontsize=12)
|
| 427 |
+
ax.set_ylabel('CAI Score', fontsize=12)
|
| 428 |
+
ax.set_title('ENCOT: CAI Before and After Optimization', fontsize=14, fontweight='bold')
|
| 429 |
+
ax.set_xticks(x[::5]) # Show every 5th label
|
| 430 |
+
ax.set_xticklabels(x[::5])
|
| 431 |
+
ax.legend()
|
| 432 |
+
ax.grid(axis='y', alpha=0.3)
|
| 433 |
+
|
| 434 |
+
avg_before = before_cai.mean()
|
| 435 |
+
avg_after = after_cai.mean()
|
| 436 |
+
improvement = ((avg_after - avg_before) / avg_before) * 100
|
| 437 |
+
|
| 438 |
+
ax.text(0.02, 0.98, f'Average CAI Before: {avg_before:.3f}\nAverage CAI After: {avg_after:.3f}\nImprovement: {improvement:.1f}%',
|
| 439 |
+
transform=ax.transAxes, fontsize=10, verticalalignment='top',
|
| 440 |
+
bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.5))
|
| 441 |
+
|
| 442 |
+
plt.tight_layout()
|
| 443 |
+
plt.savefig(os.path.join(fig_dir, 'cai_before_after.png'), dpi=300, bbox_inches='tight')
|
| 444 |
+
plt.close()
|
| 445 |
+
|
| 446 |
+
print(f"CAI Before/After graph saved to {os.path.join(fig_dir, 'cai_before_after.png')}")
|
| 447 |
+
|
| 448 |
+
# 1b. Median CAI Before/After Graph
|
| 449 |
+
plt.figure(figsize=(8, 6))
|
| 450 |
+
|
| 451 |
+
median_before = before_cai.median()
|
| 452 |
+
median_after = after_cai.median()
|
| 453 |
+
|
| 454 |
+
categories = ['Before Optimization', 'After Optimization']
|
| 455 |
+
medians = [median_before, median_after]
|
| 456 |
+
colors = ['#ff7f0e', '#2ca02c']
|
| 457 |
+
|
| 458 |
+
bars = plt.bar(categories, medians, color=colors, alpha=0.8, width=0.6)
|
| 459 |
+
plt.ylabel('Median CAI Score', fontsize=12)
|
| 460 |
+
plt.title('ENCOT: Median CAI Before and After Optimization', fontsize=14, fontweight='bold')
|
| 461 |
+
plt.ylim(0, max(medians) * 1.2)
|
| 462 |
+
|
| 463 |
+
for bar, median in zip(bars, medians):
|
| 464 |
+
plt.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.01,
|
| 465 |
+
f'{median:.3f}', ha='center', va='bottom', fontweight='bold')
|
| 466 |
+
|
| 467 |
+
improvement_pct = ((median_after - median_before) / median_before) * 100
|
| 468 |
+
plt.text(0.5, max(medians) * 0.95, f'Improvement: {improvement_pct:.1f}%',
|
| 469 |
+
ha='center', transform=plt.gca().transData, fontsize=12,
|
| 470 |
+
bbox=dict(boxstyle='round', facecolor='lightblue', alpha=0.7))
|
| 471 |
+
|
| 472 |
+
plt.grid(axis='y', alpha=0.3)
|
| 473 |
+
plt.tight_layout()
|
| 474 |
+
plt.savefig(os.path.join(fig_dir, 'median_cai_comparison.png'), dpi=300, bbox_inches='tight')
|
| 475 |
+
plt.close()
|
| 476 |
+
|
| 477 |
+
print(f"Median CAI comparison graph saved to {os.path.join(fig_dir, 'median_cai_comparison.png')}")
|
| 478 |
+
|
| 479 |
+
# 2. Summary metrics table
|
| 480 |
+
metrics_summary = {}
|
| 481 |
+
|
| 482 |
+
if 'original_cai' in results_df.columns:
|
| 483 |
+
metrics_summary['CAI'] = {
|
| 484 |
+
'Before': results_df['original_cai'].mean(),
|
| 485 |
+
'After': results_df['optimized_cai'].mean(),
|
| 486 |
+
'Improvement': ((results_df['optimized_cai'].mean() - results_df['original_cai'].mean()) / results_df['original_cai'].mean()) * 100
|
| 487 |
+
}
|
| 488 |
+
metrics_summary['GC Content (%)'] = {
|
| 489 |
+
'Before': results_df['original_gc'].mean(),
|
| 490 |
+
'After': results_df['optimized_gc'].mean(),
|
| 491 |
+
'Difference': results_df['optimized_gc'].mean() - results_df['original_gc'].mean()
|
| 492 |
+
}
|
| 493 |
+
metrics_summary['tAI'] = {
|
| 494 |
+
'Before': results_df['original_tai'].mean(),
|
| 495 |
+
'After': results_df['optimized_tai'].mean(),
|
| 496 |
+
'Improvement': ((results_df['optimized_tai'].mean() - results_df['original_tai'].mean()) / results_df['original_tai'].mean()) * 100
|
| 497 |
+
}
|
| 498 |
+
metrics_summary['CFD'] = {
|
| 499 |
+
'Before': results_df['original_cfd'].mean(),
|
| 500 |
+
'After': results_df['optimized_cfd'].mean(),
|
| 501 |
+
'Improvement': ((results_df['optimized_cfd'].mean() - results_df['original_cfd'].mean()) / results_df['original_cfd'].mean()) * 100
|
| 502 |
+
}
|
| 503 |
+
metrics_summary['Negative Cis Elements'] = {
|
| 504 |
+
'Before': results_df['original_neg_cis'].mean(),
|
| 505 |
+
'After': results_df['optimized_neg_cis'].mean(),
|
| 506 |
+
'Reduction': results_df['original_neg_cis'].mean() - results_df['optimized_neg_cis'].mean()
|
| 507 |
+
}
|
| 508 |
+
else:
|
| 509 |
+
metrics_summary['CAI'] = {
|
| 510 |
+
'Optimized': results_df['optimized_cai'].mean(),
|
| 511 |
+
'Std Dev': results_df['optimized_cai'].std()
|
| 512 |
+
}
|
| 513 |
+
metrics_summary['GC Content (%)'] = {
|
| 514 |
+
'Optimized': results_df['optimized_gc'].mean(),
|
| 515 |
+
'Std Dev': results_df['optimized_gc'].std()
|
| 516 |
+
}
|
| 517 |
+
metrics_summary['tAI'] = {
|
| 518 |
+
'Optimized': results_df['optimized_tai'].mean(),
|
| 519 |
+
'Std Dev': results_df['optimized_tai'].std()
|
| 520 |
+
}
|
| 521 |
+
metrics_summary['CFD'] = {
|
| 522 |
+
'Optimized': results_df['optimized_cfd'].mean(),
|
| 523 |
+
'Std Dev': results_df['optimized_cfd'].std()
|
| 524 |
+
}
|
| 525 |
+
metrics_summary['Negative Cis Elements'] = {
|
| 526 |
+
'Optimized': results_df['optimized_neg_cis'].mean(),
|
| 527 |
+
'Std Dev': results_df['optimized_neg_cis'].std()
|
| 528 |
+
}
|
| 529 |
+
|
| 530 |
+
metrics_summary['Runtime (seconds)'] = {
|
| 531 |
+
'Mean': results_df['runtime'].mean(),
|
| 532 |
+
'Median': results_df['runtime'].median(),
|
| 533 |
+
'Total': results_df['runtime'].sum()
|
| 534 |
+
}
|
| 535 |
+
|
| 536 |
+
summary_df = pd.DataFrame(metrics_summary).T
|
| 537 |
+
summary_df = summary_df.round(4)
|
| 538 |
+
|
| 539 |
+
summary_df.to_csv(os.path.join(output_dir, 'metrics_summary.csv'))
|
| 540 |
+
print(f"\nMetrics Summary saved to {os.path.join(output_dir, 'metrics_summary.csv')}")
|
| 541 |
+
print("\n" + "="*60)
|
| 542 |
+
print("METRICS SUMMARY:")
|
| 543 |
+
print("="*60)
|
| 544 |
+
print(summary_df.to_string())
|
| 545 |
+
|
| 546 |
+
fig, axes = plt.subplots(2, 3, figsize=(15, 10))
|
| 547 |
+
axes = axes.flatten()
|
| 548 |
+
|
| 549 |
+
metrics_to_plot = [
|
| 550 |
+
('optimized_cai', 'CAI Distribution'),
|
| 551 |
+
('optimized_gc', 'GC Content Distribution (%)'),
|
| 552 |
+
('optimized_tai', 'tAI Distribution'),
|
| 553 |
+
('optimized_cfd', 'CFD Distribution'),
|
| 554 |
+
('optimized_neg_cis', 'Negative Cis Elements'),
|
| 555 |
+
('runtime', 'Runtime Distribution (seconds)')
|
| 556 |
+
]
|
| 557 |
+
|
| 558 |
+
for idx, (col, title) in enumerate(metrics_to_plot):
|
| 559 |
+
if col in results_df.columns:
|
| 560 |
+
axes[idx].hist(results_df[col].dropna(), bins=20, edgecolor='black', alpha=0.7)
|
| 561 |
+
axes[idx].set_title(title, fontsize=10, fontweight='bold')
|
| 562 |
+
axes[idx].set_xlabel(col.replace('optimized_', '').replace('_', ' ').title())
|
| 563 |
+
axes[idx].set_ylabel('Frequency')
|
| 564 |
+
axes[idx].grid(axis='y', alpha=0.3)
|
| 565 |
+
|
| 566 |
+
mean_val = results_df[col].mean()
|
| 567 |
+
axes[idx].axvline(mean_val, color='red', linestyle='--', linewidth=2, label=f'Mean: {mean_val:.3f}')
|
| 568 |
+
axes[idx].legend()
|
| 569 |
+
|
| 570 |
+
plt.suptitle('ENCOT: Optimization Metrics Distribution', fontsize=14, fontweight='bold', y=1.02)
|
| 571 |
+
plt.tight_layout()
|
| 572 |
+
plt.savefig(os.path.join(fig_dir, 'metrics_distribution.png'), dpi=300, bbox_inches='tight')
|
| 573 |
+
plt.close()
|
| 574 |
+
|
| 575 |
+
print(f"Metrics distribution plot saved to {os.path.join(fig_dir, 'metrics_distribution.png')}")
|
| 576 |
+
|
| 577 |
+
return summary_df
|
| 578 |
+
|
| 579 |
+
|
| 580 |
+
def main():
|
| 581 |
+
"""CLI entrypoint to run the ENCOT benchmark workflow."""
|
| 582 |
+
parser = argparse.ArgumentParser(description="Benchmark ENCOT on E. coli sequences")
|
| 583 |
+
parser.add_argument("--excel_path", type=str, default="Benchmark 80 sequences.xlsx",
|
| 584 |
+
help="Path to benchmark Excel file")
|
| 585 |
+
parser.add_argument("--checkpoint_path", type=str, default="models/ecoli-codon-optimizer/finetune_best.ckpt",
|
| 586 |
+
help="Path to fine-tuned model checkpoint")
|
| 587 |
+
parser.add_argument("--natural_sequences_path", type=str, default="data/ecoli_processed_genes.csv",
|
| 588 |
+
help="Path to natural E. coli sequences for CAI calculation")
|
| 589 |
+
parser.add_argument("--output_dir", type=str, default="benchmark_results",
|
| 590 |
+
help="Directory to save results")
|
| 591 |
+
parser.add_argument("--use_gpu", action="store_true", help="Use GPU if available")
|
| 592 |
+
parser.add_argument("--name_col", type=str, default=None, help="Optional: column name for sequence label (case-insensitive)")
|
| 593 |
+
parser.add_argument("--seq_col", type=str, default=None, help="Optional: column name for sequence (case-insensitive)")
|
| 594 |
+
parser.add_argument("--sheet_name", type=str, default=None, help="Optional: Excel sheet name or index")
|
| 595 |
+
|
| 596 |
+
args = parser.parse_args()
|
| 597 |
+
|
| 598 |
+
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
| 599 |
+
output_dir = os.path.join(args.output_dir, f"run_{timestamp}")
|
| 600 |
+
os.makedirs(output_dir, exist_ok=True)
|
| 601 |
+
|
| 602 |
+
print("="*60)
|
| 603 |
+
print("ENCOT BENCHMARK EVALUATION")
|
| 604 |
+
print("="*60)
|
| 605 |
+
|
| 606 |
+
device = torch.device("cuda" if torch.cuda.is_available() and args.use_gpu else "cpu")
|
| 607 |
+
print(f"Using device: {device}")
|
| 608 |
+
|
| 609 |
+
print(f"\nLoading sequences from {args.excel_path}...")
|
| 610 |
+
sequences = parse_excel_sequences(
|
| 611 |
+
args.excel_path,
|
| 612 |
+
name_col=args.name_col,
|
| 613 |
+
seq_col=args.seq_col,
|
| 614 |
+
sheet_name=args.sheet_name,
|
| 615 |
+
)
|
| 616 |
+
print(f"Loaded {len(sequences)} sequences")
|
| 617 |
+
|
| 618 |
+
print("\nLoading ENCOT model...")
|
| 619 |
+
model = load_model(model_path=args.checkpoint_path, device=device)
|
| 620 |
+
tokenizer = AutoTokenizer.from_pretrained("adibvafa/CodonTransformer")
|
| 621 |
+
print("Model loaded successfully")
|
| 622 |
+
|
| 623 |
+
print("\nPreparing evaluation utilities...")
|
| 624 |
+
|
| 625 |
+
natural_df = pd.read_csv(args.natural_sequences_path)
|
| 626 |
+
ref_sequences = natural_df['dna_sequence'].tolist()
|
| 627 |
+
cai_weights = relative_adaptiveness(sequences=ref_sequences)
|
| 628 |
+
print("CAI weights generated")
|
| 629 |
+
|
| 630 |
+
tai_weights = get_ecoli_tai_weights()
|
| 631 |
+
print("tAI weights loaded")
|
| 632 |
+
|
| 633 |
+
try:
|
| 634 |
+
codon_frequencies = download_codon_frequencies_from_kazusa(taxonomy_id=83333)
|
| 635 |
+
print("Codon frequencies loaded from Kazusa")
|
| 636 |
+
except Exception as e:
|
| 637 |
+
print(f"Warning: Kazusa download failed ({e}). Using local frequencies.")
|
| 638 |
+
codon_frequencies = get_codon_frequencies(
|
| 639 |
+
ref_sequences, organism="Escherichia coli general"
|
| 640 |
+
)
|
| 641 |
+
|
| 642 |
+
reference_profile = []
|
| 643 |
+
|
| 644 |
+
print("\n" + "="*60)
|
| 645 |
+
print("RUNNING OPTIMIZATION...")
|
| 646 |
+
print("="*60)
|
| 647 |
+
|
| 648 |
+
results_df = run_model_on_sequences(
|
| 649 |
+
sequences=sequences,
|
| 650 |
+
model=model,
|
| 651 |
+
tokenizer=tokenizer,
|
| 652 |
+
device=device,
|
| 653 |
+
cai_weights=cai_weights,
|
| 654 |
+
tai_weights=tai_weights,
|
| 655 |
+
codon_frequencies=codon_frequencies,
|
| 656 |
+
reference_profile=reference_profile,
|
| 657 |
+
output_dir=output_dir
|
| 658 |
+
)
|
| 659 |
+
|
| 660 |
+
results_path = os.path.join(output_dir, 'optimization_results.csv')
|
| 661 |
+
results_df.to_csv(results_path, index=False)
|
| 662 |
+
print(f"\nRaw results saved to {results_path}")
|
| 663 |
+
|
| 664 |
+
optimized_sequences = results_df[['id', 'name', 'protein_sequence', 'optimized_dna']].copy()
|
| 665 |
+
optimized_sequences['protein_length'] = results_df['protein_length']
|
| 666 |
+
optimized_sequences['dna_length'] = optimized_sequences['optimized_dna'].apply(len)
|
| 667 |
+
optimized_sequences['optimized_cai'] = results_df['optimized_cai']
|
| 668 |
+
optimized_sequences['optimized_gc'] = results_df['optimized_gc']
|
| 669 |
+
optimized_sequences['optimized_tai'] = results_df['optimized_tai']
|
| 670 |
+
|
| 671 |
+
if 'original_cai' in results_df.columns:
|
| 672 |
+
optimized_sequences['original_cai'] = results_df['original_cai']
|
| 673 |
+
optimized_sequences['cai_improvement'] = ((results_df['optimized_cai'] - results_df['original_cai']) / results_df['original_cai'] * 100).round(2)
|
| 674 |
+
|
| 675 |
+
optimized_sequences_path = os.path.join(output_dir, 'optimized_dna_sequences.csv')
|
| 676 |
+
optimized_sequences.to_csv(optimized_sequences_path, index=False)
|
| 677 |
+
print(f"Optimized DNA sequences saved to {optimized_sequences_path}")
|
| 678 |
+
|
| 679 |
+
print("\n" + "="*60)
|
| 680 |
+
print("GENERATING VISUALIZATIONS...")
|
| 681 |
+
print("="*60)
|
| 682 |
+
|
| 683 |
+
summary_df = generate_visualizations(results_df, output_dir)
|
| 684 |
+
|
| 685 |
+
print("\n" + "="*60)
|
| 686 |
+
print("BENCHMARK EVALUATION COMPLETE")
|
| 687 |
+
print("="*60)
|
| 688 |
+
print(f"Results saved to: {output_dir}")
|
| 689 |
+
print(f"Total sequences processed: {len(results_df)}")
|
| 690 |
+
print(f"Average runtime per sequence: {results_df['runtime'].mean():.2f} seconds")
|
| 691 |
+
print(f"Total runtime: {results_df['runtime'].sum():.2f} seconds")
|
| 692 |
+
|
| 693 |
+
|
| 694 |
+
if __name__ == "__main__":
|
| 695 |
+
main()
|
comprehensive_model_comparison.png
ADDED
|
Git LFS Details
|
configs/train_ecoli_alm.yaml
ADDED
|
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# ENCOT ALM Training Configuration
|
| 2 |
+
# This configuration reproduces the main training setup from the paper
|
| 3 |
+
# using the Augmented-Lagrangian Method (ALM) for GC content control.
|
| 4 |
+
|
| 5 |
+
model:
|
| 6 |
+
base_model: "adibvafa/CodonTransformer-base"
|
| 7 |
+
tokenizer: "adibvafa/CodonTransformer"
|
| 8 |
+
|
| 9 |
+
data:
|
| 10 |
+
dataset_dir: "data"
|
| 11 |
+
# Expected files: finetune_set.json (created by preprocess_data.py)
|
| 12 |
+
|
| 13 |
+
training:
|
| 14 |
+
batch_size: 6
|
| 15 |
+
max_epochs: 15
|
| 16 |
+
learning_rate: 5e-5
|
| 17 |
+
warmup_fraction: 0.1
|
| 18 |
+
num_workers: 5
|
| 19 |
+
accumulate_grad_batches: 1
|
| 20 |
+
num_gpus: 4
|
| 21 |
+
save_every_n_steps: 512
|
| 22 |
+
seed: 123
|
| 23 |
+
log_every_n_steps: 20
|
| 24 |
+
|
| 25 |
+
checkpoint:
|
| 26 |
+
checkpoint_dir: "models/alm-enhanced-training"
|
| 27 |
+
checkpoint_filename: "balanced_alm_finetune.ckpt"
|
| 28 |
+
|
| 29 |
+
# Augmented-Lagrangian Method (ALM) for GC content control
|
| 30 |
+
alm:
|
| 31 |
+
enabled: true
|
| 32 |
+
gc_target: 0.52 # Target GC content for E. coli (52%)
|
| 33 |
+
curriculum_epochs: 3 # Warm-up epochs before enforcing GC constraint
|
| 34 |
+
|
| 35 |
+
# ALM penalty parameters
|
| 36 |
+
initial_penalty_factor: 20.0
|
| 37 |
+
penalty_update_factor: 10.0
|
| 38 |
+
max_penalty: 1e6
|
| 39 |
+
min_penalty: 1e-6
|
| 40 |
+
|
| 41 |
+
# ALM tolerance parameters
|
| 42 |
+
tolerance: 1e-5 # Primal tolerance
|
| 43 |
+
dual_tolerance: 1e-5 # Dual tolerance for constraint violation
|
| 44 |
+
tolerance_update_factor: 0.1
|
| 45 |
+
|
| 46 |
+
# Adaptive penalty adjustment
|
| 47 |
+
rel_penalty_increase_threshold: 0.1
|
| 48 |
+
|
| 49 |
+
# Legacy penalty method (if ALM disabled)
|
| 50 |
+
gc_penalty:
|
| 51 |
+
weight: 0.0 # Only used if use_lagrangian=false
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
|
configs/train_ecoli_quick.yaml
ADDED
|
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# ENCOT Quick Training Configuration
|
| 2 |
+
# This is a minimal configuration for quick sanity checks and testing.
|
| 3 |
+
# Use this to verify your setup before running full training.
|
| 4 |
+
|
| 5 |
+
model:
|
| 6 |
+
base_model: "adibvafa/CodonTransformer-base"
|
| 7 |
+
tokenizer: "adibvafa/CodonTransformer"
|
| 8 |
+
|
| 9 |
+
data:
|
| 10 |
+
dataset_dir: "data"
|
| 11 |
+
|
| 12 |
+
training:
|
| 13 |
+
batch_size: 2
|
| 14 |
+
max_epochs: 1
|
| 15 |
+
learning_rate: 5e-5
|
| 16 |
+
warmup_fraction: 0.1
|
| 17 |
+
num_workers: 0 # Disable multiprocessing for debugging
|
| 18 |
+
accumulate_grad_batches: 1
|
| 19 |
+
num_gpus: 0 # CPU-only for quick testing
|
| 20 |
+
save_every_n_steps: 10
|
| 21 |
+
seed: 123
|
| 22 |
+
log_every_n_steps: 5
|
| 23 |
+
|
| 24 |
+
checkpoint:
|
| 25 |
+
checkpoint_dir: "models/test-training"
|
| 26 |
+
checkpoint_filename: "quick_test.ckpt"
|
| 27 |
+
|
| 28 |
+
alm:
|
| 29 |
+
enabled: false # Disable ALM for quick test
|
| 30 |
+
gc_target: 0.52
|
| 31 |
+
curriculum_epochs: 0
|
| 32 |
+
|
| 33 |
+
gc_penalty:
|
| 34 |
+
weight: 0.0
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
|
create_model_datasets.py
ADDED
|
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import pandas as pd
|
| 2 |
+
import json
|
| 3 |
+
import os
|
| 4 |
+
from CodonTransformer.CodonData import prepare_training_data
|
| 5 |
+
|
| 6 |
+
def main():
|
| 7 |
+
"""
|
| 8 |
+
Main function to partition the processed data into fine-tuning and test sets.
|
| 9 |
+
"""
|
| 10 |
+
if not os.path.exists('data'):
|
| 11 |
+
print("Error: 'data' directory not found. Please run prepare_ecoli_data.py first.")
|
| 12 |
+
return
|
| 13 |
+
|
| 14 |
+
processed_data_path = 'data/ecoli_processed_genes.csv'
|
| 15 |
+
if not os.path.exists(processed_data_path):
|
| 16 |
+
print(f"Error: Processed data file not found at {processed_data_path}")
|
| 17 |
+
return
|
| 18 |
+
|
| 19 |
+
df_processed = pd.read_csv(processed_data_path)
|
| 20 |
+
|
| 21 |
+
df_finetune = df_processed[df_processed['is_high_cai'] == True].copy()
|
| 22 |
+
df_finetune.drop_duplicates(subset=['dna_sequence'], inplace=True)
|
| 23 |
+
df_finetune.rename(columns={'dna_sequence': 'dna', 'protein_sequence': 'protein'}, inplace=True)
|
| 24 |
+
df_finetune['organism'] = "Escherichia coli general"
|
| 25 |
+
|
| 26 |
+
finetune_output_path = 'data/finetune_set.json'
|
| 27 |
+
prepare_training_data(df_finetune, finetune_output_path, shuffle=True)
|
| 28 |
+
print(f"Fine-tuning set saved to {finetune_output_path} with {len(df_finetune)} records.")
|
| 29 |
+
|
| 30 |
+
df_test_pool = df_processed[df_processed['is_high_cai'] == False].copy()
|
| 31 |
+
df_test = df_test_pool.sample(n=100, random_state=42) # for reproducibility
|
| 32 |
+
df_test['organism'] = 51 # E. coli general
|
| 33 |
+
df_test.rename(columns={'dna_sequence': 'codons'}, inplace=True)
|
| 34 |
+
test_records = df_test[['codons', 'organism']].to_dict(orient='records')
|
| 35 |
+
|
| 36 |
+
test_output_path = 'data/test_set.json'
|
| 37 |
+
with open(test_output_path, 'w') as f:
|
| 38 |
+
json.dump(test_records, f, indent=4)
|
| 39 |
+
print(f"Test set saved to {test_output_path} with {len(df_test)} records.")
|
| 40 |
+
|
| 41 |
+
if __name__ == "__main__":
|
| 42 |
+
main()
|
evaluate_optimizer.py
ADDED
|
@@ -0,0 +1,577 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import sys
|
| 2 |
+
"""
|
| 3 |
+
File: evaluate_optimizer.py
|
| 4 |
+
---------------------------
|
| 5 |
+
Evaluate ColiFormer with enhanced capabilities:
|
| 6 |
+
1) DNAChisel post-processing for sequence polishing
|
| 7 |
+
2) Optional multi-objective generation (Pareto-style filtering)
|
| 8 |
+
3) Enhanced beam search with multiple candidates
|
| 9 |
+
4) Comprehensive metrics and optional ablation studies
|
| 10 |
+
"""
|
| 11 |
+
|
| 12 |
+
import argparse
|
| 13 |
+
import json
|
| 14 |
+
import os
|
| 15 |
+
import warnings
|
| 16 |
+
from typing import Dict, List, Tuple, Any
|
| 17 |
+
|
| 18 |
+
import numpy as np
|
| 19 |
+
import pandas as pd
|
| 20 |
+
import torch
|
| 21 |
+
from CAI import CAI, relative_adaptiveness
|
| 22 |
+
from tqdm import tqdm
|
| 23 |
+
|
| 24 |
+
from CodonTransformer.CodonData import (
|
| 25 |
+
download_codon_frequencies_from_kazusa,
|
| 26 |
+
get_codon_frequencies,
|
| 27 |
+
)
|
| 28 |
+
from CodonTransformer.CodonPrediction import (
|
| 29 |
+
load_model,
|
| 30 |
+
predict_dna_sequence,
|
| 31 |
+
get_high_frequency_choice_sequence_optimized,
|
| 32 |
+
)
|
| 33 |
+
from CodonTransformer.CodonEvaluation import (
|
| 34 |
+
calculate_dtw_distance,
|
| 35 |
+
calculate_homopolymer_runs,
|
| 36 |
+
calculate_tAI,
|
| 37 |
+
count_negative_cis_elements,
|
| 38 |
+
get_GC_content,
|
| 39 |
+
get_ecoli_tai_weights,
|
| 40 |
+
get_min_max_profile,
|
| 41 |
+
get_sequence_similarity,
|
| 42 |
+
scan_for_restriction_sites,
|
| 43 |
+
calculate_ENC,
|
| 44 |
+
calculate_CPB,
|
| 45 |
+
calculate_SCUO,
|
| 46 |
+
)
|
| 47 |
+
from CodonTransformer.CodonPostProcessing import (
|
| 48 |
+
polish_sequence_with_dnachisel,
|
| 49 |
+
)
|
| 50 |
+
from CodonTransformer.CodonUtils import DNASequencePrediction
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
def translate_dna_to_protein(dna_sequence: str) -> str:
|
| 54 |
+
"""Translate DNA sequence to protein sequence."""
|
| 55 |
+
codon_table = {
|
| 56 |
+
'TTT': 'F', 'TTC': 'F', 'TTA': 'L', 'TTG': 'L',
|
| 57 |
+
'TCT': 'S', 'TCC': 'S', 'TCA': 'S', 'TCG': 'S',
|
| 58 |
+
'TAT': 'Y', 'TAC': 'Y', 'TAA': '*', 'TAG': '*',
|
| 59 |
+
'TGT': 'C', 'TGC': 'C', 'TGA': '*', 'TGG': 'W',
|
| 60 |
+
'CTT': 'L', 'CTC': 'L', 'CTA': 'L', 'CTG': 'L',
|
| 61 |
+
'CCT': 'P', 'CCC': 'P', 'CCA': 'P', 'CCG': 'P',
|
| 62 |
+
'CAT': 'H', 'CAC': 'H', 'CAA': 'Q', 'CAG': 'Q',
|
| 63 |
+
'CGT': 'R', 'CGC': 'R', 'CGA': 'R', 'CGG': 'R',
|
| 64 |
+
'ATT': 'I', 'ATC': 'I', 'ATA': 'I', 'ATG': 'M',
|
| 65 |
+
'ACT': 'T', 'ACC': 'T', 'ACA': 'T', 'ACG': 'T',
|
| 66 |
+
'AAT': 'N', 'AAC': 'N', 'AAA': 'K', 'AAG': 'K',
|
| 67 |
+
'AGT': 'S', 'AGC': 'S', 'AGA': 'R', 'AGG': 'R',
|
| 68 |
+
'GTT': 'V', 'GTC': 'V', 'GTA': 'V', 'GTG': 'V',
|
| 69 |
+
'GCT': 'A', 'GCC': 'A', 'GCA': 'A', 'GCG': 'A',
|
| 70 |
+
'GAT': 'D', 'GAC': 'D', 'GAA': 'E', 'GAG': 'E',
|
| 71 |
+
'GGT': 'G', 'GGC': 'G', 'GGA': 'G', 'GGG': 'G'
|
| 72 |
+
}
|
| 73 |
+
|
| 74 |
+
protein = ""
|
| 75 |
+
for i in range(0, len(dna_sequence), 3):
|
| 76 |
+
codon = dna_sequence[i:i+3].upper()
|
| 77 |
+
if len(codon) == 3:
|
| 78 |
+
aa = codon_table.get(codon, 'X')
|
| 79 |
+
if aa == '*': # Stop codon
|
| 80 |
+
break
|
| 81 |
+
protein += aa
|
| 82 |
+
|
| 83 |
+
return protein
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
def evaluate_with_enhancements(
|
| 87 |
+
protein_sequence: str,
|
| 88 |
+
model,
|
| 89 |
+
tokenizer,
|
| 90 |
+
device,
|
| 91 |
+
cai_weights: Dict[str, float],
|
| 92 |
+
tai_weights: Dict[str, float],
|
| 93 |
+
codon_frequencies: Dict,
|
| 94 |
+
reference_profile: List[float],
|
| 95 |
+
args,
|
| 96 |
+
) -> Dict[str, Any]:
|
| 97 |
+
"""
|
| 98 |
+
Evaluate a protein sequence with enhanced generation techniques.
|
| 99 |
+
|
| 100 |
+
Args:
|
| 101 |
+
protein_sequence: Input protein sequence
|
| 102 |
+
model: Fine-tuned model
|
| 103 |
+
tokenizer: Model tokenizer
|
| 104 |
+
device: PyTorch device
|
| 105 |
+
cai_weights: CAI weights dictionary
|
| 106 |
+
tai_weights: tAI weights dictionary
|
| 107 |
+
codon_frequencies: Codon frequencies dictionary
|
| 108 |
+
reference_profile: Reference profile for DTW calculation
|
| 109 |
+
args: Command line arguments
|
| 110 |
+
|
| 111 |
+
Returns:
|
| 112 |
+
Dict containing evaluation results for all methods
|
| 113 |
+
"""
|
| 114 |
+
results = {}
|
| 115 |
+
|
| 116 |
+
# 1. Original fine-tuned model (baseline)
|
| 117 |
+
try:
|
| 118 |
+
original_output = predict_dna_sequence(
|
| 119 |
+
protein=protein_sequence,
|
| 120 |
+
organism="Escherichia coli general",
|
| 121 |
+
device=device,
|
| 122 |
+
model=model,
|
| 123 |
+
deterministic=True,
|
| 124 |
+
match_protein=True,
|
| 125 |
+
use_constrained_search=args.use_constrained_search,
|
| 126 |
+
gc_bounds=tuple(args.gc_bounds),
|
| 127 |
+
beam_size=args.beam_size,
|
| 128 |
+
length_penalty=args.length_penalty,
|
| 129 |
+
diversity_penalty=args.diversity_penalty,
|
| 130 |
+
)
|
| 131 |
+
|
| 132 |
+
if isinstance(original_output, list):
|
| 133 |
+
original_dna = original_output[0].predicted_dna
|
| 134 |
+
else:
|
| 135 |
+
original_dna = original_output.predicted_dna
|
| 136 |
+
|
| 137 |
+
results['fine_tuned_original'] = {
|
| 138 |
+
'dna_sequence': original_dna,
|
| 139 |
+
'method': 'fine_tuned_original',
|
| 140 |
+
'enhancement': 'none',
|
| 141 |
+
}
|
| 142 |
+
|
| 143 |
+
except Exception as e:
|
| 144 |
+
print(f"Warning: Original fine-tuned generation failed: {str(e)}")
|
| 145 |
+
results['fine_tuned_original'] = {
|
| 146 |
+
'dna_sequence': '',
|
| 147 |
+
'method': 'fine_tuned_original',
|
| 148 |
+
'enhancement': 'none',
|
| 149 |
+
'error': str(e),
|
| 150 |
+
}
|
| 151 |
+
|
| 152 |
+
# 2. Enhanced sequence generation (DNAChisel + Pareto filtering)
|
| 153 |
+
if args.use_enhanced_generation:
|
| 154 |
+
try:
|
| 155 |
+
enhanced_dna, generation_report = enhanced_sequence_generation(
|
| 156 |
+
protein_sequence=protein_sequence,
|
| 157 |
+
model=model,
|
| 158 |
+
tokenizer=tokenizer,
|
| 159 |
+
device=device,
|
| 160 |
+
beam_size=args.enhanced_beam_size,
|
| 161 |
+
gc_bounds=(args.gc_bounds[0] * 100, args.gc_bounds[1] * 100),
|
| 162 |
+
use_dnachisel_polish=args.use_dnachisel,
|
| 163 |
+
use_pareto_filtering=args.use_pareto_filtering,
|
| 164 |
+
cai_weights=cai_weights,
|
| 165 |
+
tai_weights=tai_weights,
|
| 166 |
+
codon_frequencies=codon_frequencies,
|
| 167 |
+
reference_profile=reference_profile,
|
| 168 |
+
)
|
| 169 |
+
|
| 170 |
+
results['fine_tuned_enhanced'] = {
|
| 171 |
+
'dna_sequence': enhanced_dna,
|
| 172 |
+
'method': 'fine_tuned_enhanced',
|
| 173 |
+
'enhancement': 'dnachisel+pareto',
|
| 174 |
+
'generation_report': generation_report,
|
| 175 |
+
}
|
| 176 |
+
|
| 177 |
+
except Exception as e:
|
| 178 |
+
print(f"Warning: Enhanced generation failed: {str(e)}")
|
| 179 |
+
results['fine_tuned_enhanced'] = {
|
| 180 |
+
'dna_sequence': '',
|
| 181 |
+
'method': 'fine_tuned_enhanced',
|
| 182 |
+
'enhancement': 'dnachisel+pareto',
|
| 183 |
+
'error': str(e),
|
| 184 |
+
}
|
| 185 |
+
|
| 186 |
+
# 3. DNAChisel post-processing only (ablation study)
|
| 187 |
+
if args.use_dnachisel and 'fine_tuned_original' in results and results['fine_tuned_original']['dna_sequence']:
|
| 188 |
+
try:
|
| 189 |
+
dnachisel_dna, polish_report = polish_sequence_with_dnachisel(
|
| 190 |
+
dna_sequence=results['fine_tuned_original']['dna_sequence'],
|
| 191 |
+
protein_sequence=protein_sequence,
|
| 192 |
+
gc_bounds=(args.gc_bounds[0] * 100, args.gc_bounds[1] * 100),
|
| 193 |
+
maximize_cai=True,
|
| 194 |
+
seed=42,
|
| 195 |
+
)
|
| 196 |
+
|
| 197 |
+
results['fine_tuned_dnachisel'] = {
|
| 198 |
+
'dna_sequence': dnachisel_dna,
|
| 199 |
+
'method': 'fine_tuned_dnachisel',
|
| 200 |
+
'enhancement': 'dnachisel_only',
|
| 201 |
+
'polish_report': polish_report,
|
| 202 |
+
}
|
| 203 |
+
|
| 204 |
+
except Exception as e:
|
| 205 |
+
print(f"Warning: DNAChisel post-processing failed: {str(e)}")
|
| 206 |
+
results['fine_tuned_dnachisel'] = {
|
| 207 |
+
'dna_sequence': '',
|
| 208 |
+
'method': 'fine_tuned_dnachisel',
|
| 209 |
+
'enhancement': 'dnachisel_only',
|
| 210 |
+
'error': str(e),
|
| 211 |
+
}
|
| 212 |
+
|
| 213 |
+
return results
|
| 214 |
+
|
| 215 |
+
|
| 216 |
+
def calculate_comprehensive_metrics(
|
| 217 |
+
dna_sequence: str,
|
| 218 |
+
protein_sequence: str,
|
| 219 |
+
cai_weights: Dict[str, float],
|
| 220 |
+
tai_weights: Dict[str, float],
|
| 221 |
+
codon_frequencies: Dict,
|
| 222 |
+
reference_profile: List[float],
|
| 223 |
+
ref_sequences: List[str],
|
| 224 |
+
) -> Dict[str, float]:
|
| 225 |
+
"""Calculate comprehensive metrics for a DNA sequence."""
|
| 226 |
+
if not dna_sequence:
|
| 227 |
+
return {
|
| 228 |
+
'cai': 0.0,
|
| 229 |
+
'tai': 0.0,
|
| 230 |
+
'gc_content': 0.0,
|
| 231 |
+
'restriction_sites': float('inf'),
|
| 232 |
+
'neg_cis_elements': float('inf'),
|
| 233 |
+
'homopolymer_runs': float('inf'),
|
| 234 |
+
'dtw_distance': float('inf'),
|
| 235 |
+
'enc': 0.0,
|
| 236 |
+
'cpb': 0.0,
|
| 237 |
+
'scuo': 0.0,
|
| 238 |
+
}
|
| 239 |
+
|
| 240 |
+
return calculate_sequence_metrics(
|
| 241 |
+
dna_sequence=dna_sequence,
|
| 242 |
+
protein_sequence=protein_sequence,
|
| 243 |
+
cai_weights=cai_weights,
|
| 244 |
+
tai_weights=tai_weights,
|
| 245 |
+
codon_frequencies=codon_frequencies,
|
| 246 |
+
reference_profile=reference_profile,
|
| 247 |
+
)
|
| 248 |
+
|
| 249 |
+
|
| 250 |
+
def run_ablation_study(results_df: pd.DataFrame) -> pd.DataFrame:
|
| 251 |
+
"""
|
| 252 |
+
Run ablation study to compare different enhancement methods.
|
| 253 |
+
|
| 254 |
+
Args:
|
| 255 |
+
results_df: DataFrame with evaluation results
|
| 256 |
+
|
| 257 |
+
Returns:
|
| 258 |
+
DataFrame with ablation study results
|
| 259 |
+
"""
|
| 260 |
+
# Group by protein and calculate improvements
|
| 261 |
+
ablation_results = []
|
| 262 |
+
|
| 263 |
+
for protein in results_df['protein_sequence'].unique():
|
| 264 |
+
protein_results = results_df[results_df['protein_sequence'] == protein]
|
| 265 |
+
|
| 266 |
+
# Get baseline (original fine-tuned)
|
| 267 |
+
baseline = protein_results[protein_results['method'] == 'fine_tuned_original']
|
| 268 |
+
if baseline.empty:
|
| 269 |
+
continue
|
| 270 |
+
|
| 271 |
+
baseline_metrics = baseline.iloc[0]
|
| 272 |
+
|
| 273 |
+
# Compare each enhancement method
|
| 274 |
+
for method in protein_results['method'].unique():
|
| 275 |
+
if method == 'fine_tuned_original':
|
| 276 |
+
continue
|
| 277 |
+
|
| 278 |
+
method_results = protein_results[protein_results['method'] == method]
|
| 279 |
+
if method_results.empty:
|
| 280 |
+
continue
|
| 281 |
+
|
| 282 |
+
method_metrics = method_results.iloc[0]
|
| 283 |
+
|
| 284 |
+
# Calculate improvements
|
| 285 |
+
improvements = {
|
| 286 |
+
'protein': protein,
|
| 287 |
+
'method': method,
|
| 288 |
+
'enhancement': method_metrics['enhancement'],
|
| 289 |
+
'cai_improvement': method_metrics['cai'] - baseline_metrics['cai'],
|
| 290 |
+
'tai_improvement': method_metrics['tai'] - baseline_metrics['tai'],
|
| 291 |
+
'gc_improvement': abs(method_metrics['gc_content'] - 52) - abs(baseline_metrics['gc_content'] - 52),
|
| 292 |
+
'restriction_sites_improvement': baseline_metrics['restriction_sites'] - method_metrics['restriction_sites'],
|
| 293 |
+
'neg_cis_improvement': baseline_metrics['neg_cis_elements'] - method_metrics['neg_cis_elements'],
|
| 294 |
+
'homopolymer_improvement': baseline_metrics['homopolymer_runs'] - method_metrics['homopolymer_runs'],
|
| 295 |
+
'dtw_improvement': baseline_metrics['dtw_distance'] - method_metrics['dtw_distance'],
|
| 296 |
+
'composite_score_improvement': (
|
| 297 |
+
(method_metrics['cai'] - baseline_metrics['cai']) * 0.3 +
|
| 298 |
+
(method_metrics['tai'] - baseline_metrics['tai']) * 0.3 +
|
| 299 |
+
(abs(baseline_metrics['gc_content'] - 52) - abs(method_metrics['gc_content'] - 52)) * 0.2 +
|
| 300 |
+
(baseline_metrics['restriction_sites'] - method_metrics['restriction_sites']) * 0.1 +
|
| 301 |
+
(baseline_metrics['neg_cis_elements'] - method_metrics['neg_cis_elements']) * 0.1
|
| 302 |
+
),
|
| 303 |
+
}
|
| 304 |
+
|
| 305 |
+
ablation_results.append(improvements)
|
| 306 |
+
|
| 307 |
+
return pd.DataFrame(ablation_results)
|
| 308 |
+
|
| 309 |
+
|
| 310 |
+
def main(args):
|
| 311 |
+
"""Main function to run the enhanced evaluation."""
|
| 312 |
+
print("=== Enhanced CodonTransformer Evaluation ===")
|
| 313 |
+
|
| 314 |
+
# Setup device
|
| 315 |
+
device = torch.device("cuda" if torch.cuda.is_available() and args.use_gpu else "cpu")
|
| 316 |
+
print(f"Using device: {device}")
|
| 317 |
+
|
| 318 |
+
# Load test data
|
| 319 |
+
with open(args.test_data_path, "r") as f:
|
| 320 |
+
first = f.read(1)
|
| 321 |
+
f.seek(0)
|
| 322 |
+
if first == "[":
|
| 323 |
+
test_set = json.load(f)
|
| 324 |
+
else:
|
| 325 |
+
test_set = [json.loads(line) for line in f if line.strip()]
|
| 326 |
+
|
| 327 |
+
# Limit test set size if requested
|
| 328 |
+
if args.max_test_proteins > 0:
|
| 329 |
+
test_set = test_set[:args.max_test_proteins]
|
| 330 |
+
|
| 331 |
+
print(f"Loaded {len(test_set)} proteins from the test set.")
|
| 332 |
+
|
| 333 |
+
# Load models
|
| 334 |
+
print("Loading models...")
|
| 335 |
+
finetuned_model = load_model(model_path=args.checkpoint_path, device=device)
|
| 336 |
+
print(f"Fine-tuned model loaded from {args.checkpoint_path}")
|
| 337 |
+
|
| 338 |
+
# Load tokenizer
|
| 339 |
+
from transformers import AutoTokenizer
|
| 340 |
+
tokenizer = AutoTokenizer.from_pretrained("adibvafa/CodonTransformer")
|
| 341 |
+
|
| 342 |
+
# Load base model if comparison requested
|
| 343 |
+
base_model = None
|
| 344 |
+
if args.compare_with_base:
|
| 345 |
+
base_model = load_model(device=device)
|
| 346 |
+
print("Base model loaded from Hugging Face")
|
| 347 |
+
|
| 348 |
+
# Prepare evaluation utilities
|
| 349 |
+
print("Preparing evaluation utilities...")
|
| 350 |
+
|
| 351 |
+
# CAI weights
|
| 352 |
+
natural_csv = args.natural_sequences_path
|
| 353 |
+
natural_df = pd.read_csv(natural_csv)
|
| 354 |
+
ref_sequences = natural_df['dna_sequence'].tolist()
|
| 355 |
+
cai_weights = relative_adaptiveness(sequences=ref_sequences)
|
| 356 |
+
print("CAI weights generated")
|
| 357 |
+
|
| 358 |
+
# tAI weights
|
| 359 |
+
tai_weights = get_ecoli_tai_weights()
|
| 360 |
+
print("tAI weights loaded")
|
| 361 |
+
|
| 362 |
+
# Codon frequencies
|
| 363 |
+
try:
|
| 364 |
+
codon_frequencies = download_codon_frequencies_from_kazusa(taxonomy_id=83333)
|
| 365 |
+
print("Codon frequencies loaded from Kazusa")
|
| 366 |
+
except Exception as e:
|
| 367 |
+
print(f"Warning: Kazusa download failed ({e}). Using local frequencies.")
|
| 368 |
+
codon_frequencies = get_codon_frequencies(
|
| 369 |
+
ref_sequences, organism="Escherichia coli general"
|
| 370 |
+
)
|
| 371 |
+
|
| 372 |
+
# Reference profile for DTW
|
| 373 |
+
reference_profiles = [
|
| 374 |
+
get_min_max_profile(seq, codon_frequencies) for seq in ref_sequences[:100]
|
| 375 |
+
]
|
| 376 |
+
valid_profiles = [p for p in reference_profiles if p and not all(v is None for v in p)]
|
| 377 |
+
|
| 378 |
+
if valid_profiles:
|
| 379 |
+
max_len = max(len(p) for p in valid_profiles)
|
| 380 |
+
padded_profiles = [
|
| 381 |
+
np.pad(
|
| 382 |
+
np.array([v for v in p if v is not None]),
|
| 383 |
+
(0, max_len - len([v for v in p if v is not None])),
|
| 384 |
+
"constant",
|
| 385 |
+
constant_values=np.nan,
|
| 386 |
+
)
|
| 387 |
+
for p in valid_profiles
|
| 388 |
+
]
|
| 389 |
+
avg_reference_profile = np.nanmean(padded_profiles, axis=0).tolist()
|
| 390 |
+
else:
|
| 391 |
+
avg_reference_profile = []
|
| 392 |
+
|
| 393 |
+
print("Reference profile generated")
|
| 394 |
+
|
| 395 |
+
# Run evaluation
|
| 396 |
+
all_results = []
|
| 397 |
+
evaluation_reports = []
|
| 398 |
+
|
| 399 |
+
print("Starting enhanced evaluation...")
|
| 400 |
+
for i, item in enumerate(tqdm(test_set, desc="Evaluating proteins")):
|
| 401 |
+
# Get protein sequence
|
| 402 |
+
if "protein_sequence" in item:
|
| 403 |
+
protein_sequence = item["protein_sequence"]
|
| 404 |
+
else:
|
| 405 |
+
dna_sequence = item["codons"]
|
| 406 |
+
protein_sequence = translate_dna_to_protein(dna_sequence)
|
| 407 |
+
|
| 408 |
+
# Skip if protein is too short or too long
|
| 409 |
+
if len(protein_sequence) < 10 or len(protein_sequence) > 1000:
|
| 410 |
+
continue
|
| 411 |
+
|
| 412 |
+
# Evaluate with enhancements
|
| 413 |
+
protein_results = evaluate_with_enhancements(
|
| 414 |
+
protein_sequence=protein_sequence,
|
| 415 |
+
model=finetuned_model,
|
| 416 |
+
tokenizer=tokenizer,
|
| 417 |
+
device=device,
|
| 418 |
+
cai_weights=cai_weights,
|
| 419 |
+
tai_weights=tai_weights,
|
| 420 |
+
codon_frequencies=codon_frequencies,
|
| 421 |
+
reference_profile=avg_reference_profile,
|
| 422 |
+
args=args,
|
| 423 |
+
)
|
| 424 |
+
|
| 425 |
+
# Add base model comparison if requested
|
| 426 |
+
if base_model:
|
| 427 |
+
try:
|
| 428 |
+
base_output = predict_dna_sequence(
|
| 429 |
+
protein=protein_sequence,
|
| 430 |
+
organism="Escherichia coli general",
|
| 431 |
+
device=device,
|
| 432 |
+
model=base_model,
|
| 433 |
+
deterministic=True,
|
| 434 |
+
match_protein=True,
|
| 435 |
+
)
|
| 436 |
+
base_dna = base_output.predicted_dna if not isinstance(base_output, list) else base_output[0].predicted_dna
|
| 437 |
+
|
| 438 |
+
protein_results['base_model'] = {
|
| 439 |
+
'dna_sequence': base_dna,
|
| 440 |
+
'method': 'base_model',
|
| 441 |
+
'enhancement': 'none',
|
| 442 |
+
}
|
| 443 |
+
except Exception as e:
|
| 444 |
+
print(f"Warning: Base model generation failed: {str(e)}")
|
| 445 |
+
|
| 446 |
+
# Add naive baseline
|
| 447 |
+
try:
|
| 448 |
+
naive_dna = get_high_frequency_choice_sequence_optimized(
|
| 449 |
+
protein=protein_sequence, codon_frequencies=codon_frequencies
|
| 450 |
+
)
|
| 451 |
+
protein_results['naive_hfc'] = {
|
| 452 |
+
'dna_sequence': naive_dna,
|
| 453 |
+
'method': 'naive_hfc',
|
| 454 |
+
'enhancement': 'none',
|
| 455 |
+
}
|
| 456 |
+
except Exception as e:
|
| 457 |
+
print(f"Warning: Naive HFC generation failed: {str(e)}")
|
| 458 |
+
|
| 459 |
+
# Calculate metrics for each method
|
| 460 |
+
for method_name, method_result in protein_results.items():
|
| 461 |
+
if 'error' in method_result:
|
| 462 |
+
continue
|
| 463 |
+
|
| 464 |
+
dna_seq = method_result['dna_sequence']
|
| 465 |
+
if not dna_seq:
|
| 466 |
+
continue
|
| 467 |
+
|
| 468 |
+
metrics = calculate_comprehensive_metrics(
|
| 469 |
+
dna_sequence=dna_seq,
|
| 470 |
+
protein_sequence=protein_sequence,
|
| 471 |
+
cai_weights=cai_weights,
|
| 472 |
+
tai_weights=tai_weights,
|
| 473 |
+
codon_frequencies=codon_frequencies,
|
| 474 |
+
reference_profile=avg_reference_profile,
|
| 475 |
+
ref_sequences=ref_sequences,
|
| 476 |
+
)
|
| 477 |
+
|
| 478 |
+
# Combine results
|
| 479 |
+
result_row = {
|
| 480 |
+
'protein_id': i,
|
| 481 |
+
'protein_sequence': protein_sequence,
|
| 482 |
+
'protein_length': len(protein_sequence),
|
| 483 |
+
'method': method_name,
|
| 484 |
+
'enhancement': method_result['enhancement'],
|
| 485 |
+
'dna_sequence': dna_seq,
|
| 486 |
+
'dna_length': len(dna_seq),
|
| 487 |
+
**metrics,
|
| 488 |
+
}
|
| 489 |
+
|
| 490 |
+
# Add generation reports if available
|
| 491 |
+
if 'generation_report' in method_result:
|
| 492 |
+
result_row['generation_report'] = str(method_result['generation_report'])
|
| 493 |
+
if 'polish_report' in method_result:
|
| 494 |
+
result_row['polish_report'] = str(method_result['polish_report'])
|
| 495 |
+
|
| 496 |
+
all_results.append(result_row)
|
| 497 |
+
|
| 498 |
+
# Create results DataFrame
|
| 499 |
+
results_df = pd.DataFrame(all_results)
|
| 500 |
+
|
| 501 |
+
# Save detailed results
|
| 502 |
+
os.makedirs(os.path.dirname(args.output_path), exist_ok=True)
|
| 503 |
+
results_df.to_csv(args.output_path, index=False)
|
| 504 |
+
print(f"Detailed results saved to {args.output_path}")
|
| 505 |
+
|
| 506 |
+
# Run ablation study
|
| 507 |
+
if args.run_ablation_study:
|
| 508 |
+
ablation_df = run_ablation_study(results_df)
|
| 509 |
+
ablation_path = args.output_path.replace('.csv', '_ablation.csv')
|
| 510 |
+
ablation_df.to_csv(ablation_path, index=False)
|
| 511 |
+
print(f"Ablation study results saved to {ablation_path}")
|
| 512 |
+
|
| 513 |
+
# Print summary statistics
|
| 514 |
+
print("\n=== ABLATION STUDY SUMMARY ===")
|
| 515 |
+
for method in ablation_df['method'].unique():
|
| 516 |
+
method_results = ablation_df[ablation_df['method'] == method]
|
| 517 |
+
print(f"\n{method.upper()}:")
|
| 518 |
+
print(f" CAI improvement: {method_results['cai_improvement'].mean():.4f} ± {method_results['cai_improvement'].std():.4f}")
|
| 519 |
+
print(f" tAI improvement: {method_results['tai_improvement'].mean():.4f} ± {method_results['tai_improvement'].std():.4f}")
|
| 520 |
+
print(f" GC improvement: {method_results['gc_improvement'].mean():.4f} ± {method_results['gc_improvement'].std():.4f}")
|
| 521 |
+
print(f" Restriction sites improvement: {method_results['restriction_sites_improvement'].mean():.2f} ± {method_results['restriction_sites_improvement'].std():.2f}")
|
| 522 |
+
print(f" Composite score improvement: {method_results['composite_score_improvement'].mean():.4f} ± {method_results['composite_score_improvement'].std():.4f}")
|
| 523 |
+
|
| 524 |
+
# Print final summary
|
| 525 |
+
print("\n=== EVALUATION COMPLETE ===")
|
| 526 |
+
print(f"Total proteins evaluated: {len(results_df['protein_id'].unique())}")
|
| 527 |
+
print(f"Total sequences generated: {len(results_df)}")
|
| 528 |
+
print(f"Results saved to: {args.output_path}")
|
| 529 |
+
|
| 530 |
+
|
| 531 |
+
if __name__ == "__main__":
|
| 532 |
+
parser = argparse.ArgumentParser(description="Enhanced CodonTransformer Evaluation")
|
| 533 |
+
|
| 534 |
+
# Input/Output paths
|
| 535 |
+
parser.add_argument("--checkpoint_path", type=str, default="models/ecoli-codon-optimizer/finetune_best.ckpt",
|
| 536 |
+
help="Path to fine-tuned model checkpoint")
|
| 537 |
+
parser.add_argument("--test_data_path", type=str, default="data/test_set.json",
|
| 538 |
+
help="Path to test dataset")
|
| 539 |
+
parser.add_argument("--natural_sequences_path", type=str, default="data/ecoli_processed_genes.csv",
|
| 540 |
+
help="Path to natural E. coli sequences for CAI calculation")
|
| 541 |
+
parser.add_argument("--output_path", type=str, default="results/enhanced_evaluation_results.csv",
|
| 542 |
+
help="Path to save evaluation results")
|
| 543 |
+
|
| 544 |
+
# Model parameters
|
| 545 |
+
parser.add_argument("--use_gpu", action="store_true", help="Use GPU if available")
|
| 546 |
+
parser.add_argument("--compare_with_base", action="store_true", help="Compare with base model")
|
| 547 |
+
|
| 548 |
+
# Generation parameters
|
| 549 |
+
parser.add_argument("--use_constrained_search", action="store_true",
|
| 550 |
+
help="Use constrained beam search")
|
| 551 |
+
parser.add_argument("--gc_bounds", type=float, nargs=2, default=[0.50, 0.54],
|
| 552 |
+
help="GC content bounds (min max)")
|
| 553 |
+
parser.add_argument("--beam_size", type=int, default=10,
|
| 554 |
+
help="Beam size for standard generation")
|
| 555 |
+
parser.add_argument("--length_penalty", type=float, default=1.2,
|
| 556 |
+
help="Length penalty for beam search")
|
| 557 |
+
parser.add_argument("--diversity_penalty", type=float, default=0.1,
|
| 558 |
+
help="Diversity penalty for beam search")
|
| 559 |
+
|
| 560 |
+
# Enhancement parameters
|
| 561 |
+
parser.add_argument("--use_enhanced_generation", action="store_true",
|
| 562 |
+
help="Use enhanced generation with DNAChisel and Pareto filtering")
|
| 563 |
+
parser.add_argument("--enhanced_beam_size", type=int, default=20,
|
| 564 |
+
help="Beam size for enhanced generation")
|
| 565 |
+
parser.add_argument("--use_dnachisel", action="store_true",
|
| 566 |
+
help="Use DNAChisel post-processing")
|
| 567 |
+
parser.add_argument("--use_pareto_filtering", action="store_true",
|
| 568 |
+
help="Use Pareto frontier filtering")
|
| 569 |
+
|
| 570 |
+
# Evaluation parameters
|
| 571 |
+
parser.add_argument("--max_test_proteins", type=int, default=0,
|
| 572 |
+
help="Maximum number of proteins to test (0 for all)")
|
| 573 |
+
parser.add_argument("--run_ablation_study", action="store_true",
|
| 574 |
+
help="Run ablation study comparing methods")
|
| 575 |
+
|
| 576 |
+
args = parser.parse_args()
|
| 577 |
+
main(args)
|
prepare_ecoli_data.py
ADDED
|
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import pandas as pd
|
| 2 |
+
from Bio.Seq import Seq
|
| 3 |
+
import os
|
| 4 |
+
|
| 5 |
+
def is_valid_sequence(dna_seq: str) -> bool:
|
| 6 |
+
"""
|
| 7 |
+
Applies a series of validation checks to a DNA sequence.
|
| 8 |
+
|
| 9 |
+
Args:
|
| 10 |
+
dna_seq (str): The DNA sequence to validate.
|
| 11 |
+
|
| 12 |
+
Returns:
|
| 13 |
+
bool: True if the sequence is valid, False otherwise.
|
| 14 |
+
"""
|
| 15 |
+
if len(dna_seq) % 3 != 0:
|
| 16 |
+
return False
|
| 17 |
+
if not dna_seq.upper().startswith(('ATG', 'TTG', 'CTG', 'GTG')):
|
| 18 |
+
return False
|
| 19 |
+
if not dna_seq.upper().endswith(('TAA', 'TAG', 'TGA')):
|
| 20 |
+
return False
|
| 21 |
+
|
| 22 |
+
codons = [dna_seq[i:i+3].upper() for i in range(0, len(dna_seq) - 3, 3)]
|
| 23 |
+
if any(codon in ['TAA', 'TAG', 'TGA'] for codon in codons):
|
| 24 |
+
return False
|
| 25 |
+
|
| 26 |
+
if not all(c in 'ATGC' for c in dna_seq.upper()):
|
| 27 |
+
return False
|
| 28 |
+
|
| 29 |
+
return True
|
| 30 |
+
|
| 31 |
+
def main():
|
| 32 |
+
"""
|
| 33 |
+
Main function to process and validate E. coli gene data.
|
| 34 |
+
"""
|
| 35 |
+
if not os.path.exists('data'):
|
| 36 |
+
os.makedirs('data')
|
| 37 |
+
|
| 38 |
+
print("Loading data from CSV files...")
|
| 39 |
+
df_all = pd.read_csv("data/CAI.csv", header=0, names=['gene_id', 'cai_score', 'drop1', 'drop2', 'dna_sequence', 'drop3'])
|
| 40 |
+
df_high_cai = pd.read_csv("data/Database 3_4300 gene.csv", header=0, names=['dna_sequence'])
|
| 41 |
+
|
| 42 |
+
high_cai_sequences = set(df_high_cai['dna_sequence'])
|
| 43 |
+
|
| 44 |
+
validated_genes = []
|
| 45 |
+
for index, row in df_all.iterrows():
|
| 46 |
+
gene_id = row['gene_id']
|
| 47 |
+
dna_sequence = str(row['dna_sequence'])
|
| 48 |
+
|
| 49 |
+
if is_valid_sequence(dna_sequence):
|
| 50 |
+
protein_sequence = str(Seq(dna_sequence).translate())
|
| 51 |
+
is_high_cai = dna_sequence in high_cai_sequences
|
| 52 |
+
|
| 53 |
+
validated_genes.append({
|
| 54 |
+
'gene_id': gene_id,
|
| 55 |
+
'dna_sequence': dna_sequence,
|
| 56 |
+
'protein_sequence': protein_sequence,
|
| 57 |
+
'cai_score': row.get('cai_score', None),
|
| 58 |
+
'is_high_cai': is_high_cai
|
| 59 |
+
})
|
| 60 |
+
|
| 61 |
+
df_processed = pd.DataFrame(validated_genes)
|
| 62 |
+
|
| 63 |
+
output_path = 'data/ecoli_processed_genes.csv'
|
| 64 |
+
df_processed.to_csv(output_path, index=False)
|
| 65 |
+
print(f"Processed data saved to {output_path}")
|
| 66 |
+
print(f"Total validated genes: {len(df_processed)}")
|
| 67 |
+
|
| 68 |
+
if __name__ == "__main__":
|
| 69 |
+
main()
|
pretrain.py
ADDED
|
@@ -0,0 +1,232 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
File: pretrain.py
|
| 3 |
+
-----------------
|
| 4 |
+
Pretrain the base transformer model on JSON datasets prepared via
|
| 5 |
+
CodonData.prepare_training_data. This is typically not needed for ENCOT
|
| 6 |
+
as we use the pretrained CodonTransformer base. See README for setup and usage.
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
import argparse
|
| 10 |
+
import os
|
| 11 |
+
|
| 12 |
+
import pytorch_lightning as pl
|
| 13 |
+
import torch
|
| 14 |
+
from torch.utils.data import DataLoader
|
| 15 |
+
from transformers import BigBirdConfig, BigBirdForMaskedLM, PreTrainedTokenizerFast
|
| 16 |
+
|
| 17 |
+
from CodonTransformer.CodonUtils import (
|
| 18 |
+
MAX_LEN,
|
| 19 |
+
NUM_ORGANISMS,
|
| 20 |
+
TOKEN2MASK,
|
| 21 |
+
IterableJSONData,
|
| 22 |
+
)
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
class MaskedTokenizerCollator:
|
| 26 |
+
def __init__(self, tokenizer):
|
| 27 |
+
self.tokenizer = tokenizer
|
| 28 |
+
|
| 29 |
+
def __call__(self, examples):
|
| 30 |
+
tokenized = self.tokenizer(
|
| 31 |
+
[ex["codons"] for ex in examples],
|
| 32 |
+
return_attention_mask=True,
|
| 33 |
+
return_token_type_ids=True,
|
| 34 |
+
truncation=True,
|
| 35 |
+
padding=True,
|
| 36 |
+
max_length=MAX_LEN,
|
| 37 |
+
return_tensors="pt",
|
| 38 |
+
)
|
| 39 |
+
|
| 40 |
+
seq_len = tokenized["input_ids"].shape[-1]
|
| 41 |
+
species_index = torch.tensor([[ex["organism"]] for ex in examples])
|
| 42 |
+
tokenized["token_type_ids"] = species_index.repeat(1, seq_len)
|
| 43 |
+
|
| 44 |
+
inputs = tokenized["input_ids"]
|
| 45 |
+
targets = inputs.clone()
|
| 46 |
+
|
| 47 |
+
prob_matrix = torch.full(inputs.shape, 0.15)
|
| 48 |
+
prob_matrix[inputs < 5] = 0.0
|
| 49 |
+
selected = torch.bernoulli(prob_matrix).bool()
|
| 50 |
+
|
| 51 |
+
replaced = torch.bernoulli(torch.full(selected.shape, 0.8)).bool() & selected
|
| 52 |
+
inputs[replaced] = torch.tensor(
|
| 53 |
+
list((map(TOKEN2MASK.__getitem__, inputs[replaced].numpy())))
|
| 54 |
+
)
|
| 55 |
+
|
| 56 |
+
randomized = (
|
| 57 |
+
torch.bernoulli(torch.full(selected.shape, 0.1)).bool()
|
| 58 |
+
& selected
|
| 59 |
+
& ~replaced
|
| 60 |
+
)
|
| 61 |
+
random_idx = torch.randint(26, 90, inputs.shape, dtype=torch.long)
|
| 62 |
+
inputs[randomized] = random_idx[randomized]
|
| 63 |
+
|
| 64 |
+
tokenized["input_ids"] = inputs
|
| 65 |
+
tokenized["labels"] = torch.where(selected, targets, -100)
|
| 66 |
+
|
| 67 |
+
return tokenized
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
class plTrainHarness(pl.LightningModule):
|
| 71 |
+
def __init__(self, model, learning_rate, warmup_fraction):
|
| 72 |
+
super().__init__()
|
| 73 |
+
self.model = model
|
| 74 |
+
self.learning_rate = learning_rate
|
| 75 |
+
self.warmup_fraction = warmup_fraction
|
| 76 |
+
|
| 77 |
+
def configure_optimizers(self):
|
| 78 |
+
optimizer = torch.optim.AdamW(
|
| 79 |
+
self.model.parameters(),
|
| 80 |
+
lr=self.learning_rate,
|
| 81 |
+
)
|
| 82 |
+
lr_scheduler = {
|
| 83 |
+
"scheduler": torch.optim.lr_scheduler.OneCycleLR(
|
| 84 |
+
optimizer,
|
| 85 |
+
max_lr=self.learning_rate,
|
| 86 |
+
total_steps=self.trainer.estimated_stepping_batches,
|
| 87 |
+
pct_start=self.warmup_fraction,
|
| 88 |
+
),
|
| 89 |
+
"interval": "step",
|
| 90 |
+
"frequency": 1,
|
| 91 |
+
}
|
| 92 |
+
return [optimizer], [lr_scheduler]
|
| 93 |
+
|
| 94 |
+
def training_step(self, batch, batch_idx):
|
| 95 |
+
self.model.bert.set_attention_type("block_sparse")
|
| 96 |
+
outputs = self.model(**batch)
|
| 97 |
+
self.log_dict(
|
| 98 |
+
dictionary={
|
| 99 |
+
"loss": outputs.loss,
|
| 100 |
+
"lr": self.trainer.optimizers[0].param_groups[0]["lr"],
|
| 101 |
+
},
|
| 102 |
+
on_step=True,
|
| 103 |
+
prog_bar=True,
|
| 104 |
+
)
|
| 105 |
+
return outputs.loss
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
class EpochCheckpoint(pl.Callback):
|
| 109 |
+
def __init__(self, checkpoint_dir, save_interval):
|
| 110 |
+
super().__init__()
|
| 111 |
+
self.checkpoint_dir = checkpoint_dir
|
| 112 |
+
self.save_interval = save_interval
|
| 113 |
+
|
| 114 |
+
def on_train_epoch_end(self, trainer, pl_module):
|
| 115 |
+
current_epoch = trainer.current_epoch
|
| 116 |
+
if current_epoch % self.save_interval == 0 or current_epoch == 0:
|
| 117 |
+
checkpoint_path = os.path.join(
|
| 118 |
+
self.checkpoint_dir, f"epoch_{current_epoch}.ckpt"
|
| 119 |
+
)
|
| 120 |
+
trainer.save_checkpoint(checkpoint_path)
|
| 121 |
+
print(f"\nCheckpoint saved at {checkpoint_path}\n")
|
| 122 |
+
|
| 123 |
+
|
| 124 |
+
def main(args):
|
| 125 |
+
"""Pretrain the base transformer model."""
|
| 126 |
+
pl.seed_everything(args.seed)
|
| 127 |
+
torch.set_float32_matmul_precision("medium")
|
| 128 |
+
|
| 129 |
+
tokenizer = PreTrainedTokenizerFast(
|
| 130 |
+
tokenizer_file=args.tokenizer_path,
|
| 131 |
+
bos_token="[CLS]",
|
| 132 |
+
eos_token="[SEP]",
|
| 133 |
+
unk_token="[UNK]",
|
| 134 |
+
sep_token="[SEP]",
|
| 135 |
+
pad_token="[PAD]",
|
| 136 |
+
cls_token="[CLS]",
|
| 137 |
+
mask_token="[MASK]",
|
| 138 |
+
)
|
| 139 |
+
config = BigBirdConfig(
|
| 140 |
+
vocab_size=len(tokenizer),
|
| 141 |
+
type_vocab_size=NUM_ORGANISMS,
|
| 142 |
+
sep_token_id=2,
|
| 143 |
+
)
|
| 144 |
+
model = BigBirdForMaskedLM(config=config)
|
| 145 |
+
harnessed_model = plTrainHarness(model, args.learning_rate, args.warmup_fraction)
|
| 146 |
+
|
| 147 |
+
train_data = IterableJSONData(args.train_data_path, dist_env="slurm")
|
| 148 |
+
data_loader = DataLoader(
|
| 149 |
+
dataset=train_data,
|
| 150 |
+
collate_fn=MaskedTokenizerCollator(tokenizer),
|
| 151 |
+
batch_size=args.batch_size,
|
| 152 |
+
num_workers=0 if args.debug else args.num_workers,
|
| 153 |
+
persistent_workers=False if args.debug else True,
|
| 154 |
+
)
|
| 155 |
+
|
| 156 |
+
save_checkpoint = EpochCheckpoint(args.checkpoint_dir, args.save_interval)
|
| 157 |
+
trainer = pl.Trainer(
|
| 158 |
+
default_root_dir=args.checkpoint_dir,
|
| 159 |
+
strategy="ddp_find_unused_parameters_true",
|
| 160 |
+
accelerator="gpu",
|
| 161 |
+
devices=1 if args.debug else args.num_gpus,
|
| 162 |
+
precision="16-mixed",
|
| 163 |
+
max_epochs=args.max_epochs,
|
| 164 |
+
deterministic=False,
|
| 165 |
+
enable_checkpointing=True,
|
| 166 |
+
callbacks=[save_checkpoint],
|
| 167 |
+
accumulate_grad_batches=args.accumulate_grad_batches,
|
| 168 |
+
)
|
| 169 |
+
|
| 170 |
+
# Pretrain the model
|
| 171 |
+
trainer.fit(harnessed_model, data_loader)
|
| 172 |
+
|
| 173 |
+
|
| 174 |
+
if __name__ == "__main__":
|
| 175 |
+
parser = argparse.ArgumentParser(description="Pretrain the base transformer model.")
|
| 176 |
+
parser.add_argument(
|
| 177 |
+
"--tokenizer_path",
|
| 178 |
+
type=str,
|
| 179 |
+
required=True,
|
| 180 |
+
help="Path to the tokenizer model file",
|
| 181 |
+
)
|
| 182 |
+
parser.add_argument(
|
| 183 |
+
"--train_data_path",
|
| 184 |
+
type=str,
|
| 185 |
+
required=True,
|
| 186 |
+
help="Path to the training data JSON file",
|
| 187 |
+
)
|
| 188 |
+
parser.add_argument(
|
| 189 |
+
"--checkpoint_dir",
|
| 190 |
+
type=str,
|
| 191 |
+
required=True,
|
| 192 |
+
help="Directory where checkpoints will be saved",
|
| 193 |
+
)
|
| 194 |
+
parser.add_argument(
|
| 195 |
+
"--batch_size", type=int, default=6, help="Batch size for training"
|
| 196 |
+
)
|
| 197 |
+
parser.add_argument(
|
| 198 |
+
"--max_epochs", type=int, default=5, help="Maximum number of epochs to train"
|
| 199 |
+
)
|
| 200 |
+
parser.add_argument(
|
| 201 |
+
"--num_workers", type=int, default=5, help="Number of workers for data loading"
|
| 202 |
+
)
|
| 203 |
+
parser.add_argument(
|
| 204 |
+
"--accumulate_grad_batches",
|
| 205 |
+
type=int,
|
| 206 |
+
default=1,
|
| 207 |
+
help="Number of batches to accumulate gradients",
|
| 208 |
+
)
|
| 209 |
+
parser.add_argument(
|
| 210 |
+
"--num_gpus", type=int, default=16, help="Number of GPUs to use for training"
|
| 211 |
+
)
|
| 212 |
+
parser.add_argument(
|
| 213 |
+
"--learning_rate",
|
| 214 |
+
type=float,
|
| 215 |
+
default=5e-5,
|
| 216 |
+
help="Learning rate for the optimizer",
|
| 217 |
+
)
|
| 218 |
+
parser.add_argument(
|
| 219 |
+
"--warmup_fraction",
|
| 220 |
+
type=float,
|
| 221 |
+
default=0.1,
|
| 222 |
+
help="Fraction of total steps to use for warmup",
|
| 223 |
+
)
|
| 224 |
+
parser.add_argument(
|
| 225 |
+
"--save_interval", type=int, default=5, help="Save checkpoint every N epochs"
|
| 226 |
+
)
|
| 227 |
+
parser.add_argument(
|
| 228 |
+
"--seed", type=int, default=123, help="Random seed for reproducibility"
|
| 229 |
+
)
|
| 230 |
+
parser.add_argument("--debug", action="store_true", help="Enable debug mode")
|
| 231 |
+
args = parser.parse_args()
|
| 232 |
+
main(args)
|
pyproject.toml
ADDED
|
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[tool.poetry]
|
| 2 |
+
name = "ENCOT"
|
| 3 |
+
version = "1.0.0"
|
| 4 |
+
description = "Transformer-based codon optimization for E. coli using deep learning with Augmented-Lagrangian GC control."
|
| 5 |
+
authors = ["Adibvafa Fallahpour <Adibvafa.fallahpour@mail.utoronto.ca>"]
|
| 6 |
+
license = "Apache-2.0"
|
| 7 |
+
readme = "README.md"
|
| 8 |
+
homepage = "https://github.com/geno543/ENCOT"
|
| 9 |
+
repository = "https://github.com/geno543/ENCOT"
|
| 10 |
+
classifiers = [
|
| 11 |
+
"Programming Language :: Python :: 3",
|
| 12 |
+
"License :: OSI Approved :: Apache Software License",
|
| 13 |
+
"Operating System :: OS Independent",
|
| 14 |
+
]
|
| 15 |
+
|
| 16 |
+
[tool.poetry.dependencies]
|
| 17 |
+
python = "^3.9"
|
| 18 |
+
biopython = "^1.83"
|
| 19 |
+
ipywidgets = "^7.0.0"
|
| 20 |
+
numpy = "<2.0.0"
|
| 21 |
+
onnxruntime = "^1.16.3"
|
| 22 |
+
pandas = "^2.0.0"
|
| 23 |
+
python_codon_tables = "^0.1.12"
|
| 24 |
+
pytorch_lightning = "^2.2.1"
|
| 25 |
+
scikit-learn = "^1.2.2"
|
| 26 |
+
scipy = "^1.13.1"
|
| 27 |
+
setuptools = "^70.0.0"
|
| 28 |
+
torch = "^2.0.0"
|
| 29 |
+
tqdm = "^4.66.2"
|
| 30 |
+
transformers = "^4.40.0"
|
| 31 |
+
CAI-PyPI = "^2.0.1"
|
| 32 |
+
codon-bias = "^1.0.2"
|
| 33 |
+
gcua = "^0.1.2"
|
| 34 |
+
dtw-python = "^1.3.0"
|
| 35 |
+
|
| 36 |
+
[tool.poetry.dev-dependencies]
|
| 37 |
+
coverage = {version = "^7.0", extras = ["toml"]}
|
| 38 |
+
|
| 39 |
+
[build-system]
|
| 40 |
+
requires = ["poetry-core>=1.0.0"]
|
| 41 |
+
build-backend = "poetry.core.masonry.api"
|
| 42 |
+
|
| 43 |
+
[tool.ruff]
|
| 44 |
+
line-length = 88
|
| 45 |
+
indent-width = 4
|
| 46 |
+
target-version = "py310"
|
| 47 |
+
|
| 48 |
+
[tool.ruff.lint]
|
| 49 |
+
select = ["E", "F", "I"]
|
| 50 |
+
ignore = []
|
| 51 |
+
|
| 52 |
+
[tool.ruff.format]
|
| 53 |
+
quote-style = "double"
|
| 54 |
+
indent-style = "space"
|
| 55 |
+
skip-magic-trailing-comma = false
|
| 56 |
+
line-ending = "auto"
|
| 57 |
+
|
| 58 |
+
[tool.coverage.run]
|
| 59 |
+
omit = [
|
| 60 |
+
# omit pytorch-generated files in /tmp
|
| 61 |
+
"/tmp/*",
|
| 62 |
+
]
|
requirements.txt
ADDED
|
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
biopython>=1.83,<2.0
|
| 2 |
+
CAI-PyPI>=2.0.1,<3.0
|
| 3 |
+
ipywidgets>=7.0.0,<10.0
|
| 4 |
+
numpy>=1.26.4,<2.0
|
| 5 |
+
onnxruntime>=1.16.3,<3.0
|
| 6 |
+
pandas>=2.0.0,<3.0
|
| 7 |
+
python_codon_tables>=0.1.12,<1.0
|
| 8 |
+
pytorch_lightning>=2.2.1,<3.0
|
| 9 |
+
scikit-learn>=1.2.2,<2.0
|
| 10 |
+
scipy>=1.13.1,<3.0
|
| 11 |
+
setuptools>=70.0.0
|
| 12 |
+
torch>=2.0.0,<3.0
|
| 13 |
+
tqdm>=4.66.2,<5.0
|
| 14 |
+
transformers>=4.40.0,<5.0
|
| 15 |
+
codon-bias>=0.3.5,<0.4
|
| 16 |
+
dtw-python>=1.3.0,<2.0
|
| 17 |
+
|
| 18 |
+
dnachisel>=1.0
|
| 19 |
+
paretoset>=1.2.0
|
| 20 |
+
softadapt>=0.1.2,<0.2
|
| 21 |
+
ema-pytorch>=0.4.3
|
| 22 |
+
torchmetrics>=1.4.0
|
| 23 |
+
pyyaml>=6.0
|
| 24 |
+
|
| 25 |
+
matplotlib>=3.8,<4.0
|
| 26 |
+
seaborn>=0.13,<0.14
|
| 27 |
+
openpyxl>=3.1,<4.0
|
| 28 |
+
|
| 29 |
+
huggingface-hub>=0.20,<1.0
|
scripts/optimize_sequence.py
ADDED
|
@@ -0,0 +1,383 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Optimize protein sequences using ColiFormer.
|
| 3 |
+
|
| 4 |
+
This script provides a user-friendly interface for codon optimization,
|
| 5 |
+
supporting both single sequences and batch processing via FASTA files.
|
| 6 |
+
|
| 7 |
+
Usage:
|
| 8 |
+
# Single sequence
|
| 9 |
+
python scripts/optimize_sequence.py --input "MALWMRLLPLLALLALWGPDPAAAFVNQHLCGSHLVEALYLVCGERGFFYTPKTRREAEDLQVGQVELGG" --output optimized.fasta
|
| 10 |
+
|
| 11 |
+
# Batch processing from FASTA file
|
| 12 |
+
python scripts/optimize_sequence.py --input sequences.fasta --output optimized.fasta --batch
|
| 13 |
+
|
| 14 |
+
# With GC content constraints
|
| 15 |
+
python scripts/optimize_sequence.py --input protein.fasta --output optimized.fasta --gc-min 0.45 --gc-max 0.55
|
| 16 |
+
"""
|
| 17 |
+
|
| 18 |
+
import argparse
|
| 19 |
+
import os
|
| 20 |
+
import sys
|
| 21 |
+
from pathlib import Path
|
| 22 |
+
from typing import Any, List, Tuple
|
| 23 |
+
|
| 24 |
+
# Add parent directory to path to import CodonTransformer
|
| 25 |
+
sys.path.insert(0, str(Path(__file__).parent.parent))
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def parse_fasta(fasta_path: str) -> List[Tuple[str, str]]:
|
| 29 |
+
"""
|
| 30 |
+
Parse FASTA file into list of (name, sequence) tuples.
|
| 31 |
+
|
| 32 |
+
Args:
|
| 33 |
+
fasta_path: Path to FASTA file
|
| 34 |
+
|
| 35 |
+
Returns:
|
| 36 |
+
List of (name, sequence) tuples
|
| 37 |
+
"""
|
| 38 |
+
sequences = []
|
| 39 |
+
current_name = None
|
| 40 |
+
current_seq = []
|
| 41 |
+
|
| 42 |
+
with open(fasta_path, 'r') as f:
|
| 43 |
+
for line in f:
|
| 44 |
+
line = line.strip()
|
| 45 |
+
if line.startswith('>'):
|
| 46 |
+
if current_name is not None:
|
| 47 |
+
sequences.append((current_name, ''.join(current_seq)))
|
| 48 |
+
current_name = line[1:] if len(line) > 1 else f"sequence_{len(sequences)+1}"
|
| 49 |
+
current_seq = []
|
| 50 |
+
else:
|
| 51 |
+
current_seq.append(line.upper())
|
| 52 |
+
|
| 53 |
+
if current_name is not None:
|
| 54 |
+
sequences.append((current_name, ''.join(current_seq)))
|
| 55 |
+
|
| 56 |
+
return sequences
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
def write_fasta(output_path: str, sequences: List[Tuple[str, str]]):
|
| 60 |
+
"""
|
| 61 |
+
Write sequences to FASTA file.
|
| 62 |
+
|
| 63 |
+
Args:
|
| 64 |
+
output_path: Output FASTA file path
|
| 65 |
+
sequences: List of (name, sequence) tuples
|
| 66 |
+
"""
|
| 67 |
+
with open(output_path, 'w') as f:
|
| 68 |
+
for name, seq in sequences:
|
| 69 |
+
f.write(f">{name}\n")
|
| 70 |
+
# Write sequence in 60-character lines
|
| 71 |
+
for i in range(0, len(seq), 60):
|
| 72 |
+
f.write(seq[i:i+60] + "\n")
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
def optimize_single_sequence(
|
| 76 |
+
protein: str,
|
| 77 |
+
model: Any,
|
| 78 |
+
tokenizer: Any,
|
| 79 |
+
device: Any,
|
| 80 |
+
organism: str = "Escherichia coli general",
|
| 81 |
+
gc_min: float = None,
|
| 82 |
+
gc_max: float = None,
|
| 83 |
+
cai_weights: dict = None,
|
| 84 |
+
tai_weights: dict = None
|
| 85 |
+
) -> dict:
|
| 86 |
+
"""
|
| 87 |
+
Optimize a single protein sequence.
|
| 88 |
+
|
| 89 |
+
Args:
|
| 90 |
+
protein: Protein sequence string
|
| 91 |
+
model: Loaded ColiFormer model
|
| 92 |
+
tokenizer: Tokenizer
|
| 93 |
+
device: PyTorch device
|
| 94 |
+
organism: Target organism name
|
| 95 |
+
gc_min: Minimum GC content (0-1)
|
| 96 |
+
gc_max: Maximum GC content (0-1)
|
| 97 |
+
cai_weights: CAI weights dictionary
|
| 98 |
+
tai_weights: tAI weights dictionary
|
| 99 |
+
|
| 100 |
+
Returns:
|
| 101 |
+
Dictionary with optimization results
|
| 102 |
+
"""
|
| 103 |
+
# Lazy imports so `python scripts/optimize_sequence.py --help` works without ML deps installed.
|
| 104 |
+
from CodonTransformer.CodonPrediction import predict_dna_sequence
|
| 105 |
+
from CodonTransformer.CodonEvaluation import get_GC_content, calculate_tAI
|
| 106 |
+
from CAI import CAI
|
| 107 |
+
|
| 108 |
+
# Determine GC bounds if specified
|
| 109 |
+
gc_bounds = None
|
| 110 |
+
use_constrained = False
|
| 111 |
+
if gc_min is not None and gc_max is not None:
|
| 112 |
+
gc_bounds = (gc_min, gc_max)
|
| 113 |
+
use_constrained = True
|
| 114 |
+
|
| 115 |
+
# Run optimization
|
| 116 |
+
output = predict_dna_sequence(
|
| 117 |
+
protein=protein,
|
| 118 |
+
organism=organism,
|
| 119 |
+
device=device,
|
| 120 |
+
model=model,
|
| 121 |
+
tokenizer=tokenizer,
|
| 122 |
+
deterministic=True,
|
| 123 |
+
match_protein=True,
|
| 124 |
+
use_constrained_search=use_constrained,
|
| 125 |
+
gc_bounds=gc_bounds,
|
| 126 |
+
beam_size=20 if use_constrained else 5,
|
| 127 |
+
)
|
| 128 |
+
|
| 129 |
+
if isinstance(output, list):
|
| 130 |
+
output = output[0]
|
| 131 |
+
|
| 132 |
+
optimized_dna = output.predicted_dna
|
| 133 |
+
|
| 134 |
+
# Calculate metrics
|
| 135 |
+
gc_content = get_GC_content(optimized_dna) / 100.0 # Convert to fraction
|
| 136 |
+
|
| 137 |
+
metrics = {
|
| 138 |
+
'protein': protein,
|
| 139 |
+
'optimized_dna': optimized_dna,
|
| 140 |
+
'gc_content': gc_content,
|
| 141 |
+
'length': len(optimized_dna),
|
| 142 |
+
}
|
| 143 |
+
|
| 144 |
+
if cai_weights:
|
| 145 |
+
try:
|
| 146 |
+
metrics['cai'] = CAI(optimized_dna, weights=cai_weights)
|
| 147 |
+
except:
|
| 148 |
+
metrics['cai'] = None
|
| 149 |
+
else:
|
| 150 |
+
metrics['cai'] = None
|
| 151 |
+
|
| 152 |
+
if tai_weights:
|
| 153 |
+
try:
|
| 154 |
+
metrics['tai'] = calculate_tAI(optimized_dna, tai_weights)
|
| 155 |
+
except:
|
| 156 |
+
metrics['tai'] = None
|
| 157 |
+
else:
|
| 158 |
+
metrics['tai'] = None
|
| 159 |
+
|
| 160 |
+
return metrics
|
| 161 |
+
|
| 162 |
+
|
| 163 |
+
def load_reference_data(ref_sequences_path: str = None):
|
| 164 |
+
"""
|
| 165 |
+
Load reference sequences and calculate CAI weights.
|
| 166 |
+
|
| 167 |
+
Args:
|
| 168 |
+
ref_sequences_path: Path to CSV with reference sequences
|
| 169 |
+
|
| 170 |
+
Returns:
|
| 171 |
+
Tuple of (cai_weights, tai_weights)
|
| 172 |
+
"""
|
| 173 |
+
# Lazy imports so `--help` works without ML deps installed.
|
| 174 |
+
import pandas as pd
|
| 175 |
+
from CAI import relative_adaptiveness
|
| 176 |
+
from CodonTransformer.CodonEvaluation import get_ecoli_tai_weights
|
| 177 |
+
|
| 178 |
+
cai_weights = None
|
| 179 |
+
tai_weights = None
|
| 180 |
+
|
| 181 |
+
# Try to load reference sequences for CAI
|
| 182 |
+
if ref_sequences_path and os.path.exists(ref_sequences_path):
|
| 183 |
+
try:
|
| 184 |
+
df = pd.read_csv(ref_sequences_path)
|
| 185 |
+
if 'dna_sequence' in df.columns:
|
| 186 |
+
ref_sequences = df['dna_sequence'].tolist()
|
| 187 |
+
cai_weights = relative_adaptiveness(sequences=ref_sequences)
|
| 188 |
+
print(f"Loaded CAI weights from {len(ref_sequences)} reference sequences")
|
| 189 |
+
except Exception as e:
|
| 190 |
+
print(f"Warning: Could not load CAI weights: {e}")
|
| 191 |
+
|
| 192 |
+
# Load tAI weights
|
| 193 |
+
try:
|
| 194 |
+
tai_weights = get_ecoli_tai_weights()
|
| 195 |
+
print("Loaded E. coli tAI weights")
|
| 196 |
+
except Exception as e:
|
| 197 |
+
print(f"Warning: Could not load tAI weights: {e}")
|
| 198 |
+
|
| 199 |
+
return cai_weights, tai_weights
|
| 200 |
+
|
| 201 |
+
|
| 202 |
+
def main():
|
| 203 |
+
"""Main entry point for sequence optimization."""
|
| 204 |
+
parser = argparse.ArgumentParser(
|
| 205 |
+
description="Optimize protein sequences using ENCOT",
|
| 206 |
+
formatter_class=argparse.RawDescriptionHelpFormatter,
|
| 207 |
+
epilog="""
|
| 208 |
+
Examples:
|
| 209 |
+
# Single sequence
|
| 210 |
+
python scripts/optimize_sequence.py --input "MALWMRLLPLLALLALWGPDPAAAFVNQHLCGSHLVEALYLVCGERGFFYTPKTRREAEDLQVGQVELGG" --output optimized.fasta
|
| 211 |
+
|
| 212 |
+
# Batch processing from FASTA file
|
| 213 |
+
python scripts/optimize_sequence.py --input sequences.fasta --output optimized.fasta --batch
|
| 214 |
+
|
| 215 |
+
# With GC content constraints
|
| 216 |
+
python scripts/optimize_sequence.py --input protein.fasta --output optimized.fasta --gc-min 0.45 --gc-max 0.55
|
| 217 |
+
|
| 218 |
+
# Use custom checkpoint
|
| 219 |
+
python scripts/optimize_sequence.py --input protein.fasta --output optimized.fasta --checkpoint models/my_model.ckpt
|
| 220 |
+
"""
|
| 221 |
+
)
|
| 222 |
+
parser.add_argument(
|
| 223 |
+
"--input",
|
| 224 |
+
type=str,
|
| 225 |
+
required=True,
|
| 226 |
+
help="Input protein sequence (string) or FASTA file path"
|
| 227 |
+
)
|
| 228 |
+
parser.add_argument(
|
| 229 |
+
"--output",
|
| 230 |
+
type=str,
|
| 231 |
+
required=True,
|
| 232 |
+
help="Output FASTA file path"
|
| 233 |
+
)
|
| 234 |
+
parser.add_argument(
|
| 235 |
+
"--checkpoint",
|
| 236 |
+
type=str,
|
| 237 |
+
default=None,
|
| 238 |
+
help="Path to model checkpoint (default: auto-download from Hugging Face)"
|
| 239 |
+
)
|
| 240 |
+
parser.add_argument(
|
| 241 |
+
"--organism",
|
| 242 |
+
type=str,
|
| 243 |
+
default="Escherichia coli general",
|
| 244 |
+
help="Target organism (default: Escherichia coli general)"
|
| 245 |
+
)
|
| 246 |
+
parser.add_argument(
|
| 247 |
+
"--gc-min",
|
| 248 |
+
type=float,
|
| 249 |
+
default=None,
|
| 250 |
+
help="Minimum GC content (0-1, e.g., 0.45 for 45%%)"
|
| 251 |
+
)
|
| 252 |
+
parser.add_argument(
|
| 253 |
+
"--gc-max",
|
| 254 |
+
type=float,
|
| 255 |
+
default=None,
|
| 256 |
+
help="Maximum GC content (0-1, e.g., 0.55 for 55%%)"
|
| 257 |
+
)
|
| 258 |
+
parser.add_argument(
|
| 259 |
+
"--batch",
|
| 260 |
+
action="store_true",
|
| 261 |
+
help="Process input as FASTA file with multiple sequences"
|
| 262 |
+
)
|
| 263 |
+
parser.add_argument(
|
| 264 |
+
"--ref-sequences",
|
| 265 |
+
type=str,
|
| 266 |
+
default="data/ecoli_processed_genes.csv",
|
| 267 |
+
help="Path to reference sequences CSV for CAI calculation"
|
| 268 |
+
)
|
| 269 |
+
parser.add_argument(
|
| 270 |
+
"--use-gpu",
|
| 271 |
+
action="store_true",
|
| 272 |
+
help="Use GPU if available"
|
| 273 |
+
)
|
| 274 |
+
|
| 275 |
+
args = parser.parse_args()
|
| 276 |
+
|
| 277 |
+
try:
|
| 278 |
+
# Lazy imports so `--help` works without ML deps installed.
|
| 279 |
+
import torch
|
| 280 |
+
from transformers import AutoTokenizer
|
| 281 |
+
from CodonTransformer.CodonPrediction import load_model
|
| 282 |
+
import pandas as pd
|
| 283 |
+
|
| 284 |
+
# Setup device
|
| 285 |
+
device = torch.device("cuda" if torch.cuda.is_available() and args.use_gpu else "cpu")
|
| 286 |
+
print(f"Using device: {device}")
|
| 287 |
+
|
| 288 |
+
# Load model
|
| 289 |
+
print("Loading ColiFormer model...")
|
| 290 |
+
if args.checkpoint:
|
| 291 |
+
model = load_model(model_path=args.checkpoint, device=device)
|
| 292 |
+
print(f"Loaded model from {args.checkpoint}")
|
| 293 |
+
else:
|
| 294 |
+
# Try to load from Hugging Face
|
| 295 |
+
try:
|
| 296 |
+
from huggingface_hub import hf_hub_download
|
| 297 |
+
checkpoint_path = hf_hub_download(
|
| 298 |
+
repo_id="saketh11/ColiFormer",
|
| 299 |
+
filename="balanced_alm_finetune.ckpt",
|
| 300 |
+
cache_dir="./hf_cache"
|
| 301 |
+
)
|
| 302 |
+
model = load_model(model_path=checkpoint_path, device=device)
|
| 303 |
+
print("Loaded model from Hugging Face (saketh11/ColiFormer)")
|
| 304 |
+
except Exception as e:
|
| 305 |
+
print(f"Warning: Could not load from Hugging Face: {e}")
|
| 306 |
+
print("Falling back to base CodonTransformer model...")
|
| 307 |
+
from transformers import BigBirdForMaskedLM
|
| 308 |
+
model = BigBirdForMaskedLM.from_pretrained("adibvafa/CodonTransformer").to(device)
|
| 309 |
+
|
| 310 |
+
# Load tokenizer
|
| 311 |
+
tokenizer = AutoTokenizer.from_pretrained("adibvafa/CodonTransformer")
|
| 312 |
+
|
| 313 |
+
# Load reference data for metrics
|
| 314 |
+
cai_weights, tai_weights = load_reference_data(args.ref_sequences)
|
| 315 |
+
|
| 316 |
+
# Parse input
|
| 317 |
+
if args.batch or os.path.exists(args.input):
|
| 318 |
+
# FASTA file
|
| 319 |
+
print(f"Reading sequences from {args.input}...")
|
| 320 |
+
sequences = parse_fasta(args.input)
|
| 321 |
+
print(f"Found {len(sequences)} sequences")
|
| 322 |
+
else:
|
| 323 |
+
# Single sequence string
|
| 324 |
+
sequences = [("sequence_1", args.input.upper())]
|
| 325 |
+
|
| 326 |
+
# Optimize sequences
|
| 327 |
+
optimized_sequences = []
|
| 328 |
+
results = []
|
| 329 |
+
|
| 330 |
+
for i, (name, protein_seq) in enumerate(sequences, 1):
|
| 331 |
+
print(f"\nOptimizing sequence {i}/{len(sequences)}: {name}")
|
| 332 |
+
|
| 333 |
+
metrics = optimize_single_sequence(
|
| 334 |
+
protein=protein_seq,
|
| 335 |
+
model=model,
|
| 336 |
+
tokenizer=tokenizer,
|
| 337 |
+
device=device,
|
| 338 |
+
organism=args.organism,
|
| 339 |
+
gc_min=args.gc_min,
|
| 340 |
+
gc_max=args.gc_max,
|
| 341 |
+
cai_weights=cai_weights,
|
| 342 |
+
tai_weights=tai_weights
|
| 343 |
+
)
|
| 344 |
+
|
| 345 |
+
optimized_sequences.append((name, metrics['optimized_dna']))
|
| 346 |
+
results.append({
|
| 347 |
+
'name': name,
|
| 348 |
+
'protein_length': len(protein_seq),
|
| 349 |
+
'dna_length': metrics['length'],
|
| 350 |
+
'gc_content': f"{metrics['gc_content']*100:.2f}%",
|
| 351 |
+
'cai': metrics['cai'],
|
| 352 |
+
'tai': metrics['tai'],
|
| 353 |
+
})
|
| 354 |
+
|
| 355 |
+
print(f" GC content: {metrics['gc_content']*100:.2f}%")
|
| 356 |
+
if metrics['cai']:
|
| 357 |
+
print(f" CAI: {metrics['cai']:.3f}")
|
| 358 |
+
if metrics['tai']:
|
| 359 |
+
print(f" tAI: {metrics['tai']:.3f}")
|
| 360 |
+
|
| 361 |
+
# Write output
|
| 362 |
+
write_fasta(args.output, optimized_sequences)
|
| 363 |
+
print(f"\nOptimized sequences saved to {args.output}")
|
| 364 |
+
|
| 365 |
+
# Print summary
|
| 366 |
+
if len(results) > 1:
|
| 367 |
+
print("\n" + "="*60)
|
| 368 |
+
print("Summary Statistics")
|
| 369 |
+
print("="*60)
|
| 370 |
+
df = pd.DataFrame(results)
|
| 371 |
+
print(df.to_string(index=False))
|
| 372 |
+
print("="*60)
|
| 373 |
+
|
| 374 |
+
except Exception as e:
|
| 375 |
+
print(f"Error: {e}", file=sys.stderr)
|
| 376 |
+
import traceback
|
| 377 |
+
traceback.print_exc()
|
| 378 |
+
sys.exit(1)
|
| 379 |
+
|
| 380 |
+
|
| 381 |
+
if __name__ == "__main__":
|
| 382 |
+
main()
|
| 383 |
+
|
scripts/preprocess_data.py
ADDED
|
@@ -0,0 +1,251 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Preprocess E. coli gene data for ColiFormer training.
|
| 3 |
+
|
| 4 |
+
This script combines the functionality of prepare_ecoli_data.py and
|
| 5 |
+
create_model_datasets.py to prepare training and test datasets from raw CSV files.
|
| 6 |
+
|
| 7 |
+
Usage:
|
| 8 |
+
python scripts/preprocess_data.py
|
| 9 |
+
python scripts/preprocess_data.py --cai_csv data/CAI.csv --high_cai_csv data/Database_3_4300_gene.csv
|
| 10 |
+
"""
|
| 11 |
+
|
| 12 |
+
import argparse
|
| 13 |
+
import json
|
| 14 |
+
import os
|
| 15 |
+
import sys
|
| 16 |
+
from pathlib import Path
|
| 17 |
+
|
| 18 |
+
# Add parent directory to path to import CodonTransformer
|
| 19 |
+
sys.path.insert(0, str(Path(__file__).parent.parent))
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def is_valid_sequence(dna_seq: str) -> bool:
|
| 23 |
+
"""
|
| 24 |
+
Validate a DNA sequence for training suitability.
|
| 25 |
+
|
| 26 |
+
Args:
|
| 27 |
+
dna_seq: DNA sequence string
|
| 28 |
+
|
| 29 |
+
Returns:
|
| 30 |
+
True if sequence is valid (divisible by 3, proper start/stop codons, no internal stops)
|
| 31 |
+
"""
|
| 32 |
+
if len(dna_seq) % 3 != 0:
|
| 33 |
+
return False
|
| 34 |
+
if not dna_seq.upper().startswith(('ATG', 'TTG', 'CTG', 'GTG')):
|
| 35 |
+
return False
|
| 36 |
+
if not dna_seq.upper().endswith(('TAA', 'TAG', 'TGA')):
|
| 37 |
+
return False
|
| 38 |
+
|
| 39 |
+
codons = [dna_seq[i:i+3].upper() for i in range(0, len(dna_seq) - 3, 3)]
|
| 40 |
+
if any(codon in ['TAA', 'TAG', 'TGA'] for codon in codons):
|
| 41 |
+
return False
|
| 42 |
+
|
| 43 |
+
if not all(c in 'ATGC' for c in dna_seq.upper()):
|
| 44 |
+
return False
|
| 45 |
+
|
| 46 |
+
return True
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
def process_ecoli_data(cai_csv: str, high_cai_csv: str, output_dir: str = "data"):
|
| 50 |
+
"""
|
| 51 |
+
Process raw E. coli gene data from CSV files.
|
| 52 |
+
|
| 53 |
+
Args:
|
| 54 |
+
cai_csv: Path to CAI.csv file with gene data
|
| 55 |
+
high_cai_csv: Path to Database 3_4300 gene.csv with high-CAI sequences
|
| 56 |
+
output_dir: Output directory for processed files
|
| 57 |
+
|
| 58 |
+
Returns:
|
| 59 |
+
Path to processed CSV file
|
| 60 |
+
"""
|
| 61 |
+
# Lazy imports so `python scripts/preprocess_data.py --help` works without heavy deps installed.
|
| 62 |
+
import pandas as pd
|
| 63 |
+
from Bio.Seq import Seq
|
| 64 |
+
|
| 65 |
+
# Validate input files exist
|
| 66 |
+
if not os.path.exists(cai_csv):
|
| 67 |
+
raise FileNotFoundError(f"CAI CSV file not found: {cai_csv}")
|
| 68 |
+
if not os.path.exists(high_cai_csv):
|
| 69 |
+
raise FileNotFoundError(f"High-CAI CSV file not found: {high_cai_csv}")
|
| 70 |
+
|
| 71 |
+
# Create output directory if needed
|
| 72 |
+
os.makedirs(output_dir, exist_ok=True)
|
| 73 |
+
|
| 74 |
+
print("Loading data from CSV files...")
|
| 75 |
+
df_all = pd.read_csv(
|
| 76 |
+
cai_csv,
|
| 77 |
+
header=0,
|
| 78 |
+
names=['gene_id', 'cai_score', 'drop1', 'drop2', 'dna_sequence', 'drop3']
|
| 79 |
+
)
|
| 80 |
+
df_high_cai = pd.read_csv(
|
| 81 |
+
high_cai_csv,
|
| 82 |
+
header=0,
|
| 83 |
+
names=['dna_sequence']
|
| 84 |
+
)
|
| 85 |
+
|
| 86 |
+
high_cai_sequences = set(df_high_cai['dna_sequence'])
|
| 87 |
+
|
| 88 |
+
validated_genes = []
|
| 89 |
+
for index, row in df_all.iterrows():
|
| 90 |
+
gene_id = row['gene_id']
|
| 91 |
+
dna_sequence = str(row['dna_sequence'])
|
| 92 |
+
|
| 93 |
+
if is_valid_sequence(dna_sequence):
|
| 94 |
+
protein_sequence = str(Seq(dna_sequence).translate())
|
| 95 |
+
is_high_cai = dna_sequence in high_cai_sequences
|
| 96 |
+
|
| 97 |
+
validated_genes.append({
|
| 98 |
+
'gene_id': gene_id,
|
| 99 |
+
'dna_sequence': dna_sequence,
|
| 100 |
+
'protein_sequence': protein_sequence,
|
| 101 |
+
'cai_score': row.get('cai_score', None),
|
| 102 |
+
'is_high_cai': is_high_cai
|
| 103 |
+
})
|
| 104 |
+
|
| 105 |
+
df_processed = pd.DataFrame(validated_genes)
|
| 106 |
+
|
| 107 |
+
output_path = os.path.join(output_dir, 'ecoli_processed_genes.csv')
|
| 108 |
+
df_processed.to_csv(output_path, index=False)
|
| 109 |
+
print(f"Processed data saved to {output_path}")
|
| 110 |
+
print(f"Total validated genes: {len(df_processed)}")
|
| 111 |
+
|
| 112 |
+
return output_path
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
def create_train_test_splits(processed_csv: str, output_dir: str = "data", test_size: int = 100):
|
| 116 |
+
"""
|
| 117 |
+
Create training and test splits from processed data.
|
| 118 |
+
|
| 119 |
+
Args:
|
| 120 |
+
processed_csv: Path to processed ecoli_processed_genes.csv
|
| 121 |
+
output_dir: Output directory for JSON files
|
| 122 |
+
test_size: Number of sequences for test set
|
| 123 |
+
|
| 124 |
+
Returns:
|
| 125 |
+
Tuple of (finetune_json_path, test_json_path)
|
| 126 |
+
"""
|
| 127 |
+
# Lazy imports so `--help` works without heavy deps installed.
|
| 128 |
+
import pandas as pd
|
| 129 |
+
from CodonTransformer.CodonData import prepare_training_data
|
| 130 |
+
|
| 131 |
+
if not os.path.exists(processed_csv):
|
| 132 |
+
raise FileNotFoundError(f"Processed data file not found: {processed_csv}")
|
| 133 |
+
|
| 134 |
+
os.makedirs(output_dir, exist_ok=True)
|
| 135 |
+
|
| 136 |
+
df_processed = pd.read_csv(processed_csv)
|
| 137 |
+
|
| 138 |
+
# Create fine-tuning set (high-CAI sequences)
|
| 139 |
+
df_finetune = df_processed[df_processed['is_high_cai'] == True].copy()
|
| 140 |
+
df_finetune.drop_duplicates(subset=['dna_sequence'], inplace=True)
|
| 141 |
+
df_finetune.rename(columns={'dna_sequence': 'dna', 'protein_sequence': 'protein'}, inplace=True)
|
| 142 |
+
df_finetune['organism'] = "Escherichia coli general"
|
| 143 |
+
|
| 144 |
+
finetune_output_path = os.path.join(output_dir, 'finetune_set.json')
|
| 145 |
+
prepare_training_data(df_finetune, finetune_output_path, shuffle=True)
|
| 146 |
+
print(f"Fine-tuning set saved to {finetune_output_path} with {len(df_finetune)} records.")
|
| 147 |
+
|
| 148 |
+
# Create test set (non-high-CAI sequences)
|
| 149 |
+
df_test_pool = df_processed[df_processed['is_high_cai'] == False].copy()
|
| 150 |
+
df_test = df_test_pool.sample(n=test_size, random_state=42) # for reproducibility
|
| 151 |
+
df_test['organism'] = 51 # E. coli general organism ID
|
| 152 |
+
df_test.rename(columns={'dna_sequence': 'codons'}, inplace=True)
|
| 153 |
+
test_records = df_test[['codons', 'organism']].to_dict(orient='records')
|
| 154 |
+
|
| 155 |
+
test_output_path = os.path.join(output_dir, 'test_set.json')
|
| 156 |
+
with open(test_output_path, 'w') as f:
|
| 157 |
+
json.dump(test_records, f, indent=4)
|
| 158 |
+
print(f"Test set saved to {test_output_path} with {len(df_test)} records.")
|
| 159 |
+
|
| 160 |
+
return finetune_output_path, test_output_path
|
| 161 |
+
|
| 162 |
+
|
| 163 |
+
def main():
|
| 164 |
+
"""Main entry point for data preprocessing."""
|
| 165 |
+
parser = argparse.ArgumentParser(
|
| 166 |
+
description="Preprocess E. coli gene data for ENCOT training",
|
| 167 |
+
formatter_class=argparse.RawDescriptionHelpFormatter,
|
| 168 |
+
epilog="""
|
| 169 |
+
Examples:
|
| 170 |
+
# Use default paths
|
| 171 |
+
python scripts/preprocess_data.py
|
| 172 |
+
|
| 173 |
+
# Specify custom input files
|
| 174 |
+
python scripts/preprocess_data.py --cai_csv data/CAI.csv --high_cai_csv data/Database_3_4300_gene.csv
|
| 175 |
+
|
| 176 |
+
# Custom output directory and test size
|
| 177 |
+
python scripts/preprocess_data.py --output_dir my_data --test_size 200
|
| 178 |
+
"""
|
| 179 |
+
)
|
| 180 |
+
parser.add_argument(
|
| 181 |
+
"--cai_csv",
|
| 182 |
+
type=str,
|
| 183 |
+
default="data/CAI.csv",
|
| 184 |
+
help="Path to CAI.csv file with gene data (default: data/CAI.csv)"
|
| 185 |
+
)
|
| 186 |
+
parser.add_argument(
|
| 187 |
+
"--high_cai_csv",
|
| 188 |
+
type=str,
|
| 189 |
+
default="data/Database 3_4300 gene.csv",
|
| 190 |
+
help="Path to Database 3_4300 gene.csv file (default: data/Database 3_4300 gene.csv)"
|
| 191 |
+
)
|
| 192 |
+
parser.add_argument(
|
| 193 |
+
"--output_dir",
|
| 194 |
+
type=str,
|
| 195 |
+
default="data",
|
| 196 |
+
help="Output directory for processed files (default: data)"
|
| 197 |
+
)
|
| 198 |
+
parser.add_argument(
|
| 199 |
+
"--test_size",
|
| 200 |
+
type=int,
|
| 201 |
+
default=100,
|
| 202 |
+
help="Number of sequences for test set (default: 100)"
|
| 203 |
+
)
|
| 204 |
+
parser.add_argument(
|
| 205 |
+
"--skip_processing",
|
| 206 |
+
action="store_true",
|
| 207 |
+
help="Skip data processing step (assume ecoli_processed_genes.csv exists)"
|
| 208 |
+
)
|
| 209 |
+
|
| 210 |
+
args = parser.parse_args()
|
| 211 |
+
|
| 212 |
+
try:
|
| 213 |
+
# Step 1: Process raw data
|
| 214 |
+
if not args.skip_processing:
|
| 215 |
+
processed_csv = process_ecoli_data(
|
| 216 |
+
args.cai_csv,
|
| 217 |
+
args.high_cai_csv,
|
| 218 |
+
args.output_dir
|
| 219 |
+
)
|
| 220 |
+
else:
|
| 221 |
+
processed_csv = os.path.join(args.output_dir, 'ecoli_processed_genes.csv')
|
| 222 |
+
if not os.path.exists(processed_csv):
|
| 223 |
+
raise FileNotFoundError(
|
| 224 |
+
f"Processed data not found at {processed_csv}. "
|
| 225 |
+
"Remove --skip_processing flag to process raw data first."
|
| 226 |
+
)
|
| 227 |
+
print(f"Using existing processed data: {processed_csv}")
|
| 228 |
+
|
| 229 |
+
# Step 2: Create train/test splits
|
| 230 |
+
finetune_path, test_path = create_train_test_splits(
|
| 231 |
+
processed_csv,
|
| 232 |
+
args.output_dir,
|
| 233 |
+
args.test_size
|
| 234 |
+
)
|
| 235 |
+
|
| 236 |
+
print("\n" + "="*60)
|
| 237 |
+
print("Data preprocessing complete!")
|
| 238 |
+
print("="*60)
|
| 239 |
+
print(f"Training set: {finetune_path}")
|
| 240 |
+
print(f"Test set: {test_path}")
|
| 241 |
+
print("\nYou can now run training with:")
|
| 242 |
+
print(f" python scripts/train.py --config configs/train_ecoli_alm.yaml")
|
| 243 |
+
|
| 244 |
+
except Exception as e:
|
| 245 |
+
print(f"Error: {e}", file=sys.stderr)
|
| 246 |
+
sys.exit(1)
|
| 247 |
+
|
| 248 |
+
|
| 249 |
+
if __name__ == "__main__":
|
| 250 |
+
main()
|
| 251 |
+
|
scripts/run_benchmarks.py
ADDED
|
@@ -0,0 +1,235 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Run benchmark evaluation for ColiFormer.
|
| 3 |
+
|
| 4 |
+
This script wraps benchmark_evaluation.py and evaluate_optimizer.py to provide
|
| 5 |
+
a unified interface for running comprehensive evaluations.
|
| 6 |
+
|
| 7 |
+
Usage:
|
| 8 |
+
python scripts/run_benchmarks.py --config configs/benchmark.yaml
|
| 9 |
+
python scripts/run_benchmarks.py --excel_path Benchmark_80_sequences.xlsx --checkpoint_path models/my_model.ckpt
|
| 10 |
+
"""
|
| 11 |
+
|
| 12 |
+
import argparse
|
| 13 |
+
import os
|
| 14 |
+
import sys
|
| 15 |
+
from pathlib import Path
|
| 16 |
+
|
| 17 |
+
# Add parent directory to path to import benchmark scripts
|
| 18 |
+
sys.path.insert(0, str(Path(__file__).parent.parent))
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def load_config(config_path: str) -> dict:
|
| 22 |
+
"""
|
| 23 |
+
Load configuration from YAML file.
|
| 24 |
+
|
| 25 |
+
Args:
|
| 26 |
+
config_path: Path to YAML config file
|
| 27 |
+
|
| 28 |
+
Returns:
|
| 29 |
+
Dictionary with configuration values
|
| 30 |
+
"""
|
| 31 |
+
# Lazy import so `python scripts/run_benchmarks.py --help` works without dependencies installed.
|
| 32 |
+
import yaml
|
| 33 |
+
|
| 34 |
+
if not os.path.exists(config_path):
|
| 35 |
+
raise FileNotFoundError(f"Config file not found: {config_path}")
|
| 36 |
+
|
| 37 |
+
with open(config_path, 'r') as f:
|
| 38 |
+
config = yaml.safe_load(f)
|
| 39 |
+
|
| 40 |
+
return config
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
def config_to_args(config: dict) -> argparse.Namespace:
|
| 44 |
+
"""
|
| 45 |
+
Convert config dictionary to argparse.Namespace compatible with benchmark_evaluation.py.
|
| 46 |
+
|
| 47 |
+
Args:
|
| 48 |
+
config: Configuration dictionary from YAML
|
| 49 |
+
|
| 50 |
+
Returns:
|
| 51 |
+
argparse.Namespace with all required arguments
|
| 52 |
+
"""
|
| 53 |
+
model_config = config.get('model', {})
|
| 54 |
+
data_config = config.get('data', {})
|
| 55 |
+
output_config = config.get('output', {})
|
| 56 |
+
eval_config = config.get('evaluation', {})
|
| 57 |
+
|
| 58 |
+
args = argparse.Namespace()
|
| 59 |
+
|
| 60 |
+
# Model paths
|
| 61 |
+
args.checkpoint_path = model_config.get('checkpoint_path', 'models/alm-enhanced-training/balanced_alm_finetune.ckpt')
|
| 62 |
+
|
| 63 |
+
# Data paths
|
| 64 |
+
args.excel_path = data_config.get('excel_path', 'Benchmark 80 sequences.xlsx')
|
| 65 |
+
args.natural_sequences_path = data_config.get('natural_sequences_path', 'data/ecoli_processed_genes.csv')
|
| 66 |
+
args.name_col = data_config.get('name_col')
|
| 67 |
+
args.seq_col = data_config.get('seq_col')
|
| 68 |
+
args.sheet_name = data_config.get('sheet_name')
|
| 69 |
+
|
| 70 |
+
# Output paths
|
| 71 |
+
args.output_dir = output_config.get('output_dir', 'benchmark_results')
|
| 72 |
+
|
| 73 |
+
# Evaluation parameters
|
| 74 |
+
args.use_gpu = eval_config.get('use_gpu', True)
|
| 75 |
+
args.compare_with_base = eval_config.get('compare_with_base', False)
|
| 76 |
+
args.max_test_proteins = eval_config.get('max_test_proteins', 0)
|
| 77 |
+
|
| 78 |
+
return args
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
def validate_config(config: dict):
|
| 82 |
+
"""
|
| 83 |
+
Validate configuration before running benchmarks.
|
| 84 |
+
|
| 85 |
+
Args:
|
| 86 |
+
config: Configuration dictionary
|
| 87 |
+
|
| 88 |
+
Raises:
|
| 89 |
+
ValueError: If configuration is invalid
|
| 90 |
+
"""
|
| 91 |
+
data_config = config.get('data', {})
|
| 92 |
+
excel_path = data_config.get('excel_path', 'Benchmark 80 sequences.xlsx')
|
| 93 |
+
|
| 94 |
+
if not os.path.exists(excel_path):
|
| 95 |
+
raise ValueError(
|
| 96 |
+
f"Benchmark Excel file not found: {excel_path}\n"
|
| 97 |
+
"Please provide a valid path to your benchmark sequences file."
|
| 98 |
+
)
|
| 99 |
+
|
| 100 |
+
model_config = config.get('model', {})
|
| 101 |
+
checkpoint_path = model_config.get('checkpoint_path')
|
| 102 |
+
|
| 103 |
+
# Check if checkpoint exists locally, or will be downloaded from HF
|
| 104 |
+
if checkpoint_path and os.path.exists(checkpoint_path):
|
| 105 |
+
print(f"Using local checkpoint: {checkpoint_path}")
|
| 106 |
+
else:
|
| 107 |
+
print(f"Checkpoint not found locally: {checkpoint_path}")
|
| 108 |
+
print("Will attempt to download from Hugging Face (saketh11/ColiFormer) if needed")
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
def main():
|
| 112 |
+
"""Main entry point for benchmark evaluation."""
|
| 113 |
+
parser = argparse.ArgumentParser(
|
| 114 |
+
description="Run benchmark evaluation for ENCOT",
|
| 115 |
+
formatter_class=argparse.RawDescriptionHelpFormatter,
|
| 116 |
+
epilog="""
|
| 117 |
+
Examples:
|
| 118 |
+
# Run with configuration file
|
| 119 |
+
python scripts/run_benchmarks.py --config configs/benchmark.yaml
|
| 120 |
+
|
| 121 |
+
# Run with command-line arguments
|
| 122 |
+
python scripts/run_benchmarks.py --excel_path Benchmark_80_sequences.xlsx --checkpoint_path models/my_model.ckpt
|
| 123 |
+
|
| 124 |
+
# Override config values
|
| 125 |
+
python scripts/run_benchmarks.py --config configs/benchmark.yaml --use_gpu --max_test_proteins 50
|
| 126 |
+
"""
|
| 127 |
+
)
|
| 128 |
+
parser.add_argument(
|
| 129 |
+
"--config",
|
| 130 |
+
type=str,
|
| 131 |
+
default=None,
|
| 132 |
+
help="Path to YAML configuration file"
|
| 133 |
+
)
|
| 134 |
+
parser.add_argument(
|
| 135 |
+
"--excel_path",
|
| 136 |
+
type=str,
|
| 137 |
+
default=None,
|
| 138 |
+
help="Path to benchmark Excel file (overrides config)"
|
| 139 |
+
)
|
| 140 |
+
parser.add_argument(
|
| 141 |
+
"--checkpoint_path",
|
| 142 |
+
type=str,
|
| 143 |
+
default=None,
|
| 144 |
+
help="Path to model checkpoint (overrides config)"
|
| 145 |
+
)
|
| 146 |
+
parser.add_argument(
|
| 147 |
+
"--output_dir",
|
| 148 |
+
type=str,
|
| 149 |
+
default=None,
|
| 150 |
+
help="Output directory for results (overrides config)"
|
| 151 |
+
)
|
| 152 |
+
parser.add_argument(
|
| 153 |
+
"--use_gpu",
|
| 154 |
+
action="store_true",
|
| 155 |
+
help="Use GPU if available (overrides config)"
|
| 156 |
+
)
|
| 157 |
+
parser.add_argument(
|
| 158 |
+
"--max_test_proteins",
|
| 159 |
+
type=int,
|
| 160 |
+
default=None,
|
| 161 |
+
help="Maximum number of proteins to test (overrides config)"
|
| 162 |
+
)
|
| 163 |
+
|
| 164 |
+
args = parser.parse_args()
|
| 165 |
+
|
| 166 |
+
try:
|
| 167 |
+
# Lazy import so `--help` works even if plotting/ML deps are missing.
|
| 168 |
+
from benchmark_evaluation import main as benchmark_main
|
| 169 |
+
|
| 170 |
+
if args.config:
|
| 171 |
+
# Load configuration from file
|
| 172 |
+
print(f"Loading configuration from {args.config}...")
|
| 173 |
+
config = load_config(args.config)
|
| 174 |
+
|
| 175 |
+
# Override with command-line arguments if provided
|
| 176 |
+
if args.excel_path:
|
| 177 |
+
config.setdefault('data', {})['excel_path'] = args.excel_path
|
| 178 |
+
if args.checkpoint_path:
|
| 179 |
+
config.setdefault('model', {})['checkpoint_path'] = args.checkpoint_path
|
| 180 |
+
if args.output_dir:
|
| 181 |
+
config.setdefault('output', {})['output_dir'] = args.output_dir
|
| 182 |
+
if args.use_gpu:
|
| 183 |
+
config.setdefault('evaluation', {})['use_gpu'] = True
|
| 184 |
+
if args.max_test_proteins is not None:
|
| 185 |
+
config.setdefault('evaluation', {})['max_test_proteins'] = args.max_test_proteins
|
| 186 |
+
|
| 187 |
+
# Validate configuration
|
| 188 |
+
validate_config(config)
|
| 189 |
+
|
| 190 |
+
# Convert config to args namespace
|
| 191 |
+
benchmark_args = config_to_args(config)
|
| 192 |
+
else:
|
| 193 |
+
# Use command-line arguments directly
|
| 194 |
+
if not args.excel_path:
|
| 195 |
+
parser.error("Either --config or --excel_path must be provided")
|
| 196 |
+
|
| 197 |
+
benchmark_args = argparse.Namespace()
|
| 198 |
+
benchmark_args.excel_path = args.excel_path
|
| 199 |
+
benchmark_args.checkpoint_path = args.checkpoint_path or 'models/alm-enhanced-training/balanced_alm_finetune.ckpt'
|
| 200 |
+
benchmark_args.natural_sequences_path = 'data/ecoli_processed_genes.csv'
|
| 201 |
+
benchmark_args.output_dir = args.output_dir or 'benchmark_results'
|
| 202 |
+
benchmark_args.use_gpu = args.use_gpu
|
| 203 |
+
benchmark_args.max_test_proteins = args.max_test_proteins or 0
|
| 204 |
+
benchmark_args.name_col = None
|
| 205 |
+
benchmark_args.seq_col = None
|
| 206 |
+
benchmark_args.sheet_name = None
|
| 207 |
+
|
| 208 |
+
# Validate
|
| 209 |
+
if not os.path.exists(benchmark_args.excel_path):
|
| 210 |
+
raise ValueError(f"Benchmark Excel file not found: {benchmark_args.excel_path}")
|
| 211 |
+
|
| 212 |
+
# Print configuration summary
|
| 213 |
+
print("\n" + "="*60)
|
| 214 |
+
print("Benchmark Configuration Summary")
|
| 215 |
+
print("="*60)
|
| 216 |
+
print(f"Excel file: {benchmark_args.excel_path}")
|
| 217 |
+
print(f"Checkpoint: {benchmark_args.checkpoint_path}")
|
| 218 |
+
print(f"Output directory: {benchmark_args.output_dir}")
|
| 219 |
+
print(f"Use GPU: {benchmark_args.use_gpu}")
|
| 220 |
+
print(f"Max test proteins: {benchmark_args.max_test_proteins if benchmark_args.max_test_proteins > 0 else 'All'}")
|
| 221 |
+
print("="*60 + "\n")
|
| 222 |
+
|
| 223 |
+
# Run benchmark
|
| 224 |
+
benchmark_main(benchmark_args)
|
| 225 |
+
|
| 226 |
+
except Exception as e:
|
| 227 |
+
print(f"Error: {e}", file=sys.stderr)
|
| 228 |
+
import traceback
|
| 229 |
+
traceback.print_exc()
|
| 230 |
+
sys.exit(1)
|
| 231 |
+
|
| 232 |
+
|
| 233 |
+
if __name__ == "__main__":
|
| 234 |
+
main()
|
| 235 |
+
|
scripts/train.py
ADDED
|
@@ -0,0 +1,228 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Training entry point for ColiFormer.
|
| 3 |
+
|
| 4 |
+
This script wraps finetune.py and loads configuration from YAML files.
|
| 5 |
+
|
| 6 |
+
Usage:
|
| 7 |
+
python scripts/train.py --config configs/train_ecoli_alm.yaml
|
| 8 |
+
python scripts/train.py --config configs/train_ecoli_quick.yaml
|
| 9 |
+
"""
|
| 10 |
+
|
| 11 |
+
import argparse
|
| 12 |
+
import os
|
| 13 |
+
import sys
|
| 14 |
+
from pathlib import Path
|
| 15 |
+
|
| 16 |
+
# Add parent directory to path to import finetune
|
| 17 |
+
sys.path.insert(0, str(Path(__file__).parent.parent))
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def load_config(config_path: str) -> dict:
|
| 21 |
+
"""
|
| 22 |
+
Load configuration from YAML file.
|
| 23 |
+
|
| 24 |
+
Args:
|
| 25 |
+
config_path: Path to YAML config file
|
| 26 |
+
|
| 27 |
+
Returns:
|
| 28 |
+
Dictionary with configuration values
|
| 29 |
+
"""
|
| 30 |
+
# Lazy import so `python scripts/train.py --help` works without dependencies installed.
|
| 31 |
+
import yaml
|
| 32 |
+
|
| 33 |
+
if not os.path.exists(config_path):
|
| 34 |
+
raise FileNotFoundError(f"Config file not found: {config_path}")
|
| 35 |
+
|
| 36 |
+
with open(config_path, 'r') as f:
|
| 37 |
+
config = yaml.safe_load(f)
|
| 38 |
+
|
| 39 |
+
return config
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
def config_to_args(config: dict) -> argparse.Namespace:
|
| 43 |
+
"""
|
| 44 |
+
Convert config dictionary to argparse.Namespace compatible with finetune.py.
|
| 45 |
+
|
| 46 |
+
Args:
|
| 47 |
+
config: Configuration dictionary from YAML
|
| 48 |
+
|
| 49 |
+
Returns:
|
| 50 |
+
argparse.Namespace with all required arguments
|
| 51 |
+
"""
|
| 52 |
+
# Extract nested config values
|
| 53 |
+
data_config = config.get('data', {})
|
| 54 |
+
training_config = config.get('training', {})
|
| 55 |
+
checkpoint_config = config.get('checkpoint', {})
|
| 56 |
+
alm_config = config.get('alm', {})
|
| 57 |
+
gc_penalty_config = config.get('gc_penalty', {})
|
| 58 |
+
|
| 59 |
+
# Build args namespace
|
| 60 |
+
args = argparse.Namespace()
|
| 61 |
+
|
| 62 |
+
# Data paths
|
| 63 |
+
args.dataset_dir = data_config.get('dataset_dir', 'data')
|
| 64 |
+
|
| 65 |
+
# Checkpoint paths
|
| 66 |
+
args.checkpoint_dir = checkpoint_config.get('checkpoint_dir', 'models/checkpoints')
|
| 67 |
+
args.checkpoint_filename = checkpoint_config.get('checkpoint_filename', 'finetune.ckpt')
|
| 68 |
+
|
| 69 |
+
# Training parameters
|
| 70 |
+
args.batch_size = training_config.get('batch_size', 6)
|
| 71 |
+
args.max_epochs = training_config.get('max_epochs', 15)
|
| 72 |
+
args.num_workers = training_config.get('num_workers', 5)
|
| 73 |
+
args.accumulate_grad_batches = training_config.get('accumulate_grad_batches', 1)
|
| 74 |
+
args.num_gpus = training_config.get('num_gpus', 4)
|
| 75 |
+
args.learning_rate = training_config.get('learning_rate', 5e-5)
|
| 76 |
+
args.warmup_fraction = training_config.get('warmup_fraction', 0.1)
|
| 77 |
+
args.save_every_n_steps = training_config.get('save_every_n_steps', 512)
|
| 78 |
+
args.seed = training_config.get('seed', 123)
|
| 79 |
+
args.log_every_n_steps = training_config.get('log_every_n_steps', 20)
|
| 80 |
+
args.debug = training_config.get('debug', False)
|
| 81 |
+
|
| 82 |
+
# GC penalty (legacy)
|
| 83 |
+
args.gc_penalty_weight = gc_penalty_config.get('weight', 0.0)
|
| 84 |
+
|
| 85 |
+
# ALM parameters
|
| 86 |
+
args.use_lagrangian = alm_config.get('enabled', False)
|
| 87 |
+
args.gc_target = alm_config.get('gc_target', 0.52)
|
| 88 |
+
args.curriculum_epochs = alm_config.get('curriculum_epochs', 3)
|
| 89 |
+
args.lagrangian_rho = alm_config.get('initial_penalty_factor', 20.0) # Use initial_penalty_factor as rho
|
| 90 |
+
args.alm_tolerance = alm_config.get('tolerance', 1e-5)
|
| 91 |
+
args.alm_dual_tolerance = alm_config.get('dual_tolerance', 1e-5)
|
| 92 |
+
args.alm_penalty_update_factor = alm_config.get('penalty_update_factor', 10.0)
|
| 93 |
+
args.alm_initial_penalty_factor = alm_config.get('initial_penalty_factor', 20.0)
|
| 94 |
+
args.alm_tolerance_update_factor = alm_config.get('tolerance_update_factor', 0.1)
|
| 95 |
+
args.alm_rel_penalty_increase_threshold = alm_config.get('rel_penalty_increase_threshold', 0.1)
|
| 96 |
+
args.alm_max_penalty = alm_config.get('max_penalty', 1e6)
|
| 97 |
+
args.alm_min_penalty = alm_config.get('min_penalty', 1e-6)
|
| 98 |
+
|
| 99 |
+
return args
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
def validate_config(config: dict):
|
| 103 |
+
"""
|
| 104 |
+
Validate configuration before training.
|
| 105 |
+
|
| 106 |
+
Args:
|
| 107 |
+
config: Configuration dictionary
|
| 108 |
+
|
| 109 |
+
Raises:
|
| 110 |
+
ValueError: If configuration is invalid
|
| 111 |
+
"""
|
| 112 |
+
data_config = config.get('data', {})
|
| 113 |
+
dataset_dir = data_config.get('dataset_dir', 'data')
|
| 114 |
+
|
| 115 |
+
# Check dataset directory exists
|
| 116 |
+
if not os.path.exists(dataset_dir):
|
| 117 |
+
raise ValueError(f"Dataset directory not found: {dataset_dir}")
|
| 118 |
+
|
| 119 |
+
# Check for expected data files
|
| 120 |
+
finetune_set = os.path.join(dataset_dir, 'finetune_set.json')
|
| 121 |
+
if not os.path.exists(finetune_set):
|
| 122 |
+
raise ValueError(
|
| 123 |
+
f"Training data not found: {finetune_set}\n"
|
| 124 |
+
"Please run data preprocessing first:\n"
|
| 125 |
+
" python scripts/preprocess_data.py"
|
| 126 |
+
)
|
| 127 |
+
|
| 128 |
+
# Validate checkpoint directory can be created
|
| 129 |
+
checkpoint_config = config.get('checkpoint', {})
|
| 130 |
+
checkpoint_dir = checkpoint_config.get('checkpoint_dir', 'models/checkpoints')
|
| 131 |
+
os.makedirs(checkpoint_dir, exist_ok=True)
|
| 132 |
+
|
| 133 |
+
|
| 134 |
+
def main():
|
| 135 |
+
"""Main entry point for training."""
|
| 136 |
+
parser = argparse.ArgumentParser(
|
| 137 |
+
description="Train ENCOT model with configuration file",
|
| 138 |
+
formatter_class=argparse.RawDescriptionHelpFormatter,
|
| 139 |
+
epilog="""
|
| 140 |
+
Examples:
|
| 141 |
+
# Train with main ALM configuration
|
| 142 |
+
python scripts/train.py --config configs/train_ecoli_alm.yaml
|
| 143 |
+
|
| 144 |
+
# Quick test training (CPU, 1 epoch)
|
| 145 |
+
python scripts/train.py --config configs/train_ecoli_quick.yaml
|
| 146 |
+
|
| 147 |
+
# Override config values from command line
|
| 148 |
+
python scripts/train.py --config configs/train_ecoli_alm.yaml --num_gpus 2 --batch_size 4
|
| 149 |
+
"""
|
| 150 |
+
)
|
| 151 |
+
parser.add_argument(
|
| 152 |
+
"--config",
|
| 153 |
+
type=str,
|
| 154 |
+
required=True,
|
| 155 |
+
help="Path to YAML configuration file"
|
| 156 |
+
)
|
| 157 |
+
parser.add_argument(
|
| 158 |
+
"--num_gpus",
|
| 159 |
+
type=int,
|
| 160 |
+
default=None,
|
| 161 |
+
help="Override number of GPUs from config"
|
| 162 |
+
)
|
| 163 |
+
parser.add_argument(
|
| 164 |
+
"--batch_size",
|
| 165 |
+
type=int,
|
| 166 |
+
default=None,
|
| 167 |
+
help="Override batch size from config"
|
| 168 |
+
)
|
| 169 |
+
parser.add_argument(
|
| 170 |
+
"--max_epochs",
|
| 171 |
+
type=int,
|
| 172 |
+
default=None,
|
| 173 |
+
help="Override max epochs from config"
|
| 174 |
+
)
|
| 175 |
+
|
| 176 |
+
args = parser.parse_args()
|
| 177 |
+
|
| 178 |
+
try:
|
| 179 |
+
# Lazy import so `--help` works even if training deps are missing.
|
| 180 |
+
from finetune import main as finetune_main
|
| 181 |
+
|
| 182 |
+
# Load configuration
|
| 183 |
+
print(f"Loading configuration from {args.config}...")
|
| 184 |
+
config = load_config(args.config)
|
| 185 |
+
|
| 186 |
+
# Override with command-line arguments if provided
|
| 187 |
+
if args.num_gpus is not None:
|
| 188 |
+
config.setdefault('training', {})['num_gpus'] = args.num_gpus
|
| 189 |
+
if args.batch_size is not None:
|
| 190 |
+
config.setdefault('training', {})['batch_size'] = args.batch_size
|
| 191 |
+
if args.max_epochs is not None:
|
| 192 |
+
config.setdefault('training', {})['max_epochs'] = args.max_epochs
|
| 193 |
+
|
| 194 |
+
# Validate configuration
|
| 195 |
+
print("Validating configuration...")
|
| 196 |
+
validate_config(config)
|
| 197 |
+
|
| 198 |
+
# Convert config to args namespace
|
| 199 |
+
train_args = config_to_args(config)
|
| 200 |
+
|
| 201 |
+
# Print training summary
|
| 202 |
+
print("\n" + "="*60)
|
| 203 |
+
print("Training Configuration Summary")
|
| 204 |
+
print("="*60)
|
| 205 |
+
print(f"Dataset directory: {train_args.dataset_dir}")
|
| 206 |
+
print(f"Checkpoint directory: {train_args.checkpoint_dir}")
|
| 207 |
+
print(f"Checkpoint filename: {train_args.checkpoint_filename}")
|
| 208 |
+
print(f"Batch size: {train_args.batch_size}")
|
| 209 |
+
print(f"Max epochs: {train_args.max_epochs}")
|
| 210 |
+
print(f"Learning rate: {train_args.learning_rate}")
|
| 211 |
+
print(f"Number of GPUs: {train_args.num_gpus}")
|
| 212 |
+
print(f"ALM enabled: {train_args.use_lagrangian}")
|
| 213 |
+
if train_args.use_lagrangian:
|
| 214 |
+
print(f"GC target: {train_args.gc_target}")
|
| 215 |
+
print(f"Curriculum epochs: {train_args.curriculum_epochs}")
|
| 216 |
+
print("="*60 + "\n")
|
| 217 |
+
|
| 218 |
+
# Run training
|
| 219 |
+
finetune_main(train_args)
|
| 220 |
+
|
| 221 |
+
except Exception as e:
|
| 222 |
+
print(f"Error: {e}", file=sys.stderr)
|
| 223 |
+
sys.exit(1)
|
| 224 |
+
|
| 225 |
+
|
| 226 |
+
if __name__ == "__main__":
|
| 227 |
+
main()
|
| 228 |
+
|
setup.py
ADDED
|
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
|
| 3 |
+
from setuptools import find_packages, setup
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
def read_requirements():
|
| 7 |
+
with open("requirements.txt") as f:
|
| 8 |
+
return [line.strip() for line in f if line.strip() and not line.startswith("#")]
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def read_readme():
|
| 12 |
+
here = os.path.abspath(os.path.dirname(__file__))
|
| 13 |
+
readme_path = os.path.join(here, "README.md")
|
| 14 |
+
|
| 15 |
+
with open(readme_path, "r", encoding="utf-8") as f:
|
| 16 |
+
return f.read()
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
setup(
|
| 20 |
+
name="ENCOT",
|
| 21 |
+
version="1.0.0",
|
| 22 |
+
packages=find_packages(),
|
| 23 |
+
install_requires=read_requirements(),
|
| 24 |
+
author="Adibvafa Fallahpour",
|
| 25 |
+
author_email="Adibvafa.fallahpour@mail.utoronto.ca",
|
| 26 |
+
description=(
|
| 27 |
+
"Transformer-based codon optimization for E. coli using "
|
| 28 |
+
"deep learning with Augmented-Lagrangian GC control. "
|
| 29 |
+
"Built on CodonTransformer for E. coli-specific optimization."
|
| 30 |
+
),
|
| 31 |
+
long_description=read_readme(),
|
| 32 |
+
long_description_content_type="text/markdown",
|
| 33 |
+
url="https://github.com/geno543/ENCOT",
|
| 34 |
+
classifiers=[
|
| 35 |
+
"Programming Language :: Python :: 3",
|
| 36 |
+
"License :: OSI Approved :: Apache Software License",
|
| 37 |
+
"Operating System :: OS Independent",
|
| 38 |
+
],
|
| 39 |
+
python_requires=">=3.9",
|
| 40 |
+
)
|
src/CodonTransformer_inference_template.xlsx
ADDED
|
Binary file (17.4 kB). View file
|
|
|
src/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
"""Model weights, tokenizer, and other resources."""
|
src/banner_final.png
ADDED
|
Git LFS Details
|
src/organism2id.pkl
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:44f7b73bbb3c6ea82bf864e886b57b219cbd5f14fe79a8aa47d2befab5d40ad0
|
| 3 |
+
size 4605
|
streamlit_app.py
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Public Streamlit entrypoint for ENCOT.
|
| 2 |
+
|
| 3 |
+
This file is intentionally minimal so hosting platforms like Streamlit
|
| 4 |
+
Community Cloud can run the existing UI without changing project structure.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
from pathlib import Path
|
| 8 |
+
import sys
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
ROOT = Path(__file__).resolve().parent
|
| 12 |
+
if str(ROOT) not in sys.path:
|
| 13 |
+
sys.path.insert(0, str(ROOT))
|
| 14 |
+
|
| 15 |
+
# Importing this module runs the Streamlit app defined there.
|
| 16 |
+
import streamlit_gui.app # noqa: F401,E402
|
streamlit_gui/app.py
ADDED
|
@@ -0,0 +1,1456 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
File: app.py
|
| 3 |
+
-------------
|
| 4 |
+
Streamlit GUI for ENCOT. Provides sequence validation, optimization,
|
| 5 |
+
and visualization for E. coli-focused workflows with optional post-processing.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
import streamlit as st
|
| 9 |
+
import torch
|
| 10 |
+
import pandas as pd
|
| 11 |
+
import numpy as np
|
| 12 |
+
import plotly.graph_objects as go
|
| 13 |
+
import plotly.express as px
|
| 14 |
+
from transformers import AutoTokenizer, BigBirdForMaskedLM
|
| 15 |
+
from huggingface_hub import hf_hub_download
|
| 16 |
+
from datasets import load_dataset
|
| 17 |
+
import time
|
| 18 |
+
import threading
|
| 19 |
+
from typing import Dict, Optional, Tuple
|
| 20 |
+
import warnings
|
| 21 |
+
warnings.filterwarnings("ignore")
|
| 22 |
+
|
| 23 |
+
import sys
|
| 24 |
+
import os
|
| 25 |
+
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
| 26 |
+
|
| 27 |
+
from CodonTransformer.CodonPrediction import (
|
| 28 |
+
predict_dna_sequence,
|
| 29 |
+
load_model
|
| 30 |
+
)
|
| 31 |
+
from CodonTransformer.CodonEvaluation import (
|
| 32 |
+
get_GC_content,
|
| 33 |
+
calculate_tAI,
|
| 34 |
+
get_ecoli_tai_weights,
|
| 35 |
+
scan_for_restriction_sites,
|
| 36 |
+
count_negative_cis_elements,
|
| 37 |
+
calculate_homopolymer_runs
|
| 38 |
+
)
|
| 39 |
+
from CAI import CAI, relative_adaptiveness
|
| 40 |
+
from CodonTransformer.CodonUtils import get_organism2id_dict
|
| 41 |
+
import json
|
| 42 |
+
|
| 43 |
+
try:
|
| 44 |
+
from CodonTransformer.CodonPostProcessing import (
|
| 45 |
+
polish_sequence_with_dnachisel,
|
| 46 |
+
DNACHISEL_AVAILABLE
|
| 47 |
+
)
|
| 48 |
+
POST_PROCESSING_AVAILABLE = True
|
| 49 |
+
except ImportError:
|
| 50 |
+
POST_PROCESSING_AVAILABLE = False
|
| 51 |
+
DNACHISEL_AVAILABLE = False
|
| 52 |
+
|
| 53 |
+
st.set_page_config(
|
| 54 |
+
page_title="ENCOT GUI",
|
| 55 |
+
layout="wide",
|
| 56 |
+
initial_sidebar_state="expanded"
|
| 57 |
+
)
|
| 58 |
+
|
| 59 |
+
if 'model' not in st.session_state:
|
| 60 |
+
st.session_state.model = None
|
| 61 |
+
if 'tokenizer' not in st.session_state:
|
| 62 |
+
st.session_state.tokenizer = None
|
| 63 |
+
if 'device' not in st.session_state:
|
| 64 |
+
st.session_state.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 65 |
+
if 'optimization_running' not in st.session_state:
|
| 66 |
+
st.session_state.optimization_running = False
|
| 67 |
+
if 'results' not in st.session_state:
|
| 68 |
+
st.session_state.results = None
|
| 69 |
+
if 'post_processed_results' not in st.session_state:
|
| 70 |
+
st.session_state.post_processed_results = None
|
| 71 |
+
if 'cai_weights' not in st.session_state:
|
| 72 |
+
st.session_state.cai_weights = None
|
| 73 |
+
if 'tai_weights' not in st.session_state:
|
| 74 |
+
st.session_state.tai_weights = None
|
| 75 |
+
|
| 76 |
+
def get_organism_tai_weights(organism: str) -> Dict[str, float]:
|
| 77 |
+
"""Get organism-specific tAI weights from pre-calculated data"""
|
| 78 |
+
try:
|
| 79 |
+
# Load organism-specific tAI weights
|
| 80 |
+
weights_file = os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), 'organism_tai_weights.json')
|
| 81 |
+
with open(weights_file, 'r') as f:
|
| 82 |
+
all_weights = json.load(f)
|
| 83 |
+
|
| 84 |
+
if organism in all_weights:
|
| 85 |
+
return all_weights[organism]
|
| 86 |
+
else:
|
| 87 |
+
# Fallback to E. coli if organism not found
|
| 88 |
+
st.warning(f"tAI weights for {organism} not found, using E. coli weights")
|
| 89 |
+
return all_weights.get("Escherichia coli general", get_ecoli_tai_weights())
|
| 90 |
+
except Exception as e:
|
| 91 |
+
st.error(f"Error loading organism-specific tAI weights: {e}")
|
| 92 |
+
return get_ecoli_tai_weights()
|
| 93 |
+
|
| 94 |
+
def load_model_and_tokenizer():
|
| 95 |
+
"""Load the model and tokenizer with progress tracking"""
|
| 96 |
+
if st.session_state.model is None or st.session_state.tokenizer is None:
|
| 97 |
+
with st.spinner("Loading model... This may take a few minutes."):
|
| 98 |
+
progress_bar = st.progress(0)
|
| 99 |
+
status_text = st.empty()
|
| 100 |
+
|
| 101 |
+
status_text.text("Loading tokenizer...")
|
| 102 |
+
progress_bar.progress(25)
|
| 103 |
+
st.session_state.tokenizer = AutoTokenizer.from_pretrained("adibvafa/CodonTransformer")
|
| 104 |
+
|
| 105 |
+
status_text.text("Loading fine-tuned model from Hugging Face...")
|
| 106 |
+
progress_bar.progress(50)
|
| 107 |
+
# Try to download and load fine-tuned model from Hugging Face
|
| 108 |
+
try:
|
| 109 |
+
# Download the checkpoint file from Hugging Face
|
| 110 |
+
from huggingface_hub import hf_hub_download
|
| 111 |
+
|
| 112 |
+
status_text.text("Downloading model from saketh11/ColiFormer...")
|
| 113 |
+
model_path = hf_hub_download(
|
| 114 |
+
repo_id="saketh11/ColiFormer",
|
| 115 |
+
filename="balanced_alm_finetune.ckpt",
|
| 116 |
+
cache_dir="./hf_cache"
|
| 117 |
+
)
|
| 118 |
+
|
| 119 |
+
status_text.text("Loading downloaded model...")
|
| 120 |
+
st.session_state.model = load_model(
|
| 121 |
+
model_path=model_path,
|
| 122 |
+
device=st.session_state.device,
|
| 123 |
+
attention_type="original_full"
|
| 124 |
+
)
|
| 125 |
+
status_text.text("Fine-tuned model loaded from Hugging Face")
|
| 126 |
+
st.session_state.model_type = "fine_tuned_hf"
|
| 127 |
+
except Exception as e:
|
| 128 |
+
status_text.text(f"Failed to load from Hugging Face: {str(e)[:50]}...")
|
| 129 |
+
status_text.text("Loading base model as fallback...")
|
| 130 |
+
st.session_state.model = BigBirdForMaskedLM.from_pretrained("adibvafa/CodonTransformer")
|
| 131 |
+
st.session_state.model = st.session_state.model.to(st.session_state.device)
|
| 132 |
+
st.session_state.model_type = "base"
|
| 133 |
+
|
| 134 |
+
progress_bar.progress(100)
|
| 135 |
+
time.sleep(0.5)
|
| 136 |
+
|
| 137 |
+
status_text.empty()
|
| 138 |
+
progress_bar.empty()
|
| 139 |
+
|
| 140 |
+
@st.cache_data
|
| 141 |
+
def download_reference_data():
|
| 142 |
+
"""Download and cache reference data from Hugging Face"""
|
| 143 |
+
try:
|
| 144 |
+
# Download the processed genes file from Hugging Face
|
| 145 |
+
file_path = hf_hub_download(
|
| 146 |
+
repo_id="saketh11/ColiFormer-Data",
|
| 147 |
+
filename="ecoli_processed_genes.csv",
|
| 148 |
+
repo_type="dataset"
|
| 149 |
+
)
|
| 150 |
+
df = pd.read_csv(file_path)
|
| 151 |
+
return df['dna_sequence'].tolist()
|
| 152 |
+
except Exception as e:
|
| 153 |
+
st.warning(f"Could not download reference data from Hugging Face: {e}")
|
| 154 |
+
# Fallback to minimal sequences
|
| 155 |
+
return [
|
| 156 |
+
"ATGGCGAAAGCGCTGTATCGCGAAAGCGCTGTATCGCGAAAGCGCTGTATCGC",
|
| 157 |
+
"ATGAAATTTATTTATTATTATAAATTTATTTATTATTATAAATTTATTTAT",
|
| 158 |
+
"ATGGGTCGTCGTCGTCGTGGTCGTCGTCGTCGTGGTCGTCGTCGTCGTGGT"
|
| 159 |
+
]
|
| 160 |
+
|
| 161 |
+
@st.cache_data
|
| 162 |
+
def download_tai_weights():
|
| 163 |
+
"""Download and cache tAI weights from Hugging Face"""
|
| 164 |
+
try:
|
| 165 |
+
# Download the tAI weights file from Hugging Face
|
| 166 |
+
file_path = hf_hub_download(
|
| 167 |
+
repo_id="saketh11/ColiFormer-Data",
|
| 168 |
+
filename="organism_tai_weights.json",
|
| 169 |
+
repo_type="dataset"
|
| 170 |
+
)
|
| 171 |
+
with open(file_path, 'r') as f:
|
| 172 |
+
all_weights = json.load(f)
|
| 173 |
+
return all_weights.get("Escherichia coli general", get_ecoli_tai_weights())
|
| 174 |
+
except Exception as e:
|
| 175 |
+
st.warning(f"Could not download tAI weights from Hugging Face: {e}")
|
| 176 |
+
return get_ecoli_tai_weights()
|
| 177 |
+
|
| 178 |
+
def load_reference_data(organism: str = "Escherichia coli general"):
|
| 179 |
+
"""Load reference sequences and tAI weights for E. coli"""
|
| 180 |
+
if 'cai_weights' not in st.session_state or st.session_state['cai_weights'] is None:
|
| 181 |
+
try:
|
| 182 |
+
# Download reference sequences from Hugging Face
|
| 183 |
+
with st.spinner("Downloading E. coli reference sequences from Hugging Face..."):
|
| 184 |
+
ref_sequences = download_reference_data()
|
| 185 |
+
st.session_state['cai_weights'] = relative_adaptiveness(sequences=ref_sequences)
|
| 186 |
+
if len(ref_sequences) > 100: # If we got the full dataset
|
| 187 |
+
st.success(f"Downloaded {len(ref_sequences):,} E. coli reference sequences for CAI calculation")
|
| 188 |
+
else:
|
| 189 |
+
st.info(f"Using {len(ref_sequences)} minimal reference sequences (full dataset unavailable)")
|
| 190 |
+
except Exception as e:
|
| 191 |
+
st.error(f"Error loading E. coli reference data: {e}")
|
| 192 |
+
st.session_state['cai_weights'] = {}
|
| 193 |
+
# tAI weights (E. coli only)
|
| 194 |
+
if 'tai_weights' not in st.session_state or st.session_state['tai_weights'] is None:
|
| 195 |
+
try:
|
| 196 |
+
with st.spinner("Downloading E. coli tAI weights from Hugging Face..."):
|
| 197 |
+
st.session_state['tai_weights'] = download_tai_weights()
|
| 198 |
+
st.success("Downloaded E. coli tAI weights")
|
| 199 |
+
except Exception as e:
|
| 200 |
+
st.error(f"Error loading E. coli tAI weights: {e}")
|
| 201 |
+
st.session_state['tai_weights'] = {}
|
| 202 |
+
|
| 203 |
+
def validate_sequence(sequence: str) -> Tuple[bool, str, str, str]:
|
| 204 |
+
"""Validate sequence and return status, message, sequence type, and possibly fixed sequence"""
|
| 205 |
+
if not sequence:
|
| 206 |
+
return False, "Sequence cannot be empty", "unknown", sequence
|
| 207 |
+
|
| 208 |
+
# Remove whitespace and convert to uppercase
|
| 209 |
+
sequence = sequence.strip().upper()
|
| 210 |
+
|
| 211 |
+
# Check if it's a DNA sequence
|
| 212 |
+
dna_chars = set("ATGC")
|
| 213 |
+
protein_chars = set("ACDEFGHIKLMNPQRSTVWY*_")
|
| 214 |
+
|
| 215 |
+
sequence_chars = set(sequence)
|
| 216 |
+
|
| 217 |
+
# If all characters are DNA nucleotides, treat as DNA
|
| 218 |
+
if sequence_chars.issubset(dna_chars):
|
| 219 |
+
if len(sequence) < 3:
|
| 220 |
+
return False, "DNA sequence must be at least 3 nucleotides long", "dna", sequence
|
| 221 |
+
|
| 222 |
+
# Auto-fix DNA sequences not divisible by 3
|
| 223 |
+
if len(sequence) % 3 != 0:
|
| 224 |
+
remainder = len(sequence) % 3
|
| 225 |
+
fixed_sequence = sequence[:-remainder]
|
| 226 |
+
message = f"Valid DNA sequence (auto-fixed: removed {remainder} nucleotides from end to make divisible by 3)"
|
| 227 |
+
else:
|
| 228 |
+
fixed_sequence = sequence
|
| 229 |
+
message = "Valid DNA sequence"
|
| 230 |
+
|
| 231 |
+
return True, message, "dna", fixed_sequence
|
| 232 |
+
|
| 233 |
+
# If contains protein-specific amino acids, treat as protein
|
| 234 |
+
elif sequence_chars.issubset(protein_chars):
|
| 235 |
+
if len(sequence) < 3:
|
| 236 |
+
return False, "Protein sequence must be at least 3 amino acids long", "protein", sequence
|
| 237 |
+
return True, "Valid protein sequence", "protein", sequence
|
| 238 |
+
|
| 239 |
+
# Invalid characters
|
| 240 |
+
else:
|
| 241 |
+
invalid_chars = sequence_chars - (dna_chars | protein_chars)
|
| 242 |
+
return False, f"Invalid characters found: {', '.join(invalid_chars)}", "unknown", sequence
|
| 243 |
+
|
| 244 |
+
def calculate_input_metrics(sequence: str, organism: str, sequence_type: str) -> Dict:
|
| 245 |
+
"""Calculate metrics for the input sequence using E. coli reference only"""
|
| 246 |
+
# Load reference data (E. coli only)
|
| 247 |
+
load_reference_data()
|
| 248 |
+
if sequence_type == "dna":
|
| 249 |
+
dna_sequence = sequence.upper()
|
| 250 |
+
metrics = {
|
| 251 |
+
'length': len(dna_sequence) // 3,
|
| 252 |
+
'gc_content': get_GC_content(dna_sequence),
|
| 253 |
+
'baseline_dna': dna_sequence,
|
| 254 |
+
'sequence_type': 'dna'
|
| 255 |
+
}
|
| 256 |
+
try:
|
| 257 |
+
if 'cai_weights' in st.session_state and st.session_state['cai_weights']:
|
| 258 |
+
metrics['cai'] = CAI(dna_sequence, weights=st.session_state['cai_weights'])
|
| 259 |
+
else:
|
| 260 |
+
metrics['cai'] = None
|
| 261 |
+
except:
|
| 262 |
+
metrics['cai'] = None
|
| 263 |
+
try:
|
| 264 |
+
if 'tai_weights' in st.session_state and st.session_state['tai_weights']:
|
| 265 |
+
metrics['tai'] = calculate_tAI(dna_sequence, st.session_state['tai_weights'])
|
| 266 |
+
else:
|
| 267 |
+
metrics['tai'] = None
|
| 268 |
+
except:
|
| 269 |
+
metrics['tai'] = None
|
| 270 |
+
else:
|
| 271 |
+
most_frequent_codons = {
|
| 272 |
+
'A': 'GCG', 'C': 'TGC', 'D': 'GAT', 'E': 'GAA', 'F': 'TTT',
|
| 273 |
+
'G': 'GGC', 'H': 'CAT', 'I': 'ATT', 'K': 'AAA', 'L': 'CTG',
|
| 274 |
+
'M': 'ATG', 'N': 'AAC', 'P': 'CCG', 'Q': 'CAG', 'R': 'CGC',
|
| 275 |
+
'S': 'TCG', 'T': 'ACG', 'V': 'GTG', 'W': 'TGG', 'Y': 'TAT',
|
| 276 |
+
'*': 'TAA', '_': 'TAA'
|
| 277 |
+
}
|
| 278 |
+
baseline_dna = ''.join([most_frequent_codons.get(aa, 'NNN') for aa in sequence])
|
| 279 |
+
metrics = {
|
| 280 |
+
'length': len(sequence),
|
| 281 |
+
'gc_content': get_GC_content(baseline_dna),
|
| 282 |
+
'baseline_dna': baseline_dna,
|
| 283 |
+
'sequence_type': 'protein'
|
| 284 |
+
}
|
| 285 |
+
try:
|
| 286 |
+
if 'cai_weights' in st.session_state and st.session_state['cai_weights']:
|
| 287 |
+
metrics['cai'] = CAI(baseline_dna, weights=st.session_state['cai_weights'])
|
| 288 |
+
else:
|
| 289 |
+
metrics['cai'] = None
|
| 290 |
+
except:
|
| 291 |
+
metrics['cai'] = None
|
| 292 |
+
try:
|
| 293 |
+
if 'tai_weights' in st.session_state and st.session_state['tai_weights']:
|
| 294 |
+
metrics['tai'] = calculate_tAI(baseline_dna, st.session_state['tai_weights'])
|
| 295 |
+
else:
|
| 296 |
+
metrics['tai'] = None
|
| 297 |
+
except:
|
| 298 |
+
metrics['tai'] = None
|
| 299 |
+
try:
|
| 300 |
+
analysis_dna = metrics['baseline_dna']
|
| 301 |
+
metrics['restriction_sites'] = len(scan_for_restriction_sites(analysis_dna))
|
| 302 |
+
metrics['negative_cis_elements'] = count_negative_cis_elements(analysis_dna)
|
| 303 |
+
metrics['homopolymer_runs'] = calculate_homopolymer_runs(analysis_dna)
|
| 304 |
+
except:
|
| 305 |
+
metrics['restriction_sites'] = 0
|
| 306 |
+
metrics['negative_cis_elements'] = 0
|
| 307 |
+
metrics['homopolymer_runs'] = 0
|
| 308 |
+
return metrics
|
| 309 |
+
|
| 310 |
+
def translate_dna_to_protein(dna_sequence: str) -> str:
|
| 311 |
+
"""Translate DNA sequence to protein sequence"""
|
| 312 |
+
codon_table = {
|
| 313 |
+
'TTT': 'F', 'TTC': 'F', 'TTA': 'L', 'TTG': 'L',
|
| 314 |
+
'TCT': 'S', 'TCC': 'S', 'TCA': 'S', 'TCG': 'S',
|
| 315 |
+
'TAT': 'Y', 'TAC': 'Y', 'TAA': '*', 'TAG': '*',
|
| 316 |
+
'TGT': 'C', 'TGC': 'C', 'TGA': '*', 'TGG': 'W',
|
| 317 |
+
'CTT': 'L', 'CTC': 'L', 'CTA': 'L', 'CTG': 'L',
|
| 318 |
+
'CCT': 'P', 'CCC': 'P', 'CCA': 'P', 'CCG': 'P',
|
| 319 |
+
'CAT': 'H', 'CAC': 'H', 'CAA': 'Q', 'CAG': 'Q',
|
| 320 |
+
'CGT': 'R', 'CGC': 'R', 'CGA': 'R', 'CGG': 'R',
|
| 321 |
+
'ATT': 'I', 'ATC': 'I', 'ATA': 'I', 'ATG': 'M',
|
| 322 |
+
'ACT': 'T', 'ACC': 'T', 'ACA': 'T', 'ACG': 'T',
|
| 323 |
+
'AAT': 'N', 'AAC': 'N', 'AAA': 'K', 'AAG': 'K',
|
| 324 |
+
'AGT': 'S', 'AGC': 'S', 'AGA': 'R', 'AGG': 'R',
|
| 325 |
+
'GTT': 'V', 'GTC': 'V', 'GTA': 'V', 'GTG': 'V',
|
| 326 |
+
'GCT': 'A', 'GCC': 'A', 'GCA': 'A', 'GCG': 'A',
|
| 327 |
+
'GAT': 'D', 'GAC': 'D', 'GAA': 'E', 'GAG': 'E',
|
| 328 |
+
'GGT': 'G', 'GGC': 'G', 'GGA': 'G', 'GGG': 'G'
|
| 329 |
+
}
|
| 330 |
+
|
| 331 |
+
protein = ""
|
| 332 |
+
for i in range(0, len(dna_sequence), 3):
|
| 333 |
+
codon = dna_sequence[i:i+3].upper()
|
| 334 |
+
if len(codon) == 3:
|
| 335 |
+
aa = codon_table.get(codon, 'X')
|
| 336 |
+
if aa == '*': # Stop codon
|
| 337 |
+
break
|
| 338 |
+
protein += aa
|
| 339 |
+
|
| 340 |
+
return protein
|
| 341 |
+
|
| 342 |
+
def create_gc_content_plot(sequence: str, window_size: int = 50) -> go.Figure:
|
| 343 |
+
"""Create a sliding window GC content plot"""
|
| 344 |
+
if len(sequence) < window_size:
|
| 345 |
+
window_size = len(sequence) // 3
|
| 346 |
+
|
| 347 |
+
positions = []
|
| 348 |
+
gc_values = []
|
| 349 |
+
|
| 350 |
+
for i in range(0, len(sequence) - window_size + 1, 3): # Step by codons
|
| 351 |
+
window = sequence[i:i + window_size]
|
| 352 |
+
gc_content = get_GC_content(window)
|
| 353 |
+
positions.append(i // 3) # Position in codons
|
| 354 |
+
gc_values.append(gc_content)
|
| 355 |
+
|
| 356 |
+
fig = go.Figure()
|
| 357 |
+
fig.add_trace(go.Scatter(
|
| 358 |
+
x=positions,
|
| 359 |
+
y=gc_values,
|
| 360 |
+
mode='lines',
|
| 361 |
+
name='GC Content',
|
| 362 |
+
line=dict(color='blue', width=2)
|
| 363 |
+
))
|
| 364 |
+
|
| 365 |
+
# Add target range
|
| 366 |
+
fig.add_hline(y=45, line_dash="dash", line_color="red",
|
| 367 |
+
annotation_text="Min Target (45%)")
|
| 368 |
+
fig.add_hline(y=55, line_dash="dash", line_color="red",
|
| 369 |
+
annotation_text="Max Target (55%)")
|
| 370 |
+
|
| 371 |
+
fig.update_layout(
|
| 372 |
+
title=f'GC Content (sliding window: {window_size} bp)',
|
| 373 |
+
xaxis_title='Position (codons)',
|
| 374 |
+
yaxis_title='GC Content (%)',
|
| 375 |
+
height=300
|
| 376 |
+
)
|
| 377 |
+
|
| 378 |
+
return fig
|
| 379 |
+
|
| 380 |
+
def create_gc_comparison_chart(before_metrics: Dict, after_metrics: Dict) -> go.Figure:
|
| 381 |
+
"""Create a comparison chart for GC Content"""
|
| 382 |
+
fig = go.Figure()
|
| 383 |
+
fig.add_trace(go.Bar(
|
| 384 |
+
name='Before Optimization',
|
| 385 |
+
x=['GC Content (%)'],
|
| 386 |
+
y=[before_metrics.get('gc_content', 0)],
|
| 387 |
+
marker_color='lightblue',
|
| 388 |
+
text=[f"{before_metrics.get('gc_content', 0):.1f}%"],
|
| 389 |
+
textposition='auto'
|
| 390 |
+
))
|
| 391 |
+
fig.add_trace(go.Bar(
|
| 392 |
+
name='After Optimization',
|
| 393 |
+
x=['GC Content (%)'],
|
| 394 |
+
y=[after_metrics.get('gc_content', 0)],
|
| 395 |
+
marker_color='darkblue',
|
| 396 |
+
text=[f"{after_metrics.get('gc_content', 0):.1f}%"],
|
| 397 |
+
textposition='auto'
|
| 398 |
+
))
|
| 399 |
+
fig.update_layout(
|
| 400 |
+
title='GC Content Comparison: Before vs After',
|
| 401 |
+
xaxis_title='Metric',
|
| 402 |
+
yaxis_title='Value (%)',
|
| 403 |
+
barmode='group',
|
| 404 |
+
height=300
|
| 405 |
+
)
|
| 406 |
+
return fig
|
| 407 |
+
|
| 408 |
+
def create_expression_comparison_chart(before_metrics: Dict, after_metrics: Dict) -> go.Figure:
|
| 409 |
+
"""Create a comparison chart for expression metrics (CAI, tAI)"""
|
| 410 |
+
metrics_names = ['CAI', 'tAI']
|
| 411 |
+
before_values = [
|
| 412 |
+
before_metrics.get('cai', 0) if before_metrics.get('cai') else 0,
|
| 413 |
+
before_metrics.get('tai', 0) if before_metrics.get('tai') else 0
|
| 414 |
+
]
|
| 415 |
+
after_values = [
|
| 416 |
+
after_metrics.get('cai', 0) if after_metrics.get('cai') else 0,
|
| 417 |
+
after_metrics.get('tai', 0) if after_metrics.get('tai') else 0
|
| 418 |
+
]
|
| 419 |
+
|
| 420 |
+
fig = go.Figure()
|
| 421 |
+
fig.add_trace(go.Bar(
|
| 422 |
+
name='Before Optimization',
|
| 423 |
+
x=metrics_names,
|
| 424 |
+
y=before_values,
|
| 425 |
+
marker_color='lightblue',
|
| 426 |
+
text=[f"{v:.3f}" for v in before_values],
|
| 427 |
+
textposition='auto'
|
| 428 |
+
))
|
| 429 |
+
fig.add_trace(go.Bar(
|
| 430 |
+
name='After Optimization',
|
| 431 |
+
x=metrics_names,
|
| 432 |
+
y=after_values,
|
| 433 |
+
marker_color='darkblue',
|
| 434 |
+
text=[f"{v:.3f}" for v in after_values],
|
| 435 |
+
textposition='auto'
|
| 436 |
+
))
|
| 437 |
+
fig.update_layout(
|
| 438 |
+
title='Expression Metrics Comparison: Before vs After',
|
| 439 |
+
xaxis_title='Metric',
|
| 440 |
+
yaxis_title='Value',
|
| 441 |
+
barmode='group',
|
| 442 |
+
height=300
|
| 443 |
+
)
|
| 444 |
+
return fig
|
| 445 |
+
|
| 446 |
+
def smart_codon_replacement(dna_sequence: str, target_gc_min: float = 0.45, target_gc_max: float = 0.55, max_iterations: int = 100) -> str:
|
| 447 |
+
"""Smart codon replacement to optimize GC content while maximizing CAI"""
|
| 448 |
+
|
| 449 |
+
# Codon alternatives with their GC content
|
| 450 |
+
codon_alternatives = {
|
| 451 |
+
# Serine: high GC options
|
| 452 |
+
'TCT': ['TCG', 'TCC', 'TCA', 'AGT', 'AGC'], # 33% -> 67%, 67%, 33%, 33%, 67%
|
| 453 |
+
'TCA': ['TCG', 'TCC', 'TCT', 'AGT', 'AGC'],
|
| 454 |
+
'AGT': ['TCG', 'TCC', 'TCT', 'TCA', 'AGC'],
|
| 455 |
+
|
| 456 |
+
# Leucine: various GC options
|
| 457 |
+
'TTA': ['TTG', 'CTT', 'CTC', 'CTA', 'CTG'], # 0% -> 33%, 33%, 67%, 33%, 67%
|
| 458 |
+
'TTG': ['TTA', 'CTT', 'CTC', 'CTA', 'CTG'],
|
| 459 |
+
'CTT': ['CTG', 'CTC', 'TTA', 'TTG', 'CTA'],
|
| 460 |
+
'CTA': ['CTG', 'CTC', 'CTT', 'TTA', 'TTG'],
|
| 461 |
+
|
| 462 |
+
# Arginine: various GC options
|
| 463 |
+
'AGA': ['CGT', 'CGC', 'CGA', 'CGG', 'AGG'], # 33% -> 67%, 100%, 67%, 100%, 67%
|
| 464 |
+
'AGG': ['CGT', 'CGC', 'CGA', 'CGG', 'AGA'],
|
| 465 |
+
'CGT': ['CGC', 'CGG', 'CGA', 'AGA', 'AGG'],
|
| 466 |
+
'CGA': ['CGC', 'CGG', 'CGT', 'AGA', 'AGG'],
|
| 467 |
+
|
| 468 |
+
# Proline
|
| 469 |
+
'CCT': ['CCG', 'CCC', 'CCA'], # 67% -> 100%, 100%, 67%
|
| 470 |
+
'CCA': ['CCG', 'CCC', 'CCT'],
|
| 471 |
+
|
| 472 |
+
# Threonine
|
| 473 |
+
'ACT': ['ACG', 'ACC', 'ACA'], # 33% -> 67%, 67%, 33%
|
| 474 |
+
'ACA': ['ACG', 'ACC', 'ACT'],
|
| 475 |
+
|
| 476 |
+
# Alanine
|
| 477 |
+
'GCT': ['GCG', 'GCC', 'GCA'], # 67% -> 100%, 100%, 67%
|
| 478 |
+
'GCA': ['GCG', 'GCC', 'GCT'],
|
| 479 |
+
|
| 480 |
+
# Glycine
|
| 481 |
+
'GGT': ['GGG', 'GGC', 'GGA'], # 67% -> 100%, 100%, 67%
|
| 482 |
+
'GGA': ['GGG', 'GGC', 'GGT'],
|
| 483 |
+
|
| 484 |
+
# Valine
|
| 485 |
+
'GTT': ['GTG', 'GTC', 'GTA'], # 67% -> 100%, 100%, 67%
|
| 486 |
+
'GTA': ['GTG', 'GTC', 'GTT'],
|
| 487 |
+
}
|
| 488 |
+
|
| 489 |
+
def get_codon_gc(codon):
|
| 490 |
+
return (codon.count('G') + codon.count('C')) / 3.0
|
| 491 |
+
|
| 492 |
+
current_sequence = dna_sequence.upper()
|
| 493 |
+
current_gc = get_GC_content(current_sequence)
|
| 494 |
+
|
| 495 |
+
if target_gc_min <= current_gc <= target_gc_max:
|
| 496 |
+
return current_sequence
|
| 497 |
+
|
| 498 |
+
codons = [current_sequence[i:i+3] for i in range(0, len(current_sequence), 3)]
|
| 499 |
+
|
| 500 |
+
for iteration in range(max_iterations):
|
| 501 |
+
current_gc = get_GC_content(''.join(codons))
|
| 502 |
+
|
| 503 |
+
if target_gc_min <= current_gc <= target_gc_max:
|
| 504 |
+
break
|
| 505 |
+
|
| 506 |
+
# Find best codon to replace
|
| 507 |
+
best_improvement = 0
|
| 508 |
+
best_pos = -1
|
| 509 |
+
best_replacement = None
|
| 510 |
+
|
| 511 |
+
for pos, codon in enumerate(codons):
|
| 512 |
+
if codon in codon_alternatives:
|
| 513 |
+
for alt_codon in codon_alternatives[codon]:
|
| 514 |
+
# Calculate GC change
|
| 515 |
+
old_gc_contrib = get_codon_gc(codon)
|
| 516 |
+
new_gc_contrib = get_codon_gc(alt_codon)
|
| 517 |
+
gc_change = new_gc_contrib - old_gc_contrib
|
| 518 |
+
|
| 519 |
+
# Check if this change moves us toward target
|
| 520 |
+
if current_gc < target_gc_min and gc_change > best_improvement:
|
| 521 |
+
best_improvement = gc_change
|
| 522 |
+
best_pos = pos
|
| 523 |
+
best_replacement = alt_codon
|
| 524 |
+
elif current_gc > target_gc_max and gc_change < best_improvement:
|
| 525 |
+
best_improvement = abs(gc_change)
|
| 526 |
+
best_pos = pos
|
| 527 |
+
best_replacement = alt_codon
|
| 528 |
+
|
| 529 |
+
if best_pos >= 0:
|
| 530 |
+
if isinstance(best_replacement, str):
|
| 531 |
+
codons[best_pos] = best_replacement
|
| 532 |
+
else:
|
| 533 |
+
break # No more improvements possible
|
| 534 |
+
|
| 535 |
+
return ''.join(codons)
|
| 536 |
+
|
| 537 |
+
def run_optimization(protein: str, organism: str, use_post_processing: bool = False):
|
| 538 |
+
"""Run the optimization using the exact method from run_full_comparison.py with auto GC correction"""
|
| 539 |
+
st.session_state.optimization_running = True
|
| 540 |
+
st.session_state.post_processed_results = None
|
| 541 |
+
|
| 542 |
+
try:
|
| 543 |
+
# Use the exact same method that achieved best results in evaluation
|
| 544 |
+
result = predict_dna_sequence(
|
| 545 |
+
protein=protein,
|
| 546 |
+
organism=organism,
|
| 547 |
+
device=st.session_state.device,
|
| 548 |
+
model=st.session_state.model,
|
| 549 |
+
deterministic=True,
|
| 550 |
+
match_protein=True,
|
| 551 |
+
)
|
| 552 |
+
|
| 553 |
+
# Check GC content and auto-correct if out of optimal range
|
| 554 |
+
_res = result[0] if isinstance(result, list) else result
|
| 555 |
+
initial_gc = get_GC_content(_res.predicted_dna)
|
| 556 |
+
|
| 557 |
+
if initial_gc < 45.0 or initial_gc > 55.0:
|
| 558 |
+
# Auto-correct GC content silently
|
| 559 |
+
optimized_dna = smart_codon_replacement(_res.predicted_dna, 0.45, 0.55)
|
| 560 |
+
smart_gc = get_GC_content(optimized_dna)
|
| 561 |
+
|
| 562 |
+
if 45.0 <= smart_gc <= 55.0:
|
| 563 |
+
from CodonTransformer.CodonUtils import DNASequencePrediction
|
| 564 |
+
result = DNASequencePrediction(
|
| 565 |
+
organism=_res.organism,
|
| 566 |
+
protein=_res.protein,
|
| 567 |
+
processed_input=_res.processed_input,
|
| 568 |
+
predicted_dna=optimized_dna
|
| 569 |
+
)
|
| 570 |
+
else:
|
| 571 |
+
# Fall back to constrained beam search silently
|
| 572 |
+
try:
|
| 573 |
+
result = predict_dna_sequence(
|
| 574 |
+
protein=protein,
|
| 575 |
+
organism=organism,
|
| 576 |
+
device=st.session_state.device,
|
| 577 |
+
model=st.session_state.model,
|
| 578 |
+
deterministic=True,
|
| 579 |
+
match_protein=True,
|
| 580 |
+
use_constrained_search=True,
|
| 581 |
+
gc_bounds=(0.45, 0.55),
|
| 582 |
+
beam_size=20
|
| 583 |
+
)
|
| 584 |
+
_res2 = result[0] if isinstance(result, list) else result
|
| 585 |
+
final_gc = get_GC_content(_res2.predicted_dna)
|
| 586 |
+
except Exception as e:
|
| 587 |
+
# If constrained search fails, use smart replacement result anyway
|
| 588 |
+
from CodonTransformer.CodonUtils import DNASequencePrediction
|
| 589 |
+
result = DNASequencePrediction(
|
| 590 |
+
organism=_res.organism,
|
| 591 |
+
protein=_res.protein,
|
| 592 |
+
processed_input=_res.processed_input,
|
| 593 |
+
predicted_dna=optimized_dna
|
| 594 |
+
)
|
| 595 |
+
|
| 596 |
+
st.session_state.results = result
|
| 597 |
+
|
| 598 |
+
# Post-processing if enabled
|
| 599 |
+
if use_post_processing and POST_PROCESSING_AVAILABLE and result:
|
| 600 |
+
try:
|
| 601 |
+
_res = result[0] if isinstance(result, list) else result
|
| 602 |
+
polished_sequence = polish_sequence_with_dnachisel(
|
| 603 |
+
dna_sequence=_res.predicted_dna,
|
| 604 |
+
protein_sequence=protein,
|
| 605 |
+
gc_bounds=(45.0, 55.0),
|
| 606 |
+
cai_species=organism.lower().replace(' ', '_'),
|
| 607 |
+
avoid_homopolymers_length=6
|
| 608 |
+
)
|
| 609 |
+
|
| 610 |
+
# Create enhanced result object
|
| 611 |
+
from CodonTransformer.CodonUtils import DNASequencePrediction
|
| 612 |
+
st.session_state.post_processed_results = DNASequencePrediction(
|
| 613 |
+
organism=result.organism,
|
| 614 |
+
protein=result.protein,
|
| 615 |
+
processed_input=result.processed_input,
|
| 616 |
+
predicted_dna=polished_sequence
|
| 617 |
+
)
|
| 618 |
+
except Exception as e:
|
| 619 |
+
st.session_state.post_processed_results = f"Post-processing error: {str(e)}"
|
| 620 |
+
|
| 621 |
+
except Exception as e:
|
| 622 |
+
st.session_state.results = f"Error: {str(e)}"
|
| 623 |
+
|
| 624 |
+
finally:
|
| 625 |
+
st.session_state.optimization_running = False
|
| 626 |
+
|
| 627 |
+
def main():
|
| 628 |
+
st.title("ENCOT")
|
| 629 |
+
st.markdown("E. coli codon optimization with constraint-aware decoding and in silico evaluation metrics.")
|
| 630 |
+
|
| 631 |
+
# Remove the performance highlights expander (details/summary block)
|
| 632 |
+
# (No expander here anymore)
|
| 633 |
+
|
| 634 |
+
# Load model
|
| 635 |
+
load_model_and_tokenizer()
|
| 636 |
+
|
| 637 |
+
# Create the main tabbed interface
|
| 638 |
+
tab1, tab2, tab3, tab4 = st.tabs(["Single Optimize", "Batch Process", "Comparative Analysis", "Advanced Settings"])
|
| 639 |
+
|
| 640 |
+
with tab1:
|
| 641 |
+
single_sequence_optimization()
|
| 642 |
+
|
| 643 |
+
with tab2:
|
| 644 |
+
batch_processing_interface()
|
| 645 |
+
|
| 646 |
+
with tab3:
|
| 647 |
+
comparative_analysis_interface()
|
| 648 |
+
|
| 649 |
+
with tab4:
|
| 650 |
+
advanced_settings_interface()
|
| 651 |
+
|
| 652 |
+
def single_sequence_optimization():
|
| 653 |
+
"""Single sequence optimization interface - enhanced from original functionality"""
|
| 654 |
+
# Sidebar configuration
|
| 655 |
+
st.sidebar.header("Configuration")
|
| 656 |
+
organism_options = [
|
| 657 |
+
"Escherichia coli general",
|
| 658 |
+
"Saccharomyces cerevisiae",
|
| 659 |
+
"Homo sapiens",
|
| 660 |
+
"Bacillus subtilis",
|
| 661 |
+
"Pichia pastoris"
|
| 662 |
+
]
|
| 663 |
+
organism = st.sidebar.selectbox("Select Target Organism", organism_options)
|
| 664 |
+
load_reference_data(organism)
|
| 665 |
+
with st.sidebar.expander("Advanced Optimization Settings"):
|
| 666 |
+
st.markdown("**Model Parameters**")
|
| 667 |
+
use_deterministic = st.checkbox("Deterministic Mode", value=True, help="Use deterministic decoding for reproducible results")
|
| 668 |
+
match_protein = st.checkbox("Match Protein Validation", value=True, help="Ensure DNA translates back to exact protein")
|
| 669 |
+
st.markdown("**GC Content Control**")
|
| 670 |
+
gc_target_min = st.slider("GC Target Min (%)", 30, 70, 45, help="Minimum GC content target")
|
| 671 |
+
gc_target_max = st.slider("GC Target Max (%)", 30, 70, 55, help="Maximum GC content target")
|
| 672 |
+
st.markdown("**Quality Constraints**")
|
| 673 |
+
avoid_restriction_sites = st.multiselect(
|
| 674 |
+
"Avoid Restriction Sites",
|
| 675 |
+
["EcoRI", "BamHI", "HindIII", "XhoI", "NotI"],
|
| 676 |
+
default=["EcoRI", "BamHI"]
|
| 677 |
+
)
|
| 678 |
+
st.sidebar.subheader("Post-Processing")
|
| 679 |
+
use_post_processing = st.sidebar.checkbox(
|
| 680 |
+
"Enable DNAChisel Post-Processing",
|
| 681 |
+
value=False,
|
| 682 |
+
disabled=not POST_PROCESSING_AVAILABLE,
|
| 683 |
+
help="Polish sequences to remove restriction sites, homopolymers, and synthesis issues"
|
| 684 |
+
)
|
| 685 |
+
if not POST_PROCESSING_AVAILABLE:
|
| 686 |
+
st.sidebar.warning("DNAChisel not available. Install with: pip install dnachisel")
|
| 687 |
+
|
| 688 |
+
# Dataset Information
|
| 689 |
+
st.sidebar.markdown("---")
|
| 690 |
+
st.sidebar.markdown("### Dataset Information")
|
| 691 |
+
st.sidebar.markdown("""
|
| 692 |
+
- **Dataset**: [ColiFormer-Data](https://huggingface.co/datasets/saketh11/ColiFormer-Data)
|
| 693 |
+
- **Training**: 3,676 high-expression E. coli genes (NCBI-curated)
|
| 694 |
+
- **Evaluation**: 37,053 native E. coli genes + 80 recombinant protein targets
|
| 695 |
+
- **Auto-download**: CAI weights & tAI coefficients
|
| 696 |
+
""")
|
| 697 |
+
|
| 698 |
+
# Model Information
|
| 699 |
+
st.sidebar.markdown("### Model Information")
|
| 700 |
+
st.sidebar.markdown("""
|
| 701 |
+
- **Model**: [ColiFormer](https://huggingface.co/saketh11/ColiFormer)
|
| 702 |
+
- **Base**: CodonTransformer BigBird architecture
|
| 703 |
+
- **Architecture**: BigBird Transformer + ALM
|
| 704 |
+
- **Auto-download**: From Hugging Face Hub
|
| 705 |
+
""")
|
| 706 |
+
col1, col2 = st.columns([1, 1])
|
| 707 |
+
with col1:
|
| 708 |
+
st.header("Input Sequence")
|
| 709 |
+
sequence_input = st.text_area(
|
| 710 |
+
"Enter Protein or DNA Sequence",
|
| 711 |
+
height=150,
|
| 712 |
+
placeholder="Enter protein sequence (MKWVT...) or DNA sequence (ATGGCG...)\n\nExample protein: MKWVTFISLLFLFSSAYSRGVFRRDAHKSEVAHRFKDLGEENFKALVLIAFAQYLQQCPFEDHVKLVNEVTEFAKTCVADESAENCDKSLHTLFGDKLCTVATLRETYGEMADCCAKQEPERNECFLQHKDDNPNLPRLVRPEVDVMCTAFHDNEETFLKKYLYEIARRHPYFYAPELLFFAKRYKAAFTECCQAADKAACLLPKLDELRDEGKASSAKQRLKCASLQKFGERAFKAWAVARLSQRFPKAEFAEVSKLVTDLTKVHTECCHGDLLECADDRADLAKYICENQDSISSKLKECCEKPLLEKSHCIAEVENDEMPADLPSLAADFVESKDVCKNYAEAKDVFLGMFLYEYARRHPDYSVVLLLRLAKTYETTLEKCCAAADPHECYAKVFDEFKPLVEEPQNLIKQNCELFEQLGEYKFQNALLVRYTKKVPQVSTPTLVEVSRNLGKVGSKCCKHPEAKRMPCAEDYLSVVLNQLCVLHEKTPVSDRVTKCCTE"
|
| 713 |
+
)
|
| 714 |
+
analyze_btn = st.button("Analyze Sequence", type="primary")
|
| 715 |
+
if sequence_input and analyze_btn:
|
| 716 |
+
is_valid, message, sequence_type, fixed_sequence = validate_sequence(sequence_input)
|
| 717 |
+
if is_valid:
|
| 718 |
+
st.success(message)
|
| 719 |
+
# Store in session state for use by Optimize Sequence
|
| 720 |
+
st.session_state.sequence_clean = fixed_sequence
|
| 721 |
+
st.session_state.sequence_type = sequence_type
|
| 722 |
+
st.session_state.input_metrics = calculate_input_metrics(fixed_sequence, organism, sequence_type)
|
| 723 |
+
st.session_state.organism = organism
|
| 724 |
+
else:
|
| 725 |
+
st.error(message)
|
| 726 |
+
if "Invalid characters" in message:
|
| 727 |
+
st.info("Suggestion: Remove spaces, numbers, and special characters. Use only standard amino acid letters (A–Z) for proteins or nucleotides (A/T/G/C) for DNA.")
|
| 728 |
+
elif "too long" in message:
|
| 729 |
+
st.info("Suggestion: Consider breaking long sequences into smaller segments for optimization.")
|
| 730 |
+
elif "too short" in message:
|
| 731 |
+
st.info("Suggestion: Minimum length is 3 characters. Ensure your sequence is complete.")
|
| 732 |
+
# Clear session state if invalid
|
| 733 |
+
st.session_state.sequence_clean = None
|
| 734 |
+
st.session_state.sequence_type = None
|
| 735 |
+
st.session_state.input_metrics = None
|
| 736 |
+
st.session_state.organism = None
|
| 737 |
+
elif not sequence_input:
|
| 738 |
+
st.session_state.sequence_clean = None
|
| 739 |
+
st.session_state.sequence_type = None
|
| 740 |
+
st.session_state.input_metrics = None
|
| 741 |
+
st.session_state.organism = None
|
| 742 |
+
|
| 743 |
+
# Always display the last analysis if it exists in session state
|
| 744 |
+
if st.session_state.get('input_metrics') and st.session_state.get('sequence_type'):
|
| 745 |
+
input_metrics = st.session_state.input_metrics
|
| 746 |
+
sequence_type = st.session_state.sequence_type
|
| 747 |
+
st.subheader("Input Analysis")
|
| 748 |
+
metrics_col1, metrics_col2, metrics_col3 = st.columns(3)
|
| 749 |
+
with metrics_col1:
|
| 750 |
+
unit = "codons" if sequence_type == "dna" else "AA"
|
| 751 |
+
length = input_metrics.get('length', 0) if input_metrics else 0
|
| 752 |
+
gc_content = input_metrics.get('gc_content', 0) if input_metrics else 0
|
| 753 |
+
st.metric("Length", f"{length} {unit}")
|
| 754 |
+
st.metric("GC Content", f"{gc_content:.1f}%")
|
| 755 |
+
with metrics_col2:
|
| 756 |
+
cai_val = input_metrics.get('cai') if input_metrics else None
|
| 757 |
+
if cai_val:
|
| 758 |
+
label = "CAI" if sequence_type == "dna" else "CAI (baseline)"
|
| 759 |
+
st.metric(label, f"{cai_val:.3f}")
|
| 760 |
+
else:
|
| 761 |
+
st.metric("CAI", "N/A")
|
| 762 |
+
with metrics_col3:
|
| 763 |
+
tai_val = input_metrics.get('tai') if input_metrics else None
|
| 764 |
+
if tai_val:
|
| 765 |
+
label = "tAI" if sequence_type == "dna" else "tAI (baseline)"
|
| 766 |
+
st.metric(label, f"{tai_val:.3f}")
|
| 767 |
+
else:
|
| 768 |
+
st.metric("tAI", "N/A")
|
| 769 |
+
st.subheader("Sequence Quality Analysis")
|
| 770 |
+
analysis_col1, analysis_col2, analysis_col3 = st.columns(3)
|
| 771 |
+
with analysis_col1:
|
| 772 |
+
sites_count = input_metrics.get('restriction_sites', 0) if input_metrics else 0
|
| 773 |
+
color = "normal" if sites_count <= 2 else "inverse"
|
| 774 |
+
st.metric("Restriction Sites", sites_count)
|
| 775 |
+
with analysis_col2:
|
| 776 |
+
neg_elements = input_metrics.get('negative_cis_elements', 0) if input_metrics else 0
|
| 777 |
+
st.metric("Negative Elements", neg_elements)
|
| 778 |
+
with analysis_col3:
|
| 779 |
+
homo_runs = input_metrics.get('homopolymer_runs', 0) if input_metrics else 0
|
| 780 |
+
st.metric("Homopolymer Runs", homo_runs)
|
| 781 |
+
baseline_dna = input_metrics.get('baseline_dna', '') if input_metrics else ''
|
| 782 |
+
if baseline_dna and len(baseline_dna) > 150:
|
| 783 |
+
st.subheader("GC Content Distribution")
|
| 784 |
+
fig = create_gc_content_plot(baseline_dna)
|
| 785 |
+
fig.update_layout(
|
| 786 |
+
title="Input Sequence GC Content Analysis",
|
| 787 |
+
xaxis_title="Position (codons)",
|
| 788 |
+
yaxis_title="GC Content (%)",
|
| 789 |
+
hovermode='x unified'
|
| 790 |
+
)
|
| 791 |
+
st.plotly_chart(fig, use_container_width=True)
|
| 792 |
+
|
| 793 |
+
with col2:
|
| 794 |
+
st.header("Optimization Results")
|
| 795 |
+
# Enhanced optimization button
|
| 796 |
+
if (
|
| 797 |
+
st.session_state.get('sequence_clean')
|
| 798 |
+
and st.session_state.get('sequence_type')
|
| 799 |
+
and not st.session_state.optimization_running
|
| 800 |
+
):
|
| 801 |
+
st.markdown("**Ready to optimize your sequence!**")
|
| 802 |
+
strategy_info = st.container()
|
| 803 |
+
with strategy_info:
|
| 804 |
+
st.info(f"""
|
| 805 |
+
**Optimization Strategy:**
|
| 806 |
+
• Target organism: {st.session_state.organism}
|
| 807 |
+
• Model: Fine-tuned CodonTransformer (89.6M parameters)
|
| 808 |
+
• GC target: {gc_target_min}-{gc_target_max}%
|
| 809 |
+
• Mode: {'Deterministic' if use_deterministic else 'Stochastic'}
|
| 810 |
+
""")
|
| 811 |
+
if st.button("Optimize Sequence", type="primary", use_container_width=True):
|
| 812 |
+
st.session_state.results = None
|
| 813 |
+
if st.session_state.sequence_type == "dna":
|
| 814 |
+
protein_sequence = translate_dna_to_protein(st.session_state.sequence_clean)
|
| 815 |
+
run_optimization(protein_sequence, st.session_state.organism, use_post_processing)
|
| 816 |
+
else:
|
| 817 |
+
run_optimization(st.session_state.sequence_clean, st.session_state.organism, use_post_processing)
|
| 818 |
+
|
| 819 |
+
# Enhanced progress display
|
| 820 |
+
if st.session_state.optimization_running:
|
| 821 |
+
st.info("Optimizing sequence...")
|
| 822 |
+
|
| 823 |
+
# Create progress container
|
| 824 |
+
progress_container = st.container()
|
| 825 |
+
with progress_container:
|
| 826 |
+
progress_bar = st.progress(0)
|
| 827 |
+
status_text = st.empty()
|
| 828 |
+
|
| 829 |
+
# Enhanced progress steps
|
| 830 |
+
steps = [
|
| 831 |
+
"Analyzing input sequence structure...",
|
| 832 |
+
"Loading model...",
|
| 833 |
+
"Running optimization algorithm...",
|
| 834 |
+
"Applying GC/content constraints...",
|
| 835 |
+
"Finalizing optimized sequence..."
|
| 836 |
+
]
|
| 837 |
+
|
| 838 |
+
for i, step in enumerate(steps):
|
| 839 |
+
progress_value = int((i + 1) / len(steps) * 100)
|
| 840 |
+
progress_bar.progress(progress_value)
|
| 841 |
+
status_text.text(step)
|
| 842 |
+
time.sleep(0.8) # Realistic timing
|
| 843 |
+
|
| 844 |
+
progress_bar.empty()
|
| 845 |
+
status_text.empty()
|
| 846 |
+
|
| 847 |
+
# Enhanced results display
|
| 848 |
+
if st.session_state.results and not st.session_state.optimization_running:
|
| 849 |
+
if isinstance(st.session_state.results, str):
|
| 850 |
+
st.error(f"Optimization failed: {st.session_state.results}")
|
| 851 |
+
else:
|
| 852 |
+
display_optimization_results(
|
| 853 |
+
st.session_state.results,
|
| 854 |
+
st.session_state.get('organism', organism),
|
| 855 |
+
st.session_state.get('sequence_clean', ''),
|
| 856 |
+
st.session_state.get('sequence_type', 'protein'),
|
| 857 |
+
st.session_state.get('input_metrics', {})
|
| 858 |
+
)
|
| 859 |
+
|
| 860 |
+
def display_optimization_results(result, organism, original_sequence, sequence_type, input_metrics):
|
| 861 |
+
"""Enhanced results display with publication-quality visualizations"""
|
| 862 |
+
|
| 863 |
+
# Calculate optimized metrics
|
| 864 |
+
optimized_metrics = {
|
| 865 |
+
'gc_content': get_GC_content(result.predicted_dna),
|
| 866 |
+
'length': len(result.predicted_dna)
|
| 867 |
+
}
|
| 868 |
+
|
| 869 |
+
# Calculate CAI and tAI
|
| 870 |
+
try:
|
| 871 |
+
if 'cai_weights' in st.session_state and st.session_state['cai_weights']:
|
| 872 |
+
optimized_metrics['cai'] = CAI(result.predicted_dna, weights=st.session_state['cai_weights'])
|
| 873 |
+
else:
|
| 874 |
+
optimized_metrics['cai'] = None
|
| 875 |
+
except:
|
| 876 |
+
optimized_metrics['cai'] = None
|
| 877 |
+
|
| 878 |
+
try:
|
| 879 |
+
if 'tai_weights' in st.session_state and st.session_state['tai_weights']:
|
| 880 |
+
optimized_metrics['tai'] = calculate_tAI(result.predicted_dna, st.session_state['tai_weights'])
|
| 881 |
+
else:
|
| 882 |
+
optimized_metrics['tai'] = None
|
| 883 |
+
except:
|
| 884 |
+
optimized_metrics['tai'] = None
|
| 885 |
+
|
| 886 |
+
# Success header
|
| 887 |
+
st.success("Optimization complete.")
|
| 888 |
+
|
| 889 |
+
# Key improvements summary
|
| 890 |
+
st.subheader("Optimization Improvements")
|
| 891 |
+
imp_col1, imp_col2, imp_col3 = st.columns(3)
|
| 892 |
+
|
| 893 |
+
if input_metrics is not None:
|
| 894 |
+
with imp_col1:
|
| 895 |
+
if input_metrics.get('gc_content') and optimized_metrics.get('gc_content'):
|
| 896 |
+
gc_change = optimized_metrics['gc_content'] - input_metrics['gc_content']
|
| 897 |
+
st.metric("GC Content", f"{optimized_metrics['gc_content']:.1f}%", delta=f"{gc_change:+.1f}%")
|
| 898 |
+
|
| 899 |
+
with imp_col2:
|
| 900 |
+
if input_metrics.get('cai') and optimized_metrics.get('cai'):
|
| 901 |
+
cai_change = optimized_metrics['cai'] - input_metrics['cai']
|
| 902 |
+
st.metric("CAI Score", f"{optimized_metrics['cai']:.3f}", delta=f"{cai_change:+.3f}")
|
| 903 |
+
|
| 904 |
+
with imp_col3:
|
| 905 |
+
if input_metrics.get('tai') and optimized_metrics.get('tai'):
|
| 906 |
+
tai_change = optimized_metrics['tai'] - input_metrics['tai']
|
| 907 |
+
st.metric("tAI Score", f"{optimized_metrics['tai']:.3f}", delta=f"{tai_change:+.3f}")
|
| 908 |
+
|
| 909 |
+
# Optimized DNA sequence display
|
| 910 |
+
st.subheader("Optimized DNA Sequence")
|
| 911 |
+
st.text_area("Optimized DNA Sequence", result.predicted_dna, height=100)
|
| 912 |
+
|
| 913 |
+
# Enhanced download and export options
|
| 914 |
+
col1, col2, col3 = st.columns(3)
|
| 915 |
+
with col1:
|
| 916 |
+
st.download_button(
|
| 917 |
+
label="Download DNA (FASTA)",
|
| 918 |
+
data=f">Optimized_{organism.replace(' ', '_')}\n{result.predicted_dna}",
|
| 919 |
+
file_name=f"optimized_sequence_{organism.replace(' ', '_')}.fasta",
|
| 920 |
+
mime="text/plain"
|
| 921 |
+
)
|
| 922 |
+
|
| 923 |
+
with col2:
|
| 924 |
+
# Create CSV report
|
| 925 |
+
csv_data = f"Metric,Original,Optimized,Improvement\n"
|
| 926 |
+
csv_data += f"GC Content (%),{input_metrics['gc_content']:.1f},{optimized_metrics['gc_content']:.1f},{optimized_metrics['gc_content'] - input_metrics['gc_content']:+.1f}\n"
|
| 927 |
+
if input_metrics['cai'] and optimized_metrics['cai']:
|
| 928 |
+
csv_data += f"CAI Score,{input_metrics['cai']:.3f},{optimized_metrics['cai']:.3f},{optimized_metrics['cai'] - input_metrics['cai']:+.3f}\n"
|
| 929 |
+
if input_metrics['tai'] and optimized_metrics['tai']:
|
| 930 |
+
csv_data += f"tAI Score,{input_metrics['tai']:.3f},{optimized_metrics['tai']:.3f},{optimized_metrics['tai'] - input_metrics['tai']:+.3f}\n"
|
| 931 |
+
|
| 932 |
+
st.download_button(
|
| 933 |
+
label="Download Metrics (CSV)",
|
| 934 |
+
data=csv_data,
|
| 935 |
+
file_name=f"optimization_metrics_{organism.replace(' ', '_')}.csv",
|
| 936 |
+
mime="text/csv"
|
| 937 |
+
)
|
| 938 |
+
|
| 939 |
+
with col3:
|
| 940 |
+
st.button("Generate PDF Report", help="Coming soon: PDF report")
|
| 941 |
+
|
| 942 |
+
# Enhanced comparison visualizations
|
| 943 |
+
st.subheader("Before vs After Analysis")
|
| 944 |
+
|
| 945 |
+
# Create enhanced comparison charts
|
| 946 |
+
create_enhanced_comparison_charts(input_metrics, optimized_metrics, original_sequence, result.predicted_dna, sequence_type)
|
| 947 |
+
|
| 948 |
+
def create_enhanced_comparison_charts(input_metrics, optimized_metrics, original_dna, optimized_dna, sequence_type):
|
| 949 |
+
"""Create publication-quality comparison visualizations"""
|
| 950 |
+
if input_metrics is None or optimized_metrics is None:
|
| 951 |
+
st.info("No comparison data available.")
|
| 952 |
+
return
|
| 953 |
+
|
| 954 |
+
# GC Content comparison
|
| 955 |
+
gc_comp_fig = create_gc_comparison_chart(input_metrics, optimized_metrics)
|
| 956 |
+
gc_comp_fig.update_layout(
|
| 957 |
+
title="GC Content Optimization Results",
|
| 958 |
+
font=dict(size=12),
|
| 959 |
+
height=350
|
| 960 |
+
)
|
| 961 |
+
st.plotly_chart(gc_comp_fig, use_container_width=True)
|
| 962 |
+
|
| 963 |
+
# Expression metrics comparison
|
| 964 |
+
if input_metrics.get('cai') and optimized_metrics.get('cai'):
|
| 965 |
+
expr_comp_fig = create_expression_comparison_chart(input_metrics, optimized_metrics)
|
| 966 |
+
expr_comp_fig.update_layout(
|
| 967 |
+
title="Expression Potential Improvement",
|
| 968 |
+
font=dict(size=12),
|
| 969 |
+
height=350
|
| 970 |
+
)
|
| 971 |
+
st.plotly_chart(expr_comp_fig, use_container_width=True)
|
| 972 |
+
|
| 973 |
+
# Side-by-side GC distribution analysis
|
| 974 |
+
st.subheader("GC Content Distribution Analysis")
|
| 975 |
+
col1, col2 = st.columns(2)
|
| 976 |
+
|
| 977 |
+
with col1:
|
| 978 |
+
st.write(f"**{'Original DNA' if sequence_type == 'dna' else 'Baseline (Most Frequent Codons)'}**")
|
| 979 |
+
baseline_dna = input_metrics.get('baseline_dna') if input_metrics else None
|
| 980 |
+
plot_dna = baseline_dna if baseline_dna is not None else original_dna
|
| 981 |
+
if plot_dna is not None and isinstance(plot_dna, str) and len(plot_dna) > 150:
|
| 982 |
+
fig_before = create_gc_content_plot(plot_dna)
|
| 983 |
+
fig_before.update_layout(title="Before Optimization", height=300)
|
| 984 |
+
st.plotly_chart(fig_before, use_container_width=True)
|
| 985 |
+
else:
|
| 986 |
+
st.info("Sequence too short for sliding window analysis")
|
| 987 |
+
|
| 988 |
+
with col2:
|
| 989 |
+
st.write("** Model Optimized**")
|
| 990 |
+
if optimized_dna is not None and isinstance(optimized_dna, str) and len(optimized_dna) > 150:
|
| 991 |
+
fig_after = create_gc_content_plot(optimized_dna)
|
| 992 |
+
fig_after.update_layout(title="After Optimization", height=300)
|
| 993 |
+
st.plotly_chart(fig_after, use_container_width=True)
|
| 994 |
+
else:
|
| 995 |
+
st.info("Sequence too short for sliding window analysis")
|
| 996 |
+
|
| 997 |
+
def batch_processing_interface():
|
| 998 |
+
"""Batch processing interface for multiple sequences"""
|
| 999 |
+
st.header("Batch Processing")
|
| 1000 |
+
st.markdown("**Process multiple protein sequences simultaneously with optimization**")
|
| 1001 |
+
|
| 1002 |
+
# File upload section
|
| 1003 |
+
st.subheader("Upload Sequences")
|
| 1004 |
+
uploaded_file = st.file_uploader(
|
| 1005 |
+
"Choose a file with multiple sequences",
|
| 1006 |
+
type=['csv', 'xlsx', 'fasta', 'txt', 'fa'],
|
| 1007 |
+
help="Upload CSV, Excel (XLSX, with 'sequence' column) or FASTA format files"
|
| 1008 |
+
)
|
| 1009 |
+
|
| 1010 |
+
if uploaded_file:
|
| 1011 |
+
st.success(f"File uploaded: {uploaded_file.name}")
|
| 1012 |
+
|
| 1013 |
+
# Process uploaded file
|
| 1014 |
+
try:
|
| 1015 |
+
def find_column(df, target):
|
| 1016 |
+
# Find column name case-insensitively and ignoring spaces
|
| 1017 |
+
for col in df.columns:
|
| 1018 |
+
if col.strip().lower() == target:
|
| 1019 |
+
return col
|
| 1020 |
+
return None
|
| 1021 |
+
|
| 1022 |
+
if uploaded_file.name.endswith('.csv'):
|
| 1023 |
+
df = pd.read_csv(uploaded_file)
|
| 1024 |
+
seq_col = find_column(df, 'sequence')
|
| 1025 |
+
name_col = find_column(df, 'name')
|
| 1026 |
+
if seq_col:
|
| 1027 |
+
sequences = df[seq_col].tolist()
|
| 1028 |
+
if name_col:
|
| 1029 |
+
names = df[name_col].tolist()
|
| 1030 |
+
else:
|
| 1031 |
+
names = [f"Sequence_{i+1}" for i in range(len(sequences))]
|
| 1032 |
+
else:
|
| 1033 |
+
st.error("CSV file must contain a column named 'sequence' (case-insensitive, spaces ignored)")
|
| 1034 |
+
return
|
| 1035 |
+
elif uploaded_file.name.endswith('.xlsx'):
|
| 1036 |
+
df = pd.read_excel(uploaded_file)
|
| 1037 |
+
seq_col = find_column(df, 'sequence')
|
| 1038 |
+
name_col = find_column(df, 'name')
|
| 1039 |
+
if seq_col:
|
| 1040 |
+
sequences = df[seq_col].tolist()
|
| 1041 |
+
if name_col:
|
| 1042 |
+
names = df[name_col].tolist()
|
| 1043 |
+
else:
|
| 1044 |
+
names = [f"Sequence_{i+1}" for i in range(len(sequences))]
|
| 1045 |
+
else:
|
| 1046 |
+
st.error("Excel file must contain a column named 'sequence' (case-insensitive, spaces ignored)")
|
| 1047 |
+
return
|
| 1048 |
+
else:
|
| 1049 |
+
# Handle FASTA format
|
| 1050 |
+
content = uploaded_file.read().decode('utf-8')
|
| 1051 |
+
sequences, names = parse_fasta_content(content)
|
| 1052 |
+
|
| 1053 |
+
st.info(f"Found {len(sequences)} sequences ready for optimization")
|
| 1054 |
+
|
| 1055 |
+
# Batch configuration
|
| 1056 |
+
col1, col2 = st.columns(2)
|
| 1057 |
+
with col1:
|
| 1058 |
+
batch_organism = st.selectbox("Target Organism", [
|
| 1059 |
+
"Escherichia coli general", "Saccharomyces cerevisiae", "Homo sapiens"
|
| 1060 |
+
])
|
| 1061 |
+
with col2:
|
| 1062 |
+
max_sequences = st.number_input("Max sequences to process", 1, len(sequences), min(10, len(sequences)))
|
| 1063 |
+
|
| 1064 |
+
# Start batch processing
|
| 1065 |
+
if st.button("Start Batch Optimization", type="primary"):
|
| 1066 |
+
run_batch_optimization(sequences[:max_sequences], names[:max_sequences], batch_organism)
|
| 1067 |
+
|
| 1068 |
+
except Exception as e:
|
| 1069 |
+
st.error(f"Error processing file: {str(e)}")
|
| 1070 |
+
|
| 1071 |
+
# Batch results display
|
| 1072 |
+
if 'batch_results' in st.session_state and st.session_state.batch_results:
|
| 1073 |
+
display_batch_results()
|
| 1074 |
+
|
| 1075 |
+
def parse_fasta_content(content):
|
| 1076 |
+
"""Parse FASTA format content"""
|
| 1077 |
+
sequences = []
|
| 1078 |
+
names = []
|
| 1079 |
+
current_seq = ""
|
| 1080 |
+
current_name = ""
|
| 1081 |
+
|
| 1082 |
+
for line in content.split('\n'):
|
| 1083 |
+
line = line.strip()
|
| 1084 |
+
if line.startswith('>'):
|
| 1085 |
+
if current_seq:
|
| 1086 |
+
sequences.append(current_seq)
|
| 1087 |
+
names.append(current_name)
|
| 1088 |
+
current_name = line[1:] if len(line) > 1 else f"Sequence_{len(sequences)+1}"
|
| 1089 |
+
current_seq = ""
|
| 1090 |
+
else:
|
| 1091 |
+
current_seq += line
|
| 1092 |
+
|
| 1093 |
+
if current_seq:
|
| 1094 |
+
sequences.append(current_seq)
|
| 1095 |
+
names.append(current_name)
|
| 1096 |
+
|
| 1097 |
+
return sequences, names
|
| 1098 |
+
|
| 1099 |
+
def run_batch_optimization(sequences, names, organism):
|
| 1100 |
+
"""Run batch optimization with progress tracking"""
|
| 1101 |
+
st.session_state.batch_results = []
|
| 1102 |
+
st.session_state.batch_logs = [] # Collect info logs for auto-fixes
|
| 1103 |
+
|
| 1104 |
+
# Load reference data for CAI/tAI
|
| 1105 |
+
load_reference_data(organism)
|
| 1106 |
+
cai_weights = st.session_state.get('cai_weights', None)
|
| 1107 |
+
tai_weights = st.session_state.get('tai_weights', None)
|
| 1108 |
+
|
| 1109 |
+
# Create progress tracking
|
| 1110 |
+
progress_bar = st.progress(0)
|
| 1111 |
+
status_text = st.empty()
|
| 1112 |
+
|
| 1113 |
+
for i, (seq, name) in enumerate(zip(sequences, names)):
|
| 1114 |
+
progress = (i + 1) / len(sequences)
|
| 1115 |
+
progress_bar.progress(progress)
|
| 1116 |
+
status_text.text(f"Processing {name} ({i+1}/{len(sequences)})")
|
| 1117 |
+
|
| 1118 |
+
try:
|
| 1119 |
+
# Validate sequence and get possibly fixed sequence
|
| 1120 |
+
is_valid, message, sequence_type, fixed_seq = validate_sequence(seq)
|
| 1121 |
+
if is_valid:
|
| 1122 |
+
# Log if auto-fixed
|
| 1123 |
+
if 'auto-fixed' in message:
|
| 1124 |
+
st.session_state.batch_logs.append(f"{name}: {message}")
|
| 1125 |
+
# Calculate original metrics (use fixed_seq for DNA)
|
| 1126 |
+
if sequence_type == "dna":
|
| 1127 |
+
orig_gc = get_GC_content(fixed_seq)
|
| 1128 |
+
orig_cai = CAI(fixed_seq, weights=cai_weights) if cai_weights else None
|
| 1129 |
+
orig_tai = calculate_tAI(fixed_seq, tai_weights) if tai_weights else None
|
| 1130 |
+
else:
|
| 1131 |
+
# For protein, create baseline DNA
|
| 1132 |
+
most_frequent_codons = {
|
| 1133 |
+
'A': 'GCG', 'C': 'TGC', 'D': 'GAT', 'E': 'GAA', 'F': 'TTT',
|
| 1134 |
+
'G': 'GGC', 'H': 'CAT', 'I': 'ATT', 'K': 'AAA', 'L': 'CTG',
|
| 1135 |
+
'M': 'ATG', 'N': 'AAC', 'P': 'CCG', 'Q': 'CAG', 'R': 'CGC',
|
| 1136 |
+
'S': 'TCG', 'T': 'ACG', 'V': 'GTG', 'W': 'TGG', 'Y': 'TAT',
|
| 1137 |
+
'*': 'TAA', '_': 'TAA'
|
| 1138 |
+
}
|
| 1139 |
+
baseline_dna = ''.join([most_frequent_codons.get(aa, 'NNN') for aa in fixed_seq])
|
| 1140 |
+
orig_gc = get_GC_content(baseline_dna)
|
| 1141 |
+
orig_cai = CAI(baseline_dna, weights=cai_weights) if cai_weights else None
|
| 1142 |
+
orig_tai = calculate_tAI(baseline_dna, tai_weights) if tai_weights else None
|
| 1143 |
+
|
| 1144 |
+
# Run optimization using the fixed sequence
|
| 1145 |
+
result = predict_dna_sequence(
|
| 1146 |
+
protein=fixed_seq if sequence_type == "protein" else translate_dna_to_protein(fixed_seq),
|
| 1147 |
+
organism=organism,
|
| 1148 |
+
device=st.session_state.device,
|
| 1149 |
+
model=st.session_state.model,
|
| 1150 |
+
deterministic=True,
|
| 1151 |
+
match_protein=True,
|
| 1152 |
+
)
|
| 1153 |
+
|
| 1154 |
+
# If result is a list, use the first element
|
| 1155 |
+
if isinstance(result, list):
|
| 1156 |
+
result_obj = result[0]
|
| 1157 |
+
else:
|
| 1158 |
+
result_obj = result
|
| 1159 |
+
|
| 1160 |
+
# Calculate optimized metrics
|
| 1161 |
+
opt_gc = get_GC_content(result_obj.predicted_dna)
|
| 1162 |
+
opt_cai = CAI(result_obj.predicted_dna, weights=cai_weights) if cai_weights else None
|
| 1163 |
+
opt_tai = calculate_tAI(result_obj.predicted_dna, tai_weights) if tai_weights else None
|
| 1164 |
+
|
| 1165 |
+
metrics = {
|
| 1166 |
+
'name': name,
|
| 1167 |
+
'original_sequence': fixed_seq,
|
| 1168 |
+
'optimized_dna': result_obj.predicted_dna,
|
| 1169 |
+
'gc_content_before': orig_gc,
|
| 1170 |
+
'gc_content_after': opt_gc,
|
| 1171 |
+
'cai_before': orig_cai,
|
| 1172 |
+
'cai_after': opt_cai,
|
| 1173 |
+
'tai_before': orig_tai,
|
| 1174 |
+
'tai_after': opt_tai,
|
| 1175 |
+
'length_before': len(fixed_seq),
|
| 1176 |
+
'length_after': len(result_obj.predicted_dna),
|
| 1177 |
+
'validation_message': message
|
| 1178 |
+
}
|
| 1179 |
+
|
| 1180 |
+
st.session_state.batch_results.append(metrics)
|
| 1181 |
+
else:
|
| 1182 |
+
# Only skip if truly invalid (not auto-fixable)
|
| 1183 |
+
st.session_state.batch_logs.append(f"{name}: {message}")
|
| 1184 |
+
|
| 1185 |
+
except Exception as e:
|
| 1186 |
+
st.session_state.batch_logs.append(f"{name}: Error processing: {str(e)}")
|
| 1187 |
+
|
| 1188 |
+
progress_bar.empty()
|
| 1189 |
+
status_text.empty()
|
| 1190 |
+
st.success(f"Batch optimization complete. Processed {len(st.session_state.batch_results)} sequences.")
|
| 1191 |
+
|
| 1192 |
+
def display_batch_results():
|
| 1193 |
+
"""Display batch processing results"""
|
| 1194 |
+
st.subheader("Batch Results")
|
| 1195 |
+
|
| 1196 |
+
# Show all logs (auto-fixes and errors)
|
| 1197 |
+
if hasattr(st.session_state, 'batch_logs') and st.session_state.batch_logs:
|
| 1198 |
+
for log in st.session_state.batch_logs:
|
| 1199 |
+
st.info(log)
|
| 1200 |
+
|
| 1201 |
+
results_df = pd.DataFrame(st.session_state.batch_results)
|
| 1202 |
+
|
| 1203 |
+
# Summary statistics
|
| 1204 |
+
col1, col2, col3, col4 = st.columns(4)
|
| 1205 |
+
with col1:
|
| 1206 |
+
st.metric("Sequences Processed", len(results_df))
|
| 1207 |
+
with col2:
|
| 1208 |
+
st.metric("Avg GC Before", f"{results_df['gc_content_before'].mean():.1f}%")
|
| 1209 |
+
st.metric("Avg GC After", f"{results_df['gc_content_after'].mean():.1f}%")
|
| 1210 |
+
with col3:
|
| 1211 |
+
st.metric("Avg CAI Before", f"{results_df['cai_before'].mean():.3f}")
|
| 1212 |
+
st.metric("Avg CAI After", f"{results_df['cai_after'].mean():.3f}")
|
| 1213 |
+
with col4:
|
| 1214 |
+
st.metric("Avg tAI Before", f"{results_df['tai_before'].mean():.3f}")
|
| 1215 |
+
st.metric("Avg tAI After", f"{results_df['tai_after'].mean():.3f}")
|
| 1216 |
+
|
| 1217 |
+
# CAI Extremes Analysis
|
| 1218 |
+
st.subheader("CAI Performance Analysis")
|
| 1219 |
+
|
| 1220 |
+
# Filter out rows with NaN CAI values for analysis
|
| 1221 |
+
valid_cai_df = results_df.dropna(subset=['cai_after'])
|
| 1222 |
+
|
| 1223 |
+
if len(valid_cai_df) > 0:
|
| 1224 |
+
# Find lowest and highest CAI sequences
|
| 1225 |
+
lowest_cai_idx = valid_cai_df['cai_after'].idxmin()
|
| 1226 |
+
highest_cai_idx = valid_cai_df['cai_after'].idxmax()
|
| 1227 |
+
|
| 1228 |
+
lowest_cai_row = results_df.loc[lowest_cai_idx]
|
| 1229 |
+
highest_cai_row = results_df.loc[highest_cai_idx]
|
| 1230 |
+
|
| 1231 |
+
col1, col2 = st.columns(2)
|
| 1232 |
+
|
| 1233 |
+
with col1:
|
| 1234 |
+
st.markdown("**Lowest CAI Sequence**")
|
| 1235 |
+
st.write(f"**Name:** {lowest_cai_row['name']}")
|
| 1236 |
+
st.metric("CAI Score", f"{lowest_cai_row['cai_after']:.3f}")
|
| 1237 |
+
st.metric("GC Content", f"{lowest_cai_row['gc_content_after']:.1f}%")
|
| 1238 |
+
st.metric("tAI Score", f"{lowest_cai_row['tai_after']:.3f}")
|
| 1239 |
+
st.metric("Length", f"{lowest_cai_row['length_after']} bp")
|
| 1240 |
+
|
| 1241 |
+
# Show improvement
|
| 1242 |
+
if pd.notna(lowest_cai_row['cai_before']):
|
| 1243 |
+
cai_improvement = lowest_cai_row['cai_after'] - lowest_cai_row['cai_before']
|
| 1244 |
+
st.metric("CAI Improvement", f"{cai_improvement:+.3f}")
|
| 1245 |
+
|
| 1246 |
+
with col2:
|
| 1247 |
+
st.markdown("**Highest CAI Sequence**")
|
| 1248 |
+
st.write(f"**Name:** {highest_cai_row['name']}")
|
| 1249 |
+
st.metric("CAI Score", f"{highest_cai_row['cai_after']:.3f}")
|
| 1250 |
+
st.metric("GC Content", f"{highest_cai_row['gc_content_after']:.1f}%")
|
| 1251 |
+
st.metric("tAI Score", f"{highest_cai_row['tai_after']:.3f}")
|
| 1252 |
+
st.metric("Length", f"{highest_cai_row['length_after']} bp")
|
| 1253 |
+
|
| 1254 |
+
# Show improvement
|
| 1255 |
+
if pd.notna(highest_cai_row['cai_before']):
|
| 1256 |
+
cai_improvement = highest_cai_row['cai_after'] - highest_cai_row['cai_before']
|
| 1257 |
+
st.metric("CAI Improvement", f"{cai_improvement:+.3f}")
|
| 1258 |
+
|
| 1259 |
+
# CAI Distribution Chart
|
| 1260 |
+
st.subheader("CAI Distribution")
|
| 1261 |
+
fig = go.Figure()
|
| 1262 |
+
fig.add_trace(go.Histogram(
|
| 1263 |
+
x=valid_cai_df['cai_after'],
|
| 1264 |
+
nbinsx=20,
|
| 1265 |
+
name='Optimized CAI Scores',
|
| 1266 |
+
marker_color='darkblue',
|
| 1267 |
+
opacity=0.7
|
| 1268 |
+
))
|
| 1269 |
+
|
| 1270 |
+
# Add vertical lines for lowest and highest
|
| 1271 |
+
fig.add_vline(
|
| 1272 |
+
x=lowest_cai_row['cai_after'],
|
| 1273 |
+
line_dash="dash",
|
| 1274 |
+
line_color="red",
|
| 1275 |
+
annotation_text=f"Lowest: {lowest_cai_row['cai_after']:.3f}"
|
| 1276 |
+
)
|
| 1277 |
+
fig.add_vline(
|
| 1278 |
+
x=highest_cai_row['cai_after'],
|
| 1279 |
+
line_dash="dash",
|
| 1280 |
+
line_color="green",
|
| 1281 |
+
annotation_text=f"Highest: {highest_cai_row['cai_after']:.3f}"
|
| 1282 |
+
)
|
| 1283 |
+
|
| 1284 |
+
fig.update_layout(
|
| 1285 |
+
title="Distribution of Optimized CAI Scores",
|
| 1286 |
+
xaxis_title="CAI Score",
|
| 1287 |
+
yaxis_title="Number of Sequences",
|
| 1288 |
+
height=400,
|
| 1289 |
+
showlegend=False
|
| 1290 |
+
)
|
| 1291 |
+
st.plotly_chart(fig, use_container_width=True)
|
| 1292 |
+
|
| 1293 |
+
# GC Content Distribution Chart
|
| 1294 |
+
st.subheader("GC Content Distribution")
|
| 1295 |
+
valid_gc_df = results_df.dropna(subset=['gc_content_after'])
|
| 1296 |
+
if len(valid_gc_df) > 0:
|
| 1297 |
+
lowest_gc_idx = valid_gc_df['gc_content_after'].idxmin()
|
| 1298 |
+
highest_gc_idx = valid_gc_df['gc_content_after'].idxmax()
|
| 1299 |
+
lowest_gc_row = results_df.loc[lowest_gc_idx]
|
| 1300 |
+
highest_gc_row = results_df.loc[highest_gc_idx]
|
| 1301 |
+
|
| 1302 |
+
fig_gc = go.Figure()
|
| 1303 |
+
fig_gc.add_trace(go.Histogram(
|
| 1304 |
+
x=valid_gc_df['gc_content_after'],
|
| 1305 |
+
nbinsx=20,
|
| 1306 |
+
name='Optimized GC Content',
|
| 1307 |
+
marker_color='teal',
|
| 1308 |
+
opacity=0.7
|
| 1309 |
+
))
|
| 1310 |
+
fig_gc.add_vline(
|
| 1311 |
+
x=lowest_gc_row['gc_content_after'],
|
| 1312 |
+
line_dash="dash",
|
| 1313 |
+
line_color="red",
|
| 1314 |
+
annotation_text=f"Lowest: {lowest_gc_row['gc_content_after']:.1f}%"
|
| 1315 |
+
)
|
| 1316 |
+
fig_gc.add_vline(
|
| 1317 |
+
x=highest_gc_row['gc_content_after'],
|
| 1318 |
+
line_dash="dash",
|
| 1319 |
+
line_color="green",
|
| 1320 |
+
annotation_text=f"Highest: {highest_gc_row['gc_content_after']:.1f}%"
|
| 1321 |
+
)
|
| 1322 |
+
fig_gc.update_layout(
|
| 1323 |
+
title="Distribution of Optimized GC Content",
|
| 1324 |
+
xaxis_title="GC Content (%)",
|
| 1325 |
+
yaxis_title="Number of Sequences",
|
| 1326 |
+
height=400,
|
| 1327 |
+
showlegend=False
|
| 1328 |
+
)
|
| 1329 |
+
st.plotly_chart(fig_gc, use_container_width=True)
|
| 1330 |
+
else:
|
| 1331 |
+
st.warning("No valid GC content values found in the batch results.")
|
| 1332 |
+
|
| 1333 |
+
else:
|
| 1334 |
+
st.warning("No valid CAI scores found in the batch results. Check if CAI weights are properly loaded.")
|
| 1335 |
+
|
| 1336 |
+
# Sequence selector
|
| 1337 |
+
seq_names = results_df['name'].tolist()
|
| 1338 |
+
selected_seq = st.selectbox("Select a sequence to view details", seq_names)
|
| 1339 |
+
seq_row = results_df[results_df['name'] == selected_seq].iloc[0]
|
| 1340 |
+
|
| 1341 |
+
st.markdown(f"### Details for: {selected_seq}")
|
| 1342 |
+
if 'validation_message' in seq_row and 'auto-fixed' in seq_row['validation_message']:
|
| 1343 |
+
st.info(seq_row['validation_message'])
|
| 1344 |
+
col1, col2 = st.columns(2)
|
| 1345 |
+
with col1:
|
| 1346 |
+
st.markdown("**Original Sequence**")
|
| 1347 |
+
st.text_area("Original Sequence", seq_row['original_sequence'], height=100)
|
| 1348 |
+
st.metric("GC Content (Before)", f"{seq_row['gc_content_before']:.1f}%")
|
| 1349 |
+
st.metric("CAI (Before)", f"{seq_row['cai_before']:.3f}")
|
| 1350 |
+
st.metric("tAI (Before)", f"{seq_row['tai_before']:.3f}")
|
| 1351 |
+
st.metric("Length (Before)", f"{seq_row['length_before']}")
|
| 1352 |
+
with col2:
|
| 1353 |
+
st.markdown("**Optimized Sequence**")
|
| 1354 |
+
st.text_area("Optimized Sequence", seq_row['optimized_dna'], height=100)
|
| 1355 |
+
st.metric("GC Content (After)", f"{seq_row['gc_content_after']:.1f}%")
|
| 1356 |
+
st.metric("CAI (After)", f"{seq_row['cai_after']:.3f}")
|
| 1357 |
+
st.metric("tAI (After)", f"{seq_row['tai_after']:.3f}")
|
| 1358 |
+
st.metric("Length (After)", f"{seq_row['length_after']}")
|
| 1359 |
+
|
| 1360 |
+
# Plots for before/after GC content
|
| 1361 |
+
st.subheader("GC Content Distribution (Before vs After)")
|
| 1362 |
+
if len(seq_row['original_sequence']) > 150 and len(seq_row['optimized_dna']) > 150:
|
| 1363 |
+
fig_before = create_gc_content_plot(seq_row['original_sequence'])
|
| 1364 |
+
fig_before.update_layout(title="Before Optimization", height=300)
|
| 1365 |
+
fig_after = create_gc_content_plot(seq_row['optimized_dna'])
|
| 1366 |
+
fig_after.update_layout(title="After Optimization", height=300)
|
| 1367 |
+
st.plotly_chart(fig_before, use_container_width=True)
|
| 1368 |
+
st.plotly_chart(fig_after, use_container_width=True)
|
| 1369 |
+
else:
|
| 1370 |
+
st.info("Sequence(s) too short for sliding window analysis")
|
| 1371 |
+
|
| 1372 |
+
# Download batch results
|
| 1373 |
+
if st.button("Download Batch Results"):
|
| 1374 |
+
csv_data = results_df.to_csv(index=False)
|
| 1375 |
+
st.download_button(
|
| 1376 |
+
label="Download CSV",
|
| 1377 |
+
data=csv_data,
|
| 1378 |
+
file_name="batch_optimization_results.csv",
|
| 1379 |
+
mime="text/csv"
|
| 1380 |
+
)
|
| 1381 |
+
|
| 1382 |
+
def comparative_analysis_interface():
|
| 1383 |
+
"""Comparative analysis interface"""
|
| 1384 |
+
st.header("Comparative Analysis")
|
| 1385 |
+
st.markdown("For quantitative comparisons and plots, use the benchmark script:")
|
| 1386 |
+
st.code("python scripts/run_benchmarks.py --config configs/benchmark.yaml")
|
| 1387 |
+
|
| 1388 |
+
def advanced_settings_interface():
|
| 1389 |
+
"""Advanced settings and configuration interface"""
|
| 1390 |
+
st.header("Advanced Settings")
|
| 1391 |
+
st.markdown("**Configure advanced parameters and model settings**")
|
| 1392 |
+
|
| 1393 |
+
# Model configuration
|
| 1394 |
+
st.subheader("Model Configuration")
|
| 1395 |
+
col1, col2 = st.columns(2)
|
| 1396 |
+
|
| 1397 |
+
with col1:
|
| 1398 |
+
st.write("**Current Model Status:**")
|
| 1399 |
+
if st.session_state.model:
|
| 1400 |
+
model_type = getattr(st.session_state, 'model_type', 'unknown')
|
| 1401 |
+
st.success(f"Model loaded: {model_type}")
|
| 1402 |
+
st.write(f"Device: {st.session_state.device}")
|
| 1403 |
+
else:
|
| 1404 |
+
st.warning("Model not loaded")
|
| 1405 |
+
|
| 1406 |
+
with col2:
|
| 1407 |
+
st.write("**Model Information:**")
|
| 1408 |
+
st.write("• Architecture: BigBird Transformer")
|
| 1409 |
+
st.write("• Parameters: 89.6M")
|
| 1410 |
+
st.write("• Fine-tuning data: 3,676 high-expression E. coli genes (NCBI-curated)")
|
| 1411 |
+
|
| 1412 |
+
# Performance tuning
|
| 1413 |
+
st.subheader("Performance Tuning")
|
| 1414 |
+
|
| 1415 |
+
# Memory management
|
| 1416 |
+
col1, col2 = st.columns(2)
|
| 1417 |
+
with col1:
|
| 1418 |
+
if st.button("Clear Cache"):
|
| 1419 |
+
st.cache_data.clear()
|
| 1420 |
+
st.success("Cache cleared successfully")
|
| 1421 |
+
|
| 1422 |
+
with col2:
|
| 1423 |
+
if st.button("Reload Model"):
|
| 1424 |
+
st.session_state.model = None
|
| 1425 |
+
st.session_state.tokenizer = None
|
| 1426 |
+
st.rerun()
|
| 1427 |
+
|
| 1428 |
+
# System information
|
| 1429 |
+
st.subheader("System Information")
|
| 1430 |
+
import torch
|
| 1431 |
+
col1, col2, col3 = st.columns(3)
|
| 1432 |
+
|
| 1433 |
+
with col1:
|
| 1434 |
+
st.write("**PyTorch:**")
|
| 1435 |
+
st.write(f"Version: {torch.__version__}")
|
| 1436 |
+
st.write(f"CUDA Available: {torch.cuda.is_available()}")
|
| 1437 |
+
|
| 1438 |
+
with col2:
|
| 1439 |
+
st.write("**Device:**")
|
| 1440 |
+
st.write(f"Current: {st.session_state.device}")
|
| 1441 |
+
if torch.cuda.is_available():
|
| 1442 |
+
st.write(f"GPU: {torch.cuda.get_device_name()}")
|
| 1443 |
+
|
| 1444 |
+
with col3:
|
| 1445 |
+
st.write("**Memory:**")
|
| 1446 |
+
if torch.cuda.is_available():
|
| 1447 |
+
gpu_memory = torch.cuda.get_device_properties(0).total_memory / 1e9
|
| 1448 |
+
st.write(f"GPU Memory: {gpu_memory:.1f} GB")
|
| 1449 |
+
|
| 1450 |
+
# Footer
|
| 1451 |
+
st.markdown("---")
|
| 1452 |
+
st.markdown("**ENCOT**")
|
| 1453 |
+
st.markdown("Open-source codon optimization for E. coli with reproducible evaluation.")
|
| 1454 |
+
|
| 1455 |
+
if __name__ == "__main__":
|
| 1456 |
+
main()
|
streamlit_gui/demo.py
ADDED
|
@@ -0,0 +1,288 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
Demo script for ColiFormer Streamlit GUI
|
| 4 |
+
|
| 5 |
+
This script demonstrates the GUI functionality with example sequences
|
| 6 |
+
and showcases key features of the ColiFormer optimization tool.
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
import sys
|
| 10 |
+
import os
|
| 11 |
+
import time
|
| 12 |
+
from pathlib import Path
|
| 13 |
+
|
| 14 |
+
# Add parent directory to path for imports
|
| 15 |
+
sys.path.append(str(Path(__file__).parent.parent))
|
| 16 |
+
|
| 17 |
+
def print_header():
|
| 18 |
+
"""Print demo header"""
|
| 19 |
+
print("=" * 40)
|
| 20 |
+
print(" ColiFormer GUI Demo")
|
| 21 |
+
print("=" * 40)
|
| 22 |
+
print()
|
| 23 |
+
|
| 24 |
+
def print_section(title):
|
| 25 |
+
"""Print section header"""
|
| 26 |
+
print(f"\n{title}")
|
| 27 |
+
print("-" * (len(title) + 4))
|
| 28 |
+
|
| 29 |
+
def demo_validation():
|
| 30 |
+
"""Demonstrate protein sequence validation"""
|
| 31 |
+
print_section("Protein Sequence Validation")
|
| 32 |
+
|
| 33 |
+
# Import validation function
|
| 34 |
+
from streamlit_gui.app import validate_protein_sequence
|
| 35 |
+
|
| 36 |
+
test_sequences = [
|
| 37 |
+
("MKTVRQERLK", "Valid short peptide"),
|
| 38 |
+
("MALWMRLLPLLALLALWGPDPAAAFVNQHLCGSHLVEALYLVCGERGFFYTPKTRREAEDLQVGQVELGG", "Valid longer protein"),
|
| 39 |
+
("MKTVRQERLKX", "Invalid character (X)"),
|
| 40 |
+
("MK", "Too short"),
|
| 41 |
+
("mktvrqerlk", "Lowercase (should work)"),
|
| 42 |
+
("MKTVRQERLK*", "With stop codon"),
|
| 43 |
+
]
|
| 44 |
+
|
| 45 |
+
for seq, description in test_sequences:
|
| 46 |
+
is_valid, message = validate_protein_sequence(seq)
|
| 47 |
+
status = "OK" if is_valid else "FAIL"
|
| 48 |
+
print(f"{status} {description}: {message}")
|
| 49 |
+
|
| 50 |
+
def demo_metrics():
|
| 51 |
+
"""Demonstrate metrics calculation"""
|
| 52 |
+
print_section("Metrics Calculation Demo")
|
| 53 |
+
|
| 54 |
+
from streamlit_gui.app import calculate_input_metrics
|
| 55 |
+
|
| 56 |
+
example_proteins = [
|
| 57 |
+
("MKTVRQERLK", "Short peptide (10 AA)"),
|
| 58 |
+
("MALWMRLLPLLALLALWGPDPAAAFVNQHLCGSHLVEALYLVCGERGFFYTPKTRREAEDLQVGQVELGG", "Medium protein (67 AA)"),
|
| 59 |
+
("MKWVTFISLLLLFSSAYSRGVFRRDTHKSEIAHRFKDLGEEHFKGLVLIAFSQYLQQCPFDEHVKLVNELTE", "Long protein (72 AA)"),
|
| 60 |
+
]
|
| 61 |
+
|
| 62 |
+
organism = "Escherichia coli general"
|
| 63 |
+
|
| 64 |
+
for protein, description in example_proteins:
|
| 65 |
+
print(f"\n{description}")
|
| 66 |
+
print(f" Sequence: {protein[:30]}{'...' if len(protein) > 30 else ''}")
|
| 67 |
+
|
| 68 |
+
metrics = calculate_input_metrics(protein, organism)
|
| 69 |
+
|
| 70 |
+
print(f" Length: {metrics['length']} amino acids")
|
| 71 |
+
print(f" GC Content: {metrics['gc_content']:.1f}%")
|
| 72 |
+
if metrics['tai']:
|
| 73 |
+
print(f" tAI: {metrics['tai']:.3f}")
|
| 74 |
+
if metrics['cai']:
|
| 75 |
+
print(f" CAI: {metrics['cai']:.3f}")
|
| 76 |
+
else:
|
| 77 |
+
print(" CAI: Not available for this organism")
|
| 78 |
+
|
| 79 |
+
def demo_visualization():
|
| 80 |
+
"""Demonstrate visualization capabilities"""
|
| 81 |
+
print_section("Visualization Demo")
|
| 82 |
+
|
| 83 |
+
from streamlit_gui.app import create_gc_content_plot, create_metrics_comparison_chart
|
| 84 |
+
|
| 85 |
+
# Test DNA sequence for GC content plot
|
| 86 |
+
test_dna = "ATGGCGAAAGCGCTGTATCGCGAAAGCGCTGTATCGCGAAAGCGCTGTATCGCGAAAGCGCTGTATCGC"
|
| 87 |
+
|
| 88 |
+
print("Creating GC content sliding window plot...")
|
| 89 |
+
try:
|
| 90 |
+
fig = create_gc_content_plot(test_dna)
|
| 91 |
+
print(" OK: GC content plot created successfully")
|
| 92 |
+
print(f" Analyzing {len(test_dna)} base pairs")
|
| 93 |
+
except Exception as e:
|
| 94 |
+
print(f" FAIL: Error creating GC plot: {e}")
|
| 95 |
+
|
| 96 |
+
print("\nCreating metrics comparison chart...")
|
| 97 |
+
try:
|
| 98 |
+
before_metrics = {
|
| 99 |
+
'gc_content': 45.2,
|
| 100 |
+
'cai': 0.485,
|
| 101 |
+
'tai': 0.312
|
| 102 |
+
}
|
| 103 |
+
after_metrics = {
|
| 104 |
+
'gc_content': 52.1,
|
| 105 |
+
'cai': 0.634,
|
| 106 |
+
'tai': 0.456
|
| 107 |
+
}
|
| 108 |
+
fig = create_metrics_comparison_chart(before_metrics, after_metrics)
|
| 109 |
+
print(" OK: Comparison chart created successfully")
|
| 110 |
+
print(" Shows improvement in all metrics")
|
| 111 |
+
except Exception as e:
|
| 112 |
+
print(f" FAIL: Error creating comparison chart: {e}")
|
| 113 |
+
|
| 114 |
+
def demo_codon_evaluation():
|
| 115 |
+
"""Demonstrate CodonEvaluation functions"""
|
| 116 |
+
print_section("CodonEvaluation Functions Demo")
|
| 117 |
+
|
| 118 |
+
from CodonTransformer.CodonEvaluation import get_GC_content, calculate_tAI, get_ecoli_tai_weights
|
| 119 |
+
|
| 120 |
+
test_sequences = [
|
| 121 |
+
("ATGGCGAAAGCGCTGTATCGC", "High GC content"),
|
| 122 |
+
("ATGAAATTTATTTATTATTAT", "Low GC content"),
|
| 123 |
+
("ATGGCGAAAGCGCTGTATCGCGAAAGCGCTGTATCGC", "Medium length"),
|
| 124 |
+
]
|
| 125 |
+
|
| 126 |
+
print("Testing GC content calculation:")
|
| 127 |
+
for seq, description in test_sequences:
|
| 128 |
+
gc_content = get_GC_content(seq)
|
| 129 |
+
print(f" {description}: {gc_content:.1f}%")
|
| 130 |
+
|
| 131 |
+
print("\nTesting tAI calculation:")
|
| 132 |
+
try:
|
| 133 |
+
tai_weights = get_ecoli_tai_weights()
|
| 134 |
+
for seq, description in test_sequences:
|
| 135 |
+
tai_value = calculate_tAI(seq, tai_weights)
|
| 136 |
+
print(f" {description}: {tai_value:.3f}")
|
| 137 |
+
except Exception as e:
|
| 138 |
+
print(f" FAIL: tAI calculation error: {e}")
|
| 139 |
+
|
| 140 |
+
def demo_model_info():
|
| 141 |
+
"""Show model information"""
|
| 142 |
+
print_section("Model Information")
|
| 143 |
+
|
| 144 |
+
try:
|
| 145 |
+
import torch
|
| 146 |
+
from transformers import AutoTokenizer
|
| 147 |
+
|
| 148 |
+
print("Model Details:")
|
| 149 |
+
print(" Base model: adibvafa/CodonTransformer")
|
| 150 |
+
print(" Architecture: BigBird Transformer")
|
| 151 |
+
print(" Task: Masked Language Modeling for codon optimization")
|
| 152 |
+
|
| 153 |
+
print("\nSystem Information:")
|
| 154 |
+
print(f" PyTorch: {torch.__version__}")
|
| 155 |
+
print(f" Device: {'GPU (CUDA)' if torch.cuda.is_available() else 'CPU'}")
|
| 156 |
+
if torch.cuda.is_available():
|
| 157 |
+
print(f" GPU: {torch.cuda.get_device_name(0)}")
|
| 158 |
+
print(f" GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f}GB")
|
| 159 |
+
|
| 160 |
+
print("\nTokenizer Test:")
|
| 161 |
+
tokenizer = AutoTokenizer.from_pretrained("adibvafa/CodonTransformer")
|
| 162 |
+
print(f" OK: Tokenizer loaded: {len(tokenizer)} tokens")
|
| 163 |
+
print(f" Vocab size: {tokenizer.vocab_size}")
|
| 164 |
+
|
| 165 |
+
except Exception as e:
|
| 166 |
+
print(f" FAIL: Error loading model info: {e}")
|
| 167 |
+
|
| 168 |
+
def demo_gui_features():
|
| 169 |
+
"""Show GUI features overview"""
|
| 170 |
+
print_section("GUI Features Overview")
|
| 171 |
+
|
| 172 |
+
features = [
|
| 173 |
+
("Real-time Validation", "Instant feedback on protein sequence validity"),
|
| 174 |
+
("Metrics Dashboard", "GC content, CAI, tAI calculations"),
|
| 175 |
+
("Constrained Optimization", "GC content control with beam search"),
|
| 176 |
+
("Visual Analytics", "Interactive plots and comparisons"),
|
| 177 |
+
("Configurable Parameters", "Organism selection, beam size, GC targets"),
|
| 178 |
+
("Export Options", "Download optimized sequences"),
|
| 179 |
+
("Progress Tracking", "Real-time optimization progress"),
|
| 180 |
+
("Responsive Design", "Works on desktop and mobile"),
|
| 181 |
+
]
|
| 182 |
+
|
| 183 |
+
for feature, description in features:
|
| 184 |
+
print(f" {feature}: {description}")
|
| 185 |
+
|
| 186 |
+
def demo_usage_examples():
|
| 187 |
+
"""Show usage examples"""
|
| 188 |
+
print_section("Usage Examples")
|
| 189 |
+
|
| 190 |
+
examples = [
|
| 191 |
+
{
|
| 192 |
+
"name": "Short Peptide Optimization",
|
| 193 |
+
"protein": "MKTVRQERLK",
|
| 194 |
+
"organism": "Escherichia coli general",
|
| 195 |
+
"use_case": "Quick testing and validation"
|
| 196 |
+
},
|
| 197 |
+
{
|
| 198 |
+
"name": "Insulin Chain A",
|
| 199 |
+
"protein": "GIVEQCCTSICSLYQLENYCN",
|
| 200 |
+
"organism": "Escherichia coli general",
|
| 201 |
+
"use_case": "Pharmaceutical protein production"
|
| 202 |
+
},
|
| 203 |
+
{
|
| 204 |
+
"name": "Green Fluorescent Protein (partial)",
|
| 205 |
+
"protein": "MSKGEELFTGVVPILVELDGDVNGHKFSVSGEGEGDATYGKLTLKFICTTGKLPVPWPTLVTTFSYGVQC",
|
| 206 |
+
"organism": "Escherichia coli general",
|
| 207 |
+
"use_case": "Research marker protein"
|
| 208 |
+
},
|
| 209 |
+
{
|
| 210 |
+
"name": "Yeast Expression",
|
| 211 |
+
"protein": "MKTVRQERLKSIVRILERSKEPVSGAQ",
|
| 212 |
+
"organism": "Saccharomyces cerevisiae",
|
| 213 |
+
"use_case": "Eukaryotic protein expression"
|
| 214 |
+
}
|
| 215 |
+
]
|
| 216 |
+
|
| 217 |
+
for i, example in enumerate(examples, 1):
|
| 218 |
+
print(f"\nExample {i}: {example['name']}")
|
| 219 |
+
print(f" Protein: {example['protein'][:40]}{'...' if len(example['protein']) > 40 else ''}")
|
| 220 |
+
print(f" Organism: {example['organism']}")
|
| 221 |
+
print(f" Use case: {example['use_case']}")
|
| 222 |
+
print(f" Length: {len(example['protein'])} amino acids")
|
| 223 |
+
|
| 224 |
+
def demo_launch_instructions():
|
| 225 |
+
"""Show how to launch the GUI"""
|
| 226 |
+
print_section("How to Launch the GUI")
|
| 227 |
+
|
| 228 |
+
print("Launch Options:")
|
| 229 |
+
print()
|
| 230 |
+
print(" Option 1 - Using the launcher script:")
|
| 231 |
+
print(" $ cd ecoli/streamlit_gui")
|
| 232 |
+
print(" $ python run_gui.py")
|
| 233 |
+
print()
|
| 234 |
+
print(" Option 2 - Direct streamlit command:")
|
| 235 |
+
print(" $ cd ecoli/streamlit_gui")
|
| 236 |
+
print(" $ source ../codon_env/bin/activate")
|
| 237 |
+
print(" $ streamlit run app.py")
|
| 238 |
+
print()
|
| 239 |
+
print(" Option 3 - With custom port:")
|
| 240 |
+
print(" $ streamlit run app.py --server.port 8502")
|
| 241 |
+
print()
|
| 242 |
+
print("Access the GUI:")
|
| 243 |
+
print(" Web browser: http://localhost:8501")
|
| 244 |
+
print(" The GUI will automatically open in your default browser")
|
| 245 |
+
print()
|
| 246 |
+
print("Performance Tips:")
|
| 247 |
+
print(" • Use GPU if available for faster processing")
|
| 248 |
+
print(" • Start with shorter sequences for testing")
|
| 249 |
+
print(" • Adjust beam size based on sequence length")
|
| 250 |
+
print(" • Close other applications to free up memory")
|
| 251 |
+
|
| 252 |
+
def main():
|
| 253 |
+
"""Run the complete demo"""
|
| 254 |
+
print_header()
|
| 255 |
+
|
| 256 |
+
print("This demo showcases the ENCOT Streamlit GUI capabilities.")
|
| 257 |
+
print("The GUI provides an interface for protein codon optimization.")
|
| 258 |
+
print()
|
| 259 |
+
|
| 260 |
+
try:
|
| 261 |
+
demo_validation()
|
| 262 |
+
demo_metrics()
|
| 263 |
+
demo_visualization()
|
| 264 |
+
demo_codon_evaluation()
|
| 265 |
+
demo_model_info()
|
| 266 |
+
demo_gui_features()
|
| 267 |
+
demo_usage_examples()
|
| 268 |
+
demo_launch_instructions()
|
| 269 |
+
|
| 270 |
+
print("\nDemo completed successfully.")
|
| 271 |
+
print()
|
| 272 |
+
print("Next steps:")
|
| 273 |
+
print("1. Launch the GUI using one of the methods above")
|
| 274 |
+
print("2. Try the example sequences provided")
|
| 275 |
+
print("3. Experiment with different organisms and settings")
|
| 276 |
+
print("4. Compare optimization results")
|
| 277 |
+
print()
|
| 278 |
+
print("Happy optimizing.")
|
| 279 |
+
|
| 280 |
+
except Exception as e:
|
| 281 |
+
print(f"\nDemo error: {e}")
|
| 282 |
+
print("Make sure you're running from the correct directory and all dependencies are installed.")
|
| 283 |
+
return 1
|
| 284 |
+
|
| 285 |
+
return 0
|
| 286 |
+
|
| 287 |
+
if __name__ == "__main__":
|
| 288 |
+
exit(main())
|
streamlit_gui/requirements.txt
ADDED
|
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
streamlit>=1.28.0
|
| 2 |
+
torch>=1.13.0
|
| 3 |
+
pandas>=1.5.0
|
| 4 |
+
numpy>=1.21.0
|
| 5 |
+
plotly>=5.0.0
|
| 6 |
+
transformers>=4.21.0
|
| 7 |
+
scipy>=1.9.0
|
| 8 |
+
tokenizers>=0.13.0
|
| 9 |
+
tqdm>=4.64.0
|
| 10 |
+
matplotlib>=3.5.0
|
| 11 |
+
seaborn>=0.11.0
|
| 12 |
+
onnxruntime>=1.15.0
|
| 13 |
+
python-codon-tables>=0.1.12
|
| 14 |
+
biopython>=1.79
|
| 15 |
+
scikit-learn>=1.0.0
|
| 16 |
+
requests>=2.25.0
|
| 17 |
+
ipywidgets>=7.6.0
|
| 18 |
+
huggingface-hub>=0.20.0
|
| 19 |
+
datasets>=2.0.0
|
| 20 |
+
git+https://github.com/Benjamin-Lee/CodonAdaptationIndex.git
|
streamlit_gui/run_gui.py
ADDED
|
@@ -0,0 +1,102 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
Launcher script for ColiFormer Streamlit GUI
|
| 4 |
+
|
| 5 |
+
This script sets up the environment and launches the Streamlit application.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
import os
|
| 9 |
+
import sys
|
| 10 |
+
import subprocess
|
| 11 |
+
from pathlib import Path
|
| 12 |
+
|
| 13 |
+
def main():
|
| 14 |
+
"""Launch the Streamlit GUI application"""
|
| 15 |
+
|
| 16 |
+
# Get the directory containing this script
|
| 17 |
+
script_dir = Path(__file__).parent
|
| 18 |
+
|
| 19 |
+
# Add the parent directory to Python path so we can import CodonTransformer
|
| 20 |
+
parent_dir = script_dir.parent
|
| 21 |
+
sys.path.insert(0, str(parent_dir))
|
| 22 |
+
|
| 23 |
+
# Set working directory to parent directory so model paths work correctly
|
| 24 |
+
os.chdir(parent_dir)
|
| 25 |
+
|
| 26 |
+
print("Starting ENCOT GUI...")
|
| 27 |
+
print(f" Working directory: {parent_dir}")
|
| 28 |
+
print(f" Python path includes: {parent_dir}")
|
| 29 |
+
|
| 30 |
+
# Check for model checkpoint
|
| 31 |
+
model_path = parent_dir / "models" / "alm-enhanced-training" / "balanced_alm_finetune.ckpt"
|
| 32 |
+
if model_path.exists():
|
| 33 |
+
print(f"Found fine-tuned model: {model_path}")
|
| 34 |
+
else:
|
| 35 |
+
print("Fine-tuned model not found, will use base model")
|
| 36 |
+
|
| 37 |
+
# Check for virtual environment
|
| 38 |
+
venv_path = parent_dir / "codon_env"
|
| 39 |
+
if venv_path.exists():
|
| 40 |
+
# Set up virtual environment paths
|
| 41 |
+
venv_bin = venv_path / "bin"
|
| 42 |
+
venv_python = venv_bin / "python"
|
| 43 |
+
|
| 44 |
+
if venv_python.exists():
|
| 45 |
+
print(f"Found virtual environment: {venv_path}")
|
| 46 |
+
# Update PATH to include virtual environment
|
| 47 |
+
current_path = os.environ.get("PATH", "")
|
| 48 |
+
os.environ["PATH"] = f"{venv_bin}:{current_path}"
|
| 49 |
+
# Use virtual environment Python
|
| 50 |
+
python_executable = str(venv_python)
|
| 51 |
+
else:
|
| 52 |
+
print("Virtual environment found but Python executable missing")
|
| 53 |
+
python_executable = sys.executable
|
| 54 |
+
else:
|
| 55 |
+
print("No virtual environment found, using system Python")
|
| 56 |
+
python_executable = sys.executable
|
| 57 |
+
|
| 58 |
+
print(f" Using Python: {python_executable}")
|
| 59 |
+
print()
|
| 60 |
+
|
| 61 |
+
# Check if streamlit is installed
|
| 62 |
+
try:
|
| 63 |
+
import streamlit
|
| 64 |
+
print(f"Streamlit version: {streamlit.__version__}")
|
| 65 |
+
except ImportError:
|
| 66 |
+
print("Streamlit not found. Please install requirements:")
|
| 67 |
+
print(" pip install -r requirements.txt")
|
| 68 |
+
return 1
|
| 69 |
+
|
| 70 |
+
# Check if torch is available
|
| 71 |
+
try:
|
| 72 |
+
import torch
|
| 73 |
+
device = "GPU" if torch.cuda.is_available() else "CPU"
|
| 74 |
+
print(f"PyTorch available, using: {device}")
|
| 75 |
+
except ImportError:
|
| 76 |
+
print("PyTorch not found. Please install requirements:")
|
| 77 |
+
print(" pip install -r requirements.txt")
|
| 78 |
+
return 1
|
| 79 |
+
|
| 80 |
+
print()
|
| 81 |
+
print("Launching GUI...")
|
| 82 |
+
print(" The application will open in your default web browser")
|
| 83 |
+
print(" Press Ctrl+C to stop the server")
|
| 84 |
+
print()
|
| 85 |
+
|
| 86 |
+
# Launch streamlit
|
| 87 |
+
try:
|
| 88 |
+
subprocess.run([
|
| 89 |
+
python_executable, "-m", "streamlit", "run", "streamlit_gui/app.py",
|
| 90 |
+
"--server.headless", "false",
|
| 91 |
+
"--server.port", "8501",
|
| 92 |
+
"--server.address", "0.0.0.0"
|
| 93 |
+
])
|
| 94 |
+
except KeyboardInterrupt:
|
| 95 |
+
print("\nShutting down ENCOT GUI...")
|
| 96 |
+
return 0
|
| 97 |
+
except Exception as e:
|
| 98 |
+
print(f"Error launching Streamlit: {e}")
|
| 99 |
+
return 1
|
| 100 |
+
|
| 101 |
+
if __name__ == "__main__":
|
| 102 |
+
exit(main())
|
streamlit_gui/test_gui.py
ADDED
|
@@ -0,0 +1,321 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
Test script for ColiFormer Streamlit GUI
|
| 4 |
+
|
| 5 |
+
This script tests the core functionality of the GUI without running the full Streamlit application.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
import sys
|
| 9 |
+
import os
|
| 10 |
+
import traceback
|
| 11 |
+
from pathlib import Path
|
| 12 |
+
|
| 13 |
+
# Add parent directory to path for imports
|
| 14 |
+
sys.path.append(str(Path(__file__).parent.parent))
|
| 15 |
+
|
| 16 |
+
def test_imports():
|
| 17 |
+
"""Test if all required imports work"""
|
| 18 |
+
print("Testing imports...")
|
| 19 |
+
|
| 20 |
+
try:
|
| 21 |
+
import streamlit as st
|
| 22 |
+
print(f" OK: Streamlit: {st.__version__}")
|
| 23 |
+
except ImportError as e:
|
| 24 |
+
print(f" FAIL: Streamlit: {e}")
|
| 25 |
+
return False
|
| 26 |
+
|
| 27 |
+
try:
|
| 28 |
+
import torch
|
| 29 |
+
device = "GPU" if torch.cuda.is_available() else "CPU"
|
| 30 |
+
print(f" OK: PyTorch: {torch.__version__} ({device})")
|
| 31 |
+
except ImportError as e:
|
| 32 |
+
print(f" FAIL: PyTorch: {e}")
|
| 33 |
+
return False
|
| 34 |
+
|
| 35 |
+
try:
|
| 36 |
+
import plotly
|
| 37 |
+
print(f" OK: Plotly: {plotly.__version__}")
|
| 38 |
+
except ImportError as e:
|
| 39 |
+
print(f" FAIL: Plotly: {e}")
|
| 40 |
+
return False
|
| 41 |
+
|
| 42 |
+
try:
|
| 43 |
+
from CodonTransformer.CodonPrediction import predict_dna_sequence
|
| 44 |
+
print(" OK: CodonTransformer.CodonPrediction")
|
| 45 |
+
except ImportError as e:
|
| 46 |
+
print(f" FAIL: CodonTransformer.CodonPrediction: {e}")
|
| 47 |
+
return False
|
| 48 |
+
|
| 49 |
+
try:
|
| 50 |
+
from CodonTransformer.CodonEvaluation import get_GC_content, calculate_tAI
|
| 51 |
+
print(" OK: CodonTransformer.CodonEvaluation")
|
| 52 |
+
except ImportError as e:
|
| 53 |
+
print(f" FAIL: CodonTransformer.CodonEvaluation: {e}")
|
| 54 |
+
return False
|
| 55 |
+
|
| 56 |
+
return True
|
| 57 |
+
|
| 58 |
+
def test_protein_validation():
|
| 59 |
+
"""Test protein sequence validation"""
|
| 60 |
+
print("\nTesting protein sequence validation...")
|
| 61 |
+
|
| 62 |
+
try:
|
| 63 |
+
# Import the validation function
|
| 64 |
+
from app import validate_protein_sequence
|
| 65 |
+
|
| 66 |
+
# Test cases
|
| 67 |
+
test_cases = [
|
| 68 |
+
("MKTVRQERLK", True, "Valid short sequence"),
|
| 69 |
+
("", False, "Empty sequence"),
|
| 70 |
+
("MKTVRQERLKX", False, "Invalid character X"),
|
| 71 |
+
("MK", False, "Too short"),
|
| 72 |
+
("M" * 501, False, "Too long"),
|
| 73 |
+
("mktvrqerlk", True, "Lowercase (should work)"),
|
| 74 |
+
("MKTVRQERLK*", True, "With stop codon"),
|
| 75 |
+
("MKTVRQERLK_", True, "With underscore stop"),
|
| 76 |
+
]
|
| 77 |
+
|
| 78 |
+
for seq, expected_valid, description in test_cases:
|
| 79 |
+
is_valid, message = validate_protein_sequence(seq)
|
| 80 |
+
status = "OK" if is_valid == expected_valid else "FAIL"
|
| 81 |
+
print(f" {status} {description}: {message}")
|
| 82 |
+
|
| 83 |
+
return True
|
| 84 |
+
except Exception as e:
|
| 85 |
+
print(f" FAIL: Error in validation test: {e}")
|
| 86 |
+
traceback.print_exc()
|
| 87 |
+
return False
|
| 88 |
+
|
| 89 |
+
def test_metrics_calculation():
|
| 90 |
+
"""Test metrics calculation"""
|
| 91 |
+
print("\nTesting metrics calculation...")
|
| 92 |
+
|
| 93 |
+
try:
|
| 94 |
+
from app import calculate_input_metrics
|
| 95 |
+
|
| 96 |
+
test_protein = "MKTVRQERLK"
|
| 97 |
+
organism = "Escherichia coli general"
|
| 98 |
+
|
| 99 |
+
metrics = calculate_input_metrics(test_protein, organism)
|
| 100 |
+
|
| 101 |
+
# Check if all expected metrics are present
|
| 102 |
+
expected_keys = ['length', 'gc_content', 'baseline_dna', 'cai', 'tai']
|
| 103 |
+
for key in expected_keys:
|
| 104 |
+
if key in metrics:
|
| 105 |
+
print(f" OK: {key}: {metrics[key]}")
|
| 106 |
+
else:
|
| 107 |
+
print(f" FAIL: Missing metric: {key}")
|
| 108 |
+
return False
|
| 109 |
+
|
| 110 |
+
# Validate metric values
|
| 111 |
+
if metrics['length'] == len(test_protein):
|
| 112 |
+
print(" OK: Length calculation correct")
|
| 113 |
+
else:
|
| 114 |
+
print(" FAIL: Length calculation incorrect")
|
| 115 |
+
return False
|
| 116 |
+
|
| 117 |
+
if 0 <= metrics['gc_content'] <= 100:
|
| 118 |
+
print(" OK: GC content in valid range")
|
| 119 |
+
else:
|
| 120 |
+
print(" FAIL: GC content out of range")
|
| 121 |
+
return False
|
| 122 |
+
|
| 123 |
+
return True
|
| 124 |
+
except Exception as e:
|
| 125 |
+
print(f" FAIL: Error in metrics calculation: {e}")
|
| 126 |
+
traceback.print_exc()
|
| 127 |
+
return False
|
| 128 |
+
|
| 129 |
+
def test_visualization_functions():
|
| 130 |
+
"""Test visualization functions"""
|
| 131 |
+
print("\nTesting visualization functions...")
|
| 132 |
+
|
| 133 |
+
try:
|
| 134 |
+
from app import create_gc_content_plot, create_metrics_comparison_chart
|
| 135 |
+
|
| 136 |
+
# Test GC content plot
|
| 137 |
+
test_dna = "ATGGCGAAAGCGCTGTATCGCGAAAGCGCTGTATCGCGAAAGCGCTGTATCGC"
|
| 138 |
+
fig = create_gc_content_plot(test_dna)
|
| 139 |
+
print(" OK: GC content plot created")
|
| 140 |
+
|
| 141 |
+
# Test metrics comparison chart
|
| 142 |
+
before_metrics = {'gc_content': 50.0, 'cai': 0.5, 'tai': 0.3}
|
| 143 |
+
after_metrics = {'gc_content': 52.0, 'cai': 0.6, 'tai': 0.4}
|
| 144 |
+
fig = create_metrics_comparison_chart(before_metrics, after_metrics)
|
| 145 |
+
print(" OK: Metrics comparison chart created")
|
| 146 |
+
|
| 147 |
+
return True
|
| 148 |
+
except Exception as e:
|
| 149 |
+
print(f" FAIL: Error in visualization test: {e}")
|
| 150 |
+
traceback.print_exc()
|
| 151 |
+
return False
|
| 152 |
+
|
| 153 |
+
def test_codon_evaluation():
|
| 154 |
+
"""Test CodonEvaluation functions directly"""
|
| 155 |
+
print("\nTesting CodonEvaluation functions...")
|
| 156 |
+
|
| 157 |
+
try:
|
| 158 |
+
from CodonTransformer.CodonEvaluation import get_GC_content, calculate_tAI, get_ecoli_tai_weights
|
| 159 |
+
|
| 160 |
+
# Test GC content calculation
|
| 161 |
+
test_dna = "ATGGCGAAAGCG"
|
| 162 |
+
gc_content = get_GC_content(test_dna)
|
| 163 |
+
print(f" OK: GC content calculation: {gc_content:.1f}%")
|
| 164 |
+
|
| 165 |
+
# Test tAI calculation
|
| 166 |
+
try:
|
| 167 |
+
tai_weights = get_ecoli_tai_weights()
|
| 168 |
+
tai_value = calculate_tAI(test_dna, tai_weights)
|
| 169 |
+
print(f" OK: tAI calculation: {tai_value:.3f}")
|
| 170 |
+
except Exception as e:
|
| 171 |
+
print(f" NOTE: tAI calculation (may need scipy): {e}")
|
| 172 |
+
|
| 173 |
+
return True
|
| 174 |
+
except Exception as e:
|
| 175 |
+
print(f" FAIL: Error in CodonEvaluation test: {e}")
|
| 176 |
+
traceback.print_exc()
|
| 177 |
+
return False
|
| 178 |
+
|
| 179 |
+
def test_model_loading():
|
| 180 |
+
"""Test model loading functionality"""
|
| 181 |
+
print("\nTesting model loading (mock)...")
|
| 182 |
+
|
| 183 |
+
try:
|
| 184 |
+
import torch
|
| 185 |
+
from transformers import AutoTokenizer
|
| 186 |
+
from CodonTransformer.CodonPrediction import load_model
|
| 187 |
+
|
| 188 |
+
# Test tokenizer loading (this is fast)
|
| 189 |
+
print(" Testing tokenizer loading...")
|
| 190 |
+
tokenizer = AutoTokenizer.from_pretrained("adibvafa/CodonTransformer")
|
| 191 |
+
print(" OK: Tokenizer loaded successfully")
|
| 192 |
+
|
| 193 |
+
# Test load_model function
|
| 194 |
+
print(" Testing load_model function...")
|
| 195 |
+
from transformers import BigBirdForMaskedLM
|
| 196 |
+
print(" OK: Model class available: BigBirdForMaskedLM")
|
| 197 |
+
|
| 198 |
+
# Check if fine-tuned model exists
|
| 199 |
+
import os
|
| 200 |
+
model_path = "models/alm-enhanced-training/balanced_alm_finetune.ckpt"
|
| 201 |
+
if os.path.exists(model_path):
|
| 202 |
+
print(f" OK: Fine-tuned model found: {model_path}")
|
| 203 |
+
else:
|
| 204 |
+
print(f" NOTE: Fine-tuned model not found at: {model_path}")
|
| 205 |
+
|
| 206 |
+
# Note: We won't actually load the full model here as it's ~2GB
|
| 207 |
+
print(" NOTE: Full model loading skipped in test (too large)")
|
| 208 |
+
|
| 209 |
+
return True
|
| 210 |
+
except Exception as e:
|
| 211 |
+
print(f" FAIL: Error in model loading test: {e}")
|
| 212 |
+
traceback.print_exc()
|
| 213 |
+
return False
|
| 214 |
+
|
| 215 |
+
def test_file_structure():
|
| 216 |
+
"""Test if all required files exist"""
|
| 217 |
+
print("\nTesting file structure...")
|
| 218 |
+
|
| 219 |
+
gui_dir = Path(__file__).parent
|
| 220 |
+
parent_dir = gui_dir.parent
|
| 221 |
+
|
| 222 |
+
required_files = [
|
| 223 |
+
"app.py",
|
| 224 |
+
"run_gui.py",
|
| 225 |
+
"requirements.txt",
|
| 226 |
+
"README.md"
|
| 227 |
+
]
|
| 228 |
+
|
| 229 |
+
all_present = True
|
| 230 |
+
for file_name in required_files:
|
| 231 |
+
file_path = gui_dir / file_name
|
| 232 |
+
if file_path.exists():
|
| 233 |
+
print(f" OK: {file_name}")
|
| 234 |
+
else:
|
| 235 |
+
print(f" FAIL: {file_name} missing")
|
| 236 |
+
all_present = False
|
| 237 |
+
|
| 238 |
+
# Check for model checkpoint
|
| 239 |
+
model_path = parent_dir / "models" / "alm-enhanced-training" / "balanced_alm_finetune.ckpt"
|
| 240 |
+
if model_path.exists():
|
| 241 |
+
print(" OK: Fine-tuned model checkpoint found")
|
| 242 |
+
else:
|
| 243 |
+
print(" NOTE: Fine-tuned model checkpoint not found")
|
| 244 |
+
|
| 245 |
+
return all_present
|
| 246 |
+
|
| 247 |
+
def test_post_processing():
|
| 248 |
+
"""Test post-processing functionality"""
|
| 249 |
+
print("\nTesting post-processing features...")
|
| 250 |
+
|
| 251 |
+
try:
|
| 252 |
+
from app import POST_PROCESSING_AVAILABLE, DNACHISEL_AVAILABLE
|
| 253 |
+
|
| 254 |
+
if POST_PROCESSING_AVAILABLE:
|
| 255 |
+
print(" OK: Post-processing module available")
|
| 256 |
+
if DNACHISEL_AVAILABLE:
|
| 257 |
+
print(" OK: DNAChisel available")
|
| 258 |
+
else:
|
| 259 |
+
print(" NOTE: DNAChisel not available")
|
| 260 |
+
else:
|
| 261 |
+
print(" NOTE: Post-processing module not available")
|
| 262 |
+
|
| 263 |
+
return True
|
| 264 |
+
except Exception as e:
|
| 265 |
+
print(f" FAIL: Error in post-processing test: {e}")
|
| 266 |
+
return False
|
| 267 |
+
|
| 268 |
+
def main():
|
| 269 |
+
"""Run all tests"""
|
| 270 |
+
print("ENCOT GUI Test Suite")
|
| 271 |
+
print("=" * 50)
|
| 272 |
+
|
| 273 |
+
tests = [
|
| 274 |
+
("File Structure", test_file_structure),
|
| 275 |
+
("Imports", test_imports),
|
| 276 |
+
("Protein Validation", test_protein_validation),
|
| 277 |
+
("Metrics Calculation", test_metrics_calculation),
|
| 278 |
+
("Visualization Functions", test_visualization_functions),
|
| 279 |
+
("CodonEvaluation Functions", test_codon_evaluation),
|
| 280 |
+
("Model Loading", test_model_loading),
|
| 281 |
+
("Post-Processing", test_post_processing),
|
| 282 |
+
]
|
| 283 |
+
|
| 284 |
+
passed = 0
|
| 285 |
+
total = len(tests)
|
| 286 |
+
|
| 287 |
+
for test_name, test_func in tests:
|
| 288 |
+
try:
|
| 289 |
+
result = test_func()
|
| 290 |
+
if result:
|
| 291 |
+
passed += 1
|
| 292 |
+
print(f"OK: {test_name}: PASSED")
|
| 293 |
+
else:
|
| 294 |
+
print(f"FAIL: {test_name}: FAILED")
|
| 295 |
+
except Exception as e:
|
| 296 |
+
print(f"FAIL: {test_name}: ERROR - {e}")
|
| 297 |
+
|
| 298 |
+
print("\n" + "=" * 50)
|
| 299 |
+
print(f"Test Results: {passed}/{total} tests passed")
|
| 300 |
+
|
| 301 |
+
if passed == total:
|
| 302 |
+
print("All tests passed. The GUI should work correctly.")
|
| 303 |
+
print("\nTo run the GUI:")
|
| 304 |
+
print(" python run_gui.py")
|
| 305 |
+
print(" or")
|
| 306 |
+
print(" cd streamlit_gui && streamlit run app.py --server.address=0.0.0.0")
|
| 307 |
+
else:
|
| 308 |
+
print("Some tests failed. Please check the issues above.")
|
| 309 |
+
|
| 310 |
+
print("\nNotes:")
|
| 311 |
+
print(" • Fine-tuned model integration")
|
| 312 |
+
print(" • Enhanced constrained beam search")
|
| 313 |
+
print(" • Post-processing with DNAChisel")
|
| 314 |
+
print(" • Advanced sequence analysis")
|
| 315 |
+
print(" • Improved parameter controls")
|
| 316 |
+
|
| 317 |
+
return passed == total
|
| 318 |
+
|
| 319 |
+
if __name__ == "__main__":
|
| 320 |
+
success = main()
|
| 321 |
+
sys.exit(0 if success else 1)
|