Spaces:
Sleeping
Sleeping
| from collections import OrderedDict | |
| import torch | |
| import torch.nn as nn | |
| class OnesLayer(nn.Module): | |
| def __init__(self, size=None): | |
| super().__init__() | |
| self.size = size | |
| def forward(self, tensor): | |
| shape = list(tensor.shape) | |
| shape[1] = 1 # return only one channel | |
| if self.size is not None: | |
| shape[2], shape[3] = self.size | |
| return torch.ones(shape, dtype=torch.float32, device=tensor.device) | |
| class UninformativeFeatures(torch.nn.Sequential): | |
| def __init__(self): | |
| super().__init__(OrderedDict([ | |
| ('ones', OnesLayer(size=(1, 1))), | |
| ])) | |