import sys sys.path.append("./BranchSBM") import torch import torch.nn as nn from typing import List, Optional from networks.mlp_base import SimpleDenseNet class GrowthNet(SimpleDenseNet): def __init__( self, dim: int, activation: str, hidden_dims: List[int] = None, batch_norm: bool = False, negative: bool = False ): super().__init__(input_size=dim + 1, target_size=1, activation=activation, batch_norm=batch_norm, hidden_dims=hidden_dims) self.softplus = nn.Softplus() self.negative = negative def forward(self, t, x): if t.dim() < 1 or t.shape[0] != x.shape[0]: t = t.repeat(x.shape[0])[:, None] if t.dim() < 2: t = t[:, None] x = torch.cat([t, x], dim=-1) x = self.softplus(self.model(x)) if self.negative: x = -x return x