File size: 5,858 Bytes
c20d7cc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
"""Contains utility math functions.

For licensing see accompanying LICENSE file.
Copyright (C) 2025 Apple Inc. All Rights Reserved.
"""

from __future__ import annotations

from typing import Any, Callable, Literal, NamedTuple, Tuple, Union

import torch
from torch import autograd

ActivationType = Literal[
    "linear",
    "exp",
    "sigmoid",
    "softplus",
    "relu_with_pushback",
    "hard_sigmoid_with_pushback",
]
ActivationFunction = Callable[[torch.Tensor], torch.Tensor]


class ActivationPair(NamedTuple):
    """A pair of forward and inverse activation functions."""

    forward: ActivationFunction
    inverse: ActivationFunction


def create_activation_pair(activation_type: ActivationType) -> ActivationPair:
    """Create activation function and corresponding inverse function.

    Args:
        activation_type: The activation type to create.

    Returns:
        The corresponding activation functions and the corresponding inverse function.
    """
    if activation_type == "linear":
        return ActivationPair(lambda x: x, lambda x: x)
    elif activation_type == "exp":
        return ActivationPair(torch.exp, torch.log)
    elif activation_type == "sigmoid":
        return ActivationPair(torch.sigmoid, inverse_sigmoid)
    elif activation_type == "softplus":
        return ActivationPair(torch.nn.functional.softplus, inverse_softplus)
    elif activation_type == "relu_with_pushback":
        return ActivationPair(relu_with_pushback, lambda x: x)
    elif activation_type == "hard_sigmoid_with_pushback":
        return ActivationPair(hard_sigmoid_with_pushback, lambda x: 6.0 * x - 3.0)
    else:
        raise ValueError(f"Unsupported activation function: {activation_type}.")


def inverse_sigmoid(tensor: torch.Tensor) -> torch.Tensor:
    """Compute inverse sigmoid."""
    return torch.log(tensor / (1.0 - tensor))


def inverse_softplus(tensor: torch.Tensor, eps: float = 1e-06) -> torch.Tensor:
    """Compute inverse softplus."""
    tensor = tensor.clamp_min(eps)
    sigmoid = torch.sigmoid(-tensor)
    exp = sigmoid / (1.0 - sigmoid)
    return tensor + torch.log(-exp + 1.0)


# The first value describes the threshold from where clamping will be applied, while
# the second value describes the value to clamp with.
SoftClampRange = Tuple[Union[torch.Tensor, float], Union[torch.Tensor, float]]


def softclamp(
    tensor: torch.Tensor,
    min: SoftClampRange | None = None,
    max: SoftClampRange | None = None,
) -> torch.Tensor:
    """Clamp tensor to min/max in differentiable way.

    Args:
        tensor: The tensor to clamp.
        min: Pair of threshold to start clamping and value to clamp to.
            The first value should be larger than the second.
        max: Pair of threshold to start clamping and value to clamp to.
            The first value should be smaller than the second.

    Returns:
        The clamped tensor.
    """

    def normalize(clamp_range: SoftClampRange) -> torch.Tensor:
        value0, value1 = clamp_range
        return value0 + (value1 - value0) * torch.tanh((tensor - value0) / (value1 - value0))

    tensor_clamped = tensor
    if min is not None:
        tensor_clamped = torch.maximum(tensor_clamped, normalize(min))
    if max is not None:
        tensor_clamped = torch.minimum(tensor_clamped, normalize(max))

    return tensor_clamped


class ClampWithPushback(autograd.Function):
    """Implementation of clamp_with_pushback function."""

    @staticmethod
    def forward(
        ctx: Any,
        tensor: torch.Tensor,
        min: float | None,
        max: float | None,
        pushback: float,
    ) -> torch.Tensor:
        """Apply clamp."""
        if min is not None and max is not None and min >= max:
            raise ValueError("Only min < max is supported.")

        ctx.save_for_backward(tensor)
        ctx.min = min
        ctx.max = max
        ctx.pushback = pushback
        return torch.clamp(tensor, min=min, max=max)

    @staticmethod
    def backward(  # type: ignore[override] # Deal with buggy torch annotations.
        ctx: Any, grad_in: torch.Tensor
    ) -> tuple[torch.Tensor, None, None, None]:
        """Compute gradient of clamp with pushback."""
        grad_out = grad_in.clone()
        (tensor,) = ctx.saved_tensors

        if ctx.min is not None:
            mask_min = tensor < ctx.min
            grad_out[mask_min] = -ctx.pushback

        if ctx.max is not None:
            mask_max = tensor > ctx.max
            grad_out[mask_max] = ctx.pushback

        return grad_out, None, None, None


def clamp_with_pushback(
    tensor: torch.Tensor,
    min: float | None = None,
    max: float | None = None,
    pushback: float = 1e-2,
) -> torch.Tensor:
    """Variant of clamp function which avoid the vanishing gradient problem.

    This function is equivalent to adding a regularizer of the form

        pushback * sum_i (
            relu(min - preactivation_i) + relu(preactivation_i - max)
        )

    to the full loss function, which pushes clamped values back.

    When used in minimization problems, pushback should be greater than
    zero. In maximization problems, pushback should be smaller than zero.
    """
    output = ClampWithPushback.apply(tensor, min, max, pushback)
    assert isinstance(output, torch.Tensor)
    return output


def hard_sigmoid_with_pushback(x: torch.Tensor, slope: float = 1.0 / 6.0) -> torch.Tensor:
    """Apply hard sigmoid with pushback.

    For compatibility reasons, we follow the default PyTorch implementation with a
    default slope of 1/6:

        https://pytorch.org/docs/stable/generated/torch.nn.Hardsigmoid.html
    """
    return clamp_with_pushback(slope * x + 0.5, min=0.0, max=1.0)


def relu_with_pushback(x: torch.Tensor) -> torch.Tensor:
    """Compute relu with pushback."""
    return clamp_with_pushback(x, min=0.0)