| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | import re |
| | import contextlib |
| | import numpy as np |
| | import torch |
| | import warnings |
| | from modules.eg3ds import dnnlib |
| |
|
| | |
| | |
| | |
| |
|
| | _constant_cache = dict() |
| |
|
| | def constant(value, shape=None, dtype=None, device=None, memory_format=None): |
| | value = np.asarray(value) |
| | if shape is not None: |
| | shape = tuple(shape) |
| | if dtype is None: |
| | dtype = torch.get_default_dtype() |
| | if device is None: |
| | device = torch.device('cpu') |
| | if memory_format is None: |
| | memory_format = torch.contiguous_format |
| |
|
| | key = (value.shape, value.dtype, value.tobytes(), shape, dtype, device, memory_format) |
| | tensor = _constant_cache.get(key, None) |
| | if tensor is None: |
| | tensor = torch.as_tensor(value.copy(), dtype=dtype, device=device) |
| | if shape is not None: |
| | tensor, _ = torch.broadcast_tensors(tensor, torch.empty(shape)) |
| | tensor = tensor.contiguous(memory_format=memory_format) |
| | _constant_cache[key] = tensor |
| | return tensor |
| |
|
| | |
| | |
| |
|
| | try: |
| | nan_to_num = torch.nan_to_num |
| | except AttributeError: |
| | def nan_to_num(input, nan=0.0, posinf=None, neginf=None, *, out=None): |
| | assert isinstance(input, torch.Tensor) |
| | if posinf is None: |
| | posinf = torch.finfo(input.dtype).max |
| | if neginf is None: |
| | neginf = torch.finfo(input.dtype).min |
| | assert nan == 0 |
| | return torch.clamp(input.unsqueeze(0).nansum(0), min=neginf, max=posinf, out=out) |
| |
|
| | |
| | |
| |
|
| | try: |
| | symbolic_assert = torch._assert |
| | except AttributeError: |
| | symbolic_assert = torch.Assert |
| |
|
| | |
| | |
| | |
| |
|
| | @contextlib.contextmanager |
| | def suppress_tracer_warnings(): |
| | flt = ('ignore', None, torch.jit.TracerWarning, None, 0) |
| | warnings.filters.insert(0, flt) |
| | yield |
| | warnings.filters.remove(flt) |
| |
|
| | |
| | |
| | |
| | |
| |
|
| | def assert_shape(tensor, ref_shape): |
| | if tensor.ndim != len(ref_shape): |
| | raise AssertionError(f'Wrong number of dimensions: got {tensor.ndim}, expected {len(ref_shape)}') |
| | for idx, (size, ref_size) in enumerate(zip(tensor.shape, ref_shape)): |
| | if ref_size is None: |
| | pass |
| | elif isinstance(ref_size, torch.Tensor): |
| | with suppress_tracer_warnings(): |
| | symbolic_assert(torch.equal(torch.as_tensor(size), ref_size), f'Wrong size for dimension {idx}') |
| | elif isinstance(size, torch.Tensor): |
| | with suppress_tracer_warnings(): |
| | symbolic_assert(torch.equal(size, torch.as_tensor(ref_size)), f'Wrong size for dimension {idx}: expected {ref_size}') |
| | elif size != ref_size: |
| | raise AssertionError(f'Wrong size for dimension {idx}: got {size}, expected {ref_size}') |
| |
|
| | |
| | |
| |
|
| | def profiled_function(fn): |
| | def decorator(*args, **kwargs): |
| | with torch.autograd.profiler.record_function(fn.__name__): |
| | return fn(*args, **kwargs) |
| | decorator.__name__ = fn.__name__ |
| | return decorator |
| |
|
| | |
| | |
| | |
| |
|
| | class InfiniteSampler(torch.utils.data.Sampler): |
| | def __init__(self, dataset, rank=0, num_replicas=1, shuffle=True, seed=0, window_size=0.5): |
| | assert len(dataset) > 0 |
| | assert num_replicas > 0 |
| | assert 0 <= rank < num_replicas |
| | assert 0 <= window_size <= 1 |
| | super().__init__(dataset) |
| | self.dataset = dataset |
| | self.rank = rank |
| | self.num_replicas = num_replicas |
| | self.shuffle = shuffle |
| | self.seed = seed |
| | self.window_size = window_size |
| |
|
| | def __iter__(self): |
| | order = np.arange(len(self.dataset)) |
| | rnd = None |
| | window = 0 |
| | if self.shuffle: |
| | rnd = np.random.RandomState(self.seed) |
| | rnd.shuffle(order) |
| | window = int(np.rint(order.size * self.window_size)) |
| |
|
| | idx = 0 |
| | while True: |
| | i = idx % order.size |
| | if idx % self.num_replicas == self.rank: |
| | yield order[i] |
| | if window >= 2: |
| | j = (i - rnd.randint(window)) % order.size |
| | order[i], order[j] = order[j], order[i] |
| | idx += 1 |
| |
|
| | |
| | |
| |
|
| | def params_and_buffers(module): |
| | assert isinstance(module, torch.nn.Module) |
| | return list(module.parameters()) + list(module.buffers()) |
| |
|
| | def named_params_and_buffers(module): |
| | assert isinstance(module, torch.nn.Module) |
| | return list(module.named_parameters()) + list(module.named_buffers()) |
| |
|
| | def copy_params_and_buffers(src_module, dst_module, require_all=False): |
| | assert isinstance(src_module, torch.nn.Module) |
| | assert isinstance(dst_module, torch.nn.Module) |
| | src_tensors = dict(named_params_and_buffers(src_module)) |
| | for name, tensor in named_params_and_buffers(dst_module): |
| | assert (name in src_tensors) or (not require_all) |
| | if name in src_tensors: |
| | tensor.copy_(src_tensors[name].detach()).requires_grad_(tensor.requires_grad) |
| |
|
| | |
| | |
| | |
| |
|
| | @contextlib.contextmanager |
| | def ddp_sync(module, sync): |
| | assert isinstance(module, torch.nn.Module) |
| | if sync or not isinstance(module, torch.nn.parallel.DistributedDataParallel): |
| | yield |
| | else: |
| | with module.no_sync(): |
| | yield |
| |
|
| | |
| | |
| |
|
| | def check_ddp_consistency(module, ignore_regex=None): |
| | assert isinstance(module, torch.nn.Module) |
| | for name, tensor in named_params_and_buffers(module): |
| | fullname = type(module).__name__ + '.' + name |
| | if ignore_regex is not None and re.fullmatch(ignore_regex, fullname): |
| | continue |
| | tensor = tensor.detach() |
| | if tensor.is_floating_point(): |
| | tensor = nan_to_num(tensor) |
| | other = tensor.clone() |
| | torch.distributed.broadcast(tensor=other, src=0) |
| | assert (tensor == other).all(), fullname |
| |
|
| | |
| | |
| |
|
| | def print_module_summary(module, inputs, max_nesting=3, skip_redundant=True): |
| | assert isinstance(module, torch.nn.Module) |
| | assert not isinstance(module, torch.jit.ScriptModule) |
| | assert isinstance(inputs, (tuple, list)) |
| |
|
| | |
| | entries = [] |
| | nesting = [0] |
| | def pre_hook(_mod, _inputs): |
| | nesting[0] += 1 |
| | def post_hook(mod, _inputs, outputs): |
| | nesting[0] -= 1 |
| | if nesting[0] <= max_nesting: |
| | outputs = list(outputs) if isinstance(outputs, (tuple, list)) else [outputs] |
| | outputs = [t for t in outputs if isinstance(t, torch.Tensor)] |
| | entries.append(dnnlib.EasyDict(mod=mod, outputs=outputs)) |
| | hooks = [mod.register_forward_pre_hook(pre_hook) for mod in module.modules()] |
| | hooks += [mod.register_forward_hook(post_hook) for mod in module.modules()] |
| |
|
| | |
| | outputs = module(*inputs) |
| | for hook in hooks: |
| | hook.remove() |
| |
|
| | |
| | tensors_seen = set() |
| | for e in entries: |
| | e.unique_params = [t for t in e.mod.parameters() if id(t) not in tensors_seen] |
| | e.unique_buffers = [t for t in e.mod.buffers() if id(t) not in tensors_seen] |
| | e.unique_outputs = [t for t in e.outputs if id(t) not in tensors_seen] |
| | tensors_seen |= {id(t) for t in e.unique_params + e.unique_buffers + e.unique_outputs} |
| |
|
| | |
| | if skip_redundant: |
| | entries = [e for e in entries if len(e.unique_params) or len(e.unique_buffers) or len(e.unique_outputs)] |
| |
|
| | |
| | rows = [[type(module).__name__, 'Parameters', 'Buffers', 'Output shape', 'Datatype']] |
| | rows += [['---'] * len(rows[0])] |
| | param_total = 0 |
| | buffer_total = 0 |
| | submodule_names = {mod: name for name, mod in module.named_modules()} |
| | for e in entries: |
| | name = '<top-level>' if e.mod is module else submodule_names[e.mod] |
| | param_size = sum(t.numel() for t in e.unique_params) |
| | buffer_size = sum(t.numel() for t in e.unique_buffers) |
| | output_shapes = [str(list(t.shape)) for t in e.outputs] |
| | output_dtypes = [str(t.dtype).split('.')[-1] for t in e.outputs] |
| | rows += [[ |
| | name + (':0' if len(e.outputs) >= 2 else ''), |
| | str(param_size) if param_size else '-', |
| | str(buffer_size) if buffer_size else '-', |
| | (output_shapes + ['-'])[0], |
| | (output_dtypes + ['-'])[0], |
| | ]] |
| | for idx in range(1, len(e.outputs)): |
| | rows += [[name + f':{idx}', '-', '-', output_shapes[idx], output_dtypes[idx]]] |
| | param_total += param_size |
| | buffer_total += buffer_size |
| | rows += [['---'] * len(rows[0])] |
| | rows += [['Total', str(param_total), str(buffer_total), '-', '-']] |
| |
|
| | |
| | widths = [max(len(cell) for cell in column) for column in zip(*rows)] |
| | print() |
| | for row in rows: |
| | print(' '.join(cell + ' ' * (width - len(cell)) for cell, width in zip(row, widths))) |
| | print() |
| | return outputs |
| |
|
| | |
| |
|