| """
|
| Batch Processing and GPU Acceleration (Phase 3.5.5)
|
| ===================================================
|
| Implementation of batch operations for HDVs, leveraging PyTorch for GPU acceleration
|
| if available, with fallback to NumPy (CPU).
|
|
|
| Designed to scale comfortably from Raspberry Pi (CPU) to dedicated AI rigs (CUDA).
|
| """
|
|
|
| import multiprocessing
|
| from concurrent.futures import ProcessPoolExecutor, as_completed
|
| from typing import List, Tuple, Optional
|
|
|
| import numpy as np
|
| from loguru import logger
|
|
|
|
|
| try:
|
| import torch
|
| TORCH_AVAILABLE = True
|
| except ImportError:
|
| torch = None
|
| TORCH_AVAILABLE = False
|
|
|
| from .binary_hdv import BinaryHDV, TextEncoder, batch_hamming_distance
|
|
|
|
|
| def _encode_single_worker(args: tuple) -> bytes:
|
| """Module-level worker function for ProcessPoolExecutor (must be picklable)."""
|
| text, dim = args
|
| encoder = TextEncoder(dimension=dim)
|
| hdv = encoder.encode(text)
|
| return hdv.to_bytes()
|
|
|
|
|
| class BatchProcessor:
|
| """
|
| Handles batched operations for HDV encoding and search.
|
| Automatically selects the best available backend (CUDA > MPS > CPU).
|
| """
|
|
|
| def __init__(self, use_gpu: bool = True, num_workers: Optional[int] = None):
|
| """
|
| Args:
|
| use_gpu: Whether to attempt using GPU acceleration.
|
| num_workers: Number of CPU workers for encoding (defaults to CPU count).
|
| """
|
| self.device = self._detect_device(use_gpu)
|
| self.num_workers = num_workers or multiprocessing.cpu_count()
|
| self.popcount_table_gpu = None
|
|
|
| logger.info(f"BatchProcessor initialized on device: {self.device}")
|
|
|
|
|
|
|
|
|
| def _detect_device(self, use_gpu: bool) -> str:
|
| """Detect the best available compute device."""
|
| if not use_gpu or not TORCH_AVAILABLE:
|
| return "cpu"
|
|
|
| if torch.cuda.is_available():
|
| return "cuda"
|
| elif torch.backends.mps.is_available():
|
| return "mps"
|
| else:
|
| return "cpu"
|
|
|
| def _ensure_gpu_table(self):
|
| """Initialize bits-set lookup table on GPU if needed."""
|
| if self.device == "cpu" or self.popcount_table_gpu is not None:
|
| return
|
|
|
|
|
|
|
| table = torch.tensor(
|
| [bin(i).count("1") for i in range(256)],
|
| dtype=torch.int32,
|
| device=self.device
|
| )
|
| self.popcount_table_gpu = table
|
|
|
| def encode_batch(self, texts: List[str], dimension: int = 16384) -> List[BinaryHDV]:
|
| """
|
| Encode a batch of texts into BinaryHDVs using parallel CPU processing.
|
|
|
| Encoding logic is strictly CPU-bound (tokenization + python loops),
|
| so we use ProcessPoolExecutor to bypass the GIL.
|
| """
|
| if not texts:
|
| return []
|
|
|
| results = [None] * len(texts)
|
|
|
|
|
| with ProcessPoolExecutor(max_workers=self.num_workers) as executor:
|
| future_to_idx = {
|
| executor.submit(_encode_single_worker, (text, dimension)): i
|
| for i, text in enumerate(texts)
|
| }
|
|
|
| for future in as_completed(future_to_idx):
|
| idx = future_to_idx[future]
|
| try:
|
| raw_bytes = future.result()
|
| results[idx] = BinaryHDV.from_bytes(raw_bytes, dimension=dimension)
|
| except Exception as e:
|
| logger.error(f"Encoding failed for item {idx}: {e}")
|
| results[idx] = BinaryHDV.zeros(dimension)
|
|
|
| return results
|
|
|
| def search_batch(
|
| self,
|
| queries: List[BinaryHDV],
|
| targets: List[BinaryHDV]
|
| ) -> np.ndarray:
|
| """
|
| Compute Hamming distance matrix between queries and targets.
|
|
|
| Args:
|
| queries: List of M query vectors.
|
| targets: List of N target vectors.
|
|
|
| Returns:
|
| np.ndarray of shape (M, N) with Hamming distances.
|
| """
|
| if not queries or not targets:
|
| return np.array([[]])
|
|
|
|
|
| d_bytes = queries[0].dimension // 8
|
|
|
|
|
|
|
| query_arr = np.stack([q.data for q in queries])
|
| target_arr = np.stack([t.data for t in targets])
|
|
|
| if self.device == "cpu":
|
| return self._search_cpu(query_arr, target_arr)
|
| else:
|
| return self._search_gpu(query_arr, target_arr)
|
|
|
| def _search_cpu(self, query_arr: np.ndarray, target_arr: np.ndarray) -> np.ndarray:
|
| """NumPy-based batch Hamming distance."""
|
|
|
|
|
|
|
|
|
|
|
| M, B = query_arr.shape
|
| N = target_arr.shape[0]
|
|
|
|
|
|
|
|
|
| dists = np.zeros((M, N), dtype=np.int32)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| from .binary_hdv import _build_popcount_table
|
| popcount_table = _build_popcount_table()
|
|
|
| for i in range(M):
|
|
|
| xor_result = np.bitwise_xor(target_arr, query_arr[i])
|
|
|
| dists[i] = popcount_table[xor_result].sum(axis=1)
|
|
|
| return dists
|
|
|
| def _search_gpu(self, query_arr: np.ndarray, target_arr: np.ndarray) -> np.ndarray:
|
| """PyTorch-based batch Hamming distance."""
|
| self._ensure_gpu_table()
|
|
|
|
|
|
|
| q_tensor = torch.from_numpy(query_arr).to(self.device)
|
| t_tensor = torch.from_numpy(target_arr).to(self.device)
|
|
|
| M = q_tensor.shape[0]
|
| N = t_tensor.shape[0]
|
|
|
|
|
| dists = torch.zeros((M, N), dtype=torch.int32, device=self.device)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| for i in range(M):
|
|
|
|
|
| xor_result = torch.bitwise_xor(t_tensor, q_tensor[i])
|
|
|
|
|
|
|
|
|
| counts = self.popcount_table_gpu[xor_result.long()]
|
|
|
|
|
| dists[i] = counts.sum(dim=1)
|
|
|
| return dists.cpu().numpy()
|
|
|