|
|
import torch |
|
|
import torch.nn.functional as F |
|
|
from typing import Dict, Optional, Union |
|
|
|
|
|
|
|
|
def get_parameter_groups(model: PentachoraViT, |
|
|
weight_decay: float = 0.05) -> List[Dict[str, Any]]: |
|
|
"""Get parameter groups for optimizer with weight decay handling.""" |
|
|
no_decay = ['bias', 'norm', 'LayerNorm'] |
|
|
|
|
|
decay_params = [] |
|
|
no_decay_params = [] |
|
|
|
|
|
for name, param in model.named_parameters(): |
|
|
if not param.requires_grad: |
|
|
continue |
|
|
|
|
|
if any(nd in name for nd in no_decay): |
|
|
no_decay_params.append(param) |
|
|
else: |
|
|
decay_params.append(param) |
|
|
|
|
|
return [ |
|
|
{'params': decay_params, 'weight_decay': weight_decay}, |
|
|
{'params': no_decay_params, 'weight_decay': 0.0} |
|
|
] |
|
|
|
|
|
|
|
|
def count_parameters(model: nn.Module) -> Dict[str, int]: |
|
|
"""Count model parameters.""" |
|
|
total = sum(p.numel() for p in model.parameters()) |
|
|
trainable = sum(p.numel() for p in model.parameters() if p.requires_grad) |
|
|
return { |
|
|
'total': total, |
|
|
'trainable': trainable, |
|
|
'non_trainable': total - trainable |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_default_device(): |
|
|
"""Get the default device (CUDA if available, else CPU).""" |
|
|
return torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
|
|
|
|
|
class PentachoronStabilizer: |
|
|
""" |
|
|
Geometric constraint utilities for a 5-simplex (pentachoron). |
|
|
Includes Rose scoring for semantic alignment. |
|
|
""" |
|
|
|
|
|
@staticmethod |
|
|
def vertices_to_tensor(vertices): |
|
|
"""Convert dict to tensor once, reuse everywhere.""" |
|
|
if isinstance(vertices, dict): |
|
|
return torch.stack([ |
|
|
vertices['anchor'], vertices['need'], |
|
|
vertices['relation'], vertices['purpose'], |
|
|
vertices['observer'] |
|
|
], dim=1) |
|
|
return vertices |
|
|
|
|
|
@staticmethod |
|
|
def tensor_to_dict(verts): |
|
|
"""Convert tensor [B, 5, D] back to dict.""" |
|
|
return { |
|
|
'anchor': verts[:, 0], |
|
|
'need': verts[:, 1], |
|
|
'relation': verts[:, 2], |
|
|
'purpose': verts[:, 3], |
|
|
'observer': verts[:, 4] |
|
|
} |
|
|
|
|
|
@staticmethod |
|
|
def rose_score_magnitude( |
|
|
x: torch.Tensor, |
|
|
vertices: Union[Dict[str, torch.Tensor], torch.Tensor], |
|
|
eps: float = 1e-6 |
|
|
) -> torch.Tensor: |
|
|
""" |
|
|
Compute Rose similarity score between x and pentachoron vertices. |
|
|
|
|
|
Args: |
|
|
x: Query tensor [B, T, D] or [B, D] |
|
|
vertices: Either dict or tensor [B, 5, D] |
|
|
eps: Small value for numerical stability |
|
|
|
|
|
Returns: |
|
|
scores: [B, T] or [B] depending on input shape |
|
|
""" |
|
|
|
|
|
squeeze_output = False |
|
|
if x.dim() == 2: |
|
|
x = x.unsqueeze(1) |
|
|
squeeze_output = True |
|
|
|
|
|
|
|
|
if not isinstance(vertices, dict): |
|
|
vertices = PentachoronStabilizer.tensor_to_dict(vertices) |
|
|
|
|
|
|
|
|
B, T, D = x.shape |
|
|
need = vertices['need'].unsqueeze(1).expand(-1, T, -1) |
|
|
relation = vertices['relation'].unsqueeze(1).expand(-1, T, -1) |
|
|
purpose = vertices['purpose'].unsqueeze(1).expand(-1, T, -1) |
|
|
|
|
|
|
|
|
x_n = F.normalize(x, dim=-1, eps=eps) |
|
|
n_n = F.normalize(need, dim=-1, eps=eps) |
|
|
r_n = F.normalize(relation, dim=-1, eps=eps) |
|
|
p_n = F.normalize(purpose, dim=-1, eps=eps) |
|
|
|
|
|
|
|
|
a_n = torch.cosine_similarity(x_n, n_n, dim=-1) |
|
|
a_r = torch.cosine_similarity(x_n, r_n, dim=-1) |
|
|
a_p = torch.cosine_similarity(x_n, p_n, dim=-1) |
|
|
|
|
|
|
|
|
r7 = (a_n + a_r + a_p) / 3.0 |
|
|
r8 = x.norm(dim=-1) |
|
|
|
|
|
score = r7 * r8 |
|
|
|
|
|
return score.squeeze(1) if squeeze_output else score |
|
|
|
|
|
@staticmethod |
|
|
def compute_gram_matrix(verts): |
|
|
"""Compute Gram matrix for batch of vertices.""" |
|
|
return torch.bmm(verts, verts.transpose(-2, -1)) |
|
|
|
|
|
@staticmethod |
|
|
def cayley_menger_determinant(verts): |
|
|
"""Compute Cayley-Menger determinant (vectorized).""" |
|
|
B = verts.shape[0] |
|
|
|
|
|
gram = torch.bmm(verts, verts.transpose(-2, -1)) |
|
|
diag = gram.diagonal(dim1=-2, dim2=-1).unsqueeze(-1) |
|
|
dist_sq = diag + diag.transpose(-2, -1) - 2 * gram |
|
|
|
|
|
cm = torch.zeros(B, 6, 6, device=verts.device) |
|
|
cm[:, 0, 1:] = 1 |
|
|
cm[:, 1:, 0] = 1 |
|
|
cm[:, 1:, 1:] = dist_sq |
|
|
|
|
|
return torch.det(cm) |
|
|
|
|
|
@staticmethod |
|
|
def enforce_regular_simplex(verts): |
|
|
"""Compute edge length variance (fully vectorized).""" |
|
|
diff = verts.unsqueeze(2) - verts.unsqueeze(1) |
|
|
dist = torch.norm(diff, dim=-1) |
|
|
|
|
|
triu_indices = torch.triu_indices(5, 5, offset=1) |
|
|
edges = dist[:, triu_indices[0], triu_indices[1]] |
|
|
|
|
|
return torch.var(edges, dim=-1) |
|
|
|
|
|
@staticmethod |
|
|
def orthoplex_projection(verts): |
|
|
"""Project to unit hypersphere, centered.""" |
|
|
verts_norm = F.normalize(verts, dim=-1) |
|
|
center = verts_norm.mean(dim=1, keepdim=True) |
|
|
verts_centered = verts_norm - center |
|
|
return F.normalize(verts_centered, dim=-1) |
|
|
|
|
|
@staticmethod |
|
|
def apply( |
|
|
vertices, |
|
|
cayley_target: float = 1.0, |
|
|
return_dict: bool = False, |
|
|
compute_rose_scores: Optional[torch.Tensor] = None |
|
|
): |
|
|
""" |
|
|
Apply all constraints and return stable vertices + losses. |
|
|
|
|
|
Args: |
|
|
vertices: Either dict or tensor [B, 5, D] |
|
|
cayley_target: Target Cayley-Menger determinant |
|
|
return_dict: If True and input was dict, return dict |
|
|
compute_rose_scores: Optional tensor to compute Rose scores against |
|
|
|
|
|
Returns: |
|
|
vertices_stable: Stabilized vertices |
|
|
losses: Dict of loss components (includes rose_scores if requested) |
|
|
""" |
|
|
was_dict = isinstance(vertices, dict) |
|
|
verts = PentachoronStabilizer.vertices_to_tensor(vertices) |
|
|
|
|
|
|
|
|
cm_det = PentachoronStabilizer.cayley_menger_determinant(verts) |
|
|
validity_loss = torch.abs(cm_det - cayley_target).mean() |
|
|
regularity_loss = PentachoronStabilizer.enforce_regular_simplex(verts).mean() |
|
|
|
|
|
|
|
|
verts_stable = PentachoronStabilizer.orthoplex_projection(verts) |
|
|
|
|
|
|
|
|
gram = PentachoronStabilizer.compute_gram_matrix(verts_stable) |
|
|
gram_entropy = -torch.sum(gram * torch.log(torch.abs(gram) + 1e-8)) / (verts.shape[0] * 25) |
|
|
|
|
|
losses = { |
|
|
'validity': validity_loss, |
|
|
'regularity': regularity_loss, |
|
|
'gram_entropy': gram_entropy |
|
|
} |
|
|
|
|
|
|
|
|
if compute_rose_scores is not None: |
|
|
rose_scores = PentachoronStabilizer.rose_score_magnitude( |
|
|
compute_rose_scores, |
|
|
verts_stable |
|
|
) |
|
|
losses['rose_scores'] = rose_scores |
|
|
|
|
|
|
|
|
if was_dict and return_dict: |
|
|
verts_stable = PentachoronStabilizer.tensor_to_dict(verts_stable) |
|
|
|
|
|
return verts_stable, losses |