File size: 698 Bytes
c5f4ee2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
import torch
import torch.nn as nn
from math import gcd


def mmd_linear(f_of_X, f_of_Y):
    delta = f_of_X - f_of_Y
    loss = torch.mean(torch.mm(delta, torch.transpose(delta, 0, 1)))
    return loss

class MMDLinear(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, fea_source, fea_target):
        n_s, d_s = fea_source.size()
        n_t, d_t = fea_target.size()

        assert d_s == d_t

        if n_s != n_t:
            n = int(n_s * n_t / gcd(n_s, n_t)) # 最小公倍数

            fea_source = fea_source.repeat((int(n / n_s), 1))
            fea_target = fea_target.repeat((int(n / n_t), 1))
        return mmd_linear(fea_source, fea_target)