| |
|
|
| import torch |
|
|
| def weighting_function(x, samples, gamma): |
| pairwise_sq_diff = (x[:, None, :] - samples[None, :, :]) ** 2 |
| pairwise_sq_dist = pairwise_sq_diff.sum(-1) |
| weights = torch.exp(-pairwise_sq_dist / (2 * gamma**2)) |
| return weights |
|
|
|
|
| def land_metric_tensor(x, samples, gamma, rho): |
| weights = weighting_function(x, samples, gamma) |
| differences = samples[None, :, :] - x[:, None, :] |
| squared_differences = differences**2 |
|
|
| |
| M_dd_diag = torch.einsum("bn,bnd->bd", weights, squared_differences) + rho |
|
|
| |
| M_dd_inv_diag = 1.0 / M_dd_diag |
| return M_dd_inv_diag |
|
|
|
|
| def weighting_function_dt(x, dx_dt, samples, gamma, weights): |
| pairwise_sq_diff_dt = (x[:, None, :] - samples[None, :, :]) * dx_dt[:, None, :] |
| return -pairwise_sq_diff_dt.sum(-1) * weights / (gamma**2) |
|
|