| import sys | |
| sys.path.append("./BranchSBM") | |
| import torch | |
| import torch.nn as nn | |
| from typing import List, Optional | |
| from networks.mlp_base import SimpleDenseNet | |
| class GrowthNet(SimpleDenseNet): | |
| def __init__( | |
| self, | |
| dim: int, | |
| activation: str, | |
| hidden_dims: List[int] = None, | |
| batch_norm: bool = False, | |
| negative: bool = False | |
| ): | |
| super().__init__(input_size=dim + 1, target_size=1, | |
| activation=activation, | |
| batch_norm=batch_norm, | |
| hidden_dims=hidden_dims) | |
| self.softplus = nn.Softplus() | |
| self.negative = negative | |
| def forward(self, t, x): | |
| if t.dim() < 1 or t.shape[0] != x.shape[0]: | |
| t = t.repeat(x.shape[0])[:, None] | |
| if t.dim() < 2: | |
| t = t[:, None] | |
| x = torch.cat([t, x], dim=-1) | |
| x = self.softplus(self.model(x)) | |
| if self.negative: | |
| x = -x | |
| return x |