Spaces:
Running
Running
Commit ·
b0d7cdb
0
Parent(s):
first commit
Browse files- README.md +28 -0
- main.py +50 -0
- requirements.txt +7 -0
- tabtreeformer/__init__.py +1 -0
- tabtreeformer/data.py +423 -0
- tabtreeformer/dsml.py +141 -0
- tabtreeformer/model.py +347 -0
README.md
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# TabTreeFormer: Tabular Data Generation Using Hybrid Tree-Transformer
|
| 2 |
+
|
| 3 |
+
## Pre-requisites
|
| 4 |
+
|
| 5 |
+
- `Python>=3.9` installed.
|
| 6 |
+
- `pip install -r requirements.txt`.
|
| 7 |
+
|
| 8 |
+
## Usage
|
| 9 |
+
|
| 10 |
+
To train,
|
| 11 |
+
|
| 12 |
+
```shell
|
| 13 |
+
python main.py train -d DATA_PATH -t TARGET_COLUMN -p TASK_TYPE -o OUT_DIR
|
| 14 |
+
```
|
| 15 |
+
|
| 16 |
+
After training, to generate,
|
| 17 |
+
|
| 18 |
+
```shell
|
| 19 |
+
python main.py sample -c OUT_DIR -n N_ROWS -o OUT_CSV_PATH
|
| 20 |
+
```
|
| 21 |
+
|
| 22 |
+
For instance, to train and sample iris dataset (one can get the dataset by `sklearn.datasets.load_iris`),
|
| 23 |
+
one can run the following:
|
| 24 |
+
|
| 25 |
+
```shell
|
| 26 |
+
python main.py train -d iris.csv -t target -p mult -o out
|
| 27 |
+
python main.py sample -c out -n 150 -o synthetic-iris.csv
|
| 28 |
+
```
|
main.py
ADDED
|
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import os.path
|
| 3 |
+
|
| 4 |
+
import pandas as pd
|
| 5 |
+
import torch
|
| 6 |
+
|
| 7 |
+
from tabtreeformer import TabTreeFormer
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def parse_args() -> argparse.Namespace:
|
| 11 |
+
parser = argparse.ArgumentParser()
|
| 12 |
+
subparsers = parser.add_subparsers(dest="op")
|
| 13 |
+
|
| 14 |
+
train_parser = subparsers.add_parser("train")
|
| 15 |
+
train_parser.add_argument("--data-path", "-d", type=str, required=True,
|
| 16 |
+
help="Path to data (.csv file).")
|
| 17 |
+
train_parser.add_argument("--target", "-t", type=str, required=True,
|
| 18 |
+
help="Target column name.")
|
| 19 |
+
train_parser.add_argument("--ttype", "-p", type=str, required=True, choices=["bin", "mult", "reg"],
|
| 20 |
+
help="Task type.")
|
| 21 |
+
train_parser.add_argument("--out", "-o", type=str, required=True,
|
| 22 |
+
help="Path to output directory.")
|
| 23 |
+
|
| 24 |
+
sample_parser = subparsers.add_parser("sample")
|
| 25 |
+
sample_parser.add_argument("--ckpt-path", "-c", type=str, required=True,
|
| 26 |
+
help="Path to checkpoint directory (output directory during training).")
|
| 27 |
+
sample_parser.add_argument("--n-rows", "-n", type=int, required=True,
|
| 28 |
+
help="Number of rows to sample.")
|
| 29 |
+
sample_parser.add_argument("--out", "-o", type=str, required=True,
|
| 30 |
+
help="Path to output synthetic data (.csv file).")
|
| 31 |
+
return parser.parse_args()
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def main():
|
| 35 |
+
args = parse_args()
|
| 36 |
+
if args.op == "train":
|
| 37 |
+
data = pd.read_csv(args.data_path)
|
| 38 |
+
ttf = TabTreeFormer()
|
| 39 |
+
ttf.train(data, args.target, args.ttype, args.out)
|
| 40 |
+
torch.save(ttf, os.path.join(args.out, "ttf.pkl"))
|
| 41 |
+
elif args.op == "sample":
|
| 42 |
+
ttf: TabTreeFormer = torch.load(os.path.join(args.ckpt_path, "ttf.pkl"))
|
| 43 |
+
sampled = ttf.sample(args.n_rows)
|
| 44 |
+
sampled.to_csv(args.out, index=False)
|
| 45 |
+
else:
|
| 46 |
+
raise ValueError("Invalid op.")
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
if __name__ == "__main__":
|
| 50 |
+
main()
|
requirements.txt
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
numpy
|
| 2 |
+
lightgbm
|
| 3 |
+
optuna
|
| 4 |
+
pandas
|
| 5 |
+
scikit-learn
|
| 6 |
+
torch
|
| 7 |
+
transformers
|
tabtreeformer/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
from .model import TabTreeFormer
|
tabtreeformer/data.py
ADDED
|
@@ -0,0 +1,423 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Data Handling."""
|
| 2 |
+
|
| 3 |
+
from typing import List, Literal, Optional, Tuple, Sequence, Union
|
| 4 |
+
|
| 5 |
+
import numpy as np
|
| 6 |
+
import pandas as pd
|
| 7 |
+
import torch
|
| 8 |
+
from sklearn.preprocessing import KBinsDiscretizer, OrdinalEncoder, StandardScaler
|
| 9 |
+
from transformers import BatchEncoding, LogitsProcessor
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class Dataset:
|
| 13 |
+
def __init__(self, data: pd.DataFrame,
|
| 14 |
+
target: str,
|
| 15 |
+
ttype: Literal["bin", "mult", "reg"]):
|
| 16 |
+
"""
|
| 17 |
+
Parameters
|
| 18 |
+
----------
|
| 19 |
+
data : pd.DataFrame
|
| 20 |
+
The dataset data.
|
| 21 |
+
target : str
|
| 22 |
+
The target column's name.
|
| 23 |
+
ttype : "bin" | "mult" | "reg"
|
| 24 |
+
Task type. Values can be "bin" for binary classification, "mult" for multiclass classification,
|
| 25 |
+
"reg" for regression.
|
| 26 |
+
"""
|
| 27 |
+
n_unique = data[target].nunique()
|
| 28 |
+
|
| 29 |
+
self.data = data.copy()
|
| 30 |
+
self.target = target
|
| 31 |
+
self.ttype = ttype
|
| 32 |
+
self.n_classes = n_unique if ttype != "reg" else -1
|
| 33 |
+
|
| 34 |
+
self.num_columns = []
|
| 35 |
+
self.cat_columns = []
|
| 36 |
+
for c in self.data.columns:
|
| 37 |
+
if data[c].nunique() <= max(3, 10) or not pd.api.types.is_numeric_dtype(data[c]):
|
| 38 |
+
self.cat_columns.append(c)
|
| 39 |
+
else:
|
| 40 |
+
self.num_columns.append(c)
|
| 41 |
+
self.columns = data.columns.tolist()
|
| 42 |
+
|
| 43 |
+
cat_data = self.data[self.cat_columns].values
|
| 44 |
+
self._oe = OrdinalEncoder()
|
| 45 |
+
self.max_n_categories = 0
|
| 46 |
+
if cat_data.shape[-1] > 0:
|
| 47 |
+
self._cat_data = self._oe.fit_transform(cat_data)
|
| 48 |
+
self.max_n_categories = max(len(x) for x in self._oe.categories_)
|
| 49 |
+
else:
|
| 50 |
+
self._cat_data = cat_data
|
| 51 |
+
num_data = self.data[self.num_columns].values
|
| 52 |
+
self.max_n_bins = 0
|
| 53 |
+
self._kbins = KBinsDiscretizer(strategy="kmeans", n_bins=10, encode="ordinal", )
|
| 54 |
+
if len(self.num_columns) <= 0:
|
| 55 |
+
self._bin_data = num_data
|
| 56 |
+
else:
|
| 57 |
+
self._bin_data = self._kbins.fit_transform(num_data)
|
| 58 |
+
self.max_n_bins = max(x for x in self._kbins.n_bins_)
|
| 59 |
+
self.max_n_quantiles = 0
|
| 60 |
+
self._qbins = KBinsDiscretizer(strategy="quantile", n_bins=1000, encode="ordinal", )
|
| 61 |
+
self._sc = StandardScaler()
|
| 62 |
+
if len(self.num_columns) <= 0:
|
| 63 |
+
self._quantiles = num_data
|
| 64 |
+
self._num_data = num_data
|
| 65 |
+
else:
|
| 66 |
+
self._quantiles = self._qbins.fit_transform(num_data)
|
| 67 |
+
self.max_n_quantiles = max(x for x in self._qbins.n_bins_)
|
| 68 |
+
self._num_data = self._sc.fit_transform(num_data)
|
| 69 |
+
|
| 70 |
+
self._orders = []
|
| 71 |
+
self._target_descr = None
|
| 72 |
+
for c in self.columns:
|
| 73 |
+
if c in self.num_columns:
|
| 74 |
+
descr = "num", self.num_columns.index(c)
|
| 75 |
+
else:
|
| 76 |
+
descr = "cat", self.cat_columns.index(c)
|
| 77 |
+
self._orders.append(descr)
|
| 78 |
+
if c == self.target:
|
| 79 |
+
self._target_descr = descr
|
| 80 |
+
|
| 81 |
+
self.index_matrix = []
|
| 82 |
+
self.index_description = []
|
| 83 |
+
self._cat_indices = []
|
| 84 |
+
self._quantile_indices = []
|
| 85 |
+
idx = 0
|
| 86 |
+
for dtype, index in self._orders:
|
| 87 |
+
if dtype == "num":
|
| 88 |
+
if self._bin_data is not None:
|
| 89 |
+
self.index_matrix.append(self._bin_data[:, index])
|
| 90 |
+
if self._kbins is not None:
|
| 91 |
+
self.index_description.append(("bin", self._kbins.n_bins_[index]))
|
| 92 |
+
else:
|
| 93 |
+
self.index_description.append(("bin", 1))
|
| 94 |
+
idx += 1
|
| 95 |
+
self.index_matrix.append(self._quantiles[:, index])
|
| 96 |
+
self.index_description.append(("quantile", self._qbins.n_bins_[index]))
|
| 97 |
+
self._quantile_indices.append(idx)
|
| 98 |
+
else:
|
| 99 |
+
self.index_matrix.append(self._cat_data[:, index])
|
| 100 |
+
self.index_description.append(("cat", len(self._oe.categories_[index])))
|
| 101 |
+
self._cat_indices.append(idx)
|
| 102 |
+
idx += 1
|
| 103 |
+
self.index_matrix = np.stack(self.index_matrix, axis=1).astype(np.int32)
|
| 104 |
+
|
| 105 |
+
cat_x = self._cat_data
|
| 106 |
+
num_x = self._num_data
|
| 107 |
+
y = None
|
| 108 |
+
if self._target_descr is not None:
|
| 109 |
+
tdtype, tindex = self._target_descr
|
| 110 |
+
if tdtype == "cat":
|
| 111 |
+
y = cat_x[:, tindex].astype(np.int32)
|
| 112 |
+
cat_x = np.concatenate([cat_x[:, :tindex], cat_x[:, tindex + 1:]], axis=1)
|
| 113 |
+
else:
|
| 114 |
+
y = num_x[:, tindex]
|
| 115 |
+
num_x = np.concatenate([num_x[:, :tindex], num_x[:, tindex + 1:]], axis=1)
|
| 116 |
+
self.transformed = np.concatenate([cat_x, num_x], axis=1)
|
| 117 |
+
self.y = y
|
| 118 |
+
|
| 119 |
+
def get_index_matrix(self, df: pd.DataFrame) -> np.ndarray:
|
| 120 |
+
"""
|
| 121 |
+
Transform raw data into index matrix.
|
| 122 |
+
|
| 123 |
+
Parameters
|
| 124 |
+
----------
|
| 125 |
+
df : pd.DataFrame
|
| 126 |
+
The raw data to be converted.
|
| 127 |
+
|
| 128 |
+
Returns
|
| 129 |
+
-------
|
| 130 |
+
np.ndarray
|
| 131 |
+
Transformed data.
|
| 132 |
+
"""
|
| 133 |
+
cat_data = df[self.cat_columns].values
|
| 134 |
+
if len(self.cat_columns) > 0:
|
| 135 |
+
cat_data = self._oe.transform(cat_data)
|
| 136 |
+
num_data = df[self.num_columns].values
|
| 137 |
+
if self._kbins is not None:
|
| 138 |
+
if len(self.num_columns) > 0:
|
| 139 |
+
bin_data = self._kbins.transform(num_data)
|
| 140 |
+
else:
|
| 141 |
+
bin_data = num_data
|
| 142 |
+
elif self._bin_data is not None:
|
| 143 |
+
bin_data = np.zeros_like(num_data, dtype=np.int32)
|
| 144 |
+
else:
|
| 145 |
+
bin_data = None
|
| 146 |
+
if len(self.num_columns) > 0:
|
| 147 |
+
quantiles = self._qbins.transform(num_data)
|
| 148 |
+
else:
|
| 149 |
+
quantiles = num_data
|
| 150 |
+
|
| 151 |
+
out = []
|
| 152 |
+
for dtype, index in self._orders:
|
| 153 |
+
if dtype == "num":
|
| 154 |
+
if bin_data is not None:
|
| 155 |
+
out.append(bin_data[:, index])
|
| 156 |
+
out.append(quantiles[:, index])
|
| 157 |
+
else:
|
| 158 |
+
out.append(cat_data[:, index])
|
| 159 |
+
return np.stack(out, axis=1).astype(np.int32)
|
| 160 |
+
|
| 161 |
+
def recover_index_matrix(self, data: np.ndarray) -> pd.DataFrame:
|
| 162 |
+
"""
|
| 163 |
+
Inversely transform index matrix to raw data.
|
| 164 |
+
|
| 165 |
+
Parameters
|
| 166 |
+
----------
|
| 167 |
+
np.ndarray
|
| 168 |
+
Transformed data.
|
| 169 |
+
|
| 170 |
+
Returns
|
| 171 |
+
-------
|
| 172 |
+
pd.DataFrame
|
| 173 |
+
The raw data recovered.
|
| 174 |
+
"""
|
| 175 |
+
cat_data = data[:, self._cat_indices]
|
| 176 |
+
if len(self.cat_columns) > 0:
|
| 177 |
+
cat_data = self._oe.inverse_transform(cat_data)
|
| 178 |
+
cat_data = pd.DataFrame(cat_data, columns=self.cat_columns)
|
| 179 |
+
else:
|
| 180 |
+
cat_data = pd.DataFrame(index=pd.RangeIndex(data.shape[0]))
|
| 181 |
+
num_data = data[:, self._quantile_indices]
|
| 182 |
+
if len(self.num_columns) > 0:
|
| 183 |
+
num_data = self._qbins.inverse_transform(num_data)
|
| 184 |
+
num_data = pd.DataFrame(num_data, columns=self.num_columns)
|
| 185 |
+
else:
|
| 186 |
+
num_data = pd.DataFrame(index=pd.RangeIndex(data.shape[0]))
|
| 187 |
+
recovered = pd.concat([cat_data, num_data], axis=1)
|
| 188 |
+
return recovered[self.columns]
|
| 189 |
+
|
| 190 |
+
def transform(self, df: pd.DataFrame) -> Tuple[np.ndarray, Optional[np.ndarray]]:
|
| 191 |
+
"""
|
| 192 |
+
Transform raw data to matrix that is friendly to downstream tasks.
|
| 193 |
+
|
| 194 |
+
Parameters
|
| 195 |
+
----------
|
| 196 |
+
df : pd.DataFrame
|
| 197 |
+
The raw data to be transformed.
|
| 198 |
+
|
| 199 |
+
Returns
|
| 200 |
+
-------
|
| 201 |
+
np.ndarray
|
| 202 |
+
X of transformed data (features).
|
| 203 |
+
np.ndarray, optional
|
| 204 |
+
y of transformed data (targets).
|
| 205 |
+
"""
|
| 206 |
+
cat_data = df[self.cat_columns].values
|
| 207 |
+
if len(self.cat_columns) > 0:
|
| 208 |
+
cat_data = self._oe.transform(cat_data)
|
| 209 |
+
num_data = df[self.num_columns].values
|
| 210 |
+
if len(self.num_columns) > 0:
|
| 211 |
+
num_data = self._sc.transform(num_data)
|
| 212 |
+
if self._target_descr is None:
|
| 213 |
+
return np.concatenate([cat_data, num_data], axis=1), None
|
| 214 |
+
tdtype, tindex = self._target_descr
|
| 215 |
+
if tdtype == "cat":
|
| 216 |
+
y = cat_data[:, tindex]
|
| 217 |
+
cat_data = np.concatenate([cat_data[:, :tindex], cat_data[:, tindex + 1:]], axis=1)
|
| 218 |
+
else:
|
| 219 |
+
y = num_data[:, tindex]
|
| 220 |
+
num_data = np.concatenate([num_data[:, :tindex], num_data[:, tindex + 1:]], axis=1)
|
| 221 |
+
return np.concatenate([cat_data, num_data], axis=1), y
|
| 222 |
+
|
| 223 |
+
|
| 224 |
+
class MaskGenerator:
|
| 225 |
+
def __init__(self, data: Dataset, leaf_dim: int,
|
| 226 |
+
tree_mask_ratio: Union[float, Tuple[float, float]],
|
| 227 |
+
data_mask_ratio: Union[float, Tuple[float, float]]):
|
| 228 |
+
"""
|
| 229 |
+
Parameters
|
| 230 |
+
----------
|
| 231 |
+
data : Dataset
|
| 232 |
+
The tabular dataset.
|
| 233 |
+
leaf_dim : int
|
| 234 |
+
The number of dimensions for leaf.
|
| 235 |
+
tree_mask_ratio : float | (float, float)
|
| 236 |
+
Mask ratio of tree leaves.
|
| 237 |
+
data_mask_ratio : float | (float, float)
|
| 238 |
+
Mask ratio of data values.
|
| 239 |
+
"""
|
| 240 |
+
# bos_id: 0, eos_id: 1, mask_id: 2
|
| 241 |
+
self.leaf_dim = leaf_dim
|
| 242 |
+
self.data_dim = data.index_matrix.shape[-1]
|
| 243 |
+
if isinstance(tree_mask_ratio, Sequence):
|
| 244 |
+
self.min_tree_mask_ratio, self.max_tree_mask_ratio = tree_mask_ratio
|
| 245 |
+
else:
|
| 246 |
+
self.min_tree_mask_ratio, self.max_tree_mask_ratio = tree_mask_ratio, tree_mask_ratio
|
| 247 |
+
if isinstance(data_mask_ratio, Sequence):
|
| 248 |
+
self.min_data_mask_ratio, self.max_data_mask_ratio = data_mask_ratio
|
| 249 |
+
else:
|
| 250 |
+
self.min_data_mask_ratio, self.max_data_mask_ratio = data_mask_ratio, data_mask_ratio
|
| 251 |
+
|
| 252 |
+
self._bin_indices = []
|
| 253 |
+
self._quantile_indices = []
|
| 254 |
+
for i, (itype, _) in enumerate(data.index_description):
|
| 255 |
+
pos = i + 1 + self.leaf_dim
|
| 256 |
+
if itype == "quantile":
|
| 257 |
+
self._quantile_indices.append(pos)
|
| 258 |
+
elif itype == "bin":
|
| 259 |
+
self._bin_indices.append(pos)
|
| 260 |
+
if len(self._quantile_indices) > 0 and len(self._bin_indices) == 0:
|
| 261 |
+
self._bin_indices = None
|
| 262 |
+
|
| 263 |
+
def generate_mask(self, batch_size: int, tree_threshold: Optional[torch.FloatTensor] = None,
|
| 264 |
+
data_threshold: Optional[torch.FloatTensor] = None,
|
| 265 |
+
prev_mask: Optional[torch.BoolTensor] = None) -> torch.BoolTensor:
|
| 266 |
+
if tree_threshold is None:
|
| 267 |
+
tree_threshold = torch.rand(batch_size) * (
|
| 268 |
+
self.max_tree_mask_ratio - self.min_tree_mask_ratio
|
| 269 |
+
) + self.min_tree_mask_ratio
|
| 270 |
+
if data_threshold is None:
|
| 271 |
+
data_threshold = torch.rand(batch_size) * (
|
| 272 |
+
self.max_data_mask_ratio - self.min_data_mask_ratio
|
| 273 |
+
) + self.min_data_mask_ratio
|
| 274 |
+
tree_mask = torch.rand(batch_size, self.leaf_dim) < tree_threshold.view(-1, 1)
|
| 275 |
+
data_mask = torch.rand(batch_size, self.data_dim) < data_threshold.view(-1, 1)
|
| 276 |
+
if prev_mask is not None:
|
| 277 |
+
tree_mask = tree_mask.masked_fill(~prev_mask[:, 1:1 + self.leaf_dim], False)
|
| 278 |
+
data_mask = data_mask.masked_fill(~prev_mask[:, 1 + self.leaf_dim:-1], False)
|
| 279 |
+
mask = torch.cat([
|
| 280 |
+
torch.zeros(batch_size, 1).bool(), tree_mask, data_mask, torch.zeros(batch_size, 1).bool()
|
| 281 |
+
], dim=-1)
|
| 282 |
+
if self._bin_indices is not None:
|
| 283 |
+
bin_mask = mask[:, self._bin_indices]
|
| 284 |
+
quantile_mask = mask[:, self._quantile_indices]
|
| 285 |
+
need_to_swap = bin_mask & ~quantile_mask
|
| 286 |
+
bin_mask[need_to_swap] = False
|
| 287 |
+
quantile_mask[need_to_swap] = True
|
| 288 |
+
mask[:, self._bin_indices] = bin_mask
|
| 289 |
+
mask[:, self._quantile_indices] = quantile_mask
|
| 290 |
+
return mask
|
| 291 |
+
|
| 292 |
+
|
| 293 |
+
class _DataOffsetter:
|
| 294 |
+
def __init__(self, data: Dataset, leaf_dim: int, max_n_leaves: int):
|
| 295 |
+
"""
|
| 296 |
+
Parameters
|
| 297 |
+
----------
|
| 298 |
+
data : Dataset
|
| 299 |
+
The tabular dataset.
|
| 300 |
+
leaf_dim : int
|
| 301 |
+
Number of trees.
|
| 302 |
+
max_n_leaves : int
|
| 303 |
+
The maximum number of leaves.
|
| 304 |
+
"""
|
| 305 |
+
self._cat_offset = 3 + max_n_leaves
|
| 306 |
+
self._bin_offset = self._cat_offset + data.max_n_categories
|
| 307 |
+
self._quantile_offset = self._bin_offset + data.max_n_bins
|
| 308 |
+
|
| 309 |
+
self.offsets = torch.zeros(2 + leaf_dim + data.index_matrix.shape[-1], dtype=torch.long)
|
| 310 |
+
self.offsets[-1] = 1
|
| 311 |
+
self.offsets[1:1 + leaf_dim] = 3
|
| 312 |
+
for i, (itype, _) in enumerate(data.index_description):
|
| 313 |
+
pos = i + 1 + leaf_dim
|
| 314 |
+
offset = self._cat_offset if itype == "cat" else self._quantile_offset \
|
| 315 |
+
if itype == "quantile" else self._bin_offset
|
| 316 |
+
self.offsets[pos] = offset
|
| 317 |
+
|
| 318 |
+
|
| 319 |
+
class TrainingDataCollator:
|
| 320 |
+
def __init__(self, data: Dataset, leaf_index_matrix: np.ndarray, max_n_leaves: int,
|
| 321 |
+
tree_mask_ratio: Union[float, Tuple[float, float]],
|
| 322 |
+
data_mask_ratio: Union[float, Tuple[float, float]]):
|
| 323 |
+
"""
|
| 324 |
+
Parameters
|
| 325 |
+
----------
|
| 326 |
+
data : Dataset
|
| 327 |
+
The tabular dataset.
|
| 328 |
+
leaf_index_matrix : np.ndarray
|
| 329 |
+
The leaf index matrix from tree-based model.
|
| 330 |
+
max_n_leaves : int
|
| 331 |
+
The maximum number of leaves.
|
| 332 |
+
tree_mask_ratio, data_mask_ratio : float | (float, float)
|
| 333 |
+
Arguments of `MaskGenerator`.
|
| 334 |
+
"""
|
| 335 |
+
# bos_id: 0, eos_id: 1, mask_id: 2
|
| 336 |
+
leaf_dim = leaf_index_matrix.shape[1]
|
| 337 |
+
self._mask_generator = MaskGenerator(
|
| 338 |
+
data=data, leaf_dim=leaf_dim, tree_mask_ratio=tree_mask_ratio, data_mask_ratio=data_mask_ratio
|
| 339 |
+
)
|
| 340 |
+
self._offsetter = _DataOffsetter(data=data, leaf_dim=leaf_dim, max_n_leaves=max_n_leaves)
|
| 341 |
+
|
| 342 |
+
def __call__(self, batch: List[Tuple[torch.LongTensor, torch.LongTensor]]) -> BatchEncoding:
|
| 343 |
+
tokens = torch.cat([
|
| 344 |
+
torch.zeros(len(batch), 1).long(),
|
| 345 |
+
torch.stack([a for a, b in batch]), torch.stack([b for a, b in batch]),
|
| 346 |
+
torch.zeros(len(batch), 1).long()
|
| 347 |
+
], dim=1) + self._offsetter.offsets
|
| 348 |
+
mask = self._mask_generator.generate_mask(len(batch))
|
| 349 |
+
masked = torch.masked_fill(tokens, mask, 2)
|
| 350 |
+
return BatchEncoding({
|
| 351 |
+
"input_ids": masked,
|
| 352 |
+
"attention_mask": torch.ones_like(masked, dtype=torch.bool),
|
| 353 |
+
"labels": tokens
|
| 354 |
+
})
|
| 355 |
+
|
| 356 |
+
|
| 357 |
+
class CausalInferenceDataCollator:
|
| 358 |
+
def __init__(self, leaf_index_matrix: np.ndarray,
|
| 359 |
+
tree_mask_ratio: Union[float, Tuple[float, float]]):
|
| 360 |
+
"""
|
| 361 |
+
Parameters
|
| 362 |
+
----------
|
| 363 |
+
leaf_index_matrix : np.ndarray
|
| 364 |
+
The leaf index matrix from tree-based model.
|
| 365 |
+
tree_mask_ratio : float | (float, float)
|
| 366 |
+
Mask ratio of tree leaves.
|
| 367 |
+
"""
|
| 368 |
+
# bos_id: 0, eos_id: 1, mask_id: 2
|
| 369 |
+
self._leaf_dim = leaf_index_matrix.shape[1]
|
| 370 |
+
if isinstance(tree_mask_ratio, Sequence):
|
| 371 |
+
self._min_tree_mask_ratio, self._max_tree_mask_ratio = tree_mask_ratio
|
| 372 |
+
else:
|
| 373 |
+
self._min_tree_mask_ratio, self._max_tree_mask_ratio = tree_mask_ratio, tree_mask_ratio
|
| 374 |
+
|
| 375 |
+
def __call__(self, batch: List[Tuple[torch.LongTensor, ]]) -> BatchEncoding:
|
| 376 |
+
tokens = torch.cat([
|
| 377 |
+
torch.zeros(len(batch), 1).long(),
|
| 378 |
+
torch.stack([a for a, in batch]) + 3,
|
| 379 |
+
], dim=1)
|
| 380 |
+
tree_threshold = torch.rand(len(batch)) * (
|
| 381 |
+
self._max_tree_mask_ratio - self._min_tree_mask_ratio
|
| 382 |
+
) + self._min_tree_mask_ratio
|
| 383 |
+
tree_mask = torch.rand(len(batch), self._leaf_dim) < tree_threshold.view(-1, 1)
|
| 384 |
+
mask = torch.cat([
|
| 385 |
+
torch.zeros(len(batch), 1).bool(), tree_mask,
|
| 386 |
+
], dim=-1)
|
| 387 |
+
|
| 388 |
+
masked = torch.masked_fill(tokens, mask, 2)
|
| 389 |
+
return BatchEncoding({
|
| 390 |
+
"input_ids": masked,
|
| 391 |
+
"attention_mask": torch.ones_like(masked, dtype=torch.bool),
|
| 392 |
+
})
|
| 393 |
+
|
| 394 |
+
|
| 395 |
+
class _DataLogitsProcessor(LogitsProcessor):
|
| 396 |
+
def __init__(self, data: Dataset, max_n_leaves: int, leaf_dim: int):
|
| 397 |
+
super().__init__()
|
| 398 |
+
self._data = data
|
| 399 |
+
self._max_n_leaves = max_n_leaves
|
| 400 |
+
self._leaf_dim = leaf_dim
|
| 401 |
+
self._main_dim = data.index_matrix.shape[-1]
|
| 402 |
+
self._cat_offset = 3 + max_n_leaves
|
| 403 |
+
self._bin_offset = self._cat_offset + data.max_n_categories
|
| 404 |
+
self._quantile_offset = self._bin_offset + data.max_n_bins
|
| 405 |
+
|
| 406 |
+
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
|
| 407 |
+
if input_ids.shape[-1] == 0:
|
| 408 |
+
valid_tokens = torch.tensor([0])
|
| 409 |
+
elif 1 <= input_ids.shape[-1] < self._leaf_dim + 1:
|
| 410 |
+
valid_tokens = torch.arange(3, 3 + self._max_n_leaves)
|
| 411 |
+
elif self._leaf_dim + 1 <= input_ids.shape[-1] < self._leaf_dim + 1 + self._main_dim:
|
| 412 |
+
main_index = input_ids.shape[-1] - self._leaf_dim - 1
|
| 413 |
+
itype, n_classes = self._data.index_description[main_index]
|
| 414 |
+
offset = self._cat_offset if itype == "cat" else self._bin_offset\
|
| 415 |
+
if itype == "bin" else self._quantile_offset
|
| 416 |
+
valid_tokens = torch.arange(offset, offset + n_classes)
|
| 417 |
+
else:
|
| 418 |
+
valid_tokens = torch.tensor([1])
|
| 419 |
+
|
| 420 |
+
mask = torch.zeros_like(scores, dtype=torch.bool)
|
| 421 |
+
mask[:, valid_tokens] = True
|
| 422 |
+
scores = scores.masked_fill(~mask, -1e9)
|
| 423 |
+
return scores
|
tabtreeformer/dsml.py
ADDED
|
@@ -0,0 +1,141 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Downstream ML Models."""
|
| 2 |
+
from abc import ABC, abstractmethod
|
| 3 |
+
from typing import Optional, Type
|
| 4 |
+
|
| 5 |
+
import numpy as np
|
| 6 |
+
import pandas as pd
|
| 7 |
+
import optuna
|
| 8 |
+
from lightgbm import LGBMClassifier, LGBMRegressor
|
| 9 |
+
from sklearn.base import BaseEstimator
|
| 10 |
+
from sklearn.model_selection import cross_val_score
|
| 11 |
+
|
| 12 |
+
from .data import Dataset
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
eval_metrics = {
|
| 16 |
+
"bin": "f1_weighted",
|
| 17 |
+
"mult": "f1_weighted",
|
| 18 |
+
"reg": "neg_mean_squared_error"
|
| 19 |
+
}
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
class MLModel(ABC):
|
| 23 |
+
def __init__(self, data: Dataset, **kwargs):
|
| 24 |
+
"""
|
| 25 |
+
Parameters
|
| 26 |
+
----------
|
| 27 |
+
data : Dataset
|
| 28 |
+
The dataset to be used for training and evaluation.
|
| 29 |
+
**kwargs
|
| 30 |
+
Hyperparameter tuning search space. Keys are parameter names, and values are dict with key "dtype",
|
| 31 |
+
where there are 4 dtypes:
|
| 32 |
+
|
| 33 |
+
- "const": fixed constant value, with the value in another key "value",
|
| 34 |
+
- "categorical": categorical value, other keys and values for `optuna.Trial.suggest_categorical`,
|
| 35 |
+
- "int": integer value, other keys and values for `optuna.Trial.suggest_int`,
|
| 36 |
+
- "float": float value, other keys and values for `optuna.Trial.suggest_float`.
|
| 37 |
+
"""
|
| 38 |
+
self.data = data
|
| 39 |
+
self._kwargs = kwargs
|
| 40 |
+
self._fixed_kwargs = {}
|
| 41 |
+
self._model: Type[BaseEstimator] = self._create_model(data)
|
| 42 |
+
self._best_params = None
|
| 43 |
+
self._base_best_model: Optional[BaseEstimator] = None
|
| 44 |
+
|
| 45 |
+
def _objective(self, trial: optuna.Trial):
|
| 46 |
+
params = {}
|
| 47 |
+
for k, v in self._kwargs.items():
|
| 48 |
+
v = v.copy()
|
| 49 |
+
dtype = v.pop("dtype")
|
| 50 |
+
if dtype == "const":
|
| 51 |
+
params[k] = v["value"]
|
| 52 |
+
elif dtype == "categorical":
|
| 53 |
+
params[k] = trial.suggest_categorical(k, **v)
|
| 54 |
+
elif dtype == "int":
|
| 55 |
+
params[k] = trial.suggest_int(k, **v)
|
| 56 |
+
elif dtype == "float":
|
| 57 |
+
params[k] = trial.suggest_float(k, **v)
|
| 58 |
+
else:
|
| 59 |
+
raise ValueError(f"Unrecognized dtype {dtype}")
|
| 60 |
+
model = self._model(**params, **self._fixed_kwargs)
|
| 61 |
+
return cross_val_score(
|
| 62 |
+
model, self.data.transformed, self.data.y,
|
| 63 |
+
cv=3, scoring=eval_metrics[self.data.ttype]
|
| 64 |
+
).mean()
|
| 65 |
+
|
| 66 |
+
def fit(self):
|
| 67 |
+
"""
|
| 68 |
+
Do hyper-parameter tuning, and fit a best-performing model.
|
| 69 |
+
"""
|
| 70 |
+
study = optuna.create_study(direction="maximize")
|
| 71 |
+
study.optimize(self._objective, n_trials=50, n_jobs=10)
|
| 72 |
+
self._best_params = study.best_params
|
| 73 |
+
self._base_best_model = self._model(**self._fixed_kwargs, **self._best_params)
|
| 74 |
+
self._base_best_model.fit(self.data.transformed, self.data.y)
|
| 75 |
+
|
| 76 |
+
@classmethod
|
| 77 |
+
@abstractmethod
|
| 78 |
+
def _create_model(cls, data: Dataset) -> Type[BaseEstimator]:
|
| 79 |
+
raise NotImplementedError()
|
| 80 |
+
|
| 81 |
+
@abstractmethod
|
| 82 |
+
def _predict_leaves(self, x: np.ndarray) -> np.ndarray:
|
| 83 |
+
raise NotImplementedError()
|
| 84 |
+
|
| 85 |
+
def apply(self, data: pd.DataFrame) -> np.ndarray:
|
| 86 |
+
"""
|
| 87 |
+
Apply the model to obtain tree indices.
|
| 88 |
+
|
| 89 |
+
Parameters
|
| 90 |
+
----------
|
| 91 |
+
data : pd.DataFrame
|
| 92 |
+
The data to obtain tree indices from.
|
| 93 |
+
|
| 94 |
+
Returns
|
| 95 |
+
-------
|
| 96 |
+
np.ndarray
|
| 97 |
+
Leaf index matrix for data.
|
| 98 |
+
"""
|
| 99 |
+
x, _ = self.data.transform(data)
|
| 100 |
+
return self._predict_leaves(x).astype(np.int32)
|
| 101 |
+
|
| 102 |
+
@property
|
| 103 |
+
def n_leaves(self) -> int:
|
| 104 |
+
"""Maximum number of leaves per tree."""
|
| 105 |
+
raise NotImplementedError()
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
class LightGBMModel(MLModel):
|
| 109 |
+
def __init__(self, data: Dataset, ):
|
| 110 |
+
super().__init__(
|
| 111 |
+
data,
|
| 112 |
+
learning_rate=dict(dtype="float", low=0.01, high=0.3, log=True),
|
| 113 |
+
n_estimators=dict(dtype="int", low=50, high=250, step=50),
|
| 114 |
+
max_depth=dict(dtype="int", low=3, high=10),
|
| 115 |
+
num_leaves=dict(dtype="int", low=20, high=100, step=5),
|
| 116 |
+
min_data_in_leaf=dict(dtype="int", low=10, high=50, step=5),
|
| 117 |
+
feature_fraction=dict(dtype="float", low=0.6, high=1.0),
|
| 118 |
+
bagging_fraction=dict(dtype="float", low=0.6, high=1.0),
|
| 119 |
+
lambda_l1=dict(dtype="float", low=0, high=10),
|
| 120 |
+
lambda_l2=dict(dtype="float", low=0, high=10),
|
| 121 |
+
)
|
| 122 |
+
self._fixed_kwargs = {
|
| 123 |
+
"categorical_feature": [self.data.columns.index(c) for c in self.data.cat_columns],
|
| 124 |
+
"verbose": -1,
|
| 125 |
+
"log_level": "error"
|
| 126 |
+
}
|
| 127 |
+
|
| 128 |
+
@classmethod
|
| 129 |
+
def _create_model(cls, data: Dataset) -> Type[BaseEstimator]:
|
| 130 |
+
if data.ttype == "reg":
|
| 131 |
+
return LGBMRegressor
|
| 132 |
+
else:
|
| 133 |
+
return LGBMClassifier
|
| 134 |
+
|
| 135 |
+
def _predict_leaves(self, x: np.ndarray) -> np.ndarray:
|
| 136 |
+
return self._base_best_model.predict(x, pred_leaf=True)
|
| 137 |
+
|
| 138 |
+
@property
|
| 139 |
+
def n_leaves(self) -> int:
|
| 140 |
+
model_dump = self._base_best_model.booster_.dump_model()
|
| 141 |
+
return max(tree["num_leaves"] for tree in model_dump["tree_info"])
|
tabtreeformer/model.py
ADDED
|
@@ -0,0 +1,347 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
from typing import Literal, Optional
|
| 3 |
+
|
| 4 |
+
import numpy as np
|
| 5 |
+
import pandas as pd
|
| 6 |
+
import torch
|
| 7 |
+
from torch import nn
|
| 8 |
+
from torch.utils.data import DataLoader, TensorDataset
|
| 9 |
+
from transformers import (
|
| 10 |
+
AutoConfig, AutoModelForCausalLM, BatchEncoding, PreTrainedModel, PretrainedConfig,
|
| 11 |
+
Trainer as _Trainer, TrainingArguments,
|
| 12 |
+
)
|
| 13 |
+
|
| 14 |
+
from .data import (
|
| 15 |
+
CausalInferenceDataCollator, Dataset, TrainingDataCollator,
|
| 16 |
+
_DataLogitsProcessor, _DataOffsetter
|
| 17 |
+
)
|
| 18 |
+
from .dsml import MLModel, LightGBMModel
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def _prepare_config(config: PretrainedConfig,
|
| 22 |
+
hidden_size: Optional[int] = None, intermediate_size: Optional[int] = None,
|
| 23 |
+
n_layers: Optional[int] = None, n_heads: Optional[int] = None) -> PretrainedConfig:
|
| 24 |
+
if hidden_size is not None:
|
| 25 |
+
for k in ["hidden_size", "n_embd"]:
|
| 26 |
+
if hasattr(config, k):
|
| 27 |
+
setattr(config, k, hidden_size)
|
| 28 |
+
if intermediate_size is not None:
|
| 29 |
+
for k in ["intermediate_size", "n_inner"]:
|
| 30 |
+
if hasattr(config, k):
|
| 31 |
+
setattr(config, k, intermediate_size)
|
| 32 |
+
if n_layers is not None:
|
| 33 |
+
for k in ["num_hidden_layers", "n_layer"]:
|
| 34 |
+
if hasattr(config, k):
|
| 35 |
+
setattr(config, k, n_layers)
|
| 36 |
+
if n_heads is not None:
|
| 37 |
+
for k in ["num_attention_heads", "n_head"]:
|
| 38 |
+
if hasattr(config, k):
|
| 39 |
+
setattr(config, k, n_heads)
|
| 40 |
+
config.bos_token_id = 0
|
| 41 |
+
config.eos_token_id = 1
|
| 42 |
+
config.masked_token_id = 2
|
| 43 |
+
config.pad_token_id = 1
|
| 44 |
+
|
| 45 |
+
return config
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
def _update_size(config: PretrainedConfig, vocab_size: int, length: int) -> PretrainedConfig:
|
| 49 |
+
for k in ["max_position_embeddings", "n_positions", "max_length", "n_ctx"]:
|
| 50 |
+
if hasattr(config, k):
|
| 51 |
+
setattr(config, k, length)
|
| 52 |
+
config.vocab_size = vocab_size
|
| 53 |
+
return config
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
class TokenLoss(nn.Module):
|
| 57 |
+
def __init__(self, quantile_offset: int, n_quantiles: torch.LongTensor, max_n_quantiles: int, data_offset: int):
|
| 58 |
+
"""
|
| 59 |
+
Parameters
|
| 60 |
+
----------
|
| 61 |
+
quantile_offset : int
|
| 62 |
+
Quantile tokens offset (first quantile token's ID).
|
| 63 |
+
n_quantiles : torch.LongTensor
|
| 64 |
+
Number of quantiles at each position.
|
| 65 |
+
max_n_quantiles : int
|
| 66 |
+
Maximum number of quantiles.
|
| 67 |
+
data_offset : int
|
| 68 |
+
The offset position for data values.
|
| 69 |
+
"""
|
| 70 |
+
super().__init__()
|
| 71 |
+
max_n_quantiles = max(1, max_n_quantiles)
|
| 72 |
+
self.quantile_offset = quantile_offset
|
| 73 |
+
self.n_quantiles = torch.cat([
|
| 74 |
+
torch.zeros(data_offset - 1, dtype=torch.long),
|
| 75 |
+
n_quantiles, torch.zeros(1, dtype=torch.long)
|
| 76 |
+
], dim=0).view(-1, 1)
|
| 77 |
+
self.max_n_quantiles = max_n_quantiles
|
| 78 |
+
self.data_offset = data_offset
|
| 79 |
+
self.ignore_index = -100
|
| 80 |
+
self._is_quantile = self.n_quantiles > 0
|
| 81 |
+
self._quantile_logits_mask = torch.zeros(
|
| 82 |
+
self.n_quantiles.shape[0], max_n_quantiles, dtype=torch.bool
|
| 83 |
+
)
|
| 84 |
+
for i, q in enumerate(self.n_quantiles):
|
| 85 |
+
self._quantile_logits_mask[i, :q.item()] = True
|
| 86 |
+
self._quantile_logits_mask = self._quantile_logits_mask.unsqueeze(0).contiguous()
|
| 87 |
+
self._weight_matrix = torch.zeros(
|
| 88 |
+
self.n_quantiles.shape[0], max_n_quantiles, max_n_quantiles, dtype=torch.float
|
| 89 |
+
)
|
| 90 |
+
se = ((torch.arange(max_n_quantiles) - torch.arange(max_n_quantiles).view(-1, 1)) ** 2).float()
|
| 91 |
+
for i, q in enumerate(self.n_quantiles):
|
| 92 |
+
self._weight_matrix[i] = torch.exp(-se / ((q * 0.005) ** 2))
|
| 93 |
+
self._weight_matrix = 1 + 0.5 - self._weight_matrix
|
| 94 |
+
self._weight_matrix = self._weight_matrix.masked_fill(torch.isnan(self._weight_matrix), 0)
|
| 95 |
+
self._weight_matrix = self._weight_matrix.contiguous()
|
| 96 |
+
|
| 97 |
+
def __call__(
|
| 98 |
+
self, model_output: BatchEncoding, labels: torch.LongTensor, shift_labels: bool = False
|
| 99 |
+
) -> torch.FloatTensor:
|
| 100 |
+
logits = model_output["logits"] if isinstance(model_output, dict) else model_output[0]
|
| 101 |
+
if shift_labels:
|
| 102 |
+
logits = logits[..., :-1, :].contiguous()
|
| 103 |
+
labels = labels[..., 1:].contiguous()
|
| 104 |
+
probs = nn.functional.softmax(logits, dim=-1)
|
| 105 |
+
log_probs = -torch.log(probs)
|
| 106 |
+
if labels.dim() == logits.dim() - 1:
|
| 107 |
+
labels = labels.unsqueeze(-1)
|
| 108 |
+
|
| 109 |
+
padding_mask = labels.eq(self.ignore_index)
|
| 110 |
+
num_active_elements = padding_mask.numel() - padding_mask.long().sum()
|
| 111 |
+
|
| 112 |
+
nll_loss = log_probs.gather(dim=-1, index=labels)
|
| 113 |
+
self.n_quantiles = self.n_quantiles.to(logits.device)
|
| 114 |
+
self._is_quantile = self._is_quantile.to(logits.device)
|
| 115 |
+
self._quantile_logits_mask = self._quantile_logits_mask.to(logits.device)
|
| 116 |
+
self._weight_matrix = self._weight_matrix.to(logits.device)
|
| 117 |
+
seq_indices = torch.arange(logits.shape[1], device=logits.device)
|
| 118 |
+
soft_targets = self._weight_matrix[seq_indices, (labels.squeeze(-1) - self.quantile_offset).clip(0), :]
|
| 119 |
+
masked = ~self._quantile_logits_mask.repeat((soft_targets.shape[0], 1, 1))
|
| 120 |
+
soft_targets = soft_targets.masked_fill(masked, 0)
|
| 121 |
+
soft_targets = soft_targets / soft_targets.sum(dim=-1, keepdim=True)
|
| 122 |
+
soft_targets = soft_targets.masked_fill(~self._is_quantile.unsqueeze(0), 1 / self.max_n_quantiles)
|
| 123 |
+
quantile_logits = logits[:, :, self.quantile_offset:self.quantile_offset + self.max_n_quantiles]
|
| 124 |
+
if quantile_logits.shape[-1] > 0:
|
| 125 |
+
max_logits = quantile_logits.max(dim=-1, keepdim=True).values
|
| 126 |
+
else:
|
| 127 |
+
max_logits = torch.zeros_like(quantile_logits[:, :, :1])
|
| 128 |
+
stabilized_logits = quantile_logits - max_logits
|
| 129 |
+
weighted_sum = torch.sum(soft_targets * torch.exp(stabilized_logits), dim=-1, keepdim=True)
|
| 130 |
+
|
| 131 |
+
log_normalized = torch.log(soft_targets) + quantile_logits
|
| 132 |
+
if quantile_logits.shape[-1] > 0:
|
| 133 |
+
soft_nll_loss = (-log_normalized).gather(
|
| 134 |
+
dim=-1, index=(labels - self.quantile_offset).clamp(0, self.max_n_quantiles)
|
| 135 |
+
) + torch.log(weighted_sum) + max_logits
|
| 136 |
+
else:
|
| 137 |
+
soft_nll_loss = torch.zeros_like(weighted_sum)
|
| 138 |
+
soft_nll_loss = soft_nll_loss.masked_fill(~self._is_quantile.unsqueeze(0), 0)
|
| 139 |
+
is_quantile = self._is_quantile.unsqueeze(0)
|
| 140 |
+
is_quantile_labels = (~self._is_quantile).long().unsqueeze(0).repeat((logits.shape[0], 1, 1))
|
| 141 |
+
quantile_logits_mask = torch.zeros((1, *probs.shape[1:]), dtype=torch.bool, device=logits.device)
|
| 142 |
+
quantile_logits_mask[
|
| 143 |
+
0, :, self.quantile_offset:self.quantile_offset + self.max_n_quantiles
|
| 144 |
+
] = self._quantile_logits_mask
|
| 145 |
+
quantile_probs = (probs * quantile_logits_mask).sum(dim=-1).masked_fill(~is_quantile[:, :, 0], 0.5)
|
| 146 |
+
non_quantile_probs = (probs * ~quantile_logits_mask).sum(dim=-1).masked_fill(~is_quantile[:, :, 0], 0.5)
|
| 147 |
+
quantile_probs = torch.stack([quantile_probs, non_quantile_probs], dim=-1)
|
| 148 |
+
quantile_log_probs = -torch.log(quantile_probs)
|
| 149 |
+
quantile_nll_loss = quantile_log_probs.gather(dim=-1, index=is_quantile_labels)
|
| 150 |
+
nll_loss = nll_loss * (~is_quantile) + (soft_nll_loss + quantile_nll_loss) * is_quantile
|
| 151 |
+
|
| 152 |
+
nll_loss.masked_fill_(padding_mask, 0.0)
|
| 153 |
+
nll_loss = nll_loss.sum() / num_active_elements
|
| 154 |
+
return nll_loss
|
| 155 |
+
|
| 156 |
+
|
| 157 |
+
class Trainer(_Trainer):
|
| 158 |
+
def __init__(self, *args, quantile_offset: int, n_quantiles: torch.LongTensor, max_n_quantiles: int,
|
| 159 |
+
data_offset: int,
|
| 160 |
+
**kwargs):
|
| 161 |
+
super().__init__(*args, **kwargs)
|
| 162 |
+
self.label_smoother = TokenLoss(
|
| 163 |
+
quantile_offset, n_quantiles, max_n_quantiles, data_offset,
|
| 164 |
+
)
|
| 165 |
+
|
| 166 |
+
|
| 167 |
+
class Transformer:
|
| 168 |
+
def __init__(self, model: MLModel, backbone: str = "distilgpt2"):
|
| 169 |
+
"""
|
| 170 |
+
Model of transformer using tree-based model input and data input.
|
| 171 |
+
|
| 172 |
+
Parameters
|
| 173 |
+
----------
|
| 174 |
+
model : MLModel
|
| 175 |
+
The tree-based model.
|
| 176 |
+
backbone : str
|
| 177 |
+
The causal LM backbone from huggingface pre-trained model.
|
| 178 |
+
"""
|
| 179 |
+
self.model = model
|
| 180 |
+
self.data = self.model.data
|
| 181 |
+
|
| 182 |
+
self.config = _prepare_config(
|
| 183 |
+
AutoConfig.from_pretrained(backbone), 256, None, None, 8
|
| 184 |
+
)
|
| 185 |
+
|
| 186 |
+
self.leaf_index_matrix: Optional[np.ndarray] = None
|
| 187 |
+
self.lm: Optional[PreTrainedModel] = None
|
| 188 |
+
self._batch_size = 0
|
| 189 |
+
self._seq_len = 0
|
| 190 |
+
|
| 191 |
+
def train(self, out_dir: str):
|
| 192 |
+
"""
|
| 193 |
+
Train model.
|
| 194 |
+
|
| 195 |
+
Parameters
|
| 196 |
+
----------
|
| 197 |
+
out_dir : str
|
| 198 |
+
Output directory.
|
| 199 |
+
"""
|
| 200 |
+
leaf_index_matrix = self.model.apply(self.data.data)
|
| 201 |
+
self.leaf_index_matrix = leaf_index_matrix
|
| 202 |
+
vocab_size = int(
|
| 203 |
+
self.model.n_leaves + 3 + self.data.max_n_categories + self.data.max_n_bins + self.data.max_n_quantiles
|
| 204 |
+
)
|
| 205 |
+
seq_len = int(2 + leaf_index_matrix.shape[-1] + self.data.index_matrix.shape[-1])
|
| 206 |
+
self._seq_len = seq_len
|
| 207 |
+
self.config = _update_size(self.config, vocab_size, seq_len)
|
| 208 |
+
|
| 209 |
+
self.lm = AutoModelForCausalLM.from_config(self.config)
|
| 210 |
+
|
| 211 |
+
dataset = TensorDataset(
|
| 212 |
+
torch.from_numpy(leaf_index_matrix).long(),
|
| 213 |
+
torch.from_numpy(self.data.index_matrix).long(),
|
| 214 |
+
)
|
| 215 |
+
os.makedirs(out_dir, exist_ok=True)
|
| 216 |
+
|
| 217 |
+
collator = TrainingDataCollator(
|
| 218 |
+
self.data, leaf_index_matrix, self.model.n_leaves, (0.3, 0.7), (0.1, 0.4),
|
| 219 |
+
)
|
| 220 |
+
training_args = TrainingArguments(
|
| 221 |
+
output_dir=os.path.join(out_dir, "ckpt"),
|
| 222 |
+
logging_dir=os.path.join(out_dir, "logs"),
|
| 223 |
+
per_device_train_batch_size=128,
|
| 224 |
+
max_steps=30, fp16=True, learning_rate=5e-4, logging_steps=100
|
| 225 |
+
)
|
| 226 |
+
trainer = Trainer(
|
| 227 |
+
model=self.lm,
|
| 228 |
+
args=training_args,
|
| 229 |
+
train_dataset=dataset,
|
| 230 |
+
data_collator=collator,
|
| 231 |
+
quantile_offset=3 + self.model.n_leaves + self.data.max_n_categories + self.data.max_n_bins,
|
| 232 |
+
max_n_quantiles=self.data.max_n_quantiles,
|
| 233 |
+
n_quantiles=torch.tensor(
|
| 234 |
+
[0 if t != "quantile" else x for t, x in self.data.index_description], dtype=torch.long
|
| 235 |
+
),
|
| 236 |
+
data_offset=1 + leaf_index_matrix.shape[-1],
|
| 237 |
+
)
|
| 238 |
+
self._batch_size = 128
|
| 239 |
+
|
| 240 |
+
trainer.train()
|
| 241 |
+
|
| 242 |
+
self.lm.save_pretrained(os.path.join(out_dir, "final"))
|
| 243 |
+
self.lm = os.path.join(out_dir, f"final")
|
| 244 |
+
|
| 245 |
+
@torch.no_grad()
|
| 246 |
+
def sample(
|
| 247 |
+
self, n: int,
|
| 248 |
+
) -> torch.LongTensor:
|
| 249 |
+
"""
|
| 250 |
+
Sample data.
|
| 251 |
+
|
| 252 |
+
Parameters
|
| 253 |
+
----------
|
| 254 |
+
n : int
|
| 255 |
+
Number of rows to be sampled.
|
| 256 |
+
|
| 257 |
+
Returns
|
| 258 |
+
-------
|
| 259 |
+
torch.LongTensor
|
| 260 |
+
The generated token IDs.
|
| 261 |
+
"""
|
| 262 |
+
inference_dataset = TensorDataset(torch.from_numpy(
|
| 263 |
+
self.leaf_index_matrix[np.random.randint(low=0, high=self.leaf_index_matrix.shape[0], size=(n,))]
|
| 264 |
+
).long())
|
| 265 |
+
dataloader = DataLoader(
|
| 266 |
+
inference_dataset, collate_fn=CausalInferenceDataCollator(self.leaf_index_matrix, (0.3, 0.7)),
|
| 267 |
+
batch_size=self._batch_size, shuffle=False
|
| 268 |
+
)
|
| 269 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 270 |
+
if isinstance(self.lm, str):
|
| 271 |
+
self.lm = AutoModelForCausalLM.from_pretrained(self.lm)
|
| 272 |
+
self.lm.to(device)
|
| 273 |
+
self.lm.eval()
|
| 274 |
+
generated_texts = []
|
| 275 |
+
logits_processor = _DataLogitsProcessor(
|
| 276 |
+
self.data, self.model.n_leaves, self.leaf_index_matrix.shape[-1]
|
| 277 |
+
)
|
| 278 |
+
for batch in dataloader:
|
| 279 |
+
batch = batch.to(device)
|
| 280 |
+
outputs = self.lm.generate(
|
| 281 |
+
**batch,
|
| 282 |
+
max_length=self._seq_len,
|
| 283 |
+
num_return_sequences=1,
|
| 284 |
+
do_sample=True,
|
| 285 |
+
eos_token_id=1,
|
| 286 |
+
bos_token_id=0,
|
| 287 |
+
pad_token_id=1,
|
| 288 |
+
logits_processor=[logits_processor],
|
| 289 |
+
temperature=0.7
|
| 290 |
+
)
|
| 291 |
+
|
| 292 |
+
generated_texts.append(outputs)
|
| 293 |
+
out = torch.cat(generated_texts, dim=0)
|
| 294 |
+
return out
|
| 295 |
+
|
| 296 |
+
|
| 297 |
+
class TabTreeFormer:
|
| 298 |
+
def __init__(self):
|
| 299 |
+
"""
|
| 300 |
+
TabTreeFormer model.
|
| 301 |
+
"""
|
| 302 |
+
self.data: Optional[Dataset] = None
|
| 303 |
+
self.ml_model: Optional[MLModel] = None
|
| 304 |
+
self.transformer: Optional[Transformer] = None
|
| 305 |
+
self._offsetter: Optional[_DataOffsetter] = None
|
| 306 |
+
|
| 307 |
+
def train(self, data: pd.DataFrame, target: str, ttype: Literal["bin", "mult", "reg"], out_dir: str):
|
| 308 |
+
"""
|
| 309 |
+
Train a TabTreeFormer.
|
| 310 |
+
|
| 311 |
+
Parameters
|
| 312 |
+
----------
|
| 313 |
+
data : pd.DataFrame
|
| 314 |
+
The data to train the model on.
|
| 315 |
+
target, ttype
|
| 316 |
+
Arguments for `data.Dataset`.
|
| 317 |
+
out_dir : str
|
| 318 |
+
The output directory.
|
| 319 |
+
"""
|
| 320 |
+
self.data = Dataset(data, target, ttype)
|
| 321 |
+
self.ml_model = LightGBMModel(data=self.data)
|
| 322 |
+
self.ml_model.fit()
|
| 323 |
+
self.transformer = Transformer(model=self.ml_model)
|
| 324 |
+
self.transformer.train(out_dir)
|
| 325 |
+
self._offsetter = _DataOffsetter(
|
| 326 |
+
self.data, self.transformer.leaf_index_matrix.shape[-1], self.ml_model.n_leaves
|
| 327 |
+
)
|
| 328 |
+
|
| 329 |
+
def sample(self, n: int) -> pd.DataFrame:
|
| 330 |
+
"""
|
| 331 |
+
Sample data by TabTreeFormer.
|
| 332 |
+
|
| 333 |
+
Parameters
|
| 334 |
+
----------
|
| 335 |
+
n : int
|
| 336 |
+
Number of rows to be sampled.
|
| 337 |
+
|
| 338 |
+
Returns
|
| 339 |
+
-------
|
| 340 |
+
pd.DataFrame
|
| 341 |
+
Sampled dataset.
|
| 342 |
+
"""
|
| 343 |
+
out = self.transformer.sample(n)
|
| 344 |
+
out = out - self._offsetter.offsets.to(out.device)
|
| 345 |
+
st = 1 + self.transformer.leaf_index_matrix.shape[-1]
|
| 346 |
+
ed = st + self.data.index_matrix.shape[-1]
|
| 347 |
+
return self.data.recover_index_matrix(out[:, st:ed].detach().cpu().numpy())
|