AbstractPhil commited on
Commit
ce8387e
·
verified ·
1 Parent(s): 813575e

Create shape_generator.py

Browse files
Files changed (1) hide show
  1. shape_generator.py +722 -0
shape_generator.py ADDED
@@ -0,0 +1,722 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ 3D Voxel Shape Generator — VAE-Matched Resolution (8×16×16)
3
+ =============================================================
4
+ Adapted from v10.2 for Flux 2 VAE latent geometry:
5
+ - GZ=8 (channel dimension), GY=16, GX=16 (spatial dimensions)
6
+ - 2048 voxels per patch (vs 15,625 at 25³)
7
+ - Aspect ratio 1:2:2 matches VAE latent structure
8
+ - All 38 shape classes preserved
9
+ - Selective rasterization (polyhedra get edges, point-classes stay sparse)
10
+ - Shapes generated in native aspect ratio space
11
+
12
+ Multi-scale usage:
13
+ Patches extracted at any scale (4×8×8, 8×16×16, 16×32×32, 32×64×64)
14
+ are resized to canonical 8×16×16 before classification.
15
+ """
16
+
17
+ import numpy as np
18
+ from itertools import combinations
19
+
20
+ # === Grid dimensions (non-cubic, VAE-matched) ================================
21
+ GZ = 8 # channel dimension (thin)
22
+ GY = 16 # spatial height
23
+ GX = 16 # spatial width
24
+ GRID_SHAPE = (GZ, GY, GX)
25
+ GRID_VOLUME = GZ * GY * GX # 2048
26
+
27
+ # Precompute coordinate grid for vectorized generation
28
+ _COORDS = np.mgrid[0:GZ, 0:GY, 0:GX].reshape(3, -1).T.astype(np.float64)
29
+
30
+ # === Shape classes ============================================================
31
+ CLASS_META = {
32
+ # 0D
33
+ "point": {"dim": 0, "curved": False, "curvature": "none"},
34
+ # 1D
35
+ "line_x": {"dim": 1, "curved": False, "curvature": "none"},
36
+ "line_y": {"dim": 1, "curved": False, "curvature": "none"},
37
+ "line_z": {"dim": 1, "curved": False, "curvature": "none"},
38
+ "line_diag": {"dim": 1, "curved": False, "curvature": "none"},
39
+ "cross": {"dim": 1, "curved": False, "curvature": "none"},
40
+ "l_shape": {"dim": 1, "curved": False, "curvature": "none"},
41
+ "collinear": {"dim": 1, "curved": False, "curvature": "none"},
42
+ # 2D flat
43
+ "triangle_xy": {"dim": 2, "curved": False, "curvature": "none"},
44
+ "triangle_xz": {"dim": 2, "curved": False, "curvature": "none"},
45
+ "triangle_3d": {"dim": 2, "curved": False, "curvature": "none"},
46
+ "square_xy": {"dim": 2, "curved": False, "curvature": "none"},
47
+ "square_xz": {"dim": 2, "curved": False, "curvature": "none"},
48
+ "rectangle": {"dim": 2, "curved": False, "curvature": "none"},
49
+ "coplanar": {"dim": 2, "curved": False, "curvature": "none"},
50
+ "plane": {"dim": 2, "curved": False, "curvature": "none"},
51
+ # 3D flat
52
+ "tetrahedron": {"dim": 3, "curved": False, "curvature": "none"},
53
+ "pyramid": {"dim": 3, "curved": False, "curvature": "none"},
54
+ "pentachoron": {"dim": 3, "curved": False, "curvature": "none"},
55
+ "cube": {"dim": 3, "curved": False, "curvature": "none"},
56
+ "cuboid": {"dim": 3, "curved": False, "curvature": "none"},
57
+ "triangular_prism": {"dim": 3, "curved": False, "curvature": "none"},
58
+ "octahedron": {"dim": 3, "curved": False, "curvature": "none"},
59
+ # 1D curved
60
+ "arc": {"dim": 1, "curved": True, "curvature": "convex"},
61
+ "helix": {"dim": 1, "curved": True, "curvature": "helical"},
62
+ # 2D curved
63
+ "circle": {"dim": 2, "curved": True, "curvature": "convex"},
64
+ "ellipse": {"dim": 2, "curved": True, "curvature": "convex"},
65
+ "disc": {"dim": 2, "curved": True, "curvature": "convex"},
66
+ # 3D curved
67
+ "sphere": {"dim": 3, "curved": True, "curvature": "convex"},
68
+ "hemisphere": {"dim": 3, "curved": True, "curvature": "convex"},
69
+ "cylinder": {"dim": 3, "curved": True, "curvature": "cylindrical"},
70
+ "cone": {"dim": 3, "curved": True, "curvature": "conical"},
71
+ "capsule": {"dim": 3, "curved": True, "curvature": "convex"},
72
+ "torus": {"dim": 3, "curved": True, "curvature": "toroidal"},
73
+ "shell": {"dim": 3, "curved": True, "curvature": "convex"},
74
+ "tube": {"dim": 3, "curved": True, "curvature": "cylindrical"},
75
+ "bowl": {"dim": 3, "curved": True, "curvature": "concave"},
76
+ "saddle": {"dim": 3, "curved": True, "curvature": "hyperbolic"},
77
+ }
78
+
79
+ CLASS_NAMES = list(CLASS_META.keys())
80
+ NUM_CLASSES = len(CLASS_NAMES)
81
+ CLASS_TO_IDX = {n: i for i, n in enumerate(CLASS_NAMES)}
82
+
83
+ CURVATURE_NAMES = ["none", "convex", "concave", "cylindrical",
84
+ "conical", "toroidal", "hyperbolic", "helical"]
85
+ CURV_TO_IDX = {n: i for i, n in enumerate(CURVATURE_NAMES)}
86
+
87
+ # Edge topology for rasterized shapes
88
+ TRIANGLE_EDGES = [(0,1), (1,2), (2,0)]
89
+ QUAD_EDGES = [(0,1), (1,3), (3,2), (2,0)]
90
+ TETRA_EDGES = list(combinations(range(4), 2)) # 6 edges
91
+ CUBE_EDGES = [(0,1),(0,2),(0,4),(1,3),(1,5),(2,3),(2,6),(3,7),(4,5),(4,6),(5,7),(6,7)]
92
+ PYRAMID_EDGES = [(0,1),(1,3),(3,2),(2,0),(0,4),(1,4),(2,4),(3,4)]
93
+ PENTA_EDGES = list(combinations(range(5), 2)) # 10 edges
94
+ OCTA_EDGES = [(0,1),(0,2),(0,3),(0,4),(5,1),(5,2),(5,3),(5,4),(1,2),(2,3),(3,4),(4,1)]
95
+
96
+
97
+ def rasterize_line(p1, p2):
98
+ """Bresenham-style 3D line rasterization between two points."""
99
+ p1 = np.array(p1, dtype=float)
100
+ p2 = np.array(p2, dtype=float)
101
+ diff = p2 - p1
102
+ n_steps = max(int(np.max(np.abs(diff))) + 1, 2)
103
+ t = np.linspace(0, 1, n_steps)
104
+ pts = p1[None, :] + t[:, None] * diff[None, :]
105
+ pts = np.round(pts).astype(int)
106
+ # Clip to grid bounds
107
+ pts[:, 0] = np.clip(pts[:, 0], 0, GZ - 1)
108
+ pts[:, 1] = np.clip(pts[:, 1], 0, GY - 1)
109
+ pts[:, 2] = np.clip(pts[:, 2], 0, GX - 1)
110
+ return np.unique(pts, axis=0)
111
+
112
+
113
+ def rasterize_edges(vertices, edges):
114
+ """Rasterize a complete wireframe from vertex list and edge topology."""
115
+ all_pts = [vertices]
116
+ for i, j in edges:
117
+ all_pts.append(rasterize_line(vertices[i], vertices[j]))
118
+ return np.unique(np.vstack(all_pts), axis=0)
119
+
120
+
121
+ class ShapeGenerator:
122
+ def __init__(self, seed=42):
123
+ self.rng = np.random.RandomState(seed)
124
+
125
+ def _pts_to_result(self, pts):
126
+ """Convert point array to grid + metadata."""
127
+ pts = np.atleast_2d(pts).astype(int)
128
+ # Clip to grid
129
+ pts[:, 0] = np.clip(pts[:, 0], 0, GZ - 1)
130
+ pts[:, 1] = np.clip(pts[:, 1], 0, GY - 1)
131
+ pts[:, 2] = np.clip(pts[:, 2], 0, GX - 1)
132
+ pts = np.unique(pts, axis=0)
133
+ grid = np.zeros(GRID_SHAPE, dtype=np.float32)
134
+ grid[pts[:, 0], pts[:, 1], pts[:, 2]] = 1.0
135
+ return {"grid": grid, "n_occupied": int(pts.shape[0]), "points": pts}
136
+
137
+ def _rand_center(self, margin_z=2, margin_yx=3):
138
+ """Random center respecting aspect ratio margins."""
139
+ cz = self.rng.uniform(margin_z, GZ - margin_z)
140
+ cy = self.rng.uniform(margin_yx, GY - margin_yx)
141
+ cx = self.rng.uniform(margin_yx, GX - margin_yx)
142
+ return np.array([cz, cy, cx])
143
+
144
+ def _rand_pts_2d(self, n, min_dist=2):
145
+ """Random 2D points in YX plane."""
146
+ for _ in range(100):
147
+ pts = np.column_stack([
148
+ self.rng.randint(1, GY - 1, n),
149
+ self.rng.randint(1, GX - 1, n)])
150
+ dists = [np.linalg.norm(pts[i] - pts[j])
151
+ for i in range(n) for j in range(i+1, n)]
152
+ if all(d >= min_dist for d in dists):
153
+ return pts
154
+ return None
155
+
156
+ def _rand_pts_3d(self, n, min_dist=2):
157
+ """Random 3D points respecting grid bounds."""
158
+ for _ in range(100):
159
+ pts = np.column_stack([
160
+ self.rng.randint(0, GZ, n),
161
+ self.rng.randint(1, GY - 1, n),
162
+ self.rng.randint(1, GX - 1, n)])
163
+ dists = [np.linalg.norm(pts[i] - pts[j])
164
+ for i in range(n) for j in range(i+1, n)]
165
+ if all(d >= min_dist for d in dists):
166
+ return pts
167
+ return None
168
+
169
+ def _rigid(self, name):
170
+ """Generate rotation axis for rigid shapes."""
171
+ axes = [(1,0,0), (0,1,0), (0,0,1)]
172
+ return axes[self.rng.randint(len(axes))]
173
+
174
+ def _make(self, name):
175
+ rng = self.rng
176
+
177
+ # === 0D ===
178
+ if name == "point":
179
+ z = rng.randint(0, GZ)
180
+ y = rng.randint(0, GY)
181
+ x = rng.randint(0, GX)
182
+ return self._pts_to_result(np.array([[z, y, x]]))
183
+
184
+ # === 1D lines ===
185
+ elif name == "line_x":
186
+ z = rng.randint(0, GZ)
187
+ y = rng.randint(0, GY)
188
+ x1, x2 = sorted(rng.choice(GX, 2, replace=False))
189
+ pts = np.array([[z, y, x1], [z, y, x2]])
190
+ return self._pts_to_result(rasterize_line(pts[0], pts[1]))
191
+
192
+ elif name == "line_y":
193
+ z = rng.randint(0, GZ)
194
+ x = rng.randint(0, GX)
195
+ y1, y2 = sorted(rng.choice(GY, 2, replace=False))
196
+ pts = np.array([[z, y1, x], [z, y2, x]])
197
+ return self._pts_to_result(rasterize_line(pts[0], pts[1]))
198
+
199
+ elif name == "line_z":
200
+ y = rng.randint(0, GY)
201
+ x = rng.randint(0, GX)
202
+ z1, z2 = sorted(rng.choice(GZ, 2, replace=False))
203
+ pts = np.array([[z1, y, x], [z2, y, x]])
204
+ return self._pts_to_result(rasterize_line(pts[0], pts[1]))
205
+
206
+ elif name == "line_diag":
207
+ p1 = np.array([rng.randint(0, GZ), rng.randint(0, GY), rng.randint(0, GX)])
208
+ p2 = np.array([rng.randint(0, GZ), rng.randint(0, GY), rng.randint(0, GX)])
209
+ if np.linalg.norm(p1 - p2) < 3:
210
+ return None
211
+ return self._pts_to_result(rasterize_line(p1, p2))
212
+
213
+ elif name == "cross":
214
+ c = self._rand_center(margin_z=1, margin_yx=2)
215
+ arm_yx = rng.randint(2, min(6, GY // 2))
216
+ arm_z = rng.randint(1, min(3, GZ // 2))
217
+ pts = []
218
+ ci = np.round(c).astype(int)
219
+ # YX cross
220
+ for dy in range(-arm_yx, arm_yx + 1):
221
+ pts.append([ci[0], np.clip(ci[1] + dy, 0, GY-1), ci[2]])
222
+ for dx in range(-arm_yx, arm_yx + 1):
223
+ pts.append([ci[0], ci[1], np.clip(ci[2] + dx, 0, GX-1)])
224
+ # Z arm
225
+ for dz in range(-arm_z, arm_z + 1):
226
+ pts.append([np.clip(ci[0] + dz, 0, GZ-1), ci[1], ci[2]])
227
+ return self._pts_to_result(np.array(pts))
228
+
229
+ elif name == "l_shape":
230
+ c = self._rand_center(margin_z=1, margin_yx=2)
231
+ arm = rng.randint(2, min(5, GY // 2))
232
+ ci = np.round(c).astype(int)
233
+ pts = []
234
+ for dy in range(arm + 1):
235
+ pts.append([ci[0], np.clip(ci[1] + dy, 0, GY-1), ci[2]])
236
+ for dx in range(1, arm + 1):
237
+ pts.append([ci[0], ci[1], np.clip(ci[2] + dx, 0, GX-1)])
238
+ return self._pts_to_result(np.array(pts))
239
+
240
+ elif name == "collinear":
241
+ # Identity = 3 discrete points on a line, NOT a line segment
242
+ axis = rng.randint(3)
243
+ gs = [GZ, GY, GX]
244
+ vals = sorted(rng.choice(gs[axis], 3, replace=False))
245
+ fixed = [rng.randint(0, gs[(axis+1)%3]), rng.randint(0, gs[(axis+2)%3])]
246
+ pts = np.zeros((3, 3), dtype=int)
247
+ for i, v in enumerate(vals):
248
+ pts[i, axis] = v
249
+ pts[i, (axis + 1) % 3] = fixed[0]
250
+ pts[i, (axis + 2) % 3] = fixed[1]
251
+ return self._pts_to_result(pts)
252
+
253
+ # === 2D flat ===
254
+ elif name == "triangle_xy":
255
+ z = rng.randint(0, GZ)
256
+ pts2d = self._rand_pts_2d(3, min_dist=3)
257
+ if pts2d is None: return None
258
+ return self._pts_to_result(np.column_stack([np.full(3, z), pts2d]))
259
+
260
+ elif name == "triangle_xz":
261
+ y = rng.randint(0, GY)
262
+ for _ in range(50):
263
+ pts = np.column_stack([
264
+ rng.randint(0, GZ, 3),
265
+ np.full(3, y),
266
+ rng.randint(1, GX - 1, 3)])
267
+ dists = [np.linalg.norm(pts[i] - pts[j]) for i in range(3) for j in range(i+1, 3)]
268
+ if all(d >= 2 for d in dists):
269
+ return self._pts_to_result(pts)
270
+ return None
271
+
272
+ elif name == "triangle_3d":
273
+ verts = self._rand_pts_3d(3, min_dist=3)
274
+ if verts is None: return None
275
+ return self._pts_to_result(verts)
276
+
277
+ elif name == "square_xy":
278
+ z = rng.randint(0, GZ)
279
+ s = rng.randint(3, min(7, GY - 2))
280
+ cy, cx = rng.randint(s, GY - s), rng.randint(s, GX - s)
281
+ verts = np.array([
282
+ [z, cy - s, cx - s], [z, cy - s, cx + s],
283
+ [z, cy + s, cx - s], [z, cy + s, cx + s]])
284
+ return self._pts_to_result(rasterize_edges(verts, QUAD_EDGES))
285
+
286
+ elif name == "square_xz":
287
+ y = rng.randint(0, GY)
288
+ s_z = rng.randint(1, min(3, GZ // 2))
289
+ s_x = rng.randint(2, min(6, GX // 2))
290
+ cz, cx = rng.randint(s_z, GZ - s_z), rng.randint(s_x, GX - s_x)
291
+ verts = np.array([
292
+ [cz - s_z, y, cx - s_x], [cz - s_z, y, cx + s_x],
293
+ [cz + s_z, y, cx - s_x], [cz + s_z, y, cx + s_x]])
294
+ return self._pts_to_result(rasterize_edges(verts, QUAD_EDGES))
295
+
296
+ elif name == "rectangle":
297
+ z = rng.randint(0, GZ)
298
+ sy = rng.randint(2, min(6, GY // 2))
299
+ sx = rng.randint(2, min(6, GX // 2))
300
+ while abs(sy - sx) < 2:
301
+ sy = rng.randint(2, min(6, GY // 2))
302
+ sx = rng.randint(2, min(6, GX // 2))
303
+ cy, cx = rng.randint(sy, GY - sy), rng.randint(sx, GX - sx)
304
+ verts = np.array([
305
+ [z, cy - sy, cx - sx], [z, cy - sy, cx + sx],
306
+ [z, cy + sy, cx - sx], [z, cy + sy, cx + sx]])
307
+ return self._pts_to_result(rasterize_edges(verts, QUAD_EDGES))
308
+
309
+ elif name == "coplanar":
310
+ # Identity = 4 discrete coplanar points, NOT a quadrilateral
311
+ pts = self._rand_pts_3d(4, min_dist=2)
312
+ if pts is None: return None
313
+ axis = rng.randint(3)
314
+ pts[:, axis] = pts[0, axis]
315
+ return self._pts_to_result(pts)
316
+
317
+ elif name == "plane":
318
+ axis = rng.randint(3)
319
+ gs = [GZ, GY, GX]
320
+ pos = rng.randint(0, gs[axis])
321
+ thick = rng.randint(1, max(2, gs[axis] // 4) + 1)
322
+ mask = np.zeros(GRID_SHAPE, dtype=np.float32)
323
+ for t in range(thick):
324
+ p = min(pos + t, gs[axis] - 1)
325
+ if axis == 0: mask[p, :, :] = 1
326
+ elif axis == 1: mask[:, p, :] = 1
327
+ else: mask[:, :, p] = 1
328
+ pts = np.argwhere(mask > 0)
329
+ return self._pts_to_result(pts)
330
+
331
+ # === 3D polyhedra (rasterized edges) ===
332
+ elif name == "tetrahedron":
333
+ verts = self._rand_pts_3d(4, min_dist=3)
334
+ if verts is None: return None
335
+ return self._pts_to_result(rasterize_edges(verts, TETRA_EDGES))
336
+
337
+ elif name == "pyramid":
338
+ base_y = rng.randint(2, GY - 2)
339
+ s = rng.randint(2, min(5, GX // 2))
340
+ cy, cx = rng.randint(s + 1, GY - s - 1), rng.randint(s + 1, GX - s - 1)
341
+ base_z = rng.randint(1, GZ - 2)
342
+ apex_z = rng.randint(0, GZ) if rng.random() < 0.5 else base_z + rng.randint(2, min(4, GZ - base_z))
343
+ apex_z = min(apex_z, GZ - 1)
344
+ verts = np.array([
345
+ [base_z, cy - s, cx - s], [base_z, cy - s, cx + s],
346
+ [base_z, cy + s, cx - s], [base_z, cy + s, cx + s],
347
+ [apex_z, cy, cx]])
348
+ return self._pts_to_result(rasterize_edges(verts, PYRAMID_EDGES))
349
+
350
+ elif name == "pentachoron":
351
+ verts = self._rand_pts_3d(5, min_dist=3)
352
+ if verts is None: return None
353
+ return self._pts_to_result(rasterize_edges(verts, PENTA_EDGES))
354
+
355
+ elif name == "cube":
356
+ s_z = rng.randint(1, min(3, GZ // 2))
357
+ s_yx = rng.randint(2, min(5, GY // 2))
358
+ c = self._rand_center(margin_z=s_z + 1, margin_yx=s_yx + 1)
359
+ ci = np.round(c).astype(int)
360
+ verts = np.array([
361
+ [ci[0]-s_z, ci[1]-s_yx, ci[2]-s_yx],
362
+ [ci[0]-s_z, ci[1]-s_yx, ci[2]+s_yx],
363
+ [ci[0]-s_z, ci[1]+s_yx, ci[2]-s_yx],
364
+ [ci[0]-s_z, ci[1]+s_yx, ci[2]+s_yx],
365
+ [ci[0]+s_z, ci[1]-s_yx, ci[2]-s_yx],
366
+ [ci[0]+s_z, ci[1]-s_yx, ci[2]+s_yx],
367
+ [ci[0]+s_z, ci[1]+s_yx, ci[2]-s_yx],
368
+ [ci[0]+s_z, ci[1]+s_yx, ci[2]+s_yx]])
369
+ return self._pts_to_result(rasterize_edges(verts, CUBE_EDGES))
370
+
371
+ elif name == "cuboid":
372
+ sz = rng.randint(1, min(3, GZ // 2))
373
+ sy = rng.randint(2, min(6, GY // 2))
374
+ sx = rng.randint(2, min(6, GX // 2))
375
+ # Ensure at least one dimension differs significantly
376
+ while abs(sy - sx) < 2 and abs(sz * 2 - sy) < 2:
377
+ sy = rng.randint(2, min(6, GY // 2))
378
+ sx = rng.randint(2, min(6, GX // 2))
379
+ c = self._rand_center(margin_z=sz + 1, margin_yx=max(sy, sx) + 1)
380
+ ci = np.round(c).astype(int)
381
+ verts = np.array([
382
+ [ci[0]-sz, ci[1]-sy, ci[2]-sx], [ci[0]-sz, ci[1]-sy, ci[2]+sx],
383
+ [ci[0]-sz, ci[1]+sy, ci[2]-sx], [ci[0]-sz, ci[1]+sy, ci[2]+sx],
384
+ [ci[0]+sz, ci[1]-sy, ci[2]-sx], [ci[0]+sz, ci[1]-sy, ci[2]+sx],
385
+ [ci[0]+sz, ci[1]+sy, ci[2]-sx], [ci[0]+sz, ci[1]+sy, ci[2]+sx]])
386
+ return self._pts_to_result(rasterize_edges(verts, CUBE_EDGES))
387
+
388
+ elif name == "triangular_prism":
389
+ z1, z2 = sorted(rng.choice(GZ, 2, replace=False))
390
+ pts2d = self._rand_pts_2d(3, min_dist=3)
391
+ if pts2d is None: return None
392
+ # Two triangular faces + connecting edges
393
+ top = np.column_stack([np.full(3, z1), pts2d])
394
+ bot = np.column_stack([np.full(3, z2), pts2d])
395
+ verts = np.vstack([top, bot])
396
+ edges = [(0,1),(1,2),(2,0),(3,4),(4,5),(5,3),(0,3),(1,4),(2,5)]
397
+ return self._pts_to_result(rasterize_edges(verts, edges))
398
+
399
+ elif name == "octahedron":
400
+ c = self._rand_center(margin_z=2, margin_yx=3)
401
+ ci = np.round(c).astype(int)
402
+ rz = rng.randint(1, min(3, GZ // 2))
403
+ ryx = rng.randint(2, min(5, GY // 2))
404
+ verts = np.array([
405
+ [ci[0], ci[1] + ryx, ci[2]], [ci[0], ci[1], ci[2] + ryx],
406
+ [ci[0], ci[1] - ryx, ci[2]], [ci[0], ci[1], ci[2] - ryx],
407
+ [ci[0] + rz, ci[1], ci[2]], [ci[0] - rz, ci[1], ci[2]]])
408
+ return self._pts_to_result(rasterize_edges(verts, OCTA_EDGES))
409
+
410
+ # === 1D curved ===
411
+ elif name == "arc":
412
+ plane = rng.randint(3)
413
+ c = self._rand_center(margin_z=1, margin_yx=2)
414
+ r_main = rng.uniform(2.0, min(5.0, GY / 2 - 1))
415
+ r_z = rng.uniform(1.0, min(3.0, GZ / 2 - 1)) if plane != 0 else r_main
416
+ angle_start = rng.uniform(0, np.pi)
417
+ angle_span = rng.uniform(np.pi / 3, np.pi)
418
+ n_pts = max(8, int(angle_span * r_main))
419
+ t = np.linspace(angle_start, angle_start + angle_span, n_pts)
420
+ if plane == 0: # YX plane
421
+ pts = np.column_stack([
422
+ np.full(n_pts, c[0]),
423
+ c[1] + r_main * np.cos(t),
424
+ c[2] + r_main * np.sin(t)])
425
+ elif plane == 1: # ZX plane
426
+ pts = np.column_stack([
427
+ c[0] + r_z * np.cos(t),
428
+ np.full(n_pts, c[1]),
429
+ c[2] + r_main * np.sin(t)])
430
+ else: # ZY plane
431
+ pts = np.column_stack([
432
+ c[0] + r_z * np.cos(t),
433
+ c[1] + r_main * np.sin(t),
434
+ np.full(n_pts, c[2])])
435
+ pts = np.round(pts).astype(int)
436
+ return self._pts_to_result(pts)
437
+
438
+ elif name == "helix":
439
+ c = self._rand_center(margin_z=0, margin_yx=3)
440
+ r = rng.uniform(1.5, min(4.0, GY / 2 - 2))
441
+ turns = rng.uniform(1.0, 2.5)
442
+ n_pts = int(turns * 20)
443
+ t = np.linspace(0, turns * 2 * np.pi, n_pts)
444
+ z_span = GZ - 1
445
+ pts = np.column_stack([
446
+ t / (turns * 2 * np.pi) * z_span,
447
+ c[1] + r * np.cos(t),
448
+ c[2] + r * np.sin(t)])
449
+ pts = np.round(pts).astype(int)
450
+ return self._pts_to_result(pts)
451
+
452
+ # === 2D curved ===
453
+ elif name == "circle":
454
+ plane = rng.randint(3)
455
+ c = self._rand_center(margin_z=1, margin_yx=3)
456
+ r = rng.uniform(2.0, min(5.0, GY / 2 - 1))
457
+ n_pts = max(12, int(2 * np.pi * r))
458
+ t = np.linspace(0, 2 * np.pi, n_pts, endpoint=False)
459
+ if plane == 0:
460
+ pts = np.column_stack([
461
+ np.full(n_pts, c[0]),
462
+ c[1] + r * np.cos(t),
463
+ c[2] + r * np.sin(t)])
464
+ elif plane == 1:
465
+ r_z = min(r, GZ / 2 - 1)
466
+ pts = np.column_stack([
467
+ c[0] + r_z * np.cos(t),
468
+ np.full(n_pts, c[1]),
469
+ c[2] + r * np.sin(t)])
470
+ else:
471
+ r_z = min(r, GZ / 2 - 1)
472
+ pts = np.column_stack([
473
+ c[0] + r_z * np.cos(t),
474
+ c[1] + r * np.sin(t),
475
+ np.full(n_pts, c[2])])
476
+ pts = np.round(pts).astype(int)
477
+ return self._pts_to_result(pts)
478
+
479
+ elif name == "ellipse":
480
+ c = self._rand_center(margin_z=1, margin_yx=3)
481
+ ry = rng.uniform(2.0, min(5.0, GY / 2 - 1))
482
+ ratio = rng.uniform(1.6, 2.5)
483
+ if rng.random() < 0.5:
484
+ rx = ry / ratio
485
+ else:
486
+ rx = ry * ratio
487
+ rx = min(rx, GX / 2 - 1)
488
+ if rx / ry < 1.6: ry = rx / 1.6
489
+ n_pts = max(16, int(2 * np.pi * max(rx, ry)))
490
+ t = np.linspace(0, 2 * np.pi, n_pts, endpoint=False)
491
+ pts = np.column_stack([
492
+ np.full(n_pts, c[0]),
493
+ c[1] + ry * np.cos(t),
494
+ c[2] + rx * np.sin(t)])
495
+ pts = np.round(pts).astype(int)
496
+ return self._pts_to_result(pts)
497
+
498
+ elif name == "disc":
499
+ plane = rng.randint(3)
500
+ c = self._rand_center(margin_z=1, margin_yx=3)
501
+ r = rng.uniform(2.0, min(5.0, GY / 2 - 1))
502
+ if plane == 0:
503
+ mask = ((_COORDS[:, 1] - c[1])**2 + (_COORDS[:, 2] - c[2])**2 <= r**2) & \
504
+ (np.abs(_COORDS[:, 0] - c[0]) < 0.6)
505
+ elif plane == 1:
506
+ r_z = min(r, GZ / 2 - 1)
507
+ mask = ((_COORDS[:, 0] - c[0])**2 / max(r_z, 0.5)**2 +
508
+ (_COORDS[:, 2] - c[2])**2 / r**2 <= 1) & \
509
+ (np.abs(_COORDS[:, 1] - c[1]) < 0.6)
510
+ else:
511
+ r_z = min(r, GZ / 2 - 1)
512
+ mask = ((_COORDS[:, 0] - c[0])**2 / max(r_z, 0.5)**2 +
513
+ (_COORDS[:, 1] - c[1])**2 / r**2 <= 1) & \
514
+ (np.abs(_COORDS[:, 2] - c[2]) < 0.6)
515
+ pts = _COORDS[mask].astype(int)
516
+ if len(pts) < 3: return None
517
+ return self._pts_to_result(pts)
518
+
519
+ # === 3D curved ===
520
+ elif name == "sphere":
521
+ c = self._rand_center(margin_z=2, margin_yx=3)
522
+ r = rng.uniform(2.0, min(3.5, GZ / 2 - 0.5, GY / 2 - 1))
523
+ # Use ellipsoidal check respecting aspect ratio
524
+ d2 = ((_COORDS[:, 0] - c[0]) / r)**2 + \
525
+ ((_COORDS[:, 1] - c[1]) / r)**2 + \
526
+ ((_COORDS[:, 2] - c[2]) / r)**2
527
+ mask = d2 <= 1.0
528
+ pts = _COORDS[mask].astype(int)
529
+ if len(pts) < 4: return None
530
+ return self._pts_to_result(pts)
531
+
532
+ elif name == "hemisphere":
533
+ c = self._rand_center(margin_z=2, margin_yx=3)
534
+ r = rng.uniform(2.0, min(3.5, GZ / 2 - 0.5, GY / 2 - 1))
535
+ cut_axis = rng.randint(3)
536
+ d2 = ((_COORDS[:, 0] - c[0]) / r)**2 + \
537
+ ((_COORDS[:, 1] - c[1]) / r)**2 + \
538
+ ((_COORDS[:, 2] - c[2]) / r)**2
539
+ mask = d2 <= 1.0
540
+ if cut_axis == 0: mask &= _COORDS[:, 0] >= c[0]
541
+ elif cut_axis == 1: mask &= _COORDS[:, 1] >= c[1]
542
+ else: mask &= _COORDS[:, 2] >= c[2]
543
+ pts = _COORDS[mask].astype(int)
544
+ if len(pts) < 3: return None
545
+ return self._pts_to_result(pts)
546
+
547
+ elif name == "cylinder":
548
+ axis = rng.randint(3)
549
+ c = self._rand_center(margin_z=0, margin_yx=3)
550
+ r = rng.uniform(1.5, min(3.0, GY / 2 - 1))
551
+ if axis == 0:
552
+ d2 = (_COORDS[:, 1] - c[1])**2 + (_COORDS[:, 2] - c[2])**2
553
+ mask = d2 <= r**2
554
+ elif axis == 1:
555
+ r_z = min(r, GZ / 2 - 0.5)
556
+ d2 = (_COORDS[:, 0] - c[0])**2 / max(r_z, 0.5)**2 + \
557
+ (_COORDS[:, 2] - c[2])**2 / r**2
558
+ mask = d2 <= 1.0
559
+ else:
560
+ r_z = min(r, GZ / 2 - 0.5)
561
+ d2 = (_COORDS[:, 0] - c[0])**2 / max(r_z, 0.5)**2 + \
562
+ (_COORDS[:, 1] - c[1])**2 / r**2
563
+ mask = d2 <= 1.0
564
+ pts = _COORDS[mask].astype(int)
565
+ if len(pts) < 4: return None
566
+ return self._pts_to_result(pts)
567
+
568
+ elif name == "cone":
569
+ axis = rng.randint(3)
570
+ c = self._rand_center(margin_z=1, margin_yx=3)
571
+ r = rng.uniform(2.0, min(4.0, GY / 2 - 1))
572
+ gs = [GZ, GY, GX]
573
+ h = gs[axis] - 1
574
+ apex_frac = _COORDS[:, axis] / max(h, 1)
575
+ local_r = r * (1.0 - apex_frac)
576
+ if axis == 0:
577
+ d2 = (_COORDS[:, 1] - c[1])**2 + (_COORDS[:, 2] - c[2])**2
578
+ elif axis == 1:
579
+ d2 = (_COORDS[:, 0] - c[0])**2 + (_COORDS[:, 2] - c[2])**2
580
+ else:
581
+ d2 = (_COORDS[:, 0] - c[0])**2 + (_COORDS[:, 1] - c[1])**2
582
+ mask = d2 <= local_r**2
583
+ pts = _COORDS[mask].astype(int)
584
+ if len(pts) < 4: return None
585
+ return self._pts_to_result(pts)
586
+
587
+ elif name == "capsule":
588
+ axis = rng.randint(3)
589
+ c = self._rand_center(margin_z=1, margin_yx=3)
590
+ r = rng.uniform(1.5, min(2.5, GZ / 2 - 0.5, GY / 2 - 1))
591
+ gs = [GZ, GY, GX]
592
+ half_h = rng.uniform(1.0, gs[axis] / 2 - r - 0.5)
593
+ # Cylinder body + spherical caps
594
+ dist_axis = np.abs(_COORDS[:, axis] - c[axis])
595
+ clamped = np.clip(dist_axis - half_h, 0, None)
596
+ perp_axes = [i for i in range(3) if i != axis]
597
+ d2 = clamped**2
598
+ for a in perp_axes:
599
+ d2 += (_COORDS[:, a] - c[a])**2
600
+ mask = d2 <= r**2
601
+ pts = _COORDS[mask].astype(int)
602
+ if len(pts) < 4: return None
603
+ return self._pts_to_result(pts)
604
+
605
+ elif name == "torus":
606
+ c = self._rand_center(margin_z=2, margin_yx=4)
607
+ R = rng.uniform(2.5, min(4.0, GY / 2 - 2))
608
+ r = rng.uniform(0.8, min(1.5, GZ / 2 - 0.5, R * 0.5))
609
+ # Torus in YX plane
610
+ d_yx = np.sqrt((_COORDS[:, 1] - c[1])**2 + (_COORDS[:, 2] - c[2])**2)
611
+ d2 = (d_yx - R)**2 + (_COORDS[:, 0] - c[0])**2
612
+ mask = d2 <= r**2
613
+ pts = _COORDS[mask].astype(int)
614
+ if len(pts) < 4: return None
615
+ return self._pts_to_result(pts)
616
+
617
+ elif name == "shell":
618
+ c = self._rand_center(margin_z=1, margin_yx=3)
619
+ r = rng.uniform(2.0, min(3.5, GZ / 2 - 0.5, GY / 2 - 1))
620
+ thick = rng.uniform(0.4, 0.8)
621
+ d2 = ((_COORDS[:, 0] - c[0]) / r)**2 + \
622
+ ((_COORDS[:, 1] - c[1]) / r)**2 + \
623
+ ((_COORDS[:, 2] - c[2]) / r)**2
624
+ mask = (d2 <= 1.0) & (d2 >= (1.0 - thick)**2)
625
+ pts = _COORDS[mask].astype(int)
626
+ if len(pts) < 4: return None
627
+ return self._pts_to_result(pts)
628
+
629
+ elif name == "tube":
630
+ axis = rng.randint(3)
631
+ c = self._rand_center(margin_z=0, margin_yx=3)
632
+ r_out = rng.uniform(2.0, min(3.5, GY / 2 - 1))
633
+ r_in = r_out * rng.uniform(0.4, 0.7)
634
+ if axis == 0:
635
+ d2 = (_COORDS[:, 1] - c[1])**2 + (_COORDS[:, 2] - c[2])**2
636
+ elif axis == 1:
637
+ d2 = (_COORDS[:, 0] - c[0])**2 + (_COORDS[:, 2] - c[2])**2
638
+ else:
639
+ d2 = (_COORDS[:, 0] - c[0])**2 + (_COORDS[:, 1] - c[1])**2
640
+ mask = (d2 <= r_out**2) & (d2 >= r_in**2)
641
+ pts = _COORDS[mask].astype(int)
642
+ if len(pts) < 4: return None
643
+ return self._pts_to_result(pts)
644
+
645
+ elif name == "bowl":
646
+ c = self._rand_center(margin_z=1, margin_yx=3)
647
+ r = rng.uniform(2.0, min(3.5, GZ / 2 - 0.5, GY / 2 - 1))
648
+ thick = rng.uniform(0.3, 0.7)
649
+ d2 = ((_COORDS[:, 0] - c[0]) / r)**2 + \
650
+ ((_COORDS[:, 1] - c[1]) / r)**2 + \
651
+ ((_COORDS[:, 2] - c[2]) / r)**2
652
+ mask = (d2 <= 1.0) & (d2 >= (1.0 - thick)**2) & (_COORDS[:, 0] <= c[0])
653
+ pts = _COORDS[mask].astype(int)
654
+ if len(pts) < 3: return None
655
+ return self._pts_to_result(pts)
656
+
657
+ elif name == "saddle":
658
+ c = self._rand_center(margin_z=2, margin_yx=4)
659
+ scale = rng.uniform(1.5, 3.0)
660
+ dy = (_COORDS[:, 1] - c[1]) / scale
661
+ dx = (_COORDS[:, 2] - c[2]) / scale
662
+ z_saddle = c[0] + (dy**2 - dx**2)
663
+ mask = np.abs(_COORDS[:, 0] - z_saddle) < 0.8
664
+ pts = _COORDS[mask].astype(int)
665
+ if len(pts) < 4: return None
666
+ return self._pts_to_result(pts)
667
+
668
+ return None
669
+
670
+ def generate(self, name, max_retries=10):
671
+ """Generate one sample with retries."""
672
+ for _ in range(max_retries):
673
+ result = self._make(name)
674
+ if result is not None and result["n_occupied"] > 0:
675
+ return result
676
+ return None
677
+
678
+ def generate_dataset(self, n_per_class, seed=None):
679
+ """Generate balanced dataset."""
680
+ if seed is not None:
681
+ self.rng = np.random.RandomState(seed)
682
+ grids, labels, dims, curveds = [], [], [], []
683
+ for cls_idx, name in enumerate(CLASS_NAMES):
684
+ meta = CLASS_META[name]
685
+ count = 0
686
+ while count < n_per_class:
687
+ self.rng = np.random.RandomState(seed * 1000 + cls_idx * n_per_class + count if seed else None)
688
+ result = self.generate(name)
689
+ if result is not None:
690
+ grids.append(result["grid"])
691
+ labels.append(cls_idx)
692
+ dims.append(meta["dim"])
693
+ curveds.append(1 if meta["curved"] else 0)
694
+ count += 1
695
+ return {
696
+ "grids": np.array(grids),
697
+ "labels": np.array(labels),
698
+ "dims": np.array(dims),
699
+ "curveds": np.array(curveds),
700
+ }
701
+
702
+
703
+ # === Verification =============================================================
704
+ if __name__ == "__main__":
705
+ gen = ShapeGenerator(seed=42)
706
+ print(f"Grid: {GZ}×{GY}×{GX} = {GRID_VOLUME} voxels")
707
+ print(f"Classes: {NUM_CLASSES}")
708
+ print(f"\n{'Shape':20s} {'OK':>4s} {'Avg vox':>8s}")
709
+ print("-" * 36)
710
+ for name in CLASS_NAMES:
711
+ ok = 0; voxels = []
712
+ for trial in range(20):
713
+ gen.rng = np.random.RandomState(trial * 100 + hash(name) % 10000)
714
+ s = gen.generate(name)
715
+ if s:
716
+ ok += 1
717
+ voxels.append(s["n_occupied"])
718
+ avg = np.mean(voxels) if voxels else 0
719
+ status = "✓" if ok >= 15 else "✗"
720
+ print(f" {status} {name:20s} {ok:2d}/20 {avg:7.1f}")
721
+
722
+ print(f'\nLoaded {NUM_CLASSES} shape classes, grid={GZ}×{GY}×{GX}')