|
|
import numpy as np |
|
|
import yaml |
|
|
import pickle |
|
|
import scipy.stats as stats |
|
|
|
|
|
from tqdm import tqdm |
|
|
|
|
|
class ProbPostProcessor: |
|
|
def __init__( |
|
|
self, |
|
|
orig_pkl, |
|
|
test_token_pkl, |
|
|
save_pkl, |
|
|
method='lognorm', |
|
|
ood_percentile=5, |
|
|
): |
|
|
|
|
|
self.save_pkl = save_pkl |
|
|
self.method = method |
|
|
|
|
|
with open(orig_pkl, 'rb') as f: |
|
|
self.full_data = pickle.load(f) |
|
|
|
|
|
with open(test_token_pkl, 'r') as f: |
|
|
test_token = yaml.safe_load(f)['tokens'] |
|
|
self.test_token = set(test_token) |
|
|
|
|
|
test_idx, corner_idx = [], [] |
|
|
for i, d in enumerate(self.full_data): |
|
|
if d['token'] in self.test_token: |
|
|
test_idx.append(i) |
|
|
else: |
|
|
corner_idx.append(i) |
|
|
|
|
|
self.test_idx = test_idx |
|
|
|
|
|
print(f'Collecting test case: {len(test_idx)}; Corner case: {len(corner_idx)}') |
|
|
|
|
|
|
|
|
variances = [ |
|
|
self.full_data[i]['ens_var'] for i in corner_idx |
|
|
] |
|
|
self.pdf_1, self.threshold_1 = self.build_corner_dist(variances, method, ood_percentile) |
|
|
|
|
|
|
|
|
def variance_cal(self, i): |
|
|
return self.full_data[i]['orig_var_score'] |
|
|
|
|
|
|
|
|
def build_corner_dist(self, variances, method, ood_percentile=5): |
|
|
""" |
|
|
build fitted distribution using varied cases |
|
|
Returns: |
|
|
List of booleans (True = OOD case). |
|
|
""" |
|
|
|
|
|
|
|
|
if method == "pareto": |
|
|
params = stats.pareto.fit(variances) |
|
|
dist = stats.pareto |
|
|
pdf = lambda x: dist.pdf(x, *params) |
|
|
|
|
|
elif method == "lognorm": |
|
|
params = stats.lognorm.fit(variances) |
|
|
dist = stats.lognorm |
|
|
pdf = lambda x: dist.pdf(x, *params) |
|
|
|
|
|
elif method == "powerlaw": |
|
|
params = stats.powerlaw.fit(variances) |
|
|
dist = stats.powerlaw |
|
|
pdf = lambda x: dist.pdf(x, *params) |
|
|
|
|
|
elif method == "kde": |
|
|
kde = stats.gaussian_kde(variances) |
|
|
pdf = kde |
|
|
|
|
|
elif method == 'variance': |
|
|
threshold = np.percentile(variances, 100 - ood_percentile) |
|
|
pdf = variances |
|
|
return pdf, threshold |
|
|
else: |
|
|
raise ValueError("Invalid method. Choose 'pareto', 'lognorm', 'powerlaw', or 'kde'.") |
|
|
|
|
|
known_densities = pdf(variances) |
|
|
threshold = np.percentile(known_densities, ood_percentile) |
|
|
return pdf, threshold |
|
|
|
|
|
def judge_ood(self, value): |
|
|
if self.method=='variance': |
|
|
return value > self.threshold |
|
|
density = self.pdf(value) |
|
|
return density < self.threshold |
|
|
|
|
|
def main_process(self): |
|
|
ret_data = [] |
|
|
ood_rate = [] |
|
|
for idx in tqdm(self.test_idx): |
|
|
buf_dict = {} |
|
|
buf_dict['token'] = self.full_data[idx]['token'] |
|
|
|
|
|
score = self.full_data[idx]['ens_var'] |
|
|
if self.method=='variance': |
|
|
ood_org = score < self.threshold_1 |
|
|
else: |
|
|
ood_org = self.pdf_1(score) < self.threshold_1 |
|
|
buf_dict['gpd_score'] = self.pdf_1(score) |
|
|
|
|
|
if ood_org: |
|
|
|
|
|
buf_dict['chosen_ind'] = self.full_data[idx]['chosen_ind'] |
|
|
buf_dict['ood_flag'] = 1 |
|
|
ood_rate.append(1) |
|
|
else: |
|
|
buf_dict['chosen_ind'] = self.full_data[idx]['chosen_ind'] |
|
|
buf_dict['ood_flag'] = 0 |
|
|
ood_rate.append(0) |
|
|
ret_data.append(buf_dict) |
|
|
|
|
|
print(np.mean(ood_rate)) |
|
|
|
|
|
with open(self.save_pkl, 'wb') as f: |
|
|
pickle.dump(ret_data, f) |
|
|
|
|
|
print(f'Plan result Saved at {self.save_pkl}') |
|
|
|
|
|
if __name__=='__main__': |
|
|
|
|
|
processer = ProbPostProcessor( |
|
|
'/xxx/r2se_test_result.pkl', |
|
|
'/xxx/scene_filter/navtest.yaml', |
|
|
'/xxx/output_test_result.pkl', |
|
|
method='pareto',ood_percentile=75 |
|
|
) |
|
|
processer.main_process() |
|
|
|
|
|
|