File size: 3,654 Bytes
38f7d61
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
import argparse

import numpy as np
import pandas as pd
import torch
from torch.utils.data import DataLoader
from tqdm import tqdm

from model.CPICANN import CPICANN
from model.dataset import XrdDataset


def get_cs_anno():
    vs = pd.read_csv(args.anno_struc).values
    csAnno = {}
    for v in vs:
        csAnno[v[1]] = v[6]
    return csAnno


def get_acc(cls, label):
    correct_cnt = sum(cls.argmax(1) == label.int())
    cls_acc = correct_cnt / cls.shape[0]
    return cls_acc, correct_cnt


def run_one_epoch(model, dataloader):
    model.eval()

    csAnno = get_cs_anno()

    csCorrect = [0 for _ in range(7)]
    csTotal = [0 for _ in range(7)]
    cMtrx = [[0 for _ in range(7)] for _ in range(7)]
    epoch_loss, cls_acc = 0, 0
    correct_cnt, total_cnt = 0, 0
    pbar = tqdm(total=len(dataloader.dataset), desc='Evaluating... ', unit='data')
    iters = len(dataloader)
    for i, batch in enumerate(dataloader):

        data = batch[0].to(args.device)
        label_cls = batch[1].to(args.device)

        with torch.no_grad():
            logits = model(data)
            logits.to(args.device)

        pbar.update(len(data))

        _cls_acc, correct = get_acc(logits, label_cls)
        cls_acc += _cls_acc.item()

        correct_cnt += correct.item()
        total_cnt += len(data)

        preds = logits.argmax(1)
        for gt, pred in zip(label_cls, preds):
            cs_gt = csAnno[gt.item()]
            cMtrx[cs_gt][csAnno[pred.item()]] += 1
            csTotal[cs_gt] += 1
            if gt == pred:
                csCorrect[cs_gt] += 1

    return epoch_loss / iters, cls_acc * 100 / iters, correct_cnt, total_cnt, cMtrx, csCorrect, csTotal


def main():
    model = CPICANN(embed_dim=128, num_classes=args.num_classes)

    loaded = torch.load(args.load_path)
    model.load_state_dict(loaded['model'])
    model.to(args.device)
    model.eval()
    print('loaded model from {}'.format(args.load_path))

    print(model)

    valset = XrdDataset(args.data_dir, args.anno_val)
    val_loader = DataLoader(valset, batch_size=128, num_workers=16, pin_memory=True, shuffle=True)

    loss_val, acc_val, correct_cnt, total_cnt, cMtrx, csCorrect, csTotal = run_one_epoch(model, val_loader)

    print("loss_val: ", loss_val)
    print("acc_val: ", acc_val)
    print("{}%  ({}/{})".format(round(correct_cnt / total_cnt, 5) * 100, correct_cnt, total_cnt))

    sums = np.array(cMtrx).sum(axis=1)
    for i, row in enumerate(cMtrx):
        buf = ""
        for j, v in enumerate(row):
            buf += "{}({}%) ".format(v, round(v / sums[i] * 100, 2))
        print(buf)

    print("csCorrect: ", csCorrect)
    print("csTotal: ", csTotal)


if __name__ == '__main__':
    parser = argparse.ArgumentParser()

    parser.add_argument('--device', default='cuda:0', type=str)
    parser.add_argument('--data_dir', default='data/val/', type=str)
    parser.add_argument('--load_path', default='pretrained/single-phase_checkpoint_0200.pth', type=str,
                        help='path to load pretrained single-phase identification model')
    parser.add_argument('--anno_struc', default='annotation/anno_struc.csv', type=str,
                        help='path to annotation file for training data')
    parser.add_argument('--anno_val', default='annotation/anno_val.csv', type=str,
                        help='path to annotation file for validation data')
    parser.add_argument('--num_classes', default=23073, type=int, metavar='N')

    args = parser.parse_args()

    main()

    print('THE END')