# Copyright (c) 2025 FoundationVision # SPDX-License-Identifier: MIT import torch import torch.nn as nn import numpy as np import torch.nn.functional as F class DiagonalGaussianDistribution(object): def __init__(self, parameters, deterministic=False): self.parameters = parameters self.mean, self.logvar = torch.chunk(parameters, 2, dim=1) self.logvar = torch.clamp(self.logvar, -30.0, 20.0) self.deterministic = deterministic self.std = torch.exp(0.5 * self.logvar) self.var = torch.exp(self.logvar) if self.deterministic: self.var = self.std = torch.zeros_like(self.mean).to(device=self.parameters.device) def sample(self): # x = self.mean + self.std * torch.randn(self.mean.shape).to(device) x = self.mean + self.std * torch.randn(self.mean.shape, device=self.parameters.device) return x def kl(self, other=None, reduction="sum"): if reduction == "sum": reduction_op = torch.sum elif reduction == "mean": reduction_op = torch.mean if self.mean.ndim == 4: dims = [1,2,3] else: dims = [1,2,3,4] if self.deterministic: return torch.Tensor([0.]) else: if other is None: return 0.5 * reduction_op(torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar, dim=dims) else: return 0.5 * reduction_op( torch.pow(self.mean - other.mean, 2) / other.var + self.var / other.var - 1.0 - self.logvar + other.logvar, dim=dims) def nll(self, sample, dims=[1,2,3]): if self.deterministic: return torch.Tensor([0.]) logtwopi = np.log(2.0 * np.pi) return 0.5 * torch.sum( logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var, dim=dims) def mode(self): return self.mean def normal_kl(mean1, logvar1, mean2, logvar2): """ source: https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/losses.py#L12 Compute the KL divergence between two gaussians. Shapes are automatically broadcasted, so batches can be compared to scalars, among other use cases. """ tensor = None for obj in (mean1, logvar1, mean2, logvar2): if isinstance(obj, torch.Tensor): tensor = obj break assert tensor is not None, "at least one argument must be a Tensor" # Force variances to be Tensors. Broadcasting helps convert scalars to # Tensors, but it does not work for torch.exp(). logvar1, logvar2 = [ x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor) for x in (logvar1, logvar2) ] return 0.5 * ( -1.0 + logvar2 - logvar1 + torch.exp(logvar1 - logvar2) + ((mean1 - mean2) ** 2) * torch.exp(-logvar2) ) class VectorQuantizer(nn.Module): def __init__(self, n_e, e_dim, beta, entropy_loss_ratio, l2_norm, show_usage): super().__init__() self.n_e = n_e self.e_dim = e_dim self.beta = beta self.entropy_loss_ratio = entropy_loss_ratio self.l2_norm = l2_norm self.show_usage = show_usage self.embedding = nn.Embedding(self.n_e, self.e_dim) self.embedding.weight.data.uniform_(-1.0 / self.n_e, 1.0 / self.n_e) if self.l2_norm: self.embedding.weight.data = F.normalize(self.embedding.weight.data, p=2, dim=-1) if self.show_usage: self.register_buffer("codebook_used", nn.Parameter(torch.zeros(65536))) def forward(self, z): # reshape z -> (batch, height, width, channel) and flatten z = torch.einsum('b c h w -> b h w c', z).contiguous() z_flattened = z.view(-1, self.e_dim) # distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z if self.l2_norm: z = F.normalize(z, p=2, dim=-1) z_flattened = F.normalize(z_flattened, p=2, dim=-1) embedding = F.normalize(self.embedding.weight, p=2, dim=-1) else: embedding = self.embedding.weight d = torch.sum(z_flattened ** 2, dim=1, keepdim=True) + \ torch.sum(embedding**2, dim=1) - 2 * \ torch.einsum('bd,dn->bn', z_flattened, torch.einsum('n d -> d n', embedding)) min_encoding_indices = torch.argmin(d, dim=1) z_q = embedding[min_encoding_indices].view(z.shape) perplexity = None min_encodings = None vq_loss = None commit_loss = None entropy_loss = None codebook_usage = 0 if self.show_usage and self.training: cur_len = min_encoding_indices.shape[0] self.codebook_used[:-cur_len] = self.codebook_used[cur_len:].clone() self.codebook_used[-cur_len:] = min_encoding_indices codebook_usage = len(torch.unique(self.codebook_used)) / self.n_e # compute loss for embedding if self.training: vq_loss = torch.mean((z_q - z.detach()) ** 2) commit_loss = self.beta * torch.mean((z_q.detach() - z) ** 2) entropy_loss = self.entropy_loss_ratio * compute_entropy_loss(-d) # preserve gradients z_q = z + (z_q - z).detach() # reshape back to match original input shape z_q = torch.einsum('b h w c -> b c h w', z_q) return z_q, (vq_loss, commit_loss, entropy_loss, codebook_usage), (perplexity, min_encodings, min_encoding_indices) def get_codebook_entry(self, indices, shape=None, channel_first=True): # shape = (batch, channel, height, width) if channel_first else (batch, height, width, channel) if self.l2_norm: embedding = F.normalize(self.embedding.weight, p=2, dim=-1) else: embedding = self.embedding.weight z_q = embedding[indices] # (b*h*w, c) if shape is not None: if channel_first: z_q = z_q.reshape(shape[0], shape[2], shape[3], shape[1]) # reshape back to match original input shape z_q = z_q.permute(0, 3, 1, 2).contiguous() else: z_q = z_q.view(shape) return z_q def compute_entropy_loss(affinity, loss_type="softmax", temperature=0.01): flat_affinity = affinity.reshape(-1, affinity.shape[-1]) flat_affinity /= temperature probs = F.softmax(flat_affinity, dim=-1) log_probs = F.log_softmax(flat_affinity + 1e-5, dim=-1) if loss_type == "softmax": target_probs = probs else: raise ValueError("Entropy loss {} not supported".format(loss_type)) avg_probs = torch.mean(target_probs, dim=0) avg_entropy = - torch.sum(avg_probs * torch.log(avg_probs + 1e-5)) sample_entropy = - torch.mean(torch.sum(target_probs * log_probs, dim=-1)) loss = sample_entropy - avg_entropy return loss