AbstractPhil commited on
Commit
279d83a
·
verified ·
1 Parent(s): 6565ed3

Create pentachora_stabilizer.py

Browse files
Files changed (1) hide show
  1. pentachora_stabilizer.py +219 -0
pentachora_stabilizer.py ADDED
@@ -0,0 +1,219 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+ from typing import Dict, Optional, Union
4
+
5
+
6
+ def get_parameter_groups(model: PentachoraViT,
7
+ weight_decay: float = 0.05) -> List[Dict[str, Any]]:
8
+ """Get parameter groups for optimizer with weight decay handling."""
9
+ no_decay = ['bias', 'norm', 'LayerNorm']
10
+
11
+ decay_params = []
12
+ no_decay_params = []
13
+
14
+ for name, param in model.named_parameters():
15
+ if not param.requires_grad:
16
+ continue
17
+
18
+ if any(nd in name for nd in no_decay):
19
+ no_decay_params.append(param)
20
+ else:
21
+ decay_params.append(param)
22
+
23
+ return [
24
+ {'params': decay_params, 'weight_decay': weight_decay},
25
+ {'params': no_decay_params, 'weight_decay': 0.0}
26
+ ]
27
+
28
+
29
+ def count_parameters(model: nn.Module) -> Dict[str, int]:
30
+ """Count model parameters."""
31
+ total = sum(p.numel() for p in model.parameters())
32
+ trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
33
+ return {
34
+ 'total': total,
35
+ 'trainable': trainable,
36
+ 'non_trainable': total - trainable
37
+ }
38
+
39
+
40
+
41
+ # 1. Add a utility function at the top of the file:
42
+ def get_default_device():
43
+ """Get the default device (CUDA if available, else CPU)."""
44
+ return torch.device('cuda' if torch.cuda.is_available() else 'cpu')
45
+
46
+ class PentachoronStabilizer:
47
+ """
48
+ Geometric constraint utilities for a 5-simplex (pentachoron).
49
+ Includes Rose scoring for semantic alignment.
50
+ """
51
+
52
+ @staticmethod
53
+ def vertices_to_tensor(vertices):
54
+ """Convert dict to tensor once, reuse everywhere."""
55
+ if isinstance(vertices, dict):
56
+ return torch.stack([
57
+ vertices['anchor'], vertices['need'],
58
+ vertices['relation'], vertices['purpose'],
59
+ vertices['observer']
60
+ ], dim=1) # [B, 5, D]
61
+ return vertices
62
+
63
+ @staticmethod
64
+ def tensor_to_dict(verts):
65
+ """Convert tensor [B, 5, D] back to dict."""
66
+ return {
67
+ 'anchor': verts[:, 0],
68
+ 'need': verts[:, 1],
69
+ 'relation': verts[:, 2],
70
+ 'purpose': verts[:, 3],
71
+ 'observer': verts[:, 4]
72
+ }
73
+
74
+ @staticmethod
75
+ def rose_score_magnitude(
76
+ x: torch.Tensor,
77
+ vertices: Union[Dict[str, torch.Tensor], torch.Tensor],
78
+ eps: float = 1e-6
79
+ ) -> torch.Tensor:
80
+ """
81
+ Compute Rose similarity score between x and pentachoron vertices.
82
+
83
+ Args:
84
+ x: Query tensor [B, T, D] or [B, D]
85
+ vertices: Either dict or tensor [B, 5, D]
86
+ eps: Small value for numerical stability
87
+
88
+ Returns:
89
+ scores: [B, T] or [B] depending on input shape
90
+ """
91
+ # Handle input shapes
92
+ squeeze_output = False
93
+ if x.dim() == 2:
94
+ x = x.unsqueeze(1) # [B, 1, D]
95
+ squeeze_output = True
96
+
97
+ # Get vertices as dict
98
+ if not isinstance(vertices, dict):
99
+ vertices = PentachoronStabilizer.tensor_to_dict(vertices)
100
+
101
+ # Expand vertices to match sequence dimension
102
+ B, T, D = x.shape
103
+ need = vertices['need'].unsqueeze(1).expand(-1, T, -1)
104
+ relation = vertices['relation'].unsqueeze(1).expand(-1, T, -1)
105
+ purpose = vertices['purpose'].unsqueeze(1).expand(-1, T, -1)
106
+
107
+ # Normalize all inputs
108
+ x_n = F.normalize(x, dim=-1, eps=eps)
109
+ n_n = F.normalize(need, dim=-1, eps=eps)
110
+ r_n = F.normalize(relation, dim=-1, eps=eps)
111
+ p_n = F.normalize(purpose, dim=-1, eps=eps)
112
+
113
+ # Core directional cosine components
114
+ a_n = torch.cosine_similarity(x_n, n_n, dim=-1)
115
+ a_r = torch.cosine_similarity(x_n, r_n, dim=-1)
116
+ a_p = torch.cosine_similarity(x_n, p_n, dim=-1)
117
+
118
+ # Triadic magnitude score
119
+ r7 = (a_n + a_r + a_p) / 3.0
120
+ r8 = x.norm(dim=-1)
121
+
122
+ score = r7 * r8
123
+
124
+ return score.squeeze(1) if squeeze_output else score
125
+
126
+ @staticmethod
127
+ def compute_gram_matrix(verts):
128
+ """Compute Gram matrix for batch of vertices."""
129
+ return torch.bmm(verts, verts.transpose(-2, -1))
130
+
131
+ @staticmethod
132
+ def cayley_menger_determinant(verts):
133
+ """Compute Cayley-Menger determinant (vectorized)."""
134
+ B = verts.shape[0]
135
+
136
+ gram = torch.bmm(verts, verts.transpose(-2, -1))
137
+ diag = gram.diagonal(dim1=-2, dim2=-1).unsqueeze(-1)
138
+ dist_sq = diag + diag.transpose(-2, -1) - 2 * gram
139
+
140
+ cm = torch.zeros(B, 6, 6, device=verts.device)
141
+ cm[:, 0, 1:] = 1
142
+ cm[:, 1:, 0] = 1
143
+ cm[:, 1:, 1:] = dist_sq
144
+
145
+ return torch.det(cm)
146
+
147
+ @staticmethod
148
+ def enforce_regular_simplex(verts):
149
+ """Compute edge length variance (fully vectorized)."""
150
+ diff = verts.unsqueeze(2) - verts.unsqueeze(1)
151
+ dist = torch.norm(diff, dim=-1)
152
+
153
+ triu_indices = torch.triu_indices(5, 5, offset=1)
154
+ edges = dist[:, triu_indices[0], triu_indices[1]]
155
+
156
+ return torch.var(edges, dim=-1)
157
+
158
+ @staticmethod
159
+ def orthoplex_projection(verts):
160
+ """Project to unit hypersphere, centered."""
161
+ verts_norm = F.normalize(verts, dim=-1)
162
+ center = verts_norm.mean(dim=1, keepdim=True)
163
+ verts_centered = verts_norm - center
164
+ return F.normalize(verts_centered, dim=-1)
165
+
166
+ @staticmethod
167
+ def apply(
168
+ vertices,
169
+ cayley_target: float = 1.0,
170
+ return_dict: bool = False,
171
+ compute_rose_scores: Optional[torch.Tensor] = None
172
+ ):
173
+ """
174
+ Apply all constraints and return stable vertices + losses.
175
+
176
+ Args:
177
+ vertices: Either dict or tensor [B, 5, D]
178
+ cayley_target: Target Cayley-Menger determinant
179
+ return_dict: If True and input was dict, return dict
180
+ compute_rose_scores: Optional tensor to compute Rose scores against
181
+
182
+ Returns:
183
+ vertices_stable: Stabilized vertices
184
+ losses: Dict of loss components (includes rose_scores if requested)
185
+ """
186
+ was_dict = isinstance(vertices, dict)
187
+ verts = PentachoronStabilizer.vertices_to_tensor(vertices)
188
+
189
+ # Compute geometric losses
190
+ cm_det = PentachoronStabilizer.cayley_menger_determinant(verts)
191
+ validity_loss = torch.abs(cm_det - cayley_target).mean()
192
+ regularity_loss = PentachoronStabilizer.enforce_regular_simplex(verts).mean()
193
+
194
+ # Stabilize vertices
195
+ verts_stable = PentachoronStabilizer.orthoplex_projection(verts)
196
+
197
+ # Compute Gram entropy
198
+ gram = PentachoronStabilizer.compute_gram_matrix(verts_stable)
199
+ gram_entropy = -torch.sum(gram * torch.log(torch.abs(gram) + 1e-8)) / (verts.shape[0] * 25)
200
+
201
+ losses = {
202
+ 'validity': validity_loss,
203
+ 'regularity': regularity_loss,
204
+ 'gram_entropy': gram_entropy
205
+ }
206
+
207
+ # Compute Rose scores if requested
208
+ if compute_rose_scores is not None:
209
+ rose_scores = PentachoronStabilizer.rose_score_magnitude(
210
+ compute_rose_scores,
211
+ verts_stable
212
+ )
213
+ losses['rose_scores'] = rose_scores
214
+
215
+ # Convert back to dict if requested
216
+ if was_dict and return_dict:
217
+ verts_stable = PentachoronStabilizer.tensor_to_dict(verts_stable)
218
+
219
+ return verts_stable, losses