| | import torch |
| | import torch.nn.functional as F |
| | from typing import Dict, Optional, Union |
| |
|
| |
|
| | def get_parameter_groups(model: PentachoraViT, |
| | weight_decay: float = 0.05) -> List[Dict[str, Any]]: |
| | """Get parameter groups for optimizer with weight decay handling.""" |
| | no_decay = ['bias', 'norm', 'LayerNorm'] |
| |
|
| | decay_params = [] |
| | no_decay_params = [] |
| |
|
| | for name, param in model.named_parameters(): |
| | if not param.requires_grad: |
| | continue |
| |
|
| | if any(nd in name for nd in no_decay): |
| | no_decay_params.append(param) |
| | else: |
| | decay_params.append(param) |
| |
|
| | return [ |
| | {'params': decay_params, 'weight_decay': weight_decay}, |
| | {'params': no_decay_params, 'weight_decay': 0.0} |
| | ] |
| |
|
| | |
| | def count_parameters(model: nn.Module) -> Dict[str, int]: |
| | """Count model parameters.""" |
| | total = sum(p.numel() for p in model.parameters()) |
| | trainable = sum(p.numel() for p in model.parameters() if p.requires_grad) |
| | return { |
| | 'total': total, |
| | 'trainable': trainable, |
| | 'non_trainable': total - trainable |
| | } |
| |
|
| |
|
| |
|
| | |
| | def get_default_device(): |
| | """Get the default device (CUDA if available, else CPU).""" |
| | return torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
| |
|
| | class PentachoronStabilizer: |
| | """ |
| | Geometric constraint utilities for a 5-simplex (pentachoron). |
| | Includes Rose scoring for semantic alignment. |
| | """ |
| | |
| | @staticmethod |
| | def vertices_to_tensor(vertices): |
| | """Convert dict to tensor once, reuse everywhere.""" |
| | if isinstance(vertices, dict): |
| | return torch.stack([ |
| | vertices['anchor'], vertices['need'], |
| | vertices['relation'], vertices['purpose'], |
| | vertices['observer'] |
| | ], dim=1) |
| | return vertices |
| | |
| | @staticmethod |
| | def tensor_to_dict(verts): |
| | """Convert tensor [B, 5, D] back to dict.""" |
| | return { |
| | 'anchor': verts[:, 0], |
| | 'need': verts[:, 1], |
| | 'relation': verts[:, 2], |
| | 'purpose': verts[:, 3], |
| | 'observer': verts[:, 4] |
| | } |
| | |
| | @staticmethod |
| | def rose_score_magnitude( |
| | x: torch.Tensor, |
| | vertices: Union[Dict[str, torch.Tensor], torch.Tensor], |
| | eps: float = 1e-6 |
| | ) -> torch.Tensor: |
| | """ |
| | Compute Rose similarity score between x and pentachoron vertices. |
| | |
| | Args: |
| | x: Query tensor [B, T, D] or [B, D] |
| | vertices: Either dict or tensor [B, 5, D] |
| | eps: Small value for numerical stability |
| | |
| | Returns: |
| | scores: [B, T] or [B] depending on input shape |
| | """ |
| | |
| | squeeze_output = False |
| | if x.dim() == 2: |
| | x = x.unsqueeze(1) |
| | squeeze_output = True |
| | |
| | |
| | if not isinstance(vertices, dict): |
| | vertices = PentachoronStabilizer.tensor_to_dict(vertices) |
| | |
| | |
| | B, T, D = x.shape |
| | need = vertices['need'].unsqueeze(1).expand(-1, T, -1) |
| | relation = vertices['relation'].unsqueeze(1).expand(-1, T, -1) |
| | purpose = vertices['purpose'].unsqueeze(1).expand(-1, T, -1) |
| | |
| | |
| | x_n = F.normalize(x, dim=-1, eps=eps) |
| | n_n = F.normalize(need, dim=-1, eps=eps) |
| | r_n = F.normalize(relation, dim=-1, eps=eps) |
| | p_n = F.normalize(purpose, dim=-1, eps=eps) |
| | |
| | |
| | a_n = torch.cosine_similarity(x_n, n_n, dim=-1) |
| | a_r = torch.cosine_similarity(x_n, r_n, dim=-1) |
| | a_p = torch.cosine_similarity(x_n, p_n, dim=-1) |
| | |
| | |
| | r7 = (a_n + a_r + a_p) / 3.0 |
| | r8 = x.norm(dim=-1) |
| | |
| | score = r7 * r8 |
| | |
| | return score.squeeze(1) if squeeze_output else score |
| | |
| | @staticmethod |
| | def compute_gram_matrix(verts): |
| | """Compute Gram matrix for batch of vertices.""" |
| | return torch.bmm(verts, verts.transpose(-2, -1)) |
| | |
| | @staticmethod |
| | def cayley_menger_determinant(verts): |
| | """Compute Cayley-Menger determinant (vectorized).""" |
| | B = verts.shape[0] |
| | |
| | gram = torch.bmm(verts, verts.transpose(-2, -1)) |
| | diag = gram.diagonal(dim1=-2, dim2=-1).unsqueeze(-1) |
| | dist_sq = diag + diag.transpose(-2, -1) - 2 * gram |
| | |
| | cm = torch.zeros(B, 6, 6, device=verts.device) |
| | cm[:, 0, 1:] = 1 |
| | cm[:, 1:, 0] = 1 |
| | cm[:, 1:, 1:] = dist_sq |
| | |
| | return torch.det(cm) |
| | |
| | @staticmethod |
| | def enforce_regular_simplex(verts): |
| | """Compute edge length variance (fully vectorized).""" |
| | diff = verts.unsqueeze(2) - verts.unsqueeze(1) |
| | dist = torch.norm(diff, dim=-1) |
| | |
| | triu_indices = torch.triu_indices(5, 5, offset=1) |
| | edges = dist[:, triu_indices[0], triu_indices[1]] |
| | |
| | return torch.var(edges, dim=-1) |
| | |
| | @staticmethod |
| | def orthoplex_projection(verts): |
| | """Project to unit hypersphere, centered.""" |
| | verts_norm = F.normalize(verts, dim=-1) |
| | center = verts_norm.mean(dim=1, keepdim=True) |
| | verts_centered = verts_norm - center |
| | return F.normalize(verts_centered, dim=-1) |
| | |
| | @staticmethod |
| | def apply( |
| | vertices, |
| | cayley_target: float = 1.0, |
| | return_dict: bool = False, |
| | compute_rose_scores: Optional[torch.Tensor] = None |
| | ): |
| | """ |
| | Apply all constraints and return stable vertices + losses. |
| | |
| | Args: |
| | vertices: Either dict or tensor [B, 5, D] |
| | cayley_target: Target Cayley-Menger determinant |
| | return_dict: If True and input was dict, return dict |
| | compute_rose_scores: Optional tensor to compute Rose scores against |
| | |
| | Returns: |
| | vertices_stable: Stabilized vertices |
| | losses: Dict of loss components (includes rose_scores if requested) |
| | """ |
| | was_dict = isinstance(vertices, dict) |
| | verts = PentachoronStabilizer.vertices_to_tensor(vertices) |
| | |
| | |
| | cm_det = PentachoronStabilizer.cayley_menger_determinant(verts) |
| | validity_loss = torch.abs(cm_det - cayley_target).mean() |
| | regularity_loss = PentachoronStabilizer.enforce_regular_simplex(verts).mean() |
| | |
| | |
| | verts_stable = PentachoronStabilizer.orthoplex_projection(verts) |
| | |
| | |
| | gram = PentachoronStabilizer.compute_gram_matrix(verts_stable) |
| | gram_entropy = -torch.sum(gram * torch.log(torch.abs(gram) + 1e-8)) / (verts.shape[0] * 25) |
| | |
| | losses = { |
| | 'validity': validity_loss, |
| | 'regularity': regularity_loss, |
| | 'gram_entropy': gram_entropy |
| | } |
| | |
| | |
| | if compute_rose_scores is not None: |
| | rose_scores = PentachoronStabilizer.rose_score_magnitude( |
| | compute_rose_scores, |
| | verts_stable |
| | ) |
| | losses['rose_scores'] = rose_scores |
| | |
| | |
| | if was_dict and return_dict: |
| | verts_stable = PentachoronStabilizer.tensor_to_dict(verts_stable) |
| | |
| | return verts_stable, losses |