File size: 281 Bytes
f4cade0
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
import torch


class _WrapperModule(torch.nn.Module):
    def __init__(self, f):  # type: ignore[no-untyped-def]
        super().__init__()
        self.f = f

    def forward(self, *args, **kwargs):  # type: ignore[no-untyped-def]
        return self.f(*args, **kwargs)