""" Hierarchical Shape Generator - Two-Tier Gate Version ====================================================== Generates grids only. Patch analysis split into: - Local properties: intrinsic to each patch's voxels (no cross-patch info) - Structural properties: relational, require neighborhood context Colab Cell 1 of 3 - runs first, populates shared namespace. """ import numpy as np from typing import Dict, Optional from itertools import combinations # === Grid Constants =========================================================== GZ, GY, GX = 8, 16, 16 GRID_SHAPE = (GZ, GY, GX) GRID_VOLUME = GZ * GY * GX PATCH_Z, PATCH_Y, PATCH_X = 2, 4, 4 PATCH_VOL = PATCH_Z * PATCH_Y * PATCH_X MACRO_Z, MACRO_Y, MACRO_X = GZ // PATCH_Z, GY // PATCH_Y, GX // PATCH_X MACRO_N = MACRO_Z * MACRO_Y * MACRO_X # Worker budget (A100 Colab limit) MAX_WORKERS = 10 _COORDS = np.mgrid[0:GZ, 0:GY, 0:GX].reshape(3, -1).T.astype(np.float64) # === Classes ================================================================== CLASS_NAMES = [ "point", "line", "corner", "cross", "arc", "helix", "circle", "triangle", "quad", "plane", "disc", "tetrahedron", "cube", "pyramid", "prism", "octahedron", "pentachoron", "wedge", "sphere", "hemisphere", "torus", "bowl", "saddle", "capsule", "cylinder", "cone", "channel" ] NUM_CLASSES = len(CLASS_NAMES) CLASS_TO_IDX = {n: i for i, n in enumerate(CLASS_NAMES)} # === Two-Tier Gate Constants ================================================== # Local gates: intrinsic to each patch, no cross-patch info needed # dims: 4 classes (0D point, 1D line, 2D surface, 3D volume) # curvature: 3 classes (rigid, curved, combined) # boundary: 1 binary (partial fill = surface patch) # axis_active: 3 binary (which axes have extent > 1 voxel) NUM_LOCAL_DIMS = 4 NUM_LOCAL_CURVS = 3 NUM_LOCAL_BOUNDARY = 1 NUM_LOCAL_AXES = 3 LOCAL_GATE_DIM = NUM_LOCAL_DIMS + NUM_LOCAL_CURVS + NUM_LOCAL_BOUNDARY + NUM_LOCAL_AXES # 11 # Structural gates: relational, require neighborhood context (post-attention) # topology: 2 classes (open / closed based on neighbor count) # neighbor_ct: 1 continuous (normalized 0-1, raw count / 6) # surface_role: 3 classes (isolated 0-1 neighbors, boundary 2-4, interior 5-6) NUM_STRUCT_TOPO = 2 NUM_STRUCT_NEIGHBOR = 1 NUM_STRUCT_ROLE = 3 STRUCTURAL_GATE_DIM = NUM_STRUCT_TOPO + NUM_STRUCT_NEIGHBOR + NUM_STRUCT_ROLE # 6 TOTAL_GATE_DIM = LOCAL_GATE_DIM + STRUCTURAL_GATE_DIM # 17 # Legacy compat GATES = ["rigid", "curved", "combined", "open", "closed"] NUM_GATES = len(GATES) # === Rasterization ============================================================ def rasterize_line(p1, p2): p1, p2 = np.array(p1, dtype=float), np.array(p2, dtype=float) n = max(int(np.max(np.abs(p2 - p1))) + 1, 2) t = np.linspace(0, 1, n)[:, None] pts = np.round(p1 + t * (p2 - p1)).astype(int) return np.clip(pts, [0, 0, 0], [GZ-1, GY-1, GX-1]) def rasterize_edges(verts, edges): pts = [] for i, j in edges: pts.append(rasterize_line(verts[i], verts[j])) return np.concatenate(pts) def rasterize_faces(verts, faces, density=1.0): pts = [] for f in faces: v0, v1, v2 = [np.array(verts[i], dtype=float) for i in f[:3]] e1, e2 = v1 - v0, v2 - v0 n = max(int(max(np.linalg.norm(e1), np.linalg.norm(e2)) * density) + 1, 3) for u in np.linspace(0, 1, n): for v in np.linspace(0, 1 - u, max(int(n * (1 - u)), 1)): p = np.round(v0 + u * e1 + v * e2).astype(int) pts.append(p) return np.clip(np.array(pts), [0, 0, 0], [GZ-1, GY-1, GX-1]) def rasterize_sphere(c, r, fill=True, half=False, zmin=None, zmax=None): c = np.array(c, dtype=float) pts = [] zr = range(max(0, int(c[0] - r)), min(GZ, int(c[0] + r) + 1)) for z in zr: for y in range(max(0, int(c[1] - r)), min(GY, int(c[1] + r) + 1)): for x in range(max(0, int(c[2] - r)), min(GX, int(c[2] + r) + 1)): d = np.sqrt((z - c[0])**2 + (y - c[1])**2 + (x - c[2])**2) if fill and d <= r: if zmin is not None and z < zmin: continue if zmax is not None and z > zmax: continue pts.append([z, y, x]) elif not fill and abs(d - r) < 0.8: if half and z < c[0]: continue pts.append([z, y, x]) return np.array(pts) if pts else np.zeros((0, 3), dtype=int) # === Shape Generators ========================================================= class HierarchicalShapeGenerator: def __init__(self, seed=42): self.rng = np.random.RandomState(seed) def _random_center(self, margin=3): return [self.rng.randint(max(1, margin//2), max(2, GZ - margin//2)), self.rng.randint(margin, GY - margin), self.rng.randint(margin, GX - margin)] def _to_grid(self, pts): if len(pts) == 0: return None, None grid = np.zeros(GRID_SHAPE, dtype=np.float32) pts = np.clip(np.array(pts).astype(int), [0, 0, 0], [GZ-1, GY-1, GX-1]) grid[pts[:, 0], pts[:, 1], pts[:, 2]] = 1.0 return grid, pts def generate(self, name): r = self.rng c = self._random_center() try: if name == "point": pts = [c] elif name == "line": axis = r.randint(0, 3) p1, p2 = list(c), list(c) L = r.randint(4, [GZ, GY, GX][axis]) p1[axis] = max(0, c[axis] - L//2) p2[axis] = min([GZ, GY, GX][axis] - 1, c[axis] + L//2) pts = rasterize_line(p1, p2) elif name == "corner": L = r.randint(3, 7) p1, p2 = list(c), list(c) p1[1] = max(0, c[1] - L) p2[2] = min(GX - 1, c[2] + L) pts = np.concatenate([rasterize_line(c, p1), rasterize_line(c, p2)]) elif name == "cross": L = r.randint(2, 5) pts = [] for d in range(3): p1, p2 = list(c), list(c) p1[d] = max(0, c[d] - L) p2[d] = min([GZ, GY, GX][d] - 1, c[d] + L) pts.append(rasterize_line(p1, p2)) pts = np.concatenate(pts) elif name == "arc": R = r.uniform(2, 5) t = np.linspace(0, np.pi * r.uniform(0.4, 0.9), 30) pts = np.round(np.column_stack([c[0] + np.zeros_like(t), c[1] + R*np.cos(t), c[2] + R*np.sin(t)])).astype(int) elif name == "helix": R, H = r.uniform(2, 4), r.uniform(3, GZ - 2) t = np.linspace(0, 4*np.pi, 60) pts = np.round(np.column_stack([c[0] - H/2 + t/(4*np.pi)*H, c[1] + R*np.cos(t), c[2] + R*np.sin(t)])).astype(int) elif name == "circle": R = r.uniform(2, 5) t = np.linspace(0, 2*np.pi, 40) pts = np.round(np.column_stack([np.full_like(t, c[0]), c[1] + R*np.cos(t), c[2] + R*np.sin(t)])).astype(int) elif name == "triangle": s = r.uniform(3, 6) v = [[c[0], c[1] - s, c[2]], [c[0], c[1] + s//2, c[2] - s], [c[0], c[1] + s//2, c[2] + s]] pts = rasterize_edges(v, [(0,1),(1,2),(2,0)]) elif name == "quad": s = r.randint(2, 5) v = [[c[0], c[1]-s, c[2]-s], [c[0], c[1]-s, c[2]+s], [c[0], c[1]+s, c[2]+s], [c[0], c[1]+s, c[2]-s]] pts = rasterize_edges(v, [(0,1),(1,2),(2,3),(3,0)]) elif name == "plane": s = r.randint(2, 5) pts = rasterize_faces([[c[0],c[1]-s,c[2]-s],[c[0],c[1]-s,c[2]+s],[c[0],c[1]+s,c[2]+s],[c[0],c[1]+s,c[2]-s]], [(0,1,2),(0,2,3)]) elif name == "disc": R = r.uniform(2, 5) pts = rasterize_sphere(c, R, fill=True) pts = pts[pts[:, 0] == c[0]] if len(pts) > 0 else pts elif name == "tetrahedron": s = r.uniform(3, 5) v = [[c[0]+s,c[1],c[2]], [c[0]-s//2,c[1]+s,c[2]], [c[0]-s//2,c[1]-s//2,c[2]+s], [c[0]-s//2,c[1]-s//2,c[2]-s]] pts = rasterize_edges(v, [(0,1),(0,2),(0,3),(1,2),(1,3),(2,3)]) elif name == "cube": s = r.randint(2, 4) v = [[c[0]+d[0]*s, c[1]+d[1]*s, c[2]+d[2]*s] for d in [(-1,-1,-1),(-1,-1,1),(-1,1,1),(-1,1,-1),(1,-1,-1),(1,-1,1),(1,1,1),(1,1,-1)]] pts = rasterize_edges(v, [(0,1),(1,2),(2,3),(3,0),(4,5),(5,6),(6,7),(7,4),(0,4),(1,5),(2,6),(3,7)]) elif name == "pyramid": s = r.randint(2, 4) base = [[c[0]-s,c[1]-s,c[2]-s],[c[0]-s,c[1]-s,c[2]+s],[c[0]-s,c[1]+s,c[2]+s],[c[0]-s,c[1]+s,c[2]-s]] apex = [c[0]+s, c[1], c[2]] v = base + [apex] pts = rasterize_edges(v, [(0,1),(1,2),(2,3),(3,0),(0,4),(1,4),(2,4),(3,4)]) elif name == "prism": s, h = r.randint(2, 4), r.randint(2, 4) bottom = [[c[0]-h,c[1]-s,c[2]], [c[0]-h,c[1]+s//2,c[2]-s], [c[0]-h,c[1]+s//2,c[2]+s]] top = [[b[0]+2*h, b[1], b[2]] for b in bottom] v = bottom + top pts = rasterize_edges(v, [(0,1),(1,2),(2,0),(3,4),(4,5),(5,3),(0,3),(1,4),(2,5)]) elif name == "octahedron": s = r.uniform(2, 4) v = [[c[0]+s,c[1],c[2]],[c[0]-s,c[1],c[2]],[c[0],c[1]+s,c[2]],[c[0],c[1]-s,c[2]],[c[0],c[1],c[2]+s],[c[0],c[1],c[2]-s]] pts = rasterize_edges(v, [(0,2),(0,3),(0,4),(0,5),(1,2),(1,3),(1,4),(1,5),(2,4),(2,5),(3,4),(3,5)]) elif name == "pentachoron": s = r.uniform(2, 4) v = [[c[0]+s,c[1],c[2]],[c[0]-s//2,c[1]+s,c[2]],[c[0]-s//2,c[1]-s//2,c[2]+s],[c[0]-s//2,c[1]-s//2,c[2]-s],[c[0],c[1],c[2]]] pts = rasterize_edges(v, [(i,j) for i in range(5) for j in range(i+1,5)]) elif name == "wedge": s = r.randint(2, 4) v = [[c[0]-s,c[1]-s,c[2]-s],[c[0]-s,c[1]+s,c[2]-s],[c[0]-s,c[1],c[2]+s],[c[0]+s,c[1]-s,c[2]-s],[c[0]+s,c[1]+s,c[2]-s],[c[0]+s,c[1],c[2]+s]] pts = rasterize_edges(v, [(0,1),(1,2),(2,0),(3,4),(4,5),(5,3),(0,3),(1,4),(2,5)]) elif name == "sphere": R = r.uniform(2, min(3.5, GZ//2 - 1)) pts = rasterize_sphere(c, R, fill=False) elif name == "hemisphere": R = r.uniform(2, min(3.5, GZ//2 - 1)) pts = rasterize_sphere(c, R, fill=False, half=True) elif name == "torus": R, rr = r.uniform(3, 5), r.uniform(1, 2) t = np.linspace(0, 2*np.pi, 40) p = np.linspace(0, 2*np.pi, 20) T, P = np.meshgrid(t, p) pts = np.round(np.column_stack([c[0] + rr*np.sin(P.ravel()), c[1] + (R+rr*np.cos(P.ravel()))*np.cos(T.ravel()), c[2] + (R+rr*np.cos(P.ravel()))*np.sin(T.ravel())])).astype(int) elif name == "bowl": R = r.uniform(2, 4) pts = rasterize_sphere(c, R, fill=False) pts = pts[pts[:, 0] >= c[0]] if len(pts) > 0 else pts elif name == "saddle": s = r.uniform(2, 4) Y, X = np.mgrid[-s:s:0.5, -s:s:0.5] Z = (Y**2 - X**2) / (2*s) pts = np.round(np.column_stack([c[0] + Z.ravel(), c[1] + Y.ravel(), c[2] + X.ravel()])).astype(int) elif name == "capsule": R, H = r.uniform(1.5, 3), r.uniform(2, 4) shell = rasterize_sphere(c, R, fill=False) body = [] for z in range(max(0, int(c[0]-H//2)), min(GZ, int(c[0]+H//2)+1)): for y in range(GY): for x in range(GX): if abs(np.sqrt((y-c[1])**2 + (x-c[2])**2) - R) < 0.8: body.append([z, y, x]) pts = np.concatenate([shell, np.array(body) if body else np.zeros((0,3), dtype=int)]) elif name == "cylinder": R, H = r.uniform(2, 4), r.uniform(3, GZ - 2) pts = [] for z in range(max(0, int(c[0]-H/2)), min(GZ, int(c[0]+H/2)+1)): for y in range(GY): for x in range(GX): d = np.sqrt((y-c[1])**2 + (x-c[2])**2) if abs(d - R) < 0.8: pts.append([z, y, x]) pts = np.array(pts) if pts else np.zeros((0,3), dtype=int) elif name == "cone": R, H = r.uniform(2, 4), r.uniform(3, GZ - 2) pts = [] for z in range(max(0, int(c[0]-H/2)), min(GZ, int(c[0]+H/2)+1)): frac = 1 - (z - (c[0]-H/2)) / H cr = R * frac for y in range(GY): for x in range(GX): d = np.sqrt((y-c[1])**2 + (x-c[2])**2) if abs(d - cr) < 0.8 and cr > 0.3: pts.append([z, y, x]) pts = np.array(pts) if pts else np.zeros((0,3), dtype=int) elif name == "channel": R = r.uniform(2, 4) L = r.randint(6, GX - 2) pts = [] for z in range(GZ): for x in range(max(0, c[2]-L//2), min(GX, c[2]+L//2)): for y in range(GY): d = np.sqrt((z - c[0])**2 + (y - c[1])**2) if abs(d - R) < 0.8: pts.append([z, y, x]) pts = np.array(pts) if pts else np.zeros((0,3), dtype=int) else: return None except Exception: return None grid, pts = self._to_grid(pts) if grid is not None and pts is not None and len(pts) > 0: return {"grid": grid, "class_idx": CLASS_TO_IDX[name]} return None def generate_multi(self, n_shapes: int = None) -> Optional[Dict]: if n_shapes is None: n_shapes = self.rng.randint(2, 5) names = list(self.rng.choice(CLASS_NAMES, size=n_shapes, replace=False)) shapes = [s for s in [self.generate(n) for n in names] if s is not None] if len(shapes) < 2: return None grid = np.zeros(GRID_SHAPE, dtype=np.float32) membership = np.zeros((MACRO_N, NUM_CLASSES), dtype=np.float32) for s in shapes: pts = np.argwhere(s["grid"] > 0.5) grid[pts[:, 0], pts[:, 1], pts[:, 2]] = 1.0 patch_idx = (pts[:, 0]//PATCH_Z) * (MACRO_Y*MACRO_X) + (pts[:, 1]//PATCH_Y) * MACRO_X + (pts[:, 2]//PATCH_X) np.add.at(membership[:, s["class_idx"]], patch_idx, 1.0) return {"grid": grid, "membership": (membership > 0).astype(np.float32), "n_shapes": len(shapes)} def _worker(args): seed, min_s, max_s = args gen = HierarchicalShapeGenerator(seed) return gen.generate_multi(gen.rng.randint(min_s, max_s + 1)) def generate_dataset(n_samples: int, seed: int = 42, num_workers: int = MAX_WORKERS) -> Dict: from multiprocessing import Pool try: from tqdm import tqdm use_tqdm = True except ImportError: use_tqdm = False tasks = [(seed * 10000 + i, 2, 4) for i in range(n_samples * 2)] grids, memberships, n_shapes = [], [], [] with Pool(num_workers) as pool: pbar = tqdm(total=n_samples, desc="Generating") if use_tqdm else None for r in pool.imap_unordered(_worker, tasks): if r is not None and len(grids) < n_samples: grids.append(r["grid"]) memberships.append(r["membership"]) n_shapes.append(r["n_shapes"]) if pbar: pbar.update(1) if len(grids) >= n_samples: break if pbar: pbar.close() return {"grids": np.array(grids), "memberships": np.array(memberships), "n_shapes": np.array(n_shapes)} # === Patch Analysis: Two-Tier ================================================= def analyze_local_patches(grids): """ Local patch properties — intrinsic to each patch's voxels. No cross-patch information. Computable from raw patch data. Returns: occupancy: (N, 64) float — mean voxel density dims: (N, 64) long — 0-3 (axis extent counting) curvature: (N, 64) long — 0=rigid, 1=curved, 2=combined boundary: (N, 64) float — 1.0 if partial fill (surface patch) axis_active: (N, 64, 3) float — which axes have extent > 1 fill_ratio: (N, 64) float — voxels / bounding_box_volume """ import torch if isinstance(grids, np.ndarray): grids = torch.from_numpy(grids).float() device, N = grids.device, grids.shape[0] patches = grids.view(N, MACRO_Z, PATCH_Z, MACRO_Y, PATCH_Y, MACRO_X, PATCH_X) patches = patches.permute(0, 1, 3, 5, 2, 4, 6).contiguous().view(N, MACRO_N, PATCH_Z, PATCH_Y, PATCH_X) occupancy = patches.sum(dim=(2, 3, 4)) / PATCH_VOL occ_mask = occupancy > 0.01 occ = patches > 0.5 z_c = torch.arange(PATCH_Z, device=device).view(1, 1, PATCH_Z, 1, 1).float() y_c = torch.arange(PATCH_Y, device=device).view(1, 1, 1, PATCH_Y, 1).float() x_c = torch.arange(PATCH_X, device=device).view(1, 1, 1, 1, PATCH_X).float() INF = 1000.0 z_ext = torch.where(occ, z_c.expand_as(patches), torch.full_like(patches, -INF)).amax(dim=(2,3,4)) - torch.where(occ, z_c.expand_as(patches), torch.full_like(patches, INF)).amin(dim=(2,3,4)) y_ext = torch.where(occ, y_c.expand_as(patches), torch.full_like(patches, -INF)).amax(dim=(2,3,4)) - torch.where(occ, y_c.expand_as(patches), torch.full_like(patches, INF)).amin(dim=(2,3,4)) x_ext = torch.where(occ, x_c.expand_as(patches), torch.full_like(patches, -INF)).amax(dim=(2,3,4)) - torch.where(occ, x_c.expand_as(patches), torch.full_like(patches, INF)).amin(dim=(2,3,4)) ext_sorted, _ = torch.stack([z_ext, y_ext, x_ext], dim=-1).clamp(min=0).sort(dim=-1, descending=True) dims = torch.zeros(N, MACRO_N, dtype=torch.long, device=device) dims = torch.where(ext_sorted[..., 0] >= 1, torch.tensor(1, device=device), dims) dims = torch.where(ext_sorted[..., 1] >= 1, torch.tensor(2, device=device), dims) dims = torch.where(ext_sorted[..., 2] >= 1, torch.tensor(3, device=device), dims) dims = torch.where(~occ_mask, torch.tensor(-1, device=device), dims) voxels = patches.sum(dim=(2, 3, 4)) bb_vol = ((z_ext + 1) * (y_ext + 1) * (x_ext + 1)).clamp(min=1) fill_ratio = voxels / bb_vol curvature = torch.where(fill_ratio > 0.6, 0, torch.where(fill_ratio < 0.3, 1, 2)).long() boundary = ((occupancy > 0.01) & (occupancy < 0.9)).float() axis_active = torch.stack([ (z_ext.clamp(min=0) >= 1).float(), (y_ext.clamp(min=0) >= 1).float(), (x_ext.clamp(min=0) >= 1).float(), ], dim=-1) return { "occupancy": occupancy, "dims": dims, "curvature": curvature, "boundary": boundary, "axis_active": axis_active, "fill_ratio": fill_ratio, } def analyze_structural_patches(grids, local_data): """ Structural patch properties — relational, require neighborhood context. Ground truth targets for post-attention heads. Returns: topology: (N, 64) long — 0=open (<= 3 neighbors), 1=closed (> 3) neighbor_count: (N, 64) float — normalized 0-1 (raw count / 6) surface_role: (N, 64) long — 0=isolated (0-1), 1=boundary (2-4), 2=interior (5-6) """ import torch import torch.nn.functional as F if isinstance(grids, np.ndarray): grids = torch.from_numpy(grids).float() device, N = grids.device, grids.shape[0] occ_mask = local_data["occupancy"] > 0.01 occ_3d = occ_mask.float().view(N, 1, MACRO_Z, MACRO_Y, MACRO_X) kernel = torch.zeros(1, 1, 3, 3, 3, device=device) kernel[0, 0, 1, 1, 0] = kernel[0, 0, 1, 1, 2] = 1 kernel[0, 0, 1, 0, 1] = kernel[0, 0, 1, 2, 1] = 1 kernel[0, 0, 0, 1, 1] = kernel[0, 0, 2, 1, 1] = 1 raw_count = F.conv3d(occ_3d, kernel, padding=1).view(N, MACRO_N) topology = (raw_count > 3).long() neighbor_count = raw_count / 6.0 surface_role = torch.zeros(N, MACRO_N, dtype=torch.long, device=device) surface_role = torch.where(raw_count >= 2, torch.tensor(1, device=device), surface_role) surface_role = torch.where(raw_count >= 5, torch.tensor(2, device=device), surface_role) return { "topology": topology, "neighbor_count": neighbor_count, "surface_role": surface_role, } def analyze_patches_torch(grids): """Combined analysis — returns both local and structural properties.""" local_data = analyze_local_patches(grids) struct_data = analyze_structural_patches(grids, local_data) import torch N = local_data["occupancy"].shape[0] device = local_data["occupancy"].device labels = torch.zeros(N, MACRO_N, NUM_GATES, device=device) labels[..., 0] = (local_data["curvature"] == 0).float() labels[..., 1] = (local_data["curvature"] == 1).float() labels[..., 2] = (local_data["curvature"] == 2).float() labels[..., 3] = (struct_data["topology"] == 0).float() labels[..., 4] = (struct_data["topology"] == 1).float() return { # Local "patch_occupancy": local_data["occupancy"], "patch_dims": local_data["dims"], "patch_curvature": local_data["curvature"], "patch_boundary": local_data["boundary"], "patch_axis_active": local_data["axis_active"], "patch_fill_ratio": local_data["fill_ratio"], # Structural "patch_topology": struct_data["topology"], "patch_neighbor_count": struct_data["neighbor_count"], "patch_surface_role": struct_data["surface_role"], # Legacy "patch_labels": labels, } # === Dataset ================================================================== import torch from torch.utils.data import Dataset class ShapeDataset(Dataset): def __init__(self, grids, memberships, patch_data): self.grids = grids self.memberships = memberships # Local self.patch_occupancy = patch_data["patch_occupancy"] self.patch_dims = patch_data["patch_dims"] self.patch_curvature = patch_data["patch_curvature"] self.patch_boundary = patch_data["patch_boundary"] self.patch_axis_active = patch_data["patch_axis_active"] self.patch_fill_ratio = patch_data["patch_fill_ratio"] # Structural self.patch_topology = patch_data["patch_topology"] self.patch_neighbor_count = patch_data["patch_neighbor_count"] self.patch_surface_role = patch_data["patch_surface_role"] # Legacy self.patch_labels = patch_data["patch_labels"] # Derived global targets self.patch_shape_count = (memberships > 0).sum(dim=-1).long() self.global_shapes = (memberships.sum(dim=1) > 0).float() occ_mask = self.patch_occupancy > 0.01 occ_count = occ_mask.sum(dim=1, keepdim=True).clamp(min=1) self.global_gates = (self.patch_labels * occ_mask.unsqueeze(-1)).sum(dim=1) / occ_count def __len__(self): return len(self.grids) def __getitem__(self, idx): return { "grid": self.grids[idx], "patch_shape_membership": self.memberships[idx], "patch_shape_count": self.patch_shape_count[idx], # Local "patch_occupancy": self.patch_occupancy[idx], "patch_dims": self.patch_dims[idx], "patch_curvature": self.patch_curvature[idx], "patch_boundary": self.patch_boundary[idx], "patch_axis_active": self.patch_axis_active[idx], "patch_fill_ratio": self.patch_fill_ratio[idx], # Structural "patch_topology": self.patch_topology[idx], "patch_neighbor_count": self.patch_neighbor_count[idx], "patch_surface_role": self.patch_surface_role[idx], # Legacy "patch_labels": self.patch_labels[idx], # Global "global_shapes": self.global_shapes[idx], "global_gates": self.global_gates[idx], } def collate_fn(batch): return {k: torch.stack([b[k] for b in batch]) for k in batch[0].keys()} print(f"✓ Generator ready | Local: {LOCAL_GATE_DIM}d | Structural: {STRUCTURAL_GATE_DIM}d | Total: {TOTAL_GATE_DIM}d")