File size: 3,826 Bytes
89f29f8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
import argparse
import logging
from typing import List, Optional

import pandas as pd

from tabgan.sampler import (
    OriginalGenerator,
    GANGenerator,
    ForestDiffusionGenerator,
    LLMGenerator,
)


def _parse_cat_cols(raw: Optional[str]) -> Optional[List[str]]:
    if not raw:
        return None
    return [c.strip() for c in raw.split(",") if c.strip()]


def main() -> None:
    """
    Command-line interface for generating synthetic tabular data with tabgan.

    Example:
        tabgan-generate \\
            --input-csv train.csv \\
            --target-col target \\
            --generator gan \\
            --gen-x-times 1.5 \\
            --cat-cols year,gender \\
            --output-csv synthetic_train.csv
    """
    parser = argparse.ArgumentParser(
        description="Generate synthetic tabular data using tabgan samplers."
    )
    parser.add_argument(
        "--input-csv",
        required=True,
        help="Path to input CSV file containing training data (with or without target column).",
    )
    parser.add_argument(
        "--target-col",
        default=None,
        help="Name of the target column in the CSV (optional).",
    )
    parser.add_argument(
        "--output-csv",
        required=True,
        help="Path to write the generated synthetic dataset as CSV.",
    )
    parser.add_argument(
        "--generator",
        choices=["original", "gan", "diffusion", "llm"],
        default="gan",
        help="Which sampler to use for generation.",
    )
    parser.add_argument(
        "--gen-x-times",
        type=float,
        default=1.1,
        help="Factor controlling how many synthetic samples to generate relative to the training size.",
    )
    parser.add_argument(
        "--cat-cols",
        default=None,
        help="Comma-separated list of categorical column names (e.g. 'year,gender').",
    )
    parser.add_argument(
        "--only-generated",
        action="store_true",
        help="If set, output only synthetic rows instead of original + synthetic.",
    )

    args = parser.parse_args()

    logging.basicConfig(level=logging.INFO)

    logging.info("Reading input CSV from %s", args.input_csv)
    df = pd.read_csv(args.input_csv)

    target_df = None
    train_df = df
    if args.target_col is not None:
        if args.target_col not in df.columns:
            raise ValueError(f"Target column '{args.target_col}' not found in input CSV.")
        target_df = df[[args.target_col]]
        train_df = df.drop(columns=[args.target_col])

    cat_cols = _parse_cat_cols(args.cat_cols)

    generator_map = {
        "original": OriginalGenerator,
        "gan": GANGenerator,
        "diffusion": ForestDiffusionGenerator,
        "llm": LLMGenerator,
    }
    generator_cls = generator_map[args.generator]

    logging.info("Initializing %s generator", generator_cls.__name__)
    generator = generator_cls(
        gen_x_times=args.gen_x_times,
        cat_cols=cat_cols,
        only_generated_data=bool(args.only_generated),
    )

    # Use train_df itself as test_df when a dedicated hold-out set is not provided.
    logging.info("Generating synthetic data...")
    new_train, new_target = generator.generate_data_pipe(
        train_df, target_df, train_df
    )

    if new_target is not None and args.target_col is not None:
        out_df = new_train.copy()
        # new_target can be DataFrame or Series; align to a 1D array
        if hasattr(new_target, "values") and new_target.ndim > 1:
            out_df[args.target_col] = new_target.values.ravel()
        else:
            out_df[args.target_col] = new_target
    else:
        out_df = new_train

    logging.info("Writing synthetic data to %s", args.output_csv)
    out_df.to_csv(args.output_csv, index=False)


if __name__ == "__main__":
    main()