File size: 4,302 Bytes
c4ac745
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
import argparse
import json
import os
import shutil
import time
import warnings
from typing import Dict

import pandas as pd
import psutil
import torch
import yaml

from irg import TableConfig, IncrementalRelationalGenerator


def parse_args() -> argparse.Namespace:
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--config", "-c", default="config/sample.yaml", help="config file, default is a sample"
    )
    parser.add_argument(
        "--input-data-dir", "-i", required=True, help="input data directory, data are in TABLE_NAME.csv"
    )
    parser.add_argument("--output-path", "-o", default="./out", help="output directory")
    parser.add_argument("--actions", "-a", default=[], nargs="+", choices=["train", "gen"])
    return parser.parse_args()


def _validate_data_config(path: str, tables: Dict[str, TableConfig], descr: str):
    for tn, tc in tables.items():
        table = pd.read_csv(os.path.join(path, f"{tn}.csv"))
        if tc.primary_key is not None:
            if table[tc.primary_key].duplicated().any():
                raise ValueError(f"Primary key constraint {tc.primary_key} on {tn} is not fulfilled for {descr}.")
        for fk in tc.foreign_keys:
            parent = pd.read_csv(os.path.join(path, f"{fk.parent_table_name}.csv"))
            fk_str = f"{fk.child_table_name}{fk.child_column_names} -> {fk.parent_table_name}{fk.parent_column_names}"
            if parent[fk.parent_column_names].duplicated().any():
                raise ValueError(f"Foreign key {fk_str} uniqueness on parent is not fulfilled for {descr}.")
            if (parent.merge(
                table[fk.child_column_names].dropna(),
                    left_on=fk.parent_column_names, right_on=fk.child_column_names, how="outer", indicator="_merged"
            )["_merged"] == "right_only").any():
                raise ValueError(f"Foreign key {fk_str} validity is not fulfilled for {descr}.")
        for a, b in tc.inequality:
            if (table[a] == table[b].rename(columns={bb: aa for bb, aa in zip(b, a)})).all(axis=1).any():
                raise ValueError(f"Inequality [{a}, {b}] on {tn} is not fulfilled for {descr}.")


def main():
    os.environ["WANDB_DISABLED"] = "true"
    os.environ["PYTHONWARNINGS"] = "ignore"
    os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
    os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
    mem = psutil.virtual_memory()
    if mem.available + mem.used < 0.95 * mem.total:
        raise RuntimeError(f"Memory not available: {mem.available:,}, {mem.free:,}, {mem.used:,}, {mem.total:,}")
    warnings.filterwarnings("ignore")
    args = parse_args()

    with open(args.config, "r") as f:
        config = yaml.safe_load(f)
    tables = {tn: TableConfig.from_dict(ta | {"name": tn}) for tn, ta in config["tables"].items()}

    if "train" in args.actions:
        config["tables"] = tables

        if os.path.exists(args.output_path):
            shutil.rmtree(args.output_path)
        _validate_data_config(args.input_data_dir, tables, "real")
        start_time = time.time()
        synthesizer = IncrementalRelationalGenerator(**config)
        table_paths = {
            tn: os.path.join(args.input_data_dir, f"{tn}.csv") for tn in tables
        }
        synthesizer.fit(table_paths, args.output_path)
        end_time = time.time()
        times = {"fit": end_time - start_time}
        with open(os.path.join(args.output_path, "timing.json"), "w") as f:
            json.dump(times, f, indent=2)
        torch.save(synthesizer, os.path.join(args.output_path, "synthesizer.pt"))
    else:
        synthesizer = torch.load(os.path.join(args.output_path, "synthesizer.pt"))
        with open(os.path.join(args.output_path, "timing.json"), "r") as f:
            times = json.load(f)

    if "gen" in args.actions:
        start_time = time.time()
        synthesizer.generate(os.path.join(args.output_path, "generated"), os.path.join(args.output_path, "model"))
        end_time = time.time()
        times["sample"] = end_time - start_time
        with open(os.path.join(args.output_path, "timing.json"), "w") as f:
            json.dump(times, f, indent=2)
        _validate_data_config(os.path.join(args.output_path, "generated"), tables, "synthetic")


if __name__ == '__main__':
    main()