Spaces:
Sleeping
Sleeping
| import torch | |
| import torch.nn as nn | |
| from math import gcd | |
| def mmd_linear(f_of_X, f_of_Y): | |
| delta = f_of_X - f_of_Y | |
| loss = torch.mean(torch.mm(delta, torch.transpose(delta, 0, 1))) | |
| return loss | |
| class MMDLinear(nn.Module): | |
| def __init__(self): | |
| super().__init__() | |
| def forward(self, fea_source, fea_target): | |
| n_s, d_s = fea_source.size() | |
| n_t, d_t = fea_target.size() | |
| assert d_s == d_t | |
| if n_s != n_t: | |
| n = int(n_s * n_t / gcd(n_s, n_t)) # 最小公倍数 | |
| fea_source = fea_source.repeat((int(n / n_s), 1)) | |
| fea_target = fea_target.repeat((int(n / n_t), 1)) | |
| return mmd_linear(fea_source, fea_target) | |