| | import torch.nn as nn
|
| |
|
| | class ConvAutoencoder_2D(nn.Module):
|
| | """
|
| | 2-D convolutional autoencoder with a linear bottleneck, refactored for robustness.
|
| |
|
| | Architecture:
|
| | Encoder:
|
| | - A dedicated Conv2d layer to handle the (C_in, L) input shape.
|
| | - A loop of subsequent Conv2d layers for further downsampling.
|
| | Bottleneck:
|
| | - Flatten -> Linear -> ReLU -> Linear(latent_dim)
|
| | Decoder:
|
| | - Linear -> ReLU -> Linear -> Unflatten
|
| | - A loop of ConvTranspose2d layers to upsample.
|
| | - A dedicated final ConvTranspose2d layer to restore the original (C_in, L) shape.
|
| | """
|
| |
|
| | def __init__(self, c_in: int, latent_dim: int) -> None:
|
| | super().__init__()
|
| | self.c_in: int = c_in
|
| | self.latent_dim: int = latent_dim
|
| |
|
| |
|
| |
|
| | channels: list[int] = [1, 8, 16, 32]
|
| | strides: list[tuple[int, int]] = [(1, 2), (1, 2), (1, 2)]
|
| |
|
| | enc_layers: list[nn.Module] = []
|
| |
|
| |
|
| | enc_layers += [
|
| | nn.Conv2d(
|
| | in_channels=channels[0],
|
| | out_channels=channels[1],
|
| | kernel_size=(self.c_in, 3),
|
| | stride=strides[0],
|
| | padding=(0, 1),
|
| | ),
|
| | nn.ReLU(inplace=True),
|
| | ]
|
| |
|
| |
|
| | for i in range(1, len(strides)):
|
| | enc_layers += [
|
| | nn.Conv2d(
|
| | in_channels=channels[i],
|
| | out_channels=channels[i + 1],
|
| | kernel_size=(1, 3),
|
| | stride=strides[i],
|
| | padding=(0, 1),
|
| | ),
|
| | nn.ReLU(inplace=True),
|
| | ]
|
| |
|
| | self.encoder = nn.Sequential(*enc_layers)
|
| |
|
| |
|
| |
|
| | with torch.no_grad():
|
| | dummy_input = torch.zeros(1, self.c_in, 10_000)
|
| | h = self.encoder(dummy_input.unsqueeze(1))
|
| | self.bottleneck_shape = h.shape[1:]
|
| | flat_bottleneck = h.flatten().shape[0]
|
| |
|
| |
|
| | self.to_latent = nn.Sequential(
|
| | nn.Flatten(),
|
| | nn.Linear(flat_bottleneck, 256),
|
| | nn.ReLU(inplace=True),
|
| | nn.Linear(256, latent_dim),
|
| | )
|
| |
|
| | self.from_latent = nn.Sequential(
|
| | nn.Linear(latent_dim, 256),
|
| | nn.ReLU(inplace=True),
|
| | nn.Linear(256, flat_bottleneck),
|
| | nn.Unflatten(dim=1, unflattened_size=self.bottleneck_shape),
|
| | )
|
| |
|
| |
|
| | dec_layers: list[nn.Module] = []
|
| |
|
| |
|
| | dec_channels = list(reversed(channels))
|
| |
|
| |
|
| | for i in range(len(strides) - 1):
|
| | dec_layers += [
|
| | nn.ConvTranspose2d(
|
| | in_channels=dec_channels[i],
|
| | out_channels=dec_channels[i + 1],
|
| | kernel_size=(1, 3),
|
| | stride=strides[-(i + 1)],
|
| | padding=(0, 1),
|
| | output_padding=(0, 1),
|
| | ),
|
| | nn.ReLU(inplace=True),
|
| | ]
|
| |
|
| |
|
| | dec_layers += [
|
| | nn.ConvTranspose2d(
|
| | in_channels=dec_channels[-2],
|
| | out_channels=dec_channels[-1],
|
| | kernel_size=(self.c_in, 3),
|
| | stride=strides[0],
|
| | padding=(0, 1),
|
| | output_padding=(0, 1),
|
| | ),
|
| |
|
| |
|
| | nn.PReLU(),
|
| | ]
|
| |
|
| | self.decoder = nn.Sequential(*dec_layers)
|
| |
|
| | def encode(self, x: torch.Tensor) -> torch.Tensor:
|
| | """Encode input to a latent vector."""
|
| |
|
| | h = x.unsqueeze(1)
|
| | h = self.encoder(h)
|
| | z = self.to_latent(h)
|
| | return z
|
| |
|
| | def decode(self, z: torch.Tensor) -> torch.Tensor:
|
| | """Decode latent vector back to the input space."""
|
| | h = self.from_latent(z)
|
| | x_hat = self.decoder(h)
|
| |
|
| | x_hat = x_hat.squeeze(1)
|
| | return x_hat
|
| |
|
| | def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| | """Full autoencoder forward pass."""
|
| | original_len = x.shape[2]
|
| | z = self.encode(x)
|
| | x_hat = self.decode(z)
|
| |
|
| |
|
| | return x_hat[:, :, :original_len]
|
| |
|