CellDreamer-API / celldreamer /models /class_celldreamer.py
RobroKools's picture
Upload 44 files
e59f78e verified
import torch
import torch.nn.functional as F
from celldreamer.models.networks import CellDreamer
class ClassCellDreamer:
def __init__(self, args):
self.args = args
self.device = args.device
self.model = CellDreamer(
device=torch.device(args.device),
latent_dim=args.latent_dim,
rnn_dim=args.rnn_dim,
enc_hidden_dims=args.enc_hidden_dims,
dec_hidden_dims=args.dec_hidden_dims,
num_genes=args.num_genes
)
self.model.to(self.device)
self.optimizer = torch.optim.Adam(
self.model.parameters(),
lr=args.learning_rate,
weight_decay=args.weight_decay
)
self.kl_scale = getattr(args, 'kl_scale', 0.1) # default 0.1
def get_kl_loss(self, mean1, std1, mean2, std2):
dist1 = torch.distributions.Normal(mean1, std1)
dist2 = torch.distributions.Normal(mean2, std2)
return torch.distributions.kl_divergence(dist1, dist2).sum(dim=1).mean()
def train_step(self, x_t, x_next, current_epoch, total_epochs):
self.model.train()
self.optimizer.zero_grad()
warmup_period = total_epochs // 2
kl_weight = min(1.0, (current_epoch / warmup_period))
effective_kl = self.kl_scale * kl_weight
outputs = self.model(x_t)
with torch.no_grad():
target_mean, target_std = self.model.encoder(x_next)
recon_loss = F.mse_loss(outputs["recon_x"], x_t)
# Dynamics KL: KL(posterior(x_next) || prior_next)
dynamics_loss = self.get_kl_loss(
target_mean, target_std,
outputs["prior_next_mean"], outputs["prior_next_std"]
)
# CRITICAL: Add posterior-prior KL to prevent posterior collapse
# KL(posterior(x_t) || N(0,1)) - standard VAE regularization
zeros = torch.zeros_like(outputs["post_mean"])
ones = torch.ones_like(outputs["post_std"])
posterior_kl = self.get_kl_loss(
outputs["post_mean"], outputs["post_std"],
zeros, ones
)
# Free bits: ensure minimum KL per dimension to prevent collapse
# This ensures the model uses at least some information capacity
free_bits_per_dim = 0.1 # minimum nats per dimension
min_kl = free_bits_per_dim * outputs["post_mean"].shape[1]
posterior_kl = torch.clamp(posterior_kl, min=min_kl)
dynamics_loss = torch.clamp(dynamics_loss, min=min_kl)
# Total Loss: reconstruction + dynamics KL + posterior regularization
total_loss = recon_loss + (effective_kl * dynamics_loss) + (effective_kl * posterior_kl)
total_loss.backward()
torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)
self.optimizer.step()
return {
"loss": total_loss.item(),
"recon_loss": recon_loss.item(),
"dynamics_loss": dynamics_loss.item(),
"posterior_kl": posterior_kl.item(),
"kl_weight": effective_kl
}
def save(self, path):
torch.save(self.model.state_dict(), path)
def load(self, path):
self.model.load_state_dict(torch.load(path, map_location=self.device))