File size: 4,702 Bytes
38f7d61 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 |
import argparse
import os
import numpy as np
import pandas as pd
import torch
from tqdm import tqdm
from model.CPICANN import CPICANN
def getAnnoMap():
vs = pd.read_csv(args.anno_struc).values
annos, elems = {}, {}
for v in vs:
annos[v[1]] = v
elems[v[1]] = set(v[3].split(' '))
return annos, elems
def filter_by_elem(logits, elemMap, elem):
for i, e in elemMap.items():
if not e <= elem:
logits[:, i] = -10 ** 9
return logits
def main():
annoMap, elemMap = getAnnoMap()
model = CPICANN(embed_dim=128, num_classes=args.num_classes)
loaded = torch.load(args.load_path)
model.load_state_dict(loaded['model'])
model.to(args.device)
model.eval()
print('loaded model from {}'.format(args.load_path))
print(model)
if args.elem_filtration:
print('elem_filtration activated!')
else:
print('elem_filtration deactivated!')
lst = pd.read_csv(args.anno_val).values
top10Hits = np.array([0] * 10, dtype=np.int32)
dataLen = len(lst)
pbar = tqdm(range(args.infTimes))
for i in range(args.infTimes):
while True:
c1, c2 = np.random.randint(0, dataLen, 2)
anno1, anno2 = lst[c1], lst[c2]
if anno1[6] != anno2[6]:
break
# id1, id2 = int(lst[c1][0].split('_')[0]), int(lst[c2][0].split('_')[0])
# formula1, formula2 = lst[c1][2], lst[c2][2]
data1 = pd.read_csv(os.path.join(args.data_dir, f'{lst[c1][0]}.csv')).values
data2 = pd.read_csv(os.path.join(args.data_dir, f'{lst[c2][0]}.csv')).values
mixRate1 = np.random.randint(20, 81)
mixRate2 = 100 - mixRate1
data = mixRate1 * data1 + mixRate2 * data2
elem = set(lst[c2][3].strip().split(' ')) | set(lst[c1][3].strip().split(' '))
def runFile(v):
min_i, scale = min(v), max(v) - min(v)
v = (v - min_i) / scale * 100
v = torch.tensor(v, dtype=torch.float32).reshape(1, 1, -1)
v = v.to(args.device)
with torch.no_grad():
logits = model(v)
# filter by elements
if args.elem_filtration:
logits = filter_by_elem(logits, elemMap, elem)
_pred = torch.nn.functional.softmax(logits.squeeze(), dim=0)
return _pred.topk(10)
top10 = runFile(data)
m = [0] * 10
for no, (indice, rate) in enumerate(zip(top10.indices, top10.values)):
pred = annoMap[top10.indices[no].item()]
if pred[0] == int(anno1[0][:7]):
m[no] = 1
elif pred[0] == int(anno2[0][:7]):
m[no] = 2
if 1 in m[:2] and 2 in m[:2]:
top10Hits[1:] += 1
elif 1 in m[:3] and 2 in m[:3]:
top10Hits[2:] += 1
elif 1 in m[:4] and 2 in m[:4]:
top10Hits[3:] += 1
elif 1 in m[:5] and 2 in m[:5]:
top10Hits[4:] += 1
elif 1 in m[:6] and 2 in m[:6]:
top10Hits[5:] += 1
elif 1 in m[:7] and 2 in m[:7]:
top10Hits[6:] += 1
elif 1 in m[:8] and 2 in m[:8]:
top10Hits[7:] += 1
elif 1 in m[:9] and 2 in m[:9]:
top10Hits[8:] += 1
elif 1 in m[:10] and 2 in m[:10]:
top10Hits[9:] += 1
pbar.update(1)
pbar.close()
for i in range(1, 10):
print('top{}Hits: {}%'.format(i + 1, round(top10Hits[i] / args.infTimes * 100, 2)))
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--device', default='cuda:0', type=str)
parser.add_argument('--data_dir', default='data/val/', type=str)
parser.add_argument('--infTimes', default=1000, type=int, help='number of mixed pattern to be inferenced')
parser.add_argument('--load_path', default='pretrained/bi-phase_checkpoint_2000.pth', type=str,
help='path to load pretrained single-phase identification model')
parser.add_argument('--anno_struc', default='annotation/anno_struc.csv', type=str,
help='path to annotation file for training data')
parser.add_argument('--anno_val', default='annotation/anno_val.csv', type=str,
help='path to annotation file for validation data')
parser.add_argument('--num_classes', default=23073, type=int, metavar='N')
parser.add_argument('--elem_filtration', default=False, type=bool)
args = parser.parse_args()
main()
print('THE END')
|