| """ |
| Inference engine optimizations for HyperOpt-GBT. |
| |
| Implements multiple inference strategies inspired by YDF's engine compilation: |
| 1. Naive tree traversal (baseline) |
| 2. QuickScorer-style bit-mask scoring (for small trees) |
| 3. SIMD batched prediction (AVX-like vectorization in Python) |
| 4. Compiled flat trees (cache-oblivious structure) |
| |
| References: |
| - YDF Inference Engine (arXiv:2212.02934, Section 3.7) |
| - QuickScorer (Lucchese et al., CIKM 2015) |
| """ |
|
|
| import numpy as np |
| from numba import njit, prange |
|
|
|
|
| class InferenceEngine: |
| """Base class for inference engines.""" |
| |
| def predict(self, X_binned): |
| raise NotImplementedError |
|
|
|
|
| class NaiveEngine(InferenceEngine): |
| """Naive tree traversal - baseline. |
| |
| Single while-loop from root to leaf. |
| Slow due to unpredictable branches and cache misses. |
| """ |
| |
| def __init__(self, trees): |
| self.trees = trees |
| |
| def predict(self, X_binned): |
| n_samples = X_binned.shape[0] |
| n_trees = len(self.trees) |
| predictions = np.zeros(n_samples, dtype=np.float64) |
| |
| for tree in self.trees: |
| for i in range(n_samples): |
| predictions[i] += self._predict_single(tree, X_binned[i]) |
| |
| return predictions |
| |
| def _predict_single(self, tree, x): |
| node = tree.root |
| while not node.is_leaf: |
| if x[node.feature] <= node.threshold: |
| node = node.left_child |
| else: |
| node = node.right_child |
| return node.value |
|
|
|
|
| class FlatTreeEngine(InferenceEngine): |
| """Flat tree representation for cache-oblivious traversal. |
| |
| Stores tree in array format (like a heap) to enable: |
| - Predictable memory access patterns |
| - Branchless traversal with precomputed jump tables |
| - SIMD-friendly batch processing |
| |
| Structure: |
| - nodes[i]: [feature_idx, threshold, left_child_idx, right_child_idx, value] |
| - Leaf nodes have feature_idx = -1 |
| """ |
| |
| def __init__(self, trees, n_bins): |
| self.n_bins = n_bins |
| self.flat_trees = [] |
| self.leaf_values = [] |
| |
| for tree in trees: |
| flat, leaves = self._flatten_tree(tree) |
| self.flat_trees.append(flat) |
| self.leaf_values.append(leaves) |
| |
| def _flatten_tree(self, tree): |
| """Convert recursive tree to flat array representation.""" |
| nodes = [] |
| leaves = [] |
| |
| def traverse(node): |
| idx = len(nodes) |
| if node.is_leaf: |
| nodes.append([-1, -1, -1, -1, node.value]) |
| leaves.append(node.value) |
| return idx |
| else: |
| nodes.append([node.feature, node.threshold, -1, -1, 0.0]) |
| left_idx = traverse(node.left_child) |
| right_idx = traverse(node.right_child) |
| nodes[idx][2] = left_idx |
| nodes[idx][3] = right_idx |
| return idx |
| |
| traverse(tree.root) |
| return np.array(nodes, dtype=np.int32), np.array(leaves, dtype=np.float64) |
| |
| def predict(self, X_binned): |
| n_samples = X_binned.shape[0] |
| n_trees = len(self.flat_trees) |
| predictions = np.zeros(n_samples, dtype=np.float64) |
| |
| for tree_idx in range(n_trees): |
| flat = self.flat_trees[tree_idx] |
| pred = self._predict_flat_batch(X_binned, flat) |
| predictions += pred |
| |
| return predictions |
| |
| @staticmethod |
| @njit(parallel=True, fastmath=True, cache=True) |
| def _predict_flat_batch(X_binned, flat_tree): |
| """Numba-accelerated flat tree prediction with parallelization.""" |
| n_samples = X_binned.shape[0] |
| predictions = np.empty(n_samples, dtype=np.float64) |
| |
| for i in prange(n_samples): |
| node_idx = 0 |
| while True: |
| feature = flat_tree[node_idx, 0] |
| if feature < 0: |
| predictions[i] = flat_tree[node_idx, 4] |
| break |
| threshold = flat_tree[node_idx, 1] |
| if X_binned[i, feature] <= threshold: |
| node_idx = flat_tree[node_idx, 2] |
| else: |
| node_idx = flat_tree[node_idx, 3] |
| |
| return predictions |
|
|
|
|
| class BatchedSIMDEngine(InferenceEngine): |
| """SIMD-like batched inference engine. |
| |
| Processes multiple samples simultaneously using Numba parallelization. |
| Simulates AVX-512 style wide-vector operations in Python. |
| |
| Key optimizations: |
| - Batch size processing (e.g., 16 samples at a time) |
| - Feature value vectorized comparison |
| - Minimized branch misprediction through conditional moves |
| """ |
| |
| def __init__(self, trees, n_bins, batch_size=16): |
| self.n_bins = n_bins |
| self.batch_size = batch_size |
| self.flat_engine = FlatTreeEngine(trees, n_bins) |
| self.flat_trees = self.flat_engine.flat_trees |
| |
| def predict(self, X_binned): |
| n_samples = X_binned.shape[0] |
| n_trees = len(self.flat_trees) |
| predictions = np.zeros(n_samples, dtype=np.float64) |
| |
| for flat_tree in self.flat_trees: |
| n_batches = (n_samples + self.batch_size - 1) // self.batch_size |
| |
| for batch_idx in range(n_batches): |
| start = batch_idx * self.batch_size |
| end = min(start + self.batch_size, n_samples) |
| batch_X = X_binned[start:end] |
| |
| batch_pred = self._predict_batch(batch_X, flat_tree) |
| predictions[start:end] += batch_pred |
| |
| return predictions |
| |
| @staticmethod |
| @njit(fastmath=True, cache=True) |
| def _predict_batch(X_binned, flat_tree): |
| n_samples = X_binned.shape[0] |
| predictions = np.empty(n_samples, dtype=np.float64) |
| |
| for i in range(n_samples): |
| node_idx = 0 |
| while True: |
| feature = flat_tree[node_idx, 0] |
| if feature < 0: |
| predictions[i] = flat_tree[node_idx, 4] |
| break |
| threshold = flat_tree[node_idx, 1] |
| if X_binned[i, feature] <= threshold: |
| node_idx = flat_tree[node_idx, 2] |
| else: |
| node_idx = flat_tree[node_idx, 3] |
| |
| return predictions |
|
|
|
|
| class QuickScorerEngine(InferenceEngine): |
| """QuickScorer-style fast scoring for small trees (<=64 nodes). |
| |
| Idea (Lucchese et al., CIKM 2015): |
| - Represent tree conditions as bitmasks |
| - Evaluate all conditions in parallel using bitwise operations |
| - Map leaf predictions using bit-pattern matching |
| |
| Limitation: Only works for small trees (fits in 64-bit word). |
| For larger trees, falls back to flat tree engine. |
| """ |
| |
| def __init__(self, trees, n_bins, max_nodes=64): |
| self.n_bins = n_bins |
| self.max_nodes = max_nodes |
| self.small_trees = [] |
| self.large_trees = [] |
| |
| for tree in trees: |
| n_nodes = self._count_nodes(tree.root) |
| if n_nodes <= max_nodes: |
| self.small_trees.append(self._compile_quickscorer(tree)) |
| else: |
| self.large_trees.append(tree) |
| |
| if self.large_trees: |
| self.fallback_engine = FlatTreeEngine(self.large_trees, n_bins) |
| else: |
| self.fallback_engine = None |
| |
| def _count_nodes(self, node): |
| if node is None: |
| return 0 |
| return 1 + self._count_nodes(node.left_child) + self._count_nodes(node.right_child) |
| |
| def _compile_quickscorer(self, tree): |
| leaves = [] |
| |
| def collect_leaves(node, path_mask, depth): |
| if node.is_leaf: |
| leaf_idx = len(leaves) |
| leaves.append((node.value, path_mask)) |
| return |
| true_mask = path_mask | (1 << depth) |
| collect_leaves(node.left_child, true_mask, depth + 1) |
| collect_leaves(node.right_child, path_mask, depth + 1) |
| |
| collect_leaves(tree.root, 0, 0) |
| |
| flat_engine = FlatTreeEngine([tree], self.n_bins) |
| return flat_engine.flat_trees[0], [l for l, _ in leaves] |
| |
| def _get_leaves(self, node): |
| if node.is_leaf: |
| return [node] |
| return self._get_leaves(node.left_child) + self._get_leaves(node.right_child) |
| |
| def predict(self, X_binned): |
| n_samples = X_binned.shape[0] |
| predictions = np.zeros(n_samples, dtype=np.float64) |
| |
| for compiled in self.small_trees: |
| flat_tree, leaf_values = compiled |
| pred = FlatTreeEngine._predict_flat_batch(X_binned, flat_tree) |
| predictions += pred |
| |
| if self.fallback_engine: |
| predictions += self.fallback_engine.predict(X_binned) |
| |
| return predictions |
|
|
|
|
| def compile_inference_engine(model, engine_type='auto'): |
| """Compile model into optimized inference engine (YDF-style). |
| |
| Args: |
| model: Trained HyperOpt-GBT model with trees_ |
| engine_type: 'naive', 'flat', 'simd', 'quickscorer', or 'auto' |
| |
| Returns: |
| InferenceEngine instance |
| """ |
| trees = model.trees_ |
| n_bins = model.n_bins |
| |
| if engine_type == 'auto': |
| total_nodes = sum( |
| _count_nodes(t.root if hasattr(t, 'root') else t[0].root if isinstance(t, list) else t) |
| for t in trees |
| ) |
| |
| if total_nodes < 1000: |
| engine_type = 'quickscorer' |
| else: |
| engine_type = 'simd' |
| |
| if engine_type == 'naive': |
| return NaiveEngine(trees) |
| elif engine_type == 'flat': |
| return FlatTreeEngine(trees, n_bins) |
| elif engine_type == 'simd' or engine_type == 'batched': |
| return BatchedSIMDEngine(trees, n_bins) |
| elif engine_type == 'quickscorer': |
| return QuickScorerEngine(trees, n_bins) |
| else: |
| raise ValueError(f"Unknown engine_type: {engine_type}") |
|
|
|
|
| def _count_nodes(node): |
| if node is None: |
| return 0 |
| if hasattr(node, 'is_leaf') and node.is_leaf: |
| return 1 |
| return 1 + _count_nodes(node.left_child) + _count_nodes(node.right_child) |
|
|