AbstractPhil commited on
Commit
1fe1ab5
Β·
verified Β·
1 Parent(s): a52a380

Create model.py

Browse files
Files changed (1) hide show
  1. model.py +409 -0
model.py ADDED
@@ -0,0 +1,409 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Superposition Patch Classifier β€” Standalone Inference Module
3
+ =============================================================
4
+ Two-tier gated geometric transformer that extracts structural
5
+ properties from (8, 16, 16) latent patches.
6
+
7
+ No dependencies beyond PyTorch. All grid/gate constants inlined.
8
+
9
+ Input: (B, 8, 16, 16) β€” adapted latent patches
10
+ Output: gate_vectors (B, 64, 17), patch_features (B, 64, 256), logits
11
+
12
+ Usage:
13
+ from geometric_model import SuperpositionPatchClassifier, load_from_hub
14
+
15
+ model = load_from_hub() # downloads from AbstractPhil/geovocab-patch-maker
16
+ out = model(patches)
17
+
18
+ # Gate vectors: explicit geometric properties per patch
19
+ local_gates = torch.cat([
20
+ F.softmax(out["local_dim_logits"], dim=-1), # 4d: dimensionality
21
+ F.softmax(out["local_curv_logits"], dim=-1), # 3d: curvature class
22
+ torch.sigmoid(out["local_bound_logits"]), # 1d: boundary flag
23
+ torch.sigmoid(out["local_axis_logits"]), # 3d: active axes
24
+ ], dim=-1) # (B, 64, 11)
25
+
26
+ structural_gates = torch.cat([
27
+ F.softmax(out["struct_topo_logits"], dim=-1), # 2d: topology
28
+ torch.sigmoid(out["struct_neighbor_logits"]), # 1d: neighbor density
29
+ F.softmax(out["struct_role_logits"], dim=-1), # 3d: surface role
30
+ ], dim=-1) # (B, 64, 6)
31
+
32
+ gate_vectors = torch.cat([local_gates, structural_gates], dim=-1) # (B, 64, 17)
33
+ patch_features = out["patch_features"] # (B, 64, embed_dim)
34
+ """
35
+
36
+ import math
37
+ import torch
38
+ import torch.nn as nn
39
+ import torch.nn.functional as F
40
+
41
+
42
+ # ══════════════════════════════════════════════════════════════════════════════
43
+ # Grid Constants (inlined from generator β€” no dependency needed)
44
+ # ══════════════════════════════════════════════════════════════════════════════
45
+
46
+ GZ, GY, GX = 8, 16, 16
47
+ PATCH_Z, PATCH_Y, PATCH_X = 2, 4, 4
48
+ PATCH_VOL = PATCH_Z * PATCH_Y * PATCH_X # 32
49
+ MACRO_Z = GZ // PATCH_Z # 4
50
+ MACRO_Y = GY // PATCH_Y # 4
51
+ MACRO_X = GX // PATCH_X # 4
52
+ MACRO_N = MACRO_Z * MACRO_Y * MACRO_X # 64
53
+
54
+ # Local gates: intrinsic per-patch (no cross-patch info)
55
+ NUM_LOCAL_DIMS = 4 # 0D point, 1D line, 2D surface, 3D volume
56
+ NUM_LOCAL_CURVS = 3 # rigid, curved, combined
57
+ NUM_LOCAL_BOUNDARY = 1 # partial fill flag
58
+ NUM_LOCAL_AXES = 3 # which axes have extent > 1
59
+ LOCAL_GATE_DIM = NUM_LOCAL_DIMS + NUM_LOCAL_CURVS + NUM_LOCAL_BOUNDARY + NUM_LOCAL_AXES # 11
60
+
61
+ # Structural gates: relational (require neighborhood context)
62
+ NUM_STRUCT_TOPO = 2 # open / closed
63
+ NUM_STRUCT_NEIGHBOR = 1 # normalized neighbor count
64
+ NUM_STRUCT_ROLE = 3 # isolated / boundary / interior
65
+ STRUCTURAL_GATE_DIM = NUM_STRUCT_TOPO + NUM_STRUCT_NEIGHBOR + NUM_STRUCT_ROLE # 6
66
+
67
+ TOTAL_GATE_DIM = LOCAL_GATE_DIM + STRUCTURAL_GATE_DIM # 17
68
+
69
+ # Shape classes (27 geometric primitives)
70
+ CLASS_NAMES = [
71
+ "point", "line", "corner", "cross", "arc", "helix", "circle",
72
+ "triangle", "quad", "plane", "disc",
73
+ "tetrahedron", "cube", "pyramid", "prism", "octahedron", "pentachoron", "wedge",
74
+ "sphere", "hemisphere", "torus", "bowl", "saddle", "capsule", "cylinder", "cone", "channel"
75
+ ]
76
+ NUM_CLASSES = len(CLASS_NAMES)
77
+
78
+ # Legacy gate names
79
+ GATES = ["rigid", "curved", "combined", "open", "closed"]
80
+ NUM_GATES = len(GATES)
81
+
82
+
83
+ # ══════════════════════════════════════════════════════════════════════════════
84
+ # Patch Embedding
85
+ # ══════════════════════════════════════════════════════════════════════════════
86
+
87
+ class PatchEmbedding3D(nn.Module):
88
+ def __init__(self, patch_dim=64):
89
+ super().__init__()
90
+ self.proj = nn.Linear(PATCH_VOL, patch_dim)
91
+ pz = torch.arange(MACRO_Z).float() / MACRO_Z
92
+ py = torch.arange(MACRO_Y).float() / MACRO_Y
93
+ px = torch.arange(MACRO_X).float() / MACRO_X
94
+ pos = torch.stack(torch.meshgrid(pz, py, px, indexing='ij'), dim=-1).reshape(MACRO_N, 3)
95
+ self.register_buffer('pos_embed', pos)
96
+ self.pos_proj = nn.Linear(3, patch_dim)
97
+
98
+ def forward(self, x):
99
+ B = x.shape[0]
100
+ patches = x.view(B, MACRO_Z, PATCH_Z, MACRO_Y, PATCH_Y, MACRO_X, PATCH_X)
101
+ patches = patches.permute(0, 1, 3, 5, 2, 4, 6).contiguous().view(B, MACRO_N, PATCH_VOL)
102
+ return self.proj(patches) + self.pos_proj(self.pos_embed)
103
+
104
+
105
+ # ═���════════════════════════════════════════════════════════════════════════════
106
+ # Transformer Blocks
107
+ # ══════════════════════════════════════════════════════════════════════════════
108
+
109
+ class TransformerBlock(nn.Module):
110
+ def __init__(self, dim, n_heads, dropout=0.1):
111
+ super().__init__()
112
+ self.attn = nn.MultiheadAttention(dim, n_heads, dropout=dropout, batch_first=True)
113
+ self.ff = nn.Sequential(
114
+ nn.Linear(dim, dim * 4), nn.GELU(), nn.Dropout(dropout),
115
+ nn.Linear(dim * 4, dim), nn.Dropout(dropout)
116
+ )
117
+ self.ln1, self.ln2 = nn.LayerNorm(dim), nn.LayerNorm(dim)
118
+
119
+ def forward(self, x):
120
+ x = x + self.attn(self.ln1(x), self.ln1(x), self.ln1(x))[0]
121
+ return x + self.ff(self.ln2(x))
122
+
123
+
124
+ class GatedGeometricAttention(nn.Module):
125
+ """
126
+ Multi-head attention with two-tier gate modulation.
127
+ Q, K see both local and structural gates.
128
+ V modulated by combined gate vector.
129
+ Per-head compatibility bias from gate interactions.
130
+ """
131
+
132
+ def __init__(self, embed_dim, gate_dim, n_heads, dropout=0.1):
133
+ super().__init__()
134
+ self.embed_dim = embed_dim
135
+ self.n_heads = n_heads
136
+ self.head_dim = embed_dim // n_heads
137
+
138
+ self.q_proj = nn.Linear(embed_dim + gate_dim, embed_dim)
139
+ self.k_proj = nn.Linear(embed_dim + gate_dim, embed_dim)
140
+ self.v_proj = nn.Linear(embed_dim, embed_dim)
141
+
142
+ self.gate_q = nn.Linear(gate_dim, n_heads)
143
+ self.gate_k = nn.Linear(gate_dim, n_heads)
144
+ self.v_gate = nn.Sequential(nn.Linear(gate_dim, embed_dim), nn.Sigmoid())
145
+
146
+ self.out_proj = nn.Linear(embed_dim, embed_dim)
147
+ self.attn_drop = nn.Dropout(dropout)
148
+ self.scale = math.sqrt(self.head_dim)
149
+
150
+ def forward(self, h, gate_features):
151
+ B, N, _ = h.shape
152
+ hg = torch.cat([h, gate_features], dim=-1)
153
+ Q = self.q_proj(hg).view(B, N, self.n_heads, self.head_dim).transpose(1, 2)
154
+ K = self.k_proj(hg).view(B, N, self.n_heads, self.head_dim).transpose(1, 2)
155
+
156
+ V = self.v_proj(h)
157
+ V = (V * self.v_gate(gate_features)).view(B, N, self.n_heads, self.head_dim).transpose(1, 2)
158
+
159
+ content_scores = (Q @ K.transpose(-2, -1)) / self.scale
160
+ gq = self.gate_q(gate_features)
161
+ gk = self.gate_k(gate_features)
162
+ compat = torch.einsum('bih,bjh->bhij', gq, gk)
163
+
164
+ attn = F.softmax(content_scores + compat, dim=-1)
165
+ attn = self.attn_drop(attn)
166
+
167
+ out = (attn @ V).transpose(1, 2).reshape(B, N, self.embed_dim)
168
+ return self.out_proj(out)
169
+
170
+
171
+ class GeometricTransformerBlock(nn.Module):
172
+ def __init__(self, embed_dim, gate_dim, n_heads, dropout=0.1, ff_mult=4):
173
+ super().__init__()
174
+ self.ln1 = nn.LayerNorm(embed_dim)
175
+ self.attn = GatedGeometricAttention(embed_dim, gate_dim, n_heads, dropout)
176
+ self.ln2 = nn.LayerNorm(embed_dim)
177
+ self.ff = nn.Sequential(
178
+ nn.Linear(embed_dim, embed_dim * ff_mult), nn.GELU(), nn.Dropout(dropout),
179
+ nn.Linear(embed_dim * ff_mult, embed_dim), nn.Dropout(dropout)
180
+ )
181
+
182
+ def forward(self, h, gate_features):
183
+ h = h + self.attn(self.ln1(h), gate_features)
184
+ h = h + self.ff(self.ln2(h))
185
+ return h
186
+
187
+
188
+ # ══════════════════════════════════════════════════════════════════════════════
189
+ # Main Model
190
+ # ══════════════════════════════════════════════════════════════════════════════
191
+
192
+ class SuperpositionPatchClassifier(nn.Module):
193
+ """
194
+ Two-tier gated geometric transformer.
195
+
196
+ Stage 0: Local gates from raw patch embeddings (what IS in this patch)
197
+ Stage 1: Bootstrap attention with local gate context
198
+ Stage 1.5: Structural gates from post-attention features (what ROLE this patch plays)
199
+ Stage 2: Geometric gated attention with both gate tiers
200
+ Stage 3: Classification heads
201
+
202
+ For feature extraction (no classification), use outputs:
203
+ - gate vectors: cat(local_gates, structural_gates) β†’ (B, 64, 17)
204
+ - patch_features: out["patch_features"] β†’ (B, 64, embed_dim)
205
+ - global_features: out["global_features"] β†’ (B, embed_dim)
206
+ """
207
+
208
+ def __init__(self, embed_dim=128, patch_dim=64, n_bootstrap=2, n_geometric=2,
209
+ n_heads=4, dropout=0.1):
210
+ super().__init__()
211
+ self.embed_dim = embed_dim
212
+
213
+ # Patch embedding
214
+ self.patch_embed = PatchEmbedding3D(patch_dim)
215
+
216
+ # Stage 0: Local encoder + gate heads
217
+ local_hidden = patch_dim * 2
218
+ self.local_encoder = nn.Sequential(
219
+ nn.Linear(patch_dim, local_hidden), nn.GELU(), nn.Dropout(dropout),
220
+ nn.Linear(local_hidden, local_hidden), nn.GELU(), nn.Dropout(dropout),
221
+ )
222
+ self.local_dim_head = nn.Linear(local_hidden, NUM_LOCAL_DIMS)
223
+ self.local_curv_head = nn.Linear(local_hidden, NUM_LOCAL_CURVS)
224
+ self.local_bound_head = nn.Linear(local_hidden, NUM_LOCAL_BOUNDARY)
225
+ self.local_axis_head = nn.Linear(local_hidden, NUM_LOCAL_AXES)
226
+
227
+ # Projection into transformer dim
228
+ self.proj = nn.Linear(patch_dim + LOCAL_GATE_DIM, embed_dim)
229
+
230
+ # Stage 1: Bootstrap blocks
231
+ self.bootstrap_blocks = nn.ModuleList([
232
+ TransformerBlock(embed_dim, n_heads, dropout)
233
+ for _ in range(n_bootstrap)
234
+ ])
235
+
236
+ # Stage 1.5: Structural gate heads
237
+ self.struct_topo_head = nn.Linear(embed_dim, NUM_STRUCT_TOPO)
238
+ self.struct_neighbor_head = nn.Linear(embed_dim, NUM_STRUCT_NEIGHBOR)
239
+ self.struct_role_head = nn.Linear(embed_dim, NUM_STRUCT_ROLE)
240
+
241
+ # Stage 2: Geometric gated blocks
242
+ self.geometric_blocks = nn.ModuleList([
243
+ GeometricTransformerBlock(embed_dim, TOTAL_GATE_DIM, n_heads, dropout)
244
+ for _ in range(n_geometric)
245
+ ])
246
+
247
+ # Stage 3: Classification heads
248
+ gated_dim = embed_dim + TOTAL_GATE_DIM
249
+
250
+ self.patch_shape_head = nn.Sequential(
251
+ nn.Linear(gated_dim, embed_dim), nn.GELU(), nn.Dropout(dropout),
252
+ nn.Linear(embed_dim, NUM_CLASSES)
253
+ )
254
+
255
+ self.global_pool = nn.Sequential(
256
+ nn.Linear(gated_dim, embed_dim), nn.GELU(),
257
+ nn.Linear(embed_dim, embed_dim)
258
+ )
259
+ self.global_gate_head = nn.Linear(embed_dim, NUM_GATES)
260
+ self.global_shape_head = nn.Linear(embed_dim, NUM_CLASSES)
261
+
262
+ def forward(self, x):
263
+ # Patch embedding
264
+ e = self.patch_embed(x)
265
+
266
+ # Stage 0: Local gates
267
+ e_local = self.local_encoder(e)
268
+ local_dim_logits = self.local_dim_head(e_local)
269
+ local_curv_logits = self.local_curv_head(e_local)
270
+ local_bound_logits = self.local_bound_head(e_local)
271
+ local_axis_logits = self.local_axis_head(e_local)
272
+
273
+ local_gates = torch.cat([
274
+ F.softmax(local_dim_logits, dim=-1),
275
+ F.softmax(local_curv_logits, dim=-1),
276
+ torch.sigmoid(local_bound_logits),
277
+ torch.sigmoid(local_axis_logits),
278
+ ], dim=-1)
279
+
280
+ # Stage 1: Bootstrap
281
+ h = self.proj(torch.cat([e, local_gates], dim=-1))
282
+ for blk in self.bootstrap_blocks:
283
+ h = blk(h)
284
+
285
+ # Stage 1.5: Structural gates
286
+ struct_topo_logits = self.struct_topo_head(h)
287
+ struct_neighbor_logits = self.struct_neighbor_head(h)
288
+ struct_role_logits = self.struct_role_head(h)
289
+
290
+ structural_gates = torch.cat([
291
+ F.softmax(struct_topo_logits, dim=-1),
292
+ torch.sigmoid(struct_neighbor_logits),
293
+ F.softmax(struct_role_logits, dim=-1),
294
+ ], dim=-1)
295
+
296
+ all_gates = torch.cat([local_gates, structural_gates], dim=-1)
297
+
298
+ # Stage 2: Geometric routing
299
+ for blk in self.geometric_blocks:
300
+ h = blk(h, all_gates)
301
+
302
+ # Stage 3: Classification
303
+ h_gated = torch.cat([h, all_gates], dim=-1)
304
+ shape_logits = self.patch_shape_head(h_gated)
305
+ g = self.global_pool(h_gated.mean(dim=1))
306
+
307
+ return {
308
+ "local_dim_logits": local_dim_logits,
309
+ "local_curv_logits": local_curv_logits,
310
+ "local_bound_logits": local_bound_logits,
311
+ "local_axis_logits": local_axis_logits,
312
+ "struct_topo_logits": struct_topo_logits,
313
+ "struct_neighbor_logits": struct_neighbor_logits,
314
+ "struct_role_logits": struct_role_logits,
315
+ "patch_shape_logits": shape_logits,
316
+ "patch_features": h,
317
+ "global_features": g,
318
+ "global_gates": self.global_gate_head(g),
319
+ "global_shapes": self.global_shape_head(g),
320
+ }
321
+
322
+
323
+ # ══════════════════════════════════════════════════════════════════════════════
324
+ # Hub Loading
325
+ # ══════════════════════════════════════════════════════════════════════════════
326
+
327
+ def load_from_hub(
328
+ repo_id="AbstractPhil/geovocab-patch-maker",
329
+ filename="model.pt",
330
+ device="cuda" if torch.cuda.is_available() else "cpu",
331
+ ):
332
+ """Load pretrained model from HuggingFace Hub."""
333
+ from huggingface_hub import hf_hub_download
334
+
335
+ path = hf_hub_download(repo_id=repo_id, filename=filename)
336
+ ckpt = torch.load(path, map_location=device, weights_only=False)
337
+ cfg = ckpt["config"]
338
+
339
+ model = SuperpositionPatchClassifier(
340
+ embed_dim=cfg["embed_dim"],
341
+ patch_dim=cfg["patch_dim"],
342
+ n_bootstrap=cfg["n_bootstrap"],
343
+ n_geometric=cfg["n_geometric"],
344
+ n_heads=cfg["n_heads"],
345
+ dropout=0.0,
346
+ ).to(device).eval()
347
+
348
+ model.load_state_dict(ckpt["model_state_dict"])
349
+ print(f"βœ“ Loaded {repo_id} (epoch {ckpt.get('epoch', '?')})")
350
+ return model
351
+
352
+
353
+ @torch.no_grad()
354
+ def extract_features(model, patches, batch_size=256):
355
+ """
356
+ Convenience: patches β†’ (gate_vectors, patch_features)
357
+
358
+ Args:
359
+ model: SuperpositionPatchClassifier (eval mode)
360
+ patches: (N, 8, 16, 16) tensor
361
+ batch_size: inference batch size
362
+
363
+ Returns:
364
+ gate_vectors: (N, 64, 17) β€” explicit geometric properties
365
+ patch_features: (N, 64, embed_dim) β€” learned representations
366
+ """
367
+ device = next(model.parameters()).device
368
+ all_gates, all_patch = [], []
369
+
370
+ for s in range(0, patches.shape[0], batch_size):
371
+ batch = patches[s:s + batch_size].to(device)
372
+ out = model(batch)
373
+
374
+ local = torch.cat([
375
+ F.softmax(out["local_dim_logits"], dim=-1),
376
+ F.softmax(out["local_curv_logits"], dim=-1),
377
+ torch.sigmoid(out["local_bound_logits"]),
378
+ torch.sigmoid(out["local_axis_logits"]),
379
+ ], dim=-1)
380
+
381
+ struct = torch.cat([
382
+ F.softmax(out["struct_topo_logits"], dim=-1),
383
+ torch.sigmoid(out["struct_neighbor_logits"]),
384
+ F.softmax(out["struct_role_logits"], dim=-1),
385
+ ], dim=-1)
386
+
387
+ all_gates.append(torch.cat([local, struct], dim=-1).cpu())
388
+ all_patch.append(out["patch_features"].cpu())
389
+
390
+ return torch.cat(all_gates), torch.cat(all_patch)
391
+
392
+
393
+ # ══════════════════════════════════════════════════════════════════════════════
394
+ # Quick test
395
+ # ══════════════════════════════════════════════════════════════════════════════
396
+
397
+ if __name__ == "__main__":
398
+ model = SuperpositionPatchClassifier()
399
+ n_params = sum(p.numel() for p in model.parameters())
400
+ print(f"SuperpositionPatchClassifier: {n_params:,} parameters")
401
+
402
+ x = torch.randn(2, 8, 16, 16)
403
+ out = model(x)
404
+ print(f" Input: {x.shape}")
405
+ print(f" patch_features: {out['patch_features'].shape}")
406
+ print(f" local_dim: {out['local_dim_logits'].shape}")
407
+ print(f" struct_topo: {out['struct_topo_logits'].shape}")
408
+ print(f" patch_shapes: {out['patch_shape_logits'].shape}")
409
+ print(f" global_features: {out['global_features'].shape}")