Spaces:
Runtime error
Runtime error
File size: 2,973 Bytes
875baeb |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 |
#! /usr/bin/python
# -*- encoding: utf-8 -*-
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy
from tuneThreshold import tuneThresholdfromScore
import random
class LossFunction(nn.Module):
def __init__(self, hard_rank=0, hard_prob=0, margin=0, **kwargs):
super(LossFunction, self).__init__()
self.test_normalize = True
self.hard_rank = hard_rank
self.hard_prob = hard_prob
self.margin = margin
print('Initialised Triplet Loss')
def forward(self, x, label=None):
assert x.size()[1] == 2
out_anchor = F.normalize(x[:,0,:], p=2, dim=1)
out_positive = F.normalize(x[:,1,:], p=2, dim=1)
stepsize = out_anchor.size()[0]
output = -1 * (F.pairwise_distance(out_anchor.unsqueeze(-1),out_positive.unsqueeze(-1).transpose(0,2))**2)
print(output.shape)
negidx = self.mineHardNegative(output.detach())
print(negidx)
out_negative = out_positive[negidx,:]
print(out_negative.shape)
labelnp = numpy.array([1]*len(out_positive)+[0]*len(out_negative))
## calculate distances
pos_dist = F.pairwise_distance(out_anchor,out_positive)
neg_dist = F.pairwise_distance(out_anchor,out_negative)
print(pos_dist.shape)
print(neg_dist.shape)
print(F.relu(torch.pow(pos_dist, 2)).shape)
## loss function
nloss = torch.mean(F.relu(torch.pow(pos_dist, 2) - torch.pow(neg_dist, 2) + self.margin))
scores = -1 * torch.cat([pos_dist,neg_dist],dim=0).detach().cpu().numpy()
print(scores.shape)
errors = tuneThresholdfromScore(scores, labelnp, []);
return nloss, errors[1]
## ===== ===== ===== ===== ===== ===== ===== =====
## Hard negative mining
## ===== ===== ===== ===== ===== ===== ===== =====
def mineHardNegative(self, output):
negidx = []
for idx, similarity in enumerate(output):
simval, simidx = torch.sort(similarity,descending=True)
if self.hard_rank < 0:
## Semi hard negative mining
semihardidx = simidx[(similarity[idx] - self.margin < simval) & (simval < similarity[idx])]
if len(semihardidx) == 0:
negidx.append(random.choice(simidx))
else:
negidx.append(random.choice(semihardidx))
else:
## Rank based negative mining
simidx = simidx[simidx!=idx]
if random.random() < self.hard_prob:
negidx.append(simidx[random.randint(0, self.hard_rank)])
else:
negidx.append(random.choice(simidx))
return negidx
if __name__=="__main__":
x = torch.randn(32, 2, 512)
loss = LossFunction()
nloss, errors = loss(x)
print(nloss, errors) |