| import torch | |
| # from https://github.com/napsternxg/pytorch-practice/blob/master/Pytorch%20-%20MMD%20VAE.ipynb | |
| def compute_kernel(x, y): | |
| x_size = x.size(0) | |
| y_size = y.size(0) | |
| dim = x.size(1) | |
| x = x.unsqueeze(1) # (x_size, 1, dim) | |
| y = y.unsqueeze(0) # (1, y_size, dim) | |
| tiled_x = x.expand(x_size, y_size, dim) | |
| tiled_y = y.expand(x_size, y_size, dim) | |
| kernel_input = (tiled_x - tiled_y).pow(2).mean(2)/float(dim) | |
| return torch.exp(-kernel_input) # (x_size, y_size) | |
| def compute_mmd(x, y): | |
| x_kernel = compute_kernel(x, x) | |
| y_kernel = compute_kernel(y, y) | |
| xy_kernel = compute_kernel(x, y) | |
| mmd = x_kernel.mean() + y_kernel.mean() - 2*xy_kernel.mean() | |
| return mmd | |