IRG / baselines /ClavaDDPM /eval /eval_detection.py
Zilong-Zhao's picture
first commit
c4ac745
import numpy as np
import torch
import pandas as pd
import os
import sys
import json
import pickle
# Metrics
from sdmetrics import load_demo
from sdmetrics.single_table import LogisticDetection
from sdv.metadata import SingleTableMetadata
from matplotlib import pyplot as plt
import argparse
import warnings
warnings.filterwarnings("ignore")
def eval_detection(syn_data, real_data, domain_dict):
metadata = SingleTableMetadata()
metadata.detect_from_dataframe(real_data)
for col, _ in domain_dict.items():
if domain_dict[col]['type'] == 'discrete':
metadata.update_column(
column_name=col,
sdtype='categorical',
)
else:
metadata.update_column(
column_name=col,
sdtype='numerical',
)
metadata.remove_primary_key()
score = LogisticDetection.compute(
real_data=real_data,
synthetic_data=syn_data,
metadata=metadata
)
print(f'score: {score}')
return score
def reorder(real_data, syn_data, info):
num_col_idx = info['num_col_idx']
cat_col_idx = info['cat_col_idx']
target_col_idx = info['target_col_idx']
task_type = info['task_type']
if task_type == 'regression':
num_col_idx += target_col_idx
else:
cat_col_idx += target_col_idx
real_num_data = real_data[num_col_idx]
real_cat_data = real_data[cat_col_idx]
new_real_data = pd.concat([real_num_data, real_cat_data], axis=1)
new_real_data.columns = range(len(new_real_data.columns))
syn_num_data = syn_data[num_col_idx]
syn_cat_data = syn_data[cat_col_idx]
new_syn_data = pd.concat([syn_num_data, syn_cat_data], axis=1)
new_syn_data.columns = range(len(new_syn_data.columns))
metadata = info['metadata']
columns = metadata['columns']
metadata['columns'] = {}
inverse_idx_mapping = info['inverse_idx_mapping']
for i in range(len(new_real_data.columns)):
if i < len(num_col_idx):
metadata['columns'][i] = columns[num_col_idx[i]]
else:
metadata['columns'][i] = columns[cat_col_idx[i-len(num_col_idx)]]
return new_real_data, new_syn_data, metadata
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--dataname', type=str, default='adult')
parser.add_argument('--model', type=str, default='real')
args = parser.parse_args()
dataname = args.dataname
model = args.model
syn_path = f'synthetic/{dataname}/{model}.csv'
real_path = f'synthetic/{dataname}/real.csv'
data_dir = f'data/{dataname}'
print(syn_path)
with open(f'{data_dir}/info.json', 'r') as f:
info = json.load(f)
syn_data = pd.read_csv(syn_path)
real_data = pd.read_csv(real_path)
save_dir = f'eval/density/{dataname}/{model}'
if not os.path.exists(save_dir):
os.makedirs(save_dir)
real_data.columns = range(len(real_data.columns))
syn_data.columns = range(len(syn_data.columns))
metadata = info['metadata']
metadata['columns'] = {int(key): value for key, value in metadata['columns'].items()}
new_real_data, new_syn_data, metadata = reorder(real_data, syn_data, info)
# qual_report.generate(new_real_data, new_syn_data, metadata)
score = LogisticDetection.compute(
real_data=new_real_data,
synthetic_data=new_syn_data,
metadata=metadata
)
print(f'{dataname}, {model}: {score}')