File size: 281 Bytes
f4cade0 |
1 2 3 4 5 6 7 8 9 10 11 |
import torch
class _WrapperModule(torch.nn.Module):
def __init__(self, f): # type: ignore[no-untyped-def]
super().__init__()
self.f = f
def forward(self, *args, **kwargs): # type: ignore[no-untyped-def]
return self.f(*args, **kwargs)
|