File size: 3,150 Bytes
57eeb52
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
import torch
import torch.nn as nn
import torch.nn.functional as F
import pytorch_lightning as pl

# ----------------------------------------------------
# A helper block for the Residual Connection
# ----------------------------------------------------
class ResidualBlock(nn.Module):
    def __init__(self, in_channels, out_channels, stride=1):
        super().__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(out_channels)

        # Skip connection for differing channels/stride
        self.shortcut = nn.Sequential()
        if stride != 1 or in_channels != out_channels:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(out_channels)
            )

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out += self.shortcut(x)
        out = F.relu(out)
        return out

# ----------------------------------------------------
# The Main Residual Autoencoder Model
# ----------------------------------------------------
class ResidualConvAutoencoder(pl.LightningModule):
    def __init__(self, latent_dim=512, dropout_rate=0.2):
        super().__init__()
        self.latent_dim = latent_dim
        
        # --- Encoder ---
        self.encoder = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1), # 32x32 -> 32x32
            ResidualBlock(64, 128, stride=2), # 32x32 -> 16x16
            ResidualBlock(128, 256, stride=2), # 16x16 -> 8x8
            ResidualBlock(256, 512, stride=2), # 8x8 -> 4x4
            nn.Flatten(),
            nn.Linear(512 * 4 * 4, self.latent_dim),
            nn.Dropout(dropout_rate)
        )
        
        # --- Decoder ---
        self.decoder = nn.Sequential(
            nn.Linear(self.latent_dim, 512 * 4 * 4),
            nn.Unflatten(1, (512, 4, 4)),
            ResidualBlock(512, 256),
            nn.Upsample(scale_factor=2, mode='nearest'), # 4x4 -> 8x8
            ResidualBlock(256, 128),
            nn.Upsample(scale_factor=2, mode='nearest'), # 8x8 -> 16x16
            ResidualBlock(128, 64),
            nn.Upsample(scale_factor=2, mode='nearest'), # 16x16 -> 32x32
            nn.Conv2d(64, 3, kernel_size=3, stride=1, padding=1),
            nn.Sigmoid() # Output pixel values between 0 and 1
        )
        
    def forward(self, x):
        z = self.encoder(x)
        recon = self.decoder(z)
        return recon
    
    # Placeholder training step (not needed for deployment file, but required for class completeness)
    def training_step(self, batch, batch_idx):
        return torch.tensor(0.0) 

    # Placeholder configure_optimizers (not needed for deployment file, but required for class completeness)
    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters())