AbstractPhil commited on
Commit
4706821
·
verified ·
1 Parent(s): c73866e

Create data_generator.py

Browse files
Files changed (1) hide show
  1. data_generator.py +856 -0
data_generator.py ADDED
@@ -0,0 +1,856 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ 3D Voxel Shape Classifier — Complete Geometric Primitive Vocabulary
3
+ 5×5×5 binary voxel grid → rigid cascade → curvature analysis → classify
4
+
5
+ 38 shape classes covering:
6
+ - Rigid 0D-3D: points, lines, joints, triangles, quads, polyhedra, prisms
7
+ - Curved 1D: arcs, helices
8
+ - Curved 2D: circles, ellipses, discs
9
+ - Curved 3D solid: sphere, hemisphere, cylinder, cone, capsule, torus
10
+ - Curved 3D hollow: shell, tube
11
+ - Curved 3D open: bowl (concave), saddle (hyperbolic)
12
+
13
+ Curvature types: none, convex, concave, cylindrical, conical, toroidal, hyperbolic, helical
14
+ """
15
+
16
+ import numpy as np
17
+ import torch
18
+ import torch.nn as nn
19
+ import torch.nn.functional as F
20
+ from typing import Optional
21
+ import math
22
+ from itertools import combinations
23
+
24
+
25
+ # === SwiGLU Activation =======================================================
26
+
27
+ class SwiGLU(nn.Module):
28
+ """
29
+ SwiGLU activation: out = (x @ W1) * SiLU(x @ W2)
30
+
31
+ SiLU(x) = x * sigmoid(x), aka Swish — the "Swi" in SwiGLU.
32
+ Unlike plain sigmoid gating, SiLU preserves gradient magnitude
33
+ through the gate branch while maintaining sharp gating behavior.
34
+
35
+ Used at geometric decision points where crisp on/off transitions
36
+ matter more than smooth interpolation.
37
+ """
38
+
39
+ def __init__(self, in_dim, out_dim):
40
+ super().__init__()
41
+ self.w1 = nn.Linear(in_dim, out_dim)
42
+ self.w2 = nn.Linear(in_dim, out_dim)
43
+
44
+ def forward(self, x):
45
+ return self.w1(x) * F.silu(self.w2(x))
46
+
47
+
48
+ # === Shape Catalog ===========================================================
49
+
50
+ SHAPE_CATALOG = {
51
+ # ---- Rigid 0D ----
52
+ "point": {"dim": 0, "curved": False, "curvature": "none"},
53
+
54
+ # ---- Rigid 1D: lines ----
55
+ "line_x": {"dim": 1, "curved": False, "curvature": "none"},
56
+ "line_y": {"dim": 1, "curved": False, "curvature": "none"},
57
+ "line_z": {"dim": 1, "curved": False, "curvature": "none"},
58
+ "line_diag": {"dim": 1, "curved": False, "curvature": "none"},
59
+
60
+ # ---- Rigid 1D: compounds ----
61
+ "cross": {"dim": 1, "curved": False, "curvature": "none"},
62
+ "l_shape": {"dim": 1, "curved": False, "curvature": "none"},
63
+ "collinear": {"dim": 1, "curved": False, "curvature": "none"},
64
+
65
+ # ---- Rigid 2D: triangles ----
66
+ "triangle_xy": {"dim": 2, "curved": False, "curvature": "none"},
67
+ "triangle_xz": {"dim": 2, "curved": False, "curvature": "none"},
68
+ "triangle_3d": {"dim": 2, "curved": False, "curvature": "none"},
69
+
70
+ # ---- Rigid 2D: quads ----
71
+ "square_xy": {"dim": 2, "curved": False, "curvature": "none"},
72
+ "square_xz": {"dim": 2, "curved": False, "curvature": "none"},
73
+ "rectangle": {"dim": 2, "curved": False, "curvature": "none"},
74
+ "coplanar": {"dim": 2, "curved": False, "curvature": "none"},
75
+
76
+ # ---- Rigid 2D: filled ----
77
+ "plane": {"dim": 2, "curved": False, "curvature": "none"},
78
+
79
+ # ---- Rigid 3D: simplices ----
80
+ "tetrahedron": {"dim": 3, "curved": False, "curvature": "none"},
81
+ "pyramid": {"dim": 3, "curved": False, "curvature": "none"},
82
+ "pentachoron": {"dim": 3, "curved": False, "curvature": "none"},
83
+
84
+ # ---- Rigid 3D: prisms/polyhedra ----
85
+ "cube": {"dim": 3, "curved": False, "curvature": "none"},
86
+ "cuboid": {"dim": 3, "curved": False, "curvature": "none"},
87
+ "triangular_prism": {"dim": 3, "curved": False, "curvature": "none"},
88
+ "octahedron": {"dim": 3, "curved": False, "curvature": "none"},
89
+
90
+ # ---- Curved 1D ----
91
+ "arc": {"dim": 1, "curved": True, "curvature": "convex"},
92
+ "helix": {"dim": 1, "curved": True, "curvature": "helical"},
93
+
94
+ # ---- Curved 2D: outlines ----
95
+ "circle": {"dim": 2, "curved": True, "curvature": "convex"},
96
+ "ellipse": {"dim": 2, "curved": True, "curvature": "convex"},
97
+
98
+ # ---- Curved 2D: filled ----
99
+ "disc": {"dim": 2, "curved": True, "curvature": "convex"},
100
+
101
+ # ---- Curved 3D: solid ----
102
+ "sphere": {"dim": 3, "curved": True, "curvature": "convex"},
103
+ "hemisphere": {"dim": 3, "curved": True, "curvature": "convex"},
104
+ "cylinder": {"dim": 3, "curved": True, "curvature": "cylindrical"},
105
+ "cone": {"dim": 3, "curved": True, "curvature": "conical"},
106
+ "capsule": {"dim": 3, "curved": True, "curvature": "convex"},
107
+ "torus": {"dim": 3, "curved": True, "curvature": "toroidal"},
108
+
109
+ # ---- Curved 3D: hollow ----
110
+ "shell": {"dim": 3, "curved": True, "curvature": "convex"},
111
+ "tube": {"dim": 3, "curved": True, "curvature": "cylindrical"},
112
+
113
+ # ---- Curved 3D: open surfaces ----
114
+ "bowl": {"dim": 3, "curved": True, "curvature": "concave"},
115
+ "saddle": {"dim": 3, "curved": True, "curvature": "hyperbolic"},
116
+ }
117
+
118
+ NUM_CLASSES = len(SHAPE_CATALOG)
119
+ CLASS_NAMES = list(SHAPE_CATALOG.keys())
120
+ CLASS_TO_IDX = {name: i for i, name in enumerate(CLASS_NAMES)}
121
+
122
+ CURVATURE_TYPES = ["none", "convex", "concave", "cylindrical", "conical",
123
+ "toroidal", "hyperbolic", "helical"]
124
+ CURV_TO_IDX = {c: i for i, c in enumerate(CURVATURE_TYPES)}
125
+ NUM_CURVATURES = len(CURVATURE_TYPES)
126
+
127
+ GS = 5 # grid size
128
+
129
+
130
+ # === Cayley-Menger Utilities =================================================
131
+
132
+ def cayley_menger_det(points: np.ndarray) -> float:
133
+ n = len(points)
134
+ D = np.zeros((n, n))
135
+ for i in range(n):
136
+ for j in range(n):
137
+ D[i, j] = np.sum((points[i] - points[j]) ** 2)
138
+ CM = np.zeros((n + 1, n + 1))
139
+ CM[0, 1:] = 1
140
+ CM[1:, 0] = 1
141
+ CM[1:, 1:] = D
142
+ return np.linalg.det(CM)
143
+
144
+
145
+ def simplex_volume(points: np.ndarray) -> float:
146
+ k = len(points)
147
+ if k < 2: return 0.0
148
+ cm = cayley_menger_det(points)
149
+ sign = (-1) ** k
150
+ denom = (2 ** (k - 1)) * (math.factorial(k - 1) ** 2)
151
+ v_sq = sign * cm / denom
152
+ return np.sqrt(max(0, v_sq))
153
+
154
+
155
+ def effective_volume(points: np.ndarray) -> float:
156
+ k = len(points)
157
+ if k < 2: return 0.0
158
+ if k == 2: return np.linalg.norm(points[0] - points[1])
159
+ if k >= 3:
160
+ max_a = 0
161
+ for idx in combinations(range(min(k, 8)), 3):
162
+ max_a = max(max_a, simplex_volume(points[list(idx)]))
163
+ if k < 4: return max_a
164
+ if k >= 4:
165
+ max_v = 0
166
+ for idx in combinations(range(min(k, 8)), 4):
167
+ max_v = max(max_v, simplex_volume(points[list(idx)]))
168
+ return max_v
169
+ return 0.0
170
+
171
+
172
+ # === Shape Generator =========================================================
173
+
174
+ class ShapeGenerator:
175
+ def __init__(self, seed=42):
176
+ self.rng = np.random.RandomState(seed)
177
+
178
+ def generate(self, n_samples: int) -> list:
179
+ samples = []
180
+ per_class = n_samples // NUM_CLASSES
181
+ for name in CLASS_NAMES:
182
+ count = 0
183
+ attempts = 0
184
+ while count < per_class and attempts < per_class * 5:
185
+ s = self._make(name)
186
+ attempts += 1
187
+ if s is not None:
188
+ samples.append(s)
189
+ count += 1
190
+ while len(samples) < n_samples:
191
+ name = self.rng.choice(CLASS_NAMES)
192
+ s = self._make(name)
193
+ if s is not None:
194
+ samples.append(s)
195
+ self.rng.shuffle(samples)
196
+ return samples[:n_samples]
197
+
198
+ def _make(self, name: str) -> Optional[dict]:
199
+ info = SHAPE_CATALOG[name]
200
+ if info["curved"]:
201
+ voxels = self._curved(name)
202
+ else:
203
+ voxels = self._rigid(name)
204
+ if voxels is None: return None
205
+ voxels = np.clip(voxels, 0, GS - 1).astype(int)
206
+ voxels = np.unique(voxels, axis=0)
207
+ if len(voxels) < 1: return None
208
+ return self._build(name, info, voxels)
209
+
210
+ # === Rigid Generators ===
211
+
212
+ def _rigid(self, name):
213
+ rng = self.rng
214
+
215
+ if name == "point":
216
+ return rng.randint(0, GS, size=(1, 3))
217
+
218
+ elif name == "line_x":
219
+ y, z = rng.randint(0, GS, size=2)
220
+ x1, x2 = sorted(rng.choice(GS, 2, replace=False))
221
+ return np.array([[x1, y, z], [x2, y, z]])
222
+
223
+ elif name == "line_y":
224
+ x, z = rng.randint(0, GS, size=2)
225
+ y1, y2 = sorted(rng.choice(GS, 2, replace=False))
226
+ return np.array([[x, y1, z], [x, y2, z]])
227
+
228
+ elif name == "line_z":
229
+ x, y = rng.randint(0, GS, size=2)
230
+ z1, z2 = sorted(rng.choice(GS, 2, replace=False))
231
+ return np.array([[x, y, z1], [x, y, z2]])
232
+
233
+ elif name == "line_diag":
234
+ p1 = rng.randint(0, 3, size=3)
235
+ step = rng.randint(1, 3)
236
+ direction = rng.choice([-1, 1], size=3)
237
+ if np.sum(direction != 0) < 2:
238
+ direction[rng.randint(3)] = rng.choice([-1, 1])
239
+ p2 = np.clip(p1 + step * direction, 0, GS - 1)
240
+ if np.array_equal(p1, p2):
241
+ p2 = np.clip(p1 + np.array([1, 1, 0]), 0, GS - 1)
242
+ return np.array([p1, p2])
243
+
244
+ elif name == "cross":
245
+ # Two perpendicular lines intersecting at a point
246
+ cx, cy, cz = rng.randint(1, GS - 1, size=3)
247
+ length = rng.randint(1, 3)
248
+ axis1, axis2 = rng.choice(3, 2, replace=False)
249
+ pts = [[cx, cy, cz]] # center
250
+ for sign in [-1, 1]:
251
+ p = [cx, cy, cz]
252
+ p[axis1] = np.clip(p[axis1] + sign * length, 0, GS - 1)
253
+ pts.append(list(p))
254
+ for sign in [-1, 1]:
255
+ p = [cx, cy, cz]
256
+ p[axis2] = np.clip(p[axis2] + sign * length, 0, GS - 1)
257
+ pts.append(list(p))
258
+ return np.array(pts)
259
+
260
+ elif name == "l_shape":
261
+ # Two lines meeting at a vertex (right angle)
262
+ corner = rng.randint(1, GS - 1, size=3)
263
+ axis1, axis2 = rng.choice(3, 2, replace=False)
264
+ len1 = rng.randint(1, 3)
265
+ len2 = rng.randint(1, 3)
266
+ dir1 = rng.choice([-1, 1])
267
+ dir2 = rng.choice([-1, 1])
268
+ pts = [list(corner)]
269
+ for i in range(1, len1 + 1):
270
+ p = list(corner)
271
+ p[axis1] = np.clip(p[axis1] + dir1 * i, 0, GS - 1)
272
+ pts.append(p)
273
+ for i in range(1, len2 + 1):
274
+ p = list(corner)
275
+ p[axis2] = np.clip(p[axis2] + dir2 * i, 0, GS - 1)
276
+ pts.append(p)
277
+ return np.array(pts)
278
+
279
+ elif name == "collinear":
280
+ axis = rng.randint(3)
281
+ fixed = rng.randint(0, GS, size=2)
282
+ vals = sorted(rng.choice(GS, 3, replace=False))
283
+ pts = np.zeros((3, 3), dtype=int)
284
+ for i, v in enumerate(vals):
285
+ pts[i, axis] = v
286
+ pts[i, (axis + 1) % 3] = fixed[0]
287
+ pts[i, (axis + 2) % 3] = fixed[1]
288
+ return pts
289
+
290
+ elif name == "triangle_xy":
291
+ z = rng.randint(0, GS)
292
+ pts = self._rand_pts_2d(3, min_dist=1)
293
+ if pts is None: return None
294
+ return np.column_stack([pts, np.full(3, z)])
295
+
296
+ elif name == "triangle_xz":
297
+ y = rng.randint(0, GS)
298
+ pts = self._rand_pts_2d(3, min_dist=1)
299
+ if pts is None: return None
300
+ return np.column_stack([pts[:, 0], np.full(3, y), pts[:, 1]])
301
+
302
+ elif name == "triangle_3d":
303
+ return self._rand_pts_3d(3, min_dist=1)
304
+
305
+ elif name == "square_xy":
306
+ z = rng.randint(0, GS)
307
+ x1, y1 = rng.randint(0, 3, size=2)
308
+ s = rng.randint(1, 3)
309
+ pts = np.array([[x1, y1, z], [x1 + s, y1, z],
310
+ [x1, y1 + s, z], [x1 + s, y1 + s, z]])
311
+ return np.clip(pts, 0, GS - 1)
312
+
313
+ elif name == "square_xz":
314
+ y = rng.randint(0, GS)
315
+ x1, z1 = rng.randint(0, 3, size=2)
316
+ s = rng.randint(1, 3)
317
+ pts = np.array([[x1, y, z1], [x1 + s, y, z1],
318
+ [x1, y, z1 + s], [x1 + s, y, z1 + s]])
319
+ return np.clip(pts, 0, GS - 1)
320
+
321
+ elif name == "rectangle":
322
+ axis = rng.randint(3)
323
+ val = rng.randint(0, GS)
324
+ a1, a2 = rng.randint(0, 3), rng.randint(0, 3)
325
+ w, h = rng.randint(1, 4), rng.randint(1, 3)
326
+ if w == h: w = min(GS - 1, w + 1)
327
+ c = np.array([[a1, a2], [a1 + w, a2], [a1, a2 + h], [a1 + w, a2 + h]])
328
+ c = np.clip(c, 0, GS - 1)
329
+ if axis == 0: return np.column_stack([np.full(4, val), c])
330
+ elif axis == 1: return np.column_stack([c[:, 0], np.full(4, val), c[:, 1]])
331
+ else: return np.column_stack([c, np.full(4, val)])
332
+
333
+ elif name == "coplanar":
334
+ pts = self._rand_pts_3d(4, min_dist=1)
335
+ if pts is None: return None
336
+ pts[:, rng.randint(3)] = pts[0, rng.randint(3)]
337
+ return pts
338
+
339
+ elif name == "plane":
340
+ # Filled rectangular slab, 1 voxel thick
341
+ axis = rng.randint(3)
342
+ val = rng.randint(0, GS)
343
+ a_start = rng.randint(0, 2)
344
+ b_start = rng.randint(0, 2)
345
+ a_size = rng.randint(2, GS - a_start + 1)
346
+ b_size = rng.randint(2, GS - b_start + 1)
347
+ pts = []
348
+ for a in range(a_start, min(GS, a_start + a_size)):
349
+ for b in range(b_start, min(GS, b_start + b_size)):
350
+ p = [0, 0, 0]
351
+ p[axis] = val
352
+ p[(axis + 1) % 3] = a
353
+ p[(axis + 2) % 3] = b
354
+ pts.append(p)
355
+ return np.array(pts) if len(pts) >= 4 else None
356
+
357
+ elif name == "tetrahedron":
358
+ pts = self._rand_pts_3d(4, min_dist=1)
359
+ if pts is None: return None
360
+ centered = pts - pts.mean(axis=0)
361
+ _, s, _ = np.linalg.svd(centered.astype(float))
362
+ if s[-1] < 0.5:
363
+ pts[rng.randint(4), rng.randint(3)] = (pts[0, 0] + 2) % GS
364
+ return pts
365
+
366
+ elif name == "pyramid":
367
+ z_base = rng.randint(0, 3)
368
+ x1, y1 = rng.randint(0, 3), rng.randint(0, 3)
369
+ s = rng.randint(1, 3)
370
+ base = np.array([[x1, y1, z_base], [x1 + s, y1, z_base],
371
+ [x1, y1 + s, z_base], [x1 + s, y1 + s, z_base]])
372
+ apex = np.array([[x1 + s // 2, y1 + s // 2, z_base + rng.randint(1, 3)]])
373
+ return np.clip(np.vstack([base, apex]), 0, GS - 1)
374
+
375
+ elif name == "pentachoron":
376
+ return self._rand_pts_3d(5, min_dist=1)
377
+
378
+ elif name == "cube":
379
+ x1, y1, z1 = rng.randint(0, 3, size=3)
380
+ s = rng.randint(1, 3)
381
+ pts = []
382
+ for dx in [0, s]:
383
+ for dy in [0, s]:
384
+ for dz in [0, s]:
385
+ pts.append([x1 + dx, y1 + dy, z1 + dz])
386
+ return np.clip(np.array(pts), 0, GS - 1)
387
+
388
+ elif name == "cuboid":
389
+ x1, y1, z1 = rng.randint(0, 2, size=3)
390
+ sx, sy, sz = rng.randint(1, 4, size=3)
391
+ # Ensure not a cube: at least 2 different edge lengths
392
+ if sx == sy == sz:
393
+ sx = min(GS - 1, sx + 1)
394
+ pts = []
395
+ for dx in [0, sx]:
396
+ for dy in [0, sy]:
397
+ for dz in [0, sz]:
398
+ pts.append([x1 + dx, y1 + dy, z1 + dz])
399
+ return np.clip(np.array(pts), 0, GS - 1)
400
+
401
+ elif name == "triangular_prism":
402
+ # Triangle in one plane, extruded along the other axis
403
+ axis = rng.randint(3) # extrusion axis
404
+ ext_start = rng.randint(0, 3)
405
+ ext_len = rng.randint(1, 3)
406
+ tri = self._rand_pts_2d(3, min_dist=1)
407
+ if tri is None: return None
408
+ pts = []
409
+ for e in range(ext_start, min(GS, ext_start + ext_len + 1)):
410
+ for t in tri:
411
+ p = [0, 0, 0]
412
+ p[axis] = e
413
+ p[(axis + 1) % 3] = t[0]
414
+ p[(axis + 2) % 3] = t[1]
415
+ pts.append(p)
416
+ return np.clip(np.array(pts), 0, GS - 1) if len(pts) >= 6 else None
417
+
418
+ elif name == "octahedron":
419
+ # 6 vertices: ±1 along each axis from center
420
+ cx, cy, cz = rng.randint(1, GS - 1, size=3)
421
+ s = rng.randint(1, 3)
422
+ pts = [[cx, cy, cz + s], [cx, cy, cz - s],
423
+ [cx + s, cy, cz], [cx - s, cy, cz],
424
+ [cx, cy + s, cz], [cx, cy - s, cz]]
425
+ return np.clip(np.array(pts), 0, GS - 1)
426
+
427
+ return None
428
+
429
+ # === Curved Generators ===
430
+
431
+ def _curved(self, name):
432
+ rng = self.rng
433
+ cx, cy, cz = rng.uniform(1.0, 3.0, size=3)
434
+
435
+ if name == "arc":
436
+ r = rng.uniform(1.2, 2.2)
437
+ plane = rng.choice(["xy", "xz", "yz"])
438
+ start = rng.uniform(0, 2 * np.pi)
439
+ span = rng.uniform(np.pi * 0.4, np.pi * 1.2)
440
+ n = rng.randint(6, 12)
441
+ angles = np.linspace(start, start + span, n)
442
+ pts = []
443
+ for a in angles:
444
+ if plane == "xy":
445
+ pts.append([cx + r * np.cos(a), cy + r * np.sin(a), cz])
446
+ elif plane == "xz":
447
+ pts.append([cx + r * np.cos(a), cy, cz + r * np.sin(a)])
448
+ else:
449
+ pts.append([cx, cy + r * np.cos(a), cz + r * np.sin(a)])
450
+ pts = np.unique(np.round(np.clip(pts, 0, GS - 1)).astype(int), axis=0)
451
+ return pts if len(pts) >= 3 else None
452
+
453
+ elif name == "helix":
454
+ # Spiral through 3D: parametric curve
455
+ r = rng.uniform(0.8, 1.8)
456
+ axis = rng.randint(3)
457
+ pitch = rng.uniform(0.3, 0.8) # rise per radian
458
+ n = rng.randint(15, 30)
459
+ t = np.linspace(0, 2 * np.pi * rng.uniform(1.0, 2.5), n)
460
+ pts = []
461
+ center = [cx, cy, cz]
462
+ axes = [i for i in range(3) if i != axis]
463
+ start_h = rng.uniform(0, 1.0)
464
+ for ti in t:
465
+ p = [0.0, 0.0, 0.0]
466
+ p[axes[0]] = center[axes[0]] + r * np.cos(ti)
467
+ p[axes[1]] = center[axes[1]] + r * np.sin(ti)
468
+ p[axis] = start_h + pitch * ti
469
+ pts.append(p)
470
+ pts = np.unique(np.round(np.clip(pts, 0, GS - 1)).astype(int), axis=0)
471
+ return pts if len(pts) >= 5 else None
472
+
473
+ elif name == "circle":
474
+ r = rng.uniform(1.0, 2.0)
475
+ plane = rng.choice(["xy", "xz", "yz"])
476
+ n = rng.randint(12, 20)
477
+ angles = np.linspace(0, 2 * np.pi, n, endpoint=False)
478
+ pts = []
479
+ for a in angles:
480
+ if plane == "xy":
481
+ pts.append([cx + r * np.cos(a), cy + r * np.sin(a), cz])
482
+ elif plane == "xz":
483
+ pts.append([cx + r * np.cos(a), cy, cz + r * np.sin(a)])
484
+ else:
485
+ pts.append([cx, cy + r * np.cos(a), cz + r * np.sin(a)])
486
+ pts = np.unique(np.round(np.clip(pts, 0, GS - 1)).astype(int), axis=0)
487
+ return pts if len(pts) >= 5 else None
488
+
489
+ elif name == "ellipse":
490
+ rx, ry = rng.uniform(0.8, 2.0), rng.uniform(0.8, 2.0)
491
+ if abs(rx - ry) < 0.3: rx *= 1.4
492
+ plane = rng.choice(["xy", "xz", "yz"])
493
+ n = rng.randint(12, 20)
494
+ angles = np.linspace(0, 2 * np.pi, n, endpoint=False)
495
+ pts = []
496
+ for a in angles:
497
+ if plane == "xy":
498
+ pts.append([cx + rx * np.cos(a), cy + ry * np.sin(a), cz])
499
+ elif plane == "xz":
500
+ pts.append([cx + rx * np.cos(a), cy, cz + ry * np.sin(a)])
501
+ else:
502
+ pts.append([cx, cy + rx * np.cos(a), cz + ry * np.sin(a)])
503
+ pts = np.unique(np.round(np.clip(pts, 0, GS - 1)).astype(int), axis=0)
504
+ return pts if len(pts) >= 5 else None
505
+
506
+ elif name == "disc":
507
+ # Filled circle in a plane (not just outline)
508
+ r = rng.uniform(1.0, 2.2)
509
+ axis = rng.randint(3)
510
+ val = round(rng.uniform(0.5, 3.5))
511
+ center = [cx, cy, cz]
512
+ axes = [i for i in range(3) if i != axis]
513
+ pts = []
514
+ for x in range(GS):
515
+ for y in range(GS):
516
+ p = [0, 0, 0]
517
+ p[axis] = val
518
+ p[axes[0]] = x
519
+ p[axes[1]] = y
520
+ dist = np.sqrt((x - center[axes[0]])**2 + (y - center[axes[1]])**2)
521
+ if dist <= r:
522
+ pts.append(p)
523
+ return np.array(pts) if len(pts) >= 4 else None
524
+
525
+ elif name == "sphere":
526
+ r = rng.uniform(1.0, 2.2)
527
+ pts = []
528
+ for x in range(GS):
529
+ for y in range(GS):
530
+ for z in range(GS):
531
+ if (x - cx)**2 + (y - cy)**2 + (z - cz)**2 <= r**2:
532
+ pts.append([x, y, z])
533
+ return np.array(pts) if len(pts) >= 4 else None
534
+
535
+ elif name == "hemisphere":
536
+ r = rng.uniform(1.0, 2.2)
537
+ cut_axis = rng.randint(3)
538
+ center = [cx, cy, cz]
539
+ pts = []
540
+ for x in range(GS):
541
+ for y in range(GS):
542
+ for z in range(GS):
543
+ p = [x, y, z]
544
+ if (x - cx)**2 + (y - cy)**2 + (z - cz)**2 <= r**2:
545
+ if p[cut_axis] >= center[cut_axis]:
546
+ pts.append(p)
547
+ return np.array(pts) if len(pts) >= 3 else None
548
+
549
+ elif name == "cylinder":
550
+ r = rng.uniform(0.8, 1.8)
551
+ axis = rng.randint(3)
552
+ length = rng.randint(2, 5)
553
+ start = rng.randint(0, GS - length + 1)
554
+ center = [cx, cy, cz]
555
+ axes = [i for i in range(3) if i != axis]
556
+ pts = []
557
+ for x in range(GS):
558
+ for y in range(GS):
559
+ for z in range(GS):
560
+ p = [x, y, z]
561
+ if p[axis] < start or p[axis] >= start + length: continue
562
+ dist_sq = sum((p[a] - center[a])**2 for a in axes)
563
+ if dist_sq <= r**2:
564
+ pts.append(p)
565
+ return np.array(pts) if len(pts) >= 4 else None
566
+
567
+ elif name == "cone":
568
+ r_base = rng.uniform(1.0, 2.0)
569
+ axis = rng.randint(3)
570
+ height = rng.randint(2, 5)
571
+ base_pos = rng.randint(0, GS - height + 1)
572
+ center = [cx, cy, cz]
573
+ axes = [i for i in range(3) if i != axis]
574
+ pts = []
575
+ for x in range(GS):
576
+ for y in range(GS):
577
+ for z in range(GS):
578
+ p = [x, y, z]
579
+ along = p[axis] - base_pos
580
+ if along < 0 or along >= height: continue
581
+ t = along / (height - 1 + 1e-6)
582
+ r_at = r_base * (1.0 - t)
583
+ dist_sq = sum((p[a] - center[a])**2 for a in axes)
584
+ if dist_sq <= r_at**2:
585
+ pts.append(p)
586
+ return np.array(pts) if len(pts) >= 4 else None
587
+
588
+ elif name == "capsule":
589
+ # Cylinder with hemispherical caps
590
+ r = rng.uniform(0.8, 1.5)
591
+ axis = rng.randint(3)
592
+ body_len = rng.randint(1, 3)
593
+ center = [cx, cy, cz]
594
+ axes = [i for i in range(3) if i != axis]
595
+ body_start = round(center[axis] - body_len / 2)
596
+ body_end = body_start + body_len
597
+ pts = []
598
+ for x in range(GS):
599
+ for y in range(GS):
600
+ for z in range(GS):
601
+ p = [x, y, z]
602
+ radial_sq = sum((p[a] - center[a])**2 for a in axes)
603
+ along = p[axis]
604
+ # Body
605
+ if body_start <= along <= body_end and radial_sq <= r**2:
606
+ pts.append(p)
607
+ # Bottom cap
608
+ elif along < body_start:
609
+ cap_center = list(center)
610
+ cap_center[axis] = body_start
611
+ dist_sq = sum((p[i] - cap_center[i])**2 for i in range(3))
612
+ if dist_sq <= r**2:
613
+ pts.append(p)
614
+ # Top cap
615
+ elif along > body_end:
616
+ cap_center = list(center)
617
+ cap_center[axis] = body_end
618
+ dist_sq = sum((p[i] - cap_center[i])**2 for i in range(3))
619
+ if dist_sq <= r**2:
620
+ pts.append(p)
621
+ return np.array(pts) if len(pts) >= 5 else None
622
+
623
+ elif name == "torus":
624
+ R = rng.uniform(1.2, 2.0)
625
+ r = rng.uniform(0.5, 0.9)
626
+ axis = rng.randint(3)
627
+ center = [cx, cy, cz]
628
+ ring_axes = [i for i in range(3) if i != axis]
629
+ pts = []
630
+ for x in range(GS):
631
+ for y in range(GS):
632
+ for z in range(GS):
633
+ p = [x, y, z]
634
+ dist_in_plane = np.sqrt(
635
+ sum((p[a] - center[a])**2 for a in ring_axes))
636
+ dist_from_ring = np.sqrt(
637
+ (dist_in_plane - R)**2 + (p[axis] - center[axis])**2)
638
+ if dist_from_ring <= r:
639
+ pts.append(p)
640
+ return np.array(pts) if len(pts) >= 4 else None
641
+
642
+ elif name == "shell":
643
+ # Hollow sphere: outer radius - inner radius
644
+ r_out = rng.uniform(1.5, 2.3)
645
+ r_in = r_out - rng.uniform(0.4, 0.8)
646
+ if r_in < 0.3: r_in = 0.3
647
+ pts = []
648
+ for x in range(GS):
649
+ for y in range(GS):
650
+ for z in range(GS):
651
+ d_sq = (x - cx)**2 + (y - cy)**2 + (z - cz)**2
652
+ if r_in**2 <= d_sq <= r_out**2:
653
+ pts.append([x, y, z])
654
+ return np.array(pts) if len(pts) >= 4 else None
655
+
656
+ elif name == "tube":
657
+ # Hollow cylinder
658
+ r_out = rng.uniform(1.0, 2.0)
659
+ r_in = r_out - rng.uniform(0.3, 0.7)
660
+ if r_in < 0.2: r_in = 0.2
661
+ axis = rng.randint(3)
662
+ length = rng.randint(2, 5)
663
+ start = rng.randint(0, GS - length + 1)
664
+ center = [cx, cy, cz]
665
+ axes = [i for i in range(3) if i != axis]
666
+ pts = []
667
+ for x in range(GS):
668
+ for y in range(GS):
669
+ for z in range(GS):
670
+ p = [x, y, z]
671
+ if p[axis] < start or p[axis] >= start + length: continue
672
+ dist_sq = sum((p[a] - center[a])**2 for a in axes)
673
+ if r_in**2 <= dist_sq <= r_out**2:
674
+ pts.append(p)
675
+ return np.array(pts) if len(pts) >= 4 else None
676
+
677
+ elif name == "bowl":
678
+ # Paraboloid: concave surface, open on top
679
+ r = rng.uniform(1.2, 2.2)
680
+ axis = rng.randint(3)
681
+ center = [cx, cy, cz]
682
+ axes = [i for i in range(3) if i != axis]
683
+ thickness = 0.6
684
+ pts = []
685
+ for x in range(GS):
686
+ for y in range(GS):
687
+ for z in range(GS):
688
+ p = [x, y, z]
689
+ dist_planar = np.sqrt(
690
+ sum((p[a] - center[a])**2 for a in axes))
691
+ if dist_planar > r: continue
692
+ # Paraboloid surface: h = k * dist^2
693
+ k = 1.0 / (r + 1e-6)
694
+ expected_h = center[axis] + k * dist_planar**2
695
+ actual_h = p[axis]
696
+ if abs(actual_h - expected_h) <= thickness:
697
+ pts.append(p)
698
+ return np.array(pts) if len(pts) >= 4 else None
699
+
700
+ elif name == "saddle":
701
+ # Hyperbolic paraboloid: z = k*(x^2 - y^2)
702
+ axis = rng.randint(3)
703
+ center = [cx, cy, cz]
704
+ axes = [i for i in range(3) if i != axis]
705
+ k = rng.uniform(0.3, 0.8)
706
+ thickness = 0.7
707
+ pts = []
708
+ for x in range(GS):
709
+ for y in range(GS):
710
+ for z in range(GS):
711
+ p = [x, y, z]
712
+ da = p[axes[0]] - center[axes[0]]
713
+ db = p[axes[1]] - center[axes[1]]
714
+ expected_h = center[axis] + k * (da**2 - db**2)
715
+ if abs(p[axis] - expected_h) <= thickness:
716
+ # Limit radius so it doesn't fill everything
717
+ dist_sq = da**2 + db**2
718
+ if dist_sq <= 4.0:
719
+ pts.append(p)
720
+ return np.array(pts) if len(pts) >= 4 else None
721
+
722
+ return None
723
+
724
+ # === Helpers ===
725
+
726
+ def _rand_pts_2d(self, n, min_dist=0):
727
+ for _ in range(50):
728
+ pts = set()
729
+ while len(pts) < n:
730
+ pts.add((self.rng.randint(0, GS), self.rng.randint(0, GS)))
731
+ pts = np.array(list(pts)[:n])
732
+ if min_dist <= 0 or self._check_dist(pts, min_dist):
733
+ return pts
734
+ return None
735
+
736
+ def _rand_pts_3d(self, n, min_dist=0):
737
+ for _ in range(100):
738
+ pts = set()
739
+ while len(pts) < n:
740
+ pts.add(tuple(self.rng.randint(0, GS, size=3)))
741
+ pts = np.array(list(pts)[:n])
742
+ if min_dist <= 0 or self._check_dist(pts, min_dist):
743
+ return pts
744
+ return None
745
+
746
+ def _check_dist(self, pts, min_dist):
747
+ for i in range(len(pts)):
748
+ for j in range(i + 1, len(pts)):
749
+ if np.sum(np.abs(pts[i] - pts[j])) < min_dist:
750
+ return False
751
+ return True
752
+
753
+ def _build(self, name, info, voxels):
754
+ n = len(voxels)
755
+ sub = voxels[:6].astype(float) if n > 6 else voxels.astype(float)
756
+ cm_det = cayley_menger_det(sub)
757
+ volume = effective_volume(sub)
758
+
759
+ dim_conf = np.zeros(4, dtype=np.float32)
760
+ dim_conf[0] = 1.0
761
+ if n >= 2: dim_conf[1] = 1.0
762
+ if info["dim"] >= 2: dim_conf[2] = 1.0
763
+ if info["dim"] >= 3: dim_conf[3] = 1.0
764
+
765
+ grid = np.zeros((GS, GS, GS), dtype=np.float32)
766
+ for v in voxels:
767
+ grid[v[0], v[1], v[2]] = 1.0
768
+
769
+ return {
770
+ "grid": grid, "label": CLASS_TO_IDX[name], "class_name": name,
771
+ "n_points": n, "n_occupied": int(grid.sum()),
772
+ "cm_det": float(cm_det), "volume": float(volume),
773
+ "peak_dim": info["dim"], "dim_confidence": dim_conf,
774
+ "is_curved": info["curved"], "curvature": CURV_TO_IDX[info["curvature"]],
775
+ }
776
+
777
+
778
+ # === Dataset =================================================================
779
+
780
+ def _generate_chunk(args):
781
+ """Worker function for parallel shape generation."""
782
+ class_assignments, seed, start_idx = args
783
+ gen = ShapeGenerator(seed=seed)
784
+ samples = []
785
+ for ci in class_assignments:
786
+ name = CLASS_NAMES[ci]
787
+ for attempt in range(10):
788
+ s = gen._make(name)
789
+ if s is not None:
790
+ samples.append(s)
791
+ break
792
+ else:
793
+ s = gen._make("cube")
794
+ if s is not None:
795
+ samples.append(s)
796
+ return samples
797
+
798
+
799
+ def generate_parallel(n_samples, seed=42, n_workers=8):
800
+ """Pre-generate all samples using multiprocessing."""
801
+ import multiprocessing as mp
802
+ per_class = n_samples // NUM_CLASSES
803
+ class_assignments = []
804
+ for ci in range(NUM_CLASSES):
805
+ class_assignments.extend([ci] * per_class)
806
+ rng = np.random.RandomState(seed)
807
+ while len(class_assignments) < n_samples:
808
+ class_assignments.append(rng.randint(0, NUM_CLASSES))
809
+ rng.shuffle(class_assignments)
810
+ class_assignments = class_assignments[:n_samples]
811
+
812
+ # Split into chunks per worker
813
+ chunk_size = (n_samples + n_workers - 1) // n_workers
814
+ chunks = []
815
+ for i in range(n_workers):
816
+ start = i * chunk_size
817
+ end = min(start + chunk_size, n_samples)
818
+ if start >= n_samples:
819
+ break
820
+ chunks.append((class_assignments[start:end], seed + i * 1000000, start))
821
+
822
+ print(f"Generating {n_samples} shapes across {len(chunks)} workers...")
823
+ import time; t0 = time.time()
824
+ with mp.Pool(n_workers) as pool:
825
+ results = pool.map(_generate_chunk, chunks)
826
+ samples = []
827
+ for r in results:
828
+ samples.extend(r)
829
+ rng.shuffle(samples)
830
+ dt = time.time() - t0
831
+ print(f"Generated {len(samples)} samples in {dt:.1f}s ({len(samples)/dt:.0f} samples/s)")
832
+ return samples
833
+
834
+
835
+ class ShapeDataset(torch.utils.data.Dataset):
836
+ def __init__(self, samples):
837
+ self.grids = torch.tensor(np.stack([s["grid"] for s in samples]), dtype=torch.float32)
838
+ self.labels = torch.tensor([s["label"] for s in samples], dtype=torch.long)
839
+ self.dim_conf = torch.tensor(np.stack([s["dim_confidence"] for s in samples]), dtype=torch.float32)
840
+ self.peak_dim = torch.tensor([s["peak_dim"] for s in samples], dtype=torch.long)
841
+ self.volume = torch.tensor([s["volume"] for s in samples], dtype=torch.float32)
842
+ self.cm_det = torch.tensor([s["cm_det"] for s in samples], dtype=torch.float32)
843
+ self.is_curved = torch.tensor([s["is_curved"] for s in samples], dtype=torch.float32)
844
+ self.curvature = torch.tensor([s["curvature"] for s in samples], dtype=torch.long)
845
+
846
+ def __len__(self):
847
+ return len(self.labels)
848
+
849
+ def __getitem__(self, idx):
850
+ return (self.grids[idx], self.labels[idx], self.dim_conf[idx],
851
+ self.peak_dim[idx], self.volume[idx], self.cm_det[idx],
852
+ self.is_curved[idx], self.curvature[idx])
853
+
854
+
855
+
856
+ print(f'Loaded {NUM_CLASSES} shape classes, GS={GS}')