File size: 2,067 Bytes
f3b5cf2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
84c306c
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
    
class GroupedAutoEncoder(nn.Module):
    def __init__(self, input_dim, hidden_dim, num_groups):
        super(GroupedAutoEncoder, self).__init__()

        self.num_groups = num_groups
        self.group_input_dim = input_dim // num_groups
        self.group_hidden_dim = hidden_dim // num_groups

        assert input_dim % num_groups == 0, "Input dimension must be divisible by the number of groups."
        assert hidden_dim % num_groups == 0, "Hidden dimension must be divisible by the number of groups."

        # Define group-wise encoders and decoders
        self.encoders = nn.ModuleList([
            nn.Linear(self.group_input_dim, self.group_hidden_dim, bias=False)
            for _ in range(num_groups)
        ])
        '''
        self.decoders = nn.ModuleList([
            nn.Linear(self.group_hidden_dim, self.group_input_dim, bias=False)
            for _ in range(num_groups)
        ])
        '''
        self.decoder = nn.Linear(hidden_dim, input_dim, bias=False)

        self.init_weights()

    def init_weights(self):
        for encoder in self.encoders:
            nn.init.xavier_uniform_(encoder.weight)
        #for decoder in self.decoders:
        #    nn.init.xavier_uniform_(decoder.weight)
        nn.init.xavier_uniform_(self.decoder.weight)
        
    def forward(self, x):
        # Split input into groups
        group_inputs = torch.split(x, self.group_input_dim, dim=1)

        # Apply group-wise encoding
        encoded_groups = [encoder(group) for group, encoder in zip(group_inputs, self.encoders)]

        # Apply group-wise decoding
        #decoded_groups = [decoder(group) for group, decoder in zip(encoded_groups, self.decoders)]
        
        reconstructed = self.decoder(torch.cat(encoded_groups,dim=1))
        
        # Concatenate groups back together
        # reconstructed = torch.cat(decoded_groups, dim=1)
        return reconstructed

input_dim = 5120
hidden_dim = 320
num_groups = 40
model = GroupedAutoEncoder(input_dim=input_dim, hidden_dim=hidden_dim, num_groups=num_groups).cuda()