import torch import torch.nn as nn from .configuration_autoencoder import AutoencoderConfig from transformers import PreTrainedModel class Encoder(nn.Module): def __init__(self, latent_dim=256): super().__init__() self.conv = nn.Sequential( nn.Conv2d(3, 32, 4, 2, 1), # 256 → 128 nn.ReLU(), nn.Conv2d(32, 64, 4, 2, 1), # 128 → 64 nn.ReLU(), nn.Conv2d(64, 128, 4, 2, 1), # 64 → 32 nn.ReLU(), ) self.flatten = nn.Flatten() self.fc = nn.Linear(128 * 32 * 32, latent_dim) def forward(self, x): x = self.conv(x) x = self.flatten(x) return self.fc(x) class Decoder(nn.Module): def __init__(self, latent_dim=256): super().__init__() self.fc = nn.Linear(latent_dim, 128 * 32 * 32) self.deconv = nn.Sequential( nn.ConvTranspose2d(128, 64, 4, 2, 1), nn.ReLU(), nn.ConvTranspose2d(64, 32, 4, 2, 1), nn.ReLU(), nn.ConvTranspose2d(32, 3, 4, 2, 1), nn.Sigmoid() ) def forward(self, z): x = self.fc(z) x = x.view(-1, 128, 32, 32) return self.deconv(x) class Autoencoder(PreTrainedModel): config_class = AutoencoderConfig def __init__(self, config): super().__init__(config) self.encoder = Encoder(config.latent_dim) self.decoder = Decoder(config.latent_dim) self.post_init() def forward(self, x): z = self.encoder(x) return self.decoder(z)