import argparse import os.path import pandas as pd import torch from tabtreeformer import TabTreeFormer def parse_args() -> argparse.Namespace: parser = argparse.ArgumentParser() subparsers = parser.add_subparsers(dest="op") train_parser = subparsers.add_parser("train") train_parser.add_argument("--data-path", "-d", type=str, required=True, help="Path to data (.csv file).") train_parser.add_argument("--target", "-t", type=str, required=True, help="Target column name.") train_parser.add_argument("--ttype", "-p", type=str, required=True, choices=["bin", "mult", "reg"], help="Task type.") train_parser.add_argument("--out", "-o", type=str, required=True, help="Path to output directory.") sample_parser = subparsers.add_parser("sample") sample_parser.add_argument("--ckpt-path", "-c", type=str, required=True, help="Path to checkpoint directory (output directory during training).") sample_parser.add_argument("--n-rows", "-n", type=int, required=True, help="Number of rows to sample.") sample_parser.add_argument("--out", "-o", type=str, required=True, help="Path to output synthetic data (.csv file).") return parser.parse_args() def main(): args = parse_args() if args.op == "train": data = pd.read_csv(args.data_path) ttf = TabTreeFormer() ttf.train(data, args.target, args.ttype, args.out) torch.save(ttf, os.path.join(args.out, "ttf.pkl")) elif args.op == "sample": ttf: TabTreeFormer = torch.load(os.path.join(args.ckpt_path, "ttf.pkl")) sampled = ttf.sample(args.n_rows) sampled.to_csv(args.out, index=False) else: raise ValueError("Invalid op.") if __name__ == "__main__": main()