InsafQ commited on
Commit
0e1204d
·
verified ·
1 Parent(s): 61f0faf

Add tabgan/sklearn_transformer.py

Browse files
Files changed (1) hide show
  1. tabgan/sklearn_transformer.py +149 -0
tabgan/sklearn_transformer.py ADDED
@@ -0,0 +1,149 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ """
3
+ sklearn-compatible transformer for TabGAN data augmentation.
4
+
5
+ Allows inserting synthetic data generation into a ``sklearn.pipeline.Pipeline``.
6
+ """
7
+
8
+ import logging
9
+ from typing import List, Optional, Type
10
+
11
+ import numpy as np
12
+ import pandas as pd
13
+ from sklearn.base import BaseEstimator, TransformerMixin
14
+
15
+ from tabgan.sampler import GANGenerator
16
+
17
+ __all__ = ["TabGANTransformer"]
18
+
19
+
20
+ class TabGANTransformer(BaseEstimator, TransformerMixin):
21
+ """Augment training data with TabGAN synthetic rows inside an sklearn Pipeline.
22
+
23
+ During ``fit`` the generator is trained and synthetic data produced.
24
+ ``transform`` returns the augmented DataFrame (original + synthetic).
25
+
26
+ Because sklearn's ``transform`` only returns X, the augmented target
27
+ is available via :meth:`get_augmented_target` after ``fit_transform``.
28
+
29
+ Args:
30
+ generator_class: A TabGAN generator class (e.g. ``GANGenerator``).
31
+ gen_x_times: Multiplier for synthetic sample count.
32
+ cat_cols: Categorical column names.
33
+ gen_params: Generator-specific hyperparameters.
34
+ only_generated_data: If True, return only synthetic rows.
35
+ constraints: Optional list of ``Constraint`` instances.
36
+ use_adversarial: Whether to use adversarial filtering.
37
+ **generator_kwargs: Extra keyword arguments forwarded to the generator.
38
+
39
+ Example::
40
+
41
+ from sklearn.pipeline import Pipeline
42
+ from sklearn.ensemble import RandomForestClassifier
43
+ from tabgan.sklearn_transformer import TabGANTransformer
44
+
45
+ pipe = Pipeline([
46
+ ("augment", TabGANTransformer(gen_x_times=1.5)),
47
+ ("model", RandomForestClassifier()),
48
+ ])
49
+ pipe.fit(X_train, y_train)
50
+ """
51
+
52
+ def __init__(
53
+ self,
54
+ generator_class: Type = None,
55
+ gen_x_times: float = 1.1,
56
+ cat_cols: Optional[List[str]] = None,
57
+ gen_params: Optional[dict] = None,
58
+ only_generated_data: bool = False,
59
+ constraints: Optional[list] = None,
60
+ use_adversarial: bool = True,
61
+ **generator_kwargs,
62
+ ):
63
+ self.generator_class = generator_class
64
+ self.gen_x_times = gen_x_times
65
+ self.cat_cols = cat_cols
66
+ self.gen_params = gen_params
67
+ self.only_generated_data = only_generated_data
68
+ self.constraints = constraints
69
+ self.use_adversarial = use_adversarial
70
+ self.generator_kwargs = generator_kwargs
71
+
72
+ # Internal state (set after fit)
73
+ self._augmented_X: Optional[pd.DataFrame] = None
74
+ self._augmented_y: Optional[pd.Series] = None
75
+
76
+ def fit(self, X, y=None):
77
+ """Train the generator and produce synthetic data.
78
+
79
+ Args:
80
+ X: Training features (DataFrame or ndarray).
81
+ y: Target variable (Series, DataFrame, ndarray, or None).
82
+ """
83
+ gen_cls = self.generator_class or GANGenerator
84
+
85
+ X_df = pd.DataFrame(X).copy() if not isinstance(X, pd.DataFrame) else X.copy()
86
+
87
+ target_df = None
88
+ if y is not None:
89
+ if isinstance(y, pd.DataFrame):
90
+ target_df = y.copy()
91
+ elif isinstance(y, pd.Series):
92
+ target_df = y.to_frame().copy()
93
+ else:
94
+ target_df = pd.DataFrame(y, columns=["target"])
95
+
96
+ gen_kwargs = dict(
97
+ gen_x_times=self.gen_x_times,
98
+ cat_cols=self.cat_cols,
99
+ only_generated_data=self.only_generated_data,
100
+ )
101
+ if self.gen_params is not None:
102
+ gen_kwargs["gen_params"] = self.gen_params
103
+ gen_kwargs.update(self.generator_kwargs)
104
+
105
+ generator = gen_cls(**gen_kwargs)
106
+
107
+ new_train, new_target = generator.generate_data_pipe(
108
+ X_df,
109
+ target_df,
110
+ X_df, # use train as test for distribution alignment
111
+ use_adversarial=self.use_adversarial,
112
+ constraints=self.constraints,
113
+ )
114
+
115
+ self._augmented_X = new_train
116
+ if new_target is not None and not new_target.isna().all():
117
+ self._augmented_y = (
118
+ new_target.iloc[:, 0] if isinstance(new_target, pd.DataFrame) else new_target
119
+ )
120
+ else:
121
+ self._augmented_y = None
122
+
123
+ return self
124
+
125
+ def transform(self, X, y=None):
126
+ """Return the augmented training data.
127
+
128
+ During training (when ``_augmented_X`` is available), returns the
129
+ augmented data. At inference time, returns X unchanged.
130
+ """
131
+ if self._augmented_X is not None:
132
+ result = self._augmented_X
133
+ # Clear after first transform to avoid leaking into predict
134
+ self._augmented_X = None
135
+ return result
136
+ return X
137
+
138
+ def fit_transform(self, X, y=None, **fit_params):
139
+ """Fit and return augmented data in one step."""
140
+ self.fit(X, y)
141
+ return self.transform(X, y)
142
+
143
+ def get_augmented_target(self) -> Optional[pd.Series]:
144
+ """Return the augmented target produced during ``fit``.
145
+
146
+ Call this after ``fit`` or ``fit_transform`` to get the target
147
+ values corresponding to the augmented training data.
148
+ """
149
+ return self._augmented_y