Spaces:
Runtime error
Runtime error
File size: 9,781 Bytes
ed8c297 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 |
import math
import torch
import torch.nn.functional as F
import math
def linear_beta_schedule(timesteps):
scale = 1.0 # for 100 steps
beta_start = scale * 0.0001
beta_end = scale * 0.02
return torch.linspace(beta_start, beta_end, timesteps, dtype=torch.float64)
class GaussianDiffusion:
def __init__(
self,
device,
fix_mode=False,
text_emb=False,
fixed_frames=2,
seq_len=16,
timesteps=100,
beta_schedule='linear',
):
self.device = device
self.fix_mode = fix_mode # autoregressive
self.fixed_frames = fixed_frames # number of frames to fix
self.timesteps = timesteps
self.text_emb = text_emb
self.seq_len = seq_len
if beta_schedule == 'linear':
betas = linear_beta_schedule(timesteps)
elif beta_schedule == 'cosine':
raise NotImplementedError('cosine schedule is not implemented yet!')
else:
raise ValueError(f'unknown beta schedule {beta_schedule}')
self.betas = betas.to(self.device)
self.alphas = (1. - self.betas).to(self.device)
self.alphas_cumprod = torch.cumprod(self.alphas, axis=0).to(self.device)
self.alphas_cumprod_prev = F.pad(self.alphas_cumprod[:-1], (1, 0), value=1.).to(self.device)
# calculations for diffusion q(x_t | x_{t-1}) and others
self.sqrt_alphas_cumprod = torch.sqrt(self.alphas_cumprod).to(self.device)
self.sqrt_one_minus_alphas_cumprod = torch.sqrt(1.0 - self.alphas_cumprod).to(self.device)
self.log_one_minus_alphas_cumprod = torch.log(1.0 - self.alphas_cumprod).to(self.device)
self.sqrt_recip_alphas_cumprod = torch.sqrt(1.0 / self.alphas_cumprod).to(self.device)
self.sqrt_recipm1_alphas_cumprod = torch.sqrt(1.0 / self.alphas_cumprod - 1).to(self.device)
# calculations for posterior q(x_{t-1} | x_t, x_0)
self.posterior_variance = (
self.betas * (1.0 - self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod)
).to(self.device)
# below: log calculation clipped because the posterior variance is 0 at the beginning
# of the diffusion chain
self.posterior_log_variance_clipped = torch.log(self.posterior_variance.clamp(min =1e-20)).to(self.device)
self.posterior_mean_coef1 = (
self.betas * torch.sqrt(self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod)
).to(self.device)
self.posterior_mean_coef2 = (
(1.0 - self.alphas_cumprod_prev)
* torch.sqrt(self.alphas)
/ (1.0 - self.alphas_cumprod)
).to(self.device)
# get the param of given timestep t
def _extract(self, a, t, x_shape):
batch_size = t.shape[0]
out = a.to(t.device).gather(0, t).float()
out = out.reshape(batch_size, *((1,) * (len(x_shape) - 1))).to(self.device)
return out
# forward diffusion (using the nice property): q(x_t | x_0)
def q_sample(self, x_start, t, noise=None):
if noise is None:
noise = torch.randn_like(x_start)
sqrt_alphas_cumprod_t = self._extract(self.sqrt_alphas_cumprod, t, x_start.shape)
sqrt_one_minus_alphas_cumprod_t = self._extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape)
return sqrt_alphas_cumprod_t * x_start + sqrt_one_minus_alphas_cumprod_t * noise
# Get the mean and variance of q(x_t | x_0).
def q_mean_variance(self, x_start, t):
mean = self._extract(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
variance = self._extract(1.0 - self.alphas_cumprod, t, x_start.shape)
log_variance = self._extract(self.log_one_minus_alphas_cumprod, t, x_start.shape)
return mean, variance, log_variance
# Compute the mean and variance of the diffusion posterior: q(x_{t-1} | x_t, x_0)
def q_posterior_mean_variance(self, x_start, x_t, t):
posterior_mean = (
self._extract(self.posterior_mean_coef1, t, x_t.shape) * x_start
+ self._extract(self.posterior_mean_coef2, t, x_t.shape) * x_t
)
posterior_variance = self._extract(self.posterior_variance, t, x_t.shape)
posterior_log_variance_clipped = self._extract(self.posterior_log_variance_clipped, t, x_t.shape)
return posterior_mean, posterior_variance, posterior_log_variance_clipped
# compute x_0 from x_t and pred noise: the reverse of `q_sample`
def predict_start_from_noise(self, x_t, t, noise):
return (
self._extract(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t -
self._extract(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * noise
)
# compute predicted mean and variance of p(x_{t-1} | x_t)
def p_mean_variance(self, model, x_t, t, clip_denoised=True, **kwargs):
# predict noise using model
assert 'text' in kwargs, 'text is required'
assert 'prog_ind' in kwargs, 'prog_ind is required'
assert 'joints_orig' in kwargs, 'joints_orig is required'
pred_noise = model(x_t, t,
text_emb=kwargs['text'],
prog_ind=kwargs['prog_ind'],
joints_orig=kwargs['joints_orig'])
# use cfg for text embedding:
if kwargs['use_cfg']:
pred_noise_empty = model(x_t, t,
text_emb=torch.zeros_like(kwargs['text']),
prog_ind=kwargs['prog_ind'],
joints_orig=kwargs['joints_orig'])
pred_noise = pred_noise_empty + kwargs['cfg_alpha'] * (pred_noise - pred_noise_empty)
# get the predicted x_0: different from the algorithm2 in the paper
x_recon = self.predict_start_from_noise(x_t, t, pred_noise)
if clip_denoised:
x_recon = torch.clamp(x_recon, min=-1., max=1.)
model_mean, posterior_variance, posterior_log_variance = \
self.q_posterior_mean_variance(x_recon, x_t, t)
return model_mean, posterior_variance, posterior_log_variance
# denoise_step: sample x_{t-1} from x_t and pred_noise
# @torch.no_grad()
def p_sample(self, model, x_t, t, clip_denoised=True, **kwargs):
if 'disc_model' in kwargs:
disc_model = kwargs['disc_model']
try:
cg_alpha = kwargs['cg_alpha'] # default 1.0
cg_diffusion_steps = kwargs['cg_diffusion_steps']
except:
print("cg_alpha and cg_diffusion_steps are not provided!")
print("Using default values: cg_alpha=1.0, cg_diffusion_steps=5")
cg_alpha = 1.0
cg_diffusion_steps = 5
# predict mean and variance
model_mean, _, model_log_variance = self.p_mean_variance(model, x_t, t,
clip_denoised=clip_denoised, **kwargs)
model_mean = torch.tensor(model_mean, requires_grad=True)
noise = torch.randn_like(x_t)
# no noise when t == 0
nonzero_mask = ((t != 0).float().view(-1, *([1] * (len(x_t.shape) - 1))))
if t.item() < cg_diffusion_steps:
pred_syn = disc_model(model_mean, t) # y = f(theta, x) theta fixed
pred_syn.backward()
grad = model_mean.grad * cg_alpha
model_mean = model_mean - nonzero_mask * (0.5 * model_log_variance).exp() * grad
# compute x_{t-1}
pred_img = model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
return pred_img
# denoise: reverse diffusion
# @torch.no_grad()
def p_sample_loop(self, model, shape, fixed_points=None, **kwargs):
batch_size = shape[0]
device = next(model.parameters()).device
# start from pure noise (for each example in the batch)
img = torch.randn(shape, device=device)
# notice that if we are in fixed mode, we need to fix the first 2 frames
if self.fix_mode:
assert not (fixed_points is None), 'fixed_points is required for fixed mode'
img[:, :self.fixed_frames, :] = fixed_points
imgs = []
for i in reversed(range(0, self.timesteps)):
img = self.p_sample(model, img, torch.full((batch_size,), i, device=device, dtype=torch.long), **kwargs)
if self.fix_mode:
img[:, :self.fixed_frames, :] = fixed_points
imgs.append(img)
return imgs
# sample new images
# @torch.no_grad()
def sample(self, model, batch_size=1, seq_len=16, channels=135,
fixed_points=None, clip_denoised=True, **kwargs):
return self.p_sample_loop(model, shape=(batch_size, seq_len, channels),
fixed_points=fixed_points, clip_denoised=clip_denoised, **kwargs)
# compute train losses
def train_losses(self, model, x_start, t, mask=None, **kwargs):
assert not (mask is None and self.fixed_mode), 'mask is required for fixed mode'
if mask is None:
mask = torch.zeros_like(x_start).to(dtype=torch.bool, device=self.device)
mask_inv = torch.logical_not(mask)
# generate random noise
noise = torch.randn_like(x_start).to(device=self.device)
noise[mask] = 0.
# get x_t
x_noisy = self.q_sample(x_start, t, noise=noise)
predicted_noise = model(x_noisy, t, text_emb=kwargs['text'], prog_ind=kwargs['prog_ind'], joints_orig=kwargs['joints_orig'])
loss = F.smooth_l1_loss(noise[mask_inv], predicted_noise[mask_inv])
return loss |