motionReFit / src /model /gaussian_diffusion.py
Yzy00518's picture
Upload src/model/gaussian_diffusion.py with huggingface_hub
ed8c297
import math
import torch
import torch.nn.functional as F
import math
def linear_beta_schedule(timesteps):
scale = 1.0 # for 100 steps
beta_start = scale * 0.0001
beta_end = scale * 0.02
return torch.linspace(beta_start, beta_end, timesteps, dtype=torch.float64)
class GaussianDiffusion:
def __init__(
self,
device,
fix_mode=False,
text_emb=False,
fixed_frames=2,
seq_len=16,
timesteps=100,
beta_schedule='linear',
):
self.device = device
self.fix_mode = fix_mode # autoregressive
self.fixed_frames = fixed_frames # number of frames to fix
self.timesteps = timesteps
self.text_emb = text_emb
self.seq_len = seq_len
if beta_schedule == 'linear':
betas = linear_beta_schedule(timesteps)
elif beta_schedule == 'cosine':
raise NotImplementedError('cosine schedule is not implemented yet!')
else:
raise ValueError(f'unknown beta schedule {beta_schedule}')
self.betas = betas.to(self.device)
self.alphas = (1. - self.betas).to(self.device)
self.alphas_cumprod = torch.cumprod(self.alphas, axis=0).to(self.device)
self.alphas_cumprod_prev = F.pad(self.alphas_cumprod[:-1], (1, 0), value=1.).to(self.device)
# calculations for diffusion q(x_t | x_{t-1}) and others
self.sqrt_alphas_cumprod = torch.sqrt(self.alphas_cumprod).to(self.device)
self.sqrt_one_minus_alphas_cumprod = torch.sqrt(1.0 - self.alphas_cumprod).to(self.device)
self.log_one_minus_alphas_cumprod = torch.log(1.0 - self.alphas_cumprod).to(self.device)
self.sqrt_recip_alphas_cumprod = torch.sqrt(1.0 / self.alphas_cumprod).to(self.device)
self.sqrt_recipm1_alphas_cumprod = torch.sqrt(1.0 / self.alphas_cumprod - 1).to(self.device)
# calculations for posterior q(x_{t-1} | x_t, x_0)
self.posterior_variance = (
self.betas * (1.0 - self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod)
).to(self.device)
# below: log calculation clipped because the posterior variance is 0 at the beginning
# of the diffusion chain
self.posterior_log_variance_clipped = torch.log(self.posterior_variance.clamp(min =1e-20)).to(self.device)
self.posterior_mean_coef1 = (
self.betas * torch.sqrt(self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod)
).to(self.device)
self.posterior_mean_coef2 = (
(1.0 - self.alphas_cumprod_prev)
* torch.sqrt(self.alphas)
/ (1.0 - self.alphas_cumprod)
).to(self.device)
# get the param of given timestep t
def _extract(self, a, t, x_shape):
batch_size = t.shape[0]
out = a.to(t.device).gather(0, t).float()
out = out.reshape(batch_size, *((1,) * (len(x_shape) - 1))).to(self.device)
return out
# forward diffusion (using the nice property): q(x_t | x_0)
def q_sample(self, x_start, t, noise=None):
if noise is None:
noise = torch.randn_like(x_start)
sqrt_alphas_cumprod_t = self._extract(self.sqrt_alphas_cumprod, t, x_start.shape)
sqrt_one_minus_alphas_cumprod_t = self._extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape)
return sqrt_alphas_cumprod_t * x_start + sqrt_one_minus_alphas_cumprod_t * noise
# Get the mean and variance of q(x_t | x_0).
def q_mean_variance(self, x_start, t):
mean = self._extract(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
variance = self._extract(1.0 - self.alphas_cumprod, t, x_start.shape)
log_variance = self._extract(self.log_one_minus_alphas_cumprod, t, x_start.shape)
return mean, variance, log_variance
# Compute the mean and variance of the diffusion posterior: q(x_{t-1} | x_t, x_0)
def q_posterior_mean_variance(self, x_start, x_t, t):
posterior_mean = (
self._extract(self.posterior_mean_coef1, t, x_t.shape) * x_start
+ self._extract(self.posterior_mean_coef2, t, x_t.shape) * x_t
)
posterior_variance = self._extract(self.posterior_variance, t, x_t.shape)
posterior_log_variance_clipped = self._extract(self.posterior_log_variance_clipped, t, x_t.shape)
return posterior_mean, posterior_variance, posterior_log_variance_clipped
# compute x_0 from x_t and pred noise: the reverse of `q_sample`
def predict_start_from_noise(self, x_t, t, noise):
return (
self._extract(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t -
self._extract(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * noise
)
# compute predicted mean and variance of p(x_{t-1} | x_t)
def p_mean_variance(self, model, x_t, t, clip_denoised=True, **kwargs):
# predict noise using model
assert 'text' in kwargs, 'text is required'
assert 'prog_ind' in kwargs, 'prog_ind is required'
assert 'joints_orig' in kwargs, 'joints_orig is required'
pred_noise = model(x_t, t,
text_emb=kwargs['text'],
prog_ind=kwargs['prog_ind'],
joints_orig=kwargs['joints_orig'])
# use cfg for text embedding:
if kwargs['use_cfg']:
pred_noise_empty = model(x_t, t,
text_emb=torch.zeros_like(kwargs['text']),
prog_ind=kwargs['prog_ind'],
joints_orig=kwargs['joints_orig'])
pred_noise = pred_noise_empty + kwargs['cfg_alpha'] * (pred_noise - pred_noise_empty)
# get the predicted x_0: different from the algorithm2 in the paper
x_recon = self.predict_start_from_noise(x_t, t, pred_noise)
if clip_denoised:
x_recon = torch.clamp(x_recon, min=-1., max=1.)
model_mean, posterior_variance, posterior_log_variance = \
self.q_posterior_mean_variance(x_recon, x_t, t)
return model_mean, posterior_variance, posterior_log_variance
# denoise_step: sample x_{t-1} from x_t and pred_noise
# @torch.no_grad()
def p_sample(self, model, x_t, t, clip_denoised=True, **kwargs):
if 'disc_model' in kwargs:
disc_model = kwargs['disc_model']
try:
cg_alpha = kwargs['cg_alpha'] # default 1.0
cg_diffusion_steps = kwargs['cg_diffusion_steps']
except:
print("cg_alpha and cg_diffusion_steps are not provided!")
print("Using default values: cg_alpha=1.0, cg_diffusion_steps=5")
cg_alpha = 1.0
cg_diffusion_steps = 5
# predict mean and variance
model_mean, _, model_log_variance = self.p_mean_variance(model, x_t, t,
clip_denoised=clip_denoised, **kwargs)
model_mean = torch.tensor(model_mean, requires_grad=True)
noise = torch.randn_like(x_t)
# no noise when t == 0
nonzero_mask = ((t != 0).float().view(-1, *([1] * (len(x_t.shape) - 1))))
if t.item() < cg_diffusion_steps:
pred_syn = disc_model(model_mean, t) # y = f(theta, x) theta fixed
pred_syn.backward()
grad = model_mean.grad * cg_alpha
model_mean = model_mean - nonzero_mask * (0.5 * model_log_variance).exp() * grad
# compute x_{t-1}
pred_img = model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
return pred_img
# denoise: reverse diffusion
# @torch.no_grad()
def p_sample_loop(self, model, shape, fixed_points=None, **kwargs):
batch_size = shape[0]
device = next(model.parameters()).device
# start from pure noise (for each example in the batch)
img = torch.randn(shape, device=device)
# notice that if we are in fixed mode, we need to fix the first 2 frames
if self.fix_mode:
assert not (fixed_points is None), 'fixed_points is required for fixed mode'
img[:, :self.fixed_frames, :] = fixed_points
imgs = []
for i in reversed(range(0, self.timesteps)):
img = self.p_sample(model, img, torch.full((batch_size,), i, device=device, dtype=torch.long), **kwargs)
if self.fix_mode:
img[:, :self.fixed_frames, :] = fixed_points
imgs.append(img)
return imgs
# sample new images
# @torch.no_grad()
def sample(self, model, batch_size=1, seq_len=16, channels=135,
fixed_points=None, clip_denoised=True, **kwargs):
return self.p_sample_loop(model, shape=(batch_size, seq_len, channels),
fixed_points=fixed_points, clip_denoised=clip_denoised, **kwargs)
# compute train losses
def train_losses(self, model, x_start, t, mask=None, **kwargs):
assert not (mask is None and self.fixed_mode), 'mask is required for fixed mode'
if mask is None:
mask = torch.zeros_like(x_start).to(dtype=torch.bool, device=self.device)
mask_inv = torch.logical_not(mask)
# generate random noise
noise = torch.randn_like(x_start).to(device=self.device)
noise[mask] = 0.
# get x_t
x_noisy = self.q_sample(x_start, t, noise=noise)
predicted_noise = model(x_noisy, t, text_emb=kwargs['text'], prog_ind=kwargs['prog_ind'], joints_orig=kwargs['joints_orig'])
loss = F.smooth_l1_loss(noise[mask_inv], predicted_noise[mask_inv])
return loss