Gilmullin Almaz commited on
Commit
914ea41
·
1 Parent(s): f2f3593

Refactor code structure and remove redundant sections for improved readability and maintainability

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. app.py +1148 -591
  2. cluster/clustering.py +0 -174
  3. cluster/generalized_cgr.py +0 -204
  4. cluster/reduced_g_cgr.py +0 -159
  5. cluster/subcluster.py +0 -33
  6. cluster/utils.py +0 -314
  7. cluster/visualize.py +0 -481
  8. synplan/__init__.py +3 -0
  9. synplan/chem/__init__.py +3 -0
  10. {cluster → synplan/chem/data}/__init__.py +0 -0
  11. synplan/chem/data/filtering.py +962 -0
  12. synplan/chem/data/standardizing.py +1187 -0
  13. synplan/chem/precursor.py +100 -0
  14. synplan/chem/reaction.py +125 -0
  15. synplan/chem/reaction_routes/__init__.py +0 -0
  16. synplan/chem/reaction_routes/clustering.py +857 -0
  17. synplan/chem/reaction_routes/io.py +286 -0
  18. synplan/chem/reaction_routes/leaving_groups.py +131 -0
  19. synplan/chem/reaction_routes/route_cgr.py +570 -0
  20. synplan/chem/reaction_routes/visualisation.py +903 -0
  21. synplan/chem/reaction_rules/__init__.py +0 -0
  22. synplan/chem/reaction_rules/extraction.py +744 -0
  23. synplan/chem/reaction_rules/manual/__init__.py +6 -0
  24. synplan/chem/reaction_rules/manual/decompositions.py +413 -0
  25. synplan/chem/reaction_rules/manual/transformations.py +532 -0
  26. synplan/chem/utils.py +225 -0
  27. synplan/interfaces/__init__.py +0 -0
  28. synplan/interfaces/building_blocks/building_blocks_em_sa_ln.smi +0 -0
  29. synplan/interfaces/cli.py +506 -0
  30. synplan/interfaces/gui.py +1304 -0
  31. synplan/interfaces/uspto/uspto_reaction_rules.pickle +3 -0
  32. synplan/interfaces/uspto/weights/ranking_policy_network.ckpt +3 -0
  33. synplan/mcts/__init__.py +8 -0
  34. synplan/mcts/evaluation.py +45 -0
  35. synplan/mcts/expansion.py +96 -0
  36. synplan/mcts/node.py +47 -0
  37. synplan/mcts/search.py +199 -0
  38. synplan/mcts/tree.py +635 -0
  39. synplan/ml/__init__.py +0 -0
  40. synplan/ml/networks/__init__.py +0 -0
  41. synplan/ml/networks/modules.py +234 -0
  42. synplan/ml/networks/policy.py +137 -0
  43. synplan/ml/networks/value.py +67 -0
  44. synplan/ml/training/__init__.py +11 -0
  45. synplan/ml/training/preprocessing.py +516 -0
  46. synplan/ml/training/reinforcement.py +379 -0
  47. synplan/ml/training/supervised.py +153 -0
  48. synplan/utils/__init__.py +4 -0
  49. synplan/utils/config.py +543 -0
  50. synplan/utils/files.py +226 -0
app.py CHANGED
@@ -2,6 +2,8 @@ import base64
2
  import pickle
3
  import re
4
  import uuid
 
 
5
 
6
  import pandas as pd
7
  import streamlit as st
@@ -15,18 +17,19 @@ from synplan.mcts.expansion import PolicyNetworkFunction
15
  from synplan.mcts.search import extract_tree_stats
16
  from synplan.mcts.tree import Tree
17
  from synplan.chem.utils import mol_from_smiles
 
 
 
 
 
 
 
 
 
 
18
  from synplan.utils.config import TreeConfig, PolicyNetworkConfig
19
  from synplan.utils.loading import load_reaction_rules, load_building_blocks
20
- from synplan.utils.visualisation import generate_results_html, get_route_svg
21
-
22
 
23
- from cluster.generalized_cgr import *
24
- from cluster.reduced_g_cgr import *
25
- from cluster.clustering import *
26
- from cluster.visualize import *
27
- from cluster.utils import *
28
- from cluster.subcluster import *
29
- from StructureFingerprint import MorganFingerprint
30
 
31
  import psutil
32
  import gc
@@ -35,8 +38,13 @@ import gc
35
  disable_progress_bars("huggingface_hub")
36
 
37
  smiles_parser = SMILESRead.create_parser(ignore=True)
 
38
 
39
- def download_button(object_to_download, download_filename, button_text, pickle_it=False):
 
 
 
 
40
  """
41
  Issued from
42
  Generates a link to download the given object_to_download.
@@ -68,21 +76,17 @@ def download_button(object_to_download, download_filename, button_text, pickle_i
68
  pass
69
 
70
  elif isinstance(object_to_download, pd.DataFrame):
71
- object_to_download = object_to_download.to_csv(index=False).encode('utf-8')
72
-
73
- # Try JSON encode for everything else # else: # object_to_download = json.dumps(object_to_download)
74
 
75
  try:
76
- # some strings <-> bytes conversions necessary here
77
  b64 = base64.b64encode(object_to_download.encode()).decode()
78
-
79
  except AttributeError:
80
  b64 = base64.b64encode(object_to_download).decode()
81
 
82
- button_uuid = str(uuid.uuid4()).replace('-', '')
83
- button_id = re.sub('\d+', '', button_uuid)
84
 
85
- custom_css = f"""
86
  <style>
87
  #{button_id} {{
88
  background-color: rgb(255, 255, 255);
@@ -93,7 +97,7 @@ def download_button(object_to_download, download_filename, button_text, pickle_i
93
  border-style: solid;
94
  border-color: rgb(230, 234, 241);
95
  border-image: initial;
96
- }}
97
  #{button_id}:hover {{
98
  border-color: rgb(246, 51, 102);
99
  color: rgb(246, 51, 102);
@@ -105,644 +109,1197 @@ def download_button(object_to_download, download_filename, button_text, pickle_i
105
  }}
106
  </style> """
107
 
108
- dl_link = custom_css + f'<a download="{download_filename}" id="{button_id}" href="data:file/txt;base64,{b64}">{button_text}</a><br></br>'
109
-
 
 
110
  return dl_link
111
 
112
 
113
- st.set_page_config(page_title="SynPlanner GUI", page_icon="🧪", layout="wide")
114
-
115
- # Initialize session state variables if they don't exist.
116
- if "planning_done" not in st.session_state:
117
- st.session_state.planning_done = False
118
- if "tree" not in st.session_state:
119
- st.session_state.tree = None
120
- if "res" not in st.session_state:
121
- st.session_state.res = None
122
- if "target_smiles" not in st.session_state:
123
- st.session_state.target_smiles = ''
124
-
125
- # Clustering state
126
- if "clustering_done" not in st.session_state:
127
- st.session_state.clustering_done = False
128
- if "clusters" not in st.session_state:
129
- st.session_state.clusters = None
130
- if "reactions_dict" not in st.session_state:
131
- st.session_state.reactions_dict = None
132
- if "num_clusters_setting" not in st.session_state: # Store the setting used
133
- st.session_state.num_clusters_setting = 10
134
-
135
- # Subclustering state
136
- if "subclustering_done" not in st.session_state:
137
- st.session_state.subclustering_done = False
138
- if "sub" not in st.session_state:
139
- st.session_state.sub = None
140
-
141
- # Download state (less critical now with direct download links)
142
- if 'clusters_downloaded' not in st.session_state: # Example, might not be needed
143
- st.session_state.clusters_downloaded = False
144
-
145
- intro_text = '''
146
- This is a demo of the graphical user interface of
147
- [SynPlanner](https://github.com/Laboratoire-de-Chemoinformatique/SynPlanner/).
148
- SynPlanner is a comprehensive tool for reaction data curation, rule extraction, model training and retrosynthetic planning.
149
-
150
- More information on SynPlanner is available in the [official docs](https://synplanner.readthedocs.io/en/latest/index.html).
151
- '''
152
-
153
- st.title("`SynPlanner GUI`")
154
-
155
- st.write(intro_text)
156
-
157
- st.header('Molecule input')
158
- st.markdown(
159
- '''
160
- You can provide a molecular structure by either providing:
161
- * SMILES string + Enter
162
- * Draw it + Apply
163
- '''
164
- )
165
-
166
- DEFAULT_MOL = 'c1cc(ccc1Cl)C(CCO)NC(C2(CCN(CC2)c3c4cc[nH]c4ncn3)N)=O'
167
- molecule = st.text_input("SMILES:", DEFAULT_MOL)
168
- smile_code = st_ketcher(molecule)
169
- target_molecule = mol_from_smiles(smile_code)
170
-
171
- if 'target_smiles' in st.session_state and smile_code != st.session_state.target_smiles:
172
- # If the SMILES changes, invalidate previous results
173
- st.warning("Molecule structure changed. Please re-run planning.")
174
- st.session_state.planning_done = False
175
- st.session_state.clustering_done = False
176
- st.session_state.subclustering_done = False
177
- st.session_state.tree = None
178
- st.session_state.res = None
179
- st.session_state.clusters = None
180
- st.session_state.reactions_dict = None
181
- st.session_state.sub = None
182
-
183
  @st.cache_resource
184
- def load_planning_resources():
185
  building_blocks_path = hf_hub_download(
186
- repo_id="Laboratoire-De-Chemoinformatique/SynPlanner",
187
- filename="building_blocks_em_sa_ln.smi",
188
- subfolder="building_blocks",
189
- local_dir="."
190
- )
191
  ranking_policy_weights_path = hf_hub_download(
192
- repo_id="Laboratoire-De-Chemoinformatique/SynPlanner",
193
- filename="ranking_policy_network.ckpt",
194
- subfolder="uspto/weights",
195
- local_dir="."
196
- )
197
  reaction_rules_path = hf_hub_download(
198
- repo_id="Laboratoire-De-Chemoinformatique/SynPlanner",
199
- filename="uspto_reaction_rules.pickle",
200
- subfolder="uspto",
201
- local_dir="."
202
- )
203
  return building_blocks_path, ranking_policy_weights_path, reaction_rules_path
204
 
205
- building_blocks_path, ranking_policy_weights_path, reaction_rules_path = load_planning_resources()
206
 
207
- st.header('Launch calculation')
208
- st.markdown(
209
- '''If you modified the structure, please ensure you clicked on `Apply` (bottom right of the molecular editor).'''
210
- )
211
- st.markdown(f"The molecule SMILES is actually: ``{smile_code}``")
212
 
213
- st.subheader('Planning options')
214
 
215
- st.markdown(
216
- '''
217
- The description of each option can be found in the
218
- [Retrosynthetic Planning Tutorial](https://synplanner.readthedocs.io/en/latest/tutorial_files/retrosynthetic_planning.html#Configuring-search-tree).
219
- '''
220
- )
221
 
222
- col_options_1, col_options_2 = st.columns(2, gap="medium")
223
-
224
- with col_options_1:
225
- search_strategy_input = st.selectbox(label='Search strategy', options=('Expansion first', 'Evaluation first',), index=0)
226
- ucb_type = st.selectbox(label='Search strategy', options=('uct', 'puct', 'value'), index=0)
227
- c_ucb = st.number_input("C coefficient of UCB", value=0.1, placeholder="Type a number...")
228
-
229
- with col_options_2:
230
- max_iterations = st.slider('Total number of MCTS iterations', min_value=50, max_value=1000, value=300)
231
- max_depth = st.slider('Maximal number of reaction steps', min_value=3, max_value=9, value=6)
232
- min_mol_size = st.slider('Minimum size of a molecule to be precursor', min_value=0, max_value=7, value=0)
233
-
234
- search_strategy_translator = {
235
- "Expansion first": "expansion_first",
236
- "Evaluation first": "evaluation_first",
237
- }
238
- search_strategy = search_strategy_translator[search_strategy_input]
239
-
240
- submit_planning = st.button('Start retrosynthetic planning')
241
-
242
- if submit_planning:
243
- # Reset downstream states if replanning
244
- st.session_state.planning_done = False
245
- st.session_state.clustering_done = False
246
- st.session_state.subclustering_done = False
247
- st.session_state.tree = None
248
- st.session_state.res = None
249
- st.session_state.clusters = None
250
- st.session_state.reactions_dict = None
251
- st.session_state.sub = None
252
- st.session_state.target_smiles = smile_code # Store the SMILES used for this run
253
 
254
- try:
255
- target_molecule = mol_from_smiles(smile_code)
256
- if target_molecule is None:
257
- st.error(f"Could not parse the input SMILES: {smile_code}")
258
- else:
259
- with st.spinner("Running retrosynthetic planning..."):
260
- with st.status("Loading resources...", expanded=False) as status:
261
- st.write("Loading building blocks...")
262
- building_blocks = load_building_blocks(building_blocks_path, standardize=False)
263
- st.write('Loading reaction rules...')
264
- reaction_rules = load_reaction_rules(reaction_rules_path)
265
- st.write('Loading policy network...')
266
- policy_config = PolicyNetworkConfig(weights_path=ranking_policy_weights_path)
267
- policy_function = PolicyNetworkFunction(policy_config=policy_config)
268
- status.update(label="Resources loaded!", state="complete")
269
-
270
- tree_config = TreeConfig(
271
- search_strategy=search_strategy,
272
- evaluation_type="rollout",
273
- max_iterations=max_iterations,
274
- max_depth=max_depth,
275
- min_mol_size=min_mol_size,
276
- init_node_value=0.5,
277
- ucb_type=ucb_type,
278
- c_ucb=c_ucb,
279
- silent=True
280
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
281
 
282
- tree = Tree(
283
- target=target_molecule,
284
- config=tree_config,
285
- reaction_rules=reaction_rules,
286
- building_blocks=building_blocks,
287
- expansion_function=policy_function,
288
- evaluation_function=None,
289
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
290
 
291
- mcts_progress_text = "Running MCTS iterations..."
292
- mcts_bar = st.progress(0, text=mcts_progress_text)
293
- for step, (solved, node_id) in enumerate(tree):
294
- progress_value = min(1.0, (step + 1) / max_iterations)
295
- mcts_bar.progress(progress_value, text=f"{mcts_progress_text} ({step+1}/{max_iterations})")
296
-
297
- res = extract_tree_stats(tree, target_molecule)
298
-
299
- # Store planning outputs in session_state
300
- st.session_state['tree'] = tree
301
- st.session_state['res'] = res
302
- st.session_state.planning_done = True
303
- st.rerun() # Rerun to display results cleanly
304
-
305
- except Exception as e:
306
- st.error(f"An error occurred during planning: {e}")
307
- st.session_state.planning_done = False # Ensure state reflects failure
308
-
309
- # Display results if planning has been completed
310
- if st.session_state.get('planning_done', False):
311
- res = st.session_state.res
312
- tree = st.session_state.tree
313
-
314
- if res is None or tree is None:
315
- st.error("Planning results are missing from session state. Please re-run planning.")
316
- st.session_state.planning_done = False # Reset state
317
- elif res["solved"]:
318
- st.header('Planning Results')
319
- # st.balloons() # Optional fun
320
- winning_nodes = sorted(set(tree.winning_nodes)) if hasattr(tree, 'winning_nodes') and tree.winning_nodes else []
321
- st.subheader(f"Number of unique routes found: {len(winning_nodes)}")
322
-
323
- st.subheader("Examples of found retrosynthetic routes")
324
- image_counter = 0
325
- visualised_node_ids = set()
326
- # Ensure winning_nodes is iterable and not empty
327
-
328
- if not winning_nodes:
329
- st.warning("Planning solved, but no winning nodes found in the tree object.")
330
- else:
331
- for n, node_id in enumerate(winning_nodes):
332
- if image_counter >= 3: # Use >= for clarity
333
- break
334
- # Simple display logic: show first 3 unique routes
335
- if node_id not in visualised_node_ids:
336
- try:
337
- visualised_node_ids.add(node_id)
338
- num_steps = len(tree.synthesis_route(node_id))
339
- route_score = round(tree.route_score(node_id), 3)
340
- svg = get_route_svg(tree, node_id)
341
- if svg:
342
- st.image(svg, caption=f"Route {node_id}; {num_steps} steps; Route score: {route_score}")
343
- image_counter += 1
344
- else:
345
- st.warning(f"Could not generate SVG for route {node_id}.")
346
- except Exception as e:
347
- st.error(f"Error displaying route {node_id}: {e}")
348
-
349
- stat_col, download_col = st.columns(2, gap="medium")
350
- with stat_col:
351
- st.subheader("Statistics")
352
- try:
353
- # Ensure 'target_smiles' exists in res, if not, use the stored one
354
- if 'target_smiles' not in res:
355
- res['target_smiles'] = st.session_state.target_smiles
356
- # Select only existing columns safely
357
- cols_to_show = [col for col in ["target_smiles", "num_routes", "num_nodes", "num_iter", "search_time"] if col in res]
358
- df = pd.DataFrame(res, index=[0])[cols_to_show]
359
- st.dataframe(df) # Use dataframe for better display
360
- except Exception as e:
361
- st.error(f"Error displaying statistics: {e}")
362
- st.write(res) # Show raw dict if DataFrame fails
363
-
364
- with download_col:
365
- st.subheader("Downloads")
366
- try:
367
- html_body = generate_results_html(tree, html_path=None, extended=True)
368
- dl_html = download_button(html_body, 'results_synplanner.html', 'Download results (HTML)')
369
- if dl_html: st.markdown(dl_html, unsafe_allow_html=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
370
 
371
- # Ensure res is suitable for DataFrame before creating/downloading
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
372
  try:
373
- res_df = pd.DataFrame(res, index=[0])
374
- dl_csv = download_button(res_df, 'results_synplanner.csv', 'Download statistics (CSV)')
375
- if dl_csv: st.markdown(dl_csv, unsafe_allow_html=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
376
  except Exception as e:
377
- st.error(f"Could not prepare statistics CSV for download: {e}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
378
 
 
 
 
 
 
 
 
 
 
379
  except Exception as e:
380
- st.error(f"Error generating download links: {e}")
381
 
382
- st.divider()
 
383
 
 
 
 
 
 
 
 
 
 
384
  st.header("Clustering the retrosynthetic routes")
385
 
386
- if 'num_clusters' not in st.session_state:
387
- st.session_state['num_clusters'] = 10
388
-
389
- cluster_box, _ = st.columns(2, gap="medium")
390
- with cluster_box:
391
- num_clusters_input = st.slider(
392
- 'Max number of clusters to generate',
393
- min_value=2,
394
- max_value=min(50, res.get("num_routes", 50)), # Sensible max based on routes found
395
- value=st.session_state.num_clusters_setting,
396
- key='cluster_slider'
397
- )
398
- if st.button('Run Clustering', key='submit_clustering'):
399
- # Update the setting in session state when the button is clicked
400
- st.session_state.num_clusters_setting = num_clusters_input
401
- # Reset downstream states
402
  st.session_state.clustering_done = False
403
  st.session_state.subclustering_done = False
404
  st.session_state.clusters = None
405
  st.session_state.reactions_dict = None
406
- st.session_state.sub = None
 
 
407
 
408
  with st.spinner("Performing clustering..."):
409
  try:
410
- # Ensure tree is available from session state
411
  current_tree = st.session_state.tree
412
  if not current_tree:
413
  st.error("Tree object not found. Please re-run planning.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
414
  else:
415
- st.write("Calculating Generalized CGRs...")
416
- g_cgrs_dict = reassign_nums(current_tree) # Assuming this needs the tree
417
- st.write("Processing RG-CGRs...")
418
- reduced_g_cgrs_dict = process_all_rg_cgrs(g_cgrs_dict) # Assuming this uses the previous output
419
-
420
- mfp = MorganFingerprint()
421
- st.write(f"Clustering into max {st.session_state.num_clusters_setting} clusters...")
422
- results = cluster_molecules(reduced_g_cgrs_dict, mfp, max_clusters=st.session_state.num_clusters_setting)
423
-
424
- st.session_state.clusters = results.get('clusters_dict')
425
- st.session_state.rg_cgrs_dict = reduced_g_cgrs_dict
426
- # Extract reactions *after* clustering if needed, ensure tree is passed
427
- st.write("Extracting reactions...")
428
- st.session_state.reactions_dict = extract_reactions(current_tree)
429
-
430
- if st.session_state.clusters and st.session_state.reactions_dict:
431
- st.session_state.clustering_done = True
432
- st.success(f"Clustering complete. Found {len(st.session_state.clusters)} clusters.")
433
- else:
434
- st.error("Clustering failed or returned empty results.")
435
- st.session_state.clustering_done = False
436
-
437
- # Clean up large intermediate objects if possible
438
- del g_cgrs_dict
439
- del results
440
- gc.collect()
441
 
442
- st.rerun() # Rerun to display clustering results cleanly
 
 
443
  except Exception as e:
444
  st.error(f"An error occurred during clustering: {e}")
445
  st.session_state.clustering_done = False
446
 
447
- # --- Display Clustering Results (if done) ---
448
- if st.session_state.get('clustering_done', False):
449
- clusters = st.session_state.clusters
450
- reactions_dict = st.session_state.reactions_dict # Needed for download
451
-
452
- tree = st.session_state.tree # Needed for display and download
453
 
454
- if not clusters or not reactions_dict or not tree:
455
- st.error("Clustering results are missing from session state. Please re-run clustering.")
456
- st.session_state.clustering_done = False # Reset flag
457
- else:
458
- st.subheader(f"Best routes from {len(clusters)} Found Clusters")
459
- # Display first route from first few clusters
460
- displayed_clusters = 0
461
- for cluster_num, node_id_list in clusters.items():
462
- if displayed_clusters >= 10: # Limit displayed clusters
463
- st.write(f"... and {len(clusters) - displayed_clusters} more clusters.")
464
- break
465
- if not node_id_list: continue # Skip empty clusters
466
-
467
- st.markdown(f"**Cluster {cluster_num}** (Size: {len(node_id_list)}) - Example Route:")
468
- node_id = node_id_list[0] # Display the first route as example
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
469
  try:
470
  num_steps = len(tree.synthesis_route(node_id))
471
  route_score = round(tree.route_score(node_id), 3)
472
  svg = get_route_svg(tree, node_id)
473
- if svg:
474
- st.image(svg, caption=f"Route {node_id}; {num_steps} steps; Route score: {route_score}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
475
  else:
476
- st.warning(f"Could not generate SVG for route {node_id}.")
477
- displayed_clusters += 1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
478
  except Exception as e:
479
- st.error(f"Error displaying route {node_id} for cluster {cluster_num+1}: {e}")
 
 
480
 
481
- cluster_sizes = [len(cluster) for cluster in clusters.values()]
482
- cluster_stat_col, cluster_download_col = st.columns(2, gap="medium")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
483
 
484
- with cluster_stat_col:
485
- st.subheader("Cluster Statistics")
486
- if cluster_sizes:
487
- cluster_df = pd.DataFrame({'Cluster': range(1, len(cluster_sizes) + 1), 'Number of Routes': cluster_sizes})
488
- st.dataframe(cluster_df)
489
 
490
- # Display Pie Chart using Matplotlib
491
- # try:
492
- # fig, ax = plt.subplots(figsize=(5, 4)) # Adjust size if needed
493
- # ax.pie(cluster_sizes, labels=[f'C{i+1}' for i in range(len(cluster_sizes))], autopct='%1.1f%%', startangle=90)
494
- # ax.axis('equal') # Equal aspect ratio ensures that pie is drawn as a circle.
495
- # st.pyplot(fig)
496
- # plt.close(fig) # Close the figure to free memory
497
- # except Exception as e:
498
- # st.error(f"Could not generate pie chart: {e}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
499
  else:
500
- st.write("No cluster data to display statistics for.")
 
 
 
 
 
 
 
 
 
 
501
 
502
- with cluster_download_col:
503
- st.subheader("Cluster Reports") # Changed subheader
 
504
 
505
- # Retrieve necessary data from session state
506
- tree_for_html = st.session_state.get('tree')
507
- clusters_for_html = st.session_state.get('clusters')
508
- rg_cgrs_for_html = st.session_state.get('rg_cgrs_dict')
509
 
510
- if not tree_for_html:
511
- st.warning("MCTS Tree data not found. Cannot generate reports.")
512
- elif not clusters_for_html:
513
- st.warning("Cluster data not found. Cannot generate reports.")
514
- else:
515
- st.write("Generate downloadable HTML reports for each cluster:")
516
 
517
- # Limit the number of download links shown directly if there are many clusters
518
- MAX_DOWNLOAD_LINKS_DISPLAYED = 15 # Adjust as needed
519
- num_clusters_total = len(clusters_for_html)
520
- clusters_items = list(clusters_for_html.items()) # Get items to slice
 
 
521
 
522
- for i, (cluster_num_idx, node_ids) in enumerate(clusters_items):
523
- if i >= MAX_DOWNLOAD_LINKS_DISPLAYED:
524
- st.caption(f"... plus {num_clusters_total - MAX_DOWNLOAD_LINKS_DISPLAYED} more cluster reports available.")
525
- # Consider adding a button to download all as a zip if needed
526
- break
527
 
528
- cluster_num_display = int(cluster_num_idx) # Use 1-based index
 
 
 
 
 
529
 
530
- if not node_ids: # Skip empty clusters
531
- st.caption(f"Cluster {cluster_num_display} is empty, no report generated.")
532
- continue
 
 
533
 
534
- try:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
535
  try:
536
- cluster_html_content = generate_cluster_html(
537
- tree=tree_for_html,
538
- cluster_node_ids=node_ids,
539
- cluster_num=cluster_num_display,
540
- rg_cgrs_dict=rg_cgrs_for_html,
541
- aam=False
542
  )
 
 
 
 
 
 
 
 
 
 
543
  except Exception as e:
544
- st.error(f"Error generating report/link for Cluster {cluster_num_display}: {e}")
545
-
546
-
547
- # --- Create the download button using the existing function ---
548
- download_filename = f"cluster_{cluster_num_display}_report.html"
549
- button_text = f"Cluster {cluster_num_display} Report (HTML)"
550
- dl_button_html = download_button(
551
- object_to_download=cluster_html_content,
552
- download_filename=download_filename,
553
- button_text=button_text
554
- # pickle_it=False # Ensure it's not pickled
555
- )
556
 
557
- if dl_button_html:
558
- st.markdown(dl_button_html, unsafe_allow_html=True)
559
- else:
560
- # Error message if button creation failed (e.g., encoding error)
561
- st.error(f"Failed to create download link for Cluster {cluster_num_display}.")
562
-
563
- except Exception as e:
564
- # Catch errors during HTML generation or button creation for a specific cluster
565
- st.error(f"Error generating report/link for Cluster {cluster_num_display}: {e}")
566
- # Optionally add more detailed logging here:
567
- # import traceback
568
- # st.error(traceback.format_exc())
569
-
570
- if num_clusters_total > MAX_DOWNLOAD_LINKS_DISPLAYED:
571
- # Optional: Add a button here to generate and download a ZIP file
572
- # containing all cluster reports. This requires more implementation
573
- # (using zipfile library in memory).
574
- # e.g., if st.button("Download All Reports as ZIP"): ...
575
- pass
576
-
577
-
578
- st.divider()
579
-
580
- # --- Subclustering Section ---
581
- st.header("Sub-Clustering within a selected Cluster")
582
-
583
- # Button to trigger the subclustering calculation
584
- if st.button("Run Subclustering Analysis", key="submit_subclustering"):
585
- st.session_state.subclustering_done = False # Reset flag
586
- st.session_state.sub = None # Clear old results
587
- with st.spinner("Performing subclustering analysis..."):
588
- try:
589
- # Retrieve necessary data from session state
590
- clusters_for_sub = st.session_state.get('clusters')
591
- reactions_for_sub = st.session_state.get('reactions_dict')
592
- if clusters_for_sub and reactions_for_sub:
593
- sub = sublcuster_all(clusters_for_sub, reactions_for_sub)
594
- st.session_state.sub = sub
595
- st.session_state.subclustering_done = True
596
- st.success("Subclustering analysis complete.")
597
- # Clean up intermediates if possible
598
- gc.collect()
599
- st.rerun() # Rerun to display results/inputs cleanly
600
- else:
601
- st.error("Missing cluster or reaction data needed for subclustering.")
602
- except Exception as e:
603
- st.error(f"An error occurred during subclustering: {e}")
604
- st.session_state.subclustering_done = False
605
-
606
-
607
- # Display subclustering inputs and results ONLY if subclustering is done
608
- if st.session_state.get('subclustering_done', False):
609
- sub = st.session_state.sub
610
- tree = st.session_state.tree
611
- clusters_for_sub = st.session_state.get('clusters')
612
-
613
- if not sub or not tree:
614
- st.error("Subclustering results are missing from session state. Please re-run subclustering.")
615
- st.session_state.subclustering_done = False
616
- else:
617
- sub_input_col, sub_display_col = st.columns([0.2, 0.8]) # Adjust column ratio if needed
618
-
619
- with sub_input_col:
620
- st.subheader("Select Cluster and Step")
621
- # Cluster selection (use cluster numbers as displayed, usually 1-based)
622
- available_cluster_nums = [int(k) for k in sub.keys()] # Use 1-based indexing for UI
623
- if not available_cluster_nums:
624
- st.warning("No clusters available in subclustering results.")
625
- else:
626
- # Key is essential here to maintain state across reruns
627
- user_input_cluster_num_display = st.selectbox(
628
- "Select Cluster #:",
629
- options=sorted(available_cluster_nums),
630
- key='subcluster_num_select'
631
- )
632
 
633
- # Convert back to 0-based index for accessing 'sub' dictionary
634
- selected_cluster_idx = user_input_cluster_num_display
 
 
 
 
 
635
 
636
- if selected_cluster_idx in sub:
637
- sub_step_cluster = sub[selected_cluster_idx]
638
- allowed_step_numbers = sorted(list(sub_step_cluster.keys()))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
639
 
640
- if not allowed_step_numbers:
641
- st.warning(f"No reaction steps found for Cluster {user_input_cluster_num_display}.")
642
- else:
643
- # Key is essential here
644
- selected_step_number = st.selectbox(
645
- "Select Number of Steps:",
646
- options=allowed_step_numbers,
647
- key='subcluster_steps_select'
648
- )
649
- # --- Display logic moved to the right column ---
650
- rg_cgrs = st.session_state.get('rg_cgrs_dict')
651
- cluster_rg_cgr = rg_cgrs[clusters_for_sub[user_input_cluster_num_display][0]]
652
- cluster_rg_cgr.clean2d()
653
- st.image(cluster_rg_cgr.depict(), caption=f"RG-CGR of cluster")
654
-
655
-
656
- else:
657
- st.warning(f"Selected cluster {user_input_cluster_num_display} (index {selected_cluster_idx}) not found in subclustering results.")
658
-
659
-
660
-
661
- with sub_display_col:
662
- st.subheader("Subcluster Results")
663
- # Check if inputs are valid before trying to display
664
- if 'user_input_cluster_num_display' in locals() and \
665
- 'selected_cluster_idx' in locals() and \
666
- selected_cluster_idx in sub and \
667
- 'selected_step_number' in locals() and \
668
- selected_step_number in sub[selected_cluster_idx]:
669
-
670
- subclusters = sub[selected_cluster_idx][selected_step_number]
671
- st.write(f"Displaying **{len(subclusters)}** subclusters for **Cluster {user_input_cluster_num_display}** with **{selected_step_number} steps**:")
672
-
673
- if not subclusters:
674
- st.info("No subclusters found for this selection.")
675
- else:
676
- # Limit the display if there are too many subclusters/routes
677
- MAX_DISPLAY_SUBCLUSTERS = 20
678
- MAX_ROUTES_PER_SUBCLUSTER = 10
679
-
680
- for subcluster_num, subcluster_set in enumerate(subclusters):
681
- if subcluster_num >= MAX_DISPLAY_SUBCLUSTERS:
682
- st.write(f"... and {len(subclusters) - MAX_DISPLAY_SUBCLUSTERS} more subclusters.")
683
- break
684
-
685
- st.markdown(f"--- \n**Subcluster {subcluster_num + 1}** (Size: {len(subcluster_set)})")
686
- routes_shown = 0
687
- for route_id in subcluster_set:
688
- if routes_shown >= MAX_ROUTES_PER_SUBCLUSTER:
689
- st.write(f"(Showing first {MAX_ROUTES_PER_SUBCLUSTER} routes)")
690
- break
691
- try:
692
- # Need num_steps and route_score for caption (optional but nice)
693
- num_steps_sub = len(tree.synthesis_route(route_id))
694
- route_score_sub = round(tree.route_score(route_id), 3)
695
- svg_sub = get_route_svg(tree, route_id)
696
- if svg_sub:
697
- st.image(svg_sub, caption=f"Route {route_id}; {num_steps_sub} steps; Score: {route_score_sub}")
698
- else:
699
- st.warning(f"Could not generate SVG for route {route_id}.")
700
- routes_shown += 1
701
- except Exception as e:
702
- st.error(f"Error displaying route {route_id} in subcluster {subcluster_num+1}: {e}")
703
- else:
704
- st.info("Select a cluster and step number to view subclusters.")
705
-
706
-
707
- # --- Handling No Solution Case ---
708
- elif not st.session_state.get('planning_done', False):
709
- # Only show this if planning was attempted but failed (or not run yet)
710
- # Avoid showing it if just molecule changed
711
- if submit_planning: # Check if the button was just pressed
712
- st.warning("Planning did not complete successfully or is still running.")
713
-
714
- else: # Planning done, but res["solved"] is False
715
- st.header('Planning Results')
716
- st.warning("No reaction path found for the target molecule with the current settings.")
717
- st.write("Consider adjusting planning options (e.g., increase iterations, adjust depth, check molecule validity).")
718
- # Optionally display basic stats even if not solved
719
- stat_col, _ = st.columns(2)
720
- with stat_col:
721
- st.subheader("Run Statistics (No Solution)")
722
  try:
723
- if 'target_smiles' not in res: res['target_smiles'] = st.session_state.target_smiles
724
- cols_to_show = [col for col in ["target_smiles", "num_nodes", "num_iter", "search_time"] if col in res]
725
- df = pd.DataFrame(res, index=[0])[cols_to_show]
726
- st.dataframe(df)
 
 
 
 
 
 
 
 
 
 
 
727
  except Exception as e:
728
- st.error(f"Error displaying statistics: {e}")
729
- st.write(res)
730
-
731
-
732
- # --- Restart Button ---
733
- st.divider()
734
- st.header('Restart Application State')
735
- if st.button("Clear All Results & Restart"):
736
- # Clear all relevant session state keys
737
- keys_to_clear = [
738
- "planning_done", "tree", "res", "target_smiles",
739
- "clustering_done", "clusters", "reactions_dict", "num_clusters_setting", "rg_cgrs_dict",
740
- "subclustering_done", "sub",
741
- "clusters_downloaded" # Add any other state keys you use
742
- ]
743
- for key in keys_to_clear:
744
- if key in st.session_state:
745
- del st.session_state[key]
746
- # Clear ketcher state by assigning a default value (or empty string)
747
- st.session_state.ketcher = DEFAULT_MOL # Reset ketcher to default
748
- st.rerun()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
  import pickle
3
  import re
4
  import uuid
5
+ import io
6
+ import zipfile
7
 
8
  import pandas as pd
9
  import streamlit as st
 
17
  from synplan.mcts.search import extract_tree_stats
18
  from synplan.mcts.tree import Tree
19
  from synplan.chem.utils import mol_from_smiles
20
+ from synplan.chem.reaction_routes.route_cgr import *
21
+ from synplan.chem.reaction_routes.clustering import *
22
+
23
+ from synplan.utils.visualisation import (
24
+ routes_clustering_report,
25
+ routes_subclustering_report,
26
+ generate_results_html,
27
+ html_top_routes_cluster,
28
+ get_route_svg,
29
+ )
30
  from synplan.utils.config import TreeConfig, PolicyNetworkConfig
31
  from synplan.utils.loading import load_reaction_rules, load_building_blocks
 
 
32
 
 
 
 
 
 
 
 
33
 
34
  import psutil
35
  import gc
 
38
  disable_progress_bars("huggingface_hub")
39
 
40
  smiles_parser = SMILESRead.create_parser(ignore=True)
41
+ DEFAULT_MOL = "c1cc(ccc1Cl)C(CCO)NC(C2(CCN(CC2)c3c4cc[nH]c4ncn3)N)=O"
42
 
43
+
44
+ # --- Helper Functions ---
45
+ def download_button(
46
+ object_to_download, download_filename, button_text, pickle_it=False
47
+ ):
48
  """
49
  Issued from
50
  Generates a link to download the given object_to_download.
 
76
  pass
77
 
78
  elif isinstance(object_to_download, pd.DataFrame):
79
+ object_to_download = object_to_download.to_csv(index=False).encode("utf-8")
 
 
80
 
81
  try:
 
82
  b64 = base64.b64encode(object_to_download.encode()).decode()
 
83
  except AttributeError:
84
  b64 = base64.b64encode(object_to_download).decode()
85
 
86
+ button_uuid = str(uuid.uuid4()).replace("-", "")
87
+ button_id = re.sub("\d+", "", button_uuid)
88
 
89
+ custom_css = f"""
90
  <style>
91
  #{button_id} {{
92
  background-color: rgb(255, 255, 255);
 
97
  border-style: solid;
98
  border-color: rgb(230, 234, 241);
99
  border-image: initial;
100
+ }}
101
  #{button_id}:hover {{
102
  border-color: rgb(246, 51, 102);
103
  color: rgb(246, 51, 102);
 
109
  }}
110
  </style> """
111
 
112
+ dl_link = (
113
+ custom_css
114
+ + f'<a download="{download_filename}" id="{button_id}" href="data:file/txt;base64,{b64}">{button_text}</a><br></br>'
115
+ )
116
  return dl_link
117
 
118
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
119
  @st.cache_resource
120
+ def load_planning_resources_cached(): # Renamed to avoid conflict if main calls it directly
121
  building_blocks_path = hf_hub_download(
122
+ repo_id="Laboratoire-De-Chemoinformatique/SynPlanner",
123
+ filename="building_blocks_em_sa_ln.smi",
124
+ subfolder="building_blocks",
125
+ local_dir=".",
126
+ )
127
  ranking_policy_weights_path = hf_hub_download(
128
+ repo_id="Laboratoire-De-Chemoinformatique/SynPlanner",
129
+ filename="ranking_policy_network.ckpt",
130
+ subfolder="uspto/weights",
131
+ local_dir=".",
132
+ )
133
  reaction_rules_path = hf_hub_download(
134
+ repo_id="Laboratoire-De-Chemoinformatique/SynPlanner",
135
+ filename="uspto_reaction_rules.pickle",
136
+ subfolder="uspto",
137
+ local_dir=".",
138
+ )
139
  return building_blocks_path, ranking_policy_weights_path, reaction_rules_path
140
 
 
141
 
142
+ # --- GUI Sections ---
 
 
 
 
143
 
 
144
 
145
+ def initialize_app():
146
+ """1. Initialization: Setting up the main window, layout, and initial widgets."""
147
+ st.set_page_config(page_title="SynPlanner GUI", page_icon="🧪", layout="wide")
 
 
 
148
 
149
+ # Initialize session state variables if they don't exist.
150
+ if "planning_done" not in st.session_state:
151
+ st.session_state.planning_done = False
152
+ if "tree" not in st.session_state:
153
+ st.session_state.tree = None
154
+ if "res" not in st.session_state:
155
+ st.session_state.res = None
156
+ if "target_smiles" not in st.session_state:
157
+ st.session_state.target_smiles = (
158
+ "" # Initial value, might be overwritten by ketcher
159
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
160
 
161
+ # Clustering state
162
+ if "clustering_done" not in st.session_state:
163
+ st.session_state.clustering_done = False
164
+ if "clusters" not in st.session_state:
165
+ st.session_state.clusters = None
166
+ if "reactions_dict" not in st.session_state:
167
+ st.session_state.reactions_dict = None
168
+ if "num_clusters_setting" not in st.session_state: # Store the setting used
169
+ st.session_state.num_clusters_setting = 10
170
+ if "route_cgrs_dict" not in st.session_state:
171
+ st.session_state.route_cgrs_dict = None
172
+ if "r_route_cgrs_dict" not in st.session_state:
173
+ st.session_state.r_route_cgrs_dict = None
174
+
175
+ # Subclustering state
176
+ if "subclustering_done" not in st.session_state:
177
+ st.session_state.subclustering_done = False
178
+ if "subclusters" not in st.session_state: # Renamed from 'sub' for clarity
179
+ st.session_state.subclusters = None
180
+
181
+ # Download state (less critical now with direct download links)
182
+ if "clusters_downloaded" not in st.session_state: # Example, might not be needed
183
+ st.session_state.clusters_downloaded = False
184
+
185
+ if "ketcher" not in st.session_state: # For ketcher persistence
186
+ st.session_state.ketcher = DEFAULT_MOL
187
+
188
+ intro_text = """
189
+ This is a demo of the graphical user interface of
190
+ [SynPlanner](https://github.com/Laboratoire-de-Chemoinformatique/SynPlanner/).
191
+ SynPlanner is a comprehensive tool for reaction data curation, rule extraction, model training and retrosynthetic planning.
192
+
193
+ More information on SynPlanner is available in the [official docs](https://synplanner.readthedocs.io/en/latest/index.html).
194
+ """
195
+ st.title("`SynPlanner GUI`")
196
+ st.write(intro_text)
197
+
198
+
199
+ def setup_sidebar():
200
+ """2. Sidebar: Handling the widgets and logic within the sidebar area."""
201
+ # st.sidebar.image("img/logo.png") # Assuming img/logo.png is available
202
+ st.sidebar.title("Docs")
203
+ st.sidebar.markdown("https://synplanner.readthedocs.io/en/latest/")
204
+
205
+ st.sidebar.title("Tutorials")
206
+ st.sidebar.markdown(
207
+ "https://github.com/Laboratoire-de-Chemoinformatique/SynPlanner/tree/main/tutorials"
208
+ )
209
+
210
+ st.sidebar.title("Paper")
211
+ st.sidebar.markdown(
212
+ "https://chemrxiv.org/engage/chemrxiv/article-details/66add90bc9c6a5c07ae65796"
213
+ )
214
+
215
+ st.sidebar.title("Issues")
216
+ st.sidebar.markdown(
217
+ "[Report a bug 🐞](https://github.com/Laboratoire-de-Chemoinformatique/SynPlanner/issues/new?assignees=&labels=bug&projects=&template=bug_report.md&title=%5BBUG%5D)"
218
+ )
219
+
220
+
221
+ def handle_molecule_input():
222
+ """3. Molecule Input: Managing the input area for molecule data."""
223
+ st.header("Molecule input")
224
+ st.markdown(
225
+ """
226
+ You can provide a molecular structure by either providing:
227
+ * SMILES string + Enter
228
+ * Draw it + Apply
229
+ """
230
+ )
231
+ # Use st.session_state.ketcher to persist drawn molecule
232
+ molecule_text_input = st.text_input(
233
+ "SMILES:", value=st.session_state.ketcher, key="smiles_text_input_key"
234
+ )
235
+
236
+ smile_code_ketcher = st_ketcher(molecule_text_input, key="ketcher_widget")
237
+ # col_kethcer, col_info = st.columns([0.8, 0.2])
238
+ # with col_kethcer:
239
+ # smile_code_ketcher = st_ketcher(molecule_text_input, key="ketcher_widget")
240
+ # with col_info:
241
+ # st.subheader("Synthetic Complexity")
242
+ # sascore = ()
243
+ # st.markdown(f"SAScore: {sascore}")
244
+ # syba_score = ()
245
+ # st.markdown(f"SYBA: {sascore}")
246
+
247
+ current_smile_code = (
248
+ smile_code_ketcher # The output from ketcher is the definitive SMILES
249
+ )
250
+
251
+ if (
252
+ "target_smiles" in st.session_state
253
+ and current_smile_code != st.session_state.target_smiles
254
+ ):
255
+ st.warning("Molecule structure changed. Please re-run planning.")
256
+ st.session_state.planning_done = False
257
+ st.session_state.clustering_done = False
258
+ st.session_state.subclustering_done = False
259
+ st.session_state.tree = None
260
+ st.session_state.res = None
261
+ st.session_state.clusters = None
262
+ st.session_state.reactions_dict = None
263
+ st.session_state.subclusters = None
264
+ st.session_state.ketcher = current_smile_code
265
+
266
+ return current_smile_code
267
+
268
+
269
+ def setup_planning_options():
270
+ """4. Planning: Encapsulating the logic related to the "planning" functionality."""
271
+ st.header("Launch calculation")
272
+ st.markdown(
273
+ """If you modified the structure, please ensure you clicked on `Apply` (bottom right of the molecular editor)."""
274
+ )
275
+ # This smile_code display will be updated if handle_molecule_input has run and returned a new smile_code
276
+ # However, to display it correctly, we need the current smile_code from the session or input handler.
277
+ # For simplicity, let's assume handle_molecule_input has updated st.session_state.ketcher
278
+ st.markdown(
279
+ f"The molecule SMILES is actually: ``{st.session_state.get('ketcher', DEFAULT_MOL)}``"
280
+ )
281
+
282
+ st.subheader("Planning options")
283
+ st.markdown(
284
+ """
285
+ The description of each option can be found in the
286
+ [Retrosynthetic Planning Tutorial](https://synplanner.readthedocs.io/en/latest/tutorial_files/retrosynthetic_planning.html#Configuring-search-tree).
287
+ """
288
+ )
289
+
290
+ col_options_1, col_options_2 = st.columns(2, gap="medium")
291
+ with col_options_1:
292
+ search_strategy_input = st.selectbox(
293
+ label="Search strategy",
294
+ options=(
295
+ "Expansion first",
296
+ "Evaluation first",
297
+ ),
298
+ index=0,
299
+ key="search_strategy_input",
300
+ )
301
+ ucb_type = st.selectbox(
302
+ label="UCB type",
303
+ options=("uct", "puct", "value"),
304
+ index=0,
305
+ key="ucb_type_input",
306
+ ) # Fixed label
307
+ c_ucb = st.number_input(
308
+ "C coefficient of UCB",
309
+ value=0.1,
310
+ placeholder="Type a number...",
311
+ key="c_ucb_input",
312
+ )
313
 
314
+ with col_options_2:
315
+ max_iterations = st.slider(
316
+ "Total number of MCTS iterations",
317
+ min_value=50,
318
+ max_value=1000,
319
+ value=300,
320
+ key="max_iterations_slider",
321
+ )
322
+ max_depth = st.slider(
323
+ "Maximal number of reaction steps",
324
+ min_value=3,
325
+ max_value=9,
326
+ value=6,
327
+ key="max_depth_slider",
328
+ )
329
+ min_mol_size = st.slider(
330
+ "Minimum size of a molecule to be precursor",
331
+ min_value=0,
332
+ max_value=7,
333
+ value=0,
334
+ key="min_mol_size_slider",
335
+ help="Number of non-hydrogen atoms in molecule",
336
+ )
337
 
338
+ search_strategy_translator = {
339
+ "Expansion first": "expansion_first",
340
+ "Evaluation first": "evaluation_first",
341
+ }
342
+ search_strategy = search_strategy_translator[search_strategy_input]
343
+
344
+ planning_params = {
345
+ "search_strategy": search_strategy,
346
+ "ucb_type": ucb_type,
347
+ "c_ucb": c_ucb,
348
+ "max_iterations": max_iterations,
349
+ "max_depth": max_depth,
350
+ "min_mol_size": min_mol_size,
351
+ }
352
+
353
+ if st.button("Start retrosynthetic planning", key="submit_planning_button"):
354
+ # Reset downstream states if replanning
355
+ st.session_state.planning_done = False
356
+ st.session_state.clustering_done = False
357
+ st.session_state.subclustering_done = False
358
+ st.session_state.tree = None
359
+ st.session_state.res = None
360
+ st.session_state.clusters = None
361
+ st.session_state.reactions_dict = None
362
+ st.session_state.subclusters = None
363
+ st.session_state.route_cgrs_dict = None
364
+ st.session_state.r_route_cgrs_dict = None
365
+ active_smile_code = st.session_state.get(
366
+ "ketcher", DEFAULT_MOL
367
+ ) # Get current SMILES
368
+ st.session_state.target_smiles = (
369
+ active_smile_code # Store the SMILES used for this run
370
+ )
371
+
372
+ try:
373
+ target_molecule = mol_from_smiles(active_smile_code)
374
+ if target_molecule is None:
375
+ st.error(f"Could not parse the input SMILES: {active_smile_code}")
376
+ else:
377
+ (
378
+ building_blocks_path,
379
+ ranking_policy_weights_path,
380
+ reaction_rules_path,
381
+ ) = load_planning_resources_cached()
382
+ with st.spinner("Running retrosynthetic planning..."):
383
+ with st.status("Loading resources...", expanded=False) as status:
384
+ st.write("Loading building blocks...")
385
+ building_blocks = load_building_blocks(
386
+ building_blocks_path, standardize=False
387
+ )
388
+ st.write("Loading reaction rules...")
389
+ reaction_rules = load_reaction_rules(reaction_rules_path)
390
+ st.write("Loading policy network...")
391
+ policy_config = PolicyNetworkConfig(
392
+ weights_path=ranking_policy_weights_path
393
+ )
394
+ policy_function = PolicyNetworkFunction(
395
+ policy_config=policy_config
396
+ )
397
+ status.update(label="Resources loaded!", state="complete")
398
+
399
+ tree_config = TreeConfig(
400
+ search_strategy=planning_params["search_strategy"],
401
+ evaluation_type="rollout", # This was hardcoded, keeping it.
402
+ max_iterations=planning_params["max_iterations"],
403
+ max_depth=planning_params["max_depth"],
404
+ min_mol_size=planning_params["min_mol_size"],
405
+ init_node_value=0.5, # This was hardcoded
406
+ ucb_type=planning_params["ucb_type"],
407
+ c_ucb=planning_params["c_ucb"],
408
+ silent=True, # This was hardcoded
409
+ )
410
+
411
+ tree = Tree(
412
+ target=target_molecule,
413
+ config=tree_config,
414
+ reaction_rules=reaction_rules,
415
+ building_blocks=building_blocks,
416
+ expansion_function=policy_function,
417
+ evaluation_function=None, # This was hardcoded
418
+ )
419
+
420
+ mcts_progress_text = "Running MCTS iterations..."
421
+ mcts_bar = st.progress(0, text=mcts_progress_text)
422
+ for step, (solved, node_id) in enumerate(tree):
423
+ progress_value = min(
424
+ 1.0, (step + 1) / planning_params["max_iterations"]
425
+ )
426
+ mcts_bar.progress(
427
+ progress_value,
428
+ text=f"{mcts_progress_text} ({step+1}/{planning_params['max_iterations']})",
429
+ )
430
+
431
+ res = extract_tree_stats(tree, target_molecule)
432
+
433
+ st.session_state["tree"] = tree
434
+ st.session_state["res"] = res
435
+ st.session_state.planning_done = True
436
+ st.rerun()
437
+
438
+ except Exception as e:
439
+ st.error(f"An error occurred during planning: {e}")
440
+ st.session_state.planning_done = False
441
+
442
+
443
+ def display_planning_results():
444
+ """5. Planning Results Display: Handling the presentation of results."""
445
+ if st.session_state.get("planning_done", False):
446
+ res = st.session_state.res
447
+ tree = st.session_state.tree
448
+
449
+ if res is None or tree is None:
450
+ st.error(
451
+ "Planning results are missing from session state. Please re-run planning."
452
+ )
453
+ st.session_state.planning_done = False # Reset state
454
+ return # Exit this function if no results
455
+
456
+ if res.get("solved", False): # Use .get for safety
457
+ st.header("Planning Results")
458
+ winning_nodes = (
459
+ sorted(set(tree.winning_nodes))
460
+ if hasattr(tree, "winning_nodes") and tree.winning_nodes
461
+ else []
462
+ )
463
+ st.subheader(f"Number of unique routes found: {len(winning_nodes)}")
464
 
465
+ st.subheader("Examples of found retrosynthetic routes")
466
+ image_counter = 0
467
+ visualised_node_ids = set()
468
+
469
+ if not winning_nodes:
470
+ st.warning(
471
+ "Planning solved, but no winning nodes found in the tree object."
472
+ )
473
+ else:
474
+ for n, node_id in enumerate(winning_nodes):
475
+ if image_counter >= 3:
476
+ break
477
+ if node_id not in visualised_node_ids:
478
+ try:
479
+ visualised_node_ids.add(node_id)
480
+ num_steps = len(tree.synthesis_route(node_id))
481
+ route_score = round(tree.route_score(node_id), 3)
482
+ svg = get_route_svg(tree, node_id)
483
+ if svg:
484
+ st.image(
485
+ svg,
486
+ caption=f"Route {node_id}; {num_steps} steps; Route score: {route_score}",
487
+ )
488
+ image_counter += 1
489
+ else:
490
+ st.warning(
491
+ f"Could not generate SVG for route {node_id}."
492
+ )
493
+ except Exception as e:
494
+ st.error(f"Error displaying route {node_id}: {e}")
495
+ else: # Not solved
496
+ st.header("Planning Results")
497
+ st.warning(
498
+ "No reaction path found for the target molecule with the current settings."
499
+ )
500
+ st.write(
501
+ "Consider adjusting planning options (e.g., increase iterations, adjust depth, check molecule validity)."
502
+ )
503
+ stat_col, _ = st.columns(2)
504
+ with stat_col:
505
+ st.subheader("Run Statistics (No Solution)")
506
  try:
507
+ if (
508
+ "target_smiles" not in res
509
+ and "target_smiles" in st.session_state
510
+ ):
511
+ res["target_smiles"] = st.session_state.target_smiles
512
+ cols_to_show = [
513
+ col
514
+ for col in [
515
+ "target_smiles",
516
+ "num_nodes",
517
+ "num_iter",
518
+ "search_time",
519
+ ]
520
+ if col in res
521
+ ]
522
+ if cols_to_show:
523
+ df = pd.DataFrame(res, index=[0])[cols_to_show]
524
+ st.dataframe(df)
525
+ else:
526
+ st.write("No statistics to display for the unsuccessful run.")
527
  except Exception as e:
528
+ st.error(f"Error displaying statistics: {e}")
529
+ st.write(res)
530
+
531
+
532
+ def download_planning_results():
533
+ """6. Planning Results Download: Providing functionality to download."""
534
+ if (
535
+ st.session_state.get("planning_done", False)
536
+ and st.session_state.res
537
+ and st.session_state.res.get("solved", False)
538
+ ):
539
+ res = st.session_state.res
540
+ tree = st.session_state.tree
541
+ # This section is usually placed within a column in the original script
542
+ # We'll assume it's called after display_planning_results and can use a new column or area.
543
+ # For proper layout, this should be integrated with display_planning_results' columns.
544
+ # For now, creating a placeholder or separate section for downloads:
545
+ # st.subheader("Downloads") # This might be redundant if called within a layout context.
546
+
547
+ # The original code places downloads in the second column of planning results.
548
+ # To replicate, we'd need to pass the column object or call this within that context.
549
+ # Simulating this by just creating the download links:
550
+ try:
551
+ html_body = generate_results_html(tree, html_path=None, extended=True)
552
+ dl_html = download_button(
553
+ html_body,
554
+ f"results_synplanner_{st.session_state.target_smiles}.html",
555
+ "Download results (HTML)",
556
+ )
557
+ if dl_html:
558
+ st.markdown(dl_html, unsafe_allow_html=True)
559
 
560
+ try:
561
+ res_df = pd.DataFrame(res, index=[0])
562
+ dl_csv = download_button(
563
+ res_df,
564
+ f"stats_synplanner_{st.session_state.target_smiles}.csv",
565
+ "Download statistics (CSV)",
566
+ )
567
+ if dl_csv:
568
+ st.markdown(dl_csv, unsafe_allow_html=True)
569
  except Exception as e:
570
+ st.error(f"Could not prepare statistics CSV for download: {e}")
571
 
572
+ except Exception as e:
573
+ st.error(f"Error generating download links for planning results: {e}")
574
 
575
+
576
+ def setup_clustering():
577
+ """7. Clustering: Encapsulating the logic related to the "clustering" functionality."""
578
+ if (
579
+ st.session_state.get("planning_done", False)
580
+ and st.session_state.res
581
+ and st.session_state.res.get("solved", False)
582
+ ):
583
+ st.divider()
584
  st.header("Clustering the retrosynthetic routes")
585
 
586
+ # num_clusters_input = st.number_input( # This input was removed in the final user code, so omitting.
587
+ # "Desired Number of Clusters (approximate):",
588
+ # min_value=2, max_value=50, value=st.session_state.get("num_clusters_setting", 10),
589
+ # key="num_clusters_input_key"
590
+ # )
591
+
592
+ if st.button("Run Clustering", key="submit_clustering_button"):
593
+ # st.session_state.num_clusters_setting = num_clusters_input
 
 
 
 
 
 
 
 
594
  st.session_state.clustering_done = False
595
  st.session_state.subclustering_done = False
596
  st.session_state.clusters = None
597
  st.session_state.reactions_dict = None
598
+ st.session_state.subclusters = None
599
+ st.session_state.route_cgrs_dict = None
600
+ st.session_state.r_route_cgrs_dict = None
601
 
602
  with st.spinner("Performing clustering..."):
603
  try:
 
604
  current_tree = st.session_state.tree
605
  if not current_tree:
606
  st.error("Tree object not found. Please re-run planning.")
607
+ return
608
+
609
+ st.write("Calculating RoutesCGRs...")
610
+ route_cgrs_dict = compose_all_route_cgrs(current_tree)
611
+ st.write("Processing ReducedRoutesCGRs...")
612
+ r_route_cgrs_dict = compose_all_reduced_route_cgrs(route_cgrs_dict)
613
+
614
+ results = cluster_routes(
615
+ r_route_cgrs_dict, use_strat=False
616
+ ) # num_clusters was removed from args
617
+ results = dict(sorted(results.items(), key=lambda x: float(x[0])))
618
+
619
+ st.session_state.clusters = results
620
+ st.session_state.route_cgrs_dict = route_cgrs_dict
621
+ st.session_state.r_route_cgrs_dict = r_route_cgrs_dict
622
+ st.write("Extracting reactions...")
623
+ st.session_state.reactions_dict = extract_reactions(current_tree)
624
+
625
+ if (
626
+ st.session_state.clusters is not None
627
+ and st.session_state.reactions_dict is not None
628
+ ): # Check for None explicitly
629
+ st.session_state.clustering_done = True
630
+ st.success(
631
+ f"Clustering complete. Found {len(st.session_state.clusters)} clusters."
632
+ )
633
  else:
634
+ st.error("Clustering failed or returned empty results.")
635
+ st.session_state.clustering_done = False
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
636
 
637
+ del results # route_cgrs_dict, r_route_cgrs_dict are stored
638
+ gc.collect()
639
+ st.rerun()
640
  except Exception as e:
641
  st.error(f"An error occurred during clustering: {e}")
642
  st.session_state.clustering_done = False
643
 
 
 
 
 
 
 
644
 
645
+ def display_clustering_results():
646
+ """8. Clustering Results Display: Handling the presentation of results."""
647
+ if st.session_state.get("clustering_done", False):
648
+ clusters = st.session_state.clusters
649
+ # reactions_dict = st.session_state.reactions_dict # Needed for download, not directly for display here
650
+ tree = st.session_state.tree
651
+ MAX_DISPLAY_CLUSTERS_DATA = 10
652
+
653
+ if (
654
+ clusters is None or tree is None
655
+ ): # reactions_dict removed as not critical for display part
656
+ st.error(
657
+ "Clustering results (clusters or tree) are missing. Please re-run clustering."
658
+ )
659
+ st.session_state.clustering_done = False
660
+ return
661
+
662
+ st.subheader(f"Best routes from {len(clusters)} Found Clusters")
663
+ clusters_items = list(clusters.items())
664
+ first_items = clusters_items[:MAX_DISPLAY_CLUSTERS_DATA]
665
+ remaining_items = clusters_items[MAX_DISPLAY_CLUSTERS_DATA:]
666
+
667
+ for cluster_num, group_data in first_items:
668
+ if (
669
+ not group_data
670
+ or "node_ids" not in group_data
671
+ or not group_data["node_ids"]
672
+ ):
673
+ st.warning(f"Cluster {cluster_num} has no data or node_ids.")
674
+ continue
675
+ st.markdown(
676
+ f"**Cluster {cluster_num}** (Size: {group_data.get('group_size', 'N/A')})"
677
+ )
678
+ node_id = group_data["node_ids"][0]
679
+ try:
680
+ num_steps = len(tree.synthesis_route(node_id))
681
+ route_score = round(tree.route_score(node_id), 3)
682
+ svg = get_route_svg(tree, node_id)
683
+ r_route_cgr = group_data.get("r_route_cgr") # Safely get r_route_cgr
684
+ r_route_cgr_svg = None
685
+ if r_route_cgr:
686
+ r_route_cgr.clean2d()
687
+ r_route_cgr_svg = cgr_display(r_route_cgr)
688
+
689
+ if svg and r_route_cgr_svg:
690
+ col1, col2 = st.columns([0.2, 0.8])
691
+ with col1:
692
+ st.image(r_route_cgr_svg, caption="ReducedRouteCGR")
693
+ with col2:
694
+ st.image(
695
+ svg,
696
+ caption=f"Route {node_id}; {num_steps} steps; Route score: {route_score}",
697
+ )
698
+ elif svg: # Only route SVG available
699
+ st.image(
700
+ svg,
701
+ caption=f"Route {node_id}; {num_steps} steps; Route score: {route_score}",
702
+ )
703
+ st.warning(
704
+ f"ReducedRouteCGR could not be displayed for cluster {cluster_num}."
705
+ )
706
+ else:
707
+ st.warning(
708
+ f"Could not generate SVG for route {node_id} or its ReducedRouteCGR."
709
+ )
710
+ except Exception as e:
711
+ st.error(
712
+ f"Error displaying route {node_id} for cluster {cluster_num}: {e}"
713
+ )
714
+
715
+ if remaining_items:
716
+ with st.expander(f"... and {len(remaining_items)} more clusters"):
717
+ for cluster_num, group_data in remaining_items:
718
+ if (
719
+ not group_data
720
+ or "node_ids" not in group_data
721
+ or not group_data["node_ids"]
722
+ ):
723
+ st.warning(
724
+ f"Cluster {cluster_num} in expansion has no data or node_ids."
725
+ )
726
+ continue
727
+ st.markdown(
728
+ f"**Cluster {cluster_num}** (Size: {group_data.get('group_size', 'N/A')})"
729
+ )
730
+ node_id = group_data["node_ids"][0]
731
  try:
732
  num_steps = len(tree.synthesis_route(node_id))
733
  route_score = round(tree.route_score(node_id), 3)
734
  svg = get_route_svg(tree, node_id)
735
+ r_route_cgr = group_data.get("r_route_cgr")
736
+ r_route_cgr_svg = None
737
+ if r_route_cgr:
738
+ r_route_cgr.clean2d()
739
+ r_route_cgr_svg = cgr_display(r_route_cgr)
740
+
741
+ if svg and r_route_cgr_svg:
742
+ col1, col2 = st.columns([0.2, 0.8])
743
+ with col1:
744
+ st.image(r_route_cgr_svg, caption="ReducedRouteCGR")
745
+ with col2:
746
+ st.image(
747
+ svg,
748
+ caption=f"Route {node_id}; {num_steps} steps; Route score: {route_score}",
749
+ )
750
+ elif svg:
751
+ st.image(
752
+ svg,
753
+ caption=f"Route {node_id}; {num_steps} steps; Route score: {route_score}",
754
+ )
755
+ st.warning(
756
+ f"ReducedRouteCGR could not be displayed for cluster {cluster_num}."
757
+ )
758
  else:
759
+ st.warning(
760
+ f"Could not generate SVG for route {node_id} or its ReducedRouteCGR."
761
+ )
762
+ except Exception as e:
763
+ st.error(
764
+ f"Error displaying route {node_id} for cluster {cluster_num}: {e}"
765
+ )
766
+
767
+
768
+ def download_clustering_results():
769
+ """10. Clustering Results Download: Providing functionality to download."""
770
+ if st.session_state.get("clustering_done", False):
771
+ tree_for_html = st.session_state.get("tree")
772
+ clusters_for_html = st.session_state.get("clusters")
773
+ r_route_cgrs_for_html = st.session_state.get(
774
+ "r_route_cgrs_dict"
775
+ ) # This was used instead of reactions_dict in the original for report
776
+
777
+ if not tree_for_html:
778
+ st.warning("MCTS Tree data not found. Cannot generate cluster reports.")
779
+ return
780
+ if not clusters_for_html:
781
+ st.warning("Cluster data not found. Cannot generate cluster reports.")
782
+ return
783
+ # r_route_cgrs_for_html is optional for routes_clustering_report if not essential
784
+
785
+ st.subheader("Cluster Reports") # Changed subheader in original
786
+ st.write("Generate downloadable HTML reports for each cluster:")
787
+
788
+ MAX_DOWNLOAD_LINKS_DISPLAYED = 10
789
+ num_clusters_total = len(clusters_for_html)
790
+ clusters_items = list(clusters_for_html.items())
791
+
792
+ for i, (cluster_idx, group_data) in enumerate(
793
+ clusters_items
794
+ ): # group_data might not be needed here if report uses cluster_idx
795
+ if i >= MAX_DOWNLOAD_LINKS_DISPLAYED:
796
+ break
797
+ try:
798
+ html_content = routes_clustering_report(
799
+ tree_for_html,
800
+ clusters_for_html, # Pass the whole dict
801
+ str(cluster_idx), # Pass the key of the cluster
802
+ r_route_cgrs_for_html, # Pass the r_route_cgrs dict
803
+ aam=False,
804
+ )
805
+ st.download_button(
806
+ label=f"Download report for cluster {cluster_idx}",
807
+ data=html_content,
808
+ file_name=f"cluster_{cluster_idx}_{st.session_state.target_smiles}.html",
809
+ mime="text/html",
810
+ key=f"download_cluster_{cluster_idx}",
811
+ )
812
+ except Exception as e:
813
+ st.error(f"Error generating report for cluster {cluster_idx}: {e}")
814
+
815
+ if num_clusters_total > MAX_DOWNLOAD_LINKS_DISPLAYED:
816
+ remaining_items = clusters_items[MAX_DOWNLOAD_LINKS_DISPLAYED:]
817
+ remaining_count = len(remaining_items)
818
+ expander_label = f"Show remaining {remaining_count} cluster reports"
819
+ with st.expander(expander_label):
820
+ for (
821
+ group_index,
822
+ _,
823
+ ) in remaining_items: # group_data not needed here either
824
+ try:
825
+ html_content = routes_clustering_report(
826
+ tree_for_html,
827
+ clusters_for_html,
828
+ str(group_index),
829
+ r_route_cgrs_for_html,
830
+ aam=False,
831
+ )
832
+ st.download_button(
833
+ label=f"Download report for cluster {group_index}",
834
+ data=html_content,
835
+ file_name=f"cluster_{group_index}_{st.session_state.target_smiles}.html",
836
+ mime="text/html",
837
+ key=f"download_cluster_expanded_{group_index}",
838
+ )
839
  except Exception as e:
840
+ st.error(
841
+ f"Error generating report for cluster {group_index} (expanded): {e}"
842
+ )
843
 
844
+ try:
845
+ buffer = io.BytesIO()
846
+ with zipfile.ZipFile(
847
+ buffer, mode="w", compression=zipfile.ZIP_DEFLATED
848
+ ) as zf:
849
+ for idx, _ in clusters_items: # group_data not needed
850
+ html_content_zip = routes_clustering_report(
851
+ tree_for_html,
852
+ clusters_for_html,
853
+ str(idx),
854
+ r_route_cgrs_for_html,
855
+ aam=False,
856
+ )
857
+ filename = f"cluster_{idx}_{st.session_state.target_smiles}.html"
858
+ zf.writestr(filename, html_content_zip)
859
+ buffer.seek(0)
860
+
861
+ st.download_button(
862
+ label="📦 Download all cluster reports as ZIP",
863
+ data=buffer,
864
+ file_name=f"all_cluster_reports_{st.session_state.target_smiles}.zip",
865
+ mime="application/zip",
866
+ key="download_all_clusters_zip",
867
+ )
868
+ except Exception as e:
869
+ st.error(f"Error generating ZIP file for cluster reports: {e}")
870
 
 
 
 
 
 
871
 
872
+ def setup_subclustering():
873
+ """11. Subclustering: Encapsulating the logic related to the "subclustering" functionality."""
874
+ if st.session_state.get(
875
+ "clustering_done", False
876
+ ): # Subclustering depends on clustering being done
877
+ st.divider()
878
+ st.header("Sub-Clustering within a selected Cluster")
879
+
880
+ if st.button("Run Subclustering Analysis", key="submit_subclustering_button"):
881
+ st.session_state.subclustering_done = False
882
+ st.session_state.subclusters = None
883
+ with st.spinner("Performing subclustering analysis..."):
884
+ try:
885
+ clusters_for_sub = st.session_state.get("clusters")
886
+ r_route_cgrs_dict_for_sub = st.session_state.get(
887
+ "r_route_cgrs_dict"
888
+ )
889
+ route_cgrs_dict_for_sub = st.session_state.get("route_cgrs_dict")
890
+
891
+ if (
892
+ clusters_for_sub
893
+ and r_route_cgrs_dict_for_sub
894
+ and route_cgrs_dict_for_sub
895
+ ): # Ensure all are present
896
+ all_subgroups = subcluster_all_clusters(
897
+ clusters_for_sub,
898
+ r_route_cgrs_dict_for_sub,
899
+ route_cgrs_dict_for_sub,
900
+ )
901
+ st.session_state.subclusters = all_subgroups
902
+ st.session_state.subclustering_done = True
903
+ st.success("Subclustering analysis complete.")
904
+ gc.collect()
905
+ st.rerun()
906
  else:
907
+ missing = []
908
+ if not clusters_for_sub:
909
+ missing.append("clusters")
910
+ if not r_route_cgrs_dict_for_sub:
911
+ missing.append("ReducedRouteCGRs dictionary")
912
+ if not route_cgrs_dict_for_sub:
913
+ missing.append("RouteCGRs dictionary")
914
+ st.error(
915
+ f"Cannot run subclustering. Missing data: {', '.join(missing)}. Please ensure clustering ran successfully."
916
+ )
917
+ st.session_state.subclustering_done = False
918
 
919
+ except Exception as e:
920
+ st.error(f"An error occurred during subclustering: {e}")
921
+ st.session_state.subclustering_done = False
922
 
 
 
 
 
923
 
924
+ def display_subclustering_results():
925
+ """12. Subclustering Results Display: Handling the presentation of results."""
926
+ if st.session_state.get("subclustering_done", False):
927
+ sub = st.session_state.get("subclusters")
928
+ tree = st.session_state.get("tree")
929
+ # clusters_for_sub_display = st.session_state.get('clusters') # Not directly used in display logic from original code snippet
930
 
931
+ if not sub or not tree:
932
+ st.error(
933
+ "Subclustering results (subclusters or tree) are missing. Please re-run subclustering."
934
+ )
935
+ st.session_state.subclustering_done = False
936
+ return
937
 
938
+ sub_input_col, sub_display_col = st.columns([0.25, 0.75])
 
 
 
 
939
 
940
+ with sub_input_col:
941
+ st.subheader("Select Cluster and Subcluster")
942
+ available_cluster_nums = list(sub.keys())
943
+ if not available_cluster_nums:
944
+ st.warning("No clusters available in subclustering results.")
945
+ return # Exit if no clusters to select
946
 
947
+ user_input_cluster_num_display = st.selectbox(
948
+ "Select Cluster #:",
949
+ options=sorted(available_cluster_nums),
950
+ key="subcluster_num_select_key",
951
+ )
952
 
953
+ selected_subcluster_idx = 0
954
+
955
+ if user_input_cluster_num_display in sub:
956
+ sub_step_cluster = sub[user_input_cluster_num_display]
957
+ allowed_subclusters_indices = sorted(list(sub_step_cluster.keys()))
958
+
959
+ if not allowed_subclusters_indices:
960
+ st.warning(
961
+ f"No reaction steps (subclusters) found for Cluster {user_input_cluster_num_display}."
962
+ )
963
+ else:
964
+ selected_subcluster_idx = st.selectbox(
965
+ "Select Subcluster Index:",
966
+ options=allowed_subclusters_indices,
967
+ key="subcluster_index_select_key",
968
+ )
969
+ if selected_subcluster_idx in sub[user_input_cluster_num_display]:
970
+ current_subcluster_data = sub[user_input_cluster_num_display][
971
+ selected_subcluster_idx
972
+ ]
973
+ if "r_route_cgr" in current_subcluster_data:
974
+ cluster_r_route_cgr_display = current_subcluster_data[
975
+ "r_route_cgr"
976
+ ]
977
+ cluster_r_route_cgr_display.clean2d()
978
+ st.image(
979
+ cluster_r_route_cgr_display.depict(),
980
+ caption=f"ReducedRouteCGR of parent Cluster {user_input_cluster_num_display}",
981
+ )
982
+ else:
983
+ st.warning("ReducedRouteCGR for this subcluster not found.")
984
+ else:
985
+ st.warning(
986
+ f"Selected cluster {user_input_cluster_num_display} not found in subclustering results."
987
+ )
988
+ return
989
+
990
+ with sub_display_col:
991
+ st.subheader("Subcluster Details")
992
+ if (
993
+ user_input_cluster_num_display in sub
994
+ and selected_subcluster_idx in sub[user_input_cluster_num_display]
995
+ ):
996
+
997
+ subcluster_content = sub[user_input_cluster_num_display][
998
+ selected_subcluster_idx
999
+ ]
1000
+
1001
+ # subcluster_to_display = post_process_subgroup(subcluster_content) #Under development
1002
+ subcluster_to_display = subcluster_content
1003
+ if (
1004
+ not subcluster_to_display
1005
+ or "nodes_data" not in subcluster_to_display
1006
+ or not subcluster_to_display["nodes_data"]
1007
+ ):
1008
+ st.info("No routes or data found for this subcluster selection.")
1009
+ else:
1010
+ MAX_ROUTES_PER_SUBCLUSTER = 5
1011
+ all_route_ids_in_subcluster = list(
1012
+ subcluster_to_display["nodes_data"].keys()
1013
+ )
1014
+ routes_to_display_direct = all_route_ids_in_subcluster[
1015
+ :MAX_ROUTES_PER_SUBCLUSTER
1016
+ ]
1017
+ remaining_routes_sub = all_route_ids_in_subcluster[
1018
+ MAX_ROUTES_PER_SUBCLUSTER:
1019
+ ]
1020
+
1021
+ st.markdown(
1022
+ f"--- \n**Subcluster {user_input_cluster_num_display}.{selected_subcluster_idx}** (Size: {len(all_route_ids_in_subcluster)})"
1023
+ )
1024
+
1025
+ if "synthon_reaction" in subcluster_to_display:
1026
+ synthon_reaction = subcluster_to_display["synthon_reaction"]
1027
+ synthon_reaction.clean2d()
1028
+ try:
1029
+ st.image(
1030
+ depict_custom_reaction(synthon_reaction),
1031
+ caption=f"Markush-like pseudo reaction of subcluster",
1032
+ ) # Assuming depict_custom_reaction
1033
+ except Exception as e_depict:
1034
+ st.warning(f"Could not depict synthon reaction: {e_depict}")
1035
+ else:
1036
+ st.info("No synthon reaction data for this subcluster.")
1037
+
1038
+ for route_id in routes_to_display_direct:
1039
+ try:
1040
+ route_score_sub = round(tree.route_score(route_id), 3)
1041
+ svg_sub = get_route_svg(tree, route_id)
1042
+ if svg_sub:
1043
+ st.image(
1044
+ svg_sub,
1045
+ caption=f"Route {route_id}; Score: {route_score_sub}",
1046
+ )
1047
+ else:
1048
+ st.warning(
1049
+ f"Could not generate SVG for route {route_id}."
1050
+ )
1051
+ except Exception as e:
1052
+ st.error(
1053
+ f"Error displaying route {route_id} in subcluster: {e}"
1054
+ )
1055
+
1056
+ if remaining_routes_sub:
1057
+ with st.expander(
1058
+ f"... and {len(remaining_routes_sub)} more routes in this subcluster"
1059
+ ):
1060
+ for route_id in remaining_routes_sub:
1061
  try:
1062
+ route_score_sub = round(
1063
+ tree.route_score(route_id), 3
 
 
 
 
1064
  )
1065
+ svg_sub = get_route_svg(tree, route_id)
1066
+ if svg_sub:
1067
+ st.image(
1068
+ svg_sub,
1069
+ caption=f"Route {route_id}; Score: {route_score_sub}",
1070
+ )
1071
+ else:
1072
+ st.warning(
1073
+ f"Could not generate SVG for route {route_id}."
1074
+ )
1075
  except Exception as e:
1076
+ st.error(
1077
+ f"Error displaying route {route_id} in subcluster (expanded): {e}"
1078
+ )
1079
+ else:
1080
+ st.info("Select a valid cluster and subcluster index to see details.")
 
 
 
 
 
 
 
1081
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1082
 
1083
+ def download_subclustering_results():
1084
+ """13. Subclustering Results Download: Providing functionality to download."""
1085
+ if (
1086
+ st.session_state.get("subclustering_done", False)
1087
+ and "subcluster_num_select_key" in st.session_state
1088
+ and "subcluster_index_select_key" in st.session_state
1089
+ ):
1090
 
1091
+ sub = st.session_state.get("subclusters")
1092
+ tree = st.session_state.get("tree")
1093
+ r_route_cgrs_for_report = st.session_state.get(
1094
+ "r_route_cgrs_dict"
1095
+ ) # Used by routes_subclustering_report
1096
+
1097
+ user_input_cluster_num_display = st.session_state.subcluster_num_select_key
1098
+ selected_subcluster_idx = st.session_state.subcluster_index_select_key
1099
+
1100
+ if not tree or not sub or not r_route_cgrs_for_report:
1101
+ st.warning(
1102
+ "Missing data for subclustering report generation (tree, subclusters, or ReducedRouteCGRs)."
1103
+ )
1104
+ return
1105
+
1106
+ if (
1107
+ user_input_cluster_num_display in sub
1108
+ and selected_subcluster_idx in sub[user_input_cluster_num_display]
1109
+ ):
1110
+
1111
+ subcluster_data_for_report = sub[user_input_cluster_num_display][
1112
+ selected_subcluster_idx
1113
+ ]
1114
+ # Apply the same post-processing as in display
1115
+ processed_subcluster_data = post_process_subgroup(
1116
+ subcluster_data_for_report
1117
+ )
1118
+ if "nodes_data" in subcluster_data_for_report and isinstance(
1119
+ subcluster_data_for_report["nodes_data"], dict
1120
+ ):
1121
+ processed_subcluster_data["group_lgs"] = group_by_identical_values(
1122
+ subcluster_data_for_report["nodes_data"]
1123
+ )
1124
+ else:
1125
+ processed_subcluster_data["group_lgs"] = {}
1126
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1127
  try:
1128
+ subcluster_html_content = routes_subclustering_report(
1129
+ tree,
1130
+ processed_subcluster_data, # Pass the specific post-processed subcluster data
1131
+ user_input_cluster_num_display,
1132
+ selected_subcluster_idx,
1133
+ r_route_cgrs_for_report, # Pass the whole r_route_cgrs dict
1134
+ if_lg_group=True, # This parameter was in the original call
1135
+ )
1136
+ st.download_button(
1137
+ label=f"Download report for subcluster {user_input_cluster_num_display}.{selected_subcluster_idx}",
1138
+ data=subcluster_html_content,
1139
+ file_name=f"subcluster_{user_input_cluster_num_display}.{selected_subcluster_idx}_{st.session_state.target_smiles}.html",
1140
+ mime="text/html",
1141
+ key=f"download_subcluster_{user_input_cluster_num_display}_{selected_subcluster_idx}",
1142
+ )
1143
  except Exception as e:
1144
+ st.error(
1145
+ f"Error generating download report for subcluster {user_input_cluster_num_display}.{selected_subcluster_idx}: {e}"
1146
+ )
1147
+ # else:
1148
+ # This case is handled by the display logic mostly, download button just won't appear or will be for previous valid selection.
1149
+
1150
+
1151
+ def implement_restart():
1152
+ """14. Restart: Implementing the logic to reset or restart the application state."""
1153
+ st.divider()
1154
+ st.header("Restart Application State")
1155
+ if st.button("Clear All Results & Restart", key="restart_button"):
1156
+ keys_to_clear = [
1157
+ "planning_done",
1158
+ "tree",
1159
+ "res",
1160
+ "target_smiles",
1161
+ "clustering_done",
1162
+ "clusters",
1163
+ "reactions_dict",
1164
+ "num_clusters_setting",
1165
+ "route_cgrs_dict",
1166
+ "r_route_cgrs_dict",
1167
+ "subclustering_done",
1168
+ "subclusters", # "sub" was renamed
1169
+ "clusters_downloaded",
1170
+ # Potentially ketcher related keys if they need manual reset beyond new input
1171
+ "ketcher_widget",
1172
+ "smiles_text_input_key", # Keys for widgets
1173
+ "subcluster_num_select_key",
1174
+ "subcluster_index_select_key",
1175
+ ]
1176
+ for key in keys_to_clear:
1177
+ if key in st.session_state:
1178
+ del st.session_state[key]
1179
+
1180
+ # Reset ketcher input to default by resetting its session state variable
1181
+ st.session_state.ketcher = DEFAULT_MOL
1182
+ # Also explicitly set target_smiles to empty or default to avoid stale data
1183
+ st.session_state.target_smiles = ""
1184
+
1185
+ # It's generally better to let Streamlit manage widget state if possible,
1186
+ # but for a full reset, clearing their explicit session state keys might be needed.
1187
+ st.rerun()
1188
+
1189
+
1190
+ # --- Main Application Flow ---
1191
+ def main():
1192
+ initialize_app()
1193
+ setup_sidebar()
1194
+ current_smile_code = handle_molecule_input()
1195
+ # Update session_state.ketcher if current_smile_code has changed from ketcher output
1196
+ if st.session_state.get("ketcher") != current_smile_code:
1197
+ st.session_state.ketcher = current_smile_code
1198
+ # No rerun here, let the flow continue. handle_molecule_input already warns.
1199
+
1200
+ setup_planning_options() # This function now also handles the button press and logic for planning
1201
+
1202
+ # Display planning results and download options together
1203
+ if st.session_state.get("planning_done", False):
1204
+ display_planning_results() # Displays stats and routes
1205
+ if st.session_state.res and st.session_state.res.get("solved", False):
1206
+ stat_col, download_col = st.columns(
1207
+ 2, gap="medium"
1208
+ ) # Placeholder for download column
1209
+ with stat_col:
1210
+ st.subheader("Statistics")
1211
+ try:
1212
+ res = st.session_state.res
1213
+ if (
1214
+ "target_smiles" not in res
1215
+ and "target_smiles" in st.session_state
1216
+ ):
1217
+ res["target_smiles"] = st.session_state.target_smiles
1218
+ cols_to_show = [
1219
+ col
1220
+ for col in [
1221
+ "target_smiles",
1222
+ "num_routes",
1223
+ "num_nodes",
1224
+ "num_iter",
1225
+ "search_time",
1226
+ ]
1227
+ if col in res
1228
+ ]
1229
+ if cols_to_show: # Ensure there are columns to show
1230
+ df = pd.DataFrame(res, index=[0])[cols_to_show]
1231
+ st.dataframe(df)
1232
+ else:
1233
+ st.write("No statistics to display from planning results.")
1234
+ except Exception as e:
1235
+ st.error(f"Error displaying statistics: {e}")
1236
+ st.write(res) # Show raw dict if DataFrame fails
1237
+ with download_col:
1238
+ st.subheader("Planning Downloads") # Adding a subheader for clarity
1239
+ download_planning_results()
1240
+
1241
+ # Clustering section (setup button, display, download)
1242
+ if (
1243
+ st.session_state.get("planning_done", False)
1244
+ and st.session_state.res
1245
+ and st.session_state.res.get("solved", False)
1246
+ ):
1247
+ setup_clustering() # Contains the "Run Clustering" button and logic
1248
+ if st.session_state.get("clustering_done", False):
1249
+ display_clustering_results() # Displays cluster routes and stats
1250
+ cluster_stat_col, cluster_download_col = st.columns(2, gap="medium")
1251
+
1252
+ with cluster_stat_col:
1253
+ clusters = st.session_state.clusters
1254
+ cluster_sizes = [
1255
+ cluster.get("group_size", 0)
1256
+ for cluster in clusters.values()
1257
+ if cluster
1258
+ ] # Safe get
1259
+ st.subheader("Cluster Statistics")
1260
+ if cluster_sizes:
1261
+ cluster_df = pd.DataFrame(
1262
+ {
1263
+ "Cluster": [
1264
+ k for k, v in clusters.items() if v
1265
+ ], # Filter out empty clusters
1266
+ "Number of Routes": [
1267
+ v["group_size"] for v in clusters.values() if v
1268
+ ],
1269
+ }
1270
+ )
1271
+ if not cluster_df.empty:
1272
+ cluster_df.index += 1
1273
+ st.dataframe(cluster_df)
1274
+ best_route_html = html_top_routes_cluster(
1275
+ clusters,
1276
+ st.session_state.tree,
1277
+ st.session_state.target_smiles,
1278
+ )
1279
+ st.download_button(
1280
+ label=f"Download best route from each cluster",
1281
+ data=best_route_html,
1282
+ file_name=f"cluster_best_{st.session_state.target_smiles}.html",
1283
+ mime="text/html",
1284
+ key=f"download_cluster_best",
1285
+ )
1286
+ else:
1287
+ st.write("No valid cluster data to display statistics for.")
1288
+ # download_top_routes_cluster()
1289
+ else:
1290
+ st.write("No cluster data to display statistics for.")
1291
+ with cluster_download_col:
1292
+ download_clustering_results()
1293
+
1294
+ # Subclustering section (setup button, display, download)
1295
+ if st.session_state.get("clustering_done", False): # Depends on clustering
1296
+ setup_subclustering() # Contains "Run Subclustering" button
1297
+ if st.session_state.get("subclustering_done", False):
1298
+ display_subclustering_results() # Displays subcluster details and routes
1299
+ download_subclustering_results() # This needs to be called after selections are made in display.
1300
+
1301
+ implement_restart()
1302
+
1303
+
1304
+ if __name__ == "__main__":
1305
+ main()
cluster/clustering.py DELETED
@@ -1,174 +0,0 @@
1
- import numpy as np
2
- import pandas as pd
3
- from scipy.spatial.distance import squareform
4
- from scipy.cluster.hierarchy import fcluster
5
- from sklearn.metrics import silhouette_score, calinski_harabasz_score
6
- import fastcluster
7
-
8
- def tanimoto_similarity_continuous(matrix_1, matrix_2):
9
- """
10
- "The Tanimoto coefficient is a measure of the similarity between two sets.
11
- It is defined as the size of the intersection divided by the size of the union of the sample sets."
12
-
13
- The Tanimoto coefficient is also known as the Jaccard index
14
-
15
- Adoppted from https://github.com/cimm-kzn/CIMtools/blob/master/CIMtools/metrics/pairwise.py
16
-
17
- :param matrix_1: 2D array of features.
18
- :param matrix_2: 2D array of features.
19
- :return: The Tanimoto coefficient between the two arrays.
20
- """
21
- x_dot = np.dot(matrix_1, matrix_2.T)
22
-
23
- x2 = (matrix_1**2).sum(axis=1)
24
- y2 = (matrix_2**2).sum(axis=1)
25
-
26
- len_x2 = len(x2)
27
- len_y2 = len(y2)
28
-
29
- result = x_dot / (np.array([x2] * len_y2).T + np.array([y2] * len_x2) - x_dot)
30
- result[np.isnan(result)] = 0
31
-
32
- if matrix_1.shape == matrix_2.shape:
33
- np.fill_diagonal(result, 1.0)
34
-
35
- return result
36
-
37
- def calculate_fingerprints(cgrs, fingerprint_method):
38
- """Calculate fingerprints for a collection of CGRs.
39
-
40
- Args:
41
- cgrs (dict): Dictionary of CGRs
42
- fingerprint_method: Initialized fingerprint calculator (e.g., MorganFingerprint instance)
43
-
44
- Returns:
45
- np.ndarray: Array of fingerprints
46
- """
47
- fingerprints = []
48
- for cgr in cgrs.values():
49
- fp = fingerprint_method.transform([cgr])[0]
50
- fingerprints.append(fp)
51
- return np.array(fingerprints)
52
-
53
- def create_similarity_matrix(fingerprints, labels):
54
- """Create a similarity matrix from fingerprints.
55
-
56
- Args:
57
- fingerprints (np.ndarray): Array of fingerprints
58
- labels (list): Labels for the fingerprints
59
-
60
- Returns:
61
- pd.DataFrame: Similarity matrix as a DataFrame
62
- """
63
- similarity_matrix = tanimoto_similarity_continuous(fingerprints, fingerprints)
64
- return pd.DataFrame(similarity_matrix, columns=labels, index=labels)
65
-
66
- def calculate_linkage(similarity_df, method='average'):
67
- """Calculate linkage matrix for hierarchical clustering.
68
-
69
- Args:
70
- similarity_df (pd.DataFrame): Similarity matrix
71
- method (str): Linkage method
72
-
73
- Returns:
74
- np.ndarray: Linkage matrix
75
- """
76
- distance_matrix = 1 - similarity_df
77
- condensed_distance = squareform(distance_matrix)
78
- return fastcluster.linkage(condensed_distance, method=method)
79
-
80
- def optimal_cluster_num(Z, distance_matrix, max_clusters=10):
81
- """Find optimal number of clusters using silhouette score.
82
-
83
- Args:
84
- Z (np.ndarray): Linkage matrix
85
- distance_matrix (np.ndarray): Distance matrix
86
- max_clusters (int): Maximum number of clusters to consider
87
-
88
- Returns:
89
- int: Optimal number of clusters
90
- """
91
- cluster_range = range(2, max_clusters)
92
- silhouette_scores = []
93
-
94
- for n_clusters in cluster_range:
95
- cluster_labels = fcluster(Z, n_clusters, criterion='maxclust')
96
- score = silhouette_score(distance_matrix, cluster_labels, metric='precomputed')
97
- silhouette_scores.append(score)
98
-
99
- return cluster_range[np.argmax(silhouette_scores)]
100
-
101
- def perform_clustering(Z, similarity_df, threshold=0.0, max_clusters=10):
102
- """Perform hierarchical clustering with automatic cluster number optimization.
103
-
104
- Args:
105
- Z (np.ndarray): Linkage matrix
106
- threshold (float): Distance threshold for initial clustering
107
- max_clusters (int): Maximum number of clusters
108
-
109
- Returns:
110
- np.ndarray: Cluster labels
111
- """
112
- cluster_labels = fcluster(Z, t=threshold, criterion='distance')
113
- unique_clusters = np.unique(cluster_labels)
114
-
115
- if max(unique_clusters) > max_clusters:
116
- optimal_n_clusters = optimal_cluster_num(Z, 1 - similarity_df, max_clusters)
117
- cluster_labels = fcluster(Z, optimal_n_clusters, criterion='maxclust')
118
-
119
- return cluster_labels
120
-
121
- def create_clusters_dict(cluster_labels, labels):
122
- """Create a dictionary of clusters with their members.
123
-
124
- Args:
125
- cluster_labels (np.ndarray): Cluster assignments
126
- labels (list): Labels for the items
127
-
128
- Returns:
129
- dict: Dictionary mapping cluster numbers to lists of member labels
130
- """
131
- unique_clusters = np.unique(cluster_labels)
132
- clusters_dict = {}
133
-
134
- for cluster in unique_clusters:
135
- cluster_indices = np.where(cluster_labels == cluster)[0]
136
- clusters_dict[cluster] = list(labels[cluster_indices])
137
-
138
- return clusters_dict
139
-
140
- def cluster_molecules(cgrs, fingerprint_method, threshold=0.0, max_clusters=10, linkage_method='average'):
141
- """Main function to perform molecular clustering.
142
-
143
- Args:
144
- cgrs (dict): Dictionary of CGRs
145
- fingerprint_method: Initialized fingerprint calculator
146
- threshold (float): Distance threshold for clustering
147
- max_clusters (int): Maximum number of clusters
148
- linkage_method (str): Method for hierarchical clustering
149
-
150
- Returns:
151
- dict: Clustering results containing clusters_dict and cluster_labels
152
- """
153
- # Calculate fingerprints
154
- fingerprints = calculate_fingerprints(cgrs, fingerprint_method)
155
-
156
- # Create similarity matrix
157
- labels = list(cgrs.keys())
158
- similarity_df = create_similarity_matrix(fingerprints, labels)
159
-
160
- # Calculate linkage
161
- Z = calculate_linkage(similarity_df, method=linkage_method)
162
-
163
- # Perform clustering
164
- cluster_labels = perform_clustering(Z, similarity_df, threshold, max_clusters)
165
-
166
- # Create clusters dictionary
167
- clusters_dict = create_clusters_dict(cluster_labels, np.array(labels))
168
-
169
- return {
170
- 'clusters_dict': clusters_dict,
171
- 'cluster_labels': cluster_labels,
172
- 'similarity_matrix': similarity_df,
173
- 'linkage_matrix': Z
174
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cluster/generalized_cgr.py DELETED
@@ -1,204 +0,0 @@
1
- def find_next_atom_num(accum_cgr, reactions):
2
- """Find the next available atom number."""
3
- max_num = 0
4
- for reaction in reactions:
5
- cgr = reaction.compose()
6
- max_num = max(max_num, max(cgr._atoms.keys()))
7
- return max_num + 1
8
-
9
- def get_clean_mapping(curr_prod, prod, reverse=False):
10
- """Get clean mapping between molecules while avoiding number conflicts."""
11
- dict_map = {}
12
- m = list(curr_prod.get_mapping(prod))
13
-
14
- if len(m) == 0:
15
- return dict_map
16
-
17
- # Get existing atom numbers in both molecules
18
- curr_atoms = set(curr_prod._atoms.keys())
19
- prod_atoms = set(prod._atoms.keys())
20
-
21
- rr = m[0]
22
-
23
- # Build mapping while checking for conflicts
24
- for key, value in rr.items():
25
- if key != value:
26
- if value in rr.keys() and rr[value] != key:
27
- # Skip cyclic mappings that could cause conflicts
28
- continue
29
-
30
- source = value if reverse else key
31
- target = key if reverse else value
32
-
33
- # Check if target number already exists in the molecule
34
- if reverse and target in curr_atoms:
35
- continue
36
- if not reverse and target in prod_atoms:
37
- continue
38
-
39
- dict_map[source] = target
40
-
41
- return dict_map
42
-
43
- def validate_molecule_components(curr_mol, node_id):
44
- """Validate that molecule has only one connected component."""
45
- new_rmol = [curr_mol.substructure(c) for c in curr_mol.connected_components]
46
- if len(new_rmol) > 1:
47
- print(f'Error tree {node_id}: We have more than one molecule in one node')
48
-
49
- def get_leaving_groups(products):
50
- """Extract leaving group atom numbers from products."""
51
- lg_atom_nums = []
52
- for i, prod in enumerate(products):
53
- if i != 0: # Skip first product (main product)
54
- lg_atom_nums.extend(prod._atoms.keys())
55
- return lg_atom_nums
56
-
57
- def process_first_reaction(first_react, tree, node_id, min_mol_size):
58
- """Process first reaction in the route and initialize building block set."""
59
- bb_set = set()
60
-
61
- for curr_mol in first_react.reactants:
62
- react_key = tuple(curr_mol._atoms)
63
- react_key_set = set(react_key)
64
-
65
- if len(curr_mol) <= min_mol_size or str(curr_mol) in tree.building_blocks:
66
- bb_set = react_key_set
67
-
68
- validate_molecule_components(curr_mol, node_id)
69
-
70
- return bb_set
71
-
72
- def update_reaction_dict(reaction, node_id, mapping, react_dict, tree, min_mol_size, bb_set, prev_remap=None):
73
- """Update reaction dictionary with new mappings."""
74
- for curr_mol in reaction.reactants:
75
- react_key = tuple(curr_mol._atoms)
76
- react_key_set = set(react_key)
77
-
78
- validate_molecule_components(curr_mol, node_id)
79
-
80
- if len(curr_mol) <= min_mol_size or str(curr_mol) in tree.building_blocks:
81
- bb_set = bb_set.union(react_key_set)
82
-
83
- # Filter the mapping to include only keys present in the current react_key
84
- filtered_mapping = {k: v for k, v in mapping.items() if k in react_key_set}
85
- if prev_remap:
86
- prev_remappping = {k: v for k, v in prev_remap.items() if k in react_key_set}
87
- filtered_mapping.update(prev_remappping)
88
- react_dict[react_key] = filtered_mapping
89
-
90
- return react_dict, bb_set
91
-
92
- def process_target_blocks(curr_products, curr_prod, lg_atom_nums, curr_lg_atom_nums, bb_set):
93
- """Process and collect target blocks for remapping."""
94
- target_block = []
95
- if len(curr_products) > 1:
96
- for prod in curr_products:
97
- dict_map = get_clean_mapping(curr_prod, prod)
98
- if prod._atoms.keys() != curr_prod._atoms.keys():
99
- for key in list(prod._atoms.keys()):
100
- if key in lg_atom_nums or key in curr_lg_atom_nums:
101
- target_block.append(key)
102
- if key in bb_set:
103
- target_block.append(key)
104
- return target_block
105
-
106
- def process_single_route(tree, node_id, min_mol_size=6):
107
- """Process a single synthesis route maintaining consistent state."""
108
- try:
109
- reactions = tree.synthesis_route(node_id)
110
-
111
- first_react = reactions[-1]
112
-
113
- accum_cgr = first_react.compose()
114
- bb_set = process_first_reaction(first_react, tree, node_id, min_mol_size)
115
-
116
- react_dict = {}
117
-
118
- max_num = find_next_atom_num(accum_cgr, reactions)
119
-
120
- for step in range(len(reactions) - 2, -1, -1):
121
- # print("\nProcessing step:", step + 1)
122
- reaction = reactions[step]
123
- curr_cgr = reaction.compose()
124
-
125
- curr_prod = reaction.products[0]
126
- accum_products = accum_cgr.decompose()[1].split()
127
- lg_atom_nums = get_leaving_groups(accum_products)
128
-
129
- curr_products = curr_cgr.decompose()[1].split()
130
-
131
- tuple_atoms = tuple(curr_prod._atoms)
132
- prev_remap = {}
133
-
134
- if tuple_atoms in react_dict.keys() and len(react_dict[tuple_atoms]) != 0:
135
- prev_remap = react_dict[tuple_atoms]
136
- curr_cgr = curr_cgr.remap(prev_remap, copy=True)
137
-
138
- curr_lg_atom_nums = []
139
- for i in range(1, len(curr_products)):
140
- prod = curr_products[i]
141
- curr_lg_atom_nums += list(prod._atoms.keys())
142
-
143
- target_block = process_target_blocks(curr_products, curr_prod, lg_atom_nums, curr_lg_atom_nums, bb_set)
144
-
145
- mapping = {}
146
- for atom_num in sorted(target_block):
147
- if atom_num in accum_cgr._atoms and atom_num not in mapping:
148
- mapping[atom_num] = max_num
149
- max_num += 1
150
-
151
- for i in range(len(accum_products)):
152
- accum_prod = accum_products[i]
153
- dict_map = get_clean_mapping(curr_prod, accum_prod, reverse=True)
154
-
155
- if dict_map:
156
- curr_cgr.remap(dict_map)
157
-
158
-
159
- #maybe remap, then decompose and to BB
160
- react_dict, bb_set = update_reaction_dict(reaction, node_id, mapping, react_dict, tree, min_mol_size, bb_set, prev_remap)
161
-
162
-
163
- if mapping:
164
- curr_cgr.remap(mapping)
165
-
166
- accum_cgr = curr_cgr.compose(accum_cgr)
167
-
168
-
169
- return {
170
- 'cgr': accum_cgr,
171
- }
172
-
173
- except Exception as e:
174
- print(f"Error processing node {node_id}: {e}")
175
- return None
176
-
177
- def reassign_nums(tree, node_id=None, min_mol_size=6):
178
- """
179
- Process routes and reassign atom numbers.
180
-
181
- Args:
182
- tree: Synthesis tree
183
- node_id: Optional specific node ID to process. If None, processes all winning nodes
184
- min_mol_size: Minimum size for building blocks
185
-
186
- Returns:
187
- If node_id is None:
188
- dict: Dictionary mapping node IDs to their processed CGRs
189
- If node_id is specified:
190
- dict: Information about the processed route
191
- """
192
- if node_id is not None:
193
- return process_single_route(tree, node_id, min_mol_size)
194
-
195
- complex_cgr_dict = {}
196
- reactions_dict = {}
197
- cgrs_list = []
198
- for node_id in set(tree.winning_nodes):
199
- result = process_single_route(tree, node_id, min_mol_size)
200
- if result:
201
- complex_cgr_dict[node_id] = result['cgr']
202
-
203
- return dict(sorted(complex_cgr_dict.items()))
204
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cluster/reduced_g_cgr.py DELETED
@@ -1,159 +0,0 @@
1
- from CGRtools.containers.bonds import DynamicBond
2
-
3
- def reducing_g_cgr(g_cgr):
4
- """
5
- Reduces a Generalized Condensed Graph of reaction (G-CGR) by performing the following steps:
6
-
7
- 1. Extracts substructures corresponding to connected components from the input G-CGR.
8
- 2. Selects the first substructure as the target to work on.
9
- 3. Iterates over all bonds in the target G-CGR:
10
- - If a bond is identified as a "leaving group" (its primary order is None while its original order is defined),
11
- the bond is removed.
12
- - If a bond has a modified order (both primary and original orders are integers) and the primary order is less than the original,
13
- the bond is deleted and then re-added with a new dynamic bond using the primary order (this updates the bond to the reduced form).
14
- 4. After bond modifications, re-extracts the substructure from the target G-CGR (now called the reduced G-CGR or RG-CGR).
15
- 5. If the charge distributions (_p_charges vs. _charges) differ, neutralizes the charges by setting them to zero.
16
-
17
- Finally, returns the reduced G-CGR.
18
- """
19
- # Get all connected components of the G-CGR as separate substructures.
20
- cgr_prods = [g_cgr.substructure(c) for c in g_cgr.connected_components]
21
- target_cgr = cgr_prods[0] # Choose the first substructure (main product) for further reduction.
22
-
23
- # Iterate over each bond in the target G-CGR.
24
- bond_items = list(target_cgr._bonds.items())
25
- for atom1, bond_set in bond_items:
26
- bond_set_items = list(bond_set.items())
27
- for atom2, bond in bond_set_items:
28
-
29
- # Removing bonds corresponding to leaving groups:
30
- # If product bond order is None (indicating a leaving group) but an original bond order exists,
31
- # delete the bond.
32
- if bond.p_order is None and bond.order is not None:
33
- target_cgr.delete_bond(atom1, atom2)
34
-
35
- # For bonds that have been modified (not leaving groups) where the new (primary) order is less than the original:
36
- # Remove the bond and re-add it using the DynamicBond with the primary order for both bond orders.
37
- elif type(bond.p_order) is int and type(bond.order) is int and bond.p_order != bond.order:
38
- p_order = int(bond.p_order)
39
- target_cgr.delete_bond(atom1, atom2)
40
- target_cgr.add_bond(atom1, atom2, DynamicBond(p_order, p_order))
41
-
42
- # After modifying bonds, extract the reduced G-CGR from the target's connected components.
43
- rg_cgr = [target_cgr.substructure(c) for c in target_cgr.connected_components][0]
44
-
45
- # Neutralize charges if the primary charges and current charges differ.
46
- if rg_cgr._p_charges != rg_cgr._charges:
47
- for num, charge in rg_cgr._charges.items():
48
- if charge != 0:
49
- rg_cgr._atoms[num].charge = 0
50
-
51
- return rg_cgr
52
-
53
-
54
- def process_all_rg_cgrs(g_cgrs_dict):
55
- """
56
- Processes a collection (dictionary) of G-CGRs to generate their reduced forms (RG-CGRs).
57
-
58
- Iterates over each G-CGR in the provided dictionary and applies the reducing_g_cgr function.
59
-
60
- Note: There is an apparent bug in the code since it uses an undefined variable 'super_cgrs_dict'
61
- and assigns to 'all_rs_cgrs' instead of 'all_rg_cgrs'. The intended behavior is to iterate over
62
- the input dictionary (g_cgrs_dict) and store the reduced RG-CGR for each key.
63
-
64
- Returns:
65
- A dictionary where each key corresponds to the RG-CGR obtained from the input G-CGR.
66
- """
67
- all_rg_cgrs = dict()
68
- for num, cgr in g_cgrs_dict.items():
69
- all_rg_cgrs[num] = reducing_g_cgr(cgr)
70
- return all_rg_cgrs
71
-
72
-
73
- def report_strategic_bonds(result, target_cgr):
74
- """
75
- Reports strategic bonds from a provided result list.
76
-
77
- Each element in 'result' is expected to be a list with two elements:
78
- - A tuple (atom pair) indicating the connected atoms.
79
- - The primary bond order (p_order) associated with that bond.
80
-
81
- The function prints out the atoms (accessed from target_cgr._atoms) and the bond order.
82
- """
83
- for value in result:
84
- atom_pair = value[0]
85
- # Print the two atoms and the associated primary bond order.
86
- print('\t', target_cgr._atoms[atom_pair[0]], target_cgr._atoms[atom_pair[1]], value[1])
87
-
88
-
89
- def extract_strategic_bonds(target_cgr, report=True):
90
- """
91
- Extracts and optionally reports strategic bonds from a reduced G-CGR (RG-CGR).
92
-
93
- Strategic bonds are defined as those with:
94
- - No current bond order (order is None) but a defined primary bond order (p_order is not None).
95
-
96
- The function goes through all bonds in the target_cgr, collects each unique bond (avoiding duplicates by using a set)
97
- along with its primary bond order, and optionally prints them out.
98
-
99
- Returns:
100
- A list where each element is a pair: [bond_key (tuple of atom indices), primary bond order]
101
- """
102
- result = []
103
- seen = set()
104
- # Loop through all bonds in the RG-CGR.
105
- for atom1, bond_set in target_cgr._bonds.items():
106
- for atom2, bond in bond_set.items():
107
- # Check for strategic bonds (order undefined but p_order defined).
108
- if bond.order is None and bond.p_order is not None:
109
- # Create a sorted tuple of the atom pair to ensure uniqueness.
110
- bond_key = tuple(sorted((atom1, atom2)))
111
- if bond_key not in seen:
112
- seen.add(bond_key)
113
- result.append([bond_key, bond.p_order])
114
- # If reporting is enabled, print the strategic bonds.
115
- if report:
116
- print('Strategic bonds in RG-CGR:')
117
- report_strategic_bonds(result, target_cgr)
118
- return result
119
-
120
-
121
- def compare_rg_cgr_by_strategic_bonds(rg_cgr1, rg_cgr2, report=True):
122
- """
123
- Compares two reduced G-CGRs (RG-CGRs) based on their strategic bonds.
124
-
125
- The function performs the following steps:
126
-
127
- 1. Extracts the list of strategic bonds for each RG-CGR.
128
- 2. Converts each list into a set of tuples (bond key and bond order) for easy set operations.
129
- 3. Identifies common bonds, and bonds unique to each RG-CGR.
130
- 4. Converts these sets back into lists for reporting.
131
- 5. Prints out the common bonds, bonds unique to the first RG-CGR, and bonds unique to the second RG-CGR.
132
-
133
- The reporting uses the report_strategic_bonds function to output the atom details and bond orders.
134
- """
135
- # Extract strategic bonds from both RG-CGRs without reporting.
136
- l1 = extract_strategic_bonds(rg_cgr1, report=False)
137
- l2 = extract_strategic_bonds(rg_cgr2, report=False)
138
-
139
- # Create sets of (atom pair, bond order) tuples for both RG-CGRs.
140
- set_l1 = { (tuple(item[0]), item[1]) for item in l1 }
141
- set_l2 = { (tuple(item[0]), item[1]) for item in l2 }
142
-
143
- # Identify common bonds and bonds unique to each list.
144
- common = set_l1 & set_l2
145
- unique_l1 = set_l1 - set_l2
146
- unique_l2 = set_l2 - set_l1
147
-
148
- # Convert the sets back to list format for reporting.
149
- common_list = [ [atom_pair, order] for atom_pair, order in common ]
150
- unique_l1_list = [ [atom_pair, order] for atom_pair, order in unique_l1 ]
151
- unique_l2_list = [ [atom_pair, order] for atom_pair, order in unique_l2 ]
152
-
153
- if report:
154
- print("Common:")
155
- report_strategic_bonds(common_list, rg_cgr1)
156
- print("Unique for first RG-CGR:")
157
- report_strategic_bonds(unique_l1_list, rg_cgr1)
158
- print("Unique for second RG-CGR:")
159
- report_strategic_bonds(unique_l2_list, rg_cgr1)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cluster/subcluster.py DELETED
@@ -1,33 +0,0 @@
1
- from collections import defaultdict
2
-
3
- def split_ids_by_length(ids, data):
4
- length_to_ids = defaultdict(list)
5
-
6
- for id_ in ids:
7
- if id_ in data:
8
- length_to_ids[len(data[id_])].append(id_)
9
- return length_to_ids
10
-
11
-
12
- def group_ids_by_intermediate_products(ids, reactions_dict):
13
- groups = defaultdict(list)
14
- for id_ in ids:
15
- # Build a key: a tuple of the first product for each reaction.
16
- # This assumes that reactions_dict[id_] is a tuple of Reaction objects
17
- # and each Reaction object has an attribute 'products' that is indexable.
18
- key = tuple(reaction.products[0] for reaction in reactions_dict[id_])
19
- groups[key].append(id_)
20
- return list(groups.values())
21
-
22
-
23
- def sublcuster_all(cluster_dict, reactions_dict):
24
- subcluster_dict = {}
25
- for num, cluster in cluster_dict.items():
26
- step_split_dict = split_ids_by_length(cluster, reactions_dict)
27
- subcluster = {}
28
- for steps in step_split_dict.keys():
29
- ids_to_group = step_split_dict[steps]
30
- grouped_ids = group_ids_by_intermediate_products(ids_to_group, reactions_dict)
31
- subcluster[steps] = grouped_ids
32
- subcluster_dict[num] = subcluster
33
- return subcluster_dict
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cluster/utils.py DELETED
@@ -1,314 +0,0 @@
1
- from synplan.mcts.tree import Tree
2
- from synplan.utils.visualisation import get_route_svg
3
- from CGRtools.containers import MoleculeContainer
4
- import pickle
5
- import os
6
-
7
- def extract_reactions(tree):
8
- reactions_dict = {}
9
- for node_id in set(tree.winning_nodes):
10
- reactions = tree.synthesis_route(node_id)
11
- reactions_dict[node_id] = reactions
12
- return reactions_dict
13
-
14
- def extract_rules_from_route(node_id, tree):
15
- nodes = tree.route_to_node(node_id)
16
- found_rules_ids = []
17
- for i in range(len(nodes)):
18
- precursor = nodes[i].new_precursors[0]
19
- if len(precursor) != 0:
20
- if 'reactor_id' in precursor.molecule.meta.keys():
21
- found_rules_ids.append(precursor.molecule.meta['reactor_id'])
22
- return found_rules_ids[::-1]
23
-
24
- def save_smarts(mol_id, config, reactions_dict):
25
- with open(f'smarts/smarts_mol_{mol_id}_{config}.txt', "w") as file:
26
- for node_id, reactions in reactions_dict.items():
27
- file.write(f"{node_id}\n")
28
- for reaction in reactions:
29
- file.write(f"{reaction}\n")
30
-
31
-
32
- def get_highest_route_nodes(tree, node_dict):
33
- highest_nodes = {}
34
- for key, node_ids in node_dict.items():
35
- max_score = float('-inf')
36
- best_nodes = []
37
- for node_id in node_ids:
38
- score = round(tree.route_score(node_id), 3)
39
- if score > max_score:
40
- max_score = score
41
- best_nodes = [node_id]
42
- elif score == max_score:
43
- best_nodes.append(node_id)
44
- highest_nodes[key] = best_nodes
45
- return highest_nodes
46
-
47
-
48
- class TreeWrapper:
49
-
50
- BASE_DIR = 'forest'
51
-
52
- def __init__(self, tree, mol_id, config):
53
- """Initializes the TreeWrapper."""
54
- self.tree = tree
55
- self.mol_id = mol_id
56
- self.config = config
57
- # Ensure the directory exists before creating the filename
58
- os.makedirs(self.BASE_DIR, exist_ok=True)
59
- self.filename = os.path.join(self.BASE_DIR, f'tree_{mol_id}_{config}.pkl')
60
-
61
- def __getstate__(self):
62
- state = self.__dict__.copy()
63
- tree_state = self.tree.__dict__.copy()
64
- # Reset or remove non-pickleable attributes (e.g., _tqdm, policy_network, value_network)
65
- if '_tqdm' in tree_state:
66
- tree_state['_tqdm'] = True # Reset to a simple flag
67
- for attr in ['policy_network', 'value_network']:
68
- if attr in tree_state:
69
- tree_state[attr] = None
70
- state['tree_state'] = tree_state
71
- del state['tree']
72
- return state
73
-
74
- def __setstate__(self, state):
75
- tree_state = state.pop('tree_state')
76
- self.__dict__.update(state)
77
- new_tree = Tree.__new__(Tree)
78
- new_tree.__dict__.update(tree_state)
79
- self.tree = new_tree
80
-
81
- def save_tree(self):
82
- """Saves the TreeWrapper instance (including the tree state) to a file."""
83
- try:
84
- with open(self.filename, 'wb') as f:
85
- pickle.dump(self, f)
86
- print(f"Tree wrapper for mol_id '{self.mol_id}', config '{self.config}' saved to '{self.filename}'.")
87
- except Exception as e:
88
- print(f"Error saving tree to {self.filename}: {e}")
89
-
90
- @classmethod
91
- def load_tree_from_id(cls, mol_id, config):
92
- """
93
- Loads a Tree object from a saved file using mol_id and config.
94
-
95
- Args:
96
- mol_id: The molecule ID used for saving.
97
- config: The configuration used for saving.
98
-
99
- Returns:
100
- The loaded Tree object, or None if loading fails.
101
- """
102
- filename = os.path.join(cls.BASE_DIR, f'tree_{mol_id}_{config}.pkl')
103
- print(f"Attempting to load tree from: {filename}")
104
- try:
105
- # Ensure the 'Tree' class is defined in the current scope
106
- if 'Tree' not in globals() and 'Tree' not in locals():
107
- raise NameError("The 'Tree' class definition is required to load the object.")
108
-
109
- with open(filename, 'rb') as f:
110
- loaded_wrapper = pickle.load(f) # This implicitly calls __setstate__
111
-
112
- # Check if the loaded object is indeed a TreeWrapper instance (optional sanity check)
113
- if not isinstance(loaded_wrapper, cls):
114
- print(f"Warning: Loaded object from {filename} is not a TreeWrapper instance.")
115
- return None # Or raise an error
116
-
117
- print(f"Tree object for mol_id '{mol_id}', config '{config}' successfully loaded from '{filename}'.")
118
- # The __setstate__ method already reconstructed the tree inside the wrapper
119
- return loaded_wrapper.tree
120
-
121
- except FileNotFoundError:
122
- print(f"Error: File not found at {filename}")
123
- return None
124
- except (pickle.UnpicklingError, EOFError) as e:
125
- print(f"Error: Could not unpickle file {filename}. It might be corrupted or empty. Details: {e}")
126
- return None
127
- except NameError as e:
128
- print(f"Error during loading: {e}. Ensure 'Tree' class is defined.")
129
- return None
130
- except Exception as e:
131
- print(f"An unexpected error occurred loading tree from {filename}: {e}")
132
- return None
133
-
134
- def generate_cluster_html(
135
- tree: Tree,
136
- cluster_node_ids: list,
137
- cluster_num: int,
138
- rg_cgrs_dict: dict, # <--- New parameter
139
- aam: bool = False,
140
- ) -> str:
141
- # ... (initial setup, validation, filtering routes remains the same) ...
142
- """
143
- Generates an HTML page report for a specific cluster's synthesis routes.
144
-
145
- :param tree: The built MCTS tree.
146
- :param cluster_node_ids: List of route node IDs belonging to this cluster.
147
- :param cluster_num: The identifier number for this cluster (used in title/header).
148
- :param aam: If True, depict atom-to-atom mapping in route SVGs.
149
- # :param scg_svg: Optional SVG string for the cluster's representative SCG.
150
- :return: A string containing the complete HTML report.
151
- """
152
- # --- Depict Settings (Optional: Keep if get_route_svg depends on it) ---
153
- # Uncomment if MoleculeContainer is used and needed:
154
- try:
155
- if aam:
156
- MoleculeContainer.depict_settings(aam=True)
157
- else:
158
- MoleculeContainer.depict_settings(aam=False)
159
- except NameError:
160
- # If MoleculeContainer isn't available/needed, just pass
161
- pass
162
- except Exception as e:
163
- print(f"Warning: Error setting MoleculeContainer depict settings: {e}")
164
-
165
- # --- Validate Input ---
166
- if not isinstance(cluster_node_ids, list):
167
- return "<html><body>Error: cluster_node_ids must be a list.</body></html>"
168
- if not tree or not isinstance(tree, Tree):
169
- return "<html><body>Error: Invalid tree object provided.</body></html>"
170
-
171
- # Filter out node IDs not actually present or not solved in the tree
172
- valid_routes_in_cluster = []
173
- for node_id in cluster_node_ids:
174
- if node_id in tree.nodes and tree.nodes[node_id].is_solved():
175
- valid_routes_in_cluster.append(node_id)
176
- # Optionally log or warn about invalid/unsolved nodes removed
177
-
178
- if not valid_routes_in_cluster:
179
- # Return a minimal HTML page indicating no valid routes
180
- return f"""
181
- <!doctype html><html lang="en"><head><meta charset="utf-8">
182
- <title>Cluster {cluster_num} Report</title></head><body>
183
- <h3>Cluster {cluster_num} Report</h3>
184
- <p>No valid/solved routes found for this cluster.</p>
185
- </body></html>"""
186
-
187
- # --- HTML Templates & Tags ---
188
- # (Keep tags like th, td, fonts as they were)
189
- th = '<th style="text-align: left; background-color:#978785; border: 1px solid black; border-spacing: 0">'
190
- td = '<td style="text-align: left; border: 1px solid black; border-spacing: 0">'
191
- # font_red = "<font color='red' style='font-weight: bold'>" # Consider using CSS classes instead
192
- # font_green = "<font color='light-green' style='font-weight: bold'>"
193
- font_head = "<font style='font-weight: bold; font-size: 18px'>"
194
- font_normal = "<font style='font-weight: normal; font-size: 18px'>"
195
- font_close = "</font>"
196
-
197
- template_begin = f"""
198
- <!doctype html>
199
- <html lang="en">
200
- <head>
201
- <link href="https://cdn.jsdelivr.net/npm/bootstrap@5.1.3/dist/css/bootstrap.min.css"
202
- rel="stylesheet"
203
- integrity="sha384-1BmE4kWBq78iYhFldvKuhfTAU6auU8tT94WrHftjDbrCEXSU1oBoqyl2QvZ6jIW3"
204
- crossorigin="anonymous">
205
- <meta charset="utf-8">
206
- <meta name="viewport" content="width=device-width, initial-scale=1">
207
- <title>Cluster {cluster_num} Routes Report</title>
208
- <style>
209
- /* Optional: Add some basic styling */
210
- .table {{ border-collapse: collapse; width: 100%; }}
211
- th, td {{ border: 1px solid #ddd; padding: 8px; text-align: left; }}
212
- tr:nth-child(even) {{ background-color: #f2f2f2; }}
213
- caption {{ caption-side: top; font-size: 1.5em; margin: 1em 0; }}
214
- svg {{ max-width: 100%; height: auto; }} /* Make SVGs responsive */
215
- </style>
216
- </head>
217
- <body>
218
- <div class="container"> """
219
-
220
- template_end = """
221
- </div> <script
222
- src="https://cdn.jsdelivr.net/npm/bootstrap@5.1.3/dist/js/bootstrap.bundle.min.js"
223
- integrity="sha384-ka7Sk0Gln4gmtz2MlQnikT1wXgYsOg+OMhuP+IlRH9sENBO0LRn5q+8nbTov4+1p"
224
- crossorigin="anonymous">
225
- </script>
226
- </body>
227
- </html>
228
- """
229
-
230
- box_mark = """
231
- <svg width="30" height="30" viewBox="0 0 1 1" xmlns="http://www.w3.org/2000/svg" style="vertical-align: middle; margin-right: 5px;">
232
- <circle cx="0.5" cy="0.5" r="0.5" fill="rgb()" fill-opacity="0.35" />
233
- </svg>
234
- """
235
-
236
- # --- Build HTML Table ---
237
- table = f"""
238
- <table class="table table-striped table-hover caption-top">
239
- <caption><h3>Retrosynthetic Routes Report - Cluster {cluster_num}</h3></caption>
240
- <tbody>"""
241
-
242
- try:
243
- target_smiles_str = str(tree.nodes[1].curr_precursor) if 1 in tree.nodes else "N/A"
244
- except Exception:
245
- target_smiles_str = "Error retrieving target SMILES"
246
- table += f"<tr>{td}{font_normal}Target Molecule: {target_smiles_str}{font_close}</td></tr>"
247
- table += f"<tr>{td}{font_normal}Cluster Number: {cluster_num}{font_close}</td></tr>"
248
- table += f"<tr>{td}{font_normal}Size of Cluster: {len(valid_routes_in_cluster)}{font_close} routes</td></tr>"
249
-
250
- # --- Add RG-CGR Image ---
251
- # Get the node_id of the first valid route in the cluster
252
- first_route_id = valid_routes_in_cluster[0] if valid_routes_in_cluster else None
253
-
254
- if first_route_id and rg_cgrs_dict and first_route_id in rg_cgrs_dict:
255
- try:
256
- rg_cgr = rg_cgrs_dict[first_route_id]
257
- rg_cgr.clean2d()
258
- rg_cgr_svg = rg_cgr.depict()
259
-
260
- # Validate if it looks like SVG (basic check)
261
- if rg_cgr_svg.strip().startswith("<svg"):
262
- table += f"<tr>{td}{font_normal}Cluster Representative RG-CGR (from Route {first_route_id}):{font_close}<br>{rg_cgr_svg}</td></tr>"
263
- else:
264
- # Handle case where it's not SVG as expected
265
- table += f"<tr>{td}{font_normal}Cluster Representative RG-CGR (from Route {first_route_id}):{font_close}<br><i>Invalid SVG format retrieved.</i></td></tr>"
266
- print(f"Warning: Expected SVG for RG-CGR of node {first_route_id}, but got: {rg_cgr_svg[:100]}...") # Log a warning
267
-
268
- except Exception as e:
269
- table += f"<tr>{td}{font_normal}Cluster Representative RG-CGR (from Route {first_route_id}):{font_close}<br><i>Error retrieving/displaying RG-CGR: {e}</i></td></tr>"
270
- else:
271
- # Handle cases where RG-CGR data is missing
272
- if first_route_id:
273
- table += f"<tr>{td}{font_normal}Cluster Representative RG-CGR (from Route {first_route_id}):{font_close}<br><i>Not found in provided RG-CGR dictionary.</i></td></tr>"
274
- else:
275
- # This case shouldn't happen due to earlier check, but as fallback:
276
- table += f"<tr>{td}{font_normal}Cluster Representative RG-CGR:{font_close}<br><i>No valid routes in cluster to select from.</i></td></tr>"
277
-
278
-
279
- # --- Legend ---
280
- table += f"""
281
- <tr>{td}
282
- <div style="display: flex; align-items: center; flex-wrap: wrap; gap: 15px;">
283
- <span>{box_mark.replace("rgb()", "rgb(152, 238, 255)")} Target Molecule</span>
284
- <span>{box_mark.replace("rgb()", "rgb(240, 171, 144)")} Molecule Not In Stock</span>
285
- <span>{box_mark.replace("rgb()", "rgb(155, 250, 179)")} Molecule In Stock</span>
286
- </div>
287
- </td></tr>
288
- """
289
-
290
- # --- Add Routes for this Cluster ---
291
- for route_id in valid_routes_in_cluster:
292
- try:
293
- svg = get_route_svg(tree, route_id) # get SVG
294
- full_route = tree.synthesis_route(route_id) # get route steps
295
- reactions = ""
296
- for i, synth_step in enumerate(full_route):
297
- reactions += f"<b>Step {i + 1}:</b> {str(synth_step)}<br>"
298
- route_score = round(tree.route_score(route_id), 3)
299
-
300
- table += (
301
- f'<tr style="line-height: 1.8;">{td}{font_head}Route {route_id} | ' # Use | for separation
302
- f"Steps: {len(full_route)} | "
303
- f"Score: {route_score}{font_close}</td></tr>"
304
- )
305
- table += f"<tr>{td}{svg if svg else '<i>Error generating route visualization</i>'}</td></tr>"
306
- table += f"<tr>{td}{reactions if reactions else '<i>No reaction steps found</i>'}</td></tr>"
307
- except Exception as e:
308
- table += f'<tr><td colspan="1" style="color: red;">Error processing route {route_id}: {e}</td></tr>' # Use colspan if needed based on final table structure
309
-
310
- table += "</tbody></table>"
311
-
312
- # --- Combine and Return Full HTML ---
313
- full_html = template_begin + table + template_end
314
- return full_html
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cluster/visualize.py DELETED
@@ -1,481 +0,0 @@
1
- import os
2
- import re
3
- from collections import Counter
4
- import numpy as np
5
- import matplotlib.pyplot as plt
6
- import seaborn as sns
7
- from IPython.display import SVG, display
8
-
9
- import io
10
- import sys
11
-
12
- from synplan.utils.visualisation import get_route_svg
13
- from scipy.cluster.hierarchy import dendrogram
14
-
15
- from .reduced_g_cgr import extract_strategic_bonds, compare_rg_cgr_by_strategic_bonds
16
-
17
- def report_2_dissimilar(similarity_df, tree, rg_cgrs_dict):
18
- min_index = similarity_df.stack().idxmin()
19
- row_index, col_index = min_index
20
-
21
- print(f'Most dissimilar routes are {row_index} and {col_index}, Tanimoto index = {"%.2f" % similarity_df[row_index][col_index]}')
22
-
23
- print('Route ID', row_index)
24
- rg_cgr_1 = rg_cgrs_dict[row_index]
25
- rg_cgr_1.clean2d()
26
- display(SVG(rg_cgr_1.depict()))
27
- extract_strategic_bonds(rg_cgr_1)
28
- display(SVG(get_route_svg(tree, row_index)))
29
-
30
-
31
- print('Route ID', col_index)
32
- rg_cgr_2 = rg_cgrs_dict[col_index]
33
- rg_cgr_2.clean2d()
34
- display(SVG(rg_cgr_2.depict()))
35
- extract_strategic_bonds(rg_cgr_2)
36
- display(SVG(get_route_svg(tree, col_index)))
37
-
38
- print('Summary:')
39
- compare_rg_cgr_by_strategic_bonds(rg_cgr_1, rg_cgr_2)
40
-
41
- def save_clusters_html(clusters, best_by_score, tree, rg_cgrs_dict, mol_id, config):
42
- # Prepare a list to accumulate HTML parts for each cluster
43
- os.makedirs("./final_clusters", exist_ok=True)
44
- html_parts = []
45
-
46
- # Loop over your clusters
47
- for cluster_num, node_id_list in clusters.items():
48
- parts = [] # to accumulate parts for this cluster
49
-
50
- # Generate text output
51
- best_route_in_cluster = best_by_score[cluster_num][0]
52
- score = round(tree.route_score(best_route_in_cluster), 3)
53
- parts.append(f"{cluster_num} ||| Size: {len(clusters[cluster_num])}\n")
54
- parts.append(f"Example: {best_route_in_cluster} Route score: {score}\n")
55
-
56
- # Insert the first SVG immediately after its marker text
57
- svg1 = get_route_svg(tree, best_route_in_cluster)
58
- parts.append(svg1 + "\n")
59
-
60
- # Continue with additional text and SVGs
61
- parts.append("The RG-CGR:\n")
62
- rg_cgr = rg_cgrs_dict[best_route_in_cluster]
63
- rg_cgr.clean2d()
64
- svg2 = rg_cgr.depict()
65
- parts.append(svg2 + "\n")
66
-
67
- # Capture output from extract_strategic_bonds, if it prints something
68
- buf = io.StringIO()
69
- old_stdout = sys.stdout
70
- sys.stdout = buf
71
- extract_strategic_bonds(rg_cgr)
72
- sys.stdout = old_stdout
73
- strategic_text = buf.getvalue()
74
- parts.append(strategic_text + "\n")
75
-
76
- # Wrap this cluster's output in a <pre> tag for formatting and add some spacing
77
- cluster_html = f'<div class="cluster" style="margin-bottom: 2em;"><pre>{"".join(parts)}</pre></div>'
78
- html_parts.append(cluster_html)
79
-
80
- # Combine all parts into a full HTML document
81
- html_content = f"""
82
- <html>
83
- <head>
84
- <meta charset="utf-8">
85
- <title>Captured Cluster Outputs</title>
86
- </head>
87
- <body>
88
- {''.join(html_parts)}
89
- </body>
90
- </html>
91
- """
92
-
93
-
94
- # Write the HTML content to a file
95
- with open(f"final_clusters/htmls/mol_{mol_id}_{config}.html", "w", encoding="utf-8") as f:
96
- f.write(html_content)
97
-
98
- def report_2_dissimilar_to_html(similarity_df, tree, rg_cgrs_dict, mol_id=1, config=2,output_filename=None):
99
- """Generates an HTML report of the two most dissimilar routes based on a similarity DataFrame."""
100
-
101
- os.makedirs("./dissimilars", exist_ok=True)
102
- output_filename=f"dissimilars/report_dissimilar_mol_{mol_id}_{config}.html"
103
- # Identify the two most dissimilar routes
104
- min_index = similarity_df.stack().idxmin()
105
- row_index, col_index = min_index
106
-
107
- # Capture text output in a buffer
108
- buf = io.StringIO()
109
- old_stdout = sys.stdout
110
- sys.stdout = buf
111
-
112
- print(f'Most dissimilar routes are {row_index} and {col_index}, Tanimoto index = {"%.2f" % similarity_df[row_index][col_index]}')
113
-
114
- # Store HTML content
115
- html_parts = []
116
-
117
- # Function to capture and append text, SVGs, and function outputs
118
- def capture_route_info(route_id):
119
- rg_cgr = rg_cgrs_dict[route_id]
120
- rg_cgr.clean2d()
121
-
122
- # Capture the first SVG (RG-CGR depiction)
123
- svg1 = rg_cgr.depict()
124
-
125
- # Capture the second SVG (Route depiction)
126
- svg2 = get_route_svg(tree, route_id)
127
-
128
- # Capture output of extract_strategic_bonds
129
- buf_extract = io.StringIO()
130
- sys.stdout = buf_extract
131
- extract_strategic_bonds(rg_cgr)
132
- sys.stdout = old_stdout
133
- extract_output = buf_extract.getvalue()
134
-
135
- # Store text + SVGs in HTML format
136
- html_parts.append(f"""
137
- <div class="route-section">
138
- <pre>{buf.getvalue()}</pre>
139
- <div class="svg1">{svg1}</div>
140
- <pre>{extract_output}</pre>
141
- <div class="svg2">{svg2}</div>
142
- </div>
143
- """)
144
- buf.truncate(0) # Clear buffer for next route
145
- buf.seek(0)
146
-
147
- # Process the first route
148
- capture_route_info(row_index)
149
-
150
- # Process the second route
151
- capture_route_info(col_index)
152
-
153
- # Capture and store final summary
154
- buf_summary = io.StringIO()
155
- sys.stdout = buf_summary
156
- compare_rg_cgr_by_strategic_bonds(rg_cgrs_dict[row_index], rg_cgrs_dict[col_index])
157
- sys.stdout = old_stdout
158
- summary_output = buf_summary.getvalue()
159
- html_parts.append(f"<h2>Summary</h2><pre>{summary_output}</pre>")
160
-
161
- # Restore standard stdout
162
- sys.stdout = old_stdout
163
-
164
- # Build the full HTML file
165
- html_content = f"""
166
- <html>
167
- <head>
168
- <meta charset="utf-8">
169
- <title>Route Dissimilarity Report</title>
170
- </head>
171
- <body>
172
- {''.join(html_parts)}
173
- </body>
174
- </html>
175
- """
176
-
177
- # Write the HTML file
178
- with open(output_filename, "w", encoding="utf-8") as f:
179
- f.write(html_content)
180
-
181
- print(f"Report saved as {output_filename}")
182
-
183
- def pie_chart(cluster_sizes, sub='', input_cluster_num=1, input_step_nums=None):
184
- labels = [f'{sub}Cluster {i+1}' for i in range(len(cluster_sizes))]
185
-
186
- sns.set_style("whitegrid")
187
-
188
- fig, ax = plt.subplots(figsize=(6, 6))
189
- wedges, texts, autotexts = ax.pie(
190
- cluster_sizes, labels=None, autopct='%1.1f%%', colors=sns.color_palette("pastel"),
191
- startangle=140, wedgeprops={'edgecolor': 'black'}
192
- )
193
- ax.legend(wedges, labels, title=f"{sub}Clusters", loc="center left", bbox_to_anchor=(1, 0.5))
194
- if sub == '':
195
- ax.set_title(f"{sub}Cluster Size Distribution for {sum(cluster_sizes)} routes")
196
- else:
197
- ax.set_title(f"{sub}cluster Size Distribution for {sum(cluster_sizes)} routes in cluster {input_cluster_num} with number of steps {input_step_nums}")
198
-
199
- plt.close(fig)
200
-
201
- # plt.show()
202
- return fig
203
-
204
- def save_dendrogram(df, Z, mol_id, config):
205
- plt.figure(figsize=(14, 7)) # figsize=(14, 7)
206
-
207
- dendrogram(Z, labels=df.columns, leaf_rotation=90)
208
- plt.title(f"Hierarchical Clustering Dendrogram for routes generated for molecule #{mol_id}")
209
- plt.xlabel("Route node id")
210
- plt.ylabel("Distance (1 - Similarity)")
211
-
212
- # Get current y-axis limits and add a gap below zero
213
- ax = plt.gca()
214
- ymin, ymax = ax.get_ylim()
215
- # Add a gap that is 5% of the current y-range below zero
216
- gap = 0.05 * (ymax - ymin)
217
- ax.set_ylim(ymin - gap, ymax)
218
- ax.grid(False)
219
- ax.autoscale(enable=None, axis="x", tight=True)
220
-
221
- plt.tight_layout()
222
- plt.savefig(f'dendrograms/av_link_mol{mol_id}_{config}.png', dpi=100)
223
-
224
-
225
- def distribution_by_depth(tree, complex_cgr_dict):
226
- if len(complex_cgr_dict) == 0:
227
- print('Error: Empty dictionary')
228
- return None
229
- depths = np.zeros(len(complex_cgr_dict))
230
- for n, node in enumerate(complex_cgr_dict.keys()):
231
- reactions = tree.synthesis_route(node)
232
- depths[n] = len(reactions)
233
- return depths
234
-
235
- def histogram_by_depth(depths, mol_id=1, config=1, save=False):
236
-
237
- if len(depths) == 0:
238
- print('Error no depths')
239
- return None
240
- # Count frequency of each depth
241
- counter = Counter(depths)
242
- bins, counts = zip(*sorted(counter.items()))
243
-
244
- # Plot the histogram
245
- plt.bar(bins, counts, width=0.5, color='skyblue', edgecolor='black')
246
- plt.xlabel('Number of reactions')
247
- plt.ylabel('Frequency')
248
- plt.title(f'Frequency Histogram of Number of reactions in one tree of total {len(depths)}')
249
- plt.xticks(bins)
250
- if save:
251
- plt.savefig(f'histograms/by_depth_mol{mol_id}_{config}.png', dpi=100)
252
- else:
253
- plt.show()
254
-
255
-
256
- def group_routes_by_depth(depths):
257
- """
258
- Group route IDs by their reaction count (depth).
259
-
260
- Args:
261
- depths: Dictionary with node_ids as keys and reaction tuples as values
262
-
263
- Returns:
264
- dict: Dictionary with depths as keys and lists of node_ids as values
265
- """
266
- depth_groups = {}
267
- for node_id, reactions in depths.items():
268
- depth = len(reactions)
269
- if depth not in depth_groups:
270
- depth_groups[depth] = []
271
- depth_groups[depth].append(node_id)
272
- return depth_groups
273
-
274
- def create_route_svg(tree, node_ids, mol_id, config, depths, depth=None):
275
- """Create SVG file for specified routes with optimized spacing."""
276
-
277
- # First pass: analyze all SVGs to find maximum width
278
- max_width_cm = 0
279
- all_route_svgs = [] # Store SVGs to avoid calling get_route_svg twice
280
-
281
- for g in node_ids:
282
- route_svg = get_route_svg(tree, g)
283
- all_route_svgs.append(route_svg)
284
-
285
- # Extract the actual SVG content
286
- svg_match = re.search(r'<svg[^>]*>', route_svg)
287
- if svg_match:
288
- svg_header = svg_match.group(0)
289
-
290
- # Try to get width from cm attribute
291
- width_match = re.search(r'width="([0-9.]+)cm"', svg_header)
292
- if width_match:
293
- try:
294
- width_cm = float(width_match.group(1))
295
- max_width_cm = max(max_width_cm, width_cm)
296
- except ValueError:
297
- pass
298
-
299
- # Convert cm to pixels (1cm ≈ 37.8 pixels)
300
- CM_TO_PX = 37.8
301
- max_width_px = max_width_cm * CM_TO_PX
302
-
303
- # Add margins
304
- left_margin = 50
305
- right_margin = 100
306
- composite_width = max_width_px + left_margin + right_margin
307
-
308
- # Continue with SVG creation using calculated width
309
- vertical_spacing = 20
310
- text_height = 20
311
- route_spacing = 250
312
- current_y = 30
313
- entries = []
314
-
315
- size = len(node_ids)
316
-
317
- for num, (g, route_svg_str) in enumerate(zip(node_ids, all_route_svgs), 1):
318
- # Calculate dimensions
319
- route_px_height = 200
320
-
321
- # Create entry with optimized spacing
322
- entry_parts = []
323
- entry_parts.append(f'<g transform="translate({left_margin}, {current_y})">')
324
- entry_parts.append(f' <text x="0" y="{text_height}" font-size="12" fill="black">{num} (Node ID: {g}, Number of reactions: {len(depths[g])})</text>')
325
-
326
- inner_y = text_height + 25
327
- entry_parts.append(f' <g transform="translate(0, {inner_y})">{route_svg_str}</g>')
328
-
329
- total_entry_height = inner_y + route_px_height + 250
330
- entry_parts.append('</g>')
331
-
332
- entry_block = "\n".join(entry_parts)
333
- entry_bottom_y = current_y + total_entry_height
334
- entries.append((entry_block, entry_bottom_y))
335
-
336
- current_y = entry_bottom_y + route_spacing - 50
337
-
338
- # Create master SVG with adjusted dimensions
339
- master_width = composite_width
340
- master_height = current_y + vertical_spacing
341
-
342
- final_parts = []
343
- for entry_block, bottom_y in entries:
344
- final_parts.append(entry_block)
345
- final_parts.append(f'<line x1="0" y1="{bottom_y}" x2="{master_width}" y2="{bottom_y}" stroke="black" stroke-width="1" />')
346
-
347
- master_svg = f'<svg xmlns="http://www.w3.org/2000/svg" width="{master_width}" height="{master_height}" viewBox="0 0 {master_width} {master_height}">\n'
348
- master_svg += "\n".join(final_parts)
349
- master_svg += "\n</svg>"
350
-
351
- # Save file with appropriate name
352
- if depth is None:
353
- path_name = f"./routes_img/mol_{mol_id}/mol{mol_id}_{config}_all_{size}.svg"
354
- else:
355
- path_name = f"./routes_img/mol_{mol_id}/mol{mol_id}_{config}_depth_{depth}_{size}.svg"
356
-
357
- with open(path_name, "w") as f:
358
- f.write(master_svg)
359
-
360
- print(f"Saved: {path_name}")
361
-
362
-
363
- def create_route_svg_cluster(tree, node_ids, mol_id, config, depths, cluster_num):
364
- """
365
- Create SVG file for specified routes with optimized spacing, grouped by cluster.
366
- """
367
- # First pass: analyze all SVGs to find maximum width
368
- max_width_cm = 0
369
- all_route_svgs = [] # Store SVGs to avoid calling get_route_svg twice
370
-
371
- for g in node_ids:
372
- route_svg = get_route_svg(tree, g)
373
- all_route_svgs.append(route_svg)
374
-
375
- # Extract the actual SVG content
376
- svg_match = re.search(r'<svg[^>]*>', route_svg)
377
- if svg_match:
378
- svg_header = svg_match.group(0)
379
-
380
- # Try to get width from cm attribute
381
- width_match = re.search(r'width="([0-9.]+)cm"', svg_header)
382
- if width_match:
383
- try:
384
- width_cm = float(width_match.group(1))
385
- max_width_cm = max(max_width_cm, width_cm)
386
- except ValueError:
387
- pass
388
-
389
- # Convert cm to pixels (1cm ≈ 37.8 pixels)
390
- CM_TO_PX = 37.8
391
- max_width_px = max_width_cm * CM_TO_PX
392
-
393
- # Add margins
394
- left_margin = 50
395
- right_margin = 100
396
- composite_width = max_width_px + left_margin + right_margin
397
-
398
- # Continue with SVG creation using calculated width
399
- vertical_spacing = 20
400
- text_height = 20
401
- route_spacing = 250
402
- current_y = 30
403
- entries = []
404
-
405
- size = len(node_ids)
406
-
407
- for num, (g, route_svg_str) in enumerate(zip(node_ids, all_route_svgs), 1):
408
- # Calculate dimensions
409
- route_px_height = 200
410
-
411
- # Create entry with optimized spacing
412
- entry_parts = []
413
- entry_parts.append(f'<g transform="translate({left_margin}, {current_y})">')
414
- entry_parts.append(f' <text x="0" y="{text_height}" font-size="12" fill="black">{num} (Node ID: {g}, Number of reactions: {len(depths[g])})</text>')
415
-
416
- inner_y = text_height + 25
417
- entry_parts.append(f' <g transform="translate(0, {inner_y})">{route_svg_str}</g>')
418
-
419
- total_entry_height = inner_y + route_px_height + 350
420
- entry_parts.append('</g>')
421
-
422
- entry_block = "\n".join(entry_parts)
423
- entry_bottom_y = current_y + total_entry_height
424
- entries.append((entry_block, entry_bottom_y))
425
-
426
- current_y = entry_bottom_y + route_spacing - 50
427
-
428
- # Create master SVG with adjusted dimensions
429
- master_width = composite_width
430
- master_height = current_y + vertical_spacing
431
-
432
- final_parts = []
433
- for entry_block, bottom_y in entries:
434
- final_parts.append(entry_block)
435
- final_parts.append(f'<line x1="0" y1="{bottom_y}" x2="{master_width}" y2="{bottom_y}" stroke="black" stroke-width="1" />')
436
-
437
- master_svg = f'<svg xmlns="http://www.w3.org/2000/svg" width="{master_width}" height="{master_height}" viewBox="0 0 {master_width} {master_height}">\n'
438
- master_svg += "\n".join(final_parts)
439
- master_svg += "\n</svg>"
440
-
441
- # Save file with cluster-specific name
442
- path_name = f"./routes_img/mol_{mol_id}/mol{mol_id}_{config}_cluster_{cluster_num}_{size}.svg"
443
-
444
- with open(path_name, "w") as f:
445
- f.write(master_svg)
446
-
447
- print(f"Saved: {path_name}")
448
-
449
- def save_route_images(tree, depths, mol_id, config, cluster_dict=None):
450
- """
451
- Save route images grouped by depth and/or cluster.
452
-
453
- Args:
454
- tree: Synthesis tree
455
- routes: Dictionary of routes
456
- depths: Dictionary of reaction depths
457
- mol_id: Molecule ID
458
- config: Configuration value
459
- cluster_dict: Optional dictionary mapping cluster numbers to lists of node_ids
460
- """
461
- # Create directory if it doesn't exist
462
- os.makedirs("./routes_img", exist_ok=True)
463
- os.makedirs(f"./routes_img/mol_{mol_id}", exist_ok=True)
464
-
465
- # Save complete image with all routes
466
- all_node_ids = sorted(depths.keys())
467
- create_route_svg(tree, all_node_ids, mol_id, config, depths)
468
-
469
- # Group routes by depth and save separate images
470
- depth_groups = group_routes_by_depth(depths)
471
- for depth, node_ids in depth_groups.items():
472
- create_route_svg(tree, sorted(node_ids), mol_id, config, depths, depth)
473
-
474
- # If cluster dictionary is provided, save routes grouped by cluster
475
- if cluster_dict is not None:
476
- for cluster_num, node_ids in cluster_dict.items():
477
- # Filter node_ids to only include those that exist in routes
478
- valid_node_ids = [nid for nid in node_ids if nid in depths]
479
- if valid_node_ids:
480
- create_route_svg_cluster(tree, sorted(valid_node_ids),
481
- mol_id, config, depths, cluster_num)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
synplan/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from .mcts import *
2
+
3
+ __all__ = ["Tree"]
synplan/chem/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from CGRtools.files import SMILESRead
2
+
3
+ smiles_parser = SMILESRead.create_parser(ignore=True)
{cluster → synplan/chem/data}/__init__.py RENAMED
File without changes
synplan/chem/data/filtering.py ADDED
@@ -0,0 +1,962 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Module containing classes abd functions for reactions filtering."""
2
+
3
+ import logging
4
+ from dataclasses import dataclass
5
+ from io import TextIOWrapper
6
+ from typing import Any, Dict, Iterable, List, Optional, Tuple
7
+
8
+ import numpy as np
9
+ import ray
10
+ import yaml
11
+ from CGRtools.containers import CGRContainer, MoleculeContainer, ReactionContainer
12
+ from chython.algorithms.fingerprints.morgan import MorganFingerprint
13
+ from tqdm import tqdm
14
+
15
+ from synplan.chem.data.standardizing import (
16
+ AromaticFormStandardizer,
17
+ KekuleFormStandardizer,
18
+ RemoveReagentsStandardizer,
19
+ )
20
+ from synplan.chem.utils import cgrtools_to_chython_molecule
21
+ from synplan.utils.config import ConfigABC, convert_config_to_dict
22
+ from synplan.utils.files import ReactionReader, ReactionWriter
23
+
24
+
25
+ @dataclass
26
+ class CompeteProductsConfig(ConfigABC):
27
+ fingerprint_tanimoto_threshold: float = 0.3
28
+ mcs_tanimoto_threshold: float = 0.6
29
+
30
+ @staticmethod
31
+ def from_dict(config_dict: Dict[str, Any]) -> "CompeteProductsConfig":
32
+ """Create an instance of CompeteProductsConfig from a dictionary."""
33
+ return CompeteProductsConfig(**config_dict)
34
+
35
+ @staticmethod
36
+ def from_yaml(file_path: str) -> "CompeteProductsConfig":
37
+ """Deserialize a YAML file into a CompeteProductsConfig object."""
38
+ with open(file_path, "r", encoding="utf-8") as file:
39
+ config_dict = yaml.safe_load(file)
40
+ return CompeteProductsConfig.from_dict(config_dict)
41
+
42
+ def _validate_params(self, params: Dict[str, Any]) -> None:
43
+ """Validate configuration parameters."""
44
+ if not isinstance(params.get("fingerprint_tanimoto_threshold"), float) or not (
45
+ 0 <= params["fingerprint_tanimoto_threshold"] <= 1
46
+ ):
47
+ raise ValueError(
48
+ "Invalid 'fingerprint_tanimoto_threshold'; expected a float between 0 and 1"
49
+ )
50
+
51
+ if not isinstance(params.get("mcs_tanimoto_threshold"), float) or not (
52
+ 0 <= params["mcs_tanimoto_threshold"] <= 1
53
+ ):
54
+ raise ValueError(
55
+ "Invalid 'mcs_tanimoto_threshold'; expected a float between 0 and 1"
56
+ )
57
+
58
+
59
+ class CompeteProductsFilter:
60
+ """Checks if there are compete reactions."""
61
+
62
+ def __init__(
63
+ self,
64
+ fingerprint_tanimoto_threshold: float = 0.3,
65
+ mcs_tanimoto_threshold: float = 0.6,
66
+ ):
67
+ self.fingerprint_tanimoto_threshold = fingerprint_tanimoto_threshold
68
+ self.mcs_tanimoto_threshold = mcs_tanimoto_threshold
69
+
70
+ @staticmethod
71
+ def from_config(config: CompeteProductsConfig) -> "CompeteProductsFilter":
72
+ """Creates an instance of CompeteProductsFilter from a configuration object."""
73
+ return CompeteProductsFilter(
74
+ config.fingerprint_tanimoto_threshold, config.mcs_tanimoto_threshold
75
+ )
76
+
77
+ def __call__(self, reaction: ReactionContainer) -> bool:
78
+ """Checks if the reaction has competing products, else False.
79
+
80
+ :param reaction: Input reaction.
81
+ :return: Returns True if the reaction has competing products, else False.
82
+ """
83
+ mf = MorganFingerprint()
84
+ is_compete = False
85
+
86
+ # check for compete products using both fingerprint similarity and maximum common substructure (MCS) similarity
87
+ for mol in reaction.reagents:
88
+ for other_mol in reaction.products:
89
+ if len(mol) > 6 and len(other_mol) > 6:
90
+ # compute fingerprint similarity
91
+ molf = mf.transform([cgrtools_to_chython_molecule(mol)])
92
+ other_molf = mf.transform([cgrtools_to_chython_molecule(other_mol)])
93
+ fingerprint_tanimoto = tanimoto_kernel(molf, other_molf)[0][0]
94
+
95
+ # if fingerprint similarity is high enough, check for MCS similarity
96
+ if fingerprint_tanimoto > self.fingerprint_tanimoto_threshold:
97
+ try:
98
+ # find the maximum common substructure (MCS) and compute its size
99
+ clique_size = len(
100
+ next(mol.get_mcs_mapping(other_mol, limit=100))
101
+ )
102
+
103
+ # calculate MCS similarity based on MCS size
104
+ mcs_tanimoto = clique_size / (
105
+ len(mol) + len(other_mol) - clique_size
106
+ )
107
+
108
+ # if MCS similarity is also high enough, mark the reaction as having compete products
109
+ if mcs_tanimoto > self.mcs_tanimoto_threshold:
110
+ is_compete = True
111
+ break
112
+ except StopIteration:
113
+ continue
114
+
115
+ return is_compete
116
+
117
+
118
+ @dataclass
119
+ class DynamicBondsConfig(ConfigABC):
120
+ min_bonds_number: int = 1
121
+ max_bonds_number: int = 6
122
+
123
+ @staticmethod
124
+ def from_dict(config_dict: Dict[str, Any]) -> "DynamicBondsConfig":
125
+ """Create an instance of DynamicBondsConfig from a dictionary."""
126
+ return DynamicBondsConfig(**config_dict)
127
+
128
+ @staticmethod
129
+ def from_yaml(file_path: str) -> "DynamicBondsConfig":
130
+ """Deserialize a YAML file into a DynamicBondsConfig object."""
131
+ with open(file_path, "r") as file:
132
+ config_dict = yaml.safe_load(file)
133
+ return DynamicBondsConfig.from_dict(config_dict)
134
+
135
+ def _validate_params(self, params: Dict[str, Any]) -> None:
136
+ """Validate configuration parameters."""
137
+ if (
138
+ not isinstance(params.get("min_bonds_number"), int)
139
+ or params["min_bonds_number"] < 0
140
+ ):
141
+ raise ValueError(
142
+ "Invalid 'min_bonds_number'; expected a non-negative integer"
143
+ )
144
+
145
+ if (
146
+ not isinstance(params.get("max_bonds_number"), int)
147
+ or params["max_bonds_number"] < 0
148
+ ):
149
+ raise ValueError(
150
+ "Invalid 'max_bonds_number'; expected a non-negative integer"
151
+ )
152
+
153
+ if params["min_bonds_number"] > params["max_bonds_number"]:
154
+ raise ValueError(
155
+ "'min_bonds_number' cannot be greater than 'max_bonds_number'"
156
+ )
157
+
158
+
159
+ class DynamicBondsFilter:
160
+ """Checks if there is an unacceptable number of dynamic bonds in CGR."""
161
+
162
+ def __init__(self, min_bonds_number: int = 1, max_bonds_number: int = 6):
163
+ self.min_bonds_number = min_bonds_number
164
+ self.max_bonds_number = max_bonds_number
165
+
166
+ @staticmethod
167
+ def from_config(config: DynamicBondsConfig):
168
+ """Creates an instance of DynamicBondsChecker from a configuration object."""
169
+ return DynamicBondsFilter(config.min_bonds_number, config.max_bonds_number)
170
+
171
+ def __call__(self, reaction: ReactionContainer) -> bool:
172
+ cgr = ~reaction
173
+ return not (
174
+ self.min_bonds_number <= len(cgr.center_bonds) <= self.max_bonds_number
175
+ )
176
+
177
+
178
+ @dataclass
179
+ class SmallMoleculesConfig(ConfigABC):
180
+ mol_max_size: int = 6
181
+
182
+ @staticmethod
183
+ def from_dict(config_dict: Dict[str, Any]) -> "SmallMoleculesConfig":
184
+ """Creates an instance of SmallMoleculesConfig from a dictionary."""
185
+ return SmallMoleculesConfig(**config_dict)
186
+
187
+ @staticmethod
188
+ def from_yaml(file_path: str) -> "SmallMoleculesConfig":
189
+ """Deserialize a YAML file into a SmallMoleculesConfig object."""
190
+ with open(file_path, "r") as file:
191
+ config_dict = yaml.safe_load(file)
192
+ return SmallMoleculesConfig.from_dict(config_dict)
193
+
194
+ def _validate_params(self, params: Dict[str, Any]) -> None:
195
+ """Validate configuration parameters."""
196
+ if (
197
+ not isinstance(params.get("mol_max_size"), int)
198
+ or params["mol_max_size"] < 1
199
+ ):
200
+ raise ValueError("Invalid 'mol_max_size'; expected a positive integer")
201
+
202
+
203
+ class SmallMoleculesFilter:
204
+ """Checks if there are only small molecules in the reaction or if there is only one
205
+ small reactant or product."""
206
+
207
+ def __init__(self, mol_max_size: int = 6):
208
+ self.limit = mol_max_size
209
+
210
+ @staticmethod
211
+ def from_config(config: SmallMoleculesConfig) -> "SmallMoleculesFilter":
212
+ """Creates an instance of SmallMoleculesChecker from a configuration object."""
213
+ return SmallMoleculesFilter(config.mol_max_size)
214
+
215
+ def __call__(self, reaction: ReactionContainer) -> bool:
216
+ if (
217
+ (
218
+ len(reaction.reactants) == 1
219
+ and self.are_only_small_molecules(reaction.reactants)
220
+ )
221
+ or (
222
+ len(reaction.products) == 1
223
+ and self.are_only_small_molecules(reaction.products)
224
+ )
225
+ or (
226
+ self.are_only_small_molecules(reaction.reactants)
227
+ and self.are_only_small_molecules(reaction.products)
228
+ )
229
+ ):
230
+ return True
231
+ return False
232
+
233
+ def are_only_small_molecules(self, molecules: Iterable[MoleculeContainer]) -> bool:
234
+ """Checks if all molecules in the given iterable are small molecules."""
235
+ return all(len(molecule) <= self.limit for molecule in molecules)
236
+
237
+
238
+ @dataclass
239
+ class CGRConnectedComponentsConfig:
240
+ pass
241
+
242
+
243
+ class CGRConnectedComponentsFilter:
244
+ """Checks if CGR contains unrelated components (without reagents)."""
245
+
246
+ @staticmethod
247
+ def from_config(
248
+ config: CGRConnectedComponentsConfig,
249
+ ) -> "CGRConnectedComponentsFilter":
250
+ """Creates an instance of CGRConnectedComponentsChecker from a configuration
251
+ object."""
252
+ return CGRConnectedComponentsFilter()
253
+
254
+ def __call__(self, reaction: ReactionContainer) -> bool:
255
+ tmp_reaction = ReactionContainer(reaction.reactants, reaction.products)
256
+ cgr = ~tmp_reaction
257
+ return cgr.connected_components_count > 1
258
+
259
+
260
+ @dataclass
261
+ class RingsChangeConfig:
262
+ pass
263
+
264
+
265
+ class RingsChangeFilter:
266
+ """Checks if there is changing rings number in the reaction."""
267
+
268
+ @staticmethod
269
+ def from_config(config: RingsChangeConfig) -> "RingsChangeFilter":
270
+ """Creates an instance of RingsChecker from a configuration object."""
271
+ return RingsChangeFilter()
272
+
273
+ def __call__(self, reaction: ReactionContainer):
274
+ """
275
+ Returns True if there are valence mistakes in the reaction or there is a
276
+ reaction with mismatch numbers of all rings or aromatic rings in reactants and
277
+ products (reaction in rings)
278
+
279
+ :param reaction: Input reaction.
280
+ :return: Returns True if there are valence mistakes in the reaction.
281
+
282
+ """
283
+
284
+ r_rings, r_arom_rings = self._calc_rings(reaction.reactants)
285
+ p_rings, p_arom_rings = self._calc_rings(reaction.products)
286
+
287
+ return (r_arom_rings != p_arom_rings) or (r_rings != p_rings)
288
+
289
+ @staticmethod
290
+ def _calc_rings(molecules: Iterable) -> Tuple[int, int]:
291
+ """
292
+ Calculates number of all rings and number of aromatic rings in molecules.
293
+
294
+ :param molecules: Set of molecules.
295
+ :return: Number of all rings and number of aromatic rings in molecules
296
+ """
297
+ rings, arom_rings = 0, 0
298
+ for mol in molecules:
299
+ rings += mol.rings_count
300
+ arom_rings += len(mol.aromatic_rings)
301
+ return rings, arom_rings
302
+
303
+
304
+ @dataclass
305
+ class StrangeCarbonsConfig:
306
+ # currently empty, but can be extended in the future if needed
307
+ pass
308
+
309
+
310
+ class StrangeCarbonsFilter:
311
+ """Checks if there are 'strange' carbons in the reaction."""
312
+
313
+ @staticmethod
314
+ def from_config(config: StrangeCarbonsConfig) -> "StrangeCarbonsFilter":
315
+ """Creates an instance of StrangeCarbonsChecker from a configuration object."""
316
+ return StrangeCarbonsFilter()
317
+
318
+ def __call__(self, reaction: ReactionContainer) -> bool:
319
+ for molecule in reaction.reactants + reaction.products:
320
+ atoms_types = {
321
+ a.atomic_symbol for _, a in molecule.atoms()
322
+ } # atoms types in molecule
323
+ if len(atoms_types) == 1 and atoms_types.pop() == "C":
324
+ if len(molecule) == 1: # methane
325
+ return True
326
+ bond_types = {int(b) for _, _, b in molecule.bonds()}
327
+ if len(bond_types) == 1 and bond_types.pop() != 4:
328
+ return True # C molecules with only one type of bond (not aromatic)
329
+ return False
330
+
331
+
332
+ @dataclass
333
+ class NoReactionConfig:
334
+ # Currently empty, but can be extended in the future if needed
335
+ pass
336
+
337
+
338
+ class NoReactionFilter:
339
+ """Checks if there is no reaction in the provided reaction container."""
340
+
341
+ @staticmethod
342
+ def from_config(config: NoReactionConfig) -> "NoReactionFilter":
343
+ """Creates an instance of NoReactionChecker from a configuration object."""
344
+ return NoReactionFilter()
345
+
346
+ def __call__(self, reaction: ReactionContainer) -> bool:
347
+ cgr = ~reaction
348
+ return not cgr.center_atoms and not cgr.center_bonds
349
+
350
+
351
+ @dataclass
352
+ class MultiCenterConfig:
353
+ pass
354
+
355
+
356
+ class MultiCenterFilter:
357
+ """Checks if there is a multicenter reaction."""
358
+
359
+ @staticmethod
360
+ def from_config(config: MultiCenterConfig) -> "MultiCenterFilter":
361
+ return MultiCenterFilter()
362
+
363
+ def __call__(self, reaction: ReactionContainer) -> bool:
364
+ cgr = ~reaction
365
+ return len(cgr.centers_list) > 1
366
+
367
+
368
+ @dataclass
369
+ class WrongCHBreakingConfig:
370
+ pass
371
+
372
+
373
+ class WrongCHBreakingFilter:
374
+ """Checks for incorrect C-C bond formation from breaking a C-H bond."""
375
+
376
+ @staticmethod
377
+ def from_config(config: WrongCHBreakingConfig) -> "WrongCHBreakingFilter":
378
+ return WrongCHBreakingFilter()
379
+
380
+ def __call__(self, reaction: ReactionContainer) -> bool:
381
+ """
382
+ Determines if a reaction involves incorrect C-C bond formation from breaking
383
+ a C-H bond.
384
+
385
+ :param reaction: The reaction to be filtered.
386
+ :return: True if incorrect C-C bond formation is found, False otherwise.
387
+
388
+ """
389
+
390
+ if reaction.check_valence():
391
+ return False
392
+
393
+ copy_reaction = reaction.copy()
394
+ copy_reaction.explicify_hydrogens()
395
+ cgr = ~copy_reaction
396
+ reduced_cgr = cgr.augmented_substructure(cgr.center_atoms, deep=1)
397
+
398
+ return self.is_wrong_c_h_breaking(reduced_cgr)
399
+
400
+ @staticmethod
401
+ def is_wrong_c_h_breaking(cgr: CGRContainer) -> bool:
402
+ """
403
+ Checks for incorrect C-C bond formation from breaking a C-H bond in a CGR.
404
+
405
+ :param cgr: The CGR with explicified hydrogens.
406
+ :return: True if incorrect C-C bond formation is found, False otherwise.
407
+
408
+ """
409
+ for atom_id in cgr.center_atoms:
410
+ if cgr.atom(atom_id).atomic_symbol == "C":
411
+ is_c_h_breaking, is_c_c_formation = False, False
412
+ c_with_h_id, another_c_id = None, None
413
+
414
+ for neighbour_id, bond in cgr._bonds[atom_id].items():
415
+ neighbour = cgr.atom(neighbour_id)
416
+
417
+ if (
418
+ bond.order
419
+ and not bond.p_order
420
+ and neighbour.atomic_symbol == "H"
421
+ ):
422
+ is_c_h_breaking = True
423
+ c_with_h_id = atom_id
424
+
425
+ elif (
426
+ not bond.order
427
+ and bond.p_order
428
+ and neighbour.atomic_symbol == "C"
429
+ ):
430
+ is_c_c_formation = True
431
+ another_c_id = neighbour_id
432
+
433
+ if is_c_h_breaking and is_c_c_formation:
434
+ # check for presence of heteroatoms in the first environment of 2 bonding carbons
435
+ if any(
436
+ cgr.atom(neighbour_id).atomic_symbol not in ("C", "H")
437
+ for neighbour_id in cgr._bonds[c_with_h_id]
438
+ ) or any(
439
+ cgr.atom(neighbour_id).atomic_symbol not in ("C", "H")
440
+ for neighbour_id in cgr._bonds[another_c_id]
441
+ ):
442
+ return False
443
+ return True
444
+
445
+ return False
446
+
447
+
448
+ @dataclass
449
+ class CCsp3BreakingConfig:
450
+ pass
451
+
452
+
453
+ class CCsp3BreakingFilter:
454
+ """Checks if there is C(sp3)-C bond breaking."""
455
+
456
+ @staticmethod
457
+ def from_config(config: CCsp3BreakingConfig) -> "CCsp3BreakingFilter":
458
+ return CCsp3BreakingFilter()
459
+
460
+ def __call__(self, reaction: ReactionContainer) -> bool:
461
+ """
462
+ Returns True if there is C(sp3)-C bonds breaking, else False.
463
+
464
+ :param reaction: Input reaction
465
+ :return: Returns True if there is C(sp3)-C bonds breaking, else False.
466
+
467
+ """
468
+ cgr = ~reaction
469
+ reaction_center = cgr.augmented_substructure(cgr.center_atoms, deep=1)
470
+ for atom_id, neighbour_id, bond in reaction_center.bonds():
471
+ atom = reaction_center.atom(atom_id)
472
+ neighbour = reaction_center.atom(neighbour_id)
473
+
474
+ is_bond_broken = bond.order is not None and bond.p_order is None
475
+ are_atoms_carbons = (
476
+ atom.atomic_symbol == "C" and neighbour.atomic_symbol == "C"
477
+ )
478
+ is_atom_sp3 = atom.hybridization == 1 or neighbour.hybridization == 1
479
+
480
+ if is_bond_broken and are_atoms_carbons and is_atom_sp3:
481
+ return True
482
+ return False
483
+
484
+
485
+ @dataclass
486
+ class CCRingBreakingConfig:
487
+ """
488
+ Object to pass to ReactionFilterConfig if you want to enable C-C ring breaking filter
489
+
490
+ """
491
+
492
+ pass
493
+
494
+
495
+ class CCRingBreakingFilter:
496
+ """Checks if a reaction involves ring C-C bond breaking."""
497
+
498
+ @staticmethod
499
+ def from_config(config: CCRingBreakingConfig):
500
+ return CCRingBreakingFilter()
501
+
502
+ def __call__(self, reaction: ReactionContainer) -> bool:
503
+ """
504
+ Returns True if the reaction involves ring C-C bond breaking, else False.
505
+
506
+ :param reaction: Input reaction
507
+ :return: Returns True if the reaction involves ring C-C bond breaking, else
508
+ False.
509
+
510
+ """
511
+ cgr = ~reaction
512
+
513
+ # Extract reactants' center atoms and their rings
514
+ reactants_center_atoms = {}
515
+ reactants_rings = set()
516
+ for reactant in reaction.reactants:
517
+ reactants_rings.update(reactant.sssr)
518
+ for n, atom in reactant.atoms():
519
+ if n in cgr.center_atoms:
520
+ reactants_center_atoms[n] = atom
521
+
522
+ # identify reaction center based on center atoms
523
+ reaction_center = cgr.augmented_substructure(atoms=cgr.center_atoms, deep=0)
524
+
525
+ # iterate over bonds in the reaction center and filter for ring C-C bond breaking
526
+ for atom_id, neighbour_id, bond in reaction_center.bonds():
527
+ try:
528
+ # Retrieve corresponding atoms from reactants
529
+ atom = reactants_center_atoms[atom_id]
530
+ neighbour = reactants_center_atoms[neighbour_id]
531
+ except KeyError:
532
+ continue
533
+ else:
534
+ # Check if the bond is broken and both atoms are carbons in rings of size 5, 6, or 7
535
+ is_bond_broken = (bond.order is not None) and (bond.p_order is None)
536
+ are_atoms_carbons = (
537
+ atom.atomic_symbol == "C" and neighbour.atomic_symbol == "C"
538
+ )
539
+ are_atoms_in_ring = (
540
+ set(atom.ring_sizes).intersection({5, 6, 7})
541
+ and set(neighbour.ring_sizes).intersection({5, 6, 7})
542
+ and any(
543
+ atom_id in ring and neighbour_id in ring
544
+ for ring in reactants_rings
545
+ )
546
+ )
547
+
548
+ # If all conditions are met, indicate ring C-C bond breaking
549
+ if is_bond_broken and are_atoms_carbons and are_atoms_in_ring:
550
+ return True
551
+
552
+ return False
553
+
554
+
555
+ @dataclass
556
+ class ReactionFilterConfig(ConfigABC):
557
+ """
558
+ Configuration class for reaction filtering. This class manages configuration
559
+ settings for various reaction filters, including paths, file formats, and filter-
560
+ specific parameters.
561
+
562
+ :ivar dynamic_bonds_config: Configuration for dynamic bonds checking.
563
+ :ivar small_molecules_config: Configuration for small molecules checking.
564
+ :ivar strange_carbons_config: Configuration for strange carbons checking.
565
+ :ivar compete_products_config: Configuration for competing products checking.
566
+ :ivar cgr_connected_components_config: Configuration for CGR connected components checking.
567
+ :ivar rings_change_config: Configuration for rings change checking.
568
+ :ivar no_reaction_config: Configuration for no reaction checking.
569
+ :ivar multi_center_config: Configuration for multi-center checking.
570
+ :ivar wrong_ch_breaking_config: Configuration for wrong C-H breaking checking.
571
+ :ivar cc_sp3_breaking_config: Configuration for CC sp3 breaking checking.
572
+ :ivar cc_ring_breaking_config: Configuration for CC ring breaking checking.
573
+
574
+ """
575
+
576
+ # configuration for reaction filters
577
+ dynamic_bonds_config: Optional[DynamicBondsConfig] = None
578
+ small_molecules_config: Optional[SmallMoleculesConfig] = None
579
+ strange_carbons_config: Optional[StrangeCarbonsConfig] = None
580
+ compete_products_config: Optional[CompeteProductsConfig] = None
581
+ cgr_connected_components_config: Optional[CGRConnectedComponentsConfig] = None
582
+ rings_change_config: Optional[RingsChangeConfig] = None
583
+ no_reaction_config: Optional[NoReactionConfig] = None
584
+ multi_center_config: Optional[MultiCenterConfig] = None
585
+ wrong_ch_breaking_config: Optional[WrongCHBreakingConfig] = None
586
+ cc_sp3_breaking_config: Optional[CCsp3BreakingConfig] = None
587
+ cc_ring_breaking_config: Optional[CCRingBreakingConfig] = None
588
+
589
+ def to_dict(self):
590
+ """Converts the configuration into a dictionary."""
591
+ config_dict = {
592
+ "dynamic_bonds_config": convert_config_to_dict(
593
+ self.dynamic_bonds_config, DynamicBondsConfig
594
+ ),
595
+ "small_molecules_config": convert_config_to_dict(
596
+ self.small_molecules_config, SmallMoleculesConfig
597
+ ),
598
+ "compete_products_config": convert_config_to_dict(
599
+ self.compete_products_config, CompeteProductsConfig
600
+ ),
601
+ "cgr_connected_components_config": (
602
+ {} if self.cgr_connected_components_config is not None else None
603
+ ),
604
+ "rings_change_config": {} if self.rings_change_config is not None else None,
605
+ "strange_carbons_config": (
606
+ {} if self.strange_carbons_config is not None else None
607
+ ),
608
+ "no_reaction_config": {} if self.no_reaction_config is not None else None,
609
+ "multi_center_config": {} if self.multi_center_config is not None else None,
610
+ "wrong_ch_breaking_config": (
611
+ {} if self.wrong_ch_breaking_config is not None else None
612
+ ),
613
+ "cc_sp3_breaking_config": (
614
+ {} if self.cc_sp3_breaking_config is not None else None
615
+ ),
616
+ "cc_ring_breaking_config": (
617
+ {} if self.cc_ring_breaking_config is not None else None
618
+ ),
619
+ }
620
+
621
+ filtered_config_dict = {k: v for k, v in config_dict.items() if v is not None}
622
+
623
+ return filtered_config_dict
624
+
625
+ @staticmethod
626
+ def from_dict(config_dict: Dict[str, Any]) -> "ReactionFilterConfig":
627
+ """Create an instance of ReactionCheckConfig from a dictionary."""
628
+ # Instantiate configuration objects if their corresponding dictionary is present
629
+ dynamic_bonds_config = (
630
+ DynamicBondsConfig(**config_dict["dynamic_bonds_config"])
631
+ if "dynamic_bonds_config" in config_dict
632
+ else None
633
+ )
634
+
635
+ small_molecules_config = (
636
+ SmallMoleculesConfig(**config_dict["small_molecules_config"])
637
+ if "small_molecules_config" in config_dict
638
+ else None
639
+ )
640
+
641
+ compete_products_config = (
642
+ CompeteProductsConfig(**config_dict["compete_products_config"])
643
+ if "compete_products_config" in config_dict
644
+ else None
645
+ )
646
+
647
+ cgr_connected_components_config = (
648
+ CGRConnectedComponentsConfig()
649
+ if "cgr_connected_components_config" in config_dict
650
+ else None
651
+ )
652
+
653
+ rings_change_config = (
654
+ RingsChangeConfig() if "rings_change_config" in config_dict else None
655
+ )
656
+
657
+ strange_carbons_config = (
658
+ StrangeCarbonsConfig() if "strange_carbons_config" in config_dict else None
659
+ )
660
+
661
+ no_reaction_config = (
662
+ NoReactionConfig() if "no_reaction_config" in config_dict else None
663
+ )
664
+
665
+ multi_center_config = (
666
+ MultiCenterConfig() if "multi_center_config" in config_dict else None
667
+ )
668
+
669
+ wrong_ch_breaking_config = (
670
+ WrongCHBreakingConfig()
671
+ if "wrong_ch_breaking_config" in config_dict
672
+ else None
673
+ )
674
+
675
+ cc_sp3_breaking_config = (
676
+ CCsp3BreakingConfig() if "cc_sp3_breaking_config" in config_dict else None
677
+ )
678
+
679
+ cc_ring_breaking_config = (
680
+ CCRingBreakingConfig() if "cc_ring_breaking_config" in config_dict else None
681
+ )
682
+
683
+ return ReactionFilterConfig(
684
+ dynamic_bonds_config=dynamic_bonds_config,
685
+ small_molecules_config=small_molecules_config,
686
+ compete_products_config=compete_products_config,
687
+ cgr_connected_components_config=cgr_connected_components_config,
688
+ rings_change_config=rings_change_config,
689
+ strange_carbons_config=strange_carbons_config,
690
+ no_reaction_config=no_reaction_config,
691
+ multi_center_config=multi_center_config,
692
+ wrong_ch_breaking_config=wrong_ch_breaking_config,
693
+ cc_sp3_breaking_config=cc_sp3_breaking_config,
694
+ cc_ring_breaking_config=cc_ring_breaking_config,
695
+ )
696
+
697
+ @staticmethod
698
+ def from_yaml(file_path: str) -> "ReactionFilterConfig":
699
+ """Deserializes a YAML file into a ReactionCheckConfig object."""
700
+ with open(file_path, "r", encoding="utf-8") as file:
701
+ config_dict = yaml.safe_load(file)
702
+ return ReactionFilterConfig.from_dict(config_dict)
703
+
704
+ def _validate_params(self, params: Dict[str, Any]):
705
+ pass
706
+
707
+ def create_filters(self):
708
+ filter_instances = []
709
+
710
+ if self.dynamic_bonds_config is not None:
711
+ filter_instances.append(
712
+ DynamicBondsFilter.from_config(self.dynamic_bonds_config)
713
+ )
714
+
715
+ if self.small_molecules_config is not None:
716
+ filter_instances.append(
717
+ SmallMoleculesFilter.from_config(self.small_molecules_config)
718
+ )
719
+
720
+ if self.strange_carbons_config is not None:
721
+ filter_instances.append(
722
+ StrangeCarbonsFilter.from_config(self.strange_carbons_config)
723
+ )
724
+
725
+ if self.compete_products_config is not None:
726
+ filter_instances.append(
727
+ CompeteProductsFilter.from_config(self.compete_products_config)
728
+ )
729
+
730
+ if self.cgr_connected_components_config is not None:
731
+ filter_instances.append(
732
+ CGRConnectedComponentsFilter.from_config(
733
+ self.cgr_connected_components_config
734
+ )
735
+ )
736
+
737
+ if self.rings_change_config is not None:
738
+ filter_instances.append(
739
+ RingsChangeFilter.from_config(self.rings_change_config)
740
+ )
741
+
742
+ if self.no_reaction_config is not None:
743
+ filter_instances.append(
744
+ NoReactionFilter.from_config(self.no_reaction_config)
745
+ )
746
+
747
+ if self.multi_center_config is not None:
748
+ filter_instances.append(
749
+ MultiCenterFilter.from_config(self.multi_center_config)
750
+ )
751
+
752
+ if self.wrong_ch_breaking_config is not None:
753
+ filter_instances.append(
754
+ WrongCHBreakingFilter.from_config(self.wrong_ch_breaking_config)
755
+ )
756
+
757
+ if self.cc_sp3_breaking_config is not None:
758
+ filter_instances.append(
759
+ CCsp3BreakingFilter.from_config(self.cc_sp3_breaking_config)
760
+ )
761
+
762
+ if self.cc_ring_breaking_config is not None:
763
+ filter_instances.append(
764
+ CCRingBreakingFilter.from_config(self.cc_ring_breaking_config)
765
+ )
766
+
767
+ return filter_instances
768
+
769
+
770
+ def tanimoto_kernel(x: MorganFingerprint, y: MorganFingerprint) -> float:
771
+ """Calculate the Tanimoto coefficient between each element of arrays x and y."""
772
+ x = x.astype(np.float64)
773
+ y = y.astype(np.float64)
774
+ x_dot = np.dot(x, y.T)
775
+ x2 = np.sum(x**2, axis=1)
776
+ y2 = np.sum(y**2, axis=1)
777
+
778
+ denominator = np.array([x2] * len(y2)).T + np.array([y2] * len(x2)) - x_dot
779
+ result = np.divide(
780
+ x_dot, denominator, out=np.zeros_like(x_dot), where=denominator != 0
781
+ )
782
+
783
+ return result
784
+
785
+
786
+ def filter_reaction(
787
+ reaction: ReactionContainer, config: ReactionFilterConfig, filters: list
788
+ ) -> Tuple[bool, ReactionContainer]:
789
+ """Checks the input reaction. Returns True if reaction is detected as erroneous and
790
+ returns reaction itself, which sometimes is modified and does not necessarily
791
+ correspond to the initial reaction.
792
+
793
+ :param reaction: Reaction to be filtered.
794
+ :param config: Reaction filtration configuration.
795
+ :param filters: The list of reaction filters.
796
+ :return: False and reaction if reaction is correct and True and reaction if reaction
797
+ is filtered (erroneous).
798
+ """
799
+
800
+ is_filtered = False
801
+
802
+ # run reaction standardization
803
+
804
+ standardizers = [
805
+ RemoveReagentsStandardizer(),
806
+ KekuleFormStandardizer(),
807
+ AromaticFormStandardizer(),
808
+ ]
809
+
810
+ for reaction_standardizer in standardizers:
811
+ reaction = reaction_standardizer(reaction)
812
+ if not reaction:
813
+ is_filtered = True
814
+ break
815
+
816
+ # run reaction filtration
817
+ if not is_filtered:
818
+ for reaction_filter in filters:
819
+ try: # CGRTools ValueError: mapping of graphs is not disjoint
820
+ if reaction_filter(reaction):
821
+ # if filter returns True it means the reaction doesn't pass the filter
822
+ reaction.meta["filtration_log"] = reaction_filter.__class__.__name__
823
+ is_filtered = True
824
+ except Exception as e:
825
+ logging.debug(e)
826
+ is_filtered = True
827
+
828
+ return is_filtered, reaction
829
+
830
+
831
+ @ray.remote
832
+ def process_batch(
833
+ batch: List[Tuple[int, ReactionContainer]],
834
+ config: ReactionFilterConfig,
835
+ filters: list,
836
+ ) -> List[Tuple[bool, ReactionContainer]]:
837
+ """
838
+ Processes a batch of reactions to extract reaction rules based on the given
839
+ configuration. This function operates as a remote task in a distributed system using
840
+ Ray.
841
+
842
+ :param batch: A list where each element is a tuple containing an index (int) and a
843
+ ReactionContainer object. The index is typically used to keep track of the
844
+ reaction's position in a larger dataset.
845
+ :param config: Reaction filtration configuration.
846
+ :param filters: The list of reaction filters.
847
+ :return: The list of tuples where each tuple include the reaction index, is ir
848
+ filtered or not (True/False) and reaction itself.
849
+
850
+ """
851
+
852
+ processed_reaction_list = []
853
+ for reaction in batch:
854
+ try: # CGRtools.exceptions.MappingError: atoms with number {52} not equal
855
+ is_filtered, processed_reaction = filter_reaction(reaction, config, filters)
856
+ processed_reaction_list.append((is_filtered, processed_reaction))
857
+ except Exception as e:
858
+ logging.debug(e)
859
+ processed_reaction_list.append((True, reaction))
860
+ return processed_reaction_list
861
+
862
+
863
+ def process_completed_batch(
864
+ futures: Dict,
865
+ result_file: TextIOWrapper,
866
+ n_filtered: int = 0,
867
+ ) -> int:
868
+ """
869
+ Processes completed batches of reactions.
870
+
871
+ :param futures: A dictionary of futures representing ongoing batch processing tasks.
872
+ :param result_file: The path to the file where filtered reactions will be stored.
873
+ :param n_filtered: The number of processed reactions.
874
+ :return: The numbers of filtered and correct reactions.
875
+
876
+ """
877
+
878
+ ready_id, running_id = ray.wait(list(futures.keys()), num_returns=1)
879
+ completed_batch = ray.get(ready_id[0])
880
+
881
+ # write results of the completed batch to file
882
+ for is_filtered, reaction in completed_batch:
883
+ if not is_filtered:
884
+ result_file.write(reaction)
885
+ n_filtered += 1
886
+
887
+ # remove completed future and update progress bar
888
+ del futures[ready_id[0]]
889
+
890
+ return n_filtered
891
+
892
+
893
+ def filter_reactions_from_file(
894
+ config: ReactionFilterConfig,
895
+ input_reaction_data_path: str,
896
+ filtered_reaction_data_path: str = "reaction_data_filtered.smi",
897
+ num_cpus: int = 1,
898
+ batch_size: int = 100,
899
+ ) -> None:
900
+ """
901
+ Processes reaction data, applying reaction filters based on the provided
902
+ configuration, and writes the results to specified files.
903
+
904
+ :param config: ReactionCheckConfig object containing all filtration configuration
905
+ settings.
906
+ :param input_reaction_data_path: Path to the reaction data file.
907
+ :param filtered_reaction_data_path: Name for the file that will contain filtered
908
+ reactions.
909
+ :param num_cpus: Number of CPUs to use for processing.
910
+ :param batch_size: Size of the batch for processing reactions.
911
+ :return: None. The function writes the processed reactions to specified RDF/smi
912
+ files.
913
+
914
+ """
915
+
916
+ filters = config.create_filters()
917
+
918
+ ray.init(num_cpus=num_cpus, ignore_reinit_error=True, logging_level=logging.ERROR)
919
+ max_concurrent_batches = num_cpus # limit the number of concurrent batches
920
+ lines_counter = 0
921
+ with ReactionReader(input_reaction_data_path) as reactions, ReactionWriter(
922
+ filtered_reaction_data_path
923
+ ) as result_file:
924
+
925
+ batches_to_process, batch = {}, []
926
+ n_filtered = 0
927
+ for index, reaction in tqdm(
928
+ enumerate(reactions),
929
+ desc="Number of reactions processed: ",
930
+ bar_format="{desc}{n} [{elapsed}]",
931
+ ):
932
+ lines_counter += 1
933
+ batch.append(reaction)
934
+ if len(batch) == batch_size:
935
+ batch_results = process_batch.remote(batch, config, filters)
936
+ batches_to_process[batch_results] = None
937
+ batch = []
938
+
939
+ # check and process completed tasks if we've reached the concurrency limit
940
+ while len(batches_to_process) >= max_concurrent_batches:
941
+ n_filtered = process_completed_batch(
942
+ batches_to_process,
943
+ result_file,
944
+ n_filtered,
945
+ )
946
+
947
+ # process the last batch if it's not empty
948
+ if batch:
949
+ batch_results = process_batch.remote(batch, config, filters)
950
+ batches_to_process[batch_results] = None
951
+
952
+ # process remaining batches
953
+ while batches_to_process:
954
+ n_filtered = process_completed_batch(
955
+ batches_to_process,
956
+ result_file,
957
+ n_filtered,
958
+ )
959
+
960
+ ray.shutdown()
961
+ print(f"Initial number of reactions: {lines_counter}")
962
+ print(f"Filtered number of reactions: {n_filtered}")
synplan/chem/data/standardizing.py ADDED
@@ -0,0 +1,1187 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Module containing classes and functions for reactions standardizing.
2
+
3
+ This module contains the open-source code from
4
+ https://github.com/Laboratoire-de-Chemoinformatique/Reaction_Data_Cleaning/blob/master/scripts/standardizer.py
5
+ """
6
+
7
+ from __future__ import annotations
8
+
9
+ import logging
10
+ from contextlib import suppress
11
+ from dataclasses import dataclass
12
+ from io import TextIOWrapper
13
+ from typing import Any, Dict, Iterable, List, Optional, Tuple, Union, Sequence, TextIO
14
+ from abc import ABC, abstractmethod
15
+ from pathlib import Path
16
+ import sys
17
+
18
+
19
+ import ray
20
+ import yaml
21
+ from CGRtools import smiles as smiles_cgrtools
22
+ from CGRtools.containers import MoleculeContainer
23
+ from CGRtools.containers import ReactionContainer
24
+ from CGRtools.containers import ReactionContainer as ReactionContainerCGRTools
25
+ from chython import ReactionContainer as ReactionContainerChython
26
+ from chython import smiles as smiles_chython
27
+ from tqdm.auto import tqdm
28
+
29
+ from synplan.chem.utils import unite_molecules
30
+ from synplan.utils.config import ConfigABC
31
+ from synplan.utils.files import ReactionReader, ReactionWriter
32
+ from synplan.utils.logging import init_logger, init_ray_logging
33
+
34
+ logger = logging.getLogger("synplan.chem.data.standardizing")
35
+
36
+
37
+ class StandardizationError(RuntimeError):
38
+ """Wraps the original exception and the reaction string that failed."""
39
+
40
+ def __init__(self, stage: str, reaction: str, original: Exception):
41
+ super().__init__(f"{stage} failed on {reaction}: {original}")
42
+ self.stage = stage
43
+ self.reaction = reaction
44
+ self.original = original
45
+
46
+
47
+ class BaseStandardizer(ABC):
48
+ """Template: subclasses override `_run` only."""
49
+
50
+ @classmethod
51
+ def from_config(cls, _cfg: object) -> "BaseStandardizer":
52
+ return cls()
53
+
54
+ @abstractmethod
55
+ def _run(self, rxn: ReactionContainer) -> ReactionContainer:
56
+ """Run the standardization step on the reaction.
57
+
58
+ Args:
59
+ rxn: The reaction to standardize
60
+
61
+ Returns:
62
+ The standardized reaction
63
+
64
+ Raises:
65
+ StandardizationError: If standardization fails
66
+ """
67
+ ...
68
+
69
+ def __call__(self, rxn: ReactionContainer) -> ReactionContainer:
70
+ """Execute the standardization step with proper error handling.
71
+
72
+ Args:
73
+ rxn: The reaction to standardize
74
+
75
+ Returns:
76
+ The standardized reaction
77
+
78
+ Raises:
79
+ StandardizationError: If standardization fails
80
+ """
81
+ try:
82
+ return self._run(rxn)
83
+ except Exception as exc:
84
+ logging.debug("%s: %s", self.__class__.__name__, exc, exc_info=True)
85
+ raise StandardizationError(self.__class__.__name__, str(rxn), exc)
86
+
87
+
88
+ # Configuration classes
89
+ @dataclass
90
+ class ReactionMappingConfig:
91
+ pass
92
+
93
+
94
+ class ReactionMappingStandardizer(BaseStandardizer):
95
+ """Maps atoms of the reaction using chython (chytorch)."""
96
+
97
+ def _map_and_remove_reagents(
98
+ self, reaction: ReactionContainerChython
99
+ ) -> ReactionContainerChython:
100
+ """Map and remove reagents from the reaction.
101
+
102
+ Args:
103
+ reaction: Input reaction
104
+
105
+ Returns:
106
+ The mapped reaction with reagents removed
107
+ """
108
+ reaction.reset_mapping()
109
+ reaction.remove_reagents()
110
+ return reaction
111
+
112
+ def _run(self, rxn: ReactionContainerCGRTools) -> ReactionContainerCGRTools:
113
+ """Map atoms of the reaction using chython.
114
+
115
+ Args:
116
+ rxn: Input reaction
117
+
118
+ Returns:
119
+ The mapped reaction
120
+
121
+ Raises:
122
+ StandardizationError: If mapping fails
123
+ """
124
+ try:
125
+ # Convert to chython format
126
+ if isinstance(rxn, str):
127
+ chython_reaction = smiles_chython(rxn)
128
+ else:
129
+ # Convert CGRtools reaction to SMILES string, preserving reagents
130
+ reactants = ".".join(str(m) for m in rxn.reactants)
131
+ reagents = ".".join(str(m) for m in rxn.reagents)
132
+ products = ".".join(str(m) for m in rxn.products)
133
+ smiles = f"{reactants}>{reagents}>{products}"
134
+ # Parse SMILES string with chython
135
+ chython_reaction = smiles_chython(smiles)
136
+
137
+ # Map and remove reagents
138
+ reaction_mapped = self._map_and_remove_reagents(chython_reaction)
139
+ if not reaction_mapped:
140
+ raise StandardizationError(
141
+ "ReactionMapping", str(rxn), ValueError("Mapping failed")
142
+ )
143
+
144
+ # Convert back to CGRtools format
145
+ mapped_smiles = format(chython_reaction, "m")
146
+ result = smiles_cgrtools(mapped_smiles)
147
+ result.meta.update(rxn.meta) # Preserve metadata
148
+ return result
149
+ except Exception as e:
150
+ raise StandardizationError("ReactionMapping", str(rxn), e)
151
+
152
+
153
+ @dataclass
154
+ class FunctionalGroupsConfig:
155
+ pass
156
+
157
+
158
+ class FunctionalGroupsStandardizer(BaseStandardizer):
159
+ """Functional groups standardization."""
160
+
161
+ def _run(self, rxn: ReactionContainer) -> ReactionContainer:
162
+ """Standardize functional groups in the reaction.
163
+
164
+ Args:
165
+ rxn: Input reaction
166
+
167
+ Returns:
168
+ The reaction with standardized functional groups
169
+
170
+ Raises:
171
+ StandardizationError: If standardization fails
172
+ """
173
+ rxn.standardize()
174
+ return rxn
175
+
176
+
177
+ @dataclass
178
+ class KekuleFormConfig:
179
+ pass
180
+
181
+
182
+ class KekuleFormStandardizer(BaseStandardizer):
183
+ """Reactants/reagents/products kekulization."""
184
+
185
+ def _run(self, rxn: ReactionContainer) -> ReactionContainer:
186
+ """Kekulize the reaction.
187
+
188
+ Args:
189
+ rxn: The reaction to kekulize
190
+
191
+ Returns:
192
+ The kekulized reaction
193
+
194
+ Raises:
195
+ StandardizationError: If kekulization fails
196
+ """
197
+ rxn.kekule()
198
+ return rxn
199
+
200
+
201
+ @dataclass
202
+ class CheckValenceConfig:
203
+ pass
204
+
205
+
206
+ class CheckValenceStandardizer(BaseStandardizer):
207
+ """Check valence."""
208
+
209
+ def _run(self, rxn: ReactionContainer) -> ReactionContainer:
210
+ """Check valence of atoms in the reaction.
211
+
212
+ Args:
213
+ rxn: Input reaction
214
+
215
+ Returns:
216
+ The reaction if valences are correct
217
+
218
+ Raises:
219
+ StandardizationError: If valence check fails
220
+ """
221
+ for molecule in rxn.reactants + rxn.products + rxn.reagents:
222
+ valence_mistakes = molecule.check_valence()
223
+ if valence_mistakes:
224
+ raise StandardizationError(
225
+ "CheckValence",
226
+ str(rxn),
227
+ ValueError(f"Valence errors: {valence_mistakes}"),
228
+ )
229
+ return rxn
230
+
231
+
232
+ @dataclass
233
+ class ImplicifyHydrogensConfig:
234
+ pass
235
+
236
+
237
+ class ImplicifyHydrogensStandardizer(BaseStandardizer):
238
+ """Implicify hydrogens."""
239
+
240
+ def _run(self, rxn: ReactionContainer) -> ReactionContainer:
241
+ """Implicify hydrogens in the reaction.
242
+
243
+ Args:
244
+ rxn: Input reaction
245
+
246
+ Returns:
247
+ The reaction with implicified hydrogens
248
+
249
+ Raises:
250
+ StandardizationError: If hydrogen implicification fails
251
+ """
252
+ rxn.implicify_hydrogens()
253
+ return rxn
254
+
255
+
256
+ @dataclass
257
+ class CheckIsotopesConfig:
258
+ pass
259
+
260
+
261
+ class CheckIsotopesStandardizer(BaseStandardizer):
262
+ """Check isotopes."""
263
+
264
+ def _run(self, rxn: ReactionContainer) -> ReactionContainer:
265
+ """Check and clean isotopes in the reaction.
266
+
267
+ Args:
268
+ rxn: Input reaction
269
+
270
+ Returns:
271
+ The reaction with cleaned isotopes
272
+
273
+ Raises:
274
+ StandardizationError: If isotope check/cleaning fails
275
+ """
276
+ is_isotope = False
277
+ for molecule in rxn.reactants + rxn.products:
278
+ for _, atom in molecule.atoms():
279
+ if atom.isotope:
280
+ is_isotope = True
281
+ break
282
+ if is_isotope:
283
+ break
284
+
285
+ if is_isotope:
286
+ rxn.clean_isotopes()
287
+
288
+ return rxn
289
+
290
+
291
+ @dataclass
292
+ class SplitIonsConfig:
293
+ pass
294
+
295
+
296
+ class SplitIonsStandardizer(BaseStandardizer):
297
+ """Computing charge of molecule."""
298
+
299
+ def _run(self, rxn: ReactionContainer) -> ReactionContainer:
300
+ """Split ions in the reaction.
301
+
302
+ Args:
303
+ rxn: Input reaction
304
+
305
+ Returns:
306
+ The reaction with split ions
307
+
308
+ Raises:
309
+ StandardizationError: If ion splitting fails
310
+ """
311
+ reaction, return_code = self._split_ions(rxn)
312
+ if return_code == 2: # ions were split but the reaction is imbalanced
313
+ raise StandardizationError(
314
+ "SplitIons",
315
+ str(rxn),
316
+ ValueError("Reaction is imbalanced after ion splitting"),
317
+ )
318
+ return reaction
319
+
320
+ def _calc_charge(self, molecule: MoleculeContainer) -> int:
321
+ """Compute total charge of a molecule.
322
+
323
+ Args:
324
+ molecule: Input molecule
325
+
326
+ Returns:
327
+ The total charge of the molecule
328
+ """
329
+ return sum(molecule._charges.values())
330
+
331
+ def _split_ions(self, reaction: ReactionContainer) -> Tuple[ReactionContainer, int]:
332
+ """Split ions in a reaction.
333
+
334
+ Args:
335
+ reaction: Input reaction
336
+
337
+ Returns:
338
+ A tuple containing:
339
+ - The reaction with split ions
340
+ - Return code (0: nothing changed, 1: ions split, 2: ions split but imbalanced)
341
+ """
342
+ meta = reaction.meta
343
+ reaction_parts = []
344
+ return_codes = []
345
+
346
+ for molecules in (reaction.reactants, reaction.reagents, reaction.products):
347
+ # Split molecules into individual components
348
+ divided_molecules = []
349
+ for molecule in molecules:
350
+ if isinstance(molecule, str):
351
+ # If it's a string, try to parse it as a molecule
352
+ try:
353
+ molecule: MoleculeContainer = smiles_cgrtools(molecule)
354
+ except Exception as e:
355
+ logging.warning("Failed to parse molecule %s: %s", molecule, e)
356
+ continue
357
+
358
+ # Use the split method from CGRtools
359
+ try:
360
+ components = molecule.split()
361
+ divided_molecules.extend(components)
362
+ except Exception as e:
363
+ logging.warning("Failed to split molecule %s: %s", molecule, e)
364
+ divided_molecules.append(molecule)
365
+
366
+ total_charge = 0
367
+ ions_present = False
368
+ for molecule in divided_molecules:
369
+ try:
370
+ mol_charge = self._calc_charge(molecule)
371
+ total_charge += mol_charge
372
+ if mol_charge != 0:
373
+ ions_present = True
374
+ except Exception as e:
375
+ logging.warning(
376
+ "Failed to calculate charge for molecule %s: %s", molecule, e
377
+ )
378
+ continue
379
+
380
+ if ions_present and total_charge:
381
+ return_codes.append(2)
382
+ elif ions_present:
383
+ return_codes.append(1)
384
+ else:
385
+ return_codes.append(0)
386
+
387
+ reaction_parts.append(tuple(divided_molecules))
388
+
389
+ return (
390
+ ReactionContainer(
391
+ reactants=reaction_parts[0],
392
+ reagents=reaction_parts[1],
393
+ products=reaction_parts[2],
394
+ meta=meta,
395
+ ),
396
+ max(return_codes),
397
+ )
398
+
399
+
400
+ @dataclass
401
+ class AromaticFormConfig:
402
+ pass
403
+
404
+
405
+ class AromaticFormStandardizer(BaseStandardizer):
406
+ """Aromatize molecules in reaction."""
407
+
408
+ def _run(self, rxn: ReactionContainer) -> ReactionContainer:
409
+ """Aromatize molecules in the reaction.
410
+
411
+ Args:
412
+ rxn: Input reaction
413
+
414
+ Returns:
415
+ The reaction with aromatized molecules
416
+
417
+ Raises:
418
+ StandardizationError: If aromatization fails
419
+ """
420
+ rxn.thiele()
421
+ return rxn
422
+
423
+
424
+ @dataclass
425
+ class MappingFixConfig:
426
+ pass
427
+
428
+
429
+ class MappingFixStandardizer(BaseStandardizer):
430
+ """Fix atom-to-atom mapping in reaction."""
431
+
432
+ def _run(self, rxn: ReactionContainer) -> ReactionContainer:
433
+ """Fix atom-to-atom mapping in the reaction.
434
+
435
+ Args:
436
+ rxn: Input reaction
437
+
438
+ Returns:
439
+ The reaction with fixed atom-to-atom mapping
440
+
441
+ Raises:
442
+ StandardizationError: If mapping fix fails
443
+ """
444
+ rxn.fix_mapping()
445
+ return rxn
446
+
447
+
448
+ @dataclass
449
+ class UnchangedPartsConfig:
450
+ pass
451
+
452
+
453
+ class UnchangedPartsStandardizer(BaseStandardizer):
454
+ """Ungroup molecules, remove unchanged parts from reactants and products."""
455
+
456
+ def __init__(
457
+ self,
458
+ add_reagents_to_reactants: bool = False,
459
+ keep_reagents: bool = False,
460
+ ):
461
+ self.add_reagents_to_reactants = add_reagents_to_reactants
462
+ self.keep_reagents = keep_reagents
463
+
464
+ @classmethod
465
+ def from_config(cls, config: UnchangedPartsConfig) -> "UnchangedPartsStandardizer":
466
+ return cls()
467
+
468
+ def _run(self, rxn: ReactionContainer) -> ReactionContainer:
469
+ """Remove unchanged parts from the reaction.
470
+
471
+ Args:
472
+ rxn: Input reaction
473
+
474
+ Returns:
475
+ The reaction with unchanged parts removed
476
+
477
+ Raises:
478
+ StandardizationError: If unchanged parts removal fails
479
+ """
480
+ meta = rxn.meta
481
+ new_reactants = list(rxn.reactants)
482
+ new_reagents = list(rxn.reagents)
483
+ if self.add_reagents_to_reactants:
484
+ new_reactants.extend(new_reagents)
485
+ new_reagents = []
486
+ reactants = new_reactants.copy()
487
+ new_products = list(rxn.products)
488
+
489
+ for reactant in reactants:
490
+ if reactant in new_products:
491
+ new_reagents.append(reactant)
492
+ new_reactants.remove(reactant)
493
+ new_products.remove(reactant)
494
+ if not self.keep_reagents:
495
+ new_reagents = []
496
+
497
+ if not new_reactants and new_products:
498
+ raise StandardizationError(
499
+ "UnchangedParts", str(rxn), ValueError("No reactants left")
500
+ )
501
+ if not new_products and new_reactants:
502
+ raise StandardizationError(
503
+ "UnchangedParts", str(rxn), ValueError("No products left")
504
+ )
505
+ if not new_reactants and not new_products:
506
+ raise StandardizationError(
507
+ "UnchangedParts", str(rxn), ValueError("No molecules left")
508
+ )
509
+
510
+ new_reaction = ReactionContainer(
511
+ reactants=tuple(new_reactants),
512
+ reagents=tuple(new_reagents),
513
+ products=tuple(new_products),
514
+ meta=meta,
515
+ )
516
+ new_reaction.name = rxn.name
517
+ return new_reaction
518
+
519
+
520
+ @dataclass
521
+ class SmallMoleculesConfig:
522
+ mol_max_size: int = 6
523
+
524
+ @staticmethod
525
+ def from_dict(config_dict: Dict[str, Any]) -> "SmallMoleculesConfig":
526
+ """Create an instance of SmallMoleculesConfig from a dictionary."""
527
+ return SmallMoleculesConfig(**config_dict)
528
+
529
+ @staticmethod
530
+ def from_yaml(file_path: str) -> "SmallMoleculesConfig":
531
+ """Deserialize a YAML file into a SmallMoleculesConfig object."""
532
+ with open(file_path, "r", encoding="utf-8") as file:
533
+ config_dict = yaml.safe_load(file)
534
+ return SmallMoleculesConfig.from_dict(config_dict)
535
+
536
+ def _validate_params(self, params: Dict[str, Any]) -> None:
537
+ """Validate configuration parameters."""
538
+ mol_max_size = params.get("mol_max_size", self.mol_max_size)
539
+ if not isinstance(mol_max_size, int) or not (0 < mol_max_size):
540
+ raise ValueError("Invalid 'mol_max_size'; expected an integer more than 1")
541
+
542
+
543
+ class SmallMoleculesStandardizer(BaseStandardizer):
544
+ """Remove small molecule from reaction."""
545
+
546
+ def __init__(self, mol_max_size: int = 6):
547
+ self.mol_max_size = mol_max_size
548
+
549
+ @classmethod
550
+ def from_config(cls, config: SmallMoleculesConfig) -> "SmallMoleculesStandardizer":
551
+ return cls(config.mol_max_size)
552
+
553
+ def _split_molecules(
554
+ self, molecules: Iterable, number_of_atoms: int
555
+ ) -> Tuple[List[MoleculeContainer], List[MoleculeContainer]]:
556
+ """Split molecules according to the number of heavy atoms.
557
+
558
+ Args:
559
+ molecules: Iterable of molecules
560
+ number_of_atoms: Threshold for splitting molecules
561
+
562
+ Returns:
563
+ Tuple of lists containing "big" molecules and "small" molecules
564
+ """
565
+ big_molecules, small_molecules = [], []
566
+ for molecule in molecules:
567
+ if len(molecule) > number_of_atoms:
568
+ big_molecules.append(molecule)
569
+ else:
570
+ small_molecules.append(molecule)
571
+ return big_molecules, small_molecules
572
+
573
+ def _run(self, rxn: ReactionContainer) -> ReactionContainer:
574
+ """Remove small molecules from the reaction.
575
+
576
+ Args:
577
+ rxn: Input reaction
578
+
579
+ Returns:
580
+ The reaction without small molecules
581
+
582
+ Raises:
583
+ StandardizationError: If small molecule removal fails
584
+ """
585
+ new_reactants, small_reactants = self._split_molecules(
586
+ rxn.reactants, self.mol_max_size
587
+ )
588
+ new_products, small_products = self._split_molecules(
589
+ rxn.products, self.mol_max_size
590
+ )
591
+
592
+ if not new_reactants or not new_products:
593
+ raise StandardizationError(
594
+ "SmallMolecules",
595
+ str(rxn),
596
+ ValueError("No molecules left after removing small ones"),
597
+ )
598
+
599
+ new_reaction = ReactionContainer(
600
+ new_reactants, new_products, rxn.reagents, rxn.meta
601
+ )
602
+ new_reaction.name = rxn.name
603
+
604
+ # Save small molecules to meta
605
+ united_small_reactants = unite_molecules(small_reactants)
606
+ new_reaction.meta["small_reactants"] = str(united_small_reactants)
607
+ united_small_products = unite_molecules(small_products)
608
+ new_reaction.meta["small_products"] = str(united_small_products)
609
+
610
+ return new_reaction
611
+
612
+
613
+ @dataclass
614
+ class RemoveReagentsConfig:
615
+ reagent_max_size: int = 7
616
+
617
+ @staticmethod
618
+ def from_dict(config_dict: Dict[str, Any]) -> "RemoveReagentsConfig":
619
+ """Create an instance of RemoveReagentsConfig from a dictionary."""
620
+ return RemoveReagentsConfig(**config_dict)
621
+
622
+ @staticmethod
623
+ def from_yaml(file_path: str) -> "RemoveReagentsConfig":
624
+ """Deserialize a YAML file into a RemoveReagentsConfig object."""
625
+ with open(file_path, "r", encoding="utf-8") as file:
626
+ config_dict = yaml.safe_load(file)
627
+ return RemoveReagentsConfig.from_dict(config_dict)
628
+
629
+ def _validate_params(self, params: Dict[str, Any]) -> None:
630
+ """Validate configuration parameters."""
631
+ reagent_max_size = params.get("reagent_max_size", self.reagent_max_size)
632
+ if not isinstance(reagent_max_size, int) or not (0 < reagent_max_size):
633
+ raise ValueError(
634
+ "Invalid 'reagent_max_size'; expected an integer more than 1"
635
+ )
636
+
637
+
638
+ class RemoveReagentsStandardizer(BaseStandardizer):
639
+ """Remove reagents from reaction."""
640
+
641
+ def __init__(self, reagent_max_size: int = 7):
642
+ self.reagent_max_size = reagent_max_size
643
+
644
+ @classmethod
645
+ def from_config(cls, config: RemoveReagentsConfig) -> "RemoveReagentsStandardizer":
646
+ return cls(config.reagent_max_size)
647
+
648
+ def _run(self, rxn: ReactionContainer) -> ReactionContainer:
649
+ """Remove reagents from the reaction.
650
+
651
+ Args:
652
+ rxn: Input reaction
653
+
654
+ Returns:
655
+ The reaction without reagents
656
+
657
+ Raises:
658
+ StandardizationError: If reagent removal fails
659
+ """
660
+ not_changed_molecules = set(rxn.reactants).intersection(rxn.products)
661
+ cgr = ~rxn
662
+ center_atoms = set(cgr.center_atoms)
663
+
664
+ new_reactants = []
665
+ new_products = []
666
+ new_reagents = []
667
+
668
+ for molecule in rxn.reactants:
669
+ if center_atoms.isdisjoint(molecule) or molecule in not_changed_molecules:
670
+ new_reagents.append(molecule)
671
+ else:
672
+ new_reactants.append(molecule)
673
+
674
+ for molecule in rxn.products:
675
+ if center_atoms.isdisjoint(molecule) or molecule in not_changed_molecules:
676
+ new_reagents.append(molecule)
677
+ else:
678
+ new_products.append(molecule)
679
+
680
+ if not new_reactants or not new_products:
681
+ raise StandardizationError(
682
+ "RemoveReagents",
683
+ str(rxn),
684
+ ValueError("No molecules left after removing reagents"),
685
+ )
686
+
687
+ # Filter reagents by size
688
+ new_reagents = {
689
+ molecule
690
+ for molecule in new_reagents
691
+ if len(molecule) <= self.reagent_max_size
692
+ }
693
+
694
+ new_reaction = ReactionContainer(
695
+ new_reactants, new_products, new_reagents, rxn.meta
696
+ )
697
+ new_reaction.name = rxn.name
698
+
699
+ return new_reaction
700
+
701
+
702
+ @dataclass
703
+ class RebalanceReactionConfig:
704
+ pass
705
+
706
+
707
+ class RebalanceReactionStandardizer(BaseStandardizer):
708
+ """Rebalance reaction."""
709
+
710
+ @classmethod
711
+ def from_config(
712
+ cls, config: RebalanceReactionConfig
713
+ ) -> "RebalanceReactionStandardizer":
714
+ return cls()
715
+
716
+ def _run(self, rxn: ReactionContainer) -> ReactionContainer:
717
+ """Rebalances the reaction by assembling CGR and then decomposing it. Works for
718
+ all reactions for which the correct CGR can be assembled.
719
+
720
+ Args:
721
+ rxn: Input reaction
722
+
723
+ Returns:
724
+ The rebalanced reaction
725
+
726
+ Raises:
727
+ StandardizationError: If rebalancing fails
728
+ """
729
+ try:
730
+ tmp_rxn = ReactionContainer(rxn.reactants, rxn.products)
731
+ cgr = ~tmp_rxn
732
+ reactants, products = ~cgr
733
+ new_rxn = ReactionContainer(
734
+ reactants.split(), products.split(), rxn.reagents, rxn.meta
735
+ )
736
+ new_rxn.name = rxn.name
737
+ return new_rxn
738
+ except Exception as e:
739
+ logging.debug(f"Rebalancing attempt failed: {e}")
740
+ raise StandardizationError(
741
+ "RebalanceReaction",
742
+ str(rxn),
743
+ ValueError("Failed to rebalance reaction"),
744
+ )
745
+
746
+
747
+ @dataclass
748
+ class DuplicateReactionConfig:
749
+ pass
750
+
751
+
752
+ class DuplicateReactionStandardizer(BaseStandardizer):
753
+ """Cluster‑wide duplicate removal via a Ray actor."""
754
+
755
+ def __init__(self, dedup_actor: "ray.actor.ActorHandle"):
756
+ self._actor = dedup_actor # global singleton handle
757
+ # local fast‑path cache to avoid actor call on obvious repeats *in
758
+ # the same worker*; purely an optimisation, not required.
759
+ self._local_seen: set[int] = set()
760
+
761
+ @classmethod
762
+ def from_config(cls, config: DuplicateReactionConfig):
763
+ # fallback for single‑process mode: create a dummy in‑proc actor
764
+ if ray.is_initialized():
765
+ dedup_actor = ray.get_actor("duplicate_rxn_actor")
766
+ else:
767
+ dedup_actor = None
768
+ return cls(dedup_actor)
769
+
770
+ # ------------------------------------------------------------------
771
+ def safe_reaction_smiles(self, reaction: ReactionContainer) -> str:
772
+ reactants_smi = ".".join(str(i) for i in reaction.reactants)
773
+ products_smi = ".".join(str(i) for i in reaction.products)
774
+ return f"{reactants_smi}>>{products_smi}"
775
+
776
+ def _run(self, rxn: ReactionContainer) -> ReactionContainer:
777
+ h = hash(self.safe_reaction_smiles(rxn))
778
+
779
+ # local cache fast‑path (helps in large batches processed by same
780
+ # worker; no correctness impact).
781
+ if h in self._local_seen:
782
+ raise StandardizationError(
783
+ "DuplicateReaction", str(rxn), ValueError("Duplicate reaction found")
784
+ )
785
+
786
+ # ------------------- cluster‑wide check ------------------------
787
+ if self._actor is None: # single‑CPU fall‑back
788
+ is_new = h not in self._local_seen
789
+ else:
790
+ # synchronous, returns True/False
791
+ is_new = ray.get(self._actor.check_and_add.remote(h))
792
+
793
+ if is_new:
794
+ self._local_seen.add(h)
795
+ return rxn
796
+
797
+ raise StandardizationError(
798
+ "DuplicateReaction", str(rxn), ValueError("Duplicate reaction found")
799
+ )
800
+
801
+
802
+ @ray.remote
803
+ class DedupActor:
804
+ """Cluster‑wide set of reaction hashes."""
805
+
806
+ def __init__(self):
807
+ self._seen: set[int] = set()
808
+
809
+ def check_and_add(self, h: int) -> bool:
810
+ """
811
+ Returns True **iff** the hash was not present yet and is now stored.
812
+ Cluster‑wide uniqueness is guaranteed because this method executes
813
+ serially inside the actor process.
814
+ """
815
+ if h in self._seen:
816
+ return False
817
+ self._seen.add(h)
818
+ return True
819
+
820
+
821
+ # Registry mapping config field names to standardizer classes
822
+ STANDARDIZER_REGISTRY = {
823
+ "reaction_mapping_config": ReactionMappingStandardizer,
824
+ "functional_groups_config": FunctionalGroupsStandardizer,
825
+ "kekule_form_config": KekuleFormStandardizer,
826
+ "check_valence_config": CheckValenceStandardizer,
827
+ "implicify_hydrogens_config": ImplicifyHydrogensStandardizer,
828
+ "check_isotopes_config": CheckIsotopesStandardizer,
829
+ "split_ions_config": SplitIonsStandardizer,
830
+ "aromatic_form_config": AromaticFormStandardizer,
831
+ "mapping_fix_config": MappingFixStandardizer,
832
+ "unchanged_parts_config": UnchangedPartsStandardizer,
833
+ "small_molecules_config": SmallMoleculesStandardizer,
834
+ "remove_reagents_config": RemoveReagentsStandardizer,
835
+ "rebalance_reaction_config": RebalanceReactionStandardizer,
836
+ "duplicate_reaction_config": DuplicateReactionStandardizer,
837
+ }
838
+
839
+
840
+ @dataclass
841
+ class ReactionStandardizationConfig(ConfigABC):
842
+ """Configuration class for reaction filtering. This class manages configuration
843
+ settings for various reaction filters, including paths, file formats, and filter-
844
+ specific parameters.
845
+
846
+ :param reaction_mapping_config: Configuration for reaction mapping.
847
+ :param functional_groups_config: Configuration for functional groups
848
+ standardization.
849
+ :param kekule_form_config: Configuration for reactants/reagents/products
850
+ kekulization.
851
+ :param check_valence_config: Configuration for atom valence checking.
852
+ :param implicify_hydrogens_config: Configuration for hydrogens removal.
853
+ :param check_isotopes_config: Configuration for isotopes checking and cleaning.
854
+ :param split_ions_config: Configuration for computing charge of molecule.
855
+ :param aromatic_form_config: Configuration for molecules aromatization.
856
+ :param unchanged_parts_config: Configuration for removal of unchanged parts in
857
+ reaction.
858
+ :param small_molecules_config: Configuration for removal of small molecule from
859
+ reaction.
860
+ :param remove_reagents_config: Configuration for removal of reagents from reaction.
861
+ :param rebalance_reaction_config: Configuration for reaction rebalancing.
862
+ :param duplicate_reaction_config: Configuration for removal of duplicate reactions.
863
+ """
864
+
865
+ # configuration for reaction standardizers
866
+ reaction_mapping_config: Optional[ReactionMappingConfig] = None
867
+ functional_groups_config: Optional[FunctionalGroupsConfig] = None
868
+ kekule_form_config: Optional[KekuleFormConfig] = None
869
+ check_valence_config: Optional[CheckValenceConfig] = None
870
+ implicify_hydrogens_config: Optional[ImplicifyHydrogensConfig] = None
871
+ check_isotopes_config: Optional[CheckIsotopesConfig] = None
872
+ split_ions_config: Optional[SplitIonsConfig] = None
873
+ aromatic_form_config: Optional[AromaticFormConfig] = None
874
+ mapping_fix_config: Optional[MappingFixConfig] = None
875
+ unchanged_parts_config: Optional[UnchangedPartsConfig] = None
876
+ small_molecules_config: Optional[SmallMoleculesConfig] = None
877
+ remove_reagents_config: Optional[RemoveReagentsConfig] = None
878
+ rebalance_reaction_config: Optional[RebalanceReactionConfig] = None
879
+ duplicate_reaction_config: Optional[DuplicateReactionConfig] = None
880
+
881
+ def _validate_params(self, params: Dict[str, Any]) -> None:
882
+ """Validate configuration parameters."""
883
+ for field_name, config in self.__dict__.items():
884
+ if config is not None and hasattr(config, "_validate_params"):
885
+ config._validate_params(params.get(field_name, {}))
886
+
887
+ def to_dict(self):
888
+ """Converts the configuration into a dictionary."""
889
+ config_dict = {}
890
+ for field_name in STANDARDIZER_REGISTRY:
891
+ config = getattr(self, field_name)
892
+ if config is not None:
893
+ config_dict[field_name] = {}
894
+ return config_dict
895
+
896
+ @staticmethod
897
+ def from_dict(config_dict: Dict[str, Any]) -> "ReactionStandardizationConfig":
898
+ """Create an instance of ReactionCheckConfig from a dictionary."""
899
+ config_kwargs = {}
900
+ for field_name, std_cls in STANDARDIZER_REGISTRY.items():
901
+ if field_name in config_dict:
902
+ config_kwargs[field_name] = std_cls.__name__.replace(
903
+ "Standardizer", "Config"
904
+ )()
905
+ return ReactionStandardizationConfig(**config_kwargs)
906
+
907
+ @staticmethod
908
+ def from_yaml(file_path: str) -> "ReactionStandardizationConfig":
909
+ """Deserializes a YAML file into a ReactionCheckConfig object."""
910
+ with open(file_path, "r", encoding="utf-8") as file:
911
+ config_dict = yaml.safe_load(file)
912
+ return ReactionStandardizationConfig.from_dict(config_dict)
913
+
914
+ def create_standardizers(self):
915
+ """Create standardizer instances based on configuration."""
916
+ standardizers = []
917
+ for field_name, std_cls in STANDARDIZER_REGISTRY.items():
918
+ config = getattr(self, field_name)
919
+ if config is not None:
920
+ standardizers.append(std_cls.from_config(config))
921
+ return standardizers
922
+
923
+
924
+ def standardize_reaction(
925
+ reaction: ReactionContainer,
926
+ standardizers: Sequence,
927
+ ) -> ReactionContainer | None:
928
+ """
929
+ Apply each standardizer in order.
930
+
931
+ Returns
932
+ -------
933
+ ReactionContainer | None
934
+ - the fully‑standardised reaction, or
935
+ - None if *any* standardizer decides to filter it out.
936
+
937
+ Raises
938
+ ------
939
+ StandardizationError
940
+ Propagated untouched so the caller can decide what to do.
941
+ """
942
+ std_rxn = reaction
943
+ for std in standardizers:
944
+ logger.debug(" › %s(%s)", std.__class__.__name__, std_rxn)
945
+ try:
946
+ std_rxn = std(std_rxn) # may return None
947
+ if std_rxn is None: # soft filter
948
+ logger.info("%s filtered out reaction", std.__class__.__name__)
949
+ return None
950
+ except StandardizationError as exc:
951
+ # Log *once*, then re‑raise with full traceback intact
952
+ logger.warning(
953
+ "%s failed on reaction %s : %s",
954
+ std.__class__.__name__,
955
+ std_rxn,
956
+ exc,
957
+ )
958
+ raise # re‑raise same object
959
+ return std_rxn
960
+
961
+
962
+ def safe_standardize(
963
+ item: str | ReactionContainer,
964
+ standardizers: Sequence,
965
+ ) -> Tuple[ReactionContainer, bool]:
966
+ """
967
+ Always returns a ReactionContainer. The boolean flags real success.
968
+ """
969
+ try:
970
+ # Parse only if needed
971
+ reaction = (
972
+ item if isinstance(item, ReactionContainer) else smiles_cgrtools(item)
973
+ )
974
+ std = standardize_reaction(reaction, standardizers)
975
+ if std is None:
976
+ return reaction, False # filtered → keep original
977
+ return std, True
978
+ except Exception as exc: # noqa: BLE001
979
+ # keep the original container (parse if it was a string)
980
+ if isinstance(item, ReactionContainer):
981
+ return item, False
982
+ return smiles_cgrtools(item), False
983
+
984
+
985
+ def _process_batch(
986
+ batch: Sequence[str | ReactionContainer],
987
+ standardizers: Sequence,
988
+ ) -> Tuple[List[ReactionContainer], int]:
989
+ results: List[ReactionContainer] = []
990
+ n_std = 0
991
+ for item in batch:
992
+ rxn, ok = safe_standardize(item, standardizers)
993
+ results.append(rxn)
994
+ n_std += ok
995
+ return results, n_std
996
+
997
+
998
+ @ray.remote
999
+ def process_batch_remote(
1000
+ batch: Sequence[str | ReactionContainer],
1001
+ std_param: ray.ObjectRef, # <-- receives a ref
1002
+ log_file_path: str | Path | None = None,
1003
+ ) -> Tuple[List[ReactionContainer], int]:
1004
+ # Ray keeps a local cache of fetched objects, so the list is
1005
+ # deserialised only once per worker process, not once per task.
1006
+ if isinstance(std_param, ray.ObjectRef): # handle? get it
1007
+ standardizers = ray.get(std_param) # • O(once)
1008
+ else: # plain list? use as is
1009
+ standardizers = std_param
1010
+
1011
+ # --- Worker-specific logging setup ---
1012
+ worker_logger = logging.getLogger("synplan.chem.data.standardizing")
1013
+ if log_file_path:
1014
+ log_file_path = Path(log_file_path) # Ensure it's a Path object
1015
+ # Check if a handler for this file already exists for this logger
1016
+ handler_exists = any(
1017
+ isinstance(h, logging.FileHandler) and Path(h.baseFilename) == log_file_path
1018
+ for h in worker_logger.handlers
1019
+ )
1020
+ if not handler_exists:
1021
+ try:
1022
+ fh = logging.FileHandler(log_file_path, encoding="utf-8")
1023
+ # Use a simple format for worker logs, or match driver's format
1024
+ formatter = logging.Formatter(
1025
+ "%(asctime)s | %(name)s (worker) | %(levelname)-8s | %(message)s",
1026
+ datefmt="%Y-%m-%d %H:%M:%S",
1027
+ )
1028
+ fh.setFormatter(formatter)
1029
+ fh.setLevel(logging.INFO) # Or DEBUG, or use worker_log_level if passed
1030
+ worker_logger.addHandler(fh)
1031
+ worker_logger.setLevel(
1032
+ logging.INFO
1033
+ ) # Ensure logger passes messages to handler
1034
+ worker_logger.propagate = (
1035
+ False # Avoid double logging if driver also logs
1036
+ )
1037
+ # Optional: Log that the handler was added
1038
+ # worker_logger.info(f"Worker process attached file handler: {log_file_path}")
1039
+ except Exception as e:
1040
+ # Log error if handler creation fails (e.g., permissions)
1041
+ logging.error(
1042
+ f"Worker failed to create file handler {log_file_path}: {e}"
1043
+ )
1044
+
1045
+ return _process_batch(batch, standardizers)
1046
+
1047
+
1048
+ def chunked(iterable: Iterable, size: int):
1049
+ chunk = []
1050
+ for it in iterable:
1051
+ chunk.append(it)
1052
+ if len(chunk) == size:
1053
+ yield chunk
1054
+ chunk = []
1055
+ if chunk:
1056
+ yield chunk
1057
+
1058
+
1059
+ def standardize_reactions_from_file(
1060
+ config: "ReactionStandardizationConfig",
1061
+ input_reaction_data_path: str | Path,
1062
+ standardized_reaction_data_path: str | Path = "reaction_data_standardized.smi",
1063
+ *,
1064
+ num_cpus: int = 1,
1065
+ batch_size: int = 1_000, # larger batches amortise overhead
1066
+ silent: bool = True,
1067
+ max_pending_factor: int = 4, # tasks in flight = factor × CPUs
1068
+ worker_log_level: int | str = logging.WARNING,
1069
+ log_file_path: str | Path | None = None,
1070
+ ) -> None:
1071
+ """
1072
+ Reads reactions, standardises them in parallel with Ray, writes results.
1073
+
1074
+ The function keeps at most `max_pending_factor * num_cpus` Ray tasks in
1075
+ flight to avoid flooding the scheduler and blowing up the object store.
1076
+ Standardisers are broadcast once with `ray.put`, removing per‑task
1077
+ pickling cost. All other logic is unchanged.
1078
+
1079
+ Args:
1080
+ config: Configuration object for standardizers.
1081
+ input_reaction_data_path: Path to the input reaction data file.
1082
+ standardized_reaction_data_path: Path to save the standardized reactions.
1083
+ num_cpus: Number of CPU cores to use for parallel processing.
1084
+ batch_size: Number of reactions to process in each batch.
1085
+ silent: If True, suppress the progress bar.
1086
+ max_pending_factor: Controls the number of pending Ray tasks.
1087
+ worker_log_level: Logging level for Ray workers (e.g., logging.INFO, logging.WARNING).
1088
+ log_file_path: Path to the log file for workers to write to.
1089
+ """
1090
+ output_path = Path(standardized_reaction_data_path)
1091
+ standardizers = config.create_standardizers()
1092
+
1093
+ logger.info(
1094
+ "Standardizers: %s",
1095
+ ", ".join(s.__class__.__name__ for s in standardizers),
1096
+ )
1097
+
1098
+ # ----------------------- Ray initialisation -----------------------
1099
+ if num_cpus > 1:
1100
+ if not ray.is_initialized():
1101
+ ray.init(
1102
+ num_cpus=num_cpus,
1103
+ ignore_reinit_error=True,
1104
+ logging_level=worker_log_level,
1105
+ log_to_driver=False,
1106
+ )
1107
+
1108
+ DEDUP_NAME = "duplicate_rxn_actor"
1109
+
1110
+ try:
1111
+ dedup_actor = ray.get_actor(DEDUP_NAME) # already running?
1112
+ except ValueError:
1113
+ dedup_actor = DedupActor.options(
1114
+ name=DEDUP_NAME, lifetime="detached" # survives driver exit
1115
+ ).remote()
1116
+
1117
+ std_ref: ray.ObjectRef | None = None
1118
+ if num_cpus > 1 and std_ref is None: # broadcast once
1119
+ std_ref = ray.put(standardizers)
1120
+
1121
+ max_pending = max_pending_factor * num_cpus
1122
+ pending: Dict[ray.ObjectRef, None] = {}
1123
+
1124
+ n_processed = n_std = 0
1125
+ bar = tqdm(
1126
+ total=0,
1127
+ unit="rxn",
1128
+ desc="Standardising",
1129
+ disable=silent,
1130
+ dynamic_ncols=True,
1131
+ )
1132
+
1133
+ # ------------------------ Helper function ------------------------
1134
+ def _flush(ref: ray.ObjectRef, write_fn) -> None:
1135
+ """Fetch finished task, write its results, update counters & bar."""
1136
+ nonlocal n_processed, n_std
1137
+ res, ok = ray.get(ref)
1138
+ write_fn(res)
1139
+ bar.update(len(res))
1140
+ n_processed += len(res)
1141
+ n_std += ok
1142
+
1143
+ # ----------------------------- I/O -------------------------------
1144
+ with ReactionReader(input_reaction_data_path) as reader, ReactionWriter(
1145
+ output_path
1146
+ ) as writer:
1147
+
1148
+ write_fn = lambda reactions: [writer.write(r) for r in reactions]
1149
+
1150
+ # --------------------- Main read/compute loop -----------------
1151
+ for chunk in chunked(reader, batch_size):
1152
+ bar.total += len(chunk)
1153
+ bar.refresh()
1154
+
1155
+ if num_cpus > 1:
1156
+ # ---------- back‑pressure: keep ≤ max_pending ----------
1157
+ while len(pending) >= max_pending:
1158
+ done, _ = ray.wait(list(pending), num_returns=1)
1159
+ _flush(done[0], write_fn)
1160
+ pending.pop(done[0], None)
1161
+
1162
+ # ----------- schedule new task -------------------------
1163
+ ref = process_batch_remote.remote(chunk, std_ref, log_file_path)
1164
+ pending[ref] = None
1165
+ else:
1166
+ # --------------- serial fall‑back ----------------------
1167
+ res, ok = _process_batch(chunk, standardizers)
1168
+ write_fn(res)
1169
+ bar.update(len(res))
1170
+ n_processed += len(res)
1171
+ n_std += ok
1172
+
1173
+ # ------------------ Drain remaining Ray tasks -----------------
1174
+ while pending:
1175
+ done, _ = ray.wait(list(pending), num_returns=1)
1176
+ _flush(done[0], write_fn)
1177
+ pending.pop(done[0], None)
1178
+
1179
+ bar.close()
1180
+ ray.shutdown()
1181
+
1182
+ logger.info(
1183
+ "Finished: processed %d, standardised %d, filtered %d",
1184
+ n_processed,
1185
+ n_std,
1186
+ n_processed - n_std,
1187
+ )
synplan/chem/precursor.py ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Module containing a class Precursor that represents a precursor (extend molecule object) in
2
+ the search tree."""
3
+
4
+ from typing import Set
5
+
6
+ from CGRtools.containers import MoleculeContainer
7
+
8
+ from synplan.chem.utils import safe_canonicalization
9
+
10
+
11
+ class Precursor:
12
+ """Precursor class is used to extend the molecule behavior needed for interaction with
13
+ a tree in MCTS."""
14
+
15
+ def __init__(self, molecule: MoleculeContainer, canonicalize: bool = True):
16
+ """It initializes a Precursor object with a molecule container as a parameter.
17
+
18
+ :param molecule: A molecule.
19
+ """
20
+ self.molecule = safe_canonicalization(molecule) if canonicalize else molecule
21
+ self.prev_precursors = []
22
+
23
+ def __len__(self) -> int:
24
+ """Return the number of atoms in Precursor."""
25
+ return len(self.molecule)
26
+
27
+ def __hash__(self) -> hash:
28
+ """Returns the hash value of Precursor."""
29
+ return hash(self.molecule)
30
+
31
+ def __str__(self) -> str:
32
+ """Returns a SMILES of the Precursor."""
33
+ return str(self.molecule)
34
+
35
+ def __eq__(self, other: "Precursor") -> bool:
36
+ """Checks if the current Precursor is equal to another Precursor."""
37
+ return self.molecule == other.molecule
38
+
39
+ def __repr__(self) -> str:
40
+ """Returns a SMILES of the Precursor."""
41
+ return str(self.molecule)
42
+
43
+ def is_building_block(self, bb_stock: Set[str], min_mol_size: int = 6) -> bool:
44
+ """Checks if a Precursor is a building block.
45
+
46
+ :param bb_stock: The list of building blocks. Each building block is represented
47
+ by a canonical SMILES.
48
+ :param min_mol_size: If the size of the Precursor is equal or smaller than
49
+ min_mol_size it is automatically classified as building block.
50
+ :return: True is Precursor is a building block.
51
+ """
52
+ if len(self.molecule) <= min_mol_size:
53
+ return True
54
+
55
+ return str(self.molecule) in bb_stock
56
+
57
+
58
+ def compose_precursors(
59
+ precursors: list = None, exclude_small: bool = True, min_mol_size: int = 6
60
+ ) -> MoleculeContainer:
61
+ """
62
+ Takes a list of precursors, excludes small precursors if specified, and composes them
63
+ into a single molecule. The composed molecule then is used for the prediction of
64
+ synthesisability of the characterizing the possible success of the route including
65
+ the nodes with the given precursor.
66
+
67
+ :param precursors: The list of precursor to be composed.
68
+ :param exclude_small: The parameter that determines whether small precursor should be excluded from the composition
69
+ process. If `exclude_small` is set to `True`,
70
+ only precursor with a length greater than min_mol_size will be composed.
71
+ :param min_mol_size: The parameter used with exclude_small.
72
+
73
+ :return: A composed precursor as a MoleculeContainer object.
74
+
75
+ """
76
+
77
+ if len(precursors) == 1:
78
+ return precursors[0].molecule
79
+ if len(precursors) > 1:
80
+ if exclude_small:
81
+ big_precursor = [
82
+ precursor
83
+ for precursor in precursors
84
+ if len(precursor.molecule) > min_mol_size
85
+ ]
86
+ if big_precursor:
87
+ precursors = big_precursor
88
+ tmp_mol = precursors[0].molecule.copy()
89
+ transition_mapping = {}
90
+ for mol in precursors[1:]:
91
+ for n, atom in mol.molecule.atoms():
92
+ new_number = tmp_mol.add_atom(atom.atomic_symbol)
93
+ transition_mapping[n] = new_number
94
+ for atom, neighbor, bond in mol.molecule.bonds():
95
+ tmp_mol.add_bond(
96
+ transition_mapping[atom], transition_mapping[neighbor], bond
97
+ )
98
+ transition_mapping = {}
99
+
100
+ return tmp_mol
synplan/chem/reaction.py ADDED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Module containing classes and functions for manipulating reactions and reaction
2
+ rules."""
3
+
4
+ from typing import Any, Iterator, List, Optional
5
+
6
+ from CGRtools.containers import MoleculeContainer, ReactionContainer
7
+ from CGRtools.exceptions import InvalidAromaticRing
8
+ from CGRtools.reactor import Reactor
9
+
10
+
11
+ class Reaction(ReactionContainer):
12
+ """Reaction class used for a general representation of reaction."""
13
+
14
+ def __init__(self, *args, **kwargs):
15
+ super().__init__(*args, **kwargs)
16
+
17
+
18
+ def add_small_mols(
19
+ big_mol: MoleculeContainer, small_molecules: Optional[Any] = None
20
+ ) -> List[MoleculeContainer]:
21
+ """Takes a molecule and returns a list of modified molecules where each small
22
+ molecule has been added to the big molecule.
23
+
24
+ :param big_mol: A molecule.
25
+ :param small_molecules: A list of small molecules that need to be added to the
26
+ molecule.
27
+ :return: Returns a list of molecules.
28
+ """
29
+ if small_molecules:
30
+ tmp_mol = big_mol.copy()
31
+ transition_mapping = {}
32
+ for small_mol in small_molecules:
33
+
34
+ for n, atom in small_mol.atoms():
35
+ new_number = tmp_mol.add_atom(atom.atomic_symbol)
36
+ transition_mapping[n] = new_number
37
+
38
+ for atom, neighbor, bond in small_mol.bonds():
39
+ tmp_mol.add_bond(
40
+ transition_mapping[atom], transition_mapping[neighbor], bond
41
+ )
42
+
43
+ transition_mapping = {}
44
+ return tmp_mol.split()
45
+
46
+ return [big_mol]
47
+
48
+
49
+ def apply_reaction_rule(
50
+ molecule: MoleculeContainer,
51
+ reaction_rule: Reactor,
52
+ sort_reactions: bool = False,
53
+ top_reactions_num: int = 3,
54
+ validate_products: bool = True,
55
+ rebuild_with_cgr: bool = False,
56
+ ) -> Iterator[List[MoleculeContainer,]]:
57
+ """Applies a reaction rule to a given molecule.
58
+
59
+ :param molecule: A molecule to which reaction rule will be applied.
60
+ :param reaction_rule: A reaction rule to be applied.
61
+ :param sort_reactions:
62
+ :param top_reactions_num: The maximum amount of reactions after the application of
63
+ reaction rule.
64
+ :param validate_products: If True, validates the final products.
65
+ :param rebuild_with_cgr: If True, the products are extracted from CGR decomposition.
66
+ :return: An iterator yielding the products of reaction rule application.
67
+ """
68
+
69
+ reactants = add_small_mols(molecule, small_molecules=False)
70
+
71
+ try:
72
+ if sort_reactions:
73
+ unsorted_reactions = list(reaction_rule(reactants))
74
+ sorted_reactions = sorted(
75
+ unsorted_reactions,
76
+ key=lambda react: len(
77
+ list(filter(lambda mol: len(mol) > 6, react.products))
78
+ ),
79
+ reverse=True,
80
+ )
81
+
82
+ # take top-N reactions from reactor
83
+ reactions = sorted_reactions[:top_reactions_num]
84
+ else:
85
+ reactions = []
86
+ for reaction in reaction_rule(reactants):
87
+ reactions.append(reaction)
88
+ if len(reactions) == top_reactions_num:
89
+ break
90
+ except IndexError:
91
+ reactions = []
92
+
93
+ for reaction in reactions:
94
+
95
+ # temporary solution - incorrect leaving groups
96
+ reactant_atom_nums = []
97
+ for i in reaction.reactants:
98
+ reactant_atom_nums.extend(i.atoms_numbers)
99
+ product_atom_nums = []
100
+ for i in reaction.products:
101
+ product_atom_nums.extend(i.atoms_numbers)
102
+ leaving_atom_nums = set(reactant_atom_nums) - set(product_atom_nums)
103
+ if len(leaving_atom_nums) > len(product_atom_nums):
104
+ continue
105
+
106
+ # check reaction
107
+ if rebuild_with_cgr:
108
+ cgr = reaction.compose()
109
+ reactants = cgr.decompose()[1].split()
110
+ else:
111
+ reactants = reaction.products # reactants are products in retro reaction
112
+ reactants = [mol for mol in reactants if len(mol) > 0]
113
+
114
+ # validate products
115
+ if validate_products:
116
+ for mol in reactants:
117
+ try:
118
+ mol.kekule()
119
+ if mol.check_valence():
120
+ yield None
121
+ mol.thiele()
122
+ except InvalidAromaticRing:
123
+ yield None
124
+
125
+ yield reactants
synplan/chem/reaction_routes/__init__.py ADDED
File without changes
synplan/chem/reaction_routes/clustering.py ADDED
@@ -0,0 +1,857 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections import defaultdict
2
+
3
+ from pathlib import Path
4
+ import pickle
5
+ import re
6
+
7
+ from CGRtools.containers import ReactionContainer, CGRContainer
8
+ from CGRtools.containers.bonds import DynamicBond
9
+
10
+ from synplan.chem.reaction_routes.leaving_groups import *
11
+ from synplan.chem.reaction_routes.visualisation import *
12
+ from synplan.chem.reaction_routes.route_cgr import *
13
+ from synplan.chem.reaction_routes.io import (
14
+ read_routes_csv,
15
+ read_routes_json,
16
+ make_dict,
17
+ make_json,
18
+ )
19
+ from synplan.utils.visualisation import (
20
+ routes_clustering_report,
21
+ routes_subclustering_report,
22
+ )
23
+
24
+
25
+ def run_cluster_cli(
26
+ routes_file: str,
27
+ cluster_results_dir: str,
28
+ perform_subcluster: bool = False,
29
+ subcluster_results_dir: Path = None,
30
+ ):
31
+ """
32
+ Read routes from a CSV or JSON file, perform clustering, and optionally subclustering.
33
+
34
+ Args:
35
+ routes_file: Path to the input routes file (.csv or .json).
36
+ cluster_results_dir: Directory where clustering results are stored.
37
+ perform_subcluster: Whether to run subclustering on each cluster.
38
+ subcluster_results_dir: Subdirectory for subclustering results (if enabled).
39
+ """
40
+ import click
41
+
42
+ routes_file = Path(routes_file)
43
+ match = re.search(r"_(\d+)\.", routes_file.name)
44
+ if not match:
45
+ raise ValueError(f"Could not extract index from filename: {routes_file.name}")
46
+ file_index = int(match.group(1))
47
+ ext = routes_file.suffix.lower()
48
+ if ext == ".csv":
49
+ routes_dict = read_routes_csv(str(routes_file))
50
+ routes_json = make_json(routes_dict)
51
+ elif ext == ".json":
52
+ routes_json = read_routes_json(str(routes_file))
53
+ routes_dict = make_dict(routes_json)
54
+ else:
55
+ raise ValueError(f"Unsupported file type: {ext}")
56
+
57
+ # Compose condensed graph representations
58
+ route_cgrs = compose_all_route_cgrs(routes_dict)
59
+ click.echo(f"Generating RouteCGR")
60
+ reduced_cgrs = compose_all_reduced_route_cgrs(route_cgrs)
61
+ click.echo(f"Generating ReducedRouteCGR")
62
+
63
+ # Perform clustering
64
+ click.echo(f"\nClustering")
65
+ clusters = cluster_routes(reduced_cgrs, use_strat=False)
66
+
67
+ click.echo(f"Total number of routes: {len(routes_dict)}")
68
+ click.echo(f"Found number of clusters: {len(clusters)} ({list(clusters.keys())})")
69
+
70
+ # Ensure output directory exists
71
+ cluster_results_dir = Path(cluster_results_dir)
72
+ cluster_results_dir.mkdir(parents=True, exist_ok=True)
73
+
74
+ # Save clusters to pickle
75
+ with open(cluster_results_dir / f"clusters_{file_index}.pickle", "wb") as f:
76
+ pickle.dump(clusters, f)
77
+
78
+ # Generate HTML reports for each cluster
79
+ for idx in clusters:
80
+ report_path = cluster_results_dir / f"{file_index}_cluster_{idx}.html"
81
+ routes_clustering_report(
82
+ routes_json, clusters, idx, reduced_cgrs, html_path=str(report_path)
83
+ )
84
+
85
+ # Optional subclustering (Under development)
86
+ if perform_subcluster and subcluster_results_dir:
87
+ click.echo("\nSubClustering")
88
+ sub_dir = cluster_results_dir / subcluster_results_dir
89
+ sub_dir.mkdir(parents=True, exist_ok=True)
90
+
91
+ subclusters = subcluster_all_clusters(clusters, reduced_cgrs, route_cgrs)
92
+ for cluster_idx, sub in subclusters.items():
93
+ click.echo(f"Cluster {cluster_idx} has {len(sub)} subclusters")
94
+ for sub_idx, subcluster in sub.items():
95
+ subreport_path = (
96
+ sub_dir / f"{file_index}_subcluster_{cluster_idx}.{sub_idx}.html"
97
+ )
98
+ routes_subclustering_report(
99
+ routes_json,
100
+ subcluster,
101
+ cluster_idx,
102
+ sub_idx,
103
+ reduced_cgrs,
104
+ aam=False,
105
+ html_path=str(subreport_path),
106
+ )
107
+
108
+
109
+ def cluster_route_from_csv(routes_file: str):
110
+ """
111
+ Reads retrosynthetic routes from a CSV file, processes them, and performs clustering.
112
+
113
+ This function orchestrates the process of loading retrosynthetic route data
114
+ from a specified CSV file, converting the routes into Condensed Graph of
115
+ Reactions (CGRs), reducing these CGRs to a simplified form (ReducedRouteCGRs),
116
+ and finally clustering the routes based on these reduced representations.
117
+ It uses strategic bonds for clustering by default (as indicated by `use_strat=False`
118
+ in `cluster_routes`, which implies clustering based on the graph structure
119
+ derived from the reduced CGRs, which often highlight strategic bonds).
120
+
121
+ Args:
122
+ routes_file (str): The path to the CSV file containing the retrosynthetic
123
+ route data.
124
+
125
+ Returns:
126
+ object: The result of the clustering process, typically a data structure
127
+ representing the identified clusters. The exact type depends on
128
+ the implementation of the `cluster_routes` function.
129
+ """
130
+ routes_dict = read_routes_csv(routes_file)
131
+ route_cgrs_dict = compose_all_route_cgrs(routes_dict)
132
+ reduced_route_cgrs_dict = compose_all_reduced_route_cgrs(route_cgrs_dict)
133
+ clusters = cluster_routes(reduced_route_cgrs_dict, use_strat=False)
134
+ return clusters
135
+
136
+
137
+ def cluster_route_from_json(routes_file: str):
138
+ """
139
+ Reads retrosynthetic routes from a JSON file, processes them, and performs clustering.
140
+
141
+ This function is similar to `cluster_route_from_csv` but loads the
142
+ retrosynthetic route data from a specified JSON file. It reads the JSON,
143
+ converts it into a suitable dictionary format, composes and reduces the
144
+ Condensed Graph of Reactions (CGRs) for each route, and then clusters
145
+ the routes based on these reduced representations, typically using
146
+ strategic bonds as the basis for clustering.
147
+
148
+ Args:
149
+ routes_file (str): The path to the JSON file containing the retrosynthetic
150
+ route data.
151
+
152
+ Returns:
153
+ object: The result of the clustering process, typically a data structure
154
+ representing the identified clusters. The exact type depends on
155
+ the implementation of the `cluster_routes` function.
156
+ """
157
+ routes_json = read_routes_json(routes_file)
158
+ routes_dict = make_dict(routes_json)
159
+ route_cgrs_dict = compose_all_route_cgrs(routes_dict)
160
+ reduced_route_cgrs_dict = compose_all_reduced_route_cgrs(route_cgrs_dict)
161
+ clusters = cluster_routes(reduced_route_cgrs_dict, use_strat=False)
162
+ return clusters
163
+
164
+
165
+ def extract_strat_bonds(target_cgr: CGRContainer):
166
+ """
167
+ Extracts strategic bonds from a CGRContainer object.
168
+
169
+ Strategic bonds are identified as bonds where the original bond order
170
+ (`bond.order`) is None (indicating a bond that was not present in the
171
+ reactants) but the primary bond order (`bond.p_order`) is not None
172
+ (indicating a bond that was formed in the product). This function iterates
173
+ through all bonds in the input CGR, identifies those matching the criteria
174
+ for strategic bonds, and returns a sorted list of unique strategic bonds
175
+ represented as tuples of sorted atom indices.
176
+
177
+ Args:
178
+ target_cgr (CGRContainer): The CGRContainer object from which to extract
179
+ strategic bonds.
180
+
181
+ Returns:
182
+ list: A sorted list of tuples, where each tuple represents a strategic
183
+ bond by the sorted integer indices of the two atoms involved in the bond.
184
+ """
185
+ result = []
186
+ seen = set()
187
+ for atom1, bond_set in target_cgr._bonds.items():
188
+ for atom2, bond in bond_set.items():
189
+ if atom1 >= atom2:
190
+ continue
191
+ if bond.order is None and bond.p_order is not None:
192
+ bond_key = tuple(sorted((atom1, atom2)))
193
+ if bond_key not in seen:
194
+ seen.add(bond_key)
195
+ result.append(bond_key)
196
+ return sorted(result)
197
+
198
+
199
+ def cluster_routes(r_route_cgrs: dict, use_strat=False):
200
+ """
201
+ Cluster routes objects based on their strategic bonds
202
+ or CGRContainer object signature (not avoid mapping)
203
+
204
+ Args:
205
+ r_route_cgrs: Dictionary mapping node_id to r_route_cgr objects.
206
+
207
+ Returns:
208
+ Dictionary with groups keyed by '{length}.{index}' containing
209
+ 'r_route_cgr', 'node_ids', and 'strat_bonds'.
210
+ """
211
+ temp_groups = defaultdict(
212
+ lambda: {"node_ids": [], "r_route_cgr": None, "strat_bonds": None}
213
+ )
214
+
215
+ # 1. Initial grouping based on the content of strategic bonds
216
+ for node_id, r_route_cgr in r_route_cgrs.items():
217
+ strat_bonds_list = extract_strat_bonds(r_route_cgr)
218
+ if use_strat == True:
219
+ group_key = tuple(strat_bonds_list)
220
+ else:
221
+ group_key = str(r_route_cgr)
222
+
223
+ if not temp_groups[group_key]["node_ids"]: # First time seeing this group
224
+ temp_groups[group_key][
225
+ "r_route_cgr"
226
+ ] = r_route_cgr # Store the first CGR as representative
227
+ temp_groups[group_key][
228
+ "strat_bonds"
229
+ ] = strat_bonds_list # Store the actual list
230
+
231
+ temp_groups[group_key]["node_ids"].append(node_id)
232
+ temp_groups[group_key][
233
+ "node_ids"
234
+ ].sort() # Keep node_ids sorted for consistency
235
+
236
+ for group_key in temp_groups.keys():
237
+ temp_groups[group_key]["group_size"] = len(temp_groups[group_key]["node_ids"])
238
+
239
+ # 2. Format the output dictionary with desired keys '{length}.{index}'
240
+ final_grouped_results = {}
241
+ group_indices = defaultdict(int) # To track index for each length
242
+
243
+ # Sort items by length of bonds first, then potentially by bonds themselves for consistent indexing
244
+ # Sorting by the group_key (tuple of tuples) provides a deterministic order
245
+ sorted_groups = sorted(
246
+ temp_groups.items(), key=lambda item: (len(item[0]), item[0])
247
+ )
248
+
249
+ for group_key, group_data in sorted_groups:
250
+ num_bonds = len(group_data["strat_bonds"])
251
+ group_indices[num_bonds] += 1 # Increment index for this length (1-based)
252
+ final_key = f"{num_bonds}.{group_indices[num_bonds]}"
253
+ final_grouped_results[final_key] = group_data
254
+
255
+ return final_grouped_results
256
+
257
+
258
+ def lg_process_reset(lg_cgr: CGRContainer, atom_num: int):
259
+ """
260
+ Normalize bonds in an extracted leaving group (X) fragment and flag the attachment atom as a radical.
261
+
262
+ Scans all bonds in `lg_cgr`, converting any bond with undefined `p_order`
263
+ but defined `order` into a `DynamicBond` of matching integer order. Then sets
264
+ the atom at `atom_num` to a radical.
265
+
266
+ Parameters
267
+ ----------
268
+ target_cgr : CGRContainer
269
+ The CGR representing the isolated leaving-group fragment.
270
+ atom_num : int
271
+ Index of the attachment atom to mark as a radical.
272
+
273
+ Returns
274
+ -------
275
+ CGRContainer
276
+ The modified `lg_cgr` with normalized bonds and the specified atom
277
+ flagged as a radical.
278
+ """
279
+ bond_items = list(lg_cgr._bonds.items())
280
+ for atom1, bond_set in bond_items:
281
+ bond_set_items = list(bond_set.items())
282
+ for atom2, bond in bond_set_items:
283
+ if bond.p_order is None and bond.order is not None:
284
+ order = int(bond.order)
285
+ lg_cgr.delete_bond(atom1, atom2)
286
+ lg_cgr.add_bond(atom1, atom2, DynamicBond(order, order))
287
+ lg_cgr._atoms[atom_num].is_radical = True
288
+ return lg_cgr
289
+
290
+
291
+ def lg_replacer(route_cgr: CGRContainer):
292
+ """
293
+ Extract dynamic leaving-groups from a CGR and mark attachment points.
294
+
295
+ Scans the input CGRContainer for bonds lacking explicit p_order (i.e., leaving-group attachments),
296
+ severs those bonds, captures each leaving-group as its own CGRContainer, and inserts DynamicX
297
+ markers at the attachment sites. Finally, reindexes the markers to ensure unique labels.
298
+
299
+ Parameters
300
+ ----------
301
+ route_cgr : CGRContainer
302
+ A CGR representing the full synthethic route.
303
+
304
+ Returns
305
+ -------
306
+ synthon_cgr : CGRContainer
307
+ The core synthon CGR with DynamicX atoms marking each former leaving-group site.
308
+ lg_groups : dict[int, tuple[CGRContainer, int]]
309
+ Mapping from each marker label to a tuple of:
310
+ - the extracted leaving-group CGRContainer
311
+ - the atom index where it was attached.
312
+ """
313
+ lg_groups = {}
314
+
315
+ cgr_prods = [route_cgr.substructure(c) for c in route_cgr.connected_components]
316
+ target_cgr = cgr_prods[0]
317
+
318
+ bond_items = list(target_cgr._bonds.items())
319
+ reaction = ReactionContainer.from_cgr(target_cgr)
320
+ target_mol = reaction.products[0]
321
+ max_in_target_mol = max(target_mol._atoms)
322
+
323
+ k = 1
324
+ atom_nums = []
325
+
326
+ for atom1, bond_set in bond_items:
327
+ bond_set_items = list(bond_set.items())
328
+ for atom2, bond in bond_set_items:
329
+ if bond.p_order is None and bond.order is not None:
330
+ if atom1 <= max_in_target_mol:
331
+ lg = DynamicX()
332
+ lg.mark = k
333
+ lg.isotope = k
334
+ order = bond.order
335
+ p_order = bond.p_order
336
+ target_cgr.delete_bond(atom1, atom2)
337
+ lg_cgrs = [
338
+ target_cgr.substructure(c)
339
+ for c in target_cgr.connected_components
340
+ ]
341
+ if len(lg_cgrs) == 2:
342
+ lg_cgr = lg_cgrs[1]
343
+ lg_cgr = lg_process_reset(lg_cgr, atom2)
344
+ lg_cgr.clean2d()
345
+ else:
346
+ continue
347
+ lg_groups[k] = (lg_cgr, atom2)
348
+ target_cgr = [
349
+ target_cgr.substructure(c)
350
+ for c in target_cgr.connected_components
351
+ ][0]
352
+ target_cgr.add_atom(lg, atom2)
353
+ if order == 4 and p_order == None:
354
+ order = 1
355
+ target_cgr.add_bond(atom1, atom2, DynamicBond(order, p_order))
356
+ target_cgr = [
357
+ target_cgr.substructure(c)
358
+ for c in target_cgr.connected_components
359
+ ][0]
360
+ k += 1
361
+ atom_nums.append(atom2)
362
+
363
+ synthon_cgr = [target_cgr.substructure(c) for c in target_cgr.connected_components][
364
+ 0
365
+ ]
366
+ reaction = ReactionContainer.from_cgr(synthon_cgr)
367
+ reactants = reaction.reactants
368
+
369
+ atom_mark_map = {} # To map atom numbers to their new marks
370
+ g = 1
371
+ for n, r in enumerate(reactants):
372
+ for atom_num in atom_nums:
373
+ if atom_num in r._atoms:
374
+ synthon_cgr._atoms[atom_num].mark = g
375
+ atom_mark_map[atom_num] = g
376
+ g += 1
377
+
378
+ new_lg_groups = {}
379
+ for original_mark in lg_groups:
380
+ cgr_obj, a_num = lg_groups[original_mark]
381
+ new_mark = atom_mark_map.get(a_num)
382
+ if new_mark is not None:
383
+ new_lg_groups[new_mark] = (cgr_obj, a_num)
384
+ lg_groups = new_lg_groups
385
+
386
+ return synthon_cgr, lg_groups
387
+
388
+
389
+ def lg_reaction_replacer(
390
+ synthon_reaction: ReactionContainer, lg_groups: dict, max_in_target_mol: int
391
+ ):
392
+ """
393
+ Replace marked leaving-groups (X) into synthon reactants.
394
+
395
+ For each reactant in `synthon_reaction`, finds placeholder atoms
396
+ (indices > `max_in_target_mol`) that match entries in `lg_groups`,
397
+ replaces them with `MarkedAt` atoms labeled by their leaving-group key (X),
398
+ and preserves original bond connectivity.
399
+
400
+ Parameters
401
+ ----------
402
+ synthon_reaction : ReactionContainer
403
+ Reaction containing reactants with X placeholders.
404
+ lg_groups : dict[int, tuple[CGRContainer, int]]
405
+ Mapping from X label to (X CGR, attachment atom index).
406
+ max_in_target_mol : int
407
+ Highest atom index of the core product; any atom_num above this is a placeholder.
408
+
409
+ Returns
410
+ -------
411
+ List[Molecule]
412
+ Reactant molecules with `MarkedAt` atoms reinserted at X attachment sites.
413
+ """
414
+ new_reactants = []
415
+ for reactant in synthon_reaction.reactants:
416
+ atom_keys = list(reactant._atoms.keys())
417
+ for atom_num in atom_keys:
418
+ if atom_num > max_in_target_mol:
419
+ for k, val in lg_groups.items():
420
+ lg = MarkedAt()
421
+ if atom_num == val[1]:
422
+ lg.mark = k
423
+ lg.isotope = k
424
+ atom1 = list(reactant._bonds[atom_num].keys())[0]
425
+ bond = reactant._bonds[atom_num][atom1]
426
+ reactant.delete_bond(atom1, atom_num)
427
+ reactant.delete_atom(atom_num)
428
+ reactant.add_atom(lg, atom_num)
429
+ reactant.add_bond(atom1, atom_num, bond)
430
+ new_reactants.append(reactant)
431
+ return new_reactants
432
+
433
+
434
+ class SubclusterError(Exception):
435
+ """Raised when subcluster_one_cluster cannot complete successfully."""
436
+
437
+
438
+ def subcluster_one_cluster(group, r_route_cgrs_dict, route_cgrs_dict):
439
+ """
440
+ Generate synthon data for each route in a single cluster.
441
+
442
+ For each route (node ID) in `group['node_ids']`, replaces RouteCGRs with
443
+ SynthonCGR, builds ReactionContainers before and after X replacement,
444
+ and collects relevant data.
445
+
446
+ Parameters
447
+ ----------
448
+ group : dict
449
+ Must include `'node_ids'`, a list of node identifiers.
450
+ r_route_cgrs_dict : dict
451
+ Maps node IDs to their ReducedRouteCGR.
452
+ route_cgrs_dict : dict
453
+ Maps node IDs to their RouteCGR.
454
+
455
+ Returns
456
+ -------
457
+ dict or None
458
+ If successful, returns a dict mapping each `node_id` to a tuple:
459
+ `(r_route_cgr, original_reaction, synthon_cgr, new_reaction, lg_groups)`.
460
+ Or raises SubclusterError on any failure: if any step (X replacement or reaction
461
+ parsing) fails for a node.
462
+
463
+ """
464
+
465
+ node_ids = group.get("node_ids")
466
+ if not isinstance(node_ids, (list, tuple)):
467
+ raise SubclusterError(
468
+ f"'node_ids' must be a list or tuple, got {type(node_ids).__name__}"
469
+ )
470
+
471
+ result = {}
472
+ for node_id in node_ids:
473
+ r_route_cgr = r_route_cgrs_dict[node_id]
474
+ route_cgr = route_cgrs_dict[node_id]
475
+
476
+ # 1) Replace leaving groups in RouteCGR
477
+ try:
478
+ synthon_cgr, lg_groups = lg_replacer(route_cgr)
479
+ except (KeyError, ValueError) as e:
480
+ raise SubclusterError(f"LG replacement failed for node {node_id}") from e
481
+
482
+ # 2) Build ReactionContainer for Abstracted RouteCGR
483
+ try:
484
+ synthon_rxn = ReactionContainer.from_cgr(synthon_cgr)
485
+ except: # replace with the actual exception class
486
+ raise SubclusterError(
487
+ f"Failed to parse synthon CGR for node {node_id}"
488
+ ) from e
489
+
490
+ # 3) Prepare for X-based reaction replacement
491
+ try:
492
+ old_reactants = synthon_rxn.reactants
493
+ target_mol = synthon_rxn.products[0]
494
+ max_atom_idx = max(target_mol._atoms)
495
+ new_reactants = lg_reaction_replacer(synthon_rxn, lg_groups, max_atom_idx)
496
+ new_rxn = ReactionContainer(reactants=new_reactants, products=[target_mol])
497
+ except (IndexError, TypeError) as e:
498
+ raise SubclusterError(
499
+ f"Leaving group (X) reaction replacement failed for node {node_id}"
500
+ ) from e
501
+
502
+ result[node_id] = (
503
+ r_route_cgr,
504
+ ReactionContainer(reactants=old_reactants, products=[target_mol]),
505
+ synthon_cgr,
506
+ new_rxn,
507
+ lg_groups,
508
+ )
509
+
510
+ return result
511
+
512
+
513
+ def group_nodes_by_synthon_detail(data_dict: dict):
514
+ """
515
+ Groups nodes based on synthon CGR (result[0]) and reaction (result[1]).
516
+ The output includes a dictionary mapping node IDs to their result[2] value.
517
+
518
+ Args:
519
+ data_dict: Dictionary {node_id: [synthon_cgr, synthon_reaction, node_data, ...]}.
520
+
521
+ Returns:
522
+ Dictionary {group_index: {'r_route_cgr': ... ,'synthon_cgr': ..., 'synthon_reaction': ...,
523
+ 'nodes_data': {node_id1: node_data1, ...}}}.
524
+ """
525
+ temp_groups = defaultdict(list)
526
+
527
+ for node_id, result_list in data_dict.items():
528
+ if len(result_list) < 4:
529
+ group_key = (result_list[0], None) # Handle missing reaction
530
+ else:
531
+ try:
532
+ group_key = (
533
+ result_list[0],
534
+ result_list[1],
535
+ result_list[2],
536
+ result_list[3],
537
+ )
538
+ except TypeError:
539
+ print(
540
+ f"Warning: Skipping node {node_id} because reaction data is not hashable: {type(result_list[1])}"
541
+ )
542
+ continue
543
+
544
+ temp_groups[group_key].append(node_id)
545
+
546
+ # 2. Format the output dictionary with sequential integer keys
547
+ # and include the node-specific data (result[2]) in a sub-dictionary.
548
+ final_grouped_results = {}
549
+ group_index = 1
550
+
551
+ sorted_temp_groups = sorted(temp_groups.items(), key=lambda item: item[1])
552
+ for group_key, node_ids in sorted_temp_groups:
553
+
554
+ r_route_cgr, unlabeled_reaction, synthon_cgr, synthon_reaction = group_key
555
+ nodes_data_dict = {}
556
+
557
+ # Iterate through the node IDs belonging to this group
558
+ for node_id in sorted(node_ids): # Sort node IDs for consistent dict order
559
+ original_result = data_dict.get(
560
+ node_id, []
561
+ ) # Get original list for this node
562
+ node_specific_data = None # Default value if index 2 is missing
563
+ if len(original_result) > 4:
564
+ node_specific_data = original_result[4] # Get the third element
565
+
566
+ nodes_data_dict[node_id] = node_specific_data # Add to the sub-dictionary
567
+
568
+ final_grouped_results[group_index] = {
569
+ "r_route_cgr": r_route_cgr,
570
+ "unlabeled_reaction": unlabeled_reaction,
571
+ "synthon_cgr": synthon_cgr,
572
+ "synthon_reaction": synthon_reaction,
573
+ "nodes_data": nodes_data_dict,
574
+ "post_processed": False,
575
+ }
576
+ group_index += 1
577
+
578
+ return final_grouped_results
579
+
580
+
581
+ def subcluster_all_clusters(groups, r_route_cgrs_dict, route_cgrs_dict):
582
+ """
583
+ Subdivide each reaction cluster into detailed synthon-based subgroups.
584
+
585
+ Iterates over all clusters in `groups`, applies `subcluster_one_cluster`
586
+ to generate per-cluster synthons, then organizes nodes by synthon detail.
587
+
588
+ Parameters
589
+ ----------
590
+ groups : dict
591
+ Mapping of cluster indices to cluster data.
592
+ r_route_cgrs_dict : dict
593
+ Dictionary of ReducedRoteCGRs
594
+ route_cgrs_dict : dict
595
+ Dictionary of RoteCGRs
596
+
597
+ Returns
598
+ -------
599
+ dict or None
600
+ A dict mapping each cluster index to its subgroups dict,
601
+ or None if any cluster fails to subcluster.
602
+ """
603
+ all_subgroups = {}
604
+ for group_index, group in groups.items():
605
+ group_synthons = subcluster_one_cluster(
606
+ group, r_route_cgrs_dict, route_cgrs_dict
607
+ )
608
+ if group_synthons is None:
609
+ return None
610
+ all_subgroups[group_index] = group_nodes_by_synthon_detail(group_synthons)
611
+ return all_subgroups
612
+
613
+
614
+ def all_lg_collect(subgroup):
615
+ """
616
+ Gather all leaving-group CGRContainers by node index.
617
+
618
+ Scans `subgroup['nodes_data']`, collects every CGRContainer per index,
619
+ and returns a mapping from each index to the list of distinct containers.
620
+
621
+ Parameters
622
+ ----------
623
+ subgroup : dict
624
+ Must contain 'nodes_data', a dict mapping pathway keys to
625
+ dicts of {node_index: (CGRContainer, …)}.
626
+
627
+ Returns
628
+ -------
629
+ dict[int, list[CGRContainer]]
630
+ For each node index, a list of unique CGRContainer objects
631
+ (duplicates by string are filtered out).
632
+ """
633
+ all_indices = set()
634
+ for sub_dict in subgroup["nodes_data"].values():
635
+ all_indices.update(sub_dict.keys())
636
+
637
+ # Dynamically initialize result and seen dictionaries
638
+ result = {idx: [] for idx in all_indices}
639
+ seen = {idx: set() for idx in all_indices}
640
+
641
+ # Populate the result with unique CGRContainer objects
642
+ for sub_dict in subgroup["nodes_data"].values():
643
+ for idx in sub_dict:
644
+ cgr_container = sub_dict[idx][0]
645
+ cgr_str = str(cgr_container)
646
+ if cgr_str not in seen[idx]:
647
+ seen[idx].add(cgr_str)
648
+ result[idx].append(cgr_container)
649
+ return result
650
+
651
+
652
+ def replace_leaving_groups_in_synthon(subgroup, to_remove): # Under development
653
+ """
654
+ Replace specified leaving groups (LG) in a synthon CGR with new fragments and return the updated CGR
655
+ along with a mapping from adjusted LG marks to their atom indices.
656
+
657
+ Parameters:
658
+ subgroup (dict): Must contain:
659
+ - 'synthon_cgr': the CGR object representing the synthon graph
660
+ - 'nodes_data': mapping of node indices to LG replacement data
661
+ to_remove (List[int]): List of LG marks to remove and replace.
662
+
663
+ Returns:
664
+ Tuple[CGR, Dict[int, int]]:
665
+ - The updated CGR with replacements
666
+ - A dict mapping new LG marks to their atom indices in the updated CGR
667
+ """
668
+ # Extract the original CGR and leaving group replacement table
669
+ original_cgr = subgroup["synthon_cgr"]
670
+ lg_table = next(iter(subgroup["nodes_data"].values()))
671
+
672
+ updated_cgr = original_cgr
673
+
674
+ removed_count = 0
675
+ new_lgs = {}
676
+
677
+ # Iterate through all atoms (index, atom_obj) in the CGR
678
+ for atom_idx, atom_obj in list(updated_cgr.atoms()):
679
+ # Skip non-X atoms
680
+ if atom_obj.__class__.__name__ != "DynamicX":
681
+ continue
682
+
683
+ current_mark = atom_obj.mark
684
+ if current_mark in to_remove:
685
+ # Remove old LG (X): delete bond and atom
686
+ neighbors = list(updated_cgr._bonds[atom_idx].keys())
687
+ if neighbors:
688
+ neighbor_idx = neighbors[0]
689
+ bond = updated_cgr._bonds[atom_idx][neighbor_idx]
690
+ updated_cgr.delete_bond(atom_idx, neighbor_idx)
691
+ updated_cgr.delete_atom(atom_idx)
692
+
693
+ # Attach new LG(X) fragment from the table
694
+ lg_fragment = lg_table[current_mark][0]
695
+ updated_cgr = updated_cgr.union(lg_fragment)
696
+ # Reset radical flag on the new atom and restore the bond
697
+ updated_cgr._atoms[atom_idx].is_radical = False
698
+ updated_cgr.add_bond(atom_idx, neighbor_idx, bond)
699
+
700
+ removed_count += 1
701
+ else:
702
+ # Adjust the marks of remaining LGs to account for removed ones
703
+ atom_obj.mark -= removed_count
704
+ new_lgs[atom_obj.mark] = atom_idx
705
+
706
+ # Reorder atoms dict and update 2D coordinates for depiction
707
+ updated_cgr._atoms = dict(sorted(updated_cgr._atoms.items()))
708
+
709
+ return updated_cgr, new_lgs
710
+
711
+
712
+ def new_lg_reaction_replacer(synthon_reaction, new_lgs, max_in_target_mol):
713
+ """
714
+ Replace placeholder atom indices with marked leaving-group atoms in reactants.
715
+
716
+ Iterates through each reactant in a `ReactionContainer`, finds atom indices
717
+ corresponding to newly detached leaving-groups (those greater than the
718
+ core’s maximum index), and replaces them with `MarkedAt` atoms bearing
719
+ the correct X labels and isotopes. Bonds to the original attachment points
720
+ are preserved.
721
+
722
+ Parameters
723
+ ----------
724
+ synthon_reaction : ReactionContainer
725
+ A reaction container whose `reactants` list contains molecules with
726
+ dummy atoms (by index) marking where leaving-groups were removed.
727
+ new_lgs : dict[int, int]
728
+ Mapping from leaving-group label (int) to the atom index (int) in each
729
+ reactant that should be replaced.
730
+ max_in_target_mol : int
731
+ The highest atom index used by the core product. Any atom index in a
732
+ reactant greater than this is treated as a leaving-group placeholder.
733
+
734
+ Returns
735
+ -------
736
+ List[Molecule]
737
+ A list of reactant molecules where each placeholder atom has been
738
+ replaced by a `MarkedAt` atom with its `.mark` and `.isotope` set
739
+ to the leaving-group label, and original bonds reattached.
740
+ """
741
+ new_reactants = []
742
+ for reactant in synthon_reaction.reactants:
743
+ atom_keys = list(reactant._atoms.keys())
744
+ for atom_num in atom_keys:
745
+ if atom_num > max_in_target_mol:
746
+ for k, val in new_lgs.items():
747
+ lg = MarkedAt()
748
+ if atom_num == val:
749
+ lg.mark = k
750
+ lg.isotope = k
751
+ atom1 = list(reactant._bonds[atom_num].keys())[0]
752
+ bond = reactant._bonds[atom_num][atom1]
753
+ reactant.delete_bond(atom1, atom_num)
754
+ reactant.delete_atom(atom_num)
755
+ reactant.add_atom(lg, atom_num)
756
+ reactant.add_bond(atom1, atom_num, bond)
757
+ new_reactants.append(reactant)
758
+
759
+ return new_reactants
760
+
761
+
762
+ def post_process_subgroup(
763
+ subgroup,
764
+ ): # Under development: Error in replace_leaving_groups_in_synthon , 'cuz synthon_reaction.clean2d crashes
765
+ """
766
+ Drop leaving-groups common to all pathways and rebuild a minimal synthon.
767
+
768
+ Scans the subgroup for leaving-groups present in every route, removes those
769
+ from the CGR, re-assembles a clean ReactionContainer with the original core,
770
+ updates `nodes_data`, and flags the dict as processed.
771
+
772
+ Parameters
773
+ ----------
774
+ subgroup : dict
775
+ Must include keys for `nodes_data` and the helpers
776
+ (`all_lg_collect`, `find_const_lg`, etc.). If already
777
+ post_processed, returns immediately.
778
+
779
+ Returns
780
+ -------
781
+ dict
782
+ The same dict, now with:
783
+ - `'synthon_reaction'`: cleaned ReactionContainer
784
+ - `'nodes_data'`: filtered node table
785
+ - `'post_processed'`: True
786
+ """
787
+ if "post_processed" in subgroup.keys() and subgroup["post_processed"] == True:
788
+ return subgroup
789
+ result = all_lg_collect(subgroup)
790
+ # to find constant lg that need to be removed
791
+ to_remove = [ind for ind, cgr_set in result.items() if len(cgr_set) == 1]
792
+ new_synthon_cgr, new_lgs = replace_leaving_groups_in_synthon(subgroup, to_remove)
793
+ synthon_reaction = ReactionContainer.from_cgr(new_synthon_cgr)
794
+ synthon_reaction.clean2d()
795
+ old_reactants = ReactionContainer.from_cgr(new_synthon_cgr).reactants
796
+ target_mol = synthon_reaction.products[0] # TO DO: target_mol might be non 0
797
+ max_in_target_mol = max(target_mol._atoms)
798
+ new_reactants = new_lg_reaction_replacer(
799
+ synthon_reaction, new_lgs, max_in_target_mol
800
+ )
801
+ new_synthon_reaction = ReactionContainer(
802
+ reactants=new_reactants, products=[target_mol]
803
+ )
804
+ new_synthon_reaction.clean2d()
805
+ subgroup["synthon_reaction"] = new_synthon_reaction
806
+ subgroup["nodes_data"] = remove_and_shift(subgroup["nodes_data"], to_remove)
807
+ subgroup["post_processed"] = True
808
+ subgroup["group_lgs"] = group_by_identical_values(subgroup["nodes_data"])
809
+ return subgroup
810
+
811
+
812
+ def group_by_identical_values(nodes_data): # Under development
813
+ """
814
+ Groups entries in a nested dictionary based on identical sets of core values.
815
+
816
+ Identifies route IDs whose inner dictionaries contain the
817
+ same sequence of leaving groups, when ordered by subkey. These are collapsed into a single entry.
818
+
819
+ Args:
820
+ nodes_data (dict): A dictionary mapping outer keys to inner dictionaries.
821
+ Each inner dictionary maps subkeys to a tuple `(value_obj, other_info)`.
822
+ `value_obj` is used for grouping, `other_info` is ignored.
823
+ Example: {'route_1': {'pos_a': (1, 'infoA'), 'pos_b': (2, 'infoB')}, ...}
824
+
825
+ Returns:
826
+ dict: A dictionary where:
827
+ - Keys are tuples of the original outer keys that were grouped.
828
+ - Values are dictionaries mapping the original subkeys to the
829
+ `value_obj` from the first outer key in the group's tuple.
830
+ The dictionary is sorted descending by the number of grouped outer keys.
831
+ Example: {('route_1', 'route_2'): {'pos_a': 1, 'pos_b': 2}, ...}
832
+ """
833
+ # Step 1: Build a signature for each outer key: the tuple of all first-elements in its inner dict
834
+ signature_map = defaultdict(list)
835
+ for outer_key, inner_dict in nodes_data.items():
836
+ # Sort inner_dict items by subkey to ensure consistent ordering
837
+ sorted_items = sorted(inner_dict.items(), key=lambda kv: kv[0])
838
+ # Extract only the first element of each (value_obj, other_info) tuple
839
+ signature = tuple(val_tuple[0] for _, val_tuple in sorted_items)
840
+ signature_map[signature].append(outer_key)
841
+
842
+ # Step 2: Build the grouped result
843
+ grouped = {}
844
+ for signature, outer_keys in signature_map.items():
845
+ # Use the representative inner dict from the first outer key in this group
846
+ rep_inner = nodes_data[outer_keys[0]]
847
+ # Build mapping subkey -> value_obj
848
+ rep_values = {subkey: val_tuple[0] for subkey, val_tuple in rep_inner.items()}
849
+ # Store under tuple of grouped outer keys
850
+ grouped_key = tuple(outer_keys)
851
+ grouped[grouped_key] = rep_values
852
+
853
+ sorted_grouped = dict(
854
+ sorted(grouped.items(), key=lambda item: len(item[0]), reverse=True)
855
+ )
856
+
857
+ return sorted_grouped
synplan/chem/reaction_routes/io.py ADDED
@@ -0,0 +1,286 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import csv
2
+ import json
3
+ import pickle
4
+ import os
5
+
6
+ from CGRtools import smiles as read_smiles
7
+ from synplan.mcts.tree import Tree
8
+
9
+
10
+ def make_dict(routes_json):
11
+ """
12
+ routes_json : list of tree-dicts as produced by make_json()
13
+
14
+ Returns a dict mapping each route index (0, 1, …) to a sub-dict
15
+ of {new_step_id: ReactionContainer}, where the step IDs run
16
+ from the earliest reaction (0) up to the final (max).
17
+ """
18
+ routes_dict = {}
19
+ if isinstance(routes_json, dict):
20
+ for route_idx, tree in routes_json.items():
21
+ rxn_list = []
22
+
23
+ def _postorder(node):
24
+ # first dive into any children, then record this reaction
25
+ for child in node.get("children", []):
26
+ _postorder(child)
27
+ if node["type"] == "reaction":
28
+ rxn_list.append(read_smiles(node["smiles"]))
29
+ # mol-nodes simply recurse (no record)
30
+
31
+ # collect all reactions in leaf→root order
32
+ _postorder(tree)
33
+
34
+ # now assign 0,1,2,… in that order
35
+ reactions = {i: rxn for i, rxn in enumerate(rxn_list)}
36
+ routes_dict[int(route_idx)] = reactions
37
+
38
+ return routes_dict
39
+ else:
40
+ for route_idx, tree in enumerate(routes_json):
41
+ rxn_list = []
42
+
43
+ def _postorder(node):
44
+ # first dive into any children, then record this reaction
45
+ for child in node.get("children", []):
46
+ _postorder(child)
47
+ if node["type"] == "reaction":
48
+ rxn_list.append(read_smiles(node["smiles"]))
49
+ # mol-nodes simply recurse (no record)
50
+
51
+ # collect all reactions in leaf→root order
52
+ _postorder(tree)
53
+
54
+ # now assign 0,1,2,… in that order
55
+ reactions = {i: rxn for i, rxn in enumerate(rxn_list)}
56
+ routes_dict[int(route_idx)] = reactions
57
+
58
+ return routes_dict
59
+
60
+
61
+ def read_routes_json(file_path="routes.csv", to_dict=False):
62
+ with open(file_path, "r") as file:
63
+ routes_json = json.load(file)
64
+ if to_dict:
65
+ return make_dict(routes_json)
66
+ return routes_json
67
+
68
+
69
+ def read_routes_csv(file_path="routes.csv"):
70
+ """
71
+ Read a CSV with columns: route_id, step_id, smiles, meta
72
+ and return a nested dict mapping
73
+ route_id (int) -> step_id (int) -> ReactionContainer
74
+ (ignoring meta for now, but you could extract it if needed).
75
+ """
76
+ routes_dict = {}
77
+ with open(file_path, newline="") as csvfile:
78
+ reader = csv.DictReader(csvfile)
79
+ for row in reader:
80
+ route_id = int(row["route_id"])
81
+ step_id = int(row["step_id"])
82
+ smiles = row["smiles"]
83
+ # adjust this constructor to your actual API
84
+ reaction = read_smiles(smiles)
85
+ routes_dict.setdefault(route_id, {})[step_id] = reaction
86
+ return routes_dict
87
+
88
+
89
+ def make_json(routes_dict, keep_ids=True):
90
+ """
91
+ Convert routes into a nested JSON tree of reaction and molecule nodes.
92
+
93
+ Args:
94
+ routes_dict (dict[int, dict[int, Reaction]]): Mapping route IDs to steps (step_id -> Reaction).
95
+ keep_ids (bool): If True, returns a list of route trees; otherwise returns a dict mapping route IDs to trees.
96
+
97
+ Returns:
98
+ list or dict: JSON-like tree(s) of routes.
99
+ """
100
+ # Prepare output
101
+ all_routes = {} if keep_ids else []
102
+
103
+ for route_id, steps in routes_dict.items():
104
+ if not steps:
105
+ continue
106
+
107
+ # Determine target molecule atoms from the final step of this route
108
+ final_step = max(steps)
109
+ target = steps[final_step].products[0]
110
+ atom_nums = set(target._atoms.keys())
111
+
112
+ # Precompute canonical SMILES and producer mapping for all products
113
+ prod_map = {} # smiles -> list of step_ids
114
+ for sid, rxn in steps.items():
115
+ for prod in rxn.products:
116
+ prod.kekule()
117
+ prod.implicify_hydrogens()
118
+ prod.thiele()
119
+ s = str(prod)
120
+ prod_map.setdefault(s, []).append(sid)
121
+
122
+ def transform(mol):
123
+ mol.kekule()
124
+ mol.implicify_hydrogens()
125
+ mol.thiele()
126
+ return str(mol)
127
+
128
+ def build_mol_node(sid):
129
+ """Find the product with any overlap to target atoms and recurse into its reaction."""
130
+ rxn = steps[sid]
131
+ for p in rxn.products:
132
+ if atom_nums & set(p._atoms.keys()):
133
+ smiles = str(p)
134
+ return {
135
+ "type": "mol",
136
+ "smiles": smiles,
137
+ "children": [build_reaction_node(sid)],
138
+ "in_stock": False,
139
+ }
140
+ # Shouldn't reach here if tree is consistent
141
+ return None
142
+
143
+ def build_reaction_node(sid):
144
+ """Build reaction node and recurse into reactant molecule nodes."""
145
+ rxn = steps[sid]
146
+ node = {"type": "reaction", "smiles": format(rxn, "m"), "children": []}
147
+
148
+ for react in rxn.reactants:
149
+ r_smi = transform(react)
150
+ # Look up any prior step producing this reactant
151
+ prior = [ps for ps in prod_map.get(r_smi, []) if ps < sid]
152
+ if prior:
153
+ node["children"].append(build_mol_node(max(prior)))
154
+ else:
155
+ node["children"].append(
156
+ {"type": "mol", "smiles": r_smi, "in_stock": True}
157
+ )
158
+
159
+ return node
160
+
161
+ # Build route tree and store
162
+ tree = build_mol_node(final_step)
163
+ if keep_ids:
164
+ all_routes[int(route_id)] = tree
165
+ else:
166
+ all_routes.append(tree)
167
+
168
+ return all_routes
169
+
170
+
171
+ def write_routes_json(routes_dict, file_path):
172
+ """Serialize reaction routes to a JSON file."""
173
+ routes_json = make_json(routes_dict)
174
+ with open(file_path, "w") as f:
175
+ json.dump(routes_json, f, indent=2)
176
+
177
+
178
+ def write_routes_csv(routes_dict, file_path="routes.csv"):
179
+ """
180
+ Write out a nested routes_dict of the form
181
+ { route_id: { step_id: reaction_obj, ... }, ... }
182
+ to a CSV with columns: route_id, step_id, smiles, meta
183
+ where smiles is format(reaction, 'm') and meta is left blank.
184
+ """
185
+ with open(file_path, "w", newline="") as csvfile:
186
+ writer = csv.writer(csvfile)
187
+ # header row
188
+ writer.writerow(["route_id", "step_id", "smiles", "meta"])
189
+ # sort routes and steps for deterministic output
190
+ for route_id in sorted(routes_dict):
191
+ steps = routes_dict[route_id]
192
+ for step_id in sorted(steps):
193
+ reaction = steps[step_id]
194
+ smiles = format(reaction, "m")
195
+ meta = "" # or reaction.meta if you add that later
196
+ writer.writerow([route_id, step_id, smiles, meta])
197
+
198
+
199
+ class TreeWrapper:
200
+
201
+ def __init__(self, tree, mol_id=1, config=1, path="planning_results/forest"):
202
+ """Initializes the TreeWrapper."""
203
+ self.tree = tree
204
+ self.mol_id = mol_id
205
+ self.config = config
206
+ self.path = path
207
+ # Ensure the directory exists before creating the filename
208
+ os.makedirs(self.path, exist_ok=True)
209
+ self.filename = os.path.join(self.path, f"tree_{mol_id}_{config}.pkl")
210
+
211
+ def __getstate__(self):
212
+ state = self.__dict__.copy()
213
+ tree_state = self.tree.__dict__.copy()
214
+ # Reset or remove non-pickleable attributes (e.g., _tqdm, policy_network, value_network)
215
+ if "_tqdm" in tree_state:
216
+ tree_state["_tqdm"] = True # Reset to a simple flag
217
+ for attr in ["policy_network", "value_network"]:
218
+ if attr in tree_state:
219
+ tree_state[attr] = None
220
+ state["tree_state"] = tree_state
221
+ del state["tree"]
222
+ return state
223
+
224
+ def __setstate__(self, state):
225
+ tree_state = state.pop("tree_state")
226
+ self.__dict__.update(state)
227
+ new_tree = Tree.__new__(Tree)
228
+ new_tree.__dict__.update(tree_state)
229
+ self.tree = new_tree
230
+
231
+ def save_tree(self):
232
+ """Saves the TreeWrapper instance (including the tree state) to a file."""
233
+ try:
234
+ with open(self.filename, "wb") as f:
235
+ pickle.dump(self, f)
236
+ print(
237
+ f"Tree wrapper for mol_id '{self.mol_id}', config '{self.config}' saved to '{self.filename}'."
238
+ )
239
+ except Exception as e:
240
+ print(f"Error saving tree to {self.filename}: {e}")
241
+
242
+ @classmethod
243
+ def load_tree_from_id(cls, mol_id, config=1, path="planning_results/forest"):
244
+ """
245
+ Loads a Tree object from a saved file using mol_id and config.
246
+
247
+ Args:
248
+ mol_id: The molecule ID used for saving.
249
+ config: The configuration used for saving.
250
+ path: The directory where the file is located
251
+
252
+ Returns:
253
+ The loaded Tree object, or None if loading fails.
254
+ """
255
+ filename = os.path.join(path, f"tree_{mol_id}_{config}.pkl")
256
+ print(f"Attempting to load tree from: {filename}")
257
+ try:
258
+ # Ensure the 'Tree' class is defined in the current scope
259
+ if "Tree" not in globals() and "Tree" not in locals():
260
+ raise NameError(
261
+ "The 'Tree' class definition is required to load the object."
262
+ )
263
+
264
+ with open(filename, "rb") as f:
265
+ loaded_wrapper = pickle.load(f) # This implicitly calls __setstate__
266
+
267
+ print(
268
+ f"Tree object for mol_id '{mol_id}', config '{config}' successfully loaded from '{filename}'."
269
+ )
270
+ # The __setstate__ method already reconstructed the tree inside the wrapper
271
+ return loaded_wrapper.tree
272
+
273
+ except FileNotFoundError:
274
+ print(f"Error: File not found at {filename}")
275
+ return None
276
+ except (pickle.UnpicklingError, EOFError) as e:
277
+ print(
278
+ f"Error: Could not unpickle file {filename}. It might be corrupted or empty. Details: {e}"
279
+ )
280
+ return None
281
+ except NameError as e:
282
+ print(f"Error during loading: {e}. Ensure 'Tree' class is defined.")
283
+ return None
284
+ except Exception as e:
285
+ print(f"An unexpected error occurred loading tree from {filename}: {e}")
286
+ return None
synplan/chem/reaction_routes/leaving_groups.py ADDED
@@ -0,0 +1,131 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from CGRtools.periodictable import Core, At, DynamicElement
2
+ from typing import Optional
3
+
4
+
5
+ class Marked(Core):
6
+ __slots__ = "__mark", "_isotope"
7
+
8
+ def __init__(self, *args, **kwargs):
9
+ super().__init__(*args, **kwargs)
10
+ self.__mark = None
11
+ self._isotope = 0 # Make sure this exists
12
+
13
+ @property
14
+ def mark(self):
15
+ return self.__mark
16
+
17
+ @mark.setter
18
+ def mark(self, mark):
19
+ self.__mark = mark
20
+
21
+ @property
22
+ def isotope(self):
23
+ return getattr(self, "_isotope", 0) # Always returns int
24
+
25
+ @isotope.setter
26
+ def isotope(self, value):
27
+ self._isotope = int(value)
28
+
29
+ def __repr__(self):
30
+ return f"{self.symbol}({self.isotope})"
31
+
32
+ @property
33
+ def atomic_symbol(self) -> str:
34
+ return self.__class__.__name__[6:]
35
+
36
+ @property
37
+ def symbol(self) -> str:
38
+ return "X" # For human-readable representation
39
+
40
+ def __len__(self):
41
+ return super().__len__()
42
+
43
+
44
+ class MarkedAt(Marked, At):
45
+ atomic_number = At.atomic_number
46
+
47
+ @property
48
+ def atomic_symbol(self):
49
+ return "At"
50
+
51
+ @property
52
+ def symbol(self):
53
+ return "X"
54
+
55
+ def __repr__(self):
56
+ return f"X({self.isotope})"
57
+
58
+ def __str__(self):
59
+ return f"X({self.isotope})"
60
+
61
+ def __hash__(self):
62
+ return hash(
63
+ (
64
+ self.isotope,
65
+ getattr(self, "atomic_number", 0),
66
+ getattr(self, "charge", 0),
67
+ getattr(self, "is_radical", False),
68
+ )
69
+ )
70
+
71
+
72
+ class DynamicX(DynamicElement):
73
+ __slots__ = ("_mark", "_isotope")
74
+
75
+ atomic_number = 85
76
+ mass = 0.0
77
+ group = 0
78
+ period = 0
79
+ isotopes_distribution = list(range(20))
80
+ atomic_radius = 0.5
81
+ isotopes_masses = 0
82
+
83
+ def __init__(self, *args, **kwargs):
84
+ super().__init__(*args, **kwargs)
85
+ self._isotope = None
86
+ self._mark = None
87
+
88
+ @property
89
+ def mark(self):
90
+ return getattr(self, "_mark", None)
91
+
92
+ @mark.setter
93
+ def mark(self, value):
94
+ self._mark = value
95
+
96
+ @property
97
+ def isotope(self):
98
+ return getattr(self, "_isotope", None)
99
+
100
+ @isotope.setter
101
+ def isotope(self, value):
102
+ self._isotope = value
103
+
104
+ @property
105
+ def symbol(self) -> str:
106
+ return "X"
107
+
108
+ def valence_rules(
109
+ self, charge: int = 0, is_radical: bool = False, valence: int = 0
110
+ ) -> tuple:
111
+ if charge == 0 and not is_radical and (valence == 1):
112
+ return tuple()
113
+ elif charge == 0 and not is_radical and valence == 0:
114
+ return tuple()
115
+ else:
116
+ return tuple()
117
+
118
+ def __repr__(self):
119
+ return f"Dynamic{self.symbol}()"
120
+
121
+ @property
122
+ def p_charge(self) -> int:
123
+ return self.charge
124
+
125
+ @property
126
+ def p_is_radical(self) -> bool:
127
+ return self.is_radical
128
+
129
+ @property
130
+ def p_hybridization(self) -> Optional[int]:
131
+ return self.hybridization
synplan/chem/reaction_routes/route_cgr.py ADDED
@@ -0,0 +1,570 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from CGRtools.containers.bonds import DynamicBond
2
+ from CGRtools.containers import ReactionContainer, CGRContainer, MoleculeContainer
3
+ from synplan.mcts.tree import Tree
4
+
5
+
6
+ def find_next_atom_num(reactions: list):
7
+ """
8
+ Find the next available atom number across a list of reactions.
9
+
10
+ This function iterates through a list of reaction containers, composes
11
+ each reaction to get its Condensed Graph of Reaction (CGR), and finds
12
+ the maximum atom index used within each CGR. It then returns the maximum
13
+ atom index found across all reactions plus one, providing a unique
14
+ next available atom number.
15
+
16
+ Args:
17
+ reactions (list): A list of ReactionContainer objects.
18
+
19
+ Returns:
20
+ int: The next available integer atom number, which is one greater
21
+ than the maximum atom index found in any of the provided reaction CGRs.
22
+ """
23
+ max_num = 0
24
+ for reaction in reactions:
25
+ cgr = reaction.compose()
26
+ max_num = max(max_num, max(cgr._atoms.keys()))
27
+ return max_num + 1
28
+
29
+
30
+ def get_clean_mapping(
31
+ curr_prod: MoleculeContainer, prod: MoleculeContainer, reverse: bool = False
32
+ ):
33
+ """
34
+ Get a 'clean' atom mapping between two molecules, avoiding conflicts.
35
+
36
+ This function attempts to establish a mapping between the atoms of two
37
+ MoleculeContainer objects (`curr_prod` and `prod`). It uses an internal
38
+ mapping mechanism and then filters the result to create a "clean" mapping.
39
+ The cleaning process specifically avoids adding entries to the mapping
40
+ where the source and target indices are the same, or where the target
41
+ index already exists as a source in the mapping with a different target.
42
+ It also checks for potential conflicts based on the atom keys present
43
+ in the original molecules.
44
+
45
+ Args:
46
+ curr_prod (MoleculeContainer): The first MoleculeContainer object.
47
+ prod (MoleculeContainer): The second MoleculeContainer object.
48
+ reverse (bool, optional): If True, the mapping is generated in the
49
+ reverse direction (from `prod` to `curr_prod`).
50
+ Defaults to False (mapping from `curr_prod` to `prod`).
51
+
52
+ Returns:
53
+ dict: A dictionary representing the clean atom mapping. Keys are atom
54
+ indices from the source molecule, and values are the corresponding
55
+ atom indices in the target molecule. Returns an empty dictionary
56
+ if no mapping is found or if the initial mapping is empty.
57
+ """
58
+ dict_map = {}
59
+ m = list(curr_prod.get_mapping(prod))
60
+
61
+ if len(m) == 0:
62
+ return dict_map
63
+
64
+ curr_atoms = set(curr_prod._atoms.keys())
65
+ prod_atoms = set(prod._atoms.keys())
66
+
67
+ rr = m[0]
68
+
69
+ # Build mapping while checking for conflicts
70
+ for key, value in rr.items():
71
+ if key != value:
72
+ if value in rr.keys() and rr[value] != key:
73
+ continue
74
+
75
+ source = value if reverse else key
76
+ target = key if reverse else value
77
+
78
+ if reverse and target in curr_atoms:
79
+ continue
80
+ if not reverse and target in prod_atoms:
81
+ continue
82
+
83
+ dict_map[source] = target
84
+
85
+ return dict_map
86
+
87
+
88
+ def validate_molecule_components(curr_mol: MoleculeContainer, node_id: int):
89
+ """
90
+ Validate that a molecule has only one connected component.
91
+
92
+ This function checks if a given MoleculeContainer object represents a
93
+ single connected molecule or multiple disconnected fragments. It extracts
94
+ the connected components and prints an error message if more than one
95
+ component is found, indicating a potential issue with the molecule
96
+ representation within a specific tree node.
97
+
98
+ Args:
99
+ curr_mol (MoleculeContainer): The MoleculeContainer object to validate.
100
+ node_id (int): The ID of the tree node associated with this molecule,
101
+ used for reporting purposes in the error message.
102
+ """
103
+ new_rmol = [curr_mol.substructure(c) for c in curr_mol.connected_components]
104
+ if len(new_rmol) > 1:
105
+ print(f"Error tree {node_id}: We have more than one molecule in one node")
106
+
107
+
108
+ def get_leaving_groups(products: list):
109
+ """
110
+ Extract leaving group atom numbers from a list of reaction products.
111
+
112
+ This function takes a list of product MoleculeContainer objects resulting
113
+ from a reaction. It assumes the first molecule in the list is the main
114
+ product and the subsequent molecules are leaving groups. It collects
115
+ the atom indices (keys from the `_atoms` dictionary) for all molecules
116
+ except the first one, considering these indices as belonging to leaving
117
+ group atoms.
118
+
119
+ Args:
120
+ products (list): A list of MoleculeContainer objects representing the
121
+ products of a reaction. The first element is assumed
122
+ to be the main product.
123
+
124
+ Returns:
125
+ list: A list of integer atom indices corresponding to the atoms
126
+ in the leaving group molecules.
127
+ """
128
+ lg_atom_nums = []
129
+ for i, prod in enumerate(products):
130
+ if i != 0: # Skip first product (main product)
131
+ lg_atom_nums.extend(prod._atoms.keys())
132
+ return lg_atom_nums
133
+
134
+
135
+ def process_first_reaction(first_react: ReactionContainer, tree: Tree, node_id: int):
136
+ """
137
+ Process the first reaction in a retrosynthetic route and initialize the building block set.
138
+
139
+ This function takes the first reaction in a route, iterates through its
140
+ reactants, validates that each reactant is a single connected component,
141
+ and identifies potential building blocks. A reactant is considered a
142
+ potential building block if its size is less than or equal to the
143
+ minimum molecule size defined in the tree's configuration or if its
144
+ SMILES string is present in the tree's building blocks set. The atom
145
+ indices of such building blocks are collected into a set.
146
+
147
+ Args:
148
+ first_react (ReactionContainer): The first ReactionContainer object in the route.
149
+ tree (Tree): The Tree object containing the retrosynthetic search tree
150
+ and configuration (including `min_mol_size` and `building_blocks`).
151
+ node_id (int): The ID of the tree node associated with this reaction,
152
+ used for validation reporting.
153
+
154
+ Returns:
155
+ set: A set of integer atom indices corresponding to the atoms
156
+ identified as part of building blocks in the first reaction's reactants.
157
+ """
158
+ bb_set = set()
159
+
160
+ for curr_mol in first_react.reactants:
161
+ react_key = tuple(curr_mol._atoms)
162
+ react_key_set = set(react_key)
163
+
164
+ if (
165
+ len(curr_mol) <= tree.config.min_mol_size
166
+ or str(curr_mol) in tree.building_blocks
167
+ ):
168
+ bb_set = react_key_set
169
+
170
+ validate_molecule_components(curr_mol, node_id)
171
+
172
+ return bb_set
173
+
174
+
175
+ def update_reaction_dict(
176
+ reaction: ReactionContainer,
177
+ node_id: int,
178
+ mapping: dict,
179
+ react_dict: dict,
180
+ tree: Tree,
181
+ bb_set: set,
182
+ prev_remap: dict = None,
183
+ ):
184
+ """
185
+ Update a reaction dictionary with atom mappings and identify building blocks.
186
+
187
+ This function processes the reactants of a given reaction, validates their
188
+ structure (single connected component), updates a dictionary (`react_dict`)
189
+ with atom mappings for each reactant, and expands a set of building block
190
+ atom indices (`bb_set`). The mapping is filtered based on the atoms present
191
+ in the current reactant, and can optionally include a previous remapping.
192
+ Reactants are identified as building blocks based on size or presence in
193
+ the tree's building blocks set.
194
+
195
+ Args:
196
+ reaction (ReactionContainer): The ReactionContainer object representing the reaction.
197
+ node_id (int): The ID of the tree node associated with this synthethic route,
198
+ used for validation reporting.
199
+ mapping (dict): The primary atom mapping dictionary to filter and apply.
200
+ react_dict (dict): The dictionary to update with filtered mappings for each reactant.
201
+ Keys are tuples of atom indices for each reactant molecule.
202
+ tree (Tree): The Tree object containing the retrosynthetic search tree
203
+ and configuration (including `min_mol_size` and `building_blocks`).
204
+ bb_set (set): The set of building block atom indices to update.
205
+ prev_remap (dict, optional): An optional dictionary representing a previous
206
+ remapping to include in the filtered mapping.
207
+ Defaults to None.
208
+
209
+ Returns:
210
+ tuple: A tuple containing:
211
+ - dict: The updated `react_dict` with filtered mappings for each reactant.
212
+ - set: The updated `bb_set` including atom indices from newly identified
213
+ building blocks.
214
+ """
215
+ for curr_mol in reaction.reactants:
216
+ react_key = tuple(curr_mol._atoms)
217
+ react_key_set = set(react_key)
218
+
219
+ validate_molecule_components(curr_mol, node_id)
220
+
221
+ if (
222
+ len(curr_mol) <= tree.config.min_mol_size
223
+ or str(curr_mol) in tree.building_blocks
224
+ ):
225
+ bb_set = bb_set.union(react_key_set)
226
+
227
+ # Filter the mapping to include only keys present in the current react_key
228
+ filtered_mapping = {k: v for k, v in mapping.items() if k in react_key_set}
229
+ if prev_remap:
230
+ prev_remappping = {
231
+ k: v for k, v in prev_remap.items() if k in react_key_set
232
+ }
233
+ filtered_mapping.update(prev_remappping)
234
+ react_dict[react_key] = filtered_mapping
235
+
236
+ return react_dict, bb_set
237
+
238
+
239
+ def process_target_blocks(
240
+ curr_products: list,
241
+ curr_prod: MoleculeContainer,
242
+ lg_atom_nums: list,
243
+ curr_lg_atom_nums: list,
244
+ bb_set: set,
245
+ ):
246
+ """
247
+ Identifies and collects atom indices for target blocks based on leaving groups and building blocks.
248
+
249
+ This function iterates through a list of current product molecules, compares their atoms
250
+ to a reference molecule (`curr_prod`), and collects the indices of atoms that correspond
251
+ to atoms in the provided leaving group lists (`lg_atom_nums`, `curr_lg_atom_nums`) or
252
+ the building block set (`bb_set`). This is typically used to identify parts of molecules
253
+ that should be treated as 'target blocks' during a remapping or analysis process.
254
+
255
+ Args:
256
+ curr_products (list): A list of MoleculeContainer objects representing the current products.
257
+ curr_prod (MoleculeContainer): A reference MoleculeContainer object, likely the main product,
258
+ used for mapping atom indices.
259
+ lg_atom_nums (list): A list of integer atom indices identified as leaving group atoms
260
+ in a relevant context.
261
+ curr_lg_atom_nums (list): Another list of integer atom indices identified as leaving
262
+ group atoms, potentially from a different context than `lg_atom_nums`.
263
+ bb_set (set): A set of integer atom indices identified as building block atoms.
264
+
265
+ Returns:
266
+ list: A list of integer atom indices that are identified as 'target blocks' based on
267
+ their presence in the leaving group lists or building block set after mapping
268
+ to the reference molecule.
269
+ """
270
+ target_block = []
271
+ if len(curr_products) > 1:
272
+ for prod in curr_products:
273
+ dict_map = get_clean_mapping(curr_prod, prod)
274
+ if prod._atoms.keys() != curr_prod._atoms.keys():
275
+ for key in list(prod._atoms.keys()):
276
+ if key in lg_atom_nums or key in curr_lg_atom_nums:
277
+ target_block.append(key)
278
+ if key in bb_set:
279
+ target_block.append(key)
280
+ return target_block
281
+
282
+
283
+ def compose_route_cgr(tree_or_routes, node_id):
284
+ """
285
+ Process a single synthesis route maintaining consistent state.
286
+
287
+ Parameters
288
+ ----------
289
+ tree_or_routes : synplan.mcts.tree.Tree
290
+ or dict mapping route_id -> {step_id: ReactionContainer}
291
+ node_id : int
292
+ the route index (in the Tree’s winning_nodes, or the dict’s keys)
293
+
294
+ Returns
295
+ -------
296
+ dict or None
297
+ - if successful: { 'cgr': <composed CGR>, 'reactions_dict': {step: ReactionContainer,…} }
298
+ - on error: None
299
+ """
300
+ # ----------- dict-based branch ------------
301
+ if isinstance(tree_or_routes, dict):
302
+ routes_dict = tree_or_routes
303
+ if node_id not in routes_dict:
304
+ raise KeyError(f"Route {node_id} not in provided dict.")
305
+ # grab and sort the ReactionContainers in chronological order
306
+ step_map = routes_dict[node_id]
307
+ sorted_ids = sorted(step_map)
308
+ reactions = [step_map[i] for i in sorted_ids]
309
+
310
+ # start from the last (final) reaction
311
+ accum_cgr = reactions[-1].compose()
312
+ reactions_dict = {len(reactions) - 1: reactions[-1]}
313
+ # now fold backwards through the earlier steps
314
+ for idx in range(len(reactions) - 2, -1, -1):
315
+ rxn = reactions[idx]
316
+ curr_cgr = rxn.compose()
317
+ accum_cgr = curr_cgr.compose(accum_cgr)
318
+ reactions_dict[idx] = rxn
319
+
320
+ return {"cgr": accum_cgr, "reactions_dict": reactions_dict}
321
+
322
+ # ----------- tree-based branch ------------
323
+ tree = tree_or_routes
324
+ try:
325
+ # original tree-based logic:
326
+ reactions = tree.synthesis_route(node_id)
327
+
328
+ first_react = reactions[-1]
329
+ reactions_dict = {len(reactions) - 1: first_react}
330
+
331
+ accum_cgr = first_react.compose()
332
+ bb_set = process_first_reaction(first_react, tree, node_id)
333
+ react_dict = {}
334
+ max_num = find_next_atom_num(reactions)
335
+
336
+ for step in range(len(reactions) - 2, -1, -1):
337
+ reaction = reactions[step]
338
+ curr_cgr = reaction.compose()
339
+ curr_prod = reaction.products[0]
340
+
341
+ accum_products = accum_cgr.decompose()[1].split()
342
+ lg_atom_nums = get_leaving_groups(accum_products)
343
+ curr_products = curr_cgr.decompose()[1].split()
344
+
345
+ tuple_atoms = tuple(curr_prod._atoms)
346
+ prev_remap = react_dict.get(tuple_atoms, {})
347
+
348
+ if prev_remap:
349
+ curr_cgr = curr_cgr.remap(prev_remap, copy=True)
350
+
351
+ # identify new atom‐numbers for any overlap
352
+ target_block = process_target_blocks(
353
+ curr_products,
354
+ curr_prod,
355
+ lg_atom_nums,
356
+ [list(p._atoms.keys()) for p in curr_products[1:]],
357
+ bb_set,
358
+ )
359
+ mapping = {}
360
+ for atom_num in sorted(target_block):
361
+ if atom_num in accum_cgr._atoms and atom_num not in mapping:
362
+ mapping[atom_num] = max_num
363
+ max_num += 1
364
+
365
+ # carry forward any clean remap on the product itself
366
+ dict_map = {}
367
+ for ap in accum_products:
368
+ clean_map = get_clean_mapping(curr_prod, ap, reverse=True)
369
+ if clean_map:
370
+ dict_map = clean_map
371
+ break
372
+ if dict_map:
373
+ curr_cgr = curr_cgr.remap(dict_map, copy=False)
374
+
375
+ # update our react_dict & bb_set
376
+ react_dict, bb_set = update_reaction_dict(
377
+ reaction, node_id, mapping, react_dict, tree, bb_set, prev_remap
378
+ )
379
+
380
+ # apply the new overlap‐mapping
381
+ if mapping:
382
+ curr_cgr = curr_cgr.remap(mapping, copy=False)
383
+
384
+ reactions_dict[step] = ReactionContainer.from_cgr(curr_cgr)
385
+ accum_cgr = curr_cgr.compose(accum_cgr)
386
+
387
+ return {"cgr": accum_cgr, "reactions_dict": reactions_dict}
388
+
389
+ except Exception as e:
390
+ print(f"Error processing node {node_id}: {e}")
391
+ return None
392
+
393
+
394
+ def compose_all_route_cgrs(tree_or_routes, node_id=None):
395
+ """
396
+ Process routes (reassign atom mappings) to compose RouteCGR.
397
+
398
+ Parameters
399
+ ----------
400
+ tree_or_routes : synplan.mcts.tree.Tree
401
+ or dict mapping route_id -> {step_id: ReactionContainer}
402
+ node_id : int or None
403
+ if None, do *all* winning routes (or all keys of the dict);
404
+ otherwise only that specific route.
405
+
406
+ Returns
407
+ -------
408
+ dict or None
409
+ - if node_id is None: {route_id: CGR, …}
410
+ - if node_id is given: {node_id: CGR}
411
+ - returns None on error
412
+ """
413
+ # dict-based branch
414
+ if isinstance(tree_or_routes, dict):
415
+ routes_dict = tree_or_routes
416
+
417
+ def _single(rid):
418
+ res = compose_route_cgr(routes_dict, rid)
419
+ return res["cgr"] if res else None
420
+
421
+ if node_id is not None:
422
+ if node_id not in routes_dict:
423
+ raise KeyError(f"Route {node_id} not in provided dict.")
424
+ return {node_id: _single(node_id)}
425
+
426
+ # all routes
427
+ result = {rid: _single(rid) for rid in sorted(routes_dict)}
428
+ return result
429
+
430
+ # tree-based branch
431
+ tree = tree_or_routes
432
+ route_cgrs = {}
433
+
434
+ if node_id is not None:
435
+ res = compose_route_cgr(tree, node_id)
436
+ if res:
437
+ route_cgrs[node_id] = res["cgr"]
438
+ else:
439
+ return None
440
+ return route_cgrs
441
+
442
+ for rid in sorted(set(tree.winning_nodes)):
443
+ res = compose_route_cgr(tree, rid)
444
+ if res:
445
+ route_cgrs[rid] = res["cgr"]
446
+
447
+ return route_cgrs
448
+
449
+
450
+ def extract_reactions(tree: Tree, node_id=None):
451
+ """
452
+ Collect mapped reaction sequences from a synthesis tree.
453
+
454
+ Traverses either a single branch (if `node_id` is given) or all winning routes,
455
+ composing CGR-based reactions for each, and returns a dict of reaction mappings.
456
+ Ensures that in every extracted reaction, atom indices are uniquely mapped (no overlaps)
457
+
458
+ Parameters
459
+ ----------
460
+ tree : ReactionTree
461
+ A retrosynthetic tree object with a `.winning_nodes` attribute and
462
+ supporting `compose_route_cgr(...)`.
463
+ node_id : hashable, optional
464
+ If provided, only extract reactions for this specific node/route.
465
+
466
+ Returns
467
+ -------
468
+ dict[node_id, dict]
469
+ Maps each route terminal node ID to its `reactions_dict` (as returned
470
+ by `compose_route_cgr`). Returns `None` if the specified `node_id` fails
471
+ to produce valid reactions.
472
+ """
473
+ react_dict = {}
474
+ if node_id is not None:
475
+ result = compose_route_cgr(tree, node_id)
476
+ if result:
477
+ react_dict[node_id] = result["reactions_dict"]
478
+ else:
479
+ return None
480
+ return react_dict
481
+
482
+ for node_id in set(tree.winning_nodes):
483
+ result = compose_route_cgr(tree, node_id)
484
+ if result:
485
+ react_dict[node_id] = result["reactions_dict"]
486
+
487
+ return dict(sorted(react_dict.items()))
488
+
489
+
490
+ def compose_reduced_route_cgr(route_cgr: CGRContainer):
491
+ """
492
+ Reduces a Routes Condensed Graph of reaction (RouteCGR) by performing the following steps:
493
+
494
+ 1. Extracts substructures corresponding to connected components from the input RouteCGR.
495
+ 2. Selects the first substructure as the target to work on.
496
+ 3. Iterates over all bonds in the target RouteCGR:
497
+ - If a bond is identified as a "leaving group" (its primary order is None while its original order is defined),
498
+ the bond is removed.
499
+ - If a bond has a modified order (both primary and original orders are integers) and the primary order is less than the original,
500
+ the bond is deleted and then re-added with a new dynamic bond using the primary order (this updates the bond to the reduced form).
501
+ 4. After bond modifications, re-extracts the substructure from the target RouteCGR (now called the reduced RouteCGR or ReducedRouteCGR).
502
+ 5. If the charge distributions (_p_charges vs. _charges) differ, neutralizes the charges by setting them to zero.
503
+
504
+ Args:
505
+ route_cgr: The input RouteCGR object to be reduced.
506
+
507
+ Returns:
508
+ The reduced RouteCGR object.
509
+ """
510
+ # Get all connected components of the RouteCGR as separate substructures.
511
+ cgr_prods = [route_cgr.substructure(c) for c in route_cgr.connected_components]
512
+ target_cgr = cgr_prods[
513
+ 0
514
+ ] # Choose the first substructure (main product) for further reduction.
515
+
516
+ # Iterate over each bond in the target RouteCGR.
517
+ bond_items = list(target_cgr._bonds.items())
518
+ for atom1, bond_set in bond_items:
519
+ bond_set_items = list(bond_set.items())
520
+ for atom2, bond in bond_set_items:
521
+
522
+ # Removing bonds corresponding to leaving groups:
523
+ # If product bond order is None (indicating a leaving group) but an original bond order exists,
524
+ # delete the bond.
525
+ if bond.p_order is None and bond.order is not None:
526
+ target_cgr.delete_bond(atom1, atom2)
527
+
528
+ # For bonds that have been modified (not leaving groups) where the new (primary) order is less than the original:
529
+ # Remove the bond and re-add it using the DynamicBond with the primary order for both bond orders.
530
+ elif (
531
+ type(bond.p_order) is int
532
+ and type(bond.order) is int
533
+ and bond.p_order != bond.order
534
+ ):
535
+ p_order = int(bond.p_order)
536
+ target_cgr.delete_bond(atom1, atom2)
537
+ target_cgr.add_bond(atom1, atom2, DynamicBond(p_order, p_order))
538
+
539
+ # After modifying bonds, extract the reduced RouteCGR from the target's connected components.
540
+ reduced_route_cgr = [
541
+ target_cgr.substructure(c) for c in target_cgr.connected_components
542
+ ][0]
543
+
544
+ # Neutralize charges if the primary charges and current charges differ.
545
+ if reduced_route_cgr._p_charges != reduced_route_cgr._charges:
546
+ for num, charge in reduced_route_cgr._charges.items():
547
+ if charge != 0:
548
+ reduced_route_cgr._atoms[num].charge = 0
549
+
550
+ return reduced_route_cgr
551
+
552
+
553
+ def compose_all_reduced_route_cgrs(route_cgrs_dict: dict):
554
+ """
555
+ Processes a collection (dictionary) of RouteCGRs to generate their reduced forms (ReducedRouteCGRs).
556
+
557
+ Iterates over each RouteCGR in the provided dictionary and applies the compose_reduced_route_cgr function.
558
+
559
+ Args:
560
+ route_cgrs_dict (dict): A dictionary where keys are identifiers (e.g., route numbers)
561
+ and values are RouteCGR objects.
562
+
563
+ Returns:
564
+ dict: A dictionary where each key corresponds to the original identifier from
565
+ `route_cgrs_dict` and the value is the corresponding ReducedRouteCGR object.
566
+ """
567
+ all_reduced_route_cgrs = dict()
568
+ for num, cgr in route_cgrs_dict.items():
569
+ all_reduced_route_cgrs[num] = compose_reduced_route_cgr(cgr)
570
+ return all_reduced_route_cgrs
synplan/chem/reaction_routes/visualisation.py ADDED
@@ -0,0 +1,903 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from CGRtools.algorithms.depict import (
2
+ Depict,
3
+ DepictMolecule,
4
+ DepictCGR,
5
+ rotate_vector,
6
+ _render_charge,
7
+ )
8
+ from CGRtools.containers import ReactionContainer, MoleculeContainer, CGRContainer
9
+
10
+ from collections import defaultdict
11
+ from uuid import uuid4
12
+ from math import hypot
13
+ from functools import partial
14
+
15
+
16
+ class WideBondDepictCGR(DepictCGR):
17
+ """
18
+ Like DepictCGR, but all DynamicBonds
19
+ are drawn 2.5× wider than the standard bond width.
20
+ """
21
+
22
+ __slots__ = ()
23
+
24
+ def _render_bonds(self):
25
+ """
26
+ Renders the bonds of the CGR as SVG lines, with DynamicBonds drawn wider.
27
+
28
+ This method overrides the base `_render_bonds` to apply a wider stroke
29
+ to DynamicBonds, highlighting changes in bond order during a reaction.
30
+ It iterates through all bonds, calculates their positions based on
31
+ 2D coordinates, and generates SVG `<line>` elements with appropriate
32
+ styles (color, width, dash array) based on the bond's original (`order`)
33
+ and primary (`p_order`) states. Aromatic bonds are handled separately
34
+ using a helper method.
35
+
36
+ Returns:
37
+ list: A list of strings, where each string is an SVG element
38
+ representing a bond.
39
+ """
40
+ plane = self._plane
41
+ config = self._render_config
42
+
43
+ # get the normal width (default 1.0) and compute a 4× wide stroke
44
+ normal_width = config.get("bond_width", 0.02)
45
+ wide_width = normal_width * 2.5
46
+
47
+ broken = config["broken_color"]
48
+ formed = config["formed_color"]
49
+ dash1, dash2 = config["dashes"]
50
+ double_space = config["double_space"]
51
+ triple_space = config["triple_space"]
52
+
53
+ svg = []
54
+ ar_bond_colors = defaultdict(dict)
55
+
56
+ for n, m, bond in self.bonds():
57
+ order, p_order = bond.order, bond.p_order
58
+ nx, ny = plane[n]
59
+ mx, my = plane[m]
60
+ # invert Y for SVG
61
+ ny, my = -ny, -my
62
+ rv = partial(rotate_vector, 0, x2=mx - nx, y2=ny - my)
63
+ if order == 1:
64
+ if p_order == 1:
65
+ svg.append(
66
+ f' <line x1="{nx:.2f}" y1="{ny:.2f}" x2="{mx:.2f}" y2="{my:.2f}"/>'
67
+ )
68
+ elif p_order == 4:
69
+ ar_bond_colors[n][m] = ar_bond_colors[m][n] = formed
70
+ svg.append(
71
+ f' <line x1="{nx:.2f}" y1="{ny:.2f}" x2="{mx:.2f}" y2="{my:.2f}"/>'
72
+ )
73
+ elif p_order == 2:
74
+ dx, dy = rv(double_space)
75
+ svg.append(
76
+ f' <line x1="{nx + dx:.2f}" y1="{ny - dy:.2f}" '
77
+ f'x2="{mx + dx:.2f}" y2="{my - dy:.2f}"/>'
78
+ )
79
+ svg.append(
80
+ f' <line x1="{nx - dx:.2f}" y1="{ny + dy:.2f}" '
81
+ f'x2="{mx - dx:.2f}" y2="{my + dy:.2f}" stroke="{formed}" stroke-width="{wide_width:.2f}"/>'
82
+ )
83
+ elif p_order == 3:
84
+ dx, dy = rv(triple_space)
85
+ svg.append(
86
+ f' <line x1="{nx:.2f}" y1="{ny:.2f}" x2="{mx:.2f}" y2="{my:.2f}"'
87
+ f' stroke="{formed}" stroke-width="{wide_width:.2f}"/>'
88
+ )
89
+ svg.append(
90
+ f' <line x1="{nx + dx:.2f}" y1="{ny - dy:.2f}" '
91
+ f'x2="{mx + dx:.2f}" y2="{my - dy:.2f}" stroke-width="{wide_width:.2f}"/>'
92
+ )
93
+ svg.append(
94
+ f' <line x1="{nx - dx:.2f}" y1="{ny + dy:.2f}" '
95
+ f'x2="{mx - dx:.2f}" y2="{my + dy:.2f}" stroke="{formed}" stroke-width="{wide_width:.2f}"/>'
96
+ )
97
+ elif p_order is None:
98
+ svg.append(
99
+ f' <line x1="{nx:.2f}" y1="{ny:.2f}" x2="{mx:.2f}" y2="{my:.2f}"'
100
+ f' stroke="{broken}" stroke-width="{wide_width:.2f}"/>'
101
+ )
102
+ else:
103
+ dx, dy = rv(double_space)
104
+ svg.append(
105
+ f' <line x1="{nx + dx:.2f}" y1="{ny - dy:.2f}" x2="{mx + dx:.2f}"'
106
+ f' y2="{my - dy:.2f}" stroke-dasharray="{dash1:.2f} {dash2:.2f}" stroke="{formed}" stroke-width="{wide_width:.2f}"/>'
107
+ )
108
+ svg.append(
109
+ f' <line x1="{nx - dx:.2f}" y1="{ny + dy:.2f}" '
110
+ f'x2="{mx - dx:.2f}" y2="{my + dy:.2f}" stroke="{broken}" stroke-width="{wide_width:.2f}"/>'
111
+ )
112
+ elif order == 4:
113
+ if p_order == 4:
114
+ svg.append(
115
+ f' <line x1="{nx:.2f}" y1="{ny:.2f}" x2="{mx:.2f}" y2="{my:.2f}"/>'
116
+ )
117
+ elif p_order == 1:
118
+ ar_bond_colors[n][m] = ar_bond_colors[m][n] = broken
119
+ svg.append(
120
+ f' <line x1="{nx:.2f}" y1="{ny:.2f}" x2="{mx:.2f}" y2="{my:.2f}"/>'
121
+ )
122
+ elif p_order == 2:
123
+ ar_bond_colors[n][m] = ar_bond_colors[m][n] = broken
124
+ dx, dy = rv(double_space)
125
+ svg.append(
126
+ f' <line x1="{nx + dx:.2f}" y1="{ny - dy:.2f}" '
127
+ f'x2="{mx + dx:.2f}" y2="{my - dy:.2f}"/>'
128
+ )
129
+ svg.append(
130
+ f' <line x1="{nx - dx:.2f}" y1="{ny + dy:.2f}" '
131
+ f'x2="{mx - dx:.2f}" y2="{my + dy:.2f}" stroke="{formed}" stroke-width="{wide_width:.2f}"/>'
132
+ )
133
+ elif p_order == 3:
134
+ ar_bond_colors[n][m] = ar_bond_colors[m][n] = broken
135
+ dx, dy = rv(triple_space)
136
+ svg.append(
137
+ f' <line x1="{nx + dx:.2f}" y1="{ny - dy:.2f}" '
138
+ f'x2="{mx + dx:.2f}" y2="{my - dy:.2f}" stroke="{formed}" stroke-width="{wide_width:.2f}"/>'
139
+ )
140
+ svg.append(
141
+ f' <line x1="{nx:.2f}" y1="{ny:.2f}" x2="{mx:.2f}" y2="{my:.2f}"/>'
142
+ )
143
+ svg.append(
144
+ f' <line x1="{nx - dx:.2f}" y1="{ny + dy:.2f}" '
145
+ f'x2="{mx - dx:.2f}" y2="{my + dy:.2f}" stroke="{formed}" stroke-width="{wide_width:.2f}"/>'
146
+ )
147
+ elif p_order is None:
148
+ ar_bond_colors[n][m] = ar_bond_colors[m][n] = broken
149
+ svg.append(
150
+ f' <line x1="{nx:.2f}" y1="{ny:.2f}" x2="{mx:.2f}" y2="{my:.2f}"'
151
+ f' stroke="{broken}" stroke-width="{wide_width:.2f}"/>'
152
+ )
153
+ else:
154
+ ar_bond_colors[n][m] = ar_bond_colors[m][n] = None
155
+ svg.append(
156
+ f' <line x1="{nx:.2f}" y1="{ny:.2f}" x2="{mx:.2f}" y2="{my:.2f}"'
157
+ f' stroke="{broken}" stroke-width="{wide_width:.2f}"/>'
158
+ )
159
+ elif order == 2:
160
+ if p_order == 2:
161
+ dx, dy = rv(double_space)
162
+ svg.append(
163
+ f' <line x1="{nx + dx:.2f}" y1="{ny - dy:.2f}" '
164
+ f'x2="{mx + dx:.2f}" y2="{my - dy:.2f}"/>'
165
+ )
166
+ svg.append(
167
+ f' <line x1="{nx - dx:.2f}" y1="{ny + dy:.2f}" '
168
+ f'x2="{mx - dx:.2f}" y2="{my + dy:.2f}"/>'
169
+ )
170
+ elif p_order == 1:
171
+ dx, dy = rv(double_space)
172
+ svg.append(
173
+ f' <line x1="{nx + dx:.2f}" y1="{ny - dy:.2f}" '
174
+ f'x2="{mx + dx:.2f}" y2="{my - dy:.2f}"/>'
175
+ )
176
+ svg.append(
177
+ f' <line x1="{nx - dx:.2f}" y1="{ny + dy:.2f}" '
178
+ f'x2="{mx - dx:.2f}" y2="{my + dy:.2f}" stroke="{broken}" stroke-width="{wide_width:.2f}"/>'
179
+ )
180
+ elif p_order == 4:
181
+ ar_bond_colors[n][m] = ar_bond_colors[m][n] = formed
182
+ dx, dy = rv(double_space)
183
+ svg.append(
184
+ f' <line x1="{nx + dx:.2f}" y1="{ny - dy:.2f}" '
185
+ f'x2="{mx + dx:.2f}" y2="{my - dy:.2f}"/>'
186
+ )
187
+ svg.append(
188
+ f' <line x1="{nx - dx:.2f}" y1="{ny + dy:.2f}" '
189
+ f'x2="{mx - dx:.2f}" y2="{my + dy:.2f}" stroke="{broken}" stroke-width="{wide_width:.2f}"/>'
190
+ )
191
+ elif p_order == 3:
192
+ dx, dy = rv(triple_space)
193
+ svg.append(
194
+ f' <line x1="{nx + dx:.2f}" y1="{ny - dy:.2f}" '
195
+ f'x2="{mx + dx:.2f}" y2="{my - dy:.2f}"/>'
196
+ )
197
+ svg.append(
198
+ f' <line x1="{nx:.2f}" y1="{ny:.2f}" x2="{mx:.2f}" y2="{my:.2f}"/>'
199
+ )
200
+ svg.append(
201
+ f' <line x1="{nx - dx:.2f}" y1="{ny + dy:.2f}" '
202
+ f'x2="{mx - dx:.2f}" y2="{my + dy:.2f}" stroke="{formed} stroke-width="{wide_width:.2f}""/>'
203
+ )
204
+ elif p_order is None:
205
+ dx, dy = rv(double_space)
206
+ svg.append(
207
+ f' <line x1="{nx + dx:.2f}" y1="{ny - dy:.2f}" '
208
+ f'x2="{mx + dx:.2f}" y2="{my - dy:.2f}" stroke="{broken}" stroke-width="{wide_width:.2f}"/>'
209
+ )
210
+ svg.append(
211
+ f' <line x1="{nx - dx:.2f}" y1="{ny + dy:.2f}" '
212
+ f'x2="{mx - dx:.2f}" y2="{my + dy:.2f}" stroke="{broken}" stroke-width="{wide_width:.2f}"/>'
213
+ )
214
+ else:
215
+ dx, dy = rv(triple_space)
216
+ svg.append(
217
+ f' <line x1="{nx + dx:.2f}" y1="{ny - dy:.2f}" x2="{mx + dx:.2f}"'
218
+ f' y2="{my - dy:.2f}" stroke-dasharray="{dash1:.2f} {dash2:.2f}" stroke="{formed}" stroke-width="{wide_width:.2f}"/>'
219
+ )
220
+ svg.append(
221
+ f' <line x1="{nx:.2f}" y1="{ny:.2f}" x2="{mx:.2f}" y2="{my:.2f}"'
222
+ f' stroke="{broken}" stroke-width="{wide_width:.2f}"/>'
223
+ )
224
+ svg.append(
225
+ f' <line x1="{nx - dx:.2f}" y1="{ny + dy:.2f}" '
226
+ f'x2="{mx - dx:.2f}" y2="{my + dy:.2f}" stroke="{broken}" stroke-width="{wide_width:.2f}"/>'
227
+ )
228
+ elif order == 3:
229
+ if p_order == 3:
230
+ dx, dy = rv(triple_space)
231
+ svg.append(
232
+ f' <line x1="{nx + dx:.2f}" y1="{ny - dy:.2f}" '
233
+ f'x2="{mx + dx:.2f}" y2="{my - dy:.2f}"/>'
234
+ )
235
+ svg.append(
236
+ f' <line x1="{nx:.2f}" y1="{ny:.2f}" x2="{mx:.2f}" y2="{my:.2f}"/>'
237
+ )
238
+ svg.append(
239
+ f' <line x1="{nx - dx:.2f}" y1="{ny + dy:.2f}" '
240
+ f'x2="{mx - dx:.2f}" y2="{my + dy:.2f}"/>'
241
+ )
242
+ elif p_order == 1:
243
+ dx, dy = rv(triple_space)
244
+ svg.append(
245
+ f' <line x1="{nx + dx:.2f}" y1="{ny - dy:.2f}" '
246
+ f'x2="{mx + dx:.2f}" y2="{my - dy:.2f}"/>'
247
+ )
248
+ svg.append(
249
+ f' <line x1="{nx:.2f}" y1="{ny:.2f}" x2="{mx:.2f}" y2="{my:.2f}"'
250
+ f' stroke="{broken}" stroke-width="{wide_width:.2f}"/>'
251
+ )
252
+ svg.append(
253
+ f' <line x1="{nx - dx:.2f}" y1="{ny + dy:.2f}" '
254
+ f'x2="{mx - dx:.2f}" y2="{my + dy:.2f}" '
255
+ f'stroke="{broken}" stroke-width="{wide_width:.2f}"/>'
256
+ )
257
+ elif p_order == 4:
258
+ ar_bond_colors[n][m] = ar_bond_colors[m][n] = formed
259
+ dx, dy = rv(triple_space)
260
+ svg.append(
261
+ f' <line x1="{nx + dx:.2f}" y1="{ny - dy:.2f}" x2="{mx + dx:.2f}" '
262
+ f'y2="{my - dy:.2f}" stroke="{broken}" stroke-width="{wide_width:.2f}"/>'
263
+ )
264
+ svg.append(
265
+ f' <line x1="{nx:.2f}" y1="{ny:.2f}" x2="{mx:.2f}" y2="{my:.2f}"/>'
266
+ )
267
+ svg.append(
268
+ f' <line x1="{nx - dx:.2f}" y1="{ny + dy:.2f}" x2="{mx - dx:.2f}" '
269
+ f'y2="{my + dy:.2f}" stroke="{broken}" stroke-width="{wide_width:.2f}"/>'
270
+ )
271
+ elif p_order == 2:
272
+ dx, dy = rv(triple_space)
273
+ svg.append(
274
+ f' <line x1="{nx + dx:.2f}" y1="{ny - dy:.2f}" '
275
+ f'x2="{mx + dx:.2f}" y2="{my - dy:.2f}"/>'
276
+ )
277
+ svg.append(
278
+ f' <line x1="{nx:.2f}" y1="{ny:.2f}" x2="{mx:.2f}" y2="{my:.2f}"/>'
279
+ )
280
+ svg.append(
281
+ f' <line x1="{nx - dx:.2f}" y1="{ny + dy:.2f}" '
282
+ f'x2="{mx - dx:.2f}" y2="{my + dy:.2f}" stroke="{broken}" stroke-width="{wide_width:.2f}"/>'
283
+ )
284
+ elif p_order is None:
285
+ dx, dy = rv(triple_space)
286
+ svg.append(
287
+ f' <line x1="{nx + dx:.2f}" y1="{ny - dy:.2f}" '
288
+ f'x2="{mx + dx:.2f}" y2="{my - dy:.2f}" stroke="{broken}" stroke-width="{wide_width:.2f}"/>'
289
+ )
290
+ svg.append(
291
+ f' <line x1="{nx:.2f}" y1="{ny:.2f}" '
292
+ f'x2="{mx:.2f}" y2="{my:.2f}" stroke="{broken}" stroke-width="{wide_width:.2f}"/>'
293
+ )
294
+ svg.append(
295
+ f' <line x1="{nx - dx:.2f}" y1="{ny + dy:.2f}" '
296
+ f'x2="{mx - dx:.2f}" y2="{my + dy:.2f}" stroke="{broken}" stroke-width="{wide_width:.2f}"/>'
297
+ )
298
+ else:
299
+ dx, dy = rv(double_space)
300
+ dx3 = 3 * dx
301
+ dy3 = 3 * dy
302
+ svg.append(
303
+ f' <line x1="{nx + dx3:.2f}" y1="{ny - dy3:.2f}" x2="{mx + dx3:.2f}" '
304
+ f'y2="{my - dy3:.2f}" stroke-dasharray="{dash1:.2f} {dash2:.2f}" stroke="{formed}" stroke-width="{wide_width:.2f}"/>'
305
+ )
306
+ svg.append(
307
+ f' <line x1="{nx + dx:.2f}" y1="{ny - dy:.2f}" '
308
+ f'x2="{mx + dx:.2f}" y2="{my - dy:.2f}" stroke="{broken}" stroke-width="{wide_width:.2f}"/>'
309
+ )
310
+ svg.append(
311
+ f' <line x1="{nx - dx:.2f}" y1="{ny + dy:.2f}" '
312
+ f'x2="{mx - dx:.2f}" y2="{my + dy:.2f}" stroke="{broken}" stroke-width="{wide_width:.2f}"/>'
313
+ )
314
+ svg.append(
315
+ f' <line x1="{nx - dx3:.2f}" y1="{ny + dy3:.2f}" x2="{mx - dx3:.2f}" '
316
+ f'y2="{my + dy3:.2f}" stroke="{broken}" stroke-width="{wide_width:.2f}"/>'
317
+ )
318
+ elif order is None:
319
+ if p_order == 1:
320
+ svg.append(
321
+ f' <line x1="{nx:.2f}" y1="{ny:.2f}" x2="{mx:.2f}" y2="{my:.2f}"'
322
+ f' stroke="{formed}" stroke-width="{wide_width:.2f}"/>'
323
+ )
324
+ elif p_order == 4:
325
+ ar_bond_colors[n][m] = ar_bond_colors[m][n] = formed
326
+ svg.append(
327
+ f' <line x1="{nx:.2f}" y1="{ny:.2f}" x2="{mx:.2f}" y2="{my:.2f}"'
328
+ f' stroke="{formed}" stroke-width="{wide_width:.2f}"/>'
329
+ )
330
+ elif p_order == 2:
331
+ dx, dy = rv(double_space)
332
+ # dx = dx // 1.4
333
+ # dy = dy // 1.4
334
+ svg.append(
335
+ f' <line x1="{nx + dx:.2f}" y1="{ny - dy:.2f}" x2="{mx + dx:.2f}" '
336
+ f'y2="{my - dy:.2f}" stroke="{formed}" stroke-width="{wide_width:.2f}"/>'
337
+ )
338
+ svg.append(
339
+ f' <line x1="{nx - dx:.2f}" y1="{ny + dy:.2f}" x2="{mx - dx:.2f}" '
340
+ f'y2="{my + dy:.2f}" stroke="{formed}" stroke-width="{wide_width:.2f}"/>'
341
+ )
342
+ elif p_order == 3:
343
+ dx, dy = rv(triple_space)
344
+ svg.append(
345
+ f' <line x1="{nx + dx:.2f}" y1="{ny - dy:.2f}" '
346
+ f'x2="{mx + dx:.2f}" y2="{my - dy:.2f}" stroke="{formed}" stroke-width="{wide_width:.2f}"/>'
347
+ )
348
+ svg.append(
349
+ f' <line x1="{nx:.2f}" y1="{ny:.2f}" x2="{mx:.2f}" y2="{my:.2f}"'
350
+ f' stroke="{formed}" stroke-width="{wide_width:.2f}"/>'
351
+ )
352
+ svg.append(
353
+ f' <line x1="{nx - dx:.2f}" y1="{ny + dy:.2f}" '
354
+ f'x2="{mx - dx:.2f}" y2="{my + dy:.2f}" stroke="{formed}" stroke-width="{wide_width:.2f}"/>'
355
+ )
356
+ else:
357
+ svg.append(
358
+ f' <line x1="{nx:.2f}" y1="{ny:.2f}" x2="{mx:.2f}" y2="{my:.2f}" '
359
+ f'stroke-dasharray="{dash1:.2f} {dash2:.2f}" stroke="{formed}" stroke-width="{wide_width:.2f}"/>'
360
+ )
361
+ else:
362
+ if p_order == 8:
363
+ svg.append(
364
+ f' <line x1="{nx:.2f}" y1="{ny:.2f}" x2="{mx:.2f}" y2="{my:.2f}" '
365
+ f'stroke-dasharray="{dash1:.2f} {dash2:.2f}"/>'
366
+ )
367
+ elif p_order == 1:
368
+ dx, dy = rv(double_space)
369
+ svg.append(
370
+ f' <line x1="{nx + dx:.2f}" y1="{ny - dy:.2f}" x2="{mx + dx:.2f}"'
371
+ f' y2="{my - dy:.2f}" stroke-dasharray="{dash1:.2f} {dash2:.2f}" stroke="{broken}" stroke-width="{wide_width:.2f}"/>'
372
+ )
373
+ svg.append(
374
+ f' <line x1="{nx - dx:.2f}" y1="{ny + dy:.2f}" '
375
+ f'x2="{mx - dx:.2f}" y2="{my + dy:.2f}" stroke="{formed}" stroke-width="{wide_width:.2f}"/>'
376
+ )
377
+ elif p_order == 4:
378
+ ar_bond_colors[n][m] = ar_bond_colors[m][n] = None
379
+ svg.append(
380
+ f' <line x1="{nx:.2f}" y1="{ny:.2f}" x2="{mx:.2f}" y2="{my:.2f}"'
381
+ f' stroke="{formed}" stroke-width="{wide_width:.2f}"/>'
382
+ )
383
+ elif p_order == 2:
384
+ dx, dy = rv(triple_space)
385
+ svg.append(
386
+ f' <line x1="{nx + dx:.2f}" y1="{ny - dy:.2f}" x2="{mx + dx:.2f}"'
387
+ f' y2="{my - dy:.2f}" stroke-dasharray="{dash1:.2f} {dash2:.2f}" stroke="{broken}" stroke-width="{wide_width:.2f}"/>'
388
+ )
389
+ svg.append(
390
+ f' <line x1="{nx:.2f}" y1="{ny:.2f}" x2="{mx:.2f}" y2="{my:.2f}"'
391
+ f' stroke="{formed}" stroke-width="{wide_width:.2f}"/>'
392
+ )
393
+ svg.append(
394
+ f' <line x1="{nx - dx:.2f}" y1="{ny + dy:.2f}" '
395
+ f'x2="{mx - dx:.2f}" y2="{my + dy:.2f}" stroke="{formed}" stroke-width="{wide_width:.2f}"/>'
396
+ )
397
+ elif p_order == 3:
398
+ dx, dy = rv(double_space)
399
+ dx3 = 3 * dx
400
+ dy3 = 3 * dy
401
+ svg.append(
402
+ f' <line x1="{nx + dx3:.2f}" y1="{ny - dy3:.2f}" x2="{mx + dx3:.2f}" '
403
+ f'y2="{my - dy3:.2f}" stroke-dasharray="{dash1:.2f} {dash2:.2f}" stroke="{broken}" stroke-width="{wide_width:.2f}"/>'
404
+ )
405
+ svg.append(
406
+ f' <line x1="{nx + dx:.2f}" y1="{ny - dy:.2f}" '
407
+ f'x2="{mx + dx:.2f}" y2="{my - dy:.2f}" stroke="{formed}" stroke-width="{wide_width:.2f}"/>'
408
+ )
409
+ svg.append(
410
+ f' <line x1="{nx - dx:.2f}" y1="{ny + dy:.2f}" '
411
+ f'x2="{mx - dx:.2f}" y2="{my + dy:.2f}" stroke="{formed}" stroke-width="{wide_width:.2f}"/>'
412
+ )
413
+ svg.append(
414
+ f' <line x1="{nx - dx3:.2f}" y1="{ny + dy3:.2f}" '
415
+ f'x2="{mx - dx3:.2f}" y2="{my + dy3:.2f}" stroke="{formed}" stroke-width="{wide_width:.2f}"/>'
416
+ )
417
+ else:
418
+ svg.append(
419
+ f' <line x1="{nx:.2f}" y1="{ny:.2f}" x2="{mx:.2f}" y2="{my:.2f}" '
420
+ f'stroke-dasharray="{dash1:.2f} {dash2:.2f}" stroke="{broken}" stroke-width="{wide_width:.2f}"/>'
421
+ )
422
+
423
+ # aromatic rings - unchanged
424
+ for ring in self.aromatic_rings:
425
+ cx = sum(plane[x][0] for x in ring) / len(ring)
426
+ cy = sum(plane[x][1] for x in ring) / len(ring)
427
+
428
+ for n, m in zip(ring, ring[1:]):
429
+ nx, ny = plane[n]
430
+ mx, my = plane[m]
431
+ aromatic = self.__render_aromatic_bond(
432
+ nx, ny, mx, my, cx, cy, ar_bond_colors[n].get(m)
433
+ )
434
+ if aromatic:
435
+ svg.append(aromatic)
436
+
437
+ n, m = ring[-1], ring[0]
438
+ nx, ny = plane[n]
439
+ mx, my = plane[m]
440
+ aromatic = self.__render_aromatic_bond(
441
+ nx, ny, mx, my, cx, cy, ar_bond_colors[n].get(m)
442
+ )
443
+ if aromatic:
444
+ svg.append(aromatic)
445
+ return svg
446
+
447
+ def __render_aromatic_bond(self, n_x, n_y, m_x, m_y, c_x, c_y, color):
448
+ config = self._render_config
449
+
450
+ dash1, dash2 = config["dashes"]
451
+ dash3, dash4 = config["aromatic_dashes"]
452
+ aromatic_space = config["cgr_aromatic_space"]
453
+
454
+ normal_width = config.get("bond_width", 0.02)
455
+ wide_width = normal_width * 2
456
+
457
+ # n aligned xy
458
+ mn_x, mn_y, cn_x, cn_y = m_x - n_x, m_y - n_y, c_x - n_x, c_y - n_y
459
+
460
+ # nm reoriented xy
461
+ mr_x, mr_y = hypot(mn_x, mn_y), 0
462
+ cr_x, cr_y = rotate_vector(cn_x, cn_y, mn_x, -mn_y)
463
+
464
+ if cr_y and aromatic_space / cr_y < 0.65:
465
+ if cr_y > 0:
466
+ r_y = aromatic_space
467
+ else:
468
+ r_y = -aromatic_space
469
+ cr_y = -cr_y
470
+
471
+ ar_x = aromatic_space * cr_x / cr_y
472
+ br_x = mr_x - aromatic_space * (mr_x - cr_x) / cr_y
473
+
474
+ # backward reorienting
475
+ an_x, an_y = rotate_vector(ar_x, r_y, mn_x, mn_y)
476
+ bn_x, bn_y = rotate_vector(br_x, r_y, mn_x, mn_y)
477
+
478
+ if color:
479
+ # print('color')
480
+ return (
481
+ f' <line x1="{an_x + n_x:.2f}" y1="{-an_y - n_y:.2f}" x2="{bn_x + n_x:.2f}" '
482
+ f'y2="{-bn_y - n_y:.2f}" stroke-dasharray="{dash3:.2f} {dash4:.2f}" stroke="{color}" stroke-width="{wide_width:.2f}"/>'
483
+ )
484
+ elif color is None:
485
+ dash3, dash4 = dash1, dash2
486
+ return (
487
+ f' <line x1="{an_x + n_x:.2f}" y1="{-an_y - n_y:.2f}"'
488
+ f' x2="{bn_x + n_x:.2f}" y2="{-bn_y - n_y:.2f}" stroke-dasharray="{dash3:.2f} {dash4:.2f}"/>'
489
+ )
490
+
491
+
492
+ def cgr_display(cgr: CGRContainer) -> str:
493
+ """
494
+ Generates an SVG string for displaying a CGR with wider DynamicBonds.
495
+
496
+ This function temporarily modifies the rendering methods of the
497
+ `CGRContainer` class to use the bond rendering logic from
498
+ `WideBondDepictCGR`, which draws DynamicBonds with a wider stroke.
499
+ It cleans the 2D coordinates of the input CGR and then calls its
500
+ `depict()` method to generate the SVG string using the modified
501
+ rendering behavior.
502
+
503
+ Args:
504
+ cgr (CGRContainer): The CGRContainer object to be depicted.
505
+
506
+ Returns:
507
+ str: An SVG string representing the depiction of the CGR
508
+ with wider DynamicBonds.
509
+ """
510
+ CGRContainer._CGRContainer__render_aromatic_bond = (
511
+ WideBondDepictCGR._WideBondDepictCGR__render_aromatic_bond
512
+ )
513
+ CGRContainer._render_bonds = WideBondDepictCGR._render_bonds
514
+ CGRContainer._WideBondDepictCGR__render_aromatic_bond = (
515
+ WideBondDepictCGR._WideBondDepictCGR__render_aromatic_bond
516
+ )
517
+ cgr.clean2d()
518
+ return cgr.depict()
519
+
520
+
521
+ class CustomDepictMolecule(DepictMolecule):
522
+ """
523
+ Custom molecule depiction class that uses atom.symbol for rendering.
524
+ """
525
+
526
+ def _render_atoms(self):
527
+ bonds = self._bonds
528
+ plane = self._plane
529
+ charges = self._charges
530
+ radicals = self._radicals
531
+ hydrogens = self._hydrogens
532
+ config = self._render_config
533
+
534
+ carbon = config["carbon"]
535
+ mapping = config["mapping"]
536
+ span_size = config["span_size"]
537
+ font_size = config["font_size"]
538
+ monochrome = config["monochrome"]
539
+ other_size = config["other_size"]
540
+ atoms_colors = config["atoms_colors"]
541
+ mapping_font = config["mapping_size"]
542
+ dx_m, dy_m = config["dx_m"], config["dy_m"]
543
+ dx_ci, dy_ci = config["dx_ci"], config["dy_ci"]
544
+ symbols_font_style = config["symbols_font_style"]
545
+
546
+ # for cumulenes
547
+ try:
548
+ # Check if _cumulenes method exists and handle potential errors
549
+ cumulenes = {
550
+ y
551
+ for x in self._cumulenes(heteroatoms=True)
552
+ if len(x) > 2
553
+ for y in x[1:-1]
554
+ }
555
+ except AttributeError:
556
+ cumulenes = set() # Fallback if _cumulenes is not available or fails
557
+
558
+ if monochrome:
559
+ map_fill = other_fill = "black"
560
+ else:
561
+ map_fill = config["mapping_color"]
562
+ other_fill = config["other_color"]
563
+
564
+ svg = []
565
+ maps = []
566
+ others = []
567
+ font2 = 0.2 * font_size
568
+ font3 = 0.3 * font_size
569
+ font4 = 0.4 * font_size
570
+ font5 = 0.5 * font_size
571
+ font6 = 0.6 * font_size
572
+ font7 = 0.7 * font_size
573
+ font15 = 0.15 * font_size
574
+ font25 = 0.25 * font_size
575
+ mask = defaultdict(list)
576
+ for n, atom in self._atoms.items():
577
+ x, y = plane[n]
578
+ y = -y
579
+
580
+ # --- KEY CHANGE HERE ---
581
+ # Use atom.symbol if it exists, otherwise fallback to atomic_symbol
582
+ try:
583
+ symbol = atom.symbol
584
+ except AttributeError:
585
+ symbol = atom.atomic_symbol # Fallback if .symbol doesn't exist
586
+ # --- END KEY CHANGE ---
587
+
588
+ if (
589
+ not bonds.get(n)
590
+ or symbol != "C"
591
+ or carbon
592
+ or atom.charge
593
+ or atom.is_radical
594
+ or atom.isotope
595
+ or n in cumulenes
596
+ ): # Added bonds.get(n) check for single atoms
597
+ # Calculate hydrogens if the attribute exists, otherwise default to 0
598
+ try:
599
+ h = hydrogens[n]
600
+ except (KeyError, AttributeError):
601
+ h = 0 # Default if _hydrogens is missing or key n is not present
602
+
603
+ if h == 1:
604
+ h_str = "H"
605
+ span = ""
606
+ elif h and h > 1: # Check if h is not None and greater than 1
607
+ span = f'<tspan dy="{config["span_dy"]:.2f}" font-size="{span_size:.2f}">{h}</tspan>'
608
+ h_str = "H"
609
+ else:
610
+ h_str = ""
611
+ span = ""
612
+
613
+ # Handle charges and radicals safely
614
+ charge_val = charges.get(n, 0)
615
+ is_radical = radicals.get(n, False)
616
+
617
+ if charge_val:
618
+ t = f'{_render_charge.get(charge_val, "")}{"↑" if is_radical else ""}' # Use .get for safety
619
+ if t: # Only add if charge text is generated
620
+ others.append(
621
+ f' <text x="{x:.2f}" y="{y:.2f}" dx="{dx_ci:.2f}" dy="-{dy_ci:.2f}">'
622
+ f"{t}</text>"
623
+ )
624
+ mask["other"].append(
625
+ f' <text x="{x:.2f}" y="{y:.2f}" dx="{dx_ci:.2f}" dy="-{dy_ci:.2f}">'
626
+ f"{t}</text>"
627
+ )
628
+ elif is_radical:
629
+ others.append(
630
+ f' <text x="{x:.2f}" y="{y:.2f}" dx="{dx_ci:.2f}" dy="-{dy_ci:.2f}">↑</text>'
631
+ )
632
+ mask["other"].append(
633
+ f' <text x="{x:.2f}" y="{y:.2f}" dx="{dx_ci:.2f}"'
634
+ f' dy="-{dy_ci:.2f}">↑</text>'
635
+ )
636
+
637
+ # Handle isotope safely
638
+ try:
639
+ iso = atom.isotope
640
+ if iso:
641
+ t = iso
642
+ others.append(
643
+ f' <text x="{x:.2f}" y="{y:.2f}" dx="-{dx_ci:.2f}" dy="-{dy_ci:.2f}" '
644
+ f'text-anchor="end">{t}</text>'
645
+ )
646
+ mask["other"].append(
647
+ f' <text x="{x:.2f}" y="{y:.2f}" dx="-{dx_ci:.2f}"'
648
+ f' dy="-{dy_ci:.2f}" text-anchor="end">{t}</text>'
649
+ )
650
+ except AttributeError:
651
+ pass # Atom might not have isotope attribute
652
+
653
+ # Determine atom color based on atomic_number, default to black if monochrome or not found
654
+ atom_color = "black"
655
+ if not monochrome:
656
+ try:
657
+ an = atom.atomic_number
658
+ if 0 < an <= len(atoms_colors):
659
+ atom_color = atoms_colors[an - 1]
660
+ else:
661
+ atom_color = atoms_colors[
662
+ 5
663
+ ] # Default to Carbon color if out of range
664
+ except AttributeError:
665
+ atom_color = atoms_colors[
666
+ 5
667
+ ] # Default to Carbon color if no atomic_number
668
+
669
+ svg.append(
670
+ f' <g fill="{atom_color}" '
671
+ f'font-family="{symbols_font_style }">'
672
+ )
673
+
674
+ # Adjust dx based on symbol length for better centering
675
+ if len(symbol) > 1:
676
+ dx = font7
677
+ dx_mm = dx_m + font5
678
+ if symbol[-1].lower() in (
679
+ "l",
680
+ "i",
681
+ "r",
682
+ "t",
683
+ ): # Heuristic for narrow last letters
684
+ rx = font6
685
+ ax = font25
686
+ else:
687
+ rx = font7
688
+ ax = font15
689
+ mask["center"].append(
690
+ f' <ellipse cx="{x - ax:.2f}" cy="{y:.2f}" rx="{rx}" ry="{font4}"/>'
691
+ )
692
+ else:
693
+ if symbol == "I": # Special case for 'I'
694
+ dx = font15
695
+ dx_mm = dx_m
696
+ else: # Single character
697
+ dx = font4
698
+ dx_mm = dx_m + font2
699
+ mask["center"].append(
700
+ f' <circle cx="{x:.2f}" cy="{y:.2f}" r="{font4:.2f}"/>'
701
+ )
702
+
703
+ svg.append(
704
+ f' <text x="{x:.2f}" y="{y:.2f}" dx="-{dx:.2f}" dy="{font4:.2f}" '
705
+ f'font-size="{font_size:.2f}">{symbol}{h_str}{span}</text>'
706
+ )
707
+ mask["symbols"].append(
708
+ f' <text x="{x:.2f}" y="{y:.2f}" dx="-{dx:.2f}" '
709
+ f'dy="{font4:.2f}">{symbol}{h_str}</text>'
710
+ )
711
+ if span:
712
+ mask["span"].append(
713
+ f' <text x="{x:.2f}" y="{y:.2f}" dx="-{dx:.2f}" dy="{font4:.2f}">'
714
+ f"{symbol}{h_str}{span}</text>"
715
+ )
716
+ svg.append(" </g>")
717
+
718
+ if mapping:
719
+ maps.append(
720
+ f' <text x="{x:.2f}" y="{y:.2f}" dx="-{dx_mm:.2f}" dy="{dy_m + font3:.2f}" '
721
+ f'text-anchor="end">{n}</text>'
722
+ )
723
+ mask["aam"].append(
724
+ f' <text x="{x:.2f}" y="{y:.2f}" dx="-{dx_mm:.2f}" '
725
+ f'dy="{dy_m + font3:.2f}" text-anchor="end">{n}</text>'
726
+ )
727
+
728
+ elif mapping:
729
+ # Determine dx_mm for mapping based on symbol length even if atom itself isn't drawn
730
+ if len(symbol) > 1:
731
+ dx_mm = dx_m + font5
732
+ else:
733
+ dx_mm = dx_m + font2 if symbol != "I" else dx_m
734
+
735
+ maps.append(
736
+ f' <text x="{x:.2f}" y="{y:.2f}" dx="-{dx_mm:.2f}" dy="{dy_m:.2f}" '
737
+ f'text-anchor="end">{n}</text>'
738
+ )
739
+ mask["aam"].append(
740
+ f' <text x="{x:.2f}" y="{y:.2f}" dx="-{dx_mm:.2f}" dy="{dy_m:.2f}" '
741
+ f'text-anchor="end">{n}</text>'
742
+ )
743
+ if others:
744
+ svg.append(
745
+ f' <g font-family="{config["other_font_style"]}" fill="{other_fill}" '
746
+ f'font-size="{other_size:.2f}">'
747
+ )
748
+ svg.extend(others)
749
+ svg.append(" </g>")
750
+ if mapping:
751
+ svg.append(f' <g fill="{map_fill}" font-size="{mapping_font:.2f}">')
752
+ svg.extend(maps)
753
+ svg.append(" </g>")
754
+ return svg, mask
755
+
756
+
757
+ def depict_custom_reaction(reaction: ReactionContainer):
758
+ """
759
+ Depicts a ReactionContainer using custom atom rendering logic (replace At to X).
760
+
761
+ This function generates an SVG string representing a reaction. It
762
+ temporarily modifies the classes of the molecules within the reaction
763
+ to use a custom depiction logic (`CustomDepictMolecule`) that alters
764
+ how atoms are rendered (specifically, it seems to use `atom.symbol`
765
+ instead of `atom.atomic_symbol`, potentially for replacing 'At' with 'X'
766
+ as mentioned in the original comment). After depicting each molecule
767
+ with the temporary class, it restores the original classes. The function
768
+ then combines the individual molecule depictions, reaction arrow, and
769
+ reaction signs into a single SVG.
770
+
771
+ Args:
772
+ reaction (ReactionContainer): The ReactionContainer object to be depicted.
773
+
774
+ Returns:
775
+ str: An SVG string representing the depiction of the reaction
776
+ with custom atom rendering.
777
+ """
778
+ if not reaction._arrow:
779
+ reaction.fix_positions() # Ensure positions are calculated
780
+
781
+ r_atoms = []
782
+ r_bonds = []
783
+ r_masks = []
784
+ r_max_x = r_max_y = r_min_y = 0
785
+ original_classes = {} # Store original classes to restore later
786
+
787
+ try:
788
+ # Temporarily change the class of molecules to use the custom depiction
789
+ for mol in reaction.molecules():
790
+ if isinstance(mol, (MoleculeContainer, CGRContainer)):
791
+ original_classes[mol] = mol.__class__
792
+ custom_class_name = (
793
+ f"TempCustom_{mol.__class__.__name__}_{uuid4().hex}" # Unique name
794
+ )
795
+ # Combine custom depiction with original class methods
796
+ # Ensure the custom _render_atoms takes precedence
797
+ new_bases = (CustomDepictMolecule,) + original_classes[mol].__bases__
798
+ # Filter out DepictMolecule if it's already a base to avoid MRO issues
799
+ new_bases = tuple(b for b in new_bases if b is not DepictMolecule)
800
+ # If DepictMolecule wasn't a direct base, ensure its methods are accessible
801
+ if CustomDepictMolecule not in original_classes[mol].__mro__:
802
+ # Prioritize CustomDepictMolecule's methods
803
+ new_bases = (CustomDepictMolecule, original_classes[mol])
804
+ else:
805
+ # If DepictMolecule was a base, CustomDepictMolecule is already first
806
+ new_bases = (CustomDepictMolecule,) + tuple(
807
+ b
808
+ for b in original_classes[mol].__bases__
809
+ if b is not DepictMolecule
810
+ )
811
+
812
+ # Create the temporary class
813
+ mol.__class__ = type(custom_class_name, new_bases, {})
814
+
815
+ # Depict using the (potentially) modified class
816
+ atoms, bonds, masks, min_x, min_y, max_x, max_y = mol.depict(embedding=True)
817
+ r_atoms.append(atoms)
818
+ r_bonds.append(bonds)
819
+ r_masks.append(masks)
820
+ if max_x > r_max_x:
821
+ r_max_x = max_x
822
+ if max_y > r_max_y:
823
+ r_max_y = max_y
824
+ if min_y < r_min_y:
825
+ r_min_y = min_y
826
+
827
+ finally:
828
+ # Restore original classes
829
+ for mol, original_class in original_classes.items():
830
+ mol.__class__ = original_class
831
+
832
+ config = DepictMolecule._render_config # Access via the imported class
833
+
834
+ font_size = config["font_size"]
835
+ font125 = 1.25 * font_size
836
+ width = r_max_x + 3.0 * font_size
837
+ height = r_max_y - r_min_y + 2.5 * font_size
838
+ viewbox_x = -font125
839
+ viewbox_y = -r_max_y - font125
840
+
841
+ svg = [
842
+ f'<svg width="{width:.2f}cm" height="{height:.2f}cm" '
843
+ f'viewBox="{viewbox_x:.2f} {viewbox_y:.2f} {width:.2f} '
844
+ f'{height:.2f}" xmlns="http://www.w3.org/2000/svg" version="1.1">\n'
845
+ ' <defs>\n <marker id="arrow" markerWidth="10" markerHeight="10" '
846
+ 'refX="0" refY="3" orient="auto">\n <path d="M0,0 L0,6 L9,3"/>\n </marker>\n </defs>\n'
847
+ f' <line x1="{reaction._arrow[0]:.2f}" y1="0" x2="{reaction._arrow[1]:.2f}" y2="0" '
848
+ 'fill="none" stroke="black" stroke-width=".04" marker-end="url(#arrow)"/>'
849
+ ]
850
+
851
+ sings_plus = reaction._signs
852
+ if sings_plus:
853
+ svg.append(f' <g fill="none" stroke="black" stroke-width=".04">')
854
+ for x in sings_plus:
855
+ svg.append(
856
+ f' <line x1="{x + .35:.2f}" y1="0" x2="{x + .65:.2f}" y2="0"/>'
857
+ )
858
+ svg.append(
859
+ f' <line x1="{x + .5:.2f}" y1="0.15" x2="{x + .5:.2f}" y2="-0.15"/>'
860
+ )
861
+ svg.append(" </g>")
862
+
863
+ for atoms, bonds, masks in zip(r_atoms, r_bonds, r_masks):
864
+ # Use the static method from Depict directly
865
+ svg.extend(
866
+ Depict._graph_svg(atoms, bonds, masks, viewbox_x, viewbox_y, width, height)
867
+ )
868
+ svg.append("</svg>")
869
+ return "\n".join(svg)
870
+
871
+
872
+ def remove_and_shift(nested_dict, to_remove): # Under development
873
+ """
874
+ Removes specified inner keys from a nested dictionary and renumbers the remaining keys.
875
+
876
+ Given a dictionary where values are themselves dictionaries, this function
877
+ iterates through each inner dictionary. For each inner dictionary, it
878
+ creates a new dictionary containing only the key-value pairs where the
879
+ inner key is NOT present in the `to_remove` list. The keys of the remaining
880
+ elements in the new inner dictionary are then renumbered sequentially
881
+ starting from 0, effectively removing gaps left by the removed keys.
882
+
883
+ Args:
884
+ nested_dict (dict): The input nested dictionary (dict of dicts).
885
+ to_remove (list): A list of keys to remove from the inner dictionaries.
886
+
887
+ Returns:
888
+ dict: A new nested dictionary with the specified keys removed from
889
+ inner dictionaries and the remaining inner keys renumbered.
890
+ """
891
+ rem_set = set(to_remove)
892
+
893
+ result = {}
894
+ for outer_k, inner in nested_dict.items():
895
+ new_inner = {}
896
+ for old_k, v in inner.items():
897
+ if old_k in rem_set:
898
+ continue
899
+ shift = sum(1 for r in rem_set if r < old_k)
900
+ new_k = old_k - shift
901
+ new_inner[new_k] = v
902
+ result[outer_k] = new_inner
903
+ return result
synplan/chem/reaction_rules/__init__.py ADDED
File without changes
synplan/chem/reaction_rules/extraction.py ADDED
@@ -0,0 +1,744 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Module containing functions for protocol of reaction rules extraction."""
2
+
3
+ import logging
4
+ import pickle
5
+ from collections import defaultdict
6
+ from itertools import islice
7
+ from os.path import splitext
8
+ from typing import Dict, List, Set, Tuple
9
+
10
+ import ray
11
+ from chython import smarts
12
+ from chython import QueryContainer as QueryContainerChython
13
+ from CGRtools.containers.cgr import CGRContainer
14
+ from CGRtools.containers.molecule import MoleculeContainer
15
+ from CGRtools.containers.query import QueryContainer
16
+ from CGRtools.containers.reaction import ReactionContainer
17
+ from CGRtools.exceptions import InvalidAromaticRing
18
+ from CGRtools.reactor import Reactor
19
+ from tqdm import tqdm
20
+
21
+ from synplan.chem.data.standardizing import RemoveReagentsStandardizer
22
+ from synplan.chem.utils import (
23
+ reverse_reaction,
24
+ cgrtools_to_chython_molecule,
25
+ chython_query_to_cgrtools,
26
+ )
27
+ from synplan.utils.config import RuleExtractionConfig
28
+ from synplan.utils.files import ReactionReader
29
+
30
+
31
+ def add_environment_atoms(
32
+ cgr: CGRContainer, center_atoms: Set[int], environment_atom_count: int
33
+ ) -> Set[int]:
34
+ """
35
+ Adds environment atoms to the set of center atoms based on the specified depth.
36
+
37
+ :param cgr: A complete graph representation of a reaction (ReactionContainer
38
+ object).
39
+ :param center_atoms: A set of atom id corresponding to the center atoms of the
40
+ reaction.
41
+ :param environment_atom_count: An integer specifying the depth of the environment
42
+ around the reaction center to be included. If it's 0, only the reaction center
43
+ is included. If it's 1, the first layer of surrounding atoms is included, and so
44
+ on.
45
+
46
+ :return: A set of atom id including the center atoms and their environment atoms up
47
+ to the specified depth. If environment_atom_count is 0, the original set of
48
+ center atoms is returned unchanged.
49
+
50
+ """
51
+ if environment_atom_count:
52
+ env_cgr = cgr.augmented_substructure(center_atoms, deep=environment_atom_count)
53
+ # combine the original center atoms with the new environment atoms
54
+ return center_atoms | set(env_cgr)
55
+
56
+ # if no environment is to be included, return the original center atoms
57
+ return center_atoms
58
+
59
+
60
+ def add_functional_groups(
61
+ reaction: ReactionContainer,
62
+ center_atoms: Set[int],
63
+ func_groups_list: List[QueryContainerChython],
64
+ ) -> Set[int]:
65
+ """
66
+ Augments the set of reaction rule atoms with functional groups if specified.
67
+
68
+ :param reaction: The reaction object (ReactionContainer) from which molecules are
69
+ extracted.
70
+ :param center_atoms: A set of atom id corresponding to the center atoms of the
71
+ reaction.
72
+ :param func_groups_list: A list of functional group objects (MoleculeContainer or
73
+ QueryContainer) to be considered when including functional groups. These objects
74
+ define the structure of the functional groups to be included.
75
+
76
+ :return: A set of atom id corresponding to the rule atoms, including atoms from the
77
+ specified functional groups if include_func_groups is True. If
78
+ include_func_groups is False, the original set of center atoms is returned.
79
+
80
+ """
81
+
82
+ rule_atoms = center_atoms.copy()
83
+ # iterate over each molecule in the reaction
84
+ for molecule in reaction.molecules():
85
+ molecule_chython = cgrtools_to_chython_molecule(molecule)
86
+ # for each functional group specified in the list
87
+ for func_group in func_groups_list:
88
+ # find mappings of the functional group in the molecule
89
+ for mapping in func_group.get_mapping(molecule_chython):
90
+ # remap the functional group based on the found mapping
91
+ func_group.remap(mapping)
92
+ # if the functional group intersects with center atoms, include it
93
+ if set(func_group.atoms_numbers) & center_atoms:
94
+ rule_atoms |= set(func_group.atoms_numbers)
95
+ # reset the mapping to its original state for the next iteration
96
+ func_group.remap({v: k for k, v in mapping.items()})
97
+ return rule_atoms
98
+
99
+
100
+ def add_ring_structures(cgr: CGRContainer, rule_atoms: Set[int]) -> Set[int]:
101
+ """
102
+ Adds ring structures to the set of rule atoms if they intersect with the reaction
103
+ center atoms.
104
+
105
+ :param cgr: A condensed graph representation of a reaction (CGRContainer object).
106
+ :param rule_atoms: A set of atom id corresponding to the center atoms of the
107
+ reaction.
108
+
109
+ :return: A set of atom id corresponding to the original rule atoms and the included
110
+ ring structures.
111
+
112
+ """
113
+ for ring in cgr.sssr:
114
+ # check if the current ring intersects with the set of rule atoms
115
+ if set(ring) & rule_atoms:
116
+ # if the intersection exists, include all atoms in the ring to the rule atoms
117
+ rule_atoms |= set(ring)
118
+ return rule_atoms
119
+
120
+
121
+ def add_leaving_incoming_groups(
122
+ reaction: ReactionContainer,
123
+ rule_atoms: Set[int],
124
+ keep_leaving_groups: bool,
125
+ keep_incoming_groups: bool,
126
+ ) -> Tuple[Set[int], Dict[str, Set]]:
127
+ """
128
+ Identifies and includes leaving and incoming groups to the rule atoms based on
129
+ specified flags.
130
+
131
+ :param reaction: The reaction object (ReactionContainer) from which leaving and
132
+ incoming groups are extracted.
133
+ :param rule_atoms: A set of atom id corresponding to the center atoms of the
134
+ reaction.
135
+ :param keep_leaving_groups: A boolean flag indicating whether to include leaving
136
+ groups in the rule.
137
+ :param keep_incoming_groups: A boolean flag indicating whether to include incoming
138
+ groups in the rule.
139
+
140
+ :return: Updated set of rule atoms including leaving and incoming groups if
141
+ specified, and metadata about added groups.
142
+
143
+ """
144
+
145
+ meta_debug = {"leaving": set(), "incoming": set()}
146
+
147
+ # extract atoms from reactants and products
148
+ reactant_atoms = {atom for reactant in reaction.reactants for atom in reactant}
149
+ product_atoms = {atom for product in reaction.products for atom in product}
150
+
151
+ # identify leaving groups (reactant atoms not in products)
152
+ if keep_leaving_groups:
153
+ leaving_atoms = reactant_atoms - product_atoms
154
+ new_leaving_atoms = leaving_atoms - rule_atoms
155
+ # include leaving atoms in the rule atoms
156
+ rule_atoms |= leaving_atoms
157
+ # add leaving atoms to metadata
158
+ meta_debug["leaving"] |= new_leaving_atoms
159
+
160
+ # identify incoming groups (product atoms not in reactants)
161
+ if keep_incoming_groups:
162
+ incoming_atoms = product_atoms - reactant_atoms
163
+ new_incoming_atoms = incoming_atoms - rule_atoms
164
+ # Include incoming atoms in the rule atoms
165
+ rule_atoms |= incoming_atoms
166
+ # Add incoming atoms to metadata
167
+ meta_debug["incoming"] |= new_incoming_atoms
168
+
169
+ return rule_atoms, meta_debug
170
+
171
+
172
+ def clean_molecules(
173
+ rule_molecules: List[MoleculeContainer],
174
+ reaction_molecules: Tuple[MoleculeContainer],
175
+ reaction_center_atoms: Set[int],
176
+ atom_retention_details: Dict[str, Dict[str, bool]],
177
+ ) -> List[QueryContainer]:
178
+ """
179
+ Cleans rule molecules by removing specified information about atoms based on
180
+ retention details provided.
181
+
182
+ :param rule_molecules: A list of query container objects representing the rule molecules.
183
+ :param reaction_molecules: A list of molecule container objects involved in the reaction.
184
+ :param reaction_center_atoms: A set of id corresponding to the atom numbers in the reaction center.
185
+ :param atom_retention_details: A dictionary specifying what atom information to retain or remove.
186
+ This dictionary should have two keys: "reaction_center" and "environment",
187
+ each mapping to another dictionary. The nested dictionaries should have
188
+ keys representing atom attributes (like "neighbors", "hybridization",
189
+ "implicit_hydrogens", "ring_sizes") and boolean values.
190
+ A value of True indicates that the corresponding attribute
191
+ should be retained, while False indicates it should be removed from the atom.
192
+
193
+ :return: A list of QueryContainer objects representing the cleaned rule molecules.
194
+
195
+ """
196
+ cleaned_rule_molecules = []
197
+
198
+ for rule_molecule in rule_molecules:
199
+ for reaction_molecule in reaction_molecules:
200
+ if set(rule_molecule.atoms_numbers) <= set(reaction_molecule.atoms_numbers):
201
+ query_reaction_molecule = reaction_molecule.substructure(
202
+ reaction_molecule, as_query=True
203
+ )
204
+ query_rule_molecule = query_reaction_molecule.substructure(
205
+ rule_molecule
206
+ )
207
+
208
+ # clean reaction center atoms
209
+ if not all(
210
+ atom_retention_details["reaction_center"].values()
211
+ ): # if everything True, we keep all marks
212
+ local_reaction_center_atoms = (
213
+ set(rule_molecule.atoms_numbers) & reaction_center_atoms
214
+ )
215
+ for atom_number in local_reaction_center_atoms:
216
+ query_rule_molecule = clean_atom(
217
+ query_rule_molecule,
218
+ atom_retention_details["reaction_center"],
219
+ atom_number,
220
+ )
221
+
222
+ # clean environment atoms
223
+ if not all(
224
+ atom_retention_details["environment"].values()
225
+ ): # if everything True, we keep all marks
226
+ local_environment_atoms = (
227
+ set(rule_molecule.atoms_numbers) - reaction_center_atoms
228
+ )
229
+ for atom_number in local_environment_atoms:
230
+ query_rule_molecule = clean_atom(
231
+ query_rule_molecule,
232
+ atom_retention_details["environment"],
233
+ atom_number,
234
+ )
235
+
236
+ cleaned_rule_molecules.append(query_rule_molecule)
237
+ break
238
+
239
+ return cleaned_rule_molecules
240
+
241
+
242
+ def clean_atom(
243
+ query_molecule: QueryContainer,
244
+ attributes_to_keep: Dict[str, bool],
245
+ atom_number: int,
246
+ ) -> QueryContainer:
247
+ """
248
+ Removes specified information from a given atom in a query molecule.
249
+
250
+ :param query_molecule: The QueryContainer of molecule.
251
+ :param attributes_to_keep: Dictionary indicating which attributes to keep in the atom. The keys should be strings
252
+ representing the attribute names, and the values should be booleans indicating whether
253
+ to retain (True) or remove(False) that attribute. Expected keys are:
254
+ - "neighbors": Indicates if neighbors of the atom should be removed.
255
+ - "hybridization": Indicates if hybridization information of the atom should be removed.
256
+ - "implicit_hydrogens": Indicates if implicit hydrogen information of the atom should be removed.
257
+ - "ring_sizes": Indicates if ring size information of the atom should be removed.
258
+
259
+ :param atom_number: The number of the atom to be modified in the query molecule.
260
+
261
+ """
262
+
263
+ target_atom = query_molecule.atom(atom_number)
264
+
265
+ if not attributes_to_keep["neighbors"]:
266
+ target_atom.neighbors = None
267
+ if not attributes_to_keep["hybridization"]:
268
+ target_atom.hybridization = None
269
+ if not attributes_to_keep["implicit_hydrogens"]:
270
+ target_atom.implicit_hydrogens = None
271
+ if not attributes_to_keep["ring_sizes"]:
272
+ target_atom.ring_sizes = None
273
+
274
+ return query_molecule
275
+
276
+
277
+ def create_substructures_and_reagents(
278
+ reaction: ReactionContainer,
279
+ rule_atoms: Set[int],
280
+ as_query_container: bool,
281
+ keep_reagents: bool,
282
+ ) -> Tuple[List[MoleculeContainer], List[MoleculeContainer], List]:
283
+ """
284
+ Creates substructures for reactants and products, and optionally includes
285
+ reagents, based on specified parameters. The function processes the reaction to
286
+ create substructures for reactants and products based on the rule atoms. It also
287
+ handles the inclusion of reagents based on the keep_reagents flag and converts these
288
+ structures to query containers if required.
289
+
290
+ :param reaction: The reaction object (ReactionContainer) from which to extract substructures.
291
+ This object represents a chemical reaction with specified reactants, products, and possibly reagents.
292
+ :param rule_atoms: A set of atom id corresponding to the rule atoms. These are used to identify relevant
293
+ substructures in reactants and products.
294
+ :param as_query_container: A boolean flag indicating whether the substructures should be converted to query containers.
295
+ Query containers are used for pattern matching in chemical structures.
296
+ :param keep_reagents: A boolean flag indicating whether reagents should be included in the resulting structures.
297
+ Reagents are additional substances that are present in the reaction but are not reactants or products.
298
+
299
+ :return: A tuple containing three elements:
300
+ - A list of reactant substructures, each corresponding to a part of the reactants that matches the rule atoms.
301
+ - A list of product substructures, each corresponding to a part of the products that matches the rule atoms.
302
+ - A list of reagents, included as is or as substructures, depending on the as_query_container flag.
303
+
304
+ """
305
+ reactant_substructures = [
306
+ reactant.substructure(rule_atoms.intersection(reactant.atoms_numbers))
307
+ for reactant in reaction.reactants
308
+ if rule_atoms.intersection(reactant.atoms_numbers)
309
+ ]
310
+
311
+ product_substructures = [
312
+ product.substructure(rule_atoms.intersection(product.atoms_numbers))
313
+ for product in reaction.products
314
+ if rule_atoms.intersection(product.atoms_numbers)
315
+ ]
316
+
317
+ reagents = []
318
+ if keep_reagents:
319
+ if as_query_container:
320
+ reagents = [
321
+ reagent.substructure(reagent, as_query=True)
322
+ for reagent in reaction.reagents
323
+ ]
324
+ else:
325
+ reagents = reaction.reagents
326
+
327
+ return reactant_substructures, product_substructures, reagents
328
+
329
+
330
+ def assemble_final_rule(
331
+ reactant_substructures: List[QueryContainer],
332
+ product_substructures: List[QueryContainer],
333
+ reagents: List,
334
+ meta_debug: Dict[str, Set],
335
+ keep_metadata: bool,
336
+ reaction: ReactionContainer,
337
+ ) -> ReactionContainer:
338
+ """
339
+ Assembles the final reaction rule from the provided substructures and metadata.
340
+ This function brings together the various components of a reaction rule, including
341
+ reactant and product substructures, reagents, and metadata. It creates a
342
+ comprehensive representation of the reaction rule, which can be used for further
343
+ processing or analysis.
344
+
345
+ :param reactant_substructures: A list of substructures derived from the reactants of
346
+ the reaction. These substructures represent parts of reactants that are relevant
347
+ to the rule.
348
+ :param product_substructures: A list of substructures derived from the products of
349
+ the reaction. These substructures represent parts of products that are relevant
350
+ to the rule.
351
+ :param reagents: A list of reagents involved in the reaction. These may be included
352
+ as-is or as substructures, depending on earlier processing steps.
353
+ :param meta_debug: A dictionary containing additional metadata about the reaction,
354
+ such as leaving and incoming groups.
355
+ :param keep_metadata: A boolean flag indicating whether to retain the metadata
356
+ associated with the reaction in the rule.
357
+ :param reaction: The original reaction object (ReactionContainer) from which the
358
+ rule is being created.
359
+
360
+ :return: A ReactionContainer object representing the assembled reaction rule. This
361
+ container includes the reactant and product substructures, reagents, and any
362
+ additional metadata if keep_metadata is True.
363
+
364
+ """
365
+
366
+ rule_metadata = meta_debug if keep_metadata else {}
367
+ rule_metadata.update(reaction.meta if keep_metadata else {})
368
+
369
+ rule = ReactionContainer(
370
+ reactant_substructures, product_substructures, reagents, rule_metadata
371
+ )
372
+
373
+ if keep_metadata:
374
+ rule.name = reaction.name
375
+
376
+ rule.flush_cache()
377
+ return rule
378
+
379
+
380
+ def validate_rule(rule: ReactionContainer, reaction: ReactionContainer) -> bool:
381
+ """
382
+ Validates a reaction rule by ensuring it can correctly generate the products from
383
+ the reactants. The function uses a chemical reactor to simulate the reaction based
384
+ on the provided rule. It then compares the products generated by the simulation with
385
+ the actual products of the reaction. If they match, the rule is considered valid. If
386
+ not, a ValueError is raised, indicating an issue with the rule.
387
+
388
+ :param rule: The reaction rule to be validated. This is a ReactionContainer object
389
+ representing a chemical reaction rule, which includes the necessary information
390
+ to perform a reaction.
391
+ :param reaction: The original reaction object (ReactionContainer) against which the
392
+ rule is to be validated. This object contains the actual reactants and products
393
+ of the reaction.
394
+
395
+ :return: The validated rule if the rule correctly generates the products from the
396
+ reactants.
397
+
398
+ :raises ValueError: If the rule does not correctly generate the products from the
399
+ reactants, indicating an incorrect or incomplete rule.
400
+
401
+ """
402
+
403
+ # create a reactor with the given rule
404
+ reactor = Reactor(rule)
405
+ try:
406
+ for result_reaction in reactor(reaction.reactants):
407
+ result_products = []
408
+ for result_product in result_reaction.products:
409
+ tmp = result_product.copy()
410
+ try:
411
+ tmp.kekule()
412
+ if tmp.check_valence():
413
+ continue
414
+ except InvalidAromaticRing:
415
+ continue
416
+ result_products.append(result_product)
417
+ if set(reaction.products) == set(result_products) and len(
418
+ reaction.products
419
+ ) == len(result_products):
420
+ return True
421
+
422
+ except (KeyError, IndexError):
423
+ # KeyError - iteration over reactor is finished and products are different from the original reaction
424
+ # IndexError - mistake in __contract_ions, possibly problems with charges in reaction rule
425
+ return False
426
+
427
+ return False
428
+
429
+
430
+ def create_rule(
431
+ config: RuleExtractionConfig, reaction: ReactionContainer
432
+ ) -> ReactionContainer:
433
+ """
434
+ Creates a reaction rule from a given reaction based on the specified
435
+ configuration. The function processes the reaction to create a rule that matches the
436
+ configuration settings. It handles the inclusion of environmental atoms, functional
437
+ groups, ring structures, and leaving and incoming groups. It also constructs
438
+ substructures for reactants, products, and reagents, and cleans molecule
439
+ representations if required. Optionally, it validates the rule using a reactor.
440
+
441
+ :param config: An instance of ExtractRuleConfig, containing various settings that
442
+ determine how the rule is created, such as environmental atom count, inclusion
443
+ of functional groups, rings, leaving and incoming groups, and other parameters.
444
+ :param reaction: The reaction object (ReactionContainer) from which to create the
445
+ rule. This object represents a chemical reaction with specified reactants,
446
+ products, and possibly reagents.
447
+ :return: A ReactionContainer object representing the extracted reaction rule. This
448
+ rule includes various elements of the reaction as specified by the
449
+ configuration, such as reaction centers, environmental atoms, functional groups,
450
+ and others.
451
+
452
+ """
453
+
454
+ # 1. create reaction CGR
455
+ cgr = ~reaction
456
+ center_atoms = set(cgr.center_atoms)
457
+
458
+ # 2. add atoms of reaction environment based on config settings
459
+ center_atoms = add_environment_atoms(
460
+ cgr, center_atoms, config.environment_atom_count
461
+ )
462
+
463
+ # 3. include functional groups in the rule if specified in config
464
+ if config.include_func_groups and config.func_groups_list:
465
+ rule_atoms = add_functional_groups(
466
+ reaction, center_atoms, config.func_groups_list
467
+ )
468
+ else:
469
+ rule_atoms = center_atoms.copy()
470
+
471
+ # 4. include ring structures in the rule if specified in config
472
+ if config.include_rings:
473
+ rule_atoms = add_ring_structures(cgr, rule_atoms)
474
+
475
+ # 5. add leaving and incoming groups to the rule based on config settings
476
+ rule_atoms, meta_debug = add_leaving_incoming_groups(
477
+ reaction, rule_atoms, config.keep_leaving_groups, config.keep_incoming_groups
478
+ )
479
+
480
+ # 6. create substructures for reactants, products, and reagents
481
+ reactant_substructures, product_substructures, reagents = (
482
+ create_substructures_and_reagents(
483
+ reaction, rule_atoms, config.as_query_container, config.keep_reagents
484
+ )
485
+ )
486
+ # 7. clean atom marks in the molecules if they are being converted to query containers
487
+ if config.as_query_container:
488
+ reactant_substructures = clean_molecules(
489
+ reactant_substructures,
490
+ reaction.reactants,
491
+ center_atoms,
492
+ config.atom_info_retention,
493
+ )
494
+
495
+ product_substructures = clean_molecules(
496
+ product_substructures,
497
+ reaction.products,
498
+ center_atoms,
499
+ config.atom_info_retention,
500
+ )
501
+
502
+ # 8. assemble the final rule including metadata if specified
503
+ rule = assemble_final_rule(
504
+ reactant_substructures,
505
+ product_substructures,
506
+ reagents,
507
+ meta_debug,
508
+ config.keep_metadata,
509
+ reaction,
510
+ )
511
+
512
+ # 9. reverse extracted reaction rule and reaction
513
+ if config.reverse_rule:
514
+ rule = reverse_reaction(rule)
515
+ reaction = reverse_reaction(reaction)
516
+
517
+ # 10. validate the rule using a reactor if validation is enabled in config
518
+ if config.reactor_validation:
519
+ if validate_rule(rule, reaction):
520
+ rule.meta["reactor_validation"] = "passed"
521
+ else:
522
+ rule.meta["reactor_validation"] = "failed"
523
+
524
+ return rule
525
+
526
+
527
+ def extract_rules(
528
+ config: RuleExtractionConfig, reaction: ReactionContainer
529
+ ) -> List[ReactionContainer]:
530
+ """
531
+ Extracts reaction rules from a given reaction based on the specified
532
+ configuration.
533
+
534
+ :param config: An instance of ExtractRuleConfig, which contains various
535
+ configuration settings for rule extraction, such as whether to include
536
+ multicenter rules, functional groups, ring structures, leaving and incoming
537
+ groups, etc.
538
+ :param reaction: The reaction object (ReactionContainer) from which to extract
539
+ rules. The reaction object represents a chemical reaction with specified
540
+ reactants, products, and possibly reagents.
541
+ :return: A list of ReactionContainer objects, each representing a distinct reaction
542
+ rule. If config.multicenter_rules is True, a single rule encompassing all
543
+ reaction centers is returned. Otherwise, separate rules for each reaction center
544
+ are extracted, up to a maximum of 15 distinct centers.
545
+
546
+ """
547
+
548
+ standardizer = (
549
+ RemoveReagentsStandardizer()
550
+ ) # reagents are needed if they are the part of reaction rule specification
551
+ reaction = standardizer(reaction)
552
+
553
+ if config.multicenter_rules:
554
+ # extract a single rule encompassing all reaction centers
555
+ return [create_rule(config, reaction)]
556
+
557
+ # extract separate rules for each distinct reaction center
558
+ distinct_rules = set()
559
+ for center_reaction in islice(reaction.enumerate_centers(), 15):
560
+ single_rule = create_rule(config, center_reaction)
561
+ distinct_rules.add(single_rule)
562
+
563
+ return list(distinct_rules)
564
+
565
+
566
+ @ray.remote
567
+ def process_reaction_batch(
568
+ batch: List[Tuple[int, ReactionContainer]], config: RuleExtractionConfig
569
+ ) -> List[Tuple[int, List[ReactionContainer]]]:
570
+ """
571
+ Processes a batch of reactions to extract reaction rules based on the given
572
+ configuration. This function operates as a remote task in a distributed system using
573
+ Ray. It takes a batch of reactions, where each reaction is paired with an index. For
574
+ each reaction in the batch, it extracts reaction rules as specified by the
575
+ configuration object. The extracted rules for each reaction are then returned along
576
+ with the corresponding index. This function is intended to be used in a distributed
577
+ manner with Ray to parallelize the rule extraction process across multiple
578
+ reactions.
579
+
580
+ :param batch: A list where each element is a tuple containing an index (int) and a
581
+ ReactionContainer object. The index is typically used to keep track of the
582
+ reaction's position in a larger dataset.
583
+ :param config: An instance of ExtractRuleConfig that provides settings and
584
+ parameters for the rule extraction process.
585
+ :return: A list where each element is a tuple. The first element of the tuple is an
586
+ index (int), and the second is a list of ReactionContainer objects representing
587
+ the extracted rules for the corresponding reaction.
588
+
589
+ """
590
+
591
+ extracted_rules_list = []
592
+ for index, reaction in batch:
593
+ try:
594
+ extracted_rules = extract_rules(config, reaction)
595
+ extracted_rules_list.append((index, extracted_rules))
596
+ except Exception as e:
597
+ logging.debug(e)
598
+ continue
599
+ return extracted_rules_list
600
+
601
+
602
+ def process_completed_batch(
603
+ futures: Dict,
604
+ rules_statistics: Dict,
605
+ ) -> None:
606
+ """
607
+ Processes completed batches of reactions, updating the rules statistics and
608
+ writing rules to a file. This function waits for the completion of a batch of
609
+ reactions processed in parallel (using Ray), updates the statistics for each
610
+ extracted rule, and writes the rules to a result file if they are new. It also
611
+ updates the progress bar with the size of the processed batch.
612
+
613
+ :param futures: A dictionary of futures representing ongoing batch processing tasks.
614
+ :param rules_statistics: A dictionary to keep track of statistics for each rule.
615
+ :return: None
616
+
617
+ """
618
+
619
+ ready_id, running_id = ray.wait(list(futures.keys()), num_returns=1)
620
+ completed_batch = ray.get(ready_id[0])
621
+ for index, extracted_rules in completed_batch:
622
+ for rule in extracted_rules:
623
+ prev_stats_len = len(rules_statistics)
624
+ rules_statistics[rule].append(index)
625
+ if len(rules_statistics) != prev_stats_len:
626
+ rule.meta["first_reaction_index"] = index
627
+
628
+ del futures[ready_id[0]]
629
+
630
+
631
+ def sort_rules(
632
+ rules_stats: Dict, min_popularity: int, single_reactant_only: bool
633
+ ) -> List[Tuple[ReactionContainer, List[int]]]:
634
+ """
635
+ Sorts reaction rules based on their popularity and validation status. This
636
+ function sorts the given rules according to their popularity (i.e., the number of
637
+ times they have been applied) and filters out rules that haven't passed reactor
638
+ validation or are less popular than the specified minimum popularity threshold.
639
+
640
+ :param rules_stats: A dictionary where each key is a reaction rule and the value is
641
+ a list of integers. Each integer represents an index where the rule was applied.
642
+ :type rules_stats: The number of occurrence of the reaction rules.
643
+ :param min_popularity: The minimum number of times a rule must be applied to be
644
+ considered. Default is 3.
645
+ :type min_popularity: The minimum number of occurrence of the reaction rule to be
646
+ selected.
647
+ :param single_reactant_only: Whether to keep only reaction rules with a single
648
+ molecule on the right side of reaction arrow. Default is True.
649
+
650
+ :return: A list of tuples, where each tuple contains a reaction rule and a list of
651
+ indices representing the rule's applications. The list is sorted in descending
652
+ order of the rule's popularity.
653
+
654
+ """
655
+
656
+ return sorted(
657
+ (
658
+ (rule, indices)
659
+ for rule, indices in rules_stats.items()
660
+ if len(indices) >= min_popularity
661
+ and rule.meta["reactor_validation"] == "passed"
662
+ and (not single_reactant_only or len(rule.reactants) == 1)
663
+ ),
664
+ key=lambda x: -len(x[1]),
665
+ )
666
+
667
+
668
+ def extract_rules_from_reactions(
669
+ config: RuleExtractionConfig,
670
+ reaction_data_path: str,
671
+ reaction_rules_path: str,
672
+ num_cpus: int,
673
+ batch_size: int,
674
+ ) -> None:
675
+ """
676
+ Extracts reaction rules from a set of reactions based on the given configuration.
677
+ This function initializes a Ray environment for distributed computing and processes
678
+ each reaction in the provided reaction database to extract reaction rules. It
679
+ handles the reactions in batches, parallelize the rule extraction process. Extracted
680
+ rules are written to RDF files and their statistics are recorded. The function also
681
+ sorts the rules based on their popularity and saves the sorted rules.
682
+
683
+ :param config: Configuration settings for rule extraction, including file paths,
684
+ batch size, and other parameters.
685
+ :param reaction_data_path: Path to the file containing reaction database.
686
+ :param reaction_rules_path: Name of the file to store the extracted rules.
687
+ :param num_cpus: Number of CPU cores to use for processing. Defaults to 1.
688
+ :param batch_size: Number of reactions to process in each batch. Defaults to 10.
689
+ :return: None
690
+
691
+ """
692
+
693
+ ray.init(num_cpus=num_cpus, ignore_reinit_error=True, logging_level=logging.ERROR)
694
+
695
+ reaction_rules_path, _ = splitext(reaction_rules_path)
696
+ with ReactionReader(reaction_data_path) as reactions:
697
+
698
+ futures = {}
699
+ batch = []
700
+ max_concurrent_batches = num_cpus
701
+ extracted_rules_and_statistics = defaultdict(list)
702
+
703
+ for index, reaction in tqdm(
704
+ enumerate(reactions),
705
+ desc="Number of reactions processed: ",
706
+ bar_format="{desc}{n} [{elapsed}]",
707
+ ):
708
+
709
+ # reaction ready to use
710
+ batch.append((index, reaction))
711
+ if len(batch) == batch_size:
712
+ future = process_reaction_batch.remote(batch, config)
713
+
714
+ futures[future] = None
715
+ batch = []
716
+
717
+ while len(futures) >= max_concurrent_batches:
718
+ process_completed_batch(
719
+ futures,
720
+ extracted_rules_and_statistics,
721
+ )
722
+
723
+ if batch:
724
+ future = process_reaction_batch.remote(batch, config)
725
+ futures[future] = None
726
+
727
+ while futures:
728
+ process_completed_batch(
729
+ futures,
730
+ extracted_rules_and_statistics,
731
+ )
732
+
733
+ sorted_rules = sort_rules(
734
+ extracted_rules_and_statistics,
735
+ min_popularity=config.min_popularity,
736
+ single_reactant_only=config.single_reactant_only,
737
+ )
738
+
739
+ ray.shutdown()
740
+
741
+ with open(f"{reaction_rules_path}.pickle", "wb") as statistics_file:
742
+ pickle.dump(sorted_rules, statistics_file)
743
+
744
+ print(f"Number of extracted reaction rules: {len(sorted_rules)}")
synplan/chem/reaction_rules/manual/__init__.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ from .decompositions import rules as d_rules
2
+ from .transformations import rules as t_rules
3
+
4
+ hardcoded_rules = t_rules + d_rules
5
+
6
+ __all__ = ["hardcoded_rules"]
synplan/chem/reaction_rules/manual/decompositions.py ADDED
@@ -0,0 +1,413 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Module containing hardcoded decomposition reaction rules."""
2
+
3
+ from CGRtools import QueryContainer, ReactionContainer
4
+ from CGRtools.periodictable import ListElement
5
+
6
+ rules = []
7
+
8
+
9
+ def prepare():
10
+ """Creates and returns three query containers and appends a reaction container to
11
+ the "rules" list."""
12
+ q_ = QueryContainer()
13
+ p1_ = QueryContainer()
14
+ p2_ = QueryContainer()
15
+ rules.append(ReactionContainer((q_,), (p1_, p2_)))
16
+
17
+ return q_, p1_, p2_
18
+
19
+
20
+ # R-amide/ester formation
21
+ # [C](-[N,O;D23;Zs])(-[C])=[O]>>[A].[C]-[C](-[O])=[O]
22
+ q, p1, p2 = prepare()
23
+ q.add_atom("C")
24
+ q.add_atom("C")
25
+ q.add_atom("O")
26
+ q.add_atom(ListElement(["N", "O"]), hybridization=1, neighbors=(2, 3))
27
+ q.add_bond(1, 2, 1)
28
+ q.add_bond(2, 3, 2)
29
+ q.add_bond(2, 4, 1)
30
+
31
+ p1.add_atom("C")
32
+ p1.add_atom("C")
33
+ p1.add_atom("O")
34
+ p1.add_atom("O", _map=5)
35
+ p1.add_bond(1, 2, 1)
36
+ p1.add_bond(2, 3, 2)
37
+ p1.add_bond(2, 5, 1)
38
+
39
+ p2.add_atom("A", _map=4)
40
+
41
+ # acyl group addition with aromatic carbon's case (Friedel-Crafts)
42
+ # [C;Za]-[C](-[C])=[O]>>[C].[C]-[C](-[Cl])=[O]
43
+ q, p1, p2 = prepare()
44
+ q.add_atom("C")
45
+ q.add_atom("C")
46
+ q.add_atom("O")
47
+ q.add_atom("C", hybridization=4)
48
+ q.add_bond(1, 2, 1)
49
+ q.add_bond(2, 3, 2)
50
+ q.add_bond(2, 4, 1)
51
+
52
+ p1.add_atom("C")
53
+ p1.add_atom("C")
54
+ p1.add_atom("O")
55
+ p1.add_atom("Cl", _map=5)
56
+ p1.add_bond(1, 2, 1)
57
+ p1.add_bond(2, 3, 2)
58
+ p1.add_bond(2, 5, 1)
59
+
60
+ p2.add_atom("C", _map=4)
61
+
62
+ # Williamson reaction
63
+ # [C;Za]-[O]-[C;Zs;W0]>>[C]-[Br].[C]-[O]
64
+ q, p1, p2 = prepare()
65
+ q.add_atom("C", hybridization=4)
66
+ q.add_atom("O")
67
+ q.add_atom("C", hybridization=1, heteroatoms=1)
68
+ q.add_bond(1, 2, 1)
69
+ q.add_bond(2, 3, 1)
70
+
71
+ p1.add_atom("C")
72
+ p1.add_atom("O")
73
+ p1.add_bond(1, 2, 1)
74
+
75
+ p2.add_atom("C", _map=3)
76
+ p2.add_atom("Br")
77
+ p2.add_bond(3, 4, 1)
78
+
79
+ # Buchwald-Hartwig amination
80
+ # [N;D23;Zs;W0]-[C;Za]>>[C]-[Br].[N]
81
+ q, p1, p2 = prepare()
82
+ q.add_atom("N", heteroatoms=0, hybridization=1, neighbors=(2, 3))
83
+ q.add_atom("C", hybridization=4)
84
+ q.add_bond(1, 2, 1)
85
+
86
+ p1.add_atom("C", _map=2)
87
+ p1.add_atom("Br")
88
+ p1.add_bond(2, 3, 1)
89
+
90
+ p2.add_atom("N")
91
+
92
+ # imidazole imine atom's alkylation
93
+ # [C;r5](:[N;r5]-[C;Zs;W1]):[N;D2;r5]>>[C]-[Br].[N]:[C]:[N]
94
+ q, p1, p2 = prepare()
95
+ q.add_atom("N", rings_sizes=5)
96
+ q.add_atom("C", rings_sizes=5)
97
+ q.add_atom("N", rings_sizes=5, neighbors=2)
98
+ q.add_atom("C", hybridization=1, heteroatoms=(1, 2))
99
+ q.add_bond(1, 2, 4)
100
+ q.add_bond(2, 3, 4)
101
+ q.add_bond(1, 4, 1)
102
+
103
+ p1.add_atom("N")
104
+ p1.add_atom("C")
105
+ p1.add_atom("N")
106
+ p1.add_bond(1, 2, 4)
107
+ p1.add_bond(2, 3, 4)
108
+
109
+ p2.add_atom("C", _map=4)
110
+ p2.add_atom("Br")
111
+ p2.add_bond(4, 5, 1)
112
+
113
+ # Knoevenagel condensation (nitryl and carboxyl case)
114
+ # [C]=[C](-[C]#[N])-[C](-[O])=[O]>>[C]=[O].[C](-[C]#[N])-[C](-[O])=[O]
115
+ q, p1, p2 = prepare()
116
+ q.add_atom("C")
117
+ q.add_atom("C")
118
+ q.add_atom("C")
119
+ q.add_atom("N")
120
+ q.add_atom("C")
121
+ q.add_atom("O")
122
+ q.add_atom("O")
123
+ q.add_bond(1, 2, 2)
124
+ q.add_bond(2, 3, 1)
125
+ q.add_bond(3, 4, 3)
126
+ q.add_bond(2, 5, 1)
127
+ q.add_bond(5, 6, 2)
128
+ q.add_bond(5, 7, 1)
129
+
130
+ p1.add_atom("C", _map=2)
131
+ p1.add_atom("C")
132
+ p1.add_atom("N")
133
+ p1.add_atom("C")
134
+ p1.add_atom("O")
135
+ p1.add_atom("O")
136
+ p1.add_bond(2, 3, 1)
137
+ p1.add_bond(3, 4, 3)
138
+ p1.add_bond(2, 5, 1)
139
+ p1.add_bond(5, 6, 2)
140
+ p1.add_bond(5, 7, 1)
141
+
142
+ p2.add_atom("C", _map=1)
143
+ p2.add_atom("O", _map=8)
144
+ p2.add_bond(1, 8, 2)
145
+
146
+ # Knoevenagel condensation (double nitryl case)
147
+ # [C]=[C](-[C]#[N])-[C]#[N]>>[C]=[O].[C](-[C]#[N])-[C]#[N]
148
+ q, p1, p2 = prepare()
149
+ q.add_atom("C")
150
+ q.add_atom("C")
151
+ q.add_atom("C")
152
+ q.add_atom("N")
153
+ q.add_atom("C")
154
+ q.add_atom("N")
155
+ q.add_bond(1, 2, 2)
156
+ q.add_bond(2, 3, 1)
157
+ q.add_bond(3, 4, 3)
158
+ q.add_bond(2, 5, 1)
159
+ q.add_bond(5, 6, 3)
160
+
161
+ p1.add_atom("C", _map=2)
162
+ p1.add_atom("C")
163
+ p1.add_atom("N")
164
+ p1.add_atom("C")
165
+ p1.add_atom("N")
166
+ p1.add_bond(2, 3, 1)
167
+ p1.add_bond(3, 4, 3)
168
+ p1.add_bond(2, 5, 1)
169
+ p1.add_bond(5, 6, 3)
170
+
171
+ p2.add_atom("C", _map=1)
172
+ p2.add_atom("O", _map=8)
173
+ p2.add_bond(1, 8, 2)
174
+
175
+ # Knoevenagel condensation (double carboxyl case)
176
+ # [C]=[C](-[C](-[O])=[O])-[C](-[O])=[O]>>[C]=[O].[C](-[C](-[O])=[O])-[C](-[O])=[O]
177
+ q, p1, p2 = prepare()
178
+ q.add_atom("C")
179
+ q.add_atom("C")
180
+ q.add_atom("C")
181
+ q.add_atom("O")
182
+ q.add_atom("O")
183
+ q.add_atom("C")
184
+ q.add_atom("O")
185
+ q.add_atom("O")
186
+ q.add_bond(1, 2, 2)
187
+ q.add_bond(2, 3, 1)
188
+ q.add_bond(3, 4, 2)
189
+ q.add_bond(3, 5, 1)
190
+ q.add_bond(2, 6, 1)
191
+ q.add_bond(6, 7, 2)
192
+ q.add_bond(6, 8, 1)
193
+
194
+ p1.add_atom("C", _map=2)
195
+ p1.add_atom("C")
196
+ p1.add_atom("O")
197
+ p1.add_atom("O")
198
+ p1.add_atom("C")
199
+ p1.add_atom("O")
200
+ p1.add_atom("O")
201
+ p1.add_bond(2, 3, 1)
202
+ p1.add_bond(3, 4, 2)
203
+ p1.add_bond(3, 5, 1)
204
+ p1.add_bond(2, 6, 1)
205
+ p1.add_bond(6, 7, 2)
206
+ p1.add_bond(6, 8, 1)
207
+
208
+ p2.add_atom("C", _map=1)
209
+ p2.add_atom("O", _map=9)
210
+ p2.add_bond(1, 9, 2)
211
+
212
+ # heterocyclization with guanidine
213
+ # [c]((-[N;W0;Zs])@[n]@[c](-[N;D1])@[c;W0])@[n]@[c]-[O; D1]>>[C](-[N])(=[N])-[N].[C](#[N])-[C]-[C](-[O])=[O]
214
+ q, p1, p2 = prepare()
215
+ q.add_atom("C")
216
+ q.add_atom("N", heteroatoms=0, hybridization=1)
217
+ q.add_atom("N")
218
+ q.add_atom("C")
219
+ q.add_atom("N", neighbors=1)
220
+ q.add_atom("C", heteroatoms=0)
221
+ q.add_atom("N")
222
+ q.add_atom("C")
223
+ q.add_atom("O", neighbors=1)
224
+ q.add_bond(1, 2, 1)
225
+ q.add_bond(1, 3, 4)
226
+ q.add_bond(3, 4, 4)
227
+ q.add_bond(4, 5, 1)
228
+ q.add_bond(4, 6, 4)
229
+ q.add_bond(1, 7, 4)
230
+ q.add_bond(7, 8, 4)
231
+ q.add_bond(8, 9, 1)
232
+
233
+ p1.add_atom("C")
234
+ p1.add_atom("N")
235
+ p1.add_atom("N")
236
+ p1.add_atom("N", _map=7)
237
+ p1.add_bond(1, 2, 1)
238
+ p1.add_bond(1, 3, 2)
239
+ p1.add_bond(1, 7, 1)
240
+
241
+ p2.add_atom("C", _map=4)
242
+ p2.add_atom("N")
243
+ p2.add_atom("C")
244
+ p2.add_atom("C", _map=8)
245
+ p2.add_atom("O", _map=9)
246
+ p2.add_atom("O")
247
+ p2.add_bond(4, 5, 3)
248
+ p2.add_bond(4, 6, 1)
249
+ p2.add_bond(6, 8, 1)
250
+ p2.add_bond(8, 9, 2)
251
+ p2.add_bond(8, 10, 1)
252
+
253
+ # alkylation of amine
254
+ # [C]-[N]-[C]>>[C]-[N].[C]-[Br]
255
+ q, p1, p2 = prepare()
256
+ q.add_atom("C")
257
+ q.add_atom("N")
258
+ q.add_atom("C")
259
+ q.add_atom("C")
260
+ q.add_bond(1, 2, 1)
261
+ q.add_bond(2, 3, 1)
262
+ q.add_bond(2, 4, 1)
263
+
264
+ p1.add_atom("C")
265
+ p1.add_atom("N")
266
+ p1.add_atom("C")
267
+ p1.add_bond(1, 2, 1)
268
+ p1.add_bond(2, 3, 1)
269
+
270
+ p2.add_atom("C", _map=4)
271
+ p2.add_atom("Cl")
272
+ p2.add_bond(4, 5, 1)
273
+
274
+ # Synthesis of guanidines
275
+ #
276
+ q, p1, p2 = prepare()
277
+ q.add_atom("N")
278
+ q.add_atom("C")
279
+ q.add_atom("N", hybridization=1)
280
+ q.add_atom("N", hybridization=1)
281
+ q.add_bond(1, 2, 2)
282
+ q.add_bond(2, 3, 1)
283
+ q.add_bond(2, 4, 1)
284
+
285
+ p1.add_atom("N")
286
+ p1.add_atom("C")
287
+ p1.add_atom("N")
288
+ p1.add_bond(1, 2, 3)
289
+ p1.add_bond(2, 3, 1)
290
+
291
+ p2.add_atom("N", _map=4)
292
+
293
+ # Grignard reaction with nitrile
294
+ #
295
+ q, p1, p2 = prepare()
296
+ q.add_atom("C")
297
+ q.add_atom("C")
298
+ q.add_atom("O")
299
+ q.add_atom("C")
300
+ q.add_bond(1, 2, 1)
301
+ q.add_bond(2, 3, 2)
302
+ q.add_bond(2, 4, 1)
303
+
304
+ p1.add_atom("C")
305
+ p1.add_atom("C")
306
+ p1.add_atom("N")
307
+ p1.add_bond(1, 2, 1)
308
+ p1.add_bond(2, 3, 3)
309
+
310
+ p2.add_atom("C", _map=4)
311
+ p2.add_atom("Br")
312
+ p2.add_bond(4, 5, 1)
313
+
314
+ # Alkylation of alpha-carbon atom of nitrile
315
+ #
316
+ q, p1, p2 = prepare()
317
+ q.add_atom("N")
318
+ q.add_atom("C")
319
+ q.add_atom("C", neighbors=(3, 4))
320
+ q.add_atom("C", hybridization=1)
321
+ q.add_bond(1, 2, 3)
322
+ q.add_bond(2, 3, 1)
323
+ q.add_bond(3, 4, 1)
324
+
325
+ p1.add_atom("N")
326
+ p1.add_atom("C")
327
+ p1.add_atom("C")
328
+ p1.add_bond(1, 2, 3)
329
+ p1.add_bond(2, 3, 1)
330
+
331
+ p2.add_atom("C", _map=4)
332
+ p2.add_atom("Cl")
333
+ p2.add_bond(4, 5, 1)
334
+
335
+ # Gomberg-Bachmann reaction
336
+ #
337
+ q, p1, p2 = prepare()
338
+ q.add_atom("C", hybridization=4, heteroatoms=0)
339
+ q.add_atom("C", hybridization=4, heteroatoms=0)
340
+ q.add_bond(1, 2, 1)
341
+
342
+ p1.add_atom("C")
343
+ p1.add_atom("N", _map=3)
344
+ p1.add_bond(1, 3, 1)
345
+
346
+ p2.add_atom("C", _map=2)
347
+
348
+ # Cyclocondensation
349
+ #
350
+ q, p1, p2 = prepare()
351
+ q.add_atom("N", neighbors=2)
352
+ q.add_atom("C")
353
+ q.add_atom("C")
354
+ q.add_atom("C")
355
+ q.add_atom("N")
356
+ q.add_atom("C")
357
+ q.add_atom("C")
358
+ q.add_atom("O", neighbors=1)
359
+ q.add_bond(1, 2, 1)
360
+ q.add_bond(2, 3, 1)
361
+ q.add_bond(3, 4, 1)
362
+ q.add_bond(4, 5, 2)
363
+ q.add_bond(5, 6, 1)
364
+ q.add_bond(6, 7, 1)
365
+ q.add_bond(7, 8, 2)
366
+ q.add_bond(1, 7, 1)
367
+
368
+ p1.add_atom("N")
369
+ p1.add_atom("C")
370
+ p1.add_atom("C")
371
+ p1.add_atom("C")
372
+ p1.add_atom("O", _map=9)
373
+ p1.add_bond(1, 2, 1)
374
+ p1.add_bond(2, 3, 1)
375
+ p1.add_bond(3, 4, 1)
376
+ p1.add_bond(4, 9, 2)
377
+
378
+ p2.add_atom("N", _map=5)
379
+ p2.add_atom("C")
380
+ p2.add_atom("C")
381
+ p2.add_atom("O")
382
+ p2.add_atom("O", _map=10)
383
+ p2.add_bond(5, 6, 1)
384
+ p2.add_bond(6, 7, 1)
385
+ p2.add_bond(7, 8, 2)
386
+ p2.add_bond(7, 10, 1)
387
+
388
+ # heterocyclization dicarboxylic acids
389
+ #
390
+ q, p1, p2 = prepare()
391
+ q.add_atom("C", rings_sizes=(5, 6))
392
+ q.add_atom("O")
393
+ q.add_atom(ListElement(["O", "N"]))
394
+ q.add_atom("C", rings_sizes=(5, 6))
395
+ q.add_atom("O")
396
+ q.add_bond(1, 2, 2)
397
+ q.add_bond(1, 3, 1)
398
+ q.add_bond(3, 4, 1)
399
+ q.add_bond(4, 5, 2)
400
+
401
+ p1.add_atom("C")
402
+ p1.add_atom("O")
403
+ p1.add_atom("O", _map=6)
404
+ p1.add_bond(1, 2, 2)
405
+ p1.add_bond(1, 6, 1)
406
+
407
+ p2.add_atom("C", _map=4)
408
+ p2.add_atom("O")
409
+ p2.add_atom("O", _map=7)
410
+ p2.add_bond(4, 5, 2)
411
+ p2.add_bond(4, 7, 1)
412
+
413
+ __all__ = ["rules"]
synplan/chem/reaction_rules/manual/transformations.py ADDED
@@ -0,0 +1,532 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Module containing hardcoded transformation reaction rules."""
2
+
3
+ from CGRtools import QueryContainer, ReactionContainer
4
+ from CGRtools.periodictable import ListElement
5
+
6
+ rules = []
7
+
8
+
9
+ def prepare():
10
+ """Creates and returns three query containers and appends a reaction container to
11
+ the "rules" list."""
12
+ q_ = QueryContainer()
13
+ p_ = QueryContainer()
14
+ rules.append(ReactionContainer((q_,), (p_,)))
15
+ return q_, p_
16
+
17
+
18
+ # aryl nitro reduction
19
+ # [C;Za;W1]-[N;D1]>>[O-]-[N+](-[C])=[O]
20
+ q, p = prepare()
21
+ q.add_atom("N", neighbors=1)
22
+ q.add_atom("C", hybridization=4, heteroatoms=1)
23
+ q.add_bond(1, 2, 1)
24
+
25
+ p.add_atom("N", charge=1)
26
+ p.add_atom("C")
27
+ p.add_atom("O", charge=-1)
28
+ p.add_atom("O")
29
+ p.add_bond(1, 2, 1)
30
+ p.add_bond(1, 3, 1)
31
+ p.add_bond(1, 4, 2)
32
+
33
+ # aryl nitration
34
+ # [O-]-[N+](=[O])-[C;Za;W12]>>[C]
35
+ q, p = prepare()
36
+ q.add_atom("N", charge=1)
37
+ q.add_atom("C", hybridization=4, heteroatoms=(1, 2))
38
+ q.add_atom("O", charge=-1)
39
+ q.add_atom("O")
40
+ q.add_bond(1, 2, 1)
41
+ q.add_bond(1, 3, 1)
42
+ q.add_bond(1, 4, 2)
43
+
44
+ p.add_atom("C", _map=2)
45
+
46
+ # Beckmann rearrangement (oxime -> amide)
47
+ # [C]-[N;D2]-[C]=[O]>>[O]-[N]=[C]-[C]
48
+ q, p = prepare()
49
+ q.add_atom("C")
50
+ q.add_atom("N", neighbors=2)
51
+ q.add_atom("O")
52
+ q.add_atom("C")
53
+ q.add_bond(1, 2, 1)
54
+ q.add_bond(1, 3, 2)
55
+ q.add_bond(2, 4, 1)
56
+
57
+ p.add_atom("C")
58
+ p.add_atom("N")
59
+ p.add_atom("O")
60
+ p.add_atom("C")
61
+ p.add_bond(1, 2, 2)
62
+ p.add_bond(2, 3, 1)
63
+ p.add_bond(1, 4, 1)
64
+
65
+ # aldehydes or ketones into oxime/imine reaction
66
+ # [C;Zd;W1]=[N]>>[C]=[O]
67
+ q, p = prepare()
68
+ q.add_atom("C", hybridization=2, heteroatoms=1)
69
+ q.add_atom("N")
70
+ q.add_bond(1, 2, 2)
71
+
72
+ p.add_atom("C")
73
+ p.add_atom("O", _map=3)
74
+ p.add_bond(1, 3, 2)
75
+
76
+ # addition of halogen atom into phenol ring (orto)
77
+ # [C](-[Cl,F,Br,I;D1]):[C]-[O,N;Zs]>>[C](-[A]):[C]
78
+ q, p = prepare()
79
+ q.add_atom(ListElement(["O", "N"]), hybridization=1)
80
+ q.add_atom("C")
81
+ q.add_atom("C")
82
+ q.add_atom(ListElement(["Cl", "F", "Br", "I"]), neighbors=1)
83
+ q.add_bond(1, 2, 1)
84
+ q.add_bond(2, 3, 4)
85
+ q.add_bond(3, 4, 1)
86
+
87
+ p.add_atom("A")
88
+ p.add_atom("C")
89
+ p.add_atom("C")
90
+ p.add_bond(1, 2, 1)
91
+ p.add_bond(2, 3, 4)
92
+
93
+ # addition of halogen atom into phenol ring (para)
94
+ # [C](:[C]:[C]:[C]-[O,N;Zs])-[Cl,F,Br,I;D1]>>[A]-[C]:[C]:[C]:[C]
95
+ q, p = prepare()
96
+ q.add_atom(ListElement(["O", "N"]), hybridization=1)
97
+ q.add_atom("C")
98
+ q.add_atom("C")
99
+ q.add_atom("C")
100
+ q.add_atom("C")
101
+ q.add_atom(ListElement(["Cl", "F", "Br", "I"]), neighbors=1)
102
+ q.add_bond(1, 2, 1)
103
+ q.add_bond(2, 3, 4)
104
+ q.add_bond(3, 4, 4)
105
+ q.add_bond(4, 5, 4)
106
+ q.add_bond(5, 6, 1)
107
+
108
+ p.add_atom("A")
109
+ p.add_atom("C")
110
+ p.add_atom("C")
111
+ p.add_atom("C")
112
+ p.add_atom("C")
113
+ p.add_bond(1, 2, 1)
114
+ p.add_bond(2, 3, 4)
115
+ p.add_bond(3, 4, 4)
116
+ p.add_bond(4, 5, 4)
117
+
118
+ # hard reduction of Ar-ketones
119
+ # [C;Za]-[C;D2;Zs;W0]>>[C]-[C]=[O]
120
+ q, p = prepare()
121
+ q.add_atom("C", hybridization=4)
122
+ q.add_atom("C", hybridization=1, neighbors=2, heteroatoms=0)
123
+ q.add_bond(1, 2, 1)
124
+
125
+ p.add_atom("C")
126
+ p.add_atom("C")
127
+ p.add_atom("O")
128
+ p.add_bond(1, 2, 1)
129
+ p.add_bond(2, 3, 2)
130
+
131
+ # reduction of alpha-hydroxy pyridine
132
+ # [C;W1]:[N;H0;r6]>>[C](:[N])-[O]
133
+ q, p = prepare()
134
+ q.add_atom("C", heteroatoms=1)
135
+ q.add_atom("N", rings_sizes=6, hydrogens=0)
136
+ q.add_bond(1, 2, 4)
137
+
138
+ p.add_atom("C")
139
+ p.add_atom("N")
140
+ p.add_atom("O")
141
+ p.add_bond(1, 2, 4)
142
+ p.add_bond(1, 3, 1)
143
+
144
+ # Reduction of alkene
145
+ # [C]-[C;D23;Zs;W0]-[C;D123;Zs;W0]>>[C](-[C])=[C]
146
+ q, p = prepare()
147
+ q.add_atom("C")
148
+ q.add_atom("C", heteroatoms=0, neighbors=(2, 3), hybridization=1)
149
+ q.add_atom("C", heteroatoms=0, neighbors=(1, 2, 3), hybridization=1)
150
+ q.add_bond(1, 2, 1)
151
+ q.add_bond(2, 3, 1)
152
+
153
+ p.add_atom("C")
154
+ p.add_atom("C")
155
+ p.add_atom("C")
156
+ p.add_bond(1, 2, 1)
157
+ p.add_bond(2, 3, 2)
158
+
159
+ # Kolbe-Schmitt reaction
160
+ # [C](:[C]-[O;D1])-[C](=[O])-[O;D1]>>[C](-[O]):[C]
161
+ q, p = prepare()
162
+ q.add_atom("O", neighbors=1)
163
+ q.add_atom("C")
164
+ q.add_atom("C")
165
+ q.add_atom("C")
166
+ q.add_atom("O", neighbors=1)
167
+ q.add_atom("O")
168
+ q.add_bond(1, 2, 1)
169
+ q.add_bond(2, 3, 4)
170
+ q.add_bond(3, 4, 1)
171
+ q.add_bond(4, 5, 1)
172
+ q.add_bond(4, 6, 2)
173
+
174
+ p.add_atom("O")
175
+ p.add_atom("C")
176
+ p.add_atom("C")
177
+ p.add_bond(1, 2, 1)
178
+ p.add_bond(2, 3, 4)
179
+
180
+ # reduction of carboxylic acid
181
+ # [O;D1]-[C;D2]-[C]>>[C]-[C](-[O])=[O]
182
+ q, p = prepare()
183
+ q.add_atom("C")
184
+ q.add_atom("C", neighbors=2)
185
+ q.add_atom("O", neighbors=1)
186
+ q.add_bond(1, 2, 1)
187
+ q.add_bond(2, 3, 1)
188
+
189
+ p.add_atom("C")
190
+ p.add_atom("C")
191
+ p.add_atom("O")
192
+ p.add_atom("O")
193
+ p.add_bond(1, 2, 1)
194
+ p.add_bond(2, 3, 1)
195
+ p.add_bond(2, 4, 2)
196
+
197
+ # halogenation of alcohols
198
+ # [C;Zs]-[Cl,Br;D1]>>[C]-[O]
199
+ q, p = prepare()
200
+ q.add_atom("C", hybridization=1, heteroatoms=1)
201
+ q.add_atom(ListElement(["Cl", "Br"]), neighbors=1)
202
+ q.add_bond(1, 2, 1)
203
+
204
+ p.add_atom("C")
205
+ p.add_atom("O", _map=3)
206
+ p.add_bond(1, 3, 1)
207
+
208
+ # Kolbe nitrilation
209
+ # [N]#[C]-[C;Zs;W0]>>[Br]-[C]
210
+ q, p = prepare()
211
+ q.add_atom("C", heteroatoms=0, hybridization=1)
212
+ q.add_atom("C")
213
+ q.add_atom("N")
214
+ q.add_bond(1, 2, 1)
215
+ q.add_bond(2, 3, 3)
216
+
217
+ p.add_atom("C")
218
+ p.add_atom("Br", _map=4)
219
+ p.add_bond(1, 4, 1)
220
+
221
+ # Nitrile hydrolysis
222
+ # [O;D1]-[C]=[O]>>[N]#[C]
223
+ q, p = prepare()
224
+ q.add_atom("C")
225
+ q.add_atom("O", neighbors=1)
226
+ q.add_atom("O")
227
+ q.add_bond(1, 2, 1)
228
+ q.add_bond(1, 3, 2)
229
+
230
+ p.add_atom("C")
231
+ p.add_atom("N", _map=4)
232
+ p.add_bond(1, 4, 3)
233
+
234
+ # sulfamidation
235
+ # [c]-[S](=[O])(=[O])-[N]>>[c]
236
+ q, p = prepare()
237
+ q.add_atom("C", hybridization=4)
238
+ q.add_atom("S")
239
+ q.add_atom("O")
240
+ q.add_atom("O")
241
+ q.add_atom("N", neighbors=1)
242
+ q.add_bond(1, 2, 1)
243
+ q.add_bond(2, 3, 2)
244
+ q.add_bond(2, 4, 2)
245
+ q.add_bond(2, 5, 1)
246
+
247
+ p.add_atom("C")
248
+
249
+ # Ring expansion rearrangement
250
+ #
251
+ q, p = prepare()
252
+ q.add_atom("C")
253
+ q.add_atom("N")
254
+ q.add_atom("C", rings_sizes=6)
255
+ q.add_atom("C")
256
+ q.add_atom("O")
257
+ q.add_atom("C")
258
+ q.add_atom("C")
259
+ q.add_bond(1, 2, 1)
260
+ q.add_bond(2, 3, 1)
261
+ q.add_bond(3, 4, 1)
262
+ q.add_bond(4, 5, 2)
263
+ q.add_bond(3, 6, 1)
264
+ q.add_bond(4, 7, 1)
265
+
266
+ p.add_atom("C")
267
+ p.add_atom("N")
268
+ p.add_atom("C")
269
+ p.add_atom("C")
270
+ p.add_atom("O")
271
+ p.add_atom("C")
272
+ p.add_atom("C")
273
+ p.add_bond(1, 2, 1)
274
+ p.add_bond(2, 3, 2)
275
+ p.add_bond(3, 4, 1)
276
+ p.add_bond(4, 5, 1)
277
+ p.add_bond(4, 6, 1)
278
+ p.add_bond(4, 7, 1)
279
+
280
+ # hydrolysis of bromide alkyl
281
+ #
282
+ q, p = prepare()
283
+ q.add_atom("C", hybridization=1)
284
+ q.add_atom("O", neighbors=1)
285
+ q.add_bond(1, 2, 1)
286
+
287
+ p.add_atom("C")
288
+ p.add_atom("Br")
289
+ p.add_bond(1, 2, 1)
290
+
291
+ # Condensation of ketones/aldehydes and amines into imines
292
+ #
293
+ q, p = prepare()
294
+ q.add_atom("N", neighbors=(1, 2))
295
+ q.add_atom("C", neighbors=(2, 3), heteroatoms=1)
296
+ q.add_bond(1, 2, 2)
297
+
298
+ p.add_atom("C", _map=2)
299
+ p.add_atom("O")
300
+ p.add_bond(2, 3, 2)
301
+
302
+ # Halogenation of alkanes
303
+ #
304
+ q, p = prepare()
305
+ q.add_atom("C", hybridization=1)
306
+ q.add_atom(ListElement(["F", "Cl", "Br"]))
307
+ q.add_bond(1, 2, 1)
308
+
309
+ p.add_atom("C")
310
+
311
+ # heterocyclization
312
+ #
313
+ q, p = prepare()
314
+ q.add_atom("N", heteroatoms=0, hybridization=1, neighbors=(2, 3))
315
+ q.add_atom("C", heteroatoms=2)
316
+ q.add_atom("N", heteroatoms=0, neighbors=2)
317
+ q.add_bond(1, 2, 1)
318
+ q.add_bond(2, 3, 2)
319
+
320
+ p.add_atom("N")
321
+ p.add_atom("C")
322
+ p.add_atom("N")
323
+ p.add_atom("O")
324
+ p.add_bond(1, 2, 1)
325
+ p.add_bond(2, 4, 2)
326
+
327
+ # Reduction of nitrile
328
+ #
329
+ q, p = prepare()
330
+ q.add_atom("N", neighbors=1)
331
+ q.add_atom("C")
332
+ q.add_atom("C", hybridization=1)
333
+ q.add_bond(1, 2, 1)
334
+ q.add_bond(2, 3, 1)
335
+
336
+ p.add_atom("N")
337
+ p.add_atom("C")
338
+ p.add_atom("C")
339
+ p.add_bond(1, 2, 3)
340
+ p.add_bond(2, 3, 1)
341
+
342
+ # SPECIAL CASE
343
+ # Reduction of nitrile into methylamine
344
+ #
345
+ q, p = prepare()
346
+ q.add_atom("C", neighbors=1)
347
+ q.add_atom("N", neighbors=2)
348
+ q.add_atom("C")
349
+ q.add_atom("C", hybridization=1)
350
+ q.add_bond(1, 2, 1)
351
+ q.add_bond(2, 3, 1)
352
+ q.add_bond(3, 4, 1)
353
+
354
+ p.add_atom("N", _map=2)
355
+ p.add_atom("C")
356
+ p.add_atom("C")
357
+ p.add_bond(2, 3, 3)
358
+ p.add_bond(3, 4, 1)
359
+
360
+ # methylation of amides
361
+ #
362
+ q, p = prepare()
363
+ q.add_atom("O")
364
+ q.add_atom("C")
365
+ q.add_atom("N")
366
+ q.add_atom("C", neighbors=1)
367
+ q.add_bond(1, 2, 2)
368
+ q.add_bond(2, 3, 1)
369
+ q.add_bond(3, 4, 1)
370
+
371
+ p.add_atom("O")
372
+ p.add_atom("C")
373
+ p.add_atom("N")
374
+ p.add_bond(1, 2, 2)
375
+ p.add_bond(2, 3, 1)
376
+
377
+ # hydrocyanation of alkenes
378
+ #
379
+ q, p = prepare()
380
+ q.add_atom("C", hybridization=1)
381
+ q.add_atom("C")
382
+ q.add_atom("C")
383
+ q.add_atom("N")
384
+ q.add_bond(1, 2, 1)
385
+ q.add_bond(2, 3, 1)
386
+ q.add_bond(3, 4, 3)
387
+
388
+ p.add_atom("C")
389
+ p.add_atom("C")
390
+ p.add_bond(1, 2, 2)
391
+
392
+ # decarbocylation (alpha atom of nitrile)
393
+ #
394
+ q, p = prepare()
395
+ q.add_atom("N")
396
+ q.add_atom("C")
397
+ q.add_atom("C", neighbors=2)
398
+ q.add_bond(1, 2, 3)
399
+ q.add_bond(2, 3, 1)
400
+
401
+ p.add_atom("N")
402
+ p.add_atom("C")
403
+ p.add_atom("C")
404
+ p.add_atom("C")
405
+ p.add_atom("O")
406
+ p.add_atom("O")
407
+ p.add_bond(1, 2, 3)
408
+ p.add_bond(2, 3, 1)
409
+ p.add_bond(3, 4, 1)
410
+ p.add_bond(4, 5, 2)
411
+ p.add_bond(4, 6, 1)
412
+
413
+ # Bichler-Napieralski reaction
414
+ #
415
+ q, p = prepare()
416
+ q.add_atom("C", rings_sizes=(6,))
417
+ q.add_atom("C", rings_sizes=(6,))
418
+ q.add_atom("N", rings_sizes=(6,), neighbors=2)
419
+ q.add_atom("C")
420
+ q.add_atom("C")
421
+ q.add_atom("C")
422
+ q.add_atom("O")
423
+ q.add_atom("O")
424
+ q.add_atom("C")
425
+ q.add_atom("O", neighbors=1)
426
+ q.add_bond(1, 2, 4)
427
+ q.add_bond(2, 3, 1)
428
+ q.add_bond(3, 4, 1)
429
+ q.add_bond(4, 5, 2)
430
+ q.add_bond(5, 6, 1)
431
+ q.add_bond(6, 7, 2)
432
+ q.add_bond(6, 8, 1)
433
+ q.add_bond(5, 9, 4)
434
+ q.add_bond(9, 10, 1)
435
+ q.add_bond(1, 9, 1)
436
+
437
+ p.add_atom("C")
438
+ p.add_atom("C")
439
+ p.add_atom("N")
440
+ p.add_atom("C")
441
+ p.add_atom("C")
442
+ p.add_atom("C")
443
+ p.add_atom("O")
444
+ p.add_atom("O")
445
+ p.add_atom("C")
446
+ p.add_atom("O")
447
+ p.add_atom("O")
448
+ p.add_bond(1, 2, 4)
449
+ p.add_bond(2, 3, 1)
450
+ p.add_bond(3, 4, 1)
451
+ p.add_bond(4, 5, 2)
452
+ p.add_bond(5, 6, 1)
453
+ p.add_bond(6, 7, 2)
454
+ p.add_bond(6, 8, 1)
455
+ p.add_bond(5, 9, 1)
456
+ p.add_bond(9, 10, 2)
457
+ p.add_bond(9, 11, 1)
458
+
459
+ # heterocyclization in Prins reaction
460
+ #
461
+ q, p = prepare()
462
+ q.add_atom("C")
463
+ q.add_atom("O")
464
+ q.add_atom("C")
465
+ q.add_atom(ListElement(["N", "O"]), neighbors=2)
466
+ q.add_atom("C")
467
+ q.add_atom("C")
468
+ q.add_bond(1, 2, 1)
469
+ q.add_bond(2, 3, 1)
470
+ q.add_bond(3, 4, 1)
471
+ q.add_bond(4, 5, 1)
472
+ q.add_bond(5, 6, 1)
473
+ q.add_bond(1, 6, 1)
474
+
475
+ p.add_atom("C")
476
+ p.add_atom("C", _map=5)
477
+ p.add_bond(1, 5, 2)
478
+
479
+ # recyclization of tetrahydropyran through an opening the ring and dehydration
480
+ #
481
+ q, p = prepare()
482
+ q.add_atom("C")
483
+ q.add_atom("C")
484
+ q.add_atom("C")
485
+ q.add_atom(ListElement(["N", "O"]))
486
+ q.add_atom("C")
487
+ q.add_atom("C")
488
+ q.add_bond(1, 2, 1)
489
+ q.add_bond(2, 3, 1)
490
+ q.add_bond(3, 4, 1)
491
+ q.add_bond(4, 5, 1)
492
+ q.add_bond(5, 6, 1)
493
+ q.add_bond(1, 6, 2)
494
+
495
+ p.add_atom("C")
496
+ p.add_atom("C")
497
+ p.add_atom("C")
498
+ p.add_atom("A")
499
+ p.add_atom("C")
500
+ p.add_atom("C")
501
+ p.add_atom("O")
502
+ p.add_bond(1, 2, 1)
503
+ p.add_bond(1, 7, 1)
504
+ p.add_bond(3, 7, 1)
505
+ p.add_bond(3, 4, 1)
506
+ p.add_bond(4, 5, 1)
507
+ p.add_bond(5, 6, 1)
508
+ p.add_bond(1, 6, 1)
509
+
510
+ # alkenes + h2o/hHal
511
+ #
512
+ q, p = prepare()
513
+ q.add_atom("C", hybridization=1)
514
+ q.add_atom("C", hybridization=1)
515
+ q.add_atom(ListElement(["O", "F", "Cl", "Br", "I"]), neighbors=1)
516
+ q.add_bond(1, 2, 1)
517
+ q.add_bond(2, 3, 1)
518
+
519
+ p.add_atom("C")
520
+ p.add_atom("C")
521
+ p.add_bond(1, 2, 2)
522
+
523
+ # methylation of dimethylamines
524
+ #
525
+ q, p = prepare()
526
+ q.add_atom("C", neighbors=1)
527
+ q.add_atom("N", neighbors=3)
528
+ q.add_bond(1, 2, 1)
529
+
530
+ p.add_atom("N", _map=2)
531
+
532
+ __all__ = ["rules"]
synplan/chem/utils.py ADDED
@@ -0,0 +1,225 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Module containing additional functions needed in different reaction data processing
2
+ protocols."""
3
+
4
+ import logging
5
+ from typing import Iterable
6
+
7
+ from CGRtools.containers import (
8
+ CGRContainer,
9
+ MoleculeContainer,
10
+ QueryContainer,
11
+ ReactionContainer,
12
+ )
13
+ from CGRtools.exceptions import InvalidAromaticRing
14
+ from tqdm import tqdm
15
+
16
+ from synplan.chem import smiles_parser
17
+ from synplan.utils.files import MoleculeReader, MoleculeWriter
18
+
19
+ from chython import MoleculeContainer as MoleculeContainerChython
20
+
21
+
22
+ def mol_from_smiles(
23
+ smiles: str,
24
+ standardize: bool = True,
25
+ clean_stereo: bool = True,
26
+ clean2d: bool = True,
27
+ ) -> MoleculeContainer:
28
+ """Converts a SMILES string to a `MoleculeContainer` object and optionally
29
+ standardizes, cleans stereochemistry, and cleans 2D coordinates.
30
+
31
+ :param smiles: The SMILES string representing the molecule.
32
+ :param standardize: Whether to standardize the molecule (default is True).
33
+ :param clean_stereo: Whether to remove the stereo marks on atoms of the molecule (default is True).
34
+ :param clean2d: Whether to clean the 2D coordinates of the molecule (default is True).
35
+ :return: The processed molecule object.
36
+ :raises ValueError: If the SMILES string could not be processed by CGRtools.
37
+ """
38
+ molecule = smiles_parser(smiles)
39
+
40
+ if not isinstance(molecule, MoleculeContainer):
41
+ raise ValueError("SMILES string was not processed by CGRtools")
42
+
43
+ tmp = molecule.copy()
44
+ try:
45
+ if standardize:
46
+ tmp.canonicalize()
47
+ if clean_stereo:
48
+ tmp.clean_stereo()
49
+ if clean2d:
50
+ tmp.clean2d()
51
+ molecule = tmp
52
+ except InvalidAromaticRing:
53
+ logging.warning(
54
+ "CGRtools was not able to standardize molecule due to invalid aromatic ring"
55
+ )
56
+ return molecule
57
+
58
+
59
+ def query_to_mol(query: QueryContainer) -> MoleculeContainer:
60
+ """Converts a QueryContainer object into a MoleculeContainer object.
61
+
62
+ :param query: A QueryContainer object representing the query structure.
63
+ :return: A MoleculeContainer object that replicates the structure of the query.
64
+ """
65
+ new_mol = MoleculeContainer()
66
+ for n, atom in query.atoms():
67
+ new_mol.add_atom(
68
+ atom.atomic_symbol, n, charge=atom.charge, is_radical=atom.is_radical
69
+ )
70
+ for i, j, bond in query.bonds():
71
+ new_mol.add_bond(i, j, int(bond))
72
+ return new_mol
73
+
74
+
75
+ def reaction_query_to_reaction(reaction_rule: ReactionContainer) -> ReactionContainer:
76
+ """Converts a ReactionContainer object with query structures into a
77
+ ReactionContainer with molecular structures.
78
+
79
+ :param reaction_rule: A ReactionContainer object where reactants and products are
80
+ QueryContainer objects.
81
+ :return: A new ReactionContainer object where reactants and products are
82
+ MoleculeContainer objects.
83
+ """
84
+ reactants = [query_to_mol(q) for q in reaction_rule.reactants]
85
+ products = [query_to_mol(q) for q in reaction_rule.products]
86
+ reagents = [
87
+ query_to_mol(q) for q in reaction_rule.reagents
88
+ ] # Assuming reagents are also part of the rule
89
+ reaction = ReactionContainer(reactants, products, reagents, reaction_rule.meta)
90
+ reaction.name = reaction_rule.name
91
+ return reaction
92
+
93
+
94
+ def unite_molecules(molecules: Iterable[MoleculeContainer]) -> MoleculeContainer:
95
+ """Unites a list of MoleculeContainer objects into a single MoleculeContainer. This
96
+ function takes multiple molecules and combines them into one larger molecule. The
97
+ first molecule in the list is taken as the base, and subsequent molecules are united
98
+ with it sequentially.
99
+
100
+ :param molecules: A list of MoleculeContainer objects to be united.
101
+ :return: A single MoleculeContainer object representing the union of all input
102
+ molecules.
103
+ """
104
+ new_mol = MoleculeContainer()
105
+ for mol in molecules:
106
+ new_mol = new_mol.union(mol)
107
+ return new_mol
108
+
109
+
110
+ def safe_canonicalization(molecule: MoleculeContainer) -> MoleculeContainer:
111
+ """Attempts to canonicalize a molecule, handling any exceptions. If the
112
+ canonicalization process fails due to an InvalidAromaticRing exception, it safely
113
+ returns the original molecule.
114
+
115
+ :param molecule: The given molecule to be canonicalized.
116
+ :return: The canonicalized molecule if successful, otherwise the original molecule.
117
+ """
118
+ molecule._atoms = dict(sorted(molecule._atoms.items()))
119
+
120
+ molecule_copy = molecule.copy()
121
+ try:
122
+ molecule_copy.canonicalize()
123
+ molecule_copy.clean_stereo()
124
+ return molecule_copy
125
+ except InvalidAromaticRing:
126
+ return molecule
127
+
128
+
129
+ def standardize_building_blocks(input_file: str, output_file: str) -> str:
130
+ """Standardizes custom building blocks.
131
+
132
+ :param input_file: The path to the file that stores the original building blocks.
133
+ :param output_file: The path to the file that will store the standardized building
134
+ blocks.
135
+ :return: The path to the file with standardized building blocks.
136
+ """
137
+ if input_file == output_file:
138
+ raise ValueError("input_file name and output_file name cannot be the same.")
139
+
140
+ with MoleculeReader(input_file) as inp_file, MoleculeWriter(
141
+ output_file
142
+ ) as out_file:
143
+ for mol in tqdm(
144
+ inp_file,
145
+ desc="Number of building blocks processed: ",
146
+ bar_format="{desc}{n} [{elapsed}]",
147
+ ):
148
+ try:
149
+ mol = safe_canonicalization(mol)
150
+ except Exception as e:
151
+ logging.debug(e)
152
+ continue
153
+ out_file.write(mol)
154
+
155
+ return output_file
156
+
157
+
158
+ def cgr_from_reaction_rule(reaction_rule: ReactionContainer) -> CGRContainer:
159
+ """Creates a CGR from the given reaction rule.
160
+
161
+ :param reaction_rule: The reaction rule to be converted.
162
+ :return: The resulting CGR.
163
+ """
164
+
165
+ reaction_rule = reaction_query_to_reaction(reaction_rule)
166
+ cgr_rule = ~reaction_rule
167
+
168
+ return cgr_rule
169
+
170
+
171
+ def hash_from_reaction_rule(reaction_rule: ReactionContainer) -> hash:
172
+ """Generates hash for the given reaction rule.
173
+
174
+ :param reaction_rule: The reaction rule to be converted.
175
+ :return: The resulting hash.
176
+ """
177
+
178
+ reactants_hash = tuple(sorted(hash(r) for r in reaction_rule.reactants))
179
+ reagents_hash = tuple(sorted(hash(r) for r in reaction_rule.reagents))
180
+ products_hash = tuple(sorted(hash(r) for r in reaction_rule.products))
181
+
182
+ return hash((reactants_hash, reagents_hash, products_hash))
183
+
184
+
185
+ def reverse_reaction(
186
+ reaction: ReactionContainer,
187
+ ) -> ReactionContainer:
188
+ """Reverses the given reaction.
189
+
190
+ :param reaction: The reaction to be reversed.
191
+ :return: The reversed reaction.
192
+ """
193
+ reversed_reaction = ReactionContainer(
194
+ reaction.products, reaction.reactants, reaction.reagents, reaction.meta
195
+ )
196
+ reversed_reaction.name = reaction.name
197
+
198
+ return reversed_reaction
199
+
200
+
201
+ def cgrtools_to_chython_molecule(molecule):
202
+ molecule_chython = MoleculeContainerChython()
203
+ for n, atom in molecule.atoms():
204
+ molecule_chython.add_atom(atom.atomic_symbol, n)
205
+
206
+ for n, m, bond in molecule.bonds():
207
+ molecule_chython.add_bond(n, m, int(bond))
208
+
209
+ return molecule_chython
210
+
211
+
212
+ def chython_query_to_cgrtools(query):
213
+ cgrtools_query = QueryContainer()
214
+ for n, atom in query.atoms():
215
+ cgrtools_query.add_atom(
216
+ atom=atom.atomic_symbol,
217
+ charge=atom.charge,
218
+ neighbors=atom.neighbors,
219
+ hybridization=atom.hybridization,
220
+ _map=n,
221
+ )
222
+ for n, m, bond in query.bonds():
223
+ cgrtools_query.add_bond(n, m, int(bond))
224
+
225
+ return cgrtools_query
synplan/interfaces/__init__.py ADDED
File without changes
synplan/interfaces/building_blocks/building_blocks_em_sa_ln.smi ADDED
The diff for this file is too large to render. See raw diff
 
synplan/interfaces/cli.py ADDED
@@ -0,0 +1,506 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Module containing commands line scripts for training and planning steps."""
2
+
3
+ import os
4
+ import warnings
5
+ from pathlib import Path
6
+
7
+ import click
8
+ import yaml
9
+
10
+ from synplan.chem.data.filtering import ReactionFilterConfig, filter_reactions_from_file
11
+ from synplan.chem.data.standardizing import (
12
+ ReactionStandardizationConfig,
13
+ standardize_reactions_from_file,
14
+ )
15
+ from synplan.chem.reaction_rules.extraction import extract_rules_from_reactions
16
+ from synplan.chem.reaction_routes.clustering import run_cluster_cli
17
+ from synplan.chem.utils import standardize_building_blocks
18
+ from synplan.mcts.search import run_search
19
+ from synplan.ml.training.supervised import create_policy_dataset, run_policy_training
20
+ from synplan.ml.training.reinforcement import run_updating
21
+ from synplan.utils.config import (
22
+ PolicyNetworkConfig,
23
+ RuleExtractionConfig,
24
+ TreeConfig,
25
+ TuningConfig,
26
+ ValueNetworkConfig,
27
+ )
28
+ from synplan.utils.loading import download_all_data
29
+ from synplan.utils.visualisation import (
30
+ routes_clustering_report,
31
+ routes_subclustering_report,
32
+ )
33
+
34
+ warnings.filterwarnings("ignore")
35
+
36
+
37
+ @click.group(name="synplan")
38
+ def synplan():
39
+ """SynPlanner command line interface."""
40
+
41
+
42
+ @synplan.command(name="download_all_data")
43
+ @click.option(
44
+ "--save_to",
45
+ "save_to",
46
+ help="Path to the folder where downloaded data will be stored.",
47
+ )
48
+ def download_all_data_cli(save_to: str = ".") -> None:
49
+ """Downloads all data for training, planning and benchmarking SynPlanner."""
50
+ download_all_data(save_to=save_to)
51
+
52
+
53
+ @synplan.command(name="building_blocks_standardizing")
54
+ @click.option(
55
+ "--input",
56
+ "input_file",
57
+ required=True,
58
+ type=click.Path(exists=True),
59
+ help="Path to the file with building blocks to be standardized.",
60
+ )
61
+ @click.option(
62
+ "--output",
63
+ "output_file",
64
+ required=True,
65
+ type=click.Path(),
66
+ help="Path to the file where standardized building blocks will be stored.",
67
+ )
68
+ def building_blocks_standardizing_cli(input_file: str, output_file: str) -> None:
69
+ """Standardizes building blocks."""
70
+ standardize_building_blocks(input_file=input_file, output_file=output_file)
71
+
72
+
73
+ @synplan.command(name="reaction_standardizing")
74
+ @click.option(
75
+ "--config",
76
+ "config_path",
77
+ required=True,
78
+ type=click.Path(exists=True),
79
+ help="Path to the configuration file for reactions standardizing.",
80
+ )
81
+ @click.option(
82
+ "--input",
83
+ "input_file",
84
+ required=True,
85
+ type=click.Path(exists=True),
86
+ help="Path to the file with reactions to be standardized.",
87
+ )
88
+ @click.option(
89
+ "--output",
90
+ "output_file",
91
+ type=click.Path(),
92
+ help="Path to the file where standardized reactions will be stored.",
93
+ )
94
+ @click.option(
95
+ "--num_cpus", default=4, type=int, help="The number of CPUs to use for processing."
96
+ )
97
+ def reaction_standardizing_cli(
98
+ config_path: str, input_file: str, output_file: str, num_cpus: int
99
+ ) -> None:
100
+ """Standardizes reactions and remove duplicates."""
101
+ stand_config = ReactionStandardizationConfig.from_yaml(config_path)
102
+ standardize_reactions_from_file(
103
+ config=stand_config,
104
+ input_reaction_data_path=input_file,
105
+ standardized_reaction_data_path=output_file,
106
+ num_cpus=num_cpus,
107
+ batch_size=100,
108
+ )
109
+
110
+
111
+ @synplan.command(name="reaction_filtering")
112
+ @click.option(
113
+ "--config",
114
+ "config_path",
115
+ required=True,
116
+ type=click.Path(exists=True),
117
+ help="Path to the configuration file for reactions filtering.",
118
+ )
119
+ @click.option(
120
+ "--input",
121
+ "input_file",
122
+ required=True,
123
+ type=click.Path(exists=True),
124
+ help="Path to the file with reactions to be filtered.",
125
+ )
126
+ @click.option(
127
+ "--output",
128
+ "output_file",
129
+ default=Path("./"),
130
+ type=click.Path(),
131
+ help="Path to the file where successfully filtered reactions will be stored.",
132
+ )
133
+ @click.option(
134
+ "--num_cpus", default=4, type=int, help="The number of CPUs to use for processing."
135
+ )
136
+ def reaction_filtering_cli(
137
+ config_path: str, input_file: str, output_file: str, num_cpus: int
138
+ ):
139
+ """Filters erroneous reactions."""
140
+ reaction_check_config = ReactionFilterConfig().from_yaml(config_path)
141
+ filter_reactions_from_file(
142
+ config=reaction_check_config,
143
+ input_reaction_data_path=input_file,
144
+ filtered_reaction_data_path=output_file,
145
+ num_cpus=num_cpus,
146
+ batch_size=100,
147
+ )
148
+
149
+
150
+ @synplan.command(name="rule_extracting")
151
+ @click.option(
152
+ "--config",
153
+ "config_path",
154
+ required=True,
155
+ type=click.Path(exists=True),
156
+ help="Path to the configuration file for reaction rules extracting.",
157
+ )
158
+ @click.option(
159
+ "--input",
160
+ "input_file",
161
+ required=True,
162
+ type=click.Path(exists=True),
163
+ help="Path to the file with reactions for reaction rules extraction.",
164
+ )
165
+ @click.option(
166
+ "--output",
167
+ "output_file",
168
+ required=True,
169
+ type=click.Path(),
170
+ help="Path to the file where extracted reaction rules will be stored.",
171
+ )
172
+ @click.option(
173
+ "--num_cpus", default=4, type=int, help="The number of CPUs to use for processing."
174
+ )
175
+ def rule_extracting_cli(
176
+ config_path: str, input_file: str, output_file: str, num_cpus: int
177
+ ):
178
+ """Reaction rules extraction."""
179
+ reaction_rule_config = RuleExtractionConfig.from_yaml(config_path)
180
+ extract_rules_from_reactions(
181
+ config=reaction_rule_config,
182
+ reaction_data_path=input_file,
183
+ reaction_rules_path=output_file,
184
+ num_cpus=num_cpus,
185
+ batch_size=100,
186
+ )
187
+
188
+
189
+ @synplan.command(name="ranking_policy_training")
190
+ @click.option(
191
+ "--config",
192
+ "config_path",
193
+ required=True,
194
+ type=click.Path(exists=True),
195
+ help="Path to the configuration file for ranking policy training.",
196
+ )
197
+ @click.option(
198
+ "--reaction_data",
199
+ required=True,
200
+ type=click.Path(exists=True),
201
+ help="Path to the file with reactions for ranking policy training.",
202
+ )
203
+ @click.option(
204
+ "--reaction_rules",
205
+ required=True,
206
+ type=click.Path(exists=True),
207
+ help="Path to the file with extracted reaction rules.",
208
+ )
209
+ @click.option(
210
+ "--results_dir",
211
+ default=Path("."),
212
+ type=click.Path(),
213
+ help="Path to the directory where the trained policy network will be stored.",
214
+ )
215
+ @click.option(
216
+ "--num_cpus",
217
+ default=4,
218
+ type=int,
219
+ help="The number of CPUs to use for training set preparation.",
220
+ )
221
+ def ranking_policy_training_cli(
222
+ config_path: str,
223
+ reaction_data: str,
224
+ reaction_rules: str,
225
+ results_dir: str,
226
+ num_cpus: int,
227
+ ) -> None:
228
+ """Ranking policy network training."""
229
+ policy_config = PolicyNetworkConfig.from_yaml(config_path)
230
+ policy_config.policy_type = "ranking"
231
+ policy_dataset_file = os.path.join(results_dir, "policy_dataset.dt")
232
+
233
+ datamodule = create_policy_dataset(
234
+ reaction_rules_path=reaction_rules,
235
+ molecules_or_reactions_path=reaction_data,
236
+ output_path=policy_dataset_file,
237
+ dataset_type="ranking",
238
+ batch_size=policy_config.batch_size,
239
+ num_cpus=num_cpus,
240
+ )
241
+
242
+ run_policy_training(datamodule, config=policy_config, results_path=results_dir)
243
+
244
+
245
+ @synplan.command(name="filtering_policy_training")
246
+ @click.option(
247
+ "--config",
248
+ "config_path",
249
+ required=True,
250
+ type=click.Path(exists=True),
251
+ help="Path to the configuration file for filtering policy training.",
252
+ )
253
+ @click.option(
254
+ "--molecule_data",
255
+ required=True,
256
+ type=click.Path(exists=True),
257
+ help="Path to the file with molecules for filtering policy training.",
258
+ )
259
+ @click.option(
260
+ "--reaction_rules",
261
+ required=True,
262
+ type=click.Path(exists=True),
263
+ help="Path to the file with extracted reaction rules.",
264
+ )
265
+ @click.option(
266
+ "--results_dir",
267
+ default=Path("."),
268
+ type=click.Path(),
269
+ help="Path to the directory where the trained policy network will be stored.",
270
+ )
271
+ @click.option(
272
+ "--num_cpus",
273
+ default=8,
274
+ type=int,
275
+ help="The number of CPUs to use for training set preparation.",
276
+ )
277
+ def filtering_policy_training_cli(
278
+ config_path: str,
279
+ molecule_data: str,
280
+ reaction_rules: str,
281
+ results_dir: str,
282
+ num_cpus: int,
283
+ ):
284
+ """Filtering policy network training."""
285
+
286
+ policy_config = PolicyNetworkConfig.from_yaml(config_path)
287
+ policy_config.policy_type = "filtering"
288
+ policy_dataset_file = os.path.join(results_dir, "policy_dataset.ckpt")
289
+
290
+ datamodule = create_policy_dataset(
291
+ reaction_rules_path=reaction_rules,
292
+ molecules_or_reactions_path=molecule_data,
293
+ output_path=policy_dataset_file,
294
+ dataset_type="filtering",
295
+ batch_size=policy_config.batch_size,
296
+ num_cpus=num_cpus,
297
+ )
298
+
299
+ run_policy_training(datamodule, config=policy_config, results_path=results_dir)
300
+
301
+
302
+ @synplan.command(name="value_network_tuning")
303
+ @click.option(
304
+ "--config",
305
+ "config_path",
306
+ required=True,
307
+ type=click.Path(exists=True),
308
+ help="Path to the configuration file for value network training.",
309
+ )
310
+ @click.option(
311
+ "--targets",
312
+ required=True,
313
+ type=click.Path(exists=True),
314
+ help="Path to the file with target molecules for planning simulations.",
315
+ )
316
+ @click.option(
317
+ "--reaction_rules",
318
+ required=True,
319
+ type=click.Path(exists=True),
320
+ help="Path to the file with extracted reaction rules. Needed for planning simulations.",
321
+ )
322
+ @click.option(
323
+ "--building_blocks",
324
+ required=True,
325
+ type=click.Path(exists=True),
326
+ help="Path to the file with building blocks. Needed for planning simulations.",
327
+ )
328
+ @click.option(
329
+ "--policy_network",
330
+ required=True,
331
+ type=click.Path(exists=True),
332
+ help="Path to the file with trained policy network. Needed for planning simulations.",
333
+ )
334
+ @click.option(
335
+ "--value_network",
336
+ default=None,
337
+ type=click.Path(exists=True),
338
+ help="Path to the file with trained value network. Needed in case of additional value network fine-tuning",
339
+ )
340
+ @click.option(
341
+ "--results_dir",
342
+ default=".",
343
+ type=click.Path(exists=False),
344
+ help="Path to the directory where the trained value network will be stored.",
345
+ )
346
+ def value_network_tuning_cli(
347
+ config_path: str,
348
+ targets: str,
349
+ reaction_rules: str,
350
+ building_blocks: str,
351
+ policy_network: str,
352
+ value_network: str,
353
+ results_dir: str,
354
+ ):
355
+ """Value network tuning."""
356
+
357
+ with open(config_path, "r", encoding="utf-8") as file:
358
+ config = yaml.safe_load(file)
359
+
360
+ policy_config = PolicyNetworkConfig.from_dict(config["node_expansion"])
361
+ policy_config.weights_path = policy_network
362
+
363
+ value_config = ValueNetworkConfig.from_dict(config["value_network"])
364
+ if value_network is None:
365
+ value_config.weights_path = os.path.join(
366
+ results_dir, "weights", "value_network.ckpt"
367
+ )
368
+
369
+ tree_config = TreeConfig.from_dict(config["tree"])
370
+ tuning_config = TuningConfig.from_dict(config["tuning"])
371
+
372
+ run_updating(
373
+ targets_path=targets,
374
+ tree_config=tree_config,
375
+ policy_config=policy_config,
376
+ value_config=value_config,
377
+ reinforce_config=tuning_config,
378
+ reaction_rules_path=reaction_rules,
379
+ building_blocks_path=building_blocks,
380
+ results_root=results_dir,
381
+ )
382
+
383
+
384
+ @synplan.command(name="planning")
385
+ @click.option(
386
+ "--config",
387
+ "config_path",
388
+ required=True,
389
+ type=click.Path(exists=True),
390
+ help="Path to the configuration file for retrosynthetic planning.",
391
+ )
392
+ @click.option(
393
+ "--targets",
394
+ required=True,
395
+ type=click.Path(exists=True),
396
+ help="Path to the file with target molecules for retrosynthetic planning.",
397
+ )
398
+ @click.option(
399
+ "--reaction_rules",
400
+ required=True,
401
+ type=click.Path(exists=True),
402
+ help="Path to the file with extracted reaction rules.",
403
+ )
404
+ @click.option(
405
+ "--building_blocks",
406
+ required=True,
407
+ type=click.Path(exists=True),
408
+ help="Path to the file with building blocks.",
409
+ )
410
+ @click.option(
411
+ "--policy_network",
412
+ required=True,
413
+ type=click.Path(exists=True),
414
+ help="Path to the file with trained policy network.",
415
+ )
416
+ @click.option(
417
+ "--value_network",
418
+ default=None,
419
+ type=click.Path(exists=True),
420
+ help="Path to the file with trained value network.",
421
+ )
422
+ @click.option(
423
+ "--results_dir",
424
+ default=".",
425
+ type=click.Path(exists=False),
426
+ help="Path to the file where retrosynthetic planning results will be stored.",
427
+ )
428
+ def planning_cli(
429
+ config_path: str,
430
+ targets: str,
431
+ reaction_rules: str,
432
+ building_blocks: str,
433
+ policy_network: str,
434
+ value_network: str,
435
+ results_dir: str,
436
+ ):
437
+ """Retrosynthetic planning."""
438
+
439
+ with open(config_path, "r", encoding="utf-8") as file:
440
+ config = yaml.safe_load(file)
441
+
442
+ search_config = {**config["tree"], **config["node_evaluation"]}
443
+ policy_config = PolicyNetworkConfig.from_dict(
444
+ {**config["node_expansion"], **{"weights_path": policy_network}}
445
+ )
446
+
447
+ run_search(
448
+ targets_path=targets,
449
+ search_config=search_config,
450
+ policy_config=policy_config,
451
+ reaction_rules_path=reaction_rules,
452
+ building_blocks_path=building_blocks,
453
+ value_network_path=value_network,
454
+ results_root=results_dir,
455
+ )
456
+
457
+
458
+ @synplan.command(name="clustering")
459
+ @click.option(
460
+ "--targets",
461
+ required=True,
462
+ type=click.Path(exists=True),
463
+ help="Path to the file with target molecules for retrosynthetic planning.",
464
+ )
465
+ @click.option(
466
+ "--routes_file",
467
+ default=".",
468
+ type=click.Path(exists=False),
469
+ help="Path to the file where the planning results are stored.",
470
+ )
471
+ @click.option(
472
+ "--cluster_results_dir",
473
+ default=".",
474
+ type=click.Path(exists=False),
475
+ help="Path to the file where clustering results will be stored.",
476
+ )
477
+ @click.option(
478
+ "--perform_subcluster",
479
+ default=None,
480
+ type=click.Path(exists=False),
481
+ help="Perform subclustering.",
482
+ )
483
+ @click.option(
484
+ "--subcluster_results_dir",
485
+ default=".",
486
+ type=click.Path(exists=False),
487
+ help="Path to the file where subclustering results will be stored.",
488
+ )
489
+ def cluster_route_from_file_cli(
490
+ targets: str,
491
+ routes_file: str,
492
+ cluster_results_dir: str,
493
+ perform_subcluster: bool,
494
+ subcluster_results_dir: str,
495
+ ):
496
+ """Clustering the routes from planning"""
497
+ run_cluster_cli(
498
+ routes_file=routes_file,
499
+ cluster_results_dir=cluster_results_dir,
500
+ perform_subcluster=perform_subcluster,
501
+ subcluster_results_dir=subcluster_results_dir if perform_subcluster else None,
502
+ )
503
+
504
+
505
+ if __name__ == "__main__":
506
+ synplan()
synplan/interfaces/gui.py ADDED
@@ -0,0 +1,1304 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import base64
2
+ import pickle
3
+ import re
4
+ import uuid
5
+ import io
6
+ import zipfile
7
+
8
+ import pandas as pd
9
+ import streamlit as st
10
+ from CGRtools.files import SMILESRead
11
+ from streamlit_ketcher import st_ketcher
12
+ from huggingface_hub import hf_hub_download
13
+ from huggingface_hub.utils import disable_progress_bars
14
+
15
+
16
+ from synplan.mcts.expansion import PolicyNetworkFunction
17
+ from synplan.mcts.search import extract_tree_stats
18
+ from synplan.mcts.tree import Tree
19
+ from synplan.chem.utils import mol_from_smiles
20
+ from synplan.chem.reaction_routes.route_cgr import *
21
+ from synplan.chem.reaction_routes.clustering import *
22
+
23
+ from synplan.utils.visualisation import (
24
+ routes_clustering_report,
25
+ routes_subclustering_report,
26
+ generate_results_html,
27
+ html_top_routes_cluster,
28
+ get_route_svg,
29
+ )
30
+ from synplan.utils.config import TreeConfig, PolicyNetworkConfig
31
+ from synplan.utils.loading import load_reaction_rules, load_building_blocks
32
+
33
+
34
+ import psutil
35
+ import gc
36
+
37
+
38
+ disable_progress_bars("huggingface_hub")
39
+
40
+ smiles_parser = SMILESRead.create_parser(ignore=True)
41
+ DEFAULT_MOL = "c1cc(ccc1Cl)C(CCO)NC(C2(CCN(CC2)c3c4cc[nH]c4ncn3)N)=O"
42
+
43
+
44
+ # --- Helper Functions ---
45
+ def download_button(
46
+ object_to_download, download_filename, button_text, pickle_it=False
47
+ ):
48
+ """
49
+ Issued from
50
+ Generates a link to download the given object_to_download.
51
+ Params:
52
+ ------
53
+ object_to_download: The object to be downloaded.
54
+ download_filename (str): filename and extension of file. e.g. mydata.csv,
55
+ some_txt_output.txt download_link_text (str): Text to display for download
56
+ link.
57
+ button_text (str): Text to display on download button (e.g. 'click here to download file')
58
+ pickle_it (bool): If True, pickle file.
59
+ Returns:
60
+ -------
61
+ (str): the anchor tag to download object_to_download
62
+ Examples:
63
+ --------
64
+ download_link(your_df, 'YOUR_DF.csv', 'Click to download data!')
65
+ download_link(your_str, 'YOUR_STRING.txt', 'Click to download text!')
66
+ """
67
+ if pickle_it:
68
+ try:
69
+ object_to_download = pickle.dumps(object_to_download)
70
+ except pickle.PicklingError as e:
71
+ st.write(e)
72
+ return None
73
+
74
+ else:
75
+ if isinstance(object_to_download, bytes):
76
+ pass
77
+
78
+ elif isinstance(object_to_download, pd.DataFrame):
79
+ object_to_download = object_to_download.to_csv(index=False).encode("utf-8")
80
+
81
+ try:
82
+ b64 = base64.b64encode(object_to_download.encode()).decode()
83
+ except AttributeError:
84
+ b64 = base64.b64encode(object_to_download).decode()
85
+
86
+ button_uuid = str(uuid.uuid4()).replace("-", "")
87
+ button_id = re.sub("\d+", "", button_uuid)
88
+
89
+ custom_css = f"""
90
+ <style>
91
+ #{button_id} {{
92
+ background-color: rgb(255, 255, 255);
93
+ color: rgb(38, 39, 48);
94
+ text-decoration: none;
95
+ border-radius: 4px;
96
+ border-width: 1px;
97
+ border-style: solid;
98
+ border-color: rgb(230, 234, 241);
99
+ border-image: initial;
100
+ }}
101
+ #{button_id}:hover {{
102
+ border-color: rgb(246, 51, 102);
103
+ color: rgb(246, 51, 102);
104
+ }}
105
+ #{button_id}:active {{
106
+ box-shadow: none;
107
+ background-color: rgb(246, 51, 102);
108
+ color: white;
109
+ }}
110
+ </style> """
111
+
112
+ dl_link = (
113
+ custom_css
114
+ + f'<a download="{download_filename}" id="{button_id}" href="data:file/txt;base64,{b64}">{button_text}</a><br></br>'
115
+ )
116
+ return dl_link
117
+
118
+
119
+ @st.cache_resource
120
+ def load_planning_resources_cached(): # Renamed to avoid conflict if main calls it directly
121
+ building_blocks_path = hf_hub_download(
122
+ repo_id="Laboratoire-De-Chemoinformatique/SynPlanner",
123
+ filename="building_blocks_em_sa_ln.smi",
124
+ subfolder="building_blocks",
125
+ local_dir=".",
126
+ )
127
+ ranking_policy_weights_path = hf_hub_download(
128
+ repo_id="Laboratoire-De-Chemoinformatique/SynPlanner",
129
+ filename="ranking_policy_network.ckpt",
130
+ subfolder="uspto/weights",
131
+ local_dir=".",
132
+ )
133
+ reaction_rules_path = hf_hub_download(
134
+ repo_id="Laboratoire-De-Chemoinformatique/SynPlanner",
135
+ filename="uspto_reaction_rules.pickle",
136
+ subfolder="uspto",
137
+ local_dir=".",
138
+ )
139
+ return building_blocks_path, ranking_policy_weights_path, reaction_rules_path
140
+
141
+
142
+ # --- GUI Sections ---
143
+
144
+
145
+ def initialize_app():
146
+ """1. Initialization: Setting up the main window, layout, and initial widgets."""
147
+ st.set_page_config(page_title="SynPlanner GUI", page_icon="🧪", layout="wide")
148
+
149
+ # Initialize session state variables if they don't exist.
150
+ if "planning_done" not in st.session_state:
151
+ st.session_state.planning_done = False
152
+ if "tree" not in st.session_state:
153
+ st.session_state.tree = None
154
+ if "res" not in st.session_state:
155
+ st.session_state.res = None
156
+ if "target_smiles" not in st.session_state:
157
+ st.session_state.target_smiles = (
158
+ "" # Initial value, might be overwritten by ketcher
159
+ )
160
+
161
+ # Clustering state
162
+ if "clustering_done" not in st.session_state:
163
+ st.session_state.clustering_done = False
164
+ if "clusters" not in st.session_state:
165
+ st.session_state.clusters = None
166
+ if "reactions_dict" not in st.session_state:
167
+ st.session_state.reactions_dict = None
168
+ if "num_clusters_setting" not in st.session_state: # Store the setting used
169
+ st.session_state.num_clusters_setting = 10
170
+ if "route_cgrs_dict" not in st.session_state:
171
+ st.session_state.route_cgrs_dict = None
172
+ if "r_route_cgrs_dict" not in st.session_state:
173
+ st.session_state.r_route_cgrs_dict = None
174
+
175
+ # Subclustering state
176
+ if "subclustering_done" not in st.session_state:
177
+ st.session_state.subclustering_done = False
178
+ if "subclusters" not in st.session_state: # Renamed from 'sub' for clarity
179
+ st.session_state.subclusters = None
180
+
181
+ # Download state (less critical now with direct download links)
182
+ if "clusters_downloaded" not in st.session_state: # Example, might not be needed
183
+ st.session_state.clusters_downloaded = False
184
+
185
+ if "ketcher" not in st.session_state: # For ketcher persistence
186
+ st.session_state.ketcher = DEFAULT_MOL
187
+
188
+ intro_text = """
189
+ This is a demo of the graphical user interface of
190
+ [SynPlanner](https://github.com/Laboratoire-de-Chemoinformatique/SynPlanner/).
191
+ SynPlanner is a comprehensive tool for reaction data curation, rule extraction, model training and retrosynthetic planning.
192
+
193
+ More information on SynPlanner is available in the [official docs](https://synplanner.readthedocs.io/en/latest/index.html).
194
+ """
195
+ st.title("`SynPlanner GUI`")
196
+ st.write(intro_text)
197
+
198
+
199
+ def setup_sidebar():
200
+ """2. Sidebar: Handling the widgets and logic within the sidebar area."""
201
+ # st.sidebar.image("img/logo.png") # Assuming img/logo.png is available
202
+ st.sidebar.title("Docs")
203
+ st.sidebar.markdown("https://synplanner.readthedocs.io/en/latest/")
204
+
205
+ st.sidebar.title("Tutorials")
206
+ st.sidebar.markdown(
207
+ "https://github.com/Laboratoire-de-Chemoinformatique/SynPlanner/tree/main/tutorials"
208
+ )
209
+
210
+ st.sidebar.title("Paper")
211
+ st.sidebar.markdown(
212
+ "https://chemrxiv.org/engage/chemrxiv/article-details/66add90bc9c6a5c07ae65796"
213
+ )
214
+
215
+ st.sidebar.title("Issues")
216
+ st.sidebar.markdown(
217
+ "[Report a bug 🐞](https://github.com/Laboratoire-de-Chemoinformatique/SynPlanner/issues/new?assignees=&labels=bug&projects=&template=bug_report.md&title=%5BBUG%5D)"
218
+ )
219
+
220
+
221
+ def handle_molecule_input():
222
+ """3. Molecule Input: Managing the input area for molecule data."""
223
+ st.header("Molecule input")
224
+ st.markdown(
225
+ """
226
+ You can provide a molecular structure by either providing:
227
+ * SMILES string + Enter
228
+ * Draw it + Apply
229
+ """
230
+ )
231
+ # Use st.session_state.ketcher to persist drawn molecule
232
+ molecule_text_input = st.text_input(
233
+ "SMILES:", value=st.session_state.ketcher, key="smiles_text_input_key"
234
+ )
235
+
236
+ smile_code_ketcher = st_ketcher(molecule_text_input, key="ketcher_widget")
237
+ # col_kethcer, col_info = st.columns([0.8, 0.2])
238
+ # with col_kethcer:
239
+ # smile_code_ketcher = st_ketcher(molecule_text_input, key="ketcher_widget")
240
+ # with col_info:
241
+ # st.subheader("Synthetic Complexity")
242
+ # sascore = ()
243
+ # st.markdown(f"SAScore: {sascore}")
244
+ # syba_score = ()
245
+ # st.markdown(f"SYBA: {sascore}")
246
+
247
+ current_smile_code = (
248
+ smile_code_ketcher # The output from ketcher is the definitive SMILES
249
+ )
250
+
251
+ if (
252
+ "target_smiles" in st.session_state
253
+ and current_smile_code != st.session_state.target_smiles
254
+ ):
255
+ st.warning("Molecule structure changed. Please re-run planning.")
256
+ st.session_state.planning_done = False
257
+ st.session_state.clustering_done = False
258
+ st.session_state.subclustering_done = False
259
+ st.session_state.tree = None
260
+ st.session_state.res = None
261
+ st.session_state.clusters = None
262
+ st.session_state.reactions_dict = None
263
+ st.session_state.subclusters = None
264
+ st.session_state.ketcher = current_smile_code
265
+
266
+ return current_smile_code
267
+
268
+
269
+ def setup_planning_options():
270
+ """4. Planning: Encapsulating the logic related to the "planning" functionality."""
271
+ st.header("Launch calculation")
272
+ st.markdown(
273
+ """If you modified the structure, please ensure you clicked on `Apply` (bottom right of the molecular editor)."""
274
+ )
275
+ # This smile_code display will be updated if handle_molecule_input has run and returned a new smile_code
276
+ # However, to display it correctly, we need the current smile_code from the session or input handler.
277
+ # For simplicity, let's assume handle_molecule_input has updated st.session_state.ketcher
278
+ st.markdown(
279
+ f"The molecule SMILES is actually: ``{st.session_state.get('ketcher', DEFAULT_MOL)}``"
280
+ )
281
+
282
+ st.subheader("Planning options")
283
+ st.markdown(
284
+ """
285
+ The description of each option can be found in the
286
+ [Retrosynthetic Planning Tutorial](https://synplanner.readthedocs.io/en/latest/tutorial_files/retrosynthetic_planning.html#Configuring-search-tree).
287
+ """
288
+ )
289
+
290
+ col_options_1, col_options_2 = st.columns(2, gap="medium")
291
+ with col_options_1:
292
+ search_strategy_input = st.selectbox(
293
+ label="Search strategy",
294
+ options=(
295
+ "Expansion first",
296
+ "Evaluation first",
297
+ ),
298
+ index=0,
299
+ key="search_strategy_input",
300
+ )
301
+ ucb_type = st.selectbox(
302
+ label="UCB type",
303
+ options=("uct", "puct", "value"),
304
+ index=0,
305
+ key="ucb_type_input",
306
+ ) # Fixed label
307
+ c_ucb = st.number_input(
308
+ "C coefficient of UCB",
309
+ value=0.1,
310
+ placeholder="Type a number...",
311
+ key="c_ucb_input",
312
+ )
313
+
314
+ with col_options_2:
315
+ max_iterations = st.slider(
316
+ "Total number of MCTS iterations",
317
+ min_value=50,
318
+ max_value=1000,
319
+ value=300,
320
+ key="max_iterations_slider",
321
+ )
322
+ max_depth = st.slider(
323
+ "Maximal number of reaction steps",
324
+ min_value=3,
325
+ max_value=9,
326
+ value=6,
327
+ key="max_depth_slider",
328
+ )
329
+ min_mol_size = st.slider(
330
+ "Minimum size of a molecule to be precursor",
331
+ min_value=0,
332
+ max_value=7,
333
+ value=0,
334
+ key="min_mol_size_slider",
335
+ help="Number of non-hydrogen atoms in molecule",
336
+ )
337
+
338
+ search_strategy_translator = {
339
+ "Expansion first": "expansion_first",
340
+ "Evaluation first": "evaluation_first",
341
+ }
342
+ search_strategy = search_strategy_translator[search_strategy_input]
343
+
344
+ planning_params = {
345
+ "search_strategy": search_strategy,
346
+ "ucb_type": ucb_type,
347
+ "c_ucb": c_ucb,
348
+ "max_iterations": max_iterations,
349
+ "max_depth": max_depth,
350
+ "min_mol_size": min_mol_size,
351
+ }
352
+
353
+ if st.button("Start retrosynthetic planning", key="submit_planning_button"):
354
+ # Reset downstream states if replanning
355
+ st.session_state.planning_done = False
356
+ st.session_state.clustering_done = False
357
+ st.session_state.subclustering_done = False
358
+ st.session_state.tree = None
359
+ st.session_state.res = None
360
+ st.session_state.clusters = None
361
+ st.session_state.reactions_dict = None
362
+ st.session_state.subclusters = None
363
+ st.session_state.route_cgrs_dict = None
364
+ st.session_state.r_route_cgrs_dict = None
365
+ active_smile_code = st.session_state.get(
366
+ "ketcher", DEFAULT_MOL
367
+ ) # Get current SMILES
368
+ st.session_state.target_smiles = (
369
+ active_smile_code # Store the SMILES used for this run
370
+ )
371
+
372
+ try:
373
+ target_molecule = mol_from_smiles(active_smile_code)
374
+ if target_molecule is None:
375
+ st.error(f"Could not parse the input SMILES: {active_smile_code}")
376
+ else:
377
+ (
378
+ building_blocks_path,
379
+ ranking_policy_weights_path,
380
+ reaction_rules_path,
381
+ ) = load_planning_resources_cached()
382
+ with st.spinner("Running retrosynthetic planning..."):
383
+ with st.status("Loading resources...", expanded=False) as status:
384
+ st.write("Loading building blocks...")
385
+ building_blocks = load_building_blocks(
386
+ building_blocks_path, standardize=False
387
+ )
388
+ st.write("Loading reaction rules...")
389
+ reaction_rules = load_reaction_rules(reaction_rules_path)
390
+ st.write("Loading policy network...")
391
+ policy_config = PolicyNetworkConfig(
392
+ weights_path=ranking_policy_weights_path
393
+ )
394
+ policy_function = PolicyNetworkFunction(
395
+ policy_config=policy_config
396
+ )
397
+ status.update(label="Resources loaded!", state="complete")
398
+
399
+ tree_config = TreeConfig(
400
+ search_strategy=planning_params["search_strategy"],
401
+ evaluation_type="rollout", # This was hardcoded, keeping it.
402
+ max_iterations=planning_params["max_iterations"],
403
+ max_depth=planning_params["max_depth"],
404
+ min_mol_size=planning_params["min_mol_size"],
405
+ init_node_value=0.5, # This was hardcoded
406
+ ucb_type=planning_params["ucb_type"],
407
+ c_ucb=planning_params["c_ucb"],
408
+ silent=True, # This was hardcoded
409
+ )
410
+
411
+ tree = Tree(
412
+ target=target_molecule,
413
+ config=tree_config,
414
+ reaction_rules=reaction_rules,
415
+ building_blocks=building_blocks,
416
+ expansion_function=policy_function,
417
+ evaluation_function=None, # This was hardcoded
418
+ )
419
+
420
+ mcts_progress_text = "Running MCTS iterations..."
421
+ mcts_bar = st.progress(0, text=mcts_progress_text)
422
+ for step, (solved, node_id) in enumerate(tree):
423
+ progress_value = min(
424
+ 1.0, (step + 1) / planning_params["max_iterations"]
425
+ )
426
+ mcts_bar.progress(
427
+ progress_value,
428
+ text=f"{mcts_progress_text} ({step+1}/{planning_params['max_iterations']})",
429
+ )
430
+
431
+ res = extract_tree_stats(tree, target_molecule)
432
+
433
+ st.session_state["tree"] = tree
434
+ st.session_state["res"] = res
435
+ st.session_state.planning_done = True
436
+ st.rerun()
437
+
438
+ except Exception as e:
439
+ st.error(f"An error occurred during planning: {e}")
440
+ st.session_state.planning_done = False
441
+
442
+
443
+ def display_planning_results():
444
+ """5. Planning Results Display: Handling the presentation of results."""
445
+ if st.session_state.get("planning_done", False):
446
+ res = st.session_state.res
447
+ tree = st.session_state.tree
448
+
449
+ if res is None or tree is None:
450
+ st.error(
451
+ "Planning results are missing from session state. Please re-run planning."
452
+ )
453
+ st.session_state.planning_done = False # Reset state
454
+ return # Exit this function if no results
455
+
456
+ if res.get("solved", False): # Use .get for safety
457
+ st.header("Planning Results")
458
+ winning_nodes = (
459
+ sorted(set(tree.winning_nodes))
460
+ if hasattr(tree, "winning_nodes") and tree.winning_nodes
461
+ else []
462
+ )
463
+ st.subheader(f"Number of unique routes found: {len(winning_nodes)}")
464
+
465
+ st.subheader("Examples of found retrosynthetic routes")
466
+ image_counter = 0
467
+ visualised_node_ids = set()
468
+
469
+ if not winning_nodes:
470
+ st.warning(
471
+ "Planning solved, but no winning nodes found in the tree object."
472
+ )
473
+ else:
474
+ for n, node_id in enumerate(winning_nodes):
475
+ if image_counter >= 3:
476
+ break
477
+ if node_id not in visualised_node_ids:
478
+ try:
479
+ visualised_node_ids.add(node_id)
480
+ num_steps = len(tree.synthesis_route(node_id))
481
+ route_score = round(tree.route_score(node_id), 3)
482
+ svg = get_route_svg(tree, node_id)
483
+ if svg:
484
+ st.image(
485
+ svg,
486
+ caption=f"Route {node_id}; {num_steps} steps; Route score: {route_score}",
487
+ )
488
+ image_counter += 1
489
+ else:
490
+ st.warning(
491
+ f"Could not generate SVG for route {node_id}."
492
+ )
493
+ except Exception as e:
494
+ st.error(f"Error displaying route {node_id}: {e}")
495
+ else: # Not solved
496
+ st.header("Planning Results")
497
+ st.warning(
498
+ "No reaction path found for the target molecule with the current settings."
499
+ )
500
+ st.write(
501
+ "Consider adjusting planning options (e.g., increase iterations, adjust depth, check molecule validity)."
502
+ )
503
+ stat_col, _ = st.columns(2)
504
+ with stat_col:
505
+ st.subheader("Run Statistics (No Solution)")
506
+ try:
507
+ if (
508
+ "target_smiles" not in res
509
+ and "target_smiles" in st.session_state
510
+ ):
511
+ res["target_smiles"] = st.session_state.target_smiles
512
+ cols_to_show = [
513
+ col
514
+ for col in [
515
+ "target_smiles",
516
+ "num_nodes",
517
+ "num_iter",
518
+ "search_time",
519
+ ]
520
+ if col in res
521
+ ]
522
+ if cols_to_show:
523
+ df = pd.DataFrame(res, index=[0])[cols_to_show]
524
+ st.dataframe(df)
525
+ else:
526
+ st.write("No statistics to display for the unsuccessful run.")
527
+ except Exception as e:
528
+ st.error(f"Error displaying statistics: {e}")
529
+ st.write(res)
530
+
531
+
532
+ def download_planning_results():
533
+ """6. Planning Results Download: Providing functionality to download."""
534
+ if (
535
+ st.session_state.get("planning_done", False)
536
+ and st.session_state.res
537
+ and st.session_state.res.get("solved", False)
538
+ ):
539
+ res = st.session_state.res
540
+ tree = st.session_state.tree
541
+ # This section is usually placed within a column in the original script
542
+ # We'll assume it's called after display_planning_results and can use a new column or area.
543
+ # For proper layout, this should be integrated with display_planning_results' columns.
544
+ # For now, creating a placeholder or separate section for downloads:
545
+ # st.subheader("Downloads") # This might be redundant if called within a layout context.
546
+
547
+ # The original code places downloads in the second column of planning results.
548
+ # To replicate, we'd need to pass the column object or call this within that context.
549
+ # Simulating this by just creating the download links:
550
+ try:
551
+ html_body = generate_results_html(tree, html_path=None, extended=True)
552
+ dl_html = download_button(
553
+ html_body,
554
+ f"results_synplanner_{st.session_state.target_smiles}.html",
555
+ "Download results (HTML)",
556
+ )
557
+ if dl_html:
558
+ st.markdown(dl_html, unsafe_allow_html=True)
559
+
560
+ try:
561
+ res_df = pd.DataFrame(res, index=[0])
562
+ dl_csv = download_button(
563
+ res_df,
564
+ f"stats_synplanner_{st.session_state.target_smiles}.csv",
565
+ "Download statistics (CSV)",
566
+ )
567
+ if dl_csv:
568
+ st.markdown(dl_csv, unsafe_allow_html=True)
569
+ except Exception as e:
570
+ st.error(f"Could not prepare statistics CSV for download: {e}")
571
+
572
+ except Exception as e:
573
+ st.error(f"Error generating download links for planning results: {e}")
574
+
575
+
576
+ def setup_clustering():
577
+ """7. Clustering: Encapsulating the logic related to the "clustering" functionality."""
578
+ if (
579
+ st.session_state.get("planning_done", False)
580
+ and st.session_state.res
581
+ and st.session_state.res.get("solved", False)
582
+ ):
583
+ st.divider()
584
+ st.header("Clustering the retrosynthetic routes")
585
+
586
+ # num_clusters_input = st.number_input( # This input was removed in the final user code, so omitting.
587
+ # "Desired Number of Clusters (approximate):",
588
+ # min_value=2, max_value=50, value=st.session_state.get("num_clusters_setting", 10),
589
+ # key="num_clusters_input_key"
590
+ # )
591
+
592
+ if st.button("Run Clustering", key="submit_clustering_button"):
593
+ # st.session_state.num_clusters_setting = num_clusters_input
594
+ st.session_state.clustering_done = False
595
+ st.session_state.subclustering_done = False
596
+ st.session_state.clusters = None
597
+ st.session_state.reactions_dict = None
598
+ st.session_state.subclusters = None
599
+ st.session_state.route_cgrs_dict = None
600
+ st.session_state.r_route_cgrs_dict = None
601
+
602
+ with st.spinner("Performing clustering..."):
603
+ try:
604
+ current_tree = st.session_state.tree
605
+ if not current_tree:
606
+ st.error("Tree object not found. Please re-run planning.")
607
+ return
608
+
609
+ st.write("Calculating RoutesCGRs...")
610
+ route_cgrs_dict = compose_all_route_cgrs(current_tree)
611
+ st.write("Processing ReducedRoutesCGRs...")
612
+ r_route_cgrs_dict = compose_all_reduced_route_cgrs(route_cgrs_dict)
613
+
614
+ results = cluster_routes(
615
+ r_route_cgrs_dict, use_strat=False
616
+ ) # num_clusters was removed from args
617
+ results = dict(sorted(results.items(), key=lambda x: float(x[0])))
618
+
619
+ st.session_state.clusters = results
620
+ st.session_state.route_cgrs_dict = route_cgrs_dict
621
+ st.session_state.r_route_cgrs_dict = r_route_cgrs_dict
622
+ st.write("Extracting reactions...")
623
+ st.session_state.reactions_dict = extract_reactions(current_tree)
624
+
625
+ if (
626
+ st.session_state.clusters is not None
627
+ and st.session_state.reactions_dict is not None
628
+ ): # Check for None explicitly
629
+ st.session_state.clustering_done = True
630
+ st.success(
631
+ f"Clustering complete. Found {len(st.session_state.clusters)} clusters."
632
+ )
633
+ else:
634
+ st.error("Clustering failed or returned empty results.")
635
+ st.session_state.clustering_done = False
636
+
637
+ del results # route_cgrs_dict, r_route_cgrs_dict are stored
638
+ gc.collect()
639
+ st.rerun()
640
+ except Exception as e:
641
+ st.error(f"An error occurred during clustering: {e}")
642
+ st.session_state.clustering_done = False
643
+
644
+
645
+ def display_clustering_results():
646
+ """8. Clustering Results Display: Handling the presentation of results."""
647
+ if st.session_state.get("clustering_done", False):
648
+ clusters = st.session_state.clusters
649
+ # reactions_dict = st.session_state.reactions_dict # Needed for download, not directly for display here
650
+ tree = st.session_state.tree
651
+ MAX_DISPLAY_CLUSTERS_DATA = 10
652
+
653
+ if (
654
+ clusters is None or tree is None
655
+ ): # reactions_dict removed as not critical for display part
656
+ st.error(
657
+ "Clustering results (clusters or tree) are missing. Please re-run clustering."
658
+ )
659
+ st.session_state.clustering_done = False
660
+ return
661
+
662
+ st.subheader(f"Best routes from {len(clusters)} Found Clusters")
663
+ clusters_items = list(clusters.items())
664
+ first_items = clusters_items[:MAX_DISPLAY_CLUSTERS_DATA]
665
+ remaining_items = clusters_items[MAX_DISPLAY_CLUSTERS_DATA:]
666
+
667
+ for cluster_num, group_data in first_items:
668
+ if (
669
+ not group_data
670
+ or "node_ids" not in group_data
671
+ or not group_data["node_ids"]
672
+ ):
673
+ st.warning(f"Cluster {cluster_num} has no data or node_ids.")
674
+ continue
675
+ st.markdown(
676
+ f"**Cluster {cluster_num}** (Size: {group_data.get('group_size', 'N/A')})"
677
+ )
678
+ node_id = group_data["node_ids"][0]
679
+ try:
680
+ num_steps = len(tree.synthesis_route(node_id))
681
+ route_score = round(tree.route_score(node_id), 3)
682
+ svg = get_route_svg(tree, node_id)
683
+ r_route_cgr = group_data.get("r_route_cgr") # Safely get r_route_cgr
684
+ r_route_cgr_svg = None
685
+ if r_route_cgr:
686
+ r_route_cgr.clean2d()
687
+ r_route_cgr_svg = cgr_display(r_route_cgr)
688
+
689
+ if svg and r_route_cgr_svg:
690
+ col1, col2 = st.columns([0.2, 0.8])
691
+ with col1:
692
+ st.image(r_route_cgr_svg, caption="ReducedRouteCGR")
693
+ with col2:
694
+ st.image(
695
+ svg,
696
+ caption=f"Route {node_id}; {num_steps} steps; Route score: {route_score}",
697
+ )
698
+ elif svg: # Only route SVG available
699
+ st.image(
700
+ svg,
701
+ caption=f"Route {node_id}; {num_steps} steps; Route score: {route_score}",
702
+ )
703
+ st.warning(
704
+ f"ReducedRouteCGR could not be displayed for cluster {cluster_num}."
705
+ )
706
+ else:
707
+ st.warning(
708
+ f"Could not generate SVG for route {node_id} or its ReducedRouteCGR."
709
+ )
710
+ except Exception as e:
711
+ st.error(
712
+ f"Error displaying route {node_id} for cluster {cluster_num}: {e}"
713
+ )
714
+
715
+ if remaining_items:
716
+ with st.expander(f"... and {len(remaining_items)} more clusters"):
717
+ for cluster_num, group_data in remaining_items:
718
+ if (
719
+ not group_data
720
+ or "node_ids" not in group_data
721
+ or not group_data["node_ids"]
722
+ ):
723
+ st.warning(
724
+ f"Cluster {cluster_num} in expansion has no data or node_ids."
725
+ )
726
+ continue
727
+ st.markdown(
728
+ f"**Cluster {cluster_num}** (Size: {group_data.get('group_size', 'N/A')})"
729
+ )
730
+ node_id = group_data["node_ids"][0]
731
+ try:
732
+ num_steps = len(tree.synthesis_route(node_id))
733
+ route_score = round(tree.route_score(node_id), 3)
734
+ svg = get_route_svg(tree, node_id)
735
+ r_route_cgr = group_data.get("r_route_cgr")
736
+ r_route_cgr_svg = None
737
+ if r_route_cgr:
738
+ r_route_cgr.clean2d()
739
+ r_route_cgr_svg = cgr_display(r_route_cgr)
740
+
741
+ if svg and r_route_cgr_svg:
742
+ col1, col2 = st.columns([0.2, 0.8])
743
+ with col1:
744
+ st.image(r_route_cgr_svg, caption="ReducedRouteCGR")
745
+ with col2:
746
+ st.image(
747
+ svg,
748
+ caption=f"Route {node_id}; {num_steps} steps; Route score: {route_score}",
749
+ )
750
+ elif svg:
751
+ st.image(
752
+ svg,
753
+ caption=f"Route {node_id}; {num_steps} steps; Route score: {route_score}",
754
+ )
755
+ st.warning(
756
+ f"ReducedRouteCGR could not be displayed for cluster {cluster_num}."
757
+ )
758
+ else:
759
+ st.warning(
760
+ f"Could not generate SVG for route {node_id} or its ReducedRouteCGR."
761
+ )
762
+ except Exception as e:
763
+ st.error(
764
+ f"Error displaying route {node_id} for cluster {cluster_num}: {e}"
765
+ )
766
+
767
+
768
+ def download_clustering_results():
769
+ """10. Clustering Results Download: Providing functionality to download."""
770
+ if st.session_state.get("clustering_done", False):
771
+ tree_for_html = st.session_state.get("tree")
772
+ clusters_for_html = st.session_state.get("clusters")
773
+ r_route_cgrs_for_html = st.session_state.get(
774
+ "r_route_cgrs_dict"
775
+ ) # This was used instead of reactions_dict in the original for report
776
+
777
+ if not tree_for_html:
778
+ st.warning("MCTS Tree data not found. Cannot generate cluster reports.")
779
+ return
780
+ if not clusters_for_html:
781
+ st.warning("Cluster data not found. Cannot generate cluster reports.")
782
+ return
783
+ # r_route_cgrs_for_html is optional for routes_clustering_report if not essential
784
+
785
+ st.subheader("Cluster Reports") # Changed subheader in original
786
+ st.write("Generate downloadable HTML reports for each cluster:")
787
+
788
+ MAX_DOWNLOAD_LINKS_DISPLAYED = 10
789
+ num_clusters_total = len(clusters_for_html)
790
+ clusters_items = list(clusters_for_html.items())
791
+
792
+ for i, (cluster_idx, group_data) in enumerate(
793
+ clusters_items
794
+ ): # group_data might not be needed here if report uses cluster_idx
795
+ if i >= MAX_DOWNLOAD_LINKS_DISPLAYED:
796
+ break
797
+ try:
798
+ html_content = routes_clustering_report(
799
+ tree_for_html,
800
+ clusters_for_html, # Pass the whole dict
801
+ str(cluster_idx), # Pass the key of the cluster
802
+ r_route_cgrs_for_html, # Pass the r_route_cgrs dict
803
+ aam=False,
804
+ )
805
+ st.download_button(
806
+ label=f"Download report for cluster {cluster_idx}",
807
+ data=html_content,
808
+ file_name=f"cluster_{cluster_idx}_{st.session_state.target_smiles}.html",
809
+ mime="text/html",
810
+ key=f"download_cluster_{cluster_idx}",
811
+ )
812
+ except Exception as e:
813
+ st.error(f"Error generating report for cluster {cluster_idx}: {e}")
814
+
815
+ if num_clusters_total > MAX_DOWNLOAD_LINKS_DISPLAYED:
816
+ remaining_items = clusters_items[MAX_DOWNLOAD_LINKS_DISPLAYED:]
817
+ remaining_count = len(remaining_items)
818
+ expander_label = f"Show remaining {remaining_count} cluster reports"
819
+ with st.expander(expander_label):
820
+ for (
821
+ group_index,
822
+ _,
823
+ ) in remaining_items: # group_data not needed here either
824
+ try:
825
+ html_content = routes_clustering_report(
826
+ tree_for_html,
827
+ clusters_for_html,
828
+ str(group_index),
829
+ r_route_cgrs_for_html,
830
+ aam=False,
831
+ )
832
+ st.download_button(
833
+ label=f"Download report for cluster {group_index}",
834
+ data=html_content,
835
+ file_name=f"cluster_{group_index}_{st.session_state.target_smiles}.html",
836
+ mime="text/html",
837
+ key=f"download_cluster_expanded_{group_index}",
838
+ )
839
+ except Exception as e:
840
+ st.error(
841
+ f"Error generating report for cluster {group_index} (expanded): {e}"
842
+ )
843
+
844
+ try:
845
+ buffer = io.BytesIO()
846
+ with zipfile.ZipFile(
847
+ buffer, mode="w", compression=zipfile.ZIP_DEFLATED
848
+ ) as zf:
849
+ for idx, _ in clusters_items: # group_data not needed
850
+ html_content_zip = routes_clustering_report(
851
+ tree_for_html,
852
+ clusters_for_html,
853
+ str(idx),
854
+ r_route_cgrs_for_html,
855
+ aam=False,
856
+ )
857
+ filename = f"cluster_{idx}_{st.session_state.target_smiles}.html"
858
+ zf.writestr(filename, html_content_zip)
859
+ buffer.seek(0)
860
+
861
+ st.download_button(
862
+ label="📦 Download all cluster reports as ZIP",
863
+ data=buffer,
864
+ file_name=f"all_cluster_reports_{st.session_state.target_smiles}.zip",
865
+ mime="application/zip",
866
+ key="download_all_clusters_zip",
867
+ )
868
+ except Exception as e:
869
+ st.error(f"Error generating ZIP file for cluster reports: {e}")
870
+
871
+
872
+ def setup_subclustering():
873
+ """11. Subclustering: Encapsulating the logic related to the "subclustering" functionality."""
874
+ if st.session_state.get(
875
+ "clustering_done", False
876
+ ): # Subclustering depends on clustering being done
877
+ st.divider()
878
+ st.header("Sub-Clustering within a selected Cluster")
879
+
880
+ if st.button("Run Subclustering Analysis", key="submit_subclustering_button"):
881
+ st.session_state.subclustering_done = False
882
+ st.session_state.subclusters = None
883
+ with st.spinner("Performing subclustering analysis..."):
884
+ try:
885
+ clusters_for_sub = st.session_state.get("clusters")
886
+ r_route_cgrs_dict_for_sub = st.session_state.get(
887
+ "r_route_cgrs_dict"
888
+ )
889
+ route_cgrs_dict_for_sub = st.session_state.get("route_cgrs_dict")
890
+
891
+ if (
892
+ clusters_for_sub
893
+ and r_route_cgrs_dict_for_sub
894
+ and route_cgrs_dict_for_sub
895
+ ): # Ensure all are present
896
+ all_subgroups = subcluster_all_clusters(
897
+ clusters_for_sub,
898
+ r_route_cgrs_dict_for_sub,
899
+ route_cgrs_dict_for_sub,
900
+ )
901
+ st.session_state.subclusters = all_subgroups
902
+ st.session_state.subclustering_done = True
903
+ st.success("Subclustering analysis complete.")
904
+ gc.collect()
905
+ st.rerun()
906
+ else:
907
+ missing = []
908
+ if not clusters_for_sub:
909
+ missing.append("clusters")
910
+ if not r_route_cgrs_dict_for_sub:
911
+ missing.append("ReducedRouteCGRs dictionary")
912
+ if not route_cgrs_dict_for_sub:
913
+ missing.append("RouteCGRs dictionary")
914
+ st.error(
915
+ f"Cannot run subclustering. Missing data: {', '.join(missing)}. Please ensure clustering ran successfully."
916
+ )
917
+ st.session_state.subclustering_done = False
918
+
919
+ except Exception as e:
920
+ st.error(f"An error occurred during subclustering: {e}")
921
+ st.session_state.subclustering_done = False
922
+
923
+
924
+ def display_subclustering_results():
925
+ """12. Subclustering Results Display: Handling the presentation of results."""
926
+ if st.session_state.get("subclustering_done", False):
927
+ sub = st.session_state.get("subclusters")
928
+ tree = st.session_state.get("tree")
929
+ # clusters_for_sub_display = st.session_state.get('clusters') # Not directly used in display logic from original code snippet
930
+
931
+ if not sub or not tree:
932
+ st.error(
933
+ "Subclustering results (subclusters or tree) are missing. Please re-run subclustering."
934
+ )
935
+ st.session_state.subclustering_done = False
936
+ return
937
+
938
+ sub_input_col, sub_display_col = st.columns([0.25, 0.75])
939
+
940
+ with sub_input_col:
941
+ st.subheader("Select Cluster and Subcluster")
942
+ available_cluster_nums = list(sub.keys())
943
+ if not available_cluster_nums:
944
+ st.warning("No clusters available in subclustering results.")
945
+ return # Exit if no clusters to select
946
+
947
+ user_input_cluster_num_display = st.selectbox(
948
+ "Select Cluster #:",
949
+ options=sorted(available_cluster_nums),
950
+ key="subcluster_num_select_key",
951
+ )
952
+
953
+ selected_subcluster_idx = 0
954
+
955
+ if user_input_cluster_num_display in sub:
956
+ sub_step_cluster = sub[user_input_cluster_num_display]
957
+ allowed_subclusters_indices = sorted(list(sub_step_cluster.keys()))
958
+
959
+ if not allowed_subclusters_indices:
960
+ st.warning(
961
+ f"No reaction steps (subclusters) found for Cluster {user_input_cluster_num_display}."
962
+ )
963
+ else:
964
+ selected_subcluster_idx = st.selectbox(
965
+ "Select Subcluster Index:",
966
+ options=allowed_subclusters_indices,
967
+ key="subcluster_index_select_key",
968
+ )
969
+ if selected_subcluster_idx in sub[user_input_cluster_num_display]:
970
+ current_subcluster_data = sub[user_input_cluster_num_display][
971
+ selected_subcluster_idx
972
+ ]
973
+ if "r_route_cgr" in current_subcluster_data:
974
+ cluster_r_route_cgr_display = current_subcluster_data[
975
+ "r_route_cgr"
976
+ ]
977
+ cluster_r_route_cgr_display.clean2d()
978
+ st.image(
979
+ cluster_r_route_cgr_display.depict(),
980
+ caption=f"ReducedRouteCGR of parent Cluster {user_input_cluster_num_display}",
981
+ )
982
+ else:
983
+ st.warning("ReducedRouteCGR for this subcluster not found.")
984
+ else:
985
+ st.warning(
986
+ f"Selected cluster {user_input_cluster_num_display} not found in subclustering results."
987
+ )
988
+ return
989
+
990
+ with sub_display_col:
991
+ st.subheader("Subcluster Details")
992
+ if (
993
+ user_input_cluster_num_display in sub
994
+ and selected_subcluster_idx in sub[user_input_cluster_num_display]
995
+ ):
996
+
997
+ subcluster_content = sub[user_input_cluster_num_display][
998
+ selected_subcluster_idx
999
+ ]
1000
+
1001
+ # subcluster_to_display = post_process_subgroup(subcluster_content) #Under development
1002
+ subcluster_to_display = subcluster_content
1003
+ if (
1004
+ not subcluster_to_display
1005
+ or "nodes_data" not in subcluster_to_display
1006
+ or not subcluster_to_display["nodes_data"]
1007
+ ):
1008
+ st.info("No routes or data found for this subcluster selection.")
1009
+ else:
1010
+ MAX_ROUTES_PER_SUBCLUSTER = 5
1011
+ all_route_ids_in_subcluster = list(
1012
+ subcluster_to_display["nodes_data"].keys()
1013
+ )
1014
+ routes_to_display_direct = all_route_ids_in_subcluster[
1015
+ :MAX_ROUTES_PER_SUBCLUSTER
1016
+ ]
1017
+ remaining_routes_sub = all_route_ids_in_subcluster[
1018
+ MAX_ROUTES_PER_SUBCLUSTER:
1019
+ ]
1020
+
1021
+ st.markdown(
1022
+ f"--- \n**Subcluster {user_input_cluster_num_display}.{selected_subcluster_idx}** (Size: {len(all_route_ids_in_subcluster)})"
1023
+ )
1024
+
1025
+ if "synthon_reaction" in subcluster_to_display:
1026
+ synthon_reaction = subcluster_to_display["synthon_reaction"]
1027
+ try:
1028
+ st.image(
1029
+ depict_custom_reaction(synthon_reaction),
1030
+ caption=f"Markush-like pseudo reaction of subcluster",
1031
+ ) # Assuming depict_custom_reaction
1032
+ except Exception as e_depict:
1033
+ st.warning(f"Could not depict synthon reaction: {e_depict}")
1034
+ else:
1035
+ st.info("No synthon reaction data for this subcluster.")
1036
+
1037
+ for route_id in routes_to_display_direct:
1038
+ try:
1039
+ route_score_sub = round(tree.route_score(route_id), 3)
1040
+ svg_sub = get_route_svg(tree, route_id)
1041
+ if svg_sub:
1042
+ st.image(
1043
+ svg_sub,
1044
+ caption=f"Route {route_id}; Score: {route_score_sub}",
1045
+ )
1046
+ else:
1047
+ st.warning(
1048
+ f"Could not generate SVG for route {route_id}."
1049
+ )
1050
+ except Exception as e:
1051
+ st.error(
1052
+ f"Error displaying route {route_id} in subcluster: {e}"
1053
+ )
1054
+
1055
+ if remaining_routes_sub:
1056
+ with st.expander(
1057
+ f"... and {len(remaining_routes_sub)} more routes in this subcluster"
1058
+ ):
1059
+ for route_id in remaining_routes_sub:
1060
+ try:
1061
+ route_score_sub = round(
1062
+ tree.route_score(route_id), 3
1063
+ )
1064
+ svg_sub = get_route_svg(tree, route_id)
1065
+ if svg_sub:
1066
+ st.image(
1067
+ svg_sub,
1068
+ caption=f"Route {route_id}; Score: {route_score_sub}",
1069
+ )
1070
+ else:
1071
+ st.warning(
1072
+ f"Could not generate SVG for route {route_id}."
1073
+ )
1074
+ except Exception as e:
1075
+ st.error(
1076
+ f"Error displaying route {route_id} in subcluster (expanded): {e}"
1077
+ )
1078
+ else:
1079
+ st.info("Select a valid cluster and subcluster index to see details.")
1080
+
1081
+
1082
+ def download_subclustering_results():
1083
+ """13. Subclustering Results Download: Providing functionality to download."""
1084
+ if (
1085
+ st.session_state.get("subclustering_done", False)
1086
+ and "subcluster_num_select_key" in st.session_state
1087
+ and "subcluster_index_select_key" in st.session_state
1088
+ ):
1089
+
1090
+ sub = st.session_state.get("subclusters")
1091
+ tree = st.session_state.get("tree")
1092
+ r_route_cgrs_for_report = st.session_state.get(
1093
+ "r_route_cgrs_dict"
1094
+ ) # Used by routes_subclustering_report
1095
+
1096
+ user_input_cluster_num_display = st.session_state.subcluster_num_select_key
1097
+ selected_subcluster_idx = st.session_state.subcluster_index_select_key
1098
+
1099
+ if not tree or not sub or not r_route_cgrs_for_report:
1100
+ st.warning(
1101
+ "Missing data for subclustering report generation (tree, subclusters, or ReducedRouteCGRs)."
1102
+ )
1103
+ return
1104
+
1105
+ if (
1106
+ user_input_cluster_num_display in sub
1107
+ and selected_subcluster_idx in sub[user_input_cluster_num_display]
1108
+ ):
1109
+
1110
+ subcluster_data_for_report = sub[user_input_cluster_num_display][
1111
+ selected_subcluster_idx
1112
+ ]
1113
+ # Apply the same post-processing as in display
1114
+ processed_subcluster_data = post_process_subgroup(
1115
+ subcluster_data_for_report
1116
+ )
1117
+ if "nodes_data" in subcluster_data_for_report and isinstance(
1118
+ subcluster_data_for_report["nodes_data"], dict
1119
+ ):
1120
+ processed_subcluster_data["group_lgs"] = group_by_identical_values(
1121
+ subcluster_data_for_report["nodes_data"]
1122
+ )
1123
+ else:
1124
+ processed_subcluster_data["group_lgs"] = {}
1125
+
1126
+ try:
1127
+ subcluster_html_content = routes_subclustering_report(
1128
+ tree,
1129
+ processed_subcluster_data, # Pass the specific post-processed subcluster data
1130
+ user_input_cluster_num_display,
1131
+ selected_subcluster_idx,
1132
+ r_route_cgrs_for_report, # Pass the whole r_route_cgrs dict
1133
+ if_lg_group=True, # This parameter was in the original call
1134
+ )
1135
+ st.download_button(
1136
+ label=f"Download report for subcluster {user_input_cluster_num_display}.{selected_subcluster_idx}",
1137
+ data=subcluster_html_content,
1138
+ file_name=f"subcluster_{user_input_cluster_num_display}.{selected_subcluster_idx}_{st.session_state.target_smiles}.html",
1139
+ mime="text/html",
1140
+ key=f"download_subcluster_{user_input_cluster_num_display}_{selected_subcluster_idx}",
1141
+ )
1142
+ except Exception as e:
1143
+ st.error(
1144
+ f"Error generating download report for subcluster {user_input_cluster_num_display}.{selected_subcluster_idx}: {e}"
1145
+ )
1146
+ # else:
1147
+ # This case is handled by the display logic mostly, download button just won't appear or will be for previous valid selection.
1148
+
1149
+
1150
+ def implement_restart():
1151
+ """14. Restart: Implementing the logic to reset or restart the application state."""
1152
+ st.divider()
1153
+ st.header("Restart Application State")
1154
+ if st.button("Clear All Results & Restart", key="restart_button"):
1155
+ keys_to_clear = [
1156
+ "planning_done",
1157
+ "tree",
1158
+ "res",
1159
+ "target_smiles",
1160
+ "clustering_done",
1161
+ "clusters",
1162
+ "reactions_dict",
1163
+ "num_clusters_setting",
1164
+ "route_cgrs_dict",
1165
+ "r_route_cgrs_dict",
1166
+ "subclustering_done",
1167
+ "subclusters", # "sub" was renamed
1168
+ "clusters_downloaded",
1169
+ # Potentially ketcher related keys if they need manual reset beyond new input
1170
+ "ketcher_widget",
1171
+ "smiles_text_input_key", # Keys for widgets
1172
+ "subcluster_num_select_key",
1173
+ "subcluster_index_select_key",
1174
+ ]
1175
+ for key in keys_to_clear:
1176
+ if key in st.session_state:
1177
+ del st.session_state[key]
1178
+
1179
+ # Reset ketcher input to default by resetting its session state variable
1180
+ st.session_state.ketcher = DEFAULT_MOL
1181
+ # Also explicitly set target_smiles to empty or default to avoid stale data
1182
+ st.session_state.target_smiles = ""
1183
+
1184
+ # It's generally better to let Streamlit manage widget state if possible,
1185
+ # but for a full reset, clearing their explicit session state keys might be needed.
1186
+ st.rerun()
1187
+
1188
+
1189
+ # --- Main Application Flow ---
1190
+ def main():
1191
+ initialize_app()
1192
+ setup_sidebar()
1193
+ current_smile_code = handle_molecule_input()
1194
+ # Update session_state.ketcher if current_smile_code has changed from ketcher output
1195
+ if st.session_state.get("ketcher") != current_smile_code:
1196
+ st.session_state.ketcher = current_smile_code
1197
+ # No rerun here, let the flow continue. handle_molecule_input already warns.
1198
+
1199
+ setup_planning_options() # This function now also handles the button press and logic for planning
1200
+
1201
+ # Display planning results and download options together
1202
+ if st.session_state.get("planning_done", False):
1203
+ display_planning_results() # Displays stats and routes
1204
+ if st.session_state.res and st.session_state.res.get("solved", False):
1205
+ stat_col, download_col = st.columns(
1206
+ 2, gap="medium"
1207
+ ) # Placeholder for download column
1208
+ with stat_col:
1209
+ st.subheader("Statistics")
1210
+ try:
1211
+ res = st.session_state.res
1212
+ if (
1213
+ "target_smiles" not in res
1214
+ and "target_smiles" in st.session_state
1215
+ ):
1216
+ res["target_smiles"] = st.session_state.target_smiles
1217
+ cols_to_show = [
1218
+ col
1219
+ for col in [
1220
+ "target_smiles",
1221
+ "num_routes",
1222
+ "num_nodes",
1223
+ "num_iter",
1224
+ "search_time",
1225
+ ]
1226
+ if col in res
1227
+ ]
1228
+ if cols_to_show: # Ensure there are columns to show
1229
+ df = pd.DataFrame(res, index=[0])[cols_to_show]
1230
+ st.dataframe(df)
1231
+ else:
1232
+ st.write("No statistics to display from planning results.")
1233
+ except Exception as e:
1234
+ st.error(f"Error displaying statistics: {e}")
1235
+ st.write(res) # Show raw dict if DataFrame fails
1236
+ with download_col:
1237
+ st.subheader("Planning Downloads") # Adding a subheader for clarity
1238
+ download_planning_results()
1239
+
1240
+ # Clustering section (setup button, display, download)
1241
+ if (
1242
+ st.session_state.get("planning_done", False)
1243
+ and st.session_state.res
1244
+ and st.session_state.res.get("solved", False)
1245
+ ):
1246
+ setup_clustering() # Contains the "Run Clustering" button and logic
1247
+ if st.session_state.get("clustering_done", False):
1248
+ display_clustering_results() # Displays cluster routes and stats
1249
+ cluster_stat_col, cluster_download_col = st.columns(2, gap="medium")
1250
+
1251
+ with cluster_stat_col:
1252
+ clusters = st.session_state.clusters
1253
+ cluster_sizes = [
1254
+ cluster.get("group_size", 0)
1255
+ for cluster in clusters.values()
1256
+ if cluster
1257
+ ] # Safe get
1258
+ st.subheader("Cluster Statistics")
1259
+ if cluster_sizes:
1260
+ cluster_df = pd.DataFrame(
1261
+ {
1262
+ "Cluster": [
1263
+ k for k, v in clusters.items() if v
1264
+ ], # Filter out empty clusters
1265
+ "Number of Routes": [
1266
+ v["group_size"] for v in clusters.values() if v
1267
+ ],
1268
+ }
1269
+ )
1270
+ if not cluster_df.empty:
1271
+ cluster_df.index += 1
1272
+ st.dataframe(cluster_df)
1273
+ best_route_html = html_top_routes_cluster(
1274
+ clusters,
1275
+ st.session_state.tree,
1276
+ st.session_state.target_smiles,
1277
+ )
1278
+ st.download_button(
1279
+ label=f"Download best route from each cluster",
1280
+ data=best_route_html,
1281
+ file_name=f"cluster_best_{st.session_state.target_smiles}.html",
1282
+ mime="text/html",
1283
+ key=f"download_cluster_best",
1284
+ )
1285
+ else:
1286
+ st.write("No valid cluster data to display statistics for.")
1287
+ # download_top_routes_cluster()
1288
+ else:
1289
+ st.write("No cluster data to display statistics for.")
1290
+ with cluster_download_col:
1291
+ download_clustering_results()
1292
+
1293
+ # Subclustering section (setup button, display, download)
1294
+ if st.session_state.get("clustering_done", False): # Depends on clustering
1295
+ setup_subclustering() # Contains "Run Subclustering" button
1296
+ if st.session_state.get("subclustering_done", False):
1297
+ display_subclustering_results() # Displays subcluster details and routes
1298
+ download_subclustering_results() # This needs to be called after selections are made in display.
1299
+
1300
+ implement_restart()
1301
+
1302
+
1303
+ if __name__ == "__main__":
1304
+ main()
synplan/interfaces/uspto/uspto_reaction_rules.pickle ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7d9f275781cc926eeff1d9a9564b4aa335b66506c621f943c82ec902460bf977
3
+ size 45489168
synplan/interfaces/uspto/weights/ranking_policy_network.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c1c8852c4d8177538ba2a815d53e7b29f27e8a6067341f05b136a690bc46d53e
3
+ size 164172437
synplan/mcts/__init__.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ from CGRtools.containers import MoleculeContainer
2
+ from .node import *
3
+ from .tree import *
4
+
5
+
6
+ MoleculeContainer.depict_settings(aam=False)
7
+
8
+ __all__ = ["Tree", "Node"]
synplan/mcts/evaluation.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Module containing a class that represents a value function for prediction of
2
+ synthesisablity of new nodes in the tree search."""
3
+
4
+ from typing import List
5
+
6
+ import torch
7
+
8
+ from synplan.chem.precursor import Precursor, compose_precursors
9
+ from synplan.ml.networks.value import ValueNetwork
10
+ from synplan.ml.training import mol_to_pyg
11
+
12
+
13
+ class ValueNetworkFunction:
14
+ """Value function implemented as a value neural network for node evaluation
15
+ (synthesisability prediction) in tree search."""
16
+
17
+ def __init__(self, weights_path: str) -> None:
18
+ """The value function predicts the probability to synthesize the target molecule
19
+ with available building blocks starting from a given precursor.
20
+
21
+ :param weights_path: The value network weights file path.
22
+ """
23
+
24
+ value_net = ValueNetwork.load_from_checkpoint(
25
+ weights_path, map_location=torch.device("cpu")
26
+ )
27
+ self.value_network = value_net.eval()
28
+
29
+ def predict_value(self, precursors: List[Precursor,]) -> float:
30
+ """Predicts a value based on the given precursors from the node. For prediction,
31
+ precursors must be composed into a single molecule (product).
32
+
33
+ :param precursors: The list of precursors.
34
+ :return: The predicted float value ("synthesisability") of the node.
35
+ """
36
+
37
+ molecule = compose_precursors(precursors=precursors, exclude_small=True)
38
+ pyg_graph = mol_to_pyg(molecule)
39
+ if pyg_graph:
40
+ with torch.no_grad():
41
+ value_pred = self.value_network.forward(pyg_graph)[0].item()
42
+ else:
43
+ value_pred = -1e6
44
+
45
+ return value_pred
synplan/mcts/expansion.py ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Module containing a class that represents a policy function for node expansion in the
2
+ tree search."""
3
+
4
+ from typing import Iterator, List, Tuple, Union
5
+
6
+ import torch
7
+ import torch_geometric
8
+ from CGRtools.reactor.reactor import Reactor
9
+
10
+ from synplan.chem.precursor import Precursor
11
+ from synplan.ml.networks.policy import PolicyNetwork
12
+ from synplan.ml.training import mol_to_pyg
13
+ from synplan.utils.config import PolicyNetworkConfig
14
+
15
+
16
+ class PolicyNetworkFunction:
17
+ """Policy function implemented as a policy neural network for node expansion in tree
18
+ search."""
19
+
20
+ def __init__(
21
+ self, policy_config: PolicyNetworkConfig, compile: bool = False
22
+ ) -> None:
23
+ """Initializes the expansion function (ranking or filter policy network).
24
+
25
+ :param policy_config: An expansion policy configuration.
26
+ :param compile: Is supposed to speed up the training with model compilation.
27
+ """
28
+
29
+ self.config = policy_config
30
+
31
+ policy_net = PolicyNetwork.load_from_checkpoint(
32
+ self.config.weights_path,
33
+ map_location=torch.device("cpu"),
34
+ batch_size=1,
35
+ dropout=0,
36
+ )
37
+
38
+ policy_net = policy_net.eval()
39
+ if compile:
40
+ self.policy_net = torch_geometric.compile(policy_net, dynamic=True)
41
+ else:
42
+ self.policy_net = policy_net
43
+
44
+ def predict_reaction_rules(
45
+ self, precursor: Precursor, reaction_rules: List[Reactor]
46
+ ) -> Iterator[Union[Iterator, Iterator[Tuple[float, Reactor, int]]]]:
47
+ """The policy function predicts the list of reaction rules for a given precursor.
48
+
49
+ :param precursor: The current precursor for which the reaction rules are predicted.
50
+ :param reaction_rules: The list of reaction rules from which applicable reaction
51
+ rules are predicted and selected.
52
+ :return: Yielding the predicted probability for the reaction rule, reaction rule
53
+ and reaction rule id.
54
+ """
55
+
56
+ out_dim = list(self.policy_net.modules())[-1].out_features
57
+ if out_dim != len(reaction_rules):
58
+ raise Exception(
59
+ f"The policy network output dimensionality is {out_dim}, but the number of reaction rules is {len(reaction_rules)}. "
60
+ "Probably you use a different version of the policy network. Be sure to retain the policy network "
61
+ "with the current set of reaction rules"
62
+ )
63
+
64
+ pyg_graph = mol_to_pyg(precursor.molecule, canonicalize=False)
65
+ if pyg_graph:
66
+ with torch.no_grad():
67
+ if self.policy_net.policy_type == "filtering":
68
+ probs, priority = self.policy_net.forward(pyg_graph)
69
+ if self.policy_net.policy_type == "ranking":
70
+ probs = self.policy_net.forward(pyg_graph)
71
+ del pyg_graph
72
+ else:
73
+ return []
74
+
75
+ probs = probs[0].double()
76
+ if self.policy_net.policy_type == "filtering":
77
+ priority = priority[0].double()
78
+ priority_coef = self.config.priority_rules_fraction
79
+ probs = (1 - priority_coef) * probs + priority_coef * priority
80
+
81
+ sorted_probs, sorted_rules = torch.sort(probs, descending=True)
82
+ sorted_probs, sorted_rules = (
83
+ sorted_probs[: self.config.top_rules],
84
+ sorted_rules[: self.config.top_rules],
85
+ )
86
+
87
+ if self.policy_net.policy_type == "filtering":
88
+ sorted_probs = torch.softmax(sorted_probs, -1)
89
+
90
+ sorted_probs, sorted_rules = sorted_probs.tolist(), sorted_rules.tolist()
91
+
92
+ for prob, rule_id in zip(sorted_probs, sorted_rules):
93
+ if (
94
+ prob > self.config.rule_prob_threshold
95
+ ): # search may fail if rule_prob_threshold is too low (recommended value is 0.0)
96
+ yield prob, reaction_rules[rule_id], rule_id
synplan/mcts/node.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Module containing a class Node in the tree search."""
2
+
3
+
4
+ class Node:
5
+ """Node class represents a node in the tree search."""
6
+
7
+ def __init__(
8
+ self, precursors_to_expand: tuple = None, new_precursors: tuple = None
9
+ ) -> None:
10
+ """The function initializes the new Node object.
11
+
12
+ :param precursors_to_expand: The tuple of precursors to be expanded. The first precursor
13
+ in the tuple is the current precursor which will be expanded (for which new
14
+ precursors will be generated by applying the predicted reaction rules). When
15
+ the first precursor has been successfully expanded, the second precursor becomes
16
+ the current precursor to be expanded.
17
+ :param new_precursors: The tuple of new precursors generated by applying the reaction
18
+ rule.
19
+ """
20
+
21
+ self.precursors_to_expand = precursors_to_expand
22
+ self.new_precursors = new_precursors
23
+
24
+ if len(self.precursors_to_expand) == 0:
25
+ self.curr_precursor = tuple()
26
+ else:
27
+ self.curr_precursor = self.precursors_to_expand[0]
28
+ self.next_precursor = self.precursors_to_expand[1:]
29
+
30
+ def __len__(self) -> int:
31
+ """Returns the number of precursor in the node to expand."""
32
+ return len(self.precursors_to_expand)
33
+
34
+ def __repr__(self) -> str:
35
+ """Returns the SMILES of each precursor in precursor_to_expand and new_precursor."""
36
+ return (
37
+ f"New precursors: {self.new_precursors}\n"
38
+ f"Precursors to expand: {self.precursors_to_expand}\n"
39
+ )
40
+
41
+ def is_solved(self) -> bool:
42
+ """If True, it is a terminal node.
43
+
44
+ There are no precursors for expansion.
45
+ """
46
+
47
+ return len(self.precursors_to_expand) == 0
synplan/mcts/search.py ADDED
@@ -0,0 +1,199 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Module containing functions for running tree search for the set of target
2
+ molecules."""
3
+
4
+ import csv
5
+ import json
6
+ import logging
7
+ import os.path
8
+ from pathlib import Path
9
+ from typing import Union
10
+
11
+ from CGRtools.containers import MoleculeContainer
12
+ from tqdm import tqdm
13
+
14
+ from synplan.chem.reaction_routes.route_cgr import extract_reactions
15
+ from synplan.chem.reaction_routes.io import write_routes_csv, write_routes_json
16
+ from synplan.chem.utils import mol_from_smiles
17
+ from synplan.mcts.evaluation import ValueNetworkFunction
18
+ from synplan.mcts.expansion import PolicyNetworkFunction
19
+ from synplan.mcts.tree import Tree, TreeConfig
20
+ from synplan.utils.config import PolicyNetworkConfig
21
+ from synplan.utils.loading import load_building_blocks, load_reaction_rules
22
+ from synplan.utils.visualisation import extract_routes, generate_results_html
23
+
24
+
25
+ def extract_tree_stats(
26
+ tree: Tree, target: Union[str, MoleculeContainer], init_smiles: str = None
27
+ ):
28
+ """Collects various statistics from a tree and returns them in a dictionary format.
29
+
30
+ :param tree: The built search tree.
31
+ :param target: The target molecule associated with the tree.
32
+ :param init_smiles: initial SMILES of the molecule, optional.
33
+ :return: A dictionary with the calculated statistics.
34
+ """
35
+
36
+ newick_tree, newick_meta = tree.newickify(visits_threshold=0)
37
+ newick_meta_line = ";".join(
38
+ [f"{nid},{v[0]},{v[1]},{v[2]}" for nid, v in newick_meta.items()]
39
+ )
40
+
41
+ return {
42
+ "target_smiles": init_smiles if init_smiles is not None else str(target),
43
+ "num_routes": len(tree.winning_nodes),
44
+ "num_nodes": len(tree),
45
+ "num_iter": tree.curr_iteration,
46
+ "tree_depth": max(tree.nodes_depth.values()),
47
+ "search_time": round(tree.curr_time, 1),
48
+ "newick_tree": newick_tree,
49
+ "newick_meta": newick_meta_line,
50
+ "solved": True if len(tree.winning_nodes) > 0 else False,
51
+ }
52
+
53
+
54
+ def run_search(
55
+ targets_path: str,
56
+ search_config: dict,
57
+ policy_config: PolicyNetworkConfig,
58
+ reaction_rules_path: str,
59
+ building_blocks_path: str,
60
+ value_network_path: str = None,
61
+ results_root: str = "search_results",
62
+ ) -> None:
63
+ """Performs a tree search on a set of target molecules using specified configuration
64
+ and reaction rules, logging the results and statistics.
65
+
66
+ :param targets_path: The path to the file containing the target molecules (in SDF or
67
+ SMILES format).
68
+ :param search_config: The config object containing the configuration for the tree
69
+ search.
70
+ :param policy_config: The config object containing the configuration for the policy.
71
+ :param reaction_rules_path: The path to the file containing reaction rules.
72
+ :param building_blocks_path: The path to the file containing building blocks.
73
+ :param value_network_path: The path to the file containing value weights (optional).
74
+ :param results_root: The name of the folder where the results of the tree search
75
+ will be saved.
76
+ :return: None.
77
+ """
78
+
79
+ # results folder
80
+ results_root = Path(results_root)
81
+ if not results_root.exists():
82
+ results_root.mkdir()
83
+
84
+ # output files
85
+ stats_file = results_root.joinpath("tree_search_stats.csv")
86
+ routes_file = results_root.joinpath("extracted_routes.json")
87
+ routes_folder = results_root.joinpath("extracted_routes_html")
88
+ routes_folder.mkdir(exist_ok=True)
89
+
90
+ # stats header
91
+ stats_header = [
92
+ "target_smiles",
93
+ "num_routes",
94
+ "num_nodes",
95
+ "num_iter",
96
+ "tree_depth",
97
+ "search_time",
98
+ "newick_tree",
99
+ "newick_meta",
100
+ "solved",
101
+ "error",
102
+ ]
103
+
104
+ # config
105
+ policy_function = PolicyNetworkFunction(policy_config=policy_config)
106
+ if search_config["evaluation_type"] == "gcn" and value_network_path:
107
+ value_function = ValueNetworkFunction(weights_path=value_network_path)
108
+ else:
109
+ value_function = None
110
+
111
+ reaction_rules = load_reaction_rules(reaction_rules_path)
112
+ building_blocks = load_building_blocks(building_blocks_path, standardize=True)
113
+
114
+ # run search
115
+ n_solved = 0
116
+ extracted_routes = []
117
+
118
+ tree_config = TreeConfig.from_dict(search_config)
119
+ tree_config.silent = True
120
+ with (
121
+ open(targets_path, "r", encoding="utf-8") as targets,
122
+ open(stats_file, "w", encoding="utf-8", newline="\n") as csvfile,
123
+ ):
124
+
125
+ statswriter = csv.DictWriter(csvfile, delimiter=",", fieldnames=stats_header)
126
+ statswriter.writeheader()
127
+
128
+ for ti, target_smi in tqdm(
129
+ enumerate(targets),
130
+ leave=True,
131
+ desc="Number of target molecules processed: ",
132
+ bar_format="{desc}{n} [{elapsed}]",
133
+ ):
134
+ target_smi = target_smi.strip()
135
+ target_mol = mol_from_smiles(target_smi)
136
+ try:
137
+ # run search
138
+ tree = Tree(
139
+ target=target_mol,
140
+ config=tree_config,
141
+ reaction_rules=reaction_rules,
142
+ building_blocks=building_blocks,
143
+ expansion_function=policy_function,
144
+ evaluation_function=value_function,
145
+ )
146
+
147
+ _ = list(tree)
148
+
149
+ except Exception as e:
150
+ extracted_routes.append(
151
+ [
152
+ {
153
+ "type": "mol",
154
+ "smiles": target_smi,
155
+ "in_stock": False,
156
+ "children": [],
157
+ }
158
+ ]
159
+ )
160
+ logging.warning(
161
+ f"Retrosynthetic_planning {target_smi} failed with the following error: {e}"
162
+ )
163
+
164
+ continue
165
+
166
+ # is solved
167
+ n_solved += bool(tree.winning_nodes)
168
+ if bool(tree.winning_nodes):
169
+
170
+ # extract routes
171
+ extracted_routes.append(extract_routes(tree))
172
+
173
+ # save routes
174
+ generate_results_html(
175
+ tree,
176
+ os.path.join(routes_folder, f"retroroutes_target_{ti}.html"),
177
+ extended=True,
178
+ )
179
+
180
+ # save stats
181
+ statswriter.writerow(extract_tree_stats(tree, target_smi))
182
+ csvfile.flush()
183
+
184
+ # save json routes
185
+ with open(routes_file, "w", encoding="utf-8") as f:
186
+ json.dump(extracted_routes, f)
187
+
188
+ # Save mapped reactions (CSV)
189
+ routes_dict = extract_reactions(tree)
190
+ write_routes_csv(
191
+ routes_dict, os.path.join(routes_folder, f"mapped_routes_{ti}.csv")
192
+ )
193
+
194
+ # save mapped reactions (JSON)
195
+ write_routes_json(
196
+ routes_dict, os.path.join(routes_folder, f"mapped_routes_{ti}.json")
197
+ )
198
+
199
+ print(f"Number of solved target molecules: {n_solved}")
synplan/mcts/tree.py ADDED
@@ -0,0 +1,635 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Module containing a class Tree that used for tree search of retrosynthetic routes."""
2
+
3
+ import logging
4
+ import warnings
5
+ from collections import defaultdict, deque
6
+ from math import sqrt
7
+ from random import choice, uniform
8
+ from time import time
9
+ from typing import Dict, List, Set, Tuple
10
+
11
+ from CGRtools.reactor import Reactor
12
+ from CGRtools.containers import MoleculeContainer
13
+ from tqdm.auto import tqdm
14
+
15
+ from synplan.chem.precursor import Precursor
16
+ from synplan.chem.reaction import Reaction, apply_reaction_rule
17
+ from synplan.mcts.evaluation import ValueNetworkFunction
18
+ from synplan.mcts.expansion import PolicyNetworkFunction
19
+ from synplan.mcts.node import Node
20
+ from synplan.utils.config import TreeConfig
21
+
22
+
23
+ class Tree:
24
+ """Tree class with attributes and methods for Monte-Carlo tree search."""
25
+
26
+ def __init__(
27
+ self,
28
+ target: MoleculeContainer,
29
+ config: TreeConfig,
30
+ reaction_rules: List[Reactor],
31
+ building_blocks: Set[str],
32
+ expansion_function: PolicyNetworkFunction,
33
+ evaluation_function: ValueNetworkFunction = None,
34
+ ):
35
+ """Initializes a tree object with optional parameters for tree search for target
36
+ molecule.
37
+
38
+ :param target: A target molecule for retrosynthetic routes search.
39
+ :param config: A tree configuration.
40
+ :param reaction_rules: A loaded reaction rules.
41
+ :param building_blocks: A loaded building blocks.
42
+ :param expansion_function: A loaded policy function.
43
+ :param evaluation_function: A loaded value function. If None, the rollout is
44
+ used as a default for node evaluation.
45
+ """
46
+
47
+ # config parameters
48
+ self.config = config
49
+
50
+ assert isinstance(
51
+ target, MoleculeContainer
52
+ ), "Target should be given as MoleculeContainer"
53
+ assert len(target) > 3, "Target molecule has less than 3 atoms"
54
+
55
+ target_molecule = Precursor(target)
56
+ target_molecule.prev_precursors.append(Precursor(target))
57
+ target_node = Node(
58
+ precursors_to_expand=(target_molecule,), new_precursors=(target_molecule,)
59
+ )
60
+
61
+ # tree structure init
62
+ self.nodes: Dict[int, Node] = {1: target_node}
63
+ self.parents: Dict[int, int] = {1: 0}
64
+ self.children: Dict[int, Set[int]] = {1: set()}
65
+ self.winning_nodes: List[int] = []
66
+ self.visited_nodes: Set[int] = set()
67
+ self.expanded_nodes: Set[int] = set()
68
+ self.nodes_visit: Dict[int, int] = {1: 0}
69
+ self.nodes_depth: Dict[int, int] = {1: 0}
70
+ self.nodes_prob: Dict[int, float] = {1: 0.0}
71
+ self.nodes_rules: Dict[int, float] = {}
72
+ self.nodes_init_value: Dict[int, float] = {1: 0.0}
73
+ self.nodes_total_value: Dict[int, float] = {1: 0.0}
74
+
75
+ # tree building limits
76
+ self.curr_iteration: int = 0
77
+ self.curr_tree_size: int = 2
78
+ self.start_time: float = 0
79
+ self.curr_time: float = 0
80
+
81
+ # building blocks and reaction reaction_rules
82
+ self.reaction_rules = reaction_rules
83
+ self.building_blocks = building_blocks
84
+
85
+ # policy and value functions
86
+ self.policy_network = expansion_function
87
+ if self.config.evaluation_type == "gcn":
88
+ if evaluation_function is None:
89
+ raise ValueError(
90
+ "Value function not specified while evaluation type is 'gcn'"
91
+ )
92
+ if (
93
+ evaluation_function is not None
94
+ and self.config.evaluation_type == "rollout"
95
+ ):
96
+ raise ValueError(
97
+ "Value function is not None while evaluation type is 'rollout'. What should be evaluation type ?"
98
+ )
99
+ self.value_network = evaluation_function
100
+
101
+ # utils
102
+ self._tqdm = True # needed to disable tqdm with multiprocessing module
103
+
104
+ target_smiles = str(self.nodes[1].curr_precursor.molecule)
105
+ if target_smiles in self.building_blocks:
106
+ self.building_blocks.remove(target_smiles)
107
+ print(
108
+ "Target was found in building blocks and removed from building blocks."
109
+ )
110
+
111
+ def __len__(self) -> int:
112
+ """Returns the current size (the number of nodes) in the tree."""
113
+
114
+ return self.curr_tree_size - 1
115
+
116
+ def __iter__(self) -> "Tree":
117
+ """The function is defining an iterator for a Tree object.
118
+
119
+ Also needed for the bar progress display.
120
+ """
121
+
122
+ self.start_time = time()
123
+ if self._tqdm:
124
+ self._tqdm = tqdm(
125
+ total=self.config.max_iterations, disable=self.config.silent
126
+ )
127
+ return self
128
+
129
+ def __repr__(self) -> str:
130
+ """Returns a string representation of the tree (target SMILES, tree size, and
131
+ the number of found routes)."""
132
+ return self.report()
133
+
134
+ def __next__(self) -> [bool, List[int]]:
135
+ """The __next__ method is used to do one iteration of the tree building.
136
+
137
+ :return: Returns True if the route was found and the node id of the last node in
138
+ the route. Otherwise, returns False and the id of the last visited node.
139
+ """
140
+
141
+ if self.curr_iteration >= self.config.max_iterations:
142
+ raise StopIteration("Iterations limit exceeded.")
143
+ if self.curr_tree_size >= self.config.max_tree_size:
144
+ raise StopIteration("Max tree size exceeded or all possible routes found.")
145
+ if self.curr_time >= self.config.max_time:
146
+ raise StopIteration("Time limit exceeded.")
147
+
148
+ # start new iteration
149
+ self.curr_iteration += 1
150
+ self.curr_time = time() - self.start_time
151
+
152
+ if self._tqdm:
153
+ self._tqdm.update()
154
+
155
+ curr_depth, node_id = 0, 1 # start from the root node_id
156
+
157
+ explore_route = True
158
+ while explore_route:
159
+ self.visited_nodes.add(node_id)
160
+
161
+ if self.nodes_visit[node_id]: # already visited
162
+ if not self.children[node_id]: # dead node
163
+ self._update_visits(node_id)
164
+ explore_route = False
165
+ else:
166
+ node_id = self._select_node(node_id) # select the child node
167
+ curr_depth += 1
168
+ else:
169
+ if self.nodes[node_id].is_solved(): # found route
170
+ self._update_visits(
171
+ node_id
172
+ ) # this prevents expanding of bb node_id
173
+ self.winning_nodes.append(node_id)
174
+ return True, [node_id]
175
+
176
+ if (
177
+ curr_depth < self.config.max_depth
178
+ ): # expand node if depth limit is not reached
179
+ self._expand_node(node_id)
180
+ if not self.children[node_id]: # node was not expanded
181
+ value_to_backprop = -1.0
182
+ else:
183
+ self.expanded_nodes.add(node_id)
184
+
185
+ if self.config.search_strategy == "evaluation_first":
186
+ # recalculate node value based on children synthesisability and backpropagation
187
+ child_values = [
188
+ self.nodes_init_value[child_id]
189
+ for child_id in self.children[node_id]
190
+ ]
191
+
192
+ if self.config.evaluation_agg == "max":
193
+ value_to_backprop = max(child_values)
194
+
195
+ elif self.config.evaluation_agg == "average":
196
+ value_to_backprop = sum(child_values) / len(
197
+ self.children[node_id]
198
+ )
199
+
200
+ elif self.config.search_strategy == "expansion_first":
201
+ value_to_backprop = self._get_node_value(node_id)
202
+
203
+ # backpropagation
204
+ self._backpropagate(node_id, value_to_backprop)
205
+ self._update_visits(node_id)
206
+ explore_route = False
207
+
208
+ if self.children[node_id]:
209
+ # found after expansion
210
+ found_after_expansion = set()
211
+ for child_id in iter(self.children[node_id]):
212
+ if self.nodes[child_id].is_solved():
213
+ found_after_expansion.add(child_id)
214
+ self.winning_nodes.append(child_id)
215
+
216
+ if found_after_expansion:
217
+ return True, list(found_after_expansion)
218
+
219
+ else:
220
+ self._backpropagate(node_id, self.nodes_total_value[node_id])
221
+ self._update_visits(node_id)
222
+ explore_route = False
223
+
224
+ return False, [node_id]
225
+
226
+ def _ucb(self, node_id: int) -> float:
227
+ """Calculates the Upper Confidence Bound (UCB) statistics for a given node.
228
+
229
+ :param node_id: The id of the node.
230
+ :return: The calculated UCB.
231
+ """
232
+
233
+ prob = self.nodes_prob[node_id] # predicted by policy network score
234
+ visit = self.nodes_visit[node_id]
235
+
236
+ if self.config.ucb_type == "puct":
237
+ u = (
238
+ self.config.c_ucb * prob * sqrt(self.nodes_visit[self.parents[node_id]])
239
+ ) / (visit + 1)
240
+ ucb_value = self.nodes_total_value[node_id] + u
241
+
242
+ if self.config.ucb_type == "uct":
243
+ u = (
244
+ self.config.c_ucb
245
+ * sqrt(self.nodes_visit[self.parents[node_id]])
246
+ / (visit + 1)
247
+ )
248
+ ucb_value = self.nodes_total_value[node_id] + u
249
+
250
+ if self.config.ucb_type == "value":
251
+ ucb_value = self.nodes_init_value[node_id] / (visit + 1)
252
+
253
+ return ucb_value
254
+
255
+ def _select_node(self, node_id: int) -> int:
256
+ """Selects a node based on its UCB value and returns the id of the node with the
257
+ highest UCB.
258
+
259
+ :param node_id: The id of the node.
260
+ :return: The id of the node with the highest UCB.
261
+ """
262
+
263
+ if self.config.epsilon > 0:
264
+ n = uniform(0, 1)
265
+ if n < self.config.epsilon:
266
+ return choice(list(self.children[node_id]))
267
+
268
+ best_score, best_children = None, []
269
+ for child_id in self.children[node_id]:
270
+ score = self._ucb(child_id)
271
+ if best_score is None or score > best_score:
272
+ best_score, best_children = score, [child_id]
273
+ elif score == best_score:
274
+ best_children.append(child_id)
275
+
276
+ # is needed for tree search reproducibility, when all child nodes has the same score
277
+ return best_children[0]
278
+
279
+ def _expand_node(self, node_id: int) -> None:
280
+ """Expands the node by generating new precursor with policy (expansion) function.
281
+
282
+ :param node_id: The id the node to be expanded.
283
+ :return: None.
284
+ """
285
+ curr_node = self.nodes[node_id]
286
+ prev_precursor = curr_node.curr_precursor.prev_precursors
287
+
288
+ tmp_precursor = set()
289
+ expanded = False
290
+ for prob, rule, rule_id in self.policy_network.predict_reaction_rules(
291
+ curr_node.curr_precursor, self.reaction_rules
292
+ ):
293
+ for products in apply_reaction_rule(
294
+ curr_node.curr_precursor.molecule, rule
295
+ ):
296
+ # check repeated products
297
+ if not products or not set(products) - tmp_precursor:
298
+ continue
299
+ tmp_precursor.update(products)
300
+
301
+ for molecule in products:
302
+ molecule.meta["reactor_id"] = rule_id
303
+
304
+ new_precursor = tuple(Precursor(mol) for mol in products)
305
+ scaled_prob = prob * len(
306
+ list(filter(lambda x: len(x) > self.config.min_mol_size, products))
307
+ )
308
+
309
+ if set(prev_precursor).isdisjoint(new_precursor):
310
+ precursors_to_expand = (
311
+ *curr_node.next_precursor,
312
+ *(
313
+ x
314
+ for x in new_precursor
315
+ if not x.is_building_block(
316
+ self.building_blocks, self.config.min_mol_size
317
+ )
318
+ ),
319
+ )
320
+
321
+ child_node = Node(
322
+ precursors_to_expand=precursors_to_expand,
323
+ new_precursors=new_precursor,
324
+ )
325
+
326
+ for new_precursor in new_precursor:
327
+ new_precursor.prev_precursors = [new_precursor, *prev_precursor]
328
+
329
+ self._add_node(node_id, child_node, scaled_prob, rule_id)
330
+
331
+ expanded = True
332
+ if not expanded and node_id == 1:
333
+ raise StopIteration("\nThe target molecule was not expanded.")
334
+
335
+ def _add_node(
336
+ self,
337
+ node_id: int,
338
+ new_node: Node,
339
+ policy_prob: float = None,
340
+ rule_id: int = None,
341
+ ) -> None:
342
+ """Adds a new node to the tree with probability of reaction rules predicted by
343
+ policy function and applied to the parent node of the new node.
344
+
345
+ :param node_id: The id of the parent node.
346
+ :param new_node: The new node to be added.
347
+ :param policy_prob: The probability of reaction rules predicted by policy
348
+ function for thr parent node.
349
+ :return: None.
350
+ """
351
+
352
+ new_node_id = self.curr_tree_size
353
+
354
+ self.nodes[new_node_id] = new_node
355
+ self.parents[new_node_id] = node_id
356
+ self.children[node_id].add(new_node_id)
357
+ self.children[new_node_id] = set()
358
+ self.nodes_visit[new_node_id] = 0
359
+ self.nodes_prob[new_node_id] = policy_prob
360
+ self.nodes_rules[new_node_id] = rule_id
361
+ self.nodes_depth[new_node_id] = self.nodes_depth[node_id] + 1
362
+ self.curr_tree_size += 1
363
+
364
+ if self.config.search_strategy == "evaluation_first":
365
+ node_value = self._get_node_value(new_node_id)
366
+ elif self.config.search_strategy == "expansion_first":
367
+ node_value = self.config.init_node_value
368
+
369
+ self.nodes_init_value[new_node_id] = node_value
370
+ self.nodes_total_value[new_node_id] = node_value
371
+
372
+ def _get_node_value(self, node_id: int) -> float:
373
+ """Calculates the value for the given node (for example with rollout or value
374
+ network).
375
+
376
+ :param node_id: The id of the node to be evaluated.
377
+ :return: The estimated value of the node.
378
+ """
379
+
380
+ node = self.nodes[node_id]
381
+
382
+ if self.config.evaluation_type == "random":
383
+ node_value = uniform(0, 1)
384
+
385
+ elif self.config.evaluation_type == "rollout":
386
+ node_value = min(
387
+ (
388
+ self._rollout_node(
389
+ precursor, current_depth=self.nodes_depth[node_id]
390
+ )
391
+ for precursor in node.precursors_to_expand
392
+ ),
393
+ default=1.0,
394
+ )
395
+
396
+ elif self.config.evaluation_type == "gcn":
397
+ node_value = self.value_network.predict_value(node.new_precursors)
398
+
399
+ return node_value
400
+
401
+ def _update_visits(self, node_id: int) -> None:
402
+ """Updates the number of visits from the current node to the root node.
403
+
404
+ :param node_id: The id of the current node.
405
+ :return: None.
406
+ """
407
+
408
+ while node_id:
409
+ self.nodes_visit[node_id] += 1
410
+ node_id = self.parents[node_id]
411
+
412
+ def _backpropagate(self, node_id: int, value: float) -> None:
413
+ """Backpropagates the value through the tree from the current.
414
+
415
+ :param node_id: The id of the node from which to backpropagate the value.
416
+ :param value: The value to backpropagate.
417
+ :return: None.
418
+ """
419
+ while node_id:
420
+ if self.config.backprop_type == "muzero":
421
+ self.nodes_total_value[node_id] = (
422
+ self.nodes_total_value[node_id] * self.nodes_visit[node_id] + value
423
+ ) / (self.nodes_visit[node_id] + 1)
424
+ elif self.config.backprop_type == "cumulative":
425
+ self.nodes_total_value[node_id] += value
426
+ node_id = self.parents[node_id]
427
+
428
+ def _rollout_node(self, precursor: Precursor, current_depth: int = None) -> float:
429
+ """Performs a rollout simulation from a given node in the tree. Given the
430
+ current precursor, find the first successful reaction and return the new precursor.
431
+
432
+ If the precursor is a building_block, return 1.0, else check the
433
+ first successful reaction.
434
+
435
+ If the reaction is not successful, return -1.0.
436
+
437
+ If the reaction is successful, but the generated precursor are not
438
+ the building_blocks and the precursor cannot be generated without
439
+ exceeding current_depth threshold, return -0.5.
440
+
441
+ If the reaction is successful, but the precursor are not the
442
+ building_blocks and the precursor cannot be generated, return
443
+ -1.0.
444
+
445
+ :param precursor: The precursor to be evaluated.
446
+ :param current_depth: The current depth of the tree.
447
+ :return: The reward (value) assigned to the precursor.
448
+ """
449
+
450
+ max_depth = self.config.max_depth - current_depth
451
+
452
+ # precursor checking
453
+ if precursor.is_building_block(self.building_blocks, self.config.min_mol_size):
454
+ return 1.0
455
+
456
+ if max_depth == 0:
457
+ print("max depth reached in the beginning")
458
+
459
+ # precursor simulating
460
+ occurred_precursor = set()
461
+ precursor_to_expand = deque([precursor])
462
+ history = defaultdict(dict)
463
+ rollout_depth = 0
464
+ while precursor_to_expand:
465
+ # Iterate through reactors and pick first successful reaction.
466
+ # Check products of the reaction if you can find them in in-building_blocks data
467
+ # If not, then add missed products to precursor_to_expand and try to decompose them
468
+ if len(history) >= max_depth:
469
+ reward = -0.5
470
+ return reward
471
+
472
+ current_precursor = precursor_to_expand.popleft()
473
+ history[rollout_depth]["target"] = current_precursor
474
+ occurred_precursor.add(current_precursor)
475
+
476
+ # Pick the first successful reaction while iterating through reactors
477
+ reaction_rule_applied = False
478
+ for prob, rule, rule_id in self.policy_network.predict_reaction_rules(
479
+ current_precursor, self.reaction_rules
480
+ ):
481
+ for products in apply_reaction_rule(current_precursor.molecule, rule):
482
+ if products:
483
+ reaction_rule_applied = True
484
+ break
485
+
486
+ if reaction_rule_applied:
487
+ history[rollout_depth]["rule_index"] = rule_id
488
+ break
489
+
490
+ if not reaction_rule_applied:
491
+ reward = -1.0
492
+ return reward
493
+
494
+ products = tuple(Precursor(product) for product in products)
495
+ history[rollout_depth]["products"] = products
496
+
497
+ # check loops
498
+ if any(x in occurred_precursor for x in products) and products:
499
+ # sometimes manual can create a loop, when
500
+ # print('occurred_precursor')
501
+ reward = -1.0
502
+ return reward
503
+
504
+ if occurred_precursor.isdisjoint(products):
505
+ # added number of atoms check
506
+ precursor_to_expand.extend(
507
+ [
508
+ x
509
+ for x in products
510
+ if not x.is_building_block(
511
+ self.building_blocks, self.config.min_mol_size
512
+ )
513
+ ]
514
+ )
515
+ rollout_depth += 1
516
+
517
+ reward = 1.0
518
+ return reward
519
+
520
+ def report(self) -> str:
521
+ """Returns the string representation of the tree."""
522
+
523
+ return (
524
+ f"Tree for: {str(self.nodes[1].precursors_to_expand[0])}\n"
525
+ f"Time: {round(self.curr_time, 1)} seconds\n"
526
+ f"Number of nodes: {len(self)}\n"
527
+ f"Number of iterations: {self.curr_iteration}\n"
528
+ f"Number of visited nodes: {len(self.visited_nodes)}\n"
529
+ f"Number of found routes: {len(self.winning_nodes)}"
530
+ )
531
+
532
+ def route_score(self, node_id: int) -> float:
533
+ """Calculates the score of a given route from the current node to the root node.
534
+ The score depends on cumulated node values nad the route length.
535
+
536
+ :param node_id: The id of the current given node.
537
+ :return: The route score.
538
+ """
539
+
540
+ cumulated_nodes_value, route_length = 0, 0
541
+ while node_id:
542
+ route_length += 1
543
+
544
+ cumulated_nodes_value += self.nodes_total_value[node_id]
545
+ node_id = self.parents[node_id]
546
+
547
+ return cumulated_nodes_value / (route_length**2)
548
+
549
+ def route_to_node(self, node_id: int) -> List[Node,]:
550
+ """Returns the route (list of id of nodes) to from the node current node to the
551
+ root node.
552
+
553
+ :param node_id: The id of the current node.
554
+ :return: The list of nodes.
555
+ """
556
+
557
+ nodes = []
558
+ while node_id:
559
+ nodes.append(node_id)
560
+ node_id = self.parents[node_id]
561
+ return [self.nodes[node_id] for node_id in reversed(nodes)]
562
+
563
+ def synthesis_route(self, node_id: int) -> Tuple[Reaction,]:
564
+ """Given a node_id, return a tuple of reactions that represent the
565
+ retrosynthetic route from the current node.
566
+
567
+ :param node_id: The id of the current node.
568
+ :return: The tuple of extracted reactions representing the synthesis route.
569
+ """
570
+
571
+ nodes = self.route_to_node(node_id)
572
+
573
+ reaction_sequence = [
574
+ Reaction(
575
+ [x.molecule for x in after.new_precursors],
576
+ [before.curr_precursor.molecule],
577
+ )
578
+ for before, after in zip(nodes, nodes[1:])
579
+ ]
580
+
581
+ for r in reaction_sequence:
582
+ r.clean2d()
583
+ return tuple(reversed(reaction_sequence))
584
+
585
+ def newickify(self, visits_threshold: int = 0, root_node_id: int = 1):
586
+ """
587
+ Adopted from https://stackoverflow.com/questions/50003007/how-to-convert-python-dictionary-to-newick-form-format.
588
+
589
+ :param visits_threshold: The minimum number of visits for the given node.
590
+ :param root_node_id: The id of the root node.
591
+
592
+ :return: The newick string and meta dict.
593
+ """
594
+ visited_nodes = set()
595
+
596
+ def newick_render_node(current_node_id: int) -> str:
597
+ """Recursively generates a Newick string representation of the tree.
598
+
599
+ :param current_node_id: The id of the current node.
600
+ :return: A string representation of a node in a Newick format.
601
+ """
602
+ assert (
603
+ current_node_id not in visited_nodes
604
+ ), "Error: The tree may not be circular!"
605
+ node_visit = self.nodes_visit[current_node_id]
606
+
607
+ visited_nodes.add(current_node_id)
608
+ if self.children[current_node_id]:
609
+ # Nodes
610
+ children = [
611
+ child
612
+ for child in list(self.children[current_node_id])
613
+ if self.nodes_visit[child] >= visits_threshold
614
+ ]
615
+ children_strings = [newick_render_node(child) for child in children]
616
+ children_strings = ",".join(children_strings)
617
+ if children_strings:
618
+ return f"({children_strings}){current_node_id}:{node_visit}"
619
+ # leafs within threshold
620
+ return f"{current_node_id}:{node_visit}"
621
+
622
+ return f"{current_node_id}:{node_visit}"
623
+
624
+ newick_string = newick_render_node(root_node_id) + ";"
625
+
626
+ meta = {}
627
+ for node_id in iter(visited_nodes):
628
+ node_value = round(self.nodes_total_value[node_id], 3)
629
+
630
+ node_synthesisability = round(self.nodes_init_value[node_id])
631
+
632
+ visit_in_node = self.nodes_visit[node_id]
633
+ meta[node_id] = (node_value, node_synthesisability, visit_in_node)
634
+
635
+ return newick_string, meta
synplan/ml/__init__.py ADDED
File without changes
synplan/ml/networks/__init__.py ADDED
File without changes
synplan/ml/networks/modules.py ADDED
@@ -0,0 +1,234 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Module containing basic pytorch architectures of policy and value neural networks."""
2
+
3
+ from abc import ABC, abstractmethod
4
+ from typing import Dict, List, Tuple, Union
5
+
6
+ import torch
7
+ from adabelief_pytorch import AdaBelief
8
+ from pytorch_lightning import LightningModule
9
+ from torch import Tensor
10
+ from torch.nn import GELU, Dropout, Linear, Module, ModuleDict, ModuleList
11
+ from torch.nn.functional import relu
12
+ from torch.optim.lr_scheduler import ReduceLROnPlateau
13
+ from torch_geometric.data.batch import Batch
14
+ from torch_geometric.nn.conv import GCNConv
15
+ from torch_geometric.nn.pool import global_add_pool
16
+
17
+
18
+ class GraphEmbedding(Module):
19
+ """Needed to convert molecule atom vectors to the single vector using graph
20
+ convolution."""
21
+
22
+ def __init__(
23
+ self, vector_dim: int = 512, dropout: float = 0.4, num_conv_layers: int = 5
24
+ ):
25
+ """Initializes a graph convolutional module. Needed to convert molecule atom
26
+ vectors to the single vector using graph convolution.
27
+
28
+ :param vector_dim: The dimensionality of the hidden layers and output layer of
29
+ graph convolution module.
30
+ :param dropout: Dropout is a regularization technique used in neural networks to
31
+ prevent overfitting. It randomly sets a fraction of input units to 0 at each
32
+ update during training time.
33
+ :param num_conv_layers: The number of convolutional layers in a graph
34
+ convolutional module.
35
+ """
36
+
37
+ super().__init__()
38
+ self.expansion = Linear(11, vector_dim)
39
+ self.dropout = Dropout(dropout)
40
+ self.gcn_convs = ModuleList(
41
+ [
42
+ GCNConv(
43
+ vector_dim,
44
+ vector_dim,
45
+ improved=True,
46
+ )
47
+ for _ in range(num_conv_layers)
48
+ ]
49
+ )
50
+
51
+ def forward(self, graph: Batch, batch_size: int) -> Tensor:
52
+ """Takes a graph as input and performs graph convolution on it.
53
+
54
+ :param graph: The batch of molecular graphs, where each atom is represented by
55
+ the atom/bond vector.
56
+ :param batch_size: The size of the batch.
57
+ :return: Graph embedding.
58
+ """
59
+ atoms, connections = graph.x.float(), graph.edge_index.long()
60
+ atoms = torch.log(atoms + 1)
61
+ atoms = self.expansion(atoms)
62
+ for gcn_conv in self.gcn_convs:
63
+ atoms = atoms + self.dropout(relu(gcn_conv(atoms, connections)))
64
+
65
+ return global_add_pool(atoms, graph.batch, size=batch_size)
66
+
67
+
68
+ class GraphEmbeddingConcat(GraphEmbedding, Module):
69
+ """Needed to concat.""" # TODO for what ?
70
+
71
+ def __init__(
72
+ self, vector_dim: int = 512, dropout: float = 0.4, num_conv_layers: int = 8
73
+ ):
74
+ super().__init__()
75
+
76
+ gcn_dim = vector_dim // num_conv_layers
77
+
78
+ self.expansion = Linear(11, gcn_dim)
79
+ self.dropout = Dropout(dropout)
80
+ self.gcn_convs = ModuleList(
81
+ [
82
+ ModuleDict(
83
+ {
84
+ "gcn": GCNConv(gcn_dim, gcn_dim, improved=True),
85
+ "activation": GELU(),
86
+ }
87
+ )
88
+ for _ in range(num_conv_layers)
89
+ ]
90
+ )
91
+
92
+ def forward(self, graph: Batch, batch_size: int) -> Tensor:
93
+ """Takes a graph as input and performs graph convolution on it.
94
+
95
+ :param graph: The batch of molecular graphs, where each atom is represented by
96
+ the atom/bond vector.
97
+ :param batch_size: The size of the batch.
98
+ :return: Graph embedding.
99
+ """
100
+
101
+ atoms, connections = graph.x.float(), graph.edge_index.long()
102
+ atoms = torch.log(atoms + 1)
103
+ atoms = self.expansion(atoms)
104
+
105
+ collected_atoms = []
106
+ for gcn_convs in self.gcn_convs:
107
+ atoms = gcn_convs["gcn"](atoms, connections)
108
+ atoms = gcn_convs["activation"](atoms)
109
+ atoms = self.dropout(atoms)
110
+ collected_atoms.append(atoms)
111
+
112
+ atoms = torch.cat(collected_atoms, dim=-1)
113
+
114
+ return global_add_pool(atoms, graph.batch, size=batch_size)
115
+
116
+
117
+ class MCTSNetwork(LightningModule, ABC):
118
+ """Basic class for policy and value networks."""
119
+
120
+ def __init__(
121
+ self,
122
+ vector_dim: int,
123
+ batch_size: int,
124
+ dropout: float = 0.4,
125
+ num_conv_layers: int = 5,
126
+ learning_rate: float = 0.001,
127
+ gcn_concat: bool = False,
128
+ ):
129
+ """The basic class for MCTS graph convolutional neural networks (policy and
130
+ value network).
131
+
132
+ :param vector_dim: The dimensionality of the hidden layers and output layer of
133
+ graph convolution module.
134
+ :param dropout: Dropout is a regularization technique used in neural networks to
135
+ prevent overfitting.
136
+ :param num_conv_layers: The number of convolutional layers in a graph
137
+ convolutional module.
138
+ :param learning_rate: The learning rate determines how quickly the model learns
139
+ from the training data.
140
+ :param gcn_concat: ???. #TODO explain
141
+ """
142
+ super().__init__()
143
+ if gcn_concat:
144
+ self.embedder = GraphEmbeddingConcat(vector_dim, dropout, num_conv_layers)
145
+ else:
146
+ self.embedder = GraphEmbedding(vector_dim, dropout, num_conv_layers)
147
+ self.batch_size = batch_size
148
+ self.lr = learning_rate
149
+
150
+ @abstractmethod
151
+ def forward(self, batch: Batch) -> Tensor:
152
+ """The forward function takes a batch of input data and performs forward
153
+ propagation through the neural network.
154
+
155
+ :param batch: The batch of molecular graphs processed together in a single
156
+ forward pass through the neural network.
157
+ """
158
+
159
+ @abstractmethod
160
+ def _get_loss(self, batch: Batch) -> Tensor:
161
+ """Calculate the loss for a given batch of data.
162
+
163
+ :param batch: The batch of input data that is used to compute the loss.
164
+ """
165
+
166
+ def training_step(self, batch: Batch, batch_idx: int) -> Tensor:
167
+ """Calculates the loss for a given training batch and logs the loss value.
168
+
169
+ :param batch: The batch of data that is used for training.
170
+ :param batch_idx: The index of the batch.
171
+ :return: The value of the training loss.
172
+ """
173
+ metrics = self._get_loss(batch)
174
+ for name, value in metrics.items():
175
+ self.log(
176
+ "train_" + name,
177
+ value,
178
+ prog_bar=True,
179
+ on_step=True,
180
+ on_epoch=True,
181
+ batch_size=self.batch_size,
182
+ )
183
+ return metrics["loss"]
184
+
185
+ def validation_step(self, batch: Batch, batch_idx: int) -> None:
186
+ """Calculates the loss for a given validation batch and logs the loss value.
187
+
188
+ :param batch: The batch of data that is used for validation.
189
+ :param batch_idx: The index of the batch.
190
+ """
191
+ metrics = self._get_loss(batch)
192
+ for name, value in metrics.items():
193
+ self.log("val_" + name, value, on_epoch=True, batch_size=self.batch_size)
194
+
195
+ def test_step(self, batch: Batch, batch_idx: int) -> None:
196
+ """Calculates the loss for a given test batch and logs the loss value.
197
+
198
+ :param batch: The batch of data that is used for testing.
199
+ :param batch_idx: The index of the batch.
200
+ """
201
+ metrics = self._get_loss(batch)
202
+ for name, value in metrics.items():
203
+ self.log("test_" + name, value, on_epoch=True, batch_size=self.batch_size)
204
+
205
+ def configure_optimizers(
206
+ self,
207
+ ) -> Tuple[List[AdaBelief], List[Dict[str, Union[bool, str, ReduceLROnPlateau]]]]:
208
+ """Returns an optimizer and a learning rate scheduler for training a model using
209
+ the AdaBelief optimizer and ReduceLROnPlateau scheduler.
210
+
211
+ :return: The optimizer and a scheduler.
212
+ """
213
+
214
+ optimizer = AdaBelief(
215
+ self.parameters(),
216
+ lr=self.lr,
217
+ eps=1e-16,
218
+ betas=(0.9, 0.999),
219
+ weight_decouple=True,
220
+ rectify=True,
221
+ weight_decay=0.01,
222
+ print_change_log=False,
223
+ )
224
+
225
+ lr_scheduler = ReduceLROnPlateau(
226
+ optimizer, patience=3, factor=0.8, min_lr=5e-5, verbose=True
227
+ )
228
+ scheduler = {
229
+ "scheduler": lr_scheduler,
230
+ "reduce_on_plateau": True,
231
+ "monitor": "val_loss",
232
+ }
233
+
234
+ return [optimizer], [scheduler]
synplan/ml/networks/policy.py ADDED
@@ -0,0 +1,137 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Module containing main class for policy network."""
2
+
3
+ from abc import ABC
4
+ from typing import Dict
5
+
6
+ import torch
7
+ from pytorch_lightning import LightningModule
8
+ from torch import Tensor
9
+ from torch.nn import Linear
10
+ from torch.nn.functional import binary_cross_entropy_with_logits, cross_entropy, one_hot
11
+ from torch_geometric.data.batch import Batch
12
+ from torchmetrics.functional.classification import f1_score, recall, specificity
13
+
14
+ from synplan.ml.networks.modules import MCTSNetwork
15
+
16
+
17
+ class PolicyNetwork(MCTSNetwork, LightningModule, ABC):
18
+ """Policy network."""
19
+
20
+ def __init__(
21
+ self,
22
+ *args,
23
+ n_rules: int,
24
+ vector_dim: int,
25
+ policy_type: str = "ranking",
26
+ **kwargs
27
+ ):
28
+ """Initializes a policy network with the given number of reaction rules (output
29
+ dimension) and vector graph embedding dimension, and creates linear layers for
30
+ predicting the regular and priority reaction rules.
31
+
32
+ :param n_rules: The number of reaction rules in the policy network.
33
+ :param vector_dim: The dimensionality of the input vectors.
34
+ """
35
+ super().__init__(vector_dim, *args, **kwargs)
36
+ self.save_hyperparameters()
37
+ self.policy_type = policy_type
38
+ self.n_rules = n_rules
39
+ self.y_predictor = Linear(vector_dim, n_rules)
40
+
41
+ if self.policy_type == "filtering":
42
+ self.priority_predictor = Linear(vector_dim, n_rules)
43
+
44
+ def forward(self, batch: Batch) -> Tensor:
45
+ """Takes a molecular graph, applies a graph convolution and sigmoid layers to
46
+ predict regular and priority reaction rules.
47
+
48
+ :param batch: The input batch of molecular graphs.
49
+ :return: Returns the vector of probabilities (given by sigmoid) of successful
50
+ application of regular and priority reaction rules.
51
+ """
52
+ x = self.embedder(batch, self.batch_size)
53
+ y = self.y_predictor(x)
54
+
55
+ if self.policy_type == "ranking":
56
+ y = torch.softmax(y, dim=-1)
57
+ return y
58
+
59
+ if self.policy_type == "filtering":
60
+ y = torch.sigmoid(y)
61
+ priority = torch.sigmoid(self.priority_predictor(x))
62
+ return y, priority
63
+
64
+ def _get_loss(self, batch: Batch) -> Dict[str, Tensor]:
65
+ """Calculates the loss and various classification metrics for a given batch for
66
+ reaction rules prediction.
67
+
68
+ :param batch: The batch of molecular graphs.
69
+ :return: A dictionary with loss value and balanced accuracy of reaction rules
70
+ prediction.
71
+ """
72
+ true_y = batch.y_rules.long()
73
+ x = self.embedder(batch, self.batch_size)
74
+ pred_y = self.y_predictor(x)
75
+
76
+ if self.policy_type == "ranking":
77
+ true_one_hot = one_hot(true_y, num_classes=self.n_rules)
78
+ loss = cross_entropy(pred_y, true_one_hot.float())
79
+ ba_y = (
80
+ recall(pred_y, true_y, task="multiclass", num_classes=self.n_rules)
81
+ + specificity(
82
+ pred_y, true_y, task="multiclass", num_classes=self.n_rules
83
+ )
84
+ ) / 2
85
+ f1_y = f1_score(pred_y, true_y, task="multiclass", num_classes=self.n_rules)
86
+
87
+ metrics = {"loss": loss, "balanced_accuracy_y": ba_y, "f1_score_y": f1_y}
88
+
89
+ elif self.policy_type == "filtering":
90
+ loss_y = binary_cross_entropy_with_logits(pred_y, true_y.float())
91
+
92
+ ba_y = (
93
+ recall(pred_y, true_y, task="multilabel", num_labels=self.n_rules)
94
+ + specificity(
95
+ pred_y, true_y, task="multilabel", num_labels=self.n_rules
96
+ )
97
+ ) / 2
98
+
99
+ f1_y = f1_score(pred_y, true_y, task="multilabel", num_labels=self.n_rules)
100
+
101
+ true_priority = batch.y_priority.float()
102
+ pred_priority = self.priority_predictor(x)
103
+ loss_priority = binary_cross_entropy_with_logits(
104
+ pred_priority, true_priority
105
+ )
106
+
107
+ loss = loss_y + loss_priority
108
+
109
+ true_priority = true_priority.long()
110
+ ba_priority = (
111
+ recall(
112
+ pred_priority,
113
+ true_priority,
114
+ task="multilabel",
115
+ num_labels=self.n_rules,
116
+ )
117
+ + specificity(
118
+ pred_priority,
119
+ true_priority,
120
+ task="multilabel",
121
+ num_labels=self.n_rules,
122
+ )
123
+ ) / 2
124
+
125
+ f1_priority = f1_score(
126
+ pred_priority, true_priority, task="multilabel", num_labels=self.n_rules
127
+ )
128
+
129
+ metrics = {
130
+ "loss": loss,
131
+ "balanced_accuracy_y": ba_y,
132
+ "f1_score_y": f1_y,
133
+ "balanced_accuracy_priority": ba_priority,
134
+ "f1_score_priority": f1_priority,
135
+ }
136
+
137
+ return metrics
synplan/ml/networks/value.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Module containing main class for value network."""
2
+
3
+ from abc import ABC
4
+ from typing import Any, Dict
5
+
6
+ import torch
7
+ from pytorch_lightning import LightningModule
8
+ from torch import Tensor
9
+ from torch.nn import Linear
10
+ from torch.nn.functional import binary_cross_entropy_with_logits
11
+ from torch_geometric.data.batch import Batch
12
+ from torchmetrics.functional.classification import (
13
+ binary_f1_score,
14
+ binary_recall,
15
+ binary_specificity,
16
+ )
17
+
18
+ from synplan.ml.networks.modules import MCTSNetwork
19
+
20
+
21
+ class ValueNetwork(MCTSNetwork, LightningModule, ABC):
22
+ """Value network."""
23
+
24
+ def __init__(self, vector_dim: int, *args: Any, **kwargs: Any) -> None:
25
+ """Initializes a value network, and creates linear layer for predicting the
26
+ synthesisability of given precursor represented by molecular graph.
27
+
28
+ :param vector_dim: The dimensionality of the output linear layer.
29
+ """
30
+ super().__init__(vector_dim, *args, **kwargs)
31
+ self.save_hyperparameters()
32
+ self.predictor = Linear(vector_dim, 1)
33
+
34
+ def forward(self, batch) -> torch.Tensor:
35
+ """Takes a batch of molecular graphs, applies a graph convolution returns the
36
+ synthesisability (probability given by sigmoid function) of a given precursor
37
+ represented by molecular graph precessed by graph convolution.
38
+
39
+ :param batch: The batch of molecular graphs.
40
+ :return: The predicted synthesisability (between 0 and 1).
41
+ """
42
+
43
+ x = self.embedder(batch, self.batch_size)
44
+ x = torch.sigmoid(self.predictor(x))
45
+ return x
46
+
47
+ def _get_loss(self, batch: Batch) -> Dict[str, Tensor]:
48
+ """Calculates the loss and various classification metrics for a given batch for
49
+ the precursor synthesysability prediction.
50
+
51
+ :param batch: The batch of molecular graphs.
52
+ :return: The dictionary with loss value and balanced accuracy of precursor
53
+ synthesysability prediction.
54
+ """
55
+
56
+ true_y = batch.y.float()
57
+ true_y = torch.unsqueeze(true_y, -1)
58
+ x = self.embedder(batch, self.batch_size)
59
+ pred_y = self.predictor(x)
60
+ # calc loss func
61
+ loss = binary_cross_entropy_with_logits(pred_y, true_y)
62
+
63
+ true_y = true_y.long()
64
+ ba = (binary_recall(pred_y, true_y) + binary_specificity(pred_y, true_y)) / 2
65
+ f1 = binary_f1_score(pred_y, true_y)
66
+ metrics = {"loss": loss, "balanced_accuracy": ba, "f1_score": f1}
67
+ return metrics
synplan/ml/training/__init__.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .supervised import *
2
+ from .preprocessing import ValueNetworkDataset, mol_to_pyg, MENDEL_INFO
3
+ from .supervised import create_policy_dataset, run_policy_training
4
+
5
+ __all__ = [
6
+ "ValueNetworkDataset",
7
+ "mol_to_pyg",
8
+ "MENDEL_INFO",
9
+ "create_policy_dataset",
10
+ "run_policy_training",
11
+ ]
synplan/ml/training/preprocessing.py ADDED
@@ -0,0 +1,516 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Module containing functions for preparation of the training sets for policy and value
2
+ network."""
3
+
4
+ import logging
5
+ import os
6
+ import pickle
7
+ from abc import ABC
8
+ from typing import Any, Dict, List, Optional, Tuple
9
+
10
+ import ray
11
+ import torch
12
+ from CGRtools import smiles
13
+ from CGRtools.containers import MoleculeContainer
14
+ from CGRtools.exceptions import InvalidAromaticRing
15
+ from CGRtools.reactor import Reactor
16
+ from ray.util.queue import Empty, Queue
17
+ from torch import Tensor
18
+ from torch_geometric.data import InMemoryDataset
19
+ from torch_geometric.data.data import Data
20
+ from torch_geometric.data.makedirs import makedirs
21
+ from torch_geometric.transforms import ToUndirected
22
+ from tqdm import tqdm
23
+
24
+ from synplan.chem.utils import unite_molecules
25
+ from synplan.utils.files import ReactionReader
26
+ from synplan.utils.loading import load_reaction_rules
27
+
28
+
29
+ class ValueNetworkDataset(InMemoryDataset, ABC):
30
+ """Value network dataset."""
31
+
32
+ def __init__(self, extracted_precursor: Dict[str, float]) -> None:
33
+ """Initializes a value network dataset object.
34
+
35
+ :param extracted_precursor: The dictionary with the extracted from the built
36
+ search trees precursor and their labels.
37
+ """
38
+ super().__init__(None, None, None)
39
+
40
+ if extracted_precursor:
41
+ self.data, self.slices = self.graphs_from_extracted_precursor(
42
+ extracted_precursor
43
+ )
44
+
45
+ @staticmethod
46
+ def mol_to_graph(molecule: MoleculeContainer, label: float) -> Optional[Data]:
47
+ """Takes a molecule as input, and converts the molecule to a PyTorch geometric
48
+ graph, assigns the reward value (label) to the graph, and returns the graph.
49
+
50
+ :param molecule: The input molecule.
51
+ :param label: The label (solved/unsolved routes in the tree) of the molecule
52
+ (precursor).
53
+ :return: A PyTorch Geometric graph representation of a molecule.
54
+ """
55
+ if len(molecule) > 2:
56
+ pyg = mol_to_pyg(molecule)
57
+ if pyg:
58
+ pyg.y = torch.tensor([label])
59
+ return pyg
60
+
61
+ return None
62
+
63
+ def graphs_from_extracted_precursor(
64
+ self, extracted_precursor: Dict[str, float]
65
+ ) -> Tuple[Data, Dict]:
66
+ """Converts the extracted from the search trees precursor to the PyTorch geometric
67
+ graphs.
68
+
69
+ :param extracted_precursor: The dictionary with the extracted from the built
70
+ search trees precursor and their labels.
71
+ :return: The PyTorch geometric graphs and slices.
72
+ """
73
+ processed_data = []
74
+ for smi, label in extracted_precursor.items():
75
+ mol = smiles(smi)
76
+ pyg = self.mol_to_graph(mol, label)
77
+ if pyg:
78
+ processed_data.append(pyg)
79
+ data, slices = self.collate(processed_data)
80
+ return data, slices
81
+
82
+
83
+ class RankingPolicyDataset(InMemoryDataset):
84
+ """Ranking policy network dataset."""
85
+
86
+ def __init__(self, reactions_path: str, reaction_rules_path: str, output_path: str):
87
+ """Initializes a policy network dataset.
88
+
89
+ :param reactions_path: The path to the file containing the reaction data used
90
+ for extraction of reaction rules.
91
+ :param reaction_rules_path: The path to the file containing the reaction rules.
92
+ :param output_path: The output path to the file where policy network dataset
93
+ will be saved.
94
+ """
95
+ super().__init__(None, None, None)
96
+
97
+ self.reactions_path = reactions_path
98
+ self.reaction_rules_path = reaction_rules_path
99
+ self.output_path = output_path
100
+
101
+ if output_path and os.path.exists(output_path):
102
+ self.data, self.slices = torch.load(self.output_path)
103
+ else:
104
+ self.data, self.slices = self.prepare_data()
105
+
106
+ @property
107
+ def num_classes(self) -> int:
108
+ return self._infer_num_classes(self._data.y_rules)
109
+
110
+ def prepare_data(self) -> Tuple[Data, Dict[str, Tensor]]:
111
+ """Prepares data by loading reaction rules, preprocessing the molecules,
112
+ collating the data, and returning the data and slices.
113
+
114
+ :return: The PyTorch geometric graphs and slices.
115
+ """
116
+
117
+ with open(self.reaction_rules_path, "rb") as inp:
118
+ reaction_rules = pickle.load(inp)
119
+ reaction_rules = sorted(reaction_rules, key=lambda x: len(x[1]), reverse=True)
120
+
121
+ reaction_rule_pairs = {}
122
+ for rule_i, (_, reactions_ids) in enumerate(reaction_rules):
123
+ for reaction_id in reactions_ids:
124
+ reaction_rule_pairs[reaction_id] = rule_i
125
+ reaction_rule_pairs = dict(sorted(reaction_rule_pairs.items()))
126
+
127
+ list_of_graphs = []
128
+ with ReactionReader(self.reactions_path) as reactions:
129
+
130
+ for reaction_id, reaction in tqdm(
131
+ enumerate(reactions),
132
+ desc="Number of reactions processed: ",
133
+ bar_format="{desc}{n} [{elapsed}]",
134
+ ):
135
+
136
+ rule_id = reaction_rule_pairs.get(reaction_id)
137
+ if rule_id:
138
+ try: # MENDEL_INFO does not contain cadmium (Cd) properties
139
+ molecule = unite_molecules(reaction.products)
140
+ pyg_graph = mol_to_pyg(molecule)
141
+
142
+ except (
143
+ Exception
144
+ ) as e: # TypeError: can't assign a NoneType to a torch.ByteTensor
145
+ logging.debug(e)
146
+ continue
147
+
148
+ if pyg_graph is not None:
149
+ pyg_graph.y_rules = torch.tensor([rule_id], dtype=torch.long)
150
+ list_of_graphs.append(pyg_graph)
151
+ else:
152
+ continue
153
+
154
+ data, slices = self.collate(list_of_graphs)
155
+ if self.output_path:
156
+ makedirs(os.path.dirname(self.output_path))
157
+ torch.save((data, slices), self.output_path)
158
+
159
+ return data, slices
160
+
161
+
162
+ class FilteringPolicyDataset(InMemoryDataset):
163
+ """Filtering policy network dataset."""
164
+
165
+ def __init__(
166
+ self,
167
+ molecules_path: str,
168
+ reaction_rules_path: str,
169
+ output_path: str,
170
+ num_cpus: int,
171
+ ) -> None:
172
+ """Initializes a policy network dataset object.
173
+
174
+ :param molecules_path: The path to the file containing the molecules for
175
+ reaction rule appliance.
176
+ :param reaction_rules_path: The path to the file containing the reaction rules.
177
+ :param output_path: The output path to the file where policy network dataset
178
+ will be stored.
179
+ :param num_cpus: The number of CPUs to be used for the dataset preparation.
180
+ :return: None.
181
+ """
182
+ super().__init__(None, None, None)
183
+
184
+ self.molecules_path = molecules_path
185
+ self.reaction_rules_path = reaction_rules_path
186
+ self.output_path = output_path
187
+ self.num_cpus = num_cpus
188
+ self.batch_size = 100
189
+
190
+ if output_path and os.path.exists(output_path):
191
+ self.data, self.slices = torch.load(self.output_path)
192
+ else:
193
+ self.data, self.slices = self.prepare_data()
194
+
195
+ @property
196
+ def num_classes(self) -> int:
197
+ return self._data.y_rules.shape[1]
198
+
199
+ def prepare_data(self) -> Tuple[Data, Dict]:
200
+ """Prepares data by loading reaction rules, initializing Ray, preprocessing the
201
+ molecules, collating the data, and returning the data and slices.
202
+
203
+ :return: The PyTorch geometric graphs and slices.
204
+ """
205
+
206
+ ray.init(num_cpus=self.num_cpus, ignore_reinit_error=True)
207
+ reaction_rules = load_reaction_rules(self.reaction_rules_path)
208
+ reaction_rules_ids = ray.put(reaction_rules)
209
+
210
+ to_process = Queue(maxsize=self.batch_size * self.num_cpus)
211
+ processed_data = []
212
+ results_ids = [
213
+ preprocess_filtering_policy_molecules.remote(to_process, reaction_rules_ids)
214
+ for _ in range(self.num_cpus)
215
+ ]
216
+
217
+ with open(self.molecules_path, "r", encoding="utf-8") as inp_data:
218
+ for molecule in tqdm(
219
+ inp_data.read().splitlines(),
220
+ desc="Number of molecules processed: ",
221
+ bar_format="{desc}{n} [{elapsed}]",
222
+ ):
223
+
224
+ to_process.put(molecule)
225
+
226
+ results = [graph for res in ray.get(results_ids) if res for graph in res]
227
+ processed_data.extend(results)
228
+
229
+ ray.shutdown()
230
+
231
+ for pyg in processed_data:
232
+ pyg.y_rules = pyg.y_rules.to_dense()
233
+ pyg.y_priority = pyg.y_priority.to_dense()
234
+
235
+ data, slices = self.collate(processed_data)
236
+ if self.output_path:
237
+ makedirs(os.path.dirname(self.output_path))
238
+ torch.save((data, slices), self.output_path)
239
+
240
+ return data, slices
241
+
242
+
243
+ def reaction_rules_appliance(
244
+ molecule: MoleculeContainer, reaction_rules: List[Reactor]
245
+ ) -> Tuple[List[int], List[int]]:
246
+ """Applies each reaction rule from the list of reaction rules to a given molecule
247
+ and returns the indexes of the successfully applied regular and prioritized reaction
248
+ rules.
249
+
250
+ :param molecule: The input molecule.
251
+ :param reaction_rules: The list of reaction rules.
252
+ :return: The two lists of indexes of successfully applied regular reaction rules and
253
+ priority reaction rules.
254
+ """
255
+
256
+ applied_rules, priority_rules = [], []
257
+ for i, rule in enumerate(reaction_rules):
258
+
259
+ rule_applied = False
260
+ rule_prioritized = False
261
+
262
+ try:
263
+ for reaction in rule([molecule]):
264
+ for prod in reaction.products:
265
+ prod.kekule()
266
+ if prod.check_valence():
267
+ break
268
+ rule_applied = True
269
+
270
+ # check priority rules
271
+ if len(reaction.products) > 1:
272
+ # check coupling retro manual
273
+ if all(len(mol) > 6 for mol in reaction.products):
274
+ if (
275
+ sum(len(mol) for mol in reaction.products)
276
+ - len(reaction.reactants[0])
277
+ < 6
278
+ ):
279
+ rule_prioritized = True
280
+ else:
281
+ # check cyclization retro manual
282
+ if sum(len(mol.sssr) for mol in reaction.products) < sum(
283
+ len(mol.sssr) for mol in reaction.reactants
284
+ ):
285
+ rule_prioritized = True
286
+ #
287
+ if rule_applied:
288
+ applied_rules.append(i)
289
+ #
290
+ if rule_prioritized:
291
+ priority_rules.append(i)
292
+ except Exception as e:
293
+ logging.debug(e)
294
+ continue
295
+
296
+ return applied_rules, priority_rules
297
+
298
+
299
+ @ray.remote
300
+ def preprocess_filtering_policy_molecules(
301
+ to_process: Queue, reaction_rules: List[Reactor]
302
+ ) -> List[Optional[Data]]:
303
+ """Preprocesses a list of molecules by applying reaction rules and converting
304
+ molecules into PyTorch geometric graphs. Successfully applied reaction rules are
305
+ converted to binary vectors for policy network training.
306
+
307
+ :param to_process: The queue containing SMILES of molecules to be converted to the
308
+ training data.
309
+ :param reaction_rules: The list of reaction rules.
310
+ :return: The list of PyGraph objects.
311
+ """
312
+
313
+ pyg_graphs = []
314
+ while True:
315
+ try:
316
+ molecule = smiles(to_process.get(timeout=30))
317
+ if not isinstance(molecule, MoleculeContainer):
318
+ continue
319
+
320
+ # reaction reaction_rules application
321
+ applied_rules, priority_rules = reaction_rules_appliance(
322
+ molecule, reaction_rules
323
+ )
324
+
325
+ y_rules = torch.sparse_coo_tensor(
326
+ [applied_rules],
327
+ torch.ones(len(applied_rules)),
328
+ (len(reaction_rules),),
329
+ dtype=torch.uint8,
330
+ )
331
+ y_priority = torch.sparse_coo_tensor(
332
+ [priority_rules],
333
+ torch.ones(len(priority_rules)),
334
+ (len(reaction_rules),),
335
+ dtype=torch.uint8,
336
+ )
337
+
338
+ y_rules = torch.unsqueeze(y_rules, 0)
339
+ y_priority = torch.unsqueeze(y_priority, 0)
340
+
341
+ pyg_graph = mol_to_pyg(molecule)
342
+ if not pyg_graph:
343
+ continue
344
+ pyg_graph.y_rules = y_rules
345
+ pyg_graph.y_priority = y_priority
346
+ pyg_graphs.append(pyg_graph)
347
+
348
+ except Empty:
349
+ break
350
+
351
+ return pyg_graphs
352
+
353
+
354
+ def atom_to_vector(atom: Any) -> Tensor:
355
+ """Given an atom, return a vector of length 8 with the following
356
+ information:
357
+
358
+ 1. Atomic number
359
+ 2. Period
360
+ 3. Group
361
+ 4. Number of electrons + atom's charge
362
+ 5. Shell
363
+ 6. Total number of hydrogens
364
+ 7. Whether the atom is in a ring
365
+ 8. Number of neighbors
366
+
367
+ :param atom: The atom object.
368
+
369
+ :return: The vector of the atom.
370
+ """
371
+ vector = torch.zeros(8, dtype=torch.uint8)
372
+ period, group, shell, electrons = MENDEL_INFO[atom.atomic_symbol]
373
+ vector[0] = atom.atomic_number
374
+ vector[1] = period
375
+ vector[2] = group
376
+ vector[3] = electrons + atom.charge
377
+ vector[4] = shell
378
+ vector[5] = atom.total_hydrogens
379
+ vector[6] = int(atom.in_ring)
380
+ vector[7] = atom.neighbors
381
+ return vector
382
+
383
+
384
+ def bonds_to_vector(molecule: MoleculeContainer, atom_ind: int) -> Tensor:
385
+ """Takes a molecule and an atom index as input, and returns a vector representing
386
+ the bond orders of the atom's bonds.
387
+
388
+ :param molecule: The given molecule.
389
+ :param atom_ind: The index of the atom in the molecule to be converted to the bond
390
+ vector.
391
+ :return: The torch tensor of size 3, with each element representing the order of
392
+ bonds connected to the atom with the given index in the molecule.
393
+ """
394
+
395
+ vector = torch.zeros(3, dtype=torch.uint8)
396
+ for b_order in molecule._bonds[atom_ind].values():
397
+ vector[int(b_order) - 1] += 1
398
+ return vector
399
+
400
+
401
+ def mol_to_matrix(molecule: MoleculeContainer) -> Tensor:
402
+ """Given a molecule, it returns a vector of shape (max_atoms, 12) where each row is
403
+ an atom and each column is a feature.
404
+
405
+ :param molecule: The molecule to be converted to a vector
406
+ :return: The atoms vectors array.
407
+ """
408
+
409
+ atoms_vectors = torch.zeros((len(molecule), 11), dtype=torch.uint8)
410
+ for n, atom in molecule.atoms():
411
+ atoms_vectors[n - 1][:8] = atom_to_vector(atom)
412
+ for n, _ in molecule.atoms():
413
+ atoms_vectors[n - 1][8:] = bonds_to_vector(molecule, n)
414
+
415
+ return atoms_vectors
416
+
417
+
418
+ def mol_to_pyg(
419
+ molecule: MoleculeContainer, canonicalize: bool = True
420
+ ) -> Optional[Data]:
421
+ """Takes a list of molecules and returns a list of PyTorch Geometric graphs, a one-
422
+ hot encoded vectors of the atoms, and a matrices of the bonds.
423
+
424
+ :param molecule: The molecule to be converted to PyTorch Geometric graph.
425
+ :param canonicalize: If True, the input molecule is canonicalized.
426
+ :return: The list of PyGraph objects.
427
+ """
428
+
429
+ if len(molecule) == 1: # to avoid a precursor to be a single atom
430
+ return None
431
+
432
+ tmp_molecule = molecule.copy()
433
+ try:
434
+ if canonicalize:
435
+ tmp_molecule.canonicalize()
436
+ tmp_molecule.kekule()
437
+ if tmp_molecule.check_valence():
438
+ return None
439
+ except InvalidAromaticRing:
440
+ return None
441
+
442
+ # remapping target for torch_geometric because
443
+ # it is necessary that the elements in edge_index only hold nodes_idx in the range { 0, ..., num_nodes - 1}
444
+ new_mappings = {n: i for i, (n, _) in enumerate(tmp_molecule.atoms(), 1)}
445
+ tmp_molecule.remap(new_mappings)
446
+
447
+ # get edge indexes from target mapping
448
+ edge_index = []
449
+ for atom, neighbour, bond in tmp_molecule.bonds():
450
+ edge_index.append([atom - 1, neighbour - 1])
451
+ edge_index = torch.tensor(edge_index, dtype=torch.long)
452
+
453
+ #
454
+ x = mol_to_matrix(tmp_molecule)
455
+
456
+ mol_pyg_graph = Data(x=x, edge_index=edge_index.t().contiguous())
457
+ mol_pyg_graph = ToUndirected()(mol_pyg_graph)
458
+
459
+ assert mol_pyg_graph.is_undirected()
460
+
461
+ return mol_pyg_graph
462
+
463
+
464
+ MENDEL_INFO = {
465
+ "Ag": (5, 11, 1, 1),
466
+ "Al": (3, 13, 2, 1),
467
+ "Ar": (3, 18, 2, 6),
468
+ "As": (4, 15, 2, 3),
469
+ "B": (2, 13, 2, 1),
470
+ "Ba": (6, 2, 1, 2),
471
+ "Bi": (6, 15, 2, 3),
472
+ "Br": (4, 17, 2, 5),
473
+ "C": (2, 14, 2, 2),
474
+ "Ca": (4, 2, 1, 2),
475
+ "Ce": (6, None, 1, 2),
476
+ "Cl": (3, 17, 2, 5),
477
+ "Cr": (4, 6, 1, 1),
478
+ "Cs": (6, 1, 1, 1),
479
+ "Cu": (4, 11, 1, 1),
480
+ "Dy": (6, None, 1, 2),
481
+ "Er": (6, None, 1, 2),
482
+ "F": (2, 17, 2, 5),
483
+ "Fe": (4, 8, 1, 2),
484
+ "Ga": (4, 13, 2, 1),
485
+ "Gd": (6, None, 1, 2),
486
+ "Ge": (4, 14, 2, 2),
487
+ "Hg": (6, 12, 1, 2),
488
+ "I": (5, 17, 2, 5),
489
+ "In": (5, 13, 2, 1),
490
+ "K": (4, 1, 1, 1),
491
+ "La": (6, 3, 1, 2),
492
+ "Li": (2, 1, 1, 1),
493
+ "Mg": (3, 2, 1, 2),
494
+ "Mn": (4, 7, 1, 2),
495
+ "N": (2, 15, 2, 3),
496
+ "Na": (3, 1, 1, 1),
497
+ "Nd": (6, None, 1, 2),
498
+ "O": (2, 16, 2, 4),
499
+ "P": (3, 15, 2, 3),
500
+ "Pb": (6, 14, 2, 2),
501
+ "Pd": (5, 10, 3, 10),
502
+ "Pr": (6, None, 1, 2),
503
+ "Rb": (5, 1, 1, 1),
504
+ "S": (3, 16, 2, 4),
505
+ "Sb": (5, 15, 2, 3),
506
+ "Se": (4, 16, 2, 4),
507
+ "Si": (3, 14, 2, 2),
508
+ "Sm": (6, None, 1, 2),
509
+ "Sn": (5, 14, 2, 2),
510
+ "Sr": (5, 2, 1, 2),
511
+ "Te": (5, 16, 2, 4),
512
+ "Ti": (4, 4, 1, 2),
513
+ "Tl": (6, 13, 2, 1),
514
+ "Yb": (6, None, 1, 2),
515
+ "Zn": (4, 12, 1, 2),
516
+ }
synplan/ml/training/reinforcement.py ADDED
@@ -0,0 +1,379 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Module containing functions for running value network tuning with reinforcement learning
2
+ approach."""
3
+
4
+ import os
5
+ import random
6
+ from collections import defaultdict
7
+ from pathlib import Path
8
+ from random import shuffle
9
+ from typing import Dict, List
10
+
11
+ import torch
12
+ from CGRtools.containers import MoleculeContainer
13
+ from pytorch_lightning import Trainer
14
+ from torch.utils.data import random_split
15
+ from torch_geometric.data.lightning import LightningDataset
16
+
17
+ from synplan.chem.precursor import compose_precursors
18
+ from synplan.mcts.evaluation import ValueNetworkFunction
19
+ from synplan.mcts.expansion import PolicyNetworkFunction
20
+ from synplan.mcts.tree import Tree
21
+ from synplan.ml.networks.value import ValueNetwork
22
+ from synplan.ml.training.preprocessing import ValueNetworkDataset
23
+ from synplan.utils.config import (
24
+ PolicyNetworkConfig,
25
+ TuningConfig,
26
+ TreeConfig,
27
+ ValueNetworkConfig,
28
+ )
29
+ from synplan.utils.files import MoleculeReader
30
+ from synplan.utils.loading import (
31
+ load_building_blocks,
32
+ load_reaction_rules,
33
+ load_value_net,
34
+ )
35
+ from synplan.utils.logging import DisableLogger, HiddenPrints
36
+
37
+
38
+ def create_value_network(value_config: ValueNetworkConfig) -> ValueNetwork:
39
+ """Creates the initial value network.
40
+
41
+ :param value_config: The value network configuration.
42
+ :return: The valueNetwork to be trained/tuned.
43
+ """
44
+
45
+ weights_path = Path(value_config.weights_path)
46
+ value_network = ValueNetwork(
47
+ vector_dim=value_config.vector_dim,
48
+ batch_size=value_config.batch_size,
49
+ dropout=value_config.dropout,
50
+ num_conv_layers=value_config.num_conv_layers,
51
+ learning_rate=value_config.learning_rate,
52
+ )
53
+
54
+ with DisableLogger(), HiddenPrints():
55
+ trainer = Trainer()
56
+ trainer.strategy.connect(value_network)
57
+ trainer.save_checkpoint(weights_path)
58
+
59
+ return value_network
60
+
61
+
62
+ def create_targets_batch(
63
+ targets: List[MoleculeContainer], batch_size: int
64
+ ) -> List[List[MoleculeContainer]]:
65
+ """Creates the targets batches for planning simulations and value network tuning.
66
+
67
+ :param targets: The list of target molecules.
68
+ :param batch_size: The size of each target batch.
69
+ :return: The list of lists corresponding to each target batch.
70
+ """
71
+
72
+ num_targets = len(targets)
73
+ batch_splits = list(
74
+ range(num_targets // batch_size + int(bool(num_targets % batch_size)))
75
+ )
76
+
77
+ if int(num_targets / batch_size) == 0:
78
+ print(f"1 batch were created with {num_targets} molecules")
79
+ else:
80
+ print(
81
+ f"{len(batch_splits)} batches were created with {batch_size} molecules each"
82
+ )
83
+
84
+ targets_batch_list = []
85
+ for batch_id in batch_splits:
86
+ batch_slices = [
87
+ i
88
+ for i in range(batch_id * batch_size, (batch_id + 1) * batch_size)
89
+ if i < len(targets)
90
+ ]
91
+ targets_batch_list.append([targets[i] for i in batch_slices])
92
+
93
+ return targets_batch_list
94
+
95
+
96
+ def run_tree_search(
97
+ target: MoleculeContainer,
98
+ tree_config: TreeConfig,
99
+ policy_config: PolicyNetworkConfig,
100
+ value_config: ValueNetworkConfig,
101
+ reaction_rules_path: str,
102
+ building_blocks_path: str,
103
+ ) -> Tree:
104
+ """Runs tree search for the given target molecule.
105
+
106
+ :param target: The target molecule.
107
+ :param tree_config: The planning configuration of tree search.
108
+ :param policy_config: The policy network configuration.
109
+ :param value_config: The value network configuration.
110
+ :param reaction_rules_path: The path to the file with reaction rules.
111
+ :param building_blocks_path: The path to the file with building blocks.
112
+ :return: The built search tree for the given molecule.
113
+ """
114
+
115
+ # policy and value function loading
116
+ policy_function = PolicyNetworkFunction(policy_config=policy_config)
117
+ value_function = ValueNetworkFunction(weights_path=value_config.weights_path)
118
+ reaction_rules = load_reaction_rules(reaction_rules_path)
119
+ building_blocks = load_building_blocks(building_blocks_path, standardize=True)
120
+
121
+ # initialize tree
122
+ tree_config.evaluation_type = "gcn"
123
+ tree_config.silent = True
124
+ tree = Tree(
125
+ target=target,
126
+ config=tree_config,
127
+ reaction_rules=reaction_rules,
128
+ building_blocks=building_blocks,
129
+ expansion_function=policy_function,
130
+ evaluation_function=value_function,
131
+ )
132
+ tree._tqdm = False
133
+
134
+ # remove target from buildings blocs
135
+ if str(target) in tree.building_blocks:
136
+ tree.building_blocks.remove(str(target))
137
+
138
+ # run tree search
139
+ _ = list(tree)
140
+
141
+ return tree
142
+
143
+
144
+ def extract_tree_precursor(tree_list: List[Tree]) -> Dict[str, float]:
145
+ """Takes the built tree and extracts the precursor for value network tuning. The
146
+ precursor from found retrosynthetic routes are labeled as a positive class and precursor
147
+ from not solved routes are labeled as a negative class.
148
+
149
+ :param tree_list: The list of built search trees.
150
+
151
+ :return: The dictionary with the precursor SMILES and its class (positive - 1 or negative - 0).
152
+ """
153
+ extracted_precursor = defaultdict(float)
154
+ for tree in tree_list:
155
+ for idx, node in tree.nodes.items():
156
+ # add solved nodes to set
157
+ if node.is_solved():
158
+ parent = idx
159
+ while parent and parent != 1:
160
+ composed_smi = str(
161
+ compose_precursors(tree.nodes[parent].new_precursors)
162
+ )
163
+ extracted_precursor[composed_smi] = 1.0
164
+ parent = tree.parents[parent]
165
+ else:
166
+ composed_smi = str(compose_precursors(tree.nodes[idx].new_precursors))
167
+ extracted_precursor[composed_smi] = 0.0
168
+
169
+ # shuffle extracted precursor
170
+ processed_keys = list(extracted_precursor.keys())
171
+ shuffle(processed_keys)
172
+ extracted_precursor = {i: extracted_precursor[i] for i in processed_keys}
173
+
174
+ return extracted_precursor
175
+
176
+
177
+ def balance_extracted_precursor(extracted_precursor):
178
+ extracted_precursor_balanced = {}
179
+ neg_list = [i for i, j in extracted_precursor.items() if j == 0]
180
+ for k, v in extracted_precursor.items():
181
+ if v == 1:
182
+ extracted_precursor_balanced[k] = v
183
+ if len(extracted_precursor_balanced) < len(neg_list):
184
+ neg_list.pop(random.choice(range(len(neg_list))))
185
+ return extracted_precursor_balanced
186
+
187
+
188
+ def create_updating_set(
189
+ extracted_precursor: Dict[str, float], batch_size: int = 1
190
+ ) -> LightningDataset:
191
+ """Creates the value network updating dataset from precursor extracted from the planning
192
+ simulation.
193
+
194
+ :param extracted_precursor: The dictionary with the extracted precursor and their
195
+ labels.
196
+ :param batch_size: The size of the batch in value network updating.
197
+ :return: A LightningDataset object, which contains the tuning set for value network
198
+ tuning.
199
+ """
200
+
201
+ extracted_precursor = balance_extracted_precursor(extracted_precursor)
202
+
203
+ full_dataset = ValueNetworkDataset(extracted_precursor)
204
+ train_size = int(0.6 * len(full_dataset))
205
+ val_size = len(full_dataset) - train_size
206
+
207
+ train_set, val_set = random_split(
208
+ full_dataset, [train_size, val_size], torch.Generator().manual_seed(42)
209
+ )
210
+
211
+ print(f"Training set size: {len(train_set)}")
212
+ print(f"Validation set size: {len(val_set)}")
213
+
214
+ return LightningDataset(
215
+ train_set, val_set, batch_size=batch_size, pin_memory=True, drop_last=True
216
+ )
217
+
218
+
219
+ def tune_value_network(
220
+ datamodule: LightningDataset, value_config: ValueNetworkConfig
221
+ ) -> None:
222
+ """Trains the value network using a given tuning data and saves the trained neural
223
+ network.
224
+
225
+ :param datamodule: The tuning dataset (LightningDataset).
226
+ :param value_config: The value network configuration.
227
+ :return: None.
228
+ """
229
+
230
+ current_weights = value_config.weights_path
231
+ value_network = load_value_net(ValueNetwork, current_weights)
232
+
233
+ with DisableLogger(), HiddenPrints():
234
+ trainer = Trainer(
235
+ accelerator="gpu",
236
+ devices=[0],
237
+ max_epochs=value_config.num_epoch,
238
+ enable_checkpointing=False,
239
+ logger=False,
240
+ gradient_clip_val=1.0,
241
+ enable_progress_bar=False,
242
+ )
243
+
244
+ trainer.fit(value_network, datamodule)
245
+ val_score = trainer.validate(value_network, datamodule.val_dataloader())[0]
246
+ trainer.save_checkpoint(current_weights)
247
+
248
+ print(f"Value network balanced accuracy: {val_score['val_balanced_accuracy']}")
249
+
250
+
251
+ def run_training(
252
+ extracted_precursor: Dict[str, float] = None,
253
+ value_config: ValueNetworkConfig = None,
254
+ ) -> None:
255
+ """Runs the training stage in value network tuning.
256
+
257
+ :param extracted_precursor: The precursor extracted from the planing simulations.
258
+ :param value_config: The value network configuration.
259
+ :return: None.
260
+ """
261
+
262
+ # create training set
263
+ training_set = create_updating_set(
264
+ extracted_precursor=extracted_precursor, batch_size=value_config.batch_size
265
+ )
266
+
267
+ # retrain value network
268
+ tune_value_network(datamodule=training_set, value_config=value_config)
269
+
270
+
271
+ def run_planning(
272
+ targets_batch: List[MoleculeContainer],
273
+ tree_config: TreeConfig,
274
+ policy_config: PolicyNetworkConfig,
275
+ value_config: ValueNetworkConfig,
276
+ reaction_rules_path: str,
277
+ building_blocks_path: str,
278
+ targets_batch_id: int,
279
+ ):
280
+ """Performs planning stage (tree search) for target molecules and save extracted
281
+ from built trees precursor for further tuning the value network in the training stage.
282
+
283
+ :param targets_batch:
284
+ :param tree_config:
285
+ :param policy_config:
286
+ :param value_config:
287
+ :param reaction_rules_path:
288
+ :param building_blocks_path:
289
+ :param targets_batch_id:
290
+ """
291
+ from tqdm import tqdm
292
+
293
+ print(f"\nProcess batch number {targets_batch_id}")
294
+ tree_list = []
295
+ tree_config.silent = False
296
+ for target in tqdm(targets_batch):
297
+
298
+ try:
299
+ tree = run_tree_search(
300
+ target=target,
301
+ tree_config=tree_config,
302
+ policy_config=policy_config,
303
+ value_config=value_config,
304
+ reaction_rules_path=reaction_rules_path,
305
+ building_blocks_path=building_blocks_path,
306
+ )
307
+ tree_list.append(tree)
308
+
309
+ except Exception as e:
310
+ print(e)
311
+ continue
312
+
313
+ num_solved = sum([len(i.winning_nodes) > 0 for i in tree_list])
314
+ print(f"Planning is finished with {num_solved} solved targets")
315
+
316
+ return tree_list
317
+
318
+
319
+ def run_updating(
320
+ targets_path: str,
321
+ tree_config: TreeConfig,
322
+ policy_config: PolicyNetworkConfig,
323
+ value_config: ValueNetworkConfig,
324
+ reinforce_config: TuningConfig,
325
+ reaction_rules_path: str,
326
+ building_blocks_path: str,
327
+ results_root: str = None,
328
+ ) -> None:
329
+ """Performs updating of value network.
330
+
331
+ :param targets_path: The path to the file with target molecules.
332
+ :param tree_config: The search tree configuration.
333
+ :param policy_config: The policy network configuration.
334
+ :param value_config: The value network configuration.
335
+ :param reinforce_config: The value network tuning configuration.
336
+ :param reaction_rules_path: The path to the file with reaction rules.
337
+ :param building_blocks_path: The path to the file with building blocks.
338
+ :param results_root: The path to the directory where trained value network will be
339
+ saved.
340
+ :return: None.
341
+ """
342
+
343
+ # create results root folder
344
+ results_root = Path(results_root)
345
+ if not results_root.exists():
346
+ results_root.mkdir()
347
+
348
+ # load targets list
349
+ with MoleculeReader(targets_path) as targets:
350
+ targets = list(targets)
351
+
352
+ # create value neural network
353
+ value_config.weights_path = os.path.join(results_root, "value_network.ckpt")
354
+ create_value_network(value_config)
355
+
356
+ # create targets batch
357
+ targets_batch_list = create_targets_batch(
358
+ targets, batch_size=reinforce_config.batch_size
359
+ )
360
+
361
+ # run value network tuning
362
+ for batch_id, targets_batch in enumerate(targets_batch_list, start=1):
363
+
364
+ # start tree planning simulation for batch of targets
365
+ tree_list = run_planning(
366
+ targets_batch=targets_batch,
367
+ tree_config=tree_config,
368
+ policy_config=policy_config,
369
+ value_config=value_config,
370
+ reaction_rules_path=reaction_rules_path,
371
+ building_blocks_path=building_blocks_path,
372
+ targets_batch_id=batch_id,
373
+ )
374
+
375
+ # extract pos and neg precursor from the list of built trees
376
+ extracted_precursor = extract_tree_precursor(tree_list)
377
+
378
+ # train value network for extracted precursor
379
+ run_training(extracted_precursor=extracted_precursor, value_config=value_config)
synplan/ml/training/supervised.py ADDED
@@ -0,0 +1,153 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Module for the preparation and training of a policy network used in the expansion of
2
+ nodes in tree search.
3
+
4
+ This module includes functions for creating training datasets and running the training
5
+ process for the policy network.
6
+ """
7
+
8
+ import warnings
9
+ from pathlib import Path
10
+ from typing import Union, List
11
+
12
+ import os
13
+ import torch
14
+ from pytorch_lightning import Trainer
15
+ from pytorch_lightning.callbacks import ModelCheckpoint
16
+ from torch.utils.data import random_split
17
+ from torch_geometric.data.lightning import LightningDataset
18
+
19
+ from synplan.ml.networks.policy import PolicyNetwork
20
+ from synplan.ml.training.preprocessing import (
21
+ FilteringPolicyDataset,
22
+ RankingPolicyDataset,
23
+ )
24
+ from synplan.utils.config import PolicyNetworkConfig
25
+ from synplan.utils.logging import DisableLogger, HiddenPrints
26
+
27
+ warnings.filterwarnings("ignore")
28
+
29
+
30
+ def create_policy_dataset(
31
+ reaction_rules_path: str,
32
+ molecules_or_reactions_path: str,
33
+ output_path: str,
34
+ dataset_type: str = "filtering",
35
+ batch_size: int = 100,
36
+ num_cpus: int = 1,
37
+ training_data_ratio: float = 0.8,
38
+ ):
39
+ """
40
+ Create a training dataset for a policy network.
41
+
42
+ :param reaction_rules_path: Path to the reaction rules file.
43
+ :param molecules_or_reactions_path: Path to the molecules or reactions file used to create the training set.
44
+ :param output_path: Path to store the processed dataset.
45
+ :param dataset_type: Type of the dataset to be created ('ranking' or 'filtering').
46
+ :param batch_size: The size of batch of molecules/reactions.
47
+ :param training_data_ratio: Ratio of training data to total data.
48
+ :param num_cpus: Number of CPUs to use for data processing.
49
+
50
+ :return: A `LightningDataset` object containing training and validation datasets.
51
+
52
+ """
53
+
54
+ with DisableLogger(), HiddenPrints():
55
+ if dataset_type == "filtering":
56
+ full_dataset = FilteringPolicyDataset(
57
+ reaction_rules_path=reaction_rules_path,
58
+ molecules_path=molecules_or_reactions_path,
59
+ output_path=output_path,
60
+ num_cpus=num_cpus,
61
+ )
62
+
63
+ elif dataset_type == "ranking":
64
+ full_dataset = RankingPolicyDataset(
65
+ reaction_rules_path=reaction_rules_path,
66
+ reactions_path=molecules_or_reactions_path,
67
+ output_path=output_path,
68
+ )
69
+
70
+ train_size = int(training_data_ratio * len(full_dataset))
71
+ val_size = len(full_dataset) - train_size
72
+
73
+ train_dataset, val_dataset = random_split(
74
+ full_dataset, [train_size, val_size], torch.Generator().manual_seed(42)
75
+ )
76
+ print(
77
+ f"Training set size: {len(train_dataset)}, validation set size: {len(val_dataset)}"
78
+ )
79
+
80
+ datamodule = LightningDataset(
81
+ train_dataset,
82
+ val_dataset,
83
+ batch_size=batch_size,
84
+ pin_memory=True,
85
+ drop_last=True,
86
+ )
87
+
88
+ return datamodule
89
+
90
+
91
+ def run_policy_training(
92
+ datamodule: LightningDataset,
93
+ config: PolicyNetworkConfig,
94
+ results_path: str,
95
+ weights_file_name: str = "policy_network",
96
+ accelerator: str = "gpu",
97
+ devices: Union[List[int], str, int] = "auto",
98
+ silent: bool = False,
99
+ ) -> None:
100
+ """
101
+ Trains a policy network using a given datamodule and training configuration.
102
+
103
+ :param datamodule: A PyTorch Lightning `DataModule` class instance. It is responsible for loading, processing, and preparing the training data for the model.
104
+ :param config: The dictionary that contains various configuration settings for the policy training process.
105
+ :param results_path: Path to store the training results and logs.
106
+ :param accelerator: Supports passing different accelerator types (“cpu”, “gpu”, “tpu”, “hpu”, “mps”, “auto”) as well as custom accelerator instances. Default: "gpu".
107
+ :param devices: The devices to use. Can be set to a positive number (int or str), a sequence of device indices (list or str), the value -1 to indicate all available devices should be used, or "auto" for automatic selection based on the chosen accelerator. Default: "auto".
108
+ :param silent: Run in the silent mode with no progress bars. Default: True.
109
+ :param weights_file_name: The name of weights file to be saved. Default: "policy_network".
110
+
111
+ :return: None.
112
+
113
+ """
114
+ results_path = Path(results_path)
115
+ results_path.mkdir(exist_ok=True)
116
+
117
+ network = PolicyNetwork(
118
+ vector_dim=config.vector_dim,
119
+ n_rules=datamodule.train_dataset.dataset.num_classes,
120
+ batch_size=config.batch_size,
121
+ dropout=config.dropout,
122
+ num_conv_layers=config.num_conv_layers,
123
+ learning_rate=config.learning_rate,
124
+ policy_type=config.policy_type,
125
+ )
126
+
127
+ checkpoint = ModelCheckpoint(
128
+ dirpath=results_path, filename=weights_file_name, monitor="val_loss", mode="min"
129
+ )
130
+
131
+ if silent:
132
+ enable_progress_bar = False
133
+ else:
134
+ enable_progress_bar = True
135
+
136
+ trainer = Trainer(
137
+ accelerator=accelerator,
138
+ devices=devices,
139
+ max_epochs=config.num_epoch,
140
+ callbacks=[checkpoint],
141
+ logger=False,
142
+ gradient_clip_val=1.0,
143
+ enable_progress_bar=enable_progress_bar,
144
+ )
145
+
146
+ if silent:
147
+ with DisableLogger(), HiddenPrints():
148
+ trainer.fit(network, datamodule)
149
+ else:
150
+ trainer.fit(network, datamodule)
151
+
152
+ ba = round(trainer.logged_metrics["train_balanced_accuracy_y_step"].item(), 3)
153
+ print(f"Policy network balanced accuracy: {ba}")
synplan/utils/__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ from typing import Union
2
+ from os import PathLike
3
+
4
+ path_type = Union[str, PathLike]
synplan/utils/config.py ADDED
@@ -0,0 +1,543 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Module containing configuration classes."""
2
+
3
+ from abc import ABC, abstractmethod
4
+ from dataclasses import dataclass, field
5
+ from pathlib import Path
6
+ from typing import Any, Dict, List, Union
7
+ from chython import smarts
8
+
9
+ import yaml
10
+ from CGRtools.containers import MoleculeContainer, QueryContainer
11
+
12
+
13
+ @dataclass
14
+ class ConfigABC(ABC):
15
+ """Abstract base class for configuration classes."""
16
+
17
+ @staticmethod
18
+ @abstractmethod
19
+ def from_dict(config_dict: Dict[str, Any]):
20
+ """Create an instance of the configuration from a dictionary."""
21
+
22
+ def to_dict(self) -> Dict[str, Any]:
23
+ """Convert the configuration into a dictionary."""
24
+ return {
25
+ k: str(v) if isinstance(v, Path) else v for k, v in self.__dict__.items()
26
+ }
27
+
28
+ @staticmethod
29
+ @abstractmethod
30
+ def from_yaml(file_path: str):
31
+ """Deserialize a YAML file into a configuration object."""
32
+
33
+ def to_yaml(self, file_path: str):
34
+ """Serializes the configuration to a YAML file.
35
+
36
+ :param file_path: The path to the output YAML file.
37
+ """
38
+ with open(file_path, "w", encoding="utf-8") as file:
39
+ yaml.dump(self.to_dict(), file)
40
+
41
+ @abstractmethod
42
+ def _validate_params(self, params: Dict[str, Any]):
43
+ """Validate configuration parameters."""
44
+
45
+ def __post_init__(self):
46
+ """Validates the configuration parameters."""
47
+ # call _validate_params method after initialization
48
+ params = self.to_dict()
49
+ self._validate_params(params)
50
+
51
+
52
+ @dataclass
53
+ class RuleExtractionConfig(ConfigABC):
54
+ """Configuration class for extracting reaction rules.
55
+
56
+ :param multicenter_rules: If True, extracts a single rule
57
+ encompassing all centers. If False, extracts separate reaction
58
+ rules for each reaction center in a multicenter reaction.
59
+ :param as_query_container: If True, the extracted rules are
60
+ generated as QueryContainer objects, analogous to SMARTS objects
61
+ for pattern matching in chemical structures.
62
+ :param reverse_rule: If True, reverses the direction of the reaction
63
+ for rule extraction.
64
+ :param reactor_validation: If True, validates each generated rule in
65
+ a chemical reactor to ensure correct generation of products from
66
+ reactants.
67
+ :param include_func_groups: If True, includes specific functional
68
+ groups in the reaction rule in addition to the reaction center
69
+ and its environment.
70
+ :param func_groups_list: A list of functional groups to be
71
+ considered when include_func_groups is True.
72
+ :param include_rings: If True, includes ring structures in the
73
+ reaction rules.
74
+ :param keep_leaving_groups: If True, retains leaving groups in the
75
+ extracted reaction rule.
76
+ :param keep_incoming_groups: If True, retains incoming groups in the
77
+ extracted reaction rule.
78
+ :param keep_reagents: If True, includes reagents in the extracted
79
+ reaction rule.
80
+ :param environment_atom_count: Defines the size of the environment
81
+ around the reaction center to be included in the rule (0 for
82
+ only the reaction center, 1 for the first environment, etc.).
83
+ :param min_popularity: Minimum number of times a rule must be
84
+ applied to be considered for further analysis.
85
+ :param keep_metadata: If True, retains metadata associated with the
86
+ reaction in the extracted rule.
87
+ :param single_reactant_only: If True, includes only reaction rules
88
+ with a single reactant molecule.
89
+ :param atom_info_retention: Controls the amount of information about
90
+ each atom to retain ('none', 'reaction_center', or 'all').
91
+ """
92
+
93
+ # default low-level parameters
94
+ single_reactant_only: bool = True
95
+ keep_metadata: bool = False
96
+ reactor_validation: bool = True
97
+ reverse_rule: bool = True
98
+ as_query_container: bool = True
99
+ include_func_groups: bool = False
100
+ func_groups_list: List[str] = field(default_factory=list)
101
+
102
+ # adjustable parameters
103
+ environment_atom_count: int = 1
104
+ min_popularity: int = 3
105
+ include_rings: bool = True
106
+ multicenter_rules: bool = True
107
+ keep_leaving_groups: bool = True
108
+ keep_incoming_groups: bool = True
109
+ keep_reagents: bool = False
110
+ atom_info_retention: Dict[str, Dict[str, bool]] = field(default_factory=dict)
111
+
112
+ def __post_init__(self):
113
+ super().__post_init__()
114
+ self._validate_params(self.to_dict())
115
+ self._initialize_default_atom_info_retention()
116
+ self._parse_functional_groups()
117
+
118
+ def _initialize_default_atom_info_retention(self):
119
+ default_atom_info = {
120
+ "reaction_center": {
121
+ "neighbors": True,
122
+ "hybridization": True,
123
+ "implicit_hydrogens": False,
124
+ "ring_sizes": False,
125
+ },
126
+ "environment": {
127
+ "neighbors": False,
128
+ "hybridization": False,
129
+ "implicit_hydrogens": False,
130
+ "ring_sizes": False,
131
+ },
132
+ }
133
+
134
+ if not self.atom_info_retention:
135
+ self.atom_info_retention = default_atom_info
136
+ else:
137
+ for key in default_atom_info:
138
+ self.atom_info_retention[key].update(
139
+ self.atom_info_retention.get(key, {})
140
+ )
141
+
142
+ def _parse_functional_groups(self):
143
+ func_groups_list = []
144
+ for group_smarts in self.func_groups_list:
145
+ try:
146
+ query = smarts(group_smarts)
147
+ func_groups_list.append(query)
148
+ except Exception as e:
149
+ print(f"Functional group {group_smarts} was not parsed because of {e}")
150
+ self.func_groups_list = func_groups_list
151
+
152
+ @staticmethod
153
+ def from_dict(config_dict: Dict[str, Any]) -> "RuleExtractionConfig":
154
+ return RuleExtractionConfig(**config_dict)
155
+
156
+ @staticmethod
157
+ def from_yaml(file_path: str) -> "RuleExtractionConfig":
158
+
159
+ with open(file_path, "r", encoding="utf-8") as file:
160
+ config_dict = yaml.safe_load(file)
161
+ return RuleExtractionConfig.from_dict(config_dict)
162
+
163
+ def _validate_params(self, params: Dict[str, Any]) -> None:
164
+
165
+ if not isinstance(params["multicenter_rules"], bool):
166
+ raise ValueError("multicenter_rules must be a boolean.")
167
+
168
+ if not isinstance(params["as_query_container"], bool):
169
+ raise ValueError("as_query_container must be a boolean.")
170
+
171
+ if not isinstance(params["reverse_rule"], bool):
172
+ raise ValueError("reverse_rule must be a boolean.")
173
+
174
+ if not isinstance(params["reactor_validation"], bool):
175
+ raise ValueError("reactor_validation must be a boolean.")
176
+
177
+ if not isinstance(params["include_func_groups"], bool):
178
+ raise ValueError("include_func_groups must be a boolean.")
179
+
180
+ if params["func_groups_list"] is not None and not all(
181
+ isinstance(group, str) for group in params["func_groups_list"]
182
+ ):
183
+ raise ValueError("func_groups_list must be a list of SMARTS.")
184
+
185
+ if not isinstance(params["include_rings"], bool):
186
+ raise ValueError("include_rings must be a boolean.")
187
+
188
+ if not isinstance(params["keep_leaving_groups"], bool):
189
+ raise ValueError("keep_leaving_groups must be a boolean.")
190
+
191
+ if not isinstance(params["keep_incoming_groups"], bool):
192
+ raise ValueError("keep_incoming_groups must be a boolean.")
193
+
194
+ if not isinstance(params["keep_reagents"], bool):
195
+ raise ValueError("keep_reagents must be a boolean.")
196
+
197
+ if not isinstance(params["environment_atom_count"], int):
198
+ raise ValueError("environment_atom_count must be an integer.")
199
+
200
+ if not isinstance(params["min_popularity"], int):
201
+ raise ValueError("min_popularity must be an integer.")
202
+
203
+ if not isinstance(params["keep_metadata"], bool):
204
+ raise ValueError("keep_metadata must be a boolean.")
205
+
206
+ if not isinstance(params["single_reactant_only"], bool):
207
+ raise ValueError("single_reactant_only must be a boolean.")
208
+
209
+ if params["atom_info_retention"] is not None:
210
+ if not isinstance(params["atom_info_retention"], dict):
211
+ raise ValueError("atom_info_retention must be a dictionary.")
212
+
213
+ required_keys = {"reaction_center", "environment"}
214
+ if not required_keys.issubset(params["atom_info_retention"]):
215
+ missing_keys = required_keys - set(params["atom_info_retention"].keys())
216
+ raise ValueError(
217
+ f"atom_info_retention missing required keys: {missing_keys}"
218
+ )
219
+
220
+ for key, value in params["atom_info_retention"].items():
221
+ if key not in required_keys:
222
+ raise ValueError(f"Unexpected key in atom_info_retention: {key}")
223
+
224
+ expected_subkeys = {
225
+ "neighbors",
226
+ "hybridization",
227
+ "implicit_hydrogens",
228
+ "ring_sizes",
229
+ }
230
+ if not isinstance(value, dict) or not expected_subkeys.issubset(value):
231
+ missing_subkeys = expected_subkeys - set(value.keys())
232
+ raise ValueError(
233
+ f"Invalid structure for {key} in atom_info_retention. Missing subkeys: {missing_subkeys}"
234
+ )
235
+
236
+ for subkey, subvalue in value.items():
237
+ if not isinstance(subvalue, bool):
238
+ raise ValueError(
239
+ f"Value for {subkey} in {key} of atom_info_retention must be boolean."
240
+ )
241
+
242
+
243
+ @dataclass
244
+ class PolicyNetworkConfig(ConfigABC):
245
+ """Configuration class for the policy network.
246
+
247
+ :param vector_dim: Dimension of the input vectors.
248
+ :param batch_size: Number of samples per batch.
249
+ :param dropout: Dropout rate for regularization.
250
+ :param learning_rate: Learning rate for the optimizer.
251
+ :param num_conv_layers: Number of convolutional layers in the network.
252
+ :param num_epoch: Number of training epochs.
253
+ :param policy_type: Mode of operation, either 'filtering' or 'ranking'.
254
+ """
255
+
256
+ policy_type: str = "ranking"
257
+ vector_dim: int = 256
258
+ batch_size: int = 500
259
+ dropout: float = 0.4
260
+ learning_rate: float = 0.008
261
+ num_conv_layers: int = 5
262
+ num_epoch: int = 100
263
+ weights_path: str = None
264
+
265
+ # for filtering policy
266
+ priority_rules_fraction: float = 0.5
267
+ rule_prob_threshold: float = 0.0
268
+ top_rules: int = 50
269
+
270
+ @staticmethod
271
+ def from_dict(config_dict: Dict[str, Any]) -> "PolicyNetworkConfig":
272
+ return PolicyNetworkConfig(**config_dict)
273
+
274
+ @staticmethod
275
+ def from_yaml(file_path: str) -> "PolicyNetworkConfig":
276
+ with open(file_path, "r", encoding="utf-8") as file:
277
+ config_dict = yaml.safe_load(file)
278
+ return PolicyNetworkConfig.from_dict(config_dict)
279
+
280
+ def _validate_params(self, params: Dict[str, Any]):
281
+
282
+ if params["policy_type"] not in ["filtering", "ranking"]:
283
+ raise ValueError("policy_type must be either 'filtering' or 'ranking'.")
284
+
285
+ if not isinstance(params["vector_dim"], int) or params["vector_dim"] <= 0:
286
+ raise ValueError("vector_dim must be a positive integer.")
287
+
288
+ if not isinstance(params["batch_size"], int) or params["batch_size"] <= 0:
289
+ raise ValueError("batch_size must be a positive integer.")
290
+
291
+ if (
292
+ not isinstance(params["num_conv_layers"], int)
293
+ or params["num_conv_layers"] <= 0
294
+ ):
295
+ raise ValueError("num_conv_layers must be a positive integer.")
296
+
297
+ if not isinstance(params["num_epoch"], int) or params["num_epoch"] <= 0:
298
+ raise ValueError("num_epoch must be a positive integer.")
299
+
300
+ if not isinstance(params["dropout"], float) or not (
301
+ 0.0 <= params["dropout"] <= 1.0
302
+ ):
303
+ raise ValueError("dropout must be a float between 0.0 and 1.0.")
304
+
305
+ if (
306
+ not isinstance(params["learning_rate"], float)
307
+ or params["learning_rate"] <= 0.0
308
+ ):
309
+ raise ValueError("learning_rate must be a positive float.")
310
+
311
+ if (
312
+ not isinstance(params["priority_rules_fraction"], float)
313
+ or params["priority_rules_fraction"] < 0.0
314
+ ):
315
+ raise ValueError(
316
+ "priority_rules_fraction must be a non-negative positive float."
317
+ )
318
+
319
+ if (
320
+ not isinstance(params["rule_prob_threshold"], float)
321
+ or params["rule_prob_threshold"] < 0.0
322
+ ):
323
+ raise ValueError("rule_prob_threshold must be a non-negative float.")
324
+
325
+ if not isinstance(params["top_rules"], int) or params["top_rules"] <= 0:
326
+ raise ValueError("top_rules must be a positive integer.")
327
+
328
+
329
+ @dataclass
330
+ class ValueNetworkConfig(ConfigABC):
331
+ """Configuration class for the value network.
332
+
333
+ :param vector_dim: Dimension of the input vectors.
334
+ :param batch_size: Number of samples per batch.
335
+ :param dropout: Dropout rate for regularization.
336
+ :param learning_rate: Learning rate for the optimizer.
337
+ :param num_conv_layers: Number of convolutional layers in the network.
338
+ :param num_epoch: Number of training epochs.
339
+ """
340
+
341
+ weights_path: str = None
342
+ vector_dim: int = 256
343
+ batch_size: int = 500
344
+ dropout: float = 0.4
345
+ learning_rate: float = 0.008
346
+ num_conv_layers: int = 5
347
+ num_epoch: int = 100
348
+
349
+ @staticmethod
350
+ def from_dict(config_dict: Dict[str, Any]) -> "ValueNetworkConfig":
351
+ return ValueNetworkConfig(**config_dict)
352
+
353
+ @staticmethod
354
+ def from_yaml(file_path: str) -> "ValueNetworkConfig":
355
+ with open(file_path, "r", encoding="utf-8") as file:
356
+ config_dict = yaml.safe_load(file)
357
+ return ValueNetworkConfig.from_dict(config_dict)
358
+
359
+ def to_yaml(self, file_path: str):
360
+ with open(file_path, "w", encoding="utf-8") as file:
361
+ yaml.dump(self.to_dict(), file)
362
+
363
+ def _validate_params(self, params: Dict[str, Any]):
364
+
365
+ if not isinstance(params["vector_dim"], int) or params["vector_dim"] <= 0:
366
+ raise ValueError("vector_dim must be a positive integer.")
367
+
368
+ if not isinstance(params["batch_size"], int) or params["batch_size"] <= 0:
369
+ raise ValueError("batch_size must be a positive integer.")
370
+
371
+ if (
372
+ not isinstance(params["num_conv_layers"], int)
373
+ or params["num_conv_layers"] <= 0
374
+ ):
375
+ raise ValueError("num_conv_layers must be a positive integer.")
376
+
377
+ if not isinstance(params["num_epoch"], int) or params["num_epoch"] <= 0:
378
+ raise ValueError("num_epoch must be a positive integer.")
379
+
380
+ if not isinstance(params["dropout"], float) or not (
381
+ 0.0 <= params["dropout"] <= 1.0
382
+ ):
383
+ raise ValueError("dropout must be a float between 0.0 and 1.0.")
384
+
385
+ if (
386
+ not isinstance(params["learning_rate"], float)
387
+ or params["learning_rate"] <= 0.0
388
+ ):
389
+ raise ValueError("learning_rate must be a positive float.")
390
+
391
+
392
+ @dataclass
393
+ class TuningConfig(ConfigABC):
394
+ """Configuration class for the network training.
395
+
396
+ :param batch_size: The number of targets per batch in the planning simulation step.
397
+ :param num_simulations: The number of planning simulations.
398
+ """
399
+
400
+ batch_size: int = 100
401
+ num_simulations: int = 1
402
+
403
+ @staticmethod
404
+ def from_dict(config_dict: Dict[str, Any]) -> "TuningConfig":
405
+ return TuningConfig(**config_dict)
406
+
407
+ @staticmethod
408
+ def from_yaml(file_path: str) -> "TuningConfig":
409
+ with open(file_path, "r", encoding="utf-8") as file:
410
+ config_dict = yaml.safe_load(file)
411
+ return TuningConfig.from_dict(config_dict)
412
+
413
+ def _validate_params(self, params: Dict[str, Any]):
414
+
415
+ if not isinstance(params["batch_size"], int) or params["batch_size"] <= 0:
416
+ raise ValueError("batch_size must be a positive integer.")
417
+
418
+
419
+ @dataclass
420
+ class TreeConfig(ConfigABC):
421
+ """Configuration class for the tree search algorithm.
422
+
423
+ :param max_iterations: The number of iterations to run the algorithm
424
+ for.
425
+ :param max_tree_size: The maximum number of nodes in the tree.
426
+ :param max_time: The time limit (in seconds) for the algorithm to
427
+ run.
428
+ :param max_depth: The maximum depth of the tree.
429
+ :param ucb_type: Type of UCB used in the search algorithm. Options
430
+ are "puct", "uct", "value", defaults to "uct".
431
+ :param c_ucb: The exploration-exploitation balance coefficient used
432
+ in Upper Confidence Bound (UCB).
433
+ :param backprop_type: Type of backpropagation algorithm. Options are
434
+ "muzero", "cumulative", defaults to "muzero".
435
+ :param search_strategy: The strategy used for tree search. Options
436
+ are "expansion_first", "evaluation_first".
437
+ :param exclude_small: Whether to exclude small molecules during the
438
+ search.
439
+ :param evaluation_agg: Method for aggregating evaluation scores.
440
+ Options are "max", "average", defaults to "max".
441
+ :param evaluation_type: The method used for evaluating nodes.
442
+ Options are "random", "rollout", "gcn".
443
+ :param init_node_value: Initial value for a new node.
444
+ :param epsilon: A parameter in the epsilon-greedy search strategy
445
+ representing the chance of random selection of reaction rules
446
+ during the selection stage in Monte Carlo Tree Search,
447
+ specifically during Upper Confidence Bound estimation. It
448
+ balances between exploration and exploitation.
449
+ :param min_mol_size: Defines the minimum size of a molecule that is
450
+ have to be synthesized. Molecules with 6 or fewer heavy atoms
451
+ are assumed to be building blocks by definition, thus setting
452
+ the threshold for considering larger molecules in the search,
453
+ defaults to 6.
454
+ :param silent: Whether to suppress progress output.
455
+ """
456
+
457
+ max_iterations: int = 100
458
+ max_tree_size: int = 1000000
459
+ max_time: float = 600
460
+ max_depth: int = 6
461
+ ucb_type: str = "uct"
462
+ c_ucb: float = 0.1
463
+ backprop_type: str = "muzero"
464
+ search_strategy: str = "expansion_first"
465
+ exclude_small: bool = True
466
+ evaluation_agg: str = "max"
467
+ evaluation_type: str = "gcn"
468
+ init_node_value: float = 0.0
469
+ epsilon: float = 0.0
470
+ min_mol_size: int = 6
471
+ silent: bool = False
472
+
473
+ @staticmethod
474
+ def from_dict(config_dict: Dict[str, Any]) -> "TreeConfig":
475
+ return TreeConfig(**config_dict)
476
+
477
+ @staticmethod
478
+ def from_yaml(file_path: str) -> "TreeConfig":
479
+ with open(file_path, "r", encoding="utf-8") as file:
480
+ config_dict = yaml.safe_load(file)
481
+ return TreeConfig.from_dict(config_dict)
482
+
483
+ def _validate_params(self, params):
484
+ if params["ucb_type"] not in ["puct", "uct", "value"]:
485
+ raise ValueError(
486
+ "Invalid ucb_type. Allowed values are 'puct', 'uct', 'value'."
487
+ )
488
+ if params["backprop_type"] not in ["muzero", "cumulative"]:
489
+ raise ValueError(
490
+ "Invalid backprop_type. Allowed values are 'muzero', 'cumulative'."
491
+ )
492
+ if params["evaluation_type"] not in ["random", "rollout", "gcn"]:
493
+ raise ValueError(
494
+ "Invalid evaluation_type. Allowed values are 'random', 'rollout', 'gcn'."
495
+ )
496
+ if params["evaluation_agg"] not in ["max", "average"]:
497
+ raise ValueError(
498
+ "Invalid evaluation_agg. Allowed values are 'max', 'average'."
499
+ )
500
+ if not isinstance(params["c_ucb"], float):
501
+ raise TypeError("c_ucb must be a float.")
502
+ if not isinstance(params["max_depth"], int) or params["max_depth"] < 1:
503
+ raise ValueError("max_depth must be a positive integer.")
504
+ if not isinstance(params["max_tree_size"], int) or params["max_tree_size"] < 1:
505
+ raise ValueError("max_tree_size must be a positive integer.")
506
+ if (
507
+ not isinstance(params["max_iterations"], int)
508
+ or params["max_iterations"] < 1
509
+ ):
510
+ raise ValueError("max_iterations must be a positive integer.")
511
+ if not isinstance(params["max_time"], int) or params["max_time"] < 1:
512
+ raise ValueError("max_time must be a positive integer.")
513
+ if not isinstance(params["exclude_small"], bool):
514
+ raise TypeError("exclude_small must be a boolean.")
515
+ if not isinstance(params["silent"], bool):
516
+ raise TypeError("silent must be a boolean.")
517
+ if not isinstance(params["init_node_value"], float):
518
+ raise TypeError("init_node_value must be a float if provided.")
519
+ if params["search_strategy"] not in ["expansion_first", "evaluation_first"]:
520
+ raise ValueError(
521
+ f"Invalid search_strategy: {params['search_strategy']}: "
522
+ f"Allowed values are 'expansion_first', 'evaluation_first'"
523
+ )
524
+ if not isinstance(params["epsilon"], float) or 0 >= params["epsilon"] >= 1:
525
+ raise ValueError("epsilon epsilon be a positive float between 0 and 1.")
526
+ if not isinstance(params["min_mol_size"], int) or params["min_mol_size"] < 0:
527
+ raise ValueError("min_mol_size must be a non-negative integer.")
528
+
529
+
530
+ def convert_config_to_dict(config_attr: ConfigABC, config_type) -> Dict | None:
531
+ """Converts a configuration attribute to a dictionary if it's either a dictionary or
532
+ an instance of a specified configuration type.
533
+
534
+ :param config_attr: The configuration attribute to be converted.
535
+ :param config_type: The type to check against for conversion.
536
+ :return: The configuration attribute as a dictionary, or None if it's not an
537
+ instance of the given type or dict.
538
+ """
539
+ if isinstance(config_attr, dict):
540
+ return config_attr
541
+ if isinstance(config_attr, config_type):
542
+ return config_attr.to_dict()
543
+ return None
synplan/utils/files.py ADDED
@@ -0,0 +1,226 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Module containing classes and functions needed for reactions/molecules data
2
+ reading/writing."""
3
+
4
+ from os.path import splitext
5
+ from pathlib import Path
6
+ from typing import Iterable, Union
7
+
8
+ from CGRtools import smiles
9
+ from CGRtools.containers import CGRContainer, MoleculeContainer, ReactionContainer
10
+ from CGRtools.files.RDFrw import RDFRead, RDFWrite
11
+ from CGRtools.files.SDFrw import SDFRead, SDFWrite
12
+
13
+
14
+ class FileHandler:
15
+ """General class to handle chemical files."""
16
+
17
+ def __init__(self, filename: Union[str, Path], **kwargs):
18
+ """General class to handle chemical files.
19
+
20
+ :param filename: The path and name of the file.
21
+ :return: None.
22
+ """
23
+ self._file = None
24
+ _, ext = splitext(filename)
25
+ file_types = {".smi": "SMI", ".smiles": "SMI", ".rdf": "RDF", ".sdf": "SDF"}
26
+ try:
27
+ self._file_type = file_types[ext]
28
+ except KeyError:
29
+ raise ValueError("I don't know the file extension,", ext)
30
+
31
+ def close(self):
32
+ self._file.close()
33
+
34
+ def __exit__(self, exc_type, exc_val, exc_tb):
35
+ self.close()
36
+
37
+
38
+ class Reader(FileHandler):
39
+ def __init__(self, filename: Union[str, Path], **kwargs):
40
+ """General class to read reactions/molecules data files.
41
+
42
+ :param filename: The path and name of the file.
43
+ :return: None.
44
+ """
45
+ super().__init__(filename, **kwargs)
46
+
47
+ def __enter__(self):
48
+ return self._file
49
+
50
+ def __iter__(self):
51
+ return iter(self._file)
52
+
53
+ def __next__(self):
54
+ return next(self._file)
55
+
56
+ def __len__(self):
57
+ return len(self._file)
58
+
59
+
60
+ class SMILESRead:
61
+ def __init__(self, filename: Union[str, Path], **kwargs):
62
+ """Simplified class to read files containing a SMILES (Molecules or Reaction)
63
+ string per line.
64
+
65
+ :param filename: The path and name of the SMILES file to parse.
66
+ :return: None.
67
+ """
68
+ filename = str(Path(filename).resolve(strict=True))
69
+ self._file = open(filename, "r", encoding="utf-8")
70
+ self._data = self.__data()
71
+
72
+ def __data(
73
+ self,
74
+ ) -> Iterable[Union[ReactionContainer, CGRContainer, MoleculeContainer]]:
75
+ for line in iter(self._file.readline, ""):
76
+ line = line.strip()
77
+ x = smiles(line)
78
+ if isinstance(x, (ReactionContainer, CGRContainer, MoleculeContainer)):
79
+ x.meta["init_smiles"] = line
80
+ yield x
81
+
82
+ def __enter__(self):
83
+ return self
84
+
85
+ def read(self):
86
+ """Parse the whole SMILES file.
87
+
88
+ :return: List of parsed molecules or reactions.
89
+ """
90
+ return list(iter(self))
91
+
92
+ def __iter__(self):
93
+ return (x for x in self._data)
94
+
95
+ def __next__(self):
96
+ return next(iter(self))
97
+
98
+ def close(self):
99
+ self._file.close()
100
+
101
+ def __exit__(self, exc_type, exc_val, exc_tb):
102
+ self.close()
103
+
104
+
105
+ class Writer(FileHandler):
106
+ def __init__(self, filename: Union[str, Path], mapping: bool = True, **kwargs):
107
+ """General class to write chemical files.
108
+
109
+ :param filename: The path and name of the file.
110
+ :param mapping: Whenever to save mapping or not.
111
+ :return: None.
112
+ """
113
+ super().__init__(filename, **kwargs)
114
+ self._mapping = mapping
115
+
116
+ def __enter__(self):
117
+ return self
118
+
119
+
120
+ class ReactionReader(Reader):
121
+ def __init__(self, filename: Union[str, Path], **kwargs):
122
+ """Class to read reaction files.
123
+
124
+ :param filename: The path and name of the file.
125
+ :return: None.
126
+ """
127
+ super().__init__(filename, **kwargs)
128
+ if self._file_type == "SMI":
129
+ self._file = SMILESRead(filename, **kwargs)
130
+ elif self._file_type == "RDF":
131
+ self._file = RDFRead(filename, indexable=True, **kwargs)
132
+ else:
133
+ raise ValueError("File type incompatible -", filename)
134
+
135
+
136
+ class ReactionWriter(Writer):
137
+ def __init__(self, filename: Union[str, Path], mapping: bool = True, **kwargs):
138
+ """Class to write reaction files.
139
+
140
+ :param filename: The path and name of the file.
141
+ :param mapping: Whenever to save mapping or not.
142
+ :return: None.
143
+ """
144
+ super().__init__(filename, mapping, **kwargs)
145
+ if self._file_type == "SMI":
146
+ self._file = open(filename, "w", encoding="utf-8", **kwargs)
147
+ elif self._file_type == "RDF":
148
+ self._file = RDFWrite(filename, append=False, **kwargs)
149
+ else:
150
+ raise ValueError("File type incompatible -", filename)
151
+
152
+ def write(self, reaction: ReactionContainer):
153
+ """Function to write a specific reaction to the file.
154
+
155
+ :param reaction: The path and name of the file.
156
+ :return: None.
157
+ """
158
+ if self._file_type == "SMI":
159
+ rea_str = to_reaction_smiles_record(reaction)
160
+ self._file.write(rea_str + "\n")
161
+ elif self._file_type == "RDF":
162
+ self._file.write(reaction)
163
+
164
+
165
+ class MoleculeReader(Reader):
166
+ def __init__(self, filename: Union[str, Path], **kwargs):
167
+ """Class to read molecule files.
168
+
169
+ :param filename: The path and name of the file.
170
+ :return: None.
171
+ """
172
+ super().__init__(filename, **kwargs)
173
+ if self._file_type == "SMI":
174
+ self._file = SMILESRead(filename, ignore=True, **kwargs)
175
+ elif self._file_type == "SDF":
176
+ self._file = SDFRead(filename, indexable=True, **kwargs)
177
+ else:
178
+ raise ValueError("File type incompatible -", filename)
179
+
180
+
181
+ class MoleculeWriter(Writer):
182
+ def __init__(self, filename: Union[str, Path], mapping: bool = True, **kwargs):
183
+ """Class to write molecule files.
184
+
185
+ :param filename: The path and name of the file.
186
+ :param mapping: Whenever to save mapping or not.
187
+ :return: None.
188
+ """
189
+ super().__init__(filename, mapping, **kwargs)
190
+ if self._file_type == "SMI":
191
+ self._file = open(filename, "w", encoding="utf-8", **kwargs)
192
+ elif self._file_type == "SDF":
193
+ self._file = SDFWrite(filename, append=False, **kwargs)
194
+ else:
195
+ raise ValueError("File type incompatible -", filename)
196
+
197
+ def write(self, molecule: MoleculeContainer):
198
+ """Function to write a specific molecule to the file.
199
+
200
+ :param molecule: The path and name of the file.
201
+ :return: None.
202
+ """
203
+ if self._file_type == "SMI":
204
+ mol_str = str(molecule)
205
+ self._file.write(mol_str + "\n")
206
+ elif self._file_type == "SDF":
207
+ self._file.write(molecule)
208
+
209
+
210
+ def to_reaction_smiles_record(reaction: ReactionContainer) -> str:
211
+ """Converts the reaction to the SMILES record. Needed for reaction/molecule writers.
212
+
213
+ :param reaction: The reaction to be written.
214
+ :return: The SMILES record to be written.
215
+ """
216
+
217
+ if isinstance(reaction, str):
218
+ return reaction
219
+
220
+ reaction_record = [format(reaction, "m")]
221
+ sorted_meta = sorted(reaction.meta.items(), key=lambda x: x[0])
222
+ for _, meta_info in sorted_meta:
223
+ meta_info = ""
224
+ meta_info = ";".join(meta_info.split("\n"))
225
+ reaction_record.append(str(meta_info))
226
+ return "\t".join(reaction_record)