File size: 993 Bytes
5a87d8d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
import sys
sys.path.append("./BranchSBM")
import torch
import torch.nn as nn
from typing import List, Optional

from networks.mlp_base import SimpleDenseNet

class GrowthNet(SimpleDenseNet):
    def __init__(
        self,
        dim: int,
        activation: str,
        hidden_dims: List[int] = None,
        batch_norm: bool = False,
        negative: bool = False
    ):
        super().__init__(input_size=dim + 1, target_size=1,
                         activation=activation, 
                         batch_norm=batch_norm, 
                         hidden_dims=hidden_dims)
        
        self.softplus = nn.Softplus()
        self.negative = negative
        
    def forward(self, t, x):

        if t.dim() < 1 or t.shape[0] != x.shape[0]:
            t = t.repeat(x.shape[0])[:, None]
        if t.dim() < 2:
            t = t[:, None]
        x = torch.cat([t, x], dim=-1)
        x = self.softplus(self.model(x))
        if self.negative:
            x = -x
        return x