File size: 1,590 Bytes
ed83042 00588fa ed83042 40e07cf c4f79b7 ed83042 c9f1804 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 | import torch
import torch.nn as nn
from .configuration_autoencoder import AutoencoderConfig
from transformers import PreTrainedModel
class Encoder(nn.Module):
def __init__(self, latent_dim=256):
super().__init__()
self.conv = nn.Sequential(
nn.Conv2d(3, 32, 4, 2, 1), # 256 → 128
nn.ReLU(),
nn.Conv2d(32, 64, 4, 2, 1), # 128 → 64
nn.ReLU(),
nn.Conv2d(64, 128, 4, 2, 1), # 64 → 32
nn.ReLU(),
)
self.flatten = nn.Flatten()
self.fc = nn.Linear(128 * 32 * 32, latent_dim)
def forward(self, x):
x = self.conv(x)
x = self.flatten(x)
return self.fc(x)
class Decoder(nn.Module):
def __init__(self, latent_dim=256):
super().__init__()
self.fc = nn.Linear(latent_dim, 128 * 32 * 32)
self.deconv = nn.Sequential(
nn.ConvTranspose2d(128, 64, 4, 2, 1),
nn.ReLU(),
nn.ConvTranspose2d(64, 32, 4, 2, 1),
nn.ReLU(),
nn.ConvTranspose2d(32, 3, 4, 2, 1),
nn.Sigmoid()
)
def forward(self, z):
x = self.fc(z)
x = x.view(-1, 128, 32, 32)
return self.deconv(x)
class Autoencoder(PreTrainedModel):
config_class = AutoencoderConfig
def __init__(self, config):
super().__init__(config)
self.encoder = Encoder(config.latent_dim)
self.decoder = Decoder(config.latent_dim)
self.post_init()
def forward(self, x):
z = self.encoder(x)
return self.decoder(z) |