Spaces:
Running
on
Zero
Running
on
Zero
File size: 5,633 Bytes
9294bc7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 |
# Copied from https://github.com/kvablack/ddpo-pytorch/blob/main/ddpo_pytorch/diffusers_patch/ddim_with_logprob.py
# We adapt it from flow to flow matching.
import math
from typing import Optional, Union
import torch
from diffusers.utils.torch_utils import randn_tensor
from diffusers.schedulers.scheduling_flow_match_euler_discrete import FlowMatchEulerDiscreteScheduler
def sde_step_with_logprob(
self: FlowMatchEulerDiscreteScheduler,
model_output: torch.FloatTensor,
timestep: Union[float, torch.FloatTensor],
sample: torch.FloatTensor,
noise_level: float = 0.7,
prev_sample: Optional[torch.FloatTensor] = None,
generator: Optional[torch.Generator] = None,
):
"""
Predict the sample from the previous timestep by reversing the SDE. This function propagates the flow
process from the learned model outputs (most often the predicted velocity).
Args:
model_output (`torch.FloatTensor`):
The direct output from learned flow model.
timestep (`float`):
The current discrete timestep in the diffusion chain.
sample (`torch.FloatTensor`):
A current instance of a sample created by the diffusion process.
generator (`torch.Generator`, *optional*):
A random number generator.
"""
# bf16 can overflow here when compute prev_sample_mean, we must convert all variable to fp32
model_output=model_output.float()
sample=sample.float()
if prev_sample is not None:
prev_sample=prev_sample.float()
step_index = [self.index_for_timestep(t) for t in timestep]
prev_step_index = [step+1 for step in step_index]
sigma = self.sigmas[step_index].view(-1, *([1] * (len(sample.shape) - 1)))
sigma_prev = self.sigmas[prev_step_index].view(-1, *([1] * (len(sample.shape) - 1)))
sigma_max = self.sigmas[1].item()
dt = sigma_prev - sigma
std_dev_t = torch.sqrt(sigma / (1 - torch.where(sigma == 1, sigma_max, sigma)))*noise_level
# import pdb; pdb.set_trace()
# our sde
prev_sample_mean = sample*(1+std_dev_t**2/(2*sigma)*dt)+model_output*(1+std_dev_t**2*(1-sigma)/(2*sigma))*dt
if prev_sample is None:
variance_noise = randn_tensor(
model_output.shape,
generator=generator,
device=model_output.device,
dtype=model_output.dtype,
)
prev_sample = prev_sample_mean + std_dev_t * torch.sqrt(-1*dt) * variance_noise
log_prob = (
-((prev_sample.detach() - prev_sample_mean) ** 2) / (2 * ((std_dev_t * torch.sqrt(-1*dt))**2))
- torch.log(std_dev_t * torch.sqrt(-1*dt))
- torch.log(torch.sqrt(2 * torch.as_tensor(math.pi)))
)
# mean along all but batch dimension
log_prob = log_prob.mean(dim=tuple(range(1, log_prob.ndim)))
return prev_sample, log_prob, prev_sample_mean, std_dev_t
def sde_step_with_logprob_new(
self: FlowMatchEulerDiscreteScheduler,
model_output: torch.FloatTensor,
timestep: Union[float, torch.FloatTensor],
sample: torch.FloatTensor,
noise_level: float = 0.7,
prev_sample: Optional[torch.FloatTensor] = None,
generator: Optional[torch.Generator] = None,
):
"""
Predict the sample from the previous timestep by reversing the SDE. This function propagates the flow
process from the learned model outputs (most often the predicted velocity).
Args:
model_output (`torch.FloatTensor`):
The direct output from learned flow model.
timestep (`float`):
The current discrete timestep in the diffusion chain.
sample (`torch.FloatTensor`):
A current instance of a sample created by the diffusion process.
generator (`torch.Generator`, *optional*):
A random number generator.
"""
# bf16 can overflow here when compute prev_sample_mean, we must convert all variable to fp32
model_output=model_output.float()
sample=sample.float()
if prev_sample is not None:
prev_sample=prev_sample.float()
step_index = [self.index_for_timestep(t) for t in timestep]
prev_step_index = [step+1 for step in step_index]
sigma = self.sigmas[step_index].view(-1, *([1] * (len(sample.shape) - 1)))
sigma_prev = self.sigmas[prev_step_index].view(-1, *([1] * (len(sample.shape) - 1)))
sigma_max = self.sigmas[1].item()
dt = sigma_prev - sigma
# Flow-SDE
#std_dev_t = torch.sqrt(sigma / (1 - torch.where(sigma == 1, sigma_max, sigma)))*noise_level * torch.sqrt(-1*dt)
# prev_sample_mean = sample*(1+std_dev_t**2/(2*sigma)*dt)+model_output*(1+std_dev_t**2*(1-sigma)/(2*sigma))*dt
# Flow-CPS
std_dev_t = sigma_prev * math.sin(noise_level * math.pi / 2) # sigma_t in paper
pred_original_sample = sample - sigma * model_output # predicted x_0 in paper
noise_estimate = sample + model_output * (1 - sigma) # predicted x_1 in paper
prev_sample_mean = pred_original_sample * (1 - sigma_prev) + noise_estimate * torch.sqrt(sigma_prev**2 - std_dev_t**2)
# import pdb; pdb.set_trace()
if prev_sample is None:
variance_noise = randn_tensor(
model_output.shape,
generator=generator,
device=model_output.device,
dtype=model_output.dtype,
)
prev_sample = prev_sample_mean + std_dev_t * variance_noise
# remove all constants
log_prob = -((prev_sample.detach() - prev_sample_mean) ** 2)
# mean along all but batch dimension
log_prob = log_prob.mean(dim=tuple(range(1, log_prob.ndim)))
return prev_sample, log_prob, prev_sample_mean, std_dev_t |