|
|
| """ |
| A collection of common activation functions. |
| """ |
|
|
| import torch |
|
|
| class LearnedActivation(torch.nn.Module): |
| def __init__(self, hidden_size=10): |
| super(LearnedActivation, self).__init__() |
| self.fc1 = torch.nn.Linear(1, hidden_size) |
| self.fc2 = torch.nn.Linear(hidden_size, 1) |
| self.initialize_weights() |
|
|
| def initialize_weights(self): |
| |
| self.fc1.weight.data = torch.tensor([[-0.3478], |
| [-0.3444], |
| [-0.9863], |
| [-0.8657], |
| [-0.0148], |
| [ 0.1085], |
| [-0.5282], |
| [-0.1138], |
| [-1.1070], |
| [-0.1035]]) |
| self.fc1.bias.data = torch.tensor([ 1.4480, 1.4610, -0.8526, 0.0151, -0.1249, -0.7658, 2.2386, -0.8884, 1.0032, -0.6235]) |
| self.fc2.weight.data = torch.tensor([[-0.4762, -1.2194, 0.4155, 0.3927, -0.2778, 0.0986, -0.9284, 0.2070, 0.3586, -0.2143]]) |
| self.fc2.bias.data = torch.tensor([4.1740]) |
| |
| def forward(self, x): |
| |
| orig_shape = x.shape |
| x = x.view(-1, 1) |
| x = torch.relu(self.fc1(x)) |
| x = self.fc2(x) |
| return x.view(orig_shape) |
|
|
|
|
| ACTIVATIONS_DICT = { |
| "gelu": torch.nn.GELU(), |
| "relu": torch.nn.ReLU(), |
| "leakyrelu": torch.nn.LeakyReLU(), |
| "tanh": torch.nn.Tanh(), |
| "sigmoid": torch.nn.Sigmoid(), |
| "silu": torch.nn.SiLU(), |
| "learned": LearnedActivation(hidden_size=10), |
| "none": torch.nn.Identity(), |
| } |
|
|
|
|
| def build_activation(activation_name: str): |
| """ |
| Given the name of the activation function, |
| build it. |
| Args: |
| activation_name: str |
| Returns: |
| activation: torch.nn.Module |
| """ |
| return ACTIVATIONS_DICT[activation_name.lower()] |
|
|