| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
|
|
|
|
| class SparseAutoencoder(nn.Module): |
| def __init__( |
| self, |
| input_dim, |
| hidden_dim, |
| sparsity_alpha=0.00004, |
| decoder_norm_range=(0.05, 1.0), |
| ): |
| super(SparseAutoencoder, self).__init__() |
| self.input_dim = input_dim |
| self.hidden_dim = hidden_dim |
| self.sparsity_alpha = sparsity_alpha |
|
|
| self.enc_bias = nn.Parameter(torch.zeros(hidden_dim)) |
| self.encoder = nn.Linear(input_dim, hidden_dim, bias=False) |
|
|
| self.dec_bias = nn.Parameter(torch.zeros(input_dim)) |
| self.decoder = nn.Linear(hidden_dim, input_dim, bias=False) |
|
|
| self._initialize_weights(decoder_norm_range) |
|
|
| def forward(self, x): |
| encoded = self.encode(x) |
| decoded = self.decode(encoded) |
| return decoded, encoded |
|
|
| def encode(self, x): |
| return F.relu(self.encoder(x) + self.enc_bias) |
|
|
| def decode(self, x): |
| return self.decoder(x) + self.dec_bias |
|
|
| def loss(self, x, decoded, encoded): |
| reconstruction_loss = F.mse_loss(decoded, x) |
| sparsity_loss = self.sparsity_alpha * torch.sum( |
| encoded.abs() * self.decoder.weight.norm(p=2, dim=0) |
| ) |
| total_loss = reconstruction_loss + sparsity_loss |
| return total_loss |
|
|
| def _initialize_weights(self, decoder_norm_range): |
| |
| self.encoder.weight.data = self.decoder.weight.data.t() |
|
|
| |
| norm_min, norm_max = decoder_norm_range |
| norm_range = norm_max - norm_min |
| self.decoder.weight.data.normal_(0, 1) |
| self.decoder.weight.data /= self.decoder.weight.data.norm( |
| p=2, dim=1, keepdim=True |
| ) |
| self.decoder.weight.data *= ( |
| norm_min + norm_range * torch.rand(1, self.hidden_dim) |
| ).expand_as(self.decoder.weight.data) |
|
|