IRG / main.py
Zilong-Zhao's picture
first commit
c4ac745
import argparse
import json
import os
import shutil
import time
import warnings
from typing import Dict
import pandas as pd
import psutil
import torch
import yaml
from irg import TableConfig, IncrementalRelationalGenerator
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser()
parser.add_argument(
"--config", "-c", default="config/sample.yaml", help="config file, default is a sample"
)
parser.add_argument(
"--input-data-dir", "-i", required=True, help="input data directory, data are in TABLE_NAME.csv"
)
parser.add_argument("--output-path", "-o", default="./out", help="output directory")
parser.add_argument("--actions", "-a", default=[], nargs="+", choices=["train", "gen"])
return parser.parse_args()
def _validate_data_config(path: str, tables: Dict[str, TableConfig], descr: str):
for tn, tc in tables.items():
table = pd.read_csv(os.path.join(path, f"{tn}.csv"))
if tc.primary_key is not None:
if table[tc.primary_key].duplicated().any():
raise ValueError(f"Primary key constraint {tc.primary_key} on {tn} is not fulfilled for {descr}.")
for fk in tc.foreign_keys:
parent = pd.read_csv(os.path.join(path, f"{fk.parent_table_name}.csv"))
fk_str = f"{fk.child_table_name}{fk.child_column_names} -> {fk.parent_table_name}{fk.parent_column_names}"
if parent[fk.parent_column_names].duplicated().any():
raise ValueError(f"Foreign key {fk_str} uniqueness on parent is not fulfilled for {descr}.")
if (parent.merge(
table[fk.child_column_names].dropna(),
left_on=fk.parent_column_names, right_on=fk.child_column_names, how="outer", indicator="_merged"
)["_merged"] == "right_only").any():
raise ValueError(f"Foreign key {fk_str} validity is not fulfilled for {descr}.")
for a, b in tc.inequality:
if (table[a] == table[b].rename(columns={bb: aa for bb, aa in zip(b, a)})).all(axis=1).any():
raise ValueError(f"Inequality [{a}, {b}] on {tn} is not fulfilled for {descr}.")
def main():
os.environ["WANDB_DISABLED"] = "true"
os.environ["PYTHONWARNINGS"] = "ignore"
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
mem = psutil.virtual_memory()
if mem.available + mem.used < 0.95 * mem.total:
raise RuntimeError(f"Memory not available: {mem.available:,}, {mem.free:,}, {mem.used:,}, {mem.total:,}")
warnings.filterwarnings("ignore")
args = parse_args()
with open(args.config, "r") as f:
config = yaml.safe_load(f)
tables = {tn: TableConfig.from_dict(ta | {"name": tn}) for tn, ta in config["tables"].items()}
if "train" in args.actions:
config["tables"] = tables
if os.path.exists(args.output_path):
shutil.rmtree(args.output_path)
_validate_data_config(args.input_data_dir, tables, "real")
start_time = time.time()
synthesizer = IncrementalRelationalGenerator(**config)
table_paths = {
tn: os.path.join(args.input_data_dir, f"{tn}.csv") for tn in tables
}
synthesizer.fit(table_paths, args.output_path)
end_time = time.time()
times = {"fit": end_time - start_time}
with open(os.path.join(args.output_path, "timing.json"), "w") as f:
json.dump(times, f, indent=2)
torch.save(synthesizer, os.path.join(args.output_path, "synthesizer.pt"))
else:
synthesizer = torch.load(os.path.join(args.output_path, "synthesizer.pt"))
with open(os.path.join(args.output_path, "timing.json"), "r") as f:
times = json.load(f)
if "gen" in args.actions:
start_time = time.time()
synthesizer.generate(os.path.join(args.output_path, "generated"), os.path.join(args.output_path, "model"))
end_time = time.time()
times["sample"] = end_time - start_time
with open(os.path.join(args.output_path, "timing.json"), "w") as f:
json.dump(times, f, indent=2)
_validate_data_config(os.path.join(args.output_path, "generated"), tables, "synthetic")
if __name__ == '__main__':
main()