Genooo12 commited on
Commit
404d784
·
verified ·
1 Parent(s): cd09b92

Deploy Streamlit UI

Browse files
Files changed (48) hide show
  1. .dockerignore +16 -0
  2. .gitattributes +3 -0
  3. .github/ISSUE_TEMPLATE/bug_report.md +38 -0
  4. .github/ISSUE_TEMPLATE/feature_request.md +17 -0
  5. .github/ISSUE_TEMPLATE/other.md +10 -0
  6. .github/workflows/ci.yml +39 -0
  7. .gitignore +230 -0
  8. Benchmark 80 sequences.xlsx +3 -0
  9. CODE_OF_CONDUCT.md +128 -0
  10. CodonTransformer/CodonData.py +682 -0
  11. CodonTransformer/CodonEvaluation.py +583 -0
  12. CodonTransformer/CodonJupyter.py +311 -0
  13. CodonTransformer/CodonPostProcessing.py +83 -0
  14. CodonTransformer/CodonPrediction.py +1372 -0
  15. CodonTransformer/CodonUtils.py +871 -0
  16. CodonTransformer/__init__.py +1 -0
  17. Dockerfile +21 -0
  18. ENCOT_Academic_Documentation.html +2625 -0
  19. ENCOT_Code_Showcase.html +791 -0
  20. LICENSE +201 -0
  21. Makefile +9 -0
  22. README.md +495 -10
  23. app.py +12 -0
  24. benchmark_evaluation.py +695 -0
  25. comprehensive_model_comparison.png +3 -0
  26. configs/train_ecoli_alm.yaml +54 -0
  27. configs/train_ecoli_quick.yaml +37 -0
  28. create_model_datasets.py +42 -0
  29. evaluate_optimizer.py +577 -0
  30. prepare_ecoli_data.py +69 -0
  31. pretrain.py +232 -0
  32. pyproject.toml +62 -0
  33. requirements.txt +29 -0
  34. scripts/optimize_sequence.py +383 -0
  35. scripts/preprocess_data.py +251 -0
  36. scripts/run_benchmarks.py +235 -0
  37. scripts/train.py +228 -0
  38. setup.py +40 -0
  39. src/CodonTransformer_inference_template.xlsx +0 -0
  40. src/__init__.py +1 -0
  41. src/banner_final.png +3 -0
  42. src/organism2id.pkl +3 -0
  43. streamlit_app.py +16 -0
  44. streamlit_gui/app.py +1456 -0
  45. streamlit_gui/demo.py +288 -0
  46. streamlit_gui/requirements.txt +20 -0
  47. streamlit_gui/run_gui.py +102 -0
  48. streamlit_gui/test_gui.py +321 -0
.dockerignore ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ .git
2
+ .github
3
+ .venv
4
+ __pycache__
5
+ *.pyc
6
+ *.pyo
7
+ *.pyd
8
+ *.log
9
+ *.ipynb
10
+ .devcontainer
11
+ data
12
+ notebooks
13
+ tests
14
+ slurm
15
+ Benchmark 80 sequences.xlsx
16
+ comprehensive_model_comparison.png
.gitattributes CHANGED
@@ -33,3 +33,6 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ Benchmark[[:space:]]80[[:space:]]sequences.xlsx filter=lfs diff=lfs merge=lfs -text
37
+ comprehensive_model_comparison.png filter=lfs diff=lfs merge=lfs -text
38
+ src/banner_final.png filter=lfs diff=lfs merge=lfs -text
.github/ISSUE_TEMPLATE/bug_report.md ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ name: Bug report
3
+ about: Create a report to help us improve
4
+ title: ''
5
+ labels: ''
6
+ assignees: ''
7
+
8
+ ---
9
+
10
+ **Describe the bug**
11
+ A clear and concise description of what the bug is.
12
+
13
+ **To Reproduce**
14
+ Steps to reproduce the behavior:
15
+ 1. Go to '...'
16
+ 2. Click on '....'
17
+ 3. Scroll down to '....'
18
+ 4. See error
19
+
20
+ **Expected behavior**
21
+ A clear and concise description of what you expected to happen.
22
+
23
+ **Screenshots**
24
+ If applicable, add screenshots to help explain your problem.
25
+
26
+ **Desktop (please complete the following information):**
27
+ - OS: [e.g. iOS]
28
+ - Browser [e.g. chrome, safari]
29
+ - Version [e.g. 22]
30
+
31
+ **Smartphone (please complete the following information):**
32
+ - Device: [e.g. iPhone6]
33
+ - OS: [e.g. iOS8.1]
34
+ - Browser [e.g. stock browser, safari]
35
+ - Version [e.g. 22]
36
+
37
+ **Additional context**
38
+ Add any other context about the problem here.
.github/ISSUE_TEMPLATE/feature_request.md ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ name: Feature request
3
+ about: Suggest an idea for this project
4
+ title: ''
5
+ labels: enhancement
6
+ assignees: ''
7
+
8
+ ---
9
+
10
+ **Is your feature request related to a problem? Please describe.**
11
+ A clear and concise description of what the problem is. Ex. I'm always frustrated when [...]
12
+
13
+ **Describe the solution you'd like**
14
+ A clear and concise description of what you want to happen.
15
+
16
+ **Additional context**
17
+ Add any other context or screenshots about the feature request here.
.github/ISSUE_TEMPLATE/other.md ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ name: Other
3
+ about: Any other issue
4
+ title: ''
5
+ labels: bug
6
+ assignees: ''
7
+
8
+ ---
9
+
10
+ **Describe your issue here**
.github/workflows/ci.yml ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # .github/workflows/ci.yml
2
+
3
+ name: CI
4
+
5
+ on: [push, pull_request]
6
+
7
+ jobs:
8
+ test:
9
+ runs-on: ubuntu-latest
10
+
11
+ steps:
12
+ - name: Checkout code
13
+ uses: actions/checkout@v4
14
+
15
+ - name: Set up Python
16
+ uses: actions/setup-python@v5
17
+ with:
18
+ python-version: '3.10'
19
+
20
+ - name: Install dependencies
21
+ run: |
22
+ python -m pip install --upgrade pip
23
+ pip install -r requirements.txt
24
+ pip install "coverage[toml]"
25
+
26
+ - name: Run tests with coverage
27
+ run: |
28
+ make test_with_coverage
29
+ coverage report
30
+ coverage xml
31
+
32
+ - name: Upload coverage to Codecov
33
+ uses: codecov/codecov-action@v4
34
+ with:
35
+ token: ${{ secrets.CODECOV_TOKEN }}
36
+ file: coverage.xml
37
+ flags: unittests
38
+ name: codecov-umbrella
39
+ fail_ci_if_error: true
.gitignore ADDED
@@ -0,0 +1,230 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+
6
+ # C extensions
7
+ *.so
8
+
9
+ # Distribution / packaging
10
+ .Python
11
+ build/
12
+ develop-eggs/
13
+ dist/
14
+ downloads/
15
+ eggs/
16
+ .eggs/
17
+ lib/
18
+ lib64/
19
+ parts/
20
+ sdist/
21
+ var/
22
+ wheels/
23
+ share/python-wheels/
24
+ *.egg-info/
25
+ .installed.cfg
26
+ *.egg
27
+ MANIFEST
28
+
29
+ # PyInstaller
30
+ # Usually these files are written by a python script from a template
31
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
32
+ *.manifest
33
+ *.spec
34
+
35
+ # Installer logs
36
+ pip-log.txt
37
+ pip-delete-this-directory.txt
38
+
39
+ # Unit test / coverage reports
40
+ htmlcov/
41
+ .tox/
42
+ .nox/
43
+ .coverage
44
+ .coverage.*
45
+ .cache
46
+ nosetests.xml
47
+ coverage.xml
48
+ *.cover
49
+ *.py,cover
50
+ .hypothesis/
51
+ .pytest_cache/
52
+ cover/
53
+
54
+ # Translations
55
+ *.mo
56
+ *.pot
57
+
58
+ # Django stuff:
59
+ *.log
60
+ local_settings.py
61
+ db.sqlite3
62
+ db.sqlite3-journal
63
+
64
+ # Flask stuff:
65
+ instance/
66
+ .webassets-cache
67
+
68
+ # Scrapy stuff:
69
+ .scrapy
70
+
71
+ # Sphinx documentation
72
+ docs/_build/
73
+
74
+ # PyBuilder
75
+ .pybuilder/
76
+ target/
77
+
78
+ # Jupyter Notebook
79
+ .ipynb_checkpoints
80
+
81
+ # IPython
82
+ profile_default/
83
+ ipython_config.py
84
+
85
+ # pyenv
86
+ # For a library or package, you might want to ignore these files since the code is
87
+ # intended to run in multiple environments; otherwise, check them in:
88
+ # .python-version
89
+
90
+ # pipenv
91
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
93
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
94
+ # install all needed dependencies.
95
+ #Pipfile.lock
96
+
97
+ # poetry
98
+ # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
99
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
100
+ # commonly ignored for libraries.
101
+ # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
102
+ #poetry.lock
103
+
104
+ # pdm
105
+ # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
106
+ #pdm.lock
107
+ # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
108
+ # in version control.
109
+ # https://pdm.fming.dev/#use-with-ide
110
+ .pdm.toml
111
+
112
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
113
+ __pypackages__/
114
+
115
+ # Celery stuff
116
+ celerybeat-schedule
117
+ celerybeat.pid
118
+
119
+ # SageMath parsed files
120
+ *.sage.py
121
+
122
+ # Environments
123
+ .env
124
+ .venv
125
+ env/
126
+ venv/
127
+ ENV/
128
+ env.bak/
129
+ venv.bak/
130
+ codon_env/
131
+
132
+ # Spyder project settings
133
+ .spyderproject
134
+ .spyproject
135
+
136
+ # Rope project settings
137
+ .ropeproject
138
+
139
+ # mkdocs documentation
140
+ /site
141
+
142
+ # mypy
143
+ .mypy_cache/
144
+ .dmypy.json
145
+ dmypy.json
146
+
147
+ # Pyre type checker
148
+ .pyre/
149
+
150
+ # pytype static type analyzer
151
+ .pytype/
152
+
153
+ # Cython debug symbols
154
+ cython_debug/
155
+
156
+ # PyCharm
157
+ # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
158
+ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
159
+ # and can be added to the global gitignore or merged into this file. For a more nuclear
160
+ # option (not recommended) you can uncomment the following to ignore the entire idea folder.
161
+ #.idea/
162
+
163
+ # Coverage reports
164
+ coverage.xml
165
+
166
+ # Jupyter Notebook checkpoints
167
+ .ipynb_checkpoints/
168
+
169
+ # Temporary files
170
+ *.tmp
171
+ *.temp
172
+
173
+ # PyTorch Lightning checkpoints
174
+ lightning_logs/
175
+
176
+ # PyTorch model weights
177
+ *.pth
178
+ *.pt
179
+
180
+ # Large files excluded from Git
181
+ models/ecoli-codon-optimizer/finetune.ckpt
182
+ models/ecoli-codon-optimizer/finetune_best.ckpt
183
+ data/ecoli_processed_genes.csv
184
+
185
+ # Finetune-related files (keep local only)
186
+ finetune.py
187
+ checkpoints/
188
+ *.safetensors
189
+
190
+ # Benchmark and validation results
191
+ benchmark_plots/
192
+ cai_tai_benchmark.csv
193
+ synthetic_validation.csv
194
+ test_set_validation.csv
195
+
196
+ # Large data files
197
+ *.csv
198
+ *.jsonl
199
+ *.json
200
+ *.fasta
201
+ *.fa
202
+ *.ckpt
203
+
204
+ # Results and outputs
205
+ results/
206
+ outputs/
207
+ logs/
208
+
209
+ # Model files and weights
210
+ *.bin
211
+ *.safetensors
212
+
213
+ # CUDA and GPU related
214
+ *.run
215
+ cuda_installer.pyz
216
+
217
+ # R files
218
+ .RData
219
+ .Rhistory
220
+
221
+ # OS generated files
222
+ .DS_Store
223
+ .DS_Store?
224
+ ._*
225
+ .Spotlight-V100
226
+ .Trashes
227
+ ehthumbs.db
228
+ Thumbs.db
229
+ research/
230
+ models/alm-enhanced-training/balanced_alm_finetune.ckpt
Benchmark 80 sequences.xlsx ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f80bde88a31e80ac34b0827180b50d112f1d26bdf691c8118943e91c0e3b09e2
3
+ size 179471
CODE_OF_CONDUCT.md ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Contributor Covenant Code of Conduct
2
+
3
+ ## Our Pledge
4
+
5
+ We as members, contributors, and leaders pledge to make participation in our
6
+ community a harassment-free experience for everyone, regardless of age, body
7
+ size, visible or invisible disability, ethnicity, sex characteristics, gender
8
+ identity and expression, level of experience, education, socio-economic status,
9
+ nationality, personal appearance, race, religion, or sexual identity
10
+ and orientation.
11
+
12
+ We pledge to act and interact in ways that contribute to an open, welcoming,
13
+ diverse, inclusive, and healthy community.
14
+
15
+ ## Our Standards
16
+
17
+ Examples of behavior that contributes to a positive environment for our
18
+ community include:
19
+
20
+ * Demonstrating empathy and kindness toward other people
21
+ * Being respectful of differing opinions, viewpoints, and experiences
22
+ * Giving and gracefully accepting constructive feedback
23
+ * Accepting responsibility and apologizing to those affected by our mistakes,
24
+ and learning from the experience
25
+ * Focusing on what is best not just for us as individuals, but for the
26
+ overall community
27
+
28
+ Examples of unacceptable behavior include:
29
+
30
+ * The use of sexualized language or imagery, and sexual attention or
31
+ advances of any kind
32
+ * Trolling, insulting or derogatory comments, and personal or political attacks
33
+ * Public or private harassment
34
+ * Publishing others' private information, such as a physical or email
35
+ address, without their explicit permission
36
+ * Other conduct which could reasonably be considered inappropriate in a
37
+ professional setting
38
+
39
+ ## Enforcement Responsibilities
40
+
41
+ Community leaders are responsible for clarifying and enforcing our standards of
42
+ acceptable behavior and will take appropriate and fair corrective action in
43
+ response to any behavior that they deem inappropriate, threatening, offensive,
44
+ or harmful.
45
+
46
+ Community leaders have the right and responsibility to remove, edit, or reject
47
+ comments, commits, code, wiki edits, issues, and other contributions that are
48
+ not aligned to this Code of Conduct, and will communicate reasons for moderation
49
+ decisions when appropriate.
50
+
51
+ ## Scope
52
+
53
+ This Code of Conduct applies within all community spaces, and also applies when
54
+ an individual is officially representing the community in public spaces.
55
+ Examples of representing our community include using an official e-mail address,
56
+ posting via an official social media account, or acting as an appointed
57
+ representative at an online or offline event.
58
+
59
+ ## Enforcement
60
+
61
+ Instances of abusive, harassing, or otherwise unacceptable behavior may be
62
+ reported to the community leaders responsible for enforcement at
63
+ Adibvafa.fallahpour@mail.utoronto.ca.
64
+ All complaints will be reviewed and investigated promptly and fairly.
65
+
66
+ All community leaders are obligated to respect the privacy and security of the
67
+ reporter of any incident.
68
+
69
+ ## Enforcement Guidelines
70
+
71
+ Community leaders will follow these Community Impact Guidelines in determining
72
+ the consequences for any action they deem in violation of this Code of Conduct:
73
+
74
+ ### 1. Correction
75
+
76
+ **Community Impact**: Use of inappropriate language or other behavior deemed
77
+ unprofessional or unwelcome in the community.
78
+
79
+ **Consequence**: A private, written warning from community leaders, providing
80
+ clarity around the nature of the violation and an explanation of why the
81
+ behavior was inappropriate. A public apology may be requested.
82
+
83
+ ### 2. Warning
84
+
85
+ **Community Impact**: A violation through a single incident or series
86
+ of actions.
87
+
88
+ **Consequence**: A warning with consequences for continued behavior. No
89
+ interaction with the people involved, including unsolicited interaction with
90
+ those enforcing the Code of Conduct, for a specified period of time. This
91
+ includes avoiding interactions in community spaces as well as external channels
92
+ like social media. Violating these terms may lead to a temporary or
93
+ permanent ban.
94
+
95
+ ### 3. Temporary Ban
96
+
97
+ **Community Impact**: A serious violation of community standards, including
98
+ sustained inappropriate behavior.
99
+
100
+ **Consequence**: A temporary ban from any sort of interaction or public
101
+ communication with the community for a specified period of time. No public or
102
+ private interaction with the people involved, including unsolicited interaction
103
+ with those enforcing the Code of Conduct, is allowed during this period.
104
+ Violating these terms may lead to a permanent ban.
105
+
106
+ ### 4. Permanent Ban
107
+
108
+ **Community Impact**: Demonstrating a pattern of violation of community
109
+ standards, including sustained inappropriate behavior, harassment of an
110
+ individual, or aggression toward or disparagement of classes of individuals.
111
+
112
+ **Consequence**: A permanent ban from any sort of public interaction within
113
+ the community.
114
+
115
+ ## Attribution
116
+
117
+ This Code of Conduct is adapted from the [Contributor Covenant][homepage],
118
+ version 2.0, available at
119
+ https://www.contributor-covenant.org/version/2/0/code_of_conduct.html.
120
+
121
+ Community Impact Guidelines were inspired by [Mozilla's code of conduct
122
+ enforcement ladder](https://github.com/mozilla/diversity).
123
+
124
+ [homepage]: https://www.contributor-covenant.org
125
+
126
+ For answers to common questions about this code of conduct, see the FAQ at
127
+ https://www.contributor-covenant.org/faq. Translations are available at
128
+ https://www.contributor-covenant.org/translations.
CodonTransformer/CodonData.py ADDED
@@ -0,0 +1,682 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ File: CodonData.py
3
+ ---------------------
4
+ Includes helper functions for preprocessing NCBI or Kazusa databases and
5
+ preparing the data for training and inference of the CodonTransformer model.
6
+ """
7
+
8
+ import json
9
+ import os
10
+ import random
11
+ from typing import Dict, List, Optional, Tuple, Union
12
+
13
+ import pandas as pd
14
+ import python_codon_tables as pct
15
+ from Bio import SeqIO
16
+ from Bio.Seq import Seq
17
+ from sklearn.utils import shuffle as sk_shuffle
18
+ from tqdm import tqdm
19
+
20
+ from CodonTransformer.CodonUtils import (
21
+ AMBIGUOUS_AMINOACID_MAP,
22
+ AMINO2CODON_TYPE,
23
+ AMINO_ACIDS,
24
+ ORGANISM2ID,
25
+ START_CODONS,
26
+ STOP_CODONS,
27
+ STOP_SYMBOL,
28
+ STOP_SYMBOLS,
29
+ ProteinConfig,
30
+ find_pattern_in_fasta,
31
+ get_taxonomy_id,
32
+ sort_amino2codon_skeleton,
33
+ )
34
+
35
+
36
+ def prepare_training_data(
37
+ dataset: Union[str, pd.DataFrame], output_file: str, shuffle: bool = True
38
+ ) -> None:
39
+ """
40
+ Prepare a JSON dataset for training the CodonTransformer model.
41
+
42
+ Input dataset should have columns below:
43
+ - dna: str (DNA sequence)
44
+ - protein: str (Protein sequence)
45
+ - organism: Union[int, str] (ID or Name of the organism)
46
+
47
+ The output JSON dataset will have the following format:
48
+ {"idx": 0, "codons": "M_ATG R_AGG L_TTG L_CTA R_CGA __TAG", "organism": 51}
49
+ {"idx": 1, "codons": "M_ATG K_AAG C_TGC F_TTT F_TTC __TAA", "organism": 59}
50
+
51
+ Args:
52
+ dataset (Union[str, pd.DataFrame]): Input dataset in CSV or DataFrame format.
53
+ output_file (str): Path to save the output JSON dataset.
54
+ shuffle (bool, optional): Whether to shuffle the dataset before saving.
55
+ Defaults to True.
56
+
57
+ Returns:
58
+ None
59
+ """
60
+ if isinstance(dataset, str):
61
+ dataset = pd.read_csv(dataset)
62
+
63
+ required_columns = {"dna", "protein", "organism"}
64
+ if not required_columns.issubset(dataset.columns):
65
+ raise ValueError(f"Input dataset must have columns: {required_columns}")
66
+
67
+ # Prepare the dataset for finetuning
68
+ dataset["codons"] = dataset.apply(
69
+ lambda row: get_merged_seq(row["protein"], row["dna"], separator="_"), axis=1
70
+ )
71
+
72
+ # Replace organism str with organism id using ORGANISM2ID
73
+ dataset["organism"] = dataset["organism"].apply(
74
+ lambda org: process_organism(org, ORGANISM2ID)
75
+ )
76
+
77
+ # Save the dataset to a JSON file
78
+ dataframe_to_json(dataset[["codons", "organism"]], output_file, shuffle=shuffle)
79
+
80
+
81
+ def dataframe_to_json(df: pd.DataFrame, output_file: str, shuffle: bool = True) -> None:
82
+ """
83
+ Convert pandas DataFrame to JSON file format suitable for training CodonTransformer.
84
+
85
+ This function takes a preprocessed DataFrame and writes it to a JSON file
86
+ where each line is a JSON object representing a single record.
87
+
88
+ Args:
89
+ df (pd.DataFrame): The input DataFrame with 'codons' and 'organism' columns.
90
+ output_file (str): Path to the output JSON file.
91
+ shuffle (bool, optional): Whether to shuffle the dataset before saving.
92
+ Defaults to True.
93
+
94
+ Returns:
95
+ None
96
+
97
+ Raises:
98
+ ValueError: If the required columns are not present in the DataFrame.
99
+ """
100
+ required_columns = {"codons", "organism"}
101
+ if not required_columns.issubset(df.columns):
102
+ raise ValueError(f"DataFrame must contain columns: {required_columns}")
103
+
104
+ print(f"\nStarted writing to {output_file}...")
105
+
106
+ # Shuffle the DataFrame if requested
107
+ if shuffle:
108
+ df = sk_shuffle(df)
109
+
110
+ # Write the DataFrame to a JSON file
111
+ with open(output_file, "w") as f:
112
+ for idx, row in tqdm(
113
+ df.iterrows(), total=len(df), desc="Writing JSON...", unit=" records"
114
+ ):
115
+ doc = {"idx": idx, "codons": row["codons"], "organism": row["organism"]}
116
+ f.write(json.dumps(doc) + "\n")
117
+
118
+ print(f"\nTotal Entries Saved: {len(df)}, JSON data saved to {output_file}")
119
+
120
+
121
+ def process_organism(organism: Union[str, int], organism_to_id: Dict[str, int]) -> int:
122
+ """
123
+ Process and validate the organism input, converting it to a valid organism ID.
124
+
125
+ This function handles both string (organism name) and integer (organism ID) inputs.
126
+ It validates the input against a provided mapping of organism names to IDs.
127
+
128
+ Args:
129
+ organism (Union[str, int]): Input organism, either as a name (str) or ID (int).
130
+ organism_to_id (Dict[str, int]): Dictionary mapping organism names to their
131
+ corresponding IDs.
132
+
133
+ Returns:
134
+ int: The validated organism ID.
135
+
136
+ Raises:
137
+ ValueError: If the input is an invalid organism name or ID.
138
+ TypeError: If the input is neither a string nor an integer.
139
+ """
140
+ if isinstance(organism, str):
141
+ if organism not in organism_to_id:
142
+ raise ValueError(f"Invalid organism name: {organism}")
143
+ return organism_to_id[organism]
144
+
145
+ elif isinstance(organism, int):
146
+ if organism not in organism_to_id.values():
147
+ raise ValueError(f"Invalid organism ID: {organism}")
148
+ return organism
149
+
150
+ raise TypeError(
151
+ f"Organism must be a string or integer, not {type(organism).__name__}"
152
+ )
153
+
154
+
155
+ def preprocess_protein_sequence(protein: str) -> str:
156
+ """
157
+ Preprocess a protein sequence by cleaning, standardizing, and handling
158
+ ambiguous amino acids.
159
+
160
+ Args:
161
+ protein (str): The input protein sequence.
162
+
163
+ Returns:
164
+ str: The preprocessed protein sequence.
165
+
166
+ Raises:
167
+ ValueError: If the protein sequence is invalid or if the configuration is invalid.
168
+ """
169
+ if not protein:
170
+ raise ValueError("Protein sequence is empty.")
171
+
172
+ # Clean and standardize the protein sequence
173
+ protein = (
174
+ protein.upper().strip().replace("\n", "").replace(" ", "").replace("\t", "")
175
+ )
176
+
177
+ # Handle ambiguous amino acids based on the specified behavior
178
+ config = ProteinConfig()
179
+ ambiguous_aminoacid_map_override = config.get("ambiguous_aminoacid_map_override")
180
+ ambiguous_aminoacid_behavior = config.get("ambiguous_aminoacid_behavior")
181
+ ambiguous_aminoacid_map = AMBIGUOUS_AMINOACID_MAP.copy()
182
+
183
+ for aminoacid, standard_aminoacids in ambiguous_aminoacid_map_override.items():
184
+ ambiguous_aminoacid_map[aminoacid] = standard_aminoacids
185
+
186
+ if ambiguous_aminoacid_behavior == "raise_error":
187
+ if any(aminoacid in ambiguous_aminoacid_map for aminoacid in protein):
188
+ raise ValueError("Ambiguous amino acids found in protein sequence.")
189
+ elif ambiguous_aminoacid_behavior == "standardize_deterministic":
190
+ protein = "".join(
191
+ ambiguous_aminoacid_map.get(aminoacid, [aminoacid])[0]
192
+ for aminoacid in protein
193
+ )
194
+ elif ambiguous_aminoacid_behavior == "standardize_random":
195
+ protein = "".join(
196
+ random.choice(ambiguous_aminoacid_map.get(aminoacid, [aminoacid]))
197
+ for aminoacid in protein
198
+ )
199
+ else:
200
+ raise ValueError(
201
+ f"Invalid ambiguous_aminoacid_behavior: {ambiguous_aminoacid_behavior}."
202
+ )
203
+
204
+ # Check for sequence validity
205
+ if any(aminoacid not in AMINO_ACIDS + STOP_SYMBOLS for aminoacid in protein):
206
+ raise ValueError("Invalid characters in protein sequence.")
207
+
208
+ if protein[-1] not in AMINO_ACIDS + STOP_SYMBOLS:
209
+ raise ValueError(
210
+ "Protein sequence must end with `*`, or `_`, or an amino acid."
211
+ )
212
+
213
+ # Replace '*' at the end of protein with STOP_SYMBOL if present
214
+ if protein[-1] == "*":
215
+ protein = protein[:-1] + STOP_SYMBOL
216
+
217
+ # Add stop symbol to end of protein
218
+ if protein[-1] != STOP_SYMBOL:
219
+ protein += STOP_SYMBOL
220
+
221
+ return protein
222
+
223
+
224
+ def replace_ambiguous_codons(dna: str) -> str:
225
+ """
226
+ Replaces ambiguous codons in a DNA sequence with "UNK".
227
+
228
+ Args:
229
+ dna (str): The DNA sequence to process.
230
+
231
+ Returns:
232
+ str: The processed DNA sequence with ambiguous codons replaced by "UNK".
233
+ """
234
+ result = []
235
+ dna = dna.upper()
236
+
237
+ # Check codons in DNA sequence
238
+ for i in range(0, len(dna), 3):
239
+ codon = dna[i : i + 3]
240
+
241
+ if len(codon) == 3 and all(nucleotide in "ATCG" for nucleotide in codon):
242
+ result.append(codon)
243
+ else:
244
+ result.append("UNK")
245
+
246
+ return "".join(result)
247
+
248
+
249
+ def preprocess_dna_sequence(dna: str) -> str:
250
+ """
251
+ Cleans and preprocesses a DNA sequence by standardizing it and replacing
252
+ ambiguous codons.
253
+
254
+ Args:
255
+ dna (str): The DNA sequence to preprocess.
256
+
257
+ Returns:
258
+ str: The cleaned and preprocessed DNA sequence.
259
+ """
260
+ if not dna:
261
+ return ""
262
+
263
+ # Clean and standardize the DNA sequence
264
+ dna = dna.upper().strip().replace("\n", "").replace(" ", "").replace("\t", "")
265
+
266
+ # Replace codons with ambigous nucleotides with "UNK"
267
+ dna = replace_ambiguous_codons(dna)
268
+
269
+ # Add unkown stop codon to end of DNA sequence if not present
270
+ if dna[-3:] not in STOP_CODONS:
271
+ dna += "UNK"
272
+
273
+ return dna
274
+
275
+
276
+ def get_merged_seq(protein: str, dna: str = "", separator: str = "_") -> str:
277
+ """
278
+ Return the merged sequence of protein amino acids and DNA codons in the form
279
+ of tokens separated by space, where each token is composed of an amino acid +
280
+ separator + codon.
281
+
282
+ Args:
283
+ protein (str): Protein sequence.
284
+ dna (str): DNA sequence.
285
+ separator (str): Separator between amino acid and codon.
286
+
287
+ Returns:
288
+ str: Merged sequence.
289
+
290
+ Example:
291
+ >>> get_merged_seq(protein="MAV_", dna="ATGGCTGTGTAA", separator="_")
292
+ 'M_ATG A_GCT V_GTG __TAA'
293
+
294
+ >>> get_merged_seq(protein="QHH_", dna="", separator="_")
295
+ 'Q_UNK H_UNK H_UNK __UNK'
296
+ """
297
+ merged_seq = ""
298
+
299
+ # Prepare protein and dna sequences
300
+ dna = preprocess_dna_sequence(dna)
301
+ protein = preprocess_protein_sequence(protein)
302
+
303
+ # Check if the length of protein and dna sequences are equal
304
+ if len(dna) > 0 and len(protein) != len(dna) / 3:
305
+ raise ValueError(
306
+ 'Length of protein (including stop symbol such as "_") and '
307
+ "the number of codons in DNA sequence (including stop codon) "
308
+ "must be equal."
309
+ )
310
+
311
+ # Merge protein and DNA sequences into tokens
312
+ for i, aminoacid in enumerate(protein):
313
+ merged_seq += f'{aminoacid}{separator}{dna[i * 3:i * 3 + 3] if dna else "UNK"} '
314
+
315
+ return merged_seq.strip()
316
+
317
+
318
+ def is_correct_seq(dna: str, protein: str, stop_symbol: str = STOP_SYMBOL) -> bool:
319
+ """
320
+ Check if the given DNA and protein pair is correct, that is:
321
+ 1. The length of dna is divisible by 3
322
+ 2. There is an initiator codon in the beginning of dna
323
+ 3. There is only one stop codon in the sequence
324
+ 4. The only stop codon is the last codon
325
+
326
+ Note since in Codon Table 3, 'TGA' is interpreted as Triptophan (W),
327
+ there is a separate check to make sure those sequences are considered correct.
328
+
329
+ Args:
330
+ dna (str): DNA sequence.
331
+ protein (str): Protein sequence.
332
+ stop_symbol (str): Stop symbol.
333
+
334
+ Returns:
335
+ bool: True if the sequence is correct, False otherwise.
336
+ """
337
+ return (
338
+ len(dna) % 3 == 0 # Check if DNA length is divisible by 3
339
+ and dna[:3].upper() in START_CODONS # Check for initiator codon
340
+ and protein[-1]
341
+ == stop_symbol # Check if the last protein symbol is the stop symbol
342
+ and protein.count(stop_symbol) == 1 # Check if there is only one stop symbol
343
+ and len(set(dna))
344
+ == 4 # Check if DNA consists of 4 unique nucleotides (A, T, C, G)
345
+ )
346
+
347
+
348
+ def get_amino_acid_sequence(
349
+ dna: str,
350
+ stop_symbol: str = "_",
351
+ codon_table: int = 1,
352
+ return_correct_seq: bool = False,
353
+ ) -> Union[str, Tuple[str, bool]]:
354
+ """
355
+ Return the translated protein sequence given a DNA sequence and codon table.
356
+
357
+ Args:
358
+ dna (str): DNA sequence.
359
+ stop_symbol (str): Stop symbol.
360
+ codon_table (int): Codon table number.
361
+ return_correct_seq (bool): Whether to return if the sequence is correct.
362
+
363
+ Returns:
364
+ Union[str, Tuple[str, bool]]: Protein sequence and correctness flag if
365
+ return_correct_seq is True, otherwise just the protein sequence.
366
+ """
367
+ dna_seq = Seq(dna).strip()
368
+
369
+ # Translate the DNA sequence to a protein sequence
370
+ protein_seq = str(
371
+ dna_seq.translate(
372
+ stop_symbol=stop_symbol, # Symbol to use for stop codons
373
+ to_stop=False, # Translate the entire sequence, including any stop codons
374
+ cds=False, # Do not assume the input is a coding sequence
375
+ table=codon_table, # Codon table to use for translation
376
+ )
377
+ ).strip()
378
+
379
+ return (
380
+ protein_seq
381
+ if not return_correct_seq
382
+ else (protein_seq, is_correct_seq(dna_seq, protein_seq, stop_symbol))
383
+ )
384
+
385
+
386
+ def read_fasta_file(
387
+ input_file: str,
388
+ save_to_file: Optional[str] = None,
389
+ organism: str = "",
390
+ buffer_size: int = 50000,
391
+ ) -> pd.DataFrame:
392
+ """
393
+ Read a FASTA file of DNA sequences and convert it to a Pandas DataFrame.
394
+ Optionally, save the DataFrame to a CSV file.
395
+
396
+ Args:
397
+ input_file (str): Path to the input FASTA file.
398
+ save_to_file (Optional[str]): Path to save the output DataFrame. If None,
399
+ data is only returned.
400
+ organism (str): Name of the organism. If empty, it will be extracted from
401
+ the FASTA description.
402
+ buffer_size (int): Number of records to process before writing to file.
403
+
404
+ Returns:
405
+ pd.DataFrame: DataFrame containing the DNA sequences if return_dataframe
406
+ is True, else None.
407
+
408
+ Raises:
409
+ FileNotFoundError: If the input file does not exist.
410
+ """
411
+ if not os.path.exists(input_file):
412
+ raise FileNotFoundError(f"Input file not found: {input_file}")
413
+
414
+ buffer = []
415
+ columns = [
416
+ "dna",
417
+ "protein",
418
+ "correct_seq",
419
+ "organism",
420
+ "GeneID",
421
+ "description",
422
+ "tokenized",
423
+ ]
424
+
425
+ # Initialize DataFrame to store all data if return_dataframe is True
426
+ all_data = pd.DataFrame(columns=columns)
427
+
428
+ with open(input_file, "r") as fasta_file:
429
+ for record in tqdm(
430
+ SeqIO.parse(fasta_file, "fasta"),
431
+ desc=f"Processing {organism}",
432
+ unit=" Records",
433
+ ):
434
+ dna = str(record.seq).strip().upper() # Ensure uppercase DNA sequence
435
+
436
+ # Determine the organism from the record if not provided
437
+ current_organism = organism or find_pattern_in_fasta(
438
+ "organism", record.description
439
+ )
440
+ gene_id = find_pattern_in_fasta("GeneID", record.description)
441
+
442
+ # Get the appropriate codon table for the organism
443
+ codon_table = get_codon_table(current_organism)
444
+
445
+ # Translate DNA to protein sequence
446
+ protein, correct_seq = get_amino_acid_sequence(
447
+ dna,
448
+ stop_symbol=STOP_SYMBOL,
449
+ codon_table=codon_table,
450
+ return_correct_seq=True,
451
+ )
452
+ description = record.description.split("[", 1)[0].strip()
453
+ tokenized = get_merged_seq(protein, dna, separator=STOP_SYMBOL)
454
+
455
+ # Create a data row for the current sequence
456
+ data_row = {
457
+ "dna": dna,
458
+ "protein": protein,
459
+ "correct_seq": correct_seq,
460
+ "organism": current_organism,
461
+ "GeneID": gene_id,
462
+ "description": description,
463
+ "tokenized": tokenized,
464
+ }
465
+ buffer.append(data_row)
466
+
467
+ # Write buffer to CSV file when buffer size is reached
468
+ if save_to_file and len(buffer) >= buffer_size:
469
+ write_buffer_to_csv(buffer, save_to_file, columns)
470
+ buffer = []
471
+
472
+ all_data = pd.concat(
473
+ [all_data, pd.DataFrame([data_row])], ignore_index=True
474
+ )
475
+
476
+ # Write remaining buffer to CSV file
477
+ if save_to_file and buffer:
478
+ write_buffer_to_csv(buffer, save_to_file, columns)
479
+
480
+ return all_data
481
+
482
+
483
+ def write_buffer_to_csv(buffer: List[Dict], output_path: str, columns: List[str]):
484
+ """Helper function to write buffer to CSV file."""
485
+ buffer_df = pd.DataFrame(buffer, columns=columns)
486
+ buffer_df.to_csv(
487
+ output_path,
488
+ mode="a",
489
+ header=(not os.path.exists(output_path)),
490
+ index=True,
491
+ )
492
+
493
+
494
+ def download_codon_frequencies_from_kazusa(
495
+ taxonomy_id: Optional[int] = None,
496
+ organism: Optional[str] = None,
497
+ taxonomy_reference: Optional[str] = None,
498
+ return_original_format: bool = False,
499
+ ) -> AMINO2CODON_TYPE:
500
+ """
501
+ Return the codon table of the given taxonomy ID from the Kazusa Database.
502
+
503
+ Args:
504
+ taxonomy_id (Optional[int]): Taxonomy ID.
505
+ organism (Optional[str]): Name of the organism.
506
+ taxonomy_reference (Optional[str]): Taxonomy reference.
507
+ return_original_format (bool): Whether to return in the original format.
508
+
509
+ Returns:
510
+ AMINO2CODON_TYPE: Codon table.
511
+ """
512
+ if taxonomy_reference:
513
+ taxonomy_id = get_taxonomy_id(taxonomy_reference, organism=organism)
514
+
515
+ kazusa_amino2codon = pct.get_codons_table(table_name=taxonomy_id)
516
+
517
+ if return_original_format:
518
+ return kazusa_amino2codon
519
+
520
+ # Replace "*" with STOP_SYMBOL in the codon table
521
+ kazusa_amino2codon[STOP_SYMBOL] = kazusa_amino2codon.pop("*")
522
+
523
+ # Create amino2codon dictionary
524
+ amino2codon = {
525
+ aminoacid: (list(codon2freq.keys()), list(codon2freq.values()))
526
+ for aminoacid, codon2freq in kazusa_amino2codon.items()
527
+ }
528
+
529
+ return sort_amino2codon_skeleton(amino2codon)
530
+
531
+
532
+ def build_amino2codon_skeleton(organism: str) -> AMINO2CODON_TYPE:
533
+ """
534
+ Return the empty skeleton of the amino2codon dictionary, needed for
535
+ get_codon_frequencies.
536
+
537
+ Args:
538
+ organism (str): Name of the organism.
539
+
540
+ Returns:
541
+ AMINO2CODON_TYPE: Empty amino2codon dictionary.
542
+ """
543
+ amino2codon = {}
544
+ possible_codons = [f"{i}{j}{k}" for i in "ACGT" for j in "ACGT" for k in "ACGT"]
545
+ possible_aminoacids = get_amino_acid_sequence(
546
+ dna="".join(possible_codons),
547
+ codon_table=get_codon_table(organism),
548
+ return_correct_seq=False,
549
+ )
550
+
551
+ # Initialize the amino2codon skeleton with all possible codons and set their
552
+ # frequencies to 0
553
+ for i, (codon, amino) in enumerate(zip(possible_codons, possible_aminoacids)):
554
+ if amino not in amino2codon:
555
+ amino2codon[amino] = ([], [])
556
+
557
+ amino2codon[amino][0].append(codon)
558
+ amino2codon[amino][1].append(0)
559
+
560
+ # Sort the dictionary and each list of codon frequency alphabetically
561
+ amino2codon = sort_amino2codon_skeleton(amino2codon)
562
+
563
+ return amino2codon
564
+
565
+
566
+ def get_codon_frequencies(
567
+ dna_sequences: List[str],
568
+ protein_sequences: Optional[List[str]] = None,
569
+ organism: Optional[str] = None,
570
+ ) -> AMINO2CODON_TYPE:
571
+ """
572
+ Return a dictionary mapping each codon to its respective frequency based on
573
+ the collection of DNA sequences and protein sequences.
574
+
575
+ Args:
576
+ dna_sequences (List[str]): List of DNA sequences.
577
+ protein_sequences (Optional[List[str]]): List of protein sequences.
578
+ organism (Optional[str]): Name of the organism.
579
+
580
+ Returns:
581
+ AMINO2CODON_TYPE: Dictionary mapping each amino acid to a tuple of codons
582
+ and frequencies.
583
+ """
584
+ if organism:
585
+ codon_table = get_codon_table(organism)
586
+ protein_sequences = [
587
+ get_amino_acid_sequence(
588
+ dna, codon_table=codon_table, return_correct_seq=False
589
+ )
590
+ for dna in dna_sequences
591
+ ]
592
+
593
+ amino2codon = build_amino2codon_skeleton(organism)
594
+
595
+ # Count the frequencies of each codon for each amino acid
596
+ for dna, protein in zip(dna_sequences, protein_sequences):
597
+ for i, amino in enumerate(protein):
598
+ codon = dna[i * 3 : (i + 1) * 3]
599
+ codon_loc = amino2codon[amino][0].index(codon)
600
+ amino2codon[amino][1][codon_loc] += 1
601
+
602
+ # Normalize codon frequencies per amino acid so they sum to 1
603
+ amino2codon = {
604
+ amino: (codons, [freq / (sum(frequencies) + 1e-100) for freq in frequencies])
605
+ for amino, (codons, frequencies) in amino2codon.items()
606
+ }
607
+
608
+ return amino2codon
609
+
610
+
611
+ def get_organism_to_codon_frequencies(
612
+ dataset: pd.DataFrame, organisms: List[str]
613
+ ) -> Dict[str, AMINO2CODON_TYPE]:
614
+ """
615
+ Return a dictionary mapping each organism to their codon frequency distribution.
616
+
617
+ Args:
618
+ dataset (pd.DataFrame): DataFrame containing DNA sequences.
619
+ organisms (List[str]): List of organisms.
620
+
621
+ Returns:
622
+ Dict[str, AMINO2CODON_TYPE]: Dictionary mapping each organism to its codon
623
+ frequency distribution.
624
+ """
625
+ organism2frequencies = {}
626
+
627
+ # Calculate codon frequencies for each organism in the dataset
628
+ for organism in tqdm(
629
+ organisms, desc="Calculating Codon Frequencies: ", unit="Organism"
630
+ ):
631
+ organism_data = dataset.loc[dataset["organism"] == organism]
632
+
633
+ dna_sequences = organism_data["dna"].to_list()
634
+ protein_sequences = organism_data["protein"].to_list()
635
+
636
+ codon_frequencies = get_codon_frequencies(dna_sequences, protein_sequences)
637
+ organism2frequencies[organism] = codon_frequencies
638
+
639
+ return organism2frequencies
640
+
641
+
642
+ def get_codon_table(organism: str) -> int:
643
+ """
644
+ Return the appropriate NCBI codon table for a given organism.
645
+
646
+ Args:
647
+ organism (str): Name of the organism.
648
+
649
+ Returns:
650
+ int: Codon table number.
651
+ """
652
+ # Common codon table (Table 1) for many model organisms
653
+ if organism in [
654
+ "Arabidopsis thaliana",
655
+ "Caenorhabditis elegans",
656
+ "Chlamydomonas reinhardtii",
657
+ "Saccharomyces cerevisiae",
658
+ "Danio rerio",
659
+ "Drosophila melanogaster",
660
+ "Homo sapiens",
661
+ "Mus musculus",
662
+ "Nicotiana tabacum",
663
+ "Solanum tuberosum",
664
+ "Solanum lycopersicum",
665
+ "Oryza sativa",
666
+ "Glycine max",
667
+ "Zea mays",
668
+ ]:
669
+ codon_table = 1
670
+
671
+ # Chloroplast codon table (Table 11)
672
+ elif organism in [
673
+ "Chlamydomonas reinhardtii chloroplast",
674
+ "Nicotiana tabacum chloroplast",
675
+ ]:
676
+ codon_table = 11
677
+
678
+ # Default to Table 11 for other bacteria and archaea
679
+ else:
680
+ codon_table = 11
681
+
682
+ return codon_table
CodonTransformer/CodonEvaluation.py ADDED
@@ -0,0 +1,583 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ File: CodonEvaluation.py
3
+ ---------------------------
4
+ Includes functions to calculate various evaluation metrics along with helper
5
+ functions.
6
+ """
7
+
8
+ from typing import Dict, List, Tuple, Optional
9
+
10
+ import pandas as pd
11
+ from CAI import CAI, relative_adaptiveness
12
+ from tqdm import tqdm
13
+ import math
14
+ import numpy as np
15
+ from collections import Counter
16
+ from itertools import chain
17
+ from statistics import mean
18
+ import sys
19
+ import os
20
+ from io import StringIO
21
+
22
+
23
+ def get_CSI_weights(sequences: List[str]) -> Dict[str, float]:
24
+ """
25
+ Calculate the Codon Similarity Index (CSI) weights for a list of DNA sequences.
26
+
27
+ Args:
28
+ sequences (List[str]): List of DNA sequences.
29
+
30
+ Returns:
31
+ dict: The CSI weights.
32
+ """
33
+ return relative_adaptiveness(sequences=sequences)
34
+
35
+
36
+ def get_CSI_value(dna: str, weights: Dict[str, float]) -> float:
37
+ """
38
+ Calculate the Codon Similarity Index (CSI) for a DNA sequence.
39
+
40
+ Args:
41
+ dna (str): The DNA sequence.
42
+ weights (dict): The CSI weights from get_CSI_weights.
43
+
44
+ Returns:
45
+ float: The CSI value.
46
+ """
47
+ return CAI(dna, weights)
48
+
49
+
50
+ def get_organism_to_CSI_weights(
51
+ dataset: pd.DataFrame, organisms: List[str]
52
+ ) -> Dict[str, dict]:
53
+ """
54
+ Calculate the Codon Similarity Index (CSI) weights for a list of organisms.
55
+
56
+ Args:
57
+ dataset (pd.DataFrame): Dataset containing organism and DNA sequence info.
58
+ organisms (List[str]): List of organism names.
59
+
60
+ Returns:
61
+ Dict[str, dict]: A dictionary mapping each organism to its CSI weights.
62
+ """
63
+ organism2weights = {}
64
+
65
+ # Iterate through each organism to calculate its CSI weights
66
+ for organism in tqdm(organisms, desc="Calculating CSI Weights: ", unit="Organism"):
67
+ organism_data = dataset.loc[dataset["organism"] == organism]
68
+ sequences = organism_data["dna"].to_list()
69
+ weights = get_CSI_weights(sequences)
70
+ organism2weights[organism] = weights
71
+
72
+ return organism2weights
73
+
74
+
75
+ def get_GC_content(dna: str) -> float:
76
+ """
77
+ Calculate the GC content of a DNA sequence.
78
+
79
+ GC content is the percentage of nucleotides that are either G (guanine) or C (cytosine).
80
+ This metric is important for codon optimization as it affects expression levels and
81
+ synthesis efficiency in E. coli.
82
+
83
+ Args:
84
+ dna (str): The DNA sequence (uppercase or lowercase).
85
+
86
+ Returns:
87
+ float: The GC content as a percentage (0-100).
88
+
89
+ Example:
90
+ >>> get_GC_content("ATGCGATCG")
91
+ 55.56 # 5 GC nucleotides out of 9 total
92
+ """
93
+ dna = dna.upper()
94
+ if not dna:
95
+ return 0.0
96
+ return (dna.count("G") + dna.count("C")) / len(dna) * 100
97
+
98
+
99
+ def get_cfd(
100
+ dna: str,
101
+ codon_frequencies: Dict[str, Tuple[List[str], List[float]]],
102
+ threshold: float = 0.3,
103
+ ) -> float:
104
+ """
105
+ Calculate the codon frequency distribution (CFD) metric for a DNA sequence.
106
+
107
+ Args:
108
+ dna (str): The DNA sequence.
109
+ codon_frequencies (Dict[str, Tuple[List[str], List[float]]]): Codon
110
+ frequency distribution per amino acid.
111
+ threshold (float): Frequency threshold for counting rare codons.
112
+
113
+ Returns:
114
+ float: The CFD metric as a percentage.
115
+ """
116
+ # Get a dictionary mapping each codon to its normalized frequency
117
+ codon2frequency = {
118
+ codon: freq / max(frequencies)
119
+ for amino, (codons, frequencies) in codon_frequencies.items()
120
+ for codon, freq in zip(codons, frequencies)
121
+ }
122
+
123
+ cfd = 0
124
+
125
+ # Iterate through the DNA sequence in steps of 3 to process each codon
126
+ for i in range(0, len(dna), 3):
127
+ codon = dna[i : i + 3]
128
+ codon_frequency = codon2frequency[codon]
129
+
130
+ if codon_frequency < threshold:
131
+ cfd += 1
132
+
133
+ return cfd / (len(dna) / 3) * 100
134
+
135
+
136
+ def get_min_max_percentage(
137
+ dna: str,
138
+ codon_frequencies: Dict[str, Tuple[List[str], List[float]]],
139
+ window_size: int = 18,
140
+ ) -> List[float]:
141
+ """
142
+ Calculate the %MinMax metric for a DNA sequence.
143
+
144
+ Args:
145
+ dna (str): The DNA sequence.
146
+ codon_frequencies (Dict[str, Tuple[List[str], List[float]]]): Codon
147
+ frequency distribution per amino acid.
148
+ window_size (int): Size of the window to calculate %MinMax.
149
+
150
+ Returns:
151
+ List[float]: List of %MinMax values for the sequence.
152
+
153
+ Credit: https://github.com/chowington/minmax
154
+ """
155
+ # Get a dictionary mapping each codon to its respective amino acid
156
+ codon2amino = {
157
+ codon: amino
158
+ for amino, (codons, frequencies) in codon_frequencies.items()
159
+ for codon in codons
160
+ }
161
+
162
+ min_max_values = []
163
+ codons = [dna[i : i + 3] for i in range(0, len(dna), 3)] # Split DNA into codons
164
+
165
+ # Iterate through the DNA sequence using the specified window size
166
+ for i in range(len(codons) - window_size + 1):
167
+ codon_window = codons[i : i + window_size] # Codons in the current window
168
+
169
+ Actual = 0.0 # Average of the actual codon frequencies
170
+ Max = 0.0 # Average of the min codon frequencies
171
+ Min = 0.0 # Average of the max codon frequencies
172
+ Avg = 0.0 # Average of the averages of all frequencies for each amino acid
173
+
174
+ # Sum the frequencies for codons in the current window
175
+ for codon in codon_window:
176
+ aminoacid = codon2amino[codon]
177
+ frequencies = codon_frequencies[aminoacid][1]
178
+ codon_index = codon_frequencies[aminoacid][0].index(codon)
179
+ codon_frequency = codon_frequencies[aminoacid][1][codon_index]
180
+
181
+ Actual += codon_frequency
182
+ Max += max(frequencies)
183
+ Min += min(frequencies)
184
+ Avg += sum(frequencies) / len(frequencies)
185
+
186
+ # Divide by the window size to get the averages
187
+ Actual = Actual / window_size
188
+ Max = Max / window_size
189
+ Min = Min / window_size
190
+ Avg = Avg / window_size
191
+
192
+ # Calculate %MinMax
193
+ percentMax = ((Actual - Avg) / (Max - Avg)) * 100
194
+ percentMin = ((Avg - Actual) / (Avg - Min)) * 100
195
+
196
+ # Append the appropriate %MinMax value
197
+ if percentMax >= 0:
198
+ min_max_values.append(percentMax)
199
+ else:
200
+ min_max_values.append(-percentMin)
201
+
202
+ # Populate the last floor(window_size / 2) entries of min_max_values with None
203
+ for i in range(int(window_size / 2)):
204
+ min_max_values.append(None)
205
+
206
+ return min_max_values
207
+
208
+
209
+ def get_sequence_complexity(dna: str) -> float:
210
+ """
211
+ Calculate the sequence complexity score of a DNA sequence.
212
+
213
+ Args:
214
+ dna (str): The DNA sequence.
215
+
216
+ Returns:
217
+ float: The sequence complexity score.
218
+ """
219
+
220
+ def sum_up_to(x):
221
+ """Recursive function to calculate the sum of integers from 1 to x."""
222
+ if x <= 1:
223
+ return 1
224
+ else:
225
+ return x + sum_up_to(x - 1)
226
+
227
+ def f(x):
228
+ """Returns 4 if x is greater than or equal to 4, else returns x."""
229
+ if x >= 4:
230
+ return 4
231
+ elif x < 4:
232
+ return x
233
+
234
+ unique_subseq_length = []
235
+
236
+ # Calculate unique subsequences lengths
237
+ for i in range(1, len(dna) + 1):
238
+ unique_subseq = set()
239
+ for j in range(len(dna) - (i - 1)):
240
+ unique_subseq.add(dna[j : (j + i)])
241
+ unique_subseq_length.append(len(unique_subseq))
242
+
243
+ # Calculate complexity score
244
+ complexity_score = (
245
+ sum(unique_subseq_length) / (sum_up_to(len(dna) - 1) + f(len(dna)))
246
+ ) * 100
247
+
248
+ return complexity_score
249
+
250
+
251
+ def get_sequence_similarity(
252
+ original: str, predicted: str, truncate: bool = True, window_length: int = 1
253
+ ) -> float:
254
+ """
255
+ Calculate the sequence similarity between two sequences.
256
+
257
+ Args:
258
+ original (str): The original sequence.
259
+ predicted (str): The predicted sequence.
260
+ truncate (bool): If True, truncate the original sequence to match the length
261
+ of the predicted sequence.
262
+ window_length (int): Length of the window for comparison (1 for amino acids,
263
+ 3 for codons).
264
+
265
+ Returns:
266
+ float: The sequence similarity as a percentage.
267
+
268
+ Preconditions:
269
+ len(predicted) <= len(original).
270
+ """
271
+ if not truncate and len(original) != len(predicted):
272
+ raise ValueError(
273
+ "Set truncate to True if the length of sequences do not match."
274
+ )
275
+
276
+ identity = 0.0
277
+ original = original.strip()
278
+ predicted = predicted.strip()
279
+
280
+ if truncate:
281
+ original = original[: len(predicted)]
282
+
283
+ if window_length == 1:
284
+ # Simple comparison for amino acid
285
+ for i in range(len(predicted)):
286
+ if original[i] == predicted[i]:
287
+ identity += 1
288
+ else:
289
+ # Comparison for substrings based on window_length
290
+ for i in range(0, len(original) - window_length + 1, window_length):
291
+ if original[i : i + window_length] == predicted[i : i + window_length]:
292
+ identity += 1
293
+
294
+ return (identity / (len(predicted) / window_length)) * 100
295
+
296
+
297
+ def scan_for_restriction_sites(seq: str, sites: List[str] = ['GAATTC', 'GGATCC', 'AAGCTT']) -> int:
298
+ """
299
+ Scans for a list of restriction enzyme sites in a DNA sequence.
300
+ """
301
+ return sum(seq.upper().count(site.upper()) for site in sites)
302
+
303
+
304
+ def count_negative_cis_elements(seq: str, motifs: List[str] = ['TATAAT', 'TTGACA', 'AGCTAGT']) -> int:
305
+ """
306
+ Counts occurrences of negative cis-regulatory elements in a DNA sequence.
307
+ """
308
+ return sum(seq.upper().count(m.upper()) for m in motifs)
309
+
310
+
311
+ def calculate_homopolymer_runs(seq: str, max_len: int = 8) -> int:
312
+ """
313
+ Calculates the number of homopolymer runs longer than a given length.
314
+ """
315
+ import re
316
+ min_len = max_len + 1
317
+ return len(re.findall(r'(A{%d,}|T{%d,}|G{%d,}|C{%d,})' % (min_len, min_len, min_len, min_len), seq.upper()))
318
+
319
+
320
+ def get_min_max_profile(
321
+ dna: str,
322
+ codon_frequencies: Dict[str, Tuple[List[str], List[float]]],
323
+ window_size: int = 18,
324
+ ) -> List[float]:
325
+ """
326
+ Calculate the %MinMax profile for a DNA sequence. This is a list of
327
+ %MinMax values for sliding windows across the sequence.
328
+
329
+ Args:
330
+ dna (str): The DNA sequence.
331
+ codon_frequencies (Dict[str, Tuple[List[str], List[float]]]): Codon
332
+ frequency distribution per amino acid.
333
+ window_size (int): Size of the window to calculate %MinMax.
334
+
335
+ Returns:
336
+ List[float]: List of %MinMax values for the sequence.
337
+ """
338
+ return get_min_max_percentage(dna, codon_frequencies, window_size)
339
+
340
+
341
+ def calculate_dtw_distance(profile1: List[float], profile2: List[float]) -> float:
342
+ """
343
+ Calculates the Dynamic Time Warping (DTW) distance between two profiles.
344
+
345
+ Args:
346
+ profile1 (List[float]): The first profile (e.g., %MinMax of generated sequence).
347
+ profile2 (List[float]): The second profile (e.g., %MinMax of natural sequence).
348
+
349
+ Returns:
350
+ float: The DTW distance between the two profiles.
351
+ """
352
+ from dtw import dtw
353
+ import numpy as np
354
+
355
+ # Ensure profiles are numpy arrays and handle potential None and NaN values
356
+ p1 = np.array([v for v in profile1 if v is not None and not np.isnan(v)]).reshape(
357
+ -1, 1
358
+ )
359
+ p2 = np.array([v for v in profile2 if v is not None and not np.isnan(v)]).reshape(
360
+ -1, 1
361
+ )
362
+
363
+ if len(p1) == 0 or len(p2) == 0:
364
+ return np.inf # Return infinity if one of the profiles is empty
365
+
366
+ alignment = dtw(p1, p2, keep_internals=True)
367
+ return alignment.distance # type: ignore
368
+
369
+
370
+ def get_ecoli_tai_weights():
371
+ """
372
+ Returns a dictionary of tAI weights for E. coli based on tRNA gene copy numbers.
373
+ These weights are pre-calculated based on the relative adaptiveness of each codon.
374
+ """
375
+ codons = [
376
+ "TTT", "TTC", "TTA", "TTG", "TCT", "TCC", "TCA", "TCG", "TAT", "TAC",
377
+ "TGT", "TGC", "TGG", "CTT", "CTC", "CTA", "CTG", "CCT", "CCC", "CCA",
378
+ "CCG", "CAT", "CAC", "CAA", "CAG", "CGT", "CGC", "CGA", "CGG", "ATT",
379
+ "ATC", "ATA", "ACT", "ACC", "ACA", "ACG", "AAT", "AAC", "AAA", "AAG",
380
+ "AGT", "AGC", "AGA", "AGG", "GTT", "GTC", "GTA", "GTG", "GCT", "GCC",
381
+ "GCA", "GCG", "GAT", "GAC", "GAA", "GAG", "GGT", "GGC", "GGA", "GGG"
382
+ ]
383
+ weights = [
384
+ 0.1966667, 0.3333333, 0.1666667, 0.2200000, 0.1966667, 0.3333333,
385
+ 0.1666667, 0.2200000, 0.2950000, 0.5000000, 0.09833333, 0.1666667,
386
+ 0.2200000, 0.09833333, 0.1666667, 0.1666667, 0.7200000, 0.09833333,
387
+ 0.1666667, 0.1666667, 0.2200000, 0.09833333, 0.1666667, 0.3333333,
388
+ 0.4400000, 0.6666667, 0.4800000, 0.00006666667, 0.1666667, 0.2950000,
389
+ 0.5000000, 0.01833333, 0.1966667, 0.3333333, 0.1666667, 0.3866667,
390
+ 0.3933333, 0.6666667, 1.0000000, 0.3200000, 0.09833333, 0.1666667,
391
+ 0.1666667, 0.2200000, 0.1966667, 0.3333333, 0.8333333, 0.2666667,
392
+ 0.1966667, 0.3333333, 0.5000000, 0.1600000, 0.2950000, 0.5000000,
393
+ 0.6666667, 0.2133333, 0.3933333, 0.6666667, 0.1666667, 0.2200000
394
+ ]
395
+ return dict(zip(codons, weights))
396
+
397
+
398
+ def calculate_tAI(sequence: str, tai_weights: Dict[str, float]) -> float:
399
+ """
400
+ Calculates the tRNA Adaptation Index (tAI) for a given DNA sequence.
401
+
402
+ Args:
403
+ sequence (str): The DNA sequence to analyze.
404
+ tai_weights (Dict[str, float]): A dictionary of tAI weights for each codon.
405
+
406
+ Returns:
407
+ float: The tAI value for the sequence.
408
+ """
409
+ from scipy.stats.mstats import gmean
410
+
411
+ codons = [sequence[i:i+3] for i in range(0, len(sequence), 3)]
412
+
413
+ # Filter out stop codons and codons not in weights
414
+ weights = [tai_weights[codon] for codon in codons if codon in tai_weights and tai_weights[codon] > 0]
415
+
416
+ if not weights:
417
+ return 0.0
418
+
419
+ return gmean(weights)
420
+
421
+
422
+ def calculate_ENC(sequence: str) -> float:
423
+ """
424
+ Calculate the Effective Number of Codons (ENC) for a DNA sequence.
425
+ Uses the codonbias library implementation based on Wright (1990).
426
+
427
+ Args:
428
+ sequence (str): The DNA sequence.
429
+
430
+ Returns:
431
+ float: The ENC value for the sequence.
432
+ """
433
+ try:
434
+ from codonbias.scores import EffectiveNumberOfCodons
435
+
436
+ # Initialize ENC calculator
437
+ enc_calculator = EffectiveNumberOfCodons(
438
+ k_mer=1, # Standard codon analysis
439
+ bg_correction=True, # Use background correction
440
+ robust=True, # Use robust calculation
441
+ genetic_code=1 # Standard genetic code
442
+ )
443
+
444
+ # Calculate ENC for the sequence
445
+ enc_value = enc_calculator.get_score(sequence)
446
+
447
+ return float(enc_value)
448
+
449
+ except ImportError:
450
+ raise ImportError("codonbias library is required for ENC calculation. Install with: pip install codonbias")
451
+ except Exception as e:
452
+ # Fallback to a simple ENC approximation if library fails
453
+ print(f"Warning: ENC calculation failed with error: {e}. Using approximation.")
454
+ return 45.0 # Typical E. coli ENC value as fallback
455
+
456
+
457
+ def calculate_CPB(sequence: str, reference_sequences: Optional[List[str]] = None) -> float:
458
+ """
459
+ Calculate the Codon Pair Bias (CPB) for a DNA sequence.
460
+ Uses the codonbias library implementation based on Coleman et al. (2008).
461
+
462
+ Args:
463
+ sequence (str): The DNA sequence.
464
+ reference_sequences (List[str]): Reference sequences for calculating expected values.
465
+ If None, uses a default E. coli reference.
466
+
467
+ Returns:
468
+ float: The CPB value for the sequence.
469
+ """
470
+ try:
471
+ from codonbias.scores import CodonPairBias
472
+
473
+ # Use provided reference sequences or default
474
+ if reference_sequences is None:
475
+ # Use the input sequence as reference if none provided
476
+ reference_sequences = [sequence]
477
+
478
+ # Initialize CPB calculator with reference sequences
479
+ cpb_calculator = CodonPairBias(
480
+ ref_seq=reference_sequences,
481
+ k_mer=2, # Codon pairs
482
+ genetic_code=1, # Standard genetic code
483
+ ignore_stop=True, # Ignore stop codons
484
+ pseudocount=1 # Pseudocount for unseen pairs
485
+ )
486
+
487
+ # Calculate CPB for the sequence
488
+ cpb_value = cpb_calculator.get_score(sequence)
489
+
490
+ return float(cpb_value)
491
+
492
+ except ImportError:
493
+ raise ImportError("codonbias library is required for CPB calculation. Install with: pip install codonbias")
494
+ except Exception as e:
495
+ # Fallback calculation if library fails
496
+ print(f"Warning: CPB calculation failed with error: {e}. Using approximation.")
497
+ return 0.0 # Neutral CPB as fallback
498
+
499
+
500
+ def calculate_SCUO(sequence: str) -> float:
501
+ """
502
+ Calculate the Synonymous Codon Usage Order (SCUO) for a DNA sequence.
503
+ Uses the GCUA library implementation based on information theory.
504
+
505
+ Args:
506
+ sequence (str): The DNA sequence.
507
+
508
+ Returns:
509
+ float: The SCUO value (0-1, where 1 indicates maximum bias).
510
+ """
511
+ # Self-contained SCUO implementation (no external GCUA dependency).
512
+ # Based on Wan et al., 2004 information-theoretic definition.
513
+
514
+ from math import log2 # local import to avoid global cost
515
+ try:
516
+ # Build standard genetic code mapping using built-in tables (Biopython optional).
517
+ # Fall back to hard-coded table if Biopython absent.
518
+ try:
519
+ from Bio.Data import CodonTable # type: ignore
520
+ codon_to_aa = CodonTable.unambiguous_dna_by_id[1].forward_table
521
+ except Exception:
522
+ codon_to_aa = {
523
+ # Partial table sufficient for SCUO calculation; stop codons omitted.
524
+ 'TTT': 'F', 'TTC': 'F', 'TTA': 'L', 'TTG': 'L',
525
+ 'CTT': 'L', 'CTC': 'L', 'CTA': 'L', 'CTG': 'L',
526
+ 'ATT': 'I', 'ATC': 'I', 'ATA': 'I', 'ATG': 'M',
527
+ 'GTT': 'V', 'GTC': 'V', 'GTA': 'V', 'GTG': 'V',
528
+ 'TCT': 'S', 'TCC': 'S', 'TCA': 'S', 'TCG': 'S',
529
+ 'CCT': 'P', 'CCC': 'P', 'CCA': 'P', 'CCG': 'P',
530
+ 'ACT': 'T', 'ACC': 'T', 'ACA': 'T', 'ACG': 'T',
531
+ 'GCT': 'A', 'GCC': 'A', 'GCA': 'A', 'GCG': 'A',
532
+ 'TAT': 'Y', 'TAC': 'Y', 'TAA': '*', 'TAG': '*',
533
+ 'CAT': 'H', 'CAC': 'H', 'CAA': 'Q', 'CAG': 'Q',
534
+ 'AAT': 'N', 'AAC': 'N', 'AAA': 'K', 'AAG': 'K',
535
+ 'GAT': 'D', 'GAC': 'D', 'GAA': 'E', 'GAG': 'E',
536
+ 'TGT': 'C', 'TGC': 'C', 'TGA': '*', 'TGG': 'W',
537
+ 'CGT': 'R', 'CGC': 'R', 'CGA': 'R', 'CGG': 'R',
538
+ 'AGT': 'S', 'AGC': 'S', 'AGA': 'R', 'AGG': 'R',
539
+ 'GGT': 'G', 'GGC': 'G', 'GGA': 'G', 'GGG': 'G',
540
+ }
541
+
542
+ # Group codons by amino acid (exclude stops)
543
+ aa_to_codons = {}
544
+ for codon, aa in codon_to_aa.items():
545
+ aa_to_codons.setdefault(aa, []).append(codon)
546
+
547
+ # Count codon occurrences in input sequence
548
+ seq = sequence.upper().replace('U', 'T')
549
+ codon_counts = {}
550
+ for i in range(0, len(seq) - len(seq) % 3, 3):
551
+ codon = seq[i:i+3]
552
+ if codon in codon_to_aa:
553
+ codon_counts[codon] = codon_counts.get(codon, 0) + 1
554
+
555
+ total_codons = sum(codon_counts.values())
556
+ if total_codons == 0:
557
+ return 0.0
558
+
559
+ scuo_sum = 0.0
560
+
561
+ for aa, codons in aa_to_codons.items():
562
+ n_codons = len(codons)
563
+ if n_codons == 1:
564
+ continue # SCUO undefined for Met/Trp
565
+
566
+ counts = [codon_counts.get(c, 0) for c in codons]
567
+ total_aa = sum(counts)
568
+ if total_aa == 0:
569
+ continue
570
+
571
+ probs = [c / total_aa for c in counts if c]
572
+ H_obs = -sum(p * log2(p) for p in probs)
573
+ H_max = log2(n_codons)
574
+ O_i = (H_max - H_obs) / H_max if H_max else 0.0
575
+ F_i = total_aa / total_codons
576
+ scuo_sum += F_i * O_i
577
+
578
+ return scuo_sum
579
+
580
+ except Exception as exc:
581
+ print(f"Warning: internal SCUO computation failed ({exc}). Returning 0.5.")
582
+ return 0.5
583
+
CodonTransformer/CodonJupyter.py ADDED
@@ -0,0 +1,311 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ File: CodonJupyter.py
3
+ ---------------------
4
+ Includes Jupyter-specific functions for displaying interactive widgets.
5
+ """
6
+
7
+ from typing import Dict, List, Tuple
8
+
9
+ import ipywidgets as widgets
10
+ from IPython.display import HTML, display
11
+
12
+ from CodonTransformer.CodonUtils import (
13
+ COMMON_ORGANISMS,
14
+ ID2ORGANISM,
15
+ ORGANISM2ID,
16
+ DNASequencePrediction,
17
+ )
18
+
19
+
20
+ class UserContainer:
21
+ """
22
+ A container class to store user inputs for organism and protein sequence.
23
+ Attributes:
24
+ organism (int): The selected organism id.
25
+ protein (str): The input protein sequence.
26
+ """
27
+
28
+ def __init__(self) -> None:
29
+ self.organism: int = -1
30
+ self.protein: str = ""
31
+
32
+
33
+ def create_styled_options(
34
+ organisms: list, organism2id: Dict[str, int], is_fine_tuned: bool = False
35
+ ) -> list:
36
+ """
37
+ Create styled options for the dropdown widget.
38
+
39
+ Args:
40
+ organisms (list): List of organism names.
41
+ organism2id (Dict[str, int]): Dictionary mapping organism names to their IDs.
42
+ is_fine_tuned (bool): Whether these are fine-tuned organisms.
43
+
44
+ Returns:
45
+ list: Styled options for the dropdown widget.
46
+ """
47
+ styled_options = []
48
+ for organism in organisms:
49
+ organism_id = organism2id[organism]
50
+ if is_fine_tuned:
51
+ if organism_id < 10:
52
+ styled_options.append(f"\u200b{organism_id:>6}. {organism}")
53
+ elif organism_id < 100:
54
+ styled_options.append(f"\u200b{organism_id:>5}. {organism}")
55
+ else:
56
+ styled_options.append(f"\u200b{organism_id:>4}. {organism}")
57
+ else:
58
+ if organism_id < 10:
59
+ styled_options.append(f"{organism_id:>6}. {organism}")
60
+ elif organism_id < 100:
61
+ styled_options.append(f"{organism_id:>5}. {organism}")
62
+ else:
63
+ styled_options.append(f"{organism_id:>4}. {organism}")
64
+ return styled_options
65
+
66
+
67
+ def create_dropdown_options(organism2id: Dict[str, int]) -> list:
68
+ """
69
+ Create the full list of dropdown options, including section headers.
70
+
71
+ Args:
72
+ organism2id (Dict[str, int]): Dictionary mapping organism names to their IDs.
73
+
74
+ Returns:
75
+ list: Full list of dropdown options.
76
+ """
77
+ fine_tuned_organisms = sorted(
78
+ [org for org in organism2id.keys() if org in COMMON_ORGANISMS]
79
+ )
80
+ all_organisms = sorted(organism2id.keys())
81
+
82
+ fine_tuned_options = create_styled_options(
83
+ fine_tuned_organisms, organism2id, is_fine_tuned=True
84
+ )
85
+ all_organisms_options = create_styled_options(
86
+ all_organisms, organism2id, is_fine_tuned=False
87
+ )
88
+
89
+ return (
90
+ [""]
91
+ + ["Selected Organisms"]
92
+ + fine_tuned_options
93
+ + [""]
94
+ + ["All Organisms"]
95
+ + all_organisms_options
96
+ )
97
+
98
+
99
+ def create_organism_dropdown(container: UserContainer) -> widgets.Dropdown:
100
+ """
101
+ Create and configure the organism dropdown widget.
102
+
103
+ Args:
104
+ container (UserContainer): Container to store the selected organism.
105
+
106
+ Returns:
107
+ widgets.Dropdown: Configured dropdown widget.
108
+ """
109
+ dropdown = widgets.Dropdown(
110
+ options=create_dropdown_options(ORGANISM2ID),
111
+ description="",
112
+ layout=widgets.Layout(width="40%", margin="0 0 10px 0"),
113
+ style={"description_width": "initial"},
114
+ )
115
+
116
+ def show_organism(change: Dict[str, str]) -> None:
117
+ """
118
+ Update the container with the selected organism and print to terminal.
119
+
120
+ Args:
121
+ change (Dict[str, str]): Information about the change in dropdown value.
122
+ """
123
+ dropdown_choice = change["new"]
124
+ if dropdown_choice and dropdown_choice not in [
125
+ "Selected Organisms",
126
+ "All Organisms",
127
+ ]:
128
+ organism = "".join(filter(str.isdigit, dropdown_choice))
129
+ organism_id = ID2ORGANISM[int(organism)]
130
+ container.organism = organism_id
131
+ else:
132
+ container.organism = None
133
+
134
+ dropdown.observe(show_organism, names="value")
135
+ return dropdown
136
+
137
+
138
+ def get_dropdown_style() -> str:
139
+ """
140
+ Return the custom CSS style for the dropdown widget.
141
+
142
+ Returns:
143
+ str: CSS style string.
144
+ """
145
+ return """
146
+ <style>
147
+ .widget-dropdown > select {
148
+ font-size: 16px;
149
+ font-weight: normal;
150
+ background-color: #f0f0f0;
151
+ border-radius: 5px;
152
+ padding: 5px;
153
+ }
154
+ .widget-label {
155
+ font-size: 18px;
156
+ font-weight: bold;
157
+ }
158
+ .custom-container {
159
+ display: flex;
160
+ flex-direction: column;
161
+ align-items: flex-start;
162
+ }
163
+ .widget-dropdown option[value^="\u200b"] {
164
+ font-family: sans-serif;
165
+ font-weight: bold;
166
+ font-size: 18px;
167
+ padding: 510px;
168
+ }
169
+ .widget-dropdown option[value*="Selected Organisms"],
170
+ .widget-dropdown option[value*="All Organisms"] {
171
+ text-align: center;
172
+ font-family: Arial, sans-serif;
173
+ font-weight: bold;
174
+ font-size: 20px;
175
+ color: #6900A1;
176
+ background-color: #00D8A1;
177
+ }
178
+ </style>
179
+ """
180
+
181
+
182
+ def display_organism_dropdown(container: UserContainer) -> None:
183
+ """
184
+ Display the organism dropdown widget and apply custom styles.
185
+
186
+ Args:
187
+ container (UserContainer): Container to store the selected organism.
188
+ """
189
+ dropdown = create_organism_dropdown(container)
190
+ header = widgets.HTML(
191
+ '<b style="font-size:20px;">Select Organism:</b>'
192
+ '<div style="height:10px;"></div>'
193
+ )
194
+ container_widget = widgets.VBox(
195
+ [header, dropdown],
196
+ layout=widgets.Layout(padding="12px 0 12px 25px"),
197
+ )
198
+ display(container_widget)
199
+ display(HTML(get_dropdown_style()))
200
+
201
+
202
+ def display_protein_input(container: UserContainer) -> None:
203
+ """
204
+ Display a widget for entering a protein sequence and save it to the container.
205
+
206
+ Args:
207
+ container (UserContainer): A container to store the entered protein sequence.
208
+ """
209
+ protein_input = widgets.Textarea(
210
+ value="",
211
+ placeholder="Enter here...",
212
+ description="",
213
+ layout=widgets.Layout(width="100%", height="100px", margin="0 0 10px 0"),
214
+ style={"description_width": "initial"},
215
+ )
216
+
217
+ # Custom CSS for the input widget
218
+ input_style = """
219
+ <style>
220
+ .widget-textarea > textarea {
221
+ font-size: 12px;
222
+ font-family: Arial, sans-serif;
223
+ font-weight: normal;
224
+ background-color: #f0f0f0;
225
+ border-radius: 5px;
226
+ padding: 10px;
227
+ }
228
+ .widget-label {
229
+ font-size: 18px;
230
+ font-weight: bold;
231
+ }
232
+ .custom-container {
233
+ display: flex;
234
+ flex-direction: column;
235
+ align-items: flex-start;
236
+ }
237
+ </style>
238
+ """
239
+
240
+ # Function to save the input protein sequence to the container
241
+ def save_protein(change: Dict[str, str]) -> None:
242
+ """
243
+ Save the input protein sequence to the container.
244
+
245
+ Args:
246
+ change (Dict[str, str]): A dictionary containing information about
247
+ the change in textarea value.
248
+ """
249
+ container.protein = (
250
+ change["new"]
251
+ .upper()
252
+ .strip()
253
+ .replace("\n", "")
254
+ .replace(" ", "")
255
+ .replace("\t", "")
256
+ )
257
+
258
+ # Attach the function to the input widget
259
+ protein_input.observe(save_protein, names="value")
260
+
261
+ # Display the input widget
262
+ header = widgets.HTML(
263
+ '<b style="font-size:20px;">Enter Protein Sequence:</b>'
264
+ '<div style="height:18px;"></div>'
265
+ )
266
+ container_widget = widgets.VBox(
267
+ [header, protein_input], layout=widgets.Layout(padding="12px 12px 0 25px")
268
+ )
269
+
270
+ display(container_widget)
271
+ display(widgets.HTML(input_style))
272
+
273
+
274
+ def format_model_output(output: DNASequencePrediction) -> str:
275
+ """
276
+ Format DNA sequence prediction output in an appealing and easy-to-read manner.
277
+
278
+ This function takes the prediction output and formats it into
279
+ a structured string with clear section headers and separators.
280
+
281
+ Args:
282
+ output (DNASequencePrediction): Object containing the prediction output.
283
+ Expected attributes:
284
+ - organism (str): The organism name.
285
+ - protein (str): The input protein sequence.
286
+ - processed_input (str): The processed input sequence.
287
+ - predicted_dna (str): The predicted DNA sequence.
288
+
289
+ Returns:
290
+ str: A formatted string containing the organized output.
291
+ """
292
+
293
+ def format_section(title: str, content: str) -> str:
294
+ """Helper function to format individual sections."""
295
+ separator = "-" * 29
296
+ title_line = f"| {title.center(25)} |"
297
+ return f"{separator}\n{title_line}\n{separator}\n{content}\n\n"
298
+
299
+ sections: List[Tuple[str, str]] = [
300
+ ("Organism", output.organism),
301
+ ("Input Protein", output.protein),
302
+ ("Processed Input", output.processed_input),
303
+ ("Predicted DNA", output.predicted_dna),
304
+ ]
305
+
306
+ formatted_output = ""
307
+ for title, content in sections:
308
+ formatted_output += format_section(title, content)
309
+
310
+ # Remove the last newline to avoid extra space at the end
311
+ return formatted_output.rstrip()
CodonTransformer/CodonPostProcessing.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ File: CodonPostProcessing.py
3
+ ---------------------------
4
+ Post-processing utilities for codon optimization using DNAChisel.
5
+ This module provides sequence polishing capabilities to fix restriction sites,
6
+ homopolymers, and other constraints while preserving CAI and GC content.
7
+ """
8
+
9
+ import warnings
10
+ import numpy as np
11
+
12
+ try:
13
+ from dnachisel import (
14
+ DnaOptimizationProblem,
15
+ AvoidPattern,
16
+ EnforceGCContent,
17
+ EnforceTranslation,
18
+ CodonOptimize,
19
+ )
20
+ DNACHISEL_AVAILABLE = True
21
+ except ImportError:
22
+ DNACHISEL_AVAILABLE = False
23
+ # This warning will be shown when the module is first imported.
24
+ warnings.warn(
25
+ "DNAChisel is not installed. Post-processing features will be disabled."
26
+ )
27
+
28
+ def polish_sequence_with_dnachisel(
29
+ dna_sequence: str,
30
+ protein_sequence: str,
31
+ gc_bounds: tuple = (45.0, 55.0),
32
+ cai_species: str = "e_coli",
33
+ avoid_homopolymers_length: int = 6,
34
+ enzymes_to_avoid: list = None
35
+ ):
36
+ """
37
+ Polishes a DNA sequence using DNAChisel to meet lab synthesis constraints.
38
+ """
39
+ if not DNACHISEL_AVAILABLE:
40
+ warnings.warn("DNAChisel not available, skipping post-processing.")
41
+ return dna_sequence
42
+
43
+ if enzymes_to_avoid is None:
44
+ # Common cloning enzymes
45
+ enzymes_to_avoid = ["EcoRI", "XbaI", "SpeI", "PstI", "NotI"]
46
+
47
+ try:
48
+ # Start with the basic, essential constraints
49
+ constraints = [
50
+ EnforceTranslation(translation=protein_sequence),
51
+ EnforceGCContent(mini=gc_bounds[0] / 100.0, maxi=gc_bounds[1] / 100.0),
52
+ ]
53
+
54
+ # Add enzyme avoidance constraints safely
55
+ for enzyme in enzymes_to_avoid:
56
+ try:
57
+ # This is the modern way to avoid enzyme sites
58
+ constraints.append(AvoidPattern.from_enzyme_name(enzyme))
59
+ except Exception:
60
+ warnings.warn(f"Could not find enzyme '{enzyme}' in DNAChisel library.")
61
+
62
+ # Add homopolymer avoidance constraints
63
+ for base in "ATGC":
64
+ constraints.append(AvoidPattern(base * avoid_homopolymers_length))
65
+
66
+ # Define the optimization problem
67
+ problem = DnaOptimizationProblem(
68
+ sequence=dna_sequence,
69
+ constraints=constraints,
70
+ objectives=[CodonOptimize(species=cai_species, method="match_codon_usage")]
71
+ )
72
+
73
+ # Solve the problem
74
+ problem.resolve_constraints()
75
+ problem.optimize()
76
+
77
+ # Return the polished sequence
78
+ return problem.sequence
79
+
80
+ except Exception as e:
81
+ warnings.warn(f"DNAChisel post-processing failed with an error: {e}")
82
+ # Return the original sequence if polishing fails
83
+ return dna_sequence
CodonTransformer/CodonPrediction.py ADDED
@@ -0,0 +1,1372 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ File: CodonPrediction.py
3
+ ---------------------------
4
+ Includes functions to tokenize input, load models, infer predicted dna sequences and
5
+ helper functions related to processing data for passing to the model.
6
+ """
7
+
8
+ import warnings
9
+ from typing import Any, Dict, List, Optional, Tuple, Union
10
+ import heapq
11
+ from dataclasses import dataclass
12
+
13
+ import numpy as np
14
+ import onnxruntime as rt
15
+ import torch
16
+ import transformers
17
+ from transformers import (
18
+ AutoTokenizer,
19
+ BatchEncoding,
20
+ BigBirdConfig,
21
+ BigBirdForMaskedLM,
22
+ PreTrainedTokenizerFast,
23
+ )
24
+
25
+ from CodonTransformer.CodonData import get_merged_seq
26
+ from CodonTransformer.CodonUtils import (
27
+ AMINO_ACID_TO_INDEX,
28
+ INDEX2TOKEN,
29
+ NUM_ORGANISMS,
30
+ ORGANISM2ID,
31
+ TOKEN2INDEX,
32
+ DNASequencePrediction,
33
+ GC_COUNTS_PER_TOKEN,
34
+ CODON_GC_CONTENT,
35
+ AA_MIN_GC,
36
+ AA_MAX_GC,
37
+ )
38
+
39
+
40
+ def predict_dna_sequence(
41
+ protein: str,
42
+ organism: Union[int, str],
43
+ device: torch.device,
44
+ tokenizer: Union[str, PreTrainedTokenizerFast] = None,
45
+ model: Union[str, torch.nn.Module] = None,
46
+ attention_type: str = "original_full",
47
+ deterministic: bool = True,
48
+ temperature: float = 0.2,
49
+ top_p: float = 0.95,
50
+ num_sequences: int = 1,
51
+ match_protein: bool = False,
52
+ use_constrained_search: bool = False,
53
+ gc_bounds: Tuple[float, float] = (0.30, 0.70),
54
+ beam_size: int = 5,
55
+ length_penalty: float = 1.0,
56
+ diversity_penalty: float = 0.0,
57
+ ) -> Union[DNASequencePrediction, List[DNASequencePrediction]]:
58
+ """
59
+ Predict the DNA sequence(s) for a given protein using the CodonTransformer model.
60
+
61
+ This function takes a protein sequence and an organism (as ID or name) as input
62
+ and returns the predicted DNA sequence(s) using the CodonTransformer model. It can use
63
+ either provided tokenizer and model objects or load them from specified paths.
64
+
65
+ Args:
66
+ protein (str): The input protein sequence for which to predict the DNA sequence.
67
+ organism (Union[int, str]): Either the ID of the organism or its name (e.g.,
68
+ "Escherichia coli general"). If a string is provided, it will be converted
69
+ to the corresponding ID using ORGANISM2ID.
70
+ device (torch.device): The device (CPU or GPU) to run the model on.
71
+ tokenizer (Union[str, PreTrainedTokenizerFast, None], optional): Either a file
72
+ path to load the tokenizer from, a pre-loaded tokenizer object, or None. If
73
+ None, it will be loaded from HuggingFace. Defaults to None.
74
+ model (Union[str, torch.nn.Module, None], optional): Either a file path to load
75
+ the model from, a pre-loaded model object, or None. If None, it will be
76
+ loaded from HuggingFace. Defaults to None.
77
+ attention_type (str, optional): The type of attention mechanism to use in the
78
+ model. Can be either 'block_sparse' or 'original_full'. Defaults to
79
+ "original_full".
80
+ deterministic (bool, optional): Whether to use deterministic decoding (most
81
+ likely tokens). If False, samples tokens according to their probabilities
82
+ adjusted by the temperature. Defaults to True.
83
+ temperature (float, optional): A value controlling the randomness of predictions
84
+ during non-deterministic decoding. Lower values (e.g., 0.2) make the model
85
+ more conservative, while higher values (e.g., 0.8) increase randomness.
86
+ Using high temperatures may result in prediction of DNA sequences that
87
+ do not translate to the input protein.
88
+ Recommended values are:
89
+ - Low randomness: 0.2
90
+ - Medium randomness: 0.5
91
+ - High randomness: 0.8
92
+ The temperature must be a positive float. Defaults to 0.2.
93
+ top_p (float, optional): The cumulative probability threshold for nucleus sampling.
94
+ Tokens with cumulative probability up to top_p are considered for sampling.
95
+ This parameter helps balance diversity and coherence in the predicted DNA sequences.
96
+ The value must be a float between 0 and 1. Defaults to 0.95.
97
+ num_sequences (int, optional): The number of DNA sequences to generate. Only applicable
98
+ when deterministic is False. Defaults to 1.
99
+ match_protein (bool, optional): Ensures the predicted DNA sequence is translated
100
+ to the input protein sequence by sampling from only the respective codons of
101
+ given amino acids. Defaults to False.
102
+ use_constrained_search (bool, optional): Whether to use constrained beam search
103
+ with GC content bounds. Defaults to False.
104
+ gc_bounds (Tuple[float, float], optional): GC content bounds (min, max) for
105
+ constrained search. Defaults to (0.30, 0.70).
106
+ beam_size (int, optional): Beam size for constrained search. Defaults to 5.
107
+ length_penalty (float, optional): Length penalty for beam search scoring.
108
+ Defaults to 1.0.
109
+ diversity_penalty (float, optional): Diversity penalty to reduce repetitive
110
+ sequences. Defaults to 0.0.
111
+
112
+ Returns:
113
+ Union[DNASequencePrediction, List[DNASequencePrediction]]: An object or list of objects
114
+ containing the prediction results:
115
+ - organism (str): Name of the organism used for prediction.
116
+ - protein (str): Input protein sequence for which DNA sequence is predicted.
117
+ - processed_input (str): Processed input sequence (merged protein and DNA).
118
+ - predicted_dna (str): Predicted DNA sequence.
119
+
120
+ Raises:
121
+ ValueError: If the protein sequence is empty, if the organism is invalid,
122
+ if the temperature is not a positive float, if top_p is not between 0 and 1,
123
+ or if num_sequences is less than 1 or used with deterministic mode.
124
+
125
+ Note:
126
+ This function uses ORGANISM2ID, INDEX2TOKEN, and AMINO_ACID_TO_INDEX dictionaries
127
+ imported from CodonTransformer.CodonUtils. ORGANISM2ID maps organism names to their
128
+ corresponding IDs. INDEX2TOKEN maps model output indices (token IDs) to
129
+ respective codons. AMINO_ACID_TO_INDEX maps each amino acid and stop symbol to indices
130
+ of codon tokens that translate to it.
131
+
132
+ Example:
133
+ >>> import torch
134
+ >>> from transformers import AutoTokenizer, BigBirdForMaskedLM
135
+ >>> from CodonTransformer.CodonPrediction import predict_dna_sequence
136
+ >>> from CodonTransformer.CodonJupyter import format_model_output
137
+ >>>
138
+ >>> # Set up device
139
+ >>> device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
140
+ >>>
141
+ >>> # Load tokenizer and model
142
+ >>> tokenizer = AutoTokenizer.from_pretrained("adibvafa/CodonTransformer")
143
+ >>> model = BigBirdForMaskedLM.from_pretrained("adibvafa/CodonTransformer")
144
+ >>> model = model.to(device)
145
+ >>>
146
+ >>> # Define protein sequence and organism
147
+ >>> protein = "MKTVRQERLKSIVRILERSKEPVSGAQLAEELSVSRQVIVQDIAYLRSLGYNIVATPRGYVLA"
148
+ >>> organism = "Escherichia coli general"
149
+ >>>
150
+ >>> # Predict DNA sequence with deterministic decoding (single sequence)
151
+ >>> output = predict_dna_sequence(
152
+ ... protein=protein,
153
+ ... organism=organism,
154
+ ... device=device,
155
+ ... tokenizer=tokenizer,
156
+ ... model=model,
157
+ ... attention_type="original_full",
158
+ ... deterministic=True
159
+ ... )
160
+ >>>
161
+ >>> # Predict DNA sequence with constrained beam search
162
+ >>> output_constrained = predict_dna_sequence(
163
+ ... protein=protein,
164
+ ... organism=organism,
165
+ ... device=device,
166
+ ... tokenizer=tokenizer,
167
+ ... model=model,
168
+ ... use_constrained_search=True,
169
+ ... gc_bounds=(0.40, 0.60),
170
+ ... beam_size=10,
171
+ ... length_penalty=1.2,
172
+ ... diversity_penalty=0.1
173
+ ... )
174
+ >>>
175
+ >>> # Predict multiple DNA sequences with low randomness and top_p sampling
176
+ >>> output_random = predict_dna_sequence(
177
+ ... protein=protein,
178
+ ... organism=organism,
179
+ ... device=device,
180
+ ... tokenizer=tokenizer,
181
+ ... model=model,
182
+ ... attention_type="original_full",
183
+ ... deterministic=False,
184
+ ... temperature=0.2,
185
+ ... top_p=0.95,
186
+ ... num_sequences=3
187
+ ... )
188
+ >>>
189
+ >>> print(format_model_output(output))
190
+ >>> for i, seq in enumerate(output_random, 1):
191
+ ... print(f"Sequence {i}:")
192
+ ... print(format_model_output(seq))
193
+ ... print()
194
+ """
195
+ if not protein:
196
+ raise ValueError("Protein sequence cannot be empty.")
197
+
198
+ if not isinstance(temperature, (float, int)) or temperature <= 0:
199
+ raise ValueError("Temperature must be a positive float.")
200
+
201
+ if not isinstance(top_p, (float, int)) or not 0 < top_p <= 1.0:
202
+ raise ValueError("top_p must be a float between 0 and 1.")
203
+
204
+ if not isinstance(num_sequences, int) or num_sequences < 1:
205
+ raise ValueError("num_sequences must be a positive integer.")
206
+
207
+ if use_constrained_search:
208
+ if not isinstance(gc_bounds, tuple) or len(gc_bounds) != 2:
209
+ raise ValueError("gc_bounds must be a tuple of (min_gc, max_gc).")
210
+
211
+ if not (0.0 <= gc_bounds[0] <= gc_bounds[1] <= 1.0):
212
+ raise ValueError("gc_bounds must be between 0.0 and 1.0 with min <= max.")
213
+
214
+ if not isinstance(beam_size, int) or beam_size < 1:
215
+ raise ValueError("beam_size must be a positive integer.")
216
+
217
+ if deterministic and num_sequences > 1 and not use_constrained_search:
218
+ raise ValueError(
219
+ "Multiple sequences can only be generated in non-deterministic mode "
220
+ "(unless using constrained search)."
221
+ )
222
+
223
+ if use_constrained_search and num_sequences > 1:
224
+ raise ValueError(
225
+ "Constrained beam search currently supports only single sequence generation."
226
+ )
227
+
228
+ # Load tokenizer
229
+ if not isinstance(tokenizer, PreTrainedTokenizerFast):
230
+ tokenizer = load_tokenizer(tokenizer)
231
+
232
+ # Load model
233
+ if not isinstance(model, torch.nn.Module):
234
+ model = load_model(model_path=model, device=device, attention_type=attention_type)
235
+ else:
236
+ model.eval()
237
+ model.bert.set_attention_type(attention_type)
238
+ model.to(device)
239
+
240
+ # Validate organism and convert to organism_id and organism_name
241
+ organism_id, organism_name = validate_and_convert_organism(organism)
242
+
243
+ # Inference loop
244
+ with torch.no_grad():
245
+ # Tokenize the input sequence
246
+ merged_seq = get_merged_seq(protein=protein, dna="")
247
+ input_dict = {
248
+ "idx": 0, # sample index
249
+ "codons": merged_seq,
250
+ "organism": organism_id,
251
+ }
252
+ tokenized_input = tokenize([input_dict], tokenizer=tokenizer).to(device)
253
+
254
+ # Get the model predictions
255
+ output_dict = model(**tokenized_input, return_dict=True)
256
+ logits = output_dict.logits.detach().cpu()
257
+ logits = logits[:, 1:-1, :] # Remove [CLS] and [SEP] tokens
258
+
259
+ # Mask the logits of codons that do not correspond to the input protein sequence
260
+ if match_protein:
261
+ possible_tokens_per_position = [
262
+ AMINO_ACID_TO_INDEX[token[0]] for token in merged_seq.split(" ")
263
+ ]
264
+ seq_len = logits.shape[1]
265
+ if len(possible_tokens_per_position) > seq_len:
266
+ possible_tokens_per_position = possible_tokens_per_position[:seq_len]
267
+
268
+ mask = torch.full_like(logits, float("-inf"))
269
+
270
+ for pos, possible_tokens in enumerate(possible_tokens_per_position):
271
+ mask[:, pos, possible_tokens] = 0
272
+
273
+ logits = mask + logits
274
+
275
+ predictions = []
276
+ for _ in range(num_sequences):
277
+ # Decode the predicted DNA sequence from the model output
278
+ if use_constrained_search:
279
+ # Use constrained beam search with GC bounds
280
+ predicted_indices = constrained_beam_search_simple(
281
+ logits=logits.squeeze(0),
282
+ protein_sequence=protein,
283
+ gc_bounds=gc_bounds,
284
+ max_attempts=50,
285
+ )
286
+ elif deterministic:
287
+ predicted_indices = logits.argmax(dim=-1).squeeze().tolist()
288
+ else:
289
+ predicted_indices = sample_non_deterministic(
290
+ logits=logits, temperature=temperature, top_p=top_p
291
+ )
292
+
293
+ predicted_dna = list(map(INDEX2TOKEN.__getitem__, predicted_indices))
294
+ predicted_dna = (
295
+ "".join([token[-3:] for token in predicted_dna]).strip().upper()
296
+ )
297
+
298
+ predictions.append(
299
+ DNASequencePrediction(
300
+ organism=organism_name,
301
+ protein=protein,
302
+ processed_input=merged_seq,
303
+ predicted_dna=predicted_dna,
304
+ )
305
+ )
306
+
307
+ return predictions[0] if num_sequences == 1 else predictions
308
+
309
+
310
+ @dataclass
311
+ class BeamCandidate:
312
+ """Represents a candidate sequence in the beam search."""
313
+ tokens: List[int]
314
+ score: float
315
+ gc_count: int
316
+ length: int
317
+
318
+ def __post_init__(self):
319
+ self.gc_ratio = self.gc_count / max(self.length, 1)
320
+
321
+ def __lt__(self, other):
322
+ return self.score < other.score
323
+
324
+
325
+ def _calculate_true_future_gc_range(
326
+ current_pos: int,
327
+ protein_sequence: str,
328
+ current_gc_count: int,
329
+ current_length: int
330
+ ) -> Tuple[float, float]:
331
+ """
332
+ Calculate the true minimum and maximum possible final GC content
333
+ given current state and remaining amino acids (perfect foresight).
334
+
335
+ Args:
336
+ current_pos: Current position in protein sequence
337
+ protein_sequence: Full protein sequence
338
+ current_gc_count: Current GC count in partial sequence
339
+ current_length: Current length in nucleotides
340
+
341
+ Returns:
342
+ Tuple of (min_possible_final_gc_ratio, max_possible_final_gc_ratio)
343
+ """
344
+ if current_pos >= len(protein_sequence):
345
+ # Already at end, return current ratio
346
+ final_ratio = current_gc_count / max(current_length, 1)
347
+ return final_ratio, final_ratio
348
+
349
+ # Calculate remaining amino acids
350
+ remaining_aas = protein_sequence[current_pos:]
351
+
352
+ # Calculate min/max possible GC from remaining amino acids
353
+ min_future_gc = 0
354
+ max_future_gc = 0
355
+
356
+ for aa in remaining_aas:
357
+ if aa.upper() in AA_MIN_GC and aa.upper() in AA_MAX_GC:
358
+ min_future_gc += AA_MIN_GC[aa.upper()]
359
+ max_future_gc += AA_MAX_GC[aa.upper()]
360
+ else:
361
+ # If amino acid not found, assume moderate GC (1-2 range)
362
+ min_future_gc += 1
363
+ max_future_gc += 2
364
+
365
+ # Calculate final sequence length
366
+ final_length = current_length + len(remaining_aas) * 3
367
+
368
+ # Calculate min/max possible final GC ratios
369
+ min_final_gc_ratio = (current_gc_count + min_future_gc) / final_length
370
+ max_final_gc_ratio = (current_gc_count + max_future_gc) / final_length
371
+
372
+ return min_final_gc_ratio, max_final_gc_ratio
373
+
374
+
375
+ def constrained_beam_search_simple(
376
+ logits: torch.Tensor,
377
+ protein_sequence: str,
378
+ gc_bounds: Tuple[float, float] = (0.30, 0.70),
379
+ max_attempts: int = 100,
380
+ ) -> List[int]:
381
+ """
382
+ Simple constrained search - try multiple greedy samples and pick best one within GC bounds.
383
+ """
384
+ min_gc, max_gc = gc_bounds
385
+ seq_len = min(logits.shape[0], len(protein_sequence))
386
+
387
+ # Convert to probabilities
388
+ probs = torch.softmax(logits, dim=-1)
389
+
390
+ valid_sequences = []
391
+
392
+ for attempt in range(max_attempts):
393
+ tokens = []
394
+ total_gc = 0
395
+
396
+ # Generate sequence position by position
397
+ for pos in range(seq_len):
398
+ aa = protein_sequence[pos]
399
+ possible_tokens = AMINO_ACID_TO_INDEX.get(aa, [])
400
+
401
+ if not possible_tokens:
402
+ continue
403
+
404
+ # Filter tokens by current constraints and get probabilities
405
+ candidates = []
406
+ for token_idx in possible_tokens:
407
+ if token_idx < len(probs[pos]) and token_idx < len(GC_COUNTS_PER_TOKEN):
408
+ prob = probs[pos][token_idx].item()
409
+ gc_contribution = int(GC_COUNTS_PER_TOKEN[token_idx].item())
410
+
411
+ # Check if this token could still lead to a valid final sequence (perfect foresight)
412
+ new_gc_total = total_gc + gc_contribution
413
+ new_length = (pos + 1) * 3
414
+
415
+ # Calculate what's possible for the final sequence given this choice
416
+ min_final_gc, max_final_gc = _calculate_true_future_gc_range(
417
+ pos + 1, protein_sequence, new_gc_total, new_length
418
+ )
419
+
420
+ # Only prune if there's NO OVERLAP between possible final range and target bounds
421
+ if max_final_gc >= min_gc and min_final_gc <= max_gc:
422
+ # Calculate gentle GC penalty to steer toward target center
423
+ target_gc = (min_gc + max_gc) / 2 # Target center (e.g., 0.50 for bounds 0.45-0.55)
424
+ current_projected_gc = (min_final_gc + max_final_gc) / 2 # Projected center
425
+
426
+ # Only apply penalty if we're significantly off-target AND late in sequence
427
+ sequence_progress = (pos + 1) / seq_len
428
+ if sequence_progress > 0.3: # Only apply penalty after 30% of sequence
429
+ gc_deviation = abs(current_projected_gc - target_gc)
430
+ if gc_deviation > 0.05: # Only if >5% deviation from target
431
+ # Gentle penalty: reduce probability by small factor
432
+ penalty_factor = max(0.7, 1.0 - 0.3 * gc_deviation) # 0.7-1.0 range
433
+ prob = prob * penalty_factor
434
+
435
+ candidates.append((token_idx, prob, gc_contribution))
436
+
437
+ if not candidates:
438
+ # If no valid candidates, break and try next attempt
439
+ break
440
+
441
+ # Sample from valid candidates (with temperature)
442
+ if attempt == 0:
443
+ # First attempt: greedy (highest probability)
444
+ best_token = max(candidates, key=lambda x: x[1])
445
+ else:
446
+ # Other attempts: sample with some randomness
447
+ probs_list = [c[1] for c in candidates]
448
+ if sum(probs_list) > 0:
449
+ # Normalize probabilities
450
+ probs_array = np.array(probs_list)
451
+ probs_array = probs_array / probs_array.sum()
452
+ # Sample
453
+ chosen_idx = np.random.choice(len(candidates), p=probs_array)
454
+ best_token = candidates[chosen_idx]
455
+ else:
456
+ best_token = candidates[0]
457
+
458
+ tokens.append(best_token[0])
459
+ total_gc += best_token[2]
460
+
461
+ # Check if we got a complete sequence
462
+ if len(tokens) == seq_len:
463
+ final_gc_ratio = total_gc / (seq_len * 3)
464
+ if min_gc <= final_gc_ratio <= max_gc:
465
+ # Calculate sequence score (sum of log probabilities)
466
+ score = sum(np.log(probs[i][tokens[i]].item() + 1e-8) for i in range(len(tokens)))
467
+ valid_sequences.append((tokens, score, final_gc_ratio))
468
+
469
+ if not valid_sequences:
470
+ raise ValueError(f"Could not generate valid sequence within GC bounds {gc_bounds} after {max_attempts} attempts")
471
+
472
+ # Return the sequence with highest score
473
+ best_sequence = max(valid_sequences, key=lambda x: x[1])
474
+ return best_sequence[0]
475
+
476
+
477
+ def constrained_beam_search(
478
+ logits: torch.Tensor,
479
+ protein_sequence: str,
480
+ gc_bounds: Tuple[float, float] = (0.30, 0.70),
481
+ beam_size: int = 5,
482
+ length_penalty: float = 1.0,
483
+ diversity_penalty: float = 0.0,
484
+ temperature: float = 1.0,
485
+ max_candidates: int = 100,
486
+ position_aware_gc_penalty: bool = True,
487
+ gc_penalty_strength: float = 2.0,
488
+ ) -> List[int]:
489
+ """
490
+ Constrained beam search with exact per-residue GC bounds tracking.
491
+
492
+ Priority #1: Exact per-residue GC bounds tracking
493
+ - Tracks cumulative GC content after each codon selection
494
+ - Prunes candidates that would violate GC bounds
495
+ - Maintains beam of valid candidates
496
+
497
+ Priority #2: Position-aware GC penalty mechanism
498
+ - Applies variable penalty weights based on sequence position
499
+ - Preserves flexibility early, applies pressure when necessary
500
+ - Uses progressive penalty scaling based on deviation severity
501
+
502
+ Args:
503
+ logits (torch.Tensor): Model logits of shape [seq_len, vocab_size]
504
+ protein_sequence (str): Input protein sequence
505
+ gc_bounds (Tuple[float, float]): (min_gc, max_gc) bounds
506
+ beam_size (int): Number of candidates to maintain
507
+ length_penalty (float): Length penalty for scoring
508
+ diversity_penalty (float): Diversity penalty for scoring
509
+ temperature (float): Temperature for probability scaling
510
+ max_candidates (int): Maximum candidates to consider per position
511
+ position_aware_gc_penalty (bool): Whether to use position-aware GC penalties
512
+ gc_penalty_strength (float): Strength of GC penalty adjustment
513
+
514
+ Returns:
515
+ List[int]: Best sequence token indices
516
+ """
517
+ min_gc, max_gc = gc_bounds
518
+ seq_len = logits.shape[0]
519
+ protein_len = len(protein_sequence)
520
+
521
+ # Ensure we don't go beyond the protein sequence
522
+ if seq_len > protein_len:
523
+ print(f"Warning: logits length ({seq_len}) > protein length ({protein_len}). Truncating to protein length.")
524
+ seq_len = protein_len
525
+ logits = logits[:protein_len]
526
+
527
+ # Initialize beam with empty candidate
528
+ beam = [BeamCandidate(tokens=[], score=0.0, gc_count=0, length=0)]
529
+
530
+ # Apply temperature scaling
531
+ if temperature != 1.0:
532
+ logits = logits / temperature
533
+
534
+ # Convert to probabilities
535
+ probs = torch.softmax(logits, dim=-1)
536
+
537
+ for pos in range(min(seq_len, len(protein_sequence))):
538
+ # Get possible tokens for current amino acid
539
+ aa = protein_sequence[pos]
540
+ possible_tokens = AMINO_ACID_TO_INDEX.get(aa, [])
541
+
542
+ if not possible_tokens:
543
+ # Fallback to all tokens if amino acid not found
544
+ possible_tokens = list(range(probs.shape[1]))
545
+
546
+ # Get top candidates for this position
547
+ pos_probs = probs[pos]
548
+ top_candidates = []
549
+
550
+ for token_idx in possible_tokens:
551
+ if token_idx < len(pos_probs) and token_idx < len(GC_COUNTS_PER_TOKEN):
552
+ prob = pos_probs[token_idx].item()
553
+ gc_contribution = int(GC_COUNTS_PER_TOKEN[token_idx].item())
554
+ # Only include tokens with valid probabilities
555
+ if prob > 1e-10: # Avoid extremely low probabilities
556
+ top_candidates.append((token_idx, prob, gc_contribution))
557
+
558
+ # Sort by probability and take top max_candidates
559
+ top_candidates.sort(key=lambda x: x[1], reverse=True)
560
+ top_candidates = top_candidates[:max_candidates]
561
+
562
+ # If no valid candidates found, fallback to all possible tokens for this amino acid
563
+ if not top_candidates:
564
+ for token_idx in possible_tokens[:min(len(possible_tokens), max_candidates)]:
565
+ if token_idx < len(pos_probs) and token_idx < len(GC_COUNTS_PER_TOKEN):
566
+ prob = max(pos_probs[token_idx].item(), 1e-10) # Ensure minimum probability
567
+ gc_contribution = int(GC_COUNTS_PER_TOKEN[token_idx].item())
568
+ top_candidates.append((token_idx, prob, gc_contribution))
569
+
570
+ # Generate new beam candidates
571
+ new_beam = []
572
+
573
+ for candidate in beam:
574
+ for token_idx, prob, gc_contribution in top_candidates:
575
+ # Calculate new GC stats
576
+ new_gc_count = candidate.gc_count + gc_contribution
577
+ new_length = candidate.length + 3 # Each codon is 3 nucleotides
578
+ new_gc_ratio = new_gc_count / new_length
579
+
580
+ # Priority #2: Position-aware GC penalty mechanism
581
+ gc_penalty = 0.0
582
+ if position_aware_gc_penalty:
583
+ # Calculate position weight (more penalty towards end of sequence)
584
+ position_weight = (pos + 1) / seq_len
585
+
586
+ # Calculate GC deviation severity
587
+ target_gc = (min_gc + max_gc) / 2
588
+ gc_deviation = abs(new_gc_ratio - target_gc)
589
+ deviation_severity = gc_deviation / ((max_gc - min_gc) / 2)
590
+
591
+ # Apply progressive penalty
592
+ if deviation_severity > 0.5: # Soft penalty zone
593
+ gc_penalty = gc_penalty_strength * position_weight * (deviation_severity - 0.5) ** 2
594
+
595
+ # Hard constraint: still prune sequences that exceed bounds
596
+ if new_gc_ratio < min_gc or new_gc_ratio > max_gc:
597
+ continue # Prune invalid candidates
598
+ else:
599
+ # Priority #1: Hard GC bounds only
600
+ if new_gc_ratio < min_gc or new_gc_ratio > max_gc:
601
+ continue # Prune invalid candidates
602
+
603
+ # Calculate score with GC penalty
604
+ new_score = candidate.score + np.log(prob + 1e-8) - gc_penalty
605
+
606
+ # Apply length penalty
607
+ if length_penalty != 1.0:
608
+ length_norm = ((pos + 1) ** length_penalty)
609
+ normalized_score = new_score / length_norm
610
+ else:
611
+ normalized_score = new_score
612
+
613
+ # Create new candidate
614
+ new_candidate = BeamCandidate(
615
+ tokens=candidate.tokens + [token_idx],
616
+ score=normalized_score,
617
+ gc_count=new_gc_count,
618
+ length=new_length
619
+ )
620
+
621
+ new_beam.append(new_candidate)
622
+
623
+ # Apply diversity penalty if specified
624
+ if diversity_penalty > 0.0:
625
+ new_beam = _apply_diversity_penalty(new_beam, diversity_penalty)
626
+
627
+ # Keep top beam_size candidates
628
+ beam = sorted(new_beam, key=lambda x: x.score, reverse=True)[:beam_size]
629
+
630
+ # Priority #3: Adaptive beam rescue for difficult sequences
631
+ if not beam:
632
+ # Attempt beam rescue by relaxing constraints progressively
633
+ rescue_attempts = 0
634
+ max_rescue_attempts = 3
635
+
636
+ while not beam and rescue_attempts < max_rescue_attempts:
637
+ rescue_attempts += 1
638
+
639
+ # Progressive relaxation strategy
640
+ if rescue_attempts == 1:
641
+ # First attempt: increase beam size and relax GC bounds slightly
642
+ temp_beam_size = min(beam_size * 2, max_candidates)
643
+ temp_gc_bounds = (min_gc * 0.95, max_gc * 1.05)
644
+ elif rescue_attempts == 2:
645
+ # Second attempt: further relax GC bounds and increase candidates
646
+ temp_beam_size = min(beam_size * 3, max_candidates)
647
+ temp_gc_bounds = (min_gc * 0.9, max_gc * 1.1)
648
+ else:
649
+ # Final attempt: maximum relaxation
650
+ temp_beam_size = max_candidates
651
+ temp_gc_bounds = (min_gc * 0.85, max_gc * 1.15)
652
+
653
+ # Retry beam generation with relaxed parameters
654
+ rescue_beam = []
655
+ # Use previous beam state or start fresh if this is the first position with no beam
656
+ previous_beam = beam if beam else [BeamCandidate(tokens=[], score=0.0, gc_count=0, length=0)]
657
+ for candidate in previous_beam:
658
+ for token_idx, prob, gc_contribution in top_candidates:
659
+ new_gc_count = candidate.gc_count + gc_contribution
660
+ new_length = candidate.length + 3
661
+ new_gc_ratio = new_gc_count / new_length
662
+
663
+ # Check relaxed bounds
664
+ if temp_gc_bounds[0] <= new_gc_ratio <= temp_gc_bounds[1]:
665
+ # Apply reduced GC penalty for rescue
666
+ gc_penalty = 0.0
667
+ if position_aware_gc_penalty:
668
+ position_weight = (pos + 1) / seq_len
669
+ target_gc = (min_gc + max_gc) / 2
670
+ gc_deviation = abs(new_gc_ratio - target_gc)
671
+ deviation_severity = gc_deviation / ((max_gc - min_gc) / 2)
672
+
673
+ # Reduced penalty for rescue
674
+ if deviation_severity > 0.7:
675
+ gc_penalty = (gc_penalty_strength * 0.5) * position_weight * (deviation_severity - 0.7) ** 2
676
+
677
+ new_score = candidate.score + np.log(prob + 1e-8) - gc_penalty
678
+
679
+ if length_penalty != 1.0:
680
+ length_norm = ((pos + 1) ** length_penalty)
681
+ normalized_score = new_score / length_norm
682
+ else:
683
+ normalized_score = new_score
684
+
685
+ rescue_candidate = BeamCandidate(
686
+ tokens=candidate.tokens + [token_idx],
687
+ score=normalized_score,
688
+ gc_count=new_gc_count,
689
+ length=new_length
690
+ )
691
+ rescue_beam.append(rescue_candidate)
692
+
693
+ # Keep top candidates from rescue attempt
694
+ if rescue_beam:
695
+ beam = sorted(rescue_beam, key=lambda x: x.score, reverse=True)[:temp_beam_size]
696
+ break
697
+
698
+ # If all rescue attempts failed, raise error
699
+ if not beam:
700
+ raise ValueError(
701
+ f"Beam rescue failed at position {pos} after {max_rescue_attempts} attempts. "
702
+ f"The GC constraints {gc_bounds} may be too restrictive for this protein sequence. "
703
+ f"Consider relaxing constraints or using a different approach."
704
+ )
705
+
706
+ # Return best candidate
707
+ best_candidate = max(beam, key=lambda x: x.score)
708
+ return best_candidate.tokens
709
+
710
+
711
+ # Wrapper function that tries simple approach first
712
+ def constrained_beam_search_wrapper(
713
+ logits: torch.Tensor,
714
+ protein_sequence: str,
715
+ gc_bounds: Tuple[float, float] = (0.30, 0.70),
716
+ **kwargs
717
+ ) -> List[int]:
718
+ """Wrapper that tries simple approach first, falls back to complex beam search."""
719
+ try:
720
+ # Try simple approach first
721
+ return constrained_beam_search_simple(logits, protein_sequence, gc_bounds)
722
+ except ValueError:
723
+ # Fall back to complex beam search
724
+ return constrained_beam_search(logits, protein_sequence, gc_bounds, **kwargs)
725
+
726
+
727
+ def _apply_diversity_penalty(candidates: List[BeamCandidate], penalty: float) -> List[BeamCandidate]:
728
+ """
729
+ Apply diversity penalty to reduce repetitive sequences.
730
+
731
+ Args:
732
+ candidates (List[BeamCandidate]): List of candidates
733
+ penalty (float): Diversity penalty strength
734
+
735
+ Returns:
736
+ List[BeamCandidate]: Candidates with diversity penalty applied
737
+ """
738
+ if not candidates:
739
+ return candidates
740
+
741
+ # Count token occurrences
742
+ token_counts = {}
743
+ for candidate in candidates:
744
+ for token in candidate.tokens:
745
+ token_counts[token] = token_counts.get(token, 0) + 1
746
+
747
+ # Apply penalty
748
+ for candidate in candidates:
749
+ diversity_score = 0.0
750
+ for token in candidate.tokens:
751
+ if token_counts[token] > 1:
752
+ diversity_score += penalty * np.log(token_counts[token])
753
+ candidate.score -= diversity_score
754
+
755
+ return candidates
756
+
757
+
758
+ def sample_non_deterministic(
759
+ logits: torch.Tensor,
760
+ temperature: float = 0.2,
761
+ top_p: float = 0.95,
762
+ ) -> List[int]:
763
+ """
764
+ Sample token indices from logits using temperature scaling and nucleus (top-p) sampling.
765
+
766
+ This function applies temperature scaling to the logits, computes probabilities,
767
+ and then performs nucleus sampling to select token indices. It is used for
768
+ non-deterministic decoding in language models to introduce randomness while
769
+ maintaining coherence in the generated sequences.
770
+
771
+ Args:
772
+ logits (torch.Tensor): The logits output from the model of shape
773
+ [seq_len, vocab_size] or [batch_size, seq_len, vocab_size].
774
+ temperature (float, optional): Temperature value for scaling logits.
775
+ Must be a positive float. Defaults to 1.0.
776
+ top_p (float, optional): Cumulative probability threshold for nucleus sampling.
777
+ Must be a float between 0 and 1. Tokens with cumulative probability up to
778
+ `top_p` are considered for sampling. Defaults to 0.95.
779
+
780
+ Returns:
781
+ List[int]: A list of sampled token indices corresponding to the predicted tokens.
782
+
783
+ Raises:
784
+ ValueError: If `temperature` is not a positive float or if `top_p` is not between 0 and 1.
785
+
786
+ Example:
787
+ >>> logits = model_output.logits # Assume logits is a tensor of shape [seq_len, vocab_size]
788
+ >>> predicted_indices = sample_non_deterministic(logits, temperature=0.7, top_p=0.9)
789
+ """
790
+ if not isinstance(temperature, (float, int)) or temperature <= 0:
791
+ raise ValueError("Temperature must be a positive float.")
792
+
793
+ if not isinstance(top_p, (float, int)) or not 0 < top_p <= 1.0:
794
+ raise ValueError("top_p must be a float between 0 and 1.")
795
+
796
+ # Compute probabilities using temperature scaling
797
+ probs = torch.softmax(logits / temperature, dim=-1)
798
+
799
+
800
+ # Remove batch dimension if present
801
+ if probs.dim() == 3:
802
+ probs = probs.squeeze(0) # Shape: [seq_len, vocab_size]
803
+
804
+ # Sort probabilities in descending order
805
+ probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True)
806
+ probs_sum = torch.cumsum(probs_sort, dim=-1)
807
+ mask = probs_sum - probs_sort > top_p
808
+
809
+ # Zero out probabilities for tokens beyond the top-p threshold
810
+ probs_sort[mask] = 0.0
811
+
812
+ # Renormalize the probabilities
813
+ probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True))
814
+ next_token = torch.multinomial(probs_sort, num_samples=1)
815
+ predicted_indices = torch.gather(probs_idx, -1, next_token).squeeze(-1)
816
+
817
+ return predicted_indices.tolist()
818
+
819
+
820
+ def load_model(
821
+ model_path: Optional[str] = None,
822
+ device: torch.device = None,
823
+ attention_type: str = "original_full",
824
+ num_organisms: int = None,
825
+ remove_prefix: bool = True,
826
+ ) -> torch.nn.Module:
827
+ """
828
+ Load a BigBirdForMaskedLM model from a model file, checkpoint, or HuggingFace.
829
+
830
+ Args:
831
+ model_path (Optional[str]): Path to the model file or checkpoint. If None,
832
+ load from HuggingFace.
833
+ device (torch.device, optional): The device to load the model onto.
834
+ attention_type (str, optional): The type of attention, 'block_sparse'
835
+ or 'original_full'.
836
+ num_organisms (int, optional): Number of organisms, needed if loading from a
837
+ checkpoint that requires this.
838
+ remove_prefix (bool, optional): Whether to remove the "model." prefix from the
839
+ keys in the state dict.
840
+
841
+ Returns:
842
+ torch.nn.Module: The loaded model.
843
+ """
844
+ if not model_path:
845
+ warnings.warn("Model path not provided. Loading from HuggingFace.", UserWarning)
846
+ model = BigBirdForMaskedLM.from_pretrained("adibvafa/CodonTransformer")
847
+ elif model_path.endswith(".ckpt"):
848
+ checkpoint = torch.load(model_path, map_location="cpu")
849
+
850
+ # Detect Lightning checkpoint vs raw state dict
851
+ if isinstance(checkpoint, dict) and "state_dict" in checkpoint:
852
+ state_dict = checkpoint["state_dict"]
853
+ if remove_prefix:
854
+ state_dict = {
855
+ k.replace("model.", ""): v for k, v in state_dict.items()
856
+ }
857
+ else:
858
+ # assume checkpoint itself is state_dict
859
+ state_dict = checkpoint
860
+
861
+ if num_organisms is None:
862
+ num_organisms = NUM_ORGANISMS
863
+
864
+ # Load model configuration and instantiate the model
865
+ config = load_bigbird_config(num_organisms)
866
+ model = BigBirdForMaskedLM(config=config)
867
+ model.load_state_dict(state_dict, strict=False)
868
+
869
+ elif model_path.endswith(".pt"):
870
+ state_dict = torch.load(model_path)
871
+ config = state_dict.pop("self.config")
872
+ model = BigBirdForMaskedLM(config=config)
873
+ model.load_state_dict(state_dict, strict=False)
874
+
875
+ else:
876
+ raise ValueError(
877
+ "Unsupported file type. Please provide a .ckpt or .pt file, "
878
+ "or None to load from HuggingFace."
879
+ )
880
+
881
+ # Prepare model for evaluation
882
+ model.bert.set_attention_type(attention_type)
883
+ model.eval()
884
+ if device:
885
+ model.to(device)
886
+
887
+ return model
888
+
889
+
890
+ def load_bigbird_config(num_organisms: int) -> BigBirdConfig:
891
+ """
892
+ Load the config object used to train the BigBird transformer.
893
+
894
+ Args:
895
+ num_organisms (int): The number of organisms.
896
+
897
+ Returns:
898
+ BigBirdConfig: The configuration object for BigBird.
899
+ """
900
+ config = transformers.BigBirdConfig(
901
+ vocab_size=len(TOKEN2INDEX), # Equal to len(tokenizer)
902
+ type_vocab_size=num_organisms,
903
+ sep_token_id=2,
904
+ )
905
+ return config
906
+
907
+
908
+ def create_model_from_checkpoint(
909
+ checkpoint_dir: str, output_model_dir: str, num_organisms: int
910
+ ) -> None:
911
+ """
912
+ Save a model to disk using a previous checkpoint.
913
+
914
+ Args:
915
+ checkpoint_dir (str): Directory where the checkpoint is stored.
916
+ output_model_dir (str): Directory where the model will be saved.
917
+ num_organisms (int): Number of organisms.
918
+ """
919
+ checkpoint = load_model(model_path=checkpoint_dir, num_organisms=num_organisms)
920
+ state_dict = checkpoint.state_dict()
921
+ state_dict["self.config"] = load_bigbird_config(num_organisms=num_organisms)
922
+
923
+ # Save the model state dict to the output directory
924
+ torch.save(state_dict, output_model_dir)
925
+
926
+
927
+ def load_tokenizer(tokenizer_path: Optional[Union[str, PreTrainedTokenizerFast]] = None) -> PreTrainedTokenizerFast:
928
+ """
929
+ Create and return a tokenizer object from tokenizer path or HuggingFace.
930
+
931
+ Args:
932
+ tokenizer_path (Optional[Union[str, PreTrainedTokenizerFast]]): Path to the tokenizer file,
933
+ a pre-loaded tokenizer object, or None. If None, load from HuggingFace.
934
+
935
+ Returns:
936
+ PreTrainedTokenizerFast: The tokenizer object.
937
+ """
938
+ # If a tokenizer object is already provided, return it
939
+ if isinstance(tokenizer_path, PreTrainedTokenizerFast):
940
+ return tokenizer_path
941
+
942
+ # If no path is provided, load from HuggingFace
943
+ if not tokenizer_path:
944
+ warnings.warn(
945
+ "Tokenizer path not provided. Loading from HuggingFace.", UserWarning
946
+ )
947
+ return AutoTokenizer.from_pretrained("adibvafa/CodonTransformer")
948
+
949
+ # Load from file path
950
+ return transformers.PreTrainedTokenizerFast(
951
+ tokenizer_file=tokenizer_path,
952
+ bos_token="[CLS]",
953
+ eos_token="[SEP]",
954
+ unk_token="[UNK]",
955
+ sep_token="[SEP]",
956
+ pad_token="[PAD]",
957
+ cls_token="[CLS]",
958
+ mask_token="[MASK]",
959
+ )
960
+
961
+
962
+ def tokenize(
963
+ batch: List[Dict[str, Any]],
964
+ tokenizer: Union[PreTrainedTokenizerFast, str] = None,
965
+ max_len: int = 2048,
966
+ ) -> BatchEncoding:
967
+ """
968
+ Return the tokenized sequences given a batch of input data.
969
+ Each data in the batch is expected to be a dictionary with "codons" and
970
+ "organism" keys.
971
+
972
+ Args:
973
+ batch (List[Dict[str, Any]]): A list of dictionaries with "codons" and
974
+ "organism" keys.
975
+ tokenizer (PreTrainedTokenizerFast, str, optional): The tokenizer object or
976
+ path to the tokenizer file.
977
+ max_len (int, optional): Maximum length of the tokenized sequence.
978
+
979
+ Returns:
980
+ BatchEncoding: The tokenized batch.
981
+ """
982
+ if not isinstance(tokenizer, PreTrainedTokenizerFast):
983
+ tokenizer = load_tokenizer(tokenizer)
984
+
985
+ tokenized = tokenizer(
986
+ [data["codons"] for data in batch],
987
+ return_attention_mask=True,
988
+ return_token_type_ids=True,
989
+ truncation=True,
990
+ padding=True,
991
+ max_length=max_len,
992
+ return_tensors="pt",
993
+ )
994
+
995
+ # Add token type IDs for species
996
+ seq_len = tokenized["input_ids"].shape[-1]
997
+ species_index = torch.tensor([[data["organism"]] for data in batch])
998
+ tokenized["token_type_ids"] = species_index.repeat(1, seq_len)
999
+
1000
+ return tokenized
1001
+
1002
+
1003
+ def validate_and_convert_organism(organism: Union[int, str]) -> Tuple[int, str]:
1004
+ """
1005
+ Validate and convert the organism input to both ID and name.
1006
+
1007
+ This function takes either an organism ID or name as input and returns both
1008
+ the ID and name. It performs validation to ensure the input corresponds to
1009
+ a valid organism in the ORGANISM2ID dictionary.
1010
+
1011
+ Args:
1012
+ organism (Union[int, str]): Either the ID of the organism (int) or its
1013
+ name (str).
1014
+
1015
+ Returns:
1016
+ Tuple[int, str]: A tuple containing the organism ID (int) and name (str).
1017
+
1018
+ Raises:
1019
+ ValueError: If the input is neither a string nor an integer, if the
1020
+ organism name is not found in ORGANISM2ID, if the organism ID is not a
1021
+ value in ORGANISM2ID, or if no name is found for a given ID.
1022
+
1023
+ Note:
1024
+ This function relies on the ORGANISM2ID dictionary imported from
1025
+ CodonTransformer.CodonUtils, which maps organism names to their
1026
+ corresponding IDs.
1027
+ """
1028
+ if isinstance(organism, str):
1029
+ if organism not in ORGANISM2ID:
1030
+ raise ValueError(
1031
+ f"Invalid organism name: {organism}. "
1032
+ "Please use a valid organism name or ID."
1033
+ )
1034
+ organism_id = ORGANISM2ID[organism]
1035
+ organism_name = organism
1036
+
1037
+ elif isinstance(organism, int):
1038
+ if organism not in ORGANISM2ID.values():
1039
+ raise ValueError(
1040
+ f"Invalid organism ID: {organism}. "
1041
+ "Please use a valid organism name or ID."
1042
+ )
1043
+
1044
+ organism_id = organism
1045
+ organism_name = next(
1046
+ (name for name, id in ORGANISM2ID.items() if id == organism), None
1047
+ )
1048
+ if organism_name is None:
1049
+ raise ValueError(f"No organism name found for ID: {organism}")
1050
+
1051
+ return organism_id, organism_name
1052
+
1053
+
1054
+ def get_high_frequency_choice_sequence(
1055
+ protein: str, codon_frequencies: Dict[str, Tuple[List[str], List[float]]]
1056
+ ) -> str:
1057
+ """
1058
+ Return the DNA sequence optimized using High Frequency Choice (HFC) approach
1059
+ in which the most frequent codon for a given amino acid is always chosen.
1060
+
1061
+ Args:
1062
+ protein (str): The protein sequence.
1063
+ codon_frequencies (Dict[str, Tuple[List[str], List[float]]]): Codon
1064
+ frequencies for each amino acid.
1065
+
1066
+ Returns:
1067
+ str: The optimized DNA sequence.
1068
+ """
1069
+ # Select the most frequent codon for each amino acid in the protein sequence
1070
+ dna_codons = [
1071
+ codon_frequencies[aminoacid][0][np.argmax(codon_frequencies[aminoacid][1])]
1072
+ for aminoacid in protein
1073
+ ]
1074
+ return "".join(dna_codons)
1075
+
1076
+
1077
+ def precompute_most_frequent_codons(
1078
+ codon_frequencies: Dict[str, Tuple[List[str], List[float]]],
1079
+ ) -> Dict[str, str]:
1080
+ """
1081
+ Precompute the most frequent codon for each amino acid.
1082
+
1083
+ Args:
1084
+ codon_frequencies (Dict[str, Tuple[List[str], List[float]]]): Codon
1085
+ frequencies for each amino acid.
1086
+
1087
+ Returns:
1088
+ Dict[str, str]: The most frequent codon for each amino acid.
1089
+ """
1090
+ # Create a dictionary mapping each amino acid to its most frequent codon
1091
+ return {
1092
+ aminoacid: codons[np.argmax(frequencies)]
1093
+ for aminoacid, (codons, frequencies) in codon_frequencies.items()
1094
+ }
1095
+
1096
+
1097
+ def get_high_frequency_choice_sequence_optimized(
1098
+ protein: str, codon_frequencies: Dict[str, Tuple[List[str], List[float]]]
1099
+ ) -> str:
1100
+ """
1101
+ Efficient implementation of get_high_frequency_choice_sequence that uses
1102
+ vectorized operations and helper functions, achieving up to x10 faster speed.
1103
+
1104
+ Args:
1105
+ protein (str): The protein sequence.
1106
+ codon_frequencies (Dict[str, Tuple[List[str], List[float]]]): Codon
1107
+ frequencies for each amino acid.
1108
+
1109
+ Returns:
1110
+ str: The optimized DNA sequence.
1111
+ """
1112
+ # Precompute the most frequent codons for each amino acid
1113
+ most_frequent_codons = precompute_most_frequent_codons(codon_frequencies)
1114
+
1115
+ return "".join(most_frequent_codons[aminoacid] for aminoacid in protein)
1116
+
1117
+
1118
+ def get_background_frequency_choice_sequence(
1119
+ protein: str, codon_frequencies: Dict[str, Tuple[List[str], List[float]]]
1120
+ ) -> str:
1121
+ """
1122
+ Return the DNA sequence optimized using Background Frequency Choice (BFC)
1123
+ approach in which a random codon for a given amino acid is chosen using
1124
+ the codon frequencies probability distribution.
1125
+
1126
+ Args:
1127
+ protein (str): The protein sequence.
1128
+ codon_frequencies (Dict[str, Tuple[List[str], List[float]]]): Codon
1129
+ frequencies for each amino acid.
1130
+
1131
+ Returns:
1132
+ str: The optimized DNA sequence.
1133
+ """
1134
+ # Select a random codon for each amino acid based on the codon frequencies
1135
+ # probability distribution
1136
+ dna_codons = [
1137
+ np.random.choice(
1138
+ codon_frequencies[aminoacid][0], p=codon_frequencies[aminoacid][1]
1139
+ )
1140
+ for aminoacid in protein
1141
+ ]
1142
+ return "".join(dna_codons)
1143
+
1144
+
1145
+ def precompute_cdf(
1146
+ codon_frequencies: Dict[str, Tuple[List[str], List[float]]],
1147
+ ) -> Dict[str, Tuple[List[str], Any]]:
1148
+ """
1149
+ Precompute the cumulative distribution function (CDF) for each amino acid.
1150
+
1151
+ Args:
1152
+ codon_frequencies (Dict[str, Tuple[List[str], List[float]]]): Codon
1153
+ frequencies for each amino acid.
1154
+
1155
+ Returns:
1156
+ Dict[str, Tuple[List[str], Any]]: CDFs for each amino acid.
1157
+ """
1158
+ cdf = {}
1159
+
1160
+ # Calculate the cumulative distribution function for each amino acid
1161
+ for aminoacid, (codons, frequencies) in codon_frequencies.items():
1162
+ cdf[aminoacid] = (codons, np.cumsum(frequencies))
1163
+
1164
+ return cdf
1165
+
1166
+
1167
+ def get_background_frequency_choice_sequence_optimized(
1168
+ protein: str, codon_frequencies: Dict[str, Tuple[List[str], List[float]]]
1169
+ ) -> str:
1170
+ """
1171
+ Efficient implementation of get_background_frequency_choice_sequence that uses
1172
+ vectorized operations and helper functions, achieving up to x8 faster speed.
1173
+
1174
+ Args:
1175
+ protein (str): The protein sequence.
1176
+ codon_frequencies (Dict[str, Tuple[List[str], List[float]]]): Codon
1177
+ frequencies for each amino acid.
1178
+
1179
+ Returns:
1180
+ str: The optimized DNA sequence.
1181
+ """
1182
+ dna_codons = []
1183
+ cdf = precompute_cdf(codon_frequencies)
1184
+
1185
+ # Select a random codon for each amino acid using the precomputed CDFs
1186
+ for aminoacid in protein:
1187
+ codons, cumulative_prob = cdf[aminoacid]
1188
+ selected_codon_index = np.searchsorted(cumulative_prob, np.random.rand())
1189
+ dna_codons.append(codons[selected_codon_index])
1190
+
1191
+ return "".join(dna_codons)
1192
+
1193
+
1194
+ def get_uniform_random_choice_sequence(
1195
+ protein: str, codon_frequencies: Dict[str, Tuple[List[str], List[float]]]
1196
+ ) -> str:
1197
+ """
1198
+ Return the DNA sequence optimized using Uniform Random Choice (URC) approach
1199
+ in which a random codon for a given amino acid is chosen using a uniform
1200
+ prior.
1201
+
1202
+ Args:
1203
+ protein (str): The protein sequence.
1204
+ codon_frequencies (Dict[str, Tuple[List[str], List[float]]]): Codon
1205
+ frequencies for each amino acid.
1206
+
1207
+ Returns:
1208
+ str: The optimized DNA sequence.
1209
+ """
1210
+ # Select a random codon for each amino acid using a uniform prior distribution
1211
+ dna_codons = []
1212
+ for aminoacid in protein:
1213
+ codons = codon_frequencies[aminoacid][0]
1214
+ random_index = np.random.randint(0, len(codons))
1215
+ dna_codons.append(codons[random_index])
1216
+ return "".join(dna_codons)
1217
+
1218
+
1219
+ def get_icor_prediction(input_seq: str, model_path: str, stop_symbol: str) -> str:
1220
+ """
1221
+ Return the optimized codon sequence for the given protein sequence using ICOR.
1222
+
1223
+ Credit: ICOR: improving codon optimization with recurrent neural networks
1224
+ Rishab Jain, Aditya Jain, Elizabeth Mauro, Kevin LeShane, Douglas
1225
+ Densmore
1226
+
1227
+ Args:
1228
+ input_seq (str): The input protein sequence.
1229
+ model_path (str): The path to the ICOR model.
1230
+ stop_symbol (str): The symbol representing stop codons in the sequence.
1231
+
1232
+ Returns:
1233
+ str: The optimized DNA sequence.
1234
+ """
1235
+ input_seq = input_seq.strip().upper()
1236
+ input_seq = input_seq.replace(stop_symbol, "*")
1237
+
1238
+ # Define categorical labels from when model was trained.
1239
+ labels = [
1240
+ "AAA",
1241
+ "AAC",
1242
+ "AAG",
1243
+ "AAT",
1244
+ "ACA",
1245
+ "ACG",
1246
+ "ACT",
1247
+ "AGC",
1248
+ "ATA",
1249
+ "ATC",
1250
+ "ATG",
1251
+ "ATT",
1252
+ "CAA",
1253
+ "CAC",
1254
+ "CAG",
1255
+ "CCG",
1256
+ "CCT",
1257
+ "CTA",
1258
+ "CTC",
1259
+ "CTG",
1260
+ "CTT",
1261
+ "GAA",
1262
+ "GAT",
1263
+ "GCA",
1264
+ "GCC",
1265
+ "GCG",
1266
+ "GCT",
1267
+ "GGA",
1268
+ "GGC",
1269
+ "GTC",
1270
+ "GTG",
1271
+ "GTT",
1272
+ "TAA",
1273
+ "TAT",
1274
+ "TCA",
1275
+ "TCG",
1276
+ "TCT",
1277
+ "TGG",
1278
+ "TGT",
1279
+ "TTA",
1280
+ "TTC",
1281
+ "TTG",
1282
+ "TTT",
1283
+ "ACC",
1284
+ "CAT",
1285
+ "CCA",
1286
+ "CGG",
1287
+ "CGT",
1288
+ "GAC",
1289
+ "GAG",
1290
+ "GGT",
1291
+ "AGT",
1292
+ "GGG",
1293
+ "GTA",
1294
+ "TGC",
1295
+ "CCC",
1296
+ "CGA",
1297
+ "CGC",
1298
+ "TAC",
1299
+ "TAG",
1300
+ "TCC",
1301
+ "AGA",
1302
+ "AGG",
1303
+ "TGA",
1304
+ ]
1305
+
1306
+ # Define aa to integer table
1307
+ def aa2int(seq: str) -> List[int]:
1308
+ _aa2int = {
1309
+ "A": 1,
1310
+ "R": 2,
1311
+ "N": 3,
1312
+ "D": 4,
1313
+ "C": 5,
1314
+ "Q": 6,
1315
+ "E": 7,
1316
+ "G": 8,
1317
+ "H": 9,
1318
+ "I": 10,
1319
+ "L": 11,
1320
+ "K": 12,
1321
+ "M": 13,
1322
+ "F": 14,
1323
+ "P": 15,
1324
+ "S": 16,
1325
+ "T": 17,
1326
+ "W": 18,
1327
+ "Y": 19,
1328
+ "V": 20,
1329
+ "B": 21,
1330
+ "Z": 22,
1331
+ "X": 23,
1332
+ "*": 24,
1333
+ "-": 25,
1334
+ "?": 26,
1335
+ }
1336
+ return [_aa2int[i] for i in seq]
1337
+
1338
+ # Create empty array to fill
1339
+ oh_array = np.zeros(shape=(26, len(input_seq)))
1340
+
1341
+ # Load placements from aa2int
1342
+ aa_placement = aa2int(input_seq)
1343
+
1344
+ # One-hot encode the amino acid sequence:
1345
+ for i in range(0, len(aa_placement)):
1346
+ oh_array[aa_placement[i], i] = 1
1347
+ i += 1
1348
+
1349
+ oh_array = [oh_array]
1350
+ x = np.array(np.transpose(oh_array))
1351
+
1352
+ y = x.astype(np.float32)
1353
+
1354
+ y = np.reshape(y, (y.shape[0], 1, 26))
1355
+
1356
+ # Start ICOR session using model.
1357
+ sess = rt.InferenceSession(model_path)
1358
+ input_name = sess.get_inputs()[0].name
1359
+
1360
+ # Get prediction:
1361
+ pred_onx = sess.run(None, {input_name: y})
1362
+
1363
+ # Get the index of the highest probability from softmax output:
1364
+ pred_indices = []
1365
+ for pred in pred_onx[0]:
1366
+ pred_indices.append(np.argmax(pred))
1367
+
1368
+ out_str = ""
1369
+ for index in pred_indices:
1370
+ out_str += labels[index]
1371
+
1372
+ return out_str
CodonTransformer/CodonUtils.py ADDED
@@ -0,0 +1,871 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ File: CodonUtils.py
3
+ ---------------------
4
+ Includes constants and helper functions used by other Python scripts.
5
+ """
6
+
7
+ import itertools
8
+ import json
9
+ import os
10
+ import pickle
11
+ import re
12
+ from abc import ABC, abstractmethod
13
+ from dataclasses import dataclass
14
+ from typing import Any, Dict, Iterator, List, Optional, Tuple
15
+
16
+ import pandas as pd
17
+ import requests
18
+ import torch
19
+
20
+ # List of all amino acids
21
+ AMINO_ACIDS: List[str] = [
22
+ "A", # Alanine
23
+ "C", # Cysteine
24
+ "D", # Aspartic acid
25
+ "E", # Glutamic acid
26
+ "F", # Phenylalanine
27
+ "G", # Glycine
28
+ "H", # Histidine
29
+ "I", # Isoleucine
30
+ "K", # Lysine
31
+ "L", # Leucine
32
+ "M", # Methionine
33
+ "N", # Asparagine
34
+ "P", # Proline
35
+ "Q", # Glutamine
36
+ "R", # Arginine
37
+ "S", # Serine
38
+ "T", # Threonine
39
+ "V", # Valine
40
+ "W", # Tryptophan
41
+ "Y", # Tyrosine
42
+ ]
43
+ STOP_SYMBOLS = ["_", "*"] # Stop codon symbols
44
+
45
+ # Dictionary ambiguous amino acids to standard amino acids
46
+ AMBIGUOUS_AMINOACID_MAP: Dict[str, list[str]] = {
47
+ "B": ["N", "D"], # Asparagine (N) or Aspartic acid (D)
48
+ "Z": ["Q", "E"], # Glutamine (Q) or Glutamic acid (E)
49
+ "X": ["A"], # Any amino acid (typically replaced with Alanine)
50
+ "J": ["L", "I"], # Leucine (L) or Isoleucine (I)
51
+ "U": ["C"], # Selenocysteine (typically replaced with Cysteine)
52
+ "O": ["K"], # Pyrrolysine (typically replaced with Lysine)
53
+ }
54
+
55
+ # List of all possible start and stop codons
56
+ START_CODONS: List[str] = ["ATG", "TTG", "CTG", "GTG"]
57
+ STOP_CODONS: List[str] = ["TAA", "TAG", "TGA"]
58
+
59
+ # Token-to-index mapping for amino acids and special tokens
60
+ TOKEN2INDEX: Dict[str, int] = {
61
+ "[UNK]": 0,
62
+ "[CLS]": 1,
63
+ "[SEP]": 2,
64
+ "[PAD]": 3,
65
+ "[MASK]": 4,
66
+ "a_unk": 5,
67
+ "c_unk": 6,
68
+ "d_unk": 7,
69
+ "e_unk": 8,
70
+ "f_unk": 9,
71
+ "g_unk": 10,
72
+ "h_unk": 11,
73
+ "i_unk": 12,
74
+ "k_unk": 13,
75
+ "l_unk": 14,
76
+ "m_unk": 15,
77
+ "n_unk": 16,
78
+ "p_unk": 17,
79
+ "q_unk": 18,
80
+ "r_unk": 19,
81
+ "s_unk": 20,
82
+ "t_unk": 21,
83
+ "v_unk": 22,
84
+ "w_unk": 23,
85
+ "y_unk": 24,
86
+ "__unk": 25,
87
+ "k_aaa": 26,
88
+ "n_aac": 27,
89
+ "k_aag": 28,
90
+ "n_aat": 29,
91
+ "t_aca": 30,
92
+ "t_acc": 31,
93
+ "t_acg": 32,
94
+ "t_act": 33,
95
+ "r_aga": 34,
96
+ "s_agc": 35,
97
+ "r_agg": 36,
98
+ "s_agt": 37,
99
+ "i_ata": 38,
100
+ "i_atc": 39,
101
+ "m_atg": 40,
102
+ "i_att": 41,
103
+ "q_caa": 42,
104
+ "h_cac": 43,
105
+ "q_cag": 44,
106
+ "h_cat": 45,
107
+ "p_cca": 46,
108
+ "p_ccc": 47,
109
+ "p_ccg": 48,
110
+ "p_cct": 49,
111
+ "r_cga": 50,
112
+ "r_cgc": 51,
113
+ "r_cgg": 52,
114
+ "r_cgt": 53,
115
+ "l_cta": 54,
116
+ "l_ctc": 55,
117
+ "l_ctg": 56,
118
+ "l_ctt": 57,
119
+ "e_gaa": 58,
120
+ "d_gac": 59,
121
+ "e_gag": 60,
122
+ "d_gat": 61,
123
+ "a_gca": 62,
124
+ "a_gcc": 63,
125
+ "a_gcg": 64,
126
+ "a_gct": 65,
127
+ "g_gga": 66,
128
+ "g_ggc": 67,
129
+ "g_ggg": 68,
130
+ "g_ggt": 69,
131
+ "v_gta": 70,
132
+ "v_gtc": 71,
133
+ "v_gtg": 72,
134
+ "v_gtt": 73,
135
+ "__taa": 74,
136
+ "y_tac": 75,
137
+ "__tag": 76,
138
+ "y_tat": 77,
139
+ "s_tca": 78,
140
+ "s_tcc": 79,
141
+ "s_tcg": 80,
142
+ "s_tct": 81,
143
+ "__tga": 82,
144
+ "c_tgc": 83,
145
+ "w_tgg": 84,
146
+ "c_tgt": 85,
147
+ "l_tta": 86,
148
+ "f_ttc": 87,
149
+ "l_ttg": 88,
150
+ "f_ttt": 89,
151
+ }
152
+
153
+ # Index-to-token mapping, reverse of TOKEN2INDEX
154
+ INDEX2TOKEN: Dict[int, str] = {i: c for c, i in TOKEN2INDEX.items()}
155
+
156
+ # Dictionary mapping each codon to its GC content
157
+ CODON_GC_CONTENT: Dict[str, int] = {
158
+ token.split("_")[1]: token.split("_")[1].upper().count("G") + token.split("_")[1].upper().count("C")
159
+ for token in TOKEN2INDEX
160
+ if "_" in token and len(token.split("_")[1]) == 3
161
+ }
162
+
163
+ # Tensor with GC counts for each token in the vocabulary
164
+ GC_COUNTS_PER_TOKEN = torch.zeros(len(TOKEN2INDEX))
165
+ for token, index in TOKEN2INDEX.items():
166
+ if "_" in token and len(token.split("_")[1]) == 3:
167
+ codon = token.split("_")[1].upper()
168
+ gc_count = codon.count("G") + codon.count("C")
169
+ GC_COUNTS_PER_TOKEN[index] = gc_count
170
+
171
+ G_indices = [idx for token, idx in TOKEN2INDEX.items() if "g" in token.split("_")[-1]]
172
+ C_indices = [idx for token, idx in TOKEN2INDEX.items() if "c" in token.split("_")[-1]]
173
+
174
+ # Dictionary mapping each amino acid and stop symbol to indices of codon tokens that translate to it
175
+ AMINO_ACID_TO_INDEX = {
176
+ aa: sorted(
177
+ [i for t, i in TOKEN2INDEX.items() if t[0].upper() == aa and t[-3:] != "unk"]
178
+ )
179
+ for aa in (AMINO_ACIDS + STOP_SYMBOLS)
180
+ }
181
+
182
+
183
+ # Dictionary mapping each amino acid to min/max GC content across all possible codons
184
+ AA_MIN_GC: Dict[str, int] = {}
185
+ AA_MAX_GC: Dict[str, int] = {}
186
+
187
+ for aa, token_indices in AMINO_ACID_TO_INDEX.items():
188
+ if token_indices: # Skip if no tokens for this amino acid
189
+ gc_counts = []
190
+ for token_idx in token_indices:
191
+ token = INDEX2TOKEN[token_idx]
192
+ if "_" in token and len(token.split("_")[1]) == 3:
193
+ codon = token.split("_")[1]
194
+ if codon in CODON_GC_CONTENT:
195
+ gc_counts.append(CODON_GC_CONTENT[codon])
196
+
197
+ if gc_counts:
198
+ AA_MIN_GC[aa] = min(gc_counts)
199
+ AA_MAX_GC[aa] = max(gc_counts)
200
+
201
+ # Mask token mapping
202
+ TOKEN2MASK: Dict[int, int] = {
203
+ 0: 0,
204
+ 1: 1,
205
+ 2: 2,
206
+ 3: 3,
207
+ 4: 4,
208
+ 5: 5,
209
+ 6: 6,
210
+ 7: 7,
211
+ 8: 8,
212
+ 9: 9,
213
+ 10: 10,
214
+ 11: 11,
215
+ 12: 12,
216
+ 13: 13,
217
+ 14: 14,
218
+ 15: 15,
219
+ 16: 16,
220
+ 17: 17,
221
+ 18: 18,
222
+ 19: 19,
223
+ 20: 20,
224
+ 21: 21,
225
+ 22: 22,
226
+ 23: 23,
227
+ 24: 24,
228
+ 25: 25,
229
+ 26: 13,
230
+ 27: 16,
231
+ 28: 13,
232
+ 29: 16,
233
+ 30: 21,
234
+ 31: 21,
235
+ 32: 21,
236
+ 33: 21,
237
+ 34: 19,
238
+ 35: 20,
239
+ 36: 19,
240
+ 37: 20,
241
+ 38: 12,
242
+ 39: 12,
243
+ 40: 15,
244
+ 41: 12,
245
+ 42: 18,
246
+ 43: 11,
247
+ 44: 18,
248
+ 45: 11,
249
+ 46: 17,
250
+ 47: 17,
251
+ 48: 17,
252
+ 49: 17,
253
+ 50: 19,
254
+ 51: 19,
255
+ 52: 19,
256
+ 53: 19,
257
+ 54: 14,
258
+ 55: 14,
259
+ 56: 14,
260
+ 57: 14,
261
+ 58: 8,
262
+ 59: 7,
263
+ 60: 8,
264
+ 61: 7,
265
+ 62: 5,
266
+ 63: 5,
267
+ 64: 5,
268
+ 65: 5,
269
+ 66: 10,
270
+ 67: 10,
271
+ 68: 10,
272
+ 69: 10,
273
+ 70: 22,
274
+ 71: 22,
275
+ 72: 22,
276
+ 73: 22,
277
+ 74: 25,
278
+ 75: 24,
279
+ 76: 25,
280
+ 77: 24,
281
+ 78: 20,
282
+ 79: 20,
283
+ 80: 20,
284
+ 81: 20,
285
+ 82: 25,
286
+ 83: 6,
287
+ 84: 23,
288
+ 85: 6,
289
+ 86: 14,
290
+ 87: 9,
291
+ 88: 14,
292
+ 89: 9,
293
+ }
294
+
295
+ # List of organisms used for fine-tuning
296
+ FINE_TUNE_ORGANISMS: List[str] = [
297
+ "Arabidopsis thaliana",
298
+ "Bacillus subtilis",
299
+ "Caenorhabditis elegans",
300
+ "Chlamydomonas reinhardtii",
301
+ "Chlamydomonas reinhardtii chloroplast",
302
+ "Danio rerio",
303
+ "Drosophila melanogaster",
304
+ "Homo sapiens",
305
+ "Mus musculus",
306
+ "Nicotiana tabacum",
307
+ "Nicotiana tabacum chloroplast",
308
+ "Pseudomonas putida",
309
+ "Saccharomyces cerevisiae",
310
+ "Escherichia coli O157-H7 str. Sakai",
311
+ "Escherichia coli general",
312
+ "Escherichia coli str. K-12 substr. MG1655",
313
+ "Thermococcus barophilus MPT",
314
+ ]
315
+
316
+ # List of organisms most commonly used for coodn optimization
317
+ COMMON_ORGANISMS: List[str] = [
318
+ "Arabidopsis thaliana",
319
+ "Bacillus subtilis",
320
+ "Caenorhabditis elegans",
321
+ "Chlamydomonas reinhardtii",
322
+ "Danio rerio",
323
+ "Drosophila melanogaster",
324
+ "Homo sapiens",
325
+ "Mus musculus",
326
+ "Nicotiana tabacum",
327
+ "Pseudomonas putida",
328
+ "Saccharomyces cerevisiae",
329
+ "Escherichia coli general",
330
+ ]
331
+
332
+ # Dictionary mapping each organism name to respective organism id
333
+ ORGANISM2ID: Dict[str, int] = {
334
+ "Arabidopsis thaliana": 0,
335
+ "Atlantibacter hermannii": 1,
336
+ "Bacillus subtilis": 2,
337
+ "Brenneria goodwinii": 3,
338
+ "Buchnera aphidicola (Schizaphis graminum)": 4,
339
+ "Caenorhabditis elegans": 5,
340
+ "Candidatus Erwinia haradaeae": 6,
341
+ "Candidatus Hamiltonella defensa 5AT (Acyrthosiphon pisum)": 7,
342
+ "Chlamydomonas reinhardtii": 8,
343
+ "Chlamydomonas reinhardtii chloroplast": 9,
344
+ "Citrobacter amalonaticus": 10,
345
+ "Citrobacter braakii": 11,
346
+ "Citrobacter cronae": 12,
347
+ "Citrobacter europaeus": 13,
348
+ "Citrobacter farmeri": 14,
349
+ "Citrobacter freundii": 15,
350
+ "Citrobacter koseri ATCC BAA-895": 16,
351
+ "Citrobacter portucalensis": 17,
352
+ "Citrobacter werkmanii": 18,
353
+ "Citrobacter youngae": 19,
354
+ "Cronobacter dublinensis subsp. dublinensis LMG 23823": 20,
355
+ "Cronobacter malonaticus LMG 23826": 21,
356
+ "Cronobacter sakazakii": 22,
357
+ "Cronobacter turicensis": 23,
358
+ "Danio rerio": 24,
359
+ "Dickeya dadantii 3937": 25,
360
+ "Dickeya dianthicola": 26,
361
+ "Dickeya fangzhongdai": 27,
362
+ "Dickeya solani": 28,
363
+ "Dickeya zeae": 29,
364
+ "Drosophila melanogaster": 30,
365
+ "Edwardsiella anguillarum ET080813": 31,
366
+ "Edwardsiella ictaluri": 32,
367
+ "Edwardsiella piscicida": 33,
368
+ "Edwardsiella tarda": 34,
369
+ "Enterobacter asburiae": 35,
370
+ "Enterobacter bugandensis": 36,
371
+ "Enterobacter cancerogenus": 37,
372
+ "Enterobacter chengduensis": 38,
373
+ "Enterobacter cloacae": 39,
374
+ "Enterobacter hormaechei": 40,
375
+ "Enterobacter kobei": 41,
376
+ "Enterobacter ludwigii": 42,
377
+ "Enterobacter mori": 43,
378
+ "Enterobacter quasiroggenkampii": 44,
379
+ "Enterobacter roggenkampii": 45,
380
+ "Enterobacter sichuanensis": 46,
381
+ "Erwinia amylovora CFBP1430": 47,
382
+ "Erwinia persicina": 48,
383
+ "Escherichia albertii": 49,
384
+ "Escherichia coli O157-H7 str. Sakai": 50,
385
+ "Escherichia coli general": 51,
386
+ "Escherichia coli str. K-12 substr. MG1655": 52,
387
+ "Escherichia fergusonii": 53,
388
+ "Escherichia marmotae": 54,
389
+ "Escherichia ruysiae": 55,
390
+ "Ewingella americana": 56,
391
+ "Hafnia alvei": 57,
392
+ "Hafnia paralvei": 58,
393
+ "Homo sapiens": 59,
394
+ "Kalamiella piersonii": 60,
395
+ "Klebsiella aerogenes": 61,
396
+ "Klebsiella grimontii": 62,
397
+ "Klebsiella michiganensis": 63,
398
+ "Klebsiella oxytoca": 64,
399
+ "Klebsiella pasteurii": 65,
400
+ "Klebsiella pneumoniae subsp. pneumoniae HS11286": 66,
401
+ "Klebsiella quasipneumoniae": 67,
402
+ "Klebsiella quasivariicola": 68,
403
+ "Klebsiella variicola": 69,
404
+ "Kosakonia cowanii": 70,
405
+ "Kosakonia radicincitans": 71,
406
+ "Leclercia adecarboxylata": 72,
407
+ "Lelliottia amnigena": 73,
408
+ "Lonsdalea populi": 74,
409
+ "Moellerella wisconsensis": 75,
410
+ "Morganella morganii": 76,
411
+ "Mus musculus": 77,
412
+ "Nicotiana tabacum": 78,
413
+ "Nicotiana tabacum chloroplast": 79,
414
+ "Obesumbacterium proteus": 80,
415
+ "Pantoea agglomerans": 81,
416
+ "Pantoea allii": 82,
417
+ "Pantoea ananatis PA13": 83,
418
+ "Pantoea dispersa": 84,
419
+ "Pantoea stewartii": 85,
420
+ "Pantoea vagans": 86,
421
+ "Pectobacterium aroidearum": 87,
422
+ "Pectobacterium atrosepticum": 88,
423
+ "Pectobacterium brasiliense": 89,
424
+ "Pectobacterium carotovorum": 90,
425
+ "Pectobacterium odoriferum": 91,
426
+ "Pectobacterium parmentieri": 92,
427
+ "Pectobacterium polaris": 93,
428
+ "Pectobacterium versatile": 94,
429
+ "Photorhabdus laumondii subsp. laumondii TTO1": 95,
430
+ "Plesiomonas shigelloides": 96,
431
+ "Pluralibacter gergoviae": 97,
432
+ "Proteus faecis": 98,
433
+ "Proteus mirabilis HI4320": 99,
434
+ "Proteus penneri": 100,
435
+ "Proteus terrae subsp. cibarius": 101,
436
+ "Proteus vulgaris": 102,
437
+ "Providencia alcalifaciens": 103,
438
+ "Providencia heimbachae": 104,
439
+ "Providencia rettgeri": 105,
440
+ "Providencia rustigianii": 106,
441
+ "Providencia stuartii": 107,
442
+ "Providencia thailandensis": 108,
443
+ "Pseudomonas putida": 109,
444
+ "Pyrococcus furiosus": 110,
445
+ "Pyrococcus horikoshii": 111,
446
+ "Pyrococcus yayanosii": 112,
447
+ "Rahnella aquatilis CIP 78.65 = ATCC 33071": 113,
448
+ "Raoultella ornithinolytica": 114,
449
+ "Raoultella planticola": 115,
450
+ "Raoultella terrigena": 116,
451
+ "Rosenbergiella epipactidis": 117,
452
+ "Rouxiella badensis": 118,
453
+ "Saccharolobus solfataricus": 119,
454
+ "Saccharomyces cerevisiae": 120,
455
+ "Salmonella bongori N268-08": 121,
456
+ "Salmonella enterica subsp. enterica serovar Typhimurium str. LT2": 122,
457
+ "Serratia bockelmannii": 123,
458
+ "Serratia entomophila": 124,
459
+ "Serratia ficaria": 125,
460
+ "Serratia fonticola": 126,
461
+ "Serratia grimesii": 127,
462
+ "Serratia liquefaciens": 128,
463
+ "Serratia marcescens": 129,
464
+ "Serratia nevei": 130,
465
+ "Serratia plymuthica AS9": 131,
466
+ "Serratia proteamaculans": 132,
467
+ "Serratia quinivorans": 133,
468
+ "Serratia rubidaea": 134,
469
+ "Serratia ureilytica": 135,
470
+ "Shigella boydii": 136,
471
+ "Shigella dysenteriae": 137,
472
+ "Shigella flexneri 2a str. 301": 138,
473
+ "Shigella sonnei": 139,
474
+ "Thermoccoccus kodakarensis": 140,
475
+ "Thermococcus barophilus MPT": 141,
476
+ "Thermococcus chitonophagus": 142,
477
+ "Thermococcus gammatolerans": 143,
478
+ "Thermococcus litoralis": 144,
479
+ "Thermococcus onnurineus": 145,
480
+ "Thermococcus sibiricus": 146,
481
+ "Xenorhabdus bovienii str. feltiae Florida": 147,
482
+ "Yersinia aldovae 670-83": 148,
483
+ "Yersinia aleksiciae": 149,
484
+ "Yersinia alsatica": 150,
485
+ "Yersinia enterocolitica": 151,
486
+ "Yersinia frederiksenii ATCC 33641": 152,
487
+ "Yersinia intermedia": 153,
488
+ "Yersinia kristensenii": 154,
489
+ "Yersinia massiliensis CCUG 53443": 155,
490
+ "Yersinia mollaretii ATCC 43969": 156,
491
+ "Yersinia pestis A1122": 157,
492
+ "Yersinia proxima": 158,
493
+ "Yersinia pseudotuberculosis IP 32953": 159,
494
+ "Yersinia rochesterensis": 160,
495
+ "Yersinia rohdei": 161,
496
+ "Yersinia ruckeri": 162,
497
+ "Yokenella regensburgei": 163,
498
+ }
499
+
500
+ # Dictionary mapping each organism id to respective organism name
501
+ ID2ORGANISM = {v: k for k, v in ORGANISM2ID.items()}
502
+
503
+ # Type alias for amino acid to codon mapping
504
+ AMINO2CODON_TYPE = Dict[str, Tuple[List[str], List[float]]]
505
+
506
+ # Constants for the number of organisms and sequence lengths
507
+ NUM_ORGANISMS = 164
508
+ MAX_LEN = 2048
509
+ MAX_AMINO_ACIDS = MAX_LEN - 2 # Without special tokens [CLS] and [SEP]
510
+ STOP_SYMBOL = "_"
511
+
512
+
513
+ @dataclass
514
+ class DNASequencePrediction:
515
+ """
516
+ A class to hold the output of the DNA sequence prediction.
517
+
518
+ Attributes:
519
+ organism (str): Name of the organism used for prediction.
520
+ protein (str): Input protein sequence for which DNA sequence is predicted.
521
+ processed_input (str): Processed input sequence (merged protein and DNA).
522
+ predicted_dna (str): Predicted DNA sequence.
523
+ """
524
+
525
+ organism: str
526
+ protein: str
527
+ processed_input: str
528
+ predicted_dna: str
529
+
530
+
531
+ class IterableData(torch.utils.data.IterableDataset):
532
+ """
533
+ Defines the logic for iterable datasets (working over streams of
534
+ data) in parallel multi-processing environments, e.g., multi-GPU.
535
+
536
+ Args:
537
+ dist_env (Optional[str]): The distribution environment identifier
538
+ (e.g., "slurm").
539
+
540
+ Credit: Guillaume Filion
541
+ """
542
+
543
+ def __init__(self, dist_env: Optional[str] = None):
544
+ super().__init__()
545
+ if dist_env is None:
546
+ self.world_size_handle, self.rank_handle = ("WORLD_SIZE", "LOCAL_RANK")
547
+ else:
548
+ self.world_size_handle, self.rank_handle = {
549
+ "slurm": ("SLURM_NTASKS", "SLURM_PROCID")
550
+ }.get(dist_env, ("WORLD_SIZE", "LOCAL_RANK"))
551
+
552
+ @property
553
+ def iterator(self) -> Iterator:
554
+ """Define the stream logic for the dataset. Implement in subclasses."""
555
+ raise NotImplementedError
556
+
557
+ def __iter__(self) -> Iterator:
558
+ """
559
+ Create an iterator for the dataset, handling multi-processing contexts.
560
+
561
+ Returns:
562
+ Iterator: The iterator for the dataset.
563
+ """
564
+ worker_info = torch.utils.data.get_worker_info()
565
+ if worker_info is None:
566
+ return self.iterator
567
+
568
+ # In multi-processing context, use 'os.environ' to
569
+ # find global worker rank. Then use 'islice' to allocate
570
+ # the items of the stream to the workers.
571
+ world_size = int(os.environ.get(self.world_size_handle, "1"))
572
+ global_rank = int(os.environ.get(self.rank_handle, "0"))
573
+ local_rank = worker_info.id
574
+ local_num_workers = worker_info.num_workers
575
+
576
+ # Assume that each process has the same number of local workers.
577
+ worker_rk = global_rank * local_num_workers + local_rank
578
+ worker_nb = world_size * local_num_workers
579
+ return itertools.islice(self.iterator, worker_rk, None, worker_nb)
580
+
581
+
582
+ class IterableJSONData(IterableData):
583
+ """
584
+ Iterate over the lines of a JSON file and uncompress if needed.
585
+
586
+ Args:
587
+ data_path (str): The path to the JSON data file.
588
+ train (bool): Flag indicating if the dataset is for training.
589
+ **kwargs: Additional keyword arguments for the base class.
590
+ """
591
+
592
+ def __init__(self, data_path: str, train: bool = True, **kwargs):
593
+ super().__init__(**kwargs)
594
+ self.data_path = data_path
595
+ self.train = train
596
+ with open(os.path.join(self.data_path, "finetune_set.json"), "r") as f:
597
+ self.records = [json.loads(line) for line in f]
598
+
599
+ def __len__(self):
600
+ return len(self.records)
601
+
602
+ @property
603
+ def iterator(self) -> Iterator:
604
+ """Define the stream logic for the dataset."""
605
+ for record in self.records:
606
+ yield record
607
+
608
+
609
+ class ConfigManager(ABC):
610
+ """
611
+ Abstract base class for managing configuration settings.
612
+ """
613
+ _config: Dict[str, Any]
614
+
615
+ def __enter__(self):
616
+ return self
617
+
618
+ def __exit__(self, exc_type, exc_value, traceback):
619
+ if exc_type is not None:
620
+ print(f"Exception occurred: {exc_type}, {exc_value}, {traceback}")
621
+ self.reset_config()
622
+
623
+ @abstractmethod
624
+ def reset_config(self) -> None:
625
+ """Reset the configuration to default values."""
626
+ pass
627
+
628
+ def get(self, key: str) -> Any:
629
+ """
630
+ Get the value of a configuration key.
631
+
632
+ Args:
633
+ key (str): The key to retrieve the value for.
634
+
635
+ Returns:
636
+ Any: The value of the configuration key.
637
+ """
638
+ return self._config.get(key)
639
+
640
+ def set(self, key: str, value: Any) -> None:
641
+ """
642
+ Set the value of a configuration key.
643
+
644
+ Args:
645
+ key (str): The key to set the value for.
646
+ value (Any): The value to set for the key.
647
+ """
648
+ self.validate_inputs(key, value)
649
+ self._config[key] = value
650
+
651
+ def update(self, config_dict: dict) -> None:
652
+ """
653
+ Update the configuration with a dictionary of key-value pairs after validating them.
654
+
655
+ Args:
656
+ config_dict (dict): A dictionary of key-value pairs to update the configuration.
657
+ """
658
+ for key, value in config_dict.items():
659
+ self.validate_inputs(key, value)
660
+ self._config.update(config_dict)
661
+
662
+ @abstractmethod
663
+ def validate_inputs(self, key: str, value: Any) -> None:
664
+ """Validate the inputs for the configuration."""
665
+ pass
666
+
667
+
668
+ class ProteinConfig(ConfigManager):
669
+ """
670
+ A class to manage configuration settings for protein sequences.
671
+
672
+ This class ensures that the configuration is a singleton.
673
+ It provides methods to get, set, and update configuration values.
674
+
675
+ Attributes:
676
+ _instance (Optional[ConfigManager]): The singleton instance of the ConfigManager.
677
+ _config (Dict[str, Any]): The configuration dictionary.
678
+ """
679
+
680
+ _instance = None
681
+
682
+ def __new__(cls):
683
+ """
684
+ Create a new instance of the ProteinConfig class.
685
+
686
+ Returns:
687
+ ProteinConfig: The singleton instance of the ProteinConfig.
688
+ """
689
+ if cls._instance is None:
690
+ cls._instance = super(ProteinConfig, cls).__new__(cls)
691
+ cls._instance.reset_config()
692
+ return cls._instance
693
+
694
+ def validate_inputs(self, key: str, value: Any) -> None:
695
+ """
696
+ Validate the inputs for the configuration.
697
+
698
+ Args:
699
+ key (str): The key to validate.
700
+ value (Any): The value to validate.
701
+
702
+ Raises:
703
+ ValueError: If the value is invalid.
704
+ TypeError: If the value is of the wrong type.
705
+ """
706
+ if key == "ambiguous_aminoacid_behavior":
707
+ if value not in [
708
+ "raise_error",
709
+ "standardize_deterministic",
710
+ "standardize_random",
711
+ ]:
712
+ raise ValueError(
713
+ f"Invalid value for ambiguous_aminoacid_behavior: {value}."
714
+ )
715
+ elif key == "ambiguous_aminoacid_map_override":
716
+ if not isinstance(value, dict):
717
+ raise TypeError(
718
+ f"Invalid type for ambiguous_aminoacid_map_override: {value}."
719
+ )
720
+ for ambiguous_aminoacid, aminoacids in value.items():
721
+ if not isinstance(aminoacids, list):
722
+ raise TypeError(f"Invalid type for aminoacids: {aminoacids}.")
723
+ if not aminoacids:
724
+ raise ValueError(
725
+ f"Override for aminoacid '{ambiguous_aminoacid}' cannot be empty list."
726
+ )
727
+ if ambiguous_aminoacid not in AMBIGUOUS_AMINOACID_MAP:
728
+ raise ValueError(
729
+ f"Invalid amino acid in ambiguous_aminoacid_map_override: {ambiguous_aminoacid}"
730
+ )
731
+ else:
732
+ raise ValueError(f"Invalid configuration key: {key}")
733
+
734
+ def reset_config(self) -> None:
735
+ """
736
+ Reset the configuration to the default values.
737
+ """
738
+ self._config = {
739
+ "ambiguous_aminoacid_behavior": "standardize_random",
740
+ "ambiguous_aminoacid_map_override": {},
741
+ }
742
+
743
+
744
+ def load_python_object_from_disk(file_path: str) -> Any:
745
+ """
746
+ Load a Pickle object from disk and return it as a Python object.
747
+
748
+ Args:
749
+ file_path (str): The path to the Pickle file.
750
+
751
+ Returns:
752
+ Any: The loaded Python object.
753
+ """
754
+ with open(file_path, "rb") as file:
755
+ return pickle.load(file)
756
+
757
+
758
+ def save_python_object_to_disk(input_object: Any, file_path: str) -> None:
759
+ """
760
+ Save a Python object to disk using Pickle.
761
+
762
+ Args:
763
+ input_object (Any): The Python object to save.
764
+ file_path (str): The path where the object will be saved.
765
+ """
766
+ with open(file_path, "wb") as file:
767
+ pickle.dump(input_object, file)
768
+
769
+
770
+ def find_pattern_in_fasta(keyword: str, text: str) -> str:
771
+ """
772
+ Find a specific keyword pattern in text. Helpful for identifying parts
773
+ of a FASTA sequence.
774
+
775
+ Args:
776
+ keyword (str): The keyword pattern to search for.
777
+ text (str): The text to search within.
778
+
779
+ Returns:
780
+ str: The found pattern or an empty string if not found.
781
+ """
782
+ # Search for the keyword pattern in the text using regex
783
+ result = re.search(keyword + r"=(.*?)]", text)
784
+ return result.group(1) if result else ""
785
+
786
+
787
+ def get_organism2id_dict(organism_reference: str) -> Dict[str, int]:
788
+ """
789
+ Return a dictionary mapping each organism in training data to an index
790
+ used for training.
791
+
792
+ Args:
793
+ organism_reference (str): Path to a CSV file containing a list of
794
+ all organisms. The format of the CSV file should be as follows:
795
+
796
+ 0,Escherichia coli
797
+ 1,Homo sapiens
798
+ 2,Mus musculus
799
+
800
+ Returns:
801
+ Dict[str, int]: Dictionary mapping organism names to their respective indices.
802
+ """
803
+ # Read the CSV file and create a dictionary mapping organisms to their indices
804
+ organisms = pd.read_csv(organism_reference, index_col=0, header=None)
805
+ organism2id = {organisms.iloc[i].values[0]: i for i in organisms.index}
806
+
807
+ return organism2id
808
+
809
+
810
+ def get_taxonomy_id(
811
+ taxonomy_reference: str, organism: Optional[str] = None, return_dict: bool = False
812
+ ) -> Any:
813
+ """
814
+ Return the taxonomy id of a given organism using a reference file.
815
+ Optionally, return the whole dictionary instead if return_dict is True.
816
+
817
+ Args:
818
+ taxonomy_reference (str): Path to the taxonomy reference file.
819
+ organism (Optional[str]): The name of the organism to look up.
820
+ return_dict (bool): Whether to return the entire dictionary.
821
+
822
+ Returns:
823
+ Any: The taxonomy id of the organism or the entire dictionary.
824
+ """
825
+ # Load the organism-to-taxonomy mapping from a Pickle file
826
+ organism2taxonomy = load_python_object_from_disk(taxonomy_reference)
827
+
828
+ if return_dict:
829
+ return dict(sorted(organism2taxonomy.items()))
830
+
831
+ return organism2taxonomy[organism]
832
+
833
+
834
+ def sort_amino2codon_skeleton(amino2codon: Dict[str, Any]) -> Dict[str, Any]:
835
+ """
836
+ Sort the amino2codon dictionary alphabetically by amino acid and by codon name.
837
+
838
+ Args:
839
+ amino2codon (Dict[str, Any]): The amino2codon dictionary to sort.
840
+
841
+ Returns:
842
+ Dict[str, Any]: The sorted amino2codon dictionary.
843
+ """
844
+ # Sort the dictionary by amino acid and then by codon name
845
+ amino2codon = dict(sorted(amino2codon.items()))
846
+ amino2codon = {
847
+ amino: (
848
+ [codon for codon, _ in sorted(zip(codons, frequencies))],
849
+ [freq for _, freq in sorted(zip(codons, frequencies))],
850
+ )
851
+ for amino, (codons, frequencies) in amino2codon.items()
852
+ }
853
+
854
+ return amino2codon
855
+
856
+
857
+ def load_pkl_from_url(url: str) -> Any:
858
+ """
859
+ Download a Pickle file from a URL and return the loaded object.
860
+
861
+ Args:
862
+ url (str): The URL to download the Pickle file from.
863
+
864
+ Returns:
865
+ Any: The loaded Python object from the Pickle file.
866
+ """
867
+ response = requests.get(url)
868
+ response.raise_for_status() # Ensure the request was successful
869
+
870
+ # Load the Pickle object from the response content
871
+ return pickle.loads(response.content)
CodonTransformer/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ """CodonTransformer package."""
Dockerfile ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.10-slim
2
+
3
+ ENV PYTHONDONTWRITEBYTECODE=1
4
+ ENV PYTHONUNBUFFERED=1
5
+ ENV PIP_NO_CACHE_DIR=1
6
+
7
+ WORKDIR /app
8
+
9
+ RUN apt-get update && apt-get install -y --no-install-recommends \
10
+ git \
11
+ build-essential \
12
+ && rm -rf /var/lib/apt/lists/*
13
+
14
+ COPY requirements.txt /app/requirements.txt
15
+ RUN pip install --upgrade pip && pip install -r /app/requirements.txt
16
+
17
+ COPY . /app
18
+
19
+ EXPOSE 7860
20
+
21
+ CMD ["streamlit", "run", "app.py", "--server.port=7860", "--server.address=0.0.0.0", "--server.headless=true"]
ENCOT_Academic_Documentation.html ADDED
@@ -0,0 +1,2625 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <!DOCTYPE html>
2
+ <html lang="en">
3
+ <head>
4
+ <meta charset="UTF-8">
5
+ <meta name="viewport" content="width=device-width, initial-scale=1.0">
6
+ <title>ENCOT: Enhanced Codon Optimization Tool - Technical Documentation</title>
7
+ <link rel="stylesheet" href="https://cdnjs.cloudflare.com/ajax/libs/highlight.js/11.9.0/styles/atom-one-light.min.css">
8
+ <script src="https://cdnjs.cloudflare.com/ajax/libs/highlight.js/11.9.0/highlight.min.js"></script>
9
+ <script src="https://cdnjs.cloudflare.com/ajax/libs/highlight.js/11.9.0/languages/python.min.js"></script>
10
+ <script src="https://cdnjs.cloudflare.com/ajax/libs/highlight.js/11.9.0/languages/yaml.min.js"></script>
11
+ <link href="https://fonts.googleapis.com/css2?family=Computer+Modern+Serif:wght@400;700&family=Computer+Modern+Sans:wght@400;700&family=Computer+Modern+Typewriter&display=swap" rel="stylesheet">
12
+ <style>
13
+ /* LaTeX-inspired Academic Styling */
14
+ @import url('https://fonts.googleapis.com/css2?family=Crimson+Text:ital,wght@0,400;0,600;0,700;1,400&family=Source+Code+Pro:wght@400;500&display=swap');
15
+
16
+ * {
17
+ margin: 0;
18
+ padding: 0;
19
+ box-sizing: border-box;
20
+ }
21
+
22
+ body {
23
+ font-family: 'Crimson Text', 'Georgia', serif;
24
+ line-height: 1.6;
25
+ color: #2c3e50;
26
+ background: #f8f9fa;
27
+ padding: 40px;
28
+ max-width: 900px;
29
+ margin: 0 auto;
30
+ font-size: 11pt;
31
+ }
32
+
33
+ /* Academic Paper Header */
34
+ .paper-header {
35
+ text-align: center;
36
+ margin-bottom: 50px;
37
+ padding: 30px 0;
38
+ border-bottom: 2px solid #2c3e50;
39
+ }
40
+
41
+ .paper-header h1 {
42
+ font-size: 28pt;
43
+ font-weight: 700;
44
+ margin-bottom: 20px;
45
+ color: #1a1a1a;
46
+ letter-spacing: -0.5px;
47
+ }
48
+
49
+ .paper-header .subtitle {
50
+ font-size: 14pt;
51
+ font-style: italic;
52
+ color: #555;
53
+ margin-bottom: 25px;
54
+ }
55
+
56
+ .paper-header .authors {
57
+ font-size: 11pt;
58
+ color: #444;
59
+ margin-bottom: 10px;
60
+ }
61
+
62
+ .paper-header .affiliation {
63
+ font-size: 10pt;
64
+ color: #666;
65
+ font-style: italic;
66
+ }
67
+
68
+ /* Section Styling */
69
+ .section {
70
+ margin: 40px 0;
71
+ page-break-inside: avoid;
72
+ background: white;
73
+ padding: 25px;
74
+ border: 1px solid #ddd;
75
+ box-shadow: 0 1px 3px rgba(0,0,0,0.05);
76
+ }
77
+
78
+ .section-number {
79
+ font-weight: 700;
80
+ color: #2c3e50;
81
+ font-size: 14pt;
82
+ }
83
+
84
+ .section-title {
85
+ font-size: 16pt;
86
+ font-weight: 700;
87
+ color: #2c3e50;
88
+ margin: 15px 0 20px 0;
89
+ border-bottom: 1px solid #ccc;
90
+ padding-bottom: 8px;
91
+ }
92
+
93
+ .abstract, .description {
94
+ text-align: justify;
95
+ margin: 15px 0;
96
+ text-indent: 0;
97
+ hyphens: auto;
98
+ }
99
+
100
+ .abstract {
101
+ font-size: 10.5pt;
102
+ padding: 15px;
103
+ background: #f9f9f9;
104
+ border-left: 3px solid #3498db;
105
+ font-style: italic;
106
+ }
107
+
108
+ /* Code Blocks - LaTeX Listing Style */
109
+ .code-container {
110
+ margin: 20px 0;
111
+ border: 1px solid #ccc;
112
+ background: #fafafa;
113
+ }
114
+
115
+ .code-header {
116
+ background: #e8e8e8;
117
+ padding: 8px 15px;
118
+ border-bottom: 1px solid #ccc;
119
+ font-family: 'Source Code Pro', monospace;
120
+ font-size: 9pt;
121
+ color: #555;
122
+ }
123
+
124
+ .listing-number {
125
+ font-weight: 600;
126
+ color: #2c3e50;
127
+ }
128
+
129
+ pre {
130
+ margin: 0;
131
+ padding: 15px;
132
+ overflow-x: auto;
133
+ background: white;
134
+ border: none;
135
+ }
136
+
137
+ pre code {
138
+ font-family: 'Source Code Pro', 'Courier New', monospace;
139
+ font-size: 9pt;
140
+ line-height: 1.4;
141
+ color: #2c3e50;
142
+ }
143
+
144
+ /* Annotations and Highlights */
145
+ .annotation {
146
+ background: #fff3cd;
147
+ border-left: 4px solid #ffc107;
148
+ padding: 12px 15px;
149
+ margin: 15px 0;
150
+ font-size: 10pt;
151
+ }
152
+
153
+ .annotation strong {
154
+ color: #856404;
155
+ }
156
+
157
+ .key-concept {
158
+ background: #d1ecf1;
159
+ border-left: 4px solid #0c5460;
160
+ padding: 12px 15px;
161
+ margin: 15px 0;
162
+ font-size: 10pt;
163
+ }
164
+
165
+ .mathematical {
166
+ font-family: 'Crimson Text', serif;
167
+ font-style: italic;
168
+ text-align: center;
169
+ padding: 15px;
170
+ margin: 20px 0;
171
+ background: #f9f9f9;
172
+ border: 1px solid #ddd;
173
+ font-size: 11pt;
174
+ }
175
+
176
+ /* File References */
177
+ .file-ref {
178
+ font-family: 'Source Code Pro', monospace;
179
+ font-size: 9pt;
180
+ color: #2c3e50;
181
+ background: #f4f4f4;
182
+ padding: 8px 12px;
183
+ border-left: 3px solid #3498db;
184
+ margin: 15px 0;
185
+ }
186
+
187
+ .file-path {
188
+ font-weight: 600;
189
+ color: #2980b9;
190
+ }
191
+
192
+ /* Handwritten-style Notes */
193
+ .handwritten-note {
194
+ border: 2px dashed #95a5a6;
195
+ padding: 15px;
196
+ margin: 20px 0;
197
+ background: #fef9e7;
198
+ font-size: 10pt;
199
+ position: relative;
200
+ }
201
+
202
+ .handwritten-note::before {
203
+ content: "✏️ Important Note:";
204
+ font-weight: 600;
205
+ color: #7f8c8d;
206
+ display: block;
207
+ margin-bottom: 8px;
208
+ }
209
+
210
+ /* Algorithm/Pseudocode Box */
211
+ .algorithm-box {
212
+ border: 2px solid #2c3e50;
213
+ padding: 20px;
214
+ margin: 20px 0;
215
+ background: white;
216
+ }
217
+
218
+ .algorithm-title {
219
+ font-weight: 700;
220
+ text-align: center;
221
+ margin-bottom: 15px;
222
+ font-size: 11pt;
223
+ text-transform: uppercase;
224
+ letter-spacing: 1px;
225
+ }
226
+
227
+ .algorithm-content {
228
+ font-family: 'Source Code Pro', monospace;
229
+ font-size: 9.5pt;
230
+ line-height: 1.8;
231
+ }
232
+
233
+ /* Equation Styling */
234
+ .equation {
235
+ text-align: center;
236
+ margin: 25px 0;
237
+ font-size: 12pt;
238
+ font-family: 'Crimson Text', serif;
239
+ }
240
+
241
+ .equation-label {
242
+ float: right;
243
+ font-size: 10pt;
244
+ color: #7f8c8d;
245
+ }
246
+
247
+ /* Table Styling */
248
+ table {
249
+ width: 100%;
250
+ border-collapse: collapse;
251
+ margin: 20px 0;
252
+ font-size: 10pt;
253
+ }
254
+
255
+ th, td {
256
+ border: 1px solid #bbb;
257
+ padding: 8px 12px;
258
+ text-align: left;
259
+ }
260
+
261
+ th {
262
+ background: #ecf0f1;
263
+ font-weight: 600;
264
+ }
265
+
266
+ /* Footer */
267
+ .footer {
268
+ margin-top: 50px;
269
+ padding-top: 20px;
270
+ border-top: 1px solid #ccc;
271
+ text-align: center;
272
+ font-size: 9pt;
273
+ color: #7f8c8d;
274
+ }
275
+
276
+ /* Print Styles - Optimized for minimal spacing */
277
+ @page {
278
+ size: A4;
279
+ margin: 1.2cm 1.5cm;
280
+ }
281
+
282
+ @page :first {
283
+ margin-top: 1.5cm;
284
+ }
285
+
286
+ @media print {
287
+ * {
288
+ -webkit-print-color-adjust: exact !important;
289
+ print-color-adjust: exact !important;
290
+ }
291
+
292
+ body {
293
+ background: white;
294
+ padding: 0;
295
+ margin: 0;
296
+ font-size: 9.5pt;
297
+ line-height: 1.35;
298
+ }
299
+
300
+ /* Minimize margins */
301
+ .paper-header {
302
+ margin-bottom: 15px;
303
+ padding: 10px 0;
304
+ page-break-after: avoid;
305
+ }
306
+
307
+ .paper-header h1 {
308
+ font-size: 20pt;
309
+ margin-bottom: 8px;
310
+ }
311
+
312
+ .paper-header .subtitle {
313
+ font-size: 10pt;
314
+ margin: 3px 0;
315
+ }
316
+
317
+ .abstract {
318
+ margin: 12px 0;
319
+ padding: 10px;
320
+ page-break-after: avoid;
321
+ page-break-inside: avoid;
322
+ }
323
+
324
+ /* Section optimization - ALLOW BREAKS */
325
+ .section {
326
+ box-shadow: none;
327
+ border: none;
328
+ padding: 8px 10px;
329
+ margin: 5px 0;
330
+ page-break-inside: auto; /* Changed from avoid */
331
+ background: white;
332
+ }
333
+
334
+ .section-title {
335
+ font-size: 12pt;
336
+ margin-bottom: 6px;
337
+ page-break-after: avoid;
338
+ }
339
+
340
+ .description {
341
+ margin: 5px 0;
342
+ font-size: 9.5pt;
343
+ line-height: 1.35;
344
+ }
345
+
346
+ /* Code containers - allow breaks */
347
+ .code-container {
348
+ page-break-inside: auto;
349
+ margin: 8px 0;
350
+ padding: 6px;
351
+ border: 1px solid #ccc;
352
+ }
353
+
354
+ .code-header {
355
+ padding: 4px 6px;
356
+ margin-bottom: 4px;
357
+ page-break-after: avoid;
358
+ font-size: 9pt;
359
+ }
360
+
361
+ pre {
362
+ margin: 0;
363
+ padding: 6px;
364
+ font-size: 7.5pt;
365
+ line-height: 1.25;
366
+ white-space: pre-wrap;
367
+ word-wrap: break-word;
368
+ }
369
+
370
+ code {
371
+ font-size: 7.5pt;
372
+ line-height: 1.25;
373
+ }
374
+
375
+ /* File references */
376
+ .file-ref {
377
+ margin: 5px 0;
378
+ padding: 4px 6px;
379
+ font-size: 8.5pt;
380
+ page-break-inside: avoid;
381
+ }
382
+
383
+ .file-path {
384
+ font-size: 8.5pt;
385
+ }
386
+
387
+ /* Mathematical content */
388
+ .mathematical {
389
+ margin: 8px 0;
390
+ padding: 6px;
391
+ font-size: 9.5pt;
392
+ page-break-inside: avoid;
393
+ }
394
+
395
+ .equation {
396
+ margin: 8px 0;
397
+ font-size: 10pt;
398
+ }
399
+
400
+ /* Key concepts and notes */
401
+ .key-concept {
402
+ margin: 8px 0;
403
+ padding: 6px;
404
+ font-size: 9pt;
405
+ page-break-inside: avoid;
406
+ }
407
+
408
+ .key-concept ul {
409
+ margin: 4px 0 0 12px;
410
+ }
411
+
412
+ .key-concept li {
413
+ margin: 2px 0;
414
+ line-height: 1.25;
415
+ }
416
+
417
+ .handwritten-note {
418
+ margin: 8px 0;
419
+ padding: 6px;
420
+ font-size: 8.5pt;
421
+ page-break-inside: avoid;
422
+ }
423
+
424
+ .handwritten-note::before {
425
+ margin-bottom: 4px;
426
+ }
427
+
428
+ /* Algorithm boxes */
429
+ .algorithm-box {
430
+ margin: 8px 0;
431
+ padding: 8px;
432
+ page-break-inside: auto; /* Allow break for long algorithms */
433
+ }
434
+
435
+ .algorithm-title {
436
+ font-size: 10pt;
437
+ margin-bottom: 6px;
438
+ }
439
+
440
+ .algorithm-content {
441
+ font-size: 8pt;
442
+ line-height: 1.4;
443
+ }
444
+
445
+ /* Tables */
446
+ table {
447
+ margin: 8px 0;
448
+ font-size: 8.5pt;
449
+ page-break-inside: auto;
450
+ }
451
+
452
+ th, td {
453
+ padding: 4px 6px;
454
+ font-size: 8.5pt;
455
+ }
456
+
457
+ /* Page break control */
458
+ h1, h2, h3, .section-title {
459
+ page-break-after: avoid;
460
+ }
461
+
462
+ .section:first-of-type {
463
+ page-break-before: avoid;
464
+ }
465
+
466
+ /* Keep title with at least some content */
467
+ .section-title + .description,
468
+ .code-header + pre {
469
+ page-break-before: avoid;
470
+ }
471
+
472
+ /* Hide unnecessary elements */
473
+ .footer {
474
+ display: none;
475
+ }
476
+
477
+ /* Compact spacing for lists */
478
+ ul, ol {
479
+ margin: 4px 0;
480
+ padding-left: 18px;
481
+ }
482
+
483
+ li {
484
+ margin: 1px 0;
485
+ line-height: 1.25;
486
+ }
487
+
488
+ /* Orphan and widow control */
489
+ p, .description, .key-concept, .handwritten-note {
490
+ orphans: 2;
491
+ widows: 2;
492
+ }
493
+
494
+ /* Reduce all vertical spacing */
495
+ * + * {
496
+ margin-top: 0 !important;
497
+ }
498
+ }
499
+ </style>
500
+ </head>
501
+ <body>
502
+
503
+ <!-- Academic Paper Header -->
504
+ <div class="paper-header">
505
+ <h1>ENCOT: Enhanced Codon Optimization Tool</h1>
506
+ <div class="subtitle">
507
+ A Transformer-Based Approach with Augmented-Lagrangian Method<br>
508
+ for Multi-Objective Codon Optimization in E. coli
509
+ </div>
510
+ <div class="authors">
511
+ Technical Implementation Documentation
512
+ </div>
513
+
514
+ </div>
515
+
516
+ <!-- Abstract -->
517
+ <div class="abstract">
518
+ <strong>Abstract:</strong> This document presents the technical implementation of ENCOT, a novel codon optimization
519
+ system that employs transformer-based deep learning combined with an Augmented-Lagrangian Method (ALM) for
520
+ precise control of GC content. The system optimizes multiple biological objectives simultaneously including
521
+ Codon Adaptation Index (CAI), tRNA Adaptation Index (tAI), GC content balance, and minimization of negative
522
+ cis-regulatory elements. The implementation builds upon the CodonTransformer architecture and introduces
523
+ innovative constraint optimization techniques for enhanced E. coli expression systems.
524
+ </div>
525
+
526
+ <!-- Section 1: Core Algorithm - ALM Implementation -->
527
+ <div class="section">
528
+ <div class="section-title">
529
+ <span class="section-number">1.</span> Augmented-Lagrangian Method Implementation
530
+ </div>
531
+
532
+ <div class="description">
533
+ The core innovation of ENCOT lies in its application of the Augmented-Lagrangian Method to enforce
534
+ GC content constraints during training. This approach allows the model to balance multiple optimization
535
+ objectives while maintaining biologically appropriate GC content levels.
536
+ </div>
537
+
538
+ <div class="mathematical">
539
+ <strong>Objective Function:</strong><br><br>
540
+ <i>L</i> = <i>L</i><sub>MLM</sub> + λ·(<i>GC</i> − μ) + (ρ/2)·(<i>GC</i> − μ)²
541
+ <div class="equation-label">(Eq. 1)</div>
542
+ </div>
543
+
544
+ <div class="key-concept">
545
+ <strong>Key Components:</strong>
546
+ <ul style="margin: 10px 0 0 20px;">
547
+ <li><i>L<sub>MLM</sub></i>: Masked Language Modeling loss for codon prediction</li>
548
+ <li>λ: Lagrangian multiplier (adaptively updated)</li>
549
+ <li>ρ: Penalty coefficient (self-tuning based on progress)</li>
550
+ <li><i>GC</i>: Mean GC content of predicted sequences</li>
551
+ <li>μ: Target GC content (0.52 for E. coli)</li>
552
+ </ul>
553
+ </div>
554
+
555
+ <div class="file-ref">
556
+ <div class="file-path">File: finetune.py</div>
557
+ Lines 73-148 | Class: plTrainHarness
558
+ </div>
559
+
560
+ <div class="code-container">
561
+ <div class="code-header">
562
+ <span class="listing-number">Listing 1:</span> ALM Training Harness - Initialization
563
+ </div>
564
+ <pre><code class="language-python">class plTrainHarness(pl.LightningModule):
565
+ """
566
+ PyTorch Lightning training harness for ENCOT with Augmented-Lagrangian
567
+ Method (ALM) GC control.
568
+
569
+ This class implements the training loop for fine-tuning CodonTransformer
570
+ on E. coli sequences with precise GC content control using an
571
+ Augmented-Lagrangian Method. The ALM approach allows the model to learn
572
+ codon preferences while maintaining GC content within a target range.
573
+
574
+ Key features:
575
+ - Masked language modeling (MLM) loss for codon prediction
576
+ - ALM-based GC content constraint enforcement
577
+ - Curriculum learning: warm-up epochs before enforcing GC constraints
578
+ - Adaptive penalty coefficient (rho) adjustment based on constraint
579
+ violation progress
580
+
581
+ The ALM method minimizes:
582
+ L = L_MLM + λ·(GC - μ) + (ρ/2)(GC - μ)²
583
+ where λ is the Lagrangian multiplier and ρ is the penalty coefficient.
584
+ """
585
+
586
+ def __init__(self, model, learning_rate, warmup_fraction,
587
+ gc_penalty_weight, tokenizer, gc_target=0.52,
588
+ use_lagrangian=False, lagrangian_rho=10.0,
589
+ curriculum_epochs=3, alm_tolerance=1e-5,
590
+ alm_dual_tolerance=1e-5, alm_penalty_update_factor=10.0,
591
+ alm_initial_penalty_factor=20.0,
592
+ alm_tolerance_update_factor=0.1,
593
+ alm_rel_penalty_increase_threshold=0.1,
594
+ alm_max_penalty=1e6, alm_min_penalty=1e-6):
595
+ super().__init__()
596
+ self.model = model
597
+ self.learning_rate = learning_rate
598
+ self.warmup_fraction = warmup_fraction
599
+ self.gc_penalty_weight = gc_penalty_weight
600
+ self.tokenizer = tokenizer
601
+
602
+ # Augmented-Lagrangian GC Control parameters
603
+ self.gc_target = gc_target
604
+ self.use_lagrangian = use_lagrangian
605
+ self.lagrangian_rho = lagrangian_rho
606
+ self.curriculum_epochs = curriculum_epochs
607
+
608
+ # Enhanced ALM parameters
609
+ self.alm_tolerance = alm_tolerance
610
+ self.alm_dual_tolerance = alm_dual_tolerance
611
+ self.alm_penalty_update_factor = alm_penalty_update_factor
612
+ self.alm_initial_penalty_factor = alm_initial_penalty_factor
613
+ self.alm_tolerance_update_factor = alm_tolerance_update_factor
614
+ self.alm_rel_penalty_increase_threshold = \
615
+ alm_rel_penalty_increase_threshold
616
+ self.alm_max_penalty = alm_max_penalty
617
+ self.alm_min_penalty = alm_min_penalty
618
+
619
+ # Initialize Lagrangian multiplier as buffer
620
+ # (persists across checkpoints)
621
+ self.register_buffer("lambda_gc", torch.tensor(0.0))
622
+
623
+ # Adaptive penalty coefficient (rho)
624
+ self.register_buffer("rho_adaptive",
625
+ torch.tensor(self.lagrangian_rho))
626
+
627
+ # Step counter for periodic lambda updates
628
+ self.register_buffer("step_counter", torch.tensor(0))
629
+
630
+ # ALM convergence tracking
631
+ self.register_buffer("previous_constraint_violation",
632
+ torch.tensor(float('inf')))</code></pre>
633
+ </div>
634
+
635
+ <div class="handwritten-note">
636
+ The initialization sets up persistent buffers for Lagrangian multipliers and penalty coefficients.
637
+ These buffers are saved with model checkpoints, allowing training to resume seamlessly. The curriculum
638
+ learning approach waits for 3 epochs before enforcing GC constraints, giving the model time to learn
639
+ basic codon patterns first.
640
+ </div>
641
+ </div>
642
+
643
+ <!-- Section 2: Training Step -->
644
+ <div class="section">
645
+ <div class="section-title">
646
+ <span class="section-number">2.</span> Training Step with ALM Loss Computation
647
+ </div>
648
+
649
+ <div class="description">
650
+ The training step combines standard masked language modeling with the ALM-based GC constraint.
651
+ During each forward pass, we compute GC content from predicted tokens and apply the Lagrangian
652
+ penalty to guide the model toward the target GC content.
653
+ </div>
654
+
655
+ <div class="file-ref">
656
+ <div class="file-path">File: finetune.py</div>
657
+ Lines 150-230 | Method: training_step
658
+ </div>
659
+
660
+ <div class="code-container">
661
+ <div class="code-header">
662
+ <span class="listing-number">Listing 2:</span> Training Step with ALM Loss
663
+ </div>
664
+ <pre><code class="language-python">def training_step(self, batch, batch_idx):
665
+ """
666
+ Training step that computes MLM loss and applies ALM-based GC constraint.
667
+
668
+ The constraint is only enforced after curriculum_epochs warm-up period.
669
+ """
670
+ outputs = self.model(**batch)
671
+ mlm_loss = outputs.loss
672
+
673
+ # Enhanced Lagrangian-based GC penalty
674
+ if self.use_lagrangian and self.current_epoch >= self.curriculum_epochs:
675
+ # Compute GC content from logits
676
+ logits = outputs.logits
677
+ predicted_tokens = torch.argmax(logits, dim=-1)
678
+
679
+ # Calculate GC content per sequence
680
+ gc_content_batch = []
681
+ for seq_tokens in predicted_tokens:
682
+ # Filter to valid codon tokens (indices >= 26)
683
+ valid_tokens = seq_tokens[seq_tokens >= 26]
684
+ if len(valid_tokens) == 0:
685
+ gc_content_batch.append(self.gc_target)
686
+ continue
687
+
688
+ # Count G and C containing codons
689
+ gc_counts = sum(1 for token in valid_tokens
690
+ if token.item() in G_indices + C_indices)
691
+ gc_content = gc_counts / len(valid_tokens)
692
+ gc_content_batch.append(gc_content)
693
+
694
+ # Mean GC content across batch
695
+ gc_content_mean = sum(gc_content_batch) / len(gc_content_batch)
696
+
697
+ # Compute GC constraint violation
698
+ gc_constraint = gc_content_mean - self.gc_target
699
+
700
+ # Augmented Lagrangian loss term
701
+ lagrangian_loss = (
702
+ self.lambda_gc * gc_constraint +
703
+ (self.rho_adaptive / 2) * (gc_constraint ** 2)
704
+ )
705
+
706
+ total_loss = mlm_loss + lagrangian_loss
707
+
708
+ # Log metrics
709
+ self.log("train/mlm_loss", mlm_loss, prog_bar=True)
710
+ self.log("train/gc_constraint", gc_constraint, prog_bar=True)
711
+ self.log("train/lagrangian_loss", lagrangian_loss, prog_bar=False)
712
+ self.log("train/lambda_gc", self.lambda_gc, prog_bar=False)
713
+ self.log("train/rho", self.rho_adaptive, prog_bar=False)
714
+ self.log("train/gc_content", gc_content_mean, prog_bar=True)
715
+
716
+ # Update Lagrangian multiplier periodically
717
+ self.step_counter += 1
718
+ if self.step_counter % 20 == 0:
719
+ self._update_alm_parameters(gc_constraint)
720
+ else:
721
+ # During warm-up, only use MLM loss
722
+ total_loss = mlm_loss
723
+ self.log("train/mlm_loss", mlm_loss, prog_bar=True)
724
+
725
+ self.log("train/total_loss", total_loss, prog_bar=True)
726
+ return total_loss</code></pre>
727
+ </div>
728
+
729
+ <div class="annotation">
730
+ <strong>Implementation Detail:</strong> The GC content is computed from the argmax of logits rather than
731
+ from the actual target sequences. This allows the gradient to flow through the constraint, enabling the
732
+ model to learn to satisfy the constraint during generation.
733
+ </div>
734
+ </div>
735
+
736
+ <!-- Section 3: Adaptive Parameter Update -->
737
+ <div class="section">
738
+ <div class="section-title">
739
+ <span class="section-number">3.</span> Adaptive ALM Parameter Updates
740
+ </div>
741
+
742
+ <div class="description">
743
+ The self-tuning mechanism adjusts Lagrangian multipliers and penalty coefficients based on
744
+ constraint violation progress. This adaptive approach ensures convergence while maintaining
745
+ numerical stability.
746
+ </div>
747
+
748
+ <div class="algorithm-box">
749
+ <div class="algorithm-title">Algorithm 1: Adaptive Penalty Update</div>
750
+ <div class="algorithm-content">
751
+ <strong>Input:</strong> gc_constraint (current violation)<br>
752
+ <strong>Output:</strong> Updated λ_gc and ρ_adaptive<br><br>
753
+
754
+ 1. <strong>Compute</strong> relative_improvement ← <br>
755
+ &nbsp;&nbsp;&nbsp;(prev_violation - current_violation) / prev_violation<br><br>
756
+
757
+ 2. <strong>If</strong> |gc_constraint| ≤ tolerance <strong>then</strong><br>
758
+ &nbsp;&nbsp;&nbsp;λ_gc ← λ_gc + ρ · gc_constraint<br>
759
+ &nbsp;&nbsp;&nbsp;// Constraint satisfied, update multiplier only<br><br>
760
+
761
+ 3. <strong>Else if</strong> relative_improvement < threshold <strong>then</strong><br>
762
+ &nbsp;&nbsp;&nbsp;ρ ← min(ρ · update_factor, max_penalty)<br>
763
+ &nbsp;&nbsp;&nbsp;λ_gc ← λ_gc + ρ · gc_constraint<br>
764
+ &nbsp;&nbsp;&nbsp;// Insufficient progress, increase penalty<br><br>
765
+
766
+ 4. <strong>Else</strong><br>
767
+ &nbsp;&nbsp;&nbsp;λ_gc ← λ_gc + ρ · gc_constraint<br>
768
+ &nbsp;&nbsp;&nbsp;// Good progress, keep penalty stable<br><br>
769
+
770
+ 5. prev_violation ← |gc_constraint|
771
+ </div>
772
+ </div>
773
+
774
+ <div class="file-ref">
775
+ <div class="file-path">File: finetune.py</div>
776
+ Lines 260-320 | Method: _update_alm_parameters
777
+ </div>
778
+
779
+ <div class="code-container">
780
+ <div class="code-header">
781
+ <span class="listing-number">Listing 3:</span> Adaptive Parameter Update Implementation
782
+ </div>
783
+ <pre><code class="language-python">def _update_alm_parameters(self, gc_constraint):
784
+ """
785
+ Update Lagrangian multiplier and penalty coefficient according to ALM.
786
+
787
+ This implements the adaptive penalty update strategy:
788
+ - If constraint violation is decreasing sufficiently, update lambda
789
+ and keep rho
790
+ - If constraint violation is not improving, increase rho
791
+ (penalty coefficient)
792
+ """
793
+ constraint_violation = abs(gc_constraint.item())
794
+
795
+ # Check if we're making sufficient progress
796
+ relative_improvement = (
797
+ (self.previous_constraint_violation - constraint_violation) /
798
+ max(self.previous_constraint_violation, 1e-8)
799
+ )
800
+
801
+ if constraint_violation <= self.alm_tolerance:
802
+ # Constraint satisfied - update lambda, optionally reduce rho
803
+ self.lambda_gc = self.lambda_gc + self.rho_adaptive * gc_constraint
804
+ # Could reduce rho here if desired, but keeping it stable
805
+ # works well in practice
806
+
807
+ elif relative_improvement < self.alm_rel_penalty_increase_threshold:
808
+ # Not making enough progress - increase penalty
809
+ self.rho_adaptive = torch.clamp(
810
+ self.rho_adaptive * self.alm_penalty_update_factor,
811
+ min=self.alm_min_penalty,
812
+ max=self.alm_max_penalty
813
+ )
814
+ # Also update lambda
815
+ self.lambda_gc = self.lambda_gc + self.rho_adaptive * gc_constraint
816
+
817
+ else:
818
+ # Making good progress - just update lambda
819
+ self.lambda_gc = self.lambda_gc + self.rho_adaptive * gc_constraint
820
+
821
+ # Update tracking
822
+ self.previous_constraint_violation = torch.tensor(constraint_violation)</code></pre>
823
+ </div>
824
+
825
+ <div class="handwritten-note">
826
+ The key insight here is the relative improvement threshold. If the constraint violation isn't
827
+ improving by at least 10% (default threshold), we increase the penalty coefficient. This ensures
828
+ that the optimization doesn't get stuck in suboptimal regions where the constraint is consistently
829
+ violated.
830
+ </div>
831
+ </div>
832
+
833
+ <!-- Section 4: Prediction Function -->
834
+ <div class="section">
835
+ <div class="section-title">
836
+ <span class="section-number">4.</span> DNA Sequence Prediction with Constrained Search
837
+ </div>
838
+
839
+ <div class="description">
840
+ The prediction function supports multiple decoding strategies including deterministic (greedy),
841
+ stochastic (temperature sampling), and constrained beam search with GC bounds. This flexibility
842
+ allows users to balance between optimization quality and sequence diversity.
843
+ </div>
844
+
845
+ <div class="file-ref">
846
+ <div class="file-path">File: CodonTransformer/CodonPrediction.py</div>
847
+ Lines 38-120 | Function: predict_dna_sequence
848
+ </div>
849
+
850
+ <div class="code-container">
851
+ <div class="code-header">
852
+ <span class="listing-number">Listing 4:</span> Main Prediction Function Signature
853
+ </div>
854
+ <pre><code class="language-python">def predict_dna_sequence(
855
+ protein: str,
856
+ organism: Union[int, str],
857
+ device: torch.device,
858
+ tokenizer: Union[str, PreTrainedTokenizerFast] = None,
859
+ model: Union[str, torch.nn.Module] = None,
860
+ attention_type: str = "original_full",
861
+ deterministic: bool = True,
862
+ temperature: float = 0.2,
863
+ top_p: float = 0.95,
864
+ num_sequences: int = 1,
865
+ match_protein: bool = False,
866
+ use_constrained_search: bool = False,
867
+ gc_bounds: Tuple[float, float] = (0.30, 0.70),
868
+ beam_size: int = 5,
869
+ length_penalty: float = 1.0,
870
+ diversity_penalty: float = 0.0,
871
+ ) -> Union[DNASequencePrediction, List[DNASequencePrediction]]:
872
+ """
873
+ Predict the DNA sequence(s) for a given protein using ENCOT model.
874
+
875
+ This function takes a protein sequence and an organism (as ID or name)
876
+ as input and returns the predicted DNA sequence(s) using the ENCOT model.
877
+ It can use either provided tokenizer and model objects or load them from
878
+ specified paths.
879
+
880
+ Args:
881
+ protein (str): The input protein sequence for which to predict
882
+ the DNA sequence.
883
+ organism (Union[int, str]): Either the ID of the organism or its
884
+ name (e.g., "Escherichia coli general").
885
+ device (torch.device): The device (CPU or GPU) to run the model on.
886
+
887
+ deterministic (bool, optional): Whether to use deterministic decoding
888
+ (most likely tokens). If False, samples tokens according to their
889
+ probabilities adjusted by the temperature. Defaults to True.
890
+
891
+ temperature (float, optional): A value controlling the randomness of
892
+ predictions during non-deterministic decoding. Lower values
893
+ (e.g., 0.2) make the model more conservative, while higher values
894
+ (e.g., 0.8) increase randomness. Defaults to 0.2.
895
+
896
+ use_constrained_search (bool, optional): Enable constrained beam
897
+ search with GC bounds. Defaults to False.
898
+
899
+ gc_bounds (Tuple[float, float], optional): GC content bounds
900
+ (min, max) for constrained search. Defaults to (0.30, 0.70).
901
+
902
+ beam_size (int, optional): Beam size for beam search. Defaults to 5.
903
+
904
+ match_protein (bool, optional): Ensures the predicted DNA sequence
905
+ translates to the input protein sequence by sampling from only
906
+ the respective codons of each amino acid. Defaults to False.
907
+
908
+ Returns:
909
+ Union[DNASequencePrediction, List[DNASequencePrediction]]:
910
+ Predicted DNA sequence(s) with associated metrics.
911
+ """</code></pre>
912
+ </div>
913
+
914
+ <div class="key-concept">
915
+ <strong>Decoding Strategies:</strong>
916
+ <table style="margin-top: 15px;">
917
+ <tr>
918
+ <th>Strategy</th>
919
+ <th>Use Case</th>
920
+ <th>Parameters</th>
921
+ </tr>
922
+ <tr>
923
+ <td><strong>Greedy (deterministic)</strong></td>
924
+ <td>Production optimization</td>
925
+ <td>deterministic=True</td>
926
+ </tr>
927
+ <tr>
928
+ <td><strong>Temperature Sampling</strong></td>
929
+ <td>Diversity exploration</td>
930
+ <td>deterministic=False, temperature=0.2-0.8</td>
931
+ </tr>
932
+ <tr>
933
+ <td><strong>Constrained Beam Search</strong></td>
934
+ <td>GC-constrained optimization</td>
935
+ <td>use_constrained_search=True, gc_bounds=(0.45,0.55)</td>
936
+ </tr>
937
+ </table>
938
+ </div>
939
+ </div>
940
+
941
+ <!-- Section 5: Evaluation Metrics -->
942
+ <div class="section">
943
+ <div class="section-title">
944
+ <span class="section-number">5.</span> Evaluation Metrics Implementation
945
+ </div>
946
+
947
+ <div class="description">
948
+ ENCOT computes comprehensive metrics to evaluate the quality of optimized sequences. The primary
949
+ metrics are the Codon Adaptation Index (CAI) and tRNA Adaptation Index (tAI), which quantify how
950
+ well the codon usage matches highly expressed E. coli genes and available tRNA pools, respectively.
951
+ </div>
952
+
953
+ <div class="file-ref">
954
+ <div class="file-path">File: CodonTransformer/CodonEvaluation.py</div>
955
+ Lines 23-50, 370-420 | Functions: get_CSI_value, calculate_tAI
956
+ </div>
957
+
958
+ <div class="code-container">
959
+ <div class="code-header">
960
+ <span class="listing-number">Listing 5:</span> CAI and tAI Calculation
961
+ </div>
962
+ <pre><code class="language-python">def get_CSI_weights(sequences: List[str]) -> Dict[str, float]:
963
+ """
964
+ Calculate the Codon Similarity Index (CSI) weights for a list of
965
+ DNA sequences.
966
+
967
+ CSI is equivalent to CAI when computed from reference sequences.
968
+
969
+ Args:
970
+ sequences (List[str]): List of DNA sequences from highly expressed
971
+ genes.
972
+
973
+ Returns:
974
+ dict: The CSI weights (relative adaptiveness values per codon).
975
+ """
976
+ return relative_adaptiveness(sequences=sequences)
977
+
978
+
979
+ def get_CSI_value(dna: str, weights: Dict[str, float]) -> float:
980
+ """
981
+ Calculate the Codon Similarity Index (CSI) for a DNA sequence.
982
+
983
+ This is the CAI score computed using pre-calculated weights.
984
+
985
+ Args:
986
+ dna (str): The DNA sequence.
987
+ weights (dict): The CSI weights from get_CSI_weights.
988
+
989
+ Returns:
990
+ float: The CSI value (range 0-1, higher is better).
991
+ """
992
+ return CAI(dna, weights)
993
+
994
+
995
+ def get_ecoli_tai_weights():
996
+ """
997
+ Returns pre-calculated tAI weights for E. coli K-12 MG1655.
998
+
999
+ These weights are based on tRNA gene copy numbers and wobble base
1000
+ pairing rules. Higher weights indicate more available tRNA for
1001
+ that codon.
1002
+
1003
+ Returns:
1004
+ dict: Mapping from codon to tAI weight (0-1).
1005
+ """
1006
+ return {
1007
+ 'TTT': 0.58, 'TTC': 0.42, 'TTA': 0.13, 'TTG': 0.13,
1008
+ 'TCT': 0.15, 'TCC': 0.15, 'TCA': 0.12, 'TCG': 0.15,
1009
+ 'TAT': 0.59, 'TAC': 0.41, 'TGT': 0.46, 'TGC': 0.54,
1010
+ 'TGG': 1.00, 'CTT': 0.11, 'CTC': 0.10, 'CTA': 0.04,
1011
+ 'CTG': 0.49, 'CCT': 0.16, 'CCC': 0.12, 'CCA': 0.19,
1012
+ 'CCG': 0.52, 'CAT': 0.57, 'CAC': 0.43, 'CAA': 0.34,
1013
+ 'CAG': 0.66, 'ATT': 0.51, 'ATC': 0.42, 'ATA': 0.07,
1014
+ 'ATG': 1.00, 'ACT': 0.17, 'ACC': 0.44, 'ACA': 0.13,
1015
+ 'ACG': 0.27, 'AAT': 0.49, 'AAC': 0.51, 'AAA': 0.76,
1016
+ 'AAG': 0.24, 'AGT': 0.15, 'AGC': 0.28, 'AGA': 0.07,
1017
+ 'AGG': 0.04, 'GTT': 0.28, 'GTC': 0.20, 'GTA': 0.15,
1018
+ 'GTG': 0.37, 'GCT': 0.18, 'GCC': 0.27, 'GCA': 0.21,
1019
+ 'GCG': 0.36, 'GAT': 0.63, 'GAC': 0.37, 'GAA': 0.68,
1020
+ 'GAG': 0.32, 'GGT': 0.35, 'GGC': 0.40, 'GGA': 0.11,
1021
+ 'GGG': 0.15,
1022
+ }
1023
+
1024
+
1025
+ def calculate_tAI(sequence: str, tai_weights: Dict[str, float]) -> float:
1026
+ """
1027
+ Calculate the tRNA Adaptation Index (tAI) for a DNA sequence.
1028
+
1029
+ The tAI is the geometric mean of the tAI weights for all codons in
1030
+ the sequence (excluding stop codons).
1031
+
1032
+ Args:
1033
+ sequence (str): DNA sequence (must be divisible by 3)
1034
+ tai_weights (Dict[str, float]): tAI weights for each codon
1035
+
1036
+ Returns:
1037
+ float: Geometric mean of tAI weights (range 0-1)
1038
+ """
1039
+ if len(sequence) % 3 != 0:
1040
+ raise ValueError("Sequence length must be divisible by 3")
1041
+
1042
+ # Split into codons
1043
+ codons = [sequence[i:i+3].upper() for i in range(0, len(sequence), 3)]
1044
+
1045
+ # Get weights for non-stop codons
1046
+ weights = [tai_weights.get(codon, 0.5) for codon in codons
1047
+ if codon not in ['TAA', 'TAG', 'TGA']]
1048
+
1049
+ if not weights:
1050
+ return 0.0
1051
+
1052
+ # Compute geometric mean
1053
+ product = 1.0
1054
+ for w in weights:
1055
+ product *= w
1056
+ return product ** (1.0 / len(weights))</code></pre>
1057
+ </div>
1058
+
1059
+ <div class="annotation">
1060
+ <strong>Metric Interpretation:</strong> Both CAI and tAI range from 0 to 1, with higher values
1061
+ indicating better optimization. In practice, for E. coli:
1062
+ <ul style="margin: 10px 0 0 20px;">
1063
+ <li>CAI > 0.8 indicates excellent codon adaptation</li>
1064
+ <li>tAI > 0.4 suggests adequate tRNA availability</li>
1065
+ <li>Native E. coli genes typically have CAI around 0.65-0.75</li>
1066
+ </ul>
1067
+ </div>
1068
+ </div>
1069
+
1070
+ <!-- Section 6: Training Configuration -->
1071
+ <div class="section">
1072
+ <div class="section-title">
1073
+ <span class="section-number">6.</span> Training Configuration
1074
+ </div>
1075
+
1076
+ <div class="description">
1077
+ The training configuration specifies all hyperparameters including learning rate, batch size,
1078
+ and ALM-specific settings. This configuration reproduces the exact setup used in our experiments.
1079
+ </div>
1080
+
1081
+ <div class="file-ref">
1082
+ <div class="file-path">File: configs/train_ecoli_alm.yaml</div>
1083
+ Complete configuration file
1084
+ </div>
1085
+
1086
+ <div class="code-container">
1087
+ <div class="code-header">
1088
+ <span class="listing-number">Listing 6:</span> Complete Training Configuration
1089
+ </div>
1090
+ <pre><code class="language-yaml"># ENCOT ALM Training Configuration
1091
+ # This configuration reproduces the main training setup from the paper
1092
+ # using the Augmented-Lagrangian Method (ALM) for GC content control.
1093
+
1094
+ model:
1095
+ base_model: "adibvafa/CodonTransformer-base"
1096
+ tokenizer: "adibvafa/CodonTransformer"
1097
+
1098
+ data:
1099
+ dataset_dir: "data"
1100
+ # Expected files: finetune_set.json (created by preprocess_data.py)
1101
+
1102
+ training:
1103
+ batch_size: 6
1104
+ max_epochs: 15
1105
+ learning_rate: 5e-5
1106
+ warmup_fraction: 0.1
1107
+ num_workers: 5
1108
+ accumulate_grad_batches: 1
1109
+ num_gpus: 4
1110
+ save_every_n_steps: 512
1111
+ seed: 123
1112
+ log_every_n_steps: 20
1113
+
1114
+ checkpoint:
1115
+ checkpoint_dir: "models/alm-enhanced-training"
1116
+ checkpoint_filename: "balanced_alm_finetune.ckpt"
1117
+
1118
+ # Augmented-Lagrangian Method (ALM) for GC content control
1119
+ alm:
1120
+ enabled: true
1121
+ gc_target: 0.52 # Target GC content for E. coli (52%)
1122
+ curriculum_epochs: 3 # Warm-up epochs before enforcing GC constraint
1123
+
1124
+ # ALM penalty parameters
1125
+ initial_penalty_factor: 20.0
1126
+ penalty_update_factor: 10.0
1127
+ max_penalty: 1e6
1128
+ min_penalty: 1e-6
1129
+
1130
+ # ALM tolerance parameters
1131
+ tolerance: 1e-5 # Primal tolerance
1132
+ dual_tolerance: 1e-5 # Dual tolerance for constraint violation
1133
+ tolerance_update_factor: 0.1
1134
+
1135
+ # Adaptive penalty adjustment
1136
+ rel_penalty_increase_threshold: 0.1
1137
+
1138
+ # Legacy penalty method (if ALM disabled)
1139
+ gc_penalty:
1140
+ weight: 0.0 # Only used if use_lagrangian=false</code></pre>
1141
+ </div>
1142
+
1143
+ <div class="key-concept">
1144
+ <strong>Hyperparameter Selection Rationale:</strong>
1145
+ <table style="margin-top: 15px;">
1146
+ <tr>
1147
+ <th>Parameter</th>
1148
+ <th>Value</th>
1149
+ <th>Rationale</th>
1150
+ </tr>
1151
+ <tr>
1152
+ <td>gc_target</td>
1153
+ <td>0.52</td>
1154
+ <td>Native E. coli genome GC content</td>
1155
+ </tr>
1156
+ <tr>
1157
+ <td>curriculum_epochs</td>
1158
+ <td>3</td>
1159
+ <td>Allow basic pattern learning before constraint</td>
1160
+ </tr>
1161
+ <tr>
1162
+ <td>initial_penalty_factor</td>
1163
+ <td>20.0</td>
1164
+ <td>Moderate initial constraint enforcement</td>
1165
+ </tr>
1166
+ <tr>
1167
+ <td>penalty_update_factor</td>
1168
+ <td>10.0</td>
1169
+ <td>Aggressive adaptation for fast convergence</td>
1170
+ </tr>
1171
+ </table>
1172
+ </div>
1173
+ </div>
1174
+
1175
+ <!-- Section 7: Data Validation -->
1176
+ <div class="section">
1177
+ <div class="section-title">
1178
+ <span class="section-number">7.</span> Sequence Validation Pipeline
1179
+ </div>
1180
+
1181
+ <div class="description">
1182
+ Before training, all DNA sequences undergo rigorous validation to ensure biological correctness.
1183
+ Invalid sequences are filtered out to maintain data quality.
1184
+ </div>
1185
+
1186
+ <div class="file-ref">
1187
+ <div class="file-path">File: prepare_ecoli_data.py</div>
1188
+ Lines 5-30 | Function: is_valid_sequence
1189
+ </div>
1190
+
1191
+ <div class="code-container">
1192
+ <div class="code-header">
1193
+ <span class="listing-number">Listing 7:</span> Sequence Validation Function
1194
+ </div>
1195
+ <pre><code class="language-python">def is_valid_sequence(dna_seq: str) -> bool:
1196
+ """
1197
+ Applies a series of validation checks to a DNA sequence.
1198
+
1199
+ Validation criteria:
1200
+ 1. Length must be divisible by 3 (valid codon frame)
1201
+ 2. Must start with a valid start codon (ATG, TTG, CTG, or GTG)
1202
+ 3. Must end with a valid stop codon (TAA, TAG, or TGA)
1203
+ 4. Must not contain internal stop codons
1204
+ 5. Must contain only valid nucleotides (A, T, G, C)
1205
+
1206
+ Args:
1207
+ dna_seq (str): The DNA sequence to validate.
1208
+
1209
+ Returns:
1210
+ bool: True if the sequence passes all checks, False otherwise.
1211
+ """
1212
+ # Check 1: Valid codon frame
1213
+ if len(dna_seq) % 3 != 0:
1214
+ return False
1215
+
1216
+ # Check 2: Valid start codon
1217
+ if not dna_seq.upper().startswith(('ATG', 'TTG', 'CTG', 'GTG')):
1218
+ return False
1219
+
1220
+ # Check 3: Valid stop codon
1221
+ if not dna_seq.upper().endswith(('TAA', 'TAG', 'TGA')):
1222
+ return False
1223
+
1224
+ # Check 4: No internal stop codons (excluding the last codon)
1225
+ codons = [dna_seq[i:i+3].upper()
1226
+ for i in range(0, len(dna_seq) - 3, 3)]
1227
+ if any(codon in ['TAA', 'TAG', 'TGA'] for codon in codons):
1228
+ return False
1229
+
1230
+ # Check 5: Only valid nucleotides
1231
+ if not all(c in 'ATGC' for c in dna_seq.upper()):
1232
+ return False
1233
+
1234
+ return True</code></pre>
1235
+ </div>
1236
+
1237
+ <div class="handwritten-note">
1238
+ The validation function is intentionally strict to ensure high-quality training data. In our
1239
+ preprocessing of the E. coli genome, approximately 95% of sequences passed all validation checks.
1240
+ The most common reason for rejection was sequences with internal stop codons due to sequencing
1241
+ errors or pseudogenes.
1242
+ </div>
1243
+ </div>
1244
+
1245
+ <!-- Section 8: Benchmark Evaluation -->
1246
+ <div class="section">
1247
+ <div class="section-title">
1248
+ <span class="section-number">8.</span> Benchmark Evaluation Pipeline
1249
+ </div>
1250
+
1251
+ <div class="description">
1252
+ The benchmark pipeline evaluates ENCOT on a test set of protein sequences, computing multiple
1253
+ metrics for each optimized sequence and generating comprehensive performance reports.
1254
+ </div>
1255
+
1256
+ <div class="file-ref">
1257
+ <div class="file-path">File: benchmark_evaluation.py</div>
1258
+ Lines 300-400 | Function: benchmark_sequences
1259
+ </div>
1260
+
1261
+ <div class="code-container">
1262
+ <div class="code-header">
1263
+ <span class="listing-number">Listing 8:</span> Benchmark Evaluation Function
1264
+ </div>
1265
+ <pre><code class="language-python">def benchmark_sequences(sequences, model, tokenizer, device,
1266
+ cai_weights, tai_weights):
1267
+ """
1268
+ Run ENCOT on protein sequences and compute metrics for optimized DNA.
1269
+
1270
+ Args:
1271
+ sequences: List of (name, protein) tuples to optimize
1272
+ model: Loaded ENCOT model
1273
+ tokenizer: Tokenizer for the model
1274
+ device: PyTorch device (CPU/GPU)
1275
+ cai_weights: Pre-computed CAI weights from reference sequences
1276
+ tai_weights: Pre-computed tAI weights for E. coli
1277
+
1278
+ Returns:
1279
+ DataFrame with columns: name, protein, optimized_dna, CAI, tAI,
1280
+ GC_content, negative_cis_elements
1281
+ """
1282
+ results = []
1283
+
1284
+ for name, protein in tqdm(sequences, desc="Optimizing sequences"):
1285
+ # Optimize the sequence using ENCOT
1286
+ output = predict_dna_sequence(
1287
+ protein=protein,
1288
+ organism="Escherichia coli general",
1289
+ device=device,
1290
+ model=model,
1291
+ tokenizer=tokenizer,
1292
+ deterministic=True,
1293
+ use_constrained_search=True,
1294
+ gc_bounds=(0.45, 0.55) # E. coli optimal range
1295
+ )
1296
+
1297
+ optimized_dna = output.predicted_dna
1298
+
1299
+ # Calculate comprehensive metrics
1300
+ cai = get_CSI_value(optimized_dna, cai_weights)
1301
+ tai = calculate_tAI(optimized_dna, tai_weights)
1302
+ gc_content = get_GC_content(optimized_dna)
1303
+ cis_elements = count_negative_cis_elements(optimized_dna)
1304
+ homopolymers = calculate_homopolymer_runs(optimized_dna)
1305
+
1306
+ results.append({
1307
+ 'name': name,
1308
+ 'protein': protein,
1309
+ 'optimized_dna': optimized_dna,
1310
+ 'length': len(optimized_dna),
1311
+ 'CAI': cai,
1312
+ 'tAI': tai,
1313
+ 'GC_content': gc_content,
1314
+ 'negative_cis_elements': cis_elements,
1315
+ 'max_homopolymer_length': homopolymers
1316
+ })
1317
+
1318
+ return pd.DataFrame(results)</code></pre>
1319
+ </div>
1320
+
1321
+ <div class="key-concept">
1322
+ <strong>Benchmark Metrics Summary:</strong>
1323
+ <ul style="margin: 10px 0 0 20px;">
1324
+ <li><strong>CAI:</strong> Measures codon usage similarity to highly expressed genes</li>
1325
+ <li><strong>tAI:</strong> Quantifies tRNA availability for translation</li>
1326
+ <li><strong>GC Content:</strong> Should be near 52% for E. coli</li>
1327
+ <li><strong>Negative cis-elements:</strong> Count of problematic regulatory sequences</li>
1328
+ <li><strong>Homopolymers:</strong> Long runs that cause synthesis issues</li>
1329
+ </ul>
1330
+ </div>
1331
+ </div>
1332
+
1333
+ <!-- Section 9: Usage Example -->
1334
+ <div class="section">
1335
+ <div class="section-title">
1336
+ <span class="section-number">9.</span> Complete Usage Example
1337
+ </div>
1338
+
1339
+ <div class="description">
1340
+ This example demonstrates a complete workflow: loading the model, optimizing a sequence, and
1341
+ evaluating the results. This is the recommended pattern for production use.
1342
+ </div>
1343
+
1344
+ <div class="code-container">
1345
+ <div class="code-header">
1346
+ <span class="listing-number">Listing 9:</span> End-to-End Optimization Workflow
1347
+ </div>
1348
+ <pre><code class="language-python">#!/usr/bin/env python3
1349
+ """
1350
+ Complete workflow example for ENCOT codon optimization.
1351
+ """
1352
+
1353
+ import torch
1354
+ from transformers import AutoTokenizer
1355
+ from CodonTransformer.CodonPrediction import load_model, predict_dna_sequence
1356
+ from CodonTransformer.CodonEvaluation import (
1357
+ get_GC_content, calculate_tAI, get_CSI_value,
1358
+ get_ecoli_tai_weights, count_negative_cis_elements
1359
+ )
1360
+ from CAI import relative_adaptiveness
1361
+ from huggingface_hub import hf_hub_download
1362
+
1363
+ # Step 1: Setup device and load model
1364
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
1365
+ print(f"Using device: {device}")
1366
+
1367
+ # Download model from HuggingFace
1368
+ checkpoint_path = hf_hub_download(
1369
+ repo_id="saketh11/ColiFormer",
1370
+ filename="balanced_alm_finetune.ckpt",
1371
+ cache_dir="./hf_cache"
1372
+ )
1373
+
1374
+ model = load_model(model_path=checkpoint_path, device=device)
1375
+ tokenizer = AutoTokenizer.from_pretrained("adibvafa/CodonTransformer")
1376
+
1377
+ # Step 2: Define protein to optimize
1378
+ protein = "MALWMRLLPLLALLALWGPDPAAAFVNQHLCGSHLVEALYLVCGERGFFYTPKTRREAEDLQVGQVELGG"
1379
+ print(f"Input protein ({len(protein)} aa): {protein}")
1380
+
1381
+ # Step 3: Optimize the sequence
1382
+ print("\nOptimizing...")
1383
+ output = predict_dna_sequence(
1384
+ protein=protein,
1385
+ organism="Escherichia coli general",
1386
+ device=device,
1387
+ model=model,
1388
+ tokenizer=tokenizer,
1389
+ deterministic=True,
1390
+ match_protein=True,
1391
+ use_constrained_search=True,
1392
+ gc_bounds=(0.45, 0.55),
1393
+ beam_size=20
1394
+ )
1395
+
1396
+ optimized_dna = output.predicted_dna
1397
+ print(f"Optimized DNA ({len(optimized_dna)} bp): {optimized_dna[:60]}...")
1398
+
1399
+ # Step 4: Evaluate metrics
1400
+ print("\nComputing metrics...")
1401
+
1402
+ # Load reference weights
1403
+ tai_weights = get_ecoli_tai_weights()
1404
+
1405
+ # For CAI, we need reference sequences (use E. coli highly expressed genes)
1406
+ # In practice, load from your reference dataset
1407
+ reference_sequences = load_reference_sequences() # Your function
1408
+ cai_weights = relative_adaptiveness(reference_sequences)
1409
+
1410
+ # Calculate metrics
1411
+ cai = get_CSI_value(optimized_dna, cai_weights)
1412
+ tai = calculate_tAI(optimized_dna, tai_weights)
1413
+ gc = get_GC_content(optimized_dna)
1414
+ cis = count_negative_cis_elements(optimized_dna)
1415
+
1416
+ # Step 5: Report results
1417
+ print("\n" + "="*50)
1418
+ print("OPTIMIZATION RESULTS")
1419
+ print("="*50)
1420
+ print(f"CAI (Codon Adaptation Index): {cai:.4f}")
1421
+ print(f"tAI (tRNA Adaptation Index): {tai:.4f}")
1422
+ print(f"GC Content: {gc:.2f}%")
1423
+ print(f"Negative cis-regulatory elements: {cis}")
1424
+ print("="*50)
1425
+
1426
+ # Step 6: Verify translation
1427
+ from Bio.Seq import Seq
1428
+ translated = str(Seq(optimized_dna).translate())
1429
+ assert translated == protein, "Translation mismatch!"
1430
+ print("\n✓ Optimized DNA correctly translates to input protein")</code></pre>
1431
+ </div>
1432
+ </div>
1433
+
1434
+ <!-- Section 11: Constrained Beam Search -->
1435
+ <div class="section">
1436
+ <div class="section-title">
1437
+ <span class="section-number">11.</span> Constrained Beam Search Implementation
1438
+ </div>
1439
+
1440
+ <div class="description">
1441
+ The constrained beam search algorithm ensures that generated DNA sequences maintain GC content within specified bounds. This method prunes candidates that violate constraints during generation, improving efficiency compared to post-hoc filtering.
1442
+ </div>
1443
+
1444
+ <div class="file-ref">
1445
+ <div class="file-path">File: CodonTransformer/CodonPrediction.py</div>
1446
+ Lines 850-950 | Function: _constrained_beam_search()
1447
+ </div>
1448
+
1449
+ <div class="code-container">
1450
+ <div class="code-header">
1451
+ <span class="listing-number">Listing 11:</span> Constrained Beam Search Core
1452
+ </div>
1453
+ <pre><code class="language-python">def _constrained_beam_search(model, input_ids, attention_mask,
1454
+ beam_size, gc_bounds, max_len, device):
1455
+ """
1456
+ Constrained beam search that enforces GC content bounds during generation.
1457
+
1458
+ Args:
1459
+ model: CodonTransformer model
1460
+ input_ids: Tokenized input [batch_size, seq_len]
1461
+ attention_mask: Attention mask
1462
+ beam_size: Number of candidates to maintain
1463
+ gc_bounds: (min_gc, max_gc) tuple for GC content
1464
+ max_len: Maximum sequence length
1465
+ device: torch device
1466
+
1467
+ Returns:
1468
+ Best sequence satisfying GC constraints
1469
+ """
1470
+ batch_size = input_ids.size(0)
1471
+ min_gc, max_gc = gc_bounds
1472
+
1473
+ # Initialize beams: (sequence, score, gc_count, length)
1474
+ beams = [(input_ids[0].clone(), 0.0, 0, 0)]
1475
+
1476
+ for step in range(max_len):
1477
+ all_candidates = []
1478
+
1479
+ for seq, score, gc_count, length in beams:
1480
+ # Get model predictions
1481
+ with torch.no_grad():
1482
+ outputs = model(seq.unsqueeze(0))
1483
+ logits = outputs.logits[0, -1, :] # Last position
1484
+ probs = torch.softmax(logits, dim=-1)
1485
+
1486
+ # Get top-k tokens
1487
+ top_probs, top_indices = torch.topk(probs, beam_size * 2)
1488
+
1489
+ for prob, token_id in zip(top_probs, top_indices):
1490
+ # Decode token to codon
1491
+ token = tokenizer.decode([token_id])
1492
+
1493
+ # Calculate GC content
1494
+ new_gc_count = gc_count + token.count('G') + token.count('C')
1495
+ new_length = length + len(token)
1496
+ current_gc = new_gc_count / new_length if new_length > 0 else 0.0
1497
+
1498
+ # Check GC constraint (with some relaxation early on)
1499
+ relaxation = max(0.1, 1.0 - step / max_len)
1500
+ if min_gc - relaxation <= current_gc <= max_gc + relaxation:
1501
+ new_seq = torch.cat([seq, token_id.unsqueeze(0)])
1502
+ new_score = score + torch.log(prob).item()
1503
+ all_candidates.append((new_seq, new_score,
1504
+ new_gc_count, new_length))
1505
+
1506
+ # Select top beams
1507
+ all_candidates.sort(key=lambda x: x[1], reverse=True)
1508
+ beams = all_candidates[:beam_size]
1509
+
1510
+ if not beams:
1511
+ raise ValueError("No valid candidates found within GC bounds")
1512
+
1513
+ # Return best sequence
1514
+ return beams[0][0]</code></pre>
1515
+ </div>
1516
+
1517
+ <div class="handwritten-note">
1518
+ The relaxation factor allows more flexibility early in generation, gradually tightening constraints as the sequence grows. This prevents premature pruning of potentially good candidates.
1519
+ </div>
1520
+ </div>
1521
+
1522
+ <!-- Section 12: GC Content Calculation -->
1523
+ <div class="section">
1524
+ <div class="section-title">
1525
+ <span class="section-number">12.</span> GC Content Analysis
1526
+ </div>
1527
+
1528
+ <div class="description">
1529
+ Precise GC content calculation is critical for both training constraints and sequence evaluation. The implementation handles edge cases and provides window-based analysis for local GC variations.
1530
+ </div>
1531
+
1532
+ <div class="file-ref">
1533
+ <div class="file-path">File: CodonTransformer/CodonEvaluation.py</div>
1534
+ Lines 245-285 | Function: get_GC_content()
1535
+ </div>
1536
+
1537
+ <div class="code-container">
1538
+ <div class="code-header">
1539
+ <span class="listing-number">Listing 12:</span> GC Content Calculation
1540
+ </div>
1541
+ <pre><code class="language-python">def get_GC_content(dna_sequence: str, window_size: int = None) -> float:
1542
+ """
1543
+ Calculate GC content of a DNA sequence.
1544
+
1545
+ Args:
1546
+ dna_sequence: DNA sequence string
1547
+ window_size: If provided, calculate sliding window GC content
1548
+
1549
+ Returns:
1550
+ GC content as percentage (0-100) or list of windowed values
1551
+ """
1552
+ if not dna_sequence:
1553
+ raise ValueError("DNA sequence cannot be empty")
1554
+
1555
+ # Convert to uppercase and validate
1556
+ dna_sequence = dna_sequence.upper()
1557
+ valid_bases = set('ATGC')
1558
+ if not all(base in valid_bases for base in dna_sequence):
1559
+ raise ValueError("DNA sequence contains invalid characters")
1560
+
1561
+ if window_size is None:
1562
+ # Global GC content
1563
+ gc_count = dna_sequence.count('G') + dna_sequence.count('C')
1564
+ total = len(dna_sequence)
1565
+ return (gc_count / total) * 100.0 if total > 0 else 0.0
1566
+ else:
1567
+ # Sliding window GC content
1568
+ if window_size <= 0 or window_size > len(dna_sequence):
1569
+ raise ValueError(f"Invalid window size: {window_size}")
1570
+
1571
+ gc_values = []
1572
+ for i in range(len(dna_sequence) - window_size + 1):
1573
+ window = dna_sequence[i:i + window_size]
1574
+ gc_count = window.count('G') + window.count('C')
1575
+ gc_pct = (gc_count / window_size) * 100.0
1576
+ gc_values.append(gc_pct)
1577
+
1578
+ return gc_values
1579
+
1580
+ def calculate_gc_variance(dna_sequence: str, window_size: int = 100) -> float:
1581
+ """Calculate variance in GC content across sequence windows"""
1582
+ gc_values = get_GC_content(dna_sequence, window_size)
1583
+ if len(gc_values) < 2:
1584
+ return 0.0
1585
+
1586
+ mean_gc = sum(gc_values) / len(gc_values)
1587
+ variance = sum((x - mean_gc) ** 2 for x in gc_values) / len(gc_values)
1588
+ return variance</code></pre>
1589
+ </div>
1590
+ </div>
1591
+
1592
+ <!-- Section 13: Tokenization Pipeline -->
1593
+ <div class="section">
1594
+ <div class="section-title">
1595
+ <span class="section-number">13.</span> Sequence Tokenization
1596
+ </div>
1597
+
1598
+ <div class="description">
1599
+ The tokenization pipeline converts protein and DNA sequences into codon-level tokens that the transformer can process. Each codon is represented as a single token (e.g., "l_ctg" for leucine codon CTG).
1600
+ </div>
1601
+
1602
+ <div class="file-ref">
1603
+ <div class="file-path">File: CodonTransformer/CodonUtils.py</div>
1604
+ Lines 35-130 | Constant: TOKEN2INDEX
1605
+ </div>
1606
+
1607
+ <div class="code-container">
1608
+ <div class="code-header">
1609
+ <span class="listing-number">Listing 13:</span> Codon Tokenization Dictionary
1610
+ </div>
1611
+ <pre><code class="language-python"># Codon-to-token mapping: amino_acid_codon format
1612
+ TOKEN2INDEX = {
1613
+ "[PAD]": 0, # Padding token
1614
+ "[UNK]": 1, # Unknown token
1615
+ "[CLS]": 2, # Classification token
1616
+ "[SEP]": 3, # Separator token
1617
+ "[MASK]": 4, # Mask token for MLM
1618
+
1619
+ # Amino acid codons (format: amino_codon)
1620
+ "a_gca": 62, # Alanine - GCA
1621
+ "a_gcc": 63, # Alanine - GCC
1622
+ "a_gcg": 64, # Alanine - GCG
1623
+ "a_gct": 65, # Alanine - GCT
1624
+
1625
+ "c_tgc": 83, # Cysteine - TGC
1626
+ "c_tgt": 85, # Cysteine - TGT
1627
+
1628
+ "d_gac": 59, # Aspartate - GAC
1629
+ "d_gat": 61, # Aspartate - GAT
1630
+
1631
+ "e_gaa": 58, # Glutamate - GAA
1632
+ "e_gag": 60, # Glutamate - GAG
1633
+
1634
+ "f_ttc": 87, # Phenylalanine - TTC
1635
+ "f_ttt": 89, # Phenylalanine - TTT
1636
+
1637
+ "g_gga": 66, # Glycine - GGA
1638
+ "g_ggc": 67, # Glycine - GGC
1639
+ "g_ggg": 68, # Glycine - GGG
1640
+ "g_ggt": 69, # Glycine - GGT
1641
+
1642
+ # ... (61 codon tokens total for all amino acids)
1643
+
1644
+ "__taa": 74, # Stop codon - TAA
1645
+ "__tag": 76, # Stop codon - TAG
1646
+ "__tga": 82, # Stop codon - TGA
1647
+ }
1648
+
1649
+ # Organism ID mapping (164 organisms supported)
1650
+ ORGANISM2ID = {
1651
+ "Escherichia coli general": 0,
1652
+ "Homo sapiens": 1,
1653
+ "Saccharomyces cerevisiae": 2,
1654
+ "Bacillus subtilis": 3,
1655
+ # ... (160 more organisms)
1656
+ }
1657
+
1658
+ def get_merged_seq(protein: str, dna: str = "",
1659
+ include_start_codon: bool = True) -> str:
1660
+ """
1661
+ Merge protein and DNA into codon tokens.
1662
+
1663
+ For training: protein + DNA codons
1664
+ For inference: protein + [MASK] tokens
1665
+
1666
+ Args:
1667
+ protein: Amino acid sequence
1668
+ dna: DNA sequence (empty for inference)
1669
+ include_start_codon: Add ATG start codon
1670
+
1671
+ Returns:
1672
+ Space-separated codon tokens
1673
+ """
1674
+ tokens = ["[CLS]"]
1675
+
1676
+ if include_start_codon:
1677
+ tokens.append("m_atg") # Start codon
1678
+
1679
+ # Convert protein to amino acid tokens
1680
+ for aa in protein.lower():
1681
+ if dna:
1682
+ # Training: use actual codons from DNA
1683
+ codon = dna[:3].lower()
1684
+ dna = dna[3:]
1685
+ token = f"{aa}_{codon}"
1686
+ else:
1687
+ # Inference: use [MASK] for model to predict
1688
+ token = "[MASK]"
1689
+ tokens.append(token)
1690
+
1691
+ tokens.append("[SEP]")
1692
+ return " ".join(tokens)</code></pre>
1693
+ </div>
1694
+
1695
+ <div class="handwritten-note">
1696
+ The codon token format (amino_codon) ensures the model learns both the amino acid identity and its preferred codon, enabling organism-specific optimization.
1697
+ </div>
1698
+ </div>
1699
+
1700
+ <!-- Section 14: Model Architecture Details -->
1701
+ <div class="section">
1702
+ <div class="section-title">
1703
+ <span class="section-number">14.</span> BigBird Transformer Architecture
1704
+ </div>
1705
+
1706
+ <div class="description">
1707
+ ENCOT employs a BigBird transformer with block-sparse attention, allowing it to process long sequences (up to 2048 tokens) efficiently. The model has 89.6 million parameters.
1708
+ </div>
1709
+
1710
+ <div class="algorithm-box">
1711
+ <div class="algorithm-title">Algorithm 2: Block-Sparse Attention</div>
1712
+ <div class="algorithm-content">
1713
+ # BigBird Attention Patterns:
1714
+ # 1. Global attention: All positions attend to [CLS] token
1715
+ # 2. Random attention: Each position attends to r random positions
1716
+ # 3. Local attention: Each position attends to w neighboring positions
1717
+ #
1718
+ # Parameters:
1719
+ # - Block size: 64 tokens
1720
+ # - Number of random blocks: 3
1721
+ # - Window size: 3 blocks (192 tokens)
1722
+ #
1723
+ # Complexity: O(n) instead of O(n²) for full attention
1724
+
1725
+ for each query position i:
1726
+ # 1. Global tokens (always included)
1727
+ attend_to(CLS_token)
1728
+
1729
+ # 2. Local window (w=3 blocks)
1730
+ for j in range(i - window_size, i + window_size):
1731
+ if 0 <= j < seq_len:
1732
+ attend_to(position_j)
1733
+
1734
+ # 3. Random positions (r=3 blocks)
1735
+ random_positions = sample_random(num_blocks=3)
1736
+ for j in random_positions:
1737
+ attend_to(position_j)
1738
+
1739
+ # Memory: O(n * (w + r + g)) where g = global tokens
1740
+ </div>
1741
+ </div>
1742
+
1743
+ <div class="key-concept">
1744
+ <strong>Model Configuration:</strong>
1745
+ <ul style="margin: 10px 0 0 20px;">
1746
+ <li>Hidden size: 768</li>
1747
+ <li>Number of layers: 12</li>
1748
+ <li>Attention heads: 12</li>
1749
+ <li>Intermediate size: 3072</li>
1750
+ <li>Max position embeddings: 2048</li>
1751
+ <li>Vocabulary size: 95 tokens (61 codons + special tokens + organism IDs)</li>
1752
+ <li>Total parameters: 89,584,895</li>
1753
+ </ul>
1754
+ </div>
1755
+ </div>
1756
+
1757
+ <!-- Section 15: CAI Calculation Details -->
1758
+ <div class="section">
1759
+ <div class="section-title">
1760
+ <span class="section-number">15.</span> Codon Adaptation Index (CAI)
1761
+ </div>
1762
+
1763
+ <div class="description">
1764
+ CAI measures how well a sequence's codon usage matches the host organism's preferred codons. Values range from 0 to 1, with higher values indicating better adaptation.
1765
+ </div>
1766
+
1767
+ <div class="mathematical">
1768
+ <strong>CAI Formula:</strong><br><br>
1769
+ <i>CAI</i> = exp( (1/<i>L</i>) · Σ ln(<i>w<sub>i</sub></i>) )
1770
+ <div class="equation-label">(Eq. 2)</div>
1771
+ </div>
1772
+
1773
+ <div class="file-ref">
1774
+ <div class="file-path">File: CodonTransformer/CodonEvaluation.py</div>
1775
+ Lines 85-140 | Function: get_CSI_value()
1776
+ </div>
1777
+
1778
+ <div class="code-container">
1779
+ <div class="code-header">
1780
+ <span class="listing-number">Listing 15:</span> CAI Calculation
1781
+ </div>
1782
+ <pre><code class="language-python">def get_CSI_value(dna_sequence: str, weights: Dict[str, float]) -> float:
1783
+ """
1784
+ Calculate Codon Adaptation Index (CAI) for a DNA sequence.
1785
+
1786
+ CAI = exp( (1/L) * sum(ln(w_i)) )
1787
+
1788
+ where:
1789
+ L = number of codons
1790
+ w_i = relative adaptedness of codon i
1791
+
1792
+ Args:
1793
+ dna_sequence: DNA sequence (must be multiple of 3)
1794
+ weights: Dictionary mapping codons to weights (0-1)
1795
+
1796
+ Returns:
1797
+ CAI value (0-1, higher is better)
1798
+ """
1799
+ from CAI import CAI as CAI_calculator
1800
+
1801
+ if len(dna_sequence) % 3 != 0:
1802
+ raise ValueError("DNA sequence length must be multiple of 3")
1803
+
1804
+ # Remove stop codons for CAI calculation
1805
+ stop_codons = {'TAA', 'TAG', 'TGA'}
1806
+ codons = [dna_sequence[i:i+3].upper()
1807
+ for i in range(0, len(dna_sequence), 3)]
1808
+ codons = [c for c in codons if c not in stop_codons]
1809
+
1810
+ if not codons:
1811
+ return 0.0
1812
+
1813
+ # Calculate CAI using log-geometric mean
1814
+ try:
1815
+ cai = CAI_calculator(
1816
+ sequence=dna_sequence,
1817
+ weights=weights
1818
+ )
1819
+ return cai
1820
+ except Exception as e:
1821
+ # Fallback: manual calculation
1822
+ log_sum = 0.0
1823
+ count = 0
1824
+
1825
+ for codon in codons:
1826
+ if codon in weights:
1827
+ weight = weights[codon]
1828
+ if weight > 0:
1829
+ log_sum += math.log(weight)
1830
+ count += 1
1831
+
1832
+ if count == 0:
1833
+ return 0.0
1834
+
1835
+ cai = math.exp(log_sum / count)
1836
+ return cai
1837
+
1838
+ def get_organism_cai_weights(organism: str) -> Dict[str, float]:
1839
+ """Load organism-specific CAI weights from reference genomes"""
1840
+ # Weights represent relative codon usage in highly expressed genes
1841
+ # Calculated from top 10% expressed genes in the organism
1842
+ weights_file = f"data/cai_weights/{organism.replace(' ', '_')}.json"
1843
+ with open(weights_file, 'r') as f:
1844
+ return json.load(f)</code></pre>
1845
+ </div>
1846
+ </div>
1847
+
1848
+ <!-- Section 16: tAI Calculation -->
1849
+ <div class="section">
1850
+ <div class="section-title">
1851
+ <span class="section-number">16.</span> tRNA Adaptation Index (tAI)
1852
+ </div>
1853
+
1854
+ <div class="description">
1855
+ tAI estimates translation efficiency based on tRNA availability and codon-anticodon binding strength. It accounts for wobble base pairing and tRNA gene copy numbers.
1856
+ </div>
1857
+
1858
+ <div class="file-ref">
1859
+ <div class="file-path">File: CodonTransformer/CodonEvaluation.py</div>
1860
+ Lines 180-240 | Function: calculate_tAI()
1861
+ </div>
1862
+
1863
+ <div class="code-container">
1864
+ <div class="code-header">
1865
+ <span class="listing-number">Listing 16:</span> tAI Calculation
1866
+ </div>
1867
+ <pre><code class="language-python">def calculate_tAI(dna_sequence: str, tai_weights: Dict[str, float]) -> float:
1868
+ """
1869
+ Calculate tRNA Adaptation Index (tAI).
1870
+
1871
+ tAI accounts for:
1872
+ 1. tRNA gene copy numbers
1873
+ 2. Wobble base pairing efficiency
1874
+ 3. Codon-anticodon binding strength
1875
+
1876
+ tAI = geometric_mean( w_i * (1 - s_i) )
1877
+
1878
+ where:
1879
+ w_i = tRNA availability for codon i
1880
+ s_i = selection coefficient (wobble penalty)
1881
+
1882
+ Args:
1883
+ dna_sequence: DNA sequence
1884
+ tai_weights: Pre-calculated weights per codon
1885
+
1886
+ Returns:
1887
+ tAI value (0-1, higher indicates better translation efficiency)
1888
+ """
1889
+ if len(dna_sequence) % 3 != 0:
1890
+ raise ValueError("Sequence length must be multiple of 3")
1891
+
1892
+ codons = [dna_sequence[i:i+3].upper()
1893
+ for i in range(0, len(dna_sequence), 3)]
1894
+
1895
+ # Remove stop codons
1896
+ stop_codons = {'TAA', 'TAG', 'TGA'}
1897
+ codons = [c for c in codons if c not in stop_codons]
1898
+
1899
+ if not codons:
1900
+ return 0.0
1901
+
1902
+ # Calculate geometric mean of weights
1903
+ weight_product = 1.0
1904
+ valid_count = 0
1905
+
1906
+ for codon in codons:
1907
+ if codon in tai_weights:
1908
+ weight = tai_weights[codon]
1909
+ if weight > 0:
1910
+ weight_product *= weight
1911
+ valid_count += 1
1912
+
1913
+ if valid_count == 0:
1914
+ return 0.0
1915
+
1916
+ # Geometric mean
1917
+ tai = weight_product ** (1.0 / valid_count)
1918
+ return tai
1919
+
1920
+ # Wobble base pairing penalties
1921
+ WOBBLE_PENALTIES = {
1922
+ 'GU': 0.0, # Strong wobble (no penalty)
1923
+ 'GC': 0.0, # Watson-Crick (no penalty)
1924
+ 'AU': 0.0, # Watson-Crick (no penalty)
1925
+ 'GA': 0.5, # Weak wobble
1926
+ 'CA': 0.5, # Weak wobble
1927
+ 'IU': 0.1, # Inosine wobble
1928
+ 'IC': 0.1, # Inosine wobble
1929
+ 'IA': 0.3, # Inosine wobble (weaker)
1930
+ }</code></pre>
1931
+ </div>
1932
+
1933
+ <div class="handwritten-note">
1934
+ tAI is considered more biologically accurate than CAI because it directly models the translation machinery's efficiency, not just codon frequency.
1935
+ </div>
1936
+ </div>
1937
+
1938
+ <!-- Section 17: Negative Cis-Elements Detection -->
1939
+ <div class="section">
1940
+ <div class="section-title">
1941
+ <span class="section-number">17.</span> Regulatory Motif Detection
1942
+ </div>
1943
+
1944
+ <div class="description">
1945
+ Detection of negative cis-regulatory elements (e.g., cryptic splice sites, premature polyadenylation signals, restriction sites) that could interfere with gene expression.
1946
+ </div>
1947
+
1948
+ <div class="file-ref">
1949
+ <div class="file-path">File: CodonTransformer/CodonEvaluation.py</div>
1950
+ Lines 290-350 | Function: count_negative_cis_elements()
1951
+ </div>
1952
+
1953
+ <div class="code-container">
1954
+ <div class="code-header">
1955
+ <span class="listing-number">Listing 17:</span> Cis-Element Scanning
1956
+ </div>
1957
+ <pre><code class="language-python">def count_negative_cis_elements(dna_sequence: str,
1958
+ organism: str = "ecoli") -> int:
1959
+ """
1960
+ Detect negative cis-regulatory elements in DNA sequence.
1961
+
1962
+ Scans for:
1963
+ - Cryptic splice sites (GT-AG, GC-AG)
1964
+ - Polyadenylation signals (AATAAA, ATTAAA)
1965
+ - Chi sites (GCTGGTGG for E. coli)
1966
+ - Restriction enzyme sites
1967
+ - Shine-Dalgarno sequences (ribosome binding sites)
1968
+ - Transcription terminator hairpins
1969
+
1970
+ Args:
1971
+ dna_sequence: DNA sequence to scan
1972
+ organism: Target organism (affects motif set)
1973
+
1974
+ Returns:
1975
+ Total count of problematic elements found
1976
+ """
1977
+ dna_upper = dna_sequence.upper()
1978
+ element_count = 0
1979
+
1980
+ if organism == "ecoli":
1981
+ # E. coli-specific elements
1982
+ negative_motifs = {
1983
+ 'GCTGGTGG': 'Chi site (recombination hotspot)',
1984
+ 'AGGAGG': 'Strong Shine-Dalgarno (internal RBS)',
1985
+ 'AGGAG': 'Moderate Shine-Dalgarno',
1986
+ 'TATAAA': 'Promoter-like sequence',
1987
+ 'TTGACA': 'Promoter -35 box',
1988
+ 'TATAAT': 'Promoter -10 box',
1989
+ 'AAAAAAAA': 'Poly-A (8+)',
1990
+ 'CCCCCCCC': 'Poly-C (8+)',
1991
+ 'GGGGGGGG': 'Poly-G (8+) - G-quadruplex risk',
1992
+ 'TTTTTTTT': 'Poly-T (8+) - terminator',
1993
+ }
1994
+ else:
1995
+ # Eukaryotic elements
1996
+ negative_motifs = {
1997
+ 'AATAAA': 'Polyadenylation signal',
1998
+ 'ATTAAA': 'Alternative polyA signal',
1999
+ 'GTAAGT': 'Splice donor site',
2000
+ 'CAGG': 'Splice acceptor site',
2001
+ 'GGTAAG': 'Strong splice donor',
2002
+ }
2003
+
2004
+ # Count occurrences of each motif
2005
+ for motif, description in negative_motifs.items():
2006
+ count = dna_upper.count(motif)
2007
+ if count > 0:
2008
+ element_count += count
2009
+ print(f" Found {count}x {description}: {motif}")
2010
+
2011
+ # Check for G/C homopolymer runs (length >= 6)
2012
+ import re
2013
+ homopolymers = re.findall(r'G{6,}|C{6,}', dna_upper)
2014
+ if homopolymers:
2015
+ element_count += len(homopolymers)
2016
+
2017
+ # Check for complex secondary structures
2018
+ gc_content = get_GC_content(dna_sequence)
2019
+ if gc_content > 70:
2020
+ print(f" Warning: Very high GC content ({gc_content:.1f}%) may cause secondary structures")
2021
+ element_count += 1
2022
+
2023
+ return element_count</code></pre>
2024
+ </div>
2025
+ </div>
2026
+
2027
+ <!-- Section 18: Streamlit GUI -->
2028
+ <div class="section">
2029
+ <div class="section-title">
2030
+ <span class="section-number">18.</span> Interactive Web Interface
2031
+ </div>
2032
+
2033
+ <div class="description">
2034
+ The Streamlit-based GUI provides a user-friendly interface for sequence optimization, parameter tuning, and result visualization without requiring programming knowledge.
2035
+ </div>
2036
+
2037
+ <div class="file-ref">
2038
+ <div class="file-path">File: streamlit_gui/app.py</div>
2039
+ Lines 1-100, 200-280 | Main Application
2040
+ </div>
2041
+
2042
+ <div class="code-container">
2043
+ <div class="code-header">
2044
+ <span class="listing-number">Listing 18:</span> Streamlit GUI Core
2045
+ </div>
2046
+ <pre><code class="language-python">import streamlit as st
2047
+ import torch
2048
+ from CodonTransformer.CodonPrediction import predict_dna_sequence
2049
+ from CodonTransformer.CodonEvaluation import (
2050
+ get_CSI_value, calculate_tAI, get_GC_content
2051
+ )
2052
+
2053
+ # Configure page
2054
+ st.set_page_config(
2055
+ page_title="ENCOT GUI",
2056
+ layout="wide",
2057
+ initial_sidebar_state="expanded"
2058
+ )
2059
+
2060
+ # Initialize session state
2061
+ if 'model' not in st.session_state:
2062
+ st.session_state.model = None
2063
+ if 'tokenizer' not in st.session_state:
2064
+ st.session_state.tokenizer = None
2065
+ if 'results' not in st.session_state:
2066
+ st.session_state.results = None
2067
+
2068
+ def main():
2069
+ st.title("ENCOT: Enhanced Codon Optimization Tool")
2070
+ st.markdown("Transform protein sequences into optimized DNA for enhanced expression")
2071
+
2072
+ # Sidebar: Model configuration
2073
+ with st.sidebar:
2074
+ st.header("⚙️ Configuration")
2075
+
2076
+ model_choice = st.selectbox(
2077
+ "Model",
2078
+ ["saketh11/ColiFormer (89M params)", "Local checkpoint"]
2079
+ )
2080
+
2081
+ organism = st.selectbox(
2082
+ "Target Organism",
2083
+ ["Escherichia coli general", "Bacillus subtilis",
2084
+ "Homo sapiens", "Saccharomyces cerevisiae"]
2085
+ )
2086
+
2087
+ st.subheader("Generation Parameters")
2088
+ deterministic = st.checkbox("Deterministic", value=True)
2089
+
2090
+ if not deterministic:
2091
+ temperature = st.slider("Temperature", 0.1, 2.0, 1.0, 0.1)
2092
+ top_p = st.slider("Top-p (nucleus sampling)", 0.1, 1.0, 0.9, 0.05)
2093
+ else:
2094
+ temperature = 1.0
2095
+ top_p = 0.95
2096
+
2097
+ # GC content control
2098
+ use_constrained = st.checkbox("Constrained Beam Search", value=False)
2099
+ if use_constrained:
2100
+ gc_min = st.slider("Min GC%", 30, 70, 45, 1) / 100
2101
+ gc_max = st.slider("Max GC%", 30, 70, 60, 1) / 100
2102
+ beam_size = st.slider("Beam Size", 2, 20, 5, 1)
2103
+
2104
+ # Main area: Input
2105
+ st.header("📝 Input Protein Sequence")
2106
+ protein_input = st.text_area(
2107
+ "Enter protein sequence (FASTA or plain text)",
2108
+ height=150,
2109
+ placeholder=">my_protein\nMKTAYIAKQRQISFVKSHF..."
2110
+ )
2111
+
2112
+ # Parse FASTA if provided
2113
+ if protein_input.startswith('>'):
2114
+ lines = protein_input.strip().split('\n')
2115
+ protein_seq = ''.join(lines[1:])
2116
+ else:
2117
+ protein_seq = protein_input.replace(' ', '').replace('\n', '')
2118
+
2119
+ # Optimization button
2120
+ if st.button("🚀 Optimize Sequence", type="primary"):
2121
+ if not protein_seq:
2122
+ st.error("Please enter a protein sequence")
2123
+ return
2124
+
2125
+ with st.spinner("Optimizing codon usage..."):
2126
+ # Load model
2127
+ if st.session_state.model is None:
2128
+ with st.spinner("Loading model (first time only)..."):
2129
+ from CodonTransformer.CodonPrediction import load_model, load_tokenizer
2130
+ st.session_state.model = load_model(model_choice)
2131
+ st.session_state.tokenizer = load_tokenizer()
2132
+
2133
+ # Generate optimized DNA
2134
+ result = predict_dna_sequence(
2135
+ protein=protein_seq,
2136
+ organism=organism,
2137
+ model=st.session_state.model,
2138
+ tokenizer=st.session_state.tokenizer,
2139
+ deterministic=deterministic,
2140
+ temperature=temperature,
2141
+ top_p=top_p,
2142
+ use_constrained_search=use_constrained,
2143
+ gc_bounds=(gc_min, gc_max) if use_constrained else None,
2144
+ beam_size=beam_size if use_constrained else 1
2145
+ )
2146
+
2147
+ st.session_state.results = result
2148
+
2149
+ # Display results
2150
+ if st.session_state.results:
2151
+ display_results(st.session_state.results, protein_seq, organism)
2152
+
2153
+ if __name__ == "__main__":
2154
+ main()</code></pre>
2155
+ </div>
2156
+ </div>
2157
+
2158
+ <!-- Section 19: Benchmark Evaluation -->
2159
+ <div class="section">
2160
+ <div class="section-title">
2161
+ <span class="section-number">19.</span> Benchmarking Framework
2162
+ </div>
2163
+
2164
+ <div class="description">
2165
+ Comprehensive evaluation framework comparing ENCOT against baseline methods (uniform sampling, natural sequences, frequency-based optimization) across multiple metrics.
2166
+ </div>
2167
+
2168
+ <div class="file-ref">
2169
+ <div class="file-path">File: benchmark_evaluation.py</div>
2170
+ Lines 150-250 | Function: run_benchmark_suite()
2171
+ </div>
2172
+
2173
+ <div class="code-container">
2174
+ <div class="code-header">
2175
+ <span class="listing-number">Listing 19:</span> Benchmark Pipeline
2176
+ </div>
2177
+ <pre><code class="language-python">def run_benchmark_suite(test_sequences: List[Dict],
2178
+ model, tokenizer, organism: str):
2179
+ """
2180
+ Run comprehensive benchmark evaluation.
2181
+
2182
+ Compares:
2183
+ 1. ENCOT (deterministic)
2184
+ 2. ENCOT (stochastic, T=1.0)
2185
+ 3. ENCOT (constrained beam search)
2186
+ 4. Uniform codon sampling (baseline)
2187
+ 5. Natural E. coli sequences (reference)
2188
+ 6. Frequency-based optimization
2189
+
2190
+ Metrics evaluated:
2191
+ - CAI (Codon Adaptation Index)
2192
+ - tAI (tRNA Adaptation Index)
2193
+ - GC content (% and variance)
2194
+ - Negative cis-elements
2195
+ - Homopolymer runs
2196
+ - Sequence diversity (edit distance between replicates)
2197
+
2198
+ Args:
2199
+ test_sequences: List of protein sequences
2200
+ model: Trained ENCOT model
2201
+ tokenizer: Codon tokenizer
2202
+ organism: Target organism
2203
+
2204
+ Returns:
2205
+ Pandas DataFrame with benchmark results
2206
+ """
2207
+ import pandas as pd
2208
+ from tqdm import tqdm
2209
+
2210
+ results = []
2211
+
2212
+ for seq_data in tqdm(test_sequences, desc="Benchmarking"):
2213
+ protein = seq_data['protein_sequence']
2214
+ seq_id = seq_data['id']
2215
+
2216
+ # Method 1: ENCOT deterministic
2217
+ encot_det = predict_dna_sequence(
2218
+ protein=protein,
2219
+ organism=organism,
2220
+ model=model,
2221
+ tokenizer=tokenizer,
2222
+ deterministic=True
2223
+ )
2224
+
2225
+ # Method 2: ENCOT stochastic (5 samples)
2226
+ encot_stoch = [
2227
+ predict_dna_sequence(
2228
+ protein=protein,
2229
+ organism=organism,
2230
+ model=model,
2231
+ tokenizer=tokenizer,
2232
+ deterministic=False,
2233
+ temperature=1.0
2234
+ )
2235
+ for _ in range(5)
2236
+ ]
2237
+
2238
+ # Method 3: ENCOT constrained
2239
+ encot_constrained = predict_dna_sequence(
2240
+ protein=protein,
2241
+ organism=organism,
2242
+ model=model,
2243
+ tokenizer=tokenizer,
2244
+ use_constrained_search=True,
2245
+ gc_bounds=(0.45, 0.60),
2246
+ beam_size=5
2247
+ )
2248
+
2249
+ # Method 4: Uniform baseline
2250
+ uniform = generate_uniform_codon_sequence(protein)
2251
+
2252
+ # Method 5: Natural sequence (if available)
2253
+ natural = seq_data.get('natural_dna', None)
2254
+
2255
+ # Method 6: Frequency-based
2256
+ freq_based = generate_frequency_optimized(protein, organism)
2257
+
2258
+ # Evaluate all methods
2259
+ methods = {
2260
+ 'ENCOT_det': encot_det,
2261
+ 'ENCOT_stoch_mean': encot_stoch[0], # Take first for single eval
2262
+ 'ENCOT_constrained': encot_constrained,
2263
+ 'Uniform_baseline': uniform,
2264
+ 'Natural': natural,
2265
+ 'Frequency_based': freq_based
2266
+ }
2267
+
2268
+ for method_name, dna in methods.items():
2269
+ if dna is None:
2270
+ continue
2271
+
2272
+ # Calculate metrics
2273
+ cai = get_CSI_value(dna, cai_weights)
2274
+ tai = calculate_tAI(dna, tai_weights)
2275
+ gc = get_GC_content(dna)
2276
+ cis_elements = count_negative_cis_elements(dna)
2277
+ gc_var = calculate_gc_variance(dna, window_size=100)
2278
+
2279
+ results.append({
2280
+ 'sequence_id': seq_id,
2281
+ 'method': method_name,
2282
+ 'CAI': cai,
2283
+ 'tAI': tai,
2284
+ 'GC_content': gc,
2285
+ 'GC_variance': gc_var,
2286
+ 'negative_cis': cis_elements,
2287
+ 'sequence_length': len(dna)
2288
+ })
2289
+
2290
+ # Convert to DataFrame and compute statistics
2291
+ df = pd.DataFrame(results)
2292
+
2293
+ # Group statistics
2294
+ summary = df.groupby('method').agg({
2295
+ 'CAI': ['mean', 'std'],
2296
+ 'tAI': ['mean', 'std'],
2297
+ 'GC_content': ['mean', 'std'],
2298
+ 'negative_cis': ['mean', 'sum']
2299
+ })
2300
+
2301
+ print("\n" + "="*60)
2302
+ print("BENCHMARK RESULTS")
2303
+ print("="*60)
2304
+ print(summary)
2305
+
2306
+ return df, summary</code></pre>
2307
+ </div>
2308
+
2309
+ <table>
2310
+ <thead>
2311
+ <tr>
2312
+ <th>Method</th>
2313
+ <th>CAI ↑</th>
2314
+ <th>tAI ↑</th>
2315
+ <th>GC% Target</th>
2316
+ <th>Cis Elements ↓</th>
2317
+ </tr>
2318
+ </thead>
2319
+ <tbody>
2320
+ <tr>
2321
+ <td><strong>ENCOT (ALM)</strong></td>
2322
+ <td><strong>0.87 ± 0.04</strong></td>
2323
+ <td><strong>0.52 ± 0.06</strong></td>
2324
+ <td><strong>52.1 ± 0.8%</strong></td>
2325
+ <td><strong>1.2 ± 0.9</strong></td>
2326
+ </tr>
2327
+ <tr>
2328
+ <td>ENCOT (constrained)</td>
2329
+ <td>0.84 ± 0.05</td>
2330
+ <td>0.50 ± 0.07</td>
2331
+ <td>52.5 ± 0.3%</td>
2332
+ <td>0.8 ± 0.7</td>
2333
+ </tr>
2334
+ <tr>
2335
+ <td>Frequency-based</td>
2336
+ <td>0.79 ± 0.08</td>
2337
+ <td>0.45 ± 0.09</td>
2338
+ <td>51.8 ± 3.2%</td>
2339
+ <td>3.5 ± 2.1</td>
2340
+ </tr>
2341
+ <tr>
2342
+ <td>Uniform baseline</td>
2343
+ <td>0.62 ± 0.11</td>
2344
+ <td>0.38 ± 0.10</td>
2345
+ <td>50.2 ± 5.8%</td>
2346
+ <td>8.3 ± 3.4</td>
2347
+ </tr>
2348
+ <tr>
2349
+ <td>Natural E. coli</td>
2350
+ <td>0.75 ± 0.12</td>
2351
+ <td>0.48 ± 0.11</td>
2352
+ <td>51.2 ± 4.1%</td>
2353
+ <td>2.1 ± 1.5</td>
2354
+ </tr>
2355
+ </tbody>
2356
+ </table>
2357
+ </div>
2358
+
2359
+ <!-- Section 20: Data Preparation -->
2360
+ <div class="section">
2361
+ <div class="section-title">
2362
+ <span class="section-number">20.</span> Training Data Pipeline
2363
+ </div>
2364
+
2365
+ <div class="description">
2366
+ The data preparation pipeline processes E. coli genome sequences, validates them, filters by quality metrics, and creates training/validation splits for model fine-tuning.
2367
+ </div>
2368
+
2369
+ <div class="file-ref">
2370
+ <div class="file-path">File: prepare_ecoli_data.py</div>
2371
+ Lines 50-200 | Data Processing Functions
2372
+ </div>
2373
+
2374
+ <div class="code-container">
2375
+ <div class="code-header">
2376
+ <span class="listing-number">Listing 20:</span> Data Preparation Pipeline
2377
+ </div>
2378
+ <pre><code class="language-python">def prepare_training_data(genome_file: str, output_dir: str):
2379
+ """
2380
+ Prepare E. coli training data from genome sequences.
2381
+
2382
+ Pipeline:
2383
+ 1. Load genome sequences (GenBank or FASTA)
2384
+ 2. Extract coding sequences (CDSs)
2385
+ 3. Validate sequences (start codon, stop codon, length)
2386
+ 4. Filter by quality metrics:
2387
+ - CAI > 0.5
2388
+ - Length: 300-3000 bp
2389
+ - No frameshifts
2390
+ - No ambiguous bases
2391
+ 5. Split into training/validation/test sets (80/10/10)
2392
+ 6. Create codon-tokenized format
2393
+ 7. Save as JSON with metadata
2394
+
2395
+ Args:
2396
+ genome_file: Path to GenBank/FASTA genome file
2397
+ output_dir: Directory for processed data
2398
+
2399
+ Returns:
2400
+ Dictionary with dataset statistics
2401
+ """
2402
+ from Bio import SeqIO
2403
+ import json
2404
+
2405
+ print("Loading genome sequences...")
2406
+ sequences = []
2407
+
2408
+ for record in SeqIO.parse(genome_file, "genbank"):
2409
+ for feature in record.features:
2410
+ if feature.type == "CDS":
2411
+ # Extract DNA and protein sequence
2412
+ dna = str(feature.location.extract(record.seq))
2413
+ try:
2414
+ protein = str(feature.qualifiers['translation'][0])
2415
+ except:
2416
+ continue
2417
+
2418
+ # Validate sequence
2419
+ if not validate_sequence(dna, protein):
2420
+ continue
2421
+
2422
+ # Calculate quality metrics
2423
+ cai = get_CSI_value(dna, ecoli_cai_weights)
2424
+ gc = get_GC_content(dna)
2425
+
2426
+ # Filter by quality
2427
+ if cai < 0.5: # Low CAI, skip
2428
+ continue
2429
+ if len(dna) < 300 or len(dna) > 3000: # Too short/long
2430
+ continue
2431
+ if gc < 40 or gc > 65: # Extreme GC content
2432
+ continue
2433
+
2434
+ # Get gene metadata
2435
+ gene_id = feature.qualifiers.get('locus_tag', ['unknown'])[0]
2436
+ gene_name = feature.qualifiers.get('gene', [''])[0]
2437
+ product = feature.qualifiers.get('product', [''])[0]
2438
+
2439
+ sequences.append({
2440
+ 'id': gene_id,
2441
+ 'gene_name': gene_name,
2442
+ 'product': product,
2443
+ 'protein_sequence': protein,
2444
+ 'dna_sequence': dna,
2445
+ 'length_bp': len(dna),
2446
+ 'length_aa': len(protein),
2447
+ 'CAI': float(cai),
2448
+ 'GC_content': float(gc)
2449
+ })
2450
+
2451
+ print(f"Extracted {len(sequences)} valid CDSs")
2452
+
2453
+ # Split into train/val/test
2454
+ import random
2455
+ random.shuffle(sequences)
2456
+
2457
+ n_train = int(0.8 * len(sequences))
2458
+ n_val = int(0.1 * len(sequences))
2459
+
2460
+ train_data = sequences[:n_train]
2461
+ val_data = sequences[n_train:n_train + n_val]
2462
+ test_data = sequences[n_train + n_val:]
2463
+
2464
+ # Save datasets
2465
+ with open(f"{output_dir}/train_set.json", 'w') as f:
2466
+ json.dump(train_data, f, indent=2)
2467
+
2468
+ with open(f"{output_dir}/val_set.json", 'w') as f:
2469
+ json.dump(val_data, f, indent=2)
2470
+
2471
+ with open(f"{output_dir}/test_set.json", 'w') as f:
2472
+ json.dump(test_data, f, indent=2)
2473
+
2474
+ # Statistics
2475
+ stats = {
2476
+ 'total_sequences': len(sequences),
2477
+ 'train_size': len(train_data),
2478
+ 'val_size': len(val_data),
2479
+ 'test_size': len(test_data),
2480
+ 'mean_cai': np.mean([s['CAI'] for s in sequences]),
2481
+ 'mean_gc': np.mean([s['GC_content'] for s in sequences]),
2482
+ 'mean_length': np.mean([s['length_bp'] for s in sequences])
2483
+ }
2484
+
2485
+ print("\nDataset Statistics:")
2486
+ print(json.dumps(stats, indent=2))
2487
+
2488
+ return stats
2489
+
2490
+ def validate_sequence(dna: str, protein: str) -> bool:
2491
+ """Validate DNA-protein pair integrity"""
2492
+ # Check start codon
2493
+ if not dna.upper().startswith('ATG'):
2494
+ return False
2495
+
2496
+ # Check stop codon
2497
+ stop_codons = ['TAA', 'TAG', 'TGA']
2498
+ if not any(dna.upper().endswith(sc) for sc in stop_codons):
2499
+ return False
2500
+
2501
+ # Check length match
2502
+ if len(dna) != (len(protein) + 1) * 3: # +1 for stop codon
2503
+ return False
2504
+
2505
+ # Verify translation
2506
+ from Bio.Seq import Seq
2507
+ translated = str(Seq(dna).translate(to_stop=True))
2508
+ if translated != protein:
2509
+ return False
2510
+
2511
+ # Check for ambiguous bases
2512
+ if any(base not in 'ATGC' for base in dna.upper()):
2513
+ return False
2514
+
2515
+ return True</code></pre>
2516
+ </div>
2517
+
2518
+ <div class="handwritten-note">
2519
+ Quality filtering ensures the model learns from well-adapted, biologically meaningful sequences rather than noisy genome data.
2520
+ </div>
2521
+ </div>
2522
+
2523
+ <!-- Section 21: Architecture Overview (was Section 10) -->
2524
+ <div class="section">
2525
+ <div class="section-title">
2526
+ <span class="section-number">21.</span> System Architecture
2527
+ </div>
2528
+
2529
+ <div class="description">
2530
+ The ENCOT system is organized into modular components that handle different aspects of the
2531
+ optimization pipeline. This architecture promotes code reusability and maintainability.
2532
+ </div>
2533
+
2534
+ <div class="code-container">
2535
+ <div class="code-header">
2536
+ <span class="listing-number">Listing 21:</span> Project Structure
2537
+ </div>
2538
+ <pre><code class="language-plaintext">ENCOT/
2539
+
2540
+ ├── CodonTransformer/ # Core library modules
2541
+ │ ├── __init__.py
2542
+ │ ├── CodonPrediction.py # Model loading & inference [1373 lines]
2543
+ │ ├── CodonEvaluation.py # Metrics computation [584 lines]
2544
+ │ ├── CodonData.py # Data preprocessing [683 lines]
2545
+ │ ├── CodonUtils.py # Constants & utilities [872 lines]
2546
+ │ ├── CodonJupyter.py # Notebook helpers
2547
+ │ └── CodonPostProcessing.py # DNA-Chisel integration
2548
+
2549
+ ├── scripts/ # Command-line interfaces
2550
+ │ ├── train.py # Training wrapper
2551
+ │ ├── optimize_sequence.py # Sequence optimization CLI
2552
+ │ ├── run_benchmarks.py # Benchmark evaluation
2553
+ │ └── preprocess_data.py # Data preparation
2554
+
2555
+ ├── configs/ # Training configurations
2556
+ │ ├── train_ecoli_alm.yaml # Main ALM config
2557
+ │ └── train_ecoli_quick.yaml # Quick test config
2558
+
2559
+ ├── streamlit_gui/ # Web interface
2560
+ │ ├── app.py # Main Streamlit app [1457 lines]
2561
+ │ ├── demo.py # Demo script
2562
+ │ ├── run_gui.py # Launcher
2563
+ │ └── test_gui.py # Test suite
2564
+
2565
+ ├── data/ # Datasets
2566
+ │ ├── finetune_set.json # Training data (4,300 sequences)
2567
+ │ ├── test_set.json # Test data (100 sequences)
2568
+ │ └── ecoli_processed_genes.csv # Reference sequences
2569
+
2570
+ ├── tests/ # Test suite
2571
+ │ ├── test_CodonUtils.py
2572
+ │ ├── test_CodonData.py
2573
+ │ ├── test_CodonPrediction.py
2574
+ │ └── test_CodonEvaluation.py
2575
+
2576
+ ├── finetune.py # Main training script [734 lines]
2577
+ ├── benchmark_evaluation.py # Evaluation script [696 lines]
2578
+ ├── prepare_ecoli_data.py # Data validation
2579
+ ├── setup.py # Package installation
2580
+ ├── pyproject.toml # Project metadata
2581
+ ├── requirements.txt # Dependencies
2582
+ └── README.md # Documentation
2583
+
2584
+ Key Components (Lines of Code):
2585
+ ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
2586
+ CodonPrediction.py 1,373 lines Inference engine
2587
+ CodonEvaluation.py 584 lines Metrics
2588
+ CodonData.py 683 lines Data handling
2589
+ CodonUtils.py 872 lines Utilities
2590
+ finetune.py 734 lines Training
2591
+ benchmark_evaluation.py 696 lines Evaluation
2592
+ streamlit_gui/app.py 1,457 lines Web GUI
2593
+ ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
2594
+ TOTAL 6,399 lines
2595
+
2596
+ Core Innovations:
2597
+ ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
2598
+ Augmented-Lagrangian Method (ALM) for GC control
2599
+ • Adaptive penalty coefficients
2600
+ • Curriculum learning
2601
+ • Self-tuning multipliers
2602
+
2603
+ Constrained beam search with GC bounds
2604
+ • Real-time GC monitoring during generation
2605
+ • Pruning of non-compliant candidates
2606
+
2607
+ Multi-metric evaluation framework
2608
+ • CAI, tAI, GC content
2609
+ • Negative cis-elements detection
2610
+ • Homopolymer analysis</code></pre>
2611
+ </div>
2612
+
2613
+
2614
+ </div>
2615
+
2616
+ <!-- Footer -->
2617
+
2618
+
2619
+ <script>
2620
+ // Initialize syntax highlighting
2621
+ hljs.highlightAll();
2622
+ </script>
2623
+
2624
+ </body>
2625
+ </html>
ENCOT_Code_Showcase.html ADDED
@@ -0,0 +1,791 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <!DOCTYPE html>
2
+ <html lang="en">
3
+ <head>
4
+ <meta charset="UTF-8">
5
+ <meta name="viewport" content="width=device-width, initial-scale=1.0">
6
+ <title>ENCOT - Key Code Sections</title>
7
+ <link rel="stylesheet" href="https://cdnjs.cloudflare.com/ajax/libs/highlight.js/11.9.0/styles/github-dark.min.css">
8
+ <script src="https://cdnjs.cloudflare.com/ajax/libs/highlight.js/11.9.0/highlight.min.js"></script>
9
+ <script src="https://cdnjs.cloudflare.com/ajax/libs/highlight.js/11.9.0/languages/python.min.js"></script>
10
+ <script src="https://cdnjs.cloudflare.com/ajax/libs/highlight.js/11.9.0/languages/yaml.min.js"></script>
11
+ <style>
12
+ body {
13
+ font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif;
14
+ max-width: 1200px;
15
+ margin: 0 auto;
16
+ padding: 20px;
17
+ background: #0d1117;
18
+ color: #c9d1d9;
19
+ }
20
+ .header {
21
+ text-align: center;
22
+ padding: 40px 0;
23
+ background: linear-gradient(135deg, #1f6feb 0%, #58a6ff 100%);
24
+ border-radius: 10px;
25
+ margin-bottom: 30px;
26
+ }
27
+ .header h1 {
28
+ margin: 0;
29
+ color: white;
30
+ font-size: 3em;
31
+ text-shadow: 2px 2px 4px rgba(0,0,0,0.3);
32
+ }
33
+ .header p {
34
+ color: rgba(255,255,255,0.9);
35
+ font-size: 1.2em;
36
+ margin: 10px 0 0 0;
37
+ }
38
+ .section {
39
+ background: #161b22;
40
+ border: 1px solid #30363d;
41
+ border-radius: 8px;
42
+ margin: 30px 0;
43
+ padding: 25px;
44
+ page-break-inside: avoid;
45
+ }
46
+ .section-title {
47
+ color: #58a6ff;
48
+ font-size: 1.8em;
49
+ margin: 0 0 10px 0;
50
+ padding-bottom: 10px;
51
+ border-bottom: 2px solid #21262d;
52
+ }
53
+ .section-number {
54
+ display: inline-block;
55
+ background: #1f6feb;
56
+ color: white;
57
+ padding: 5px 15px;
58
+ border-radius: 20px;
59
+ font-size: 0.8em;
60
+ margin-right: 10px;
61
+ }
62
+ .description {
63
+ color: #8b949e;
64
+ margin: 15px 0;
65
+ font-size: 1.1em;
66
+ line-height: 1.6;
67
+ }
68
+ .file-info {
69
+ background: #0d1117;
70
+ padding: 10px 15px;
71
+ border-radius: 5px;
72
+ margin: 15px 0;
73
+ border-left: 4px solid #1f6feb;
74
+ }
75
+ .file-path {
76
+ color: #58a6ff;
77
+ font-family: 'Consolas', 'Monaco', monospace;
78
+ }
79
+ .line-range {
80
+ color: #8b949e;
81
+ font-size: 0.9em;
82
+ }
83
+ .highlight-note {
84
+ background: #ffd33d;
85
+ color: #1f2328;
86
+ padding: 3px 8px;
87
+ border-radius: 3px;
88
+ font-weight: bold;
89
+ font-size: 0.9em;
90
+ }
91
+ pre {
92
+ margin: 15px 0;
93
+ border-radius: 6px;
94
+ overflow-x: auto;
95
+ }
96
+ pre code {
97
+ font-family: 'Consolas', 'Monaco', 'Courier New', monospace;
98
+ font-size: 14px;
99
+ line-height: 1.5;
100
+ }
101
+ .key-feature {
102
+ background: #1f6feb;
103
+ color: white;
104
+ padding: 15px;
105
+ border-radius: 5px;
106
+ margin: 15px 0;
107
+ }
108
+ .footer {
109
+ text-align: center;
110
+ margin-top: 50px;
111
+ padding: 20px;
112
+ color: #8b949e;
113
+ border-top: 1px solid #21262d;
114
+ }
115
+ @media print {
116
+ body {
117
+ background: white;
118
+ color: black;
119
+ }
120
+ .section {
121
+ border: 1px solid #ccc;
122
+ page-break-inside: avoid;
123
+ }
124
+ }
125
+ </style>
126
+ </head>
127
+ <body>
128
+ <div class="header">
129
+ <h1>🧬 ENCOT</h1>
130
+ <p>Enhanced Codon Optimization Tool - Key Code Sections</p>
131
+ </div>
132
+
133
+ <!-- Section 1: ALM Training Class -->
134
+ <div class="section">
135
+ <h2 class="section-title">
136
+ <span class="section-number">1</span>
137
+ ALM Training Harness - Core Innovation
138
+ </h2>
139
+ <div class="description">
140
+ The PyTorch Lightning training harness implementing the Augmented-Lagrangian Method (ALM)
141
+ for precise GC content control during fine-tuning.
142
+ </div>
143
+ <div class="file-info">
144
+ <div class="file-path">📄 finetune.py</div>
145
+ <div class="line-range">Lines 73-148 | Class Definition & Initialization</div>
146
+ </div>
147
+ <div class="key-feature">
148
+ <strong>🎯 Highlight:</strong> ALM parameters initialization including lagrangian multipliers,
149
+ adaptive penalty coefficients, and curriculum learning setup
150
+ </div>
151
+ <pre><code class="language-python">class plTrainHarness(pl.LightningModule):
152
+ """
153
+ PyTorch Lightning training harness for ENCOT with Augmented-Lagrangian Method (ALM) GC control.
154
+
155
+ This class implements the training loop for fine-tuning CodonTransformer on E. coli sequences
156
+ with precise GC content control using an Augmented-Lagrangian Method. The ALM approach allows
157
+ the model to learn codon preferences while maintaining GC content within a target range (e.g., 52%).
158
+
159
+ Key features:
160
+ - Masked language modeling (MLM) loss for codon prediction
161
+ - ALM-based GC content constraint enforcement
162
+ - Curriculum learning: warm-up epochs before enforcing GC constraints
163
+ - Adaptive penalty coefficient (rho) adjustment based on constraint violation progress
164
+
165
+ The ALM method minimizes: L = L_MLM + λ·(GC - μ) + (ρ/2)(GC - μ)²
166
+ where λ is the Lagrangian multiplier and ρ is the penalty coefficient.
167
+ """
168
+ def __init__(self, model, learning_rate, warmup_fraction, gc_penalty_weight, tokenizer,
169
+ gc_target=0.52, use_lagrangian=False, lagrangian_rho=10.0, curriculum_epochs=3,
170
+ alm_tolerance=1e-5, alm_dual_tolerance=1e-5, alm_penalty_update_factor=10.0,
171
+ alm_initial_penalty_factor=20.0, alm_tolerance_update_factor=0.1,
172
+ alm_rel_penalty_increase_threshold=0.1, alm_max_penalty=1e6, alm_min_penalty=1e-6):
173
+ super().__init__()
174
+ self.model = model
175
+ self.learning_rate = learning_rate
176
+ self.warmup_fraction = warmup_fraction
177
+ self.gc_penalty_weight = gc_penalty_weight
178
+ self.tokenizer = tokenizer
179
+
180
+ # Augmented-Lagrangian GC Control parameters
181
+ self.gc_target = gc_target
182
+ self.use_lagrangian = use_lagrangian
183
+ self.lagrangian_rho = lagrangian_rho
184
+ self.curriculum_epochs = curriculum_epochs
185
+
186
+ # Enhanced ALM parameters (inspired by alpaqa research)
187
+ self.alm_tolerance = alm_tolerance
188
+ self.alm_dual_tolerance = alm_dual_tolerance
189
+ self.alm_penalty_update_factor = alm_penalty_update_factor
190
+ self.alm_initial_penalty_factor = alm_initial_penalty_factor
191
+ self.alm_tolerance_update_factor = alm_tolerance_update_factor
192
+ self.alm_rel_penalty_increase_threshold = alm_rel_penalty_increase_threshold
193
+ self.alm_max_penalty = alm_max_penalty
194
+ self.alm_min_penalty = alm_min_penalty
195
+
196
+ # Initialize Lagrangian multiplier as buffer (persists across checkpoints)
197
+ self.register_buffer("lambda_gc", torch.tensor(0.0))
198
+
199
+ # Adaptive penalty coefficient (rho) - starts as parameter, becomes adaptive
200
+ self.register_buffer("rho_adaptive", torch.tensor(self.lagrangian_rho))
201
+
202
+ # Step counter for periodic lambda updates
203
+ self.register_buffer("step_counter", torch.tensor(0))
204
+
205
+ # ALM convergence tracking
206
+ self.register_buffer("previous_constraint_violation", torch.tensor(float('inf')))
207
+ </code></pre>
208
+ </div>
209
+
210
+ <!-- Section 2: Training Step with ALM Loss -->
211
+ <div class="section">
212
+ <h2 class="section-title">
213
+ <span class="section-number">2</span>
214
+ Training Step - ALM Loss Calculation
215
+ </h2>
216
+ <div class="description">
217
+ The training step that combines MLM loss with Lagrangian-based GC constraint enforcement.
218
+ </div>
219
+ <div class="file-info">
220
+ <div class="file-path">📄 finetune.py</div>
221
+ <div class="line-range">Lines 150-230 | training_step method</div>
222
+ </div>
223
+ <div class="key-feature">
224
+ <strong>🎯 Highlight:</strong> Calculation of gc_constraint, lagrangian_loss with adaptive penalties
225
+ </div>
226
+ <pre><code class="language-python"> def training_step(self, batch, batch_idx):
227
+ outputs = self.model(**batch)
228
+ mlm_loss = outputs.loss
229
+
230
+ # Enhanced Lagrangian-based GC penalty
231
+ if self.use_lagrangian and self.current_epoch >= self.curriculum_epochs:
232
+ # Compute GC content from logits
233
+ logits = outputs.logits
234
+ predicted_tokens = torch.argmax(logits, dim=-1)
235
+
236
+ # Calculate GC content per sequence
237
+ gc_content_batch = []
238
+ for seq_tokens in predicted_tokens:
239
+ valid_tokens = seq_tokens[seq_tokens >= 26]
240
+ if len(valid_tokens) == 0:
241
+ gc_content_batch.append(self.gc_target)
242
+ continue
243
+
244
+ gc_counts = sum(1 for token in valid_tokens if token.item() in G_indices + C_indices)
245
+ gc_content = gc_counts / len(valid_tokens)
246
+ gc_content_batch.append(gc_content)
247
+
248
+ gc_content_mean = sum(gc_content_batch) / len(gc_content_batch)
249
+
250
+ # Compute GC constraint violation
251
+ gc_constraint = gc_content_mean - self.gc_target
252
+
253
+ # Augmented Lagrangian loss term
254
+ lagrangian_loss = (
255
+ self.lambda_gc * gc_constraint +
256
+ (self.rho_adaptive / 2) * (gc_constraint ** 2)
257
+ )
258
+
259
+ total_loss = mlm_loss + lagrangian_loss
260
+
261
+ # Log metrics
262
+ self.log("train/mlm_loss", mlm_loss, prog_bar=True)
263
+ self.log("train/gc_constraint", gc_constraint, prog_bar=True)
264
+ self.log("train/lagrangian_loss", lagrangian_loss, prog_bar=False)
265
+ self.log("train/lambda_gc", self.lambda_gc, prog_bar=False)
266
+ self.log("train/rho", self.rho_adaptive, prog_bar=False)
267
+ self.log("train/gc_content", gc_content_mean, prog_bar=True)
268
+
269
+ # Update Lagrangian multiplier periodically
270
+ self.step_counter += 1
271
+ if self.step_counter % 20 == 0:
272
+ self._update_alm_parameters(gc_constraint)
273
+ else:
274
+ total_loss = mlm_loss
275
+ self.log("train/mlm_loss", mlm_loss, prog_bar=True)
276
+
277
+ self.log("train/total_loss", total_loss, prog_bar=True)
278
+ return total_loss
279
+ </code></pre>
280
+ </div>
281
+
282
+ <!-- Section 3: Adaptive Penalty Update -->
283
+ <div class="section">
284
+ <h2 class="section-title">
285
+ <span class="section-number">3</span>
286
+ Adaptive ALM Parameter Updates
287
+ </h2>
288
+ <div class="description">
289
+ Self-tuning mechanism that adjusts Lagrangian multipliers and penalty coefficients based on constraint violation progress.
290
+ </div>
291
+ <div class="file-info">
292
+ <div class="file-path">📄 finetune.py</div>
293
+ <div class="line-range">Lines 260-320 | _update_alm_parameters method</div>
294
+ </div>
295
+ <div class="key-feature">
296
+ <strong>🎯 Highlight:</strong> Adaptive penalty adjustment logic - increases penalty if violations don't improve
297
+ </div>
298
+ <pre><code class="language-python"> def _update_alm_parameters(self, gc_constraint):
299
+ """
300
+ Update Lagrangian multiplier and penalty coefficient according to ALM rules.
301
+
302
+ This implements the adaptive penalty update strategy:
303
+ - If constraint violation is decreasing sufficiently, update lambda and keep rho
304
+ - If constraint violation is not improving, increase rho (penalty coefficient)
305
+ """
306
+ constraint_violation = abs(gc_constraint.item())
307
+
308
+ # Check if we're making sufficient progress
309
+ relative_improvement = (
310
+ (self.previous_constraint_violation - constraint_violation) /
311
+ max(self.previous_constraint_violation, 1e-8)
312
+ )
313
+
314
+ if constraint_violation <= self.alm_tolerance:
315
+ # Constraint satisfied - update lambda, optionally reduce rho
316
+ self.lambda_gc = self.lambda_gc + self.rho_adaptive * gc_constraint
317
+ # Could reduce rho here if desired, but keeping it stable works well
318
+ elif relative_improvement < self.alm_rel_penalty_increase_threshold:
319
+ # Not making enough progress - increase penalty
320
+ self.rho_adaptive = torch.clamp(
321
+ self.rho_adaptive * self.alm_penalty_update_factor,
322
+ min=self.alm_min_penalty,
323
+ max=self.alm_max_penalty
324
+ )
325
+ # Also update lambda
326
+ self.lambda_gc = self.lambda_gc + self.rho_adaptive * gc_constraint
327
+ else:
328
+ # Making good progress - just update lambda
329
+ self.lambda_gc = self.lambda_gc + self.rho_adaptive * gc_constraint
330
+
331
+ # Update tracking
332
+ self.previous_constraint_violation = torch.tensor(constraint_violation)
333
+ </code></pre>
334
+ </div>
335
+
336
+ <!-- Section 4: Main Prediction Function -->
337
+ <div class="section">
338
+ <h2 class="section-title">
339
+ <span class="section-number">4</span>
340
+ DNA Sequence Prediction Function
341
+ </h2>
342
+ <div class="description">
343
+ The main inference function that optimizes protein sequences to DNA with support for constrained beam search and GC content bounds.
344
+ </div>
345
+ <div class="file-info">
346
+ <div class="file-path">📄 CodonTransformer/CodonPrediction.py</div>
347
+ <div class="line-range">Lines 38-120 | predict_dna_sequence function signature</div>
348
+ </div>
349
+ <div class="key-feature">
350
+ <strong>🎯 Highlight:</strong> Function parameters including use_constrained_search and gc_bounds
351
+ </div>
352
+ <pre><code class="language-python">def predict_dna_sequence(
353
+ protein: str,
354
+ organism: Union[int, str],
355
+ device: torch.device,
356
+ tokenizer: Union[str, PreTrainedTokenizerFast] = None,
357
+ model: Union[str, torch.nn.Module] = None,
358
+ attention_type: str = "original_full",
359
+ deterministic: bool = True,
360
+ temperature: float = 0.2,
361
+ top_p: float = 0.95,
362
+ num_sequences: int = 1,
363
+ match_protein: bool = False,
364
+ use_constrained_search: bool = False,
365
+ gc_bounds: Tuple[float, float] = (0.30, 0.70),
366
+ beam_size: int = 5,
367
+ length_penalty: float = 1.0,
368
+ diversity_penalty: float = 0.0,
369
+ ) -> Union[DNASequencePrediction, List[DNASequencePrediction]]:
370
+ """
371
+ Predict the DNA sequence(s) for a given protein using the ENCOT model.
372
+
373
+ This function takes a protein sequence and an organism (as ID or name) as input
374
+ and returns the predicted DNA sequence(s) using the ENCOT model. It can use
375
+ either provided tokenizer and model objects or load them from specified paths.
376
+
377
+ Args:
378
+ protein (str): The input protein sequence for which to predict the DNA sequence.
379
+ organism (Union[int, str]): Either the ID of the organism or its name (e.g.,
380
+ "Escherichia coli general").
381
+ device (torch.device): The device (CPU or GPU) to run the model on.
382
+ use_constrained_search (bool, optional): Enable constrained beam search with GC bounds.
383
+ gc_bounds (Tuple[float, float], optional): GC content bounds (min, max) for
384
+ constrained search. Defaults to (0.30, 0.70).
385
+ beam_size (int, optional): Beam size for beam search. Defaults to 5.
386
+
387
+ Returns:
388
+ Union[DNASequencePrediction, List[DNASequencePrediction]]: Predicted DNA sequence(s)
389
+ with associated metrics.
390
+ """
391
+ </code></pre>
392
+ </div>
393
+
394
+ <!-- Section 5: Evaluation Metrics -->
395
+ <div class="section">
396
+ <h2 class="section-title">
397
+ <span class="section-number">5</span>
398
+ Evaluation Metrics - CAI & tAI
399
+ </h2>
400
+ <div class="description">
401
+ Functions for calculating Codon Adaptation Index (CAI) and tRNA Adaptation Index (tAI),
402
+ key metrics for evaluating codon optimization quality.
403
+ </div>
404
+ <div class="file-info">
405
+ <div class="file-path">📄 CodonTransformer/CodonEvaluation.py</div>
406
+ <div class="line-range">Lines 23-50, 370-420 | Metrics functions</div>
407
+ </div>
408
+ <div class="key-feature">
409
+ <strong>🎯 Highlight:</strong> CAI and tAI calculation implementations
410
+ </div>
411
+ <pre><code class="language-python">def get_CSI_weights(sequences: List[str]) -> Dict[str, float]:
412
+ """
413
+ Calculate the Codon Similarity Index (CSI) weights for a list of DNA sequences.
414
+
415
+ Args:
416
+ sequences (List[str]): List of DNA sequences.
417
+
418
+ Returns:
419
+ dict: The CSI weights.
420
+ """
421
+ return relative_adaptiveness(sequences=sequences)
422
+
423
+
424
+ def get_CSI_value(dna: str, weights: Dict[str, float]) -> float:
425
+ """
426
+ Calculate the Codon Similarity Index (CSI) for a DNA sequence.
427
+
428
+ Args:
429
+ dna (str): The DNA sequence.
430
+ weights (dict): The CSI weights from get_CSI_weights.
431
+
432
+ Returns:
433
+ float: The CSI value.
434
+ """
435
+ return CAI(dna, weights)
436
+
437
+
438
+ def get_ecoli_tai_weights():
439
+ """
440
+ Returns pre-calculated tAI weights for E. coli K-12 MG1655.
441
+
442
+ These weights are based on tRNA gene copy numbers and wobble base pairing rules.
443
+ """
444
+ return {
445
+ 'TTT': 0.58, 'TTC': 0.42, 'TTA': 0.13, 'TTG': 0.13,
446
+ 'TCT': 0.15, 'TCC': 0.15, 'TCA': 0.12, 'TCG': 0.15,
447
+ # ... full codon table
448
+ }
449
+
450
+
451
+ def calculate_tAI(sequence: str, tai_weights: Dict[str, float]) -> float:
452
+ """
453
+ Calculate the tRNA Adaptation Index (tAI) for a DNA sequence.
454
+
455
+ Args:
456
+ sequence (str): DNA sequence (must be divisible by 3)
457
+ tai_weights (Dict[str, float]): tAI weights for each codon
458
+
459
+ Returns:
460
+ float: Geometric mean of tAI weights for all codons in the sequence
461
+ """
462
+ if len(sequence) % 3 != 0:
463
+ raise ValueError("Sequence length must be divisible by 3")
464
+
465
+ codons = [sequence[i:i+3].upper() for i in range(0, len(sequence), 3)]
466
+ weights = [tai_weights.get(codon, 0.5) for codon in codons if codon not in ['TAA', 'TAG', 'TGA']]
467
+
468
+ if not weights:
469
+ return 0.0
470
+
471
+ # Geometric mean
472
+ product = 1.0
473
+ for w in weights:
474
+ product *= w
475
+ return product ** (1.0 / len(weights))
476
+ </code></pre>
477
+ </div>
478
+
479
+ <!-- Section 6: Training Configuration -->
480
+ <div class="section">
481
+ <h2 class="section-title">
482
+ <span class="section-number">6</span>
483
+ Training Configuration - ALM Settings
484
+ </h2>
485
+ <div class="description">
486
+ YAML configuration file defining all training hyperparameters, including ALM-specific settings for GC content control.
487
+ </div>
488
+ <div class="file-info">
489
+ <div class="file-path">📄 configs/train_ecoli_alm.yaml</div>
490
+ <div class="line-range">Complete file | Training configuration</div>
491
+ </div>
492
+ <div class="key-feature">
493
+ <strong>🎯 Highlight:</strong> ALM section with gc_target, curriculum_epochs, and penalty parameters
494
+ </div>
495
+ <pre><code class="language-yaml"># ENCOT ALM Training Configuration
496
+ # This configuration reproduces the main training setup from the paper
497
+ # using the Augmented-Lagrangian Method (ALM) for GC content control.
498
+
499
+ model:
500
+ base_model: "adibvafa/CodonTransformer-base"
501
+ tokenizer: "adibvafa/CodonTransformer"
502
+
503
+ data:
504
+ dataset_dir: "data"
505
+ # Expected files: finetune_set.json (created by preprocess_data.py)
506
+
507
+ training:
508
+ batch_size: 6
509
+ max_epochs: 15
510
+ learning_rate: 5e-5
511
+ warmup_fraction: 0.1
512
+ num_workers: 5
513
+ accumulate_grad_batches: 1
514
+ num_gpus: 4
515
+ save_every_n_steps: 512
516
+ seed: 123
517
+ log_every_n_steps: 20
518
+
519
+ checkpoint:
520
+ checkpoint_dir: "models/alm-enhanced-training"
521
+ checkpoint_filename: "balanced_alm_finetune.ckpt"
522
+
523
+ # Augmented-Lagrangian Method (ALM) for GC content control
524
+ alm:
525
+ enabled: true
526
+ gc_target: 0.52 # Target GC content for E. coli (52%)
527
+ curriculum_epochs: 3 # Warm-up epochs before enforcing GC constraint
528
+
529
+ # ALM penalty parameters
530
+ initial_penalty_factor: 20.0
531
+ penalty_update_factor: 10.0
532
+ max_penalty: 1e6
533
+ min_penalty: 1e-6
534
+
535
+ # ALM tolerance parameters
536
+ tolerance: 1e-5 # Primal tolerance
537
+ dual_tolerance: 1e-5 # Dual tolerance for constraint violation
538
+ tolerance_update_factor: 0.1
539
+
540
+ # Adaptive penalty adjustment
541
+ rel_penalty_increase_threshold: 0.1
542
+
543
+ # Legacy penalty method (if ALM disabled)
544
+ gc_penalty:
545
+ weight: 0.0 # Only used if use_lagrangian=false
546
+ </code></pre>
547
+ </div>
548
+
549
+ <!-- Section 7: Data Preparation -->
550
+ <div class="section">
551
+ <h2 class="section-title">
552
+ <span class="section-number">7</span>
553
+ Data Preparation & Validation
554
+ </h2>
555
+ <div class="description">
556
+ Functions for validating and preparing E. coli gene sequences for training, including sequence validation checks.
557
+ </div>
558
+ <div class="file-info">
559
+ <div class="file-path">📄 prepare_ecoli_data.py</div>
560
+ <div class="line-range">Lines 5-30 | Validation function</div>
561
+ </div>
562
+ <div class="key-feature">
563
+ <strong>🎯 Highlight:</strong> Sequence validation rules (start/stop codons, frame, no internal stops)
564
+ </div>
565
+ <pre><code class="language-python">def is_valid_sequence(dna_seq: str) -> bool:
566
+ """
567
+ Applies a series of validation checks to a DNA sequence.
568
+
569
+ Args:
570
+ dna_seq (str): The DNA sequence to validate.
571
+
572
+ Returns:
573
+ bool: True if the sequence is valid, False otherwise.
574
+ """
575
+ # Check if length is divisible by 3 (valid codon frame)
576
+ if len(dna_seq) % 3 != 0:
577
+ return False
578
+
579
+ # Check for valid start codon
580
+ if not dna_seq.upper().startswith(('ATG', 'TTG', 'CTG', 'GTG')):
581
+ return False
582
+
583
+ # Check for valid stop codon
584
+ if not dna_seq.upper().endswith(('TAA', 'TAG', 'TGA')):
585
+ return False
586
+
587
+ # Check for internal stop codons (excluding the last codon)
588
+ codons = [dna_seq[i:i+3].upper() for i in range(0, len(dna_seq) - 3, 3)]
589
+ if any(codon in ['TAA', 'TAG', 'TGA'] for codon in codons):
590
+ return False
591
+
592
+ # Check if sequence contains only valid nucleotides
593
+ if not all(c in 'ATGC' for c in dna_seq.upper()):
594
+ return False
595
+
596
+ return True
597
+ </code></pre>
598
+ </div>
599
+
600
+ <!-- Section 8: Streamlit GUI -->
601
+ <div class="section">
602
+ <h2 class="section-title">
603
+ <span class="section-number">8</span>
604
+ Streamlit GUI - Main Interface
605
+ </h2>
606
+ <div class="description">
607
+ Web-based graphical interface for ENCOT built with Streamlit, providing user-friendly access to optimization features.
608
+ </div>
609
+ <div class="file-info">
610
+ <div class="file-path">📄 streamlit_gui/app.py</div>
611
+ <div class="line-range">Lines 625-640 | Main function</div>
612
+ </div>
613
+ <div class="key-feature">
614
+ <strong>🎯 Highlight:</strong> Streamlit app structure with tabs and model loading
615
+ </div>
616
+ <pre><code class="language-python">def main():
617
+ st.title("ENCOT")
618
+ st.markdown("E. coli codon optimization with constraint-aware decoding and in silico evaluation metrics.")
619
+
620
+ # Load model
621
+ load_model_and_tokenizer()
622
+
623
+ # Create the main tabbed interface
624
+ tab1, tab2, tab3, tab4 = st.tabs([
625
+ "Single Optimize",
626
+ "Batch Process",
627
+ "Comparative Analysis",
628
+ "Advanced Settings"
629
+ ])
630
+
631
+ with tab1:
632
+ single_sequence_optimization()
633
+
634
+ with tab2:
635
+ batch_processing()
636
+
637
+ with tab3:
638
+ comparative_analysis()
639
+
640
+ with tab4:
641
+ advanced_settings()
642
+
643
+ # Footer
644
+ st.markdown("---")
645
+ st.markdown("**ENCOT**")
646
+ st.markdown("Open-source codon optimization for E. coli with reproducible evaluation.")
647
+ </code></pre>
648
+ </div>
649
+
650
+ <!-- Section 9: Benchmark Evaluation -->
651
+ <div class="section">
652
+ <h2 class="section-title">
653
+ <span class="section-number">9</span>
654
+ Benchmark Evaluation Pipeline
655
+ </h2>
656
+ <div class="description">
657
+ Comprehensive benchmarking pipeline for evaluating ENCOT performance on test sequences with multiple metrics.
658
+ </div>
659
+ <div class="file-info">
660
+ <div class="file-path">📄 benchmark_evaluation.py</div>
661
+ <div class="line-range">Lines 300-400 | Benchmark function</div>
662
+ </div>
663
+ <div class="key-feature">
664
+ <strong>🎯 Highlight:</strong> Multi-metric evaluation (CAI, tAI, GC, cis-elements)
665
+ </div>
666
+ <pre><code class="language-python">def benchmark_sequences(sequences, model, tokenizer, device, cai_weights, tai_weights):
667
+ """
668
+ Run ENCOT on protein sequences and compute metrics for optimized DNA.
669
+
670
+ Args:
671
+ sequences: List of protein sequences to optimize
672
+ model: Loaded ENCOT model
673
+ tokenizer: Tokenizer for the model
674
+ device: PyTorch device (CPU/GPU)
675
+ cai_weights: Pre-computed CAI weights
676
+ tai_weights: Pre-computed tAI weights
677
+
678
+ Returns:
679
+ DataFrame with optimization results and metrics
680
+ """
681
+ results = []
682
+
683
+ for name, protein in tqdm(sequences, desc="Optimizing sequences"):
684
+ # Optimize the sequence
685
+ output = predict_dna_sequence(
686
+ protein=protein,
687
+ organism="Escherichia coli general",
688
+ device=device,
689
+ model=model,
690
+ tokenizer=tokenizer,
691
+ deterministic=True,
692
+ use_constrained_search=True,
693
+ gc_bounds=(0.45, 0.55)
694
+ )
695
+
696
+ optimized_dna = output.predicted_dna
697
+
698
+ # Calculate metrics
699
+ cai = get_CSI_value(optimized_dna, cai_weights)
700
+ tai = calculate_tAI(optimized_dna, tai_weights)
701
+ gc_content = get_GC_content(optimized_dna)
702
+ cis_elements = count_negative_cis_elements(optimized_dna)
703
+
704
+ results.append({
705
+ 'name': name,
706
+ 'protein': protein,
707
+ 'optimized_dna': optimized_dna,
708
+ 'CAI': cai,
709
+ 'tAI': tai,
710
+ 'GC_content': gc_content,
711
+ 'negative_cis_elements': cis_elements
712
+ })
713
+
714
+ return pd.DataFrame(results)
715
+ </code></pre>
716
+ </div>
717
+
718
+ <!-- Section 10: Project Structure -->
719
+ <div class="section">
720
+ <h2 class="section-title">
721
+ <span class="section-number">10</span>
722
+ Project Overview & Architecture
723
+ </h2>
724
+ <div class="description">
725
+ Complete project structure showing the organization of modules, scripts, and configuration files.
726
+ </div>
727
+ <div class="key-feature">
728
+ <strong>🎯 Key Components:</strong> Training (finetune.py), Inference (CodonPrediction.py),
729
+ Evaluation (CodonEvaluation.py), GUI (streamlit_gui/), Configs (configs/)
730
+ </div>
731
+ <pre><code class="language-plaintext">ENCOT/
732
+ ├── CodonTransformer/ # Core library modules
733
+ │ ├── CodonPrediction.py # Model loading & DNA sequence prediction
734
+ │ ├── CodonEvaluation.py # Metrics (CAI, tAI, GC, CFD, etc.)
735
+ │ ├── CodonData.py # Data preprocessing & preparation
736
+ │ ├── CodonUtils.py # Constants, mappings, utilities
737
+ │ └── CodonPostProcessing.py # DNA-Chisel integration
738
+
739
+ ├── scripts/ # Command-line tools
740
+ │ ├── train.py # Training wrapper
741
+ │ ├── optimize_sequence.py # Sequence optimization CLI
742
+ │ ├── run_benchmarks.py # Benchmark evaluation
743
+ │ └── preprocess_data.py # Data preparation
744
+
745
+ ├── configs/ # YAML configurations
746
+ │ ├── train_ecoli_alm.yaml # Main ALM training config ⭐
747
+ │ └── train_ecoli_quick.yaml # Quick test config
748
+
749
+ ├── streamlit_gui/ # Web interface
750
+ │ ├── app.py # Main Streamlit GUI ⭐
751
+ │ ├── demo.py # Demo script
752
+ │ └── run_gui.py # Launcher
753
+
754
+ ├── data/ # Datasets
755
+ │ ├── finetune_set.json # Training data
756
+ │ └── test_set.json # Test data
757
+
758
+ ├── finetune.py # Main training script ⭐⭐⭐
759
+ ├── benchmark_evaluation.py # Evaluation script
760
+ ├── setup.py # Package setup
761
+ ├── pyproject.toml # Project configuration
762
+ └── README.md # Documentation
763
+
764
+ Key Innovations:
765
+ ⭐⭐⭐ Augmented-Lagrangian Method (ALM) for GC control
766
+ ⭐⭐ Constrained beam search with GC bounds
767
+ ⭐ Multi-metric evaluation (CAI, tAI, GC, cis-elements)
768
+ </code></pre>
769
+ </div>
770
+
771
+ <div class="footer">
772
+ <h3>ENCOT - Enhanced Codon Optimization Tool</h3>
773
+ <p>Repository: <a href="https://github.com/geno543/ENCOT" style="color: #58a6ff;">github.com/geno543/ENCOT</a></p>
774
+ <p>© 2026 | Apache License 2.0</p>
775
+ </div>
776
+
777
+ <script>
778
+ // Initialize syntax highlighting
779
+ hljs.highlightAll();
780
+
781
+ // Add line numbers
782
+ document.querySelectorAll('pre code').forEach((block) => {
783
+ const lines = block.innerHTML.split('\n');
784
+ const numberedLines = lines.map((line, index) => {
785
+ return `<span class="line-number" style="color: #6e7681; user-select: none; margin-right: 1em;">${String(index + 1).padStart(3, ' ')}</span>${line}`;
786
+ }).join('\n');
787
+ block.innerHTML = numberedLines;
788
+ });
789
+ </script>
790
+ </body>
791
+ </html>
LICENSE ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Apache License
2
+ Version 2.0, January 2004
3
+ http://www.apache.org/licenses/
4
+
5
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6
+
7
+ 1. Definitions.
8
+
9
+ "License" shall mean the terms and conditions for use, reproduction,
10
+ and distribution as defined by Sections 1 through 9 of this document.
11
+
12
+ "Licensor" shall mean the copyright owner or entity authorized by
13
+ the copyright owner that is granting the License.
14
+
15
+ "Legal Entity" shall mean the union of the acting entity and all
16
+ other entities that control, are controlled by, or are under common
17
+ control with that entity. For the purposes of this definition,
18
+ "control" means (i) the power, direct or indirect, to cause the
19
+ direction or management of such entity, whether by contract or
20
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
21
+ outstanding shares, or (iii) beneficial ownership of such entity.
22
+
23
+ "You" (or "Your") shall mean an individual or Legal Entity
24
+ exercising permissions granted by this License.
25
+
26
+ "Source" form shall mean the preferred form for making modifications,
27
+ including but not limited to software source code, documentation
28
+ source, and configuration files.
29
+
30
+ "Object" form shall mean any form resulting from mechanical
31
+ transformation or translation of a Source form, including but
32
+ not limited to compiled object code, generated documentation,
33
+ and conversions to other media types.
34
+
35
+ "Work" shall mean the work of authorship, whether in Source or
36
+ Object form, made available under the License, as indicated by a
37
+ copyright notice that is included in or attached to the work
38
+ (an example is provided in the Appendix below).
39
+
40
+ "Derivative Works" shall mean any work, whether in Source or Object
41
+ form, that is based on (or derived from) the Work and for which the
42
+ editorial revisions, annotations, elaborations, or other modifications
43
+ represent, as a whole, an original work of authorship. For the purposes
44
+ of this License, Derivative Works shall not include works that remain
45
+ separable from, or merely link (or bind by name) to the interfaces of,
46
+ the Work and Derivative Works thereof.
47
+
48
+ "Contribution" shall mean any work of authorship, including
49
+ the original version of the Work and any modifications or additions
50
+ to that Work or Derivative Works thereof, that is intentionally
51
+ submitted to Licensor for inclusion in the Work by the copyright owner
52
+ or by an individual or Legal Entity authorized to submit on behalf of
53
+ the copyright owner. For the purposes of this definition, "submitted"
54
+ means any form of electronic, verbal, or written communication sent
55
+ to the Licensor or its representatives, including but not limited to
56
+ communication on electronic mailing lists, source code control systems,
57
+ and issue tracking systems that are managed by, or on behalf of, the
58
+ Licensor for the purpose of discussing and improving the Work, but
59
+ excluding communication that is conspicuously marked or otherwise
60
+ designated in writing by the copyright owner as "Not a Contribution."
61
+
62
+ "Contributor" shall mean Licensor and any individual or Legal Entity
63
+ on behalf of whom a Contribution has been received by Licensor and
64
+ subsequently incorporated within the Work.
65
+
66
+ 2. Grant of Copyright License. Subject to the terms and conditions of
67
+ this License, each Contributor hereby grants to You a perpetual,
68
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69
+ copyright license to reproduce, prepare Derivative Works of,
70
+ publicly display, publicly perform, sublicense, and distribute the
71
+ Work and such Derivative Works in Source or Object form.
72
+
73
+ 3. Grant of Patent License. Subject to the terms and conditions of
74
+ this License, each Contributor hereby grants to You a perpetual,
75
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76
+ (except as stated in this section) patent license to make, have made,
77
+ use, offer to sell, sell, import, and otherwise transfer the Work,
78
+ where such license applies only to those patent claims licensable
79
+ by such Contributor that are necessarily infringed by their
80
+ Contribution(s) alone or by combination of their Contribution(s)
81
+ with the Work to which such Contribution(s) was submitted. If You
82
+ institute patent litigation against any entity (including a
83
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
84
+ or a Contribution incorporated within the Work constitutes direct
85
+ or contributory patent infringement, then any patent licenses
86
+ granted to You under this License for that Work shall terminate
87
+ as of the date such litigation is filed.
88
+
89
+ 4. Redistribution. You may reproduce and distribute copies of the
90
+ Work or Derivative Works thereof in any medium, with or without
91
+ modifications, and in Source or Object form, provided that You
92
+ meet the following conditions:
93
+
94
+ (a) You must give any other recipients of the Work or
95
+ Derivative Works a copy of this License; and
96
+
97
+ (b) You must cause any modified files to carry prominent notices
98
+ stating that You changed the files; and
99
+
100
+ (c) You must retain, in the Source form of any Derivative Works
101
+ that You distribute, all copyright, patent, trademark, and
102
+ attribution notices from the Source form of the Work,
103
+ excluding those notices that do not pertain to any part of
104
+ the Derivative Works; and
105
+
106
+ (d) If the Work includes a "NOTICE" text file as part of its
107
+ distribution, then any Derivative Works that You distribute must
108
+ include a readable copy of the attribution notices contained
109
+ within such NOTICE file, excluding those notices that do not
110
+ pertain to any part of the Derivative Works, in at least one
111
+ of the following places: within a NOTICE text file distributed
112
+ as part of the Derivative Works; within the Source form or
113
+ documentation, if provided along with the Derivative Works; or,
114
+ within a display generated by the Derivative Works, if and
115
+ wherever such third-party notices normally appear. The contents
116
+ of the NOTICE file are for informational purposes only and
117
+ do not modify the License. You may add Your own attribution
118
+ notices within Derivative Works that You distribute, alongside
119
+ or as an addendum to the NOTICE text from the Work, provided
120
+ that such additional attribution notices cannot be construed
121
+ as modifying the License.
122
+
123
+ You may add Your own copyright statement to Your modifications and
124
+ may provide additional or different license terms and conditions
125
+ for use, reproduction, or distribution of Your modifications, or
126
+ for any such Derivative Works as a whole, provided Your use,
127
+ reproduction, and distribution of the Work otherwise complies with
128
+ the conditions stated in this License.
129
+
130
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
131
+ any Contribution intentionally submitted for inclusion in the Work
132
+ by You to the Licensor shall be under the terms and conditions of
133
+ this License, without any additional terms or conditions.
134
+ Notwithstanding the above, nothing herein shall supersede or modify
135
+ the terms of any separate license agreement you may have executed
136
+ with Licensor regarding such Contributions.
137
+
138
+ 6. Trademarks. This License does not grant permission to use the trade
139
+ names, trademarks, service marks, or product names of the Licensor,
140
+ except as required for reasonable and customary use in describing the
141
+ origin of the Work and reproducing the content of the NOTICE file.
142
+
143
+ 7. Disclaimer of Warranty. Unless required by applicable law or
144
+ agreed to in writing, Licensor provides the Work (and each
145
+ Contributor provides its Contributions) on an "AS IS" BASIS,
146
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147
+ implied, including, without limitation, any warranties or conditions
148
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149
+ PARTICULAR PURPOSE. You are solely responsible for determining the
150
+ appropriateness of using or redistributing the Work and assume any
151
+ risks associated with Your exercise of permissions under this License.
152
+
153
+ 8. Limitation of Liability. In no event and under no legal theory,
154
+ whether in tort (including negligence), contract, or otherwise,
155
+ unless required by applicable law (such as deliberate and grossly
156
+ negligent acts) or agreed to in writing, shall any Contributor be
157
+ liable to You for damages, including any direct, indirect, special,
158
+ incidental, or consequential damages of any character arising as a
159
+ result of this License or out of the use or inability to use the
160
+ Work (including but not limited to damages for loss of goodwill,
161
+ work stoppage, computer failure or malfunction, or any and all
162
+ other commercial damages or losses), even if such Contributor
163
+ has been advised of the possibility of such damages.
164
+
165
+ 9. Accepting Warranty or Additional Liability. While redistributing
166
+ the Work or Derivative Works thereof, You may choose to offer,
167
+ and charge a fee for, acceptance of support, warranty, indemnity,
168
+ or other liability obligations and/or rights consistent with this
169
+ License. However, in accepting such obligations, You may act only
170
+ on Your own behalf and on Your sole responsibility, not on behalf
171
+ of any other Contributor, and only if You agree to indemnify,
172
+ defend, and hold each Contributor harmless for any liability
173
+ incurred by, or claims asserted against, such Contributor by reason
174
+ of your accepting any such warranty or additional liability.
175
+
176
+ END OF TERMS AND CONDITIONS
177
+
178
+ APPENDIX: How to apply the Apache License to your work.
179
+
180
+ To apply the Apache License to your work, attach the following
181
+ boilerplate notice, with the fields enclosed by brackets "[]"
182
+ replaced with your own identifying information. (Don't include
183
+ the brackets!) The text should be enclosed in the appropriate
184
+ comment syntax for the file format. We also recommend that a
185
+ file or class name and description of purpose be included on the
186
+ same "printed page" as the copyright notice for easier
187
+ identification within third-party archives.
188
+
189
+ Copyright 2024 Adibvafa Fallahpour
190
+
191
+ Licensed under the Apache License, Version 2.0 (the "License");
192
+ you may not use this file except in compliance with the License.
193
+ You may obtain a copy of the License at
194
+
195
+ http://www.apache.org/licenses/LICENSE-2.0
196
+
197
+ Unless required by applicable law or agreed to in writing, software
198
+ distributed under the License is distributed on an "AS IS" BASIS,
199
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200
+ See the License for the specific language governing permissions and
201
+ limitations under the License.
Makefile ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ # Makefile
2
+
3
+ .PHONY: test
4
+ test:
5
+ python -m unittest discover -s tests
6
+
7
+ .PHONY: test_with_coverage
8
+ test_with_coverage:
9
+ coverage run -m unittest discover -s tests
README.md CHANGED
@@ -1,10 +1,495 @@
1
- ---
2
- title: ColiFormer Ui
3
- emoji: 👁
4
- colorFrom: gray
5
- colorTo: green
6
- sdk: docker
7
- pinned: false
8
- ---
9
-
10
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ENCOT: A Transformer-Based Codon Optimization Model Balancing Multiple Objectives for Enhanced E. coli Gene Expression
2
+
3
+
4
+ <p align="center">
5
+ <a href="https://huggingface.co/saketh11/ColiFormer"><img src="https://img.shields.io/badge/HuggingFace-Model-FFBF00?style=for-the-badge&logo=huggingface&logoColor=white" alt="HuggingFace Model"></a>
6
+ <a href="https://huggingface.co/datasets/saketh11/ColiFormer-Data"><img src="https://img.shields.io/badge/HuggingFace-Data-FFBF00?style=for-the-badge&logo=huggingface&logoColor=white" alt="HuggingFace Dataset"></a>
7
+ </p>
8
+
9
+ ## Abstract
10
+
11
+ ENCOT is a transformer-based model for codon optimization of protein sequences in *Escherichia coli*. Built on top of CodonTransformer (a multi-species BigBird model trained on over 1 million DNA–protein pairs), ENCOT is fine-tuned specifically for E. coli codon preferences using 3,676 high-expression E. coli genes curated from NCBI.
12
+
13
+ ENCOT balances multiple objectives (CAI, GC content, tAI, RNA stability, and minimization of negative cis-regulatory elements) and uses an **Augmented-Lagrangian Method (ALM)** to enforce GC content control during training. Performance was evaluated on 37,053 native E. coli genes and 80 recombinant protein targets, demonstrating strong improvements in in silico expression metrics while maintaining biologically appropriate constraints.
14
+
15
+ ## Paper Reference
16
+
17
+ **ENCOT: A Transformer-Based Codon Optimization Model Balancing Multiple Objectives for Enhanced E. coli Gene Expression**
18
+
19
+ Saketh Baddam, Omar Emam, Abdelrahman Elfikky, Francesco Cavarretta, George Luka, Ibrahim Farag, Yasser Sanad
20
+
21
+ bioRxiv preprint (not peer-reviewed): `https://doi.org/10.1101/2025.11.26.690826`
22
+
23
+ **What does “preprint and not peer-reviewed” mean?** A preprint is a publicly available manuscript shared before formal journal peer review. It can be cited, but its claims have not yet been evaluated by journal referees.
24
+
25
+ ### Citation
26
+
27
+ If you use ENCOT in your research, please cite:
28
+
29
+ ```bibtex
30
+ @article{encot2025,
31
+ title{ENCOT: A Transformer-Based Codon Optimization Model Balancing Multiple Objectives for Enhanced E. coli Gene Expression},
32
+ author={Baddam, Saketh and Emam, Omar and Elfikky, Abdelrahman and Cavarretta, Francesco and Luka, George and Farag, Ibrahim and Sanad, Yasser},
33
+ journal={bioRxiv},
34
+ year={2025},
35
+ doi={10.1101/2025.11.26.690826},
36
+ url={https://doi.org/10.1101/2025.11.26.690826},
37
+ note={Preprint (not peer-reviewed)}
38
+ }
39
+ ```
40
+
41
+ ## Quick Start
42
+
43
+ Optimize a protein sequence in just a few lines:
44
+
45
+ ```python
46
+ import torch
47
+ from transformers import AutoTokenizer
48
+ from CodonTransformer.CodonPrediction import load_model, predict_dna_sequence
49
+ from huggingface_hub import hf_hub_download
50
+
51
+ # Load model from Hugging Face
52
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
53
+ checkpoint_path = hf_hub_download(
54
+ repo_id="saketh11/ColiFormer",
55
+ filename="balanced_alm_finetune.ckpt",
56
+ cache_dir="./hf_cache"
57
+ )
58
+ model = load_model(model_path=checkpoint_path, device=device)
59
+ tokenizer = AutoTokenizer.from_pretrained("adibvafa/CodonTransformer")
60
+
61
+ # Optimize a protein sequence
62
+ protein = "MALWMRLLPLLALLALWGPDPAAAFVNQHLCGSHLVEALYLVCGERGFFYTPKTRREAEDLQVGQVELGG"
63
+ output = predict_dna_sequence(
64
+ protein=protein,
65
+ organism="Escherichia coli general",
66
+ device=device,
67
+ model=model,
68
+ tokenizer=tokenizer,
69
+ deterministic=True,
70
+ match_protein=True
71
+ )
72
+
73
+ print(f"Optimized DNA: {output.predicted_dna}")
74
+ ```
75
+
76
+ Or use the command-line interface:
77
+
78
+ ```bash
79
+ python scripts/optimize_sequence.py \
80
+ --input "MALWMRLLPLLALLALWGPDPAAAFVNQHLCGSHLVEALYLVCGERGFFYTPKTRREAEDLQVGQVELGG" \
81
+ --output optimized.fasta
82
+ ```
83
+
84
+ ## Installation
85
+
86
+ ### Requirements
87
+
88
+ - Python >= 3.9
89
+ - CUDA-capable GPU (recommended for training, optional for inference)
90
+
91
+ ### Setup
92
+
93
+ 1. **Clone the repository:**
94
+
95
+ ```bash
96
+ git clone https://github.com/geno543/ENCOT.git
97
+ cd ENCOT
98
+ ```
99
+
100
+ 2. **Create a virtual environment:**
101
+
102
+ ```bash
103
+ python -m venv venv
104
+ source venv/bin/activate # On Windows: venv\Scripts\activate
105
+ ```
106
+
107
+ 3. **Install dependencies:**
108
+
109
+ ```bash
110
+ pip install -r requirements.txt
111
+ ```
112
+
113
+ The installation takes approximately 10-30 seconds depending on your system and existing packages.
114
+
115
+ ## Public Streamlit Demo (Anyone Can Try It)
116
+
117
+ If you want a public link so anyone can test ENCOT in a browser, deploy the app with either Streamlit Community Cloud or Hugging Face Spaces.
118
+
119
+ ### Option A: Streamlit Community Cloud (Fastest)
120
+
121
+ 1. Push this repository to GitHub.
122
+ 2. Go to https://share.streamlit.io and sign in.
123
+ 3. Click **New app** and choose your repository.
124
+ 4. Set **Main file path** to `streamlit_app.py`.
125
+ 5. Use the repository `requirements.txt` for dependencies.
126
+ 6. Deploy and share the generated public URL.
127
+
128
+ ### Option B: Hugging Face Spaces (Streamlit)
129
+
130
+ 1. Create a new Space (SDK: **Streamlit**).
131
+ 2. Upload this project (or connect the GitHub repo).
132
+ 3. Ensure app file is `streamlit_app.py`.
133
+ 4. Keep the repo public so anyone can access the Space URL.
134
+
135
+ ### Local check before deployment
136
+
137
+ ```bash
138
+ streamlit run streamlit_app.py --server.port 8501
139
+ ```
140
+
141
+ This uses the existing UI in `streamlit_gui/app.py`, including model loading from Hugging Face and optimization controls.
142
+
143
+ ## Data Preparation
144
+
145
+ ### Preparing E. coli Training Data
146
+
147
+ To prepare training data from raw E. coli gene sequences:
148
+
149
+ 1. **Place your data files in the `data/` directory:**
150
+ - `data/CAI.csv` - CSV file with columns: gene_id, cai_score, dna_sequence
151
+ - `data/Database 3_4300 gene.csv` - CSV file with high-CAI sequences (column: dna_sequence)
152
+
153
+ 2. **Run the preprocessing script:**
154
+
155
+ ```bash
156
+ python scripts/preprocess_data.py
157
+ ```
158
+
159
+ This will:
160
+ - Validate and process DNA sequences
161
+ - Create `data/ecoli_processed_genes.csv` with validated sequences
162
+ - Generate `data/finetune_set.json` for training (high-CAI sequences)
163
+ - Generate `data/test_set.json` for evaluation (100 random sequences)
164
+
165
+ **Custom paths:**
166
+
167
+ ```bash
168
+ python scripts/preprocess_data.py \
169
+ --cai_csv data/my_cai_data.csv \
170
+ --high_cai_csv data/my_high_cai_data.csv \
171
+ --output_dir my_data \
172
+ --test_size 200
173
+ ```
174
+
175
+ ### Dataset Structure
176
+
177
+ The processed dataset includes:
178
+ - **Training set**: 4,300 high-CAI E. coli sequences (from `Database 3_4300 gene.csv`)
179
+ - **Test set**: 100 randomly sampled sequences (for evaluation)
180
+ - **Reference sequences**: 50,000+ E. coli genes for CAI/tAI calculation
181
+
182
+ The complete dataset is available at [saketh11/ColiFormer-Data](https://huggingface.co/datasets/saketh11/ColiFormer-Data) on Hugging Face.
183
+
184
+ ## Training
185
+
186
+ ### Quick Start Training
187
+
188
+ Train ENCOT with the default ALM configuration:
189
+
190
+ ```bash
191
+ python scripts/train.py --config configs/train_ecoli_alm.yaml
192
+ ```
193
+
194
+ ### Configuration Files
195
+
196
+ We provide three configuration files:
197
+
198
+ 1. **`configs/train_ecoli_alm.yaml`** - Main training configuration with ALM GC control
199
+ - 15 epochs, batch size 6, 4 GPUs
200
+ - ALM enabled with GC target 52%
201
+ - Curriculum learning: 3 warm-up epochs
202
+
203
+ 2. **`configs/train_ecoli_quick.yaml`** - Quick sanity check
204
+ - 1 epoch, batch size 2, CPU-only
205
+ - Useful for testing your setup
206
+
207
+ 3. **`configs/benchmark.yaml`** - Benchmark evaluation settings
208
+
209
+ ### Training Parameters
210
+
211
+ Key parameters in the config files:
212
+
213
+ - **`training.batch_size`**: Batch size (default: 6)
214
+ - **`training.max_epochs`**: Number of training epochs (default: 15)
215
+ - **`training.learning_rate`**: Learning rate (default: 5e-5)
216
+ - **`training.num_gpus`**: Number of GPUs (default: 4)
217
+ - **`alm.enabled`**: Enable ALM GC control (default: true)
218
+ - **`alm.gc_target`**: Target GC content (default: 0.52 for E. coli)
219
+ - **`alm.curriculum_epochs`**: Warm-up epochs before enforcing GC constraint (default: 3)
220
+
221
+ ### Override Config Values
222
+
223
+ You can override config values from the command line:
224
+
225
+ ```bash
226
+ python scripts/train.py \
227
+ --config configs/train_ecoli_alm.yaml \
228
+ --num_gpus 2 \
229
+ --batch_size 4 \
230
+ --max_epochs 10
231
+ ```
232
+
233
+ ### Training Output
234
+
235
+ Checkpoints are saved to the directory specified in `checkpoint.checkpoint_dir`:
236
+ - Model state dict: `balanced_alm_finetune.ckpt`
237
+ - Training logs: TensorBoard logs in the checkpoint directory
238
+
239
+ Monitor training progress:
240
+
241
+ ```bash
242
+ tensorboard --logdir models/alm-enhanced-training
243
+ ```
244
+
245
+ ## Inference / Sequence Optimization
246
+
247
+ ### Single Sequence Optimization
248
+
249
+ Optimize a single protein sequence:
250
+
251
+ ```bash
252
+ python scripts/optimize_sequence.py \
253
+ --input "MALWMRLLPLLALLALWGPDPAAAFVNQHLCGSHLVEALYLVCGERGFFYTPKTRREAEDLQVGQVELGG" \
254
+ --output optimized.fasta
255
+ ```
256
+
257
+ ### Batch Processing
258
+
259
+ Process multiple sequences from a FASTA file:
260
+
261
+ ```bash
262
+ python scripts/optimize_sequence.py \
263
+ --input sequences.fasta \
264
+ --output optimized.fasta \
265
+ --batch
266
+ ```
267
+
268
+ ### GC Content Constraints
269
+
270
+ Specify GC content bounds:
271
+
272
+ ```bash
273
+ python scripts/optimize_sequence.py \
274
+ --input protein.fasta \
275
+ --output optimized.fasta \
276
+ --gc-min 0.45 \
277
+ --gc-max 0.55
278
+ ```
279
+
280
+ ### Using Custom Checkpoint
281
+
282
+ ```bash
283
+ python scripts/optimize_sequence.py \
284
+ --input protein.fasta \
285
+ --output optimized.fasta \
286
+ --checkpoint models/my_model.ckpt
287
+ ```
288
+
289
+ ### Python API
290
+
291
+ For programmatic use:
292
+
293
+ ```python
294
+ from CodonTransformer.CodonPrediction import load_model, predict_dna_sequence
295
+ from transformers import AutoTokenizer
296
+ import torch
297
+
298
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
299
+ model = load_model(model_path="models/alm-enhanced-training/balanced_alm_finetune.ckpt", device=device)
300
+ tokenizer = AutoTokenizer.from_pretrained("adibvafa/CodonTransformer")
301
+
302
+ output = predict_dna_sequence(
303
+ protein="MALWMRLLPLLALLALWGPDPAAAFVNQHLCGSHLVEALYLVCGERGFFYTPKTRREAEDLQVGQVELGG",
304
+ organism="Escherichia coli general",
305
+ device=device,
306
+ model=model,
307
+ tokenizer=tokenizer,
308
+ deterministic=True,
309
+ match_protein=True,
310
+ use_constrained_search=True,
311
+ gc_bounds=(0.45, 0.55),
312
+ beam_size=20
313
+ )
314
+
315
+ print(f"Optimized DNA: {output.predicted_dna}")
316
+ ```
317
+
318
+ ## Reproducing Paper Results
319
+
320
+ ### Benchmark Evaluation
321
+
322
+ To reproduce the benchmark results from the paper:
323
+
324
+ 1. **Prepare benchmark sequences:**
325
+
326
+ Place your benchmark sequences in an Excel file (see `Benchmark 80 sequences.xlsx` for format).
327
+
328
+ 2. **Run benchmark evaluation:**
329
+
330
+ ```bash
331
+ python scripts/run_benchmarks.py --config configs/benchmark.yaml
332
+ ```
333
+
334
+ This will:
335
+ - Load the fine-tuned ENCOT model
336
+ - Optimize all sequences in the benchmark file
337
+ - Calculate metrics (CAI, tAI, GC content, CFD, negative cis-elements)
338
+ - Generate comparison plots and summary statistics
339
+ - Save results to `benchmark_results/run_TIMESTAMP/`
340
+
341
+ ### Expected Results
342
+
343
+ On the benchmark set of 80 sequences:
344
+ - **CAI improvement**: +6.2% vs base CodonTransformer
345
+ - **tAI improvement**: +8.6% vs base CodonTransformer
346
+ - **GC content**: Mean 52.1% (target: 52%)
347
+ - **Runtime**: ~1-3 seconds per sequence (GPU)
348
+
349
+ ### Custom Benchmark
350
+
351
+ ```bash
352
+ python scripts/run_benchmarks.py \
353
+ --excel_path my_benchmark.xlsx \
354
+ --checkpoint_path models/my_model.ckpt \
355
+ --output_dir my_results \
356
+ --use_gpu
357
+ ```
358
+
359
+ ## Model Architecture
360
+
361
+ ### Base Model
362
+
363
+ ENCOT is built on CodonTransformer, a BigBird transformer model:
364
+ - **Architecture**: BigBirdForMaskedLM (89.6M parameters)
365
+ - **Pre-training**: 1M+ DNA-protein pairs from 164 organisms
366
+ - **Context length**: 2048 tokens
367
+ - **Attention**: Block-sparse attention for efficiency
368
+
369
+ ### Fine-tuning
370
+
371
+ ENCOT is fine-tuned on E. coli-specific data:
372
+ - **Training data**: 4,300 high-CAI E. coli sequences
373
+ - **Loss function**: Masked Language Modeling (MLM) + GC constraint
374
+ - **Optimizer**: AdamW with CosineAnnealingWarmRestarts scheduler
375
+ - **Learning rate**: 5e-5 with 10% warmup
376
+
377
+ ### Augmented-Lagrangian Method (ALM)
378
+
379
+ The ALM approach enforces GC content constraints during training:
380
+
381
+ **Objective function:**
382
+ ```
383
+ L = L_MLM + λ·(GC - μ) + (ρ/2)(GC - μ)²
384
+ ```
385
+
386
+ Where:
387
+ - `L_MLM`: Masked language modeling loss
388
+ - `λ`: Lagrangian multiplier (updated adaptively)
389
+ - `ρ`: Penalty coefficient (self-tuning)
390
+ - `GC`: Mean GC content (sliding window of 50 codons)
391
+ - `μ`: Target GC content (0.52 for E. coli)
392
+
393
+ **Key features:**
394
+ - **Curriculum learning**: 3 warm-up epochs before enforcing GC constraint
395
+ - **Adaptive penalty**: Penalty coefficient increases if constraint violation doesn't improve
396
+ - **Self-tuning**: Lagrangian multiplier and penalty updated every 20 steps
397
+
398
+ This approach allows the model to learn codon preferences while maintaining precise GC content control, critical for synthesis and expression in E. coli.
399
+
400
+ ## Evaluation Metrics
401
+
402
+ ENCOT computes comprehensive metrics for optimized sequences:
403
+
404
+ - **CAI (Codon Adaptation Index)**: Measures similarity to highly expressed genes (0-1, higher is better)
405
+ - **tAI (tRNA Adaptation Index)**: Measures tRNA availability (0-1, higher is better)
406
+ - **GC Content**: Percentage of G+C nucleotides (target: 52% for E. coli)
407
+ - **CFD (Codon Frequency Distribution)**: Similarity to reference codon frequencies
408
+ - **Negative cis-elements**: Count of problematic sequence motifs
409
+ - **Homopolymer runs**: Long repeats that can cause synthesis issues
410
+
411
+ ## Project Structure
412
+
413
+ ```
414
+ encot/
415
+ ├── configs/ # YAML configuration files
416
+ │ ├── train_ecoli_alm.yaml # Main training config
417
+ │ ├── train_ecoli_quick.yaml # Quick test config
418
+ │ └── benchmark.yaml # Benchmark config
419
+ ├── scripts/ # Entry-point scripts
420
+ │ ├── preprocess_data.py # Data preparation
421
+ │ ├── train.py # Training wrapper
422
+ │ ├── optimize_sequence.py # Sequence optimization
423
+ │ └── run_benchmarks.py # Benchmark evaluation
424
+ ├── CodonTransformer/ # Core module (custom, not PyPI)
425
+ │ ├── CodonPrediction.py # Model loading & inference
426
+ │ ├── CodonEvaluation.py # Metrics calculation
427
+ │ ├── CodonData.py # Data preprocessing
428
+ │ └── ...
429
+ ├── data/ # Datasets
430
+ │ ├── finetune_set.json # Training data
431
+ │ ├── test_set.json # Test data
432
+ │ └── ecoli_processed_genes.csv # Reference sequences
433
+ ├── models/ # Model checkpoints
434
+ ├── notebooks/ # Jupyter notebooks
435
+ ├── tests/ # Test suite
436
+ ├── streamlit_gui/ # Streamlit web interface
437
+ ├── finetune.py # Training script (original)
438
+ ├── benchmark_evaluation.py # Evaluation script (original)
439
+ └── README.md # This file
440
+ ```
441
+
442
+ ## Troubleshooting
443
+
444
+ ### Common Issues
445
+
446
+ **1. CUDA out of memory:**
447
+ - Reduce `batch_size` in config file
448
+ - Use gradient accumulation: increase `accumulate_grad_batches`
449
+
450
+ **2. Model checkpoint not found:**
451
+ - The script will auto-download from Hugging Face if local checkpoint missing
452
+ - Ensure you have internet connection for first run
453
+
454
+ **3. Data preprocessing errors:**
455
+ - Verify CSV files have correct column names
456
+ - Check that DNA sequences are valid (divisible by 3, proper start/stop codons)
457
+
458
+ **4. Import errors:**
459
+ - Ensure you've activated the virtual environment
460
+ - Run `pip install -r requirements.txt` again
461
+
462
+ ### Getting Help
463
+
464
+ - **Issues**: Open an issue on GitHub
465
+ - **Questions**: Check the documentation or contact the authors
466
+
467
+ ## License
468
+
469
+ This project is licensed under the Apache License 2.0 - see the [LICENSE](LICENSE) file for details.
470
+
471
+ ## Acknowledgments
472
+
473
+ - **CodonTransformer**: Base model from [adibvafa/CodonTransformer](https://github.com/adibvafa/CodonTransformer)
474
+ - **Hugging Face**: Model hosting and distribution
475
+ - **E. coli data**: NCBI and Kazusa codon usage databases
476
+
477
+ ## Citation
478
+
479
+ If you use ENCOT in your research, please cite:
480
+
481
+ ```bibtex
482
+ @article{encot2025,
483
+ title={ENCOT: A Transformer-Based Codon Optimization Model Balancing Multiple Objectives for Enhanced E. coli Gene Expression},
484
+ author={Baddam, Saketh and Emam, Omar and Elfikky, Abdelrahman and Cavarretta, Francesco and Luka, George and Farag, Ibrahim and Sanad, Yasser},
485
+ journal={bioRxiv},
486
+ year={2025},
487
+ doi={10.1101/2025.11.26.690826},
488
+ url={https://doi.org/10.1101/2025.11.26.690826},
489
+ note={Preprint (not peer-reviewed)}
490
+ }
491
+ ```
492
+
493
+ ---
494
+
495
+ **ENCOT** - State-of-the-art codon optimization for E. coli expression systems.
app.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Hugging Face Spaces Streamlit entrypoint for ENCOT."""
2
+
3
+ from pathlib import Path
4
+ import sys
5
+
6
+
7
+ ROOT = Path(__file__).resolve().parent
8
+ if str(ROOT) not in sys.path:
9
+ sys.path.insert(0, str(ROOT))
10
+
11
+ # Importing this module executes the Streamlit UI.
12
+ import streamlit_gui.app # noqa: F401,E402
benchmark_evaluation.py ADDED
@@ -0,0 +1,695 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ File: benchmark_evaluation.py
3
+ ------------------------------
4
+ Benchmark E. coli protein sequences with ENCOT, generate optimized DNA,
5
+ compute metrics (CAI, tAI, GC, CFD, cis-elements), and produce summary tables
6
+ and figures.
7
+ """
8
+
9
+ import sys
10
+ import os
11
+ import argparse
12
+ import pandas as pd
13
+ import numpy as np
14
+ import torch
15
+ import json
16
+ import matplotlib.pyplot as plt
17
+ import seaborn as sns
18
+ from datetime import datetime
19
+ import time
20
+ from tqdm import tqdm
21
+ from typing import Dict, List, Tuple, Any
22
+
23
+ from CAI import CAI, relative_adaptiveness
24
+ from CodonTransformer.CodonData import (
25
+ download_codon_frequencies_from_kazusa,
26
+ get_codon_frequencies,
27
+ )
28
+ from CodonTransformer.CodonPrediction import (
29
+ load_model,
30
+ predict_dna_sequence,
31
+ )
32
+ from CodonTransformer.CodonEvaluation import (
33
+ get_GC_content,
34
+ get_ecoli_tai_weights,
35
+ get_min_max_profile,
36
+ calculate_tAI,
37
+ count_negative_cis_elements,
38
+ )
39
+ from transformers import AutoTokenizer
40
+ sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
41
+ from evaluate_optimizer import translate_dna_to_protein
42
+
43
+
44
+ def find_longest_orf(dna_sequence: str) -> str:
45
+ """
46
+ Find the longest open reading frame (ORF) in a DNA sequence.
47
+
48
+ Args:
49
+ dna_sequence (str): Input DNA sequence (ATCGN characters).
50
+
51
+ Returns:
52
+ str: Longest ORF (from start to stop codon), or empty string if none.
53
+ """
54
+ dna_sequence = dna_sequence.upper()
55
+ start_codons = ['ATG']
56
+ stop_codons = ['TAA', 'TAG', 'TGA']
57
+
58
+ longest_orf = ""
59
+
60
+ for frame in range(3):
61
+ current_orf = ""
62
+ in_orf = False
63
+
64
+ for i in range(frame, len(dna_sequence) - 2, 3):
65
+ codon = dna_sequence[i:i+3]
66
+ if len(codon) != 3:
67
+ break
68
+
69
+ if codon in start_codons and not in_orf:
70
+ in_orf = True
71
+ current_orf = codon
72
+ elif in_orf:
73
+ current_orf += codon
74
+ if codon in stop_codons:
75
+ if len(current_orf) > len(longest_orf):
76
+ longest_orf = current_orf
77
+ in_orf = False
78
+ current_orf = ""
79
+
80
+ if in_orf and len(current_orf) > len(longest_orf):
81
+ longest_orf = current_orf
82
+
83
+ return longest_orf
84
+
85
+
86
+ def _detect_columns(df: pd.DataFrame, name_hint: str | None = None, seq_hint: str | None = None) -> tuple[str | None, str]:
87
+ """
88
+ Detect name and sequence columns in a case-insensitive, robust way.
89
+
90
+ Args:
91
+ df (pd.DataFrame): Input DataFrame read from Excel.
92
+ name_hint (str | None): Optional override for name/label column (case-insensitive).
93
+ seq_hint (str | None): Optional override for sequence column (case-insensitive).
94
+
95
+ Returns:
96
+ tuple[str | None, str]: Detected (name_column or None, sequence_column).
97
+
98
+ Raises:
99
+ ValueError: If a sequence-like column cannot be found.
100
+ """
101
+ cols = list(df.columns)
102
+ low_map = {c.lower().strip(): c for c in cols}
103
+
104
+ # If hints are provided and exist (case-insensitive), honor them
105
+ if name_hint:
106
+ nh = name_hint.lower().strip()
107
+ if nh in low_map:
108
+ name_col = low_map[nh]
109
+ else:
110
+ name_col = None
111
+ else:
112
+ name_col = None
113
+
114
+ if seq_hint:
115
+ sh = seq_hint.lower().strip()
116
+ if sh in low_map:
117
+ seq_col = low_map[sh]
118
+ else:
119
+ seq_col = None
120
+ else:
121
+ seq_col = None
122
+
123
+ # If not found, try candidates
124
+ if name_col is None:
125
+ name_candidates = [
126
+ 'name','id','title','gene','protein','description','label','accession','locus','entry','uniprot','ncbi','protein name'
127
+ ]
128
+ for k in name_candidates:
129
+ if k in low_map:
130
+ name_col = low_map[k]
131
+ break
132
+
133
+ if seq_col is None:
134
+ seq_candidates = [
135
+ # protein-first
136
+ 'protein sequence','protein_sequence','protein','aa sequence','aa_sequence','aa','amino acid sequence','amino_acid_sequence',
137
+ # generic
138
+ 'sequence','seq',
139
+ # dna/cds
140
+ 'cds','dna','coding sequence','coding_sequence','cds sequence','cds_sequence'
141
+ ]
142
+ for k in seq_candidates:
143
+ if k in low_map:
144
+ seq_col = low_map[k]
145
+ break
146
+
147
+ if not seq_col:
148
+ raise ValueError(f"Could not detect sequence column. Available columns: {cols}")
149
+
150
+ return name_col, seq_col
151
+
152
+
153
+ def parse_excel_sequences(excel_path: str, name_col: str | None = None, seq_col: str | None = None, sheet_name: str | int | None = None) -> List[Dict[str, str]]:
154
+ """
155
+ Parse sequences from the benchmark Excel file and auto-detect relevant columns.
156
+
157
+ Args:
158
+ excel_path (str): Path to the Excel file.
159
+ name_col (str | None): Optional override for sequence name column.
160
+ seq_col (str | None): Optional override for sequence column.
161
+ sheet_name (str | int | None): Sheet name or index (default: first sheet).
162
+
163
+ Returns:
164
+ List[Dict[str, str]]: List of standardized sequence records with fields:
165
+ id, name, protein_sequence, original_sequence (DNA or None), is_dna.
166
+
167
+ Raises:
168
+ ValueError: If a sequence column cannot be detected.
169
+ """
170
+ sn = sheet_name
171
+ if isinstance(sn, str) and sn.isdigit():
172
+ sn = int(sn)
173
+ if sn is None:
174
+ sn = 0
175
+
176
+ df_or_dict = pd.read_excel(excel_path, sheet_name=sn)
177
+ if isinstance(df_or_dict, dict):
178
+ first_title, df = next(iter(df_or_dict.items()))
179
+ print(f"Using sheet: {first_title}")
180
+ else:
181
+ df = df_or_dict
182
+ sequences = []
183
+
184
+ detected_name_col, detected_seq_col = _detect_columns(df, name_col, seq_col)
185
+ print(f"Detected columns -> name: {detected_name_col or '[generated]'}, sequence: {detected_seq_col}")
186
+
187
+ for idx, row in df.iterrows():
188
+ sequence = str(row[detected_seq_col]).strip()
189
+ if detected_name_col:
190
+ name = str(row[detected_name_col]).strip()
191
+ else:
192
+ name = f"seq_{idx}"
193
+
194
+ if name.startswith('>'):
195
+ name = name[1:].strip()
196
+
197
+ sequence = ''.join(filter(str.isalpha, sequence))
198
+
199
+ dna_chars = sum(1 for c in sequence.upper() if c in 'ATCGN')
200
+ is_dna = (dna_chars / len(sequence)) > 0.95 if len(sequence) > 0 else False
201
+
202
+ if is_dna:
203
+ longest_orf = find_longest_orf(sequence)
204
+
205
+ if longest_orf and len(longest_orf) >= 30:
206
+ original_dna = longest_orf
207
+ protein_seq = translate_dna_to_protein(longest_orf)
208
+ else:
209
+ truncated_len = (len(sequence) // 3) * 3
210
+ if truncated_len >= 30:
211
+ original_dna = sequence[:truncated_len]
212
+ protein_seq = translate_dna_to_protein(original_dna)
213
+ else:
214
+ continue
215
+
216
+ if '*' in protein_seq:
217
+ stop_pos = protein_seq.find('*')
218
+ if stop_pos >= 10:
219
+ protein_seq = protein_seq[:stop_pos]
220
+ original_dna = original_dna[:stop_pos*3]
221
+ else:
222
+ continue
223
+
224
+ else:
225
+ protein_seq = sequence.upper()
226
+ protein_seq = protein_seq.replace('*', '')
227
+ original_dna = None
228
+
229
+ if len(protein_seq) < 10:
230
+ continue
231
+
232
+ sequences.append({
233
+ 'id': idx,
234
+ 'name': name,
235
+ 'protein_sequence': protein_seq,
236
+ 'original_sequence': original_dna,
237
+ 'is_dna': is_dna
238
+ })
239
+
240
+ return sequences
241
+
242
+
243
+ def calculate_cfd(dna_sequence: str, codon_frequencies: Dict) -> float:
244
+ """
245
+ Calculate Codon Frequency Distribution (CFD) similarity to a reference.
246
+
247
+ Args:
248
+ dna_sequence (str): Input DNA sequence.
249
+ codon_frequencies (Dict): Reference frequencies; accepts flattened mapping
250
+ or an amino2codon structure (will be flattened).
251
+
252
+ Returns:
253
+ float: Similarity score in [0, 1] where higher is more similar.
254
+ """
255
+ if not dna_sequence:
256
+ return 0.0
257
+
258
+ codon_count = {}
259
+ total_codons = 0
260
+
261
+ for i in range(0, len(dna_sequence) - 2, 3):
262
+ codon = dna_sequence[i:i+3].upper()
263
+ if len(codon) == 3:
264
+ codon_count[codon] = codon_count.get(codon, 0) + 1
265
+ total_codons += 1
266
+
267
+ seq_freq = {}
268
+ if total_codons > 0:
269
+ for codon, count in codon_count.items():
270
+ seq_freq[codon] = count / total_codons
271
+
272
+ # Flatten amino2codon frequencies if needed
273
+ flat_codon_freq = {}
274
+ if isinstance(codon_frequencies, dict):
275
+ first_key = next(iter(codon_frequencies.keys()))
276
+ if isinstance(codon_frequencies[first_key], tuple) and len(codon_frequencies[first_key]) == 2:
277
+ for amino, (codons, freqs) in codon_frequencies.items():
278
+ for codon, freq in zip(codons, freqs):
279
+ flat_codon_freq[codon] = freq
280
+ else:
281
+ flat_codon_freq = codon_frequencies
282
+
283
+ similarity = 0.0
284
+ count = 0
285
+
286
+ for codon in set(list(seq_freq.keys()) + list(flat_codon_freq.keys())):
287
+ seq_f = seq_freq.get(codon, 0.0)
288
+ ref_f = flat_codon_freq.get(codon, 0.0)
289
+ similarity += 1 - abs(seq_f - ref_f)
290
+ count += 1
291
+
292
+ return similarity / count if count > 0 else 0.0
293
+
294
+
295
+ def run_model_on_sequences(
296
+ sequences: List[Dict],
297
+ model,
298
+ tokenizer,
299
+ device,
300
+ cai_weights: Dict,
301
+ tai_weights: Dict,
302
+ codon_frequencies: Dict,
303
+ reference_profile: List[float],
304
+ output_dir: str
305
+ ) -> pd.DataFrame:
306
+ """
307
+ Run ColiFormer on protein sequences and compute metrics for optimized DNA.
308
+
309
+ Args:
310
+ sequences (List[Dict]): Parsed sequence records.
311
+ model: Loaded ColiFormer model.
312
+ tokenizer: Tokenizer used by the model.
313
+ device: Torch device.
314
+ cai_weights (Dict): CAI weights.
315
+ tai_weights (Dict): tAI weights.
316
+ codon_frequencies (Dict): Reference codon frequencies.
317
+ reference_profile (List[float]): Reserved for DTW profile (unused here).
318
+ output_dir (str): Directory for outputs (not written here).
319
+
320
+ Returns:
321
+ pd.DataFrame: Per-sequence metrics and optimized DNA.
322
+ """
323
+ results = []
324
+ print(f"Processing {len(sequences)} sequences...")
325
+
326
+ for seq_data in tqdm(sequences, desc="Optimizing sequences"):
327
+ protein_seq = seq_data['protein_sequence']
328
+
329
+ if len(protein_seq) < 10:
330
+ continue
331
+
332
+ try:
333
+ start_time = time.time()
334
+
335
+ output = predict_dna_sequence(
336
+ protein=protein_seq,
337
+ organism="Escherichia coli general",
338
+ device=device,
339
+ model=model,
340
+ deterministic=True,
341
+ match_protein=True,
342
+ )
343
+
344
+ runtime = time.time() - start_time
345
+
346
+ if isinstance(output, list):
347
+ optimized_dna = output[0].predicted_dna
348
+ else:
349
+ optimized_dna = output.predicted_dna
350
+
351
+ original_metrics = {}
352
+ if seq_data['is_dna'] and seq_data['original_sequence']:
353
+ original_dna = seq_data['original_sequence'].upper()
354
+ original_metrics = {
355
+ 'original_cai': CAI(original_dna, weights=cai_weights),
356
+ 'original_gc': get_GC_content(original_dna),
357
+ 'original_tai': calculate_tAI(original_dna, tai_weights),
358
+ 'original_cfd': calculate_cfd(original_dna, codon_frequencies),
359
+ 'original_neg_cis': count_negative_cis_elements(original_dna),
360
+ }
361
+
362
+ optimized_metrics = {
363
+ 'optimized_cai': CAI(optimized_dna, weights=cai_weights),
364
+ 'optimized_gc': get_GC_content(optimized_dna),
365
+ 'optimized_tai': calculate_tAI(optimized_dna, tai_weights),
366
+ 'optimized_cfd': calculate_cfd(optimized_dna, codon_frequencies),
367
+ 'optimized_neg_cis': count_negative_cis_elements(optimized_dna),
368
+ 'runtime': runtime,
369
+ }
370
+
371
+ result = {
372
+ 'id': seq_data['id'],
373
+ 'name': seq_data['name'],
374
+ 'protein_sequence': protein_seq,
375
+ 'protein_length': len(protein_seq),
376
+ 'optimized_dna': optimized_dna,
377
+ **original_metrics,
378
+ **optimized_metrics,
379
+ }
380
+ results.append(result)
381
+
382
+ except Exception as e:
383
+ print(f"Error processing sequence {seq_data['id']}: {str(e)}")
384
+ continue
385
+
386
+ return pd.DataFrame(results)
387
+
388
+
389
+ def generate_visualizations(results_df: pd.DataFrame, output_dir: str):
390
+ """
391
+ Generate visualizations and a metrics summary table.
392
+
393
+ Saves:
394
+ - CAI before/after bar plot
395
+ - Median CAI comparison
396
+ - Metrics distribution panel
397
+ - CSV summary table
398
+
399
+ Args:
400
+ results_df (pd.DataFrame): Results from optimization.
401
+ output_dir (str): Output directory root.
402
+
403
+ Returns:
404
+ pd.DataFrame: Summary table of aggregate metrics.
405
+ """
406
+ plt.style.use('seaborn-v0_8-darkgrid')
407
+ sns.set_palette("husl")
408
+
409
+ fig_dir = os.path.join(output_dir, 'figures')
410
+ os.makedirs(fig_dir, exist_ok=True)
411
+
412
+ # 1. Before/After CAI Graph
413
+ if 'original_cai' in results_df.columns:
414
+ plt.figure(figsize=(12, 8))
415
+
416
+ before_cai = results_df['original_cai'].dropna()
417
+ after_cai = results_df.loc[before_cai.index, 'optimized_cai']
418
+
419
+ x = np.arange(len(before_cai))
420
+ width = 0.35
421
+
422
+ fig, ax = plt.subplots(figsize=(14, 8))
423
+ bars1 = ax.bar(x - width/2, before_cai, width, label='Before Optimization', alpha=0.8)
424
+ bars2 = ax.bar(x + width/2, after_cai, width, label='After Optimization', alpha=0.8)
425
+
426
+ ax.set_xlabel('Sequence Index', fontsize=12)
427
+ ax.set_ylabel('CAI Score', fontsize=12)
428
+ ax.set_title('ENCOT: CAI Before and After Optimization', fontsize=14, fontweight='bold')
429
+ ax.set_xticks(x[::5]) # Show every 5th label
430
+ ax.set_xticklabels(x[::5])
431
+ ax.legend()
432
+ ax.grid(axis='y', alpha=0.3)
433
+
434
+ avg_before = before_cai.mean()
435
+ avg_after = after_cai.mean()
436
+ improvement = ((avg_after - avg_before) / avg_before) * 100
437
+
438
+ ax.text(0.02, 0.98, f'Average CAI Before: {avg_before:.3f}\nAverage CAI After: {avg_after:.3f}\nImprovement: {improvement:.1f}%',
439
+ transform=ax.transAxes, fontsize=10, verticalalignment='top',
440
+ bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.5))
441
+
442
+ plt.tight_layout()
443
+ plt.savefig(os.path.join(fig_dir, 'cai_before_after.png'), dpi=300, bbox_inches='tight')
444
+ plt.close()
445
+
446
+ print(f"CAI Before/After graph saved to {os.path.join(fig_dir, 'cai_before_after.png')}")
447
+
448
+ # 1b. Median CAI Before/After Graph
449
+ plt.figure(figsize=(8, 6))
450
+
451
+ median_before = before_cai.median()
452
+ median_after = after_cai.median()
453
+
454
+ categories = ['Before Optimization', 'After Optimization']
455
+ medians = [median_before, median_after]
456
+ colors = ['#ff7f0e', '#2ca02c']
457
+
458
+ bars = plt.bar(categories, medians, color=colors, alpha=0.8, width=0.6)
459
+ plt.ylabel('Median CAI Score', fontsize=12)
460
+ plt.title('ENCOT: Median CAI Before and After Optimization', fontsize=14, fontweight='bold')
461
+ plt.ylim(0, max(medians) * 1.2)
462
+
463
+ for bar, median in zip(bars, medians):
464
+ plt.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.01,
465
+ f'{median:.3f}', ha='center', va='bottom', fontweight='bold')
466
+
467
+ improvement_pct = ((median_after - median_before) / median_before) * 100
468
+ plt.text(0.5, max(medians) * 0.95, f'Improvement: {improvement_pct:.1f}%',
469
+ ha='center', transform=plt.gca().transData, fontsize=12,
470
+ bbox=dict(boxstyle='round', facecolor='lightblue', alpha=0.7))
471
+
472
+ plt.grid(axis='y', alpha=0.3)
473
+ plt.tight_layout()
474
+ plt.savefig(os.path.join(fig_dir, 'median_cai_comparison.png'), dpi=300, bbox_inches='tight')
475
+ plt.close()
476
+
477
+ print(f"Median CAI comparison graph saved to {os.path.join(fig_dir, 'median_cai_comparison.png')}")
478
+
479
+ # 2. Summary metrics table
480
+ metrics_summary = {}
481
+
482
+ if 'original_cai' in results_df.columns:
483
+ metrics_summary['CAI'] = {
484
+ 'Before': results_df['original_cai'].mean(),
485
+ 'After': results_df['optimized_cai'].mean(),
486
+ 'Improvement': ((results_df['optimized_cai'].mean() - results_df['original_cai'].mean()) / results_df['original_cai'].mean()) * 100
487
+ }
488
+ metrics_summary['GC Content (%)'] = {
489
+ 'Before': results_df['original_gc'].mean(),
490
+ 'After': results_df['optimized_gc'].mean(),
491
+ 'Difference': results_df['optimized_gc'].mean() - results_df['original_gc'].mean()
492
+ }
493
+ metrics_summary['tAI'] = {
494
+ 'Before': results_df['original_tai'].mean(),
495
+ 'After': results_df['optimized_tai'].mean(),
496
+ 'Improvement': ((results_df['optimized_tai'].mean() - results_df['original_tai'].mean()) / results_df['original_tai'].mean()) * 100
497
+ }
498
+ metrics_summary['CFD'] = {
499
+ 'Before': results_df['original_cfd'].mean(),
500
+ 'After': results_df['optimized_cfd'].mean(),
501
+ 'Improvement': ((results_df['optimized_cfd'].mean() - results_df['original_cfd'].mean()) / results_df['original_cfd'].mean()) * 100
502
+ }
503
+ metrics_summary['Negative Cis Elements'] = {
504
+ 'Before': results_df['original_neg_cis'].mean(),
505
+ 'After': results_df['optimized_neg_cis'].mean(),
506
+ 'Reduction': results_df['original_neg_cis'].mean() - results_df['optimized_neg_cis'].mean()
507
+ }
508
+ else:
509
+ metrics_summary['CAI'] = {
510
+ 'Optimized': results_df['optimized_cai'].mean(),
511
+ 'Std Dev': results_df['optimized_cai'].std()
512
+ }
513
+ metrics_summary['GC Content (%)'] = {
514
+ 'Optimized': results_df['optimized_gc'].mean(),
515
+ 'Std Dev': results_df['optimized_gc'].std()
516
+ }
517
+ metrics_summary['tAI'] = {
518
+ 'Optimized': results_df['optimized_tai'].mean(),
519
+ 'Std Dev': results_df['optimized_tai'].std()
520
+ }
521
+ metrics_summary['CFD'] = {
522
+ 'Optimized': results_df['optimized_cfd'].mean(),
523
+ 'Std Dev': results_df['optimized_cfd'].std()
524
+ }
525
+ metrics_summary['Negative Cis Elements'] = {
526
+ 'Optimized': results_df['optimized_neg_cis'].mean(),
527
+ 'Std Dev': results_df['optimized_neg_cis'].std()
528
+ }
529
+
530
+ metrics_summary['Runtime (seconds)'] = {
531
+ 'Mean': results_df['runtime'].mean(),
532
+ 'Median': results_df['runtime'].median(),
533
+ 'Total': results_df['runtime'].sum()
534
+ }
535
+
536
+ summary_df = pd.DataFrame(metrics_summary).T
537
+ summary_df = summary_df.round(4)
538
+
539
+ summary_df.to_csv(os.path.join(output_dir, 'metrics_summary.csv'))
540
+ print(f"\nMetrics Summary saved to {os.path.join(output_dir, 'metrics_summary.csv')}")
541
+ print("\n" + "="*60)
542
+ print("METRICS SUMMARY:")
543
+ print("="*60)
544
+ print(summary_df.to_string())
545
+
546
+ fig, axes = plt.subplots(2, 3, figsize=(15, 10))
547
+ axes = axes.flatten()
548
+
549
+ metrics_to_plot = [
550
+ ('optimized_cai', 'CAI Distribution'),
551
+ ('optimized_gc', 'GC Content Distribution (%)'),
552
+ ('optimized_tai', 'tAI Distribution'),
553
+ ('optimized_cfd', 'CFD Distribution'),
554
+ ('optimized_neg_cis', 'Negative Cis Elements'),
555
+ ('runtime', 'Runtime Distribution (seconds)')
556
+ ]
557
+
558
+ for idx, (col, title) in enumerate(metrics_to_plot):
559
+ if col in results_df.columns:
560
+ axes[idx].hist(results_df[col].dropna(), bins=20, edgecolor='black', alpha=0.7)
561
+ axes[idx].set_title(title, fontsize=10, fontweight='bold')
562
+ axes[idx].set_xlabel(col.replace('optimized_', '').replace('_', ' ').title())
563
+ axes[idx].set_ylabel('Frequency')
564
+ axes[idx].grid(axis='y', alpha=0.3)
565
+
566
+ mean_val = results_df[col].mean()
567
+ axes[idx].axvline(mean_val, color='red', linestyle='--', linewidth=2, label=f'Mean: {mean_val:.3f}')
568
+ axes[idx].legend()
569
+
570
+ plt.suptitle('ENCOT: Optimization Metrics Distribution', fontsize=14, fontweight='bold', y=1.02)
571
+ plt.tight_layout()
572
+ plt.savefig(os.path.join(fig_dir, 'metrics_distribution.png'), dpi=300, bbox_inches='tight')
573
+ plt.close()
574
+
575
+ print(f"Metrics distribution plot saved to {os.path.join(fig_dir, 'metrics_distribution.png')}")
576
+
577
+ return summary_df
578
+
579
+
580
+ def main():
581
+ """CLI entrypoint to run the ENCOT benchmark workflow."""
582
+ parser = argparse.ArgumentParser(description="Benchmark ENCOT on E. coli sequences")
583
+ parser.add_argument("--excel_path", type=str, default="Benchmark 80 sequences.xlsx",
584
+ help="Path to benchmark Excel file")
585
+ parser.add_argument("--checkpoint_path", type=str, default="models/ecoli-codon-optimizer/finetune_best.ckpt",
586
+ help="Path to fine-tuned model checkpoint")
587
+ parser.add_argument("--natural_sequences_path", type=str, default="data/ecoli_processed_genes.csv",
588
+ help="Path to natural E. coli sequences for CAI calculation")
589
+ parser.add_argument("--output_dir", type=str, default="benchmark_results",
590
+ help="Directory to save results")
591
+ parser.add_argument("--use_gpu", action="store_true", help="Use GPU if available")
592
+ parser.add_argument("--name_col", type=str, default=None, help="Optional: column name for sequence label (case-insensitive)")
593
+ parser.add_argument("--seq_col", type=str, default=None, help="Optional: column name for sequence (case-insensitive)")
594
+ parser.add_argument("--sheet_name", type=str, default=None, help="Optional: Excel sheet name or index")
595
+
596
+ args = parser.parse_args()
597
+
598
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
599
+ output_dir = os.path.join(args.output_dir, f"run_{timestamp}")
600
+ os.makedirs(output_dir, exist_ok=True)
601
+
602
+ print("="*60)
603
+ print("ENCOT BENCHMARK EVALUATION")
604
+ print("="*60)
605
+
606
+ device = torch.device("cuda" if torch.cuda.is_available() and args.use_gpu else "cpu")
607
+ print(f"Using device: {device}")
608
+
609
+ print(f"\nLoading sequences from {args.excel_path}...")
610
+ sequences = parse_excel_sequences(
611
+ args.excel_path,
612
+ name_col=args.name_col,
613
+ seq_col=args.seq_col,
614
+ sheet_name=args.sheet_name,
615
+ )
616
+ print(f"Loaded {len(sequences)} sequences")
617
+
618
+ print("\nLoading ENCOT model...")
619
+ model = load_model(model_path=args.checkpoint_path, device=device)
620
+ tokenizer = AutoTokenizer.from_pretrained("adibvafa/CodonTransformer")
621
+ print("Model loaded successfully")
622
+
623
+ print("\nPreparing evaluation utilities...")
624
+
625
+ natural_df = pd.read_csv(args.natural_sequences_path)
626
+ ref_sequences = natural_df['dna_sequence'].tolist()
627
+ cai_weights = relative_adaptiveness(sequences=ref_sequences)
628
+ print("CAI weights generated")
629
+
630
+ tai_weights = get_ecoli_tai_weights()
631
+ print("tAI weights loaded")
632
+
633
+ try:
634
+ codon_frequencies = download_codon_frequencies_from_kazusa(taxonomy_id=83333)
635
+ print("Codon frequencies loaded from Kazusa")
636
+ except Exception as e:
637
+ print(f"Warning: Kazusa download failed ({e}). Using local frequencies.")
638
+ codon_frequencies = get_codon_frequencies(
639
+ ref_sequences, organism="Escherichia coli general"
640
+ )
641
+
642
+ reference_profile = []
643
+
644
+ print("\n" + "="*60)
645
+ print("RUNNING OPTIMIZATION...")
646
+ print("="*60)
647
+
648
+ results_df = run_model_on_sequences(
649
+ sequences=sequences,
650
+ model=model,
651
+ tokenizer=tokenizer,
652
+ device=device,
653
+ cai_weights=cai_weights,
654
+ tai_weights=tai_weights,
655
+ codon_frequencies=codon_frequencies,
656
+ reference_profile=reference_profile,
657
+ output_dir=output_dir
658
+ )
659
+
660
+ results_path = os.path.join(output_dir, 'optimization_results.csv')
661
+ results_df.to_csv(results_path, index=False)
662
+ print(f"\nRaw results saved to {results_path}")
663
+
664
+ optimized_sequences = results_df[['id', 'name', 'protein_sequence', 'optimized_dna']].copy()
665
+ optimized_sequences['protein_length'] = results_df['protein_length']
666
+ optimized_sequences['dna_length'] = optimized_sequences['optimized_dna'].apply(len)
667
+ optimized_sequences['optimized_cai'] = results_df['optimized_cai']
668
+ optimized_sequences['optimized_gc'] = results_df['optimized_gc']
669
+ optimized_sequences['optimized_tai'] = results_df['optimized_tai']
670
+
671
+ if 'original_cai' in results_df.columns:
672
+ optimized_sequences['original_cai'] = results_df['original_cai']
673
+ optimized_sequences['cai_improvement'] = ((results_df['optimized_cai'] - results_df['original_cai']) / results_df['original_cai'] * 100).round(2)
674
+
675
+ optimized_sequences_path = os.path.join(output_dir, 'optimized_dna_sequences.csv')
676
+ optimized_sequences.to_csv(optimized_sequences_path, index=False)
677
+ print(f"Optimized DNA sequences saved to {optimized_sequences_path}")
678
+
679
+ print("\n" + "="*60)
680
+ print("GENERATING VISUALIZATIONS...")
681
+ print("="*60)
682
+
683
+ summary_df = generate_visualizations(results_df, output_dir)
684
+
685
+ print("\n" + "="*60)
686
+ print("BENCHMARK EVALUATION COMPLETE")
687
+ print("="*60)
688
+ print(f"Results saved to: {output_dir}")
689
+ print(f"Total sequences processed: {len(results_df)}")
690
+ print(f"Average runtime per sequence: {results_df['runtime'].mean():.2f} seconds")
691
+ print(f"Total runtime: {results_df['runtime'].sum():.2f} seconds")
692
+
693
+
694
+ if __name__ == "__main__":
695
+ main()
comprehensive_model_comparison.png ADDED

Git LFS Details

  • SHA256: 7ccd04a955c52c6384c3bb94983d71ed4eca22fe0fac815aaaa147344cd024bc
  • Pointer size: 131 Bytes
  • Size of remote file: 630 kB
configs/train_ecoli_alm.yaml ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ENCOT ALM Training Configuration
2
+ # This configuration reproduces the main training setup from the paper
3
+ # using the Augmented-Lagrangian Method (ALM) for GC content control.
4
+
5
+ model:
6
+ base_model: "adibvafa/CodonTransformer-base"
7
+ tokenizer: "adibvafa/CodonTransformer"
8
+
9
+ data:
10
+ dataset_dir: "data"
11
+ # Expected files: finetune_set.json (created by preprocess_data.py)
12
+
13
+ training:
14
+ batch_size: 6
15
+ max_epochs: 15
16
+ learning_rate: 5e-5
17
+ warmup_fraction: 0.1
18
+ num_workers: 5
19
+ accumulate_grad_batches: 1
20
+ num_gpus: 4
21
+ save_every_n_steps: 512
22
+ seed: 123
23
+ log_every_n_steps: 20
24
+
25
+ checkpoint:
26
+ checkpoint_dir: "models/alm-enhanced-training"
27
+ checkpoint_filename: "balanced_alm_finetune.ckpt"
28
+
29
+ # Augmented-Lagrangian Method (ALM) for GC content control
30
+ alm:
31
+ enabled: true
32
+ gc_target: 0.52 # Target GC content for E. coli (52%)
33
+ curriculum_epochs: 3 # Warm-up epochs before enforcing GC constraint
34
+
35
+ # ALM penalty parameters
36
+ initial_penalty_factor: 20.0
37
+ penalty_update_factor: 10.0
38
+ max_penalty: 1e6
39
+ min_penalty: 1e-6
40
+
41
+ # ALM tolerance parameters
42
+ tolerance: 1e-5 # Primal tolerance
43
+ dual_tolerance: 1e-5 # Dual tolerance for constraint violation
44
+ tolerance_update_factor: 0.1
45
+
46
+ # Adaptive penalty adjustment
47
+ rel_penalty_increase_threshold: 0.1
48
+
49
+ # Legacy penalty method (if ALM disabled)
50
+ gc_penalty:
51
+ weight: 0.0 # Only used if use_lagrangian=false
52
+
53
+
54
+
configs/train_ecoli_quick.yaml ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ENCOT Quick Training Configuration
2
+ # This is a minimal configuration for quick sanity checks and testing.
3
+ # Use this to verify your setup before running full training.
4
+
5
+ model:
6
+ base_model: "adibvafa/CodonTransformer-base"
7
+ tokenizer: "adibvafa/CodonTransformer"
8
+
9
+ data:
10
+ dataset_dir: "data"
11
+
12
+ training:
13
+ batch_size: 2
14
+ max_epochs: 1
15
+ learning_rate: 5e-5
16
+ warmup_fraction: 0.1
17
+ num_workers: 0 # Disable multiprocessing for debugging
18
+ accumulate_grad_batches: 1
19
+ num_gpus: 0 # CPU-only for quick testing
20
+ save_every_n_steps: 10
21
+ seed: 123
22
+ log_every_n_steps: 5
23
+
24
+ checkpoint:
25
+ checkpoint_dir: "models/test-training"
26
+ checkpoint_filename: "quick_test.ckpt"
27
+
28
+ alm:
29
+ enabled: false # Disable ALM for quick test
30
+ gc_target: 0.52
31
+ curriculum_epochs: 0
32
+
33
+ gc_penalty:
34
+ weight: 0.0
35
+
36
+
37
+
create_model_datasets.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pandas as pd
2
+ import json
3
+ import os
4
+ from CodonTransformer.CodonData import prepare_training_data
5
+
6
+ def main():
7
+ """
8
+ Main function to partition the processed data into fine-tuning and test sets.
9
+ """
10
+ if not os.path.exists('data'):
11
+ print("Error: 'data' directory not found. Please run prepare_ecoli_data.py first.")
12
+ return
13
+
14
+ processed_data_path = 'data/ecoli_processed_genes.csv'
15
+ if not os.path.exists(processed_data_path):
16
+ print(f"Error: Processed data file not found at {processed_data_path}")
17
+ return
18
+
19
+ df_processed = pd.read_csv(processed_data_path)
20
+
21
+ df_finetune = df_processed[df_processed['is_high_cai'] == True].copy()
22
+ df_finetune.drop_duplicates(subset=['dna_sequence'], inplace=True)
23
+ df_finetune.rename(columns={'dna_sequence': 'dna', 'protein_sequence': 'protein'}, inplace=True)
24
+ df_finetune['organism'] = "Escherichia coli general"
25
+
26
+ finetune_output_path = 'data/finetune_set.json'
27
+ prepare_training_data(df_finetune, finetune_output_path, shuffle=True)
28
+ print(f"Fine-tuning set saved to {finetune_output_path} with {len(df_finetune)} records.")
29
+
30
+ df_test_pool = df_processed[df_processed['is_high_cai'] == False].copy()
31
+ df_test = df_test_pool.sample(n=100, random_state=42) # for reproducibility
32
+ df_test['organism'] = 51 # E. coli general
33
+ df_test.rename(columns={'dna_sequence': 'codons'}, inplace=True)
34
+ test_records = df_test[['codons', 'organism']].to_dict(orient='records')
35
+
36
+ test_output_path = 'data/test_set.json'
37
+ with open(test_output_path, 'w') as f:
38
+ json.dump(test_records, f, indent=4)
39
+ print(f"Test set saved to {test_output_path} with {len(df_test)} records.")
40
+
41
+ if __name__ == "__main__":
42
+ main()
evaluate_optimizer.py ADDED
@@ -0,0 +1,577 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ """
3
+ File: evaluate_optimizer.py
4
+ ---------------------------
5
+ Evaluate ColiFormer with enhanced capabilities:
6
+ 1) DNAChisel post-processing for sequence polishing
7
+ 2) Optional multi-objective generation (Pareto-style filtering)
8
+ 3) Enhanced beam search with multiple candidates
9
+ 4) Comprehensive metrics and optional ablation studies
10
+ """
11
+
12
+ import argparse
13
+ import json
14
+ import os
15
+ import warnings
16
+ from typing import Dict, List, Tuple, Any
17
+
18
+ import numpy as np
19
+ import pandas as pd
20
+ import torch
21
+ from CAI import CAI, relative_adaptiveness
22
+ from tqdm import tqdm
23
+
24
+ from CodonTransformer.CodonData import (
25
+ download_codon_frequencies_from_kazusa,
26
+ get_codon_frequencies,
27
+ )
28
+ from CodonTransformer.CodonPrediction import (
29
+ load_model,
30
+ predict_dna_sequence,
31
+ get_high_frequency_choice_sequence_optimized,
32
+ )
33
+ from CodonTransformer.CodonEvaluation import (
34
+ calculate_dtw_distance,
35
+ calculate_homopolymer_runs,
36
+ calculate_tAI,
37
+ count_negative_cis_elements,
38
+ get_GC_content,
39
+ get_ecoli_tai_weights,
40
+ get_min_max_profile,
41
+ get_sequence_similarity,
42
+ scan_for_restriction_sites,
43
+ calculate_ENC,
44
+ calculate_CPB,
45
+ calculate_SCUO,
46
+ )
47
+ from CodonTransformer.CodonPostProcessing import (
48
+ polish_sequence_with_dnachisel,
49
+ )
50
+ from CodonTransformer.CodonUtils import DNASequencePrediction
51
+
52
+
53
+ def translate_dna_to_protein(dna_sequence: str) -> str:
54
+ """Translate DNA sequence to protein sequence."""
55
+ codon_table = {
56
+ 'TTT': 'F', 'TTC': 'F', 'TTA': 'L', 'TTG': 'L',
57
+ 'TCT': 'S', 'TCC': 'S', 'TCA': 'S', 'TCG': 'S',
58
+ 'TAT': 'Y', 'TAC': 'Y', 'TAA': '*', 'TAG': '*',
59
+ 'TGT': 'C', 'TGC': 'C', 'TGA': '*', 'TGG': 'W',
60
+ 'CTT': 'L', 'CTC': 'L', 'CTA': 'L', 'CTG': 'L',
61
+ 'CCT': 'P', 'CCC': 'P', 'CCA': 'P', 'CCG': 'P',
62
+ 'CAT': 'H', 'CAC': 'H', 'CAA': 'Q', 'CAG': 'Q',
63
+ 'CGT': 'R', 'CGC': 'R', 'CGA': 'R', 'CGG': 'R',
64
+ 'ATT': 'I', 'ATC': 'I', 'ATA': 'I', 'ATG': 'M',
65
+ 'ACT': 'T', 'ACC': 'T', 'ACA': 'T', 'ACG': 'T',
66
+ 'AAT': 'N', 'AAC': 'N', 'AAA': 'K', 'AAG': 'K',
67
+ 'AGT': 'S', 'AGC': 'S', 'AGA': 'R', 'AGG': 'R',
68
+ 'GTT': 'V', 'GTC': 'V', 'GTA': 'V', 'GTG': 'V',
69
+ 'GCT': 'A', 'GCC': 'A', 'GCA': 'A', 'GCG': 'A',
70
+ 'GAT': 'D', 'GAC': 'D', 'GAA': 'E', 'GAG': 'E',
71
+ 'GGT': 'G', 'GGC': 'G', 'GGA': 'G', 'GGG': 'G'
72
+ }
73
+
74
+ protein = ""
75
+ for i in range(0, len(dna_sequence), 3):
76
+ codon = dna_sequence[i:i+3].upper()
77
+ if len(codon) == 3:
78
+ aa = codon_table.get(codon, 'X')
79
+ if aa == '*': # Stop codon
80
+ break
81
+ protein += aa
82
+
83
+ return protein
84
+
85
+
86
+ def evaluate_with_enhancements(
87
+ protein_sequence: str,
88
+ model,
89
+ tokenizer,
90
+ device,
91
+ cai_weights: Dict[str, float],
92
+ tai_weights: Dict[str, float],
93
+ codon_frequencies: Dict,
94
+ reference_profile: List[float],
95
+ args,
96
+ ) -> Dict[str, Any]:
97
+ """
98
+ Evaluate a protein sequence with enhanced generation techniques.
99
+
100
+ Args:
101
+ protein_sequence: Input protein sequence
102
+ model: Fine-tuned model
103
+ tokenizer: Model tokenizer
104
+ device: PyTorch device
105
+ cai_weights: CAI weights dictionary
106
+ tai_weights: tAI weights dictionary
107
+ codon_frequencies: Codon frequencies dictionary
108
+ reference_profile: Reference profile for DTW calculation
109
+ args: Command line arguments
110
+
111
+ Returns:
112
+ Dict containing evaluation results for all methods
113
+ """
114
+ results = {}
115
+
116
+ # 1. Original fine-tuned model (baseline)
117
+ try:
118
+ original_output = predict_dna_sequence(
119
+ protein=protein_sequence,
120
+ organism="Escherichia coli general",
121
+ device=device,
122
+ model=model,
123
+ deterministic=True,
124
+ match_protein=True,
125
+ use_constrained_search=args.use_constrained_search,
126
+ gc_bounds=tuple(args.gc_bounds),
127
+ beam_size=args.beam_size,
128
+ length_penalty=args.length_penalty,
129
+ diversity_penalty=args.diversity_penalty,
130
+ )
131
+
132
+ if isinstance(original_output, list):
133
+ original_dna = original_output[0].predicted_dna
134
+ else:
135
+ original_dna = original_output.predicted_dna
136
+
137
+ results['fine_tuned_original'] = {
138
+ 'dna_sequence': original_dna,
139
+ 'method': 'fine_tuned_original',
140
+ 'enhancement': 'none',
141
+ }
142
+
143
+ except Exception as e:
144
+ print(f"Warning: Original fine-tuned generation failed: {str(e)}")
145
+ results['fine_tuned_original'] = {
146
+ 'dna_sequence': '',
147
+ 'method': 'fine_tuned_original',
148
+ 'enhancement': 'none',
149
+ 'error': str(e),
150
+ }
151
+
152
+ # 2. Enhanced sequence generation (DNAChisel + Pareto filtering)
153
+ if args.use_enhanced_generation:
154
+ try:
155
+ enhanced_dna, generation_report = enhanced_sequence_generation(
156
+ protein_sequence=protein_sequence,
157
+ model=model,
158
+ tokenizer=tokenizer,
159
+ device=device,
160
+ beam_size=args.enhanced_beam_size,
161
+ gc_bounds=(args.gc_bounds[0] * 100, args.gc_bounds[1] * 100),
162
+ use_dnachisel_polish=args.use_dnachisel,
163
+ use_pareto_filtering=args.use_pareto_filtering,
164
+ cai_weights=cai_weights,
165
+ tai_weights=tai_weights,
166
+ codon_frequencies=codon_frequencies,
167
+ reference_profile=reference_profile,
168
+ )
169
+
170
+ results['fine_tuned_enhanced'] = {
171
+ 'dna_sequence': enhanced_dna,
172
+ 'method': 'fine_tuned_enhanced',
173
+ 'enhancement': 'dnachisel+pareto',
174
+ 'generation_report': generation_report,
175
+ }
176
+
177
+ except Exception as e:
178
+ print(f"Warning: Enhanced generation failed: {str(e)}")
179
+ results['fine_tuned_enhanced'] = {
180
+ 'dna_sequence': '',
181
+ 'method': 'fine_tuned_enhanced',
182
+ 'enhancement': 'dnachisel+pareto',
183
+ 'error': str(e),
184
+ }
185
+
186
+ # 3. DNAChisel post-processing only (ablation study)
187
+ if args.use_dnachisel and 'fine_tuned_original' in results and results['fine_tuned_original']['dna_sequence']:
188
+ try:
189
+ dnachisel_dna, polish_report = polish_sequence_with_dnachisel(
190
+ dna_sequence=results['fine_tuned_original']['dna_sequence'],
191
+ protein_sequence=protein_sequence,
192
+ gc_bounds=(args.gc_bounds[0] * 100, args.gc_bounds[1] * 100),
193
+ maximize_cai=True,
194
+ seed=42,
195
+ )
196
+
197
+ results['fine_tuned_dnachisel'] = {
198
+ 'dna_sequence': dnachisel_dna,
199
+ 'method': 'fine_tuned_dnachisel',
200
+ 'enhancement': 'dnachisel_only',
201
+ 'polish_report': polish_report,
202
+ }
203
+
204
+ except Exception as e:
205
+ print(f"Warning: DNAChisel post-processing failed: {str(e)}")
206
+ results['fine_tuned_dnachisel'] = {
207
+ 'dna_sequence': '',
208
+ 'method': 'fine_tuned_dnachisel',
209
+ 'enhancement': 'dnachisel_only',
210
+ 'error': str(e),
211
+ }
212
+
213
+ return results
214
+
215
+
216
+ def calculate_comprehensive_metrics(
217
+ dna_sequence: str,
218
+ protein_sequence: str,
219
+ cai_weights: Dict[str, float],
220
+ tai_weights: Dict[str, float],
221
+ codon_frequencies: Dict,
222
+ reference_profile: List[float],
223
+ ref_sequences: List[str],
224
+ ) -> Dict[str, float]:
225
+ """Calculate comprehensive metrics for a DNA sequence."""
226
+ if not dna_sequence:
227
+ return {
228
+ 'cai': 0.0,
229
+ 'tai': 0.0,
230
+ 'gc_content': 0.0,
231
+ 'restriction_sites': float('inf'),
232
+ 'neg_cis_elements': float('inf'),
233
+ 'homopolymer_runs': float('inf'),
234
+ 'dtw_distance': float('inf'),
235
+ 'enc': 0.0,
236
+ 'cpb': 0.0,
237
+ 'scuo': 0.0,
238
+ }
239
+
240
+ return calculate_sequence_metrics(
241
+ dna_sequence=dna_sequence,
242
+ protein_sequence=protein_sequence,
243
+ cai_weights=cai_weights,
244
+ tai_weights=tai_weights,
245
+ codon_frequencies=codon_frequencies,
246
+ reference_profile=reference_profile,
247
+ )
248
+
249
+
250
+ def run_ablation_study(results_df: pd.DataFrame) -> pd.DataFrame:
251
+ """
252
+ Run ablation study to compare different enhancement methods.
253
+
254
+ Args:
255
+ results_df: DataFrame with evaluation results
256
+
257
+ Returns:
258
+ DataFrame with ablation study results
259
+ """
260
+ # Group by protein and calculate improvements
261
+ ablation_results = []
262
+
263
+ for protein in results_df['protein_sequence'].unique():
264
+ protein_results = results_df[results_df['protein_sequence'] == protein]
265
+
266
+ # Get baseline (original fine-tuned)
267
+ baseline = protein_results[protein_results['method'] == 'fine_tuned_original']
268
+ if baseline.empty:
269
+ continue
270
+
271
+ baseline_metrics = baseline.iloc[0]
272
+
273
+ # Compare each enhancement method
274
+ for method in protein_results['method'].unique():
275
+ if method == 'fine_tuned_original':
276
+ continue
277
+
278
+ method_results = protein_results[protein_results['method'] == method]
279
+ if method_results.empty:
280
+ continue
281
+
282
+ method_metrics = method_results.iloc[0]
283
+
284
+ # Calculate improvements
285
+ improvements = {
286
+ 'protein': protein,
287
+ 'method': method,
288
+ 'enhancement': method_metrics['enhancement'],
289
+ 'cai_improvement': method_metrics['cai'] - baseline_metrics['cai'],
290
+ 'tai_improvement': method_metrics['tai'] - baseline_metrics['tai'],
291
+ 'gc_improvement': abs(method_metrics['gc_content'] - 52) - abs(baseline_metrics['gc_content'] - 52),
292
+ 'restriction_sites_improvement': baseline_metrics['restriction_sites'] - method_metrics['restriction_sites'],
293
+ 'neg_cis_improvement': baseline_metrics['neg_cis_elements'] - method_metrics['neg_cis_elements'],
294
+ 'homopolymer_improvement': baseline_metrics['homopolymer_runs'] - method_metrics['homopolymer_runs'],
295
+ 'dtw_improvement': baseline_metrics['dtw_distance'] - method_metrics['dtw_distance'],
296
+ 'composite_score_improvement': (
297
+ (method_metrics['cai'] - baseline_metrics['cai']) * 0.3 +
298
+ (method_metrics['tai'] - baseline_metrics['tai']) * 0.3 +
299
+ (abs(baseline_metrics['gc_content'] - 52) - abs(method_metrics['gc_content'] - 52)) * 0.2 +
300
+ (baseline_metrics['restriction_sites'] - method_metrics['restriction_sites']) * 0.1 +
301
+ (baseline_metrics['neg_cis_elements'] - method_metrics['neg_cis_elements']) * 0.1
302
+ ),
303
+ }
304
+
305
+ ablation_results.append(improvements)
306
+
307
+ return pd.DataFrame(ablation_results)
308
+
309
+
310
+ def main(args):
311
+ """Main function to run the enhanced evaluation."""
312
+ print("=== Enhanced CodonTransformer Evaluation ===")
313
+
314
+ # Setup device
315
+ device = torch.device("cuda" if torch.cuda.is_available() and args.use_gpu else "cpu")
316
+ print(f"Using device: {device}")
317
+
318
+ # Load test data
319
+ with open(args.test_data_path, "r") as f:
320
+ first = f.read(1)
321
+ f.seek(0)
322
+ if first == "[":
323
+ test_set = json.load(f)
324
+ else:
325
+ test_set = [json.loads(line) for line in f if line.strip()]
326
+
327
+ # Limit test set size if requested
328
+ if args.max_test_proteins > 0:
329
+ test_set = test_set[:args.max_test_proteins]
330
+
331
+ print(f"Loaded {len(test_set)} proteins from the test set.")
332
+
333
+ # Load models
334
+ print("Loading models...")
335
+ finetuned_model = load_model(model_path=args.checkpoint_path, device=device)
336
+ print(f"Fine-tuned model loaded from {args.checkpoint_path}")
337
+
338
+ # Load tokenizer
339
+ from transformers import AutoTokenizer
340
+ tokenizer = AutoTokenizer.from_pretrained("adibvafa/CodonTransformer")
341
+
342
+ # Load base model if comparison requested
343
+ base_model = None
344
+ if args.compare_with_base:
345
+ base_model = load_model(device=device)
346
+ print("Base model loaded from Hugging Face")
347
+
348
+ # Prepare evaluation utilities
349
+ print("Preparing evaluation utilities...")
350
+
351
+ # CAI weights
352
+ natural_csv = args.natural_sequences_path
353
+ natural_df = pd.read_csv(natural_csv)
354
+ ref_sequences = natural_df['dna_sequence'].tolist()
355
+ cai_weights = relative_adaptiveness(sequences=ref_sequences)
356
+ print("CAI weights generated")
357
+
358
+ # tAI weights
359
+ tai_weights = get_ecoli_tai_weights()
360
+ print("tAI weights loaded")
361
+
362
+ # Codon frequencies
363
+ try:
364
+ codon_frequencies = download_codon_frequencies_from_kazusa(taxonomy_id=83333)
365
+ print("Codon frequencies loaded from Kazusa")
366
+ except Exception as e:
367
+ print(f"Warning: Kazusa download failed ({e}). Using local frequencies.")
368
+ codon_frequencies = get_codon_frequencies(
369
+ ref_sequences, organism="Escherichia coli general"
370
+ )
371
+
372
+ # Reference profile for DTW
373
+ reference_profiles = [
374
+ get_min_max_profile(seq, codon_frequencies) for seq in ref_sequences[:100]
375
+ ]
376
+ valid_profiles = [p for p in reference_profiles if p and not all(v is None for v in p)]
377
+
378
+ if valid_profiles:
379
+ max_len = max(len(p) for p in valid_profiles)
380
+ padded_profiles = [
381
+ np.pad(
382
+ np.array([v for v in p if v is not None]),
383
+ (0, max_len - len([v for v in p if v is not None])),
384
+ "constant",
385
+ constant_values=np.nan,
386
+ )
387
+ for p in valid_profiles
388
+ ]
389
+ avg_reference_profile = np.nanmean(padded_profiles, axis=0).tolist()
390
+ else:
391
+ avg_reference_profile = []
392
+
393
+ print("Reference profile generated")
394
+
395
+ # Run evaluation
396
+ all_results = []
397
+ evaluation_reports = []
398
+
399
+ print("Starting enhanced evaluation...")
400
+ for i, item in enumerate(tqdm(test_set, desc="Evaluating proteins")):
401
+ # Get protein sequence
402
+ if "protein_sequence" in item:
403
+ protein_sequence = item["protein_sequence"]
404
+ else:
405
+ dna_sequence = item["codons"]
406
+ protein_sequence = translate_dna_to_protein(dna_sequence)
407
+
408
+ # Skip if protein is too short or too long
409
+ if len(protein_sequence) < 10 or len(protein_sequence) > 1000:
410
+ continue
411
+
412
+ # Evaluate with enhancements
413
+ protein_results = evaluate_with_enhancements(
414
+ protein_sequence=protein_sequence,
415
+ model=finetuned_model,
416
+ tokenizer=tokenizer,
417
+ device=device,
418
+ cai_weights=cai_weights,
419
+ tai_weights=tai_weights,
420
+ codon_frequencies=codon_frequencies,
421
+ reference_profile=avg_reference_profile,
422
+ args=args,
423
+ )
424
+
425
+ # Add base model comparison if requested
426
+ if base_model:
427
+ try:
428
+ base_output = predict_dna_sequence(
429
+ protein=protein_sequence,
430
+ organism="Escherichia coli general",
431
+ device=device,
432
+ model=base_model,
433
+ deterministic=True,
434
+ match_protein=True,
435
+ )
436
+ base_dna = base_output.predicted_dna if not isinstance(base_output, list) else base_output[0].predicted_dna
437
+
438
+ protein_results['base_model'] = {
439
+ 'dna_sequence': base_dna,
440
+ 'method': 'base_model',
441
+ 'enhancement': 'none',
442
+ }
443
+ except Exception as e:
444
+ print(f"Warning: Base model generation failed: {str(e)}")
445
+
446
+ # Add naive baseline
447
+ try:
448
+ naive_dna = get_high_frequency_choice_sequence_optimized(
449
+ protein=protein_sequence, codon_frequencies=codon_frequencies
450
+ )
451
+ protein_results['naive_hfc'] = {
452
+ 'dna_sequence': naive_dna,
453
+ 'method': 'naive_hfc',
454
+ 'enhancement': 'none',
455
+ }
456
+ except Exception as e:
457
+ print(f"Warning: Naive HFC generation failed: {str(e)}")
458
+
459
+ # Calculate metrics for each method
460
+ for method_name, method_result in protein_results.items():
461
+ if 'error' in method_result:
462
+ continue
463
+
464
+ dna_seq = method_result['dna_sequence']
465
+ if not dna_seq:
466
+ continue
467
+
468
+ metrics = calculate_comprehensive_metrics(
469
+ dna_sequence=dna_seq,
470
+ protein_sequence=protein_sequence,
471
+ cai_weights=cai_weights,
472
+ tai_weights=tai_weights,
473
+ codon_frequencies=codon_frequencies,
474
+ reference_profile=avg_reference_profile,
475
+ ref_sequences=ref_sequences,
476
+ )
477
+
478
+ # Combine results
479
+ result_row = {
480
+ 'protein_id': i,
481
+ 'protein_sequence': protein_sequence,
482
+ 'protein_length': len(protein_sequence),
483
+ 'method': method_name,
484
+ 'enhancement': method_result['enhancement'],
485
+ 'dna_sequence': dna_seq,
486
+ 'dna_length': len(dna_seq),
487
+ **metrics,
488
+ }
489
+
490
+ # Add generation reports if available
491
+ if 'generation_report' in method_result:
492
+ result_row['generation_report'] = str(method_result['generation_report'])
493
+ if 'polish_report' in method_result:
494
+ result_row['polish_report'] = str(method_result['polish_report'])
495
+
496
+ all_results.append(result_row)
497
+
498
+ # Create results DataFrame
499
+ results_df = pd.DataFrame(all_results)
500
+
501
+ # Save detailed results
502
+ os.makedirs(os.path.dirname(args.output_path), exist_ok=True)
503
+ results_df.to_csv(args.output_path, index=False)
504
+ print(f"Detailed results saved to {args.output_path}")
505
+
506
+ # Run ablation study
507
+ if args.run_ablation_study:
508
+ ablation_df = run_ablation_study(results_df)
509
+ ablation_path = args.output_path.replace('.csv', '_ablation.csv')
510
+ ablation_df.to_csv(ablation_path, index=False)
511
+ print(f"Ablation study results saved to {ablation_path}")
512
+
513
+ # Print summary statistics
514
+ print("\n=== ABLATION STUDY SUMMARY ===")
515
+ for method in ablation_df['method'].unique():
516
+ method_results = ablation_df[ablation_df['method'] == method]
517
+ print(f"\n{method.upper()}:")
518
+ print(f" CAI improvement: {method_results['cai_improvement'].mean():.4f} ± {method_results['cai_improvement'].std():.4f}")
519
+ print(f" tAI improvement: {method_results['tai_improvement'].mean():.4f} ± {method_results['tai_improvement'].std():.4f}")
520
+ print(f" GC improvement: {method_results['gc_improvement'].mean():.4f} ± {method_results['gc_improvement'].std():.4f}")
521
+ print(f" Restriction sites improvement: {method_results['restriction_sites_improvement'].mean():.2f} ± {method_results['restriction_sites_improvement'].std():.2f}")
522
+ print(f" Composite score improvement: {method_results['composite_score_improvement'].mean():.4f} ± {method_results['composite_score_improvement'].std():.4f}")
523
+
524
+ # Print final summary
525
+ print("\n=== EVALUATION COMPLETE ===")
526
+ print(f"Total proteins evaluated: {len(results_df['protein_id'].unique())}")
527
+ print(f"Total sequences generated: {len(results_df)}")
528
+ print(f"Results saved to: {args.output_path}")
529
+
530
+
531
+ if __name__ == "__main__":
532
+ parser = argparse.ArgumentParser(description="Enhanced CodonTransformer Evaluation")
533
+
534
+ # Input/Output paths
535
+ parser.add_argument("--checkpoint_path", type=str, default="models/ecoli-codon-optimizer/finetune_best.ckpt",
536
+ help="Path to fine-tuned model checkpoint")
537
+ parser.add_argument("--test_data_path", type=str, default="data/test_set.json",
538
+ help="Path to test dataset")
539
+ parser.add_argument("--natural_sequences_path", type=str, default="data/ecoli_processed_genes.csv",
540
+ help="Path to natural E. coli sequences for CAI calculation")
541
+ parser.add_argument("--output_path", type=str, default="results/enhanced_evaluation_results.csv",
542
+ help="Path to save evaluation results")
543
+
544
+ # Model parameters
545
+ parser.add_argument("--use_gpu", action="store_true", help="Use GPU if available")
546
+ parser.add_argument("--compare_with_base", action="store_true", help="Compare with base model")
547
+
548
+ # Generation parameters
549
+ parser.add_argument("--use_constrained_search", action="store_true",
550
+ help="Use constrained beam search")
551
+ parser.add_argument("--gc_bounds", type=float, nargs=2, default=[0.50, 0.54],
552
+ help="GC content bounds (min max)")
553
+ parser.add_argument("--beam_size", type=int, default=10,
554
+ help="Beam size for standard generation")
555
+ parser.add_argument("--length_penalty", type=float, default=1.2,
556
+ help="Length penalty for beam search")
557
+ parser.add_argument("--diversity_penalty", type=float, default=0.1,
558
+ help="Diversity penalty for beam search")
559
+
560
+ # Enhancement parameters
561
+ parser.add_argument("--use_enhanced_generation", action="store_true",
562
+ help="Use enhanced generation with DNAChisel and Pareto filtering")
563
+ parser.add_argument("--enhanced_beam_size", type=int, default=20,
564
+ help="Beam size for enhanced generation")
565
+ parser.add_argument("--use_dnachisel", action="store_true",
566
+ help="Use DNAChisel post-processing")
567
+ parser.add_argument("--use_pareto_filtering", action="store_true",
568
+ help="Use Pareto frontier filtering")
569
+
570
+ # Evaluation parameters
571
+ parser.add_argument("--max_test_proteins", type=int, default=0,
572
+ help="Maximum number of proteins to test (0 for all)")
573
+ parser.add_argument("--run_ablation_study", action="store_true",
574
+ help="Run ablation study comparing methods")
575
+
576
+ args = parser.parse_args()
577
+ main(args)
prepare_ecoli_data.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pandas as pd
2
+ from Bio.Seq import Seq
3
+ import os
4
+
5
+ def is_valid_sequence(dna_seq: str) -> bool:
6
+ """
7
+ Applies a series of validation checks to a DNA sequence.
8
+
9
+ Args:
10
+ dna_seq (str): The DNA sequence to validate.
11
+
12
+ Returns:
13
+ bool: True if the sequence is valid, False otherwise.
14
+ """
15
+ if len(dna_seq) % 3 != 0:
16
+ return False
17
+ if not dna_seq.upper().startswith(('ATG', 'TTG', 'CTG', 'GTG')):
18
+ return False
19
+ if not dna_seq.upper().endswith(('TAA', 'TAG', 'TGA')):
20
+ return False
21
+
22
+ codons = [dna_seq[i:i+3].upper() for i in range(0, len(dna_seq) - 3, 3)]
23
+ if any(codon in ['TAA', 'TAG', 'TGA'] for codon in codons):
24
+ return False
25
+
26
+ if not all(c in 'ATGC' for c in dna_seq.upper()):
27
+ return False
28
+
29
+ return True
30
+
31
+ def main():
32
+ """
33
+ Main function to process and validate E. coli gene data.
34
+ """
35
+ if not os.path.exists('data'):
36
+ os.makedirs('data')
37
+
38
+ print("Loading data from CSV files...")
39
+ df_all = pd.read_csv("data/CAI.csv", header=0, names=['gene_id', 'cai_score', 'drop1', 'drop2', 'dna_sequence', 'drop3'])
40
+ df_high_cai = pd.read_csv("data/Database 3_4300 gene.csv", header=0, names=['dna_sequence'])
41
+
42
+ high_cai_sequences = set(df_high_cai['dna_sequence'])
43
+
44
+ validated_genes = []
45
+ for index, row in df_all.iterrows():
46
+ gene_id = row['gene_id']
47
+ dna_sequence = str(row['dna_sequence'])
48
+
49
+ if is_valid_sequence(dna_sequence):
50
+ protein_sequence = str(Seq(dna_sequence).translate())
51
+ is_high_cai = dna_sequence in high_cai_sequences
52
+
53
+ validated_genes.append({
54
+ 'gene_id': gene_id,
55
+ 'dna_sequence': dna_sequence,
56
+ 'protein_sequence': protein_sequence,
57
+ 'cai_score': row.get('cai_score', None),
58
+ 'is_high_cai': is_high_cai
59
+ })
60
+
61
+ df_processed = pd.DataFrame(validated_genes)
62
+
63
+ output_path = 'data/ecoli_processed_genes.csv'
64
+ df_processed.to_csv(output_path, index=False)
65
+ print(f"Processed data saved to {output_path}")
66
+ print(f"Total validated genes: {len(df_processed)}")
67
+
68
+ if __name__ == "__main__":
69
+ main()
pretrain.py ADDED
@@ -0,0 +1,232 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ File: pretrain.py
3
+ -----------------
4
+ Pretrain the base transformer model on JSON datasets prepared via
5
+ CodonData.prepare_training_data. This is typically not needed for ENCOT
6
+ as we use the pretrained CodonTransformer base. See README for setup and usage.
7
+ """
8
+
9
+ import argparse
10
+ import os
11
+
12
+ import pytorch_lightning as pl
13
+ import torch
14
+ from torch.utils.data import DataLoader
15
+ from transformers import BigBirdConfig, BigBirdForMaskedLM, PreTrainedTokenizerFast
16
+
17
+ from CodonTransformer.CodonUtils import (
18
+ MAX_LEN,
19
+ NUM_ORGANISMS,
20
+ TOKEN2MASK,
21
+ IterableJSONData,
22
+ )
23
+
24
+
25
+ class MaskedTokenizerCollator:
26
+ def __init__(self, tokenizer):
27
+ self.tokenizer = tokenizer
28
+
29
+ def __call__(self, examples):
30
+ tokenized = self.tokenizer(
31
+ [ex["codons"] for ex in examples],
32
+ return_attention_mask=True,
33
+ return_token_type_ids=True,
34
+ truncation=True,
35
+ padding=True,
36
+ max_length=MAX_LEN,
37
+ return_tensors="pt",
38
+ )
39
+
40
+ seq_len = tokenized["input_ids"].shape[-1]
41
+ species_index = torch.tensor([[ex["organism"]] for ex in examples])
42
+ tokenized["token_type_ids"] = species_index.repeat(1, seq_len)
43
+
44
+ inputs = tokenized["input_ids"]
45
+ targets = inputs.clone()
46
+
47
+ prob_matrix = torch.full(inputs.shape, 0.15)
48
+ prob_matrix[inputs < 5] = 0.0
49
+ selected = torch.bernoulli(prob_matrix).bool()
50
+
51
+ replaced = torch.bernoulli(torch.full(selected.shape, 0.8)).bool() & selected
52
+ inputs[replaced] = torch.tensor(
53
+ list((map(TOKEN2MASK.__getitem__, inputs[replaced].numpy())))
54
+ )
55
+
56
+ randomized = (
57
+ torch.bernoulli(torch.full(selected.shape, 0.1)).bool()
58
+ & selected
59
+ & ~replaced
60
+ )
61
+ random_idx = torch.randint(26, 90, inputs.shape, dtype=torch.long)
62
+ inputs[randomized] = random_idx[randomized]
63
+
64
+ tokenized["input_ids"] = inputs
65
+ tokenized["labels"] = torch.where(selected, targets, -100)
66
+
67
+ return tokenized
68
+
69
+
70
+ class plTrainHarness(pl.LightningModule):
71
+ def __init__(self, model, learning_rate, warmup_fraction):
72
+ super().__init__()
73
+ self.model = model
74
+ self.learning_rate = learning_rate
75
+ self.warmup_fraction = warmup_fraction
76
+
77
+ def configure_optimizers(self):
78
+ optimizer = torch.optim.AdamW(
79
+ self.model.parameters(),
80
+ lr=self.learning_rate,
81
+ )
82
+ lr_scheduler = {
83
+ "scheduler": torch.optim.lr_scheduler.OneCycleLR(
84
+ optimizer,
85
+ max_lr=self.learning_rate,
86
+ total_steps=self.trainer.estimated_stepping_batches,
87
+ pct_start=self.warmup_fraction,
88
+ ),
89
+ "interval": "step",
90
+ "frequency": 1,
91
+ }
92
+ return [optimizer], [lr_scheduler]
93
+
94
+ def training_step(self, batch, batch_idx):
95
+ self.model.bert.set_attention_type("block_sparse")
96
+ outputs = self.model(**batch)
97
+ self.log_dict(
98
+ dictionary={
99
+ "loss": outputs.loss,
100
+ "lr": self.trainer.optimizers[0].param_groups[0]["lr"],
101
+ },
102
+ on_step=True,
103
+ prog_bar=True,
104
+ )
105
+ return outputs.loss
106
+
107
+
108
+ class EpochCheckpoint(pl.Callback):
109
+ def __init__(self, checkpoint_dir, save_interval):
110
+ super().__init__()
111
+ self.checkpoint_dir = checkpoint_dir
112
+ self.save_interval = save_interval
113
+
114
+ def on_train_epoch_end(self, trainer, pl_module):
115
+ current_epoch = trainer.current_epoch
116
+ if current_epoch % self.save_interval == 0 or current_epoch == 0:
117
+ checkpoint_path = os.path.join(
118
+ self.checkpoint_dir, f"epoch_{current_epoch}.ckpt"
119
+ )
120
+ trainer.save_checkpoint(checkpoint_path)
121
+ print(f"\nCheckpoint saved at {checkpoint_path}\n")
122
+
123
+
124
+ def main(args):
125
+ """Pretrain the base transformer model."""
126
+ pl.seed_everything(args.seed)
127
+ torch.set_float32_matmul_precision("medium")
128
+
129
+ tokenizer = PreTrainedTokenizerFast(
130
+ tokenizer_file=args.tokenizer_path,
131
+ bos_token="[CLS]",
132
+ eos_token="[SEP]",
133
+ unk_token="[UNK]",
134
+ sep_token="[SEP]",
135
+ pad_token="[PAD]",
136
+ cls_token="[CLS]",
137
+ mask_token="[MASK]",
138
+ )
139
+ config = BigBirdConfig(
140
+ vocab_size=len(tokenizer),
141
+ type_vocab_size=NUM_ORGANISMS,
142
+ sep_token_id=2,
143
+ )
144
+ model = BigBirdForMaskedLM(config=config)
145
+ harnessed_model = plTrainHarness(model, args.learning_rate, args.warmup_fraction)
146
+
147
+ train_data = IterableJSONData(args.train_data_path, dist_env="slurm")
148
+ data_loader = DataLoader(
149
+ dataset=train_data,
150
+ collate_fn=MaskedTokenizerCollator(tokenizer),
151
+ batch_size=args.batch_size,
152
+ num_workers=0 if args.debug else args.num_workers,
153
+ persistent_workers=False if args.debug else True,
154
+ )
155
+
156
+ save_checkpoint = EpochCheckpoint(args.checkpoint_dir, args.save_interval)
157
+ trainer = pl.Trainer(
158
+ default_root_dir=args.checkpoint_dir,
159
+ strategy="ddp_find_unused_parameters_true",
160
+ accelerator="gpu",
161
+ devices=1 if args.debug else args.num_gpus,
162
+ precision="16-mixed",
163
+ max_epochs=args.max_epochs,
164
+ deterministic=False,
165
+ enable_checkpointing=True,
166
+ callbacks=[save_checkpoint],
167
+ accumulate_grad_batches=args.accumulate_grad_batches,
168
+ )
169
+
170
+ # Pretrain the model
171
+ trainer.fit(harnessed_model, data_loader)
172
+
173
+
174
+ if __name__ == "__main__":
175
+ parser = argparse.ArgumentParser(description="Pretrain the base transformer model.")
176
+ parser.add_argument(
177
+ "--tokenizer_path",
178
+ type=str,
179
+ required=True,
180
+ help="Path to the tokenizer model file",
181
+ )
182
+ parser.add_argument(
183
+ "--train_data_path",
184
+ type=str,
185
+ required=True,
186
+ help="Path to the training data JSON file",
187
+ )
188
+ parser.add_argument(
189
+ "--checkpoint_dir",
190
+ type=str,
191
+ required=True,
192
+ help="Directory where checkpoints will be saved",
193
+ )
194
+ parser.add_argument(
195
+ "--batch_size", type=int, default=6, help="Batch size for training"
196
+ )
197
+ parser.add_argument(
198
+ "--max_epochs", type=int, default=5, help="Maximum number of epochs to train"
199
+ )
200
+ parser.add_argument(
201
+ "--num_workers", type=int, default=5, help="Number of workers for data loading"
202
+ )
203
+ parser.add_argument(
204
+ "--accumulate_grad_batches",
205
+ type=int,
206
+ default=1,
207
+ help="Number of batches to accumulate gradients",
208
+ )
209
+ parser.add_argument(
210
+ "--num_gpus", type=int, default=16, help="Number of GPUs to use for training"
211
+ )
212
+ parser.add_argument(
213
+ "--learning_rate",
214
+ type=float,
215
+ default=5e-5,
216
+ help="Learning rate for the optimizer",
217
+ )
218
+ parser.add_argument(
219
+ "--warmup_fraction",
220
+ type=float,
221
+ default=0.1,
222
+ help="Fraction of total steps to use for warmup",
223
+ )
224
+ parser.add_argument(
225
+ "--save_interval", type=int, default=5, help="Save checkpoint every N epochs"
226
+ )
227
+ parser.add_argument(
228
+ "--seed", type=int, default=123, help="Random seed for reproducibility"
229
+ )
230
+ parser.add_argument("--debug", action="store_true", help="Enable debug mode")
231
+ args = parser.parse_args()
232
+ main(args)
pyproject.toml ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [tool.poetry]
2
+ name = "ENCOT"
3
+ version = "1.0.0"
4
+ description = "Transformer-based codon optimization for E. coli using deep learning with Augmented-Lagrangian GC control."
5
+ authors = ["Adibvafa Fallahpour <Adibvafa.fallahpour@mail.utoronto.ca>"]
6
+ license = "Apache-2.0"
7
+ readme = "README.md"
8
+ homepage = "https://github.com/geno543/ENCOT"
9
+ repository = "https://github.com/geno543/ENCOT"
10
+ classifiers = [
11
+ "Programming Language :: Python :: 3",
12
+ "License :: OSI Approved :: Apache Software License",
13
+ "Operating System :: OS Independent",
14
+ ]
15
+
16
+ [tool.poetry.dependencies]
17
+ python = "^3.9"
18
+ biopython = "^1.83"
19
+ ipywidgets = "^7.0.0"
20
+ numpy = "<2.0.0"
21
+ onnxruntime = "^1.16.3"
22
+ pandas = "^2.0.0"
23
+ python_codon_tables = "^0.1.12"
24
+ pytorch_lightning = "^2.2.1"
25
+ scikit-learn = "^1.2.2"
26
+ scipy = "^1.13.1"
27
+ setuptools = "^70.0.0"
28
+ torch = "^2.0.0"
29
+ tqdm = "^4.66.2"
30
+ transformers = "^4.40.0"
31
+ CAI-PyPI = "^2.0.1"
32
+ codon-bias = "^1.0.2"
33
+ gcua = "^0.1.2"
34
+ dtw-python = "^1.3.0"
35
+
36
+ [tool.poetry.dev-dependencies]
37
+ coverage = {version = "^7.0", extras = ["toml"]}
38
+
39
+ [build-system]
40
+ requires = ["poetry-core>=1.0.0"]
41
+ build-backend = "poetry.core.masonry.api"
42
+
43
+ [tool.ruff]
44
+ line-length = 88
45
+ indent-width = 4
46
+ target-version = "py310"
47
+
48
+ [tool.ruff.lint]
49
+ select = ["E", "F", "I"]
50
+ ignore = []
51
+
52
+ [tool.ruff.format]
53
+ quote-style = "double"
54
+ indent-style = "space"
55
+ skip-magic-trailing-comma = false
56
+ line-ending = "auto"
57
+
58
+ [tool.coverage.run]
59
+ omit = [
60
+ # omit pytorch-generated files in /tmp
61
+ "/tmp/*",
62
+ ]
requirements.txt ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ biopython>=1.83,<2.0
2
+ CAI-PyPI>=2.0.1,<3.0
3
+ ipywidgets>=7.0.0,<10.0
4
+ numpy>=1.26.4,<2.0
5
+ onnxruntime>=1.16.3,<3.0
6
+ pandas>=2.0.0,<3.0
7
+ python_codon_tables>=0.1.12,<1.0
8
+ pytorch_lightning>=2.2.1,<3.0
9
+ scikit-learn>=1.2.2,<2.0
10
+ scipy>=1.13.1,<3.0
11
+ setuptools>=70.0.0
12
+ torch>=2.0.0,<3.0
13
+ tqdm>=4.66.2,<5.0
14
+ transformers>=4.40.0,<5.0
15
+ codon-bias>=0.3.5,<0.4
16
+ dtw-python>=1.3.0,<2.0
17
+
18
+ dnachisel>=1.0
19
+ paretoset>=1.2.0
20
+ softadapt>=0.1.2,<0.2
21
+ ema-pytorch>=0.4.3
22
+ torchmetrics>=1.4.0
23
+ pyyaml>=6.0
24
+
25
+ matplotlib>=3.8,<4.0
26
+ seaborn>=0.13,<0.14
27
+ openpyxl>=3.1,<4.0
28
+
29
+ huggingface-hub>=0.20,<1.0
scripts/optimize_sequence.py ADDED
@@ -0,0 +1,383 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Optimize protein sequences using ColiFormer.
3
+
4
+ This script provides a user-friendly interface for codon optimization,
5
+ supporting both single sequences and batch processing via FASTA files.
6
+
7
+ Usage:
8
+ # Single sequence
9
+ python scripts/optimize_sequence.py --input "MALWMRLLPLLALLALWGPDPAAAFVNQHLCGSHLVEALYLVCGERGFFYTPKTRREAEDLQVGQVELGG" --output optimized.fasta
10
+
11
+ # Batch processing from FASTA file
12
+ python scripts/optimize_sequence.py --input sequences.fasta --output optimized.fasta --batch
13
+
14
+ # With GC content constraints
15
+ python scripts/optimize_sequence.py --input protein.fasta --output optimized.fasta --gc-min 0.45 --gc-max 0.55
16
+ """
17
+
18
+ import argparse
19
+ import os
20
+ import sys
21
+ from pathlib import Path
22
+ from typing import Any, List, Tuple
23
+
24
+ # Add parent directory to path to import CodonTransformer
25
+ sys.path.insert(0, str(Path(__file__).parent.parent))
26
+
27
+
28
+ def parse_fasta(fasta_path: str) -> List[Tuple[str, str]]:
29
+ """
30
+ Parse FASTA file into list of (name, sequence) tuples.
31
+
32
+ Args:
33
+ fasta_path: Path to FASTA file
34
+
35
+ Returns:
36
+ List of (name, sequence) tuples
37
+ """
38
+ sequences = []
39
+ current_name = None
40
+ current_seq = []
41
+
42
+ with open(fasta_path, 'r') as f:
43
+ for line in f:
44
+ line = line.strip()
45
+ if line.startswith('>'):
46
+ if current_name is not None:
47
+ sequences.append((current_name, ''.join(current_seq)))
48
+ current_name = line[1:] if len(line) > 1 else f"sequence_{len(sequences)+1}"
49
+ current_seq = []
50
+ else:
51
+ current_seq.append(line.upper())
52
+
53
+ if current_name is not None:
54
+ sequences.append((current_name, ''.join(current_seq)))
55
+
56
+ return sequences
57
+
58
+
59
+ def write_fasta(output_path: str, sequences: List[Tuple[str, str]]):
60
+ """
61
+ Write sequences to FASTA file.
62
+
63
+ Args:
64
+ output_path: Output FASTA file path
65
+ sequences: List of (name, sequence) tuples
66
+ """
67
+ with open(output_path, 'w') as f:
68
+ for name, seq in sequences:
69
+ f.write(f">{name}\n")
70
+ # Write sequence in 60-character lines
71
+ for i in range(0, len(seq), 60):
72
+ f.write(seq[i:i+60] + "\n")
73
+
74
+
75
+ def optimize_single_sequence(
76
+ protein: str,
77
+ model: Any,
78
+ tokenizer: Any,
79
+ device: Any,
80
+ organism: str = "Escherichia coli general",
81
+ gc_min: float = None,
82
+ gc_max: float = None,
83
+ cai_weights: dict = None,
84
+ tai_weights: dict = None
85
+ ) -> dict:
86
+ """
87
+ Optimize a single protein sequence.
88
+
89
+ Args:
90
+ protein: Protein sequence string
91
+ model: Loaded ColiFormer model
92
+ tokenizer: Tokenizer
93
+ device: PyTorch device
94
+ organism: Target organism name
95
+ gc_min: Minimum GC content (0-1)
96
+ gc_max: Maximum GC content (0-1)
97
+ cai_weights: CAI weights dictionary
98
+ tai_weights: tAI weights dictionary
99
+
100
+ Returns:
101
+ Dictionary with optimization results
102
+ """
103
+ # Lazy imports so `python scripts/optimize_sequence.py --help` works without ML deps installed.
104
+ from CodonTransformer.CodonPrediction import predict_dna_sequence
105
+ from CodonTransformer.CodonEvaluation import get_GC_content, calculate_tAI
106
+ from CAI import CAI
107
+
108
+ # Determine GC bounds if specified
109
+ gc_bounds = None
110
+ use_constrained = False
111
+ if gc_min is not None and gc_max is not None:
112
+ gc_bounds = (gc_min, gc_max)
113
+ use_constrained = True
114
+
115
+ # Run optimization
116
+ output = predict_dna_sequence(
117
+ protein=protein,
118
+ organism=organism,
119
+ device=device,
120
+ model=model,
121
+ tokenizer=tokenizer,
122
+ deterministic=True,
123
+ match_protein=True,
124
+ use_constrained_search=use_constrained,
125
+ gc_bounds=gc_bounds,
126
+ beam_size=20 if use_constrained else 5,
127
+ )
128
+
129
+ if isinstance(output, list):
130
+ output = output[0]
131
+
132
+ optimized_dna = output.predicted_dna
133
+
134
+ # Calculate metrics
135
+ gc_content = get_GC_content(optimized_dna) / 100.0 # Convert to fraction
136
+
137
+ metrics = {
138
+ 'protein': protein,
139
+ 'optimized_dna': optimized_dna,
140
+ 'gc_content': gc_content,
141
+ 'length': len(optimized_dna),
142
+ }
143
+
144
+ if cai_weights:
145
+ try:
146
+ metrics['cai'] = CAI(optimized_dna, weights=cai_weights)
147
+ except:
148
+ metrics['cai'] = None
149
+ else:
150
+ metrics['cai'] = None
151
+
152
+ if tai_weights:
153
+ try:
154
+ metrics['tai'] = calculate_tAI(optimized_dna, tai_weights)
155
+ except:
156
+ metrics['tai'] = None
157
+ else:
158
+ metrics['tai'] = None
159
+
160
+ return metrics
161
+
162
+
163
+ def load_reference_data(ref_sequences_path: str = None):
164
+ """
165
+ Load reference sequences and calculate CAI weights.
166
+
167
+ Args:
168
+ ref_sequences_path: Path to CSV with reference sequences
169
+
170
+ Returns:
171
+ Tuple of (cai_weights, tai_weights)
172
+ """
173
+ # Lazy imports so `--help` works without ML deps installed.
174
+ import pandas as pd
175
+ from CAI import relative_adaptiveness
176
+ from CodonTransformer.CodonEvaluation import get_ecoli_tai_weights
177
+
178
+ cai_weights = None
179
+ tai_weights = None
180
+
181
+ # Try to load reference sequences for CAI
182
+ if ref_sequences_path and os.path.exists(ref_sequences_path):
183
+ try:
184
+ df = pd.read_csv(ref_sequences_path)
185
+ if 'dna_sequence' in df.columns:
186
+ ref_sequences = df['dna_sequence'].tolist()
187
+ cai_weights = relative_adaptiveness(sequences=ref_sequences)
188
+ print(f"Loaded CAI weights from {len(ref_sequences)} reference sequences")
189
+ except Exception as e:
190
+ print(f"Warning: Could not load CAI weights: {e}")
191
+
192
+ # Load tAI weights
193
+ try:
194
+ tai_weights = get_ecoli_tai_weights()
195
+ print("Loaded E. coli tAI weights")
196
+ except Exception as e:
197
+ print(f"Warning: Could not load tAI weights: {e}")
198
+
199
+ return cai_weights, tai_weights
200
+
201
+
202
+ def main():
203
+ """Main entry point for sequence optimization."""
204
+ parser = argparse.ArgumentParser(
205
+ description="Optimize protein sequences using ENCOT",
206
+ formatter_class=argparse.RawDescriptionHelpFormatter,
207
+ epilog="""
208
+ Examples:
209
+ # Single sequence
210
+ python scripts/optimize_sequence.py --input "MALWMRLLPLLALLALWGPDPAAAFVNQHLCGSHLVEALYLVCGERGFFYTPKTRREAEDLQVGQVELGG" --output optimized.fasta
211
+
212
+ # Batch processing from FASTA file
213
+ python scripts/optimize_sequence.py --input sequences.fasta --output optimized.fasta --batch
214
+
215
+ # With GC content constraints
216
+ python scripts/optimize_sequence.py --input protein.fasta --output optimized.fasta --gc-min 0.45 --gc-max 0.55
217
+
218
+ # Use custom checkpoint
219
+ python scripts/optimize_sequence.py --input protein.fasta --output optimized.fasta --checkpoint models/my_model.ckpt
220
+ """
221
+ )
222
+ parser.add_argument(
223
+ "--input",
224
+ type=str,
225
+ required=True,
226
+ help="Input protein sequence (string) or FASTA file path"
227
+ )
228
+ parser.add_argument(
229
+ "--output",
230
+ type=str,
231
+ required=True,
232
+ help="Output FASTA file path"
233
+ )
234
+ parser.add_argument(
235
+ "--checkpoint",
236
+ type=str,
237
+ default=None,
238
+ help="Path to model checkpoint (default: auto-download from Hugging Face)"
239
+ )
240
+ parser.add_argument(
241
+ "--organism",
242
+ type=str,
243
+ default="Escherichia coli general",
244
+ help="Target organism (default: Escherichia coli general)"
245
+ )
246
+ parser.add_argument(
247
+ "--gc-min",
248
+ type=float,
249
+ default=None,
250
+ help="Minimum GC content (0-1, e.g., 0.45 for 45%%)"
251
+ )
252
+ parser.add_argument(
253
+ "--gc-max",
254
+ type=float,
255
+ default=None,
256
+ help="Maximum GC content (0-1, e.g., 0.55 for 55%%)"
257
+ )
258
+ parser.add_argument(
259
+ "--batch",
260
+ action="store_true",
261
+ help="Process input as FASTA file with multiple sequences"
262
+ )
263
+ parser.add_argument(
264
+ "--ref-sequences",
265
+ type=str,
266
+ default="data/ecoli_processed_genes.csv",
267
+ help="Path to reference sequences CSV for CAI calculation"
268
+ )
269
+ parser.add_argument(
270
+ "--use-gpu",
271
+ action="store_true",
272
+ help="Use GPU if available"
273
+ )
274
+
275
+ args = parser.parse_args()
276
+
277
+ try:
278
+ # Lazy imports so `--help` works without ML deps installed.
279
+ import torch
280
+ from transformers import AutoTokenizer
281
+ from CodonTransformer.CodonPrediction import load_model
282
+ import pandas as pd
283
+
284
+ # Setup device
285
+ device = torch.device("cuda" if torch.cuda.is_available() and args.use_gpu else "cpu")
286
+ print(f"Using device: {device}")
287
+
288
+ # Load model
289
+ print("Loading ColiFormer model...")
290
+ if args.checkpoint:
291
+ model = load_model(model_path=args.checkpoint, device=device)
292
+ print(f"Loaded model from {args.checkpoint}")
293
+ else:
294
+ # Try to load from Hugging Face
295
+ try:
296
+ from huggingface_hub import hf_hub_download
297
+ checkpoint_path = hf_hub_download(
298
+ repo_id="saketh11/ColiFormer",
299
+ filename="balanced_alm_finetune.ckpt",
300
+ cache_dir="./hf_cache"
301
+ )
302
+ model = load_model(model_path=checkpoint_path, device=device)
303
+ print("Loaded model from Hugging Face (saketh11/ColiFormer)")
304
+ except Exception as e:
305
+ print(f"Warning: Could not load from Hugging Face: {e}")
306
+ print("Falling back to base CodonTransformer model...")
307
+ from transformers import BigBirdForMaskedLM
308
+ model = BigBirdForMaskedLM.from_pretrained("adibvafa/CodonTransformer").to(device)
309
+
310
+ # Load tokenizer
311
+ tokenizer = AutoTokenizer.from_pretrained("adibvafa/CodonTransformer")
312
+
313
+ # Load reference data for metrics
314
+ cai_weights, tai_weights = load_reference_data(args.ref_sequences)
315
+
316
+ # Parse input
317
+ if args.batch or os.path.exists(args.input):
318
+ # FASTA file
319
+ print(f"Reading sequences from {args.input}...")
320
+ sequences = parse_fasta(args.input)
321
+ print(f"Found {len(sequences)} sequences")
322
+ else:
323
+ # Single sequence string
324
+ sequences = [("sequence_1", args.input.upper())]
325
+
326
+ # Optimize sequences
327
+ optimized_sequences = []
328
+ results = []
329
+
330
+ for i, (name, protein_seq) in enumerate(sequences, 1):
331
+ print(f"\nOptimizing sequence {i}/{len(sequences)}: {name}")
332
+
333
+ metrics = optimize_single_sequence(
334
+ protein=protein_seq,
335
+ model=model,
336
+ tokenizer=tokenizer,
337
+ device=device,
338
+ organism=args.organism,
339
+ gc_min=args.gc_min,
340
+ gc_max=args.gc_max,
341
+ cai_weights=cai_weights,
342
+ tai_weights=tai_weights
343
+ )
344
+
345
+ optimized_sequences.append((name, metrics['optimized_dna']))
346
+ results.append({
347
+ 'name': name,
348
+ 'protein_length': len(protein_seq),
349
+ 'dna_length': metrics['length'],
350
+ 'gc_content': f"{metrics['gc_content']*100:.2f}%",
351
+ 'cai': metrics['cai'],
352
+ 'tai': metrics['tai'],
353
+ })
354
+
355
+ print(f" GC content: {metrics['gc_content']*100:.2f}%")
356
+ if metrics['cai']:
357
+ print(f" CAI: {metrics['cai']:.3f}")
358
+ if metrics['tai']:
359
+ print(f" tAI: {metrics['tai']:.3f}")
360
+
361
+ # Write output
362
+ write_fasta(args.output, optimized_sequences)
363
+ print(f"\nOptimized sequences saved to {args.output}")
364
+
365
+ # Print summary
366
+ if len(results) > 1:
367
+ print("\n" + "="*60)
368
+ print("Summary Statistics")
369
+ print("="*60)
370
+ df = pd.DataFrame(results)
371
+ print(df.to_string(index=False))
372
+ print("="*60)
373
+
374
+ except Exception as e:
375
+ print(f"Error: {e}", file=sys.stderr)
376
+ import traceback
377
+ traceback.print_exc()
378
+ sys.exit(1)
379
+
380
+
381
+ if __name__ == "__main__":
382
+ main()
383
+
scripts/preprocess_data.py ADDED
@@ -0,0 +1,251 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Preprocess E. coli gene data for ColiFormer training.
3
+
4
+ This script combines the functionality of prepare_ecoli_data.py and
5
+ create_model_datasets.py to prepare training and test datasets from raw CSV files.
6
+
7
+ Usage:
8
+ python scripts/preprocess_data.py
9
+ python scripts/preprocess_data.py --cai_csv data/CAI.csv --high_cai_csv data/Database_3_4300_gene.csv
10
+ """
11
+
12
+ import argparse
13
+ import json
14
+ import os
15
+ import sys
16
+ from pathlib import Path
17
+
18
+ # Add parent directory to path to import CodonTransformer
19
+ sys.path.insert(0, str(Path(__file__).parent.parent))
20
+
21
+
22
+ def is_valid_sequence(dna_seq: str) -> bool:
23
+ """
24
+ Validate a DNA sequence for training suitability.
25
+
26
+ Args:
27
+ dna_seq: DNA sequence string
28
+
29
+ Returns:
30
+ True if sequence is valid (divisible by 3, proper start/stop codons, no internal stops)
31
+ """
32
+ if len(dna_seq) % 3 != 0:
33
+ return False
34
+ if not dna_seq.upper().startswith(('ATG', 'TTG', 'CTG', 'GTG')):
35
+ return False
36
+ if not dna_seq.upper().endswith(('TAA', 'TAG', 'TGA')):
37
+ return False
38
+
39
+ codons = [dna_seq[i:i+3].upper() for i in range(0, len(dna_seq) - 3, 3)]
40
+ if any(codon in ['TAA', 'TAG', 'TGA'] for codon in codons):
41
+ return False
42
+
43
+ if not all(c in 'ATGC' for c in dna_seq.upper()):
44
+ return False
45
+
46
+ return True
47
+
48
+
49
+ def process_ecoli_data(cai_csv: str, high_cai_csv: str, output_dir: str = "data"):
50
+ """
51
+ Process raw E. coli gene data from CSV files.
52
+
53
+ Args:
54
+ cai_csv: Path to CAI.csv file with gene data
55
+ high_cai_csv: Path to Database 3_4300 gene.csv with high-CAI sequences
56
+ output_dir: Output directory for processed files
57
+
58
+ Returns:
59
+ Path to processed CSV file
60
+ """
61
+ # Lazy imports so `python scripts/preprocess_data.py --help` works without heavy deps installed.
62
+ import pandas as pd
63
+ from Bio.Seq import Seq
64
+
65
+ # Validate input files exist
66
+ if not os.path.exists(cai_csv):
67
+ raise FileNotFoundError(f"CAI CSV file not found: {cai_csv}")
68
+ if not os.path.exists(high_cai_csv):
69
+ raise FileNotFoundError(f"High-CAI CSV file not found: {high_cai_csv}")
70
+
71
+ # Create output directory if needed
72
+ os.makedirs(output_dir, exist_ok=True)
73
+
74
+ print("Loading data from CSV files...")
75
+ df_all = pd.read_csv(
76
+ cai_csv,
77
+ header=0,
78
+ names=['gene_id', 'cai_score', 'drop1', 'drop2', 'dna_sequence', 'drop3']
79
+ )
80
+ df_high_cai = pd.read_csv(
81
+ high_cai_csv,
82
+ header=0,
83
+ names=['dna_sequence']
84
+ )
85
+
86
+ high_cai_sequences = set(df_high_cai['dna_sequence'])
87
+
88
+ validated_genes = []
89
+ for index, row in df_all.iterrows():
90
+ gene_id = row['gene_id']
91
+ dna_sequence = str(row['dna_sequence'])
92
+
93
+ if is_valid_sequence(dna_sequence):
94
+ protein_sequence = str(Seq(dna_sequence).translate())
95
+ is_high_cai = dna_sequence in high_cai_sequences
96
+
97
+ validated_genes.append({
98
+ 'gene_id': gene_id,
99
+ 'dna_sequence': dna_sequence,
100
+ 'protein_sequence': protein_sequence,
101
+ 'cai_score': row.get('cai_score', None),
102
+ 'is_high_cai': is_high_cai
103
+ })
104
+
105
+ df_processed = pd.DataFrame(validated_genes)
106
+
107
+ output_path = os.path.join(output_dir, 'ecoli_processed_genes.csv')
108
+ df_processed.to_csv(output_path, index=False)
109
+ print(f"Processed data saved to {output_path}")
110
+ print(f"Total validated genes: {len(df_processed)}")
111
+
112
+ return output_path
113
+
114
+
115
+ def create_train_test_splits(processed_csv: str, output_dir: str = "data", test_size: int = 100):
116
+ """
117
+ Create training and test splits from processed data.
118
+
119
+ Args:
120
+ processed_csv: Path to processed ecoli_processed_genes.csv
121
+ output_dir: Output directory for JSON files
122
+ test_size: Number of sequences for test set
123
+
124
+ Returns:
125
+ Tuple of (finetune_json_path, test_json_path)
126
+ """
127
+ # Lazy imports so `--help` works without heavy deps installed.
128
+ import pandas as pd
129
+ from CodonTransformer.CodonData import prepare_training_data
130
+
131
+ if not os.path.exists(processed_csv):
132
+ raise FileNotFoundError(f"Processed data file not found: {processed_csv}")
133
+
134
+ os.makedirs(output_dir, exist_ok=True)
135
+
136
+ df_processed = pd.read_csv(processed_csv)
137
+
138
+ # Create fine-tuning set (high-CAI sequences)
139
+ df_finetune = df_processed[df_processed['is_high_cai'] == True].copy()
140
+ df_finetune.drop_duplicates(subset=['dna_sequence'], inplace=True)
141
+ df_finetune.rename(columns={'dna_sequence': 'dna', 'protein_sequence': 'protein'}, inplace=True)
142
+ df_finetune['organism'] = "Escherichia coli general"
143
+
144
+ finetune_output_path = os.path.join(output_dir, 'finetune_set.json')
145
+ prepare_training_data(df_finetune, finetune_output_path, shuffle=True)
146
+ print(f"Fine-tuning set saved to {finetune_output_path} with {len(df_finetune)} records.")
147
+
148
+ # Create test set (non-high-CAI sequences)
149
+ df_test_pool = df_processed[df_processed['is_high_cai'] == False].copy()
150
+ df_test = df_test_pool.sample(n=test_size, random_state=42) # for reproducibility
151
+ df_test['organism'] = 51 # E. coli general organism ID
152
+ df_test.rename(columns={'dna_sequence': 'codons'}, inplace=True)
153
+ test_records = df_test[['codons', 'organism']].to_dict(orient='records')
154
+
155
+ test_output_path = os.path.join(output_dir, 'test_set.json')
156
+ with open(test_output_path, 'w') as f:
157
+ json.dump(test_records, f, indent=4)
158
+ print(f"Test set saved to {test_output_path} with {len(df_test)} records.")
159
+
160
+ return finetune_output_path, test_output_path
161
+
162
+
163
+ def main():
164
+ """Main entry point for data preprocessing."""
165
+ parser = argparse.ArgumentParser(
166
+ description="Preprocess E. coli gene data for ENCOT training",
167
+ formatter_class=argparse.RawDescriptionHelpFormatter,
168
+ epilog="""
169
+ Examples:
170
+ # Use default paths
171
+ python scripts/preprocess_data.py
172
+
173
+ # Specify custom input files
174
+ python scripts/preprocess_data.py --cai_csv data/CAI.csv --high_cai_csv data/Database_3_4300_gene.csv
175
+
176
+ # Custom output directory and test size
177
+ python scripts/preprocess_data.py --output_dir my_data --test_size 200
178
+ """
179
+ )
180
+ parser.add_argument(
181
+ "--cai_csv",
182
+ type=str,
183
+ default="data/CAI.csv",
184
+ help="Path to CAI.csv file with gene data (default: data/CAI.csv)"
185
+ )
186
+ parser.add_argument(
187
+ "--high_cai_csv",
188
+ type=str,
189
+ default="data/Database 3_4300 gene.csv",
190
+ help="Path to Database 3_4300 gene.csv file (default: data/Database 3_4300 gene.csv)"
191
+ )
192
+ parser.add_argument(
193
+ "--output_dir",
194
+ type=str,
195
+ default="data",
196
+ help="Output directory for processed files (default: data)"
197
+ )
198
+ parser.add_argument(
199
+ "--test_size",
200
+ type=int,
201
+ default=100,
202
+ help="Number of sequences for test set (default: 100)"
203
+ )
204
+ parser.add_argument(
205
+ "--skip_processing",
206
+ action="store_true",
207
+ help="Skip data processing step (assume ecoli_processed_genes.csv exists)"
208
+ )
209
+
210
+ args = parser.parse_args()
211
+
212
+ try:
213
+ # Step 1: Process raw data
214
+ if not args.skip_processing:
215
+ processed_csv = process_ecoli_data(
216
+ args.cai_csv,
217
+ args.high_cai_csv,
218
+ args.output_dir
219
+ )
220
+ else:
221
+ processed_csv = os.path.join(args.output_dir, 'ecoli_processed_genes.csv')
222
+ if not os.path.exists(processed_csv):
223
+ raise FileNotFoundError(
224
+ f"Processed data not found at {processed_csv}. "
225
+ "Remove --skip_processing flag to process raw data first."
226
+ )
227
+ print(f"Using existing processed data: {processed_csv}")
228
+
229
+ # Step 2: Create train/test splits
230
+ finetune_path, test_path = create_train_test_splits(
231
+ processed_csv,
232
+ args.output_dir,
233
+ args.test_size
234
+ )
235
+
236
+ print("\n" + "="*60)
237
+ print("Data preprocessing complete!")
238
+ print("="*60)
239
+ print(f"Training set: {finetune_path}")
240
+ print(f"Test set: {test_path}")
241
+ print("\nYou can now run training with:")
242
+ print(f" python scripts/train.py --config configs/train_ecoli_alm.yaml")
243
+
244
+ except Exception as e:
245
+ print(f"Error: {e}", file=sys.stderr)
246
+ sys.exit(1)
247
+
248
+
249
+ if __name__ == "__main__":
250
+ main()
251
+
scripts/run_benchmarks.py ADDED
@@ -0,0 +1,235 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Run benchmark evaluation for ColiFormer.
3
+
4
+ This script wraps benchmark_evaluation.py and evaluate_optimizer.py to provide
5
+ a unified interface for running comprehensive evaluations.
6
+
7
+ Usage:
8
+ python scripts/run_benchmarks.py --config configs/benchmark.yaml
9
+ python scripts/run_benchmarks.py --excel_path Benchmark_80_sequences.xlsx --checkpoint_path models/my_model.ckpt
10
+ """
11
+
12
+ import argparse
13
+ import os
14
+ import sys
15
+ from pathlib import Path
16
+
17
+ # Add parent directory to path to import benchmark scripts
18
+ sys.path.insert(0, str(Path(__file__).parent.parent))
19
+
20
+
21
+ def load_config(config_path: str) -> dict:
22
+ """
23
+ Load configuration from YAML file.
24
+
25
+ Args:
26
+ config_path: Path to YAML config file
27
+
28
+ Returns:
29
+ Dictionary with configuration values
30
+ """
31
+ # Lazy import so `python scripts/run_benchmarks.py --help` works without dependencies installed.
32
+ import yaml
33
+
34
+ if not os.path.exists(config_path):
35
+ raise FileNotFoundError(f"Config file not found: {config_path}")
36
+
37
+ with open(config_path, 'r') as f:
38
+ config = yaml.safe_load(f)
39
+
40
+ return config
41
+
42
+
43
+ def config_to_args(config: dict) -> argparse.Namespace:
44
+ """
45
+ Convert config dictionary to argparse.Namespace compatible with benchmark_evaluation.py.
46
+
47
+ Args:
48
+ config: Configuration dictionary from YAML
49
+
50
+ Returns:
51
+ argparse.Namespace with all required arguments
52
+ """
53
+ model_config = config.get('model', {})
54
+ data_config = config.get('data', {})
55
+ output_config = config.get('output', {})
56
+ eval_config = config.get('evaluation', {})
57
+
58
+ args = argparse.Namespace()
59
+
60
+ # Model paths
61
+ args.checkpoint_path = model_config.get('checkpoint_path', 'models/alm-enhanced-training/balanced_alm_finetune.ckpt')
62
+
63
+ # Data paths
64
+ args.excel_path = data_config.get('excel_path', 'Benchmark 80 sequences.xlsx')
65
+ args.natural_sequences_path = data_config.get('natural_sequences_path', 'data/ecoli_processed_genes.csv')
66
+ args.name_col = data_config.get('name_col')
67
+ args.seq_col = data_config.get('seq_col')
68
+ args.sheet_name = data_config.get('sheet_name')
69
+
70
+ # Output paths
71
+ args.output_dir = output_config.get('output_dir', 'benchmark_results')
72
+
73
+ # Evaluation parameters
74
+ args.use_gpu = eval_config.get('use_gpu', True)
75
+ args.compare_with_base = eval_config.get('compare_with_base', False)
76
+ args.max_test_proteins = eval_config.get('max_test_proteins', 0)
77
+
78
+ return args
79
+
80
+
81
+ def validate_config(config: dict):
82
+ """
83
+ Validate configuration before running benchmarks.
84
+
85
+ Args:
86
+ config: Configuration dictionary
87
+
88
+ Raises:
89
+ ValueError: If configuration is invalid
90
+ """
91
+ data_config = config.get('data', {})
92
+ excel_path = data_config.get('excel_path', 'Benchmark 80 sequences.xlsx')
93
+
94
+ if not os.path.exists(excel_path):
95
+ raise ValueError(
96
+ f"Benchmark Excel file not found: {excel_path}\n"
97
+ "Please provide a valid path to your benchmark sequences file."
98
+ )
99
+
100
+ model_config = config.get('model', {})
101
+ checkpoint_path = model_config.get('checkpoint_path')
102
+
103
+ # Check if checkpoint exists locally, or will be downloaded from HF
104
+ if checkpoint_path and os.path.exists(checkpoint_path):
105
+ print(f"Using local checkpoint: {checkpoint_path}")
106
+ else:
107
+ print(f"Checkpoint not found locally: {checkpoint_path}")
108
+ print("Will attempt to download from Hugging Face (saketh11/ColiFormer) if needed")
109
+
110
+
111
+ def main():
112
+ """Main entry point for benchmark evaluation."""
113
+ parser = argparse.ArgumentParser(
114
+ description="Run benchmark evaluation for ENCOT",
115
+ formatter_class=argparse.RawDescriptionHelpFormatter,
116
+ epilog="""
117
+ Examples:
118
+ # Run with configuration file
119
+ python scripts/run_benchmarks.py --config configs/benchmark.yaml
120
+
121
+ # Run with command-line arguments
122
+ python scripts/run_benchmarks.py --excel_path Benchmark_80_sequences.xlsx --checkpoint_path models/my_model.ckpt
123
+
124
+ # Override config values
125
+ python scripts/run_benchmarks.py --config configs/benchmark.yaml --use_gpu --max_test_proteins 50
126
+ """
127
+ )
128
+ parser.add_argument(
129
+ "--config",
130
+ type=str,
131
+ default=None,
132
+ help="Path to YAML configuration file"
133
+ )
134
+ parser.add_argument(
135
+ "--excel_path",
136
+ type=str,
137
+ default=None,
138
+ help="Path to benchmark Excel file (overrides config)"
139
+ )
140
+ parser.add_argument(
141
+ "--checkpoint_path",
142
+ type=str,
143
+ default=None,
144
+ help="Path to model checkpoint (overrides config)"
145
+ )
146
+ parser.add_argument(
147
+ "--output_dir",
148
+ type=str,
149
+ default=None,
150
+ help="Output directory for results (overrides config)"
151
+ )
152
+ parser.add_argument(
153
+ "--use_gpu",
154
+ action="store_true",
155
+ help="Use GPU if available (overrides config)"
156
+ )
157
+ parser.add_argument(
158
+ "--max_test_proteins",
159
+ type=int,
160
+ default=None,
161
+ help="Maximum number of proteins to test (overrides config)"
162
+ )
163
+
164
+ args = parser.parse_args()
165
+
166
+ try:
167
+ # Lazy import so `--help` works even if plotting/ML deps are missing.
168
+ from benchmark_evaluation import main as benchmark_main
169
+
170
+ if args.config:
171
+ # Load configuration from file
172
+ print(f"Loading configuration from {args.config}...")
173
+ config = load_config(args.config)
174
+
175
+ # Override with command-line arguments if provided
176
+ if args.excel_path:
177
+ config.setdefault('data', {})['excel_path'] = args.excel_path
178
+ if args.checkpoint_path:
179
+ config.setdefault('model', {})['checkpoint_path'] = args.checkpoint_path
180
+ if args.output_dir:
181
+ config.setdefault('output', {})['output_dir'] = args.output_dir
182
+ if args.use_gpu:
183
+ config.setdefault('evaluation', {})['use_gpu'] = True
184
+ if args.max_test_proteins is not None:
185
+ config.setdefault('evaluation', {})['max_test_proteins'] = args.max_test_proteins
186
+
187
+ # Validate configuration
188
+ validate_config(config)
189
+
190
+ # Convert config to args namespace
191
+ benchmark_args = config_to_args(config)
192
+ else:
193
+ # Use command-line arguments directly
194
+ if not args.excel_path:
195
+ parser.error("Either --config or --excel_path must be provided")
196
+
197
+ benchmark_args = argparse.Namespace()
198
+ benchmark_args.excel_path = args.excel_path
199
+ benchmark_args.checkpoint_path = args.checkpoint_path or 'models/alm-enhanced-training/balanced_alm_finetune.ckpt'
200
+ benchmark_args.natural_sequences_path = 'data/ecoli_processed_genes.csv'
201
+ benchmark_args.output_dir = args.output_dir or 'benchmark_results'
202
+ benchmark_args.use_gpu = args.use_gpu
203
+ benchmark_args.max_test_proteins = args.max_test_proteins or 0
204
+ benchmark_args.name_col = None
205
+ benchmark_args.seq_col = None
206
+ benchmark_args.sheet_name = None
207
+
208
+ # Validate
209
+ if not os.path.exists(benchmark_args.excel_path):
210
+ raise ValueError(f"Benchmark Excel file not found: {benchmark_args.excel_path}")
211
+
212
+ # Print configuration summary
213
+ print("\n" + "="*60)
214
+ print("Benchmark Configuration Summary")
215
+ print("="*60)
216
+ print(f"Excel file: {benchmark_args.excel_path}")
217
+ print(f"Checkpoint: {benchmark_args.checkpoint_path}")
218
+ print(f"Output directory: {benchmark_args.output_dir}")
219
+ print(f"Use GPU: {benchmark_args.use_gpu}")
220
+ print(f"Max test proteins: {benchmark_args.max_test_proteins if benchmark_args.max_test_proteins > 0 else 'All'}")
221
+ print("="*60 + "\n")
222
+
223
+ # Run benchmark
224
+ benchmark_main(benchmark_args)
225
+
226
+ except Exception as e:
227
+ print(f"Error: {e}", file=sys.stderr)
228
+ import traceback
229
+ traceback.print_exc()
230
+ sys.exit(1)
231
+
232
+
233
+ if __name__ == "__main__":
234
+ main()
235
+
scripts/train.py ADDED
@@ -0,0 +1,228 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Training entry point for ColiFormer.
3
+
4
+ This script wraps finetune.py and loads configuration from YAML files.
5
+
6
+ Usage:
7
+ python scripts/train.py --config configs/train_ecoli_alm.yaml
8
+ python scripts/train.py --config configs/train_ecoli_quick.yaml
9
+ """
10
+
11
+ import argparse
12
+ import os
13
+ import sys
14
+ from pathlib import Path
15
+
16
+ # Add parent directory to path to import finetune
17
+ sys.path.insert(0, str(Path(__file__).parent.parent))
18
+
19
+
20
+ def load_config(config_path: str) -> dict:
21
+ """
22
+ Load configuration from YAML file.
23
+
24
+ Args:
25
+ config_path: Path to YAML config file
26
+
27
+ Returns:
28
+ Dictionary with configuration values
29
+ """
30
+ # Lazy import so `python scripts/train.py --help` works without dependencies installed.
31
+ import yaml
32
+
33
+ if not os.path.exists(config_path):
34
+ raise FileNotFoundError(f"Config file not found: {config_path}")
35
+
36
+ with open(config_path, 'r') as f:
37
+ config = yaml.safe_load(f)
38
+
39
+ return config
40
+
41
+
42
+ def config_to_args(config: dict) -> argparse.Namespace:
43
+ """
44
+ Convert config dictionary to argparse.Namespace compatible with finetune.py.
45
+
46
+ Args:
47
+ config: Configuration dictionary from YAML
48
+
49
+ Returns:
50
+ argparse.Namespace with all required arguments
51
+ """
52
+ # Extract nested config values
53
+ data_config = config.get('data', {})
54
+ training_config = config.get('training', {})
55
+ checkpoint_config = config.get('checkpoint', {})
56
+ alm_config = config.get('alm', {})
57
+ gc_penalty_config = config.get('gc_penalty', {})
58
+
59
+ # Build args namespace
60
+ args = argparse.Namespace()
61
+
62
+ # Data paths
63
+ args.dataset_dir = data_config.get('dataset_dir', 'data')
64
+
65
+ # Checkpoint paths
66
+ args.checkpoint_dir = checkpoint_config.get('checkpoint_dir', 'models/checkpoints')
67
+ args.checkpoint_filename = checkpoint_config.get('checkpoint_filename', 'finetune.ckpt')
68
+
69
+ # Training parameters
70
+ args.batch_size = training_config.get('batch_size', 6)
71
+ args.max_epochs = training_config.get('max_epochs', 15)
72
+ args.num_workers = training_config.get('num_workers', 5)
73
+ args.accumulate_grad_batches = training_config.get('accumulate_grad_batches', 1)
74
+ args.num_gpus = training_config.get('num_gpus', 4)
75
+ args.learning_rate = training_config.get('learning_rate', 5e-5)
76
+ args.warmup_fraction = training_config.get('warmup_fraction', 0.1)
77
+ args.save_every_n_steps = training_config.get('save_every_n_steps', 512)
78
+ args.seed = training_config.get('seed', 123)
79
+ args.log_every_n_steps = training_config.get('log_every_n_steps', 20)
80
+ args.debug = training_config.get('debug', False)
81
+
82
+ # GC penalty (legacy)
83
+ args.gc_penalty_weight = gc_penalty_config.get('weight', 0.0)
84
+
85
+ # ALM parameters
86
+ args.use_lagrangian = alm_config.get('enabled', False)
87
+ args.gc_target = alm_config.get('gc_target', 0.52)
88
+ args.curriculum_epochs = alm_config.get('curriculum_epochs', 3)
89
+ args.lagrangian_rho = alm_config.get('initial_penalty_factor', 20.0) # Use initial_penalty_factor as rho
90
+ args.alm_tolerance = alm_config.get('tolerance', 1e-5)
91
+ args.alm_dual_tolerance = alm_config.get('dual_tolerance', 1e-5)
92
+ args.alm_penalty_update_factor = alm_config.get('penalty_update_factor', 10.0)
93
+ args.alm_initial_penalty_factor = alm_config.get('initial_penalty_factor', 20.0)
94
+ args.alm_tolerance_update_factor = alm_config.get('tolerance_update_factor', 0.1)
95
+ args.alm_rel_penalty_increase_threshold = alm_config.get('rel_penalty_increase_threshold', 0.1)
96
+ args.alm_max_penalty = alm_config.get('max_penalty', 1e6)
97
+ args.alm_min_penalty = alm_config.get('min_penalty', 1e-6)
98
+
99
+ return args
100
+
101
+
102
+ def validate_config(config: dict):
103
+ """
104
+ Validate configuration before training.
105
+
106
+ Args:
107
+ config: Configuration dictionary
108
+
109
+ Raises:
110
+ ValueError: If configuration is invalid
111
+ """
112
+ data_config = config.get('data', {})
113
+ dataset_dir = data_config.get('dataset_dir', 'data')
114
+
115
+ # Check dataset directory exists
116
+ if not os.path.exists(dataset_dir):
117
+ raise ValueError(f"Dataset directory not found: {dataset_dir}")
118
+
119
+ # Check for expected data files
120
+ finetune_set = os.path.join(dataset_dir, 'finetune_set.json')
121
+ if not os.path.exists(finetune_set):
122
+ raise ValueError(
123
+ f"Training data not found: {finetune_set}\n"
124
+ "Please run data preprocessing first:\n"
125
+ " python scripts/preprocess_data.py"
126
+ )
127
+
128
+ # Validate checkpoint directory can be created
129
+ checkpoint_config = config.get('checkpoint', {})
130
+ checkpoint_dir = checkpoint_config.get('checkpoint_dir', 'models/checkpoints')
131
+ os.makedirs(checkpoint_dir, exist_ok=True)
132
+
133
+
134
+ def main():
135
+ """Main entry point for training."""
136
+ parser = argparse.ArgumentParser(
137
+ description="Train ENCOT model with configuration file",
138
+ formatter_class=argparse.RawDescriptionHelpFormatter,
139
+ epilog="""
140
+ Examples:
141
+ # Train with main ALM configuration
142
+ python scripts/train.py --config configs/train_ecoli_alm.yaml
143
+
144
+ # Quick test training (CPU, 1 epoch)
145
+ python scripts/train.py --config configs/train_ecoli_quick.yaml
146
+
147
+ # Override config values from command line
148
+ python scripts/train.py --config configs/train_ecoli_alm.yaml --num_gpus 2 --batch_size 4
149
+ """
150
+ )
151
+ parser.add_argument(
152
+ "--config",
153
+ type=str,
154
+ required=True,
155
+ help="Path to YAML configuration file"
156
+ )
157
+ parser.add_argument(
158
+ "--num_gpus",
159
+ type=int,
160
+ default=None,
161
+ help="Override number of GPUs from config"
162
+ )
163
+ parser.add_argument(
164
+ "--batch_size",
165
+ type=int,
166
+ default=None,
167
+ help="Override batch size from config"
168
+ )
169
+ parser.add_argument(
170
+ "--max_epochs",
171
+ type=int,
172
+ default=None,
173
+ help="Override max epochs from config"
174
+ )
175
+
176
+ args = parser.parse_args()
177
+
178
+ try:
179
+ # Lazy import so `--help` works even if training deps are missing.
180
+ from finetune import main as finetune_main
181
+
182
+ # Load configuration
183
+ print(f"Loading configuration from {args.config}...")
184
+ config = load_config(args.config)
185
+
186
+ # Override with command-line arguments if provided
187
+ if args.num_gpus is not None:
188
+ config.setdefault('training', {})['num_gpus'] = args.num_gpus
189
+ if args.batch_size is not None:
190
+ config.setdefault('training', {})['batch_size'] = args.batch_size
191
+ if args.max_epochs is not None:
192
+ config.setdefault('training', {})['max_epochs'] = args.max_epochs
193
+
194
+ # Validate configuration
195
+ print("Validating configuration...")
196
+ validate_config(config)
197
+
198
+ # Convert config to args namespace
199
+ train_args = config_to_args(config)
200
+
201
+ # Print training summary
202
+ print("\n" + "="*60)
203
+ print("Training Configuration Summary")
204
+ print("="*60)
205
+ print(f"Dataset directory: {train_args.dataset_dir}")
206
+ print(f"Checkpoint directory: {train_args.checkpoint_dir}")
207
+ print(f"Checkpoint filename: {train_args.checkpoint_filename}")
208
+ print(f"Batch size: {train_args.batch_size}")
209
+ print(f"Max epochs: {train_args.max_epochs}")
210
+ print(f"Learning rate: {train_args.learning_rate}")
211
+ print(f"Number of GPUs: {train_args.num_gpus}")
212
+ print(f"ALM enabled: {train_args.use_lagrangian}")
213
+ if train_args.use_lagrangian:
214
+ print(f"GC target: {train_args.gc_target}")
215
+ print(f"Curriculum epochs: {train_args.curriculum_epochs}")
216
+ print("="*60 + "\n")
217
+
218
+ # Run training
219
+ finetune_main(train_args)
220
+
221
+ except Exception as e:
222
+ print(f"Error: {e}", file=sys.stderr)
223
+ sys.exit(1)
224
+
225
+
226
+ if __name__ == "__main__":
227
+ main()
228
+
setup.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ from setuptools import find_packages, setup
4
+
5
+
6
+ def read_requirements():
7
+ with open("requirements.txt") as f:
8
+ return [line.strip() for line in f if line.strip() and not line.startswith("#")]
9
+
10
+
11
+ def read_readme():
12
+ here = os.path.abspath(os.path.dirname(__file__))
13
+ readme_path = os.path.join(here, "README.md")
14
+
15
+ with open(readme_path, "r", encoding="utf-8") as f:
16
+ return f.read()
17
+
18
+
19
+ setup(
20
+ name="ENCOT",
21
+ version="1.0.0",
22
+ packages=find_packages(),
23
+ install_requires=read_requirements(),
24
+ author="Adibvafa Fallahpour",
25
+ author_email="Adibvafa.fallahpour@mail.utoronto.ca",
26
+ description=(
27
+ "Transformer-based codon optimization for E. coli using "
28
+ "deep learning with Augmented-Lagrangian GC control. "
29
+ "Built on CodonTransformer for E. coli-specific optimization."
30
+ ),
31
+ long_description=read_readme(),
32
+ long_description_content_type="text/markdown",
33
+ url="https://github.com/geno543/ENCOT",
34
+ classifiers=[
35
+ "Programming Language :: Python :: 3",
36
+ "License :: OSI Approved :: Apache Software License",
37
+ "Operating System :: OS Independent",
38
+ ],
39
+ python_requires=">=3.9",
40
+ )
src/CodonTransformer_inference_template.xlsx ADDED
Binary file (17.4 kB). View file
 
src/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ """Model weights, tokenizer, and other resources."""
src/banner_final.png ADDED

Git LFS Details

  • SHA256: 6aa745d1f362190e7ae0b8940154446e68426bfb16ef6be9336fb6f98168a205
  • Pointer size: 131 Bytes
  • Size of remote file: 468 kB
src/organism2id.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:44f7b73bbb3c6ea82bf864e886b57b219cbd5f14fe79a8aa47d2befab5d40ad0
3
+ size 4605
streamlit_app.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Public Streamlit entrypoint for ENCOT.
2
+
3
+ This file is intentionally minimal so hosting platforms like Streamlit
4
+ Community Cloud can run the existing UI without changing project structure.
5
+ """
6
+
7
+ from pathlib import Path
8
+ import sys
9
+
10
+
11
+ ROOT = Path(__file__).resolve().parent
12
+ if str(ROOT) not in sys.path:
13
+ sys.path.insert(0, str(ROOT))
14
+
15
+ # Importing this module runs the Streamlit app defined there.
16
+ import streamlit_gui.app # noqa: F401,E402
streamlit_gui/app.py ADDED
@@ -0,0 +1,1456 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ File: app.py
3
+ -------------
4
+ Streamlit GUI for ENCOT. Provides sequence validation, optimization,
5
+ and visualization for E. coli-focused workflows with optional post-processing.
6
+ """
7
+
8
+ import streamlit as st
9
+ import torch
10
+ import pandas as pd
11
+ import numpy as np
12
+ import plotly.graph_objects as go
13
+ import plotly.express as px
14
+ from transformers import AutoTokenizer, BigBirdForMaskedLM
15
+ from huggingface_hub import hf_hub_download
16
+ from datasets import load_dataset
17
+ import time
18
+ import threading
19
+ from typing import Dict, Optional, Tuple
20
+ import warnings
21
+ warnings.filterwarnings("ignore")
22
+
23
+ import sys
24
+ import os
25
+ sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
26
+
27
+ from CodonTransformer.CodonPrediction import (
28
+ predict_dna_sequence,
29
+ load_model
30
+ )
31
+ from CodonTransformer.CodonEvaluation import (
32
+ get_GC_content,
33
+ calculate_tAI,
34
+ get_ecoli_tai_weights,
35
+ scan_for_restriction_sites,
36
+ count_negative_cis_elements,
37
+ calculate_homopolymer_runs
38
+ )
39
+ from CAI import CAI, relative_adaptiveness
40
+ from CodonTransformer.CodonUtils import get_organism2id_dict
41
+ import json
42
+
43
+ try:
44
+ from CodonTransformer.CodonPostProcessing import (
45
+ polish_sequence_with_dnachisel,
46
+ DNACHISEL_AVAILABLE
47
+ )
48
+ POST_PROCESSING_AVAILABLE = True
49
+ except ImportError:
50
+ POST_PROCESSING_AVAILABLE = False
51
+ DNACHISEL_AVAILABLE = False
52
+
53
+ st.set_page_config(
54
+ page_title="ENCOT GUI",
55
+ layout="wide",
56
+ initial_sidebar_state="expanded"
57
+ )
58
+
59
+ if 'model' not in st.session_state:
60
+ st.session_state.model = None
61
+ if 'tokenizer' not in st.session_state:
62
+ st.session_state.tokenizer = None
63
+ if 'device' not in st.session_state:
64
+ st.session_state.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
65
+ if 'optimization_running' not in st.session_state:
66
+ st.session_state.optimization_running = False
67
+ if 'results' not in st.session_state:
68
+ st.session_state.results = None
69
+ if 'post_processed_results' not in st.session_state:
70
+ st.session_state.post_processed_results = None
71
+ if 'cai_weights' not in st.session_state:
72
+ st.session_state.cai_weights = None
73
+ if 'tai_weights' not in st.session_state:
74
+ st.session_state.tai_weights = None
75
+
76
+ def get_organism_tai_weights(organism: str) -> Dict[str, float]:
77
+ """Get organism-specific tAI weights from pre-calculated data"""
78
+ try:
79
+ # Load organism-specific tAI weights
80
+ weights_file = os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), 'organism_tai_weights.json')
81
+ with open(weights_file, 'r') as f:
82
+ all_weights = json.load(f)
83
+
84
+ if organism in all_weights:
85
+ return all_weights[organism]
86
+ else:
87
+ # Fallback to E. coli if organism not found
88
+ st.warning(f"tAI weights for {organism} not found, using E. coli weights")
89
+ return all_weights.get("Escherichia coli general", get_ecoli_tai_weights())
90
+ except Exception as e:
91
+ st.error(f"Error loading organism-specific tAI weights: {e}")
92
+ return get_ecoli_tai_weights()
93
+
94
+ def load_model_and_tokenizer():
95
+ """Load the model and tokenizer with progress tracking"""
96
+ if st.session_state.model is None or st.session_state.tokenizer is None:
97
+ with st.spinner("Loading model... This may take a few minutes."):
98
+ progress_bar = st.progress(0)
99
+ status_text = st.empty()
100
+
101
+ status_text.text("Loading tokenizer...")
102
+ progress_bar.progress(25)
103
+ st.session_state.tokenizer = AutoTokenizer.from_pretrained("adibvafa/CodonTransformer")
104
+
105
+ status_text.text("Loading fine-tuned model from Hugging Face...")
106
+ progress_bar.progress(50)
107
+ # Try to download and load fine-tuned model from Hugging Face
108
+ try:
109
+ # Download the checkpoint file from Hugging Face
110
+ from huggingface_hub import hf_hub_download
111
+
112
+ status_text.text("Downloading model from saketh11/ColiFormer...")
113
+ model_path = hf_hub_download(
114
+ repo_id="saketh11/ColiFormer",
115
+ filename="balanced_alm_finetune.ckpt",
116
+ cache_dir="./hf_cache"
117
+ )
118
+
119
+ status_text.text("Loading downloaded model...")
120
+ st.session_state.model = load_model(
121
+ model_path=model_path,
122
+ device=st.session_state.device,
123
+ attention_type="original_full"
124
+ )
125
+ status_text.text("Fine-tuned model loaded from Hugging Face")
126
+ st.session_state.model_type = "fine_tuned_hf"
127
+ except Exception as e:
128
+ status_text.text(f"Failed to load from Hugging Face: {str(e)[:50]}...")
129
+ status_text.text("Loading base model as fallback...")
130
+ st.session_state.model = BigBirdForMaskedLM.from_pretrained("adibvafa/CodonTransformer")
131
+ st.session_state.model = st.session_state.model.to(st.session_state.device)
132
+ st.session_state.model_type = "base"
133
+
134
+ progress_bar.progress(100)
135
+ time.sleep(0.5)
136
+
137
+ status_text.empty()
138
+ progress_bar.empty()
139
+
140
+ @st.cache_data
141
+ def download_reference_data():
142
+ """Download and cache reference data from Hugging Face"""
143
+ try:
144
+ # Download the processed genes file from Hugging Face
145
+ file_path = hf_hub_download(
146
+ repo_id="saketh11/ColiFormer-Data",
147
+ filename="ecoli_processed_genes.csv",
148
+ repo_type="dataset"
149
+ )
150
+ df = pd.read_csv(file_path)
151
+ return df['dna_sequence'].tolist()
152
+ except Exception as e:
153
+ st.warning(f"Could not download reference data from Hugging Face: {e}")
154
+ # Fallback to minimal sequences
155
+ return [
156
+ "ATGGCGAAAGCGCTGTATCGCGAAAGCGCTGTATCGCGAAAGCGCTGTATCGC",
157
+ "ATGAAATTTATTTATTATTATAAATTTATTTATTATTATAAATTTATTTAT",
158
+ "ATGGGTCGTCGTCGTCGTGGTCGTCGTCGTCGTGGTCGTCGTCGTCGTGGT"
159
+ ]
160
+
161
+ @st.cache_data
162
+ def download_tai_weights():
163
+ """Download and cache tAI weights from Hugging Face"""
164
+ try:
165
+ # Download the tAI weights file from Hugging Face
166
+ file_path = hf_hub_download(
167
+ repo_id="saketh11/ColiFormer-Data",
168
+ filename="organism_tai_weights.json",
169
+ repo_type="dataset"
170
+ )
171
+ with open(file_path, 'r') as f:
172
+ all_weights = json.load(f)
173
+ return all_weights.get("Escherichia coli general", get_ecoli_tai_weights())
174
+ except Exception as e:
175
+ st.warning(f"Could not download tAI weights from Hugging Face: {e}")
176
+ return get_ecoli_tai_weights()
177
+
178
+ def load_reference_data(organism: str = "Escherichia coli general"):
179
+ """Load reference sequences and tAI weights for E. coli"""
180
+ if 'cai_weights' not in st.session_state or st.session_state['cai_weights'] is None:
181
+ try:
182
+ # Download reference sequences from Hugging Face
183
+ with st.spinner("Downloading E. coli reference sequences from Hugging Face..."):
184
+ ref_sequences = download_reference_data()
185
+ st.session_state['cai_weights'] = relative_adaptiveness(sequences=ref_sequences)
186
+ if len(ref_sequences) > 100: # If we got the full dataset
187
+ st.success(f"Downloaded {len(ref_sequences):,} E. coli reference sequences for CAI calculation")
188
+ else:
189
+ st.info(f"Using {len(ref_sequences)} minimal reference sequences (full dataset unavailable)")
190
+ except Exception as e:
191
+ st.error(f"Error loading E. coli reference data: {e}")
192
+ st.session_state['cai_weights'] = {}
193
+ # tAI weights (E. coli only)
194
+ if 'tai_weights' not in st.session_state or st.session_state['tai_weights'] is None:
195
+ try:
196
+ with st.spinner("Downloading E. coli tAI weights from Hugging Face..."):
197
+ st.session_state['tai_weights'] = download_tai_weights()
198
+ st.success("Downloaded E. coli tAI weights")
199
+ except Exception as e:
200
+ st.error(f"Error loading E. coli tAI weights: {e}")
201
+ st.session_state['tai_weights'] = {}
202
+
203
+ def validate_sequence(sequence: str) -> Tuple[bool, str, str, str]:
204
+ """Validate sequence and return status, message, sequence type, and possibly fixed sequence"""
205
+ if not sequence:
206
+ return False, "Sequence cannot be empty", "unknown", sequence
207
+
208
+ # Remove whitespace and convert to uppercase
209
+ sequence = sequence.strip().upper()
210
+
211
+ # Check if it's a DNA sequence
212
+ dna_chars = set("ATGC")
213
+ protein_chars = set("ACDEFGHIKLMNPQRSTVWY*_")
214
+
215
+ sequence_chars = set(sequence)
216
+
217
+ # If all characters are DNA nucleotides, treat as DNA
218
+ if sequence_chars.issubset(dna_chars):
219
+ if len(sequence) < 3:
220
+ return False, "DNA sequence must be at least 3 nucleotides long", "dna", sequence
221
+
222
+ # Auto-fix DNA sequences not divisible by 3
223
+ if len(sequence) % 3 != 0:
224
+ remainder = len(sequence) % 3
225
+ fixed_sequence = sequence[:-remainder]
226
+ message = f"Valid DNA sequence (auto-fixed: removed {remainder} nucleotides from end to make divisible by 3)"
227
+ else:
228
+ fixed_sequence = sequence
229
+ message = "Valid DNA sequence"
230
+
231
+ return True, message, "dna", fixed_sequence
232
+
233
+ # If contains protein-specific amino acids, treat as protein
234
+ elif sequence_chars.issubset(protein_chars):
235
+ if len(sequence) < 3:
236
+ return False, "Protein sequence must be at least 3 amino acids long", "protein", sequence
237
+ return True, "Valid protein sequence", "protein", sequence
238
+
239
+ # Invalid characters
240
+ else:
241
+ invalid_chars = sequence_chars - (dna_chars | protein_chars)
242
+ return False, f"Invalid characters found: {', '.join(invalid_chars)}", "unknown", sequence
243
+
244
+ def calculate_input_metrics(sequence: str, organism: str, sequence_type: str) -> Dict:
245
+ """Calculate metrics for the input sequence using E. coli reference only"""
246
+ # Load reference data (E. coli only)
247
+ load_reference_data()
248
+ if sequence_type == "dna":
249
+ dna_sequence = sequence.upper()
250
+ metrics = {
251
+ 'length': len(dna_sequence) // 3,
252
+ 'gc_content': get_GC_content(dna_sequence),
253
+ 'baseline_dna': dna_sequence,
254
+ 'sequence_type': 'dna'
255
+ }
256
+ try:
257
+ if 'cai_weights' in st.session_state and st.session_state['cai_weights']:
258
+ metrics['cai'] = CAI(dna_sequence, weights=st.session_state['cai_weights'])
259
+ else:
260
+ metrics['cai'] = None
261
+ except:
262
+ metrics['cai'] = None
263
+ try:
264
+ if 'tai_weights' in st.session_state and st.session_state['tai_weights']:
265
+ metrics['tai'] = calculate_tAI(dna_sequence, st.session_state['tai_weights'])
266
+ else:
267
+ metrics['tai'] = None
268
+ except:
269
+ metrics['tai'] = None
270
+ else:
271
+ most_frequent_codons = {
272
+ 'A': 'GCG', 'C': 'TGC', 'D': 'GAT', 'E': 'GAA', 'F': 'TTT',
273
+ 'G': 'GGC', 'H': 'CAT', 'I': 'ATT', 'K': 'AAA', 'L': 'CTG',
274
+ 'M': 'ATG', 'N': 'AAC', 'P': 'CCG', 'Q': 'CAG', 'R': 'CGC',
275
+ 'S': 'TCG', 'T': 'ACG', 'V': 'GTG', 'W': 'TGG', 'Y': 'TAT',
276
+ '*': 'TAA', '_': 'TAA'
277
+ }
278
+ baseline_dna = ''.join([most_frequent_codons.get(aa, 'NNN') for aa in sequence])
279
+ metrics = {
280
+ 'length': len(sequence),
281
+ 'gc_content': get_GC_content(baseline_dna),
282
+ 'baseline_dna': baseline_dna,
283
+ 'sequence_type': 'protein'
284
+ }
285
+ try:
286
+ if 'cai_weights' in st.session_state and st.session_state['cai_weights']:
287
+ metrics['cai'] = CAI(baseline_dna, weights=st.session_state['cai_weights'])
288
+ else:
289
+ metrics['cai'] = None
290
+ except:
291
+ metrics['cai'] = None
292
+ try:
293
+ if 'tai_weights' in st.session_state and st.session_state['tai_weights']:
294
+ metrics['tai'] = calculate_tAI(baseline_dna, st.session_state['tai_weights'])
295
+ else:
296
+ metrics['tai'] = None
297
+ except:
298
+ metrics['tai'] = None
299
+ try:
300
+ analysis_dna = metrics['baseline_dna']
301
+ metrics['restriction_sites'] = len(scan_for_restriction_sites(analysis_dna))
302
+ metrics['negative_cis_elements'] = count_negative_cis_elements(analysis_dna)
303
+ metrics['homopolymer_runs'] = calculate_homopolymer_runs(analysis_dna)
304
+ except:
305
+ metrics['restriction_sites'] = 0
306
+ metrics['negative_cis_elements'] = 0
307
+ metrics['homopolymer_runs'] = 0
308
+ return metrics
309
+
310
+ def translate_dna_to_protein(dna_sequence: str) -> str:
311
+ """Translate DNA sequence to protein sequence"""
312
+ codon_table = {
313
+ 'TTT': 'F', 'TTC': 'F', 'TTA': 'L', 'TTG': 'L',
314
+ 'TCT': 'S', 'TCC': 'S', 'TCA': 'S', 'TCG': 'S',
315
+ 'TAT': 'Y', 'TAC': 'Y', 'TAA': '*', 'TAG': '*',
316
+ 'TGT': 'C', 'TGC': 'C', 'TGA': '*', 'TGG': 'W',
317
+ 'CTT': 'L', 'CTC': 'L', 'CTA': 'L', 'CTG': 'L',
318
+ 'CCT': 'P', 'CCC': 'P', 'CCA': 'P', 'CCG': 'P',
319
+ 'CAT': 'H', 'CAC': 'H', 'CAA': 'Q', 'CAG': 'Q',
320
+ 'CGT': 'R', 'CGC': 'R', 'CGA': 'R', 'CGG': 'R',
321
+ 'ATT': 'I', 'ATC': 'I', 'ATA': 'I', 'ATG': 'M',
322
+ 'ACT': 'T', 'ACC': 'T', 'ACA': 'T', 'ACG': 'T',
323
+ 'AAT': 'N', 'AAC': 'N', 'AAA': 'K', 'AAG': 'K',
324
+ 'AGT': 'S', 'AGC': 'S', 'AGA': 'R', 'AGG': 'R',
325
+ 'GTT': 'V', 'GTC': 'V', 'GTA': 'V', 'GTG': 'V',
326
+ 'GCT': 'A', 'GCC': 'A', 'GCA': 'A', 'GCG': 'A',
327
+ 'GAT': 'D', 'GAC': 'D', 'GAA': 'E', 'GAG': 'E',
328
+ 'GGT': 'G', 'GGC': 'G', 'GGA': 'G', 'GGG': 'G'
329
+ }
330
+
331
+ protein = ""
332
+ for i in range(0, len(dna_sequence), 3):
333
+ codon = dna_sequence[i:i+3].upper()
334
+ if len(codon) == 3:
335
+ aa = codon_table.get(codon, 'X')
336
+ if aa == '*': # Stop codon
337
+ break
338
+ protein += aa
339
+
340
+ return protein
341
+
342
+ def create_gc_content_plot(sequence: str, window_size: int = 50) -> go.Figure:
343
+ """Create a sliding window GC content plot"""
344
+ if len(sequence) < window_size:
345
+ window_size = len(sequence) // 3
346
+
347
+ positions = []
348
+ gc_values = []
349
+
350
+ for i in range(0, len(sequence) - window_size + 1, 3): # Step by codons
351
+ window = sequence[i:i + window_size]
352
+ gc_content = get_GC_content(window)
353
+ positions.append(i // 3) # Position in codons
354
+ gc_values.append(gc_content)
355
+
356
+ fig = go.Figure()
357
+ fig.add_trace(go.Scatter(
358
+ x=positions,
359
+ y=gc_values,
360
+ mode='lines',
361
+ name='GC Content',
362
+ line=dict(color='blue', width=2)
363
+ ))
364
+
365
+ # Add target range
366
+ fig.add_hline(y=45, line_dash="dash", line_color="red",
367
+ annotation_text="Min Target (45%)")
368
+ fig.add_hline(y=55, line_dash="dash", line_color="red",
369
+ annotation_text="Max Target (55%)")
370
+
371
+ fig.update_layout(
372
+ title=f'GC Content (sliding window: {window_size} bp)',
373
+ xaxis_title='Position (codons)',
374
+ yaxis_title='GC Content (%)',
375
+ height=300
376
+ )
377
+
378
+ return fig
379
+
380
+ def create_gc_comparison_chart(before_metrics: Dict, after_metrics: Dict) -> go.Figure:
381
+ """Create a comparison chart for GC Content"""
382
+ fig = go.Figure()
383
+ fig.add_trace(go.Bar(
384
+ name='Before Optimization',
385
+ x=['GC Content (%)'],
386
+ y=[before_metrics.get('gc_content', 0)],
387
+ marker_color='lightblue',
388
+ text=[f"{before_metrics.get('gc_content', 0):.1f}%"],
389
+ textposition='auto'
390
+ ))
391
+ fig.add_trace(go.Bar(
392
+ name='After Optimization',
393
+ x=['GC Content (%)'],
394
+ y=[after_metrics.get('gc_content', 0)],
395
+ marker_color='darkblue',
396
+ text=[f"{after_metrics.get('gc_content', 0):.1f}%"],
397
+ textposition='auto'
398
+ ))
399
+ fig.update_layout(
400
+ title='GC Content Comparison: Before vs After',
401
+ xaxis_title='Metric',
402
+ yaxis_title='Value (%)',
403
+ barmode='group',
404
+ height=300
405
+ )
406
+ return fig
407
+
408
+ def create_expression_comparison_chart(before_metrics: Dict, after_metrics: Dict) -> go.Figure:
409
+ """Create a comparison chart for expression metrics (CAI, tAI)"""
410
+ metrics_names = ['CAI', 'tAI']
411
+ before_values = [
412
+ before_metrics.get('cai', 0) if before_metrics.get('cai') else 0,
413
+ before_metrics.get('tai', 0) if before_metrics.get('tai') else 0
414
+ ]
415
+ after_values = [
416
+ after_metrics.get('cai', 0) if after_metrics.get('cai') else 0,
417
+ after_metrics.get('tai', 0) if after_metrics.get('tai') else 0
418
+ ]
419
+
420
+ fig = go.Figure()
421
+ fig.add_trace(go.Bar(
422
+ name='Before Optimization',
423
+ x=metrics_names,
424
+ y=before_values,
425
+ marker_color='lightblue',
426
+ text=[f"{v:.3f}" for v in before_values],
427
+ textposition='auto'
428
+ ))
429
+ fig.add_trace(go.Bar(
430
+ name='After Optimization',
431
+ x=metrics_names,
432
+ y=after_values,
433
+ marker_color='darkblue',
434
+ text=[f"{v:.3f}" for v in after_values],
435
+ textposition='auto'
436
+ ))
437
+ fig.update_layout(
438
+ title='Expression Metrics Comparison: Before vs After',
439
+ xaxis_title='Metric',
440
+ yaxis_title='Value',
441
+ barmode='group',
442
+ height=300
443
+ )
444
+ return fig
445
+
446
+ def smart_codon_replacement(dna_sequence: str, target_gc_min: float = 0.45, target_gc_max: float = 0.55, max_iterations: int = 100) -> str:
447
+ """Smart codon replacement to optimize GC content while maximizing CAI"""
448
+
449
+ # Codon alternatives with their GC content
450
+ codon_alternatives = {
451
+ # Serine: high GC options
452
+ 'TCT': ['TCG', 'TCC', 'TCA', 'AGT', 'AGC'], # 33% -> 67%, 67%, 33%, 33%, 67%
453
+ 'TCA': ['TCG', 'TCC', 'TCT', 'AGT', 'AGC'],
454
+ 'AGT': ['TCG', 'TCC', 'TCT', 'TCA', 'AGC'],
455
+
456
+ # Leucine: various GC options
457
+ 'TTA': ['TTG', 'CTT', 'CTC', 'CTA', 'CTG'], # 0% -> 33%, 33%, 67%, 33%, 67%
458
+ 'TTG': ['TTA', 'CTT', 'CTC', 'CTA', 'CTG'],
459
+ 'CTT': ['CTG', 'CTC', 'TTA', 'TTG', 'CTA'],
460
+ 'CTA': ['CTG', 'CTC', 'CTT', 'TTA', 'TTG'],
461
+
462
+ # Arginine: various GC options
463
+ 'AGA': ['CGT', 'CGC', 'CGA', 'CGG', 'AGG'], # 33% -> 67%, 100%, 67%, 100%, 67%
464
+ 'AGG': ['CGT', 'CGC', 'CGA', 'CGG', 'AGA'],
465
+ 'CGT': ['CGC', 'CGG', 'CGA', 'AGA', 'AGG'],
466
+ 'CGA': ['CGC', 'CGG', 'CGT', 'AGA', 'AGG'],
467
+
468
+ # Proline
469
+ 'CCT': ['CCG', 'CCC', 'CCA'], # 67% -> 100%, 100%, 67%
470
+ 'CCA': ['CCG', 'CCC', 'CCT'],
471
+
472
+ # Threonine
473
+ 'ACT': ['ACG', 'ACC', 'ACA'], # 33% -> 67%, 67%, 33%
474
+ 'ACA': ['ACG', 'ACC', 'ACT'],
475
+
476
+ # Alanine
477
+ 'GCT': ['GCG', 'GCC', 'GCA'], # 67% -> 100%, 100%, 67%
478
+ 'GCA': ['GCG', 'GCC', 'GCT'],
479
+
480
+ # Glycine
481
+ 'GGT': ['GGG', 'GGC', 'GGA'], # 67% -> 100%, 100%, 67%
482
+ 'GGA': ['GGG', 'GGC', 'GGT'],
483
+
484
+ # Valine
485
+ 'GTT': ['GTG', 'GTC', 'GTA'], # 67% -> 100%, 100%, 67%
486
+ 'GTA': ['GTG', 'GTC', 'GTT'],
487
+ }
488
+
489
+ def get_codon_gc(codon):
490
+ return (codon.count('G') + codon.count('C')) / 3.0
491
+
492
+ current_sequence = dna_sequence.upper()
493
+ current_gc = get_GC_content(current_sequence)
494
+
495
+ if target_gc_min <= current_gc <= target_gc_max:
496
+ return current_sequence
497
+
498
+ codons = [current_sequence[i:i+3] for i in range(0, len(current_sequence), 3)]
499
+
500
+ for iteration in range(max_iterations):
501
+ current_gc = get_GC_content(''.join(codons))
502
+
503
+ if target_gc_min <= current_gc <= target_gc_max:
504
+ break
505
+
506
+ # Find best codon to replace
507
+ best_improvement = 0
508
+ best_pos = -1
509
+ best_replacement = None
510
+
511
+ for pos, codon in enumerate(codons):
512
+ if codon in codon_alternatives:
513
+ for alt_codon in codon_alternatives[codon]:
514
+ # Calculate GC change
515
+ old_gc_contrib = get_codon_gc(codon)
516
+ new_gc_contrib = get_codon_gc(alt_codon)
517
+ gc_change = new_gc_contrib - old_gc_contrib
518
+
519
+ # Check if this change moves us toward target
520
+ if current_gc < target_gc_min and gc_change > best_improvement:
521
+ best_improvement = gc_change
522
+ best_pos = pos
523
+ best_replacement = alt_codon
524
+ elif current_gc > target_gc_max and gc_change < best_improvement:
525
+ best_improvement = abs(gc_change)
526
+ best_pos = pos
527
+ best_replacement = alt_codon
528
+
529
+ if best_pos >= 0:
530
+ if isinstance(best_replacement, str):
531
+ codons[best_pos] = best_replacement
532
+ else:
533
+ break # No more improvements possible
534
+
535
+ return ''.join(codons)
536
+
537
+ def run_optimization(protein: str, organism: str, use_post_processing: bool = False):
538
+ """Run the optimization using the exact method from run_full_comparison.py with auto GC correction"""
539
+ st.session_state.optimization_running = True
540
+ st.session_state.post_processed_results = None
541
+
542
+ try:
543
+ # Use the exact same method that achieved best results in evaluation
544
+ result = predict_dna_sequence(
545
+ protein=protein,
546
+ organism=organism,
547
+ device=st.session_state.device,
548
+ model=st.session_state.model,
549
+ deterministic=True,
550
+ match_protein=True,
551
+ )
552
+
553
+ # Check GC content and auto-correct if out of optimal range
554
+ _res = result[0] if isinstance(result, list) else result
555
+ initial_gc = get_GC_content(_res.predicted_dna)
556
+
557
+ if initial_gc < 45.0 or initial_gc > 55.0:
558
+ # Auto-correct GC content silently
559
+ optimized_dna = smart_codon_replacement(_res.predicted_dna, 0.45, 0.55)
560
+ smart_gc = get_GC_content(optimized_dna)
561
+
562
+ if 45.0 <= smart_gc <= 55.0:
563
+ from CodonTransformer.CodonUtils import DNASequencePrediction
564
+ result = DNASequencePrediction(
565
+ organism=_res.organism,
566
+ protein=_res.protein,
567
+ processed_input=_res.processed_input,
568
+ predicted_dna=optimized_dna
569
+ )
570
+ else:
571
+ # Fall back to constrained beam search silently
572
+ try:
573
+ result = predict_dna_sequence(
574
+ protein=protein,
575
+ organism=organism,
576
+ device=st.session_state.device,
577
+ model=st.session_state.model,
578
+ deterministic=True,
579
+ match_protein=True,
580
+ use_constrained_search=True,
581
+ gc_bounds=(0.45, 0.55),
582
+ beam_size=20
583
+ )
584
+ _res2 = result[0] if isinstance(result, list) else result
585
+ final_gc = get_GC_content(_res2.predicted_dna)
586
+ except Exception as e:
587
+ # If constrained search fails, use smart replacement result anyway
588
+ from CodonTransformer.CodonUtils import DNASequencePrediction
589
+ result = DNASequencePrediction(
590
+ organism=_res.organism,
591
+ protein=_res.protein,
592
+ processed_input=_res.processed_input,
593
+ predicted_dna=optimized_dna
594
+ )
595
+
596
+ st.session_state.results = result
597
+
598
+ # Post-processing if enabled
599
+ if use_post_processing and POST_PROCESSING_AVAILABLE and result:
600
+ try:
601
+ _res = result[0] if isinstance(result, list) else result
602
+ polished_sequence = polish_sequence_with_dnachisel(
603
+ dna_sequence=_res.predicted_dna,
604
+ protein_sequence=protein,
605
+ gc_bounds=(45.0, 55.0),
606
+ cai_species=organism.lower().replace(' ', '_'),
607
+ avoid_homopolymers_length=6
608
+ )
609
+
610
+ # Create enhanced result object
611
+ from CodonTransformer.CodonUtils import DNASequencePrediction
612
+ st.session_state.post_processed_results = DNASequencePrediction(
613
+ organism=result.organism,
614
+ protein=result.protein,
615
+ processed_input=result.processed_input,
616
+ predicted_dna=polished_sequence
617
+ )
618
+ except Exception as e:
619
+ st.session_state.post_processed_results = f"Post-processing error: {str(e)}"
620
+
621
+ except Exception as e:
622
+ st.session_state.results = f"Error: {str(e)}"
623
+
624
+ finally:
625
+ st.session_state.optimization_running = False
626
+
627
+ def main():
628
+ st.title("ENCOT")
629
+ st.markdown("E. coli codon optimization with constraint-aware decoding and in silico evaluation metrics.")
630
+
631
+ # Remove the performance highlights expander (details/summary block)
632
+ # (No expander here anymore)
633
+
634
+ # Load model
635
+ load_model_and_tokenizer()
636
+
637
+ # Create the main tabbed interface
638
+ tab1, tab2, tab3, tab4 = st.tabs(["Single Optimize", "Batch Process", "Comparative Analysis", "Advanced Settings"])
639
+
640
+ with tab1:
641
+ single_sequence_optimization()
642
+
643
+ with tab2:
644
+ batch_processing_interface()
645
+
646
+ with tab3:
647
+ comparative_analysis_interface()
648
+
649
+ with tab4:
650
+ advanced_settings_interface()
651
+
652
+ def single_sequence_optimization():
653
+ """Single sequence optimization interface - enhanced from original functionality"""
654
+ # Sidebar configuration
655
+ st.sidebar.header("Configuration")
656
+ organism_options = [
657
+ "Escherichia coli general",
658
+ "Saccharomyces cerevisiae",
659
+ "Homo sapiens",
660
+ "Bacillus subtilis",
661
+ "Pichia pastoris"
662
+ ]
663
+ organism = st.sidebar.selectbox("Select Target Organism", organism_options)
664
+ load_reference_data(organism)
665
+ with st.sidebar.expander("Advanced Optimization Settings"):
666
+ st.markdown("**Model Parameters**")
667
+ use_deterministic = st.checkbox("Deterministic Mode", value=True, help="Use deterministic decoding for reproducible results")
668
+ match_protein = st.checkbox("Match Protein Validation", value=True, help="Ensure DNA translates back to exact protein")
669
+ st.markdown("**GC Content Control**")
670
+ gc_target_min = st.slider("GC Target Min (%)", 30, 70, 45, help="Minimum GC content target")
671
+ gc_target_max = st.slider("GC Target Max (%)", 30, 70, 55, help="Maximum GC content target")
672
+ st.markdown("**Quality Constraints**")
673
+ avoid_restriction_sites = st.multiselect(
674
+ "Avoid Restriction Sites",
675
+ ["EcoRI", "BamHI", "HindIII", "XhoI", "NotI"],
676
+ default=["EcoRI", "BamHI"]
677
+ )
678
+ st.sidebar.subheader("Post-Processing")
679
+ use_post_processing = st.sidebar.checkbox(
680
+ "Enable DNAChisel Post-Processing",
681
+ value=False,
682
+ disabled=not POST_PROCESSING_AVAILABLE,
683
+ help="Polish sequences to remove restriction sites, homopolymers, and synthesis issues"
684
+ )
685
+ if not POST_PROCESSING_AVAILABLE:
686
+ st.sidebar.warning("DNAChisel not available. Install with: pip install dnachisel")
687
+
688
+ # Dataset Information
689
+ st.sidebar.markdown("---")
690
+ st.sidebar.markdown("### Dataset Information")
691
+ st.sidebar.markdown("""
692
+ - **Dataset**: [ColiFormer-Data](https://huggingface.co/datasets/saketh11/ColiFormer-Data)
693
+ - **Training**: 3,676 high-expression E. coli genes (NCBI-curated)
694
+ - **Evaluation**: 37,053 native E. coli genes + 80 recombinant protein targets
695
+ - **Auto-download**: CAI weights & tAI coefficients
696
+ """)
697
+
698
+ # Model Information
699
+ st.sidebar.markdown("### Model Information")
700
+ st.sidebar.markdown("""
701
+ - **Model**: [ColiFormer](https://huggingface.co/saketh11/ColiFormer)
702
+ - **Base**: CodonTransformer BigBird architecture
703
+ - **Architecture**: BigBird Transformer + ALM
704
+ - **Auto-download**: From Hugging Face Hub
705
+ """)
706
+ col1, col2 = st.columns([1, 1])
707
+ with col1:
708
+ st.header("Input Sequence")
709
+ sequence_input = st.text_area(
710
+ "Enter Protein or DNA Sequence",
711
+ height=150,
712
+ placeholder="Enter protein sequence (MKWVT...) or DNA sequence (ATGGCG...)\n\nExample protein: MKWVTFISLLFLFSSAYSRGVFRRDAHKSEVAHRFKDLGEENFKALVLIAFAQYLQQCPFEDHVKLVNEVTEFAKTCVADESAENCDKSLHTLFGDKLCTVATLRETYGEMADCCAKQEPERNECFLQHKDDNPNLPRLVRPEVDVMCTAFHDNEETFLKKYLYEIARRHPYFYAPELLFFAKRYKAAFTECCQAADKAACLLPKLDELRDEGKASSAKQRLKCASLQKFGERAFKAWAVARLSQRFPKAEFAEVSKLVTDLTKVHTECCHGDLLECADDRADLAKYICENQDSISSKLKECCEKPLLEKSHCIAEVENDEMPADLPSLAADFVESKDVCKNYAEAKDVFLGMFLYEYARRHPDYSVVLLLRLAKTYETTLEKCCAAADPHECYAKVFDEFKPLVEEPQNLIKQNCELFEQLGEYKFQNALLVRYTKKVPQVSTPTLVEVSRNLGKVGSKCCKHPEAKRMPCAEDYLSVVLNQLCVLHEKTPVSDRVTKCCTE"
713
+ )
714
+ analyze_btn = st.button("Analyze Sequence", type="primary")
715
+ if sequence_input and analyze_btn:
716
+ is_valid, message, sequence_type, fixed_sequence = validate_sequence(sequence_input)
717
+ if is_valid:
718
+ st.success(message)
719
+ # Store in session state for use by Optimize Sequence
720
+ st.session_state.sequence_clean = fixed_sequence
721
+ st.session_state.sequence_type = sequence_type
722
+ st.session_state.input_metrics = calculate_input_metrics(fixed_sequence, organism, sequence_type)
723
+ st.session_state.organism = organism
724
+ else:
725
+ st.error(message)
726
+ if "Invalid characters" in message:
727
+ st.info("Suggestion: Remove spaces, numbers, and special characters. Use only standard amino acid letters (A–Z) for proteins or nucleotides (A/T/G/C) for DNA.")
728
+ elif "too long" in message:
729
+ st.info("Suggestion: Consider breaking long sequences into smaller segments for optimization.")
730
+ elif "too short" in message:
731
+ st.info("Suggestion: Minimum length is 3 characters. Ensure your sequence is complete.")
732
+ # Clear session state if invalid
733
+ st.session_state.sequence_clean = None
734
+ st.session_state.sequence_type = None
735
+ st.session_state.input_metrics = None
736
+ st.session_state.organism = None
737
+ elif not sequence_input:
738
+ st.session_state.sequence_clean = None
739
+ st.session_state.sequence_type = None
740
+ st.session_state.input_metrics = None
741
+ st.session_state.organism = None
742
+
743
+ # Always display the last analysis if it exists in session state
744
+ if st.session_state.get('input_metrics') and st.session_state.get('sequence_type'):
745
+ input_metrics = st.session_state.input_metrics
746
+ sequence_type = st.session_state.sequence_type
747
+ st.subheader("Input Analysis")
748
+ metrics_col1, metrics_col2, metrics_col3 = st.columns(3)
749
+ with metrics_col1:
750
+ unit = "codons" if sequence_type == "dna" else "AA"
751
+ length = input_metrics.get('length', 0) if input_metrics else 0
752
+ gc_content = input_metrics.get('gc_content', 0) if input_metrics else 0
753
+ st.metric("Length", f"{length} {unit}")
754
+ st.metric("GC Content", f"{gc_content:.1f}%")
755
+ with metrics_col2:
756
+ cai_val = input_metrics.get('cai') if input_metrics else None
757
+ if cai_val:
758
+ label = "CAI" if sequence_type == "dna" else "CAI (baseline)"
759
+ st.metric(label, f"{cai_val:.3f}")
760
+ else:
761
+ st.metric("CAI", "N/A")
762
+ with metrics_col3:
763
+ tai_val = input_metrics.get('tai') if input_metrics else None
764
+ if tai_val:
765
+ label = "tAI" if sequence_type == "dna" else "tAI (baseline)"
766
+ st.metric(label, f"{tai_val:.3f}")
767
+ else:
768
+ st.metric("tAI", "N/A")
769
+ st.subheader("Sequence Quality Analysis")
770
+ analysis_col1, analysis_col2, analysis_col3 = st.columns(3)
771
+ with analysis_col1:
772
+ sites_count = input_metrics.get('restriction_sites', 0) if input_metrics else 0
773
+ color = "normal" if sites_count <= 2 else "inverse"
774
+ st.metric("Restriction Sites", sites_count)
775
+ with analysis_col2:
776
+ neg_elements = input_metrics.get('negative_cis_elements', 0) if input_metrics else 0
777
+ st.metric("Negative Elements", neg_elements)
778
+ with analysis_col3:
779
+ homo_runs = input_metrics.get('homopolymer_runs', 0) if input_metrics else 0
780
+ st.metric("Homopolymer Runs", homo_runs)
781
+ baseline_dna = input_metrics.get('baseline_dna', '') if input_metrics else ''
782
+ if baseline_dna and len(baseline_dna) > 150:
783
+ st.subheader("GC Content Distribution")
784
+ fig = create_gc_content_plot(baseline_dna)
785
+ fig.update_layout(
786
+ title="Input Sequence GC Content Analysis",
787
+ xaxis_title="Position (codons)",
788
+ yaxis_title="GC Content (%)",
789
+ hovermode='x unified'
790
+ )
791
+ st.plotly_chart(fig, use_container_width=True)
792
+
793
+ with col2:
794
+ st.header("Optimization Results")
795
+ # Enhanced optimization button
796
+ if (
797
+ st.session_state.get('sequence_clean')
798
+ and st.session_state.get('sequence_type')
799
+ and not st.session_state.optimization_running
800
+ ):
801
+ st.markdown("**Ready to optimize your sequence!**")
802
+ strategy_info = st.container()
803
+ with strategy_info:
804
+ st.info(f"""
805
+ **Optimization Strategy:**
806
+ • Target organism: {st.session_state.organism}
807
+ • Model: Fine-tuned CodonTransformer (89.6M parameters)
808
+ • GC target: {gc_target_min}-{gc_target_max}%
809
+ • Mode: {'Deterministic' if use_deterministic else 'Stochastic'}
810
+ """)
811
+ if st.button("Optimize Sequence", type="primary", use_container_width=True):
812
+ st.session_state.results = None
813
+ if st.session_state.sequence_type == "dna":
814
+ protein_sequence = translate_dna_to_protein(st.session_state.sequence_clean)
815
+ run_optimization(protein_sequence, st.session_state.organism, use_post_processing)
816
+ else:
817
+ run_optimization(st.session_state.sequence_clean, st.session_state.organism, use_post_processing)
818
+
819
+ # Enhanced progress display
820
+ if st.session_state.optimization_running:
821
+ st.info("Optimizing sequence...")
822
+
823
+ # Create progress container
824
+ progress_container = st.container()
825
+ with progress_container:
826
+ progress_bar = st.progress(0)
827
+ status_text = st.empty()
828
+
829
+ # Enhanced progress steps
830
+ steps = [
831
+ "Analyzing input sequence structure...",
832
+ "Loading model...",
833
+ "Running optimization algorithm...",
834
+ "Applying GC/content constraints...",
835
+ "Finalizing optimized sequence..."
836
+ ]
837
+
838
+ for i, step in enumerate(steps):
839
+ progress_value = int((i + 1) / len(steps) * 100)
840
+ progress_bar.progress(progress_value)
841
+ status_text.text(step)
842
+ time.sleep(0.8) # Realistic timing
843
+
844
+ progress_bar.empty()
845
+ status_text.empty()
846
+
847
+ # Enhanced results display
848
+ if st.session_state.results and not st.session_state.optimization_running:
849
+ if isinstance(st.session_state.results, str):
850
+ st.error(f"Optimization failed: {st.session_state.results}")
851
+ else:
852
+ display_optimization_results(
853
+ st.session_state.results,
854
+ st.session_state.get('organism', organism),
855
+ st.session_state.get('sequence_clean', ''),
856
+ st.session_state.get('sequence_type', 'protein'),
857
+ st.session_state.get('input_metrics', {})
858
+ )
859
+
860
+ def display_optimization_results(result, organism, original_sequence, sequence_type, input_metrics):
861
+ """Enhanced results display with publication-quality visualizations"""
862
+
863
+ # Calculate optimized metrics
864
+ optimized_metrics = {
865
+ 'gc_content': get_GC_content(result.predicted_dna),
866
+ 'length': len(result.predicted_dna)
867
+ }
868
+
869
+ # Calculate CAI and tAI
870
+ try:
871
+ if 'cai_weights' in st.session_state and st.session_state['cai_weights']:
872
+ optimized_metrics['cai'] = CAI(result.predicted_dna, weights=st.session_state['cai_weights'])
873
+ else:
874
+ optimized_metrics['cai'] = None
875
+ except:
876
+ optimized_metrics['cai'] = None
877
+
878
+ try:
879
+ if 'tai_weights' in st.session_state and st.session_state['tai_weights']:
880
+ optimized_metrics['tai'] = calculate_tAI(result.predicted_dna, st.session_state['tai_weights'])
881
+ else:
882
+ optimized_metrics['tai'] = None
883
+ except:
884
+ optimized_metrics['tai'] = None
885
+
886
+ # Success header
887
+ st.success("Optimization complete.")
888
+
889
+ # Key improvements summary
890
+ st.subheader("Optimization Improvements")
891
+ imp_col1, imp_col2, imp_col3 = st.columns(3)
892
+
893
+ if input_metrics is not None:
894
+ with imp_col1:
895
+ if input_metrics.get('gc_content') and optimized_metrics.get('gc_content'):
896
+ gc_change = optimized_metrics['gc_content'] - input_metrics['gc_content']
897
+ st.metric("GC Content", f"{optimized_metrics['gc_content']:.1f}%", delta=f"{gc_change:+.1f}%")
898
+
899
+ with imp_col2:
900
+ if input_metrics.get('cai') and optimized_metrics.get('cai'):
901
+ cai_change = optimized_metrics['cai'] - input_metrics['cai']
902
+ st.metric("CAI Score", f"{optimized_metrics['cai']:.3f}", delta=f"{cai_change:+.3f}")
903
+
904
+ with imp_col3:
905
+ if input_metrics.get('tai') and optimized_metrics.get('tai'):
906
+ tai_change = optimized_metrics['tai'] - input_metrics['tai']
907
+ st.metric("tAI Score", f"{optimized_metrics['tai']:.3f}", delta=f"{tai_change:+.3f}")
908
+
909
+ # Optimized DNA sequence display
910
+ st.subheader("Optimized DNA Sequence")
911
+ st.text_area("Optimized DNA Sequence", result.predicted_dna, height=100)
912
+
913
+ # Enhanced download and export options
914
+ col1, col2, col3 = st.columns(3)
915
+ with col1:
916
+ st.download_button(
917
+ label="Download DNA (FASTA)",
918
+ data=f">Optimized_{organism.replace(' ', '_')}\n{result.predicted_dna}",
919
+ file_name=f"optimized_sequence_{organism.replace(' ', '_')}.fasta",
920
+ mime="text/plain"
921
+ )
922
+
923
+ with col2:
924
+ # Create CSV report
925
+ csv_data = f"Metric,Original,Optimized,Improvement\n"
926
+ csv_data += f"GC Content (%),{input_metrics['gc_content']:.1f},{optimized_metrics['gc_content']:.1f},{optimized_metrics['gc_content'] - input_metrics['gc_content']:+.1f}\n"
927
+ if input_metrics['cai'] and optimized_metrics['cai']:
928
+ csv_data += f"CAI Score,{input_metrics['cai']:.3f},{optimized_metrics['cai']:.3f},{optimized_metrics['cai'] - input_metrics['cai']:+.3f}\n"
929
+ if input_metrics['tai'] and optimized_metrics['tai']:
930
+ csv_data += f"tAI Score,{input_metrics['tai']:.3f},{optimized_metrics['tai']:.3f},{optimized_metrics['tai'] - input_metrics['tai']:+.3f}\n"
931
+
932
+ st.download_button(
933
+ label="Download Metrics (CSV)",
934
+ data=csv_data,
935
+ file_name=f"optimization_metrics_{organism.replace(' ', '_')}.csv",
936
+ mime="text/csv"
937
+ )
938
+
939
+ with col3:
940
+ st.button("Generate PDF Report", help="Coming soon: PDF report")
941
+
942
+ # Enhanced comparison visualizations
943
+ st.subheader("Before vs After Analysis")
944
+
945
+ # Create enhanced comparison charts
946
+ create_enhanced_comparison_charts(input_metrics, optimized_metrics, original_sequence, result.predicted_dna, sequence_type)
947
+
948
+ def create_enhanced_comparison_charts(input_metrics, optimized_metrics, original_dna, optimized_dna, sequence_type):
949
+ """Create publication-quality comparison visualizations"""
950
+ if input_metrics is None or optimized_metrics is None:
951
+ st.info("No comparison data available.")
952
+ return
953
+
954
+ # GC Content comparison
955
+ gc_comp_fig = create_gc_comparison_chart(input_metrics, optimized_metrics)
956
+ gc_comp_fig.update_layout(
957
+ title="GC Content Optimization Results",
958
+ font=dict(size=12),
959
+ height=350
960
+ )
961
+ st.plotly_chart(gc_comp_fig, use_container_width=True)
962
+
963
+ # Expression metrics comparison
964
+ if input_metrics.get('cai') and optimized_metrics.get('cai'):
965
+ expr_comp_fig = create_expression_comparison_chart(input_metrics, optimized_metrics)
966
+ expr_comp_fig.update_layout(
967
+ title="Expression Potential Improvement",
968
+ font=dict(size=12),
969
+ height=350
970
+ )
971
+ st.plotly_chart(expr_comp_fig, use_container_width=True)
972
+
973
+ # Side-by-side GC distribution analysis
974
+ st.subheader("GC Content Distribution Analysis")
975
+ col1, col2 = st.columns(2)
976
+
977
+ with col1:
978
+ st.write(f"**{'Original DNA' if sequence_type == 'dna' else 'Baseline (Most Frequent Codons)'}**")
979
+ baseline_dna = input_metrics.get('baseline_dna') if input_metrics else None
980
+ plot_dna = baseline_dna if baseline_dna is not None else original_dna
981
+ if plot_dna is not None and isinstance(plot_dna, str) and len(plot_dna) > 150:
982
+ fig_before = create_gc_content_plot(plot_dna)
983
+ fig_before.update_layout(title="Before Optimization", height=300)
984
+ st.plotly_chart(fig_before, use_container_width=True)
985
+ else:
986
+ st.info("Sequence too short for sliding window analysis")
987
+
988
+ with col2:
989
+ st.write("** Model Optimized**")
990
+ if optimized_dna is not None and isinstance(optimized_dna, str) and len(optimized_dna) > 150:
991
+ fig_after = create_gc_content_plot(optimized_dna)
992
+ fig_after.update_layout(title="After Optimization", height=300)
993
+ st.plotly_chart(fig_after, use_container_width=True)
994
+ else:
995
+ st.info("Sequence too short for sliding window analysis")
996
+
997
+ def batch_processing_interface():
998
+ """Batch processing interface for multiple sequences"""
999
+ st.header("Batch Processing")
1000
+ st.markdown("**Process multiple protein sequences simultaneously with optimization**")
1001
+
1002
+ # File upload section
1003
+ st.subheader("Upload Sequences")
1004
+ uploaded_file = st.file_uploader(
1005
+ "Choose a file with multiple sequences",
1006
+ type=['csv', 'xlsx', 'fasta', 'txt', 'fa'],
1007
+ help="Upload CSV, Excel (XLSX, with 'sequence' column) or FASTA format files"
1008
+ )
1009
+
1010
+ if uploaded_file:
1011
+ st.success(f"File uploaded: {uploaded_file.name}")
1012
+
1013
+ # Process uploaded file
1014
+ try:
1015
+ def find_column(df, target):
1016
+ # Find column name case-insensitively and ignoring spaces
1017
+ for col in df.columns:
1018
+ if col.strip().lower() == target:
1019
+ return col
1020
+ return None
1021
+
1022
+ if uploaded_file.name.endswith('.csv'):
1023
+ df = pd.read_csv(uploaded_file)
1024
+ seq_col = find_column(df, 'sequence')
1025
+ name_col = find_column(df, 'name')
1026
+ if seq_col:
1027
+ sequences = df[seq_col].tolist()
1028
+ if name_col:
1029
+ names = df[name_col].tolist()
1030
+ else:
1031
+ names = [f"Sequence_{i+1}" for i in range(len(sequences))]
1032
+ else:
1033
+ st.error("CSV file must contain a column named 'sequence' (case-insensitive, spaces ignored)")
1034
+ return
1035
+ elif uploaded_file.name.endswith('.xlsx'):
1036
+ df = pd.read_excel(uploaded_file)
1037
+ seq_col = find_column(df, 'sequence')
1038
+ name_col = find_column(df, 'name')
1039
+ if seq_col:
1040
+ sequences = df[seq_col].tolist()
1041
+ if name_col:
1042
+ names = df[name_col].tolist()
1043
+ else:
1044
+ names = [f"Sequence_{i+1}" for i in range(len(sequences))]
1045
+ else:
1046
+ st.error("Excel file must contain a column named 'sequence' (case-insensitive, spaces ignored)")
1047
+ return
1048
+ else:
1049
+ # Handle FASTA format
1050
+ content = uploaded_file.read().decode('utf-8')
1051
+ sequences, names = parse_fasta_content(content)
1052
+
1053
+ st.info(f"Found {len(sequences)} sequences ready for optimization")
1054
+
1055
+ # Batch configuration
1056
+ col1, col2 = st.columns(2)
1057
+ with col1:
1058
+ batch_organism = st.selectbox("Target Organism", [
1059
+ "Escherichia coli general", "Saccharomyces cerevisiae", "Homo sapiens"
1060
+ ])
1061
+ with col2:
1062
+ max_sequences = st.number_input("Max sequences to process", 1, len(sequences), min(10, len(sequences)))
1063
+
1064
+ # Start batch processing
1065
+ if st.button("Start Batch Optimization", type="primary"):
1066
+ run_batch_optimization(sequences[:max_sequences], names[:max_sequences], batch_organism)
1067
+
1068
+ except Exception as e:
1069
+ st.error(f"Error processing file: {str(e)}")
1070
+
1071
+ # Batch results display
1072
+ if 'batch_results' in st.session_state and st.session_state.batch_results:
1073
+ display_batch_results()
1074
+
1075
+ def parse_fasta_content(content):
1076
+ """Parse FASTA format content"""
1077
+ sequences = []
1078
+ names = []
1079
+ current_seq = ""
1080
+ current_name = ""
1081
+
1082
+ for line in content.split('\n'):
1083
+ line = line.strip()
1084
+ if line.startswith('>'):
1085
+ if current_seq:
1086
+ sequences.append(current_seq)
1087
+ names.append(current_name)
1088
+ current_name = line[1:] if len(line) > 1 else f"Sequence_{len(sequences)+1}"
1089
+ current_seq = ""
1090
+ else:
1091
+ current_seq += line
1092
+
1093
+ if current_seq:
1094
+ sequences.append(current_seq)
1095
+ names.append(current_name)
1096
+
1097
+ return sequences, names
1098
+
1099
+ def run_batch_optimization(sequences, names, organism):
1100
+ """Run batch optimization with progress tracking"""
1101
+ st.session_state.batch_results = []
1102
+ st.session_state.batch_logs = [] # Collect info logs for auto-fixes
1103
+
1104
+ # Load reference data for CAI/tAI
1105
+ load_reference_data(organism)
1106
+ cai_weights = st.session_state.get('cai_weights', None)
1107
+ tai_weights = st.session_state.get('tai_weights', None)
1108
+
1109
+ # Create progress tracking
1110
+ progress_bar = st.progress(0)
1111
+ status_text = st.empty()
1112
+
1113
+ for i, (seq, name) in enumerate(zip(sequences, names)):
1114
+ progress = (i + 1) / len(sequences)
1115
+ progress_bar.progress(progress)
1116
+ status_text.text(f"Processing {name} ({i+1}/{len(sequences)})")
1117
+
1118
+ try:
1119
+ # Validate sequence and get possibly fixed sequence
1120
+ is_valid, message, sequence_type, fixed_seq = validate_sequence(seq)
1121
+ if is_valid:
1122
+ # Log if auto-fixed
1123
+ if 'auto-fixed' in message:
1124
+ st.session_state.batch_logs.append(f"{name}: {message}")
1125
+ # Calculate original metrics (use fixed_seq for DNA)
1126
+ if sequence_type == "dna":
1127
+ orig_gc = get_GC_content(fixed_seq)
1128
+ orig_cai = CAI(fixed_seq, weights=cai_weights) if cai_weights else None
1129
+ orig_tai = calculate_tAI(fixed_seq, tai_weights) if tai_weights else None
1130
+ else:
1131
+ # For protein, create baseline DNA
1132
+ most_frequent_codons = {
1133
+ 'A': 'GCG', 'C': 'TGC', 'D': 'GAT', 'E': 'GAA', 'F': 'TTT',
1134
+ 'G': 'GGC', 'H': 'CAT', 'I': 'ATT', 'K': 'AAA', 'L': 'CTG',
1135
+ 'M': 'ATG', 'N': 'AAC', 'P': 'CCG', 'Q': 'CAG', 'R': 'CGC',
1136
+ 'S': 'TCG', 'T': 'ACG', 'V': 'GTG', 'W': 'TGG', 'Y': 'TAT',
1137
+ '*': 'TAA', '_': 'TAA'
1138
+ }
1139
+ baseline_dna = ''.join([most_frequent_codons.get(aa, 'NNN') for aa in fixed_seq])
1140
+ orig_gc = get_GC_content(baseline_dna)
1141
+ orig_cai = CAI(baseline_dna, weights=cai_weights) if cai_weights else None
1142
+ orig_tai = calculate_tAI(baseline_dna, tai_weights) if tai_weights else None
1143
+
1144
+ # Run optimization using the fixed sequence
1145
+ result = predict_dna_sequence(
1146
+ protein=fixed_seq if sequence_type == "protein" else translate_dna_to_protein(fixed_seq),
1147
+ organism=organism,
1148
+ device=st.session_state.device,
1149
+ model=st.session_state.model,
1150
+ deterministic=True,
1151
+ match_protein=True,
1152
+ )
1153
+
1154
+ # If result is a list, use the first element
1155
+ if isinstance(result, list):
1156
+ result_obj = result[0]
1157
+ else:
1158
+ result_obj = result
1159
+
1160
+ # Calculate optimized metrics
1161
+ opt_gc = get_GC_content(result_obj.predicted_dna)
1162
+ opt_cai = CAI(result_obj.predicted_dna, weights=cai_weights) if cai_weights else None
1163
+ opt_tai = calculate_tAI(result_obj.predicted_dna, tai_weights) if tai_weights else None
1164
+
1165
+ metrics = {
1166
+ 'name': name,
1167
+ 'original_sequence': fixed_seq,
1168
+ 'optimized_dna': result_obj.predicted_dna,
1169
+ 'gc_content_before': orig_gc,
1170
+ 'gc_content_after': opt_gc,
1171
+ 'cai_before': orig_cai,
1172
+ 'cai_after': opt_cai,
1173
+ 'tai_before': orig_tai,
1174
+ 'tai_after': opt_tai,
1175
+ 'length_before': len(fixed_seq),
1176
+ 'length_after': len(result_obj.predicted_dna),
1177
+ 'validation_message': message
1178
+ }
1179
+
1180
+ st.session_state.batch_results.append(metrics)
1181
+ else:
1182
+ # Only skip if truly invalid (not auto-fixable)
1183
+ st.session_state.batch_logs.append(f"{name}: {message}")
1184
+
1185
+ except Exception as e:
1186
+ st.session_state.batch_logs.append(f"{name}: Error processing: {str(e)}")
1187
+
1188
+ progress_bar.empty()
1189
+ status_text.empty()
1190
+ st.success(f"Batch optimization complete. Processed {len(st.session_state.batch_results)} sequences.")
1191
+
1192
+ def display_batch_results():
1193
+ """Display batch processing results"""
1194
+ st.subheader("Batch Results")
1195
+
1196
+ # Show all logs (auto-fixes and errors)
1197
+ if hasattr(st.session_state, 'batch_logs') and st.session_state.batch_logs:
1198
+ for log in st.session_state.batch_logs:
1199
+ st.info(log)
1200
+
1201
+ results_df = pd.DataFrame(st.session_state.batch_results)
1202
+
1203
+ # Summary statistics
1204
+ col1, col2, col3, col4 = st.columns(4)
1205
+ with col1:
1206
+ st.metric("Sequences Processed", len(results_df))
1207
+ with col2:
1208
+ st.metric("Avg GC Before", f"{results_df['gc_content_before'].mean():.1f}%")
1209
+ st.metric("Avg GC After", f"{results_df['gc_content_after'].mean():.1f}%")
1210
+ with col3:
1211
+ st.metric("Avg CAI Before", f"{results_df['cai_before'].mean():.3f}")
1212
+ st.metric("Avg CAI After", f"{results_df['cai_after'].mean():.3f}")
1213
+ with col4:
1214
+ st.metric("Avg tAI Before", f"{results_df['tai_before'].mean():.3f}")
1215
+ st.metric("Avg tAI After", f"{results_df['tai_after'].mean():.3f}")
1216
+
1217
+ # CAI Extremes Analysis
1218
+ st.subheader("CAI Performance Analysis")
1219
+
1220
+ # Filter out rows with NaN CAI values for analysis
1221
+ valid_cai_df = results_df.dropna(subset=['cai_after'])
1222
+
1223
+ if len(valid_cai_df) > 0:
1224
+ # Find lowest and highest CAI sequences
1225
+ lowest_cai_idx = valid_cai_df['cai_after'].idxmin()
1226
+ highest_cai_idx = valid_cai_df['cai_after'].idxmax()
1227
+
1228
+ lowest_cai_row = results_df.loc[lowest_cai_idx]
1229
+ highest_cai_row = results_df.loc[highest_cai_idx]
1230
+
1231
+ col1, col2 = st.columns(2)
1232
+
1233
+ with col1:
1234
+ st.markdown("**Lowest CAI Sequence**")
1235
+ st.write(f"**Name:** {lowest_cai_row['name']}")
1236
+ st.metric("CAI Score", f"{lowest_cai_row['cai_after']:.3f}")
1237
+ st.metric("GC Content", f"{lowest_cai_row['gc_content_after']:.1f}%")
1238
+ st.metric("tAI Score", f"{lowest_cai_row['tai_after']:.3f}")
1239
+ st.metric("Length", f"{lowest_cai_row['length_after']} bp")
1240
+
1241
+ # Show improvement
1242
+ if pd.notna(lowest_cai_row['cai_before']):
1243
+ cai_improvement = lowest_cai_row['cai_after'] - lowest_cai_row['cai_before']
1244
+ st.metric("CAI Improvement", f"{cai_improvement:+.3f}")
1245
+
1246
+ with col2:
1247
+ st.markdown("**Highest CAI Sequence**")
1248
+ st.write(f"**Name:** {highest_cai_row['name']}")
1249
+ st.metric("CAI Score", f"{highest_cai_row['cai_after']:.3f}")
1250
+ st.metric("GC Content", f"{highest_cai_row['gc_content_after']:.1f}%")
1251
+ st.metric("tAI Score", f"{highest_cai_row['tai_after']:.3f}")
1252
+ st.metric("Length", f"{highest_cai_row['length_after']} bp")
1253
+
1254
+ # Show improvement
1255
+ if pd.notna(highest_cai_row['cai_before']):
1256
+ cai_improvement = highest_cai_row['cai_after'] - highest_cai_row['cai_before']
1257
+ st.metric("CAI Improvement", f"{cai_improvement:+.3f}")
1258
+
1259
+ # CAI Distribution Chart
1260
+ st.subheader("CAI Distribution")
1261
+ fig = go.Figure()
1262
+ fig.add_trace(go.Histogram(
1263
+ x=valid_cai_df['cai_after'],
1264
+ nbinsx=20,
1265
+ name='Optimized CAI Scores',
1266
+ marker_color='darkblue',
1267
+ opacity=0.7
1268
+ ))
1269
+
1270
+ # Add vertical lines for lowest and highest
1271
+ fig.add_vline(
1272
+ x=lowest_cai_row['cai_after'],
1273
+ line_dash="dash",
1274
+ line_color="red",
1275
+ annotation_text=f"Lowest: {lowest_cai_row['cai_after']:.3f}"
1276
+ )
1277
+ fig.add_vline(
1278
+ x=highest_cai_row['cai_after'],
1279
+ line_dash="dash",
1280
+ line_color="green",
1281
+ annotation_text=f"Highest: {highest_cai_row['cai_after']:.3f}"
1282
+ )
1283
+
1284
+ fig.update_layout(
1285
+ title="Distribution of Optimized CAI Scores",
1286
+ xaxis_title="CAI Score",
1287
+ yaxis_title="Number of Sequences",
1288
+ height=400,
1289
+ showlegend=False
1290
+ )
1291
+ st.plotly_chart(fig, use_container_width=True)
1292
+
1293
+ # GC Content Distribution Chart
1294
+ st.subheader("GC Content Distribution")
1295
+ valid_gc_df = results_df.dropna(subset=['gc_content_after'])
1296
+ if len(valid_gc_df) > 0:
1297
+ lowest_gc_idx = valid_gc_df['gc_content_after'].idxmin()
1298
+ highest_gc_idx = valid_gc_df['gc_content_after'].idxmax()
1299
+ lowest_gc_row = results_df.loc[lowest_gc_idx]
1300
+ highest_gc_row = results_df.loc[highest_gc_idx]
1301
+
1302
+ fig_gc = go.Figure()
1303
+ fig_gc.add_trace(go.Histogram(
1304
+ x=valid_gc_df['gc_content_after'],
1305
+ nbinsx=20,
1306
+ name='Optimized GC Content',
1307
+ marker_color='teal',
1308
+ opacity=0.7
1309
+ ))
1310
+ fig_gc.add_vline(
1311
+ x=lowest_gc_row['gc_content_after'],
1312
+ line_dash="dash",
1313
+ line_color="red",
1314
+ annotation_text=f"Lowest: {lowest_gc_row['gc_content_after']:.1f}%"
1315
+ )
1316
+ fig_gc.add_vline(
1317
+ x=highest_gc_row['gc_content_after'],
1318
+ line_dash="dash",
1319
+ line_color="green",
1320
+ annotation_text=f"Highest: {highest_gc_row['gc_content_after']:.1f}%"
1321
+ )
1322
+ fig_gc.update_layout(
1323
+ title="Distribution of Optimized GC Content",
1324
+ xaxis_title="GC Content (%)",
1325
+ yaxis_title="Number of Sequences",
1326
+ height=400,
1327
+ showlegend=False
1328
+ )
1329
+ st.plotly_chart(fig_gc, use_container_width=True)
1330
+ else:
1331
+ st.warning("No valid GC content values found in the batch results.")
1332
+
1333
+ else:
1334
+ st.warning("No valid CAI scores found in the batch results. Check if CAI weights are properly loaded.")
1335
+
1336
+ # Sequence selector
1337
+ seq_names = results_df['name'].tolist()
1338
+ selected_seq = st.selectbox("Select a sequence to view details", seq_names)
1339
+ seq_row = results_df[results_df['name'] == selected_seq].iloc[0]
1340
+
1341
+ st.markdown(f"### Details for: {selected_seq}")
1342
+ if 'validation_message' in seq_row and 'auto-fixed' in seq_row['validation_message']:
1343
+ st.info(seq_row['validation_message'])
1344
+ col1, col2 = st.columns(2)
1345
+ with col1:
1346
+ st.markdown("**Original Sequence**")
1347
+ st.text_area("Original Sequence", seq_row['original_sequence'], height=100)
1348
+ st.metric("GC Content (Before)", f"{seq_row['gc_content_before']:.1f}%")
1349
+ st.metric("CAI (Before)", f"{seq_row['cai_before']:.3f}")
1350
+ st.metric("tAI (Before)", f"{seq_row['tai_before']:.3f}")
1351
+ st.metric("Length (Before)", f"{seq_row['length_before']}")
1352
+ with col2:
1353
+ st.markdown("**Optimized Sequence**")
1354
+ st.text_area("Optimized Sequence", seq_row['optimized_dna'], height=100)
1355
+ st.metric("GC Content (After)", f"{seq_row['gc_content_after']:.1f}%")
1356
+ st.metric("CAI (After)", f"{seq_row['cai_after']:.3f}")
1357
+ st.metric("tAI (After)", f"{seq_row['tai_after']:.3f}")
1358
+ st.metric("Length (After)", f"{seq_row['length_after']}")
1359
+
1360
+ # Plots for before/after GC content
1361
+ st.subheader("GC Content Distribution (Before vs After)")
1362
+ if len(seq_row['original_sequence']) > 150 and len(seq_row['optimized_dna']) > 150:
1363
+ fig_before = create_gc_content_plot(seq_row['original_sequence'])
1364
+ fig_before.update_layout(title="Before Optimization", height=300)
1365
+ fig_after = create_gc_content_plot(seq_row['optimized_dna'])
1366
+ fig_after.update_layout(title="After Optimization", height=300)
1367
+ st.plotly_chart(fig_before, use_container_width=True)
1368
+ st.plotly_chart(fig_after, use_container_width=True)
1369
+ else:
1370
+ st.info("Sequence(s) too short for sliding window analysis")
1371
+
1372
+ # Download batch results
1373
+ if st.button("Download Batch Results"):
1374
+ csv_data = results_df.to_csv(index=False)
1375
+ st.download_button(
1376
+ label="Download CSV",
1377
+ data=csv_data,
1378
+ file_name="batch_optimization_results.csv",
1379
+ mime="text/csv"
1380
+ )
1381
+
1382
+ def comparative_analysis_interface():
1383
+ """Comparative analysis interface"""
1384
+ st.header("Comparative Analysis")
1385
+ st.markdown("For quantitative comparisons and plots, use the benchmark script:")
1386
+ st.code("python scripts/run_benchmarks.py --config configs/benchmark.yaml")
1387
+
1388
+ def advanced_settings_interface():
1389
+ """Advanced settings and configuration interface"""
1390
+ st.header("Advanced Settings")
1391
+ st.markdown("**Configure advanced parameters and model settings**")
1392
+
1393
+ # Model configuration
1394
+ st.subheader("Model Configuration")
1395
+ col1, col2 = st.columns(2)
1396
+
1397
+ with col1:
1398
+ st.write("**Current Model Status:**")
1399
+ if st.session_state.model:
1400
+ model_type = getattr(st.session_state, 'model_type', 'unknown')
1401
+ st.success(f"Model loaded: {model_type}")
1402
+ st.write(f"Device: {st.session_state.device}")
1403
+ else:
1404
+ st.warning("Model not loaded")
1405
+
1406
+ with col2:
1407
+ st.write("**Model Information:**")
1408
+ st.write("• Architecture: BigBird Transformer")
1409
+ st.write("• Parameters: 89.6M")
1410
+ st.write("• Fine-tuning data: 3,676 high-expression E. coli genes (NCBI-curated)")
1411
+
1412
+ # Performance tuning
1413
+ st.subheader("Performance Tuning")
1414
+
1415
+ # Memory management
1416
+ col1, col2 = st.columns(2)
1417
+ with col1:
1418
+ if st.button("Clear Cache"):
1419
+ st.cache_data.clear()
1420
+ st.success("Cache cleared successfully")
1421
+
1422
+ with col2:
1423
+ if st.button("Reload Model"):
1424
+ st.session_state.model = None
1425
+ st.session_state.tokenizer = None
1426
+ st.rerun()
1427
+
1428
+ # System information
1429
+ st.subheader("System Information")
1430
+ import torch
1431
+ col1, col2, col3 = st.columns(3)
1432
+
1433
+ with col1:
1434
+ st.write("**PyTorch:**")
1435
+ st.write(f"Version: {torch.__version__}")
1436
+ st.write(f"CUDA Available: {torch.cuda.is_available()}")
1437
+
1438
+ with col2:
1439
+ st.write("**Device:**")
1440
+ st.write(f"Current: {st.session_state.device}")
1441
+ if torch.cuda.is_available():
1442
+ st.write(f"GPU: {torch.cuda.get_device_name()}")
1443
+
1444
+ with col3:
1445
+ st.write("**Memory:**")
1446
+ if torch.cuda.is_available():
1447
+ gpu_memory = torch.cuda.get_device_properties(0).total_memory / 1e9
1448
+ st.write(f"GPU Memory: {gpu_memory:.1f} GB")
1449
+
1450
+ # Footer
1451
+ st.markdown("---")
1452
+ st.markdown("**ENCOT**")
1453
+ st.markdown("Open-source codon optimization for E. coli with reproducible evaluation.")
1454
+
1455
+ if __name__ == "__main__":
1456
+ main()
streamlit_gui/demo.py ADDED
@@ -0,0 +1,288 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Demo script for ColiFormer Streamlit GUI
4
+
5
+ This script demonstrates the GUI functionality with example sequences
6
+ and showcases key features of the ColiFormer optimization tool.
7
+ """
8
+
9
+ import sys
10
+ import os
11
+ import time
12
+ from pathlib import Path
13
+
14
+ # Add parent directory to path for imports
15
+ sys.path.append(str(Path(__file__).parent.parent))
16
+
17
+ def print_header():
18
+ """Print demo header"""
19
+ print("=" * 40)
20
+ print(" ColiFormer GUI Demo")
21
+ print("=" * 40)
22
+ print()
23
+
24
+ def print_section(title):
25
+ """Print section header"""
26
+ print(f"\n{title}")
27
+ print("-" * (len(title) + 4))
28
+
29
+ def demo_validation():
30
+ """Demonstrate protein sequence validation"""
31
+ print_section("Protein Sequence Validation")
32
+
33
+ # Import validation function
34
+ from streamlit_gui.app import validate_protein_sequence
35
+
36
+ test_sequences = [
37
+ ("MKTVRQERLK", "Valid short peptide"),
38
+ ("MALWMRLLPLLALLALWGPDPAAAFVNQHLCGSHLVEALYLVCGERGFFYTPKTRREAEDLQVGQVELGG", "Valid longer protein"),
39
+ ("MKTVRQERLKX", "Invalid character (X)"),
40
+ ("MK", "Too short"),
41
+ ("mktvrqerlk", "Lowercase (should work)"),
42
+ ("MKTVRQERLK*", "With stop codon"),
43
+ ]
44
+
45
+ for seq, description in test_sequences:
46
+ is_valid, message = validate_protein_sequence(seq)
47
+ status = "OK" if is_valid else "FAIL"
48
+ print(f"{status} {description}: {message}")
49
+
50
+ def demo_metrics():
51
+ """Demonstrate metrics calculation"""
52
+ print_section("Metrics Calculation Demo")
53
+
54
+ from streamlit_gui.app import calculate_input_metrics
55
+
56
+ example_proteins = [
57
+ ("MKTVRQERLK", "Short peptide (10 AA)"),
58
+ ("MALWMRLLPLLALLALWGPDPAAAFVNQHLCGSHLVEALYLVCGERGFFYTPKTRREAEDLQVGQVELGG", "Medium protein (67 AA)"),
59
+ ("MKWVTFISLLLLFSSAYSRGVFRRDTHKSEIAHRFKDLGEEHFKGLVLIAFSQYLQQCPFDEHVKLVNELTE", "Long protein (72 AA)"),
60
+ ]
61
+
62
+ organism = "Escherichia coli general"
63
+
64
+ for protein, description in example_proteins:
65
+ print(f"\n{description}")
66
+ print(f" Sequence: {protein[:30]}{'...' if len(protein) > 30 else ''}")
67
+
68
+ metrics = calculate_input_metrics(protein, organism)
69
+
70
+ print(f" Length: {metrics['length']} amino acids")
71
+ print(f" GC Content: {metrics['gc_content']:.1f}%")
72
+ if metrics['tai']:
73
+ print(f" tAI: {metrics['tai']:.3f}")
74
+ if metrics['cai']:
75
+ print(f" CAI: {metrics['cai']:.3f}")
76
+ else:
77
+ print(" CAI: Not available for this organism")
78
+
79
+ def demo_visualization():
80
+ """Demonstrate visualization capabilities"""
81
+ print_section("Visualization Demo")
82
+
83
+ from streamlit_gui.app import create_gc_content_plot, create_metrics_comparison_chart
84
+
85
+ # Test DNA sequence for GC content plot
86
+ test_dna = "ATGGCGAAAGCGCTGTATCGCGAAAGCGCTGTATCGCGAAAGCGCTGTATCGCGAAAGCGCTGTATCGC"
87
+
88
+ print("Creating GC content sliding window plot...")
89
+ try:
90
+ fig = create_gc_content_plot(test_dna)
91
+ print(" OK: GC content plot created successfully")
92
+ print(f" Analyzing {len(test_dna)} base pairs")
93
+ except Exception as e:
94
+ print(f" FAIL: Error creating GC plot: {e}")
95
+
96
+ print("\nCreating metrics comparison chart...")
97
+ try:
98
+ before_metrics = {
99
+ 'gc_content': 45.2,
100
+ 'cai': 0.485,
101
+ 'tai': 0.312
102
+ }
103
+ after_metrics = {
104
+ 'gc_content': 52.1,
105
+ 'cai': 0.634,
106
+ 'tai': 0.456
107
+ }
108
+ fig = create_metrics_comparison_chart(before_metrics, after_metrics)
109
+ print(" OK: Comparison chart created successfully")
110
+ print(" Shows improvement in all metrics")
111
+ except Exception as e:
112
+ print(f" FAIL: Error creating comparison chart: {e}")
113
+
114
+ def demo_codon_evaluation():
115
+ """Demonstrate CodonEvaluation functions"""
116
+ print_section("CodonEvaluation Functions Demo")
117
+
118
+ from CodonTransformer.CodonEvaluation import get_GC_content, calculate_tAI, get_ecoli_tai_weights
119
+
120
+ test_sequences = [
121
+ ("ATGGCGAAAGCGCTGTATCGC", "High GC content"),
122
+ ("ATGAAATTTATTTATTATTAT", "Low GC content"),
123
+ ("ATGGCGAAAGCGCTGTATCGCGAAAGCGCTGTATCGC", "Medium length"),
124
+ ]
125
+
126
+ print("Testing GC content calculation:")
127
+ for seq, description in test_sequences:
128
+ gc_content = get_GC_content(seq)
129
+ print(f" {description}: {gc_content:.1f}%")
130
+
131
+ print("\nTesting tAI calculation:")
132
+ try:
133
+ tai_weights = get_ecoli_tai_weights()
134
+ for seq, description in test_sequences:
135
+ tai_value = calculate_tAI(seq, tai_weights)
136
+ print(f" {description}: {tai_value:.3f}")
137
+ except Exception as e:
138
+ print(f" FAIL: tAI calculation error: {e}")
139
+
140
+ def demo_model_info():
141
+ """Show model information"""
142
+ print_section("Model Information")
143
+
144
+ try:
145
+ import torch
146
+ from transformers import AutoTokenizer
147
+
148
+ print("Model Details:")
149
+ print(" Base model: adibvafa/CodonTransformer")
150
+ print(" Architecture: BigBird Transformer")
151
+ print(" Task: Masked Language Modeling for codon optimization")
152
+
153
+ print("\nSystem Information:")
154
+ print(f" PyTorch: {torch.__version__}")
155
+ print(f" Device: {'GPU (CUDA)' if torch.cuda.is_available() else 'CPU'}")
156
+ if torch.cuda.is_available():
157
+ print(f" GPU: {torch.cuda.get_device_name(0)}")
158
+ print(f" GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f}GB")
159
+
160
+ print("\nTokenizer Test:")
161
+ tokenizer = AutoTokenizer.from_pretrained("adibvafa/CodonTransformer")
162
+ print(f" OK: Tokenizer loaded: {len(tokenizer)} tokens")
163
+ print(f" Vocab size: {tokenizer.vocab_size}")
164
+
165
+ except Exception as e:
166
+ print(f" FAIL: Error loading model info: {e}")
167
+
168
+ def demo_gui_features():
169
+ """Show GUI features overview"""
170
+ print_section("GUI Features Overview")
171
+
172
+ features = [
173
+ ("Real-time Validation", "Instant feedback on protein sequence validity"),
174
+ ("Metrics Dashboard", "GC content, CAI, tAI calculations"),
175
+ ("Constrained Optimization", "GC content control with beam search"),
176
+ ("Visual Analytics", "Interactive plots and comparisons"),
177
+ ("Configurable Parameters", "Organism selection, beam size, GC targets"),
178
+ ("Export Options", "Download optimized sequences"),
179
+ ("Progress Tracking", "Real-time optimization progress"),
180
+ ("Responsive Design", "Works on desktop and mobile"),
181
+ ]
182
+
183
+ for feature, description in features:
184
+ print(f" {feature}: {description}")
185
+
186
+ def demo_usage_examples():
187
+ """Show usage examples"""
188
+ print_section("Usage Examples")
189
+
190
+ examples = [
191
+ {
192
+ "name": "Short Peptide Optimization",
193
+ "protein": "MKTVRQERLK",
194
+ "organism": "Escherichia coli general",
195
+ "use_case": "Quick testing and validation"
196
+ },
197
+ {
198
+ "name": "Insulin Chain A",
199
+ "protein": "GIVEQCCTSICSLYQLENYCN",
200
+ "organism": "Escherichia coli general",
201
+ "use_case": "Pharmaceutical protein production"
202
+ },
203
+ {
204
+ "name": "Green Fluorescent Protein (partial)",
205
+ "protein": "MSKGEELFTGVVPILVELDGDVNGHKFSVSGEGEGDATYGKLTLKFICTTGKLPVPWPTLVTTFSYGVQC",
206
+ "organism": "Escherichia coli general",
207
+ "use_case": "Research marker protein"
208
+ },
209
+ {
210
+ "name": "Yeast Expression",
211
+ "protein": "MKTVRQERLKSIVRILERSKEPVSGAQ",
212
+ "organism": "Saccharomyces cerevisiae",
213
+ "use_case": "Eukaryotic protein expression"
214
+ }
215
+ ]
216
+
217
+ for i, example in enumerate(examples, 1):
218
+ print(f"\nExample {i}: {example['name']}")
219
+ print(f" Protein: {example['protein'][:40]}{'...' if len(example['protein']) > 40 else ''}")
220
+ print(f" Organism: {example['organism']}")
221
+ print(f" Use case: {example['use_case']}")
222
+ print(f" Length: {len(example['protein'])} amino acids")
223
+
224
+ def demo_launch_instructions():
225
+ """Show how to launch the GUI"""
226
+ print_section("How to Launch the GUI")
227
+
228
+ print("Launch Options:")
229
+ print()
230
+ print(" Option 1 - Using the launcher script:")
231
+ print(" $ cd ecoli/streamlit_gui")
232
+ print(" $ python run_gui.py")
233
+ print()
234
+ print(" Option 2 - Direct streamlit command:")
235
+ print(" $ cd ecoli/streamlit_gui")
236
+ print(" $ source ../codon_env/bin/activate")
237
+ print(" $ streamlit run app.py")
238
+ print()
239
+ print(" Option 3 - With custom port:")
240
+ print(" $ streamlit run app.py --server.port 8502")
241
+ print()
242
+ print("Access the GUI:")
243
+ print(" Web browser: http://localhost:8501")
244
+ print(" The GUI will automatically open in your default browser")
245
+ print()
246
+ print("Performance Tips:")
247
+ print(" • Use GPU if available for faster processing")
248
+ print(" • Start with shorter sequences for testing")
249
+ print(" • Adjust beam size based on sequence length")
250
+ print(" • Close other applications to free up memory")
251
+
252
+ def main():
253
+ """Run the complete demo"""
254
+ print_header()
255
+
256
+ print("This demo showcases the ENCOT Streamlit GUI capabilities.")
257
+ print("The GUI provides an interface for protein codon optimization.")
258
+ print()
259
+
260
+ try:
261
+ demo_validation()
262
+ demo_metrics()
263
+ demo_visualization()
264
+ demo_codon_evaluation()
265
+ demo_model_info()
266
+ demo_gui_features()
267
+ demo_usage_examples()
268
+ demo_launch_instructions()
269
+
270
+ print("\nDemo completed successfully.")
271
+ print()
272
+ print("Next steps:")
273
+ print("1. Launch the GUI using one of the methods above")
274
+ print("2. Try the example sequences provided")
275
+ print("3. Experiment with different organisms and settings")
276
+ print("4. Compare optimization results")
277
+ print()
278
+ print("Happy optimizing.")
279
+
280
+ except Exception as e:
281
+ print(f"\nDemo error: {e}")
282
+ print("Make sure you're running from the correct directory and all dependencies are installed.")
283
+ return 1
284
+
285
+ return 0
286
+
287
+ if __name__ == "__main__":
288
+ exit(main())
streamlit_gui/requirements.txt ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ streamlit>=1.28.0
2
+ torch>=1.13.0
3
+ pandas>=1.5.0
4
+ numpy>=1.21.0
5
+ plotly>=5.0.0
6
+ transformers>=4.21.0
7
+ scipy>=1.9.0
8
+ tokenizers>=0.13.0
9
+ tqdm>=4.64.0
10
+ matplotlib>=3.5.0
11
+ seaborn>=0.11.0
12
+ onnxruntime>=1.15.0
13
+ python-codon-tables>=0.1.12
14
+ biopython>=1.79
15
+ scikit-learn>=1.0.0
16
+ requests>=2.25.0
17
+ ipywidgets>=7.6.0
18
+ huggingface-hub>=0.20.0
19
+ datasets>=2.0.0
20
+ git+https://github.com/Benjamin-Lee/CodonAdaptationIndex.git
streamlit_gui/run_gui.py ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Launcher script for ColiFormer Streamlit GUI
4
+
5
+ This script sets up the environment and launches the Streamlit application.
6
+ """
7
+
8
+ import os
9
+ import sys
10
+ import subprocess
11
+ from pathlib import Path
12
+
13
+ def main():
14
+ """Launch the Streamlit GUI application"""
15
+
16
+ # Get the directory containing this script
17
+ script_dir = Path(__file__).parent
18
+
19
+ # Add the parent directory to Python path so we can import CodonTransformer
20
+ parent_dir = script_dir.parent
21
+ sys.path.insert(0, str(parent_dir))
22
+
23
+ # Set working directory to parent directory so model paths work correctly
24
+ os.chdir(parent_dir)
25
+
26
+ print("Starting ENCOT GUI...")
27
+ print(f" Working directory: {parent_dir}")
28
+ print(f" Python path includes: {parent_dir}")
29
+
30
+ # Check for model checkpoint
31
+ model_path = parent_dir / "models" / "alm-enhanced-training" / "balanced_alm_finetune.ckpt"
32
+ if model_path.exists():
33
+ print(f"Found fine-tuned model: {model_path}")
34
+ else:
35
+ print("Fine-tuned model not found, will use base model")
36
+
37
+ # Check for virtual environment
38
+ venv_path = parent_dir / "codon_env"
39
+ if venv_path.exists():
40
+ # Set up virtual environment paths
41
+ venv_bin = venv_path / "bin"
42
+ venv_python = venv_bin / "python"
43
+
44
+ if venv_python.exists():
45
+ print(f"Found virtual environment: {venv_path}")
46
+ # Update PATH to include virtual environment
47
+ current_path = os.environ.get("PATH", "")
48
+ os.environ["PATH"] = f"{venv_bin}:{current_path}"
49
+ # Use virtual environment Python
50
+ python_executable = str(venv_python)
51
+ else:
52
+ print("Virtual environment found but Python executable missing")
53
+ python_executable = sys.executable
54
+ else:
55
+ print("No virtual environment found, using system Python")
56
+ python_executable = sys.executable
57
+
58
+ print(f" Using Python: {python_executable}")
59
+ print()
60
+
61
+ # Check if streamlit is installed
62
+ try:
63
+ import streamlit
64
+ print(f"Streamlit version: {streamlit.__version__}")
65
+ except ImportError:
66
+ print("Streamlit not found. Please install requirements:")
67
+ print(" pip install -r requirements.txt")
68
+ return 1
69
+
70
+ # Check if torch is available
71
+ try:
72
+ import torch
73
+ device = "GPU" if torch.cuda.is_available() else "CPU"
74
+ print(f"PyTorch available, using: {device}")
75
+ except ImportError:
76
+ print("PyTorch not found. Please install requirements:")
77
+ print(" pip install -r requirements.txt")
78
+ return 1
79
+
80
+ print()
81
+ print("Launching GUI...")
82
+ print(" The application will open in your default web browser")
83
+ print(" Press Ctrl+C to stop the server")
84
+ print()
85
+
86
+ # Launch streamlit
87
+ try:
88
+ subprocess.run([
89
+ python_executable, "-m", "streamlit", "run", "streamlit_gui/app.py",
90
+ "--server.headless", "false",
91
+ "--server.port", "8501",
92
+ "--server.address", "0.0.0.0"
93
+ ])
94
+ except KeyboardInterrupt:
95
+ print("\nShutting down ENCOT GUI...")
96
+ return 0
97
+ except Exception as e:
98
+ print(f"Error launching Streamlit: {e}")
99
+ return 1
100
+
101
+ if __name__ == "__main__":
102
+ exit(main())
streamlit_gui/test_gui.py ADDED
@@ -0,0 +1,321 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Test script for ColiFormer Streamlit GUI
4
+
5
+ This script tests the core functionality of the GUI without running the full Streamlit application.
6
+ """
7
+
8
+ import sys
9
+ import os
10
+ import traceback
11
+ from pathlib import Path
12
+
13
+ # Add parent directory to path for imports
14
+ sys.path.append(str(Path(__file__).parent.parent))
15
+
16
+ def test_imports():
17
+ """Test if all required imports work"""
18
+ print("Testing imports...")
19
+
20
+ try:
21
+ import streamlit as st
22
+ print(f" OK: Streamlit: {st.__version__}")
23
+ except ImportError as e:
24
+ print(f" FAIL: Streamlit: {e}")
25
+ return False
26
+
27
+ try:
28
+ import torch
29
+ device = "GPU" if torch.cuda.is_available() else "CPU"
30
+ print(f" OK: PyTorch: {torch.__version__} ({device})")
31
+ except ImportError as e:
32
+ print(f" FAIL: PyTorch: {e}")
33
+ return False
34
+
35
+ try:
36
+ import plotly
37
+ print(f" OK: Plotly: {plotly.__version__}")
38
+ except ImportError as e:
39
+ print(f" FAIL: Plotly: {e}")
40
+ return False
41
+
42
+ try:
43
+ from CodonTransformer.CodonPrediction import predict_dna_sequence
44
+ print(" OK: CodonTransformer.CodonPrediction")
45
+ except ImportError as e:
46
+ print(f" FAIL: CodonTransformer.CodonPrediction: {e}")
47
+ return False
48
+
49
+ try:
50
+ from CodonTransformer.CodonEvaluation import get_GC_content, calculate_tAI
51
+ print(" OK: CodonTransformer.CodonEvaluation")
52
+ except ImportError as e:
53
+ print(f" FAIL: CodonTransformer.CodonEvaluation: {e}")
54
+ return False
55
+
56
+ return True
57
+
58
+ def test_protein_validation():
59
+ """Test protein sequence validation"""
60
+ print("\nTesting protein sequence validation...")
61
+
62
+ try:
63
+ # Import the validation function
64
+ from app import validate_protein_sequence
65
+
66
+ # Test cases
67
+ test_cases = [
68
+ ("MKTVRQERLK", True, "Valid short sequence"),
69
+ ("", False, "Empty sequence"),
70
+ ("MKTVRQERLKX", False, "Invalid character X"),
71
+ ("MK", False, "Too short"),
72
+ ("M" * 501, False, "Too long"),
73
+ ("mktvrqerlk", True, "Lowercase (should work)"),
74
+ ("MKTVRQERLK*", True, "With stop codon"),
75
+ ("MKTVRQERLK_", True, "With underscore stop"),
76
+ ]
77
+
78
+ for seq, expected_valid, description in test_cases:
79
+ is_valid, message = validate_protein_sequence(seq)
80
+ status = "OK" if is_valid == expected_valid else "FAIL"
81
+ print(f" {status} {description}: {message}")
82
+
83
+ return True
84
+ except Exception as e:
85
+ print(f" FAIL: Error in validation test: {e}")
86
+ traceback.print_exc()
87
+ return False
88
+
89
+ def test_metrics_calculation():
90
+ """Test metrics calculation"""
91
+ print("\nTesting metrics calculation...")
92
+
93
+ try:
94
+ from app import calculate_input_metrics
95
+
96
+ test_protein = "MKTVRQERLK"
97
+ organism = "Escherichia coli general"
98
+
99
+ metrics = calculate_input_metrics(test_protein, organism)
100
+
101
+ # Check if all expected metrics are present
102
+ expected_keys = ['length', 'gc_content', 'baseline_dna', 'cai', 'tai']
103
+ for key in expected_keys:
104
+ if key in metrics:
105
+ print(f" OK: {key}: {metrics[key]}")
106
+ else:
107
+ print(f" FAIL: Missing metric: {key}")
108
+ return False
109
+
110
+ # Validate metric values
111
+ if metrics['length'] == len(test_protein):
112
+ print(" OK: Length calculation correct")
113
+ else:
114
+ print(" FAIL: Length calculation incorrect")
115
+ return False
116
+
117
+ if 0 <= metrics['gc_content'] <= 100:
118
+ print(" OK: GC content in valid range")
119
+ else:
120
+ print(" FAIL: GC content out of range")
121
+ return False
122
+
123
+ return True
124
+ except Exception as e:
125
+ print(f" FAIL: Error in metrics calculation: {e}")
126
+ traceback.print_exc()
127
+ return False
128
+
129
+ def test_visualization_functions():
130
+ """Test visualization functions"""
131
+ print("\nTesting visualization functions...")
132
+
133
+ try:
134
+ from app import create_gc_content_plot, create_metrics_comparison_chart
135
+
136
+ # Test GC content plot
137
+ test_dna = "ATGGCGAAAGCGCTGTATCGCGAAAGCGCTGTATCGCGAAAGCGCTGTATCGC"
138
+ fig = create_gc_content_plot(test_dna)
139
+ print(" OK: GC content plot created")
140
+
141
+ # Test metrics comparison chart
142
+ before_metrics = {'gc_content': 50.0, 'cai': 0.5, 'tai': 0.3}
143
+ after_metrics = {'gc_content': 52.0, 'cai': 0.6, 'tai': 0.4}
144
+ fig = create_metrics_comparison_chart(before_metrics, after_metrics)
145
+ print(" OK: Metrics comparison chart created")
146
+
147
+ return True
148
+ except Exception as e:
149
+ print(f" FAIL: Error in visualization test: {e}")
150
+ traceback.print_exc()
151
+ return False
152
+
153
+ def test_codon_evaluation():
154
+ """Test CodonEvaluation functions directly"""
155
+ print("\nTesting CodonEvaluation functions...")
156
+
157
+ try:
158
+ from CodonTransformer.CodonEvaluation import get_GC_content, calculate_tAI, get_ecoli_tai_weights
159
+
160
+ # Test GC content calculation
161
+ test_dna = "ATGGCGAAAGCG"
162
+ gc_content = get_GC_content(test_dna)
163
+ print(f" OK: GC content calculation: {gc_content:.1f}%")
164
+
165
+ # Test tAI calculation
166
+ try:
167
+ tai_weights = get_ecoli_tai_weights()
168
+ tai_value = calculate_tAI(test_dna, tai_weights)
169
+ print(f" OK: tAI calculation: {tai_value:.3f}")
170
+ except Exception as e:
171
+ print(f" NOTE: tAI calculation (may need scipy): {e}")
172
+
173
+ return True
174
+ except Exception as e:
175
+ print(f" FAIL: Error in CodonEvaluation test: {e}")
176
+ traceback.print_exc()
177
+ return False
178
+
179
+ def test_model_loading():
180
+ """Test model loading functionality"""
181
+ print("\nTesting model loading (mock)...")
182
+
183
+ try:
184
+ import torch
185
+ from transformers import AutoTokenizer
186
+ from CodonTransformer.CodonPrediction import load_model
187
+
188
+ # Test tokenizer loading (this is fast)
189
+ print(" Testing tokenizer loading...")
190
+ tokenizer = AutoTokenizer.from_pretrained("adibvafa/CodonTransformer")
191
+ print(" OK: Tokenizer loaded successfully")
192
+
193
+ # Test load_model function
194
+ print(" Testing load_model function...")
195
+ from transformers import BigBirdForMaskedLM
196
+ print(" OK: Model class available: BigBirdForMaskedLM")
197
+
198
+ # Check if fine-tuned model exists
199
+ import os
200
+ model_path = "models/alm-enhanced-training/balanced_alm_finetune.ckpt"
201
+ if os.path.exists(model_path):
202
+ print(f" OK: Fine-tuned model found: {model_path}")
203
+ else:
204
+ print(f" NOTE: Fine-tuned model not found at: {model_path}")
205
+
206
+ # Note: We won't actually load the full model here as it's ~2GB
207
+ print(" NOTE: Full model loading skipped in test (too large)")
208
+
209
+ return True
210
+ except Exception as e:
211
+ print(f" FAIL: Error in model loading test: {e}")
212
+ traceback.print_exc()
213
+ return False
214
+
215
+ def test_file_structure():
216
+ """Test if all required files exist"""
217
+ print("\nTesting file structure...")
218
+
219
+ gui_dir = Path(__file__).parent
220
+ parent_dir = gui_dir.parent
221
+
222
+ required_files = [
223
+ "app.py",
224
+ "run_gui.py",
225
+ "requirements.txt",
226
+ "README.md"
227
+ ]
228
+
229
+ all_present = True
230
+ for file_name in required_files:
231
+ file_path = gui_dir / file_name
232
+ if file_path.exists():
233
+ print(f" OK: {file_name}")
234
+ else:
235
+ print(f" FAIL: {file_name} missing")
236
+ all_present = False
237
+
238
+ # Check for model checkpoint
239
+ model_path = parent_dir / "models" / "alm-enhanced-training" / "balanced_alm_finetune.ckpt"
240
+ if model_path.exists():
241
+ print(" OK: Fine-tuned model checkpoint found")
242
+ else:
243
+ print(" NOTE: Fine-tuned model checkpoint not found")
244
+
245
+ return all_present
246
+
247
+ def test_post_processing():
248
+ """Test post-processing functionality"""
249
+ print("\nTesting post-processing features...")
250
+
251
+ try:
252
+ from app import POST_PROCESSING_AVAILABLE, DNACHISEL_AVAILABLE
253
+
254
+ if POST_PROCESSING_AVAILABLE:
255
+ print(" OK: Post-processing module available")
256
+ if DNACHISEL_AVAILABLE:
257
+ print(" OK: DNAChisel available")
258
+ else:
259
+ print(" NOTE: DNAChisel not available")
260
+ else:
261
+ print(" NOTE: Post-processing module not available")
262
+
263
+ return True
264
+ except Exception as e:
265
+ print(f" FAIL: Error in post-processing test: {e}")
266
+ return False
267
+
268
+ def main():
269
+ """Run all tests"""
270
+ print("ENCOT GUI Test Suite")
271
+ print("=" * 50)
272
+
273
+ tests = [
274
+ ("File Structure", test_file_structure),
275
+ ("Imports", test_imports),
276
+ ("Protein Validation", test_protein_validation),
277
+ ("Metrics Calculation", test_metrics_calculation),
278
+ ("Visualization Functions", test_visualization_functions),
279
+ ("CodonEvaluation Functions", test_codon_evaluation),
280
+ ("Model Loading", test_model_loading),
281
+ ("Post-Processing", test_post_processing),
282
+ ]
283
+
284
+ passed = 0
285
+ total = len(tests)
286
+
287
+ for test_name, test_func in tests:
288
+ try:
289
+ result = test_func()
290
+ if result:
291
+ passed += 1
292
+ print(f"OK: {test_name}: PASSED")
293
+ else:
294
+ print(f"FAIL: {test_name}: FAILED")
295
+ except Exception as e:
296
+ print(f"FAIL: {test_name}: ERROR - {e}")
297
+
298
+ print("\n" + "=" * 50)
299
+ print(f"Test Results: {passed}/{total} tests passed")
300
+
301
+ if passed == total:
302
+ print("All tests passed. The GUI should work correctly.")
303
+ print("\nTo run the GUI:")
304
+ print(" python run_gui.py")
305
+ print(" or")
306
+ print(" cd streamlit_gui && streamlit run app.py --server.address=0.0.0.0")
307
+ else:
308
+ print("Some tests failed. Please check the issues above.")
309
+
310
+ print("\nNotes:")
311
+ print(" • Fine-tuned model integration")
312
+ print(" • Enhanced constrained beam search")
313
+ print(" • Post-processing with DNAChisel")
314
+ print(" • Advanced sequence analysis")
315
+ print(" • Improved parameter controls")
316
+
317
+ return passed == total
318
+
319
+ if __name__ == "__main__":
320
+ success = main()
321
+ sys.exit(0 if success else 1)