Protolaw's picture
Fix missing synplan folder
5a1721f verified
import base64
import gc
import io
import os
import pickle
import re
import tempfile
import uuid
import zipfile
from pathlib import Path
from CGRtools.files import SMILESRead
from huggingface_hub import hf_hub_download
from huggingface_hub.utils import disable_progress_bars
import pandas as pd
import streamlit as st
from streamlit_ketcher import st_ketcher
from synplan.chem.reaction_routes.clustering import *
from synplan.chem.reaction_routes.route_cgr import *
from synplan.chem.utils import mol_from_smiles
from synplan.mcts.search import extract_tree_stats
from synplan.mcts.tree import Tree
from synplan.utils.config import TreeConfig
from synplan.utils.loading import (
load_building_blocks,
load_policy_function,
load_reaction_rules,
)
from synplan.utils.visualisation import (
generate_results_html,
get_route_svg,
get_route_svg_from_json,
html_top_routes_cluster,
routes_clustering_report,
routes_subclustering_report,
)
from synplan.utils.config import RolloutEvaluationConfig
from synplan.utils.loading import load_evaluation_function
disable_progress_bars("huggingface_hub")
smiles_parser = SMILESRead.create_parser(ignore=True)
DEFAULT_MOL = "c1cc(ccc1Cl)C(CCO)NC(C2(CCN(CC2)c3c4cc[nH]c4ncn3)N)=O"
# --- Helper Functions ---
def download_button(
object_to_download, download_filename, button_text, pickle_it=False
):
"""
Generates a link to download the given object_to_download.
"""
if pickle_it:
try:
object_to_download = pickle.dumps(object_to_download)
except pickle.PicklingError as e:
st.write(e)
return None
else:
if isinstance(object_to_download, bytes):
pass
elif isinstance(object_to_download, pd.DataFrame):
object_to_download = object_to_download.to_csv(index=False).encode("utf-8")
try:
b64 = base64.b64encode(object_to_download.encode()).decode()
except AttributeError:
b64 = base64.b64encode(object_to_download).decode()
button_uuid = str(uuid.uuid4()).replace("-", "")
button_id = re.sub(r"\d+", "", button_uuid)
custom_css = f"""
<style>
#{button_id} {{
background-color: rgb(255, 255, 255);
color: rgb(38, 39, 48);
text-decoration: none;
border-radius: 4px;
border-width: 1px;
border-style: solid;
border-color: rgb(230, 234, 241);
border-image: initial;
}}
#{button_id}:hover {{
border-color: rgb(246, 51, 102);
color: rgb(246, 51, 102);
}}
#{button_id}:active {{
box-shadow: none;
background-color: rgb(246, 51, 102);
color: white;
}}
</style> """
dl_link = (
custom_css
+ f'<a download="{download_filename}" id="{button_id}" href="data:file/txt;base64,{b64}">{button_text}</a><br></br>'
)
return dl_link
def load_priority_rules_from_upload(uploaded_file):
if uploaded_file is None:
return tuple()
try:
if isinstance(uploaded_file, (str, Path)):
with open(uploaded_file, "rb") as f:
priority_rules = pickle.load(f)
else:
data = (
uploaded_file.getbuffer()
if hasattr(uploaded_file, "getbuffer")
else uploaded_file.read()
)
priority_rules = pickle.loads(data)
return priority_rules
except Exception as e:
raise ValueError(f"Failed to load priority rules: {e}") from e
@st.cache_resource
def load_planning_resources_cached():
building_blocks_path = hf_hub_download(
repo_id="Laboratoire-De-Chemoinformatique/SynPlanner",
filename="building_blocks_em_sa_ln.smi",
subfolder="building_blocks",
local_dir=".",
)
ranking_policy_weights_path = hf_hub_download(
repo_id="Laboratoire-De-Chemoinformatique/SynPlanner",
filename="ranking_policy_network.ckpt",
subfolder="uspto/weights",
local_dir=".",
)
reaction_rules_path = hf_hub_download(
repo_id="Laboratoire-De-Chemoinformatique/SynPlanner",
filename="uspto_reaction_rules.pickle",
subfolder="uspto",
local_dir=".",
)
return building_blocks_path, ranking_policy_weights_path, reaction_rules_path
# --- GUI Sections ---
def initialize_app():
st.set_page_config(
page_title="SynPlanner GUI",
page_icon="🧪",
layout="wide",
initial_sidebar_state="expanded",
)
st.markdown(
"""
<style>
:root {
color-scheme: light !important;
}
</style>
""",
unsafe_allow_html=True,
)
# Initialize session state variables if they don't exist.
if "planning_done" not in st.session_state:
st.session_state.planning_done = False
if "tree" not in st.session_state:
st.session_state.tree = None
if "res" not in st.session_state:
st.session_state.res = None
if "target_smiles" not in st.session_state:
st.session_state.target_smiles = ""
# Clustering state
if "clustering_done" not in st.session_state:
st.session_state.clustering_done = False
if "clusters" not in st.session_state:
st.session_state.clusters = None
if "reactions_dict" not in st.session_state:
st.session_state.reactions_dict = None
if "num_clusters_setting" not in st.session_state:
st.session_state.num_clusters_setting = 10
if "route_cgrs_dict" not in st.session_state:
st.session_state.route_cgrs_dict = None
if "sb_cgrs_dict" not in st.session_state:
st.session_state.sb_cgrs_dict = None
if "route_json" not in st.session_state:
st.session_state.route_json = None
# Subclustering state
if "subclustering_done" not in st.session_state:
st.session_state.subclustering_done = False
if "subclusters" not in st.session_state:
st.session_state.subclusters = None
# Download state
if "clusters_downloaded" not in st.session_state:
st.session_state.clusters_downloaded = False
if "ketcher" not in st.session_state:
st.session_state.ketcher = DEFAULT_MOL
if "upload_mode" not in st.session_state:
st.session_state.upload_mode = None
if "uploaded_rules_file" not in st.session_state:
st.session_state.uploaded_rules_file = None
if "selected_routes" not in st.session_state:
st.session_state.selected_routes = []
if "selected_routes_dialog_open" not in st.session_state:
st.session_state.selected_routes_dialog_open = False
if "selected_routes_dialog_minimized" not in st.session_state:
st.session_state.selected_routes_dialog_minimized = False
if "go_selected_routes" not in st.session_state:
st.session_state.go_selected_routes = False
intro_text = """
This is a demo of the graphical user interface of
[SynPlanner](https://github.com/Laboratoire-de-Chemoinformatique/SynPlanner/).
SynPlanner is a comprehensive tool for reaction data curation, rule extraction, model training and retrosynthetic planning.
More information on SynPlanner is available in the [official docs](https://synplanner.readthedocs.io/en/latest/index.html).
"""
st.title("`SynPlanner GUI`")
st.write(intro_text)
def setup_sidebar():
st.sidebar.title("Docs")
st.sidebar.markdown("https://synplanner.readthedocs.io/en/latest/")
st.sidebar.title("Tutorials")
st.sidebar.markdown(
"[Link](https://github.com/Laboratoire-de-Chemoinformatique/SynPlanner/tree/main/tutorials)"
)
st.sidebar.title("Preprint")
st.sidebar.markdown("[Link](https://doi.org/10.26434/chemrxiv-2024-bzpnd)")
st.sidebar.title("Paper")
st.sidebar.markdown("[Link](https://doi.org/10.1021/acs.jcim.4c02004)")
st.sidebar.title("Issues")
st.sidebar.markdown(
"[Report a bug 🐞](https://github.com/Laboratoire-de-Chemoinformatique/SynPlanner/issues/new?assignees=&labels=bug&projects=&template=bug_report.md&title=%5BBUG%5D)"
)
def handle_molecule_input():
col_molecule, col_import = st.columns([0.7, 0.3], gap="large")
with col_import:
st.header("Import In-house dataset - Priority")
if st.button("Upload Reactions", key="upload_reactions_button"):
st.session_state.upload_mode = "reactions"
if st.button("Upload Rules", key="upload_rules_button"):
st.session_state.upload_mode = "rules"
upload_mode = st.session_state.get("upload_mode")
if upload_mode == "reactions":
uploaded_reactions = st.file_uploader(
"Choose reactions file",
key="upload_reactions_file",
)
if uploaded_reactions is not None:
st.session_state.uploaded_reaction_file = uploaded_reactions
st.success(f"Reactions file selected: {uploaded_reactions.name}")
elif upload_mode == "rules":
uploaded_rules = st.file_uploader(
"Choose rules file",
key="upload_rules_file",
)
if uploaded_rules is not None:
st.session_state.uploaded_rules_file = uploaded_rules
st.success(f"Rules file selected: {uploaded_rules.name}")
with col_molecule:
st.header("Molecule input")
st.markdown(
"""
You can provide a molecular structure by either providing:
* SMILES string + Enter
* Draw it + Apply
"""
)
if "shared_smiles" not in st.session_state:
st.session_state.shared_smiles = st.session_state.get(
"ketcher", DEFAULT_MOL
)
if "ketcher_render_count" not in st.session_state:
st.session_state.ketcher_render_count = 0
def text_input_changed_callback():
new_text_value = st.session_state.smiles_text_input_key_for_sync
if new_text_value != st.session_state.shared_smiles:
st.session_state.shared_smiles = new_text_value
st.session_state.ketcher = new_text_value
st.session_state.ketcher_render_count += 1
st.text_input(
"SMILES:",
value=st.session_state.shared_smiles,
key="smiles_text_input_key_for_sync",
on_change=text_input_changed_callback,
help=(
"Enter SMILES string and press Enter. "
"The drawing will update, and vice-versa."
),
)
ketcher_key = f"ketcher_widget_for_sync_{st.session_state.ketcher_render_count}"
smile_code_output_from_ketcher = st_ketcher(
st.session_state.shared_smiles, key=ketcher_key
)
if smile_code_output_from_ketcher != st.session_state.shared_smiles:
st.session_state.shared_smiles = smile_code_output_from_ketcher
st.session_state.ketcher = smile_code_output_from_ketcher
st.rerun()
current_smiles_for_planning = st.session_state.shared_smiles
last_planned_smiles = st.session_state.get("target_smiles")
if (
last_planned_smiles
and current_smiles_for_planning != last_planned_smiles
and st.session_state.get("planning_done", False)
):
st.warning(
"Molecule structure has changed since the last successful planning run. "
"Results shown below (if any) are for the previous molecule. "
"Please re-run planning for the current structure."
)
if st.session_state.get("ketcher") != current_smiles_for_planning:
st.session_state.ketcher = current_smiles_for_planning
return current_smiles_for_planning
def handle_priority_rules():
if st.session_state.upload_mode == "rules":
uploaded_rules_file = st.session_state.get(
"uploaded_rules_file"
)
if uploaded_rules_file is not None:
st.write("Loading priority rules...")
try:
priority_rules = load_priority_rules_from_upload(
uploaded_rules_file
)
except ValueError as e:
st.error(str(e))
priority_rules = tuple()
if st.session_state.upload_mode == "reactions":
uploaded_reaction_file = st.session_state.get(
"uploaded_reaction_file"
)
if uploaded_reaction_file is not None:
st.write("Loading priority rules...")
try:
priority_rules = load_priority_rules_from_upload(
uploaded_reaction_file
)
except ValueError as e:
st.error(str(e))
priority_rules = tuple()
return priority_rules
def setup_planning_options():
st.header("Launch calculation")
st.markdown(
"""If you modified the structure, please ensure you clicked on `Apply` (bottom right of the molecular editor)."""
)
st.markdown(
f"The molecule SMILES is actually: ``{st.session_state.get('ketcher', DEFAULT_MOL)}``"
)
st.subheader("Planning options")
st.markdown(
"""
The description of each option can be found in the
[Retrosynthetic Planning Tutorial](https://synplanner.readthedocs.io/en/latest/tutorial_files/retrosynthetic_planning.html#Configuring-search-tree).
"""
)
col_options_1, col_options_2 = st.columns(2, gap="medium")
with col_options_1:
search_strategy_input = st.selectbox(
label="Search strategy",
options=(
"Expansion first",
"Evaluation first",
),
index=0,
key="search_strategy_input",
)
ucb_type = st.selectbox(
label="UCB type",
options=("uct", "puct", "value"),
index=0,
key="ucb_type_input",
)
c_ucb = st.number_input(
"C coefficient of UCB",
value=0.1,
placeholder="Type a number...",
key="c_ucb_input",
)
with col_options_2:
max_iterations = st.slider(
"Total number of MCTS iterations",
min_value=50,
max_value=1000,
value=300,
key="max_iterations_slider",
)
max_depth = st.slider(
"Maximal number of reaction steps",
min_value=3,
max_value=40,
value=6,
key="max_depth_slider",
)
min_mol_size = st.slider(
"Minimum size of a molecule to be precursor",
min_value=0,
max_value=7,
value=0,
key="min_mol_size_slider",
help="Number of non-hydrogen atoms in molecule",
)
search_strategy_translator = {
"Expansion first": "expansion_first",
"Evaluation first": "evaluation_first",
}
search_strategy = search_strategy_translator[search_strategy_input]
planning_params = {
"search_strategy": search_strategy,
"ucb_type": ucb_type,
"c_ucb": c_ucb,
"max_iterations": max_iterations,
"max_depth": max_depth,
"min_mol_size": min_mol_size,
}
if st.button("Start retrosynthetic planning", key="submit_planning_button"):
# Reset downstream states if replanning
st.session_state.planning_done = False
st.session_state.clustering_done = False
st.session_state.subclustering_done = False
st.session_state.tree = None
st.session_state.res = None
st.session_state.clusters = None
st.session_state.reactions_dict = None
st.session_state.subclusters = None
st.session_state.route_cgrs_dict = None
st.session_state.sb_cgrs_dict = None
st.session_state.route_json = None
st.session_state.selected_routes = []
active_smile_code = st.session_state.get("ketcher", DEFAULT_MOL)
st.session_state.target_smiles = active_smile_code
try:
target_molecule = mol_from_smiles(active_smile_code, clean_stereo=True)
if target_molecule is None:
st.error(f"Could not parse the input SMILES: {active_smile_code}")
else:
(
building_blocks_path,
ranking_policy_weights_path,
reaction_rules_path,
) = load_planning_resources_cached()
with st.spinner("Running retrosynthetic planning..."):
with st.status("Loading resources...", expanded=False) as status:
st.write("Loading building blocks...")
building_blocks = load_building_blocks(
building_blocks_path, standardize=False
)
st.write("Loading reaction rules...")
reaction_rules = load_reaction_rules(reaction_rules_path)
priority_rules = tuple()
priority_rules = handle_priority_rules()
st.write("Loading policy network...")
policy_function = load_policy_function(
weights_path=ranking_policy_weights_path
)
status.update(label="Resources loaded!", state="complete")
tree_config = TreeConfig(
search_strategy=planning_params["search_strategy"],
max_iterations=planning_params["max_iterations"],
max_depth=planning_params["max_depth"],
min_mol_size=planning_params["min_mol_size"],
init_node_value=0.5,
ucb_type=planning_params["ucb_type"],
c_ucb=planning_params["c_ucb"],
silent=True,
)
if priority_rules:
tree_config.use_priority = True
eval_config = RolloutEvaluationConfig(
policy_network=policy_function,
reaction_rules=reaction_rules,
building_blocks=building_blocks,
min_mol_size=tree_config.min_mol_size,
max_depth=tree_config.max_depth,
)
evaluator = load_evaluation_function(eval_config)
tree = Tree(
target=target_molecule,
config=tree_config,
reaction_rules=reaction_rules,
building_blocks=building_blocks,
expansion_function=policy_function,
evaluation_function=evaluator,
priority_rules=priority_rules,
)
mcts_progress_text = "Running MCTS iterations..."
mcts_bar = st.progress(0, text=mcts_progress_text)
for step, (solved, route_id) in enumerate(tree):
progress_value = min(
1.0, (step + 1) / planning_params["max_iterations"]
)
mcts_bar.progress(
progress_value,
text=f"{mcts_progress_text} ({step+1}/{planning_params['max_iterations']})",
)
res = extract_tree_stats(tree, target_molecule)
st.session_state["tree"] = tree
st.session_state["res"] = res
st.session_state.planning_done = True
except Exception as e:
st.error(f"An error occurred during planning: {e}")
st.session_state.planning_done = False
def display_planning_results():
"""
Planning results for the NOT-SOLVED case only.
For solved runs, we use the unified planning+clustering section.
"""
if not st.session_state.get("planning_done", False):
return
res = st.session_state.res
tree = st.session_state.tree
if res is None or tree is None:
st.error(
"Planning results are missing from session state. Please re-run planning."
)
st.session_state.planning_done = False
return
if res.get("solved", False):
# Solved case handled in unified section.
return
st.header("Planning results")
st.warning(
"No reaction path found for the target molecule with the current settings."
)
st.write(
"Consider adjusting planning options (e.g., increase iterations, adjust depth, check molecule validity)."
)
stat_col, _ = st.columns(2)
with stat_col:
st.subheader("Run statistics (no solution)")
try:
if "target_smiles" not in res and "target_smiles" in st.session_state:
res["target_smiles"] = st.session_state.target_smiles
cols_to_show = [
col
for col in [
"target_smiles",
"num_nodes",
"num_iter",
"search_time",
]
if col in res
]
if cols_to_show:
df = pd.DataFrame(res, index=[0])[cols_to_show]
st.dataframe(df)
else:
st.write("No statistics to display for the unsuccessful run.")
except Exception as e:
st.error(f"Error displaying statistics: {e}")
st.write(res)
def download_planning_results():
"""
Planning results download (full HTML report).
Only uses internal state; no headers here to keep UX unified.
"""
if (
st.session_state.get("planning_done", False)
and st.session_state.res
and st.session_state.res.get("solved", False)
):
try:
if st.button("Generate full HTML report", key="gen_plan_html"):
with st.spinner("Generating HTML report..."):
st.session_state.planning_report_html = generate_results_html(
st.session_state.tree, html_path=None, extended=True
)
if st.session_state.get("planning_report_html"):
st.download_button(
label="Download full planning report (HTML)",
data=st.session_state.planning_report_html,
file_name=f"full_report_{st.session_state.target_smiles}.html",
mime="text/html",
key="download_full_planning_html",
)
except Exception as e:
st.error(f"Error generating download links for planning results: {e}")
def filter_routes_for_clustering(
clusters: dict,
sb_cgrs_dict: dict,
route_cgrs_dict: dict,
skip_ids: set[int],
):
"""
Remove routes in skip_ids from:
- clusters (route_ids and group_size)
- SB-CGR dictionary
- RouteCGR dictionary
Clusters that become empty are dropped.
"""
if not skip_ids:
return clusters, sb_cgrs_dict, route_cgrs_dict
# Filter SB-CGRs and RouteCGRs
sb_cgrs_filtered = {
rid: cgr for rid, cgr in sb_cgrs_dict.items() if rid not in skip_ids
}
route_cgrs_filtered = {
rid: cgr for rid, cgr in route_cgrs_dict.items() if rid not in skip_ids
}
# Filter clusters and recompute group_size
filtered_clusters = {}
for cid, data in clusters.items():
route_ids = [rid for rid in data.get("route_ids", []) if rid not in skip_ids]
if not route_ids:
# whole cluster becomes empty -> drop it
continue
new_data = dict(data)
new_data["route_ids"] = route_ids
new_data["group_size"] = len(route_ids)
filtered_clusters[cid] = new_data
return filtered_clusters, sb_cgrs_filtered, route_cgrs_filtered
def run_clustering_core():
"""Core clustering logic (no explicit UI button)."""
st.session_state.clustering_done = False
st.session_state.subclustering_done = False
st.session_state.clusters = None
st.session_state.reactions_dict = None
st.session_state.subclusters = None
st.session_state.route_cgrs_dict = None
st.session_state.sb_cgrs_dict = None
st.session_state.route_json = None
with st.spinner("Performing clustering..."):
try:
current_tree = st.session_state.tree
if not current_tree:
st.error("Tree object not found. Please re-run planning.")
return
st.write("Calculating RouteCGRs...")
route_cgrs_dict = compose_all_route_cgrs(current_tree)
st.write("Processing SB-CGRs...")
sb_cgrs_dict = compose_all_sb_cgrs(route_cgrs_dict)
results = cluster_routes(sb_cgrs_dict, use_strat=False)
results = dict(sorted(results.items(), key=lambda x: float(x[0])))
st.session_state.clusters = results
st.session_state.route_cgrs_dict = route_cgrs_dict
st.session_state.sb_cgrs_dict = sb_cgrs_dict
st.write("Extracting reactions...")
st.session_state.reactions_dict = extract_reactions(current_tree)
st.session_state.route_json = make_json(st.session_state.reactions_dict)
if (
st.session_state.clusters is not None
and st.session_state.reactions_dict is not None
):
st.session_state.clustering_done = True
st.success(
f"Clustering complete. Found {len(st.session_state.clusters)} clusters."
)
else:
st.error("Clustering failed or returned empty results.")
st.session_state.clustering_done = False
gc.collect()
except Exception as e:
st.error(f"An error occurred during clustering: {e}")
st.session_state.clustering_done = False
def download_clustering_results():
"""Clustering results download: per-cluster reports + ZIP."""
if st.session_state.get("clustering_done", False):
tree_for_html = st.session_state.get("tree")
clusters_for_html = st.session_state.get("clusters")
sb_cgrs_for_html = st.session_state.get("sb_cgrs_dict")
if not tree_for_html:
st.warning("MCTS tree data not found. Cannot generate cluster reports.")
return
if not clusters_for_html:
st.warning("Cluster data not found. Cannot generate cluster reports.")
return
st.caption("Generate downloadable HTML reports for each cluster.")
MAX_DOWNLOAD_LINKS_DISPLAYED = 10
num_clusters_total = len(clusters_for_html)
clusters_items = list(clusters_for_html.items())
for i, (cluster_idx, group_data) in enumerate(clusters_items):
if i >= MAX_DOWNLOAD_LINKS_DISPLAYED:
break
try:
html_content = routes_clustering_report(
tree_for_html,
clusters_for_html,
str(cluster_idx),
sb_cgrs_for_html,
aam=False,
)
st.download_button(
label=f"Download report for cluster {cluster_idx}",
data=html_content,
file_name=f"cluster_{cluster_idx}_{st.session_state.target_smiles}.html",
mime="text/html",
key=f"download_cluster_{cluster_idx}",
)
except Exception as e:
st.error(f"Error generating report for cluster {cluster_idx}: {e}")
if num_clusters_total > MAX_DOWNLOAD_LINKS_DISPLAYED:
remaining_items = clusters_items[MAX_DOWNLOAD_LINKS_DISPLAYED:]
remaining_count = len(remaining_items)
expander_label = f"Show remaining {remaining_count} cluster reports"
with st.expander(expander_label):
for group_index, _ in remaining_items:
try:
html_content = routes_clustering_report(
tree_for_html,
clusters_for_html,
str(group_index),
sb_cgrs_for_html,
aam=False,
)
st.download_button(
label=f"Download report for cluster {group_index}",
data=html_content,
file_name=f"cluster_{group_index}_{st.session_state.target_smiles}.html",
mime="text/html",
key=f"download_cluster_expanded_{group_index}",
)
except Exception as e:
st.error(
f"Error generating report for cluster {group_index} (expanded): {e}"
)
# ZIP of all clusters
try:
buffer = io.BytesIO()
with zipfile.ZipFile(
buffer, mode="w", compression=zipfile.ZIP_DEFLATED
) as zf:
for idx, _ in clusters_items:
html_content_zip = routes_clustering_report(
tree_for_html,
clusters_for_html,
str(idx),
sb_cgrs_for_html,
aam=False,
)
filename = f"cluster_{idx}_{st.session_state.target_smiles}.html"
zf.writestr(filename, html_content_zip)
buffer.seek(0)
st.download_button(
label="Download all cluster reports as ZIP",
data=buffer,
file_name=f"all_cluster_reports_{st.session_state.target_smiles}.zip",
mime="application/zip",
key="download_all_clusters_zip",
)
except Exception as e:
st.error(f"Error generating ZIP file for cluster reports: {e}")
def display_planning_and_clustering_results_unified(show_rule_labels=False):
"""
Overview tab: planning summary, cluster summary, and best routes from clusters.
"""
res = st.session_state.res
tree = st.session_state.tree
clusters = st.session_state.clusters
route_json = st.session_state.route_json
if not (res and tree and clusters and route_json):
st.error(
"Missing data for unified planning+clustering display. Please re-run planning."
)
return
# --- Compact planning + cluster summaries instead of big tables ---
stat_col, cluster_stat_col = st.columns(2, gap="medium")
with stat_col:
st.subheader("Planning summary")
smi = st.session_state.get("target_smiles", "")
num_routes = res.get("num_routes", "—")
num_nodes = res.get("num_nodes", "—")
num_iter = res.get("num_iter", "—")
search_time = res.get("search_time", "—")
st.markdown(
f"""
- **Target SMILES**: `{smi}`
- **Routes explored**: **{num_routes}**
- **Tree nodes**: {num_nodes}
- **MCTS iterations**: {num_iter}
- **Search time**: {search_time} s
"""
)
with cluster_stat_col:
st.subheader("Cluster summary")
clusters_dict = clusters or {}
non_empty_clusters = [v for v in clusters_dict.values() if v]
n_clusters = len(non_empty_clusters)
route_counts = [v.get("group_size", 0) for v in non_empty_clusters]
if route_counts:
min_routes = min(route_counts)
max_routes = max(route_counts)
avg_routes = sum(route_counts) / len(route_counts)
else:
min_routes = max_routes = avg_routes = 0
st.markdown(
f"""
- **Number of clusters**: **{n_clusters}**
- **Routes per cluster**: min {min_routes}, max {max_routes}, avg {avg_routes:.1f}
"""
)
if clusters_dict:
best_route_html = html_top_routes_cluster(
clusters_dict,
st.session_state.tree,
st.session_state.target_smiles,
)
st.download_button(
label="Download best route from each cluster (HTML)",
data=best_route_html,
file_name=f"cluster_best_{st.session_state.target_smiles}.html",
mime="text/html",
key="download_cluster_best_unified",
)
st.markdown("---")
st.subheader(f"Best routes from {len(clusters)} found clusters")
# --- Best routes from clusters (always shown, not hidden in expanders) ---
MAX_DISPLAY_CLUSTERS_DATA = 10
clusters_items = list(clusters.items())
first_items = clusters_items[:MAX_DISPLAY_CLUSTERS_DATA]
remaining_items = clusters_items[MAX_DISPLAY_CLUSTERS_DATA:]
for cluster_num, group_data in first_items:
if (
not group_data
or "route_ids" not in group_data
or not group_data["route_ids"]
):
st.warning(f"Cluster {cluster_num} has no data or route_ids.")
continue
st.markdown(
f"**Cluster {cluster_num}** (size: {group_data.get('group_size', 'N/A')})"
)
route_id = group_data["route_ids"][0]
try:
num_steps = len(tree.synthesis_route(route_id))
route_score = round(tree.route_score(route_id), 3)
if show_rule_labels:
svg = get_route_svg(tree, route_id, labeled=True)
else:
svg = get_route_svg_from_json(route_json, route_id)
sb_cgr = group_data.get("sb_cgr")
sb_cgr_svg = None
if sb_cgr:
sb_cgr.clean2d()
sb_cgr_svg = cgr_display(sb_cgr)
if svg and sb_cgr_svg:
col1, col2 = st.columns([0.2, 0.8])
with col1:
st.image(sb_cgr_svg, caption="SB-CGR")
with col2:
display_route_with_add_button(
svg,
f"Route {route_id}; {num_steps} steps; route score: {route_score}",
route_id,
key_prefix=f"cluster_{cluster_num}",
)
elif svg:
display_route_with_add_button(
svg,
f"Route {route_id}; {num_steps} steps; route score: {route_score}",
route_id,
key_prefix=f"cluster_{cluster_num}",
)
st.warning(f"SB-CGR could not be displayed for cluster {cluster_num}.")
else:
st.warning(
f"Could not generate SVG for route {route_id} or its SB-CGR."
)
except Exception as e:
st.error(
f"Error displaying route {route_id} for cluster {cluster_num}: {e}"
)
if remaining_items:
with st.expander(f"... and {len(remaining_items)} more clusters"):
for cluster_num, group_data in remaining_items:
if (
not group_data
or "route_ids" not in group_data
or not group_data["route_ids"]
):
st.warning(
f"Cluster {cluster_num} in expansion has no data or route_ids."
)
continue
st.markdown(
f"**Cluster {cluster_num}** (size: {group_data.get('group_size', 'N/A')})"
)
route_id = group_data["route_ids"][0]
try:
num_steps = len(tree.synthesis_route(route_id))
route_score = round(tree.route_score(route_id), 3)
if show_rule_labels:
svg = get_route_svg(tree, route_id, labeled=True)
else:
svg = get_route_svg_from_json(route_json, route_id)
sb_cgr = group_data.get("sb_cgr")
sb_cgr_svg = None
if sb_cgr:
sb_cgr.clean2d()
sb_cgr_svg = cgr_display(sb_cgr)
if svg and sb_cgr_svg:
col1, col2 = st.columns([0.2, 0.8])
with col1:
st.image(sb_cgr_svg, caption="SB-CGR")
with col2:
display_route_with_add_button(
svg,
f"Route {route_id}; {num_steps} steps; route score: {route_score}",
route_id,
key_prefix=f"cluster_expanded_{cluster_num}",
)
elif svg:
display_route_with_add_button(
svg,
f"Route {route_id}; {num_steps} steps; route score: {route_score}",
route_id,
key_prefix=f"cluster_expanded_{cluster_num}",
)
st.warning(
f"SB-CGR could not be displayed for cluster {cluster_num}."
)
else:
st.warning(
f"Could not generate SVG for route {route_id} or its SB-CGR."
)
except Exception as e:
st.error(
f"Error displaying route {route_id} for cluster {cluster_num}: {e}"
)
# --- Subclustering-related cached helpers ---
@st.cache_data
def generate_sb_cgr_image(_cgr):
_cgr.clean2d()
return _cgr.depict()
@st.cache_data
def generate_synthon_reaction_image(_synthon_reaction):
_synthon_reaction.clean2d()
return depict_custom_reaction(_synthon_reaction)
@st.cache_data
def get_cached_route_svg(_route_json, route_id):
return get_route_svg_from_json(_route_json, route_id)
@st.cache_data
def get_route_details(_tree, route_id):
score = round(_tree.route_score(route_id), 3)
length = len(_tree.synthesis_route(route_id))
return {"score": score, "length": length}
def add_route_to_selection(route_id):
selected_routes = st.session_state.get("selected_routes", [])
if route_id not in selected_routes:
selected_routes.append(route_id)
st.session_state.selected_routes = selected_routes
st.session_state.selected_routes_dialog_open = True
st.session_state.selected_routes_dialog_minimized = False
def remove_route_from_selection(route_id):
selected_routes = st.session_state.get("selected_routes", [])
if route_id in selected_routes:
selected_routes.remove(route_id)
st.session_state.selected_routes = selected_routes
def render_selected_routes_dialog(key_prefix="selected"):
selected_routes = st.session_state.get("selected_routes", [])
if not selected_routes or not st.session_state.get(
"selected_routes_dialog_open", False
):
return
@st.dialog("Selected routes")
def show_selected_routes_dialog():
st.markdown(
"""
<style>
div[data-testid="stDialog"] div[data-testid="stButton"] > button {
padding: 2px 6px;
min-height: 18px;
height: 18px;
line-height: 1;
font-size: 10px;
}
</style>
""",
unsafe_allow_html=True,
)
st.markdown(f"**Selected routes** ({len(selected_routes)})")
with st.container(height=300):
for route_id in selected_routes:
row_col, action_col = st.columns([0.7, 0.3])
with row_col:
st.write(f"Route {route_id}")
with action_col:
if st.button(
"Remove", key=f"{key_prefix}_remove_route_{route_id}"
):
remove_route_from_selection(route_id)
st.rerun()
spacer, c1, c2, c3 = st.columns([0.88, 0.04, 0.04, 0.04])
with c1:
if st.button("🔴", key=f"{key_prefix}_dialog_close"):
st.session_state.selected_routes_dialog_open = False
st.session_state.selected_routes_dialog_minimized = False
st.rerun()
with c2:
if st.button("🟡", key=f"{key_prefix}_dialog_minimize"):
st.session_state.selected_routes_dialog_open = False
st.session_state.selected_routes_dialog_minimized = True
st.rerun()
with c3:
if st.button("🟢", key=f"{key_prefix}_dialog_maximize"):
st.session_state.selected_routes_dialog_open = True
st.session_state.selected_routes_dialog_minimized = False
st.rerun()
show_selected_routes_dialog()
def render_selected_routes_minimized_bar(key_prefix="selected"):
if not st.session_state.get("selected_routes_dialog_minimized", False):
return
selected_routes = st.session_state.get("selected_routes", [])
if not selected_routes:
st.session_state.selected_routes_dialog_minimized = False
return
st.markdown("---")
bar_col, _ = st.columns([0.25, 0.75])
with bar_col:
if st.button(
f"🟡 Selected routes ({len(selected_routes)})",
key=f"{key_prefix}_minimized_restore",
):
st.session_state.selected_routes_dialog_open = True
st.session_state.selected_routes_dialog_minimized = False
st.rerun()
def render_selected_routes_navigation():
if st.button("Go to the selected routes", key="go_selected_routes_button"):
st.session_state.go_selected_routes = True
st.rerun()
if st.session_state.get("go_selected_routes"):
st.session_state.go_selected_routes = False
if hasattr(st, "switch_page"):
st.switch_page("pages/gui_2.py")
else:
st.warning(
"Page navigation is unavailable in this Streamlit version. "
"Open `pages/gui_2.py` from the multipage menu."
)
def display_route_with_add_button(svg, caption, route_id, key_prefix):
image_col, button_col = st.columns([0.94, 0.06], gap="small")
with image_col:
st.image(svg, caption=caption)
with button_col:
st.caption("Add to the selected")
if st.button("➕", key=f"{key_prefix}_add_route_{route_id}"):
add_route_to_selection(route_id)
st.rerun()
def display_single_route(
route_id,
route_json,
details,
tree=None,
show_rule_labels=False,
key_prefix="subcluster",
):
try:
if show_rule_labels and tree is not None:
svg_sub = get_route_svg(tree, route_id, labeled=True)
else:
svg_sub = get_cached_route_svg(route_json, route_id)
if svg_sub:
display_route_with_add_button(
svg_sub,
f"Route {route_id}; score: {details['score']}; steps: {details['length']}",
route_id,
key_prefix=key_prefix,
)
else:
st.warning(f"Could not generate SVG for route {route_id}.")
except Exception as e:
st.error(f"Error displaying route {route_id}: {e}")
# --- Subclustering: core + UI + downloads ---
def run_subclustering_core():
"""
Core subclustering logic, with automatic skipping of problematic routes.
Strategy:
- Try subcluster_all_clusters.
- If it fails and the error message contains 'route <id>',
remove that route from all clustering dictionaries and retry.
- Repeat a few times; if still failing, give up with an error.
"""
st.session_state.subclustering_done = False
st.session_state.subclusters = None
clusters = st.session_state.get("clusters")
sb_cgrs_dict = st.session_state.get("sb_cgrs_dict")
route_cgrs_dict = st.session_state.get("route_cgrs_dict")
if not (clusters and sb_cgrs_dict and route_cgrs_dict):
st.error(
"Cannot run subclustering. Missing clusters / SB-CGRs / RouteCGRs. "
"Please ensure clustering ran successfully."
)
return
skipped_routes = set()
# Arbitrary cap to avoid infinite loops if something else is wrong
MAX_ATTEMPTS = 5
for attempt in range(1, MAX_ATTEMPTS + 1):
with st.spinner(f"Performing subclustering analysis (attempt {attempt})..."):
try:
all_subgroups = subcluster_all_clusters(
clusters,
sb_cgrs_dict,
route_cgrs_dict,
)
# Success
st.session_state.clusters = clusters
st.session_state.sb_cgrs_dict = sb_cgrs_dict
st.session_state.route_cgrs_dict = route_cgrs_dict
st.session_state.subclusters = all_subgroups
st.session_state.subclustering_done = True
if skipped_routes:
st.info(
"Subclustering finished after automatically skipping "
f"{len(skipped_routes)} problematic route(s): "
+ ", ".join(str(r) for r in sorted(skipped_routes))
)
gc.collect()
return
except Exception as e:
msg = str(e)
# Look for "... route 1267" in the error text
m = re.search(r"route\s+(\d+)", msg)
if not m:
# No route id -> we cannot automatically fix this
st.error(f"An error occurred during subclustering: {e}")
st.session_state.subclustering_done = False
return
bad_id = int(m.group(1))
if bad_id in skipped_routes:
# We already skipped this one but it still fails -> abort
st.error(
"Subclustering still failing even after skipping route "
f"{bad_id}. Last error: {e}"
)
st.session_state.subclustering_done = False
return
skipped_routes.add(bad_id)
st.warning(
f"Subclustering failed for route {bad_id}; "
"it will be ignored in clusters, Overview and subclustering."
)
# Filter this route out and retry
clusters, sb_cgrs_dict, route_cgrs_dict = filter_routes_for_clustering(
clusters,
sb_cgrs_dict,
route_cgrs_dict,
{bad_id},
)
st.error(
"Subclustering failed after multiple attempts even after skipping "
"problematic routes."
)
st.session_state.subclustering_done = False
def setup_subclustering():
"""Subclustering tab header + optional re-run button."""
if not st.session_state.get("clustering_done", False):
return
st.header("Subclustering within a selected cluster")
with st.expander("What is subclustering?"):
st.markdown(
"The first two numbers define the cluster of interest (e.g., 2.1), "
"while the final designation (such as 3_1) indicates that the selected "
"subcluster contains three leaving groups in the Markush-like representation "
"of the abstracted RouteCGR, with “1” specifying a particular set of leaving groups."
)
st.caption(
"Subclustering is pre-computed after clustering so you can switch tabs without extra waiting."
)
if st.button("Re-run subclustering analysis", key="rerun_subclustering_button"):
run_subclustering_core()
st.rerun()
def display_subclustering_results(show_rule_labels=False):
"""Subclustering results display."""
if not st.session_state.get("subclustering_done", False):
st.info("Subclustering is not available. Please check clustering results.")
return
sub = st.session_state.get("subclusters")
tree = st.session_state.get("tree")
route_json = st.session_state.get("route_json")
if not all([sub, tree, route_json]):
st.error("Subclustering results are missing. Please re-run subclustering.")
st.session_state.subclustering_done = False
return
sub_input_col, sub_display_col = st.columns([0.15, 0.85])
with sub_input_col:
st.subheader("Select cluster")
available_cluster_nums = sorted(list(sub.keys()))
if not available_cluster_nums:
st.warning("No clusters available in subclustering results.")
return
sel_cluster_num = st.selectbox(
"Cluster #",
options=available_cluster_nums,
key="subcluster_num_select_key",
)
sub_step_cluster = sub.get(sel_cluster_num, {})
allowed_subclusters = sorted(list(sub_step_cluster.keys()))
if not allowed_subclusters:
st.warning(f"No subclusters found for cluster {sel_cluster_num}.")
return
sel_subcluster_idx = st.selectbox(
"Subcluster index",
options=allowed_subclusters,
key="subcluster_index_select_key",
)
current_subcluster_data = sub_step_cluster.get(sel_subcluster_idx)
if not current_subcluster_data:
st.warning("Selected subcluster not found.")
return
if "sb_cgr" in current_subcluster_data:
st.image(
generate_sb_cgr_image(current_subcluster_data["sb_cgr"]),
caption=f"SB-CGR of parent cluster {sel_cluster_num}",
)
all_routes_in_subcluster = current_subcluster_data.get("routes_data", {}).keys()
route_details_list = [
get_route_details(tree, rid) for rid in all_routes_in_subcluster
]
if not route_details_list:
min_steps, max_steps = 1, 2
else:
all_steps = [details["length"] for details in route_details_list]
min_steps, max_steps = min(all_steps), max(all_steps)
if min_steps < max_steps:
min_max_step = st.slider(
"Filter by number of steps",
min_value=min_steps,
max_value=max_steps,
value=(min_steps, max_steps),
)
else:
st.write(f"Routes with only one possible number of steps: **{min_steps}**")
min_max_step = (min_steps, max_steps)
with sub_display_col:
st.subheader(
f"Details for subcluster {sel_cluster_num}.{sel_subcluster_idx}: "
f"total {len(all_routes_in_subcluster)} routes"
)
filtered_routes = [
(rid, details)
for rid, details in zip(all_routes_in_subcluster, route_details_list)
if min_max_step[0] <= details["length"] <= min_max_step[1]
]
if not filtered_routes:
st.info("No routes match the current filter settings.")
return
st.markdown(
f"--- \n**Displaying {len(filtered_routes)} routes "
f"(from {min_max_step[0]} to {min_max_step[1]} reaction steps)**"
)
if "synthon_reaction" in current_subcluster_data:
try:
st.image(
generate_synthon_reaction_image(
current_subcluster_data["synthon_reaction"]
),
caption="Markush-like pseudo reaction of subcluster",
)
except Exception as e_depict:
st.warning(f"Could not depict synthon reaction: {e_depict}")
MAX_ROUTES_PER_SUBCLUSTER = 5
routes_to_display_direct = filtered_routes[:MAX_ROUTES_PER_SUBCLUSTER]
remaining_routes = filtered_routes[MAX_ROUTES_PER_SUBCLUSTER:]
with st.container(height=500):
for route_id, details in routes_to_display_direct:
display_single_route(
route_id,
route_json,
details,
tree=tree,
show_rule_labels=show_rule_labels,
key_prefix=f"subcluster_{sel_cluster_num}_{sel_subcluster_idx}",
)
if remaining_routes:
with st.expander(f"... and {len(remaining_routes)} more routes"):
for route_id, details in remaining_routes:
display_single_route(
route_id,
route_json,
details,
tree=tree,
show_rule_labels=show_rule_labels,
key_prefix=(
f"subcluster_expanded_{sel_cluster_num}_{sel_subcluster_idx}"
),
)
def download_subclustering_results():
"""Subclustering results download for the currently selected subcluster."""
if (
st.session_state.get("subclustering_done", False)
and "subcluster_num_select_key" in st.session_state
and "subcluster_index_select_key" in st.session_state
):
sub = st.session_state.get("subclusters")
tree = st.session_state.get("tree")
sb_cgrs_for_report = st.session_state.get("sb_cgrs_dict")
user_input_cluster_num_display = st.session_state.subcluster_num_select_key
selected_subcluster_idx = st.session_state.subcluster_index_select_key
if not tree or not sub or not sb_cgrs_for_report:
st.warning(
"Missing data for subclustering report generation (tree, subclusters, or SB-CGRs)."
)
return
if (
user_input_cluster_num_display in sub
and selected_subcluster_idx in sub[user_input_cluster_num_display]
):
subcluster_data_for_report = sub[user_input_cluster_num_display][
selected_subcluster_idx
]
processed_subcluster_data = post_process_subgroup(
subcluster_data_for_report
)
if "routes_data" in subcluster_data_for_report and isinstance(
subcluster_data_for_report["routes_data"], dict
):
processed_subcluster_data["group_lgs"] = group_by_identical_values(
subcluster_data_for_report["routes_data"]
)
else:
processed_subcluster_data["group_lgs"] = {}
try:
subcluster_html_content = routes_subclustering_report(
tree,
processed_subcluster_data,
user_input_cluster_num_display,
selected_subcluster_idx,
sb_cgrs_for_report,
if_lg_group=True,
)
st.download_button(
label=(
f"Download report for subcluster "
f"{user_input_cluster_num_display}.{selected_subcluster_idx}"
),
data=subcluster_html_content,
file_name=(
f"subcluster_{user_input_cluster_num_display}."
f"{selected_subcluster_idx}_{st.session_state.target_smiles}.html"
),
mime="text/html",
key=(
f"download_subcluster_{user_input_cluster_num_display}_"
f"{selected_subcluster_idx}"
),
)
except Exception as e:
st.error(
f"Error generating download report for subcluster "
f"{user_input_cluster_num_display}.{selected_subcluster_idx}: {e}"
)
def display_downloads_tab():
"""All download actions grouped in one place."""
st.subheader("Planning reports")
download_planning_results()
st.markdown("---")
st.subheader("Cluster reports")
download_clustering_results()
st.markdown("---")
st.subheader("Subclustering reports")
st.caption(
"Select a cluster and subcluster in the Subclustering tab first, "
"then return here to export the corresponding report."
)
download_subclustering_results()
def implement_restart():
"""Restart: reset application state."""
st.divider()
st.header("Restart application state")
if st.button("Clear all results & restart", key="restart_button"):
keys_to_clear = [
"planning_done",
"tree",
"res",
"target_smiles",
"clustering_done",
"clusters",
"reactions_dict",
"num_clusters_setting",
"route_cgrs_dict",
"sb_cgrs_dict",
"route_json",
"subclustering_done",
"subclusters",
"clusters_downloaded",
"ketcher_widget",
"smiles_text_input_key",
"subcluster_num_select_key",
"subcluster_index_select_key",
"planning_report_html",
"selected_routes",
]
for key in keys_to_clear:
if key in st.session_state:
del st.session_state[key]
st.session_state.ketcher = DEFAULT_MOL
st.session_state.target_smiles = ""
st.rerun()
# --- Main Application Flow with tabs ---
def main():
initialize_app()
setup_sidebar()
# --- Top section: Input & Planning ---
current_smile_code = handle_molecule_input()
if st.session_state.get("ketcher") != current_smile_code:
st.session_state.ketcher = current_smile_code
setup_planning_options()
# If nothing has been run yet, stop here
if not st.session_state.get("planning_done", False):
implement_restart()
return
res = st.session_state.res
# Not solved: show simple planning section, no tabs needed
if not (res and res.get("solved", False)):
st.markdown("---")
st.header("Results")
display_planning_results()
implement_restart()
return
# Solved: run clustering (once)
if not st.session_state.get("clustering_done", False):
run_clustering_core()
if not st.session_state.get("clustering_done", False):
# Planning solved but clustering failed
st.markdown("---")
st.header("Results")
st.success("Planning succeeded.")
st.error(
"Clustering did not complete successfully. Please check logs or adjust settings."
)
st.subheader("Planning downloads")
download_planning_results()
implement_restart()
return
# Clustering done; pre-compute subclustering so navigation is instant
if not st.session_state.get("subclustering_done", False):
run_subclustering_core()
# From this point we have: planning_done, solved, clustering_done (+ usually subclustering_done)
# Clear separation between planning controls and results
st.markdown("---")
st.header("Results")
render_selected_routes_dialog(key_prefix="dialog")
render_selected_routes_navigation()
# Small status row
col1, col2, col3 = st.columns(3)
with col1:
st.metric("Solution found", "Yes")
with col2:
st.metric("Routes", st.session_state.res.get("num_routes", "—"))
with col3:
st.metric("Clusters", len(st.session_state.clusters))
# Results tabs
tab_overview, tab_subclustering, tab_downloads = st.tabs(
["Overview", "Subclustering", "Downloads"]
)
with tab_overview:
show_rule_labels = st.checkbox(
"Show rule labels in routes", key="overview_show_rule_labels"
)
display_planning_and_clustering_results_unified(
show_rule_labels=show_rule_labels
)
with tab_subclustering:
setup_subclustering()
if st.session_state.get("subclustering_done", False):
st.caption(
"Select a cluster and subcluster, and optionally filter routes by number of steps."
)
show_rule_labels = st.checkbox(
"Show rule labels in routes", key="subcluster_show_rule_labels"
)
display_subclustering_results(show_rule_labels=show_rule_labels)
with tab_downloads:
display_downloads_tab()
render_selected_routes_minimized_bar(key_prefix="minbar")
implement_restart()
if __name__ == "__main__":
main()