| | |
| | class GroupedAutoEncoder(nn.Module): |
| | def __init__(self, input_dim, hidden_dim, num_groups): |
| | super(GroupedAutoEncoder, self).__init__() |
| |
|
| | self.num_groups = num_groups |
| | self.group_input_dim = input_dim // num_groups |
| | self.group_hidden_dim = hidden_dim // num_groups |
| |
|
| | assert input_dim % num_groups == 0, "Input dimension must be divisible by the number of groups." |
| | assert hidden_dim % num_groups == 0, "Hidden dimension must be divisible by the number of groups." |
| |
|
| | |
| | self.encoders = nn.ModuleList([ |
| | nn.Linear(self.group_input_dim, self.group_hidden_dim, bias=False) |
| | for _ in range(num_groups) |
| | ]) |
| | ''' |
| | self.decoders = nn.ModuleList([ |
| | nn.Linear(self.group_hidden_dim, self.group_input_dim, bias=False) |
| | for _ in range(num_groups) |
| | ]) |
| | ''' |
| | self.decoder = nn.Linear(hidden_dim, input_dim, bias=False) |
| |
|
| | self.init_weights() |
| |
|
| | def init_weights(self): |
| | for encoder in self.encoders: |
| | nn.init.xavier_uniform_(encoder.weight) |
| | |
| | |
| | nn.init.xavier_uniform_(self.decoder.weight) |
| | |
| | def forward(self, x): |
| | |
| | group_inputs = torch.split(x, self.group_input_dim, dim=1) |
| |
|
| | |
| | encoded_groups = [encoder(group) for group, encoder in zip(group_inputs, self.encoders)] |
| |
|
| | |
| | |
| | |
| | reconstructed = self.decoder(torch.cat(encoded_groups,dim=1)) |
| | |
| | |
| | |
| | return reconstructed |
| |
|
| | input_dim = 5120 |
| | hidden_dim = 320 |
| | num_groups = 40 |
| | model = GroupedAutoEncoder(input_dim=input_dim, hidden_dim=hidden_dim, num_groups=num_groups).cuda() |
| |
|