studyOverflow's picture
Add files using upload-large-folder tool
b171568 verified
#This code file is from [https://github.com/hao-ai-lab/FastVideo], which is licensed under Apache License 2.0.
from dataclasses import dataclass
from typing import Optional, Tuple, Union
import numpy as np
import torch
from diffusers.configuration_utils import ConfigMixin, register_to_config
from diffusers.schedulers.scheduling_utils import SchedulerMixin
from diffusers.utils import BaseOutput, logging
from fastvideo.models.mochi_hf.pipeline_mochi import linear_quadratic_schedule
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
@dataclass
class PCMFMSchedulerOutput(BaseOutput):
prev_sample: torch.FloatTensor
def extract_into_tensor(a, t, x_shape):
b, *_ = t.shape
out = a.gather(-1, t)
return out.reshape(b, *((1, ) * (len(x_shape) - 1)))
class PCMFMScheduler(SchedulerMixin, ConfigMixin):
_compatibles = []
order = 1
@register_to_config
def __init__(
self,
num_train_timesteps: int = 1000,
shift: float = 1.0,
pcm_timesteps: int = 50,
linear_quadratic=False,
linear_quadratic_threshold=0.025,
linear_range=0.5,
):
if linear_quadratic:
linear_steps = int(num_train_timesteps * linear_range)
sigmas = linear_quadratic_schedule(num_train_timesteps,
linear_quadratic_threshold,
linear_steps)
sigmas = torch.tensor(sigmas).to(dtype=torch.float32)
else:
timesteps = np.linspace(1,
num_train_timesteps,
num_train_timesteps,
dtype=np.float32)[::-1].copy()
timesteps = torch.from_numpy(timesteps).to(dtype=torch.float32)
sigmas = timesteps / num_train_timesteps
sigmas = shift * sigmas / (1 + (shift - 1) * sigmas)
self.euler_timesteps = (np.arange(1, pcm_timesteps + 1) *
(num_train_timesteps //
pcm_timesteps)).round().astype(np.int64) - 1
self.sigmas = sigmas.numpy()[::-1][self.euler_timesteps]
self.sigmas = torch.from_numpy((self.sigmas[::-1].copy()))
self.timesteps = self.sigmas * num_train_timesteps
self._step_index = None
self._begin_index = None
self.sigmas = self.sigmas.to(
"cpu") # to avoid too much CPU/GPU communication
self.sigma_min = self.sigmas[-1].item()
self.sigma_max = self.sigmas[0].item()
@property
def step_index(self):
"""
The index counter for current timestep. It will increase 1 after each scheduler step.
"""
return self._step_index
@property
def begin_index(self):
"""
The index for the first timestep. It should be set from pipeline with `set_begin_index` method.
"""
return self._begin_index
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index
def set_begin_index(self, begin_index: int = 0):
"""
Sets the begin index for the scheduler. This function should be run from pipeline before the inference.
Args:
begin_index (`int`):
The begin index for the scheduler.
"""
self._begin_index = begin_index
def scale_noise(
self,
sample: torch.FloatTensor,
timestep: Union[float, torch.FloatTensor],
noise: Optional[torch.FloatTensor] = None,
) -> torch.FloatTensor:
"""
Forward process in flow-matching
Args:
sample (`torch.FloatTensor`):
The input sample.
timestep (`int`, *optional*):
The current timestep in the diffusion chain.
Returns:
`torch.FloatTensor`:
A scaled input sample.
"""
if self.step_index is None:
self._init_step_index(timestep)
sigma = self.sigmas[self.step_index]
sample = sigma * noise + (1.0 - sigma) * sample
return sample
def _sigma_to_t(self, sigma):
return sigma * self.config.num_train_timesteps
def set_timesteps(self,
num_inference_steps: int,
device: Union[str, torch.device] = None):
"""
Sets the discrete timesteps used for the diffusion chain (to be run before inference).
Args:
num_inference_steps (`int`):
The number of diffusion steps used when generating samples with a pre-trained model.
device (`str` or `torch.device`, *optional*):
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
"""
self.num_inference_steps = num_inference_steps
inference_indices = np.linspace(0,
self.config.pcm_timesteps,
num=num_inference_steps,
endpoint=False)
inference_indices = np.floor(inference_indices).astype(np.int64)
inference_indices = torch.from_numpy(inference_indices).long()
self.sigmas_ = self.sigmas[inference_indices]
timesteps = self.sigmas_ * self.config.num_train_timesteps
self.timesteps = timesteps.to(device=device)
self.sigmas_ = torch.cat(
[self.sigmas_,
torch.zeros(1, device=self.sigmas_.device)])
self._step_index = None
self._begin_index = None
def index_for_timestep(self, timestep, schedule_timesteps=None):
if schedule_timesteps is None:
schedule_timesteps = self.timesteps
indices = (schedule_timesteps == timestep).nonzero()
# The sigma index that is taken for the **very** first `step`
# is always the second index (or the last index if there is only 1)
# This way we can ensure we don't accidentally skip a sigma in
# case we start in the middle of the denoising schedule (e.g. for image-to-image)
pos = 1 if len(indices) > 1 else 0
return indices[pos].item()
def _init_step_index(self, timestep):
if self.begin_index is None:
if isinstance(timestep, torch.Tensor):
timestep = timestep.to(self.timesteps.device)
self._step_index = self.index_for_timestep(timestep)
else:
self._step_index = self._begin_index
def step(
self,
model_output: torch.FloatTensor,
timestep: Union[float, torch.FloatTensor],
sample: torch.FloatTensor,
generator: Optional[torch.Generator] = None,
return_dict: bool = True,
) -> Union[PCMFMSchedulerOutput, Tuple]:
"""
Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion
process from the learned model outputs (most often the predicted noise).
Args:
model_output (`torch.FloatTensor`):
The direct output from learned diffusion model.
timestep (`float`):
The current discrete timestep in the diffusion chain.
sample (`torch.FloatTensor`):
A current instance of a sample created by the diffusion process.
s_churn (`float`):
s_tmin (`float`):
s_tmax (`float`):
s_noise (`float`, defaults to 1.0):
Scaling factor for noise added to the sample.
generator (`torch.Generator`, *optional*):
A random number generator.
return_dict (`bool`):
Whether or not to return a [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] or
tuple.
Returns:
[`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] or `tuple`:
If return_dict is `True`, [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] is
returned, otherwise a tuple is returned where the first element is the sample tensor.
"""
if (isinstance(timestep, int) or isinstance(timestep, torch.IntTensor)
or isinstance(timestep, torch.LongTensor)):
raise ValueError((
"Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to"
" `EulerDiscreteScheduler.step()` is not supported. Make sure to pass"
" one of the `scheduler.timesteps` as a timestep."), )
if self.step_index is None:
self._init_step_index(timestep)
sample = sample.to(torch.float32)
sigma = self.sigmas_[self.step_index]
denoised = sample - model_output * sigma
derivative = (sample - denoised) / sigma
dt = self.sigmas_[self.step_index + 1] - sigma
prev_sample = sample + derivative * dt
prev_sample = prev_sample.to(model_output.dtype)
self._step_index += 1
if not return_dict:
return (prev_sample, )
return PCMFMSchedulerOutput(prev_sample=prev_sample)
def __len__(self):
return self.config.num_train_timesteps
class EulerSolver:
def __init__(self, sigmas, timesteps=1000, euler_timesteps=50):
self.step_ratio = timesteps // euler_timesteps
self.euler_timesteps = (np.arange(1, euler_timesteps + 1) *
self.step_ratio).round().astype(np.int64) - 1
self.euler_timesteps_prev = np.asarray(
[0] + self.euler_timesteps[:-1].tolist())
self.sigmas = sigmas[self.euler_timesteps]
self.sigmas_prev = np.asarray(
[sigmas[0]] + sigmas[self.euler_timesteps[:-1]].tolist()
) # either use sigma0 or 0
self.euler_timesteps = torch.from_numpy(self.euler_timesteps).long()
self.euler_timesteps_prev = torch.from_numpy(
self.euler_timesteps_prev).long()
self.sigmas = torch.from_numpy(self.sigmas)
self.sigmas_prev = torch.from_numpy(self.sigmas_prev)
def to(self, device):
self.euler_timesteps = self.euler_timesteps.to(device)
self.euler_timesteps_prev = self.euler_timesteps_prev.to(device)
self.sigmas = self.sigmas.to(device)
self.sigmas_prev = self.sigmas_prev.to(device)
return self
def euler_step(self, sample, model_pred, timestep_index):
sigma = extract_into_tensor(self.sigmas, timestep_index,
model_pred.shape)
sigma_prev = extract_into_tensor(self.sigmas_prev, timestep_index,
model_pred.shape)
x_prev = sample + (sigma_prev - sigma) * model_pred
return x_prev
def euler_style_multiphase_pred(
self,
sample,
model_pred,
timestep_index,
multiphase,
is_target=False,
):
inference_indices = np.linspace(0,
len(self.euler_timesteps),
num=multiphase,
endpoint=False)
inference_indices = np.floor(inference_indices).astype(np.int64)
inference_indices = (torch.from_numpy(inference_indices).long().to(
self.euler_timesteps.device))
expanded_timestep_index = timestep_index.unsqueeze(1).expand(
-1, inference_indices.size(0))
valid_indices_mask = expanded_timestep_index >= inference_indices
last_valid_index = valid_indices_mask.flip(dims=[1]).long().argmax(
dim=1)
last_valid_index = inference_indices.size(0) - 1 - last_valid_index
timestep_index_end = inference_indices[last_valid_index]
if is_target:
sigma = extract_into_tensor(self.sigmas_prev, timestep_index,
sample.shape)
else:
sigma = extract_into_tensor(self.sigmas, timestep_index,
sample.shape)
sigma_prev = extract_into_tensor(self.sigmas_prev, timestep_index_end,
sample.shape)
x_prev = sample + (sigma_prev - sigma) * model_pred
return x_prev, timestep_index_end