Spaces:
Build error
Build error
| """Base Module.""" | |
| # Copyright (c) 2018-2022 Lynton Ardizzone, Visual Learning Lab Heidelberg. | |
| # SPDX-License-Identifier: MIT | |
| # | |
| # flake8: noqa | |
| # pylint: skip-file | |
| # type: ignore | |
| # pydocstyle: noqa | |
| from typing import Iterable, List, Tuple | |
| import torch.nn as nn | |
| from torch import Tensor | |
| class InvertibleModule(nn.Module): | |
| r"""Base class for all invertible modules in FrEIA. | |
| Given ``module``, an instance of some InvertibleModule. | |
| This ``module`` shall be invertible in its input dimensions, | |
| so that the input can be recovered by applying the module | |
| in backwards mode (``rev=True``), not to be confused with | |
| ``pytorch.backward()`` which computes the gradient of an operation:: | |
| x = torch.randn(BATCH_SIZE, DIM_COUNT) | |
| c = torch.randn(BATCH_SIZE, CONDITION_DIM) | |
| # Forward mode | |
| z, jac = module([x], [c], jac=True) | |
| # Backward mode | |
| x_rev, jac_rev = module(z, [c], rev=True) | |
| The ``module`` returns :math:`\\log \\det J = \\log \\left| \\det \\frac{\\partial f}{\\partial x} \\right|` | |
| of the operation in forward mode, and | |
| :math:`-\\log | \\det J | = \\log \\left| \\det \\frac{\\partial f^{-1}}{\\partial z} \\right| = -\\log \\left| \\det \\frac{\\partial f}{\\partial x} \\right|` | |
| in backward mode (``rev=True``). | |
| Then, ``torch.allclose(x, x_rev) == True`` and ``torch.allclose(jac, -jac_rev) == True``. | |
| """ | |
| def __init__(self, dims_in: Iterable[Tuple[int]], dims_c: Iterable[Tuple[int]] = None): | |
| """Initialize. | |
| Args: | |
| dims_in: list of tuples specifying the shape of the inputs to this | |
| operator: ``dims_in = [shape_x_0, shape_x_1, ...]`` | |
| dims_c: list of tuples specifying the shape of the conditions to | |
| this operator. | |
| """ | |
| super().__init__() | |
| if dims_c is None: | |
| dims_c = [] | |
| self.dims_in = list(dims_in) | |
| self.dims_c = list(dims_c) | |
| def forward( | |
| self, x_or_z: Iterable[Tensor], c: Iterable[Tensor] = None, rev: bool = False, jac: bool = True | |
| ) -> Tuple[Tuple[Tensor], Tensor]: | |
| r"""Forward/Backward Pass. | |
| Perform a forward (default, ``rev=False``) or backward pass (``rev=True``) through this module/operator. | |
| **Note to implementers:** | |
| - Subclasses MUST return a Jacobian when ``jac=True``, but CAN return a | |
| valid Jacobian when ``jac=False`` (not punished). The latter is only recommended | |
| if the computation of the Jacobian is trivial. | |
| - Subclasses MUST follow the convention that the returned Jacobian be | |
| consistent with the evaluation direction. Let's make this more precise: | |
| Let :math:`f` be the function that the subclass represents. Then: | |
| .. math:: | |
| J &= \\log \\det \\frac{\\partial f}{\\partial x} \\\\ | |
| -J &= \\log \\det \\frac{\\partial f^{-1}}{\\partial z}. | |
| Any subclass MUST return :math:`J` for forward evaluation (``rev=False``), | |
| and :math:`-J` for backward evaluation (``rev=True``). | |
| Args: | |
| x_or_z: input data (array-like of one or more tensors) | |
| c: conditioning data (array-like of none or more tensors) | |
| rev: perform backward pass | |
| jac: return Jacobian associated to the direction | |
| """ | |
| raise NotImplementedError(f"{self.__class__.__name__} does not provide forward(...) method") | |
| def log_jacobian(self, *args, **kwargs): | |
| """This method is deprecated, and does nothing except raise a warning.""" | |
| raise DeprecationWarning( | |
| "module.log_jacobian(...) is deprecated. " | |
| "module.forward(..., jac=True) returns a " | |
| "tuple (out, jacobian) now." | |
| ) | |
| def output_dims(self, input_dims: List[Tuple[int]]) -> List[Tuple[int]]: | |
| """Use for shape inference during construction of the graph. | |
| MUST be implemented for each subclass of ``InvertibleModule``. | |
| Args: | |
| input_dims: A list with one entry for each input to the module. | |
| Even if the module only has one input, must be a list with one | |
| entry. Each entry is a tuple giving the shape of that input, | |
| excluding the batch dimension. For example for a module with one | |
| input, which receives a 32x32 pixel RGB image, ``input_dims`` would | |
| be ``[(3, 32, 32)]`` | |
| Returns: | |
| A list structured in the same way as ``input_dims``. Each entry | |
| represents one output of the module, and the entry is a tuple giving | |
| the shape of that output. For example if the module splits the image | |
| into a right and a left half, the return value should be | |
| ``[(3, 16, 32), (3, 16, 32)]``. It is up to the implementor of the | |
| subclass to ensure that the total number of elements in all inputs | |
| and all outputs is consistent. | |
| """ | |
| raise NotImplementedError(f"{self.__class__.__name__} does not provide output_dims(...)") | |