#!/usr/bin/env python3 """ Simple HTML visualization for a SynPlanner MCTS tree. """ from __future__ import annotations import argparse import json import math import pickle from collections import deque from pathlib import Path from typing import Dict, Iterable, List, Optional, Set, Tuple from synplan.mcts.tree import Tree from synplan.chem.utils import mol_from_smiles from synplan.utils.visualisation import get_route_svg def _node_primary_molecule(node) -> Optional[object]: if getattr(node, "curr_precursor", None) and hasattr(node.curr_precursor, "molecule"): return node.curr_precursor.molecule if getattr(node, "new_precursors", None): precursor = node.new_precursors[0] if hasattr(precursor, "molecule"): return precursor.molecule if getattr(node, "precursors_to_expand", None): precursor = node.precursors_to_expand[0] if hasattr(precursor, "molecule"): return precursor.molecule return None def _depict_molecule_svg(molecule) -> Optional[str]: if molecule is None: return None try: molecule.clean2d() return molecule.depict() except Exception: return None def _svg_from_smiles(smiles: str) -> Optional[str]: if not smiles: return None try: molecule = mol_from_smiles(smiles) except Exception: return None return _depict_molecule_svg(molecule) def _build_target_svg(tree: Tree) -> str: target_node = tree.nodes.get(1) if not target_node or not getattr(target_node, "curr_precursor", None): return "" molecule = target_node.curr_precursor.molecule try: molecule.clean2d() except Exception: pass plane = getattr(molecule, "_plane", None) if not plane: return "" xs = [coord[0] for coord in plane.values()] ys = [coord[1] for coord in plane.values()] if not xs or not ys: return "" pad = 0.7 min_x, max_x = min(xs) - pad, max(xs) + pad min_y, max_y = min(ys) - pad, max(ys) + pad min_y_svg = -max_y max_y_svg = -min_y width = max_x - min_x height = max_y_svg - min_y_svg bond_lines = [] bond_offset = 0.18 for a, b, _bond in molecule.bonds(): if a not in plane or b not in plane: continue x1, y1 = plane[a] x2, y2 = plane[b] y1 = -y1 y2 = -y2 dx = x2 - x1 dy = y2 - y1 norm = (dx * dx + dy * dy) ** 0.5 if norm == 0: continue perp_x = -dy / norm perp_y = dx / norm bond_id = f"{a}-{b}" if a < b else f"{b}-{a}" order = getattr(_bond, "order", 1) if order == 2: offsets = [-bond_offset, bond_offset] bond_class = "target-bond bond-double" elif order == 3: offsets = [-bond_offset, 0.0, bond_offset] bond_class = "target-bond bond-triple" elif order == 4: offsets = [0.0] bond_class = "target-bond bond-aromatic" else: offsets = [0.0] bond_class = "target-bond bond-single" for offset in offsets: ox = perp_x * offset oy = perp_y * offset bond_lines.append( f'' ) atom_marks = [] atom_radius = 0.14 label_size = 0.5 atom_colors = { "N": "#2f6fd0", "O": "#e14b4b", "S": "#d3b338", "F": "#2aa84a", "Cl": "#2aa84a", "Br": "#2aa84a", "I": "#2aa84a", } for atom_id, atom in molecule.atoms(): if atom_id not in plane: continue x, y = plane[atom_id] y = -y symbol = getattr(atom, "atomic_symbol", "C") fill = atom_colors.get(symbol, "#1f242a55") if symbol == "C": atom_marks.append( f'' ) else: mask_radius = atom_radius * 1.9 atom_marks.append( f'' ) atom_marks.append( f'{symbol}' ) return f""" {"".join(bond_lines)} {"".join(atom_marks)} """ def _ends_with_pickle_stop(path: Path) -> bool: size = path.stat().st_size if size == 0: return False with path.open("rb") as handle: handle.seek(-1, 2) return handle.read(1) == b"." def _load_tree(tree_pkl: Path) -> Tree: if not tree_pkl.exists(): raise FileNotFoundError(f"Tree pickle not found: {tree_pkl}") if tree_pkl.stat().st_size == 0: raise ValueError(f"Tree pickle is empty: {tree_pkl}") if not _ends_with_pickle_stop(tree_pkl): raise ValueError( "Tree pickle appears truncated (missing STOP opcode). " "Re-save the tree and try again." ) try: with tree_pkl.open("rb") as handle: loaded = pickle.load(handle) except EOFError as exc: raise ValueError( "Tree pickle is incomplete or corrupted (unexpected EOF). " "Re-save the tree and try again." ) from exc except pickle.UnpicklingError as exc: raise ValueError( "Tree pickle could not be unpickled. " "Re-save the tree and try again." ) from exc except ModuleNotFoundError as exc: raise ModuleNotFoundError( "Missing dependency while unpickling. " "Run this in the same environment where the tree was saved." ) from exc if hasattr(loaded, "tree"): return loaded.tree return loaded def _load_clusters(clusters_pkl: Optional[Path]) -> Dict[str, dict]: if not clusters_pkl: return {} clusters_pkl = Path(clusters_pkl) if not clusters_pkl.exists(): raise FileNotFoundError(f"Clusters pickle not found: {clusters_pkl}") if clusters_pkl.stat().st_size == 0: raise ValueError(f"Clusters pickle is empty: {clusters_pkl}") try: with clusters_pkl.open("rb") as handle: loaded = pickle.load(handle) except EOFError as exc: raise ValueError( "Clusters pickle is incomplete or corrupted (unexpected EOF). " "Re-save the clusters and try again." ) from exc except pickle.UnpicklingError as exc: raise ValueError( "Clusters pickle could not be unpickled. " "Re-save the clusters and try again." ) from exc except ModuleNotFoundError as exc: raise ModuleNotFoundError( "Missing dependency while unpickling clusters. " "Run this in the same environment where the clusters were saved." ) from exc if not isinstance(loaded, dict): raise TypeError("Clusters pickle must contain a dict.") return loaded def _route_nodes_by_route( tree: Tree, route_ids: Optional[Iterable[int]] = None ) -> Dict[str, List[int]]: if route_ids is None: route_ids_set = set(tree.winning_nodes) else: route_ids_set = {int(rid) for rid in route_ids} route_nodes: Dict[str, List[int]] = {} for route_id in sorted(route_ids_set): if route_id not in tree.nodes: continue nodes: List[int] = [] current = route_id seen: Set[int] = set() while current and current not in seen: seen.add(current) nodes.append(current) current = tree.parents.get(current) route_nodes[str(route_id)] = nodes return route_nodes def _is_building_block(precursor, tree: Tree) -> bool: if precursor is None: return False try: return precursor.is_building_block( tree.building_blocks, getattr(tree.config, "min_mol_size", 0) ) except Exception: try: return str(precursor) in tree.building_blocks except Exception: return False def _route_extras_by_route( tree: Tree, route_ids: Iterable[int] ) -> Dict[str, Dict[str, object]]: extras: Dict[str, Dict[str, object]] = {} for route_id in sorted(route_ids): if route_id not in tree.nodes: continue path_ids: List[int] = [] current = route_id seen: Set[int] = set() while current and current not in seen: seen.add(current) path_ids.append(current) current = tree.parents.get(current) path_ids = list(reversed(path_ids)) if len(path_ids) < 2: continue smiles_to_node: Dict[str, int] = {} base_smiles: Set[str] = set() for node_id in path_ids: node = tree.nodes.get(node_id) molecule = _node_primary_molecule(node) if molecule is None: continue try: smiles = str(molecule) except Exception: continue if smiles: base_smiles.add(smiles) if smiles not in smiles_to_node: smiles_to_node[smiles] = node_id by_parent: Dict[str, List[Dict[str, str]]] = {} route_seen_smiles: Set[str] = set() for before_id, after_id in zip(path_ids, path_ids[1:]): before = tree.nodes.get(before_id) after = tree.nodes.get(after_id) if before is None or after is None: continue parent_precursor = getattr(before, "curr_precursor", None) if parent_precursor is None: continue try: parent_smiles = str(parent_precursor) except Exception: continue if not parent_smiles: continue extra_items: List[Dict[str, str]] = [] seen_smiles: Set[str] = set() for precursor in getattr(after, "new_precursors", ()) or (): try: child_smiles = str(precursor) except Exception: continue if ( not child_smiles or child_smiles in smiles_to_node or child_smiles in base_smiles or child_smiles in route_seen_smiles ): continue if child_smiles in seen_smiles: continue seen_smiles.add(child_smiles) route_seen_smiles.add(child_smiles) status = "leaf" if _is_building_block(precursor, tree) else "intermediate" extra_items.append( { "smiles": child_smiles, "status": status, "svg": _svg_from_smiles(child_smiles), } ) if extra_items: by_parent[str(before_id)] = extra_items if by_parent: extras[str(route_id)] = {"by_parent": by_parent} return extras def _normalize_strat_bonds( strat_bonds: Optional[Iterable[Iterable[int]]], ) -> List[List[int]]: if not strat_bonds: return [] normalized: List[List[int]] = [] seen: Set[Tuple[int, int]] = set() for bond in strat_bonds: if not bond or len(bond) < 2: continue try: a, b = bond except (TypeError, ValueError): continue try: a_int = int(a) b_int = int(b) except (TypeError, ValueError): continue pair = tuple(sorted((a_int, b_int))) if pair in seen: continue seen.add(pair) normalized.append([pair[0], pair[1]]) normalized.sort() return normalized def _build_cluster_payload(clusters: Dict[str, dict]) -> List[Dict[str, object]]: payload: List[Dict[str, object]] = [] for cluster_id, data in clusters.items(): if not isinstance(data, dict): continue bonds = _normalize_strat_bonds(data.get("strat_bonds")) route_ids_raw = data.get("route_ids") or [] route_ids: List[int] = [] for rid in route_ids_raw: try: route_ids.append(int(rid)) except (TypeError, ValueError): continue payload.append( { "id": str(cluster_id), "bonds": bonds, "route_ids": sorted(set(route_ids)), } ) return payload def _group_nodes_by_depth(nodes_depth: Dict[int, int]) -> Dict[int, list]: by_depth: Dict[int, list] = {} for node_id, depth in nodes_depth.items(): by_depth.setdefault(depth, []).append(node_id) for node_ids in by_depth.values(): node_ids.sort() return by_depth def _build_children_map( tree: Tree, allowed_nodes: Optional[Set[int]] = None ) -> Dict[int, List[int]]: if allowed_nodes is None: allowed_nodes = set(tree.nodes.keys()) else: allowed_nodes = set(allowed_nodes) allowed_nodes.add(1) children_map: Dict[int, List[int]] = {node_id: [] for node_id in allowed_nodes} for child_id, parent_id in tree.parents.items(): if child_id == 1 or not parent_id: continue if parent_id not in allowed_nodes or child_id not in allowed_nodes: continue children_map[parent_id].append(child_id) for node_id, children in children_map.items(): children.sort() return children_map def _sorted_children(children_map: Dict[int, List[int]], node_id: int) -> List[int]: return children_map.get(node_id, []) def _compute_depths( children_map: Dict[int, List[int]], root_id: int = 1 ) -> Dict[int, int]: if root_id not in children_map: return {} depths = {root_id: 0} queue = deque([root_id]) while queue: node_id = queue.popleft() for child_id in children_map.get(node_id, []): if child_id in depths: continue depths[child_id] = depths[node_id] + 1 queue.append(child_id) return depths def _compute_subtree_leaf_counts( children_map: Dict[int, List[int]], root_id: int = 1 ) -> Dict[int, int]: order: List[int] = [] if root_id not in children_map: return {} stack = [root_id] while stack: node_id = stack.pop() order.append(node_id) stack.extend(_sorted_children(children_map, node_id)) leaf_counts: Dict[int, int] = {} for node_id in reversed(order): children = _sorted_children(children_map, node_id) if not children: leaf_counts[node_id] = 1 else: leaf_counts[node_id] = sum(leaf_counts[c] for c in children) return leaf_counts def _assign_subtree_angles( children_map: Dict[int, List[int]], leaf_counts: Dict[int, int], root_id: int = 1, base_gap: float = 0.04, ) -> Dict[int, float]: angles: Dict[int, float] = {} if root_id not in children_map: return angles stack = [(root_id, 0.0, 2.0 * math.pi)] while stack: node_id, start_angle, end_angle = stack.pop() angles[node_id] = (start_angle + end_angle) / 2.0 children = _sorted_children(children_map, node_id) if not children: continue span = max(0.0, end_angle - start_angle) total = sum(leaf_counts.get(child, 1) for child in children) if total <= 0 or span <= 0.0: continue if len(children) > 1: max_gap = span * 0.15 / (len(children) - 1) gap = min(base_gap, max_gap) else: gap = 0.0 span_for_children = span - gap * (len(children) - 1) if span_for_children <= 0.0: gap = 0.0 span_for_children = span cursor = start_angle for child in children: frac = leaf_counts.get(child, 1) / total child_span = span_for_children * frac child_start = cursor child_end = cursor + child_span stack.append((child, child_start, child_end)) cursor = child_end + gap return angles def _compute_radius_scale( by_depth: Dict[int, list], angles: Dict[int, float], radius_step: float, node_radius: float, spacing_factor: float = 2.2, root_gap_factor: float = 2.8, ) -> float: min_distance = node_radius * spacing_factor scale = 1.0 epsilon = 1e-6 scale = max(scale, min_distance / max(radius_step, epsilon)) root_gap = node_radius * root_gap_factor scale = max(scale, root_gap / max(radius_step, epsilon)) for depth, node_ids in by_depth.items(): if depth == 0 or len(node_ids) < 2: continue radius = depth * radius_step depth_angles = sorted(angles.get(node_id, 0.0) for node_id in node_ids) deltas = [] for left, right in zip(depth_angles, depth_angles[1:]): deltas.append(right - left) deltas.append(2.0 * math.pi - depth_angles[-1] + depth_angles[0]) min_delta = max(min(deltas), epsilon) required = min_distance / (radius * min_delta) scale = max(scale, required) return max(scale, 1.0) def _radial_layout( nodes_depth: Dict[int, int], children_map: Dict[int, List[int]], radius_step: float, node_radius: float, spacing_factor: float = 2.2, root_gap_factor: float = 2.8, ) -> Dict[int, Tuple[float, float]]: if not nodes_depth: return {} by_depth = _group_nodes_by_depth(nodes_depth) leaf_counts = _compute_subtree_leaf_counts(children_map) angles = _assign_subtree_angles(children_map, leaf_counts) scale = _compute_radius_scale( by_depth, angles, radius_step, node_radius, spacing_factor=spacing_factor, root_gap_factor=root_gap_factor, ) radius_step *= scale positions: Dict[int, Tuple[float, float]] = {} for node_id, depth in nodes_depth.items(): if depth == 0: positions[node_id] = (0.0, 0.0) continue angle = angles.get(node_id, 0.0) radius = depth * radius_step positions[node_id] = (radius * math.cos(angle), radius * math.sin(angle)) return positions def _node_status(tree: Tree, node_id: int) -> str: if node_id == 1: return "target" node = tree.nodes[node_id] if node.is_solved(): return "leaf" return "intermediate" def _node_metadata( tree: Tree, node_id: int, route_index: Dict[int, int] ) -> Dict[str, object]: node = tree.nodes[node_id] molecule = _node_primary_molecule(node) smiles = str(molecule) if molecule is not None else None return { "node_id": node_id, "route_id": node_id if node_id in route_index else None, "route_index": route_index.get(node_id), "depth": tree.nodes_depth.get(node_id, 0), "visits": tree.nodes_visit.get(node_id, 0), "num_children": len(tree.children.get(node_id, [])), "rule_id": tree.nodes_rules.get(node_id), "rule_label": tree.nodes_rule_label.get(node_id), "solved": bool(node.is_solved()), "smiles": smiles, "svg": _depict_molecule_svg(molecule), "pending_smiles": ( [str(p) for p in node.precursors_to_expand] if node.precursors_to_expand else [] ), "new_smiles": ( [str(p) for p in node.new_precursors] if node.new_precursors else [] ), } def _edges_from_tree( tree: Tree, allowed_nodes: Optional[Set[int]] = None ) -> Iterable[Tuple[int, int]]: allowed = set(allowed_nodes) if allowed_nodes is not None else None for child_id, parent_id in tree.parents.items(): if child_id == 1 or not parent_id: continue if allowed is not None and (child_id not in allowed or parent_id not in allowed): continue yield parent_id, child_id def _winning_route_edges(tree: Tree) -> set: edges = set() for node_id in tree.winning_nodes: current = node_id while current and current in tree.parents: parent = tree.parents.get(current) if not parent: break edges.add((parent, current)) current = parent return edges def _scale_positions( positions: Dict[int, Tuple[float, float]], node_radius: float, render_scale: float, ) -> Tuple[Dict[int, Tuple[float, float]], float]: render_scale = max(render_scale, 0.01) scaled_positions = { node_id: (pos[0] * render_scale, pos[1] * render_scale) for node_id, pos in positions.items() } return scaled_positions, node_radius * render_scale def _render_svg( tree: Tree, positions: Dict[int, Tuple[float, float]], edges: Iterable[Tuple[int, int]], winning_edges: Set[Tuple[int, int]], node_radius: float, nodes_depth: Dict[int, int], radius_step: Optional[float] = None, pad_scale: float = 4.0, ) -> str: if not positions: return "" xs = [pos[0] for pos in positions.values()] ys = [pos[1] for pos in positions.values()] max_radius = node_radius * 2.0 pad = max_radius * pad_scale min_x, max_x = min(xs) - pad, max(xs) + pad min_y, max_y = min(ys) - pad, max(ys) + pad view_w = max_x - min_x view_h = max_y - min_y line_items = [] for parent_id, child_id in edges: if parent_id not in positions or child_id not in positions: continue x1, y1 = positions[parent_id] x2, y2 = positions[child_id] line_class = "edge-winning" if (parent_id, child_id) in winning_edges else "" line_items.append( f'' ) circle_items = [] for node_id, (x, y) in positions.items(): status = _node_status(tree, node_id) delay = min(nodes_depth.get(node_id, 0) * 0.03, 0.6) radius = node_radius * 2.0 if node_id == 1 else node_radius parent_id = tree.parents.get(node_id, 0) circle_items.append( f'' ) depth_circles = [] if radius_step is not None: max_depth = max(nodes_depth.values()) if nodes_depth else 0 for depth in range(1, max_depth + 1): depth_circles.append( f'' ) return f""" {"".join(depth_circles)} {"".join(line_items)} {"".join(circle_items)} """ def _format_point(value: float) -> str: return f"{value:.2f}" def generate_tree_html( tree: Tree, output_path: Path, radius_step: float = 280.0, node_radius: float = 80.0, render_scale: float = 0.25, clusters_pkl: Optional[Path] = None, ) -> None: full_children = _build_children_map(tree) full_nodes_depth = _compute_depths(full_children) full_radius_step = radius_step * 0.4 full_node_radius = node_radius * 10.0 full_render_scale = min(render_scale * 1.0, 1.0) full_positions = _radial_layout( full_nodes_depth, full_children, radius_step=full_radius_step, node_radius=full_node_radius, spacing_factor=3.4, root_gap_factor=4.0, ) full_angles = _assign_subtree_angles( full_children, _compute_subtree_leaf_counts(full_children) ) full_layout_scale = _compute_radius_scale( _group_nodes_by_depth(full_nodes_depth), full_angles, full_radius_step, full_node_radius, spacing_factor=3.4, root_gap_factor=4.0, ) full_edges = list(_edges_from_tree(tree)) winning_edges = _winning_route_edges(tree) if not full_positions: raise ValueError("Tree has no nodes to render.") if winning_edges: solved_nodes: Set[int] = {1} for parent_id, child_id in winning_edges: solved_nodes.add(parent_id) solved_nodes.add(child_id) else: solved_nodes = set(tree.nodes.keys()) solved_children = _build_children_map(tree, allowed_nodes=solved_nodes) solved_nodes_depth = _compute_depths(solved_children) solved_radius_step = radius_step * 0.5 solved_node_radius = node_radius * 1.15 solved_render_scale = min(render_scale * 0.4, 1.0) solved_positions = _radial_layout( solved_nodes_depth, solved_children, radius_step=solved_radius_step, node_radius=solved_node_radius, ) solved_angles = _assign_subtree_angles( solved_children, _compute_subtree_leaf_counts(solved_children) ) solved_layout_scale = _compute_radius_scale( _group_nodes_by_depth(solved_nodes_depth), solved_angles, solved_radius_step, solved_node_radius, ) solved_edges = list(_edges_from_tree(tree, allowed_nodes=solved_nodes)) full_positions, full_render_radius = _scale_positions( full_positions, full_node_radius, full_render_scale ) solved_positions, solved_render_radius = _scale_positions( solved_positions, solved_node_radius, solved_render_scale ) full_render_step = full_radius_step * full_layout_scale * full_render_scale solved_render_step = solved_radius_step * solved_layout_scale * solved_render_scale svg_full = _render_svg( tree, full_positions, full_edges, winning_edges, full_render_radius, full_nodes_depth, radius_step=full_render_step, pad_scale=1.8, ) svg_solved = _render_svg( tree, solved_positions, solved_edges, winning_edges, solved_render_radius, solved_nodes_depth, radius_step=solved_render_step, pad_scale=1.6, ) if not svg_solved: svg_solved = svg_full solved_positions = full_positions node_meta = {} route_index = {node_id: idx for idx, node_id in enumerate(tree.winning_nodes)} for node_id in tree.nodes: node_meta[str(node_id)] = _node_metadata(tree, node_id, route_index) target_svg = _build_target_svg(tree) target_smiles = str(tree.nodes[1].curr_precursor) if tree.nodes.get(1) else "" clusters_payload: List[Dict[str, object]] = [] route_nodes: Dict[str, List[int]] = {} route_extras: Dict[str, Dict[str, object]] = {} route_svgs: Dict[str, str] = {} if clusters_pkl: clusters = _load_clusters(clusters_pkl) clusters_payload = _build_cluster_payload(clusters) cluster_route_ids = { route_id for cluster in clusters_payload for route_id in cluster.get("route_ids", []) } route_nodes = _route_nodes_by_route(tree, cluster_route_ids) route_extras = _route_extras_by_route(tree, cluster_route_ids) for route_id in sorted(cluster_route_ids): if route_id not in tree.winning_nodes: continue svg = get_route_svg(tree, route_id) if svg: route_svgs[str(route_id)] = svg html = f""" SynPlanner Tree Visualization
{svg_full}
{svg_solved}
Legend
Target molecule
Intermediate product
Building block
Solved pathway
Unsolved pathway
Route ID: Total {len(tree.winning_nodes)} Current 0
Target molecule
{target_svg if target_svg else "
Target depiction unavailable.
"}
Selected bonds (CGRtools atom ids): []
Clusters: not loaded
""" output_path = Path(output_path) output_path.write_text(html, encoding="utf-8") def main() -> None: parser = argparse.ArgumentParser( description="Generate a simple HTML visualization for a SynPlanner MCTS tree." ) parser.add_argument( "--tree-pkl", type=Path, required=True, help="Path to a pickled Tree or TreeWrapper object.", ) parser.add_argument( "--out", type=Path, default=Path("tree_visualization.html"), help="Output HTML file path.", ) parser.add_argument( "--radius-step", type=float, default=280.0, help="Radial distance between depth rings.", ) parser.add_argument( "--node-radius", type=float, default=80.0, help="Node radius in SVG units (pixels when not fit-to-screen).", ) parser.add_argument( "--render-scale", type=float, default=0.25, help="Scale factor applied to the final render (1.0 = full size).", ) parser.add_argument( "--clusters-pkl", type=Path, default=None, help="Optional path to a clusters pickle (cluster_routes output).", ) args = parser.parse_args() tree = _load_tree(args.tree_pkl) if not isinstance(tree, Tree): raise TypeError("Loaded object is not a Tree.") generate_tree_html( tree, output_path=args.out, radius_step=args.radius_step, node_radius=args.node_radius, render_scale=args.render_scale, clusters_pkl=args.clusters_pkl, ) print(f"Tree visualization written to {args.out}") if __name__ == "__main__": main()