Spaces:
Running on Zero
Running on Zero
| """Temporal Invariant Verifier (TIV) for MANIFOLD. | |
| Domain adversarial network for temporal invariance via gradient reversal. | |
| Ensures features are invariant to temporal context (early/late game, etc.). | |
| """ | |
| from __future__ import annotations | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from typing import Optional, Dict, Any, TYPE_CHECKING | |
| from torch.autograd import Function | |
| if TYPE_CHECKING: | |
| from manifold.config import ModelConfig | |
| class GradientReversalFunction(Function): | |
| """Gradient reversal layer for domain adversarial training.""" | |
| def forward(ctx, x: torch.Tensor, lambda_: float) -> torch.Tensor: | |
| ctx.lambda_ = lambda_ | |
| return x.clone() | |
| def backward(ctx, grad_output: torch.Tensor): | |
| return -ctx.lambda_ * grad_output, None | |
| class GradientReversalLayer(nn.Module): | |
| """Wrapper module for gradient reversal.""" | |
| def __init__(self, lambda_: float = 1.0): | |
| super().__init__() | |
| self.lambda_ = lambda_ | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| return GradientReversalFunction.apply(x, self.lambda_) | |
| class DomainClassifier(nn.Module): | |
| """Classify temporal domain (e.g., early/mid/late game, round type). | |
| Args: | |
| input_dim: Dimension of input features (default 256) | |
| hidden_dim: Hidden dimension for MLP (default 128) | |
| num_domains: Number of temporal domains to classify (default 4) | |
| """ | |
| def __init__( | |
| self, | |
| input_dim: int = 256, | |
| hidden_dim: int = 128, | |
| num_domains: int = 4, | |
| ): | |
| super().__init__() | |
| self.input_dim = input_dim | |
| self.hidden_dim = hidden_dim | |
| self.num_domains = num_domains | |
| self.net = nn.Sequential( | |
| nn.Linear(input_dim, hidden_dim), | |
| nn.GELU(), | |
| nn.Linear(hidden_dim, num_domains), | |
| ) | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| """Classify temporal domain. | |
| Args: | |
| x: Input features [batch, input_dim] | |
| Returns: | |
| Domain logits [batch, num_domains] | |
| """ | |
| return self.net(x) | |
| class TemporalInvariantVerifier(nn.Module): | |
| """Domain adversarial network for temporal invariance. | |
| Ensures features are invariant to temporal context (early/late game, etc.) | |
| via gradient reversal. Main task gradients go forward, but domain | |
| classification gradients are reversed to encourage domain-invariant features. | |
| Args: | |
| input_dim: Dimension of input features (default 256) | |
| hidden_dim: Hidden dimension for domain classifier (default 128) | |
| num_domains: Number of temporal domains (default 4) | |
| adversarial_lambda: Scaling factor for gradient reversal (default 0.1) | |
| """ | |
| def __init__( | |
| self, | |
| input_dim: int = 256, | |
| hidden_dim: int = 128, | |
| num_domains: int = 4, | |
| adversarial_lambda: float = 0.1, | |
| ): | |
| super().__init__() | |
| self.input_dim = input_dim | |
| self.hidden_dim = hidden_dim | |
| self.num_domains = num_domains | |
| self.adversarial_lambda = adversarial_lambda | |
| self.feature_transform = nn.Linear(input_dim, input_dim) | |
| self.gradient_reversal = GradientReversalLayer(lambda_=adversarial_lambda) | |
| self.domain_classifier = DomainClassifier( | |
| input_dim=input_dim, | |
| hidden_dim=hidden_dim, | |
| num_domains=num_domains, | |
| ) | |
| def forward( | |
| self, | |
| x: torch.Tensor, | |
| domain_labels: Optional[torch.Tensor] = None, | |
| ) -> Dict[str, torch.Tensor]: | |
| """Forward pass through temporal invariant verifier. | |
| Args: | |
| x: Input features [batch, seq, input_dim] | |
| domain_labels: Ground truth domains [batch] (for loss computation) | |
| Returns: | |
| Dict with: | |
| - "output": transformed features (same shape as input) | |
| - "domain_logits": domain predictions [batch, num_domains] | |
| - "domain_loss": cross-entropy loss if labels provided, else 0 | |
| """ | |
| batch, seq, dim = x.shape | |
| output = self.feature_transform(x) | |
| pooled = output.mean(dim=1) | |
| reversed_features = self.gradient_reversal(pooled) | |
| domain_logits = self.domain_classifier(reversed_features) | |
| if domain_labels is not None: | |
| domain_loss = F.cross_entropy(domain_logits, domain_labels) | |
| else: | |
| domain_loss = torch.tensor(0.0, device=x.device) | |
| return { | |
| "output": output, | |
| "domain_logits": domain_logits, | |
| "domain_loss": domain_loss, | |
| } | |
| def from_config(cls, config: "ModelConfig") -> "TemporalInvariantVerifier": | |
| """Create TemporalInvariantVerifier from ModelConfig. | |
| Args: | |
| config: Model configuration object | |
| Returns: | |
| Configured TemporalInvariantVerifier instance | |
| """ | |
| return cls( | |
| input_dim=config.embed_dim, | |
| hidden_dim=config.embed_dim // 2, | |
| num_domains=config.num_domains, | |
| adversarial_lambda=config.adversarial_lambda, | |
| ) | |