File size: 5,687 Bytes
b6ff324 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 | # Copyright (c) 2024-present, BAAI. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
##############################################################################
"""Simple implementation of continuous flow matching schedulers."""
import dataclasses
import math
import numpy as np
import torch
from diffusers.configuration_utils import ConfigMixin, register_to_config
from diffusers.models.modeling_outputs import BaseOutput
from diffusers.schedulers.scheduling_utils import SchedulerMixin
@dataclasses.dataclass
class FlowMatchEulerDiscreteSchedulerOutput(BaseOutput):
"""Output for scheduler's `step` function output."""
prev_sample: torch.FloatTensor
class FlowMatchEulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
order = 1
@register_to_config
def __init__(self, num_train_timesteps=1000, shift=1.0, use_dynamic_shifting=False):
timesteps = np.arange(1, num_train_timesteps + 1, dtype="float32")[::-1]
sigmas, self._shift = timesteps / num_train_timesteps, shift
if not use_dynamic_shifting:
sigmas = shift * sigmas / (1 + (shift - 1) * sigmas)
self.timesteps = torch.as_tensor(sigmas * num_train_timesteps)
self.sigmas = torch.as_tensor(sigmas)
self.sigma_min, self.sigma_max = float(sigmas[-1]), float(sigmas[0])
self.timestep = self.sigma = None # Training states.
self._begin_index = self._step_index = None # Inference counters.
@property
def shift(self):
"""The value used for shifting."""
return self._shift
@property
def step_index(self):
"""The index counter for current timestep."""
return self._step_index
@property
def begin_index(self):
"""The index for the first timestep."""
return self._begin_index
def _sigma_to_t(self, sigma):
return sigma * self.config.num_train_timesteps
def _init_step_index(self, timestep):
if self.begin_index is None:
self._step_index = self.index_for_timestep(timestep)
else:
self._step_index = self._begin_index
def time_shift(self, mu: float, sigma: float, t: torch.Tensor):
return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma)
def set_shift(self, shift: float):
self._shift = shift
def index_for_timestep(self, timestep, schedule_timesteps=None):
if schedule_timesteps is None:
schedule_timesteps = self.timesteps
indices = (schedule_timesteps == timestep).nonzero()
return indices[1 if len(indices) > 1 else 0].item()
def sample_timesteps(self, size, device=None):
"""Sample the discrete timesteps used for training."""
dist = torch.normal(0, 1, size, device=device).sigmoid_()
return dist.mul_(self.config.num_train_timesteps).to(dtype=torch.int64)
def set_timesteps(self, num_inference_steps, mu=None):
"""Sets the discrete timesteps used for the diffusion chain."""
self.num_inference_steps = num_inference_steps
t_max, t_min = self._sigma_to_t(self.sigma_max), self._sigma_to_t(self.sigma_min)
timesteps = np.linspace(t_max, t_min, num_inference_steps, dtype="float32")
sigmas = timesteps / self.config.num_train_timesteps
if self.config.use_dynamic_shifting:
sigmas = self.time_shift(mu, 1.0, sigmas)
else:
sigmas = self.shift * sigmas / (1 + (self.shift - 1) * sigmas)
self.sigmas = sigmas.tolist() + [0]
self.timesteps = sigmas * self.config.num_train_timesteps
self._begin_index = self._step_index = None
def add_noise(
self,
original_samples: torch.Tensor,
noise: torch.Tensor,
timesteps: torch.Tensor,
):
"""Add forward noise to samples for training."""
dtype, device = original_samples.dtype, original_samples.device
self.timestep = self.timesteps.to(device=device)[timesteps]
self.sigma = self.sigmas.to(device=device, dtype=dtype)[timesteps]
self.sigma = self.sigma.view(timesteps.shape + (1,) * (noise.dim() - timesteps.dim()))
return self.sigma * noise + (1.0 - self.sigma) * original_samples
def scale_noise(self, sample: torch.Tensor, timestep: float, noise: torch.Tensor):
"""Add forward noise to samples for inference."""
self._init_step_index(timestep) if self.step_index is None else None
sigma = self.sigmas[self.step_index]
return sigma * noise + (1.0 - sigma) * sample
def step(
self,
model_output: torch.Tensor,
timestep: float,
sample: torch.FloatTensor,
generator: torch.Generator = None,
return_dict=True,
):
"""Predict the sample from the previous timestep."""
self._init_step_index(timestep) if self.step_index is None else None
dt = self.sigmas[self.step_index + 1] - self.sigmas[self.step_index]
prev_sample = model_output.mul(dt).add_(sample)
self._step_index += 1
if not return_dict:
return (prev_sample,)
return FlowMatchEulerDiscreteSchedulerOutput(prev_sample=prev_sample)
|