File size: 323 Bytes
5a87d8d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 |
import sys
sys.path.append("./BranchSBM")
import torch
class flow_model_torch_wrapper(torch.nn.Module):
"""Wraps model to torchdyn compatible format."""
def __init__(self, model):
super().__init__()
self.model = model
def forward(self, t, x, *args, **kwargs):
return self.model(t, x)
|