import torch import torch.optim as optim import os from torch.nn import MSELoss from torch.utils.data import DataLoader from model import Model from dataset import Dataset from test import test import option from tqdm import tqdm torch.set_default_tensor_type('torch.FloatTensor') def sparsity(arr, lamda2): loss = torch.mean(torch.norm(arr, dim=0)) return lamda2*loss def smooth(arr, lamda1): arr2 = torch.zeros_like(arr) arr2[:-1] = arr[1:] arr2[-1] = arr[-1] loss = torch.sum((arr2-arr)**2) return lamda1*loss class SigmoidMAELoss(torch.nn.Module): def __init__(self): super(SigmoidMAELoss, self).__init__() from torch.nn import Sigmoid self.__sigmoid__ = Sigmoid() self.__l1_loss__ = MSELoss() def forward(self, pred, target): return self.__l1_loss__(pred, target) class RTFM_loss(torch.nn.Module): def __init__(self, alpha, margin): super(RTFM_loss, self).__init__() self.alpha = alpha self.margin = margin self.sigmoid = torch.nn.Sigmoid() self.mae_criterion = SigmoidMAELoss() self.criterion = torch.nn.CrossEntropyLoss() # multi class def forward(self, score_normal, score_abnormal, nlabel, alabel, feat_n, feat_a): labels = torch.cat((nlabel, alabel), 0) scores = torch.cat((score_normal, score_abnormal), 0) labels = labels.cuda() loss_cls = self.criterion(scores, labels) # CE loss in the score space loss_abn = torch.abs(self.margin - torch.norm(torch.mean(feat_a, dim=1), p=2, dim=1)) loss_nor = torch.norm(torch.mean(feat_n, dim=1), p=2, dim=1) loss_rtfm = torch.mean((loss_abn + loss_nor) ** 2) loss_total = loss_cls + self.alpha * loss_rtfm return loss_total def train(nloader, aloader, model, batch_size, seg_num, optimizer, device): with torch.set_grad_enabled(True): model.train() ninput1, ninput2, ninput3, nlabel = next(nloader) ainput1, ainput2, ainput3, alabel = next(aloader) input1 = torch.cat((ninput1, ainput1), 0).to(device) input2 = torch.cat((ninput2, ainput2), 0).to(device) input3 = torch.cat((ninput3, ainput3), 0).to(device) score_abnormal, score_normal, feat_select_abn, feat_select_normal, scores = model(input1, input2, input3) scores = scores.view(batch_size * seg_num * 2, -1) # BX32X2, 18 abn_scores, indice = torch.max(scores[batch_size*32:], dim=1) nlabel = nlabel[0:batch_size] alabel = alabel[0:batch_size] loss_criterion = RTFM_loss(0.0001, 100) loss_sparse = sparsity(abn_scores, 8e-3) loss_smooth = smooth(abn_scores, 8e-4) loss_RTFM = loss_criterion(score_normal, score_abnormal, nlabel, alabel, feat_select_normal, feat_select_abn) cost = loss_RTFM + loss_smooth + loss_sparse optimizer.zero_grad() cost.backward() optimizer.step() def main(): args = option.train_parser.parse_args() os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') train_nloader = DataLoader(Dataset(args, test_mode=False, is_normal=True), batch_size=args.batch_size, shuffle=True, num_workers=args.workers, pin_memory=True, drop_last=True) train_aloader = DataLoader(Dataset(args, test_mode=False, is_normal=False), batch_size=args.batch_size, shuffle=True, num_workers=args.workers, pin_memory=True, drop_last=True) test_loader = DataLoader(Dataset(args, test_mode=True), batch_size=1, shuffle=False, num_workers=args.workers, pin_memory=True) if not os.path.exists(args.save_models): os.makedirs(args.save_models) feature_size = args.feature_size model = Model(feature_size, args.batch_size, args.seg_num) optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=0.005) test_info = {"epoch": [], "TOP-1 ACC": []} best_ACC = -1 output_dir = args.output_dir os.makedirs(output_dir, exist_ok=True) acc, _ = test(dataloader=test_loader, model=model, device=device, test_dataset=args.test_dataset) for step in tqdm(range(1, args.max_epoch + 1), total=args.max_epoch, dynamic_ncols=True): if (step - 1) % len(train_nloader) == 0: loadern_iter = iter(train_nloader) if (step - 1) % len(train_aloader) == 0: loadera_iter = iter(train_aloader) train(nloader=loadern_iter, aloader=loadera_iter, model=model, batch_size=args.batch_size, seg_num=args.seg_num, optimizer=optimizer, device=device) if step % 5 == 0 and step > 5: acc, _ = test(dataloader=test_loader, model=model, device=device, test_dataset=args.test_dataset) test_info["epoch"].append(step) test_info["TOP-1 ACC"].append(acc) if test_info["TOP-1 ACC"][-1] > best_ACC: best_ACC = test_info["TOP-1 ACC"][-1] torch.save(model.state_dict(), os.path.join(args.save_models, args.model_name + '-{}.pkl'.format(step))) file_path = os.path.join(output_dir, '{}-step-ACC.txt'.format(step)) with open(file_path, "w") as fo: for key in test_info: fo.write("{}: {}\n".format(key, test_info[key][-1])) if __name__ == '__main__': main()