Spaces:
Runtime error
Runtime error
Update inference.py
Browse files- inference.py +3 -46
inference.py
CHANGED
|
@@ -21,43 +21,6 @@ else:
|
|
| 21 |
XLA_AVAILABLE = False
|
| 22 |
|
| 23 |
|
| 24 |
-
def retrieve_timesteps(
|
| 25 |
-
scheduler,
|
| 26 |
-
num_inference_steps: Optional[int] = None,
|
| 27 |
-
device: Optional[Union[str, torch.device]] = None,
|
| 28 |
-
timesteps: Optional[List[int]] = None,
|
| 29 |
-
sigmas: Optional[List[float]] = None,
|
| 30 |
-
**kwargs,
|
| 31 |
-
):
|
| 32 |
-
if timesteps is not None and sigmas is not None:
|
| 33 |
-
raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
|
| 34 |
-
if timesteps is not None:
|
| 35 |
-
accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
| 36 |
-
if not accepts_timesteps:
|
| 37 |
-
raise ValueError(
|
| 38 |
-
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
| 39 |
-
f" timestep schedules. Please check whether you are using the correct scheduler."
|
| 40 |
-
)
|
| 41 |
-
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
|
| 42 |
-
timesteps = scheduler.timesteps
|
| 43 |
-
num_inference_steps = len(timesteps)
|
| 44 |
-
elif sigmas is not None:
|
| 45 |
-
accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
| 46 |
-
if not accept_sigmas:
|
| 47 |
-
raise ValueError(
|
| 48 |
-
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
| 49 |
-
f" sigmas schedules. Please check whether you are using the correct scheduler."
|
| 50 |
-
)
|
| 51 |
-
scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
|
| 52 |
-
timesteps = scheduler.timesteps
|
| 53 |
-
num_inference_steps = len(timesteps)
|
| 54 |
-
else:
|
| 55 |
-
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
|
| 56 |
-
timesteps = scheduler.timesteps
|
| 57 |
-
|
| 58 |
-
return timesteps, num_inference_steps
|
| 59 |
-
|
| 60 |
-
|
| 61 |
@torch.no_grad()
|
| 62 |
def run(
|
| 63 |
self,
|
|
@@ -68,6 +31,7 @@ def run(
|
|
| 68 |
width: Optional[int] = None,
|
| 69 |
num_inference_steps: int = 28,
|
| 70 |
sigmas: Optional[List[float]] = None,
|
|
|
|
| 71 |
scales: List[float] = None,
|
| 72 |
guidance_scale: float = 7.0,
|
| 73 |
negative_prompt: Optional[Union[str, List[str]]] = None,
|
|
@@ -196,13 +160,6 @@ def run(
|
|
| 196 |
scheduler_kwargs["mu"] = mu
|
| 197 |
elif mu is not None:
|
| 198 |
scheduler_kwargs["mu"] = mu
|
| 199 |
-
timesteps, num_inference_steps = retrieve_timesteps(
|
| 200 |
-
self.scheduler,
|
| 201 |
-
num_inference_steps,
|
| 202 |
-
device,
|
| 203 |
-
sigmas=sigmas,
|
| 204 |
-
**scheduler_kwargs,
|
| 205 |
-
)
|
| 206 |
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
|
| 207 |
self._num_timesteps = len(timesteps)
|
| 208 |
|
|
@@ -269,8 +226,8 @@ def run(
|
|
| 269 |
|
| 270 |
# compute the previous noisy sample x_t -> x_t-1
|
| 271 |
latents_dtype = latents.dtype
|
| 272 |
-
sigma =
|
| 273 |
-
sigma_next =
|
| 274 |
x0_pred = (latents - sigma * noise_pred)
|
| 275 |
try:
|
| 276 |
x0_pred = torch.nn.functional.interpolate(x0_pred, size=scales[i + 1])
|
|
|
|
| 21 |
XLA_AVAILABLE = False
|
| 22 |
|
| 23 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 24 |
@torch.no_grad()
|
| 25 |
def run(
|
| 26 |
self,
|
|
|
|
| 31 |
width: Optional[int] = None,
|
| 32 |
num_inference_steps: int = 28,
|
| 33 |
sigmas: Optional[List[float]] = None,
|
| 34 |
+
timesteps: Optional[List[float]] = None,
|
| 35 |
scales: List[float] = None,
|
| 36 |
guidance_scale: float = 7.0,
|
| 37 |
negative_prompt: Optional[Union[str, List[str]]] = None,
|
|
|
|
| 160 |
scheduler_kwargs["mu"] = mu
|
| 161 |
elif mu is not None:
|
| 162 |
scheduler_kwargs["mu"] = mu
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 163 |
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
|
| 164 |
self._num_timesteps = len(timesteps)
|
| 165 |
|
|
|
|
| 226 |
|
| 227 |
# compute the previous noisy sample x_t -> x_t-1
|
| 228 |
latents_dtype = latents.dtype
|
| 229 |
+
sigma = sigmas[i]
|
| 230 |
+
sigma_next = sigmas[i + 1]
|
| 231 |
x0_pred = (latents - sigma * noise_pred)
|
| 232 |
try:
|
| 233 |
x0_pred = torch.nn.functional.interpolate(x0_pred, size=scales[i + 1])
|