Spaces:
Running
Running
| import argparse | |
| import os.path | |
| import pandas as pd | |
| import torch | |
| from tabtreeformer import TabTreeFormer | |
| def parse_args() -> argparse.Namespace: | |
| parser = argparse.ArgumentParser() | |
| subparsers = parser.add_subparsers(dest="op") | |
| train_parser = subparsers.add_parser("train") | |
| train_parser.add_argument("--data-path", "-d", type=str, required=True, | |
| help="Path to data (.csv file).") | |
| train_parser.add_argument("--target", "-t", type=str, required=True, | |
| help="Target column name.") | |
| train_parser.add_argument("--ttype", "-p", type=str, required=True, choices=["bin", "mult", "reg"], | |
| help="Task type.") | |
| train_parser.add_argument("--out", "-o", type=str, required=True, | |
| help="Path to output directory.") | |
| sample_parser = subparsers.add_parser("sample") | |
| sample_parser.add_argument("--ckpt-path", "-c", type=str, required=True, | |
| help="Path to checkpoint directory (output directory during training).") | |
| sample_parser.add_argument("--n-rows", "-n", type=int, required=True, | |
| help="Number of rows to sample.") | |
| sample_parser.add_argument("--out", "-o", type=str, required=True, | |
| help="Path to output synthetic data (.csv file).") | |
| return parser.parse_args() | |
| def main(): | |
| args = parse_args() | |
| if args.op == "train": | |
| data = pd.read_csv(args.data_path) | |
| ttf = TabTreeFormer() | |
| ttf.train(data, args.target, args.ttype, args.out) | |
| torch.save(ttf, os.path.join(args.out, "ttf.pkl")) | |
| elif args.op == "sample": | |
| ttf: TabTreeFormer = torch.load(os.path.join(args.ckpt_path, "ttf.pkl")) | |
| sampled = ttf.sample(args.n_rows) | |
| sampled.to_csv(args.out, index=False) | |
| else: | |
| raise ValueError("Invalid op.") | |
| if __name__ == "__main__": | |
| main() | |