|
|
|
|
|
|
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import numpy as np |
|
|
import torch.nn.functional as F |
|
|
|
|
|
|
|
|
class DiagonalGaussianDistribution(object): |
|
|
def __init__(self, parameters, deterministic=False): |
|
|
self.parameters = parameters |
|
|
self.mean, self.logvar = torch.chunk(parameters, 2, dim=1) |
|
|
self.logvar = torch.clamp(self.logvar, -30.0, 20.0) |
|
|
self.deterministic = deterministic |
|
|
self.std = torch.exp(0.5 * self.logvar) |
|
|
self.var = torch.exp(self.logvar) |
|
|
if self.deterministic: |
|
|
self.var = self.std = torch.zeros_like(self.mean).to(device=self.parameters.device) |
|
|
|
|
|
def sample(self): |
|
|
|
|
|
x = self.mean + self.std * torch.randn(self.mean.shape, device=self.parameters.device) |
|
|
return x |
|
|
|
|
|
def kl(self, other=None, reduction="sum"): |
|
|
if reduction == "sum": |
|
|
reduction_op = torch.sum |
|
|
elif reduction == "mean": |
|
|
reduction_op = torch.mean |
|
|
if self.mean.ndim == 4: |
|
|
dims = [1,2,3] |
|
|
else: |
|
|
dims = [1,2,3,4] |
|
|
if self.deterministic: |
|
|
return torch.Tensor([0.]) |
|
|
else: |
|
|
if other is None: |
|
|
return 0.5 * reduction_op(torch.pow(self.mean, 2) |
|
|
+ self.var - 1.0 - self.logvar, |
|
|
dim=dims) |
|
|
else: |
|
|
return 0.5 * reduction_op( |
|
|
torch.pow(self.mean - other.mean, 2) / other.var |
|
|
+ self.var / other.var - 1.0 - self.logvar + other.logvar, |
|
|
dim=dims) |
|
|
|
|
|
def nll(self, sample, dims=[1,2,3]): |
|
|
if self.deterministic: |
|
|
return torch.Tensor([0.]) |
|
|
logtwopi = np.log(2.0 * np.pi) |
|
|
return 0.5 * torch.sum( |
|
|
logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var, |
|
|
dim=dims) |
|
|
|
|
|
def mode(self): |
|
|
return self.mean |
|
|
|
|
|
|
|
|
|
|
|
def normal_kl(mean1, logvar1, mean2, logvar2): |
|
|
""" |
|
|
source: https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/losses.py#L12 |
|
|
Compute the KL divergence between two gaussians. |
|
|
Shapes are automatically broadcasted, so batches can be compared to |
|
|
scalars, among other use cases. |
|
|
""" |
|
|
tensor = None |
|
|
for obj in (mean1, logvar1, mean2, logvar2): |
|
|
if isinstance(obj, torch.Tensor): |
|
|
tensor = obj |
|
|
break |
|
|
assert tensor is not None, "at least one argument must be a Tensor" |
|
|
|
|
|
|
|
|
|
|
|
logvar1, logvar2 = [ |
|
|
x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor) |
|
|
for x in (logvar1, logvar2) |
|
|
] |
|
|
|
|
|
return 0.5 * ( |
|
|
-1.0 |
|
|
+ logvar2 |
|
|
- logvar1 |
|
|
+ torch.exp(logvar1 - logvar2) |
|
|
+ ((mean1 - mean2) ** 2) * torch.exp(-logvar2) |
|
|
) |
|
|
|
|
|
class VectorQuantizer(nn.Module): |
|
|
def __init__(self, n_e, e_dim, beta, entropy_loss_ratio, l2_norm, show_usage): |
|
|
super().__init__() |
|
|
self.n_e = n_e |
|
|
self.e_dim = e_dim |
|
|
self.beta = beta |
|
|
self.entropy_loss_ratio = entropy_loss_ratio |
|
|
self.l2_norm = l2_norm |
|
|
self.show_usage = show_usage |
|
|
|
|
|
self.embedding = nn.Embedding(self.n_e, self.e_dim) |
|
|
self.embedding.weight.data.uniform_(-1.0 / self.n_e, 1.0 / self.n_e) |
|
|
if self.l2_norm: |
|
|
self.embedding.weight.data = F.normalize(self.embedding.weight.data, p=2, dim=-1) |
|
|
if self.show_usage: |
|
|
self.register_buffer("codebook_used", nn.Parameter(torch.zeros(65536))) |
|
|
|
|
|
|
|
|
def forward(self, z): |
|
|
|
|
|
z = torch.einsum('b c h w -> b h w c', z).contiguous() |
|
|
z_flattened = z.view(-1, self.e_dim) |
|
|
|
|
|
|
|
|
if self.l2_norm: |
|
|
z = F.normalize(z, p=2, dim=-1) |
|
|
z_flattened = F.normalize(z_flattened, p=2, dim=-1) |
|
|
embedding = F.normalize(self.embedding.weight, p=2, dim=-1) |
|
|
else: |
|
|
embedding = self.embedding.weight |
|
|
|
|
|
d = torch.sum(z_flattened ** 2, dim=1, keepdim=True) + \ |
|
|
torch.sum(embedding**2, dim=1) - 2 * \ |
|
|
torch.einsum('bd,dn->bn', z_flattened, torch.einsum('n d -> d n', embedding)) |
|
|
|
|
|
min_encoding_indices = torch.argmin(d, dim=1) |
|
|
z_q = embedding[min_encoding_indices].view(z.shape) |
|
|
perplexity = None |
|
|
min_encodings = None |
|
|
vq_loss = None |
|
|
commit_loss = None |
|
|
entropy_loss = None |
|
|
codebook_usage = 0 |
|
|
|
|
|
if self.show_usage and self.training: |
|
|
cur_len = min_encoding_indices.shape[0] |
|
|
self.codebook_used[:-cur_len] = self.codebook_used[cur_len:].clone() |
|
|
self.codebook_used[-cur_len:] = min_encoding_indices |
|
|
codebook_usage = len(torch.unique(self.codebook_used)) / self.n_e |
|
|
|
|
|
|
|
|
if self.training: |
|
|
vq_loss = torch.mean((z_q - z.detach()) ** 2) |
|
|
commit_loss = self.beta * torch.mean((z_q.detach() - z) ** 2) |
|
|
entropy_loss = self.entropy_loss_ratio * compute_entropy_loss(-d) |
|
|
|
|
|
|
|
|
z_q = z + (z_q - z).detach() |
|
|
|
|
|
|
|
|
z_q = torch.einsum('b h w c -> b c h w', z_q) |
|
|
|
|
|
return z_q, (vq_loss, commit_loss, entropy_loss, codebook_usage), (perplexity, min_encodings, min_encoding_indices) |
|
|
|
|
|
def get_codebook_entry(self, indices, shape=None, channel_first=True): |
|
|
|
|
|
if self.l2_norm: |
|
|
embedding = F.normalize(self.embedding.weight, p=2, dim=-1) |
|
|
else: |
|
|
embedding = self.embedding.weight |
|
|
z_q = embedding[indices] |
|
|
|
|
|
if shape is not None: |
|
|
if channel_first: |
|
|
z_q = z_q.reshape(shape[0], shape[2], shape[3], shape[1]) |
|
|
|
|
|
z_q = z_q.permute(0, 3, 1, 2).contiguous() |
|
|
else: |
|
|
z_q = z_q.view(shape) |
|
|
return z_q |
|
|
|
|
|
def compute_entropy_loss(affinity, loss_type="softmax", temperature=0.01): |
|
|
flat_affinity = affinity.reshape(-1, affinity.shape[-1]) |
|
|
flat_affinity /= temperature |
|
|
probs = F.softmax(flat_affinity, dim=-1) |
|
|
log_probs = F.log_softmax(flat_affinity + 1e-5, dim=-1) |
|
|
if loss_type == "softmax": |
|
|
target_probs = probs |
|
|
else: |
|
|
raise ValueError("Entropy loss {} not supported".format(loss_type)) |
|
|
avg_probs = torch.mean(target_probs, dim=0) |
|
|
avg_entropy = - torch.sum(avg_probs * torch.log(avg_probs + 1e-5)) |
|
|
sample_entropy = - torch.mean(torch.sum(target_probs * log_probs, dim=-1)) |
|
|
loss = sample_entropy - avg_entropy |
|
|
return loss |