Spaces:
Sleeping
Sleeping
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| # define a mlp encoder | |
| # inputs: batch x num_genes (2446) | |
| # outputs: batch x ecoding_dim | |
| class Encoder(nn.Module): | |
| def __init__(self, latent_dim, hidden_dims, num_genes=2446): | |
| super().__init__() | |
| layers = [] | |
| prev_dim = num_genes | |
| for h_dim in hidden_dims: | |
| layers.append(nn.Linear(prev_dim, h_dim)) | |
| layers.append(nn.BatchNorm1d(h_dim)) | |
| layers.append(nn.ELU()) | |
| layers.append(nn.Dropout(0.4)) | |
| prev_dim = h_dim | |
| self.enc_net = nn.Sequential(*layers) | |
| self.fc_mean = nn.Linear(prev_dim, latent_dim) | |
| self.fc_std = nn.Linear(prev_dim, latent_dim) | |
| def forward(self, x_t): | |
| h = self.enc_net(x_t) | |
| mean = self.fc_mean(h) | |
| # Ensure minimum std to prevent posterior collapse | |
| # Higher minimum (1e-3) prevents std from collapsing to near-zero | |
| std = F.softplus(self.fc_std(h)) + 1e-3 | |
| return mean, std | |
| # define a corresponding mlp decoder | |
| # input: batch x ecoding_dim + rnn_hidden_dim | |
| class Decoder(nn.Module): | |
| def __init__(self, latent_dim, rnn_hidden_dim, hidden_dims, num_genes=2446): | |
| super().__init__() | |
| layers = [] | |
| prev_dim = latent_dim + rnn_hidden_dim | |
| for h_dim in hidden_dims: | |
| layers.append(nn.Linear(prev_dim, h_dim)) | |
| layers.append(nn.BatchNorm1d(h_dim)) | |
| layers.append(nn.ELU()) | |
| layers.append(nn.Dropout(0.4)) | |
| prev_dim = h_dim | |
| layers.append(nn.Linear(prev_dim, num_genes)) | |
| self.dec_net = nn.Sequential(*layers) | |
| def forward(self, z, h): | |
| inps = torch.cat([z, h], dim=1) | |
| return self.dec_net(inps) | |
| # define a gru-based rssm | |
| # input: batch x ecoding_dim at t=0 | |
| # output: batch x 2*encoding_dim at t = 1 to get the mean and standard deviation | |
| class RSSM(nn.Module): | |
| def __init__(self, latent_dim, rnn_hidden_dim): | |
| super().__init__() | |
| self.latent_dim = latent_dim | |
| self.hidden_dim = rnn_hidden_dim | |
| self.gru = nn.GRUCell(latent_dim, rnn_hidden_dim) | |
| self.mlp = nn.Sequential( | |
| nn.Linear(rnn_hidden_dim, rnn_hidden_dim), | |
| nn.LayerNorm(rnn_hidden_dim), | |
| nn.ELU(), | |
| nn.Linear(rnn_hidden_dim, 2 * latent_dim) | |
| ) | |
| # Better initialization: larger std prevents weak prior | |
| # Use Xavier/Glorot initialization for better gradient flow | |
| nn.init.xavier_uniform_(self.mlp[3].weight, gain=0.1) | |
| nn.init.zeros_(self.mlp[3].bias) | |
| def forward(self, prev_r, prev_h): | |
| h_t_1 = self.gru(prev_r, prev_h) | |
| prev_stats = self.mlp(h_t_1) | |
| prev_mean, prev_std = torch.chunk(prev_stats, 2, dim=1) | |
| prev_std = F.softplus(prev_std) + 1e-3 | |
| return h_t_1, prev_mean, prev_std | |
| # create joint training architecture for dreamer | |
| class CellDreamer(nn.Module): | |
| def __init__( | |
| self, | |
| device, | |
| latent_dim = 20, | |
| rnn_dim = 64, | |
| enc_hidden_dims = [128, 64, 32], | |
| dec_hidden_dims = [32, 64, 128], | |
| num_genes = 2446 | |
| ): | |
| super().__init__() | |
| self.encoder = Encoder(latent_dim, enc_hidden_dims, num_genes) | |
| self.decoder = Decoder(latent_dim, rnn_dim, dec_hidden_dims, num_genes) | |
| self.rssm = RSSM(latent_dim, rnn_dim) | |
| self.rnn_dim = rnn_dim | |
| self.latent_dim = latent_dim | |
| self.input_dim = num_genes | |
| self.device = device | |
| def reparametrize(self, mean, std): | |
| eps = torch.randn_like(std) | |
| return mean + eps * std | |
| def forward(self, x_t): | |
| post_mean, post_std = self.encoder(x_t) | |
| z_t = self.reparametrize(post_mean, post_std) | |
| h_prev = torch.zeros(x_t.size(0), self.rnn_dim).to(self.device) | |
| h_next, velocity_mean, velocity_std = self.rssm(z_t, h_prev) | |
| prior_next_mean = z_t + velocity_mean | |
| prior_next_std = velocity_std | |
| rec_x = self.decoder(z_t, h_next) | |
| return { | |
| "recon_x": rec_x, | |
| "post_mean": post_mean, | |
| "post_std": post_std, | |
| "prior_next_mean": prior_next_mean, | |
| "prior_next_std": prior_next_std, | |
| "z_t": z_t, | |
| "h_next": h_next | |
| } | |