LimmeDev's picture
Initial MANIFOLD upload - CS2 cheat detection training
454ecdd verified
"""Temporal Invariant Verifier (TIV) for MANIFOLD.
Domain adversarial network for temporal invariance via gradient reversal.
Ensures features are invariant to temporal context (early/late game, etc.).
"""
from __future__ import annotations
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Optional, Dict, Any, TYPE_CHECKING
from torch.autograd import Function
if TYPE_CHECKING:
from manifold.config import ModelConfig
class GradientReversalFunction(Function):
"""Gradient reversal layer for domain adversarial training."""
@staticmethod
def forward(ctx, x: torch.Tensor, lambda_: float) -> torch.Tensor:
ctx.lambda_ = lambda_
return x.clone()
@staticmethod
def backward(ctx, grad_output: torch.Tensor):
return -ctx.lambda_ * grad_output, None
class GradientReversalLayer(nn.Module):
"""Wrapper module for gradient reversal."""
def __init__(self, lambda_: float = 1.0):
super().__init__()
self.lambda_ = lambda_
def forward(self, x: torch.Tensor) -> torch.Tensor:
return GradientReversalFunction.apply(x, self.lambda_)
class DomainClassifier(nn.Module):
"""Classify temporal domain (e.g., early/mid/late game, round type).
Args:
input_dim: Dimension of input features (default 256)
hidden_dim: Hidden dimension for MLP (default 128)
num_domains: Number of temporal domains to classify (default 4)
"""
def __init__(
self,
input_dim: int = 256,
hidden_dim: int = 128,
num_domains: int = 4,
):
super().__init__()
self.input_dim = input_dim
self.hidden_dim = hidden_dim
self.num_domains = num_domains
self.net = nn.Sequential(
nn.Linear(input_dim, hidden_dim),
nn.GELU(),
nn.Linear(hidden_dim, num_domains),
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Classify temporal domain.
Args:
x: Input features [batch, input_dim]
Returns:
Domain logits [batch, num_domains]
"""
return self.net(x)
class TemporalInvariantVerifier(nn.Module):
"""Domain adversarial network for temporal invariance.
Ensures features are invariant to temporal context (early/late game, etc.)
via gradient reversal. Main task gradients go forward, but domain
classification gradients are reversed to encourage domain-invariant features.
Args:
input_dim: Dimension of input features (default 256)
hidden_dim: Hidden dimension for domain classifier (default 128)
num_domains: Number of temporal domains (default 4)
adversarial_lambda: Scaling factor for gradient reversal (default 0.1)
"""
def __init__(
self,
input_dim: int = 256,
hidden_dim: int = 128,
num_domains: int = 4,
adversarial_lambda: float = 0.1,
):
super().__init__()
self.input_dim = input_dim
self.hidden_dim = hidden_dim
self.num_domains = num_domains
self.adversarial_lambda = adversarial_lambda
self.feature_transform = nn.Linear(input_dim, input_dim)
self.gradient_reversal = GradientReversalLayer(lambda_=adversarial_lambda)
self.domain_classifier = DomainClassifier(
input_dim=input_dim,
hidden_dim=hidden_dim,
num_domains=num_domains,
)
def forward(
self,
x: torch.Tensor,
domain_labels: Optional[torch.Tensor] = None,
) -> Dict[str, torch.Tensor]:
"""Forward pass through temporal invariant verifier.
Args:
x: Input features [batch, seq, input_dim]
domain_labels: Ground truth domains [batch] (for loss computation)
Returns:
Dict with:
- "output": transformed features (same shape as input)
- "domain_logits": domain predictions [batch, num_domains]
- "domain_loss": cross-entropy loss if labels provided, else 0
"""
batch, seq, dim = x.shape
output = self.feature_transform(x)
pooled = output.mean(dim=1)
reversed_features = self.gradient_reversal(pooled)
domain_logits = self.domain_classifier(reversed_features)
if domain_labels is not None:
domain_loss = F.cross_entropy(domain_logits, domain_labels)
else:
domain_loss = torch.tensor(0.0, device=x.device)
return {
"output": output,
"domain_logits": domain_logits,
"domain_loss": domain_loss,
}
@classmethod
def from_config(cls, config: "ModelConfig") -> "TemporalInvariantVerifier":
"""Create TemporalInvariantVerifier from ModelConfig.
Args:
config: Model configuration object
Returns:
Configured TemporalInvariantVerifier instance
"""
return cls(
input_dim=config.embed_dim,
hidden_dim=config.embed_dim // 2,
num_domains=config.num_domains,
adversarial_lambda=config.adversarial_lambda,
)