| """Collapse monitor: track latent variance, effective rank, cross-modal cosine sim. |
| |
| Hard-stop criterion (per RESEARCH_DEVELOPMENT.md Pitfall 3): |
| mean cosine sim > 0.99 for 500 consecutive logged steps -> abort |
| """ |
| from __future__ import annotations |
|
|
| import collections |
| from dataclasses import dataclass, field |
|
|
| import torch |
|
|
|
|
| def effective_rank(z: torch.Tensor, eps: float = 1e-9) -> float: |
| """Entropy-based effective rank of the covariance matrix.""" |
| z = z - z.mean(dim=0, keepdim=True) |
| cov = (z.t() @ z) / max(z.shape[0] - 1, 1) |
| eig = torch.linalg.eigvalsh(cov.float()) |
| eig = torch.clamp(eig, min=0) |
| total = eig.sum() + eps |
| p = eig / total |
| entropy = -(p * torch.log(p + eps)).sum() |
| return float(torch.exp(entropy).item()) |
|
|
|
|
| def cross_modal_cosine(z_a: torch.Tensor, z_b: torch.Tensor) -> float: |
| a = torch.nn.functional.normalize(z_a, dim=-1) |
| b = torch.nn.functional.normalize(z_b, dim=-1) |
| return float((a * b).sum(dim=-1).mean().item()) |
|
|
|
|
| @dataclass |
| class CollapseMonitor: |
| window: int = 500 |
| threshold: float = 0.99 |
| history: collections.deque = field(default_factory=lambda: collections.deque(maxlen=500)) |
|
|
| def update(self, cosine: float) -> bool: |
| self.history.append(cosine) |
| if len(self.history) < self.window: |
| return False |
| return all(c > self.threshold for c in self.history) |
|
|