Spaces:
Running
Running
| import argparse | |
| import json | |
| import os | |
| import shutil | |
| import time | |
| import warnings | |
| from typing import Dict | |
| import pandas as pd | |
| import psutil | |
| import torch | |
| import yaml | |
| from irg import TableConfig, IncrementalRelationalGenerator | |
| def parse_args() -> argparse.Namespace: | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument( | |
| "--config", "-c", default="config/sample.yaml", help="config file, default is a sample" | |
| ) | |
| parser.add_argument( | |
| "--input-data-dir", "-i", required=True, help="input data directory, data are in TABLE_NAME.csv" | |
| ) | |
| parser.add_argument("--output-path", "-o", default="./out", help="output directory") | |
| parser.add_argument("--actions", "-a", default=[], nargs="+", choices=["train", "gen"]) | |
| return parser.parse_args() | |
| def _validate_data_config(path: str, tables: Dict[str, TableConfig], descr: str): | |
| for tn, tc in tables.items(): | |
| table = pd.read_csv(os.path.join(path, f"{tn}.csv")) | |
| if tc.primary_key is not None: | |
| if table[tc.primary_key].duplicated().any(): | |
| raise ValueError(f"Primary key constraint {tc.primary_key} on {tn} is not fulfilled for {descr}.") | |
| for fk in tc.foreign_keys: | |
| parent = pd.read_csv(os.path.join(path, f"{fk.parent_table_name}.csv")) | |
| fk_str = f"{fk.child_table_name}{fk.child_column_names} -> {fk.parent_table_name}{fk.parent_column_names}" | |
| if parent[fk.parent_column_names].duplicated().any(): | |
| raise ValueError(f"Foreign key {fk_str} uniqueness on parent is not fulfilled for {descr}.") | |
| if (parent.merge( | |
| table[fk.child_column_names].dropna(), | |
| left_on=fk.parent_column_names, right_on=fk.child_column_names, how="outer", indicator="_merged" | |
| )["_merged"] == "right_only").any(): | |
| raise ValueError(f"Foreign key {fk_str} validity is not fulfilled for {descr}.") | |
| for a, b in tc.inequality: | |
| if (table[a] == table[b].rename(columns={bb: aa for bb, aa in zip(b, a)})).all(axis=1).any(): | |
| raise ValueError(f"Inequality [{a}, {b}] on {tn} is not fulfilled for {descr}.") | |
| def main(): | |
| os.environ["WANDB_DISABLED"] = "true" | |
| os.environ["PYTHONWARNINGS"] = "ignore" | |
| os.environ["CUDA_LAUNCH_BLOCKING"] = "1" | |
| os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" | |
| mem = psutil.virtual_memory() | |
| if mem.available + mem.used < 0.95 * mem.total: | |
| raise RuntimeError(f"Memory not available: {mem.available:,}, {mem.free:,}, {mem.used:,}, {mem.total:,}") | |
| warnings.filterwarnings("ignore") | |
| args = parse_args() | |
| with open(args.config, "r") as f: | |
| config = yaml.safe_load(f) | |
| tables = {tn: TableConfig.from_dict(ta | {"name": tn}) for tn, ta in config["tables"].items()} | |
| if "train" in args.actions: | |
| config["tables"] = tables | |
| if os.path.exists(args.output_path): | |
| shutil.rmtree(args.output_path) | |
| _validate_data_config(args.input_data_dir, tables, "real") | |
| start_time = time.time() | |
| synthesizer = IncrementalRelationalGenerator(**config) | |
| table_paths = { | |
| tn: os.path.join(args.input_data_dir, f"{tn}.csv") for tn in tables | |
| } | |
| synthesizer.fit(table_paths, args.output_path) | |
| end_time = time.time() | |
| times = {"fit": end_time - start_time} | |
| with open(os.path.join(args.output_path, "timing.json"), "w") as f: | |
| json.dump(times, f, indent=2) | |
| torch.save(synthesizer, os.path.join(args.output_path, "synthesizer.pt")) | |
| else: | |
| synthesizer = torch.load(os.path.join(args.output_path, "synthesizer.pt")) | |
| with open(os.path.join(args.output_path, "timing.json"), "r") as f: | |
| times = json.load(f) | |
| if "gen" in args.actions: | |
| start_time = time.time() | |
| synthesizer.generate(os.path.join(args.output_path, "generated"), os.path.join(args.output_path, "model")) | |
| end_time = time.time() | |
| times["sample"] = end_time - start_time | |
| with open(os.path.join(args.output_path, "timing.json"), "w") as f: | |
| json.dump(times, f, indent=2) | |
| _validate_data_config(os.path.join(args.output_path, "generated"), tables, "synthetic") | |
| if __name__ == '__main__': | |
| main() | |