Spaces:
Running on Zero
Running on Zero
| """VAE-style Manifold Projection Layer for learning behavior manifold.""" | |
| from __future__ import annotations | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from typing import Optional, Dict, Any | |
| class MPLEncoder(nn.Module): | |
| """Encode input to latent distribution parameters (mu, logvar).""" | |
| def __init__( | |
| self, | |
| input_dim: int = 256, | |
| hidden_dim: int = 256, | |
| latent_dim: int = 64, | |
| ): | |
| super().__init__() | |
| self.input_dim = input_dim | |
| self.hidden_dim = hidden_dim | |
| self.latent_dim = latent_dim | |
| self.fc1 = nn.Linear(input_dim, hidden_dim) | |
| self.fc_mu = nn.Linear(hidden_dim, latent_dim) | |
| self.fc_logvar = nn.Linear(hidden_dim, latent_dim) | |
| self._init_weights() | |
| def _init_weights(self) -> None: | |
| """Initialize weights for stable training.""" | |
| nn.init.xavier_uniform_(self.fc1.weight) | |
| nn.init.zeros_(self.fc1.bias) | |
| nn.init.xavier_uniform_(self.fc_mu.weight) | |
| nn.init.zeros_(self.fc_mu.bias) | |
| nn.init.xavier_uniform_(self.fc_logvar.weight, gain=0.1) | |
| nn.init.zeros_(self.fc_logvar.bias) | |
| def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: | |
| """ | |
| Encode input to latent distribution parameters. | |
| Args: | |
| x: Input features [batch, seq, input_dim] | |
| Returns: | |
| (mu, logvar): Distribution parameters [batch, seq, latent_dim] | |
| """ | |
| h = F.gelu(self.fc1(x)) | |
| mu = self.fc_mu(h) | |
| logvar = self.fc_logvar(h) | |
| return mu, logvar | |
| class MPLDecoder(nn.Module): | |
| """Decode latent samples back to input space.""" | |
| def __init__( | |
| self, | |
| latent_dim: int = 64, | |
| hidden_dim: int = 256, | |
| output_dim: int = 256, | |
| ): | |
| super().__init__() | |
| self.latent_dim = latent_dim | |
| self.hidden_dim = hidden_dim | |
| self.output_dim = output_dim | |
| self.fc1 = nn.Linear(latent_dim, hidden_dim) | |
| self.fc2 = nn.Linear(hidden_dim, output_dim) | |
| self._init_weights() | |
| def _init_weights(self) -> None: | |
| """Initialize weights for stable training.""" | |
| nn.init.xavier_uniform_(self.fc1.weight) | |
| nn.init.zeros_(self.fc1.bias) | |
| nn.init.xavier_uniform_(self.fc2.weight) | |
| nn.init.zeros_(self.fc2.bias) | |
| def forward(self, z: torch.Tensor) -> torch.Tensor: | |
| """ | |
| Decode latent samples to reconstruction. | |
| Args: | |
| z: Latent samples [batch, seq, latent_dim] | |
| Returns: | |
| Reconstruction [batch, seq, output_dim] | |
| """ | |
| h = F.gelu(self.fc1(z)) | |
| return self.fc2(h) | |
| class ManifoldProjectionLayer(nn.Module): | |
| """ | |
| VAE-style manifold projection for learning behavior manifold. | |
| Projects high-dimensional behavior to low-dimensional manifold, | |
| then reconstructs. KL divergence regularizes latent space. | |
| """ | |
| def __init__( | |
| self, | |
| input_dim: int = 256, | |
| latent_dim: int = 64, | |
| hidden_dim: int = 256, | |
| kl_weight: float = 0.001, | |
| ): | |
| super().__init__() | |
| self.input_dim = input_dim | |
| self.latent_dim = latent_dim | |
| self.hidden_dim = hidden_dim | |
| self.kl_weight = kl_weight | |
| self.encoder = MPLEncoder( | |
| input_dim=input_dim, | |
| hidden_dim=hidden_dim, | |
| latent_dim=latent_dim, | |
| ) | |
| self.decoder = MPLDecoder( | |
| latent_dim=latent_dim, | |
| hidden_dim=hidden_dim, | |
| output_dim=input_dim, | |
| ) | |
| def reparameterize(self, mu: torch.Tensor, logvar: torch.Tensor) -> torch.Tensor: | |
| """ | |
| Reparameterization trick for backprop through sampling. | |
| z = mu + std * eps, where eps ~ N(0, I) | |
| Args: | |
| mu: Mean of latent distribution [batch, seq, latent_dim] | |
| logvar: Log variance of latent distribution [batch, seq, latent_dim] | |
| Returns: | |
| Sampled latent [batch, seq, latent_dim] | |
| """ | |
| std = torch.exp(0.5 * logvar) | |
| eps = torch.randn_like(std) | |
| return mu + eps * std | |
| def forward(self, x: torch.Tensor) -> Dict[str, torch.Tensor]: | |
| """ | |
| Forward pass through VAE-style manifold projection. | |
| Args: | |
| x: Input features [batch, seq, input_dim] | |
| Returns: | |
| Dict with: | |
| - "latent": sampled latent [batch, seq, latent_dim] | |
| - "mu": mean [batch, seq, latent_dim] | |
| - "logvar": log variance [batch, seq, latent_dim] | |
| - "reconstruction": reconstructed input [batch, seq, input_dim] | |
| - "kl_loss": KL divergence loss scalar | |
| """ | |
| mu, logvar = self.encoder(x) | |
| z = self.reparameterize(mu, logvar) | |
| reconstruction = self.decoder(z) | |
| # KL(q(z|x) || p(z)) = -0.5 * sum(1 + logvar - mu^2 - exp(logvar)) | |
| kl_loss = -0.5 * torch.mean( | |
| torch.sum(1 + logvar - mu.pow(2) - logvar.exp(), dim=-1) | |
| ) | |
| return { | |
| "latent": z, | |
| "mu": mu, | |
| "logvar": logvar, | |
| "reconstruction": reconstruction, | |
| "kl_loss": kl_loss * self.kl_weight, | |
| } | |
| def from_config(cls, config: Any) -> "ManifoldProjectionLayer": | |
| """ | |
| Create ManifoldProjectionLayer from ModelConfig. | |
| Args: | |
| config: ModelConfig instance with mpl parameters | |
| Returns: | |
| Configured ManifoldProjectionLayer instance | |
| """ | |
| return cls( | |
| input_dim=config.embed_dim, | |
| latent_dim=config.latent_dim, | |
| hidden_dim=config.mpl_hidden, | |
| kl_weight=config.kl_weight, | |
| ) | |