File size: 294 Bytes
af8092c
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
import torch

features = torch.load("features.pth")
qf = features["qf"]
ql = features["ql"]
gf = features["gf"]
gl = features["gl"]

scores = qf.mm(gf.t())
res = scores.topk(5, dim=1)[1][:, 0]
top1correct = gl[res].eq(ql).sum().item()

print("Acc top1:{:.3f}".format(top1correct / ql.size(0)))