File size: 1,051 Bytes
352cafd |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 |
import os
import numpy as np
import pandas as pd
class SIMIaccess:
def __init__(self, path=None):
assert os.path.exists(path), 'similarity matrix {} is not exists.'.format(path)
df_sim = pd.read_csv(path, index_col=0)
self.matrix = df_sim.values
self.labels = list(df_sim.columns)
self.label_to_index = dict()
for i, label in enumerate(self.labels):
self.label_to_index.update({label: i})
def findSimiElement(self, dt_label, gt_label):
dt_label = str(dt_label)
gt_label = str(gt_label)
assert dt_label in self.labels, "Category id not belong to this dataset."
assert gt_label in self.labels, "Category id not belong to this dataset."
if dt_label == gt_label:
return 1
else:
dt_index = self.label_to_index.get(dt_label)
gt_index = self.label_to_index.get(gt_label)
simi = self.matrix[dt_index, gt_index]
return simi
def findSimiMatrix(self):
return self.matrix |