| from torch.autograd import Function | |
| class GradientReversal(Function): | |
| def forward(ctx, x, alpha): | |
| ctx.save_for_backward(x, alpha) | |
| return x | |
| def backward(ctx, grad_output): | |
| grad_input = None | |
| _, alpha = ctx.saved_tensors | |
| if ctx.needs_input_grad[0]: | |
| grad_input = - alpha*grad_output | |
| return grad_input, None | |
| revgrad = GradientReversal.apply |