Spaces:
Sleeping
Sleeping
| import torch | |
| import torch.nn.functional as F | |
| from celldreamer.models.networks import CellDreamer | |
| class ClassCellDreamer: | |
| def __init__(self, args): | |
| self.args = args | |
| self.device = args.device | |
| self.model = CellDreamer( | |
| device=torch.device(args.device), | |
| latent_dim=args.latent_dim, | |
| rnn_dim=args.rnn_dim, | |
| enc_hidden_dims=args.enc_hidden_dims, | |
| dec_hidden_dims=args.dec_hidden_dims, | |
| num_genes=args.num_genes | |
| ) | |
| self.model.to(self.device) | |
| self.optimizer = torch.optim.Adam( | |
| self.model.parameters(), | |
| lr=args.learning_rate, | |
| weight_decay=args.weight_decay | |
| ) | |
| self.kl_scale = getattr(args, 'kl_scale', 0.1) # default 0.1 | |
| def get_kl_loss(self, mean1, std1, mean2, std2): | |
| dist1 = torch.distributions.Normal(mean1, std1) | |
| dist2 = torch.distributions.Normal(mean2, std2) | |
| return torch.distributions.kl_divergence(dist1, dist2).sum(dim=1).mean() | |
| def train_step(self, x_t, x_next, current_epoch, total_epochs): | |
| self.model.train() | |
| self.optimizer.zero_grad() | |
| warmup_period = total_epochs // 2 | |
| kl_weight = min(1.0, (current_epoch / warmup_period)) | |
| effective_kl = self.kl_scale * kl_weight | |
| outputs = self.model(x_t) | |
| with torch.no_grad(): | |
| target_mean, target_std = self.model.encoder(x_next) | |
| recon_loss = F.mse_loss(outputs["recon_x"], x_t) | |
| # Dynamics KL: KL(posterior(x_next) || prior_next) | |
| dynamics_loss = self.get_kl_loss( | |
| target_mean, target_std, | |
| outputs["prior_next_mean"], outputs["prior_next_std"] | |
| ) | |
| # CRITICAL: Add posterior-prior KL to prevent posterior collapse | |
| # KL(posterior(x_t) || N(0,1)) - standard VAE regularization | |
| zeros = torch.zeros_like(outputs["post_mean"]) | |
| ones = torch.ones_like(outputs["post_std"]) | |
| posterior_kl = self.get_kl_loss( | |
| outputs["post_mean"], outputs["post_std"], | |
| zeros, ones | |
| ) | |
| # Free bits: ensure minimum KL per dimension to prevent collapse | |
| # This ensures the model uses at least some information capacity | |
| free_bits_per_dim = 0.1 # minimum nats per dimension | |
| min_kl = free_bits_per_dim * outputs["post_mean"].shape[1] | |
| posterior_kl = torch.clamp(posterior_kl, min=min_kl) | |
| dynamics_loss = torch.clamp(dynamics_loss, min=min_kl) | |
| # Total Loss: reconstruction + dynamics KL + posterior regularization | |
| total_loss = recon_loss + (effective_kl * dynamics_loss) + (effective_kl * posterior_kl) | |
| total_loss.backward() | |
| torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0) | |
| self.optimizer.step() | |
| return { | |
| "loss": total_loss.item(), | |
| "recon_loss": recon_loss.item(), | |
| "dynamics_loss": dynamics_loss.item(), | |
| "posterior_kl": posterior_kl.item(), | |
| "kl_weight": effective_kl | |
| } | |
| def save(self, path): | |
| torch.save(self.model.state_dict(), path) | |
| def load(self, path): | |
| self.model.load_state_dict(torch.load(path, map_location=self.device)) |