Spaces:
Sleeping
Sleeping
File size: 1,949 Bytes
2875fe6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 |
import torch
from torch import nn
def cacf_torch(x, max_lag, dim=(0, 1)):
def get_lower_triangular_indices(n):
return [list(x) for x in torch.tril_indices(n, n)]
ind = get_lower_triangular_indices(x.shape[2])
x = (x - x.mean(dim, keepdims=True)) / x.std(dim, keepdims=True)
x_l = x[..., ind[0]]
x_r = x[..., ind[1]]
cacf_list = list()
for i in range(max_lag):
y = x_l[:, i:] * x_r[:, :-i] if i > 0 else x_l * x_r
cacf_i = torch.mean(y, (1))
cacf_list.append(cacf_i)
cacf = torch.cat(cacf_list, 1)
return cacf.reshape(cacf.shape[0], -1, len(ind[0]))
class Loss(nn.Module):
def __init__(
self,
name,
reg=1.0,
transform=lambda x: x,
threshold=10.0,
backward=False,
norm_foo=lambda x: x,
):
super(Loss, self).__init__()
self.name = name
self.reg = reg
self.transform = transform
self.threshold = threshold
self.backward = backward
self.norm_foo = norm_foo
def forward(self, x_fake):
self.loss_componentwise = self.compute(x_fake)
return self.reg * self.loss_componentwise.mean()
def compute(self, x_fake):
raise NotImplementedError()
@property
def success(self):
return torch.all(self.loss_componentwise <= self.threshold)
class CrossCorrelLoss(Loss):
def __init__(self, x_real, **kwargs):
super(CrossCorrelLoss, self).__init__(
norm_foo=lambda x: torch.abs(x).sum(0), **kwargs
)
self.cross_correl_real = cacf_torch(self.transform(x_real), 1).mean(0)[0]
def compute(self, x_fake):
cross_correl_fake = cacf_torch(self.transform(x_fake), 1).mean(0)[0]
loss = self.norm_foo(
cross_correl_fake - self.cross_correl_real.to(x_fake.device)
)
return loss / 10.0
|