Spaces:
Runtime error
Runtime error
| from typing import Union, List | |
| import numpy as np | |
| import torch | |
| def extract_generator_seed(generator: Union[torch.Generator, List[torch.Generator]]) -> List[int]: | |
| if isinstance(generator, list): | |
| generator = [g.seed() for g in generator] | |
| else: | |
| generator = [generator.seed()] | |
| return generator | |
| def randn_tensor(shape, dtype: np.dtype, generator: Union[torch.Generator, List[torch.Generator], int, List[int]]): | |
| if hasattr(generator, "seed") or (isinstance(generator, list) and hasattr(generator[0], "seed")): | |
| generator = extract_generator_seed(generator) | |
| if len(generator) == 1: | |
| generator = generator[0] | |
| return np.random.default_rng(generator).standard_normal(shape).astype(dtype) | |
| def prepare_latents( | |
| init_noise_sigma: float, | |
| batch_size: int, | |
| height: int, | |
| width: int, | |
| dtype: np.dtype, | |
| generator: Union[torch.Generator, List[torch.Generator]], | |
| latents: Union[np.ndarray, None]=None, | |
| num_channels_latents=4, | |
| vae_scale_factor=8, | |
| ): | |
| shape = (batch_size, num_channels_latents, height // vae_scale_factor, width // vae_scale_factor) | |
| if isinstance(generator, list) and len(generator) != batch_size: | |
| raise ValueError( | |
| f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" | |
| f" size of {batch_size}. Make sure the batch size matches the length of the generators." | |
| ) | |
| if latents is None: | |
| latents = randn_tensor(shape, dtype, generator) | |
| elif latents.shape != shape: | |
| raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}") | |
| # scale the initial noise by the standard deviation required by the scheduler | |
| latents = latents * np.float64(init_noise_sigma) | |
| return latents | |