Spaces:
Running
Running
| import argparse | |
| import os | |
| import re | |
| import shutil | |
| import time | |
| from collections import defaultdict | |
| import numpy as np | |
| import pandas as pd | |
| import json_tricks as json | |
| from preprocess_utils import topological_sort | |
| def parse_args() -> argparse.Namespace: | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument("--sdv-schema", "-s", type=str, required=True, help="SDV schema file") | |
| parser.add_argument("--output-dir", "-o", type=str, required=True, help="Output directory") | |
| subparsers = parser.add_subparsers(dest="op") | |
| pre_parser = subparsers.add_parser("pre") | |
| pre_parser.add_argument("--dataset-dir", "-d", type=str, required=True, help="Directory containing datasets") | |
| pre_parser.add_argument("--dataset-name", "-n", type=str, required=True, help="Dataset name") | |
| pre_parser.add_argument("--fast", "-f", default=False, action="store_true", help="Fast experiment or not") | |
| post_parser = subparsers.add_parser("post") | |
| return parser.parse_args() | |
| def main(): | |
| args = parse_args() | |
| if args.op == "pre": | |
| preprocess(args) | |
| elif args.op == "post": | |
| postprocess(args) | |
| else: | |
| raise NotImplementedError(f"Unknown op: {args.op}") | |
| def preprocess(args): | |
| with open(args.sdv_schema, "r") as f: | |
| schema = json.load(f) | |
| start_time = time.time() | |
| os.makedirs(args.output_dir, exist_ok=True) | |
| os.makedirs(os.path.join(args.output_dir, "data"), exist_ok=True) | |
| n_copies = defaultdict(int) | |
| parents_fk = defaultdict(list) | |
| for fk in schema["relationships"]: | |
| n_copies[(fk["parent_table_name"], fk["child_table_name"])] += 1 | |
| parents_fk[fk["child_table_name"]].append(fk) | |
| sum_n_copies = defaultdict(int) | |
| for (p, c), v in n_copies.items(): | |
| sum_n_copies[p] = max(sum_n_copies[p], v) | |
| need_null_pk = defaultdict(list) | |
| renames = {} | |
| for table_name, table_args in schema["tables"].items(): | |
| table = pd.read_csv(os.path.join(args.dataset_dir, f"{table_name}.csv")) | |
| table_renames = {} | |
| for c in table.columns: | |
| if (c.endswith("_id") and c not in [fk["child_foreign_key"] for fk in parents_fk[table_name]] | |
| and c != schema["tables"][table_name].get("primary_key")): | |
| table_renames[c] = c.replace("_id", "_iidd") | |
| renames[table_name] = table_renames | |
| for fk in parents_fk[table_name]: | |
| fk["nullable"] = table[fk["child_foreign_key"]].isna().any() | |
| if fk["nullable"]: | |
| sizes = table[fk["child_foreign_key"]].value_counts(dropna=True) | |
| need_null_pk[fk["parent_table_name"]].append(( | |
| sizes.quantile(0.05), sizes.quantile(0.95), table[fk["child_foreign_key"]].isna().sum() | |
| )) | |
| parents = defaultdict(list) | |
| domains = {} | |
| n_na_all = {} | |
| for table_name, table_args in schema["tables"].items(): | |
| domains[table_name] = {} | |
| table = pd.read_csv(os.path.join(args.dataset_dir, f"{table_name}.csv")) | |
| table_renames = renames[table_name] | |
| table = table.rename(columns=table_renames) | |
| for c, a in table_args["columns"].items(): | |
| if a["sdtype"] == "id": | |
| continue | |
| domains[table_name][table_renames.get(c, c)] = { | |
| "size": len(table[table_renames.get(c, c)].unique()), | |
| "type": "discrete" if a["sdtype"] == "categorical" else "continuous", | |
| } | |
| primary_key = table_args.get("primary_key") | |
| if table_name in need_null_pk: | |
| min_n_na = 0 | |
| max_n_na = np.inf | |
| for min_size, max_size, na_size in need_null_pk[table_name]: | |
| min_n_na = max(min_n_na, na_size / max_size) | |
| max_n_na = min(max_n_na, na_size / min_size) | |
| if max_n_na < min_n_na: | |
| raise RuntimeError("Number of NULLs cannot be inferred.") | |
| n_na = np.random.randint(np.floor(min_n_na), np.ceil(max_n_na) + 1) | |
| n_na_all[table_name] = n_na | |
| sampled_df = pd.DataFrame({ | |
| c: table[c].sample(n=n_na).values for c in table.columns | |
| }) | |
| sampled_df[primary_key] = np.char.add("NULL-KEY-", np.arange(n_na).astype(str)) | |
| table["_isna_key"] = "notna" | |
| sampled_df["_isna_key"] = "isna" | |
| table = pd.concat([table, sampled_df[table.columns]], axis=0, ignore_index=True) | |
| domains[table_name]["_isna_key"] = { | |
| "size": 2, "type": "discrete" | |
| } | |
| index = defaultdict(int) | |
| for fk in parents_fk[table_name]: | |
| current_index = index[(fk["parent_table_name"], fk["child_table_name"])] + 1 | |
| parent_name = fk["parent_table_name"] | |
| if current_index > 1: | |
| parent_name = f"{parent_name}{current_index}" | |
| if fk["nullable"]: | |
| isna = table[fk["child_foreign_key"]].isna() | |
| keys = pd.read_csv(os.path.join(args.output_dir, "data", f"{fk['parent_table_name']}.csv")) | |
| na_keys = keys[f"{fk['parent_table_name']}_id"][keys["_isna_key"] == "isna"] | |
| sampled_na_keys = na_keys.sample(n=isna.sum(), replace=True) | |
| table.loc[isna, fk["child_foreign_key"]] = sampled_na_keys.values | |
| table = table.rename(columns={fk["child_foreign_key"]: f"{parent_name}_id"}) | |
| index[(fk["parent_table_name"], fk["child_table_name"])] += 1 | |
| parents[table_name].append(parent_name) | |
| if primary_key is not None: | |
| if sum_n_copies[table_name] > 1: | |
| for i in range(sum_n_copies[table_name]): | |
| if i == 0: | |
| new_table = table.rename(columns={primary_key: f"{table_name}_id"}) | |
| new_table.to_csv(os.path.join(args.output_dir, "data", f"{table_name}.csv"), index=False) | |
| else: | |
| new_table = pd.DataFrame({ | |
| f"{table_name}_id": table[primary_key], | |
| f"{table_name}{i + 1}_id": table[primary_key], | |
| }) | |
| new_table.to_csv(os.path.join(args.output_dir, "data", f"{table_name}{i + 1}.csv"), index=False) | |
| domains[f"{table_name}{i + 1}"] = {} | |
| parents[f"{table_name}{i + 1}"].append(table_name) | |
| else: | |
| table = table.rename(columns={primary_key: f"{table_name}_id"}) | |
| table.to_csv(os.path.join(args.output_dir, "data", f"{table_name}.csv"), index=False) | |
| else: | |
| for i in range(max(1, sum_n_copies[table_name])): | |
| if sum_n_copies[table_name] > 1 or i == 0: | |
| new_table = table.copy() | |
| new_table[f"{table_name}_id"] = np.arange(new_table.shape[0]) | |
| new_table.to_csv(os.path.join(args.output_dir, "data", f"{table_name}.csv"), index=False) | |
| else: | |
| new_table = pd.DataFrame({ | |
| f"{table_name}_id": np.arange(table.shape[0]), | |
| f"{table_name}{i + 1}_id": np.arange(table.shape[0]), | |
| }) | |
| new_table.to_csv(os.path.join(args.output_dir, "data", f"{table_name}{i + 1}.csv"), index=False) | |
| domains[f"{table_name}{i + 1}"] = {} | |
| parents[f"{table_name}{i + 1}"].append(table_name) | |
| for table_name, table_domain in domains.items(): | |
| with open(os.path.join(args.output_dir, "data", f"{table_name}_domain.json"), "w") as f: | |
| json.dump(table_domain, f, indent=2) | |
| children = defaultdict(list) | |
| for table_name, table_parents in parents.items(): | |
| for parent in table_parents: | |
| children[parent].append(table_name) | |
| for table_name, child_tables in children.items(): | |
| table = pd.read_csv(os.path.join(args.output_dir, "data", f"{table_name}.csv")) | |
| if not pd.api.types.is_numeric_dtype(table[f"{table_name}_id"].dtype): | |
| mapper = table[f"{table_name}_id"].reset_index().set_index(f"{table_name}_id")["index"].to_dict() | |
| table[f"{table_name}_id"] = table[f"{table_name}_id"].map(mapper) | |
| table.to_csv(os.path.join(args.output_dir, "data", f"{table_name}.csv"), index=False) | |
| for child in child_tables: | |
| child_table = pd.read_csv(os.path.join(args.output_dir, "data", f"{child}.csv")) | |
| child_table[f"{table_name}_id"] = child_table[f"{table_name}_id"].map(mapper) | |
| child_table.to_csv(os.path.join(args.output_dir, "data", f"{child}.csv"), index=False) | |
| all_tables = [] | |
| for table_name in schema["tables"]: | |
| all_tables.append(table_name) | |
| if sum_n_copies[table_name] > 1: | |
| for i in range(1, sum_n_copies[table_name]): | |
| all_tables.append(f"{table_name}{i + 1}") | |
| meta = {"tables": {table_name: {"parents": [], "children": []} for table_name in all_tables}} | |
| for table_name, table_parents in parents.items(): | |
| meta["tables"][table_name]["parents"] = table_parents | |
| for table_name, table_children in children.items(): | |
| meta["tables"][table_name]["children"] = table_children | |
| sorted_order = topological_sort(meta["tables"]) | |
| meta["relation_order"] = sorted_order | |
| with open(os.path.join(args.output_dir, "data", "dataset_meta.json"), "w") as f: | |
| json.dump(meta, f, indent=2) | |
| with open(os.path.join(__file__.replace("process.py", "configs/movie_lens.json")), "r") as f: | |
| configs = json.load(f) | |
| configs["general"]["data_dir"] = os.path.join(args.output_dir, "data") | |
| configs["general"]["exp_name"] = args.dataset_name | |
| configs["general"]["workspace_dir"] = os.path.join(args.output_dir, "workspace") | |
| if args.fast: | |
| configs["diffusion"]["iterations"] = 200 | |
| configs["classifier"]["iterations"] = 200 | |
| configs["diffusion"]["num_timesteps"] = 20 | |
| with open(os.path.join(args.output_dir, "config.json"), "w") as f: | |
| json.dump(configs, f, indent=2) | |
| with open(os.path.join(args.output_dir, "process-config.json"), "w") as f: | |
| json.dump({ | |
| "parent_fks": parents_fk, | |
| "has_null_pk": [*need_null_pk], | |
| "sum_n_copies": sum_n_copies, | |
| "renames": renames | |
| }, f, indent=2) | |
| end_time = time.time() | |
| with open(os.path.join(args.output_dir, "timing.json"), "w") as f: | |
| json.dump({"preprocess": end_time - start_time}, f, indent=2) | |
| def postprocess(args): | |
| with open(args.sdv_schema, "r") as f: | |
| schema = json.load(f) | |
| start_time = time.time() | |
| with open(os.path.join(args.output_dir, "data", "dataset_meta.json"), "r") as f: | |
| dataset_meta = json.load(f) | |
| relation_order = dataset_meta["relation_order"] | |
| os.makedirs(os.path.join(args.output_dir, "generated"), exist_ok=True) | |
| os.makedirs(os.path.join(args.output_dir, "intermediate-generated"), exist_ok=True) | |
| for table_name in dataset_meta["tables"]: | |
| shutil.copyfile( | |
| os.path.join(args.output_dir, "workspace", table_name, "_final", f"{table_name}_synthetic.csv"), | |
| os.path.join(args.output_dir, "intermediate-generated", f"{table_name}.csv"), | |
| ) | |
| for parent, child in relation_order: | |
| if parent is not None and re.fullmatch(r".*\d+", parent): | |
| base_parent = parent | |
| while re.fullmatch(r".*\d+", base_parent): | |
| base_parent = base_parent[:-1] | |
| child_table = pd.read_csv(os.path.join(args.output_dir, "intermediate-generated", f"{child}.csv")) | |
| parent_table = pd.read_csv(os.path.join(args.output_dir, "intermediate-generated", f"{parent}.csv")) | |
| base_parent_table = pd.read_csv( | |
| os.path.join(args.output_dir, "intermediate-generated", f"{base_parent}.csv") | |
| ) | |
| child_table = child_table.merge( | |
| parent_table.rename(columns={f"{base_parent}_id": f"___{base_parent}_id"}), | |
| on=f"{parent}_id", how="left" | |
| ) | |
| child_table = child_table.merge( | |
| base_parent_table[[f"{base_parent}_id"]].rename(columns={f"{base_parent}_id": f"___{base_parent}_id"}), | |
| on=f"___{base_parent}_id", how="left" | |
| ).drop(columns=[f"{parent}_id"]).rename(columns={f"___{base_parent}_id": f"{parent}_id"}) | |
| child_table.to_csv(os.path.join(args.output_dir, "intermediate-generated", f"{child}.csv"), index=False) | |
| with open(os.path.join(args.output_dir, "process-config.json"), "r") as f: | |
| loaded = json.load(f) | |
| parents_fk = loaded["parent_fks"] | |
| has_null_pk = loaded["has_null_pk"] | |
| renames = loaded["renames"] | |
| for table_name, table_args in schema["tables"].items(): | |
| table = pd.read_csv(os.path.join(args.output_dir, "intermediate-generated", f"{table_name}.csv")) | |
| primary_key = table_args.get("primary_key") | |
| if primary_key is not None: | |
| table = table.rename(columns={f"{table_name}_id": primary_key}) | |
| if table_name in has_null_pk: | |
| table = table[table["_isna_key"] == "notna"].reset_index(drop=True).drop(columns=["_isna_key"]) | |
| index = defaultdict(int) | |
| for fk in parents_fk[table_name]: | |
| current_index = index[(fk["parent_table_name"], fk["child_table_name"])] + 1 | |
| parent_name = fk["parent_table_name"] | |
| if current_index > 1: | |
| parent_name = f"{parent_name}{current_index}" | |
| table = table.rename(columns={f"{parent_name}_id": fk["child_foreign_key"]}) | |
| if fk["nullable"]: | |
| parent_na_key = pd.read_csv( | |
| os.path.join(args.output_dir, "intermediate-generated", f"{fk['parent_table_name']}.csv") | |
| ) | |
| parent_na_key = parent_na_key[parent_na_key["_isna_key"] == "isna"][f'{fk["parent_table_name"]}_id'] | |
| table[fk["child_foreign_key"]] = table[fk["child_foreign_key"]].replace(parent_na_key.tolist(), np.nan) | |
| index[(fk["parent_table_name"], fk["child_table_name"])] += 1 | |
| table = table.rename(columns={v: k for k, v in renames[table_name].items()}) | |
| table.to_csv(os.path.join(args.output_dir, "generated", f"{table_name}.csv"), index=False) | |
| end_time = time.time() | |
| with open(os.path.join(args.output_dir, "timing.json"), "r") as f: | |
| timing = json.load(f) | |
| timing["postprocess"] = end_time - start_time | |
| with open(os.path.join(args.output_dir, "timing.json"), "w") as f: | |
| json.dump(timing, f, indent=2) | |
| if __name__ == "__main__": | |
| main() | |