ash12321 commited on
Commit
57eeb52
·
verified ·
1 Parent(s): 938c9ec

Create model.py

Browse files
Files changed (1) hide show
  1. model.py +76 -0
model.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import pytorch_lightning as pl
5
+
6
+ # ----------------------------------------------------
7
+ # A helper block for the Residual Connection
8
+ # ----------------------------------------------------
9
+ class ResidualBlock(nn.Module):
10
+ def __init__(self, in_channels, out_channels, stride=1):
11
+ super().__init__()
12
+ self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False)
13
+ self.bn1 = nn.BatchNorm2d(out_channels)
14
+ self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False)
15
+ self.bn2 = nn.BatchNorm2d(out_channels)
16
+
17
+ # Skip connection for differing channels/stride
18
+ self.shortcut = nn.Sequential()
19
+ if stride != 1 or in_channels != out_channels:
20
+ self.shortcut = nn.Sequential(
21
+ nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, bias=False),
22
+ nn.BatchNorm2d(out_channels)
23
+ )
24
+
25
+ def forward(self, x):
26
+ out = F.relu(self.bn1(self.conv1(x)))
27
+ out = self.bn2(self.conv2(out))
28
+ out += self.shortcut(x)
29
+ out = F.relu(out)
30
+ return out
31
+
32
+ # ----------------------------------------------------
33
+ # The Main Residual Autoencoder Model
34
+ # ----------------------------------------------------
35
+ class ResidualConvAutoencoder(pl.LightningModule):
36
+ def __init__(self, latent_dim=512, dropout_rate=0.2):
37
+ super().__init__()
38
+ self.latent_dim = latent_dim
39
+
40
+ # --- Encoder ---
41
+ self.encoder = nn.Sequential(
42
+ nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1), # 32x32 -> 32x32
43
+ ResidualBlock(64, 128, stride=2), # 32x32 -> 16x16
44
+ ResidualBlock(128, 256, stride=2), # 16x16 -> 8x8
45
+ ResidualBlock(256, 512, stride=2), # 8x8 -> 4x4
46
+ nn.Flatten(),
47
+ nn.Linear(512 * 4 * 4, self.latent_dim),
48
+ nn.Dropout(dropout_rate)
49
+ )
50
+
51
+ # --- Decoder ---
52
+ self.decoder = nn.Sequential(
53
+ nn.Linear(self.latent_dim, 512 * 4 * 4),
54
+ nn.Unflatten(1, (512, 4, 4)),
55
+ ResidualBlock(512, 256),
56
+ nn.Upsample(scale_factor=2, mode='nearest'), # 4x4 -> 8x8
57
+ ResidualBlock(256, 128),
58
+ nn.Upsample(scale_factor=2, mode='nearest'), # 8x8 -> 16x16
59
+ ResidualBlock(128, 64),
60
+ nn.Upsample(scale_factor=2, mode='nearest'), # 16x16 -> 32x32
61
+ nn.Conv2d(64, 3, kernel_size=3, stride=1, padding=1),
62
+ nn.Sigmoid() # Output pixel values between 0 and 1
63
+ )
64
+
65
+ def forward(self, x):
66
+ z = self.encoder(x)
67
+ recon = self.decoder(z)
68
+ return recon
69
+
70
+ # Placeholder training step (not needed for deployment file, but required for class completeness)
71
+ def training_step(self, batch, batch_idx):
72
+ return torch.tensor(0.0)
73
+
74
+ # Placeholder configure_optimizers (not needed for deployment file, but required for class completeness)
75
+ def configure_optimizers(self):
76
+ return torch.optim.Adam(self.parameters())