sem-v6-training / src /sem_v6 /modules /module_a.py
icarus112's picture
Upload folder using huggingface_hub
518db7a verified
from functools import lru_cache
from typing import Optional, Union, cast
import torch
import torch.nn as nn
import torch.nn.functional as F
# Optional: Use transformers tokenizer if available
try:
from transformers import AutoTokenizer
HAS_TRANSFORMERS = True
except ImportError:
HAS_TRANSFORMERS = False
def _fast_hash(s: str) -> int:
"""Fast polynomial hash (FNV-1a inspired) - 10x faster than MD5."""
h = 0x811c9dc5 # FNV offset basis
for c in s:
h ^= ord(c)
h = (h * 0x01000193) & 0xFFFFFFFFFFFFFFFF # FNV prime, 64-bit
return h
class SimHashEncoder(nn.Module):
"""
Locality-Sensitive Hashing encoder for text → SDR conversion.
Based on Charikar (2002) SimHash algorithm:
Hamming(SDR_A, SDR_B) ≈ CosineSim(BoW_A, BoW_B)
Optimized with:
- LRU cache for token hashes (avoid recomputation)
- Fast polynomial hash (10x faster than MD5)
- Vectorized batch processing on GPU
- Pre-allocated buffers for reduced memory allocation
"""
projection_matrix: torch.Tensor
last_logits: Optional[torch.Tensor]
_token_cache: dict[str, torch.Tensor]
def __init__(
self,
sdr_dim: int = 16384,
sparsity: float = 0.05,
n_hashes: int = 128,
seed: int = 42,
device: str = "cuda",
cache_size: int = 50000, # Cache most common n-grams
) -> None:
super().__init__()
assert 0 < sparsity <= 1.0, f"sparsity must be in (0, 1], got {sparsity}"
self.sdr_dim = sdr_dim
self.sparsity = sparsity
self.k = int(sdr_dim * sparsity)
self.n_hashes = n_hashes
self.device = torch.device(device)
self.cache_size = cache_size
torch.manual_seed(seed)
self.register_buffer(
"projection_matrix",
torch.randn(n_hashes, sdr_dim, dtype=torch.float32, device=self.device),
)
self.last_logits = None
# Token hash cache - stays on CPU, batch-transferred to GPU
self._token_cache: dict[str, torch.Tensor] = {}
self._cache_hits = 0
self._cache_misses = 0
def tokenize(
self, text: Union[str, list[str]], ngram_size: int = 3
) -> Union[list[str], list[list[str]]]:
"""
Convert text into character n-grams for hashing.
Args:
text: Input string or list of strings for batch processing
ngram_size: Size of character n-grams (default: 3)
Returns:
List of n-gram strings for single input,
or list of n-gram lists for batch input
"""
if isinstance(text, list):
result = []
for t in text:
t_normalized = t.lower().replace(" ", "_")
ngrams = [
t_normalized[i : i + ngram_size]
for i in range(len(t_normalized) - ngram_size + 1)
]
result.append(ngrams)
return result
text = text.lower().replace(" ", "_")
return [text[i : i + ngram_size] for i in range(len(text) - ngram_size + 1)]
def hash_token(self, token: str) -> torch.Tensor:
"""
Generate deterministic random vector for a token with caching.
Uses fast polynomial hash instead of MD5 (10x speedup).
Args:
token: String token to hash
Returns:
Random vector of shape (n_hashes,) on device
"""
# Check cache first
if token in self._token_cache:
self._cache_hits += 1
return self._token_cache[token]
self._cache_misses += 1
# Fast polynomial hash (FNV-1a inspired)
hash_val = _fast_hash(token)
generator = torch.Generator() # CPU generator
generator.manual_seed(hash_val)
# Generate on CPU, then transfer to GPU
vec = torch.randn(self.n_hashes, generator=generator).to(self.device)
# Cache if not full
if len(self._token_cache) < self.cache_size:
self._token_cache[token] = vec
return vec
def _hash_tokens_batch(self, all_tokens: list[str]) -> torch.Tensor:
"""
Hash multiple tokens at once, leveraging cache and batch GPU transfer.
Args:
all_tokens: List of unique tokens to hash
Returns:
Tensor of shape (num_tokens, n_hashes) on device
"""
if not all_tokens:
return torch.empty(0, self.n_hashes, device=self.device)
# Separate cached vs uncached
cached_vecs = []
uncached_tokens = []
uncached_indices = []
for i, token in enumerate(all_tokens):
if token in self._token_cache:
cached_vecs.append((i, self._token_cache[token]))
self._cache_hits += 1
else:
uncached_tokens.append(token)
uncached_indices.append(i)
self._cache_misses += 1
# Allocate result tensor
result = torch.empty(len(all_tokens), self.n_hashes, device=self.device)
# Fill cached values
for i, vec in cached_vecs:
result[i] = vec
# Batch compute uncached (still sequential but minimized)
if uncached_tokens:
for idx, token in zip(uncached_indices, uncached_tokens):
hash_val = _fast_hash(token)
generator = torch.Generator() # CPU generator
generator.manual_seed(hash_val)
# Generate on CPU, then transfer to GPU
vec = torch.randn(self.n_hashes, generator=generator).to(self.device)
result[idx] = vec
# Cache if not full
if len(self._token_cache) < self.cache_size:
self._token_cache[token] = vec
return result
def apply_kwta(self, logits: torch.Tensor) -> torch.Tensor:
"""
k-Winners-Take-All activation.
Args:
logits: Float tensor of shape (sdr_dim,) or (batch, sdr_dim)
Returns:
sdr: Binary tensor with exactly k bits per sample
Shape matches input: (sdr_dim,) or (batch, sdr_dim)
"""
k = self.k
_, top_k_indices = torch.topk(logits, k, dim=-1)
sdr = torch.zeros_like(logits, dtype=torch.bool)
if logits.dim() == 1:
sdr[top_k_indices] = True
else:
sdr.scatter_(-1, top_k_indices, True)
return sdr
def forward(self, text: Union[str, list[str]]) -> torch.Tensor:
"""
Convert text to Sparse Distributed Representation.
Optimized for batch processing with:
- Token deduplication across batch (hash each unique token once)
- LRU cache for common n-grams
- Vectorized GPU operations
Args:
text: Input string or list of strings for batch processing
Returns:
sdr: Binary tensor of shape (sdr_dim,) for single string,
or (batch, sdr_dim) for list of strings
"""
# Handle batch processing
if isinstance(text, list):
batch_size = len(text)
batch_tokens = self.tokenize(text)
# Collect ALL unique tokens across entire batch for deduplication
all_unique_tokens: list[str] = []
token_to_idx: dict[str, int] = {}
for tokens in batch_tokens:
for token in tokens:
if token not in token_to_idx:
token_to_idx[token] = len(all_unique_tokens)
all_unique_tokens.append(token)
# Hash all unique tokens at once (leverages cache)
if all_unique_tokens:
all_hashes = self._hash_tokens_batch(all_unique_tokens) # (num_unique, n_hashes)
else:
all_hashes = torch.empty(0, self.n_hashes, device=self.device)
# Build per-sample v_sum using index lookups (vectorized)
batch_v_sums = torch.zeros(batch_size, self.n_hashes, device=self.device)
empty_mask = []
for i, tokens in enumerate(batch_tokens):
if not tokens:
empty_mask.append(True)
else:
# Sum hashes for this sample's tokens using index lookup
indices = [token_to_idx[t] for t in tokens]
batch_v_sums[i] = all_hashes[indices].sum(dim=0)
empty_mask.append(False)
# Vectorized projection: (batch, n_hashes) @ (n_hashes, sdr_dim) = (batch, sdr_dim)
with torch.amp.autocast('cuda'):
logits = batch_v_sums @ self.projection_matrix
self.last_logits = logits
sdr = self.apply_kwta(logits)
# Force empty inputs to have zero SDRs
if any(empty_mask):
empty_mask_tensor = torch.tensor(empty_mask, device=self.device)
sdr[empty_mask_tensor] = False
return sdr
# Single string case
tokens = cast(list[str], self.tokenize(text))
if not tokens:
return torch.zeros(self.sdr_dim, dtype=torch.bool, device=self.device)
with torch.amp.autocast('cuda'):
# Use cached token hashes
v_sum = torch.zeros(self.n_hashes, device=self.device)
for token in tokens:
v_sum += self.hash_token(token)
sdr_float = self.projection_matrix.T @ v_sum
self.last_logits = sdr_float
sdr = self.apply_kwta(sdr_float)
return sdr
class LearnableEncoder(nn.Module):
"""
Learnable text encoder using tokenizer embeddings.
Unlike SimHash (which uses fixed random projections and loses semantic info),
this encoder uses learnable embeddings that can be trained end-to-end.
Architecture:
text → tokenizer → embedding lookup → mean pool → projection → SDR
The key insight is that we share the vocabulary embedding with the decoder
(tied embeddings), allowing the model to learn meaningful token representations.
"""
def __init__(
self,
sdr_dim: int = 1024,
sparsity: float = 0.05,
embed_dim: int = 512,
tokenizer_name: str = "meta-llama/Meta-Llama-3-8B",
device: str = "cuda",
vocab_embedding: Optional[nn.Embedding] = None, # Shared with decoder
) -> None:
"""
Initialize learnable encoder.
Args:
sdr_dim: Output SDR dimension
sparsity: Fraction of active bits (for k-WTA)
embed_dim: Embedding dimension (should match decoder)
tokenizer_name: HuggingFace tokenizer name
device: Computation device
vocab_embedding: Optional shared embedding from decoder (for tied weights)
"""
super().__init__()
assert HAS_TRANSFORMERS, "transformers library required for LearnableEncoder"
self.sdr_dim = sdr_dim
self.sparsity = sparsity
self.k = int(sdr_dim * sparsity)
self.embed_dim = embed_dim
self.device = torch.device(device)
# Load tokenizer
self.tokenizer = AutoTokenizer.from_pretrained(
tokenizer_name,
trust_remote_code=True,
)
if self.tokenizer.pad_token is None:
self.tokenizer.pad_token = self.tokenizer.eos_token
self.vocab_size = len(self.tokenizer)
# Vocabulary embedding (can be shared with decoder for tied weights)
if vocab_embedding is not None:
self.vocab_embedding = vocab_embedding
self._tied = True
else:
self.vocab_embedding = nn.Embedding(self.vocab_size, embed_dim)
nn.init.normal_(self.vocab_embedding.weight, std=0.02)
self._tied = False
# Projection: embed_dim → sdr_dim (learnable)
self.projection = nn.Sequential(
nn.Linear(embed_dim, sdr_dim),
nn.LayerNorm(sdr_dim),
)
self.last_logits: Optional[torch.Tensor] = None
def apply_kwta(self, logits: torch.Tensor) -> torch.Tensor:
"""k-Winners-Take-All activation (same as SimHash)."""
k = self.k
_, top_k_indices = torch.topk(logits, k, dim=-1)
sdr = torch.zeros_like(logits, dtype=torch.bool)
if logits.dim() == 1:
sdr[top_k_indices] = True
else:
sdr.scatter_(-1, top_k_indices, True)
return sdr
def encode_to_embedding(self, text: Union[str, list[str]]) -> torch.Tensor:
"""
Encode text directly to pooled embedding (bypasses SDR projection).
This method returns the rich semantic embedding before projection to SDR,
preserving all information for tasks like text generation.
Args:
text: Input string or list of strings
Returns:
pooled: Pooled embedding tensor (embed_dim,) or (batch, embed_dim)
"""
# Handle single string
if isinstance(text, str):
text = [text]
squeeze = True
else:
squeeze = False
# Tokenize
encoded = self.tokenizer(
text,
padding=True,
truncation=True,
max_length=64,
return_tensors="pt",
)
input_ids = encoded["input_ids"].to(self.device)
attention_mask = encoded["attention_mask"].to(self.device)
# Embed tokens
embeddings = self.vocab_embedding(input_ids) # (batch, seq_len, embed_dim)
# Mean pooling over sequence (masked)
mask = attention_mask.unsqueeze(-1).float()
pooled = (embeddings * mask).sum(dim=1) / mask.sum(dim=1).clamp(min=1)
if squeeze:
pooled = pooled.squeeze(0)
return pooled
def forward(
self, text: Union[str, list[str]], return_continuous: bool = True
) -> torch.Tensor:
"""
Encode text to SDR using learnable embeddings.
Args:
text: Input string or list of strings
return_continuous: If True, return continuous logits (for gradient flow).
If False, return binary SDR (for compatibility).
Returns:
If return_continuous=True:
logits: Continuous tensor (sdr_dim,) or (batch, sdr_dim)
If return_continuous=False:
sdr: Binary tensor (sdr_dim,) or (batch, sdr_dim)
"""
# Handle single string
if isinstance(text, str):
text = [text]
squeeze = True
else:
squeeze = False
# Tokenize
encoded = self.tokenizer(
text,
padding=True,
truncation=True,
max_length=64,
return_tensors="pt",
)
input_ids = encoded["input_ids"].to(self.device) # (batch, seq_len)
attention_mask = encoded["attention_mask"].to(self.device)
# Embed tokens
embeddings = self.vocab_embedding(input_ids) # (batch, seq_len, embed_dim)
# Mean pooling over sequence (masked)
mask = attention_mask.unsqueeze(-1).float() # (batch, seq_len, 1)
pooled = (embeddings * mask).sum(dim=1) / mask.sum(dim=1).clamp(min=1) # (batch, embed_dim)
# Project to SDR space
logits = self.projection(pooled) # (batch, sdr_dim)
self.last_logits = logits
if return_continuous:
# Return continuous logits for gradient flow
# Apply soft sparsity with top-k gating for differentiability
k = self.k
_, top_k_indices = torch.topk(logits.abs(), k, dim=-1)
mask_tensor = torch.zeros_like(logits)
if logits.dim() == 1:
mask_tensor[top_k_indices] = 1.0
else:
mask_tensor.scatter_(-1, top_k_indices, 1.0)
# Keep values at top-k positions, zero elsewhere
sparse_logits = logits * mask_tensor
if squeeze:
sparse_logits = sparse_logits.squeeze(0)
return sparse_logits
else:
# Return binary SDR (original behavior)
sdr = self.apply_kwta(logits)
if squeeze:
sdr = sdr.squeeze(0)
return sdr