| | import itertools |
| | import math |
| | import torch |
| | import torch.nn.functional as F |
| | import pytorch_lightning as L |
| | import torchmetrics |
| | from dataclasses import dataclass |
| | from models import dit, ema |
| | import noise_schedule |
| |
|
| | LOG2 = math.log(2) |
| |
|
| | @dataclass |
| | class Loss: |
| | loss: torch.FloatTensor |
| | nlls: torch.FloatTensor |
| | token_mask: torch.FloatTensor |
| |
|
| | class NLL(torchmetrics.MeanMetric): |
| | pass |
| |
|
| | class BPD(NLL): |
| | def compute(self) -> torch.Tensor: |
| | """Computes the bits per dimension. |
| | |
| | Returns: |
| | bpd |
| | """ |
| | return self.mean_value / self.weight / LOG2 |
| |
|
| | class Perplexity(NLL): |
| | def compute(self) -> torch.Tensor: |
| | """Computes the Perplexity. |
| | |
| | Returns: |
| | Perplexity |
| | """ |
| | return torch.exp(self.mean_value / self.weight) |
| |
|
| | class Diffusion(L.LightningModule): |
| | def __init__(self, config, latent_dim): |
| | super().__init__() |
| | self.config = config |
| | self.latent_dim = latent_dim |
| |
|
| | self.backbone = dit.DIT(config, vocab_size=self.latent_dim) |
| | self.T = self.config.T |
| | self.subs_masking = self.config.subs_masking |
| |
|
| | self.softplus = torch.nn.Softplus() |
| | metrics = torchmetrics.MetricCollection({ |
| | 'nll': NLL(), |
| | 'bpd': BPD(), |
| | 'ppl': Perplexity(), |
| | }) |
| | metrics.set_dtype(torch.float64) |
| | self.train_metrics = metrics.clone(prefix='train/') |
| | self.valid_metrics = metrics.clone(prefix='val/') |
| | self.test_metrics = metrics.clone(prefix='test/') |
| |
|
| | self.noise = noise_schedule.get_noise(self.config, dtype=self.dtype) |
| | self.lr = self.config.optim["lr"] |
| | self.sampling_eps = self.config.training.get("sampling_eps", 1e-5) |
| | self.time_conditioning = self.config.get("time_conditioning", True) |
| | self.neg_infinity = -1000000.0 |
| |
|
| | def forward(self, latents, sigma): |
| | """Forward diffusion process, adds noise to the latents.""" |
| | noise = sigma * torch.randn_like(latents) |
| | noisy_latents = latents + noise |
| | return noisy_latents |
| |
|
| | def reverse_diffusion(self, noisy_latents, sigma): |
| | """Reverse diffusion process, denoises the latents.""" |
| | denoised_latents = self.backbone(noisy_latents, sigma) |
| | return denoised_latents |
| |
|
| | def training_step(self, batch, batch_idx): |
| | sigma = torch.rand(batch.size(0), device=self.device) |
| | noisy_latents = self.forward(batch, sigma) |
| | denoised_latents = self.reverse_diffusion(noisy_latents, sigma) |
| | loss = F.mse_loss(denoised_latents, batch) |
| | self.log("train_loss", loss) |
| | return loss |
| |
|
| | def configure_optimizers(self): |
| | optimizer = torch.optim.Adam(self.parameters(), lr=self.lr) |
| | return optimizer |
| |
|