BranchSBM / networks /growth_mlp.py
sophtang's picture
update
670c065 verified
raw
history blame contribute delete
993 Bytes
import sys
sys.path.append("./BranchSBM")
import torch
import torch.nn as nn
from typing import List, Optional
from networks.mlp_base import SimpleDenseNet
class GrowthNet(SimpleDenseNet):
def __init__(
self,
dim: int,
activation: str,
hidden_dims: List[int] = None,
batch_norm: bool = False,
negative: bool = False
):
super().__init__(input_size=dim + 1, target_size=1,
activation=activation,
batch_norm=batch_norm,
hidden_dims=hidden_dims)
self.softplus = nn.Softplus()
self.negative = negative
def forward(self, t, x):
if t.dim() < 1 or t.shape[0] != x.shape[0]:
t = t.repeat(x.shape[0])[:, None]
if t.dim() < 2:
t = t[:, None]
x = torch.cat([t, x], dim=-1)
x = self.softplus(self.model(x))
if self.negative:
x = -x
return x