PhysioJEPA / src /physiojepa /monitor.py
guychuk's picture
Upload folder using huggingface_hub
31e2456 verified
"""Collapse monitor: track latent variance, effective rank, cross-modal cosine sim.
Hard-stop criterion (per RESEARCH_DEVELOPMENT.md Pitfall 3):
mean cosine sim > 0.99 for 500 consecutive logged steps -> abort
"""
from __future__ import annotations
import collections
from dataclasses import dataclass, field
import torch
def effective_rank(z: torch.Tensor, eps: float = 1e-9) -> float:
"""Entropy-based effective rank of the covariance matrix."""
z = z - z.mean(dim=0, keepdim=True)
cov = (z.t() @ z) / max(z.shape[0] - 1, 1)
eig = torch.linalg.eigvalsh(cov.float())
eig = torch.clamp(eig, min=0)
total = eig.sum() + eps
p = eig / total
entropy = -(p * torch.log(p + eps)).sum()
return float(torch.exp(entropy).item())
def cross_modal_cosine(z_a: torch.Tensor, z_b: torch.Tensor) -> float:
a = torch.nn.functional.normalize(z_a, dim=-1)
b = torch.nn.functional.normalize(z_b, dim=-1)
return float((a * b).sum(dim=-1).mean().item())
@dataclass
class CollapseMonitor:
window: int = 500
threshold: float = 0.99
history: collections.deque = field(default_factory=lambda: collections.deque(maxlen=500))
def update(self, cosine: float) -> bool:
self.history.append(cosine)
if len(self.history) < self.window:
return False
return all(c > self.threshold for c in self.history)