LimmeDev's picture
Initial MANIFOLD upload - CS2 cheat detection training
454ecdd verified
"""VAE-style Manifold Projection Layer for learning behavior manifold."""
from __future__ import annotations
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Optional, Dict, Any
class MPLEncoder(nn.Module):
"""Encode input to latent distribution parameters (mu, logvar)."""
def __init__(
self,
input_dim: int = 256,
hidden_dim: int = 256,
latent_dim: int = 64,
):
super().__init__()
self.input_dim = input_dim
self.hidden_dim = hidden_dim
self.latent_dim = latent_dim
self.fc1 = nn.Linear(input_dim, hidden_dim)
self.fc_mu = nn.Linear(hidden_dim, latent_dim)
self.fc_logvar = nn.Linear(hidden_dim, latent_dim)
self._init_weights()
def _init_weights(self) -> None:
"""Initialize weights for stable training."""
nn.init.xavier_uniform_(self.fc1.weight)
nn.init.zeros_(self.fc1.bias)
nn.init.xavier_uniform_(self.fc_mu.weight)
nn.init.zeros_(self.fc_mu.bias)
nn.init.xavier_uniform_(self.fc_logvar.weight, gain=0.1)
nn.init.zeros_(self.fc_logvar.bias)
def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
"""
Encode input to latent distribution parameters.
Args:
x: Input features [batch, seq, input_dim]
Returns:
(mu, logvar): Distribution parameters [batch, seq, latent_dim]
"""
h = F.gelu(self.fc1(x))
mu = self.fc_mu(h)
logvar = self.fc_logvar(h)
return mu, logvar
class MPLDecoder(nn.Module):
"""Decode latent samples back to input space."""
def __init__(
self,
latent_dim: int = 64,
hidden_dim: int = 256,
output_dim: int = 256,
):
super().__init__()
self.latent_dim = latent_dim
self.hidden_dim = hidden_dim
self.output_dim = output_dim
self.fc1 = nn.Linear(latent_dim, hidden_dim)
self.fc2 = nn.Linear(hidden_dim, output_dim)
self._init_weights()
def _init_weights(self) -> None:
"""Initialize weights for stable training."""
nn.init.xavier_uniform_(self.fc1.weight)
nn.init.zeros_(self.fc1.bias)
nn.init.xavier_uniform_(self.fc2.weight)
nn.init.zeros_(self.fc2.bias)
def forward(self, z: torch.Tensor) -> torch.Tensor:
"""
Decode latent samples to reconstruction.
Args:
z: Latent samples [batch, seq, latent_dim]
Returns:
Reconstruction [batch, seq, output_dim]
"""
h = F.gelu(self.fc1(z))
return self.fc2(h)
class ManifoldProjectionLayer(nn.Module):
"""
VAE-style manifold projection for learning behavior manifold.
Projects high-dimensional behavior to low-dimensional manifold,
then reconstructs. KL divergence regularizes latent space.
"""
def __init__(
self,
input_dim: int = 256,
latent_dim: int = 64,
hidden_dim: int = 256,
kl_weight: float = 0.001,
):
super().__init__()
self.input_dim = input_dim
self.latent_dim = latent_dim
self.hidden_dim = hidden_dim
self.kl_weight = kl_weight
self.encoder = MPLEncoder(
input_dim=input_dim,
hidden_dim=hidden_dim,
latent_dim=latent_dim,
)
self.decoder = MPLDecoder(
latent_dim=latent_dim,
hidden_dim=hidden_dim,
output_dim=input_dim,
)
def reparameterize(self, mu: torch.Tensor, logvar: torch.Tensor) -> torch.Tensor:
"""
Reparameterization trick for backprop through sampling.
z = mu + std * eps, where eps ~ N(0, I)
Args:
mu: Mean of latent distribution [batch, seq, latent_dim]
logvar: Log variance of latent distribution [batch, seq, latent_dim]
Returns:
Sampled latent [batch, seq, latent_dim]
"""
std = torch.exp(0.5 * logvar)
eps = torch.randn_like(std)
return mu + eps * std
def forward(self, x: torch.Tensor) -> Dict[str, torch.Tensor]:
"""
Forward pass through VAE-style manifold projection.
Args:
x: Input features [batch, seq, input_dim]
Returns:
Dict with:
- "latent": sampled latent [batch, seq, latent_dim]
- "mu": mean [batch, seq, latent_dim]
- "logvar": log variance [batch, seq, latent_dim]
- "reconstruction": reconstructed input [batch, seq, input_dim]
- "kl_loss": KL divergence loss scalar
"""
mu, logvar = self.encoder(x)
z = self.reparameterize(mu, logvar)
reconstruction = self.decoder(z)
# KL(q(z|x) || p(z)) = -0.5 * sum(1 + logvar - mu^2 - exp(logvar))
kl_loss = -0.5 * torch.mean(
torch.sum(1 + logvar - mu.pow(2) - logvar.exp(), dim=-1)
)
return {
"latent": z,
"mu": mu,
"logvar": logvar,
"reconstruction": reconstruction,
"kl_loss": kl_loss * self.kl_weight,
}
@classmethod
def from_config(cls, config: Any) -> "ManifoldProjectionLayer":
"""
Create ManifoldProjectionLayer from ModelConfig.
Args:
config: ModelConfig instance with mpl parameters
Returns:
Configured ManifoldProjectionLayer instance
"""
return cls(
input_dim=config.embed_dim,
latent_dim=config.latent_dim,
hidden_dim=config.mpl_hidden,
kl_weight=config.kl_weight,
)