| import torch |
| import torch.nn as nn |
| from .configuration_autoencoder import AutoencoderConfig |
| from transformers import PreTrainedModel |
|
|
|
|
| class Encoder(nn.Module): |
| def __init__(self, latent_dim=256): |
| super().__init__() |
|
|
| self.conv = nn.Sequential( |
| nn.Conv2d(3, 32, 4, 2, 1), |
| nn.ReLU(), |
| nn.Conv2d(32, 64, 4, 2, 1), |
| nn.ReLU(), |
| nn.Conv2d(64, 128, 4, 2, 1), |
| nn.ReLU(), |
| ) |
|
|
| self.flatten = nn.Flatten() |
| self.fc = nn.Linear(128 * 32 * 32, latent_dim) |
|
|
| def forward(self, x): |
| x = self.conv(x) |
| x = self.flatten(x) |
| return self.fc(x) |
|
|
|
|
| class Decoder(nn.Module): |
| def __init__(self, latent_dim=256): |
| super().__init__() |
|
|
| self.fc = nn.Linear(latent_dim, 128 * 32 * 32) |
|
|
| self.deconv = nn.Sequential( |
| nn.ConvTranspose2d(128, 64, 4, 2, 1), |
| nn.ReLU(), |
| nn.ConvTranspose2d(64, 32, 4, 2, 1), |
| nn.ReLU(), |
| nn.ConvTranspose2d(32, 3, 4, 2, 1), |
| nn.Sigmoid() |
| ) |
|
|
| def forward(self, z): |
| x = self.fc(z) |
| x = x.view(-1, 128, 32, 32) |
| return self.deconv(x) |
|
|
|
|
| class Autoencoder(PreTrainedModel): |
| config_class = AutoencoderConfig |
|
|
| def __init__(self, config): |
| super().__init__(config) |
|
|
| self.encoder = Encoder(config.latent_dim) |
| self.decoder = Decoder(config.latent_dim) |
|
|
| self.post_init() |
|
|
| def forward(self, x): |
| z = self.encoder(x) |
| return self.decoder(z) |