Spaces:
Build error
Build error
| import torchvision.transforms.functional as TF | |
| from src.utils.iimage import IImage | |
| import torch | |
| import sys | |
| from .utils import * | |
| input_mask = None | |
| input_shape = None | |
| timestep = None | |
| timestep_index = None | |
| class Seed: | |
| def __getitem__(self, idx): | |
| if isinstance(idx, slice): | |
| idx = list(range(*idx.indices(idx.stop))) | |
| if isinstance(idx, list) or isinstance(idx, tuple): | |
| return [self[_idx] for _idx in idx] | |
| return 12345 ** idx % 54321 | |
| class DDIMIterator: | |
| def __init__(self, iterator): | |
| self.iterator = iterator | |
| def __iter__(self): | |
| self.iterator = iter(self.iterator) | |
| global timestep_index | |
| timestep_index = 0 | |
| return self | |
| def __next__(self): | |
| global timestep, timestep_index | |
| timestep = next(self.iterator) | |
| timestep_index += 1 | |
| return timestep | |
| seed = Seed() | |
| self = sys.modules[__name__] | |
| def reshape(x): | |
| return input_shape.reshape(x) | |
| def set_shape(image_or_shape): | |
| global input_shape | |
| # if isinstance(image_or_shape, IImage): | |
| if hasattr(image_or_shape, 'size'): | |
| input_shape = InputShape(image_or_shape.size) | |
| if isinstance(image_or_shape, torch.Tensor): | |
| input_shape = InputShape(image_or_shape.shape[-2:][::-1]) | |
| elif isinstance(image_or_shape, list) or isinstance(image_or_shape, tuple): | |
| input_shape = InputShape(image_or_shape) | |
| def set_mask(mask): | |
| global input_mask, mask64, mask32, mask16, mask8, painta_mask | |
| input_mask = InputMask(mask) | |
| painta_mask = InputMask(mask) | |
| mask64 = input_mask.val64[0,0] | |
| mask32 = input_mask.val32[0,0] | |
| mask16 = input_mask.val16[0,0] | |
| mask8 = input_mask.val8[0,0] | |
| set_shape(mask) | |
| def exists(name): | |
| return hasattr(self, name) and getattr(self, name) is not None |