Spaces:
Running
Running
| import pandas as pd | |
| import os | |
| from sdv.metadata import MultiTableMetadata | |
| def get_multi_metadata(tables, relation_order): | |
| metadata = MultiTableMetadata() | |
| for table_name, val in tables.items(): | |
| df = val['original_df'] | |
| metadata.detect_table_from_dataframe( | |
| table_name, | |
| df | |
| ) | |
| id_cols = [col for col in df.columns if '_id' in col] | |
| for id_col in id_cols: | |
| metadata.update_column( | |
| table_name=table_name, | |
| column_name=id_col, | |
| sdtype='id' | |
| ) | |
| domain = tables[table_name]['domain'] | |
| for col, dom in domain.items(): | |
| if col in df.columns: | |
| if dom['type'] == 'discrete': | |
| metadata.update_column( | |
| table_name=table_name, | |
| column_name=col, | |
| sdtype='categorical', | |
| ) | |
| elif dom['type'] == 'continuous': | |
| metadata.update_column( | |
| table_name=table_name, | |
| column_name=col, | |
| sdtype='numerical', | |
| ) | |
| else: | |
| raise ValueError(f'Unknown domain type: {dom["type"]}') | |
| metadata.set_primary_key( | |
| table_name=table_name, | |
| column_name=f'{table_name}_id' | |
| ) | |
| for parent, child in relation_order: | |
| if parent is not None: | |
| metadata.add_relationship( | |
| parent_table_name=parent, | |
| child_table_name=child, | |
| parent_primary_key=f'{parent}_id', | |
| child_foreign_key=f'{parent}_id' | |
| ) | |
| return metadata | |
| def baseline_load_synthetic_data(path, tables): | |
| syn = {} | |
| for table, val in tables.items(): | |
| syn[table] = {} | |
| syn[table]['df'] = pd.read_csv(os.path.join( | |
| path, | |
| 'final', | |
| f'{table}_synthetic.csv' | |
| )) | |
| syn[table]['domain'] = val['domain'] | |
| return syn | |
| def clava_load_synthetic_data(path, tables): | |
| syn = {} | |
| for table, val in tables.items(): | |
| syn[table] = {} | |
| syn[table]['df'] = pd.read_csv(os.path.join( | |
| path, | |
| table, | |
| '_final', | |
| f'{table}_synthetic.csv' | |
| )) | |
| syn[table]['domain'] = val['domain'] | |
| return syn | |
| def sdv_load_synthetic_data(path, tables): | |
| syn = {} | |
| for table, val in tables.items(): | |
| syn[table] = {} | |
| syn[table]['df'] = pd.read_csv(os.path.join( | |
| path, | |
| f'{table}.csv' | |
| )) | |
| syn[table]['domain'] = val['domain'] | |
| return syn | |