File size: 1,998 Bytes
b0d7cdb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
import argparse
import os.path

import pandas as pd
import torch

from tabtreeformer import TabTreeFormer


def parse_args() -> argparse.Namespace:
    parser = argparse.ArgumentParser()
    subparsers = parser.add_subparsers(dest="op")

    train_parser = subparsers.add_parser("train")
    train_parser.add_argument("--data-path", "-d", type=str, required=True,
                              help="Path to data (.csv file).")
    train_parser.add_argument("--target", "-t", type=str, required=True,
                              help="Target column name.")
    train_parser.add_argument("--ttype", "-p", type=str, required=True, choices=["bin", "mult", "reg"],
                              help="Task type.")
    train_parser.add_argument("--out", "-o", type=str, required=True,
                              help="Path to output directory.")

    sample_parser = subparsers.add_parser("sample")
    sample_parser.add_argument("--ckpt-path", "-c", type=str, required=True,
                               help="Path to checkpoint directory (output directory during training).")
    sample_parser.add_argument("--n-rows", "-n", type=int, required=True,
                               help="Number of rows to sample.")
    sample_parser.add_argument("--out", "-o", type=str, required=True,
                               help="Path to output synthetic data (.csv file).")
    return parser.parse_args()


def main():
    args = parse_args()
    if args.op == "train":
        data = pd.read_csv(args.data_path)
        ttf = TabTreeFormer()
        ttf.train(data, args.target, args.ttype, args.out)
        torch.save(ttf, os.path.join(args.out, "ttf.pkl"))
    elif args.op == "sample":
        ttf: TabTreeFormer = torch.load(os.path.join(args.ckpt_path, "ttf.pkl"))
        sampled = ttf.sample(args.n_rows)
        sampled.to_csv(args.out, index=False)
    else:
        raise ValueError("Invalid op.")


if __name__ == "__main__":
    main()