"""Collapse monitor: track latent variance, effective rank, cross-modal cosine sim. Hard-stop criterion (per RESEARCH_DEVELOPMENT.md Pitfall 3): mean cosine sim > 0.99 for 500 consecutive logged steps -> abort """ from __future__ import annotations import collections from dataclasses import dataclass, field import torch def effective_rank(z: torch.Tensor, eps: float = 1e-9) -> float: """Entropy-based effective rank of the covariance matrix.""" z = z - z.mean(dim=0, keepdim=True) cov = (z.t() @ z) / max(z.shape[0] - 1, 1) eig = torch.linalg.eigvalsh(cov.float()) eig = torch.clamp(eig, min=0) total = eig.sum() + eps p = eig / total entropy = -(p * torch.log(p + eps)).sum() return float(torch.exp(entropy).item()) def cross_modal_cosine(z_a: torch.Tensor, z_b: torch.Tensor) -> float: a = torch.nn.functional.normalize(z_a, dim=-1) b = torch.nn.functional.normalize(z_b, dim=-1) return float((a * b).sum(dim=-1).mean().item()) @dataclass class CollapseMonitor: window: int = 500 threshold: float = 0.99 history: collections.deque = field(default_factory=lambda: collections.deque(maxlen=500)) def update(self, cosine: float) -> bool: self.history.append(cosine) if len(self.history) < self.window: return False return all(c > self.threshold for c in self.history)