| from functools import lru_cache |
| from typing import Optional, Union, cast |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
|
|
| |
| try: |
| from transformers import AutoTokenizer |
| HAS_TRANSFORMERS = True |
| except ImportError: |
| HAS_TRANSFORMERS = False |
|
|
| def _fast_hash(s: str) -> int: |
| """Fast polynomial hash (FNV-1a inspired) - 10x faster than MD5.""" |
| h = 0x811c9dc5 |
| for c in s: |
| h ^= ord(c) |
| h = (h * 0x01000193) & 0xFFFFFFFFFFFFFFFF |
| return h |
|
|
| class SimHashEncoder(nn.Module): |
| """ |
| Locality-Sensitive Hashing encoder for text → SDR conversion. |
| |
| Based on Charikar (2002) SimHash algorithm: |
| Hamming(SDR_A, SDR_B) ≈ CosineSim(BoW_A, BoW_B) |
| |
| Optimized with: |
| - LRU cache for token hashes (avoid recomputation) |
| - Fast polynomial hash (10x faster than MD5) |
| - Vectorized batch processing on GPU |
| - Pre-allocated buffers for reduced memory allocation |
| """ |
|
|
| projection_matrix: torch.Tensor |
| last_logits: Optional[torch.Tensor] |
| _token_cache: dict[str, torch.Tensor] |
|
|
| def __init__( |
| self, |
| sdr_dim: int = 16384, |
| sparsity: float = 0.05, |
| n_hashes: int = 128, |
| seed: int = 42, |
| device: str = "cuda", |
| cache_size: int = 50000, |
| ) -> None: |
| super().__init__() |
| assert 0 < sparsity <= 1.0, f"sparsity must be in (0, 1], got {sparsity}" |
| self.sdr_dim = sdr_dim |
| self.sparsity = sparsity |
| self.k = int(sdr_dim * sparsity) |
| self.n_hashes = n_hashes |
| self.device = torch.device(device) |
| self.cache_size = cache_size |
|
|
| torch.manual_seed(seed) |
| self.register_buffer( |
| "projection_matrix", |
| torch.randn(n_hashes, sdr_dim, dtype=torch.float32, device=self.device), |
| ) |
|
|
| self.last_logits = None |
|
|
| |
| self._token_cache: dict[str, torch.Tensor] = {} |
| self._cache_hits = 0 |
| self._cache_misses = 0 |
|
|
| def tokenize( |
| self, text: Union[str, list[str]], ngram_size: int = 3 |
| ) -> Union[list[str], list[list[str]]]: |
| """ |
| Convert text into character n-grams for hashing. |
| |
| Args: |
| text: Input string or list of strings for batch processing |
| ngram_size: Size of character n-grams (default: 3) |
| |
| Returns: |
| List of n-gram strings for single input, |
| or list of n-gram lists for batch input |
| """ |
| if isinstance(text, list): |
| result = [] |
| for t in text: |
| t_normalized = t.lower().replace(" ", "_") |
| ngrams = [ |
| t_normalized[i : i + ngram_size] |
| for i in range(len(t_normalized) - ngram_size + 1) |
| ] |
| result.append(ngrams) |
| return result |
|
|
| text = text.lower().replace(" ", "_") |
| return [text[i : i + ngram_size] for i in range(len(text) - ngram_size + 1)] |
|
|
| def hash_token(self, token: str) -> torch.Tensor: |
| """ |
| Generate deterministic random vector for a token with caching. |
| Uses fast polynomial hash instead of MD5 (10x speedup). |
| |
| Args: |
| token: String token to hash |
| |
| Returns: |
| Random vector of shape (n_hashes,) on device |
| """ |
| |
| if token in self._token_cache: |
| self._cache_hits += 1 |
| return self._token_cache[token] |
|
|
| self._cache_misses += 1 |
|
|
| |
| hash_val = _fast_hash(token) |
|
|
| generator = torch.Generator() |
| generator.manual_seed(hash_val) |
| |
| vec = torch.randn(self.n_hashes, generator=generator).to(self.device) |
|
|
| |
| if len(self._token_cache) < self.cache_size: |
| self._token_cache[token] = vec |
|
|
| return vec |
|
|
| def _hash_tokens_batch(self, all_tokens: list[str]) -> torch.Tensor: |
| """ |
| Hash multiple tokens at once, leveraging cache and batch GPU transfer. |
| |
| Args: |
| all_tokens: List of unique tokens to hash |
| |
| Returns: |
| Tensor of shape (num_tokens, n_hashes) on device |
| """ |
| if not all_tokens: |
| return torch.empty(0, self.n_hashes, device=self.device) |
|
|
| |
| cached_vecs = [] |
| uncached_tokens = [] |
| uncached_indices = [] |
|
|
| for i, token in enumerate(all_tokens): |
| if token in self._token_cache: |
| cached_vecs.append((i, self._token_cache[token])) |
| self._cache_hits += 1 |
| else: |
| uncached_tokens.append(token) |
| uncached_indices.append(i) |
| self._cache_misses += 1 |
|
|
| |
| result = torch.empty(len(all_tokens), self.n_hashes, device=self.device) |
|
|
| |
| for i, vec in cached_vecs: |
| result[i] = vec |
|
|
| |
| if uncached_tokens: |
| for idx, token in zip(uncached_indices, uncached_tokens): |
| hash_val = _fast_hash(token) |
| generator = torch.Generator() |
| generator.manual_seed(hash_val) |
| |
| vec = torch.randn(self.n_hashes, generator=generator).to(self.device) |
| result[idx] = vec |
|
|
| |
| if len(self._token_cache) < self.cache_size: |
| self._token_cache[token] = vec |
|
|
| return result |
|
|
| def apply_kwta(self, logits: torch.Tensor) -> torch.Tensor: |
| """ |
| k-Winners-Take-All activation. |
| |
| Args: |
| logits: Float tensor of shape (sdr_dim,) or (batch, sdr_dim) |
| |
| Returns: |
| sdr: Binary tensor with exactly k bits per sample |
| Shape matches input: (sdr_dim,) or (batch, sdr_dim) |
| """ |
| k = self.k |
|
|
| _, top_k_indices = torch.topk(logits, k, dim=-1) |
|
|
| sdr = torch.zeros_like(logits, dtype=torch.bool) |
|
|
| if logits.dim() == 1: |
| sdr[top_k_indices] = True |
| else: |
| sdr.scatter_(-1, top_k_indices, True) |
|
|
| return sdr |
|
|
| def forward(self, text: Union[str, list[str]]) -> torch.Tensor: |
| """ |
| Convert text to Sparse Distributed Representation. |
| |
| Optimized for batch processing with: |
| - Token deduplication across batch (hash each unique token once) |
| - LRU cache for common n-grams |
| - Vectorized GPU operations |
| |
| Args: |
| text: Input string or list of strings for batch processing |
| |
| Returns: |
| sdr: Binary tensor of shape (sdr_dim,) for single string, |
| or (batch, sdr_dim) for list of strings |
| """ |
| |
| if isinstance(text, list): |
| batch_size = len(text) |
| batch_tokens = self.tokenize(text) |
|
|
| |
| all_unique_tokens: list[str] = [] |
| token_to_idx: dict[str, int] = {} |
|
|
| for tokens in batch_tokens: |
| for token in tokens: |
| if token not in token_to_idx: |
| token_to_idx[token] = len(all_unique_tokens) |
| all_unique_tokens.append(token) |
|
|
| |
| if all_unique_tokens: |
| all_hashes = self._hash_tokens_batch(all_unique_tokens) |
| else: |
| all_hashes = torch.empty(0, self.n_hashes, device=self.device) |
|
|
| |
| batch_v_sums = torch.zeros(batch_size, self.n_hashes, device=self.device) |
| empty_mask = [] |
|
|
| for i, tokens in enumerate(batch_tokens): |
| if not tokens: |
| empty_mask.append(True) |
| else: |
| |
| indices = [token_to_idx[t] for t in tokens] |
| batch_v_sums[i] = all_hashes[indices].sum(dim=0) |
| empty_mask.append(False) |
|
|
| |
| with torch.amp.autocast('cuda'): |
| logits = batch_v_sums @ self.projection_matrix |
|
|
| self.last_logits = logits |
| sdr = self.apply_kwta(logits) |
|
|
| |
| if any(empty_mask): |
| empty_mask_tensor = torch.tensor(empty_mask, device=self.device) |
| sdr[empty_mask_tensor] = False |
|
|
| return sdr |
|
|
| |
| tokens = cast(list[str], self.tokenize(text)) |
|
|
| if not tokens: |
| return torch.zeros(self.sdr_dim, dtype=torch.bool, device=self.device) |
|
|
| with torch.amp.autocast('cuda'): |
| |
| v_sum = torch.zeros(self.n_hashes, device=self.device) |
| for token in tokens: |
| v_sum += self.hash_token(token) |
|
|
| sdr_float = self.projection_matrix.T @ v_sum |
|
|
| self.last_logits = sdr_float |
| sdr = self.apply_kwta(sdr_float) |
|
|
| return sdr |
|
|
| class LearnableEncoder(nn.Module): |
| """ |
| Learnable text encoder using tokenizer embeddings. |
| |
| Unlike SimHash (which uses fixed random projections and loses semantic info), |
| this encoder uses learnable embeddings that can be trained end-to-end. |
| |
| Architecture: |
| text → tokenizer → embedding lookup → mean pool → projection → SDR |
| |
| The key insight is that we share the vocabulary embedding with the decoder |
| (tied embeddings), allowing the model to learn meaningful token representations. |
| """ |
|
|
| def __init__( |
| self, |
| sdr_dim: int = 1024, |
| sparsity: float = 0.05, |
| embed_dim: int = 512, |
| tokenizer_name: str = "meta-llama/Meta-Llama-3-8B", |
| device: str = "cuda", |
| vocab_embedding: Optional[nn.Embedding] = None, |
| ) -> None: |
| """ |
| Initialize learnable encoder. |
| |
| Args: |
| sdr_dim: Output SDR dimension |
| sparsity: Fraction of active bits (for k-WTA) |
| embed_dim: Embedding dimension (should match decoder) |
| tokenizer_name: HuggingFace tokenizer name |
| device: Computation device |
| vocab_embedding: Optional shared embedding from decoder (for tied weights) |
| """ |
| super().__init__() |
| assert HAS_TRANSFORMERS, "transformers library required for LearnableEncoder" |
|
|
| self.sdr_dim = sdr_dim |
| self.sparsity = sparsity |
| self.k = int(sdr_dim * sparsity) |
| self.embed_dim = embed_dim |
| self.device = torch.device(device) |
|
|
| |
| self.tokenizer = AutoTokenizer.from_pretrained( |
| tokenizer_name, |
| trust_remote_code=True, |
| ) |
| if self.tokenizer.pad_token is None: |
| self.tokenizer.pad_token = self.tokenizer.eos_token |
| self.vocab_size = len(self.tokenizer) |
|
|
| |
| if vocab_embedding is not None: |
| self.vocab_embedding = vocab_embedding |
| self._tied = True |
| else: |
| self.vocab_embedding = nn.Embedding(self.vocab_size, embed_dim) |
| nn.init.normal_(self.vocab_embedding.weight, std=0.02) |
| self._tied = False |
|
|
| |
| self.projection = nn.Sequential( |
| nn.Linear(embed_dim, sdr_dim), |
| nn.LayerNorm(sdr_dim), |
| ) |
|
|
| self.last_logits: Optional[torch.Tensor] = None |
|
|
| def apply_kwta(self, logits: torch.Tensor) -> torch.Tensor: |
| """k-Winners-Take-All activation (same as SimHash).""" |
| k = self.k |
| _, top_k_indices = torch.topk(logits, k, dim=-1) |
| sdr = torch.zeros_like(logits, dtype=torch.bool) |
| if logits.dim() == 1: |
| sdr[top_k_indices] = True |
| else: |
| sdr.scatter_(-1, top_k_indices, True) |
| return sdr |
|
|
| def encode_to_embedding(self, text: Union[str, list[str]]) -> torch.Tensor: |
| """ |
| Encode text directly to pooled embedding (bypasses SDR projection). |
| |
| This method returns the rich semantic embedding before projection to SDR, |
| preserving all information for tasks like text generation. |
| |
| Args: |
| text: Input string or list of strings |
| |
| Returns: |
| pooled: Pooled embedding tensor (embed_dim,) or (batch, embed_dim) |
| """ |
| |
| if isinstance(text, str): |
| text = [text] |
| squeeze = True |
| else: |
| squeeze = False |
|
|
| |
| encoded = self.tokenizer( |
| text, |
| padding=True, |
| truncation=True, |
| max_length=64, |
| return_tensors="pt", |
| ) |
| input_ids = encoded["input_ids"].to(self.device) |
| attention_mask = encoded["attention_mask"].to(self.device) |
|
|
| |
| embeddings = self.vocab_embedding(input_ids) |
|
|
| |
| mask = attention_mask.unsqueeze(-1).float() |
| pooled = (embeddings * mask).sum(dim=1) / mask.sum(dim=1).clamp(min=1) |
|
|
| if squeeze: |
| pooled = pooled.squeeze(0) |
|
|
| return pooled |
|
|
| def forward( |
| self, text: Union[str, list[str]], return_continuous: bool = True |
| ) -> torch.Tensor: |
| """ |
| Encode text to SDR using learnable embeddings. |
| |
| Args: |
| text: Input string or list of strings |
| return_continuous: If True, return continuous logits (for gradient flow). |
| If False, return binary SDR (for compatibility). |
| |
| Returns: |
| If return_continuous=True: |
| logits: Continuous tensor (sdr_dim,) or (batch, sdr_dim) |
| If return_continuous=False: |
| sdr: Binary tensor (sdr_dim,) or (batch, sdr_dim) |
| """ |
| |
| if isinstance(text, str): |
| text = [text] |
| squeeze = True |
| else: |
| squeeze = False |
|
|
| |
| encoded = self.tokenizer( |
| text, |
| padding=True, |
| truncation=True, |
| max_length=64, |
| return_tensors="pt", |
| ) |
| input_ids = encoded["input_ids"].to(self.device) |
| attention_mask = encoded["attention_mask"].to(self.device) |
|
|
| |
| embeddings = self.vocab_embedding(input_ids) |
|
|
| |
| mask = attention_mask.unsqueeze(-1).float() |
| pooled = (embeddings * mask).sum(dim=1) / mask.sum(dim=1).clamp(min=1) |
|
|
| |
| logits = self.projection(pooled) |
| self.last_logits = logits |
|
|
| if return_continuous: |
| |
| |
| k = self.k |
| _, top_k_indices = torch.topk(logits.abs(), k, dim=-1) |
| mask_tensor = torch.zeros_like(logits) |
| if logits.dim() == 1: |
| mask_tensor[top_k_indices] = 1.0 |
| else: |
| mask_tensor.scatter_(-1, top_k_indices, 1.0) |
| |
| sparse_logits = logits * mask_tensor |
|
|
| if squeeze: |
| sparse_logits = sparse_logits.squeeze(0) |
| return sparse_logits |
| else: |
| |
| sdr = self.apply_kwta(logits) |
| if squeeze: |
| sdr = sdr.squeeze(0) |
| return sdr |
|
|