#!/usr/bin/env python3
"""
Simple HTML visualization for a SynPlanner MCTS tree.
"""
from __future__ import annotations
import argparse
import json
import math
import pickle
from collections import deque
from pathlib import Path
from typing import Dict, Iterable, List, Optional, Set, Tuple
from synplan.mcts.tree import Tree
from synplan.chem.utils import mol_from_smiles
from synplan.utils.visualisation import get_route_svg
def _node_primary_molecule(node) -> Optional[object]:
if getattr(node, "curr_precursor", None) and hasattr(node.curr_precursor, "molecule"):
return node.curr_precursor.molecule
if getattr(node, "new_precursors", None):
precursor = node.new_precursors[0]
if hasattr(precursor, "molecule"):
return precursor.molecule
if getattr(node, "precursors_to_expand", None):
precursor = node.precursors_to_expand[0]
if hasattr(precursor, "molecule"):
return precursor.molecule
return None
def _depict_molecule_svg(molecule) -> Optional[str]:
if molecule is None:
return None
try:
molecule.clean2d()
return molecule.depict()
except Exception:
return None
def _svg_from_smiles(smiles: str) -> Optional[str]:
if not smiles:
return None
try:
molecule = mol_from_smiles(smiles)
except Exception:
return None
return _depict_molecule_svg(molecule)
def _build_target_svg(tree: Tree) -> str:
target_node = tree.nodes.get(1)
if not target_node or not getattr(target_node, "curr_precursor", None):
return ""
molecule = target_node.curr_precursor.molecule
try:
molecule.clean2d()
except Exception:
pass
plane = getattr(molecule, "_plane", None)
if not plane:
return ""
xs = [coord[0] for coord in plane.values()]
ys = [coord[1] for coord in plane.values()]
if not xs or not ys:
return ""
pad = 0.7
min_x, max_x = min(xs) - pad, max(xs) + pad
min_y, max_y = min(ys) - pad, max(ys) + pad
min_y_svg = -max_y
max_y_svg = -min_y
width = max_x - min_x
height = max_y_svg - min_y_svg
bond_lines = []
bond_offset = 0.18
for a, b, _bond in molecule.bonds():
if a not in plane or b not in plane:
continue
x1, y1 = plane[a]
x2, y2 = plane[b]
y1 = -y1
y2 = -y2
dx = x2 - x1
dy = y2 - y1
norm = (dx * dx + dy * dy) ** 0.5
if norm == 0:
continue
perp_x = -dy / norm
perp_y = dx / norm
bond_id = f"{a}-{b}" if a < b else f"{b}-{a}"
order = getattr(_bond, "order", 1)
if order == 2:
offsets = [-bond_offset, bond_offset]
bond_class = "target-bond bond-double"
elif order == 3:
offsets = [-bond_offset, 0.0, bond_offset]
bond_class = "target-bond bond-triple"
elif order == 4:
offsets = [0.0]
bond_class = "target-bond bond-aromatic"
else:
offsets = [0.0]
bond_class = "target-bond bond-single"
for offset in offsets:
ox = perp_x * offset
oy = perp_y * offset
bond_lines.append(
f''
)
atom_marks = []
atom_radius = 0.14
label_size = 0.5
atom_colors = {
"N": "#2f6fd0",
"O": "#e14b4b",
"S": "#d3b338",
"F": "#2aa84a",
"Cl": "#2aa84a",
"Br": "#2aa84a",
"I": "#2aa84a",
}
for atom_id, atom in molecule.atoms():
if atom_id not in plane:
continue
x, y = plane[atom_id]
y = -y
symbol = getattr(atom, "atomic_symbol", "C")
fill = atom_colors.get(symbol, "#1f242a55")
if symbol == "C":
atom_marks.append(
f''
)
else:
mask_radius = atom_radius * 1.9
atom_marks.append(
f''
)
atom_marks.append(
f'{symbol}'
)
return f"""
"""
def _ends_with_pickle_stop(path: Path) -> bool:
size = path.stat().st_size
if size == 0:
return False
with path.open("rb") as handle:
handle.seek(-1, 2)
return handle.read(1) == b"."
def _load_tree(tree_pkl: Path) -> Tree:
if not tree_pkl.exists():
raise FileNotFoundError(f"Tree pickle not found: {tree_pkl}")
if tree_pkl.stat().st_size == 0:
raise ValueError(f"Tree pickle is empty: {tree_pkl}")
if not _ends_with_pickle_stop(tree_pkl):
raise ValueError(
"Tree pickle appears truncated (missing STOP opcode). "
"Re-save the tree and try again."
)
try:
with tree_pkl.open("rb") as handle:
loaded = pickle.load(handle)
except EOFError as exc:
raise ValueError(
"Tree pickle is incomplete or corrupted (unexpected EOF). "
"Re-save the tree and try again."
) from exc
except pickle.UnpicklingError as exc:
raise ValueError(
"Tree pickle could not be unpickled. "
"Re-save the tree and try again."
) from exc
except ModuleNotFoundError as exc:
raise ModuleNotFoundError(
"Missing dependency while unpickling. "
"Run this in the same environment where the tree was saved."
) from exc
if hasattr(loaded, "tree"):
return loaded.tree
return loaded
def _load_clusters(clusters_pkl: Optional[Path]) -> Dict[str, dict]:
if not clusters_pkl:
return {}
clusters_pkl = Path(clusters_pkl)
if not clusters_pkl.exists():
raise FileNotFoundError(f"Clusters pickle not found: {clusters_pkl}")
if clusters_pkl.stat().st_size == 0:
raise ValueError(f"Clusters pickle is empty: {clusters_pkl}")
try:
with clusters_pkl.open("rb") as handle:
loaded = pickle.load(handle)
except EOFError as exc:
raise ValueError(
"Clusters pickle is incomplete or corrupted (unexpected EOF). "
"Re-save the clusters and try again."
) from exc
except pickle.UnpicklingError as exc:
raise ValueError(
"Clusters pickle could not be unpickled. "
"Re-save the clusters and try again."
) from exc
except ModuleNotFoundError as exc:
raise ModuleNotFoundError(
"Missing dependency while unpickling clusters. "
"Run this in the same environment where the clusters were saved."
) from exc
if not isinstance(loaded, dict):
raise TypeError("Clusters pickle must contain a dict.")
return loaded
def _route_nodes_by_route(
tree: Tree, route_ids: Optional[Iterable[int]] = None
) -> Dict[str, List[int]]:
if route_ids is None:
route_ids_set = set(tree.winning_nodes)
else:
route_ids_set = {int(rid) for rid in route_ids}
route_nodes: Dict[str, List[int]] = {}
for route_id in sorted(route_ids_set):
if route_id not in tree.nodes:
continue
nodes: List[int] = []
current = route_id
seen: Set[int] = set()
while current and current not in seen:
seen.add(current)
nodes.append(current)
current = tree.parents.get(current)
route_nodes[str(route_id)] = nodes
return route_nodes
def _is_building_block(precursor, tree: Tree) -> bool:
if precursor is None:
return False
try:
return precursor.is_building_block(
tree.building_blocks, getattr(tree.config, "min_mol_size", 0)
)
except Exception:
try:
return str(precursor) in tree.building_blocks
except Exception:
return False
def _route_extras_by_route(
tree: Tree, route_ids: Iterable[int]
) -> Dict[str, Dict[str, object]]:
extras: Dict[str, Dict[str, object]] = {}
for route_id in sorted(route_ids):
if route_id not in tree.nodes:
continue
path_ids: List[int] = []
current = route_id
seen: Set[int] = set()
while current and current not in seen:
seen.add(current)
path_ids.append(current)
current = tree.parents.get(current)
path_ids = list(reversed(path_ids))
if len(path_ids) < 2:
continue
smiles_to_node: Dict[str, int] = {}
base_smiles: Set[str] = set()
for node_id in path_ids:
node = tree.nodes.get(node_id)
molecule = _node_primary_molecule(node)
if molecule is None:
continue
try:
smiles = str(molecule)
except Exception:
continue
if smiles:
base_smiles.add(smiles)
if smiles not in smiles_to_node:
smiles_to_node[smiles] = node_id
by_parent: Dict[str, List[Dict[str, str]]] = {}
route_seen_smiles: Set[str] = set()
for before_id, after_id in zip(path_ids, path_ids[1:]):
before = tree.nodes.get(before_id)
after = tree.nodes.get(after_id)
if before is None or after is None:
continue
parent_precursor = getattr(before, "curr_precursor", None)
if parent_precursor is None:
continue
try:
parent_smiles = str(parent_precursor)
except Exception:
continue
if not parent_smiles:
continue
extra_items: List[Dict[str, str]] = []
seen_smiles: Set[str] = set()
for precursor in getattr(after, "new_precursors", ()) or ():
try:
child_smiles = str(precursor)
except Exception:
continue
if (
not child_smiles
or child_smiles in smiles_to_node
or child_smiles in base_smiles
or child_smiles in route_seen_smiles
):
continue
if child_smiles in seen_smiles:
continue
seen_smiles.add(child_smiles)
route_seen_smiles.add(child_smiles)
status = "leaf" if _is_building_block(precursor, tree) else "intermediate"
extra_items.append(
{
"smiles": child_smiles,
"status": status,
"svg": _svg_from_smiles(child_smiles),
}
)
if extra_items:
by_parent[str(before_id)] = extra_items
if by_parent:
extras[str(route_id)] = {"by_parent": by_parent}
return extras
def _normalize_strat_bonds(
strat_bonds: Optional[Iterable[Iterable[int]]],
) -> List[List[int]]:
if not strat_bonds:
return []
normalized: List[List[int]] = []
seen: Set[Tuple[int, int]] = set()
for bond in strat_bonds:
if not bond or len(bond) < 2:
continue
try:
a, b = bond
except (TypeError, ValueError):
continue
try:
a_int = int(a)
b_int = int(b)
except (TypeError, ValueError):
continue
pair = tuple(sorted((a_int, b_int)))
if pair in seen:
continue
seen.add(pair)
normalized.append([pair[0], pair[1]])
normalized.sort()
return normalized
def _build_cluster_payload(clusters: Dict[str, dict]) -> List[Dict[str, object]]:
payload: List[Dict[str, object]] = []
for cluster_id, data in clusters.items():
if not isinstance(data, dict):
continue
bonds = _normalize_strat_bonds(data.get("strat_bonds"))
route_ids_raw = data.get("route_ids") or []
route_ids: List[int] = []
for rid in route_ids_raw:
try:
route_ids.append(int(rid))
except (TypeError, ValueError):
continue
payload.append(
{
"id": str(cluster_id),
"bonds": bonds,
"route_ids": sorted(set(route_ids)),
}
)
return payload
def _group_nodes_by_depth(nodes_depth: Dict[int, int]) -> Dict[int, list]:
by_depth: Dict[int, list] = {}
for node_id, depth in nodes_depth.items():
by_depth.setdefault(depth, []).append(node_id)
for node_ids in by_depth.values():
node_ids.sort()
return by_depth
def _build_children_map(
tree: Tree, allowed_nodes: Optional[Set[int]] = None
) -> Dict[int, List[int]]:
if allowed_nodes is None:
allowed_nodes = set(tree.nodes.keys())
else:
allowed_nodes = set(allowed_nodes)
allowed_nodes.add(1)
children_map: Dict[int, List[int]] = {node_id: [] for node_id in allowed_nodes}
for child_id, parent_id in tree.parents.items():
if child_id == 1 or not parent_id:
continue
if parent_id not in allowed_nodes or child_id not in allowed_nodes:
continue
children_map[parent_id].append(child_id)
for node_id, children in children_map.items():
children.sort()
return children_map
def _sorted_children(children_map: Dict[int, List[int]], node_id: int) -> List[int]:
return children_map.get(node_id, [])
def _compute_depths(
children_map: Dict[int, List[int]], root_id: int = 1
) -> Dict[int, int]:
if root_id not in children_map:
return {}
depths = {root_id: 0}
queue = deque([root_id])
while queue:
node_id = queue.popleft()
for child_id in children_map.get(node_id, []):
if child_id in depths:
continue
depths[child_id] = depths[node_id] + 1
queue.append(child_id)
return depths
def _compute_subtree_leaf_counts(
children_map: Dict[int, List[int]], root_id: int = 1
) -> Dict[int, int]:
order: List[int] = []
if root_id not in children_map:
return {}
stack = [root_id]
while stack:
node_id = stack.pop()
order.append(node_id)
stack.extend(_sorted_children(children_map, node_id))
leaf_counts: Dict[int, int] = {}
for node_id in reversed(order):
children = _sorted_children(children_map, node_id)
if not children:
leaf_counts[node_id] = 1
else:
leaf_counts[node_id] = sum(leaf_counts[c] for c in children)
return leaf_counts
def _assign_subtree_angles(
children_map: Dict[int, List[int]],
leaf_counts: Dict[int, int],
root_id: int = 1,
base_gap: float = 0.04,
) -> Dict[int, float]:
angles: Dict[int, float] = {}
if root_id not in children_map:
return angles
stack = [(root_id, 0.0, 2.0 * math.pi)]
while stack:
node_id, start_angle, end_angle = stack.pop()
angles[node_id] = (start_angle + end_angle) / 2.0
children = _sorted_children(children_map, node_id)
if not children:
continue
span = max(0.0, end_angle - start_angle)
total = sum(leaf_counts.get(child, 1) for child in children)
if total <= 0 or span <= 0.0:
continue
if len(children) > 1:
max_gap = span * 0.15 / (len(children) - 1)
gap = min(base_gap, max_gap)
else:
gap = 0.0
span_for_children = span - gap * (len(children) - 1)
if span_for_children <= 0.0:
gap = 0.0
span_for_children = span
cursor = start_angle
for child in children:
frac = leaf_counts.get(child, 1) / total
child_span = span_for_children * frac
child_start = cursor
child_end = cursor + child_span
stack.append((child, child_start, child_end))
cursor = child_end + gap
return angles
def _compute_radius_scale(
by_depth: Dict[int, list],
angles: Dict[int, float],
radius_step: float,
node_radius: float,
spacing_factor: float = 2.2,
root_gap_factor: float = 2.8,
) -> float:
min_distance = node_radius * spacing_factor
scale = 1.0
epsilon = 1e-6
scale = max(scale, min_distance / max(radius_step, epsilon))
root_gap = node_radius * root_gap_factor
scale = max(scale, root_gap / max(radius_step, epsilon))
for depth, node_ids in by_depth.items():
if depth == 0 or len(node_ids) < 2:
continue
radius = depth * radius_step
depth_angles = sorted(angles.get(node_id, 0.0) for node_id in node_ids)
deltas = []
for left, right in zip(depth_angles, depth_angles[1:]):
deltas.append(right - left)
deltas.append(2.0 * math.pi - depth_angles[-1] + depth_angles[0])
min_delta = max(min(deltas), epsilon)
required = min_distance / (radius * min_delta)
scale = max(scale, required)
return max(scale, 1.0)
def _radial_layout(
nodes_depth: Dict[int, int],
children_map: Dict[int, List[int]],
radius_step: float,
node_radius: float,
spacing_factor: float = 2.2,
root_gap_factor: float = 2.8,
) -> Dict[int, Tuple[float, float]]:
if not nodes_depth:
return {}
by_depth = _group_nodes_by_depth(nodes_depth)
leaf_counts = _compute_subtree_leaf_counts(children_map)
angles = _assign_subtree_angles(children_map, leaf_counts)
scale = _compute_radius_scale(
by_depth,
angles,
radius_step,
node_radius,
spacing_factor=spacing_factor,
root_gap_factor=root_gap_factor,
)
radius_step *= scale
positions: Dict[int, Tuple[float, float]] = {}
for node_id, depth in nodes_depth.items():
if depth == 0:
positions[node_id] = (0.0, 0.0)
continue
angle = angles.get(node_id, 0.0)
radius = depth * radius_step
positions[node_id] = (radius * math.cos(angle), radius * math.sin(angle))
return positions
def _node_status(tree: Tree, node_id: int) -> str:
if node_id == 1:
return "target"
node = tree.nodes[node_id]
if node.is_solved():
return "leaf"
return "intermediate"
def _node_metadata(
tree: Tree, node_id: int, route_index: Dict[int, int]
) -> Dict[str, object]:
node = tree.nodes[node_id]
molecule = _node_primary_molecule(node)
smiles = str(molecule) if molecule is not None else None
return {
"node_id": node_id,
"route_id": node_id if node_id in route_index else None,
"route_index": route_index.get(node_id),
"depth": tree.nodes_depth.get(node_id, 0),
"visits": tree.nodes_visit.get(node_id, 0),
"num_children": len(tree.children.get(node_id, [])),
"rule_id": tree.nodes_rules.get(node_id),
"rule_label": tree.nodes_rule_label.get(node_id),
"solved": bool(node.is_solved()),
"smiles": smiles,
"svg": _depict_molecule_svg(molecule),
"pending_smiles": (
[str(p) for p in node.precursors_to_expand]
if node.precursors_to_expand
else []
),
"new_smiles": (
[str(p) for p in node.new_precursors] if node.new_precursors else []
),
}
def _edges_from_tree(
tree: Tree, allowed_nodes: Optional[Set[int]] = None
) -> Iterable[Tuple[int, int]]:
allowed = set(allowed_nodes) if allowed_nodes is not None else None
for child_id, parent_id in tree.parents.items():
if child_id == 1 or not parent_id:
continue
if allowed is not None and (child_id not in allowed or parent_id not in allowed):
continue
yield parent_id, child_id
def _winning_route_edges(tree: Tree) -> set:
edges = set()
for node_id in tree.winning_nodes:
current = node_id
while current and current in tree.parents:
parent = tree.parents.get(current)
if not parent:
break
edges.add((parent, current))
current = parent
return edges
def _scale_positions(
positions: Dict[int, Tuple[float, float]],
node_radius: float,
render_scale: float,
) -> Tuple[Dict[int, Tuple[float, float]], float]:
render_scale = max(render_scale, 0.01)
scaled_positions = {
node_id: (pos[0] * render_scale, pos[1] * render_scale)
for node_id, pos in positions.items()
}
return scaled_positions, node_radius * render_scale
def _render_svg(
tree: Tree,
positions: Dict[int, Tuple[float, float]],
edges: Iterable[Tuple[int, int]],
winning_edges: Set[Tuple[int, int]],
node_radius: float,
nodes_depth: Dict[int, int],
radius_step: Optional[float] = None,
pad_scale: float = 4.0,
) -> str:
if not positions:
return ""
xs = [pos[0] for pos in positions.values()]
ys = [pos[1] for pos in positions.values()]
max_radius = node_radius * 2.0
pad = max_radius * pad_scale
min_x, max_x = min(xs) - pad, max(xs) + pad
min_y, max_y = min(ys) - pad, max(ys) + pad
view_w = max_x - min_x
view_h = max_y - min_y
line_items = []
for parent_id, child_id in edges:
if parent_id not in positions or child_id not in positions:
continue
x1, y1 = positions[parent_id]
x2, y2 = positions[child_id]
line_class = "edge-winning" if (parent_id, child_id) in winning_edges else ""
line_items.append(
f''
)
circle_items = []
for node_id, (x, y) in positions.items():
status = _node_status(tree, node_id)
delay = min(nodes_depth.get(node_id, 0) * 0.03, 0.6)
radius = node_radius * 2.0 if node_id == 1 else node_radius
parent_id = tree.parents.get(node_id, 0)
circle_items.append(
f''
)
depth_circles = []
if radius_step is not None:
max_depth = max(nodes_depth.values()) if nodes_depth else 0
for depth in range(1, max_depth + 1):
depth_circles.append(
f''
)
return f"""
"""
def _format_point(value: float) -> str:
return f"{value:.2f}"
def generate_tree_html(
tree: Tree,
output_path: Path,
radius_step: float = 280.0,
node_radius: float = 80.0,
render_scale: float = 0.25,
clusters_pkl: Optional[Path] = None,
) -> None:
full_children = _build_children_map(tree)
full_nodes_depth = _compute_depths(full_children)
full_radius_step = radius_step * 0.4
full_node_radius = node_radius * 10.0
full_render_scale = min(render_scale * 1.0, 1.0)
full_positions = _radial_layout(
full_nodes_depth,
full_children,
radius_step=full_radius_step,
node_radius=full_node_radius,
spacing_factor=3.4,
root_gap_factor=4.0,
)
full_angles = _assign_subtree_angles(
full_children, _compute_subtree_leaf_counts(full_children)
)
full_layout_scale = _compute_radius_scale(
_group_nodes_by_depth(full_nodes_depth),
full_angles,
full_radius_step,
full_node_radius,
spacing_factor=3.4,
root_gap_factor=4.0,
)
full_edges = list(_edges_from_tree(tree))
winning_edges = _winning_route_edges(tree)
if not full_positions:
raise ValueError("Tree has no nodes to render.")
if winning_edges:
solved_nodes: Set[int] = {1}
for parent_id, child_id in winning_edges:
solved_nodes.add(parent_id)
solved_nodes.add(child_id)
else:
solved_nodes = set(tree.nodes.keys())
solved_children = _build_children_map(tree, allowed_nodes=solved_nodes)
solved_nodes_depth = _compute_depths(solved_children)
solved_radius_step = radius_step * 0.5
solved_node_radius = node_radius * 1.15
solved_render_scale = min(render_scale * 0.4, 1.0)
solved_positions = _radial_layout(
solved_nodes_depth,
solved_children,
radius_step=solved_radius_step,
node_radius=solved_node_radius,
)
solved_angles = _assign_subtree_angles(
solved_children, _compute_subtree_leaf_counts(solved_children)
)
solved_layout_scale = _compute_radius_scale(
_group_nodes_by_depth(solved_nodes_depth),
solved_angles,
solved_radius_step,
solved_node_radius,
)
solved_edges = list(_edges_from_tree(tree, allowed_nodes=solved_nodes))
full_positions, full_render_radius = _scale_positions(
full_positions, full_node_radius, full_render_scale
)
solved_positions, solved_render_radius = _scale_positions(
solved_positions, solved_node_radius, solved_render_scale
)
full_render_step = full_radius_step * full_layout_scale * full_render_scale
solved_render_step = solved_radius_step * solved_layout_scale * solved_render_scale
svg_full = _render_svg(
tree,
full_positions,
full_edges,
winning_edges,
full_render_radius,
full_nodes_depth,
radius_step=full_render_step,
pad_scale=1.8,
)
svg_solved = _render_svg(
tree,
solved_positions,
solved_edges,
winning_edges,
solved_render_radius,
solved_nodes_depth,
radius_step=solved_render_step,
pad_scale=1.6,
)
if not svg_solved:
svg_solved = svg_full
solved_positions = full_positions
node_meta = {}
route_index = {node_id: idx for idx, node_id in enumerate(tree.winning_nodes)}
for node_id in tree.nodes:
node_meta[str(node_id)] = _node_metadata(tree, node_id, route_index)
target_svg = _build_target_svg(tree)
target_smiles = str(tree.nodes[1].curr_precursor) if tree.nodes.get(1) else ""
clusters_payload: List[Dict[str, object]] = []
route_nodes: Dict[str, List[int]] = {}
route_extras: Dict[str, Dict[str, object]] = {}
route_svgs: Dict[str, str] = {}
if clusters_pkl:
clusters = _load_clusters(clusters_pkl)
clusters_payload = _build_cluster_payload(clusters)
cluster_route_ids = {
route_id
for cluster in clusters_payload
for route_id in cluster.get("route_ids", [])
}
route_nodes = _route_nodes_by_route(tree, cluster_route_ids)
route_extras = _route_extras_by_route(tree, cluster_route_ids)
for route_id in sorted(cluster_route_ids):
if route_id not in tree.winning_nodes:
continue
svg = get_route_svg(tree, route_id)
if svg:
route_svgs[str(route_id)] = svg
html = f"""
SynPlanner Tree Visualization
{svg_full}
{svg_solved}
Legend
Target molecule
Intermediate product
Building block
Solved pathway
Unsolved pathway
Route ID: Total {len(tree.winning_nodes)} Current 0
Route
Target molecule
{target_svg if target_svg else "
Target depiction unavailable.
"}
Selected bonds (CGRtools atom ids): []
Clusters: not loaded
"""
output_path = Path(output_path)
output_path.write_text(html, encoding="utf-8")
def main() -> None:
parser = argparse.ArgumentParser(
description="Generate a simple HTML visualization for a SynPlanner MCTS tree."
)
parser.add_argument(
"--tree-pkl",
type=Path,
required=True,
help="Path to a pickled Tree or TreeWrapper object.",
)
parser.add_argument(
"--out",
type=Path,
default=Path("tree_visualization.html"),
help="Output HTML file path.",
)
parser.add_argument(
"--radius-step",
type=float,
default=280.0,
help="Radial distance between depth rings.",
)
parser.add_argument(
"--node-radius",
type=float,
default=80.0,
help="Node radius in SVG units (pixels when not fit-to-screen).",
)
parser.add_argument(
"--render-scale",
type=float,
default=0.25,
help="Scale factor applied to the final render (1.0 = full size).",
)
parser.add_argument(
"--clusters-pkl",
type=Path,
default=None,
help="Optional path to a clusters pickle (cluster_routes output).",
)
args = parser.parse_args()
tree = _load_tree(args.tree_pkl)
if not isinstance(tree, Tree):
raise TypeError("Loaded object is not a Tree.")
generate_tree_html(
tree,
output_path=args.out,
radius_step=args.radius_step,
node_radius=args.node_radius,
render_scale=args.render_scale,
clusters_pkl=args.clusters_pkl,
)
print(f"Tree visualization written to {args.out}")
if __name__ == "__main__":
main()