vehitv / modeling_autoencoder.py
quebeccyb's picture
Fix: import config
00588fa
Raw
History Blame Contribute Delete
1.59 kB
import torch
import torch.nn as nn
from .configuration_autoencoder import AutoencoderConfig
from transformers import PreTrainedModel
class Encoder(nn.Module):
def __init__(self, latent_dim=256):
super().__init__()
self.conv = nn.Sequential(
nn.Conv2d(3, 32, 4, 2, 1), # 256 → 128
nn.ReLU(),
nn.Conv2d(32, 64, 4, 2, 1), # 128 → 64
nn.ReLU(),
nn.Conv2d(64, 128, 4, 2, 1), # 64 → 32
nn.ReLU(),
)
self.flatten = nn.Flatten()
self.fc = nn.Linear(128 * 32 * 32, latent_dim)
def forward(self, x):
x = self.conv(x)
x = self.flatten(x)
return self.fc(x)
class Decoder(nn.Module):
def __init__(self, latent_dim=256):
super().__init__()
self.fc = nn.Linear(latent_dim, 128 * 32 * 32)
self.deconv = nn.Sequential(
nn.ConvTranspose2d(128, 64, 4, 2, 1),
nn.ReLU(),
nn.ConvTranspose2d(64, 32, 4, 2, 1),
nn.ReLU(),
nn.ConvTranspose2d(32, 3, 4, 2, 1),
nn.Sigmoid()
)
def forward(self, z):
x = self.fc(z)
x = x.view(-1, 128, 32, 32)
return self.deconv(x)
class Autoencoder(PreTrainedModel):
config_class = AutoencoderConfig
def __init__(self, config):
super().__init__(config)
self.encoder = Encoder(config.latent_dim)
self.decoder = Decoder(config.latent_dim)
self.post_init()
def forward(self, x):
z = self.encoder(x)
return self.decoder(z)