File size: 3,474 Bytes
e59f78e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
import torch
import torch.nn.functional as F

from celldreamer.models.networks import CellDreamer

class ClassCellDreamer:

     def __init__(self, args):
          
          self.args = args
          self.device = args.device

          self.model = CellDreamer(
               device=torch.device(args.device),
               latent_dim=args.latent_dim,
               rnn_dim=args.rnn_dim,
               enc_hidden_dims=args.enc_hidden_dims,
               dec_hidden_dims=args.dec_hidden_dims,
               num_genes=args.num_genes
          )
          self.model.to(self.device)
          
          self.optimizer = torch.optim.Adam(
               self.model.parameters(), 
               lr=args.learning_rate,
               weight_decay=args.weight_decay
          )
          self.kl_scale = getattr(args, 'kl_scale', 0.1) # default 0.1

     def get_kl_loss(self, mean1, std1, mean2, std2):
          dist1 = torch.distributions.Normal(mean1, std1)
          dist2 = torch.distributions.Normal(mean2, std2)
          return torch.distributions.kl_divergence(dist1, dist2).sum(dim=1).mean()

     def train_step(self, x_t, x_next, current_epoch, total_epochs):
          
          self.model.train()
          self.optimizer.zero_grad()

          warmup_period = total_epochs // 2
          kl_weight = min(1.0, (current_epoch / warmup_period))
          
          effective_kl = self.kl_scale * kl_weight

          outputs = self.model(x_t)
          with torch.no_grad():
               target_mean, target_std = self.model.encoder(x_next)

          recon_loss = F.mse_loss(outputs["recon_x"], x_t)
          
          # Dynamics KL: KL(posterior(x_next) || prior_next)
          dynamics_loss = self.get_kl_loss(
               target_mean, target_std,
               outputs["prior_next_mean"], outputs["prior_next_std"]
          )

          # CRITICAL: Add posterior-prior KL to prevent posterior collapse
          # KL(posterior(x_t) || N(0,1)) - standard VAE regularization
          zeros = torch.zeros_like(outputs["post_mean"])
          ones = torch.ones_like(outputs["post_std"])
          posterior_kl = self.get_kl_loss(
               outputs["post_mean"], outputs["post_std"],
               zeros, ones
          )

          # Free bits: ensure minimum KL per dimension to prevent collapse
          # This ensures the model uses at least some information capacity
          free_bits_per_dim = 0.1  # minimum nats per dimension
          min_kl = free_bits_per_dim * outputs["post_mean"].shape[1]
          posterior_kl = torch.clamp(posterior_kl, min=min_kl)
          dynamics_loss = torch.clamp(dynamics_loss, min=min_kl)

          # Total Loss: reconstruction + dynamics KL + posterior regularization
          total_loss = recon_loss + (effective_kl * dynamics_loss) + (effective_kl * posterior_kl)
          
          total_loss.backward()

          torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)
          
          self.optimizer.step()

          return {
               "loss": total_loss.item(),
               "recon_loss": recon_loss.item(),
               "dynamics_loss": dynamics_loss.item(),
               "posterior_kl": posterior_kl.item(),
               "kl_weight": effective_kl
          }

     def save(self, path):
          torch.save(self.model.state_dict(), path)

     def load(self, path):
          self.model.load_state_dict(torch.load(path, map_location=self.device))