| import math |
| from collections import defaultdict |
|
|
| import numpy as np |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
|
|
| import craftsman |
| from craftsman.utils.typing import * |
|
|
|
|
| def dot(x, y): |
| return torch.sum(x * y, -1, keepdim=True) |
|
|
|
|
| def reflect(x, n): |
| return 2 * dot(x, n) * n - x |
|
|
|
|
| ValidScale = Union[Tuple[float, float], Num[Tensor, "2 D"]] |
|
|
|
|
| def scale_tensor( |
| dat: Num[Tensor, "... D"], inp_scale: ValidScale, tgt_scale: ValidScale |
| ): |
| if inp_scale is None: |
| inp_scale = (0, 1) |
| if tgt_scale is None: |
| tgt_scale = (0, 1) |
| if isinstance(tgt_scale, Tensor): |
| assert dat.shape[-1] == tgt_scale.shape[-1] |
| dat = (dat - inp_scale[0]) / (inp_scale[1] - inp_scale[0]) |
| dat = dat * (tgt_scale[1] - tgt_scale[0]) + tgt_scale[0] |
| return dat |
|
|
|
|
| def chunk_batch(func: Callable, chunk_size: int, *args, **kwargs) -> Any: |
| if chunk_size <= 0: |
| return func(*args, **kwargs) |
| B = None |
| for arg in list(args) + list(kwargs.values()): |
| if isinstance(arg, torch.Tensor): |
| B = arg.shape[0] |
| break |
| assert ( |
| B is not None |
| ), "No tensor found in args or kwargs, cannot determine batch size." |
| out = defaultdict(list) |
| out_type = None |
| |
| for i in range(0, max(1, B), chunk_size): |
| out_chunk = func( |
| *[ |
| arg[i : i + chunk_size] if isinstance(arg, torch.Tensor) else arg |
| for arg in args |
| ], |
| **{ |
| k: arg[i : i + chunk_size] if isinstance(arg, torch.Tensor) else arg |
| for k, arg in kwargs.items() |
| }, |
| ) |
| if out_chunk is None: |
| continue |
| out_type = type(out_chunk) |
| if isinstance(out_chunk, torch.Tensor): |
| out_chunk = {0: out_chunk} |
| elif isinstance(out_chunk, tuple) or isinstance(out_chunk, list): |
| chunk_length = len(out_chunk) |
| out_chunk = {i: chunk for i, chunk in enumerate(out_chunk)} |
| elif isinstance(out_chunk, dict): |
| pass |
| else: |
| print( |
| f"Return value of func must be in type [torch.Tensor, list, tuple, dict], get {type(out_chunk)}." |
| ) |
| exit(1) |
| for k, v in out_chunk.items(): |
| v = v if torch.is_grad_enabled() else v.detach() |
| out[k].append(v) |
|
|
| if out_type is None: |
| return None |
|
|
| out_merged: Dict[Any, Optional[torch.Tensor]] = {} |
| for k, v in out.items(): |
| if all([vv is None for vv in v]): |
| |
| out_merged[k] = None |
| elif all([isinstance(vv, torch.Tensor) for vv in v]): |
| out_merged[k] = torch.cat(v, dim=0) |
| else: |
| raise TypeError( |
| f"Unsupported types in return value of func: {[type(vv) for vv in v if not isinstance(vv, torch.Tensor)]}" |
| ) |
|
|
| if out_type is torch.Tensor: |
| return out_merged[0] |
| elif out_type in [tuple, list]: |
| return out_type([out_merged[i] for i in range(chunk_length)]) |
| elif out_type is dict: |
| return out_merged |
|
|
|
|
| def randn_tensor( |
| shape: Union[Tuple, List], |
| generator: Optional[Union[List["torch.Generator"], "torch.Generator"]] = None, |
| device: Optional["torch.device"] = None, |
| dtype: Optional["torch.dtype"] = None, |
| layout: Optional["torch.layout"] = None, |
| ): |
| """A helper function to create random tensors on the desired `device` with the desired `dtype`. When |
| passing a list of generators, you can seed each batch size individually. If CPU generators are passed, the tensor |
| is always created on the CPU. |
| """ |
| |
| rand_device = device |
| batch_size = shape[0] |
|
|
| layout = layout or torch.strided |
| device = device or torch.device("cpu") |
|
|
| if generator is not None: |
| gen_device_type = generator.device.type if not isinstance(generator, list) else generator[0].device.type |
| if gen_device_type != device.type and gen_device_type == "cpu": |
| rand_device = "cpu" |
| if device != "mps": |
| logger.info( |
| f"The passed generator was created on 'cpu' even though a tensor on {device} was expected." |
| f" Tensors will be created on 'cpu' and then moved to {device}. Note that one can probably" |
| f" slighly speed up this function by passing a generator that was created on the {device} device." |
| ) |
| elif gen_device_type != device.type and gen_device_type == "cuda": |
| raise ValueError(f"Cannot generate a {device} tensor from a generator of type {gen_device_type}.") |
|
|
| |
| if isinstance(generator, list) and len(generator) == 1: |
| generator = generator[0] |
|
|
| if isinstance(generator, list): |
| shape = (1,) + shape[1:] |
| latents = [ |
| torch.randn(shape, generator=generator[i], device=rand_device, dtype=dtype, layout=layout) |
| for i in range(batch_size) |
| ] |
| latents = torch.cat(latents, dim=0).to(device) |
| else: |
| latents = torch.randn(shape, generator=generator, device=rand_device, dtype=dtype, layout=layout).to(device) |
|
|
| return latents |
|
|
|
|
| def generate_dense_grid_points( |
| bbox_min: np.ndarray, |
| bbox_max: np.ndarray, |
| octree_depth: int, |
| indexing: str = "ij" |
| ): |
| length = bbox_max - bbox_min |
| num_cells = np.exp2(octree_depth) |
| x = np.linspace(bbox_min[0], bbox_max[0], int(num_cells) + 1, dtype=np.float32) |
| y = np.linspace(bbox_min[1], bbox_max[1], int(num_cells) + 1, dtype=np.float32) |
| z = np.linspace(bbox_min[2], bbox_max[2], int(num_cells) + 1, dtype=np.float32) |
| [xs, ys, zs] = np.meshgrid(x, y, z, indexing=indexing) |
| xyz = np.stack((xs, ys, zs), axis=-1) |
| xyz = xyz.reshape(-1, 3) |
| grid_size = [int(num_cells) + 1, int(num_cells) + 1, int(num_cells) + 1] |
|
|
| return xyz, grid_size, length |