Spaces:
Sleeping
Sleeping
File size: 698 Bytes
c5f4ee2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 |
import torch
import torch.nn as nn
from math import gcd
def mmd_linear(f_of_X, f_of_Y):
delta = f_of_X - f_of_Y
loss = torch.mean(torch.mm(delta, torch.transpose(delta, 0, 1)))
return loss
class MMDLinear(nn.Module):
def __init__(self):
super().__init__()
def forward(self, fea_source, fea_target):
n_s, d_s = fea_source.size()
n_t, d_t = fea_target.size()
assert d_s == d_t
if n_s != n_t:
n = int(n_s * n_t / gcd(n_s, n_t)) # 最小公倍数
fea_source = fea_source.repeat((int(n / n_s), 1))
fea_target = fea_target.repeat((int(n / n_t), 1))
return mmd_linear(fea_source, fea_target)
|