|
|
import argparse
|
|
|
import os
|
|
|
|
|
|
import numpy as np
|
|
|
import pandas as pd
|
|
|
import torch
|
|
|
from tqdm import tqdm
|
|
|
|
|
|
from model.CPICANN import CPICANN
|
|
|
|
|
|
|
|
|
def getAnnoMap():
|
|
|
vs = pd.read_csv(args.anno_struc).values
|
|
|
annos, elems = {}, {}
|
|
|
for v in vs:
|
|
|
annos[v[1]] = v
|
|
|
elems[v[1]] = set(v[3].split(' '))
|
|
|
|
|
|
return annos, elems
|
|
|
|
|
|
|
|
|
def filter_by_elem(logits, elemMap, elem):
|
|
|
for i, e in elemMap.items():
|
|
|
if not e <= elem:
|
|
|
logits[:, i] = -10 ** 9
|
|
|
|
|
|
return logits
|
|
|
|
|
|
|
|
|
def main():
|
|
|
annoMap, elemMap = getAnnoMap()
|
|
|
|
|
|
model = CPICANN(embed_dim=128, num_classes=args.num_classes)
|
|
|
|
|
|
loaded = torch.load(args.load_path)
|
|
|
model.load_state_dict(loaded['model'])
|
|
|
model.to(args.device)
|
|
|
model.eval()
|
|
|
print('loaded model from {}'.format(args.load_path))
|
|
|
print(model)
|
|
|
|
|
|
if args.elem_filtration:
|
|
|
print('elem_filtration activated!')
|
|
|
else:
|
|
|
print('elem_filtration deactivated!')
|
|
|
|
|
|
lst = pd.read_csv(args.anno_val).values
|
|
|
|
|
|
top10Hits = np.array([0] * 10, dtype=np.int32)
|
|
|
|
|
|
dataLen = len(lst)
|
|
|
pbar = tqdm(range(args.infTimes))
|
|
|
for i in range(args.infTimes):
|
|
|
while True:
|
|
|
c1, c2 = np.random.randint(0, dataLen, 2)
|
|
|
anno1, anno2 = lst[c1], lst[c2]
|
|
|
if anno1[6] != anno2[6]:
|
|
|
break
|
|
|
|
|
|
|
|
|
|
|
|
data1 = pd.read_csv(os.path.join(args.data_dir, f'{lst[c1][0]}.csv')).values
|
|
|
data2 = pd.read_csv(os.path.join(args.data_dir, f'{lst[c2][0]}.csv')).values
|
|
|
|
|
|
mixRate1 = np.random.randint(20, 81)
|
|
|
mixRate2 = 100 - mixRate1
|
|
|
|
|
|
data = mixRate1 * data1 + mixRate2 * data2
|
|
|
elem = set(lst[c2][3].strip().split(' ')) | set(lst[c1][3].strip().split(' '))
|
|
|
|
|
|
def runFile(v):
|
|
|
min_i, scale = min(v), max(v) - min(v)
|
|
|
v = (v - min_i) / scale * 100
|
|
|
|
|
|
v = torch.tensor(v, dtype=torch.float32).reshape(1, 1, -1)
|
|
|
v = v.to(args.device)
|
|
|
with torch.no_grad():
|
|
|
logits = model(v)
|
|
|
|
|
|
|
|
|
if args.elem_filtration:
|
|
|
logits = filter_by_elem(logits, elemMap, elem)
|
|
|
|
|
|
_pred = torch.nn.functional.softmax(logits.squeeze(), dim=0)
|
|
|
return _pred.topk(10)
|
|
|
|
|
|
top10 = runFile(data)
|
|
|
|
|
|
m = [0] * 10
|
|
|
for no, (indice, rate) in enumerate(zip(top10.indices, top10.values)):
|
|
|
pred = annoMap[top10.indices[no].item()]
|
|
|
|
|
|
if pred[0] == int(anno1[0][:7]):
|
|
|
m[no] = 1
|
|
|
elif pred[0] == int(anno2[0][:7]):
|
|
|
m[no] = 2
|
|
|
|
|
|
if 1 in m[:2] and 2 in m[:2]:
|
|
|
top10Hits[1:] += 1
|
|
|
elif 1 in m[:3] and 2 in m[:3]:
|
|
|
top10Hits[2:] += 1
|
|
|
elif 1 in m[:4] and 2 in m[:4]:
|
|
|
top10Hits[3:] += 1
|
|
|
elif 1 in m[:5] and 2 in m[:5]:
|
|
|
top10Hits[4:] += 1
|
|
|
elif 1 in m[:6] and 2 in m[:6]:
|
|
|
top10Hits[5:] += 1
|
|
|
elif 1 in m[:7] and 2 in m[:7]:
|
|
|
top10Hits[6:] += 1
|
|
|
elif 1 in m[:8] and 2 in m[:8]:
|
|
|
top10Hits[7:] += 1
|
|
|
elif 1 in m[:9] and 2 in m[:9]:
|
|
|
top10Hits[8:] += 1
|
|
|
elif 1 in m[:10] and 2 in m[:10]:
|
|
|
top10Hits[9:] += 1
|
|
|
|
|
|
pbar.update(1)
|
|
|
pbar.close()
|
|
|
|
|
|
for i in range(1, 10):
|
|
|
print('top{}Hits: {}%'.format(i + 1, round(top10Hits[i] / args.infTimes * 100, 2)))
|
|
|
|
|
|
|
|
|
if __name__ == '__main__':
|
|
|
parser = argparse.ArgumentParser()
|
|
|
|
|
|
parser.add_argument('--device', default='cuda:0', type=str)
|
|
|
parser.add_argument('--data_dir', default='data/val/', type=str)
|
|
|
parser.add_argument('--infTimes', default=1000, type=int, help='number of mixed pattern to be inferenced')
|
|
|
parser.add_argument('--load_path', default='pretrained/bi-phase_checkpoint_2000.pth', type=str,
|
|
|
help='path to load pretrained single-phase identification model')
|
|
|
parser.add_argument('--anno_struc', default='annotation/anno_struc.csv', type=str,
|
|
|
help='path to annotation file for training data')
|
|
|
parser.add_argument('--anno_val', default='annotation/anno_val.csv', type=str,
|
|
|
help='path to annotation file for validation data')
|
|
|
parser.add_argument('--num_classes', default=23073, type=int, metavar='N')
|
|
|
|
|
|
parser.add_argument('--elem_filtration', default=False, type=bool)
|
|
|
|
|
|
args = parser.parse_args()
|
|
|
|
|
|
main()
|
|
|
print('THE END')
|
|
|
|