Spaces:
Runtime error
Runtime error
| import gc | |
| from collections import defaultdict | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from torch.autograd import Function | |
| from torch.cuda.amp import custom_bwd, custom_fwd | |
| import tinycudann as tcnn | |
| def chunk_batch(func, chunk_size, move_to_cpu, *args, **kwargs): | |
| B = None | |
| for arg in args: | |
| if isinstance(arg, torch.Tensor): | |
| B = arg.shape[0] | |
| break | |
| out = defaultdict(list) | |
| out_type = None | |
| for i in range(0, B, chunk_size): | |
| out_chunk = func(*[arg[i:i+chunk_size] if isinstance(arg, torch.Tensor) else arg for arg in args], **kwargs) | |
| if out_chunk is None: | |
| continue | |
| out_type = type(out_chunk) | |
| if isinstance(out_chunk, torch.Tensor): | |
| out_chunk = {0: out_chunk} | |
| elif isinstance(out_chunk, tuple) or isinstance(out_chunk, list): | |
| chunk_length = len(out_chunk) | |
| out_chunk = {i: chunk for i, chunk in enumerate(out_chunk)} | |
| elif isinstance(out_chunk, dict): | |
| pass | |
| else: | |
| print(f'Return value of func must be in type [torch.Tensor, list, tuple, dict], get {type(out_chunk)}.') | |
| exit(1) | |
| for k, v in out_chunk.items(): | |
| v = v if torch.is_grad_enabled() else v.detach() | |
| v = v.cpu() if move_to_cpu else v | |
| out[k].append(v) | |
| if out_type is None: | |
| return | |
| out = {k: torch.cat(v, dim=0) for k, v in out.items()} | |
| if out_type is torch.Tensor: | |
| return out[0] | |
| elif out_type in [tuple, list]: | |
| return out_type([out[i] for i in range(chunk_length)]) | |
| elif out_type is dict: | |
| return out | |
| class _TruncExp(Function): # pylint: disable=abstract-method | |
| # Implementation from torch-ngp: | |
| # https://github.com/ashawkey/torch-ngp/blob/93b08a0d4ec1cc6e69d85df7f0acdfb99603b628/activation.py | |
| def forward(ctx, x): # pylint: disable=arguments-differ | |
| ctx.save_for_backward(x) | |
| return torch.exp(x) | |
| def backward(ctx, g): # pylint: disable=arguments-differ | |
| x = ctx.saved_tensors[0] | |
| return g * torch.exp(torch.clamp(x, max=15)) | |
| trunc_exp = _TruncExp.apply | |
| def get_activation(name): | |
| if name is None: | |
| return lambda x: x | |
| name = name.lower() | |
| if name == 'none': | |
| return lambda x: x | |
| elif name.startswith('scale'): | |
| scale_factor = float(name[5:]) | |
| return lambda x: x.clamp(0., scale_factor) / scale_factor | |
| elif name.startswith('clamp'): | |
| clamp_max = float(name[5:]) | |
| return lambda x: x.clamp(0., clamp_max) | |
| elif name.startswith('mul'): | |
| mul_factor = float(name[3:]) | |
| return lambda x: x * mul_factor | |
| elif name == 'lin2srgb': | |
| return lambda x: torch.where(x > 0.0031308, torch.pow(torch.clamp(x, min=0.0031308), 1.0/2.4)*1.055 - 0.055, 12.92*x).clamp(0., 1.) | |
| elif name == 'trunc_exp': | |
| return trunc_exp | |
| elif name.startswith('+') or name.startswith('-'): | |
| return lambda x: x + float(name) | |
| elif name == 'sigmoid': | |
| return lambda x: torch.sigmoid(x) | |
| elif name == 'tanh': | |
| return lambda x: torch.tanh(x) | |
| else: | |
| return getattr(F, name) | |
| def dot(x, y): | |
| return torch.sum(x*y, -1, keepdim=True) | |
| def reflect(x, n): | |
| return 2 * dot(x, n) * n - x | |
| def scale_anything(dat, inp_scale, tgt_scale): | |
| if inp_scale is None: | |
| inp_scale = [dat.min(), dat.max()] | |
| dat = (dat - inp_scale[0]) / (inp_scale[1] - inp_scale[0]) | |
| dat = dat * (tgt_scale[1] - tgt_scale[0]) + tgt_scale[0] | |
| return dat | |
| def cleanup(): | |
| gc.collect() | |
| torch.cuda.empty_cache() | |
| tcnn.free_temporary_memory() | |