File size: 1,590 Bytes
ed83042
 
00588fa
ed83042
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40e07cf
c4f79b7
ed83042
 
 
 
 
 
 
 
 
 
 
c9f1804
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
import torch
import torch.nn as nn
from .configuration_autoencoder import AutoencoderConfig
from transformers import PreTrainedModel


class Encoder(nn.Module):
    def __init__(self, latent_dim=256):
        super().__init__()

        self.conv = nn.Sequential(
            nn.Conv2d(3, 32, 4, 2, 1),   # 256 → 128
            nn.ReLU(),
            nn.Conv2d(32, 64, 4, 2, 1),  # 128 → 64
            nn.ReLU(),
            nn.Conv2d(64, 128, 4, 2, 1), # 64 → 32
            nn.ReLU(),
        )

        self.flatten = nn.Flatten()
        self.fc = nn.Linear(128 * 32 * 32, latent_dim)

    def forward(self, x):
        x = self.conv(x)
        x = self.flatten(x)
        return self.fc(x)


class Decoder(nn.Module):
    def __init__(self, latent_dim=256):
        super().__init__()

        self.fc = nn.Linear(latent_dim, 128 * 32 * 32)

        self.deconv = nn.Sequential(
            nn.ConvTranspose2d(128, 64, 4, 2, 1),
            nn.ReLU(),
            nn.ConvTranspose2d(64, 32, 4, 2, 1),
            nn.ReLU(),
            nn.ConvTranspose2d(32, 3, 4, 2, 1),
            nn.Sigmoid()
        )

    def forward(self, z):
        x = self.fc(z)
        x = x.view(-1, 128, 32, 32)
        return self.deconv(x)


class Autoencoder(PreTrainedModel):
    config_class = AutoencoderConfig 

    def __init__(self, config):
        super().__init__(config)

        self.encoder = Encoder(config.latent_dim)
        self.decoder = Decoder(config.latent_dim)

        self.post_init()

    def forward(self, x):
        z = self.encoder(x)
        return self.decoder(z)