import torch.nn as nn class ConvAutoencoder_2D(nn.Module): """ 2-D convolutional autoencoder with a linear bottleneck, refactored for robustness. Architecture: Encoder: - A dedicated Conv2d layer to handle the (C_in, L) input shape. - A loop of subsequent Conv2d layers for further downsampling. Bottleneck: - Flatten -> Linear -> ReLU -> Linear(latent_dim) Decoder: - Linear -> ReLU -> Linear -> Unflatten - A loop of ConvTranspose2d layers to upsample. - A dedicated final ConvTranspose2d layer to restore the original (C_in, L) shape. """ def __init__(self, c_in: int, latent_dim: int) -> None: super().__init__() self.c_in: int = c_in self.latent_dim: int = latent_dim # ----- Encoder ----- # Channel plan now starts from 1, as we treat the input channels as a spatial dimension. channels: list[int] = [1, 8, 16, 32] strides: list[tuple[int, int]] = [(1, 2), (1, 2), (1, 2)] enc_layers: list[nn.Module] = [] # 1. Dedicated first layer to collapse the 'c_in' dimension enc_layers += [ nn.Conv2d( in_channels=channels[0], out_channels=channels[1], kernel_size=(self.c_in, 3), # Kernel height matches input height stride=strides[0], padding=(0, 1), ), nn.ReLU(inplace=True), ] # 2. Loop for subsequent layers where height is already 1 for i in range(1, len(strides)): enc_layers += [ nn.Conv2d( in_channels=channels[i], out_channels=channels[i + 1], kernel_size=(1, 3), # Kernel height is now 1 stride=strides[i], padding=(0, 1), ), nn.ReLU(inplace=True), ] self.encoder = nn.Sequential(*enc_layers) # ----- Dynamic Shape Calculation for Bottleneck ----- # Pass a dummy tensor through the encoder to find the flattened size automatically. with torch.no_grad(): dummy_input = torch.zeros(1, self.c_in, 10_000) h = self.encoder(dummy_input.unsqueeze(1)) self.bottleneck_shape = h.shape[1:] # (C, H, W) flat_bottleneck = h.flatten().shape[0] # ----- Linear Bottleneck Layers ----- self.to_latent = nn.Sequential( nn.Flatten(), nn.Linear(flat_bottleneck, 256), nn.ReLU(inplace=True), nn.Linear(256, latent_dim), ) self.from_latent = nn.Sequential( nn.Linear(latent_dim, 256), nn.ReLU(inplace=True), nn.Linear(256, flat_bottleneck), nn.Unflatten(dim=1, unflattened_size=self.bottleneck_shape), ) # ----- Decoder (mirrors the encoder) ----- dec_layers: list[nn.Module] = [] # Invert the channel plan for the decoder dec_channels = list(reversed(channels)) # 1. Loop for first set of layers (mirrors the encoder's main loop) for i in range(len(strides) - 1): dec_layers += [ nn.ConvTranspose2d( in_channels=dec_channels[i], out_channels=dec_channels[i + 1], kernel_size=(1, 3), stride=strides[-(i + 1)], padding=(0, 1), output_padding=(0, 1), # Needed to correct shape with stride > 1 ), nn.ReLU(inplace=True), ] # 2. Dedicated final layer to restore the original 'c_in' height dec_layers += [ nn.ConvTranspose2d( in_channels=dec_channels[-2], out_channels=dec_channels[-1], kernel_size=(self.c_in, 3), stride=strides[0], padding=(0, 1), output_padding=(0, 1), ), ##nn.Tanh(), # Final activation ## parametric ReLU nn.PReLU(), ] self.decoder = nn.Sequential(*dec_layers) def encode(self, x: torch.Tensor) -> torch.Tensor: """Encode input to a latent vector.""" # Reshape to (N, 1, C_in, 10_000) to treat C_in as a spatial dimension h = x.unsqueeze(1) h = self.encoder(h) z = self.to_latent(h) return z def decode(self, z: torch.Tensor) -> torch.Tensor: """Decode latent vector back to the input space.""" h = self.from_latent(z) x_hat = self.decoder(h) # Reshape back to (N, C_in, 10_000) x_hat = x_hat.squeeze(1) return x_hat def forward(self, x: torch.Tensor) -> torch.Tensor: """Full autoencoder forward pass.""" original_len = x.shape[2] z = self.encode(x) x_hat = self.decode(z) # Crop output to match input dimensions return x_hat[:, :, :original_len]