forma-3d-review-api / src /loader /assembly_tree.py
lomit's picture
Sync from forma-3d-review@b6d4687f5d0f2e5303758c97095ea7e38e740723
182efca verified
"""Assembly tree extraction from XDE document or shape topology."""
from __future__ import annotations
import logging
from dataclasses import dataclass, field
from OCP.Bnd import Bnd_Box
from OCP.BRepBndLib import BRepBndLib
from OCP.TDF import TDF_Label, TDF_LabelSequence
from OCP.TDataStd import TDataStd_Name
from OCP.TopAbs import TopAbs_COMPOUND, TopAbs_SOLID
from OCP.TopExp import TopExp_Explorer
from OCP.TopoDS import TopoDS, TopoDS_Shape
from OCP.XCAFDoc import XCAFDoc_ShapeTool
logger = logging.getLogger(__name__)
@dataclass
class AssemblyNode:
"""Node in the assembly tree."""
id: str
name: str
is_assembly: bool = False
is_leaf: bool = False
shape: TopoDS_Shape | None = None
children: list[AssemblyNode] = field(default_factory=list)
classification: str = "unknown" # exterior, interior, unknown
# Geometry info (populated after tessellation)
num_faces: int = 0
num_solids: int = 0
bounding_box: dict | None = None
def to_dict(self) -> dict:
"""Convert to serializable dictionary."""
result = {
"id": self.id,
"name": self.name,
"is_assembly": self.is_assembly,
"is_leaf": self.is_leaf,
"classification": self.classification,
"num_faces": self.num_faces,
"num_solids": self.num_solids,
}
if self.bounding_box:
result["bounding_box"] = self.bounding_box
if self.children:
result["children"] = [c.to_dict() for c in self.children]
return result
def iter_leaves(self):
"""Iterate over all leaf nodes."""
if self.is_leaf:
yield self
for child in self.children:
yield from child.iter_leaves()
def find_by_id(self, node_id: str) -> AssemblyNode | None:
"""Find a node by ID."""
if self.id == node_id:
return self
for child in self.children:
found = child.find_by_id(node_id)
if found:
return found
return None
def _get_label_name(label: TDF_Label) -> str:
"""Extract name from a TDF_Label."""
name_attr = TDataStd_Name()
if label.FindAttribute(TDataStd_Name.GetID_s(), name_attr):
return name_attr.Get().ToExtString()
return "unnamed"
def _count_solids(shape: TopoDS_Shape) -> int:
"""Count solid entities in a shape."""
count = 0
exp = TopExp_Explorer(shape, TopAbs_SOLID)
while exp.More():
count += 1
exp.Next()
return count
def _compute_bounding_box(shape: TopoDS_Shape) -> dict | None:
"""Compute the bounding box of a shape."""
try:
bbox = Bnd_Box()
BRepBndLib.Add_s(shape, bbox)
if bbox.IsVoid():
return None
xmin, ymin, zmin, xmax, ymax, zmax = bbox.Get()
return {
"min": [xmin, ymin, zmin],
"max": [xmax, ymax, zmax],
"size": [xmax - xmin, ymax - ymin, zmax - zmin],
"center": [(xmin + xmax) / 2, (ymin + ymax) / 2, (zmin + zmax) / 2],
}
except Exception:
return None
def _build_tree(
shape_tool: XCAFDoc_ShapeTool,
label: TDF_Label,
parent_id: str = "",
index: int = 0,
) -> AssemblyNode:
"""Recursively build assembly tree from XDE label."""
node_id = f"{parent_id}_{index}" if parent_id else str(index)
name = _get_label_name(label)
is_assembly = shape_tool.IsAssembly(label)
shape = shape_tool.GetShape(label)
node = AssemblyNode(
id=node_id,
name=name,
is_assembly=is_assembly,
shape=shape if not shape.IsNull() else None,
)
if is_assembly:
sub_labels = TDF_LabelSequence()
shape_tool.GetComponents(label, sub_labels)
for i in range(1, sub_labels.Length() + 1):
sub_label = sub_labels.Value(i)
ref_label = TDF_Label()
if shape_tool.GetReferredShape(sub_label, ref_label):
child = _build_tree(shape_tool, ref_label, node_id, i)
else:
child = _build_tree(shape_tool, sub_label, node_id, i)
node.children.append(child)
else:
node.is_leaf = True
if node.shape is not None and not node.shape.IsNull():
node.num_solids = _count_solids(node.shape)
node.bounding_box = _compute_bounding_box(node.shape)
return node
def extract_assembly_tree(shape_tool: XCAFDoc_ShapeTool) -> AssemblyNode:
"""Extract the full assembly tree from an XDE ShapeTool.
Args:
shape_tool: XCAFDoc_ShapeTool from the loaded document.
Returns:
Root AssemblyNode with the complete tree.
"""
labels = TDF_LabelSequence()
shape_tool.GetFreeShapes(labels)
if labels.Length() == 0:
raise RuntimeError("No free shapes found in XDE document")
if labels.Length() == 1:
root = _build_tree(shape_tool, labels.Value(1), "", 0)
else:
root = AssemblyNode(id="root", name="Assembly", is_assembly=True)
for i in range(1, labels.Length() + 1):
child = _build_tree(shape_tool, labels.Value(i), "root", i)
root.children.append(child)
leaf_count = sum(1 for _ in root.iter_leaves())
logger.info("Assembly tree (XDE): %d leaves, root name='%s'", leaf_count, root.name)
return root
def _extract_product_names(reader) -> list[str]:
"""Extract product names from STEP model entities.
Filters out generic translator names and bare shape type names.
"""
try:
from OCP.StepBasic import StepBasic_Product
ws = reader.WS()
model = ws.Model()
skip = {"Open CASCADE STEP translator", "SOLID", "SHELL", "COMPOUND", "COMPSOLID", "WIRE", "EDGE", "VERTEX", "FACE"}
names = []
for i in range(1, model.NbEntities() + 1):
ent = model.Entity(i)
if isinstance(ent, StepBasic_Product):
name = ent.Name().ToCString() if ent.Name() else None
if name and name not in skip and not any(s in name for s in {"Open CASCADE STEP translator"}):
names.append(name)
return names
except Exception:
return []
def _extract_xde_label_names(doc) -> list[str]:
"""Extract shape names from XDE document's Shapes sublabel tree."""
try:
from OCP.TDF import TDF_ChildIterator
names = []
# First child of Main is the "Shapes" label
it = TDF_ChildIterator(doc.Main(), False)
if not it.More():
return []
shapes_label = it.Value()
# Iterate shape sublabels
it2 = TDF_ChildIterator(shapes_label, False)
while it2.More():
label = it2.Value()
name_attr = TDataStd_Name()
if label.FindAttribute(TDataStd_Name.GetID_s(), name_attr):
raw = name_attr.Get().ToExtString()
# Skip translator meta-names
if "Open CASCADE STEP translator" not in raw:
names.append(raw)
it2.Next()
return names
except Exception:
return []
def _load_names_json(step_path) -> list[str]:
"""Load part names from a .names.json sidecar file."""
import json
from pathlib import Path
names_path = Path(str(step_path) + ".names.json")
if not names_path.exists():
# Also try replacing extension
names_path = Path(step_path).with_suffix(".names.json")
if not names_path.exists():
return []
try:
with open(names_path) as f:
names = json.load(f)
if isinstance(names, list):
return names
except Exception:
pass
return []
def extract_assembly_from_shape(
shape: TopoDS_Shape,
name: str = "Assembly",
reader=None,
doc=None,
step_path=None,
id_prefix: str = "",
) -> AssemblyNode:
"""Fallback: extract assembly tree from shape topology.
Enumerates solids from the shape compound to create leaf nodes.
Tries to map names from: XDE labels, or STEP products.
Args:
shape: The TopoDS_Shape to decompose.
name: Root node name.
reader: Optional STEPControl_Reader for product name extraction.
doc: Optional TDocStd_Document for XDE label name extraction.
step_path: Path to the STEP file.
id_prefix: Prefix for node IDs to avoid collisions between files.
Returns:
Root AssemblyNode with solids as leaves.
"""
part_names: list[str] = []
# Source 1: STEP product entities (most reliable for named parts)
if reader is not None:
product_names = _extract_product_names(reader)
if product_names:
part_names = product_names
logger.info("Found %d part names from STEP products", len(part_names))
# Source 2: XDE document labels (fallback)
if not part_names and doc is not None:
xde_names = _extract_xde_label_names(doc)
# Filter out generic shape type names
skip = {"SOLID", "SHELL", "COMPOUND", "COMPSOLID", "WIRE", "EDGE", "VERTEX", "FACE"}
xde_names = [n for n in xde_names if n not in skip]
if xde_names:
part_names = xde_names
logger.info("Found %d part names from XDE labels", len(part_names))
root_id = f"{id_prefix}0" if id_prefix else "0"
root = AssemblyNode(id=root_id, name=name, is_assembly=True, shape=shape)
solid_exp = TopExp_Explorer(shape, TopAbs_SOLID)
idx = 0
while solid_exp.More():
solid = TopoDS.Solid_s(solid_exp.Current())
if idx < len(part_names):
solid_name = part_names[idx]
else:
solid_name = f"Part_{idx + 1}"
node = AssemblyNode(
id=f"{root_id}_{idx + 1}",
name=solid_name,
is_leaf=True,
shape=solid,
num_solids=1,
bounding_box=_compute_bounding_box(solid),
)
root.children.append(node)
idx += 1
solid_exp.Next()
leaf_count = len(root.children)
if leaf_count == 0:
root.is_leaf = True
root.is_assembly = False
root.num_solids = 0
root.bounding_box = _compute_bounding_box(shape)
else:
logger.info("Assembly tree (shape): %d solids from topology", leaf_count)
return root