""" Inference engine optimizations for HyperOpt-GBT. Implements multiple inference strategies inspired by YDF's engine compilation: 1. Naive tree traversal (baseline) 2. QuickScorer-style bit-mask scoring (for small trees) 3. SIMD batched prediction (AVX-like vectorization in Python) 4. Compiled flat trees (cache-oblivious structure) References: - YDF Inference Engine (arXiv:2212.02934, Section 3.7) - QuickScorer (Lucchese et al., CIKM 2015) """ import numpy as np from numba import njit, prange class InferenceEngine: """Base class for inference engines.""" def predict(self, X_binned): raise NotImplementedError class NaiveEngine(InferenceEngine): """Naive tree traversal - baseline. Single while-loop from root to leaf. Slow due to unpredictable branches and cache misses. """ def __init__(self, trees): self.trees = trees def predict(self, X_binned): n_samples = X_binned.shape[0] n_trees = len(self.trees) predictions = np.zeros(n_samples, dtype=np.float64) for tree in self.trees: for i in range(n_samples): predictions[i] += self._predict_single(tree, X_binned[i]) return predictions def _predict_single(self, tree, x): node = tree.root while not node.is_leaf: if x[node.feature] <= node.threshold: node = node.left_child else: node = node.right_child return node.value class FlatTreeEngine(InferenceEngine): """Flat tree representation for cache-oblivious traversal. Stores tree in array format (like a heap) to enable: - Predictable memory access patterns - Branchless traversal with precomputed jump tables - SIMD-friendly batch processing Structure: - nodes[i]: [feature_idx, threshold, left_child_idx, right_child_idx, value] - Leaf nodes have feature_idx = -1 """ def __init__(self, trees, n_bins): self.n_bins = n_bins self.flat_trees = [] self.leaf_values = [] for tree in trees: flat, leaves = self._flatten_tree(tree) self.flat_trees.append(flat) self.leaf_values.append(leaves) def _flatten_tree(self, tree): """Convert recursive tree to flat array representation.""" nodes = [] leaves = [] def traverse(node): idx = len(nodes) if node.is_leaf: nodes.append([-1, -1, -1, -1, node.value]) leaves.append(node.value) return idx else: nodes.append([node.feature, node.threshold, -1, -1, 0.0]) left_idx = traverse(node.left_child) right_idx = traverse(node.right_child) nodes[idx][2] = left_idx nodes[idx][3] = right_idx return idx traverse(tree.root) return np.array(nodes, dtype=np.int32), np.array(leaves, dtype=np.float64) def predict(self, X_binned): n_samples = X_binned.shape[0] n_trees = len(self.flat_trees) predictions = np.zeros(n_samples, dtype=np.float64) for tree_idx in range(n_trees): flat = self.flat_trees[tree_idx] pred = self._predict_flat_batch(X_binned, flat) predictions += pred return predictions @staticmethod @njit(parallel=True, fastmath=True, cache=True) def _predict_flat_batch(X_binned, flat_tree): """Numba-accelerated flat tree prediction with parallelization.""" n_samples = X_binned.shape[0] predictions = np.empty(n_samples, dtype=np.float64) for i in prange(n_samples): node_idx = 0 while True: feature = flat_tree[node_idx, 0] if feature < 0: # Leaf node predictions[i] = flat_tree[node_idx, 4] break threshold = flat_tree[node_idx, 1] if X_binned[i, feature] <= threshold: node_idx = flat_tree[node_idx, 2] else: node_idx = flat_tree[node_idx, 3] return predictions class BatchedSIMDEngine(InferenceEngine): """SIMD-like batched inference engine. Processes multiple samples simultaneously using Numba parallelization. Simulates AVX-512 style wide-vector operations in Python. Key optimizations: - Batch size processing (e.g., 16 samples at a time) - Feature value vectorized comparison - Minimized branch misprediction through conditional moves """ def __init__(self, trees, n_bins, batch_size=16): self.n_bins = n_bins self.batch_size = batch_size self.flat_engine = FlatTreeEngine(trees, n_bins) self.flat_trees = self.flat_engine.flat_trees def predict(self, X_binned): n_samples = X_binned.shape[0] n_trees = len(self.flat_trees) predictions = np.zeros(n_samples, dtype=np.float64) for flat_tree in self.flat_trees: n_batches = (n_samples + self.batch_size - 1) // self.batch_size for batch_idx in range(n_batches): start = batch_idx * self.batch_size end = min(start + self.batch_size, n_samples) batch_X = X_binned[start:end] batch_pred = self._predict_batch(batch_X, flat_tree) predictions[start:end] += batch_pred return predictions @staticmethod @njit(fastmath=True, cache=True) def _predict_batch(X_binned, flat_tree): n_samples = X_binned.shape[0] predictions = np.empty(n_samples, dtype=np.float64) for i in range(n_samples): node_idx = 0 while True: feature = flat_tree[node_idx, 0] if feature < 0: # Leaf predictions[i] = flat_tree[node_idx, 4] break threshold = flat_tree[node_idx, 1] if X_binned[i, feature] <= threshold: node_idx = flat_tree[node_idx, 2] else: node_idx = flat_tree[node_idx, 3] return predictions class QuickScorerEngine(InferenceEngine): """QuickScorer-style fast scoring for small trees (<=64 nodes). Idea (Lucchese et al., CIKM 2015): - Represent tree conditions as bitmasks - Evaluate all conditions in parallel using bitwise operations - Map leaf predictions using bit-pattern matching Limitation: Only works for small trees (fits in 64-bit word). For larger trees, falls back to flat tree engine. """ def __init__(self, trees, n_bins, max_nodes=64): self.n_bins = n_bins self.max_nodes = max_nodes self.small_trees = [] self.large_trees = [] for tree in trees: n_nodes = self._count_nodes(tree.root) if n_nodes <= max_nodes: self.small_trees.append(self._compile_quickscorer(tree)) else: self.large_trees.append(tree) if self.large_trees: self.fallback_engine = FlatTreeEngine(self.large_trees, n_bins) else: self.fallback_engine = None def _count_nodes(self, node): if node is None: return 0 return 1 + self._count_nodes(node.left_child) + self._count_nodes(node.right_child) def _compile_quickscorer(self, tree): leaves = [] def collect_leaves(node, path_mask, depth): if node.is_leaf: leaf_idx = len(leaves) leaves.append((node.value, path_mask)) return true_mask = path_mask | (1 << depth) collect_leaves(node.left_child, true_mask, depth + 1) collect_leaves(node.right_child, path_mask, depth + 1) collect_leaves(tree.root, 0, 0) flat_engine = FlatTreeEngine([tree], self.n_bins) return flat_engine.flat_trees[0], [l for l, _ in leaves] def _get_leaves(self, node): if node.is_leaf: return [node] return self._get_leaves(node.left_child) + self._get_leaves(node.right_child) def predict(self, X_binned): n_samples = X_binned.shape[0] predictions = np.zeros(n_samples, dtype=np.float64) for compiled in self.small_trees: flat_tree, leaf_values = compiled pred = FlatTreeEngine._predict_flat_batch(X_binned, flat_tree) predictions += pred if self.fallback_engine: predictions += self.fallback_engine.predict(X_binned) return predictions def compile_inference_engine(model, engine_type='auto'): """Compile model into optimized inference engine (YDF-style). Args: model: Trained HyperOpt-GBT model with trees_ engine_type: 'naive', 'flat', 'simd', 'quickscorer', or 'auto' Returns: InferenceEngine instance """ trees = model.trees_ n_bins = model.n_bins if engine_type == 'auto': total_nodes = sum( _count_nodes(t.root if hasattr(t, 'root') else t[0].root if isinstance(t, list) else t) for t in trees ) if total_nodes < 1000: engine_type = 'quickscorer' else: engine_type = 'simd' if engine_type == 'naive': return NaiveEngine(trees) elif engine_type == 'flat': return FlatTreeEngine(trees, n_bins) elif engine_type == 'simd' or engine_type == 'batched': return BatchedSIMDEngine(trees, n_bins) elif engine_type == 'quickscorer': return QuickScorerEngine(trees, n_bins) else: raise ValueError(f"Unknown engine_type: {engine_type}") def _count_nodes(node): if node is None: return 0 if hasattr(node, 'is_leaf') and node.is_leaf: return 1 return 1 + _count_nodes(node.left_child) + _count_nodes(node.right_child)