File size: 14,055 Bytes
007d3b9 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 | from __future__ import print_function
import argparse
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.optim.lr_scheduler import StepLR, MultiStepLR
from datasets.hkpoly_test import hktest
from datasets.original_combined_train import Combined_original
from datasets.rb_loader import RB_loader
from loss import DualMSLoss_FineGrained_domain_agnostic_ft, DualMSLoss_FineGrained, DualMSLoss_FineGrained_domain_agnostic
import timm
from utils import Prev_RetMetric, RetMetric, compute_recall_at_k, l2_norm, compute_sharded_cosine_similarity, count_parameters
from pprint import pprint
import numpy as np
from tqdm import tqdm
from combined_sampler import BalancedSampler
from torch.utils.data.sampler import BatchSampler
from torch.nn.parallel import DataParallel
from model import SwinModel_domain_agnostic as Model
import matplotlib.pyplot as plt
from sklearn.metrics import roc_curve, auc
import json
from torch.utils.tensorboard import SummaryWriter
def train(args, model, device, train_loader, test_loader, optimizers, epoch, loss_func, pl_arg, stepping, log_writer):
model.train()
steploss = list()
for batch_idx, (x_cl, x_cb, target, category_cl, category_cb) in enumerate(pbar := tqdm(train_loader)):
x_cl, x_cb, target, category_cl, category_cb = x_cl.to(device), x_cb.to(device), target.to(device), category_cl.to(device), category_cb.to(device)
for optimizer in optimizers:
optimizer.zero_grad()
x_cl, x_cb, x_cl_tokens, x_cb_tokens, domain_class_cl, domain_class_cb = model(x_cl, x_cb)
loss = loss_func(x_cl, x_cb, x_cl_tokens, x_cb_tokens, target, device, domain_class_cl, domain_class_cb, category_cl, category_cb)
loss.backward()
for optimizer in optimizers:
optimizer.step()
if batch_idx % args.log_interval == 0:
if args.dry_run:
break
pbar.set_description(f"Loss {loss}")
steploss.append(loss)
return sum(steploss)/len(steploss), stepping
def l2_norm(input):
input_size = input.size()
buffer = torch.pow(input, 2)
normp = torch.sum(buffer, 1).add_(1e-12)
norm = torch.sqrt(normp)
_output = torch.div(input, norm.view(-1, 1).expand_as(input))
output = _output.view(input_size)
return output
def hkpoly_test_fn(model,device,test_loader,epoch,plot_argument):
model.eval()
cl_feats, cb_feats, cl_labels, cb_labels = list(),list(),list(),list()
with torch.no_grad():
for (x_cl, x_cb, label) in tqdm(test_loader):
x_cl, x_cb, label = x_cl.to(device), x_cb.to(device), label.to(device)
x_cl_feat, x_cl_token = model.get_embeddings(x_cl,'contactless')
x_cb_feat,x_cb_token = model.get_embeddings(x_cb,'contactbased')
x_cl_feat = l2_norm(x_cl_feat).cpu().detach().numpy()
x_cb_feat = l2_norm(x_cb_feat).cpu().detach().numpy()
label = label.cpu().detach().numpy()
cl_feats.append(x_cl_feat)
cb_feats.append(x_cb_feat)
cl_labels.append(label)
cb_labels.append(label)
cl_feats = np.concatenate(cl_feats)
cb_feats = np.concatenate(cb_feats)
cl_label = torch.from_numpy(np.concatenate(cl_labels))
cb_label = torch.from_numpy(np.concatenate(cb_labels))
# CB2CL
scores = np.dot(cl_feats,np.transpose(cb_feats))
np.save("combined_models_scores/task1_cb2cl_score_matrix_"+str(epoch)+"_"+plot_argument[0]+"_"+plot_argument[1]+"_"+plot_argument[2]+"_"+plot_argument[3]+".npy", scores)
scores = scores.flatten().tolist()
labels = torch.eq(cl_label.view(-1,1) - cb_label.view(1,-1),0.0).flatten().tolist()
ids_mod = list()
for i in labels:
if i==True:
ids_mod.append(1)
else:
ids_mod.append(0)
fpr,tpr,_ = roc_curve(labels,scores)
lower_fpr_idx = max(i for i, val in enumerate(fpr) if val < 0.01)
upper_fpr_idx = min(i for i, val in enumerate(fpr) if val >= 0.01)
tar_far_102 = (tpr[lower_fpr_idx]+tpr[upper_fpr_idx])/2
lower_fpr_idx = max(i for i, val in enumerate(fpr) if val < 0.001)
upper_fpr_idx = min(i for i, val in enumerate(fpr) if val >= 0.001)
tar_far_103 = (tpr[lower_fpr_idx]+tpr[upper_fpr_idx])/2
lower_fpr_idx = max(i for i, val in enumerate(fpr) if val < 0.0001)
upper_fpr_idx = min(i for i, val in enumerate(fpr) if val >= 0.0001)
tar_far_104 = (tpr[lower_fpr_idx]+tpr[upper_fpr_idx])/2
fnr = 1 - tpr
EER = fpr[np.nanargmin(np.absolute((fnr - fpr)))]
roc_auc = auc(fpr, tpr)
plt.figure()
plt.plot(fpr, tpr, label='ROC curve (area = %0.2f)' % roc_auc)
plt.plot([0, 1], [0, 1], 'k--', label='No Skill')
plt.xlim([0, 1])
plt.ylim([0, 1])
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
plt.title('ROC Curve CB2CL task1')
plt.legend(loc="lower right")
plt.savefig("combined_models_scores/roc_curve_cb2cl_task1_"+"_"+plot_argument[0]+"_"+plot_argument[1]+"_"+plot_argument[2]+"_"+plot_argument[3]+str(epoch)+".png", dpi=300, bbox_inches='tight')
print(f"ROCAUC for CB2CL: {roc_auc * 100} %")
print(f"EER for CB2CL: {EER * 100} %")
eer_cb2cl = EER * 100
print(f"TAR@FAR=10^-2 for CB2CL: {tar_far_102 * 100} %")
print(f"TAR@FAR=10^-3 for CB2CL: {tar_far_103 * 100} %")
print(f"TAR@FAR=10^-4 for CB2CL: {tar_far_104 * 100} %")
cbcltf102 = tar_far_102 * 100
cbcltf103 = tar_far_103 * 100
cbcltf104 = tar_far_104 * 100
cl_label = cl_label.cpu().detach().numpy()
cb_label = cb_label.cpu().detach().numpy()
recall_score = Prev_RetMetric([cb_feats,cl_feats],[cb_label,cl_label],cl2cl = False)
cl2cbk1 = recall_score.recall_k(k=1) * 100
print(f"R@1 for CB2CL: {recall_score.recall_k(k=1) * 100} %")
print(f"R@10 for CB2CL: {recall_score.recall_k(k=10) * 100} %")
print(f"R@50 for CB2CL: {recall_score.recall_k(k=50) * 100} %")
print(f"R@100 for CB2CL: {recall_score.recall_k(k=100) * 100} %")
return cl2cbk1,eer_cb2cl,cbcltf102,cbcltf103,cbcltf104
def main():
# Training settings
parser = argparse.ArgumentParser(description='PyTorch MNIST Example')
parser.add_argument('--manifest-list', type=list, default=mani_lst,
help='list of manifest files from different datasets to train on')
parser.add_argument('--batch-size', type=int, default=32, metavar='N',
help='input batch size for training (default: 64)')
parser.add_argument('--test-batch-size', type=int, default=16, metavar='N',
help='input batch size for testing (default: 1000)')
parser.add_argument('--epochs', type=int, default=50, metavar='N',
help='number of epochs to train (default: 14)')
parser.add_argument('--lr_linear', type=float, default=1.0, metavar='LR',
help='learning rate (default: 1.0)')
parser.add_argument('--lr_swin', type=float, default=1.0, metavar='LR',
help='learning rate (default: 1.0)')
parser.add_argument('--gamma', type=float, default=0.9, metavar='M',
help='Learning rate step gamma (default: 0.7)')
parser.add_argument('--no-cuda', action='store_true', default=False,
help='disables CUDA training')
parser.add_argument('--dry-run', action='store_true', default=False,
help='quickly check a single pass')
parser.add_argument('--seed', type=int, default=1, metavar='S',
help='random seed (default: 1)')
parser.add_argument('--log-interval', type=int, default=10, metavar='N',
help='how many batches to wait before logging training status')
parser.add_argument('--warmup', type=int, default=2, metavar='N',
help='warm up rate for feature extractor')
parser.add_argument('--model-name', type=str, default="ridgeformer",
help='Name of the model for checkpointing')
args = parser.parse_args()
checkpoint_save_path = "ridgeformer_checkpoints/"
use_cuda = not args.no_cuda and torch.cuda.is_available()
if not os.path.exists("experiment_logs/"+args.model_name):
os.mkdir("experiment_logs/"+args.model_name)
log_writer = SummaryWriter("experiment_logs/"+args.model_name+"/",comment = str(args.batch_size)+str(args.lr_linear)+str(args.lr_swin))
torch.manual_seed(args.seed)
device = torch.device("cuda" if use_cuda else "cpu")
print("loading Normal RGB images -----------------------------")
train_dataset = Combined_original(args.manifest_list,split="train")
val_dataset = hktest(split="test")
balanced_sampler = BalancedSampler(train_dataset, batch_size = args.batch_size, images_per_class = 2)
batch_sampler = BatchSampler(balanced_sampler, batch_size = args.batch_size, drop_last = True)
train_kwargs = {'batch_sampler': batch_sampler}
test_kwargs = {'batch_size': args.test_batch_size}
if use_cuda:
cuda_kwargs = {
'num_workers': 1,
'pin_memory': True
}
train_kwargs.update(cuda_kwargs)
test_kwargs.update(cuda_kwargs)
train_loader = torch.utils.data.DataLoader(train_dataset, **train_kwargs)
test_loader = torch.utils.data.DataLoader(val_dataset, **test_kwargs)
model = Model().to(device)
ckpt = torch.load("ridgeformer_checkpoints/phase1_scratch.pt", map_location=torch.device('cpu'))
model.load_state_dict(ckpt,strict=False)
print("Number of Trainable Parameters: - ", count_parameters(model))
loss_func = DualMSLoss_FineGrained_domain_agnostic()
# loss_func = DualMSLoss_FineGrained_domain_agnostic_ft()
optimizer_swin = optim.AdamW(
[
{"params": model.swin_cl.parameters(), "lr":args.lr_swin},
{"params": model.classify.parameters(), "lr":args.lr_linear},
{"params": model.linear_cl.parameters(), "lr":args.lr_linear},
{"params": model.linear_cb.parameters(), "lr":args.lr_linear},
],
weight_decay=0.000001,
lr=args.lr_swin)
scheduler_swin = MultiStepLR(optimizer_swin, milestones = [100], gamma=0.7)
cl2cl_lst = list()
cb2cl_lst = list()
eer_cl2cl_lst = list()
eer_cb2cl_lst = list()
cbcltf102_lst,cbcltf103_lst,cbcltf104_lst,clcltf102_lst,clcltf103_lst,clcltf104_lst = list(),list(),list(),list(),list(),list()
stepping = 1
for epoch in range(1, args.epochs + 1):
print(f"running epoch------ {epoch}")
if (epoch > args.warmup):
print("Training with Swin")
model.unfreeze_encoder()
else:
print("Training only linear")
model.freeze_encoder()
avg_step_loss,stepping = train(args, model, device, train_loader, test_loader, [optimizer_swin], epoch, loss_func, [args.model_name,str(args.batch_size),str(args.lr_linear),str(args.lr_swin)],stepping,log_writer)
print(f"Learning Rate for {epoch} for swin = {scheduler_swin.get_last_lr()}")
log_writer.add_scalar('Swin_LR/epoch',scheduler_swin.get_last_lr()[0],epoch)
if (epoch > args.warmup):
scheduler_swin.step()
cl2clk1,cl2cbk1,eer_cb2cl,eer_cl2cl,cbcltf102,cbcltf103,cbcltf104,clcltf102,clcltf103,clcltf104 = hkpoly_test_fn(model, device, test_loader, epoch,[args.model_name,str(args.batch_size),str(args.lr_linear),str(args.lr_swin)])
cl2cl_lst.append(cl2clk1)
cb2cl_lst.append(cl2cbk1)
eer_cl2cl_lst.append(eer_cl2cl)
eer_cb2cl_lst.append(eer_cb2cl)
cbcltf102_lst.append(cbcltf102)
cbcltf103_lst.append(cbcltf103)
cbcltf104_lst.append(cbcltf104)
clcltf102_lst.append(clcltf102)
clcltf103_lst.append(clcltf103)
clcltf104_lst.append(clcltf104)
log_writer.add_scalars('recall@1/epoch',{'CL2CL':cl2clk1,'CB2CL':cl2cbk1},epoch)
log_writer.add_scalars('EER/epoch',{'CL2CL':eer_cl2cl,'CB2CL':eer_cb2cl},epoch)
log_writer.add_scalars('TARFAR10^-2/epoch',{'CL2CL':clcltf102,'CB2CL':cbcltf102},epoch)
log_writer.add_scalars('TARFAR10^-3/epoch',{'CL2CL':clcltf103,'CB2CL':cbcltf103},epoch)
log_writer.add_scalars('TARFAR10^-4/epoch',{'CL2CL':clcltf104,'CB2CL':cbcltf104},epoch)
log_writer.add_scalar('AvgLoss/epoch',avg_step_loss,epoch)
torch.save(model.state_dict(), checkpoint_save_path + "combinedtrained_hkpolytest_" + args.model_name + "_" + str(args.lr_linear) + "_" + str(args.lr_swin) + "_" + str(args.batch_size) + str(epoch) + "_" + str(cl2clk1)+ "_" + str(cl2cbk1) + ".pt")
log_writer.close()
print(f"Maximum recall@1 for CL2CL: {max(cl2cl_lst)} at epoch {cl2cl_lst.index(max(cl2cl_lst))+1}")
print(f"Maximum recall@1 for CB2CL: {max(cb2cl_lst)} at epoch {cb2cl_lst.index(max(cb2cl_lst))+1}")
print(f"Minimum EER for CL2CL: {min(eer_cl2cl_lst)} at epoch {eer_cl2cl_lst.index(min(eer_cl2cl_lst))+1}")
print(f"Minimum EER for CB2CL: {min(eer_cb2cl_lst)} at epoch {eer_cb2cl_lst.index(min(eer_cb2cl_lst))+1}")
print(f"Maximum TAR@FAR=10^-2 for CB2CL: {max(cbcltf102_lst)} at epoch {cbcltf102_lst.index(max(cbcltf102_lst))+1}")
print(f"Maximum TAR@FAR=10^-3 for CB2CL: {max(cbcltf103_lst)} at epoch {cbcltf103_lst.index(max(cbcltf103_lst))+1}")
print(f"Maximum TAR@FAR=10^-4 for CB2CL: {max(cbcltf104_lst)} at epoch {cbcltf104_lst.index(max(cbcltf104_lst))+1}")
print(f"Maximum TAR@FAR=10^-2 for CL2CL: {max(clcltf102_lst)} at epoch {clcltf102_lst.index(max(clcltf102_lst))+1}")
print(f"Maximum TAR@FAR=10^-3 for CL2CL: {max(clcltf103_lst)} at epoch {clcltf103_lst.index(max(clcltf103_lst))+1}")
print(f"Maximum TAR@FAR=10^-4 for CL2CL: {max(clcltf104_lst)} at epoch {clcltf104_lst.index(max(clcltf104_lst))+1}")
if __name__ == '__main__':
main()
|