File size: 1,145 Bytes
5a87d8d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 |
import sys
sys.path.append("./BranchSBM")
import torch.nn as nn
import torch
from typing import List, Optional
class swish(nn.Module):
def forward(self, x):
return x * torch.sigmoid(x)
ACTIVATION_MAP = {
"relu": nn.ReLU,
"sigmoid": nn.Sigmoid,
"tanh": nn.Tanh,
"selu": nn.SELU,
"elu": nn.ELU,
"lrelu": nn.LeakyReLU,
"softplus": nn.Softplus,
"silu": nn.SiLU,
"swish": swish,
}
class SimpleDenseNet(nn.Module):
def __init__(
self,
input_size: int,
target_size: int,
activation: str,
batch_norm: bool = False,
hidden_dims: List[int] = None,
):
super().__init__()
dims = [input_size, *hidden_dims, target_size]
layers = []
for i in range(len(dims) - 2):
layers.append(nn.Linear(dims[i], dims[i + 1]))
if batch_norm:
layers.append(nn.BatchNorm1d(dims[i + 1]))
layers.append(ACTIVATION_MAP[activation]())
layers.append(nn.Linear(dims[-2], dims[-1]))
self.model = nn.Sequential(*layers)
def forward(self, x):
return self.model(x)
|