Gilmullin Almaz commited on
Commit
57a9d9a
·
1 Parent(s): c45df67

added module codes to subcluster

Browse files
cluster/clustering.py CHANGED
@@ -29,6 +29,9 @@ def tanimoto_similarity_continuous(matrix_1, matrix_2):
29
  result = x_dot / (np.array([x2] * len_y2).T + np.array([y2] * len_x2) - x_dot)
30
  result[np.isnan(result)] = 0
31
 
 
 
 
32
  return result
33
 
34
  def calculate_fingerprints(cgrs, fingerprint_method):
 
29
  result = x_dot / (np.array([x2] * len_y2).T + np.array([y2] * len_x2) - x_dot)
30
  result[np.isnan(result)] = 0
31
 
32
+ if matrix_1.shape == matrix_2.shape:
33
+ np.fill_diagonal(result, 1.0)
34
+
35
  return result
36
 
37
  def calculate_fingerprints(cgrs, fingerprint_method):
cluster/subcluster.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections import defaultdict
2
+
3
+ def split_ids_by_length(ids, data):
4
+ length_to_ids = defaultdict(list)
5
+
6
+ for id_ in ids:
7
+ if id_ in data:
8
+ length_to_ids[len(data[id_])].append(id_)
9
+ return length_to_ids
10
+
11
+
12
+ def group_ids_by_intermediate_products(ids, reactions_dict):
13
+ groups = defaultdict(list)
14
+ for id_ in ids:
15
+ # Build a key: a tuple of the first product for each reaction.
16
+ # This assumes that reactions_dict[id_] is a tuple of Reaction objects
17
+ # and each Reaction object has an attribute 'products' that is indexable.
18
+ key = tuple(reaction.products[0] for reaction in reactions_dict[id_])
19
+ groups[key].append(id_)
20
+ return list(groups.values())
21
+
22
+
23
+ def sublcuster_all(cluster_dict, reactions_dict):
24
+ subcluster_dict = {}
25
+ for num, cluster in cluster_dict.items():
26
+ step_split_dict = split_ids_by_length(cluster, reactions_dict)
27
+ subcluster = {}
28
+ for steps in step_split_dict.keys():
29
+ ids_to_group = step_split_dict[steps]
30
+ grouped_ids = group_ids_by_intermediate_products(ids_to_group, reactions_dict)
31
+ subcluster[steps] = grouped_ids
32
+ subcluster_dict[num] = subcluster
33
+ return subcluster_dict
cluster/utils.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ def extract_reactions(tree):
2
+ reactions_dict = {}
3
+ for node_id in set(tree.winning_nodes):
4
+ reactions = tree.synthesis_route(node_id)
5
+ reactions_dict[node_id] = reactions
6
+ return reactions_dict
7
+
8
+
9
+ class TreeWrapper:
10
+ def __init__(self, tree):
11
+ self.tree = tree
12
+
13
+ def __getstate__(self):
14
+ state = self.__dict__.copy()
15
+ # Save the state of the tree
16
+ tree_state = self.tree.__dict__.copy()
17
+ # Reset or remove non-pickleable attributes (e.g., _tqdm, policy_network, value_network)
18
+ if '_tqdm' in tree_state:
19
+ tree_state['_tqdm'] = True # Reset to a simple flag
20
+ for attr in ['policy_network', 'value_network']:
21
+ if attr in tree_state:
22
+ tree_state[attr] = None
23
+ # Store tree state separately
24
+ state['tree_state'] = tree_state
25
+ # Remove the actual tree instance from the state
26
+ del state['tree']
27
+ return state
28
+
29
+ def __setstate__(self, state):
30
+ # Retrieve the stored tree state
31
+ tree_state = state.pop('tree_state')
32
+ # Update the instance state
33
+ self.__dict__.update(state)
34
+ # Create a new Tree instance without calling __init__
35
+ new_tree = Tree.__new__(Tree)
36
+ new_tree.__dict__.update(tree_state)
37
+ self.tree = new_tree
cluster/visualize.py ADDED
@@ -0,0 +1,279 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import re
3
+ import numpy as np
4
+ import matplotlib.pyplot as plt
5
+ from collections import Counter
6
+ from synplan.utils.visualisation import get_route_svg
7
+ import seaborn as sns
8
+
9
+ def pie_chart(cluster_sizes):
10
+ labels = [f'Cluster {i+1}' for i in range(len(cluster_sizes))]
11
+
12
+ sns.set_style("whitegrid")
13
+
14
+ fig, ax = plt.subplots(figsize=(6, 6))
15
+ wedges, texts, autotexts = ax.pie(
16
+ cluster_sizes, labels=None, autopct='%1.1f%%', colors=sns.color_palette("pastel"),
17
+ startangle=140, wedgeprops={'edgecolor': 'black'}
18
+ )
19
+ ax.legend(wedges, labels, title="Clusters", loc="center left", bbox_to_anchor=(1, 0.5))
20
+
21
+ # plt.show()
22
+ return fig
23
+
24
+
25
+ def distribution_by_depth(tree, complex_cgr_dict):
26
+ if len(complex_cgr_dict) == 0:
27
+ print('Error: Empty dictionary')
28
+ return None
29
+ depths = np.zeros(len(complex_cgr_dict))
30
+ for n, node in enumerate(complex_cgr_dict.keys()):
31
+ reactions = tree.synthesis_route(node)
32
+ depths[n] = len(reactions)
33
+ return depths
34
+
35
+ def histogram_by_depth(depths, mol_id, config):
36
+
37
+ if len(depths) == 0:
38
+ print('Error no depths')
39
+ return None
40
+ # Count frequency of each depth
41
+ counter = Counter(depths)
42
+ bins, counts = zip(*sorted(counter.items()))
43
+
44
+ # Plot the histogram
45
+ plt.bar(bins, counts, width=0.5, color='skyblue', edgecolor='black')
46
+ plt.xlabel('Number of reactions')
47
+ plt.ylabel('Frequency')
48
+ plt.title(f'Frequency Histogram of Number of reactions in one tree of total {len(depths)}')
49
+ plt.xticks(bins)
50
+ # plt.show()
51
+ plt.savefig(f'histograms/by_depth_mol{mol_id}_{config}.png', dpi=100)
52
+
53
+
54
+ def group_routes_by_depth(depths):
55
+ """
56
+ Group route IDs by their reaction count (depth).
57
+
58
+ Args:
59
+ depths: Dictionary with node_ids as keys and reaction tuples as values
60
+
61
+ Returns:
62
+ dict: Dictionary with depths as keys and lists of node_ids as values
63
+ """
64
+ depth_groups = {}
65
+ for node_id, reactions in depths.items():
66
+ depth = len(reactions)
67
+ if depth not in depth_groups:
68
+ depth_groups[depth] = []
69
+ depth_groups[depth].append(node_id)
70
+ return depth_groups
71
+
72
+ def create_route_svg(tree, node_ids, mol_id, config, depths, depth=None):
73
+ """Create SVG file for specified routes with optimized spacing."""
74
+
75
+ # First pass: analyze all SVGs to find maximum width
76
+ max_width_cm = 0
77
+ all_route_svgs = [] # Store SVGs to avoid calling get_route_svg twice
78
+
79
+ for g in node_ids:
80
+ route_svg = get_route_svg(tree, g)
81
+ all_route_svgs.append(route_svg)
82
+
83
+ # Extract the actual SVG content
84
+ svg_match = re.search(r'<svg[^>]*>', route_svg)
85
+ if svg_match:
86
+ svg_header = svg_match.group(0)
87
+
88
+ # Try to get width from cm attribute
89
+ width_match = re.search(r'width="([0-9.]+)cm"', svg_header)
90
+ if width_match:
91
+ try:
92
+ width_cm = float(width_match.group(1))
93
+ max_width_cm = max(max_width_cm, width_cm)
94
+ except ValueError:
95
+ pass
96
+
97
+ # Convert cm to pixels (1cm ≈ 37.8 pixels)
98
+ CM_TO_PX = 37.8
99
+ max_width_px = max_width_cm * CM_TO_PX
100
+
101
+ # Add margins
102
+ left_margin = 50
103
+ right_margin = 100
104
+ composite_width = max_width_px + left_margin + right_margin
105
+
106
+ # Continue with SVG creation using calculated width
107
+ vertical_spacing = 20
108
+ text_height = 20
109
+ route_spacing = 250
110
+ current_y = 30
111
+ entries = []
112
+
113
+ size = len(node_ids)
114
+
115
+ for num, (g, route_svg_str) in enumerate(zip(node_ids, all_route_svgs), 1):
116
+ # Calculate dimensions
117
+ route_px_height = 200
118
+
119
+ # Create entry with optimized spacing
120
+ entry_parts = []
121
+ entry_parts.append(f'<g transform="translate({left_margin}, {current_y})">')
122
+ entry_parts.append(f' <text x="0" y="{text_height}" font-size="12" fill="black">{num} (Node ID: {g}, Number of reactions: {len(depths[g])})</text>')
123
+
124
+ inner_y = text_height + 25
125
+ entry_parts.append(f' <g transform="translate(0, {inner_y})">{route_svg_str}</g>')
126
+
127
+ total_entry_height = inner_y + route_px_height + 250
128
+ entry_parts.append('</g>')
129
+
130
+ entry_block = "\n".join(entry_parts)
131
+ entry_bottom_y = current_y + total_entry_height
132
+ entries.append((entry_block, entry_bottom_y))
133
+
134
+ current_y = entry_bottom_y + route_spacing - 50
135
+
136
+ # Create master SVG with adjusted dimensions
137
+ master_width = composite_width
138
+ master_height = current_y + vertical_spacing
139
+
140
+ final_parts = []
141
+ for entry_block, bottom_y in entries:
142
+ final_parts.append(entry_block)
143
+ final_parts.append(f'<line x1="0" y1="{bottom_y}" x2="{master_width}" y2="{bottom_y}" stroke="black" stroke-width="1" />')
144
+
145
+ master_svg = f'<svg xmlns="http://www.w3.org/2000/svg" width="{master_width}" height="{master_height}" viewBox="0 0 {master_width} {master_height}">\n'
146
+ master_svg += "\n".join(final_parts)
147
+ master_svg += "\n</svg>"
148
+
149
+ # Save file with appropriate name
150
+ if depth is None:
151
+ path_name = f"./routes_img/mol_{mol_id}/mol{mol_id}_{config}_all_{size}.svg"
152
+ else:
153
+ path_name = f"./routes_img/mol_{mol_id}/mol{mol_id}_{config}_depth_{depth}_{size}.svg"
154
+
155
+ with open(path_name, "w") as f:
156
+ f.write(master_svg)
157
+
158
+ print(f"Saved: {path_name}")
159
+
160
+
161
+ def create_route_svg_cluster(tree, node_ids, mol_id, config, depths, cluster_num):
162
+ """
163
+ Create SVG file for specified routes with optimized spacing, grouped by cluster.
164
+ """
165
+ # First pass: analyze all SVGs to find maximum width
166
+ max_width_cm = 0
167
+ all_route_svgs = [] # Store SVGs to avoid calling get_route_svg twice
168
+
169
+ for g in node_ids:
170
+ route_svg = get_route_svg(tree, g)
171
+ all_route_svgs.append(route_svg)
172
+
173
+ # Extract the actual SVG content
174
+ svg_match = re.search(r'<svg[^>]*>', route_svg)
175
+ if svg_match:
176
+ svg_header = svg_match.group(0)
177
+
178
+ # Try to get width from cm attribute
179
+ width_match = re.search(r'width="([0-9.]+)cm"', svg_header)
180
+ if width_match:
181
+ try:
182
+ width_cm = float(width_match.group(1))
183
+ max_width_cm = max(max_width_cm, width_cm)
184
+ except ValueError:
185
+ pass
186
+
187
+ # Convert cm to pixels (1cm ≈ 37.8 pixels)
188
+ CM_TO_PX = 37.8
189
+ max_width_px = max_width_cm * CM_TO_PX
190
+
191
+ # Add margins
192
+ left_margin = 50
193
+ right_margin = 100
194
+ composite_width = max_width_px + left_margin + right_margin
195
+
196
+ # Continue with SVG creation using calculated width
197
+ vertical_spacing = 20
198
+ text_height = 20
199
+ route_spacing = 250
200
+ current_y = 30
201
+ entries = []
202
+
203
+ size = len(node_ids)
204
+
205
+ for num, (g, route_svg_str) in enumerate(zip(node_ids, all_route_svgs), 1):
206
+ # Calculate dimensions
207
+ route_px_height = 200
208
+
209
+ # Create entry with optimized spacing
210
+ entry_parts = []
211
+ entry_parts.append(f'<g transform="translate({left_margin}, {current_y})">')
212
+ entry_parts.append(f' <text x="0" y="{text_height}" font-size="12" fill="black">{num} (Node ID: {g}, Number of reactions: {len(depths[g])})</text>')
213
+
214
+ inner_y = text_height + 25
215
+ entry_parts.append(f' <g transform="translate(0, {inner_y})">{route_svg_str}</g>')
216
+
217
+ total_entry_height = inner_y + route_px_height + 350
218
+ entry_parts.append('</g>')
219
+
220
+ entry_block = "\n".join(entry_parts)
221
+ entry_bottom_y = current_y + total_entry_height
222
+ entries.append((entry_block, entry_bottom_y))
223
+
224
+ current_y = entry_bottom_y + route_spacing - 50
225
+
226
+ # Create master SVG with adjusted dimensions
227
+ master_width = composite_width
228
+ master_height = current_y + vertical_spacing
229
+
230
+ final_parts = []
231
+ for entry_block, bottom_y in entries:
232
+ final_parts.append(entry_block)
233
+ final_parts.append(f'<line x1="0" y1="{bottom_y}" x2="{master_width}" y2="{bottom_y}" stroke="black" stroke-width="1" />')
234
+
235
+ master_svg = f'<svg xmlns="http://www.w3.org/2000/svg" width="{master_width}" height="{master_height}" viewBox="0 0 {master_width} {master_height}">\n'
236
+ master_svg += "\n".join(final_parts)
237
+ master_svg += "\n</svg>"
238
+
239
+ # Save file with cluster-specific name
240
+ path_name = f"./routes_img/mol_{mol_id}/mol{mol_id}_{config}_cluster_{cluster_num}_{size}.svg"
241
+
242
+ with open(path_name, "w") as f:
243
+ f.write(master_svg)
244
+
245
+ print(f"Saved: {path_name}")
246
+
247
+ def save_route_images(tree, depths, mol_id=1, config=1, cluster_dict=None):
248
+ """
249
+ Save route images grouped by depth and/or cluster.
250
+
251
+ Args:
252
+ tree: Synthesis tree
253
+ routes: Dictionary of routes
254
+ depths: Dictionary of reaction depths
255
+ mol_id: Molecule ID
256
+ config: Configuration value
257
+ cluster_dict: Optional dictionary mapping cluster numbers to lists of node_ids
258
+ """
259
+ # Create directory if it doesn't exist
260
+ os.makedirs("./routes_img", exist_ok=True)
261
+ os.makedirs(f"./routes_img/mol_{mol_id}", exist_ok=True)
262
+
263
+ # Save complete image with all routes
264
+ all_node_ids = sorted(depths.keys())
265
+ create_route_svg(tree, all_node_ids, mol_id, config, depths)
266
+
267
+ # Group routes by depth and save separate images
268
+ depth_groups = group_routes_by_depth(depths)
269
+ for depth, node_ids in depth_groups.items():
270
+ create_route_svg(tree, sorted(node_ids), mol_id, config, depths, depth)
271
+
272
+ # If cluster dictionary is provided, save routes grouped by cluster
273
+ if cluster_dict is not None:
274
+ for cluster_num, node_ids in cluster_dict.items():
275
+ # Filter node_ids to only include those that exist in routes
276
+ valid_node_ids = [nid for nid in node_ids if nid in depths]
277
+ if valid_node_ids:
278
+ create_route_svg_cluster(tree, sorted(valid_node_ids),
279
+ mol_id, config, depths, cluster_num)