File size: 292 Bytes
b55bace | 1 2 3 4 5 6 7 8 9 10 11 12 13 | import sys
import torch
class flow_model_torch_wrapper(torch.nn.Module):
"""Wraps model to torchdyn compatible format."""
def __init__(self, model):
super().__init__()
self.model = model
def forward(self, t, x, *args, **kwargs):
return self.model(t, x)
|