sparse-cafm / src /models /autoencoder.py
leharris3's picture
Minimal HF Space deployment with gradio 5.x fix
0917e8d
import torch
import torch.nn as nn
class Autoencoder(nn.Module):
def __init__(self, channels=3):
super(Autoencoder, self).__init__()
# Encoder
self.encoder = nn.Sequential(
# Input: (channels, 64, 64)
nn.Conv2d(channels, 16, kernel_size=3, stride=2, padding=1), # (16, 32, 32)
nn.ReLU(True),
nn.Conv2d(16, 32, kernel_size=3, stride=2, padding=1), # (32, 16, 16)
nn.ReLU(True),
nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1), # (64, 8, 8)
nn.ReLU(True),
nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1), # (128, 4, 4)
nn.ReLU(True)
)
# Decoder
self.decoder = nn.Sequential(
nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1), # (64, 8, 8)
nn.ReLU(True),
nn.ConvTranspose2d(64, 32, kernel_size=4, stride=2, padding=1), # (32, 16, 16)
nn.ReLU(True),
nn.ConvTranspose2d(32, 16, kernel_size=4, stride=2, padding=1), # (16, 32, 32)
nn.ReLU(True),
nn.ConvTranspose2d(16, channels, kernel_size=4, stride=2, padding=1), # (channels, 64, 64)
nn.Tanh() # To ensure the output is between 0 and 1
)
def forward(self, x):
encoded = self.encoder(x)
decoded = self.decoder(encoded)
return decoded
@staticmethod
def get(weights=None):
return Autoencoder()
if __name__ == "__main__":
pass