Gilmullin Almaz commited on
Commit
27a7101
·
1 Parent(s): 57a9d9a

draft subclustering - need to solve resetting

Browse files
Files changed (1) hide show
  1. app.py +147 -315
app.py CHANGED
@@ -23,6 +23,9 @@ from synplan.utils.visualisation import generate_results_html, get_route_svg
23
  from cluster.super_cgr import *
24
  from cluster.rs_cgr import *
25
  from cluster.clustering import *
 
 
 
26
  from StructureFingerprint import MorganFingerprint
27
 
28
  import psutil
@@ -33,83 +36,6 @@ disable_progress_bars("huggingface_hub")
33
 
34
  smiles_parser = SMILESRead.create_parser(ignore=True)
35
 
36
-
37
- def reassign_nums_chunk(route_dict):
38
- """Process a chunk of routes for reassigning numbers"""
39
- return {k: reassign_nums(v) for k, v in route_dict.items()}
40
-
41
- def cluster_molecules_optimized(fingerprints_dict, max_clusters):
42
- """Memory-optimized version of cluster_molecules.
43
-
44
- Args:
45
- fingerprints_dict (dict): Dictionary of pre-computed fingerprints
46
- max_clusters (int): Maximum number of clusters
47
-
48
- Returns:
49
- dict: Clustering results containing clusters_dict and cluster_labels
50
- """
51
- try:
52
- # Convert dictionary to arrays for efficient processing
53
- labels = np.array(list(fingerprints_dict.keys()))
54
- fingerprints = np.array(list(fingerprints_dict.values()))
55
-
56
- # Calculate similarity matrix in chunks to save memory
57
- chunk_size = 100
58
- n_samples = len(fingerprints)
59
- similarity_matrix = np.zeros((n_samples, n_samples))
60
-
61
- for i in range(0, n_samples, chunk_size):
62
- chunk_end = min(i + chunk_size, n_samples)
63
- chunk = fingerprints[i:chunk_end]
64
-
65
- # Calculate similarity for this chunk against all fingerprints
66
- similarity_chunk = tanimoto_similarity_continuous(chunk, fingerprints)
67
- similarity_matrix[i:chunk_end] = similarity_chunk
68
-
69
- # Clear memory
70
- del similarity_chunk
71
- gc.collect()
72
-
73
- # Convert to distance matrix
74
- distance_matrix = 1 - similarity_matrix
75
-
76
- # Free memory
77
- del similarity_matrix
78
- gc.collect()
79
-
80
- # Calculate condensed distance matrix
81
- condensed_distance = squareform(distance_matrix)
82
-
83
- # Free memory
84
- del distance_matrix
85
- gc.collect()
86
-
87
- # Calculate linkage
88
- Z = fastcluster.linkage(condensed_distance, method='average')
89
-
90
- # Free memory
91
- del condensed_distance
92
- gc.collect()
93
-
94
- # Perform clustering
95
- cluster_labels = fcluster(Z, max_clusters, criterion='maxclust')
96
-
97
- # Create clusters dictionary
98
- clusters_dict = {}
99
- for cluster in range(1, max_clusters + 1):
100
- cluster_indices = np.where(cluster_labels == cluster)[0]
101
- clusters_dict[cluster] = list(labels[cluster_indices])
102
-
103
- return {
104
- 'clusters_dict': clusters_dict,
105
- 'cluster_labels': cluster_labels,
106
- 'linkage_matrix': Z
107
- }
108
-
109
- except Exception as e:
110
- print(f"Error in cluster_molecules_optimized: {str(e)}")
111
- raise e
112
-
113
  def download_button(object_to_download, download_filename, button_text, pickle_it=False):
114
  """
115
  Issued from
@@ -186,6 +112,23 @@ def download_button(object_to_download, download_filename, button_text, pickle_i
186
 
187
  st.set_page_config(page_title="SynPlanner GUI", page_icon="🧪", layout="wide")
188
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
189
  intro_text = '''
190
  This is a demo of the graphical user interface of
191
  [SynPlanner](https://github.com/Laboratoire-de-Chemoinformatique/SynPlanner/).
@@ -268,7 +211,8 @@ search_strategy = search_strategy_translator[search_strategy_input]
268
 
269
  submit_planning = st.button('Start retrosynthetic planning')
270
 
271
- if submit_planning:
 
272
  with st.status("Downloading data"):
273
  st.write("Downloading building blocks")
274
  building_blocks = load_building_blocks(building_blocks_path, standardize=False)
@@ -306,12 +250,21 @@ if submit_planning:
306
 
307
  res = extract_tree_stats(tree, target_molecule)
308
 
 
 
 
 
 
 
 
 
 
 
309
  st.header('Results')
310
  if res["solved"]:
311
- st.balloons()
312
-
313
  st.subheader("Examples of found retrosynthetic routes")
314
-
315
  image_counter = 0
316
  visualised_node_ids = set()
317
  for n, node_id in enumerate(sorted(set(tree.winning_nodes))):
@@ -322,255 +275,134 @@ if submit_planning:
322
  image_counter += 1
323
  num_steps = len(tree.synthesis_route(node_id))
324
  route_score = round(tree.route_score(node_id), 3)
325
- st.image(get_route_svg(tree, node_id), caption=f"Route {node_id}; {num_steps} steps; Route score: {route_score}")
 
326
 
327
-
328
- ### Modified part
329
- # cluster_box, z = st.columns(2, gap="medium")
330
- # with cluster_box:
331
- # num_clusters = st.slider('Number of clusters to display', min_value=2, max_value=10, value=2)
332
-
333
- # submit_clustering = st.button('Start clustering')
334
-
335
- # if submit_clustering:
336
- # st.subheader("Examples of clusters")
337
- # super_cgrs_dict = reassign_nums(tree)
338
-
339
- # reduced_super_cgrs_dict = process_all_rs_cgrs(super_cgrs_dict)
340
-
341
- # mfp = MorganFingerprint()
342
-
343
- # results = cluster_molecules(reduced_super_cgrs_dict, mfp, max_clusters=num_clusters)
344
- # cluster_box, z = st.columns(2, gap="medium")
345
- # with cluster_box:
346
- # # Initialize session state if not exists
347
- # if 'memory_warning_shown' not in st.session_state:
348
- # st.session_state.memory_warning_shown = False
349
-
350
- # current_memory = psutil.Process().memory_info().rss / 1024 / 1024
351
- # st.write(f"Current memory usage: {current_memory:.2f} MB")
352
- # st.write(f"Number of winning nodes: {len(tree.winning_nodes)}")
353
-
354
- # # Memory warning
355
- # if current_memory > 1000 and not st.session_state.memory_warning_shown:
356
- # st.warning("Memory usage is high. Consider reducing the number of routes or clearing cache.")
357
- # st.session_state.memory_warning_shown = True
358
-
359
- # # Store the previous value in session state
360
- # if 'prev_num_clusters' not in st.session_state:
361
- # st.session_state.prev_num_clusters = 2
362
-
363
- # num_clusters = st.slider(
364
- # 'Number of clusters to display',
365
- # min_value=2,
366
- # max_value=min(10, len(tree.winning_nodes)),
367
- # value=st.session_state.prev_num_clusters
368
- # )
369
-
370
- # # Update the stored value only if it changed
371
- # if num_clusters != st.session_state.prev_num_clusters:
372
- # st.session_state.prev_num_clusters = num_clusters
373
-
374
- # submit_clustering = st.button('Start clustering')
375
-
376
- # if submit_clustering:
377
- # try:
378
- # with st.spinner("Processing clusters..."):
379
- # # Clear memory before starting
380
- # gc.collect()
381
-
382
- # st.write("Starting clustering process...")
383
- # memory_before = psutil.Process().memory_info().rss / 1024 / 1024
384
- # st.write(f"Memory before clustering: {memory_before:.2f} MB")
385
-
386
- # super_cgrs_dict = reassign_nums(tree)
387
- # del tree # Free up memory from the tree object since we don't need it anymore
388
- # gc.collect()
389
-
390
- # reduced_super_cgrs_dict = process_all_rs_cgrs(super_cgrs_dict)
391
- # del super_cgrs_dict # Free up memory
392
- # gc.collect()
393
-
394
- # memory_after = psutil.Process().memory_info().rss / 1024 / 1024
395
- # st.write(f"Memory after CGR processing: {memory_after:.2f} MB")
396
-
397
- # mfp = MorganFingerprint()
398
- # results = cluster_molecules(reduced_super_cgrs_dict, mfp, max_clusters=num_clusters)
399
- # del reduced_super_cgrs_dict # Free up memory
400
- # gc.collect()
401
-
402
- # st.write("Clustering completed")
403
-
404
- # except Exception as e:
405
- # st.error(f"Clustering failed with error: {str(e)}")
406
- # st.write(f"Memory at error: {psutil.Process().memory_info().rss / 1024 / 1024:.2f} MB")
407
- # raise e
408
-
409
-
410
- # Access results
411
- # clusters = results['clusters_dict']
412
-
413
- # for cluster_num, node_id_list in clusters.items():
414
- # st.markdown(f"Cluster's number: ``{cluster_num}``")
415
- # node_id = node_id_list[0]
416
- # num_steps = len(tree.synthesis_route(node_id))
417
- # route_score = round(tree.route_score(node_id), 3)
418
- # st.image(get_route_svg(tree, node_id), caption=f"Route {node_id}; {num_steps} steps; Route score: {route_score}")
419
-
420
- @st.cache_data(hash_funcs={Tree: lambda _: None})
421
- def prepare_clustering_data(tree):
422
- try:
423
- # Log the start and basic info from the Tree
424
- print("Starting clustering data preparation.")
425
- total_nodes = len(tree.winning_nodes)
426
- print(f"Total winning nodes: {total_nodes}")
427
- print(f"Tree id: {id(tree)}")
428
-
429
- chunk_size = 10
430
- super_cgrs_dict = {}
431
-
432
- # Process winning nodes in chunks
433
- for i in range(0, total_nodes, chunk_size):
434
- current_chunk = list(tree.winning_nodes)[i:i+chunk_size]
435
- print(f"Processing chunk {i // chunk_size + 1}: Nodes {current_chunk}")
436
-
437
- temp_dict = {}
438
- for node in current_chunk:
439
- try:
440
- # Log before processing each node
441
- print(f"Processing node {node}")
442
- route = tree.synthesis_route(node)
443
- temp_dict[node] = route
444
- print(f"Node {node} processed successfully (route length: {len(route)}).")
445
- except Exception as e:
446
- print(f"Error processing node {node}: {e}")
447
-
448
- # Log before calling reassign_nums_chunk
449
- print(f"Calling reassign_nums_chunk for nodes: {list(temp_dict.keys())}")
450
- chunk_super_cgrs = reassign_nums_chunk(temp_dict)
451
- super_cgrs_dict.update(chunk_super_cgrs)
452
- print(f"Chunk {i // chunk_size + 1} processed. Keys: {list(chunk_super_cgrs.keys())}")
453
-
454
- del temp_dict
455
- gc.collect()
456
-
457
- # Process reduced CGRs in chunks
458
- reduced_super_cgrs_dict = {}
459
- for i in range(0, len(super_cgrs_dict), chunk_size):
460
- keys = list(super_cgrs_dict.keys())[i:i+chunk_size]
461
- chunk_dict = {k: super_cgrs_dict[k] for k in keys}
462
- print(f"Reducing chunk for keys: {keys}")
463
- reduced_chunk = process_all_rs_cgrs(chunk_dict)
464
- reduced_super_cgrs_dict.update(reduced_chunk)
465
- print(f"Reduced chunk processed for keys: {list(reduced_chunk.keys())}")
466
-
467
- del chunk_dict
468
- gc.collect()
469
-
470
- print("Clustering data preparation complete.")
471
- return reduced_super_cgrs_dict
472
- except Exception as e:
473
- print(f"Error in prepare_clustering_data: {str(e)}")
474
- st.error(f"Error in prepare_clustering_data: {str(e)}")
475
- return None
476
-
477
-
478
- def memory_status():
479
- """Get current memory status"""
480
- process = psutil.Process()
481
- memory = process.memory_info().rss / 1024 / 1024
482
- return f"Memory usage: {memory:.2f} MB"
483
-
484
- # Initialize session state for tree and clustering data
485
- if 'tree_data' not in st.session_state:
486
- st.session_state.tree_data = tree
487
- if 'clustering_state' not in st.session_state:
488
- st.session_state.clustering_state = {
489
- 'prepared': False,
490
- 'data': None,
491
- 'num_clusters': 2
492
- }
493
-
494
- cluster_box, z = st.columns(2, gap="medium")
495
- with cluster_box:
496
- st.write(memory_status())
497
- st.write(f"Number of winning nodes: {len(st.session_state.tree_data.winning_nodes)}")
498
-
499
- # Step 1: Prepare Data Button
500
- if not st.session_state.clustering_state['prepared']:
501
- if st.button('Step 1: Prepare clustering data'):
502
- with st.spinner("Preparing data..."):
503
- try:
504
- st.session_state.clustering_state['data'] = prepare_clustering_data(st.session_state.tree_data)
505
- st.session_state.clustering_state['prepared'] = True
506
- st.success("Data prepared! Now you can proceed to Step 2.")
507
- except Exception as e:
508
- st.error(f"Preparation failed: {str(e)}")
509
-
510
- # Step 2: Only show clustering controls if data is prepared
511
- if st.session_state.clustering_state['prepared']:
512
- st.markdown("### Step 2: Select number of clusters")
513
- # Store slider value in session state
514
- st.session_state.clustering_state['num_clusters'] = st.slider(
515
- 'Number of clusters',
516
- min_value=2,
517
- max_value=min(10, len(st.session_state.tree_data.winning_nodes)),
518
- value=st.session_state.clustering_state['num_clusters']
519
- )
520
-
521
- # Step 3: Generate Clusters Button
522
- if st.button('Step 3: Generate clusters'):
523
- with st.spinner("Clustering..."):
524
- try:
525
- results = perform_clustering(
526
- st.session_state.clustering_state['data'],
527
- st.session_state.clustering_state['num_clusters']
528
- )
529
-
530
- if results:
531
- st.success("Clustering complete!")
532
- for cluster_num, node_ids in results['clusters_dict'].items():
533
- with st.expander(f"Cluster {cluster_num}"):
534
- if node_ids:
535
- node_id = node_ids[0]
536
- num_steps = len(st.session_state.tree_data.synthesis_route(node_id))
537
- route_score = round(st.session_state.tree_data.route_score(node_id), 3)
538
- st.image(
539
- get_route_svg(st.session_state.tree_data, node_id),
540
- caption=f"Route {node_id}; {num_steps} steps; Score: {route_score}"
541
- )
542
- except Exception as e:
543
- st.error(f"Clustering failed: {str(e)}")
544
-
545
- # Clear memory button
546
- if st.button('Clear memory and start over'):
547
- st.cache_data.clear()
548
- del st.session_state.clustering_state
549
- del st.session_state.tree_data
550
- gc.collect()
551
- st.success("Memory cleared! Please refresh the page to start over.")
552
- st.rerun()
553
-
554
  stat_col, download_col = st.columns(2, gap="medium")
555
-
556
  with stat_col:
557
  st.subheader("Statistics")
558
  df = pd.DataFrame(res, index=[0])
559
  st.write(df[["target_smiles", "num_routes", "num_nodes", "num_iter", "search_time"]])
560
-
561
  with download_col:
562
  st.subheader("Downloads")
563
  html_body = generate_results_html(tree, html_path=None, extended=True)
564
  dl_html = download_button(html_body, 'results_synplanner.html', 'Download results as a HTML file')
565
-
566
- dl_csv = download_button(pd.DataFrame(res, index=[0]), 'results_synplanner.csv',
567
- 'Download statistics as a csv file')
568
  st.markdown(dl_html + dl_csv, unsafe_allow_html=True)
569
 
570
- else:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
571
  st.write("Found no reaction path.")
572
 
573
  st.divider()
574
  st.header('Restart from the beginning?')
575
  if st.button("Restart"):
 
576
  st.rerun()
 
23
  from cluster.super_cgr import *
24
  from cluster.rs_cgr import *
25
  from cluster.clustering import *
26
+ from cluster.visualize import *
27
+ from cluster.utils import *
28
+ from cluster.subcluster import *
29
  from StructureFingerprint import MorganFingerprint
30
 
31
  import psutil
 
36
 
37
  smiles_parser = SMILESRead.create_parser(ignore=True)
38
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39
  def download_button(object_to_download, download_filename, button_text, pickle_it=False):
40
  """
41
  Issued from
 
112
 
113
  st.set_page_config(page_title="SynPlanner GUI", page_icon="🧪", layout="wide")
114
 
115
+ # Initialize session state variables if they don't exist.
116
+ if "planning_done" not in st.session_state:
117
+ st.session_state.planning_done = False
118
+ if "tree" not in st.session_state:
119
+ st.session_state.tree = None
120
+ if "res" not in st.session_state:
121
+ st.session_state.res = None
122
+ if "num_clusters" not in st.session_state:
123
+ st.session_state.num_clusters = 10
124
+
125
+ if 'clustering_started' not in st.session_state:
126
+ st.session_state.clustering_started = False
127
+ if 'clusters_downloaded' not in st.session_state:
128
+ st.session_state.clusters_downloaded = False
129
+
130
+ # st.write("Initial session state:", dict(st.session_state))
131
+
132
  intro_text = '''
133
  This is a demo of the graphical user interface of
134
  [SynPlanner](https://github.com/Laboratoire-de-Chemoinformatique/SynPlanner/).
 
211
 
212
  submit_planning = st.button('Start retrosynthetic planning')
213
 
214
+ # if submit_planning:
215
+ if submit_planning and not st.session_state.planning_done:
216
  with st.status("Downloading data"):
217
  st.write("Downloading building blocks")
218
  building_blocks = load_building_blocks(building_blocks_path, standardize=False)
 
250
 
251
  res = extract_tree_stats(tree, target_molecule)
252
 
253
+ # Store planning outputs in session_state so they persist
254
+ st.session_state['tree'] = tree
255
+ st.session_state['res'] = res
256
+ st.session_state.planning_done = True
257
+
258
+ # Display results if planning has been completed
259
+ if st.session_state.planning_done and st.session_state.res is not None and st.session_state.clustering_started:
260
+ res = st.session_state.res
261
+ tree = st.session_state.tree
262
+
263
  st.header('Results')
264
  if res["solved"]:
265
+ # st.balloons()
266
+
267
  st.subheader("Examples of found retrosynthetic routes")
 
268
  image_counter = 0
269
  visualised_node_ids = set()
270
  for n, node_id in enumerate(sorted(set(tree.winning_nodes))):
 
275
  image_counter += 1
276
  num_steps = len(tree.synthesis_route(node_id))
277
  route_score = round(tree.route_score(node_id), 3)
278
+ st.image(get_route_svg(tree, node_id),
279
+ caption=f"Route {node_id}; {num_steps} steps; Route score: {route_score}")
280
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
281
  stat_col, download_col = st.columns(2, gap="medium")
 
282
  with stat_col:
283
  st.subheader("Statistics")
284
  df = pd.DataFrame(res, index=[0])
285
  st.write(df[["target_smiles", "num_routes", "num_nodes", "num_iter", "search_time"]])
 
286
  with download_col:
287
  st.subheader("Downloads")
288
  html_body = generate_results_html(tree, html_path=None, extended=True)
289
  dl_html = download_button(html_body, 'results_synplanner.html', 'Download results as a HTML file')
290
+ dl_csv = download_button(pd.DataFrame(res, index=[0]),
291
+ 'results_synplanner.csv', 'Download statistics as a csv file')
 
292
  st.markdown(dl_html + dl_csv, unsafe_allow_html=True)
293
 
294
+ st.header("Clustering the retrosynthetic routes")
295
+
296
+ # Initialize slider state if not already set
297
+ if 'num_clusters' not in st.session_state:
298
+ st.session_state['num_clusters'] = 10
299
+
300
+ cluster_box, _ = st.columns(2, gap="medium")
301
+ with cluster_box:
302
+ num_clusters = st.slider(
303
+ 'Number of clusters to display',
304
+ min_value=2,
305
+ max_value=10,
306
+ value=st.session_state['num_clusters'],
307
+ key='cluster_slider'
308
+ )
309
+ # Save the current slider value to session_state
310
+ st.session_state['num_clusters'] = num_clusters
311
+
312
+ if st.button('Start clustering', key='submit_clustering'):
313
+ st.session_state.clustering_started = True
314
+ # st.write("Clustering started; session state now:", dict(st.session_state))
315
+ # st.write("Clustering started!")
316
+ st.subheader("Examples of clusters")
317
+ super_cgrs_dict = reassign_nums(tree)
318
+
319
+ reduced_super_cgrs_dict = process_all_rs_cgrs(super_cgrs_dict)
320
+
321
+ mfp = MorganFingerprint()
322
+
323
+ results = cluster_molecules(reduced_super_cgrs_dict, mfp, max_clusters=num_clusters)
324
+
325
+ clusters = results['clusters_dict']
326
+
327
+ for cluster_num, node_id_list in clusters.items():
328
+ st.markdown(f"Cluster's number: {cluster_num}; Size {len(node_id_list)}")
329
+ node_id = node_id_list[0]
330
+ num_steps = len(tree.synthesis_route(node_id))
331
+ route_score = round(tree.route_score(node_id), 3)
332
+ st.image(get_route_svg(tree, node_id), caption=f"Route {node_id}; {num_steps} steps; Route score: {route_score}")
333
+
334
+ cluster_sizes = [len(cluster) for cluster in clusters.values()]
335
+
336
+ cluster_stat_col, cluster_download_col = st.columns(2, gap="medium")
337
+ with cluster_stat_col:
338
+ st.subheader("Statistics")
339
+ # st.write(cluster_sizes)
340
+ cluster_df = pd.DataFrame({'Cluster': range(len(cluster_sizes)), 'Routes': cluster_sizes})
341
+ # cluster_df = pd.DataFrame(cluster_sizes, index=[0])
342
+ st.write(cluster_df)
343
+
344
+ def on_download_click():
345
+ st.session_state.clusters_downloaded = True
346
+ st.write("Download clusters button pressed via on_click. Updated session state:", dict(st.session_state))
347
+ save_route_images(tree, reactions_dict, cluster_dict=clusters_converted)
348
+ # Here you can call save_route_images(...) if desired.
349
+
350
+ with cluster_download_col:
351
+ st.subheader("Downloads: Don't work. Resets evey time")
352
+ reactions_dict = extract_reactions(tree)
353
+ clusters_converted = {int(key): value for key, value in clusters.items()} if clusters else clusters
354
+
355
+ # Use on_click to capture the click event reliably.
356
+ st.button('Download clusters', key='download_clusters_button', on_click=on_download_click)
357
+
358
+ # Log whether the flag has been set after the button definition.
359
+ st.write("Clusters downloaded flag (from session_state):", st.session_state.get("clusters_downloaded"))
360
+
361
+ # # save_route_images(tree, reactions_dict, cluster_dict=clusters_converted)
362
+ # with cluster_download_col:
363
+ # st.subheader("Downloads")
364
+ # reactions_dict = extract_reactions(tree)
365
+ # clusters_converted = {int(key): value for key, value in clusters.items()} if clusters else clusters
366
+
367
+ # if st.session_state.clustering_started:
368
+ # st.write("Rendering download clusters button. Session state:", dict(st.session_state))
369
+ # # Use a more unique key for the download button.
370
+ # download_clusters = st.button('Download clusters', key='download_clusters_button')
371
+ # st.write("download_clusters value:", download_clusters)
372
+ # if download_clusters:
373
+ # st.session_state.clusters_downloaded = True
374
+ # st.write("Download clusters button pressed. Updated session state:", dict(st.session_state))
375
+
376
+ col1, _ = st.columns([.2, .8])
377
+ with col1:
378
+ fig = pie_chart(cluster_sizes)
379
+ st.pyplot(fig)
380
+ st.header("Sub Clustering the retrosynthetic routes - Resets every time when i interact with input widget")
381
+ sub = sublcuster_all(clusters, reactions_dict)
382
+ col2, _ = st.columns([.2, .8])
383
+ with col2:
384
+ user_input_cluster_num = st.number_input("Enter a number:", min_value=1,
385
+ max_value=max(clusters.keys()), value=1, step=1)
386
+
387
+ st.write(f"You entered the # cluster: {user_input_cluster_num}")
388
+ sub_step_cluster = sub[user_input_cluster_num]
389
+ allowed_numbers = sub_step_cluster.keys()
390
+ selected_number = st.selectbox("Choose a number:", allowed_numbers)
391
+ st.write(f"You entered number of steps: {selected_number}")
392
+ subclusters = sub_step_cluster[selected_number]
393
+
394
+ st.subheader(f"Found number of subclusters: {len(subclusters)}")
395
+ for subcluster_num, subcluster_set in enumerate(subclusters):
396
+ st.write(f"Subcluster #: {subcluster_num + 1}")
397
+ for route_id in subcluster_set:
398
+ st.write(f"Node_ID: {route_id}")
399
+ st.image(get_route_svg(tree, route_id), caption=f"Route {node_id}; {num_steps} steps; Route score: {route_score}")
400
+
401
+ else:
402
  st.write("Found no reaction path.")
403
 
404
  st.divider()
405
  st.header('Restart from the beginning?')
406
  if st.button("Restart"):
407
+ st.session_state.planning_done = False
408
  st.rerun()