ash12321 commited on
Commit
2bf4ec2
·
verified ·
1 Parent(s): f652fef

Upload model.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. model.py +182 -0
model.py ADDED
@@ -0,0 +1,182 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Residual Convolutional Autoencoder for Image Reconstruction
3
+ Architecture: 6-layer encoder/decoder with residual blocks
4
+ """
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+
10
+
11
+ class AEResidualBlock(nn.Module):
12
+ """Residual block with batch normalization and dropout"""
13
+ def __init__(self, channels, dropout=0.1):
14
+ super().__init__()
15
+ self.conv1 = nn.Conv2d(channels, channels, 3, padding=1)
16
+ self.bn1 = nn.BatchNorm2d(channels)
17
+ self.conv2 = nn.Conv2d(channels, channels, 3, padding=1)
18
+ self.bn2 = nn.BatchNorm2d(channels)
19
+ self.relu = nn.ReLU(inplace=True)
20
+ self.dropout = nn.Dropout2d(dropout) if dropout > 0 else nn.Identity()
21
+
22
+ def forward(self, x):
23
+ residual = x
24
+ out = self.relu(self.bn1(self.conv1(x)))
25
+ out = self.dropout(out)
26
+ out = self.bn2(self.conv2(out))
27
+ out += residual
28
+ return self.relu(out)
29
+
30
+
31
+ class ResidualConvAutoencoder(nn.Module):
32
+ """
33
+ Deep Convolutional Autoencoder with Residual Connections
34
+
35
+ Args:
36
+ latent_dim (int): Dimension of latent space (512 or 768)
37
+ dropout (float): Dropout rate for regularization (0.15 or 0.20)
38
+
39
+ Input: (B, 3, 256, 256) RGB images
40
+ Output: (B, 3, 256, 256) Reconstructed images + (B, latent_dim) latent codes
41
+ """
42
+ def __init__(self, latent_dim=512, dropout=0.15):
43
+ super().__init__()
44
+
45
+ self.latent_dim = latent_dim
46
+ self.dropout = dropout
47
+
48
+ # Encoder: 256x256 -> 4x4
49
+ self.encoder = nn.Sequential(
50
+ # 256 -> 128
51
+ nn.Conv2d(3, 64, 4, stride=2, padding=1),
52
+ nn.BatchNorm2d(64),
53
+ nn.ReLU(inplace=True),
54
+ AEResidualBlock(64, dropout),
55
+
56
+ # 128 -> 64
57
+ nn.Conv2d(64, 128, 4, stride=2, padding=1),
58
+ nn.BatchNorm2d(128),
59
+ nn.ReLU(inplace=True),
60
+ AEResidualBlock(128, dropout),
61
+
62
+ # 64 -> 32
63
+ nn.Conv2d(128, 256, 4, stride=2, padding=1),
64
+ nn.BatchNorm2d(256),
65
+ nn.ReLU(inplace=True),
66
+ AEResidualBlock(256, dropout),
67
+
68
+ # 32 -> 16
69
+ nn.Conv2d(256, 512, 4, stride=2, padding=1),
70
+ nn.BatchNorm2d(512),
71
+ nn.ReLU(inplace=True),
72
+ AEResidualBlock(512, dropout),
73
+
74
+ # 16 -> 8
75
+ nn.Conv2d(512, 512, 4, stride=2, padding=1),
76
+ nn.BatchNorm2d(512),
77
+ nn.ReLU(inplace=True),
78
+ AEResidualBlock(512, dropout),
79
+
80
+ # 8 -> 4
81
+ nn.Conv2d(512, 512, 4, stride=2, padding=1),
82
+ nn.BatchNorm2d(512),
83
+ nn.ReLU(inplace=True),
84
+ )
85
+
86
+ # Latent space projection
87
+ self.fc_encoder = nn.Linear(512 * 4 * 4, latent_dim)
88
+ self.fc_decoder = nn.Linear(latent_dim, 512 * 4 * 4)
89
+
90
+ # Decoder: 4x4 -> 256x256
91
+ self.decoder = nn.Sequential(
92
+ # 4 -> 8
93
+ nn.ConvTranspose2d(512, 512, 4, stride=2, padding=1),
94
+ nn.BatchNorm2d(512),
95
+ nn.ReLU(inplace=True),
96
+ AEResidualBlock(512, dropout),
97
+
98
+ # 8 -> 16
99
+ nn.ConvTranspose2d(512, 512, 4, stride=2, padding=1),
100
+ nn.BatchNorm2d(512),
101
+ nn.ReLU(inplace=True),
102
+ AEResidualBlock(512, dropout),
103
+
104
+ # 16 -> 32
105
+ nn.ConvTranspose2d(512, 256, 4, stride=2, padding=1),
106
+ nn.BatchNorm2d(256),
107
+ nn.ReLU(inplace=True),
108
+ AEResidualBlock(256, dropout),
109
+
110
+ # 32 -> 64
111
+ nn.ConvTranspose2d(256, 128, 4, stride=2, padding=1),
112
+ nn.BatchNorm2d(128),
113
+ nn.ReLU(inplace=True),
114
+ AEResidualBlock(128, dropout),
115
+
116
+ # 64 -> 128
117
+ nn.ConvTranspose2d(128, 64, 4, stride=2, padding=1),
118
+ nn.BatchNorm2d(64),
119
+ nn.ReLU(inplace=True),
120
+ AEResidualBlock(64, dropout),
121
+
122
+ # 128 -> 256
123
+ nn.ConvTranspose2d(64, 3, 4, stride=2, padding=1),
124
+ nn.Tanh() # Output in [-1, 1]
125
+ )
126
+
127
+ def forward(self, x):
128
+ """
129
+ Forward pass
130
+
131
+ Args:
132
+ x: Input tensor (B, 3, 256, 256) in range [-1, 1]
133
+
134
+ Returns:
135
+ reconstructed: Reconstructed tensor (B, 3, 256, 256)
136
+ latent: Latent representation (B, latent_dim)
137
+ """
138
+ # Encode
139
+ x = self.encoder(x)
140
+ x = x.view(x.size(0), -1)
141
+ latent = self.fc_encoder(x)
142
+
143
+ # Decode
144
+ x = self.fc_decoder(latent)
145
+ x = x.view(x.size(0), 512, 4, 4)
146
+ reconstructed = self.decoder(x)
147
+
148
+ return reconstructed, latent
149
+
150
+ def encode(self, x):
151
+ """Get latent representation only"""
152
+ x = self.encoder(x)
153
+ x = x.view(x.size(0), -1)
154
+ return self.fc_encoder(x)
155
+
156
+ def decode(self, latent):
157
+ """Reconstruct from latent code"""
158
+ x = self.fc_decoder(latent)
159
+ x = x.view(x.size(0), 512, 4, 4)
160
+ return self.decoder(x)
161
+
162
+
163
+ def load_model(checkpoint_path, latent_dim=512, dropout=0.15, device='cuda'):
164
+ """
165
+ Load a trained model from checkpoint
166
+
167
+ Args:
168
+ checkpoint_path: Path to .pth checkpoint file
169
+ latent_dim: Latent dimension (512 for Model A, 768 for Model B)
170
+ dropout: Dropout rate (0.15 for Model A, 0.20 for Model B)
171
+ device: Device to load model on
172
+
173
+ Returns:
174
+ model: Loaded model in eval mode
175
+ checkpoint: Full checkpoint dict with metadata
176
+ """
177
+ model = ResidualConvAutoencoder(latent_dim=latent_dim, dropout=dropout)
178
+ checkpoint = torch.load(checkpoint_path, map_location=device)
179
+ model.load_state_dict(checkpoint['model_state_dict'])
180
+ model.eval()
181
+ model.to(device)
182
+ return model, checkpoint