ProtoBind-Diff / protobind_diff /noise_schedule.py
vladimir.manuylov
initial commit
bd082dc
import abc
import torch
import torch.nn as nn
# Flags required to enable jit fusion kernels
torch._C._jit_set_profiling_mode(False)
torch._C._jit_set_profiling_executor(False)
torch._C._jit_override_can_fuse_on_cpu(True)
torch._C._jit_override_can_fuse_on_gpu(True)
def _sample_categorical(categorical_probs):
gumbel_norm = (
1e-10
- (torch.rand_like(categorical_probs) + 1e-10).log())
return (categorical_probs / gumbel_norm).argmax(dim=-1)
def _unsqueeze(x, reference):
return x.view(
* x.shape,
* ((1,) * (len(reference.shape) - len(x.shape))))
def _sample_t(n, device, antithetic_sampling=True, sampling_eps=1e-3):
_eps_t = torch.rand(n, device=device)
if antithetic_sampling:
offset = torch.arange(n, device=device) / n
_eps_t = (_eps_t / n + offset) % 1
t = (1 - sampling_eps) * _eps_t + sampling_eps
return t
def q_xt( x, move_chance, mask_index):
"""Computes the noisy sample xt.
Args:
x: int torch.Tensor with shape (batch_size,
diffusion_model_input_length), input.
move_chance: float torch.Tensor with shape (batch_size, 1).
"""
move_indices = torch.rand(
* x.shape, device=x.device) < move_chance
xt = torch.where(move_indices, mask_index, x)
return xt
def get_noise(config, dtype=torch.float32):
if config.noise.type == 'geometric':
return GeometricNoise(config.noise.sigma_min,
config.noise.sigma_max)
elif config.noise.type == 'loglinear':
return LogLinearNoise()
elif config.noise.type == 'cosine':
return CosineNoise()
elif config.noise.type == 'cosinesqr':
return CosineSqrNoise()
elif config.noise.type == 'linear':
return Linear(config.noise.sigma_min,
config.noise.sigma_max,
dtype)
else:
raise ValueError(f'{config.noise.type} is not a valid noise')
def binary_discretization(z):
z_hard = torch.sign(z)
z_soft = z / torch.norm(z, dim=-1, keepdim=True)
return z_soft + (z_hard - z_soft).detach()
class Noise(abc.ABC, nn.Module):
"""
Baseline forward method to get the total + rate of noise at a timestep
"""
def forward(self, t):
# Assume time goes from 0 to 1
return self.total_noise(t), self.rate_noise(t)
@abc.abstractmethod
def rate_noise(self, t):
"""
Rate of change of noise ie g(t)
"""
pass
@abc.abstractmethod
def total_noise(self, t):
"""
Total noise ie \int_0^t g(t) dt + g(0)
"""
pass
class CosineNoise(Noise):
def __init__(self, eps=1e-3):
super().__init__()
self.eps = eps
def rate_noise(self, t):
cos = (1 - self.eps) * torch.cos(t * torch.pi / 2)
sin = (1 - self.eps) * torch.sin(t * torch.pi / 2)
scale = torch.pi / 2
return scale * sin / (cos + self.eps)
def total_noise(self, t):
cos = torch.cos(t * torch.pi / 2)
return - torch.log(self.eps + (1 - self.eps) * cos)
class CosineSqrNoise(Noise):
def __init__(self, eps=1e-3):
super().__init__()
self.eps = eps
def rate_noise(self, t):
cos = (1 - self.eps) * (
torch.cos(t * torch.pi / 2) ** 2)
sin = (1 - self.eps) * torch.sin(t * torch.pi)
scale = torch.pi / 2
return scale * sin / (cos + self.eps)
def total_noise(self, t):
cos = torch.cos(t * torch.pi / 2) ** 2
return - torch.log(self.eps + (1 - self.eps) * cos)
class Linear(Noise):
def __init__(self, sigma_min=0, sigma_max=10, dtype=torch.float32):
super().__init__()
self.sigma_min = torch.tensor(sigma_min, dtype=dtype)
self.sigma_max = torch.tensor(sigma_max, dtype=dtype)
def rate_noise(self, t):
return self.sigma_max - self.sigma_min
def total_noise(self, t):
return self.sigma_min + t * (self.sigma_max - self.sigma_min)
def importance_sampling_transformation(self, t):
f_T = torch.log1p(- torch.exp(- self.sigma_max))
f_0 = torch.log1p(- torch.exp(- self.sigma_min))
sigma_t = - torch.log1p(- torch.exp(t * f_T + (1 - t) * f_0))
return (sigma_t - self.sigma_min) / (
self.sigma_max - self.sigma_min)
class GeometricNoise(Noise):
def __init__(self, sigma_min=1e-3, sigma_max=1):
super().__init__()
self.sigmas = 1.0 * torch.tensor([sigma_min, sigma_max])
def rate_noise(self, t):
return self.sigmas[0] ** (1 - t) * self.sigmas[1] ** t * (
self.sigmas[1].log() - self.sigmas[0].log())
def total_noise(self, t):
return self.sigmas[0] ** (1 - t) * self.sigmas[1] ** t
class LogLinearNoise(Noise):
"""Log Linear noise schedule.
Built such that 1 - 1/e^(n(t)) interpolates between 0 and 1.
"""
def __init__(self, eps=1e-3):
super().__init__()
self.eps = eps
self.sigma_max = self.total_noise(torch.tensor(1.0))
self.sigma_min = self.eps + self.total_noise(torch.tensor(0.0))
def rate_noise(self, t):
return (1 - self.eps) / (1 - (1 - self.eps) * t)
def total_noise(self, t):
return -torch.log1p(-(1 - self.eps) * t)
def importance_sampling_transformation(self, t):
f_T = torch.log1p(- torch.exp(- self.sigma_max))
f_0 = torch.log1p(- torch.exp(- self.sigma_min))
sigma_t = - torch.log1p(- torch.exp(t * f_T + (1 - t) * f_0))
t = - torch.expm1(- sigma_t) / (1 - self.eps)
return t