Add tabgan/sklearn_transformer.py
Browse files- tabgan/sklearn_transformer.py +149 -0
tabgan/sklearn_transformer.py
ADDED
|
@@ -0,0 +1,149 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
+
"""
|
| 3 |
+
sklearn-compatible transformer for TabGAN data augmentation.
|
| 4 |
+
|
| 5 |
+
Allows inserting synthetic data generation into a ``sklearn.pipeline.Pipeline``.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
import logging
|
| 9 |
+
from typing import List, Optional, Type
|
| 10 |
+
|
| 11 |
+
import numpy as np
|
| 12 |
+
import pandas as pd
|
| 13 |
+
from sklearn.base import BaseEstimator, TransformerMixin
|
| 14 |
+
|
| 15 |
+
from tabgan.sampler import GANGenerator
|
| 16 |
+
|
| 17 |
+
__all__ = ["TabGANTransformer"]
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
class TabGANTransformer(BaseEstimator, TransformerMixin):
|
| 21 |
+
"""Augment training data with TabGAN synthetic rows inside an sklearn Pipeline.
|
| 22 |
+
|
| 23 |
+
During ``fit`` the generator is trained and synthetic data produced.
|
| 24 |
+
``transform`` returns the augmented DataFrame (original + synthetic).
|
| 25 |
+
|
| 26 |
+
Because sklearn's ``transform`` only returns X, the augmented target
|
| 27 |
+
is available via :meth:`get_augmented_target` after ``fit_transform``.
|
| 28 |
+
|
| 29 |
+
Args:
|
| 30 |
+
generator_class: A TabGAN generator class (e.g. ``GANGenerator``).
|
| 31 |
+
gen_x_times: Multiplier for synthetic sample count.
|
| 32 |
+
cat_cols: Categorical column names.
|
| 33 |
+
gen_params: Generator-specific hyperparameters.
|
| 34 |
+
only_generated_data: If True, return only synthetic rows.
|
| 35 |
+
constraints: Optional list of ``Constraint`` instances.
|
| 36 |
+
use_adversarial: Whether to use adversarial filtering.
|
| 37 |
+
**generator_kwargs: Extra keyword arguments forwarded to the generator.
|
| 38 |
+
|
| 39 |
+
Example::
|
| 40 |
+
|
| 41 |
+
from sklearn.pipeline import Pipeline
|
| 42 |
+
from sklearn.ensemble import RandomForestClassifier
|
| 43 |
+
from tabgan.sklearn_transformer import TabGANTransformer
|
| 44 |
+
|
| 45 |
+
pipe = Pipeline([
|
| 46 |
+
("augment", TabGANTransformer(gen_x_times=1.5)),
|
| 47 |
+
("model", RandomForestClassifier()),
|
| 48 |
+
])
|
| 49 |
+
pipe.fit(X_train, y_train)
|
| 50 |
+
"""
|
| 51 |
+
|
| 52 |
+
def __init__(
|
| 53 |
+
self,
|
| 54 |
+
generator_class: Type = None,
|
| 55 |
+
gen_x_times: float = 1.1,
|
| 56 |
+
cat_cols: Optional[List[str]] = None,
|
| 57 |
+
gen_params: Optional[dict] = None,
|
| 58 |
+
only_generated_data: bool = False,
|
| 59 |
+
constraints: Optional[list] = None,
|
| 60 |
+
use_adversarial: bool = True,
|
| 61 |
+
**generator_kwargs,
|
| 62 |
+
):
|
| 63 |
+
self.generator_class = generator_class
|
| 64 |
+
self.gen_x_times = gen_x_times
|
| 65 |
+
self.cat_cols = cat_cols
|
| 66 |
+
self.gen_params = gen_params
|
| 67 |
+
self.only_generated_data = only_generated_data
|
| 68 |
+
self.constraints = constraints
|
| 69 |
+
self.use_adversarial = use_adversarial
|
| 70 |
+
self.generator_kwargs = generator_kwargs
|
| 71 |
+
|
| 72 |
+
# Internal state (set after fit)
|
| 73 |
+
self._augmented_X: Optional[pd.DataFrame] = None
|
| 74 |
+
self._augmented_y: Optional[pd.Series] = None
|
| 75 |
+
|
| 76 |
+
def fit(self, X, y=None):
|
| 77 |
+
"""Train the generator and produce synthetic data.
|
| 78 |
+
|
| 79 |
+
Args:
|
| 80 |
+
X: Training features (DataFrame or ndarray).
|
| 81 |
+
y: Target variable (Series, DataFrame, ndarray, or None).
|
| 82 |
+
"""
|
| 83 |
+
gen_cls = self.generator_class or GANGenerator
|
| 84 |
+
|
| 85 |
+
X_df = pd.DataFrame(X).copy() if not isinstance(X, pd.DataFrame) else X.copy()
|
| 86 |
+
|
| 87 |
+
target_df = None
|
| 88 |
+
if y is not None:
|
| 89 |
+
if isinstance(y, pd.DataFrame):
|
| 90 |
+
target_df = y.copy()
|
| 91 |
+
elif isinstance(y, pd.Series):
|
| 92 |
+
target_df = y.to_frame().copy()
|
| 93 |
+
else:
|
| 94 |
+
target_df = pd.DataFrame(y, columns=["target"])
|
| 95 |
+
|
| 96 |
+
gen_kwargs = dict(
|
| 97 |
+
gen_x_times=self.gen_x_times,
|
| 98 |
+
cat_cols=self.cat_cols,
|
| 99 |
+
only_generated_data=self.only_generated_data,
|
| 100 |
+
)
|
| 101 |
+
if self.gen_params is not None:
|
| 102 |
+
gen_kwargs["gen_params"] = self.gen_params
|
| 103 |
+
gen_kwargs.update(self.generator_kwargs)
|
| 104 |
+
|
| 105 |
+
generator = gen_cls(**gen_kwargs)
|
| 106 |
+
|
| 107 |
+
new_train, new_target = generator.generate_data_pipe(
|
| 108 |
+
X_df,
|
| 109 |
+
target_df,
|
| 110 |
+
X_df, # use train as test for distribution alignment
|
| 111 |
+
use_adversarial=self.use_adversarial,
|
| 112 |
+
constraints=self.constraints,
|
| 113 |
+
)
|
| 114 |
+
|
| 115 |
+
self._augmented_X = new_train
|
| 116 |
+
if new_target is not None and not new_target.isna().all():
|
| 117 |
+
self._augmented_y = (
|
| 118 |
+
new_target.iloc[:, 0] if isinstance(new_target, pd.DataFrame) else new_target
|
| 119 |
+
)
|
| 120 |
+
else:
|
| 121 |
+
self._augmented_y = None
|
| 122 |
+
|
| 123 |
+
return self
|
| 124 |
+
|
| 125 |
+
def transform(self, X, y=None):
|
| 126 |
+
"""Return the augmented training data.
|
| 127 |
+
|
| 128 |
+
During training (when ``_augmented_X`` is available), returns the
|
| 129 |
+
augmented data. At inference time, returns X unchanged.
|
| 130 |
+
"""
|
| 131 |
+
if self._augmented_X is not None:
|
| 132 |
+
result = self._augmented_X
|
| 133 |
+
# Clear after first transform to avoid leaking into predict
|
| 134 |
+
self._augmented_X = None
|
| 135 |
+
return result
|
| 136 |
+
return X
|
| 137 |
+
|
| 138 |
+
def fit_transform(self, X, y=None, **fit_params):
|
| 139 |
+
"""Fit and return augmented data in one step."""
|
| 140 |
+
self.fit(X, y)
|
| 141 |
+
return self.transform(X, y)
|
| 142 |
+
|
| 143 |
+
def get_augmented_target(self) -> Optional[pd.Series]:
|
| 144 |
+
"""Return the augmented target produced during ``fit``.
|
| 145 |
+
|
| 146 |
+
Call this after ``fit`` or ``fit_transform`` to get the target
|
| 147 |
+
values corresponding to the augmented training data.
|
| 148 |
+
"""
|
| 149 |
+
return self._augmented_y
|