xiaoxuezi commited on
Commit
875baeb
·
1 Parent(s): c608f1c
lossfunction/.DS_Store ADDED
Binary file (6.15 kB). View file
 
lossfunction/AdditiveAngularMargin.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import numpy
5
+ import math
6
+ from utils.acc import accuracy
7
+
8
+ class AdditiveAngularMargin(nn.Module):
9
+ def __init__(self,
10
+ feature_dim=256,
11
+ n_classes=1000,
12
+ margin=0.2,
13
+ scale=30,
14
+ easy_margin=False):
15
+ super(AdditiveAngularMargin, self).__init__()
16
+ self.margin = margin
17
+ self.scale = scale
18
+ self.easy_margin = easy_margin
19
+ self.w = nn.Parameter(torch.FloatTensor(feature_dim, n_classes))
20
+ nn.init.xavier_normal_(self.w)
21
+ self.cos_m = math.cos(self.margin)
22
+ self.sin_m = math.sin(self.margin)
23
+ self.th = math.cos(math.pi - self.margin)
24
+ self.mm = math.sin(math.pi - self.margin) * self.margin
25
+ self.nll_loss = nn.NLLLoss()
26
+ self.n_classes = n_classes
27
+ self.test_normalize = True
28
+
29
+ def forward(self, logits, targets):
30
+ # logits = self.drop(logits)
31
+ logits = F.normalize(logits, p=2, dim=1, eps=1e-8)
32
+ wn = F.normalize(self.w, p=2, dim=0, eps=1e-8)
33
+
34
+ cosine = logits @ wn
35
+
36
+ #cosine = outputs.astype('float32')
37
+ sine = torch.sqrt(1.0 - torch.square(cosine))
38
+ phi = cosine * self.cos_m - sine * self.sin_m # cos(theta + m)
39
+ if self.easy_margin:
40
+ phi = torch.where(cosine > 0, phi, cosine)
41
+ else:
42
+ phi = torch.where(cosine > self.th, phi, cosine - self.mm)
43
+ target_one_hot = F.one_hot(targets, self.n_classes)
44
+ outputs = (target_one_hot * phi) + ((1.0 - target_one_hot) * cosine)
45
+ outputs = self.scale * outputs
46
+ pred = F.log_softmax(outputs, dim=-1)
47
+ nloss = self.nll_loss(pred, targets)
48
+ prec1 = accuracy(pred.detach(), targets.detach(), topk=(1,))[0]
49
+
50
+ return nloss, prec1
lossfunction/Unetloss.py ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from lossfunction.softmaxproto import SoftmaxProto
2
+ import torch.nn as nn
3
+ import lossfunction.softmax as softmax
4
+ import torch
5
+ import torch.nn.functional as F
6
+ import numpy
7
+
8
+ class Unetloss(nn.Module):
9
+ def __init__(self, nOut, nClasses):
10
+ super(Unetloss, self).__init__()
11
+
12
+ self.test_normalize = True
13
+ self.softmax = SoftmaxProto(nOut, nClasses)
14
+ self.mseloss = nn.MSELoss()
15
+ print('Initialised Unet Loss')
16
+
17
+ def forward(self, emb, spectrogram, x, label=None):
18
+
19
+ nlossE, prec1 = self.softmax(emb, label)
20
+ nlossS = self.mseloss(spectrogram, x)
21
+ # print("\nnlossE:", nlossE,"nlossS:", nlossS)
22
+ # nlossE: 13.1695 , nlossS:0.8902
23
+
24
+ return nlossE+10*nlossS, prec1
25
+
26
+
27
+ class UnetMaskloss(nn.Module):
28
+ def __init__(self, nOut, nClasses):
29
+ super(UnetMaskloss, self).__init__()
30
+
31
+ self.test_normalize = True
32
+ self.softmax = softmax.Softmax(nOut, nClasses)
33
+ self.mseloss = nn.MSELoss(reduction='sum')
34
+ self.criterion = torch.nn.CrossEntropyLoss()
35
+ print('Initialised UnetMask Loss')
36
+
37
+ def forward(self, emb, spectrogram, label=None):
38
+
39
+ assert emb.size()[1] >= 2
40
+ nlossEd1 = self.mseloss(emb[:, 0, :], emb[:, 1, :])+self.mseloss(emb[:, 0, :], emb[:, 2, :])
41
+ nlossEd2 = self.mseloss(emb[:, 3, :], emb[:, 4, :])+self.mseloss(emb[:, 3, :], emb[:, 5, :])
42
+
43
+ emb_anchor = torch.mean(emb[:, 0:3, :], 1)
44
+ emb_positive = torch.mean(emb[:, 3:6, :], 1)
45
+ stepsize = emb_anchor.size()[0]
46
+ output = -1 * (F.pairwise_distance(emb_positive.unsqueeze(-1), emb_anchor.unsqueeze(-1).transpose(0, 2)) ** 2)
47
+ label0 = torch.from_numpy(numpy.asarray(range(0, stepsize))).cuda()
48
+ nlossEP = self.criterion(output, label0)
49
+
50
+ nlossEC, prec1 = self.softmax(emb.reshape(-1, emb.size()[-1]), label.repeat_interleave(emb.size()[1]))
51
+
52
+ nlossSd1 = self.mseloss(spectrogram[:, 0, :, :], spectrogram[:, 1, :, :]) + self.mseloss(spectrogram[:, 0, :, :], spectrogram[:, 2, :, :])
53
+ nlossSd2 = self.mseloss(spectrogram[:, 3, :, :], spectrogram[:, 4, :, :]) + self.mseloss(
54
+ spectrogram[:, 3, :, :], spectrogram[:, 5, :, :])
55
+
56
+ spec_anchor = torch.mean(spectrogram[:, 0:3, :, :], 1)
57
+ spec_positive = torch.mean(spectrogram[:, 3:6, :, :], 1)
58
+
59
+ nlossS = self.mseloss(spec_anchor, spec_positive)
60
+ # print("\nnlossEd1:", nlossEd1, "nlossEd2:", nlossEd2, "nlossEP:", nlossEP, "nlossEC:", nlossEC)
61
+ # print("nlossSd1:", nlossSd1, "nlossSd2:", nlossSd2, "nlossS:", nlossS)
62
+ # nlossEd1: 3.9563, nlossEd2: 3.5833, nlossEP:0.6218,nlossEC: 8.7362,
63
+ # nlossSd1: 3.4339, nlossSd2: 30.1156,nlossS: 2.2820,
64
+
65
+
66
+ loss = 100*(nlossEd1+nlossEd2)+10*nlossEP+nlossEC+nlossSd1+nlossSd2+10*nlossS
67
+ return loss, prec1
68
+
69
+
70
+ if __name__ == "__main__":
71
+ # a = torch.tensor([[[1, 2], [3, 4]], [[1, 2], [3, 4]]])
72
+ # b = torch.tensor([[[2, 3], [4, 5]], [[1, 2], [3, 4]]])
73
+ a = torch.randint(10,(1,2,3))
74
+ b = torch.randint(10,(1,2,3))
75
+ print(a)
76
+ print(b)
77
+ print(a.shape,a.shape)
78
+ # loss_fn = torch.nn.MSELoss(reduce=False, size_average=True)
79
+ # input = torch.autograd.Variable(torch.from_numpy(a))
80
+ # target = torch.autograd.Variable(torch.from_numpy(b))
81
+ # loss = loss_fn(input.float(), target.float())
82
+ # print(loss)
83
+ distance = F.pairwise_distance(a, b)
84
+ print(distance.shape)
85
+ print(distance)
86
+
87
+
lossfunction/__init__.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ from .softmaxproto import SoftmaxProto
2
+ from .Unetloss import Unetloss, UnetMaskloss
3
+ from .softmax import Softmax
4
+ from .proto import proto
5
+ from .AdditiveAngularMargin import AdditiveAngularMargin
6
+ from .aamsoftmax import AamSoftmax
7
+ from .aamsoftmaxproto import AamSoftmaxProto
lossfunction/aamsoftmax.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import math
5
+ from utils.acc import accuracy
6
+
7
+ class AamSoftmax(nn.Module):
8
+ def __init__(self, nOut, nClasses, margin=0.2, scale=30, easy_margin=False, **kwargs):
9
+ super(AamSoftmax, self).__init__()
10
+
11
+ self.test_normalize = True
12
+
13
+ self.m = margin
14
+ self.s = scale
15
+ self.in_feats = nOut
16
+ self.weight = torch.nn.Parameter(torch.FloatTensor(nClasses, nOut), requires_grad=True)
17
+ self.ce = nn.CrossEntropyLoss()
18
+ nn.init.xavier_normal_(self.weight, gain=1)
19
+
20
+ self.easy_margin = easy_margin
21
+ self.cos_m = math.cos(self.m)
22
+ self.sin_m = math.sin(self.m)
23
+
24
+ # make the function cos(theta+m) monotonic decreasing while theta in [0°,180°]
25
+ self.th = math.cos(math.pi - self.m)
26
+ self.mm = math.sin(math.pi - self.m) * self.m
27
+
28
+ print('Initialised AAMSoftmax margin %.3f scale %.3f'%(self.m,self.s))
29
+
30
+ def forward(self, x, label=None):
31
+
32
+ assert x.size()[0] == label.size()[0]
33
+ assert x.size()[1] == self.in_feats
34
+
35
+ # cos(theta)
36
+ cosine = F.linear(F.normalize(x), F.normalize(self.weight))
37
+ # print("cosine:", cosine.shape)
38
+ # cos(theta + m)
39
+ sine = torch.sqrt((1.0 - torch.mul(cosine, cosine)).clamp(0, 1))
40
+ # phi = cos(ø+m)
41
+ phi = cosine * self.cos_m - sine * self.sin_m
42
+ # print(self.cos_m)
43
+ # print("phi:", phi.shape)
44
+
45
+ if self.easy_margin:
46
+ phi = torch.where(cosine > 0, phi, cosine)
47
+ else:
48
+ phi = torch.where((cosine - self.th) > 0, phi, cosine - self.mm)
49
+
50
+ #one_hot = torch.zeros(cosine.size(), device='cuda' if torch.cuda.is_available() else 'cpu')
51
+ one_hot = torch.zeros_like(cosine)
52
+ one_hot.scatter_(1, label.view(-1, 1), 1)
53
+ output = (one_hot * phi) + ((1.0 - one_hot) * cosine)
54
+ output = output * self.s
55
+
56
+ loss = self.ce(output, label)
57
+ prec1 = accuracy(output.detach(), label.detach(), topk=(1,))[0]
58
+ return loss, prec1
59
+
60
+
61
+ if __name__ == "__main__":
62
+ x = torch.randn(32, 512)
63
+ y = torch.randint(1000, size=(32,))
64
+ print(x.shape, y.shape)
65
+ loss = AamSoftmax(512, 1000)
66
+ nloss, prec1 = loss(x, y)
67
+ print(nloss, prec1)
lossfunction/aamsoftmaxproto.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import lossfunction.aamsoftmax as aamsoftmax
4
+ import lossfunction.angleproto as angleproto
5
+
6
+
7
+ class AamSoftmaxProto(nn.Module):
8
+
9
+ def __init__(self, nOut, nClasses, margin, scale):
10
+ super(AamSoftmaxProto, self).__init__()
11
+
12
+ self.test_normalize = True
13
+
14
+ self.aamsoftmax = aamsoftmax.AamSoftmax(nOut, nClasses, margin, scale)
15
+ self.angleproto = angleproto.AngleProto()
16
+
17
+ print('Initialised AamSoftmaxPrototypical Loss')
18
+
19
+ def forward(self, x, label=None):
20
+
21
+ assert x.size()[1] == 2
22
+
23
+ nlossS, prec1 = self.aamsoftmax(x.reshape(-1, x.size()[-1]), label.repeat_interleave(2))
24
+
25
+ nlossP, _ = self.angleproto(x, None)
26
+ # print("lossP:", nlossP, "nlossS:", nlossS)
27
+ # lossP:0.6678 nlossS:13.6913
28
+
29
+ return nlossS + nlossP, prec1
lossfunction/amsoftmax.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from utils.acc import accuracy
4
+
5
+ class AmSoftmax(nn.Module):
6
+ def __init__(self, nOut, nClasses, margin=0.3, scale=15, **kwargs):
7
+ super(AmSoftmax, self).__init__()
8
+
9
+ self.test_normalize = True
10
+
11
+ self.m = margin
12
+ self.s = scale
13
+ self.in_feats = nOut
14
+ self.W = torch.nn.Parameter(torch.randn(nOut, nClasses), requires_grad=True)
15
+ self.ce = nn.CrossEntropyLoss()
16
+ nn.init.xavier_normal_(self.W, gain=1)
17
+
18
+ print('Initialised AMSoftmax m=%.3f s=%.3f'%(self.m,self.s))
19
+
20
+ def forward(self, x, label=None):
21
+
22
+ assert x.size()[0] == label.size()[0]
23
+ assert x.size()[1] == self.in_feats
24
+
25
+ x_norm = torch.norm(x, p=2, dim=1, keepdim=True).clamp(min=1e-12)
26
+ x_norm = torch.div(x, x_norm)
27
+ w_norm = torch.norm(self.W, p=2, dim=0, keepdim=True).clamp(min=1e-12)
28
+ w_norm = torch.div(self.W, w_norm)
29
+ costh = torch.mm(x_norm, w_norm)
30
+ label_view = label.view(-1, 1)
31
+ if label_view.is_cuda: label_view = label_view.cpu()
32
+ delt_costh = torch.zeros(costh.size()).scatter_(1, label_view, self.m)
33
+ if x.is_cuda: delt_costh = delt_costh.cuda()
34
+ costh_m = costh - delt_costh
35
+ costh_m_s = self.s * costh_m
36
+ loss = self.ce(costh_m_s, label)
37
+ prec1 = accuracy(costh_m_s.detach(), label.detach(), topk=(1,))[0]
38
+ return loss, prec1
39
+
lossfunction/angleproto.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import numpy
5
+ from utils.acc import accuracy
6
+
7
+ class AngleProto(nn.Module):
8
+
9
+ def __init__(self, init_w=10.0, init_b=-5.0):
10
+ super(AngleProto, self).__init__()
11
+
12
+ self.test_normalize = True
13
+
14
+ self.w = nn.Parameter(torch.tensor(init_w))
15
+ self.b = nn.Parameter(torch.tensor(init_b))
16
+ self.criterion = torch.nn.CrossEntropyLoss()
17
+ self.mse = torch.nn.MSELoss()
18
+
19
+ print('Initialised AngleProto')
20
+
21
+ def forward(self, x, label=None):
22
+
23
+ assert x.size()[1] >= 2
24
+
25
+ out_anchor = torch.mean(x[:,1:,:],1)
26
+ out_positive = x[:,0,:]
27
+ stepsize = out_anchor.size()[0]
28
+
29
+ cos_sim_matrix = F.cosine_similarity(out_positive.unsqueeze(-1),out_anchor.unsqueeze(-1).transpose(0,2))
30
+ # print(cos_sim_matrix)
31
+ torch.clamp(self.w, 1e-6)
32
+ cos_sim_matrix = cos_sim_matrix * self.w + self.b
33
+
34
+ label = torch.from_numpy(numpy.asarray(range(0,stepsize))).cuda()
35
+ # print(label)
36
+ nloss = self.criterion(cos_sim_matrix, label) + self.mse(out_positive, out_anchor)
37
+ # nloss = self.criterion(cos_sim_matrix, label)
38
+ # print("lossC:", self.criterion(cos_sim_matrix, label), "lossM:", self.mse(out_positive, out_anchor))
39
+ prec1 = accuracy(cos_sim_matrix.detach(), label.detach(), topk=(1,))[0]
40
+
41
+ return nloss, prec1
lossfunction/ge2e.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import numpy
5
+ from utils.acc import accuracy
6
+
7
+ class Ge2e(nn.Module):
8
+
9
+ def __init__(self, init_w=10.0, init_b=-5.0, **kwargs):
10
+ super(Ge2e, self).__init__()
11
+
12
+ self.test_normalize = True
13
+
14
+ self.w = nn.Parameter(torch.tensor(init_w))
15
+ self.b = nn.Parameter(torch.tensor(init_b))
16
+ self.criterion = torch.nn.CrossEntropyLoss()
17
+
18
+ print('Initialised GE2E')
19
+
20
+ def forward(self, x, label=None):
21
+
22
+ assert x.size()[1] >= 2
23
+
24
+ gsize = x.size()[1]
25
+ centroids = torch.mean(x, 1)
26
+ stepsize = x.size()[0]
27
+
28
+ cos_sim_matrix = []
29
+
30
+ for ii in range(0,gsize):
31
+ idx = [*range(0,gsize)]
32
+ idx.remove(ii)
33
+ exc_centroids = torch.mean(x[:,idx,:], 1) # (32,512)
34
+ cos_sim_diag = F.cosine_similarity(x[:,ii,:],exc_centroids)
35
+ # print(cos_sim_diag.shape)
36
+ cos_sim = F.cosine_similarity(x[:,ii,:].unsqueeze(-1),centroids.unsqueeze(-1).transpose(0,2))
37
+ cos_sim[range(0,stepsize),range(0,stepsize)] = cos_sim_diag
38
+ cos_sim_matrix.append(torch.clamp(cos_sim,1e-6))
39
+
40
+ cos_sim_matrix = torch.stack(cos_sim_matrix,dim=1)
41
+
42
+ torch.clamp(self.w, 1e-6)
43
+ cos_sim_matrix = cos_sim_matrix * self.w + self.b
44
+
45
+ label = torch.from_numpy(numpy.asarray(range(0,stepsize))).cuda()
46
+ nloss = self.criterion(cos_sim_matrix.view(-1,stepsize), torch.repeat_interleave(label,repeats=gsize,dim=0).cuda())
47
+ prec1 = accuracy(cos_sim_matrix.view(-1,stepsize).detach(), torch.repeat_interleave(label,repeats=gsize,dim=0).detach(), topk=(1,))[0]
48
+
49
+ return nloss, prec1
50
+
51
+
52
+ if __name__ == "__main__":
53
+ x = torch.randn(32, 10, 512).cuda()
54
+ y = torch.randint(1000, size=(32,)).cuda()
55
+ print(x.shape, y.shape)
56
+ loss = Ge2e()
57
+ nloss, prec1 = loss(x, y)
58
+ print(nloss, prec1)
lossfunction/proto.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import numpy
5
+ from utils.acc import accuracy
6
+
7
+ class proto(nn.Module):
8
+
9
+ def __init__(self, **kwargs):
10
+ super(proto, self).__init__()
11
+
12
+ self.test_normalize = False
13
+
14
+ self.criterion = torch.nn.CrossEntropyLoss()
15
+
16
+ print('Initialised Prototypical Loss')
17
+
18
+ def forward(self, x, label=None):
19
+
20
+ assert x.size()[1] >= 2
21
+
22
+ out_anchor = torch.mean(x[:, 1:, :], 1)
23
+ out_positive = x[:, 0, :]
24
+ stepsize = out_anchor.size()[0]
25
+ # print(out_anchor.shape, out_positive.shape)
26
+ # print(out_positive.unsqueeze(-1).shape, out_anchor.unsqueeze(-1).transpose(0, 2).shape)
27
+ # (10, 512, 1) (1,512,10)生成一个矩阵,使相同的靠近,对角线靠近。
28
+ output = -1 * (F.pairwise_distance(out_positive.unsqueeze(-1), out_anchor.unsqueeze(-1).transpose(0,2))**2)
29
+ # print(output)
30
+ label = torch.from_numpy(numpy.asarray(range(0,stepsize))).cuda()
31
+ # label = torch.from_numpy(numpy.asarray(range(0, stepsize)))
32
+ # print(label)
33
+ nloss = self.criterion(output, label)
34
+ prec1 = accuracy(output.detach(), label.detach(), topk=(1,))[0]
35
+
36
+ return nloss, prec1
37
+
38
+
39
+ if __name__ == "__main__":
40
+ # x = torch.randn(10, 10, 512)
41
+ # loss = LossFunction()
42
+ # nloss, prec1 = loss(x)
43
+ # print(nloss, prec1)
44
+ x = torch.randint(10, (10,512,10))
45
+ y = torch.randint(10, (10,512,10))
46
+ d = F.pairwise_distance(x,y)
47
+ print(d)
48
+ print(d.shape)
lossfunction/softmax.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from utils.acc import accuracy
4
+
5
+
6
+ class Softmax(nn.Module):
7
+ def __init__(self, nOut, nClasses):
8
+ super(Softmax, self).__init__()
9
+
10
+ self.test_normalize = True
11
+
12
+ self.criterion = torch.nn.CrossEntropyLoss()
13
+ self.fc = nn.Linear(nOut, nClasses)
14
+
15
+ print('Initialised Softmax Loss')
16
+
17
+ def forward(self, x, label=None):
18
+ x = self.fc(x)
19
+ nloss = self.criterion(x, label)
20
+ prec1 = accuracy(x.detach(), label.detach(), topk=(1,))[0]
21
+
22
+ return nloss, prec1
lossfunction/softmaxproto.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #! /usr/bin/python
2
+ # -*- encoding: utf-8 -*-
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ import lossfunction.softmax as softmax
7
+ import lossfunction.angleproto as angleproto
8
+
9
+ class SoftmaxProto(nn.Module):
10
+
11
+ def __init__(self, nOut, nClasses):
12
+ super(SoftmaxProto, self).__init__()
13
+
14
+ self.test_normalize = True
15
+
16
+ self.softmax = softmax.Softmax(nOut, nClasses)
17
+ self.angleproto = angleproto.AngleProto()
18
+
19
+ print('Initialised SoftmaxPrototypical Loss')
20
+
21
+ def forward(self, x, label=None):
22
+
23
+ if x.size()[1] != 2:
24
+ # 2是nPerSpeaker
25
+ x = x.reshape(-1, 2, x.size()[-1]).squeeze(1)
26
+
27
+ assert x.size()[1] == 2
28
+
29
+ nlossS, prec1 = self.softmax(x.reshape(-1, x.size()[-1]), label.repeat_interleave(2))
30
+
31
+ nlossP, _ = self.angleproto(x, None)
32
+ # print("lossP:", nlossP, "nlossS:", nlossS)
33
+ # lossP:0.6678 nlossS:13.6913
34
+
35
+ # return nlossS + nlossP, prec1
36
+ return nlossS + nlossP
37
+
lossfunction/triplet.py ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #! /usr/bin/python
2
+ # -*- encoding: utf-8 -*-
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+ import numpy
8
+ from tuneThreshold import tuneThresholdfromScore
9
+ import random
10
+
11
+ class LossFunction(nn.Module):
12
+
13
+ def __init__(self, hard_rank=0, hard_prob=0, margin=0, **kwargs):
14
+ super(LossFunction, self).__init__()
15
+
16
+ self.test_normalize = True
17
+
18
+ self.hard_rank = hard_rank
19
+ self.hard_prob = hard_prob
20
+ self.margin = margin
21
+
22
+ print('Initialised Triplet Loss')
23
+
24
+ def forward(self, x, label=None):
25
+
26
+ assert x.size()[1] == 2
27
+
28
+ out_anchor = F.normalize(x[:,0,:], p=2, dim=1)
29
+ out_positive = F.normalize(x[:,1,:], p=2, dim=1)
30
+ stepsize = out_anchor.size()[0]
31
+
32
+ output = -1 * (F.pairwise_distance(out_anchor.unsqueeze(-1),out_positive.unsqueeze(-1).transpose(0,2))**2)
33
+ print(output.shape)
34
+
35
+ negidx = self.mineHardNegative(output.detach())
36
+ print(negidx)
37
+
38
+ out_negative = out_positive[negidx,:]
39
+ print(out_negative.shape)
40
+
41
+ labelnp = numpy.array([1]*len(out_positive)+[0]*len(out_negative))
42
+
43
+ ## calculate distances
44
+ pos_dist = F.pairwise_distance(out_anchor,out_positive)
45
+ neg_dist = F.pairwise_distance(out_anchor,out_negative)
46
+ print(pos_dist.shape)
47
+ print(neg_dist.shape)
48
+ print(F.relu(torch.pow(pos_dist, 2)).shape)
49
+
50
+ ## loss function
51
+ nloss = torch.mean(F.relu(torch.pow(pos_dist, 2) - torch.pow(neg_dist, 2) + self.margin))
52
+
53
+ scores = -1 * torch.cat([pos_dist,neg_dist],dim=0).detach().cpu().numpy()
54
+ print(scores.shape)
55
+
56
+ errors = tuneThresholdfromScore(scores, labelnp, []);
57
+
58
+ return nloss, errors[1]
59
+
60
+ ## ===== ===== ===== ===== ===== ===== ===== =====
61
+ ## Hard negative mining
62
+ ## ===== ===== ===== ===== ===== ===== ===== =====
63
+
64
+ def mineHardNegative(self, output):
65
+
66
+ negidx = []
67
+
68
+ for idx, similarity in enumerate(output):
69
+
70
+ simval, simidx = torch.sort(similarity,descending=True)
71
+
72
+ if self.hard_rank < 0:
73
+
74
+ ## Semi hard negative mining
75
+
76
+ semihardidx = simidx[(similarity[idx] - self.margin < simval) & (simval < similarity[idx])]
77
+
78
+ if len(semihardidx) == 0:
79
+ negidx.append(random.choice(simidx))
80
+ else:
81
+ negidx.append(random.choice(semihardidx))
82
+
83
+ else:
84
+
85
+ ## Rank based negative mining
86
+
87
+ simidx = simidx[simidx!=idx]
88
+
89
+ if random.random() < self.hard_prob:
90
+ negidx.append(simidx[random.randint(0, self.hard_rank)])
91
+ else:
92
+ negidx.append(random.choice(simidx))
93
+
94
+ return negidx
95
+
96
+
97
+ if __name__=="__main__":
98
+ x = torch.randn(32, 2, 512)
99
+ loss = LossFunction()
100
+ nloss, errors = loss(x)
101
+ print(nloss, errors)
net/.DS_Store ADDED
Binary file (6.15 kB). View file
 
net/ECAPATDNN.py ADDED
@@ -0,0 +1,955 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """A popular speaker recognition and diarization model.
2
+
3
+ Authors
4
+ * Hwidong Na 2020
5
+ """
6
+
7
+ # import os
8
+ import torch # noqa: F401
9
+ import torch.nn as nn
10
+ import torch.nn.functional as F
11
+ import math
12
+ import torchaudio
13
+
14
+ def length_to_mask(length, max_len=None, dtype=None, device=None):
15
+ """Creates a binary mask for each sequence.
16
+
17
+ Reference: https://discuss.pytorch.org/t/how-to-generate-variable-length-mask/23397/3
18
+
19
+ Arguments
20
+ ---------
21
+ length : torch.LongTensor
22
+ Containing the length of each sequence in the batch. Must be 1D.
23
+ max_len : int
24
+ Max length for the mask, also the size of the second dimension.
25
+ dtype : torch.dtype, default: None
26
+ The dtype of the generated mask.
27
+ device: torch.device, default: None
28
+ The device to put the mask variable.
29
+
30
+ Returns
31
+ -------
32
+ mask : tensor
33
+ The binary mask.
34
+
35
+ Example
36
+ -------
37
+ >>> length=torch.Tensor([1,2,3])
38
+ >>> mask=length_to_mask(length)
39
+ >>> mask
40
+ tensor([[1., 0., 0.],
41
+ [1., 1., 0.],
42
+ [1., 1., 1.]])
43
+ """
44
+ assert len(length.shape) == 1
45
+
46
+ if max_len is None:
47
+ max_len = length.max().long().item() # using arange to generate mask
48
+ mask = torch.arange(
49
+ max_len, device=length.device, dtype=length.dtype
50
+ ).expand(len(length), max_len) < length.unsqueeze(1)
51
+
52
+ if dtype is None:
53
+ dtype = length.dtype
54
+
55
+ if device is None:
56
+ device = length.device
57
+
58
+ mask = torch.as_tensor(mask, dtype=dtype, device=device)
59
+ return mask
60
+
61
+ def get_padding_elem(L_in: int, stride: int, kernel_size: int, dilation: int):
62
+ """This function computes the number of elements to add for zero-padding.
63
+
64
+ Arguments
65
+ ---------
66
+ L_in : int
67
+ stride: int
68
+ kernel_size : int
69
+ dilation : int
70
+ """
71
+ if stride > 1:
72
+ n_steps = math.ceil(((L_in - kernel_size * dilation) / stride) + 1)
73
+ L_out = stride * (n_steps - 1) + kernel_size * dilation
74
+ padding = [kernel_size // 2, kernel_size // 2]
75
+
76
+ else:
77
+ L_out = (L_in - dilation * (kernel_size - 1) - 1) // stride + 1
78
+
79
+ padding = [(L_in - L_out) // 2, (L_in - L_out) // 2]
80
+ return padding
81
+
82
+ class _Conv1d(nn.Module):
83
+ """This function implements 1d convolution.
84
+
85
+ Arguments
86
+ ---------
87
+ out_channels : int
88
+ It is the number of output channels.
89
+ kernel_size : int
90
+ Kernel size of the convolutional filters.
91
+ input_shape : tuple
92
+ The shape of the input. Alternatively use ``in_channels``.
93
+ in_channels : int
94
+ The number of input channels. Alternatively use ``input_shape``.
95
+ stride : int
96
+ Stride factor of the convolutional filters. When the stride factor > 1,
97
+ a decimation in time is performed.
98
+ dilation : int
99
+ Dilation factor of the convolutional filters.
100
+ padding : str
101
+ (same, valid, causal). If "valid", no padding is performed.
102
+ If "same" and stride is 1, output shape is the same as the input shape.
103
+ "causal" results in causal (dilated) convolutions.
104
+ padding_mode : str
105
+ This flag specifies the type of padding. See torch.nn documentation
106
+ for more information.
107
+ skip_transpose : bool
108
+ If False, uses batch x time x channel convention of SpeakerRec.
109
+ If True, uses batch x channel x time convention.
110
+
111
+ Example
112
+ -------
113
+ >>> inp_tensor = torch.rand([10, 40, 16])
114
+ >>> cnn_1d = Conv1d(
115
+ ... input_shape=inp_tensor.shape, out_channels=8, kernel_size=5
116
+ ... )
117
+ >>> out_tensor = cnn_1d(inp_tensor)
118
+ >>> out_tensor.shape
119
+ torch.Size([10, 40, 8])
120
+ """
121
+
122
+ def __init__(
123
+ self,
124
+ out_channels,
125
+ kernel_size,
126
+ input_shape=None,
127
+ in_channels=None,
128
+ stride=1,
129
+ dilation=1,
130
+ padding="same",
131
+ groups=1,
132
+ bias=True,
133
+ padding_mode="reflect",
134
+ skip_transpose=False,
135
+ ):
136
+ super().__init__()
137
+ self.kernel_size = kernel_size
138
+ self.stride = stride
139
+ self.dilation = dilation
140
+ self.padding = padding
141
+ self.padding_mode = padding_mode
142
+ self.unsqueeze = False
143
+ self.skip_transpose = skip_transpose
144
+
145
+ if input_shape is None and in_channels is None:
146
+ raise ValueError("Must provide one of input_shape or in_channels")
147
+
148
+ if in_channels is None:
149
+ in_channels = self._check_input_shape(input_shape)
150
+
151
+ self.conv = nn.Conv1d(
152
+ in_channels,
153
+ out_channels,
154
+ self.kernel_size,
155
+ stride=self.stride,
156
+ dilation=self.dilation,
157
+ padding=0,
158
+ groups=groups,
159
+ bias=bias,
160
+ )
161
+
162
+ def forward(self, x):
163
+ """Returns the output of the convolution.
164
+
165
+ Arguments
166
+ ---------
167
+ x : torch.Tensor (batch, time, channel)
168
+ input to convolve. 2d or 4d tensors are expected.
169
+ """
170
+
171
+ if not self.skip_transpose:
172
+ x = x.transpose(1, -1)
173
+
174
+ if self.unsqueeze:
175
+ x = x.unsqueeze(1)
176
+
177
+ if self.padding == "same":
178
+ x = self._manage_padding(
179
+ x, self.kernel_size, self.dilation, self.stride
180
+ )
181
+
182
+ elif self.padding == "causal":
183
+ num_pad = (self.kernel_size - 1) * self.dilation
184
+ x = F.pad(x, (num_pad, 0))
185
+
186
+ elif self.padding == "valid":
187
+ pass
188
+
189
+ else:
190
+ raise ValueError(
191
+ "Padding must be 'same', 'valid' or 'causal'. Got "
192
+ + self.padding
193
+ )
194
+
195
+ wx = self.conv(x)
196
+
197
+ if self.unsqueeze:
198
+ wx = wx.squeeze(1)
199
+
200
+ if not self.skip_transpose:
201
+ wx = wx.transpose(1, -1)
202
+
203
+ return wx
204
+ def _manage_padding(
205
+ self, x, kernel_size: int, dilation: int, stride: int,
206
+ ):
207
+ """This function performs zero-padding on the time axis
208
+ such that their lengths is unchanged after the convolution.
209
+
210
+ Arguments
211
+ ---------
212
+ x : torch.Tensor
213
+ Input tensor.
214
+ kernel_size : int
215
+ Size of kernel.
216
+ dilation : int
217
+ Dilation used.
218
+ stride : int
219
+ Stride.
220
+ """
221
+
222
+ # Detecting input shape
223
+ L_in = x.shape[-1]
224
+
225
+ # Time padding
226
+ padding = get_padding_elem(L_in, stride, kernel_size, dilation)
227
+
228
+ # Applying padding
229
+ x = F.pad(x, padding, mode=self.padding_mode)
230
+
231
+ return x
232
+
233
+ def _check_input_shape(self, shape):
234
+ """Checks the input shape and returns the number of input channels.
235
+ """
236
+
237
+ if len(shape) == 2:
238
+ self.unsqueeze = True
239
+ in_channels = 1
240
+ elif self.skip_transpose:
241
+ in_channels = shape[1]
242
+ elif len(shape) == 3:
243
+ in_channels = shape[2]
244
+ else:
245
+ raise ValueError(
246
+ "conv1d expects 2d, 3d inputs. Got " + str(len(shape))
247
+ )
248
+
249
+ # Kernel size must be odd
250
+ if self.kernel_size % 2 == 0:
251
+ raise ValueError(
252
+ "The field kernel size must be an odd number. Got %s."
253
+ % (self.kernel_size)
254
+ )
255
+ return in_channels
256
+
257
+ class _BatchNorm1d(nn.Module):
258
+ """Applies 1d batch normalization to the input tensor.
259
+
260
+ Arguments
261
+ ---------
262
+ input_shape : tuple
263
+ The expected shape of the input. Alternatively, use ``input_size``.
264
+ input_size : int
265
+ The expected size of the input. Alternatively, use ``input_shape``.
266
+ eps : float
267
+ This value is added to std deviation estimation to improve the numerical
268
+ stability.
269
+ momentum : float
270
+ It is a value used for the running_mean and running_var computation.
271
+ affine : bool
272
+ When set to True, the affine parameters are learned.
273
+ track_running_stats : bool
274
+ When set to True, this module tracks the running mean and variance,
275
+ and when set to False, this module does not track such statistics.
276
+ combine_batch_time : bool
277
+ When true, it combines batch an time axis.
278
+
279
+
280
+ Example
281
+ -------
282
+ >>> input = torch.randn(100, 10)
283
+ >>> norm = BatchNorm1d(input_shape=input.shape)
284
+ >>> output = norm(input)
285
+ >>> output.shape
286
+ torch.Size([100, 10])
287
+ """
288
+
289
+ def __init__(
290
+ self,
291
+ input_shape=None,
292
+ input_size=None,
293
+ eps=1e-05,
294
+ momentum=0.1,
295
+ affine=True,
296
+ track_running_stats=True,
297
+ combine_batch_time=False,
298
+ skip_transpose=False,
299
+ ):
300
+ super().__init__()
301
+ self.combine_batch_time = combine_batch_time
302
+ self.skip_transpose = skip_transpose
303
+
304
+ if input_size is None and skip_transpose:
305
+ input_size = input_shape[1]
306
+ elif input_size is None:
307
+ input_size = input_shape[-1]
308
+
309
+ self.norm = nn.BatchNorm1d(
310
+ input_size,
311
+ eps=eps,
312
+ momentum=momentum,
313
+ affine=affine,
314
+ track_running_stats=track_running_stats,
315
+ )
316
+
317
+ def forward(self, x):
318
+ """Returns the normalized input tensor.
319
+
320
+ Arguments
321
+ ---------
322
+ x : torch.Tensor (batch, time, [channels])
323
+ input to normalize. 2d or 3d tensors are expected in input
324
+ 4d tensors can be used when combine_dims=True.
325
+ """
326
+ shape_or = x.shape
327
+ if self.combine_batch_time:
328
+ if x.ndim == 3:
329
+ x = x.reshape(shape_or[0] * shape_or[1], shape_or[2])
330
+ else:
331
+ x = x.reshape(
332
+ shape_or[0] * shape_or[1], shape_or[3], shape_or[2]
333
+ )
334
+
335
+ elif not self.skip_transpose:
336
+ x = x.transpose(-1, 1)
337
+
338
+ x_n = self.norm(x)
339
+
340
+ if self.combine_batch_time:
341
+ x_n = x_n.reshape(shape_or)
342
+ elif not self.skip_transpose:
343
+ x_n = x_n.transpose(1, -1)
344
+
345
+ return x_n
346
+
347
+ class Linear(torch.nn.Module):
348
+ """Computes a linear transformation y = wx + b.
349
+
350
+ Arguments
351
+ ---------
352
+ n_neurons : int
353
+ It is the number of output neurons (i.e, the dimensionality of the
354
+ output).
355
+ input_shape: tuple
356
+ It is the shape of the input tensor.
357
+ input_size: int
358
+ Size of the input tensor.
359
+ bias : bool
360
+ If True, the additive bias b is adopted.
361
+ combine_dims : bool
362
+ If True and the input is 4D, combine 3rd and 4th dimensions of input.
363
+
364
+ Example
365
+ -------
366
+ >>> inputs = torch.rand(10, 50, 40)
367
+ >>> lin_t = Linear(input_shape=(10, 50, 40), n_neurons=100)
368
+ >>> output = lin_t(inputs)
369
+ >>> output.shape
370
+ torch.Size([10, 50, 100])
371
+ """
372
+
373
+ def __init__(
374
+ self,
375
+ n_neurons,
376
+ input_shape=None,
377
+ input_size=None,
378
+ bias=True,
379
+ combine_dims=False,
380
+ ):
381
+ super().__init__()
382
+ self.combine_dims = combine_dims
383
+
384
+ if input_shape is None and input_size is None:
385
+ raise ValueError("Expected one of input_shape or input_size")
386
+
387
+ if input_size is None:
388
+ input_size = input_shape[-1]
389
+ if len(input_shape) == 4 and self.combine_dims:
390
+ input_size = input_shape[2] * input_shape[3]
391
+
392
+ # Weights are initialized following pytorch approach
393
+ self.w = nn.Linear(input_size, n_neurons, bias=bias)
394
+
395
+ def forward(self, x):
396
+ """Returns the linear transformation of input tensor.
397
+
398
+ Arguments
399
+ ---------
400
+ x : torch.Tensor
401
+ Input to transform linearly.
402
+ """
403
+ if x.ndim == 4 and self.combine_dims:
404
+ x = x.reshape(x.shape[0], x.shape[1], x.shape[2] * x.shape[3])
405
+
406
+ wx = self.w(x)
407
+
408
+ return wx
409
+
410
+ class Conv1d(_Conv1d):
411
+ def __init__(self, *args, **kwargs):
412
+ super().__init__(skip_transpose=True, *args, **kwargs)
413
+
414
+
415
+ class BatchNorm1d(_BatchNorm1d):
416
+ def __init__(self, *args, **kwargs):
417
+ super().__init__(skip_transpose=True, *args, **kwargs)
418
+
419
+
420
+ class TDNNBlock(nn.Module):
421
+ """An implementation of TDNN.
422
+
423
+ Arguments
424
+ ----------
425
+ in_channels : int
426
+ Number of input channels.
427
+ out_channels : int
428
+ The number of output channels.
429
+ kernel_size : int
430
+ The kernel size of the TDNN blocks.
431
+ dilation : int
432
+ The dilation of the Res2Net block.
433
+ activation : torch class
434
+ A class for constructing the activation layers.
435
+
436
+ Example
437
+ -------
438
+ >>> inp_tensor = torch.rand([8, 120, 64]).transpose(1, 2)
439
+ >>> layer = TDNNBlock(64, 64, kernel_size=3, dilation=1)
440
+ >>> out_tensor = layer(inp_tensor).transpose(1, 2)
441
+ >>> out_tensor.shape
442
+ torch.Size([8, 120, 64])
443
+ """
444
+
445
+ def __init__(
446
+ self,
447
+ in_channels,
448
+ out_channels,
449
+ kernel_size,
450
+ dilation,
451
+ activation=nn.ReLU,
452
+ ):
453
+ super(TDNNBlock, self).__init__()
454
+ self.conv = Conv1d(
455
+ in_channels=in_channels,
456
+ out_channels=out_channels,
457
+ kernel_size=kernel_size,
458
+ dilation=dilation,
459
+ )
460
+ self.activation = activation()
461
+ self.norm = BatchNorm1d(input_size=out_channels)
462
+
463
+ def forward(self, x):
464
+ return self.norm(self.activation(self.conv(x)))
465
+
466
+
467
+ class Res2NetBlock(torch.nn.Module):
468
+ """An implementation of Res2NetBlock w/ dilation.
469
+
470
+ Arguments
471
+ ---------
472
+ in_channels : int
473
+ The number of channels expected in the input.
474
+ out_channels : int
475
+ The number of output channels.
476
+ scale : int
477
+ The scale of the Res2Net block.
478
+ kernel_size: int
479
+ The kernel size of the Res2Net block.
480
+ dilation : int
481
+ The dilation of the Res2Net block.
482
+
483
+ Example
484
+ -------
485
+ >>> inp_tensor = torch.rand([8, 120, 64]).transpose(1, 2)
486
+ >>> layer = Res2NetBlock(64, 64, scale=4, dilation=3)
487
+ >>> out_tensor = layer(inp_tensor).transpose(1, 2)
488
+ >>> out_tensor.shape
489
+ torch.Size([8, 120, 64])
490
+ """
491
+
492
+ def __init__(
493
+ self, in_channels, out_channels, scale=8, kernel_size=3, dilation=1
494
+ ):
495
+ super(Res2NetBlock, self).__init__()
496
+ assert in_channels % scale == 0
497
+ assert out_channels % scale == 0
498
+
499
+ in_channel = in_channels // scale
500
+ hidden_channel = out_channels // scale
501
+
502
+ self.blocks = nn.ModuleList(
503
+ [
504
+ TDNNBlock(
505
+ in_channel,
506
+ hidden_channel,
507
+ kernel_size=kernel_size,
508
+ dilation=dilation,
509
+ )
510
+ for i in range(scale - 1)
511
+ ]
512
+ )
513
+ self.scale = scale
514
+
515
+ def forward(self, x):
516
+ y = []
517
+ for i, x_i in enumerate(torch.chunk(x, self.scale, dim=1)):
518
+ if i == 0:
519
+ y_i = x_i
520
+ elif i == 1:
521
+ y_i = self.blocks[i - 1](x_i)
522
+ else:
523
+ y_i = self.blocks[i - 1](x_i + y_i)
524
+ y.append(y_i)
525
+ y = torch.cat(y, dim=1)
526
+ return y
527
+
528
+
529
+ class SEBlock(nn.Module):
530
+ """An implementation of squeeze-and-excitation block.
531
+
532
+ Arguments
533
+ ---------
534
+ in_channels : int
535
+ The number of input channels.
536
+ se_channels : int
537
+ The number of output channels after squeeze.
538
+ out_channels : int
539
+ The number of output channels.
540
+
541
+ Example
542
+ -------
543
+ >>> inp_tensor = torch.rand([8, 120, 64]).transpose(1, 2)
544
+ >>> se_layer = SEBlock(64, 16, 64)
545
+ >>> lengths = torch.rand((8,))
546
+ >>> out_tensor = se_layer(inp_tensor, lengths).transpose(1, 2)
547
+ >>> out_tensor.shape
548
+ torch.Size([8, 120, 64])
549
+ """
550
+
551
+ def __init__(self, in_channels, se_channels, out_channels):
552
+ super(SEBlock, self).__init__()
553
+
554
+ self.conv1 = Conv1d(
555
+ in_channels=in_channels, out_channels=se_channels, kernel_size=1
556
+ )
557
+ self.relu = torch.nn.ReLU(inplace=True)
558
+ self.conv2 = Conv1d(
559
+ in_channels=se_channels, out_channels=out_channels, kernel_size=1
560
+ )
561
+ self.sigmoid = torch.nn.Sigmoid()
562
+
563
+ def forward(self, x, lengths=None):
564
+ L = x.shape[-1]
565
+ if lengths is not None:
566
+ mask = length_to_mask(lengths * L, max_len=L, device=x.device)
567
+ mask = mask.unsqueeze(1)
568
+ total = mask.sum(dim=2, keepdim=True)
569
+ s = (x * mask).sum(dim=2, keepdim=True) / total
570
+ else:
571
+ s = x.mean(dim=2, keepdim=True)
572
+
573
+ s = self.relu(self.conv1(s))
574
+ s = self.sigmoid(self.conv2(s))
575
+
576
+ return s * x
577
+
578
+
579
+ class AttentiveStatisticsPooling(nn.Module):
580
+ """This class implements an attentive statistic pooling layer for each channel.
581
+ It returns the concatenated mean and std of the input tensor.
582
+
583
+ Arguments
584
+ ---------
585
+ channels: int
586
+ The number of input channels.
587
+ attention_channels: int
588
+ The number of attention channels.
589
+
590
+ Example
591
+ -------
592
+ >>> inp_tensor = torch.rand([8, 120, 64]).transpose(1, 2)
593
+ >>> asp_layer = AttentiveStatisticsPooling(64)
594
+ >>> lengths = torch.rand((8,))
595
+ >>> out_tensor = asp_layer(inp_tensor, lengths).transpose(1, 2)
596
+ >>> out_tensor.shape
597
+ torch.Size([8, 1, 128])
598
+ """
599
+
600
+ def __init__(self, channels, attention_channels=128, global_context=True):
601
+ super().__init__()
602
+
603
+ self.eps = 1e-12
604
+ self.global_context = global_context
605
+ if global_context:
606
+ self.tdnn = TDNNBlock(channels * 3, attention_channels, 1, 1)
607
+ else:
608
+ self.tdnn = TDNNBlock(channels, attention_channels, 1, 1)
609
+ self.tanh = nn.Tanh()
610
+ self.conv = Conv1d(
611
+ in_channels=attention_channels, out_channels=channels, kernel_size=1
612
+ )
613
+
614
+ def forward(self, x, lengths=None):
615
+ """Calculates mean and std for a batch (input tensor).
616
+
617
+ Arguments
618
+ ---------
619
+ x : torch.Tensor
620
+ Tensor of shape [N, C, L].
621
+ """
622
+ L = x.shape[-1]
623
+
624
+ def _compute_statistics(x, m, dim=2, eps=self.eps):
625
+ mean = (m * x).sum(dim)
626
+ std = torch.sqrt(
627
+ (m * (x - mean.unsqueeze(dim)).pow(2)).sum(dim).clamp(eps)
628
+ )
629
+ return mean, std
630
+
631
+ if lengths is None:
632
+ lengths = torch.ones(x.shape[0], device=x.device)
633
+
634
+ # Make binary mask of shape [N, 1, L]
635
+ mask = length_to_mask(lengths * L, max_len=L, device=x.device) # mask生成的是一种全1的(N,L)
636
+ mask = mask.unsqueeze(1)
637
+
638
+ # Expand the temporal context of the pooling layer by allowing the
639
+ # self-attention to look at global properties of the utterance.
640
+ if self.global_context:
641
+ # torch.std is unstable for backward computation
642
+ # https://github.com/pytorch/pytorch/issues/4320
643
+ total = mask.sum(dim=2, keepdim=True).float()
644
+ mean, std = _compute_statistics(x, mask / total)
645
+ mean = mean.unsqueeze(2).repeat(1, 1, L)
646
+ std = std.unsqueeze(2).repeat(1, 1, L)
647
+ attn = torch.cat([x, mean, std], dim=1)
648
+ else:
649
+ attn = x
650
+
651
+ # Apply layers
652
+ attn = self.conv(self.tanh(self.tdnn(attn)))
653
+
654
+ # Filter out zero-paddings
655
+ attn = attn.masked_fill(mask == 0, float("-inf"))
656
+
657
+ attn = F.softmax(attn, dim=2)
658
+ mean, std = _compute_statistics(x, attn)
659
+ # Append mean and std of the batch
660
+ pooled_stats = torch.cat((mean, std), dim=1)
661
+ pooled_stats = pooled_stats.unsqueeze(2)
662
+
663
+ return pooled_stats
664
+
665
+
666
+ class SERes2NetBlock(nn.Module):
667
+ """An implementation of building block in ECAPA-TDNN, i.e.,
668
+ TDNN-Res2Net-TDNN-SEBlock.
669
+
670
+ Arguments
671
+ ----------
672
+ out_channels: int
673
+ The number of output channels.
674
+ res2net_scale: int
675
+ The scale of the Res2Net block.
676
+ kernel_size: int
677
+ The kernel size of the TDNN blocks.
678
+ dilation: int
679
+ The dilation of the Res2Net block.
680
+ activation : torch class
681
+ A class for constructing the activation layers.
682
+
683
+ Example
684
+ -------
685
+ >>> x = torch.rand(8, 120, 64).transpose(1, 2)
686
+ >>> conv = SERes2NetBlock(64, 64, res2net_scale=4)
687
+ >>> out = conv(x).transpose(1, 2)
688
+ >>> out.shape
689
+ torch.Size([8, 120, 64])
690
+ """
691
+
692
+ def __init__(
693
+ self,
694
+ in_channels,
695
+ out_channels,
696
+ res2net_scale=8,
697
+ se_channels=128,
698
+ kernel_size=1,
699
+ dilation=1,
700
+ activation=torch.nn.ReLU,
701
+ ):
702
+ super().__init__()
703
+ self.out_channels = out_channels
704
+ self.tdnn1 = TDNNBlock(
705
+ in_channels,
706
+ out_channels,
707
+ kernel_size=1,
708
+ dilation=1,
709
+ activation=activation,
710
+ )
711
+ self.res2net_block = Res2NetBlock(
712
+ out_channels, out_channels, res2net_scale, kernel_size, dilation
713
+ )
714
+ self.tdnn2 = TDNNBlock(
715
+ out_channels,
716
+ out_channels,
717
+ kernel_size=1,
718
+ dilation=1,
719
+ activation=activation,
720
+ )
721
+ self.se_block = SEBlock(out_channels, se_channels, out_channels)
722
+
723
+ self.shortcut = None
724
+ if in_channels != out_channels:
725
+ self.shortcut = Conv1d(
726
+ in_channels=in_channels,
727
+ out_channels=out_channels,
728
+ kernel_size=1,
729
+ )
730
+
731
+ def forward(self, x, lengths=None):
732
+ residual = x
733
+ if self.shortcut:
734
+ residual = self.shortcut(x)
735
+
736
+ x = self.tdnn1(x)
737
+ x = self.res2net_block(x)
738
+ x = self.tdnn2(x)
739
+ x = self.se_block(x, lengths)
740
+
741
+ return x + residual
742
+
743
+
744
+ class ECAPATDNN(torch.nn.Module):
745
+ """An implementation of the speaker embedding model in a paper.
746
+ "ECAPA-TDNN: Emphasized Channel Attention, Propagation and Aggregation in
747
+ TDNN Based Speaker Verification" (https://arxiv.org/abs/2005.07143).
748
+
749
+ Arguments
750
+ ---------
751
+ device : str
752
+ Device used, e.g., "cpu" or "cuda".
753
+ activation : torch class
754
+ A class for constructing the activation layers.
755
+ channels : list of ints
756
+ Output channels for TDNN/SERes2Net layer.
757
+ kernel_sizes : list of ints
758
+ List of kernel sizes for each layer.
759
+ dilations : list of ints
760
+ List of dilations for kernels in each layer.
761
+ lin_neurons : int
762
+ Number of neurons in linear layers.
763
+
764
+ Example
765
+ -------
766
+ >>> input_feats = torch.rand([5, 120, 80])
767
+ >>> compute_embedding = ECAPATDNN(80, lin_neurons=192)
768
+ >>> outputs = compute_embedding(input_feats)
769
+ >>> outputs.shape
770
+ torch.Size([5, 1, 192])
771
+ """
772
+
773
+ def __init__(
774
+ self,
775
+ input_size,
776
+ device="cpu",
777
+ lin_neurons=192,
778
+ activation=torch.nn.ReLU,
779
+ channels=[512, 512, 512, 512, 1536],
780
+ kernel_sizes=[5, 3, 3, 3, 1],
781
+ dilations=[1, 2, 3, 4, 1],
782
+ attention_channels=128,
783
+ res2net_scale=8,
784
+ se_channels=128,
785
+ global_context=True,
786
+ ):
787
+
788
+ super().__init__()
789
+ assert len(channels) == len(kernel_sizes)
790
+ assert len(channels) == len(dilations)
791
+ self.channels = channels
792
+ self.torchfb = torchaudio.transforms.MelSpectrogram(sample_rate=16000, n_fft=512, win_length=400,
793
+ hop_length=160, f_min=0.0, f_max=8000, pad=0, n_mels=80)
794
+ self.instancenorm = nn.InstanceNorm1d(40)
795
+ self.blocks = nn.ModuleList()
796
+
797
+ # The initial TDNN layer
798
+ self.blocks.append(
799
+ TDNNBlock(
800
+ input_size,
801
+ channels[0],
802
+ kernel_sizes[0],
803
+ dilations[0],
804
+ activation,
805
+ )
806
+ )
807
+
808
+ # SE-Res2Net layers
809
+ for i in range(1, len(channels) - 1):
810
+ self.blocks.append(
811
+ SERes2NetBlock(
812
+ channels[i - 1],
813
+ channels[i],
814
+ res2net_scale=res2net_scale,
815
+ se_channels=se_channels,
816
+ kernel_size=kernel_sizes[i],
817
+ dilation=dilations[i],
818
+ activation=activation,
819
+ )
820
+ )
821
+
822
+ # Multi-layer feature aggregation
823
+ self.mfa = TDNNBlock(
824
+ channels[-1],
825
+ channels[-1],
826
+ kernel_sizes[-1],
827
+ dilations[-1],
828
+ activation,
829
+ )
830
+
831
+ # Attentive Statistical Pooling
832
+ self.asp = AttentiveStatisticsPooling(
833
+ channels[-1],
834
+ attention_channels=attention_channels,
835
+ global_context=global_context,
836
+ )
837
+ self.asp_bn = BatchNorm1d(input_size=channels[-1] * 2)
838
+
839
+ # Final linear transformation
840
+ self.fc = Conv1d(
841
+ in_channels=channels[-1] * 2,
842
+ out_channels=lin_neurons,
843
+ kernel_size=1,
844
+ )
845
+
846
+ def forward(self, x, lengths=None):
847
+ """Returns the embedding vector.
848
+
849
+ Arguments
850
+ ---------
851
+ x : torch.Tensor
852
+ Tensor of shape (batch, channel, time).
853
+ """
854
+ # Minimize transpose for efficiency
855
+ x = self.torchfb(x) + 1e-6
856
+ x = x.log()
857
+ x = self.instancenorm(x)
858
+
859
+ xl = []
860
+ for layer in self.blocks:
861
+ try:
862
+ x = layer(x, lengths=lengths)
863
+ except TypeError:
864
+ x = layer(x)
865
+ xl.append(x)
866
+
867
+ # Multi-layer feature aggregation
868
+ x = torch.cat(xl[1:], dim=1)
869
+ x = self.mfa(x)
870
+
871
+ # Attentive Statistical Pooling
872
+ x = self.asp(x, lengths=lengths)
873
+ x = self.asp_bn(x)
874
+
875
+ # Final linear transformation
876
+ x = self.fc(x)
877
+
878
+ x = x.transpose(1, 2).squeeze(1)
879
+ return x
880
+
881
+
882
+ class Classifier(torch.nn.Module):
883
+ """This class implements the cosine similarity on the top of features.
884
+
885
+ Arguments
886
+ ---------
887
+ device : str
888
+ Device used, e.g., "cpu" or "cuda".
889
+ lin_blocks : int
890
+ Number of linear layers.
891
+ lin_neurons : int
892
+ Number of neurons in linear layers.
893
+ out_neurons : int
894
+ Number of classes.
895
+
896
+ Example
897
+ -------
898
+ >>> classify = Classifier(input_size=2, lin_neurons=2, out_neurons=2)
899
+ >>> outputs = torch.tensor([ [1., -1.], [-9., 1.], [0.9, 0.1], [0.1, 0.9] ])
900
+ >>> outupts = outputs.unsqueeze(1)
901
+ >>> cos = classify(outputs)
902
+ >>> (cos < -1.0).long().sum()
903
+ tensor(0)
904
+ >>> (cos > 1.0).long().sum()
905
+ tensor(0)
906
+ """
907
+
908
+ def __init__(
909
+ self,
910
+ input_size,
911
+ device="cpu",
912
+ lin_blocks=0,
913
+ lin_neurons=192,
914
+ out_neurons=1211,
915
+ ):
916
+
917
+ super().__init__()
918
+ self.blocks = nn.ModuleList()
919
+
920
+ for block_index in range(lin_blocks):
921
+ self.blocks.extend(
922
+ [
923
+ _BatchNorm1d(input_size),
924
+ Linear(input_size=input_size, n_neurons=lin_neurons),
925
+ ]
926
+ )
927
+ input_size = lin_neurons
928
+
929
+ # Final Layer
930
+ self.weight = nn.Parameter(
931
+ torch.FloatTensor(out_neurons, input_size, device=device)
932
+ )
933
+ nn.init.xavier_uniform_(self.weight)
934
+
935
+ def forward(self, x):
936
+ """Returns the output probabilities over speakers.
937
+
938
+ Arguments
939
+ ---------
940
+ x : torch.Tensor
941
+ Torch tensor.
942
+ """
943
+ for layer in self.blocks:
944
+ x = layer(x)
945
+
946
+ # Need to be normalized
947
+ x = F.linear(F.normalize(x.squeeze(1)), F.normalize(self.weight))
948
+ return x.unsqueeze(1)
949
+
950
+
951
+ if __name__ == '__main__':
952
+ x = torch.zeros(32, 32240)
953
+ model = ECAPATDNN(80, lin_neurons=192)
954
+ out = model(x)
955
+ print(out.shape) # should be [2, 192]
net/ECAPA_TDNN.py ADDED
@@ -0,0 +1,246 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import torchaudio
5
+ from torchinfo import summary
6
+
7
+
8
+
9
+ ''' Res2Conv1d + BatchNorm1d + ReLU
10
+ '''
11
+ class Res2Conv1dReluBn(nn.Module):
12
+ '''
13
+ in_channels == out_channels == channels
14
+ '''
15
+ def __init__(self, channels, kernel_size=1, stride=1, padding=0, dilation=1, bias=False, scale=4):
16
+ super().__init__()
17
+ assert channels % scale == 0, "{} % {} != 0".format(channels, scale)
18
+ self.scale = scale
19
+ self.width = channels // scale
20
+ self.nums = scale if scale == 1 else scale - 1
21
+
22
+ self.convs = []
23
+ self.bns = []
24
+ for i in range(self.nums):
25
+ self.convs.append(nn.Conv1d(self.width, self.width, kernel_size, stride, padding, dilation, bias=bias))
26
+ self.bns.append(nn.BatchNorm1d(self.width))
27
+ self.convs = nn.ModuleList(self.convs)
28
+ self.bns = nn.ModuleList(self.bns)
29
+
30
+ def forward(self, x):
31
+ out = []
32
+ spx = torch.split(x, self.width, 1)
33
+ for i in range(self.nums):
34
+ if i == 0:
35
+ sp = spx[i]
36
+ else:
37
+ sp = sp + spx[i]
38
+ # Order: conv -> relu -> bn
39
+ sp = self.convs[i](sp)
40
+ sp = self.bns[i](F.relu(sp))
41
+ out.append(sp)
42
+ if self.scale != 1:
43
+ out.append(spx[self.nums])
44
+ out = torch.cat(out, dim=1)
45
+ return out
46
+
47
+
48
+
49
+ ''' Conv1d + BatchNorm1d + ReLU
50
+ '''
51
+ class Conv1dReluBn(nn.Module):
52
+ def __init__(self, in_channels, out_channels, kernel_size=1, stride=1, padding=0, dilation=1, bias=False):
53
+ super().__init__()
54
+ self.conv = nn.Conv1d(in_channels, out_channels, kernel_size, stride, padding, dilation, bias=bias)
55
+ self.bn = nn.BatchNorm1d(out_channels)
56
+
57
+ def forward(self, x):
58
+ return self.bn(F.relu(self.conv(x)))
59
+
60
+
61
+
62
+ ''' The SE connection of 1D case.
63
+ '''
64
+ class SE_Connect(nn.Module):
65
+ def __init__(self, channels, s=2):
66
+ super().__init__()
67
+ assert channels % s == 0, "{} % {} != 0".format(channels, s)
68
+ self.linear1 = nn.Linear(channels, channels // s)
69
+ self.linear2 = nn.Linear(channels // s, channels)
70
+
71
+ def forward(self, x):
72
+ out = x.mean(dim=2)
73
+ out = F.relu(self.linear1(out))
74
+ out = torch.sigmoid(self.linear2(out))
75
+ out = x * out.unsqueeze(2)
76
+ return out
77
+
78
+
79
+
80
+ ''' SE-Res2Block.
81
+ Note: residual connection is implemented in the ECAPA_TDNN.yaml model, not here.
82
+ '''
83
+
84
+ class SE_Res2Block(nn.Module):
85
+ def __init__(self, channels, kernel_size, stride, padding, dilation, scale):
86
+ super().__init__()
87
+ self.block = nn.Sequential(
88
+ Conv1dReluBn(channels, channels, kernel_size=1, stride=1, padding=0),
89
+ Res2Conv1dReluBn(channels, kernel_size, stride, padding, dilation, scale=scale),
90
+ Conv1dReluBn(channels, channels, kernel_size=1, stride=1, padding=0),
91
+ SE_Connect(channels)
92
+ )
93
+
94
+ def forward(self, x):
95
+ out = self.block(x)
96
+ return out + x
97
+
98
+
99
+
100
+ ''' Attentive weighted mean and standard deviation pooling.
101
+ '''
102
+ class AttentiveStatsPool(nn.Module):
103
+ def __init__(self, in_dim, bottleneck_dim):
104
+ super().__init__()
105
+ # Use Conv1d with stride == 1 rather than Linear, then we don't need to transpose inputs.
106
+ self.linear1 = nn.Conv1d(in_dim, bottleneck_dim, kernel_size=1) # equals W and b in the paper
107
+ self.linear2 = nn.Conv1d(bottleneck_dim, in_dim, kernel_size=1) # equals V and k in the paper
108
+
109
+ def forward(self, x):
110
+ # DON'T use ReLU here! In experiments, I find ReLU hard to converge.
111
+ alpha = torch.tanh(self.linear1(x))
112
+ alpha = torch.softmax(self.linear2(alpha), dim=2)
113
+ mean = torch.sum(alpha * x, dim=2)
114
+ residuals = torch.sum(alpha * x ** 2, dim=2) - mean ** 2
115
+ std = torch.sqrt(residuals.clamp(min=1e-9))
116
+ return torch.cat([mean, std], dim=1)
117
+
118
+
119
+
120
+ ''' Implementation of
121
+ "ECAPA-TDNN: Emphasized Channel Attention, Propagation and Aggregation in TDNN Based Speaker Verification".
122
+ Note that we DON'T concatenate the last frame-wise layer with non-weighted mean and standard deviation,
123
+ because it brings little improvment but significantly increases model parameters.
124
+ As a result, this implementation basically equals the A.2 of Table 2 in the paper.
125
+ '''
126
+ class ECAPA_TDNN(nn.Module):
127
+ def __init__(self, in_channels=80, channels=512, embd_dim=192):
128
+ super().__init__()
129
+ self.torchfb = torchaudio.transforms.MelSpectrogram(sample_rate=16000, n_fft=512, win_length=400,
130
+ hop_length=160, f_min=0.0, f_max=8000, pad=0, n_mels=80)
131
+ self.instancenorm = nn.InstanceNorm1d(80)
132
+ self.layer1 = Conv1dReluBn(in_channels, channels, kernel_size=5, padding=2)
133
+ self.layer2 = SE_Res2Block(channels, kernel_size=3, stride=1, padding=2, dilation=2, scale=8)
134
+ self.layer3 = SE_Res2Block(channels, kernel_size=3, stride=1, padding=3, dilation=3, scale=8)
135
+ self.layer4 = SE_Res2Block(channels, kernel_size=3, stride=1, padding=4, dilation=4, scale=8)
136
+
137
+ cat_channels = channels * 3
138
+ self.conv = nn.Conv1d(cat_channels, cat_channels, kernel_size=1)
139
+ self.pooling = AttentiveStatsPool(cat_channels, 128)
140
+ self.bn1 = nn.BatchNorm1d(cat_channels * 2)
141
+ self.linear = nn.Linear(cat_channels * 2, embd_dim)
142
+ self.bn2 = nn.BatchNorm1d(embd_dim)
143
+
144
+ def forward(self, x):
145
+ x = self.torchfb(x) + 1e-6
146
+ x = x.log()
147
+ x = self.instancenorm(x)
148
+ # print(x.shape)
149
+ # x = x.transpose(1, 2)
150
+ out1 = self.layer1(x)
151
+ out2 = self.layer2(out1) + out1
152
+ out3 = self.layer3(out1 + out2) + out1 + out2
153
+ out4 = self.layer4(out1 + out2 + out3) + out1 + out2 + out3
154
+
155
+ out = torch.cat([out2, out3, out4], dim=1)
156
+ out = F.relu(self.conv(out))
157
+ # print(out.shape)
158
+ out = self.bn1(self.pooling(out))
159
+ # print(out.shape)
160
+ out = self.bn2(self.linear(out))
161
+ return out
162
+
163
+
164
+ class ECAPA_TDNN_ks5(nn.Module):
165
+ def __init__(self, in_channels=80, channels=512, embd_dim=192):
166
+ super().__init__()
167
+ self.torchfb = torchaudio.transforms.MelSpectrogram(sample_rate=16000, n_fft=512, win_length=400,
168
+ hop_length=160, f_min=0.0, f_max=8000, pad=0, n_mels=80)
169
+ self.instancenorm = nn.InstanceNorm1d(40)
170
+ self.layer1 = Conv1dReluBn(in_channels, channels, kernel_size=7, padding=3)
171
+ self.layer2 = SE_Res2Block(channels, kernel_size=5, stride=1, padding=4, dilation=2, scale=8)
172
+ self.layer3 = SE_Res2Block(channels, kernel_size=5, stride=1, padding=6, dilation=3, scale=8)
173
+ self.layer4 = SE_Res2Block(channels, kernel_size=5, stride=1, padding=8, dilation=4, scale=8)
174
+
175
+ cat_channels = channels * 3
176
+ self.conv = nn.Conv1d(cat_channels, cat_channels, kernel_size=1)
177
+ self.pooling = AttentiveStatsPool(cat_channels, 128)
178
+ self.bn1 = nn.BatchNorm1d(cat_channels * 2)
179
+ self.linear = nn.Linear(cat_channels * 2, embd_dim)
180
+ self.bn2 = nn.BatchNorm1d(embd_dim)
181
+
182
+ def forward(self, x):
183
+ x = self.torchfb(x) + 1e-6
184
+ x = x.log()
185
+ x = self.instancenorm(x)
186
+ # print(x.shape)
187
+ # x = x.transpose(1, 2)
188
+ out1 = self.layer1(x)
189
+ out2 = self.layer2(out1) + out1
190
+ out3 = self.layer3(out1 + out2) + out1 + out2
191
+ out4 = self.layer4(out1 + out2 + out3) + out1 + out2 + out3
192
+
193
+ out = torch.cat([out2, out3, out4], dim=1)
194
+ out = F.relu(self.conv(out))
195
+ out = self.bn1(self.pooling(out))
196
+ out = self.bn2(self.linear(out))
197
+
198
+ return out
199
+
200
+
201
+ class ECAPA_TDNN_L2(nn.Module):
202
+ def __init__(self, in_channels=80, channels=512, embd_dim=192):
203
+ super().__init__()
204
+ self.torchfb = torchaudio.transforms.MelSpectrogram(sample_rate=16000, n_fft=512, win_length=400,
205
+ hop_length=160, f_min=0.0, f_max=8000, pad=0, n_mels=80)
206
+ self.instancenorm = nn.InstanceNorm1d(40)
207
+ self.layer1 = Conv1dReluBn(in_channels, channels, kernel_size=5, padding=2)
208
+ self.layer2 = SE_Res2Block(channels, kernel_size=3, stride=1, padding=2, dilation=2, scale=8)
209
+ self.layer3 = SE_Res2Block(channels, kernel_size=3, stride=1, padding=3, dilation=3, scale=8)
210
+ self.layer4 = SE_Res2Block(channels, kernel_size=3, stride=1, padding=4, dilation=4, scale=8)
211
+
212
+ cat_channels = channels * 3
213
+ self.conv = nn.Conv1d(cat_channels, cat_channels, kernel_size=1)
214
+ self.pooling = AttentiveStatsPool(cat_channels, 128)
215
+ self.bn1 = nn.BatchNorm1d(cat_channels * 2)
216
+ self.linear = nn.Linear(cat_channels * 2, embd_dim)
217
+ self.bn2 = nn.BatchNorm1d(embd_dim)
218
+
219
+ def forward(self, x):
220
+ x = self.torchfb(x) + 1e-6
221
+ x = x.log()
222
+ x = self.instancenorm(x)
223
+ # print(x.shape)
224
+ # x = x.transpose(1, 2)
225
+ out1 = self.layer1(x)
226
+ out2 = self.layer2(out1) + out1
227
+ out3 = self.layer3(out1 + out2) + out1 + out2
228
+ out4 = self.layer4(out1 + out2 + out3) + out1 + out2 + out3
229
+
230
+ out = torch.cat([out2, out3, out4], dim=1)
231
+ out = F.relu(self.conv(out))
232
+ out = self.bn1(self.pooling(out))
233
+ out = self.bn2(self.linear(out))
234
+ out_l2 = out / torch.norm(out, dim=1, keepdim=True)
235
+ return out_l2*512
236
+
237
+
238
+ if __name__ == '__main__':
239
+ # Input size: batch_size * seq_len * feat_dim 32240 => 202, 35760=>224
240
+ x = torch.zeros(32, 35760).cuda()
241
+ model = ECAPA_TDNN(in_channels=80, channels=512, embd_dim=192)
242
+ # print(model)
243
+ summary(model, input_size=(tuple(x.shape)))
244
+ out = model(x)
245
+ print(out.shape) # should be [2, 192]
246
+
net/ECAPA_TDNN_br.py ADDED
@@ -0,0 +1,171 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import torchaudio
5
+ from torchinfo import summary
6
+
7
+
8
+
9
+ ''' Res2Conv1d + BatchNorm1d + ReLU
10
+ '''
11
+ class Res2Conv1dReluBn(nn.Module):
12
+ '''
13
+ in_channels == out_channels == channels
14
+ '''
15
+ def __init__(self, channels, kernel_size=1, stride=1, padding=0, dilation=1, bias=False, scale=4):
16
+ super().__init__()
17
+ assert channels % scale == 0, "{} % {} != 0".format(channels, scale)
18
+ self.scale = scale
19
+ self.width = channels // scale
20
+ self.nums = scale if scale == 1 else scale - 1
21
+
22
+ self.convs = []
23
+ self.bns = []
24
+ for i in range(self.nums):
25
+ self.convs.append(nn.Conv1d(self.width, self.width, kernel_size, stride, padding, dilation, bias=bias))
26
+ self.bns.append(nn.BatchNorm1d(self.width))
27
+ self.convs = nn.ModuleList(self.convs)
28
+ self.bns = nn.ModuleList(self.bns)
29
+
30
+ def forward(self, x):
31
+ out = []
32
+ spx = torch.split(x, self.width, 1)
33
+ for i in range(self.nums):
34
+ if i == 0:
35
+ sp = spx[i]
36
+ else:
37
+ sp = sp + spx[i]
38
+ # Order: conv -> relu -> bn
39
+ sp = self.convs[i](sp)
40
+ sp = F.relu(self.bns[i](sp))
41
+ out.append(sp)
42
+ if self.scale != 1:
43
+ out.append(spx[self.nums])
44
+ out = torch.cat(out, dim=1)
45
+ return out
46
+
47
+
48
+
49
+ ''' Conv1d + BatchNorm1d + ReLU
50
+ '''
51
+ class Conv1dReluBn(nn.Module):
52
+ def __init__(self, in_channels, out_channels, kernel_size=1, stride=1, padding=0, dilation=1, bias=False):
53
+ super().__init__()
54
+ self.conv = nn.Conv1d(in_channels, out_channels, kernel_size, stride, padding, dilation, bias=bias)
55
+ self.bn = nn.BatchNorm1d(out_channels)
56
+
57
+ def forward(self, x):
58
+ return F.relu(self.bn(self.conv(x)))
59
+
60
+
61
+
62
+ ''' The SE connection of 1D case.
63
+ '''
64
+ class SE_Connect(nn.Module):
65
+ def __init__(self, channels, s=2):
66
+ super().__init__()
67
+ assert channels % s == 0, "{} % {} != 0".format(channels, s)
68
+ self.linear1 = nn.Linear(channels, channels // s)
69
+ self.linear2 = nn.Linear(channels // s, channels)
70
+
71
+ def forward(self, x):
72
+ out = x.mean(dim=2)
73
+ out = F.relu(self.linear1(out))
74
+ out = torch.sigmoid(self.linear2(out))
75
+ out = x * out.unsqueeze(2)
76
+ return out
77
+
78
+
79
+
80
+ ''' SE-Res2Block.
81
+ Note: residual connection is implemented in the ECAPA_TDNN.yaml model, not here.
82
+ '''
83
+ class SE_Res2Block(nn.Module):
84
+ def __init__(self, channels, kernel_size, stride, padding, dilation, scale):
85
+ super().__init__()
86
+ self.block = nn.Sequential(
87
+ Conv1dReluBn(channels, channels, kernel_size=1, stride=1, padding=0),
88
+ Res2Conv1dReluBn(channels, kernel_size, stride, padding, dilation, scale=scale),
89
+ Conv1dReluBn(channels, channels, kernel_size=1, stride=1, padding=0),
90
+ SE_Connect(channels)
91
+ )
92
+
93
+ def forward(self, x):
94
+
95
+ out = self.block(x)
96
+ return out + x
97
+
98
+
99
+ ''' Attentive weighted mean and standard deviation pooling.
100
+ '''
101
+ class AttentiveStatsPool(nn.Module):
102
+ def __init__(self, in_dim, bottleneck_dim):
103
+ super().__init__()
104
+ # Use Conv1d with stride == 1 rather than Linear, then we don't need to transpose inputs.
105
+ self.linear1 = nn.Conv1d(in_dim, bottleneck_dim, kernel_size=1) # equals W and b in the paper
106
+ self.linear2 = nn.Conv1d(bottleneck_dim, in_dim, kernel_size=1) # equals V and k in the paper
107
+
108
+ def forward(self, x):
109
+ # DON'T use ReLU here! In experiments, I find ReLU hard to converge.
110
+ alpha = torch.tanh(self.linear1(x))
111
+ alpha = torch.softmax(self.linear2(alpha), dim=2)
112
+ mean = torch.sum(alpha * x, dim=2)
113
+ residuals = torch.sum(alpha * x ** 2, dim=2) - mean ** 2
114
+ std = torch.sqrt(residuals.clamp(min=1e-9))
115
+ return torch.cat([mean, std], dim=1)
116
+
117
+
118
+
119
+ ''' Implementation of
120
+ "ECAPA-TDNN: Emphasized Channel Attention, Propagation and Aggregation in TDNN Based Speaker Verification".
121
+ Note that we DON'T concatenate the last frame-wise layer with non-weighted mean and standard deviation,
122
+ because it brings little improvment but significantly increases model parameters.
123
+ As a result, this implementation basically equals the A.2 of Table 2 in the paper.
124
+ '''
125
+ class ECAPA_TDNN_br(nn.Module):
126
+ def __init__(self, in_channels=80, channels=512, embd_dim=192):
127
+ super().__init__()
128
+ self.torchfb = torchaudio.transforms.MelSpectrogram(sample_rate=16000, n_fft=512, win_length=400,
129
+ hop_length=160, f_min=0.0, f_max=8000, pad=0, n_mels=80)
130
+ self.instancenorm = nn.InstanceNorm1d(40)
131
+ self.layer1 = Conv1dReluBn(in_channels, channels, kernel_size=5, padding=2)
132
+ self.layer2 = SE_Res2Block(channels, kernel_size=3, stride=1, padding=2, dilation=2, scale=8)
133
+ self.layer3 = SE_Res2Block(channels, kernel_size=3, stride=1, padding=3, dilation=3, scale=8)
134
+ self.layer4 = SE_Res2Block(channels, kernel_size=3, stride=1, padding=4, dilation=4, scale=8)
135
+
136
+ cat_channels = channels * 3
137
+ self.conv = nn.Conv1d(cat_channels, cat_channels, kernel_size=1)
138
+ self.pooling = AttentiveStatsPool(cat_channels, 128)
139
+ self.bn1 = nn.BatchNorm1d(cat_channels * 2)
140
+ self.linear = nn.Linear(cat_channels * 2, embd_dim)
141
+ self.bn2 = nn.BatchNorm1d(embd_dim)
142
+
143
+ def forward(self, x):
144
+ x = self.torchfb(x) + 1e-6
145
+ x = x.log()
146
+ x = self.instancenorm(x)
147
+ # print(x.shape)
148
+ # x = x.transpose(1, 2)
149
+ out1 = self.layer1(x)
150
+ out2 = self.layer2(out1) + out1
151
+ out3 = self.layer3(out1 + out2) + out1 + out2
152
+ out4 = self.layer4(out1 + out2 + out3) + out1 + out2 + out3
153
+
154
+ out = torch.cat([out2, out3, out4], dim=1)
155
+ out = F.relu(self.conv(out))
156
+ out = self.bn1(self.pooling(out))
157
+ out = self.bn2(self.linear(out))
158
+ return out
159
+
160
+
161
+
162
+
163
+ if __name__ == '__main__':
164
+ # Input size: batch_size * seq_len * feat_dim
165
+ x = torch.zeros(32, 32240).cuda()
166
+ model = ECAPA_TDNN_br(in_channels=80, channels=512, embd_dim=192)
167
+ # print(model)
168
+ summary(model, input_size=(tuple(x.shape)))
169
+ out = model(x)
170
+ print(out.shape) # should be [2, 192]
171
+
net/__init__.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .VGGVox import Vgg
2
+ from .vggvox1 import vgg
3
+ from .u_net import UNetVgg, UNetVggMask
4
+ from .ECAPA_TDNN import ECAPA_TDNN, ECAPA_TDNN_ks5, ECAPA_TDNN_L2
5
+ from .ECAPATDNN import ECAPATDNN
6
+ from .hrnet import hrnet
7
+ from .VGG_TDNN import Vggtdnn
8
+ from .ResNetSE34V2 import MainModel as ResNetSE34V2
9
+ from .ECAPA_TDNN_br import ECAPA_TDNN_br
10
+ from .hrtdnn import hrtdnn
11
+ from .ResTDNN import MainModel as ResTDNN
12
+ from .TDNN_VGG import TDNN_VGG
13
+ from .ResNet_TDNN import MainModel as ResNet_TDNN
14
+ from .TDNN_ResNet import TDNN_ResNet
15
+ from .hr_tdnn import hr_tdnn
16
+ from .swin_transformer import SwinTransformer
utils/.DS_Store ADDED
Binary file (6.15 kB). View file