| import torch.nn as nn | |
| class ModuleAttrMixin(nn.Module): | |
| def __init__(self): | |
| super().__init__() | |
| self._dummy_variable = nn.Parameter() | |
| def device(self): | |
| return next(iter(self.parameters())).device | |
| def dtype(self): | |
| return next(iter(self.parameters())).dtype | |