hyperopt-gbt / README.md
erinkhoo's picture
Add README.md
d919f3d verified
# HyperOpt-GBT
**HyperOptimized Gradient Boosted Trees** — a scikit-learn compatible library that combines the best innovations from XGBoost, LightGBM, CatBoost, and YDF into one implementation.
## Key Innovations
| Innovation | Source | Effect |
|---|---|---|
| **GOSS** (Gradient-based One-Side Sampling) | LightGBM | 2-5× faster training, often *better* accuracy |
| **Weighted Quantile Sketch** | XGBoost | +15-19% AUC on skewed distributions |
| **Ordered Boosting** | CatBoost | Eliminates prediction shift → unbiased residuals |
| **Ordered Target Statistics** | CatBoost | Handles categoricals without target leakage |
| **Histogram-based Splits** | LightGBM | O(k) split finding vs O(n log n) |
| **Compiled Inference Engines** | YDF | 5-100× faster prediction |
| **Oblivious Trees** | CatBoost | Regularization + SIMD-friendly structure |
| **Cache-aware Column Blocks** | XGBoost | Cache-friendly memory access |
## Quick Start
```python
from hyperopt_gbt import HyperOptGradientBoostedClassifier
clf = HyperOptGradientBoostedClassifier(
n_estimators=100,
learning_rate=0.1,
max_depth=6,
use_goss=True, # LightGBM: gradient-based sampling
binning='quantile_sketch', # XGBoost: adaptive bin boundaries
n_bins=255,
)
clf.fit(X_train, y_train)
proba = clf.predict_proba(X_test)
```
## Installation
```bash
# From source
pip install -e .
# With benchmark dependencies
pip install -e ".[benchmark]"
# Build Rust backend (optional, for maximum speed)
cd rust_gbt && pip install maturin && maturin develop --release
```
## Benchmark Results
### Binary Classification (80K train, 20K test, 50 trees)
| Library | AUC | Train Time |
|---|---|---|
| **HyperOpt-GBT (GOSS)** | **0.9691** | 2.5s |
| XGBoost (hist) | 0.9661 | 1.3s |
| LightGBM | 0.9659 | 1.0s |
| CatBoost | 0.9756 | 1.5s |
### GOSS: Faster AND More Accurate
| Data Used | AUC | Speedup |
|---|---|---|
| 100% (no GOSS) | 0.9659 | 1.0× |
| 40% (GOSS) | 0.9717 | 2.4× |
| **15% (GOSS)** | **0.9740** | **5.3×** |
### Quantile Sketch vs Uniform (Skewed Data)
| Bins | Uniform AUC | Quantile AUC | Gain |
|---|---|---|---|
| 63 | 0.6426 | 0.8306 | **+18.8%** |
| 255 | 0.6775 | 0.8295 | **+15.2%** |
## API Reference
### Classifier
```python
HyperOptGradientBoostedClassifier(
# Core
n_estimators=100, # Number of boosting rounds
learning_rate=0.1, # Shrinkage
max_depth=6, # Maximum tree depth
# Accuracy innovations
ordered_boosting=False, # CatBoost: unbiased boosting
ordered_ts=True, # CatBoost: ordered target statistics
oblivious_trees=False, # CatBoost: balanced trees
# Speed innovations
use_goss=True, # LightGBM: gradient sampling
goss_a=0.2, # Keep top 20% by gradient magnitude
goss_b=0.1, # Sample 10% from rest
n_bins=255, # Histogram bins
binning='uniform', # 'uniform' or 'quantile_sketch'
# Regularization
l2_reg=1.0, # L2 on leaf weights
min_child_weight=1.0, # Min hessian sum in leaf
subsample=1.0, # Row subsampling
colsample_bytree=1.0, # Column subsampling
)
```
### Regressor
```python
HyperOptGradientBoostedRegressor(
# Same parameters as classifier
)
```
### Inference Engines
```python
from hyperopt_gbt import compile_inference_engine
engine = compile_inference_engine(model, engine_type='auto')
# Options: 'naive', 'flat', 'simd', 'quickscorer', 'auto'
predictions = engine.predict(X_binned)
```
## Rust Backend
The optional Rust backend provides the fastest training via:
- **Rayon** parallelism for histogram building across features
- **Flat tree arrays** (`Vec<TreeNode>`) — no pointer chasing
- **Zero-copy NumPy interop** via PyO3
- **LTO + native CPU** in release mode
```python
import rust_gbt
model = rust_gbt.PyRustGBT()
model.fit(X_train, y_train,
n_estimators=50, learning_rate=0.1, max_depth=6,
use_goss=True, goss_a=0.2, goss_b=0.1,
binning="quantile", task="classification")
proba = model.predict_proba(X_test)
```
## Run Benchmarks
```bash
python benchmark_quick.py
```
## Architecture
See [ARCHITECTURE.md](ARCHITECTURE.md) for the full technical design.
See [RESULTS.md](RESULTS.md) for detailed benchmark results.
See [WHY_HYPEROPT_GBT.md](WHY_HYPEROPT_GBT.md) for the motivation.
## License
Apache 2.0