phirni's picture
Create model.py
cb80641 verified
import torch
import torch.nn as nn
from typing import Tuple, List
# ===============================================================
# 🧠 Beta-VAE (Frame Compression / Reconstruction)
# ===============================================================
class BetaVAE(nn.Module):
"""Ξ²-Variational Autoencoder for frame compression"""
def __init__(self, input_channels: int = 1, latent_dim: int = 64, beta: float = 4.0):
super().__init__()
self.latent_dim = latent_dim
self.beta = beta
# Encoder: 64x64 β†’ 4x4
self.encoder = nn.Sequential(
nn.Conv2d(input_channels, 32, kernel_size=4, stride=2, padding=1), # 64β†’32
nn.ReLU(),
nn.Conv2d(32, 64, kernel_size=4, stride=2, padding=1), # 32β†’16
nn.ReLU(),
nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1), # 16β†’8
nn.ReLU(),
nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1), # 8β†’4
nn.ReLU(),
)
# Latent space
self.fc_mu = nn.Linear(256 * 4 * 4, latent_dim)
self.fc_logvar = nn.Linear(256 * 4 * 4, latent_dim)
# Decoder input
self.decoder_input = nn.Linear(latent_dim, 256 * 4 * 4)
# Decoder: 4x4 β†’ 64x64
self.decoder = nn.Sequential(
nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1), # 4β†’8
nn.ReLU(),
nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1), # 8β†’16
nn.ReLU(),
nn.ConvTranspose2d(64, 32, kernel_size=4, stride=2, padding=1), # 16β†’32
nn.ReLU(),
nn.ConvTranspose2d(32, input_channels, kernel_size=4, stride=2, padding=1), # 32β†’64
nn.Sigmoid() # output [0,1]
)
# -----------------------------------------------------------
# Encoder / Decoder Logic
# -----------------------------------------------------------
def encode(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
h = self.encoder(x)
h = h.view(h.size(0), -1)
mu = self.fc_mu(h)
logvar = self.fc_logvar(h)
return mu, logvar
def reparameterize(self, mu: torch.Tensor, logvar: torch.Tensor) -> torch.Tensor:
std = torch.exp(0.5 * logvar)
eps = torch.randn_like(std)
return mu + eps * std
def decode(self, z: torch.Tensor) -> torch.Tensor:
h = self.decoder_input(z)
h = h.view(h.size(0), 256, 4, 4)
return self.decoder(h)
# -----------------------------------------------------------
# Forward + Loss
# -----------------------------------------------------------
def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
mu, logvar = self.encode(x)
z = self.reparameterize(mu, logvar)
recon = self.decode(z)
return recon, mu, logvar
def loss_function(self, recon_x: torch.Tensor, x: torch.Tensor,
mu: torch.Tensor, logvar: torch.Tensor) -> dict:
"""Compute Ξ²-VAE loss"""
recon_loss = nn.functional.mse_loss(recon_x, x, reduction='sum')
kl_loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
total_loss = recon_loss + self.beta * kl_loss
return {
'total_loss': total_loss,
'recon_loss': recon_loss,
'kl_loss': kl_loss
}
# ===============================================================
# πŸ” ConvLSTM (Frame Prediction)
# ===============================================================
class ConvLSTMCell(nn.Module):
"""A single ConvLSTM cell"""
def __init__(self, input_dim: int, hidden_dim: int, kernel_size: int):
super().__init__()
padding = kernel_size // 2
self.input_dim = input_dim
self.hidden_dim = hidden_dim
self.conv = nn.Conv2d(
input_dim + hidden_dim,
4 * hidden_dim,
kernel_size=kernel_size,
padding=padding
)
def forward(self, x, hidden_state):
h_prev, c_prev = hidden_state
combined = torch.cat([x, h_prev], dim=1)
conv_out = self.conv(combined)
i, f, o, g = torch.split(conv_out, self.hidden_dim, dim=1)
i = torch.sigmoid(i)
f = torch.sigmoid(f)
o = torch.sigmoid(o)
g = torch.tanh(g)
c_next = f * c_prev + i * g
h_next = o * torch.tanh(c_next)
return h_next, c_next
def init_hidden(self, batch_size, spatial_size, device):
H, W = spatial_size
h = torch.zeros(batch_size, self.hidden_dim, H, W, device=device)
c = torch.zeros(batch_size, self.hidden_dim, H, W, device=device)
return h, c
class ConvLSTM(nn.Module):
"""Multi-layer ConvLSTM network for next-frame prediction"""
def __init__(self,
input_channels: int = 1,
hidden_channels: List[int] = [64, 64, 64],
kernel_size: int = 3,
output_channels: int = 1):
super().__init__()
self.input_channels = input_channels
self.hidden_channels = hidden_channels
self.num_layers = len(hidden_channels)
# Stack multiple ConvLSTM cells
self.cells = nn.ModuleList([
ConvLSTMCell(
input_dim=input_channels if i == 0 else hidden_channels[i - 1],
hidden_dim=hidden_channels[i],
kernel_size=kernel_size
)
for i in range(self.num_layers)
])
# Final 1Γ—1 conv to generate output frame
self.output_conv = nn.Conv2d(hidden_channels[-1], output_channels, kernel_size=1)
def forward(self, x: torch.Tensor):
"""
Args:
x: (B, T, C, H, W)
Returns:
Predicted next frame: (B, C, H, W)
"""
B, T, _, H, W = x.size()
device = x.device
# Initialize hidden states
states = [cell.init_hidden(B, (H, W), device) for cell in self.cells]
# Process input sequence
for t in range(T):
x_t = x[:, t]
for i, cell in enumerate(self.cells):
h, c = cell(x_t, states[i])
states[i] = (h, c)
x_t = h
# Output predicted frame
return self.output_conv(states[-1][0])