Spaces:
Running
Running
| import re | |
| from pipeline_utils import * | |
| import numpy as np | |
| import os | |
| import pandas as pd | |
| import json | |
| import argparse | |
| import pickle | |
| import time | |
| from tab_ddpm.utils import * | |
| from pipeline_modules import * | |
| from sdv.metadata import MultiTableMetadata | |
| from gen_multi_report import gen_multi_report | |
| def clava_clustering(tables, relation_order, save_dir, configs): | |
| relation_order_reversed = relation_order[::-1] | |
| all_group_lengths_prob_dicts = {} | |
| # Clustering | |
| if os.path.exists(os.path.join(save_dir, 'cluster_ckpt.pkl')): | |
| print('Clustering checkpoint found, loading...') | |
| cluster_ckpt = pickle.load(open(os.path.join(save_dir, 'cluster_ckpt.pkl'), 'rb')) | |
| tables = cluster_ckpt['tables'] | |
| all_group_lengths_prob_dicts = cluster_ckpt['all_group_lengths_prob_dicts'] | |
| else: | |
| for parent, child in relation_order_reversed: | |
| if parent is not None: | |
| print(f'Clustering {parent} -> {child}') | |
| if isinstance(configs['clustering']['num_clusters'], dict): | |
| num_clusters = configs['clustering']['num_clusters'][child] | |
| else: | |
| num_clusters = configs['clustering']['num_clusters'] | |
| parent_df_with_cluster, child_df_with_cluster, group_lengths_prob_dicts = pair_clustering_keep_id( | |
| tables[child]['df'], | |
| tables[child]['domain'], | |
| tables[parent]['df'], | |
| tables[parent]['domain'], | |
| f'{child}_id', | |
| f'{parent}_id', | |
| num_clusters, | |
| configs['clustering']['parent_scale'], | |
| 1, # not used for now | |
| parent, | |
| child, | |
| clustering_method=configs['clustering']['clustering_method'], | |
| ) | |
| tables[parent]['df'] = parent_df_with_cluster | |
| tables[child]['df'] = child_df_with_cluster | |
| all_group_lengths_prob_dicts[(parent, child)] = group_lengths_prob_dicts | |
| cluster_ckpt = { | |
| 'tables': tables, | |
| 'all_group_lengths_prob_dicts': all_group_lengths_prob_dicts | |
| } | |
| pickle.dump(cluster_ckpt, open(os.path.join(save_dir, 'cluster_ckpt.pkl'), 'wb')) | |
| return tables, all_group_lengths_prob_dicts | |
| def clava_training(tables, relation_order, save_dir, configs): | |
| models = {} | |
| for parent, child in relation_order: | |
| if os.path.exists(os.path.join(save_dir, f'models/{parent}_{child}_ckpt.pkl')): | |
| print(f'{parent} -> {child} checkpoint found, loading...') | |
| models[(parent, child)] = pickle.load(open(os.path.join(save_dir, f'models/{parent}_{child}_ckpt.pkl'), 'rb')) | |
| continue | |
| print(f'Training {parent} -> {child}') | |
| df_with_cluster = tables[child]['df'] | |
| id_cols = [col for col in df_with_cluster.columns if '_id' in col] | |
| df_without_id = df_with_cluster.drop(columns=id_cols) | |
| result = child_training( | |
| df_without_id, | |
| tables[child]['domain'], | |
| parent, | |
| child, | |
| configs | |
| ) | |
| models[(parent, child)] = result | |
| pickle.dump(result, open(os.path.join(save_dir, f'models/{parent}_{child}_ckpt.pkl'), 'wb')) | |
| for parent, child in relation_order: | |
| if parent is None: | |
| tables[child]['df']['placeholder'] = list(range(len(tables[child]['df']))) | |
| return tables, models | |
| def clava_synthesizing(tables, relation_order, save_dir, all_group_lengths_prob_dicts, models, configs, sample_scale=1): | |
| all_exist = True | |
| for key, val in tables.items(): | |
| table_dir = os.path.join( | |
| configs['general']['workspace_dir'], | |
| key, | |
| f'{configs["general"]["sample_prefix"]}_final' | |
| ) | |
| if not os.path.exists(os.path.join(table_dir, f'{key}_synthetic.csv')): | |
| all_exist = False | |
| break | |
| if all_exist: | |
| print('Synthetic tables found, loading...') | |
| synthetic_tables = {} | |
| for key, val in tables.items(): | |
| table_dir = os.path.join( | |
| configs['general']['workspace_dir'], | |
| key, | |
| f'{configs["general"]["sample_prefix"]}_final' | |
| ) | |
| synthetic_tables[key] = pd.read_csv(os.path.join(table_dir, f'{key}_synthetic.csv')) | |
| return synthetic_tables, 0, 0 | |
| synthesizing_start_time = time.time() | |
| synthetic_tables = {} | |
| if os.path.exists(os.path.join(save_dir, 'before_matching/synthetic_tables.pkl')): | |
| print('Synthetic tables found, loading...') | |
| synthetic_tables = pickle.load(open(os.path.join(save_dir, 'before_matching/synthetic_tables.pkl'), 'rb')) | |
| # Synthesize | |
| for parent, child in relation_order: | |
| if (parent, child) in synthetic_tables: | |
| print(f'{parent} -> {child} synthetic table found, skip...') | |
| print('---------------------------------------------------') | |
| continue | |
| print(f'Generating {parent} -> {child}') | |
| result = models[(parent, child)] | |
| df_with_cluster = tables[child]['df'] | |
| df_without_id = get_df_without_id(df_with_cluster) | |
| if parent is None: | |
| _, child_generated = sample_from_diffusion( | |
| df=df_without_id, | |
| df_info=result['df_info'], | |
| diffusion=result['diffusion'], | |
| dataset=result['dataset'], | |
| label_encoders=result['label_encoders'], | |
| sample_size=int(sample_scale * len(df_without_id)), | |
| model_params=result['model_params'], | |
| T_dict=result['T_dict'], | |
| sample_batch_size=configs['sampling']['batch_size'] | |
| ) | |
| child_keys = list(range(len(child_generated))) | |
| generated_final_arr = np.concatenate( | |
| [ | |
| np.array(child_keys).reshape(-1, 1), | |
| child_generated.to_numpy() | |
| ], | |
| axis=1 | |
| ) | |
| generated_final_df = pd.DataFrame( | |
| generated_final_arr, | |
| columns=[f'{child}_id'] + result['df_info']['num_cols'] + result['df_info']['cat_cols'] + [result['df_info']['y_col']] | |
| ) | |
| generated_final_df = generated_final_df[tables[child]['df'].columns] | |
| synthetic_tables[(parent, child)] = { | |
| 'df': generated_final_df, | |
| 'keys': child_keys | |
| } | |
| else: | |
| for key, val in synthetic_tables.items(): | |
| if key[1] == parent: | |
| parent_synthetic_df = val['df'] | |
| parent_keys = val['keys'] | |
| parent_result = models[key] | |
| break | |
| child_result = models[(parent, child)] | |
| parent_label_index = parent_result['column_orders'].index( | |
| child_result['df_info']['y_col'] | |
| ) | |
| parent_synthetic_df_without_id = get_df_without_id(parent_synthetic_df) | |
| _, child_generated, child_sampled_group_sizes = conditional_sampling_by_group_size( | |
| df=df_without_id, | |
| df_info=child_result['df_info'], | |
| dataset=child_result['dataset'], | |
| label_encoders=child_result['label_encoders'], | |
| classifier=child_result['classifier'], | |
| diffusion=child_result['diffusion'], | |
| group_labels=parent_synthetic_df_without_id.values[:, parent_label_index].astype(float).astype(int).tolist(), | |
| group_lengths_prob_dicts=all_group_lengths_prob_dicts[(parent, child)], | |
| sample_batch_size=configs['sampling']['batch_size'], | |
| is_y_cond='none', | |
| classifier_scale=configs['sampling']['classifier_scale'], | |
| ) | |
| child_foreign_keys = np.repeat(parent_keys, child_sampled_group_sizes, axis=0).reshape((-1, 1)) | |
| child_foreign_keys_arr = np.array(child_foreign_keys).reshape(-1, 1) | |
| child_primary_keys_arr = np.arange( | |
| len(child_generated) | |
| ).reshape(-1, 1) | |
| child_generated_final_arr = np.concatenate( | |
| [ | |
| child_primary_keys_arr, | |
| child_generated.to_numpy(), | |
| child_foreign_keys_arr | |
| ], | |
| axis=1 | |
| ) | |
| child_final_columns = [f'{child}_id'] + result['df_info']['num_cols'] + \ | |
| result['df_info']['cat_cols'] + [result['df_info']['y_col']] + [f'{parent}_id'] | |
| child_final_df = pd.DataFrame( | |
| child_generated_final_arr, | |
| columns=child_final_columns | |
| ) | |
| original_columns = [] | |
| for col in tables[child]['df'].columns: | |
| if col in child_final_df.columns: | |
| original_columns.append(col) | |
| child_final_df = child_final_df[original_columns] | |
| synthetic_tables[(parent, child)] = { | |
| 'df': child_final_df, | |
| 'keys': child_primary_keys_arr.flatten().tolist() | |
| } | |
| pickle.dump(synthetic_tables, open(os.path.join(save_dir, 'before_matching/synthetic_tables.pkl'), 'wb')) | |
| synthesizing_end_time = time.time() | |
| synthesizing_time_spent = synthesizing_end_time - synthesizing_start_time | |
| matching_start_time = time.time() | |
| # Matching | |
| final_tables = {} | |
| for parent, child in relation_order: | |
| if child not in final_tables: | |
| if len(tables[child]['parents']) > 1: | |
| final_tables[child] = handle_multi_parent( | |
| child, | |
| tables[child]['parents'], | |
| synthetic_tables, | |
| configs['matching']['num_matching_clusters'], | |
| unique_matching=configs['matching']['unique_matching'], | |
| batch_size=configs['matching']['matching_batch_size'], | |
| no_matching=configs['matching']['no_matching'] | |
| ) | |
| else: | |
| final_tables[child] = synthetic_tables[(parent, child)]['df'] | |
| matching_end_time = time.time() | |
| matching_time_spent = matching_end_time - matching_start_time | |
| cleaned_tables: dict[str, pd.DataFrame] = {} | |
| for key, val in final_tables.items(): | |
| cleaned_tables[key] = val[tables[key]['original_cols']] | |
| for key, val in cleaned_tables.items(): | |
| table_dir = os.path.join( | |
| configs['general']['workspace_dir'], | |
| key, | |
| f'{configs["general"]["sample_prefix"]}_final' | |
| ) | |
| os.makedirs(table_dir, exist_ok=True) | |
| val.to_csv(os.path.join(table_dir, f'{key}_synthetic.csv'), index=False) | |
| return cleaned_tables, synthesizing_time_spent, matching_time_spent | |
| def clava_eval(tables, save_dir, configs, relation_order, synthetic_tables=None): | |
| metadata = MultiTableMetadata() | |
| for table_name, val in tables.items(): | |
| df = val['original_df'] | |
| metadata.detect_table_from_dataframe( | |
| table_name, | |
| df | |
| ) | |
| id_cols = [col for col in df.columns if '_id' in col] | |
| for id_col in id_cols: | |
| metadata.update_column( | |
| table_name=table_name, | |
| column_name=id_col, | |
| sdtype='id' | |
| ) | |
| domain = tables[table_name]['domain'] | |
| for col, dom in domain.items(): | |
| if col in df.columns: | |
| if dom['type'] == 'discrete': | |
| metadata.update_column( | |
| table_name=table_name, | |
| column_name=col, | |
| sdtype='categorical', | |
| ) | |
| elif dom['type'] == 'continuous': | |
| metadata.update_column( | |
| table_name=table_name, | |
| column_name=col, | |
| sdtype='numerical', | |
| ) | |
| else: | |
| raise ValueError(f'Unknown domain type: {dom["type"]}') | |
| metadata.set_primary_key( | |
| table_name=table_name, | |
| column_name=f'{table_name}_id' | |
| ) | |
| for parent, child in relation_order: | |
| if parent is not None: | |
| metadata.add_relationship( | |
| parent_table_name=parent, | |
| child_table_name=child, | |
| parent_primary_key=f'{parent}_id', | |
| child_foreign_key=f'{parent}_id' | |
| ) | |
| if synthetic_tables is None: | |
| synthetic_tables = {} | |
| for table, meta in dataset_meta['tables'].items(): | |
| table_dir = os.path.join( | |
| configs['general']['workspace_dir'], | |
| table, | |
| f'{configs["general"]["sample_prefix"]}_final' | |
| ) | |
| synthetic_tables[table] = pd.read_csv(os.path.join(table_dir, f'{table}_synthetic.csv')) | |
| report = gen_multi_report( | |
| configs['general']['data_dir'], | |
| configs['general']['workspace_dir'], | |
| 'clava' | |
| ) | |
| pickle.dump(metadata, open(os.path.join(save_dir, 'metadata.pkl'), 'wb')) | |
| return report | |
| def load_configs(config_path): | |
| configs = json.load(open(config_path, 'r')) | |
| save_dir = os.path.join(configs['general']['workspace_dir'], configs['general']['exp_name']) | |
| os.makedirs(save_dir, exist_ok=True) | |
| os.makedirs(os.path.join(save_dir, 'models'), exist_ok=True) | |
| os.makedirs(os.path.join(save_dir, 'before_matching'), exist_ok=True) | |
| with open(os.path.join(save_dir, 'args'), 'w') as file: | |
| json.dump(configs, file, indent=4) | |
| return configs, save_dir | |
| if __name__ == '__main__': | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument('--config_path', type=str, default='configs/config.json') | |
| args = parser.parse_args() | |
| all_start_time = time.time() | |
| clustering_start_time = time.time() | |
| configs, save_dir = load_configs(args.config_path) | |
| tables, relation_order, dataset_meta = load_multi_table(configs['general']['data_dir']) | |
| # Clustering | |
| tables, all_group_lengths_prob_dicts = clava_clustering(tables, relation_order, save_dir, configs) | |
| clustering_end_time = time.time() | |
| clustering_time_spent = clustering_end_time - clustering_start_time | |
| training_start_time = time.time() | |
| # Training | |
| tables, models = clava_training(tables, relation_order, save_dir, configs) | |
| training_end_time = time.time() | |
| training_time_spent = training_end_time - training_start_time | |
| all_end_time = time.time() | |
| with open(os.path.join(os.path.dirname(configs['general']['workspace_dir']), "timing.json"), 'r') as f: | |
| timing = json.load(f) | |
| timing["fit"] = all_end_time - all_start_time | |
| with open(os.path.join(os.path.dirname(configs['general']['workspace_dir']), "timing.json"), 'w') as f: | |
| json.dump(timing, f, indent=2) | |
| # Synthesizing | |
| start_time = time.time() | |
| cleaned_tables, synthesizing_time_spent, matching_time_spent = clava_synthesizing( | |
| tables, | |
| relation_order, | |
| save_dir, | |
| all_group_lengths_prob_dicts, | |
| models, | |
| configs, | |
| sample_scale=1 if not 'debug' in configs else configs['debug']['sample_scale'] | |
| ) | |
| end_time = time.time() | |
| timing["sample"] = end_time - start_time | |
| with open(os.path.join(os.path.dirname(configs['general']['workspace_dir']), "timing.json"), 'w') as f: | |
| json.dump(timing, f, indent=2) | |
| # Eval | |
| # report = clava_eval(tables, save_dir, configs, relation_order, cleaned_tables) | |
| print('Time spent: ') | |
| print('Clustering: ', clustering_time_spent) | |
| print('Training: ', training_time_spent) | |
| print('Synthesizing: ', synthesizing_time_spent) | |
| print('Matching: ', matching_time_spent) | |
| print('Total: ', clustering_time_spent + training_time_spent + synthesizing_time_spent + matching_time_spent) | |