| """ |
| Hilbert Order |
| Modified from https://github.com/PrincetonLIPS/numpy-hilbert-curve |
| |
| Author: Xiaoyang Wu (xiaoyang.wu.cs@gmail.com), Kaixin Xu |
| Please cite our work if the code is helpful to you. |
| """ |
|
|
| import torch |
|
|
|
|
| def right_shift(binary, k=1, axis=-1): |
| """Right shift an array of binary values. |
| |
| Parameters: |
| ----------- |
| binary: An ndarray of binary values. |
| |
| k: The number of bits to shift. Default 1. |
| |
| axis: The axis along which to shift. Default -1. |
| |
| Returns: |
| -------- |
| Returns an ndarray with zero prepended and the ends truncated, along |
| whatever axis was specified.""" |
|
|
| |
| if binary.shape[axis] <= k: |
| return torch.zeros_like(binary) |
|
|
| |
| |
| |
|
|
| |
| slicing = [slice(None)] * len(binary.shape) |
| slicing[axis] = slice(None, -k) |
| shifted = torch.nn.functional.pad( |
| binary[tuple(slicing)], (k, 0), mode="constant", value=0 |
| ) |
|
|
| return shifted |
|
|
|
|
| def binary2gray(binary, axis=-1): |
| """Convert an array of binary values into Gray codes. |
| |
| This uses the classic X ^ (X >> 1) trick to compute the Gray code. |
| |
| Parameters: |
| ----------- |
| binary: An ndarray of binary values. |
| |
| axis: The axis along which to compute the gray code. Default=-1. |
| |
| Returns: |
| -------- |
| Returns an ndarray of Gray codes. |
| """ |
| shifted = right_shift(binary, axis=axis) |
|
|
| |
| gray = torch.logical_xor(binary, shifted) |
|
|
| return gray |
|
|
|
|
| def gray2binary(gray, axis=-1): |
| """Convert an array of Gray codes back into binary values. |
| |
| Parameters: |
| ----------- |
| gray: An ndarray of gray codes. |
| |
| axis: The axis along which to perform Gray decoding. Default=-1. |
| |
| Returns: |
| -------- |
| Returns an ndarray of binary values. |
| """ |
|
|
| |
| shift = 2 ** (torch.Tensor([gray.shape[axis]]).log2().ceil().int() - 1) |
| while shift > 0: |
| gray = torch.logical_xor(gray, right_shift(gray, shift)) |
| shift = torch.div(shift, 2, rounding_mode="floor") |
| return gray |
|
|
|
|
| def encode(locs, num_dims, num_bits): |
| """Decode an array of locations in a hypercube into a Hilbert integer. |
| |
| This is a vectorized-ish version of the Hilbert curve implementation by John |
| Skilling as described in: |
| |
| Skilling, J. (2004, April). Programming the Hilbert curve. In AIP Conference |
| Proceedings (Vol. 707, No. 1, pp. 381-387). American Institute of Physics. |
| |
| Params: |
| ------- |
| locs - An ndarray of locations in a hypercube of num_dims dimensions, in |
| which each dimension runs from 0 to 2**num_bits-1. The shape can |
| be arbitrary, as long as the last dimension of the same has size |
| num_dims. |
| |
| num_dims - The dimensionality of the hypercube. Integer. |
| |
| num_bits - The number of bits for each dimension. Integer. |
| |
| Returns: |
| -------- |
| The output is an ndarray of uint64 integers with the same shape as the |
| input, excluding the last dimension, which needs to be num_dims. |
| """ |
|
|
| |
| orig_shape = locs.shape |
| bitpack_mask = 1 << torch.arange(0, 8).to(locs.device) |
| bitpack_mask_rev = bitpack_mask.flip(-1) |
|
|
| if orig_shape[-1] != num_dims: |
| raise ValueError( |
| """ |
| The shape of locs was surprising in that the last dimension was of size |
| %d, but num_dims=%d. These need to be equal. |
| """ |
| % (orig_shape[-1], num_dims) |
| ) |
|
|
| if num_dims * num_bits > 63: |
| raise ValueError( |
| """ |
| num_dims=%d and num_bits=%d for %d bits total, which can't be encoded |
| into a int64. Are you sure you need that many points on your Hilbert |
| curve? |
| """ |
| % (num_dims, num_bits, num_dims * num_bits) |
| ) |
|
|
| |
| |
| locs_uint8 = locs.long().view(torch.uint8).reshape((-1, num_dims, 8)).flip(-1) |
|
|
| |
| gray = ( |
| locs_uint8.unsqueeze(-1) |
| .bitwise_and(bitpack_mask_rev) |
| .ne(0) |
| .byte() |
| .flatten(-2, -1)[..., -num_bits:] |
| ) |
|
|
| |
| |
| for bit in range(0, num_bits): |
| |
| for dim in range(0, num_dims): |
| |
| mask = gray[:, dim, bit] |
|
|
| |
| gray[:, 0, bit + 1 :] = torch.logical_xor( |
| gray[:, 0, bit + 1 :], mask[:, None] |
| ) |
|
|
| |
| to_flip = torch.logical_and( |
| torch.logical_not(mask[:, None]).repeat(1, gray.shape[2] - bit - 1), |
| torch.logical_xor(gray[:, 0, bit + 1 :], gray[:, dim, bit + 1 :]), |
| ) |
| gray[:, dim, bit + 1 :] = torch.logical_xor( |
| gray[:, dim, bit + 1 :], to_flip |
| ) |
| gray[:, 0, bit + 1 :] = torch.logical_xor(gray[:, 0, bit + 1 :], to_flip) |
|
|
| |
| gray = gray.swapaxes(1, 2).reshape((-1, num_bits * num_dims)) |
|
|
| |
| hh_bin = gray2binary(gray) |
|
|
| |
| extra_dims = 64 - num_bits * num_dims |
| padded = torch.nn.functional.pad(hh_bin, (extra_dims, 0), "constant", 0) |
|
|
| |
| hh_uint8 = ( |
| (padded.flip(-1).reshape((-1, 8, 8)) * bitpack_mask) |
| .sum(2) |
| .squeeze() |
| .type(torch.uint8) |
| ) |
|
|
| |
| hh_uint64 = hh_uint8.view(torch.int64).squeeze() |
|
|
| return hh_uint64 |
|
|
|
|
| def decode(hilberts, num_dims, num_bits): |
| """Decode an array of Hilbert integers into locations in a hypercube. |
| |
| This is a vectorized-ish version of the Hilbert curve implementation by John |
| Skilling as described in: |
| |
| Skilling, J. (2004, April). Programming the Hilbert curve. In AIP Conference |
| Proceedings (Vol. 707, No. 1, pp. 381-387). American Institute of Physics. |
| |
| Params: |
| ------- |
| hilberts - An ndarray of Hilbert integers. Must be an integer dtype and |
| cannot have fewer bits than num_dims * num_bits. |
| |
| num_dims - The dimensionality of the hypercube. Integer. |
| |
| num_bits - The number of bits for each dimension. Integer. |
| |
| Returns: |
| -------- |
| The output is an ndarray of unsigned integers with the same shape as hilberts |
| but with an additional dimension of size num_dims. |
| """ |
|
|
| if num_dims * num_bits > 64: |
| raise ValueError( |
| """ |
| num_dims=%d and num_bits=%d for %d bits total, which can't be encoded |
| into a uint64. Are you sure you need that many points on your Hilbert |
| curve? |
| """ |
| % (num_dims, num_bits) |
| ) |
|
|
| |
| hilberts = torch.atleast_1d(hilberts) |
|
|
| |
| orig_shape = hilberts.shape |
| bitpack_mask = 2 ** torch.arange(0, 8).to(hilberts.device) |
| bitpack_mask_rev = bitpack_mask.flip(-1) |
|
|
| |
| |
| hh_uint8 = ( |
| hilberts.ravel().type(torch.int64).view(torch.uint8).reshape((-1, 8)).flip(-1) |
| ) |
|
|
| |
| |
| hh_bits = ( |
| hh_uint8.unsqueeze(-1) |
| .bitwise_and(bitpack_mask_rev) |
| .ne(0) |
| .byte() |
| .flatten(-2, -1)[:, -num_dims * num_bits :] |
| ) |
|
|
| |
| gray = binary2gray(hh_bits) |
|
|
| |
| |
| gray = gray.reshape((-1, num_bits, num_dims)).swapaxes(1, 2) |
|
|
| |
| for bit in range(num_bits - 1, -1, -1): |
| |
| for dim in range(num_dims - 1, -1, -1): |
| |
| mask = gray[:, dim, bit] |
|
|
| |
| gray[:, 0, bit + 1 :] = torch.logical_xor( |
| gray[:, 0, bit + 1 :], mask[:, None] |
| ) |
|
|
| |
| to_flip = torch.logical_and( |
| torch.logical_not(mask[:, None]), |
| torch.logical_xor(gray[:, 0, bit + 1 :], gray[:, dim, bit + 1 :]), |
| ) |
| gray[:, dim, bit + 1 :] = torch.logical_xor( |
| gray[:, dim, bit + 1 :], to_flip |
| ) |
| gray[:, 0, bit + 1 :] = torch.logical_xor(gray[:, 0, bit + 1 :], to_flip) |
|
|
| |
| extra_dims = 64 - num_bits |
| padded = torch.nn.functional.pad(gray, (extra_dims, 0), "constant", 0) |
|
|
| |
| locs_chopped = padded.flip(-1).reshape((-1, num_dims, 8, 8)) |
|
|
| |
| |
| locs_uint8 = (locs_chopped * bitpack_mask).sum(3).squeeze().type(torch.uint8) |
|
|
| |
| flat_locs = locs_uint8.view(torch.int64) |
|
|
| |
| return flat_locs.reshape((*orig_shape, num_dims)) |
|
|