Spaces:
Sleeping
Sleeping
Gilmullin Almaz
commited on
Commit
·
72a3513
1
Parent(s):
3e5f8cc
Refactor code structure for improved readability and maintainability
Browse files- .gitattributes +0 -35
- Dockerfile +0 -21
- README.md +11 -14
- app.py +1349 -0
- pre-requirements.txt +7 -0
- requirements.txt +6 -3
- src/streamlit_app.py +0 -40
- synplan/__init__.py +3 -0
- synplan/chem/__init__.py +3 -0
- synplan/chem/data/__init__.py +0 -0
- synplan/chem/data/filtering.py +962 -0
- synplan/chem/data/standardizing.py +1187 -0
- synplan/chem/precursor.py +100 -0
- synplan/chem/reaction.py +125 -0
- synplan/chem/reaction_routes/__init__.py +0 -0
- synplan/chem/reaction_routes/clustering.py +859 -0
- synplan/chem/reaction_routes/io.py +286 -0
- synplan/chem/reaction_routes/leaving_groups.py +131 -0
- synplan/chem/reaction_routes/route_cgr.py +570 -0
- synplan/chem/reaction_routes/visualisation.py +903 -0
- synplan/chem/reaction_rules/__init__.py +0 -0
- synplan/chem/reaction_rules/extraction.py +744 -0
- synplan/chem/reaction_rules/manual/__init__.py +6 -0
- synplan/chem/reaction_rules/manual/decompositions.py +413 -0
- synplan/chem/reaction_rules/manual/transformations.py +532 -0
- synplan/chem/utils.py +225 -0
- synplan/interfaces/__init__.py +0 -0
- synplan/interfaces/cli.py +506 -0
- synplan/interfaces/gui.py +1323 -0
- synplan/mcts/__init__.py +8 -0
- synplan/mcts/evaluation.py +45 -0
- synplan/mcts/expansion.py +96 -0
- synplan/mcts/node.py +47 -0
- synplan/mcts/search.py +199 -0
- synplan/mcts/tree.py +635 -0
- synplan/ml/__init__.py +0 -0
- synplan/ml/networks/__init__.py +0 -0
- synplan/ml/networks/modules.py +234 -0
- synplan/ml/networks/policy.py +137 -0
- synplan/ml/networks/value.py +67 -0
- synplan/ml/training/__init__.py +11 -0
- synplan/ml/training/preprocessing.py +516 -0
- synplan/ml/training/reinforcement.py +379 -0
- synplan/ml/training/supervised.py +153 -0
- synplan/utils/__init__.py +4 -0
- synplan/utils/config.py +543 -0
- synplan/utils/files.py +226 -0
- synplan/utils/loading.py +151 -0
- synplan/utils/logging.py +179 -0
- synplan/utils/visualisation.py +1365 -0
.gitattributes
DELETED
|
@@ -1,35 +0,0 @@
|
|
| 1 |
-
*.7z filter=lfs diff=lfs merge=lfs -text
|
| 2 |
-
*.arrow filter=lfs diff=lfs merge=lfs -text
|
| 3 |
-
*.bin filter=lfs diff=lfs merge=lfs -text
|
| 4 |
-
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
| 5 |
-
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
| 6 |
-
*.ftz filter=lfs diff=lfs merge=lfs -text
|
| 7 |
-
*.gz filter=lfs diff=lfs merge=lfs -text
|
| 8 |
-
*.h5 filter=lfs diff=lfs merge=lfs -text
|
| 9 |
-
*.joblib filter=lfs diff=lfs merge=lfs -text
|
| 10 |
-
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
| 11 |
-
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
| 12 |
-
*.model filter=lfs diff=lfs merge=lfs -text
|
| 13 |
-
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
| 14 |
-
*.npy filter=lfs diff=lfs merge=lfs -text
|
| 15 |
-
*.npz filter=lfs diff=lfs merge=lfs -text
|
| 16 |
-
*.onnx filter=lfs diff=lfs merge=lfs -text
|
| 17 |
-
*.ot filter=lfs diff=lfs merge=lfs -text
|
| 18 |
-
*.parquet filter=lfs diff=lfs merge=lfs -text
|
| 19 |
-
*.pb filter=lfs diff=lfs merge=lfs -text
|
| 20 |
-
*.pickle filter=lfs diff=lfs merge=lfs -text
|
| 21 |
-
*.pkl filter=lfs diff=lfs merge=lfs -text
|
| 22 |
-
*.pt filter=lfs diff=lfs merge=lfs -text
|
| 23 |
-
*.pth filter=lfs diff=lfs merge=lfs -text
|
| 24 |
-
*.rar filter=lfs diff=lfs merge=lfs -text
|
| 25 |
-
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
| 26 |
-
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
| 27 |
-
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
| 28 |
-
*.tar filter=lfs diff=lfs merge=lfs -text
|
| 29 |
-
*.tflite filter=lfs diff=lfs merge=lfs -text
|
| 30 |
-
*.tgz filter=lfs diff=lfs merge=lfs -text
|
| 31 |
-
*.wasm filter=lfs diff=lfs merge=lfs -text
|
| 32 |
-
*.xz filter=lfs diff=lfs merge=lfs -text
|
| 33 |
-
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
-
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
-
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
Dockerfile
DELETED
|
@@ -1,21 +0,0 @@
|
|
| 1 |
-
FROM python:3.9-slim
|
| 2 |
-
|
| 3 |
-
WORKDIR /app
|
| 4 |
-
|
| 5 |
-
RUN apt-get update && apt-get install -y \
|
| 6 |
-
build-essential \
|
| 7 |
-
curl \
|
| 8 |
-
software-properties-common \
|
| 9 |
-
git \
|
| 10 |
-
&& rm -rf /var/lib/apt/lists/*
|
| 11 |
-
|
| 12 |
-
COPY requirements.txt ./
|
| 13 |
-
COPY src/ ./src/
|
| 14 |
-
|
| 15 |
-
RUN pip3 install -r requirements.txt
|
| 16 |
-
|
| 17 |
-
EXPOSE 8501
|
| 18 |
-
|
| 19 |
-
HEALTHCHECK CMD curl --fail http://localhost:8501/_stcore/health
|
| 20 |
-
|
| 21 |
-
ENTRYPOINT ["streamlit", "run", "src/streamlit_app.py", "--server.port=8501", "--server.address=0.0.0.0"]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
README.md
CHANGED
|
@@ -1,20 +1,17 @@
|
|
| 1 |
---
|
| 2 |
-
title:
|
| 3 |
-
emoji:
|
| 4 |
-
colorFrom:
|
| 5 |
-
colorTo:
|
| 6 |
-
sdk:
|
| 7 |
-
|
| 8 |
-
|
| 9 |
-
- streamlit
|
| 10 |
pinned: false
|
| 11 |
-
short_description: Developers mode for synplanner
|
| 12 |
license: mit
|
|
|
|
| 13 |
---
|
| 14 |
|
| 15 |
-
#
|
|
|
|
| 16 |
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
If you have any questions, checkout our [documentation](https://docs.streamlit.io) and [community
|
| 20 |
-
forums](https://discuss.streamlit.io).
|
|
|
|
| 1 |
---
|
| 2 |
+
title: SynPlanner GUI
|
| 3 |
+
emoji: 🧪
|
| 4 |
+
colorFrom: pink
|
| 5 |
+
colorTo: blue
|
| 6 |
+
sdk: streamlit
|
| 7 |
+
sdk_version: 1.37.0
|
| 8 |
+
app_file: app.py
|
|
|
|
| 9 |
pinned: false
|
|
|
|
| 10 |
license: mit
|
| 11 |
+
python_version: 3.11.9
|
| 12 |
---
|
| 13 |
|
| 14 |
+
# SynPlanner Graphical User Interface (GUI)
|
| 15 |
+
Try the GUI to find reaction paths...
|
| 16 |
|
| 17 |
+
**documentation to be done**
|
|
|
|
|
|
|
|
|
app.py
ADDED
|
@@ -0,0 +1,1349 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import base64
|
| 2 |
+
import pickle
|
| 3 |
+
import re
|
| 4 |
+
import uuid
|
| 5 |
+
import io
|
| 6 |
+
import zipfile
|
| 7 |
+
|
| 8 |
+
import pandas as pd
|
| 9 |
+
import streamlit as st
|
| 10 |
+
from CGRtools.files import SMILESRead
|
| 11 |
+
from streamlit_ketcher import st_ketcher
|
| 12 |
+
from huggingface_hub import hf_hub_download
|
| 13 |
+
from huggingface_hub.utils import disable_progress_bars
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
from synplan.mcts.expansion import PolicyNetworkFunction
|
| 17 |
+
from synplan.mcts.search import extract_tree_stats
|
| 18 |
+
from synplan.mcts.tree import Tree
|
| 19 |
+
from synplan.chem.utils import mol_from_smiles
|
| 20 |
+
from synplan.chem.reaction_routes.route_cgr import *
|
| 21 |
+
from synplan.chem.reaction_routes.clustering import *
|
| 22 |
+
|
| 23 |
+
from synplan.utils.visualisation import (
|
| 24 |
+
routes_clustering_report,
|
| 25 |
+
routes_subclustering_report,
|
| 26 |
+
generate_results_html,
|
| 27 |
+
html_top_routes_cluster,
|
| 28 |
+
get_route_svg,
|
| 29 |
+
get_route_svg_from_json,
|
| 30 |
+
get_route_svg_mod
|
| 31 |
+
)
|
| 32 |
+
from synplan.utils.config import TreeConfig, PolicyNetworkConfig
|
| 33 |
+
from synplan.utils.loading import load_reaction_rules, load_building_blocks
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
import psutil
|
| 37 |
+
import gc
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
disable_progress_bars("huggingface_hub")
|
| 41 |
+
|
| 42 |
+
smiles_parser = SMILESRead.create_parser(ignore=True)
|
| 43 |
+
DEFAULT_MOL = "c1cc(ccc1Cl)C(CCO)NC(C2(CCN(CC2)c3c4cc[nH]c4ncn3)N)=O"
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
# --- Helper Functions ---
|
| 47 |
+
def download_button(
|
| 48 |
+
object_to_download, download_filename, button_text, pickle_it=False
|
| 49 |
+
):
|
| 50 |
+
"""
|
| 51 |
+
Issued from
|
| 52 |
+
Generates a link to download the given object_to_download.
|
| 53 |
+
Params:
|
| 54 |
+
------
|
| 55 |
+
object_to_download: The object to be downloaded.
|
| 56 |
+
download_filename (str): filename and extension of file. e.g. mydata.csv,
|
| 57 |
+
some_txt_output.txt download_link_text (str): Text to display for download
|
| 58 |
+
link.
|
| 59 |
+
button_text (str): Text to display on download button (e.g. 'click here to download file')
|
| 60 |
+
pickle_it (bool): If True, pickle file.
|
| 61 |
+
Returns:
|
| 62 |
+
-------
|
| 63 |
+
(str): the anchor tag to download object_to_download
|
| 64 |
+
Examples:
|
| 65 |
+
--------
|
| 66 |
+
download_link(your_df, 'YOUR_DF.csv', 'Click to download data!')
|
| 67 |
+
download_link(your_str, 'YOUR_STRING.txt', 'Click to download text!')
|
| 68 |
+
"""
|
| 69 |
+
if pickle_it:
|
| 70 |
+
try:
|
| 71 |
+
object_to_download = pickle.dumps(object_to_download)
|
| 72 |
+
except pickle.PicklingError as e:
|
| 73 |
+
st.write(e)
|
| 74 |
+
return None
|
| 75 |
+
|
| 76 |
+
else:
|
| 77 |
+
if isinstance(object_to_download, bytes):
|
| 78 |
+
pass
|
| 79 |
+
|
| 80 |
+
elif isinstance(object_to_download, pd.DataFrame):
|
| 81 |
+
object_to_download = object_to_download.to_csv(index=False).encode("utf-8")
|
| 82 |
+
|
| 83 |
+
try:
|
| 84 |
+
b64 = base64.b64encode(object_to_download.encode()).decode()
|
| 85 |
+
except AttributeError:
|
| 86 |
+
b64 = base64.b64encode(object_to_download).decode()
|
| 87 |
+
|
| 88 |
+
button_uuid = str(uuid.uuid4()).replace("-", "")
|
| 89 |
+
button_id = re.sub("\d+", "", button_uuid)
|
| 90 |
+
|
| 91 |
+
custom_css = f"""
|
| 92 |
+
<style>
|
| 93 |
+
#{button_id} {{
|
| 94 |
+
background-color: rgb(255, 255, 255);
|
| 95 |
+
color: rgb(38, 39, 48);
|
| 96 |
+
text-decoration: none;
|
| 97 |
+
border-radius: 4px;
|
| 98 |
+
border-width: 1px;
|
| 99 |
+
border-style: solid;
|
| 100 |
+
border-color: rgb(230, 234, 241);
|
| 101 |
+
border-image: initial;
|
| 102 |
+
}}
|
| 103 |
+
#{button_id}:hover {{
|
| 104 |
+
border-color: rgb(246, 51, 102);
|
| 105 |
+
color: rgb(246, 51, 102);
|
| 106 |
+
}}
|
| 107 |
+
#{button_id}:active {{
|
| 108 |
+
box-shadow: none;
|
| 109 |
+
background-color: rgb(246, 51, 102);
|
| 110 |
+
color: white;
|
| 111 |
+
}}
|
| 112 |
+
</style> """
|
| 113 |
+
|
| 114 |
+
dl_link = (
|
| 115 |
+
custom_css
|
| 116 |
+
+ f'<a download="{download_filename}" id="{button_id}" href="data:file/txt;base64,{b64}">{button_text}</a><br></br>'
|
| 117 |
+
)
|
| 118 |
+
return dl_link
|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
@st.cache_resource
|
| 122 |
+
def load_planning_resources_cached(): # Renamed to avoid conflict if main calls it directly
|
| 123 |
+
building_blocks_path = hf_hub_download(
|
| 124 |
+
repo_id="Laboratoire-De-Chemoinformatique/SynPlanner",
|
| 125 |
+
filename="building_blocks_em_sa_ln.smi",
|
| 126 |
+
subfolder="building_blocks",
|
| 127 |
+
local_dir=".",
|
| 128 |
+
)
|
| 129 |
+
ranking_policy_weights_path = hf_hub_download(
|
| 130 |
+
repo_id="Laboratoire-De-Chemoinformatique/SynPlanner",
|
| 131 |
+
filename="ranking_policy_network.ckpt",
|
| 132 |
+
subfolder="uspto/weights",
|
| 133 |
+
local_dir=".",
|
| 134 |
+
)
|
| 135 |
+
reaction_rules_path = hf_hub_download(
|
| 136 |
+
repo_id="Laboratoire-De-Chemoinformatique/SynPlanner",
|
| 137 |
+
filename="uspto_reaction_rules.pickle",
|
| 138 |
+
subfolder="uspto",
|
| 139 |
+
local_dir=".",
|
| 140 |
+
)
|
| 141 |
+
return building_blocks_path, ranking_policy_weights_path, reaction_rules_path
|
| 142 |
+
|
| 143 |
+
|
| 144 |
+
# --- GUI Sections ---
|
| 145 |
+
|
| 146 |
+
|
| 147 |
+
def initialize_app():
|
| 148 |
+
"""1. Initialization: Setting up the main window, layout, and initial widgets."""
|
| 149 |
+
st.set_page_config(page_title="SynPlanner GUI", page_icon="🧪", layout="wide")
|
| 150 |
+
|
| 151 |
+
# Initialize session state variables if they don't exist.
|
| 152 |
+
if "planning_done" not in st.session_state:
|
| 153 |
+
st.session_state.planning_done = False
|
| 154 |
+
if "tree" not in st.session_state:
|
| 155 |
+
st.session_state.tree = None
|
| 156 |
+
if "res" not in st.session_state:
|
| 157 |
+
st.session_state.res = None
|
| 158 |
+
if "target_smiles" not in st.session_state:
|
| 159 |
+
st.session_state.target_smiles = (
|
| 160 |
+
"" # Initial value, might be overwritten by ketcher
|
| 161 |
+
)
|
| 162 |
+
|
| 163 |
+
# Clustering state
|
| 164 |
+
if "clustering_done" not in st.session_state:
|
| 165 |
+
st.session_state.clustering_done = False
|
| 166 |
+
if "clusters" not in st.session_state:
|
| 167 |
+
st.session_state.clusters = None
|
| 168 |
+
if "reactions_dict" not in st.session_state:
|
| 169 |
+
st.session_state.reactions_dict = None
|
| 170 |
+
if "num_clusters_setting" not in st.session_state: # Store the setting used
|
| 171 |
+
st.session_state.num_clusters_setting = 10
|
| 172 |
+
if "route_cgrs_dict" not in st.session_state:
|
| 173 |
+
st.session_state.route_cgrs_dict = None
|
| 174 |
+
if "sb_cgrs_dict" not in st.session_state:
|
| 175 |
+
st.session_state.sb_cgrs_dict = None
|
| 176 |
+
if "route_json" not in st.session_state:
|
| 177 |
+
st.session_state.route_json = None
|
| 178 |
+
|
| 179 |
+
# Subclustering state
|
| 180 |
+
if "subclustering_done" not in st.session_state:
|
| 181 |
+
st.session_state.subclustering_done = False
|
| 182 |
+
if "subclusters" not in st.session_state: # Renamed from 'sub' for clarity
|
| 183 |
+
st.session_state.subclusters = None
|
| 184 |
+
|
| 185 |
+
# Download state (less critical now with direct download links)
|
| 186 |
+
if "clusters_downloaded" not in st.session_state: # Example, might not be needed
|
| 187 |
+
st.session_state.clusters_downloaded = False
|
| 188 |
+
|
| 189 |
+
if "ketcher" not in st.session_state: # For ketcher persistence
|
| 190 |
+
st.session_state.ketcher = DEFAULT_MOL
|
| 191 |
+
|
| 192 |
+
intro_text = """
|
| 193 |
+
This is a demo of the graphical user interface of
|
| 194 |
+
[SynPlanner](https://github.com/Laboratoire-de-Chemoinformatique/SynPlanner/).
|
| 195 |
+
SynPlanner is a comprehensive tool for reaction data curation, rule extraction, model training and retrosynthetic planning.
|
| 196 |
+
|
| 197 |
+
More information on SynPlanner is available in the [official docs](https://synplanner.readthedocs.io/en/latest/index.html).
|
| 198 |
+
"""
|
| 199 |
+
st.title("`SynPlanner GUI`")
|
| 200 |
+
st.write(intro_text)
|
| 201 |
+
|
| 202 |
+
|
| 203 |
+
def setup_sidebar():
|
| 204 |
+
"""2. Sidebar: Handling the widgets and logic within the sidebar area."""
|
| 205 |
+
# st.sidebar.image("img/logo.png") # Assuming img/logo.png is available
|
| 206 |
+
st.sidebar.title("Docs")
|
| 207 |
+
st.sidebar.markdown("https://synplanner.readthedocs.io/en/latest/")
|
| 208 |
+
|
| 209 |
+
st.sidebar.title("Tutorials")
|
| 210 |
+
st.sidebar.markdown(
|
| 211 |
+
"https://github.com/Laboratoire-de-Chemoinformatique/SynPlanner/tree/main/tutorials"
|
| 212 |
+
)
|
| 213 |
+
|
| 214 |
+
st.sidebar.title("Paper")
|
| 215 |
+
st.sidebar.markdown(
|
| 216 |
+
"https://chemrxiv.org/engage/chemrxiv/article-details/66add90bc9c6a5c07ae65796"
|
| 217 |
+
)
|
| 218 |
+
|
| 219 |
+
st.sidebar.title("Issues")
|
| 220 |
+
st.sidebar.markdown(
|
| 221 |
+
"[Report a bug 🐞](https://github.com/Laboratoire-de-Chemoinformatique/SynPlanner/issues/new?assignees=&labels=bug&projects=&template=bug_report.md&title=%5BBUG%5D)"
|
| 222 |
+
)
|
| 223 |
+
|
| 224 |
+
|
| 225 |
+
def handle_molecule_input():
|
| 226 |
+
"""3. Molecule Input: Managing the input area for molecule data with two-way synchronization."""
|
| 227 |
+
st.header("Molecule input")
|
| 228 |
+
st.markdown(
|
| 229 |
+
"""
|
| 230 |
+
You can provide a molecular structure by either providing:
|
| 231 |
+
* SMILES string + Enter
|
| 232 |
+
* Draw it + Apply
|
| 233 |
+
"""
|
| 234 |
+
)
|
| 235 |
+
|
| 236 |
+
if "shared_smiles" not in st.session_state:
|
| 237 |
+
st.session_state.shared_smiles = st.session_state.get("ketcher", DEFAULT_MOL)
|
| 238 |
+
|
| 239 |
+
if "ketcher_render_count" not in st.session_state:
|
| 240 |
+
st.session_state.ketcher_render_count = 0
|
| 241 |
+
|
| 242 |
+
def text_input_changed_callback():
|
| 243 |
+
new_text_value = (
|
| 244 |
+
st.session_state.smiles_text_input_key_for_sync
|
| 245 |
+
) # Key of the text_input
|
| 246 |
+
if new_text_value != st.session_state.shared_smiles:
|
| 247 |
+
st.session_state.shared_smiles = new_text_value
|
| 248 |
+
st.session_state.ketcher = new_text_value
|
| 249 |
+
st.session_state.ketcher_render_count += 1
|
| 250 |
+
|
| 251 |
+
# SMILES Text Input
|
| 252 |
+
st.text_input(
|
| 253 |
+
"SMILES:",
|
| 254 |
+
value=st.session_state.shared_smiles,
|
| 255 |
+
key="smiles_text_input_key_for_sync", # Unique key for this widget
|
| 256 |
+
on_change=text_input_changed_callback,
|
| 257 |
+
help="Enter SMILES string and press Enter. The drawing will update, and vice-versa.",
|
| 258 |
+
)
|
| 259 |
+
|
| 260 |
+
ketcher_key = f"ketcher_widget_for_sync_{st.session_state.ketcher_render_count}"
|
| 261 |
+
smile_code_output_from_ketcher = st_ketcher(
|
| 262 |
+
st.session_state.shared_smiles, key=ketcher_key
|
| 263 |
+
)
|
| 264 |
+
|
| 265 |
+
if smile_code_output_from_ketcher != st.session_state.shared_smiles:
|
| 266 |
+
st.session_state.shared_smiles = smile_code_output_from_ketcher
|
| 267 |
+
st.session_state.ketcher = smile_code_output_from_ketcher
|
| 268 |
+
st.rerun()
|
| 269 |
+
|
| 270 |
+
current_smiles_for_planning = st.session_state.shared_smiles
|
| 271 |
+
|
| 272 |
+
last_planned_smiles = st.session_state.get("target_smiles")
|
| 273 |
+
if (
|
| 274 |
+
last_planned_smiles
|
| 275 |
+
and current_smiles_for_planning != last_planned_smiles
|
| 276 |
+
and st.session_state.get("planning_done", False)
|
| 277 |
+
):
|
| 278 |
+
st.warning(
|
| 279 |
+
"Molecule structure has changed since the last successful planning run. "
|
| 280 |
+
"Results shown below (if any) are for the previous molecule. "
|
| 281 |
+
"Please re-run planning for the current structure."
|
| 282 |
+
)
|
| 283 |
+
|
| 284 |
+
# Ensure st.session_state.ketcher is consistent for other parts of the app
|
| 285 |
+
if st.session_state.get("ketcher") != current_smiles_for_planning:
|
| 286 |
+
st.session_state.ketcher = current_smiles_for_planning
|
| 287 |
+
|
| 288 |
+
return current_smiles_for_planning
|
| 289 |
+
|
| 290 |
+
|
| 291 |
+
def setup_planning_options():
|
| 292 |
+
"""4. Planning: Encapsulating the logic related to the "planning" functionality."""
|
| 293 |
+
st.header("Launch calculation")
|
| 294 |
+
st.markdown(
|
| 295 |
+
"""If you modified the structure, please ensure you clicked on `Apply` (bottom right of the molecular editor)."""
|
| 296 |
+
)
|
| 297 |
+
|
| 298 |
+
st.markdown(
|
| 299 |
+
f"The molecule SMILES is actually: ``{st.session_state.get('ketcher', DEFAULT_MOL)}``"
|
| 300 |
+
)
|
| 301 |
+
|
| 302 |
+
st.subheader("Planning options")
|
| 303 |
+
st.markdown(
|
| 304 |
+
"""
|
| 305 |
+
The description of each option can be found in the
|
| 306 |
+
[Retrosynthetic Planning Tutorial](https://synplanner.readthedocs.io/en/latest/tutorial_files/retrosynthetic_planning.html#Configuring-search-tree).
|
| 307 |
+
"""
|
| 308 |
+
)
|
| 309 |
+
|
| 310 |
+
col_options_1, col_options_2 = st.columns(2, gap="medium")
|
| 311 |
+
with col_options_1:
|
| 312 |
+
search_strategy_input = st.selectbox(
|
| 313 |
+
label="Search strategy",
|
| 314 |
+
options=(
|
| 315 |
+
"Expansion first",
|
| 316 |
+
"Evaluation first",
|
| 317 |
+
),
|
| 318 |
+
index=0,
|
| 319 |
+
key="search_strategy_input",
|
| 320 |
+
)
|
| 321 |
+
ucb_type = st.selectbox(
|
| 322 |
+
label="UCB type",
|
| 323 |
+
options=("uct", "puct", "value"),
|
| 324 |
+
index=0,
|
| 325 |
+
key="ucb_type_input",
|
| 326 |
+
)
|
| 327 |
+
c_ucb = st.number_input(
|
| 328 |
+
"C coefficient of UCB",
|
| 329 |
+
value=0.1,
|
| 330 |
+
placeholder="Type a number...",
|
| 331 |
+
key="c_ucb_input",
|
| 332 |
+
)
|
| 333 |
+
|
| 334 |
+
with col_options_2:
|
| 335 |
+
max_iterations = st.slider(
|
| 336 |
+
"Total number of MCTS iterations",
|
| 337 |
+
min_value=50,
|
| 338 |
+
max_value=3000,
|
| 339 |
+
value=1000,
|
| 340 |
+
key="max_iterations_slider",
|
| 341 |
+
)
|
| 342 |
+
max_depth = st.slider(
|
| 343 |
+
"Maximal number of reaction steps",
|
| 344 |
+
min_value=3,
|
| 345 |
+
max_value=9,
|
| 346 |
+
value=6,
|
| 347 |
+
key="max_depth_slider",
|
| 348 |
+
)
|
| 349 |
+
min_mol_size = st.slider(
|
| 350 |
+
"Minimum size of a molecule to be precursor",
|
| 351 |
+
min_value=0,
|
| 352 |
+
max_value=7,
|
| 353 |
+
value=0,
|
| 354 |
+
key="min_mol_size_slider",
|
| 355 |
+
help="Number of non-hydrogen atoms in molecule",
|
| 356 |
+
)
|
| 357 |
+
|
| 358 |
+
search_strategy_translator = {
|
| 359 |
+
"Expansion first": "expansion_first",
|
| 360 |
+
"Evaluation first": "evaluation_first",
|
| 361 |
+
}
|
| 362 |
+
search_strategy = search_strategy_translator[search_strategy_input]
|
| 363 |
+
|
| 364 |
+
planning_params = {
|
| 365 |
+
"search_strategy": search_strategy,
|
| 366 |
+
"ucb_type": ucb_type,
|
| 367 |
+
"c_ucb": c_ucb,
|
| 368 |
+
"max_iterations": max_iterations,
|
| 369 |
+
"max_depth": max_depth,
|
| 370 |
+
"min_mol_size": min_mol_size,
|
| 371 |
+
}
|
| 372 |
+
|
| 373 |
+
if st.button("Start retrosynthetic planning", key="submit_planning_button"):
|
| 374 |
+
# Reset downstream states if replanning
|
| 375 |
+
st.session_state.planning_done = False
|
| 376 |
+
st.session_state.clustering_done = False
|
| 377 |
+
st.session_state.subclustering_done = False
|
| 378 |
+
st.session_state.tree = None
|
| 379 |
+
st.session_state.res = None
|
| 380 |
+
st.session_state.clusters = None
|
| 381 |
+
st.session_state.reactions_dict = None
|
| 382 |
+
st.session_state.subclusters = None
|
| 383 |
+
st.session_state.route_cgrs_dict = None
|
| 384 |
+
st.session_state.sb_cgrs_dict = None
|
| 385 |
+
st.session_state.route_json = None
|
| 386 |
+
active_smile_code = st.session_state.get(
|
| 387 |
+
"ketcher", DEFAULT_MOL
|
| 388 |
+
) # Get current SMILES
|
| 389 |
+
st.session_state.target_smiles = (
|
| 390 |
+
active_smile_code # Store the SMILES used for this run
|
| 391 |
+
)
|
| 392 |
+
|
| 393 |
+
try:
|
| 394 |
+
target_molecule = mol_from_smiles(active_smile_code, clean_stereo=True)
|
| 395 |
+
if target_molecule is None:
|
| 396 |
+
raise ValueError(f"Could not parse the input SMILES: {active_smile_code}")
|
| 397 |
+
|
| 398 |
+
(
|
| 399 |
+
building_blocks_path,
|
| 400 |
+
ranking_policy_weights_path,
|
| 401 |
+
reaction_rules_path,
|
| 402 |
+
) = load_planning_resources_cached()
|
| 403 |
+
with st.spinner("Running retrosynthetic planning..."):
|
| 404 |
+
with st.status("Loading resources...", expanded=False) as status:
|
| 405 |
+
st.write("Loading building blocks...")
|
| 406 |
+
building_blocks = load_building_blocks(
|
| 407 |
+
building_blocks_path, standardize=False
|
| 408 |
+
)
|
| 409 |
+
st.write("Loading reaction rules...")
|
| 410 |
+
reaction_rules = load_reaction_rules(reaction_rules_path)
|
| 411 |
+
st.write("Loading policy network...")
|
| 412 |
+
policy_config = PolicyNetworkConfig(
|
| 413 |
+
weights_path=ranking_policy_weights_path
|
| 414 |
+
)
|
| 415 |
+
policy_function = PolicyNetworkFunction(
|
| 416 |
+
policy_config=policy_config
|
| 417 |
+
)
|
| 418 |
+
status.update(label="Resources loaded!", state="complete")
|
| 419 |
+
|
| 420 |
+
tree_config = TreeConfig(
|
| 421 |
+
search_strategy=planning_params["search_strategy"],
|
| 422 |
+
evaluation_type="rollout",
|
| 423 |
+
max_iterations=planning_params["max_iterations"],
|
| 424 |
+
max_depth=planning_params["max_depth"],
|
| 425 |
+
min_mol_size=planning_params["min_mol_size"],
|
| 426 |
+
init_node_value=0.5,
|
| 427 |
+
ucb_type=planning_params["ucb_type"],
|
| 428 |
+
c_ucb=planning_params["c_ucb"],
|
| 429 |
+
silent=True,
|
| 430 |
+
)
|
| 431 |
+
|
| 432 |
+
tree = Tree(
|
| 433 |
+
target=target_molecule,
|
| 434 |
+
config=tree_config,
|
| 435 |
+
reaction_rules=reaction_rules,
|
| 436 |
+
building_blocks=building_blocks,
|
| 437 |
+
expansion_function=policy_function,
|
| 438 |
+
evaluation_function=None,
|
| 439 |
+
)
|
| 440 |
+
|
| 441 |
+
mcts_progress_text = "Running MCTS iterations..."
|
| 442 |
+
mcts_bar = st.progress(0, text=mcts_progress_text)
|
| 443 |
+
for step, (solved, route_id) in enumerate(tree):
|
| 444 |
+
progress_value = min(
|
| 445 |
+
1.0, (step + 1) / planning_params["max_iterations"]
|
| 446 |
+
)
|
| 447 |
+
mcts_bar.progress(
|
| 448 |
+
progress_value,
|
| 449 |
+
text=f"{mcts_progress_text} ({step+1}/{planning_params['max_iterations']})",
|
| 450 |
+
)
|
| 451 |
+
|
| 452 |
+
res = extract_tree_stats(tree, target_molecule)
|
| 453 |
+
|
| 454 |
+
st.session_state["tree"] = tree
|
| 455 |
+
st.session_state["res"] = res
|
| 456 |
+
st.session_state.planning_done = True
|
| 457 |
+
st.rerun()
|
| 458 |
+
|
| 459 |
+
except (ValueError, KeyError, FileNotFoundError, TypeError) as e:
|
| 460 |
+
st.error(f"An error occurred during planning: {e}")
|
| 461 |
+
st.session_state.planning_done = False
|
| 462 |
+
|
| 463 |
+
|
| 464 |
+
def display_planning_results():
|
| 465 |
+
"""5. Planning Results Display: Handling the presentation of results."""
|
| 466 |
+
if st.session_state.get("planning_done", False):
|
| 467 |
+
res = st.session_state.res
|
| 468 |
+
tree = st.session_state.tree
|
| 469 |
+
|
| 470 |
+
if res is None or tree is None:
|
| 471 |
+
st.error(
|
| 472 |
+
"Planning results are missing from session state. Please re-run planning."
|
| 473 |
+
)
|
| 474 |
+
st.session_state.planning_done = False # Reset state
|
| 475 |
+
return # Exit this function if no results
|
| 476 |
+
|
| 477 |
+
if res.get("solved", False): # Use .get for safety
|
| 478 |
+
st.header("Planning Results")
|
| 479 |
+
winning_nodes = (
|
| 480 |
+
sorted(set(tree.winning_nodes))
|
| 481 |
+
if hasattr(tree, "winning_nodes") and tree.winning_nodes
|
| 482 |
+
else []
|
| 483 |
+
)
|
| 484 |
+
st.subheader(f"Number of unique routes found: {len(winning_nodes)}")
|
| 485 |
+
|
| 486 |
+
st.subheader("Examples of found retrosynthetic routes")
|
| 487 |
+
image_counter = 0
|
| 488 |
+
visualised_route_ids = set()
|
| 489 |
+
|
| 490 |
+
if not winning_nodes:
|
| 491 |
+
st.warning(
|
| 492 |
+
"Planning solved, but no winning nodes found in the tree object."
|
| 493 |
+
)
|
| 494 |
+
else:
|
| 495 |
+
for n, route_id in enumerate(winning_nodes):
|
| 496 |
+
if image_counter >= 3:
|
| 497 |
+
break
|
| 498 |
+
if route_id not in visualised_route_ids:
|
| 499 |
+
try:
|
| 500 |
+
visualised_route_ids.add(route_id)
|
| 501 |
+
num_steps = len(tree.synthesis_route(route_id))
|
| 502 |
+
route_score = round(tree.route_score(route_id), 3)
|
| 503 |
+
svg = get_route_svg(tree, route_id)
|
| 504 |
+
if svg:
|
| 505 |
+
st.image(
|
| 506 |
+
svg,
|
| 507 |
+
caption=f"Route {route_id}; {num_steps} steps; Route score: {route_score}",
|
| 508 |
+
)
|
| 509 |
+
image_counter += 1
|
| 510 |
+
else:
|
| 511 |
+
st.warning(
|
| 512 |
+
f"Could not generate SVG for route {route_id}."
|
| 513 |
+
)
|
| 514 |
+
except Exception as e:
|
| 515 |
+
st.error(f"Error displaying route {route_id}: {e}")
|
| 516 |
+
else: # Not solved
|
| 517 |
+
st.header("Planning Results")
|
| 518 |
+
st.warning(
|
| 519 |
+
"No reaction path found for the target molecule with the current settings."
|
| 520 |
+
)
|
| 521 |
+
st.write(
|
| 522 |
+
"Find below the unfinished pathways"
|
| 523 |
+
)
|
| 524 |
+
image_counter = 0
|
| 525 |
+
for route_id in list(tree.nodes.keys())[1:tree.config.max_iterations:50]:
|
| 526 |
+
svg = get_route_svg_mod(tree, route_id)
|
| 527 |
+
# display(SVG(get_route_svg_mod(tree, route_id)))
|
| 528 |
+
if svg:
|
| 529 |
+
st.image(
|
| 530 |
+
svg,
|
| 531 |
+
caption=f"Route {route_id};",
|
| 532 |
+
)
|
| 533 |
+
image_counter += 1
|
| 534 |
+
reactions = tree.synthesis_route(route_id)
|
| 535 |
+
for reaction in reactions:
|
| 536 |
+
st.write(reaction)
|
| 537 |
+
else:
|
| 538 |
+
st.warning(
|
| 539 |
+
f"Could not generate SVG for route {route_id}."
|
| 540 |
+
)
|
| 541 |
+
if image_counter >= 20:
|
| 542 |
+
break
|
| 543 |
+
|
| 544 |
+
# st.warning(
|
| 545 |
+
# "No reaction path found for the target molecule with the current settings."
|
| 546 |
+
# )
|
| 547 |
+
# st.write(
|
| 548 |
+
# "Consider adjusting planning options (e.g., increase iterations, adjust depth, check molecule validity)."
|
| 549 |
+
# )
|
| 550 |
+
# stat_col, _ = st.columns(2)
|
| 551 |
+
# with stat_col:
|
| 552 |
+
# st.subheader("Run Statistics (No Solution)")
|
| 553 |
+
# try:
|
| 554 |
+
# if (
|
| 555 |
+
# "target_smiles" not in res
|
| 556 |
+
# and "target_smiles" in st.session_state
|
| 557 |
+
# ):
|
| 558 |
+
# res["target_smiles"] = st.session_state.target_smiles
|
| 559 |
+
# cols_to_show = [
|
| 560 |
+
# col
|
| 561 |
+
# for col in [
|
| 562 |
+
# "target_smiles",
|
| 563 |
+
# "num_nodes",
|
| 564 |
+
# "num_iter",
|
| 565 |
+
# "search_time",
|
| 566 |
+
# ]
|
| 567 |
+
# if col in res
|
| 568 |
+
# ]
|
| 569 |
+
# if cols_to_show:
|
| 570 |
+
# df = pd.DataFrame(res, index=[0])[cols_to_show]
|
| 571 |
+
# st.dataframe(df)
|
| 572 |
+
# else:
|
| 573 |
+
# st.write("No statistics to display for the unsuccessful run.")
|
| 574 |
+
# except Exception as e:
|
| 575 |
+
# st.error(f"Error displaying statistics: {e}")
|
| 576 |
+
# st.write(res)
|
| 577 |
+
|
| 578 |
+
|
| 579 |
+
def download_planning_results():
|
| 580 |
+
"""6. Planning Results Download: Providing functionality to download."""
|
| 581 |
+
if (
|
| 582 |
+
st.session_state.get("planning_done", False)
|
| 583 |
+
and st.session_state.res
|
| 584 |
+
and st.session_state.res.get("solved", False)
|
| 585 |
+
):
|
| 586 |
+
res = st.session_state.res
|
| 587 |
+
tree = st.session_state.tree
|
| 588 |
+
# This section is usually placed within a column in the original script
|
| 589 |
+
# We'll assume it's called after display_planning_results and can use a new column or area.
|
| 590 |
+
# For proper layout, this should be integrated with display_planning_results' columns.
|
| 591 |
+
# For now, creating a placeholder or separate section for downloads:
|
| 592 |
+
# st.subheader("Downloads") # This might be redundant if called within a layout context.
|
| 593 |
+
|
| 594 |
+
# The original code places downloads in the second column of planning results.
|
| 595 |
+
# To replicate, we'd need to pass the column object or call this within that context.
|
| 596 |
+
# Simulating this by just creating the download links:
|
| 597 |
+
try:
|
| 598 |
+
html_body = generate_results_html(tree, html_path=None, extended=True)
|
| 599 |
+
dl_html = download_button(
|
| 600 |
+
html_body,
|
| 601 |
+
f"results_synplanner_{st.session_state.target_smiles}.html",
|
| 602 |
+
"Download results (HTML)",
|
| 603 |
+
)
|
| 604 |
+
if dl_html:
|
| 605 |
+
st.markdown(dl_html, unsafe_allow_html=True)
|
| 606 |
+
|
| 607 |
+
try:
|
| 608 |
+
res_df = pd.DataFrame(res, index=[0])
|
| 609 |
+
dl_csv = download_button(
|
| 610 |
+
res_df,
|
| 611 |
+
f"stats_synplanner_{st.session_state.target_smiles}.csv",
|
| 612 |
+
"Download statistics (CSV)",
|
| 613 |
+
)
|
| 614 |
+
if dl_csv:
|
| 615 |
+
st.markdown(dl_csv, unsafe_allow_html=True)
|
| 616 |
+
except Exception as e:
|
| 617 |
+
st.error(f"Could not prepare statistics CSV for download: {e}")
|
| 618 |
+
|
| 619 |
+
except Exception as e:
|
| 620 |
+
st.error(f"Error generating download links for planning results: {e}")
|
| 621 |
+
|
| 622 |
+
|
| 623 |
+
def setup_clustering():
|
| 624 |
+
"""7. Clustering: Encapsulating the logic related to the "clustering" functionality."""
|
| 625 |
+
if (
|
| 626 |
+
st.session_state.get("planning_done", False)
|
| 627 |
+
and st.session_state.res
|
| 628 |
+
and st.session_state.res.get("solved", False)
|
| 629 |
+
):
|
| 630 |
+
st.divider()
|
| 631 |
+
st.header("Clustering the retrosynthetic routes")
|
| 632 |
+
|
| 633 |
+
if st.button("Run Clustering", key="submit_clustering_button"):
|
| 634 |
+
# st.session_state.num_clusters_setting = num_clusters_input
|
| 635 |
+
st.session_state.clustering_done = False
|
| 636 |
+
st.session_state.subclustering_done = False
|
| 637 |
+
st.session_state.clusters = None
|
| 638 |
+
st.session_state.reactions_dict = None
|
| 639 |
+
st.session_state.subclusters = None
|
| 640 |
+
st.session_state.route_cgrs_dict = None
|
| 641 |
+
st.session_state.sb_cgrs_dict = None
|
| 642 |
+
st.session_state.route_json = None
|
| 643 |
+
|
| 644 |
+
with st.spinner("Performing clustering..."):
|
| 645 |
+
try:
|
| 646 |
+
current_tree = st.session_state.tree
|
| 647 |
+
if not current_tree:
|
| 648 |
+
st.error("Tree object not found. Please re-run planning.")
|
| 649 |
+
return
|
| 650 |
+
|
| 651 |
+
st.write("Calculating RoutesCGRs...")
|
| 652 |
+
route_cgrs_dict = compose_all_route_cgrs(current_tree)
|
| 653 |
+
st.write("Processing SB-CGRs...")
|
| 654 |
+
sb_cgrs_dict = compose_all_sb_cgrs(route_cgrs_dict)
|
| 655 |
+
|
| 656 |
+
results = cluster_routes(
|
| 657 |
+
sb_cgrs_dict, use_strat=False
|
| 658 |
+
) # num_clusters was removed from args
|
| 659 |
+
results = dict(sorted(results.items(), key=lambda x: float(x[0])))
|
| 660 |
+
|
| 661 |
+
st.session_state.clusters = results
|
| 662 |
+
st.session_state.route_cgrs_dict = route_cgrs_dict
|
| 663 |
+
st.session_state.sb_cgrs_dict = sb_cgrs_dict
|
| 664 |
+
st.write("Extracting reactions...")
|
| 665 |
+
st.session_state.reactions_dict = extract_reactions(current_tree)
|
| 666 |
+
st.session_state.route_json = make_json(st.session_state.reactions_dict)
|
| 667 |
+
|
| 668 |
+
if (
|
| 669 |
+
st.session_state.clusters is not None
|
| 670 |
+
and st.session_state.reactions_dict is not None
|
| 671 |
+
): # Check for None explicitly
|
| 672 |
+
st.session_state.clustering_done = True
|
| 673 |
+
st.success(
|
| 674 |
+
f"Clustering complete. Found {len(st.session_state.clusters)} clusters."
|
| 675 |
+
)
|
| 676 |
+
else:
|
| 677 |
+
st.error("Clustering failed or returned empty results.")
|
| 678 |
+
st.session_state.clustering_done = False
|
| 679 |
+
|
| 680 |
+
del results # route_cgrs_dict, sb_cgrs_dict are stored
|
| 681 |
+
gc.collect()
|
| 682 |
+
st.rerun()
|
| 683 |
+
except Exception as e:
|
| 684 |
+
st.error(f"An error occurred during clustering: {e}")
|
| 685 |
+
st.session_state.clustering_done = False
|
| 686 |
+
|
| 687 |
+
|
| 688 |
+
def display_clustering_results():
|
| 689 |
+
"""8. Clustering Results Display: Handling the presentation of results."""
|
| 690 |
+
if st.session_state.get("clustering_done", False):
|
| 691 |
+
clusters = st.session_state.clusters
|
| 692 |
+
# reactions_dict = st.session_state.reactions_dict # Needed for download, not directly for display here
|
| 693 |
+
tree = st.session_state.tree
|
| 694 |
+
MAX_DISPLAY_CLUSTERS_DATA = 10
|
| 695 |
+
|
| 696 |
+
if (
|
| 697 |
+
clusters is None or tree is None
|
| 698 |
+
): # reactions_dict removed as not critical for display part
|
| 699 |
+
st.error(
|
| 700 |
+
"Clustering results (clusters or tree) are missing. Please re-run clustering."
|
| 701 |
+
)
|
| 702 |
+
st.session_state.clustering_done = False
|
| 703 |
+
return
|
| 704 |
+
|
| 705 |
+
st.subheader(f"Best routes from {len(clusters)} Found Clusters")
|
| 706 |
+
clusters_items = list(clusters.items())
|
| 707 |
+
first_items = clusters_items[:MAX_DISPLAY_CLUSTERS_DATA]
|
| 708 |
+
remaining_items = clusters_items[MAX_DISPLAY_CLUSTERS_DATA:]
|
| 709 |
+
|
| 710 |
+
for cluster_num, group_data in first_items:
|
| 711 |
+
if (
|
| 712 |
+
not group_data
|
| 713 |
+
or "route_ids" not in group_data
|
| 714 |
+
or not group_data["route_ids"]
|
| 715 |
+
):
|
| 716 |
+
st.warning(f"Cluster {cluster_num} has no data or route_ids.")
|
| 717 |
+
continue
|
| 718 |
+
st.markdown(
|
| 719 |
+
f"**Cluster {cluster_num}** (Size: {group_data.get('group_size', 'N/A')})"
|
| 720 |
+
)
|
| 721 |
+
route_id = group_data["route_ids"][0]
|
| 722 |
+
try:
|
| 723 |
+
num_steps = len(tree.synthesis_route(route_id))
|
| 724 |
+
route_score = round(tree.route_score(route_id), 3)
|
| 725 |
+
# svg = get_route_svg(tree, route_id)
|
| 726 |
+
svg = get_route_svg_from_json(st.session_state.route_json, route_id)
|
| 727 |
+
sb_cgr = group_data.get("sb_cgr") # Safely get sb_cgr
|
| 728 |
+
sb_cgr_svg = None
|
| 729 |
+
if sb_cgr:
|
| 730 |
+
sb_cgr.clean2d()
|
| 731 |
+
sb_cgr_svg = cgr_display(sb_cgr)
|
| 732 |
+
|
| 733 |
+
if svg and sb_cgr_svg:
|
| 734 |
+
col1, col2 = st.columns([0.2, 0.8])
|
| 735 |
+
with col1:
|
| 736 |
+
st.image(sb_cgr_svg, caption="SB-CGR")
|
| 737 |
+
with col2:
|
| 738 |
+
st.image(
|
| 739 |
+
svg,
|
| 740 |
+
caption=f"Route {route_id}; {num_steps} steps; Route score: {route_score}",
|
| 741 |
+
)
|
| 742 |
+
elif svg: # Only route SVG available
|
| 743 |
+
st.image(
|
| 744 |
+
svg,
|
| 745 |
+
caption=f"Route {route_id}; {num_steps} steps; Route score: {route_score}",
|
| 746 |
+
)
|
| 747 |
+
st.warning(
|
| 748 |
+
f"SB-CGR could not be displayed for cluster {cluster_num}."
|
| 749 |
+
)
|
| 750 |
+
else:
|
| 751 |
+
st.warning(
|
| 752 |
+
f"Could not generate SVG for route {route_id} or its SB-CGR."
|
| 753 |
+
)
|
| 754 |
+
except Exception as e:
|
| 755 |
+
st.error(
|
| 756 |
+
f"Error displaying route {route_id} for cluster {cluster_num}: {e}"
|
| 757 |
+
)
|
| 758 |
+
|
| 759 |
+
if remaining_items:
|
| 760 |
+
with st.expander(f"... and {len(remaining_items)} more clusters"):
|
| 761 |
+
for cluster_num, group_data in remaining_items:
|
| 762 |
+
if (
|
| 763 |
+
not group_data
|
| 764 |
+
or "route_ids" not in group_data
|
| 765 |
+
or not group_data["route_ids"]
|
| 766 |
+
):
|
| 767 |
+
st.warning(
|
| 768 |
+
f"Cluster {cluster_num} in expansion has no data or route_ids."
|
| 769 |
+
)
|
| 770 |
+
continue
|
| 771 |
+
st.markdown(
|
| 772 |
+
f"**Cluster {cluster_num}** (Size: {group_data.get('group_size', 'N/A')})"
|
| 773 |
+
)
|
| 774 |
+
route_id = group_data["route_ids"][0]
|
| 775 |
+
try:
|
| 776 |
+
num_steps = len(tree.synthesis_route(route_id))
|
| 777 |
+
route_score = round(tree.route_score(route_id), 3)
|
| 778 |
+
# svg = get_route_svg(tree, route_id)
|
| 779 |
+
svg = get_route_svg_from_json(st.session_state.route_json, route_id)
|
| 780 |
+
sb_cgr = group_data.get("sb_cgr")
|
| 781 |
+
sb_cgr_svg = None
|
| 782 |
+
if sb_cgr:
|
| 783 |
+
sb_cgr.clean2d()
|
| 784 |
+
sb_cgr_svg = cgr_display(sb_cgr)
|
| 785 |
+
|
| 786 |
+
if svg and sb_cgr_svg:
|
| 787 |
+
col1, col2 = st.columns([0.2, 0.8])
|
| 788 |
+
with col1:
|
| 789 |
+
st.image(sb_cgr_svg, caption="SB-CGR")
|
| 790 |
+
with col2:
|
| 791 |
+
st.image(
|
| 792 |
+
svg,
|
| 793 |
+
caption=f"Route {route_id}; {num_steps} steps; Route score: {route_score}",
|
| 794 |
+
)
|
| 795 |
+
elif svg:
|
| 796 |
+
st.image(
|
| 797 |
+
svg,
|
| 798 |
+
caption=f"Route {route_id}; {num_steps} steps; Route score: {route_score}",
|
| 799 |
+
)
|
| 800 |
+
st.warning(
|
| 801 |
+
f"SB-CGR could not be displayed for cluster {cluster_num}."
|
| 802 |
+
)
|
| 803 |
+
else:
|
| 804 |
+
st.warning(
|
| 805 |
+
f"Could not generate SVG for route {route_id} or its SB-CGR."
|
| 806 |
+
)
|
| 807 |
+
except Exception as e:
|
| 808 |
+
st.error(
|
| 809 |
+
f"Error displaying route {route_id} for cluster {cluster_num}: {e}"
|
| 810 |
+
)
|
| 811 |
+
|
| 812 |
+
|
| 813 |
+
def download_clustering_results():
|
| 814 |
+
"""10. Clustering Results Download: Providing functionality to download."""
|
| 815 |
+
if st.session_state.get("clustering_done", False):
|
| 816 |
+
tree_for_html = st.session_state.get("tree")
|
| 817 |
+
clusters_for_html = st.session_state.get("clusters")
|
| 818 |
+
sb_cgrs_for_html = st.session_state.get(
|
| 819 |
+
"sb_cgrs_dict"
|
| 820 |
+
) # This was used instead of reactions_dict in the original for report
|
| 821 |
+
|
| 822 |
+
if not tree_for_html:
|
| 823 |
+
st.warning("MCTS Tree data not found. Cannot generate cluster reports.")
|
| 824 |
+
return
|
| 825 |
+
if not clusters_for_html:
|
| 826 |
+
st.warning("Cluster data not found. Cannot generate cluster reports.")
|
| 827 |
+
return
|
| 828 |
+
# sb_cgrs_for_html is optional for routes_clustering_report if not essential
|
| 829 |
+
|
| 830 |
+
st.subheader("Cluster Reports") # Changed subheader in original
|
| 831 |
+
st.write("Generate downloadable HTML reports for each cluster:")
|
| 832 |
+
|
| 833 |
+
MAX_DOWNLOAD_LINKS_DISPLAYED = 10
|
| 834 |
+
num_clusters_total = len(clusters_for_html)
|
| 835 |
+
clusters_items = list(clusters_for_html.items())
|
| 836 |
+
|
| 837 |
+
for i, (cluster_idx, group_data) in enumerate(
|
| 838 |
+
clusters_items
|
| 839 |
+
): # group_data might not be needed here if report uses cluster_idx
|
| 840 |
+
if i >= MAX_DOWNLOAD_LINKS_DISPLAYED:
|
| 841 |
+
break
|
| 842 |
+
try:
|
| 843 |
+
html_content = routes_clustering_report(
|
| 844 |
+
tree_for_html,
|
| 845 |
+
clusters_for_html, # Pass the whole dict
|
| 846 |
+
str(cluster_idx), # Pass the key of the cluster
|
| 847 |
+
sb_cgrs_for_html, # Pass the sb_cgrs dict
|
| 848 |
+
aam=False,
|
| 849 |
+
)
|
| 850 |
+
st.download_button(
|
| 851 |
+
label=f"Download report for cluster {cluster_idx}",
|
| 852 |
+
data=html_content,
|
| 853 |
+
file_name=f"cluster_{cluster_idx}_{st.session_state.target_smiles}.html",
|
| 854 |
+
mime="text/html",
|
| 855 |
+
key=f"download_cluster_{cluster_idx}",
|
| 856 |
+
)
|
| 857 |
+
except Exception as e:
|
| 858 |
+
st.error(f"Error generating report for cluster {cluster_idx}: {e}")
|
| 859 |
+
|
| 860 |
+
if num_clusters_total > MAX_DOWNLOAD_LINKS_DISPLAYED:
|
| 861 |
+
remaining_items = clusters_items[MAX_DOWNLOAD_LINKS_DISPLAYED:]
|
| 862 |
+
remaining_count = len(remaining_items)
|
| 863 |
+
expander_label = f"Show remaining {remaining_count} cluster reports"
|
| 864 |
+
with st.expander(expander_label):
|
| 865 |
+
for (
|
| 866 |
+
group_index,
|
| 867 |
+
_,
|
| 868 |
+
) in remaining_items: # group_data not needed here either
|
| 869 |
+
try:
|
| 870 |
+
html_content = routes_clustering_report(
|
| 871 |
+
tree_for_html,
|
| 872 |
+
clusters_for_html,
|
| 873 |
+
str(group_index),
|
| 874 |
+
sb_cgrs_for_html,
|
| 875 |
+
aam=False,
|
| 876 |
+
)
|
| 877 |
+
st.download_button(
|
| 878 |
+
label=f"Download report for cluster {group_index}",
|
| 879 |
+
data=html_content,
|
| 880 |
+
file_name=f"cluster_{group_index}_{st.session_state.target_smiles}.html",
|
| 881 |
+
mime="text/html",
|
| 882 |
+
key=f"download_cluster_expanded_{group_index}",
|
| 883 |
+
)
|
| 884 |
+
except Exception as e:
|
| 885 |
+
st.error(
|
| 886 |
+
f"Error generating report for cluster {group_index} (expanded): {e}"
|
| 887 |
+
)
|
| 888 |
+
|
| 889 |
+
try:
|
| 890 |
+
buffer = io.BytesIO()
|
| 891 |
+
with zipfile.ZipFile(
|
| 892 |
+
buffer, mode="w", compression=zipfile.ZIP_DEFLATED
|
| 893 |
+
) as zf:
|
| 894 |
+
for idx, _ in clusters_items: # group_data not needed
|
| 895 |
+
html_content_zip = routes_clustering_report(
|
| 896 |
+
tree_for_html,
|
| 897 |
+
clusters_for_html,
|
| 898 |
+
str(idx),
|
| 899 |
+
sb_cgrs_for_html,
|
| 900 |
+
aam=False,
|
| 901 |
+
)
|
| 902 |
+
filename = f"cluster_{idx}_{st.session_state.target_smiles}.html"
|
| 903 |
+
zf.writestr(filename, html_content_zip)
|
| 904 |
+
buffer.seek(0)
|
| 905 |
+
|
| 906 |
+
st.download_button(
|
| 907 |
+
label="📦 Download all cluster reports as ZIP",
|
| 908 |
+
data=buffer,
|
| 909 |
+
file_name=f"all_cluster_reports_{st.session_state.target_smiles}.zip",
|
| 910 |
+
mime="application/zip",
|
| 911 |
+
key="download_all_clusters_zip",
|
| 912 |
+
)
|
| 913 |
+
except Exception as e:
|
| 914 |
+
st.error(f"Error generating ZIP file for cluster reports: {e}")
|
| 915 |
+
|
| 916 |
+
|
| 917 |
+
def setup_subclustering():
|
| 918 |
+
"""11. Subclustering: Encapsulating the logic related to the "subclustering" functionality."""
|
| 919 |
+
if st.session_state.get(
|
| 920 |
+
"clustering_done", False
|
| 921 |
+
): # Subclustering depends on clustering being done
|
| 922 |
+
st.divider()
|
| 923 |
+
st.header("Sub-Clustering within a selected Cluster")
|
| 924 |
+
|
| 925 |
+
if st.button("Run Subclustering Analysis", key="submit_subclustering_button"):
|
| 926 |
+
st.session_state.subclustering_done = False
|
| 927 |
+
st.session_state.subclusters = None
|
| 928 |
+
with st.spinner("Performing subclustering analysis..."):
|
| 929 |
+
try:
|
| 930 |
+
clusters_for_sub = st.session_state.get("clusters")
|
| 931 |
+
sb_cgrs_dict_for_sub = st.session_state.get("sb_cgrs_dict")
|
| 932 |
+
route_cgrs_dict_for_sub = st.session_state.get("route_cgrs_dict")
|
| 933 |
+
|
| 934 |
+
if (
|
| 935 |
+
clusters_for_sub
|
| 936 |
+
and sb_cgrs_dict_for_sub
|
| 937 |
+
and route_cgrs_dict_for_sub
|
| 938 |
+
): # Ensure all are present
|
| 939 |
+
all_subgroups = subcluster_all_clusters(
|
| 940 |
+
clusters_for_sub,
|
| 941 |
+
sb_cgrs_dict_for_sub,
|
| 942 |
+
route_cgrs_dict_for_sub,
|
| 943 |
+
)
|
| 944 |
+
st.session_state.subclusters = all_subgroups
|
| 945 |
+
st.session_state.subclustering_done = True
|
| 946 |
+
st.success("Subclustering analysis complete.")
|
| 947 |
+
gc.collect()
|
| 948 |
+
st.rerun()
|
| 949 |
+
else:
|
| 950 |
+
missing = []
|
| 951 |
+
if not clusters_for_sub:
|
| 952 |
+
missing.append("clusters")
|
| 953 |
+
if not sb_cgrs_dict_for_sub:
|
| 954 |
+
missing.append("SB-CGRs dictionary")
|
| 955 |
+
if not route_cgrs_dict_for_sub:
|
| 956 |
+
missing.append("RouteCGRs dictionary")
|
| 957 |
+
st.error(
|
| 958 |
+
f"Cannot run subclustering. Missing data: {', '.join(missing)}. Please ensure clustering ran successfully."
|
| 959 |
+
)
|
| 960 |
+
st.session_state.subclustering_done = False
|
| 961 |
+
|
| 962 |
+
except Exception as e:
|
| 963 |
+
st.error(f"An error occurred during subclustering: {e}")
|
| 964 |
+
st.session_state.subclustering_done = False
|
| 965 |
+
|
| 966 |
+
|
| 967 |
+
def display_subclustering_results():
|
| 968 |
+
"""12. Subclustering Results Display: Handling the presentation of results."""
|
| 969 |
+
if st.session_state.get("subclustering_done", False):
|
| 970 |
+
sub = st.session_state.get("subclusters")
|
| 971 |
+
tree = st.session_state.get("tree")
|
| 972 |
+
# clusters_for_sub_display = st.session_state.get('clusters') # Not directly used in display logic from original code snippet
|
| 973 |
+
|
| 974 |
+
if not sub or not tree:
|
| 975 |
+
st.error(
|
| 976 |
+
"Subclustering results (subclusters or tree) are missing. Please re-run subclustering."
|
| 977 |
+
)
|
| 978 |
+
st.session_state.subclustering_done = False
|
| 979 |
+
return
|
| 980 |
+
|
| 981 |
+
sub_input_col, sub_display_col = st.columns([0.25, 0.75])
|
| 982 |
+
|
| 983 |
+
with sub_input_col:
|
| 984 |
+
st.subheader("Select Cluster and Subcluster")
|
| 985 |
+
available_cluster_nums = list(sub.keys())
|
| 986 |
+
if not available_cluster_nums:
|
| 987 |
+
st.warning("No clusters available in subclustering results.")
|
| 988 |
+
return # Exit if no clusters to select
|
| 989 |
+
|
| 990 |
+
user_input_cluster_num_display = st.selectbox(
|
| 991 |
+
"Select Cluster #:",
|
| 992 |
+
options=sorted(available_cluster_nums),
|
| 993 |
+
key="subcluster_num_select_key",
|
| 994 |
+
)
|
| 995 |
+
|
| 996 |
+
selected_subcluster_idx = 0
|
| 997 |
+
|
| 998 |
+
if user_input_cluster_num_display in sub:
|
| 999 |
+
sub_step_cluster = sub[user_input_cluster_num_display]
|
| 1000 |
+
allowed_subclusters_indices = sorted(list(sub_step_cluster.keys()))
|
| 1001 |
+
|
| 1002 |
+
if not allowed_subclusters_indices:
|
| 1003 |
+
st.warning(
|
| 1004 |
+
f"No reaction steps (subclusters) found for Cluster {user_input_cluster_num_display}."
|
| 1005 |
+
)
|
| 1006 |
+
else:
|
| 1007 |
+
selected_subcluster_idx = st.selectbox(
|
| 1008 |
+
"Select Subcluster Index:",
|
| 1009 |
+
options=allowed_subclusters_indices,
|
| 1010 |
+
key="subcluster_index_select_key",
|
| 1011 |
+
)
|
| 1012 |
+
if selected_subcluster_idx in sub[user_input_cluster_num_display]:
|
| 1013 |
+
current_subcluster_data = sub[user_input_cluster_num_display][
|
| 1014 |
+
selected_subcluster_idx
|
| 1015 |
+
]
|
| 1016 |
+
if "sb_cgr" in current_subcluster_data:
|
| 1017 |
+
cluster_sb_cgr_display = current_subcluster_data["sb_cgr"]
|
| 1018 |
+
cluster_sb_cgr_display.clean2d()
|
| 1019 |
+
st.image(
|
| 1020 |
+
cluster_sb_cgr_display.depict(),
|
| 1021 |
+
caption=f"SB-CGR of parent Cluster {user_input_cluster_num_display}",
|
| 1022 |
+
)
|
| 1023 |
+
else:
|
| 1024 |
+
st.warning("SB-CGR for this subcluster not found.")
|
| 1025 |
+
else:
|
| 1026 |
+
st.warning(
|
| 1027 |
+
f"Selected cluster {user_input_cluster_num_display} not found in subclustering results."
|
| 1028 |
+
)
|
| 1029 |
+
return
|
| 1030 |
+
|
| 1031 |
+
with sub_display_col:
|
| 1032 |
+
st.subheader("Subcluster Details")
|
| 1033 |
+
if (
|
| 1034 |
+
user_input_cluster_num_display in sub
|
| 1035 |
+
and selected_subcluster_idx in sub[user_input_cluster_num_display]
|
| 1036 |
+
):
|
| 1037 |
+
|
| 1038 |
+
subcluster_content = sub[user_input_cluster_num_display][
|
| 1039 |
+
selected_subcluster_idx
|
| 1040 |
+
]
|
| 1041 |
+
|
| 1042 |
+
# subcluster_to_display = post_process_subgroup(subcluster_content) #Under development
|
| 1043 |
+
subcluster_to_display = subcluster_content
|
| 1044 |
+
if (
|
| 1045 |
+
not subcluster_to_display
|
| 1046 |
+
or "routes_data" not in subcluster_to_display
|
| 1047 |
+
or not subcluster_to_display["routes_data"]
|
| 1048 |
+
):
|
| 1049 |
+
st.info("No routes or data found for this subcluster selection.")
|
| 1050 |
+
else:
|
| 1051 |
+
MAX_ROUTES_PER_SUBCLUSTER = 5
|
| 1052 |
+
all_route_ids_in_subcluster = list(
|
| 1053 |
+
subcluster_to_display["routes_data"].keys()
|
| 1054 |
+
)
|
| 1055 |
+
routes_to_display_direct = all_route_ids_in_subcluster[
|
| 1056 |
+
:MAX_ROUTES_PER_SUBCLUSTER
|
| 1057 |
+
]
|
| 1058 |
+
remaining_routes_sub = all_route_ids_in_subcluster[
|
| 1059 |
+
MAX_ROUTES_PER_SUBCLUSTER:
|
| 1060 |
+
]
|
| 1061 |
+
|
| 1062 |
+
st.markdown(
|
| 1063 |
+
f"--- \n**Subcluster {user_input_cluster_num_display}.{selected_subcluster_idx}** (Size: {len(all_route_ids_in_subcluster)})"
|
| 1064 |
+
)
|
| 1065 |
+
|
| 1066 |
+
if "synthon_reaction" in subcluster_to_display:
|
| 1067 |
+
synthon_reaction = subcluster_to_display["synthon_reaction"]
|
| 1068 |
+
try:
|
| 1069 |
+
synthon_reaction.clean2d()
|
| 1070 |
+
st.image(
|
| 1071 |
+
depict_custom_reaction(synthon_reaction),
|
| 1072 |
+
caption=f"Markush-like pseudo reaction of subcluster",
|
| 1073 |
+
) # Assuming depict_custom_reaction
|
| 1074 |
+
except Exception as e_depict:
|
| 1075 |
+
st.warning(f"Could not depict synthon reaction: {e_depict}")
|
| 1076 |
+
else:
|
| 1077 |
+
st.info("No synthon reaction data for this subcluster.")
|
| 1078 |
+
with st.container(height=500):
|
| 1079 |
+
for route_id in routes_to_display_direct:
|
| 1080 |
+
try:
|
| 1081 |
+
route_score_sub = round(tree.route_score(route_id), 3)
|
| 1082 |
+
# svg_sub = get_route_svg(tree, route_id)
|
| 1083 |
+
svg_sub = get_route_svg_from_json(st.session_state.route_json, route_id)
|
| 1084 |
+
if svg_sub:
|
| 1085 |
+
st.image(
|
| 1086 |
+
svg_sub,
|
| 1087 |
+
caption=f"Route {route_id}; Score: {route_score_sub}",
|
| 1088 |
+
)
|
| 1089 |
+
else:
|
| 1090 |
+
st.warning(
|
| 1091 |
+
f"Could not generate SVG for route {route_id}."
|
| 1092 |
+
)
|
| 1093 |
+
except Exception as e:
|
| 1094 |
+
st.error(
|
| 1095 |
+
f"Error displaying route {route_id} in subcluster: {e}"
|
| 1096 |
+
)
|
| 1097 |
+
|
| 1098 |
+
if remaining_routes_sub:
|
| 1099 |
+
with st.expander(
|
| 1100 |
+
f"... and {len(remaining_routes_sub)} more routes in this subcluster"
|
| 1101 |
+
):
|
| 1102 |
+
for route_id in remaining_routes_sub:
|
| 1103 |
+
try:
|
| 1104 |
+
route_score_sub = round(
|
| 1105 |
+
tree.route_score(route_id), 3
|
| 1106 |
+
)
|
| 1107 |
+
# svg_sub = get_route_svg(tree, route_id)
|
| 1108 |
+
svg_sub = get_route_svg_from_json(st.session_state.route_json, route_id)
|
| 1109 |
+
if svg_sub:
|
| 1110 |
+
st.image(
|
| 1111 |
+
svg_sub,
|
| 1112 |
+
caption=f"Route {route_id}; Score: {route_score_sub}",
|
| 1113 |
+
)
|
| 1114 |
+
else:
|
| 1115 |
+
st.warning(
|
| 1116 |
+
f"Could not generate SVG for route {route_id}."
|
| 1117 |
+
)
|
| 1118 |
+
except Exception as e:
|
| 1119 |
+
st.error(
|
| 1120 |
+
f"Error displaying route {route_id} in subcluster (expanded): {e}"
|
| 1121 |
+
)
|
| 1122 |
+
else:
|
| 1123 |
+
st.info("Select a valid cluster and subcluster index to see details.")
|
| 1124 |
+
|
| 1125 |
+
|
| 1126 |
+
def download_subclustering_results():
|
| 1127 |
+
"""13. Subclustering Results Download: Providing functionality to download."""
|
| 1128 |
+
if (
|
| 1129 |
+
st.session_state.get("subclustering_done", False)
|
| 1130 |
+
and "subcluster_num_select_key" in st.session_state
|
| 1131 |
+
and "subcluster_index_select_key" in st.session_state
|
| 1132 |
+
):
|
| 1133 |
+
|
| 1134 |
+
sub = st.session_state.get("subclusters")
|
| 1135 |
+
tree = st.session_state.get("tree")
|
| 1136 |
+
sb_cgrs_for_report = st.session_state.get(
|
| 1137 |
+
"sb_cgrs_dict"
|
| 1138 |
+
) # Used by routes_subclustering_report
|
| 1139 |
+
|
| 1140 |
+
user_input_cluster_num_display = st.session_state.subcluster_num_select_key
|
| 1141 |
+
selected_subcluster_idx = st.session_state.subcluster_index_select_key
|
| 1142 |
+
|
| 1143 |
+
if not tree or not sub or not sb_cgrs_for_report:
|
| 1144 |
+
st.warning(
|
| 1145 |
+
"Missing data for subclustering report generation (tree, subclusters, or SB-CGRs)."
|
| 1146 |
+
)
|
| 1147 |
+
return
|
| 1148 |
+
|
| 1149 |
+
if (
|
| 1150 |
+
user_input_cluster_num_display in sub
|
| 1151 |
+
and selected_subcluster_idx in sub[user_input_cluster_num_display]
|
| 1152 |
+
):
|
| 1153 |
+
|
| 1154 |
+
subcluster_data_for_report = sub[user_input_cluster_num_display][
|
| 1155 |
+
selected_subcluster_idx
|
| 1156 |
+
]
|
| 1157 |
+
# Apply the same post-processing as in display
|
| 1158 |
+
processed_subcluster_data = post_process_subgroup(
|
| 1159 |
+
subcluster_data_for_report
|
| 1160 |
+
)
|
| 1161 |
+
if "routes_data" in subcluster_data_for_report and isinstance(
|
| 1162 |
+
subcluster_data_for_report["routes_data"], dict
|
| 1163 |
+
):
|
| 1164 |
+
processed_subcluster_data["group_lgs"] = group_by_identical_values(
|
| 1165 |
+
subcluster_data_for_report["routes_data"]
|
| 1166 |
+
)
|
| 1167 |
+
else:
|
| 1168 |
+
processed_subcluster_data["group_lgs"] = {}
|
| 1169 |
+
|
| 1170 |
+
try:
|
| 1171 |
+
subcluster_html_content = routes_subclustering_report(
|
| 1172 |
+
tree,
|
| 1173 |
+
processed_subcluster_data, # Pass the specific post-processed subcluster data
|
| 1174 |
+
user_input_cluster_num_display,
|
| 1175 |
+
selected_subcluster_idx,
|
| 1176 |
+
sb_cgrs_for_report, # Pass the whole sb_cgrs dict
|
| 1177 |
+
if_lg_group=True, # This parameter was in the original call
|
| 1178 |
+
)
|
| 1179 |
+
st.download_button(
|
| 1180 |
+
label=f"Download report for subcluster {user_input_cluster_num_display}.{selected_subcluster_idx}",
|
| 1181 |
+
data=subcluster_html_content,
|
| 1182 |
+
file_name=f"subcluster_{user_input_cluster_num_display}.{selected_subcluster_idx}_{st.session_state.target_smiles}.html",
|
| 1183 |
+
mime="text/html",
|
| 1184 |
+
key=f"download_subcluster_{user_input_cluster_num_display}_{selected_subcluster_idx}",
|
| 1185 |
+
)
|
| 1186 |
+
except Exception as e:
|
| 1187 |
+
st.error(
|
| 1188 |
+
f"Error generating download report for subcluster {user_input_cluster_num_display}.{selected_subcluster_idx}: {e}"
|
| 1189 |
+
)
|
| 1190 |
+
# else:
|
| 1191 |
+
# This case is handled by the display logic mostly, download button just won't appear or will be for previous valid selection.
|
| 1192 |
+
|
| 1193 |
+
|
| 1194 |
+
def implement_restart():
|
| 1195 |
+
"""14. Restart: Implementing the logic to reset or restart the application state."""
|
| 1196 |
+
st.divider()
|
| 1197 |
+
st.header("Restart Application State")
|
| 1198 |
+
if st.button("Clear All Results & Restart", key="restart_button"):
|
| 1199 |
+
keys_to_clear = [
|
| 1200 |
+
"planning_done",
|
| 1201 |
+
"tree",
|
| 1202 |
+
"res",
|
| 1203 |
+
"target_smiles",
|
| 1204 |
+
"clustering_done",
|
| 1205 |
+
"clusters",
|
| 1206 |
+
"reactions_dict",
|
| 1207 |
+
"num_clusters_setting",
|
| 1208 |
+
"route_cgrs_dict",
|
| 1209 |
+
"sb_cgrs_dict",
|
| 1210 |
+
"route_json",
|
| 1211 |
+
"subclustering_done",
|
| 1212 |
+
"subclusters", # "sub" was renamed
|
| 1213 |
+
"clusters_downloaded",
|
| 1214 |
+
# Potentially ketcher related keys if they need manual reset beyond new input
|
| 1215 |
+
"ketcher_widget",
|
| 1216 |
+
"smiles_text_input_key", # Keys for widgets
|
| 1217 |
+
"subcluster_num_select_key",
|
| 1218 |
+
"subcluster_index_select_key",
|
| 1219 |
+
]
|
| 1220 |
+
for key in keys_to_clear:
|
| 1221 |
+
if key in st.session_state:
|
| 1222 |
+
del st.session_state[key]
|
| 1223 |
+
|
| 1224 |
+
# Reset ketcher input to default by resetting its session state variable
|
| 1225 |
+
st.session_state.ketcher = DEFAULT_MOL
|
| 1226 |
+
# Also explicitly set target_smiles to empty or default to avoid stale data
|
| 1227 |
+
st.session_state.target_smiles = ""
|
| 1228 |
+
|
| 1229 |
+
# It's generally better to let Streamlit manage widget state if possible,
|
| 1230 |
+
# but for a full reset, clearing their explicit session state keys might be needed.
|
| 1231 |
+
st.rerun()
|
| 1232 |
+
|
| 1233 |
+
|
| 1234 |
+
# --- Main Application Flow ---
|
| 1235 |
+
def main():
|
| 1236 |
+
initialize_app()
|
| 1237 |
+
setup_sidebar()
|
| 1238 |
+
current_smile_code = handle_molecule_input()
|
| 1239 |
+
# Update session_state.ketcher if current_smile_code has changed from ketcher output
|
| 1240 |
+
if st.session_state.get("ketcher") != current_smile_code:
|
| 1241 |
+
st.session_state.ketcher = current_smile_code
|
| 1242 |
+
# No rerun here, let the flow continue. handle_molecule_input already warns.
|
| 1243 |
+
|
| 1244 |
+
setup_planning_options() # This function now also handles the button press and logic for planning
|
| 1245 |
+
|
| 1246 |
+
# Display planning results and download options together
|
| 1247 |
+
if st.session_state.get("planning_done", False):
|
| 1248 |
+
display_planning_results() # Displays stats and routes
|
| 1249 |
+
if st.session_state.res and st.session_state.res.get("solved", False):
|
| 1250 |
+
stat_col, download_col = st.columns(
|
| 1251 |
+
2, gap="medium"
|
| 1252 |
+
) # Placeholder for download column
|
| 1253 |
+
with stat_col:
|
| 1254 |
+
st.subheader("Statistics")
|
| 1255 |
+
try:
|
| 1256 |
+
res = st.session_state.res
|
| 1257 |
+
if (
|
| 1258 |
+
"target_smiles" not in res
|
| 1259 |
+
and "target_smiles" in st.session_state
|
| 1260 |
+
):
|
| 1261 |
+
res["target_smiles"] = st.session_state.target_smiles
|
| 1262 |
+
cols_to_show = [
|
| 1263 |
+
col
|
| 1264 |
+
for col in [
|
| 1265 |
+
"target_smiles",
|
| 1266 |
+
"num_routes",
|
| 1267 |
+
"num_nodes",
|
| 1268 |
+
"num_iter",
|
| 1269 |
+
"search_time",
|
| 1270 |
+
]
|
| 1271 |
+
if col in res
|
| 1272 |
+
]
|
| 1273 |
+
if cols_to_show: # Ensure there are columns to show
|
| 1274 |
+
df = pd.DataFrame(res, index=[0])[cols_to_show]
|
| 1275 |
+
st.dataframe(df)
|
| 1276 |
+
else:
|
| 1277 |
+
st.write("No statistics to display from planning results.")
|
| 1278 |
+
except Exception as e:
|
| 1279 |
+
st.error(f"Error displaying statistics: {e}")
|
| 1280 |
+
st.write(res) # Show raw dict if DataFrame fails
|
| 1281 |
+
with download_col:
|
| 1282 |
+
st.subheader("Planning Downloads") # Adding a subheader for clarity
|
| 1283 |
+
download_planning_results()
|
| 1284 |
+
|
| 1285 |
+
# Clustering section (setup button, display, download)
|
| 1286 |
+
if (
|
| 1287 |
+
st.session_state.get("planning_done", False)
|
| 1288 |
+
and st.session_state.res
|
| 1289 |
+
and st.session_state.res.get("solved", False)
|
| 1290 |
+
):
|
| 1291 |
+
setup_clustering() # Contains the "Run Clustering" button and logic
|
| 1292 |
+
if st.session_state.get("clustering_done", False):
|
| 1293 |
+
display_clustering_results() # Displays cluster routes and stats
|
| 1294 |
+
cluster_stat_col, cluster_download_col = st.columns(2, gap="medium")
|
| 1295 |
+
|
| 1296 |
+
with cluster_stat_col:
|
| 1297 |
+
clusters = st.session_state.clusters
|
| 1298 |
+
cluster_sizes = [
|
| 1299 |
+
cluster.get("group_size", 0)
|
| 1300 |
+
for cluster in clusters.values()
|
| 1301 |
+
if cluster
|
| 1302 |
+
] # Safe get
|
| 1303 |
+
st.subheader("Cluster Statistics")
|
| 1304 |
+
if cluster_sizes:
|
| 1305 |
+
cluster_df = pd.DataFrame(
|
| 1306 |
+
{
|
| 1307 |
+
"Cluster": [
|
| 1308 |
+
k for k, v in clusters.items() if v
|
| 1309 |
+
], # Filter out empty clusters
|
| 1310 |
+
"Number of Routes": [
|
| 1311 |
+
v["group_size"] for v in clusters.values() if v
|
| 1312 |
+
],
|
| 1313 |
+
}
|
| 1314 |
+
)
|
| 1315 |
+
if not cluster_df.empty:
|
| 1316 |
+
cluster_df.index += 1
|
| 1317 |
+
st.dataframe(cluster_df)
|
| 1318 |
+
best_route_html = html_top_routes_cluster(
|
| 1319 |
+
clusters,
|
| 1320 |
+
st.session_state.tree,
|
| 1321 |
+
st.session_state.target_smiles,
|
| 1322 |
+
)
|
| 1323 |
+
st.download_button(
|
| 1324 |
+
label=f"Download best route from each cluster",
|
| 1325 |
+
data=best_route_html,
|
| 1326 |
+
file_name=f"cluster_best_{st.session_state.target_smiles}.html",
|
| 1327 |
+
mime="text/html",
|
| 1328 |
+
key=f"download_cluster_best",
|
| 1329 |
+
)
|
| 1330 |
+
else:
|
| 1331 |
+
st.write("No valid cluster data to display statistics for.")
|
| 1332 |
+
# download_top_routes_cluster()
|
| 1333 |
+
else:
|
| 1334 |
+
st.write("No cluster data to display statistics for.")
|
| 1335 |
+
with cluster_download_col:
|
| 1336 |
+
download_clustering_results()
|
| 1337 |
+
|
| 1338 |
+
# Subclustering section (setup button, display, download)
|
| 1339 |
+
if st.session_state.get("clustering_done", False): # Depends on clustering
|
| 1340 |
+
setup_subclustering() # Contains "Run Subclustering" button
|
| 1341 |
+
if st.session_state.get("subclustering_done", False):
|
| 1342 |
+
display_subclustering_results() # Displays subcluster details and routes
|
| 1343 |
+
download_subclustering_results() # This needs to be called after selections are made in display.
|
| 1344 |
+
|
| 1345 |
+
implement_restart()
|
| 1346 |
+
|
| 1347 |
+
|
| 1348 |
+
if __name__ == "__main__":
|
| 1349 |
+
main()
|
pre-requirements.txt
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
--find-links https://download.pytorch.org/whl/torch_stable.html
|
| 2 |
+
torch==2.2.2+cpu
|
| 3 |
+
scikit-learn==1.5.1
|
| 4 |
+
scipy==1.14.0
|
| 5 |
+
fastcluster==1.2.6
|
| 6 |
+
matplotlib==3.10.1
|
| 7 |
+
seaborn==0.13.2
|
requirements.txt
CHANGED
|
@@ -1,3 +1,6 @@
|
|
| 1 |
-
|
| 2 |
-
|
| 3 |
-
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
streamlit
|
| 2 |
+
streamlit_ketcher
|
| 3 |
+
git+https://github.com/Laboratoire-de-Chemoinformatique/SynPlanner.git
|
| 4 |
+
|
| 5 |
+
git+https://github.com/cimm-kzn/StructureFingerprint.git
|
| 6 |
+
scikit-learn
|
src/streamlit_app.py
DELETED
|
@@ -1,40 +0,0 @@
|
|
| 1 |
-
import altair as alt
|
| 2 |
-
import numpy as np
|
| 3 |
-
import pandas as pd
|
| 4 |
-
import streamlit as st
|
| 5 |
-
|
| 6 |
-
"""
|
| 7 |
-
# Welcome to Streamlit!
|
| 8 |
-
|
| 9 |
-
Edit `/streamlit_app.py` to customize this app to your heart's desire :heart:.
|
| 10 |
-
If you have any questions, checkout our [documentation](https://docs.streamlit.io) and [community
|
| 11 |
-
forums](https://discuss.streamlit.io).
|
| 12 |
-
|
| 13 |
-
In the meantime, below is an example of what you can do with just a few lines of code:
|
| 14 |
-
"""
|
| 15 |
-
|
| 16 |
-
num_points = st.slider("Number of points in spiral", 1, 10000, 1100)
|
| 17 |
-
num_turns = st.slider("Number of turns in spiral", 1, 300, 31)
|
| 18 |
-
|
| 19 |
-
indices = np.linspace(0, 1, num_points)
|
| 20 |
-
theta = 2 * np.pi * num_turns * indices
|
| 21 |
-
radius = indices
|
| 22 |
-
|
| 23 |
-
x = radius * np.cos(theta)
|
| 24 |
-
y = radius * np.sin(theta)
|
| 25 |
-
|
| 26 |
-
df = pd.DataFrame({
|
| 27 |
-
"x": x,
|
| 28 |
-
"y": y,
|
| 29 |
-
"idx": indices,
|
| 30 |
-
"rand": np.random.randn(num_points),
|
| 31 |
-
})
|
| 32 |
-
|
| 33 |
-
st.altair_chart(alt.Chart(df, height=700, width=700)
|
| 34 |
-
.mark_point(filled=True)
|
| 35 |
-
.encode(
|
| 36 |
-
x=alt.X("x", axis=None),
|
| 37 |
-
y=alt.Y("y", axis=None),
|
| 38 |
-
color=alt.Color("idx", legend=None, scale=alt.Scale()),
|
| 39 |
-
size=alt.Size("rand", legend=None, scale=alt.Scale(range=[1, 150])),
|
| 40 |
-
))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
synplan/__init__.py
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .mcts import *
|
| 2 |
+
|
| 3 |
+
__all__ = ["Tree"]
|
synplan/chem/__init__.py
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from CGRtools.files import SMILESRead
|
| 2 |
+
|
| 3 |
+
smiles_parser = SMILESRead.create_parser(ignore=True)
|
synplan/chem/data/__init__.py
ADDED
|
File without changes
|
synplan/chem/data/filtering.py
ADDED
|
@@ -0,0 +1,962 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Module containing classes abd functions for reactions filtering."""
|
| 2 |
+
|
| 3 |
+
import logging
|
| 4 |
+
from dataclasses import dataclass
|
| 5 |
+
from io import TextIOWrapper
|
| 6 |
+
from typing import Any, Dict, Iterable, List, Optional, Tuple
|
| 7 |
+
|
| 8 |
+
import numpy as np
|
| 9 |
+
import ray
|
| 10 |
+
import yaml
|
| 11 |
+
from CGRtools.containers import CGRContainer, MoleculeContainer, ReactionContainer
|
| 12 |
+
from chython.algorithms.fingerprints.morgan import MorganFingerprint
|
| 13 |
+
from tqdm import tqdm
|
| 14 |
+
|
| 15 |
+
from synplan.chem.data.standardizing import (
|
| 16 |
+
AromaticFormStandardizer,
|
| 17 |
+
KekuleFormStandardizer,
|
| 18 |
+
RemoveReagentsStandardizer,
|
| 19 |
+
)
|
| 20 |
+
from synplan.chem.utils import cgrtools_to_chython_molecule
|
| 21 |
+
from synplan.utils.config import ConfigABC, convert_config_to_dict
|
| 22 |
+
from synplan.utils.files import ReactionReader, ReactionWriter
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
@dataclass
|
| 26 |
+
class CompeteProductsConfig(ConfigABC):
|
| 27 |
+
fingerprint_tanimoto_threshold: float = 0.3
|
| 28 |
+
mcs_tanimoto_threshold: float = 0.6
|
| 29 |
+
|
| 30 |
+
@staticmethod
|
| 31 |
+
def from_dict(config_dict: Dict[str, Any]) -> "CompeteProductsConfig":
|
| 32 |
+
"""Create an instance of CompeteProductsConfig from a dictionary."""
|
| 33 |
+
return CompeteProductsConfig(**config_dict)
|
| 34 |
+
|
| 35 |
+
@staticmethod
|
| 36 |
+
def from_yaml(file_path: str) -> "CompeteProductsConfig":
|
| 37 |
+
"""Deserialize a YAML file into a CompeteProductsConfig object."""
|
| 38 |
+
with open(file_path, "r", encoding="utf-8") as file:
|
| 39 |
+
config_dict = yaml.safe_load(file)
|
| 40 |
+
return CompeteProductsConfig.from_dict(config_dict)
|
| 41 |
+
|
| 42 |
+
def _validate_params(self, params: Dict[str, Any]) -> None:
|
| 43 |
+
"""Validate configuration parameters."""
|
| 44 |
+
if not isinstance(params.get("fingerprint_tanimoto_threshold"), float) or not (
|
| 45 |
+
0 <= params["fingerprint_tanimoto_threshold"] <= 1
|
| 46 |
+
):
|
| 47 |
+
raise ValueError(
|
| 48 |
+
"Invalid 'fingerprint_tanimoto_threshold'; expected a float between 0 and 1"
|
| 49 |
+
)
|
| 50 |
+
|
| 51 |
+
if not isinstance(params.get("mcs_tanimoto_threshold"), float) or not (
|
| 52 |
+
0 <= params["mcs_tanimoto_threshold"] <= 1
|
| 53 |
+
):
|
| 54 |
+
raise ValueError(
|
| 55 |
+
"Invalid 'mcs_tanimoto_threshold'; expected a float between 0 and 1"
|
| 56 |
+
)
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
class CompeteProductsFilter:
|
| 60 |
+
"""Checks if there are compete reactions."""
|
| 61 |
+
|
| 62 |
+
def __init__(
|
| 63 |
+
self,
|
| 64 |
+
fingerprint_tanimoto_threshold: float = 0.3,
|
| 65 |
+
mcs_tanimoto_threshold: float = 0.6,
|
| 66 |
+
):
|
| 67 |
+
self.fingerprint_tanimoto_threshold = fingerprint_tanimoto_threshold
|
| 68 |
+
self.mcs_tanimoto_threshold = mcs_tanimoto_threshold
|
| 69 |
+
|
| 70 |
+
@staticmethod
|
| 71 |
+
def from_config(config: CompeteProductsConfig) -> "CompeteProductsFilter":
|
| 72 |
+
"""Creates an instance of CompeteProductsFilter from a configuration object."""
|
| 73 |
+
return CompeteProductsFilter(
|
| 74 |
+
config.fingerprint_tanimoto_threshold, config.mcs_tanimoto_threshold
|
| 75 |
+
)
|
| 76 |
+
|
| 77 |
+
def __call__(self, reaction: ReactionContainer) -> bool:
|
| 78 |
+
"""Checks if the reaction has competing products, else False.
|
| 79 |
+
|
| 80 |
+
:param reaction: Input reaction.
|
| 81 |
+
:return: Returns True if the reaction has competing products, else False.
|
| 82 |
+
"""
|
| 83 |
+
mf = MorganFingerprint()
|
| 84 |
+
is_compete = False
|
| 85 |
+
|
| 86 |
+
# check for compete products using both fingerprint similarity and maximum common substructure (MCS) similarity
|
| 87 |
+
for mol in reaction.reagents:
|
| 88 |
+
for other_mol in reaction.products:
|
| 89 |
+
if len(mol) > 6 and len(other_mol) > 6:
|
| 90 |
+
# compute fingerprint similarity
|
| 91 |
+
molf = mf.transform([cgrtools_to_chython_molecule(mol)])
|
| 92 |
+
other_molf = mf.transform([cgrtools_to_chython_molecule(other_mol)])
|
| 93 |
+
fingerprint_tanimoto = tanimoto_kernel(molf, other_molf)[0][0]
|
| 94 |
+
|
| 95 |
+
# if fingerprint similarity is high enough, check for MCS similarity
|
| 96 |
+
if fingerprint_tanimoto > self.fingerprint_tanimoto_threshold:
|
| 97 |
+
try:
|
| 98 |
+
# find the maximum common substructure (MCS) and compute its size
|
| 99 |
+
clique_size = len(
|
| 100 |
+
next(mol.get_mcs_mapping(other_mol, limit=100))
|
| 101 |
+
)
|
| 102 |
+
|
| 103 |
+
# calculate MCS similarity based on MCS size
|
| 104 |
+
mcs_tanimoto = clique_size / (
|
| 105 |
+
len(mol) + len(other_mol) - clique_size
|
| 106 |
+
)
|
| 107 |
+
|
| 108 |
+
# if MCS similarity is also high enough, mark the reaction as having compete products
|
| 109 |
+
if mcs_tanimoto > self.mcs_tanimoto_threshold:
|
| 110 |
+
is_compete = True
|
| 111 |
+
break
|
| 112 |
+
except StopIteration:
|
| 113 |
+
continue
|
| 114 |
+
|
| 115 |
+
return is_compete
|
| 116 |
+
|
| 117 |
+
|
| 118 |
+
@dataclass
|
| 119 |
+
class DynamicBondsConfig(ConfigABC):
|
| 120 |
+
min_bonds_number: int = 1
|
| 121 |
+
max_bonds_number: int = 6
|
| 122 |
+
|
| 123 |
+
@staticmethod
|
| 124 |
+
def from_dict(config_dict: Dict[str, Any]) -> "DynamicBondsConfig":
|
| 125 |
+
"""Create an instance of DynamicBondsConfig from a dictionary."""
|
| 126 |
+
return DynamicBondsConfig(**config_dict)
|
| 127 |
+
|
| 128 |
+
@staticmethod
|
| 129 |
+
def from_yaml(file_path: str) -> "DynamicBondsConfig":
|
| 130 |
+
"""Deserialize a YAML file into a DynamicBondsConfig object."""
|
| 131 |
+
with open(file_path, "r") as file:
|
| 132 |
+
config_dict = yaml.safe_load(file)
|
| 133 |
+
return DynamicBondsConfig.from_dict(config_dict)
|
| 134 |
+
|
| 135 |
+
def _validate_params(self, params: Dict[str, Any]) -> None:
|
| 136 |
+
"""Validate configuration parameters."""
|
| 137 |
+
if (
|
| 138 |
+
not isinstance(params.get("min_bonds_number"), int)
|
| 139 |
+
or params["min_bonds_number"] < 0
|
| 140 |
+
):
|
| 141 |
+
raise ValueError(
|
| 142 |
+
"Invalid 'min_bonds_number'; expected a non-negative integer"
|
| 143 |
+
)
|
| 144 |
+
|
| 145 |
+
if (
|
| 146 |
+
not isinstance(params.get("max_bonds_number"), int)
|
| 147 |
+
or params["max_bonds_number"] < 0
|
| 148 |
+
):
|
| 149 |
+
raise ValueError(
|
| 150 |
+
"Invalid 'max_bonds_number'; expected a non-negative integer"
|
| 151 |
+
)
|
| 152 |
+
|
| 153 |
+
if params["min_bonds_number"] > params["max_bonds_number"]:
|
| 154 |
+
raise ValueError(
|
| 155 |
+
"'min_bonds_number' cannot be greater than 'max_bonds_number'"
|
| 156 |
+
)
|
| 157 |
+
|
| 158 |
+
|
| 159 |
+
class DynamicBondsFilter:
|
| 160 |
+
"""Checks if there is an unacceptable number of dynamic bonds in CGR."""
|
| 161 |
+
|
| 162 |
+
def __init__(self, min_bonds_number: int = 1, max_bonds_number: int = 6):
|
| 163 |
+
self.min_bonds_number = min_bonds_number
|
| 164 |
+
self.max_bonds_number = max_bonds_number
|
| 165 |
+
|
| 166 |
+
@staticmethod
|
| 167 |
+
def from_config(config: DynamicBondsConfig):
|
| 168 |
+
"""Creates an instance of DynamicBondsChecker from a configuration object."""
|
| 169 |
+
return DynamicBondsFilter(config.min_bonds_number, config.max_bonds_number)
|
| 170 |
+
|
| 171 |
+
def __call__(self, reaction: ReactionContainer) -> bool:
|
| 172 |
+
cgr = ~reaction
|
| 173 |
+
return not (
|
| 174 |
+
self.min_bonds_number <= len(cgr.center_bonds) <= self.max_bonds_number
|
| 175 |
+
)
|
| 176 |
+
|
| 177 |
+
|
| 178 |
+
@dataclass
|
| 179 |
+
class SmallMoleculesConfig(ConfigABC):
|
| 180 |
+
mol_max_size: int = 6
|
| 181 |
+
|
| 182 |
+
@staticmethod
|
| 183 |
+
def from_dict(config_dict: Dict[str, Any]) -> "SmallMoleculesConfig":
|
| 184 |
+
"""Creates an instance of SmallMoleculesConfig from a dictionary."""
|
| 185 |
+
return SmallMoleculesConfig(**config_dict)
|
| 186 |
+
|
| 187 |
+
@staticmethod
|
| 188 |
+
def from_yaml(file_path: str) -> "SmallMoleculesConfig":
|
| 189 |
+
"""Deserialize a YAML file into a SmallMoleculesConfig object."""
|
| 190 |
+
with open(file_path, "r") as file:
|
| 191 |
+
config_dict = yaml.safe_load(file)
|
| 192 |
+
return SmallMoleculesConfig.from_dict(config_dict)
|
| 193 |
+
|
| 194 |
+
def _validate_params(self, params: Dict[str, Any]) -> None:
|
| 195 |
+
"""Validate configuration parameters."""
|
| 196 |
+
if (
|
| 197 |
+
not isinstance(params.get("mol_max_size"), int)
|
| 198 |
+
or params["mol_max_size"] < 1
|
| 199 |
+
):
|
| 200 |
+
raise ValueError("Invalid 'mol_max_size'; expected a positive integer")
|
| 201 |
+
|
| 202 |
+
|
| 203 |
+
class SmallMoleculesFilter:
|
| 204 |
+
"""Checks if there are only small molecules in the reaction or if there is only one
|
| 205 |
+
small reactant or product."""
|
| 206 |
+
|
| 207 |
+
def __init__(self, mol_max_size: int = 6):
|
| 208 |
+
self.limit = mol_max_size
|
| 209 |
+
|
| 210 |
+
@staticmethod
|
| 211 |
+
def from_config(config: SmallMoleculesConfig) -> "SmallMoleculesFilter":
|
| 212 |
+
"""Creates an instance of SmallMoleculesChecker from a configuration object."""
|
| 213 |
+
return SmallMoleculesFilter(config.mol_max_size)
|
| 214 |
+
|
| 215 |
+
def __call__(self, reaction: ReactionContainer) -> bool:
|
| 216 |
+
if (
|
| 217 |
+
(
|
| 218 |
+
len(reaction.reactants) == 1
|
| 219 |
+
and self.are_only_small_molecules(reaction.reactants)
|
| 220 |
+
)
|
| 221 |
+
or (
|
| 222 |
+
len(reaction.products) == 1
|
| 223 |
+
and self.are_only_small_molecules(reaction.products)
|
| 224 |
+
)
|
| 225 |
+
or (
|
| 226 |
+
self.are_only_small_molecules(reaction.reactants)
|
| 227 |
+
and self.are_only_small_molecules(reaction.products)
|
| 228 |
+
)
|
| 229 |
+
):
|
| 230 |
+
return True
|
| 231 |
+
return False
|
| 232 |
+
|
| 233 |
+
def are_only_small_molecules(self, molecules: Iterable[MoleculeContainer]) -> bool:
|
| 234 |
+
"""Checks if all molecules in the given iterable are small molecules."""
|
| 235 |
+
return all(len(molecule) <= self.limit for molecule in molecules)
|
| 236 |
+
|
| 237 |
+
|
| 238 |
+
@dataclass
|
| 239 |
+
class CGRConnectedComponentsConfig:
|
| 240 |
+
pass
|
| 241 |
+
|
| 242 |
+
|
| 243 |
+
class CGRConnectedComponentsFilter:
|
| 244 |
+
"""Checks if CGR contains unrelated components (without reagents)."""
|
| 245 |
+
|
| 246 |
+
@staticmethod
|
| 247 |
+
def from_config(
|
| 248 |
+
config: CGRConnectedComponentsConfig,
|
| 249 |
+
) -> "CGRConnectedComponentsFilter":
|
| 250 |
+
"""Creates an instance of CGRConnectedComponentsChecker from a configuration
|
| 251 |
+
object."""
|
| 252 |
+
return CGRConnectedComponentsFilter()
|
| 253 |
+
|
| 254 |
+
def __call__(self, reaction: ReactionContainer) -> bool:
|
| 255 |
+
tmp_reaction = ReactionContainer(reaction.reactants, reaction.products)
|
| 256 |
+
cgr = ~tmp_reaction
|
| 257 |
+
return cgr.connected_components_count > 1
|
| 258 |
+
|
| 259 |
+
|
| 260 |
+
@dataclass
|
| 261 |
+
class RingsChangeConfig:
|
| 262 |
+
pass
|
| 263 |
+
|
| 264 |
+
|
| 265 |
+
class RingsChangeFilter:
|
| 266 |
+
"""Checks if there is changing rings number in the reaction."""
|
| 267 |
+
|
| 268 |
+
@staticmethod
|
| 269 |
+
def from_config(config: RingsChangeConfig) -> "RingsChangeFilter":
|
| 270 |
+
"""Creates an instance of RingsChecker from a configuration object."""
|
| 271 |
+
return RingsChangeFilter()
|
| 272 |
+
|
| 273 |
+
def __call__(self, reaction: ReactionContainer):
|
| 274 |
+
"""
|
| 275 |
+
Returns True if there are valence mistakes in the reaction or there is a
|
| 276 |
+
reaction with mismatch numbers of all rings or aromatic rings in reactants and
|
| 277 |
+
products (reaction in rings)
|
| 278 |
+
|
| 279 |
+
:param reaction: Input reaction.
|
| 280 |
+
:return: Returns True if there are valence mistakes in the reaction.
|
| 281 |
+
|
| 282 |
+
"""
|
| 283 |
+
|
| 284 |
+
r_rings, r_arom_rings = self._calc_rings(reaction.reactants)
|
| 285 |
+
p_rings, p_arom_rings = self._calc_rings(reaction.products)
|
| 286 |
+
|
| 287 |
+
return (r_arom_rings != p_arom_rings) or (r_rings != p_rings)
|
| 288 |
+
|
| 289 |
+
@staticmethod
|
| 290 |
+
def _calc_rings(molecules: Iterable) -> Tuple[int, int]:
|
| 291 |
+
"""
|
| 292 |
+
Calculates number of all rings and number of aromatic rings in molecules.
|
| 293 |
+
|
| 294 |
+
:param molecules: Set of molecules.
|
| 295 |
+
:return: Number of all rings and number of aromatic rings in molecules
|
| 296 |
+
"""
|
| 297 |
+
rings, arom_rings = 0, 0
|
| 298 |
+
for mol in molecules:
|
| 299 |
+
rings += mol.rings_count
|
| 300 |
+
arom_rings += len(mol.aromatic_rings)
|
| 301 |
+
return rings, arom_rings
|
| 302 |
+
|
| 303 |
+
|
| 304 |
+
@dataclass
|
| 305 |
+
class StrangeCarbonsConfig:
|
| 306 |
+
# currently empty, but can be extended in the future if needed
|
| 307 |
+
pass
|
| 308 |
+
|
| 309 |
+
|
| 310 |
+
class StrangeCarbonsFilter:
|
| 311 |
+
"""Checks if there are 'strange' carbons in the reaction."""
|
| 312 |
+
|
| 313 |
+
@staticmethod
|
| 314 |
+
def from_config(config: StrangeCarbonsConfig) -> "StrangeCarbonsFilter":
|
| 315 |
+
"""Creates an instance of StrangeCarbonsChecker from a configuration object."""
|
| 316 |
+
return StrangeCarbonsFilter()
|
| 317 |
+
|
| 318 |
+
def __call__(self, reaction: ReactionContainer) -> bool:
|
| 319 |
+
for molecule in reaction.reactants + reaction.products:
|
| 320 |
+
atoms_types = {
|
| 321 |
+
a.atomic_symbol for _, a in molecule.atoms()
|
| 322 |
+
} # atoms types in molecule
|
| 323 |
+
if len(atoms_types) == 1 and atoms_types.pop() == "C":
|
| 324 |
+
if len(molecule) == 1: # methane
|
| 325 |
+
return True
|
| 326 |
+
bond_types = {int(b) for _, _, b in molecule.bonds()}
|
| 327 |
+
if len(bond_types) == 1 and bond_types.pop() != 4:
|
| 328 |
+
return True # C molecules with only one type of bond (not aromatic)
|
| 329 |
+
return False
|
| 330 |
+
|
| 331 |
+
|
| 332 |
+
@dataclass
|
| 333 |
+
class NoReactionConfig:
|
| 334 |
+
# Currently empty, but can be extended in the future if needed
|
| 335 |
+
pass
|
| 336 |
+
|
| 337 |
+
|
| 338 |
+
class NoReactionFilter:
|
| 339 |
+
"""Checks if there is no reaction in the provided reaction container."""
|
| 340 |
+
|
| 341 |
+
@staticmethod
|
| 342 |
+
def from_config(config: NoReactionConfig) -> "NoReactionFilter":
|
| 343 |
+
"""Creates an instance of NoReactionChecker from a configuration object."""
|
| 344 |
+
return NoReactionFilter()
|
| 345 |
+
|
| 346 |
+
def __call__(self, reaction: ReactionContainer) -> bool:
|
| 347 |
+
cgr = ~reaction
|
| 348 |
+
return not cgr.center_atoms and not cgr.center_bonds
|
| 349 |
+
|
| 350 |
+
|
| 351 |
+
@dataclass
|
| 352 |
+
class MultiCenterConfig:
|
| 353 |
+
pass
|
| 354 |
+
|
| 355 |
+
|
| 356 |
+
class MultiCenterFilter:
|
| 357 |
+
"""Checks if there is a multicenter reaction."""
|
| 358 |
+
|
| 359 |
+
@staticmethod
|
| 360 |
+
def from_config(config: MultiCenterConfig) -> "MultiCenterFilter":
|
| 361 |
+
return MultiCenterFilter()
|
| 362 |
+
|
| 363 |
+
def __call__(self, reaction: ReactionContainer) -> bool:
|
| 364 |
+
cgr = ~reaction
|
| 365 |
+
return len(cgr.centers_list) > 1
|
| 366 |
+
|
| 367 |
+
|
| 368 |
+
@dataclass
|
| 369 |
+
class WrongCHBreakingConfig:
|
| 370 |
+
pass
|
| 371 |
+
|
| 372 |
+
|
| 373 |
+
class WrongCHBreakingFilter:
|
| 374 |
+
"""Checks for incorrect C-C bond formation from breaking a C-H bond."""
|
| 375 |
+
|
| 376 |
+
@staticmethod
|
| 377 |
+
def from_config(config: WrongCHBreakingConfig) -> "WrongCHBreakingFilter":
|
| 378 |
+
return WrongCHBreakingFilter()
|
| 379 |
+
|
| 380 |
+
def __call__(self, reaction: ReactionContainer) -> bool:
|
| 381 |
+
"""
|
| 382 |
+
Determines if a reaction involves incorrect C-C bond formation from breaking
|
| 383 |
+
a C-H bond.
|
| 384 |
+
|
| 385 |
+
:param reaction: The reaction to be filtered.
|
| 386 |
+
:return: True if incorrect C-C bond formation is found, False otherwise.
|
| 387 |
+
|
| 388 |
+
"""
|
| 389 |
+
|
| 390 |
+
if reaction.check_valence():
|
| 391 |
+
return False
|
| 392 |
+
|
| 393 |
+
copy_reaction = reaction.copy()
|
| 394 |
+
copy_reaction.explicify_hydrogens()
|
| 395 |
+
cgr = ~copy_reaction
|
| 396 |
+
reduced_cgr = cgr.augmented_substructure(cgr.center_atoms, deep=1)
|
| 397 |
+
|
| 398 |
+
return self.is_wrong_c_h_breaking(reduced_cgr)
|
| 399 |
+
|
| 400 |
+
@staticmethod
|
| 401 |
+
def is_wrong_c_h_breaking(cgr: CGRContainer) -> bool:
|
| 402 |
+
"""
|
| 403 |
+
Checks for incorrect C-C bond formation from breaking a C-H bond in a CGR.
|
| 404 |
+
|
| 405 |
+
:param cgr: The CGR with explicified hydrogens.
|
| 406 |
+
:return: True if incorrect C-C bond formation is found, False otherwise.
|
| 407 |
+
|
| 408 |
+
"""
|
| 409 |
+
for atom_id in cgr.center_atoms:
|
| 410 |
+
if cgr.atom(atom_id).atomic_symbol == "C":
|
| 411 |
+
is_c_h_breaking, is_c_c_formation = False, False
|
| 412 |
+
c_with_h_id, another_c_id = None, None
|
| 413 |
+
|
| 414 |
+
for neighbour_id, bond in cgr._bonds[atom_id].items():
|
| 415 |
+
neighbour = cgr.atom(neighbour_id)
|
| 416 |
+
|
| 417 |
+
if (
|
| 418 |
+
bond.order
|
| 419 |
+
and not bond.p_order
|
| 420 |
+
and neighbour.atomic_symbol == "H"
|
| 421 |
+
):
|
| 422 |
+
is_c_h_breaking = True
|
| 423 |
+
c_with_h_id = atom_id
|
| 424 |
+
|
| 425 |
+
elif (
|
| 426 |
+
not bond.order
|
| 427 |
+
and bond.p_order
|
| 428 |
+
and neighbour.atomic_symbol == "C"
|
| 429 |
+
):
|
| 430 |
+
is_c_c_formation = True
|
| 431 |
+
another_c_id = neighbour_id
|
| 432 |
+
|
| 433 |
+
if is_c_h_breaking and is_c_c_formation:
|
| 434 |
+
# check for presence of heteroatoms in the first environment of 2 bonding carbons
|
| 435 |
+
if any(
|
| 436 |
+
cgr.atom(neighbour_id).atomic_symbol not in ("C", "H")
|
| 437 |
+
for neighbour_id in cgr._bonds[c_with_h_id]
|
| 438 |
+
) or any(
|
| 439 |
+
cgr.atom(neighbour_id).atomic_symbol not in ("C", "H")
|
| 440 |
+
for neighbour_id in cgr._bonds[another_c_id]
|
| 441 |
+
):
|
| 442 |
+
return False
|
| 443 |
+
return True
|
| 444 |
+
|
| 445 |
+
return False
|
| 446 |
+
|
| 447 |
+
|
| 448 |
+
@dataclass
|
| 449 |
+
class CCsp3BreakingConfig:
|
| 450 |
+
pass
|
| 451 |
+
|
| 452 |
+
|
| 453 |
+
class CCsp3BreakingFilter:
|
| 454 |
+
"""Checks if there is C(sp3)-C bond breaking."""
|
| 455 |
+
|
| 456 |
+
@staticmethod
|
| 457 |
+
def from_config(config: CCsp3BreakingConfig) -> "CCsp3BreakingFilter":
|
| 458 |
+
return CCsp3BreakingFilter()
|
| 459 |
+
|
| 460 |
+
def __call__(self, reaction: ReactionContainer) -> bool:
|
| 461 |
+
"""
|
| 462 |
+
Returns True if there is C(sp3)-C bonds breaking, else False.
|
| 463 |
+
|
| 464 |
+
:param reaction: Input reaction
|
| 465 |
+
:return: Returns True if there is C(sp3)-C bonds breaking, else False.
|
| 466 |
+
|
| 467 |
+
"""
|
| 468 |
+
cgr = ~reaction
|
| 469 |
+
reaction_center = cgr.augmented_substructure(cgr.center_atoms, deep=1)
|
| 470 |
+
for atom_id, neighbour_id, bond in reaction_center.bonds():
|
| 471 |
+
atom = reaction_center.atom(atom_id)
|
| 472 |
+
neighbour = reaction_center.atom(neighbour_id)
|
| 473 |
+
|
| 474 |
+
is_bond_broken = bond.order is not None and bond.p_order is None
|
| 475 |
+
are_atoms_carbons = (
|
| 476 |
+
atom.atomic_symbol == "C" and neighbour.atomic_symbol == "C"
|
| 477 |
+
)
|
| 478 |
+
is_atom_sp3 = atom.hybridization == 1 or neighbour.hybridization == 1
|
| 479 |
+
|
| 480 |
+
if is_bond_broken and are_atoms_carbons and is_atom_sp3:
|
| 481 |
+
return True
|
| 482 |
+
return False
|
| 483 |
+
|
| 484 |
+
|
| 485 |
+
@dataclass
|
| 486 |
+
class CCRingBreakingConfig:
|
| 487 |
+
"""
|
| 488 |
+
Object to pass to ReactionFilterConfig if you want to enable C-C ring breaking filter
|
| 489 |
+
|
| 490 |
+
"""
|
| 491 |
+
|
| 492 |
+
pass
|
| 493 |
+
|
| 494 |
+
|
| 495 |
+
class CCRingBreakingFilter:
|
| 496 |
+
"""Checks if a reaction involves ring C-C bond breaking."""
|
| 497 |
+
|
| 498 |
+
@staticmethod
|
| 499 |
+
def from_config(config: CCRingBreakingConfig):
|
| 500 |
+
return CCRingBreakingFilter()
|
| 501 |
+
|
| 502 |
+
def __call__(self, reaction: ReactionContainer) -> bool:
|
| 503 |
+
"""
|
| 504 |
+
Returns True if the reaction involves ring C-C bond breaking, else False.
|
| 505 |
+
|
| 506 |
+
:param reaction: Input reaction
|
| 507 |
+
:return: Returns True if the reaction involves ring C-C bond breaking, else
|
| 508 |
+
False.
|
| 509 |
+
|
| 510 |
+
"""
|
| 511 |
+
cgr = ~reaction
|
| 512 |
+
|
| 513 |
+
# Extract reactants' center atoms and their rings
|
| 514 |
+
reactants_center_atoms = {}
|
| 515 |
+
reactants_rings = set()
|
| 516 |
+
for reactant in reaction.reactants:
|
| 517 |
+
reactants_rings.update(reactant.sssr)
|
| 518 |
+
for n, atom in reactant.atoms():
|
| 519 |
+
if n in cgr.center_atoms:
|
| 520 |
+
reactants_center_atoms[n] = atom
|
| 521 |
+
|
| 522 |
+
# identify reaction center based on center atoms
|
| 523 |
+
reaction_center = cgr.augmented_substructure(atoms=cgr.center_atoms, deep=0)
|
| 524 |
+
|
| 525 |
+
# iterate over bonds in the reaction center and filter for ring C-C bond breaking
|
| 526 |
+
for atom_id, neighbour_id, bond in reaction_center.bonds():
|
| 527 |
+
try:
|
| 528 |
+
# Retrieve corresponding atoms from reactants
|
| 529 |
+
atom = reactants_center_atoms[atom_id]
|
| 530 |
+
neighbour = reactants_center_atoms[neighbour_id]
|
| 531 |
+
except KeyError:
|
| 532 |
+
continue
|
| 533 |
+
else:
|
| 534 |
+
# Check if the bond is broken and both atoms are carbons in rings of size 5, 6, or 7
|
| 535 |
+
is_bond_broken = (bond.order is not None) and (bond.p_order is None)
|
| 536 |
+
are_atoms_carbons = (
|
| 537 |
+
atom.atomic_symbol == "C" and neighbour.atomic_symbol == "C"
|
| 538 |
+
)
|
| 539 |
+
are_atoms_in_ring = (
|
| 540 |
+
set(atom.ring_sizes).intersection({5, 6, 7})
|
| 541 |
+
and set(neighbour.ring_sizes).intersection({5, 6, 7})
|
| 542 |
+
and any(
|
| 543 |
+
atom_id in ring and neighbour_id in ring
|
| 544 |
+
for ring in reactants_rings
|
| 545 |
+
)
|
| 546 |
+
)
|
| 547 |
+
|
| 548 |
+
# If all conditions are met, indicate ring C-C bond breaking
|
| 549 |
+
if is_bond_broken and are_atoms_carbons and are_atoms_in_ring:
|
| 550 |
+
return True
|
| 551 |
+
|
| 552 |
+
return False
|
| 553 |
+
|
| 554 |
+
|
| 555 |
+
@dataclass
|
| 556 |
+
class ReactionFilterConfig(ConfigABC):
|
| 557 |
+
"""
|
| 558 |
+
Configuration class for reaction filtering. This class manages configuration
|
| 559 |
+
settings for various reaction filters, including paths, file formats, and filter-
|
| 560 |
+
specific parameters.
|
| 561 |
+
|
| 562 |
+
:ivar dynamic_bonds_config: Configuration for dynamic bonds checking.
|
| 563 |
+
:ivar small_molecules_config: Configuration for small molecules checking.
|
| 564 |
+
:ivar strange_carbons_config: Configuration for strange carbons checking.
|
| 565 |
+
:ivar compete_products_config: Configuration for competing products checking.
|
| 566 |
+
:ivar cgr_connected_components_config: Configuration for CGR connected components checking.
|
| 567 |
+
:ivar rings_change_config: Configuration for rings change checking.
|
| 568 |
+
:ivar no_reaction_config: Configuration for no reaction checking.
|
| 569 |
+
:ivar multi_center_config: Configuration for multi-center checking.
|
| 570 |
+
:ivar wrong_ch_breaking_config: Configuration for wrong C-H breaking checking.
|
| 571 |
+
:ivar cc_sp3_breaking_config: Configuration for CC sp3 breaking checking.
|
| 572 |
+
:ivar cc_ring_breaking_config: Configuration for CC ring breaking checking.
|
| 573 |
+
|
| 574 |
+
"""
|
| 575 |
+
|
| 576 |
+
# configuration for reaction filters
|
| 577 |
+
dynamic_bonds_config: Optional[DynamicBondsConfig] = None
|
| 578 |
+
small_molecules_config: Optional[SmallMoleculesConfig] = None
|
| 579 |
+
strange_carbons_config: Optional[StrangeCarbonsConfig] = None
|
| 580 |
+
compete_products_config: Optional[CompeteProductsConfig] = None
|
| 581 |
+
cgr_connected_components_config: Optional[CGRConnectedComponentsConfig] = None
|
| 582 |
+
rings_change_config: Optional[RingsChangeConfig] = None
|
| 583 |
+
no_reaction_config: Optional[NoReactionConfig] = None
|
| 584 |
+
multi_center_config: Optional[MultiCenterConfig] = None
|
| 585 |
+
wrong_ch_breaking_config: Optional[WrongCHBreakingConfig] = None
|
| 586 |
+
cc_sp3_breaking_config: Optional[CCsp3BreakingConfig] = None
|
| 587 |
+
cc_ring_breaking_config: Optional[CCRingBreakingConfig] = None
|
| 588 |
+
|
| 589 |
+
def to_dict(self):
|
| 590 |
+
"""Converts the configuration into a dictionary."""
|
| 591 |
+
config_dict = {
|
| 592 |
+
"dynamic_bonds_config": convert_config_to_dict(
|
| 593 |
+
self.dynamic_bonds_config, DynamicBondsConfig
|
| 594 |
+
),
|
| 595 |
+
"small_molecules_config": convert_config_to_dict(
|
| 596 |
+
self.small_molecules_config, SmallMoleculesConfig
|
| 597 |
+
),
|
| 598 |
+
"compete_products_config": convert_config_to_dict(
|
| 599 |
+
self.compete_products_config, CompeteProductsConfig
|
| 600 |
+
),
|
| 601 |
+
"cgr_connected_components_config": (
|
| 602 |
+
{} if self.cgr_connected_components_config is not None else None
|
| 603 |
+
),
|
| 604 |
+
"rings_change_config": {} if self.rings_change_config is not None else None,
|
| 605 |
+
"strange_carbons_config": (
|
| 606 |
+
{} if self.strange_carbons_config is not None else None
|
| 607 |
+
),
|
| 608 |
+
"no_reaction_config": {} if self.no_reaction_config is not None else None,
|
| 609 |
+
"multi_center_config": {} if self.multi_center_config is not None else None,
|
| 610 |
+
"wrong_ch_breaking_config": (
|
| 611 |
+
{} if self.wrong_ch_breaking_config is not None else None
|
| 612 |
+
),
|
| 613 |
+
"cc_sp3_breaking_config": (
|
| 614 |
+
{} if self.cc_sp3_breaking_config is not None else None
|
| 615 |
+
),
|
| 616 |
+
"cc_ring_breaking_config": (
|
| 617 |
+
{} if self.cc_ring_breaking_config is not None else None
|
| 618 |
+
),
|
| 619 |
+
}
|
| 620 |
+
|
| 621 |
+
filtered_config_dict = {k: v for k, v in config_dict.items() if v is not None}
|
| 622 |
+
|
| 623 |
+
return filtered_config_dict
|
| 624 |
+
|
| 625 |
+
@staticmethod
|
| 626 |
+
def from_dict(config_dict: Dict[str, Any]) -> "ReactionFilterConfig":
|
| 627 |
+
"""Create an instance of ReactionCheckConfig from a dictionary."""
|
| 628 |
+
# Instantiate configuration objects if their corresponding dictionary is present
|
| 629 |
+
dynamic_bonds_config = (
|
| 630 |
+
DynamicBondsConfig(**config_dict["dynamic_bonds_config"])
|
| 631 |
+
if "dynamic_bonds_config" in config_dict
|
| 632 |
+
else None
|
| 633 |
+
)
|
| 634 |
+
|
| 635 |
+
small_molecules_config = (
|
| 636 |
+
SmallMoleculesConfig(**config_dict["small_molecules_config"])
|
| 637 |
+
if "small_molecules_config" in config_dict
|
| 638 |
+
else None
|
| 639 |
+
)
|
| 640 |
+
|
| 641 |
+
compete_products_config = (
|
| 642 |
+
CompeteProductsConfig(**config_dict["compete_products_config"])
|
| 643 |
+
if "compete_products_config" in config_dict
|
| 644 |
+
else None
|
| 645 |
+
)
|
| 646 |
+
|
| 647 |
+
cgr_connected_components_config = (
|
| 648 |
+
CGRConnectedComponentsConfig()
|
| 649 |
+
if "cgr_connected_components_config" in config_dict
|
| 650 |
+
else None
|
| 651 |
+
)
|
| 652 |
+
|
| 653 |
+
rings_change_config = (
|
| 654 |
+
RingsChangeConfig() if "rings_change_config" in config_dict else None
|
| 655 |
+
)
|
| 656 |
+
|
| 657 |
+
strange_carbons_config = (
|
| 658 |
+
StrangeCarbonsConfig() if "strange_carbons_config" in config_dict else None
|
| 659 |
+
)
|
| 660 |
+
|
| 661 |
+
no_reaction_config = (
|
| 662 |
+
NoReactionConfig() if "no_reaction_config" in config_dict else None
|
| 663 |
+
)
|
| 664 |
+
|
| 665 |
+
multi_center_config = (
|
| 666 |
+
MultiCenterConfig() if "multi_center_config" in config_dict else None
|
| 667 |
+
)
|
| 668 |
+
|
| 669 |
+
wrong_ch_breaking_config = (
|
| 670 |
+
WrongCHBreakingConfig()
|
| 671 |
+
if "wrong_ch_breaking_config" in config_dict
|
| 672 |
+
else None
|
| 673 |
+
)
|
| 674 |
+
|
| 675 |
+
cc_sp3_breaking_config = (
|
| 676 |
+
CCsp3BreakingConfig() if "cc_sp3_breaking_config" in config_dict else None
|
| 677 |
+
)
|
| 678 |
+
|
| 679 |
+
cc_ring_breaking_config = (
|
| 680 |
+
CCRingBreakingConfig() if "cc_ring_breaking_config" in config_dict else None
|
| 681 |
+
)
|
| 682 |
+
|
| 683 |
+
return ReactionFilterConfig(
|
| 684 |
+
dynamic_bonds_config=dynamic_bonds_config,
|
| 685 |
+
small_molecules_config=small_molecules_config,
|
| 686 |
+
compete_products_config=compete_products_config,
|
| 687 |
+
cgr_connected_components_config=cgr_connected_components_config,
|
| 688 |
+
rings_change_config=rings_change_config,
|
| 689 |
+
strange_carbons_config=strange_carbons_config,
|
| 690 |
+
no_reaction_config=no_reaction_config,
|
| 691 |
+
multi_center_config=multi_center_config,
|
| 692 |
+
wrong_ch_breaking_config=wrong_ch_breaking_config,
|
| 693 |
+
cc_sp3_breaking_config=cc_sp3_breaking_config,
|
| 694 |
+
cc_ring_breaking_config=cc_ring_breaking_config,
|
| 695 |
+
)
|
| 696 |
+
|
| 697 |
+
@staticmethod
|
| 698 |
+
def from_yaml(file_path: str) -> "ReactionFilterConfig":
|
| 699 |
+
"""Deserializes a YAML file into a ReactionCheckConfig object."""
|
| 700 |
+
with open(file_path, "r", encoding="utf-8") as file:
|
| 701 |
+
config_dict = yaml.safe_load(file)
|
| 702 |
+
return ReactionFilterConfig.from_dict(config_dict)
|
| 703 |
+
|
| 704 |
+
def _validate_params(self, params: Dict[str, Any]):
|
| 705 |
+
pass
|
| 706 |
+
|
| 707 |
+
def create_filters(self):
|
| 708 |
+
filter_instances = []
|
| 709 |
+
|
| 710 |
+
if self.dynamic_bonds_config is not None:
|
| 711 |
+
filter_instances.append(
|
| 712 |
+
DynamicBondsFilter.from_config(self.dynamic_bonds_config)
|
| 713 |
+
)
|
| 714 |
+
|
| 715 |
+
if self.small_molecules_config is not None:
|
| 716 |
+
filter_instances.append(
|
| 717 |
+
SmallMoleculesFilter.from_config(self.small_molecules_config)
|
| 718 |
+
)
|
| 719 |
+
|
| 720 |
+
if self.strange_carbons_config is not None:
|
| 721 |
+
filter_instances.append(
|
| 722 |
+
StrangeCarbonsFilter.from_config(self.strange_carbons_config)
|
| 723 |
+
)
|
| 724 |
+
|
| 725 |
+
if self.compete_products_config is not None:
|
| 726 |
+
filter_instances.append(
|
| 727 |
+
CompeteProductsFilter.from_config(self.compete_products_config)
|
| 728 |
+
)
|
| 729 |
+
|
| 730 |
+
if self.cgr_connected_components_config is not None:
|
| 731 |
+
filter_instances.append(
|
| 732 |
+
CGRConnectedComponentsFilter.from_config(
|
| 733 |
+
self.cgr_connected_components_config
|
| 734 |
+
)
|
| 735 |
+
)
|
| 736 |
+
|
| 737 |
+
if self.rings_change_config is not None:
|
| 738 |
+
filter_instances.append(
|
| 739 |
+
RingsChangeFilter.from_config(self.rings_change_config)
|
| 740 |
+
)
|
| 741 |
+
|
| 742 |
+
if self.no_reaction_config is not None:
|
| 743 |
+
filter_instances.append(
|
| 744 |
+
NoReactionFilter.from_config(self.no_reaction_config)
|
| 745 |
+
)
|
| 746 |
+
|
| 747 |
+
if self.multi_center_config is not None:
|
| 748 |
+
filter_instances.append(
|
| 749 |
+
MultiCenterFilter.from_config(self.multi_center_config)
|
| 750 |
+
)
|
| 751 |
+
|
| 752 |
+
if self.wrong_ch_breaking_config is not None:
|
| 753 |
+
filter_instances.append(
|
| 754 |
+
WrongCHBreakingFilter.from_config(self.wrong_ch_breaking_config)
|
| 755 |
+
)
|
| 756 |
+
|
| 757 |
+
if self.cc_sp3_breaking_config is not None:
|
| 758 |
+
filter_instances.append(
|
| 759 |
+
CCsp3BreakingFilter.from_config(self.cc_sp3_breaking_config)
|
| 760 |
+
)
|
| 761 |
+
|
| 762 |
+
if self.cc_ring_breaking_config is not None:
|
| 763 |
+
filter_instances.append(
|
| 764 |
+
CCRingBreakingFilter.from_config(self.cc_ring_breaking_config)
|
| 765 |
+
)
|
| 766 |
+
|
| 767 |
+
return filter_instances
|
| 768 |
+
|
| 769 |
+
|
| 770 |
+
def tanimoto_kernel(x: MorganFingerprint, y: MorganFingerprint) -> float:
|
| 771 |
+
"""Calculate the Tanimoto coefficient between each element of arrays x and y."""
|
| 772 |
+
x = x.astype(np.float64)
|
| 773 |
+
y = y.astype(np.float64)
|
| 774 |
+
x_dot = np.dot(x, y.T)
|
| 775 |
+
x2 = np.sum(x**2, axis=1)
|
| 776 |
+
y2 = np.sum(y**2, axis=1)
|
| 777 |
+
|
| 778 |
+
denominator = np.array([x2] * len(y2)).T + np.array([y2] * len(x2)) - x_dot
|
| 779 |
+
result = np.divide(
|
| 780 |
+
x_dot, denominator, out=np.zeros_like(x_dot), where=denominator != 0
|
| 781 |
+
)
|
| 782 |
+
|
| 783 |
+
return result
|
| 784 |
+
|
| 785 |
+
|
| 786 |
+
def filter_reaction(
|
| 787 |
+
reaction: ReactionContainer, config: ReactionFilterConfig, filters: list
|
| 788 |
+
) -> Tuple[bool, ReactionContainer]:
|
| 789 |
+
"""Checks the input reaction. Returns True if reaction is detected as erroneous and
|
| 790 |
+
returns reaction itself, which sometimes is modified and does not necessarily
|
| 791 |
+
correspond to the initial reaction.
|
| 792 |
+
|
| 793 |
+
:param reaction: Reaction to be filtered.
|
| 794 |
+
:param config: Reaction filtration configuration.
|
| 795 |
+
:param filters: The list of reaction filters.
|
| 796 |
+
:return: False and reaction if reaction is correct and True and reaction if reaction
|
| 797 |
+
is filtered (erroneous).
|
| 798 |
+
"""
|
| 799 |
+
|
| 800 |
+
is_filtered = False
|
| 801 |
+
|
| 802 |
+
# run reaction standardization
|
| 803 |
+
|
| 804 |
+
standardizers = [
|
| 805 |
+
RemoveReagentsStandardizer(),
|
| 806 |
+
KekuleFormStandardizer(),
|
| 807 |
+
AromaticFormStandardizer(),
|
| 808 |
+
]
|
| 809 |
+
|
| 810 |
+
for reaction_standardizer in standardizers:
|
| 811 |
+
reaction = reaction_standardizer(reaction)
|
| 812 |
+
if not reaction:
|
| 813 |
+
is_filtered = True
|
| 814 |
+
break
|
| 815 |
+
|
| 816 |
+
# run reaction filtration
|
| 817 |
+
if not is_filtered:
|
| 818 |
+
for reaction_filter in filters:
|
| 819 |
+
try: # CGRTools ValueError: mapping of graphs is not disjoint
|
| 820 |
+
if reaction_filter(reaction):
|
| 821 |
+
# if filter returns True it means the reaction doesn't pass the filter
|
| 822 |
+
reaction.meta["filtration_log"] = reaction_filter.__class__.__name__
|
| 823 |
+
is_filtered = True
|
| 824 |
+
except Exception as e:
|
| 825 |
+
logging.debug(e)
|
| 826 |
+
is_filtered = True
|
| 827 |
+
|
| 828 |
+
return is_filtered, reaction
|
| 829 |
+
|
| 830 |
+
|
| 831 |
+
@ray.remote
|
| 832 |
+
def process_batch(
|
| 833 |
+
batch: List[Tuple[int, ReactionContainer]],
|
| 834 |
+
config: ReactionFilterConfig,
|
| 835 |
+
filters: list,
|
| 836 |
+
) -> List[Tuple[bool, ReactionContainer]]:
|
| 837 |
+
"""
|
| 838 |
+
Processes a batch of reactions to extract reaction rules based on the given
|
| 839 |
+
configuration. This function operates as a remote task in a distributed system using
|
| 840 |
+
Ray.
|
| 841 |
+
|
| 842 |
+
:param batch: A list where each element is a tuple containing an index (int) and a
|
| 843 |
+
ReactionContainer object. The index is typically used to keep track of the
|
| 844 |
+
reaction's position in a larger dataset.
|
| 845 |
+
:param config: Reaction filtration configuration.
|
| 846 |
+
:param filters: The list of reaction filters.
|
| 847 |
+
:return: The list of tuples where each tuple include the reaction index, is ir
|
| 848 |
+
filtered or not (True/False) and reaction itself.
|
| 849 |
+
|
| 850 |
+
"""
|
| 851 |
+
|
| 852 |
+
processed_reaction_list = []
|
| 853 |
+
for reaction in batch:
|
| 854 |
+
try: # CGRtools.exceptions.MappingError: atoms with number {52} not equal
|
| 855 |
+
is_filtered, processed_reaction = filter_reaction(reaction, config, filters)
|
| 856 |
+
processed_reaction_list.append((is_filtered, processed_reaction))
|
| 857 |
+
except Exception as e:
|
| 858 |
+
logging.debug(e)
|
| 859 |
+
processed_reaction_list.append((True, reaction))
|
| 860 |
+
return processed_reaction_list
|
| 861 |
+
|
| 862 |
+
|
| 863 |
+
def process_completed_batch(
|
| 864 |
+
futures: Dict,
|
| 865 |
+
result_file: TextIOWrapper,
|
| 866 |
+
n_filtered: int = 0,
|
| 867 |
+
) -> int:
|
| 868 |
+
"""
|
| 869 |
+
Processes completed batches of reactions.
|
| 870 |
+
|
| 871 |
+
:param futures: A dictionary of futures representing ongoing batch processing tasks.
|
| 872 |
+
:param result_file: The path to the file where filtered reactions will be stored.
|
| 873 |
+
:param n_filtered: The number of processed reactions.
|
| 874 |
+
:return: The numbers of filtered and correct reactions.
|
| 875 |
+
|
| 876 |
+
"""
|
| 877 |
+
|
| 878 |
+
ready_id, running_id = ray.wait(list(futures.keys()), num_returns=1)
|
| 879 |
+
completed_batch = ray.get(ready_id[0])
|
| 880 |
+
|
| 881 |
+
# write results of the completed batch to file
|
| 882 |
+
for is_filtered, reaction in completed_batch:
|
| 883 |
+
if not is_filtered:
|
| 884 |
+
result_file.write(reaction)
|
| 885 |
+
n_filtered += 1
|
| 886 |
+
|
| 887 |
+
# remove completed future and update progress bar
|
| 888 |
+
del futures[ready_id[0]]
|
| 889 |
+
|
| 890 |
+
return n_filtered
|
| 891 |
+
|
| 892 |
+
|
| 893 |
+
def filter_reactions_from_file(
|
| 894 |
+
config: ReactionFilterConfig,
|
| 895 |
+
input_reaction_data_path: str,
|
| 896 |
+
filtered_reaction_data_path: str = "reaction_data_filtered.smi",
|
| 897 |
+
num_cpus: int = 1,
|
| 898 |
+
batch_size: int = 100,
|
| 899 |
+
) -> None:
|
| 900 |
+
"""
|
| 901 |
+
Processes reaction data, applying reaction filters based on the provided
|
| 902 |
+
configuration, and writes the results to specified files.
|
| 903 |
+
|
| 904 |
+
:param config: ReactionCheckConfig object containing all filtration configuration
|
| 905 |
+
settings.
|
| 906 |
+
:param input_reaction_data_path: Path to the reaction data file.
|
| 907 |
+
:param filtered_reaction_data_path: Name for the file that will contain filtered
|
| 908 |
+
reactions.
|
| 909 |
+
:param num_cpus: Number of CPUs to use for processing.
|
| 910 |
+
:param batch_size: Size of the batch for processing reactions.
|
| 911 |
+
:return: None. The function writes the processed reactions to specified RDF/smi
|
| 912 |
+
files.
|
| 913 |
+
|
| 914 |
+
"""
|
| 915 |
+
|
| 916 |
+
filters = config.create_filters()
|
| 917 |
+
|
| 918 |
+
ray.init(num_cpus=num_cpus, ignore_reinit_error=True, logging_level=logging.ERROR)
|
| 919 |
+
max_concurrent_batches = num_cpus # limit the number of concurrent batches
|
| 920 |
+
lines_counter = 0
|
| 921 |
+
with ReactionReader(input_reaction_data_path) as reactions, ReactionWriter(
|
| 922 |
+
filtered_reaction_data_path
|
| 923 |
+
) as result_file:
|
| 924 |
+
|
| 925 |
+
batches_to_process, batch = {}, []
|
| 926 |
+
n_filtered = 0
|
| 927 |
+
for index, reaction in tqdm(
|
| 928 |
+
enumerate(reactions),
|
| 929 |
+
desc="Number of reactions processed: ",
|
| 930 |
+
bar_format="{desc}{n} [{elapsed}]",
|
| 931 |
+
):
|
| 932 |
+
lines_counter += 1
|
| 933 |
+
batch.append(reaction)
|
| 934 |
+
if len(batch) == batch_size:
|
| 935 |
+
batch_results = process_batch.remote(batch, config, filters)
|
| 936 |
+
batches_to_process[batch_results] = None
|
| 937 |
+
batch = []
|
| 938 |
+
|
| 939 |
+
# check and process completed tasks if we've reached the concurrency limit
|
| 940 |
+
while len(batches_to_process) >= max_concurrent_batches:
|
| 941 |
+
n_filtered = process_completed_batch(
|
| 942 |
+
batches_to_process,
|
| 943 |
+
result_file,
|
| 944 |
+
n_filtered,
|
| 945 |
+
)
|
| 946 |
+
|
| 947 |
+
# process the last batch if it's not empty
|
| 948 |
+
if batch:
|
| 949 |
+
batch_results = process_batch.remote(batch, config, filters)
|
| 950 |
+
batches_to_process[batch_results] = None
|
| 951 |
+
|
| 952 |
+
# process remaining batches
|
| 953 |
+
while batches_to_process:
|
| 954 |
+
n_filtered = process_completed_batch(
|
| 955 |
+
batches_to_process,
|
| 956 |
+
result_file,
|
| 957 |
+
n_filtered,
|
| 958 |
+
)
|
| 959 |
+
|
| 960 |
+
ray.shutdown()
|
| 961 |
+
print(f"Initial number of reactions: {lines_counter}")
|
| 962 |
+
print(f"Filtered number of reactions: {n_filtered}")
|
synplan/chem/data/standardizing.py
ADDED
|
@@ -0,0 +1,1187 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Module containing classes and functions for reactions standardizing.
|
| 2 |
+
|
| 3 |
+
This module contains the open-source code from
|
| 4 |
+
https://github.com/Laboratoire-de-Chemoinformatique/Reaction_Data_Cleaning/blob/master/scripts/standardizer.py
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
from __future__ import annotations
|
| 8 |
+
|
| 9 |
+
import logging
|
| 10 |
+
from contextlib import suppress
|
| 11 |
+
from dataclasses import dataclass
|
| 12 |
+
from io import TextIOWrapper
|
| 13 |
+
from typing import Any, Dict, Iterable, List, Optional, Tuple, Union, Sequence, TextIO
|
| 14 |
+
from abc import ABC, abstractmethod
|
| 15 |
+
from pathlib import Path
|
| 16 |
+
import sys
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
import ray
|
| 20 |
+
import yaml
|
| 21 |
+
from CGRtools import smiles as smiles_cgrtools
|
| 22 |
+
from CGRtools.containers import MoleculeContainer
|
| 23 |
+
from CGRtools.containers import ReactionContainer
|
| 24 |
+
from CGRtools.containers import ReactionContainer as ReactionContainerCGRTools
|
| 25 |
+
from chython import ReactionContainer as ReactionContainerChython
|
| 26 |
+
from chython import smiles as smiles_chython
|
| 27 |
+
from tqdm.auto import tqdm
|
| 28 |
+
|
| 29 |
+
from synplan.chem.utils import unite_molecules
|
| 30 |
+
from synplan.utils.config import ConfigABC
|
| 31 |
+
from synplan.utils.files import ReactionReader, ReactionWriter
|
| 32 |
+
from synplan.utils.logging import init_logger, init_ray_logging
|
| 33 |
+
|
| 34 |
+
logger = logging.getLogger("synplan.chem.data.standardizing")
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
class StandardizationError(RuntimeError):
|
| 38 |
+
"""Wraps the original exception and the reaction string that failed."""
|
| 39 |
+
|
| 40 |
+
def __init__(self, stage: str, reaction: str, original: Exception):
|
| 41 |
+
super().__init__(f"{stage} failed on {reaction}: {original}")
|
| 42 |
+
self.stage = stage
|
| 43 |
+
self.reaction = reaction
|
| 44 |
+
self.original = original
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
class BaseStandardizer(ABC):
|
| 48 |
+
"""Template: subclasses override `_run` only."""
|
| 49 |
+
|
| 50 |
+
@classmethod
|
| 51 |
+
def from_config(cls, _cfg: object) -> "BaseStandardizer":
|
| 52 |
+
return cls()
|
| 53 |
+
|
| 54 |
+
@abstractmethod
|
| 55 |
+
def _run(self, rxn: ReactionContainer) -> ReactionContainer:
|
| 56 |
+
"""Run the standardization step on the reaction.
|
| 57 |
+
|
| 58 |
+
Args:
|
| 59 |
+
rxn: The reaction to standardize
|
| 60 |
+
|
| 61 |
+
Returns:
|
| 62 |
+
The standardized reaction
|
| 63 |
+
|
| 64 |
+
Raises:
|
| 65 |
+
StandardizationError: If standardization fails
|
| 66 |
+
"""
|
| 67 |
+
...
|
| 68 |
+
|
| 69 |
+
def __call__(self, rxn: ReactionContainer) -> ReactionContainer:
|
| 70 |
+
"""Execute the standardization step with proper error handling.
|
| 71 |
+
|
| 72 |
+
Args:
|
| 73 |
+
rxn: The reaction to standardize
|
| 74 |
+
|
| 75 |
+
Returns:
|
| 76 |
+
The standardized reaction
|
| 77 |
+
|
| 78 |
+
Raises:
|
| 79 |
+
StandardizationError: If standardization fails
|
| 80 |
+
"""
|
| 81 |
+
try:
|
| 82 |
+
return self._run(rxn)
|
| 83 |
+
except Exception as exc:
|
| 84 |
+
logging.debug("%s: %s", self.__class__.__name__, exc, exc_info=True)
|
| 85 |
+
raise StandardizationError(self.__class__.__name__, str(rxn), exc)
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
# Configuration classes
|
| 89 |
+
@dataclass
|
| 90 |
+
class ReactionMappingConfig:
|
| 91 |
+
pass
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
class ReactionMappingStandardizer(BaseStandardizer):
|
| 95 |
+
"""Maps atoms of the reaction using chython (chytorch)."""
|
| 96 |
+
|
| 97 |
+
def _map_and_remove_reagents(
|
| 98 |
+
self, reaction: ReactionContainerChython
|
| 99 |
+
) -> ReactionContainerChython:
|
| 100 |
+
"""Map and remove reagents from the reaction.
|
| 101 |
+
|
| 102 |
+
Args:
|
| 103 |
+
reaction: Input reaction
|
| 104 |
+
|
| 105 |
+
Returns:
|
| 106 |
+
The mapped reaction with reagents removed
|
| 107 |
+
"""
|
| 108 |
+
reaction.reset_mapping()
|
| 109 |
+
reaction.remove_reagents()
|
| 110 |
+
return reaction
|
| 111 |
+
|
| 112 |
+
def _run(self, rxn: ReactionContainerCGRTools) -> ReactionContainerCGRTools:
|
| 113 |
+
"""Map atoms of the reaction using chython.
|
| 114 |
+
|
| 115 |
+
Args:
|
| 116 |
+
rxn: Input reaction
|
| 117 |
+
|
| 118 |
+
Returns:
|
| 119 |
+
The mapped reaction
|
| 120 |
+
|
| 121 |
+
Raises:
|
| 122 |
+
StandardizationError: If mapping fails
|
| 123 |
+
"""
|
| 124 |
+
try:
|
| 125 |
+
# Convert to chython format
|
| 126 |
+
if isinstance(rxn, str):
|
| 127 |
+
chython_reaction = smiles_chython(rxn)
|
| 128 |
+
else:
|
| 129 |
+
# Convert CGRtools reaction to SMILES string, preserving reagents
|
| 130 |
+
reactants = ".".join(str(m) for m in rxn.reactants)
|
| 131 |
+
reagents = ".".join(str(m) for m in rxn.reagents)
|
| 132 |
+
products = ".".join(str(m) for m in rxn.products)
|
| 133 |
+
smiles = f"{reactants}>{reagents}>{products}"
|
| 134 |
+
# Parse SMILES string with chython
|
| 135 |
+
chython_reaction = smiles_chython(smiles)
|
| 136 |
+
|
| 137 |
+
# Map and remove reagents
|
| 138 |
+
reaction_mapped = self._map_and_remove_reagents(chython_reaction)
|
| 139 |
+
if not reaction_mapped:
|
| 140 |
+
raise StandardizationError(
|
| 141 |
+
"ReactionMapping", str(rxn), ValueError("Mapping failed")
|
| 142 |
+
)
|
| 143 |
+
|
| 144 |
+
# Convert back to CGRtools format
|
| 145 |
+
mapped_smiles = format(chython_reaction, "m")
|
| 146 |
+
result = smiles_cgrtools(mapped_smiles)
|
| 147 |
+
result.meta.update(rxn.meta) # Preserve metadata
|
| 148 |
+
return result
|
| 149 |
+
except Exception as e:
|
| 150 |
+
raise StandardizationError("ReactionMapping", str(rxn), e)
|
| 151 |
+
|
| 152 |
+
|
| 153 |
+
@dataclass
|
| 154 |
+
class FunctionalGroupsConfig:
|
| 155 |
+
pass
|
| 156 |
+
|
| 157 |
+
|
| 158 |
+
class FunctionalGroupsStandardizer(BaseStandardizer):
|
| 159 |
+
"""Functional groups standardization."""
|
| 160 |
+
|
| 161 |
+
def _run(self, rxn: ReactionContainer) -> ReactionContainer:
|
| 162 |
+
"""Standardize functional groups in the reaction.
|
| 163 |
+
|
| 164 |
+
Args:
|
| 165 |
+
rxn: Input reaction
|
| 166 |
+
|
| 167 |
+
Returns:
|
| 168 |
+
The reaction with standardized functional groups
|
| 169 |
+
|
| 170 |
+
Raises:
|
| 171 |
+
StandardizationError: If standardization fails
|
| 172 |
+
"""
|
| 173 |
+
rxn.standardize()
|
| 174 |
+
return rxn
|
| 175 |
+
|
| 176 |
+
|
| 177 |
+
@dataclass
|
| 178 |
+
class KekuleFormConfig:
|
| 179 |
+
pass
|
| 180 |
+
|
| 181 |
+
|
| 182 |
+
class KekuleFormStandardizer(BaseStandardizer):
|
| 183 |
+
"""Reactants/reagents/products kekulization."""
|
| 184 |
+
|
| 185 |
+
def _run(self, rxn: ReactionContainer) -> ReactionContainer:
|
| 186 |
+
"""Kekulize the reaction.
|
| 187 |
+
|
| 188 |
+
Args:
|
| 189 |
+
rxn: The reaction to kekulize
|
| 190 |
+
|
| 191 |
+
Returns:
|
| 192 |
+
The kekulized reaction
|
| 193 |
+
|
| 194 |
+
Raises:
|
| 195 |
+
StandardizationError: If kekulization fails
|
| 196 |
+
"""
|
| 197 |
+
rxn.kekule()
|
| 198 |
+
return rxn
|
| 199 |
+
|
| 200 |
+
|
| 201 |
+
@dataclass
|
| 202 |
+
class CheckValenceConfig:
|
| 203 |
+
pass
|
| 204 |
+
|
| 205 |
+
|
| 206 |
+
class CheckValenceStandardizer(BaseStandardizer):
|
| 207 |
+
"""Check valence."""
|
| 208 |
+
|
| 209 |
+
def _run(self, rxn: ReactionContainer) -> ReactionContainer:
|
| 210 |
+
"""Check valence of atoms in the reaction.
|
| 211 |
+
|
| 212 |
+
Args:
|
| 213 |
+
rxn: Input reaction
|
| 214 |
+
|
| 215 |
+
Returns:
|
| 216 |
+
The reaction if valences are correct
|
| 217 |
+
|
| 218 |
+
Raises:
|
| 219 |
+
StandardizationError: If valence check fails
|
| 220 |
+
"""
|
| 221 |
+
for molecule in rxn.reactants + rxn.products + rxn.reagents:
|
| 222 |
+
valence_mistakes = molecule.check_valence()
|
| 223 |
+
if valence_mistakes:
|
| 224 |
+
raise StandardizationError(
|
| 225 |
+
"CheckValence",
|
| 226 |
+
str(rxn),
|
| 227 |
+
ValueError(f"Valence errors: {valence_mistakes}"),
|
| 228 |
+
)
|
| 229 |
+
return rxn
|
| 230 |
+
|
| 231 |
+
|
| 232 |
+
@dataclass
|
| 233 |
+
class ImplicifyHydrogensConfig:
|
| 234 |
+
pass
|
| 235 |
+
|
| 236 |
+
|
| 237 |
+
class ImplicifyHydrogensStandardizer(BaseStandardizer):
|
| 238 |
+
"""Implicify hydrogens."""
|
| 239 |
+
|
| 240 |
+
def _run(self, rxn: ReactionContainer) -> ReactionContainer:
|
| 241 |
+
"""Implicify hydrogens in the reaction.
|
| 242 |
+
|
| 243 |
+
Args:
|
| 244 |
+
rxn: Input reaction
|
| 245 |
+
|
| 246 |
+
Returns:
|
| 247 |
+
The reaction with implicified hydrogens
|
| 248 |
+
|
| 249 |
+
Raises:
|
| 250 |
+
StandardizationError: If hydrogen implicification fails
|
| 251 |
+
"""
|
| 252 |
+
rxn.implicify_hydrogens()
|
| 253 |
+
return rxn
|
| 254 |
+
|
| 255 |
+
|
| 256 |
+
@dataclass
|
| 257 |
+
class CheckIsotopesConfig:
|
| 258 |
+
pass
|
| 259 |
+
|
| 260 |
+
|
| 261 |
+
class CheckIsotopesStandardizer(BaseStandardizer):
|
| 262 |
+
"""Check isotopes."""
|
| 263 |
+
|
| 264 |
+
def _run(self, rxn: ReactionContainer) -> ReactionContainer:
|
| 265 |
+
"""Check and clean isotopes in the reaction.
|
| 266 |
+
|
| 267 |
+
Args:
|
| 268 |
+
rxn: Input reaction
|
| 269 |
+
|
| 270 |
+
Returns:
|
| 271 |
+
The reaction with cleaned isotopes
|
| 272 |
+
|
| 273 |
+
Raises:
|
| 274 |
+
StandardizationError: If isotope check/cleaning fails
|
| 275 |
+
"""
|
| 276 |
+
is_isotope = False
|
| 277 |
+
for molecule in rxn.reactants + rxn.products:
|
| 278 |
+
for _, atom in molecule.atoms():
|
| 279 |
+
if atom.isotope:
|
| 280 |
+
is_isotope = True
|
| 281 |
+
break
|
| 282 |
+
if is_isotope:
|
| 283 |
+
break
|
| 284 |
+
|
| 285 |
+
if is_isotope:
|
| 286 |
+
rxn.clean_isotopes()
|
| 287 |
+
|
| 288 |
+
return rxn
|
| 289 |
+
|
| 290 |
+
|
| 291 |
+
@dataclass
|
| 292 |
+
class SplitIonsConfig:
|
| 293 |
+
pass
|
| 294 |
+
|
| 295 |
+
|
| 296 |
+
class SplitIonsStandardizer(BaseStandardizer):
|
| 297 |
+
"""Computing charge of molecule."""
|
| 298 |
+
|
| 299 |
+
def _run(self, rxn: ReactionContainer) -> ReactionContainer:
|
| 300 |
+
"""Split ions in the reaction.
|
| 301 |
+
|
| 302 |
+
Args:
|
| 303 |
+
rxn: Input reaction
|
| 304 |
+
|
| 305 |
+
Returns:
|
| 306 |
+
The reaction with split ions
|
| 307 |
+
|
| 308 |
+
Raises:
|
| 309 |
+
StandardizationError: If ion splitting fails
|
| 310 |
+
"""
|
| 311 |
+
reaction, return_code = self._split_ions(rxn)
|
| 312 |
+
if return_code == 2: # ions were split but the reaction is imbalanced
|
| 313 |
+
raise StandardizationError(
|
| 314 |
+
"SplitIons",
|
| 315 |
+
str(rxn),
|
| 316 |
+
ValueError("Reaction is imbalanced after ion splitting"),
|
| 317 |
+
)
|
| 318 |
+
return reaction
|
| 319 |
+
|
| 320 |
+
def _calc_charge(self, molecule: MoleculeContainer) -> int:
|
| 321 |
+
"""Compute total charge of a molecule.
|
| 322 |
+
|
| 323 |
+
Args:
|
| 324 |
+
molecule: Input molecule
|
| 325 |
+
|
| 326 |
+
Returns:
|
| 327 |
+
The total charge of the molecule
|
| 328 |
+
"""
|
| 329 |
+
return sum(molecule._charges.values())
|
| 330 |
+
|
| 331 |
+
def _split_ions(self, reaction: ReactionContainer) -> Tuple[ReactionContainer, int]:
|
| 332 |
+
"""Split ions in a reaction.
|
| 333 |
+
|
| 334 |
+
Args:
|
| 335 |
+
reaction: Input reaction
|
| 336 |
+
|
| 337 |
+
Returns:
|
| 338 |
+
A tuple containing:
|
| 339 |
+
- The reaction with split ions
|
| 340 |
+
- Return code (0: nothing changed, 1: ions split, 2: ions split but imbalanced)
|
| 341 |
+
"""
|
| 342 |
+
meta = reaction.meta
|
| 343 |
+
reaction_parts = []
|
| 344 |
+
return_codes = []
|
| 345 |
+
|
| 346 |
+
for molecules in (reaction.reactants, reaction.reagents, reaction.products):
|
| 347 |
+
# Split molecules into individual components
|
| 348 |
+
divided_molecules = []
|
| 349 |
+
for molecule in molecules:
|
| 350 |
+
if isinstance(molecule, str):
|
| 351 |
+
# If it's a string, try to parse it as a molecule
|
| 352 |
+
try:
|
| 353 |
+
molecule: MoleculeContainer = smiles_cgrtools(molecule)
|
| 354 |
+
except Exception as e:
|
| 355 |
+
logging.warning("Failed to parse molecule %s: %s", molecule, e)
|
| 356 |
+
continue
|
| 357 |
+
|
| 358 |
+
# Use the split method from CGRtools
|
| 359 |
+
try:
|
| 360 |
+
components = molecule.split()
|
| 361 |
+
divided_molecules.extend(components)
|
| 362 |
+
except Exception as e:
|
| 363 |
+
logging.warning("Failed to split molecule %s: %s", molecule, e)
|
| 364 |
+
divided_molecules.append(molecule)
|
| 365 |
+
|
| 366 |
+
total_charge = 0
|
| 367 |
+
ions_present = False
|
| 368 |
+
for molecule in divided_molecules:
|
| 369 |
+
try:
|
| 370 |
+
mol_charge = self._calc_charge(molecule)
|
| 371 |
+
total_charge += mol_charge
|
| 372 |
+
if mol_charge != 0:
|
| 373 |
+
ions_present = True
|
| 374 |
+
except Exception as e:
|
| 375 |
+
logging.warning(
|
| 376 |
+
"Failed to calculate charge for molecule %s: %s", molecule, e
|
| 377 |
+
)
|
| 378 |
+
continue
|
| 379 |
+
|
| 380 |
+
if ions_present and total_charge:
|
| 381 |
+
return_codes.append(2)
|
| 382 |
+
elif ions_present:
|
| 383 |
+
return_codes.append(1)
|
| 384 |
+
else:
|
| 385 |
+
return_codes.append(0)
|
| 386 |
+
|
| 387 |
+
reaction_parts.append(tuple(divided_molecules))
|
| 388 |
+
|
| 389 |
+
return (
|
| 390 |
+
ReactionContainer(
|
| 391 |
+
reactants=reaction_parts[0],
|
| 392 |
+
reagents=reaction_parts[1],
|
| 393 |
+
products=reaction_parts[2],
|
| 394 |
+
meta=meta,
|
| 395 |
+
),
|
| 396 |
+
max(return_codes),
|
| 397 |
+
)
|
| 398 |
+
|
| 399 |
+
|
| 400 |
+
@dataclass
|
| 401 |
+
class AromaticFormConfig:
|
| 402 |
+
pass
|
| 403 |
+
|
| 404 |
+
|
| 405 |
+
class AromaticFormStandardizer(BaseStandardizer):
|
| 406 |
+
"""Aromatize molecules in reaction."""
|
| 407 |
+
|
| 408 |
+
def _run(self, rxn: ReactionContainer) -> ReactionContainer:
|
| 409 |
+
"""Aromatize molecules in the reaction.
|
| 410 |
+
|
| 411 |
+
Args:
|
| 412 |
+
rxn: Input reaction
|
| 413 |
+
|
| 414 |
+
Returns:
|
| 415 |
+
The reaction with aromatized molecules
|
| 416 |
+
|
| 417 |
+
Raises:
|
| 418 |
+
StandardizationError: If aromatization fails
|
| 419 |
+
"""
|
| 420 |
+
rxn.thiele()
|
| 421 |
+
return rxn
|
| 422 |
+
|
| 423 |
+
|
| 424 |
+
@dataclass
|
| 425 |
+
class MappingFixConfig:
|
| 426 |
+
pass
|
| 427 |
+
|
| 428 |
+
|
| 429 |
+
class MappingFixStandardizer(BaseStandardizer):
|
| 430 |
+
"""Fix atom-to-atom mapping in reaction."""
|
| 431 |
+
|
| 432 |
+
def _run(self, rxn: ReactionContainer) -> ReactionContainer:
|
| 433 |
+
"""Fix atom-to-atom mapping in the reaction.
|
| 434 |
+
|
| 435 |
+
Args:
|
| 436 |
+
rxn: Input reaction
|
| 437 |
+
|
| 438 |
+
Returns:
|
| 439 |
+
The reaction with fixed atom-to-atom mapping
|
| 440 |
+
|
| 441 |
+
Raises:
|
| 442 |
+
StandardizationError: If mapping fix fails
|
| 443 |
+
"""
|
| 444 |
+
rxn.fix_mapping()
|
| 445 |
+
return rxn
|
| 446 |
+
|
| 447 |
+
|
| 448 |
+
@dataclass
|
| 449 |
+
class UnchangedPartsConfig:
|
| 450 |
+
pass
|
| 451 |
+
|
| 452 |
+
|
| 453 |
+
class UnchangedPartsStandardizer(BaseStandardizer):
|
| 454 |
+
"""Ungroup molecules, remove unchanged parts from reactants and products."""
|
| 455 |
+
|
| 456 |
+
def __init__(
|
| 457 |
+
self,
|
| 458 |
+
add_reagents_to_reactants: bool = False,
|
| 459 |
+
keep_reagents: bool = False,
|
| 460 |
+
):
|
| 461 |
+
self.add_reagents_to_reactants = add_reagents_to_reactants
|
| 462 |
+
self.keep_reagents = keep_reagents
|
| 463 |
+
|
| 464 |
+
@classmethod
|
| 465 |
+
def from_config(cls, config: UnchangedPartsConfig) -> "UnchangedPartsStandardizer":
|
| 466 |
+
return cls()
|
| 467 |
+
|
| 468 |
+
def _run(self, rxn: ReactionContainer) -> ReactionContainer:
|
| 469 |
+
"""Remove unchanged parts from the reaction.
|
| 470 |
+
|
| 471 |
+
Args:
|
| 472 |
+
rxn: Input reaction
|
| 473 |
+
|
| 474 |
+
Returns:
|
| 475 |
+
The reaction with unchanged parts removed
|
| 476 |
+
|
| 477 |
+
Raises:
|
| 478 |
+
StandardizationError: If unchanged parts removal fails
|
| 479 |
+
"""
|
| 480 |
+
meta = rxn.meta
|
| 481 |
+
new_reactants = list(rxn.reactants)
|
| 482 |
+
new_reagents = list(rxn.reagents)
|
| 483 |
+
if self.add_reagents_to_reactants:
|
| 484 |
+
new_reactants.extend(new_reagents)
|
| 485 |
+
new_reagents = []
|
| 486 |
+
reactants = new_reactants.copy()
|
| 487 |
+
new_products = list(rxn.products)
|
| 488 |
+
|
| 489 |
+
for reactant in reactants:
|
| 490 |
+
if reactant in new_products:
|
| 491 |
+
new_reagents.append(reactant)
|
| 492 |
+
new_reactants.remove(reactant)
|
| 493 |
+
new_products.remove(reactant)
|
| 494 |
+
if not self.keep_reagents:
|
| 495 |
+
new_reagents = []
|
| 496 |
+
|
| 497 |
+
if not new_reactants and new_products:
|
| 498 |
+
raise StandardizationError(
|
| 499 |
+
"UnchangedParts", str(rxn), ValueError("No reactants left")
|
| 500 |
+
)
|
| 501 |
+
if not new_products and new_reactants:
|
| 502 |
+
raise StandardizationError(
|
| 503 |
+
"UnchangedParts", str(rxn), ValueError("No products left")
|
| 504 |
+
)
|
| 505 |
+
if not new_reactants and not new_products:
|
| 506 |
+
raise StandardizationError(
|
| 507 |
+
"UnchangedParts", str(rxn), ValueError("No molecules left")
|
| 508 |
+
)
|
| 509 |
+
|
| 510 |
+
new_reaction = ReactionContainer(
|
| 511 |
+
reactants=tuple(new_reactants),
|
| 512 |
+
reagents=tuple(new_reagents),
|
| 513 |
+
products=tuple(new_products),
|
| 514 |
+
meta=meta,
|
| 515 |
+
)
|
| 516 |
+
new_reaction.name = rxn.name
|
| 517 |
+
return new_reaction
|
| 518 |
+
|
| 519 |
+
|
| 520 |
+
@dataclass
|
| 521 |
+
class SmallMoleculesConfig:
|
| 522 |
+
mol_max_size: int = 6
|
| 523 |
+
|
| 524 |
+
@staticmethod
|
| 525 |
+
def from_dict(config_dict: Dict[str, Any]) -> "SmallMoleculesConfig":
|
| 526 |
+
"""Create an instance of SmallMoleculesConfig from a dictionary."""
|
| 527 |
+
return SmallMoleculesConfig(**config_dict)
|
| 528 |
+
|
| 529 |
+
@staticmethod
|
| 530 |
+
def from_yaml(file_path: str) -> "SmallMoleculesConfig":
|
| 531 |
+
"""Deserialize a YAML file into a SmallMoleculesConfig object."""
|
| 532 |
+
with open(file_path, "r", encoding="utf-8") as file:
|
| 533 |
+
config_dict = yaml.safe_load(file)
|
| 534 |
+
return SmallMoleculesConfig.from_dict(config_dict)
|
| 535 |
+
|
| 536 |
+
def _validate_params(self, params: Dict[str, Any]) -> None:
|
| 537 |
+
"""Validate configuration parameters."""
|
| 538 |
+
mol_max_size = params.get("mol_max_size", self.mol_max_size)
|
| 539 |
+
if not isinstance(mol_max_size, int) or not (0 < mol_max_size):
|
| 540 |
+
raise ValueError("Invalid 'mol_max_size'; expected an integer more than 1")
|
| 541 |
+
|
| 542 |
+
|
| 543 |
+
class SmallMoleculesStandardizer(BaseStandardizer):
|
| 544 |
+
"""Remove small molecule from reaction."""
|
| 545 |
+
|
| 546 |
+
def __init__(self, mol_max_size: int = 6):
|
| 547 |
+
self.mol_max_size = mol_max_size
|
| 548 |
+
|
| 549 |
+
@classmethod
|
| 550 |
+
def from_config(cls, config: SmallMoleculesConfig) -> "SmallMoleculesStandardizer":
|
| 551 |
+
return cls(config.mol_max_size)
|
| 552 |
+
|
| 553 |
+
def _split_molecules(
|
| 554 |
+
self, molecules: Iterable, number_of_atoms: int
|
| 555 |
+
) -> Tuple[List[MoleculeContainer], List[MoleculeContainer]]:
|
| 556 |
+
"""Split molecules according to the number of heavy atoms.
|
| 557 |
+
|
| 558 |
+
Args:
|
| 559 |
+
molecules: Iterable of molecules
|
| 560 |
+
number_of_atoms: Threshold for splitting molecules
|
| 561 |
+
|
| 562 |
+
Returns:
|
| 563 |
+
Tuple of lists containing "big" molecules and "small" molecules
|
| 564 |
+
"""
|
| 565 |
+
big_molecules, small_molecules = [], []
|
| 566 |
+
for molecule in molecules:
|
| 567 |
+
if len(molecule) > number_of_atoms:
|
| 568 |
+
big_molecules.append(molecule)
|
| 569 |
+
else:
|
| 570 |
+
small_molecules.append(molecule)
|
| 571 |
+
return big_molecules, small_molecules
|
| 572 |
+
|
| 573 |
+
def _run(self, rxn: ReactionContainer) -> ReactionContainer:
|
| 574 |
+
"""Remove small molecules from the reaction.
|
| 575 |
+
|
| 576 |
+
Args:
|
| 577 |
+
rxn: Input reaction
|
| 578 |
+
|
| 579 |
+
Returns:
|
| 580 |
+
The reaction without small molecules
|
| 581 |
+
|
| 582 |
+
Raises:
|
| 583 |
+
StandardizationError: If small molecule removal fails
|
| 584 |
+
"""
|
| 585 |
+
new_reactants, small_reactants = self._split_molecules(
|
| 586 |
+
rxn.reactants, self.mol_max_size
|
| 587 |
+
)
|
| 588 |
+
new_products, small_products = self._split_molecules(
|
| 589 |
+
rxn.products, self.mol_max_size
|
| 590 |
+
)
|
| 591 |
+
|
| 592 |
+
if not new_reactants or not new_products:
|
| 593 |
+
raise StandardizationError(
|
| 594 |
+
"SmallMolecules",
|
| 595 |
+
str(rxn),
|
| 596 |
+
ValueError("No molecules left after removing small ones"),
|
| 597 |
+
)
|
| 598 |
+
|
| 599 |
+
new_reaction = ReactionContainer(
|
| 600 |
+
new_reactants, new_products, rxn.reagents, rxn.meta
|
| 601 |
+
)
|
| 602 |
+
new_reaction.name = rxn.name
|
| 603 |
+
|
| 604 |
+
# Save small molecules to meta
|
| 605 |
+
united_small_reactants = unite_molecules(small_reactants)
|
| 606 |
+
new_reaction.meta["small_reactants"] = str(united_small_reactants)
|
| 607 |
+
united_small_products = unite_molecules(small_products)
|
| 608 |
+
new_reaction.meta["small_products"] = str(united_small_products)
|
| 609 |
+
|
| 610 |
+
return new_reaction
|
| 611 |
+
|
| 612 |
+
|
| 613 |
+
@dataclass
|
| 614 |
+
class RemoveReagentsConfig:
|
| 615 |
+
reagent_max_size: int = 7
|
| 616 |
+
|
| 617 |
+
@staticmethod
|
| 618 |
+
def from_dict(config_dict: Dict[str, Any]) -> "RemoveReagentsConfig":
|
| 619 |
+
"""Create an instance of RemoveReagentsConfig from a dictionary."""
|
| 620 |
+
return RemoveReagentsConfig(**config_dict)
|
| 621 |
+
|
| 622 |
+
@staticmethod
|
| 623 |
+
def from_yaml(file_path: str) -> "RemoveReagentsConfig":
|
| 624 |
+
"""Deserialize a YAML file into a RemoveReagentsConfig object."""
|
| 625 |
+
with open(file_path, "r", encoding="utf-8") as file:
|
| 626 |
+
config_dict = yaml.safe_load(file)
|
| 627 |
+
return RemoveReagentsConfig.from_dict(config_dict)
|
| 628 |
+
|
| 629 |
+
def _validate_params(self, params: Dict[str, Any]) -> None:
|
| 630 |
+
"""Validate configuration parameters."""
|
| 631 |
+
reagent_max_size = params.get("reagent_max_size", self.reagent_max_size)
|
| 632 |
+
if not isinstance(reagent_max_size, int) or not (0 < reagent_max_size):
|
| 633 |
+
raise ValueError(
|
| 634 |
+
"Invalid 'reagent_max_size'; expected an integer more than 1"
|
| 635 |
+
)
|
| 636 |
+
|
| 637 |
+
|
| 638 |
+
class RemoveReagentsStandardizer(BaseStandardizer):
|
| 639 |
+
"""Remove reagents from reaction."""
|
| 640 |
+
|
| 641 |
+
def __init__(self, reagent_max_size: int = 7):
|
| 642 |
+
self.reagent_max_size = reagent_max_size
|
| 643 |
+
|
| 644 |
+
@classmethod
|
| 645 |
+
def from_config(cls, config: RemoveReagentsConfig) -> "RemoveReagentsStandardizer":
|
| 646 |
+
return cls(config.reagent_max_size)
|
| 647 |
+
|
| 648 |
+
def _run(self, rxn: ReactionContainer) -> ReactionContainer:
|
| 649 |
+
"""Remove reagents from the reaction.
|
| 650 |
+
|
| 651 |
+
Args:
|
| 652 |
+
rxn: Input reaction
|
| 653 |
+
|
| 654 |
+
Returns:
|
| 655 |
+
The reaction without reagents
|
| 656 |
+
|
| 657 |
+
Raises:
|
| 658 |
+
StandardizationError: If reagent removal fails
|
| 659 |
+
"""
|
| 660 |
+
not_changed_molecules = set(rxn.reactants).intersection(rxn.products)
|
| 661 |
+
cgr = ~rxn
|
| 662 |
+
center_atoms = set(cgr.center_atoms)
|
| 663 |
+
|
| 664 |
+
new_reactants = []
|
| 665 |
+
new_products = []
|
| 666 |
+
new_reagents = []
|
| 667 |
+
|
| 668 |
+
for molecule in rxn.reactants:
|
| 669 |
+
if center_atoms.isdisjoint(molecule) or molecule in not_changed_molecules:
|
| 670 |
+
new_reagents.append(molecule)
|
| 671 |
+
else:
|
| 672 |
+
new_reactants.append(molecule)
|
| 673 |
+
|
| 674 |
+
for molecule in rxn.products:
|
| 675 |
+
if center_atoms.isdisjoint(molecule) or molecule in not_changed_molecules:
|
| 676 |
+
new_reagents.append(molecule)
|
| 677 |
+
else:
|
| 678 |
+
new_products.append(molecule)
|
| 679 |
+
|
| 680 |
+
if not new_reactants or not new_products:
|
| 681 |
+
raise StandardizationError(
|
| 682 |
+
"RemoveReagents",
|
| 683 |
+
str(rxn),
|
| 684 |
+
ValueError("No molecules left after removing reagents"),
|
| 685 |
+
)
|
| 686 |
+
|
| 687 |
+
# Filter reagents by size
|
| 688 |
+
new_reagents = {
|
| 689 |
+
molecule
|
| 690 |
+
for molecule in new_reagents
|
| 691 |
+
if len(molecule) <= self.reagent_max_size
|
| 692 |
+
}
|
| 693 |
+
|
| 694 |
+
new_reaction = ReactionContainer(
|
| 695 |
+
new_reactants, new_products, new_reagents, rxn.meta
|
| 696 |
+
)
|
| 697 |
+
new_reaction.name = rxn.name
|
| 698 |
+
|
| 699 |
+
return new_reaction
|
| 700 |
+
|
| 701 |
+
|
| 702 |
+
@dataclass
|
| 703 |
+
class RebalanceReactionConfig:
|
| 704 |
+
pass
|
| 705 |
+
|
| 706 |
+
|
| 707 |
+
class RebalanceReactionStandardizer(BaseStandardizer):
|
| 708 |
+
"""Rebalance reaction."""
|
| 709 |
+
|
| 710 |
+
@classmethod
|
| 711 |
+
def from_config(
|
| 712 |
+
cls, config: RebalanceReactionConfig
|
| 713 |
+
) -> "RebalanceReactionStandardizer":
|
| 714 |
+
return cls()
|
| 715 |
+
|
| 716 |
+
def _run(self, rxn: ReactionContainer) -> ReactionContainer:
|
| 717 |
+
"""Rebalances the reaction by assembling CGR and then decomposing it. Works for
|
| 718 |
+
all reactions for which the correct CGR can be assembled.
|
| 719 |
+
|
| 720 |
+
Args:
|
| 721 |
+
rxn: Input reaction
|
| 722 |
+
|
| 723 |
+
Returns:
|
| 724 |
+
The rebalanced reaction
|
| 725 |
+
|
| 726 |
+
Raises:
|
| 727 |
+
StandardizationError: If rebalancing fails
|
| 728 |
+
"""
|
| 729 |
+
try:
|
| 730 |
+
tmp_rxn = ReactionContainer(rxn.reactants, rxn.products)
|
| 731 |
+
cgr = ~tmp_rxn
|
| 732 |
+
reactants, products = ~cgr
|
| 733 |
+
new_rxn = ReactionContainer(
|
| 734 |
+
reactants.split(), products.split(), rxn.reagents, rxn.meta
|
| 735 |
+
)
|
| 736 |
+
new_rxn.name = rxn.name
|
| 737 |
+
return new_rxn
|
| 738 |
+
except Exception as e:
|
| 739 |
+
logging.debug(f"Rebalancing attempt failed: {e}")
|
| 740 |
+
raise StandardizationError(
|
| 741 |
+
"RebalanceReaction",
|
| 742 |
+
str(rxn),
|
| 743 |
+
ValueError("Failed to rebalance reaction"),
|
| 744 |
+
)
|
| 745 |
+
|
| 746 |
+
|
| 747 |
+
@dataclass
|
| 748 |
+
class DuplicateReactionConfig:
|
| 749 |
+
pass
|
| 750 |
+
|
| 751 |
+
|
| 752 |
+
class DuplicateReactionStandardizer(BaseStandardizer):
|
| 753 |
+
"""Cluster‑wide duplicate removal via a Ray actor."""
|
| 754 |
+
|
| 755 |
+
def __init__(self, dedup_actor: "ray.actor.ActorHandle"):
|
| 756 |
+
self._actor = dedup_actor # global singleton handle
|
| 757 |
+
# local fast‑path cache to avoid actor call on obvious repeats *in
|
| 758 |
+
# the same worker*; purely an optimisation, not required.
|
| 759 |
+
self._local_seen: set[int] = set()
|
| 760 |
+
|
| 761 |
+
@classmethod
|
| 762 |
+
def from_config(cls, config: DuplicateReactionConfig):
|
| 763 |
+
# fallback for single‑process mode: create a dummy in‑proc actor
|
| 764 |
+
if ray.is_initialized():
|
| 765 |
+
dedup_actor = ray.get_actor("duplicate_rxn_actor")
|
| 766 |
+
else:
|
| 767 |
+
dedup_actor = None
|
| 768 |
+
return cls(dedup_actor)
|
| 769 |
+
|
| 770 |
+
# ------------------------------------------------------------------
|
| 771 |
+
def safe_reaction_smiles(self, reaction: ReactionContainer) -> str:
|
| 772 |
+
reactants_smi = ".".join(str(i) for i in reaction.reactants)
|
| 773 |
+
products_smi = ".".join(str(i) for i in reaction.products)
|
| 774 |
+
return f"{reactants_smi}>>{products_smi}"
|
| 775 |
+
|
| 776 |
+
def _run(self, rxn: ReactionContainer) -> ReactionContainer:
|
| 777 |
+
h = hash(self.safe_reaction_smiles(rxn))
|
| 778 |
+
|
| 779 |
+
# local cache fast‑path (helps in large batches processed by same
|
| 780 |
+
# worker; no correctness impact).
|
| 781 |
+
if h in self._local_seen:
|
| 782 |
+
raise StandardizationError(
|
| 783 |
+
"DuplicateReaction", str(rxn), ValueError("Duplicate reaction found")
|
| 784 |
+
)
|
| 785 |
+
|
| 786 |
+
# ------------------- cluster‑wide check ------------------------
|
| 787 |
+
if self._actor is None: # single‑CPU fall‑back
|
| 788 |
+
is_new = h not in self._local_seen
|
| 789 |
+
else:
|
| 790 |
+
# synchronous, returns True/False
|
| 791 |
+
is_new = ray.get(self._actor.check_and_add.remote(h))
|
| 792 |
+
|
| 793 |
+
if is_new:
|
| 794 |
+
self._local_seen.add(h)
|
| 795 |
+
return rxn
|
| 796 |
+
|
| 797 |
+
raise StandardizationError(
|
| 798 |
+
"DuplicateReaction", str(rxn), ValueError("Duplicate reaction found")
|
| 799 |
+
)
|
| 800 |
+
|
| 801 |
+
|
| 802 |
+
@ray.remote
|
| 803 |
+
class DedupActor:
|
| 804 |
+
"""Cluster‑wide set of reaction hashes."""
|
| 805 |
+
|
| 806 |
+
def __init__(self):
|
| 807 |
+
self._seen: set[int] = set()
|
| 808 |
+
|
| 809 |
+
def check_and_add(self, h: int) -> bool:
|
| 810 |
+
"""
|
| 811 |
+
Returns True **iff** the hash was not present yet and is now stored.
|
| 812 |
+
Cluster‑wide uniqueness is guaranteed because this method executes
|
| 813 |
+
serially inside the actor process.
|
| 814 |
+
"""
|
| 815 |
+
if h in self._seen:
|
| 816 |
+
return False
|
| 817 |
+
self._seen.add(h)
|
| 818 |
+
return True
|
| 819 |
+
|
| 820 |
+
|
| 821 |
+
# Registry mapping config field names to standardizer classes
|
| 822 |
+
STANDARDIZER_REGISTRY = {
|
| 823 |
+
"reaction_mapping_config": ReactionMappingStandardizer,
|
| 824 |
+
"functional_groups_config": FunctionalGroupsStandardizer,
|
| 825 |
+
"kekule_form_config": KekuleFormStandardizer,
|
| 826 |
+
"check_valence_config": CheckValenceStandardizer,
|
| 827 |
+
"implicify_hydrogens_config": ImplicifyHydrogensStandardizer,
|
| 828 |
+
"check_isotopes_config": CheckIsotopesStandardizer,
|
| 829 |
+
"split_ions_config": SplitIonsStandardizer,
|
| 830 |
+
"aromatic_form_config": AromaticFormStandardizer,
|
| 831 |
+
"mapping_fix_config": MappingFixStandardizer,
|
| 832 |
+
"unchanged_parts_config": UnchangedPartsStandardizer,
|
| 833 |
+
"small_molecules_config": SmallMoleculesStandardizer,
|
| 834 |
+
"remove_reagents_config": RemoveReagentsStandardizer,
|
| 835 |
+
"rebalance_reaction_config": RebalanceReactionStandardizer,
|
| 836 |
+
"duplicate_reaction_config": DuplicateReactionStandardizer,
|
| 837 |
+
}
|
| 838 |
+
|
| 839 |
+
|
| 840 |
+
@dataclass
|
| 841 |
+
class ReactionStandardizationConfig(ConfigABC):
|
| 842 |
+
"""Configuration class for reaction filtering. This class manages configuration
|
| 843 |
+
settings for various reaction filters, including paths, file formats, and filter-
|
| 844 |
+
specific parameters.
|
| 845 |
+
|
| 846 |
+
:param reaction_mapping_config: Configuration for reaction mapping.
|
| 847 |
+
:param functional_groups_config: Configuration for functional groups
|
| 848 |
+
standardization.
|
| 849 |
+
:param kekule_form_config: Configuration for reactants/reagents/products
|
| 850 |
+
kekulization.
|
| 851 |
+
:param check_valence_config: Configuration for atom valence checking.
|
| 852 |
+
:param implicify_hydrogens_config: Configuration for hydrogens removal.
|
| 853 |
+
:param check_isotopes_config: Configuration for isotopes checking and cleaning.
|
| 854 |
+
:param split_ions_config: Configuration for computing charge of molecule.
|
| 855 |
+
:param aromatic_form_config: Configuration for molecules aromatization.
|
| 856 |
+
:param unchanged_parts_config: Configuration for removal of unchanged parts in
|
| 857 |
+
reaction.
|
| 858 |
+
:param small_molecules_config: Configuration for removal of small molecule from
|
| 859 |
+
reaction.
|
| 860 |
+
:param remove_reagents_config: Configuration for removal of reagents from reaction.
|
| 861 |
+
:param rebalance_reaction_config: Configuration for reaction rebalancing.
|
| 862 |
+
:param duplicate_reaction_config: Configuration for removal of duplicate reactions.
|
| 863 |
+
"""
|
| 864 |
+
|
| 865 |
+
# configuration for reaction standardizers
|
| 866 |
+
reaction_mapping_config: Optional[ReactionMappingConfig] = None
|
| 867 |
+
functional_groups_config: Optional[FunctionalGroupsConfig] = None
|
| 868 |
+
kekule_form_config: Optional[KekuleFormConfig] = None
|
| 869 |
+
check_valence_config: Optional[CheckValenceConfig] = None
|
| 870 |
+
implicify_hydrogens_config: Optional[ImplicifyHydrogensConfig] = None
|
| 871 |
+
check_isotopes_config: Optional[CheckIsotopesConfig] = None
|
| 872 |
+
split_ions_config: Optional[SplitIonsConfig] = None
|
| 873 |
+
aromatic_form_config: Optional[AromaticFormConfig] = None
|
| 874 |
+
mapping_fix_config: Optional[MappingFixConfig] = None
|
| 875 |
+
unchanged_parts_config: Optional[UnchangedPartsConfig] = None
|
| 876 |
+
small_molecules_config: Optional[SmallMoleculesConfig] = None
|
| 877 |
+
remove_reagents_config: Optional[RemoveReagentsConfig] = None
|
| 878 |
+
rebalance_reaction_config: Optional[RebalanceReactionConfig] = None
|
| 879 |
+
duplicate_reaction_config: Optional[DuplicateReactionConfig] = None
|
| 880 |
+
|
| 881 |
+
def _validate_params(self, params: Dict[str, Any]) -> None:
|
| 882 |
+
"""Validate configuration parameters."""
|
| 883 |
+
for field_name, config in self.__dict__.items():
|
| 884 |
+
if config is not None and hasattr(config, "_validate_params"):
|
| 885 |
+
config._validate_params(params.get(field_name, {}))
|
| 886 |
+
|
| 887 |
+
def to_dict(self):
|
| 888 |
+
"""Converts the configuration into a dictionary."""
|
| 889 |
+
config_dict = {}
|
| 890 |
+
for field_name in STANDARDIZER_REGISTRY:
|
| 891 |
+
config = getattr(self, field_name)
|
| 892 |
+
if config is not None:
|
| 893 |
+
config_dict[field_name] = {}
|
| 894 |
+
return config_dict
|
| 895 |
+
|
| 896 |
+
@staticmethod
|
| 897 |
+
def from_dict(config_dict: Dict[str, Any]) -> "ReactionStandardizationConfig":
|
| 898 |
+
"""Create an instance of ReactionCheckConfig from a dictionary."""
|
| 899 |
+
config_kwargs = {}
|
| 900 |
+
for field_name, std_cls in STANDARDIZER_REGISTRY.items():
|
| 901 |
+
if field_name in config_dict:
|
| 902 |
+
config_kwargs[field_name] = std_cls.__name__.replace(
|
| 903 |
+
"Standardizer", "Config"
|
| 904 |
+
)()
|
| 905 |
+
return ReactionStandardizationConfig(**config_kwargs)
|
| 906 |
+
|
| 907 |
+
@staticmethod
|
| 908 |
+
def from_yaml(file_path: str) -> "ReactionStandardizationConfig":
|
| 909 |
+
"""Deserializes a YAML file into a ReactionCheckConfig object."""
|
| 910 |
+
with open(file_path, "r", encoding="utf-8") as file:
|
| 911 |
+
config_dict = yaml.safe_load(file)
|
| 912 |
+
return ReactionStandardizationConfig.from_dict(config_dict)
|
| 913 |
+
|
| 914 |
+
def create_standardizers(self):
|
| 915 |
+
"""Create standardizer instances based on configuration."""
|
| 916 |
+
standardizers = []
|
| 917 |
+
for field_name, std_cls in STANDARDIZER_REGISTRY.items():
|
| 918 |
+
config = getattr(self, field_name)
|
| 919 |
+
if config is not None:
|
| 920 |
+
standardizers.append(std_cls.from_config(config))
|
| 921 |
+
return standardizers
|
| 922 |
+
|
| 923 |
+
|
| 924 |
+
def standardize_reaction(
|
| 925 |
+
reaction: ReactionContainer,
|
| 926 |
+
standardizers: Sequence,
|
| 927 |
+
) -> ReactionContainer | None:
|
| 928 |
+
"""
|
| 929 |
+
Apply each standardizer in order.
|
| 930 |
+
|
| 931 |
+
Returns
|
| 932 |
+
-------
|
| 933 |
+
ReactionContainer | None
|
| 934 |
+
- the fully‑standardised reaction, or
|
| 935 |
+
- None if *any* standardizer decides to filter it out.
|
| 936 |
+
|
| 937 |
+
Raises
|
| 938 |
+
------
|
| 939 |
+
StandardizationError
|
| 940 |
+
Propagated untouched so the caller can decide what to do.
|
| 941 |
+
"""
|
| 942 |
+
std_rxn = reaction
|
| 943 |
+
for std in standardizers:
|
| 944 |
+
logger.debug(" › %s(%s)", std.__class__.__name__, std_rxn)
|
| 945 |
+
try:
|
| 946 |
+
std_rxn = std(std_rxn) # may return None
|
| 947 |
+
if std_rxn is None: # soft filter
|
| 948 |
+
logger.info("%s filtered out reaction", std.__class__.__name__)
|
| 949 |
+
return None
|
| 950 |
+
except StandardizationError as exc:
|
| 951 |
+
# Log *once*, then re‑raise with full traceback intact
|
| 952 |
+
logger.warning(
|
| 953 |
+
"%s failed on reaction %s : %s",
|
| 954 |
+
std.__class__.__name__,
|
| 955 |
+
std_rxn,
|
| 956 |
+
exc,
|
| 957 |
+
)
|
| 958 |
+
raise # re‑raise same object
|
| 959 |
+
return std_rxn
|
| 960 |
+
|
| 961 |
+
|
| 962 |
+
def safe_standardize(
|
| 963 |
+
item: str | ReactionContainer,
|
| 964 |
+
standardizers: Sequence,
|
| 965 |
+
) -> Tuple[ReactionContainer, bool]:
|
| 966 |
+
"""
|
| 967 |
+
Always returns a ReactionContainer. The boolean flags real success.
|
| 968 |
+
"""
|
| 969 |
+
try:
|
| 970 |
+
# Parse only if needed
|
| 971 |
+
reaction = (
|
| 972 |
+
item if isinstance(item, ReactionContainer) else smiles_cgrtools(item)
|
| 973 |
+
)
|
| 974 |
+
std = standardize_reaction(reaction, standardizers)
|
| 975 |
+
if std is None:
|
| 976 |
+
return reaction, False # filtered → keep original
|
| 977 |
+
return std, True
|
| 978 |
+
except Exception as exc: # noqa: BLE001
|
| 979 |
+
# keep the original container (parse if it was a string)
|
| 980 |
+
if isinstance(item, ReactionContainer):
|
| 981 |
+
return item, False
|
| 982 |
+
return smiles_cgrtools(item), False
|
| 983 |
+
|
| 984 |
+
|
| 985 |
+
def _process_batch(
|
| 986 |
+
batch: Sequence[str | ReactionContainer],
|
| 987 |
+
standardizers: Sequence,
|
| 988 |
+
) -> Tuple[List[ReactionContainer], int]:
|
| 989 |
+
results: List[ReactionContainer] = []
|
| 990 |
+
n_std = 0
|
| 991 |
+
for item in batch:
|
| 992 |
+
rxn, ok = safe_standardize(item, standardizers)
|
| 993 |
+
results.append(rxn)
|
| 994 |
+
n_std += ok
|
| 995 |
+
return results, n_std
|
| 996 |
+
|
| 997 |
+
|
| 998 |
+
@ray.remote
|
| 999 |
+
def process_batch_remote(
|
| 1000 |
+
batch: Sequence[str | ReactionContainer],
|
| 1001 |
+
std_param: ray.ObjectRef, # <-- receives a ref
|
| 1002 |
+
log_file_path: str | Path | None = None,
|
| 1003 |
+
) -> Tuple[List[ReactionContainer], int]:
|
| 1004 |
+
# Ray keeps a local cache of fetched objects, so the list is
|
| 1005 |
+
# deserialised only once per worker process, not once per task.
|
| 1006 |
+
if isinstance(std_param, ray.ObjectRef): # handle? get it
|
| 1007 |
+
standardizers = ray.get(std_param) # • O(once)
|
| 1008 |
+
else: # plain list? use as is
|
| 1009 |
+
standardizers = std_param
|
| 1010 |
+
|
| 1011 |
+
# --- Worker-specific logging setup ---
|
| 1012 |
+
worker_logger = logging.getLogger("synplan.chem.data.standardizing")
|
| 1013 |
+
if log_file_path:
|
| 1014 |
+
log_file_path = Path(log_file_path) # Ensure it's a Path object
|
| 1015 |
+
# Check if a handler for this file already exists for this logger
|
| 1016 |
+
handler_exists = any(
|
| 1017 |
+
isinstance(h, logging.FileHandler) and Path(h.baseFilename) == log_file_path
|
| 1018 |
+
for h in worker_logger.handlers
|
| 1019 |
+
)
|
| 1020 |
+
if not handler_exists:
|
| 1021 |
+
try:
|
| 1022 |
+
fh = logging.FileHandler(log_file_path, encoding="utf-8")
|
| 1023 |
+
# Use a simple format for worker logs, or match driver's format
|
| 1024 |
+
formatter = logging.Formatter(
|
| 1025 |
+
"%(asctime)s | %(name)s (worker) | %(levelname)-8s | %(message)s",
|
| 1026 |
+
datefmt="%Y-%m-%d %H:%M:%S",
|
| 1027 |
+
)
|
| 1028 |
+
fh.setFormatter(formatter)
|
| 1029 |
+
fh.setLevel(logging.INFO) # Or DEBUG, or use worker_log_level if passed
|
| 1030 |
+
worker_logger.addHandler(fh)
|
| 1031 |
+
worker_logger.setLevel(
|
| 1032 |
+
logging.INFO
|
| 1033 |
+
) # Ensure logger passes messages to handler
|
| 1034 |
+
worker_logger.propagate = (
|
| 1035 |
+
False # Avoid double logging if driver also logs
|
| 1036 |
+
)
|
| 1037 |
+
# Optional: Log that the handler was added
|
| 1038 |
+
# worker_logger.info(f"Worker process attached file handler: {log_file_path}")
|
| 1039 |
+
except Exception as e:
|
| 1040 |
+
# Log error if handler creation fails (e.g., permissions)
|
| 1041 |
+
logging.error(
|
| 1042 |
+
f"Worker failed to create file handler {log_file_path}: {e}"
|
| 1043 |
+
)
|
| 1044 |
+
|
| 1045 |
+
return _process_batch(batch, standardizers)
|
| 1046 |
+
|
| 1047 |
+
|
| 1048 |
+
def chunked(iterable: Iterable, size: int):
|
| 1049 |
+
chunk = []
|
| 1050 |
+
for it in iterable:
|
| 1051 |
+
chunk.append(it)
|
| 1052 |
+
if len(chunk) == size:
|
| 1053 |
+
yield chunk
|
| 1054 |
+
chunk = []
|
| 1055 |
+
if chunk:
|
| 1056 |
+
yield chunk
|
| 1057 |
+
|
| 1058 |
+
|
| 1059 |
+
def standardize_reactions_from_file(
|
| 1060 |
+
config: "ReactionStandardizationConfig",
|
| 1061 |
+
input_reaction_data_path: str | Path,
|
| 1062 |
+
standardized_reaction_data_path: str | Path = "reaction_data_standardized.smi",
|
| 1063 |
+
*,
|
| 1064 |
+
num_cpus: int = 1,
|
| 1065 |
+
batch_size: int = 1_000, # larger batches amortise overhead
|
| 1066 |
+
silent: bool = True,
|
| 1067 |
+
max_pending_factor: int = 4, # tasks in flight = factor × CPUs
|
| 1068 |
+
worker_log_level: int | str = logging.WARNING,
|
| 1069 |
+
log_file_path: str | Path | None = None,
|
| 1070 |
+
) -> None:
|
| 1071 |
+
"""
|
| 1072 |
+
Reads reactions, standardises them in parallel with Ray, writes results.
|
| 1073 |
+
|
| 1074 |
+
The function keeps at most `max_pending_factor * num_cpus` Ray tasks in
|
| 1075 |
+
flight to avoid flooding the scheduler and blowing up the object store.
|
| 1076 |
+
Standardisers are broadcast once with `ray.put`, removing per‑task
|
| 1077 |
+
pickling cost. All other logic is unchanged.
|
| 1078 |
+
|
| 1079 |
+
Args:
|
| 1080 |
+
config: Configuration object for standardizers.
|
| 1081 |
+
input_reaction_data_path: Path to the input reaction data file.
|
| 1082 |
+
standardized_reaction_data_path: Path to save the standardized reactions.
|
| 1083 |
+
num_cpus: Number of CPU cores to use for parallel processing.
|
| 1084 |
+
batch_size: Number of reactions to process in each batch.
|
| 1085 |
+
silent: If True, suppress the progress bar.
|
| 1086 |
+
max_pending_factor: Controls the number of pending Ray tasks.
|
| 1087 |
+
worker_log_level: Logging level for Ray workers (e.g., logging.INFO, logging.WARNING).
|
| 1088 |
+
log_file_path: Path to the log file for workers to write to.
|
| 1089 |
+
"""
|
| 1090 |
+
output_path = Path(standardized_reaction_data_path)
|
| 1091 |
+
standardizers = config.create_standardizers()
|
| 1092 |
+
|
| 1093 |
+
logger.info(
|
| 1094 |
+
"Standardizers: %s",
|
| 1095 |
+
", ".join(s.__class__.__name__ for s in standardizers),
|
| 1096 |
+
)
|
| 1097 |
+
|
| 1098 |
+
# ----------------------- Ray initialisation -----------------------
|
| 1099 |
+
if num_cpus > 1:
|
| 1100 |
+
if not ray.is_initialized():
|
| 1101 |
+
ray.init(
|
| 1102 |
+
num_cpus=num_cpus,
|
| 1103 |
+
ignore_reinit_error=True,
|
| 1104 |
+
logging_level=worker_log_level,
|
| 1105 |
+
log_to_driver=False,
|
| 1106 |
+
)
|
| 1107 |
+
|
| 1108 |
+
DEDUP_NAME = "duplicate_rxn_actor"
|
| 1109 |
+
|
| 1110 |
+
try:
|
| 1111 |
+
dedup_actor = ray.get_actor(DEDUP_NAME) # already running?
|
| 1112 |
+
except ValueError:
|
| 1113 |
+
dedup_actor = DedupActor.options(
|
| 1114 |
+
name=DEDUP_NAME, lifetime="detached" # survives driver exit
|
| 1115 |
+
).remote()
|
| 1116 |
+
|
| 1117 |
+
std_ref: ray.ObjectRef | None = None
|
| 1118 |
+
if num_cpus > 1 and std_ref is None: # broadcast once
|
| 1119 |
+
std_ref = ray.put(standardizers)
|
| 1120 |
+
|
| 1121 |
+
max_pending = max_pending_factor * num_cpus
|
| 1122 |
+
pending: Dict[ray.ObjectRef, None] = {}
|
| 1123 |
+
|
| 1124 |
+
n_processed = n_std = 0
|
| 1125 |
+
bar = tqdm(
|
| 1126 |
+
total=0,
|
| 1127 |
+
unit="rxn",
|
| 1128 |
+
desc="Standardising",
|
| 1129 |
+
disable=silent,
|
| 1130 |
+
dynamic_ncols=True,
|
| 1131 |
+
)
|
| 1132 |
+
|
| 1133 |
+
# ------------------------ Helper function ------------------------
|
| 1134 |
+
def _flush(ref: ray.ObjectRef, write_fn) -> None:
|
| 1135 |
+
"""Fetch finished task, write its results, update counters & bar."""
|
| 1136 |
+
nonlocal n_processed, n_std
|
| 1137 |
+
res, ok = ray.get(ref)
|
| 1138 |
+
write_fn(res)
|
| 1139 |
+
bar.update(len(res))
|
| 1140 |
+
n_processed += len(res)
|
| 1141 |
+
n_std += ok
|
| 1142 |
+
|
| 1143 |
+
# ----------------------------- I/O -------------------------------
|
| 1144 |
+
with ReactionReader(input_reaction_data_path) as reader, ReactionWriter(
|
| 1145 |
+
output_path
|
| 1146 |
+
) as writer:
|
| 1147 |
+
|
| 1148 |
+
write_fn = lambda reactions: [writer.write(r) for r in reactions]
|
| 1149 |
+
|
| 1150 |
+
# --------------------- Main read/compute loop -----------------
|
| 1151 |
+
for chunk in chunked(reader, batch_size):
|
| 1152 |
+
bar.total += len(chunk)
|
| 1153 |
+
bar.refresh()
|
| 1154 |
+
|
| 1155 |
+
if num_cpus > 1:
|
| 1156 |
+
# ---------- back‑pressure: keep ≤ max_pending ----------
|
| 1157 |
+
while len(pending) >= max_pending:
|
| 1158 |
+
done, _ = ray.wait(list(pending), num_returns=1)
|
| 1159 |
+
_flush(done[0], write_fn)
|
| 1160 |
+
pending.pop(done[0], None)
|
| 1161 |
+
|
| 1162 |
+
# ----------- schedule new task -------------------------
|
| 1163 |
+
ref = process_batch_remote.remote(chunk, std_ref, log_file_path)
|
| 1164 |
+
pending[ref] = None
|
| 1165 |
+
else:
|
| 1166 |
+
# --------------- serial fall‑back ----------------------
|
| 1167 |
+
res, ok = _process_batch(chunk, standardizers)
|
| 1168 |
+
write_fn(res)
|
| 1169 |
+
bar.update(len(res))
|
| 1170 |
+
n_processed += len(res)
|
| 1171 |
+
n_std += ok
|
| 1172 |
+
|
| 1173 |
+
# ------------------ Drain remaining Ray tasks -----------------
|
| 1174 |
+
while pending:
|
| 1175 |
+
done, _ = ray.wait(list(pending), num_returns=1)
|
| 1176 |
+
_flush(done[0], write_fn)
|
| 1177 |
+
pending.pop(done[0], None)
|
| 1178 |
+
|
| 1179 |
+
bar.close()
|
| 1180 |
+
ray.shutdown()
|
| 1181 |
+
|
| 1182 |
+
logger.info(
|
| 1183 |
+
"Finished: processed %d, standardised %d, filtered %d",
|
| 1184 |
+
n_processed,
|
| 1185 |
+
n_std,
|
| 1186 |
+
n_processed - n_std,
|
| 1187 |
+
)
|
synplan/chem/precursor.py
ADDED
|
@@ -0,0 +1,100 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Module containing a class Precursor that represents a precursor (extend molecule object) in
|
| 2 |
+
the search tree."""
|
| 3 |
+
|
| 4 |
+
from typing import Set
|
| 5 |
+
|
| 6 |
+
from CGRtools.containers import MoleculeContainer
|
| 7 |
+
|
| 8 |
+
from synplan.chem.utils import safe_canonicalization
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class Precursor:
|
| 12 |
+
"""Precursor class is used to extend the molecule behavior needed for interaction with
|
| 13 |
+
a tree in MCTS."""
|
| 14 |
+
|
| 15 |
+
def __init__(self, molecule: MoleculeContainer, canonicalize: bool = True):
|
| 16 |
+
"""It initializes a Precursor object with a molecule container as a parameter.
|
| 17 |
+
|
| 18 |
+
:param molecule: A molecule.
|
| 19 |
+
"""
|
| 20 |
+
self.molecule = safe_canonicalization(molecule) if canonicalize else molecule
|
| 21 |
+
self.prev_precursors = []
|
| 22 |
+
|
| 23 |
+
def __len__(self) -> int:
|
| 24 |
+
"""Return the number of atoms in Precursor."""
|
| 25 |
+
return len(self.molecule)
|
| 26 |
+
|
| 27 |
+
def __hash__(self) -> hash:
|
| 28 |
+
"""Returns the hash value of Precursor."""
|
| 29 |
+
return hash(self.molecule)
|
| 30 |
+
|
| 31 |
+
def __str__(self) -> str:
|
| 32 |
+
"""Returns a SMILES of the Precursor."""
|
| 33 |
+
return str(self.molecule)
|
| 34 |
+
|
| 35 |
+
def __eq__(self, other: "Precursor") -> bool:
|
| 36 |
+
"""Checks if the current Precursor is equal to another Precursor."""
|
| 37 |
+
return self.molecule == other.molecule
|
| 38 |
+
|
| 39 |
+
def __repr__(self) -> str:
|
| 40 |
+
"""Returns a SMILES of the Precursor."""
|
| 41 |
+
return str(self.molecule)
|
| 42 |
+
|
| 43 |
+
def is_building_block(self, bb_stock: Set[str], min_mol_size: int = 6) -> bool:
|
| 44 |
+
"""Checks if a Precursor is a building block.
|
| 45 |
+
|
| 46 |
+
:param bb_stock: The list of building blocks. Each building block is represented
|
| 47 |
+
by a canonical SMILES.
|
| 48 |
+
:param min_mol_size: If the size of the Precursor is equal or smaller than
|
| 49 |
+
min_mol_size it is automatically classified as building block.
|
| 50 |
+
:return: True is Precursor is a building block.
|
| 51 |
+
"""
|
| 52 |
+
if len(self.molecule) <= min_mol_size:
|
| 53 |
+
return True
|
| 54 |
+
|
| 55 |
+
return str(self.molecule) in bb_stock
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
def compose_precursors(
|
| 59 |
+
precursors: list = None, exclude_small: bool = True, min_mol_size: int = 6
|
| 60 |
+
) -> MoleculeContainer:
|
| 61 |
+
"""
|
| 62 |
+
Takes a list of precursors, excludes small precursors if specified, and composes them
|
| 63 |
+
into a single molecule. The composed molecule then is used for the prediction of
|
| 64 |
+
synthesisability of the characterizing the possible success of the route including
|
| 65 |
+
the nodes with the given precursor.
|
| 66 |
+
|
| 67 |
+
:param precursors: The list of precursor to be composed.
|
| 68 |
+
:param exclude_small: The parameter that determines whether small precursor should be excluded from the composition
|
| 69 |
+
process. If `exclude_small` is set to `True`,
|
| 70 |
+
only precursor with a length greater than min_mol_size will be composed.
|
| 71 |
+
:param min_mol_size: The parameter used with exclude_small.
|
| 72 |
+
|
| 73 |
+
:return: A composed precursor as a MoleculeContainer object.
|
| 74 |
+
|
| 75 |
+
"""
|
| 76 |
+
|
| 77 |
+
if len(precursors) == 1:
|
| 78 |
+
return precursors[0].molecule
|
| 79 |
+
if len(precursors) > 1:
|
| 80 |
+
if exclude_small:
|
| 81 |
+
big_precursor = [
|
| 82 |
+
precursor
|
| 83 |
+
for precursor in precursors
|
| 84 |
+
if len(precursor.molecule) > min_mol_size
|
| 85 |
+
]
|
| 86 |
+
if big_precursor:
|
| 87 |
+
precursors = big_precursor
|
| 88 |
+
tmp_mol = precursors[0].molecule.copy()
|
| 89 |
+
transition_mapping = {}
|
| 90 |
+
for mol in precursors[1:]:
|
| 91 |
+
for n, atom in mol.molecule.atoms():
|
| 92 |
+
new_number = tmp_mol.add_atom(atom.atomic_symbol)
|
| 93 |
+
transition_mapping[n] = new_number
|
| 94 |
+
for atom, neighbor, bond in mol.molecule.bonds():
|
| 95 |
+
tmp_mol.add_bond(
|
| 96 |
+
transition_mapping[atom], transition_mapping[neighbor], bond
|
| 97 |
+
)
|
| 98 |
+
transition_mapping = {}
|
| 99 |
+
|
| 100 |
+
return tmp_mol
|
synplan/chem/reaction.py
ADDED
|
@@ -0,0 +1,125 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Module containing classes and functions for manipulating reactions and reaction
|
| 2 |
+
rules."""
|
| 3 |
+
|
| 4 |
+
from typing import Any, Iterator, List, Optional
|
| 5 |
+
|
| 6 |
+
from CGRtools.containers import MoleculeContainer, ReactionContainer
|
| 7 |
+
from CGRtools.exceptions import InvalidAromaticRing
|
| 8 |
+
from CGRtools.reactor import Reactor
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class Reaction(ReactionContainer):
|
| 12 |
+
"""Reaction class used for a general representation of reaction."""
|
| 13 |
+
|
| 14 |
+
def __init__(self, *args, **kwargs):
|
| 15 |
+
super().__init__(*args, **kwargs)
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def add_small_mols(
|
| 19 |
+
big_mol: MoleculeContainer, small_molecules: Optional[Any] = None
|
| 20 |
+
) -> List[MoleculeContainer]:
|
| 21 |
+
"""Takes a molecule and returns a list of modified molecules where each small
|
| 22 |
+
molecule has been added to the big molecule.
|
| 23 |
+
|
| 24 |
+
:param big_mol: A molecule.
|
| 25 |
+
:param small_molecules: A list of small molecules that need to be added to the
|
| 26 |
+
molecule.
|
| 27 |
+
:return: Returns a list of molecules.
|
| 28 |
+
"""
|
| 29 |
+
if small_molecules:
|
| 30 |
+
tmp_mol = big_mol.copy()
|
| 31 |
+
transition_mapping = {}
|
| 32 |
+
for small_mol in small_molecules:
|
| 33 |
+
|
| 34 |
+
for n, atom in small_mol.atoms():
|
| 35 |
+
new_number = tmp_mol.add_atom(atom.atomic_symbol)
|
| 36 |
+
transition_mapping[n] = new_number
|
| 37 |
+
|
| 38 |
+
for atom, neighbor, bond in small_mol.bonds():
|
| 39 |
+
tmp_mol.add_bond(
|
| 40 |
+
transition_mapping[atom], transition_mapping[neighbor], bond
|
| 41 |
+
)
|
| 42 |
+
|
| 43 |
+
transition_mapping = {}
|
| 44 |
+
return tmp_mol.split()
|
| 45 |
+
|
| 46 |
+
return [big_mol]
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
def apply_reaction_rule(
|
| 50 |
+
molecule: MoleculeContainer,
|
| 51 |
+
reaction_rule: Reactor,
|
| 52 |
+
sort_reactions: bool = False,
|
| 53 |
+
top_reactions_num: int = 3,
|
| 54 |
+
validate_products: bool = True,
|
| 55 |
+
rebuild_with_cgr: bool = False,
|
| 56 |
+
) -> Iterator[List[MoleculeContainer,]]:
|
| 57 |
+
"""Applies a reaction rule to a given molecule.
|
| 58 |
+
|
| 59 |
+
:param molecule: A molecule to which reaction rule will be applied.
|
| 60 |
+
:param reaction_rule: A reaction rule to be applied.
|
| 61 |
+
:param sort_reactions:
|
| 62 |
+
:param top_reactions_num: The maximum amount of reactions after the application of
|
| 63 |
+
reaction rule.
|
| 64 |
+
:param validate_products: If True, validates the final products.
|
| 65 |
+
:param rebuild_with_cgr: If True, the products are extracted from CGR decomposition.
|
| 66 |
+
:return: An iterator yielding the products of reaction rule application.
|
| 67 |
+
"""
|
| 68 |
+
|
| 69 |
+
reactants = add_small_mols(molecule, small_molecules=False)
|
| 70 |
+
|
| 71 |
+
try:
|
| 72 |
+
if sort_reactions:
|
| 73 |
+
unsorted_reactions = list(reaction_rule(reactants))
|
| 74 |
+
sorted_reactions = sorted(
|
| 75 |
+
unsorted_reactions,
|
| 76 |
+
key=lambda react: len(
|
| 77 |
+
list(filter(lambda mol: len(mol) > 6, react.products))
|
| 78 |
+
),
|
| 79 |
+
reverse=True,
|
| 80 |
+
)
|
| 81 |
+
|
| 82 |
+
# take top-N reactions from reactor
|
| 83 |
+
reactions = sorted_reactions[:top_reactions_num]
|
| 84 |
+
else:
|
| 85 |
+
reactions = []
|
| 86 |
+
for reaction in reaction_rule(reactants):
|
| 87 |
+
reactions.append(reaction)
|
| 88 |
+
if len(reactions) == top_reactions_num:
|
| 89 |
+
break
|
| 90 |
+
except IndexError:
|
| 91 |
+
reactions = []
|
| 92 |
+
|
| 93 |
+
for reaction in reactions:
|
| 94 |
+
|
| 95 |
+
# temporary solution - incorrect leaving groups
|
| 96 |
+
reactant_atom_nums = []
|
| 97 |
+
for i in reaction.reactants:
|
| 98 |
+
reactant_atom_nums.extend(i.atoms_numbers)
|
| 99 |
+
product_atom_nums = []
|
| 100 |
+
for i in reaction.products:
|
| 101 |
+
product_atom_nums.extend(i.atoms_numbers)
|
| 102 |
+
leaving_atom_nums = set(reactant_atom_nums) - set(product_atom_nums)
|
| 103 |
+
if len(leaving_atom_nums) > len(product_atom_nums):
|
| 104 |
+
continue
|
| 105 |
+
|
| 106 |
+
# check reaction
|
| 107 |
+
if rebuild_with_cgr:
|
| 108 |
+
cgr = reaction.compose()
|
| 109 |
+
reactants = cgr.decompose()[1].split()
|
| 110 |
+
else:
|
| 111 |
+
reactants = reaction.products # reactants are products in retro reaction
|
| 112 |
+
reactants = [mol for mol in reactants if len(mol) > 0]
|
| 113 |
+
|
| 114 |
+
# validate products
|
| 115 |
+
if validate_products:
|
| 116 |
+
for mol in reactants:
|
| 117 |
+
try:
|
| 118 |
+
mol.kekule()
|
| 119 |
+
if mol.check_valence():
|
| 120 |
+
yield None
|
| 121 |
+
mol.thiele()
|
| 122 |
+
except InvalidAromaticRing:
|
| 123 |
+
yield None
|
| 124 |
+
|
| 125 |
+
yield reactants
|
synplan/chem/reaction_routes/__init__.py
ADDED
|
File without changes
|
synplan/chem/reaction_routes/clustering.py
ADDED
|
@@ -0,0 +1,859 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from collections import defaultdict
|
| 2 |
+
|
| 3 |
+
from pathlib import Path
|
| 4 |
+
import pickle
|
| 5 |
+
import re
|
| 6 |
+
|
| 7 |
+
from CGRtools.containers import ReactionContainer, CGRContainer
|
| 8 |
+
from CGRtools.containers.bonds import DynamicBond
|
| 9 |
+
|
| 10 |
+
from synplan.chem.reaction_routes.leaving_groups import *
|
| 11 |
+
from synplan.chem.reaction_routes.visualisation import *
|
| 12 |
+
from synplan.chem.reaction_routes.route_cgr import *
|
| 13 |
+
from synplan.chem.reaction_routes.io import (
|
| 14 |
+
read_routes_csv,
|
| 15 |
+
read_routes_json,
|
| 16 |
+
make_dict,
|
| 17 |
+
make_json,
|
| 18 |
+
)
|
| 19 |
+
from synplan.utils.visualisation import (
|
| 20 |
+
routes_clustering_report,
|
| 21 |
+
routes_subclustering_report,
|
| 22 |
+
)
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def run_cluster_cli(
|
| 26 |
+
routes_file: str,
|
| 27 |
+
cluster_results_dir: str,
|
| 28 |
+
perform_subcluster: bool = False,
|
| 29 |
+
subcluster_results_dir: Path = None,
|
| 30 |
+
):
|
| 31 |
+
"""
|
| 32 |
+
Read routes from a CSV or JSON file, perform clustering, and optionally subclustering.
|
| 33 |
+
|
| 34 |
+
Args:
|
| 35 |
+
routes_file: Path to the input routes file (.csv or .json).
|
| 36 |
+
cluster_results_dir: Directory where clustering results are stored.
|
| 37 |
+
perform_subcluster: Whether to run subclustering on each cluster.
|
| 38 |
+
subcluster_results_dir: Subdirectory for subclustering results (if enabled).
|
| 39 |
+
"""
|
| 40 |
+
import click
|
| 41 |
+
|
| 42 |
+
routes_file = Path(routes_file)
|
| 43 |
+
match = re.search(r"_(\d+)\.", routes_file.name)
|
| 44 |
+
if not match:
|
| 45 |
+
raise ValueError(f"Could not extract index from filename: {routes_file.name}")
|
| 46 |
+
file_index = int(match.group(1))
|
| 47 |
+
ext = routes_file.suffix.lower()
|
| 48 |
+
if ext == ".csv":
|
| 49 |
+
routes_dict = read_routes_csv(str(routes_file))
|
| 50 |
+
routes_json = make_json(routes_dict)
|
| 51 |
+
elif ext == ".json":
|
| 52 |
+
routes_json = read_routes_json(str(routes_file))
|
| 53 |
+
routes_dict = make_dict(routes_json)
|
| 54 |
+
else:
|
| 55 |
+
raise ValueError(f"Unsupported file type: {ext}")
|
| 56 |
+
|
| 57 |
+
# Compose condensed graph representations
|
| 58 |
+
route_cgrs = compose_all_route_cgrs(routes_dict)
|
| 59 |
+
click.echo(f"Generating RouteCGR")
|
| 60 |
+
reduced_cgrs = compose_all_sb_cgrs(route_cgrs)
|
| 61 |
+
click.echo(f"Generating ReducedRouteCGR")
|
| 62 |
+
|
| 63 |
+
# Perform clustering
|
| 64 |
+
click.echo(f"\nClustering")
|
| 65 |
+
clusters = cluster_routes(reduced_cgrs, use_strat=False)
|
| 66 |
+
|
| 67 |
+
click.echo(f"Total number of routes: {len(routes_dict)}")
|
| 68 |
+
click.echo(f"Found number of clusters: {len(clusters)} ({list(clusters.keys())})")
|
| 69 |
+
|
| 70 |
+
# Ensure output directory exists
|
| 71 |
+
cluster_results_dir = Path(cluster_results_dir)
|
| 72 |
+
cluster_results_dir.mkdir(parents=True, exist_ok=True)
|
| 73 |
+
|
| 74 |
+
# Save clusters to pickle
|
| 75 |
+
with open(cluster_results_dir / f"clusters_{file_index}.pickle", "wb") as f:
|
| 76 |
+
pickle.dump(clusters, f)
|
| 77 |
+
|
| 78 |
+
# Generate HTML reports for each cluster
|
| 79 |
+
for idx in clusters:
|
| 80 |
+
report_path = cluster_results_dir / f"{file_index}_cluster_{idx}.html"
|
| 81 |
+
routes_clustering_report(
|
| 82 |
+
routes_json, clusters, idx, reduced_cgrs, html_path=str(report_path)
|
| 83 |
+
)
|
| 84 |
+
|
| 85 |
+
# Optional subclustering (Under development)
|
| 86 |
+
if perform_subcluster and subcluster_results_dir:
|
| 87 |
+
click.echo("\nSubClustering")
|
| 88 |
+
sub_dir = cluster_results_dir / subcluster_results_dir
|
| 89 |
+
sub_dir.mkdir(parents=True, exist_ok=True)
|
| 90 |
+
|
| 91 |
+
subclusters = subcluster_all_clusters(clusters, reduced_cgrs, route_cgrs)
|
| 92 |
+
for cluster_idx, sub in subclusters.items():
|
| 93 |
+
click.echo(f"Cluster {cluster_idx} has {len(sub)} subclusters")
|
| 94 |
+
for sub_idx, subcluster in sub.items():
|
| 95 |
+
subreport_path = (
|
| 96 |
+
sub_dir / f"{file_index}_subcluster_{cluster_idx}.{sub_idx}.html"
|
| 97 |
+
)
|
| 98 |
+
routes_subclustering_report(
|
| 99 |
+
routes_json,
|
| 100 |
+
subcluster,
|
| 101 |
+
cluster_idx,
|
| 102 |
+
sub_idx,
|
| 103 |
+
reduced_cgrs,
|
| 104 |
+
aam=False,
|
| 105 |
+
html_path=str(subreport_path),
|
| 106 |
+
)
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
def cluster_route_from_csv(routes_file: str):
|
| 110 |
+
"""
|
| 111 |
+
Reads retrosynthetic routes from a CSV file, processes them, and performs clustering.
|
| 112 |
+
|
| 113 |
+
This function orchestrates the process of loading retrosynthetic route data
|
| 114 |
+
from a specified CSV file, converting the routes into Condensed Graph of
|
| 115 |
+
Reactions (CGRs), reducing these CGRs to a simplified form (ReducedRouteCGRs),
|
| 116 |
+
and finally clustering the routes based on these reduced representations.
|
| 117 |
+
It uses strategic bonds for clustering by default (as indicated by `use_strat=False`
|
| 118 |
+
in `cluster_routes`, which implies clustering based on the graph structure
|
| 119 |
+
derived from the reduced CGRs, which often highlight strategic bonds).
|
| 120 |
+
|
| 121 |
+
Args:
|
| 122 |
+
routes_file (str): The path to the CSV file containing the retrosynthetic
|
| 123 |
+
route data.
|
| 124 |
+
|
| 125 |
+
Returns:
|
| 126 |
+
object: The result of the clustering process, typically a data structure
|
| 127 |
+
representing the identified clusters. The exact type depends on
|
| 128 |
+
the implementation of the `cluster_routes` function.
|
| 129 |
+
"""
|
| 130 |
+
routes_dict = read_routes_csv(routes_file)
|
| 131 |
+
route_cgrs_dict = compose_all_route_cgrs(routes_dict)
|
| 132 |
+
reduced_route_cgrs_dict = compose_all_sb_cgrs(route_cgrs_dict)
|
| 133 |
+
clusters = cluster_routes(reduced_route_cgrs_dict, use_strat=False)
|
| 134 |
+
return clusters
|
| 135 |
+
|
| 136 |
+
|
| 137 |
+
def cluster_route_from_json(routes_file: str):
|
| 138 |
+
"""
|
| 139 |
+
Reads retrosynthetic routes from a JSON file, processes them, and performs clustering.
|
| 140 |
+
|
| 141 |
+
This function is similar to `cluster_route_from_csv` but loads the
|
| 142 |
+
retrosynthetic route data from a specified JSON file. It reads the JSON,
|
| 143 |
+
converts it into a suitable dictionary format, composes and reduces the
|
| 144 |
+
Condensed Graph of Reactions (CGRs) for each route, and then clusters
|
| 145 |
+
the routes based on these reduced representations, typically using
|
| 146 |
+
strategic bonds as the basis for clustering.
|
| 147 |
+
|
| 148 |
+
Args:
|
| 149 |
+
routes_file (str): The path to the JSON file containing the retrosynthetic
|
| 150 |
+
route data.
|
| 151 |
+
|
| 152 |
+
Returns:
|
| 153 |
+
object: The result of the clustering process, typically a data structure
|
| 154 |
+
representing the identified clusters. The exact type depends on
|
| 155 |
+
the implementation of the `cluster_routes` function.
|
| 156 |
+
"""
|
| 157 |
+
routes_json = read_routes_json(routes_file)
|
| 158 |
+
routes_dict = make_dict(routes_json)
|
| 159 |
+
route_cgrs_dict = compose_all_route_cgrs(routes_dict)
|
| 160 |
+
reduced_route_cgrs_dict = compose_all_sb_cgrs(route_cgrs_dict)
|
| 161 |
+
clusters = cluster_routes(reduced_route_cgrs_dict, use_strat=False)
|
| 162 |
+
return clusters
|
| 163 |
+
|
| 164 |
+
|
| 165 |
+
def extract_strat_bonds(target_cgr: CGRContainer):
|
| 166 |
+
"""
|
| 167 |
+
Extracts strategic bonds from a CGRContainer object.
|
| 168 |
+
|
| 169 |
+
Strategic bonds are identified as bonds where the original bond order
|
| 170 |
+
(`bond.order`) is None (indicating a bond that was not present in the
|
| 171 |
+
reactants) but the primary bond order (`bond.p_order`) is not None
|
| 172 |
+
(indicating a bond that was formed in the product). This function iterates
|
| 173 |
+
through all bonds in the input CGR, identifies those matching the criteria
|
| 174 |
+
for strategic bonds, and returns a sorted list of unique strategic bonds
|
| 175 |
+
represented as tuples of sorted atom indices.
|
| 176 |
+
|
| 177 |
+
Args:
|
| 178 |
+
target_cgr (CGRContainer): The CGRContainer object from which to extract
|
| 179 |
+
strategic bonds.
|
| 180 |
+
|
| 181 |
+
Returns:
|
| 182 |
+
list: A sorted list of tuples, where each tuple represents a strategic
|
| 183 |
+
bond by the sorted integer indices of the two atoms involved in the bond.
|
| 184 |
+
"""
|
| 185 |
+
result = []
|
| 186 |
+
seen = set()
|
| 187 |
+
for atom1, bond_set in target_cgr._bonds.items():
|
| 188 |
+
for atom2, bond in bond_set.items():
|
| 189 |
+
if atom1 >= atom2:
|
| 190 |
+
continue
|
| 191 |
+
if bond.order is None and bond.p_order is not None:
|
| 192 |
+
bond_key = tuple(sorted((atom1, atom2)))
|
| 193 |
+
if bond_key not in seen:
|
| 194 |
+
seen.add(bond_key)
|
| 195 |
+
result.append(bond_key)
|
| 196 |
+
return sorted(result)
|
| 197 |
+
|
| 198 |
+
|
| 199 |
+
def cluster_routes(sb_cgrs: dict, use_strat=False):
|
| 200 |
+
"""
|
| 201 |
+
Cluster routes objects based on their strategic bonds
|
| 202 |
+
or CGRContainer object signature (not avoid mapping)
|
| 203 |
+
|
| 204 |
+
Args:
|
| 205 |
+
sb_cgrs: Dictionary mapping node_id to sb_cgr objects.
|
| 206 |
+
|
| 207 |
+
Returns:
|
| 208 |
+
Dictionary with groups keyed by '{length}.{index}' containing
|
| 209 |
+
'sb_cgr', 'node_ids', and 'strat_bonds'.
|
| 210 |
+
"""
|
| 211 |
+
temp_groups = defaultdict(
|
| 212 |
+
lambda: {"node_ids": [], "sb_cgr": None, "strat_bonds": None}
|
| 213 |
+
)
|
| 214 |
+
|
| 215 |
+
# 1. Initial grouping based on the content of strategic bonds
|
| 216 |
+
for node_id, sb_cgr in sb_cgrs.items():
|
| 217 |
+
strat_bonds_list = extract_strat_bonds(sb_cgr)
|
| 218 |
+
if use_strat == True:
|
| 219 |
+
group_key = tuple(strat_bonds_list)
|
| 220 |
+
else:
|
| 221 |
+
group_key = str(sb_cgr)
|
| 222 |
+
|
| 223 |
+
if not temp_groups[group_key]["node_ids"]: # First time seeing this group
|
| 224 |
+
temp_groups[group_key][
|
| 225 |
+
"sb_cgr"
|
| 226 |
+
] = sb_cgr # Store the first CGR as representative
|
| 227 |
+
temp_groups[group_key][
|
| 228 |
+
"strat_bonds"
|
| 229 |
+
] = strat_bonds_list # Store the actual list
|
| 230 |
+
|
| 231 |
+
temp_groups[group_key]["node_ids"].append(node_id)
|
| 232 |
+
temp_groups[group_key][
|
| 233 |
+
"node_ids"
|
| 234 |
+
].sort() # Keep node_ids sorted for consistency
|
| 235 |
+
|
| 236 |
+
for group_key in temp_groups.keys():
|
| 237 |
+
temp_groups[group_key]["group_size"] = len(temp_groups[group_key]["node_ids"])
|
| 238 |
+
|
| 239 |
+
# 2. Format the output dictionary with desired keys '{length}.{index}'
|
| 240 |
+
final_grouped_results = {}
|
| 241 |
+
group_indices = defaultdict(int) # To track index for each length
|
| 242 |
+
|
| 243 |
+
# Sort items by length of bonds first, then potentially by bonds themselves for consistent indexing
|
| 244 |
+
# Sorting by the group_key (tuple of tuples) provides a deterministic order
|
| 245 |
+
sorted_groups = sorted(
|
| 246 |
+
temp_groups.items(), key=lambda item: (len(item[0]), item[0])
|
| 247 |
+
)
|
| 248 |
+
|
| 249 |
+
for group_key, group_data in sorted_groups:
|
| 250 |
+
num_bonds = len(group_data["strat_bonds"])
|
| 251 |
+
group_indices[num_bonds] += 1 # Increment index for this length (1-based)
|
| 252 |
+
final_key = f"{num_bonds}.{group_indices[num_bonds]}"
|
| 253 |
+
final_grouped_results[final_key] = group_data
|
| 254 |
+
|
| 255 |
+
return final_grouped_results
|
| 256 |
+
|
| 257 |
+
|
| 258 |
+
def lg_process_reset(lg_cgr: CGRContainer, atom_num: int):
|
| 259 |
+
"""
|
| 260 |
+
Normalize bonds in an extracted leaving group (X) fragment and flag the attachment atom as a radical.
|
| 261 |
+
|
| 262 |
+
Scans all bonds in `lg_cgr`, converting any bond with undefined `p_order`
|
| 263 |
+
but defined `order` into a `DynamicBond` of matching integer order. Then sets
|
| 264 |
+
the atom at `atom_num` to a radical.
|
| 265 |
+
|
| 266 |
+
Parameters
|
| 267 |
+
----------
|
| 268 |
+
target_cgr : CGRContainer
|
| 269 |
+
The CGR representing the isolated leaving-group fragment.
|
| 270 |
+
atom_num : int
|
| 271 |
+
Index of the attachment atom to mark as a radical.
|
| 272 |
+
|
| 273 |
+
Returns
|
| 274 |
+
-------
|
| 275 |
+
CGRContainer
|
| 276 |
+
The modified `lg_cgr` with normalized bonds and the specified atom
|
| 277 |
+
flagged as a radical.
|
| 278 |
+
"""
|
| 279 |
+
bond_items = list(lg_cgr._bonds.items())
|
| 280 |
+
for atom1, bond_set in bond_items:
|
| 281 |
+
bond_set_items = list(bond_set.items())
|
| 282 |
+
for atom2, bond in bond_set_items:
|
| 283 |
+
if bond.p_order is None and bond.order is not None:
|
| 284 |
+
order = int(bond.order)
|
| 285 |
+
lg_cgr.delete_bond(atom1, atom2)
|
| 286 |
+
lg_cgr.add_bond(atom1, atom2, DynamicBond(order, order))
|
| 287 |
+
lg_cgr._atoms[atom_num].is_radical = True
|
| 288 |
+
return lg_cgr
|
| 289 |
+
|
| 290 |
+
|
| 291 |
+
def lg_replacer(route_cgr: CGRContainer):
|
| 292 |
+
"""
|
| 293 |
+
Extract dynamic leaving-groups from a CGR and mark attachment points.
|
| 294 |
+
|
| 295 |
+
Scans the input CGRContainer for bonds lacking explicit p_order (i.e., leaving-group attachments),
|
| 296 |
+
severs those bonds, captures each leaving-group as its own CGRContainer, and inserts DynamicX
|
| 297 |
+
markers at the attachment sites. Finally, reindexes the markers to ensure unique labels.
|
| 298 |
+
|
| 299 |
+
Parameters
|
| 300 |
+
----------
|
| 301 |
+
route_cgr : CGRContainer
|
| 302 |
+
A CGR representing the full synthethic route.
|
| 303 |
+
|
| 304 |
+
Returns
|
| 305 |
+
-------
|
| 306 |
+
synthon_cgr : CGRContainer
|
| 307 |
+
The core synthon CGR with DynamicX atoms marking each former leaving-group site.
|
| 308 |
+
lg_groups : dict[int, tuple[CGRContainer, int]]
|
| 309 |
+
Mapping from each marker label to a tuple of:
|
| 310 |
+
- the extracted leaving-group CGRContainer
|
| 311 |
+
- the atom index where it was attached.
|
| 312 |
+
"""
|
| 313 |
+
lg_groups = {}
|
| 314 |
+
|
| 315 |
+
cgr_prods = [route_cgr.substructure(c) for c in route_cgr.connected_components]
|
| 316 |
+
target_cgr = cgr_prods[0]
|
| 317 |
+
|
| 318 |
+
bond_items = list(target_cgr._bonds.items())
|
| 319 |
+
reaction = ReactionContainer.from_cgr(target_cgr)
|
| 320 |
+
target_mol = reaction.products[0]
|
| 321 |
+
max_in_target_mol = max(target_mol._atoms)
|
| 322 |
+
|
| 323 |
+
k = 1
|
| 324 |
+
atom_nums = []
|
| 325 |
+
checked_atoms = set()
|
| 326 |
+
|
| 327 |
+
for atom1, bond_set in bond_items:
|
| 328 |
+
bond_set_items = list(bond_set.items())
|
| 329 |
+
for atom2, bond in bond_set_items:
|
| 330 |
+
if bond.p_order is None and bond.order is not None and tuple(sorted([atom1, atom2])) not in checked_atoms:
|
| 331 |
+
if atom1 <= max_in_target_mol:
|
| 332 |
+
lg = DynamicX()
|
| 333 |
+
lg.mark = k
|
| 334 |
+
lg.isotope = k
|
| 335 |
+
order = bond.order
|
| 336 |
+
p_order = bond.p_order
|
| 337 |
+
target_cgr.delete_bond(atom1, atom2)
|
| 338 |
+
lg_cgrs = [
|
| 339 |
+
target_cgr.substructure(c)
|
| 340 |
+
for c in target_cgr.connected_components
|
| 341 |
+
]
|
| 342 |
+
checked_atoms.add(tuple(sorted([atom1, atom2])))
|
| 343 |
+
if len(lg_cgrs) == 2:
|
| 344 |
+
lg_cgr = lg_cgrs[1]
|
| 345 |
+
lg_cgr = lg_process_reset(lg_cgr, atom2)
|
| 346 |
+
lg_cgr.clean2d()
|
| 347 |
+
else:
|
| 348 |
+
continue
|
| 349 |
+
lg_groups[k] = (lg_cgr, atom2)
|
| 350 |
+
target_cgr = [
|
| 351 |
+
target_cgr.substructure(c)
|
| 352 |
+
for c in target_cgr.connected_components
|
| 353 |
+
][0]
|
| 354 |
+
target_cgr.add_atom(lg, atom2)
|
| 355 |
+
if order == 4 and p_order == None:
|
| 356 |
+
order = 1
|
| 357 |
+
target_cgr.add_bond(atom1, atom2, DynamicBond(order, p_order))
|
| 358 |
+
target_cgr = [
|
| 359 |
+
target_cgr.substructure(c)
|
| 360 |
+
for c in target_cgr.connected_components
|
| 361 |
+
][0]
|
| 362 |
+
k += 1
|
| 363 |
+
atom_nums.append(atom2)
|
| 364 |
+
|
| 365 |
+
synthon_cgr = [target_cgr.substructure(c) for c in target_cgr.connected_components][
|
| 366 |
+
0
|
| 367 |
+
]
|
| 368 |
+
reaction = ReactionContainer.from_cgr(synthon_cgr)
|
| 369 |
+
reactants = reaction.reactants
|
| 370 |
+
|
| 371 |
+
atom_mark_map = {} # To map atom numbers to their new marks
|
| 372 |
+
g = 1
|
| 373 |
+
for n, r in enumerate(reactants):
|
| 374 |
+
for atom_num in atom_nums:
|
| 375 |
+
if atom_num in r._atoms:
|
| 376 |
+
synthon_cgr._atoms[atom_num].mark = g
|
| 377 |
+
atom_mark_map[atom_num] = g
|
| 378 |
+
g += 1
|
| 379 |
+
|
| 380 |
+
new_lg_groups = {}
|
| 381 |
+
for original_mark in lg_groups:
|
| 382 |
+
cgr_obj, a_num = lg_groups[original_mark]
|
| 383 |
+
new_mark = atom_mark_map.get(a_num)
|
| 384 |
+
if new_mark is not None:
|
| 385 |
+
new_lg_groups[new_mark] = (cgr_obj, a_num)
|
| 386 |
+
lg_groups = new_lg_groups
|
| 387 |
+
|
| 388 |
+
return synthon_cgr, lg_groups
|
| 389 |
+
|
| 390 |
+
|
| 391 |
+
def lg_reaction_replacer(
|
| 392 |
+
synthon_reaction: ReactionContainer, lg_groups: dict, max_in_target_mol: int
|
| 393 |
+
):
|
| 394 |
+
"""
|
| 395 |
+
Replace marked leaving-groups (X) into synthon reactants.
|
| 396 |
+
|
| 397 |
+
For each reactant in `synthon_reaction`, finds placeholder atoms
|
| 398 |
+
(indices > `max_in_target_mol`) that match entries in `lg_groups`,
|
| 399 |
+
replaces them with `MarkedAt` atoms labeled by their leaving-group key (X),
|
| 400 |
+
and preserves original bond connectivity.
|
| 401 |
+
|
| 402 |
+
Parameters
|
| 403 |
+
----------
|
| 404 |
+
synthon_reaction : ReactionContainer
|
| 405 |
+
Reaction containing reactants with X placeholders.
|
| 406 |
+
lg_groups : dict[int, tuple[CGRContainer, int]]
|
| 407 |
+
Mapping from X label to (X CGR, attachment atom index).
|
| 408 |
+
max_in_target_mol : int
|
| 409 |
+
Highest atom index of the core product; any atom_num above this is a placeholder.
|
| 410 |
+
|
| 411 |
+
Returns
|
| 412 |
+
-------
|
| 413 |
+
List[Molecule]
|
| 414 |
+
Reactant molecules with `MarkedAt` atoms reinserted at X attachment sites.
|
| 415 |
+
"""
|
| 416 |
+
new_reactants = []
|
| 417 |
+
for reactant in synthon_reaction.reactants:
|
| 418 |
+
atom_keys = list(reactant._atoms.keys())
|
| 419 |
+
for atom_num in atom_keys:
|
| 420 |
+
if atom_num > max_in_target_mol:
|
| 421 |
+
for k, val in lg_groups.items():
|
| 422 |
+
lg = MarkedAt()
|
| 423 |
+
if atom_num == val[1]:
|
| 424 |
+
lg.mark = k
|
| 425 |
+
lg.isotope = k
|
| 426 |
+
atom1 = list(reactant._bonds[atom_num].keys())[0]
|
| 427 |
+
bond = reactant._bonds[atom_num][atom1]
|
| 428 |
+
reactant.delete_bond(atom1, atom_num)
|
| 429 |
+
reactant.delete_atom(atom_num)
|
| 430 |
+
reactant.add_atom(lg, atom_num)
|
| 431 |
+
reactant.add_bond(atom1, atom_num, bond)
|
| 432 |
+
new_reactants.append(reactant)
|
| 433 |
+
return new_reactants
|
| 434 |
+
|
| 435 |
+
|
| 436 |
+
class SubclusterError(Exception):
|
| 437 |
+
"""Raised when subcluster_one_cluster cannot complete successfully."""
|
| 438 |
+
|
| 439 |
+
|
| 440 |
+
def subcluster_one_cluster(group, sb_cgrs_dict, route_cgrs_dict):
|
| 441 |
+
"""
|
| 442 |
+
Generate synthon data for each route in a single cluster.
|
| 443 |
+
|
| 444 |
+
For each route (node ID) in `group['node_ids']`, replaces RouteCGRs with
|
| 445 |
+
SynthonCGR, builds ReactionContainers before and after X replacement,
|
| 446 |
+
and collects relevant data.
|
| 447 |
+
|
| 448 |
+
Parameters
|
| 449 |
+
----------
|
| 450 |
+
group : dict
|
| 451 |
+
Must include `'node_ids'`, a list of node identifiers.
|
| 452 |
+
sb_cgrs_dict : dict
|
| 453 |
+
Maps node IDs to their ReducedRouteCGR.
|
| 454 |
+
route_cgrs_dict : dict
|
| 455 |
+
Maps node IDs to their RouteCGR.
|
| 456 |
+
|
| 457 |
+
Returns
|
| 458 |
+
-------
|
| 459 |
+
dict or None
|
| 460 |
+
If successful, returns a dict mapping each `node_id` to a tuple:
|
| 461 |
+
`(sb_cgr, original_reaction, synthon_cgr, new_reaction, lg_groups)`.
|
| 462 |
+
Or raises SubclusterError on any failure: if any step (X replacement or reaction
|
| 463 |
+
parsing) fails for a node.
|
| 464 |
+
|
| 465 |
+
"""
|
| 466 |
+
|
| 467 |
+
node_ids = group.get("node_ids")
|
| 468 |
+
if not isinstance(node_ids, (list, tuple)):
|
| 469 |
+
raise SubclusterError(
|
| 470 |
+
f"'node_ids' must be a list or tuple, got {type(node_ids).__name__}"
|
| 471 |
+
)
|
| 472 |
+
|
| 473 |
+
result = {}
|
| 474 |
+
for node_id in node_ids:
|
| 475 |
+
sb_cgr = sb_cgrs_dict[node_id]
|
| 476 |
+
route_cgr = route_cgrs_dict[node_id]
|
| 477 |
+
|
| 478 |
+
# 1) Replace leaving groups in RouteCGR
|
| 479 |
+
try:
|
| 480 |
+
synthon_cgr, lg_groups = lg_replacer(route_cgr)
|
| 481 |
+
except (KeyError, ValueError) as e:
|
| 482 |
+
raise SubclusterError(f"LG replacement failed for node {node_id}") from e
|
| 483 |
+
|
| 484 |
+
# 2) Build ReactionContainer for Abstracted RouteCGR
|
| 485 |
+
try:
|
| 486 |
+
synthon_rxn = ReactionContainer.from_cgr(synthon_cgr)
|
| 487 |
+
except: # replace with the actual exception class
|
| 488 |
+
raise SubclusterError(
|
| 489 |
+
f"Failed to parse synthon CGR for node {node_id}"
|
| 490 |
+
) from e
|
| 491 |
+
|
| 492 |
+
# 3) Prepare for X-based reaction replacement
|
| 493 |
+
try:
|
| 494 |
+
old_reactants = synthon_rxn.reactants
|
| 495 |
+
target_mol = synthon_rxn.products[0]
|
| 496 |
+
max_atom_idx = max(target_mol._atoms)
|
| 497 |
+
new_reactants = lg_reaction_replacer(synthon_rxn, lg_groups, max_atom_idx)
|
| 498 |
+
new_rxn = ReactionContainer(reactants=new_reactants, products=[target_mol])
|
| 499 |
+
except (IndexError, TypeError) as e:
|
| 500 |
+
raise SubclusterError(
|
| 501 |
+
f"Leaving group (X) reaction replacement failed for node {node_id}"
|
| 502 |
+
) from e
|
| 503 |
+
|
| 504 |
+
result[node_id] = (
|
| 505 |
+
sb_cgr,
|
| 506 |
+
ReactionContainer(reactants=old_reactants, products=[target_mol]),
|
| 507 |
+
synthon_cgr,
|
| 508 |
+
new_rxn,
|
| 509 |
+
lg_groups,
|
| 510 |
+
)
|
| 511 |
+
|
| 512 |
+
return result
|
| 513 |
+
|
| 514 |
+
|
| 515 |
+
def group_nodes_by_synthon_detail(data_dict: dict):
|
| 516 |
+
"""
|
| 517 |
+
Groups nodes based on synthon CGR (result[0]) and reaction (result[1]).
|
| 518 |
+
The output includes a dictionary mapping node IDs to their result[2] value.
|
| 519 |
+
|
| 520 |
+
Args:
|
| 521 |
+
data_dict: Dictionary {node_id: [synthon_cgr, synthon_reaction, node_data, ...]}.
|
| 522 |
+
|
| 523 |
+
Returns:
|
| 524 |
+
Dictionary {group_index: {'sb_cgr': ... ,'synthon_cgr': ..., 'synthon_reaction': ...,
|
| 525 |
+
'nodes_data': {node_id1: node_data1, ...}}}.
|
| 526 |
+
"""
|
| 527 |
+
temp_groups = defaultdict(list)
|
| 528 |
+
|
| 529 |
+
for node_id, result_list in data_dict.items():
|
| 530 |
+
if len(result_list) < 4:
|
| 531 |
+
group_key = (result_list[0], None) # Handle missing reaction
|
| 532 |
+
else:
|
| 533 |
+
try:
|
| 534 |
+
group_key = (
|
| 535 |
+
result_list[0],
|
| 536 |
+
result_list[1],
|
| 537 |
+
result_list[2],
|
| 538 |
+
result_list[3],
|
| 539 |
+
)
|
| 540 |
+
except TypeError:
|
| 541 |
+
print(
|
| 542 |
+
f"Warning: Skipping node {node_id} because reaction data is not hashable: {type(result_list[1])}"
|
| 543 |
+
)
|
| 544 |
+
continue
|
| 545 |
+
|
| 546 |
+
temp_groups[group_key].append(node_id)
|
| 547 |
+
|
| 548 |
+
# 2. Format the output dictionary with sequential integer keys
|
| 549 |
+
# and include the node-specific data (result[2]) in a sub-dictionary.
|
| 550 |
+
final_grouped_results = {}
|
| 551 |
+
group_index = 1
|
| 552 |
+
|
| 553 |
+
sorted_temp_groups = sorted(temp_groups.items(), key=lambda item: item[1])
|
| 554 |
+
for group_key, node_ids in sorted_temp_groups:
|
| 555 |
+
|
| 556 |
+
sb_cgr, unlabeled_reaction, synthon_cgr, synthon_reaction = group_key
|
| 557 |
+
nodes_data_dict = {}
|
| 558 |
+
|
| 559 |
+
# Iterate through the node IDs belonging to this group
|
| 560 |
+
for node_id in sorted(node_ids): # Sort node IDs for consistent dict order
|
| 561 |
+
original_result = data_dict.get(
|
| 562 |
+
node_id, []
|
| 563 |
+
) # Get original list for this node
|
| 564 |
+
node_specific_data = None # Default value if index 2 is missing
|
| 565 |
+
if len(original_result) > 4:
|
| 566 |
+
node_specific_data = original_result[4] # Get the third element
|
| 567 |
+
|
| 568 |
+
nodes_data_dict[node_id] = node_specific_data # Add to the sub-dictionary
|
| 569 |
+
|
| 570 |
+
final_grouped_results[group_index] = {
|
| 571 |
+
"sb_cgr": sb_cgr,
|
| 572 |
+
"unlabeled_reaction": unlabeled_reaction,
|
| 573 |
+
"synthon_cgr": synthon_cgr,
|
| 574 |
+
"synthon_reaction": synthon_reaction,
|
| 575 |
+
"nodes_data": nodes_data_dict,
|
| 576 |
+
"post_processed": False,
|
| 577 |
+
}
|
| 578 |
+
group_index += 1
|
| 579 |
+
|
| 580 |
+
return final_grouped_results
|
| 581 |
+
|
| 582 |
+
|
| 583 |
+
def subcluster_all_clusters(groups, sb_cgrs_dict, route_cgrs_dict):
|
| 584 |
+
"""
|
| 585 |
+
Subdivide each reaction cluster into detailed synthon-based subgroups.
|
| 586 |
+
|
| 587 |
+
Iterates over all clusters in `groups`, applies `subcluster_one_cluster`
|
| 588 |
+
to generate per-cluster synthons, then organizes nodes by synthon detail.
|
| 589 |
+
|
| 590 |
+
Parameters
|
| 591 |
+
----------
|
| 592 |
+
groups : dict
|
| 593 |
+
Mapping of cluster indices to cluster data.
|
| 594 |
+
sb_cgrs_dict : dict
|
| 595 |
+
Dictionary of ReducedRoteCGRs
|
| 596 |
+
route_cgrs_dict : dict
|
| 597 |
+
Dictionary of RoteCGRs
|
| 598 |
+
|
| 599 |
+
Returns
|
| 600 |
+
-------
|
| 601 |
+
dict or None
|
| 602 |
+
A dict mapping each cluster index to its subgroups dict,
|
| 603 |
+
or None if any cluster fails to subcluster.
|
| 604 |
+
"""
|
| 605 |
+
all_subgroups = {}
|
| 606 |
+
for group_index, group in groups.items():
|
| 607 |
+
group_synthons = subcluster_one_cluster(
|
| 608 |
+
group, sb_cgrs_dict, route_cgrs_dict
|
| 609 |
+
)
|
| 610 |
+
if group_synthons is None:
|
| 611 |
+
return None
|
| 612 |
+
all_subgroups[group_index] = group_nodes_by_synthon_detail(group_synthons)
|
| 613 |
+
return all_subgroups
|
| 614 |
+
|
| 615 |
+
|
| 616 |
+
def all_lg_collect(subgroup):
|
| 617 |
+
"""
|
| 618 |
+
Gather all leaving-group CGRContainers by node index.
|
| 619 |
+
|
| 620 |
+
Scans `subgroup['nodes_data']`, collects every CGRContainer per index,
|
| 621 |
+
and returns a mapping from each index to the list of distinct containers.
|
| 622 |
+
|
| 623 |
+
Parameters
|
| 624 |
+
----------
|
| 625 |
+
subgroup : dict
|
| 626 |
+
Must contain 'nodes_data', a dict mapping pathway keys to
|
| 627 |
+
dicts of {node_index: (CGRContainer, …)}.
|
| 628 |
+
|
| 629 |
+
Returns
|
| 630 |
+
-------
|
| 631 |
+
dict[int, list[CGRContainer]]
|
| 632 |
+
For each node index, a list of unique CGRContainer objects
|
| 633 |
+
(duplicates by string are filtered out).
|
| 634 |
+
"""
|
| 635 |
+
all_indices = set()
|
| 636 |
+
for sub_dict in subgroup["nodes_data"].values():
|
| 637 |
+
all_indices.update(sub_dict.keys())
|
| 638 |
+
|
| 639 |
+
# Dynamically initialize result and seen dictionaries
|
| 640 |
+
result = {idx: [] for idx in all_indices}
|
| 641 |
+
seen = {idx: set() for idx in all_indices}
|
| 642 |
+
|
| 643 |
+
# Populate the result with unique CGRContainer objects
|
| 644 |
+
for sub_dict in subgroup["nodes_data"].values():
|
| 645 |
+
for idx in sub_dict:
|
| 646 |
+
cgr_container = sub_dict[idx][0]
|
| 647 |
+
cgr_str = str(cgr_container)
|
| 648 |
+
if cgr_str not in seen[idx]:
|
| 649 |
+
seen[idx].add(cgr_str)
|
| 650 |
+
result[idx].append(cgr_container)
|
| 651 |
+
return result
|
| 652 |
+
|
| 653 |
+
|
| 654 |
+
def replace_leaving_groups_in_synthon(subgroup, to_remove): # Under development
|
| 655 |
+
"""
|
| 656 |
+
Replace specified leaving groups (LG) in a synthon CGR with new fragments and return the updated CGR
|
| 657 |
+
along with a mapping from adjusted LG marks to their atom indices.
|
| 658 |
+
|
| 659 |
+
Parameters:
|
| 660 |
+
subgroup (dict): Must contain:
|
| 661 |
+
- 'synthon_cgr': the CGR object representing the synthon graph
|
| 662 |
+
- 'nodes_data': mapping of node indices to LG replacement data
|
| 663 |
+
to_remove (List[int]): List of LG marks to remove and replace.
|
| 664 |
+
|
| 665 |
+
Returns:
|
| 666 |
+
Tuple[CGR, Dict[int, int]]:
|
| 667 |
+
- The updated CGR with replacements
|
| 668 |
+
- A dict mapping new LG marks to their atom indices in the updated CGR
|
| 669 |
+
"""
|
| 670 |
+
# Extract the original CGR and leaving group replacement table
|
| 671 |
+
original_cgr = subgroup["synthon_cgr"]
|
| 672 |
+
lg_table = next(iter(subgroup["nodes_data"].values()))
|
| 673 |
+
|
| 674 |
+
updated_cgr = original_cgr
|
| 675 |
+
|
| 676 |
+
removed_count = 0
|
| 677 |
+
new_lgs = {}
|
| 678 |
+
|
| 679 |
+
# Iterate through all atoms (index, atom_obj) in the CGR
|
| 680 |
+
for atom_idx, atom_obj in list(updated_cgr.atoms()):
|
| 681 |
+
# Skip non-X atoms
|
| 682 |
+
if atom_obj.__class__.__name__ != "DynamicX":
|
| 683 |
+
continue
|
| 684 |
+
|
| 685 |
+
current_mark = atom_obj.mark
|
| 686 |
+
if current_mark in to_remove:
|
| 687 |
+
# Remove old LG (X): delete bond and atom
|
| 688 |
+
neighbors = list(updated_cgr._bonds[atom_idx].keys())
|
| 689 |
+
if neighbors:
|
| 690 |
+
neighbor_idx = neighbors[0]
|
| 691 |
+
bond = updated_cgr._bonds[atom_idx][neighbor_idx]
|
| 692 |
+
updated_cgr.delete_bond(atom_idx, neighbor_idx)
|
| 693 |
+
updated_cgr.delete_atom(atom_idx)
|
| 694 |
+
|
| 695 |
+
# Attach new LG(X) fragment from the table
|
| 696 |
+
lg_fragment = lg_table[current_mark][0]
|
| 697 |
+
updated_cgr = updated_cgr.union(lg_fragment)
|
| 698 |
+
# Reset radical flag on the new atom and restore the bond
|
| 699 |
+
updated_cgr._atoms[atom_idx].is_radical = False
|
| 700 |
+
updated_cgr.add_bond(atom_idx, neighbor_idx, bond)
|
| 701 |
+
|
| 702 |
+
removed_count += 1
|
| 703 |
+
else:
|
| 704 |
+
# Adjust the marks of remaining LGs to account for removed ones
|
| 705 |
+
atom_obj.mark -= removed_count
|
| 706 |
+
new_lgs[atom_obj.mark] = atom_idx
|
| 707 |
+
|
| 708 |
+
# Reorder atoms dict and update 2D coordinates for depiction
|
| 709 |
+
updated_cgr._atoms = dict(sorted(updated_cgr._atoms.items()))
|
| 710 |
+
|
| 711 |
+
return updated_cgr, new_lgs
|
| 712 |
+
|
| 713 |
+
|
| 714 |
+
def new_lg_reaction_replacer(synthon_reaction, new_lgs, max_in_target_mol):
|
| 715 |
+
"""
|
| 716 |
+
Replace placeholder atom indices with marked leaving-group atoms in reactants.
|
| 717 |
+
|
| 718 |
+
Iterates through each reactant in a `ReactionContainer`, finds atom indices
|
| 719 |
+
corresponding to newly detached leaving-groups (those greater than the
|
| 720 |
+
core’s maximum index), and replaces them with `MarkedAt` atoms bearing
|
| 721 |
+
the correct X labels and isotopes. Bonds to the original attachment points
|
| 722 |
+
are preserved.
|
| 723 |
+
|
| 724 |
+
Parameters
|
| 725 |
+
----------
|
| 726 |
+
synthon_reaction : ReactionContainer
|
| 727 |
+
A reaction container whose `reactants` list contains molecules with
|
| 728 |
+
dummy atoms (by index) marking where leaving-groups were removed.
|
| 729 |
+
new_lgs : dict[int, int]
|
| 730 |
+
Mapping from leaving-group label (int) to the atom index (int) in each
|
| 731 |
+
reactant that should be replaced.
|
| 732 |
+
max_in_target_mol : int
|
| 733 |
+
The highest atom index used by the core product. Any atom index in a
|
| 734 |
+
reactant greater than this is treated as a leaving-group placeholder.
|
| 735 |
+
|
| 736 |
+
Returns
|
| 737 |
+
-------
|
| 738 |
+
List[Molecule]
|
| 739 |
+
A list of reactant molecules where each placeholder atom has been
|
| 740 |
+
replaced by a `MarkedAt` atom with its `.mark` and `.isotope` set
|
| 741 |
+
to the leaving-group label, and original bonds reattached.
|
| 742 |
+
"""
|
| 743 |
+
new_reactants = []
|
| 744 |
+
for reactant in synthon_reaction.reactants:
|
| 745 |
+
atom_keys = list(reactant._atoms.keys())
|
| 746 |
+
for atom_num in atom_keys:
|
| 747 |
+
if atom_num > max_in_target_mol:
|
| 748 |
+
for k, val in new_lgs.items():
|
| 749 |
+
lg = MarkedAt()
|
| 750 |
+
if atom_num == val:
|
| 751 |
+
lg.mark = k
|
| 752 |
+
lg.isotope = k
|
| 753 |
+
atom1 = list(reactant._bonds[atom_num].keys())[0]
|
| 754 |
+
bond = reactant._bonds[atom_num][atom1]
|
| 755 |
+
reactant.delete_bond(atom1, atom_num)
|
| 756 |
+
reactant.delete_atom(atom_num)
|
| 757 |
+
reactant.add_atom(lg, atom_num)
|
| 758 |
+
reactant.add_bond(atom1, atom_num, bond)
|
| 759 |
+
new_reactants.append(reactant)
|
| 760 |
+
|
| 761 |
+
return new_reactants
|
| 762 |
+
|
| 763 |
+
|
| 764 |
+
def post_process_subgroup(
|
| 765 |
+
subgroup,
|
| 766 |
+
): # Under development: Error in replace_leaving_groups_in_synthon , 'cuz synthon_reaction.clean2d crashes
|
| 767 |
+
"""
|
| 768 |
+
Drop leaving-groups common to all pathways and rebuild a minimal synthon.
|
| 769 |
+
|
| 770 |
+
Scans the subgroup for leaving-groups present in every route, removes those
|
| 771 |
+
from the CGR, re-assembles a clean ReactionContainer with the original core,
|
| 772 |
+
updates `nodes_data`, and flags the dict as processed.
|
| 773 |
+
|
| 774 |
+
Parameters
|
| 775 |
+
----------
|
| 776 |
+
subgroup : dict
|
| 777 |
+
Must include keys for `nodes_data` and the helpers
|
| 778 |
+
(`all_lg_collect`, `find_const_lg`, etc.). If already
|
| 779 |
+
post_processed, returns immediately.
|
| 780 |
+
|
| 781 |
+
Returns
|
| 782 |
+
-------
|
| 783 |
+
dict
|
| 784 |
+
The same dict, now with:
|
| 785 |
+
- `'synthon_reaction'`: cleaned ReactionContainer
|
| 786 |
+
- `'nodes_data'`: filtered node table
|
| 787 |
+
- `'post_processed'`: True
|
| 788 |
+
"""
|
| 789 |
+
if "post_processed" in subgroup.keys() and subgroup["post_processed"] == True:
|
| 790 |
+
return subgroup
|
| 791 |
+
result = all_lg_collect(subgroup)
|
| 792 |
+
# to find constant lg that need to be removed
|
| 793 |
+
to_remove = [ind for ind, cgr_set in result.items() if len(cgr_set) == 1]
|
| 794 |
+
new_synthon_cgr, new_lgs = replace_leaving_groups_in_synthon(subgroup, to_remove)
|
| 795 |
+
synthon_reaction = ReactionContainer.from_cgr(new_synthon_cgr)
|
| 796 |
+
synthon_reaction.clean2d()
|
| 797 |
+
old_reactants = ReactionContainer.from_cgr(new_synthon_cgr).reactants
|
| 798 |
+
target_mol = synthon_reaction.products[0] # TO DO: target_mol might be non 0
|
| 799 |
+
max_in_target_mol = max(target_mol._atoms)
|
| 800 |
+
new_reactants = new_lg_reaction_replacer(
|
| 801 |
+
synthon_reaction, new_lgs, max_in_target_mol
|
| 802 |
+
)
|
| 803 |
+
new_synthon_reaction = ReactionContainer(
|
| 804 |
+
reactants=new_reactants, products=[target_mol]
|
| 805 |
+
)
|
| 806 |
+
new_synthon_reaction.clean2d()
|
| 807 |
+
subgroup["synthon_reaction"] = new_synthon_reaction
|
| 808 |
+
subgroup["nodes_data"] = remove_and_shift(subgroup["nodes_data"], to_remove)
|
| 809 |
+
subgroup["post_processed"] = True
|
| 810 |
+
subgroup["group_lgs"] = group_by_identical_values(subgroup["nodes_data"])
|
| 811 |
+
return subgroup
|
| 812 |
+
|
| 813 |
+
|
| 814 |
+
def group_by_identical_values(nodes_data): # Under development
|
| 815 |
+
"""
|
| 816 |
+
Groups entries in a nested dictionary based on identical sets of core values.
|
| 817 |
+
|
| 818 |
+
Identifies route IDs whose inner dictionaries contain the
|
| 819 |
+
same sequence of leaving groups, when ordered by subkey. These are collapsed into a single entry.
|
| 820 |
+
|
| 821 |
+
Args:
|
| 822 |
+
nodes_data (dict): A dictionary mapping outer keys to inner dictionaries.
|
| 823 |
+
Each inner dictionary maps subkeys to a tuple `(value_obj, other_info)`.
|
| 824 |
+
`value_obj` is used for grouping, `other_info` is ignored.
|
| 825 |
+
Example: {'route_1': {'pos_a': (1, 'infoA'), 'pos_b': (2, 'infoB')}, ...}
|
| 826 |
+
|
| 827 |
+
Returns:
|
| 828 |
+
dict: A dictionary where:
|
| 829 |
+
- Keys are tuples of the original outer keys that were grouped.
|
| 830 |
+
- Values are dictionaries mapping the original subkeys to the
|
| 831 |
+
`value_obj` from the first outer key in the group's tuple.
|
| 832 |
+
The dictionary is sorted descending by the number of grouped outer keys.
|
| 833 |
+
Example: {('route_1', 'route_2'): {'pos_a': 1, 'pos_b': 2}, ...}
|
| 834 |
+
"""
|
| 835 |
+
# Step 1: Build a signature for each outer key: the tuple of all first-elements in its inner dict
|
| 836 |
+
signature_map = defaultdict(list)
|
| 837 |
+
for outer_key, inner_dict in nodes_data.items():
|
| 838 |
+
# Sort inner_dict items by subkey to ensure consistent ordering
|
| 839 |
+
sorted_items = sorted(inner_dict.items(), key=lambda kv: kv[0])
|
| 840 |
+
# Extract only the first element of each (value_obj, other_info) tuple
|
| 841 |
+
signature = tuple(val_tuple[0] for _, val_tuple in sorted_items)
|
| 842 |
+
signature_map[signature].append(outer_key)
|
| 843 |
+
|
| 844 |
+
# Step 2: Build the grouped result
|
| 845 |
+
grouped = {}
|
| 846 |
+
for signature, outer_keys in signature_map.items():
|
| 847 |
+
# Use the representative inner dict from the first outer key in this group
|
| 848 |
+
rep_inner = nodes_data[outer_keys[0]]
|
| 849 |
+
# Build mapping subkey -> value_obj
|
| 850 |
+
rep_values = {subkey: val_tuple[0] for subkey, val_tuple in rep_inner.items()}
|
| 851 |
+
# Store under tuple of grouped outer keys
|
| 852 |
+
grouped_key = tuple(outer_keys)
|
| 853 |
+
grouped[grouped_key] = rep_values
|
| 854 |
+
|
| 855 |
+
sorted_grouped = dict(
|
| 856 |
+
sorted(grouped.items(), key=lambda item: len(item[0]), reverse=True)
|
| 857 |
+
)
|
| 858 |
+
|
| 859 |
+
return sorted_grouped
|
synplan/chem/reaction_routes/io.py
ADDED
|
@@ -0,0 +1,286 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import csv
|
| 2 |
+
import json
|
| 3 |
+
import pickle
|
| 4 |
+
import os
|
| 5 |
+
|
| 6 |
+
from CGRtools import smiles as read_smiles
|
| 7 |
+
from synplan.mcts.tree import Tree
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def make_dict(routes_json):
|
| 11 |
+
"""
|
| 12 |
+
routes_json : list of tree-dicts as produced by make_json()
|
| 13 |
+
|
| 14 |
+
Returns a dict mapping each route index (0, 1, …) to a sub-dict
|
| 15 |
+
of {new_step_id: ReactionContainer}, where the step IDs run
|
| 16 |
+
from the earliest reaction (0) up to the final (max).
|
| 17 |
+
"""
|
| 18 |
+
routes_dict = {}
|
| 19 |
+
if isinstance(routes_json, dict):
|
| 20 |
+
for route_idx, tree in routes_json.items():
|
| 21 |
+
rxn_list = []
|
| 22 |
+
|
| 23 |
+
def _postorder(node):
|
| 24 |
+
# first dive into any children, then record this reaction
|
| 25 |
+
for child in node.get("children", []):
|
| 26 |
+
_postorder(child)
|
| 27 |
+
if node["type"] == "reaction":
|
| 28 |
+
rxn_list.append(read_smiles(node["smiles"]))
|
| 29 |
+
# mol-nodes simply recurse (no record)
|
| 30 |
+
|
| 31 |
+
# collect all reactions in leaf→root order
|
| 32 |
+
_postorder(tree)
|
| 33 |
+
|
| 34 |
+
# now assign 0,1,2,… in that order
|
| 35 |
+
reactions = {i: rxn for i, rxn in enumerate(rxn_list)}
|
| 36 |
+
routes_dict[int(route_idx)] = reactions
|
| 37 |
+
|
| 38 |
+
return routes_dict
|
| 39 |
+
else:
|
| 40 |
+
for route_idx, tree in enumerate(routes_json):
|
| 41 |
+
rxn_list = []
|
| 42 |
+
|
| 43 |
+
def _postorder(node):
|
| 44 |
+
# first dive into any children, then record this reaction
|
| 45 |
+
for child in node.get("children", []):
|
| 46 |
+
_postorder(child)
|
| 47 |
+
if node["type"] == "reaction":
|
| 48 |
+
rxn_list.append(read_smiles(node["smiles"]))
|
| 49 |
+
# mol-nodes simply recurse (no record)
|
| 50 |
+
|
| 51 |
+
# collect all reactions in leaf→root order
|
| 52 |
+
_postorder(tree)
|
| 53 |
+
|
| 54 |
+
# now assign 0,1,2,… in that order
|
| 55 |
+
reactions = {i: rxn for i, rxn in enumerate(rxn_list)}
|
| 56 |
+
routes_dict[int(route_idx)] = reactions
|
| 57 |
+
|
| 58 |
+
return routes_dict
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
def read_routes_json(file_path="routes.csv", to_dict=False):
|
| 62 |
+
with open(file_path, "r") as file:
|
| 63 |
+
routes_json = json.load(file)
|
| 64 |
+
if to_dict:
|
| 65 |
+
return make_dict(routes_json)
|
| 66 |
+
return routes_json
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
def read_routes_csv(file_path="routes.csv"):
|
| 70 |
+
"""
|
| 71 |
+
Read a CSV with columns: route_id, step_id, smiles, meta
|
| 72 |
+
and return a nested dict mapping
|
| 73 |
+
route_id (int) -> step_id (int) -> ReactionContainer
|
| 74 |
+
(ignoring meta for now, but you could extract it if needed).
|
| 75 |
+
"""
|
| 76 |
+
routes_dict = {}
|
| 77 |
+
with open(file_path, newline="") as csvfile:
|
| 78 |
+
reader = csv.DictReader(csvfile)
|
| 79 |
+
for row in reader:
|
| 80 |
+
route_id = int(row["route_id"])
|
| 81 |
+
step_id = int(row["step_id"])
|
| 82 |
+
smiles = row["smiles"]
|
| 83 |
+
# adjust this constructor to your actual API
|
| 84 |
+
reaction = read_smiles(smiles)
|
| 85 |
+
routes_dict.setdefault(route_id, {})[step_id] = reaction
|
| 86 |
+
return routes_dict
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
def make_json(routes_dict, keep_ids=True):
|
| 90 |
+
"""
|
| 91 |
+
Convert routes into a nested JSON tree of reaction and molecule nodes.
|
| 92 |
+
|
| 93 |
+
Args:
|
| 94 |
+
routes_dict (dict[int, dict[int, Reaction]]): Mapping route IDs to steps (step_id -> Reaction).
|
| 95 |
+
keep_ids (bool): If True, returns a list of route trees; otherwise returns a dict mapping route IDs to trees.
|
| 96 |
+
|
| 97 |
+
Returns:
|
| 98 |
+
list or dict: JSON-like tree(s) of routes.
|
| 99 |
+
"""
|
| 100 |
+
# Prepare output
|
| 101 |
+
all_routes = {} if keep_ids else []
|
| 102 |
+
|
| 103 |
+
for route_id, steps in routes_dict.items():
|
| 104 |
+
if not steps:
|
| 105 |
+
continue
|
| 106 |
+
|
| 107 |
+
# Determine target molecule atoms from the final step of this route
|
| 108 |
+
final_step = max(steps)
|
| 109 |
+
target = steps[final_step].products[0]
|
| 110 |
+
atom_nums = set(target._atoms.keys())
|
| 111 |
+
|
| 112 |
+
# Precompute canonical SMILES and producer mapping for all products
|
| 113 |
+
prod_map = {} # smiles -> list of step_ids
|
| 114 |
+
for sid, rxn in steps.items():
|
| 115 |
+
for prod in rxn.products:
|
| 116 |
+
prod.kekule()
|
| 117 |
+
prod.implicify_hydrogens()
|
| 118 |
+
prod.thiele()
|
| 119 |
+
s = str(prod)
|
| 120 |
+
prod_map.setdefault(s, []).append(sid)
|
| 121 |
+
|
| 122 |
+
def transform(mol):
|
| 123 |
+
mol.kekule()
|
| 124 |
+
mol.implicify_hydrogens()
|
| 125 |
+
mol.thiele()
|
| 126 |
+
return str(mol)
|
| 127 |
+
|
| 128 |
+
def build_mol_node(sid):
|
| 129 |
+
"""Find the product with any overlap to target atoms and recurse into its reaction."""
|
| 130 |
+
rxn = steps[sid]
|
| 131 |
+
for p in rxn.products:
|
| 132 |
+
if atom_nums & set(p._atoms.keys()):
|
| 133 |
+
smiles = str(p)
|
| 134 |
+
return {
|
| 135 |
+
"type": "mol",
|
| 136 |
+
"smiles": smiles,
|
| 137 |
+
"children": [build_reaction_node(sid)],
|
| 138 |
+
"in_stock": False,
|
| 139 |
+
}
|
| 140 |
+
# Shouldn't reach here if tree is consistent
|
| 141 |
+
return None
|
| 142 |
+
|
| 143 |
+
def build_reaction_node(sid):
|
| 144 |
+
"""Build reaction node and recurse into reactant molecule nodes."""
|
| 145 |
+
rxn = steps[sid]
|
| 146 |
+
node = {"type": "reaction", "smiles": format(rxn, "m"), "children": []}
|
| 147 |
+
|
| 148 |
+
for react in rxn.reactants:
|
| 149 |
+
r_smi = transform(react)
|
| 150 |
+
# Look up any prior step producing this reactant
|
| 151 |
+
prior = [ps for ps in prod_map.get(r_smi, []) if ps < sid]
|
| 152 |
+
if prior:
|
| 153 |
+
node["children"].append(build_mol_node(max(prior)))
|
| 154 |
+
else:
|
| 155 |
+
node["children"].append(
|
| 156 |
+
{"type": "mol", "smiles": r_smi, "in_stock": True}
|
| 157 |
+
)
|
| 158 |
+
|
| 159 |
+
return node
|
| 160 |
+
|
| 161 |
+
# Build route tree and store
|
| 162 |
+
tree = build_mol_node(final_step)
|
| 163 |
+
if keep_ids:
|
| 164 |
+
all_routes[int(route_id)] = tree
|
| 165 |
+
else:
|
| 166 |
+
all_routes.append(tree)
|
| 167 |
+
|
| 168 |
+
return all_routes
|
| 169 |
+
|
| 170 |
+
|
| 171 |
+
def write_routes_json(routes_dict, file_path):
|
| 172 |
+
"""Serialize reaction routes to a JSON file."""
|
| 173 |
+
routes_json = make_json(routes_dict)
|
| 174 |
+
with open(file_path, "w") as f:
|
| 175 |
+
json.dump(routes_json, f, indent=2)
|
| 176 |
+
|
| 177 |
+
|
| 178 |
+
def write_routes_csv(routes_dict, file_path="routes.csv"):
|
| 179 |
+
"""
|
| 180 |
+
Write out a nested routes_dict of the form
|
| 181 |
+
{ route_id: { step_id: reaction_obj, ... }, ... }
|
| 182 |
+
to a CSV with columns: route_id, step_id, smiles, meta
|
| 183 |
+
where smiles is format(reaction, 'm') and meta is left blank.
|
| 184 |
+
"""
|
| 185 |
+
with open(file_path, "w", newline="") as csvfile:
|
| 186 |
+
writer = csv.writer(csvfile)
|
| 187 |
+
# header row
|
| 188 |
+
writer.writerow(["route_id", "step_id", "smiles", "meta"])
|
| 189 |
+
# sort routes and steps for deterministic output
|
| 190 |
+
for route_id in sorted(routes_dict):
|
| 191 |
+
steps = routes_dict[route_id]
|
| 192 |
+
for step_id in sorted(steps):
|
| 193 |
+
reaction = steps[step_id]
|
| 194 |
+
smiles = format(reaction, "m")
|
| 195 |
+
meta = "" # or reaction.meta if you add that later
|
| 196 |
+
writer.writerow([route_id, step_id, smiles, meta])
|
| 197 |
+
|
| 198 |
+
|
| 199 |
+
class TreeWrapper:
|
| 200 |
+
|
| 201 |
+
def __init__(self, tree, mol_id=1, config=1, path="planning_results/forest"):
|
| 202 |
+
"""Initializes the TreeWrapper."""
|
| 203 |
+
self.tree = tree
|
| 204 |
+
self.mol_id = mol_id
|
| 205 |
+
self.config = config
|
| 206 |
+
self.path = path
|
| 207 |
+
# Ensure the directory exists before creating the filename
|
| 208 |
+
os.makedirs(self.path, exist_ok=True)
|
| 209 |
+
self.filename = os.path.join(self.path, f"tree_{mol_id}_{config}.pkl")
|
| 210 |
+
|
| 211 |
+
def __getstate__(self):
|
| 212 |
+
state = self.__dict__.copy()
|
| 213 |
+
tree_state = self.tree.__dict__.copy()
|
| 214 |
+
# Reset or remove non-pickleable attributes (e.g., _tqdm, policy_network, value_network)
|
| 215 |
+
if "_tqdm" in tree_state:
|
| 216 |
+
tree_state["_tqdm"] = True # Reset to a simple flag
|
| 217 |
+
for attr in ["policy_network", "value_network"]:
|
| 218 |
+
if attr in tree_state:
|
| 219 |
+
tree_state[attr] = None
|
| 220 |
+
state["tree_state"] = tree_state
|
| 221 |
+
del state["tree"]
|
| 222 |
+
return state
|
| 223 |
+
|
| 224 |
+
def __setstate__(self, state):
|
| 225 |
+
tree_state = state.pop("tree_state")
|
| 226 |
+
self.__dict__.update(state)
|
| 227 |
+
new_tree = Tree.__new__(Tree)
|
| 228 |
+
new_tree.__dict__.update(tree_state)
|
| 229 |
+
self.tree = new_tree
|
| 230 |
+
|
| 231 |
+
def save_tree(self):
|
| 232 |
+
"""Saves the TreeWrapper instance (including the tree state) to a file."""
|
| 233 |
+
try:
|
| 234 |
+
with open(self.filename, "wb") as f:
|
| 235 |
+
pickle.dump(self, f)
|
| 236 |
+
print(
|
| 237 |
+
f"Tree wrapper for mol_id '{self.mol_id}', config '{self.config}' saved to '{self.filename}'."
|
| 238 |
+
)
|
| 239 |
+
except Exception as e:
|
| 240 |
+
print(f"Error saving tree to {self.filename}: {e}")
|
| 241 |
+
|
| 242 |
+
@classmethod
|
| 243 |
+
def load_tree_from_id(cls, mol_id, config=1, path="planning_results/forest"):
|
| 244 |
+
"""
|
| 245 |
+
Loads a Tree object from a saved file using mol_id and config.
|
| 246 |
+
|
| 247 |
+
Args:
|
| 248 |
+
mol_id: The molecule ID used for saving.
|
| 249 |
+
config: The configuration used for saving.
|
| 250 |
+
path: The directory where the file is located
|
| 251 |
+
|
| 252 |
+
Returns:
|
| 253 |
+
The loaded Tree object, or None if loading fails.
|
| 254 |
+
"""
|
| 255 |
+
filename = os.path.join(path, f"tree_{mol_id}_{config}.pkl")
|
| 256 |
+
print(f"Attempting to load tree from: {filename}")
|
| 257 |
+
try:
|
| 258 |
+
# Ensure the 'Tree' class is defined in the current scope
|
| 259 |
+
if "Tree" not in globals() and "Tree" not in locals():
|
| 260 |
+
raise NameError(
|
| 261 |
+
"The 'Tree' class definition is required to load the object."
|
| 262 |
+
)
|
| 263 |
+
|
| 264 |
+
with open(filename, "rb") as f:
|
| 265 |
+
loaded_wrapper = pickle.load(f) # This implicitly calls __setstate__
|
| 266 |
+
|
| 267 |
+
print(
|
| 268 |
+
f"Tree object for mol_id '{mol_id}', config '{config}' successfully loaded from '{filename}'."
|
| 269 |
+
)
|
| 270 |
+
# The __setstate__ method already reconstructed the tree inside the wrapper
|
| 271 |
+
return loaded_wrapper.tree
|
| 272 |
+
|
| 273 |
+
except FileNotFoundError:
|
| 274 |
+
print(f"Error: File not found at {filename}")
|
| 275 |
+
return None
|
| 276 |
+
except (pickle.UnpicklingError, EOFError) as e:
|
| 277 |
+
print(
|
| 278 |
+
f"Error: Could not unpickle file {filename}. It might be corrupted or empty. Details: {e}"
|
| 279 |
+
)
|
| 280 |
+
return None
|
| 281 |
+
except NameError as e:
|
| 282 |
+
print(f"Error during loading: {e}. Ensure 'Tree' class is defined.")
|
| 283 |
+
return None
|
| 284 |
+
except Exception as e:
|
| 285 |
+
print(f"An unexpected error occurred loading tree from {filename}: {e}")
|
| 286 |
+
return None
|
synplan/chem/reaction_routes/leaving_groups.py
ADDED
|
@@ -0,0 +1,131 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from CGRtools.periodictable import Core, At, DynamicElement
|
| 2 |
+
from typing import Optional
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
class Marked(Core):
|
| 6 |
+
__slots__ = "__mark", "_isotope"
|
| 7 |
+
|
| 8 |
+
def __init__(self, *args, **kwargs):
|
| 9 |
+
super().__init__(*args, **kwargs)
|
| 10 |
+
self.__mark = None
|
| 11 |
+
self._isotope = 0 # Make sure this exists
|
| 12 |
+
|
| 13 |
+
@property
|
| 14 |
+
def mark(self):
|
| 15 |
+
return self.__mark
|
| 16 |
+
|
| 17 |
+
@mark.setter
|
| 18 |
+
def mark(self, mark):
|
| 19 |
+
self.__mark = mark
|
| 20 |
+
|
| 21 |
+
@property
|
| 22 |
+
def isotope(self):
|
| 23 |
+
return getattr(self, "_isotope", 0) # Always returns int
|
| 24 |
+
|
| 25 |
+
@isotope.setter
|
| 26 |
+
def isotope(self, value):
|
| 27 |
+
self._isotope = int(value)
|
| 28 |
+
|
| 29 |
+
def __repr__(self):
|
| 30 |
+
return f"{self.symbol}({self.isotope})"
|
| 31 |
+
|
| 32 |
+
@property
|
| 33 |
+
def atomic_symbol(self) -> str:
|
| 34 |
+
return self.__class__.__name__[6:]
|
| 35 |
+
|
| 36 |
+
@property
|
| 37 |
+
def symbol(self) -> str:
|
| 38 |
+
return "X" # For human-readable representation
|
| 39 |
+
|
| 40 |
+
def __len__(self):
|
| 41 |
+
return super().__len__()
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
class MarkedAt(Marked, At):
|
| 45 |
+
atomic_number = At.atomic_number
|
| 46 |
+
|
| 47 |
+
@property
|
| 48 |
+
def atomic_symbol(self):
|
| 49 |
+
return "At"
|
| 50 |
+
|
| 51 |
+
@property
|
| 52 |
+
def symbol(self):
|
| 53 |
+
return "X"
|
| 54 |
+
|
| 55 |
+
def __repr__(self):
|
| 56 |
+
return f"X({self.isotope})"
|
| 57 |
+
|
| 58 |
+
def __str__(self):
|
| 59 |
+
return f"X({self.isotope})"
|
| 60 |
+
|
| 61 |
+
def __hash__(self):
|
| 62 |
+
return hash(
|
| 63 |
+
(
|
| 64 |
+
self.isotope,
|
| 65 |
+
getattr(self, "atomic_number", 0),
|
| 66 |
+
getattr(self, "charge", 0),
|
| 67 |
+
getattr(self, "is_radical", False),
|
| 68 |
+
)
|
| 69 |
+
)
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
class DynamicX(DynamicElement):
|
| 73 |
+
__slots__ = ("_mark", "_isotope")
|
| 74 |
+
|
| 75 |
+
atomic_number = 85
|
| 76 |
+
mass = 0.0
|
| 77 |
+
group = 0
|
| 78 |
+
period = 0
|
| 79 |
+
isotopes_distribution = list(range(20))
|
| 80 |
+
atomic_radius = 0.5
|
| 81 |
+
isotopes_masses = 0
|
| 82 |
+
|
| 83 |
+
def __init__(self, *args, **kwargs):
|
| 84 |
+
super().__init__(*args, **kwargs)
|
| 85 |
+
self._isotope = None
|
| 86 |
+
self._mark = None
|
| 87 |
+
|
| 88 |
+
@property
|
| 89 |
+
def mark(self):
|
| 90 |
+
return getattr(self, "_mark", None)
|
| 91 |
+
|
| 92 |
+
@mark.setter
|
| 93 |
+
def mark(self, value):
|
| 94 |
+
self._mark = value
|
| 95 |
+
|
| 96 |
+
@property
|
| 97 |
+
def isotope(self):
|
| 98 |
+
return getattr(self, "_isotope", None)
|
| 99 |
+
|
| 100 |
+
@isotope.setter
|
| 101 |
+
def isotope(self, value):
|
| 102 |
+
self._isotope = value
|
| 103 |
+
|
| 104 |
+
@property
|
| 105 |
+
def symbol(self) -> str:
|
| 106 |
+
return "X"
|
| 107 |
+
|
| 108 |
+
def valence_rules(
|
| 109 |
+
self, charge: int = 0, is_radical: bool = False, valence: int = 0
|
| 110 |
+
) -> tuple:
|
| 111 |
+
if charge == 0 and not is_radical and (valence == 1):
|
| 112 |
+
return tuple()
|
| 113 |
+
elif charge == 0 and not is_radical and valence == 0:
|
| 114 |
+
return tuple()
|
| 115 |
+
else:
|
| 116 |
+
return tuple()
|
| 117 |
+
|
| 118 |
+
def __repr__(self):
|
| 119 |
+
return f"Dynamic{self.symbol}()"
|
| 120 |
+
|
| 121 |
+
@property
|
| 122 |
+
def p_charge(self) -> int:
|
| 123 |
+
return self.charge
|
| 124 |
+
|
| 125 |
+
@property
|
| 126 |
+
def p_is_radical(self) -> bool:
|
| 127 |
+
return self.is_radical
|
| 128 |
+
|
| 129 |
+
@property
|
| 130 |
+
def p_hybridization(self) -> Optional[int]:
|
| 131 |
+
return self.hybridization
|
synplan/chem/reaction_routes/route_cgr.py
ADDED
|
@@ -0,0 +1,570 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from CGRtools.containers.bonds import DynamicBond
|
| 2 |
+
from CGRtools.containers import ReactionContainer, CGRContainer, MoleculeContainer
|
| 3 |
+
from synplan.mcts.tree import Tree
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
def find_next_atom_num(reactions: list):
|
| 7 |
+
"""
|
| 8 |
+
Find the next available atom number across a list of reactions.
|
| 9 |
+
|
| 10 |
+
This function iterates through a list of reaction containers, composes
|
| 11 |
+
each reaction to get its Condensed Graph of Reaction (CGR), and finds
|
| 12 |
+
the maximum atom index used within each CGR. It then returns the maximum
|
| 13 |
+
atom index found across all reactions plus one, providing a unique
|
| 14 |
+
next available atom number.
|
| 15 |
+
|
| 16 |
+
Args:
|
| 17 |
+
reactions (list): A list of ReactionContainer objects.
|
| 18 |
+
|
| 19 |
+
Returns:
|
| 20 |
+
int: The next available integer atom number, which is one greater
|
| 21 |
+
than the maximum atom index found in any of the provided reaction CGRs.
|
| 22 |
+
"""
|
| 23 |
+
max_num = 0
|
| 24 |
+
for reaction in reactions:
|
| 25 |
+
cgr = reaction.compose()
|
| 26 |
+
max_num = max(max_num, max(cgr._atoms.keys()))
|
| 27 |
+
return max_num + 1
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def get_clean_mapping(
|
| 31 |
+
curr_prod: MoleculeContainer, prod: MoleculeContainer, reverse: bool = False
|
| 32 |
+
):
|
| 33 |
+
"""
|
| 34 |
+
Get a 'clean' atom mapping between two molecules, avoiding conflicts.
|
| 35 |
+
|
| 36 |
+
This function attempts to establish a mapping between the atoms of two
|
| 37 |
+
MoleculeContainer objects (`curr_prod` and `prod`). It uses an internal
|
| 38 |
+
mapping mechanism and then filters the result to create a "clean" mapping.
|
| 39 |
+
The cleaning process specifically avoids adding entries to the mapping
|
| 40 |
+
where the source and target indices are the same, or where the target
|
| 41 |
+
index already exists as a source in the mapping with a different target.
|
| 42 |
+
It also checks for potential conflicts based on the atom keys present
|
| 43 |
+
in the original molecules.
|
| 44 |
+
|
| 45 |
+
Args:
|
| 46 |
+
curr_prod (MoleculeContainer): The first MoleculeContainer object.
|
| 47 |
+
prod (MoleculeContainer): The second MoleculeContainer object.
|
| 48 |
+
reverse (bool, optional): If True, the mapping is generated in the
|
| 49 |
+
reverse direction (from `prod` to `curr_prod`).
|
| 50 |
+
Defaults to False (mapping from `curr_prod` to `prod`).
|
| 51 |
+
|
| 52 |
+
Returns:
|
| 53 |
+
dict: A dictionary representing the clean atom mapping. Keys are atom
|
| 54 |
+
indices from the source molecule, and values are the corresponding
|
| 55 |
+
atom indices in the target molecule. Returns an empty dictionary
|
| 56 |
+
if no mapping is found or if the initial mapping is empty.
|
| 57 |
+
"""
|
| 58 |
+
dict_map = {}
|
| 59 |
+
m = list(curr_prod.get_mapping(prod))
|
| 60 |
+
|
| 61 |
+
if len(m) == 0:
|
| 62 |
+
return dict_map
|
| 63 |
+
|
| 64 |
+
curr_atoms = set(curr_prod._atoms.keys())
|
| 65 |
+
prod_atoms = set(prod._atoms.keys())
|
| 66 |
+
|
| 67 |
+
rr = m[0]
|
| 68 |
+
|
| 69 |
+
# Build mapping while checking for conflicts
|
| 70 |
+
for key, value in rr.items():
|
| 71 |
+
if key != value:
|
| 72 |
+
if value in rr.keys() and rr[value] != key:
|
| 73 |
+
continue
|
| 74 |
+
|
| 75 |
+
source = value if reverse else key
|
| 76 |
+
target = key if reverse else value
|
| 77 |
+
|
| 78 |
+
if reverse and target in curr_atoms:
|
| 79 |
+
continue
|
| 80 |
+
if not reverse and target in prod_atoms:
|
| 81 |
+
continue
|
| 82 |
+
|
| 83 |
+
dict_map[source] = target
|
| 84 |
+
|
| 85 |
+
return dict_map
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
def validate_molecule_components(curr_mol: MoleculeContainer, node_id: int):
|
| 89 |
+
"""
|
| 90 |
+
Validate that a molecule has only one connected component.
|
| 91 |
+
|
| 92 |
+
This function checks if a given MoleculeContainer object represents a
|
| 93 |
+
single connected molecule or multiple disconnected fragments. It extracts
|
| 94 |
+
the connected components and prints an error message if more than one
|
| 95 |
+
component is found, indicating a potential issue with the molecule
|
| 96 |
+
representation within a specific tree node.
|
| 97 |
+
|
| 98 |
+
Args:
|
| 99 |
+
curr_mol (MoleculeContainer): The MoleculeContainer object to validate.
|
| 100 |
+
node_id (int): The ID of the tree node associated with this molecule,
|
| 101 |
+
used for reporting purposes in the error message.
|
| 102 |
+
"""
|
| 103 |
+
new_rmol = [curr_mol.substructure(c) for c in curr_mol.connected_components]
|
| 104 |
+
if len(new_rmol) > 1:
|
| 105 |
+
print(f"Error tree {node_id}: We have more than one molecule in one node")
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
def get_leaving_groups(products: list):
|
| 109 |
+
"""
|
| 110 |
+
Extract leaving group atom numbers from a list of reaction products.
|
| 111 |
+
|
| 112 |
+
This function takes a list of product MoleculeContainer objects resulting
|
| 113 |
+
from a reaction. It assumes the first molecule in the list is the main
|
| 114 |
+
product and the subsequent molecules are leaving groups. It collects
|
| 115 |
+
the atom indices (keys from the `_atoms` dictionary) for all molecules
|
| 116 |
+
except the first one, considering these indices as belonging to leaving
|
| 117 |
+
group atoms.
|
| 118 |
+
|
| 119 |
+
Args:
|
| 120 |
+
products (list): A list of MoleculeContainer objects representing the
|
| 121 |
+
products of a reaction. The first element is assumed
|
| 122 |
+
to be the main product.
|
| 123 |
+
|
| 124 |
+
Returns:
|
| 125 |
+
list: A list of integer atom indices corresponding to the atoms
|
| 126 |
+
in the leaving group molecules.
|
| 127 |
+
"""
|
| 128 |
+
lg_atom_nums = []
|
| 129 |
+
for i, prod in enumerate(products):
|
| 130 |
+
if i != 0: # Skip first product (main product)
|
| 131 |
+
lg_atom_nums.extend(prod._atoms.keys())
|
| 132 |
+
return lg_atom_nums
|
| 133 |
+
|
| 134 |
+
|
| 135 |
+
def process_first_reaction(first_react: ReactionContainer, tree: Tree, node_id: int):
|
| 136 |
+
"""
|
| 137 |
+
Process the first reaction in a retrosynthetic route and initialize the building block set.
|
| 138 |
+
|
| 139 |
+
This function takes the first reaction in a route, iterates through its
|
| 140 |
+
reactants, validates that each reactant is a single connected component,
|
| 141 |
+
and identifies potential building blocks. A reactant is considered a
|
| 142 |
+
potential building block if its size is less than or equal to the
|
| 143 |
+
minimum molecule size defined in the tree's configuration or if its
|
| 144 |
+
SMILES string is present in the tree's building blocks set. The atom
|
| 145 |
+
indices of such building blocks are collected into a set.
|
| 146 |
+
|
| 147 |
+
Args:
|
| 148 |
+
first_react (ReactionContainer): The first ReactionContainer object in the route.
|
| 149 |
+
tree (Tree): The Tree object containing the retrosynthetic search tree
|
| 150 |
+
and configuration (including `min_mol_size` and `building_blocks`).
|
| 151 |
+
node_id (int): The ID of the tree node associated with this reaction,
|
| 152 |
+
used for validation reporting.
|
| 153 |
+
|
| 154 |
+
Returns:
|
| 155 |
+
set: A set of integer atom indices corresponding to the atoms
|
| 156 |
+
identified as part of building blocks in the first reaction's reactants.
|
| 157 |
+
"""
|
| 158 |
+
bb_set = set()
|
| 159 |
+
|
| 160 |
+
for curr_mol in first_react.reactants:
|
| 161 |
+
react_key = tuple(curr_mol._atoms)
|
| 162 |
+
react_key_set = set(react_key)
|
| 163 |
+
|
| 164 |
+
if (
|
| 165 |
+
len(curr_mol) <= tree.config.min_mol_size
|
| 166 |
+
or str(curr_mol) in tree.building_blocks
|
| 167 |
+
):
|
| 168 |
+
bb_set = react_key_set
|
| 169 |
+
|
| 170 |
+
validate_molecule_components(curr_mol, node_id)
|
| 171 |
+
|
| 172 |
+
return bb_set
|
| 173 |
+
|
| 174 |
+
|
| 175 |
+
def update_reaction_dict(
|
| 176 |
+
reaction: ReactionContainer,
|
| 177 |
+
node_id: int,
|
| 178 |
+
mapping: dict,
|
| 179 |
+
react_dict: dict,
|
| 180 |
+
tree: Tree,
|
| 181 |
+
bb_set: set,
|
| 182 |
+
prev_remap: dict = None,
|
| 183 |
+
):
|
| 184 |
+
"""
|
| 185 |
+
Update a reaction dictionary with atom mappings and identify building blocks.
|
| 186 |
+
|
| 187 |
+
This function processes the reactants of a given reaction, validates their
|
| 188 |
+
structure (single connected component), updates a dictionary (`react_dict`)
|
| 189 |
+
with atom mappings for each reactant, and expands a set of building block
|
| 190 |
+
atom indices (`bb_set`). The mapping is filtered based on the atoms present
|
| 191 |
+
in the current reactant, and can optionally include a previous remapping.
|
| 192 |
+
Reactants are identified as building blocks based on size or presence in
|
| 193 |
+
the tree's building blocks set.
|
| 194 |
+
|
| 195 |
+
Args:
|
| 196 |
+
reaction (ReactionContainer): The ReactionContainer object representing the reaction.
|
| 197 |
+
node_id (int): The ID of the tree node associated with this synthethic route,
|
| 198 |
+
used for validation reporting.
|
| 199 |
+
mapping (dict): The primary atom mapping dictionary to filter and apply.
|
| 200 |
+
react_dict (dict): The dictionary to update with filtered mappings for each reactant.
|
| 201 |
+
Keys are tuples of atom indices for each reactant molecule.
|
| 202 |
+
tree (Tree): The Tree object containing the retrosynthetic search tree
|
| 203 |
+
and configuration (including `min_mol_size` and `building_blocks`).
|
| 204 |
+
bb_set (set): The set of building block atom indices to update.
|
| 205 |
+
prev_remap (dict, optional): An optional dictionary representing a previous
|
| 206 |
+
remapping to include in the filtered mapping.
|
| 207 |
+
Defaults to None.
|
| 208 |
+
|
| 209 |
+
Returns:
|
| 210 |
+
tuple: A tuple containing:
|
| 211 |
+
- dict: The updated `react_dict` with filtered mappings for each reactant.
|
| 212 |
+
- set: The updated `bb_set` including atom indices from newly identified
|
| 213 |
+
building blocks.
|
| 214 |
+
"""
|
| 215 |
+
for curr_mol in reaction.reactants:
|
| 216 |
+
react_key = tuple(curr_mol._atoms)
|
| 217 |
+
react_key_set = set(react_key)
|
| 218 |
+
|
| 219 |
+
validate_molecule_components(curr_mol, node_id)
|
| 220 |
+
|
| 221 |
+
if (
|
| 222 |
+
len(curr_mol) <= tree.config.min_mol_size
|
| 223 |
+
or str(curr_mol) in tree.building_blocks
|
| 224 |
+
):
|
| 225 |
+
bb_set = bb_set.union(react_key_set)
|
| 226 |
+
|
| 227 |
+
# Filter the mapping to include only keys present in the current react_key
|
| 228 |
+
filtered_mapping = {k: v for k, v in mapping.items() if k in react_key_set}
|
| 229 |
+
if prev_remap:
|
| 230 |
+
prev_remappping = {
|
| 231 |
+
k: v for k, v in prev_remap.items() if k in react_key_set
|
| 232 |
+
}
|
| 233 |
+
filtered_mapping.update(prev_remappping)
|
| 234 |
+
react_dict[react_key] = filtered_mapping
|
| 235 |
+
|
| 236 |
+
return react_dict, bb_set
|
| 237 |
+
|
| 238 |
+
|
| 239 |
+
def process_target_blocks(
|
| 240 |
+
curr_products: list,
|
| 241 |
+
curr_prod: MoleculeContainer,
|
| 242 |
+
lg_atom_nums: list,
|
| 243 |
+
curr_lg_atom_nums: list,
|
| 244 |
+
bb_set: set,
|
| 245 |
+
):
|
| 246 |
+
"""
|
| 247 |
+
Identifies and collects atom indices for target blocks based on leaving groups and building blocks.
|
| 248 |
+
|
| 249 |
+
This function iterates through a list of current product molecules, compares their atoms
|
| 250 |
+
to a reference molecule (`curr_prod`), and collects the indices of atoms that correspond
|
| 251 |
+
to atoms in the provided leaving group lists (`lg_atom_nums`, `curr_lg_atom_nums`) or
|
| 252 |
+
the building block set (`bb_set`). This is typically used to identify parts of molecules
|
| 253 |
+
that should be treated as 'target blocks' during a remapping or analysis process.
|
| 254 |
+
|
| 255 |
+
Args:
|
| 256 |
+
curr_products (list): A list of MoleculeContainer objects representing the current products.
|
| 257 |
+
curr_prod (MoleculeContainer): A reference MoleculeContainer object, likely the main product,
|
| 258 |
+
used for mapping atom indices.
|
| 259 |
+
lg_atom_nums (list): A list of integer atom indices identified as leaving group atoms
|
| 260 |
+
in a relevant context.
|
| 261 |
+
curr_lg_atom_nums (list): Another list of integer atom indices identified as leaving
|
| 262 |
+
group atoms, potentially from a different context than `lg_atom_nums`.
|
| 263 |
+
bb_set (set): A set of integer atom indices identified as building block atoms.
|
| 264 |
+
|
| 265 |
+
Returns:
|
| 266 |
+
list: A list of integer atom indices that are identified as 'target blocks' based on
|
| 267 |
+
their presence in the leaving group lists or building block set after mapping
|
| 268 |
+
to the reference molecule.
|
| 269 |
+
"""
|
| 270 |
+
target_block = []
|
| 271 |
+
if len(curr_products) > 1:
|
| 272 |
+
for prod in curr_products:
|
| 273 |
+
dict_map = get_clean_mapping(curr_prod, prod)
|
| 274 |
+
if prod._atoms.keys() != curr_prod._atoms.keys():
|
| 275 |
+
for key in list(prod._atoms.keys()):
|
| 276 |
+
if key in lg_atom_nums or key in curr_lg_atom_nums:
|
| 277 |
+
target_block.append(key)
|
| 278 |
+
if key in bb_set:
|
| 279 |
+
target_block.append(key)
|
| 280 |
+
return target_block
|
| 281 |
+
|
| 282 |
+
|
| 283 |
+
def compose_route_cgr(tree_or_routes, node_id):
|
| 284 |
+
"""
|
| 285 |
+
Process a single synthesis route maintaining consistent state.
|
| 286 |
+
|
| 287 |
+
Parameters
|
| 288 |
+
----------
|
| 289 |
+
tree_or_routes : synplan.mcts.tree.Tree
|
| 290 |
+
or dict mapping route_id -> {step_id: ReactionContainer}
|
| 291 |
+
node_id : int
|
| 292 |
+
the route index (in the Tree’s winning_nodes, or the dict’s keys)
|
| 293 |
+
|
| 294 |
+
Returns
|
| 295 |
+
-------
|
| 296 |
+
dict or None
|
| 297 |
+
- if successful: { 'cgr': <composed CGR>, 'reactions_dict': {step: ReactionContainer,…} }
|
| 298 |
+
- on error: None
|
| 299 |
+
"""
|
| 300 |
+
# ----------- dict-based branch ------------
|
| 301 |
+
if isinstance(tree_or_routes, dict):
|
| 302 |
+
routes_dict = tree_or_routes
|
| 303 |
+
if node_id not in routes_dict:
|
| 304 |
+
raise KeyError(f"Route {node_id} not in provided dict.")
|
| 305 |
+
# grab and sort the ReactionContainers in chronological order
|
| 306 |
+
step_map = routes_dict[node_id]
|
| 307 |
+
sorted_ids = sorted(step_map)
|
| 308 |
+
reactions = [step_map[i] for i in sorted_ids]
|
| 309 |
+
|
| 310 |
+
# start from the last (final) reaction
|
| 311 |
+
accum_cgr = reactions[-1].compose()
|
| 312 |
+
reactions_dict = {len(reactions) - 1: reactions[-1]}
|
| 313 |
+
# now fold backwards through the earlier steps
|
| 314 |
+
for idx in range(len(reactions) - 2, -1, -1):
|
| 315 |
+
rxn = reactions[idx]
|
| 316 |
+
curr_cgr = rxn.compose()
|
| 317 |
+
accum_cgr = curr_cgr.compose(accum_cgr)
|
| 318 |
+
reactions_dict[idx] = rxn
|
| 319 |
+
|
| 320 |
+
return {"cgr": accum_cgr, "reactions_dict": reactions_dict}
|
| 321 |
+
|
| 322 |
+
# ----------- tree-based branch ------------
|
| 323 |
+
tree = tree_or_routes
|
| 324 |
+
try:
|
| 325 |
+
# original tree-based logic:
|
| 326 |
+
reactions = tree.synthesis_route(node_id)
|
| 327 |
+
|
| 328 |
+
first_react = reactions[-1]
|
| 329 |
+
reactions_dict = {len(reactions) - 1: first_react}
|
| 330 |
+
|
| 331 |
+
accum_cgr = first_react.compose()
|
| 332 |
+
bb_set = process_first_reaction(first_react, tree, node_id)
|
| 333 |
+
react_dict = {}
|
| 334 |
+
max_num = find_next_atom_num(reactions)
|
| 335 |
+
|
| 336 |
+
for step in range(len(reactions) - 2, -1, -1):
|
| 337 |
+
reaction = reactions[step]
|
| 338 |
+
curr_cgr = reaction.compose()
|
| 339 |
+
curr_prod = reaction.products[0]
|
| 340 |
+
|
| 341 |
+
accum_products = accum_cgr.decompose()[1].split()
|
| 342 |
+
lg_atom_nums = get_leaving_groups(accum_products)
|
| 343 |
+
curr_products = curr_cgr.decompose()[1].split()
|
| 344 |
+
|
| 345 |
+
tuple_atoms = tuple(curr_prod._atoms)
|
| 346 |
+
prev_remap = react_dict.get(tuple_atoms, {})
|
| 347 |
+
|
| 348 |
+
if prev_remap:
|
| 349 |
+
curr_cgr = curr_cgr.remap(prev_remap, copy=True)
|
| 350 |
+
|
| 351 |
+
# identify new atom‐numbers for any overlap
|
| 352 |
+
target_block = process_target_blocks(
|
| 353 |
+
curr_products,
|
| 354 |
+
curr_prod,
|
| 355 |
+
lg_atom_nums,
|
| 356 |
+
[list(p._atoms.keys()) for p in curr_products[1:]],
|
| 357 |
+
bb_set,
|
| 358 |
+
)
|
| 359 |
+
mapping = {}
|
| 360 |
+
for atom_num in sorted(target_block):
|
| 361 |
+
if atom_num in accum_cgr._atoms and atom_num not in mapping:
|
| 362 |
+
mapping[atom_num] = max_num
|
| 363 |
+
max_num += 1
|
| 364 |
+
|
| 365 |
+
# carry forward any clean remap on the product itself
|
| 366 |
+
dict_map = {}
|
| 367 |
+
for ap in accum_products:
|
| 368 |
+
clean_map = get_clean_mapping(curr_prod, ap, reverse=True)
|
| 369 |
+
if clean_map:
|
| 370 |
+
dict_map = clean_map
|
| 371 |
+
break
|
| 372 |
+
if dict_map:
|
| 373 |
+
curr_cgr = curr_cgr.remap(dict_map, copy=False)
|
| 374 |
+
|
| 375 |
+
# update our react_dict & bb_set
|
| 376 |
+
react_dict, bb_set = update_reaction_dict(
|
| 377 |
+
reaction, node_id, mapping, react_dict, tree, bb_set, prev_remap
|
| 378 |
+
)
|
| 379 |
+
|
| 380 |
+
# apply the new overlap‐mapping
|
| 381 |
+
if mapping:
|
| 382 |
+
curr_cgr = curr_cgr.remap(mapping, copy=False)
|
| 383 |
+
|
| 384 |
+
reactions_dict[step] = ReactionContainer.from_cgr(curr_cgr)
|
| 385 |
+
accum_cgr = curr_cgr.compose(accum_cgr)
|
| 386 |
+
|
| 387 |
+
return {"cgr": accum_cgr, "reactions_dict": reactions_dict}
|
| 388 |
+
|
| 389 |
+
except Exception as e:
|
| 390 |
+
print(f"Error processing node {node_id}: {e}")
|
| 391 |
+
return None
|
| 392 |
+
|
| 393 |
+
|
| 394 |
+
def compose_all_route_cgrs(tree_or_routes, node_id=None):
|
| 395 |
+
"""
|
| 396 |
+
Process routes (reassign atom mappings) to compose RouteCGR.
|
| 397 |
+
|
| 398 |
+
Parameters
|
| 399 |
+
----------
|
| 400 |
+
tree_or_routes : synplan.mcts.tree.Tree
|
| 401 |
+
or dict mapping route_id -> {step_id: ReactionContainer}
|
| 402 |
+
node_id : int or None
|
| 403 |
+
if None, do *all* winning routes (or all keys of the dict);
|
| 404 |
+
otherwise only that specific route.
|
| 405 |
+
|
| 406 |
+
Returns
|
| 407 |
+
-------
|
| 408 |
+
dict or None
|
| 409 |
+
- if node_id is None: {route_id: CGR, …}
|
| 410 |
+
- if node_id is given: {node_id: CGR}
|
| 411 |
+
- returns None on error
|
| 412 |
+
"""
|
| 413 |
+
# dict-based branch
|
| 414 |
+
if isinstance(tree_or_routes, dict):
|
| 415 |
+
routes_dict = tree_or_routes
|
| 416 |
+
|
| 417 |
+
def _single(rid):
|
| 418 |
+
res = compose_route_cgr(routes_dict, rid)
|
| 419 |
+
return res["cgr"] if res else None
|
| 420 |
+
|
| 421 |
+
if node_id is not None:
|
| 422 |
+
if node_id not in routes_dict:
|
| 423 |
+
raise KeyError(f"Route {node_id} not in provided dict.")
|
| 424 |
+
return {node_id: _single(node_id)}
|
| 425 |
+
|
| 426 |
+
# all routes
|
| 427 |
+
result = {rid: _single(rid) for rid in sorted(routes_dict)}
|
| 428 |
+
return result
|
| 429 |
+
|
| 430 |
+
# tree-based branch
|
| 431 |
+
tree = tree_or_routes
|
| 432 |
+
route_cgrs = {}
|
| 433 |
+
|
| 434 |
+
if node_id is not None:
|
| 435 |
+
res = compose_route_cgr(tree, node_id)
|
| 436 |
+
if res:
|
| 437 |
+
route_cgrs[node_id] = res["cgr"]
|
| 438 |
+
else:
|
| 439 |
+
return None
|
| 440 |
+
return route_cgrs
|
| 441 |
+
|
| 442 |
+
for rid in sorted(set(tree.winning_nodes)):
|
| 443 |
+
res = compose_route_cgr(tree, rid)
|
| 444 |
+
if res:
|
| 445 |
+
route_cgrs[rid] = res["cgr"]
|
| 446 |
+
|
| 447 |
+
return route_cgrs
|
| 448 |
+
|
| 449 |
+
|
| 450 |
+
def extract_reactions(tree: Tree, node_id=None):
|
| 451 |
+
"""
|
| 452 |
+
Collect mapped reaction sequences from a synthesis tree.
|
| 453 |
+
|
| 454 |
+
Traverses either a single branch (if `node_id` is given) or all winning routes,
|
| 455 |
+
composing CGR-based reactions for each, and returns a dict of reaction mappings.
|
| 456 |
+
Ensures that in every extracted reaction, atom indices are uniquely mapped (no overlaps)
|
| 457 |
+
|
| 458 |
+
Parameters
|
| 459 |
+
----------
|
| 460 |
+
tree : ReactionTree
|
| 461 |
+
A retrosynthetic tree object with a `.winning_nodes` attribute and
|
| 462 |
+
supporting `compose_route_cgr(...)`.
|
| 463 |
+
node_id : hashable, optional
|
| 464 |
+
If provided, only extract reactions for this specific node/route.
|
| 465 |
+
|
| 466 |
+
Returns
|
| 467 |
+
-------
|
| 468 |
+
dict[node_id, dict]
|
| 469 |
+
Maps each route terminal node ID to its `reactions_dict` (as returned
|
| 470 |
+
by `compose_route_cgr`). Returns `None` if the specified `node_id` fails
|
| 471 |
+
to produce valid reactions.
|
| 472 |
+
"""
|
| 473 |
+
react_dict = {}
|
| 474 |
+
if node_id is not None:
|
| 475 |
+
result = compose_route_cgr(tree, node_id)
|
| 476 |
+
if result:
|
| 477 |
+
react_dict[node_id] = result["reactions_dict"]
|
| 478 |
+
else:
|
| 479 |
+
return None
|
| 480 |
+
return react_dict
|
| 481 |
+
|
| 482 |
+
for node_id in set(tree.winning_nodes):
|
| 483 |
+
result = compose_route_cgr(tree, node_id)
|
| 484 |
+
if result:
|
| 485 |
+
react_dict[node_id] = result["reactions_dict"]
|
| 486 |
+
|
| 487 |
+
return dict(sorted(react_dict.items()))
|
| 488 |
+
|
| 489 |
+
|
| 490 |
+
def compose_sb_cgr(route_cgr: CGRContainer):
|
| 491 |
+
"""
|
| 492 |
+
Reduces a Routes Condensed Graph of reaction (RouteCGR) by performing the following steps:
|
| 493 |
+
|
| 494 |
+
1. Extracts substructures corresponding to connected components from the input RouteCGR.
|
| 495 |
+
2. Selects the first substructure as the target to work on.
|
| 496 |
+
3. Iterates over all bonds in the target RouteCGR:
|
| 497 |
+
- If a bond is identified as a "leaving group" (its primary order is None while its original order is defined),
|
| 498 |
+
the bond is removed.
|
| 499 |
+
- If a bond has a modified order (both primary and original orders are integers) and the primary order is less than the original,
|
| 500 |
+
the bond is deleted and then re-added with a new dynamic bond using the primary order (this updates the bond to the reduced form).
|
| 501 |
+
4. After bond modifications, re-extracts the substructure from the target RouteCGR (now called the reduced RouteCGR or ReducedRouteCGR).
|
| 502 |
+
5. If the charge distributions (_p_charges vs. _charges) differ, neutralizes the charges by setting them to zero.
|
| 503 |
+
|
| 504 |
+
Args:
|
| 505 |
+
route_cgr: The input RouteCGR object to be reduced.
|
| 506 |
+
|
| 507 |
+
Returns:
|
| 508 |
+
The reduced RouteCGR object.
|
| 509 |
+
"""
|
| 510 |
+
# Get all connected components of the RouteCGR as separate substructures.
|
| 511 |
+
cgr_prods = [route_cgr.substructure(c) for c in route_cgr.connected_components]
|
| 512 |
+
target_cgr = cgr_prods[
|
| 513 |
+
0
|
| 514 |
+
] # Choose the first substructure (main product) for further reduction.
|
| 515 |
+
|
| 516 |
+
# Iterate over each bond in the target RouteCGR.
|
| 517 |
+
bond_items = list(target_cgr._bonds.items())
|
| 518 |
+
for atom1, bond_set in bond_items:
|
| 519 |
+
bond_set_items = list(bond_set.items())
|
| 520 |
+
for atom2, bond in bond_set_items:
|
| 521 |
+
|
| 522 |
+
# Removing bonds corresponding to leaving groups:
|
| 523 |
+
# If product bond order is None (indicating a leaving group) but an original bond order exists,
|
| 524 |
+
# delete the bond.
|
| 525 |
+
if bond.p_order is None and bond.order is not None:
|
| 526 |
+
target_cgr.delete_bond(atom1, atom2)
|
| 527 |
+
|
| 528 |
+
# For bonds that have been modified (not leaving groups) where the new (primary) order is less than the original:
|
| 529 |
+
# Remove the bond and re-add it using the DynamicBond with the primary order for both bond orders.
|
| 530 |
+
elif (
|
| 531 |
+
type(bond.p_order) is int
|
| 532 |
+
and type(bond.order) is int
|
| 533 |
+
and bond.p_order != bond.order
|
| 534 |
+
):
|
| 535 |
+
p_order = int(bond.p_order)
|
| 536 |
+
target_cgr.delete_bond(atom1, atom2)
|
| 537 |
+
target_cgr.add_bond(atom1, atom2, DynamicBond(p_order, p_order))
|
| 538 |
+
|
| 539 |
+
# After modifying bonds, extract the reduced RouteCGR from the target's connected components.
|
| 540 |
+
reduced_route_cgr = [
|
| 541 |
+
target_cgr.substructure(c) for c in target_cgr.connected_components
|
| 542 |
+
][0]
|
| 543 |
+
|
| 544 |
+
# Neutralize charges if the primary charges and current charges differ.
|
| 545 |
+
if reduced_route_cgr._p_charges != reduced_route_cgr._charges:
|
| 546 |
+
for num, charge in reduced_route_cgr._charges.items():
|
| 547 |
+
if charge != 0:
|
| 548 |
+
reduced_route_cgr._atoms[num].charge = 0
|
| 549 |
+
|
| 550 |
+
return reduced_route_cgr
|
| 551 |
+
|
| 552 |
+
|
| 553 |
+
def compose_all_sb_cgrs(route_cgrs_dict: dict):
|
| 554 |
+
"""
|
| 555 |
+
Processes a collection (dictionary) of RouteCGRs to generate their reduced forms (ReducedRouteCGRs).
|
| 556 |
+
|
| 557 |
+
Iterates over each RouteCGR in the provided dictionary and applies the compose_reduced_route_cgr function.
|
| 558 |
+
|
| 559 |
+
Args:
|
| 560 |
+
route_cgrs_dict (dict): A dictionary where keys are identifiers (e.g., route numbers)
|
| 561 |
+
and values are RouteCGR objects.
|
| 562 |
+
|
| 563 |
+
Returns:
|
| 564 |
+
dict: A dictionary where each key corresponds to the original identifier from
|
| 565 |
+
`route_cgrs_dict` and the value is the corresponding ReducedRouteCGR object.
|
| 566 |
+
"""
|
| 567 |
+
all_reduced_route_cgrs = dict()
|
| 568 |
+
for num, cgr in route_cgrs_dict.items():
|
| 569 |
+
all_reduced_route_cgrs[num] = compose_sb_cgr(cgr)
|
| 570 |
+
return all_reduced_route_cgrs
|
synplan/chem/reaction_routes/visualisation.py
ADDED
|
@@ -0,0 +1,903 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from CGRtools.algorithms.depict import (
|
| 2 |
+
Depict,
|
| 3 |
+
DepictMolecule,
|
| 4 |
+
DepictCGR,
|
| 5 |
+
rotate_vector,
|
| 6 |
+
_render_charge,
|
| 7 |
+
)
|
| 8 |
+
from CGRtools.containers import ReactionContainer, MoleculeContainer, CGRContainer
|
| 9 |
+
|
| 10 |
+
from collections import defaultdict
|
| 11 |
+
from uuid import uuid4
|
| 12 |
+
from math import hypot
|
| 13 |
+
from functools import partial
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
class WideBondDepictCGR(DepictCGR):
|
| 17 |
+
"""
|
| 18 |
+
Like DepictCGR, but all DynamicBonds
|
| 19 |
+
are drawn 2.5× wider than the standard bond width.
|
| 20 |
+
"""
|
| 21 |
+
|
| 22 |
+
__slots__ = ()
|
| 23 |
+
|
| 24 |
+
def _render_bonds(self):
|
| 25 |
+
"""
|
| 26 |
+
Renders the bonds of the CGR as SVG lines, with DynamicBonds drawn wider.
|
| 27 |
+
|
| 28 |
+
This method overrides the base `_render_bonds` to apply a wider stroke
|
| 29 |
+
to DynamicBonds, highlighting changes in bond order during a reaction.
|
| 30 |
+
It iterates through all bonds, calculates their positions based on
|
| 31 |
+
2D coordinates, and generates SVG `<line>` elements with appropriate
|
| 32 |
+
styles (color, width, dash array) based on the bond's original (`order`)
|
| 33 |
+
and primary (`p_order`) states. Aromatic bonds are handled separately
|
| 34 |
+
using a helper method.
|
| 35 |
+
|
| 36 |
+
Returns:
|
| 37 |
+
list: A list of strings, where each string is an SVG element
|
| 38 |
+
representing a bond.
|
| 39 |
+
"""
|
| 40 |
+
plane = self._plane
|
| 41 |
+
config = self._render_config
|
| 42 |
+
|
| 43 |
+
# get the normal width (default 1.0) and compute a 4× wide stroke
|
| 44 |
+
normal_width = config.get("bond_width", 0.02)
|
| 45 |
+
wide_width = normal_width * 2.5
|
| 46 |
+
|
| 47 |
+
broken = config["broken_color"]
|
| 48 |
+
formed = config["formed_color"]
|
| 49 |
+
dash1, dash2 = config["dashes"]
|
| 50 |
+
double_space = config["double_space"]
|
| 51 |
+
triple_space = config["triple_space"]
|
| 52 |
+
|
| 53 |
+
svg = []
|
| 54 |
+
ar_bond_colors = defaultdict(dict)
|
| 55 |
+
|
| 56 |
+
for n, m, bond in self.bonds():
|
| 57 |
+
order, p_order = bond.order, bond.p_order
|
| 58 |
+
nx, ny = plane[n]
|
| 59 |
+
mx, my = plane[m]
|
| 60 |
+
# invert Y for SVG
|
| 61 |
+
ny, my = -ny, -my
|
| 62 |
+
rv = partial(rotate_vector, 0, x2=mx - nx, y2=ny - my)
|
| 63 |
+
if order == 1:
|
| 64 |
+
if p_order == 1:
|
| 65 |
+
svg.append(
|
| 66 |
+
f' <line x1="{nx:.2f}" y1="{ny:.2f}" x2="{mx:.2f}" y2="{my:.2f}"/>'
|
| 67 |
+
)
|
| 68 |
+
elif p_order == 4:
|
| 69 |
+
ar_bond_colors[n][m] = ar_bond_colors[m][n] = formed
|
| 70 |
+
svg.append(
|
| 71 |
+
f' <line x1="{nx:.2f}" y1="{ny:.2f}" x2="{mx:.2f}" y2="{my:.2f}"/>'
|
| 72 |
+
)
|
| 73 |
+
elif p_order == 2:
|
| 74 |
+
dx, dy = rv(double_space)
|
| 75 |
+
svg.append(
|
| 76 |
+
f' <line x1="{nx + dx:.2f}" y1="{ny - dy:.2f}" '
|
| 77 |
+
f'x2="{mx + dx:.2f}" y2="{my - dy:.2f}"/>'
|
| 78 |
+
)
|
| 79 |
+
svg.append(
|
| 80 |
+
f' <line x1="{nx - dx:.2f}" y1="{ny + dy:.2f}" '
|
| 81 |
+
f'x2="{mx - dx:.2f}" y2="{my + dy:.2f}" stroke="{formed}" stroke-width="{wide_width:.2f}"/>'
|
| 82 |
+
)
|
| 83 |
+
elif p_order == 3:
|
| 84 |
+
dx, dy = rv(triple_space)
|
| 85 |
+
svg.append(
|
| 86 |
+
f' <line x1="{nx:.2f}" y1="{ny:.2f}" x2="{mx:.2f}" y2="{my:.2f}"'
|
| 87 |
+
f' stroke="{formed}" stroke-width="{wide_width:.2f}"/>'
|
| 88 |
+
)
|
| 89 |
+
svg.append(
|
| 90 |
+
f' <line x1="{nx + dx:.2f}" y1="{ny - dy:.2f}" '
|
| 91 |
+
f'x2="{mx + dx:.2f}" y2="{my - dy:.2f}" stroke-width="{wide_width:.2f}"/>'
|
| 92 |
+
)
|
| 93 |
+
svg.append(
|
| 94 |
+
f' <line x1="{nx - dx:.2f}" y1="{ny + dy:.2f}" '
|
| 95 |
+
f'x2="{mx - dx:.2f}" y2="{my + dy:.2f}" stroke="{formed}" stroke-width="{wide_width:.2f}"/>'
|
| 96 |
+
)
|
| 97 |
+
elif p_order is None:
|
| 98 |
+
svg.append(
|
| 99 |
+
f' <line x1="{nx:.2f}" y1="{ny:.2f}" x2="{mx:.2f}" y2="{my:.2f}"'
|
| 100 |
+
f' stroke="{broken}" stroke-width="{wide_width:.2f}"/>'
|
| 101 |
+
)
|
| 102 |
+
else:
|
| 103 |
+
dx, dy = rv(double_space)
|
| 104 |
+
svg.append(
|
| 105 |
+
f' <line x1="{nx + dx:.2f}" y1="{ny - dy:.2f}" x2="{mx + dx:.2f}"'
|
| 106 |
+
f' y2="{my - dy:.2f}" stroke-dasharray="{dash1:.2f} {dash2:.2f}" stroke="{formed}" stroke-width="{wide_width:.2f}"/>'
|
| 107 |
+
)
|
| 108 |
+
svg.append(
|
| 109 |
+
f' <line x1="{nx - dx:.2f}" y1="{ny + dy:.2f}" '
|
| 110 |
+
f'x2="{mx - dx:.2f}" y2="{my + dy:.2f}" stroke="{broken}" stroke-width="{wide_width:.2f}"/>'
|
| 111 |
+
)
|
| 112 |
+
elif order == 4:
|
| 113 |
+
if p_order == 4:
|
| 114 |
+
svg.append(
|
| 115 |
+
f' <line x1="{nx:.2f}" y1="{ny:.2f}" x2="{mx:.2f}" y2="{my:.2f}"/>'
|
| 116 |
+
)
|
| 117 |
+
elif p_order == 1:
|
| 118 |
+
ar_bond_colors[n][m] = ar_bond_colors[m][n] = broken
|
| 119 |
+
svg.append(
|
| 120 |
+
f' <line x1="{nx:.2f}" y1="{ny:.2f}" x2="{mx:.2f}" y2="{my:.2f}"/>'
|
| 121 |
+
)
|
| 122 |
+
elif p_order == 2:
|
| 123 |
+
ar_bond_colors[n][m] = ar_bond_colors[m][n] = broken
|
| 124 |
+
dx, dy = rv(double_space)
|
| 125 |
+
svg.append(
|
| 126 |
+
f' <line x1="{nx + dx:.2f}" y1="{ny - dy:.2f}" '
|
| 127 |
+
f'x2="{mx + dx:.2f}" y2="{my - dy:.2f}"/>'
|
| 128 |
+
)
|
| 129 |
+
svg.append(
|
| 130 |
+
f' <line x1="{nx - dx:.2f}" y1="{ny + dy:.2f}" '
|
| 131 |
+
f'x2="{mx - dx:.2f}" y2="{my + dy:.2f}" stroke="{formed}" stroke-width="{wide_width:.2f}"/>'
|
| 132 |
+
)
|
| 133 |
+
elif p_order == 3:
|
| 134 |
+
ar_bond_colors[n][m] = ar_bond_colors[m][n] = broken
|
| 135 |
+
dx, dy = rv(triple_space)
|
| 136 |
+
svg.append(
|
| 137 |
+
f' <line x1="{nx + dx:.2f}" y1="{ny - dy:.2f}" '
|
| 138 |
+
f'x2="{mx + dx:.2f}" y2="{my - dy:.2f}" stroke="{formed}" stroke-width="{wide_width:.2f}"/>'
|
| 139 |
+
)
|
| 140 |
+
svg.append(
|
| 141 |
+
f' <line x1="{nx:.2f}" y1="{ny:.2f}" x2="{mx:.2f}" y2="{my:.2f}"/>'
|
| 142 |
+
)
|
| 143 |
+
svg.append(
|
| 144 |
+
f' <line x1="{nx - dx:.2f}" y1="{ny + dy:.2f}" '
|
| 145 |
+
f'x2="{mx - dx:.2f}" y2="{my + dy:.2f}" stroke="{formed}" stroke-width="{wide_width:.2f}"/>'
|
| 146 |
+
)
|
| 147 |
+
elif p_order is None:
|
| 148 |
+
ar_bond_colors[n][m] = ar_bond_colors[m][n] = broken
|
| 149 |
+
svg.append(
|
| 150 |
+
f' <line x1="{nx:.2f}" y1="{ny:.2f}" x2="{mx:.2f}" y2="{my:.2f}"'
|
| 151 |
+
f' stroke="{broken}" stroke-width="{wide_width:.2f}"/>'
|
| 152 |
+
)
|
| 153 |
+
else:
|
| 154 |
+
ar_bond_colors[n][m] = ar_bond_colors[m][n] = None
|
| 155 |
+
svg.append(
|
| 156 |
+
f' <line x1="{nx:.2f}" y1="{ny:.2f}" x2="{mx:.2f}" y2="{my:.2f}"'
|
| 157 |
+
f' stroke="{broken}" stroke-width="{wide_width:.2f}"/>'
|
| 158 |
+
)
|
| 159 |
+
elif order == 2:
|
| 160 |
+
if p_order == 2:
|
| 161 |
+
dx, dy = rv(double_space)
|
| 162 |
+
svg.append(
|
| 163 |
+
f' <line x1="{nx + dx:.2f}" y1="{ny - dy:.2f}" '
|
| 164 |
+
f'x2="{mx + dx:.2f}" y2="{my - dy:.2f}"/>'
|
| 165 |
+
)
|
| 166 |
+
svg.append(
|
| 167 |
+
f' <line x1="{nx - dx:.2f}" y1="{ny + dy:.2f}" '
|
| 168 |
+
f'x2="{mx - dx:.2f}" y2="{my + dy:.2f}"/>'
|
| 169 |
+
)
|
| 170 |
+
elif p_order == 1:
|
| 171 |
+
dx, dy = rv(double_space)
|
| 172 |
+
svg.append(
|
| 173 |
+
f' <line x1="{nx + dx:.2f}" y1="{ny - dy:.2f}" '
|
| 174 |
+
f'x2="{mx + dx:.2f}" y2="{my - dy:.2f}"/>'
|
| 175 |
+
)
|
| 176 |
+
svg.append(
|
| 177 |
+
f' <line x1="{nx - dx:.2f}" y1="{ny + dy:.2f}" '
|
| 178 |
+
f'x2="{mx - dx:.2f}" y2="{my + dy:.2f}" stroke="{broken}" stroke-width="{wide_width:.2f}"/>'
|
| 179 |
+
)
|
| 180 |
+
elif p_order == 4:
|
| 181 |
+
ar_bond_colors[n][m] = ar_bond_colors[m][n] = formed
|
| 182 |
+
dx, dy = rv(double_space)
|
| 183 |
+
svg.append(
|
| 184 |
+
f' <line x1="{nx + dx:.2f}" y1="{ny - dy:.2f}" '
|
| 185 |
+
f'x2="{mx + dx:.2f}" y2="{my - dy:.2f}"/>'
|
| 186 |
+
)
|
| 187 |
+
svg.append(
|
| 188 |
+
f' <line x1="{nx - dx:.2f}" y1="{ny + dy:.2f}" '
|
| 189 |
+
f'x2="{mx - dx:.2f}" y2="{my + dy:.2f}" stroke="{broken}" stroke-width="{wide_width:.2f}"/>'
|
| 190 |
+
)
|
| 191 |
+
elif p_order == 3:
|
| 192 |
+
dx, dy = rv(triple_space)
|
| 193 |
+
svg.append(
|
| 194 |
+
f' <line x1="{nx + dx:.2f}" y1="{ny - dy:.2f}" '
|
| 195 |
+
f'x2="{mx + dx:.2f}" y2="{my - dy:.2f}"/>'
|
| 196 |
+
)
|
| 197 |
+
svg.append(
|
| 198 |
+
f' <line x1="{nx:.2f}" y1="{ny:.2f}" x2="{mx:.2f}" y2="{my:.2f}"/>'
|
| 199 |
+
)
|
| 200 |
+
svg.append(
|
| 201 |
+
f' <line x1="{nx - dx:.2f}" y1="{ny + dy:.2f}" '
|
| 202 |
+
f'x2="{mx - dx:.2f}" y2="{my + dy:.2f}" stroke="{formed} stroke-width="{wide_width:.2f}""/>'
|
| 203 |
+
)
|
| 204 |
+
elif p_order is None:
|
| 205 |
+
dx, dy = rv(double_space)
|
| 206 |
+
svg.append(
|
| 207 |
+
f' <line x1="{nx + dx:.2f}" y1="{ny - dy:.2f}" '
|
| 208 |
+
f'x2="{mx + dx:.2f}" y2="{my - dy:.2f}" stroke="{broken}" stroke-width="{wide_width:.2f}"/>'
|
| 209 |
+
)
|
| 210 |
+
svg.append(
|
| 211 |
+
f' <line x1="{nx - dx:.2f}" y1="{ny + dy:.2f}" '
|
| 212 |
+
f'x2="{mx - dx:.2f}" y2="{my + dy:.2f}" stroke="{broken}" stroke-width="{wide_width:.2f}"/>'
|
| 213 |
+
)
|
| 214 |
+
else:
|
| 215 |
+
dx, dy = rv(triple_space)
|
| 216 |
+
svg.append(
|
| 217 |
+
f' <line x1="{nx + dx:.2f}" y1="{ny - dy:.2f}" x2="{mx + dx:.2f}"'
|
| 218 |
+
f' y2="{my - dy:.2f}" stroke-dasharray="{dash1:.2f} {dash2:.2f}" stroke="{formed}" stroke-width="{wide_width:.2f}"/>'
|
| 219 |
+
)
|
| 220 |
+
svg.append(
|
| 221 |
+
f' <line x1="{nx:.2f}" y1="{ny:.2f}" x2="{mx:.2f}" y2="{my:.2f}"'
|
| 222 |
+
f' stroke="{broken}" stroke-width="{wide_width:.2f}"/>'
|
| 223 |
+
)
|
| 224 |
+
svg.append(
|
| 225 |
+
f' <line x1="{nx - dx:.2f}" y1="{ny + dy:.2f}" '
|
| 226 |
+
f'x2="{mx - dx:.2f}" y2="{my + dy:.2f}" stroke="{broken}" stroke-width="{wide_width:.2f}"/>'
|
| 227 |
+
)
|
| 228 |
+
elif order == 3:
|
| 229 |
+
if p_order == 3:
|
| 230 |
+
dx, dy = rv(triple_space)
|
| 231 |
+
svg.append(
|
| 232 |
+
f' <line x1="{nx + dx:.2f}" y1="{ny - dy:.2f}" '
|
| 233 |
+
f'x2="{mx + dx:.2f}" y2="{my - dy:.2f}"/>'
|
| 234 |
+
)
|
| 235 |
+
svg.append(
|
| 236 |
+
f' <line x1="{nx:.2f}" y1="{ny:.2f}" x2="{mx:.2f}" y2="{my:.2f}"/>'
|
| 237 |
+
)
|
| 238 |
+
svg.append(
|
| 239 |
+
f' <line x1="{nx - dx:.2f}" y1="{ny + dy:.2f}" '
|
| 240 |
+
f'x2="{mx - dx:.2f}" y2="{my + dy:.2f}"/>'
|
| 241 |
+
)
|
| 242 |
+
elif p_order == 1:
|
| 243 |
+
dx, dy = rv(triple_space)
|
| 244 |
+
svg.append(
|
| 245 |
+
f' <line x1="{nx + dx:.2f}" y1="{ny - dy:.2f}" '
|
| 246 |
+
f'x2="{mx + dx:.2f}" y2="{my - dy:.2f}"/>'
|
| 247 |
+
)
|
| 248 |
+
svg.append(
|
| 249 |
+
f' <line x1="{nx:.2f}" y1="{ny:.2f}" x2="{mx:.2f}" y2="{my:.2f}"'
|
| 250 |
+
f' stroke="{broken}" stroke-width="{wide_width:.2f}"/>'
|
| 251 |
+
)
|
| 252 |
+
svg.append(
|
| 253 |
+
f' <line x1="{nx - dx:.2f}" y1="{ny + dy:.2f}" '
|
| 254 |
+
f'x2="{mx - dx:.2f}" y2="{my + dy:.2f}" '
|
| 255 |
+
f'stroke="{broken}" stroke-width="{wide_width:.2f}"/>'
|
| 256 |
+
)
|
| 257 |
+
elif p_order == 4:
|
| 258 |
+
ar_bond_colors[n][m] = ar_bond_colors[m][n] = formed
|
| 259 |
+
dx, dy = rv(triple_space)
|
| 260 |
+
svg.append(
|
| 261 |
+
f' <line x1="{nx + dx:.2f}" y1="{ny - dy:.2f}" x2="{mx + dx:.2f}" '
|
| 262 |
+
f'y2="{my - dy:.2f}" stroke="{broken}" stroke-width="{wide_width:.2f}"/>'
|
| 263 |
+
)
|
| 264 |
+
svg.append(
|
| 265 |
+
f' <line x1="{nx:.2f}" y1="{ny:.2f}" x2="{mx:.2f}" y2="{my:.2f}"/>'
|
| 266 |
+
)
|
| 267 |
+
svg.append(
|
| 268 |
+
f' <line x1="{nx - dx:.2f}" y1="{ny + dy:.2f}" x2="{mx - dx:.2f}" '
|
| 269 |
+
f'y2="{my + dy:.2f}" stroke="{broken}" stroke-width="{wide_width:.2f}"/>'
|
| 270 |
+
)
|
| 271 |
+
elif p_order == 2:
|
| 272 |
+
dx, dy = rv(triple_space)
|
| 273 |
+
svg.append(
|
| 274 |
+
f' <line x1="{nx + dx:.2f}" y1="{ny - dy:.2f}" '
|
| 275 |
+
f'x2="{mx + dx:.2f}" y2="{my - dy:.2f}"/>'
|
| 276 |
+
)
|
| 277 |
+
svg.append(
|
| 278 |
+
f' <line x1="{nx:.2f}" y1="{ny:.2f}" x2="{mx:.2f}" y2="{my:.2f}"/>'
|
| 279 |
+
)
|
| 280 |
+
svg.append(
|
| 281 |
+
f' <line x1="{nx - dx:.2f}" y1="{ny + dy:.2f}" '
|
| 282 |
+
f'x2="{mx - dx:.2f}" y2="{my + dy:.2f}" stroke="{broken}" stroke-width="{wide_width:.2f}"/>'
|
| 283 |
+
)
|
| 284 |
+
elif p_order is None:
|
| 285 |
+
dx, dy = rv(triple_space)
|
| 286 |
+
svg.append(
|
| 287 |
+
f' <line x1="{nx + dx:.2f}" y1="{ny - dy:.2f}" '
|
| 288 |
+
f'x2="{mx + dx:.2f}" y2="{my - dy:.2f}" stroke="{broken}" stroke-width="{wide_width:.2f}"/>'
|
| 289 |
+
)
|
| 290 |
+
svg.append(
|
| 291 |
+
f' <line x1="{nx:.2f}" y1="{ny:.2f}" '
|
| 292 |
+
f'x2="{mx:.2f}" y2="{my:.2f}" stroke="{broken}" stroke-width="{wide_width:.2f}"/>'
|
| 293 |
+
)
|
| 294 |
+
svg.append(
|
| 295 |
+
f' <line x1="{nx - dx:.2f}" y1="{ny + dy:.2f}" '
|
| 296 |
+
f'x2="{mx - dx:.2f}" y2="{my + dy:.2f}" stroke="{broken}" stroke-width="{wide_width:.2f}"/>'
|
| 297 |
+
)
|
| 298 |
+
else:
|
| 299 |
+
dx, dy = rv(double_space)
|
| 300 |
+
dx3 = 3 * dx
|
| 301 |
+
dy3 = 3 * dy
|
| 302 |
+
svg.append(
|
| 303 |
+
f' <line x1="{nx + dx3:.2f}" y1="{ny - dy3:.2f}" x2="{mx + dx3:.2f}" '
|
| 304 |
+
f'y2="{my - dy3:.2f}" stroke-dasharray="{dash1:.2f} {dash2:.2f}" stroke="{formed}" stroke-width="{wide_width:.2f}"/>'
|
| 305 |
+
)
|
| 306 |
+
svg.append(
|
| 307 |
+
f' <line x1="{nx + dx:.2f}" y1="{ny - dy:.2f}" '
|
| 308 |
+
f'x2="{mx + dx:.2f}" y2="{my - dy:.2f}" stroke="{broken}" stroke-width="{wide_width:.2f}"/>'
|
| 309 |
+
)
|
| 310 |
+
svg.append(
|
| 311 |
+
f' <line x1="{nx - dx:.2f}" y1="{ny + dy:.2f}" '
|
| 312 |
+
f'x2="{mx - dx:.2f}" y2="{my + dy:.2f}" stroke="{broken}" stroke-width="{wide_width:.2f}"/>'
|
| 313 |
+
)
|
| 314 |
+
svg.append(
|
| 315 |
+
f' <line x1="{nx - dx3:.2f}" y1="{ny + dy3:.2f}" x2="{mx - dx3:.2f}" '
|
| 316 |
+
f'y2="{my + dy3:.2f}" stroke="{broken}" stroke-width="{wide_width:.2f}"/>'
|
| 317 |
+
)
|
| 318 |
+
elif order is None:
|
| 319 |
+
if p_order == 1:
|
| 320 |
+
svg.append(
|
| 321 |
+
f' <line x1="{nx:.2f}" y1="{ny:.2f}" x2="{mx:.2f}" y2="{my:.2f}"'
|
| 322 |
+
f' stroke="{formed}" stroke-width="{wide_width:.2f}"/>'
|
| 323 |
+
)
|
| 324 |
+
elif p_order == 4:
|
| 325 |
+
ar_bond_colors[n][m] = ar_bond_colors[m][n] = formed
|
| 326 |
+
svg.append(
|
| 327 |
+
f' <line x1="{nx:.2f}" y1="{ny:.2f}" x2="{mx:.2f}" y2="{my:.2f}"'
|
| 328 |
+
f' stroke="{formed}" stroke-width="{wide_width:.2f}"/>'
|
| 329 |
+
)
|
| 330 |
+
elif p_order == 2:
|
| 331 |
+
dx, dy = rv(double_space)
|
| 332 |
+
# dx = dx // 1.4
|
| 333 |
+
# dy = dy // 1.4
|
| 334 |
+
svg.append(
|
| 335 |
+
f' <line x1="{nx + dx:.2f}" y1="{ny - dy:.2f}" x2="{mx + dx:.2f}" '
|
| 336 |
+
f'y2="{my - dy:.2f}" stroke="{formed}" stroke-width="{wide_width:.2f}"/>'
|
| 337 |
+
)
|
| 338 |
+
svg.append(
|
| 339 |
+
f' <line x1="{nx - dx:.2f}" y1="{ny + dy:.2f}" x2="{mx - dx:.2f}" '
|
| 340 |
+
f'y2="{my + dy:.2f}" stroke="{formed}" stroke-width="{wide_width:.2f}"/>'
|
| 341 |
+
)
|
| 342 |
+
elif p_order == 3:
|
| 343 |
+
dx, dy = rv(triple_space)
|
| 344 |
+
svg.append(
|
| 345 |
+
f' <line x1="{nx + dx:.2f}" y1="{ny - dy:.2f}" '
|
| 346 |
+
f'x2="{mx + dx:.2f}" y2="{my - dy:.2f}" stroke="{formed}" stroke-width="{wide_width:.2f}"/>'
|
| 347 |
+
)
|
| 348 |
+
svg.append(
|
| 349 |
+
f' <line x1="{nx:.2f}" y1="{ny:.2f}" x2="{mx:.2f}" y2="{my:.2f}"'
|
| 350 |
+
f' stroke="{formed}" stroke-width="{wide_width:.2f}"/>'
|
| 351 |
+
)
|
| 352 |
+
svg.append(
|
| 353 |
+
f' <line x1="{nx - dx:.2f}" y1="{ny + dy:.2f}" '
|
| 354 |
+
f'x2="{mx - dx:.2f}" y2="{my + dy:.2f}" stroke="{formed}" stroke-width="{wide_width:.2f}"/>'
|
| 355 |
+
)
|
| 356 |
+
else:
|
| 357 |
+
svg.append(
|
| 358 |
+
f' <line x1="{nx:.2f}" y1="{ny:.2f}" x2="{mx:.2f}" y2="{my:.2f}" '
|
| 359 |
+
f'stroke-dasharray="{dash1:.2f} {dash2:.2f}" stroke="{formed}" stroke-width="{wide_width:.2f}"/>'
|
| 360 |
+
)
|
| 361 |
+
else:
|
| 362 |
+
if p_order == 8:
|
| 363 |
+
svg.append(
|
| 364 |
+
f' <line x1="{nx:.2f}" y1="{ny:.2f}" x2="{mx:.2f}" y2="{my:.2f}" '
|
| 365 |
+
f'stroke-dasharray="{dash1:.2f} {dash2:.2f}"/>'
|
| 366 |
+
)
|
| 367 |
+
elif p_order == 1:
|
| 368 |
+
dx, dy = rv(double_space)
|
| 369 |
+
svg.append(
|
| 370 |
+
f' <line x1="{nx + dx:.2f}" y1="{ny - dy:.2f}" x2="{mx + dx:.2f}"'
|
| 371 |
+
f' y2="{my - dy:.2f}" stroke-dasharray="{dash1:.2f} {dash2:.2f}" stroke="{broken}" stroke-width="{wide_width:.2f}"/>'
|
| 372 |
+
)
|
| 373 |
+
svg.append(
|
| 374 |
+
f' <line x1="{nx - dx:.2f}" y1="{ny + dy:.2f}" '
|
| 375 |
+
f'x2="{mx - dx:.2f}" y2="{my + dy:.2f}" stroke="{formed}" stroke-width="{wide_width:.2f}"/>'
|
| 376 |
+
)
|
| 377 |
+
elif p_order == 4:
|
| 378 |
+
ar_bond_colors[n][m] = ar_bond_colors[m][n] = None
|
| 379 |
+
svg.append(
|
| 380 |
+
f' <line x1="{nx:.2f}" y1="{ny:.2f}" x2="{mx:.2f}" y2="{my:.2f}"'
|
| 381 |
+
f' stroke="{formed}" stroke-width="{wide_width:.2f}"/>'
|
| 382 |
+
)
|
| 383 |
+
elif p_order == 2:
|
| 384 |
+
dx, dy = rv(triple_space)
|
| 385 |
+
svg.append(
|
| 386 |
+
f' <line x1="{nx + dx:.2f}" y1="{ny - dy:.2f}" x2="{mx + dx:.2f}"'
|
| 387 |
+
f' y2="{my - dy:.2f}" stroke-dasharray="{dash1:.2f} {dash2:.2f}" stroke="{broken}" stroke-width="{wide_width:.2f}"/>'
|
| 388 |
+
)
|
| 389 |
+
svg.append(
|
| 390 |
+
f' <line x1="{nx:.2f}" y1="{ny:.2f}" x2="{mx:.2f}" y2="{my:.2f}"'
|
| 391 |
+
f' stroke="{formed}" stroke-width="{wide_width:.2f}"/>'
|
| 392 |
+
)
|
| 393 |
+
svg.append(
|
| 394 |
+
f' <line x1="{nx - dx:.2f}" y1="{ny + dy:.2f}" '
|
| 395 |
+
f'x2="{mx - dx:.2f}" y2="{my + dy:.2f}" stroke="{formed}" stroke-width="{wide_width:.2f}"/>'
|
| 396 |
+
)
|
| 397 |
+
elif p_order == 3:
|
| 398 |
+
dx, dy = rv(double_space)
|
| 399 |
+
dx3 = 3 * dx
|
| 400 |
+
dy3 = 3 * dy
|
| 401 |
+
svg.append(
|
| 402 |
+
f' <line x1="{nx + dx3:.2f}" y1="{ny - dy3:.2f}" x2="{mx + dx3:.2f}" '
|
| 403 |
+
f'y2="{my - dy3:.2f}" stroke-dasharray="{dash1:.2f} {dash2:.2f}" stroke="{broken}" stroke-width="{wide_width:.2f}"/>'
|
| 404 |
+
)
|
| 405 |
+
svg.append(
|
| 406 |
+
f' <line x1="{nx + dx:.2f}" y1="{ny - dy:.2f}" '
|
| 407 |
+
f'x2="{mx + dx:.2f}" y2="{my - dy:.2f}" stroke="{formed}" stroke-width="{wide_width:.2f}"/>'
|
| 408 |
+
)
|
| 409 |
+
svg.append(
|
| 410 |
+
f' <line x1="{nx - dx:.2f}" y1="{ny + dy:.2f}" '
|
| 411 |
+
f'x2="{mx - dx:.2f}" y2="{my + dy:.2f}" stroke="{formed}" stroke-width="{wide_width:.2f}"/>'
|
| 412 |
+
)
|
| 413 |
+
svg.append(
|
| 414 |
+
f' <line x1="{nx - dx3:.2f}" y1="{ny + dy3:.2f}" '
|
| 415 |
+
f'x2="{mx - dx3:.2f}" y2="{my + dy3:.2f}" stroke="{formed}" stroke-width="{wide_width:.2f}"/>'
|
| 416 |
+
)
|
| 417 |
+
else:
|
| 418 |
+
svg.append(
|
| 419 |
+
f' <line x1="{nx:.2f}" y1="{ny:.2f}" x2="{mx:.2f}" y2="{my:.2f}" '
|
| 420 |
+
f'stroke-dasharray="{dash1:.2f} {dash2:.2f}" stroke="{broken}" stroke-width="{wide_width:.2f}"/>'
|
| 421 |
+
)
|
| 422 |
+
|
| 423 |
+
# aromatic rings - unchanged
|
| 424 |
+
for ring in self.aromatic_rings:
|
| 425 |
+
cx = sum(plane[x][0] for x in ring) / len(ring)
|
| 426 |
+
cy = sum(plane[x][1] for x in ring) / len(ring)
|
| 427 |
+
|
| 428 |
+
for n, m in zip(ring, ring[1:]):
|
| 429 |
+
nx, ny = plane[n]
|
| 430 |
+
mx, my = plane[m]
|
| 431 |
+
aromatic = self.__render_aromatic_bond(
|
| 432 |
+
nx, ny, mx, my, cx, cy, ar_bond_colors[n].get(m)
|
| 433 |
+
)
|
| 434 |
+
if aromatic:
|
| 435 |
+
svg.append(aromatic)
|
| 436 |
+
|
| 437 |
+
n, m = ring[-1], ring[0]
|
| 438 |
+
nx, ny = plane[n]
|
| 439 |
+
mx, my = plane[m]
|
| 440 |
+
aromatic = self.__render_aromatic_bond(
|
| 441 |
+
nx, ny, mx, my, cx, cy, ar_bond_colors[n].get(m)
|
| 442 |
+
)
|
| 443 |
+
if aromatic:
|
| 444 |
+
svg.append(aromatic)
|
| 445 |
+
return svg
|
| 446 |
+
|
| 447 |
+
def __render_aromatic_bond(self, n_x, n_y, m_x, m_y, c_x, c_y, color):
|
| 448 |
+
config = self._render_config
|
| 449 |
+
|
| 450 |
+
dash1, dash2 = config["dashes"]
|
| 451 |
+
dash3, dash4 = config["aromatic_dashes"]
|
| 452 |
+
aromatic_space = config["cgr_aromatic_space"]
|
| 453 |
+
|
| 454 |
+
normal_width = config.get("bond_width", 0.02)
|
| 455 |
+
wide_width = normal_width * 2
|
| 456 |
+
|
| 457 |
+
# n aligned xy
|
| 458 |
+
mn_x, mn_y, cn_x, cn_y = m_x - n_x, m_y - n_y, c_x - n_x, c_y - n_y
|
| 459 |
+
|
| 460 |
+
# nm reoriented xy
|
| 461 |
+
mr_x, mr_y = hypot(mn_x, mn_y), 0
|
| 462 |
+
cr_x, cr_y = rotate_vector(cn_x, cn_y, mn_x, -mn_y)
|
| 463 |
+
|
| 464 |
+
if cr_y and aromatic_space / cr_y < 0.65:
|
| 465 |
+
if cr_y > 0:
|
| 466 |
+
r_y = aromatic_space
|
| 467 |
+
else:
|
| 468 |
+
r_y = -aromatic_space
|
| 469 |
+
cr_y = -cr_y
|
| 470 |
+
|
| 471 |
+
ar_x = aromatic_space * cr_x / cr_y
|
| 472 |
+
br_x = mr_x - aromatic_space * (mr_x - cr_x) / cr_y
|
| 473 |
+
|
| 474 |
+
# backward reorienting
|
| 475 |
+
an_x, an_y = rotate_vector(ar_x, r_y, mn_x, mn_y)
|
| 476 |
+
bn_x, bn_y = rotate_vector(br_x, r_y, mn_x, mn_y)
|
| 477 |
+
|
| 478 |
+
if color:
|
| 479 |
+
# print('color')
|
| 480 |
+
return (
|
| 481 |
+
f' <line x1="{an_x + n_x:.2f}" y1="{-an_y - n_y:.2f}" x2="{bn_x + n_x:.2f}" '
|
| 482 |
+
f'y2="{-bn_y - n_y:.2f}" stroke-dasharray="{dash3:.2f} {dash4:.2f}" stroke="{color}" stroke-width="{wide_width:.2f}"/>'
|
| 483 |
+
)
|
| 484 |
+
elif color is None:
|
| 485 |
+
dash3, dash4 = dash1, dash2
|
| 486 |
+
return (
|
| 487 |
+
f' <line x1="{an_x + n_x:.2f}" y1="{-an_y - n_y:.2f}"'
|
| 488 |
+
f' x2="{bn_x + n_x:.2f}" y2="{-bn_y - n_y:.2f}" stroke-dasharray="{dash3:.2f} {dash4:.2f}"/>'
|
| 489 |
+
)
|
| 490 |
+
|
| 491 |
+
|
| 492 |
+
def cgr_display(cgr: CGRContainer) -> str:
|
| 493 |
+
"""
|
| 494 |
+
Generates an SVG string for displaying a CGR with wider DynamicBonds.
|
| 495 |
+
|
| 496 |
+
This function temporarily modifies the rendering methods of the
|
| 497 |
+
`CGRContainer` class to use the bond rendering logic from
|
| 498 |
+
`WideBondDepictCGR`, which draws DynamicBonds with a wider stroke.
|
| 499 |
+
It cleans the 2D coordinates of the input CGR and then calls its
|
| 500 |
+
`depict()` method to generate the SVG string using the modified
|
| 501 |
+
rendering behavior.
|
| 502 |
+
|
| 503 |
+
Args:
|
| 504 |
+
cgr (CGRContainer): The CGRContainer object to be depicted.
|
| 505 |
+
|
| 506 |
+
Returns:
|
| 507 |
+
str: An SVG string representing the depiction of the CGR
|
| 508 |
+
with wider DynamicBonds.
|
| 509 |
+
"""
|
| 510 |
+
CGRContainer._CGRContainer__render_aromatic_bond = (
|
| 511 |
+
WideBondDepictCGR._WideBondDepictCGR__render_aromatic_bond
|
| 512 |
+
)
|
| 513 |
+
CGRContainer._render_bonds = WideBondDepictCGR._render_bonds
|
| 514 |
+
CGRContainer._WideBondDepictCGR__render_aromatic_bond = (
|
| 515 |
+
WideBondDepictCGR._WideBondDepictCGR__render_aromatic_bond
|
| 516 |
+
)
|
| 517 |
+
cgr.clean2d()
|
| 518 |
+
return cgr.depict()
|
| 519 |
+
|
| 520 |
+
|
| 521 |
+
class CustomDepictMolecule(DepictMolecule):
|
| 522 |
+
"""
|
| 523 |
+
Custom molecule depiction class that uses atom.symbol for rendering.
|
| 524 |
+
"""
|
| 525 |
+
|
| 526 |
+
def _render_atoms(self):
|
| 527 |
+
bonds = self._bonds
|
| 528 |
+
plane = self._plane
|
| 529 |
+
charges = self._charges
|
| 530 |
+
radicals = self._radicals
|
| 531 |
+
hydrogens = self._hydrogens
|
| 532 |
+
config = self._render_config
|
| 533 |
+
|
| 534 |
+
carbon = config["carbon"]
|
| 535 |
+
mapping = config["mapping"]
|
| 536 |
+
span_size = config["span_size"]
|
| 537 |
+
font_size = config["font_size"]
|
| 538 |
+
monochrome = config["monochrome"]
|
| 539 |
+
other_size = config["other_size"]
|
| 540 |
+
atoms_colors = config["atoms_colors"]
|
| 541 |
+
mapping_font = config["mapping_size"]
|
| 542 |
+
dx_m, dy_m = config["dx_m"], config["dy_m"]
|
| 543 |
+
dx_ci, dy_ci = config["dx_ci"], config["dy_ci"]
|
| 544 |
+
symbols_font_style = config["symbols_font_style"]
|
| 545 |
+
|
| 546 |
+
# for cumulenes
|
| 547 |
+
try:
|
| 548 |
+
# Check if _cumulenes method exists and handle potential errors
|
| 549 |
+
cumulenes = {
|
| 550 |
+
y
|
| 551 |
+
for x in self._cumulenes(heteroatoms=True)
|
| 552 |
+
if len(x) > 2
|
| 553 |
+
for y in x[1:-1]
|
| 554 |
+
}
|
| 555 |
+
except AttributeError:
|
| 556 |
+
cumulenes = set() # Fallback if _cumulenes is not available or fails
|
| 557 |
+
|
| 558 |
+
if monochrome:
|
| 559 |
+
map_fill = other_fill = "black"
|
| 560 |
+
else:
|
| 561 |
+
map_fill = config["mapping_color"]
|
| 562 |
+
other_fill = config["other_color"]
|
| 563 |
+
|
| 564 |
+
svg = []
|
| 565 |
+
maps = []
|
| 566 |
+
others = []
|
| 567 |
+
font2 = 0.2 * font_size
|
| 568 |
+
font3 = 0.3 * font_size
|
| 569 |
+
font4 = 0.4 * font_size
|
| 570 |
+
font5 = 0.5 * font_size
|
| 571 |
+
font6 = 0.6 * font_size
|
| 572 |
+
font7 = 0.7 * font_size
|
| 573 |
+
font15 = 0.15 * font_size
|
| 574 |
+
font25 = 0.25 * font_size
|
| 575 |
+
mask = defaultdict(list)
|
| 576 |
+
for n, atom in self._atoms.items():
|
| 577 |
+
x, y = plane[n]
|
| 578 |
+
y = -y
|
| 579 |
+
|
| 580 |
+
# --- KEY CHANGE HERE ---
|
| 581 |
+
# Use atom.symbol if it exists, otherwise fallback to atomic_symbol
|
| 582 |
+
try:
|
| 583 |
+
symbol = atom.symbol
|
| 584 |
+
except AttributeError:
|
| 585 |
+
symbol = atom.atomic_symbol # Fallback if .symbol doesn't exist
|
| 586 |
+
# --- END KEY CHANGE ---
|
| 587 |
+
|
| 588 |
+
if (
|
| 589 |
+
not bonds.get(n)
|
| 590 |
+
or symbol != "C"
|
| 591 |
+
or carbon
|
| 592 |
+
or atom.charge
|
| 593 |
+
or atom.is_radical
|
| 594 |
+
or atom.isotope
|
| 595 |
+
or n in cumulenes
|
| 596 |
+
): # Added bonds.get(n) check for single atoms
|
| 597 |
+
# Calculate hydrogens if the attribute exists, otherwise default to 0
|
| 598 |
+
try:
|
| 599 |
+
h = hydrogens[n]
|
| 600 |
+
except (KeyError, AttributeError):
|
| 601 |
+
h = 0 # Default if _hydrogens is missing or key n is not present
|
| 602 |
+
|
| 603 |
+
if h == 1:
|
| 604 |
+
h_str = "H"
|
| 605 |
+
span = ""
|
| 606 |
+
elif h and h > 1: # Check if h is not None and greater than 1
|
| 607 |
+
span = f'<tspan dy="{config["span_dy"]:.2f}" font-size="{span_size:.2f}">{h}</tspan>'
|
| 608 |
+
h_str = "H"
|
| 609 |
+
else:
|
| 610 |
+
h_str = ""
|
| 611 |
+
span = ""
|
| 612 |
+
|
| 613 |
+
# Handle charges and radicals safely
|
| 614 |
+
charge_val = charges.get(n, 0)
|
| 615 |
+
is_radical = radicals.get(n, False)
|
| 616 |
+
|
| 617 |
+
if charge_val:
|
| 618 |
+
t = f'{_render_charge.get(charge_val, "")}{"↑" if is_radical else ""}' # Use .get for safety
|
| 619 |
+
if t: # Only add if charge text is generated
|
| 620 |
+
others.append(
|
| 621 |
+
f' <text x="{x:.2f}" y="{y:.2f}" dx="{dx_ci:.2f}" dy="-{dy_ci:.2f}">'
|
| 622 |
+
f"{t}</text>"
|
| 623 |
+
)
|
| 624 |
+
mask["other"].append(
|
| 625 |
+
f' <text x="{x:.2f}" y="{y:.2f}" dx="{dx_ci:.2f}" dy="-{dy_ci:.2f}">'
|
| 626 |
+
f"{t}</text>"
|
| 627 |
+
)
|
| 628 |
+
elif is_radical:
|
| 629 |
+
others.append(
|
| 630 |
+
f' <text x="{x:.2f}" y="{y:.2f}" dx="{dx_ci:.2f}" dy="-{dy_ci:.2f}">↑</text>'
|
| 631 |
+
)
|
| 632 |
+
mask["other"].append(
|
| 633 |
+
f' <text x="{x:.2f}" y="{y:.2f}" dx="{dx_ci:.2f}"'
|
| 634 |
+
f' dy="-{dy_ci:.2f}">↑</text>'
|
| 635 |
+
)
|
| 636 |
+
|
| 637 |
+
# Handle isotope safely
|
| 638 |
+
try:
|
| 639 |
+
iso = atom.isotope
|
| 640 |
+
if iso:
|
| 641 |
+
t = iso
|
| 642 |
+
others.append(
|
| 643 |
+
f' <text x="{x:.2f}" y="{y:.2f}" dx="-{dx_ci:.2f}" dy="-{dy_ci:.2f}" '
|
| 644 |
+
f'text-anchor="end">{t}</text>'
|
| 645 |
+
)
|
| 646 |
+
mask["other"].append(
|
| 647 |
+
f' <text x="{x:.2f}" y="{y:.2f}" dx="-{dx_ci:.2f}"'
|
| 648 |
+
f' dy="-{dy_ci:.2f}" text-anchor="end">{t}</text>'
|
| 649 |
+
)
|
| 650 |
+
except AttributeError:
|
| 651 |
+
pass # Atom might not have isotope attribute
|
| 652 |
+
|
| 653 |
+
# Determine atom color based on atomic_number, default to black if monochrome or not found
|
| 654 |
+
atom_color = "black"
|
| 655 |
+
if not monochrome:
|
| 656 |
+
try:
|
| 657 |
+
an = atom.atomic_number
|
| 658 |
+
if 0 < an <= len(atoms_colors):
|
| 659 |
+
atom_color = atoms_colors[an - 1]
|
| 660 |
+
else:
|
| 661 |
+
atom_color = atoms_colors[
|
| 662 |
+
5
|
| 663 |
+
] # Default to Carbon color if out of range
|
| 664 |
+
except AttributeError:
|
| 665 |
+
atom_color = atoms_colors[
|
| 666 |
+
5
|
| 667 |
+
] # Default to Carbon color if no atomic_number
|
| 668 |
+
|
| 669 |
+
svg.append(
|
| 670 |
+
f' <g fill="{atom_color}" '
|
| 671 |
+
f'font-family="{symbols_font_style }">'
|
| 672 |
+
)
|
| 673 |
+
|
| 674 |
+
# Adjust dx based on symbol length for better centering
|
| 675 |
+
if len(symbol) > 1:
|
| 676 |
+
dx = font7
|
| 677 |
+
dx_mm = dx_m + font5
|
| 678 |
+
if symbol[-1].lower() in (
|
| 679 |
+
"l",
|
| 680 |
+
"i",
|
| 681 |
+
"r",
|
| 682 |
+
"t",
|
| 683 |
+
): # Heuristic for narrow last letters
|
| 684 |
+
rx = font6
|
| 685 |
+
ax = font25
|
| 686 |
+
else:
|
| 687 |
+
rx = font7
|
| 688 |
+
ax = font15
|
| 689 |
+
mask["center"].append(
|
| 690 |
+
f' <ellipse cx="{x - ax:.2f}" cy="{y:.2f}" rx="{rx}" ry="{font4}"/>'
|
| 691 |
+
)
|
| 692 |
+
else:
|
| 693 |
+
if symbol == "I": # Special case for 'I'
|
| 694 |
+
dx = font15
|
| 695 |
+
dx_mm = dx_m
|
| 696 |
+
else: # Single character
|
| 697 |
+
dx = font4
|
| 698 |
+
dx_mm = dx_m + font2
|
| 699 |
+
mask["center"].append(
|
| 700 |
+
f' <circle cx="{x:.2f}" cy="{y:.2f}" r="{font4:.2f}"/>'
|
| 701 |
+
)
|
| 702 |
+
|
| 703 |
+
svg.append(
|
| 704 |
+
f' <text x="{x:.2f}" y="{y:.2f}" dx="-{dx:.2f}" dy="{font4:.2f}" '
|
| 705 |
+
f'font-size="{font_size:.2f}">{symbol}{h_str}{span}</text>'
|
| 706 |
+
)
|
| 707 |
+
mask["symbols"].append(
|
| 708 |
+
f' <text x="{x:.2f}" y="{y:.2f}" dx="-{dx:.2f}" '
|
| 709 |
+
f'dy="{font4:.2f}">{symbol}{h_str}</text>'
|
| 710 |
+
)
|
| 711 |
+
if span:
|
| 712 |
+
mask["span"].append(
|
| 713 |
+
f' <text x="{x:.2f}" y="{y:.2f}" dx="-{dx:.2f}" dy="{font4:.2f}">'
|
| 714 |
+
f"{symbol}{h_str}{span}</text>"
|
| 715 |
+
)
|
| 716 |
+
svg.append(" </g>")
|
| 717 |
+
|
| 718 |
+
if mapping:
|
| 719 |
+
maps.append(
|
| 720 |
+
f' <text x="{x:.2f}" y="{y:.2f}" dx="-{dx_mm:.2f}" dy="{dy_m + font3:.2f}" '
|
| 721 |
+
f'text-anchor="end">{n}</text>'
|
| 722 |
+
)
|
| 723 |
+
mask["aam"].append(
|
| 724 |
+
f' <text x="{x:.2f}" y="{y:.2f}" dx="-{dx_mm:.2f}" '
|
| 725 |
+
f'dy="{dy_m + font3:.2f}" text-anchor="end">{n}</text>'
|
| 726 |
+
)
|
| 727 |
+
|
| 728 |
+
elif mapping:
|
| 729 |
+
# Determine dx_mm for mapping based on symbol length even if atom itself isn't drawn
|
| 730 |
+
if len(symbol) > 1:
|
| 731 |
+
dx_mm = dx_m + font5
|
| 732 |
+
else:
|
| 733 |
+
dx_mm = dx_m + font2 if symbol != "I" else dx_m
|
| 734 |
+
|
| 735 |
+
maps.append(
|
| 736 |
+
f' <text x="{x:.2f}" y="{y:.2f}" dx="-{dx_mm:.2f}" dy="{dy_m:.2f}" '
|
| 737 |
+
f'text-anchor="end">{n}</text>'
|
| 738 |
+
)
|
| 739 |
+
mask["aam"].append(
|
| 740 |
+
f' <text x="{x:.2f}" y="{y:.2f}" dx="-{dx_mm:.2f}" dy="{dy_m:.2f}" '
|
| 741 |
+
f'text-anchor="end">{n}</text>'
|
| 742 |
+
)
|
| 743 |
+
if others:
|
| 744 |
+
svg.append(
|
| 745 |
+
f' <g font-family="{config["other_font_style"]}" fill="{other_fill}" '
|
| 746 |
+
f'font-size="{other_size:.2f}">'
|
| 747 |
+
)
|
| 748 |
+
svg.extend(others)
|
| 749 |
+
svg.append(" </g>")
|
| 750 |
+
if mapping:
|
| 751 |
+
svg.append(f' <g fill="{map_fill}" font-size="{mapping_font:.2f}">')
|
| 752 |
+
svg.extend(maps)
|
| 753 |
+
svg.append(" </g>")
|
| 754 |
+
return svg, mask
|
| 755 |
+
|
| 756 |
+
|
| 757 |
+
def depict_custom_reaction(reaction: ReactionContainer):
|
| 758 |
+
"""
|
| 759 |
+
Depicts a ReactionContainer using custom atom rendering logic (replace At to X).
|
| 760 |
+
|
| 761 |
+
This function generates an SVG string representing a reaction. It
|
| 762 |
+
temporarily modifies the classes of the molecules within the reaction
|
| 763 |
+
to use a custom depiction logic (`CustomDepictMolecule`) that alters
|
| 764 |
+
how atoms are rendered (specifically, it seems to use `atom.symbol`
|
| 765 |
+
instead of `atom.atomic_symbol`, potentially for replacing 'At' with 'X'
|
| 766 |
+
as mentioned in the original comment). After depicting each molecule
|
| 767 |
+
with the temporary class, it restores the original classes. The function
|
| 768 |
+
then combines the individual molecule depictions, reaction arrow, and
|
| 769 |
+
reaction signs into a single SVG.
|
| 770 |
+
|
| 771 |
+
Args:
|
| 772 |
+
reaction (ReactionContainer): The ReactionContainer object to be depicted.
|
| 773 |
+
|
| 774 |
+
Returns:
|
| 775 |
+
str: An SVG string representing the depiction of the reaction
|
| 776 |
+
with custom atom rendering.
|
| 777 |
+
"""
|
| 778 |
+
if not reaction._arrow:
|
| 779 |
+
reaction.fix_positions() # Ensure positions are calculated
|
| 780 |
+
|
| 781 |
+
r_atoms = []
|
| 782 |
+
r_bonds = []
|
| 783 |
+
r_masks = []
|
| 784 |
+
r_max_x = r_max_y = r_min_y = 0
|
| 785 |
+
original_classes = {} # Store original classes to restore later
|
| 786 |
+
|
| 787 |
+
try:
|
| 788 |
+
# Temporarily change the class of molecules to use the custom depiction
|
| 789 |
+
for mol in reaction.molecules():
|
| 790 |
+
if isinstance(mol, (MoleculeContainer, CGRContainer)):
|
| 791 |
+
original_classes[mol] = mol.__class__
|
| 792 |
+
custom_class_name = (
|
| 793 |
+
f"TempCustom_{mol.__class__.__name__}_{uuid4().hex}" # Unique name
|
| 794 |
+
)
|
| 795 |
+
# Combine custom depiction with original class methods
|
| 796 |
+
# Ensure the custom _render_atoms takes precedence
|
| 797 |
+
new_bases = (CustomDepictMolecule,) + original_classes[mol].__bases__
|
| 798 |
+
# Filter out DepictMolecule if it's already a base to avoid MRO issues
|
| 799 |
+
new_bases = tuple(b for b in new_bases if b is not DepictMolecule)
|
| 800 |
+
# If DepictMolecule wasn't a direct base, ensure its methods are accessible
|
| 801 |
+
if CustomDepictMolecule not in original_classes[mol].__mro__:
|
| 802 |
+
# Prioritize CustomDepictMolecule's methods
|
| 803 |
+
new_bases = (CustomDepictMolecule, original_classes[mol])
|
| 804 |
+
else:
|
| 805 |
+
# If DepictMolecule was a base, CustomDepictMolecule is already first
|
| 806 |
+
new_bases = (CustomDepictMolecule,) + tuple(
|
| 807 |
+
b
|
| 808 |
+
for b in original_classes[mol].__bases__
|
| 809 |
+
if b is not DepictMolecule
|
| 810 |
+
)
|
| 811 |
+
|
| 812 |
+
# Create the temporary class
|
| 813 |
+
mol.__class__ = type(custom_class_name, new_bases, {})
|
| 814 |
+
|
| 815 |
+
# Depict using the (potentially) modified class
|
| 816 |
+
atoms, bonds, masks, min_x, min_y, max_x, max_y = mol.depict(embedding=True)
|
| 817 |
+
r_atoms.append(atoms)
|
| 818 |
+
r_bonds.append(bonds)
|
| 819 |
+
r_masks.append(masks)
|
| 820 |
+
if max_x > r_max_x:
|
| 821 |
+
r_max_x = max_x
|
| 822 |
+
if max_y > r_max_y:
|
| 823 |
+
r_max_y = max_y
|
| 824 |
+
if min_y < r_min_y:
|
| 825 |
+
r_min_y = min_y
|
| 826 |
+
|
| 827 |
+
finally:
|
| 828 |
+
# Restore original classes
|
| 829 |
+
for mol, original_class in original_classes.items():
|
| 830 |
+
mol.__class__ = original_class
|
| 831 |
+
|
| 832 |
+
config = DepictMolecule._render_config # Access via the imported class
|
| 833 |
+
|
| 834 |
+
font_size = config["font_size"]
|
| 835 |
+
font125 = 1.25 * font_size
|
| 836 |
+
width = r_max_x + 3.0 * font_size
|
| 837 |
+
height = r_max_y - r_min_y + 2.5 * font_size
|
| 838 |
+
viewbox_x = -font125
|
| 839 |
+
viewbox_y = -r_max_y - font125
|
| 840 |
+
|
| 841 |
+
svg = [
|
| 842 |
+
f'<svg width="{width:.2f}cm" height="{height:.2f}cm" '
|
| 843 |
+
f'viewBox="{viewbox_x:.2f} {viewbox_y:.2f} {width:.2f} '
|
| 844 |
+
f'{height:.2f}" xmlns="http://www.w3.org/2000/svg" version="1.1">\n'
|
| 845 |
+
' <defs>\n <marker id="arrow" markerWidth="10" markerHeight="10" '
|
| 846 |
+
'refX="0" refY="3" orient="auto">\n <path d="M0,0 L0,6 L9,3"/>\n </marker>\n </defs>\n'
|
| 847 |
+
f' <line x1="{reaction._arrow[0]:.2f}" y1="0" x2="{reaction._arrow[1]:.2f}" y2="0" '
|
| 848 |
+
'fill="none" stroke="black" stroke-width=".04" marker-end="url(#arrow)"/>'
|
| 849 |
+
]
|
| 850 |
+
|
| 851 |
+
sings_plus = reaction._signs
|
| 852 |
+
if sings_plus:
|
| 853 |
+
svg.append(f' <g fill="none" stroke="black" stroke-width=".04">')
|
| 854 |
+
for x in sings_plus:
|
| 855 |
+
svg.append(
|
| 856 |
+
f' <line x1="{x + .35:.2f}" y1="0" x2="{x + .65:.2f}" y2="0"/>'
|
| 857 |
+
)
|
| 858 |
+
svg.append(
|
| 859 |
+
f' <line x1="{x + .5:.2f}" y1="0.15" x2="{x + .5:.2f}" y2="-0.15"/>'
|
| 860 |
+
)
|
| 861 |
+
svg.append(" </g>")
|
| 862 |
+
|
| 863 |
+
for atoms, bonds, masks in zip(r_atoms, r_bonds, r_masks):
|
| 864 |
+
# Use the static method from Depict directly
|
| 865 |
+
svg.extend(
|
| 866 |
+
Depict._graph_svg(atoms, bonds, masks, viewbox_x, viewbox_y, width, height)
|
| 867 |
+
)
|
| 868 |
+
svg.append("</svg>")
|
| 869 |
+
return "\n".join(svg)
|
| 870 |
+
|
| 871 |
+
|
| 872 |
+
def remove_and_shift(nested_dict, to_remove): # Under development
|
| 873 |
+
"""
|
| 874 |
+
Removes specified inner keys from a nested dictionary and renumbers the remaining keys.
|
| 875 |
+
|
| 876 |
+
Given a dictionary where values are themselves dictionaries, this function
|
| 877 |
+
iterates through each inner dictionary. For each inner dictionary, it
|
| 878 |
+
creates a new dictionary containing only the key-value pairs where the
|
| 879 |
+
inner key is NOT present in the `to_remove` list. The keys of the remaining
|
| 880 |
+
elements in the new inner dictionary are then renumbered sequentially
|
| 881 |
+
starting from 0, effectively removing gaps left by the removed keys.
|
| 882 |
+
|
| 883 |
+
Args:
|
| 884 |
+
nested_dict (dict): The input nested dictionary (dict of dicts).
|
| 885 |
+
to_remove (list): A list of keys to remove from the inner dictionaries.
|
| 886 |
+
|
| 887 |
+
Returns:
|
| 888 |
+
dict: A new nested dictionary with the specified keys removed from
|
| 889 |
+
inner dictionaries and the remaining inner keys renumbered.
|
| 890 |
+
"""
|
| 891 |
+
rem_set = set(to_remove)
|
| 892 |
+
|
| 893 |
+
result = {}
|
| 894 |
+
for outer_k, inner in nested_dict.items():
|
| 895 |
+
new_inner = {}
|
| 896 |
+
for old_k, v in inner.items():
|
| 897 |
+
if old_k in rem_set:
|
| 898 |
+
continue
|
| 899 |
+
shift = sum(1 for r in rem_set if r < old_k)
|
| 900 |
+
new_k = old_k - shift
|
| 901 |
+
new_inner[new_k] = v
|
| 902 |
+
result[outer_k] = new_inner
|
| 903 |
+
return result
|
synplan/chem/reaction_rules/__init__.py
ADDED
|
File without changes
|
synplan/chem/reaction_rules/extraction.py
ADDED
|
@@ -0,0 +1,744 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Module containing functions for protocol of reaction rules extraction."""
|
| 2 |
+
|
| 3 |
+
import logging
|
| 4 |
+
import pickle
|
| 5 |
+
from collections import defaultdict
|
| 6 |
+
from itertools import islice
|
| 7 |
+
from os.path import splitext
|
| 8 |
+
from typing import Dict, List, Set, Tuple
|
| 9 |
+
|
| 10 |
+
import ray
|
| 11 |
+
from chython import smarts
|
| 12 |
+
from chython import QueryContainer as QueryContainerChython
|
| 13 |
+
from CGRtools.containers.cgr import CGRContainer
|
| 14 |
+
from CGRtools.containers.molecule import MoleculeContainer
|
| 15 |
+
from CGRtools.containers.query import QueryContainer
|
| 16 |
+
from CGRtools.containers.reaction import ReactionContainer
|
| 17 |
+
from CGRtools.exceptions import InvalidAromaticRing
|
| 18 |
+
from CGRtools.reactor import Reactor
|
| 19 |
+
from tqdm import tqdm
|
| 20 |
+
|
| 21 |
+
from synplan.chem.data.standardizing import RemoveReagentsStandardizer
|
| 22 |
+
from synplan.chem.utils import (
|
| 23 |
+
reverse_reaction,
|
| 24 |
+
cgrtools_to_chython_molecule,
|
| 25 |
+
chython_query_to_cgrtools,
|
| 26 |
+
)
|
| 27 |
+
from synplan.utils.config import RuleExtractionConfig
|
| 28 |
+
from synplan.utils.files import ReactionReader
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def add_environment_atoms(
|
| 32 |
+
cgr: CGRContainer, center_atoms: Set[int], environment_atom_count: int
|
| 33 |
+
) -> Set[int]:
|
| 34 |
+
"""
|
| 35 |
+
Adds environment atoms to the set of center atoms based on the specified depth.
|
| 36 |
+
|
| 37 |
+
:param cgr: A complete graph representation of a reaction (ReactionContainer
|
| 38 |
+
object).
|
| 39 |
+
:param center_atoms: A set of atom id corresponding to the center atoms of the
|
| 40 |
+
reaction.
|
| 41 |
+
:param environment_atom_count: An integer specifying the depth of the environment
|
| 42 |
+
around the reaction center to be included. If it's 0, only the reaction center
|
| 43 |
+
is included. If it's 1, the first layer of surrounding atoms is included, and so
|
| 44 |
+
on.
|
| 45 |
+
|
| 46 |
+
:return: A set of atom id including the center atoms and their environment atoms up
|
| 47 |
+
to the specified depth. If environment_atom_count is 0, the original set of
|
| 48 |
+
center atoms is returned unchanged.
|
| 49 |
+
|
| 50 |
+
"""
|
| 51 |
+
if environment_atom_count:
|
| 52 |
+
env_cgr = cgr.augmented_substructure(center_atoms, deep=environment_atom_count)
|
| 53 |
+
# combine the original center atoms with the new environment atoms
|
| 54 |
+
return center_atoms | set(env_cgr)
|
| 55 |
+
|
| 56 |
+
# if no environment is to be included, return the original center atoms
|
| 57 |
+
return center_atoms
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
def add_functional_groups(
|
| 61 |
+
reaction: ReactionContainer,
|
| 62 |
+
center_atoms: Set[int],
|
| 63 |
+
func_groups_list: List[QueryContainerChython],
|
| 64 |
+
) -> Set[int]:
|
| 65 |
+
"""
|
| 66 |
+
Augments the set of reaction rule atoms with functional groups if specified.
|
| 67 |
+
|
| 68 |
+
:param reaction: The reaction object (ReactionContainer) from which molecules are
|
| 69 |
+
extracted.
|
| 70 |
+
:param center_atoms: A set of atom id corresponding to the center atoms of the
|
| 71 |
+
reaction.
|
| 72 |
+
:param func_groups_list: A list of functional group objects (MoleculeContainer or
|
| 73 |
+
QueryContainer) to be considered when including functional groups. These objects
|
| 74 |
+
define the structure of the functional groups to be included.
|
| 75 |
+
|
| 76 |
+
:return: A set of atom id corresponding to the rule atoms, including atoms from the
|
| 77 |
+
specified functional groups if include_func_groups is True. If
|
| 78 |
+
include_func_groups is False, the original set of center atoms is returned.
|
| 79 |
+
|
| 80 |
+
"""
|
| 81 |
+
|
| 82 |
+
rule_atoms = center_atoms.copy()
|
| 83 |
+
# iterate over each molecule in the reaction
|
| 84 |
+
for molecule in reaction.molecules():
|
| 85 |
+
molecule_chython = cgrtools_to_chython_molecule(molecule)
|
| 86 |
+
# for each functional group specified in the list
|
| 87 |
+
for func_group in func_groups_list:
|
| 88 |
+
# find mappings of the functional group in the molecule
|
| 89 |
+
for mapping in func_group.get_mapping(molecule_chython):
|
| 90 |
+
# remap the functional group based on the found mapping
|
| 91 |
+
func_group.remap(mapping)
|
| 92 |
+
# if the functional group intersects with center atoms, include it
|
| 93 |
+
if set(func_group.atoms_numbers) & center_atoms:
|
| 94 |
+
rule_atoms |= set(func_group.atoms_numbers)
|
| 95 |
+
# reset the mapping to its original state for the next iteration
|
| 96 |
+
func_group.remap({v: k for k, v in mapping.items()})
|
| 97 |
+
return rule_atoms
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
def add_ring_structures(cgr: CGRContainer, rule_atoms: Set[int]) -> Set[int]:
|
| 101 |
+
"""
|
| 102 |
+
Adds ring structures to the set of rule atoms if they intersect with the reaction
|
| 103 |
+
center atoms.
|
| 104 |
+
|
| 105 |
+
:param cgr: A condensed graph representation of a reaction (CGRContainer object).
|
| 106 |
+
:param rule_atoms: A set of atom id corresponding to the center atoms of the
|
| 107 |
+
reaction.
|
| 108 |
+
|
| 109 |
+
:return: A set of atom id corresponding to the original rule atoms and the included
|
| 110 |
+
ring structures.
|
| 111 |
+
|
| 112 |
+
"""
|
| 113 |
+
for ring in cgr.sssr:
|
| 114 |
+
# check if the current ring intersects with the set of rule atoms
|
| 115 |
+
if set(ring) & rule_atoms:
|
| 116 |
+
# if the intersection exists, include all atoms in the ring to the rule atoms
|
| 117 |
+
rule_atoms |= set(ring)
|
| 118 |
+
return rule_atoms
|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
def add_leaving_incoming_groups(
|
| 122 |
+
reaction: ReactionContainer,
|
| 123 |
+
rule_atoms: Set[int],
|
| 124 |
+
keep_leaving_groups: bool,
|
| 125 |
+
keep_incoming_groups: bool,
|
| 126 |
+
) -> Tuple[Set[int], Dict[str, Set]]:
|
| 127 |
+
"""
|
| 128 |
+
Identifies and includes leaving and incoming groups to the rule atoms based on
|
| 129 |
+
specified flags.
|
| 130 |
+
|
| 131 |
+
:param reaction: The reaction object (ReactionContainer) from which leaving and
|
| 132 |
+
incoming groups are extracted.
|
| 133 |
+
:param rule_atoms: A set of atom id corresponding to the center atoms of the
|
| 134 |
+
reaction.
|
| 135 |
+
:param keep_leaving_groups: A boolean flag indicating whether to include leaving
|
| 136 |
+
groups in the rule.
|
| 137 |
+
:param keep_incoming_groups: A boolean flag indicating whether to include incoming
|
| 138 |
+
groups in the rule.
|
| 139 |
+
|
| 140 |
+
:return: Updated set of rule atoms including leaving and incoming groups if
|
| 141 |
+
specified, and metadata about added groups.
|
| 142 |
+
|
| 143 |
+
"""
|
| 144 |
+
|
| 145 |
+
meta_debug = {"leaving": set(), "incoming": set()}
|
| 146 |
+
|
| 147 |
+
# extract atoms from reactants and products
|
| 148 |
+
reactant_atoms = {atom for reactant in reaction.reactants for atom in reactant}
|
| 149 |
+
product_atoms = {atom for product in reaction.products for atom in product}
|
| 150 |
+
|
| 151 |
+
# identify leaving groups (reactant atoms not in products)
|
| 152 |
+
if keep_leaving_groups:
|
| 153 |
+
leaving_atoms = reactant_atoms - product_atoms
|
| 154 |
+
new_leaving_atoms = leaving_atoms - rule_atoms
|
| 155 |
+
# include leaving atoms in the rule atoms
|
| 156 |
+
rule_atoms |= leaving_atoms
|
| 157 |
+
# add leaving atoms to metadata
|
| 158 |
+
meta_debug["leaving"] |= new_leaving_atoms
|
| 159 |
+
|
| 160 |
+
# identify incoming groups (product atoms not in reactants)
|
| 161 |
+
if keep_incoming_groups:
|
| 162 |
+
incoming_atoms = product_atoms - reactant_atoms
|
| 163 |
+
new_incoming_atoms = incoming_atoms - rule_atoms
|
| 164 |
+
# Include incoming atoms in the rule atoms
|
| 165 |
+
rule_atoms |= incoming_atoms
|
| 166 |
+
# Add incoming atoms to metadata
|
| 167 |
+
meta_debug["incoming"] |= new_incoming_atoms
|
| 168 |
+
|
| 169 |
+
return rule_atoms, meta_debug
|
| 170 |
+
|
| 171 |
+
|
| 172 |
+
def clean_molecules(
|
| 173 |
+
rule_molecules: List[MoleculeContainer],
|
| 174 |
+
reaction_molecules: Tuple[MoleculeContainer],
|
| 175 |
+
reaction_center_atoms: Set[int],
|
| 176 |
+
atom_retention_details: Dict[str, Dict[str, bool]],
|
| 177 |
+
) -> List[QueryContainer]:
|
| 178 |
+
"""
|
| 179 |
+
Cleans rule molecules by removing specified information about atoms based on
|
| 180 |
+
retention details provided.
|
| 181 |
+
|
| 182 |
+
:param rule_molecules: A list of query container objects representing the rule molecules.
|
| 183 |
+
:param reaction_molecules: A list of molecule container objects involved in the reaction.
|
| 184 |
+
:param reaction_center_atoms: A set of id corresponding to the atom numbers in the reaction center.
|
| 185 |
+
:param atom_retention_details: A dictionary specifying what atom information to retain or remove.
|
| 186 |
+
This dictionary should have two keys: "reaction_center" and "environment",
|
| 187 |
+
each mapping to another dictionary. The nested dictionaries should have
|
| 188 |
+
keys representing atom attributes (like "neighbors", "hybridization",
|
| 189 |
+
"implicit_hydrogens", "ring_sizes") and boolean values.
|
| 190 |
+
A value of True indicates that the corresponding attribute
|
| 191 |
+
should be retained, while False indicates it should be removed from the atom.
|
| 192 |
+
|
| 193 |
+
:return: A list of QueryContainer objects representing the cleaned rule molecules.
|
| 194 |
+
|
| 195 |
+
"""
|
| 196 |
+
cleaned_rule_molecules = []
|
| 197 |
+
|
| 198 |
+
for rule_molecule in rule_molecules:
|
| 199 |
+
for reaction_molecule in reaction_molecules:
|
| 200 |
+
if set(rule_molecule.atoms_numbers) <= set(reaction_molecule.atoms_numbers):
|
| 201 |
+
query_reaction_molecule = reaction_molecule.substructure(
|
| 202 |
+
reaction_molecule, as_query=True
|
| 203 |
+
)
|
| 204 |
+
query_rule_molecule = query_reaction_molecule.substructure(
|
| 205 |
+
rule_molecule
|
| 206 |
+
)
|
| 207 |
+
|
| 208 |
+
# clean reaction center atoms
|
| 209 |
+
if not all(
|
| 210 |
+
atom_retention_details["reaction_center"].values()
|
| 211 |
+
): # if everything True, we keep all marks
|
| 212 |
+
local_reaction_center_atoms = (
|
| 213 |
+
set(rule_molecule.atoms_numbers) & reaction_center_atoms
|
| 214 |
+
)
|
| 215 |
+
for atom_number in local_reaction_center_atoms:
|
| 216 |
+
query_rule_molecule = clean_atom(
|
| 217 |
+
query_rule_molecule,
|
| 218 |
+
atom_retention_details["reaction_center"],
|
| 219 |
+
atom_number,
|
| 220 |
+
)
|
| 221 |
+
|
| 222 |
+
# clean environment atoms
|
| 223 |
+
if not all(
|
| 224 |
+
atom_retention_details["environment"].values()
|
| 225 |
+
): # if everything True, we keep all marks
|
| 226 |
+
local_environment_atoms = (
|
| 227 |
+
set(rule_molecule.atoms_numbers) - reaction_center_atoms
|
| 228 |
+
)
|
| 229 |
+
for atom_number in local_environment_atoms:
|
| 230 |
+
query_rule_molecule = clean_atom(
|
| 231 |
+
query_rule_molecule,
|
| 232 |
+
atom_retention_details["environment"],
|
| 233 |
+
atom_number,
|
| 234 |
+
)
|
| 235 |
+
|
| 236 |
+
cleaned_rule_molecules.append(query_rule_molecule)
|
| 237 |
+
break
|
| 238 |
+
|
| 239 |
+
return cleaned_rule_molecules
|
| 240 |
+
|
| 241 |
+
|
| 242 |
+
def clean_atom(
|
| 243 |
+
query_molecule: QueryContainer,
|
| 244 |
+
attributes_to_keep: Dict[str, bool],
|
| 245 |
+
atom_number: int,
|
| 246 |
+
) -> QueryContainer:
|
| 247 |
+
"""
|
| 248 |
+
Removes specified information from a given atom in a query molecule.
|
| 249 |
+
|
| 250 |
+
:param query_molecule: The QueryContainer of molecule.
|
| 251 |
+
:param attributes_to_keep: Dictionary indicating which attributes to keep in the atom. The keys should be strings
|
| 252 |
+
representing the attribute names, and the values should be booleans indicating whether
|
| 253 |
+
to retain (True) or remove(False) that attribute. Expected keys are:
|
| 254 |
+
- "neighbors": Indicates if neighbors of the atom should be removed.
|
| 255 |
+
- "hybridization": Indicates if hybridization information of the atom should be removed.
|
| 256 |
+
- "implicit_hydrogens": Indicates if implicit hydrogen information of the atom should be removed.
|
| 257 |
+
- "ring_sizes": Indicates if ring size information of the atom should be removed.
|
| 258 |
+
|
| 259 |
+
:param atom_number: The number of the atom to be modified in the query molecule.
|
| 260 |
+
|
| 261 |
+
"""
|
| 262 |
+
|
| 263 |
+
target_atom = query_molecule.atom(atom_number)
|
| 264 |
+
|
| 265 |
+
if not attributes_to_keep["neighbors"]:
|
| 266 |
+
target_atom.neighbors = None
|
| 267 |
+
if not attributes_to_keep["hybridization"]:
|
| 268 |
+
target_atom.hybridization = None
|
| 269 |
+
if not attributes_to_keep["implicit_hydrogens"]:
|
| 270 |
+
target_atom.implicit_hydrogens = None
|
| 271 |
+
if not attributes_to_keep["ring_sizes"]:
|
| 272 |
+
target_atom.ring_sizes = None
|
| 273 |
+
|
| 274 |
+
return query_molecule
|
| 275 |
+
|
| 276 |
+
|
| 277 |
+
def create_substructures_and_reagents(
|
| 278 |
+
reaction: ReactionContainer,
|
| 279 |
+
rule_atoms: Set[int],
|
| 280 |
+
as_query_container: bool,
|
| 281 |
+
keep_reagents: bool,
|
| 282 |
+
) -> Tuple[List[MoleculeContainer], List[MoleculeContainer], List]:
|
| 283 |
+
"""
|
| 284 |
+
Creates substructures for reactants and products, and optionally includes
|
| 285 |
+
reagents, based on specified parameters. The function processes the reaction to
|
| 286 |
+
create substructures for reactants and products based on the rule atoms. It also
|
| 287 |
+
handles the inclusion of reagents based on the keep_reagents flag and converts these
|
| 288 |
+
structures to query containers if required.
|
| 289 |
+
|
| 290 |
+
:param reaction: The reaction object (ReactionContainer) from which to extract substructures.
|
| 291 |
+
This object represents a chemical reaction with specified reactants, products, and possibly reagents.
|
| 292 |
+
:param rule_atoms: A set of atom id corresponding to the rule atoms. These are used to identify relevant
|
| 293 |
+
substructures in reactants and products.
|
| 294 |
+
:param as_query_container: A boolean flag indicating whether the substructures should be converted to query containers.
|
| 295 |
+
Query containers are used for pattern matching in chemical structures.
|
| 296 |
+
:param keep_reagents: A boolean flag indicating whether reagents should be included in the resulting structures.
|
| 297 |
+
Reagents are additional substances that are present in the reaction but are not reactants or products.
|
| 298 |
+
|
| 299 |
+
:return: A tuple containing three elements:
|
| 300 |
+
- A list of reactant substructures, each corresponding to a part of the reactants that matches the rule atoms.
|
| 301 |
+
- A list of product substructures, each corresponding to a part of the products that matches the rule atoms.
|
| 302 |
+
- A list of reagents, included as is or as substructures, depending on the as_query_container flag.
|
| 303 |
+
|
| 304 |
+
"""
|
| 305 |
+
reactant_substructures = [
|
| 306 |
+
reactant.substructure(rule_atoms.intersection(reactant.atoms_numbers))
|
| 307 |
+
for reactant in reaction.reactants
|
| 308 |
+
if rule_atoms.intersection(reactant.atoms_numbers)
|
| 309 |
+
]
|
| 310 |
+
|
| 311 |
+
product_substructures = [
|
| 312 |
+
product.substructure(rule_atoms.intersection(product.atoms_numbers))
|
| 313 |
+
for product in reaction.products
|
| 314 |
+
if rule_atoms.intersection(product.atoms_numbers)
|
| 315 |
+
]
|
| 316 |
+
|
| 317 |
+
reagents = []
|
| 318 |
+
if keep_reagents:
|
| 319 |
+
if as_query_container:
|
| 320 |
+
reagents = [
|
| 321 |
+
reagent.substructure(reagent, as_query=True)
|
| 322 |
+
for reagent in reaction.reagents
|
| 323 |
+
]
|
| 324 |
+
else:
|
| 325 |
+
reagents = reaction.reagents
|
| 326 |
+
|
| 327 |
+
return reactant_substructures, product_substructures, reagents
|
| 328 |
+
|
| 329 |
+
|
| 330 |
+
def assemble_final_rule(
|
| 331 |
+
reactant_substructures: List[QueryContainer],
|
| 332 |
+
product_substructures: List[QueryContainer],
|
| 333 |
+
reagents: List,
|
| 334 |
+
meta_debug: Dict[str, Set],
|
| 335 |
+
keep_metadata: bool,
|
| 336 |
+
reaction: ReactionContainer,
|
| 337 |
+
) -> ReactionContainer:
|
| 338 |
+
"""
|
| 339 |
+
Assembles the final reaction rule from the provided substructures and metadata.
|
| 340 |
+
This function brings together the various components of a reaction rule, including
|
| 341 |
+
reactant and product substructures, reagents, and metadata. It creates a
|
| 342 |
+
comprehensive representation of the reaction rule, which can be used for further
|
| 343 |
+
processing or analysis.
|
| 344 |
+
|
| 345 |
+
:param reactant_substructures: A list of substructures derived from the reactants of
|
| 346 |
+
the reaction. These substructures represent parts of reactants that are relevant
|
| 347 |
+
to the rule.
|
| 348 |
+
:param product_substructures: A list of substructures derived from the products of
|
| 349 |
+
the reaction. These substructures represent parts of products that are relevant
|
| 350 |
+
to the rule.
|
| 351 |
+
:param reagents: A list of reagents involved in the reaction. These may be included
|
| 352 |
+
as-is or as substructures, depending on earlier processing steps.
|
| 353 |
+
:param meta_debug: A dictionary containing additional metadata about the reaction,
|
| 354 |
+
such as leaving and incoming groups.
|
| 355 |
+
:param keep_metadata: A boolean flag indicating whether to retain the metadata
|
| 356 |
+
associated with the reaction in the rule.
|
| 357 |
+
:param reaction: The original reaction object (ReactionContainer) from which the
|
| 358 |
+
rule is being created.
|
| 359 |
+
|
| 360 |
+
:return: A ReactionContainer object representing the assembled reaction rule. This
|
| 361 |
+
container includes the reactant and product substructures, reagents, and any
|
| 362 |
+
additional metadata if keep_metadata is True.
|
| 363 |
+
|
| 364 |
+
"""
|
| 365 |
+
|
| 366 |
+
rule_metadata = meta_debug if keep_metadata else {}
|
| 367 |
+
rule_metadata.update(reaction.meta if keep_metadata else {})
|
| 368 |
+
|
| 369 |
+
rule = ReactionContainer(
|
| 370 |
+
reactant_substructures, product_substructures, reagents, rule_metadata
|
| 371 |
+
)
|
| 372 |
+
|
| 373 |
+
if keep_metadata:
|
| 374 |
+
rule.name = reaction.name
|
| 375 |
+
|
| 376 |
+
rule.flush_cache()
|
| 377 |
+
return rule
|
| 378 |
+
|
| 379 |
+
|
| 380 |
+
def validate_rule(rule: ReactionContainer, reaction: ReactionContainer) -> bool:
|
| 381 |
+
"""
|
| 382 |
+
Validates a reaction rule by ensuring it can correctly generate the products from
|
| 383 |
+
the reactants. The function uses a chemical reactor to simulate the reaction based
|
| 384 |
+
on the provided rule. It then compares the products generated by the simulation with
|
| 385 |
+
the actual products of the reaction. If they match, the rule is considered valid. If
|
| 386 |
+
not, a ValueError is raised, indicating an issue with the rule.
|
| 387 |
+
|
| 388 |
+
:param rule: The reaction rule to be validated. This is a ReactionContainer object
|
| 389 |
+
representing a chemical reaction rule, which includes the necessary information
|
| 390 |
+
to perform a reaction.
|
| 391 |
+
:param reaction: The original reaction object (ReactionContainer) against which the
|
| 392 |
+
rule is to be validated. This object contains the actual reactants and products
|
| 393 |
+
of the reaction.
|
| 394 |
+
|
| 395 |
+
:return: The validated rule if the rule correctly generates the products from the
|
| 396 |
+
reactants.
|
| 397 |
+
|
| 398 |
+
:raises ValueError: If the rule does not correctly generate the products from the
|
| 399 |
+
reactants, indicating an incorrect or incomplete rule.
|
| 400 |
+
|
| 401 |
+
"""
|
| 402 |
+
|
| 403 |
+
# create a reactor with the given rule
|
| 404 |
+
reactor = Reactor(rule)
|
| 405 |
+
try:
|
| 406 |
+
for result_reaction in reactor(reaction.reactants):
|
| 407 |
+
result_products = []
|
| 408 |
+
for result_product in result_reaction.products:
|
| 409 |
+
tmp = result_product.copy()
|
| 410 |
+
try:
|
| 411 |
+
tmp.kekule()
|
| 412 |
+
if tmp.check_valence():
|
| 413 |
+
continue
|
| 414 |
+
except InvalidAromaticRing:
|
| 415 |
+
continue
|
| 416 |
+
result_products.append(result_product)
|
| 417 |
+
if set(reaction.products) == set(result_products) and len(
|
| 418 |
+
reaction.products
|
| 419 |
+
) == len(result_products):
|
| 420 |
+
return True
|
| 421 |
+
|
| 422 |
+
except (KeyError, IndexError):
|
| 423 |
+
# KeyError - iteration over reactor is finished and products are different from the original reaction
|
| 424 |
+
# IndexError - mistake in __contract_ions, possibly problems with charges in reaction rule
|
| 425 |
+
return False
|
| 426 |
+
|
| 427 |
+
return False
|
| 428 |
+
|
| 429 |
+
|
| 430 |
+
def create_rule(
|
| 431 |
+
config: RuleExtractionConfig, reaction: ReactionContainer
|
| 432 |
+
) -> ReactionContainer:
|
| 433 |
+
"""
|
| 434 |
+
Creates a reaction rule from a given reaction based on the specified
|
| 435 |
+
configuration. The function processes the reaction to create a rule that matches the
|
| 436 |
+
configuration settings. It handles the inclusion of environmental atoms, functional
|
| 437 |
+
groups, ring structures, and leaving and incoming groups. It also constructs
|
| 438 |
+
substructures for reactants, products, and reagents, and cleans molecule
|
| 439 |
+
representations if required. Optionally, it validates the rule using a reactor.
|
| 440 |
+
|
| 441 |
+
:param config: An instance of ExtractRuleConfig, containing various settings that
|
| 442 |
+
determine how the rule is created, such as environmental atom count, inclusion
|
| 443 |
+
of functional groups, rings, leaving and incoming groups, and other parameters.
|
| 444 |
+
:param reaction: The reaction object (ReactionContainer) from which to create the
|
| 445 |
+
rule. This object represents a chemical reaction with specified reactants,
|
| 446 |
+
products, and possibly reagents.
|
| 447 |
+
:return: A ReactionContainer object representing the extracted reaction rule. This
|
| 448 |
+
rule includes various elements of the reaction as specified by the
|
| 449 |
+
configuration, such as reaction centers, environmental atoms, functional groups,
|
| 450 |
+
and others.
|
| 451 |
+
|
| 452 |
+
"""
|
| 453 |
+
|
| 454 |
+
# 1. create reaction CGR
|
| 455 |
+
cgr = ~reaction
|
| 456 |
+
center_atoms = set(cgr.center_atoms)
|
| 457 |
+
|
| 458 |
+
# 2. add atoms of reaction environment based on config settings
|
| 459 |
+
center_atoms = add_environment_atoms(
|
| 460 |
+
cgr, center_atoms, config.environment_atom_count
|
| 461 |
+
)
|
| 462 |
+
|
| 463 |
+
# 3. include functional groups in the rule if specified in config
|
| 464 |
+
if config.include_func_groups and config.func_groups_list:
|
| 465 |
+
rule_atoms = add_functional_groups(
|
| 466 |
+
reaction, center_atoms, config.func_groups_list
|
| 467 |
+
)
|
| 468 |
+
else:
|
| 469 |
+
rule_atoms = center_atoms.copy()
|
| 470 |
+
|
| 471 |
+
# 4. include ring structures in the rule if specified in config
|
| 472 |
+
if config.include_rings:
|
| 473 |
+
rule_atoms = add_ring_structures(cgr, rule_atoms)
|
| 474 |
+
|
| 475 |
+
# 5. add leaving and incoming groups to the rule based on config settings
|
| 476 |
+
rule_atoms, meta_debug = add_leaving_incoming_groups(
|
| 477 |
+
reaction, rule_atoms, config.keep_leaving_groups, config.keep_incoming_groups
|
| 478 |
+
)
|
| 479 |
+
|
| 480 |
+
# 6. create substructures for reactants, products, and reagents
|
| 481 |
+
reactant_substructures, product_substructures, reagents = (
|
| 482 |
+
create_substructures_and_reagents(
|
| 483 |
+
reaction, rule_atoms, config.as_query_container, config.keep_reagents
|
| 484 |
+
)
|
| 485 |
+
)
|
| 486 |
+
# 7. clean atom marks in the molecules if they are being converted to query containers
|
| 487 |
+
if config.as_query_container:
|
| 488 |
+
reactant_substructures = clean_molecules(
|
| 489 |
+
reactant_substructures,
|
| 490 |
+
reaction.reactants,
|
| 491 |
+
center_atoms,
|
| 492 |
+
config.atom_info_retention,
|
| 493 |
+
)
|
| 494 |
+
|
| 495 |
+
product_substructures = clean_molecules(
|
| 496 |
+
product_substructures,
|
| 497 |
+
reaction.products,
|
| 498 |
+
center_atoms,
|
| 499 |
+
config.atom_info_retention,
|
| 500 |
+
)
|
| 501 |
+
|
| 502 |
+
# 8. assemble the final rule including metadata if specified
|
| 503 |
+
rule = assemble_final_rule(
|
| 504 |
+
reactant_substructures,
|
| 505 |
+
product_substructures,
|
| 506 |
+
reagents,
|
| 507 |
+
meta_debug,
|
| 508 |
+
config.keep_metadata,
|
| 509 |
+
reaction,
|
| 510 |
+
)
|
| 511 |
+
|
| 512 |
+
# 9. reverse extracted reaction rule and reaction
|
| 513 |
+
if config.reverse_rule:
|
| 514 |
+
rule = reverse_reaction(rule)
|
| 515 |
+
reaction = reverse_reaction(reaction)
|
| 516 |
+
|
| 517 |
+
# 10. validate the rule using a reactor if validation is enabled in config
|
| 518 |
+
if config.reactor_validation:
|
| 519 |
+
if validate_rule(rule, reaction):
|
| 520 |
+
rule.meta["reactor_validation"] = "passed"
|
| 521 |
+
else:
|
| 522 |
+
rule.meta["reactor_validation"] = "failed"
|
| 523 |
+
|
| 524 |
+
return rule
|
| 525 |
+
|
| 526 |
+
|
| 527 |
+
def extract_rules(
|
| 528 |
+
config: RuleExtractionConfig, reaction: ReactionContainer
|
| 529 |
+
) -> List[ReactionContainer]:
|
| 530 |
+
"""
|
| 531 |
+
Extracts reaction rules from a given reaction based on the specified
|
| 532 |
+
configuration.
|
| 533 |
+
|
| 534 |
+
:param config: An instance of ExtractRuleConfig, which contains various
|
| 535 |
+
configuration settings for rule extraction, such as whether to include
|
| 536 |
+
multicenter rules, functional groups, ring structures, leaving and incoming
|
| 537 |
+
groups, etc.
|
| 538 |
+
:param reaction: The reaction object (ReactionContainer) from which to extract
|
| 539 |
+
rules. The reaction object represents a chemical reaction with specified
|
| 540 |
+
reactants, products, and possibly reagents.
|
| 541 |
+
:return: A list of ReactionContainer objects, each representing a distinct reaction
|
| 542 |
+
rule. If config.multicenter_rules is True, a single rule encompassing all
|
| 543 |
+
reaction centers is returned. Otherwise, separate rules for each reaction center
|
| 544 |
+
are extracted, up to a maximum of 15 distinct centers.
|
| 545 |
+
|
| 546 |
+
"""
|
| 547 |
+
|
| 548 |
+
standardizer = (
|
| 549 |
+
RemoveReagentsStandardizer()
|
| 550 |
+
) # reagents are needed if they are the part of reaction rule specification
|
| 551 |
+
reaction = standardizer(reaction)
|
| 552 |
+
|
| 553 |
+
if config.multicenter_rules:
|
| 554 |
+
# extract a single rule encompassing all reaction centers
|
| 555 |
+
return [create_rule(config, reaction)]
|
| 556 |
+
|
| 557 |
+
# extract separate rules for each distinct reaction center
|
| 558 |
+
distinct_rules = set()
|
| 559 |
+
for center_reaction in islice(reaction.enumerate_centers(), 15):
|
| 560 |
+
single_rule = create_rule(config, center_reaction)
|
| 561 |
+
distinct_rules.add(single_rule)
|
| 562 |
+
|
| 563 |
+
return list(distinct_rules)
|
| 564 |
+
|
| 565 |
+
|
| 566 |
+
@ray.remote
|
| 567 |
+
def process_reaction_batch(
|
| 568 |
+
batch: List[Tuple[int, ReactionContainer]], config: RuleExtractionConfig
|
| 569 |
+
) -> List[Tuple[int, List[ReactionContainer]]]:
|
| 570 |
+
"""
|
| 571 |
+
Processes a batch of reactions to extract reaction rules based on the given
|
| 572 |
+
configuration. This function operates as a remote task in a distributed system using
|
| 573 |
+
Ray. It takes a batch of reactions, where each reaction is paired with an index. For
|
| 574 |
+
each reaction in the batch, it extracts reaction rules as specified by the
|
| 575 |
+
configuration object. The extracted rules for each reaction are then returned along
|
| 576 |
+
with the corresponding index. This function is intended to be used in a distributed
|
| 577 |
+
manner with Ray to parallelize the rule extraction process across multiple
|
| 578 |
+
reactions.
|
| 579 |
+
|
| 580 |
+
:param batch: A list where each element is a tuple containing an index (int) and a
|
| 581 |
+
ReactionContainer object. The index is typically used to keep track of the
|
| 582 |
+
reaction's position in a larger dataset.
|
| 583 |
+
:param config: An instance of ExtractRuleConfig that provides settings and
|
| 584 |
+
parameters for the rule extraction process.
|
| 585 |
+
:return: A list where each element is a tuple. The first element of the tuple is an
|
| 586 |
+
index (int), and the second is a list of ReactionContainer objects representing
|
| 587 |
+
the extracted rules for the corresponding reaction.
|
| 588 |
+
|
| 589 |
+
"""
|
| 590 |
+
|
| 591 |
+
extracted_rules_list = []
|
| 592 |
+
for index, reaction in batch:
|
| 593 |
+
try:
|
| 594 |
+
extracted_rules = extract_rules(config, reaction)
|
| 595 |
+
extracted_rules_list.append((index, extracted_rules))
|
| 596 |
+
except Exception as e:
|
| 597 |
+
logging.debug(e)
|
| 598 |
+
continue
|
| 599 |
+
return extracted_rules_list
|
| 600 |
+
|
| 601 |
+
|
| 602 |
+
def process_completed_batch(
|
| 603 |
+
futures: Dict,
|
| 604 |
+
rules_statistics: Dict,
|
| 605 |
+
) -> None:
|
| 606 |
+
"""
|
| 607 |
+
Processes completed batches of reactions, updating the rules statistics and
|
| 608 |
+
writing rules to a file. This function waits for the completion of a batch of
|
| 609 |
+
reactions processed in parallel (using Ray), updates the statistics for each
|
| 610 |
+
extracted rule, and writes the rules to a result file if they are new. It also
|
| 611 |
+
updates the progress bar with the size of the processed batch.
|
| 612 |
+
|
| 613 |
+
:param futures: A dictionary of futures representing ongoing batch processing tasks.
|
| 614 |
+
:param rules_statistics: A dictionary to keep track of statistics for each rule.
|
| 615 |
+
:return: None
|
| 616 |
+
|
| 617 |
+
"""
|
| 618 |
+
|
| 619 |
+
ready_id, running_id = ray.wait(list(futures.keys()), num_returns=1)
|
| 620 |
+
completed_batch = ray.get(ready_id[0])
|
| 621 |
+
for index, extracted_rules in completed_batch:
|
| 622 |
+
for rule in extracted_rules:
|
| 623 |
+
prev_stats_len = len(rules_statistics)
|
| 624 |
+
rules_statistics[rule].append(index)
|
| 625 |
+
if len(rules_statistics) != prev_stats_len:
|
| 626 |
+
rule.meta["first_reaction_index"] = index
|
| 627 |
+
|
| 628 |
+
del futures[ready_id[0]]
|
| 629 |
+
|
| 630 |
+
|
| 631 |
+
def sort_rules(
|
| 632 |
+
rules_stats: Dict, min_popularity: int, single_reactant_only: bool
|
| 633 |
+
) -> List[Tuple[ReactionContainer, List[int]]]:
|
| 634 |
+
"""
|
| 635 |
+
Sorts reaction rules based on their popularity and validation status. This
|
| 636 |
+
function sorts the given rules according to their popularity (i.e., the number of
|
| 637 |
+
times they have been applied) and filters out rules that haven't passed reactor
|
| 638 |
+
validation or are less popular than the specified minimum popularity threshold.
|
| 639 |
+
|
| 640 |
+
:param rules_stats: A dictionary where each key is a reaction rule and the value is
|
| 641 |
+
a list of integers. Each integer represents an index where the rule was applied.
|
| 642 |
+
:type rules_stats: The number of occurrence of the reaction rules.
|
| 643 |
+
:param min_popularity: The minimum number of times a rule must be applied to be
|
| 644 |
+
considered. Default is 3.
|
| 645 |
+
:type min_popularity: The minimum number of occurrence of the reaction rule to be
|
| 646 |
+
selected.
|
| 647 |
+
:param single_reactant_only: Whether to keep only reaction rules with a single
|
| 648 |
+
molecule on the right side of reaction arrow. Default is True.
|
| 649 |
+
|
| 650 |
+
:return: A list of tuples, where each tuple contains a reaction rule and a list of
|
| 651 |
+
indices representing the rule's applications. The list is sorted in descending
|
| 652 |
+
order of the rule's popularity.
|
| 653 |
+
|
| 654 |
+
"""
|
| 655 |
+
|
| 656 |
+
return sorted(
|
| 657 |
+
(
|
| 658 |
+
(rule, indices)
|
| 659 |
+
for rule, indices in rules_stats.items()
|
| 660 |
+
if len(indices) >= min_popularity
|
| 661 |
+
and rule.meta["reactor_validation"] == "passed"
|
| 662 |
+
and (not single_reactant_only or len(rule.reactants) == 1)
|
| 663 |
+
),
|
| 664 |
+
key=lambda x: -len(x[1]),
|
| 665 |
+
)
|
| 666 |
+
|
| 667 |
+
|
| 668 |
+
def extract_rules_from_reactions(
|
| 669 |
+
config: RuleExtractionConfig,
|
| 670 |
+
reaction_data_path: str,
|
| 671 |
+
reaction_rules_path: str,
|
| 672 |
+
num_cpus: int,
|
| 673 |
+
batch_size: int,
|
| 674 |
+
) -> None:
|
| 675 |
+
"""
|
| 676 |
+
Extracts reaction rules from a set of reactions based on the given configuration.
|
| 677 |
+
This function initializes a Ray environment for distributed computing and processes
|
| 678 |
+
each reaction in the provided reaction database to extract reaction rules. It
|
| 679 |
+
handles the reactions in batches, parallelize the rule extraction process. Extracted
|
| 680 |
+
rules are written to RDF files and their statistics are recorded. The function also
|
| 681 |
+
sorts the rules based on their popularity and saves the sorted rules.
|
| 682 |
+
|
| 683 |
+
:param config: Configuration settings for rule extraction, including file paths,
|
| 684 |
+
batch size, and other parameters.
|
| 685 |
+
:param reaction_data_path: Path to the file containing reaction database.
|
| 686 |
+
:param reaction_rules_path: Name of the file to store the extracted rules.
|
| 687 |
+
:param num_cpus: Number of CPU cores to use for processing. Defaults to 1.
|
| 688 |
+
:param batch_size: Number of reactions to process in each batch. Defaults to 10.
|
| 689 |
+
:return: None
|
| 690 |
+
|
| 691 |
+
"""
|
| 692 |
+
|
| 693 |
+
ray.init(num_cpus=num_cpus, ignore_reinit_error=True, logging_level=logging.ERROR)
|
| 694 |
+
|
| 695 |
+
reaction_rules_path, _ = splitext(reaction_rules_path)
|
| 696 |
+
with ReactionReader(reaction_data_path) as reactions:
|
| 697 |
+
|
| 698 |
+
futures = {}
|
| 699 |
+
batch = []
|
| 700 |
+
max_concurrent_batches = num_cpus
|
| 701 |
+
extracted_rules_and_statistics = defaultdict(list)
|
| 702 |
+
|
| 703 |
+
for index, reaction in tqdm(
|
| 704 |
+
enumerate(reactions),
|
| 705 |
+
desc="Number of reactions processed: ",
|
| 706 |
+
bar_format="{desc}{n} [{elapsed}]",
|
| 707 |
+
):
|
| 708 |
+
|
| 709 |
+
# reaction ready to use
|
| 710 |
+
batch.append((index, reaction))
|
| 711 |
+
if len(batch) == batch_size:
|
| 712 |
+
future = process_reaction_batch.remote(batch, config)
|
| 713 |
+
|
| 714 |
+
futures[future] = None
|
| 715 |
+
batch = []
|
| 716 |
+
|
| 717 |
+
while len(futures) >= max_concurrent_batches:
|
| 718 |
+
process_completed_batch(
|
| 719 |
+
futures,
|
| 720 |
+
extracted_rules_and_statistics,
|
| 721 |
+
)
|
| 722 |
+
|
| 723 |
+
if batch:
|
| 724 |
+
future = process_reaction_batch.remote(batch, config)
|
| 725 |
+
futures[future] = None
|
| 726 |
+
|
| 727 |
+
while futures:
|
| 728 |
+
process_completed_batch(
|
| 729 |
+
futures,
|
| 730 |
+
extracted_rules_and_statistics,
|
| 731 |
+
)
|
| 732 |
+
|
| 733 |
+
sorted_rules = sort_rules(
|
| 734 |
+
extracted_rules_and_statistics,
|
| 735 |
+
min_popularity=config.min_popularity,
|
| 736 |
+
single_reactant_only=config.single_reactant_only,
|
| 737 |
+
)
|
| 738 |
+
|
| 739 |
+
ray.shutdown()
|
| 740 |
+
|
| 741 |
+
with open(f"{reaction_rules_path}.pickle", "wb") as statistics_file:
|
| 742 |
+
pickle.dump(sorted_rules, statistics_file)
|
| 743 |
+
|
| 744 |
+
print(f"Number of extracted reaction rules: {len(sorted_rules)}")
|
synplan/chem/reaction_rules/manual/__init__.py
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .decompositions import rules as d_rules
|
| 2 |
+
from .transformations import rules as t_rules
|
| 3 |
+
|
| 4 |
+
hardcoded_rules = t_rules + d_rules
|
| 5 |
+
|
| 6 |
+
__all__ = ["hardcoded_rules"]
|
synplan/chem/reaction_rules/manual/decompositions.py
ADDED
|
@@ -0,0 +1,413 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Module containing hardcoded decomposition reaction rules."""
|
| 2 |
+
|
| 3 |
+
from CGRtools import QueryContainer, ReactionContainer
|
| 4 |
+
from CGRtools.periodictable import ListElement
|
| 5 |
+
|
| 6 |
+
rules = []
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
def prepare():
|
| 10 |
+
"""Creates and returns three query containers and appends a reaction container to
|
| 11 |
+
the "rules" list."""
|
| 12 |
+
q_ = QueryContainer()
|
| 13 |
+
p1_ = QueryContainer()
|
| 14 |
+
p2_ = QueryContainer()
|
| 15 |
+
rules.append(ReactionContainer((q_,), (p1_, p2_)))
|
| 16 |
+
|
| 17 |
+
return q_, p1_, p2_
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
# R-amide/ester formation
|
| 21 |
+
# [C](-[N,O;D23;Zs])(-[C])=[O]>>[A].[C]-[C](-[O])=[O]
|
| 22 |
+
q, p1, p2 = prepare()
|
| 23 |
+
q.add_atom("C")
|
| 24 |
+
q.add_atom("C")
|
| 25 |
+
q.add_atom("O")
|
| 26 |
+
q.add_atom(ListElement(["N", "O"]), hybridization=1, neighbors=(2, 3))
|
| 27 |
+
q.add_bond(1, 2, 1)
|
| 28 |
+
q.add_bond(2, 3, 2)
|
| 29 |
+
q.add_bond(2, 4, 1)
|
| 30 |
+
|
| 31 |
+
p1.add_atom("C")
|
| 32 |
+
p1.add_atom("C")
|
| 33 |
+
p1.add_atom("O")
|
| 34 |
+
p1.add_atom("O", _map=5)
|
| 35 |
+
p1.add_bond(1, 2, 1)
|
| 36 |
+
p1.add_bond(2, 3, 2)
|
| 37 |
+
p1.add_bond(2, 5, 1)
|
| 38 |
+
|
| 39 |
+
p2.add_atom("A", _map=4)
|
| 40 |
+
|
| 41 |
+
# acyl group addition with aromatic carbon's case (Friedel-Crafts)
|
| 42 |
+
# [C;Za]-[C](-[C])=[O]>>[C].[C]-[C](-[Cl])=[O]
|
| 43 |
+
q, p1, p2 = prepare()
|
| 44 |
+
q.add_atom("C")
|
| 45 |
+
q.add_atom("C")
|
| 46 |
+
q.add_atom("O")
|
| 47 |
+
q.add_atom("C", hybridization=4)
|
| 48 |
+
q.add_bond(1, 2, 1)
|
| 49 |
+
q.add_bond(2, 3, 2)
|
| 50 |
+
q.add_bond(2, 4, 1)
|
| 51 |
+
|
| 52 |
+
p1.add_atom("C")
|
| 53 |
+
p1.add_atom("C")
|
| 54 |
+
p1.add_atom("O")
|
| 55 |
+
p1.add_atom("Cl", _map=5)
|
| 56 |
+
p1.add_bond(1, 2, 1)
|
| 57 |
+
p1.add_bond(2, 3, 2)
|
| 58 |
+
p1.add_bond(2, 5, 1)
|
| 59 |
+
|
| 60 |
+
p2.add_atom("C", _map=4)
|
| 61 |
+
|
| 62 |
+
# Williamson reaction
|
| 63 |
+
# [C;Za]-[O]-[C;Zs;W0]>>[C]-[Br].[C]-[O]
|
| 64 |
+
q, p1, p2 = prepare()
|
| 65 |
+
q.add_atom("C", hybridization=4)
|
| 66 |
+
q.add_atom("O")
|
| 67 |
+
q.add_atom("C", hybridization=1, heteroatoms=1)
|
| 68 |
+
q.add_bond(1, 2, 1)
|
| 69 |
+
q.add_bond(2, 3, 1)
|
| 70 |
+
|
| 71 |
+
p1.add_atom("C")
|
| 72 |
+
p1.add_atom("O")
|
| 73 |
+
p1.add_bond(1, 2, 1)
|
| 74 |
+
|
| 75 |
+
p2.add_atom("C", _map=3)
|
| 76 |
+
p2.add_atom("Br")
|
| 77 |
+
p2.add_bond(3, 4, 1)
|
| 78 |
+
|
| 79 |
+
# Buchwald-Hartwig amination
|
| 80 |
+
# [N;D23;Zs;W0]-[C;Za]>>[C]-[Br].[N]
|
| 81 |
+
q, p1, p2 = prepare()
|
| 82 |
+
q.add_atom("N", heteroatoms=0, hybridization=1, neighbors=(2, 3))
|
| 83 |
+
q.add_atom("C", hybridization=4)
|
| 84 |
+
q.add_bond(1, 2, 1)
|
| 85 |
+
|
| 86 |
+
p1.add_atom("C", _map=2)
|
| 87 |
+
p1.add_atom("Br")
|
| 88 |
+
p1.add_bond(2, 3, 1)
|
| 89 |
+
|
| 90 |
+
p2.add_atom("N")
|
| 91 |
+
|
| 92 |
+
# imidazole imine atom's alkylation
|
| 93 |
+
# [C;r5](:[N;r5]-[C;Zs;W1]):[N;D2;r5]>>[C]-[Br].[N]:[C]:[N]
|
| 94 |
+
q, p1, p2 = prepare()
|
| 95 |
+
q.add_atom("N", rings_sizes=5)
|
| 96 |
+
q.add_atom("C", rings_sizes=5)
|
| 97 |
+
q.add_atom("N", rings_sizes=5, neighbors=2)
|
| 98 |
+
q.add_atom("C", hybridization=1, heteroatoms=(1, 2))
|
| 99 |
+
q.add_bond(1, 2, 4)
|
| 100 |
+
q.add_bond(2, 3, 4)
|
| 101 |
+
q.add_bond(1, 4, 1)
|
| 102 |
+
|
| 103 |
+
p1.add_atom("N")
|
| 104 |
+
p1.add_atom("C")
|
| 105 |
+
p1.add_atom("N")
|
| 106 |
+
p1.add_bond(1, 2, 4)
|
| 107 |
+
p1.add_bond(2, 3, 4)
|
| 108 |
+
|
| 109 |
+
p2.add_atom("C", _map=4)
|
| 110 |
+
p2.add_atom("Br")
|
| 111 |
+
p2.add_bond(4, 5, 1)
|
| 112 |
+
|
| 113 |
+
# Knoevenagel condensation (nitryl and carboxyl case)
|
| 114 |
+
# [C]=[C](-[C]#[N])-[C](-[O])=[O]>>[C]=[O].[C](-[C]#[N])-[C](-[O])=[O]
|
| 115 |
+
q, p1, p2 = prepare()
|
| 116 |
+
q.add_atom("C")
|
| 117 |
+
q.add_atom("C")
|
| 118 |
+
q.add_atom("C")
|
| 119 |
+
q.add_atom("N")
|
| 120 |
+
q.add_atom("C")
|
| 121 |
+
q.add_atom("O")
|
| 122 |
+
q.add_atom("O")
|
| 123 |
+
q.add_bond(1, 2, 2)
|
| 124 |
+
q.add_bond(2, 3, 1)
|
| 125 |
+
q.add_bond(3, 4, 3)
|
| 126 |
+
q.add_bond(2, 5, 1)
|
| 127 |
+
q.add_bond(5, 6, 2)
|
| 128 |
+
q.add_bond(5, 7, 1)
|
| 129 |
+
|
| 130 |
+
p1.add_atom("C", _map=2)
|
| 131 |
+
p1.add_atom("C")
|
| 132 |
+
p1.add_atom("N")
|
| 133 |
+
p1.add_atom("C")
|
| 134 |
+
p1.add_atom("O")
|
| 135 |
+
p1.add_atom("O")
|
| 136 |
+
p1.add_bond(2, 3, 1)
|
| 137 |
+
p1.add_bond(3, 4, 3)
|
| 138 |
+
p1.add_bond(2, 5, 1)
|
| 139 |
+
p1.add_bond(5, 6, 2)
|
| 140 |
+
p1.add_bond(5, 7, 1)
|
| 141 |
+
|
| 142 |
+
p2.add_atom("C", _map=1)
|
| 143 |
+
p2.add_atom("O", _map=8)
|
| 144 |
+
p2.add_bond(1, 8, 2)
|
| 145 |
+
|
| 146 |
+
# Knoevenagel condensation (double nitryl case)
|
| 147 |
+
# [C]=[C](-[C]#[N])-[C]#[N]>>[C]=[O].[C](-[C]#[N])-[C]#[N]
|
| 148 |
+
q, p1, p2 = prepare()
|
| 149 |
+
q.add_atom("C")
|
| 150 |
+
q.add_atom("C")
|
| 151 |
+
q.add_atom("C")
|
| 152 |
+
q.add_atom("N")
|
| 153 |
+
q.add_atom("C")
|
| 154 |
+
q.add_atom("N")
|
| 155 |
+
q.add_bond(1, 2, 2)
|
| 156 |
+
q.add_bond(2, 3, 1)
|
| 157 |
+
q.add_bond(3, 4, 3)
|
| 158 |
+
q.add_bond(2, 5, 1)
|
| 159 |
+
q.add_bond(5, 6, 3)
|
| 160 |
+
|
| 161 |
+
p1.add_atom("C", _map=2)
|
| 162 |
+
p1.add_atom("C")
|
| 163 |
+
p1.add_atom("N")
|
| 164 |
+
p1.add_atom("C")
|
| 165 |
+
p1.add_atom("N")
|
| 166 |
+
p1.add_bond(2, 3, 1)
|
| 167 |
+
p1.add_bond(3, 4, 3)
|
| 168 |
+
p1.add_bond(2, 5, 1)
|
| 169 |
+
p1.add_bond(5, 6, 3)
|
| 170 |
+
|
| 171 |
+
p2.add_atom("C", _map=1)
|
| 172 |
+
p2.add_atom("O", _map=8)
|
| 173 |
+
p2.add_bond(1, 8, 2)
|
| 174 |
+
|
| 175 |
+
# Knoevenagel condensation (double carboxyl case)
|
| 176 |
+
# [C]=[C](-[C](-[O])=[O])-[C](-[O])=[O]>>[C]=[O].[C](-[C](-[O])=[O])-[C](-[O])=[O]
|
| 177 |
+
q, p1, p2 = prepare()
|
| 178 |
+
q.add_atom("C")
|
| 179 |
+
q.add_atom("C")
|
| 180 |
+
q.add_atom("C")
|
| 181 |
+
q.add_atom("O")
|
| 182 |
+
q.add_atom("O")
|
| 183 |
+
q.add_atom("C")
|
| 184 |
+
q.add_atom("O")
|
| 185 |
+
q.add_atom("O")
|
| 186 |
+
q.add_bond(1, 2, 2)
|
| 187 |
+
q.add_bond(2, 3, 1)
|
| 188 |
+
q.add_bond(3, 4, 2)
|
| 189 |
+
q.add_bond(3, 5, 1)
|
| 190 |
+
q.add_bond(2, 6, 1)
|
| 191 |
+
q.add_bond(6, 7, 2)
|
| 192 |
+
q.add_bond(6, 8, 1)
|
| 193 |
+
|
| 194 |
+
p1.add_atom("C", _map=2)
|
| 195 |
+
p1.add_atom("C")
|
| 196 |
+
p1.add_atom("O")
|
| 197 |
+
p1.add_atom("O")
|
| 198 |
+
p1.add_atom("C")
|
| 199 |
+
p1.add_atom("O")
|
| 200 |
+
p1.add_atom("O")
|
| 201 |
+
p1.add_bond(2, 3, 1)
|
| 202 |
+
p1.add_bond(3, 4, 2)
|
| 203 |
+
p1.add_bond(3, 5, 1)
|
| 204 |
+
p1.add_bond(2, 6, 1)
|
| 205 |
+
p1.add_bond(6, 7, 2)
|
| 206 |
+
p1.add_bond(6, 8, 1)
|
| 207 |
+
|
| 208 |
+
p2.add_atom("C", _map=1)
|
| 209 |
+
p2.add_atom("O", _map=9)
|
| 210 |
+
p2.add_bond(1, 9, 2)
|
| 211 |
+
|
| 212 |
+
# heterocyclization with guanidine
|
| 213 |
+
# [c]((-[N;W0;Zs])@[n]@[c](-[N;D1])@[c;W0])@[n]@[c]-[O; D1]>>[C](-[N])(=[N])-[N].[C](#[N])-[C]-[C](-[O])=[O]
|
| 214 |
+
q, p1, p2 = prepare()
|
| 215 |
+
q.add_atom("C")
|
| 216 |
+
q.add_atom("N", heteroatoms=0, hybridization=1)
|
| 217 |
+
q.add_atom("N")
|
| 218 |
+
q.add_atom("C")
|
| 219 |
+
q.add_atom("N", neighbors=1)
|
| 220 |
+
q.add_atom("C", heteroatoms=0)
|
| 221 |
+
q.add_atom("N")
|
| 222 |
+
q.add_atom("C")
|
| 223 |
+
q.add_atom("O", neighbors=1)
|
| 224 |
+
q.add_bond(1, 2, 1)
|
| 225 |
+
q.add_bond(1, 3, 4)
|
| 226 |
+
q.add_bond(3, 4, 4)
|
| 227 |
+
q.add_bond(4, 5, 1)
|
| 228 |
+
q.add_bond(4, 6, 4)
|
| 229 |
+
q.add_bond(1, 7, 4)
|
| 230 |
+
q.add_bond(7, 8, 4)
|
| 231 |
+
q.add_bond(8, 9, 1)
|
| 232 |
+
|
| 233 |
+
p1.add_atom("C")
|
| 234 |
+
p1.add_atom("N")
|
| 235 |
+
p1.add_atom("N")
|
| 236 |
+
p1.add_atom("N", _map=7)
|
| 237 |
+
p1.add_bond(1, 2, 1)
|
| 238 |
+
p1.add_bond(1, 3, 2)
|
| 239 |
+
p1.add_bond(1, 7, 1)
|
| 240 |
+
|
| 241 |
+
p2.add_atom("C", _map=4)
|
| 242 |
+
p2.add_atom("N")
|
| 243 |
+
p2.add_atom("C")
|
| 244 |
+
p2.add_atom("C", _map=8)
|
| 245 |
+
p2.add_atom("O", _map=9)
|
| 246 |
+
p2.add_atom("O")
|
| 247 |
+
p2.add_bond(4, 5, 3)
|
| 248 |
+
p2.add_bond(4, 6, 1)
|
| 249 |
+
p2.add_bond(6, 8, 1)
|
| 250 |
+
p2.add_bond(8, 9, 2)
|
| 251 |
+
p2.add_bond(8, 10, 1)
|
| 252 |
+
|
| 253 |
+
# alkylation of amine
|
| 254 |
+
# [C]-[N]-[C]>>[C]-[N].[C]-[Br]
|
| 255 |
+
q, p1, p2 = prepare()
|
| 256 |
+
q.add_atom("C")
|
| 257 |
+
q.add_atom("N")
|
| 258 |
+
q.add_atom("C")
|
| 259 |
+
q.add_atom("C")
|
| 260 |
+
q.add_bond(1, 2, 1)
|
| 261 |
+
q.add_bond(2, 3, 1)
|
| 262 |
+
q.add_bond(2, 4, 1)
|
| 263 |
+
|
| 264 |
+
p1.add_atom("C")
|
| 265 |
+
p1.add_atom("N")
|
| 266 |
+
p1.add_atom("C")
|
| 267 |
+
p1.add_bond(1, 2, 1)
|
| 268 |
+
p1.add_bond(2, 3, 1)
|
| 269 |
+
|
| 270 |
+
p2.add_atom("C", _map=4)
|
| 271 |
+
p2.add_atom("Cl")
|
| 272 |
+
p2.add_bond(4, 5, 1)
|
| 273 |
+
|
| 274 |
+
# Synthesis of guanidines
|
| 275 |
+
#
|
| 276 |
+
q, p1, p2 = prepare()
|
| 277 |
+
q.add_atom("N")
|
| 278 |
+
q.add_atom("C")
|
| 279 |
+
q.add_atom("N", hybridization=1)
|
| 280 |
+
q.add_atom("N", hybridization=1)
|
| 281 |
+
q.add_bond(1, 2, 2)
|
| 282 |
+
q.add_bond(2, 3, 1)
|
| 283 |
+
q.add_bond(2, 4, 1)
|
| 284 |
+
|
| 285 |
+
p1.add_atom("N")
|
| 286 |
+
p1.add_atom("C")
|
| 287 |
+
p1.add_atom("N")
|
| 288 |
+
p1.add_bond(1, 2, 3)
|
| 289 |
+
p1.add_bond(2, 3, 1)
|
| 290 |
+
|
| 291 |
+
p2.add_atom("N", _map=4)
|
| 292 |
+
|
| 293 |
+
# Grignard reaction with nitrile
|
| 294 |
+
#
|
| 295 |
+
q, p1, p2 = prepare()
|
| 296 |
+
q.add_atom("C")
|
| 297 |
+
q.add_atom("C")
|
| 298 |
+
q.add_atom("O")
|
| 299 |
+
q.add_atom("C")
|
| 300 |
+
q.add_bond(1, 2, 1)
|
| 301 |
+
q.add_bond(2, 3, 2)
|
| 302 |
+
q.add_bond(2, 4, 1)
|
| 303 |
+
|
| 304 |
+
p1.add_atom("C")
|
| 305 |
+
p1.add_atom("C")
|
| 306 |
+
p1.add_atom("N")
|
| 307 |
+
p1.add_bond(1, 2, 1)
|
| 308 |
+
p1.add_bond(2, 3, 3)
|
| 309 |
+
|
| 310 |
+
p2.add_atom("C", _map=4)
|
| 311 |
+
p2.add_atom("Br")
|
| 312 |
+
p2.add_bond(4, 5, 1)
|
| 313 |
+
|
| 314 |
+
# Alkylation of alpha-carbon atom of nitrile
|
| 315 |
+
#
|
| 316 |
+
q, p1, p2 = prepare()
|
| 317 |
+
q.add_atom("N")
|
| 318 |
+
q.add_atom("C")
|
| 319 |
+
q.add_atom("C", neighbors=(3, 4))
|
| 320 |
+
q.add_atom("C", hybridization=1)
|
| 321 |
+
q.add_bond(1, 2, 3)
|
| 322 |
+
q.add_bond(2, 3, 1)
|
| 323 |
+
q.add_bond(3, 4, 1)
|
| 324 |
+
|
| 325 |
+
p1.add_atom("N")
|
| 326 |
+
p1.add_atom("C")
|
| 327 |
+
p1.add_atom("C")
|
| 328 |
+
p1.add_bond(1, 2, 3)
|
| 329 |
+
p1.add_bond(2, 3, 1)
|
| 330 |
+
|
| 331 |
+
p2.add_atom("C", _map=4)
|
| 332 |
+
p2.add_atom("Cl")
|
| 333 |
+
p2.add_bond(4, 5, 1)
|
| 334 |
+
|
| 335 |
+
# Gomberg-Bachmann reaction
|
| 336 |
+
#
|
| 337 |
+
q, p1, p2 = prepare()
|
| 338 |
+
q.add_atom("C", hybridization=4, heteroatoms=0)
|
| 339 |
+
q.add_atom("C", hybridization=4, heteroatoms=0)
|
| 340 |
+
q.add_bond(1, 2, 1)
|
| 341 |
+
|
| 342 |
+
p1.add_atom("C")
|
| 343 |
+
p1.add_atom("N", _map=3)
|
| 344 |
+
p1.add_bond(1, 3, 1)
|
| 345 |
+
|
| 346 |
+
p2.add_atom("C", _map=2)
|
| 347 |
+
|
| 348 |
+
# Cyclocondensation
|
| 349 |
+
#
|
| 350 |
+
q, p1, p2 = prepare()
|
| 351 |
+
q.add_atom("N", neighbors=2)
|
| 352 |
+
q.add_atom("C")
|
| 353 |
+
q.add_atom("C")
|
| 354 |
+
q.add_atom("C")
|
| 355 |
+
q.add_atom("N")
|
| 356 |
+
q.add_atom("C")
|
| 357 |
+
q.add_atom("C")
|
| 358 |
+
q.add_atom("O", neighbors=1)
|
| 359 |
+
q.add_bond(1, 2, 1)
|
| 360 |
+
q.add_bond(2, 3, 1)
|
| 361 |
+
q.add_bond(3, 4, 1)
|
| 362 |
+
q.add_bond(4, 5, 2)
|
| 363 |
+
q.add_bond(5, 6, 1)
|
| 364 |
+
q.add_bond(6, 7, 1)
|
| 365 |
+
q.add_bond(7, 8, 2)
|
| 366 |
+
q.add_bond(1, 7, 1)
|
| 367 |
+
|
| 368 |
+
p1.add_atom("N")
|
| 369 |
+
p1.add_atom("C")
|
| 370 |
+
p1.add_atom("C")
|
| 371 |
+
p1.add_atom("C")
|
| 372 |
+
p1.add_atom("O", _map=9)
|
| 373 |
+
p1.add_bond(1, 2, 1)
|
| 374 |
+
p1.add_bond(2, 3, 1)
|
| 375 |
+
p1.add_bond(3, 4, 1)
|
| 376 |
+
p1.add_bond(4, 9, 2)
|
| 377 |
+
|
| 378 |
+
p2.add_atom("N", _map=5)
|
| 379 |
+
p2.add_atom("C")
|
| 380 |
+
p2.add_atom("C")
|
| 381 |
+
p2.add_atom("O")
|
| 382 |
+
p2.add_atom("O", _map=10)
|
| 383 |
+
p2.add_bond(5, 6, 1)
|
| 384 |
+
p2.add_bond(6, 7, 1)
|
| 385 |
+
p2.add_bond(7, 8, 2)
|
| 386 |
+
p2.add_bond(7, 10, 1)
|
| 387 |
+
|
| 388 |
+
# heterocyclization dicarboxylic acids
|
| 389 |
+
#
|
| 390 |
+
q, p1, p2 = prepare()
|
| 391 |
+
q.add_atom("C", rings_sizes=(5, 6))
|
| 392 |
+
q.add_atom("O")
|
| 393 |
+
q.add_atom(ListElement(["O", "N"]))
|
| 394 |
+
q.add_atom("C", rings_sizes=(5, 6))
|
| 395 |
+
q.add_atom("O")
|
| 396 |
+
q.add_bond(1, 2, 2)
|
| 397 |
+
q.add_bond(1, 3, 1)
|
| 398 |
+
q.add_bond(3, 4, 1)
|
| 399 |
+
q.add_bond(4, 5, 2)
|
| 400 |
+
|
| 401 |
+
p1.add_atom("C")
|
| 402 |
+
p1.add_atom("O")
|
| 403 |
+
p1.add_atom("O", _map=6)
|
| 404 |
+
p1.add_bond(1, 2, 2)
|
| 405 |
+
p1.add_bond(1, 6, 1)
|
| 406 |
+
|
| 407 |
+
p2.add_atom("C", _map=4)
|
| 408 |
+
p2.add_atom("O")
|
| 409 |
+
p2.add_atom("O", _map=7)
|
| 410 |
+
p2.add_bond(4, 5, 2)
|
| 411 |
+
p2.add_bond(4, 7, 1)
|
| 412 |
+
|
| 413 |
+
__all__ = ["rules"]
|
synplan/chem/reaction_rules/manual/transformations.py
ADDED
|
@@ -0,0 +1,532 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Module containing hardcoded transformation reaction rules."""
|
| 2 |
+
|
| 3 |
+
from CGRtools import QueryContainer, ReactionContainer
|
| 4 |
+
from CGRtools.periodictable import ListElement
|
| 5 |
+
|
| 6 |
+
rules = []
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
def prepare():
|
| 10 |
+
"""Creates and returns three query containers and appends a reaction container to
|
| 11 |
+
the "rules" list."""
|
| 12 |
+
q_ = QueryContainer()
|
| 13 |
+
p_ = QueryContainer()
|
| 14 |
+
rules.append(ReactionContainer((q_,), (p_,)))
|
| 15 |
+
return q_, p_
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
# aryl nitro reduction
|
| 19 |
+
# [C;Za;W1]-[N;D1]>>[O-]-[N+](-[C])=[O]
|
| 20 |
+
q, p = prepare()
|
| 21 |
+
q.add_atom("N", neighbors=1)
|
| 22 |
+
q.add_atom("C", hybridization=4, heteroatoms=1)
|
| 23 |
+
q.add_bond(1, 2, 1)
|
| 24 |
+
|
| 25 |
+
p.add_atom("N", charge=1)
|
| 26 |
+
p.add_atom("C")
|
| 27 |
+
p.add_atom("O", charge=-1)
|
| 28 |
+
p.add_atom("O")
|
| 29 |
+
p.add_bond(1, 2, 1)
|
| 30 |
+
p.add_bond(1, 3, 1)
|
| 31 |
+
p.add_bond(1, 4, 2)
|
| 32 |
+
|
| 33 |
+
# aryl nitration
|
| 34 |
+
# [O-]-[N+](=[O])-[C;Za;W12]>>[C]
|
| 35 |
+
q, p = prepare()
|
| 36 |
+
q.add_atom("N", charge=1)
|
| 37 |
+
q.add_atom("C", hybridization=4, heteroatoms=(1, 2))
|
| 38 |
+
q.add_atom("O", charge=-1)
|
| 39 |
+
q.add_atom("O")
|
| 40 |
+
q.add_bond(1, 2, 1)
|
| 41 |
+
q.add_bond(1, 3, 1)
|
| 42 |
+
q.add_bond(1, 4, 2)
|
| 43 |
+
|
| 44 |
+
p.add_atom("C", _map=2)
|
| 45 |
+
|
| 46 |
+
# Beckmann rearrangement (oxime -> amide)
|
| 47 |
+
# [C]-[N;D2]-[C]=[O]>>[O]-[N]=[C]-[C]
|
| 48 |
+
q, p = prepare()
|
| 49 |
+
q.add_atom("C")
|
| 50 |
+
q.add_atom("N", neighbors=2)
|
| 51 |
+
q.add_atom("O")
|
| 52 |
+
q.add_atom("C")
|
| 53 |
+
q.add_bond(1, 2, 1)
|
| 54 |
+
q.add_bond(1, 3, 2)
|
| 55 |
+
q.add_bond(2, 4, 1)
|
| 56 |
+
|
| 57 |
+
p.add_atom("C")
|
| 58 |
+
p.add_atom("N")
|
| 59 |
+
p.add_atom("O")
|
| 60 |
+
p.add_atom("C")
|
| 61 |
+
p.add_bond(1, 2, 2)
|
| 62 |
+
p.add_bond(2, 3, 1)
|
| 63 |
+
p.add_bond(1, 4, 1)
|
| 64 |
+
|
| 65 |
+
# aldehydes or ketones into oxime/imine reaction
|
| 66 |
+
# [C;Zd;W1]=[N]>>[C]=[O]
|
| 67 |
+
q, p = prepare()
|
| 68 |
+
q.add_atom("C", hybridization=2, heteroatoms=1)
|
| 69 |
+
q.add_atom("N")
|
| 70 |
+
q.add_bond(1, 2, 2)
|
| 71 |
+
|
| 72 |
+
p.add_atom("C")
|
| 73 |
+
p.add_atom("O", _map=3)
|
| 74 |
+
p.add_bond(1, 3, 2)
|
| 75 |
+
|
| 76 |
+
# addition of halogen atom into phenol ring (orto)
|
| 77 |
+
# [C](-[Cl,F,Br,I;D1]):[C]-[O,N;Zs]>>[C](-[A]):[C]
|
| 78 |
+
q, p = prepare()
|
| 79 |
+
q.add_atom(ListElement(["O", "N"]), hybridization=1)
|
| 80 |
+
q.add_atom("C")
|
| 81 |
+
q.add_atom("C")
|
| 82 |
+
q.add_atom(ListElement(["Cl", "F", "Br", "I"]), neighbors=1)
|
| 83 |
+
q.add_bond(1, 2, 1)
|
| 84 |
+
q.add_bond(2, 3, 4)
|
| 85 |
+
q.add_bond(3, 4, 1)
|
| 86 |
+
|
| 87 |
+
p.add_atom("A")
|
| 88 |
+
p.add_atom("C")
|
| 89 |
+
p.add_atom("C")
|
| 90 |
+
p.add_bond(1, 2, 1)
|
| 91 |
+
p.add_bond(2, 3, 4)
|
| 92 |
+
|
| 93 |
+
# addition of halogen atom into phenol ring (para)
|
| 94 |
+
# [C](:[C]:[C]:[C]-[O,N;Zs])-[Cl,F,Br,I;D1]>>[A]-[C]:[C]:[C]:[C]
|
| 95 |
+
q, p = prepare()
|
| 96 |
+
q.add_atom(ListElement(["O", "N"]), hybridization=1)
|
| 97 |
+
q.add_atom("C")
|
| 98 |
+
q.add_atom("C")
|
| 99 |
+
q.add_atom("C")
|
| 100 |
+
q.add_atom("C")
|
| 101 |
+
q.add_atom(ListElement(["Cl", "F", "Br", "I"]), neighbors=1)
|
| 102 |
+
q.add_bond(1, 2, 1)
|
| 103 |
+
q.add_bond(2, 3, 4)
|
| 104 |
+
q.add_bond(3, 4, 4)
|
| 105 |
+
q.add_bond(4, 5, 4)
|
| 106 |
+
q.add_bond(5, 6, 1)
|
| 107 |
+
|
| 108 |
+
p.add_atom("A")
|
| 109 |
+
p.add_atom("C")
|
| 110 |
+
p.add_atom("C")
|
| 111 |
+
p.add_atom("C")
|
| 112 |
+
p.add_atom("C")
|
| 113 |
+
p.add_bond(1, 2, 1)
|
| 114 |
+
p.add_bond(2, 3, 4)
|
| 115 |
+
p.add_bond(3, 4, 4)
|
| 116 |
+
p.add_bond(4, 5, 4)
|
| 117 |
+
|
| 118 |
+
# hard reduction of Ar-ketones
|
| 119 |
+
# [C;Za]-[C;D2;Zs;W0]>>[C]-[C]=[O]
|
| 120 |
+
q, p = prepare()
|
| 121 |
+
q.add_atom("C", hybridization=4)
|
| 122 |
+
q.add_atom("C", hybridization=1, neighbors=2, heteroatoms=0)
|
| 123 |
+
q.add_bond(1, 2, 1)
|
| 124 |
+
|
| 125 |
+
p.add_atom("C")
|
| 126 |
+
p.add_atom("C")
|
| 127 |
+
p.add_atom("O")
|
| 128 |
+
p.add_bond(1, 2, 1)
|
| 129 |
+
p.add_bond(2, 3, 2)
|
| 130 |
+
|
| 131 |
+
# reduction of alpha-hydroxy pyridine
|
| 132 |
+
# [C;W1]:[N;H0;r6]>>[C](:[N])-[O]
|
| 133 |
+
q, p = prepare()
|
| 134 |
+
q.add_atom("C", heteroatoms=1)
|
| 135 |
+
q.add_atom("N", rings_sizes=6, hydrogens=0)
|
| 136 |
+
q.add_bond(1, 2, 4)
|
| 137 |
+
|
| 138 |
+
p.add_atom("C")
|
| 139 |
+
p.add_atom("N")
|
| 140 |
+
p.add_atom("O")
|
| 141 |
+
p.add_bond(1, 2, 4)
|
| 142 |
+
p.add_bond(1, 3, 1)
|
| 143 |
+
|
| 144 |
+
# Reduction of alkene
|
| 145 |
+
# [C]-[C;D23;Zs;W0]-[C;D123;Zs;W0]>>[C](-[C])=[C]
|
| 146 |
+
q, p = prepare()
|
| 147 |
+
q.add_atom("C")
|
| 148 |
+
q.add_atom("C", heteroatoms=0, neighbors=(2, 3), hybridization=1)
|
| 149 |
+
q.add_atom("C", heteroatoms=0, neighbors=(1, 2, 3), hybridization=1)
|
| 150 |
+
q.add_bond(1, 2, 1)
|
| 151 |
+
q.add_bond(2, 3, 1)
|
| 152 |
+
|
| 153 |
+
p.add_atom("C")
|
| 154 |
+
p.add_atom("C")
|
| 155 |
+
p.add_atom("C")
|
| 156 |
+
p.add_bond(1, 2, 1)
|
| 157 |
+
p.add_bond(2, 3, 2)
|
| 158 |
+
|
| 159 |
+
# Kolbe-Schmitt reaction
|
| 160 |
+
# [C](:[C]-[O;D1])-[C](=[O])-[O;D1]>>[C](-[O]):[C]
|
| 161 |
+
q, p = prepare()
|
| 162 |
+
q.add_atom("O", neighbors=1)
|
| 163 |
+
q.add_atom("C")
|
| 164 |
+
q.add_atom("C")
|
| 165 |
+
q.add_atom("C")
|
| 166 |
+
q.add_atom("O", neighbors=1)
|
| 167 |
+
q.add_atom("O")
|
| 168 |
+
q.add_bond(1, 2, 1)
|
| 169 |
+
q.add_bond(2, 3, 4)
|
| 170 |
+
q.add_bond(3, 4, 1)
|
| 171 |
+
q.add_bond(4, 5, 1)
|
| 172 |
+
q.add_bond(4, 6, 2)
|
| 173 |
+
|
| 174 |
+
p.add_atom("O")
|
| 175 |
+
p.add_atom("C")
|
| 176 |
+
p.add_atom("C")
|
| 177 |
+
p.add_bond(1, 2, 1)
|
| 178 |
+
p.add_bond(2, 3, 4)
|
| 179 |
+
|
| 180 |
+
# reduction of carboxylic acid
|
| 181 |
+
# [O;D1]-[C;D2]-[C]>>[C]-[C](-[O])=[O]
|
| 182 |
+
q, p = prepare()
|
| 183 |
+
q.add_atom("C")
|
| 184 |
+
q.add_atom("C", neighbors=2)
|
| 185 |
+
q.add_atom("O", neighbors=1)
|
| 186 |
+
q.add_bond(1, 2, 1)
|
| 187 |
+
q.add_bond(2, 3, 1)
|
| 188 |
+
|
| 189 |
+
p.add_atom("C")
|
| 190 |
+
p.add_atom("C")
|
| 191 |
+
p.add_atom("O")
|
| 192 |
+
p.add_atom("O")
|
| 193 |
+
p.add_bond(1, 2, 1)
|
| 194 |
+
p.add_bond(2, 3, 1)
|
| 195 |
+
p.add_bond(2, 4, 2)
|
| 196 |
+
|
| 197 |
+
# halogenation of alcohols
|
| 198 |
+
# [C;Zs]-[Cl,Br;D1]>>[C]-[O]
|
| 199 |
+
q, p = prepare()
|
| 200 |
+
q.add_atom("C", hybridization=1, heteroatoms=1)
|
| 201 |
+
q.add_atom(ListElement(["Cl", "Br"]), neighbors=1)
|
| 202 |
+
q.add_bond(1, 2, 1)
|
| 203 |
+
|
| 204 |
+
p.add_atom("C")
|
| 205 |
+
p.add_atom("O", _map=3)
|
| 206 |
+
p.add_bond(1, 3, 1)
|
| 207 |
+
|
| 208 |
+
# Kolbe nitrilation
|
| 209 |
+
# [N]#[C]-[C;Zs;W0]>>[Br]-[C]
|
| 210 |
+
q, p = prepare()
|
| 211 |
+
q.add_atom("C", heteroatoms=0, hybridization=1)
|
| 212 |
+
q.add_atom("C")
|
| 213 |
+
q.add_atom("N")
|
| 214 |
+
q.add_bond(1, 2, 1)
|
| 215 |
+
q.add_bond(2, 3, 3)
|
| 216 |
+
|
| 217 |
+
p.add_atom("C")
|
| 218 |
+
p.add_atom("Br", _map=4)
|
| 219 |
+
p.add_bond(1, 4, 1)
|
| 220 |
+
|
| 221 |
+
# Nitrile hydrolysis
|
| 222 |
+
# [O;D1]-[C]=[O]>>[N]#[C]
|
| 223 |
+
q, p = prepare()
|
| 224 |
+
q.add_atom("C")
|
| 225 |
+
q.add_atom("O", neighbors=1)
|
| 226 |
+
q.add_atom("O")
|
| 227 |
+
q.add_bond(1, 2, 1)
|
| 228 |
+
q.add_bond(1, 3, 2)
|
| 229 |
+
|
| 230 |
+
p.add_atom("C")
|
| 231 |
+
p.add_atom("N", _map=4)
|
| 232 |
+
p.add_bond(1, 4, 3)
|
| 233 |
+
|
| 234 |
+
# sulfamidation
|
| 235 |
+
# [c]-[S](=[O])(=[O])-[N]>>[c]
|
| 236 |
+
q, p = prepare()
|
| 237 |
+
q.add_atom("C", hybridization=4)
|
| 238 |
+
q.add_atom("S")
|
| 239 |
+
q.add_atom("O")
|
| 240 |
+
q.add_atom("O")
|
| 241 |
+
q.add_atom("N", neighbors=1)
|
| 242 |
+
q.add_bond(1, 2, 1)
|
| 243 |
+
q.add_bond(2, 3, 2)
|
| 244 |
+
q.add_bond(2, 4, 2)
|
| 245 |
+
q.add_bond(2, 5, 1)
|
| 246 |
+
|
| 247 |
+
p.add_atom("C")
|
| 248 |
+
|
| 249 |
+
# Ring expansion rearrangement
|
| 250 |
+
#
|
| 251 |
+
q, p = prepare()
|
| 252 |
+
q.add_atom("C")
|
| 253 |
+
q.add_atom("N")
|
| 254 |
+
q.add_atom("C", rings_sizes=6)
|
| 255 |
+
q.add_atom("C")
|
| 256 |
+
q.add_atom("O")
|
| 257 |
+
q.add_atom("C")
|
| 258 |
+
q.add_atom("C")
|
| 259 |
+
q.add_bond(1, 2, 1)
|
| 260 |
+
q.add_bond(2, 3, 1)
|
| 261 |
+
q.add_bond(3, 4, 1)
|
| 262 |
+
q.add_bond(4, 5, 2)
|
| 263 |
+
q.add_bond(3, 6, 1)
|
| 264 |
+
q.add_bond(4, 7, 1)
|
| 265 |
+
|
| 266 |
+
p.add_atom("C")
|
| 267 |
+
p.add_atom("N")
|
| 268 |
+
p.add_atom("C")
|
| 269 |
+
p.add_atom("C")
|
| 270 |
+
p.add_atom("O")
|
| 271 |
+
p.add_atom("C")
|
| 272 |
+
p.add_atom("C")
|
| 273 |
+
p.add_bond(1, 2, 1)
|
| 274 |
+
p.add_bond(2, 3, 2)
|
| 275 |
+
p.add_bond(3, 4, 1)
|
| 276 |
+
p.add_bond(4, 5, 1)
|
| 277 |
+
p.add_bond(4, 6, 1)
|
| 278 |
+
p.add_bond(4, 7, 1)
|
| 279 |
+
|
| 280 |
+
# hydrolysis of bromide alkyl
|
| 281 |
+
#
|
| 282 |
+
q, p = prepare()
|
| 283 |
+
q.add_atom("C", hybridization=1)
|
| 284 |
+
q.add_atom("O", neighbors=1)
|
| 285 |
+
q.add_bond(1, 2, 1)
|
| 286 |
+
|
| 287 |
+
p.add_atom("C")
|
| 288 |
+
p.add_atom("Br")
|
| 289 |
+
p.add_bond(1, 2, 1)
|
| 290 |
+
|
| 291 |
+
# Condensation of ketones/aldehydes and amines into imines
|
| 292 |
+
#
|
| 293 |
+
q, p = prepare()
|
| 294 |
+
q.add_atom("N", neighbors=(1, 2))
|
| 295 |
+
q.add_atom("C", neighbors=(2, 3), heteroatoms=1)
|
| 296 |
+
q.add_bond(1, 2, 2)
|
| 297 |
+
|
| 298 |
+
p.add_atom("C", _map=2)
|
| 299 |
+
p.add_atom("O")
|
| 300 |
+
p.add_bond(2, 3, 2)
|
| 301 |
+
|
| 302 |
+
# Halogenation of alkanes
|
| 303 |
+
#
|
| 304 |
+
q, p = prepare()
|
| 305 |
+
q.add_atom("C", hybridization=1)
|
| 306 |
+
q.add_atom(ListElement(["F", "Cl", "Br"]))
|
| 307 |
+
q.add_bond(1, 2, 1)
|
| 308 |
+
|
| 309 |
+
p.add_atom("C")
|
| 310 |
+
|
| 311 |
+
# heterocyclization
|
| 312 |
+
#
|
| 313 |
+
q, p = prepare()
|
| 314 |
+
q.add_atom("N", heteroatoms=0, hybridization=1, neighbors=(2, 3))
|
| 315 |
+
q.add_atom("C", heteroatoms=2)
|
| 316 |
+
q.add_atom("N", heteroatoms=0, neighbors=2)
|
| 317 |
+
q.add_bond(1, 2, 1)
|
| 318 |
+
q.add_bond(2, 3, 2)
|
| 319 |
+
|
| 320 |
+
p.add_atom("N")
|
| 321 |
+
p.add_atom("C")
|
| 322 |
+
p.add_atom("N")
|
| 323 |
+
p.add_atom("O")
|
| 324 |
+
p.add_bond(1, 2, 1)
|
| 325 |
+
p.add_bond(2, 4, 2)
|
| 326 |
+
|
| 327 |
+
# Reduction of nitrile
|
| 328 |
+
#
|
| 329 |
+
q, p = prepare()
|
| 330 |
+
q.add_atom("N", neighbors=1)
|
| 331 |
+
q.add_atom("C")
|
| 332 |
+
q.add_atom("C", hybridization=1)
|
| 333 |
+
q.add_bond(1, 2, 1)
|
| 334 |
+
q.add_bond(2, 3, 1)
|
| 335 |
+
|
| 336 |
+
p.add_atom("N")
|
| 337 |
+
p.add_atom("C")
|
| 338 |
+
p.add_atom("C")
|
| 339 |
+
p.add_bond(1, 2, 3)
|
| 340 |
+
p.add_bond(2, 3, 1)
|
| 341 |
+
|
| 342 |
+
# SPECIAL CASE
|
| 343 |
+
# Reduction of nitrile into methylamine
|
| 344 |
+
#
|
| 345 |
+
q, p = prepare()
|
| 346 |
+
q.add_atom("C", neighbors=1)
|
| 347 |
+
q.add_atom("N", neighbors=2)
|
| 348 |
+
q.add_atom("C")
|
| 349 |
+
q.add_atom("C", hybridization=1)
|
| 350 |
+
q.add_bond(1, 2, 1)
|
| 351 |
+
q.add_bond(2, 3, 1)
|
| 352 |
+
q.add_bond(3, 4, 1)
|
| 353 |
+
|
| 354 |
+
p.add_atom("N", _map=2)
|
| 355 |
+
p.add_atom("C")
|
| 356 |
+
p.add_atom("C")
|
| 357 |
+
p.add_bond(2, 3, 3)
|
| 358 |
+
p.add_bond(3, 4, 1)
|
| 359 |
+
|
| 360 |
+
# methylation of amides
|
| 361 |
+
#
|
| 362 |
+
q, p = prepare()
|
| 363 |
+
q.add_atom("O")
|
| 364 |
+
q.add_atom("C")
|
| 365 |
+
q.add_atom("N")
|
| 366 |
+
q.add_atom("C", neighbors=1)
|
| 367 |
+
q.add_bond(1, 2, 2)
|
| 368 |
+
q.add_bond(2, 3, 1)
|
| 369 |
+
q.add_bond(3, 4, 1)
|
| 370 |
+
|
| 371 |
+
p.add_atom("O")
|
| 372 |
+
p.add_atom("C")
|
| 373 |
+
p.add_atom("N")
|
| 374 |
+
p.add_bond(1, 2, 2)
|
| 375 |
+
p.add_bond(2, 3, 1)
|
| 376 |
+
|
| 377 |
+
# hydrocyanation of alkenes
|
| 378 |
+
#
|
| 379 |
+
q, p = prepare()
|
| 380 |
+
q.add_atom("C", hybridization=1)
|
| 381 |
+
q.add_atom("C")
|
| 382 |
+
q.add_atom("C")
|
| 383 |
+
q.add_atom("N")
|
| 384 |
+
q.add_bond(1, 2, 1)
|
| 385 |
+
q.add_bond(2, 3, 1)
|
| 386 |
+
q.add_bond(3, 4, 3)
|
| 387 |
+
|
| 388 |
+
p.add_atom("C")
|
| 389 |
+
p.add_atom("C")
|
| 390 |
+
p.add_bond(1, 2, 2)
|
| 391 |
+
|
| 392 |
+
# decarbocylation (alpha atom of nitrile)
|
| 393 |
+
#
|
| 394 |
+
q, p = prepare()
|
| 395 |
+
q.add_atom("N")
|
| 396 |
+
q.add_atom("C")
|
| 397 |
+
q.add_atom("C", neighbors=2)
|
| 398 |
+
q.add_bond(1, 2, 3)
|
| 399 |
+
q.add_bond(2, 3, 1)
|
| 400 |
+
|
| 401 |
+
p.add_atom("N")
|
| 402 |
+
p.add_atom("C")
|
| 403 |
+
p.add_atom("C")
|
| 404 |
+
p.add_atom("C")
|
| 405 |
+
p.add_atom("O")
|
| 406 |
+
p.add_atom("O")
|
| 407 |
+
p.add_bond(1, 2, 3)
|
| 408 |
+
p.add_bond(2, 3, 1)
|
| 409 |
+
p.add_bond(3, 4, 1)
|
| 410 |
+
p.add_bond(4, 5, 2)
|
| 411 |
+
p.add_bond(4, 6, 1)
|
| 412 |
+
|
| 413 |
+
# Bichler-Napieralski reaction
|
| 414 |
+
#
|
| 415 |
+
q, p = prepare()
|
| 416 |
+
q.add_atom("C", rings_sizes=(6,))
|
| 417 |
+
q.add_atom("C", rings_sizes=(6,))
|
| 418 |
+
q.add_atom("N", rings_sizes=(6,), neighbors=2)
|
| 419 |
+
q.add_atom("C")
|
| 420 |
+
q.add_atom("C")
|
| 421 |
+
q.add_atom("C")
|
| 422 |
+
q.add_atom("O")
|
| 423 |
+
q.add_atom("O")
|
| 424 |
+
q.add_atom("C")
|
| 425 |
+
q.add_atom("O", neighbors=1)
|
| 426 |
+
q.add_bond(1, 2, 4)
|
| 427 |
+
q.add_bond(2, 3, 1)
|
| 428 |
+
q.add_bond(3, 4, 1)
|
| 429 |
+
q.add_bond(4, 5, 2)
|
| 430 |
+
q.add_bond(5, 6, 1)
|
| 431 |
+
q.add_bond(6, 7, 2)
|
| 432 |
+
q.add_bond(6, 8, 1)
|
| 433 |
+
q.add_bond(5, 9, 4)
|
| 434 |
+
q.add_bond(9, 10, 1)
|
| 435 |
+
q.add_bond(1, 9, 1)
|
| 436 |
+
|
| 437 |
+
p.add_atom("C")
|
| 438 |
+
p.add_atom("C")
|
| 439 |
+
p.add_atom("N")
|
| 440 |
+
p.add_atom("C")
|
| 441 |
+
p.add_atom("C")
|
| 442 |
+
p.add_atom("C")
|
| 443 |
+
p.add_atom("O")
|
| 444 |
+
p.add_atom("O")
|
| 445 |
+
p.add_atom("C")
|
| 446 |
+
p.add_atom("O")
|
| 447 |
+
p.add_atom("O")
|
| 448 |
+
p.add_bond(1, 2, 4)
|
| 449 |
+
p.add_bond(2, 3, 1)
|
| 450 |
+
p.add_bond(3, 4, 1)
|
| 451 |
+
p.add_bond(4, 5, 2)
|
| 452 |
+
p.add_bond(5, 6, 1)
|
| 453 |
+
p.add_bond(6, 7, 2)
|
| 454 |
+
p.add_bond(6, 8, 1)
|
| 455 |
+
p.add_bond(5, 9, 1)
|
| 456 |
+
p.add_bond(9, 10, 2)
|
| 457 |
+
p.add_bond(9, 11, 1)
|
| 458 |
+
|
| 459 |
+
# heterocyclization in Prins reaction
|
| 460 |
+
#
|
| 461 |
+
q, p = prepare()
|
| 462 |
+
q.add_atom("C")
|
| 463 |
+
q.add_atom("O")
|
| 464 |
+
q.add_atom("C")
|
| 465 |
+
q.add_atom(ListElement(["N", "O"]), neighbors=2)
|
| 466 |
+
q.add_atom("C")
|
| 467 |
+
q.add_atom("C")
|
| 468 |
+
q.add_bond(1, 2, 1)
|
| 469 |
+
q.add_bond(2, 3, 1)
|
| 470 |
+
q.add_bond(3, 4, 1)
|
| 471 |
+
q.add_bond(4, 5, 1)
|
| 472 |
+
q.add_bond(5, 6, 1)
|
| 473 |
+
q.add_bond(1, 6, 1)
|
| 474 |
+
|
| 475 |
+
p.add_atom("C")
|
| 476 |
+
p.add_atom("C", _map=5)
|
| 477 |
+
p.add_bond(1, 5, 2)
|
| 478 |
+
|
| 479 |
+
# recyclization of tetrahydropyran through an opening the ring and dehydration
|
| 480 |
+
#
|
| 481 |
+
q, p = prepare()
|
| 482 |
+
q.add_atom("C")
|
| 483 |
+
q.add_atom("C")
|
| 484 |
+
q.add_atom("C")
|
| 485 |
+
q.add_atom(ListElement(["N", "O"]))
|
| 486 |
+
q.add_atom("C")
|
| 487 |
+
q.add_atom("C")
|
| 488 |
+
q.add_bond(1, 2, 1)
|
| 489 |
+
q.add_bond(2, 3, 1)
|
| 490 |
+
q.add_bond(3, 4, 1)
|
| 491 |
+
q.add_bond(4, 5, 1)
|
| 492 |
+
q.add_bond(5, 6, 1)
|
| 493 |
+
q.add_bond(1, 6, 2)
|
| 494 |
+
|
| 495 |
+
p.add_atom("C")
|
| 496 |
+
p.add_atom("C")
|
| 497 |
+
p.add_atom("C")
|
| 498 |
+
p.add_atom("A")
|
| 499 |
+
p.add_atom("C")
|
| 500 |
+
p.add_atom("C")
|
| 501 |
+
p.add_atom("O")
|
| 502 |
+
p.add_bond(1, 2, 1)
|
| 503 |
+
p.add_bond(1, 7, 1)
|
| 504 |
+
p.add_bond(3, 7, 1)
|
| 505 |
+
p.add_bond(3, 4, 1)
|
| 506 |
+
p.add_bond(4, 5, 1)
|
| 507 |
+
p.add_bond(5, 6, 1)
|
| 508 |
+
p.add_bond(1, 6, 1)
|
| 509 |
+
|
| 510 |
+
# alkenes + h2o/hHal
|
| 511 |
+
#
|
| 512 |
+
q, p = prepare()
|
| 513 |
+
q.add_atom("C", hybridization=1)
|
| 514 |
+
q.add_atom("C", hybridization=1)
|
| 515 |
+
q.add_atom(ListElement(["O", "F", "Cl", "Br", "I"]), neighbors=1)
|
| 516 |
+
q.add_bond(1, 2, 1)
|
| 517 |
+
q.add_bond(2, 3, 1)
|
| 518 |
+
|
| 519 |
+
p.add_atom("C")
|
| 520 |
+
p.add_atom("C")
|
| 521 |
+
p.add_bond(1, 2, 2)
|
| 522 |
+
|
| 523 |
+
# methylation of dimethylamines
|
| 524 |
+
#
|
| 525 |
+
q, p = prepare()
|
| 526 |
+
q.add_atom("C", neighbors=1)
|
| 527 |
+
q.add_atom("N", neighbors=3)
|
| 528 |
+
q.add_bond(1, 2, 1)
|
| 529 |
+
|
| 530 |
+
p.add_atom("N", _map=2)
|
| 531 |
+
|
| 532 |
+
__all__ = ["rules"]
|
synplan/chem/utils.py
ADDED
|
@@ -0,0 +1,225 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Module containing additional functions needed in different reaction data processing
|
| 2 |
+
protocols."""
|
| 3 |
+
|
| 4 |
+
import logging
|
| 5 |
+
from typing import Iterable
|
| 6 |
+
|
| 7 |
+
from CGRtools.containers import (
|
| 8 |
+
CGRContainer,
|
| 9 |
+
MoleculeContainer,
|
| 10 |
+
QueryContainer,
|
| 11 |
+
ReactionContainer,
|
| 12 |
+
)
|
| 13 |
+
from CGRtools.exceptions import InvalidAromaticRing
|
| 14 |
+
from tqdm import tqdm
|
| 15 |
+
|
| 16 |
+
from synplan.chem import smiles_parser
|
| 17 |
+
from synplan.utils.files import MoleculeReader, MoleculeWriter
|
| 18 |
+
|
| 19 |
+
from chython import MoleculeContainer as MoleculeContainerChython
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def mol_from_smiles(
|
| 23 |
+
smiles: str,
|
| 24 |
+
standardize: bool = True,
|
| 25 |
+
clean_stereo: bool = True,
|
| 26 |
+
clean2d: bool = True,
|
| 27 |
+
) -> MoleculeContainer:
|
| 28 |
+
"""Converts a SMILES string to a `MoleculeContainer` object and optionally
|
| 29 |
+
standardizes, cleans stereochemistry, and cleans 2D coordinates.
|
| 30 |
+
|
| 31 |
+
:param smiles: The SMILES string representing the molecule.
|
| 32 |
+
:param standardize: Whether to standardize the molecule (default is True).
|
| 33 |
+
:param clean_stereo: Whether to remove the stereo marks on atoms of the molecule (default is True).
|
| 34 |
+
:param clean2d: Whether to clean the 2D coordinates of the molecule (default is True).
|
| 35 |
+
:return: The processed molecule object.
|
| 36 |
+
:raises ValueError: If the SMILES string could not be processed by CGRtools.
|
| 37 |
+
"""
|
| 38 |
+
molecule = smiles_parser(smiles)
|
| 39 |
+
|
| 40 |
+
if not isinstance(molecule, MoleculeContainer):
|
| 41 |
+
raise ValueError("SMILES string was not processed by CGRtools")
|
| 42 |
+
|
| 43 |
+
tmp = molecule.copy()
|
| 44 |
+
try:
|
| 45 |
+
if standardize:
|
| 46 |
+
tmp.canonicalize()
|
| 47 |
+
if clean_stereo:
|
| 48 |
+
tmp.clean_stereo()
|
| 49 |
+
if clean2d:
|
| 50 |
+
tmp.clean2d()
|
| 51 |
+
molecule = tmp
|
| 52 |
+
except InvalidAromaticRing:
|
| 53 |
+
logging.warning(
|
| 54 |
+
"CGRtools was not able to standardize molecule due to invalid aromatic ring"
|
| 55 |
+
)
|
| 56 |
+
return molecule
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
def query_to_mol(query: QueryContainer) -> MoleculeContainer:
|
| 60 |
+
"""Converts a QueryContainer object into a MoleculeContainer object.
|
| 61 |
+
|
| 62 |
+
:param query: A QueryContainer object representing the query structure.
|
| 63 |
+
:return: A MoleculeContainer object that replicates the structure of the query.
|
| 64 |
+
"""
|
| 65 |
+
new_mol = MoleculeContainer()
|
| 66 |
+
for n, atom in query.atoms():
|
| 67 |
+
new_mol.add_atom(
|
| 68 |
+
atom.atomic_symbol, n, charge=atom.charge, is_radical=atom.is_radical
|
| 69 |
+
)
|
| 70 |
+
for i, j, bond in query.bonds():
|
| 71 |
+
new_mol.add_bond(i, j, int(bond))
|
| 72 |
+
return new_mol
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
def reaction_query_to_reaction(reaction_rule: ReactionContainer) -> ReactionContainer:
|
| 76 |
+
"""Converts a ReactionContainer object with query structures into a
|
| 77 |
+
ReactionContainer with molecular structures.
|
| 78 |
+
|
| 79 |
+
:param reaction_rule: A ReactionContainer object where reactants and products are
|
| 80 |
+
QueryContainer objects.
|
| 81 |
+
:return: A new ReactionContainer object where reactants and products are
|
| 82 |
+
MoleculeContainer objects.
|
| 83 |
+
"""
|
| 84 |
+
reactants = [query_to_mol(q) for q in reaction_rule.reactants]
|
| 85 |
+
products = [query_to_mol(q) for q in reaction_rule.products]
|
| 86 |
+
reagents = [
|
| 87 |
+
query_to_mol(q) for q in reaction_rule.reagents
|
| 88 |
+
] # Assuming reagents are also part of the rule
|
| 89 |
+
reaction = ReactionContainer(reactants, products, reagents, reaction_rule.meta)
|
| 90 |
+
reaction.name = reaction_rule.name
|
| 91 |
+
return reaction
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
def unite_molecules(molecules: Iterable[MoleculeContainer]) -> MoleculeContainer:
|
| 95 |
+
"""Unites a list of MoleculeContainer objects into a single MoleculeContainer. This
|
| 96 |
+
function takes multiple molecules and combines them into one larger molecule. The
|
| 97 |
+
first molecule in the list is taken as the base, and subsequent molecules are united
|
| 98 |
+
with it sequentially.
|
| 99 |
+
|
| 100 |
+
:param molecules: A list of MoleculeContainer objects to be united.
|
| 101 |
+
:return: A single MoleculeContainer object representing the union of all input
|
| 102 |
+
molecules.
|
| 103 |
+
"""
|
| 104 |
+
new_mol = MoleculeContainer()
|
| 105 |
+
for mol in molecules:
|
| 106 |
+
new_mol = new_mol.union(mol)
|
| 107 |
+
return new_mol
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
def safe_canonicalization(molecule: MoleculeContainer) -> MoleculeContainer:
|
| 111 |
+
"""Attempts to canonicalize a molecule, handling any exceptions. If the
|
| 112 |
+
canonicalization process fails due to an InvalidAromaticRing exception, it safely
|
| 113 |
+
returns the original molecule.
|
| 114 |
+
|
| 115 |
+
:param molecule: The given molecule to be canonicalized.
|
| 116 |
+
:return: The canonicalized molecule if successful, otherwise the original molecule.
|
| 117 |
+
"""
|
| 118 |
+
molecule._atoms = dict(sorted(molecule._atoms.items()))
|
| 119 |
+
|
| 120 |
+
molecule_copy = molecule.copy()
|
| 121 |
+
try:
|
| 122 |
+
molecule_copy.canonicalize()
|
| 123 |
+
molecule_copy.clean_stereo()
|
| 124 |
+
return molecule_copy
|
| 125 |
+
except InvalidAromaticRing:
|
| 126 |
+
return molecule
|
| 127 |
+
|
| 128 |
+
|
| 129 |
+
def standardize_building_blocks(input_file: str, output_file: str) -> str:
|
| 130 |
+
"""Standardizes custom building blocks.
|
| 131 |
+
|
| 132 |
+
:param input_file: The path to the file that stores the original building blocks.
|
| 133 |
+
:param output_file: The path to the file that will store the standardized building
|
| 134 |
+
blocks.
|
| 135 |
+
:return: The path to the file with standardized building blocks.
|
| 136 |
+
"""
|
| 137 |
+
if input_file == output_file:
|
| 138 |
+
raise ValueError("input_file name and output_file name cannot be the same.")
|
| 139 |
+
|
| 140 |
+
with MoleculeReader(input_file) as inp_file, MoleculeWriter(
|
| 141 |
+
output_file
|
| 142 |
+
) as out_file:
|
| 143 |
+
for mol in tqdm(
|
| 144 |
+
inp_file,
|
| 145 |
+
desc="Number of building blocks processed: ",
|
| 146 |
+
bar_format="{desc}{n} [{elapsed}]",
|
| 147 |
+
):
|
| 148 |
+
try:
|
| 149 |
+
mol = safe_canonicalization(mol)
|
| 150 |
+
except Exception as e:
|
| 151 |
+
logging.debug(e)
|
| 152 |
+
continue
|
| 153 |
+
out_file.write(mol)
|
| 154 |
+
|
| 155 |
+
return output_file
|
| 156 |
+
|
| 157 |
+
|
| 158 |
+
def cgr_from_reaction_rule(reaction_rule: ReactionContainer) -> CGRContainer:
|
| 159 |
+
"""Creates a CGR from the given reaction rule.
|
| 160 |
+
|
| 161 |
+
:param reaction_rule: The reaction rule to be converted.
|
| 162 |
+
:return: The resulting CGR.
|
| 163 |
+
"""
|
| 164 |
+
|
| 165 |
+
reaction_rule = reaction_query_to_reaction(reaction_rule)
|
| 166 |
+
cgr_rule = ~reaction_rule
|
| 167 |
+
|
| 168 |
+
return cgr_rule
|
| 169 |
+
|
| 170 |
+
|
| 171 |
+
def hash_from_reaction_rule(reaction_rule: ReactionContainer) -> hash:
|
| 172 |
+
"""Generates hash for the given reaction rule.
|
| 173 |
+
|
| 174 |
+
:param reaction_rule: The reaction rule to be converted.
|
| 175 |
+
:return: The resulting hash.
|
| 176 |
+
"""
|
| 177 |
+
|
| 178 |
+
reactants_hash = tuple(sorted(hash(r) for r in reaction_rule.reactants))
|
| 179 |
+
reagents_hash = tuple(sorted(hash(r) for r in reaction_rule.reagents))
|
| 180 |
+
products_hash = tuple(sorted(hash(r) for r in reaction_rule.products))
|
| 181 |
+
|
| 182 |
+
return hash((reactants_hash, reagents_hash, products_hash))
|
| 183 |
+
|
| 184 |
+
|
| 185 |
+
def reverse_reaction(
|
| 186 |
+
reaction: ReactionContainer,
|
| 187 |
+
) -> ReactionContainer:
|
| 188 |
+
"""Reverses the given reaction.
|
| 189 |
+
|
| 190 |
+
:param reaction: The reaction to be reversed.
|
| 191 |
+
:return: The reversed reaction.
|
| 192 |
+
"""
|
| 193 |
+
reversed_reaction = ReactionContainer(
|
| 194 |
+
reaction.products, reaction.reactants, reaction.reagents, reaction.meta
|
| 195 |
+
)
|
| 196 |
+
reversed_reaction.name = reaction.name
|
| 197 |
+
|
| 198 |
+
return reversed_reaction
|
| 199 |
+
|
| 200 |
+
|
| 201 |
+
def cgrtools_to_chython_molecule(molecule):
|
| 202 |
+
molecule_chython = MoleculeContainerChython()
|
| 203 |
+
for n, atom in molecule.atoms():
|
| 204 |
+
molecule_chython.add_atom(atom.atomic_symbol, n)
|
| 205 |
+
|
| 206 |
+
for n, m, bond in molecule.bonds():
|
| 207 |
+
molecule_chython.add_bond(n, m, int(bond))
|
| 208 |
+
|
| 209 |
+
return molecule_chython
|
| 210 |
+
|
| 211 |
+
|
| 212 |
+
def chython_query_to_cgrtools(query):
|
| 213 |
+
cgrtools_query = QueryContainer()
|
| 214 |
+
for n, atom in query.atoms():
|
| 215 |
+
cgrtools_query.add_atom(
|
| 216 |
+
atom=atom.atomic_symbol,
|
| 217 |
+
charge=atom.charge,
|
| 218 |
+
neighbors=atom.neighbors,
|
| 219 |
+
hybridization=atom.hybridization,
|
| 220 |
+
_map=n,
|
| 221 |
+
)
|
| 222 |
+
for n, m, bond in query.bonds():
|
| 223 |
+
cgrtools_query.add_bond(n, m, int(bond))
|
| 224 |
+
|
| 225 |
+
return cgrtools_query
|
synplan/interfaces/__init__.py
ADDED
|
File without changes
|
synplan/interfaces/cli.py
ADDED
|
@@ -0,0 +1,506 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Module containing commands line scripts for training and planning steps."""
|
| 2 |
+
|
| 3 |
+
import os
|
| 4 |
+
import warnings
|
| 5 |
+
from pathlib import Path
|
| 6 |
+
|
| 7 |
+
import click
|
| 8 |
+
import yaml
|
| 9 |
+
|
| 10 |
+
from synplan.chem.data.filtering import ReactionFilterConfig, filter_reactions_from_file
|
| 11 |
+
from synplan.chem.data.standardizing import (
|
| 12 |
+
ReactionStandardizationConfig,
|
| 13 |
+
standardize_reactions_from_file,
|
| 14 |
+
)
|
| 15 |
+
from synplan.chem.reaction_rules.extraction import extract_rules_from_reactions
|
| 16 |
+
from synplan.chem.reaction_routes.clustering import run_cluster_cli
|
| 17 |
+
from synplan.chem.utils import standardize_building_blocks
|
| 18 |
+
from synplan.mcts.search import run_search
|
| 19 |
+
from synplan.ml.training.supervised import create_policy_dataset, run_policy_training
|
| 20 |
+
from synplan.ml.training.reinforcement import run_updating
|
| 21 |
+
from synplan.utils.config import (
|
| 22 |
+
PolicyNetworkConfig,
|
| 23 |
+
RuleExtractionConfig,
|
| 24 |
+
TreeConfig,
|
| 25 |
+
TuningConfig,
|
| 26 |
+
ValueNetworkConfig,
|
| 27 |
+
)
|
| 28 |
+
from synplan.utils.loading import download_all_data
|
| 29 |
+
from synplan.utils.visualisation import (
|
| 30 |
+
routes_clustering_report,
|
| 31 |
+
routes_subclustering_report,
|
| 32 |
+
)
|
| 33 |
+
|
| 34 |
+
warnings.filterwarnings("ignore")
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
@click.group(name="synplan")
|
| 38 |
+
def synplan():
|
| 39 |
+
"""SynPlanner command line interface."""
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
@synplan.command(name="download_all_data")
|
| 43 |
+
@click.option(
|
| 44 |
+
"--save_to",
|
| 45 |
+
"save_to",
|
| 46 |
+
help="Path to the folder where downloaded data will be stored.",
|
| 47 |
+
)
|
| 48 |
+
def download_all_data_cli(save_to: str = ".") -> None:
|
| 49 |
+
"""Downloads all data for training, planning and benchmarking SynPlanner."""
|
| 50 |
+
download_all_data(save_to=save_to)
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
@synplan.command(name="building_blocks_standardizing")
|
| 54 |
+
@click.option(
|
| 55 |
+
"--input",
|
| 56 |
+
"input_file",
|
| 57 |
+
required=True,
|
| 58 |
+
type=click.Path(exists=True),
|
| 59 |
+
help="Path to the file with building blocks to be standardized.",
|
| 60 |
+
)
|
| 61 |
+
@click.option(
|
| 62 |
+
"--output",
|
| 63 |
+
"output_file",
|
| 64 |
+
required=True,
|
| 65 |
+
type=click.Path(),
|
| 66 |
+
help="Path to the file where standardized building blocks will be stored.",
|
| 67 |
+
)
|
| 68 |
+
def building_blocks_standardizing_cli(input_file: str, output_file: str) -> None:
|
| 69 |
+
"""Standardizes building blocks."""
|
| 70 |
+
standardize_building_blocks(input_file=input_file, output_file=output_file)
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
@synplan.command(name="reaction_standardizing")
|
| 74 |
+
@click.option(
|
| 75 |
+
"--config",
|
| 76 |
+
"config_path",
|
| 77 |
+
required=True,
|
| 78 |
+
type=click.Path(exists=True),
|
| 79 |
+
help="Path to the configuration file for reactions standardizing.",
|
| 80 |
+
)
|
| 81 |
+
@click.option(
|
| 82 |
+
"--input",
|
| 83 |
+
"input_file",
|
| 84 |
+
required=True,
|
| 85 |
+
type=click.Path(exists=True),
|
| 86 |
+
help="Path to the file with reactions to be standardized.",
|
| 87 |
+
)
|
| 88 |
+
@click.option(
|
| 89 |
+
"--output",
|
| 90 |
+
"output_file",
|
| 91 |
+
type=click.Path(),
|
| 92 |
+
help="Path to the file where standardized reactions will be stored.",
|
| 93 |
+
)
|
| 94 |
+
@click.option(
|
| 95 |
+
"--num_cpus", default=4, type=int, help="The number of CPUs to use for processing."
|
| 96 |
+
)
|
| 97 |
+
def reaction_standardizing_cli(
|
| 98 |
+
config_path: str, input_file: str, output_file: str, num_cpus: int
|
| 99 |
+
) -> None:
|
| 100 |
+
"""Standardizes reactions and remove duplicates."""
|
| 101 |
+
stand_config = ReactionStandardizationConfig.from_yaml(config_path)
|
| 102 |
+
standardize_reactions_from_file(
|
| 103 |
+
config=stand_config,
|
| 104 |
+
input_reaction_data_path=input_file,
|
| 105 |
+
standardized_reaction_data_path=output_file,
|
| 106 |
+
num_cpus=num_cpus,
|
| 107 |
+
batch_size=100,
|
| 108 |
+
)
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
@synplan.command(name="reaction_filtering")
|
| 112 |
+
@click.option(
|
| 113 |
+
"--config",
|
| 114 |
+
"config_path",
|
| 115 |
+
required=True,
|
| 116 |
+
type=click.Path(exists=True),
|
| 117 |
+
help="Path to the configuration file for reactions filtering.",
|
| 118 |
+
)
|
| 119 |
+
@click.option(
|
| 120 |
+
"--input",
|
| 121 |
+
"input_file",
|
| 122 |
+
required=True,
|
| 123 |
+
type=click.Path(exists=True),
|
| 124 |
+
help="Path to the file with reactions to be filtered.",
|
| 125 |
+
)
|
| 126 |
+
@click.option(
|
| 127 |
+
"--output",
|
| 128 |
+
"output_file",
|
| 129 |
+
default=Path("./"),
|
| 130 |
+
type=click.Path(),
|
| 131 |
+
help="Path to the file where successfully filtered reactions will be stored.",
|
| 132 |
+
)
|
| 133 |
+
@click.option(
|
| 134 |
+
"--num_cpus", default=4, type=int, help="The number of CPUs to use for processing."
|
| 135 |
+
)
|
| 136 |
+
def reaction_filtering_cli(
|
| 137 |
+
config_path: str, input_file: str, output_file: str, num_cpus: int
|
| 138 |
+
):
|
| 139 |
+
"""Filters erroneous reactions."""
|
| 140 |
+
reaction_check_config = ReactionFilterConfig().from_yaml(config_path)
|
| 141 |
+
filter_reactions_from_file(
|
| 142 |
+
config=reaction_check_config,
|
| 143 |
+
input_reaction_data_path=input_file,
|
| 144 |
+
filtered_reaction_data_path=output_file,
|
| 145 |
+
num_cpus=num_cpus,
|
| 146 |
+
batch_size=100,
|
| 147 |
+
)
|
| 148 |
+
|
| 149 |
+
|
| 150 |
+
@synplan.command(name="rule_extracting")
|
| 151 |
+
@click.option(
|
| 152 |
+
"--config",
|
| 153 |
+
"config_path",
|
| 154 |
+
required=True,
|
| 155 |
+
type=click.Path(exists=True),
|
| 156 |
+
help="Path to the configuration file for reaction rules extracting.",
|
| 157 |
+
)
|
| 158 |
+
@click.option(
|
| 159 |
+
"--input",
|
| 160 |
+
"input_file",
|
| 161 |
+
required=True,
|
| 162 |
+
type=click.Path(exists=True),
|
| 163 |
+
help="Path to the file with reactions for reaction rules extraction.",
|
| 164 |
+
)
|
| 165 |
+
@click.option(
|
| 166 |
+
"--output",
|
| 167 |
+
"output_file",
|
| 168 |
+
required=True,
|
| 169 |
+
type=click.Path(),
|
| 170 |
+
help="Path to the file where extracted reaction rules will be stored.",
|
| 171 |
+
)
|
| 172 |
+
@click.option(
|
| 173 |
+
"--num_cpus", default=4, type=int, help="The number of CPUs to use for processing."
|
| 174 |
+
)
|
| 175 |
+
def rule_extracting_cli(
|
| 176 |
+
config_path: str, input_file: str, output_file: str, num_cpus: int
|
| 177 |
+
):
|
| 178 |
+
"""Reaction rules extraction."""
|
| 179 |
+
reaction_rule_config = RuleExtractionConfig.from_yaml(config_path)
|
| 180 |
+
extract_rules_from_reactions(
|
| 181 |
+
config=reaction_rule_config,
|
| 182 |
+
reaction_data_path=input_file,
|
| 183 |
+
reaction_rules_path=output_file,
|
| 184 |
+
num_cpus=num_cpus,
|
| 185 |
+
batch_size=100,
|
| 186 |
+
)
|
| 187 |
+
|
| 188 |
+
|
| 189 |
+
@synplan.command(name="ranking_policy_training")
|
| 190 |
+
@click.option(
|
| 191 |
+
"--config",
|
| 192 |
+
"config_path",
|
| 193 |
+
required=True,
|
| 194 |
+
type=click.Path(exists=True),
|
| 195 |
+
help="Path to the configuration file for ranking policy training.",
|
| 196 |
+
)
|
| 197 |
+
@click.option(
|
| 198 |
+
"--reaction_data",
|
| 199 |
+
required=True,
|
| 200 |
+
type=click.Path(exists=True),
|
| 201 |
+
help="Path to the file with reactions for ranking policy training.",
|
| 202 |
+
)
|
| 203 |
+
@click.option(
|
| 204 |
+
"--reaction_rules",
|
| 205 |
+
required=True,
|
| 206 |
+
type=click.Path(exists=True),
|
| 207 |
+
help="Path to the file with extracted reaction rules.",
|
| 208 |
+
)
|
| 209 |
+
@click.option(
|
| 210 |
+
"--results_dir",
|
| 211 |
+
default=Path("."),
|
| 212 |
+
type=click.Path(),
|
| 213 |
+
help="Path to the directory where the trained policy network will be stored.",
|
| 214 |
+
)
|
| 215 |
+
@click.option(
|
| 216 |
+
"--num_cpus",
|
| 217 |
+
default=4,
|
| 218 |
+
type=int,
|
| 219 |
+
help="The number of CPUs to use for training set preparation.",
|
| 220 |
+
)
|
| 221 |
+
def ranking_policy_training_cli(
|
| 222 |
+
config_path: str,
|
| 223 |
+
reaction_data: str,
|
| 224 |
+
reaction_rules: str,
|
| 225 |
+
results_dir: str,
|
| 226 |
+
num_cpus: int,
|
| 227 |
+
) -> None:
|
| 228 |
+
"""Ranking policy network training."""
|
| 229 |
+
policy_config = PolicyNetworkConfig.from_yaml(config_path)
|
| 230 |
+
policy_config.policy_type = "ranking"
|
| 231 |
+
policy_dataset_file = os.path.join(results_dir, "policy_dataset.dt")
|
| 232 |
+
|
| 233 |
+
datamodule = create_policy_dataset(
|
| 234 |
+
reaction_rules_path=reaction_rules,
|
| 235 |
+
molecules_or_reactions_path=reaction_data,
|
| 236 |
+
output_path=policy_dataset_file,
|
| 237 |
+
dataset_type="ranking",
|
| 238 |
+
batch_size=policy_config.batch_size,
|
| 239 |
+
num_cpus=num_cpus,
|
| 240 |
+
)
|
| 241 |
+
|
| 242 |
+
run_policy_training(datamodule, config=policy_config, results_path=results_dir)
|
| 243 |
+
|
| 244 |
+
|
| 245 |
+
@synplan.command(name="filtering_policy_training")
|
| 246 |
+
@click.option(
|
| 247 |
+
"--config",
|
| 248 |
+
"config_path",
|
| 249 |
+
required=True,
|
| 250 |
+
type=click.Path(exists=True),
|
| 251 |
+
help="Path to the configuration file for filtering policy training.",
|
| 252 |
+
)
|
| 253 |
+
@click.option(
|
| 254 |
+
"--molecule_data",
|
| 255 |
+
required=True,
|
| 256 |
+
type=click.Path(exists=True),
|
| 257 |
+
help="Path to the file with molecules for filtering policy training.",
|
| 258 |
+
)
|
| 259 |
+
@click.option(
|
| 260 |
+
"--reaction_rules",
|
| 261 |
+
required=True,
|
| 262 |
+
type=click.Path(exists=True),
|
| 263 |
+
help="Path to the file with extracted reaction rules.",
|
| 264 |
+
)
|
| 265 |
+
@click.option(
|
| 266 |
+
"--results_dir",
|
| 267 |
+
default=Path("."),
|
| 268 |
+
type=click.Path(),
|
| 269 |
+
help="Path to the directory where the trained policy network will be stored.",
|
| 270 |
+
)
|
| 271 |
+
@click.option(
|
| 272 |
+
"--num_cpus",
|
| 273 |
+
default=8,
|
| 274 |
+
type=int,
|
| 275 |
+
help="The number of CPUs to use for training set preparation.",
|
| 276 |
+
)
|
| 277 |
+
def filtering_policy_training_cli(
|
| 278 |
+
config_path: str,
|
| 279 |
+
molecule_data: str,
|
| 280 |
+
reaction_rules: str,
|
| 281 |
+
results_dir: str,
|
| 282 |
+
num_cpus: int,
|
| 283 |
+
):
|
| 284 |
+
"""Filtering policy network training."""
|
| 285 |
+
|
| 286 |
+
policy_config = PolicyNetworkConfig.from_yaml(config_path)
|
| 287 |
+
policy_config.policy_type = "filtering"
|
| 288 |
+
policy_dataset_file = os.path.join(results_dir, "policy_dataset.ckpt")
|
| 289 |
+
|
| 290 |
+
datamodule = create_policy_dataset(
|
| 291 |
+
reaction_rules_path=reaction_rules,
|
| 292 |
+
molecules_or_reactions_path=molecule_data,
|
| 293 |
+
output_path=policy_dataset_file,
|
| 294 |
+
dataset_type="filtering",
|
| 295 |
+
batch_size=policy_config.batch_size,
|
| 296 |
+
num_cpus=num_cpus,
|
| 297 |
+
)
|
| 298 |
+
|
| 299 |
+
run_policy_training(datamodule, config=policy_config, results_path=results_dir)
|
| 300 |
+
|
| 301 |
+
|
| 302 |
+
@synplan.command(name="value_network_tuning")
|
| 303 |
+
@click.option(
|
| 304 |
+
"--config",
|
| 305 |
+
"config_path",
|
| 306 |
+
required=True,
|
| 307 |
+
type=click.Path(exists=True),
|
| 308 |
+
help="Path to the configuration file for value network training.",
|
| 309 |
+
)
|
| 310 |
+
@click.option(
|
| 311 |
+
"--targets",
|
| 312 |
+
required=True,
|
| 313 |
+
type=click.Path(exists=True),
|
| 314 |
+
help="Path to the file with target molecules for planning simulations.",
|
| 315 |
+
)
|
| 316 |
+
@click.option(
|
| 317 |
+
"--reaction_rules",
|
| 318 |
+
required=True,
|
| 319 |
+
type=click.Path(exists=True),
|
| 320 |
+
help="Path to the file with extracted reaction rules. Needed for planning simulations.",
|
| 321 |
+
)
|
| 322 |
+
@click.option(
|
| 323 |
+
"--building_blocks",
|
| 324 |
+
required=True,
|
| 325 |
+
type=click.Path(exists=True),
|
| 326 |
+
help="Path to the file with building blocks. Needed for planning simulations.",
|
| 327 |
+
)
|
| 328 |
+
@click.option(
|
| 329 |
+
"--policy_network",
|
| 330 |
+
required=True,
|
| 331 |
+
type=click.Path(exists=True),
|
| 332 |
+
help="Path to the file with trained policy network. Needed for planning simulations.",
|
| 333 |
+
)
|
| 334 |
+
@click.option(
|
| 335 |
+
"--value_network",
|
| 336 |
+
default=None,
|
| 337 |
+
type=click.Path(exists=True),
|
| 338 |
+
help="Path to the file with trained value network. Needed in case of additional value network fine-tuning",
|
| 339 |
+
)
|
| 340 |
+
@click.option(
|
| 341 |
+
"--results_dir",
|
| 342 |
+
default=".",
|
| 343 |
+
type=click.Path(exists=False),
|
| 344 |
+
help="Path to the directory where the trained value network will be stored.",
|
| 345 |
+
)
|
| 346 |
+
def value_network_tuning_cli(
|
| 347 |
+
config_path: str,
|
| 348 |
+
targets: str,
|
| 349 |
+
reaction_rules: str,
|
| 350 |
+
building_blocks: str,
|
| 351 |
+
policy_network: str,
|
| 352 |
+
value_network: str,
|
| 353 |
+
results_dir: str,
|
| 354 |
+
):
|
| 355 |
+
"""Value network tuning."""
|
| 356 |
+
|
| 357 |
+
with open(config_path, "r", encoding="utf-8") as file:
|
| 358 |
+
config = yaml.safe_load(file)
|
| 359 |
+
|
| 360 |
+
policy_config = PolicyNetworkConfig.from_dict(config["node_expansion"])
|
| 361 |
+
policy_config.weights_path = policy_network
|
| 362 |
+
|
| 363 |
+
value_config = ValueNetworkConfig.from_dict(config["value_network"])
|
| 364 |
+
if value_network is None:
|
| 365 |
+
value_config.weights_path = os.path.join(
|
| 366 |
+
results_dir, "weights", "value_network.ckpt"
|
| 367 |
+
)
|
| 368 |
+
|
| 369 |
+
tree_config = TreeConfig.from_dict(config["tree"])
|
| 370 |
+
tuning_config = TuningConfig.from_dict(config["tuning"])
|
| 371 |
+
|
| 372 |
+
run_updating(
|
| 373 |
+
targets_path=targets,
|
| 374 |
+
tree_config=tree_config,
|
| 375 |
+
policy_config=policy_config,
|
| 376 |
+
value_config=value_config,
|
| 377 |
+
reinforce_config=tuning_config,
|
| 378 |
+
reaction_rules_path=reaction_rules,
|
| 379 |
+
building_blocks_path=building_blocks,
|
| 380 |
+
results_root=results_dir,
|
| 381 |
+
)
|
| 382 |
+
|
| 383 |
+
|
| 384 |
+
@synplan.command(name="planning")
|
| 385 |
+
@click.option(
|
| 386 |
+
"--config",
|
| 387 |
+
"config_path",
|
| 388 |
+
required=True,
|
| 389 |
+
type=click.Path(exists=True),
|
| 390 |
+
help="Path to the configuration file for retrosynthetic planning.",
|
| 391 |
+
)
|
| 392 |
+
@click.option(
|
| 393 |
+
"--targets",
|
| 394 |
+
required=True,
|
| 395 |
+
type=click.Path(exists=True),
|
| 396 |
+
help="Path to the file with target molecules for retrosynthetic planning.",
|
| 397 |
+
)
|
| 398 |
+
@click.option(
|
| 399 |
+
"--reaction_rules",
|
| 400 |
+
required=True,
|
| 401 |
+
type=click.Path(exists=True),
|
| 402 |
+
help="Path to the file with extracted reaction rules.",
|
| 403 |
+
)
|
| 404 |
+
@click.option(
|
| 405 |
+
"--building_blocks",
|
| 406 |
+
required=True,
|
| 407 |
+
type=click.Path(exists=True),
|
| 408 |
+
help="Path to the file with building blocks.",
|
| 409 |
+
)
|
| 410 |
+
@click.option(
|
| 411 |
+
"--policy_network",
|
| 412 |
+
required=True,
|
| 413 |
+
type=click.Path(exists=True),
|
| 414 |
+
help="Path to the file with trained policy network.",
|
| 415 |
+
)
|
| 416 |
+
@click.option(
|
| 417 |
+
"--value_network",
|
| 418 |
+
default=None,
|
| 419 |
+
type=click.Path(exists=True),
|
| 420 |
+
help="Path to the file with trained value network.",
|
| 421 |
+
)
|
| 422 |
+
@click.option(
|
| 423 |
+
"--results_dir",
|
| 424 |
+
default=".",
|
| 425 |
+
type=click.Path(exists=False),
|
| 426 |
+
help="Path to the file where retrosynthetic planning results will be stored.",
|
| 427 |
+
)
|
| 428 |
+
def planning_cli(
|
| 429 |
+
config_path: str,
|
| 430 |
+
targets: str,
|
| 431 |
+
reaction_rules: str,
|
| 432 |
+
building_blocks: str,
|
| 433 |
+
policy_network: str,
|
| 434 |
+
value_network: str,
|
| 435 |
+
results_dir: str,
|
| 436 |
+
):
|
| 437 |
+
"""Retrosynthetic planning."""
|
| 438 |
+
|
| 439 |
+
with open(config_path, "r", encoding="utf-8") as file:
|
| 440 |
+
config = yaml.safe_load(file)
|
| 441 |
+
|
| 442 |
+
search_config = {**config["tree"], **config["node_evaluation"]}
|
| 443 |
+
policy_config = PolicyNetworkConfig.from_dict(
|
| 444 |
+
{**config["node_expansion"], **{"weights_path": policy_network}}
|
| 445 |
+
)
|
| 446 |
+
|
| 447 |
+
run_search(
|
| 448 |
+
targets_path=targets,
|
| 449 |
+
search_config=search_config,
|
| 450 |
+
policy_config=policy_config,
|
| 451 |
+
reaction_rules_path=reaction_rules,
|
| 452 |
+
building_blocks_path=building_blocks,
|
| 453 |
+
value_network_path=value_network,
|
| 454 |
+
results_root=results_dir,
|
| 455 |
+
)
|
| 456 |
+
|
| 457 |
+
|
| 458 |
+
@synplan.command(name="clustering")
|
| 459 |
+
@click.option(
|
| 460 |
+
"--targets",
|
| 461 |
+
required=True,
|
| 462 |
+
type=click.Path(exists=True),
|
| 463 |
+
help="Path to the file with target molecules for retrosynthetic planning.",
|
| 464 |
+
)
|
| 465 |
+
@click.option(
|
| 466 |
+
"--routes_file",
|
| 467 |
+
default=".",
|
| 468 |
+
type=click.Path(exists=False),
|
| 469 |
+
help="Path to the file where the planning results are stored.",
|
| 470 |
+
)
|
| 471 |
+
@click.option(
|
| 472 |
+
"--cluster_results_dir",
|
| 473 |
+
default=".",
|
| 474 |
+
type=click.Path(exists=False),
|
| 475 |
+
help="Path to the file where clustering results will be stored.",
|
| 476 |
+
)
|
| 477 |
+
@click.option(
|
| 478 |
+
"--perform_subcluster",
|
| 479 |
+
default=None,
|
| 480 |
+
type=click.Path(exists=False),
|
| 481 |
+
help="Perform subclustering.",
|
| 482 |
+
)
|
| 483 |
+
@click.option(
|
| 484 |
+
"--subcluster_results_dir",
|
| 485 |
+
default=".",
|
| 486 |
+
type=click.Path(exists=False),
|
| 487 |
+
help="Path to the file where subclustering results will be stored.",
|
| 488 |
+
)
|
| 489 |
+
def cluster_route_from_file_cli(
|
| 490 |
+
targets: str,
|
| 491 |
+
routes_file: str,
|
| 492 |
+
cluster_results_dir: str,
|
| 493 |
+
perform_subcluster: bool,
|
| 494 |
+
subcluster_results_dir: str,
|
| 495 |
+
):
|
| 496 |
+
"""Clustering the routes from planning"""
|
| 497 |
+
run_cluster_cli(
|
| 498 |
+
routes_file=routes_file,
|
| 499 |
+
cluster_results_dir=cluster_results_dir,
|
| 500 |
+
perform_subcluster=perform_subcluster,
|
| 501 |
+
subcluster_results_dir=subcluster_results_dir if perform_subcluster else None,
|
| 502 |
+
)
|
| 503 |
+
|
| 504 |
+
|
| 505 |
+
if __name__ == "__main__":
|
| 506 |
+
synplan()
|
synplan/interfaces/gui.py
ADDED
|
@@ -0,0 +1,1323 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import base64
|
| 2 |
+
import pickle
|
| 3 |
+
import re
|
| 4 |
+
import uuid
|
| 5 |
+
import io
|
| 6 |
+
import zipfile
|
| 7 |
+
|
| 8 |
+
import pandas as pd
|
| 9 |
+
import streamlit as st
|
| 10 |
+
from CGRtools.files import SMILESRead
|
| 11 |
+
from streamlit_ketcher import st_ketcher
|
| 12 |
+
from huggingface_hub import hf_hub_download
|
| 13 |
+
from huggingface_hub.utils import disable_progress_bars
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
from synplan.mcts.expansion import PolicyNetworkFunction
|
| 17 |
+
from synplan.mcts.search import extract_tree_stats
|
| 18 |
+
from synplan.mcts.tree import Tree
|
| 19 |
+
from synplan.chem.utils import mol_from_smiles
|
| 20 |
+
from synplan.chem.reaction_routes.route_cgr import *
|
| 21 |
+
from synplan.chem.reaction_routes.clustering import *
|
| 22 |
+
|
| 23 |
+
from synplan.utils.visualisation import (
|
| 24 |
+
routes_clustering_report,
|
| 25 |
+
routes_subclustering_report,
|
| 26 |
+
generate_results_html,
|
| 27 |
+
html_top_routes_cluster,
|
| 28 |
+
get_route_svg,
|
| 29 |
+
get_route_svg_from_json
|
| 30 |
+
)
|
| 31 |
+
from synplan.utils.config import TreeConfig, PolicyNetworkConfig
|
| 32 |
+
from synplan.utils.loading import load_reaction_rules, load_building_blocks
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
import psutil
|
| 36 |
+
import gc
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
disable_progress_bars("huggingface_hub")
|
| 40 |
+
|
| 41 |
+
smiles_parser = SMILESRead.create_parser(ignore=True)
|
| 42 |
+
DEFAULT_MOL = "c1cc(ccc1Cl)C(CCO)NC(C2(CCN(CC2)c3c4cc[nH]c4ncn3)N)=O"
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
# --- Helper Functions ---
|
| 46 |
+
def download_button(
|
| 47 |
+
object_to_download, download_filename, button_text, pickle_it=False
|
| 48 |
+
):
|
| 49 |
+
"""
|
| 50 |
+
Issued from
|
| 51 |
+
Generates a link to download the given object_to_download.
|
| 52 |
+
Params:
|
| 53 |
+
------
|
| 54 |
+
object_to_download: The object to be downloaded.
|
| 55 |
+
download_filename (str): filename and extension of file. e.g. mydata.csv,
|
| 56 |
+
some_txt_output.txt download_link_text (str): Text to display for download
|
| 57 |
+
link.
|
| 58 |
+
button_text (str): Text to display on download button (e.g. 'click here to download file')
|
| 59 |
+
pickle_it (bool): If True, pickle file.
|
| 60 |
+
Returns:
|
| 61 |
+
-------
|
| 62 |
+
(str): the anchor tag to download object_to_download
|
| 63 |
+
Examples:
|
| 64 |
+
--------
|
| 65 |
+
download_link(your_df, 'YOUR_DF.csv', 'Click to download data!')
|
| 66 |
+
download_link(your_str, 'YOUR_STRING.txt', 'Click to download text!')
|
| 67 |
+
"""
|
| 68 |
+
if pickle_it:
|
| 69 |
+
try:
|
| 70 |
+
object_to_download = pickle.dumps(object_to_download)
|
| 71 |
+
except pickle.PicklingError as e:
|
| 72 |
+
st.write(e)
|
| 73 |
+
return None
|
| 74 |
+
|
| 75 |
+
else:
|
| 76 |
+
if isinstance(object_to_download, bytes):
|
| 77 |
+
pass
|
| 78 |
+
|
| 79 |
+
elif isinstance(object_to_download, pd.DataFrame):
|
| 80 |
+
object_to_download = object_to_download.to_csv(index=False).encode("utf-8")
|
| 81 |
+
|
| 82 |
+
try:
|
| 83 |
+
b64 = base64.b64encode(object_to_download.encode()).decode()
|
| 84 |
+
except AttributeError:
|
| 85 |
+
b64 = base64.b64encode(object_to_download).decode()
|
| 86 |
+
|
| 87 |
+
button_uuid = str(uuid.uuid4()).replace("-", "")
|
| 88 |
+
button_id = re.sub("\d+", "", button_uuid)
|
| 89 |
+
|
| 90 |
+
custom_css = f"""
|
| 91 |
+
<style>
|
| 92 |
+
#{button_id} {{
|
| 93 |
+
background-color: rgb(255, 255, 255);
|
| 94 |
+
color: rgb(38, 39, 48);
|
| 95 |
+
text-decoration: none;
|
| 96 |
+
border-radius: 4px;
|
| 97 |
+
border-width: 1px;
|
| 98 |
+
border-style: solid;
|
| 99 |
+
border-color: rgb(230, 234, 241);
|
| 100 |
+
border-image: initial;
|
| 101 |
+
}}
|
| 102 |
+
#{button_id}:hover {{
|
| 103 |
+
border-color: rgb(246, 51, 102);
|
| 104 |
+
color: rgb(246, 51, 102);
|
| 105 |
+
}}
|
| 106 |
+
#{button_id}:active {{
|
| 107 |
+
box-shadow: none;
|
| 108 |
+
background-color: rgb(246, 51, 102);
|
| 109 |
+
color: white;
|
| 110 |
+
}}
|
| 111 |
+
</style> """
|
| 112 |
+
|
| 113 |
+
dl_link = (
|
| 114 |
+
custom_css
|
| 115 |
+
+ f'<a download="{download_filename}" id="{button_id}" href="data:file/txt;base64,{b64}">{button_text}</a><br></br>'
|
| 116 |
+
)
|
| 117 |
+
return dl_link
|
| 118 |
+
|
| 119 |
+
|
| 120 |
+
@st.cache_resource
|
| 121 |
+
def load_planning_resources_cached(): # Renamed to avoid conflict if main calls it directly
|
| 122 |
+
building_blocks_path = hf_hub_download(
|
| 123 |
+
repo_id="Laboratoire-De-Chemoinformatique/SynPlanner",
|
| 124 |
+
filename="building_blocks_em_sa_ln.smi",
|
| 125 |
+
subfolder="building_blocks",
|
| 126 |
+
local_dir=".",
|
| 127 |
+
)
|
| 128 |
+
ranking_policy_weights_path = hf_hub_download(
|
| 129 |
+
repo_id="Laboratoire-De-Chemoinformatique/SynPlanner",
|
| 130 |
+
filename="ranking_policy_network.ckpt",
|
| 131 |
+
subfolder="uspto/weights",
|
| 132 |
+
local_dir=".",
|
| 133 |
+
)
|
| 134 |
+
reaction_rules_path = hf_hub_download(
|
| 135 |
+
repo_id="Laboratoire-De-Chemoinformatique/SynPlanner",
|
| 136 |
+
filename="uspto_reaction_rules.pickle",
|
| 137 |
+
subfolder="uspto",
|
| 138 |
+
local_dir=".",
|
| 139 |
+
)
|
| 140 |
+
return building_blocks_path, ranking_policy_weights_path, reaction_rules_path
|
| 141 |
+
|
| 142 |
+
|
| 143 |
+
# --- GUI Sections ---
|
| 144 |
+
|
| 145 |
+
|
| 146 |
+
def initialize_app():
|
| 147 |
+
"""1. Initialization: Setting up the main window, layout, and initial widgets."""
|
| 148 |
+
st.set_page_config(page_title="SynPlanner GUI", page_icon="🧪", layout="wide")
|
| 149 |
+
|
| 150 |
+
# Initialize session state variables if they don't exist.
|
| 151 |
+
if "planning_done" not in st.session_state:
|
| 152 |
+
st.session_state.planning_done = False
|
| 153 |
+
if "tree" not in st.session_state:
|
| 154 |
+
st.session_state.tree = None
|
| 155 |
+
if "res" not in st.session_state:
|
| 156 |
+
st.session_state.res = None
|
| 157 |
+
if "target_smiles" not in st.session_state:
|
| 158 |
+
st.session_state.target_smiles = (
|
| 159 |
+
"" # Initial value, might be overwritten by ketcher
|
| 160 |
+
)
|
| 161 |
+
|
| 162 |
+
# Clustering state
|
| 163 |
+
if "clustering_done" not in st.session_state:
|
| 164 |
+
st.session_state.clustering_done = False
|
| 165 |
+
if "clusters" not in st.session_state:
|
| 166 |
+
st.session_state.clusters = None
|
| 167 |
+
if "reactions_dict" not in st.session_state:
|
| 168 |
+
st.session_state.reactions_dict = None
|
| 169 |
+
if "num_clusters_setting" not in st.session_state: # Store the setting used
|
| 170 |
+
st.session_state.num_clusters_setting = 10
|
| 171 |
+
if "route_cgrs_dict" not in st.session_state:
|
| 172 |
+
st.session_state.route_cgrs_dict = None
|
| 173 |
+
if "sb_cgrs_dict" not in st.session_state:
|
| 174 |
+
st.session_state.sb_cgrs_dict = None
|
| 175 |
+
if "route_json" not in st.session_state:
|
| 176 |
+
st.session_state.route_json = None
|
| 177 |
+
|
| 178 |
+
# Subclustering state
|
| 179 |
+
if "subclustering_done" not in st.session_state:
|
| 180 |
+
st.session_state.subclustering_done = False
|
| 181 |
+
if "subclusters" not in st.session_state: # Renamed from 'sub' for clarity
|
| 182 |
+
st.session_state.subclusters = None
|
| 183 |
+
|
| 184 |
+
# Download state (less critical now with direct download links)
|
| 185 |
+
if "clusters_downloaded" not in st.session_state: # Example, might not be needed
|
| 186 |
+
st.session_state.clusters_downloaded = False
|
| 187 |
+
|
| 188 |
+
if "ketcher" not in st.session_state: # For ketcher persistence
|
| 189 |
+
st.session_state.ketcher = DEFAULT_MOL
|
| 190 |
+
|
| 191 |
+
intro_text = """
|
| 192 |
+
This is a demo of the graphical user interface of
|
| 193 |
+
[SynPlanner](https://github.com/Laboratoire-de-Chemoinformatique/SynPlanner/).
|
| 194 |
+
SynPlanner is a comprehensive tool for reaction data curation, rule extraction, model training and retrosynthetic planning.
|
| 195 |
+
|
| 196 |
+
More information on SynPlanner is available in the [official docs](https://synplanner.readthedocs.io/en/latest/index.html).
|
| 197 |
+
"""
|
| 198 |
+
st.title("`SynPlanner GUI`")
|
| 199 |
+
st.write(intro_text)
|
| 200 |
+
|
| 201 |
+
|
| 202 |
+
def setup_sidebar():
|
| 203 |
+
"""2. Sidebar: Handling the widgets and logic within the sidebar area."""
|
| 204 |
+
# st.sidebar.image("img/logo.png") # Assuming img/logo.png is available
|
| 205 |
+
st.sidebar.title("Docs")
|
| 206 |
+
st.sidebar.markdown("https://synplanner.readthedocs.io/en/latest/")
|
| 207 |
+
|
| 208 |
+
st.sidebar.title("Tutorials")
|
| 209 |
+
st.sidebar.markdown(
|
| 210 |
+
"https://github.com/Laboratoire-de-Chemoinformatique/SynPlanner/tree/main/tutorials"
|
| 211 |
+
)
|
| 212 |
+
|
| 213 |
+
st.sidebar.title("Paper")
|
| 214 |
+
st.sidebar.markdown(
|
| 215 |
+
"https://chemrxiv.org/engage/chemrxiv/article-details/66add90bc9c6a5c07ae65796"
|
| 216 |
+
)
|
| 217 |
+
|
| 218 |
+
st.sidebar.title("Issues")
|
| 219 |
+
st.sidebar.markdown(
|
| 220 |
+
"[Report a bug 🐞](https://github.com/Laboratoire-de-Chemoinformatique/SynPlanner/issues/new?assignees=&labels=bug&projects=&template=bug_report.md&title=%5BBUG%5D)"
|
| 221 |
+
)
|
| 222 |
+
|
| 223 |
+
|
| 224 |
+
def handle_molecule_input():
|
| 225 |
+
"""3. Molecule Input: Managing the input area for molecule data with two-way synchronization."""
|
| 226 |
+
st.header("Molecule input")
|
| 227 |
+
st.markdown(
|
| 228 |
+
"""
|
| 229 |
+
You can provide a molecular structure by either providing:
|
| 230 |
+
* SMILES string + Enter
|
| 231 |
+
* Draw it + Apply
|
| 232 |
+
"""
|
| 233 |
+
)
|
| 234 |
+
|
| 235 |
+
if "shared_smiles" not in st.session_state:
|
| 236 |
+
st.session_state.shared_smiles = st.session_state.get("ketcher", DEFAULT_MOL)
|
| 237 |
+
|
| 238 |
+
if "ketcher_render_count" not in st.session_state:
|
| 239 |
+
st.session_state.ketcher_render_count = 0
|
| 240 |
+
|
| 241 |
+
def text_input_changed_callback():
|
| 242 |
+
new_text_value = (
|
| 243 |
+
st.session_state.smiles_text_input_key_for_sync
|
| 244 |
+
) # Key of the text_input
|
| 245 |
+
if new_text_value != st.session_state.shared_smiles:
|
| 246 |
+
st.session_state.shared_smiles = new_text_value
|
| 247 |
+
st.session_state.ketcher = new_text_value
|
| 248 |
+
st.session_state.ketcher_render_count += 1
|
| 249 |
+
|
| 250 |
+
# SMILES Text Input
|
| 251 |
+
st.text_input(
|
| 252 |
+
"SMILES:",
|
| 253 |
+
value=st.session_state.shared_smiles,
|
| 254 |
+
key="smiles_text_input_key_for_sync", # Unique key for this widget
|
| 255 |
+
on_change=text_input_changed_callback,
|
| 256 |
+
help="Enter SMILES string and press Enter. The drawing will update, and vice-versa.",
|
| 257 |
+
)
|
| 258 |
+
|
| 259 |
+
ketcher_key = f"ketcher_widget_for_sync_{st.session_state.ketcher_render_count}"
|
| 260 |
+
smile_code_output_from_ketcher = st_ketcher(
|
| 261 |
+
st.session_state.shared_smiles, key=ketcher_key
|
| 262 |
+
)
|
| 263 |
+
|
| 264 |
+
if smile_code_output_from_ketcher != st.session_state.shared_smiles:
|
| 265 |
+
st.session_state.shared_smiles = smile_code_output_from_ketcher
|
| 266 |
+
st.session_state.ketcher = smile_code_output_from_ketcher
|
| 267 |
+
st.rerun()
|
| 268 |
+
|
| 269 |
+
current_smiles_for_planning = st.session_state.shared_smiles
|
| 270 |
+
|
| 271 |
+
last_planned_smiles = st.session_state.get("target_smiles")
|
| 272 |
+
if (
|
| 273 |
+
last_planned_smiles
|
| 274 |
+
and current_smiles_for_planning != last_planned_smiles
|
| 275 |
+
and st.session_state.get("planning_done", False)
|
| 276 |
+
):
|
| 277 |
+
st.warning(
|
| 278 |
+
"Molecule structure has changed since the last successful planning run. "
|
| 279 |
+
"Results shown below (if any) are for the previous molecule. "
|
| 280 |
+
"Please re-run planning for the current structure."
|
| 281 |
+
)
|
| 282 |
+
|
| 283 |
+
# Ensure st.session_state.ketcher is consistent for other parts of the app
|
| 284 |
+
if st.session_state.get("ketcher") != current_smiles_for_planning:
|
| 285 |
+
st.session_state.ketcher = current_smiles_for_planning
|
| 286 |
+
|
| 287 |
+
return current_smiles_for_planning
|
| 288 |
+
|
| 289 |
+
|
| 290 |
+
def setup_planning_options():
|
| 291 |
+
"""4. Planning: Encapsulating the logic related to the "planning" functionality."""
|
| 292 |
+
st.header("Launch calculation")
|
| 293 |
+
st.markdown(
|
| 294 |
+
"""If you modified the structure, please ensure you clicked on `Apply` (bottom right of the molecular editor)."""
|
| 295 |
+
)
|
| 296 |
+
|
| 297 |
+
st.markdown(
|
| 298 |
+
f"The molecule SMILES is actually: ``{st.session_state.get('ketcher', DEFAULT_MOL)}``"
|
| 299 |
+
)
|
| 300 |
+
|
| 301 |
+
st.subheader("Planning options")
|
| 302 |
+
st.markdown(
|
| 303 |
+
"""
|
| 304 |
+
The description of each option can be found in the
|
| 305 |
+
[Retrosynthetic Planning Tutorial](https://synplanner.readthedocs.io/en/latest/tutorial_files/retrosynthetic_planning.html#Configuring-search-tree).
|
| 306 |
+
"""
|
| 307 |
+
)
|
| 308 |
+
|
| 309 |
+
col_options_1, col_options_2 = st.columns(2, gap="medium")
|
| 310 |
+
with col_options_1:
|
| 311 |
+
search_strategy_input = st.selectbox(
|
| 312 |
+
label="Search strategy",
|
| 313 |
+
options=(
|
| 314 |
+
"Expansion first",
|
| 315 |
+
"Evaluation first",
|
| 316 |
+
),
|
| 317 |
+
index=0,
|
| 318 |
+
key="search_strategy_input",
|
| 319 |
+
)
|
| 320 |
+
ucb_type = st.selectbox(
|
| 321 |
+
label="UCB type",
|
| 322 |
+
options=("uct", "puct", "value"),
|
| 323 |
+
index=0,
|
| 324 |
+
key="ucb_type_input",
|
| 325 |
+
) # Fixed label
|
| 326 |
+
c_ucb = st.number_input(
|
| 327 |
+
"C coefficient of UCB",
|
| 328 |
+
value=0.1,
|
| 329 |
+
placeholder="Type a number...",
|
| 330 |
+
key="c_ucb_input",
|
| 331 |
+
)
|
| 332 |
+
|
| 333 |
+
with col_options_2:
|
| 334 |
+
max_iterations = st.slider(
|
| 335 |
+
"Total number of MCTS iterations",
|
| 336 |
+
min_value=50,
|
| 337 |
+
max_value=1000,
|
| 338 |
+
value=300,
|
| 339 |
+
key="max_iterations_slider",
|
| 340 |
+
)
|
| 341 |
+
max_depth = st.slider(
|
| 342 |
+
"Maximal number of reaction steps",
|
| 343 |
+
min_value=3,
|
| 344 |
+
max_value=9,
|
| 345 |
+
value=6,
|
| 346 |
+
key="max_depth_slider",
|
| 347 |
+
)
|
| 348 |
+
min_mol_size = st.slider(
|
| 349 |
+
"Minimum size of a molecule to be precursor",
|
| 350 |
+
min_value=0,
|
| 351 |
+
max_value=7,
|
| 352 |
+
value=0,
|
| 353 |
+
key="min_mol_size_slider",
|
| 354 |
+
help="Number of non-hydrogen atoms in molecule",
|
| 355 |
+
)
|
| 356 |
+
|
| 357 |
+
search_strategy_translator = {
|
| 358 |
+
"Expansion first": "expansion_first",
|
| 359 |
+
"Evaluation first": "evaluation_first",
|
| 360 |
+
}
|
| 361 |
+
search_strategy = search_strategy_translator[search_strategy_input]
|
| 362 |
+
|
| 363 |
+
planning_params = {
|
| 364 |
+
"search_strategy": search_strategy,
|
| 365 |
+
"ucb_type": ucb_type,
|
| 366 |
+
"c_ucb": c_ucb,
|
| 367 |
+
"max_iterations": max_iterations,
|
| 368 |
+
"max_depth": max_depth,
|
| 369 |
+
"min_mol_size": min_mol_size,
|
| 370 |
+
}
|
| 371 |
+
|
| 372 |
+
if st.button("Start retrosynthetic planning", key="submit_planning_button"):
|
| 373 |
+
# Reset downstream states if replanning
|
| 374 |
+
st.session_state.planning_done = False
|
| 375 |
+
st.session_state.clustering_done = False
|
| 376 |
+
st.session_state.subclustering_done = False
|
| 377 |
+
st.session_state.tree = None
|
| 378 |
+
st.session_state.res = None
|
| 379 |
+
st.session_state.clusters = None
|
| 380 |
+
st.session_state.reactions_dict = None
|
| 381 |
+
st.session_state.subclusters = None
|
| 382 |
+
st.session_state.route_cgrs_dict = None
|
| 383 |
+
st.session_state.sb_cgrs_dict = None
|
| 384 |
+
st.session_state.route_json = None
|
| 385 |
+
active_smile_code = st.session_state.get(
|
| 386 |
+
"ketcher", DEFAULT_MOL
|
| 387 |
+
) # Get current SMILES
|
| 388 |
+
st.session_state.target_smiles = (
|
| 389 |
+
active_smile_code # Store the SMILES used for this run
|
| 390 |
+
)
|
| 391 |
+
|
| 392 |
+
try:
|
| 393 |
+
target_molecule = mol_from_smiles(active_smile_code, clean_stereo=True)
|
| 394 |
+
if target_molecule is None:
|
| 395 |
+
st.error(f"Could not parse the input SMILES: {active_smile_code}")
|
| 396 |
+
else:
|
| 397 |
+
(
|
| 398 |
+
building_blocks_path,
|
| 399 |
+
ranking_policy_weights_path,
|
| 400 |
+
reaction_rules_path,
|
| 401 |
+
) = load_planning_resources_cached()
|
| 402 |
+
with st.spinner("Running retrosynthetic planning..."):
|
| 403 |
+
with st.status("Loading resources...", expanded=False) as status:
|
| 404 |
+
st.write("Loading building blocks...")
|
| 405 |
+
building_blocks = load_building_blocks(
|
| 406 |
+
building_blocks_path, standardize=False
|
| 407 |
+
)
|
| 408 |
+
st.write("Loading reaction rules...")
|
| 409 |
+
reaction_rules = load_reaction_rules(reaction_rules_path)
|
| 410 |
+
st.write("Loading policy network...")
|
| 411 |
+
policy_config = PolicyNetworkConfig(
|
| 412 |
+
weights_path=ranking_policy_weights_path
|
| 413 |
+
)
|
| 414 |
+
policy_function = PolicyNetworkFunction(
|
| 415 |
+
policy_config=policy_config
|
| 416 |
+
)
|
| 417 |
+
status.update(label="Resources loaded!", state="complete")
|
| 418 |
+
|
| 419 |
+
tree_config = TreeConfig(
|
| 420 |
+
search_strategy=planning_params["search_strategy"],
|
| 421 |
+
evaluation_type="rollout", # This was hardcoded, keeping it.
|
| 422 |
+
max_iterations=planning_params["max_iterations"],
|
| 423 |
+
max_depth=planning_params["max_depth"],
|
| 424 |
+
min_mol_size=planning_params["min_mol_size"],
|
| 425 |
+
init_node_value=0.5, # This was hardcoded
|
| 426 |
+
ucb_type=planning_params["ucb_type"],
|
| 427 |
+
c_ucb=planning_params["c_ucb"],
|
| 428 |
+
silent=True, # This was hardcoded
|
| 429 |
+
)
|
| 430 |
+
|
| 431 |
+
tree = Tree(
|
| 432 |
+
target=target_molecule,
|
| 433 |
+
config=tree_config,
|
| 434 |
+
reaction_rules=reaction_rules,
|
| 435 |
+
building_blocks=building_blocks,
|
| 436 |
+
expansion_function=policy_function,
|
| 437 |
+
evaluation_function=None, # This was hardcoded
|
| 438 |
+
)
|
| 439 |
+
|
| 440 |
+
mcts_progress_text = "Running MCTS iterations..."
|
| 441 |
+
mcts_bar = st.progress(0, text=mcts_progress_text)
|
| 442 |
+
for step, (solved, route_id) in enumerate(tree):
|
| 443 |
+
progress_value = min(
|
| 444 |
+
1.0, (step + 1) / planning_params["max_iterations"]
|
| 445 |
+
)
|
| 446 |
+
mcts_bar.progress(
|
| 447 |
+
progress_value,
|
| 448 |
+
text=f"{mcts_progress_text} ({step+1}/{planning_params['max_iterations']})",
|
| 449 |
+
)
|
| 450 |
+
|
| 451 |
+
res = extract_tree_stats(tree, target_molecule)
|
| 452 |
+
|
| 453 |
+
st.session_state["tree"] = tree
|
| 454 |
+
st.session_state["res"] = res
|
| 455 |
+
st.session_state.planning_done = True
|
| 456 |
+
st.rerun()
|
| 457 |
+
|
| 458 |
+
except Exception as e:
|
| 459 |
+
st.error(f"An error occurred during planning: {e}")
|
| 460 |
+
st.session_state.planning_done = False
|
| 461 |
+
|
| 462 |
+
|
| 463 |
+
def display_planning_results():
|
| 464 |
+
"""5. Planning Results Display: Handling the presentation of results."""
|
| 465 |
+
if st.session_state.get("planning_done", False):
|
| 466 |
+
res = st.session_state.res
|
| 467 |
+
tree = st.session_state.tree
|
| 468 |
+
|
| 469 |
+
if res is None or tree is None:
|
| 470 |
+
st.error(
|
| 471 |
+
"Planning results are missing from session state. Please re-run planning."
|
| 472 |
+
)
|
| 473 |
+
st.session_state.planning_done = False # Reset state
|
| 474 |
+
return # Exit this function if no results
|
| 475 |
+
|
| 476 |
+
if res.get("solved", False): # Use .get for safety
|
| 477 |
+
st.header("Planning Results")
|
| 478 |
+
winning_nodes = (
|
| 479 |
+
sorted(set(tree.winning_nodes))
|
| 480 |
+
if hasattr(tree, "winning_nodes") and tree.winning_nodes
|
| 481 |
+
else []
|
| 482 |
+
)
|
| 483 |
+
st.subheader(f"Number of unique routes found: {len(winning_nodes)}")
|
| 484 |
+
|
| 485 |
+
st.subheader("Examples of found retrosynthetic routes")
|
| 486 |
+
image_counter = 0
|
| 487 |
+
visualised_route_ids = set()
|
| 488 |
+
|
| 489 |
+
if not winning_nodes:
|
| 490 |
+
st.warning(
|
| 491 |
+
"Planning solved, but no winning nodes found in the tree object."
|
| 492 |
+
)
|
| 493 |
+
else:
|
| 494 |
+
for n, route_id in enumerate(winning_nodes):
|
| 495 |
+
if image_counter >= 3:
|
| 496 |
+
break
|
| 497 |
+
if route_id not in visualised_route_ids:
|
| 498 |
+
try:
|
| 499 |
+
visualised_route_ids.add(route_id)
|
| 500 |
+
num_steps = len(tree.synthesis_route(route_id))
|
| 501 |
+
route_score = round(tree.route_score(route_id), 3)
|
| 502 |
+
svg = get_route_svg(tree, route_id)
|
| 503 |
+
# svg = get_route_svg_from_json(st.session_state.route_json, route_id)
|
| 504 |
+
if svg:
|
| 505 |
+
st.image(
|
| 506 |
+
svg,
|
| 507 |
+
caption=f"Route {route_id}; {num_steps} steps; Route score: {route_score}",
|
| 508 |
+
)
|
| 509 |
+
image_counter += 1
|
| 510 |
+
else:
|
| 511 |
+
st.warning(
|
| 512 |
+
f"Could not generate SVG for route {route_id}."
|
| 513 |
+
)
|
| 514 |
+
except Exception as e:
|
| 515 |
+
st.error(f"Error displaying route {route_id}: {e}")
|
| 516 |
+
else: # Not solved
|
| 517 |
+
st.header("Planning Results")
|
| 518 |
+
st.warning(
|
| 519 |
+
"No reaction path found for the target molecule with the current settings."
|
| 520 |
+
)
|
| 521 |
+
st.write(
|
| 522 |
+
"Consider adjusting planning options (e.g., increase iterations, adjust depth, check molecule validity)."
|
| 523 |
+
)
|
| 524 |
+
stat_col, _ = st.columns(2)
|
| 525 |
+
with stat_col:
|
| 526 |
+
st.subheader("Run Statistics (No Solution)")
|
| 527 |
+
try:
|
| 528 |
+
if (
|
| 529 |
+
"target_smiles" not in res
|
| 530 |
+
and "target_smiles" in st.session_state
|
| 531 |
+
):
|
| 532 |
+
res["target_smiles"] = st.session_state.target_smiles
|
| 533 |
+
cols_to_show = [
|
| 534 |
+
col
|
| 535 |
+
for col in [
|
| 536 |
+
"target_smiles",
|
| 537 |
+
"num_nodes",
|
| 538 |
+
"num_iter",
|
| 539 |
+
"search_time",
|
| 540 |
+
]
|
| 541 |
+
if col in res
|
| 542 |
+
]
|
| 543 |
+
if cols_to_show:
|
| 544 |
+
df = pd.DataFrame(res, index=[0])[cols_to_show]
|
| 545 |
+
st.dataframe(df)
|
| 546 |
+
else:
|
| 547 |
+
st.write("No statistics to display for the unsuccessful run.")
|
| 548 |
+
except Exception as e:
|
| 549 |
+
st.error(f"Error displaying statistics: {e}")
|
| 550 |
+
st.write(res)
|
| 551 |
+
|
| 552 |
+
|
| 553 |
+
def download_planning_results():
|
| 554 |
+
"""6. Planning Results Download: Providing functionality to download."""
|
| 555 |
+
if (
|
| 556 |
+
st.session_state.get("planning_done", False)
|
| 557 |
+
and st.session_state.res
|
| 558 |
+
and st.session_state.res.get("solved", False)
|
| 559 |
+
):
|
| 560 |
+
res = st.session_state.res
|
| 561 |
+
tree = st.session_state.tree
|
| 562 |
+
# This section is usually placed within a column in the original script
|
| 563 |
+
# We'll assume it's called after display_planning_results and can use a new column or area.
|
| 564 |
+
# For proper layout, this should be integrated with display_planning_results' columns.
|
| 565 |
+
# For now, creating a placeholder or separate section for downloads:
|
| 566 |
+
# st.subheader("Downloads") # This might be redundant if called within a layout context.
|
| 567 |
+
|
| 568 |
+
# The original code places downloads in the second column of planning results.
|
| 569 |
+
# To replicate, we'd need to pass the column object or call this within that context.
|
| 570 |
+
# Simulating this by just creating the download links:
|
| 571 |
+
try:
|
| 572 |
+
html_body = generate_results_html(tree, html_path=None, extended=True)
|
| 573 |
+
dl_html = download_button(
|
| 574 |
+
html_body,
|
| 575 |
+
f"results_synplanner_{st.session_state.target_smiles}.html",
|
| 576 |
+
"Download results (HTML)",
|
| 577 |
+
)
|
| 578 |
+
if dl_html:
|
| 579 |
+
st.markdown(dl_html, unsafe_allow_html=True)
|
| 580 |
+
|
| 581 |
+
try:
|
| 582 |
+
res_df = pd.DataFrame(res, index=[0])
|
| 583 |
+
dl_csv = download_button(
|
| 584 |
+
res_df,
|
| 585 |
+
f"stats_synplanner_{st.session_state.target_smiles}.csv",
|
| 586 |
+
"Download statistics (CSV)",
|
| 587 |
+
)
|
| 588 |
+
if dl_csv:
|
| 589 |
+
st.markdown(dl_csv, unsafe_allow_html=True)
|
| 590 |
+
except Exception as e:
|
| 591 |
+
st.error(f"Could not prepare statistics CSV for download: {e}")
|
| 592 |
+
|
| 593 |
+
except Exception as e:
|
| 594 |
+
st.error(f"Error generating download links for planning results: {e}")
|
| 595 |
+
|
| 596 |
+
|
| 597 |
+
def setup_clustering():
|
| 598 |
+
"""7. Clustering: Encapsulating the logic related to the "clustering" functionality."""
|
| 599 |
+
if (
|
| 600 |
+
st.session_state.get("planning_done", False)
|
| 601 |
+
and st.session_state.res
|
| 602 |
+
and st.session_state.res.get("solved", False)
|
| 603 |
+
):
|
| 604 |
+
st.divider()
|
| 605 |
+
st.header("Clustering the retrosynthetic routes")
|
| 606 |
+
|
| 607 |
+
if st.button("Run Clustering", key="submit_clustering_button"):
|
| 608 |
+
# st.session_state.num_clusters_setting = num_clusters_input
|
| 609 |
+
st.session_state.clustering_done = False
|
| 610 |
+
st.session_state.subclustering_done = False
|
| 611 |
+
st.session_state.clusters = None
|
| 612 |
+
st.session_state.reactions_dict = None
|
| 613 |
+
st.session_state.subclusters = None
|
| 614 |
+
st.session_state.route_cgrs_dict = None
|
| 615 |
+
st.session_state.sb_cgrs_dict = None
|
| 616 |
+
st.session_state.route_json = None
|
| 617 |
+
|
| 618 |
+
with st.spinner("Performing clustering..."):
|
| 619 |
+
try:
|
| 620 |
+
current_tree = st.session_state.tree
|
| 621 |
+
if not current_tree:
|
| 622 |
+
st.error("Tree object not found. Please re-run planning.")
|
| 623 |
+
return
|
| 624 |
+
|
| 625 |
+
st.write("Calculating RoutesCGRs...")
|
| 626 |
+
route_cgrs_dict = compose_all_route_cgrs(current_tree)
|
| 627 |
+
st.write("Processing SB-CGRs...")
|
| 628 |
+
sb_cgrs_dict = compose_all_sb_cgrs(route_cgrs_dict)
|
| 629 |
+
|
| 630 |
+
results = cluster_routes(
|
| 631 |
+
sb_cgrs_dict, use_strat=False
|
| 632 |
+
) # num_clusters was removed from args
|
| 633 |
+
results = dict(sorted(results.items(), key=lambda x: float(x[0])))
|
| 634 |
+
|
| 635 |
+
st.session_state.clusters = results
|
| 636 |
+
st.session_state.route_cgrs_dict = route_cgrs_dict
|
| 637 |
+
st.session_state.sb_cgrs_dict = sb_cgrs_dict
|
| 638 |
+
st.write("Extracting reactions...")
|
| 639 |
+
st.session_state.reactions_dict = extract_reactions(current_tree)
|
| 640 |
+
st.session_state.route_json = make_json(st.session_state.reactions_dict)
|
| 641 |
+
|
| 642 |
+
if (
|
| 643 |
+
st.session_state.clusters is not None
|
| 644 |
+
and st.session_state.reactions_dict is not None
|
| 645 |
+
): # Check for None explicitly
|
| 646 |
+
st.session_state.clustering_done = True
|
| 647 |
+
st.success(
|
| 648 |
+
f"Clustering complete. Found {len(st.session_state.clusters)} clusters."
|
| 649 |
+
)
|
| 650 |
+
else:
|
| 651 |
+
st.error("Clustering failed or returned empty results.")
|
| 652 |
+
st.session_state.clustering_done = False
|
| 653 |
+
|
| 654 |
+
del results # route_cgrs_dict, sb_cgrs_dict are stored
|
| 655 |
+
gc.collect()
|
| 656 |
+
st.rerun()
|
| 657 |
+
except Exception as e:
|
| 658 |
+
st.error(f"An error occurred during clustering: {e}")
|
| 659 |
+
st.session_state.clustering_done = False
|
| 660 |
+
|
| 661 |
+
|
| 662 |
+
def display_clustering_results():
|
| 663 |
+
"""8. Clustering Results Display: Handling the presentation of results."""
|
| 664 |
+
if st.session_state.get("clustering_done", False):
|
| 665 |
+
clusters = st.session_state.clusters
|
| 666 |
+
# reactions_dict = st.session_state.reactions_dict # Needed for download, not directly for display here
|
| 667 |
+
tree = st.session_state.tree
|
| 668 |
+
MAX_DISPLAY_CLUSTERS_DATA = 10
|
| 669 |
+
|
| 670 |
+
if (
|
| 671 |
+
clusters is None or tree is None
|
| 672 |
+
): # reactions_dict removed as not critical for display part
|
| 673 |
+
st.error(
|
| 674 |
+
"Clustering results (clusters or tree) are missing. Please re-run clustering."
|
| 675 |
+
)
|
| 676 |
+
st.session_state.clustering_done = False
|
| 677 |
+
return
|
| 678 |
+
|
| 679 |
+
st.subheader(f"Best routes from {len(clusters)} Found Clusters")
|
| 680 |
+
clusters_items = list(clusters.items())
|
| 681 |
+
first_items = clusters_items[:MAX_DISPLAY_CLUSTERS_DATA]
|
| 682 |
+
remaining_items = clusters_items[MAX_DISPLAY_CLUSTERS_DATA:]
|
| 683 |
+
|
| 684 |
+
for cluster_num, group_data in first_items:
|
| 685 |
+
if (
|
| 686 |
+
not group_data
|
| 687 |
+
or "route_ids" not in group_data
|
| 688 |
+
or not group_data["route_ids"]
|
| 689 |
+
):
|
| 690 |
+
st.warning(f"Cluster {cluster_num} has no data or route_ids.")
|
| 691 |
+
continue
|
| 692 |
+
st.markdown(
|
| 693 |
+
f"**Cluster {cluster_num}** (Size: {group_data.get('group_size', 'N/A')})"
|
| 694 |
+
)
|
| 695 |
+
route_id = group_data["route_ids"][0]
|
| 696 |
+
try:
|
| 697 |
+
num_steps = len(tree.synthesis_route(route_id))
|
| 698 |
+
route_score = round(tree.route_score(route_id), 3)
|
| 699 |
+
# svg = get_route_svg(tree, route_id)
|
| 700 |
+
svg = get_route_svg_from_json(st.session_state.route_json, route_id)
|
| 701 |
+
sb_cgr = group_data.get("sb_cgr") # Safely get sb_cgr
|
| 702 |
+
sb_cgr_svg = None
|
| 703 |
+
if sb_cgr:
|
| 704 |
+
sb_cgr.clean2d()
|
| 705 |
+
sb_cgr_svg = cgr_display(sb_cgr)
|
| 706 |
+
|
| 707 |
+
if svg and sb_cgr_svg:
|
| 708 |
+
col1, col2 = st.columns([0.2, 0.8])
|
| 709 |
+
with col1:
|
| 710 |
+
st.image(sb_cgr_svg, caption="SB-CGR")
|
| 711 |
+
with col2:
|
| 712 |
+
st.image(
|
| 713 |
+
svg,
|
| 714 |
+
caption=f"Route {route_id}; {num_steps} steps; Route score: {route_score}",
|
| 715 |
+
)
|
| 716 |
+
elif svg: # Only route SVG available
|
| 717 |
+
st.image(
|
| 718 |
+
svg,
|
| 719 |
+
caption=f"Route {route_id}; {num_steps} steps; Route score: {route_score}",
|
| 720 |
+
)
|
| 721 |
+
st.warning(
|
| 722 |
+
f"SB-CGR could not be displayed for cluster {cluster_num}."
|
| 723 |
+
)
|
| 724 |
+
else:
|
| 725 |
+
st.warning(
|
| 726 |
+
f"Could not generate SVG for route {route_id} or its SB-CGR."
|
| 727 |
+
)
|
| 728 |
+
except Exception as e:
|
| 729 |
+
st.error(
|
| 730 |
+
f"Error displaying route {route_id} for cluster {cluster_num}: {e}"
|
| 731 |
+
)
|
| 732 |
+
|
| 733 |
+
if remaining_items:
|
| 734 |
+
with st.expander(f"... and {len(remaining_items)} more clusters"):
|
| 735 |
+
for cluster_num, group_data in remaining_items:
|
| 736 |
+
if (
|
| 737 |
+
not group_data
|
| 738 |
+
or "route_ids" not in group_data
|
| 739 |
+
or not group_data["route_ids"]
|
| 740 |
+
):
|
| 741 |
+
st.warning(
|
| 742 |
+
f"Cluster {cluster_num} in expansion has no data or route_ids."
|
| 743 |
+
)
|
| 744 |
+
continue
|
| 745 |
+
st.markdown(
|
| 746 |
+
f"**Cluster {cluster_num}** (Size: {group_data.get('group_size', 'N/A')})"
|
| 747 |
+
)
|
| 748 |
+
route_id = group_data["route_ids"][0]
|
| 749 |
+
try:
|
| 750 |
+
num_steps = len(tree.synthesis_route(route_id))
|
| 751 |
+
route_score = round(tree.route_score(route_id), 3)
|
| 752 |
+
# svg = get_route_svg(tree, route_id)
|
| 753 |
+
svg = get_route_svg_from_json(st.session_state.route_json, route_id)
|
| 754 |
+
sb_cgr = group_data.get("sb_cgr")
|
| 755 |
+
sb_cgr_svg = None
|
| 756 |
+
if sb_cgr:
|
| 757 |
+
sb_cgr.clean2d()
|
| 758 |
+
sb_cgr_svg = cgr_display(sb_cgr)
|
| 759 |
+
|
| 760 |
+
if svg and sb_cgr_svg:
|
| 761 |
+
col1, col2 = st.columns([0.2, 0.8])
|
| 762 |
+
with col1:
|
| 763 |
+
st.image(sb_cgr_svg, caption="SB-CGR")
|
| 764 |
+
with col2:
|
| 765 |
+
st.image(
|
| 766 |
+
svg,
|
| 767 |
+
caption=f"Route {route_id}; {num_steps} steps; Route score: {route_score}",
|
| 768 |
+
)
|
| 769 |
+
elif svg:
|
| 770 |
+
st.image(
|
| 771 |
+
svg,
|
| 772 |
+
caption=f"Route {route_id}; {num_steps} steps; Route score: {route_score}",
|
| 773 |
+
)
|
| 774 |
+
st.warning(
|
| 775 |
+
f"SB-CGR could not be displayed for cluster {cluster_num}."
|
| 776 |
+
)
|
| 777 |
+
else:
|
| 778 |
+
st.warning(
|
| 779 |
+
f"Could not generate SVG for route {route_id} or its SB-CGR."
|
| 780 |
+
)
|
| 781 |
+
except Exception as e:
|
| 782 |
+
st.error(
|
| 783 |
+
f"Error displaying route {route_id} for cluster {cluster_num}: {e}"
|
| 784 |
+
)
|
| 785 |
+
|
| 786 |
+
|
| 787 |
+
def download_clustering_results():
|
| 788 |
+
"""10. Clustering Results Download: Providing functionality to download."""
|
| 789 |
+
if st.session_state.get("clustering_done", False):
|
| 790 |
+
tree_for_html = st.session_state.get("tree")
|
| 791 |
+
clusters_for_html = st.session_state.get("clusters")
|
| 792 |
+
sb_cgrs_for_html = st.session_state.get(
|
| 793 |
+
"sb_cgrs_dict"
|
| 794 |
+
) # This was used instead of reactions_dict in the original for report
|
| 795 |
+
|
| 796 |
+
if not tree_for_html:
|
| 797 |
+
st.warning("MCTS Tree data not found. Cannot generate cluster reports.")
|
| 798 |
+
return
|
| 799 |
+
if not clusters_for_html:
|
| 800 |
+
st.warning("Cluster data not found. Cannot generate cluster reports.")
|
| 801 |
+
return
|
| 802 |
+
# sb_cgrs_for_html is optional for routes_clustering_report if not essential
|
| 803 |
+
|
| 804 |
+
st.subheader("Cluster Reports") # Changed subheader in original
|
| 805 |
+
st.write("Generate downloadable HTML reports for each cluster:")
|
| 806 |
+
|
| 807 |
+
MAX_DOWNLOAD_LINKS_DISPLAYED = 10
|
| 808 |
+
num_clusters_total = len(clusters_for_html)
|
| 809 |
+
clusters_items = list(clusters_for_html.items())
|
| 810 |
+
|
| 811 |
+
for i, (cluster_idx, group_data) in enumerate(
|
| 812 |
+
clusters_items
|
| 813 |
+
): # group_data might not be needed here if report uses cluster_idx
|
| 814 |
+
if i >= MAX_DOWNLOAD_LINKS_DISPLAYED:
|
| 815 |
+
break
|
| 816 |
+
try:
|
| 817 |
+
html_content = routes_clustering_report(
|
| 818 |
+
tree_for_html,
|
| 819 |
+
clusters_for_html, # Pass the whole dict
|
| 820 |
+
str(cluster_idx), # Pass the key of the cluster
|
| 821 |
+
sb_cgrs_for_html, # Pass the sb_cgrs dict
|
| 822 |
+
aam=False,
|
| 823 |
+
)
|
| 824 |
+
st.download_button(
|
| 825 |
+
label=f"Download report for cluster {cluster_idx}",
|
| 826 |
+
data=html_content,
|
| 827 |
+
file_name=f"cluster_{cluster_idx}_{st.session_state.target_smiles}.html",
|
| 828 |
+
mime="text/html",
|
| 829 |
+
key=f"download_cluster_{cluster_idx}",
|
| 830 |
+
)
|
| 831 |
+
except Exception as e:
|
| 832 |
+
st.error(f"Error generating report for cluster {cluster_idx}: {e}")
|
| 833 |
+
|
| 834 |
+
if num_clusters_total > MAX_DOWNLOAD_LINKS_DISPLAYED:
|
| 835 |
+
remaining_items = clusters_items[MAX_DOWNLOAD_LINKS_DISPLAYED:]
|
| 836 |
+
remaining_count = len(remaining_items)
|
| 837 |
+
expander_label = f"Show remaining {remaining_count} cluster reports"
|
| 838 |
+
with st.expander(expander_label):
|
| 839 |
+
for (
|
| 840 |
+
group_index,
|
| 841 |
+
_,
|
| 842 |
+
) in remaining_items: # group_data not needed here either
|
| 843 |
+
try:
|
| 844 |
+
html_content = routes_clustering_report(
|
| 845 |
+
tree_for_html,
|
| 846 |
+
clusters_for_html,
|
| 847 |
+
str(group_index),
|
| 848 |
+
sb_cgrs_for_html,
|
| 849 |
+
aam=False,
|
| 850 |
+
)
|
| 851 |
+
st.download_button(
|
| 852 |
+
label=f"Download report for cluster {group_index}",
|
| 853 |
+
data=html_content,
|
| 854 |
+
file_name=f"cluster_{group_index}_{st.session_state.target_smiles}.html",
|
| 855 |
+
mime="text/html",
|
| 856 |
+
key=f"download_cluster_expanded_{group_index}",
|
| 857 |
+
)
|
| 858 |
+
except Exception as e:
|
| 859 |
+
st.error(
|
| 860 |
+
f"Error generating report for cluster {group_index} (expanded): {e}"
|
| 861 |
+
)
|
| 862 |
+
|
| 863 |
+
try:
|
| 864 |
+
buffer = io.BytesIO()
|
| 865 |
+
with zipfile.ZipFile(
|
| 866 |
+
buffer, mode="w", compression=zipfile.ZIP_DEFLATED
|
| 867 |
+
) as zf:
|
| 868 |
+
for idx, _ in clusters_items: # group_data not needed
|
| 869 |
+
html_content_zip = routes_clustering_report(
|
| 870 |
+
tree_for_html,
|
| 871 |
+
clusters_for_html,
|
| 872 |
+
str(idx),
|
| 873 |
+
sb_cgrs_for_html,
|
| 874 |
+
aam=False,
|
| 875 |
+
)
|
| 876 |
+
filename = f"cluster_{idx}_{st.session_state.target_smiles}.html"
|
| 877 |
+
zf.writestr(filename, html_content_zip)
|
| 878 |
+
buffer.seek(0)
|
| 879 |
+
|
| 880 |
+
st.download_button(
|
| 881 |
+
label="📦 Download all cluster reports as ZIP",
|
| 882 |
+
data=buffer,
|
| 883 |
+
file_name=f"all_cluster_reports_{st.session_state.target_smiles}.zip",
|
| 884 |
+
mime="application/zip",
|
| 885 |
+
key="download_all_clusters_zip",
|
| 886 |
+
)
|
| 887 |
+
except Exception as e:
|
| 888 |
+
st.error(f"Error generating ZIP file for cluster reports: {e}")
|
| 889 |
+
|
| 890 |
+
|
| 891 |
+
def setup_subclustering():
|
| 892 |
+
"""11. Subclustering: Encapsulating the logic related to the "subclustering" functionality."""
|
| 893 |
+
if st.session_state.get(
|
| 894 |
+
"clustering_done", False
|
| 895 |
+
): # Subclustering depends on clustering being done
|
| 896 |
+
st.divider()
|
| 897 |
+
st.header("Sub-Clustering within a selected Cluster")
|
| 898 |
+
|
| 899 |
+
if st.button("Run Subclustering Analysis", key="submit_subclustering_button"):
|
| 900 |
+
st.session_state.subclustering_done = False
|
| 901 |
+
st.session_state.subclusters = None
|
| 902 |
+
with st.spinner("Performing subclustering analysis..."):
|
| 903 |
+
try:
|
| 904 |
+
clusters_for_sub = st.session_state.get("clusters")
|
| 905 |
+
sb_cgrs_dict_for_sub = st.session_state.get("sb_cgrs_dict")
|
| 906 |
+
route_cgrs_dict_for_sub = st.session_state.get("route_cgrs_dict")
|
| 907 |
+
|
| 908 |
+
if (
|
| 909 |
+
clusters_for_sub
|
| 910 |
+
and sb_cgrs_dict_for_sub
|
| 911 |
+
and route_cgrs_dict_for_sub
|
| 912 |
+
): # Ensure all are present
|
| 913 |
+
all_subgroups = subcluster_all_clusters(
|
| 914 |
+
clusters_for_sub,
|
| 915 |
+
sb_cgrs_dict_for_sub,
|
| 916 |
+
route_cgrs_dict_for_sub,
|
| 917 |
+
)
|
| 918 |
+
st.session_state.subclusters = all_subgroups
|
| 919 |
+
st.session_state.subclustering_done = True
|
| 920 |
+
st.success("Subclustering analysis complete.")
|
| 921 |
+
gc.collect()
|
| 922 |
+
st.rerun()
|
| 923 |
+
else:
|
| 924 |
+
missing = []
|
| 925 |
+
if not clusters_for_sub:
|
| 926 |
+
missing.append("clusters")
|
| 927 |
+
if not sb_cgrs_dict_for_sub:
|
| 928 |
+
missing.append("SB-CGRs dictionary")
|
| 929 |
+
if not route_cgrs_dict_for_sub:
|
| 930 |
+
missing.append("RouteCGRs dictionary")
|
| 931 |
+
st.error(
|
| 932 |
+
f"Cannot run subclustering. Missing data: {', '.join(missing)}. Please ensure clustering ran successfully."
|
| 933 |
+
)
|
| 934 |
+
st.session_state.subclustering_done = False
|
| 935 |
+
|
| 936 |
+
except Exception as e:
|
| 937 |
+
st.error(f"An error occurred during subclustering: {e}")
|
| 938 |
+
st.session_state.subclustering_done = False
|
| 939 |
+
|
| 940 |
+
|
| 941 |
+
def display_subclustering_results():
|
| 942 |
+
"""12. Subclustering Results Display: Handling the presentation of results."""
|
| 943 |
+
if st.session_state.get("subclustering_done", False):
|
| 944 |
+
sub = st.session_state.get("subclusters")
|
| 945 |
+
tree = st.session_state.get("tree")
|
| 946 |
+
# clusters_for_sub_display = st.session_state.get('clusters') # Not directly used in display logic from original code snippet
|
| 947 |
+
|
| 948 |
+
if not sub or not tree:
|
| 949 |
+
st.error(
|
| 950 |
+
"Subclustering results (subclusters or tree) are missing. Please re-run subclustering."
|
| 951 |
+
)
|
| 952 |
+
st.session_state.subclustering_done = False
|
| 953 |
+
return
|
| 954 |
+
|
| 955 |
+
sub_input_col, sub_display_col = st.columns([0.25, 0.75])
|
| 956 |
+
|
| 957 |
+
with sub_input_col:
|
| 958 |
+
st.subheader("Select Cluster and Subcluster")
|
| 959 |
+
available_cluster_nums = list(sub.keys())
|
| 960 |
+
if not available_cluster_nums:
|
| 961 |
+
st.warning("No clusters available in subclustering results.")
|
| 962 |
+
return # Exit if no clusters to select
|
| 963 |
+
|
| 964 |
+
user_input_cluster_num_display = st.selectbox(
|
| 965 |
+
"Select Cluster #:",
|
| 966 |
+
options=sorted(available_cluster_nums),
|
| 967 |
+
key="subcluster_num_select_key",
|
| 968 |
+
)
|
| 969 |
+
|
| 970 |
+
selected_subcluster_idx = 0
|
| 971 |
+
|
| 972 |
+
if user_input_cluster_num_display in sub:
|
| 973 |
+
sub_step_cluster = sub[user_input_cluster_num_display]
|
| 974 |
+
allowed_subclusters_indices = sorted(list(sub_step_cluster.keys()))
|
| 975 |
+
|
| 976 |
+
if not allowed_subclusters_indices:
|
| 977 |
+
st.warning(
|
| 978 |
+
f"No reaction steps (subclusters) found for Cluster {user_input_cluster_num_display}."
|
| 979 |
+
)
|
| 980 |
+
else:
|
| 981 |
+
selected_subcluster_idx = st.selectbox(
|
| 982 |
+
"Select Subcluster Index:",
|
| 983 |
+
options=allowed_subclusters_indices,
|
| 984 |
+
key="subcluster_index_select_key",
|
| 985 |
+
)
|
| 986 |
+
if selected_subcluster_idx in sub[user_input_cluster_num_display]:
|
| 987 |
+
current_subcluster_data = sub[user_input_cluster_num_display][
|
| 988 |
+
selected_subcluster_idx
|
| 989 |
+
]
|
| 990 |
+
if "sb_cgr" in current_subcluster_data:
|
| 991 |
+
cluster_sb_cgr_display = current_subcluster_data["sb_cgr"]
|
| 992 |
+
cluster_sb_cgr_display.clean2d()
|
| 993 |
+
st.image(
|
| 994 |
+
cluster_sb_cgr_display.depict(),
|
| 995 |
+
caption=f"SB-CGR of parent Cluster {user_input_cluster_num_display}",
|
| 996 |
+
)
|
| 997 |
+
else:
|
| 998 |
+
st.warning("SB-CGR for this subcluster not found.")
|
| 999 |
+
else:
|
| 1000 |
+
st.warning(
|
| 1001 |
+
f"Selected cluster {user_input_cluster_num_display} not found in subclustering results."
|
| 1002 |
+
)
|
| 1003 |
+
return
|
| 1004 |
+
|
| 1005 |
+
with sub_display_col:
|
| 1006 |
+
st.subheader("Subcluster Details")
|
| 1007 |
+
if (
|
| 1008 |
+
user_input_cluster_num_display in sub
|
| 1009 |
+
and selected_subcluster_idx in sub[user_input_cluster_num_display]
|
| 1010 |
+
):
|
| 1011 |
+
|
| 1012 |
+
subcluster_content = sub[user_input_cluster_num_display][
|
| 1013 |
+
selected_subcluster_idx
|
| 1014 |
+
]
|
| 1015 |
+
|
| 1016 |
+
# subcluster_to_display = post_process_subgroup(subcluster_content) #Under development
|
| 1017 |
+
subcluster_to_display = subcluster_content
|
| 1018 |
+
if (
|
| 1019 |
+
not subcluster_to_display
|
| 1020 |
+
or "routes_data" not in subcluster_to_display
|
| 1021 |
+
or not subcluster_to_display["routes_data"]
|
| 1022 |
+
):
|
| 1023 |
+
st.info("No routes or data found for this subcluster selection.")
|
| 1024 |
+
else:
|
| 1025 |
+
MAX_ROUTES_PER_SUBCLUSTER = 5
|
| 1026 |
+
all_route_ids_in_subcluster = list(
|
| 1027 |
+
subcluster_to_display["routes_data"].keys()
|
| 1028 |
+
)
|
| 1029 |
+
routes_to_display_direct = all_route_ids_in_subcluster[
|
| 1030 |
+
:MAX_ROUTES_PER_SUBCLUSTER
|
| 1031 |
+
]
|
| 1032 |
+
remaining_routes_sub = all_route_ids_in_subcluster[
|
| 1033 |
+
MAX_ROUTES_PER_SUBCLUSTER:
|
| 1034 |
+
]
|
| 1035 |
+
|
| 1036 |
+
st.markdown(
|
| 1037 |
+
f"--- \n**Subcluster {user_input_cluster_num_display}.{selected_subcluster_idx}** (Size: {len(all_route_ids_in_subcluster)})"
|
| 1038 |
+
)
|
| 1039 |
+
|
| 1040 |
+
if "synthon_reaction" in subcluster_to_display:
|
| 1041 |
+
synthon_reaction = subcluster_to_display["synthon_reaction"]
|
| 1042 |
+
try:
|
| 1043 |
+
synthon_reaction.clean2d()
|
| 1044 |
+
st.image(
|
| 1045 |
+
depict_custom_reaction(synthon_reaction),
|
| 1046 |
+
caption=f"Markush-like pseudo reaction of subcluster",
|
| 1047 |
+
) # Assuming depict_custom_reaction
|
| 1048 |
+
except Exception as e_depict:
|
| 1049 |
+
st.warning(f"Could not depict synthon reaction: {e_depict}")
|
| 1050 |
+
else:
|
| 1051 |
+
st.info("No synthon reaction data for this subcluster.")
|
| 1052 |
+
with st.container(height=500):
|
| 1053 |
+
for route_id in routes_to_display_direct:
|
| 1054 |
+
try:
|
| 1055 |
+
route_score_sub = round(tree.route_score(route_id), 3)
|
| 1056 |
+
# svg_sub = get_route_svg(tree, route_id)
|
| 1057 |
+
svg_sub = get_route_svg_from_json(st.session_state.route_json, route_id)
|
| 1058 |
+
if svg_sub:
|
| 1059 |
+
st.image(
|
| 1060 |
+
svg_sub,
|
| 1061 |
+
caption=f"Route {route_id}; Score: {route_score_sub}",
|
| 1062 |
+
)
|
| 1063 |
+
else:
|
| 1064 |
+
st.warning(
|
| 1065 |
+
f"Could not generate SVG for route {route_id}."
|
| 1066 |
+
)
|
| 1067 |
+
except Exception as e:
|
| 1068 |
+
st.error(
|
| 1069 |
+
f"Error displaying route {route_id} in subcluster: {e}"
|
| 1070 |
+
)
|
| 1071 |
+
|
| 1072 |
+
if remaining_routes_sub:
|
| 1073 |
+
with st.expander(
|
| 1074 |
+
f"... and {len(remaining_routes_sub)} more routes in this subcluster"
|
| 1075 |
+
):
|
| 1076 |
+
for route_id in remaining_routes_sub:
|
| 1077 |
+
try:
|
| 1078 |
+
route_score_sub = round(
|
| 1079 |
+
tree.route_score(route_id), 3
|
| 1080 |
+
)
|
| 1081 |
+
# svg_sub = get_route_svg(tree, route_id)
|
| 1082 |
+
svg_sub = get_route_svg_from_json(st.session_state.route_json, route_id)
|
| 1083 |
+
if svg_sub:
|
| 1084 |
+
st.image(
|
| 1085 |
+
svg_sub,
|
| 1086 |
+
caption=f"Route {route_id}; Score: {route_score_sub}",
|
| 1087 |
+
)
|
| 1088 |
+
else:
|
| 1089 |
+
st.warning(
|
| 1090 |
+
f"Could not generate SVG for route {route_id}."
|
| 1091 |
+
)
|
| 1092 |
+
except Exception as e:
|
| 1093 |
+
st.error(
|
| 1094 |
+
f"Error displaying route {route_id} in subcluster (expanded): {e}"
|
| 1095 |
+
)
|
| 1096 |
+
else:
|
| 1097 |
+
st.info("Select a valid cluster and subcluster index to see details.")
|
| 1098 |
+
|
| 1099 |
+
|
| 1100 |
+
def download_subclustering_results():
|
| 1101 |
+
"""13. Subclustering Results Download: Providing functionality to download."""
|
| 1102 |
+
if (
|
| 1103 |
+
st.session_state.get("subclustering_done", False)
|
| 1104 |
+
and "subcluster_num_select_key" in st.session_state
|
| 1105 |
+
and "subcluster_index_select_key" in st.session_state
|
| 1106 |
+
):
|
| 1107 |
+
|
| 1108 |
+
sub = st.session_state.get("subclusters")
|
| 1109 |
+
tree = st.session_state.get("tree")
|
| 1110 |
+
sb_cgrs_for_report = st.session_state.get(
|
| 1111 |
+
"sb_cgrs_dict"
|
| 1112 |
+
) # Used by routes_subclustering_report
|
| 1113 |
+
|
| 1114 |
+
user_input_cluster_num_display = st.session_state.subcluster_num_select_key
|
| 1115 |
+
selected_subcluster_idx = st.session_state.subcluster_index_select_key
|
| 1116 |
+
|
| 1117 |
+
if not tree or not sub or not sb_cgrs_for_report:
|
| 1118 |
+
st.warning(
|
| 1119 |
+
"Missing data for subclustering report generation (tree, subclusters, or SB-CGRs)."
|
| 1120 |
+
)
|
| 1121 |
+
return
|
| 1122 |
+
|
| 1123 |
+
if (
|
| 1124 |
+
user_input_cluster_num_display in sub
|
| 1125 |
+
and selected_subcluster_idx in sub[user_input_cluster_num_display]
|
| 1126 |
+
):
|
| 1127 |
+
|
| 1128 |
+
subcluster_data_for_report = sub[user_input_cluster_num_display][
|
| 1129 |
+
selected_subcluster_idx
|
| 1130 |
+
]
|
| 1131 |
+
# Apply the same post-processing as in display
|
| 1132 |
+
processed_subcluster_data = post_process_subgroup(
|
| 1133 |
+
subcluster_data_for_report
|
| 1134 |
+
)
|
| 1135 |
+
if "routes_data" in subcluster_data_for_report and isinstance(
|
| 1136 |
+
subcluster_data_for_report["routes_data"], dict
|
| 1137 |
+
):
|
| 1138 |
+
processed_subcluster_data["group_lgs"] = group_by_identical_values(
|
| 1139 |
+
subcluster_data_for_report["routes_data"]
|
| 1140 |
+
)
|
| 1141 |
+
else:
|
| 1142 |
+
processed_subcluster_data["group_lgs"] = {}
|
| 1143 |
+
|
| 1144 |
+
try:
|
| 1145 |
+
subcluster_html_content = routes_subclustering_report(
|
| 1146 |
+
tree,
|
| 1147 |
+
processed_subcluster_data, # Pass the specific post-processed subcluster data
|
| 1148 |
+
user_input_cluster_num_display,
|
| 1149 |
+
selected_subcluster_idx,
|
| 1150 |
+
sb_cgrs_for_report, # Pass the whole sb_cgrs dict
|
| 1151 |
+
if_lg_group=True, # This parameter was in the original call
|
| 1152 |
+
)
|
| 1153 |
+
st.download_button(
|
| 1154 |
+
label=f"Download report for subcluster {user_input_cluster_num_display}.{selected_subcluster_idx}",
|
| 1155 |
+
data=subcluster_html_content,
|
| 1156 |
+
file_name=f"subcluster_{user_input_cluster_num_display}.{selected_subcluster_idx}_{st.session_state.target_smiles}.html",
|
| 1157 |
+
mime="text/html",
|
| 1158 |
+
key=f"download_subcluster_{user_input_cluster_num_display}_{selected_subcluster_idx}",
|
| 1159 |
+
)
|
| 1160 |
+
except Exception as e:
|
| 1161 |
+
st.error(
|
| 1162 |
+
f"Error generating download report for subcluster {user_input_cluster_num_display}.{selected_subcluster_idx}: {e}"
|
| 1163 |
+
)
|
| 1164 |
+
# else:
|
| 1165 |
+
# This case is handled by the display logic mostly, download button just won't appear or will be for previous valid selection.
|
| 1166 |
+
|
| 1167 |
+
|
| 1168 |
+
def implement_restart():
|
| 1169 |
+
"""14. Restart: Implementing the logic to reset or restart the application state."""
|
| 1170 |
+
st.divider()
|
| 1171 |
+
st.header("Restart Application State")
|
| 1172 |
+
if st.button("Clear All Results & Restart", key="restart_button"):
|
| 1173 |
+
keys_to_clear = [
|
| 1174 |
+
"planning_done",
|
| 1175 |
+
"tree",
|
| 1176 |
+
"res",
|
| 1177 |
+
"target_smiles",
|
| 1178 |
+
"clustering_done",
|
| 1179 |
+
"clusters",
|
| 1180 |
+
"reactions_dict",
|
| 1181 |
+
"num_clusters_setting",
|
| 1182 |
+
"route_cgrs_dict",
|
| 1183 |
+
"sb_cgrs_dict",
|
| 1184 |
+
"route_json",
|
| 1185 |
+
"subclustering_done",
|
| 1186 |
+
"subclusters", # "sub" was renamed
|
| 1187 |
+
"clusters_downloaded",
|
| 1188 |
+
# Potentially ketcher related keys if they need manual reset beyond new input
|
| 1189 |
+
"ketcher_widget",
|
| 1190 |
+
"smiles_text_input_key", # Keys for widgets
|
| 1191 |
+
"subcluster_num_select_key",
|
| 1192 |
+
"subcluster_index_select_key",
|
| 1193 |
+
]
|
| 1194 |
+
for key in keys_to_clear:
|
| 1195 |
+
if key in st.session_state:
|
| 1196 |
+
del st.session_state[key]
|
| 1197 |
+
|
| 1198 |
+
# Reset ketcher input to default by resetting its session state variable
|
| 1199 |
+
st.session_state.ketcher = DEFAULT_MOL
|
| 1200 |
+
# Also explicitly set target_smiles to empty or default to avoid stale data
|
| 1201 |
+
st.session_state.target_smiles = ""
|
| 1202 |
+
|
| 1203 |
+
# It's generally better to let Streamlit manage widget state if possible,
|
| 1204 |
+
# but for a full reset, clearing their explicit session state keys might be needed.
|
| 1205 |
+
st.rerun()
|
| 1206 |
+
|
| 1207 |
+
|
| 1208 |
+
# --- Main Application Flow ---
|
| 1209 |
+
def main():
|
| 1210 |
+
initialize_app()
|
| 1211 |
+
setup_sidebar()
|
| 1212 |
+
current_smile_code = handle_molecule_input()
|
| 1213 |
+
# Update session_state.ketcher if current_smile_code has changed from ketcher output
|
| 1214 |
+
if st.session_state.get("ketcher") != current_smile_code:
|
| 1215 |
+
st.session_state.ketcher = current_smile_code
|
| 1216 |
+
# No rerun here, let the flow continue. handle_molecule_input already warns.
|
| 1217 |
+
|
| 1218 |
+
setup_planning_options() # This function now also handles the button press and logic for planning
|
| 1219 |
+
|
| 1220 |
+
# Display planning results and download options together
|
| 1221 |
+
if st.session_state.get("planning_done", False):
|
| 1222 |
+
display_planning_results() # Displays stats and routes
|
| 1223 |
+
if st.session_state.res and st.session_state.res.get("solved", False):
|
| 1224 |
+
stat_col, download_col = st.columns(
|
| 1225 |
+
2, gap="medium"
|
| 1226 |
+
) # Placeholder for download column
|
| 1227 |
+
with stat_col:
|
| 1228 |
+
st.subheader("Statistics")
|
| 1229 |
+
try:
|
| 1230 |
+
res = st.session_state.res
|
| 1231 |
+
if (
|
| 1232 |
+
"target_smiles" not in res
|
| 1233 |
+
and "target_smiles" in st.session_state
|
| 1234 |
+
):
|
| 1235 |
+
res["target_smiles"] = st.session_state.target_smiles
|
| 1236 |
+
cols_to_show = [
|
| 1237 |
+
col
|
| 1238 |
+
for col in [
|
| 1239 |
+
"target_smiles",
|
| 1240 |
+
"num_routes",
|
| 1241 |
+
"num_nodes",
|
| 1242 |
+
"num_iter",
|
| 1243 |
+
"search_time",
|
| 1244 |
+
]
|
| 1245 |
+
if col in res
|
| 1246 |
+
]
|
| 1247 |
+
if cols_to_show: # Ensure there are columns to show
|
| 1248 |
+
df = pd.DataFrame(res, index=[0])[cols_to_show]
|
| 1249 |
+
st.dataframe(df)
|
| 1250 |
+
else:
|
| 1251 |
+
st.write("No statistics to display from planning results.")
|
| 1252 |
+
except Exception as e:
|
| 1253 |
+
st.error(f"Error displaying statistics: {e}")
|
| 1254 |
+
st.write(res) # Show raw dict if DataFrame fails
|
| 1255 |
+
with download_col:
|
| 1256 |
+
st.subheader("Planning Downloads") # Adding a subheader for clarity
|
| 1257 |
+
download_planning_results()
|
| 1258 |
+
|
| 1259 |
+
# Clustering section (setup button, display, download)
|
| 1260 |
+
if (
|
| 1261 |
+
st.session_state.get("planning_done", False)
|
| 1262 |
+
and st.session_state.res
|
| 1263 |
+
and st.session_state.res.get("solved", False)
|
| 1264 |
+
):
|
| 1265 |
+
setup_clustering() # Contains the "Run Clustering" button and logic
|
| 1266 |
+
if st.session_state.get("clustering_done", False):
|
| 1267 |
+
display_clustering_results() # Displays cluster routes and stats
|
| 1268 |
+
cluster_stat_col, cluster_download_col = st.columns(2, gap="medium")
|
| 1269 |
+
|
| 1270 |
+
with cluster_stat_col:
|
| 1271 |
+
clusters = st.session_state.clusters
|
| 1272 |
+
cluster_sizes = [
|
| 1273 |
+
cluster.get("group_size", 0)
|
| 1274 |
+
for cluster in clusters.values()
|
| 1275 |
+
if cluster
|
| 1276 |
+
] # Safe get
|
| 1277 |
+
st.subheader("Cluster Statistics")
|
| 1278 |
+
if cluster_sizes:
|
| 1279 |
+
cluster_df = pd.DataFrame(
|
| 1280 |
+
{
|
| 1281 |
+
"Cluster": [
|
| 1282 |
+
k for k, v in clusters.items() if v
|
| 1283 |
+
], # Filter out empty clusters
|
| 1284 |
+
"Number of Routes": [
|
| 1285 |
+
v["group_size"] for v in clusters.values() if v
|
| 1286 |
+
],
|
| 1287 |
+
}
|
| 1288 |
+
)
|
| 1289 |
+
if not cluster_df.empty:
|
| 1290 |
+
cluster_df.index += 1
|
| 1291 |
+
st.dataframe(cluster_df)
|
| 1292 |
+
best_route_html = html_top_routes_cluster(
|
| 1293 |
+
clusters,
|
| 1294 |
+
st.session_state.tree,
|
| 1295 |
+
st.session_state.target_smiles,
|
| 1296 |
+
)
|
| 1297 |
+
st.download_button(
|
| 1298 |
+
label=f"Download best route from each cluster",
|
| 1299 |
+
data=best_route_html,
|
| 1300 |
+
file_name=f"cluster_best_{st.session_state.target_smiles}.html",
|
| 1301 |
+
mime="text/html",
|
| 1302 |
+
key=f"download_cluster_best",
|
| 1303 |
+
)
|
| 1304 |
+
else:
|
| 1305 |
+
st.write("No valid cluster data to display statistics for.")
|
| 1306 |
+
# download_top_routes_cluster()
|
| 1307 |
+
else:
|
| 1308 |
+
st.write("No cluster data to display statistics for.")
|
| 1309 |
+
with cluster_download_col:
|
| 1310 |
+
download_clustering_results()
|
| 1311 |
+
|
| 1312 |
+
# Subclustering section (setup button, display, download)
|
| 1313 |
+
if st.session_state.get("clustering_done", False): # Depends on clustering
|
| 1314 |
+
setup_subclustering() # Contains "Run Subclustering" button
|
| 1315 |
+
if st.session_state.get("subclustering_done", False):
|
| 1316 |
+
display_subclustering_results() # Displays subcluster details and routes
|
| 1317 |
+
download_subclustering_results() # This needs to be called after selections are made in display.
|
| 1318 |
+
|
| 1319 |
+
implement_restart()
|
| 1320 |
+
|
| 1321 |
+
|
| 1322 |
+
if __name__ == "__main__":
|
| 1323 |
+
main()
|
synplan/mcts/__init__.py
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from CGRtools.containers import MoleculeContainer
|
| 2 |
+
from .node import *
|
| 3 |
+
from .tree import *
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
MoleculeContainer.depict_settings(aam=False)
|
| 7 |
+
|
| 8 |
+
__all__ = ["Tree", "Node"]
|
synplan/mcts/evaluation.py
ADDED
|
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Module containing a class that represents a value function for prediction of
|
| 2 |
+
synthesisablity of new nodes in the tree search."""
|
| 3 |
+
|
| 4 |
+
from typing import List
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
|
| 8 |
+
from synplan.chem.precursor import Precursor, compose_precursors
|
| 9 |
+
from synplan.ml.networks.value import ValueNetwork
|
| 10 |
+
from synplan.ml.training import mol_to_pyg
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class ValueNetworkFunction:
|
| 14 |
+
"""Value function implemented as a value neural network for node evaluation
|
| 15 |
+
(synthesisability prediction) in tree search."""
|
| 16 |
+
|
| 17 |
+
def __init__(self, weights_path: str) -> None:
|
| 18 |
+
"""The value function predicts the probability to synthesize the target molecule
|
| 19 |
+
with available building blocks starting from a given precursor.
|
| 20 |
+
|
| 21 |
+
:param weights_path: The value network weights file path.
|
| 22 |
+
"""
|
| 23 |
+
|
| 24 |
+
value_net = ValueNetwork.load_from_checkpoint(
|
| 25 |
+
weights_path, map_location=torch.device("cpu")
|
| 26 |
+
)
|
| 27 |
+
self.value_network = value_net.eval()
|
| 28 |
+
|
| 29 |
+
def predict_value(self, precursors: List[Precursor,]) -> float:
|
| 30 |
+
"""Predicts a value based on the given precursors from the node. For prediction,
|
| 31 |
+
precursors must be composed into a single molecule (product).
|
| 32 |
+
|
| 33 |
+
:param precursors: The list of precursors.
|
| 34 |
+
:return: The predicted float value ("synthesisability") of the node.
|
| 35 |
+
"""
|
| 36 |
+
|
| 37 |
+
molecule = compose_precursors(precursors=precursors, exclude_small=True)
|
| 38 |
+
pyg_graph = mol_to_pyg(molecule)
|
| 39 |
+
if pyg_graph:
|
| 40 |
+
with torch.no_grad():
|
| 41 |
+
value_pred = self.value_network.forward(pyg_graph)[0].item()
|
| 42 |
+
else:
|
| 43 |
+
value_pred = -1e6
|
| 44 |
+
|
| 45 |
+
return value_pred
|
synplan/mcts/expansion.py
ADDED
|
@@ -0,0 +1,96 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Module containing a class that represents a policy function for node expansion in the
|
| 2 |
+
tree search."""
|
| 3 |
+
|
| 4 |
+
from typing import Iterator, List, Tuple, Union
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
import torch_geometric
|
| 8 |
+
from CGRtools.reactor.reactor import Reactor
|
| 9 |
+
|
| 10 |
+
from synplan.chem.precursor import Precursor
|
| 11 |
+
from synplan.ml.networks.policy import PolicyNetwork
|
| 12 |
+
from synplan.ml.training import mol_to_pyg
|
| 13 |
+
from synplan.utils.config import PolicyNetworkConfig
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
class PolicyNetworkFunction:
|
| 17 |
+
"""Policy function implemented as a policy neural network for node expansion in tree
|
| 18 |
+
search."""
|
| 19 |
+
|
| 20 |
+
def __init__(
|
| 21 |
+
self, policy_config: PolicyNetworkConfig, compile: bool = False
|
| 22 |
+
) -> None:
|
| 23 |
+
"""Initializes the expansion function (ranking or filter policy network).
|
| 24 |
+
|
| 25 |
+
:param policy_config: An expansion policy configuration.
|
| 26 |
+
:param compile: Is supposed to speed up the training with model compilation.
|
| 27 |
+
"""
|
| 28 |
+
|
| 29 |
+
self.config = policy_config
|
| 30 |
+
|
| 31 |
+
policy_net = PolicyNetwork.load_from_checkpoint(
|
| 32 |
+
self.config.weights_path,
|
| 33 |
+
map_location=torch.device("cpu"),
|
| 34 |
+
batch_size=1,
|
| 35 |
+
dropout=0,
|
| 36 |
+
)
|
| 37 |
+
|
| 38 |
+
policy_net = policy_net.eval()
|
| 39 |
+
if compile:
|
| 40 |
+
self.policy_net = torch_geometric.compile(policy_net, dynamic=True)
|
| 41 |
+
else:
|
| 42 |
+
self.policy_net = policy_net
|
| 43 |
+
|
| 44 |
+
def predict_reaction_rules(
|
| 45 |
+
self, precursor: Precursor, reaction_rules: List[Reactor]
|
| 46 |
+
) -> Iterator[Union[Iterator, Iterator[Tuple[float, Reactor, int]]]]:
|
| 47 |
+
"""The policy function predicts the list of reaction rules for a given precursor.
|
| 48 |
+
|
| 49 |
+
:param precursor: The current precursor for which the reaction rules are predicted.
|
| 50 |
+
:param reaction_rules: The list of reaction rules from which applicable reaction
|
| 51 |
+
rules are predicted and selected.
|
| 52 |
+
:return: Yielding the predicted probability for the reaction rule, reaction rule
|
| 53 |
+
and reaction rule id.
|
| 54 |
+
"""
|
| 55 |
+
|
| 56 |
+
out_dim = list(self.policy_net.modules())[-1].out_features
|
| 57 |
+
if out_dim != len(reaction_rules):
|
| 58 |
+
raise Exception(
|
| 59 |
+
f"The policy network output dimensionality is {out_dim}, but the number of reaction rules is {len(reaction_rules)}. "
|
| 60 |
+
"Probably you use a different version of the policy network. Be sure to retain the policy network "
|
| 61 |
+
"with the current set of reaction rules"
|
| 62 |
+
)
|
| 63 |
+
|
| 64 |
+
pyg_graph = mol_to_pyg(precursor.molecule, canonicalize=False)
|
| 65 |
+
if pyg_graph:
|
| 66 |
+
with torch.no_grad():
|
| 67 |
+
if self.policy_net.policy_type == "filtering":
|
| 68 |
+
probs, priority = self.policy_net.forward(pyg_graph)
|
| 69 |
+
if self.policy_net.policy_type == "ranking":
|
| 70 |
+
probs = self.policy_net.forward(pyg_graph)
|
| 71 |
+
del pyg_graph
|
| 72 |
+
else:
|
| 73 |
+
return []
|
| 74 |
+
|
| 75 |
+
probs = probs[0].double()
|
| 76 |
+
if self.policy_net.policy_type == "filtering":
|
| 77 |
+
priority = priority[0].double()
|
| 78 |
+
priority_coef = self.config.priority_rules_fraction
|
| 79 |
+
probs = (1 - priority_coef) * probs + priority_coef * priority
|
| 80 |
+
|
| 81 |
+
sorted_probs, sorted_rules = torch.sort(probs, descending=True)
|
| 82 |
+
sorted_probs, sorted_rules = (
|
| 83 |
+
sorted_probs[: self.config.top_rules],
|
| 84 |
+
sorted_rules[: self.config.top_rules],
|
| 85 |
+
)
|
| 86 |
+
|
| 87 |
+
if self.policy_net.policy_type == "filtering":
|
| 88 |
+
sorted_probs = torch.softmax(sorted_probs, -1)
|
| 89 |
+
|
| 90 |
+
sorted_probs, sorted_rules = sorted_probs.tolist(), sorted_rules.tolist()
|
| 91 |
+
|
| 92 |
+
for prob, rule_id in zip(sorted_probs, sorted_rules):
|
| 93 |
+
if (
|
| 94 |
+
prob > self.config.rule_prob_threshold
|
| 95 |
+
): # search may fail if rule_prob_threshold is too low (recommended value is 0.0)
|
| 96 |
+
yield prob, reaction_rules[rule_id], rule_id
|
synplan/mcts/node.py
ADDED
|
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Module containing a class Node in the tree search."""
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
class Node:
|
| 5 |
+
"""Node class represents a node in the tree search."""
|
| 6 |
+
|
| 7 |
+
def __init__(
|
| 8 |
+
self, precursors_to_expand: tuple = None, new_precursors: tuple = None
|
| 9 |
+
) -> None:
|
| 10 |
+
"""The function initializes the new Node object.
|
| 11 |
+
|
| 12 |
+
:param precursors_to_expand: The tuple of precursors to be expanded. The first precursor
|
| 13 |
+
in the tuple is the current precursor which will be expanded (for which new
|
| 14 |
+
precursors will be generated by applying the predicted reaction rules). When
|
| 15 |
+
the first precursor has been successfully expanded, the second precursor becomes
|
| 16 |
+
the current precursor to be expanded.
|
| 17 |
+
:param new_precursors: The tuple of new precursors generated by applying the reaction
|
| 18 |
+
rule.
|
| 19 |
+
"""
|
| 20 |
+
|
| 21 |
+
self.precursors_to_expand = precursors_to_expand
|
| 22 |
+
self.new_precursors = new_precursors
|
| 23 |
+
|
| 24 |
+
if len(self.precursors_to_expand) == 0:
|
| 25 |
+
self.curr_precursor = tuple()
|
| 26 |
+
else:
|
| 27 |
+
self.curr_precursor = self.precursors_to_expand[0]
|
| 28 |
+
self.next_precursor = self.precursors_to_expand[1:]
|
| 29 |
+
|
| 30 |
+
def __len__(self) -> int:
|
| 31 |
+
"""Returns the number of precursor in the node to expand."""
|
| 32 |
+
return len(self.precursors_to_expand)
|
| 33 |
+
|
| 34 |
+
def __repr__(self) -> str:
|
| 35 |
+
"""Returns the SMILES of each precursor in precursor_to_expand and new_precursor."""
|
| 36 |
+
return (
|
| 37 |
+
f"New precursors: {self.new_precursors}\n"
|
| 38 |
+
f"Precursors to expand: {self.precursors_to_expand}\n"
|
| 39 |
+
)
|
| 40 |
+
|
| 41 |
+
def is_solved(self) -> bool:
|
| 42 |
+
"""If True, it is a terminal node.
|
| 43 |
+
|
| 44 |
+
There are no precursors for expansion.
|
| 45 |
+
"""
|
| 46 |
+
|
| 47 |
+
return len(self.precursors_to_expand) == 0
|
synplan/mcts/search.py
ADDED
|
@@ -0,0 +1,199 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Module containing functions for running tree search for the set of target
|
| 2 |
+
molecules."""
|
| 3 |
+
|
| 4 |
+
import csv
|
| 5 |
+
import json
|
| 6 |
+
import logging
|
| 7 |
+
import os.path
|
| 8 |
+
from pathlib import Path
|
| 9 |
+
from typing import Union
|
| 10 |
+
|
| 11 |
+
from CGRtools.containers import MoleculeContainer
|
| 12 |
+
from tqdm import tqdm
|
| 13 |
+
|
| 14 |
+
from synplan.chem.reaction_routes.route_cgr import extract_reactions
|
| 15 |
+
from synplan.chem.reaction_routes.io import write_routes_csv, write_routes_json
|
| 16 |
+
from synplan.chem.utils import mol_from_smiles
|
| 17 |
+
from synplan.mcts.evaluation import ValueNetworkFunction
|
| 18 |
+
from synplan.mcts.expansion import PolicyNetworkFunction
|
| 19 |
+
from synplan.mcts.tree import Tree, TreeConfig
|
| 20 |
+
from synplan.utils.config import PolicyNetworkConfig
|
| 21 |
+
from synplan.utils.loading import load_building_blocks, load_reaction_rules
|
| 22 |
+
from synplan.utils.visualisation import extract_routes, generate_results_html
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def extract_tree_stats(
|
| 26 |
+
tree: Tree, target: Union[str, MoleculeContainer], init_smiles: str = None
|
| 27 |
+
):
|
| 28 |
+
"""Collects various statistics from a tree and returns them in a dictionary format.
|
| 29 |
+
|
| 30 |
+
:param tree: The built search tree.
|
| 31 |
+
:param target: The target molecule associated with the tree.
|
| 32 |
+
:param init_smiles: initial SMILES of the molecule, optional.
|
| 33 |
+
:return: A dictionary with the calculated statistics.
|
| 34 |
+
"""
|
| 35 |
+
|
| 36 |
+
newick_tree, newick_meta = tree.newickify(visits_threshold=0)
|
| 37 |
+
newick_meta_line = ";".join(
|
| 38 |
+
[f"{nid},{v[0]},{v[1]},{v[2]}" for nid, v in newick_meta.items()]
|
| 39 |
+
)
|
| 40 |
+
|
| 41 |
+
return {
|
| 42 |
+
"target_smiles": init_smiles if init_smiles is not None else str(target),
|
| 43 |
+
"num_routes": len(tree.winning_nodes),
|
| 44 |
+
"num_nodes": len(tree),
|
| 45 |
+
"num_iter": tree.curr_iteration,
|
| 46 |
+
"tree_depth": max(tree.nodes_depth.values()),
|
| 47 |
+
"search_time": round(tree.curr_time, 1),
|
| 48 |
+
"newick_tree": newick_tree,
|
| 49 |
+
"newick_meta": newick_meta_line,
|
| 50 |
+
"solved": True if len(tree.winning_nodes) > 0 else False,
|
| 51 |
+
}
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
def run_search(
|
| 55 |
+
targets_path: str,
|
| 56 |
+
search_config: dict,
|
| 57 |
+
policy_config: PolicyNetworkConfig,
|
| 58 |
+
reaction_rules_path: str,
|
| 59 |
+
building_blocks_path: str,
|
| 60 |
+
value_network_path: str = None,
|
| 61 |
+
results_root: str = "search_results",
|
| 62 |
+
) -> None:
|
| 63 |
+
"""Performs a tree search on a set of target molecules using specified configuration
|
| 64 |
+
and reaction rules, logging the results and statistics.
|
| 65 |
+
|
| 66 |
+
:param targets_path: The path to the file containing the target molecules (in SDF or
|
| 67 |
+
SMILES format).
|
| 68 |
+
:param search_config: The config object containing the configuration for the tree
|
| 69 |
+
search.
|
| 70 |
+
:param policy_config: The config object containing the configuration for the policy.
|
| 71 |
+
:param reaction_rules_path: The path to the file containing reaction rules.
|
| 72 |
+
:param building_blocks_path: The path to the file containing building blocks.
|
| 73 |
+
:param value_network_path: The path to the file containing value weights (optional).
|
| 74 |
+
:param results_root: The name of the folder where the results of the tree search
|
| 75 |
+
will be saved.
|
| 76 |
+
:return: None.
|
| 77 |
+
"""
|
| 78 |
+
|
| 79 |
+
# results folder
|
| 80 |
+
results_root = Path(results_root)
|
| 81 |
+
if not results_root.exists():
|
| 82 |
+
results_root.mkdir()
|
| 83 |
+
|
| 84 |
+
# output files
|
| 85 |
+
stats_file = results_root.joinpath("tree_search_stats.csv")
|
| 86 |
+
routes_file = results_root.joinpath("extracted_routes.json")
|
| 87 |
+
routes_folder = results_root.joinpath("extracted_routes_html")
|
| 88 |
+
routes_folder.mkdir(exist_ok=True)
|
| 89 |
+
|
| 90 |
+
# stats header
|
| 91 |
+
stats_header = [
|
| 92 |
+
"target_smiles",
|
| 93 |
+
"num_routes",
|
| 94 |
+
"num_nodes",
|
| 95 |
+
"num_iter",
|
| 96 |
+
"tree_depth",
|
| 97 |
+
"search_time",
|
| 98 |
+
"newick_tree",
|
| 99 |
+
"newick_meta",
|
| 100 |
+
"solved",
|
| 101 |
+
"error",
|
| 102 |
+
]
|
| 103 |
+
|
| 104 |
+
# config
|
| 105 |
+
policy_function = PolicyNetworkFunction(policy_config=policy_config)
|
| 106 |
+
if search_config["evaluation_type"] == "gcn" and value_network_path:
|
| 107 |
+
value_function = ValueNetworkFunction(weights_path=value_network_path)
|
| 108 |
+
else:
|
| 109 |
+
value_function = None
|
| 110 |
+
|
| 111 |
+
reaction_rules = load_reaction_rules(reaction_rules_path)
|
| 112 |
+
building_blocks = load_building_blocks(building_blocks_path, standardize=True)
|
| 113 |
+
|
| 114 |
+
# run search
|
| 115 |
+
n_solved = 0
|
| 116 |
+
extracted_routes = []
|
| 117 |
+
|
| 118 |
+
tree_config = TreeConfig.from_dict(search_config)
|
| 119 |
+
tree_config.silent = True
|
| 120 |
+
with (
|
| 121 |
+
open(targets_path, "r", encoding="utf-8") as targets,
|
| 122 |
+
open(stats_file, "w", encoding="utf-8", newline="\n") as csvfile,
|
| 123 |
+
):
|
| 124 |
+
|
| 125 |
+
statswriter = csv.DictWriter(csvfile, delimiter=",", fieldnames=stats_header)
|
| 126 |
+
statswriter.writeheader()
|
| 127 |
+
|
| 128 |
+
for ti, target_smi in tqdm(
|
| 129 |
+
enumerate(targets),
|
| 130 |
+
leave=True,
|
| 131 |
+
desc="Number of target molecules processed: ",
|
| 132 |
+
bar_format="{desc}{n} [{elapsed}]",
|
| 133 |
+
):
|
| 134 |
+
target_smi = target_smi.strip()
|
| 135 |
+
target_mol = mol_from_smiles(target_smi)
|
| 136 |
+
try:
|
| 137 |
+
# run search
|
| 138 |
+
tree = Tree(
|
| 139 |
+
target=target_mol,
|
| 140 |
+
config=tree_config,
|
| 141 |
+
reaction_rules=reaction_rules,
|
| 142 |
+
building_blocks=building_blocks,
|
| 143 |
+
expansion_function=policy_function,
|
| 144 |
+
evaluation_function=value_function,
|
| 145 |
+
)
|
| 146 |
+
|
| 147 |
+
_ = list(tree)
|
| 148 |
+
|
| 149 |
+
except Exception as e:
|
| 150 |
+
extracted_routes.append(
|
| 151 |
+
[
|
| 152 |
+
{
|
| 153 |
+
"type": "mol",
|
| 154 |
+
"smiles": target_smi,
|
| 155 |
+
"in_stock": False,
|
| 156 |
+
"children": [],
|
| 157 |
+
}
|
| 158 |
+
]
|
| 159 |
+
)
|
| 160 |
+
logging.warning(
|
| 161 |
+
f"Retrosynthetic_planning {target_smi} failed with the following error: {e}"
|
| 162 |
+
)
|
| 163 |
+
|
| 164 |
+
continue
|
| 165 |
+
|
| 166 |
+
# is solved
|
| 167 |
+
n_solved += bool(tree.winning_nodes)
|
| 168 |
+
if bool(tree.winning_nodes):
|
| 169 |
+
|
| 170 |
+
# extract routes
|
| 171 |
+
extracted_routes.append(extract_routes(tree))
|
| 172 |
+
|
| 173 |
+
# save routes
|
| 174 |
+
generate_results_html(
|
| 175 |
+
tree,
|
| 176 |
+
os.path.join(routes_folder, f"retroroutes_target_{ti}.html"),
|
| 177 |
+
extended=True,
|
| 178 |
+
)
|
| 179 |
+
|
| 180 |
+
# save stats
|
| 181 |
+
statswriter.writerow(extract_tree_stats(tree, target_smi))
|
| 182 |
+
csvfile.flush()
|
| 183 |
+
|
| 184 |
+
# save json routes
|
| 185 |
+
with open(routes_file, "w", encoding="utf-8") as f:
|
| 186 |
+
json.dump(extracted_routes, f)
|
| 187 |
+
|
| 188 |
+
# Save mapped reactions (CSV)
|
| 189 |
+
routes_dict = extract_reactions(tree)
|
| 190 |
+
write_routes_csv(
|
| 191 |
+
routes_dict, os.path.join(routes_folder, f"mapped_routes_{ti}.csv")
|
| 192 |
+
)
|
| 193 |
+
|
| 194 |
+
# save mapped reactions (JSON)
|
| 195 |
+
write_routes_json(
|
| 196 |
+
routes_dict, os.path.join(routes_folder, f"mapped_routes_{ti}.json")
|
| 197 |
+
)
|
| 198 |
+
|
| 199 |
+
print(f"Number of solved target molecules: {n_solved}")
|
synplan/mcts/tree.py
ADDED
|
@@ -0,0 +1,635 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Module containing a class Tree that used for tree search of retrosynthetic routes."""
|
| 2 |
+
|
| 3 |
+
import logging
|
| 4 |
+
import warnings
|
| 5 |
+
from collections import defaultdict, deque
|
| 6 |
+
from math import sqrt
|
| 7 |
+
from random import choice, uniform
|
| 8 |
+
from time import time
|
| 9 |
+
from typing import Dict, List, Set, Tuple
|
| 10 |
+
|
| 11 |
+
from CGRtools.reactor import Reactor
|
| 12 |
+
from CGRtools.containers import MoleculeContainer
|
| 13 |
+
from tqdm.auto import tqdm
|
| 14 |
+
|
| 15 |
+
from synplan.chem.precursor import Precursor
|
| 16 |
+
from synplan.chem.reaction import Reaction, apply_reaction_rule
|
| 17 |
+
from synplan.mcts.evaluation import ValueNetworkFunction
|
| 18 |
+
from synplan.mcts.expansion import PolicyNetworkFunction
|
| 19 |
+
from synplan.mcts.node import Node
|
| 20 |
+
from synplan.utils.config import TreeConfig
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
class Tree:
|
| 24 |
+
"""Tree class with attributes and methods for Monte-Carlo tree search."""
|
| 25 |
+
|
| 26 |
+
def __init__(
|
| 27 |
+
self,
|
| 28 |
+
target: MoleculeContainer,
|
| 29 |
+
config: TreeConfig,
|
| 30 |
+
reaction_rules: List[Reactor],
|
| 31 |
+
building_blocks: Set[str],
|
| 32 |
+
expansion_function: PolicyNetworkFunction,
|
| 33 |
+
evaluation_function: ValueNetworkFunction = None,
|
| 34 |
+
):
|
| 35 |
+
"""Initializes a tree object with optional parameters for tree search for target
|
| 36 |
+
molecule.
|
| 37 |
+
|
| 38 |
+
:param target: A target molecule for retrosynthetic routes search.
|
| 39 |
+
:param config: A tree configuration.
|
| 40 |
+
:param reaction_rules: A loaded reaction rules.
|
| 41 |
+
:param building_blocks: A loaded building blocks.
|
| 42 |
+
:param expansion_function: A loaded policy function.
|
| 43 |
+
:param evaluation_function: A loaded value function. If None, the rollout is
|
| 44 |
+
used as a default for node evaluation.
|
| 45 |
+
"""
|
| 46 |
+
|
| 47 |
+
# config parameters
|
| 48 |
+
self.config = config
|
| 49 |
+
|
| 50 |
+
assert isinstance(
|
| 51 |
+
target, MoleculeContainer
|
| 52 |
+
), "Target should be given as MoleculeContainer"
|
| 53 |
+
assert len(target) > 3, "Target molecule has less than 3 atoms"
|
| 54 |
+
|
| 55 |
+
target_molecule = Precursor(target)
|
| 56 |
+
target_molecule.prev_precursors.append(Precursor(target))
|
| 57 |
+
target_node = Node(
|
| 58 |
+
precursors_to_expand=(target_molecule,), new_precursors=(target_molecule,)
|
| 59 |
+
)
|
| 60 |
+
|
| 61 |
+
# tree structure init
|
| 62 |
+
self.nodes: Dict[int, Node] = {1: target_node}
|
| 63 |
+
self.parents: Dict[int, int] = {1: 0}
|
| 64 |
+
self.children: Dict[int, Set[int]] = {1: set()}
|
| 65 |
+
self.winning_nodes: List[int] = []
|
| 66 |
+
self.visited_nodes: Set[int] = set()
|
| 67 |
+
self.expanded_nodes: Set[int] = set()
|
| 68 |
+
self.nodes_visit: Dict[int, int] = {1: 0}
|
| 69 |
+
self.nodes_depth: Dict[int, int] = {1: 0}
|
| 70 |
+
self.nodes_prob: Dict[int, float] = {1: 0.0}
|
| 71 |
+
self.nodes_rules: Dict[int, float] = {}
|
| 72 |
+
self.nodes_init_value: Dict[int, float] = {1: 0.0}
|
| 73 |
+
self.nodes_total_value: Dict[int, float] = {1: 0.0}
|
| 74 |
+
|
| 75 |
+
# tree building limits
|
| 76 |
+
self.curr_iteration: int = 0
|
| 77 |
+
self.curr_tree_size: int = 2
|
| 78 |
+
self.start_time: float = 0
|
| 79 |
+
self.curr_time: float = 0
|
| 80 |
+
|
| 81 |
+
# building blocks and reaction reaction_rules
|
| 82 |
+
self.reaction_rules = reaction_rules
|
| 83 |
+
self.building_blocks = building_blocks
|
| 84 |
+
|
| 85 |
+
# policy and value functions
|
| 86 |
+
self.policy_network = expansion_function
|
| 87 |
+
if self.config.evaluation_type == "gcn":
|
| 88 |
+
if evaluation_function is None:
|
| 89 |
+
raise ValueError(
|
| 90 |
+
"Value function not specified while evaluation type is 'gcn'"
|
| 91 |
+
)
|
| 92 |
+
if (
|
| 93 |
+
evaluation_function is not None
|
| 94 |
+
and self.config.evaluation_type == "rollout"
|
| 95 |
+
):
|
| 96 |
+
raise ValueError(
|
| 97 |
+
"Value function is not None while evaluation type is 'rollout'. What should be evaluation type ?"
|
| 98 |
+
)
|
| 99 |
+
self.value_network = evaluation_function
|
| 100 |
+
|
| 101 |
+
# utils
|
| 102 |
+
self._tqdm = True # needed to disable tqdm with multiprocessing module
|
| 103 |
+
|
| 104 |
+
target_smiles = str(self.nodes[1].curr_precursor.molecule)
|
| 105 |
+
if target_smiles in self.building_blocks:
|
| 106 |
+
self.building_blocks.remove(target_smiles)
|
| 107 |
+
print(
|
| 108 |
+
"Target was found in building blocks and removed from building blocks."
|
| 109 |
+
)
|
| 110 |
+
|
| 111 |
+
def __len__(self) -> int:
|
| 112 |
+
"""Returns the current size (the number of nodes) in the tree."""
|
| 113 |
+
|
| 114 |
+
return self.curr_tree_size - 1
|
| 115 |
+
|
| 116 |
+
def __iter__(self) -> "Tree":
|
| 117 |
+
"""The function is defining an iterator for a Tree object.
|
| 118 |
+
|
| 119 |
+
Also needed for the bar progress display.
|
| 120 |
+
"""
|
| 121 |
+
|
| 122 |
+
self.start_time = time()
|
| 123 |
+
if self._tqdm:
|
| 124 |
+
self._tqdm = tqdm(
|
| 125 |
+
total=self.config.max_iterations, disable=self.config.silent
|
| 126 |
+
)
|
| 127 |
+
return self
|
| 128 |
+
|
| 129 |
+
def __repr__(self) -> str:
|
| 130 |
+
"""Returns a string representation of the tree (target SMILES, tree size, and
|
| 131 |
+
the number of found routes)."""
|
| 132 |
+
return self.report()
|
| 133 |
+
|
| 134 |
+
def __next__(self) -> [bool, List[int]]:
|
| 135 |
+
"""The __next__ method is used to do one iteration of the tree building.
|
| 136 |
+
|
| 137 |
+
:return: Returns True if the route was found and the node id of the last node in
|
| 138 |
+
the route. Otherwise, returns False and the id of the last visited node.
|
| 139 |
+
"""
|
| 140 |
+
|
| 141 |
+
if self.curr_iteration >= self.config.max_iterations:
|
| 142 |
+
raise StopIteration("Iterations limit exceeded.")
|
| 143 |
+
if self.curr_tree_size >= self.config.max_tree_size:
|
| 144 |
+
raise StopIteration("Max tree size exceeded or all possible routes found.")
|
| 145 |
+
if self.curr_time >= self.config.max_time:
|
| 146 |
+
raise StopIteration("Time limit exceeded.")
|
| 147 |
+
|
| 148 |
+
# start new iteration
|
| 149 |
+
self.curr_iteration += 1
|
| 150 |
+
self.curr_time = time() - self.start_time
|
| 151 |
+
|
| 152 |
+
if self._tqdm:
|
| 153 |
+
self._tqdm.update()
|
| 154 |
+
|
| 155 |
+
curr_depth, node_id = 0, 1 # start from the root node_id
|
| 156 |
+
|
| 157 |
+
explore_route = True
|
| 158 |
+
while explore_route:
|
| 159 |
+
self.visited_nodes.add(node_id)
|
| 160 |
+
|
| 161 |
+
if self.nodes_visit[node_id]: # already visited
|
| 162 |
+
if not self.children[node_id]: # dead node
|
| 163 |
+
self._update_visits(node_id)
|
| 164 |
+
explore_route = False
|
| 165 |
+
else:
|
| 166 |
+
node_id = self._select_node(node_id) # select the child node
|
| 167 |
+
curr_depth += 1
|
| 168 |
+
else:
|
| 169 |
+
if self.nodes[node_id].is_solved(): # found route
|
| 170 |
+
self._update_visits(
|
| 171 |
+
node_id
|
| 172 |
+
) # this prevents expanding of bb node_id
|
| 173 |
+
self.winning_nodes.append(node_id)
|
| 174 |
+
return True, [node_id]
|
| 175 |
+
|
| 176 |
+
if (
|
| 177 |
+
curr_depth < self.config.max_depth
|
| 178 |
+
): # expand node if depth limit is not reached
|
| 179 |
+
self._expand_node(node_id)
|
| 180 |
+
if not self.children[node_id]: # node was not expanded
|
| 181 |
+
value_to_backprop = -1.0
|
| 182 |
+
else:
|
| 183 |
+
self.expanded_nodes.add(node_id)
|
| 184 |
+
|
| 185 |
+
if self.config.search_strategy == "evaluation_first":
|
| 186 |
+
# recalculate node value based on children synthesisability and backpropagation
|
| 187 |
+
child_values = [
|
| 188 |
+
self.nodes_init_value[child_id]
|
| 189 |
+
for child_id in self.children[node_id]
|
| 190 |
+
]
|
| 191 |
+
|
| 192 |
+
if self.config.evaluation_agg == "max":
|
| 193 |
+
value_to_backprop = max(child_values)
|
| 194 |
+
|
| 195 |
+
elif self.config.evaluation_agg == "average":
|
| 196 |
+
value_to_backprop = sum(child_values) / len(
|
| 197 |
+
self.children[node_id]
|
| 198 |
+
)
|
| 199 |
+
|
| 200 |
+
elif self.config.search_strategy == "expansion_first":
|
| 201 |
+
value_to_backprop = self._get_node_value(node_id)
|
| 202 |
+
|
| 203 |
+
# backpropagation
|
| 204 |
+
self._backpropagate(node_id, value_to_backprop)
|
| 205 |
+
self._update_visits(node_id)
|
| 206 |
+
explore_route = False
|
| 207 |
+
|
| 208 |
+
if self.children[node_id]:
|
| 209 |
+
# found after expansion
|
| 210 |
+
found_after_expansion = set()
|
| 211 |
+
for child_id in iter(self.children[node_id]):
|
| 212 |
+
if self.nodes[child_id].is_solved():
|
| 213 |
+
found_after_expansion.add(child_id)
|
| 214 |
+
self.winning_nodes.append(child_id)
|
| 215 |
+
|
| 216 |
+
if found_after_expansion:
|
| 217 |
+
return True, list(found_after_expansion)
|
| 218 |
+
|
| 219 |
+
else:
|
| 220 |
+
self._backpropagate(node_id, self.nodes_total_value[node_id])
|
| 221 |
+
self._update_visits(node_id)
|
| 222 |
+
explore_route = False
|
| 223 |
+
|
| 224 |
+
return False, [node_id]
|
| 225 |
+
|
| 226 |
+
def _ucb(self, node_id: int) -> float:
|
| 227 |
+
"""Calculates the Upper Confidence Bound (UCB) statistics for a given node.
|
| 228 |
+
|
| 229 |
+
:param node_id: The id of the node.
|
| 230 |
+
:return: The calculated UCB.
|
| 231 |
+
"""
|
| 232 |
+
|
| 233 |
+
prob = self.nodes_prob[node_id] # predicted by policy network score
|
| 234 |
+
visit = self.nodes_visit[node_id]
|
| 235 |
+
|
| 236 |
+
if self.config.ucb_type == "puct":
|
| 237 |
+
u = (
|
| 238 |
+
self.config.c_ucb * prob * sqrt(self.nodes_visit[self.parents[node_id]])
|
| 239 |
+
) / (visit + 1)
|
| 240 |
+
ucb_value = self.nodes_total_value[node_id] + u
|
| 241 |
+
|
| 242 |
+
if self.config.ucb_type == "uct":
|
| 243 |
+
u = (
|
| 244 |
+
self.config.c_ucb
|
| 245 |
+
* sqrt(self.nodes_visit[self.parents[node_id]])
|
| 246 |
+
/ (visit + 1)
|
| 247 |
+
)
|
| 248 |
+
ucb_value = self.nodes_total_value[node_id] + u
|
| 249 |
+
|
| 250 |
+
if self.config.ucb_type == "value":
|
| 251 |
+
ucb_value = self.nodes_init_value[node_id] / (visit + 1)
|
| 252 |
+
|
| 253 |
+
return ucb_value
|
| 254 |
+
|
| 255 |
+
def _select_node(self, node_id: int) -> int:
|
| 256 |
+
"""Selects a node based on its UCB value and returns the id of the node with the
|
| 257 |
+
highest UCB.
|
| 258 |
+
|
| 259 |
+
:param node_id: The id of the node.
|
| 260 |
+
:return: The id of the node with the highest UCB.
|
| 261 |
+
"""
|
| 262 |
+
|
| 263 |
+
if self.config.epsilon > 0:
|
| 264 |
+
n = uniform(0, 1)
|
| 265 |
+
if n < self.config.epsilon:
|
| 266 |
+
return choice(list(self.children[node_id]))
|
| 267 |
+
|
| 268 |
+
best_score, best_children = None, []
|
| 269 |
+
for child_id in self.children[node_id]:
|
| 270 |
+
score = self._ucb(child_id)
|
| 271 |
+
if best_score is None or score > best_score:
|
| 272 |
+
best_score, best_children = score, [child_id]
|
| 273 |
+
elif score == best_score:
|
| 274 |
+
best_children.append(child_id)
|
| 275 |
+
|
| 276 |
+
# is needed for tree search reproducibility, when all child nodes has the same score
|
| 277 |
+
return best_children[0]
|
| 278 |
+
|
| 279 |
+
def _expand_node(self, node_id: int) -> None:
|
| 280 |
+
"""Expands the node by generating new precursor with policy (expansion) function.
|
| 281 |
+
|
| 282 |
+
:param node_id: The id the node to be expanded.
|
| 283 |
+
:return: None.
|
| 284 |
+
"""
|
| 285 |
+
curr_node = self.nodes[node_id]
|
| 286 |
+
prev_precursor = curr_node.curr_precursor.prev_precursors
|
| 287 |
+
|
| 288 |
+
tmp_precursor = set()
|
| 289 |
+
expanded = False
|
| 290 |
+
for prob, rule, rule_id in self.policy_network.predict_reaction_rules(
|
| 291 |
+
curr_node.curr_precursor, self.reaction_rules
|
| 292 |
+
):
|
| 293 |
+
for products in apply_reaction_rule(
|
| 294 |
+
curr_node.curr_precursor.molecule, rule
|
| 295 |
+
):
|
| 296 |
+
# check repeated products
|
| 297 |
+
if not products or not set(products) - tmp_precursor:
|
| 298 |
+
continue
|
| 299 |
+
tmp_precursor.update(products)
|
| 300 |
+
|
| 301 |
+
for molecule in products:
|
| 302 |
+
molecule.meta["reactor_id"] = rule_id
|
| 303 |
+
|
| 304 |
+
new_precursor = tuple(Precursor(mol) for mol in products)
|
| 305 |
+
scaled_prob = prob * len(
|
| 306 |
+
list(filter(lambda x: len(x) > self.config.min_mol_size, products))
|
| 307 |
+
)
|
| 308 |
+
|
| 309 |
+
if set(prev_precursor).isdisjoint(new_precursor):
|
| 310 |
+
precursors_to_expand = (
|
| 311 |
+
*curr_node.next_precursor,
|
| 312 |
+
*(
|
| 313 |
+
x
|
| 314 |
+
for x in new_precursor
|
| 315 |
+
if not x.is_building_block(
|
| 316 |
+
self.building_blocks, self.config.min_mol_size
|
| 317 |
+
)
|
| 318 |
+
),
|
| 319 |
+
)
|
| 320 |
+
|
| 321 |
+
child_node = Node(
|
| 322 |
+
precursors_to_expand=precursors_to_expand,
|
| 323 |
+
new_precursors=new_precursor,
|
| 324 |
+
)
|
| 325 |
+
|
| 326 |
+
for new_precursor in new_precursor:
|
| 327 |
+
new_precursor.prev_precursors = [new_precursor, *prev_precursor]
|
| 328 |
+
|
| 329 |
+
self._add_node(node_id, child_node, scaled_prob, rule_id)
|
| 330 |
+
|
| 331 |
+
expanded = True
|
| 332 |
+
if not expanded and node_id == 1:
|
| 333 |
+
raise StopIteration("\nThe target molecule was not expanded.")
|
| 334 |
+
|
| 335 |
+
def _add_node(
|
| 336 |
+
self,
|
| 337 |
+
node_id: int,
|
| 338 |
+
new_node: Node,
|
| 339 |
+
policy_prob: float = None,
|
| 340 |
+
rule_id: int = None,
|
| 341 |
+
) -> None:
|
| 342 |
+
"""Adds a new node to the tree with probability of reaction rules predicted by
|
| 343 |
+
policy function and applied to the parent node of the new node.
|
| 344 |
+
|
| 345 |
+
:param node_id: The id of the parent node.
|
| 346 |
+
:param new_node: The new node to be added.
|
| 347 |
+
:param policy_prob: The probability of reaction rules predicted by policy
|
| 348 |
+
function for thr parent node.
|
| 349 |
+
:return: None.
|
| 350 |
+
"""
|
| 351 |
+
|
| 352 |
+
new_node_id = self.curr_tree_size
|
| 353 |
+
|
| 354 |
+
self.nodes[new_node_id] = new_node
|
| 355 |
+
self.parents[new_node_id] = node_id
|
| 356 |
+
self.children[node_id].add(new_node_id)
|
| 357 |
+
self.children[new_node_id] = set()
|
| 358 |
+
self.nodes_visit[new_node_id] = 0
|
| 359 |
+
self.nodes_prob[new_node_id] = policy_prob
|
| 360 |
+
self.nodes_rules[new_node_id] = rule_id
|
| 361 |
+
self.nodes_depth[new_node_id] = self.nodes_depth[node_id] + 1
|
| 362 |
+
self.curr_tree_size += 1
|
| 363 |
+
|
| 364 |
+
if self.config.search_strategy == "evaluation_first":
|
| 365 |
+
node_value = self._get_node_value(new_node_id)
|
| 366 |
+
elif self.config.search_strategy == "expansion_first":
|
| 367 |
+
node_value = self.config.init_node_value
|
| 368 |
+
|
| 369 |
+
self.nodes_init_value[new_node_id] = node_value
|
| 370 |
+
self.nodes_total_value[new_node_id] = node_value
|
| 371 |
+
|
| 372 |
+
def _get_node_value(self, node_id: int) -> float:
|
| 373 |
+
"""Calculates the value for the given node (for example with rollout or value
|
| 374 |
+
network).
|
| 375 |
+
|
| 376 |
+
:param node_id: The id of the node to be evaluated.
|
| 377 |
+
:return: The estimated value of the node.
|
| 378 |
+
"""
|
| 379 |
+
|
| 380 |
+
node = self.nodes[node_id]
|
| 381 |
+
|
| 382 |
+
if self.config.evaluation_type == "random":
|
| 383 |
+
node_value = uniform(0, 1)
|
| 384 |
+
|
| 385 |
+
elif self.config.evaluation_type == "rollout":
|
| 386 |
+
node_value = min(
|
| 387 |
+
(
|
| 388 |
+
self._rollout_node(
|
| 389 |
+
precursor, current_depth=self.nodes_depth[node_id]
|
| 390 |
+
)
|
| 391 |
+
for precursor in node.precursors_to_expand
|
| 392 |
+
),
|
| 393 |
+
default=1.0,
|
| 394 |
+
)
|
| 395 |
+
|
| 396 |
+
elif self.config.evaluation_type == "gcn":
|
| 397 |
+
node_value = self.value_network.predict_value(node.new_precursors)
|
| 398 |
+
|
| 399 |
+
return node_value
|
| 400 |
+
|
| 401 |
+
def _update_visits(self, node_id: int) -> None:
|
| 402 |
+
"""Updates the number of visits from the current node to the root node.
|
| 403 |
+
|
| 404 |
+
:param node_id: The id of the current node.
|
| 405 |
+
:return: None.
|
| 406 |
+
"""
|
| 407 |
+
|
| 408 |
+
while node_id:
|
| 409 |
+
self.nodes_visit[node_id] += 1
|
| 410 |
+
node_id = self.parents[node_id]
|
| 411 |
+
|
| 412 |
+
def _backpropagate(self, node_id: int, value: float) -> None:
|
| 413 |
+
"""Backpropagates the value through the tree from the current.
|
| 414 |
+
|
| 415 |
+
:param node_id: The id of the node from which to backpropagate the value.
|
| 416 |
+
:param value: The value to backpropagate.
|
| 417 |
+
:return: None.
|
| 418 |
+
"""
|
| 419 |
+
while node_id:
|
| 420 |
+
if self.config.backprop_type == "muzero":
|
| 421 |
+
self.nodes_total_value[node_id] = (
|
| 422 |
+
self.nodes_total_value[node_id] * self.nodes_visit[node_id] + value
|
| 423 |
+
) / (self.nodes_visit[node_id] + 1)
|
| 424 |
+
elif self.config.backprop_type == "cumulative":
|
| 425 |
+
self.nodes_total_value[node_id] += value
|
| 426 |
+
node_id = self.parents[node_id]
|
| 427 |
+
|
| 428 |
+
def _rollout_node(self, precursor: Precursor, current_depth: int = None) -> float:
|
| 429 |
+
"""Performs a rollout simulation from a given node in the tree. Given the
|
| 430 |
+
current precursor, find the first successful reaction and return the new precursor.
|
| 431 |
+
|
| 432 |
+
If the precursor is a building_block, return 1.0, else check the
|
| 433 |
+
first successful reaction.
|
| 434 |
+
|
| 435 |
+
If the reaction is not successful, return -1.0.
|
| 436 |
+
|
| 437 |
+
If the reaction is successful, but the generated precursor are not
|
| 438 |
+
the building_blocks and the precursor cannot be generated without
|
| 439 |
+
exceeding current_depth threshold, return -0.5.
|
| 440 |
+
|
| 441 |
+
If the reaction is successful, but the precursor are not the
|
| 442 |
+
building_blocks and the precursor cannot be generated, return
|
| 443 |
+
-1.0.
|
| 444 |
+
|
| 445 |
+
:param precursor: The precursor to be evaluated.
|
| 446 |
+
:param current_depth: The current depth of the tree.
|
| 447 |
+
:return: The reward (value) assigned to the precursor.
|
| 448 |
+
"""
|
| 449 |
+
|
| 450 |
+
max_depth = self.config.max_depth - current_depth
|
| 451 |
+
|
| 452 |
+
# precursor checking
|
| 453 |
+
if precursor.is_building_block(self.building_blocks, self.config.min_mol_size):
|
| 454 |
+
return 1.0
|
| 455 |
+
|
| 456 |
+
if max_depth == 0:
|
| 457 |
+
print("max depth reached in the beginning")
|
| 458 |
+
|
| 459 |
+
# precursor simulating
|
| 460 |
+
occurred_precursor = set()
|
| 461 |
+
precursor_to_expand = deque([precursor])
|
| 462 |
+
history = defaultdict(dict)
|
| 463 |
+
rollout_depth = 0
|
| 464 |
+
while precursor_to_expand:
|
| 465 |
+
# Iterate through reactors and pick first successful reaction.
|
| 466 |
+
# Check products of the reaction if you can find them in in-building_blocks data
|
| 467 |
+
# If not, then add missed products to precursor_to_expand and try to decompose them
|
| 468 |
+
if len(history) >= max_depth:
|
| 469 |
+
reward = -0.5
|
| 470 |
+
return reward
|
| 471 |
+
|
| 472 |
+
current_precursor = precursor_to_expand.popleft()
|
| 473 |
+
history[rollout_depth]["target"] = current_precursor
|
| 474 |
+
occurred_precursor.add(current_precursor)
|
| 475 |
+
|
| 476 |
+
# Pick the first successful reaction while iterating through reactors
|
| 477 |
+
reaction_rule_applied = False
|
| 478 |
+
for prob, rule, rule_id in self.policy_network.predict_reaction_rules(
|
| 479 |
+
current_precursor, self.reaction_rules
|
| 480 |
+
):
|
| 481 |
+
for products in apply_reaction_rule(current_precursor.molecule, rule):
|
| 482 |
+
if products:
|
| 483 |
+
reaction_rule_applied = True
|
| 484 |
+
break
|
| 485 |
+
|
| 486 |
+
if reaction_rule_applied:
|
| 487 |
+
history[rollout_depth]["rule_index"] = rule_id
|
| 488 |
+
break
|
| 489 |
+
|
| 490 |
+
if not reaction_rule_applied:
|
| 491 |
+
reward = -1.0
|
| 492 |
+
return reward
|
| 493 |
+
|
| 494 |
+
products = tuple(Precursor(product) for product in products)
|
| 495 |
+
history[rollout_depth]["products"] = products
|
| 496 |
+
|
| 497 |
+
# check loops
|
| 498 |
+
if any(x in occurred_precursor for x in products) and products:
|
| 499 |
+
# sometimes manual can create a loop, when
|
| 500 |
+
# print('occurred_precursor')
|
| 501 |
+
reward = -1.0
|
| 502 |
+
return reward
|
| 503 |
+
|
| 504 |
+
if occurred_precursor.isdisjoint(products):
|
| 505 |
+
# added number of atoms check
|
| 506 |
+
precursor_to_expand.extend(
|
| 507 |
+
[
|
| 508 |
+
x
|
| 509 |
+
for x in products
|
| 510 |
+
if not x.is_building_block(
|
| 511 |
+
self.building_blocks, self.config.min_mol_size
|
| 512 |
+
)
|
| 513 |
+
]
|
| 514 |
+
)
|
| 515 |
+
rollout_depth += 1
|
| 516 |
+
|
| 517 |
+
reward = 1.0
|
| 518 |
+
return reward
|
| 519 |
+
|
| 520 |
+
def report(self) -> str:
|
| 521 |
+
"""Returns the string representation of the tree."""
|
| 522 |
+
|
| 523 |
+
return (
|
| 524 |
+
f"Tree for: {str(self.nodes[1].precursors_to_expand[0])}\n"
|
| 525 |
+
f"Time: {round(self.curr_time, 1)} seconds\n"
|
| 526 |
+
f"Number of nodes: {len(self)}\n"
|
| 527 |
+
f"Number of iterations: {self.curr_iteration}\n"
|
| 528 |
+
f"Number of visited nodes: {len(self.visited_nodes)}\n"
|
| 529 |
+
f"Number of found routes: {len(self.winning_nodes)}"
|
| 530 |
+
)
|
| 531 |
+
|
| 532 |
+
def route_score(self, node_id: int) -> float:
|
| 533 |
+
"""Calculates the score of a given route from the current node to the root node.
|
| 534 |
+
The score depends on cumulated node values nad the route length.
|
| 535 |
+
|
| 536 |
+
:param node_id: The id of the current given node.
|
| 537 |
+
:return: The route score.
|
| 538 |
+
"""
|
| 539 |
+
|
| 540 |
+
cumulated_nodes_value, route_length = 0, 0
|
| 541 |
+
while node_id:
|
| 542 |
+
route_length += 1
|
| 543 |
+
|
| 544 |
+
cumulated_nodes_value += self.nodes_total_value[node_id]
|
| 545 |
+
node_id = self.parents[node_id]
|
| 546 |
+
|
| 547 |
+
return cumulated_nodes_value / (route_length**2)
|
| 548 |
+
|
| 549 |
+
def route_to_node(self, node_id: int) -> List[Node,]:
|
| 550 |
+
"""Returns the route (list of id of nodes) to from the node current node to the
|
| 551 |
+
root node.
|
| 552 |
+
|
| 553 |
+
:param node_id: The id of the current node.
|
| 554 |
+
:return: The list of nodes.
|
| 555 |
+
"""
|
| 556 |
+
|
| 557 |
+
nodes = []
|
| 558 |
+
while node_id:
|
| 559 |
+
nodes.append(node_id)
|
| 560 |
+
node_id = self.parents[node_id]
|
| 561 |
+
return [self.nodes[node_id] for node_id in reversed(nodes)]
|
| 562 |
+
|
| 563 |
+
def synthesis_route(self, node_id: int) -> Tuple[Reaction,]:
|
| 564 |
+
"""Given a node_id, return a tuple of reactions that represent the
|
| 565 |
+
retrosynthetic route from the current node.
|
| 566 |
+
|
| 567 |
+
:param node_id: The id of the current node.
|
| 568 |
+
:return: The tuple of extracted reactions representing the synthesis route.
|
| 569 |
+
"""
|
| 570 |
+
|
| 571 |
+
nodes = self.route_to_node(node_id)
|
| 572 |
+
|
| 573 |
+
reaction_sequence = [
|
| 574 |
+
Reaction(
|
| 575 |
+
[x.molecule for x in after.new_precursors],
|
| 576 |
+
[before.curr_precursor.molecule],
|
| 577 |
+
)
|
| 578 |
+
for before, after in zip(nodes, nodes[1:])
|
| 579 |
+
]
|
| 580 |
+
|
| 581 |
+
for r in reaction_sequence:
|
| 582 |
+
r.clean2d()
|
| 583 |
+
return tuple(reversed(reaction_sequence))
|
| 584 |
+
|
| 585 |
+
def newickify(self, visits_threshold: int = 0, root_node_id: int = 1):
|
| 586 |
+
"""
|
| 587 |
+
Adopted from https://stackoverflow.com/questions/50003007/how-to-convert-python-dictionary-to-newick-form-format.
|
| 588 |
+
|
| 589 |
+
:param visits_threshold: The minimum number of visits for the given node.
|
| 590 |
+
:param root_node_id: The id of the root node.
|
| 591 |
+
|
| 592 |
+
:return: The newick string and meta dict.
|
| 593 |
+
"""
|
| 594 |
+
visited_nodes = set()
|
| 595 |
+
|
| 596 |
+
def newick_render_node(current_node_id: int) -> str:
|
| 597 |
+
"""Recursively generates a Newick string representation of the tree.
|
| 598 |
+
|
| 599 |
+
:param current_node_id: The id of the current node.
|
| 600 |
+
:return: A string representation of a node in a Newick format.
|
| 601 |
+
"""
|
| 602 |
+
assert (
|
| 603 |
+
current_node_id not in visited_nodes
|
| 604 |
+
), "Error: The tree may not be circular!"
|
| 605 |
+
node_visit = self.nodes_visit[current_node_id]
|
| 606 |
+
|
| 607 |
+
visited_nodes.add(current_node_id)
|
| 608 |
+
if self.children[current_node_id]:
|
| 609 |
+
# Nodes
|
| 610 |
+
children = [
|
| 611 |
+
child
|
| 612 |
+
for child in list(self.children[current_node_id])
|
| 613 |
+
if self.nodes_visit[child] >= visits_threshold
|
| 614 |
+
]
|
| 615 |
+
children_strings = [newick_render_node(child) for child in children]
|
| 616 |
+
children_strings = ",".join(children_strings)
|
| 617 |
+
if children_strings:
|
| 618 |
+
return f"({children_strings}){current_node_id}:{node_visit}"
|
| 619 |
+
# leafs within threshold
|
| 620 |
+
return f"{current_node_id}:{node_visit}"
|
| 621 |
+
|
| 622 |
+
return f"{current_node_id}:{node_visit}"
|
| 623 |
+
|
| 624 |
+
newick_string = newick_render_node(root_node_id) + ";"
|
| 625 |
+
|
| 626 |
+
meta = {}
|
| 627 |
+
for node_id in iter(visited_nodes):
|
| 628 |
+
node_value = round(self.nodes_total_value[node_id], 3)
|
| 629 |
+
|
| 630 |
+
node_synthesisability = round(self.nodes_init_value[node_id])
|
| 631 |
+
|
| 632 |
+
visit_in_node = self.nodes_visit[node_id]
|
| 633 |
+
meta[node_id] = (node_value, node_synthesisability, visit_in_node)
|
| 634 |
+
|
| 635 |
+
return newick_string, meta
|
synplan/ml/__init__.py
ADDED
|
File without changes
|
synplan/ml/networks/__init__.py
ADDED
|
File without changes
|
synplan/ml/networks/modules.py
ADDED
|
@@ -0,0 +1,234 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Module containing basic pytorch architectures of policy and value neural networks."""
|
| 2 |
+
|
| 3 |
+
from abc import ABC, abstractmethod
|
| 4 |
+
from typing import Dict, List, Tuple, Union
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
from adabelief_pytorch import AdaBelief
|
| 8 |
+
from pytorch_lightning import LightningModule
|
| 9 |
+
from torch import Tensor
|
| 10 |
+
from torch.nn import GELU, Dropout, Linear, Module, ModuleDict, ModuleList
|
| 11 |
+
from torch.nn.functional import relu
|
| 12 |
+
from torch.optim.lr_scheduler import ReduceLROnPlateau
|
| 13 |
+
from torch_geometric.data.batch import Batch
|
| 14 |
+
from torch_geometric.nn.conv import GCNConv
|
| 15 |
+
from torch_geometric.nn.pool import global_add_pool
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class GraphEmbedding(Module):
|
| 19 |
+
"""Needed to convert molecule atom vectors to the single vector using graph
|
| 20 |
+
convolution."""
|
| 21 |
+
|
| 22 |
+
def __init__(
|
| 23 |
+
self, vector_dim: int = 512, dropout: float = 0.4, num_conv_layers: int = 5
|
| 24 |
+
):
|
| 25 |
+
"""Initializes a graph convolutional module. Needed to convert molecule atom
|
| 26 |
+
vectors to the single vector using graph convolution.
|
| 27 |
+
|
| 28 |
+
:param vector_dim: The dimensionality of the hidden layers and output layer of
|
| 29 |
+
graph convolution module.
|
| 30 |
+
:param dropout: Dropout is a regularization technique used in neural networks to
|
| 31 |
+
prevent overfitting. It randomly sets a fraction of input units to 0 at each
|
| 32 |
+
update during training time.
|
| 33 |
+
:param num_conv_layers: The number of convolutional layers in a graph
|
| 34 |
+
convolutional module.
|
| 35 |
+
"""
|
| 36 |
+
|
| 37 |
+
super().__init__()
|
| 38 |
+
self.expansion = Linear(11, vector_dim)
|
| 39 |
+
self.dropout = Dropout(dropout)
|
| 40 |
+
self.gcn_convs = ModuleList(
|
| 41 |
+
[
|
| 42 |
+
GCNConv(
|
| 43 |
+
vector_dim,
|
| 44 |
+
vector_dim,
|
| 45 |
+
improved=True,
|
| 46 |
+
)
|
| 47 |
+
for _ in range(num_conv_layers)
|
| 48 |
+
]
|
| 49 |
+
)
|
| 50 |
+
|
| 51 |
+
def forward(self, graph: Batch, batch_size: int) -> Tensor:
|
| 52 |
+
"""Takes a graph as input and performs graph convolution on it.
|
| 53 |
+
|
| 54 |
+
:param graph: The batch of molecular graphs, where each atom is represented by
|
| 55 |
+
the atom/bond vector.
|
| 56 |
+
:param batch_size: The size of the batch.
|
| 57 |
+
:return: Graph embedding.
|
| 58 |
+
"""
|
| 59 |
+
atoms, connections = graph.x.float(), graph.edge_index.long()
|
| 60 |
+
atoms = torch.log(atoms + 1)
|
| 61 |
+
atoms = self.expansion(atoms)
|
| 62 |
+
for gcn_conv in self.gcn_convs:
|
| 63 |
+
atoms = atoms + self.dropout(relu(gcn_conv(atoms, connections)))
|
| 64 |
+
|
| 65 |
+
return global_add_pool(atoms, graph.batch, size=batch_size)
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
class GraphEmbeddingConcat(GraphEmbedding, Module):
|
| 69 |
+
"""Needed to concat.""" # TODO for what ?
|
| 70 |
+
|
| 71 |
+
def __init__(
|
| 72 |
+
self, vector_dim: int = 512, dropout: float = 0.4, num_conv_layers: int = 8
|
| 73 |
+
):
|
| 74 |
+
super().__init__()
|
| 75 |
+
|
| 76 |
+
gcn_dim = vector_dim // num_conv_layers
|
| 77 |
+
|
| 78 |
+
self.expansion = Linear(11, gcn_dim)
|
| 79 |
+
self.dropout = Dropout(dropout)
|
| 80 |
+
self.gcn_convs = ModuleList(
|
| 81 |
+
[
|
| 82 |
+
ModuleDict(
|
| 83 |
+
{
|
| 84 |
+
"gcn": GCNConv(gcn_dim, gcn_dim, improved=True),
|
| 85 |
+
"activation": GELU(),
|
| 86 |
+
}
|
| 87 |
+
)
|
| 88 |
+
for _ in range(num_conv_layers)
|
| 89 |
+
]
|
| 90 |
+
)
|
| 91 |
+
|
| 92 |
+
def forward(self, graph: Batch, batch_size: int) -> Tensor:
|
| 93 |
+
"""Takes a graph as input and performs graph convolution on it.
|
| 94 |
+
|
| 95 |
+
:param graph: The batch of molecular graphs, where each atom is represented by
|
| 96 |
+
the atom/bond vector.
|
| 97 |
+
:param batch_size: The size of the batch.
|
| 98 |
+
:return: Graph embedding.
|
| 99 |
+
"""
|
| 100 |
+
|
| 101 |
+
atoms, connections = graph.x.float(), graph.edge_index.long()
|
| 102 |
+
atoms = torch.log(atoms + 1)
|
| 103 |
+
atoms = self.expansion(atoms)
|
| 104 |
+
|
| 105 |
+
collected_atoms = []
|
| 106 |
+
for gcn_convs in self.gcn_convs:
|
| 107 |
+
atoms = gcn_convs["gcn"](atoms, connections)
|
| 108 |
+
atoms = gcn_convs["activation"](atoms)
|
| 109 |
+
atoms = self.dropout(atoms)
|
| 110 |
+
collected_atoms.append(atoms)
|
| 111 |
+
|
| 112 |
+
atoms = torch.cat(collected_atoms, dim=-1)
|
| 113 |
+
|
| 114 |
+
return global_add_pool(atoms, graph.batch, size=batch_size)
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
class MCTSNetwork(LightningModule, ABC):
|
| 118 |
+
"""Basic class for policy and value networks."""
|
| 119 |
+
|
| 120 |
+
def __init__(
|
| 121 |
+
self,
|
| 122 |
+
vector_dim: int,
|
| 123 |
+
batch_size: int,
|
| 124 |
+
dropout: float = 0.4,
|
| 125 |
+
num_conv_layers: int = 5,
|
| 126 |
+
learning_rate: float = 0.001,
|
| 127 |
+
gcn_concat: bool = False,
|
| 128 |
+
):
|
| 129 |
+
"""The basic class for MCTS graph convolutional neural networks (policy and
|
| 130 |
+
value network).
|
| 131 |
+
|
| 132 |
+
:param vector_dim: The dimensionality of the hidden layers and output layer of
|
| 133 |
+
graph convolution module.
|
| 134 |
+
:param dropout: Dropout is a regularization technique used in neural networks to
|
| 135 |
+
prevent overfitting.
|
| 136 |
+
:param num_conv_layers: The number of convolutional layers in a graph
|
| 137 |
+
convolutional module.
|
| 138 |
+
:param learning_rate: The learning rate determines how quickly the model learns
|
| 139 |
+
from the training data.
|
| 140 |
+
:param gcn_concat: ???. #TODO explain
|
| 141 |
+
"""
|
| 142 |
+
super().__init__()
|
| 143 |
+
if gcn_concat:
|
| 144 |
+
self.embedder = GraphEmbeddingConcat(vector_dim, dropout, num_conv_layers)
|
| 145 |
+
else:
|
| 146 |
+
self.embedder = GraphEmbedding(vector_dim, dropout, num_conv_layers)
|
| 147 |
+
self.batch_size = batch_size
|
| 148 |
+
self.lr = learning_rate
|
| 149 |
+
|
| 150 |
+
@abstractmethod
|
| 151 |
+
def forward(self, batch: Batch) -> Tensor:
|
| 152 |
+
"""The forward function takes a batch of input data and performs forward
|
| 153 |
+
propagation through the neural network.
|
| 154 |
+
|
| 155 |
+
:param batch: The batch of molecular graphs processed together in a single
|
| 156 |
+
forward pass through the neural network.
|
| 157 |
+
"""
|
| 158 |
+
|
| 159 |
+
@abstractmethod
|
| 160 |
+
def _get_loss(self, batch: Batch) -> Tensor:
|
| 161 |
+
"""Calculate the loss for a given batch of data.
|
| 162 |
+
|
| 163 |
+
:param batch: The batch of input data that is used to compute the loss.
|
| 164 |
+
"""
|
| 165 |
+
|
| 166 |
+
def training_step(self, batch: Batch, batch_idx: int) -> Tensor:
|
| 167 |
+
"""Calculates the loss for a given training batch and logs the loss value.
|
| 168 |
+
|
| 169 |
+
:param batch: The batch of data that is used for training.
|
| 170 |
+
:param batch_idx: The index of the batch.
|
| 171 |
+
:return: The value of the training loss.
|
| 172 |
+
"""
|
| 173 |
+
metrics = self._get_loss(batch)
|
| 174 |
+
for name, value in metrics.items():
|
| 175 |
+
self.log(
|
| 176 |
+
"train_" + name,
|
| 177 |
+
value,
|
| 178 |
+
prog_bar=True,
|
| 179 |
+
on_step=True,
|
| 180 |
+
on_epoch=True,
|
| 181 |
+
batch_size=self.batch_size,
|
| 182 |
+
)
|
| 183 |
+
return metrics["loss"]
|
| 184 |
+
|
| 185 |
+
def validation_step(self, batch: Batch, batch_idx: int) -> None:
|
| 186 |
+
"""Calculates the loss for a given validation batch and logs the loss value.
|
| 187 |
+
|
| 188 |
+
:param batch: The batch of data that is used for validation.
|
| 189 |
+
:param batch_idx: The index of the batch.
|
| 190 |
+
"""
|
| 191 |
+
metrics = self._get_loss(batch)
|
| 192 |
+
for name, value in metrics.items():
|
| 193 |
+
self.log("val_" + name, value, on_epoch=True, batch_size=self.batch_size)
|
| 194 |
+
|
| 195 |
+
def test_step(self, batch: Batch, batch_idx: int) -> None:
|
| 196 |
+
"""Calculates the loss for a given test batch and logs the loss value.
|
| 197 |
+
|
| 198 |
+
:param batch: The batch of data that is used for testing.
|
| 199 |
+
:param batch_idx: The index of the batch.
|
| 200 |
+
"""
|
| 201 |
+
metrics = self._get_loss(batch)
|
| 202 |
+
for name, value in metrics.items():
|
| 203 |
+
self.log("test_" + name, value, on_epoch=True, batch_size=self.batch_size)
|
| 204 |
+
|
| 205 |
+
def configure_optimizers(
|
| 206 |
+
self,
|
| 207 |
+
) -> Tuple[List[AdaBelief], List[Dict[str, Union[bool, str, ReduceLROnPlateau]]]]:
|
| 208 |
+
"""Returns an optimizer and a learning rate scheduler for training a model using
|
| 209 |
+
the AdaBelief optimizer and ReduceLROnPlateau scheduler.
|
| 210 |
+
|
| 211 |
+
:return: The optimizer and a scheduler.
|
| 212 |
+
"""
|
| 213 |
+
|
| 214 |
+
optimizer = AdaBelief(
|
| 215 |
+
self.parameters(),
|
| 216 |
+
lr=self.lr,
|
| 217 |
+
eps=1e-16,
|
| 218 |
+
betas=(0.9, 0.999),
|
| 219 |
+
weight_decouple=True,
|
| 220 |
+
rectify=True,
|
| 221 |
+
weight_decay=0.01,
|
| 222 |
+
print_change_log=False,
|
| 223 |
+
)
|
| 224 |
+
|
| 225 |
+
lr_scheduler = ReduceLROnPlateau(
|
| 226 |
+
optimizer, patience=3, factor=0.8, min_lr=5e-5, verbose=True
|
| 227 |
+
)
|
| 228 |
+
scheduler = {
|
| 229 |
+
"scheduler": lr_scheduler,
|
| 230 |
+
"reduce_on_plateau": True,
|
| 231 |
+
"monitor": "val_loss",
|
| 232 |
+
}
|
| 233 |
+
|
| 234 |
+
return [optimizer], [scheduler]
|
synplan/ml/networks/policy.py
ADDED
|
@@ -0,0 +1,137 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Module containing main class for policy network."""
|
| 2 |
+
|
| 3 |
+
from abc import ABC
|
| 4 |
+
from typing import Dict
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
from pytorch_lightning import LightningModule
|
| 8 |
+
from torch import Tensor
|
| 9 |
+
from torch.nn import Linear
|
| 10 |
+
from torch.nn.functional import binary_cross_entropy_with_logits, cross_entropy, one_hot
|
| 11 |
+
from torch_geometric.data.batch import Batch
|
| 12 |
+
from torchmetrics.functional.classification import f1_score, recall, specificity
|
| 13 |
+
|
| 14 |
+
from synplan.ml.networks.modules import MCTSNetwork
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
class PolicyNetwork(MCTSNetwork, LightningModule, ABC):
|
| 18 |
+
"""Policy network."""
|
| 19 |
+
|
| 20 |
+
def __init__(
|
| 21 |
+
self,
|
| 22 |
+
*args,
|
| 23 |
+
n_rules: int,
|
| 24 |
+
vector_dim: int,
|
| 25 |
+
policy_type: str = "ranking",
|
| 26 |
+
**kwargs
|
| 27 |
+
):
|
| 28 |
+
"""Initializes a policy network with the given number of reaction rules (output
|
| 29 |
+
dimension) and vector graph embedding dimension, and creates linear layers for
|
| 30 |
+
predicting the regular and priority reaction rules.
|
| 31 |
+
|
| 32 |
+
:param n_rules: The number of reaction rules in the policy network.
|
| 33 |
+
:param vector_dim: The dimensionality of the input vectors.
|
| 34 |
+
"""
|
| 35 |
+
super().__init__(vector_dim, *args, **kwargs)
|
| 36 |
+
self.save_hyperparameters()
|
| 37 |
+
self.policy_type = policy_type
|
| 38 |
+
self.n_rules = n_rules
|
| 39 |
+
self.y_predictor = Linear(vector_dim, n_rules)
|
| 40 |
+
|
| 41 |
+
if self.policy_type == "filtering":
|
| 42 |
+
self.priority_predictor = Linear(vector_dim, n_rules)
|
| 43 |
+
|
| 44 |
+
def forward(self, batch: Batch) -> Tensor:
|
| 45 |
+
"""Takes a molecular graph, applies a graph convolution and sigmoid layers to
|
| 46 |
+
predict regular and priority reaction rules.
|
| 47 |
+
|
| 48 |
+
:param batch: The input batch of molecular graphs.
|
| 49 |
+
:return: Returns the vector of probabilities (given by sigmoid) of successful
|
| 50 |
+
application of regular and priority reaction rules.
|
| 51 |
+
"""
|
| 52 |
+
x = self.embedder(batch, self.batch_size)
|
| 53 |
+
y = self.y_predictor(x)
|
| 54 |
+
|
| 55 |
+
if self.policy_type == "ranking":
|
| 56 |
+
y = torch.softmax(y, dim=-1)
|
| 57 |
+
return y
|
| 58 |
+
|
| 59 |
+
if self.policy_type == "filtering":
|
| 60 |
+
y = torch.sigmoid(y)
|
| 61 |
+
priority = torch.sigmoid(self.priority_predictor(x))
|
| 62 |
+
return y, priority
|
| 63 |
+
|
| 64 |
+
def _get_loss(self, batch: Batch) -> Dict[str, Tensor]:
|
| 65 |
+
"""Calculates the loss and various classification metrics for a given batch for
|
| 66 |
+
reaction rules prediction.
|
| 67 |
+
|
| 68 |
+
:param batch: The batch of molecular graphs.
|
| 69 |
+
:return: A dictionary with loss value and balanced accuracy of reaction rules
|
| 70 |
+
prediction.
|
| 71 |
+
"""
|
| 72 |
+
true_y = batch.y_rules.long()
|
| 73 |
+
x = self.embedder(batch, self.batch_size)
|
| 74 |
+
pred_y = self.y_predictor(x)
|
| 75 |
+
|
| 76 |
+
if self.policy_type == "ranking":
|
| 77 |
+
true_one_hot = one_hot(true_y, num_classes=self.n_rules)
|
| 78 |
+
loss = cross_entropy(pred_y, true_one_hot.float())
|
| 79 |
+
ba_y = (
|
| 80 |
+
recall(pred_y, true_y, task="multiclass", num_classes=self.n_rules)
|
| 81 |
+
+ specificity(
|
| 82 |
+
pred_y, true_y, task="multiclass", num_classes=self.n_rules
|
| 83 |
+
)
|
| 84 |
+
) / 2
|
| 85 |
+
f1_y = f1_score(pred_y, true_y, task="multiclass", num_classes=self.n_rules)
|
| 86 |
+
|
| 87 |
+
metrics = {"loss": loss, "balanced_accuracy_y": ba_y, "f1_score_y": f1_y}
|
| 88 |
+
|
| 89 |
+
elif self.policy_type == "filtering":
|
| 90 |
+
loss_y = binary_cross_entropy_with_logits(pred_y, true_y.float())
|
| 91 |
+
|
| 92 |
+
ba_y = (
|
| 93 |
+
recall(pred_y, true_y, task="multilabel", num_labels=self.n_rules)
|
| 94 |
+
+ specificity(
|
| 95 |
+
pred_y, true_y, task="multilabel", num_labels=self.n_rules
|
| 96 |
+
)
|
| 97 |
+
) / 2
|
| 98 |
+
|
| 99 |
+
f1_y = f1_score(pred_y, true_y, task="multilabel", num_labels=self.n_rules)
|
| 100 |
+
|
| 101 |
+
true_priority = batch.y_priority.float()
|
| 102 |
+
pred_priority = self.priority_predictor(x)
|
| 103 |
+
loss_priority = binary_cross_entropy_with_logits(
|
| 104 |
+
pred_priority, true_priority
|
| 105 |
+
)
|
| 106 |
+
|
| 107 |
+
loss = loss_y + loss_priority
|
| 108 |
+
|
| 109 |
+
true_priority = true_priority.long()
|
| 110 |
+
ba_priority = (
|
| 111 |
+
recall(
|
| 112 |
+
pred_priority,
|
| 113 |
+
true_priority,
|
| 114 |
+
task="multilabel",
|
| 115 |
+
num_labels=self.n_rules,
|
| 116 |
+
)
|
| 117 |
+
+ specificity(
|
| 118 |
+
pred_priority,
|
| 119 |
+
true_priority,
|
| 120 |
+
task="multilabel",
|
| 121 |
+
num_labels=self.n_rules,
|
| 122 |
+
)
|
| 123 |
+
) / 2
|
| 124 |
+
|
| 125 |
+
f1_priority = f1_score(
|
| 126 |
+
pred_priority, true_priority, task="multilabel", num_labels=self.n_rules
|
| 127 |
+
)
|
| 128 |
+
|
| 129 |
+
metrics = {
|
| 130 |
+
"loss": loss,
|
| 131 |
+
"balanced_accuracy_y": ba_y,
|
| 132 |
+
"f1_score_y": f1_y,
|
| 133 |
+
"balanced_accuracy_priority": ba_priority,
|
| 134 |
+
"f1_score_priority": f1_priority,
|
| 135 |
+
}
|
| 136 |
+
|
| 137 |
+
return metrics
|
synplan/ml/networks/value.py
ADDED
|
@@ -0,0 +1,67 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Module containing main class for value network."""
|
| 2 |
+
|
| 3 |
+
from abc import ABC
|
| 4 |
+
from typing import Any, Dict
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
from pytorch_lightning import LightningModule
|
| 8 |
+
from torch import Tensor
|
| 9 |
+
from torch.nn import Linear
|
| 10 |
+
from torch.nn.functional import binary_cross_entropy_with_logits
|
| 11 |
+
from torch_geometric.data.batch import Batch
|
| 12 |
+
from torchmetrics.functional.classification import (
|
| 13 |
+
binary_f1_score,
|
| 14 |
+
binary_recall,
|
| 15 |
+
binary_specificity,
|
| 16 |
+
)
|
| 17 |
+
|
| 18 |
+
from synplan.ml.networks.modules import MCTSNetwork
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
class ValueNetwork(MCTSNetwork, LightningModule, ABC):
|
| 22 |
+
"""Value network."""
|
| 23 |
+
|
| 24 |
+
def __init__(self, vector_dim: int, *args: Any, **kwargs: Any) -> None:
|
| 25 |
+
"""Initializes a value network, and creates linear layer for predicting the
|
| 26 |
+
synthesisability of given precursor represented by molecular graph.
|
| 27 |
+
|
| 28 |
+
:param vector_dim: The dimensionality of the output linear layer.
|
| 29 |
+
"""
|
| 30 |
+
super().__init__(vector_dim, *args, **kwargs)
|
| 31 |
+
self.save_hyperparameters()
|
| 32 |
+
self.predictor = Linear(vector_dim, 1)
|
| 33 |
+
|
| 34 |
+
def forward(self, batch) -> torch.Tensor:
|
| 35 |
+
"""Takes a batch of molecular graphs, applies a graph convolution returns the
|
| 36 |
+
synthesisability (probability given by sigmoid function) of a given precursor
|
| 37 |
+
represented by molecular graph precessed by graph convolution.
|
| 38 |
+
|
| 39 |
+
:param batch: The batch of molecular graphs.
|
| 40 |
+
:return: The predicted synthesisability (between 0 and 1).
|
| 41 |
+
"""
|
| 42 |
+
|
| 43 |
+
x = self.embedder(batch, self.batch_size)
|
| 44 |
+
x = torch.sigmoid(self.predictor(x))
|
| 45 |
+
return x
|
| 46 |
+
|
| 47 |
+
def _get_loss(self, batch: Batch) -> Dict[str, Tensor]:
|
| 48 |
+
"""Calculates the loss and various classification metrics for a given batch for
|
| 49 |
+
the precursor synthesysability prediction.
|
| 50 |
+
|
| 51 |
+
:param batch: The batch of molecular graphs.
|
| 52 |
+
:return: The dictionary with loss value and balanced accuracy of precursor
|
| 53 |
+
synthesysability prediction.
|
| 54 |
+
"""
|
| 55 |
+
|
| 56 |
+
true_y = batch.y.float()
|
| 57 |
+
true_y = torch.unsqueeze(true_y, -1)
|
| 58 |
+
x = self.embedder(batch, self.batch_size)
|
| 59 |
+
pred_y = self.predictor(x)
|
| 60 |
+
# calc loss func
|
| 61 |
+
loss = binary_cross_entropy_with_logits(pred_y, true_y)
|
| 62 |
+
|
| 63 |
+
true_y = true_y.long()
|
| 64 |
+
ba = (binary_recall(pred_y, true_y) + binary_specificity(pred_y, true_y)) / 2
|
| 65 |
+
f1 = binary_f1_score(pred_y, true_y)
|
| 66 |
+
metrics = {"loss": loss, "balanced_accuracy": ba, "f1_score": f1}
|
| 67 |
+
return metrics
|
synplan/ml/training/__init__.py
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .supervised import *
|
| 2 |
+
from .preprocessing import ValueNetworkDataset, mol_to_pyg, MENDEL_INFO
|
| 3 |
+
from .supervised import create_policy_dataset, run_policy_training
|
| 4 |
+
|
| 5 |
+
__all__ = [
|
| 6 |
+
"ValueNetworkDataset",
|
| 7 |
+
"mol_to_pyg",
|
| 8 |
+
"MENDEL_INFO",
|
| 9 |
+
"create_policy_dataset",
|
| 10 |
+
"run_policy_training",
|
| 11 |
+
]
|
synplan/ml/training/preprocessing.py
ADDED
|
@@ -0,0 +1,516 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Module containing functions for preparation of the training sets for policy and value
|
| 2 |
+
network."""
|
| 3 |
+
|
| 4 |
+
import logging
|
| 5 |
+
import os
|
| 6 |
+
import pickle
|
| 7 |
+
from abc import ABC
|
| 8 |
+
from typing import Any, Dict, List, Optional, Tuple
|
| 9 |
+
|
| 10 |
+
import ray
|
| 11 |
+
import torch
|
| 12 |
+
from CGRtools import smiles
|
| 13 |
+
from CGRtools.containers import MoleculeContainer
|
| 14 |
+
from CGRtools.exceptions import InvalidAromaticRing
|
| 15 |
+
from CGRtools.reactor import Reactor
|
| 16 |
+
from ray.util.queue import Empty, Queue
|
| 17 |
+
from torch import Tensor
|
| 18 |
+
from torch_geometric.data import InMemoryDataset
|
| 19 |
+
from torch_geometric.data.data import Data
|
| 20 |
+
from torch_geometric.data.makedirs import makedirs
|
| 21 |
+
from torch_geometric.transforms import ToUndirected
|
| 22 |
+
from tqdm import tqdm
|
| 23 |
+
|
| 24 |
+
from synplan.chem.utils import unite_molecules
|
| 25 |
+
from synplan.utils.files import ReactionReader
|
| 26 |
+
from synplan.utils.loading import load_reaction_rules
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
class ValueNetworkDataset(InMemoryDataset, ABC):
|
| 30 |
+
"""Value network dataset."""
|
| 31 |
+
|
| 32 |
+
def __init__(self, extracted_precursor: Dict[str, float]) -> None:
|
| 33 |
+
"""Initializes a value network dataset object.
|
| 34 |
+
|
| 35 |
+
:param extracted_precursor: The dictionary with the extracted from the built
|
| 36 |
+
search trees precursor and their labels.
|
| 37 |
+
"""
|
| 38 |
+
super().__init__(None, None, None)
|
| 39 |
+
|
| 40 |
+
if extracted_precursor:
|
| 41 |
+
self.data, self.slices = self.graphs_from_extracted_precursor(
|
| 42 |
+
extracted_precursor
|
| 43 |
+
)
|
| 44 |
+
|
| 45 |
+
@staticmethod
|
| 46 |
+
def mol_to_graph(molecule: MoleculeContainer, label: float) -> Optional[Data]:
|
| 47 |
+
"""Takes a molecule as input, and converts the molecule to a PyTorch geometric
|
| 48 |
+
graph, assigns the reward value (label) to the graph, and returns the graph.
|
| 49 |
+
|
| 50 |
+
:param molecule: The input molecule.
|
| 51 |
+
:param label: The label (solved/unsolved routes in the tree) of the molecule
|
| 52 |
+
(precursor).
|
| 53 |
+
:return: A PyTorch Geometric graph representation of a molecule.
|
| 54 |
+
"""
|
| 55 |
+
if len(molecule) > 2:
|
| 56 |
+
pyg = mol_to_pyg(molecule)
|
| 57 |
+
if pyg:
|
| 58 |
+
pyg.y = torch.tensor([label])
|
| 59 |
+
return pyg
|
| 60 |
+
|
| 61 |
+
return None
|
| 62 |
+
|
| 63 |
+
def graphs_from_extracted_precursor(
|
| 64 |
+
self, extracted_precursor: Dict[str, float]
|
| 65 |
+
) -> Tuple[Data, Dict]:
|
| 66 |
+
"""Converts the extracted from the search trees precursor to the PyTorch geometric
|
| 67 |
+
graphs.
|
| 68 |
+
|
| 69 |
+
:param extracted_precursor: The dictionary with the extracted from the built
|
| 70 |
+
search trees precursor and their labels.
|
| 71 |
+
:return: The PyTorch geometric graphs and slices.
|
| 72 |
+
"""
|
| 73 |
+
processed_data = []
|
| 74 |
+
for smi, label in extracted_precursor.items():
|
| 75 |
+
mol = smiles(smi)
|
| 76 |
+
pyg = self.mol_to_graph(mol, label)
|
| 77 |
+
if pyg:
|
| 78 |
+
processed_data.append(pyg)
|
| 79 |
+
data, slices = self.collate(processed_data)
|
| 80 |
+
return data, slices
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
class RankingPolicyDataset(InMemoryDataset):
|
| 84 |
+
"""Ranking policy network dataset."""
|
| 85 |
+
|
| 86 |
+
def __init__(self, reactions_path: str, reaction_rules_path: str, output_path: str):
|
| 87 |
+
"""Initializes a policy network dataset.
|
| 88 |
+
|
| 89 |
+
:param reactions_path: The path to the file containing the reaction data used
|
| 90 |
+
for extraction of reaction rules.
|
| 91 |
+
:param reaction_rules_path: The path to the file containing the reaction rules.
|
| 92 |
+
:param output_path: The output path to the file where policy network dataset
|
| 93 |
+
will be saved.
|
| 94 |
+
"""
|
| 95 |
+
super().__init__(None, None, None)
|
| 96 |
+
|
| 97 |
+
self.reactions_path = reactions_path
|
| 98 |
+
self.reaction_rules_path = reaction_rules_path
|
| 99 |
+
self.output_path = output_path
|
| 100 |
+
|
| 101 |
+
if output_path and os.path.exists(output_path):
|
| 102 |
+
self.data, self.slices = torch.load(self.output_path)
|
| 103 |
+
else:
|
| 104 |
+
self.data, self.slices = self.prepare_data()
|
| 105 |
+
|
| 106 |
+
@property
|
| 107 |
+
def num_classes(self) -> int:
|
| 108 |
+
return self._infer_num_classes(self._data.y_rules)
|
| 109 |
+
|
| 110 |
+
def prepare_data(self) -> Tuple[Data, Dict[str, Tensor]]:
|
| 111 |
+
"""Prepares data by loading reaction rules, preprocessing the molecules,
|
| 112 |
+
collating the data, and returning the data and slices.
|
| 113 |
+
|
| 114 |
+
:return: The PyTorch geometric graphs and slices.
|
| 115 |
+
"""
|
| 116 |
+
|
| 117 |
+
with open(self.reaction_rules_path, "rb") as inp:
|
| 118 |
+
reaction_rules = pickle.load(inp)
|
| 119 |
+
reaction_rules = sorted(reaction_rules, key=lambda x: len(x[1]), reverse=True)
|
| 120 |
+
|
| 121 |
+
reaction_rule_pairs = {}
|
| 122 |
+
for rule_i, (_, reactions_ids) in enumerate(reaction_rules):
|
| 123 |
+
for reaction_id in reactions_ids:
|
| 124 |
+
reaction_rule_pairs[reaction_id] = rule_i
|
| 125 |
+
reaction_rule_pairs = dict(sorted(reaction_rule_pairs.items()))
|
| 126 |
+
|
| 127 |
+
list_of_graphs = []
|
| 128 |
+
with ReactionReader(self.reactions_path) as reactions:
|
| 129 |
+
|
| 130 |
+
for reaction_id, reaction in tqdm(
|
| 131 |
+
enumerate(reactions),
|
| 132 |
+
desc="Number of reactions processed: ",
|
| 133 |
+
bar_format="{desc}{n} [{elapsed}]",
|
| 134 |
+
):
|
| 135 |
+
|
| 136 |
+
rule_id = reaction_rule_pairs.get(reaction_id)
|
| 137 |
+
if rule_id:
|
| 138 |
+
try: # MENDEL_INFO does not contain cadmium (Cd) properties
|
| 139 |
+
molecule = unite_molecules(reaction.products)
|
| 140 |
+
pyg_graph = mol_to_pyg(molecule)
|
| 141 |
+
|
| 142 |
+
except (
|
| 143 |
+
Exception
|
| 144 |
+
) as e: # TypeError: can't assign a NoneType to a torch.ByteTensor
|
| 145 |
+
logging.debug(e)
|
| 146 |
+
continue
|
| 147 |
+
|
| 148 |
+
if pyg_graph is not None:
|
| 149 |
+
pyg_graph.y_rules = torch.tensor([rule_id], dtype=torch.long)
|
| 150 |
+
list_of_graphs.append(pyg_graph)
|
| 151 |
+
else:
|
| 152 |
+
continue
|
| 153 |
+
|
| 154 |
+
data, slices = self.collate(list_of_graphs)
|
| 155 |
+
if self.output_path:
|
| 156 |
+
makedirs(os.path.dirname(self.output_path))
|
| 157 |
+
torch.save((data, slices), self.output_path)
|
| 158 |
+
|
| 159 |
+
return data, slices
|
| 160 |
+
|
| 161 |
+
|
| 162 |
+
class FilteringPolicyDataset(InMemoryDataset):
|
| 163 |
+
"""Filtering policy network dataset."""
|
| 164 |
+
|
| 165 |
+
def __init__(
|
| 166 |
+
self,
|
| 167 |
+
molecules_path: str,
|
| 168 |
+
reaction_rules_path: str,
|
| 169 |
+
output_path: str,
|
| 170 |
+
num_cpus: int,
|
| 171 |
+
) -> None:
|
| 172 |
+
"""Initializes a policy network dataset object.
|
| 173 |
+
|
| 174 |
+
:param molecules_path: The path to the file containing the molecules for
|
| 175 |
+
reaction rule appliance.
|
| 176 |
+
:param reaction_rules_path: The path to the file containing the reaction rules.
|
| 177 |
+
:param output_path: The output path to the file where policy network dataset
|
| 178 |
+
will be stored.
|
| 179 |
+
:param num_cpus: The number of CPUs to be used for the dataset preparation.
|
| 180 |
+
:return: None.
|
| 181 |
+
"""
|
| 182 |
+
super().__init__(None, None, None)
|
| 183 |
+
|
| 184 |
+
self.molecules_path = molecules_path
|
| 185 |
+
self.reaction_rules_path = reaction_rules_path
|
| 186 |
+
self.output_path = output_path
|
| 187 |
+
self.num_cpus = num_cpus
|
| 188 |
+
self.batch_size = 100
|
| 189 |
+
|
| 190 |
+
if output_path and os.path.exists(output_path):
|
| 191 |
+
self.data, self.slices = torch.load(self.output_path)
|
| 192 |
+
else:
|
| 193 |
+
self.data, self.slices = self.prepare_data()
|
| 194 |
+
|
| 195 |
+
@property
|
| 196 |
+
def num_classes(self) -> int:
|
| 197 |
+
return self._data.y_rules.shape[1]
|
| 198 |
+
|
| 199 |
+
def prepare_data(self) -> Tuple[Data, Dict]:
|
| 200 |
+
"""Prepares data by loading reaction rules, initializing Ray, preprocessing the
|
| 201 |
+
molecules, collating the data, and returning the data and slices.
|
| 202 |
+
|
| 203 |
+
:return: The PyTorch geometric graphs and slices.
|
| 204 |
+
"""
|
| 205 |
+
|
| 206 |
+
ray.init(num_cpus=self.num_cpus, ignore_reinit_error=True)
|
| 207 |
+
reaction_rules = load_reaction_rules(self.reaction_rules_path)
|
| 208 |
+
reaction_rules_ids = ray.put(reaction_rules)
|
| 209 |
+
|
| 210 |
+
to_process = Queue(maxsize=self.batch_size * self.num_cpus)
|
| 211 |
+
processed_data = []
|
| 212 |
+
results_ids = [
|
| 213 |
+
preprocess_filtering_policy_molecules.remote(to_process, reaction_rules_ids)
|
| 214 |
+
for _ in range(self.num_cpus)
|
| 215 |
+
]
|
| 216 |
+
|
| 217 |
+
with open(self.molecules_path, "r", encoding="utf-8") as inp_data:
|
| 218 |
+
for molecule in tqdm(
|
| 219 |
+
inp_data.read().splitlines(),
|
| 220 |
+
desc="Number of molecules processed: ",
|
| 221 |
+
bar_format="{desc}{n} [{elapsed}]",
|
| 222 |
+
):
|
| 223 |
+
|
| 224 |
+
to_process.put(molecule)
|
| 225 |
+
|
| 226 |
+
results = [graph for res in ray.get(results_ids) if res for graph in res]
|
| 227 |
+
processed_data.extend(results)
|
| 228 |
+
|
| 229 |
+
ray.shutdown()
|
| 230 |
+
|
| 231 |
+
for pyg in processed_data:
|
| 232 |
+
pyg.y_rules = pyg.y_rules.to_dense()
|
| 233 |
+
pyg.y_priority = pyg.y_priority.to_dense()
|
| 234 |
+
|
| 235 |
+
data, slices = self.collate(processed_data)
|
| 236 |
+
if self.output_path:
|
| 237 |
+
makedirs(os.path.dirname(self.output_path))
|
| 238 |
+
torch.save((data, slices), self.output_path)
|
| 239 |
+
|
| 240 |
+
return data, slices
|
| 241 |
+
|
| 242 |
+
|
| 243 |
+
def reaction_rules_appliance(
|
| 244 |
+
molecule: MoleculeContainer, reaction_rules: List[Reactor]
|
| 245 |
+
) -> Tuple[List[int], List[int]]:
|
| 246 |
+
"""Applies each reaction rule from the list of reaction rules to a given molecule
|
| 247 |
+
and returns the indexes of the successfully applied regular and prioritized reaction
|
| 248 |
+
rules.
|
| 249 |
+
|
| 250 |
+
:param molecule: The input molecule.
|
| 251 |
+
:param reaction_rules: The list of reaction rules.
|
| 252 |
+
:return: The two lists of indexes of successfully applied regular reaction rules and
|
| 253 |
+
priority reaction rules.
|
| 254 |
+
"""
|
| 255 |
+
|
| 256 |
+
applied_rules, priority_rules = [], []
|
| 257 |
+
for i, rule in enumerate(reaction_rules):
|
| 258 |
+
|
| 259 |
+
rule_applied = False
|
| 260 |
+
rule_prioritized = False
|
| 261 |
+
|
| 262 |
+
try:
|
| 263 |
+
for reaction in rule([molecule]):
|
| 264 |
+
for prod in reaction.products:
|
| 265 |
+
prod.kekule()
|
| 266 |
+
if prod.check_valence():
|
| 267 |
+
break
|
| 268 |
+
rule_applied = True
|
| 269 |
+
|
| 270 |
+
# check priority rules
|
| 271 |
+
if len(reaction.products) > 1:
|
| 272 |
+
# check coupling retro manual
|
| 273 |
+
if all(len(mol) > 6 for mol in reaction.products):
|
| 274 |
+
if (
|
| 275 |
+
sum(len(mol) for mol in reaction.products)
|
| 276 |
+
- len(reaction.reactants[0])
|
| 277 |
+
< 6
|
| 278 |
+
):
|
| 279 |
+
rule_prioritized = True
|
| 280 |
+
else:
|
| 281 |
+
# check cyclization retro manual
|
| 282 |
+
if sum(len(mol.sssr) for mol in reaction.products) < sum(
|
| 283 |
+
len(mol.sssr) for mol in reaction.reactants
|
| 284 |
+
):
|
| 285 |
+
rule_prioritized = True
|
| 286 |
+
#
|
| 287 |
+
if rule_applied:
|
| 288 |
+
applied_rules.append(i)
|
| 289 |
+
#
|
| 290 |
+
if rule_prioritized:
|
| 291 |
+
priority_rules.append(i)
|
| 292 |
+
except Exception as e:
|
| 293 |
+
logging.debug(e)
|
| 294 |
+
continue
|
| 295 |
+
|
| 296 |
+
return applied_rules, priority_rules
|
| 297 |
+
|
| 298 |
+
|
| 299 |
+
@ray.remote
|
| 300 |
+
def preprocess_filtering_policy_molecules(
|
| 301 |
+
to_process: Queue, reaction_rules: List[Reactor]
|
| 302 |
+
) -> List[Optional[Data]]:
|
| 303 |
+
"""Preprocesses a list of molecules by applying reaction rules and converting
|
| 304 |
+
molecules into PyTorch geometric graphs. Successfully applied reaction rules are
|
| 305 |
+
converted to binary vectors for policy network training.
|
| 306 |
+
|
| 307 |
+
:param to_process: The queue containing SMILES of molecules to be converted to the
|
| 308 |
+
training data.
|
| 309 |
+
:param reaction_rules: The list of reaction rules.
|
| 310 |
+
:return: The list of PyGraph objects.
|
| 311 |
+
"""
|
| 312 |
+
|
| 313 |
+
pyg_graphs = []
|
| 314 |
+
while True:
|
| 315 |
+
try:
|
| 316 |
+
molecule = smiles(to_process.get(timeout=30))
|
| 317 |
+
if not isinstance(molecule, MoleculeContainer):
|
| 318 |
+
continue
|
| 319 |
+
|
| 320 |
+
# reaction reaction_rules application
|
| 321 |
+
applied_rules, priority_rules = reaction_rules_appliance(
|
| 322 |
+
molecule, reaction_rules
|
| 323 |
+
)
|
| 324 |
+
|
| 325 |
+
y_rules = torch.sparse_coo_tensor(
|
| 326 |
+
[applied_rules],
|
| 327 |
+
torch.ones(len(applied_rules)),
|
| 328 |
+
(len(reaction_rules),),
|
| 329 |
+
dtype=torch.uint8,
|
| 330 |
+
)
|
| 331 |
+
y_priority = torch.sparse_coo_tensor(
|
| 332 |
+
[priority_rules],
|
| 333 |
+
torch.ones(len(priority_rules)),
|
| 334 |
+
(len(reaction_rules),),
|
| 335 |
+
dtype=torch.uint8,
|
| 336 |
+
)
|
| 337 |
+
|
| 338 |
+
y_rules = torch.unsqueeze(y_rules, 0)
|
| 339 |
+
y_priority = torch.unsqueeze(y_priority, 0)
|
| 340 |
+
|
| 341 |
+
pyg_graph = mol_to_pyg(molecule)
|
| 342 |
+
if not pyg_graph:
|
| 343 |
+
continue
|
| 344 |
+
pyg_graph.y_rules = y_rules
|
| 345 |
+
pyg_graph.y_priority = y_priority
|
| 346 |
+
pyg_graphs.append(pyg_graph)
|
| 347 |
+
|
| 348 |
+
except Empty:
|
| 349 |
+
break
|
| 350 |
+
|
| 351 |
+
return pyg_graphs
|
| 352 |
+
|
| 353 |
+
|
| 354 |
+
def atom_to_vector(atom: Any) -> Tensor:
|
| 355 |
+
"""Given an atom, return a vector of length 8 with the following
|
| 356 |
+
information:
|
| 357 |
+
|
| 358 |
+
1. Atomic number
|
| 359 |
+
2. Period
|
| 360 |
+
3. Group
|
| 361 |
+
4. Number of electrons + atom's charge
|
| 362 |
+
5. Shell
|
| 363 |
+
6. Total number of hydrogens
|
| 364 |
+
7. Whether the atom is in a ring
|
| 365 |
+
8. Number of neighbors
|
| 366 |
+
|
| 367 |
+
:param atom: The atom object.
|
| 368 |
+
|
| 369 |
+
:return: The vector of the atom.
|
| 370 |
+
"""
|
| 371 |
+
vector = torch.zeros(8, dtype=torch.uint8)
|
| 372 |
+
period, group, shell, electrons = MENDEL_INFO[atom.atomic_symbol]
|
| 373 |
+
vector[0] = atom.atomic_number
|
| 374 |
+
vector[1] = period
|
| 375 |
+
vector[2] = group
|
| 376 |
+
vector[3] = electrons + atom.charge
|
| 377 |
+
vector[4] = shell
|
| 378 |
+
vector[5] = atom.total_hydrogens
|
| 379 |
+
vector[6] = int(atom.in_ring)
|
| 380 |
+
vector[7] = atom.neighbors
|
| 381 |
+
return vector
|
| 382 |
+
|
| 383 |
+
|
| 384 |
+
def bonds_to_vector(molecule: MoleculeContainer, atom_ind: int) -> Tensor:
|
| 385 |
+
"""Takes a molecule and an atom index as input, and returns a vector representing
|
| 386 |
+
the bond orders of the atom's bonds.
|
| 387 |
+
|
| 388 |
+
:param molecule: The given molecule.
|
| 389 |
+
:param atom_ind: The index of the atom in the molecule to be converted to the bond
|
| 390 |
+
vector.
|
| 391 |
+
:return: The torch tensor of size 3, with each element representing the order of
|
| 392 |
+
bonds connected to the atom with the given index in the molecule.
|
| 393 |
+
"""
|
| 394 |
+
|
| 395 |
+
vector = torch.zeros(3, dtype=torch.uint8)
|
| 396 |
+
for b_order in molecule._bonds[atom_ind].values():
|
| 397 |
+
vector[int(b_order) - 1] += 1
|
| 398 |
+
return vector
|
| 399 |
+
|
| 400 |
+
|
| 401 |
+
def mol_to_matrix(molecule: MoleculeContainer) -> Tensor:
|
| 402 |
+
"""Given a molecule, it returns a vector of shape (max_atoms, 12) where each row is
|
| 403 |
+
an atom and each column is a feature.
|
| 404 |
+
|
| 405 |
+
:param molecule: The molecule to be converted to a vector
|
| 406 |
+
:return: The atoms vectors array.
|
| 407 |
+
"""
|
| 408 |
+
|
| 409 |
+
atoms_vectors = torch.zeros((len(molecule), 11), dtype=torch.uint8)
|
| 410 |
+
for n, atom in molecule.atoms():
|
| 411 |
+
atoms_vectors[n - 1][:8] = atom_to_vector(atom)
|
| 412 |
+
for n, _ in molecule.atoms():
|
| 413 |
+
atoms_vectors[n - 1][8:] = bonds_to_vector(molecule, n)
|
| 414 |
+
|
| 415 |
+
return atoms_vectors
|
| 416 |
+
|
| 417 |
+
|
| 418 |
+
def mol_to_pyg(
|
| 419 |
+
molecule: MoleculeContainer, canonicalize: bool = True
|
| 420 |
+
) -> Optional[Data]:
|
| 421 |
+
"""Takes a list of molecules and returns a list of PyTorch Geometric graphs, a one-
|
| 422 |
+
hot encoded vectors of the atoms, and a matrices of the bonds.
|
| 423 |
+
|
| 424 |
+
:param molecule: The molecule to be converted to PyTorch Geometric graph.
|
| 425 |
+
:param canonicalize: If True, the input molecule is canonicalized.
|
| 426 |
+
:return: The list of PyGraph objects.
|
| 427 |
+
"""
|
| 428 |
+
|
| 429 |
+
if len(molecule) == 1: # to avoid a precursor to be a single atom
|
| 430 |
+
return None
|
| 431 |
+
|
| 432 |
+
tmp_molecule = molecule.copy()
|
| 433 |
+
try:
|
| 434 |
+
if canonicalize:
|
| 435 |
+
tmp_molecule.canonicalize()
|
| 436 |
+
tmp_molecule.kekule()
|
| 437 |
+
if tmp_molecule.check_valence():
|
| 438 |
+
return None
|
| 439 |
+
except InvalidAromaticRing:
|
| 440 |
+
return None
|
| 441 |
+
|
| 442 |
+
# remapping target for torch_geometric because
|
| 443 |
+
# it is necessary that the elements in edge_index only hold nodes_idx in the range { 0, ..., num_nodes - 1}
|
| 444 |
+
new_mappings = {n: i for i, (n, _) in enumerate(tmp_molecule.atoms(), 1)}
|
| 445 |
+
tmp_molecule.remap(new_mappings)
|
| 446 |
+
|
| 447 |
+
# get edge indexes from target mapping
|
| 448 |
+
edge_index = []
|
| 449 |
+
for atom, neighbour, bond in tmp_molecule.bonds():
|
| 450 |
+
edge_index.append([atom - 1, neighbour - 1])
|
| 451 |
+
edge_index = torch.tensor(edge_index, dtype=torch.long)
|
| 452 |
+
|
| 453 |
+
#
|
| 454 |
+
x = mol_to_matrix(tmp_molecule)
|
| 455 |
+
|
| 456 |
+
mol_pyg_graph = Data(x=x, edge_index=edge_index.t().contiguous())
|
| 457 |
+
mol_pyg_graph = ToUndirected()(mol_pyg_graph)
|
| 458 |
+
|
| 459 |
+
assert mol_pyg_graph.is_undirected()
|
| 460 |
+
|
| 461 |
+
return mol_pyg_graph
|
| 462 |
+
|
| 463 |
+
|
| 464 |
+
MENDEL_INFO = {
|
| 465 |
+
"Ag": (5, 11, 1, 1),
|
| 466 |
+
"Al": (3, 13, 2, 1),
|
| 467 |
+
"Ar": (3, 18, 2, 6),
|
| 468 |
+
"As": (4, 15, 2, 3),
|
| 469 |
+
"B": (2, 13, 2, 1),
|
| 470 |
+
"Ba": (6, 2, 1, 2),
|
| 471 |
+
"Bi": (6, 15, 2, 3),
|
| 472 |
+
"Br": (4, 17, 2, 5),
|
| 473 |
+
"C": (2, 14, 2, 2),
|
| 474 |
+
"Ca": (4, 2, 1, 2),
|
| 475 |
+
"Ce": (6, None, 1, 2),
|
| 476 |
+
"Cl": (3, 17, 2, 5),
|
| 477 |
+
"Cr": (4, 6, 1, 1),
|
| 478 |
+
"Cs": (6, 1, 1, 1),
|
| 479 |
+
"Cu": (4, 11, 1, 1),
|
| 480 |
+
"Dy": (6, None, 1, 2),
|
| 481 |
+
"Er": (6, None, 1, 2),
|
| 482 |
+
"F": (2, 17, 2, 5),
|
| 483 |
+
"Fe": (4, 8, 1, 2),
|
| 484 |
+
"Ga": (4, 13, 2, 1),
|
| 485 |
+
"Gd": (6, None, 1, 2),
|
| 486 |
+
"Ge": (4, 14, 2, 2),
|
| 487 |
+
"Hg": (6, 12, 1, 2),
|
| 488 |
+
"I": (5, 17, 2, 5),
|
| 489 |
+
"In": (5, 13, 2, 1),
|
| 490 |
+
"K": (4, 1, 1, 1),
|
| 491 |
+
"La": (6, 3, 1, 2),
|
| 492 |
+
"Li": (2, 1, 1, 1),
|
| 493 |
+
"Mg": (3, 2, 1, 2),
|
| 494 |
+
"Mn": (4, 7, 1, 2),
|
| 495 |
+
"N": (2, 15, 2, 3),
|
| 496 |
+
"Na": (3, 1, 1, 1),
|
| 497 |
+
"Nd": (6, None, 1, 2),
|
| 498 |
+
"O": (2, 16, 2, 4),
|
| 499 |
+
"P": (3, 15, 2, 3),
|
| 500 |
+
"Pb": (6, 14, 2, 2),
|
| 501 |
+
"Pd": (5, 10, 3, 10),
|
| 502 |
+
"Pr": (6, None, 1, 2),
|
| 503 |
+
"Rb": (5, 1, 1, 1),
|
| 504 |
+
"S": (3, 16, 2, 4),
|
| 505 |
+
"Sb": (5, 15, 2, 3),
|
| 506 |
+
"Se": (4, 16, 2, 4),
|
| 507 |
+
"Si": (3, 14, 2, 2),
|
| 508 |
+
"Sm": (6, None, 1, 2),
|
| 509 |
+
"Sn": (5, 14, 2, 2),
|
| 510 |
+
"Sr": (5, 2, 1, 2),
|
| 511 |
+
"Te": (5, 16, 2, 4),
|
| 512 |
+
"Ti": (4, 4, 1, 2),
|
| 513 |
+
"Tl": (6, 13, 2, 1),
|
| 514 |
+
"Yb": (6, None, 1, 2),
|
| 515 |
+
"Zn": (4, 12, 1, 2),
|
| 516 |
+
}
|
synplan/ml/training/reinforcement.py
ADDED
|
@@ -0,0 +1,379 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Module containing functions for running value network tuning with reinforcement learning
|
| 2 |
+
approach."""
|
| 3 |
+
|
| 4 |
+
import os
|
| 5 |
+
import random
|
| 6 |
+
from collections import defaultdict
|
| 7 |
+
from pathlib import Path
|
| 8 |
+
from random import shuffle
|
| 9 |
+
from typing import Dict, List
|
| 10 |
+
|
| 11 |
+
import torch
|
| 12 |
+
from CGRtools.containers import MoleculeContainer
|
| 13 |
+
from pytorch_lightning import Trainer
|
| 14 |
+
from torch.utils.data import random_split
|
| 15 |
+
from torch_geometric.data.lightning import LightningDataset
|
| 16 |
+
|
| 17 |
+
from synplan.chem.precursor import compose_precursors
|
| 18 |
+
from synplan.mcts.evaluation import ValueNetworkFunction
|
| 19 |
+
from synplan.mcts.expansion import PolicyNetworkFunction
|
| 20 |
+
from synplan.mcts.tree import Tree
|
| 21 |
+
from synplan.ml.networks.value import ValueNetwork
|
| 22 |
+
from synplan.ml.training.preprocessing import ValueNetworkDataset
|
| 23 |
+
from synplan.utils.config import (
|
| 24 |
+
PolicyNetworkConfig,
|
| 25 |
+
TuningConfig,
|
| 26 |
+
TreeConfig,
|
| 27 |
+
ValueNetworkConfig,
|
| 28 |
+
)
|
| 29 |
+
from synplan.utils.files import MoleculeReader
|
| 30 |
+
from synplan.utils.loading import (
|
| 31 |
+
load_building_blocks,
|
| 32 |
+
load_reaction_rules,
|
| 33 |
+
load_value_net,
|
| 34 |
+
)
|
| 35 |
+
from synplan.utils.logging import DisableLogger, HiddenPrints
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def create_value_network(value_config: ValueNetworkConfig) -> ValueNetwork:
|
| 39 |
+
"""Creates the initial value network.
|
| 40 |
+
|
| 41 |
+
:param value_config: The value network configuration.
|
| 42 |
+
:return: The valueNetwork to be trained/tuned.
|
| 43 |
+
"""
|
| 44 |
+
|
| 45 |
+
weights_path = Path(value_config.weights_path)
|
| 46 |
+
value_network = ValueNetwork(
|
| 47 |
+
vector_dim=value_config.vector_dim,
|
| 48 |
+
batch_size=value_config.batch_size,
|
| 49 |
+
dropout=value_config.dropout,
|
| 50 |
+
num_conv_layers=value_config.num_conv_layers,
|
| 51 |
+
learning_rate=value_config.learning_rate,
|
| 52 |
+
)
|
| 53 |
+
|
| 54 |
+
with DisableLogger(), HiddenPrints():
|
| 55 |
+
trainer = Trainer()
|
| 56 |
+
trainer.strategy.connect(value_network)
|
| 57 |
+
trainer.save_checkpoint(weights_path)
|
| 58 |
+
|
| 59 |
+
return value_network
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
def create_targets_batch(
|
| 63 |
+
targets: List[MoleculeContainer], batch_size: int
|
| 64 |
+
) -> List[List[MoleculeContainer]]:
|
| 65 |
+
"""Creates the targets batches for planning simulations and value network tuning.
|
| 66 |
+
|
| 67 |
+
:param targets: The list of target molecules.
|
| 68 |
+
:param batch_size: The size of each target batch.
|
| 69 |
+
:return: The list of lists corresponding to each target batch.
|
| 70 |
+
"""
|
| 71 |
+
|
| 72 |
+
num_targets = len(targets)
|
| 73 |
+
batch_splits = list(
|
| 74 |
+
range(num_targets // batch_size + int(bool(num_targets % batch_size)))
|
| 75 |
+
)
|
| 76 |
+
|
| 77 |
+
if int(num_targets / batch_size) == 0:
|
| 78 |
+
print(f"1 batch were created with {num_targets} molecules")
|
| 79 |
+
else:
|
| 80 |
+
print(
|
| 81 |
+
f"{len(batch_splits)} batches were created with {batch_size} molecules each"
|
| 82 |
+
)
|
| 83 |
+
|
| 84 |
+
targets_batch_list = []
|
| 85 |
+
for batch_id in batch_splits:
|
| 86 |
+
batch_slices = [
|
| 87 |
+
i
|
| 88 |
+
for i in range(batch_id * batch_size, (batch_id + 1) * batch_size)
|
| 89 |
+
if i < len(targets)
|
| 90 |
+
]
|
| 91 |
+
targets_batch_list.append([targets[i] for i in batch_slices])
|
| 92 |
+
|
| 93 |
+
return targets_batch_list
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
def run_tree_search(
|
| 97 |
+
target: MoleculeContainer,
|
| 98 |
+
tree_config: TreeConfig,
|
| 99 |
+
policy_config: PolicyNetworkConfig,
|
| 100 |
+
value_config: ValueNetworkConfig,
|
| 101 |
+
reaction_rules_path: str,
|
| 102 |
+
building_blocks_path: str,
|
| 103 |
+
) -> Tree:
|
| 104 |
+
"""Runs tree search for the given target molecule.
|
| 105 |
+
|
| 106 |
+
:param target: The target molecule.
|
| 107 |
+
:param tree_config: The planning configuration of tree search.
|
| 108 |
+
:param policy_config: The policy network configuration.
|
| 109 |
+
:param value_config: The value network configuration.
|
| 110 |
+
:param reaction_rules_path: The path to the file with reaction rules.
|
| 111 |
+
:param building_blocks_path: The path to the file with building blocks.
|
| 112 |
+
:return: The built search tree for the given molecule.
|
| 113 |
+
"""
|
| 114 |
+
|
| 115 |
+
# policy and value function loading
|
| 116 |
+
policy_function = PolicyNetworkFunction(policy_config=policy_config)
|
| 117 |
+
value_function = ValueNetworkFunction(weights_path=value_config.weights_path)
|
| 118 |
+
reaction_rules = load_reaction_rules(reaction_rules_path)
|
| 119 |
+
building_blocks = load_building_blocks(building_blocks_path, standardize=True)
|
| 120 |
+
|
| 121 |
+
# initialize tree
|
| 122 |
+
tree_config.evaluation_type = "gcn"
|
| 123 |
+
tree_config.silent = True
|
| 124 |
+
tree = Tree(
|
| 125 |
+
target=target,
|
| 126 |
+
config=tree_config,
|
| 127 |
+
reaction_rules=reaction_rules,
|
| 128 |
+
building_blocks=building_blocks,
|
| 129 |
+
expansion_function=policy_function,
|
| 130 |
+
evaluation_function=value_function,
|
| 131 |
+
)
|
| 132 |
+
tree._tqdm = False
|
| 133 |
+
|
| 134 |
+
# remove target from buildings blocs
|
| 135 |
+
if str(target) in tree.building_blocks:
|
| 136 |
+
tree.building_blocks.remove(str(target))
|
| 137 |
+
|
| 138 |
+
# run tree search
|
| 139 |
+
_ = list(tree)
|
| 140 |
+
|
| 141 |
+
return tree
|
| 142 |
+
|
| 143 |
+
|
| 144 |
+
def extract_tree_precursor(tree_list: List[Tree]) -> Dict[str, float]:
|
| 145 |
+
"""Takes the built tree and extracts the precursor for value network tuning. The
|
| 146 |
+
precursor from found retrosynthetic routes are labeled as a positive class and precursor
|
| 147 |
+
from not solved routes are labeled as a negative class.
|
| 148 |
+
|
| 149 |
+
:param tree_list: The list of built search trees.
|
| 150 |
+
|
| 151 |
+
:return: The dictionary with the precursor SMILES and its class (positive - 1 or negative - 0).
|
| 152 |
+
"""
|
| 153 |
+
extracted_precursor = defaultdict(float)
|
| 154 |
+
for tree in tree_list:
|
| 155 |
+
for idx, node in tree.nodes.items():
|
| 156 |
+
# add solved nodes to set
|
| 157 |
+
if node.is_solved():
|
| 158 |
+
parent = idx
|
| 159 |
+
while parent and parent != 1:
|
| 160 |
+
composed_smi = str(
|
| 161 |
+
compose_precursors(tree.nodes[parent].new_precursors)
|
| 162 |
+
)
|
| 163 |
+
extracted_precursor[composed_smi] = 1.0
|
| 164 |
+
parent = tree.parents[parent]
|
| 165 |
+
else:
|
| 166 |
+
composed_smi = str(compose_precursors(tree.nodes[idx].new_precursors))
|
| 167 |
+
extracted_precursor[composed_smi] = 0.0
|
| 168 |
+
|
| 169 |
+
# shuffle extracted precursor
|
| 170 |
+
processed_keys = list(extracted_precursor.keys())
|
| 171 |
+
shuffle(processed_keys)
|
| 172 |
+
extracted_precursor = {i: extracted_precursor[i] for i in processed_keys}
|
| 173 |
+
|
| 174 |
+
return extracted_precursor
|
| 175 |
+
|
| 176 |
+
|
| 177 |
+
def balance_extracted_precursor(extracted_precursor):
|
| 178 |
+
extracted_precursor_balanced = {}
|
| 179 |
+
neg_list = [i for i, j in extracted_precursor.items() if j == 0]
|
| 180 |
+
for k, v in extracted_precursor.items():
|
| 181 |
+
if v == 1:
|
| 182 |
+
extracted_precursor_balanced[k] = v
|
| 183 |
+
if len(extracted_precursor_balanced) < len(neg_list):
|
| 184 |
+
neg_list.pop(random.choice(range(len(neg_list))))
|
| 185 |
+
return extracted_precursor_balanced
|
| 186 |
+
|
| 187 |
+
|
| 188 |
+
def create_updating_set(
|
| 189 |
+
extracted_precursor: Dict[str, float], batch_size: int = 1
|
| 190 |
+
) -> LightningDataset:
|
| 191 |
+
"""Creates the value network updating dataset from precursor extracted from the planning
|
| 192 |
+
simulation.
|
| 193 |
+
|
| 194 |
+
:param extracted_precursor: The dictionary with the extracted precursor and their
|
| 195 |
+
labels.
|
| 196 |
+
:param batch_size: The size of the batch in value network updating.
|
| 197 |
+
:return: A LightningDataset object, which contains the tuning set for value network
|
| 198 |
+
tuning.
|
| 199 |
+
"""
|
| 200 |
+
|
| 201 |
+
extracted_precursor = balance_extracted_precursor(extracted_precursor)
|
| 202 |
+
|
| 203 |
+
full_dataset = ValueNetworkDataset(extracted_precursor)
|
| 204 |
+
train_size = int(0.6 * len(full_dataset))
|
| 205 |
+
val_size = len(full_dataset) - train_size
|
| 206 |
+
|
| 207 |
+
train_set, val_set = random_split(
|
| 208 |
+
full_dataset, [train_size, val_size], torch.Generator().manual_seed(42)
|
| 209 |
+
)
|
| 210 |
+
|
| 211 |
+
print(f"Training set size: {len(train_set)}")
|
| 212 |
+
print(f"Validation set size: {len(val_set)}")
|
| 213 |
+
|
| 214 |
+
return LightningDataset(
|
| 215 |
+
train_set, val_set, batch_size=batch_size, pin_memory=True, drop_last=True
|
| 216 |
+
)
|
| 217 |
+
|
| 218 |
+
|
| 219 |
+
def tune_value_network(
|
| 220 |
+
datamodule: LightningDataset, value_config: ValueNetworkConfig
|
| 221 |
+
) -> None:
|
| 222 |
+
"""Trains the value network using a given tuning data and saves the trained neural
|
| 223 |
+
network.
|
| 224 |
+
|
| 225 |
+
:param datamodule: The tuning dataset (LightningDataset).
|
| 226 |
+
:param value_config: The value network configuration.
|
| 227 |
+
:return: None.
|
| 228 |
+
"""
|
| 229 |
+
|
| 230 |
+
current_weights = value_config.weights_path
|
| 231 |
+
value_network = load_value_net(ValueNetwork, current_weights)
|
| 232 |
+
|
| 233 |
+
with DisableLogger(), HiddenPrints():
|
| 234 |
+
trainer = Trainer(
|
| 235 |
+
accelerator="gpu",
|
| 236 |
+
devices=[0],
|
| 237 |
+
max_epochs=value_config.num_epoch,
|
| 238 |
+
enable_checkpointing=False,
|
| 239 |
+
logger=False,
|
| 240 |
+
gradient_clip_val=1.0,
|
| 241 |
+
enable_progress_bar=False,
|
| 242 |
+
)
|
| 243 |
+
|
| 244 |
+
trainer.fit(value_network, datamodule)
|
| 245 |
+
val_score = trainer.validate(value_network, datamodule.val_dataloader())[0]
|
| 246 |
+
trainer.save_checkpoint(current_weights)
|
| 247 |
+
|
| 248 |
+
print(f"Value network balanced accuracy: {val_score['val_balanced_accuracy']}")
|
| 249 |
+
|
| 250 |
+
|
| 251 |
+
def run_training(
|
| 252 |
+
extracted_precursor: Dict[str, float] = None,
|
| 253 |
+
value_config: ValueNetworkConfig = None,
|
| 254 |
+
) -> None:
|
| 255 |
+
"""Runs the training stage in value network tuning.
|
| 256 |
+
|
| 257 |
+
:param extracted_precursor: The precursor extracted from the planing simulations.
|
| 258 |
+
:param value_config: The value network configuration.
|
| 259 |
+
:return: None.
|
| 260 |
+
"""
|
| 261 |
+
|
| 262 |
+
# create training set
|
| 263 |
+
training_set = create_updating_set(
|
| 264 |
+
extracted_precursor=extracted_precursor, batch_size=value_config.batch_size
|
| 265 |
+
)
|
| 266 |
+
|
| 267 |
+
# retrain value network
|
| 268 |
+
tune_value_network(datamodule=training_set, value_config=value_config)
|
| 269 |
+
|
| 270 |
+
|
| 271 |
+
def run_planning(
|
| 272 |
+
targets_batch: List[MoleculeContainer],
|
| 273 |
+
tree_config: TreeConfig,
|
| 274 |
+
policy_config: PolicyNetworkConfig,
|
| 275 |
+
value_config: ValueNetworkConfig,
|
| 276 |
+
reaction_rules_path: str,
|
| 277 |
+
building_blocks_path: str,
|
| 278 |
+
targets_batch_id: int,
|
| 279 |
+
):
|
| 280 |
+
"""Performs planning stage (tree search) for target molecules and save extracted
|
| 281 |
+
from built trees precursor for further tuning the value network in the training stage.
|
| 282 |
+
|
| 283 |
+
:param targets_batch:
|
| 284 |
+
:param tree_config:
|
| 285 |
+
:param policy_config:
|
| 286 |
+
:param value_config:
|
| 287 |
+
:param reaction_rules_path:
|
| 288 |
+
:param building_blocks_path:
|
| 289 |
+
:param targets_batch_id:
|
| 290 |
+
"""
|
| 291 |
+
from tqdm import tqdm
|
| 292 |
+
|
| 293 |
+
print(f"\nProcess batch number {targets_batch_id}")
|
| 294 |
+
tree_list = []
|
| 295 |
+
tree_config.silent = False
|
| 296 |
+
for target in tqdm(targets_batch):
|
| 297 |
+
|
| 298 |
+
try:
|
| 299 |
+
tree = run_tree_search(
|
| 300 |
+
target=target,
|
| 301 |
+
tree_config=tree_config,
|
| 302 |
+
policy_config=policy_config,
|
| 303 |
+
value_config=value_config,
|
| 304 |
+
reaction_rules_path=reaction_rules_path,
|
| 305 |
+
building_blocks_path=building_blocks_path,
|
| 306 |
+
)
|
| 307 |
+
tree_list.append(tree)
|
| 308 |
+
|
| 309 |
+
except Exception as e:
|
| 310 |
+
print(e)
|
| 311 |
+
continue
|
| 312 |
+
|
| 313 |
+
num_solved = sum([len(i.winning_nodes) > 0 for i in tree_list])
|
| 314 |
+
print(f"Planning is finished with {num_solved} solved targets")
|
| 315 |
+
|
| 316 |
+
return tree_list
|
| 317 |
+
|
| 318 |
+
|
| 319 |
+
def run_updating(
|
| 320 |
+
targets_path: str,
|
| 321 |
+
tree_config: TreeConfig,
|
| 322 |
+
policy_config: PolicyNetworkConfig,
|
| 323 |
+
value_config: ValueNetworkConfig,
|
| 324 |
+
reinforce_config: TuningConfig,
|
| 325 |
+
reaction_rules_path: str,
|
| 326 |
+
building_blocks_path: str,
|
| 327 |
+
results_root: str = None,
|
| 328 |
+
) -> None:
|
| 329 |
+
"""Performs updating of value network.
|
| 330 |
+
|
| 331 |
+
:param targets_path: The path to the file with target molecules.
|
| 332 |
+
:param tree_config: The search tree configuration.
|
| 333 |
+
:param policy_config: The policy network configuration.
|
| 334 |
+
:param value_config: The value network configuration.
|
| 335 |
+
:param reinforce_config: The value network tuning configuration.
|
| 336 |
+
:param reaction_rules_path: The path to the file with reaction rules.
|
| 337 |
+
:param building_blocks_path: The path to the file with building blocks.
|
| 338 |
+
:param results_root: The path to the directory where trained value network will be
|
| 339 |
+
saved.
|
| 340 |
+
:return: None.
|
| 341 |
+
"""
|
| 342 |
+
|
| 343 |
+
# create results root folder
|
| 344 |
+
results_root = Path(results_root)
|
| 345 |
+
if not results_root.exists():
|
| 346 |
+
results_root.mkdir()
|
| 347 |
+
|
| 348 |
+
# load targets list
|
| 349 |
+
with MoleculeReader(targets_path) as targets:
|
| 350 |
+
targets = list(targets)
|
| 351 |
+
|
| 352 |
+
# create value neural network
|
| 353 |
+
value_config.weights_path = os.path.join(results_root, "value_network.ckpt")
|
| 354 |
+
create_value_network(value_config)
|
| 355 |
+
|
| 356 |
+
# create targets batch
|
| 357 |
+
targets_batch_list = create_targets_batch(
|
| 358 |
+
targets, batch_size=reinforce_config.batch_size
|
| 359 |
+
)
|
| 360 |
+
|
| 361 |
+
# run value network tuning
|
| 362 |
+
for batch_id, targets_batch in enumerate(targets_batch_list, start=1):
|
| 363 |
+
|
| 364 |
+
# start tree planning simulation for batch of targets
|
| 365 |
+
tree_list = run_planning(
|
| 366 |
+
targets_batch=targets_batch,
|
| 367 |
+
tree_config=tree_config,
|
| 368 |
+
policy_config=policy_config,
|
| 369 |
+
value_config=value_config,
|
| 370 |
+
reaction_rules_path=reaction_rules_path,
|
| 371 |
+
building_blocks_path=building_blocks_path,
|
| 372 |
+
targets_batch_id=batch_id,
|
| 373 |
+
)
|
| 374 |
+
|
| 375 |
+
# extract pos and neg precursor from the list of built trees
|
| 376 |
+
extracted_precursor = extract_tree_precursor(tree_list)
|
| 377 |
+
|
| 378 |
+
# train value network for extracted precursor
|
| 379 |
+
run_training(extracted_precursor=extracted_precursor, value_config=value_config)
|
synplan/ml/training/supervised.py
ADDED
|
@@ -0,0 +1,153 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Module for the preparation and training of a policy network used in the expansion of
|
| 2 |
+
nodes in tree search.
|
| 3 |
+
|
| 4 |
+
This module includes functions for creating training datasets and running the training
|
| 5 |
+
process for the policy network.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
import warnings
|
| 9 |
+
from pathlib import Path
|
| 10 |
+
from typing import Union, List
|
| 11 |
+
|
| 12 |
+
import os
|
| 13 |
+
import torch
|
| 14 |
+
from pytorch_lightning import Trainer
|
| 15 |
+
from pytorch_lightning.callbacks import ModelCheckpoint
|
| 16 |
+
from torch.utils.data import random_split
|
| 17 |
+
from torch_geometric.data.lightning import LightningDataset
|
| 18 |
+
|
| 19 |
+
from synplan.ml.networks.policy import PolicyNetwork
|
| 20 |
+
from synplan.ml.training.preprocessing import (
|
| 21 |
+
FilteringPolicyDataset,
|
| 22 |
+
RankingPolicyDataset,
|
| 23 |
+
)
|
| 24 |
+
from synplan.utils.config import PolicyNetworkConfig
|
| 25 |
+
from synplan.utils.logging import DisableLogger, HiddenPrints
|
| 26 |
+
|
| 27 |
+
warnings.filterwarnings("ignore")
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def create_policy_dataset(
|
| 31 |
+
reaction_rules_path: str,
|
| 32 |
+
molecules_or_reactions_path: str,
|
| 33 |
+
output_path: str,
|
| 34 |
+
dataset_type: str = "filtering",
|
| 35 |
+
batch_size: int = 100,
|
| 36 |
+
num_cpus: int = 1,
|
| 37 |
+
training_data_ratio: float = 0.8,
|
| 38 |
+
):
|
| 39 |
+
"""
|
| 40 |
+
Create a training dataset for a policy network.
|
| 41 |
+
|
| 42 |
+
:param reaction_rules_path: Path to the reaction rules file.
|
| 43 |
+
:param molecules_or_reactions_path: Path to the molecules or reactions file used to create the training set.
|
| 44 |
+
:param output_path: Path to store the processed dataset.
|
| 45 |
+
:param dataset_type: Type of the dataset to be created ('ranking' or 'filtering').
|
| 46 |
+
:param batch_size: The size of batch of molecules/reactions.
|
| 47 |
+
:param training_data_ratio: Ratio of training data to total data.
|
| 48 |
+
:param num_cpus: Number of CPUs to use for data processing.
|
| 49 |
+
|
| 50 |
+
:return: A `LightningDataset` object containing training and validation datasets.
|
| 51 |
+
|
| 52 |
+
"""
|
| 53 |
+
|
| 54 |
+
with DisableLogger(), HiddenPrints():
|
| 55 |
+
if dataset_type == "filtering":
|
| 56 |
+
full_dataset = FilteringPolicyDataset(
|
| 57 |
+
reaction_rules_path=reaction_rules_path,
|
| 58 |
+
molecules_path=molecules_or_reactions_path,
|
| 59 |
+
output_path=output_path,
|
| 60 |
+
num_cpus=num_cpus,
|
| 61 |
+
)
|
| 62 |
+
|
| 63 |
+
elif dataset_type == "ranking":
|
| 64 |
+
full_dataset = RankingPolicyDataset(
|
| 65 |
+
reaction_rules_path=reaction_rules_path,
|
| 66 |
+
reactions_path=molecules_or_reactions_path,
|
| 67 |
+
output_path=output_path,
|
| 68 |
+
)
|
| 69 |
+
|
| 70 |
+
train_size = int(training_data_ratio * len(full_dataset))
|
| 71 |
+
val_size = len(full_dataset) - train_size
|
| 72 |
+
|
| 73 |
+
train_dataset, val_dataset = random_split(
|
| 74 |
+
full_dataset, [train_size, val_size], torch.Generator().manual_seed(42)
|
| 75 |
+
)
|
| 76 |
+
print(
|
| 77 |
+
f"Training set size: {len(train_dataset)}, validation set size: {len(val_dataset)}"
|
| 78 |
+
)
|
| 79 |
+
|
| 80 |
+
datamodule = LightningDataset(
|
| 81 |
+
train_dataset,
|
| 82 |
+
val_dataset,
|
| 83 |
+
batch_size=batch_size,
|
| 84 |
+
pin_memory=True,
|
| 85 |
+
drop_last=True,
|
| 86 |
+
)
|
| 87 |
+
|
| 88 |
+
return datamodule
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
def run_policy_training(
|
| 92 |
+
datamodule: LightningDataset,
|
| 93 |
+
config: PolicyNetworkConfig,
|
| 94 |
+
results_path: str,
|
| 95 |
+
weights_file_name: str = "policy_network",
|
| 96 |
+
accelerator: str = "gpu",
|
| 97 |
+
devices: Union[List[int], str, int] = "auto",
|
| 98 |
+
silent: bool = False,
|
| 99 |
+
) -> None:
|
| 100 |
+
"""
|
| 101 |
+
Trains a policy network using a given datamodule and training configuration.
|
| 102 |
+
|
| 103 |
+
:param datamodule: A PyTorch Lightning `DataModule` class instance. It is responsible for loading, processing, and preparing the training data for the model.
|
| 104 |
+
:param config: The dictionary that contains various configuration settings for the policy training process.
|
| 105 |
+
:param results_path: Path to store the training results and logs.
|
| 106 |
+
:param accelerator: Supports passing different accelerator types (“cpu”, “gpu”, “tpu”, “hpu”, “mps”, “auto”) as well as custom accelerator instances. Default: "gpu".
|
| 107 |
+
:param devices: The devices to use. Can be set to a positive number (int or str), a sequence of device indices (list or str), the value -1 to indicate all available devices should be used, or "auto" for automatic selection based on the chosen accelerator. Default: "auto".
|
| 108 |
+
:param silent: Run in the silent mode with no progress bars. Default: True.
|
| 109 |
+
:param weights_file_name: The name of weights file to be saved. Default: "policy_network".
|
| 110 |
+
|
| 111 |
+
:return: None.
|
| 112 |
+
|
| 113 |
+
"""
|
| 114 |
+
results_path = Path(results_path)
|
| 115 |
+
results_path.mkdir(exist_ok=True)
|
| 116 |
+
|
| 117 |
+
network = PolicyNetwork(
|
| 118 |
+
vector_dim=config.vector_dim,
|
| 119 |
+
n_rules=datamodule.train_dataset.dataset.num_classes,
|
| 120 |
+
batch_size=config.batch_size,
|
| 121 |
+
dropout=config.dropout,
|
| 122 |
+
num_conv_layers=config.num_conv_layers,
|
| 123 |
+
learning_rate=config.learning_rate,
|
| 124 |
+
policy_type=config.policy_type,
|
| 125 |
+
)
|
| 126 |
+
|
| 127 |
+
checkpoint = ModelCheckpoint(
|
| 128 |
+
dirpath=results_path, filename=weights_file_name, monitor="val_loss", mode="min"
|
| 129 |
+
)
|
| 130 |
+
|
| 131 |
+
if silent:
|
| 132 |
+
enable_progress_bar = False
|
| 133 |
+
else:
|
| 134 |
+
enable_progress_bar = True
|
| 135 |
+
|
| 136 |
+
trainer = Trainer(
|
| 137 |
+
accelerator=accelerator,
|
| 138 |
+
devices=devices,
|
| 139 |
+
max_epochs=config.num_epoch,
|
| 140 |
+
callbacks=[checkpoint],
|
| 141 |
+
logger=False,
|
| 142 |
+
gradient_clip_val=1.0,
|
| 143 |
+
enable_progress_bar=enable_progress_bar,
|
| 144 |
+
)
|
| 145 |
+
|
| 146 |
+
if silent:
|
| 147 |
+
with DisableLogger(), HiddenPrints():
|
| 148 |
+
trainer.fit(network, datamodule)
|
| 149 |
+
else:
|
| 150 |
+
trainer.fit(network, datamodule)
|
| 151 |
+
|
| 152 |
+
ba = round(trainer.logged_metrics["train_balanced_accuracy_y_step"].item(), 3)
|
| 153 |
+
print(f"Policy network balanced accuracy: {ba}")
|
synplan/utils/__init__.py
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Union
|
| 2 |
+
from os import PathLike
|
| 3 |
+
|
| 4 |
+
path_type = Union[str, PathLike]
|
synplan/utils/config.py
ADDED
|
@@ -0,0 +1,543 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Module containing configuration classes."""
|
| 2 |
+
|
| 3 |
+
from abc import ABC, abstractmethod
|
| 4 |
+
from dataclasses import dataclass, field
|
| 5 |
+
from pathlib import Path
|
| 6 |
+
from typing import Any, Dict, List, Union
|
| 7 |
+
from chython import smarts
|
| 8 |
+
|
| 9 |
+
import yaml
|
| 10 |
+
from CGRtools.containers import MoleculeContainer, QueryContainer
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
@dataclass
|
| 14 |
+
class ConfigABC(ABC):
|
| 15 |
+
"""Abstract base class for configuration classes."""
|
| 16 |
+
|
| 17 |
+
@staticmethod
|
| 18 |
+
@abstractmethod
|
| 19 |
+
def from_dict(config_dict: Dict[str, Any]):
|
| 20 |
+
"""Create an instance of the configuration from a dictionary."""
|
| 21 |
+
|
| 22 |
+
def to_dict(self) -> Dict[str, Any]:
|
| 23 |
+
"""Convert the configuration into a dictionary."""
|
| 24 |
+
return {
|
| 25 |
+
k: str(v) if isinstance(v, Path) else v for k, v in self.__dict__.items()
|
| 26 |
+
}
|
| 27 |
+
|
| 28 |
+
@staticmethod
|
| 29 |
+
@abstractmethod
|
| 30 |
+
def from_yaml(file_path: str):
|
| 31 |
+
"""Deserialize a YAML file into a configuration object."""
|
| 32 |
+
|
| 33 |
+
def to_yaml(self, file_path: str):
|
| 34 |
+
"""Serializes the configuration to a YAML file.
|
| 35 |
+
|
| 36 |
+
:param file_path: The path to the output YAML file.
|
| 37 |
+
"""
|
| 38 |
+
with open(file_path, "w", encoding="utf-8") as file:
|
| 39 |
+
yaml.dump(self.to_dict(), file)
|
| 40 |
+
|
| 41 |
+
@abstractmethod
|
| 42 |
+
def _validate_params(self, params: Dict[str, Any]):
|
| 43 |
+
"""Validate configuration parameters."""
|
| 44 |
+
|
| 45 |
+
def __post_init__(self):
|
| 46 |
+
"""Validates the configuration parameters."""
|
| 47 |
+
# call _validate_params method after initialization
|
| 48 |
+
params = self.to_dict()
|
| 49 |
+
self._validate_params(params)
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
@dataclass
|
| 53 |
+
class RuleExtractionConfig(ConfigABC):
|
| 54 |
+
"""Configuration class for extracting reaction rules.
|
| 55 |
+
|
| 56 |
+
:param multicenter_rules: If True, extracts a single rule
|
| 57 |
+
encompassing all centers. If False, extracts separate reaction
|
| 58 |
+
rules for each reaction center in a multicenter reaction.
|
| 59 |
+
:param as_query_container: If True, the extracted rules are
|
| 60 |
+
generated as QueryContainer objects, analogous to SMARTS objects
|
| 61 |
+
for pattern matching in chemical structures.
|
| 62 |
+
:param reverse_rule: If True, reverses the direction of the reaction
|
| 63 |
+
for rule extraction.
|
| 64 |
+
:param reactor_validation: If True, validates each generated rule in
|
| 65 |
+
a chemical reactor to ensure correct generation of products from
|
| 66 |
+
reactants.
|
| 67 |
+
:param include_func_groups: If True, includes specific functional
|
| 68 |
+
groups in the reaction rule in addition to the reaction center
|
| 69 |
+
and its environment.
|
| 70 |
+
:param func_groups_list: A list of functional groups to be
|
| 71 |
+
considered when include_func_groups is True.
|
| 72 |
+
:param include_rings: If True, includes ring structures in the
|
| 73 |
+
reaction rules.
|
| 74 |
+
:param keep_leaving_groups: If True, retains leaving groups in the
|
| 75 |
+
extracted reaction rule.
|
| 76 |
+
:param keep_incoming_groups: If True, retains incoming groups in the
|
| 77 |
+
extracted reaction rule.
|
| 78 |
+
:param keep_reagents: If True, includes reagents in the extracted
|
| 79 |
+
reaction rule.
|
| 80 |
+
:param environment_atom_count: Defines the size of the environment
|
| 81 |
+
around the reaction center to be included in the rule (0 for
|
| 82 |
+
only the reaction center, 1 for the first environment, etc.).
|
| 83 |
+
:param min_popularity: Minimum number of times a rule must be
|
| 84 |
+
applied to be considered for further analysis.
|
| 85 |
+
:param keep_metadata: If True, retains metadata associated with the
|
| 86 |
+
reaction in the extracted rule.
|
| 87 |
+
:param single_reactant_only: If True, includes only reaction rules
|
| 88 |
+
with a single reactant molecule.
|
| 89 |
+
:param atom_info_retention: Controls the amount of information about
|
| 90 |
+
each atom to retain ('none', 'reaction_center', or 'all').
|
| 91 |
+
"""
|
| 92 |
+
|
| 93 |
+
# default low-level parameters
|
| 94 |
+
single_reactant_only: bool = True
|
| 95 |
+
keep_metadata: bool = False
|
| 96 |
+
reactor_validation: bool = True
|
| 97 |
+
reverse_rule: bool = True
|
| 98 |
+
as_query_container: bool = True
|
| 99 |
+
include_func_groups: bool = False
|
| 100 |
+
func_groups_list: List[str] = field(default_factory=list)
|
| 101 |
+
|
| 102 |
+
# adjustable parameters
|
| 103 |
+
environment_atom_count: int = 1
|
| 104 |
+
min_popularity: int = 3
|
| 105 |
+
include_rings: bool = True
|
| 106 |
+
multicenter_rules: bool = True
|
| 107 |
+
keep_leaving_groups: bool = True
|
| 108 |
+
keep_incoming_groups: bool = True
|
| 109 |
+
keep_reagents: bool = False
|
| 110 |
+
atom_info_retention: Dict[str, Dict[str, bool]] = field(default_factory=dict)
|
| 111 |
+
|
| 112 |
+
def __post_init__(self):
|
| 113 |
+
super().__post_init__()
|
| 114 |
+
self._validate_params(self.to_dict())
|
| 115 |
+
self._initialize_default_atom_info_retention()
|
| 116 |
+
self._parse_functional_groups()
|
| 117 |
+
|
| 118 |
+
def _initialize_default_atom_info_retention(self):
|
| 119 |
+
default_atom_info = {
|
| 120 |
+
"reaction_center": {
|
| 121 |
+
"neighbors": True,
|
| 122 |
+
"hybridization": True,
|
| 123 |
+
"implicit_hydrogens": False,
|
| 124 |
+
"ring_sizes": False,
|
| 125 |
+
},
|
| 126 |
+
"environment": {
|
| 127 |
+
"neighbors": False,
|
| 128 |
+
"hybridization": False,
|
| 129 |
+
"implicit_hydrogens": False,
|
| 130 |
+
"ring_sizes": False,
|
| 131 |
+
},
|
| 132 |
+
}
|
| 133 |
+
|
| 134 |
+
if not self.atom_info_retention:
|
| 135 |
+
self.atom_info_retention = default_atom_info
|
| 136 |
+
else:
|
| 137 |
+
for key in default_atom_info:
|
| 138 |
+
self.atom_info_retention[key].update(
|
| 139 |
+
self.atom_info_retention.get(key, {})
|
| 140 |
+
)
|
| 141 |
+
|
| 142 |
+
def _parse_functional_groups(self):
|
| 143 |
+
func_groups_list = []
|
| 144 |
+
for group_smarts in self.func_groups_list:
|
| 145 |
+
try:
|
| 146 |
+
query = smarts(group_smarts)
|
| 147 |
+
func_groups_list.append(query)
|
| 148 |
+
except Exception as e:
|
| 149 |
+
print(f"Functional group {group_smarts} was not parsed because of {e}")
|
| 150 |
+
self.func_groups_list = func_groups_list
|
| 151 |
+
|
| 152 |
+
@staticmethod
|
| 153 |
+
def from_dict(config_dict: Dict[str, Any]) -> "RuleExtractionConfig":
|
| 154 |
+
return RuleExtractionConfig(**config_dict)
|
| 155 |
+
|
| 156 |
+
@staticmethod
|
| 157 |
+
def from_yaml(file_path: str) -> "RuleExtractionConfig":
|
| 158 |
+
|
| 159 |
+
with open(file_path, "r", encoding="utf-8") as file:
|
| 160 |
+
config_dict = yaml.safe_load(file)
|
| 161 |
+
return RuleExtractionConfig.from_dict(config_dict)
|
| 162 |
+
|
| 163 |
+
def _validate_params(self, params: Dict[str, Any]) -> None:
|
| 164 |
+
|
| 165 |
+
if not isinstance(params["multicenter_rules"], bool):
|
| 166 |
+
raise ValueError("multicenter_rules must be a boolean.")
|
| 167 |
+
|
| 168 |
+
if not isinstance(params["as_query_container"], bool):
|
| 169 |
+
raise ValueError("as_query_container must be a boolean.")
|
| 170 |
+
|
| 171 |
+
if not isinstance(params["reverse_rule"], bool):
|
| 172 |
+
raise ValueError("reverse_rule must be a boolean.")
|
| 173 |
+
|
| 174 |
+
if not isinstance(params["reactor_validation"], bool):
|
| 175 |
+
raise ValueError("reactor_validation must be a boolean.")
|
| 176 |
+
|
| 177 |
+
if not isinstance(params["include_func_groups"], bool):
|
| 178 |
+
raise ValueError("include_func_groups must be a boolean.")
|
| 179 |
+
|
| 180 |
+
if params["func_groups_list"] is not None and not all(
|
| 181 |
+
isinstance(group, str) for group in params["func_groups_list"]
|
| 182 |
+
):
|
| 183 |
+
raise ValueError("func_groups_list must be a list of SMARTS.")
|
| 184 |
+
|
| 185 |
+
if not isinstance(params["include_rings"], bool):
|
| 186 |
+
raise ValueError("include_rings must be a boolean.")
|
| 187 |
+
|
| 188 |
+
if not isinstance(params["keep_leaving_groups"], bool):
|
| 189 |
+
raise ValueError("keep_leaving_groups must be a boolean.")
|
| 190 |
+
|
| 191 |
+
if not isinstance(params["keep_incoming_groups"], bool):
|
| 192 |
+
raise ValueError("keep_incoming_groups must be a boolean.")
|
| 193 |
+
|
| 194 |
+
if not isinstance(params["keep_reagents"], bool):
|
| 195 |
+
raise ValueError("keep_reagents must be a boolean.")
|
| 196 |
+
|
| 197 |
+
if not isinstance(params["environment_atom_count"], int):
|
| 198 |
+
raise ValueError("environment_atom_count must be an integer.")
|
| 199 |
+
|
| 200 |
+
if not isinstance(params["min_popularity"], int):
|
| 201 |
+
raise ValueError("min_popularity must be an integer.")
|
| 202 |
+
|
| 203 |
+
if not isinstance(params["keep_metadata"], bool):
|
| 204 |
+
raise ValueError("keep_metadata must be a boolean.")
|
| 205 |
+
|
| 206 |
+
if not isinstance(params["single_reactant_only"], bool):
|
| 207 |
+
raise ValueError("single_reactant_only must be a boolean.")
|
| 208 |
+
|
| 209 |
+
if params["atom_info_retention"] is not None:
|
| 210 |
+
if not isinstance(params["atom_info_retention"], dict):
|
| 211 |
+
raise ValueError("atom_info_retention must be a dictionary.")
|
| 212 |
+
|
| 213 |
+
required_keys = {"reaction_center", "environment"}
|
| 214 |
+
if not required_keys.issubset(params["atom_info_retention"]):
|
| 215 |
+
missing_keys = required_keys - set(params["atom_info_retention"].keys())
|
| 216 |
+
raise ValueError(
|
| 217 |
+
f"atom_info_retention missing required keys: {missing_keys}"
|
| 218 |
+
)
|
| 219 |
+
|
| 220 |
+
for key, value in params["atom_info_retention"].items():
|
| 221 |
+
if key not in required_keys:
|
| 222 |
+
raise ValueError(f"Unexpected key in atom_info_retention: {key}")
|
| 223 |
+
|
| 224 |
+
expected_subkeys = {
|
| 225 |
+
"neighbors",
|
| 226 |
+
"hybridization",
|
| 227 |
+
"implicit_hydrogens",
|
| 228 |
+
"ring_sizes",
|
| 229 |
+
}
|
| 230 |
+
if not isinstance(value, dict) or not expected_subkeys.issubset(value):
|
| 231 |
+
missing_subkeys = expected_subkeys - set(value.keys())
|
| 232 |
+
raise ValueError(
|
| 233 |
+
f"Invalid structure for {key} in atom_info_retention. Missing subkeys: {missing_subkeys}"
|
| 234 |
+
)
|
| 235 |
+
|
| 236 |
+
for subkey, subvalue in value.items():
|
| 237 |
+
if not isinstance(subvalue, bool):
|
| 238 |
+
raise ValueError(
|
| 239 |
+
f"Value for {subkey} in {key} of atom_info_retention must be boolean."
|
| 240 |
+
)
|
| 241 |
+
|
| 242 |
+
|
| 243 |
+
@dataclass
|
| 244 |
+
class PolicyNetworkConfig(ConfigABC):
|
| 245 |
+
"""Configuration class for the policy network.
|
| 246 |
+
|
| 247 |
+
:param vector_dim: Dimension of the input vectors.
|
| 248 |
+
:param batch_size: Number of samples per batch.
|
| 249 |
+
:param dropout: Dropout rate for regularization.
|
| 250 |
+
:param learning_rate: Learning rate for the optimizer.
|
| 251 |
+
:param num_conv_layers: Number of convolutional layers in the network.
|
| 252 |
+
:param num_epoch: Number of training epochs.
|
| 253 |
+
:param policy_type: Mode of operation, either 'filtering' or 'ranking'.
|
| 254 |
+
"""
|
| 255 |
+
|
| 256 |
+
policy_type: str = "ranking"
|
| 257 |
+
vector_dim: int = 256
|
| 258 |
+
batch_size: int = 500
|
| 259 |
+
dropout: float = 0.4
|
| 260 |
+
learning_rate: float = 0.008
|
| 261 |
+
num_conv_layers: int = 5
|
| 262 |
+
num_epoch: int = 100
|
| 263 |
+
weights_path: str = None
|
| 264 |
+
|
| 265 |
+
# for filtering policy
|
| 266 |
+
priority_rules_fraction: float = 0.5
|
| 267 |
+
rule_prob_threshold: float = 0.0
|
| 268 |
+
top_rules: int = 50
|
| 269 |
+
|
| 270 |
+
@staticmethod
|
| 271 |
+
def from_dict(config_dict: Dict[str, Any]) -> "PolicyNetworkConfig":
|
| 272 |
+
return PolicyNetworkConfig(**config_dict)
|
| 273 |
+
|
| 274 |
+
@staticmethod
|
| 275 |
+
def from_yaml(file_path: str) -> "PolicyNetworkConfig":
|
| 276 |
+
with open(file_path, "r", encoding="utf-8") as file:
|
| 277 |
+
config_dict = yaml.safe_load(file)
|
| 278 |
+
return PolicyNetworkConfig.from_dict(config_dict)
|
| 279 |
+
|
| 280 |
+
def _validate_params(self, params: Dict[str, Any]):
|
| 281 |
+
|
| 282 |
+
if params["policy_type"] not in ["filtering", "ranking"]:
|
| 283 |
+
raise ValueError("policy_type must be either 'filtering' or 'ranking'.")
|
| 284 |
+
|
| 285 |
+
if not isinstance(params["vector_dim"], int) or params["vector_dim"] <= 0:
|
| 286 |
+
raise ValueError("vector_dim must be a positive integer.")
|
| 287 |
+
|
| 288 |
+
if not isinstance(params["batch_size"], int) or params["batch_size"] <= 0:
|
| 289 |
+
raise ValueError("batch_size must be a positive integer.")
|
| 290 |
+
|
| 291 |
+
if (
|
| 292 |
+
not isinstance(params["num_conv_layers"], int)
|
| 293 |
+
or params["num_conv_layers"] <= 0
|
| 294 |
+
):
|
| 295 |
+
raise ValueError("num_conv_layers must be a positive integer.")
|
| 296 |
+
|
| 297 |
+
if not isinstance(params["num_epoch"], int) or params["num_epoch"] <= 0:
|
| 298 |
+
raise ValueError("num_epoch must be a positive integer.")
|
| 299 |
+
|
| 300 |
+
if not isinstance(params["dropout"], float) or not (
|
| 301 |
+
0.0 <= params["dropout"] <= 1.0
|
| 302 |
+
):
|
| 303 |
+
raise ValueError("dropout must be a float between 0.0 and 1.0.")
|
| 304 |
+
|
| 305 |
+
if (
|
| 306 |
+
not isinstance(params["learning_rate"], float)
|
| 307 |
+
or params["learning_rate"] <= 0.0
|
| 308 |
+
):
|
| 309 |
+
raise ValueError("learning_rate must be a positive float.")
|
| 310 |
+
|
| 311 |
+
if (
|
| 312 |
+
not isinstance(params["priority_rules_fraction"], float)
|
| 313 |
+
or params["priority_rules_fraction"] < 0.0
|
| 314 |
+
):
|
| 315 |
+
raise ValueError(
|
| 316 |
+
"priority_rules_fraction must be a non-negative positive float."
|
| 317 |
+
)
|
| 318 |
+
|
| 319 |
+
if (
|
| 320 |
+
not isinstance(params["rule_prob_threshold"], float)
|
| 321 |
+
or params["rule_prob_threshold"] < 0.0
|
| 322 |
+
):
|
| 323 |
+
raise ValueError("rule_prob_threshold must be a non-negative float.")
|
| 324 |
+
|
| 325 |
+
if not isinstance(params["top_rules"], int) or params["top_rules"] <= 0:
|
| 326 |
+
raise ValueError("top_rules must be a positive integer.")
|
| 327 |
+
|
| 328 |
+
|
| 329 |
+
@dataclass
|
| 330 |
+
class ValueNetworkConfig(ConfigABC):
|
| 331 |
+
"""Configuration class for the value network.
|
| 332 |
+
|
| 333 |
+
:param vector_dim: Dimension of the input vectors.
|
| 334 |
+
:param batch_size: Number of samples per batch.
|
| 335 |
+
:param dropout: Dropout rate for regularization.
|
| 336 |
+
:param learning_rate: Learning rate for the optimizer.
|
| 337 |
+
:param num_conv_layers: Number of convolutional layers in the network.
|
| 338 |
+
:param num_epoch: Number of training epochs.
|
| 339 |
+
"""
|
| 340 |
+
|
| 341 |
+
weights_path: str = None
|
| 342 |
+
vector_dim: int = 256
|
| 343 |
+
batch_size: int = 500
|
| 344 |
+
dropout: float = 0.4
|
| 345 |
+
learning_rate: float = 0.008
|
| 346 |
+
num_conv_layers: int = 5
|
| 347 |
+
num_epoch: int = 100
|
| 348 |
+
|
| 349 |
+
@staticmethod
|
| 350 |
+
def from_dict(config_dict: Dict[str, Any]) -> "ValueNetworkConfig":
|
| 351 |
+
return ValueNetworkConfig(**config_dict)
|
| 352 |
+
|
| 353 |
+
@staticmethod
|
| 354 |
+
def from_yaml(file_path: str) -> "ValueNetworkConfig":
|
| 355 |
+
with open(file_path, "r", encoding="utf-8") as file:
|
| 356 |
+
config_dict = yaml.safe_load(file)
|
| 357 |
+
return ValueNetworkConfig.from_dict(config_dict)
|
| 358 |
+
|
| 359 |
+
def to_yaml(self, file_path: str):
|
| 360 |
+
with open(file_path, "w", encoding="utf-8") as file:
|
| 361 |
+
yaml.dump(self.to_dict(), file)
|
| 362 |
+
|
| 363 |
+
def _validate_params(self, params: Dict[str, Any]):
|
| 364 |
+
|
| 365 |
+
if not isinstance(params["vector_dim"], int) or params["vector_dim"] <= 0:
|
| 366 |
+
raise ValueError("vector_dim must be a positive integer.")
|
| 367 |
+
|
| 368 |
+
if not isinstance(params["batch_size"], int) or params["batch_size"] <= 0:
|
| 369 |
+
raise ValueError("batch_size must be a positive integer.")
|
| 370 |
+
|
| 371 |
+
if (
|
| 372 |
+
not isinstance(params["num_conv_layers"], int)
|
| 373 |
+
or params["num_conv_layers"] <= 0
|
| 374 |
+
):
|
| 375 |
+
raise ValueError("num_conv_layers must be a positive integer.")
|
| 376 |
+
|
| 377 |
+
if not isinstance(params["num_epoch"], int) or params["num_epoch"] <= 0:
|
| 378 |
+
raise ValueError("num_epoch must be a positive integer.")
|
| 379 |
+
|
| 380 |
+
if not isinstance(params["dropout"], float) or not (
|
| 381 |
+
0.0 <= params["dropout"] <= 1.0
|
| 382 |
+
):
|
| 383 |
+
raise ValueError("dropout must be a float between 0.0 and 1.0.")
|
| 384 |
+
|
| 385 |
+
if (
|
| 386 |
+
not isinstance(params["learning_rate"], float)
|
| 387 |
+
or params["learning_rate"] <= 0.0
|
| 388 |
+
):
|
| 389 |
+
raise ValueError("learning_rate must be a positive float.")
|
| 390 |
+
|
| 391 |
+
|
| 392 |
+
@dataclass
|
| 393 |
+
class TuningConfig(ConfigABC):
|
| 394 |
+
"""Configuration class for the network training.
|
| 395 |
+
|
| 396 |
+
:param batch_size: The number of targets per batch in the planning simulation step.
|
| 397 |
+
:param num_simulations: The number of planning simulations.
|
| 398 |
+
"""
|
| 399 |
+
|
| 400 |
+
batch_size: int = 100
|
| 401 |
+
num_simulations: int = 1
|
| 402 |
+
|
| 403 |
+
@staticmethod
|
| 404 |
+
def from_dict(config_dict: Dict[str, Any]) -> "TuningConfig":
|
| 405 |
+
return TuningConfig(**config_dict)
|
| 406 |
+
|
| 407 |
+
@staticmethod
|
| 408 |
+
def from_yaml(file_path: str) -> "TuningConfig":
|
| 409 |
+
with open(file_path, "r", encoding="utf-8") as file:
|
| 410 |
+
config_dict = yaml.safe_load(file)
|
| 411 |
+
return TuningConfig.from_dict(config_dict)
|
| 412 |
+
|
| 413 |
+
def _validate_params(self, params: Dict[str, Any]):
|
| 414 |
+
|
| 415 |
+
if not isinstance(params["batch_size"], int) or params["batch_size"] <= 0:
|
| 416 |
+
raise ValueError("batch_size must be a positive integer.")
|
| 417 |
+
|
| 418 |
+
|
| 419 |
+
@dataclass
|
| 420 |
+
class TreeConfig(ConfigABC):
|
| 421 |
+
"""Configuration class for the tree search algorithm.
|
| 422 |
+
|
| 423 |
+
:param max_iterations: The number of iterations to run the algorithm
|
| 424 |
+
for.
|
| 425 |
+
:param max_tree_size: The maximum number of nodes in the tree.
|
| 426 |
+
:param max_time: The time limit (in seconds) for the algorithm to
|
| 427 |
+
run.
|
| 428 |
+
:param max_depth: The maximum depth of the tree.
|
| 429 |
+
:param ucb_type: Type of UCB used in the search algorithm. Options
|
| 430 |
+
are "puct", "uct", "value", defaults to "uct".
|
| 431 |
+
:param c_ucb: The exploration-exploitation balance coefficient used
|
| 432 |
+
in Upper Confidence Bound (UCB).
|
| 433 |
+
:param backprop_type: Type of backpropagation algorithm. Options are
|
| 434 |
+
"muzero", "cumulative", defaults to "muzero".
|
| 435 |
+
:param search_strategy: The strategy used for tree search. Options
|
| 436 |
+
are "expansion_first", "evaluation_first".
|
| 437 |
+
:param exclude_small: Whether to exclude small molecules during the
|
| 438 |
+
search.
|
| 439 |
+
:param evaluation_agg: Method for aggregating evaluation scores.
|
| 440 |
+
Options are "max", "average", defaults to "max".
|
| 441 |
+
:param evaluation_type: The method used for evaluating nodes.
|
| 442 |
+
Options are "random", "rollout", "gcn".
|
| 443 |
+
:param init_node_value: Initial value for a new node.
|
| 444 |
+
:param epsilon: A parameter in the epsilon-greedy search strategy
|
| 445 |
+
representing the chance of random selection of reaction rules
|
| 446 |
+
during the selection stage in Monte Carlo Tree Search,
|
| 447 |
+
specifically during Upper Confidence Bound estimation. It
|
| 448 |
+
balances between exploration and exploitation.
|
| 449 |
+
:param min_mol_size: Defines the minimum size of a molecule that is
|
| 450 |
+
have to be synthesized. Molecules with 6 or fewer heavy atoms
|
| 451 |
+
are assumed to be building blocks by definition, thus setting
|
| 452 |
+
the threshold for considering larger molecules in the search,
|
| 453 |
+
defaults to 6.
|
| 454 |
+
:param silent: Whether to suppress progress output.
|
| 455 |
+
"""
|
| 456 |
+
|
| 457 |
+
max_iterations: int = 100
|
| 458 |
+
max_tree_size: int = 1000000
|
| 459 |
+
max_time: float = 600
|
| 460 |
+
max_depth: int = 6
|
| 461 |
+
ucb_type: str = "uct"
|
| 462 |
+
c_ucb: float = 0.1
|
| 463 |
+
backprop_type: str = "muzero"
|
| 464 |
+
search_strategy: str = "expansion_first"
|
| 465 |
+
exclude_small: bool = True
|
| 466 |
+
evaluation_agg: str = "max"
|
| 467 |
+
evaluation_type: str = "gcn"
|
| 468 |
+
init_node_value: float = 0.0
|
| 469 |
+
epsilon: float = 0.0
|
| 470 |
+
min_mol_size: int = 6
|
| 471 |
+
silent: bool = False
|
| 472 |
+
|
| 473 |
+
@staticmethod
|
| 474 |
+
def from_dict(config_dict: Dict[str, Any]) -> "TreeConfig":
|
| 475 |
+
return TreeConfig(**config_dict)
|
| 476 |
+
|
| 477 |
+
@staticmethod
|
| 478 |
+
def from_yaml(file_path: str) -> "TreeConfig":
|
| 479 |
+
with open(file_path, "r", encoding="utf-8") as file:
|
| 480 |
+
config_dict = yaml.safe_load(file)
|
| 481 |
+
return TreeConfig.from_dict(config_dict)
|
| 482 |
+
|
| 483 |
+
def _validate_params(self, params):
|
| 484 |
+
if params["ucb_type"] not in ["puct", "uct", "value"]:
|
| 485 |
+
raise ValueError(
|
| 486 |
+
"Invalid ucb_type. Allowed values are 'puct', 'uct', 'value'."
|
| 487 |
+
)
|
| 488 |
+
if params["backprop_type"] not in ["muzero", "cumulative"]:
|
| 489 |
+
raise ValueError(
|
| 490 |
+
"Invalid backprop_type. Allowed values are 'muzero', 'cumulative'."
|
| 491 |
+
)
|
| 492 |
+
if params["evaluation_type"] not in ["random", "rollout", "gcn"]:
|
| 493 |
+
raise ValueError(
|
| 494 |
+
"Invalid evaluation_type. Allowed values are 'random', 'rollout', 'gcn'."
|
| 495 |
+
)
|
| 496 |
+
if params["evaluation_agg"] not in ["max", "average"]:
|
| 497 |
+
raise ValueError(
|
| 498 |
+
"Invalid evaluation_agg. Allowed values are 'max', 'average'."
|
| 499 |
+
)
|
| 500 |
+
if not isinstance(params["c_ucb"], float):
|
| 501 |
+
raise TypeError("c_ucb must be a float.")
|
| 502 |
+
if not isinstance(params["max_depth"], int) or params["max_depth"] < 1:
|
| 503 |
+
raise ValueError("max_depth must be a positive integer.")
|
| 504 |
+
if not isinstance(params["max_tree_size"], int) or params["max_tree_size"] < 1:
|
| 505 |
+
raise ValueError("max_tree_size must be a positive integer.")
|
| 506 |
+
if (
|
| 507 |
+
not isinstance(params["max_iterations"], int)
|
| 508 |
+
or params["max_iterations"] < 1
|
| 509 |
+
):
|
| 510 |
+
raise ValueError("max_iterations must be a positive integer.")
|
| 511 |
+
if not isinstance(params["max_time"], int) or params["max_time"] < 1:
|
| 512 |
+
raise ValueError("max_time must be a positive integer.")
|
| 513 |
+
if not isinstance(params["exclude_small"], bool):
|
| 514 |
+
raise TypeError("exclude_small must be a boolean.")
|
| 515 |
+
if not isinstance(params["silent"], bool):
|
| 516 |
+
raise TypeError("silent must be a boolean.")
|
| 517 |
+
if not isinstance(params["init_node_value"], float):
|
| 518 |
+
raise TypeError("init_node_value must be a float if provided.")
|
| 519 |
+
if params["search_strategy"] not in ["expansion_first", "evaluation_first"]:
|
| 520 |
+
raise ValueError(
|
| 521 |
+
f"Invalid search_strategy: {params['search_strategy']}: "
|
| 522 |
+
f"Allowed values are 'expansion_first', 'evaluation_first'"
|
| 523 |
+
)
|
| 524 |
+
if not isinstance(params["epsilon"], float) or 0 >= params["epsilon"] >= 1:
|
| 525 |
+
raise ValueError("epsilon epsilon be a positive float between 0 and 1.")
|
| 526 |
+
if not isinstance(params["min_mol_size"], int) or params["min_mol_size"] < 0:
|
| 527 |
+
raise ValueError("min_mol_size must be a non-negative integer.")
|
| 528 |
+
|
| 529 |
+
|
| 530 |
+
def convert_config_to_dict(config_attr: ConfigABC, config_type) -> Dict | None:
|
| 531 |
+
"""Converts a configuration attribute to a dictionary if it's either a dictionary or
|
| 532 |
+
an instance of a specified configuration type.
|
| 533 |
+
|
| 534 |
+
:param config_attr: The configuration attribute to be converted.
|
| 535 |
+
:param config_type: The type to check against for conversion.
|
| 536 |
+
:return: The configuration attribute as a dictionary, or None if it's not an
|
| 537 |
+
instance of the given type or dict.
|
| 538 |
+
"""
|
| 539 |
+
if isinstance(config_attr, dict):
|
| 540 |
+
return config_attr
|
| 541 |
+
if isinstance(config_attr, config_type):
|
| 542 |
+
return config_attr.to_dict()
|
| 543 |
+
return None
|
synplan/utils/files.py
ADDED
|
@@ -0,0 +1,226 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Module containing classes and functions needed for reactions/molecules data
|
| 2 |
+
reading/writing."""
|
| 3 |
+
|
| 4 |
+
from os.path import splitext
|
| 5 |
+
from pathlib import Path
|
| 6 |
+
from typing import Iterable, Union
|
| 7 |
+
|
| 8 |
+
from CGRtools import smiles
|
| 9 |
+
from CGRtools.containers import CGRContainer, MoleculeContainer, ReactionContainer
|
| 10 |
+
from CGRtools.files.RDFrw import RDFRead, RDFWrite
|
| 11 |
+
from CGRtools.files.SDFrw import SDFRead, SDFWrite
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class FileHandler:
|
| 15 |
+
"""General class to handle chemical files."""
|
| 16 |
+
|
| 17 |
+
def __init__(self, filename: Union[str, Path], **kwargs):
|
| 18 |
+
"""General class to handle chemical files.
|
| 19 |
+
|
| 20 |
+
:param filename: The path and name of the file.
|
| 21 |
+
:return: None.
|
| 22 |
+
"""
|
| 23 |
+
self._file = None
|
| 24 |
+
_, ext = splitext(filename)
|
| 25 |
+
file_types = {".smi": "SMI", ".smiles": "SMI", ".rdf": "RDF", ".sdf": "SDF"}
|
| 26 |
+
try:
|
| 27 |
+
self._file_type = file_types[ext]
|
| 28 |
+
except KeyError:
|
| 29 |
+
raise ValueError("I don't know the file extension,", ext)
|
| 30 |
+
|
| 31 |
+
def close(self):
|
| 32 |
+
self._file.close()
|
| 33 |
+
|
| 34 |
+
def __exit__(self, exc_type, exc_val, exc_tb):
|
| 35 |
+
self.close()
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
class Reader(FileHandler):
|
| 39 |
+
def __init__(self, filename: Union[str, Path], **kwargs):
|
| 40 |
+
"""General class to read reactions/molecules data files.
|
| 41 |
+
|
| 42 |
+
:param filename: The path and name of the file.
|
| 43 |
+
:return: None.
|
| 44 |
+
"""
|
| 45 |
+
super().__init__(filename, **kwargs)
|
| 46 |
+
|
| 47 |
+
def __enter__(self):
|
| 48 |
+
return self._file
|
| 49 |
+
|
| 50 |
+
def __iter__(self):
|
| 51 |
+
return iter(self._file)
|
| 52 |
+
|
| 53 |
+
def __next__(self):
|
| 54 |
+
return next(self._file)
|
| 55 |
+
|
| 56 |
+
def __len__(self):
|
| 57 |
+
return len(self._file)
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
class SMILESRead:
|
| 61 |
+
def __init__(self, filename: Union[str, Path], **kwargs):
|
| 62 |
+
"""Simplified class to read files containing a SMILES (Molecules or Reaction)
|
| 63 |
+
string per line.
|
| 64 |
+
|
| 65 |
+
:param filename: The path and name of the SMILES file to parse.
|
| 66 |
+
:return: None.
|
| 67 |
+
"""
|
| 68 |
+
filename = str(Path(filename).resolve(strict=True))
|
| 69 |
+
self._file = open(filename, "r", encoding="utf-8")
|
| 70 |
+
self._data = self.__data()
|
| 71 |
+
|
| 72 |
+
def __data(
|
| 73 |
+
self,
|
| 74 |
+
) -> Iterable[Union[ReactionContainer, CGRContainer, MoleculeContainer]]:
|
| 75 |
+
for line in iter(self._file.readline, ""):
|
| 76 |
+
line = line.strip()
|
| 77 |
+
x = smiles(line)
|
| 78 |
+
if isinstance(x, (ReactionContainer, CGRContainer, MoleculeContainer)):
|
| 79 |
+
x.meta["init_smiles"] = line
|
| 80 |
+
yield x
|
| 81 |
+
|
| 82 |
+
def __enter__(self):
|
| 83 |
+
return self
|
| 84 |
+
|
| 85 |
+
def read(self):
|
| 86 |
+
"""Parse the whole SMILES file.
|
| 87 |
+
|
| 88 |
+
:return: List of parsed molecules or reactions.
|
| 89 |
+
"""
|
| 90 |
+
return list(iter(self))
|
| 91 |
+
|
| 92 |
+
def __iter__(self):
|
| 93 |
+
return (x for x in self._data)
|
| 94 |
+
|
| 95 |
+
def __next__(self):
|
| 96 |
+
return next(iter(self))
|
| 97 |
+
|
| 98 |
+
def close(self):
|
| 99 |
+
self._file.close()
|
| 100 |
+
|
| 101 |
+
def __exit__(self, exc_type, exc_val, exc_tb):
|
| 102 |
+
self.close()
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
class Writer(FileHandler):
|
| 106 |
+
def __init__(self, filename: Union[str, Path], mapping: bool = True, **kwargs):
|
| 107 |
+
"""General class to write chemical files.
|
| 108 |
+
|
| 109 |
+
:param filename: The path and name of the file.
|
| 110 |
+
:param mapping: Whenever to save mapping or not.
|
| 111 |
+
:return: None.
|
| 112 |
+
"""
|
| 113 |
+
super().__init__(filename, **kwargs)
|
| 114 |
+
self._mapping = mapping
|
| 115 |
+
|
| 116 |
+
def __enter__(self):
|
| 117 |
+
return self
|
| 118 |
+
|
| 119 |
+
|
| 120 |
+
class ReactionReader(Reader):
|
| 121 |
+
def __init__(self, filename: Union[str, Path], **kwargs):
|
| 122 |
+
"""Class to read reaction files.
|
| 123 |
+
|
| 124 |
+
:param filename: The path and name of the file.
|
| 125 |
+
:return: None.
|
| 126 |
+
"""
|
| 127 |
+
super().__init__(filename, **kwargs)
|
| 128 |
+
if self._file_type == "SMI":
|
| 129 |
+
self._file = SMILESRead(filename, **kwargs)
|
| 130 |
+
elif self._file_type == "RDF":
|
| 131 |
+
self._file = RDFRead(filename, indexable=True, **kwargs)
|
| 132 |
+
else:
|
| 133 |
+
raise ValueError("File type incompatible -", filename)
|
| 134 |
+
|
| 135 |
+
|
| 136 |
+
class ReactionWriter(Writer):
|
| 137 |
+
def __init__(self, filename: Union[str, Path], mapping: bool = True, **kwargs):
|
| 138 |
+
"""Class to write reaction files.
|
| 139 |
+
|
| 140 |
+
:param filename: The path and name of the file.
|
| 141 |
+
:param mapping: Whenever to save mapping or not.
|
| 142 |
+
:return: None.
|
| 143 |
+
"""
|
| 144 |
+
super().__init__(filename, mapping, **kwargs)
|
| 145 |
+
if self._file_type == "SMI":
|
| 146 |
+
self._file = open(filename, "w", encoding="utf-8", **kwargs)
|
| 147 |
+
elif self._file_type == "RDF":
|
| 148 |
+
self._file = RDFWrite(filename, append=False, **kwargs)
|
| 149 |
+
else:
|
| 150 |
+
raise ValueError("File type incompatible -", filename)
|
| 151 |
+
|
| 152 |
+
def write(self, reaction: ReactionContainer):
|
| 153 |
+
"""Function to write a specific reaction to the file.
|
| 154 |
+
|
| 155 |
+
:param reaction: The path and name of the file.
|
| 156 |
+
:return: None.
|
| 157 |
+
"""
|
| 158 |
+
if self._file_type == "SMI":
|
| 159 |
+
rea_str = to_reaction_smiles_record(reaction)
|
| 160 |
+
self._file.write(rea_str + "\n")
|
| 161 |
+
elif self._file_type == "RDF":
|
| 162 |
+
self._file.write(reaction)
|
| 163 |
+
|
| 164 |
+
|
| 165 |
+
class MoleculeReader(Reader):
|
| 166 |
+
def __init__(self, filename: Union[str, Path], **kwargs):
|
| 167 |
+
"""Class to read molecule files.
|
| 168 |
+
|
| 169 |
+
:param filename: The path and name of the file.
|
| 170 |
+
:return: None.
|
| 171 |
+
"""
|
| 172 |
+
super().__init__(filename, **kwargs)
|
| 173 |
+
if self._file_type == "SMI":
|
| 174 |
+
self._file = SMILESRead(filename, ignore=True, **kwargs)
|
| 175 |
+
elif self._file_type == "SDF":
|
| 176 |
+
self._file = SDFRead(filename, indexable=True, **kwargs)
|
| 177 |
+
else:
|
| 178 |
+
raise ValueError("File type incompatible -", filename)
|
| 179 |
+
|
| 180 |
+
|
| 181 |
+
class MoleculeWriter(Writer):
|
| 182 |
+
def __init__(self, filename: Union[str, Path], mapping: bool = True, **kwargs):
|
| 183 |
+
"""Class to write molecule files.
|
| 184 |
+
|
| 185 |
+
:param filename: The path and name of the file.
|
| 186 |
+
:param mapping: Whenever to save mapping or not.
|
| 187 |
+
:return: None.
|
| 188 |
+
"""
|
| 189 |
+
super().__init__(filename, mapping, **kwargs)
|
| 190 |
+
if self._file_type == "SMI":
|
| 191 |
+
self._file = open(filename, "w", encoding="utf-8", **kwargs)
|
| 192 |
+
elif self._file_type == "SDF":
|
| 193 |
+
self._file = SDFWrite(filename, append=False, **kwargs)
|
| 194 |
+
else:
|
| 195 |
+
raise ValueError("File type incompatible -", filename)
|
| 196 |
+
|
| 197 |
+
def write(self, molecule: MoleculeContainer):
|
| 198 |
+
"""Function to write a specific molecule to the file.
|
| 199 |
+
|
| 200 |
+
:param molecule: The path and name of the file.
|
| 201 |
+
:return: None.
|
| 202 |
+
"""
|
| 203 |
+
if self._file_type == "SMI":
|
| 204 |
+
mol_str = str(molecule)
|
| 205 |
+
self._file.write(mol_str + "\n")
|
| 206 |
+
elif self._file_type == "SDF":
|
| 207 |
+
self._file.write(molecule)
|
| 208 |
+
|
| 209 |
+
|
| 210 |
+
def to_reaction_smiles_record(reaction: ReactionContainer) -> str:
|
| 211 |
+
"""Converts the reaction to the SMILES record. Needed for reaction/molecule writers.
|
| 212 |
+
|
| 213 |
+
:param reaction: The reaction to be written.
|
| 214 |
+
:return: The SMILES record to be written.
|
| 215 |
+
"""
|
| 216 |
+
|
| 217 |
+
if isinstance(reaction, str):
|
| 218 |
+
return reaction
|
| 219 |
+
|
| 220 |
+
reaction_record = [format(reaction, "m")]
|
| 221 |
+
sorted_meta = sorted(reaction.meta.items(), key=lambda x: x[0])
|
| 222 |
+
for _, meta_info in sorted_meta:
|
| 223 |
+
meta_info = ""
|
| 224 |
+
meta_info = ";".join(meta_info.split("\n"))
|
| 225 |
+
reaction_record.append(str(meta_info))
|
| 226 |
+
return "\t".join(reaction_record)
|
synplan/utils/loading.py
ADDED
|
@@ -0,0 +1,151 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Module containing functions for loading reaction rules, building blocks and
|
| 2 |
+
retrosynthetic models."""
|
| 3 |
+
|
| 4 |
+
import functools
|
| 5 |
+
import pickle
|
| 6 |
+
import zipfile
|
| 7 |
+
from pathlib import Path
|
| 8 |
+
from typing import List, Set, Union
|
| 9 |
+
|
| 10 |
+
from CGRtools.reactor.reactor import Reactor
|
| 11 |
+
from torch import device
|
| 12 |
+
from huggingface_hub import hf_hub_download, snapshot_download
|
| 13 |
+
from tqdm import tqdm
|
| 14 |
+
|
| 15 |
+
from synplan.ml.networks.policy import PolicyNetwork
|
| 16 |
+
from synplan.ml.networks.value import ValueNetwork
|
| 17 |
+
from synplan.utils.files import MoleculeReader
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def download_unpack_data(filename, subfolder, save_to="."):
|
| 21 |
+
if isinstance(save_to, str):
|
| 22 |
+
save_to = Path(save_to).resolve()
|
| 23 |
+
save_to.mkdir(exist_ok=True)
|
| 24 |
+
|
| 25 |
+
# Download the zip file from the repository
|
| 26 |
+
file_path = hf_hub_download(
|
| 27 |
+
repo_id="Laboratoire-De-Chemoinformatique/SynPlanner",
|
| 28 |
+
filename=filename,
|
| 29 |
+
subfolder=subfolder,
|
| 30 |
+
local_dir=save_to,
|
| 31 |
+
)
|
| 32 |
+
file_path = Path(file_path)
|
| 33 |
+
|
| 34 |
+
if file_path.suffix == ".zip":
|
| 35 |
+
with zipfile.ZipFile(file_path, "r") as zip_ref:
|
| 36 |
+
# Extract the single file in the zip
|
| 37 |
+
zip_ref.extractall(save_to)
|
| 38 |
+
extracted_file = save_to / zip_ref.namelist()[0]
|
| 39 |
+
|
| 40 |
+
file_path.unlink()
|
| 41 |
+
|
| 42 |
+
return extracted_file
|
| 43 |
+
else:
|
| 44 |
+
return file_path
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
def download_all_data(save_to="."):
|
| 48 |
+
dir_path = snapshot_download(
|
| 49 |
+
repo_id="Laboratoire-De-Chemoinformatique/SynPlanner", local_dir=save_to
|
| 50 |
+
)
|
| 51 |
+
dir_path = Path(dir_path).resolve()
|
| 52 |
+
for zip_file in dir_path.rglob("*.zip"):
|
| 53 |
+
with zipfile.ZipFile(zip_file, "r") as zip_ref:
|
| 54 |
+
# Check each file in the zip
|
| 55 |
+
for file_name in zip_ref.namelist():
|
| 56 |
+
extracted_file_path = zip_file.parent / file_name
|
| 57 |
+
|
| 58 |
+
# Check if the extracted file already exists
|
| 59 |
+
if not extracted_file_path.exists():
|
| 60 |
+
# Extract the file if it does not exist
|
| 61 |
+
zip_ref.extract(file_name, zip_file.parent)
|
| 62 |
+
print(f"Extracted {file_name} to {zip_file.parent}")
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
@functools.lru_cache(maxsize=None)
|
| 66 |
+
def load_reaction_rules(file: str) -> List[Reactor]:
|
| 67 |
+
"""Loads the reaction rules from a pickle file and converts them into a list of
|
| 68 |
+
Reactor objects if necessary.
|
| 69 |
+
|
| 70 |
+
:param file: The path to the pickle file that stores the reaction rules.
|
| 71 |
+
:return: A list of reaction rules as Reactor objects.
|
| 72 |
+
"""
|
| 73 |
+
|
| 74 |
+
with open(file, "rb") as f:
|
| 75 |
+
reaction_rules = pickle.load(f)
|
| 76 |
+
|
| 77 |
+
if not isinstance(reaction_rules[0][0], Reactor):
|
| 78 |
+
reaction_rules = [Reactor(x) for x, _ in reaction_rules]
|
| 79 |
+
|
| 80 |
+
return reaction_rules
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
@functools.lru_cache(maxsize=None)
|
| 84 |
+
def load_building_blocks(
|
| 85 |
+
building_blocks_path: Union[str, Path], standardize: bool = True
|
| 86 |
+
) -> Set[str]:
|
| 87 |
+
"""Loads building blocks data from a file and returns a frozen set of building
|
| 88 |
+
blocks.
|
| 89 |
+
|
| 90 |
+
:param building_blocks_path: The path to the file containing the building blocks.
|
| 91 |
+
:param standardize: Flag if building blocks have to be standardized before loading. Default=True.
|
| 92 |
+
:return: The set of building blocks smiles.
|
| 93 |
+
"""
|
| 94 |
+
|
| 95 |
+
building_blocks_path = Path(building_blocks_path).resolve()
|
| 96 |
+
assert (
|
| 97 |
+
building_blocks_path.suffix == ".smi"
|
| 98 |
+
or building_blocks_path.suffix == ".smiles"
|
| 99 |
+
)
|
| 100 |
+
|
| 101 |
+
building_blocks_smiles = set()
|
| 102 |
+
if standardize:
|
| 103 |
+
with MoleculeReader(building_blocks_path) as molecules:
|
| 104 |
+
for mol in tqdm(
|
| 105 |
+
molecules,
|
| 106 |
+
desc="Number of building blocks processed: ",
|
| 107 |
+
bar_format="{desc}{n} [{elapsed}]",
|
| 108 |
+
):
|
| 109 |
+
try:
|
| 110 |
+
mol.canonicalize()
|
| 111 |
+
mol.clean_stereo()
|
| 112 |
+
building_blocks_smiles.add(str(mol))
|
| 113 |
+
except: # mol.canonicalize() / InvalidAromaticRing
|
| 114 |
+
pass
|
| 115 |
+
else:
|
| 116 |
+
with open(building_blocks_path, "r") as inp:
|
| 117 |
+
for line in inp:
|
| 118 |
+
smiles = line.strip().split()[0]
|
| 119 |
+
building_blocks_smiles.add(smiles)
|
| 120 |
+
|
| 121 |
+
return building_blocks_smiles
|
| 122 |
+
|
| 123 |
+
|
| 124 |
+
def load_value_net(
|
| 125 |
+
model_class: ValueNetwork, value_network_path: Union[str, Path]
|
| 126 |
+
) -> ValueNetwork:
|
| 127 |
+
"""Loads the value network.
|
| 128 |
+
|
| 129 |
+
:param value_network_path: The path to the file storing value network weights.
|
| 130 |
+
:param model_class: The model class to be loaded.
|
| 131 |
+
:return: The loaded value network.
|
| 132 |
+
"""
|
| 133 |
+
|
| 134 |
+
map_location = device("cpu")
|
| 135 |
+
return model_class.load_from_checkpoint(value_network_path, map_location)
|
| 136 |
+
|
| 137 |
+
|
| 138 |
+
def load_policy_net(
|
| 139 |
+
model_class: PolicyNetwork, policy_network_path: Union[str, Path]
|
| 140 |
+
) -> PolicyNetwork:
|
| 141 |
+
"""Loads the policy network.
|
| 142 |
+
|
| 143 |
+
:param policy_network_path: The path to the file storing policy network weights.
|
| 144 |
+
:param model_class: The model class to be loaded.
|
| 145 |
+
:return: The loaded policy network.
|
| 146 |
+
"""
|
| 147 |
+
|
| 148 |
+
map_location = device("cpu")
|
| 149 |
+
return model_class.load_from_checkpoint(
|
| 150 |
+
policy_network_path, map_location, batch_size=1
|
| 151 |
+
)
|
synplan/utils/logging.py
ADDED
|
@@ -0,0 +1,179 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Generic logging helpers for scripts, notebooks and Ray clusters.
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
from __future__ import annotations
|
| 6 |
+
import logging, sys, os, warnings
|
| 7 |
+
from pathlib import Path
|
| 8 |
+
from datetime import datetime
|
| 9 |
+
from typing import Iterable, Optional
|
| 10 |
+
from IPython import get_ipython
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
# --------------------------------------------------------------------------- #
|
| 14 |
+
# Helper classes #
|
| 15 |
+
# --------------------------------------------------------------------------- #
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class DisableLogger:
|
| 19 |
+
"""Context‑manager that suppresses *all* logging inside its scope."""
|
| 20 |
+
|
| 21 |
+
def __enter__(self):
|
| 22 |
+
logging.disable(logging.CRITICAL)
|
| 23 |
+
|
| 24 |
+
def __exit__(self, exc_type, exc_val, exc_tb):
|
| 25 |
+
logging.disable(logging.NOTSET)
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
class HiddenPrints:
|
| 29 |
+
"""Context‑manager that suppresses *print* output inside its scope."""
|
| 30 |
+
|
| 31 |
+
def __enter__(self):
|
| 32 |
+
self._orig = sys.stdout
|
| 33 |
+
sys.stdout = open(os.devnull, "w")
|
| 34 |
+
|
| 35 |
+
def __exit__(self, exc_type, exc_val, exc_tb):
|
| 36 |
+
sys.stdout.close()
|
| 37 |
+
sys.stdout = self._orig
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
# --------------------------------------------------------------------------- #
|
| 41 |
+
# Notebook‑aware console handler #
|
| 42 |
+
# --------------------------------------------------------------------------- #
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
def _in_notebook() -> bool:
|
| 46 |
+
ip = get_ipython()
|
| 47 |
+
return bool(ip) and ip.__class__.__name__ == "ZMQInteractiveShell"
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
class TqdmHandler(logging.StreamHandler):
|
| 51 |
+
"""Write via tqdm.write so log lines don't break progress bars."""
|
| 52 |
+
|
| 53 |
+
def emit(self, record):
|
| 54 |
+
try:
|
| 55 |
+
from tqdm import tqdm
|
| 56 |
+
|
| 57 |
+
tqdm.write(self.format(record), end=self.terminator)
|
| 58 |
+
except ModuleNotFoundError:
|
| 59 |
+
super().emit(record)
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
# --------------------------------------------------------------------------- #
|
| 63 |
+
# Public initialisation API #
|
| 64 |
+
# --------------------------------------------------------------------------- #
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
def init_logger(
|
| 68 |
+
*,
|
| 69 |
+
name: str = "app",
|
| 70 |
+
console_level: str | int = "ERROR",
|
| 71 |
+
file_level: str | int = "INFO",
|
| 72 |
+
log_dir: str | os.PathLike = ".",
|
| 73 |
+
redirect_tqdm: bool = True,
|
| 74 |
+
) -> logging.Logger:
|
| 75 |
+
"""
|
| 76 |
+
Initialise (or fetch) a namespaced logger that works in scripts &
|
| 77 |
+
notebooks. Idempotent ‑ safe to call multiple times.
|
| 78 |
+
|
| 79 |
+
Returns
|
| 80 |
+
-------
|
| 81 |
+
logging.Logger
|
| 82 |
+
Configured logger instance.
|
| 83 |
+
"""
|
| 84 |
+
logger = logging.getLogger(name)
|
| 85 |
+
if logger.handlers: # already configured
|
| 86 |
+
return logger
|
| 87 |
+
|
| 88 |
+
logger.setLevel("DEBUG") # capture everything; handlers filter
|
| 89 |
+
|
| 90 |
+
# console / notebook handler
|
| 91 |
+
if _in_notebook() or (redirect_tqdm and "tqdm" in sys.modules):
|
| 92 |
+
ch: logging.Handler = TqdmHandler()
|
| 93 |
+
else:
|
| 94 |
+
ch = logging.StreamHandler(sys.stderr)
|
| 95 |
+
ch.setLevel(console_level)
|
| 96 |
+
ch.setFormatter(
|
| 97 |
+
logging.Formatter(
|
| 98 |
+
"%(asctime)s | %(levelname)-8s | %(message)s",
|
| 99 |
+
datefmt="%H:%M:%S",
|
| 100 |
+
)
|
| 101 |
+
)
|
| 102 |
+
logger.addHandler(ch)
|
| 103 |
+
|
| 104 |
+
# rotating file handler (one file per session)
|
| 105 |
+
Path(log_dir).mkdir(parents=True, exist_ok=True)
|
| 106 |
+
stamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
| 107 |
+
fh = logging.FileHandler(Path(log_dir) / f"{name}_{stamp}.log", encoding="utf-8")
|
| 108 |
+
fh.setLevel(file_level)
|
| 109 |
+
fh.setFormatter(
|
| 110 |
+
logging.Formatter(
|
| 111 |
+
"%(asctime)s | %(name)s | %(levelname)-8s | %(message)s",
|
| 112 |
+
datefmt="%Y-%m-%d %H:%M:%S",
|
| 113 |
+
)
|
| 114 |
+
)
|
| 115 |
+
logger.addHandler(fh)
|
| 116 |
+
|
| 117 |
+
# logger.propagate = False # Removed correctly
|
| 118 |
+
log_file_path = fh.baseFilename
|
| 119 |
+
logger.info("Logging initialised → %s", log_file_path)
|
| 120 |
+
return logger, log_file_path # <-- Return path too
|
| 121 |
+
|
| 122 |
+
|
| 123 |
+
# --------------------------------------------------------------------------- #
|
| 124 |
+
# Optional Ray‑specific configuration helpers #
|
| 125 |
+
# --------------------------------------------------------------------------- #
|
| 126 |
+
|
| 127 |
+
|
| 128 |
+
def init_ray_logging(
|
| 129 |
+
*,
|
| 130 |
+
python_level: str | int = "ERROR",
|
| 131 |
+
backend_level: str = "error",
|
| 132 |
+
log_to_driver: bool = False,
|
| 133 |
+
filter_userwarnings: bool = True,
|
| 134 |
+
) -> "ray.LoggingConfig":
|
| 135 |
+
"""
|
| 136 |
+
Prepare environment + Ray LoggingConfig **before** `ray.init()`.
|
| 137 |
+
|
| 138 |
+
Returns
|
| 139 |
+
-------
|
| 140 |
+
ray.LoggingConfig
|
| 141 |
+
Pass as `logging_config=` argument to `ray.init()`.
|
| 142 |
+
"""
|
| 143 |
+
# 1) silence C++ backend (raylet / plasma) BEFORE importing ray
|
| 144 |
+
os.environ.setdefault("RAY_BACKEND_LOG_LEVEL", backend_level)
|
| 145 |
+
|
| 146 |
+
# 2) optional warnings filter
|
| 147 |
+
if filter_userwarnings:
|
| 148 |
+
warnings.filterwarnings("ignore", category=UserWarning)
|
| 149 |
+
|
| 150 |
+
import ray # local import to avoid hard dep
|
| 151 |
+
|
| 152 |
+
# 3) global Python logger levels for every worker
|
| 153 |
+
ray_logger_names: Iterable[str] = (
|
| 154 |
+
"ray",
|
| 155 |
+
"ray.worker",
|
| 156 |
+
"ray.runtime",
|
| 157 |
+
"ray.dashboard",
|
| 158 |
+
"ray.tune",
|
| 159 |
+
"ray.serve",
|
| 160 |
+
)
|
| 161 |
+
for n in ray_logger_names:
|
| 162 |
+
logging.getLogger(n).setLevel(python_level)
|
| 163 |
+
|
| 164 |
+
# 4) build LoggingConfig that propagates to workers
|
| 165 |
+
return ray.LoggingConfig(
|
| 166 |
+
log_to_driver=log_to_driver,
|
| 167 |
+
log_level=python_level,
|
| 168 |
+
)
|
| 169 |
+
|
| 170 |
+
|
| 171 |
+
def silence_logger(
|
| 172 |
+
logger_name: str,
|
| 173 |
+
level: int | str = logging.ERROR,
|
| 174 |
+
):
|
| 175 |
+
"""
|
| 176 |
+
Call at the *top* of every `@ray.remote` function or actor `__init__`
|
| 177 |
+
to raise the threshold of a chatty library **inside the worker**.
|
| 178 |
+
"""
|
| 179 |
+
logging.getLogger(logger_name).setLevel(level)
|
synplan/utils/visualisation.py
ADDED
|
@@ -0,0 +1,1365 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Module containing functions for analysis and visualization of the built tree."""
|
| 2 |
+
|
| 3 |
+
import base64
|
| 4 |
+
from itertools import count, islice
|
| 5 |
+
from collections import deque
|
| 6 |
+
from typing import Any, Dict, List, Union
|
| 7 |
+
|
| 8 |
+
from CGRtools.containers.molecule import MoleculeContainer
|
| 9 |
+
from CGRtools import smiles as read_smiles
|
| 10 |
+
|
| 11 |
+
from synplan.chem.reaction_routes.visualisation import (
|
| 12 |
+
cgr_display,
|
| 13 |
+
depict_custom_reaction,
|
| 14 |
+
)
|
| 15 |
+
from synplan.chem.reaction_routes.io import make_dict
|
| 16 |
+
from synplan.mcts.tree import Tree
|
| 17 |
+
|
| 18 |
+
from IPython.display import display, HTML
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def get_child_nodes(
|
| 22 |
+
tree: Tree,
|
| 23 |
+
molecule: MoleculeContainer,
|
| 24 |
+
graph: Dict[MoleculeContainer, List[MoleculeContainer]],
|
| 25 |
+
) -> Dict[str, Any]:
|
| 26 |
+
"""Extracts the child nodes of the given molecule.
|
| 27 |
+
|
| 28 |
+
:param tree: The built tree.
|
| 29 |
+
:param molecule: The molecule in the tree from which to extract child nodes.
|
| 30 |
+
:param graph: The relationship between the given molecule and child nodes.
|
| 31 |
+
:return: The dict with extracted child nodes.
|
| 32 |
+
"""
|
| 33 |
+
|
| 34 |
+
nodes = []
|
| 35 |
+
try:
|
| 36 |
+
graph[molecule]
|
| 37 |
+
except KeyError:
|
| 38 |
+
return []
|
| 39 |
+
for precursor in graph[molecule]:
|
| 40 |
+
temp_obj = {
|
| 41 |
+
"smiles": str(precursor),
|
| 42 |
+
"type": "mol",
|
| 43 |
+
"in_stock": str(precursor) in tree.building_blocks,
|
| 44 |
+
}
|
| 45 |
+
node = get_child_nodes(tree, precursor, graph)
|
| 46 |
+
if node:
|
| 47 |
+
temp_obj["children"] = [node]
|
| 48 |
+
nodes.append(temp_obj)
|
| 49 |
+
return {"type": "reaction", "children": nodes}
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
def extract_routes(
|
| 53 |
+
tree: Tree, extended: bool = False, min_mol_size: int = 0
|
| 54 |
+
) -> List[Dict[str, Any]]:
|
| 55 |
+
"""Takes the target and the dictionary of successors and predecessors and returns a
|
| 56 |
+
list of dictionaries that contain the target and the list of successors.
|
| 57 |
+
|
| 58 |
+
:param tree: The built tree.
|
| 59 |
+
:param extended: If True, generates the extended route representation.
|
| 60 |
+
:param min_mol_size: If the size of the Precursor is equal or smaller than
|
| 61 |
+
min_mol_size it is automatically classified as building block.
|
| 62 |
+
:return: A list of dictionaries. Each dictionary contains a target, a list of
|
| 63 |
+
children, and a boolean indicating whether the target is in building_blocks.
|
| 64 |
+
"""
|
| 65 |
+
target = tree.nodes[1].precursors_to_expand[0].molecule
|
| 66 |
+
target_in_stock = tree.nodes[1].curr_precursor.is_building_block(
|
| 67 |
+
tree.building_blocks, min_mol_size
|
| 68 |
+
)
|
| 69 |
+
|
| 70 |
+
# append encoded routes to list
|
| 71 |
+
routes_block = []
|
| 72 |
+
winning_nodes = []
|
| 73 |
+
if extended:
|
| 74 |
+
# collect routes
|
| 75 |
+
for i, node in tree.nodes.items():
|
| 76 |
+
if node.is_solved():
|
| 77 |
+
winning_nodes.append(i)
|
| 78 |
+
else:
|
| 79 |
+
winning_nodes = tree.winning_nodes
|
| 80 |
+
if winning_nodes:
|
| 81 |
+
for winning_node in winning_nodes:
|
| 82 |
+
# Create graph for route
|
| 83 |
+
nodes = tree.route_to_node(winning_node)
|
| 84 |
+
graph, pred = {}, {}
|
| 85 |
+
for before, after in zip(nodes, nodes[1:]):
|
| 86 |
+
before = before.curr_precursor.molecule
|
| 87 |
+
graph[before] = after = [x.molecule for x in after.new_precursors]
|
| 88 |
+
for x in after:
|
| 89 |
+
pred[x] = before
|
| 90 |
+
|
| 91 |
+
routes_block.append(
|
| 92 |
+
{
|
| 93 |
+
"type": "mol",
|
| 94 |
+
"smiles": str(target),
|
| 95 |
+
"in_stock": target_in_stock,
|
| 96 |
+
"children": [get_child_nodes(tree, target, graph)],
|
| 97 |
+
}
|
| 98 |
+
)
|
| 99 |
+
else:
|
| 100 |
+
routes_block = [
|
| 101 |
+
{
|
| 102 |
+
"type": "mol",
|
| 103 |
+
"smiles": str(target),
|
| 104 |
+
"in_stock": target_in_stock,
|
| 105 |
+
"children": [],
|
| 106 |
+
}
|
| 107 |
+
]
|
| 108 |
+
return routes_block
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
def render_svg(pred, columns, box_colors):
|
| 112 |
+
"""
|
| 113 |
+
Renders an SVG representation of a retrosynthetic route.
|
| 114 |
+
|
| 115 |
+
This function takes the predicted reaction steps, the molecules organized
|
| 116 |
+
into columns representing reaction stages, and a mapping of molecule status
|
| 117 |
+
to box colors, and generates an SVG string visualizing the route. It
|
| 118 |
+
calculates positions for molecules and arrows, and constructs the SVG
|
| 119 |
+
elements.
|
| 120 |
+
|
| 121 |
+
Args:
|
| 122 |
+
pred (tuple): A tuple of tuples representing the predicted reaction
|
| 123 |
+
steps. Each inner tuple is (source_molecule_index,
|
| 124 |
+
target_molecule_index). The indices correspond to the
|
| 125 |
+
flattened list of molecules across all columns.
|
| 126 |
+
columns (list): A list of lists, where each inner list contains
|
| 127 |
+
Molecule objects for a specific stage (column) in the
|
| 128 |
+
retrosynthetic route.
|
| 129 |
+
box_colors (dict): A dictionary mapping molecule status strings (e.g.,
|
| 130 |
+
'target', 'mulecule', 'instock') to SVG color strings
|
| 131 |
+
for the boxes around the molecules.
|
| 132 |
+
|
| 133 |
+
Returns:
|
| 134 |
+
str: A string containing the complete SVG code for the retrosynthetic
|
| 135 |
+
route visualization.
|
| 136 |
+
"""
|
| 137 |
+
x_shift = 0.0
|
| 138 |
+
c_max_x = 0.0
|
| 139 |
+
c_max_y = 0.0
|
| 140 |
+
render = []
|
| 141 |
+
cx = count()
|
| 142 |
+
cy = count()
|
| 143 |
+
arrow_points = {}
|
| 144 |
+
for ms in columns:
|
| 145 |
+
heights = []
|
| 146 |
+
for m in ms:
|
| 147 |
+
m.clean2d()
|
| 148 |
+
# X-shift for target
|
| 149 |
+
min_x = min(x for x, y in m._plane.values()) - x_shift
|
| 150 |
+
min_y = min(y for x, y in m._plane.values())
|
| 151 |
+
m._plane = {n: (x - min_x, y - min_y) for n, (x, y) in m._plane.items()}
|
| 152 |
+
max_x = max(x for x, y in m._plane.values())
|
| 153 |
+
|
| 154 |
+
c_max_x = max(c_max_x, max_x)
|
| 155 |
+
|
| 156 |
+
arrow_points[next(cx)] = [x_shift, max_x]
|
| 157 |
+
heights.append(max(y for x, y in m._plane.values()))
|
| 158 |
+
|
| 159 |
+
x_shift = c_max_x + 5.0 # between columns gap
|
| 160 |
+
# calculate Y-shift
|
| 161 |
+
y_shift = sum(heights) + 3.0 * (len(heights) - 1)
|
| 162 |
+
|
| 163 |
+
c_max_y = max(c_max_y, y_shift)
|
| 164 |
+
|
| 165 |
+
y_shift /= 2.0
|
| 166 |
+
for m, h in zip(ms, heights):
|
| 167 |
+
m._plane = {n: (x, y - y_shift) for n, (x, y) in m._plane.items()}
|
| 168 |
+
|
| 169 |
+
# calculate coordinates for boxes
|
| 170 |
+
max_x = max(x for x, y in m._plane.values()) + 0.9 # max x
|
| 171 |
+
min_x = min(x for x, y in m._plane.values()) - 0.6 # min x
|
| 172 |
+
max_y = -(max(y for x, y in m._plane.values()) + 0.45) # max y
|
| 173 |
+
min_y = -(min(y for x, y in m._plane.values()) - 0.45) # min y
|
| 174 |
+
x_delta = abs(max_x - min_x)
|
| 175 |
+
y_delta = abs(max_y - min_y)
|
| 176 |
+
box = (
|
| 177 |
+
f'<rect x="{min_x}" y="{max_y}" rx="{y_delta * 0.1}" ry="{y_delta * 0.1}" width="{x_delta}" height="{y_delta}"'
|
| 178 |
+
f' stroke="black" stroke-width=".0025" fill="{box_colors[m.meta["status"]]}" fill-opacity="0.30"/>'
|
| 179 |
+
)
|
| 180 |
+
arrow_points[next(cy)].append(y_shift - h / 2.0)
|
| 181 |
+
y_shift -= h + 3.0
|
| 182 |
+
depicted_molecule = list(m.depict(embedding=True))[:3]
|
| 183 |
+
depicted_molecule.append(box)
|
| 184 |
+
render.append(depicted_molecule)
|
| 185 |
+
|
| 186 |
+
# calculate mid-X coordinate to draw square arrows
|
| 187 |
+
graph = {}
|
| 188 |
+
for s, p in pred:
|
| 189 |
+
try:
|
| 190 |
+
graph[s].append(p)
|
| 191 |
+
except KeyError:
|
| 192 |
+
graph[s] = [p]
|
| 193 |
+
for s, ps in graph.items():
|
| 194 |
+
mid_x = float("-inf")
|
| 195 |
+
for p in ps:
|
| 196 |
+
s_min_x, s_max, s_y = arrow_points[s][:3] # s
|
| 197 |
+
p_min_x, p_max, p_y = arrow_points[p][:3] # p
|
| 198 |
+
p_max += 1
|
| 199 |
+
mid = p_max + (s_min_x - p_max) / 3
|
| 200 |
+
mid_x = max(mid_x, mid)
|
| 201 |
+
for p in ps:
|
| 202 |
+
arrow_points[p].append(mid_x)
|
| 203 |
+
|
| 204 |
+
config = MoleculeContainer._render_config
|
| 205 |
+
font_size = config["font_size"]
|
| 206 |
+
font125 = 1.25 * font_size
|
| 207 |
+
width = c_max_x + 4.0 * font_size # 3.0 by default
|
| 208 |
+
height = c_max_y + 3.5 * font_size # 2.5 by default
|
| 209 |
+
box_y = height / 2.0
|
| 210 |
+
svg = [
|
| 211 |
+
f'<svg width="{0.6 * width:.2f}cm" height="{0.6 * height:.2f}cm" '
|
| 212 |
+
f'viewBox="{-font125:.2f} {-box_y:.2f} {width:.2f} '
|
| 213 |
+
f'{height:.2f}" xmlns="http://www.w3.org/2000/svg" version="1.1">',
|
| 214 |
+
' <defs>\n <marker id="arrow" markerWidth="10" markerHeight="10" '
|
| 215 |
+
'refX="0" refY="3" orient="auto">\n <path d="M0,0 L0,6 L9,3"/>\n </marker>\n </defs>',
|
| 216 |
+
]
|
| 217 |
+
|
| 218 |
+
for s, p in pred:
|
| 219 |
+
s_min_x, s_max, s_y = arrow_points[s][:3]
|
| 220 |
+
p_min_x, p_max, p_y = arrow_points[p][:3]
|
| 221 |
+
p_max += 1
|
| 222 |
+
mid_x = arrow_points[p][-1] # p_max + (s_min_x - p_max) / 3
|
| 223 |
+
arrow = f""" <polyline points="{p_max:.2f} {p_y:.2f}, {mid_x:.2f} {p_y:.2f}, {mid_x:.2f} {s_y:.2f}, {s_min_x - 1.:.2f} {s_y:.2f}"
|
| 224 |
+
fill="none" stroke="black" stroke-width=".04" marker-end="url(#arrow)"/>"""
|
| 225 |
+
if p_y != s_y:
|
| 226 |
+
arrow += f' <circle cx="{mid_x}" cy="{p_y}" r="0.1"/>'
|
| 227 |
+
svg.append(arrow)
|
| 228 |
+
for atoms, bonds, masks, box in render:
|
| 229 |
+
molecule_svg = MoleculeContainer._graph_svg(
|
| 230 |
+
atoms, bonds, masks, -font125, -box_y, width, height
|
| 231 |
+
)
|
| 232 |
+
molecule_svg.insert(1, box)
|
| 233 |
+
svg.extend(molecule_svg)
|
| 234 |
+
svg.append("</svg>")
|
| 235 |
+
return "\n".join(svg)
|
| 236 |
+
|
| 237 |
+
def get_route_svg_mod(tree: Tree, node_id: int) -> str:
|
| 238 |
+
"""
|
| 239 |
+
Visualizes the full retrosynthetic route from the target to a given node.
|
| 240 |
+
|
| 241 |
+
This function generates an SVG image for the synthetic path from the target
|
| 242 |
+
molecule to the specified node_id. It correctly handles paths that have not
|
| 243 |
+
been fully resolved to building blocks. The layout follows standard
|
| 244 |
+
retrosynthetic analysis, with the target on the right and precursors
|
| 245 |
+
arranged in columns to the left.
|
| 246 |
+
|
| 247 |
+
:param tree: The built MCTS tree.
|
| 248 |
+
:param node_id: The ID of the node to which the route should be visualized.
|
| 249 |
+
:return: A string containing the SVG visualization of the route.
|
| 250 |
+
"""
|
| 251 |
+
# Box colors for molecule status
|
| 252 |
+
box_colors = {
|
| 253 |
+
"target": "#98EEFF", # Light Blue for the main target
|
| 254 |
+
"mulecule": "#F0AB90", # Peach for intermediates not in stock
|
| 255 |
+
"instock": "#9BFAB3", # Light Green for building blocks
|
| 256 |
+
}
|
| 257 |
+
|
| 258 |
+
# Obtain the sequence of reaction steps in retrosynthetic order
|
| 259 |
+
retro_reactions = list(reversed(tree.synthesis_route(node_id)))
|
| 260 |
+
|
| 261 |
+
# Handle the case of the root node with no preceding reactions
|
| 262 |
+
if not retro_reactions:
|
| 263 |
+
target_node = tree.nodes.get(node_id)
|
| 264 |
+
if not target_node:
|
| 265 |
+
return ""
|
| 266 |
+
molecule = target_node.curr_precursor.molecule
|
| 267 |
+
molecule.meta["status"] = "target"
|
| 268 |
+
return render_svg(tuple(), [[molecule]], box_colors)
|
| 269 |
+
|
| 270 |
+
# Map all unique molecule SMILES to their MoleculeContainer objects
|
| 271 |
+
mol_map = {str(m): m for r in retro_reactions for m in r.reactants + r.products}
|
| 272 |
+
|
| 273 |
+
# Set the status for each unique molecule
|
| 274 |
+
for smiles, molecule in mol_map.items():
|
| 275 |
+
molecule.meta["status"] = "instock" if smiles in tree.building_blocks else "mulecule"
|
| 276 |
+
|
| 277 |
+
# The final target is the product of the first retrosynthetic reaction
|
| 278 |
+
target_molecule = retro_reactions[0].products[0]
|
| 279 |
+
target_molecule.meta["status"] = "target"
|
| 280 |
+
mol_map[str(target_molecule)] = target_molecule
|
| 281 |
+
|
| 282 |
+
# --- Build columns from left to right based on reaction dependencies ---
|
| 283 |
+
columns = []
|
| 284 |
+
# Identify molecules that are products in any reaction step
|
| 285 |
+
products_smiles = {str(p) for r in retro_reactions for p in r.products}
|
| 286 |
+
|
| 287 |
+
# The leftmost column consists of reactants that are not products of any other step in the path
|
| 288 |
+
leftmost_smiles = {str(m) for r in retro_reactions for m in r.reactants} - products_smiles
|
| 289 |
+
|
| 290 |
+
if not leftmost_smiles: # Fallback for simple A->B routes
|
| 291 |
+
leftmost_smiles = {str(m) for m in retro_reactions[-1].reactants}
|
| 292 |
+
|
| 293 |
+
columns.append([mol_map[s] for s in leftmost_smiles])
|
| 294 |
+
placed_smiles = set(leftmost_smiles)
|
| 295 |
+
|
| 296 |
+
# Iteratively build the next columns
|
| 297 |
+
while len(placed_smiles) < len(mol_map):
|
| 298 |
+
next_products = set()
|
| 299 |
+
for r in retro_reactions:
|
| 300 |
+
# If all reactants for a reaction have been placed in previous columns...
|
| 301 |
+
if all(str(reactant) in placed_smiles for reactant in r.reactants):
|
| 302 |
+
# ...then its products belong in the next column.
|
| 303 |
+
for product in r.products:
|
| 304 |
+
if str(product) not in placed_smiles:
|
| 305 |
+
next_products.add(str(product))
|
| 306 |
+
|
| 307 |
+
if not next_products:
|
| 308 |
+
break # Safety break if no new column can be formed
|
| 309 |
+
|
| 310 |
+
columns.append([mol_map[s] for s in next_products])
|
| 311 |
+
placed_smiles.update(next_products)
|
| 312 |
+
|
| 313 |
+
# --- Prepare data for rendering ---
|
| 314 |
+
# Flatten the columns to get a single list of molecules for indexing
|
| 315 |
+
flat_mols = [mol for col in columns for mol in col]
|
| 316 |
+
mol_to_idx = {str(mol): i for i, mol in enumerate(flat_mols)}
|
| 317 |
+
|
| 318 |
+
# Define the connections (precursor -> product) for the SVG rendering
|
| 319 |
+
# The arrow in render_svg points from 'p' to 's'
|
| 320 |
+
pred = []
|
| 321 |
+
for reaction in retro_reactions:
|
| 322 |
+
for product in reaction.products:
|
| 323 |
+
if str(product) in mol_to_idx:
|
| 324 |
+
s_idx = mol_to_idx[str(product)] # 's' is the product (on the right)
|
| 325 |
+
for reactant in reaction.reactants:
|
| 326 |
+
if str(reactant) in mol_to_idx:
|
| 327 |
+
p_idx = mol_to_idx[str(reactant)] # 'p' is the reactant (on the left)
|
| 328 |
+
pred.append((s_idx, p_idx))
|
| 329 |
+
|
| 330 |
+
return render_svg(tuple(pred), columns, box_colors)
|
| 331 |
+
|
| 332 |
+
|
| 333 |
+
def get_route_svg(tree: Tree, node_id: int) -> str:
|
| 334 |
+
"""Visualizes the retrosynthetic route.
|
| 335 |
+
|
| 336 |
+
:param tree: The built tree.
|
| 337 |
+
:param node_id: The id of the node from which to visualize the route.
|
| 338 |
+
:return: The SVG string.
|
| 339 |
+
"""
|
| 340 |
+
nodes = tree.route_to_node(node_id)
|
| 341 |
+
# Set up node_id types for different box colors
|
| 342 |
+
for n in nodes:
|
| 343 |
+
for precursor in n.new_precursors:
|
| 344 |
+
precursor.molecule.meta["status"] = (
|
| 345 |
+
"instock"
|
| 346 |
+
if precursor.is_building_block(tree.building_blocks)
|
| 347 |
+
else "mulecule"
|
| 348 |
+
)
|
| 349 |
+
nodes[0].curr_precursor.molecule.meta["status"] = "target"
|
| 350 |
+
# Box colors
|
| 351 |
+
box_colors = {
|
| 352 |
+
"target": "#98EEFF", # 152, 238, 255
|
| 353 |
+
"mulecule": "#F0AB90", # 240, 171, 144
|
| 354 |
+
"instock": "#9BFAB3", # 155, 250, 179
|
| 355 |
+
}
|
| 356 |
+
|
| 357 |
+
# first column is target
|
| 358 |
+
# second column are first new precursor_to_expand
|
| 359 |
+
columns = [
|
| 360 |
+
[nodes[0].curr_precursor.molecule],
|
| 361 |
+
[x.molecule for x in nodes[1].new_precursors],
|
| 362 |
+
]
|
| 363 |
+
pred = {x: 0 for x in range(1, len(columns[1]) + 1)}
|
| 364 |
+
cx = [
|
| 365 |
+
n
|
| 366 |
+
for n, x in enumerate(nodes[1].new_precursors, 1)
|
| 367 |
+
if not x.is_building_block(tree.building_blocks)
|
| 368 |
+
]
|
| 369 |
+
size = len(cx)
|
| 370 |
+
nodes = iter(nodes[2:])
|
| 371 |
+
cy = count(len(columns[1]) + 1)
|
| 372 |
+
while size:
|
| 373 |
+
layer = []
|
| 374 |
+
for s in islice(nodes, size):
|
| 375 |
+
n = cx.pop(0)
|
| 376 |
+
for x in s.new_precursors:
|
| 377 |
+
layer.append(x)
|
| 378 |
+
m = next(cy)
|
| 379 |
+
if not x.is_building_block(tree.building_blocks):
|
| 380 |
+
cx.append(m)
|
| 381 |
+
pred[m] = n
|
| 382 |
+
size = len(cx)
|
| 383 |
+
columns.append([x.molecule for x in layer])
|
| 384 |
+
|
| 385 |
+
columns = [
|
| 386 |
+
columns[::-1] for columns in columns[::-1]
|
| 387 |
+
] # Reverse array to make retrosynthetic graph
|
| 388 |
+
pred = tuple( # Change dict to tuple to make multiple precursor_to_expand available
|
| 389 |
+
(abs(source - len(pred)), abs(target - len(pred)))
|
| 390 |
+
for target, source in pred.items()
|
| 391 |
+
)
|
| 392 |
+
svg = render_svg(pred, columns, box_colors)
|
| 393 |
+
return svg
|
| 394 |
+
|
| 395 |
+
|
| 396 |
+
def get_route_svg_from_json(routes_json: dict, route_id: int) -> str:
|
| 397 |
+
"""
|
| 398 |
+
Visualizes the retrosynthetic route described in routes_json[route_id].
|
| 399 |
+
|
| 400 |
+
:param routes_json: A dict mapping route IDs to nested JSON trees of molecules/reactions.
|
| 401 |
+
:param route_id: The id of the route from which to visualize the route.
|
| 402 |
+
:return: The SVG string .
|
| 403 |
+
"""
|
| 404 |
+
# 1) Parse JSON into per-depth lists of mol-dicts, remembering parent links
|
| 405 |
+
if route_id not in routes_json.keys():
|
| 406 |
+
try:
|
| 407 |
+
root = routes_json[str(route_id)]
|
| 408 |
+
except KeyError:
|
| 409 |
+
raise ValueError(f"Route ID {route_id} not found in routes_json.")
|
| 410 |
+
else:
|
| 411 |
+
root = routes_json[route_id]
|
| 412 |
+
levels = [] # levels[d] = list of mol-dicts at depth d
|
| 413 |
+
parent_of = {} # mol_id -> parent_mol_dict
|
| 414 |
+
Q = deque([(root, 0, None)])
|
| 415 |
+
while Q:
|
| 416 |
+
node, depth, parent = Q.popleft()
|
| 417 |
+
if node.get("type") != "mol":
|
| 418 |
+
continue
|
| 419 |
+
if len(levels) <= depth:
|
| 420 |
+
levels.append([])
|
| 421 |
+
levels[depth].append(node)
|
| 422 |
+
parent_of[id(node)] = parent
|
| 423 |
+
for child in node.get("children", []):
|
| 424 |
+
if child.get("type") == "reaction":
|
| 425 |
+
for mol_child in child.get("children", []):
|
| 426 |
+
if mol_child.get("type") == "mol":
|
| 427 |
+
Q.append((mol_child, depth + 1, node))
|
| 428 |
+
|
| 429 |
+
# 2) Build MoleculeContainer objects & set meta["status"]
|
| 430 |
+
mol_container = {}
|
| 431 |
+
for depth, mols in enumerate(levels):
|
| 432 |
+
for mol in mols:
|
| 433 |
+
m = read_smiles(mol["smiles"])
|
| 434 |
+
# target at depth=0, else instock vs mulecule
|
| 435 |
+
if depth == 0:
|
| 436 |
+
m.meta["status"] = "target"
|
| 437 |
+
else:
|
| 438 |
+
m.meta["status"] = (
|
| 439 |
+
"instock" if mol.get("in_stock", False) else "mulecule"
|
| 440 |
+
)
|
| 441 |
+
mol_container[id(mol)] = m
|
| 442 |
+
|
| 443 |
+
# 3) Mirror columns left↔right at the JSON level
|
| 444 |
+
json_columns = levels[::-1]
|
| 445 |
+
|
| 446 |
+
# 4) Flatten JSON node IDs in that mirrored order (so flat_index keys = id(mol_dict))
|
| 447 |
+
flat_node_ids = [id(m) for lvl in json_columns for m in lvl]
|
| 448 |
+
flat_index = {nid: idx for idx, nid in enumerate(flat_node_ids)}
|
| 449 |
+
|
| 450 |
+
# 5) Build pred from those JSON‐node IDs
|
| 451 |
+
pred = tuple(
|
| 452 |
+
(flat_index[id(parent)], flat_index[child_id])
|
| 453 |
+
for child_id, parent in parent_of.items()
|
| 454 |
+
if parent is not None
|
| 455 |
+
)
|
| 456 |
+
|
| 457 |
+
# 6) Now map JSON columns → MoleculeContainer columns for layout
|
| 458 |
+
columns = [[mol_container[id(m)] for m in lvl] for lvl in json_columns]
|
| 459 |
+
|
| 460 |
+
# 6) The rest is identical to your original rendering logic:
|
| 461 |
+
box_colors = {
|
| 462 |
+
"target": "#98EEFF",
|
| 463 |
+
"mulecule": "#F0AB90",
|
| 464 |
+
"instock": "#9BFAB3",
|
| 465 |
+
}
|
| 466 |
+
|
| 467 |
+
svg = render_svg(pred, columns, box_colors)
|
| 468 |
+
return svg
|
| 469 |
+
|
| 470 |
+
|
| 471 |
+
def generate_results_html(
|
| 472 |
+
tree: Tree, html_path: str, aam: bool = False, extended: bool = False
|
| 473 |
+
) -> None:
|
| 474 |
+
"""Writes an HTML page with the synthesis routes in SVG format and corresponding
|
| 475 |
+
reactions in SMILES format.
|
| 476 |
+
|
| 477 |
+
:param tree: The built tree.
|
| 478 |
+
:param extended: If True, generates the extended route representation.
|
| 479 |
+
:param html_path: The path to the file where to store resulting HTML.
|
| 480 |
+
:param aam: If True, depict atom-to-atom mapping.
|
| 481 |
+
:return: None.
|
| 482 |
+
"""
|
| 483 |
+
if aam:
|
| 484 |
+
MoleculeContainer.depict_settings(aam=True)
|
| 485 |
+
else:
|
| 486 |
+
MoleculeContainer.depict_settings(aam=False)
|
| 487 |
+
|
| 488 |
+
routes = []
|
| 489 |
+
if extended:
|
| 490 |
+
# Gather paths
|
| 491 |
+
for idx, node in tree.nodes.items():
|
| 492 |
+
if node.is_solved():
|
| 493 |
+
routes.append(idx)
|
| 494 |
+
else:
|
| 495 |
+
routes = tree.winning_nodes
|
| 496 |
+
# HTML Tags
|
| 497 |
+
th = '<th style="text-align: left; background-color:#978785; border: 1px solid black; border-spacing: 0">'
|
| 498 |
+
td = '<td style="text-align: left; border: 1px solid black; border-spacing: 0">'
|
| 499 |
+
font_red = "<font color='red' style='font-weight: bold'>"
|
| 500 |
+
font_green = "<font color='light-green' style='font-weight: bold'>"
|
| 501 |
+
font_head = "<font style='font-weight: bold; font-size: 18px'>"
|
| 502 |
+
font_normal = "<font style='font-weight: normal; font-size: 18px'>"
|
| 503 |
+
font_close = "</font>"
|
| 504 |
+
|
| 505 |
+
template_begin = """
|
| 506 |
+
<!doctype html>
|
| 507 |
+
<html lang="en">
|
| 508 |
+
<head>
|
| 509 |
+
<link href="https://cdn.jsdelivr.net/npm/bootstrap@5.1.3/dist/css/bootstrap.min.css"
|
| 510 |
+
rel="stylesheet"
|
| 511 |
+
integrity="sha384-1BmE4kWBq78iYhFldvKuhfTAU6auU8tT94WrHftjDbrCEXSU1oBoqyl2QvZ6jIW3"
|
| 512 |
+
crossorigin="anonymous">
|
| 513 |
+
<script
|
| 514 |
+
src="https://cdn.jsdelivr.net/npm/bootstrap@5.1.3/dist/js/bootstrap.bundle.min.js"
|
| 515 |
+
integrity="sha384-ka7Sk0Gln4gmtz2MlQnikT1wXgYsOg+OMhuP+IlRH9sENBO0LRn5q+8nbTov4+1p"
|
| 516 |
+
crossorigin="anonymous">
|
| 517 |
+
</script>
|
| 518 |
+
<meta charset="utf-8">
|
| 519 |
+
<meta name="viewport" content="width=device-width, initial-scale=1">
|
| 520 |
+
<title>Predicted Paths Report</title>
|
| 521 |
+
<meta name="description" content="A simple HTML5 Template for new projects.">
|
| 522 |
+
<meta name="author" content="SitePoint">
|
| 523 |
+
</head>
|
| 524 |
+
<body>
|
| 525 |
+
"""
|
| 526 |
+
template_end = """
|
| 527 |
+
</body>
|
| 528 |
+
</html>
|
| 529 |
+
"""
|
| 530 |
+
# SVG Template
|
| 531 |
+
box_mark = """
|
| 532 |
+
<svg width="30" height="30" viewBox="0 0 1 1" xmlns="http://www.w3.org/2000/svg">
|
| 533 |
+
<circle cx="0.5" cy="0.5" r="0.5" fill="rgb()" fill-opacity="0.35" />
|
| 534 |
+
</svg>
|
| 535 |
+
"""
|
| 536 |
+
# table = f"<table><thead><{th}>Retrosynthetic Routes</th></thead><tbody>"
|
| 537 |
+
table = """
|
| 538 |
+
<table class="table table-striped table-hover caption-top">
|
| 539 |
+
<caption><h3>Retrosynthetic Routes Report</h3></caption>
|
| 540 |
+
<tbody>"""
|
| 541 |
+
|
| 542 |
+
# Gather path data
|
| 543 |
+
table += f"<tr>{td}{font_normal}Target Molecule: {str(tree.nodes[1].curr_precursor)}{font_close}</td></tr>"
|
| 544 |
+
table += f"<tr>{td}{font_normal}Tree Size: {len(tree)}{font_close} nodes</td></tr>"
|
| 545 |
+
table += f"<tr>{td}{font_normal}Number of visited nodes: {len(tree.visited_nodes)}{font_close}</td></tr>"
|
| 546 |
+
table += f"<tr>{td}{font_normal}Found paths: {len(routes)}{font_close}</td></tr>"
|
| 547 |
+
table += f"<tr>{td}{font_normal}Time: {round(tree.curr_time, 4)}{font_close} seconds</td></tr>"
|
| 548 |
+
table += f"""
|
| 549 |
+
<tr>{td}
|
| 550 |
+
<div>
|
| 551 |
+
{box_mark.replace("rgb()", "rgb(152, 238, 255)")}
|
| 552 |
+
Target Molecule
|
| 553 |
+
{box_mark.replace("rgb()", "rgb(240, 171, 144)")}
|
| 554 |
+
Molecule Not In Stock
|
| 555 |
+
{box_mark.replace("rgb()", "rgb(155, 250, 179)")}
|
| 556 |
+
Molecule In Stock
|
| 557 |
+
</div>
|
| 558 |
+
</td></tr>
|
| 559 |
+
"""
|
| 560 |
+
|
| 561 |
+
for route in routes:
|
| 562 |
+
svg = get_route_svg(tree, route) # get SVG
|
| 563 |
+
full_route = tree.synthesis_route(route) # get route
|
| 564 |
+
# write SMILES of all reactions in synthesis path
|
| 565 |
+
step = 1
|
| 566 |
+
reactions = ""
|
| 567 |
+
for synth_step in full_route:
|
| 568 |
+
reactions += f"<b>Step {step}:</b> {str(synth_step)}<br>"
|
| 569 |
+
step += 1
|
| 570 |
+
# Concatenate all content of path
|
| 571 |
+
route_score = round(tree.route_score(route), 3)
|
| 572 |
+
table += (
|
| 573 |
+
f'<tr style="line-height: 250%">{td}{font_head}Route {route}; '
|
| 574 |
+
f"Steps: {len(full_route)}; "
|
| 575 |
+
f"Cumulated nodes' value: {route_score}{font_close}</td></tr>"
|
| 576 |
+
)
|
| 577 |
+
# f"Cumulated nodes' value: {node._probabilities[path]}{font_close}</td></tr>"
|
| 578 |
+
table += f"<tr>{td}{svg}</td></tr>"
|
| 579 |
+
table += f"<tr>{td}{reactions}</td></tr>"
|
| 580 |
+
table += "</tbody>"
|
| 581 |
+
if html_path is None:
|
| 582 |
+
return table
|
| 583 |
+
with open(html_path, "w", encoding="utf-8") as html_file:
|
| 584 |
+
html_file.write(template_begin)
|
| 585 |
+
html_file.write(table)
|
| 586 |
+
html_file.write(template_end)
|
| 587 |
+
|
| 588 |
+
|
| 589 |
+
def html_top_routes_cluster(clusters: dict, tree: Tree, target_smiles: str) -> str:
|
| 590 |
+
"""9. Clustering Results Download: Providing functionality to download the clustering results with styled HTML report."""
|
| 591 |
+
|
| 592 |
+
# Compute summary
|
| 593 |
+
total_routes = sum(len(data.get("node_ids", [])) for data in clusters.values())
|
| 594 |
+
total_clusters = len(clusters)
|
| 595 |
+
|
| 596 |
+
# Build styled HTML report using Bootstrap
|
| 597 |
+
html = []
|
| 598 |
+
|
| 599 |
+
html.append("<!doctype html><html lang='en'><head>")
|
| 600 |
+
html.append(
|
| 601 |
+
"<meta charset='utf-8'><meta name='viewport' content='width=device-width, initial-scale=1'>"
|
| 602 |
+
)
|
| 603 |
+
html.append(
|
| 604 |
+
"<link href='https://cdn.jsdelivr.net/npm/bootstrap@5.1.3/dist/css/bootstrap.min.css' rel='stylesheet'>"
|
| 605 |
+
)
|
| 606 |
+
html.append("<title>Clustering Results Report</title>")
|
| 607 |
+
html.append(
|
| 608 |
+
"<style> svg{max-width:100%;height:auto;} .report-table th,.report-table td{vertical-align:top;border:1px solid #dee2e6;} </style>"
|
| 609 |
+
)
|
| 610 |
+
html.append("</head><body><div class='container my-4'>")
|
| 611 |
+
# Report header
|
| 612 |
+
html.append(f"<h1 class='mb-3'>Best route from each cluster</h1>")
|
| 613 |
+
html.append(f"<p><strong>Target molecule (SMILES):</strong> {target_smiles}</p>")
|
| 614 |
+
html.append(f"<p><strong>Total number of routes:</strong> {total_routes}</p>")
|
| 615 |
+
html.append(f"<p><strong>Total number of clusters:</strong> {total_clusters}</p>")
|
| 616 |
+
# Table header
|
| 617 |
+
# html.append("<table class='table report-table'><thead><tr>")
|
| 618 |
+
html.append(
|
| 619 |
+
"<table class='table report-table'><colgroup><col style='width:5%'><colgroup><col style='width:5%'><col style='width:15%'><col style='width:75%'></colgroup><thead><tr>"
|
| 620 |
+
)
|
| 621 |
+
html.append(
|
| 622 |
+
"<th>Cluster index</th><th>Size</th><th>ReducedRouteCGR</th><th>Best Route</th>"
|
| 623 |
+
)
|
| 624 |
+
html.append("</tr></thead><tbody>")
|
| 625 |
+
|
| 626 |
+
# Rows per cluster
|
| 627 |
+
for cluster_num, group_data in clusters.items():
|
| 628 |
+
node_ids = group_data.get("node_ids", [])
|
| 629 |
+
if not node_ids:
|
| 630 |
+
continue
|
| 631 |
+
node_id = node_ids[0]
|
| 632 |
+
# Get SVGs
|
| 633 |
+
svg = get_route_svg(tree, node_id)
|
| 634 |
+
r_cgr = group_data.get("sb_cgr")
|
| 635 |
+
r_cgr_svg = None
|
| 636 |
+
if r_cgr:
|
| 637 |
+
r_cgr.clean2d()
|
| 638 |
+
r_cgr_svg = cgr_display(r_cgr)
|
| 639 |
+
# Start row
|
| 640 |
+
html.append(f"<tr><td>{cluster_num}</td>")
|
| 641 |
+
html.append(f"<td>{len(node_ids)}</td>")
|
| 642 |
+
# ReducedRouteCGR cell
|
| 643 |
+
html.append("<td>")
|
| 644 |
+
if r_cgr_svg:
|
| 645 |
+
b64_r = base64.b64encode(r_cgr_svg.encode("utf-8")).decode()
|
| 646 |
+
html.append(
|
| 647 |
+
f"<img src='data:image/svg+xml;base64,{b64_r}' alt='ReducedRouteCGR' class='img-fluid'/>"
|
| 648 |
+
)
|
| 649 |
+
html.append("</td>")
|
| 650 |
+
# Best Route cell
|
| 651 |
+
html.append("<td>")
|
| 652 |
+
if svg:
|
| 653 |
+
b64_svg = base64.b64encode(svg.encode("utf-8")).decode()
|
| 654 |
+
html.append(
|
| 655 |
+
f"<img src='data:image/svg+xml;base64,{b64_svg}' alt='Route {node_id}' class='img-fluid'/>"
|
| 656 |
+
)
|
| 657 |
+
html.append("</td></tr>")
|
| 658 |
+
|
| 659 |
+
# Close table and HTML
|
| 660 |
+
html.append("</tbody></table>")
|
| 661 |
+
html.append("</div></body></html>")
|
| 662 |
+
|
| 663 |
+
report_html = "".join(html)
|
| 664 |
+
return report_html
|
| 665 |
+
|
| 666 |
+
|
| 667 |
+
def routes_clustering_report(
|
| 668 |
+
source: Union[Tree, dict],
|
| 669 |
+
clusters: dict,
|
| 670 |
+
group_index: str,
|
| 671 |
+
sb_cgrs_dict: dict,
|
| 672 |
+
aam: bool = False,
|
| 673 |
+
html_path: str = None,
|
| 674 |
+
) -> str:
|
| 675 |
+
"""
|
| 676 |
+
Generates an HTML report visualizing a cluster of retrosynthetic routes.
|
| 677 |
+
|
| 678 |
+
This function takes a source of retrosynthetic routes (either a Tree object
|
| 679 |
+
or a dictionary representing routes in JSON format), cluster information,
|
| 680 |
+
and a dictionary of ReducedRouteCGRs, and produces a comprehensive HTML report.
|
| 681 |
+
The report includes details about the cluster, a representative ReducedRouteCGR,
|
| 682 |
+
and SVG visualizations of each route within the specified cluster.
|
| 683 |
+
|
| 684 |
+
Args:
|
| 685 |
+
source (Union[Tree, dict]): The source of retrosynthetic routes.
|
| 686 |
+
Can be a Tree object containing the full
|
| 687 |
+
search tree, or a dictionary loaded from
|
| 688 |
+
a routes JSON file.
|
| 689 |
+
clusters (dict): A dictionary containing clustering results. It should
|
| 690 |
+
contain information about different clusters, typically
|
| 691 |
+
including a list of 'node_ids' for each cluster.
|
| 692 |
+
group_index (str): The key identifying the specific cluster within the
|
| 693 |
+
`clusters` dictionary for which the report should be
|
| 694 |
+
generated.
|
| 695 |
+
sb_cgrs_dict (dict): A dictionary mapping route IDs (integers) to
|
| 696 |
+
ReducedRouteCGR (Retrosynthetic Graph-based Chemical
|
| 697 |
+
Reaction) objects. Used to display a representative
|
| 698 |
+
ReducedRouteCGR for the cluster.
|
| 699 |
+
aam (bool, optional): Whether to enable atom-atom mapping visualization
|
| 700 |
+
in molecule depictions. Defaults to False.
|
| 701 |
+
html_path (str, optional): The file path where the generated HTML
|
| 702 |
+
report should be saved. If provided, the
|
| 703 |
+
function saves the report to this file and
|
| 704 |
+
returns a confirmation message. If None,
|
| 705 |
+
the function returns the HTML string
|
| 706 |
+
directly. Defaults to None.
|
| 707 |
+
|
| 708 |
+
Returns:
|
| 709 |
+
str: The generated HTML report as a string, or a string confirming
|
| 710 |
+
the file path where the report was saved if `html_path` is
|
| 711 |
+
provided. Returns an error message string if the input `source`
|
| 712 |
+
or `clusters` are invalid, or if the specified `group_index` is
|
| 713 |
+
not found.
|
| 714 |
+
"""
|
| 715 |
+
# --- Depict Settings ---
|
| 716 |
+
try:
|
| 717 |
+
MoleculeContainer.depict_settings(aam=bool(aam))
|
| 718 |
+
except Exception:
|
| 719 |
+
pass
|
| 720 |
+
|
| 721 |
+
# --- Figure out what `source` is ---
|
| 722 |
+
using_tree = False
|
| 723 |
+
if hasattr(source, "nodes") and hasattr(source, "route_to_node"):
|
| 724 |
+
tree = source
|
| 725 |
+
using_tree = True
|
| 726 |
+
elif isinstance(source, dict):
|
| 727 |
+
routes_json = source
|
| 728 |
+
tree = None
|
| 729 |
+
else:
|
| 730 |
+
return "<html><body>Error: first argument must be a Tree or a routes_json dict.</body></html>"
|
| 731 |
+
|
| 732 |
+
# --- Validate clusters ---
|
| 733 |
+
if not isinstance(clusters, dict):
|
| 734 |
+
return "<html><body>Error: clusters must be a dict.</body></html>"
|
| 735 |
+
|
| 736 |
+
group = clusters.get(group_index)
|
| 737 |
+
if group is None:
|
| 738 |
+
return f"<html><body>Error: no group with index {group_index!r}.</body></html>"
|
| 739 |
+
|
| 740 |
+
cluster_node_ids = group.get("node_ids", [])
|
| 741 |
+
# Filter valid routes
|
| 742 |
+
valid_routes = []
|
| 743 |
+
|
| 744 |
+
if using_tree:
|
| 745 |
+
for nid in cluster_node_ids:
|
| 746 |
+
if nid in tree.nodes and tree.nodes[nid].is_solved():
|
| 747 |
+
valid_routes.append(nid)
|
| 748 |
+
else:
|
| 749 |
+
# JSON mode: check if the node ID exists in the routes_dict
|
| 750 |
+
routes_dict = make_dict(routes_json)
|
| 751 |
+
for nid in cluster_node_ids:
|
| 752 |
+
if nid in routes_dict.keys():
|
| 753 |
+
valid_routes.append(nid)
|
| 754 |
+
if not valid_routes:
|
| 755 |
+
return f"""
|
| 756 |
+
<!doctype html><html><body>
|
| 757 |
+
<h3>Cluster {group_index} Report</h3>
|
| 758 |
+
<p>No valid routes found in this cluster.</p>
|
| 759 |
+
</body></html>
|
| 760 |
+
"""
|
| 761 |
+
|
| 762 |
+
# --- Boilerplate HTML head/tail omitted for brevity ---
|
| 763 |
+
template_begin = (
|
| 764 |
+
"""<!doctype html><html><head>…</head><body><div class="container">"""
|
| 765 |
+
)
|
| 766 |
+
template_end = """</div></body></html>"""
|
| 767 |
+
|
| 768 |
+
table = f"""
|
| 769 |
+
<table class="table">
|
| 770 |
+
<caption><h3>Cluster {group_index} Routes</h3></caption>
|
| 771 |
+
<tbody>
|
| 772 |
+
"""
|
| 773 |
+
|
| 774 |
+
# show target
|
| 775 |
+
if using_tree:
|
| 776 |
+
try:
|
| 777 |
+
target_smiles = str(tree.nodes[1].curr_precursor)
|
| 778 |
+
except Exception:
|
| 779 |
+
target_smiles = "N/A"
|
| 780 |
+
else:
|
| 781 |
+
# JSON mode: take the root smiles of the first route
|
| 782 |
+
target_smiles = routes_json[str(valid_routes[0])]["smiles"]
|
| 783 |
+
|
| 784 |
+
# legend row omitted…
|
| 785 |
+
|
| 786 |
+
# --- HTML Templates & Tags ---
|
| 787 |
+
th = '<th style="text-align: left; background-color:#978785; border: 1px solid black; border-spacing: 0">'
|
| 788 |
+
td = '<td style="text-align: left; border: 1px solid black; border-spacing: 0">'
|
| 789 |
+
font_head = "<font style='font-weight: bold; font-size: 18px'>"
|
| 790 |
+
font_normal = "<font style='font-weight: normal; font-size: 18px'>"
|
| 791 |
+
font_close = "</font>"
|
| 792 |
+
|
| 793 |
+
template_begin = f"""
|
| 794 |
+
<!doctype html>
|
| 795 |
+
<html lang="en">
|
| 796 |
+
<head>
|
| 797 |
+
<link href="https://cdn.jsdelivr.net/npm/bootstrap@5.1.3/dist/css/bootstrap.min.css"
|
| 798 |
+
rel="stylesheet"
|
| 799 |
+
integrity="sha384-1BmE4kWBq78iYhFldvKuhfTAU6auU8tT94WrHftjDbrCEXSU1oBoqyl2QvZ6jIW3"
|
| 800 |
+
crossorigin="anonymous">
|
| 801 |
+
<meta charset="utf-8">
|
| 802 |
+
<meta name="viewport" content="width=device-width, initial-scale=1">
|
| 803 |
+
<title>Cluster {group_index} Routes Report</title>
|
| 804 |
+
<style>
|
| 805 |
+
/* Optional: Add some basic styling */
|
| 806 |
+
.table {{ border-collapse: collapse; width: 100%; }}
|
| 807 |
+
th, td {{ border: 1px solid #ddd; padding: 8px; text-align: left; }}
|
| 808 |
+
tr:nth-child(even) {{ background-color: #ffffff; }}
|
| 809 |
+
caption {{ caption-side: top; font-size: 1.5em; margin: 1em 0; }}
|
| 810 |
+
svg {{ max-width: 100%; height: auto; }}
|
| 811 |
+
</style>
|
| 812 |
+
</head>
|
| 813 |
+
<body>
|
| 814 |
+
<div class="container"> """
|
| 815 |
+
|
| 816 |
+
template_end = """
|
| 817 |
+
</div> <script
|
| 818 |
+
src="https://cdn.jsdelivr.net/npm/bootstrap@5.1.3/dist/js/bootstrap.bundle.min.js"
|
| 819 |
+
integrity="sha384-ka7Sk0Gln4gmtz2MlQnikT1wXgYsOg+OMhuP+IlRH9sENBO0LRn5q+8nbTov4+1p"
|
| 820 |
+
crossorigin="anonymous">
|
| 821 |
+
</script>
|
| 822 |
+
</body>
|
| 823 |
+
</html>
|
| 824 |
+
"""
|
| 825 |
+
|
| 826 |
+
box_mark = """
|
| 827 |
+
<svg width="30" height="30" viewBox="0 0 1 1" xmlns="http://www.w3.org/2000/svg" style="vertical-align: middle; margin-right: 5px;">
|
| 828 |
+
<circle cx="0.5" cy="0.5" r="0.5" fill="rgb()" fill-opacity="0.35" />
|
| 829 |
+
</svg>
|
| 830 |
+
"""
|
| 831 |
+
|
| 832 |
+
# --- Build HTML Table ---
|
| 833 |
+
table = f"""
|
| 834 |
+
<table class="table table-hover caption-top">
|
| 835 |
+
<caption><h3>Retrosynthetic Routes Report - Cluster {group_index}</h3></caption>
|
| 836 |
+
<tbody>"""
|
| 837 |
+
|
| 838 |
+
table += (
|
| 839 |
+
f"<tr>{td}{font_normal}Target Molecule: {target_smiles}{font_close}</td></tr>"
|
| 840 |
+
)
|
| 841 |
+
table += f"<tr>{td}{font_normal}Group index: {group_index}{font_close}</td></tr>"
|
| 842 |
+
table += f"<tr>{td}{font_normal}Size of Cluster: {len(valid_routes)} routes{font_close} </td></tr>"
|
| 843 |
+
|
| 844 |
+
# --- Add ReducedRouteCGR Image ---
|
| 845 |
+
first_route_id = valid_routes[0] if valid_routes else None
|
| 846 |
+
|
| 847 |
+
if first_route_id and sb_cgrs_dict:
|
| 848 |
+
try:
|
| 849 |
+
sb_cgr = sb_cgrs_dict[first_route_id]
|
| 850 |
+
sb_cgr.clean2d()
|
| 851 |
+
sb_cgr_svg = cgr_display(sb_cgr)
|
| 852 |
+
|
| 853 |
+
if sb_cgr_svg.strip().startswith("<svg"):
|
| 854 |
+
table += f"<tr>{td}{font_normal}Identified Strategic Bonds{font_close}<br>{sb_cgr_svg}</td></tr>"
|
| 855 |
+
else:
|
| 856 |
+
table += f"<tr>{td}{font_normal}Cluster Representative ReducedRouteCGR (from Route {first_route_id}):{font_close}<br><i>Invalid SVG format retrieved.</i></td></tr>"
|
| 857 |
+
print(
|
| 858 |
+
f"Warning: Expected SVG for ReducedRouteCGR of node {first_route_id}, but got: {sb_cgr_svg[:100]}..."
|
| 859 |
+
)
|
| 860 |
+
except Exception as e:
|
| 861 |
+
table += f"<tr>{td}{font_normal}Cluster Representative ReducedRouteCGR (from Route {first_route_id}):{font_close}<br><i>Error retrieving/displaying ReducedRouteCGR: {e}</i></td></tr>"
|
| 862 |
+
else:
|
| 863 |
+
if first_route_id:
|
| 864 |
+
table += f"<tr>{td}{font_normal}Cluster Representative ReducedRouteCGR (from Route {first_route_id}):{font_close}<br><i>Not found in provided ReducedRouteCGR dictionary.</i></td></tr>"
|
| 865 |
+
else:
|
| 866 |
+
table += f"<tr>{td}{font_normal}Cluster Representative ReducedRouteCGR:{font_close}<br><i>No valid routes in cluster to select from.</i></td></tr>"
|
| 867 |
+
|
| 868 |
+
table += f"""
|
| 869 |
+
<tr>{td}
|
| 870 |
+
<div style="display: flex; align-items: center; flex-wrap: wrap; gap: 15px;">
|
| 871 |
+
<span>{box_mark.replace("rgb()", "rgb(152, 238, 255)")} Target Molecule</span>
|
| 872 |
+
<span>{box_mark.replace("rgb()", "rgb(240, 171, 144)")} Molecule Not In Stock</span>
|
| 873 |
+
<span>{box_mark.replace("rgb()", "rgb(155, 250, 179)")} Molecule In Stock</span>
|
| 874 |
+
</div>
|
| 875 |
+
</td></tr>
|
| 876 |
+
"""
|
| 877 |
+
for route_id in valid_routes:
|
| 878 |
+
if using_tree:
|
| 879 |
+
# 1) SVG from Tree
|
| 880 |
+
svg = get_route_svg(tree, route_id)
|
| 881 |
+
# 2) Reaction steps & score
|
| 882 |
+
steps = tree.synthesis_route(route_id)
|
| 883 |
+
score = round(tree.route_score(route_id), 3)
|
| 884 |
+
# build reaction list
|
| 885 |
+
reac_html = "".join(
|
| 886 |
+
f"<b>Step {i+1}:</b> {str(r)}<br>" for i, r in enumerate(steps)
|
| 887 |
+
)
|
| 888 |
+
header = f"Route {route_id} — {len(steps)} steps, score={score}"
|
| 889 |
+
table += f"<tr><td><b>{header}</b></td></tr>"
|
| 890 |
+
table += f"<tr><td>{svg}</td></tr>"
|
| 891 |
+
table += f"<tr><td>{reac_html}</td></tr>"
|
| 892 |
+
|
| 893 |
+
else:
|
| 894 |
+
# 1) SVG from JSON
|
| 895 |
+
svg = get_route_svg_from_json(routes_json, route_id)
|
| 896 |
+
steps = routes_dict[route_id]
|
| 897 |
+
reac_html = "".join(
|
| 898 |
+
f"<b>Step {i+1}:</b> {str(r)}<br>" for i, r in steps.items()
|
| 899 |
+
)
|
| 900 |
+
|
| 901 |
+
header = f"Route {route_id} — {len(steps)} steps"
|
| 902 |
+
table += f"<tr><td><b>{header}</b></td></tr>"
|
| 903 |
+
table += f"<tr><td>{svg}</td></tr>"
|
| 904 |
+
table += f"<tr><td>{reac_html}</td></tr>"
|
| 905 |
+
|
| 906 |
+
table += "</tbody></table>"
|
| 907 |
+
|
| 908 |
+
html = template_begin + table + template_end
|
| 909 |
+
|
| 910 |
+
if html_path:
|
| 911 |
+
with open(html_path, "w", encoding="utf-8") as f:
|
| 912 |
+
f.write(html)
|
| 913 |
+
return f"Written to {html_path}"
|
| 914 |
+
return html
|
| 915 |
+
|
| 916 |
+
|
| 917 |
+
def lg_table_2_html(subcluster, nodes_to_display=[], if_display=True):
|
| 918 |
+
"""
|
| 919 |
+
Generates an HTML table visualizing leaving groups (X) 'marks' for routes within a subcluster.
|
| 920 |
+
|
| 921 |
+
This function creates an HTML table where each row represents a routes
|
| 922 |
+
from the specified subcluster (or a subset of nodes), and columns
|
| 923 |
+
represent unique 'marks' found across the nodes. The cells contain
|
| 924 |
+
the SVG depiction of the corresponding mark for that node.
|
| 925 |
+
|
| 926 |
+
Args:
|
| 927 |
+
subcluster (dict): A dictionary containing subcluster data, expected
|
| 928 |
+
to have a 'nodes_data' key mapping node IDs to
|
| 929 |
+
dictionaries of marks and their associated data
|
| 930 |
+
(where the first element is a depictable object).
|
| 931 |
+
nodes_to_display (list, optional): A list of specific node IDs to
|
| 932 |
+
include in the table. If empty,
|
| 933 |
+
all nodes in `subcluster["nodes_data"]`
|
| 934 |
+
are included. Defaults to [].
|
| 935 |
+
if_display (bool, optional): If True, the generated HTML is
|
| 936 |
+
displayed directly using `display(HTML())`.
|
| 937 |
+
Defaults to True.
|
| 938 |
+
|
| 939 |
+
Returns:
|
| 940 |
+
str: The generated HTML string for the table.
|
| 941 |
+
"""
|
| 942 |
+
# Create HTML table header
|
| 943 |
+
html = "<table style='border-collapse: collapse;'><tr><th style='border: 1px solid black; padding: 4px;'>Route ID</th>"
|
| 944 |
+
|
| 945 |
+
# Extract all unique marks across all nodes to form consistent columns
|
| 946 |
+
all_marks = set()
|
| 947 |
+
for node_data in subcluster["nodes_data"].values():
|
| 948 |
+
all_marks.update(node_data.keys())
|
| 949 |
+
all_marks = sorted(all_marks) # sort for consistent ordering
|
| 950 |
+
|
| 951 |
+
# Add marks as headers
|
| 952 |
+
for mark in all_marks:
|
| 953 |
+
html += f"<th style='border: 1px solid black; padding: 4px;'>{mark}</th>"
|
| 954 |
+
html += "</tr>"
|
| 955 |
+
|
| 956 |
+
# Fill in the rows
|
| 957 |
+
if len(nodes_to_display) == 0:
|
| 958 |
+
for node_id, node_data in subcluster["nodes_data"].items():
|
| 959 |
+
html += (
|
| 960 |
+
f"<tr><td style='border: 1px solid black; padding: 4px;'>{node_id}</td>"
|
| 961 |
+
)
|
| 962 |
+
for mark in all_marks:
|
| 963 |
+
html += "<td style='border: 1px solid black; padding: 4px;'>"
|
| 964 |
+
if mark in node_data:
|
| 965 |
+
svg = node_data[mark][0].depict() # Get SVG data as string
|
| 966 |
+
html += svg
|
| 967 |
+
html += "</td>"
|
| 968 |
+
html += "</tr>"
|
| 969 |
+
else:
|
| 970 |
+
for node_id in nodes_to_display:
|
| 971 |
+
# Check if the node_id exists in the subcluster data
|
| 972 |
+
if node_id in subcluster["nodes_data"]:
|
| 973 |
+
node_data = subcluster["nodes_data"][node_id]
|
| 974 |
+
html += f"<tr><td style='border: 1px solid black; padding: 4px;'>{node_id}</td>"
|
| 975 |
+
for mark in all_marks:
|
| 976 |
+
html += "<td style='border: 1px solid black; padding: 4px;'>"
|
| 977 |
+
if mark in node_data:
|
| 978 |
+
svg = node_data[mark][0].depict() # Get SVG data as string
|
| 979 |
+
html += svg
|
| 980 |
+
html += "</td>"
|
| 981 |
+
html += "</tr>"
|
| 982 |
+
else:
|
| 983 |
+
# Optionally, you can note that the node_id was not found
|
| 984 |
+
html += f"<tr><td colspan='{len(all_marks)+1}' style='border: 1px solid black; padding: 4px; color:red;'>Route ID {node_id} not found.</td></tr>"
|
| 985 |
+
|
| 986 |
+
html += "</table>"
|
| 987 |
+
|
| 988 |
+
if if_display:
|
| 989 |
+
display(HTML(html))
|
| 990 |
+
|
| 991 |
+
return html
|
| 992 |
+
|
| 993 |
+
|
| 994 |
+
def group_lg_table_2_html_fixed(
|
| 995 |
+
grouped: dict,
|
| 996 |
+
groups_to_display=None,
|
| 997 |
+
if_display=False,
|
| 998 |
+
max_group_col_width: int = 200,
|
| 999 |
+
) -> str:
|
| 1000 |
+
"""
|
| 1001 |
+
Generates an HTML table visualizing leaving groups X 'marks' for representative routes in grouped data.
|
| 1002 |
+
|
| 1003 |
+
This function takes a dictionary of grouped data, where each key represents
|
| 1004 |
+
a group (e.g., a collection of node IDs of routes) and the value is a representative
|
| 1005 |
+
dictionary of 'marks' for that group. It generates an HTML table with a
|
| 1006 |
+
fixed layout, where each row corresponds to a group, and columns show the
|
| 1007 |
+
SVG depiction or string representation of the 'marks' for the group's
|
| 1008 |
+
representative.
|
| 1009 |
+
|
| 1010 |
+
Args:
|
| 1011 |
+
grouped (dict): A dictionary where keys are group identifiers (e.g.,
|
| 1012 |
+
tuples of node IDs of routes) and values are dictionaries
|
| 1013 |
+
representing the 'marks' for the representative of
|
| 1014 |
+
that group. The 'marks' dictionary should map mark
|
| 1015 |
+
names (str) to objects that have a `.depict()` method
|
| 1016 |
+
or are convertible to a string.
|
| 1017 |
+
groups_to_display (list, optional): A list of specific group
|
| 1018 |
+
identifiers to include in the table.
|
| 1019 |
+
If None, all groups in the `grouped`
|
| 1020 |
+
dictionary are included. Defaults to None.
|
| 1021 |
+
if_display (bool, optional): If True, the generated HTML is
|
| 1022 |
+
displayed directly using `display(HTML())`.
|
| 1023 |
+
Defaults to False.
|
| 1024 |
+
max_group_col_width (int, optional): The maximum width (in pixels)
|
| 1025 |
+
for the column displaying the
|
| 1026 |
+
group identifiers. Defaults to 200.
|
| 1027 |
+
|
| 1028 |
+
Returns:
|
| 1029 |
+
str: The generated HTML string for the table.
|
| 1030 |
+
"""
|
| 1031 |
+
# 1) pick which groups to show
|
| 1032 |
+
if groups_to_display is None:
|
| 1033 |
+
groups = list(grouped.keys())
|
| 1034 |
+
else:
|
| 1035 |
+
groups = [g for g in groups_to_display if g in grouped]
|
| 1036 |
+
|
| 1037 |
+
# 2) collect all marks for the header
|
| 1038 |
+
all_marks = sorted({m for rep in grouped.values() for m in rep.keys()})
|
| 1039 |
+
|
| 1040 |
+
# 3) build table start with auto layout
|
| 1041 |
+
html = [
|
| 1042 |
+
"<table style='width:100%; table-layout:auto; border-collapse: collapse;'>",
|
| 1043 |
+
"<thead><tr>",
|
| 1044 |
+
"<th style='border:1px solid #ccc; padding:4px;'>Route IDs</th>",
|
| 1045 |
+
]
|
| 1046 |
+
# numeric headers
|
| 1047 |
+
html += [
|
| 1048 |
+
f"<th style='border:1px solid #ccc; padding:4px; text-align:center;'>{mark}</th>"
|
| 1049 |
+
for mark in all_marks
|
| 1050 |
+
]
|
| 1051 |
+
html.append("</tr></thead><tbody>")
|
| 1052 |
+
|
| 1053 |
+
# 4) each row
|
| 1054 |
+
group_td_style = (
|
| 1055 |
+
f"border:1px solid #ccc; padding:4px; "
|
| 1056 |
+
"white-space: normal; overflow-wrap: break-word; "
|
| 1057 |
+
f"max-width:{max_group_col_width}px;"
|
| 1058 |
+
)
|
| 1059 |
+
img_td_style = (
|
| 1060 |
+
"border:1px solid #ccc; padding:4px; text-align:center; vertical-align:middle;"
|
| 1061 |
+
)
|
| 1062 |
+
|
| 1063 |
+
for group in groups:
|
| 1064 |
+
rep = grouped[group]
|
| 1065 |
+
label = ",".join(str(n) for n in group)
|
| 1066 |
+
# start row
|
| 1067 |
+
row = [f"<td style='{group_td_style}'>{label}</td>"]
|
| 1068 |
+
# fill in each mark column
|
| 1069 |
+
for mark in all_marks:
|
| 1070 |
+
cell = ["<td style='" + img_td_style + "'>"]
|
| 1071 |
+
if mark in rep:
|
| 1072 |
+
val = rep[mark]
|
| 1073 |
+
cell.append(val.depict() if hasattr(val, "depict") else str(val))
|
| 1074 |
+
cell.append("</td>")
|
| 1075 |
+
row.append("".join(cell))
|
| 1076 |
+
html.append("<tr>" + "".join(row) + "</tr>")
|
| 1077 |
+
|
| 1078 |
+
html.append("</tbody></table>")
|
| 1079 |
+
out = "".join(html)
|
| 1080 |
+
if if_display:
|
| 1081 |
+
display(HTML(out))
|
| 1082 |
+
|
| 1083 |
+
return out
|
| 1084 |
+
|
| 1085 |
+
|
| 1086 |
+
def routes_subclustering_report(
|
| 1087 |
+
source: Union[Tree, dict],
|
| 1088 |
+
subcluster: dict,
|
| 1089 |
+
group_index: str,
|
| 1090 |
+
cluster_num: int,
|
| 1091 |
+
sb_cgrs_dict: dict,
|
| 1092 |
+
if_lg_group: bool = False,
|
| 1093 |
+
aam: bool = False,
|
| 1094 |
+
html_path: str = None,
|
| 1095 |
+
) -> str:
|
| 1096 |
+
"""
|
| 1097 |
+
Generates an HTML report visualizing a specific subcluster of retrosynthetic routes.
|
| 1098 |
+
|
| 1099 |
+
This function takes a source of retrosynthetic routes (either a Tree object
|
| 1100 |
+
or a dictionary representing routes in JSON format), data for a specific
|
| 1101 |
+
subcluster, and a dictionary of ReducedRouteCGRs. It produces a detailed HTML report
|
| 1102 |
+
for the subcluster, including general cluster information, a representative
|
| 1103 |
+
ReducedRouteCGR, a synthon pseudo reaction, a table of leaving groups (either per
|
| 1104 |
+
node or grouped), and SVG visualizations of each valid route within the
|
| 1105 |
+
subcluster.
|
| 1106 |
+
|
| 1107 |
+
Args:
|
| 1108 |
+
source (Union[Tree, dict]): The source of retrosynthetic routes.
|
| 1109 |
+
Can be a Tree object containing the full
|
| 1110 |
+
search tree, or a dictionary loaded from
|
| 1111 |
+
a routes JSON file.
|
| 1112 |
+
subcluster (dict): A dictionary containing data for the specific
|
| 1113 |
+
subcluster. Expected keys include 'nodes_data'
|
| 1114 |
+
(mapping node IDs to mark data), 'synthon_reaction',
|
| 1115 |
+
and optionally 'group_lgs' if `if_lg_group` is True.
|
| 1116 |
+
group_index (str): The index of the main cluster to which this
|
| 1117 |
+
subcluster belongs. Used for report titling.
|
| 1118 |
+
cluster_num (int): The number or identifier of the subcluster within
|
| 1119 |
+
its main group. Used for report titling.
|
| 1120 |
+
sb_cgrs_dict (dict): A dictionary mapping route IDs (integers) to
|
| 1121 |
+
ReducedRouteCGR objects. Used to display a representative
|
| 1122 |
+
ReducedRouteCGR for the cluster.
|
| 1123 |
+
if_lg_group (bool, optional): If True, the leaving groups table will
|
| 1124 |
+
display grouped leaving groups from
|
| 1125 |
+
`subcluster['group_lgs']`. If False, it
|
| 1126 |
+
will display leaving groups per individual
|
| 1127 |
+
node from `subcluster['nodes_data']`.
|
| 1128 |
+
Defaults to False.
|
| 1129 |
+
aam (bool, optional): Whether to enable atom-atom mapping visualization
|
| 1130 |
+
in molecule depictions. Defaults to False.
|
| 1131 |
+
html_path (str, optional): The file path where the generated HTML
|
| 1132 |
+
report should be saved. If provided, the
|
| 1133 |
+
function saves the report to this file and
|
| 1134 |
+
returns a confirmation message. If None,
|
| 1135 |
+
the function returns the HTML string
|
| 1136 |
+
directly. Defaults to None.
|
| 1137 |
+
|
| 1138 |
+
Returns:
|
| 1139 |
+
str: The generated HTML report as a string, or a string confirming
|
| 1140 |
+
the file path where the report was saved if `html_path` is
|
| 1141 |
+
provided. Returns a minimal HTML page indicating no valid routes
|
| 1142 |
+
if the subcluster contains no valid/solved routes. Returns an
|
| 1143 |
+
error message string if the input `source` or `subcluster` are
|
| 1144 |
+
invalid.
|
| 1145 |
+
"""
|
| 1146 |
+
# --- Depict Settings ---
|
| 1147 |
+
try:
|
| 1148 |
+
MoleculeContainer.depict_settings(aam=bool(aam))
|
| 1149 |
+
except Exception:
|
| 1150 |
+
pass
|
| 1151 |
+
|
| 1152 |
+
# --- Figure out what `source` is ---
|
| 1153 |
+
using_tree = False
|
| 1154 |
+
if hasattr(source, "nodes") and hasattr(source, "route_to_node"):
|
| 1155 |
+
tree = source
|
| 1156 |
+
using_tree = True
|
| 1157 |
+
elif isinstance(source, dict):
|
| 1158 |
+
routes_json = source
|
| 1159 |
+
tree = None
|
| 1160 |
+
else:
|
| 1161 |
+
return "<html><body>Error: first argument must be a Tree or a routes_json dict.</body></html>"
|
| 1162 |
+
|
| 1163 |
+
# --- Validate groups ---
|
| 1164 |
+
if not isinstance(subcluster, dict):
|
| 1165 |
+
return "<html><body>Error: groups must be a dict.</body></html>"
|
| 1166 |
+
|
| 1167 |
+
subcluster_node_ids = list(subcluster["nodes_data"].keys())
|
| 1168 |
+
# Filter valid routes
|
| 1169 |
+
valid_routes = []
|
| 1170 |
+
|
| 1171 |
+
if using_tree:
|
| 1172 |
+
for nid in subcluster_node_ids:
|
| 1173 |
+
if nid in tree.nodes and tree.nodes[nid].is_solved():
|
| 1174 |
+
valid_routes.append(nid)
|
| 1175 |
+
else:
|
| 1176 |
+
# JSON mode: just keep those IDs present in the JSON
|
| 1177 |
+
for nid in subcluster_node_ids:
|
| 1178 |
+
if nid in routes_json:
|
| 1179 |
+
valid_routes.append(nid)
|
| 1180 |
+
routes_dict = make_dict(routes_json)
|
| 1181 |
+
|
| 1182 |
+
if not valid_routes:
|
| 1183 |
+
# Return a minimal HTML page indicating no valid routes
|
| 1184 |
+
return f"""
|
| 1185 |
+
<!doctype html><html lang="en"><head><meta charset="utf-8">
|
| 1186 |
+
<title>Cluster {group_index}.{cluster_num} Report</title></head><body>
|
| 1187 |
+
<h3>Cluster {group_index}.{cluster_num} Report</h3>
|
| 1188 |
+
<p>No valid/solved routes found for this cluster.</p>
|
| 1189 |
+
</body></html>"""
|
| 1190 |
+
|
| 1191 |
+
# --- Boilerplate HTML head/tail omitted for brevity ---
|
| 1192 |
+
template_begin = (
|
| 1193 |
+
"""<!doctype html><html><head>…</head><body><div class="container">"""
|
| 1194 |
+
)
|
| 1195 |
+
template_end = """</div></body></html>"""
|
| 1196 |
+
|
| 1197 |
+
table = f"""
|
| 1198 |
+
<table class="table">
|
| 1199 |
+
<caption><h3>Cluster {group_index} Routes</h3></caption>
|
| 1200 |
+
<tbody>
|
| 1201 |
+
"""
|
| 1202 |
+
|
| 1203 |
+
# show target
|
| 1204 |
+
if using_tree:
|
| 1205 |
+
try:
|
| 1206 |
+
target_smiles = str(tree.nodes[1].curr_precursor)
|
| 1207 |
+
except Exception:
|
| 1208 |
+
target_smiles = "N/A"
|
| 1209 |
+
else:
|
| 1210 |
+
# JSON mode: take the root smiles of the first route
|
| 1211 |
+
target_smiles = routes_json[valid_routes[0]]["smiles"]
|
| 1212 |
+
|
| 1213 |
+
# legend row omitted…
|
| 1214 |
+
|
| 1215 |
+
# --- HTML Templates & Tags ---
|
| 1216 |
+
th = '<th style="text-align: left; background-color:#978785; border: 1px solid black; border-spacing: 0">'
|
| 1217 |
+
td = '<td style="text-align: left; border: 1px solid black; border-spacing: 0">'
|
| 1218 |
+
font_head = "<font style='font-weight: bold; font-size: 18px'>"
|
| 1219 |
+
font_normal = "<font style='font-weight: normal; font-size: 18px'>"
|
| 1220 |
+
font_close = "</font>"
|
| 1221 |
+
|
| 1222 |
+
template_begin = f"""
|
| 1223 |
+
<!doctype html>
|
| 1224 |
+
<html lang="en">
|
| 1225 |
+
<head>
|
| 1226 |
+
<link href="https://cdn.jsdelivr.net/npm/bootstrap@5.1.3/dist/css/bootstrap.min.css"
|
| 1227 |
+
rel="stylesheet"
|
| 1228 |
+
integrity="sha384-1BmE4kWBq78iYhFldvKuhfTAU6auU8tT94WrHftjDbrCEXSU1oBoqyl2QvZ6jIW3"
|
| 1229 |
+
crossorigin="anonymous">
|
| 1230 |
+
<meta charset="utf-8">
|
| 1231 |
+
<meta name="viewport" content="width=device-width, initial-scale=1">
|
| 1232 |
+
<title>SubCluster {group_index}.{cluster_num} Routes Report</title>
|
| 1233 |
+
<style>
|
| 1234 |
+
/* Optional: Add some basic styling */
|
| 1235 |
+
.table {{ border-collapse: collapse; width: 100%; }}
|
| 1236 |
+
th, td {{ border: 1px solid #ddd; padding: 8px; text-align: left; }}
|
| 1237 |
+
tr:nth-child(even) {{ background-color: #ffffff; }}
|
| 1238 |
+
caption {{ caption-side: top; font-size: 1.5em; margin: 1em 0; }}
|
| 1239 |
+
svg {{ max-width: 100%; height: auto; }}
|
| 1240 |
+
</style>
|
| 1241 |
+
</head>
|
| 1242 |
+
<body>
|
| 1243 |
+
<div class="container"> """
|
| 1244 |
+
|
| 1245 |
+
template_end = """
|
| 1246 |
+
</div> <script
|
| 1247 |
+
src="https://cdn.jsdelivr.net/npm/bootstrap@5.1.3/dist/js/bootstrap.bundle.min.js"
|
| 1248 |
+
integrity="sha384-ka7Sk0Gln4gmtz2MlQnikT1wXgYsOg+OMhuP+IlRH9sENBO0LRn5q+8nbTov4+1p"
|
| 1249 |
+
crossorigin="anonymous">
|
| 1250 |
+
</script>
|
| 1251 |
+
</body>
|
| 1252 |
+
</html>
|
| 1253 |
+
"""
|
| 1254 |
+
|
| 1255 |
+
box_mark = """
|
| 1256 |
+
<svg width="30" height="30" viewBox="0 0 1 1" xmlns="http://www.w3.org/2000/svg" style="vertical-align: middle; margin-right: 5px;">
|
| 1257 |
+
<circle cx="0.5" cy="0.5" r="0.5" fill="rgb()" fill-opacity="0.35" />
|
| 1258 |
+
</svg>
|
| 1259 |
+
"""
|
| 1260 |
+
|
| 1261 |
+
# --- Build HTML Table ---
|
| 1262 |
+
table = f"""
|
| 1263 |
+
<table class="table table-hover caption-top">
|
| 1264 |
+
<caption><h3>Retrosynthetic Routes Report - Cluster {group_index}.{cluster_num}</h3></caption>
|
| 1265 |
+
<tbody>"""
|
| 1266 |
+
|
| 1267 |
+
table += (
|
| 1268 |
+
f"<tr>{td}{font_normal}Target Molecule: {target_smiles}{font_close}</td></tr>"
|
| 1269 |
+
)
|
| 1270 |
+
table += f"<tr>{td}{font_normal}Group index: {group_index}{font_close}</td></tr>"
|
| 1271 |
+
table += f"<tr>{td}{font_normal}Cluster Number: {cluster_num}{font_close}</td></tr>"
|
| 1272 |
+
table += f"<tr>{td}{font_normal}Size of Cluster: {len(valid_routes)} routes{font_close} </td></tr>"
|
| 1273 |
+
|
| 1274 |
+
# --- Add ReducedRouteCGR Image ---
|
| 1275 |
+
first_route_id = valid_routes[0] if valid_routes else None
|
| 1276 |
+
|
| 1277 |
+
if first_route_id and sb_cgrs_dict:
|
| 1278 |
+
try:
|
| 1279 |
+
sb_cgr = sb_cgrs_dict[first_route_id]
|
| 1280 |
+
sb_cgr.clean2d()
|
| 1281 |
+
sb_cgr_svg = cgr_display(sb_cgr)
|
| 1282 |
+
|
| 1283 |
+
if sb_cgr_svg.strip().startswith("<svg"):
|
| 1284 |
+
table += f"<tr>{td}{font_normal}Identified Strategic Bonds{font_close}<br>{sb_cgr_svg}</td></tr>"
|
| 1285 |
+
else:
|
| 1286 |
+
table += f"<tr>{td}{font_normal}Cluster Representative ReducedRouteCGR (from Route {first_route_id}):{font_close}<br><i>Invalid SVG format retrieved.</i></td></tr>"
|
| 1287 |
+
print(
|
| 1288 |
+
f"Warning: Expected SVG for ReducedRouteCGR of node {first_route_id}, but got: {sb_cgr_svg[:100]}..."
|
| 1289 |
+
)
|
| 1290 |
+
except Exception as e:
|
| 1291 |
+
table += f"<tr>{td}{font_normal}Cluster Representative ReducedRouteCGR (from Route {first_route_id}):{font_close}<br><i>Error retrieving/displaying ReducedRouteCGR: {e}</i></td></tr>"
|
| 1292 |
+
else:
|
| 1293 |
+
if first_route_id:
|
| 1294 |
+
table += f"<tr>{td}{font_normal}Cluster Representative ReducedRouteCGR (from Route {first_route_id}):{font_close}<br><i>Not found in provided ReducedRouteCGR dictionary.</i></td></tr>"
|
| 1295 |
+
else:
|
| 1296 |
+
table += f"<tr>{td}{font_normal}Cluster Representative ReducedRouteCGR:{font_close}<br><i>No valid routes in cluster to select from.</i></td></tr>"
|
| 1297 |
+
|
| 1298 |
+
try:
|
| 1299 |
+
synthon_reaction = subcluster["synthon_reaction"]
|
| 1300 |
+
synthon_reaction.clean2d()
|
| 1301 |
+
synthon_svg = depict_custom_reaction(synthon_reaction)
|
| 1302 |
+
|
| 1303 |
+
extra_synthon = f"<tr>{td}{font_normal}Synthon pseudo reaction:{font_close}<br>{synthon_svg}</td></tr>"
|
| 1304 |
+
table += extra_synthon
|
| 1305 |
+
except Exception as e:
|
| 1306 |
+
table += f"<tr><td colspan='1' style='color: red;'>Error displaying synthon reaction: {e}</td></tr>"
|
| 1307 |
+
|
| 1308 |
+
try:
|
| 1309 |
+
if if_lg_group:
|
| 1310 |
+
grouped_lgs = subcluster["group_lgs"]
|
| 1311 |
+
lg_table_html = group_lg_table_2_html_fixed(grouped_lgs, if_display=False)
|
| 1312 |
+
else:
|
| 1313 |
+
lg_table_html = lg_table_2_html(subcluster, if_display=False)
|
| 1314 |
+
extra_lg = f"<tr>{td}{font_normal}Leaving Groups table:{font_close}<br>{lg_table_html}</td></tr>"
|
| 1315 |
+
table += extra_lg
|
| 1316 |
+
except Exception as e:
|
| 1317 |
+
table += f"<tr><td colspan='1' style='color: red;'>Error displaying leaving groups: {e}</td></tr>"
|
| 1318 |
+
|
| 1319 |
+
table += f"""
|
| 1320 |
+
<tr>{td}
|
| 1321 |
+
<div style="display: flex; align-items: center; flex-wrap: wrap; gap: 15px;">
|
| 1322 |
+
<span>{box_mark.replace("rgb()", "rgb(152, 238, 255)")} Target Molecule</span>
|
| 1323 |
+
<span>{box_mark.replace("rgb()", "rgb(240, 171, 144)")} Molecule Not In Stock</span>
|
| 1324 |
+
<span>{box_mark.replace("rgb()", "rgb(155, 250, 179)")} Molecule In Stock</span>
|
| 1325 |
+
</div>
|
| 1326 |
+
</td></tr>
|
| 1327 |
+
"""
|
| 1328 |
+
for route_id in valid_routes:
|
| 1329 |
+
if using_tree:
|
| 1330 |
+
# 1) SVG from Tree
|
| 1331 |
+
svg = get_route_svg(tree, route_id)
|
| 1332 |
+
# 2) Reaction steps & score
|
| 1333 |
+
steps = tree.synthesis_route(route_id)
|
| 1334 |
+
score = round(tree.route_score(route_id), 3)
|
| 1335 |
+
# build reaction list
|
| 1336 |
+
reac_html = "".join(
|
| 1337 |
+
f"<b>Step {i+1}:</b> {str(r)}<br>" for i, r in enumerate(steps)
|
| 1338 |
+
)
|
| 1339 |
+
header = f"Route {route_id} — {len(steps)} steps, score={score}"
|
| 1340 |
+
table += f"<tr><td><b>{header}</b></td></tr>"
|
| 1341 |
+
table += f"<tr><td>{svg}</td></tr>"
|
| 1342 |
+
table += f"<tr><td>{reac_html}</td></tr>"
|
| 1343 |
+
|
| 1344 |
+
else:
|
| 1345 |
+
# 1) SVG from JSON
|
| 1346 |
+
svg = get_route_svg_from_json(routes_json, route_id)
|
| 1347 |
+
steps = routes_dict[route_id]
|
| 1348 |
+
reac_html = "".join(
|
| 1349 |
+
f"<b>Step {i+1}:</b> {str(r)}<br>" for i, r in steps.items()
|
| 1350 |
+
)
|
| 1351 |
+
|
| 1352 |
+
header = f"Route {route_id} — {len(steps)} steps"
|
| 1353 |
+
table += f"<tr><td><b>{header}</b></td></tr>"
|
| 1354 |
+
table += f"<tr><td>{svg}</td></tr>"
|
| 1355 |
+
table += f"<tr><td>{reac_html}</td></tr>"
|
| 1356 |
+
|
| 1357 |
+
table += "</tbody></table>"
|
| 1358 |
+
|
| 1359 |
+
html = template_begin + table + template_end
|
| 1360 |
+
|
| 1361 |
+
if html_path:
|
| 1362 |
+
with open(html_path, "w", encoding="utf-8") as f:
|
| 1363 |
+
f.write(html)
|
| 1364 |
+
return f"Written to {html_path}"
|
| 1365 |
+
return html
|