Spaces:
Runtime error
Runtime error
| # # Copyright 2024 Yuehao Wang (https://github.com/yuehaowang). This part of code is borrowed form ["Bilateral Guided Radiance Field Processing"](https://bilarfpro.github.io/). | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| """ | |
| This is a standalone PyTorch implementation of 3D bilateral grid and CP-decomposed 4D bilateral grid. | |
| To use this module, you can download the "lib_bilagrid.py" file and simply put it in your project directory. | |
| For the details, please check our research project: ["Bilateral Guided Radiance Field Processing"](https://bilarfpro.github.io/). | |
| #### Dependencies | |
| In addition to PyTorch and Numpy, please install [tensorly](https://github.com/tensorly/tensorly). | |
| We have tested this module on Python 3.9.18, PyTorch 2.0.1 (CUDA 11), tensorly 0.8.1, and Numpy 1.25.2. | |
| #### Overview | |
| - For bilateral guided training, you need to construct a `BilateralGrid` instance, which can hold multiple bilateral grids | |
| for input views. Then, use `slice` function to obtain transformed RGB output and the corresponding affine transformations. | |
| - For bilateral guided finishing, you need to instantiate a `BilateralGridCP4D` object and use `slice4d`. | |
| #### Examples | |
| - Bilateral grid for approximating ISP: | |
| <a target="_blank" href="https://colab.research.google.com/drive/1tx2qKtsHH9deDDnParMWrChcsa9i7Prr?usp=sharing"> | |
| <img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a> | |
| - Low-rank 4D bilateral grid for MR enhancement: | |
| <a target="_blank" href="https://colab.research.google.com/drive/17YOjQqgWFT3QI1vysOIH494rMYtt_mHL?usp=sharing"> | |
| <img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a> | |
| Below is the API reference. | |
| """ | |
| import tensorly as tl | |
| import torch | |
| import torch.nn.functional as F | |
| from torch import nn | |
| tl.set_backend("pytorch") | |
| def color_correct( | |
| img: torch.Tensor, ref: torch.Tensor, num_iters: int = 5, eps: float = 0.5 / 255 | |
| ) -> torch.Tensor: | |
| """ | |
| Warp `img` to match the colors in `ref_img` using iterative color matching. | |
| This function performs color correction by warping the colors of the input image | |
| to match those of a reference image. It uses a least squares method to find a | |
| transformation that maps the input image's colors to the reference image's colors. | |
| The algorithm iteratively solves a system of linear equations, updating the set of | |
| unsaturated pixels in each iteration. This approach helps handle non-linear color | |
| transformations and reduces the impact of clipping. | |
| Args: | |
| img (torch.Tensor): Input image to be color corrected. Shape: [..., num_channels] | |
| ref (torch.Tensor): Reference image to match colors. Shape: [..., num_channels] | |
| num_iters (int, optional): Number of iterations for the color matching process. | |
| Default is 5. | |
| eps (float, optional): Small value to determine the range of unclipped pixels. | |
| Default is 0.5 / 255. | |
| Returns: | |
| torch.Tensor: Color corrected image with the same shape as the input image. | |
| Note: | |
| - Both input and reference images should be in the range [0, 1]. | |
| - The function works with any number of channels, but typically used with 3 (RGB). | |
| """ | |
| if img.shape[-1] != ref.shape[-1]: | |
| raise ValueError( | |
| f"img's {img.shape[-1]} and ref's {ref.shape[-1]} channels must match" | |
| ) | |
| num_channels = img.shape[-1] | |
| img_mat = img.reshape([-1, num_channels]) | |
| ref_mat = ref.reshape([-1, num_channels]) | |
| def is_unclipped(z): | |
| return (z >= eps) & (z <= 1 - eps) # z \in [eps, 1-eps]. | |
| mask0 = is_unclipped(img_mat) | |
| # Because the set of saturated pixels may change after solving for a | |
| # transformation, we repeatedly solve a system `num_iters` times and update | |
| # our estimate of which pixels are saturated. | |
| for _ in range(num_iters): | |
| # Construct the left hand side of a linear system that contains a quadratic | |
| # expansion of each pixel of `img`. | |
| a_mat = [] | |
| for c in range(num_channels): | |
| a_mat.append(img_mat[:, c : (c + 1)] * img_mat[:, c:]) # Quadratic term. | |
| a_mat.append(img_mat) # Linear term. | |
| a_mat.append(torch.ones_like(img_mat[:, :1])) # Bias term. | |
| a_mat = torch.cat(a_mat, dim=-1) | |
| warp = [] | |
| for c in range(num_channels): | |
| # Construct the right hand side of a linear system containing each color | |
| # of `ref`. | |
| b = ref_mat[:, c] | |
| # Ignore rows of the linear system that were saturated in the input or are | |
| # saturated in the current corrected color estimate. | |
| mask = mask0[:, c] & is_unclipped(img_mat[:, c]) & is_unclipped(b) | |
| ma_mat = torch.where(mask[:, None], a_mat, torch.zeros_like(a_mat)) | |
| mb = torch.where(mask, b, torch.zeros_like(b)) | |
| w = torch.linalg.lstsq(ma_mat, mb, rcond=-1)[0] | |
| assert torch.all(torch.isfinite(w)) | |
| warp.append(w) | |
| warp = torch.stack(warp, dim=-1) | |
| # Apply the warp to update img_mat. | |
| img_mat = torch.clip(torch.matmul(a_mat, warp), 0, 1) | |
| corrected_img = torch.reshape(img_mat, img.shape) | |
| return corrected_img | |
| def bilateral_grid_tv_loss(model, config): | |
| """Computes total variations of bilateral grids.""" | |
| total_loss = 0.0 | |
| for bil_grids in model.bil_grids: | |
| total_loss += config.bilgrid_tv_loss_mult * total_variation_loss( | |
| bil_grids.grids | |
| ) | |
| return total_loss | |
| def color_affine_transform(affine_mats, rgb): | |
| """Applies color affine transformations. | |
| Args: | |
| affine_mats (torch.Tensor): Affine transformation matrices. Supported shape: $(..., 3, 4)$. | |
| rgb (torch.Tensor): Input RGB values. Supported shape: $(..., 3)$. | |
| Returns: | |
| Output transformed colors of shape $(..., 3)$. | |
| """ | |
| return ( | |
| torch.matmul(affine_mats[..., :3], rgb.unsqueeze(-1)).squeeze(-1) | |
| + affine_mats[..., 3] | |
| ) | |
| def _num_tensor_elems(t): | |
| return max(torch.prod(torch.tensor(t.size()[1:]).float()).item(), 1.0) | |
| def total_variation_loss(x): # noqa: F811 | |
| """Returns total variation on multi-dimensional tensors. | |
| Args: | |
| x (torch.Tensor): The input tensor with shape $(B, C, ...)$, where $B$ is the batch size and $C$ is the channel size. | |
| """ | |
| batch_size = x.shape[0] | |
| tv = 0 | |
| for i in range(2, len(x.shape)): | |
| n_res = x.shape[i] | |
| idx1 = torch.arange(1, n_res, device=x.device) | |
| idx2 = torch.arange(0, n_res - 1, device=x.device) | |
| x1 = x.index_select(i, idx1) | |
| x2 = x.index_select(i, idx2) | |
| count = _num_tensor_elems(x1) | |
| tv += torch.pow((x1 - x2), 2).sum() / count | |
| return tv / batch_size | |
| def slice(bil_grids, xy, rgb, grid_idx): | |
| """Slices a batch of 3D bilateral grids by pixel coordinates `xy` and gray-scale guidances of pixel colors `rgb`. | |
| Supports 2-D, 3-D, and 4-D input shapes. The first dimension of the input is the batch size | |
| and the last dimension is 2 for `xy`, 3 for `rgb`, and 1 for `grid_idx`. | |
| The return value is a dictionary containing the affine transformations `affine_mats` sliced from bilateral grids and | |
| the output color `rgb_out` after applying the afffine transformations. | |
| In the 2-D input case, `xy` is a $(N, 2)$ tensor, `rgb` is a $(N, 3)$ tensor, and `grid_idx` is a $(N, 1)$ tensor. | |
| Then `affine_mats[i]` can be obtained via slicing the bilateral grid indexed at `grid_idx[i]` by `xy[i, :]` and `rgb2gray(rgb[i, :])`. | |
| For 3-D and 4-D input cases, the behavior of indexing bilateral grids and coordinates is the same with the 2-D case. | |
| .. note:: | |
| This function can be regarded as a wrapper of `color_affine_transform` and `BilateralGrid` with a slight performance improvement. | |
| When `grid_idx` contains a unique index, only a single bilateral grid will used during the slicing. In this case, this function will not | |
| perform tensor indexing to avoid data copy and extra memory | |
| (see [this](https://discuss.pytorch.org/t/does-indexing-a-tensor-return-a-copy-of-it/164905)). | |
| Args: | |
| bil_grids (`BilateralGrid`): An instance of $N$ bilateral grids. | |
| xy (torch.Tensor): The x-y coordinates of shape $(..., 2)$ in the range of $[0,1]$. | |
| rgb (torch.Tensor): The RGB values of shape $(..., 3)$ for computing the guidance coordinates, ranging in $[0,1]$. | |
| grid_idx (torch.Tensor): The indices of bilateral grids for each slicing. Shape: $(..., 1)$. | |
| Returns: | |
| A dictionary with keys and values as follows: | |
| ``` | |
| { | |
| "rgb": Transformed RGB colors. Shape: (..., 3), | |
| "rgb_affine_mats": The sliced affine transformation matrices from bilateral grids. Shape: (..., 3, 4) | |
| } | |
| ``` | |
| """ | |
| sh_ = rgb.shape | |
| grid_idx_unique = torch.unique(grid_idx) | |
| if len(grid_idx_unique) == 1: | |
| # All pixels are from a single view. | |
| grid_idx = grid_idx_unique # (1,) | |
| xy = xy.unsqueeze(0) # (1, ..., 2) | |
| rgb = rgb.unsqueeze(0) # (1, ..., 3) | |
| else: | |
| # Pixels are randomly sampled from different views. | |
| if len(grid_idx.shape) == 4: | |
| grid_idx = grid_idx[:, 0, 0, 0] # (chunk_size,) | |
| elif len(grid_idx.shape) == 3: | |
| grid_idx = grid_idx[:, 0, 0] # (chunk_size,) | |
| elif len(grid_idx.shape) == 2: | |
| grid_idx = grid_idx[:, 0] # (chunk_size,) | |
| else: | |
| raise ValueError( | |
| "The input to bilateral grid slicing is not supported yet." | |
| ) | |
| affine_mats = bil_grids(xy, rgb, grid_idx) | |
| rgb = color_affine_transform(affine_mats, rgb) | |
| return { | |
| "rgb": rgb.reshape(*sh_), | |
| "rgb_affine_mats": affine_mats.reshape( | |
| *sh_[:-1], affine_mats.shape[-2], affine_mats.shape[-1] | |
| ), | |
| } | |
| class BilateralGrid(nn.Module): | |
| """Class for 3D bilateral grids. | |
| Holds one or more than one bilateral grids. | |
| """ | |
| def __init__(self, num, grid_X=16, grid_Y=16, grid_W=8): | |
| """ | |
| Args: | |
| num (int): The number of bilateral grids (i.e., the number of views). | |
| grid_X (int): Defines grid width $W$. | |
| grid_Y (int): Defines grid height $H$. | |
| grid_W (int): Defines grid guidance dimension $L$. | |
| """ | |
| super(BilateralGrid, self).__init__() | |
| self.grid_width = grid_X | |
| """Grid width. Type: int.""" | |
| self.grid_height = grid_Y | |
| """Grid height. Type: int.""" | |
| self.grid_guidance = grid_W | |
| """Grid guidance dimension. Type: int.""" | |
| # Initialize grids. | |
| grid = self._init_identity_grid() | |
| self.grids = nn.Parameter(grid.tile(num, 1, 1, 1, 1)) # (N, 12, L, H, W) | |
| """ A 5-D tensor of shape $(N, 12, L, H, W)$.""" | |
| # Weights of BT601 RGB-to-gray. | |
| self.register_buffer("rgb2gray_weight", torch.Tensor([[0.299, 0.587, 0.114]])) | |
| self.rgb2gray = lambda rgb: (rgb @ self.rgb2gray_weight.T) * 2.0 - 1.0 | |
| """ A function that converts RGB to gray-scale guidance in $[-1, 1]$.""" | |
| def _init_identity_grid(self): | |
| grid = torch.tensor( | |
| [ | |
| 1.0, | |
| 0, | |
| 0, | |
| 0, | |
| 0, | |
| 1.0, | |
| 0, | |
| 0, | |
| 0, | |
| 0, | |
| 1.0, | |
| 0, | |
| ] | |
| ).float() | |
| grid = grid.repeat( | |
| [self.grid_guidance * self.grid_height * self.grid_width, 1] | |
| ) # (L * H * W, 12) | |
| grid = grid.reshape( | |
| 1, self.grid_guidance, self.grid_height, self.grid_width, -1 | |
| ) # (1, L, H, W, 12) | |
| grid = grid.permute(0, 4, 1, 2, 3) # (1, 12, L, H, W) | |
| return grid | |
| def tv_loss(self): | |
| """Computes and returns total variation loss on the bilateral grids.""" | |
| return total_variation_loss(self.grids) | |
| def forward(self, grid_xy, rgb, idx=None): | |
| """Bilateral grid slicing. Supports 2-D, 3-D, 4-D, and 5-D input. | |
| For the 2-D, 3-D, and 4-D cases, please refer to `slice`. | |
| For the 5-D cases, `idx` will be unused and the first dimension of `xy` should be | |
| equal to the number of bilateral grids. Then this function becomes PyTorch's | |
| [`F.grid_sample`](https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html). | |
| Args: | |
| grid_xy (torch.Tensor): The x-y coordinates in the range of $[0,1]$. | |
| rgb (torch.Tensor): The RGB values in the range of $[0,1]$. | |
| idx (torch.Tensor): The bilateral grid indices. | |
| Returns: | |
| Sliced affine matrices of shape $(..., 3, 4)$. | |
| """ | |
| grids = self.grids | |
| input_ndims = len(grid_xy.shape) | |
| assert len(rgb.shape) == input_ndims | |
| if input_ndims > 1 and input_ndims < 5: | |
| # Convert input into 5D | |
| for i in range(5 - input_ndims): | |
| grid_xy = grid_xy.unsqueeze(1) | |
| rgb = rgb.unsqueeze(1) | |
| assert idx is not None | |
| elif input_ndims != 5: | |
| raise ValueError( | |
| "Bilateral grid slicing only takes either 2D, 3D, 4D and 5D inputs" | |
| ) | |
| grids = self.grids | |
| if idx is not None: | |
| grids = grids[idx] | |
| assert grids.shape[0] == grid_xy.shape[0] | |
| # Generate slicing coordinates. | |
| grid_xy = (grid_xy - 0.5) * 2 # Rescale to [-1, 1]. | |
| grid_z = self.rgb2gray(rgb) | |
| # print(grid_xy.shape, grid_z.shape) | |
| # exit() | |
| grid_xyz = torch.cat([grid_xy, grid_z], dim=-1) # (N, m, h, w, 3) | |
| affine_mats = F.grid_sample( | |
| grids, grid_xyz, mode="bilinear", align_corners=True, padding_mode="border" | |
| ) # (N, 12, m, h, w) | |
| affine_mats = affine_mats.permute(0, 2, 3, 4, 1) # (N, m, h, w, 12) | |
| affine_mats = affine_mats.reshape( | |
| *affine_mats.shape[:-1], 3, 4 | |
| ) # (N, m, h, w, 3, 4) | |
| for _ in range(5 - input_ndims): | |
| affine_mats = affine_mats.squeeze(1) | |
| return affine_mats | |
| def slice4d(bil_grid4d, xyz, rgb): | |
| """Slices a 4D bilateral grid by point coordinates `xyz` and gray-scale guidances of radiance colors `rgb`. | |
| Args: | |
| bil_grid4d (`BilateralGridCP4D`): The input 4D bilateral grid. | |
| xyz (torch.Tensor): The xyz coordinates with shape $(..., 3)$. | |
| rgb (torch.Tensor): The RGB values with shape $(..., 3)$. | |
| Returns: | |
| A dictionary with keys and values as follows: | |
| ``` | |
| { | |
| "rgb": Transformed radiance RGB colors. Shape: (..., 3), | |
| "rgb_affine_mats": The sliced affine transformation matrices from the 4D bilateral grid. Shape: (..., 3, 4) | |
| } | |
| ``` | |
| """ | |
| affine_mats = bil_grid4d(xyz, rgb) | |
| rgb = color_affine_transform(affine_mats, rgb) | |
| return {"rgb": rgb, "rgb_affine_mats": affine_mats} | |
| class _ScaledTanh(nn.Module): | |
| def __init__(self, s=2.0): | |
| super().__init__() | |
| self.scaler = s | |
| def forward(self, x): | |
| return torch.tanh(self.scaler * x) | |
| class BilateralGridCP4D(nn.Module): | |
| """Class for low-rank 4D bilateral grids.""" | |
| def __init__( | |
| self, | |
| grid_X=16, | |
| grid_Y=16, | |
| grid_Z=16, | |
| grid_W=8, | |
| rank=5, | |
| learn_gray=True, | |
| gray_mlp_width=8, | |
| gray_mlp_depth=2, | |
| init_noise_scale=1e-6, | |
| bound=2.0, | |
| ): | |
| """ | |
| Args: | |
| grid_X (int): Defines grid width. | |
| grid_Y (int): Defines grid height. | |
| grid_Z (int): Defines grid depth. | |
| grid_W (int): Defines grid guidance dimension. | |
| rank (int): Rank of the 4D bilateral grid. | |
| learn_gray (bool): If True, an MLP will be learned to convert RGB colors to gray-scale guidances. | |
| gray_mlp_width (int): The MLP width for learnable guidance. | |
| gray_mlp_depth (int): The number of MLP layers for learnable guidance. | |
| init_noise_scale (float): The noise scale of the initialized factors. | |
| bound (float): The bound of the xyz coordinates. | |
| """ | |
| super(BilateralGridCP4D, self).__init__() | |
| self.grid_X = grid_X | |
| """Grid width. Type: int.""" | |
| self.grid_Y = grid_Y | |
| """Grid height. Type: int.""" | |
| self.grid_Z = grid_Z | |
| """Grid depth. Type: int.""" | |
| self.grid_W = grid_W | |
| """Grid guidance dimension. Type: int.""" | |
| self.rank = rank | |
| """Rank of the 4D bilateral grid. Type: int.""" | |
| self.learn_gray = learn_gray | |
| """Flags of learnable guidance is used. Type: bool.""" | |
| self.gray_mlp_width = gray_mlp_width | |
| """The MLP width for learnable guidance. Type: int.""" | |
| self.gray_mlp_depth = gray_mlp_depth | |
| """The MLP depth for learnable guidance. Type: int.""" | |
| self.init_noise_scale = init_noise_scale | |
| """The noise scale of the initialized factors. Type: float.""" | |
| self.bound = bound | |
| """The bound of the xyz coordinates. Type: float.""" | |
| self._init_cp_factors_parafac() | |
| self.rgb2gray = None | |
| """ A function that converts RGB to gray-scale guidances in $[-1, 1]$. | |
| If `learn_gray` is True, this will be an MLP network.""" | |
| if self.learn_gray: | |
| def rgb2gray_mlp_linear(layer): | |
| return nn.Linear( | |
| self.gray_mlp_width, | |
| self.gray_mlp_width if layer < self.gray_mlp_depth - 1 else 1, | |
| ) | |
| def rgb2gray_mlp_actfn(_): | |
| return nn.ReLU(inplace=True) | |
| self.rgb2gray = nn.Sequential( | |
| *( | |
| [nn.Linear(3, self.gray_mlp_width)] | |
| + [ | |
| nn_module(layer) | |
| for layer in range(1, self.gray_mlp_depth) | |
| for nn_module in [rgb2gray_mlp_actfn, rgb2gray_mlp_linear] | |
| ] | |
| + [_ScaledTanh(2.0)] | |
| ) | |
| ) | |
| else: | |
| # Weights of BT601/BT470 RGB-to-gray. | |
| self.register_buffer( | |
| "rgb2gray_weight", torch.Tensor([[0.299, 0.587, 0.114]]) | |
| ) | |
| self.rgb2gray = lambda rgb: (rgb @ self.rgb2gray_weight.T) * 2.0 - 1.0 | |
| def _init_identity_grid(self): | |
| grid = torch.tensor( | |
| [ | |
| 1.0, | |
| 0, | |
| 0, | |
| 0, | |
| 0, | |
| 1.0, | |
| 0, | |
| 0, | |
| 0, | |
| 0, | |
| 1.0, | |
| 0, | |
| ] | |
| ).float() | |
| grid = grid.repeat([self.grid_W * self.grid_Z * self.grid_Y * self.grid_X, 1]) | |
| grid = grid.reshape(self.grid_W, self.grid_Z, self.grid_Y, self.grid_X, -1) | |
| grid = grid.permute(4, 0, 1, 2, 3) # (12, grid_W, grid_Z, grid_Y, grid_X) | |
| return grid | |
| def _init_cp_factors_parafac(self): | |
| # Initialize identity grids. | |
| init_grids = self._init_identity_grid() | |
| # Random noises are added to avoid singularity. | |
| init_grids = torch.randn_like(init_grids) * self.init_noise_scale + init_grids | |
| from tensorly.decomposition import parafac | |
| # Initialize grid CP factors | |
| _, facs = parafac(init_grids.clone().detach(), rank=self.rank) | |
| self.num_facs = len(facs) | |
| self.fac_0 = nn.Linear(facs[0].shape[0], facs[0].shape[1], bias=False) | |
| self.fac_0.weight = nn.Parameter(facs[0]) # (12, rank) | |
| for i in range(1, self.num_facs): | |
| fac = facs[i].T # (rank, grid_size) | |
| fac = fac.view(1, fac.shape[0], fac.shape[1], 1) # (1, rank, grid_size, 1) | |
| self.register_buffer(f"fac_{i}_init", fac) | |
| fac_resid = torch.zeros_like(fac) | |
| self.register_parameter(f"fac_{i}", nn.Parameter(fac_resid)) | |
| def tv_loss(self): | |
| """Computes and returns total variation loss on the factors of the low-rank 4D bilateral grids.""" | |
| total_loss = 0 | |
| for i in range(1, self.num_facs): | |
| fac = self.get_parameter(f"fac_{i}") | |
| total_loss += total_variation_loss(fac) | |
| return total_loss | |
| def forward(self, xyz, rgb): | |
| """Low-rank 4D bilateral grid slicing. | |
| Args: | |
| xyz (torch.Tensor): The xyz coordinates with shape $(..., 3)$. | |
| rgb (torch.Tensor): The corresponding RGB values with shape $(..., 3)$. | |
| Returns: | |
| Sliced affine matrices with shape $(..., 3, 4)$. | |
| """ | |
| sh_ = xyz.shape | |
| xyz = xyz.reshape(-1, 3) # flatten (N, 3) | |
| rgb = rgb.reshape(-1, 3) # flatten (N, 3) | |
| xyz = xyz / self.bound | |
| assert self.rgb2gray is not None | |
| gray = self.rgb2gray(rgb) | |
| xyzw = torch.cat([xyz, gray], dim=-1) # (N, 4) | |
| xyzw = xyzw.transpose(0, 1) # (4, N) | |
| coords = torch.stack([torch.zeros_like(xyzw), xyzw], dim=-1) # (4, N, 2) | |
| coords = coords.unsqueeze(1) # (4, 1, N, 2) | |
| coef = 1.0 | |
| for i in range(1, self.num_facs): | |
| fac = self.get_parameter(f"fac_{i}") + self.get_buffer(f"fac_{i}_init") | |
| coef = coef * F.grid_sample( | |
| fac, coords[[i - 1]], align_corners=True, padding_mode="border" | |
| ) # [1, rank, 1, N] | |
| coef = coef.squeeze([0, 2]).transpose(0, 1) # (N, rank) #type: ignore | |
| mat = self.fac_0(coef) | |
| return mat.reshape(*sh_[:-1], 3, 4) | |