IRG / baselines /ClavaDDPM /baseline_utils.py
Zilong-Zhao's picture
first commit
c4ac745
import pandas as pd
import os
import random
import numpy as np
from sdv.metadata import SingleTableMetadata
from sdv.metadata import MultiTableMetadata
from sdv.single_table import CTGANSynthesizer
from smote.sample_smote import sample_smote_baseline
def get_group_sizes(child_df, foreign_key):
group_sizes = {}
for group, group_df in child_df.groupby(foreign_key):
group_sizes[group] = len(group_df)
return group_sizes
def get_group_size_prob(group_size_dict):
freqs = {}
for _, freq in group_size_dict.items():
if freq not in freqs:
freqs[freq] = 0
freqs[freq] += 1
probs = {}
for freq, count in freqs.items():
probs[freq] = count / len(group_size_dict)
return probs
def get_multi_metadata(tables, relation_order):
metadata = MultiTableMetadata()
for table_name, val in tables.items():
df = val['original_df']
metadata.detect_table_from_dataframe(
table_name,
df
)
id_cols = [col for col in df.columns if '_id' in col]
for id_col in id_cols:
metadata.update_column(
table_name=table_name,
column_name=id_col,
sdtype='id'
)
domain = tables[table_name]['domain']
for col, dom in domain.items():
if col in df.columns:
if dom['type'] == 'discrete':
metadata.update_column(
table_name=table_name,
column_name=col,
sdtype='categorical',
)
elif dom['type'] == 'continuous':
metadata.update_column(
table_name=table_name,
column_name=col,
sdtype='numerical',
)
else:
raise ValueError(f'Unknown domain type: {dom["type"]}')
metadata.set_primary_key(
table_name=table_name,
column_name=f'{table_name}_id'
)
for parent, child in relation_order:
if parent is not None:
metadata.add_relationship(
parent_table_name=parent,
child_table_name=child,
parent_primary_key=f'{parent}_id',
child_foreign_key=f'{parent}_id'
)
return metadata
def get_merged_metadata(merged_df, parent_domain_dict, child_domain_dict):
metadata = SingleTableMetadata()
df_without_ids = merged_df.drop(columns=[col for col in merged_df.columns if '_id' in col])
metadata.detect_from_dataframe(df_without_ids)
for col in df_without_ids.columns:
domain_dict = None
if col in parent_domain_dict:
domain_dict = parent_domain_dict
elif col in child_domain_dict:
domain_dict = child_domain_dict
if domain_dict is not None:
if domain_dict[col]['type'] == 'discrete':
if domain_dict[col]['size'] < 1000:
metadata.update_column(
column_name=col,
sdtype='categorical',
)
else:
metadata.update_column(
column_name=col,
sdtype='numerical',
)
else:
metadata.update_column(
column_name=col,
sdtype='numerical',
)
metadata.remove_primary_key()
return metadata
def get_metadata(df, domain_dict=None):
metadata = SingleTableMetadata()
df_without_ids = df.drop(columns=[col for col in df.columns if '_id' in col])
metadata.detect_from_dataframe(df_without_ids)
if domain_dict is not None:
for col in df_without_ids.columns:
if domain_dict[col]['type'] == 'discrete':
if domain_dict[col]['size'] < 1000:
metadata.update_column(
column_name=col,
sdtype='categorical',
)
else:
metadata.update_column(
column_name=col,
sdtype='numerical',
)
else:
metadata.update_column(
column_name=col,
sdtype='numerical',
)
metadata.remove_primary_key()
return metadata, df_without_ids
def train_ctgan(df, domain_dict, batch_size):
metadata, df_without_ids = get_metadata(df, domain_dict)
synthesizer = CTGANSynthesizer(metadata, batch_size=batch_size, verbose=True)
synthesizer.fit(df_without_ids)
synthetic_data = synthesizer.sample(num_rows=len(df_without_ids))
return synthetic_data
def baseline_load_synthetic_data(path, tables):
syn = {}
for table, val in tables.items():
syn[table] = {}
syn[table]['df'] = pd.read_csv(os.path.join(
path,
'final',
f'{table}_synthetic.csv'
))
syn[table]['domain'] = val['domain']
return syn
def lava_load_synthetic_data(path, tables):
syn = {}
for table, val in tables.items():
syn[table] = {}
syn[table]['df'] = pd.read_csv(os.path.join(
path,
table,
'_final',
f'{table}_synthetic.csv'
))
syn[table]['domain'] = val['domain']
return syn
def sdv_load_synthetic_data(path, tables):
syn = {}
for table, val in tables.items():
syn[table] = {}
syn[table]['df'] = pd.read_csv(os.path.join(
path,
f'{table}.csv'
))
syn[table]['domain'] = val['domain']
return syn
def get_smote_res(df, domain_dict):
id_cols = [col for col in df.columns if '_id' in col]
df_no_id = df.drop(columns=id_cols)
num_cols = []
cat_cols = []
for col, val in domain_dict.items():
if val['type'] == 'discrete':
cat_cols.append(col)
else:
num_cols.append(col)
all_cols = num_cols + cat_cols
y_col = random.choice(all_cols)
if y_col in num_cols:
num_cols.remove(y_col)
is_regression = True
else:
cat_cols.remove(y_col)
is_regression = False
X_num = {}
X_num['train'] = df[num_cols].values
X_cat = {}
X_cat['train'] = df[cat_cols].values
y = {}
y['train'] = df[y_col].values
syn_x_num, syn_x_cat, res_y = sample_smote_baseline(
'smote_res',
X_num,
X_cat,
y,
eval_type = "synthetic",
k_neighbours = 5,
frac_samples = 1.0,
frac_lam_del = 0.0,
change_val = False,
save = False,
seed = 0,
is_regression=is_regression
)
res = np.concatenate((syn_x_num, syn_x_cat, res_y.reshape((-1, 1))), axis=1)
res_df = pd.DataFrame(res, columns=num_cols + cat_cols + [y_col])
res_df = res_df[df_no_id.columns]
return res_df, cat_cols, num_cols, y_col