File size: 3,654 Bytes
38f7d61 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 |
import argparse
import numpy as np
import pandas as pd
import torch
from torch.utils.data import DataLoader
from tqdm import tqdm
from model.CPICANN import CPICANN
from model.dataset import XrdDataset
def get_cs_anno():
vs = pd.read_csv(args.anno_struc).values
csAnno = {}
for v in vs:
csAnno[v[1]] = v[6]
return csAnno
def get_acc(cls, label):
correct_cnt = sum(cls.argmax(1) == label.int())
cls_acc = correct_cnt / cls.shape[0]
return cls_acc, correct_cnt
def run_one_epoch(model, dataloader):
model.eval()
csAnno = get_cs_anno()
csCorrect = [0 for _ in range(7)]
csTotal = [0 for _ in range(7)]
cMtrx = [[0 for _ in range(7)] for _ in range(7)]
epoch_loss, cls_acc = 0, 0
correct_cnt, total_cnt = 0, 0
pbar = tqdm(total=len(dataloader.dataset), desc='Evaluating... ', unit='data')
iters = len(dataloader)
for i, batch in enumerate(dataloader):
data = batch[0].to(args.device)
label_cls = batch[1].to(args.device)
with torch.no_grad():
logits = model(data)
logits.to(args.device)
pbar.update(len(data))
_cls_acc, correct = get_acc(logits, label_cls)
cls_acc += _cls_acc.item()
correct_cnt += correct.item()
total_cnt += len(data)
preds = logits.argmax(1)
for gt, pred in zip(label_cls, preds):
cs_gt = csAnno[gt.item()]
cMtrx[cs_gt][csAnno[pred.item()]] += 1
csTotal[cs_gt] += 1
if gt == pred:
csCorrect[cs_gt] += 1
return epoch_loss / iters, cls_acc * 100 / iters, correct_cnt, total_cnt, cMtrx, csCorrect, csTotal
def main():
model = CPICANN(embed_dim=128, num_classes=args.num_classes)
loaded = torch.load(args.load_path)
model.load_state_dict(loaded['model'])
model.to(args.device)
model.eval()
print('loaded model from {}'.format(args.load_path))
print(model)
valset = XrdDataset(args.data_dir, args.anno_val)
val_loader = DataLoader(valset, batch_size=128, num_workers=16, pin_memory=True, shuffle=True)
loss_val, acc_val, correct_cnt, total_cnt, cMtrx, csCorrect, csTotal = run_one_epoch(model, val_loader)
print("loss_val: ", loss_val)
print("acc_val: ", acc_val)
print("{}% ({}/{})".format(round(correct_cnt / total_cnt, 5) * 100, correct_cnt, total_cnt))
sums = np.array(cMtrx).sum(axis=1)
for i, row in enumerate(cMtrx):
buf = ""
for j, v in enumerate(row):
buf += "{}({}%) ".format(v, round(v / sums[i] * 100, 2))
print(buf)
print("csCorrect: ", csCorrect)
print("csTotal: ", csTotal)
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--device', default='cuda:0', type=str)
parser.add_argument('--data_dir', default='data/val/', type=str)
parser.add_argument('--load_path', default='pretrained/single-phase_checkpoint_0200.pth', type=str,
help='path to load pretrained single-phase identification model')
parser.add_argument('--anno_struc', default='annotation/anno_struc.csv', type=str,
help='path to annotation file for training data')
parser.add_argument('--anno_val', default='annotation/anno_val.csv', type=str,
help='path to annotation file for validation data')
parser.add_argument('--num_classes', default=23073, type=int, metavar='N')
args = parser.parse_args()
main()
print('THE END')
|