Gilmullin Almaz commited on
Commit
f2f3593
·
1 Parent(s): 2830c50
cluster/{super_cgr.py → generalized_cgr.py} RENAMED
File without changes
cluster/reduced_g_cgr.py ADDED
@@ -0,0 +1,159 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from CGRtools.containers.bonds import DynamicBond
2
+
3
+ def reducing_g_cgr(g_cgr):
4
+ """
5
+ Reduces a Generalized Condensed Graph of reaction (G-CGR) by performing the following steps:
6
+
7
+ 1. Extracts substructures corresponding to connected components from the input G-CGR.
8
+ 2. Selects the first substructure as the target to work on.
9
+ 3. Iterates over all bonds in the target G-CGR:
10
+ - If a bond is identified as a "leaving group" (its primary order is None while its original order is defined),
11
+ the bond is removed.
12
+ - If a bond has a modified order (both primary and original orders are integers) and the primary order is less than the original,
13
+ the bond is deleted and then re-added with a new dynamic bond using the primary order (this updates the bond to the reduced form).
14
+ 4. After bond modifications, re-extracts the substructure from the target G-CGR (now called the reduced G-CGR or RG-CGR).
15
+ 5. If the charge distributions (_p_charges vs. _charges) differ, neutralizes the charges by setting them to zero.
16
+
17
+ Finally, returns the reduced G-CGR.
18
+ """
19
+ # Get all connected components of the G-CGR as separate substructures.
20
+ cgr_prods = [g_cgr.substructure(c) for c in g_cgr.connected_components]
21
+ target_cgr = cgr_prods[0] # Choose the first substructure (main product) for further reduction.
22
+
23
+ # Iterate over each bond in the target G-CGR.
24
+ bond_items = list(target_cgr._bonds.items())
25
+ for atom1, bond_set in bond_items:
26
+ bond_set_items = list(bond_set.items())
27
+ for atom2, bond in bond_set_items:
28
+
29
+ # Removing bonds corresponding to leaving groups:
30
+ # If product bond order is None (indicating a leaving group) but an original bond order exists,
31
+ # delete the bond.
32
+ if bond.p_order is None and bond.order is not None:
33
+ target_cgr.delete_bond(atom1, atom2)
34
+
35
+ # For bonds that have been modified (not leaving groups) where the new (primary) order is less than the original:
36
+ # Remove the bond and re-add it using the DynamicBond with the primary order for both bond orders.
37
+ elif type(bond.p_order) is int and type(bond.order) is int and bond.p_order != bond.order:
38
+ p_order = int(bond.p_order)
39
+ target_cgr.delete_bond(atom1, atom2)
40
+ target_cgr.add_bond(atom1, atom2, DynamicBond(p_order, p_order))
41
+
42
+ # After modifying bonds, extract the reduced G-CGR from the target's connected components.
43
+ rg_cgr = [target_cgr.substructure(c) for c in target_cgr.connected_components][0]
44
+
45
+ # Neutralize charges if the primary charges and current charges differ.
46
+ if rg_cgr._p_charges != rg_cgr._charges:
47
+ for num, charge in rg_cgr._charges.items():
48
+ if charge != 0:
49
+ rg_cgr._atoms[num].charge = 0
50
+
51
+ return rg_cgr
52
+
53
+
54
+ def process_all_rg_cgrs(g_cgrs_dict):
55
+ """
56
+ Processes a collection (dictionary) of G-CGRs to generate their reduced forms (RG-CGRs).
57
+
58
+ Iterates over each G-CGR in the provided dictionary and applies the reducing_g_cgr function.
59
+
60
+ Note: There is an apparent bug in the code since it uses an undefined variable 'super_cgrs_dict'
61
+ and assigns to 'all_rs_cgrs' instead of 'all_rg_cgrs'. The intended behavior is to iterate over
62
+ the input dictionary (g_cgrs_dict) and store the reduced RG-CGR for each key.
63
+
64
+ Returns:
65
+ A dictionary where each key corresponds to the RG-CGR obtained from the input G-CGR.
66
+ """
67
+ all_rg_cgrs = dict()
68
+ for num, cgr in g_cgrs_dict.items():
69
+ all_rg_cgrs[num] = reducing_g_cgr(cgr)
70
+ return all_rg_cgrs
71
+
72
+
73
+ def report_strategic_bonds(result, target_cgr):
74
+ """
75
+ Reports strategic bonds from a provided result list.
76
+
77
+ Each element in 'result' is expected to be a list with two elements:
78
+ - A tuple (atom pair) indicating the connected atoms.
79
+ - The primary bond order (p_order) associated with that bond.
80
+
81
+ The function prints out the atoms (accessed from target_cgr._atoms) and the bond order.
82
+ """
83
+ for value in result:
84
+ atom_pair = value[0]
85
+ # Print the two atoms and the associated primary bond order.
86
+ print('\t', target_cgr._atoms[atom_pair[0]], target_cgr._atoms[atom_pair[1]], value[1])
87
+
88
+
89
+ def extract_strategic_bonds(target_cgr, report=True):
90
+ """
91
+ Extracts and optionally reports strategic bonds from a reduced G-CGR (RG-CGR).
92
+
93
+ Strategic bonds are defined as those with:
94
+ - No current bond order (order is None) but a defined primary bond order (p_order is not None).
95
+
96
+ The function goes through all bonds in the target_cgr, collects each unique bond (avoiding duplicates by using a set)
97
+ along with its primary bond order, and optionally prints them out.
98
+
99
+ Returns:
100
+ A list where each element is a pair: [bond_key (tuple of atom indices), primary bond order]
101
+ """
102
+ result = []
103
+ seen = set()
104
+ # Loop through all bonds in the RG-CGR.
105
+ for atom1, bond_set in target_cgr._bonds.items():
106
+ for atom2, bond in bond_set.items():
107
+ # Check for strategic bonds (order undefined but p_order defined).
108
+ if bond.order is None and bond.p_order is not None:
109
+ # Create a sorted tuple of the atom pair to ensure uniqueness.
110
+ bond_key = tuple(sorted((atom1, atom2)))
111
+ if bond_key not in seen:
112
+ seen.add(bond_key)
113
+ result.append([bond_key, bond.p_order])
114
+ # If reporting is enabled, print the strategic bonds.
115
+ if report:
116
+ print('Strategic bonds in RG-CGR:')
117
+ report_strategic_bonds(result, target_cgr)
118
+ return result
119
+
120
+
121
+ def compare_rg_cgr_by_strategic_bonds(rg_cgr1, rg_cgr2, report=True):
122
+ """
123
+ Compares two reduced G-CGRs (RG-CGRs) based on their strategic bonds.
124
+
125
+ The function performs the following steps:
126
+
127
+ 1. Extracts the list of strategic bonds for each RG-CGR.
128
+ 2. Converts each list into a set of tuples (bond key and bond order) for easy set operations.
129
+ 3. Identifies common bonds, and bonds unique to each RG-CGR.
130
+ 4. Converts these sets back into lists for reporting.
131
+ 5. Prints out the common bonds, bonds unique to the first RG-CGR, and bonds unique to the second RG-CGR.
132
+
133
+ The reporting uses the report_strategic_bonds function to output the atom details and bond orders.
134
+ """
135
+ # Extract strategic bonds from both RG-CGRs without reporting.
136
+ l1 = extract_strategic_bonds(rg_cgr1, report=False)
137
+ l2 = extract_strategic_bonds(rg_cgr2, report=False)
138
+
139
+ # Create sets of (atom pair, bond order) tuples for both RG-CGRs.
140
+ set_l1 = { (tuple(item[0]), item[1]) for item in l1 }
141
+ set_l2 = { (tuple(item[0]), item[1]) for item in l2 }
142
+
143
+ # Identify common bonds and bonds unique to each list.
144
+ common = set_l1 & set_l2
145
+ unique_l1 = set_l1 - set_l2
146
+ unique_l2 = set_l2 - set_l1
147
+
148
+ # Convert the sets back to list format for reporting.
149
+ common_list = [ [atom_pair, order] for atom_pair, order in common ]
150
+ unique_l1_list = [ [atom_pair, order] for atom_pair, order in unique_l1 ]
151
+ unique_l2_list = [ [atom_pair, order] for atom_pair, order in unique_l2 ]
152
+
153
+ if report:
154
+ print("Common:")
155
+ report_strategic_bonds(common_list, rg_cgr1)
156
+ print("Unique for first RG-CGR:")
157
+ report_strategic_bonds(unique_l1_list, rg_cgr1)
158
+ print("Unique for second RG-CGR:")
159
+ report_strategic_bonds(unique_l2_list, rg_cgr1)
cluster/rs_cgr.py DELETED
@@ -1,40 +0,0 @@
1
- from CGRtools.containers.bonds import DynamicBond
2
-
3
- def s_cgr2rs_cgr(s_cgr):
4
- cgr_prods = [s_cgr.substructure(c) for c in s_cgr.connected_components]
5
- target_cgr = cgr_prods[0]
6
-
7
- bond_items = list(target_cgr._bonds.items())
8
- for atom1, bond_set in bond_items:
9
- bond_set_items = list(bond_set.items())
10
- for atom2, bond in bond_set_items:
11
-
12
- # Leaving groups removal
13
- if bond.p_order == None and bond.order is not None:
14
- # print(atom1, atom2)
15
- # print(bond)
16
- target_cgr.delete_bond(atom1, atom2)
17
- # target_cgr.clean2d()
18
- # display(SVG(target_cgr.depict()))
19
-
20
- ## Modified bond, but not leaving group
21
- elif type(bond.p_order) is int and type(bond.order) is int and bond.p_order < bond.order:
22
- p_order = int(bond.p_order)
23
- target_cgr.delete_bond(atom1, atom2)
24
- target_cgr.add_bond(atom1, atom2, DynamicBond(p_order, p_order))
25
-
26
- rs_cgr = [target_cgr.substructure(c) for c in target_cgr.connected_components][0]
27
-
28
- # Charge neutralizer
29
- if rs_cgr._p_charges != rs_cgr._charges:
30
- for num, charge in rs_cgr._charges.items():
31
- if charge != 0:
32
- rs_cgr._atoms[num].charge = 0
33
-
34
- return rs_cgr
35
-
36
- def process_all_rs_cgrs(super_cgrs_dict):
37
- all_rs_cgrs = dict()
38
- for num, cgr in super_cgrs_dict.items():
39
- all_rs_cgrs[num] = s_cgr2rs_cgr(cgr)
40
- return all_rs_cgrs
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cluster/utils.py CHANGED
@@ -1,3 +1,9 @@
 
 
 
 
 
 
1
  def extract_reactions(tree):
2
  reactions_dict = {}
3
  for node_id in set(tree.winning_nodes):
@@ -5,14 +11,55 @@ def extract_reactions(tree):
5
  reactions_dict[node_id] = reactions
6
  return reactions_dict
7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
 
9
  class TreeWrapper:
10
- def __init__(self, tree):
 
 
 
 
11
  self.tree = tree
 
 
 
 
 
12
 
13
  def __getstate__(self):
14
  state = self.__dict__.copy()
15
- # Save the state of the tree
16
  tree_state = self.tree.__dict__.copy()
17
  # Reset or remove non-pickleable attributes (e.g., _tqdm, policy_network, value_network)
18
  if '_tqdm' in tree_state:
@@ -20,18 +67,248 @@ class TreeWrapper:
20
  for attr in ['policy_network', 'value_network']:
21
  if attr in tree_state:
22
  tree_state[attr] = None
23
- # Store tree state separately
24
  state['tree_state'] = tree_state
25
- # Remove the actual tree instance from the state
26
  del state['tree']
27
  return state
28
 
29
  def __setstate__(self, state):
30
- # Retrieve the stored tree state
31
  tree_state = state.pop('tree_state')
32
- # Update the instance state
33
  self.__dict__.update(state)
34
- # Create a new Tree instance without calling __init__
35
  new_tree = Tree.__new__(Tree)
36
  new_tree.__dict__.update(tree_state)
37
- self.tree = new_tree
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from synplan.mcts.tree import Tree
2
+ from synplan.utils.visualisation import get_route_svg
3
+ from CGRtools.containers import MoleculeContainer
4
+ import pickle
5
+ import os
6
+
7
  def extract_reactions(tree):
8
  reactions_dict = {}
9
  for node_id in set(tree.winning_nodes):
 
11
  reactions_dict[node_id] = reactions
12
  return reactions_dict
13
 
14
+ def extract_rules_from_route(node_id, tree):
15
+ nodes = tree.route_to_node(node_id)
16
+ found_rules_ids = []
17
+ for i in range(len(nodes)):
18
+ precursor = nodes[i].new_precursors[0]
19
+ if len(precursor) != 0:
20
+ if 'reactor_id' in precursor.molecule.meta.keys():
21
+ found_rules_ids.append(precursor.molecule.meta['reactor_id'])
22
+ return found_rules_ids[::-1]
23
+
24
+ def save_smarts(mol_id, config, reactions_dict):
25
+ with open(f'smarts/smarts_mol_{mol_id}_{config}.txt', "w") as file:
26
+ for node_id, reactions in reactions_dict.items():
27
+ file.write(f"{node_id}\n")
28
+ for reaction in reactions:
29
+ file.write(f"{reaction}\n")
30
+
31
+
32
+ def get_highest_route_nodes(tree, node_dict):
33
+ highest_nodes = {}
34
+ for key, node_ids in node_dict.items():
35
+ max_score = float('-inf')
36
+ best_nodes = []
37
+ for node_id in node_ids:
38
+ score = round(tree.route_score(node_id), 3)
39
+ if score > max_score:
40
+ max_score = score
41
+ best_nodes = [node_id]
42
+ elif score == max_score:
43
+ best_nodes.append(node_id)
44
+ highest_nodes[key] = best_nodes
45
+ return highest_nodes
46
+
47
 
48
  class TreeWrapper:
49
+
50
+ BASE_DIR = 'forest'
51
+
52
+ def __init__(self, tree, mol_id, config):
53
+ """Initializes the TreeWrapper."""
54
  self.tree = tree
55
+ self.mol_id = mol_id
56
+ self.config = config
57
+ # Ensure the directory exists before creating the filename
58
+ os.makedirs(self.BASE_DIR, exist_ok=True)
59
+ self.filename = os.path.join(self.BASE_DIR, f'tree_{mol_id}_{config}.pkl')
60
 
61
  def __getstate__(self):
62
  state = self.__dict__.copy()
 
63
  tree_state = self.tree.__dict__.copy()
64
  # Reset or remove non-pickleable attributes (e.g., _tqdm, policy_network, value_network)
65
  if '_tqdm' in tree_state:
 
67
  for attr in ['policy_network', 'value_network']:
68
  if attr in tree_state:
69
  tree_state[attr] = None
 
70
  state['tree_state'] = tree_state
 
71
  del state['tree']
72
  return state
73
 
74
  def __setstate__(self, state):
 
75
  tree_state = state.pop('tree_state')
 
76
  self.__dict__.update(state)
 
77
  new_tree = Tree.__new__(Tree)
78
  new_tree.__dict__.update(tree_state)
79
+ self.tree = new_tree
80
+
81
+ def save_tree(self):
82
+ """Saves the TreeWrapper instance (including the tree state) to a file."""
83
+ try:
84
+ with open(self.filename, 'wb') as f:
85
+ pickle.dump(self, f)
86
+ print(f"Tree wrapper for mol_id '{self.mol_id}', config '{self.config}' saved to '{self.filename}'.")
87
+ except Exception as e:
88
+ print(f"Error saving tree to {self.filename}: {e}")
89
+
90
+ @classmethod
91
+ def load_tree_from_id(cls, mol_id, config):
92
+ """
93
+ Loads a Tree object from a saved file using mol_id and config.
94
+
95
+ Args:
96
+ mol_id: The molecule ID used for saving.
97
+ config: The configuration used for saving.
98
+
99
+ Returns:
100
+ The loaded Tree object, or None if loading fails.
101
+ """
102
+ filename = os.path.join(cls.BASE_DIR, f'tree_{mol_id}_{config}.pkl')
103
+ print(f"Attempting to load tree from: {filename}")
104
+ try:
105
+ # Ensure the 'Tree' class is defined in the current scope
106
+ if 'Tree' not in globals() and 'Tree' not in locals():
107
+ raise NameError("The 'Tree' class definition is required to load the object.")
108
+
109
+ with open(filename, 'rb') as f:
110
+ loaded_wrapper = pickle.load(f) # This implicitly calls __setstate__
111
+
112
+ # Check if the loaded object is indeed a TreeWrapper instance (optional sanity check)
113
+ if not isinstance(loaded_wrapper, cls):
114
+ print(f"Warning: Loaded object from {filename} is not a TreeWrapper instance.")
115
+ return None # Or raise an error
116
+
117
+ print(f"Tree object for mol_id '{mol_id}', config '{config}' successfully loaded from '{filename}'.")
118
+ # The __setstate__ method already reconstructed the tree inside the wrapper
119
+ return loaded_wrapper.tree
120
+
121
+ except FileNotFoundError:
122
+ print(f"Error: File not found at {filename}")
123
+ return None
124
+ except (pickle.UnpicklingError, EOFError) as e:
125
+ print(f"Error: Could not unpickle file {filename}. It might be corrupted or empty. Details: {e}")
126
+ return None
127
+ except NameError as e:
128
+ print(f"Error during loading: {e}. Ensure 'Tree' class is defined.")
129
+ return None
130
+ except Exception as e:
131
+ print(f"An unexpected error occurred loading tree from {filename}: {e}")
132
+ return None
133
+
134
+ def generate_cluster_html(
135
+ tree: Tree,
136
+ cluster_node_ids: list,
137
+ cluster_num: int,
138
+ rg_cgrs_dict: dict, # <--- New parameter
139
+ aam: bool = False,
140
+ ) -> str:
141
+ # ... (initial setup, validation, filtering routes remains the same) ...
142
+ """
143
+ Generates an HTML page report for a specific cluster's synthesis routes.
144
+
145
+ :param tree: The built MCTS tree.
146
+ :param cluster_node_ids: List of route node IDs belonging to this cluster.
147
+ :param cluster_num: The identifier number for this cluster (used in title/header).
148
+ :param aam: If True, depict atom-to-atom mapping in route SVGs.
149
+ # :param scg_svg: Optional SVG string for the cluster's representative SCG.
150
+ :return: A string containing the complete HTML report.
151
+ """
152
+ # --- Depict Settings (Optional: Keep if get_route_svg depends on it) ---
153
+ # Uncomment if MoleculeContainer is used and needed:
154
+ try:
155
+ if aam:
156
+ MoleculeContainer.depict_settings(aam=True)
157
+ else:
158
+ MoleculeContainer.depict_settings(aam=False)
159
+ except NameError:
160
+ # If MoleculeContainer isn't available/needed, just pass
161
+ pass
162
+ except Exception as e:
163
+ print(f"Warning: Error setting MoleculeContainer depict settings: {e}")
164
+
165
+ # --- Validate Input ---
166
+ if not isinstance(cluster_node_ids, list):
167
+ return "<html><body>Error: cluster_node_ids must be a list.</body></html>"
168
+ if not tree or not isinstance(tree, Tree):
169
+ return "<html><body>Error: Invalid tree object provided.</body></html>"
170
+
171
+ # Filter out node IDs not actually present or not solved in the tree
172
+ valid_routes_in_cluster = []
173
+ for node_id in cluster_node_ids:
174
+ if node_id in tree.nodes and tree.nodes[node_id].is_solved():
175
+ valid_routes_in_cluster.append(node_id)
176
+ # Optionally log or warn about invalid/unsolved nodes removed
177
+
178
+ if not valid_routes_in_cluster:
179
+ # Return a minimal HTML page indicating no valid routes
180
+ return f"""
181
+ <!doctype html><html lang="en"><head><meta charset="utf-8">
182
+ <title>Cluster {cluster_num} Report</title></head><body>
183
+ <h3>Cluster {cluster_num} Report</h3>
184
+ <p>No valid/solved routes found for this cluster.</p>
185
+ </body></html>"""
186
+
187
+ # --- HTML Templates & Tags ---
188
+ # (Keep tags like th, td, fonts as they were)
189
+ th = '<th style="text-align: left; background-color:#978785; border: 1px solid black; border-spacing: 0">'
190
+ td = '<td style="text-align: left; border: 1px solid black; border-spacing: 0">'
191
+ # font_red = "<font color='red' style='font-weight: bold'>" # Consider using CSS classes instead
192
+ # font_green = "<font color='light-green' style='font-weight: bold'>"
193
+ font_head = "<font style='font-weight: bold; font-size: 18px'>"
194
+ font_normal = "<font style='font-weight: normal; font-size: 18px'>"
195
+ font_close = "</font>"
196
+
197
+ template_begin = f"""
198
+ <!doctype html>
199
+ <html lang="en">
200
+ <head>
201
+ <link href="https://cdn.jsdelivr.net/npm/bootstrap@5.1.3/dist/css/bootstrap.min.css"
202
+ rel="stylesheet"
203
+ integrity="sha384-1BmE4kWBq78iYhFldvKuhfTAU6auU8tT94WrHftjDbrCEXSU1oBoqyl2QvZ6jIW3"
204
+ crossorigin="anonymous">
205
+ <meta charset="utf-8">
206
+ <meta name="viewport" content="width=device-width, initial-scale=1">
207
+ <title>Cluster {cluster_num} Routes Report</title>
208
+ <style>
209
+ /* Optional: Add some basic styling */
210
+ .table {{ border-collapse: collapse; width: 100%; }}
211
+ th, td {{ border: 1px solid #ddd; padding: 8px; text-align: left; }}
212
+ tr:nth-child(even) {{ background-color: #f2f2f2; }}
213
+ caption {{ caption-side: top; font-size: 1.5em; margin: 1em 0; }}
214
+ svg {{ max-width: 100%; height: auto; }} /* Make SVGs responsive */
215
+ </style>
216
+ </head>
217
+ <body>
218
+ <div class="container"> """
219
+
220
+ template_end = """
221
+ </div> <script
222
+ src="https://cdn.jsdelivr.net/npm/bootstrap@5.1.3/dist/js/bootstrap.bundle.min.js"
223
+ integrity="sha384-ka7Sk0Gln4gmtz2MlQnikT1wXgYsOg+OMhuP+IlRH9sENBO0LRn5q+8nbTov4+1p"
224
+ crossorigin="anonymous">
225
+ </script>
226
+ </body>
227
+ </html>
228
+ """
229
+
230
+ box_mark = """
231
+ <svg width="30" height="30" viewBox="0 0 1 1" xmlns="http://www.w3.org/2000/svg" style="vertical-align: middle; margin-right: 5px;">
232
+ <circle cx="0.5" cy="0.5" r="0.5" fill="rgb()" fill-opacity="0.35" />
233
+ </svg>
234
+ """
235
+
236
+ # --- Build HTML Table ---
237
+ table = f"""
238
+ <table class="table table-striped table-hover caption-top">
239
+ <caption><h3>Retrosynthetic Routes Report - Cluster {cluster_num}</h3></caption>
240
+ <tbody>"""
241
+
242
+ try:
243
+ target_smiles_str = str(tree.nodes[1].curr_precursor) if 1 in tree.nodes else "N/A"
244
+ except Exception:
245
+ target_smiles_str = "Error retrieving target SMILES"
246
+ table += f"<tr>{td}{font_normal}Target Molecule: {target_smiles_str}{font_close}</td></tr>"
247
+ table += f"<tr>{td}{font_normal}Cluster Number: {cluster_num}{font_close}</td></tr>"
248
+ table += f"<tr>{td}{font_normal}Size of Cluster: {len(valid_routes_in_cluster)}{font_close} routes</td></tr>"
249
+
250
+ # --- Add RG-CGR Image ---
251
+ # Get the node_id of the first valid route in the cluster
252
+ first_route_id = valid_routes_in_cluster[0] if valid_routes_in_cluster else None
253
+
254
+ if first_route_id and rg_cgrs_dict and first_route_id in rg_cgrs_dict:
255
+ try:
256
+ rg_cgr = rg_cgrs_dict[first_route_id]
257
+ rg_cgr.clean2d()
258
+ rg_cgr_svg = rg_cgr.depict()
259
+
260
+ # Validate if it looks like SVG (basic check)
261
+ if rg_cgr_svg.strip().startswith("<svg"):
262
+ table += f"<tr>{td}{font_normal}Cluster Representative RG-CGR (from Route {first_route_id}):{font_close}<br>{rg_cgr_svg}</td></tr>"
263
+ else:
264
+ # Handle case where it's not SVG as expected
265
+ table += f"<tr>{td}{font_normal}Cluster Representative RG-CGR (from Route {first_route_id}):{font_close}<br><i>Invalid SVG format retrieved.</i></td></tr>"
266
+ print(f"Warning: Expected SVG for RG-CGR of node {first_route_id}, but got: {rg_cgr_svg[:100]}...") # Log a warning
267
+
268
+ except Exception as e:
269
+ table += f"<tr>{td}{font_normal}Cluster Representative RG-CGR (from Route {first_route_id}):{font_close}<br><i>Error retrieving/displaying RG-CGR: {e}</i></td></tr>"
270
+ else:
271
+ # Handle cases where RG-CGR data is missing
272
+ if first_route_id:
273
+ table += f"<tr>{td}{font_normal}Cluster Representative RG-CGR (from Route {first_route_id}):{font_close}<br><i>Not found in provided RG-CGR dictionary.</i></td></tr>"
274
+ else:
275
+ # This case shouldn't happen due to earlier check, but as fallback:
276
+ table += f"<tr>{td}{font_normal}Cluster Representative RG-CGR:{font_close}<br><i>No valid routes in cluster to select from.</i></td></tr>"
277
+
278
+
279
+ # --- Legend ---
280
+ table += f"""
281
+ <tr>{td}
282
+ <div style="display: flex; align-items: center; flex-wrap: wrap; gap: 15px;">
283
+ <span>{box_mark.replace("rgb()", "rgb(152, 238, 255)")} Target Molecule</span>
284
+ <span>{box_mark.replace("rgb()", "rgb(240, 171, 144)")} Molecule Not In Stock</span>
285
+ <span>{box_mark.replace("rgb()", "rgb(155, 250, 179)")} Molecule In Stock</span>
286
+ </div>
287
+ </td></tr>
288
+ """
289
+
290
+ # --- Add Routes for this Cluster ---
291
+ for route_id in valid_routes_in_cluster:
292
+ try:
293
+ svg = get_route_svg(tree, route_id) # get SVG
294
+ full_route = tree.synthesis_route(route_id) # get route steps
295
+ reactions = ""
296
+ for i, synth_step in enumerate(full_route):
297
+ reactions += f"<b>Step {i + 1}:</b> {str(synth_step)}<br>"
298
+ route_score = round(tree.route_score(route_id), 3)
299
+
300
+ table += (
301
+ f'<tr style="line-height: 1.8;">{td}{font_head}Route {route_id} | ' # Use | for separation
302
+ f"Steps: {len(full_route)} | "
303
+ f"Score: {route_score}{font_close}</td></tr>"
304
+ )
305
+ table += f"<tr>{td}{svg if svg else '<i>Error generating route visualization</i>'}</td></tr>"
306
+ table += f"<tr>{td}{reactions if reactions else '<i>No reaction steps found</i>'}</td></tr>"
307
+ except Exception as e:
308
+ table += f'<tr><td colspan="1" style="color: red;">Error processing route {route_id}: {e}</td></tr>' # Use colspan if needed based on final table structure
309
+
310
+ table += "</tbody></table>"
311
+
312
+ # --- Combine and Return Full HTML ---
313
+ full_html = template_begin + table + template_end
314
+ return full_html
cluster/visualize.py CHANGED
@@ -1,13 +1,187 @@
1
  import os
2
  import re
 
3
  import numpy as np
4
  import matplotlib.pyplot as plt
5
- from collections import Counter
6
- from synplan.utils.visualisation import get_route_svg
7
  import seaborn as sns
 
 
 
 
 
 
 
8
 
9
- def pie_chart(cluster_sizes):
10
- labels = [f'Cluster {i+1}' for i in range(len(cluster_sizes))]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
 
12
  sns.set_style("whitegrid")
13
 
@@ -16,11 +190,37 @@ def pie_chart(cluster_sizes):
16
  cluster_sizes, labels=None, autopct='%1.1f%%', colors=sns.color_palette("pastel"),
17
  startangle=140, wedgeprops={'edgecolor': 'black'}
18
  )
19
- ax.legend(wedges, labels, title="Clusters", loc="center left", bbox_to_anchor=(1, 0.5))
 
 
 
 
 
 
20
 
21
  # plt.show()
22
  return fig
23
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
 
25
  def distribution_by_depth(tree, complex_cgr_dict):
26
  if len(complex_cgr_dict) == 0:
@@ -32,7 +232,7 @@ def distribution_by_depth(tree, complex_cgr_dict):
32
  depths[n] = len(reactions)
33
  return depths
34
 
35
- def histogram_by_depth(depths, mol_id, config):
36
 
37
  if len(depths) == 0:
38
  print('Error no depths')
@@ -47,8 +247,10 @@ def histogram_by_depth(depths, mol_id, config):
47
  plt.ylabel('Frequency')
48
  plt.title(f'Frequency Histogram of Number of reactions in one tree of total {len(depths)}')
49
  plt.xticks(bins)
50
- # plt.show()
51
- plt.savefig(f'histograms/by_depth_mol{mol_id}_{config}.png', dpi=100)
 
 
52
 
53
 
54
  def group_routes_by_depth(depths):
@@ -244,7 +446,7 @@ def create_route_svg_cluster(tree, node_ids, mol_id, config, depths, cluster_num
244
 
245
  print(f"Saved: {path_name}")
246
 
247
- def save_route_images(tree, depths, mol_id=1, config=1, cluster_dict=None):
248
  """
249
  Save route images grouped by depth and/or cluster.
250
 
 
1
  import os
2
  import re
3
+ from collections import Counter
4
  import numpy as np
5
  import matplotlib.pyplot as plt
 
 
6
  import seaborn as sns
7
+ from IPython.display import SVG, display
8
+
9
+ import io
10
+ import sys
11
+
12
+ from synplan.utils.visualisation import get_route_svg
13
+ from scipy.cluster.hierarchy import dendrogram
14
 
15
+ from .reduced_g_cgr import extract_strategic_bonds, compare_rg_cgr_by_strategic_bonds
16
+
17
+ def report_2_dissimilar(similarity_df, tree, rg_cgrs_dict):
18
+ min_index = similarity_df.stack().idxmin()
19
+ row_index, col_index = min_index
20
+
21
+ print(f'Most dissimilar routes are {row_index} and {col_index}, Tanimoto index = {"%.2f" % similarity_df[row_index][col_index]}')
22
+
23
+ print('Route ID', row_index)
24
+ rg_cgr_1 = rg_cgrs_dict[row_index]
25
+ rg_cgr_1.clean2d()
26
+ display(SVG(rg_cgr_1.depict()))
27
+ extract_strategic_bonds(rg_cgr_1)
28
+ display(SVG(get_route_svg(tree, row_index)))
29
+
30
+
31
+ print('Route ID', col_index)
32
+ rg_cgr_2 = rg_cgrs_dict[col_index]
33
+ rg_cgr_2.clean2d()
34
+ display(SVG(rg_cgr_2.depict()))
35
+ extract_strategic_bonds(rg_cgr_2)
36
+ display(SVG(get_route_svg(tree, col_index)))
37
+
38
+ print('Summary:')
39
+ compare_rg_cgr_by_strategic_bonds(rg_cgr_1, rg_cgr_2)
40
+
41
+ def save_clusters_html(clusters, best_by_score, tree, rg_cgrs_dict, mol_id, config):
42
+ # Prepare a list to accumulate HTML parts for each cluster
43
+ os.makedirs("./final_clusters", exist_ok=True)
44
+ html_parts = []
45
+
46
+ # Loop over your clusters
47
+ for cluster_num, node_id_list in clusters.items():
48
+ parts = [] # to accumulate parts for this cluster
49
+
50
+ # Generate text output
51
+ best_route_in_cluster = best_by_score[cluster_num][0]
52
+ score = round(tree.route_score(best_route_in_cluster), 3)
53
+ parts.append(f"{cluster_num} ||| Size: {len(clusters[cluster_num])}\n")
54
+ parts.append(f"Example: {best_route_in_cluster} Route score: {score}\n")
55
+
56
+ # Insert the first SVG immediately after its marker text
57
+ svg1 = get_route_svg(tree, best_route_in_cluster)
58
+ parts.append(svg1 + "\n")
59
+
60
+ # Continue with additional text and SVGs
61
+ parts.append("The RG-CGR:\n")
62
+ rg_cgr = rg_cgrs_dict[best_route_in_cluster]
63
+ rg_cgr.clean2d()
64
+ svg2 = rg_cgr.depict()
65
+ parts.append(svg2 + "\n")
66
+
67
+ # Capture output from extract_strategic_bonds, if it prints something
68
+ buf = io.StringIO()
69
+ old_stdout = sys.stdout
70
+ sys.stdout = buf
71
+ extract_strategic_bonds(rg_cgr)
72
+ sys.stdout = old_stdout
73
+ strategic_text = buf.getvalue()
74
+ parts.append(strategic_text + "\n")
75
+
76
+ # Wrap this cluster's output in a <pre> tag for formatting and add some spacing
77
+ cluster_html = f'<div class="cluster" style="margin-bottom: 2em;"><pre>{"".join(parts)}</pre></div>'
78
+ html_parts.append(cluster_html)
79
+
80
+ # Combine all parts into a full HTML document
81
+ html_content = f"""
82
+ <html>
83
+ <head>
84
+ <meta charset="utf-8">
85
+ <title>Captured Cluster Outputs</title>
86
+ </head>
87
+ <body>
88
+ {''.join(html_parts)}
89
+ </body>
90
+ </html>
91
+ """
92
+
93
+
94
+ # Write the HTML content to a file
95
+ with open(f"final_clusters/htmls/mol_{mol_id}_{config}.html", "w", encoding="utf-8") as f:
96
+ f.write(html_content)
97
+
98
+ def report_2_dissimilar_to_html(similarity_df, tree, rg_cgrs_dict, mol_id=1, config=2,output_filename=None):
99
+ """Generates an HTML report of the two most dissimilar routes based on a similarity DataFrame."""
100
+
101
+ os.makedirs("./dissimilars", exist_ok=True)
102
+ output_filename=f"dissimilars/report_dissimilar_mol_{mol_id}_{config}.html"
103
+ # Identify the two most dissimilar routes
104
+ min_index = similarity_df.stack().idxmin()
105
+ row_index, col_index = min_index
106
+
107
+ # Capture text output in a buffer
108
+ buf = io.StringIO()
109
+ old_stdout = sys.stdout
110
+ sys.stdout = buf
111
+
112
+ print(f'Most dissimilar routes are {row_index} and {col_index}, Tanimoto index = {"%.2f" % similarity_df[row_index][col_index]}')
113
+
114
+ # Store HTML content
115
+ html_parts = []
116
+
117
+ # Function to capture and append text, SVGs, and function outputs
118
+ def capture_route_info(route_id):
119
+ rg_cgr = rg_cgrs_dict[route_id]
120
+ rg_cgr.clean2d()
121
+
122
+ # Capture the first SVG (RG-CGR depiction)
123
+ svg1 = rg_cgr.depict()
124
+
125
+ # Capture the second SVG (Route depiction)
126
+ svg2 = get_route_svg(tree, route_id)
127
+
128
+ # Capture output of extract_strategic_bonds
129
+ buf_extract = io.StringIO()
130
+ sys.stdout = buf_extract
131
+ extract_strategic_bonds(rg_cgr)
132
+ sys.stdout = old_stdout
133
+ extract_output = buf_extract.getvalue()
134
+
135
+ # Store text + SVGs in HTML format
136
+ html_parts.append(f"""
137
+ <div class="route-section">
138
+ <pre>{buf.getvalue()}</pre>
139
+ <div class="svg1">{svg1}</div>
140
+ <pre>{extract_output}</pre>
141
+ <div class="svg2">{svg2}</div>
142
+ </div>
143
+ """)
144
+ buf.truncate(0) # Clear buffer for next route
145
+ buf.seek(0)
146
+
147
+ # Process the first route
148
+ capture_route_info(row_index)
149
+
150
+ # Process the second route
151
+ capture_route_info(col_index)
152
+
153
+ # Capture and store final summary
154
+ buf_summary = io.StringIO()
155
+ sys.stdout = buf_summary
156
+ compare_rg_cgr_by_strategic_bonds(rg_cgrs_dict[row_index], rg_cgrs_dict[col_index])
157
+ sys.stdout = old_stdout
158
+ summary_output = buf_summary.getvalue()
159
+ html_parts.append(f"<h2>Summary</h2><pre>{summary_output}</pre>")
160
+
161
+ # Restore standard stdout
162
+ sys.stdout = old_stdout
163
+
164
+ # Build the full HTML file
165
+ html_content = f"""
166
+ <html>
167
+ <head>
168
+ <meta charset="utf-8">
169
+ <title>Route Dissimilarity Report</title>
170
+ </head>
171
+ <body>
172
+ {''.join(html_parts)}
173
+ </body>
174
+ </html>
175
+ """
176
+
177
+ # Write the HTML file
178
+ with open(output_filename, "w", encoding="utf-8") as f:
179
+ f.write(html_content)
180
+
181
+ print(f"Report saved as {output_filename}")
182
+
183
+ def pie_chart(cluster_sizes, sub='', input_cluster_num=1, input_step_nums=None):
184
+ labels = [f'{sub}Cluster {i+1}' for i in range(len(cluster_sizes))]
185
 
186
  sns.set_style("whitegrid")
187
 
 
190
  cluster_sizes, labels=None, autopct='%1.1f%%', colors=sns.color_palette("pastel"),
191
  startangle=140, wedgeprops={'edgecolor': 'black'}
192
  )
193
+ ax.legend(wedges, labels, title=f"{sub}Clusters", loc="center left", bbox_to_anchor=(1, 0.5))
194
+ if sub == '':
195
+ ax.set_title(f"{sub}Cluster Size Distribution for {sum(cluster_sizes)} routes")
196
+ else:
197
+ ax.set_title(f"{sub}cluster Size Distribution for {sum(cluster_sizes)} routes in cluster {input_cluster_num} with number of steps {input_step_nums}")
198
+
199
+ plt.close(fig)
200
 
201
  # plt.show()
202
  return fig
203
 
204
+ def save_dendrogram(df, Z, mol_id, config):
205
+ plt.figure(figsize=(14, 7)) # figsize=(14, 7)
206
+
207
+ dendrogram(Z, labels=df.columns, leaf_rotation=90)
208
+ plt.title(f"Hierarchical Clustering Dendrogram for routes generated for molecule #{mol_id}")
209
+ plt.xlabel("Route node id")
210
+ plt.ylabel("Distance (1 - Similarity)")
211
+
212
+ # Get current y-axis limits and add a gap below zero
213
+ ax = plt.gca()
214
+ ymin, ymax = ax.get_ylim()
215
+ # Add a gap that is 5% of the current y-range below zero
216
+ gap = 0.05 * (ymax - ymin)
217
+ ax.set_ylim(ymin - gap, ymax)
218
+ ax.grid(False)
219
+ ax.autoscale(enable=None, axis="x", tight=True)
220
+
221
+ plt.tight_layout()
222
+ plt.savefig(f'dendrograms/av_link_mol{mol_id}_{config}.png', dpi=100)
223
+
224
 
225
  def distribution_by_depth(tree, complex_cgr_dict):
226
  if len(complex_cgr_dict) == 0:
 
232
  depths[n] = len(reactions)
233
  return depths
234
 
235
+ def histogram_by_depth(depths, mol_id=1, config=1, save=False):
236
 
237
  if len(depths) == 0:
238
  print('Error no depths')
 
247
  plt.ylabel('Frequency')
248
  plt.title(f'Frequency Histogram of Number of reactions in one tree of total {len(depths)}')
249
  plt.xticks(bins)
250
+ if save:
251
+ plt.savefig(f'histograms/by_depth_mol{mol_id}_{config}.png', dpi=100)
252
+ else:
253
+ plt.show()
254
 
255
 
256
  def group_routes_by_depth(depths):
 
446
 
447
  print(f"Saved: {path_name}")
448
 
449
+ def save_route_images(tree, depths, mol_id, config, cluster_dict=None):
450
  """
451
  Save route images grouped by depth and/or cluster.
452