Spaces:
Running
on
Zero
Running
on
Zero
| # Copied from https://github.com/kvablack/ddpo-pytorch/blob/main/ddpo_pytorch/diffusers_patch/ddim_with_logprob.py | |
| # We adapt it from flow to flow matching. | |
| import math | |
| from typing import Optional, Union | |
| import torch | |
| from diffusers.utils.torch_utils import randn_tensor | |
| from diffusers.schedulers.scheduling_flow_match_euler_discrete import FlowMatchEulerDiscreteScheduler | |
| def sde_step_with_logprob( | |
| self: FlowMatchEulerDiscreteScheduler, | |
| model_output: torch.FloatTensor, | |
| timestep: Union[float, torch.FloatTensor], | |
| sample: torch.FloatTensor, | |
| noise_level: float = 0.7, | |
| prev_sample: Optional[torch.FloatTensor] = None, | |
| generator: Optional[torch.Generator] = None, | |
| ): | |
| """ | |
| Predict the sample from the previous timestep by reversing the SDE. This function propagates the flow | |
| process from the learned model outputs (most often the predicted velocity). | |
| Args: | |
| model_output (`torch.FloatTensor`): | |
| The direct output from learned flow model. | |
| timestep (`float`): | |
| The current discrete timestep in the diffusion chain. | |
| sample (`torch.FloatTensor`): | |
| A current instance of a sample created by the diffusion process. | |
| generator (`torch.Generator`, *optional*): | |
| A random number generator. | |
| """ | |
| # bf16 can overflow here when compute prev_sample_mean, we must convert all variable to fp32 | |
| model_output=model_output.float() | |
| sample=sample.float() | |
| if prev_sample is not None: | |
| prev_sample=prev_sample.float() | |
| step_index = [self.index_for_timestep(t) for t in timestep] | |
| prev_step_index = [step+1 for step in step_index] | |
| sigma = self.sigmas[step_index].view(-1, *([1] * (len(sample.shape) - 1))) | |
| sigma_prev = self.sigmas[prev_step_index].view(-1, *([1] * (len(sample.shape) - 1))) | |
| sigma_max = self.sigmas[1].item() | |
| dt = sigma_prev - sigma | |
| std_dev_t = torch.sqrt(sigma / (1 - torch.where(sigma == 1, sigma_max, sigma)))*noise_level | |
| # import pdb; pdb.set_trace() | |
| # our sde | |
| prev_sample_mean = sample*(1+std_dev_t**2/(2*sigma)*dt)+model_output*(1+std_dev_t**2*(1-sigma)/(2*sigma))*dt | |
| if prev_sample is None: | |
| variance_noise = randn_tensor( | |
| model_output.shape, | |
| generator=generator, | |
| device=model_output.device, | |
| dtype=model_output.dtype, | |
| ) | |
| prev_sample = prev_sample_mean + std_dev_t * torch.sqrt(-1*dt) * variance_noise | |
| log_prob = ( | |
| -((prev_sample.detach() - prev_sample_mean) ** 2) / (2 * ((std_dev_t * torch.sqrt(-1*dt))**2)) | |
| - torch.log(std_dev_t * torch.sqrt(-1*dt)) | |
| - torch.log(torch.sqrt(2 * torch.as_tensor(math.pi))) | |
| ) | |
| # mean along all but batch dimension | |
| log_prob = log_prob.mean(dim=tuple(range(1, log_prob.ndim))) | |
| return prev_sample, log_prob, prev_sample_mean, std_dev_t | |
| def sde_step_with_logprob_new( | |
| self: FlowMatchEulerDiscreteScheduler, | |
| model_output: torch.FloatTensor, | |
| timestep: Union[float, torch.FloatTensor], | |
| sample: torch.FloatTensor, | |
| noise_level: float = 0.7, | |
| prev_sample: Optional[torch.FloatTensor] = None, | |
| generator: Optional[torch.Generator] = None, | |
| ): | |
| """ | |
| Predict the sample from the previous timestep by reversing the SDE. This function propagates the flow | |
| process from the learned model outputs (most often the predicted velocity). | |
| Args: | |
| model_output (`torch.FloatTensor`): | |
| The direct output from learned flow model. | |
| timestep (`float`): | |
| The current discrete timestep in the diffusion chain. | |
| sample (`torch.FloatTensor`): | |
| A current instance of a sample created by the diffusion process. | |
| generator (`torch.Generator`, *optional*): | |
| A random number generator. | |
| """ | |
| # bf16 can overflow here when compute prev_sample_mean, we must convert all variable to fp32 | |
| model_output=model_output.float() | |
| sample=sample.float() | |
| if prev_sample is not None: | |
| prev_sample=prev_sample.float() | |
| step_index = [self.index_for_timestep(t) for t in timestep] | |
| prev_step_index = [step+1 for step in step_index] | |
| sigma = self.sigmas[step_index].view(-1, *([1] * (len(sample.shape) - 1))) | |
| sigma_prev = self.sigmas[prev_step_index].view(-1, *([1] * (len(sample.shape) - 1))) | |
| sigma_max = self.sigmas[1].item() | |
| dt = sigma_prev - sigma | |
| # Flow-SDE | |
| #std_dev_t = torch.sqrt(sigma / (1 - torch.where(sigma == 1, sigma_max, sigma)))*noise_level * torch.sqrt(-1*dt) | |
| # prev_sample_mean = sample*(1+std_dev_t**2/(2*sigma)*dt)+model_output*(1+std_dev_t**2*(1-sigma)/(2*sigma))*dt | |
| # Flow-CPS | |
| std_dev_t = sigma_prev * math.sin(noise_level * math.pi / 2) # sigma_t in paper | |
| pred_original_sample = sample - sigma * model_output # predicted x_0 in paper | |
| noise_estimate = sample + model_output * (1 - sigma) # predicted x_1 in paper | |
| prev_sample_mean = pred_original_sample * (1 - sigma_prev) + noise_estimate * torch.sqrt(sigma_prev**2 - std_dev_t**2) | |
| # import pdb; pdb.set_trace() | |
| if prev_sample is None: | |
| variance_noise = randn_tensor( | |
| model_output.shape, | |
| generator=generator, | |
| device=model_output.device, | |
| dtype=model_output.dtype, | |
| ) | |
| prev_sample = prev_sample_mean + std_dev_t * variance_noise | |
| # remove all constants | |
| log_prob = -((prev_sample.detach() - prev_sample_mean) ** 2) | |
| # mean along all but batch dimension | |
| log_prob = log_prob.mean(dim=tuple(range(1, log_prob.ndim))) | |
| return prev_sample, log_prob, prev_sample_mean, std_dev_t |