File size: 286 Bytes
3650b90 |
1 2 3 4 5 6 7 8 9 10 11 |
from .functional import revgrad
import torch
from torch import nn
class GradientReversal(nn.Module):
def __init__(self, alpha):
super().__init__()
self.alpha = torch.tensor(alpha, requires_grad=False)
def forward(self, x):
return revgrad(x, self.alpha) |