Bailan-Alex's picture
Upload folder using huggingface_hub
4f2b2f4 verified
# A file of bregman divergences
import torch
def mse(x, y):
sq_diff = (x - y) ** 2
if x.shape != y.shape:
assert False, "x and y must have the same shape"
return sq_diff.reshape(sq_diff.size(0), -1).sum(dim=-1)
# TODO: check if this formulation is correct
def jump_kernel_elbo(x, y, eps=1e-6):
# x_safe: true length
# y_safe: predicted length
x_safe = torch.clamp(x, min=eps)
y_safe = torch.clamp(y, min=eps)
return y_safe - x_safe + x_safe * (torch.log(x_safe) - torch.log(y_safe))