File size: 475 Bytes
b55bace | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 | import sys
import torch
from .mlp_base import SimpleDenseNet
class VelocityNet(SimpleDenseNet):
def __init__(self, dim: int, *args, **kwargs):
super().__init__(input_size=dim + 1, target_size=dim, *args, **kwargs)
def forward(self, t, x):
if t.dim() < 1 or t.shape[0] != x.shape[0]:
t = t.repeat(x.shape[0])[:, None]
if t.dim() < 2:
t = t[:, None]
x = torch.cat([t, x], dim=-1)
return self.model(x)
|