import torch import torch.nn as nn import torch.nn.functional as F # define a mlp encoder # inputs: batch x num_genes (2446) # outputs: batch x ecoding_dim class Encoder(nn.Module): def __init__(self, latent_dim, hidden_dims, num_genes=2446): super().__init__() layers = [] prev_dim = num_genes for h_dim in hidden_dims: layers.append(nn.Linear(prev_dim, h_dim)) layers.append(nn.BatchNorm1d(h_dim)) layers.append(nn.ELU()) layers.append(nn.Dropout(0.4)) prev_dim = h_dim self.enc_net = nn.Sequential(*layers) self.fc_mean = nn.Linear(prev_dim, latent_dim) self.fc_std = nn.Linear(prev_dim, latent_dim) def forward(self, x_t): h = self.enc_net(x_t) mean = self.fc_mean(h) # Ensure minimum std to prevent posterior collapse # Higher minimum (1e-3) prevents std from collapsing to near-zero std = F.softplus(self.fc_std(h)) + 1e-3 return mean, std # define a corresponding mlp decoder # input: batch x ecoding_dim + rnn_hidden_dim class Decoder(nn.Module): def __init__(self, latent_dim, rnn_hidden_dim, hidden_dims, num_genes=2446): super().__init__() layers = [] prev_dim = latent_dim + rnn_hidden_dim for h_dim in hidden_dims: layers.append(nn.Linear(prev_dim, h_dim)) layers.append(nn.BatchNorm1d(h_dim)) layers.append(nn.ELU()) layers.append(nn.Dropout(0.4)) prev_dim = h_dim layers.append(nn.Linear(prev_dim, num_genes)) self.dec_net = nn.Sequential(*layers) def forward(self, z, h): inps = torch.cat([z, h], dim=1) return self.dec_net(inps) # define a gru-based rssm # input: batch x ecoding_dim at t=0 # output: batch x 2*encoding_dim at t = 1 to get the mean and standard deviation class RSSM(nn.Module): def __init__(self, latent_dim, rnn_hidden_dim): super().__init__() self.latent_dim = latent_dim self.hidden_dim = rnn_hidden_dim self.gru = nn.GRUCell(latent_dim, rnn_hidden_dim) self.mlp = nn.Sequential( nn.Linear(rnn_hidden_dim, rnn_hidden_dim), nn.LayerNorm(rnn_hidden_dim), nn.ELU(), nn.Linear(rnn_hidden_dim, 2 * latent_dim) ) # Better initialization: larger std prevents weak prior # Use Xavier/Glorot initialization for better gradient flow nn.init.xavier_uniform_(self.mlp[3].weight, gain=0.1) nn.init.zeros_(self.mlp[3].bias) def forward(self, prev_r, prev_h): h_t_1 = self.gru(prev_r, prev_h) prev_stats = self.mlp(h_t_1) prev_mean, prev_std = torch.chunk(prev_stats, 2, dim=1) prev_std = F.softplus(prev_std) + 1e-3 return h_t_1, prev_mean, prev_std # create joint training architecture for dreamer class CellDreamer(nn.Module): def __init__( self, device, latent_dim = 20, rnn_dim = 64, enc_hidden_dims = [128, 64, 32], dec_hidden_dims = [32, 64, 128], num_genes = 2446 ): super().__init__() self.encoder = Encoder(latent_dim, enc_hidden_dims, num_genes) self.decoder = Decoder(latent_dim, rnn_dim, dec_hidden_dims, num_genes) self.rssm = RSSM(latent_dim, rnn_dim) self.rnn_dim = rnn_dim self.latent_dim = latent_dim self.input_dim = num_genes self.device = device def reparametrize(self, mean, std): eps = torch.randn_like(std) return mean + eps * std def forward(self, x_t): post_mean, post_std = self.encoder(x_t) z_t = self.reparametrize(post_mean, post_std) h_prev = torch.zeros(x_t.size(0), self.rnn_dim).to(self.device) h_next, velocity_mean, velocity_std = self.rssm(z_t, h_prev) prior_next_mean = z_t + velocity_mean prior_next_std = velocity_std rec_x = self.decoder(z_t, h_next) return { "recon_x": rec_x, "post_mean": post_mean, "post_std": post_std, "prior_next_mean": prior_next_mean, "prior_next_std": prior_next_std, "z_t": z_t, "h_next": h_next }