File size: 8,435 Bytes
6189cd3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 |
# Copied from https://github.com/kvablack/ddpo-pytorch/blob/main/ddpo_pytorch/diffusers_patch/ddim_with_logprob.py, which is licensed under MIT license.
from typing import Optional, Tuple, Union
import math
import torch
from diffusers.utils.torch_utils import randn_tensor
from diffusers.schedulers.scheduling_ddim import DDIMSchedulerOutput, DDIMScheduler
def _left_broadcast(t, shape):
assert t.ndim <= len(shape)
return t.reshape(t.shape + (1,) * (len(shape) - t.ndim)).broadcast_to(shape)
def _get_variance(self, timestep, prev_timestep):
## a_t
alpha_prod_t = torch.gather(self.alphas_cumprod, 0, timestep.cpu()).to(timestep.device)
## a_t-1
alpha_prod_t_prev = torch.where(prev_timestep.cpu() >= 0,self.alphas_cumprod.gather(0, prev_timestep.cpu()),self.final_alpha_cumprod,).to(timestep.device)
## b_t
beta_prod_t = 1 - alpha_prod_t
## b_t-1
beta_prod_t_prev = 1 - alpha_prod_t_prev
## (b_t-1 / b_t) * (1 - a_t/a_t-1)
variance = (beta_prod_t_prev / beta_prod_t) * (1 - alpha_prod_t / alpha_prod_t_prev)
return variance
def ddim_step_with_logprob(
self: DDIMScheduler,
model_output: torch.FloatTensor,
timestep: int,
sample: torch.FloatTensor,
eta: float = 0.0,
use_clipped_model_output: bool = False,
generator=None,
prev_sample: Optional[torch.FloatTensor] = None,
) -> Union[DDIMSchedulerOutput, Tuple]:
"""
Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion
process from the learned model outputs (most often the predicted noise).
Args:
model_output (`torch.FloatTensor`): direct output from learned diffusion model.
timestep (`int`): current discrete timestep in the diffusion chain.
sample (`torch.FloatTensor`):
current instance of sample being created by diffusion process.
eta (`float`): weight of noise for added noise in diffusion step.
use_clipped_model_output (`bool`): if `True`, compute "corrected" `model_output` from the clipped
predicted original sample. Necessary because predicted original sample is clipped to [-1, 1] when
`self.config.clip_sample` is `True`. If no clipping has happened, "corrected" `model_output` would
coincide with the one provided as input and `use_clipped_model_output` will have not effect.
generator: random number generator.
variance_noise (`torch.FloatTensor`): instead of generating noise for the variance using `generator`, we
can directly provide the noise for the variance itself. This is useful for methods such as
CycleDiffusion. (https://arxiv.org/abs/2210.05559)
return_dict (`bool`): option for returning tuple rather than DDIMSchedulerOutput class
Returns:
[`~schedulers.scheduling_utils.DDIMSchedulerOutput`] or `tuple`:
[`~schedulers.scheduling_utils.DDIMSchedulerOutput`] if `return_dict` is True, otherwise a `tuple`. When
returning a tuple, the first element is the sample tensor.
"""
assert isinstance(self, DDIMScheduler)
if self.num_inference_steps is None:
raise ValueError(
"Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
)
# See formulas (12) and (16) of DDIM paper https://arxiv.org/pdf/2010.02502.pdf
# Ideally, read DDIM paper in-detail understanding
# Notation (<variable name> -> <name in paper>
# - pred_noise_t -> e_theta(x_t, t)
# - pred_original_sample -> f_theta(x_t, t) or x_0
# - std_dev_t -> sigma_t
# - eta -> η
# - pred_sample_direction -> "direction pointing to x_t"
# - pred_prev_sample -> "x_t-1"
## t-1
prev_timestep = (
timestep - self.config.num_train_timesteps // self.num_inference_steps
)
prev_timestep = torch.clamp(prev_timestep, 0, self.config.num_train_timesteps - 1)
## a_t
alpha_prod_t = self.alphas_cumprod.gather(0, timestep.cpu())
## a_t-1
alpha_prod_t_prev = torch.where(
prev_timestep.cpu() >= 0,
self.alphas_cumprod.gather(0, prev_timestep.cpu()),
self.final_alpha_cumprod,
)
## s0:(2.1924) s5: (2.3384), s15: (2.6422) s24:(2.8335)
# eta_bound = (((1-alpha_prod_t) * alpha_prod_t_prev) / (alpha_prod_t_prev - alpha_prod_t)) ** (0.5)
## a_t # torch.Size([4, 4, 64, 64])
alpha_prod_t = _left_broadcast(alpha_prod_t, sample.shape).to(sample.device)
## a_t-1
alpha_prod_t_prev = _left_broadcast(alpha_prod_t_prev, sample.shape).to(
sample.device
)
## b_t
beta_prod_t = 1 - alpha_prod_t
## pred_x_0
if self.config.prediction_type == "epsilon":
pred_original_sample = (
sample - beta_prod_t ** (0.5) * model_output
) / alpha_prod_t ** (0.5)
pred_epsilon = model_output
elif self.config.prediction_type == "sample":
pred_original_sample = model_output
pred_epsilon = (
sample - alpha_prod_t ** (0.5) * pred_original_sample
) / beta_prod_t ** (0.5)
elif self.config.prediction_type == "v_prediction":
pred_original_sample = (alpha_prod_t**0.5) * sample - (
beta_prod_t**0.5
) * model_output
pred_epsilon = (alpha_prod_t**0.5) * model_output + (
beta_prod_t**0.5
) * sample
else:
raise ValueError(
f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or"
" `v_prediction`"
)
# 4. Clip or threshold "predicted x_0"
if self.config.thresholding:
pred_original_sample = self._threshold_sample(pred_original_sample)
elif self.config.clip_sample:
pred_original_sample = pred_original_sample.clamp(
-self.config.clip_sample_range, self.config.clip_sample_range
)
# 5. compute variance: "sigma_t(η)"
# σ_t = sqrt((1 − α_t−1)/(1 − α_t)) * sqrt(1 − α_t/α_t−1)
## var = (b_t-1 / b_t) * (1 - a_t/a_t-1)
variance = _get_variance(self, timestep, prev_timestep)
## std = eta * sqrt(var)
std_dev_t = eta * variance ** (0.5)
std_dev_t = _left_broadcast(std_dev_t, sample.shape).to(sample.device)
if use_clipped_model_output:
# the pred_epsilon is always re-derived from the clipped x_0 in Glide
pred_epsilon = (
sample - alpha_prod_t ** (0.5) * pred_original_sample
) / beta_prod_t ** (0.5)
# 6. compute "direction pointing to x_t" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
pred_sample_direction = (1 - alpha_prod_t_prev - std_dev_t**2) ** (0.5) * pred_epsilon
# 7. x_t-1-less
prev_sample_mean = (
alpha_prod_t_prev ** (0.5) * pred_original_sample + pred_sample_direction
)
if prev_sample is not None and generator is not None:
raise ValueError(
"Cannot pass both generator and prev_sample. Please make sure that either `generator` or"
" `prev_sample` stays `None`."
)
if prev_sample is None:
variance_noise = randn_tensor(
model_output.shape,
generator=generator,
device=model_output.device,
dtype=model_output.dtype,
)
# alpha = 1
# scale = 1.0 / (1 + 2*alpha + 2*alpha**2) ** 0.5
# new_noise_1 = variance_noise[[0]] + alpha * (variance_noise[[0]]-variance_noise[[1]])
# new_noise_2 = variance_noise[[1]] + alpha * (variance_noise[[1]]-variance_noise[[0]])
# new_noise_1 = new_noise_1 * scale
# new_noise_2 = new_noise_2 * scale
# new_noise = torch.cat((variance_noise[[0]], variance_noise[[1]], new_noise_1, new_noise_2), dim=0)
# prev_sample = prev_sample_mean + std_dev_t * new_noise
## x_t-1 = x_t-1_mean + std * noise
prev_sample = prev_sample_mean + std_dev_t * variance_noise
## x_t -> 多个 x_t-1
# log prob of prev_sample given prev_sample_mean and std_dev_t
log_prob = (
-((prev_sample.detach() - prev_sample_mean) ** 2) / (2 * (std_dev_t**2))
- torch.log(std_dev_t)
- torch.log(torch.sqrt(2 * torch.as_tensor(math.pi)))
)
# mean along all but batch dimension
log_prob = log_prob.mean(dim=tuple(range(1, log_prob.ndim)))
return prev_sample.type(sample.dtype), log_prob
|