File size: 528 Bytes
4f2b2f4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
# A file of bregman divergences
import torch


def mse(x, y):
    sq_diff = (x - y) ** 2
    if x.shape != y.shape:
        assert False, "x and y must have the same shape"
    return sq_diff.reshape(sq_diff.size(0), -1).sum(dim=-1)


# TODO: check if this formulation is correct
def jump_kernel_elbo(x, y, eps=1e-6):
    # x_safe: true length
    # y_safe: predicted length
    x_safe = torch.clamp(x, min=eps)
    y_safe = torch.clamp(y, min=eps)

    return y_safe - x_safe + x_safe * (torch.log(x_safe) - torch.log(y_safe))