Spaces:
Sleeping
Sleeping
| import torch | |
| import torch.nn as nn | |
| from typing import Tuple, List | |
| # =============================================================== | |
| # π§ Beta-VAE (Frame Compression / Reconstruction) | |
| # =============================================================== | |
| class BetaVAE(nn.Module): | |
| """Ξ²-Variational Autoencoder for frame compression""" | |
| def __init__(self, input_channels: int = 1, latent_dim: int = 64, beta: float = 4.0): | |
| super().__init__() | |
| self.latent_dim = latent_dim | |
| self.beta = beta | |
| # Encoder: 64x64 β 4x4 | |
| self.encoder = nn.Sequential( | |
| nn.Conv2d(input_channels, 32, kernel_size=4, stride=2, padding=1), # 64β32 | |
| nn.ReLU(), | |
| nn.Conv2d(32, 64, kernel_size=4, stride=2, padding=1), # 32β16 | |
| nn.ReLU(), | |
| nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1), # 16β8 | |
| nn.ReLU(), | |
| nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1), # 8β4 | |
| nn.ReLU(), | |
| ) | |
| # Latent space | |
| self.fc_mu = nn.Linear(256 * 4 * 4, latent_dim) | |
| self.fc_logvar = nn.Linear(256 * 4 * 4, latent_dim) | |
| # Decoder input | |
| self.decoder_input = nn.Linear(latent_dim, 256 * 4 * 4) | |
| # Decoder: 4x4 β 64x64 | |
| self.decoder = nn.Sequential( | |
| nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1), # 4β8 | |
| nn.ReLU(), | |
| nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1), # 8β16 | |
| nn.ReLU(), | |
| nn.ConvTranspose2d(64, 32, kernel_size=4, stride=2, padding=1), # 16β32 | |
| nn.ReLU(), | |
| nn.ConvTranspose2d(32, input_channels, kernel_size=4, stride=2, padding=1), # 32β64 | |
| nn.Sigmoid() # output [0,1] | |
| ) | |
| # ----------------------------------------------------------- | |
| # Encoder / Decoder Logic | |
| # ----------------------------------------------------------- | |
| def encode(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: | |
| h = self.encoder(x) | |
| h = h.view(h.size(0), -1) | |
| mu = self.fc_mu(h) | |
| logvar = self.fc_logvar(h) | |
| return mu, logvar | |
| def reparameterize(self, mu: torch.Tensor, logvar: torch.Tensor) -> torch.Tensor: | |
| std = torch.exp(0.5 * logvar) | |
| eps = torch.randn_like(std) | |
| return mu + eps * std | |
| def decode(self, z: torch.Tensor) -> torch.Tensor: | |
| h = self.decoder_input(z) | |
| h = h.view(h.size(0), 256, 4, 4) | |
| return self.decoder(h) | |
| # ----------------------------------------------------------- | |
| # Forward + Loss | |
| # ----------------------------------------------------------- | |
| def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: | |
| mu, logvar = self.encode(x) | |
| z = self.reparameterize(mu, logvar) | |
| recon = self.decode(z) | |
| return recon, mu, logvar | |
| def loss_function(self, recon_x: torch.Tensor, x: torch.Tensor, | |
| mu: torch.Tensor, logvar: torch.Tensor) -> dict: | |
| """Compute Ξ²-VAE loss""" | |
| recon_loss = nn.functional.mse_loss(recon_x, x, reduction='sum') | |
| kl_loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp()) | |
| total_loss = recon_loss + self.beta * kl_loss | |
| return { | |
| 'total_loss': total_loss, | |
| 'recon_loss': recon_loss, | |
| 'kl_loss': kl_loss | |
| } | |
| # =============================================================== | |
| # π ConvLSTM (Frame Prediction) | |
| # =============================================================== | |
| class ConvLSTMCell(nn.Module): | |
| """A single ConvLSTM cell""" | |
| def __init__(self, input_dim: int, hidden_dim: int, kernel_size: int): | |
| super().__init__() | |
| padding = kernel_size // 2 | |
| self.input_dim = input_dim | |
| self.hidden_dim = hidden_dim | |
| self.conv = nn.Conv2d( | |
| input_dim + hidden_dim, | |
| 4 * hidden_dim, | |
| kernel_size=kernel_size, | |
| padding=padding | |
| ) | |
| def forward(self, x, hidden_state): | |
| h_prev, c_prev = hidden_state | |
| combined = torch.cat([x, h_prev], dim=1) | |
| conv_out = self.conv(combined) | |
| i, f, o, g = torch.split(conv_out, self.hidden_dim, dim=1) | |
| i = torch.sigmoid(i) | |
| f = torch.sigmoid(f) | |
| o = torch.sigmoid(o) | |
| g = torch.tanh(g) | |
| c_next = f * c_prev + i * g | |
| h_next = o * torch.tanh(c_next) | |
| return h_next, c_next | |
| def init_hidden(self, batch_size, spatial_size, device): | |
| H, W = spatial_size | |
| h = torch.zeros(batch_size, self.hidden_dim, H, W, device=device) | |
| c = torch.zeros(batch_size, self.hidden_dim, H, W, device=device) | |
| return h, c | |
| class ConvLSTM(nn.Module): | |
| """Multi-layer ConvLSTM network for next-frame prediction""" | |
| def __init__(self, | |
| input_channels: int = 1, | |
| hidden_channels: List[int] = [64, 64, 64], | |
| kernel_size: int = 3, | |
| output_channels: int = 1): | |
| super().__init__() | |
| self.input_channels = input_channels | |
| self.hidden_channels = hidden_channels | |
| self.num_layers = len(hidden_channels) | |
| # Stack multiple ConvLSTM cells | |
| self.cells = nn.ModuleList([ | |
| ConvLSTMCell( | |
| input_dim=input_channels if i == 0 else hidden_channels[i - 1], | |
| hidden_dim=hidden_channels[i], | |
| kernel_size=kernel_size | |
| ) | |
| for i in range(self.num_layers) | |
| ]) | |
| # Final 1Γ1 conv to generate output frame | |
| self.output_conv = nn.Conv2d(hidden_channels[-1], output_channels, kernel_size=1) | |
| def forward(self, x: torch.Tensor): | |
| """ | |
| Args: | |
| x: (B, T, C, H, W) | |
| Returns: | |
| Predicted next frame: (B, C, H, W) | |
| """ | |
| B, T, _, H, W = x.size() | |
| device = x.device | |
| # Initialize hidden states | |
| states = [cell.init_hidden(B, (H, W), device) for cell in self.cells] | |
| # Process input sequence | |
| for t in range(T): | |
| x_t = x[:, t] | |
| for i, cell in enumerate(self.cells): | |
| h, c = cell(x_t, states[i]) | |
| states[i] = (h, c) | |
| x_t = h | |
| # Output predicted frame | |
| return self.output_conv(states[-1][0]) | |