File size: 2,067 Bytes
f3b5cf2 84c306c | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 |
class GroupedAutoEncoder(nn.Module):
def __init__(self, input_dim, hidden_dim, num_groups):
super(GroupedAutoEncoder, self).__init__()
self.num_groups = num_groups
self.group_input_dim = input_dim // num_groups
self.group_hidden_dim = hidden_dim // num_groups
assert input_dim % num_groups == 0, "Input dimension must be divisible by the number of groups."
assert hidden_dim % num_groups == 0, "Hidden dimension must be divisible by the number of groups."
# Define group-wise encoders and decoders
self.encoders = nn.ModuleList([
nn.Linear(self.group_input_dim, self.group_hidden_dim, bias=False)
for _ in range(num_groups)
])
'''
self.decoders = nn.ModuleList([
nn.Linear(self.group_hidden_dim, self.group_input_dim, bias=False)
for _ in range(num_groups)
])
'''
self.decoder = nn.Linear(hidden_dim, input_dim, bias=False)
self.init_weights()
def init_weights(self):
for encoder in self.encoders:
nn.init.xavier_uniform_(encoder.weight)
#for decoder in self.decoders:
# nn.init.xavier_uniform_(decoder.weight)
nn.init.xavier_uniform_(self.decoder.weight)
def forward(self, x):
# Split input into groups
group_inputs = torch.split(x, self.group_input_dim, dim=1)
# Apply group-wise encoding
encoded_groups = [encoder(group) for group, encoder in zip(group_inputs, self.encoders)]
# Apply group-wise decoding
#decoded_groups = [decoder(group) for group, decoder in zip(encoded_groups, self.decoders)]
reconstructed = self.decoder(torch.cat(encoded_groups,dim=1))
# Concatenate groups back together
# reconstructed = torch.cat(decoded_groups, dim=1)
return reconstructed
input_dim = 5120
hidden_dim = 320
num_groups = 40
model = GroupedAutoEncoder(input_dim=input_dim, hidden_dim=hidden_dim, num_groups=num_groups).cuda()
|