File size: 6,579 Bytes
ac2243f | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 | import copy
import inspect
from typing import Any, List, Optional, Union
import torch
class BaseAsyncScheduler:
def __init__(self, scheduler: Any):
self.scheduler = scheduler
def __getattr__(self, name: str):
if hasattr(self.scheduler, name):
return getattr(self.scheduler, name)
raise AttributeError(f"'{self.__class__.__name__}' object has no attribute '{name}'")
def __setattr__(self, name: str, value):
if name == "scheduler":
super().__setattr__(name, value)
else:
if hasattr(self, "scheduler") and hasattr(self.scheduler, name):
setattr(self.scheduler, name, value)
else:
super().__setattr__(name, value)
def clone_for_request(self, num_inference_steps: int, device: Union[str, torch.device, None] = None, **kwargs):
local = copy.deepcopy(self.scheduler)
local.set_timesteps(num_inference_steps=num_inference_steps, device=device, **kwargs)
cloned = self.__class__(local)
return cloned
def __repr__(self):
return f"BaseAsyncScheduler({repr(self.scheduler)})"
def __str__(self):
return f"BaseAsyncScheduler wrapping: {str(self.scheduler)}"
def async_retrieve_timesteps(
scheduler,
num_inference_steps: Optional[int] = None,
device: Optional[Union[str, torch.device]] = None,
timesteps: Optional[List[int]] = None,
sigmas: Optional[List[float]] = None,
**kwargs,
):
r"""
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call.
Handles custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
Backwards compatible: by default the function behaves exactly as before and returns
(timesteps_tensor, num_inference_steps)
If the caller passes `return_scheduler=True` in kwargs, the function will **not** mutate the passed
scheduler. Instead it will use a cloned scheduler if available (via `scheduler.clone_for_request`)
or a deepcopy fallback, call `set_timesteps` on that cloned scheduler, and return:
(timesteps_tensor, num_inference_steps, scheduler_in_use)
Args:
scheduler (`SchedulerMixin`):
The scheduler to get timesteps from.
num_inference_steps (`int`):
The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
must be `None`.
device (`str` or `torch.device`, *optional*):
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
timesteps (`List[int]`, *optional*):
Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
`num_inference_steps` and `sigmas` must be `None`.
sigmas (`List[float]`, *optional*):
Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
`num_inference_steps` and `timesteps` must be `None`.
Optional kwargs:
return_scheduler (bool, default False): if True, return (timesteps, num_inference_steps, scheduler_in_use)
where `scheduler_in_use` is a scheduler instance that already has timesteps set.
This mode will prefer `scheduler.clone_for_request(...)` if available, to avoid mutating the original scheduler.
Returns:
`(timesteps_tensor, num_inference_steps)` by default (backwards compatible), or
`(timesteps_tensor, num_inference_steps, scheduler_in_use)` if `return_scheduler=True`.
"""
# pop our optional control kwarg (keeps compatibility)
return_scheduler = bool(kwargs.pop("return_scheduler", False))
if timesteps is not None and sigmas is not None:
raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
# choose scheduler to call set_timesteps on
scheduler_in_use = scheduler
if return_scheduler:
# Do not mutate the provided scheduler: prefer to clone if possible
if hasattr(scheduler, "clone_for_request"):
try:
# clone_for_request may accept num_inference_steps or other kwargs; be permissive
scheduler_in_use = scheduler.clone_for_request(
num_inference_steps=num_inference_steps or 0, device=device
)
except Exception:
scheduler_in_use = copy.deepcopy(scheduler)
else:
# fallback deepcopy (scheduler tends to be smallish - acceptable)
scheduler_in_use = copy.deepcopy(scheduler)
# helper to test if set_timesteps supports a particular kwarg
def _accepts(param_name: str) -> bool:
try:
return param_name in set(inspect.signature(scheduler_in_use.set_timesteps).parameters.keys())
except (ValueError, TypeError):
# if signature introspection fails, be permissive and attempt the call later
return False
# now call set_timesteps on the chosen scheduler_in_use (may be original or clone)
if timesteps is not None:
accepts_timesteps = _accepts("timesteps")
if not accepts_timesteps:
raise ValueError(
f"The current scheduler class {scheduler_in_use.__class__}'s `set_timesteps` does not support custom"
f" timestep schedules. Please check whether you are using the correct scheduler."
)
scheduler_in_use.set_timesteps(timesteps=timesteps, device=device, **kwargs)
timesteps_out = scheduler_in_use.timesteps
num_inference_steps = len(timesteps_out)
elif sigmas is not None:
accept_sigmas = _accepts("sigmas")
if not accept_sigmas:
raise ValueError(
f"The current scheduler class {scheduler_in_use.__class__}'s `set_timesteps` does not support custom"
f" sigmas schedules. Please check whether you are using the correct scheduler."
)
scheduler_in_use.set_timesteps(sigmas=sigmas, device=device, **kwargs)
timesteps_out = scheduler_in_use.timesteps
num_inference_steps = len(timesteps_out)
else:
# default path
scheduler_in_use.set_timesteps(num_inference_steps, device=device, **kwargs)
timesteps_out = scheduler_in_use.timesteps
if return_scheduler:
return timesteps_out, num_inference_steps, scheduler_in_use
return timesteps_out, num_inference_steps
|