""" 3D Voxel Shape Classifier — Complete Geometric Primitive Vocabulary 5×5×5 binary voxel grid → rigid cascade → curvature analysis → classify 38 shape classes covering: - Rigid 0D-3D: points, lines, joints, triangles, quads, polyhedra, prisms - Curved 1D: arcs, helices - Curved 2D: circles, ellipses, discs - Curved 3D solid: sphere, hemisphere, cylinder, cone, capsule, torus - Curved 3D hollow: shell, tube - Curved 3D open: bowl (concave), saddle (hyperbolic) Curvature types: none, convex, concave, cylindrical, conical, toroidal, hyperbolic, helical """ import numpy as np import torch import torch.nn as nn import torch.nn.functional as F from typing import Optional import math from itertools import combinations # === SwiGLU Activation ======================================================= class SwiGLU(nn.Module): """ SwiGLU activation: out = (x @ W1) * SiLU(x @ W2) SiLU(x) = x * sigmoid(x), aka Swish — the "Swi" in SwiGLU. Unlike plain sigmoid gating, SiLU preserves gradient magnitude through the gate branch while maintaining sharp gating behavior. Used at geometric decision points where crisp on/off transitions matter more than smooth interpolation. """ def __init__(self, in_dim, out_dim): super().__init__() self.w1 = nn.Linear(in_dim, out_dim) self.w2 = nn.Linear(in_dim, out_dim) def forward(self, x): return self.w1(x) * F.silu(self.w2(x)) # === Shape Catalog =========================================================== SHAPE_CATALOG = { # ---- Rigid 0D ---- "point": {"dim": 0, "curved": False, "curvature": "none"}, # ---- Rigid 1D: lines ---- "line_x": {"dim": 1, "curved": False, "curvature": "none"}, "line_y": {"dim": 1, "curved": False, "curvature": "none"}, "line_z": {"dim": 1, "curved": False, "curvature": "none"}, "line_diag": {"dim": 1, "curved": False, "curvature": "none"}, # ---- Rigid 1D: compounds ---- "cross": {"dim": 1, "curved": False, "curvature": "none"}, "l_shape": {"dim": 1, "curved": False, "curvature": "none"}, "collinear": {"dim": 1, "curved": False, "curvature": "none"}, # ---- Rigid 2D: triangles ---- "triangle_xy": {"dim": 2, "curved": False, "curvature": "none"}, "triangle_xz": {"dim": 2, "curved": False, "curvature": "none"}, "triangle_3d": {"dim": 2, "curved": False, "curvature": "none"}, # ---- Rigid 2D: quads ---- "square_xy": {"dim": 2, "curved": False, "curvature": "none"}, "square_xz": {"dim": 2, "curved": False, "curvature": "none"}, "rectangle": {"dim": 2, "curved": False, "curvature": "none"}, "coplanar": {"dim": 2, "curved": False, "curvature": "none"}, # ---- Rigid 2D: filled ---- "plane": {"dim": 2, "curved": False, "curvature": "none"}, # ---- Rigid 3D: simplices ---- "tetrahedron": {"dim": 3, "curved": False, "curvature": "none"}, "pyramid": {"dim": 3, "curved": False, "curvature": "none"}, "pentachoron": {"dim": 3, "curved": False, "curvature": "none"}, # ---- Rigid 3D: prisms/polyhedra ---- "cube": {"dim": 3, "curved": False, "curvature": "none"}, "cuboid": {"dim": 3, "curved": False, "curvature": "none"}, "triangular_prism": {"dim": 3, "curved": False, "curvature": "none"}, "octahedron": {"dim": 3, "curved": False, "curvature": "none"}, # ---- Curved 1D ---- "arc": {"dim": 1, "curved": True, "curvature": "convex"}, "helix": {"dim": 1, "curved": True, "curvature": "helical"}, # ---- Curved 2D: outlines ---- "circle": {"dim": 2, "curved": True, "curvature": "convex"}, "ellipse": {"dim": 2, "curved": True, "curvature": "convex"}, # ---- Curved 2D: filled ---- "disc": {"dim": 2, "curved": True, "curvature": "convex"}, # ---- Curved 3D: solid ---- "sphere": {"dim": 3, "curved": True, "curvature": "convex"}, "hemisphere": {"dim": 3, "curved": True, "curvature": "convex"}, "cylinder": {"dim": 3, "curved": True, "curvature": "cylindrical"}, "cone": {"dim": 3, "curved": True, "curvature": "conical"}, "capsule": {"dim": 3, "curved": True, "curvature": "convex"}, "torus": {"dim": 3, "curved": True, "curvature": "toroidal"}, # ---- Curved 3D: hollow ---- "shell": {"dim": 3, "curved": True, "curvature": "convex"}, "tube": {"dim": 3, "curved": True, "curvature": "cylindrical"}, # ---- Curved 3D: open surfaces ---- "bowl": {"dim": 3, "curved": True, "curvature": "concave"}, "saddle": {"dim": 3, "curved": True, "curvature": "hyperbolic"}, } NUM_CLASSES = len(SHAPE_CATALOG) CLASS_NAMES = list(SHAPE_CATALOG.keys()) CLASS_TO_IDX = {name: i for i, name in enumerate(CLASS_NAMES)} CURVATURE_TYPES = ["none", "convex", "concave", "cylindrical", "conical", "toroidal", "hyperbolic", "helical"] CURV_TO_IDX = {c: i for i, c in enumerate(CURVATURE_TYPES)} NUM_CURVATURES = len(CURVATURE_TYPES) GS = 5 # grid size # === Cayley-Menger Utilities ================================================= def cayley_menger_det(points: np.ndarray) -> float: n = len(points) D = np.zeros((n, n)) for i in range(n): for j in range(n): D[i, j] = np.sum((points[i] - points[j]) ** 2) CM = np.zeros((n + 1, n + 1)) CM[0, 1:] = 1 CM[1:, 0] = 1 CM[1:, 1:] = D return np.linalg.det(CM) def simplex_volume(points: np.ndarray) -> float: k = len(points) if k < 2: return 0.0 cm = cayley_menger_det(points) sign = (-1) ** k denom = (2 ** (k - 1)) * (math.factorial(k - 1) ** 2) v_sq = sign * cm / denom return np.sqrt(max(0, v_sq)) def effective_volume(points: np.ndarray) -> float: k = len(points) if k < 2: return 0.0 if k == 2: return np.linalg.norm(points[0] - points[1]) if k >= 3: max_a = 0 for idx in combinations(range(min(k, 8)), 3): max_a = max(max_a, simplex_volume(points[list(idx)])) if k < 4: return max_a if k >= 4: max_v = 0 for idx in combinations(range(min(k, 8)), 4): max_v = max(max_v, simplex_volume(points[list(idx)])) return max_v return 0.0 # === Shape Generator ========================================================= class ShapeGenerator: def __init__(self, seed=42): self.rng = np.random.RandomState(seed) def generate(self, n_samples: int) -> list: samples = [] per_class = n_samples // NUM_CLASSES for name in CLASS_NAMES: count = 0 attempts = 0 while count < per_class and attempts < per_class * 5: s = self._make(name) attempts += 1 if s is not None: samples.append(s) count += 1 while len(samples) < n_samples: name = self.rng.choice(CLASS_NAMES) s = self._make(name) if s is not None: samples.append(s) self.rng.shuffle(samples) return samples[:n_samples] def _make(self, name: str) -> Optional[dict]: info = SHAPE_CATALOG[name] if info["curved"]: voxels = self._curved(name) else: voxels = self._rigid(name) if voxels is None: return None voxels = np.clip(voxels, 0, GS - 1).astype(int) voxels = np.unique(voxels, axis=0) if len(voxels) < 1: return None return self._build(name, info, voxels) # === Rigid Generators === def _rigid(self, name): rng = self.rng if name == "point": return rng.randint(0, GS, size=(1, 3)) elif name == "line_x": y, z = rng.randint(0, GS, size=2) x1, x2 = sorted(rng.choice(GS, 2, replace=False)) return np.array([[x1, y, z], [x2, y, z]]) elif name == "line_y": x, z = rng.randint(0, GS, size=2) y1, y2 = sorted(rng.choice(GS, 2, replace=False)) return np.array([[x, y1, z], [x, y2, z]]) elif name == "line_z": x, y = rng.randint(0, GS, size=2) z1, z2 = sorted(rng.choice(GS, 2, replace=False)) return np.array([[x, y, z1], [x, y, z2]]) elif name == "line_diag": p1 = rng.randint(0, 3, size=3) step = rng.randint(1, 3) direction = rng.choice([-1, 1], size=3) if np.sum(direction != 0) < 2: direction[rng.randint(3)] = rng.choice([-1, 1]) p2 = np.clip(p1 + step * direction, 0, GS - 1) if np.array_equal(p1, p2): p2 = np.clip(p1 + np.array([1, 1, 0]), 0, GS - 1) return np.array([p1, p2]) elif name == "cross": # Two perpendicular lines intersecting at a point cx, cy, cz = rng.randint(1, GS - 1, size=3) length = rng.randint(1, 3) axis1, axis2 = rng.choice(3, 2, replace=False) pts = [[cx, cy, cz]] # center for sign in [-1, 1]: p = [cx, cy, cz] p[axis1] = np.clip(p[axis1] + sign * length, 0, GS - 1) pts.append(list(p)) for sign in [-1, 1]: p = [cx, cy, cz] p[axis2] = np.clip(p[axis2] + sign * length, 0, GS - 1) pts.append(list(p)) return np.array(pts) elif name == "l_shape": # Two lines meeting at a vertex (right angle) corner = rng.randint(1, GS - 1, size=3) axis1, axis2 = rng.choice(3, 2, replace=False) len1 = rng.randint(1, 3) len2 = rng.randint(1, 3) dir1 = rng.choice([-1, 1]) dir2 = rng.choice([-1, 1]) pts = [list(corner)] for i in range(1, len1 + 1): p = list(corner) p[axis1] = np.clip(p[axis1] + dir1 * i, 0, GS - 1) pts.append(p) for i in range(1, len2 + 1): p = list(corner) p[axis2] = np.clip(p[axis2] + dir2 * i, 0, GS - 1) pts.append(p) return np.array(pts) elif name == "collinear": axis = rng.randint(3) fixed = rng.randint(0, GS, size=2) vals = sorted(rng.choice(GS, 3, replace=False)) pts = np.zeros((3, 3), dtype=int) for i, v in enumerate(vals): pts[i, axis] = v pts[i, (axis + 1) % 3] = fixed[0] pts[i, (axis + 2) % 3] = fixed[1] return pts elif name == "triangle_xy": z = rng.randint(0, GS) pts = self._rand_pts_2d(3, min_dist=1) if pts is None: return None return np.column_stack([pts, np.full(3, z)]) elif name == "triangle_xz": y = rng.randint(0, GS) pts = self._rand_pts_2d(3, min_dist=1) if pts is None: return None return np.column_stack([pts[:, 0], np.full(3, y), pts[:, 1]]) elif name == "triangle_3d": return self._rand_pts_3d(3, min_dist=1) elif name == "square_xy": z = rng.randint(0, GS) x1, y1 = rng.randint(0, 3, size=2) s = rng.randint(1, 3) pts = np.array([[x1, y1, z], [x1 + s, y1, z], [x1, y1 + s, z], [x1 + s, y1 + s, z]]) return np.clip(pts, 0, GS - 1) elif name == "square_xz": y = rng.randint(0, GS) x1, z1 = rng.randint(0, 3, size=2) s = rng.randint(1, 3) pts = np.array([[x1, y, z1], [x1 + s, y, z1], [x1, y, z1 + s], [x1 + s, y, z1 + s]]) return np.clip(pts, 0, GS - 1) elif name == "rectangle": axis = rng.randint(3) val = rng.randint(0, GS) a1, a2 = rng.randint(0, 3), rng.randint(0, 3) w, h = rng.randint(1, 4), rng.randint(1, 3) if w == h: w = min(GS - 1, w + 1) c = np.array([[a1, a2], [a1 + w, a2], [a1, a2 + h], [a1 + w, a2 + h]]) c = np.clip(c, 0, GS - 1) if axis == 0: return np.column_stack([np.full(4, val), c]) elif axis == 1: return np.column_stack([c[:, 0], np.full(4, val), c[:, 1]]) else: return np.column_stack([c, np.full(4, val)]) elif name == "coplanar": pts = self._rand_pts_3d(4, min_dist=1) if pts is None: return None pts[:, rng.randint(3)] = pts[0, rng.randint(3)] return pts elif name == "plane": # Filled rectangular slab, 1 voxel thick axis = rng.randint(3) val = rng.randint(0, GS) a_start = rng.randint(0, 2) b_start = rng.randint(0, 2) a_size = rng.randint(2, GS - a_start + 1) b_size = rng.randint(2, GS - b_start + 1) pts = [] for a in range(a_start, min(GS, a_start + a_size)): for b in range(b_start, min(GS, b_start + b_size)): p = [0, 0, 0] p[axis] = val p[(axis + 1) % 3] = a p[(axis + 2) % 3] = b pts.append(p) return np.array(pts) if len(pts) >= 4 else None elif name == "tetrahedron": pts = self._rand_pts_3d(4, min_dist=1) if pts is None: return None centered = pts - pts.mean(axis=0) _, s, _ = np.linalg.svd(centered.astype(float)) if s[-1] < 0.5: pts[rng.randint(4), rng.randint(3)] = (pts[0, 0] + 2) % GS return pts elif name == "pyramid": z_base = rng.randint(0, 3) x1, y1 = rng.randint(0, 3), rng.randint(0, 3) s = rng.randint(1, 3) base = np.array([[x1, y1, z_base], [x1 + s, y1, z_base], [x1, y1 + s, z_base], [x1 + s, y1 + s, z_base]]) apex = np.array([[x1 + s // 2, y1 + s // 2, z_base + rng.randint(1, 3)]]) return np.clip(np.vstack([base, apex]), 0, GS - 1) elif name == "pentachoron": return self._rand_pts_3d(5, min_dist=1) elif name == "cube": x1, y1, z1 = rng.randint(0, 3, size=3) s = rng.randint(1, 3) pts = [] for dx in [0, s]: for dy in [0, s]: for dz in [0, s]: pts.append([x1 + dx, y1 + dy, z1 + dz]) return np.clip(np.array(pts), 0, GS - 1) elif name == "cuboid": x1, y1, z1 = rng.randint(0, 2, size=3) sx, sy, sz = rng.randint(1, 4, size=3) # Ensure not a cube: at least 2 different edge lengths if sx == sy == sz: sx = min(GS - 1, sx + 1) pts = [] for dx in [0, sx]: for dy in [0, sy]: for dz in [0, sz]: pts.append([x1 + dx, y1 + dy, z1 + dz]) return np.clip(np.array(pts), 0, GS - 1) elif name == "triangular_prism": # Triangle in one plane, extruded along the other axis axis = rng.randint(3) # extrusion axis ext_start = rng.randint(0, 3) ext_len = rng.randint(1, 3) tri = self._rand_pts_2d(3, min_dist=1) if tri is None: return None pts = [] for e in range(ext_start, min(GS, ext_start + ext_len + 1)): for t in tri: p = [0, 0, 0] p[axis] = e p[(axis + 1) % 3] = t[0] p[(axis + 2) % 3] = t[1] pts.append(p) return np.clip(np.array(pts), 0, GS - 1) if len(pts) >= 6 else None elif name == "octahedron": # 6 vertices: ±1 along each axis from center cx, cy, cz = rng.randint(1, GS - 1, size=3) s = rng.randint(1, 3) pts = [[cx, cy, cz + s], [cx, cy, cz - s], [cx + s, cy, cz], [cx - s, cy, cz], [cx, cy + s, cz], [cx, cy - s, cz]] return np.clip(np.array(pts), 0, GS - 1) return None # === Curved Generators === def _curved(self, name): rng = self.rng cx, cy, cz = rng.uniform(1.0, 3.0, size=3) if name == "arc": r = rng.uniform(1.2, 2.2) plane = rng.choice(["xy", "xz", "yz"]) start = rng.uniform(0, 2 * np.pi) span = rng.uniform(np.pi * 0.4, np.pi * 1.2) n = rng.randint(6, 12) angles = np.linspace(start, start + span, n) pts = [] for a in angles: if plane == "xy": pts.append([cx + r * np.cos(a), cy + r * np.sin(a), cz]) elif plane == "xz": pts.append([cx + r * np.cos(a), cy, cz + r * np.sin(a)]) else: pts.append([cx, cy + r * np.cos(a), cz + r * np.sin(a)]) pts = np.unique(np.round(np.clip(pts, 0, GS - 1)).astype(int), axis=0) return pts if len(pts) >= 3 else None elif name == "helix": # Spiral through 3D: parametric curve r = rng.uniform(0.8, 1.8) axis = rng.randint(3) pitch = rng.uniform(0.3, 0.8) # rise per radian n = rng.randint(15, 30) t = np.linspace(0, 2 * np.pi * rng.uniform(1.0, 2.5), n) pts = [] center = [cx, cy, cz] axes = [i for i in range(3) if i != axis] start_h = rng.uniform(0, 1.0) for ti in t: p = [0.0, 0.0, 0.0] p[axes[0]] = center[axes[0]] + r * np.cos(ti) p[axes[1]] = center[axes[1]] + r * np.sin(ti) p[axis] = start_h + pitch * ti pts.append(p) pts = np.unique(np.round(np.clip(pts, 0, GS - 1)).astype(int), axis=0) return pts if len(pts) >= 5 else None elif name == "circle": r = rng.uniform(1.0, 2.0) plane = rng.choice(["xy", "xz", "yz"]) n = rng.randint(12, 20) angles = np.linspace(0, 2 * np.pi, n, endpoint=False) pts = [] for a in angles: if plane == "xy": pts.append([cx + r * np.cos(a), cy + r * np.sin(a), cz]) elif plane == "xz": pts.append([cx + r * np.cos(a), cy, cz + r * np.sin(a)]) else: pts.append([cx, cy + r * np.cos(a), cz + r * np.sin(a)]) pts = np.unique(np.round(np.clip(pts, 0, GS - 1)).astype(int), axis=0) return pts if len(pts) >= 5 else None elif name == "ellipse": rx, ry = rng.uniform(0.8, 2.0), rng.uniform(0.8, 2.0) if abs(rx - ry) < 0.3: rx *= 1.4 plane = rng.choice(["xy", "xz", "yz"]) n = rng.randint(12, 20) angles = np.linspace(0, 2 * np.pi, n, endpoint=False) pts = [] for a in angles: if plane == "xy": pts.append([cx + rx * np.cos(a), cy + ry * np.sin(a), cz]) elif plane == "xz": pts.append([cx + rx * np.cos(a), cy, cz + ry * np.sin(a)]) else: pts.append([cx, cy + rx * np.cos(a), cz + ry * np.sin(a)]) pts = np.unique(np.round(np.clip(pts, 0, GS - 1)).astype(int), axis=0) return pts if len(pts) >= 5 else None elif name == "disc": # Filled circle in a plane (not just outline) r = rng.uniform(1.0, 2.2) axis = rng.randint(3) val = round(rng.uniform(0.5, 3.5)) center = [cx, cy, cz] axes = [i for i in range(3) if i != axis] pts = [] for x in range(GS): for y in range(GS): p = [0, 0, 0] p[axis] = val p[axes[0]] = x p[axes[1]] = y dist = np.sqrt((x - center[axes[0]])**2 + (y - center[axes[1]])**2) if dist <= r: pts.append(p) return np.array(pts) if len(pts) >= 4 else None elif name == "sphere": r = rng.uniform(1.0, 2.2) pts = [] for x in range(GS): for y in range(GS): for z in range(GS): if (x - cx)**2 + (y - cy)**2 + (z - cz)**2 <= r**2: pts.append([x, y, z]) return np.array(pts) if len(pts) >= 4 else None elif name == "hemisphere": r = rng.uniform(1.0, 2.2) cut_axis = rng.randint(3) center = [cx, cy, cz] pts = [] for x in range(GS): for y in range(GS): for z in range(GS): p = [x, y, z] if (x - cx)**2 + (y - cy)**2 + (z - cz)**2 <= r**2: if p[cut_axis] >= center[cut_axis]: pts.append(p) return np.array(pts) if len(pts) >= 3 else None elif name == "cylinder": r = rng.uniform(0.8, 1.8) axis = rng.randint(3) length = rng.randint(2, 5) start = rng.randint(0, GS - length + 1) center = [cx, cy, cz] axes = [i for i in range(3) if i != axis] pts = [] for x in range(GS): for y in range(GS): for z in range(GS): p = [x, y, z] if p[axis] < start or p[axis] >= start + length: continue dist_sq = sum((p[a] - center[a])**2 for a in axes) if dist_sq <= r**2: pts.append(p) return np.array(pts) if len(pts) >= 4 else None elif name == "cone": r_base = rng.uniform(1.0, 2.0) axis = rng.randint(3) height = rng.randint(2, 5) base_pos = rng.randint(0, GS - height + 1) center = [cx, cy, cz] axes = [i for i in range(3) if i != axis] pts = [] for x in range(GS): for y in range(GS): for z in range(GS): p = [x, y, z] along = p[axis] - base_pos if along < 0 or along >= height: continue t = along / (height - 1 + 1e-6) r_at = r_base * (1.0 - t) dist_sq = sum((p[a] - center[a])**2 for a in axes) if dist_sq <= r_at**2: pts.append(p) return np.array(pts) if len(pts) >= 4 else None elif name == "capsule": # Cylinder with hemispherical caps r = rng.uniform(0.8, 1.5) axis = rng.randint(3) body_len = rng.randint(1, 3) center = [cx, cy, cz] axes = [i for i in range(3) if i != axis] body_start = round(center[axis] - body_len / 2) body_end = body_start + body_len pts = [] for x in range(GS): for y in range(GS): for z in range(GS): p = [x, y, z] radial_sq = sum((p[a] - center[a])**2 for a in axes) along = p[axis] # Body if body_start <= along <= body_end and radial_sq <= r**2: pts.append(p) # Bottom cap elif along < body_start: cap_center = list(center) cap_center[axis] = body_start dist_sq = sum((p[i] - cap_center[i])**2 for i in range(3)) if dist_sq <= r**2: pts.append(p) # Top cap elif along > body_end: cap_center = list(center) cap_center[axis] = body_end dist_sq = sum((p[i] - cap_center[i])**2 for i in range(3)) if dist_sq <= r**2: pts.append(p) return np.array(pts) if len(pts) >= 5 else None elif name == "torus": R = rng.uniform(1.2, 2.0) r = rng.uniform(0.5, 0.9) axis = rng.randint(3) center = [cx, cy, cz] ring_axes = [i for i in range(3) if i != axis] pts = [] for x in range(GS): for y in range(GS): for z in range(GS): p = [x, y, z] dist_in_plane = np.sqrt( sum((p[a] - center[a])**2 for a in ring_axes)) dist_from_ring = np.sqrt( (dist_in_plane - R)**2 + (p[axis] - center[axis])**2) if dist_from_ring <= r: pts.append(p) return np.array(pts) if len(pts) >= 4 else None elif name == "shell": # Hollow sphere: outer radius - inner radius r_out = rng.uniform(1.5, 2.3) r_in = r_out - rng.uniform(0.4, 0.8) if r_in < 0.3: r_in = 0.3 pts = [] for x in range(GS): for y in range(GS): for z in range(GS): d_sq = (x - cx)**2 + (y - cy)**2 + (z - cz)**2 if r_in**2 <= d_sq <= r_out**2: pts.append([x, y, z]) return np.array(pts) if len(pts) >= 4 else None elif name == "tube": # Hollow cylinder r_out = rng.uniform(1.0, 2.0) r_in = r_out - rng.uniform(0.3, 0.7) if r_in < 0.2: r_in = 0.2 axis = rng.randint(3) length = rng.randint(2, 5) start = rng.randint(0, GS - length + 1) center = [cx, cy, cz] axes = [i for i in range(3) if i != axis] pts = [] for x in range(GS): for y in range(GS): for z in range(GS): p = [x, y, z] if p[axis] < start or p[axis] >= start + length: continue dist_sq = sum((p[a] - center[a])**2 for a in axes) if r_in**2 <= dist_sq <= r_out**2: pts.append(p) return np.array(pts) if len(pts) >= 4 else None elif name == "bowl": # Paraboloid: concave surface, open on top r = rng.uniform(1.2, 2.2) axis = rng.randint(3) center = [cx, cy, cz] axes = [i for i in range(3) if i != axis] thickness = 0.6 pts = [] for x in range(GS): for y in range(GS): for z in range(GS): p = [x, y, z] dist_planar = np.sqrt( sum((p[a] - center[a])**2 for a in axes)) if dist_planar > r: continue # Paraboloid surface: h = k * dist^2 k = 1.0 / (r + 1e-6) expected_h = center[axis] + k * dist_planar**2 actual_h = p[axis] if abs(actual_h - expected_h) <= thickness: pts.append(p) return np.array(pts) if len(pts) >= 4 else None elif name == "saddle": # Hyperbolic paraboloid: z = k*(x^2 - y^2) axis = rng.randint(3) center = [cx, cy, cz] axes = [i for i in range(3) if i != axis] k = rng.uniform(0.3, 0.8) thickness = 0.7 pts = [] for x in range(GS): for y in range(GS): for z in range(GS): p = [x, y, z] da = p[axes[0]] - center[axes[0]] db = p[axes[1]] - center[axes[1]] expected_h = center[axis] + k * (da**2 - db**2) if abs(p[axis] - expected_h) <= thickness: # Limit radius so it doesn't fill everything dist_sq = da**2 + db**2 if dist_sq <= 4.0: pts.append(p) return np.array(pts) if len(pts) >= 4 else None return None # === Helpers === def _rand_pts_2d(self, n, min_dist=0): for _ in range(50): pts = set() while len(pts) < n: pts.add((self.rng.randint(0, GS), self.rng.randint(0, GS))) pts = np.array(list(pts)[:n]) if min_dist <= 0 or self._check_dist(pts, min_dist): return pts return None def _rand_pts_3d(self, n, min_dist=0): for _ in range(100): pts = set() while len(pts) < n: pts.add(tuple(self.rng.randint(0, GS, size=3))) pts = np.array(list(pts)[:n]) if min_dist <= 0 or self._check_dist(pts, min_dist): return pts return None def _check_dist(self, pts, min_dist): for i in range(len(pts)): for j in range(i + 1, len(pts)): if np.sum(np.abs(pts[i] - pts[j])) < min_dist: return False return True def _build(self, name, info, voxels): n = len(voxels) sub = voxels[:6].astype(float) if n > 6 else voxels.astype(float) cm_det = cayley_menger_det(sub) volume = effective_volume(sub) dim_conf = np.zeros(4, dtype=np.float32) dim_conf[0] = 1.0 if n >= 2: dim_conf[1] = 1.0 if info["dim"] >= 2: dim_conf[2] = 1.0 if info["dim"] >= 3: dim_conf[3] = 1.0 grid = np.zeros((GS, GS, GS), dtype=np.float32) for v in voxels: grid[v[0], v[1], v[2]] = 1.0 return { "grid": grid, "label": CLASS_TO_IDX[name], "class_name": name, "n_points": n, "n_occupied": int(grid.sum()), "cm_det": float(cm_det), "volume": float(volume), "peak_dim": info["dim"], "dim_confidence": dim_conf, "is_curved": info["curved"], "curvature": CURV_TO_IDX[info["curvature"]], } # === Dataset ================================================================= def _generate_chunk(args): """Worker function for parallel shape generation.""" class_assignments, seed, start_idx = args gen = ShapeGenerator(seed=seed) samples = [] for ci in class_assignments: name = CLASS_NAMES[ci] for attempt in range(10): s = gen._make(name) if s is not None: samples.append(s) break else: s = gen._make("cube") if s is not None: samples.append(s) return samples def generate_parallel(n_samples, seed=42, n_workers=8): """Pre-generate all samples using multiprocessing.""" import multiprocessing as mp per_class = n_samples // NUM_CLASSES class_assignments = [] for ci in range(NUM_CLASSES): class_assignments.extend([ci] * per_class) rng = np.random.RandomState(seed) while len(class_assignments) < n_samples: class_assignments.append(rng.randint(0, NUM_CLASSES)) rng.shuffle(class_assignments) class_assignments = class_assignments[:n_samples] # Split into chunks per worker chunk_size = (n_samples + n_workers - 1) // n_workers chunks = [] for i in range(n_workers): start = i * chunk_size end = min(start + chunk_size, n_samples) if start >= n_samples: break chunks.append((class_assignments[start:end], seed + i * 1000000, start)) print(f"Generating {n_samples} shapes across {len(chunks)} workers...") import time; t0 = time.time() with mp.Pool(n_workers) as pool: results = pool.map(_generate_chunk, chunks) samples = [] for r in results: samples.extend(r) rng.shuffle(samples) dt = time.time() - t0 print(f"Generated {len(samples)} samples in {dt:.1f}s ({len(samples)/dt:.0f} samples/s)") return samples class ShapeDataset(torch.utils.data.Dataset): def __init__(self, samples): self.grids = torch.tensor(np.stack([s["grid"] for s in samples]), dtype=torch.float32) self.labels = torch.tensor([s["label"] for s in samples], dtype=torch.long) self.dim_conf = torch.tensor(np.stack([s["dim_confidence"] for s in samples]), dtype=torch.float32) self.peak_dim = torch.tensor([s["peak_dim"] for s in samples], dtype=torch.long) self.volume = torch.tensor([s["volume"] for s in samples], dtype=torch.float32) self.cm_det = torch.tensor([s["cm_det"] for s in samples], dtype=torch.float32) self.is_curved = torch.tensor([s["is_curved"] for s in samples], dtype=torch.float32) self.curvature = torch.tensor([s["curvature"] for s in samples], dtype=torch.long) def __len__(self): return len(self.labels) def __getitem__(self, idx): return (self.grids[idx], self.labels[idx], self.dim_conf[idx], self.peak_dim[idx], self.volume[idx], self.cm_det[idx], self.is_curved[idx], self.curvature[idx]) print(f'Loaded {NUM_CLASSES} shape classes, GS={GS}')