"""VAE-style Manifold Projection Layer for learning behavior manifold.""" from __future__ import annotations import torch import torch.nn as nn import torch.nn.functional as F from typing import Optional, Dict, Any class MPLEncoder(nn.Module): """Encode input to latent distribution parameters (mu, logvar).""" def __init__( self, input_dim: int = 256, hidden_dim: int = 256, latent_dim: int = 64, ): super().__init__() self.input_dim = input_dim self.hidden_dim = hidden_dim self.latent_dim = latent_dim self.fc1 = nn.Linear(input_dim, hidden_dim) self.fc_mu = nn.Linear(hidden_dim, latent_dim) self.fc_logvar = nn.Linear(hidden_dim, latent_dim) self._init_weights() def _init_weights(self) -> None: """Initialize weights for stable training.""" nn.init.xavier_uniform_(self.fc1.weight) nn.init.zeros_(self.fc1.bias) nn.init.xavier_uniform_(self.fc_mu.weight) nn.init.zeros_(self.fc_mu.bias) nn.init.xavier_uniform_(self.fc_logvar.weight, gain=0.1) nn.init.zeros_(self.fc_logvar.bias) def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: """ Encode input to latent distribution parameters. Args: x: Input features [batch, seq, input_dim] Returns: (mu, logvar): Distribution parameters [batch, seq, latent_dim] """ h = F.gelu(self.fc1(x)) mu = self.fc_mu(h) logvar = self.fc_logvar(h) return mu, logvar class MPLDecoder(nn.Module): """Decode latent samples back to input space.""" def __init__( self, latent_dim: int = 64, hidden_dim: int = 256, output_dim: int = 256, ): super().__init__() self.latent_dim = latent_dim self.hidden_dim = hidden_dim self.output_dim = output_dim self.fc1 = nn.Linear(latent_dim, hidden_dim) self.fc2 = nn.Linear(hidden_dim, output_dim) self._init_weights() def _init_weights(self) -> None: """Initialize weights for stable training.""" nn.init.xavier_uniform_(self.fc1.weight) nn.init.zeros_(self.fc1.bias) nn.init.xavier_uniform_(self.fc2.weight) nn.init.zeros_(self.fc2.bias) def forward(self, z: torch.Tensor) -> torch.Tensor: """ Decode latent samples to reconstruction. Args: z: Latent samples [batch, seq, latent_dim] Returns: Reconstruction [batch, seq, output_dim] """ h = F.gelu(self.fc1(z)) return self.fc2(h) class ManifoldProjectionLayer(nn.Module): """ VAE-style manifold projection for learning behavior manifold. Projects high-dimensional behavior to low-dimensional manifold, then reconstructs. KL divergence regularizes latent space. """ def __init__( self, input_dim: int = 256, latent_dim: int = 64, hidden_dim: int = 256, kl_weight: float = 0.001, ): super().__init__() self.input_dim = input_dim self.latent_dim = latent_dim self.hidden_dim = hidden_dim self.kl_weight = kl_weight self.encoder = MPLEncoder( input_dim=input_dim, hidden_dim=hidden_dim, latent_dim=latent_dim, ) self.decoder = MPLDecoder( latent_dim=latent_dim, hidden_dim=hidden_dim, output_dim=input_dim, ) def reparameterize(self, mu: torch.Tensor, logvar: torch.Tensor) -> torch.Tensor: """ Reparameterization trick for backprop through sampling. z = mu + std * eps, where eps ~ N(0, I) Args: mu: Mean of latent distribution [batch, seq, latent_dim] logvar: Log variance of latent distribution [batch, seq, latent_dim] Returns: Sampled latent [batch, seq, latent_dim] """ std = torch.exp(0.5 * logvar) eps = torch.randn_like(std) return mu + eps * std def forward(self, x: torch.Tensor) -> Dict[str, torch.Tensor]: """ Forward pass through VAE-style manifold projection. Args: x: Input features [batch, seq, input_dim] Returns: Dict with: - "latent": sampled latent [batch, seq, latent_dim] - "mu": mean [batch, seq, latent_dim] - "logvar": log variance [batch, seq, latent_dim] - "reconstruction": reconstructed input [batch, seq, input_dim] - "kl_loss": KL divergence loss scalar """ mu, logvar = self.encoder(x) z = self.reparameterize(mu, logvar) reconstruction = self.decoder(z) # KL(q(z|x) || p(z)) = -0.5 * sum(1 + logvar - mu^2 - exp(logvar)) kl_loss = -0.5 * torch.mean( torch.sum(1 + logvar - mu.pow(2) - logvar.exp(), dim=-1) ) return { "latent": z, "mu": mu, "logvar": logvar, "reconstruction": reconstruction, "kl_loss": kl_loss * self.kl_weight, } @classmethod def from_config(cls, config: Any) -> "ManifoldProjectionLayer": """ Create ManifoldProjectionLayer from ModelConfig. Args: config: ModelConfig instance with mpl parameters Returns: Configured ManifoldProjectionLayer instance """ return cls( input_dim=config.embed_dim, latent_dim=config.latent_dim, hidden_dim=config.mpl_hidden, kl_weight=config.kl_weight, )