File size: 993 Bytes
5a87d8d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 |
import sys
sys.path.append("./BranchSBM")
import torch
import torch.nn as nn
from typing import List, Optional
from networks.mlp_base import SimpleDenseNet
class GrowthNet(SimpleDenseNet):
def __init__(
self,
dim: int,
activation: str,
hidden_dims: List[int] = None,
batch_norm: bool = False,
negative: bool = False
):
super().__init__(input_size=dim + 1, target_size=1,
activation=activation,
batch_norm=batch_norm,
hidden_dims=hidden_dims)
self.softplus = nn.Softplus()
self.negative = negative
def forward(self, t, x):
if t.dim() < 1 or t.shape[0] != x.shape[0]:
t = t.repeat(x.shape[0])[:, None]
if t.dim() < 2:
t = t[:, None]
x = torch.cat([t, x], dim=-1)
x = self.softplus(self.model(x))
if self.negative:
x = -x
return x |