| |
|
| | """ |
| | A collection of common activation functions. |
| | """ |
| |
|
| | import torch |
| |
|
| | class LearnedActivation(torch.nn.Module): |
| | def __init__(self, hidden_size=10): |
| | super(LearnedActivation, self).__init__() |
| | self.fc1 = torch.nn.Linear(1, hidden_size) |
| | self.fc2 = torch.nn.Linear(hidden_size, 1) |
| | self.initialize_weights() |
| |
|
| | def initialize_weights(self): |
| | |
| | self.fc1.weight.data = torch.tensor([[-0.3478], |
| | [-0.3444], |
| | [-0.9863], |
| | [-0.8657], |
| | [-0.0148], |
| | [ 0.1085], |
| | [-0.5282], |
| | [-0.1138], |
| | [-1.1070], |
| | [-0.1035]]) |
| | self.fc1.bias.data = torch.tensor([ 1.4480, 1.4610, -0.8526, 0.0151, -0.1249, -0.7658, 2.2386, -0.8884, 1.0032, -0.6235]) |
| | self.fc2.weight.data = torch.tensor([[-0.4762, -1.2194, 0.4155, 0.3927, -0.2778, 0.0986, -0.9284, 0.2070, 0.3586, -0.2143]]) |
| | self.fc2.bias.data = torch.tensor([4.1740]) |
| | |
| | def forward(self, x): |
| | |
| | orig_shape = x.shape |
| | x = x.view(-1, 1) |
| | x = torch.relu(self.fc1(x)) |
| | x = self.fc2(x) |
| | return x.view(orig_shape) |
| |
|
| |
|
| | ACTIVATIONS_DICT = { |
| | "gelu": torch.nn.GELU(), |
| | "relu": torch.nn.ReLU(), |
| | "leakyrelu": torch.nn.LeakyReLU(), |
| | "tanh": torch.nn.Tanh(), |
| | "sigmoid": torch.nn.Sigmoid(), |
| | "silu": torch.nn.SiLU(), |
| | "learned": LearnedActivation(hidden_size=10), |
| | "none": torch.nn.Identity(), |
| | } |
| |
|
| |
|
| | def build_activation(activation_name: str): |
| | """ |
| | Given the name of the activation function, |
| | build it. |
| | Args: |
| | activation_name: str |
| | Returns: |
| | activation: torch.nn.Module |
| | """ |
| | return ACTIVATIONS_DICT[activation_name.lower()] |
| |
|