File size: 5,858 Bytes
c20d7cc |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 |
"""Contains utility math functions.
For licensing see accompanying LICENSE file.
Copyright (C) 2025 Apple Inc. All Rights Reserved.
"""
from __future__ import annotations
from typing import Any, Callable, Literal, NamedTuple, Tuple, Union
import torch
from torch import autograd
ActivationType = Literal[
"linear",
"exp",
"sigmoid",
"softplus",
"relu_with_pushback",
"hard_sigmoid_with_pushback",
]
ActivationFunction = Callable[[torch.Tensor], torch.Tensor]
class ActivationPair(NamedTuple):
"""A pair of forward and inverse activation functions."""
forward: ActivationFunction
inverse: ActivationFunction
def create_activation_pair(activation_type: ActivationType) -> ActivationPair:
"""Create activation function and corresponding inverse function.
Args:
activation_type: The activation type to create.
Returns:
The corresponding activation functions and the corresponding inverse function.
"""
if activation_type == "linear":
return ActivationPair(lambda x: x, lambda x: x)
elif activation_type == "exp":
return ActivationPair(torch.exp, torch.log)
elif activation_type == "sigmoid":
return ActivationPair(torch.sigmoid, inverse_sigmoid)
elif activation_type == "softplus":
return ActivationPair(torch.nn.functional.softplus, inverse_softplus)
elif activation_type == "relu_with_pushback":
return ActivationPair(relu_with_pushback, lambda x: x)
elif activation_type == "hard_sigmoid_with_pushback":
return ActivationPair(hard_sigmoid_with_pushback, lambda x: 6.0 * x - 3.0)
else:
raise ValueError(f"Unsupported activation function: {activation_type}.")
def inverse_sigmoid(tensor: torch.Tensor) -> torch.Tensor:
"""Compute inverse sigmoid."""
return torch.log(tensor / (1.0 - tensor))
def inverse_softplus(tensor: torch.Tensor, eps: float = 1e-06) -> torch.Tensor:
"""Compute inverse softplus."""
tensor = tensor.clamp_min(eps)
sigmoid = torch.sigmoid(-tensor)
exp = sigmoid / (1.0 - sigmoid)
return tensor + torch.log(-exp + 1.0)
# The first value describes the threshold from where clamping will be applied, while
# the second value describes the value to clamp with.
SoftClampRange = Tuple[Union[torch.Tensor, float], Union[torch.Tensor, float]]
def softclamp(
tensor: torch.Tensor,
min: SoftClampRange | None = None,
max: SoftClampRange | None = None,
) -> torch.Tensor:
"""Clamp tensor to min/max in differentiable way.
Args:
tensor: The tensor to clamp.
min: Pair of threshold to start clamping and value to clamp to.
The first value should be larger than the second.
max: Pair of threshold to start clamping and value to clamp to.
The first value should be smaller than the second.
Returns:
The clamped tensor.
"""
def normalize(clamp_range: SoftClampRange) -> torch.Tensor:
value0, value1 = clamp_range
return value0 + (value1 - value0) * torch.tanh((tensor - value0) / (value1 - value0))
tensor_clamped = tensor
if min is not None:
tensor_clamped = torch.maximum(tensor_clamped, normalize(min))
if max is not None:
tensor_clamped = torch.minimum(tensor_clamped, normalize(max))
return tensor_clamped
class ClampWithPushback(autograd.Function):
"""Implementation of clamp_with_pushback function."""
@staticmethod
def forward(
ctx: Any,
tensor: torch.Tensor,
min: float | None,
max: float | None,
pushback: float,
) -> torch.Tensor:
"""Apply clamp."""
if min is not None and max is not None and min >= max:
raise ValueError("Only min < max is supported.")
ctx.save_for_backward(tensor)
ctx.min = min
ctx.max = max
ctx.pushback = pushback
return torch.clamp(tensor, min=min, max=max)
@staticmethod
def backward( # type: ignore[override] # Deal with buggy torch annotations.
ctx: Any, grad_in: torch.Tensor
) -> tuple[torch.Tensor, None, None, None]:
"""Compute gradient of clamp with pushback."""
grad_out = grad_in.clone()
(tensor,) = ctx.saved_tensors
if ctx.min is not None:
mask_min = tensor < ctx.min
grad_out[mask_min] = -ctx.pushback
if ctx.max is not None:
mask_max = tensor > ctx.max
grad_out[mask_max] = ctx.pushback
return grad_out, None, None, None
def clamp_with_pushback(
tensor: torch.Tensor,
min: float | None = None,
max: float | None = None,
pushback: float = 1e-2,
) -> torch.Tensor:
"""Variant of clamp function which avoid the vanishing gradient problem.
This function is equivalent to adding a regularizer of the form
pushback * sum_i (
relu(min - preactivation_i) + relu(preactivation_i - max)
)
to the full loss function, which pushes clamped values back.
When used in minimization problems, pushback should be greater than
zero. In maximization problems, pushback should be smaller than zero.
"""
output = ClampWithPushback.apply(tensor, min, max, pushback)
assert isinstance(output, torch.Tensor)
return output
def hard_sigmoid_with_pushback(x: torch.Tensor, slope: float = 1.0 / 6.0) -> torch.Tensor:
"""Apply hard sigmoid with pushback.
For compatibility reasons, we follow the default PyTorch implementation with a
default slope of 1/6:
https://pytorch.org/docs/stable/generated/torch.nn.Hardsigmoid.html
"""
return clamp_with_pushback(slope * x + 0.5, min=0.0, max=1.0)
def relu_with_pushback(x: torch.Tensor) -> torch.Tensor:
"""Compute relu with pushback."""
return clamp_with_pushback(x, min=0.0)
|