TabGAN / tabgan /cli.py
InsafQ's picture
Add tabgan/cli.py
89f29f8 verified
import argparse
import logging
from typing import List, Optional
import pandas as pd
from tabgan.sampler import (
OriginalGenerator,
GANGenerator,
ForestDiffusionGenerator,
LLMGenerator,
)
def _parse_cat_cols(raw: Optional[str]) -> Optional[List[str]]:
if not raw:
return None
return [c.strip() for c in raw.split(",") if c.strip()]
def main() -> None:
"""
Command-line interface for generating synthetic tabular data with tabgan.
Example:
tabgan-generate \\
--input-csv train.csv \\
--target-col target \\
--generator gan \\
--gen-x-times 1.5 \\
--cat-cols year,gender \\
--output-csv synthetic_train.csv
"""
parser = argparse.ArgumentParser(
description="Generate synthetic tabular data using tabgan samplers."
)
parser.add_argument(
"--input-csv",
required=True,
help="Path to input CSV file containing training data (with or without target column).",
)
parser.add_argument(
"--target-col",
default=None,
help="Name of the target column in the CSV (optional).",
)
parser.add_argument(
"--output-csv",
required=True,
help="Path to write the generated synthetic dataset as CSV.",
)
parser.add_argument(
"--generator",
choices=["original", "gan", "diffusion", "llm"],
default="gan",
help="Which sampler to use for generation.",
)
parser.add_argument(
"--gen-x-times",
type=float,
default=1.1,
help="Factor controlling how many synthetic samples to generate relative to the training size.",
)
parser.add_argument(
"--cat-cols",
default=None,
help="Comma-separated list of categorical column names (e.g. 'year,gender').",
)
parser.add_argument(
"--only-generated",
action="store_true",
help="If set, output only synthetic rows instead of original + synthetic.",
)
args = parser.parse_args()
logging.basicConfig(level=logging.INFO)
logging.info("Reading input CSV from %s", args.input_csv)
df = pd.read_csv(args.input_csv)
target_df = None
train_df = df
if args.target_col is not None:
if args.target_col not in df.columns:
raise ValueError(f"Target column '{args.target_col}' not found in input CSV.")
target_df = df[[args.target_col]]
train_df = df.drop(columns=[args.target_col])
cat_cols = _parse_cat_cols(args.cat_cols)
generator_map = {
"original": OriginalGenerator,
"gan": GANGenerator,
"diffusion": ForestDiffusionGenerator,
"llm": LLMGenerator,
}
generator_cls = generator_map[args.generator]
logging.info("Initializing %s generator", generator_cls.__name__)
generator = generator_cls(
gen_x_times=args.gen_x_times,
cat_cols=cat_cols,
only_generated_data=bool(args.only_generated),
)
# Use train_df itself as test_df when a dedicated hold-out set is not provided.
logging.info("Generating synthetic data...")
new_train, new_target = generator.generate_data_pipe(
train_df, target_df, train_df
)
if new_target is not None and args.target_col is not None:
out_df = new_train.copy()
# new_target can be DataFrame or Series; align to a 1D array
if hasattr(new_target, "values") and new_target.ndim > 1:
out_df[args.target_col] = new_target.values.ravel()
else:
out_df[args.target_col] = new_target
else:
out_df = new_train
logging.info("Writing synthetic data to %s", args.output_csv)
out_df.to_csv(args.output_csv, index=False)
if __name__ == "__main__":
main()