Zilong-Zhao commited on
Commit
b0d7cdb
·
0 Parent(s):

first commit

Browse files
README.md ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # TabTreeFormer: Tabular Data Generation Using Hybrid Tree-Transformer
2
+
3
+ ## Pre-requisites
4
+
5
+ - `Python>=3.9` installed.
6
+ - `pip install -r requirements.txt`.
7
+
8
+ ## Usage
9
+
10
+ To train,
11
+
12
+ ```shell
13
+ python main.py train -d DATA_PATH -t TARGET_COLUMN -p TASK_TYPE -o OUT_DIR
14
+ ```
15
+
16
+ After training, to generate,
17
+
18
+ ```shell
19
+ python main.py sample -c OUT_DIR -n N_ROWS -o OUT_CSV_PATH
20
+ ```
21
+
22
+ For instance, to train and sample iris dataset (one can get the dataset by `sklearn.datasets.load_iris`),
23
+ one can run the following:
24
+
25
+ ```shell
26
+ python main.py train -d iris.csv -t target -p mult -o out
27
+ python main.py sample -c out -n 150 -o synthetic-iris.csv
28
+ ```
main.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os.path
3
+
4
+ import pandas as pd
5
+ import torch
6
+
7
+ from tabtreeformer import TabTreeFormer
8
+
9
+
10
+ def parse_args() -> argparse.Namespace:
11
+ parser = argparse.ArgumentParser()
12
+ subparsers = parser.add_subparsers(dest="op")
13
+
14
+ train_parser = subparsers.add_parser("train")
15
+ train_parser.add_argument("--data-path", "-d", type=str, required=True,
16
+ help="Path to data (.csv file).")
17
+ train_parser.add_argument("--target", "-t", type=str, required=True,
18
+ help="Target column name.")
19
+ train_parser.add_argument("--ttype", "-p", type=str, required=True, choices=["bin", "mult", "reg"],
20
+ help="Task type.")
21
+ train_parser.add_argument("--out", "-o", type=str, required=True,
22
+ help="Path to output directory.")
23
+
24
+ sample_parser = subparsers.add_parser("sample")
25
+ sample_parser.add_argument("--ckpt-path", "-c", type=str, required=True,
26
+ help="Path to checkpoint directory (output directory during training).")
27
+ sample_parser.add_argument("--n-rows", "-n", type=int, required=True,
28
+ help="Number of rows to sample.")
29
+ sample_parser.add_argument("--out", "-o", type=str, required=True,
30
+ help="Path to output synthetic data (.csv file).")
31
+ return parser.parse_args()
32
+
33
+
34
+ def main():
35
+ args = parse_args()
36
+ if args.op == "train":
37
+ data = pd.read_csv(args.data_path)
38
+ ttf = TabTreeFormer()
39
+ ttf.train(data, args.target, args.ttype, args.out)
40
+ torch.save(ttf, os.path.join(args.out, "ttf.pkl"))
41
+ elif args.op == "sample":
42
+ ttf: TabTreeFormer = torch.load(os.path.join(args.ckpt_path, "ttf.pkl"))
43
+ sampled = ttf.sample(args.n_rows)
44
+ sampled.to_csv(args.out, index=False)
45
+ else:
46
+ raise ValueError("Invalid op.")
47
+
48
+
49
+ if __name__ == "__main__":
50
+ main()
requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ numpy
2
+ lightgbm
3
+ optuna
4
+ pandas
5
+ scikit-learn
6
+ torch
7
+ transformers
tabtreeformer/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .model import TabTreeFormer
tabtreeformer/data.py ADDED
@@ -0,0 +1,423 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Data Handling."""
2
+
3
+ from typing import List, Literal, Optional, Tuple, Sequence, Union
4
+
5
+ import numpy as np
6
+ import pandas as pd
7
+ import torch
8
+ from sklearn.preprocessing import KBinsDiscretizer, OrdinalEncoder, StandardScaler
9
+ from transformers import BatchEncoding, LogitsProcessor
10
+
11
+
12
+ class Dataset:
13
+ def __init__(self, data: pd.DataFrame,
14
+ target: str,
15
+ ttype: Literal["bin", "mult", "reg"]):
16
+ """
17
+ Parameters
18
+ ----------
19
+ data : pd.DataFrame
20
+ The dataset data.
21
+ target : str
22
+ The target column's name.
23
+ ttype : "bin" | "mult" | "reg"
24
+ Task type. Values can be "bin" for binary classification, "mult" for multiclass classification,
25
+ "reg" for regression.
26
+ """
27
+ n_unique = data[target].nunique()
28
+
29
+ self.data = data.copy()
30
+ self.target = target
31
+ self.ttype = ttype
32
+ self.n_classes = n_unique if ttype != "reg" else -1
33
+
34
+ self.num_columns = []
35
+ self.cat_columns = []
36
+ for c in self.data.columns:
37
+ if data[c].nunique() <= max(3, 10) or not pd.api.types.is_numeric_dtype(data[c]):
38
+ self.cat_columns.append(c)
39
+ else:
40
+ self.num_columns.append(c)
41
+ self.columns = data.columns.tolist()
42
+
43
+ cat_data = self.data[self.cat_columns].values
44
+ self._oe = OrdinalEncoder()
45
+ self.max_n_categories = 0
46
+ if cat_data.shape[-1] > 0:
47
+ self._cat_data = self._oe.fit_transform(cat_data)
48
+ self.max_n_categories = max(len(x) for x in self._oe.categories_)
49
+ else:
50
+ self._cat_data = cat_data
51
+ num_data = self.data[self.num_columns].values
52
+ self.max_n_bins = 0
53
+ self._kbins = KBinsDiscretizer(strategy="kmeans", n_bins=10, encode="ordinal", )
54
+ if len(self.num_columns) <= 0:
55
+ self._bin_data = num_data
56
+ else:
57
+ self._bin_data = self._kbins.fit_transform(num_data)
58
+ self.max_n_bins = max(x for x in self._kbins.n_bins_)
59
+ self.max_n_quantiles = 0
60
+ self._qbins = KBinsDiscretizer(strategy="quantile", n_bins=1000, encode="ordinal", )
61
+ self._sc = StandardScaler()
62
+ if len(self.num_columns) <= 0:
63
+ self._quantiles = num_data
64
+ self._num_data = num_data
65
+ else:
66
+ self._quantiles = self._qbins.fit_transform(num_data)
67
+ self.max_n_quantiles = max(x for x in self._qbins.n_bins_)
68
+ self._num_data = self._sc.fit_transform(num_data)
69
+
70
+ self._orders = []
71
+ self._target_descr = None
72
+ for c in self.columns:
73
+ if c in self.num_columns:
74
+ descr = "num", self.num_columns.index(c)
75
+ else:
76
+ descr = "cat", self.cat_columns.index(c)
77
+ self._orders.append(descr)
78
+ if c == self.target:
79
+ self._target_descr = descr
80
+
81
+ self.index_matrix = []
82
+ self.index_description = []
83
+ self._cat_indices = []
84
+ self._quantile_indices = []
85
+ idx = 0
86
+ for dtype, index in self._orders:
87
+ if dtype == "num":
88
+ if self._bin_data is not None:
89
+ self.index_matrix.append(self._bin_data[:, index])
90
+ if self._kbins is not None:
91
+ self.index_description.append(("bin", self._kbins.n_bins_[index]))
92
+ else:
93
+ self.index_description.append(("bin", 1))
94
+ idx += 1
95
+ self.index_matrix.append(self._quantiles[:, index])
96
+ self.index_description.append(("quantile", self._qbins.n_bins_[index]))
97
+ self._quantile_indices.append(idx)
98
+ else:
99
+ self.index_matrix.append(self._cat_data[:, index])
100
+ self.index_description.append(("cat", len(self._oe.categories_[index])))
101
+ self._cat_indices.append(idx)
102
+ idx += 1
103
+ self.index_matrix = np.stack(self.index_matrix, axis=1).astype(np.int32)
104
+
105
+ cat_x = self._cat_data
106
+ num_x = self._num_data
107
+ y = None
108
+ if self._target_descr is not None:
109
+ tdtype, tindex = self._target_descr
110
+ if tdtype == "cat":
111
+ y = cat_x[:, tindex].astype(np.int32)
112
+ cat_x = np.concatenate([cat_x[:, :tindex], cat_x[:, tindex + 1:]], axis=1)
113
+ else:
114
+ y = num_x[:, tindex]
115
+ num_x = np.concatenate([num_x[:, :tindex], num_x[:, tindex + 1:]], axis=1)
116
+ self.transformed = np.concatenate([cat_x, num_x], axis=1)
117
+ self.y = y
118
+
119
+ def get_index_matrix(self, df: pd.DataFrame) -> np.ndarray:
120
+ """
121
+ Transform raw data into index matrix.
122
+
123
+ Parameters
124
+ ----------
125
+ df : pd.DataFrame
126
+ The raw data to be converted.
127
+
128
+ Returns
129
+ -------
130
+ np.ndarray
131
+ Transformed data.
132
+ """
133
+ cat_data = df[self.cat_columns].values
134
+ if len(self.cat_columns) > 0:
135
+ cat_data = self._oe.transform(cat_data)
136
+ num_data = df[self.num_columns].values
137
+ if self._kbins is not None:
138
+ if len(self.num_columns) > 0:
139
+ bin_data = self._kbins.transform(num_data)
140
+ else:
141
+ bin_data = num_data
142
+ elif self._bin_data is not None:
143
+ bin_data = np.zeros_like(num_data, dtype=np.int32)
144
+ else:
145
+ bin_data = None
146
+ if len(self.num_columns) > 0:
147
+ quantiles = self._qbins.transform(num_data)
148
+ else:
149
+ quantiles = num_data
150
+
151
+ out = []
152
+ for dtype, index in self._orders:
153
+ if dtype == "num":
154
+ if bin_data is not None:
155
+ out.append(bin_data[:, index])
156
+ out.append(quantiles[:, index])
157
+ else:
158
+ out.append(cat_data[:, index])
159
+ return np.stack(out, axis=1).astype(np.int32)
160
+
161
+ def recover_index_matrix(self, data: np.ndarray) -> pd.DataFrame:
162
+ """
163
+ Inversely transform index matrix to raw data.
164
+
165
+ Parameters
166
+ ----------
167
+ np.ndarray
168
+ Transformed data.
169
+
170
+ Returns
171
+ -------
172
+ pd.DataFrame
173
+ The raw data recovered.
174
+ """
175
+ cat_data = data[:, self._cat_indices]
176
+ if len(self.cat_columns) > 0:
177
+ cat_data = self._oe.inverse_transform(cat_data)
178
+ cat_data = pd.DataFrame(cat_data, columns=self.cat_columns)
179
+ else:
180
+ cat_data = pd.DataFrame(index=pd.RangeIndex(data.shape[0]))
181
+ num_data = data[:, self._quantile_indices]
182
+ if len(self.num_columns) > 0:
183
+ num_data = self._qbins.inverse_transform(num_data)
184
+ num_data = pd.DataFrame(num_data, columns=self.num_columns)
185
+ else:
186
+ num_data = pd.DataFrame(index=pd.RangeIndex(data.shape[0]))
187
+ recovered = pd.concat([cat_data, num_data], axis=1)
188
+ return recovered[self.columns]
189
+
190
+ def transform(self, df: pd.DataFrame) -> Tuple[np.ndarray, Optional[np.ndarray]]:
191
+ """
192
+ Transform raw data to matrix that is friendly to downstream tasks.
193
+
194
+ Parameters
195
+ ----------
196
+ df : pd.DataFrame
197
+ The raw data to be transformed.
198
+
199
+ Returns
200
+ -------
201
+ np.ndarray
202
+ X of transformed data (features).
203
+ np.ndarray, optional
204
+ y of transformed data (targets).
205
+ """
206
+ cat_data = df[self.cat_columns].values
207
+ if len(self.cat_columns) > 0:
208
+ cat_data = self._oe.transform(cat_data)
209
+ num_data = df[self.num_columns].values
210
+ if len(self.num_columns) > 0:
211
+ num_data = self._sc.transform(num_data)
212
+ if self._target_descr is None:
213
+ return np.concatenate([cat_data, num_data], axis=1), None
214
+ tdtype, tindex = self._target_descr
215
+ if tdtype == "cat":
216
+ y = cat_data[:, tindex]
217
+ cat_data = np.concatenate([cat_data[:, :tindex], cat_data[:, tindex + 1:]], axis=1)
218
+ else:
219
+ y = num_data[:, tindex]
220
+ num_data = np.concatenate([num_data[:, :tindex], num_data[:, tindex + 1:]], axis=1)
221
+ return np.concatenate([cat_data, num_data], axis=1), y
222
+
223
+
224
+ class MaskGenerator:
225
+ def __init__(self, data: Dataset, leaf_dim: int,
226
+ tree_mask_ratio: Union[float, Tuple[float, float]],
227
+ data_mask_ratio: Union[float, Tuple[float, float]]):
228
+ """
229
+ Parameters
230
+ ----------
231
+ data : Dataset
232
+ The tabular dataset.
233
+ leaf_dim : int
234
+ The number of dimensions for leaf.
235
+ tree_mask_ratio : float | (float, float)
236
+ Mask ratio of tree leaves.
237
+ data_mask_ratio : float | (float, float)
238
+ Mask ratio of data values.
239
+ """
240
+ # bos_id: 0, eos_id: 1, mask_id: 2
241
+ self.leaf_dim = leaf_dim
242
+ self.data_dim = data.index_matrix.shape[-1]
243
+ if isinstance(tree_mask_ratio, Sequence):
244
+ self.min_tree_mask_ratio, self.max_tree_mask_ratio = tree_mask_ratio
245
+ else:
246
+ self.min_tree_mask_ratio, self.max_tree_mask_ratio = tree_mask_ratio, tree_mask_ratio
247
+ if isinstance(data_mask_ratio, Sequence):
248
+ self.min_data_mask_ratio, self.max_data_mask_ratio = data_mask_ratio
249
+ else:
250
+ self.min_data_mask_ratio, self.max_data_mask_ratio = data_mask_ratio, data_mask_ratio
251
+
252
+ self._bin_indices = []
253
+ self._quantile_indices = []
254
+ for i, (itype, _) in enumerate(data.index_description):
255
+ pos = i + 1 + self.leaf_dim
256
+ if itype == "quantile":
257
+ self._quantile_indices.append(pos)
258
+ elif itype == "bin":
259
+ self._bin_indices.append(pos)
260
+ if len(self._quantile_indices) > 0 and len(self._bin_indices) == 0:
261
+ self._bin_indices = None
262
+
263
+ def generate_mask(self, batch_size: int, tree_threshold: Optional[torch.FloatTensor] = None,
264
+ data_threshold: Optional[torch.FloatTensor] = None,
265
+ prev_mask: Optional[torch.BoolTensor] = None) -> torch.BoolTensor:
266
+ if tree_threshold is None:
267
+ tree_threshold = torch.rand(batch_size) * (
268
+ self.max_tree_mask_ratio - self.min_tree_mask_ratio
269
+ ) + self.min_tree_mask_ratio
270
+ if data_threshold is None:
271
+ data_threshold = torch.rand(batch_size) * (
272
+ self.max_data_mask_ratio - self.min_data_mask_ratio
273
+ ) + self.min_data_mask_ratio
274
+ tree_mask = torch.rand(batch_size, self.leaf_dim) < tree_threshold.view(-1, 1)
275
+ data_mask = torch.rand(batch_size, self.data_dim) < data_threshold.view(-1, 1)
276
+ if prev_mask is not None:
277
+ tree_mask = tree_mask.masked_fill(~prev_mask[:, 1:1 + self.leaf_dim], False)
278
+ data_mask = data_mask.masked_fill(~prev_mask[:, 1 + self.leaf_dim:-1], False)
279
+ mask = torch.cat([
280
+ torch.zeros(batch_size, 1).bool(), tree_mask, data_mask, torch.zeros(batch_size, 1).bool()
281
+ ], dim=-1)
282
+ if self._bin_indices is not None:
283
+ bin_mask = mask[:, self._bin_indices]
284
+ quantile_mask = mask[:, self._quantile_indices]
285
+ need_to_swap = bin_mask & ~quantile_mask
286
+ bin_mask[need_to_swap] = False
287
+ quantile_mask[need_to_swap] = True
288
+ mask[:, self._bin_indices] = bin_mask
289
+ mask[:, self._quantile_indices] = quantile_mask
290
+ return mask
291
+
292
+
293
+ class _DataOffsetter:
294
+ def __init__(self, data: Dataset, leaf_dim: int, max_n_leaves: int):
295
+ """
296
+ Parameters
297
+ ----------
298
+ data : Dataset
299
+ The tabular dataset.
300
+ leaf_dim : int
301
+ Number of trees.
302
+ max_n_leaves : int
303
+ The maximum number of leaves.
304
+ """
305
+ self._cat_offset = 3 + max_n_leaves
306
+ self._bin_offset = self._cat_offset + data.max_n_categories
307
+ self._quantile_offset = self._bin_offset + data.max_n_bins
308
+
309
+ self.offsets = torch.zeros(2 + leaf_dim + data.index_matrix.shape[-1], dtype=torch.long)
310
+ self.offsets[-1] = 1
311
+ self.offsets[1:1 + leaf_dim] = 3
312
+ for i, (itype, _) in enumerate(data.index_description):
313
+ pos = i + 1 + leaf_dim
314
+ offset = self._cat_offset if itype == "cat" else self._quantile_offset \
315
+ if itype == "quantile" else self._bin_offset
316
+ self.offsets[pos] = offset
317
+
318
+
319
+ class TrainingDataCollator:
320
+ def __init__(self, data: Dataset, leaf_index_matrix: np.ndarray, max_n_leaves: int,
321
+ tree_mask_ratio: Union[float, Tuple[float, float]],
322
+ data_mask_ratio: Union[float, Tuple[float, float]]):
323
+ """
324
+ Parameters
325
+ ----------
326
+ data : Dataset
327
+ The tabular dataset.
328
+ leaf_index_matrix : np.ndarray
329
+ The leaf index matrix from tree-based model.
330
+ max_n_leaves : int
331
+ The maximum number of leaves.
332
+ tree_mask_ratio, data_mask_ratio : float | (float, float)
333
+ Arguments of `MaskGenerator`.
334
+ """
335
+ # bos_id: 0, eos_id: 1, mask_id: 2
336
+ leaf_dim = leaf_index_matrix.shape[1]
337
+ self._mask_generator = MaskGenerator(
338
+ data=data, leaf_dim=leaf_dim, tree_mask_ratio=tree_mask_ratio, data_mask_ratio=data_mask_ratio
339
+ )
340
+ self._offsetter = _DataOffsetter(data=data, leaf_dim=leaf_dim, max_n_leaves=max_n_leaves)
341
+
342
+ def __call__(self, batch: List[Tuple[torch.LongTensor, torch.LongTensor]]) -> BatchEncoding:
343
+ tokens = torch.cat([
344
+ torch.zeros(len(batch), 1).long(),
345
+ torch.stack([a for a, b in batch]), torch.stack([b for a, b in batch]),
346
+ torch.zeros(len(batch), 1).long()
347
+ ], dim=1) + self._offsetter.offsets
348
+ mask = self._mask_generator.generate_mask(len(batch))
349
+ masked = torch.masked_fill(tokens, mask, 2)
350
+ return BatchEncoding({
351
+ "input_ids": masked,
352
+ "attention_mask": torch.ones_like(masked, dtype=torch.bool),
353
+ "labels": tokens
354
+ })
355
+
356
+
357
+ class CausalInferenceDataCollator:
358
+ def __init__(self, leaf_index_matrix: np.ndarray,
359
+ tree_mask_ratio: Union[float, Tuple[float, float]]):
360
+ """
361
+ Parameters
362
+ ----------
363
+ leaf_index_matrix : np.ndarray
364
+ The leaf index matrix from tree-based model.
365
+ tree_mask_ratio : float | (float, float)
366
+ Mask ratio of tree leaves.
367
+ """
368
+ # bos_id: 0, eos_id: 1, mask_id: 2
369
+ self._leaf_dim = leaf_index_matrix.shape[1]
370
+ if isinstance(tree_mask_ratio, Sequence):
371
+ self._min_tree_mask_ratio, self._max_tree_mask_ratio = tree_mask_ratio
372
+ else:
373
+ self._min_tree_mask_ratio, self._max_tree_mask_ratio = tree_mask_ratio, tree_mask_ratio
374
+
375
+ def __call__(self, batch: List[Tuple[torch.LongTensor, ]]) -> BatchEncoding:
376
+ tokens = torch.cat([
377
+ torch.zeros(len(batch), 1).long(),
378
+ torch.stack([a for a, in batch]) + 3,
379
+ ], dim=1)
380
+ tree_threshold = torch.rand(len(batch)) * (
381
+ self._max_tree_mask_ratio - self._min_tree_mask_ratio
382
+ ) + self._min_tree_mask_ratio
383
+ tree_mask = torch.rand(len(batch), self._leaf_dim) < tree_threshold.view(-1, 1)
384
+ mask = torch.cat([
385
+ torch.zeros(len(batch), 1).bool(), tree_mask,
386
+ ], dim=-1)
387
+
388
+ masked = torch.masked_fill(tokens, mask, 2)
389
+ return BatchEncoding({
390
+ "input_ids": masked,
391
+ "attention_mask": torch.ones_like(masked, dtype=torch.bool),
392
+ })
393
+
394
+
395
+ class _DataLogitsProcessor(LogitsProcessor):
396
+ def __init__(self, data: Dataset, max_n_leaves: int, leaf_dim: int):
397
+ super().__init__()
398
+ self._data = data
399
+ self._max_n_leaves = max_n_leaves
400
+ self._leaf_dim = leaf_dim
401
+ self._main_dim = data.index_matrix.shape[-1]
402
+ self._cat_offset = 3 + max_n_leaves
403
+ self._bin_offset = self._cat_offset + data.max_n_categories
404
+ self._quantile_offset = self._bin_offset + data.max_n_bins
405
+
406
+ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
407
+ if input_ids.shape[-1] == 0:
408
+ valid_tokens = torch.tensor([0])
409
+ elif 1 <= input_ids.shape[-1] < self._leaf_dim + 1:
410
+ valid_tokens = torch.arange(3, 3 + self._max_n_leaves)
411
+ elif self._leaf_dim + 1 <= input_ids.shape[-1] < self._leaf_dim + 1 + self._main_dim:
412
+ main_index = input_ids.shape[-1] - self._leaf_dim - 1
413
+ itype, n_classes = self._data.index_description[main_index]
414
+ offset = self._cat_offset if itype == "cat" else self._bin_offset\
415
+ if itype == "bin" else self._quantile_offset
416
+ valid_tokens = torch.arange(offset, offset + n_classes)
417
+ else:
418
+ valid_tokens = torch.tensor([1])
419
+
420
+ mask = torch.zeros_like(scores, dtype=torch.bool)
421
+ mask[:, valid_tokens] = True
422
+ scores = scores.masked_fill(~mask, -1e9)
423
+ return scores
tabtreeformer/dsml.py ADDED
@@ -0,0 +1,141 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Downstream ML Models."""
2
+ from abc import ABC, abstractmethod
3
+ from typing import Optional, Type
4
+
5
+ import numpy as np
6
+ import pandas as pd
7
+ import optuna
8
+ from lightgbm import LGBMClassifier, LGBMRegressor
9
+ from sklearn.base import BaseEstimator
10
+ from sklearn.model_selection import cross_val_score
11
+
12
+ from .data import Dataset
13
+
14
+
15
+ eval_metrics = {
16
+ "bin": "f1_weighted",
17
+ "mult": "f1_weighted",
18
+ "reg": "neg_mean_squared_error"
19
+ }
20
+
21
+
22
+ class MLModel(ABC):
23
+ def __init__(self, data: Dataset, **kwargs):
24
+ """
25
+ Parameters
26
+ ----------
27
+ data : Dataset
28
+ The dataset to be used for training and evaluation.
29
+ **kwargs
30
+ Hyperparameter tuning search space. Keys are parameter names, and values are dict with key "dtype",
31
+ where there are 4 dtypes:
32
+
33
+ - "const": fixed constant value, with the value in another key "value",
34
+ - "categorical": categorical value, other keys and values for `optuna.Trial.suggest_categorical`,
35
+ - "int": integer value, other keys and values for `optuna.Trial.suggest_int`,
36
+ - "float": float value, other keys and values for `optuna.Trial.suggest_float`.
37
+ """
38
+ self.data = data
39
+ self._kwargs = kwargs
40
+ self._fixed_kwargs = {}
41
+ self._model: Type[BaseEstimator] = self._create_model(data)
42
+ self._best_params = None
43
+ self._base_best_model: Optional[BaseEstimator] = None
44
+
45
+ def _objective(self, trial: optuna.Trial):
46
+ params = {}
47
+ for k, v in self._kwargs.items():
48
+ v = v.copy()
49
+ dtype = v.pop("dtype")
50
+ if dtype == "const":
51
+ params[k] = v["value"]
52
+ elif dtype == "categorical":
53
+ params[k] = trial.suggest_categorical(k, **v)
54
+ elif dtype == "int":
55
+ params[k] = trial.suggest_int(k, **v)
56
+ elif dtype == "float":
57
+ params[k] = trial.suggest_float(k, **v)
58
+ else:
59
+ raise ValueError(f"Unrecognized dtype {dtype}")
60
+ model = self._model(**params, **self._fixed_kwargs)
61
+ return cross_val_score(
62
+ model, self.data.transformed, self.data.y,
63
+ cv=3, scoring=eval_metrics[self.data.ttype]
64
+ ).mean()
65
+
66
+ def fit(self):
67
+ """
68
+ Do hyper-parameter tuning, and fit a best-performing model.
69
+ """
70
+ study = optuna.create_study(direction="maximize")
71
+ study.optimize(self._objective, n_trials=50, n_jobs=10)
72
+ self._best_params = study.best_params
73
+ self._base_best_model = self._model(**self._fixed_kwargs, **self._best_params)
74
+ self._base_best_model.fit(self.data.transformed, self.data.y)
75
+
76
+ @classmethod
77
+ @abstractmethod
78
+ def _create_model(cls, data: Dataset) -> Type[BaseEstimator]:
79
+ raise NotImplementedError()
80
+
81
+ @abstractmethod
82
+ def _predict_leaves(self, x: np.ndarray) -> np.ndarray:
83
+ raise NotImplementedError()
84
+
85
+ def apply(self, data: pd.DataFrame) -> np.ndarray:
86
+ """
87
+ Apply the model to obtain tree indices.
88
+
89
+ Parameters
90
+ ----------
91
+ data : pd.DataFrame
92
+ The data to obtain tree indices from.
93
+
94
+ Returns
95
+ -------
96
+ np.ndarray
97
+ Leaf index matrix for data.
98
+ """
99
+ x, _ = self.data.transform(data)
100
+ return self._predict_leaves(x).astype(np.int32)
101
+
102
+ @property
103
+ def n_leaves(self) -> int:
104
+ """Maximum number of leaves per tree."""
105
+ raise NotImplementedError()
106
+
107
+
108
+ class LightGBMModel(MLModel):
109
+ def __init__(self, data: Dataset, ):
110
+ super().__init__(
111
+ data,
112
+ learning_rate=dict(dtype="float", low=0.01, high=0.3, log=True),
113
+ n_estimators=dict(dtype="int", low=50, high=250, step=50),
114
+ max_depth=dict(dtype="int", low=3, high=10),
115
+ num_leaves=dict(dtype="int", low=20, high=100, step=5),
116
+ min_data_in_leaf=dict(dtype="int", low=10, high=50, step=5),
117
+ feature_fraction=dict(dtype="float", low=0.6, high=1.0),
118
+ bagging_fraction=dict(dtype="float", low=0.6, high=1.0),
119
+ lambda_l1=dict(dtype="float", low=0, high=10),
120
+ lambda_l2=dict(dtype="float", low=0, high=10),
121
+ )
122
+ self._fixed_kwargs = {
123
+ "categorical_feature": [self.data.columns.index(c) for c in self.data.cat_columns],
124
+ "verbose": -1,
125
+ "log_level": "error"
126
+ }
127
+
128
+ @classmethod
129
+ def _create_model(cls, data: Dataset) -> Type[BaseEstimator]:
130
+ if data.ttype == "reg":
131
+ return LGBMRegressor
132
+ else:
133
+ return LGBMClassifier
134
+
135
+ def _predict_leaves(self, x: np.ndarray) -> np.ndarray:
136
+ return self._base_best_model.predict(x, pred_leaf=True)
137
+
138
+ @property
139
+ def n_leaves(self) -> int:
140
+ model_dump = self._base_best_model.booster_.dump_model()
141
+ return max(tree["num_leaves"] for tree in model_dump["tree_info"])
tabtreeformer/model.py ADDED
@@ -0,0 +1,347 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from typing import Literal, Optional
3
+
4
+ import numpy as np
5
+ import pandas as pd
6
+ import torch
7
+ from torch import nn
8
+ from torch.utils.data import DataLoader, TensorDataset
9
+ from transformers import (
10
+ AutoConfig, AutoModelForCausalLM, BatchEncoding, PreTrainedModel, PretrainedConfig,
11
+ Trainer as _Trainer, TrainingArguments,
12
+ )
13
+
14
+ from .data import (
15
+ CausalInferenceDataCollator, Dataset, TrainingDataCollator,
16
+ _DataLogitsProcessor, _DataOffsetter
17
+ )
18
+ from .dsml import MLModel, LightGBMModel
19
+
20
+
21
+ def _prepare_config(config: PretrainedConfig,
22
+ hidden_size: Optional[int] = None, intermediate_size: Optional[int] = None,
23
+ n_layers: Optional[int] = None, n_heads: Optional[int] = None) -> PretrainedConfig:
24
+ if hidden_size is not None:
25
+ for k in ["hidden_size", "n_embd"]:
26
+ if hasattr(config, k):
27
+ setattr(config, k, hidden_size)
28
+ if intermediate_size is not None:
29
+ for k in ["intermediate_size", "n_inner"]:
30
+ if hasattr(config, k):
31
+ setattr(config, k, intermediate_size)
32
+ if n_layers is not None:
33
+ for k in ["num_hidden_layers", "n_layer"]:
34
+ if hasattr(config, k):
35
+ setattr(config, k, n_layers)
36
+ if n_heads is not None:
37
+ for k in ["num_attention_heads", "n_head"]:
38
+ if hasattr(config, k):
39
+ setattr(config, k, n_heads)
40
+ config.bos_token_id = 0
41
+ config.eos_token_id = 1
42
+ config.masked_token_id = 2
43
+ config.pad_token_id = 1
44
+
45
+ return config
46
+
47
+
48
+ def _update_size(config: PretrainedConfig, vocab_size: int, length: int) -> PretrainedConfig:
49
+ for k in ["max_position_embeddings", "n_positions", "max_length", "n_ctx"]:
50
+ if hasattr(config, k):
51
+ setattr(config, k, length)
52
+ config.vocab_size = vocab_size
53
+ return config
54
+
55
+
56
+ class TokenLoss(nn.Module):
57
+ def __init__(self, quantile_offset: int, n_quantiles: torch.LongTensor, max_n_quantiles: int, data_offset: int):
58
+ """
59
+ Parameters
60
+ ----------
61
+ quantile_offset : int
62
+ Quantile tokens offset (first quantile token's ID).
63
+ n_quantiles : torch.LongTensor
64
+ Number of quantiles at each position.
65
+ max_n_quantiles : int
66
+ Maximum number of quantiles.
67
+ data_offset : int
68
+ The offset position for data values.
69
+ """
70
+ super().__init__()
71
+ max_n_quantiles = max(1, max_n_quantiles)
72
+ self.quantile_offset = quantile_offset
73
+ self.n_quantiles = torch.cat([
74
+ torch.zeros(data_offset - 1, dtype=torch.long),
75
+ n_quantiles, torch.zeros(1, dtype=torch.long)
76
+ ], dim=0).view(-1, 1)
77
+ self.max_n_quantiles = max_n_quantiles
78
+ self.data_offset = data_offset
79
+ self.ignore_index = -100
80
+ self._is_quantile = self.n_quantiles > 0
81
+ self._quantile_logits_mask = torch.zeros(
82
+ self.n_quantiles.shape[0], max_n_quantiles, dtype=torch.bool
83
+ )
84
+ for i, q in enumerate(self.n_quantiles):
85
+ self._quantile_logits_mask[i, :q.item()] = True
86
+ self._quantile_logits_mask = self._quantile_logits_mask.unsqueeze(0).contiguous()
87
+ self._weight_matrix = torch.zeros(
88
+ self.n_quantiles.shape[0], max_n_quantiles, max_n_quantiles, dtype=torch.float
89
+ )
90
+ se = ((torch.arange(max_n_quantiles) - torch.arange(max_n_quantiles).view(-1, 1)) ** 2).float()
91
+ for i, q in enumerate(self.n_quantiles):
92
+ self._weight_matrix[i] = torch.exp(-se / ((q * 0.005) ** 2))
93
+ self._weight_matrix = 1 + 0.5 - self._weight_matrix
94
+ self._weight_matrix = self._weight_matrix.masked_fill(torch.isnan(self._weight_matrix), 0)
95
+ self._weight_matrix = self._weight_matrix.contiguous()
96
+
97
+ def __call__(
98
+ self, model_output: BatchEncoding, labels: torch.LongTensor, shift_labels: bool = False
99
+ ) -> torch.FloatTensor:
100
+ logits = model_output["logits"] if isinstance(model_output, dict) else model_output[0]
101
+ if shift_labels:
102
+ logits = logits[..., :-1, :].contiguous()
103
+ labels = labels[..., 1:].contiguous()
104
+ probs = nn.functional.softmax(logits, dim=-1)
105
+ log_probs = -torch.log(probs)
106
+ if labels.dim() == logits.dim() - 1:
107
+ labels = labels.unsqueeze(-1)
108
+
109
+ padding_mask = labels.eq(self.ignore_index)
110
+ num_active_elements = padding_mask.numel() - padding_mask.long().sum()
111
+
112
+ nll_loss = log_probs.gather(dim=-1, index=labels)
113
+ self.n_quantiles = self.n_quantiles.to(logits.device)
114
+ self._is_quantile = self._is_quantile.to(logits.device)
115
+ self._quantile_logits_mask = self._quantile_logits_mask.to(logits.device)
116
+ self._weight_matrix = self._weight_matrix.to(logits.device)
117
+ seq_indices = torch.arange(logits.shape[1], device=logits.device)
118
+ soft_targets = self._weight_matrix[seq_indices, (labels.squeeze(-1) - self.quantile_offset).clip(0), :]
119
+ masked = ~self._quantile_logits_mask.repeat((soft_targets.shape[0], 1, 1))
120
+ soft_targets = soft_targets.masked_fill(masked, 0)
121
+ soft_targets = soft_targets / soft_targets.sum(dim=-1, keepdim=True)
122
+ soft_targets = soft_targets.masked_fill(~self._is_quantile.unsqueeze(0), 1 / self.max_n_quantiles)
123
+ quantile_logits = logits[:, :, self.quantile_offset:self.quantile_offset + self.max_n_quantiles]
124
+ if quantile_logits.shape[-1] > 0:
125
+ max_logits = quantile_logits.max(dim=-1, keepdim=True).values
126
+ else:
127
+ max_logits = torch.zeros_like(quantile_logits[:, :, :1])
128
+ stabilized_logits = quantile_logits - max_logits
129
+ weighted_sum = torch.sum(soft_targets * torch.exp(stabilized_logits), dim=-1, keepdim=True)
130
+
131
+ log_normalized = torch.log(soft_targets) + quantile_logits
132
+ if quantile_logits.shape[-1] > 0:
133
+ soft_nll_loss = (-log_normalized).gather(
134
+ dim=-1, index=(labels - self.quantile_offset).clamp(0, self.max_n_quantiles)
135
+ ) + torch.log(weighted_sum) + max_logits
136
+ else:
137
+ soft_nll_loss = torch.zeros_like(weighted_sum)
138
+ soft_nll_loss = soft_nll_loss.masked_fill(~self._is_quantile.unsqueeze(0), 0)
139
+ is_quantile = self._is_quantile.unsqueeze(0)
140
+ is_quantile_labels = (~self._is_quantile).long().unsqueeze(0).repeat((logits.shape[0], 1, 1))
141
+ quantile_logits_mask = torch.zeros((1, *probs.shape[1:]), dtype=torch.bool, device=logits.device)
142
+ quantile_logits_mask[
143
+ 0, :, self.quantile_offset:self.quantile_offset + self.max_n_quantiles
144
+ ] = self._quantile_logits_mask
145
+ quantile_probs = (probs * quantile_logits_mask).sum(dim=-1).masked_fill(~is_quantile[:, :, 0], 0.5)
146
+ non_quantile_probs = (probs * ~quantile_logits_mask).sum(dim=-1).masked_fill(~is_quantile[:, :, 0], 0.5)
147
+ quantile_probs = torch.stack([quantile_probs, non_quantile_probs], dim=-1)
148
+ quantile_log_probs = -torch.log(quantile_probs)
149
+ quantile_nll_loss = quantile_log_probs.gather(dim=-1, index=is_quantile_labels)
150
+ nll_loss = nll_loss * (~is_quantile) + (soft_nll_loss + quantile_nll_loss) * is_quantile
151
+
152
+ nll_loss.masked_fill_(padding_mask, 0.0)
153
+ nll_loss = nll_loss.sum() / num_active_elements
154
+ return nll_loss
155
+
156
+
157
+ class Trainer(_Trainer):
158
+ def __init__(self, *args, quantile_offset: int, n_quantiles: torch.LongTensor, max_n_quantiles: int,
159
+ data_offset: int,
160
+ **kwargs):
161
+ super().__init__(*args, **kwargs)
162
+ self.label_smoother = TokenLoss(
163
+ quantile_offset, n_quantiles, max_n_quantiles, data_offset,
164
+ )
165
+
166
+
167
+ class Transformer:
168
+ def __init__(self, model: MLModel, backbone: str = "distilgpt2"):
169
+ """
170
+ Model of transformer using tree-based model input and data input.
171
+
172
+ Parameters
173
+ ----------
174
+ model : MLModel
175
+ The tree-based model.
176
+ backbone : str
177
+ The causal LM backbone from huggingface pre-trained model.
178
+ """
179
+ self.model = model
180
+ self.data = self.model.data
181
+
182
+ self.config = _prepare_config(
183
+ AutoConfig.from_pretrained(backbone), 256, None, None, 8
184
+ )
185
+
186
+ self.leaf_index_matrix: Optional[np.ndarray] = None
187
+ self.lm: Optional[PreTrainedModel] = None
188
+ self._batch_size = 0
189
+ self._seq_len = 0
190
+
191
+ def train(self, out_dir: str):
192
+ """
193
+ Train model.
194
+
195
+ Parameters
196
+ ----------
197
+ out_dir : str
198
+ Output directory.
199
+ """
200
+ leaf_index_matrix = self.model.apply(self.data.data)
201
+ self.leaf_index_matrix = leaf_index_matrix
202
+ vocab_size = int(
203
+ self.model.n_leaves + 3 + self.data.max_n_categories + self.data.max_n_bins + self.data.max_n_quantiles
204
+ )
205
+ seq_len = int(2 + leaf_index_matrix.shape[-1] + self.data.index_matrix.shape[-1])
206
+ self._seq_len = seq_len
207
+ self.config = _update_size(self.config, vocab_size, seq_len)
208
+
209
+ self.lm = AutoModelForCausalLM.from_config(self.config)
210
+
211
+ dataset = TensorDataset(
212
+ torch.from_numpy(leaf_index_matrix).long(),
213
+ torch.from_numpy(self.data.index_matrix).long(),
214
+ )
215
+ os.makedirs(out_dir, exist_ok=True)
216
+
217
+ collator = TrainingDataCollator(
218
+ self.data, leaf_index_matrix, self.model.n_leaves, (0.3, 0.7), (0.1, 0.4),
219
+ )
220
+ training_args = TrainingArguments(
221
+ output_dir=os.path.join(out_dir, "ckpt"),
222
+ logging_dir=os.path.join(out_dir, "logs"),
223
+ per_device_train_batch_size=128,
224
+ max_steps=30, fp16=True, learning_rate=5e-4, logging_steps=100
225
+ )
226
+ trainer = Trainer(
227
+ model=self.lm,
228
+ args=training_args,
229
+ train_dataset=dataset,
230
+ data_collator=collator,
231
+ quantile_offset=3 + self.model.n_leaves + self.data.max_n_categories + self.data.max_n_bins,
232
+ max_n_quantiles=self.data.max_n_quantiles,
233
+ n_quantiles=torch.tensor(
234
+ [0 if t != "quantile" else x for t, x in self.data.index_description], dtype=torch.long
235
+ ),
236
+ data_offset=1 + leaf_index_matrix.shape[-1],
237
+ )
238
+ self._batch_size = 128
239
+
240
+ trainer.train()
241
+
242
+ self.lm.save_pretrained(os.path.join(out_dir, "final"))
243
+ self.lm = os.path.join(out_dir, f"final")
244
+
245
+ @torch.no_grad()
246
+ def sample(
247
+ self, n: int,
248
+ ) -> torch.LongTensor:
249
+ """
250
+ Sample data.
251
+
252
+ Parameters
253
+ ----------
254
+ n : int
255
+ Number of rows to be sampled.
256
+
257
+ Returns
258
+ -------
259
+ torch.LongTensor
260
+ The generated token IDs.
261
+ """
262
+ inference_dataset = TensorDataset(torch.from_numpy(
263
+ self.leaf_index_matrix[np.random.randint(low=0, high=self.leaf_index_matrix.shape[0], size=(n,))]
264
+ ).long())
265
+ dataloader = DataLoader(
266
+ inference_dataset, collate_fn=CausalInferenceDataCollator(self.leaf_index_matrix, (0.3, 0.7)),
267
+ batch_size=self._batch_size, shuffle=False
268
+ )
269
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
270
+ if isinstance(self.lm, str):
271
+ self.lm = AutoModelForCausalLM.from_pretrained(self.lm)
272
+ self.lm.to(device)
273
+ self.lm.eval()
274
+ generated_texts = []
275
+ logits_processor = _DataLogitsProcessor(
276
+ self.data, self.model.n_leaves, self.leaf_index_matrix.shape[-1]
277
+ )
278
+ for batch in dataloader:
279
+ batch = batch.to(device)
280
+ outputs = self.lm.generate(
281
+ **batch,
282
+ max_length=self._seq_len,
283
+ num_return_sequences=1,
284
+ do_sample=True,
285
+ eos_token_id=1,
286
+ bos_token_id=0,
287
+ pad_token_id=1,
288
+ logits_processor=[logits_processor],
289
+ temperature=0.7
290
+ )
291
+
292
+ generated_texts.append(outputs)
293
+ out = torch.cat(generated_texts, dim=0)
294
+ return out
295
+
296
+
297
+ class TabTreeFormer:
298
+ def __init__(self):
299
+ """
300
+ TabTreeFormer model.
301
+ """
302
+ self.data: Optional[Dataset] = None
303
+ self.ml_model: Optional[MLModel] = None
304
+ self.transformer: Optional[Transformer] = None
305
+ self._offsetter: Optional[_DataOffsetter] = None
306
+
307
+ def train(self, data: pd.DataFrame, target: str, ttype: Literal["bin", "mult", "reg"], out_dir: str):
308
+ """
309
+ Train a TabTreeFormer.
310
+
311
+ Parameters
312
+ ----------
313
+ data : pd.DataFrame
314
+ The data to train the model on.
315
+ target, ttype
316
+ Arguments for `data.Dataset`.
317
+ out_dir : str
318
+ The output directory.
319
+ """
320
+ self.data = Dataset(data, target, ttype)
321
+ self.ml_model = LightGBMModel(data=self.data)
322
+ self.ml_model.fit()
323
+ self.transformer = Transformer(model=self.ml_model)
324
+ self.transformer.train(out_dir)
325
+ self._offsetter = _DataOffsetter(
326
+ self.data, self.transformer.leaf_index_matrix.shape[-1], self.ml_model.n_leaves
327
+ )
328
+
329
+ def sample(self, n: int) -> pd.DataFrame:
330
+ """
331
+ Sample data by TabTreeFormer.
332
+
333
+ Parameters
334
+ ----------
335
+ n : int
336
+ Number of rows to be sampled.
337
+
338
+ Returns
339
+ -------
340
+ pd.DataFrame
341
+ Sampled dataset.
342
+ """
343
+ out = self.transformer.sample(n)
344
+ out = out - self._offsetter.offsets.to(out.device)
345
+ st = 1 + self.transformer.leaf_index_matrix.shape[-1]
346
+ ed = st + self.data.index_matrix.shape[-1]
347
+ return self.data.recover_index_matrix(out[:, st:ed].detach().cpu().numpy())