hyperopt-gbt / hyperopt_gbt /inference.py
erinkhoo's picture
Upload hyperopt_gbt/inference.py
a294dd9 verified
"""
Inference engine optimizations for HyperOpt-GBT.
Implements multiple inference strategies inspired by YDF's engine compilation:
1. Naive tree traversal (baseline)
2. QuickScorer-style bit-mask scoring (for small trees)
3. SIMD batched prediction (AVX-like vectorization in Python)
4. Compiled flat trees (cache-oblivious structure)
References:
- YDF Inference Engine (arXiv:2212.02934, Section 3.7)
- QuickScorer (Lucchese et al., CIKM 2015)
"""
import numpy as np
from numba import njit, prange
class InferenceEngine:
"""Base class for inference engines."""
def predict(self, X_binned):
raise NotImplementedError
class NaiveEngine(InferenceEngine):
"""Naive tree traversal - baseline.
Single while-loop from root to leaf.
Slow due to unpredictable branches and cache misses.
"""
def __init__(self, trees):
self.trees = trees
def predict(self, X_binned):
n_samples = X_binned.shape[0]
n_trees = len(self.trees)
predictions = np.zeros(n_samples, dtype=np.float64)
for tree in self.trees:
for i in range(n_samples):
predictions[i] += self._predict_single(tree, X_binned[i])
return predictions
def _predict_single(self, tree, x):
node = tree.root
while not node.is_leaf:
if x[node.feature] <= node.threshold:
node = node.left_child
else:
node = node.right_child
return node.value
class FlatTreeEngine(InferenceEngine):
"""Flat tree representation for cache-oblivious traversal.
Stores tree in array format (like a heap) to enable:
- Predictable memory access patterns
- Branchless traversal with precomputed jump tables
- SIMD-friendly batch processing
Structure:
- nodes[i]: [feature_idx, threshold, left_child_idx, right_child_idx, value]
- Leaf nodes have feature_idx = -1
"""
def __init__(self, trees, n_bins):
self.n_bins = n_bins
self.flat_trees = []
self.leaf_values = []
for tree in trees:
flat, leaves = self._flatten_tree(tree)
self.flat_trees.append(flat)
self.leaf_values.append(leaves)
def _flatten_tree(self, tree):
"""Convert recursive tree to flat array representation."""
nodes = []
leaves = []
def traverse(node):
idx = len(nodes)
if node.is_leaf:
nodes.append([-1, -1, -1, -1, node.value])
leaves.append(node.value)
return idx
else:
nodes.append([node.feature, node.threshold, -1, -1, 0.0])
left_idx = traverse(node.left_child)
right_idx = traverse(node.right_child)
nodes[idx][2] = left_idx
nodes[idx][3] = right_idx
return idx
traverse(tree.root)
return np.array(nodes, dtype=np.int32), np.array(leaves, dtype=np.float64)
def predict(self, X_binned):
n_samples = X_binned.shape[0]
n_trees = len(self.flat_trees)
predictions = np.zeros(n_samples, dtype=np.float64)
for tree_idx in range(n_trees):
flat = self.flat_trees[tree_idx]
pred = self._predict_flat_batch(X_binned, flat)
predictions += pred
return predictions
@staticmethod
@njit(parallel=True, fastmath=True, cache=True)
def _predict_flat_batch(X_binned, flat_tree):
"""Numba-accelerated flat tree prediction with parallelization."""
n_samples = X_binned.shape[0]
predictions = np.empty(n_samples, dtype=np.float64)
for i in prange(n_samples):
node_idx = 0
while True:
feature = flat_tree[node_idx, 0]
if feature < 0: # Leaf node
predictions[i] = flat_tree[node_idx, 4]
break
threshold = flat_tree[node_idx, 1]
if X_binned[i, feature] <= threshold:
node_idx = flat_tree[node_idx, 2]
else:
node_idx = flat_tree[node_idx, 3]
return predictions
class BatchedSIMDEngine(InferenceEngine):
"""SIMD-like batched inference engine.
Processes multiple samples simultaneously using Numba parallelization.
Simulates AVX-512 style wide-vector operations in Python.
Key optimizations:
- Batch size processing (e.g., 16 samples at a time)
- Feature value vectorized comparison
- Minimized branch misprediction through conditional moves
"""
def __init__(self, trees, n_bins, batch_size=16):
self.n_bins = n_bins
self.batch_size = batch_size
self.flat_engine = FlatTreeEngine(trees, n_bins)
self.flat_trees = self.flat_engine.flat_trees
def predict(self, X_binned):
n_samples = X_binned.shape[0]
n_trees = len(self.flat_trees)
predictions = np.zeros(n_samples, dtype=np.float64)
for flat_tree in self.flat_trees:
n_batches = (n_samples + self.batch_size - 1) // self.batch_size
for batch_idx in range(n_batches):
start = batch_idx * self.batch_size
end = min(start + self.batch_size, n_samples)
batch_X = X_binned[start:end]
batch_pred = self._predict_batch(batch_X, flat_tree)
predictions[start:end] += batch_pred
return predictions
@staticmethod
@njit(fastmath=True, cache=True)
def _predict_batch(X_binned, flat_tree):
n_samples = X_binned.shape[0]
predictions = np.empty(n_samples, dtype=np.float64)
for i in range(n_samples):
node_idx = 0
while True:
feature = flat_tree[node_idx, 0]
if feature < 0: # Leaf
predictions[i] = flat_tree[node_idx, 4]
break
threshold = flat_tree[node_idx, 1]
if X_binned[i, feature] <= threshold:
node_idx = flat_tree[node_idx, 2]
else:
node_idx = flat_tree[node_idx, 3]
return predictions
class QuickScorerEngine(InferenceEngine):
"""QuickScorer-style fast scoring for small trees (<=64 nodes).
Idea (Lucchese et al., CIKM 2015):
- Represent tree conditions as bitmasks
- Evaluate all conditions in parallel using bitwise operations
- Map leaf predictions using bit-pattern matching
Limitation: Only works for small trees (fits in 64-bit word).
For larger trees, falls back to flat tree engine.
"""
def __init__(self, trees, n_bins, max_nodes=64):
self.n_bins = n_bins
self.max_nodes = max_nodes
self.small_trees = []
self.large_trees = []
for tree in trees:
n_nodes = self._count_nodes(tree.root)
if n_nodes <= max_nodes:
self.small_trees.append(self._compile_quickscorer(tree))
else:
self.large_trees.append(tree)
if self.large_trees:
self.fallback_engine = FlatTreeEngine(self.large_trees, n_bins)
else:
self.fallback_engine = None
def _count_nodes(self, node):
if node is None:
return 0
return 1 + self._count_nodes(node.left_child) + self._count_nodes(node.right_child)
def _compile_quickscorer(self, tree):
leaves = []
def collect_leaves(node, path_mask, depth):
if node.is_leaf:
leaf_idx = len(leaves)
leaves.append((node.value, path_mask))
return
true_mask = path_mask | (1 << depth)
collect_leaves(node.left_child, true_mask, depth + 1)
collect_leaves(node.right_child, path_mask, depth + 1)
collect_leaves(tree.root, 0, 0)
flat_engine = FlatTreeEngine([tree], self.n_bins)
return flat_engine.flat_trees[0], [l for l, _ in leaves]
def _get_leaves(self, node):
if node.is_leaf:
return [node]
return self._get_leaves(node.left_child) + self._get_leaves(node.right_child)
def predict(self, X_binned):
n_samples = X_binned.shape[0]
predictions = np.zeros(n_samples, dtype=np.float64)
for compiled in self.small_trees:
flat_tree, leaf_values = compiled
pred = FlatTreeEngine._predict_flat_batch(X_binned, flat_tree)
predictions += pred
if self.fallback_engine:
predictions += self.fallback_engine.predict(X_binned)
return predictions
def compile_inference_engine(model, engine_type='auto'):
"""Compile model into optimized inference engine (YDF-style).
Args:
model: Trained HyperOpt-GBT model with trees_
engine_type: 'naive', 'flat', 'simd', 'quickscorer', or 'auto'
Returns:
InferenceEngine instance
"""
trees = model.trees_
n_bins = model.n_bins
if engine_type == 'auto':
total_nodes = sum(
_count_nodes(t.root if hasattr(t, 'root') else t[0].root if isinstance(t, list) else t)
for t in trees
)
if total_nodes < 1000:
engine_type = 'quickscorer'
else:
engine_type = 'simd'
if engine_type == 'naive':
return NaiveEngine(trees)
elif engine_type == 'flat':
return FlatTreeEngine(trees, n_bins)
elif engine_type == 'simd' or engine_type == 'batched':
return BatchedSIMDEngine(trees, n_bins)
elif engine_type == 'quickscorer':
return QuickScorerEngine(trees, n_bins)
else:
raise ValueError(f"Unknown engine_type: {engine_type}")
def _count_nodes(node):
if node is None:
return 0
if hasattr(node, 'is_leaf') and node.is_leaf:
return 1
return 1 + _count_nodes(node.left_child) + _count_nodes(node.right_child)