| """ inverse matrix """ |
|
|
| import torch |
|
|
|
|
| def batch_inverse(x): |
| """ M(n) -> M(n); x -> x^-1 """ |
| batch_size, h, w = x.size() |
| assert h == w |
| y = torch.zeros_like(x) |
| for i in range(batch_size): |
| y[i, :, :] = x[i, :, :].inverse() |
| return y |
|
|
| def batch_inverse_dx(y): |
| """ backward """ |
| |
| |
| |
| |
| |
| |
| |
| batch_size, h, w = y.size() |
| assert h == w |
| |
| |
| yl = y.repeat(1, 1, h).view(batch_size*h*h, h, 1) |
| yr = y.transpose(1, 2).repeat(1, h, 1).view(batch_size*h*h, 1, h) |
| dy = - yl.bmm(yr).view(batch_size, h, h, h, h) |
|
|
| |
| |
| |
| |
| |
|
|
| return dy |
|
|
|
|
| def batch_pinv_dx(x): |
| """ returns y = (x'*x)^-1 * x' and dy/dx. """ |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| batch_size, h, w = x.size() |
| xt = x.transpose(1, 2) |
| s = xt.bmm(x) |
| b = batch_inverse(s) |
| y = b.bmm(xt) |
|
|
| |
| ex = torch.eye(h*w).to(x).unsqueeze(0).view(1, h, w, h, w) |
| |
| ex1 = ex.view(1, h, w*h*w) |
| dx1 = x.transpose(1, 2).matmul(ex1).view(batch_size, w, w, h, w) |
| ds_dx = dx1.transpose(1, 2) + dx1 |
| |
| db_ds = batch_inverse_dx(b) |
| |
| db1 = db_ds.view(batch_size, w*w, w*w).bmm(ds_dx.view(batch_size, w*w, h*w)) |
| db_dx = db1.view(batch_size, w, w, h, w) |
| |
| dy1 = db_dx.transpose(1, 2).contiguous().view(batch_size, w, w*h*w) |
| dy1 = x.matmul(dy1).view(batch_size, h, w, h, w) |
| ext = ex.transpose(1, 2).contiguous().view(1, w, h*h*w) |
| dy2 = b.matmul(ext).view(batch_size, w, h, h, w) |
| dy_dx = dy1.transpose(1, 2) + dy2 |
|
|
| return y, dy_dx |
|
|
|
|
| class InvMatrix(torch.autograd.Function): |
| """ M(n) -> M(n); x -> x^-1. |
| """ |
| @staticmethod |
| def forward(ctx, x): |
| y = batch_inverse(x) |
| ctx.save_for_backward(y) |
| return y |
|
|
| @staticmethod |
| def backward(ctx, grad_output): |
| y, = ctx.saved_tensors |
| |
| batch_size, h, w = y.size() |
| assert h == w |
|
|
| |
| |
| |
| |
| |
|
|
| dy = batch_inverse_dx(y) |
| go = grad_output.contiguous().view(batch_size, 1, h*h) |
| ym = dy.view(batch_size, h*h, h*h) |
| r = go.bmm(ym) |
| grad_input = r.view(batch_size, h, h) |
|
|
| return grad_input |
|
|
|
|
|
|
| if __name__ == '__main__': |
| def test(): |
| x = torch.randn(2, 3, 2) |
| x_val = x.requires_grad_() |
|
|
| s_val = x_val.transpose(1, 2).bmm(x_val) |
| s_inv = InvMatrix.apply(s_val) |
| y_val = s_inv.bmm(x_val.transpose(1, 2)) |
| y_val.sum().backward() |
| t1 = x_val.grad |
|
|
| y, dy_dx = batch_pinv_dx(x) |
| t2 = dy_dx.sum(1).sum(1) |
|
|
| print(t1) |
| print(t2) |
| print(t1 - t2) |
|
|
| test() |
|
|
| |
|
|