aliensmn's picture
Mirror from https://github.com/kijai/ComfyUI-WanVideoWrapper
cf812a0 verified
import torch
from .fm_solvers import (FlowDPMSolverMultistepScheduler, get_sampling_sigmas, retrieve_timesteps)
from .fm_solvers_unipc import FlowUniPCMultistepScheduler
from .basic_flowmatch import FlowMatchScheduler
from .flowmatch_pusa import FlowMatchSchedulerPusa
from .flowmatch_res_multistep import FlowMatchSchedulerResMultistep
from .scheduling_flow_match_lcm import FlowMatchLCMScheduler
from .fm_sa_ode import FlowMatchSAODEStableScheduler
from ...utils import log
try:
from diffusers.schedulers import FlowMatchEulerDiscreteScheduler, DEISMultistepScheduler
except ImportError:
FlowMatchEulerDiscreteScheduler = None
DEISMultistepScheduler = None
scheduler_list = [
"unipc", "unipc/beta",
"dpm++", "dpm++/beta",
"dpm++_sde", "dpm++_sde/beta",
"euler", "euler/beta",
"deis",
"lcm", "lcm/beta",
"res_multistep",
"flowmatch_causvid",
"flowmatch_distill",
"flowmatch_pusa",
"multitalk",
"sa_ode_stable"
]
def get_scheduler(scheduler, steps, start_step, end_step, shift, device, transformer_dim=5120, flowedit_args=None, denoise_strength=1.0, sigmas=None, log_timesteps=False, **kwargs):
timesteps = None
if 'unipc' in scheduler:
sample_scheduler = FlowUniPCMultistepScheduler(shift=shift)
if sigmas is None:
sample_scheduler.set_timesteps(steps, device=device, shift=shift, use_beta_sigmas=('beta' in scheduler))
else:
sample_scheduler.sigmas = sigmas.to(device)
sample_scheduler.timesteps = (sample_scheduler.sigmas[:-1] * 1000).to(torch.int64).to(device)
sample_scheduler.num_inference_steps = len(sample_scheduler.timesteps)
elif scheduler in ['euler/beta', 'euler']:
sample_scheduler = FlowMatchEulerDiscreteScheduler(shift=shift, use_beta_sigmas=(scheduler == 'euler/beta'))
if flowedit_args: #seems to work better
timesteps, _ = retrieve_timesteps(sample_scheduler, device=device, sigmas=get_sampling_sigmas(steps, shift))
else:
sample_scheduler.set_timesteps(steps, device=device, sigmas=sigmas[:-1].tolist() if sigmas is not None else None)
elif 'dpm' in scheduler:
if 'sde' in scheduler:
algorithm_type = "sde-dpmsolver++"
else:
algorithm_type = "dpmsolver++"
sample_scheduler = FlowDPMSolverMultistepScheduler(shift=shift, algorithm_type=algorithm_type)
if sigmas is None:
sample_scheduler.set_timesteps(steps, device=device, use_beta_sigmas=('beta' in scheduler))
else:
sample_scheduler.sigmas = sigmas.to(device)
sample_scheduler.timesteps = (sample_scheduler.sigmas[:-1] * 1000).to(torch.int64).to(device)
sample_scheduler.num_inference_steps = len(sample_scheduler.timesteps)
elif scheduler == 'deis':
sample_scheduler = DEISMultistepScheduler(use_flow_sigmas=True, prediction_type="flow_prediction", flow_shift=shift)
sample_scheduler.set_timesteps(steps, device=device)
sample_scheduler.sigmas[-1] = 1e-6
elif 'lcm' in scheduler:
sample_scheduler = FlowMatchLCMScheduler(shift=shift, use_beta_sigmas=(scheduler == 'lcm/beta'))
sample_scheduler.set_timesteps(steps, device=device, sigmas=sigmas[:-1].tolist() if sigmas is not None else None)
elif 'flowmatch_causvid' in scheduler:
if sigmas is not None:
raise NotImplementedError("This scheduler does not support custom sigmas")
if transformer_dim == 5120:
denoising_list = [999, 934, 862, 756, 603, 410, 250, 140, 74]
else:
if steps != 4:
raise ValueError("CausVid 1.3B schedule is only for 4 steps")
denoising_list = [1000, 750, 500, 250]
sample_scheduler = FlowMatchScheduler(num_inference_steps=steps, shift=shift, sigma_min=0, extra_one_step=True)
sample_scheduler.timesteps = torch.tensor(denoising_list)[:steps].to(device)
sample_scheduler.sigmas = torch.cat([sample_scheduler.timesteps / 1000, torch.tensor([0.0], device=device)])
elif 'flowmatch_distill' in scheduler:
if sigmas is not None:
raise NotImplementedError("This scheduler does not support custom sigmas")
sample_scheduler = FlowMatchScheduler(
shift=shift, sigma_min=0.0, extra_one_step=True
)
sample_scheduler.set_timesteps(1000, training=True)
denoising_step_list = torch.tensor([999, 750, 500, 250] , dtype=torch.long)
temp_timesteps = torch.cat((sample_scheduler.timesteps.cpu(), torch.tensor([0], dtype=torch.float32)))
denoising_step_list = temp_timesteps[1000 - denoising_step_list]
#print("denoising_step_list: ", denoising_step_list)
if steps != 4:
raise ValueError("This scheduler is only for 4 steps")
sample_scheduler.timesteps = denoising_step_list[:steps].clone().detach().to(device)
sample_scheduler.sigmas = torch.cat([sample_scheduler.timesteps / 1000, torch.tensor([0.0], device=device)])
elif 'flowmatch_pusa' in scheduler:
sample_scheduler = FlowMatchSchedulerPusa(shift=shift, sigma_min=0.0, extra_one_step=True)
sample_scheduler.set_timesteps(steps+1, denoising_strength=denoise_strength, shift=shift,
sigmas=sigmas[:-1].tolist() if sigmas is not None else None)
elif scheduler == 'res_multistep':
sample_scheduler = FlowMatchSchedulerResMultistep(shift=shift)
sample_scheduler.set_timesteps(steps, denoising_strength=denoise_strength, sigmas=sigmas[:-1].tolist() if sigmas is not None else None)
elif "sa_ode_stable" in scheduler:
sample_scheduler = FlowMatchSAODEStableScheduler(shift=shift, **kwargs)
sample_scheduler.set_timesteps(steps, device=device, sigmas=sigmas[:-1].tolist() if sigmas is not None else None)
if timesteps is None:
timesteps = sample_scheduler.timesteps
steps = len(timesteps)
if (isinstance(start_step, int) and end_step != -1 and start_step >= end_step) or (not isinstance(start_step, int) and start_step != -1 and end_step >= start_step):
raise ValueError("start_step must be less than end_step")
# Determine start and end indices for slicing
start_idx = 0
end_idx = len(timesteps) - 1
if log_timesteps:
log.info(f"------- Scheduler info -------")
log.info(f"Total timesteps: {timesteps}")
if isinstance(start_step, float):
idxs = (sample_scheduler.sigmas <= start_step).nonzero(as_tuple=True)[0]
if len(idxs) > 0:
start_idx = idxs[0].item()
elif isinstance(start_step, int):
if start_step > 0:
start_idx = start_step
if isinstance(end_step, float):
idxs = (sample_scheduler.sigmas >= end_step).nonzero(as_tuple=True)[0]
if len(idxs) > 0:
end_idx = idxs[-1].item()
elif isinstance(end_step, int):
if end_step != -1:
end_idx = end_step - 1
# Slice timesteps and sigmas once, based on indices
timesteps = timesteps[start_idx:end_idx+1]
sample_scheduler.full_sigmas = sample_scheduler.sigmas.clone()
sample_scheduler.sigmas = sample_scheduler.sigmas[start_idx:start_idx+len(timesteps)+1] # always one longer
if log_timesteps:
log.info(f"Using timesteps: {timesteps}")
log.info(f"Using sigmas: {sample_scheduler.sigmas}")
log.info(f"------------------------------")
if hasattr(sample_scheduler, 'timesteps'):
sample_scheduler.timesteps = timesteps
return sample_scheduler, timesteps, start_idx, end_idx