| """Post-processing functions for segment predictions.""" | |
| import numpy as np | |
| def snap_to_point_cloud(vertices, xyz, class_id, snap_radius=0.5, | |
| target_classes=None): | |
| """Snap vertices to nearby point cloud clusters of specific semantic classes.""" | |
| if target_classes is None: | |
| target_classes = [1, 2] # apex, eave_end_point | |
| snapped = vertices.copy() | |
| mask = np.isin(class_id, target_classes) | |
| if mask.sum() < 2: | |
| return snapped | |
| target_pts = xyz[mask] | |
| for i, v in enumerate(vertices): | |
| dists = np.linalg.norm(target_pts - v, axis=-1) | |
| close = dists < snap_radius | |
| if close.sum() >= 2: | |
| snapped[i] = target_pts[close].mean(axis=0) | |
| return snapped | |
| def snap_horizontal(vertices, edges, max_slope=0.05): | |
| """Snap near-horizontal edges to be exactly horizontal.""" | |
| verts = vertices.copy() | |
| for a, b in edges: | |
| a, b = int(a), int(b) | |
| dy = abs(verts[a, 1] - verts[b, 1]) | |
| dxz = np.sqrt((verts[a, 0] - verts[b, 0])**2 + (verts[a, 2] - verts[b, 2])**2) | |
| if dxz > 0.1 and dy / dxz < max_slope: | |
| avg_y = 0.5 * (verts[a, 1] + verts[b, 1]) | |
| verts[a, 1] = avg_y | |
| verts[b, 1] = avg_y | |
| return verts | |