Gilmullin Almaz commited on
Commit
81a78e7
·
1 Parent(s): 59ff193

Refactor molecule input handling to implement two-way synchronization and replace ReducedRouteCGR with SB-CGR, enhancing state management and visualization consistency.

Browse files
Files changed (1) hide show
  1. synplan/interfaces/gui.py +190 -171
synplan/interfaces/gui.py CHANGED
@@ -26,6 +26,7 @@ from synplan.utils.visualisation import (
26
  generate_results_html,
27
  html_top_routes_cluster,
28
  get_route_svg,
 
29
  )
30
  from synplan.utils.config import TreeConfig, PolicyNetworkConfig
31
  from synplan.utils.loading import load_reaction_rules, load_building_blocks
@@ -169,8 +170,10 @@ def initialize_app():
169
  st.session_state.num_clusters_setting = 10
170
  if "route_cgrs_dict" not in st.session_state:
171
  st.session_state.route_cgrs_dict = None
172
- if "r_route_cgrs_dict" not in st.session_state:
173
- st.session_state.r_route_cgrs_dict = None
 
 
174
 
175
  # Subclustering state
176
  if "subclustering_done" not in st.session_state:
@@ -219,7 +222,7 @@ def setup_sidebar():
219
 
220
 
221
  def handle_molecule_input():
222
- """3. Molecule Input: Managing the input area for molecule data."""
223
  st.header("Molecule input")
224
  st.markdown(
225
  """
@@ -228,42 +231,60 @@ def handle_molecule_input():
228
  * Draw it + Apply
229
  """
230
  )
231
- # Use st.session_state.ketcher to persist drawn molecule
232
- molecule_text_input = st.text_input(
233
- "SMILES:", value=st.session_state.ketcher, key="smiles_text_input_key"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
234
  )
235
 
236
- smile_code_ketcher = st_ketcher(molecule_text_input, key="ketcher_widget")
237
- # col_kethcer, col_info = st.columns([0.8, 0.2])
238
- # with col_kethcer:
239
- # smile_code_ketcher = st_ketcher(molecule_text_input, key="ketcher_widget")
240
- # with col_info:
241
- # st.subheader("Synthetic Complexity")
242
- # sascore = ()
243
- # st.markdown(f"SAScore: {sascore}")
244
- # syba_score = ()
245
- # st.markdown(f"SYBA: {sascore}")
246
-
247
- current_smile_code = (
248
- smile_code_ketcher # The output from ketcher is the definitive SMILES
249
  )
250
 
 
 
 
 
 
 
 
 
251
  if (
252
- "target_smiles" in st.session_state
253
- and current_smile_code != st.session_state.target_smiles
 
254
  ):
255
- st.warning("Molecule structure changed. Please re-run planning.")
256
- st.session_state.planning_done = False
257
- st.session_state.clustering_done = False
258
- st.session_state.subclustering_done = False
259
- st.session_state.tree = None
260
- st.session_state.res = None
261
- st.session_state.clusters = None
262
- st.session_state.reactions_dict = None
263
- st.session_state.subclusters = None
264
- st.session_state.ketcher = current_smile_code
265
 
266
- return current_smile_code
 
 
 
 
267
 
268
 
269
  def setup_planning_options():
@@ -272,9 +293,7 @@ def setup_planning_options():
272
  st.markdown(
273
  """If you modified the structure, please ensure you clicked on `Apply` (bottom right of the molecular editor)."""
274
  )
275
- # This smile_code display will be updated if handle_molecule_input has run and returned a new smile_code
276
- # However, to display it correctly, we need the current smile_code from the session or input handler.
277
- # For simplicity, let's assume handle_molecule_input has updated st.session_state.ketcher
278
  st.markdown(
279
  f"The molecule SMILES is actually: ``{st.session_state.get('ketcher', DEFAULT_MOL)}``"
280
  )
@@ -361,7 +380,8 @@ def setup_planning_options():
361
  st.session_state.reactions_dict = None
362
  st.session_state.subclusters = None
363
  st.session_state.route_cgrs_dict = None
364
- st.session_state.r_route_cgrs_dict = None
 
365
  active_smile_code = st.session_state.get(
366
  "ketcher", DEFAULT_MOL
367
  ) # Get current SMILES
@@ -370,7 +390,7 @@ def setup_planning_options():
370
  )
371
 
372
  try:
373
- target_molecule = mol_from_smiles(active_smile_code)
374
  if target_molecule is None:
375
  st.error(f"Could not parse the input SMILES: {active_smile_code}")
376
  else:
@@ -419,7 +439,7 @@ def setup_planning_options():
419
 
420
  mcts_progress_text = "Running MCTS iterations..."
421
  mcts_bar = st.progress(0, text=mcts_progress_text)
422
- for step, (solved, node_id) in enumerate(tree):
423
  progress_value = min(
424
  1.0, (step + 1) / planning_params["max_iterations"]
425
  )
@@ -464,34 +484,35 @@ def display_planning_results():
464
 
465
  st.subheader("Examples of found retrosynthetic routes")
466
  image_counter = 0
467
- visualised_node_ids = set()
468
 
469
  if not winning_nodes:
470
  st.warning(
471
  "Planning solved, but no winning nodes found in the tree object."
472
  )
473
  else:
474
- for n, node_id in enumerate(winning_nodes):
475
  if image_counter >= 3:
476
  break
477
- if node_id not in visualised_node_ids:
478
  try:
479
- visualised_node_ids.add(node_id)
480
- num_steps = len(tree.synthesis_route(node_id))
481
- route_score = round(tree.route_score(node_id), 3)
482
- svg = get_route_svg(tree, node_id)
 
483
  if svg:
484
  st.image(
485
  svg,
486
- caption=f"Route {node_id}; {num_steps} steps; Route score: {route_score}",
487
  )
488
  image_counter += 1
489
  else:
490
  st.warning(
491
- f"Could not generate SVG for route {node_id}."
492
  )
493
  except Exception as e:
494
- st.error(f"Error displaying route {node_id}: {e}")
495
  else: # Not solved
496
  st.header("Planning Results")
497
  st.warning(
@@ -583,12 +604,6 @@ def setup_clustering():
583
  st.divider()
584
  st.header("Clustering the retrosynthetic routes")
585
 
586
- # num_clusters_input = st.number_input( # This input was removed in the final user code, so omitting.
587
- # "Desired Number of Clusters (approximate):",
588
- # min_value=2, max_value=50, value=st.session_state.get("num_clusters_setting", 10),
589
- # key="num_clusters_input_key"
590
- # )
591
-
592
  if st.button("Run Clustering", key="submit_clustering_button"):
593
  # st.session_state.num_clusters_setting = num_clusters_input
594
  st.session_state.clustering_done = False
@@ -597,7 +612,8 @@ def setup_clustering():
597
  st.session_state.reactions_dict = None
598
  st.session_state.subclusters = None
599
  st.session_state.route_cgrs_dict = None
600
- st.session_state.r_route_cgrs_dict = None
 
601
 
602
  with st.spinner("Performing clustering..."):
603
  try:
@@ -608,19 +624,20 @@ def setup_clustering():
608
 
609
  st.write("Calculating RoutesCGRs...")
610
  route_cgrs_dict = compose_all_route_cgrs(current_tree)
611
- st.write("Processing ReducedRoutesCGRs...")
612
- r_route_cgrs_dict = compose_all_reduced_route_cgrs(route_cgrs_dict)
613
 
614
  results = cluster_routes(
615
- r_route_cgrs_dict, use_strat=False
616
  ) # num_clusters was removed from args
617
  results = dict(sorted(results.items(), key=lambda x: float(x[0])))
618
 
619
  st.session_state.clusters = results
620
  st.session_state.route_cgrs_dict = route_cgrs_dict
621
- st.session_state.r_route_cgrs_dict = r_route_cgrs_dict
622
  st.write("Extracting reactions...")
623
  st.session_state.reactions_dict = extract_reactions(current_tree)
 
624
 
625
  if (
626
  st.session_state.clusters is not None
@@ -634,7 +651,7 @@ def setup_clustering():
634
  st.error("Clustering failed or returned empty results.")
635
  st.session_state.clustering_done = False
636
 
637
- del results # route_cgrs_dict, r_route_cgrs_dict are stored
638
  gc.collect()
639
  st.rerun()
640
  except Exception as e:
@@ -667,49 +684,50 @@ def display_clustering_results():
667
  for cluster_num, group_data in first_items:
668
  if (
669
  not group_data
670
- or "node_ids" not in group_data
671
- or not group_data["node_ids"]
672
  ):
673
- st.warning(f"Cluster {cluster_num} has no data or node_ids.")
674
  continue
675
  st.markdown(
676
  f"**Cluster {cluster_num}** (Size: {group_data.get('group_size', 'N/A')})"
677
  )
678
- node_id = group_data["node_ids"][0]
679
  try:
680
- num_steps = len(tree.synthesis_route(node_id))
681
- route_score = round(tree.route_score(node_id), 3)
682
- svg = get_route_svg(tree, node_id)
683
- r_route_cgr = group_data.get("r_route_cgr") # Safely get r_route_cgr
684
- r_route_cgr_svg = None
685
- if r_route_cgr:
686
- r_route_cgr.clean2d()
687
- r_route_cgr_svg = cgr_display(r_route_cgr)
688
-
689
- if svg and r_route_cgr_svg:
 
690
  col1, col2 = st.columns([0.2, 0.8])
691
  with col1:
692
- st.image(r_route_cgr_svg, caption="ReducedRouteCGR")
693
  with col2:
694
  st.image(
695
  svg,
696
- caption=f"Route {node_id}; {num_steps} steps; Route score: {route_score}",
697
  )
698
  elif svg: # Only route SVG available
699
  st.image(
700
  svg,
701
- caption=f"Route {node_id}; {num_steps} steps; Route score: {route_score}",
702
  )
703
  st.warning(
704
- f"ReducedRouteCGR could not be displayed for cluster {cluster_num}."
705
  )
706
  else:
707
  st.warning(
708
- f"Could not generate SVG for route {node_id} or its ReducedRouteCGR."
709
  )
710
  except Exception as e:
711
  st.error(
712
- f"Error displaying route {node_id} for cluster {cluster_num}: {e}"
713
  )
714
 
715
  if remaining_items:
@@ -717,51 +735,52 @@ def display_clustering_results():
717
  for cluster_num, group_data in remaining_items:
718
  if (
719
  not group_data
720
- or "node_ids" not in group_data
721
- or not group_data["node_ids"]
722
  ):
723
  st.warning(
724
- f"Cluster {cluster_num} in expansion has no data or node_ids."
725
  )
726
  continue
727
  st.markdown(
728
  f"**Cluster {cluster_num}** (Size: {group_data.get('group_size', 'N/A')})"
729
  )
730
- node_id = group_data["node_ids"][0]
731
  try:
732
- num_steps = len(tree.synthesis_route(node_id))
733
- route_score = round(tree.route_score(node_id), 3)
734
- svg = get_route_svg(tree, node_id)
735
- r_route_cgr = group_data.get("r_route_cgr")
736
- r_route_cgr_svg = None
737
- if r_route_cgr:
738
- r_route_cgr.clean2d()
739
- r_route_cgr_svg = cgr_display(r_route_cgr)
740
-
741
- if svg and r_route_cgr_svg:
 
742
  col1, col2 = st.columns([0.2, 0.8])
743
  with col1:
744
- st.image(r_route_cgr_svg, caption="ReducedRouteCGR")
745
  with col2:
746
  st.image(
747
  svg,
748
- caption=f"Route {node_id}; {num_steps} steps; Route score: {route_score}",
749
  )
750
  elif svg:
751
  st.image(
752
  svg,
753
- caption=f"Route {node_id}; {num_steps} steps; Route score: {route_score}",
754
  )
755
  st.warning(
756
- f"ReducedRouteCGR could not be displayed for cluster {cluster_num}."
757
  )
758
  else:
759
  st.warning(
760
- f"Could not generate SVG for route {node_id} or its ReducedRouteCGR."
761
  )
762
  except Exception as e:
763
  st.error(
764
- f"Error displaying route {node_id} for cluster {cluster_num}: {e}"
765
  )
766
 
767
 
@@ -770,8 +789,8 @@ def download_clustering_results():
770
  if st.session_state.get("clustering_done", False):
771
  tree_for_html = st.session_state.get("tree")
772
  clusters_for_html = st.session_state.get("clusters")
773
- r_route_cgrs_for_html = st.session_state.get(
774
- "r_route_cgrs_dict"
775
  ) # This was used instead of reactions_dict in the original for report
776
 
777
  if not tree_for_html:
@@ -780,7 +799,7 @@ def download_clustering_results():
780
  if not clusters_for_html:
781
  st.warning("Cluster data not found. Cannot generate cluster reports.")
782
  return
783
- # r_route_cgrs_for_html is optional for routes_clustering_report if not essential
784
 
785
  st.subheader("Cluster Reports") # Changed subheader in original
786
  st.write("Generate downloadable HTML reports for each cluster:")
@@ -799,7 +818,7 @@ def download_clustering_results():
799
  tree_for_html,
800
  clusters_for_html, # Pass the whole dict
801
  str(cluster_idx), # Pass the key of the cluster
802
- r_route_cgrs_for_html, # Pass the r_route_cgrs dict
803
  aam=False,
804
  )
805
  st.download_button(
@@ -826,7 +845,7 @@ def download_clustering_results():
826
  tree_for_html,
827
  clusters_for_html,
828
  str(group_index),
829
- r_route_cgrs_for_html,
830
  aam=False,
831
  )
832
  st.download_button(
@@ -851,7 +870,7 @@ def download_clustering_results():
851
  tree_for_html,
852
  clusters_for_html,
853
  str(idx),
854
- r_route_cgrs_for_html,
855
  aam=False,
856
  )
857
  filename = f"cluster_{idx}_{st.session_state.target_smiles}.html"
@@ -883,19 +902,17 @@ def setup_subclustering():
883
  with st.spinner("Performing subclustering analysis..."):
884
  try:
885
  clusters_for_sub = st.session_state.get("clusters")
886
- r_route_cgrs_dict_for_sub = st.session_state.get(
887
- "r_route_cgrs_dict"
888
- )
889
  route_cgrs_dict_for_sub = st.session_state.get("route_cgrs_dict")
890
 
891
  if (
892
  clusters_for_sub
893
- and r_route_cgrs_dict_for_sub
894
  and route_cgrs_dict_for_sub
895
  ): # Ensure all are present
896
  all_subgroups = subcluster_all_clusters(
897
  clusters_for_sub,
898
- r_route_cgrs_dict_for_sub,
899
  route_cgrs_dict_for_sub,
900
  )
901
  st.session_state.subclusters = all_subgroups
@@ -907,8 +924,8 @@ def setup_subclustering():
907
  missing = []
908
  if not clusters_for_sub:
909
  missing.append("clusters")
910
- if not r_route_cgrs_dict_for_sub:
911
- missing.append("ReducedRouteCGRs dictionary")
912
  if not route_cgrs_dict_for_sub:
913
  missing.append("RouteCGRs dictionary")
914
  st.error(
@@ -970,17 +987,15 @@ def display_subclustering_results():
970
  current_subcluster_data = sub[user_input_cluster_num_display][
971
  selected_subcluster_idx
972
  ]
973
- if "r_route_cgr" in current_subcluster_data:
974
- cluster_r_route_cgr_display = current_subcluster_data[
975
- "r_route_cgr"
976
- ]
977
- cluster_r_route_cgr_display.clean2d()
978
  st.image(
979
- cluster_r_route_cgr_display.depict(),
980
- caption=f"ReducedRouteCGR of parent Cluster {user_input_cluster_num_display}",
981
  )
982
  else:
983
- st.warning("ReducedRouteCGR for this subcluster not found.")
984
  else:
985
  st.warning(
986
  f"Selected cluster {user_input_cluster_num_display} not found in subclustering results."
@@ -1002,14 +1017,14 @@ def display_subclustering_results():
1002
  subcluster_to_display = subcluster_content
1003
  if (
1004
  not subcluster_to_display
1005
- or "nodes_data" not in subcluster_to_display
1006
- or not subcluster_to_display["nodes_data"]
1007
  ):
1008
  st.info("No routes or data found for this subcluster selection.")
1009
  else:
1010
  MAX_ROUTES_PER_SUBCLUSTER = 5
1011
  all_route_ids_in_subcluster = list(
1012
- subcluster_to_display["nodes_data"].keys()
1013
  )
1014
  routes_to_display_direct = all_route_ids_in_subcluster[
1015
  :MAX_ROUTES_PER_SUBCLUSTER
@@ -1025,6 +1040,7 @@ def display_subclustering_results():
1025
  if "synthon_reaction" in subcluster_to_display:
1026
  synthon_reaction = subcluster_to_display["synthon_reaction"]
1027
  try:
 
1028
  st.image(
1029
  depict_custom_reaction(synthon_reaction),
1030
  caption=f"Markush-like pseudo reaction of subcluster",
@@ -1033,48 +1049,50 @@ def display_subclustering_results():
1033
  st.warning(f"Could not depict synthon reaction: {e_depict}")
1034
  else:
1035
  st.info("No synthon reaction data for this subcluster.")
1036
-
1037
- for route_id in routes_to_display_direct:
1038
- try:
1039
- route_score_sub = round(tree.route_score(route_id), 3)
1040
- svg_sub = get_route_svg(tree, route_id)
1041
- if svg_sub:
1042
- st.image(
1043
- svg_sub,
1044
- caption=f"Route {route_id}; Score: {route_score_sub}",
1045
- )
1046
- else:
1047
- st.warning(
1048
- f"Could not generate SVG for route {route_id}."
 
 
 
 
 
1049
  )
1050
- except Exception as e:
1051
- st.error(
1052
- f"Error displaying route {route_id} in subcluster: {e}"
1053
- )
1054
 
1055
- if remaining_routes_sub:
1056
- with st.expander(
1057
- f"... and {len(remaining_routes_sub)} more routes in this subcluster"
1058
- ):
1059
- for route_id in remaining_routes_sub:
1060
- try:
1061
- route_score_sub = round(
1062
- tree.route_score(route_id), 3
1063
- )
1064
- svg_sub = get_route_svg(tree, route_id)
1065
- if svg_sub:
1066
- st.image(
1067
- svg_sub,
1068
- caption=f"Route {route_id}; Score: {route_score_sub}",
1069
  )
1070
- else:
1071
- st.warning(
1072
- f"Could not generate SVG for route {route_id}."
 
 
 
 
 
 
 
 
 
 
 
1073
  )
1074
- except Exception as e:
1075
- st.error(
1076
- f"Error displaying route {route_id} in subcluster (expanded): {e}"
1077
- )
1078
  else:
1079
  st.info("Select a valid cluster and subcluster index to see details.")
1080
 
@@ -1089,16 +1107,16 @@ def download_subclustering_results():
1089
 
1090
  sub = st.session_state.get("subclusters")
1091
  tree = st.session_state.get("tree")
1092
- r_route_cgrs_for_report = st.session_state.get(
1093
- "r_route_cgrs_dict"
1094
  ) # Used by routes_subclustering_report
1095
 
1096
  user_input_cluster_num_display = st.session_state.subcluster_num_select_key
1097
  selected_subcluster_idx = st.session_state.subcluster_index_select_key
1098
 
1099
- if not tree or not sub or not r_route_cgrs_for_report:
1100
  st.warning(
1101
- "Missing data for subclustering report generation (tree, subclusters, or ReducedRouteCGRs)."
1102
  )
1103
  return
1104
 
@@ -1114,11 +1132,11 @@ def download_subclustering_results():
1114
  processed_subcluster_data = post_process_subgroup(
1115
  subcluster_data_for_report
1116
  )
1117
- if "nodes_data" in subcluster_data_for_report and isinstance(
1118
- subcluster_data_for_report["nodes_data"], dict
1119
  ):
1120
  processed_subcluster_data["group_lgs"] = group_by_identical_values(
1121
- subcluster_data_for_report["nodes_data"]
1122
  )
1123
  else:
1124
  processed_subcluster_data["group_lgs"] = {}
@@ -1129,7 +1147,7 @@ def download_subclustering_results():
1129
  processed_subcluster_data, # Pass the specific post-processed subcluster data
1130
  user_input_cluster_num_display,
1131
  selected_subcluster_idx,
1132
- r_route_cgrs_for_report, # Pass the whole r_route_cgrs dict
1133
  if_lg_group=True, # This parameter was in the original call
1134
  )
1135
  st.download_button(
@@ -1162,7 +1180,8 @@ def implement_restart():
1162
  "reactions_dict",
1163
  "num_clusters_setting",
1164
  "route_cgrs_dict",
1165
- "r_route_cgrs_dict",
 
1166
  "subclustering_done",
1167
  "subclusters", # "sub" was renamed
1168
  "clusters_downloaded",
 
26
  generate_results_html,
27
  html_top_routes_cluster,
28
  get_route_svg,
29
+ get_route_svg_from_json
30
  )
31
  from synplan.utils.config import TreeConfig, PolicyNetworkConfig
32
  from synplan.utils.loading import load_reaction_rules, load_building_blocks
 
170
  st.session_state.num_clusters_setting = 10
171
  if "route_cgrs_dict" not in st.session_state:
172
  st.session_state.route_cgrs_dict = None
173
+ if "sb_cgrs_dict" not in st.session_state:
174
+ st.session_state.sb_cgrs_dict = None
175
+ if "route_json" not in st.session_state:
176
+ st.session_state.route_json = None
177
 
178
  # Subclustering state
179
  if "subclustering_done" not in st.session_state:
 
222
 
223
 
224
  def handle_molecule_input():
225
+ """3. Molecule Input: Managing the input area for molecule data with two-way synchronization."""
226
  st.header("Molecule input")
227
  st.markdown(
228
  """
 
231
  * Draw it + Apply
232
  """
233
  )
234
+
235
+ if "shared_smiles" not in st.session_state:
236
+ st.session_state.shared_smiles = st.session_state.get("ketcher", DEFAULT_MOL)
237
+
238
+ if "ketcher_render_count" not in st.session_state:
239
+ st.session_state.ketcher_render_count = 0
240
+
241
+ def text_input_changed_callback():
242
+ new_text_value = (
243
+ st.session_state.smiles_text_input_key_for_sync
244
+ ) # Key of the text_input
245
+ if new_text_value != st.session_state.shared_smiles:
246
+ st.session_state.shared_smiles = new_text_value
247
+ st.session_state.ketcher = new_text_value
248
+ st.session_state.ketcher_render_count += 1
249
+
250
+ # SMILES Text Input
251
+ st.text_input(
252
+ "SMILES:",
253
+ value=st.session_state.shared_smiles,
254
+ key="smiles_text_input_key_for_sync", # Unique key for this widget
255
+ on_change=text_input_changed_callback,
256
+ help="Enter SMILES string and press Enter. The drawing will update, and vice-versa.",
257
  )
258
 
259
+ ketcher_key = f"ketcher_widget_for_sync_{st.session_state.ketcher_render_count}"
260
+ smile_code_output_from_ketcher = st_ketcher(
261
+ st.session_state.shared_smiles, key=ketcher_key
 
 
 
 
 
 
 
 
 
 
262
  )
263
 
264
+ if smile_code_output_from_ketcher != st.session_state.shared_smiles:
265
+ st.session_state.shared_smiles = smile_code_output_from_ketcher
266
+ st.session_state.ketcher = smile_code_output_from_ketcher
267
+ st.rerun()
268
+
269
+ current_smiles_for_planning = st.session_state.shared_smiles
270
+
271
+ last_planned_smiles = st.session_state.get("target_smiles")
272
  if (
273
+ last_planned_smiles
274
+ and current_smiles_for_planning != last_planned_smiles
275
+ and st.session_state.get("planning_done", False)
276
  ):
277
+ st.warning(
278
+ "Molecule structure has changed since the last successful planning run. "
279
+ "Results shown below (if any) are for the previous molecule. "
280
+ "Please re-run planning for the current structure."
281
+ )
 
 
 
 
 
282
 
283
+ # Ensure st.session_state.ketcher is consistent for other parts of the app
284
+ if st.session_state.get("ketcher") != current_smiles_for_planning:
285
+ st.session_state.ketcher = current_smiles_for_planning
286
+
287
+ return current_smiles_for_planning
288
 
289
 
290
  def setup_planning_options():
 
293
  st.markdown(
294
  """If you modified the structure, please ensure you clicked on `Apply` (bottom right of the molecular editor)."""
295
  )
296
+
 
 
297
  st.markdown(
298
  f"The molecule SMILES is actually: ``{st.session_state.get('ketcher', DEFAULT_MOL)}``"
299
  )
 
380
  st.session_state.reactions_dict = None
381
  st.session_state.subclusters = None
382
  st.session_state.route_cgrs_dict = None
383
+ st.session_state.sb_cgrs_dict = None
384
+ st.session_state.route_json = None
385
  active_smile_code = st.session_state.get(
386
  "ketcher", DEFAULT_MOL
387
  ) # Get current SMILES
 
390
  )
391
 
392
  try:
393
+ target_molecule = mol_from_smiles(active_smile_code, clean_stereo=True)
394
  if target_molecule is None:
395
  st.error(f"Could not parse the input SMILES: {active_smile_code}")
396
  else:
 
439
 
440
  mcts_progress_text = "Running MCTS iterations..."
441
  mcts_bar = st.progress(0, text=mcts_progress_text)
442
+ for step, (solved, route_id) in enumerate(tree):
443
  progress_value = min(
444
  1.0, (step + 1) / planning_params["max_iterations"]
445
  )
 
484
 
485
  st.subheader("Examples of found retrosynthetic routes")
486
  image_counter = 0
487
+ visualised_route_ids = set()
488
 
489
  if not winning_nodes:
490
  st.warning(
491
  "Planning solved, but no winning nodes found in the tree object."
492
  )
493
  else:
494
+ for n, route_id in enumerate(winning_nodes):
495
  if image_counter >= 3:
496
  break
497
+ if route_id not in visualised_route_ids:
498
  try:
499
+ visualised_route_ids.add(route_id)
500
+ num_steps = len(tree.synthesis_route(route_id))
501
+ route_score = round(tree.route_score(route_id), 3)
502
+ svg = get_route_svg(tree, route_id)
503
+ # svg = get_route_svg_from_json(st.session_state.route_json, route_id)
504
  if svg:
505
  st.image(
506
  svg,
507
+ caption=f"Route {route_id}; {num_steps} steps; Route score: {route_score}",
508
  )
509
  image_counter += 1
510
  else:
511
  st.warning(
512
+ f"Could not generate SVG for route {route_id}."
513
  )
514
  except Exception as e:
515
+ st.error(f"Error displaying route {route_id}: {e}")
516
  else: # Not solved
517
  st.header("Planning Results")
518
  st.warning(
 
604
  st.divider()
605
  st.header("Clustering the retrosynthetic routes")
606
 
 
 
 
 
 
 
607
  if st.button("Run Clustering", key="submit_clustering_button"):
608
  # st.session_state.num_clusters_setting = num_clusters_input
609
  st.session_state.clustering_done = False
 
612
  st.session_state.reactions_dict = None
613
  st.session_state.subclusters = None
614
  st.session_state.route_cgrs_dict = None
615
+ st.session_state.sb_cgrs_dict = None
616
+ st.session_state.route_json = None
617
 
618
  with st.spinner("Performing clustering..."):
619
  try:
 
624
 
625
  st.write("Calculating RoutesCGRs...")
626
  route_cgrs_dict = compose_all_route_cgrs(current_tree)
627
+ st.write("Processing SB-CGRs...")
628
+ sb_cgrs_dict = compose_all_sb_cgrs(route_cgrs_dict)
629
 
630
  results = cluster_routes(
631
+ sb_cgrs_dict, use_strat=False
632
  ) # num_clusters was removed from args
633
  results = dict(sorted(results.items(), key=lambda x: float(x[0])))
634
 
635
  st.session_state.clusters = results
636
  st.session_state.route_cgrs_dict = route_cgrs_dict
637
+ st.session_state.sb_cgrs_dict = sb_cgrs_dict
638
  st.write("Extracting reactions...")
639
  st.session_state.reactions_dict = extract_reactions(current_tree)
640
+ st.session_state.route_json = make_json(st.session_state.reactions_dict)
641
 
642
  if (
643
  st.session_state.clusters is not None
 
651
  st.error("Clustering failed or returned empty results.")
652
  st.session_state.clustering_done = False
653
 
654
+ del results # route_cgrs_dict, sb_cgrs_dict are stored
655
  gc.collect()
656
  st.rerun()
657
  except Exception as e:
 
684
  for cluster_num, group_data in first_items:
685
  if (
686
  not group_data
687
+ or "route_ids" not in group_data
688
+ or not group_data["route_ids"]
689
  ):
690
+ st.warning(f"Cluster {cluster_num} has no data or route_ids.")
691
  continue
692
  st.markdown(
693
  f"**Cluster {cluster_num}** (Size: {group_data.get('group_size', 'N/A')})"
694
  )
695
+ route_id = group_data["route_ids"][0]
696
  try:
697
+ num_steps = len(tree.synthesis_route(route_id))
698
+ route_score = round(tree.route_score(route_id), 3)
699
+ # svg = get_route_svg(tree, route_id)
700
+ svg = get_route_svg_from_json(st.session_state.route_json, route_id)
701
+ sb_cgr = group_data.get("sb_cgr") # Safely get sb_cgr
702
+ sb_cgr_svg = None
703
+ if sb_cgr:
704
+ sb_cgr.clean2d()
705
+ sb_cgr_svg = cgr_display(sb_cgr)
706
+
707
+ if svg and sb_cgr_svg:
708
  col1, col2 = st.columns([0.2, 0.8])
709
  with col1:
710
+ st.image(sb_cgr_svg, caption="SB-CGR")
711
  with col2:
712
  st.image(
713
  svg,
714
+ caption=f"Route {route_id}; {num_steps} steps; Route score: {route_score}",
715
  )
716
  elif svg: # Only route SVG available
717
  st.image(
718
  svg,
719
+ caption=f"Route {route_id}; {num_steps} steps; Route score: {route_score}",
720
  )
721
  st.warning(
722
+ f"SB-CGR could not be displayed for cluster {cluster_num}."
723
  )
724
  else:
725
  st.warning(
726
+ f"Could not generate SVG for route {route_id} or its SB-CGR."
727
  )
728
  except Exception as e:
729
  st.error(
730
+ f"Error displaying route {route_id} for cluster {cluster_num}: {e}"
731
  )
732
 
733
  if remaining_items:
 
735
  for cluster_num, group_data in remaining_items:
736
  if (
737
  not group_data
738
+ or "route_ids" not in group_data
739
+ or not group_data["route_ids"]
740
  ):
741
  st.warning(
742
+ f"Cluster {cluster_num} in expansion has no data or route_ids."
743
  )
744
  continue
745
  st.markdown(
746
  f"**Cluster {cluster_num}** (Size: {group_data.get('group_size', 'N/A')})"
747
  )
748
+ route_id = group_data["route_ids"][0]
749
  try:
750
+ num_steps = len(tree.synthesis_route(route_id))
751
+ route_score = round(tree.route_score(route_id), 3)
752
+ # svg = get_route_svg(tree, route_id)
753
+ svg = get_route_svg_from_json(st.session_state.route_json, route_id)
754
+ sb_cgr = group_data.get("sb_cgr")
755
+ sb_cgr_svg = None
756
+ if sb_cgr:
757
+ sb_cgr.clean2d()
758
+ sb_cgr_svg = cgr_display(sb_cgr)
759
+
760
+ if svg and sb_cgr_svg:
761
  col1, col2 = st.columns([0.2, 0.8])
762
  with col1:
763
+ st.image(sb_cgr_svg, caption="SB-CGR")
764
  with col2:
765
  st.image(
766
  svg,
767
+ caption=f"Route {route_id}; {num_steps} steps; Route score: {route_score}",
768
  )
769
  elif svg:
770
  st.image(
771
  svg,
772
+ caption=f"Route {route_id}; {num_steps} steps; Route score: {route_score}",
773
  )
774
  st.warning(
775
+ f"SB-CGR could not be displayed for cluster {cluster_num}."
776
  )
777
  else:
778
  st.warning(
779
+ f"Could not generate SVG for route {route_id} or its SB-CGR."
780
  )
781
  except Exception as e:
782
  st.error(
783
+ f"Error displaying route {route_id} for cluster {cluster_num}: {e}"
784
  )
785
 
786
 
 
789
  if st.session_state.get("clustering_done", False):
790
  tree_for_html = st.session_state.get("tree")
791
  clusters_for_html = st.session_state.get("clusters")
792
+ sb_cgrs_for_html = st.session_state.get(
793
+ "sb_cgrs_dict"
794
  ) # This was used instead of reactions_dict in the original for report
795
 
796
  if not tree_for_html:
 
799
  if not clusters_for_html:
800
  st.warning("Cluster data not found. Cannot generate cluster reports.")
801
  return
802
+ # sb_cgrs_for_html is optional for routes_clustering_report if not essential
803
 
804
  st.subheader("Cluster Reports") # Changed subheader in original
805
  st.write("Generate downloadable HTML reports for each cluster:")
 
818
  tree_for_html,
819
  clusters_for_html, # Pass the whole dict
820
  str(cluster_idx), # Pass the key of the cluster
821
+ sb_cgrs_for_html, # Pass the sb_cgrs dict
822
  aam=False,
823
  )
824
  st.download_button(
 
845
  tree_for_html,
846
  clusters_for_html,
847
  str(group_index),
848
+ sb_cgrs_for_html,
849
  aam=False,
850
  )
851
  st.download_button(
 
870
  tree_for_html,
871
  clusters_for_html,
872
  str(idx),
873
+ sb_cgrs_for_html,
874
  aam=False,
875
  )
876
  filename = f"cluster_{idx}_{st.session_state.target_smiles}.html"
 
902
  with st.spinner("Performing subclustering analysis..."):
903
  try:
904
  clusters_for_sub = st.session_state.get("clusters")
905
+ sb_cgrs_dict_for_sub = st.session_state.get("sb_cgrs_dict")
 
 
906
  route_cgrs_dict_for_sub = st.session_state.get("route_cgrs_dict")
907
 
908
  if (
909
  clusters_for_sub
910
+ and sb_cgrs_dict_for_sub
911
  and route_cgrs_dict_for_sub
912
  ): # Ensure all are present
913
  all_subgroups = subcluster_all_clusters(
914
  clusters_for_sub,
915
+ sb_cgrs_dict_for_sub,
916
  route_cgrs_dict_for_sub,
917
  )
918
  st.session_state.subclusters = all_subgroups
 
924
  missing = []
925
  if not clusters_for_sub:
926
  missing.append("clusters")
927
+ if not sb_cgrs_dict_for_sub:
928
+ missing.append("SB-CGRs dictionary")
929
  if not route_cgrs_dict_for_sub:
930
  missing.append("RouteCGRs dictionary")
931
  st.error(
 
987
  current_subcluster_data = sub[user_input_cluster_num_display][
988
  selected_subcluster_idx
989
  ]
990
+ if "sb_cgr" in current_subcluster_data:
991
+ cluster_sb_cgr_display = current_subcluster_data["sb_cgr"]
992
+ cluster_sb_cgr_display.clean2d()
 
 
993
  st.image(
994
+ cluster_sb_cgr_display.depict(),
995
+ caption=f"SB-CGR of parent Cluster {user_input_cluster_num_display}",
996
  )
997
  else:
998
+ st.warning("SB-CGR for this subcluster not found.")
999
  else:
1000
  st.warning(
1001
  f"Selected cluster {user_input_cluster_num_display} not found in subclustering results."
 
1017
  subcluster_to_display = subcluster_content
1018
  if (
1019
  not subcluster_to_display
1020
+ or "routes_data" not in subcluster_to_display
1021
+ or not subcluster_to_display["routes_data"]
1022
  ):
1023
  st.info("No routes or data found for this subcluster selection.")
1024
  else:
1025
  MAX_ROUTES_PER_SUBCLUSTER = 5
1026
  all_route_ids_in_subcluster = list(
1027
+ subcluster_to_display["routes_data"].keys()
1028
  )
1029
  routes_to_display_direct = all_route_ids_in_subcluster[
1030
  :MAX_ROUTES_PER_SUBCLUSTER
 
1040
  if "synthon_reaction" in subcluster_to_display:
1041
  synthon_reaction = subcluster_to_display["synthon_reaction"]
1042
  try:
1043
+ synthon_reaction.clean2d()
1044
  st.image(
1045
  depict_custom_reaction(synthon_reaction),
1046
  caption=f"Markush-like pseudo reaction of subcluster",
 
1049
  st.warning(f"Could not depict synthon reaction: {e_depict}")
1050
  else:
1051
  st.info("No synthon reaction data for this subcluster.")
1052
+ with st.container(height=500):
1053
+ for route_id in routes_to_display_direct:
1054
+ try:
1055
+ route_score_sub = round(tree.route_score(route_id), 3)
1056
+ # svg_sub = get_route_svg(tree, route_id)
1057
+ svg_sub = get_route_svg_from_json(st.session_state.route_json, route_id)
1058
+ if svg_sub:
1059
+ st.image(
1060
+ svg_sub,
1061
+ caption=f"Route {route_id}; Score: {route_score_sub}",
1062
+ )
1063
+ else:
1064
+ st.warning(
1065
+ f"Could not generate SVG for route {route_id}."
1066
+ )
1067
+ except Exception as e:
1068
+ st.error(
1069
+ f"Error displaying route {route_id} in subcluster: {e}"
1070
  )
 
 
 
 
1071
 
1072
+ if remaining_routes_sub:
1073
+ with st.expander(
1074
+ f"... and {len(remaining_routes_sub)} more routes in this subcluster"
1075
+ ):
1076
+ for route_id in remaining_routes_sub:
1077
+ try:
1078
+ route_score_sub = round(
1079
+ tree.route_score(route_id), 3
 
 
 
 
 
 
1080
  )
1081
+ # svg_sub = get_route_svg(tree, route_id)
1082
+ svg_sub = get_route_svg_from_json(st.session_state.route_json, route_id)
1083
+ if svg_sub:
1084
+ st.image(
1085
+ svg_sub,
1086
+ caption=f"Route {route_id}; Score: {route_score_sub}",
1087
+ )
1088
+ else:
1089
+ st.warning(
1090
+ f"Could not generate SVG for route {route_id}."
1091
+ )
1092
+ except Exception as e:
1093
+ st.error(
1094
+ f"Error displaying route {route_id} in subcluster (expanded): {e}"
1095
  )
 
 
 
 
1096
  else:
1097
  st.info("Select a valid cluster and subcluster index to see details.")
1098
 
 
1107
 
1108
  sub = st.session_state.get("subclusters")
1109
  tree = st.session_state.get("tree")
1110
+ sb_cgrs_for_report = st.session_state.get(
1111
+ "sb_cgrs_dict"
1112
  ) # Used by routes_subclustering_report
1113
 
1114
  user_input_cluster_num_display = st.session_state.subcluster_num_select_key
1115
  selected_subcluster_idx = st.session_state.subcluster_index_select_key
1116
 
1117
+ if not tree or not sub or not sb_cgrs_for_report:
1118
  st.warning(
1119
+ "Missing data for subclustering report generation (tree, subclusters, or SB-CGRs)."
1120
  )
1121
  return
1122
 
 
1132
  processed_subcluster_data = post_process_subgroup(
1133
  subcluster_data_for_report
1134
  )
1135
+ if "routes_data" in subcluster_data_for_report and isinstance(
1136
+ subcluster_data_for_report["routes_data"], dict
1137
  ):
1138
  processed_subcluster_data["group_lgs"] = group_by_identical_values(
1139
+ subcluster_data_for_report["routes_data"]
1140
  )
1141
  else:
1142
  processed_subcluster_data["group_lgs"] = {}
 
1147
  processed_subcluster_data, # Pass the specific post-processed subcluster data
1148
  user_input_cluster_num_display,
1149
  selected_subcluster_idx,
1150
+ sb_cgrs_for_report, # Pass the whole sb_cgrs dict
1151
  if_lg_group=True, # This parameter was in the original call
1152
  )
1153
  st.download_button(
 
1180
  "reactions_dict",
1181
  "num_clusters_setting",
1182
  "route_cgrs_dict",
1183
+ "sb_cgrs_dict",
1184
+ "route_json",
1185
  "subclustering_done",
1186
  "subclusters", # "sub" was renamed
1187
  "clusters_downloaded",