melhelbawi's picture
feat: establish Quantum-Enhanced CST project with core components, training pipelines, and evaluation utilities, and update README.md.
94c2e42
# CST / QCST Dual License
# Non-commercial research use only.
# Commercial use requires explicit permission.
# Copyright (c) 2025 Mohamed Mohamed Elhelbawi
# All rights reserved.
# See LICENSE file in the project root for full license information.
"""
Core CST Module Implementation
Main module that orchestrates fragment encoding, information fusion, and caching
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Dict, List, Optional, Any, Tuple
import hashlib
import json
from collections import OrderedDict
import time
from fragment_encoder import FragmentEncoder
from information_fuser import InformationFuser
class LRUCache:
"""Simple LRU cache implementation for embedding caching"""
def __init__(self, capacity: int):
self.capacity = capacity
self.cache = OrderedDict()
self.hits = 0
self.misses = 0
def get(self, key: str) -> Optional[torch.Tensor]:
if key in self.cache:
# Move to end (most recently used)
self.cache.move_to_end(key)
self.hits += 1
return self.cache[key].clone() # Clone to avoid in-place modifications
else:
self.misses += 1
return None
def put(self, key: str, value: torch.Tensor):
if key in self.cache:
self.cache.move_to_end(key)
else:
if len(self.cache) >= self.capacity:
# Remove least recently used item
self.cache.popitem(last=False)
self.cache[key] = value.clone().detach()
def clear(self):
self.cache.clear()
self.hits = 0
self.misses = 0
def stats(self):
total = self.hits + self.misses
hit_rate = self.hits / total if total > 0 else 0.0
return {
'hits': self.hits,
'misses': self.misses,
'hit_rate': hit_rate,
'cache_size': len(self.cache),
'capacity': self.capacity
}
class AmbiguityClassifier(nn.Module):
"""Determines whether dynamic processing is needed for each fragment"""
def __init__(self, config):
super().__init__()
self.config = config
# Pre-computed ambiguous word vocabulary (loaded during training)
self.register_buffer(
'ambiguous_vocab',
torch.tensor(config.ambiguous_word_ids if config.ambiguous_word_ids else [])
)
# Context-based ambiguity classifier
context_input_dim = config.fragment_encoding_dim + config.context_feature_dim
self.context_classifier = nn.Sequential(
nn.Linear(context_input_dim, config.hidden_dim),
nn.LayerNorm(config.hidden_dim),
nn.GELU(),
nn.Dropout(0.1),
nn.Linear(config.hidden_dim, config.hidden_dim // 2),
nn.GELU(),
nn.Linear(config.hidden_dim // 2, 1),
nn.Sigmoid()
)
# Frequency-based classifier (learns from data)
self.frequency_classifier = nn.Sequential(
nn.Linear(1, 32), # Input: log frequency
nn.ReLU(),
nn.Linear(32, 1),
nn.Sigmoid()
)
# Combination weights
self.combination_weights = nn.Parameter(torch.tensor([0.4, 0.3, 0.3])) # vocab, context, frequency
self.ambiguity_threshold = config.ambiguity_threshold
def forward(self,
fragment_ids: torch.Tensor,
context_features: torch.Tensor,
fragment_frequencies: Optional[torch.Tensor] = None) -> torch.Tensor:
"""
Determine ambiguity for each fragment
Args:
fragment_ids: [batch_size] - Fragment token IDs
context_features: [batch_size, context_feature_dim] - Context features
fragment_frequencies: [batch_size] - Log frequencies of fragments
"""
batch_size = fragment_ids.size(0)
ambiguity_scores = torch.zeros(batch_size, device=fragment_ids.device)
# 1. Vocabulary-based ambiguity
if len(self.ambiguous_vocab) > 0:
vocab_ambiguous = torch.isin(fragment_ids, self.ambiguous_vocab).float()
ambiguity_scores += self.combination_weights[0] * vocab_ambiguous
# 2. Context-based ambiguity
if context_features.size(1) >= self.config.context_feature_dim:
# Pad fragment encoding to match expected dimension
fragment_encoding = torch.zeros(batch_size, self.config.fragment_encoding_dim,
device=fragment_ids.device)
combined_features = torch.cat([fragment_encoding, context_features[:, :self.config.context_feature_dim]], dim=1)
context_scores = self.context_classifier(combined_features).squeeze(-1)
ambiguity_scores += self.combination_weights[1] * context_scores
# 3. Frequency-based ambiguity (high frequency words are more likely ambiguous)
if fragment_frequencies is not None:
freq_scores = self.frequency_classifier(fragment_frequencies.unsqueeze(-1)).squeeze(-1)
ambiguity_scores += self.combination_weights[2] * freq_scores
# Return binary decisions
return ambiguity_scores > self.ambiguity_threshold
def update_ambiguous_vocab(self, new_ambiguous_ids: List[int]):
"""Update the ambiguous vocabulary during training"""
self.ambiguous_vocab = torch.tensor(new_ambiguous_ids, device=self.ambiguous_vocab.device)
class ProjectionHead(nn.Module):
"""Projects fused representation to transformer embedding dimension"""
def __init__(self, config):
super().__init__()
self.projection = nn.Sequential(
nn.Linear(config.fused_dim, config.d_model),
nn.LayerNorm(config.d_model),
nn.Tanh(), # Bounded output for stability
nn.Dropout(0.1)
)
# Residual connection option
self.use_residual = config.fused_dim == config.d_model
if not self.use_residual and hasattr(config, 'enable_projection_residual'):
self.residual_proj = nn.Linear(config.fused_dim, config.d_model)
self.use_residual = config.enable_projection_residual
def forward(self, fused_representation: torch.Tensor) -> torch.Tensor:
output = self.projection(fused_representation)
if self.use_residual:
if hasattr(self, 'residual_proj'):
residual = self.residual_proj(fused_representation)
else:
residual = fused_representation
output = output + residual
return output
class CSTModule(nn.Module):
"""
Main Contextual Spectrum Tokenization Module
Integrates fragment encoding, information fusion, ambiguity detection, and caching
"""
def __init__(self, config):
super().__init__()
self.config = config
# Core components
self.fragment_encoder = FragmentEncoder(config)
self.information_fuser = InformationFuser(config)
self.projection_head = ProjectionHead(config)
self.ambiguity_classifier = AmbiguityClassifier(config)
# Static embeddings fallback
self.static_embeddings = nn.Embedding(config.vocab_size, config.d_model)
# Initialize static embeddings with reasonable values
nn.init.normal_(self.static_embeddings.weight, mean=0.0, std=0.02)
# Caching system
self.cache = LRUCache(config.cache_size)
# Performance tracking
self.enable_profiling = False
self.profile_stats = {
'cache_hits': 0,
'cache_misses': 0,
'ambiguous_tokens': 0,
'static_tokens': 0,
'total_forward_time': 0.0,
'num_forward_calls': 0
}
def _compute_cache_key(self, fragment_data: Dict[str, Any], context_data: Dict[str, Any]) -> str:
"""Compute a hash key for caching"""
# Create a simplified representation for hashing
key_components = {
'fragment_id': fragment_data.get('fragment_id', '').item() if torch.is_tensor(fragment_data.get('fragment_id')) else str(fragment_data.get('fragment_id', '')),
'context_hash': self._hash_context(context_data)
}
key_string = json.dumps(key_components, sort_keys=True)
return hashlib.md5(key_string.encode()).hexdigest()
def _hash_context(self, context_data: Dict[str, Any]) -> str:
"""Create a hash of context data for caching"""
context_summary = {}
for key, value in context_data.items():
if isinstance(value, torch.Tensor):
# Use tensor statistics for hashing
context_summary[key] = {
'shape': list(value.shape),
'mean': float(value.mean().item()) if value.numel() > 0 else 0.0,
'std': float(value.std().item()) if value.numel() > 0 else 0.0
}
elif isinstance(value, dict):
context_summary[key] = self._hash_context(value)
else:
context_summary[key] = str(value)
return hashlib.md5(json.dumps(context_summary, sort_keys=True).encode()).hexdigest()[:16]
def _compute_dynamic_embedding(self, fragment_data: Dict[str, Any], context_data: Dict[str, Any]) -> torch.Tensor:
"""Compute dynamic embedding using the full CST pipeline"""
# Extract fragment encoding
fragment_encoding = self.fragment_encoder(
fragment_data['fragment_chars'],
fragment_data['context_chars'],
fragment_data.get('fragment_positions')
)
# Fuse with contextual information
fused_representation = self.information_fuser(fragment_encoding, context_data)
# Project to output space
output_embedding = self.projection_head(fused_representation)
return output_embedding
def forward(self,
text_fragments: torch.Tensor,
context_data: Dict[str, Any],
fragment_chars: Optional[torch.Tensor] = None,
context_chars: Optional[torch.Tensor] = None,
fragment_frequencies: Optional[torch.Tensor] = None) -> torch.Tensor:
"""
Main forward pass of CST module
Args:
text_fragments: [batch_size, seq_len] - Token IDs
context_data: Dictionary of contextual information
fragment_chars: [batch_size, seq_len, char_len] - Character-level data
context_chars: [batch_size, seq_len, context_char_len] - Context characters
fragment_frequencies: [batch_size, seq_len] - Fragment frequencies
"""
start_time = time.time() if self.enable_profiling else 0
batch_size, seq_len = text_fragments.shape
device = text_fragments.device
# Initialize output
output_vectors = torch.zeros(batch_size, seq_len, self.config.d_model, device=device)
for i in range(seq_len):
fragment_ids = text_fragments[:, i]
# Prepare fragment data
fragment_data = {
'fragment_id': fragment_ids,
'fragment_chars': fragment_chars[:, i] if fragment_chars is not None else None,
'context_chars': context_chars[:, i] if context_chars is not None else None,
'fragment_positions': torch.full((batch_size,), i, device=device)
}
# Prepare context features for ambiguity classification
context_features = torch.zeros(batch_size, self.config.context_feature_dim, device=device)
if 'document_embedding' in context_data:
doc_emb = context_data['document_embedding']
feature_dim = min(self.config.context_feature_dim, doc_emb.size(-1))
context_features[:, :feature_dim] = doc_emb[:, :feature_dim]
# Determine if dynamic processing is needed
freqs = fragment_frequencies[:, i] if fragment_frequencies is not None else None
is_ambiguous = self.ambiguity_classifier(fragment_ids, context_features, freqs)
# Process each sample in the batch
for b in range(batch_size):
if is_ambiguous[b]:
# Try cache first
sample_fragment_data = {k: v[b] if v is not None else None for k, v in fragment_data.items()}
sample_context_data = {k: v[b] if isinstance(v, torch.Tensor) else v for k, v in context_data.items()}
cache_key = self._compute_cache_key(sample_fragment_data, sample_context_data)
cached_vector = self.cache.get(cache_key)
if cached_vector is not None:
output_vectors[b, i] = cached_vector
if self.enable_profiling:
self.profile_stats['cache_hits'] += 1
else:
# Compute dynamic embedding
dynamic_vector = self._compute_dynamic_embedding(sample_fragment_data, sample_context_data)
output_vectors[b, i] = dynamic_vector.squeeze(0) if dynamic_vector.dim() > 1 else dynamic_vector
# Cache the result
self.cache.put(cache_key, output_vectors[b, i])
if self.enable_profiling:
self.profile_stats['cache_misses'] += 1
self.profile_stats['ambiguous_tokens'] += 1
else:
# Use static embedding
output_vectors[b, i] = self.static_embeddings(fragment_ids[b])
if self.enable_profiling:
self.profile_stats['static_tokens'] += 1
if self.enable_profiling:
self.profile_stats['total_forward_time'] += time.time() - start_time
self.profile_stats['num_forward_calls'] += 1
return output_vectors
def encode_single_fragment(self, fragment_text: str, context_data: Dict[str, Any]) -> torch.Tensor:
"""Encode a single text fragment (useful for inference)"""
# This would need proper text preprocessing - simplified for now
fragment_id = hash(fragment_text) % self.config.vocab_size # Simplified tokenization
fragment_tensor = torch.tensor([[fragment_id]], dtype=torch.long)
return self.forward(fragment_tensor, context_data).squeeze()
def enable_profiling_mode(self, enable: bool = True):
"""Enable or disable performance profiling"""
self.enable_profiling = enable
if enable:
# Reset stats
self.profile_stats = {k: 0 if isinstance(v, (int, float)) else v for k, v in self.profile_stats.items()}
def get_performance_stats(self) -> Dict[str, Any]:
"""Get performance statistics"""
stats = self.profile_stats.copy()
cache_stats = self.cache.stats()
stats.update(cache_stats)
# Add derived metrics
if stats['num_forward_calls'] > 0:
stats['avg_forward_time'] = stats['total_forward_time'] / stats['num_forward_calls']
total_tokens = stats['ambiguous_tokens'] + stats['static_tokens']
if total_tokens > 0:
stats['ambiguous_ratio'] = stats['ambiguous_tokens'] / total_tokens
stats['static_ratio'] = stats['static_tokens'] / total_tokens
return stats
def clear_cache(self):
"""Clear the embedding cache"""
self.cache.clear()
def save_ambiguous_vocab(self, filepath: str):
"""Save the current ambiguous vocabulary"""
vocab_list = self.ambiguous_vocab.cpu().numpy().tolist()
with open(filepath, 'w') as f:
json.dump(vocab_list, f)
def load_ambiguous_vocab(self, filepath: str):
"""Load ambiguous vocabulary from file"""
with open(filepath, 'r') as f:
vocab_list = json.load(f)
self.ambiguity_classifier.update_ambiguous_vocab(vocab_list)
def test_cst_module():
"""Test the complete CST module"""
from config import CSTConfig
config = CSTConfig()
config.ambiguous_word_ids = [1, 5, 10, 15, 20] # Sample ambiguous words
cst = CSTModule(config)
cst.enable_profiling_mode(True)
batch_size = 2
seq_len = 8
# Sample input
text_fragments = torch.randint(0, config.vocab_size, (batch_size, seq_len))
fragment_chars = torch.randint(0, config.char_vocab_size, (batch_size, seq_len, 32))
context_chars = torch.randint(0, config.char_vocab_size, (batch_size, seq_len, 64))
context_data = {
'document_embedding': torch.randn(batch_size, config.raw_doc_dim),
'metadata': {
'author': torch.randint(0, config.num_authors, (batch_size,)),
'domain': torch.randint(0, config.num_domains, (batch_size,)),
}
}
# Forward pass
output = cst(text_fragments, context_data, fragment_chars, context_chars)
print(f"Input shape: {text_fragments.shape}")
print(f"Output shape: {output.shape}")
print(f"Expected output shape: {(batch_size, seq_len, config.d_model)}")
# Print performance stats
stats = cst.get_performance_stats()
print("\nPerformance Statistics:")
for key, value in stats.items():
print(f" {key}: {value}")
# Test caching
print("\nTesting caching...")
output2 = cst(text_fragments, context_data, fragment_chars, context_chars)
cache_stats = cst.get_performance_stats()
print(f"Cache hit rate after second pass: {cache_stats['hit_rate']:.2%}")
assert output.shape == (batch_size, seq_len, config.d_model), \
f"Expected {(batch_size, seq_len, config.d_model)}, got {output.shape}"
print("CST Module test passed!")
if __name__ == "__main__":
test_cst_module()