full-pegs-autoencoder-2Dconv-v1 / convautoencoder_2d.py
natesh-apoha's picture
Upload folder using huggingface_hub
5d118c4 verified
import torch.nn as nn
class ConvAutoencoder_2D(nn.Module):
"""
2-D convolutional autoencoder with a linear bottleneck, refactored for robustness.
Architecture:
Encoder:
- A dedicated Conv2d layer to handle the (C_in, L) input shape.
- A loop of subsequent Conv2d layers for further downsampling.
Bottleneck:
- Flatten -> Linear -> ReLU -> Linear(latent_dim)
Decoder:
- Linear -> ReLU -> Linear -> Unflatten
- A loop of ConvTranspose2d layers to upsample.
- A dedicated final ConvTranspose2d layer to restore the original (C_in, L) shape.
"""
def __init__(self, c_in: int, latent_dim: int) -> None:
super().__init__()
self.c_in: int = c_in
self.latent_dim: int = latent_dim
# ----- Encoder -----
# Channel plan now starts from 1, as we treat the input channels as a spatial dimension.
channels: list[int] = [1, 8, 16, 32]
strides: list[tuple[int, int]] = [(1, 2), (1, 2), (1, 2)]
enc_layers: list[nn.Module] = []
# 1. Dedicated first layer to collapse the 'c_in' dimension
enc_layers += [
nn.Conv2d(
in_channels=channels[0],
out_channels=channels[1],
kernel_size=(self.c_in, 3), # Kernel height matches input height
stride=strides[0],
padding=(0, 1),
),
nn.ReLU(inplace=True),
]
# 2. Loop for subsequent layers where height is already 1
for i in range(1, len(strides)):
enc_layers += [
nn.Conv2d(
in_channels=channels[i],
out_channels=channels[i + 1],
kernel_size=(1, 3), # Kernel height is now 1
stride=strides[i],
padding=(0, 1),
),
nn.ReLU(inplace=True),
]
self.encoder = nn.Sequential(*enc_layers)
# ----- Dynamic Shape Calculation for Bottleneck -----
# Pass a dummy tensor through the encoder to find the flattened size automatically.
with torch.no_grad():
dummy_input = torch.zeros(1, self.c_in, 10_000)
h = self.encoder(dummy_input.unsqueeze(1))
self.bottleneck_shape = h.shape[1:] # (C, H, W)
flat_bottleneck = h.flatten().shape[0]
# ----- Linear Bottleneck Layers -----
self.to_latent = nn.Sequential(
nn.Flatten(),
nn.Linear(flat_bottleneck, 256),
nn.ReLU(inplace=True),
nn.Linear(256, latent_dim),
)
self.from_latent = nn.Sequential(
nn.Linear(latent_dim, 256),
nn.ReLU(inplace=True),
nn.Linear(256, flat_bottleneck),
nn.Unflatten(dim=1, unflattened_size=self.bottleneck_shape),
)
# ----- Decoder (mirrors the encoder) -----
dec_layers: list[nn.Module] = []
# Invert the channel plan for the decoder
dec_channels = list(reversed(channels))
# 1. Loop for first set of layers (mirrors the encoder's main loop)
for i in range(len(strides) - 1):
dec_layers += [
nn.ConvTranspose2d(
in_channels=dec_channels[i],
out_channels=dec_channels[i + 1],
kernel_size=(1, 3),
stride=strides[-(i + 1)],
padding=(0, 1),
output_padding=(0, 1), # Needed to correct shape with stride > 1
),
nn.ReLU(inplace=True),
]
# 2. Dedicated final layer to restore the original 'c_in' height
dec_layers += [
nn.ConvTranspose2d(
in_channels=dec_channels[-2],
out_channels=dec_channels[-1],
kernel_size=(self.c_in, 3),
stride=strides[0],
padding=(0, 1),
output_padding=(0, 1),
),
##nn.Tanh(), # Final activation
## parametric ReLU
nn.PReLU(),
]
self.decoder = nn.Sequential(*dec_layers)
def encode(self, x: torch.Tensor) -> torch.Tensor:
"""Encode input to a latent vector."""
# Reshape to (N, 1, C_in, 10_000) to treat C_in as a spatial dimension
h = x.unsqueeze(1)
h = self.encoder(h)
z = self.to_latent(h)
return z
def decode(self, z: torch.Tensor) -> torch.Tensor:
"""Decode latent vector back to the input space."""
h = self.from_latent(z)
x_hat = self.decoder(h)
# Reshape back to (N, C_in, 10_000)
x_hat = x_hat.squeeze(1)
return x_hat
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Full autoencoder forward pass."""
original_len = x.shape[2]
z = self.encode(x)
x_hat = self.decode(z)
# Crop output to match input dimensions
return x_hat[:, :, :original_len]