File size: 5,687 Bytes
b6ff324
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
# Copyright (c) 2024-present, BAAI. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
##############################################################################
"""Simple implementation of continuous flow matching schedulers."""

import dataclasses
import math

import numpy as np
import torch

from diffusers.configuration_utils import ConfigMixin, register_to_config
from diffusers.models.modeling_outputs import BaseOutput
from diffusers.schedulers.scheduling_utils import SchedulerMixin


@dataclasses.dataclass
class FlowMatchEulerDiscreteSchedulerOutput(BaseOutput):
    """Output for scheduler's `step` function output."""

    prev_sample: torch.FloatTensor


class FlowMatchEulerDiscreteScheduler(SchedulerMixin, ConfigMixin):

    order = 1

    @register_to_config
    def __init__(self, num_train_timesteps=1000, shift=1.0, use_dynamic_shifting=False):
        timesteps = np.arange(1, num_train_timesteps + 1, dtype="float32")[::-1]
        sigmas, self._shift = timesteps / num_train_timesteps, shift
        if not use_dynamic_shifting:
            sigmas = shift * sigmas / (1 + (shift - 1) * sigmas)
        self.timesteps = torch.as_tensor(sigmas * num_train_timesteps)
        self.sigmas = torch.as_tensor(sigmas)
        self.sigma_min, self.sigma_max = float(sigmas[-1]), float(sigmas[0])
        self.timestep = self.sigma = None  # Training states.
        self._begin_index = self._step_index = None  # Inference counters.

    @property
    def shift(self):
        """The value used for shifting."""
        return self._shift

    @property
    def step_index(self):
        """The index counter for current timestep."""
        return self._step_index

    @property
    def begin_index(self):
        """The index for the first timestep."""
        return self._begin_index

    def _sigma_to_t(self, sigma):
        return sigma * self.config.num_train_timesteps

    def _init_step_index(self, timestep):
        if self.begin_index is None:
            self._step_index = self.index_for_timestep(timestep)
        else:
            self._step_index = self._begin_index

    def time_shift(self, mu: float, sigma: float, t: torch.Tensor):
        return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma)

    def set_shift(self, shift: float):
        self._shift = shift

    def index_for_timestep(self, timestep, schedule_timesteps=None):
        if schedule_timesteps is None:
            schedule_timesteps = self.timesteps
        indices = (schedule_timesteps == timestep).nonzero()
        return indices[1 if len(indices) > 1 else 0].item()

    def sample_timesteps(self, size, device=None):
        """Sample the discrete timesteps used for training."""
        dist = torch.normal(0, 1, size, device=device).sigmoid_()
        return dist.mul_(self.config.num_train_timesteps).to(dtype=torch.int64)

    def set_timesteps(self, num_inference_steps, mu=None):
        """Sets the discrete timesteps used for the diffusion chain."""
        self.num_inference_steps = num_inference_steps
        t_max, t_min = self._sigma_to_t(self.sigma_max), self._sigma_to_t(self.sigma_min)
        timesteps = np.linspace(t_max, t_min, num_inference_steps, dtype="float32")
        sigmas = timesteps / self.config.num_train_timesteps
        if self.config.use_dynamic_shifting:
            sigmas = self.time_shift(mu, 1.0, sigmas)
        else:
            sigmas = self.shift * sigmas / (1 + (self.shift - 1) * sigmas)
        self.sigmas = sigmas.tolist() + [0]
        self.timesteps = sigmas * self.config.num_train_timesteps
        self._begin_index = self._step_index = None

    def add_noise(
        self,
        original_samples: torch.Tensor,
        noise: torch.Tensor,
        timesteps: torch.Tensor,
    ):
        """Add forward noise to samples for training."""
        dtype, device = original_samples.dtype, original_samples.device
        self.timestep = self.timesteps.to(device=device)[timesteps]
        self.sigma = self.sigmas.to(device=device, dtype=dtype)[timesteps]
        self.sigma = self.sigma.view(timesteps.shape + (1,) * (noise.dim() - timesteps.dim()))
        return self.sigma * noise + (1.0 - self.sigma) * original_samples

    def scale_noise(self, sample: torch.Tensor, timestep: float, noise: torch.Tensor):
        """Add forward noise to samples for inference."""
        self._init_step_index(timestep) if self.step_index is None else None
        sigma = self.sigmas[self.step_index]
        return sigma * noise + (1.0 - sigma) * sample

    def step(
        self,
        model_output: torch.Tensor,
        timestep: float,
        sample: torch.FloatTensor,
        generator: torch.Generator = None,
        return_dict=True,
    ):
        """Predict the sample from the previous timestep."""
        self._init_step_index(timestep) if self.step_index is None else None
        dt = self.sigmas[self.step_index + 1] - self.sigmas[self.step_index]
        prev_sample = model_output.mul(dt).add_(sample)
        self._step_index += 1
        if not return_dict:
            return (prev_sample,)
        return FlowMatchEulerDiscreteSchedulerOutput(prev_sample=prev_sample)