| | |
| | |
| | |
| | |
| |
|
| | from torch.autograd import Function |
| | import torch |
| | from torch import nn |
| |
|
| |
|
| | class GradientReversal(Function): |
| | @staticmethod |
| | def forward(ctx, x, alpha): |
| | ctx.save_for_backward(x, alpha) |
| | return x |
| |
|
| | @staticmethod |
| | def backward(ctx, grad_output): |
| | grad_input = None |
| | _, alpha = ctx.saved_tensors |
| | if ctx.needs_input_grad[0]: |
| | grad_input = -alpha * grad_output |
| | return grad_input, None |
| |
|
| |
|
| | revgrad = GradientReversal.apply |
| |
|
| |
|
| | class GradientReversal(nn.Module): |
| | def __init__(self, alpha): |
| | super().__init__() |
| | self.alpha = torch.tensor(alpha, requires_grad=False) |
| |
|
| | def forward(self, x): |
| | return revgrad(x, self.alpha) |
| |
|