# Copied from https://github.com/kvablack/ddpo-pytorch/blob/main/ddpo_pytorch/diffusers_patch/ddim_with_logprob.py # We adapt it from flow to flow matching. import math from typing import Optional, Union import torch from diffusers.utils.torch_utils import randn_tensor from diffusers.schedulers.scheduling_flow_match_euler_discrete import FlowMatchEulerDiscreteScheduler def sde_step_with_logprob( self: FlowMatchEulerDiscreteScheduler, model_output: torch.FloatTensor, timestep: Union[float, torch.FloatTensor], sample: torch.FloatTensor, noise_level: float = 0.7, prev_sample: Optional[torch.FloatTensor] = None, generator: Optional[torch.Generator] = None, ): """ Predict the sample from the previous timestep by reversing the SDE. This function propagates the flow process from the learned model outputs (most often the predicted velocity). Args: model_output (`torch.FloatTensor`): The direct output from learned flow model. timestep (`float`): The current discrete timestep in the diffusion chain. sample (`torch.FloatTensor`): A current instance of a sample created by the diffusion process. generator (`torch.Generator`, *optional*): A random number generator. """ # bf16 can overflow here when compute prev_sample_mean, we must convert all variable to fp32 model_output=model_output.float() sample=sample.float() if prev_sample is not None: prev_sample=prev_sample.float() step_index = [self.index_for_timestep(t) for t in timestep] prev_step_index = [step+1 for step in step_index] sigma = self.sigmas[step_index].view(-1, *([1] * (len(sample.shape) - 1))) sigma_prev = self.sigmas[prev_step_index].view(-1, *([1] * (len(sample.shape) - 1))) sigma_max = self.sigmas[1].item() dt = sigma_prev - sigma std_dev_t = torch.sqrt(sigma / (1 - torch.where(sigma == 1, sigma_max, sigma)))*noise_level # import pdb; pdb.set_trace() # our sde prev_sample_mean = sample*(1+std_dev_t**2/(2*sigma)*dt)+model_output*(1+std_dev_t**2*(1-sigma)/(2*sigma))*dt if prev_sample is None: variance_noise = randn_tensor( model_output.shape, generator=generator, device=model_output.device, dtype=model_output.dtype, ) prev_sample = prev_sample_mean + std_dev_t * torch.sqrt(-1*dt) * variance_noise log_prob = ( -((prev_sample.detach() - prev_sample_mean) ** 2) / (2 * ((std_dev_t * torch.sqrt(-1*dt))**2)) - torch.log(std_dev_t * torch.sqrt(-1*dt)) - torch.log(torch.sqrt(2 * torch.as_tensor(math.pi))) ) # mean along all but batch dimension log_prob = log_prob.mean(dim=tuple(range(1, log_prob.ndim))) return prev_sample, log_prob, prev_sample_mean, std_dev_t def sde_step_with_logprob_new( self: FlowMatchEulerDiscreteScheduler, model_output: torch.FloatTensor, timestep: Union[float, torch.FloatTensor], sample: torch.FloatTensor, noise_level: float = 0.7, prev_sample: Optional[torch.FloatTensor] = None, generator: Optional[torch.Generator] = None, ): """ Predict the sample from the previous timestep by reversing the SDE. This function propagates the flow process from the learned model outputs (most often the predicted velocity). Args: model_output (`torch.FloatTensor`): The direct output from learned flow model. timestep (`float`): The current discrete timestep in the diffusion chain. sample (`torch.FloatTensor`): A current instance of a sample created by the diffusion process. generator (`torch.Generator`, *optional*): A random number generator. """ # bf16 can overflow here when compute prev_sample_mean, we must convert all variable to fp32 model_output=model_output.float() sample=sample.float() if prev_sample is not None: prev_sample=prev_sample.float() step_index = [self.index_for_timestep(t) for t in timestep] prev_step_index = [step+1 for step in step_index] sigma = self.sigmas[step_index].view(-1, *([1] * (len(sample.shape) - 1))) sigma_prev = self.sigmas[prev_step_index].view(-1, *([1] * (len(sample.shape) - 1))) sigma_max = self.sigmas[1].item() dt = sigma_prev - sigma # Flow-SDE #std_dev_t = torch.sqrt(sigma / (1 - torch.where(sigma == 1, sigma_max, sigma)))*noise_level * torch.sqrt(-1*dt) # prev_sample_mean = sample*(1+std_dev_t**2/(2*sigma)*dt)+model_output*(1+std_dev_t**2*(1-sigma)/(2*sigma))*dt # Flow-CPS std_dev_t = sigma_prev * math.sin(noise_level * math.pi / 2) # sigma_t in paper pred_original_sample = sample - sigma * model_output # predicted x_0 in paper noise_estimate = sample + model_output * (1 - sigma) # predicted x_1 in paper prev_sample_mean = pred_original_sample * (1 - sigma_prev) + noise_estimate * torch.sqrt(sigma_prev**2 - std_dev_t**2) # import pdb; pdb.set_trace() if prev_sample is None: variance_noise = randn_tensor( model_output.shape, generator=generator, device=model_output.device, dtype=model_output.dtype, ) prev_sample = prev_sample_mean + std_dev_t * variance_noise # remove all constants log_prob = -((prev_sample.detach() - prev_sample_mean) ** 2) # mean along all but batch dimension log_prob = log_prob.mean(dim=tuple(range(1, log_prob.ndim))) return prev_sample, log_prob, prev_sample_mean, std_dev_t