import torch import torch.nn as nn from typing import Tuple, List # =============================================================== # 🧠 Beta-VAE (Frame Compression / Reconstruction) # =============================================================== class BetaVAE(nn.Module): """β-Variational Autoencoder for frame compression""" def __init__(self, input_channels: int = 1, latent_dim: int = 64, beta: float = 4.0): super().__init__() self.latent_dim = latent_dim self.beta = beta # Encoder: 64x64 → 4x4 self.encoder = nn.Sequential( nn.Conv2d(input_channels, 32, kernel_size=4, stride=2, padding=1), # 64→32 nn.ReLU(), nn.Conv2d(32, 64, kernel_size=4, stride=2, padding=1), # 32→16 nn.ReLU(), nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1), # 16→8 nn.ReLU(), nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1), # 8→4 nn.ReLU(), ) # Latent space self.fc_mu = nn.Linear(256 * 4 * 4, latent_dim) self.fc_logvar = nn.Linear(256 * 4 * 4, latent_dim) # Decoder input self.decoder_input = nn.Linear(latent_dim, 256 * 4 * 4) # Decoder: 4x4 → 64x64 self.decoder = nn.Sequential( nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1), # 4→8 nn.ReLU(), nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1), # 8→16 nn.ReLU(), nn.ConvTranspose2d(64, 32, kernel_size=4, stride=2, padding=1), # 16→32 nn.ReLU(), nn.ConvTranspose2d(32, input_channels, kernel_size=4, stride=2, padding=1), # 32→64 nn.Sigmoid() # output [0,1] ) # ----------------------------------------------------------- # Encoder / Decoder Logic # ----------------------------------------------------------- def encode(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: h = self.encoder(x) h = h.view(h.size(0), -1) mu = self.fc_mu(h) logvar = self.fc_logvar(h) return mu, logvar def reparameterize(self, mu: torch.Tensor, logvar: torch.Tensor) -> torch.Tensor: std = torch.exp(0.5 * logvar) eps = torch.randn_like(std) return mu + eps * std def decode(self, z: torch.Tensor) -> torch.Tensor: h = self.decoder_input(z) h = h.view(h.size(0), 256, 4, 4) return self.decoder(h) # ----------------------------------------------------------- # Forward + Loss # ----------------------------------------------------------- def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: mu, logvar = self.encode(x) z = self.reparameterize(mu, logvar) recon = self.decode(z) return recon, mu, logvar def loss_function(self, recon_x: torch.Tensor, x: torch.Tensor, mu: torch.Tensor, logvar: torch.Tensor) -> dict: """Compute β-VAE loss""" recon_loss = nn.functional.mse_loss(recon_x, x, reduction='sum') kl_loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp()) total_loss = recon_loss + self.beta * kl_loss return { 'total_loss': total_loss, 'recon_loss': recon_loss, 'kl_loss': kl_loss } # =============================================================== # 🔁 ConvLSTM (Frame Prediction) # =============================================================== class ConvLSTMCell(nn.Module): """A single ConvLSTM cell""" def __init__(self, input_dim: int, hidden_dim: int, kernel_size: int): super().__init__() padding = kernel_size // 2 self.input_dim = input_dim self.hidden_dim = hidden_dim self.conv = nn.Conv2d( input_dim + hidden_dim, 4 * hidden_dim, kernel_size=kernel_size, padding=padding ) def forward(self, x, hidden_state): h_prev, c_prev = hidden_state combined = torch.cat([x, h_prev], dim=1) conv_out = self.conv(combined) i, f, o, g = torch.split(conv_out, self.hidden_dim, dim=1) i = torch.sigmoid(i) f = torch.sigmoid(f) o = torch.sigmoid(o) g = torch.tanh(g) c_next = f * c_prev + i * g h_next = o * torch.tanh(c_next) return h_next, c_next def init_hidden(self, batch_size, spatial_size, device): H, W = spatial_size h = torch.zeros(batch_size, self.hidden_dim, H, W, device=device) c = torch.zeros(batch_size, self.hidden_dim, H, W, device=device) return h, c class ConvLSTM(nn.Module): """Multi-layer ConvLSTM network for next-frame prediction""" def __init__(self, input_channels: int = 1, hidden_channels: List[int] = [64, 64, 64], kernel_size: int = 3, output_channels: int = 1): super().__init__() self.input_channels = input_channels self.hidden_channels = hidden_channels self.num_layers = len(hidden_channels) # Stack multiple ConvLSTM cells self.cells = nn.ModuleList([ ConvLSTMCell( input_dim=input_channels if i == 0 else hidden_channels[i - 1], hidden_dim=hidden_channels[i], kernel_size=kernel_size ) for i in range(self.num_layers) ]) # Final 1×1 conv to generate output frame self.output_conv = nn.Conv2d(hidden_channels[-1], output_channels, kernel_size=1) def forward(self, x: torch.Tensor): """ Args: x: (B, T, C, H, W) Returns: Predicted next frame: (B, C, H, W) """ B, T, _, H, W = x.size() device = x.device # Initialize hidden states states = [cell.init_hidden(B, (H, W), device) for cell in self.cells] # Process input sequence for t in range(T): x_t = x[:, t] for i, cell in enumerate(self.cells): h, c = cell(x_t, states[i]) states[i] = (h, c) x_t = h # Output predicted frame return self.output_conv(states[-1][0])