penta-vit-experiments / pentachora_stabilizer.py
AbstractPhil's picture
Create pentachora_stabilizer.py
279d83a verified
raw
history blame
7.54 kB
import torch
import torch.nn.functional as F
from typing import Dict, Optional, Union
def get_parameter_groups(model: PentachoraViT,
weight_decay: float = 0.05) -> List[Dict[str, Any]]:
"""Get parameter groups for optimizer with weight decay handling."""
no_decay = ['bias', 'norm', 'LayerNorm']
decay_params = []
no_decay_params = []
for name, param in model.named_parameters():
if not param.requires_grad:
continue
if any(nd in name for nd in no_decay):
no_decay_params.append(param)
else:
decay_params.append(param)
return [
{'params': decay_params, 'weight_decay': weight_decay},
{'params': no_decay_params, 'weight_decay': 0.0}
]
def count_parameters(model: nn.Module) -> Dict[str, int]:
"""Count model parameters."""
total = sum(p.numel() for p in model.parameters())
trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
return {
'total': total,
'trainable': trainable,
'non_trainable': total - trainable
}
# 1. Add a utility function at the top of the file:
def get_default_device():
"""Get the default device (CUDA if available, else CPU)."""
return torch.device('cuda' if torch.cuda.is_available() else 'cpu')
class PentachoronStabilizer:
"""
Geometric constraint utilities for a 5-simplex (pentachoron).
Includes Rose scoring for semantic alignment.
"""
@staticmethod
def vertices_to_tensor(vertices):
"""Convert dict to tensor once, reuse everywhere."""
if isinstance(vertices, dict):
return torch.stack([
vertices['anchor'], vertices['need'],
vertices['relation'], vertices['purpose'],
vertices['observer']
], dim=1) # [B, 5, D]
return vertices
@staticmethod
def tensor_to_dict(verts):
"""Convert tensor [B, 5, D] back to dict."""
return {
'anchor': verts[:, 0],
'need': verts[:, 1],
'relation': verts[:, 2],
'purpose': verts[:, 3],
'observer': verts[:, 4]
}
@staticmethod
def rose_score_magnitude(
x: torch.Tensor,
vertices: Union[Dict[str, torch.Tensor], torch.Tensor],
eps: float = 1e-6
) -> torch.Tensor:
"""
Compute Rose similarity score between x and pentachoron vertices.
Args:
x: Query tensor [B, T, D] or [B, D]
vertices: Either dict or tensor [B, 5, D]
eps: Small value for numerical stability
Returns:
scores: [B, T] or [B] depending on input shape
"""
# Handle input shapes
squeeze_output = False
if x.dim() == 2:
x = x.unsqueeze(1) # [B, 1, D]
squeeze_output = True
# Get vertices as dict
if not isinstance(vertices, dict):
vertices = PentachoronStabilizer.tensor_to_dict(vertices)
# Expand vertices to match sequence dimension
B, T, D = x.shape
need = vertices['need'].unsqueeze(1).expand(-1, T, -1)
relation = vertices['relation'].unsqueeze(1).expand(-1, T, -1)
purpose = vertices['purpose'].unsqueeze(1).expand(-1, T, -1)
# Normalize all inputs
x_n = F.normalize(x, dim=-1, eps=eps)
n_n = F.normalize(need, dim=-1, eps=eps)
r_n = F.normalize(relation, dim=-1, eps=eps)
p_n = F.normalize(purpose, dim=-1, eps=eps)
# Core directional cosine components
a_n = torch.cosine_similarity(x_n, n_n, dim=-1)
a_r = torch.cosine_similarity(x_n, r_n, dim=-1)
a_p = torch.cosine_similarity(x_n, p_n, dim=-1)
# Triadic magnitude score
r7 = (a_n + a_r + a_p) / 3.0
r8 = x.norm(dim=-1)
score = r7 * r8
return score.squeeze(1) if squeeze_output else score
@staticmethod
def compute_gram_matrix(verts):
"""Compute Gram matrix for batch of vertices."""
return torch.bmm(verts, verts.transpose(-2, -1))
@staticmethod
def cayley_menger_determinant(verts):
"""Compute Cayley-Menger determinant (vectorized)."""
B = verts.shape[0]
gram = torch.bmm(verts, verts.transpose(-2, -1))
diag = gram.diagonal(dim1=-2, dim2=-1).unsqueeze(-1)
dist_sq = diag + diag.transpose(-2, -1) - 2 * gram
cm = torch.zeros(B, 6, 6, device=verts.device)
cm[:, 0, 1:] = 1
cm[:, 1:, 0] = 1
cm[:, 1:, 1:] = dist_sq
return torch.det(cm)
@staticmethod
def enforce_regular_simplex(verts):
"""Compute edge length variance (fully vectorized)."""
diff = verts.unsqueeze(2) - verts.unsqueeze(1)
dist = torch.norm(diff, dim=-1)
triu_indices = torch.triu_indices(5, 5, offset=1)
edges = dist[:, triu_indices[0], triu_indices[1]]
return torch.var(edges, dim=-1)
@staticmethod
def orthoplex_projection(verts):
"""Project to unit hypersphere, centered."""
verts_norm = F.normalize(verts, dim=-1)
center = verts_norm.mean(dim=1, keepdim=True)
verts_centered = verts_norm - center
return F.normalize(verts_centered, dim=-1)
@staticmethod
def apply(
vertices,
cayley_target: float = 1.0,
return_dict: bool = False,
compute_rose_scores: Optional[torch.Tensor] = None
):
"""
Apply all constraints and return stable vertices + losses.
Args:
vertices: Either dict or tensor [B, 5, D]
cayley_target: Target Cayley-Menger determinant
return_dict: If True and input was dict, return dict
compute_rose_scores: Optional tensor to compute Rose scores against
Returns:
vertices_stable: Stabilized vertices
losses: Dict of loss components (includes rose_scores if requested)
"""
was_dict = isinstance(vertices, dict)
verts = PentachoronStabilizer.vertices_to_tensor(vertices)
# Compute geometric losses
cm_det = PentachoronStabilizer.cayley_menger_determinant(verts)
validity_loss = torch.abs(cm_det - cayley_target).mean()
regularity_loss = PentachoronStabilizer.enforce_regular_simplex(verts).mean()
# Stabilize vertices
verts_stable = PentachoronStabilizer.orthoplex_projection(verts)
# Compute Gram entropy
gram = PentachoronStabilizer.compute_gram_matrix(verts_stable)
gram_entropy = -torch.sum(gram * torch.log(torch.abs(gram) + 1e-8)) / (verts.shape[0] * 25)
losses = {
'validity': validity_loss,
'regularity': regularity_loss,
'gram_entropy': gram_entropy
}
# Compute Rose scores if requested
if compute_rose_scores is not None:
rose_scores = PentachoronStabilizer.rose_score_magnitude(
compute_rose_scores,
verts_stable
)
losses['rose_scores'] = rose_scores
# Convert back to dict if requested
if was_dict and return_dict:
verts_stable = PentachoronStabilizer.tensor_to_dict(verts_stable)
return verts_stable, losses