import base64
import pickle
import re
import uuid
import io
import zipfile
import pandas as pd
import streamlit as st
from CGRtools.files import SMILESRead
from streamlit_ketcher import st_ketcher
from huggingface_hub import hf_hub_download
from huggingface_hub.utils import disable_progress_bars
from synplan.mcts.expansion import PolicyNetworkFunction
from synplan.mcts.search import extract_tree_stats
from synplan.mcts.tree import Tree
from synplan.chem.utils import mol_from_smiles
from synplan.chem.reaction_routes.route_cgr import *
from synplan.chem.reaction_routes.clustering import *
from synplan.utils.visualisation import (
routes_clustering_report,
routes_subclustering_report,
generate_results_html,
html_top_routes_cluster,
get_route_svg,
get_route_svg_from_json
)
from synplan.utils.config import TreeConfig, PolicyNetworkConfig
from synplan.utils.loading import load_reaction_rules, load_building_blocks
import psutil
import gc
disable_progress_bars("huggingface_hub")
smiles_parser = SMILESRead.create_parser(ignore=True)
DEFAULT_MOL = "c1cc(ccc1Cl)C(CCO)NC(C2(CCN(CC2)c3c4cc[nH]c4ncn3)N)=O"
# --- Helper Functions ---
def download_button(
object_to_download, download_filename, button_text, pickle_it=False
):
"""
Issued from
Generates a link to download the given object_to_download.
Params:
------
object_to_download: The object to be downloaded.
download_filename (str): filename and extension of file. e.g. mydata.csv,
some_txt_output.txt download_link_text (str): Text to display for download
link.
button_text (str): Text to display on download button (e.g. 'click here to download file')
pickle_it (bool): If True, pickle file.
Returns:
-------
(str): the anchor tag to download object_to_download
Examples:
--------
download_link(your_df, 'YOUR_DF.csv', 'Click to download data!')
download_link(your_str, 'YOUR_STRING.txt', 'Click to download text!')
"""
if pickle_it:
try:
object_to_download = pickle.dumps(object_to_download)
except pickle.PicklingError as e:
st.write(e)
return None
else:
if isinstance(object_to_download, bytes):
pass
elif isinstance(object_to_download, pd.DataFrame):
object_to_download = object_to_download.to_csv(index=False).encode("utf-8")
try:
b64 = base64.b64encode(object_to_download.encode()).decode()
except AttributeError:
b64 = base64.b64encode(object_to_download).decode()
button_uuid = str(uuid.uuid4()).replace("-", "")
button_id = re.sub("\d+", "", button_uuid)
custom_css = f"""
"""
dl_link = (
custom_css
+ f'{button_text}
'
)
return dl_link
@st.cache_resource
def load_planning_resources_cached(): # Renamed to avoid conflict if main calls it directly
building_blocks_path = hf_hub_download(
repo_id="Laboratoire-De-Chemoinformatique/SynPlanner",
filename="building_blocks_em_sa_ln.smi",
subfolder="building_blocks",
local_dir=".",
)
ranking_policy_weights_path = hf_hub_download(
repo_id="Laboratoire-De-Chemoinformatique/SynPlanner",
filename="ranking_policy_network.ckpt",
subfolder="uspto/weights",
local_dir=".",
)
reaction_rules_path = hf_hub_download(
repo_id="Laboratoire-De-Chemoinformatique/SynPlanner",
filename="uspto_reaction_rules.pickle",
subfolder="uspto",
local_dir=".",
)
return building_blocks_path, ranking_policy_weights_path, reaction_rules_path
# --- GUI Sections ---
def initialize_app():
"""1. Initialization: Setting up the main window, layout, and initial widgets."""
st.set_page_config(page_title="SynPlanner GUI", page_icon="๐งช", layout="wide")
# Initialize session state variables if they don't exist.
if "planning_done" not in st.session_state:
st.session_state.planning_done = False
if "tree" not in st.session_state:
st.session_state.tree = None
if "res" not in st.session_state:
st.session_state.res = None
if "target_smiles" not in st.session_state:
st.session_state.target_smiles = (
"" # Initial value, might be overwritten by ketcher
)
# Clustering state
if "clustering_done" not in st.session_state:
st.session_state.clustering_done = False
if "clusters" not in st.session_state:
st.session_state.clusters = None
if "reactions_dict" not in st.session_state:
st.session_state.reactions_dict = None
if "num_clusters_setting" not in st.session_state: # Store the setting used
st.session_state.num_clusters_setting = 10
if "route_cgrs_dict" not in st.session_state:
st.session_state.route_cgrs_dict = None
if "sb_cgrs_dict" not in st.session_state:
st.session_state.sb_cgrs_dict = None
if "route_json" not in st.session_state:
st.session_state.route_json = None
# Subclustering state
if "subclustering_done" not in st.session_state:
st.session_state.subclustering_done = False
if "subclusters" not in st.session_state: # Renamed from 'sub' for clarity
st.session_state.subclusters = None
# Download state (less critical now with direct download links)
if "clusters_downloaded" not in st.session_state: # Example, might not be needed
st.session_state.clusters_downloaded = False
if "ketcher" not in st.session_state: # For ketcher persistence
st.session_state.ketcher = DEFAULT_MOL
intro_text = """
This is a demo of the graphical user interface of
[SynPlanner](https://github.com/Laboratoire-de-Chemoinformatique/SynPlanner/).
SynPlanner is a comprehensive tool for reaction data curation, rule extraction, model training and retrosynthetic planning.
More information on SynPlanner is available in the [official docs](https://synplanner.readthedocs.io/en/latest/index.html).
"""
st.title("`SynPlanner GUI`")
st.write(intro_text)
def setup_sidebar():
"""2. Sidebar: Handling the widgets and logic within the sidebar area."""
# st.sidebar.image("img/logo.png") # Assuming img/logo.png is available
st.sidebar.title("Docs")
st.sidebar.markdown("https://synplanner.readthedocs.io/en/latest/")
st.sidebar.title("Tutorials")
st.sidebar.markdown(
"https://github.com/Laboratoire-de-Chemoinformatique/SynPlanner/tree/main/tutorials"
)
st.sidebar.title("Paper")
st.sidebar.markdown(
"https://chemrxiv.org/engage/chemrxiv/article-details/66add90bc9c6a5c07ae65796"
)
st.sidebar.title("Issues")
st.sidebar.markdown(
"[Report a bug ๐](https://github.com/Laboratoire-de-Chemoinformatique/SynPlanner/issues/new?assignees=&labels=bug&projects=&template=bug_report.md&title=%5BBUG%5D)"
)
def handle_molecule_input():
"""3. Molecule Input: Managing the input area for molecule data with two-way synchronization."""
st.header("Molecule input")
st.markdown(
"""
You can provide a molecular structure by either providing:
* SMILES string + Enter
* Draw it + Apply
"""
)
if "shared_smiles" not in st.session_state:
st.session_state.shared_smiles = st.session_state.get("ketcher", DEFAULT_MOL)
if "ketcher_render_count" not in st.session_state:
st.session_state.ketcher_render_count = 0
def text_input_changed_callback():
new_text_value = (
st.session_state.smiles_text_input_key_for_sync
) # Key of the text_input
if new_text_value != st.session_state.shared_smiles:
st.session_state.shared_smiles = new_text_value
st.session_state.ketcher = new_text_value
st.session_state.ketcher_render_count += 1
# SMILES Text Input
st.text_input(
"SMILES:",
value=st.session_state.shared_smiles,
key="smiles_text_input_key_for_sync", # Unique key for this widget
on_change=text_input_changed_callback,
help="Enter SMILES string and press Enter. The drawing will update, and vice-versa.",
)
ketcher_key = f"ketcher_widget_for_sync_{st.session_state.ketcher_render_count}"
smile_code_output_from_ketcher = st_ketcher(
st.session_state.shared_smiles, key=ketcher_key
)
if smile_code_output_from_ketcher != st.session_state.shared_smiles:
st.session_state.shared_smiles = smile_code_output_from_ketcher
st.session_state.ketcher = smile_code_output_from_ketcher
st.rerun()
current_smiles_for_planning = st.session_state.shared_smiles
last_planned_smiles = st.session_state.get("target_smiles")
if (
last_planned_smiles
and current_smiles_for_planning != last_planned_smiles
and st.session_state.get("planning_done", False)
):
st.warning(
"Molecule structure has changed since the last successful planning run. "
"Results shown below (if any) are for the previous molecule. "
"Please re-run planning for the current structure."
)
# Ensure st.session_state.ketcher is consistent for other parts of the app
if st.session_state.get("ketcher") != current_smiles_for_planning:
st.session_state.ketcher = current_smiles_for_planning
return current_smiles_for_planning
def setup_planning_options():
"""4. Planning: Encapsulating the logic related to the "planning" functionality."""
st.header("Launch calculation")
st.markdown(
"""If you modified the structure, please ensure you clicked on `Apply` (bottom right of the molecular editor)."""
)
st.markdown(
f"The molecule SMILES is actually: ``{st.session_state.get('ketcher', DEFAULT_MOL)}``"
)
st.subheader("Planning options")
st.markdown(
"""
The description of each option can be found in the
[Retrosynthetic Planning Tutorial](https://synplanner.readthedocs.io/en/latest/tutorial_files/retrosynthetic_planning.html#Configuring-search-tree).
"""
)
col_options_1, col_options_2 = st.columns(2, gap="medium")
with col_options_1:
search_strategy_input = st.selectbox(
label="Search strategy",
options=(
"Expansion first",
"Evaluation first",
),
index=0,
key="search_strategy_input",
)
ucb_type = st.selectbox(
label="UCB type",
options=("uct", "puct", "value"),
index=0,
key="ucb_type_input",
)
c_ucb = st.number_input(
"C coefficient of UCB",
value=0.1,
placeholder="Type a number...",
key="c_ucb_input",
)
with col_options_2:
max_iterations = st.slider(
"Total number of MCTS iterations",
min_value=50,
max_value=1000,
value=300,
key="max_iterations_slider",
)
max_depth = st.slider(
"Maximal number of reaction steps",
min_value=3,
max_value=9,
value=6,
key="max_depth_slider",
)
min_mol_size = st.slider(
"Minimum size of a molecule to be precursor",
min_value=0,
max_value=7,
value=0,
key="min_mol_size_slider",
help="Number of non-hydrogen atoms in molecule",
)
search_strategy_translator = {
"Expansion first": "expansion_first",
"Evaluation first": "evaluation_first",
}
search_strategy = search_strategy_translator[search_strategy_input]
planning_params = {
"search_strategy": search_strategy,
"ucb_type": ucb_type,
"c_ucb": c_ucb,
"max_iterations": max_iterations,
"max_depth": max_depth,
"min_mol_size": min_mol_size,
}
if st.button("Start retrosynthetic planning", key="submit_planning_button"):
# Reset downstream states if replanning
st.session_state.planning_done = False
st.session_state.clustering_done = False
st.session_state.subclustering_done = False
st.session_state.tree = None
st.session_state.res = None
st.session_state.clusters = None
st.session_state.reactions_dict = None
st.session_state.subclusters = None
st.session_state.route_cgrs_dict = None
st.session_state.sb_cgrs_dict = None
st.session_state.route_json = None
active_smile_code = st.session_state.get(
"ketcher", DEFAULT_MOL
) # Get current SMILES
st.session_state.target_smiles = (
active_smile_code # Store the SMILES used for this run
)
try:
target_molecule = mol_from_smiles(active_smile_code, clean_stereo=True)
if target_molecule is None:
raise ValueError(f"Could not parse the input SMILES: {active_smile_code}")
(
building_blocks_path,
ranking_policy_weights_path,
reaction_rules_path,
) = load_planning_resources_cached()
with st.spinner("Running retrosynthetic planning..."):
with st.status("Loading resources...", expanded=False) as status:
st.write("Loading building blocks...")
building_blocks = load_building_blocks(
building_blocks_path, standardize=False
)
st.write("Loading reaction rules...")
reaction_rules = load_reaction_rules(reaction_rules_path)
st.write("Loading policy network...")
policy_config = PolicyNetworkConfig(
weights_path=ranking_policy_weights_path
)
policy_function = PolicyNetworkFunction(
policy_config=policy_config
)
status.update(label="Resources loaded!", state="complete")
tree_config = TreeConfig(
search_strategy=planning_params["search_strategy"],
evaluation_type="rollout",
max_iterations=planning_params["max_iterations"],
max_depth=planning_params["max_depth"],
min_mol_size=planning_params["min_mol_size"],
init_node_value=0.5,
ucb_type=planning_params["ucb_type"],
c_ucb=planning_params["c_ucb"],
silent=True,
)
tree = Tree(
target=target_molecule,
config=tree_config,
reaction_rules=reaction_rules,
building_blocks=building_blocks,
expansion_function=policy_function,
evaluation_function=None,
)
mcts_progress_text = "Running MCTS iterations..."
mcts_bar = st.progress(0, text=mcts_progress_text)
for step, (solved, route_id) in enumerate(tree):
progress_value = min(
1.0, (step + 1) / planning_params["max_iterations"]
)
mcts_bar.progress(
progress_value,
text=f"{mcts_progress_text} ({step+1}/{planning_params['max_iterations']})",
)
res = extract_tree_stats(tree, target_molecule)
st.session_state["tree"] = tree
st.session_state["res"] = res
st.session_state.planning_done = True
st.rerun()
except (ValueError, KeyError, FileNotFoundError, TypeError) as e:
st.error(f"An error occurred during planning: {e}")
st.session_state.planning_done = False
def display_planning_results():
"""5. Planning Results Display: Handling the presentation of results."""
if st.session_state.get("planning_done", False):
res = st.session_state.res
tree = st.session_state.tree
if res is None or tree is None:
st.error(
"Planning results are missing from session state. Please re-run planning."
)
st.session_state.planning_done = False # Reset state
return # Exit this function if no results
if res.get("solved", False): # Use .get for safety
st.header("Planning Results")
winning_nodes = (
sorted(set(tree.winning_nodes))
if hasattr(tree, "winning_nodes") and tree.winning_nodes
else []
)
st.subheader(f"Number of unique routes found: {len(winning_nodes)}")
st.subheader("Examples of found retrosynthetic routes")
image_counter = 0
visualised_route_ids = set()
if not winning_nodes:
st.warning(
"Planning solved, but no winning nodes found in the tree object."
)
else:
for n, route_id in enumerate(winning_nodes):
if image_counter >= 3:
break
if route_id not in visualised_route_ids:
try:
visualised_route_ids.add(route_id)
num_steps = len(tree.synthesis_route(route_id))
route_score = round(tree.route_score(route_id), 3)
svg = get_route_svg(tree, route_id)
# svg = get_route_svg_from_json(st.session_state.route_json, route_id)
if svg:
st.image(
svg,
caption=f"Route {route_id}; {num_steps} steps; Route score: {route_score}",
)
image_counter += 1
else:
st.warning(
f"Could not generate SVG for route {route_id}."
)
except Exception as e:
st.error(f"Error displaying route {route_id}: {e}")
else: # Not solved
st.header("Planning Results")
st.warning(
"No reaction path found for the target molecule with the current settings."
)
st.write(
"Consider adjusting planning options (e.g., increase iterations, adjust depth, check molecule validity)."
)
stat_col, _ = st.columns(2)
with stat_col:
st.subheader("Run Statistics (No Solution)")
try:
if (
"target_smiles" not in res
and "target_smiles" in st.session_state
):
res["target_smiles"] = st.session_state.target_smiles
cols_to_show = [
col
for col in [
"target_smiles",
"num_nodes",
"num_iter",
"search_time",
]
if col in res
]
if cols_to_show:
df = pd.DataFrame(res, index=[0])[cols_to_show]
st.dataframe(df)
else:
st.write("No statistics to display for the unsuccessful run.")
except Exception as e:
st.error(f"Error displaying statistics: {e}")
st.write(res)
def download_planning_results():
"""6. Planning Results Download: Providing functionality to download."""
if (
st.session_state.get("planning_done", False)
and st.session_state.res
and st.session_state.res.get("solved", False)
):
res = st.session_state.res
tree = st.session_state.tree
# This section is usually placed within a column in the original script
# We'll assume it's called after display_planning_results and can use a new column or area.
# For proper layout, this should be integrated with display_planning_results' columns.
# For now, creating a placeholder or separate section for downloads:
# st.subheader("Downloads") # This might be redundant if called within a layout context.
# The original code places downloads in the second column of planning results.
# To replicate, we'd need to pass the column object or call this within that context.
# Simulating this by just creating the download links:
try:
html_body = generate_results_html(tree, html_path=None, extended=True)
dl_html = download_button(
html_body,
f"results_synplanner_{st.session_state.target_smiles}.html",
"Download results (HTML)",
)
if dl_html:
st.markdown(dl_html, unsafe_allow_html=True)
try:
res_df = pd.DataFrame(res, index=[0])
dl_csv = download_button(
res_df,
f"stats_synplanner_{st.session_state.target_smiles}.csv",
"Download statistics (CSV)",
)
if dl_csv:
st.markdown(dl_csv, unsafe_allow_html=True)
except Exception as e:
st.error(f"Could not prepare statistics CSV for download: {e}")
except Exception as e:
st.error(f"Error generating download links for planning results: {e}")
def setup_clustering():
"""7. Clustering: Encapsulating the logic related to the "clustering" functionality."""
if (
st.session_state.get("planning_done", False)
and st.session_state.res
and st.session_state.res.get("solved", False)
):
st.divider()
st.header("Clustering the retrosynthetic routes")
if st.button("Run Clustering", key="submit_clustering_button"):
# st.session_state.num_clusters_setting = num_clusters_input
st.session_state.clustering_done = False
st.session_state.subclustering_done = False
st.session_state.clusters = None
st.session_state.reactions_dict = None
st.session_state.subclusters = None
st.session_state.route_cgrs_dict = None
st.session_state.sb_cgrs_dict = None
st.session_state.route_json = None
with st.spinner("Performing clustering..."):
try:
current_tree = st.session_state.tree
if not current_tree:
st.error("Tree object not found. Please re-run planning.")
return
st.write("Calculating RoutesCGRs...")
route_cgrs_dict = compose_all_route_cgrs(current_tree)
st.write("Processing SB-CGRs...")
sb_cgrs_dict = compose_all_sb_cgrs(route_cgrs_dict)
results = cluster_routes(
sb_cgrs_dict, use_strat=False
) # num_clusters was removed from args
results = dict(sorted(results.items(), key=lambda x: float(x[0])))
st.session_state.clusters = results
st.session_state.route_cgrs_dict = route_cgrs_dict
st.session_state.sb_cgrs_dict = sb_cgrs_dict
st.write("Extracting reactions...")
st.session_state.reactions_dict = extract_reactions(current_tree)
st.session_state.route_json = make_json(st.session_state.reactions_dict)
if (
st.session_state.clusters is not None
and st.session_state.reactions_dict is not None
): # Check for None explicitly
st.session_state.clustering_done = True
st.success(
f"Clustering complete. Found {len(st.session_state.clusters)} clusters."
)
else:
st.error("Clustering failed or returned empty results.")
st.session_state.clustering_done = False
del results # route_cgrs_dict, sb_cgrs_dict are stored
gc.collect()
st.rerun()
except Exception as e:
st.error(f"An error occurred during clustering: {e}")
st.session_state.clustering_done = False
def display_clustering_results():
"""8. Clustering Results Display: Handling the presentation of results."""
if st.session_state.get("clustering_done", False):
clusters = st.session_state.clusters
# reactions_dict = st.session_state.reactions_dict # Needed for download, not directly for display here
tree = st.session_state.tree
MAX_DISPLAY_CLUSTERS_DATA = 10
if (
clusters is None or tree is None
): # reactions_dict removed as not critical for display part
st.error(
"Clustering results (clusters or tree) are missing. Please re-run clustering."
)
st.session_state.clustering_done = False
return
st.subheader(f"Best routes from {len(clusters)} Found Clusters")
clusters_items = list(clusters.items())
first_items = clusters_items[:MAX_DISPLAY_CLUSTERS_DATA]
remaining_items = clusters_items[MAX_DISPLAY_CLUSTERS_DATA:]
for cluster_num, group_data in first_items:
if (
not group_data
or "route_ids" not in group_data
or not group_data["route_ids"]
):
st.warning(f"Cluster {cluster_num} has no data or route_ids.")
continue
st.markdown(
f"**Cluster {cluster_num}** (Size: {group_data.get('group_size', 'N/A')})"
)
route_id = group_data["route_ids"][0]
try:
num_steps = len(tree.synthesis_route(route_id))
route_score = round(tree.route_score(route_id), 3)
# svg = get_route_svg(tree, route_id)
svg = get_route_svg_from_json(st.session_state.route_json, route_id)
sb_cgr = group_data.get("sb_cgr") # Safely get sb_cgr
sb_cgr_svg = None
if sb_cgr:
sb_cgr.clean2d()
sb_cgr_svg = cgr_display(sb_cgr)
if svg and sb_cgr_svg:
col1, col2 = st.columns([0.2, 0.8])
with col1:
st.image(sb_cgr_svg, caption="SB-CGR")
with col2:
st.image(
svg,
caption=f"Route {route_id}; {num_steps} steps; Route score: {route_score}",
)
elif svg: # Only route SVG available
st.image(
svg,
caption=f"Route {route_id}; {num_steps} steps; Route score: {route_score}",
)
st.warning(
f"SB-CGR could not be displayed for cluster {cluster_num}."
)
else:
st.warning(
f"Could not generate SVG for route {route_id} or its SB-CGR."
)
except Exception as e:
st.error(
f"Error displaying route {route_id} for cluster {cluster_num}: {e}"
)
if remaining_items:
with st.expander(f"... and {len(remaining_items)} more clusters"):
for cluster_num, group_data in remaining_items:
if (
not group_data
or "route_ids" not in group_data
or not group_data["route_ids"]
):
st.warning(
f"Cluster {cluster_num} in expansion has no data or route_ids."
)
continue
st.markdown(
f"**Cluster {cluster_num}** (Size: {group_data.get('group_size', 'N/A')})"
)
route_id = group_data["route_ids"][0]
try:
num_steps = len(tree.synthesis_route(route_id))
route_score = round(tree.route_score(route_id), 3)
# svg = get_route_svg(tree, route_id)
svg = get_route_svg_from_json(st.session_state.route_json, route_id)
sb_cgr = group_data.get("sb_cgr")
sb_cgr_svg = None
if sb_cgr:
sb_cgr.clean2d()
sb_cgr_svg = cgr_display(sb_cgr)
if svg and sb_cgr_svg:
col1, col2 = st.columns([0.2, 0.8])
with col1:
st.image(sb_cgr_svg, caption="SB-CGR")
with col2:
st.image(
svg,
caption=f"Route {route_id}; {num_steps} steps; Route score: {route_score}",
)
elif svg:
st.image(
svg,
caption=f"Route {route_id}; {num_steps} steps; Route score: {route_score}",
)
st.warning(
f"SB-CGR could not be displayed for cluster {cluster_num}."
)
else:
st.warning(
f"Could not generate SVG for route {route_id} or its SB-CGR."
)
except Exception as e:
st.error(
f"Error displaying route {route_id} for cluster {cluster_num}: {e}"
)
def download_clustering_results():
"""10. Clustering Results Download: Providing functionality to download."""
if st.session_state.get("clustering_done", False):
tree_for_html = st.session_state.get("tree")
clusters_for_html = st.session_state.get("clusters")
sb_cgrs_for_html = st.session_state.get(
"sb_cgrs_dict"
) # This was used instead of reactions_dict in the original for report
if not tree_for_html:
st.warning("MCTS Tree data not found. Cannot generate cluster reports.")
return
if not clusters_for_html:
st.warning("Cluster data not found. Cannot generate cluster reports.")
return
# sb_cgrs_for_html is optional for routes_clustering_report if not essential
st.subheader("Cluster Reports") # Changed subheader in original
st.write("Generate downloadable HTML reports for each cluster:")
MAX_DOWNLOAD_LINKS_DISPLAYED = 10
num_clusters_total = len(clusters_for_html)
clusters_items = list(clusters_for_html.items())
for i, (cluster_idx, group_data) in enumerate(
clusters_items
): # group_data might not be needed here if report uses cluster_idx
if i >= MAX_DOWNLOAD_LINKS_DISPLAYED:
break
try:
html_content = routes_clustering_report(
tree_for_html,
clusters_for_html, # Pass the whole dict
str(cluster_idx), # Pass the key of the cluster
sb_cgrs_for_html, # Pass the sb_cgrs dict
aam=False,
)
st.download_button(
label=f"Download report for cluster {cluster_idx}",
data=html_content,
file_name=f"cluster_{cluster_idx}_{st.session_state.target_smiles}.html",
mime="text/html",
key=f"download_cluster_{cluster_idx}",
)
except Exception as e:
st.error(f"Error generating report for cluster {cluster_idx}: {e}")
if num_clusters_total > MAX_DOWNLOAD_LINKS_DISPLAYED:
remaining_items = clusters_items[MAX_DOWNLOAD_LINKS_DISPLAYED:]
remaining_count = len(remaining_items)
expander_label = f"Show remaining {remaining_count} cluster reports"
with st.expander(expander_label):
for (
group_index,
_,
) in remaining_items: # group_data not needed here either
try:
html_content = routes_clustering_report(
tree_for_html,
clusters_for_html,
str(group_index),
sb_cgrs_for_html,
aam=False,
)
st.download_button(
label=f"Download report for cluster {group_index}",
data=html_content,
file_name=f"cluster_{group_index}_{st.session_state.target_smiles}.html",
mime="text/html",
key=f"download_cluster_expanded_{group_index}",
)
except Exception as e:
st.error(
f"Error generating report for cluster {group_index} (expanded): {e}"
)
try:
buffer = io.BytesIO()
with zipfile.ZipFile(
buffer, mode="w", compression=zipfile.ZIP_DEFLATED
) as zf:
for idx, _ in clusters_items: # group_data not needed
html_content_zip = routes_clustering_report(
tree_for_html,
clusters_for_html,
str(idx),
sb_cgrs_for_html,
aam=False,
)
filename = f"cluster_{idx}_{st.session_state.target_smiles}.html"
zf.writestr(filename, html_content_zip)
buffer.seek(0)
st.download_button(
label="๐ฆ Download all cluster reports as ZIP",
data=buffer,
file_name=f"all_cluster_reports_{st.session_state.target_smiles}.zip",
mime="application/zip",
key="download_all_clusters_zip",
)
except Exception as e:
st.error(f"Error generating ZIP file for cluster reports: {e}")
def setup_subclustering():
"""11. Subclustering: Encapsulating the logic related to the "subclustering" functionality."""
if st.session_state.get(
"clustering_done", False
): # Subclustering depends on clustering being done
st.divider()
st.header("Sub-Clustering within a selected Cluster")
if st.button("Run Subclustering Analysis", key="submit_subclustering_button"):
st.session_state.subclustering_done = False
st.session_state.subclusters = None
with st.spinner("Performing subclustering analysis..."):
try:
clusters_for_sub = st.session_state.get("clusters")
sb_cgrs_dict_for_sub = st.session_state.get("sb_cgrs_dict")
route_cgrs_dict_for_sub = st.session_state.get("route_cgrs_dict")
if (
clusters_for_sub
and sb_cgrs_dict_for_sub
and route_cgrs_dict_for_sub
): # Ensure all are present
all_subgroups = subcluster_all_clusters(
clusters_for_sub,
sb_cgrs_dict_for_sub,
route_cgrs_dict_for_sub,
)
st.session_state.subclusters = all_subgroups
st.session_state.subclustering_done = True
st.success("Subclustering analysis complete.")
gc.collect()
st.rerun()
else:
missing = []
if not clusters_for_sub:
missing.append("clusters")
if not sb_cgrs_dict_for_sub:
missing.append("SB-CGRs dictionary")
if not route_cgrs_dict_for_sub:
missing.append("RouteCGRs dictionary")
st.error(
f"Cannot run subclustering. Missing data: {', '.join(missing)}. Please ensure clustering ran successfully."
)
st.session_state.subclustering_done = False
except Exception as e:
st.error(f"An error occurred during subclustering: {e}")
st.session_state.subclustering_done = False
def display_subclustering_results():
"""12. Subclustering Results Display: Handling the presentation of results."""
if st.session_state.get("subclustering_done", False):
sub = st.session_state.get("subclusters")
tree = st.session_state.get("tree")
# clusters_for_sub_display = st.session_state.get('clusters') # Not directly used in display logic from original code snippet
if not sub or not tree:
st.error(
"Subclustering results (subclusters or tree) are missing. Please re-run subclustering."
)
st.session_state.subclustering_done = False
return
sub_input_col, sub_display_col = st.columns([0.25, 0.75])
with sub_input_col:
st.subheader("Select Cluster and Subcluster")
available_cluster_nums = list(sub.keys())
if not available_cluster_nums:
st.warning("No clusters available in subclustering results.")
return # Exit if no clusters to select
user_input_cluster_num_display = st.selectbox(
"Select Cluster #:",
options=sorted(available_cluster_nums),
key="subcluster_num_select_key",
)
selected_subcluster_idx = 0
if user_input_cluster_num_display in sub:
sub_step_cluster = sub[user_input_cluster_num_display]
allowed_subclusters_indices = sorted(list(sub_step_cluster.keys()))
if not allowed_subclusters_indices:
st.warning(
f"No reaction steps (subclusters) found for Cluster {user_input_cluster_num_display}."
)
else:
selected_subcluster_idx = st.selectbox(
"Select Subcluster Index:",
options=allowed_subclusters_indices,
key="subcluster_index_select_key",
)
if selected_subcluster_idx in sub[user_input_cluster_num_display]:
current_subcluster_data = sub[user_input_cluster_num_display][
selected_subcluster_idx
]
if "sb_cgr" in current_subcluster_data:
cluster_sb_cgr_display = current_subcluster_data["sb_cgr"]
cluster_sb_cgr_display.clean2d()
st.image(
cluster_sb_cgr_display.depict(),
caption=f"SB-CGR of parent Cluster {user_input_cluster_num_display}",
)
else:
st.warning("SB-CGR for this subcluster not found.")
else:
st.warning(
f"Selected cluster {user_input_cluster_num_display} not found in subclustering results."
)
return
with sub_display_col:
st.subheader("Subcluster Details")
if (
user_input_cluster_num_display in sub
and selected_subcluster_idx in sub[user_input_cluster_num_display]
):
subcluster_content = sub[user_input_cluster_num_display][
selected_subcluster_idx
]
# subcluster_to_display = post_process_subgroup(subcluster_content) #Under development
subcluster_to_display = subcluster_content
if (
not subcluster_to_display
or "routes_data" not in subcluster_to_display
or not subcluster_to_display["routes_data"]
):
st.info("No routes or data found for this subcluster selection.")
else:
MAX_ROUTES_PER_SUBCLUSTER = 5
all_route_ids_in_subcluster = list(
subcluster_to_display["routes_data"].keys()
)
routes_to_display_direct = all_route_ids_in_subcluster[
:MAX_ROUTES_PER_SUBCLUSTER
]
remaining_routes_sub = all_route_ids_in_subcluster[
MAX_ROUTES_PER_SUBCLUSTER:
]
st.markdown(
f"--- \n**Subcluster {user_input_cluster_num_display}.{selected_subcluster_idx}** (Size: {len(all_route_ids_in_subcluster)})"
)
if "synthon_reaction" in subcluster_to_display:
synthon_reaction = subcluster_to_display["synthon_reaction"]
try:
synthon_reaction.clean2d()
st.image(
depict_custom_reaction(synthon_reaction),
caption=f"Markush-like pseudo reaction of subcluster",
) # Assuming depict_custom_reaction
except Exception as e_depict:
st.warning(f"Could not depict synthon reaction: {e_depict}")
else:
st.info("No synthon reaction data for this subcluster.")
with st.container(height=500):
for route_id in routes_to_display_direct:
try:
route_score_sub = round(tree.route_score(route_id), 3)
# svg_sub = get_route_svg(tree, route_id)
svg_sub = get_route_svg_from_json(st.session_state.route_json, route_id)
if svg_sub:
st.image(
svg_sub,
caption=f"Route {route_id}; Score: {route_score_sub}",
)
else:
st.warning(
f"Could not generate SVG for route {route_id}."
)
except Exception as e:
st.error(
f"Error displaying route {route_id} in subcluster: {e}"
)
if remaining_routes_sub:
with st.expander(
f"... and {len(remaining_routes_sub)} more routes in this subcluster"
):
for route_id in remaining_routes_sub:
try:
route_score_sub = round(
tree.route_score(route_id), 3
)
# svg_sub = get_route_svg(tree, route_id)
svg_sub = get_route_svg_from_json(st.session_state.route_json, route_id)
if svg_sub:
st.image(
svg_sub,
caption=f"Route {route_id}; Score: {route_score_sub}",
)
else:
st.warning(
f"Could not generate SVG for route {route_id}."
)
except Exception as e:
st.error(
f"Error displaying route {route_id} in subcluster (expanded): {e}"
)
else:
st.info("Select a valid cluster and subcluster index to see details.")
def download_subclustering_results():
"""13. Subclustering Results Download: Providing functionality to download."""
if (
st.session_state.get("subclustering_done", False)
and "subcluster_num_select_key" in st.session_state
and "subcluster_index_select_key" in st.session_state
):
sub = st.session_state.get("subclusters")
tree = st.session_state.get("tree")
sb_cgrs_for_report = st.session_state.get(
"sb_cgrs_dict"
) # Used by routes_subclustering_report
user_input_cluster_num_display = st.session_state.subcluster_num_select_key
selected_subcluster_idx = st.session_state.subcluster_index_select_key
if not tree or not sub or not sb_cgrs_for_report:
st.warning(
"Missing data for subclustering report generation (tree, subclusters, or SB-CGRs)."
)
return
if (
user_input_cluster_num_display in sub
and selected_subcluster_idx in sub[user_input_cluster_num_display]
):
subcluster_data_for_report = sub[user_input_cluster_num_display][
selected_subcluster_idx
]
# Apply the same post-processing as in display
processed_subcluster_data = post_process_subgroup(
subcluster_data_for_report
)
if "routes_data" in subcluster_data_for_report and isinstance(
subcluster_data_for_report["routes_data"], dict
):
processed_subcluster_data["group_lgs"] = group_by_identical_values(
subcluster_data_for_report["routes_data"]
)
else:
processed_subcluster_data["group_lgs"] = {}
try:
subcluster_html_content = routes_subclustering_report(
tree,
processed_subcluster_data, # Pass the specific post-processed subcluster data
user_input_cluster_num_display,
selected_subcluster_idx,
sb_cgrs_for_report, # Pass the whole sb_cgrs dict
if_lg_group=True, # This parameter was in the original call
)
st.download_button(
label=f"Download report for subcluster {user_input_cluster_num_display}.{selected_subcluster_idx}",
data=subcluster_html_content,
file_name=f"subcluster_{user_input_cluster_num_display}.{selected_subcluster_idx}_{st.session_state.target_smiles}.html",
mime="text/html",
key=f"download_subcluster_{user_input_cluster_num_display}_{selected_subcluster_idx}",
)
except Exception as e:
st.error(
f"Error generating download report for subcluster {user_input_cluster_num_display}.{selected_subcluster_idx}: {e}"
)
# else:
# This case is handled by the display logic mostly, download button just won't appear or will be for previous valid selection.
def implement_restart():
"""14. Restart: Implementing the logic to reset or restart the application state."""
st.divider()
st.header("Restart Application State")
if st.button("Clear All Results & Restart", key="restart_button"):
keys_to_clear = [
"planning_done",
"tree",
"res",
"target_smiles",
"clustering_done",
"clusters",
"reactions_dict",
"num_clusters_setting",
"route_cgrs_dict",
"sb_cgrs_dict",
"route_json",
"subclustering_done",
"subclusters", # "sub" was renamed
"clusters_downloaded",
# Potentially ketcher related keys if they need manual reset beyond new input
"ketcher_widget",
"smiles_text_input_key", # Keys for widgets
"subcluster_num_select_key",
"subcluster_index_select_key",
]
for key in keys_to_clear:
if key in st.session_state:
del st.session_state[key]
# Reset ketcher input to default by resetting its session state variable
st.session_state.ketcher = DEFAULT_MOL
# Also explicitly set target_smiles to empty or default to avoid stale data
st.session_state.target_smiles = ""
# It's generally better to let Streamlit manage widget state if possible,
# but for a full reset, clearing their explicit session state keys might be needed.
st.rerun()
# --- Main Application Flow ---
def main():
initialize_app()
setup_sidebar()
current_smile_code = handle_molecule_input()
# Update session_state.ketcher if current_smile_code has changed from ketcher output
if st.session_state.get("ketcher") != current_smile_code:
st.session_state.ketcher = current_smile_code
# No rerun here, let the flow continue. handle_molecule_input already warns.
setup_planning_options() # This function now also handles the button press and logic for planning
# Display planning results and download options together
if st.session_state.get("planning_done", False):
display_planning_results() # Displays stats and routes
if st.session_state.res and st.session_state.res.get("solved", False):
stat_col, download_col = st.columns(
2, gap="medium"
) # Placeholder for download column
with stat_col:
st.subheader("Statistics")
try:
res = st.session_state.res
if (
"target_smiles" not in res
and "target_smiles" in st.session_state
):
res["target_smiles"] = st.session_state.target_smiles
cols_to_show = [
col
for col in [
"target_smiles",
"num_routes",
"num_nodes",
"num_iter",
"search_time",
]
if col in res
]
if cols_to_show: # Ensure there are columns to show
df = pd.DataFrame(res, index=[0])[cols_to_show]
st.dataframe(df)
else:
st.write("No statistics to display from planning results.")
except Exception as e:
st.error(f"Error displaying statistics: {e}")
st.write(res) # Show raw dict if DataFrame fails
with download_col:
st.subheader("Planning Downloads") # Adding a subheader for clarity
download_planning_results()
# Clustering section (setup button, display, download)
if (
st.session_state.get("planning_done", False)
and st.session_state.res
and st.session_state.res.get("solved", False)
):
setup_clustering() # Contains the "Run Clustering" button and logic
if st.session_state.get("clustering_done", False):
display_clustering_results() # Displays cluster routes and stats
cluster_stat_col, cluster_download_col = st.columns(2, gap="medium")
with cluster_stat_col:
clusters = st.session_state.clusters
cluster_sizes = [
cluster.get("group_size", 0)
for cluster in clusters.values()
if cluster
] # Safe get
st.subheader("Cluster Statistics")
if cluster_sizes:
cluster_df = pd.DataFrame(
{
"Cluster": [
k for k, v in clusters.items() if v
], # Filter out empty clusters
"Number of Routes": [
v["group_size"] for v in clusters.values() if v
],
}
)
if not cluster_df.empty:
cluster_df.index += 1
st.dataframe(cluster_df)
best_route_html = html_top_routes_cluster(
clusters,
st.session_state.tree,
st.session_state.target_smiles,
)
st.download_button(
label=f"Download best route from each cluster",
data=best_route_html,
file_name=f"cluster_best_{st.session_state.target_smiles}.html",
mime="text/html",
key=f"download_cluster_best",
)
else:
st.write("No valid cluster data to display statistics for.")
# download_top_routes_cluster()
else:
st.write("No cluster data to display statistics for.")
with cluster_download_col:
download_clustering_results()
# Subclustering section (setup button, display, download)
if st.session_state.get("clustering_done", False): # Depends on clustering
setup_subclustering() # Contains "Run Subclustering" button
if st.session_state.get("subclustering_done", False):
display_subclustering_results() # Displays subcluster details and routes
download_subclustering_results() # This needs to be called after selections are made in display.
implement_restart()
if __name__ == "__main__":
main()