import torch import torch.nn.functional as F from celldreamer.models.networks import CellDreamer class ClassCellDreamer: def __init__(self, args): self.args = args self.device = args.device self.model = CellDreamer( device=torch.device(args.device), latent_dim=args.latent_dim, rnn_dim=args.rnn_dim, enc_hidden_dims=args.enc_hidden_dims, dec_hidden_dims=args.dec_hidden_dims, num_genes=args.num_genes ) self.model.to(self.device) self.optimizer = torch.optim.Adam( self.model.parameters(), lr=args.learning_rate, weight_decay=args.weight_decay ) self.kl_scale = getattr(args, 'kl_scale', 0.1) # default 0.1 def get_kl_loss(self, mean1, std1, mean2, std2): dist1 = torch.distributions.Normal(mean1, std1) dist2 = torch.distributions.Normal(mean2, std2) return torch.distributions.kl_divergence(dist1, dist2).sum(dim=1).mean() def train_step(self, x_t, x_next, current_epoch, total_epochs): self.model.train() self.optimizer.zero_grad() warmup_period = total_epochs // 2 kl_weight = min(1.0, (current_epoch / warmup_period)) effective_kl = self.kl_scale * kl_weight outputs = self.model(x_t) with torch.no_grad(): target_mean, target_std = self.model.encoder(x_next) recon_loss = F.mse_loss(outputs["recon_x"], x_t) # Dynamics KL: KL(posterior(x_next) || prior_next) dynamics_loss = self.get_kl_loss( target_mean, target_std, outputs["prior_next_mean"], outputs["prior_next_std"] ) # CRITICAL: Add posterior-prior KL to prevent posterior collapse # KL(posterior(x_t) || N(0,1)) - standard VAE regularization zeros = torch.zeros_like(outputs["post_mean"]) ones = torch.ones_like(outputs["post_std"]) posterior_kl = self.get_kl_loss( outputs["post_mean"], outputs["post_std"], zeros, ones ) # Free bits: ensure minimum KL per dimension to prevent collapse # This ensures the model uses at least some information capacity free_bits_per_dim = 0.1 # minimum nats per dimension min_kl = free_bits_per_dim * outputs["post_mean"].shape[1] posterior_kl = torch.clamp(posterior_kl, min=min_kl) dynamics_loss = torch.clamp(dynamics_loss, min=min_kl) # Total Loss: reconstruction + dynamics KL + posterior regularization total_loss = recon_loss + (effective_kl * dynamics_loss) + (effective_kl * posterior_kl) total_loss.backward() torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0) self.optimizer.step() return { "loss": total_loss.item(), "recon_loss": recon_loss.item(), "dynamics_loss": dynamics_loss.item(), "posterior_kl": posterior_kl.item(), "kl_weight": effective_kl } def save(self, path): torch.save(self.model.state_dict(), path) def load(self, path): self.model.load_state_dict(torch.load(path, map_location=self.device))