Spaces:
Running
Running
Gilmullin Almaz
commited on
Commit
·
f2f3593
1
Parent(s):
2830c50
debugging
Browse files- cluster/{super_cgr.py → generalized_cgr.py} +0 -0
- cluster/reduced_g_cgr.py +159 -0
- cluster/rs_cgr.py +0 -40
- cluster/utils.py +285 -8
- cluster/visualize.py +211 -9
cluster/{super_cgr.py → generalized_cgr.py}
RENAMED
|
File without changes
|
cluster/reduced_g_cgr.py
ADDED
|
@@ -0,0 +1,159 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from CGRtools.containers.bonds import DynamicBond
|
| 2 |
+
|
| 3 |
+
def reducing_g_cgr(g_cgr):
|
| 4 |
+
"""
|
| 5 |
+
Reduces a Generalized Condensed Graph of reaction (G-CGR) by performing the following steps:
|
| 6 |
+
|
| 7 |
+
1. Extracts substructures corresponding to connected components from the input G-CGR.
|
| 8 |
+
2. Selects the first substructure as the target to work on.
|
| 9 |
+
3. Iterates over all bonds in the target G-CGR:
|
| 10 |
+
- If a bond is identified as a "leaving group" (its primary order is None while its original order is defined),
|
| 11 |
+
the bond is removed.
|
| 12 |
+
- If a bond has a modified order (both primary and original orders are integers) and the primary order is less than the original,
|
| 13 |
+
the bond is deleted and then re-added with a new dynamic bond using the primary order (this updates the bond to the reduced form).
|
| 14 |
+
4. After bond modifications, re-extracts the substructure from the target G-CGR (now called the reduced G-CGR or RG-CGR).
|
| 15 |
+
5. If the charge distributions (_p_charges vs. _charges) differ, neutralizes the charges by setting them to zero.
|
| 16 |
+
|
| 17 |
+
Finally, returns the reduced G-CGR.
|
| 18 |
+
"""
|
| 19 |
+
# Get all connected components of the G-CGR as separate substructures.
|
| 20 |
+
cgr_prods = [g_cgr.substructure(c) for c in g_cgr.connected_components]
|
| 21 |
+
target_cgr = cgr_prods[0] # Choose the first substructure (main product) for further reduction.
|
| 22 |
+
|
| 23 |
+
# Iterate over each bond in the target G-CGR.
|
| 24 |
+
bond_items = list(target_cgr._bonds.items())
|
| 25 |
+
for atom1, bond_set in bond_items:
|
| 26 |
+
bond_set_items = list(bond_set.items())
|
| 27 |
+
for atom2, bond in bond_set_items:
|
| 28 |
+
|
| 29 |
+
# Removing bonds corresponding to leaving groups:
|
| 30 |
+
# If product bond order is None (indicating a leaving group) but an original bond order exists,
|
| 31 |
+
# delete the bond.
|
| 32 |
+
if bond.p_order is None and bond.order is not None:
|
| 33 |
+
target_cgr.delete_bond(atom1, atom2)
|
| 34 |
+
|
| 35 |
+
# For bonds that have been modified (not leaving groups) where the new (primary) order is less than the original:
|
| 36 |
+
# Remove the bond and re-add it using the DynamicBond with the primary order for both bond orders.
|
| 37 |
+
elif type(bond.p_order) is int and type(bond.order) is int and bond.p_order != bond.order:
|
| 38 |
+
p_order = int(bond.p_order)
|
| 39 |
+
target_cgr.delete_bond(atom1, atom2)
|
| 40 |
+
target_cgr.add_bond(atom1, atom2, DynamicBond(p_order, p_order))
|
| 41 |
+
|
| 42 |
+
# After modifying bonds, extract the reduced G-CGR from the target's connected components.
|
| 43 |
+
rg_cgr = [target_cgr.substructure(c) for c in target_cgr.connected_components][0]
|
| 44 |
+
|
| 45 |
+
# Neutralize charges if the primary charges and current charges differ.
|
| 46 |
+
if rg_cgr._p_charges != rg_cgr._charges:
|
| 47 |
+
for num, charge in rg_cgr._charges.items():
|
| 48 |
+
if charge != 0:
|
| 49 |
+
rg_cgr._atoms[num].charge = 0
|
| 50 |
+
|
| 51 |
+
return rg_cgr
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
def process_all_rg_cgrs(g_cgrs_dict):
|
| 55 |
+
"""
|
| 56 |
+
Processes a collection (dictionary) of G-CGRs to generate their reduced forms (RG-CGRs).
|
| 57 |
+
|
| 58 |
+
Iterates over each G-CGR in the provided dictionary and applies the reducing_g_cgr function.
|
| 59 |
+
|
| 60 |
+
Note: There is an apparent bug in the code since it uses an undefined variable 'super_cgrs_dict'
|
| 61 |
+
and assigns to 'all_rs_cgrs' instead of 'all_rg_cgrs'. The intended behavior is to iterate over
|
| 62 |
+
the input dictionary (g_cgrs_dict) and store the reduced RG-CGR for each key.
|
| 63 |
+
|
| 64 |
+
Returns:
|
| 65 |
+
A dictionary where each key corresponds to the RG-CGR obtained from the input G-CGR.
|
| 66 |
+
"""
|
| 67 |
+
all_rg_cgrs = dict()
|
| 68 |
+
for num, cgr in g_cgrs_dict.items():
|
| 69 |
+
all_rg_cgrs[num] = reducing_g_cgr(cgr)
|
| 70 |
+
return all_rg_cgrs
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
def report_strategic_bonds(result, target_cgr):
|
| 74 |
+
"""
|
| 75 |
+
Reports strategic bonds from a provided result list.
|
| 76 |
+
|
| 77 |
+
Each element in 'result' is expected to be a list with two elements:
|
| 78 |
+
- A tuple (atom pair) indicating the connected atoms.
|
| 79 |
+
- The primary bond order (p_order) associated with that bond.
|
| 80 |
+
|
| 81 |
+
The function prints out the atoms (accessed from target_cgr._atoms) and the bond order.
|
| 82 |
+
"""
|
| 83 |
+
for value in result:
|
| 84 |
+
atom_pair = value[0]
|
| 85 |
+
# Print the two atoms and the associated primary bond order.
|
| 86 |
+
print('\t', target_cgr._atoms[atom_pair[0]], target_cgr._atoms[atom_pair[1]], value[1])
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
def extract_strategic_bonds(target_cgr, report=True):
|
| 90 |
+
"""
|
| 91 |
+
Extracts and optionally reports strategic bonds from a reduced G-CGR (RG-CGR).
|
| 92 |
+
|
| 93 |
+
Strategic bonds are defined as those with:
|
| 94 |
+
- No current bond order (order is None) but a defined primary bond order (p_order is not None).
|
| 95 |
+
|
| 96 |
+
The function goes through all bonds in the target_cgr, collects each unique bond (avoiding duplicates by using a set)
|
| 97 |
+
along with its primary bond order, and optionally prints them out.
|
| 98 |
+
|
| 99 |
+
Returns:
|
| 100 |
+
A list where each element is a pair: [bond_key (tuple of atom indices), primary bond order]
|
| 101 |
+
"""
|
| 102 |
+
result = []
|
| 103 |
+
seen = set()
|
| 104 |
+
# Loop through all bonds in the RG-CGR.
|
| 105 |
+
for atom1, bond_set in target_cgr._bonds.items():
|
| 106 |
+
for atom2, bond in bond_set.items():
|
| 107 |
+
# Check for strategic bonds (order undefined but p_order defined).
|
| 108 |
+
if bond.order is None and bond.p_order is not None:
|
| 109 |
+
# Create a sorted tuple of the atom pair to ensure uniqueness.
|
| 110 |
+
bond_key = tuple(sorted((atom1, atom2)))
|
| 111 |
+
if bond_key not in seen:
|
| 112 |
+
seen.add(bond_key)
|
| 113 |
+
result.append([bond_key, bond.p_order])
|
| 114 |
+
# If reporting is enabled, print the strategic bonds.
|
| 115 |
+
if report:
|
| 116 |
+
print('Strategic bonds in RG-CGR:')
|
| 117 |
+
report_strategic_bonds(result, target_cgr)
|
| 118 |
+
return result
|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
def compare_rg_cgr_by_strategic_bonds(rg_cgr1, rg_cgr2, report=True):
|
| 122 |
+
"""
|
| 123 |
+
Compares two reduced G-CGRs (RG-CGRs) based on their strategic bonds.
|
| 124 |
+
|
| 125 |
+
The function performs the following steps:
|
| 126 |
+
|
| 127 |
+
1. Extracts the list of strategic bonds for each RG-CGR.
|
| 128 |
+
2. Converts each list into a set of tuples (bond key and bond order) for easy set operations.
|
| 129 |
+
3. Identifies common bonds, and bonds unique to each RG-CGR.
|
| 130 |
+
4. Converts these sets back into lists for reporting.
|
| 131 |
+
5. Prints out the common bonds, bonds unique to the first RG-CGR, and bonds unique to the second RG-CGR.
|
| 132 |
+
|
| 133 |
+
The reporting uses the report_strategic_bonds function to output the atom details and bond orders.
|
| 134 |
+
"""
|
| 135 |
+
# Extract strategic bonds from both RG-CGRs without reporting.
|
| 136 |
+
l1 = extract_strategic_bonds(rg_cgr1, report=False)
|
| 137 |
+
l2 = extract_strategic_bonds(rg_cgr2, report=False)
|
| 138 |
+
|
| 139 |
+
# Create sets of (atom pair, bond order) tuples for both RG-CGRs.
|
| 140 |
+
set_l1 = { (tuple(item[0]), item[1]) for item in l1 }
|
| 141 |
+
set_l2 = { (tuple(item[0]), item[1]) for item in l2 }
|
| 142 |
+
|
| 143 |
+
# Identify common bonds and bonds unique to each list.
|
| 144 |
+
common = set_l1 & set_l2
|
| 145 |
+
unique_l1 = set_l1 - set_l2
|
| 146 |
+
unique_l2 = set_l2 - set_l1
|
| 147 |
+
|
| 148 |
+
# Convert the sets back to list format for reporting.
|
| 149 |
+
common_list = [ [atom_pair, order] for atom_pair, order in common ]
|
| 150 |
+
unique_l1_list = [ [atom_pair, order] for atom_pair, order in unique_l1 ]
|
| 151 |
+
unique_l2_list = [ [atom_pair, order] for atom_pair, order in unique_l2 ]
|
| 152 |
+
|
| 153 |
+
if report:
|
| 154 |
+
print("Common:")
|
| 155 |
+
report_strategic_bonds(common_list, rg_cgr1)
|
| 156 |
+
print("Unique for first RG-CGR:")
|
| 157 |
+
report_strategic_bonds(unique_l1_list, rg_cgr1)
|
| 158 |
+
print("Unique for second RG-CGR:")
|
| 159 |
+
report_strategic_bonds(unique_l2_list, rg_cgr1)
|
cluster/rs_cgr.py
DELETED
|
@@ -1,40 +0,0 @@
|
|
| 1 |
-
from CGRtools.containers.bonds import DynamicBond
|
| 2 |
-
|
| 3 |
-
def s_cgr2rs_cgr(s_cgr):
|
| 4 |
-
cgr_prods = [s_cgr.substructure(c) for c in s_cgr.connected_components]
|
| 5 |
-
target_cgr = cgr_prods[0]
|
| 6 |
-
|
| 7 |
-
bond_items = list(target_cgr._bonds.items())
|
| 8 |
-
for atom1, bond_set in bond_items:
|
| 9 |
-
bond_set_items = list(bond_set.items())
|
| 10 |
-
for atom2, bond in bond_set_items:
|
| 11 |
-
|
| 12 |
-
# Leaving groups removal
|
| 13 |
-
if bond.p_order == None and bond.order is not None:
|
| 14 |
-
# print(atom1, atom2)
|
| 15 |
-
# print(bond)
|
| 16 |
-
target_cgr.delete_bond(atom1, atom2)
|
| 17 |
-
# target_cgr.clean2d()
|
| 18 |
-
# display(SVG(target_cgr.depict()))
|
| 19 |
-
|
| 20 |
-
## Modified bond, but not leaving group
|
| 21 |
-
elif type(bond.p_order) is int and type(bond.order) is int and bond.p_order < bond.order:
|
| 22 |
-
p_order = int(bond.p_order)
|
| 23 |
-
target_cgr.delete_bond(atom1, atom2)
|
| 24 |
-
target_cgr.add_bond(atom1, atom2, DynamicBond(p_order, p_order))
|
| 25 |
-
|
| 26 |
-
rs_cgr = [target_cgr.substructure(c) for c in target_cgr.connected_components][0]
|
| 27 |
-
|
| 28 |
-
# Charge neutralizer
|
| 29 |
-
if rs_cgr._p_charges != rs_cgr._charges:
|
| 30 |
-
for num, charge in rs_cgr._charges.items():
|
| 31 |
-
if charge != 0:
|
| 32 |
-
rs_cgr._atoms[num].charge = 0
|
| 33 |
-
|
| 34 |
-
return rs_cgr
|
| 35 |
-
|
| 36 |
-
def process_all_rs_cgrs(super_cgrs_dict):
|
| 37 |
-
all_rs_cgrs = dict()
|
| 38 |
-
for num, cgr in super_cgrs_dict.items():
|
| 39 |
-
all_rs_cgrs[num] = s_cgr2rs_cgr(cgr)
|
| 40 |
-
return all_rs_cgrs
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
cluster/utils.py
CHANGED
|
@@ -1,3 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
def extract_reactions(tree):
|
| 2 |
reactions_dict = {}
|
| 3 |
for node_id in set(tree.winning_nodes):
|
|
@@ -5,14 +11,55 @@ def extract_reactions(tree):
|
|
| 5 |
reactions_dict[node_id] = reactions
|
| 6 |
return reactions_dict
|
| 7 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 8 |
|
| 9 |
class TreeWrapper:
|
| 10 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 11 |
self.tree = tree
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 12 |
|
| 13 |
def __getstate__(self):
|
| 14 |
state = self.__dict__.copy()
|
| 15 |
-
# Save the state of the tree
|
| 16 |
tree_state = self.tree.__dict__.copy()
|
| 17 |
# Reset or remove non-pickleable attributes (e.g., _tqdm, policy_network, value_network)
|
| 18 |
if '_tqdm' in tree_state:
|
|
@@ -20,18 +67,248 @@ class TreeWrapper:
|
|
| 20 |
for attr in ['policy_network', 'value_network']:
|
| 21 |
if attr in tree_state:
|
| 22 |
tree_state[attr] = None
|
| 23 |
-
# Store tree state separately
|
| 24 |
state['tree_state'] = tree_state
|
| 25 |
-
# Remove the actual tree instance from the state
|
| 26 |
del state['tree']
|
| 27 |
return state
|
| 28 |
|
| 29 |
def __setstate__(self, state):
|
| 30 |
-
# Retrieve the stored tree state
|
| 31 |
tree_state = state.pop('tree_state')
|
| 32 |
-
# Update the instance state
|
| 33 |
self.__dict__.update(state)
|
| 34 |
-
# Create a new Tree instance without calling __init__
|
| 35 |
new_tree = Tree.__new__(Tree)
|
| 36 |
new_tree.__dict__.update(tree_state)
|
| 37 |
-
self.tree = new_tree
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from synplan.mcts.tree import Tree
|
| 2 |
+
from synplan.utils.visualisation import get_route_svg
|
| 3 |
+
from CGRtools.containers import MoleculeContainer
|
| 4 |
+
import pickle
|
| 5 |
+
import os
|
| 6 |
+
|
| 7 |
def extract_reactions(tree):
|
| 8 |
reactions_dict = {}
|
| 9 |
for node_id in set(tree.winning_nodes):
|
|
|
|
| 11 |
reactions_dict[node_id] = reactions
|
| 12 |
return reactions_dict
|
| 13 |
|
| 14 |
+
def extract_rules_from_route(node_id, tree):
|
| 15 |
+
nodes = tree.route_to_node(node_id)
|
| 16 |
+
found_rules_ids = []
|
| 17 |
+
for i in range(len(nodes)):
|
| 18 |
+
precursor = nodes[i].new_precursors[0]
|
| 19 |
+
if len(precursor) != 0:
|
| 20 |
+
if 'reactor_id' in precursor.molecule.meta.keys():
|
| 21 |
+
found_rules_ids.append(precursor.molecule.meta['reactor_id'])
|
| 22 |
+
return found_rules_ids[::-1]
|
| 23 |
+
|
| 24 |
+
def save_smarts(mol_id, config, reactions_dict):
|
| 25 |
+
with open(f'smarts/smarts_mol_{mol_id}_{config}.txt', "w") as file:
|
| 26 |
+
for node_id, reactions in reactions_dict.items():
|
| 27 |
+
file.write(f"{node_id}\n")
|
| 28 |
+
for reaction in reactions:
|
| 29 |
+
file.write(f"{reaction}\n")
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def get_highest_route_nodes(tree, node_dict):
|
| 33 |
+
highest_nodes = {}
|
| 34 |
+
for key, node_ids in node_dict.items():
|
| 35 |
+
max_score = float('-inf')
|
| 36 |
+
best_nodes = []
|
| 37 |
+
for node_id in node_ids:
|
| 38 |
+
score = round(tree.route_score(node_id), 3)
|
| 39 |
+
if score > max_score:
|
| 40 |
+
max_score = score
|
| 41 |
+
best_nodes = [node_id]
|
| 42 |
+
elif score == max_score:
|
| 43 |
+
best_nodes.append(node_id)
|
| 44 |
+
highest_nodes[key] = best_nodes
|
| 45 |
+
return highest_nodes
|
| 46 |
+
|
| 47 |
|
| 48 |
class TreeWrapper:
|
| 49 |
+
|
| 50 |
+
BASE_DIR = 'forest'
|
| 51 |
+
|
| 52 |
+
def __init__(self, tree, mol_id, config):
|
| 53 |
+
"""Initializes the TreeWrapper."""
|
| 54 |
self.tree = tree
|
| 55 |
+
self.mol_id = mol_id
|
| 56 |
+
self.config = config
|
| 57 |
+
# Ensure the directory exists before creating the filename
|
| 58 |
+
os.makedirs(self.BASE_DIR, exist_ok=True)
|
| 59 |
+
self.filename = os.path.join(self.BASE_DIR, f'tree_{mol_id}_{config}.pkl')
|
| 60 |
|
| 61 |
def __getstate__(self):
|
| 62 |
state = self.__dict__.copy()
|
|
|
|
| 63 |
tree_state = self.tree.__dict__.copy()
|
| 64 |
# Reset or remove non-pickleable attributes (e.g., _tqdm, policy_network, value_network)
|
| 65 |
if '_tqdm' in tree_state:
|
|
|
|
| 67 |
for attr in ['policy_network', 'value_network']:
|
| 68 |
if attr in tree_state:
|
| 69 |
tree_state[attr] = None
|
|
|
|
| 70 |
state['tree_state'] = tree_state
|
|
|
|
| 71 |
del state['tree']
|
| 72 |
return state
|
| 73 |
|
| 74 |
def __setstate__(self, state):
|
|
|
|
| 75 |
tree_state = state.pop('tree_state')
|
|
|
|
| 76 |
self.__dict__.update(state)
|
|
|
|
| 77 |
new_tree = Tree.__new__(Tree)
|
| 78 |
new_tree.__dict__.update(tree_state)
|
| 79 |
+
self.tree = new_tree
|
| 80 |
+
|
| 81 |
+
def save_tree(self):
|
| 82 |
+
"""Saves the TreeWrapper instance (including the tree state) to a file."""
|
| 83 |
+
try:
|
| 84 |
+
with open(self.filename, 'wb') as f:
|
| 85 |
+
pickle.dump(self, f)
|
| 86 |
+
print(f"Tree wrapper for mol_id '{self.mol_id}', config '{self.config}' saved to '{self.filename}'.")
|
| 87 |
+
except Exception as e:
|
| 88 |
+
print(f"Error saving tree to {self.filename}: {e}")
|
| 89 |
+
|
| 90 |
+
@classmethod
|
| 91 |
+
def load_tree_from_id(cls, mol_id, config):
|
| 92 |
+
"""
|
| 93 |
+
Loads a Tree object from a saved file using mol_id and config.
|
| 94 |
+
|
| 95 |
+
Args:
|
| 96 |
+
mol_id: The molecule ID used for saving.
|
| 97 |
+
config: The configuration used for saving.
|
| 98 |
+
|
| 99 |
+
Returns:
|
| 100 |
+
The loaded Tree object, or None if loading fails.
|
| 101 |
+
"""
|
| 102 |
+
filename = os.path.join(cls.BASE_DIR, f'tree_{mol_id}_{config}.pkl')
|
| 103 |
+
print(f"Attempting to load tree from: {filename}")
|
| 104 |
+
try:
|
| 105 |
+
# Ensure the 'Tree' class is defined in the current scope
|
| 106 |
+
if 'Tree' not in globals() and 'Tree' not in locals():
|
| 107 |
+
raise NameError("The 'Tree' class definition is required to load the object.")
|
| 108 |
+
|
| 109 |
+
with open(filename, 'rb') as f:
|
| 110 |
+
loaded_wrapper = pickle.load(f) # This implicitly calls __setstate__
|
| 111 |
+
|
| 112 |
+
# Check if the loaded object is indeed a TreeWrapper instance (optional sanity check)
|
| 113 |
+
if not isinstance(loaded_wrapper, cls):
|
| 114 |
+
print(f"Warning: Loaded object from {filename} is not a TreeWrapper instance.")
|
| 115 |
+
return None # Or raise an error
|
| 116 |
+
|
| 117 |
+
print(f"Tree object for mol_id '{mol_id}', config '{config}' successfully loaded from '{filename}'.")
|
| 118 |
+
# The __setstate__ method already reconstructed the tree inside the wrapper
|
| 119 |
+
return loaded_wrapper.tree
|
| 120 |
+
|
| 121 |
+
except FileNotFoundError:
|
| 122 |
+
print(f"Error: File not found at {filename}")
|
| 123 |
+
return None
|
| 124 |
+
except (pickle.UnpicklingError, EOFError) as e:
|
| 125 |
+
print(f"Error: Could not unpickle file {filename}. It might be corrupted or empty. Details: {e}")
|
| 126 |
+
return None
|
| 127 |
+
except NameError as e:
|
| 128 |
+
print(f"Error during loading: {e}. Ensure 'Tree' class is defined.")
|
| 129 |
+
return None
|
| 130 |
+
except Exception as e:
|
| 131 |
+
print(f"An unexpected error occurred loading tree from {filename}: {e}")
|
| 132 |
+
return None
|
| 133 |
+
|
| 134 |
+
def generate_cluster_html(
|
| 135 |
+
tree: Tree,
|
| 136 |
+
cluster_node_ids: list,
|
| 137 |
+
cluster_num: int,
|
| 138 |
+
rg_cgrs_dict: dict, # <--- New parameter
|
| 139 |
+
aam: bool = False,
|
| 140 |
+
) -> str:
|
| 141 |
+
# ... (initial setup, validation, filtering routes remains the same) ...
|
| 142 |
+
"""
|
| 143 |
+
Generates an HTML page report for a specific cluster's synthesis routes.
|
| 144 |
+
|
| 145 |
+
:param tree: The built MCTS tree.
|
| 146 |
+
:param cluster_node_ids: List of route node IDs belonging to this cluster.
|
| 147 |
+
:param cluster_num: The identifier number for this cluster (used in title/header).
|
| 148 |
+
:param aam: If True, depict atom-to-atom mapping in route SVGs.
|
| 149 |
+
# :param scg_svg: Optional SVG string for the cluster's representative SCG.
|
| 150 |
+
:return: A string containing the complete HTML report.
|
| 151 |
+
"""
|
| 152 |
+
# --- Depict Settings (Optional: Keep if get_route_svg depends on it) ---
|
| 153 |
+
# Uncomment if MoleculeContainer is used and needed:
|
| 154 |
+
try:
|
| 155 |
+
if aam:
|
| 156 |
+
MoleculeContainer.depict_settings(aam=True)
|
| 157 |
+
else:
|
| 158 |
+
MoleculeContainer.depict_settings(aam=False)
|
| 159 |
+
except NameError:
|
| 160 |
+
# If MoleculeContainer isn't available/needed, just pass
|
| 161 |
+
pass
|
| 162 |
+
except Exception as e:
|
| 163 |
+
print(f"Warning: Error setting MoleculeContainer depict settings: {e}")
|
| 164 |
+
|
| 165 |
+
# --- Validate Input ---
|
| 166 |
+
if not isinstance(cluster_node_ids, list):
|
| 167 |
+
return "<html><body>Error: cluster_node_ids must be a list.</body></html>"
|
| 168 |
+
if not tree or not isinstance(tree, Tree):
|
| 169 |
+
return "<html><body>Error: Invalid tree object provided.</body></html>"
|
| 170 |
+
|
| 171 |
+
# Filter out node IDs not actually present or not solved in the tree
|
| 172 |
+
valid_routes_in_cluster = []
|
| 173 |
+
for node_id in cluster_node_ids:
|
| 174 |
+
if node_id in tree.nodes and tree.nodes[node_id].is_solved():
|
| 175 |
+
valid_routes_in_cluster.append(node_id)
|
| 176 |
+
# Optionally log or warn about invalid/unsolved nodes removed
|
| 177 |
+
|
| 178 |
+
if not valid_routes_in_cluster:
|
| 179 |
+
# Return a minimal HTML page indicating no valid routes
|
| 180 |
+
return f"""
|
| 181 |
+
<!doctype html><html lang="en"><head><meta charset="utf-8">
|
| 182 |
+
<title>Cluster {cluster_num} Report</title></head><body>
|
| 183 |
+
<h3>Cluster {cluster_num} Report</h3>
|
| 184 |
+
<p>No valid/solved routes found for this cluster.</p>
|
| 185 |
+
</body></html>"""
|
| 186 |
+
|
| 187 |
+
# --- HTML Templates & Tags ---
|
| 188 |
+
# (Keep tags like th, td, fonts as they were)
|
| 189 |
+
th = '<th style="text-align: left; background-color:#978785; border: 1px solid black; border-spacing: 0">'
|
| 190 |
+
td = '<td style="text-align: left; border: 1px solid black; border-spacing: 0">'
|
| 191 |
+
# font_red = "<font color='red' style='font-weight: bold'>" # Consider using CSS classes instead
|
| 192 |
+
# font_green = "<font color='light-green' style='font-weight: bold'>"
|
| 193 |
+
font_head = "<font style='font-weight: bold; font-size: 18px'>"
|
| 194 |
+
font_normal = "<font style='font-weight: normal; font-size: 18px'>"
|
| 195 |
+
font_close = "</font>"
|
| 196 |
+
|
| 197 |
+
template_begin = f"""
|
| 198 |
+
<!doctype html>
|
| 199 |
+
<html lang="en">
|
| 200 |
+
<head>
|
| 201 |
+
<link href="https://cdn.jsdelivr.net/npm/bootstrap@5.1.3/dist/css/bootstrap.min.css"
|
| 202 |
+
rel="stylesheet"
|
| 203 |
+
integrity="sha384-1BmE4kWBq78iYhFldvKuhfTAU6auU8tT94WrHftjDbrCEXSU1oBoqyl2QvZ6jIW3"
|
| 204 |
+
crossorigin="anonymous">
|
| 205 |
+
<meta charset="utf-8">
|
| 206 |
+
<meta name="viewport" content="width=device-width, initial-scale=1">
|
| 207 |
+
<title>Cluster {cluster_num} Routes Report</title>
|
| 208 |
+
<style>
|
| 209 |
+
/* Optional: Add some basic styling */
|
| 210 |
+
.table {{ border-collapse: collapse; width: 100%; }}
|
| 211 |
+
th, td {{ border: 1px solid #ddd; padding: 8px; text-align: left; }}
|
| 212 |
+
tr:nth-child(even) {{ background-color: #f2f2f2; }}
|
| 213 |
+
caption {{ caption-side: top; font-size: 1.5em; margin: 1em 0; }}
|
| 214 |
+
svg {{ max-width: 100%; height: auto; }} /* Make SVGs responsive */
|
| 215 |
+
</style>
|
| 216 |
+
</head>
|
| 217 |
+
<body>
|
| 218 |
+
<div class="container"> """
|
| 219 |
+
|
| 220 |
+
template_end = """
|
| 221 |
+
</div> <script
|
| 222 |
+
src="https://cdn.jsdelivr.net/npm/bootstrap@5.1.3/dist/js/bootstrap.bundle.min.js"
|
| 223 |
+
integrity="sha384-ka7Sk0Gln4gmtz2MlQnikT1wXgYsOg+OMhuP+IlRH9sENBO0LRn5q+8nbTov4+1p"
|
| 224 |
+
crossorigin="anonymous">
|
| 225 |
+
</script>
|
| 226 |
+
</body>
|
| 227 |
+
</html>
|
| 228 |
+
"""
|
| 229 |
+
|
| 230 |
+
box_mark = """
|
| 231 |
+
<svg width="30" height="30" viewBox="0 0 1 1" xmlns="http://www.w3.org/2000/svg" style="vertical-align: middle; margin-right: 5px;">
|
| 232 |
+
<circle cx="0.5" cy="0.5" r="0.5" fill="rgb()" fill-opacity="0.35" />
|
| 233 |
+
</svg>
|
| 234 |
+
"""
|
| 235 |
+
|
| 236 |
+
# --- Build HTML Table ---
|
| 237 |
+
table = f"""
|
| 238 |
+
<table class="table table-striped table-hover caption-top">
|
| 239 |
+
<caption><h3>Retrosynthetic Routes Report - Cluster {cluster_num}</h3></caption>
|
| 240 |
+
<tbody>"""
|
| 241 |
+
|
| 242 |
+
try:
|
| 243 |
+
target_smiles_str = str(tree.nodes[1].curr_precursor) if 1 in tree.nodes else "N/A"
|
| 244 |
+
except Exception:
|
| 245 |
+
target_smiles_str = "Error retrieving target SMILES"
|
| 246 |
+
table += f"<tr>{td}{font_normal}Target Molecule: {target_smiles_str}{font_close}</td></tr>"
|
| 247 |
+
table += f"<tr>{td}{font_normal}Cluster Number: {cluster_num}{font_close}</td></tr>"
|
| 248 |
+
table += f"<tr>{td}{font_normal}Size of Cluster: {len(valid_routes_in_cluster)}{font_close} routes</td></tr>"
|
| 249 |
+
|
| 250 |
+
# --- Add RG-CGR Image ---
|
| 251 |
+
# Get the node_id of the first valid route in the cluster
|
| 252 |
+
first_route_id = valid_routes_in_cluster[0] if valid_routes_in_cluster else None
|
| 253 |
+
|
| 254 |
+
if first_route_id and rg_cgrs_dict and first_route_id in rg_cgrs_dict:
|
| 255 |
+
try:
|
| 256 |
+
rg_cgr = rg_cgrs_dict[first_route_id]
|
| 257 |
+
rg_cgr.clean2d()
|
| 258 |
+
rg_cgr_svg = rg_cgr.depict()
|
| 259 |
+
|
| 260 |
+
# Validate if it looks like SVG (basic check)
|
| 261 |
+
if rg_cgr_svg.strip().startswith("<svg"):
|
| 262 |
+
table += f"<tr>{td}{font_normal}Cluster Representative RG-CGR (from Route {first_route_id}):{font_close}<br>{rg_cgr_svg}</td></tr>"
|
| 263 |
+
else:
|
| 264 |
+
# Handle case where it's not SVG as expected
|
| 265 |
+
table += f"<tr>{td}{font_normal}Cluster Representative RG-CGR (from Route {first_route_id}):{font_close}<br><i>Invalid SVG format retrieved.</i></td></tr>"
|
| 266 |
+
print(f"Warning: Expected SVG for RG-CGR of node {first_route_id}, but got: {rg_cgr_svg[:100]}...") # Log a warning
|
| 267 |
+
|
| 268 |
+
except Exception as e:
|
| 269 |
+
table += f"<tr>{td}{font_normal}Cluster Representative RG-CGR (from Route {first_route_id}):{font_close}<br><i>Error retrieving/displaying RG-CGR: {e}</i></td></tr>"
|
| 270 |
+
else:
|
| 271 |
+
# Handle cases where RG-CGR data is missing
|
| 272 |
+
if first_route_id:
|
| 273 |
+
table += f"<tr>{td}{font_normal}Cluster Representative RG-CGR (from Route {first_route_id}):{font_close}<br><i>Not found in provided RG-CGR dictionary.</i></td></tr>"
|
| 274 |
+
else:
|
| 275 |
+
# This case shouldn't happen due to earlier check, but as fallback:
|
| 276 |
+
table += f"<tr>{td}{font_normal}Cluster Representative RG-CGR:{font_close}<br><i>No valid routes in cluster to select from.</i></td></tr>"
|
| 277 |
+
|
| 278 |
+
|
| 279 |
+
# --- Legend ---
|
| 280 |
+
table += f"""
|
| 281 |
+
<tr>{td}
|
| 282 |
+
<div style="display: flex; align-items: center; flex-wrap: wrap; gap: 15px;">
|
| 283 |
+
<span>{box_mark.replace("rgb()", "rgb(152, 238, 255)")} Target Molecule</span>
|
| 284 |
+
<span>{box_mark.replace("rgb()", "rgb(240, 171, 144)")} Molecule Not In Stock</span>
|
| 285 |
+
<span>{box_mark.replace("rgb()", "rgb(155, 250, 179)")} Molecule In Stock</span>
|
| 286 |
+
</div>
|
| 287 |
+
</td></tr>
|
| 288 |
+
"""
|
| 289 |
+
|
| 290 |
+
# --- Add Routes for this Cluster ---
|
| 291 |
+
for route_id in valid_routes_in_cluster:
|
| 292 |
+
try:
|
| 293 |
+
svg = get_route_svg(tree, route_id) # get SVG
|
| 294 |
+
full_route = tree.synthesis_route(route_id) # get route steps
|
| 295 |
+
reactions = ""
|
| 296 |
+
for i, synth_step in enumerate(full_route):
|
| 297 |
+
reactions += f"<b>Step {i + 1}:</b> {str(synth_step)}<br>"
|
| 298 |
+
route_score = round(tree.route_score(route_id), 3)
|
| 299 |
+
|
| 300 |
+
table += (
|
| 301 |
+
f'<tr style="line-height: 1.8;">{td}{font_head}Route {route_id} | ' # Use | for separation
|
| 302 |
+
f"Steps: {len(full_route)} | "
|
| 303 |
+
f"Score: {route_score}{font_close}</td></tr>"
|
| 304 |
+
)
|
| 305 |
+
table += f"<tr>{td}{svg if svg else '<i>Error generating route visualization</i>'}</td></tr>"
|
| 306 |
+
table += f"<tr>{td}{reactions if reactions else '<i>No reaction steps found</i>'}</td></tr>"
|
| 307 |
+
except Exception as e:
|
| 308 |
+
table += f'<tr><td colspan="1" style="color: red;">Error processing route {route_id}: {e}</td></tr>' # Use colspan if needed based on final table structure
|
| 309 |
+
|
| 310 |
+
table += "</tbody></table>"
|
| 311 |
+
|
| 312 |
+
# --- Combine and Return Full HTML ---
|
| 313 |
+
full_html = template_begin + table + template_end
|
| 314 |
+
return full_html
|
cluster/visualize.py
CHANGED
|
@@ -1,13 +1,187 @@
|
|
| 1 |
import os
|
| 2 |
import re
|
|
|
|
| 3 |
import numpy as np
|
| 4 |
import matplotlib.pyplot as plt
|
| 5 |
-
from collections import Counter
|
| 6 |
-
from synplan.utils.visualisation import get_route_svg
|
| 7 |
import seaborn as sns
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 8 |
|
| 9 |
-
|
| 10 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 11 |
|
| 12 |
sns.set_style("whitegrid")
|
| 13 |
|
|
@@ -16,11 +190,37 @@ def pie_chart(cluster_sizes):
|
|
| 16 |
cluster_sizes, labels=None, autopct='%1.1f%%', colors=sns.color_palette("pastel"),
|
| 17 |
startangle=140, wedgeprops={'edgecolor': 'black'}
|
| 18 |
)
|
| 19 |
-
ax.legend(wedges, labels, title="Clusters", loc="center left", bbox_to_anchor=(1, 0.5))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 20 |
|
| 21 |
# plt.show()
|
| 22 |
return fig
|
| 23 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 24 |
|
| 25 |
def distribution_by_depth(tree, complex_cgr_dict):
|
| 26 |
if len(complex_cgr_dict) == 0:
|
|
@@ -32,7 +232,7 @@ def distribution_by_depth(tree, complex_cgr_dict):
|
|
| 32 |
depths[n] = len(reactions)
|
| 33 |
return depths
|
| 34 |
|
| 35 |
-
def histogram_by_depth(depths, mol_id, config):
|
| 36 |
|
| 37 |
if len(depths) == 0:
|
| 38 |
print('Error no depths')
|
|
@@ -47,8 +247,10 @@ def histogram_by_depth(depths, mol_id, config):
|
|
| 47 |
plt.ylabel('Frequency')
|
| 48 |
plt.title(f'Frequency Histogram of Number of reactions in one tree of total {len(depths)}')
|
| 49 |
plt.xticks(bins)
|
| 50 |
-
|
| 51 |
-
|
|
|
|
|
|
|
| 52 |
|
| 53 |
|
| 54 |
def group_routes_by_depth(depths):
|
|
@@ -244,7 +446,7 @@ def create_route_svg_cluster(tree, node_ids, mol_id, config, depths, cluster_num
|
|
| 244 |
|
| 245 |
print(f"Saved: {path_name}")
|
| 246 |
|
| 247 |
-
def save_route_images(tree, depths, mol_id
|
| 248 |
"""
|
| 249 |
Save route images grouped by depth and/or cluster.
|
| 250 |
|
|
|
|
| 1 |
import os
|
| 2 |
import re
|
| 3 |
+
from collections import Counter
|
| 4 |
import numpy as np
|
| 5 |
import matplotlib.pyplot as plt
|
|
|
|
|
|
|
| 6 |
import seaborn as sns
|
| 7 |
+
from IPython.display import SVG, display
|
| 8 |
+
|
| 9 |
+
import io
|
| 10 |
+
import sys
|
| 11 |
+
|
| 12 |
+
from synplan.utils.visualisation import get_route_svg
|
| 13 |
+
from scipy.cluster.hierarchy import dendrogram
|
| 14 |
|
| 15 |
+
from .reduced_g_cgr import extract_strategic_bonds, compare_rg_cgr_by_strategic_bonds
|
| 16 |
+
|
| 17 |
+
def report_2_dissimilar(similarity_df, tree, rg_cgrs_dict):
|
| 18 |
+
min_index = similarity_df.stack().idxmin()
|
| 19 |
+
row_index, col_index = min_index
|
| 20 |
+
|
| 21 |
+
print(f'Most dissimilar routes are {row_index} and {col_index}, Tanimoto index = {"%.2f" % similarity_df[row_index][col_index]}')
|
| 22 |
+
|
| 23 |
+
print('Route ID', row_index)
|
| 24 |
+
rg_cgr_1 = rg_cgrs_dict[row_index]
|
| 25 |
+
rg_cgr_1.clean2d()
|
| 26 |
+
display(SVG(rg_cgr_1.depict()))
|
| 27 |
+
extract_strategic_bonds(rg_cgr_1)
|
| 28 |
+
display(SVG(get_route_svg(tree, row_index)))
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
print('Route ID', col_index)
|
| 32 |
+
rg_cgr_2 = rg_cgrs_dict[col_index]
|
| 33 |
+
rg_cgr_2.clean2d()
|
| 34 |
+
display(SVG(rg_cgr_2.depict()))
|
| 35 |
+
extract_strategic_bonds(rg_cgr_2)
|
| 36 |
+
display(SVG(get_route_svg(tree, col_index)))
|
| 37 |
+
|
| 38 |
+
print('Summary:')
|
| 39 |
+
compare_rg_cgr_by_strategic_bonds(rg_cgr_1, rg_cgr_2)
|
| 40 |
+
|
| 41 |
+
def save_clusters_html(clusters, best_by_score, tree, rg_cgrs_dict, mol_id, config):
|
| 42 |
+
# Prepare a list to accumulate HTML parts for each cluster
|
| 43 |
+
os.makedirs("./final_clusters", exist_ok=True)
|
| 44 |
+
html_parts = []
|
| 45 |
+
|
| 46 |
+
# Loop over your clusters
|
| 47 |
+
for cluster_num, node_id_list in clusters.items():
|
| 48 |
+
parts = [] # to accumulate parts for this cluster
|
| 49 |
+
|
| 50 |
+
# Generate text output
|
| 51 |
+
best_route_in_cluster = best_by_score[cluster_num][0]
|
| 52 |
+
score = round(tree.route_score(best_route_in_cluster), 3)
|
| 53 |
+
parts.append(f"{cluster_num} ||| Size: {len(clusters[cluster_num])}\n")
|
| 54 |
+
parts.append(f"Example: {best_route_in_cluster} Route score: {score}\n")
|
| 55 |
+
|
| 56 |
+
# Insert the first SVG immediately after its marker text
|
| 57 |
+
svg1 = get_route_svg(tree, best_route_in_cluster)
|
| 58 |
+
parts.append(svg1 + "\n")
|
| 59 |
+
|
| 60 |
+
# Continue with additional text and SVGs
|
| 61 |
+
parts.append("The RG-CGR:\n")
|
| 62 |
+
rg_cgr = rg_cgrs_dict[best_route_in_cluster]
|
| 63 |
+
rg_cgr.clean2d()
|
| 64 |
+
svg2 = rg_cgr.depict()
|
| 65 |
+
parts.append(svg2 + "\n")
|
| 66 |
+
|
| 67 |
+
# Capture output from extract_strategic_bonds, if it prints something
|
| 68 |
+
buf = io.StringIO()
|
| 69 |
+
old_stdout = sys.stdout
|
| 70 |
+
sys.stdout = buf
|
| 71 |
+
extract_strategic_bonds(rg_cgr)
|
| 72 |
+
sys.stdout = old_stdout
|
| 73 |
+
strategic_text = buf.getvalue()
|
| 74 |
+
parts.append(strategic_text + "\n")
|
| 75 |
+
|
| 76 |
+
# Wrap this cluster's output in a <pre> tag for formatting and add some spacing
|
| 77 |
+
cluster_html = f'<div class="cluster" style="margin-bottom: 2em;"><pre>{"".join(parts)}</pre></div>'
|
| 78 |
+
html_parts.append(cluster_html)
|
| 79 |
+
|
| 80 |
+
# Combine all parts into a full HTML document
|
| 81 |
+
html_content = f"""
|
| 82 |
+
<html>
|
| 83 |
+
<head>
|
| 84 |
+
<meta charset="utf-8">
|
| 85 |
+
<title>Captured Cluster Outputs</title>
|
| 86 |
+
</head>
|
| 87 |
+
<body>
|
| 88 |
+
{''.join(html_parts)}
|
| 89 |
+
</body>
|
| 90 |
+
</html>
|
| 91 |
+
"""
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
# Write the HTML content to a file
|
| 95 |
+
with open(f"final_clusters/htmls/mol_{mol_id}_{config}.html", "w", encoding="utf-8") as f:
|
| 96 |
+
f.write(html_content)
|
| 97 |
+
|
| 98 |
+
def report_2_dissimilar_to_html(similarity_df, tree, rg_cgrs_dict, mol_id=1, config=2,output_filename=None):
|
| 99 |
+
"""Generates an HTML report of the two most dissimilar routes based on a similarity DataFrame."""
|
| 100 |
+
|
| 101 |
+
os.makedirs("./dissimilars", exist_ok=True)
|
| 102 |
+
output_filename=f"dissimilars/report_dissimilar_mol_{mol_id}_{config}.html"
|
| 103 |
+
# Identify the two most dissimilar routes
|
| 104 |
+
min_index = similarity_df.stack().idxmin()
|
| 105 |
+
row_index, col_index = min_index
|
| 106 |
+
|
| 107 |
+
# Capture text output in a buffer
|
| 108 |
+
buf = io.StringIO()
|
| 109 |
+
old_stdout = sys.stdout
|
| 110 |
+
sys.stdout = buf
|
| 111 |
+
|
| 112 |
+
print(f'Most dissimilar routes are {row_index} and {col_index}, Tanimoto index = {"%.2f" % similarity_df[row_index][col_index]}')
|
| 113 |
+
|
| 114 |
+
# Store HTML content
|
| 115 |
+
html_parts = []
|
| 116 |
+
|
| 117 |
+
# Function to capture and append text, SVGs, and function outputs
|
| 118 |
+
def capture_route_info(route_id):
|
| 119 |
+
rg_cgr = rg_cgrs_dict[route_id]
|
| 120 |
+
rg_cgr.clean2d()
|
| 121 |
+
|
| 122 |
+
# Capture the first SVG (RG-CGR depiction)
|
| 123 |
+
svg1 = rg_cgr.depict()
|
| 124 |
+
|
| 125 |
+
# Capture the second SVG (Route depiction)
|
| 126 |
+
svg2 = get_route_svg(tree, route_id)
|
| 127 |
+
|
| 128 |
+
# Capture output of extract_strategic_bonds
|
| 129 |
+
buf_extract = io.StringIO()
|
| 130 |
+
sys.stdout = buf_extract
|
| 131 |
+
extract_strategic_bonds(rg_cgr)
|
| 132 |
+
sys.stdout = old_stdout
|
| 133 |
+
extract_output = buf_extract.getvalue()
|
| 134 |
+
|
| 135 |
+
# Store text + SVGs in HTML format
|
| 136 |
+
html_parts.append(f"""
|
| 137 |
+
<div class="route-section">
|
| 138 |
+
<pre>{buf.getvalue()}</pre>
|
| 139 |
+
<div class="svg1">{svg1}</div>
|
| 140 |
+
<pre>{extract_output}</pre>
|
| 141 |
+
<div class="svg2">{svg2}</div>
|
| 142 |
+
</div>
|
| 143 |
+
""")
|
| 144 |
+
buf.truncate(0) # Clear buffer for next route
|
| 145 |
+
buf.seek(0)
|
| 146 |
+
|
| 147 |
+
# Process the first route
|
| 148 |
+
capture_route_info(row_index)
|
| 149 |
+
|
| 150 |
+
# Process the second route
|
| 151 |
+
capture_route_info(col_index)
|
| 152 |
+
|
| 153 |
+
# Capture and store final summary
|
| 154 |
+
buf_summary = io.StringIO()
|
| 155 |
+
sys.stdout = buf_summary
|
| 156 |
+
compare_rg_cgr_by_strategic_bonds(rg_cgrs_dict[row_index], rg_cgrs_dict[col_index])
|
| 157 |
+
sys.stdout = old_stdout
|
| 158 |
+
summary_output = buf_summary.getvalue()
|
| 159 |
+
html_parts.append(f"<h2>Summary</h2><pre>{summary_output}</pre>")
|
| 160 |
+
|
| 161 |
+
# Restore standard stdout
|
| 162 |
+
sys.stdout = old_stdout
|
| 163 |
+
|
| 164 |
+
# Build the full HTML file
|
| 165 |
+
html_content = f"""
|
| 166 |
+
<html>
|
| 167 |
+
<head>
|
| 168 |
+
<meta charset="utf-8">
|
| 169 |
+
<title>Route Dissimilarity Report</title>
|
| 170 |
+
</head>
|
| 171 |
+
<body>
|
| 172 |
+
{''.join(html_parts)}
|
| 173 |
+
</body>
|
| 174 |
+
</html>
|
| 175 |
+
"""
|
| 176 |
+
|
| 177 |
+
# Write the HTML file
|
| 178 |
+
with open(output_filename, "w", encoding="utf-8") as f:
|
| 179 |
+
f.write(html_content)
|
| 180 |
+
|
| 181 |
+
print(f"Report saved as {output_filename}")
|
| 182 |
+
|
| 183 |
+
def pie_chart(cluster_sizes, sub='', input_cluster_num=1, input_step_nums=None):
|
| 184 |
+
labels = [f'{sub}Cluster {i+1}' for i in range(len(cluster_sizes))]
|
| 185 |
|
| 186 |
sns.set_style("whitegrid")
|
| 187 |
|
|
|
|
| 190 |
cluster_sizes, labels=None, autopct='%1.1f%%', colors=sns.color_palette("pastel"),
|
| 191 |
startangle=140, wedgeprops={'edgecolor': 'black'}
|
| 192 |
)
|
| 193 |
+
ax.legend(wedges, labels, title=f"{sub}Clusters", loc="center left", bbox_to_anchor=(1, 0.5))
|
| 194 |
+
if sub == '':
|
| 195 |
+
ax.set_title(f"{sub}Cluster Size Distribution for {sum(cluster_sizes)} routes")
|
| 196 |
+
else:
|
| 197 |
+
ax.set_title(f"{sub}cluster Size Distribution for {sum(cluster_sizes)} routes in cluster {input_cluster_num} with number of steps {input_step_nums}")
|
| 198 |
+
|
| 199 |
+
plt.close(fig)
|
| 200 |
|
| 201 |
# plt.show()
|
| 202 |
return fig
|
| 203 |
|
| 204 |
+
def save_dendrogram(df, Z, mol_id, config):
|
| 205 |
+
plt.figure(figsize=(14, 7)) # figsize=(14, 7)
|
| 206 |
+
|
| 207 |
+
dendrogram(Z, labels=df.columns, leaf_rotation=90)
|
| 208 |
+
plt.title(f"Hierarchical Clustering Dendrogram for routes generated for molecule #{mol_id}")
|
| 209 |
+
plt.xlabel("Route node id")
|
| 210 |
+
plt.ylabel("Distance (1 - Similarity)")
|
| 211 |
+
|
| 212 |
+
# Get current y-axis limits and add a gap below zero
|
| 213 |
+
ax = plt.gca()
|
| 214 |
+
ymin, ymax = ax.get_ylim()
|
| 215 |
+
# Add a gap that is 5% of the current y-range below zero
|
| 216 |
+
gap = 0.05 * (ymax - ymin)
|
| 217 |
+
ax.set_ylim(ymin - gap, ymax)
|
| 218 |
+
ax.grid(False)
|
| 219 |
+
ax.autoscale(enable=None, axis="x", tight=True)
|
| 220 |
+
|
| 221 |
+
plt.tight_layout()
|
| 222 |
+
plt.savefig(f'dendrograms/av_link_mol{mol_id}_{config}.png', dpi=100)
|
| 223 |
+
|
| 224 |
|
| 225 |
def distribution_by_depth(tree, complex_cgr_dict):
|
| 226 |
if len(complex_cgr_dict) == 0:
|
|
|
|
| 232 |
depths[n] = len(reactions)
|
| 233 |
return depths
|
| 234 |
|
| 235 |
+
def histogram_by_depth(depths, mol_id=1, config=1, save=False):
|
| 236 |
|
| 237 |
if len(depths) == 0:
|
| 238 |
print('Error no depths')
|
|
|
|
| 247 |
plt.ylabel('Frequency')
|
| 248 |
plt.title(f'Frequency Histogram of Number of reactions in one tree of total {len(depths)}')
|
| 249 |
plt.xticks(bins)
|
| 250 |
+
if save:
|
| 251 |
+
plt.savefig(f'histograms/by_depth_mol{mol_id}_{config}.png', dpi=100)
|
| 252 |
+
else:
|
| 253 |
+
plt.show()
|
| 254 |
|
| 255 |
|
| 256 |
def group_routes_by_depth(depths):
|
|
|
|
| 446 |
|
| 447 |
print(f"Saved: {path_name}")
|
| 448 |
|
| 449 |
+
def save_route_images(tree, depths, mol_id, config, cluster_dict=None):
|
| 450 |
"""
|
| 451 |
Save route images grouped by depth and/or cluster.
|
| 452 |
|