File size: 7,640 Bytes
cf812a0 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 | import torch
from .fm_solvers import (FlowDPMSolverMultistepScheduler, get_sampling_sigmas, retrieve_timesteps)
from .fm_solvers_unipc import FlowUniPCMultistepScheduler
from .basic_flowmatch import FlowMatchScheduler
from .flowmatch_pusa import FlowMatchSchedulerPusa
from .flowmatch_res_multistep import FlowMatchSchedulerResMultistep
from .scheduling_flow_match_lcm import FlowMatchLCMScheduler
from .fm_sa_ode import FlowMatchSAODEStableScheduler
from ...utils import log
try:
from diffusers.schedulers import FlowMatchEulerDiscreteScheduler, DEISMultistepScheduler
except ImportError:
FlowMatchEulerDiscreteScheduler = None
DEISMultistepScheduler = None
scheduler_list = [
"unipc", "unipc/beta",
"dpm++", "dpm++/beta",
"dpm++_sde", "dpm++_sde/beta",
"euler", "euler/beta",
"deis",
"lcm", "lcm/beta",
"res_multistep",
"flowmatch_causvid",
"flowmatch_distill",
"flowmatch_pusa",
"multitalk",
"sa_ode_stable"
]
def get_scheduler(scheduler, steps, start_step, end_step, shift, device, transformer_dim=5120, flowedit_args=None, denoise_strength=1.0, sigmas=None, log_timesteps=False, **kwargs):
timesteps = None
if 'unipc' in scheduler:
sample_scheduler = FlowUniPCMultistepScheduler(shift=shift)
if sigmas is None:
sample_scheduler.set_timesteps(steps, device=device, shift=shift, use_beta_sigmas=('beta' in scheduler))
else:
sample_scheduler.sigmas = sigmas.to(device)
sample_scheduler.timesteps = (sample_scheduler.sigmas[:-1] * 1000).to(torch.int64).to(device)
sample_scheduler.num_inference_steps = len(sample_scheduler.timesteps)
elif scheduler in ['euler/beta', 'euler']:
sample_scheduler = FlowMatchEulerDiscreteScheduler(shift=shift, use_beta_sigmas=(scheduler == 'euler/beta'))
if flowedit_args: #seems to work better
timesteps, _ = retrieve_timesteps(sample_scheduler, device=device, sigmas=get_sampling_sigmas(steps, shift))
else:
sample_scheduler.set_timesteps(steps, device=device, sigmas=sigmas[:-1].tolist() if sigmas is not None else None)
elif 'dpm' in scheduler:
if 'sde' in scheduler:
algorithm_type = "sde-dpmsolver++"
else:
algorithm_type = "dpmsolver++"
sample_scheduler = FlowDPMSolverMultistepScheduler(shift=shift, algorithm_type=algorithm_type)
if sigmas is None:
sample_scheduler.set_timesteps(steps, device=device, use_beta_sigmas=('beta' in scheduler))
else:
sample_scheduler.sigmas = sigmas.to(device)
sample_scheduler.timesteps = (sample_scheduler.sigmas[:-1] * 1000).to(torch.int64).to(device)
sample_scheduler.num_inference_steps = len(sample_scheduler.timesteps)
elif scheduler == 'deis':
sample_scheduler = DEISMultistepScheduler(use_flow_sigmas=True, prediction_type="flow_prediction", flow_shift=shift)
sample_scheduler.set_timesteps(steps, device=device)
sample_scheduler.sigmas[-1] = 1e-6
elif 'lcm' in scheduler:
sample_scheduler = FlowMatchLCMScheduler(shift=shift, use_beta_sigmas=(scheduler == 'lcm/beta'))
sample_scheduler.set_timesteps(steps, device=device, sigmas=sigmas[:-1].tolist() if sigmas is not None else None)
elif 'flowmatch_causvid' in scheduler:
if sigmas is not None:
raise NotImplementedError("This scheduler does not support custom sigmas")
if transformer_dim == 5120:
denoising_list = [999, 934, 862, 756, 603, 410, 250, 140, 74]
else:
if steps != 4:
raise ValueError("CausVid 1.3B schedule is only for 4 steps")
denoising_list = [1000, 750, 500, 250]
sample_scheduler = FlowMatchScheduler(num_inference_steps=steps, shift=shift, sigma_min=0, extra_one_step=True)
sample_scheduler.timesteps = torch.tensor(denoising_list)[:steps].to(device)
sample_scheduler.sigmas = torch.cat([sample_scheduler.timesteps / 1000, torch.tensor([0.0], device=device)])
elif 'flowmatch_distill' in scheduler:
if sigmas is not None:
raise NotImplementedError("This scheduler does not support custom sigmas")
sample_scheduler = FlowMatchScheduler(
shift=shift, sigma_min=0.0, extra_one_step=True
)
sample_scheduler.set_timesteps(1000, training=True)
denoising_step_list = torch.tensor([999, 750, 500, 250] , dtype=torch.long)
temp_timesteps = torch.cat((sample_scheduler.timesteps.cpu(), torch.tensor([0], dtype=torch.float32)))
denoising_step_list = temp_timesteps[1000 - denoising_step_list]
#print("denoising_step_list: ", denoising_step_list)
if steps != 4:
raise ValueError("This scheduler is only for 4 steps")
sample_scheduler.timesteps = denoising_step_list[:steps].clone().detach().to(device)
sample_scheduler.sigmas = torch.cat([sample_scheduler.timesteps / 1000, torch.tensor([0.0], device=device)])
elif 'flowmatch_pusa' in scheduler:
sample_scheduler = FlowMatchSchedulerPusa(shift=shift, sigma_min=0.0, extra_one_step=True)
sample_scheduler.set_timesteps(steps+1, denoising_strength=denoise_strength, shift=shift,
sigmas=sigmas[:-1].tolist() if sigmas is not None else None)
elif scheduler == 'res_multistep':
sample_scheduler = FlowMatchSchedulerResMultistep(shift=shift)
sample_scheduler.set_timesteps(steps, denoising_strength=denoise_strength, sigmas=sigmas[:-1].tolist() if sigmas is not None else None)
elif "sa_ode_stable" in scheduler:
sample_scheduler = FlowMatchSAODEStableScheduler(shift=shift, **kwargs)
sample_scheduler.set_timesteps(steps, device=device, sigmas=sigmas[:-1].tolist() if sigmas is not None else None)
if timesteps is None:
timesteps = sample_scheduler.timesteps
steps = len(timesteps)
if (isinstance(start_step, int) and end_step != -1 and start_step >= end_step) or (not isinstance(start_step, int) and start_step != -1 and end_step >= start_step):
raise ValueError("start_step must be less than end_step")
# Determine start and end indices for slicing
start_idx = 0
end_idx = len(timesteps) - 1
if log_timesteps:
log.info(f"------- Scheduler info -------")
log.info(f"Total timesteps: {timesteps}")
if isinstance(start_step, float):
idxs = (sample_scheduler.sigmas <= start_step).nonzero(as_tuple=True)[0]
if len(idxs) > 0:
start_idx = idxs[0].item()
elif isinstance(start_step, int):
if start_step > 0:
start_idx = start_step
if isinstance(end_step, float):
idxs = (sample_scheduler.sigmas >= end_step).nonzero(as_tuple=True)[0]
if len(idxs) > 0:
end_idx = idxs[-1].item()
elif isinstance(end_step, int):
if end_step != -1:
end_idx = end_step - 1
# Slice timesteps and sigmas once, based on indices
timesteps = timesteps[start_idx:end_idx+1]
sample_scheduler.full_sigmas = sample_scheduler.sigmas.clone()
sample_scheduler.sigmas = sample_scheduler.sigmas[start_idx:start_idx+len(timesteps)+1] # always one longer
if log_timesteps:
log.info(f"Using timesteps: {timesteps}")
log.info(f"Using sigmas: {sample_scheduler.sigmas}")
log.info(f"------------------------------")
if hasattr(sample_scheduler, 'timesteps'):
sample_scheduler.timesteps = timesteps
return sample_scheduler, timesteps, start_idx, end_idx |