"""Assembly tree extraction from XDE document or shape topology.""" from __future__ import annotations import logging from dataclasses import dataclass, field from OCP.Bnd import Bnd_Box from OCP.BRepBndLib import BRepBndLib from OCP.TDF import TDF_Label, TDF_LabelSequence from OCP.TDataStd import TDataStd_Name from OCP.TopAbs import TopAbs_COMPOUND, TopAbs_SOLID from OCP.TopExp import TopExp_Explorer from OCP.TopoDS import TopoDS, TopoDS_Shape from OCP.XCAFDoc import XCAFDoc_ShapeTool logger = logging.getLogger(__name__) @dataclass class AssemblyNode: """Node in the assembly tree.""" id: str name: str is_assembly: bool = False is_leaf: bool = False shape: TopoDS_Shape | None = None children: list[AssemblyNode] = field(default_factory=list) classification: str = "unknown" # exterior, interior, unknown # Geometry info (populated after tessellation) num_faces: int = 0 num_solids: int = 0 bounding_box: dict | None = None def to_dict(self) -> dict: """Convert to serializable dictionary.""" result = { "id": self.id, "name": self.name, "is_assembly": self.is_assembly, "is_leaf": self.is_leaf, "classification": self.classification, "num_faces": self.num_faces, "num_solids": self.num_solids, } if self.bounding_box: result["bounding_box"] = self.bounding_box if self.children: result["children"] = [c.to_dict() for c in self.children] return result def iter_leaves(self): """Iterate over all leaf nodes.""" if self.is_leaf: yield self for child in self.children: yield from child.iter_leaves() def find_by_id(self, node_id: str) -> AssemblyNode | None: """Find a node by ID.""" if self.id == node_id: return self for child in self.children: found = child.find_by_id(node_id) if found: return found return None def _get_label_name(label: TDF_Label) -> str: """Extract name from a TDF_Label.""" name_attr = TDataStd_Name() if label.FindAttribute(TDataStd_Name.GetID_s(), name_attr): return name_attr.Get().ToExtString() return "unnamed" def _count_solids(shape: TopoDS_Shape) -> int: """Count solid entities in a shape.""" count = 0 exp = TopExp_Explorer(shape, TopAbs_SOLID) while exp.More(): count += 1 exp.Next() return count def _compute_bounding_box(shape: TopoDS_Shape) -> dict | None: """Compute the bounding box of a shape.""" try: bbox = Bnd_Box() BRepBndLib.Add_s(shape, bbox) if bbox.IsVoid(): return None xmin, ymin, zmin, xmax, ymax, zmax = bbox.Get() return { "min": [xmin, ymin, zmin], "max": [xmax, ymax, zmax], "size": [xmax - xmin, ymax - ymin, zmax - zmin], "center": [(xmin + xmax) / 2, (ymin + ymax) / 2, (zmin + zmax) / 2], } except Exception: return None def _build_tree( shape_tool: XCAFDoc_ShapeTool, label: TDF_Label, parent_id: str = "", index: int = 0, ) -> AssemblyNode: """Recursively build assembly tree from XDE label.""" node_id = f"{parent_id}_{index}" if parent_id else str(index) name = _get_label_name(label) is_assembly = shape_tool.IsAssembly(label) shape = shape_tool.GetShape(label) node = AssemblyNode( id=node_id, name=name, is_assembly=is_assembly, shape=shape if not shape.IsNull() else None, ) if is_assembly: sub_labels = TDF_LabelSequence() shape_tool.GetComponents(label, sub_labels) for i in range(1, sub_labels.Length() + 1): sub_label = sub_labels.Value(i) ref_label = TDF_Label() if shape_tool.GetReferredShape(sub_label, ref_label): child = _build_tree(shape_tool, ref_label, node_id, i) else: child = _build_tree(shape_tool, sub_label, node_id, i) node.children.append(child) else: node.is_leaf = True if node.shape is not None and not node.shape.IsNull(): node.num_solids = _count_solids(node.shape) node.bounding_box = _compute_bounding_box(node.shape) return node def extract_assembly_tree(shape_tool: XCAFDoc_ShapeTool) -> AssemblyNode: """Extract the full assembly tree from an XDE ShapeTool. Args: shape_tool: XCAFDoc_ShapeTool from the loaded document. Returns: Root AssemblyNode with the complete tree. """ labels = TDF_LabelSequence() shape_tool.GetFreeShapes(labels) if labels.Length() == 0: raise RuntimeError("No free shapes found in XDE document") if labels.Length() == 1: root = _build_tree(shape_tool, labels.Value(1), "", 0) else: root = AssemblyNode(id="root", name="Assembly", is_assembly=True) for i in range(1, labels.Length() + 1): child = _build_tree(shape_tool, labels.Value(i), "root", i) root.children.append(child) leaf_count = sum(1 for _ in root.iter_leaves()) logger.info("Assembly tree (XDE): %d leaves, root name='%s'", leaf_count, root.name) return root def _extract_product_names(reader) -> list[str]: """Extract product names from STEP model entities. Filters out generic translator names and bare shape type names. """ try: from OCP.StepBasic import StepBasic_Product ws = reader.WS() model = ws.Model() skip = {"Open CASCADE STEP translator", "SOLID", "SHELL", "COMPOUND", "COMPSOLID", "WIRE", "EDGE", "VERTEX", "FACE"} names = [] for i in range(1, model.NbEntities() + 1): ent = model.Entity(i) if isinstance(ent, StepBasic_Product): name = ent.Name().ToCString() if ent.Name() else None if name and name not in skip and not any(s in name for s in {"Open CASCADE STEP translator"}): names.append(name) return names except Exception: return [] def _extract_xde_label_names(doc) -> list[str]: """Extract shape names from XDE document's Shapes sublabel tree.""" try: from OCP.TDF import TDF_ChildIterator names = [] # First child of Main is the "Shapes" label it = TDF_ChildIterator(doc.Main(), False) if not it.More(): return [] shapes_label = it.Value() # Iterate shape sublabels it2 = TDF_ChildIterator(shapes_label, False) while it2.More(): label = it2.Value() name_attr = TDataStd_Name() if label.FindAttribute(TDataStd_Name.GetID_s(), name_attr): raw = name_attr.Get().ToExtString() # Skip translator meta-names if "Open CASCADE STEP translator" not in raw: names.append(raw) it2.Next() return names except Exception: return [] def _load_names_json(step_path) -> list[str]: """Load part names from a .names.json sidecar file.""" import json from pathlib import Path names_path = Path(str(step_path) + ".names.json") if not names_path.exists(): # Also try replacing extension names_path = Path(step_path).with_suffix(".names.json") if not names_path.exists(): return [] try: with open(names_path) as f: names = json.load(f) if isinstance(names, list): return names except Exception: pass return [] def extract_assembly_from_shape( shape: TopoDS_Shape, name: str = "Assembly", reader=None, doc=None, step_path=None, id_prefix: str = "", ) -> AssemblyNode: """Fallback: extract assembly tree from shape topology. Enumerates solids from the shape compound to create leaf nodes. Tries to map names from: XDE labels, or STEP products. Args: shape: The TopoDS_Shape to decompose. name: Root node name. reader: Optional STEPControl_Reader for product name extraction. doc: Optional TDocStd_Document for XDE label name extraction. step_path: Path to the STEP file. id_prefix: Prefix for node IDs to avoid collisions between files. Returns: Root AssemblyNode with solids as leaves. """ part_names: list[str] = [] # Source 1: STEP product entities (most reliable for named parts) if reader is not None: product_names = _extract_product_names(reader) if product_names: part_names = product_names logger.info("Found %d part names from STEP products", len(part_names)) # Source 2: XDE document labels (fallback) if not part_names and doc is not None: xde_names = _extract_xde_label_names(doc) # Filter out generic shape type names skip = {"SOLID", "SHELL", "COMPOUND", "COMPSOLID", "WIRE", "EDGE", "VERTEX", "FACE"} xde_names = [n for n in xde_names if n not in skip] if xde_names: part_names = xde_names logger.info("Found %d part names from XDE labels", len(part_names)) root_id = f"{id_prefix}0" if id_prefix else "0" root = AssemblyNode(id=root_id, name=name, is_assembly=True, shape=shape) solid_exp = TopExp_Explorer(shape, TopAbs_SOLID) idx = 0 while solid_exp.More(): solid = TopoDS.Solid_s(solid_exp.Current()) if idx < len(part_names): solid_name = part_names[idx] else: solid_name = f"Part_{idx + 1}" node = AssemblyNode( id=f"{root_id}_{idx + 1}", name=solid_name, is_leaf=True, shape=solid, num_solids=1, bounding_box=_compute_bounding_box(solid), ) root.children.append(node) idx += 1 solid_exp.Next() leaf_count = len(root.children) if leaf_count == 0: root.is_leaf = True root.is_assembly = False root.num_solids = 0 root.bounding_box = _compute_bounding_box(shape) else: logger.info("Assembly tree (shape): %d solids from topology", leaf_count) return root