Gilmullin Almaz commited on
Commit
88aa6df
Β·
1 Parent(s): 2df9441

Refactoring till the unabstracted subclustering

Browse files
app.py CHANGED
@@ -20,8 +20,8 @@ from synplan.utils.loading import load_reaction_rules, load_building_blocks
20
  from synplan.utils.visualisation import generate_results_html, get_route_svg
21
 
22
 
23
- from cluster.super_cgr import *
24
- from cluster.rs_cgr import *
25
  from cluster.clustering import *
26
  from cluster.visualize import *
27
  from cluster.utils import *
@@ -119,15 +119,29 @@ if "tree" not in st.session_state:
119
  st.session_state.tree = None
120
  if "res" not in st.session_state:
121
  st.session_state.res = None
122
- if "num_clusters" not in st.session_state:
123
- st.session_state.num_clusters = 10
124
-
125
- if 'clustering_started' not in st.session_state:
126
- st.session_state.clustering_started = False
127
- if 'clusters_downloaded' not in st.session_state:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
128
  st.session_state.clusters_downloaded = False
129
 
130
-
131
  intro_text = '''
132
  This is a demo of the graphical user interface of
133
  [SynPlanner](https://github.com/Laboratoire-de-Chemoinformatique/SynPlanner/).
@@ -154,26 +168,41 @@ molecule = st.text_input("SMILES:", DEFAULT_MOL)
154
  smile_code = st_ketcher(molecule)
155
  target_molecule = mol_from_smiles(smile_code)
156
 
157
- building_blocks_path = hf_hub_download(
158
- repo_id="Laboratoire-De-Chemoinformatique/SynPlanner",
159
- filename="building_blocks_em_sa_ln.smi",
160
- subfolder="building_blocks",
161
- local_dir="."
162
- )
163
-
164
- ranking_policy_weights_path = hf_hub_download(
165
- repo_id="Laboratoire-De-Chemoinformatique/SynPlanner",
166
- filename="ranking_policy_network.ckpt",
167
- subfolder="uspto/weights",
168
- local_dir="."
169
- )
170
-
171
- reaction_rules_path = hf_hub_download(
172
- repo_id="Laboratoire-De-Chemoinformatique/SynPlanner",
173
- filename="uspto_reaction_rules.pickle",
174
- subfolder="uspto",
175
- local_dir="."
176
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
177
 
178
  st.header('Launch calculation')
179
  st.markdown(
@@ -198,7 +227,7 @@ with col_options_1:
198
  c_ucb = st.number_input("C coefficient of UCB", value=0.1, placeholder="Type a number...")
199
 
200
  with col_options_2:
201
- max_iterations = st.slider('Total number of MCTS iterations', min_value=50, max_value=300, value=100)
202
  max_depth = st.slider('Maximal number of reaction steps', min_value=3, max_value=9, value=6)
203
  min_mol_size = st.slider('Minimum size of a molecule to be precursor', min_value=0, max_value=7, value=0)
204
 
@@ -210,85 +239,147 @@ search_strategy = search_strategy_translator[search_strategy_input]
210
 
211
  submit_planning = st.button('Start retrosynthetic planning')
212
 
213
- # if submit_planning:
214
- if submit_planning and not st.session_state.planning_done:
215
- with st.status("Downloading data"):
216
- st.write("Downloading building blocks")
217
- building_blocks = load_building_blocks(building_blocks_path, standardize=False)
218
- st.write('Downloading reaction rules')
219
- reaction_rules = load_reaction_rules(reaction_rules_path)
220
- st.write('Loading policy network')
221
- policy_config = PolicyNetworkConfig(weights_path=ranking_policy_weights_path)
222
- policy_function = PolicyNetworkFunction(policy_config=policy_config)
223
-
224
- tree_config = TreeConfig(
225
- search_strategy=search_strategy,
226
- evaluation_type="rollout",
227
- max_iterations=max_iterations,
228
- max_depth=max_depth,
229
- min_mol_size=min_mol_size,
230
- init_node_value=0.5,
231
- ucb_type=ucb_type,
232
- c_ucb=c_ucb,
233
- silent=True
234
- )
235
-
236
- tree = Tree(
237
- target=target_molecule,
238
- config=tree_config,
239
- reaction_rules=reaction_rules,
240
- building_blocks=building_blocks,
241
- expansion_function=policy_function,
242
- evaluation_function=None,
243
- )
244
-
245
- mcts_progress_text = "Running retrosynthetic planning"
246
- mcts_bar = st.progress(0, text=mcts_progress_text)
247
- for step, (solved, node_id) in enumerate(tree):
248
- mcts_bar.progress(step / max_iterations, text=mcts_progress_text)
249
-
250
- res = extract_tree_stats(tree, target_molecule)
251
-
252
- # Store planning outputs in session_state so they persist
253
- st.session_state['tree'] = tree
254
- st.session_state['res'] = res
255
- st.session_state.planning_done = True
256
-
257
- # Display results if planning has been completed
258
- if st.session_state.planning_done and st.session_state.res is not None:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
259
  res = st.session_state.res
260
  tree = st.session_state.tree
261
 
262
- st.header('Results')
263
- if res["solved"]:
264
- # st.balloons()
265
-
 
 
 
 
 
266
  st.subheader("Examples of found retrosynthetic routes")
267
  image_counter = 0
268
  visualised_node_ids = set()
269
- for n, node_id in enumerate(sorted(set(tree.winning_nodes))):
270
- if image_counter == 3:
271
- break
272
- if n % 2 == 0 and node_id not in visualised_node_ids:
273
- visualised_node_ids.add(node_id)
274
- image_counter += 1
275
- num_steps = len(tree.synthesis_route(node_id))
276
- route_score = round(tree.route_score(node_id), 3)
277
- st.image(get_route_svg(tree, node_id),
278
- caption=f"Route {node_id}; {num_steps} steps; Route score: {route_score}")
279
-
 
 
 
 
 
 
 
 
 
 
 
 
280
  stat_col, download_col = st.columns(2, gap="medium")
281
  with stat_col:
282
  st.subheader("Statistics")
283
- df = pd.DataFrame(res, index=[0])
284
- st.write(df[["target_smiles", "num_routes", "num_nodes", "num_iter", "search_time"]])
 
 
 
 
 
 
 
 
 
 
285
  with download_col:
286
  st.subheader("Downloads")
287
- html_body = generate_results_html(tree, html_path=None, extended=True)
288
- dl_html = download_button(html_body, 'results_synplanner.html', 'Download results as a HTML file')
289
- dl_csv = download_button(pd.DataFrame(res, index=[0]),
290
- 'results_synplanner.csv', 'Download statistics as a csv file')
291
- st.markdown(dl_html + dl_csv, unsafe_allow_html=True)
 
 
 
 
 
 
 
 
 
 
 
 
292
 
293
  st.header("Clustering the retrosynthetic routes")
294
 
@@ -297,92 +388,361 @@ if st.session_state.planning_done and st.session_state.res is not None:
297
 
298
  cluster_box, _ = st.columns(2, gap="medium")
299
  with cluster_box:
300
- num_clusters = st.slider(
301
- 'Number of clusters to display',
302
  min_value=2,
303
- max_value=10,
304
- value=st.session_state['num_clusters'],
305
  key='cluster_slider'
306
  )
307
- st.session_state['num_clusters'] = num_clusters
308
-
309
- if st.button('Start clustering', key='submit_clustering'):
310
- st.session_state.clustering_started = True
311
- st.subheader("Examples of clusters")
312
- super_cgrs_dict = reassign_nums(tree)
313
-
314
- reduced_super_cgrs_dict = process_all_rs_cgrs(super_cgrs_dict)
315
-
316
- mfp = MorganFingerprint()
317
-
318
- results = cluster_molecules(reduced_super_cgrs_dict, mfp, max_clusters=num_clusters)
319
-
320
- clusters = results['clusters_dict']
321
-
322
- for cluster_num, node_id_list in clusters.items():
323
- st.markdown(f"Cluster's number: {cluster_num}; Size {len(node_id_list)}")
324
- node_id = node_id_list[0]
325
- num_steps = len(tree.synthesis_route(node_id))
326
- route_score = round(tree.route_score(node_id), 3)
327
- st.image(get_route_svg(tree, node_id), caption=f"Route {node_id}; {num_steps} steps; Route score: {route_score}")
328
-
329
- cluster_sizes = [len(cluster) for cluster in clusters.values()]
330
-
331
- cluster_stat_col, cluster_download_col = st.columns(2, gap="medium")
332
- with cluster_stat_col:
333
- st.subheader("Statistics")
334
- cluster_df = pd.DataFrame({'Cluster': range(len(cluster_sizes)), 'Routes': cluster_sizes})
335
- st.write(cluster_df)
336
-
337
- def on_download_click():
338
- st.session_state.clusters_downloaded = True
339
- st.write("Download clusters button pressed via on_click. Updated session state:", dict(st.session_state))
340
- save_route_images(tree, reactions_dict, cluster_dict=clusters_converted)
341
-
342
- with cluster_download_col:
343
- st.subheader("Downloads: Don't work. Resets evey time")
344
- reactions_dict = extract_reactions(tree)
345
- clusters_converted = {int(key): value for key, value in clusters.items()} if clusters else clusters
346
-
347
- # Use on_click to capture the click event reliably.
348
- st.button('Download clusters', key='download_clusters_button', on_click=on_download_click)
349
-
350
- st.write("Clusters downloaded flag (from session_state):", st.session_state.get("clusters_downloaded"))
351
-
352
- col1, _ = st.columns([.2, .8])
353
- with col1:
354
- fig = pie_chart(cluster_sizes)
355
- st.pyplot(fig)
356
- st.header("Sub Clustering the retrosynthetic routes - Resets every time when i interact with input widget")
357
- sub = sublcuster_all(clusters, reactions_dict)
358
- col2, _ = st.columns([.2, .8])
359
- with col2:
360
- user_input_cluster_num = st.number_input("Enter a number:", min_value=1,
361
- max_value=max(clusters.keys()), value=1, step=1)
362
-
363
- st.write(f"You entered the # cluster: {user_input_cluster_num}")
364
- sub_step_cluster = sub[user_input_cluster_num]
365
- allowed_numbers = sub_step_cluster.keys()
366
- selected_number = st.selectbox("Choose a number:", allowed_numbers)
367
- st.write(f"You entered number of steps: {selected_number}")
368
- subclusters = sub_step_cluster[selected_number]
369
-
370
- st.subheader(f"Found number of subclusters: {len(subclusters)}")
371
- for subcluster_num, subcluster_set in enumerate(subclusters):
372
- st.write(f"Subcluster #: {subcluster_num + 1}")
373
- for route_id in subcluster_set:
374
- st.write(f"Node_ID: {route_id}")
375
- st.image(get_route_svg(tree, route_id), caption=f"Route {node_id}; {num_steps} steps; Route score: {route_score}")
376
-
377
- else:
378
- st.write("Found no reaction path.")
379
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
380
  st.divider()
381
- st.header('Restart from the beginning?')
382
- if st.button("Restart"):
383
- st.session_state.planning_done = False
384
- st.session_state.tree = None
385
- st.session_state.res = None
386
- st.session_state.clustering_started = False
387
- st.session_state.clusters_downloaded = False
 
 
 
 
 
 
 
388
  st.rerun()
 
20
  from synplan.utils.visualisation import generate_results_html, get_route_svg
21
 
22
 
23
+ from cluster.generalized_cgr import *
24
+ from cluster.reduced_g_cgr import *
25
  from cluster.clustering import *
26
  from cluster.visualize import *
27
  from cluster.utils import *
 
119
  st.session_state.tree = None
120
  if "res" not in st.session_state:
121
  st.session_state.res = None
122
+ if "target_smiles" not in st.session_state:
123
+ st.session_state.target_smiles = ''
124
+
125
+ # Clustering state
126
+ if "clustering_done" not in st.session_state:
127
+ st.session_state.clustering_done = False
128
+ if "clusters" not in st.session_state:
129
+ st.session_state.clusters = None
130
+ if "reactions_dict" not in st.session_state:
131
+ st.session_state.reactions_dict = None
132
+ if "num_clusters_setting" not in st.session_state: # Store the setting used
133
+ st.session_state.num_clusters_setting = 10
134
+
135
+ # Subclustering state
136
+ if "subclustering_done" not in st.session_state:
137
+ st.session_state.subclustering_done = False
138
+ if "sub" not in st.session_state:
139
+ st.session_state.sub = None
140
+
141
+ # Download state (less critical now with direct download links)
142
+ if 'clusters_downloaded' not in st.session_state: # Example, might not be needed
143
  st.session_state.clusters_downloaded = False
144
 
 
145
  intro_text = '''
146
  This is a demo of the graphical user interface of
147
  [SynPlanner](https://github.com/Laboratoire-de-Chemoinformatique/SynPlanner/).
 
168
  smile_code = st_ketcher(molecule)
169
  target_molecule = mol_from_smiles(smile_code)
170
 
171
+ if 'target_smiles' in st.session_state and smile_code != st.session_state.target_smiles:
172
+ # If the SMILES changes, invalidate previous results
173
+ st.warning("Molecule structure changed. Please re-run planning.")
174
+ st.session_state.planning_done = False
175
+ st.session_state.clustering_done = False
176
+ st.session_state.subclustering_done = False
177
+ st.session_state.tree = None
178
+ st.session_state.res = None
179
+ st.session_state.clusters = None
180
+ st.session_state.reactions_dict = None
181
+ st.session_state.sub = None
182
+
183
+ @st.cache_resource
184
+ def load_planning_resources():
185
+ building_blocks_path = hf_hub_download(
186
+ repo_id="Laboratoire-De-Chemoinformatique/SynPlanner",
187
+ filename="building_blocks_em_sa_ln.smi",
188
+ subfolder="building_blocks",
189
+ local_dir="."
190
+ )
191
+ ranking_policy_weights_path = hf_hub_download(
192
+ repo_id="Laboratoire-De-Chemoinformatique/SynPlanner",
193
+ filename="ranking_policy_network.ckpt",
194
+ subfolder="uspto/weights",
195
+ local_dir="."
196
+ )
197
+ reaction_rules_path = hf_hub_download(
198
+ repo_id="Laboratoire-De-Chemoinformatique/SynPlanner",
199
+ filename="uspto_reaction_rules.pickle",
200
+ subfolder="uspto",
201
+ local_dir="."
202
+ )
203
+ return building_blocks_path, ranking_policy_weights_path, reaction_rules_path
204
+
205
+ building_blocks_path, ranking_policy_weights_path, reaction_rules_path = load_planning_resources()
206
 
207
  st.header('Launch calculation')
208
  st.markdown(
 
227
  c_ucb = st.number_input("C coefficient of UCB", value=0.1, placeholder="Type a number...")
228
 
229
  with col_options_2:
230
+ max_iterations = st.slider('Total number of MCTS iterations', min_value=50, max_value=1000, value=300)
231
  max_depth = st.slider('Maximal number of reaction steps', min_value=3, max_value=9, value=6)
232
  min_mol_size = st.slider('Minimum size of a molecule to be precursor', min_value=0, max_value=7, value=0)
233
 
 
239
 
240
  submit_planning = st.button('Start retrosynthetic planning')
241
 
242
+ if submit_planning:
243
+ # Reset downstream states if replanning
244
+ st.session_state.planning_done = False
245
+ st.session_state.clustering_done = False
246
+ st.session_state.subclustering_done = False
247
+ st.session_state.tree = None
248
+ st.session_state.res = None
249
+ st.session_state.clusters = None
250
+ st.session_state.reactions_dict = None
251
+ st.session_state.sub = None
252
+ st.session_state.target_smiles = smile_code # Store the SMILES used for this run
253
+
254
+ try:
255
+ target_molecule = mol_from_smiles(smile_code)
256
+ if target_molecule is None:
257
+ st.error(f"Could not parse the input SMILES: {smile_code}")
258
+ else:
259
+ with st.spinner("Running retrosynthetic planning..."):
260
+ with st.status("Loading resources...", expanded=False) as status:
261
+ st.write("Loading building blocks...")
262
+ building_blocks = load_building_blocks(building_blocks_path, standardize=False)
263
+ st.write('Loading reaction rules...')
264
+ reaction_rules = load_reaction_rules(reaction_rules_path)
265
+ st.write('Loading policy network...')
266
+ policy_config = PolicyNetworkConfig(weights_path=ranking_policy_weights_path)
267
+ policy_function = PolicyNetworkFunction(policy_config=policy_config)
268
+ status.update(label="Resources loaded!", state="complete")
269
+
270
+ tree_config = TreeConfig(
271
+ search_strategy=search_strategy,
272
+ evaluation_type="rollout",
273
+ max_iterations=max_iterations,
274
+ max_depth=max_depth,
275
+ min_mol_size=min_mol_size,
276
+ init_node_value=0.5,
277
+ ucb_type=ucb_type,
278
+ c_ucb=c_ucb,
279
+ silent=True
280
+ )
281
+
282
+ tree = Tree(
283
+ target=target_molecule,
284
+ config=tree_config,
285
+ reaction_rules=reaction_rules,
286
+ building_blocks=building_blocks,
287
+ expansion_function=policy_function,
288
+ evaluation_function=None,
289
+ )
290
+
291
+ mcts_progress_text = "Running MCTS iterations..."
292
+ mcts_bar = st.progress(0, text=mcts_progress_text)
293
+ for step, (solved, node_id) in enumerate(tree):
294
+ progress_value = min(1.0, (step + 1) / max_iterations)
295
+ mcts_bar.progress(progress_value, text=f"{mcts_progress_text} ({step+1}/{max_iterations})")
296
+
297
+ res = extract_tree_stats(tree, target_molecule)
298
+
299
+ # Store planning outputs in session_state
300
+ st.session_state['tree'] = tree
301
+ st.session_state['res'] = res
302
+ st.session_state.planning_done = True
303
+ st.rerun() # Rerun to display results cleanly
304
+
305
+ except Exception as e:
306
+ st.error(f"An error occurred during planning: {e}")
307
+ st.session_state.planning_done = False # Ensure state reflects failure
308
+
309
+ # Display results if planning has been completed
310
+ if st.session_state.get('planning_done', False):
311
  res = st.session_state.res
312
  tree = st.session_state.tree
313
 
314
+ if res is None or tree is None:
315
+ st.error("Planning results are missing from session state. Please re-run planning.")
316
+ st.session_state.planning_done = False # Reset state
317
+ elif res["solved"]:
318
+ st.header('Planning Results')
319
+ # st.balloons() # Optional fun
320
+ winning_nodes = sorted(set(tree.winning_nodes)) if hasattr(tree, 'winning_nodes') and tree.winning_nodes else []
321
+ st.subheader(f"Number of unique routes found: {len(winning_nodes)}")
322
+
323
  st.subheader("Examples of found retrosynthetic routes")
324
  image_counter = 0
325
  visualised_node_ids = set()
326
+ # Ensure winning_nodes is iterable and not empty
327
+
328
+ if not winning_nodes:
329
+ st.warning("Planning solved, but no winning nodes found in the tree object.")
330
+ else:
331
+ for n, node_id in enumerate(winning_nodes):
332
+ if image_counter >= 3: # Use >= for clarity
333
+ break
334
+ # Simple display logic: show first 3 unique routes
335
+ if node_id not in visualised_node_ids:
336
+ try:
337
+ visualised_node_ids.add(node_id)
338
+ num_steps = len(tree.synthesis_route(node_id))
339
+ route_score = round(tree.route_score(node_id), 3)
340
+ svg = get_route_svg(tree, node_id)
341
+ if svg:
342
+ st.image(svg, caption=f"Route {node_id}; {num_steps} steps; Route score: {route_score}")
343
+ image_counter += 1
344
+ else:
345
+ st.warning(f"Could not generate SVG for route {node_id}.")
346
+ except Exception as e:
347
+ st.error(f"Error displaying route {node_id}: {e}")
348
+
349
  stat_col, download_col = st.columns(2, gap="medium")
350
  with stat_col:
351
  st.subheader("Statistics")
352
+ try:
353
+ # Ensure 'target_smiles' exists in res, if not, use the stored one
354
+ if 'target_smiles' not in res:
355
+ res['target_smiles'] = st.session_state.target_smiles
356
+ # Select only existing columns safely
357
+ cols_to_show = [col for col in ["target_smiles", "num_routes", "num_nodes", "num_iter", "search_time"] if col in res]
358
+ df = pd.DataFrame(res, index=[0])[cols_to_show]
359
+ st.dataframe(df) # Use dataframe for better display
360
+ except Exception as e:
361
+ st.error(f"Error displaying statistics: {e}")
362
+ st.write(res) # Show raw dict if DataFrame fails
363
+
364
  with download_col:
365
  st.subheader("Downloads")
366
+ try:
367
+ html_body = generate_results_html(tree, html_path=None, extended=True)
368
+ dl_html = download_button(html_body, 'results_synplanner.html', 'Download results (HTML)')
369
+ if dl_html: st.markdown(dl_html, unsafe_allow_html=True)
370
+
371
+ # Ensure res is suitable for DataFrame before creating/downloading
372
+ try:
373
+ res_df = pd.DataFrame(res, index=[0])
374
+ dl_csv = download_button(res_df, 'results_synplanner.csv', 'Download statistics (CSV)')
375
+ if dl_csv: st.markdown(dl_csv, unsafe_allow_html=True)
376
+ except Exception as e:
377
+ st.error(f"Could not prepare statistics CSV for download: {e}")
378
+
379
+ except Exception as e:
380
+ st.error(f"Error generating download links: {e}")
381
+
382
+ st.divider()
383
 
384
  st.header("Clustering the retrosynthetic routes")
385
 
 
388
 
389
  cluster_box, _ = st.columns(2, gap="medium")
390
  with cluster_box:
391
+ num_clusters_input = st.slider(
392
+ 'Max number of clusters to generate',
393
  min_value=2,
394
+ max_value=min(50, res.get("num_routes", 50)), # Sensible max based on routes found
395
+ value=st.session_state.num_clusters_setting,
396
  key='cluster_slider'
397
  )
398
+ if st.button('Run Clustering', key='submit_clustering'):
399
+ # Update the setting in session state when the button is clicked
400
+ st.session_state.num_clusters_setting = num_clusters_input
401
+ # Reset downstream states
402
+ st.session_state.clustering_done = False
403
+ st.session_state.subclustering_done = False
404
+ st.session_state.clusters = None
405
+ st.session_state.reactions_dict = None
406
+ st.session_state.sub = None
407
+
408
+ with st.spinner("Performing clustering..."):
409
+ try:
410
+ # Ensure tree is available from session state
411
+ current_tree = st.session_state.tree
412
+ if not current_tree:
413
+ st.error("Tree object not found. Please re-run planning.")
414
+ else:
415
+ st.write("Calculating Generalized CGRs...")
416
+ g_cgrs_dict = reassign_nums(current_tree) # Assuming this needs the tree
417
+ st.write("Processing RG-CGRs...")
418
+ reduced_g_cgrs_dict = process_all_rg_cgrs(g_cgrs_dict) # Assuming this uses the previous output
419
+
420
+ mfp = MorganFingerprint()
421
+ st.write(f"Clustering into max {st.session_state.num_clusters_setting} clusters...")
422
+ results = cluster_molecules(reduced_g_cgrs_dict, mfp, max_clusters=st.session_state.num_clusters_setting)
423
+
424
+ st.session_state.clusters = results.get('clusters_dict')
425
+ st.session_state.rg_cgrs_dict = reduced_g_cgrs_dict
426
+ # Extract reactions *after* clustering if needed, ensure tree is passed
427
+ st.write("Extracting reactions...")
428
+ st.session_state.reactions_dict = extract_reactions(current_tree)
429
+
430
+ if st.session_state.clusters and st.session_state.reactions_dict:
431
+ st.session_state.clustering_done = True
432
+ st.success(f"Clustering complete. Found {len(st.session_state.clusters)} clusters.")
433
+ else:
434
+ st.error("Clustering failed or returned empty results.")
435
+ st.session_state.clustering_done = False
436
+
437
+ # Clean up large intermediate objects if possible
438
+ del g_cgrs_dict
439
+ del results
440
+ gc.collect()
441
+
442
+ st.rerun() # Rerun to display clustering results cleanly
443
+ except Exception as e:
444
+ st.error(f"An error occurred during clustering: {e}")
445
+ st.session_state.clustering_done = False
446
+
447
+ # --- Display Clustering Results (if done) ---
448
+ if st.session_state.get('clustering_done', False):
449
+ clusters = st.session_state.clusters
450
+ reactions_dict = st.session_state.reactions_dict # Needed for download
451
+
452
+ tree = st.session_state.tree # Needed for display and download
453
+
454
+ if not clusters or not reactions_dict or not tree:
455
+ st.error("Clustering results are missing from session state. Please re-run clustering.")
456
+ st.session_state.clustering_done = False # Reset flag
457
+ else:
458
+ st.subheader(f"Best routes from {len(clusters)} Found Clusters")
459
+ # Display first route from first few clusters
460
+ displayed_clusters = 0
461
+ for cluster_num, node_id_list in clusters.items():
462
+ if displayed_clusters >= 10: # Limit displayed clusters
463
+ st.write(f"... and {len(clusters) - displayed_clusters} more clusters.")
464
+ break
465
+ if not node_id_list: continue # Skip empty clusters
466
+
467
+ st.markdown(f"**Cluster {cluster_num}** (Size: {len(node_id_list)}) - Example Route:")
468
+ node_id = node_id_list[0] # Display the first route as example
469
+ try:
470
+ num_steps = len(tree.synthesis_route(node_id))
471
+ route_score = round(tree.route_score(node_id), 3)
472
+ svg = get_route_svg(tree, node_id)
473
+ if svg:
474
+ st.image(svg, caption=f"Route {node_id}; {num_steps} steps; Route score: {route_score}")
475
+ else:
476
+ st.warning(f"Could not generate SVG for route {node_id}.")
477
+ displayed_clusters += 1
478
+ except Exception as e:
479
+ st.error(f"Error displaying route {node_id} for cluster {cluster_num+1}: {e}")
480
+
481
+ cluster_sizes = [len(cluster) for cluster in clusters.values()]
482
+ cluster_stat_col, cluster_download_col = st.columns(2, gap="medium")
483
+
484
+ with cluster_stat_col:
485
+ st.subheader("Cluster Statistics")
486
+ if cluster_sizes:
487
+ cluster_df = pd.DataFrame({'Cluster': range(1, len(cluster_sizes) + 1), 'Number of Routes': cluster_sizes})
488
+ st.dataframe(cluster_df)
489
+
490
+ # Display Pie Chart using Matplotlib
491
+ # try:
492
+ # fig, ax = plt.subplots(figsize=(5, 4)) # Adjust size if needed
493
+ # ax.pie(cluster_sizes, labels=[f'C{i+1}' for i in range(len(cluster_sizes))], autopct='%1.1f%%', startangle=90)
494
+ # ax.axis('equal') # Equal aspect ratio ensures that pie is drawn as a circle.
495
+ # st.pyplot(fig)
496
+ # plt.close(fig) # Close the figure to free memory
497
+ # except Exception as e:
498
+ # st.error(f"Could not generate pie chart: {e}")
499
+ else:
500
+ st.write("No cluster data to display statistics for.")
501
+
502
+ with cluster_download_col:
503
+ st.subheader("Cluster Reports") # Changed subheader
504
+
505
+ # Retrieve necessary data from session state
506
+ tree_for_html = st.session_state.get('tree')
507
+ clusters_for_html = st.session_state.get('clusters')
508
+ rg_cgrs_for_html = st.session_state.get('rg_cgrs_dict')
509
+
510
+ if not tree_for_html:
511
+ st.warning("MCTS Tree data not found. Cannot generate reports.")
512
+ elif not clusters_for_html:
513
+ st.warning("Cluster data not found. Cannot generate reports.")
514
+ else:
515
+ st.write("Generate downloadable HTML reports for each cluster:")
516
+
517
+ # Limit the number of download links shown directly if there are many clusters
518
+ MAX_DOWNLOAD_LINKS_DISPLAYED = 15 # Adjust as needed
519
+ num_clusters_total = len(clusters_for_html)
520
+ clusters_items = list(clusters_for_html.items()) # Get items to slice
521
+
522
+ for i, (cluster_num_idx, node_ids) in enumerate(clusters_items):
523
+ if i >= MAX_DOWNLOAD_LINKS_DISPLAYED:
524
+ st.caption(f"... plus {num_clusters_total - MAX_DOWNLOAD_LINKS_DISPLAYED} more cluster reports available.")
525
+ # Consider adding a button to download all as a zip if needed
526
+ break
527
+
528
+ cluster_num_display = int(cluster_num_idx) # Use 1-based index
529
+
530
+ if not node_ids: # Skip empty clusters
531
+ st.caption(f"Cluster {cluster_num_display} is empty, no report generated.")
532
+ continue
533
+
534
+ try:
535
+ try:
536
+ cluster_html_content = generate_cluster_html(
537
+ tree=tree_for_html,
538
+ cluster_node_ids=node_ids,
539
+ cluster_num=cluster_num_display,
540
+ rg_cgrs_dict=rg_cgrs_for_html,
541
+ aam=False
542
+ )
543
+ except Exception as e:
544
+ st.error(f"Error generating report/link for Cluster {cluster_num_display}: {e}")
545
+
546
+
547
+ # --- Create the download button using the existing function ---
548
+ download_filename = f"cluster_{cluster_num_display}_report.html"
549
+ button_text = f"Cluster {cluster_num_display} Report (HTML)"
550
+ dl_button_html = download_button(
551
+ object_to_download=cluster_html_content,
552
+ download_filename=download_filename,
553
+ button_text=button_text
554
+ # pickle_it=False # Ensure it's not pickled
555
+ )
556
+
557
+ if dl_button_html:
558
+ st.markdown(dl_button_html, unsafe_allow_html=True)
559
+ else:
560
+ # Error message if button creation failed (e.g., encoding error)
561
+ st.error(f"Failed to create download link for Cluster {cluster_num_display}.")
562
+
563
+ except Exception as e:
564
+ # Catch errors during HTML generation or button creation for a specific cluster
565
+ st.error(f"Error generating report/link for Cluster {cluster_num_display}: {e}")
566
+ # Optionally add more detailed logging here:
567
+ # import traceback
568
+ # st.error(traceback.format_exc())
569
+
570
+ if num_clusters_total > MAX_DOWNLOAD_LINKS_DISPLAYED:
571
+ # Optional: Add a button here to generate and download a ZIP file
572
+ # containing all cluster reports. This requires more implementation
573
+ # (using zipfile library in memory).
574
+ # e.g., if st.button("Download All Reports as ZIP"): ...
575
+ pass
576
+
577
+
578
+ st.divider()
579
+
580
+ # --- Subclustering Section ---
581
+ st.header("Sub-Clustering within a selected Cluster")
582
+
583
+ # Button to trigger the subclustering calculation
584
+ if st.button("Run Subclustering Analysis", key="submit_subclustering"):
585
+ st.session_state.subclustering_done = False # Reset flag
586
+ st.session_state.sub = None # Clear old results
587
+ with st.spinner("Performing subclustering analysis..."):
588
+ try:
589
+ # Retrieve necessary data from session state
590
+ clusters_for_sub = st.session_state.get('clusters')
591
+ reactions_for_sub = st.session_state.get('reactions_dict')
592
+ if clusters_for_sub and reactions_for_sub:
593
+ sub = sublcuster_all(clusters_for_sub, reactions_for_sub)
594
+ st.session_state.sub = sub
595
+ st.session_state.subclustering_done = True
596
+ st.success("Subclustering analysis complete.")
597
+ # Clean up intermediates if possible
598
+ gc.collect()
599
+ st.rerun() # Rerun to display results/inputs cleanly
600
+ else:
601
+ st.error("Missing cluster or reaction data needed for subclustering.")
602
+ except Exception as e:
603
+ st.error(f"An error occurred during subclustering: {e}")
604
+ st.session_state.subclustering_done = False
605
+
606
+
607
+ # Display subclustering inputs and results ONLY if subclustering is done
608
+ if st.session_state.get('subclustering_done', False):
609
+ sub = st.session_state.sub
610
+ tree = st.session_state.tree
611
+ clusters_for_sub = st.session_state.get('clusters')
612
+
613
+ if not sub or not tree:
614
+ st.error("Subclustering results are missing from session state. Please re-run subclustering.")
615
+ st.session_state.subclustering_done = False
616
+ else:
617
+ sub_input_col, sub_display_col = st.columns([0.2, 0.8]) # Adjust column ratio if needed
618
+
619
+ with sub_input_col:
620
+ st.subheader("Select Cluster and Step")
621
+ # Cluster selection (use cluster numbers as displayed, usually 1-based)
622
+ available_cluster_nums = [int(k) for k in sub.keys()] # Use 1-based indexing for UI
623
+ if not available_cluster_nums:
624
+ st.warning("No clusters available in subclustering results.")
625
+ else:
626
+ # Key is essential here to maintain state across reruns
627
+ user_input_cluster_num_display = st.selectbox(
628
+ "Select Cluster #:",
629
+ options=sorted(available_cluster_nums),
630
+ key='subcluster_num_select'
631
+ )
632
+
633
+ # Convert back to 0-based index for accessing 'sub' dictionary
634
+ selected_cluster_idx = user_input_cluster_num_display
635
+
636
+ if selected_cluster_idx in sub:
637
+ sub_step_cluster = sub[selected_cluster_idx]
638
+ allowed_step_numbers = sorted(list(sub_step_cluster.keys()))
639
+
640
+ if not allowed_step_numbers:
641
+ st.warning(f"No reaction steps found for Cluster {user_input_cluster_num_display}.")
642
+ else:
643
+ # Key is essential here
644
+ selected_step_number = st.selectbox(
645
+ "Select Number of Steps:",
646
+ options=allowed_step_numbers,
647
+ key='subcluster_steps_select'
648
+ )
649
+ # --- Display logic moved to the right column ---
650
+ rg_cgrs = st.session_state.get('rg_cgrs_dict')
651
+ cluster_rg_cgr = rg_cgrs[clusters_for_sub[user_input_cluster_num_display][0]]
652
+ cluster_rg_cgr.clean2d()
653
+ st.image(cluster_rg_cgr.depict(), caption=f"RG-CGR of cluster")
654
+
655
+
656
+ else:
657
+ st.warning(f"Selected cluster {user_input_cluster_num_display} (index {selected_cluster_idx}) not found in subclustering results.")
658
+
659
+
660
+
661
+ with sub_display_col:
662
+ st.subheader("Subcluster Results")
663
+ # Check if inputs are valid before trying to display
664
+ if 'user_input_cluster_num_display' in locals() and \
665
+ 'selected_cluster_idx' in locals() and \
666
+ selected_cluster_idx in sub and \
667
+ 'selected_step_number' in locals() and \
668
+ selected_step_number in sub[selected_cluster_idx]:
669
+
670
+ subclusters = sub[selected_cluster_idx][selected_step_number]
671
+ st.write(f"Displaying **{len(subclusters)}** subclusters for **Cluster {user_input_cluster_num_display}** with **{selected_step_number} steps**:")
672
+
673
+ if not subclusters:
674
+ st.info("No subclusters found for this selection.")
675
+ else:
676
+ # Limit the display if there are too many subclusters/routes
677
+ MAX_DISPLAY_SUBCLUSTERS = 20
678
+ MAX_ROUTES_PER_SUBCLUSTER = 10
679
+
680
+ for subcluster_num, subcluster_set in enumerate(subclusters):
681
+ if subcluster_num >= MAX_DISPLAY_SUBCLUSTERS:
682
+ st.write(f"... and {len(subclusters) - MAX_DISPLAY_SUBCLUSTERS} more subclusters.")
683
+ break
684
+
685
+ st.markdown(f"--- \n**Subcluster {subcluster_num + 1}** (Size: {len(subcluster_set)})")
686
+ routes_shown = 0
687
+ for route_id in subcluster_set:
688
+ if routes_shown >= MAX_ROUTES_PER_SUBCLUSTER:
689
+ st.write(f"(Showing first {MAX_ROUTES_PER_SUBCLUSTER} routes)")
690
+ break
691
+ try:
692
+ # Need num_steps and route_score for caption (optional but nice)
693
+ num_steps_sub = len(tree.synthesis_route(route_id))
694
+ route_score_sub = round(tree.route_score(route_id), 3)
695
+ svg_sub = get_route_svg(tree, route_id)
696
+ if svg_sub:
697
+ st.image(svg_sub, caption=f"Route {route_id}; {num_steps_sub} steps; Score: {route_score_sub}")
698
+ else:
699
+ st.warning(f"Could not generate SVG for route {route_id}.")
700
+ routes_shown += 1
701
+ except Exception as e:
702
+ st.error(f"Error displaying route {route_id} in subcluster {subcluster_num+1}: {e}")
703
+ else:
704
+ st.info("Select a cluster and step number to view subclusters.")
705
+
706
+
707
+ # --- Handling No Solution Case ---
708
+ elif not st.session_state.get('planning_done', False):
709
+ # Only show this if planning was attempted but failed (or not run yet)
710
+ # Avoid showing it if just molecule changed
711
+ if submit_planning: # Check if the button was just pressed
712
+ st.warning("Planning did not complete successfully or is still running.")
713
+
714
+ else: # Planning done, but res["solved"] is False
715
+ st.header('Planning Results')
716
+ st.warning("No reaction path found for the target molecule with the current settings.")
717
+ st.write("Consider adjusting planning options (e.g., increase iterations, adjust depth, check molecule validity).")
718
+ # Optionally display basic stats even if not solved
719
+ stat_col, _ = st.columns(2)
720
+ with stat_col:
721
+ st.subheader("Run Statistics (No Solution)")
722
+ try:
723
+ if 'target_smiles' not in res: res['target_smiles'] = st.session_state.target_smiles
724
+ cols_to_show = [col for col in ["target_smiles", "num_nodes", "num_iter", "search_time"] if col in res]
725
+ df = pd.DataFrame(res, index=[0])[cols_to_show]
726
+ st.dataframe(df)
727
+ except Exception as e:
728
+ st.error(f"Error displaying statistics: {e}")
729
+ st.write(res)
730
+
731
+
732
+ # --- Restart Button ---
733
  st.divider()
734
+ st.header('Restart Application State')
735
+ if st.button("Clear All Results & Restart"):
736
+ # Clear all relevant session state keys
737
+ keys_to_clear = [
738
+ "planning_done", "tree", "res", "target_smiles",
739
+ "clustering_done", "clusters", "reactions_dict", "num_clusters_setting", "rg_cgrs_dict",
740
+ "subclustering_done", "sub",
741
+ "clusters_downloaded" # Add any other state keys you use
742
+ ]
743
+ for key in keys_to_clear:
744
+ if key in st.session_state:
745
+ del st.session_state[key]
746
+ # Clear ketcher state by assigning a default value (or empty string)
747
+ st.session_state.ketcher = DEFAULT_MOL # Reset ketcher to default
748
  st.rerun()
{cluster β†’ cluster copy}/__init__.py RENAMED
File without changes
{cluster β†’ cluster copy}/clustering.py RENAMED
File without changes
{cluster β†’ cluster copy}/rs_cgr.py RENAMED
File without changes
{cluster β†’ cluster copy}/subcluster.py RENAMED
File without changes
{cluster β†’ cluster copy}/super_cgr.py RENAMED
File without changes
{cluster β†’ cluster copy}/utils.py RENAMED
File without changes
{cluster β†’ cluster copy}/visualize.py RENAMED
File without changes