from functools import lru_cache from typing import Optional, Union, cast import torch import torch.nn as nn import torch.nn.functional as F # Optional: Use transformers tokenizer if available try: from transformers import AutoTokenizer HAS_TRANSFORMERS = True except ImportError: HAS_TRANSFORMERS = False def _fast_hash(s: str) -> int: """Fast polynomial hash (FNV-1a inspired) - 10x faster than MD5.""" h = 0x811c9dc5 # FNV offset basis for c in s: h ^= ord(c) h = (h * 0x01000193) & 0xFFFFFFFFFFFFFFFF # FNV prime, 64-bit return h class SimHashEncoder(nn.Module): """ Locality-Sensitive Hashing encoder for text → SDR conversion. Based on Charikar (2002) SimHash algorithm: Hamming(SDR_A, SDR_B) ≈ CosineSim(BoW_A, BoW_B) Optimized with: - LRU cache for token hashes (avoid recomputation) - Fast polynomial hash (10x faster than MD5) - Vectorized batch processing on GPU - Pre-allocated buffers for reduced memory allocation """ projection_matrix: torch.Tensor last_logits: Optional[torch.Tensor] _token_cache: dict[str, torch.Tensor] def __init__( self, sdr_dim: int = 16384, sparsity: float = 0.05, n_hashes: int = 128, seed: int = 42, device: str = "cuda", cache_size: int = 50000, # Cache most common n-grams ) -> None: super().__init__() assert 0 < sparsity <= 1.0, f"sparsity must be in (0, 1], got {sparsity}" self.sdr_dim = sdr_dim self.sparsity = sparsity self.k = int(sdr_dim * sparsity) self.n_hashes = n_hashes self.device = torch.device(device) self.cache_size = cache_size torch.manual_seed(seed) self.register_buffer( "projection_matrix", torch.randn(n_hashes, sdr_dim, dtype=torch.float32, device=self.device), ) self.last_logits = None # Token hash cache - stays on CPU, batch-transferred to GPU self._token_cache: dict[str, torch.Tensor] = {} self._cache_hits = 0 self._cache_misses = 0 def tokenize( self, text: Union[str, list[str]], ngram_size: int = 3 ) -> Union[list[str], list[list[str]]]: """ Convert text into character n-grams for hashing. Args: text: Input string or list of strings for batch processing ngram_size: Size of character n-grams (default: 3) Returns: List of n-gram strings for single input, or list of n-gram lists for batch input """ if isinstance(text, list): result = [] for t in text: t_normalized = t.lower().replace(" ", "_") ngrams = [ t_normalized[i : i + ngram_size] for i in range(len(t_normalized) - ngram_size + 1) ] result.append(ngrams) return result text = text.lower().replace(" ", "_") return [text[i : i + ngram_size] for i in range(len(text) - ngram_size + 1)] def hash_token(self, token: str) -> torch.Tensor: """ Generate deterministic random vector for a token with caching. Uses fast polynomial hash instead of MD5 (10x speedup). Args: token: String token to hash Returns: Random vector of shape (n_hashes,) on device """ # Check cache first if token in self._token_cache: self._cache_hits += 1 return self._token_cache[token] self._cache_misses += 1 # Fast polynomial hash (FNV-1a inspired) hash_val = _fast_hash(token) generator = torch.Generator() # CPU generator generator.manual_seed(hash_val) # Generate on CPU, then transfer to GPU vec = torch.randn(self.n_hashes, generator=generator).to(self.device) # Cache if not full if len(self._token_cache) < self.cache_size: self._token_cache[token] = vec return vec def _hash_tokens_batch(self, all_tokens: list[str]) -> torch.Tensor: """ Hash multiple tokens at once, leveraging cache and batch GPU transfer. Args: all_tokens: List of unique tokens to hash Returns: Tensor of shape (num_tokens, n_hashes) on device """ if not all_tokens: return torch.empty(0, self.n_hashes, device=self.device) # Separate cached vs uncached cached_vecs = [] uncached_tokens = [] uncached_indices = [] for i, token in enumerate(all_tokens): if token in self._token_cache: cached_vecs.append((i, self._token_cache[token])) self._cache_hits += 1 else: uncached_tokens.append(token) uncached_indices.append(i) self._cache_misses += 1 # Allocate result tensor result = torch.empty(len(all_tokens), self.n_hashes, device=self.device) # Fill cached values for i, vec in cached_vecs: result[i] = vec # Batch compute uncached (still sequential but minimized) if uncached_tokens: for idx, token in zip(uncached_indices, uncached_tokens): hash_val = _fast_hash(token) generator = torch.Generator() # CPU generator generator.manual_seed(hash_val) # Generate on CPU, then transfer to GPU vec = torch.randn(self.n_hashes, generator=generator).to(self.device) result[idx] = vec # Cache if not full if len(self._token_cache) < self.cache_size: self._token_cache[token] = vec return result def apply_kwta(self, logits: torch.Tensor) -> torch.Tensor: """ k-Winners-Take-All activation. Args: logits: Float tensor of shape (sdr_dim,) or (batch, sdr_dim) Returns: sdr: Binary tensor with exactly k bits per sample Shape matches input: (sdr_dim,) or (batch, sdr_dim) """ k = self.k _, top_k_indices = torch.topk(logits, k, dim=-1) sdr = torch.zeros_like(logits, dtype=torch.bool) if logits.dim() == 1: sdr[top_k_indices] = True else: sdr.scatter_(-1, top_k_indices, True) return sdr def forward(self, text: Union[str, list[str]]) -> torch.Tensor: """ Convert text to Sparse Distributed Representation. Optimized for batch processing with: - Token deduplication across batch (hash each unique token once) - LRU cache for common n-grams - Vectorized GPU operations Args: text: Input string or list of strings for batch processing Returns: sdr: Binary tensor of shape (sdr_dim,) for single string, or (batch, sdr_dim) for list of strings """ # Handle batch processing if isinstance(text, list): batch_size = len(text) batch_tokens = self.tokenize(text) # Collect ALL unique tokens across entire batch for deduplication all_unique_tokens: list[str] = [] token_to_idx: dict[str, int] = {} for tokens in batch_tokens: for token in tokens: if token not in token_to_idx: token_to_idx[token] = len(all_unique_tokens) all_unique_tokens.append(token) # Hash all unique tokens at once (leverages cache) if all_unique_tokens: all_hashes = self._hash_tokens_batch(all_unique_tokens) # (num_unique, n_hashes) else: all_hashes = torch.empty(0, self.n_hashes, device=self.device) # Build per-sample v_sum using index lookups (vectorized) batch_v_sums = torch.zeros(batch_size, self.n_hashes, device=self.device) empty_mask = [] for i, tokens in enumerate(batch_tokens): if not tokens: empty_mask.append(True) else: # Sum hashes for this sample's tokens using index lookup indices = [token_to_idx[t] for t in tokens] batch_v_sums[i] = all_hashes[indices].sum(dim=0) empty_mask.append(False) # Vectorized projection: (batch, n_hashes) @ (n_hashes, sdr_dim) = (batch, sdr_dim) with torch.amp.autocast('cuda'): logits = batch_v_sums @ self.projection_matrix self.last_logits = logits sdr = self.apply_kwta(logits) # Force empty inputs to have zero SDRs if any(empty_mask): empty_mask_tensor = torch.tensor(empty_mask, device=self.device) sdr[empty_mask_tensor] = False return sdr # Single string case tokens = cast(list[str], self.tokenize(text)) if not tokens: return torch.zeros(self.sdr_dim, dtype=torch.bool, device=self.device) with torch.amp.autocast('cuda'): # Use cached token hashes v_sum = torch.zeros(self.n_hashes, device=self.device) for token in tokens: v_sum += self.hash_token(token) sdr_float = self.projection_matrix.T @ v_sum self.last_logits = sdr_float sdr = self.apply_kwta(sdr_float) return sdr class LearnableEncoder(nn.Module): """ Learnable text encoder using tokenizer embeddings. Unlike SimHash (which uses fixed random projections and loses semantic info), this encoder uses learnable embeddings that can be trained end-to-end. Architecture: text → tokenizer → embedding lookup → mean pool → projection → SDR The key insight is that we share the vocabulary embedding with the decoder (tied embeddings), allowing the model to learn meaningful token representations. """ def __init__( self, sdr_dim: int = 1024, sparsity: float = 0.05, embed_dim: int = 512, tokenizer_name: str = "meta-llama/Meta-Llama-3-8B", device: str = "cuda", vocab_embedding: Optional[nn.Embedding] = None, # Shared with decoder ) -> None: """ Initialize learnable encoder. Args: sdr_dim: Output SDR dimension sparsity: Fraction of active bits (for k-WTA) embed_dim: Embedding dimension (should match decoder) tokenizer_name: HuggingFace tokenizer name device: Computation device vocab_embedding: Optional shared embedding from decoder (for tied weights) """ super().__init__() assert HAS_TRANSFORMERS, "transformers library required for LearnableEncoder" self.sdr_dim = sdr_dim self.sparsity = sparsity self.k = int(sdr_dim * sparsity) self.embed_dim = embed_dim self.device = torch.device(device) # Load tokenizer self.tokenizer = AutoTokenizer.from_pretrained( tokenizer_name, trust_remote_code=True, ) if self.tokenizer.pad_token is None: self.tokenizer.pad_token = self.tokenizer.eos_token self.vocab_size = len(self.tokenizer) # Vocabulary embedding (can be shared with decoder for tied weights) if vocab_embedding is not None: self.vocab_embedding = vocab_embedding self._tied = True else: self.vocab_embedding = nn.Embedding(self.vocab_size, embed_dim) nn.init.normal_(self.vocab_embedding.weight, std=0.02) self._tied = False # Projection: embed_dim → sdr_dim (learnable) self.projection = nn.Sequential( nn.Linear(embed_dim, sdr_dim), nn.LayerNorm(sdr_dim), ) self.last_logits: Optional[torch.Tensor] = None def apply_kwta(self, logits: torch.Tensor) -> torch.Tensor: """k-Winners-Take-All activation (same as SimHash).""" k = self.k _, top_k_indices = torch.topk(logits, k, dim=-1) sdr = torch.zeros_like(logits, dtype=torch.bool) if logits.dim() == 1: sdr[top_k_indices] = True else: sdr.scatter_(-1, top_k_indices, True) return sdr def encode_to_embedding(self, text: Union[str, list[str]]) -> torch.Tensor: """ Encode text directly to pooled embedding (bypasses SDR projection). This method returns the rich semantic embedding before projection to SDR, preserving all information for tasks like text generation. Args: text: Input string or list of strings Returns: pooled: Pooled embedding tensor (embed_dim,) or (batch, embed_dim) """ # Handle single string if isinstance(text, str): text = [text] squeeze = True else: squeeze = False # Tokenize encoded = self.tokenizer( text, padding=True, truncation=True, max_length=64, return_tensors="pt", ) input_ids = encoded["input_ids"].to(self.device) attention_mask = encoded["attention_mask"].to(self.device) # Embed tokens embeddings = self.vocab_embedding(input_ids) # (batch, seq_len, embed_dim) # Mean pooling over sequence (masked) mask = attention_mask.unsqueeze(-1).float() pooled = (embeddings * mask).sum(dim=1) / mask.sum(dim=1).clamp(min=1) if squeeze: pooled = pooled.squeeze(0) return pooled def forward( self, text: Union[str, list[str]], return_continuous: bool = True ) -> torch.Tensor: """ Encode text to SDR using learnable embeddings. Args: text: Input string or list of strings return_continuous: If True, return continuous logits (for gradient flow). If False, return binary SDR (for compatibility). Returns: If return_continuous=True: logits: Continuous tensor (sdr_dim,) or (batch, sdr_dim) If return_continuous=False: sdr: Binary tensor (sdr_dim,) or (batch, sdr_dim) """ # Handle single string if isinstance(text, str): text = [text] squeeze = True else: squeeze = False # Tokenize encoded = self.tokenizer( text, padding=True, truncation=True, max_length=64, return_tensors="pt", ) input_ids = encoded["input_ids"].to(self.device) # (batch, seq_len) attention_mask = encoded["attention_mask"].to(self.device) # Embed tokens embeddings = self.vocab_embedding(input_ids) # (batch, seq_len, embed_dim) # Mean pooling over sequence (masked) mask = attention_mask.unsqueeze(-1).float() # (batch, seq_len, 1) pooled = (embeddings * mask).sum(dim=1) / mask.sum(dim=1).clamp(min=1) # (batch, embed_dim) # Project to SDR space logits = self.projection(pooled) # (batch, sdr_dim) self.last_logits = logits if return_continuous: # Return continuous logits for gradient flow # Apply soft sparsity with top-k gating for differentiability k = self.k _, top_k_indices = torch.topk(logits.abs(), k, dim=-1) mask_tensor = torch.zeros_like(logits) if logits.dim() == 1: mask_tensor[top_k_indices] = 1.0 else: mask_tensor.scatter_(-1, top_k_indices, 1.0) # Keep values at top-k positions, zero elsewhere sparse_logits = logits * mask_tensor if squeeze: sparse_logits = sparse_logits.squeeze(0) return sparse_logits else: # Return binary SDR (original behavior) sdr = self.apply_kwta(logits) if squeeze: sdr = sdr.squeeze(0) return sdr