RobroKools's picture
Upload 44 files
e59f78e verified
import torch
import torch.nn as nn
import torch.nn.functional as F
# define a mlp encoder
# inputs: batch x num_genes (2446)
# outputs: batch x ecoding_dim
class Encoder(nn.Module):
def __init__(self, latent_dim, hidden_dims, num_genes=2446):
super().__init__()
layers = []
prev_dim = num_genes
for h_dim in hidden_dims:
layers.append(nn.Linear(prev_dim, h_dim))
layers.append(nn.BatchNorm1d(h_dim))
layers.append(nn.ELU())
layers.append(nn.Dropout(0.4))
prev_dim = h_dim
self.enc_net = nn.Sequential(*layers)
self.fc_mean = nn.Linear(prev_dim, latent_dim)
self.fc_std = nn.Linear(prev_dim, latent_dim)
def forward(self, x_t):
h = self.enc_net(x_t)
mean = self.fc_mean(h)
# Ensure minimum std to prevent posterior collapse
# Higher minimum (1e-3) prevents std from collapsing to near-zero
std = F.softplus(self.fc_std(h)) + 1e-3
return mean, std
# define a corresponding mlp decoder
# input: batch x ecoding_dim + rnn_hidden_dim
class Decoder(nn.Module):
def __init__(self, latent_dim, rnn_hidden_dim, hidden_dims, num_genes=2446):
super().__init__()
layers = []
prev_dim = latent_dim + rnn_hidden_dim
for h_dim in hidden_dims:
layers.append(nn.Linear(prev_dim, h_dim))
layers.append(nn.BatchNorm1d(h_dim))
layers.append(nn.ELU())
layers.append(nn.Dropout(0.4))
prev_dim = h_dim
layers.append(nn.Linear(prev_dim, num_genes))
self.dec_net = nn.Sequential(*layers)
def forward(self, z, h):
inps = torch.cat([z, h], dim=1)
return self.dec_net(inps)
# define a gru-based rssm
# input: batch x ecoding_dim at t=0
# output: batch x 2*encoding_dim at t = 1 to get the mean and standard deviation
class RSSM(nn.Module):
def __init__(self, latent_dim, rnn_hidden_dim):
super().__init__()
self.latent_dim = latent_dim
self.hidden_dim = rnn_hidden_dim
self.gru = nn.GRUCell(latent_dim, rnn_hidden_dim)
self.mlp = nn.Sequential(
nn.Linear(rnn_hidden_dim, rnn_hidden_dim),
nn.LayerNorm(rnn_hidden_dim),
nn.ELU(),
nn.Linear(rnn_hidden_dim, 2 * latent_dim)
)
# Better initialization: larger std prevents weak prior
# Use Xavier/Glorot initialization for better gradient flow
nn.init.xavier_uniform_(self.mlp[3].weight, gain=0.1)
nn.init.zeros_(self.mlp[3].bias)
def forward(self, prev_r, prev_h):
h_t_1 = self.gru(prev_r, prev_h)
prev_stats = self.mlp(h_t_1)
prev_mean, prev_std = torch.chunk(prev_stats, 2, dim=1)
prev_std = F.softplus(prev_std) + 1e-3
return h_t_1, prev_mean, prev_std
# create joint training architecture for dreamer
class CellDreamer(nn.Module):
def __init__(
self,
device,
latent_dim = 20,
rnn_dim = 64,
enc_hidden_dims = [128, 64, 32],
dec_hidden_dims = [32, 64, 128],
num_genes = 2446
):
super().__init__()
self.encoder = Encoder(latent_dim, enc_hidden_dims, num_genes)
self.decoder = Decoder(latent_dim, rnn_dim, dec_hidden_dims, num_genes)
self.rssm = RSSM(latent_dim, rnn_dim)
self.rnn_dim = rnn_dim
self.latent_dim = latent_dim
self.input_dim = num_genes
self.device = device
def reparametrize(self, mean, std):
eps = torch.randn_like(std)
return mean + eps * std
def forward(self, x_t):
post_mean, post_std = self.encoder(x_t)
z_t = self.reparametrize(post_mean, post_std)
h_prev = torch.zeros(x_t.size(0), self.rnn_dim).to(self.device)
h_next, velocity_mean, velocity_std = self.rssm(z_t, h_prev)
prior_next_mean = z_t + velocity_mean
prior_next_std = velocity_std
rec_x = self.decoder(z_t, h_next)
return {
"recon_x": rec_x,
"post_mean": post_mean,
"post_std": post_std,
"prior_next_mean": prior_next_mean,
"prior_next_std": prior_next_std,
"z_t": z_t,
"h_next": h_next
}