| # Copyright (c) Meta Platforms, Inc. and affiliates. | |
| # All rights reserved. | |
| # | |
| # This source code is licensed under the CC-by-NC license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| from typing import Optional | |
| import torch | |
| from torch import Tensor | |
| def unsqueeze_to_match(source: Tensor, target: Tensor, how: str = "suffix") -> Tensor: | |
| """ | |
| Unsqueeze the source tensor to match the dimensionality of the target tensor. | |
| Args: | |
| source (Tensor): The source tensor to be unsqueezed. | |
| target (Tensor): The target tensor to match the dimensionality of. | |
| how (str, optional): Whether to unsqueeze the source tensor at the beginning | |
| ("prefix") or end ("suffix"). Defaults to "suffix". | |
| Returns: | |
| Tensor: The unsqueezed source tensor. | |
| """ | |
| assert ( | |
| how == "prefix" or how == "suffix" | |
| ), f"{how} is not supported, only 'prefix' and 'suffix' are supported." | |
| dim_diff = target.dim() - source.dim() | |
| for _ in range(dim_diff): | |
| if how == "prefix": | |
| source = source.unsqueeze(0) | |
| elif how == "suffix": | |
| source = source.unsqueeze(-1) | |
| return source | |
| def expand_tensor_like(input_tensor: Tensor, expand_to: Tensor) -> Tensor: | |
| """`input_tensor` is a 1d vector of length equal to the batch size of `expand_to`, | |
| expand `input_tensor` to have the same shape as `expand_to` along all remaining dimensions. | |
| Args: | |
| input_tensor (Tensor): (batch_size,). | |
| expand_to (Tensor): (batch_size, ...). | |
| Returns: | |
| Tensor: (batch_size, ...). | |
| """ | |
| assert input_tensor.ndim == 1, "Input tensor must be a 1d vector." | |
| assert ( | |
| input_tensor.shape[0] == expand_to.shape[0] | |
| ), f"The first (batch_size) dimension must match. Got shape {input_tensor.shape} and {expand_to.shape}." | |
| dim_diff = expand_to.ndim - input_tensor.ndim | |
| t_expanded = input_tensor.clone() | |
| t_expanded = t_expanded.reshape(-1, *([1] * dim_diff)) | |
| return t_expanded.expand_as(expand_to) | |
| def gradient( | |
| output: Tensor, | |
| x: Tensor, | |
| grad_outputs: Optional[Tensor] = None, | |
| create_graph: bool = False, | |
| ) -> Tensor: | |
| """ | |
| Compute the gradient of the inner product of output and grad_outputs w.r.t :math:`x`. | |
| Args: | |
| output (Tensor): [N, D] Output of the function. | |
| x (Tensor): [N, d_1, d_2, ... ] input | |
| grad_outputs (Optional[Tensor]): [N, D] Gradient of outputs, if `None`, | |
| then will use a tensor of ones | |
| create_graph (bool): If True, graph of the derivative will be constructed, allowing | |
| to compute higher order derivative products. Defaults to False. | |
| Returns: | |
| Tensor: [N, d_1, d_2, ... ]. the gradient w.r.t x. | |
| """ | |
| if grad_outputs is None: | |
| grad_outputs = torch.ones_like(output).detach() | |
| grad = torch.autograd.grad( | |
| output, x, grad_outputs=grad_outputs, create_graph=create_graph | |
| )[0] | |
| return grad | |
Xet Storage Details
- Size:
- 2.98 kB
- Xet hash:
- 85246f42a03894cd7626703b099d6a3bf4ea25e2bfd8f39440fdcc4b74f75d15
·
Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.