R2SE_model / ood_inference.py
unknownuser6666's picture
Upload folder using huggingface_hub
663494c verified
import numpy as np
import yaml
import pickle
import scipy.stats as stats
from tqdm import tqdm
class ProbPostProcessor:
def __init__(
self,
orig_pkl,
test_token_pkl,
save_pkl,
method='lognorm',
ood_percentile=5,
):
self.save_pkl = save_pkl
self.method = method
with open(orig_pkl, 'rb') as f:
self.full_data = pickle.load(f)
with open(test_token_pkl, 'r') as f:
test_token = yaml.safe_load(f)['tokens']
self.test_token = set(test_token)
test_idx, corner_idx = [], []
for i, d in enumerate(self.full_data):
if d['token'] in self.test_token:
test_idx.append(i)
else:
corner_idx.append(i)
self.test_idx = test_idx
print(f'Collecting test case: {len(test_idx)}; Corner case: {len(corner_idx)}')
# calculate corner PDF
variances = [
self.full_data[i]['ens_var'] for i in corner_idx
]
self.pdf_1, self.threshold_1 = self.build_corner_dist(variances, method, ood_percentile)
def variance_cal(self, i):
return self.full_data[i]['orig_var_score']# - self.full_data[i]['orig_var_score']
def build_corner_dist(self, variances, method, ood_percentile=5):
"""
build fitted distribution using varied cases
Returns:
List of booleans (True = OOD case).
"""
# Fit chosen distribution
if method == "pareto":
params = stats.pareto.fit(variances)
dist = stats.pareto
pdf = lambda x: dist.pdf(x, *params) # Probability density function
elif method == "lognorm":
params = stats.lognorm.fit(variances)
dist = stats.lognorm
pdf = lambda x: dist.pdf(x, *params)
elif method == "powerlaw":
params = stats.powerlaw.fit(variances)
dist = stats.powerlaw
pdf = lambda x: dist.pdf(x, *params)
elif method == "kde":
kde = stats.gaussian_kde(variances) # Fit KDE model
pdf = kde # KDE function directly returns densities
elif method == 'variance':
threshold = np.percentile(variances, 100 - ood_percentile)
pdf = variances
return pdf, threshold
else:
raise ValueError("Invalid method. Choose 'pareto', 'lognorm', 'powerlaw', or 'kde'.")
known_densities = pdf(variances)
threshold = np.percentile(known_densities, ood_percentile)
return pdf, threshold
def judge_ood(self, value):
if self.method=='variance':
return value > self.threshold
density = self.pdf(value)
return density < self.threshold
def main_process(self):
ret_data = []
ood_rate = []
for idx in tqdm(self.test_idx):
buf_dict = {}
buf_dict['token'] = self.full_data[idx]['token']
score = self.full_data[idx]['ens_var']
if self.method=='variance':
ood_org = score < self.threshold_1
else:
ood_org = self.pdf_1(score) < self.threshold_1
buf_dict['gpd_score'] = self.pdf_1(score)
if ood_org:
# OOD of corner case dist, use original pretrained:
buf_dict['chosen_ind'] = self.full_data[idx]['chosen_ind']
buf_dict['ood_flag'] = 1
ood_rate.append(1)
else:
buf_dict['chosen_ind'] = self.full_data[idx]['chosen_ind']
buf_dict['ood_flag'] = 0
ood_rate.append(0)
ret_data.append(buf_dict)
print(np.mean(ood_rate))
with open(self.save_pkl, 'wb') as f:
pickle.dump(ret_data, f)
print(f'Plan result Saved at {self.save_pkl}')
if __name__=='__main__':
processer = ProbPostProcessor(
'/xxx/r2se_test_result.pkl',
'/xxx/scene_filter/navtest.yaml',
'/xxx/output_test_result.pkl',
method='pareto',ood_percentile=75
)
processer.main_process()