File size: 7,650 Bytes
007d3b9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 |
import torch
from datasets.rb_loader import RB_loader
from utils import Prev_RetMetric, l2_norm, compute_recall_at_k
import numpy as np
from tqdm import tqdm
from model import SwinModel_domain_agnostic as Model
from sklearn.metrics import roc_curve, auc
import json
import torch.nn.functional as F
if __name__ == '__main__':
device = torch.device('cuda')
data = RB_loader(split = 'test')
dataloader = torch.utils.data.DataLoader(data,batch_size = 16, num_workers = 1, pin_memory = True)
model = Model().to(device)
checkpoint = torch.load("ridgeformer_checkpoints/phase1_scratch.pt",map_location = torch.device('cpu'))
model.load_state_dict(checkpoint,strict=False)
model.eval()
cl_feats, cb_feats, cl_labels, cb_labels, cl_fnames, cb_fnames, cl_feats_unnormed, cb_feats_unnormed = list(),list(),list(),list(),list(),list(),list(),list()
print("Computing Test Recall")
with torch.no_grad():
for (x_cl, x_cb, target, cl_fname, cb_fname) in tqdm(dataloader):
x_cl, x_cb, target = x_cl.to(device), x_cb.to(device), target.to(device)
x_cl, _ = model.get_embeddings(x_cl, ftype="contactless")
x_cb, _ = model.get_embeddings(x_cb, ftype="contactbased")
cl_feats_unnormed.append(x_cl.cpu().detach().numpy())
cb_feats_unnormed.append(x_cb.cpu().detach().numpy())
x_cl = l2_norm(x_cl).cpu().detach().numpy()
x_cb = l2_norm(x_cb).cpu().detach().numpy()
target = target.cpu().detach().numpy()
cl_feats.append(x_cl)
cb_feats.append(x_cb)
cl_labels.append(target)
cb_labels.append(target)
cl_fnames.extend(cl_fname)
cb_fnames.extend(cb_fname)
cl_feats = torch.from_numpy(np.concatenate(cl_feats))
cb_feats = torch.from_numpy(np.concatenate(cb_feats))
cl_labels = torch.from_numpy(np.concatenate(cl_labels))
cb_labels = torch.from_numpy(np.concatenate(cb_labels))
cl_feats_unnormed = torch.from_numpy(np.concatenate(cl_feats_unnormed))
cb_feats_unnormed = torch.from_numpy(np.concatenate(cb_feats_unnormed))
unique_labels, indices = torch.unique(cb_labels, return_inverse=True)
unique_feats = torch.stack([cb_feats[indices == i].mean(dim=0) for i in range(len(unique_labels))])
cb_feats = unique_feats
unique_labels, indices = torch.unique(cb_labels, return_inverse=True)
unique_feats = torch.stack([cb_feats_unnormed[indices == i].mean(dim=0) for i in range(len(unique_labels))])
cb_labels = unique_labels
cb_feats_unnormed = unique_feats
# CL2CB <---------------------------------------->
cl_feats = cl_feats.numpy()
cb_feats = cb_feats.numpy()
cb_feats_unnormed = cb_feats_unnormed.numpy()
cl_feats_unnormed = cl_feats_unnormed.numpy()
squared_diff = np.sum(np.square(cl_feats_unnormed[:, np.newaxis] - cb_feats_unnormed), axis=2)
distance = -1 * np.sqrt(squared_diff)
similarities = np.dot(cl_feats,np.transpose(cb_feats))
scores_mat = similarities + 0.1 * distance
scores = scores_mat.flatten().tolist()
ids = torch.eq(cl_labels.view(-1,1)-cb_labels.view(1,-1),0.0).flatten().tolist()
ids_mod = list()
for x in ids:
if x==True:
ids_mod.append(1)
else:
ids_mod.append(0)
fpr,tpr,thresh = roc_curve(ids_mod,scores,drop_intermediate=True)
lower_fpr_idx = max(i for i, val in enumerate(fpr) if val < 0.01)
upper_fpr_idx = min(i for i, val in enumerate(fpr) if val >= 0.01)
tar_far_102 = tpr[upper_fpr_idx]
print(tpr[lower_fpr_idx], lower_fpr_idx, fpr[lower_fpr_idx], thresh[lower_fpr_idx])
print(tpr[upper_fpr_idx], upper_fpr_idx, fpr[upper_fpr_idx], thresh[upper_fpr_idx])
lower_fpr_idx = max(i for i, val in enumerate(fpr) if val < 0.001)
upper_fpr_idx = min(i for i, val in enumerate(fpr) if val >= 0.001)
tar_far_103 = (tpr[lower_fpr_idx]+tpr[upper_fpr_idx])/2
print(tpr[lower_fpr_idx], lower_fpr_idx, fpr[lower_fpr_idx])
print(tpr[upper_fpr_idx], upper_fpr_idx, fpr[upper_fpr_idx])
lower_fpr_idx = max(i for i, val in enumerate(fpr) if val < 0.0001)
upper_fpr_idx = min(i for i, val in enumerate(fpr) if val >= 0.0001)
tar_far_104 = (tpr[lower_fpr_idx]+tpr[upper_fpr_idx])/2
print(tpr[lower_fpr_idx], lower_fpr_idx, fpr[lower_fpr_idx])
print(tpr[upper_fpr_idx], upper_fpr_idx, fpr[upper_fpr_idx])
fnr = 1 - tpr
EER = fpr[np.nanargmin(np.absolute((fnr - fpr)))]
roc_auc = auc(fpr, tpr)
print(f"ROCAUC for CB2CL: {roc_auc * 100} %")
print(f"EER for CB2CL: {EER * 100} %")
eer_cb2cl = EER * 100
cbcltf102 = tar_far_102 * 100
cbcltf103 = tar_far_103 * 100
cbcltf104 = tar_far_104 * 100
cl_labels = cl_labels.cpu().detach()
cb_labels = cb_labels.cpu().detach()
print(f"TAR@FAR=10^-2 for CB2CL: {tar_far_102 * 100} %")
print(f"TAR@FAR=10^-3 for CB2CL: {tar_far_103 * 100} %")
print(f"TAR@FAR=10^-4 for CB2CL: {tar_far_104 * 100} %")
print(f"R@1 for CB2CL: {compute_recall_at_k(torch.from_numpy(scores_mat), cl_labels, cb_labels, 1) * 100} %")
print(f"R@10 for CB2CL: {compute_recall_at_k(torch.from_numpy(scores_mat), cl_labels, cb_labels, 10) * 100} %")
print(f"R@50 for CB2CL: {compute_recall_at_k(torch.from_numpy(scores_mat), cl_labels, cb_labels, 50) * 100} %")
print(f"R@100 for CB2CL: {compute_recall_at_k(torch.from_numpy(scores_mat), cl_labels, cb_labels, 100) * 100} %")
################################################################################
# CL2CL
scores = torch.from_numpy(np.dot(cl_feats,np.transpose(cl_feats)))
row, col = torch.triu_indices(row=scores.size(0), col=scores.size(1), offset=1)
scores = scores[row, col]
scores = scores.numpy().flatten().tolist()
labels = torch.eq(cl_labels.view(-1,1) - cl_labels.view(1,-1),0.0).float().cuda()
labels = labels[torch.triu(torch.ones(labels.shape),diagonal = 1) == 1].tolist()
fpr,tpr,_ = roc_curve(labels,scores)
lower_fpr_idx = max(i for i, val in enumerate(fpr) if val < 0.01)
upper_fpr_idx = min(i for i, val in enumerate(fpr) if val >= 0.01)
tar_far_102 = (tpr[lower_fpr_idx]+tpr[upper_fpr_idx])/2
lower_fpr_idx = max(i for i, val in enumerate(fpr) if val < 0.001)
upper_fpr_idx = min(i for i, val in enumerate(fpr) if val >= 0.001)
tar_far_103 = (tpr[lower_fpr_idx]+tpr[upper_fpr_idx])/2
lower_fpr_idx = max(i for i, val in enumerate(fpr) if val < 0.0001)
upper_fpr_idx = min(i for i, val in enumerate(fpr) if val >= 0.0001)
tar_far_104 = (tpr[lower_fpr_idx]+tpr[upper_fpr_idx])/2
clcltf102 = tar_far_102 * 100
clcltf103 = tar_far_103 * 100
clcltf104 = tar_far_104 * 100
fnr = 1 - tpr
EER = fpr[np.nanargmin(np.absolute((fnr - fpr)))]
roc_auc = auc(fpr, tpr)
print(f"ROCAUC for CL2CL: {roc_auc * 100} %")
print(f"EER for CL2CL: {EER * 100} %")
eer_cl2cl = EER * 100
print(f"TAR@FAR=10^-2 for CL2CL: {tar_far_102 * 100} %")
print(f"TAR@FAR=10^-3 for CL2CL: {tar_far_103 * 100} %")
print(f"TAR@FAR=10^-4 for CL2CL: {tar_far_104 * 100} %")
cl_labels = cl_labels.cpu().detach().numpy()
recall_score = Prev_RetMetric([cl_feats,cl_feats],[cl_labels,cl_labels],cl2cl = True)
cl2clk1 = recall_score.recall_k(k=1) * 100
print(f"R@1 for CL2CL: {recall_score.recall_k(k=1) * 100} %")
print(f"R@10 for CL2CL: {recall_score.recall_k(k=10) * 100} %")
print(f"R@50 for CL2CL: {recall_score.recall_k(k=50) * 100} %")
print(f"R@100 for CL2CL: {recall_score.recall_k(k=100) * 100} %") |