Fabrice-TIERCELIN commited on
Commit
1d692ce
·
verified ·
1 Parent(s): 7d74ebe

Upload rf.py

Browse files
Files changed (1) hide show
  1. ltx_video/schedulers/rf.py +386 -0
ltx_video/schedulers/rf.py ADDED
@@ -0,0 +1,386 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from abc import ABC, abstractmethod
3
+ from dataclasses import dataclass
4
+ from typing import Callable, Optional, Tuple, Union
5
+ import json
6
+ import os
7
+ from pathlib import Path
8
+
9
+ import torch
10
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
11
+ from diffusers.schedulers.scheduling_utils import SchedulerMixin
12
+ from diffusers.utils import BaseOutput
13
+ from torch import Tensor
14
+ from safetensors import safe_open
15
+
16
+
17
+ from ltx_video.utils.torch_utils import append_dims
18
+
19
+ from ltx_video.utils.diffusers_config_mapping import (
20
+ diffusers_and_ours_config_mapping,
21
+ make_hashable_key,
22
+ )
23
+
24
+
25
+ def linear_quadratic_schedule(num_steps, threshold_noise=0.025, linear_steps=None):
26
+ if num_steps == 1:
27
+ return torch.tensor([1.0])
28
+ if linear_steps is None:
29
+ linear_steps = num_steps // 2
30
+ linear_sigma_schedule = [
31
+ i * threshold_noise / linear_steps for i in range(linear_steps)
32
+ ]
33
+ threshold_noise_step_diff = linear_steps - threshold_noise * num_steps
34
+ quadratic_steps = num_steps - linear_steps
35
+ quadratic_coef = threshold_noise_step_diff / (linear_steps * quadratic_steps**2)
36
+ linear_coef = threshold_noise / linear_steps - 2 * threshold_noise_step_diff / (
37
+ quadratic_steps**2
38
+ )
39
+ const = quadratic_coef * (linear_steps**2)
40
+ quadratic_sigma_schedule = [
41
+ quadratic_coef * (i**2) + linear_coef * i + const
42
+ for i in range(linear_steps, num_steps)
43
+ ]
44
+ sigma_schedule = linear_sigma_schedule + quadratic_sigma_schedule + [1.0]
45
+ sigma_schedule = [1.0 - x for x in sigma_schedule]
46
+ return torch.tensor(sigma_schedule[:-1])
47
+
48
+
49
+ def simple_diffusion_resolution_dependent_timestep_shift(
50
+ samples_shape: torch.Size,
51
+ timesteps: Tensor,
52
+ n: int = 32 * 32,
53
+ ) -> Tensor:
54
+ if len(samples_shape) == 3:
55
+ _, m, _ = samples_shape
56
+ elif len(samples_shape) in [4, 5]:
57
+ m = math.prod(samples_shape[2:])
58
+ else:
59
+ raise ValueError(
60
+ "Samples must have shape (b, t, c), (b, c, h, w) or (b, c, f, h, w)"
61
+ )
62
+ snr = (timesteps / (1 - timesteps)) ** 2
63
+ shift_snr = torch.log(snr) + 2 * math.log(m / n)
64
+ shifted_timesteps = torch.sigmoid(0.5 * shift_snr)
65
+
66
+ return shifted_timesteps
67
+
68
+
69
+ def time_shift(mu: float, sigma: float, t: Tensor):
70
+ return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma)
71
+
72
+
73
+ def get_normal_shift(
74
+ n_tokens: int,
75
+ min_tokens: int = 1024,
76
+ max_tokens: int = 4096,
77
+ min_shift: float = 0.95,
78
+ max_shift: float = 2.05,
79
+ ) -> Callable[[float], float]:
80
+ m = (max_shift - min_shift) / (max_tokens - min_tokens)
81
+ b = min_shift - m * min_tokens
82
+ return m * n_tokens + b
83
+
84
+
85
+ def strech_shifts_to_terminal(shifts: Tensor, terminal=0.1):
86
+ """
87
+ Stretch a function (given as sampled shifts) so that its final value matches the given terminal value
88
+ using the provided formula.
89
+
90
+ Parameters:
91
+ - shifts (Tensor): The samples of the function to be stretched (PyTorch Tensor).
92
+ - terminal (float): The desired terminal value (value at the last sample).
93
+
94
+ Returns:
95
+ - Tensor: The stretched shifts such that the final value equals `terminal`.
96
+ """
97
+ if shifts.numel() == 0:
98
+ raise ValueError("The 'shifts' tensor must not be empty.")
99
+
100
+ # Ensure terminal value is valid
101
+ if terminal <= 0 or terminal >= 1:
102
+ raise ValueError("The terminal value must be between 0 and 1 (exclusive).")
103
+
104
+ # Transform the shifts using the given formula
105
+ one_minus_z = 1 - shifts
106
+ scale_factor = one_minus_z[-1] / (1 - terminal)
107
+ stretched_shifts = 1 - (one_minus_z / scale_factor)
108
+
109
+ return stretched_shifts
110
+
111
+
112
+ def sd3_resolution_dependent_timestep_shift(
113
+ samples_shape: torch.Size,
114
+ timesteps: Tensor,
115
+ target_shift_terminal: Optional[float] = None,
116
+ ) -> Tensor:
117
+ """
118
+ Shifts the timestep schedule as a function of the generated resolution.
119
+
120
+ In the SD3 paper, the authors empirically how to shift the timesteps based on the resolution of the target images.
121
+ For more details: https://arxiv.org/pdf/2403.03206
122
+
123
+ In Flux they later propose a more dynamic resolution dependent timestep shift, see:
124
+ https://github.com/black-forest-labs/flux/blob/87f6fff727a377ea1c378af692afb41ae84cbe04/src/flux/sampling.py#L66
125
+
126
+
127
+ Args:
128
+ samples_shape (torch.Size): The samples batch shape (batch_size, channels, height, width) or
129
+ (batch_size, channels, frame, height, width).
130
+ timesteps (Tensor): A batch of timesteps with shape (batch_size,).
131
+ target_shift_terminal (float): The target terminal value for the shifted timesteps.
132
+
133
+ Returns:
134
+ Tensor: The shifted timesteps.
135
+ """
136
+ if len(samples_shape) == 3:
137
+ _, m, _ = samples_shape
138
+ elif len(samples_shape) in [4, 5]:
139
+ m = math.prod(samples_shape[2:])
140
+ else:
141
+ raise ValueError(
142
+ "Samples must have shape (b, t, c), (b, c, h, w) or (b, c, f, h, w)"
143
+ )
144
+
145
+ shift = get_normal_shift(m)
146
+ time_shifts = time_shift(shift, 1, timesteps)
147
+ if target_shift_terminal is not None: # Stretch the shifts to the target terminal
148
+ time_shifts = strech_shifts_to_terminal(time_shifts, target_shift_terminal)
149
+ return time_shifts
150
+
151
+
152
+ class TimestepShifter(ABC):
153
+ @abstractmethod
154
+ def shift_timesteps(self, samples_shape: torch.Size, timesteps: Tensor) -> Tensor:
155
+ pass
156
+
157
+
158
+ @dataclass
159
+ class RectifiedFlowSchedulerOutput(BaseOutput):
160
+ """
161
+ Output class for the scheduler's step function output.
162
+
163
+ Args:
164
+ prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
165
+ Computed sample (x_{t-1}) of previous timestep. `prev_sample` should be used as next model input in the
166
+ denoising loop.
167
+ pred_original_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
168
+ The predicted denoised sample (x_{0}) based on the model output from the current timestep.
169
+ `pred_original_sample` can be used to preview progress or for guidance.
170
+ """
171
+
172
+ prev_sample: torch.FloatTensor
173
+ pred_original_sample: Optional[torch.FloatTensor] = None
174
+
175
+
176
+ class RectifiedFlowScheduler(SchedulerMixin, ConfigMixin, TimestepShifter):
177
+ order = 1
178
+
179
+ @register_to_config
180
+ def __init__(
181
+ self,
182
+ num_train_timesteps=1000,
183
+ shifting: Optional[str] = None,
184
+ base_resolution: int = 32**2,
185
+ target_shift_terminal: Optional[float] = None,
186
+ sampler: Optional[str] = "Uniform",
187
+ shift: Optional[float] = None,
188
+ ):
189
+ super().__init__()
190
+ self.init_noise_sigma = 1.0
191
+ self.num_inference_steps = None
192
+ self.sampler = sampler
193
+ self.shifting = shifting
194
+ self.base_resolution = base_resolution
195
+ self.target_shift_terminal = target_shift_terminal
196
+ self.timesteps = self.sigmas = self.get_initial_timesteps(
197
+ num_train_timesteps, shift=shift
198
+ )
199
+ self.shift = shift
200
+
201
+ def get_initial_timesteps(
202
+ self, num_timesteps: int, shift: Optional[float] = None
203
+ ) -> Tensor:
204
+ if self.sampler == "Uniform":
205
+ return torch.linspace(1, 1 / num_timesteps, num_timesteps)
206
+ elif self.sampler == "LinearQuadratic":
207
+ return linear_quadratic_schedule(num_timesteps)
208
+ elif self.sampler == "Constant":
209
+ assert (
210
+ shift is not None
211
+ ), "Shift must be provided for constant time shift sampler."
212
+ return time_shift(
213
+ shift, 1, torch.linspace(1, 1 / num_timesteps, num_timesteps)
214
+ )
215
+
216
+ def shift_timesteps(self, samples_shape: torch.Size, timesteps: Tensor) -> Tensor:
217
+ if self.shifting == "SD3":
218
+ return sd3_resolution_dependent_timestep_shift(
219
+ samples_shape, timesteps, self.target_shift_terminal
220
+ )
221
+ elif self.shifting == "SimpleDiffusion":
222
+ return simple_diffusion_resolution_dependent_timestep_shift(
223
+ samples_shape, timesteps, self.base_resolution
224
+ )
225
+ return timesteps
226
+
227
+ def set_timesteps(
228
+ self,
229
+ num_inference_steps: Optional[int] = None,
230
+ samples_shape: Optional[torch.Size] = None,
231
+ timesteps: Optional[Tensor] = None,
232
+ device: Union[str, torch.device] = None,
233
+ ):
234
+ """
235
+ Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference.
236
+ If `timesteps` are provided, they will be used instead of the scheduled timesteps.
237
+
238
+ Args:
239
+ num_inference_steps (`int` *optional*): The number of diffusion steps used when generating samples.
240
+ samples_shape (`torch.Size` *optional*): The samples batch shape, used for shifting.
241
+ timesteps ('torch.Tensor' *optional*): Specific timesteps to use instead of scheduled timesteps.
242
+ device (`Union[str, torch.device]`, *optional*): The device to which the timesteps tensor will be moved.
243
+ """
244
+ if timesteps is not None and num_inference_steps is not None:
245
+ raise ValueError(
246
+ "You cannot provide both `timesteps` and `num_inference_steps`."
247
+ )
248
+ if timesteps is None:
249
+ num_inference_steps = min(
250
+ self.config.num_train_timesteps, num_inference_steps
251
+ )
252
+ timesteps = self.get_initial_timesteps(
253
+ num_inference_steps, shift=self.shift
254
+ ).to(device)
255
+ timesteps = self.shift_timesteps(samples_shape, timesteps)
256
+ else:
257
+ timesteps = torch.Tensor(timesteps).to(device)
258
+ num_inference_steps = len(timesteps)
259
+ self.timesteps = timesteps
260
+ self.num_inference_steps = num_inference_steps
261
+ self.sigmas = self.timesteps
262
+
263
+ @staticmethod
264
+ def from_pretrained(pretrained_model_path: Union[str, os.PathLike]):
265
+ pretrained_model_path = Path(pretrained_model_path)
266
+ if pretrained_model_path.is_file():
267
+ comfy_single_file_state_dict = {}
268
+ with safe_open(pretrained_model_path, framework="pt", device="cpu") as f:
269
+ metadata = f.metadata()
270
+ for k in f.keys():
271
+ comfy_single_file_state_dict[k] = f.get_tensor(k)
272
+ configs = json.loads(metadata["config"])
273
+ config = configs["scheduler"]
274
+ del comfy_single_file_state_dict
275
+
276
+ elif pretrained_model_path.is_dir():
277
+ diffusers_noise_scheduler_config_path = (
278
+ pretrained_model_path / "scheduler" / "scheduler_config.json"
279
+ )
280
+
281
+ with open(diffusers_noise_scheduler_config_path, "r") as f:
282
+ scheduler_config = json.load(f)
283
+ hashable_config = make_hashable_key(scheduler_config)
284
+ if hashable_config in diffusers_and_ours_config_mapping:
285
+ config = diffusers_and_ours_config_mapping[hashable_config]
286
+ return RectifiedFlowScheduler.from_config(config)
287
+
288
+ def scale_model_input(
289
+ self, sample: torch.FloatTensor, timestep: Optional[int] = None
290
+ ) -> torch.FloatTensor:
291
+ # pylint: disable=unused-argument
292
+ """
293
+ Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
294
+ current timestep.
295
+
296
+ Args:
297
+ sample (`torch.FloatTensor`): input sample
298
+ timestep (`int`, optional): current timestep
299
+
300
+ Returns:
301
+ `torch.FloatTensor`: scaled input sample
302
+ """
303
+ return sample
304
+
305
+ def step(
306
+ self,
307
+ model_output: torch.FloatTensor,
308
+ timestep: torch.FloatTensor,
309
+ sample: torch.FloatTensor,
310
+ return_dict: bool = True,
311
+ stochastic_sampling: Optional[bool] = False,
312
+ **kwargs,
313
+ ) -> Union[RectifiedFlowSchedulerOutput, Tuple]:
314
+ """
315
+ Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion
316
+ process from the learned model outputs (most often the predicted noise).
317
+ z_{t_1} = z_t - \Delta_t * v
318
+ The method finds the next timestep that is lower than the input timestep(s) and denoises the latents
319
+ to that level. The input timestep(s) are not required to be one of the predefined timesteps.
320
+
321
+ Args:
322
+ model_output (`torch.FloatTensor`):
323
+ The direct output from learned diffusion model - the velocity,
324
+ timestep (`float`):
325
+ The current discrete timestep in the diffusion chain (global or per-token).
326
+ sample (`torch.FloatTensor`):
327
+ A current latent tokens to be de-noised.
328
+ return_dict (`bool`, *optional*, defaults to `True`):
329
+ Whether or not to return a [`~schedulers.scheduling_ddim.DDIMSchedulerOutput`] or `tuple`.
330
+ stochastic_sampling (`bool`, *optional*, defaults to `False`):
331
+ Whether to use stochastic sampling for the sampling process.
332
+
333
+ Returns:
334
+ [`~schedulers.scheduling_utils.RectifiedFlowSchedulerOutput`] or `tuple`:
335
+ If return_dict is `True`, [`~schedulers.rf_scheduler.RectifiedFlowSchedulerOutput`] is returned,
336
+ otherwise a tuple is returned where the first element is the sample tensor.
337
+ """
338
+ if self.num_inference_steps is None:
339
+ raise ValueError(
340
+ "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
341
+ )
342
+ t_eps = 1e-6 # Small epsilon to avoid numerical issues in timestep values
343
+
344
+ timesteps_padded = torch.cat(
345
+ [self.timesteps, torch.zeros(1, device=self.timesteps.device)]
346
+ )
347
+
348
+ # Find the next lower timestep(s) and compute the dt from the current timestep(s)
349
+ if timestep.ndim == 0:
350
+ # Global timestep case
351
+ lower_mask = timesteps_padded < timestep - t_eps
352
+ lower_timestep = timesteps_padded[lower_mask][0] # Closest lower timestep
353
+ dt = timestep - lower_timestep
354
+
355
+ else:
356
+ # Per-token case
357
+ assert timestep.ndim == 2
358
+ lower_mask = timesteps_padded[:, None, None] < timestep[None] - t_eps
359
+ lower_timestep = lower_mask * timesteps_padded[:, None, None]
360
+ lower_timestep, _ = lower_timestep.max(dim=0)
361
+ dt = (timestep - lower_timestep)[..., None]
362
+
363
+ # Compute previous sample
364
+ if stochastic_sampling:
365
+ x0 = sample - timestep[..., None] * model_output
366
+ next_timestep = timestep[..., None] - dt
367
+ prev_sample = self.add_noise(x0, torch.randn_like(sample), next_timestep)
368
+ else:
369
+ prev_sample = sample - dt * model_output
370
+
371
+ if not return_dict:
372
+ return (prev_sample,)
373
+
374
+ return RectifiedFlowSchedulerOutput(prev_sample=prev_sample)
375
+
376
+ def add_noise(
377
+ self,
378
+ original_samples: torch.FloatTensor,
379
+ noise: torch.FloatTensor,
380
+ timesteps: torch.FloatTensor,
381
+ ) -> torch.FloatTensor:
382
+ sigmas = timesteps
383
+ sigmas = append_dims(sigmas, original_samples.ndim)
384
+ alphas = 1 - sigmas
385
+ noisy_samples = alphas * original_samples + sigmas * noise
386
+ return noisy_samples