# CST / QCST Dual License # Non-commercial research use only. # Commercial use requires explicit permission. # Copyright (c) 2025 Mohamed Mohamed Elhelbawi # All rights reserved. # See LICENSE file in the project root for full license information. """ Quantum-Enhanced CST Module Integrates quantum computing into the core CST architecture Fully standalone - no classical dependencies """ import torch import torch.nn as nn import torch.nn.functional as F from typing import Dict, List, Optional, Any, Tuple import hashlib import json from collections import OrderedDict import time import logging from .quantum_information_fuser import QuantumInformationFuser from .quantum_fragment_encoder import QuantumFragmentEncoder from .quantum_cst_config import QuantumConfig logger = logging.getLogger(__name__) class LRUCache: """Enhanced LRU cache with quantum state caching support""" def __init__(self, capacity: int, enable_quantum_cache: bool = True): self.capacity = capacity self.cache = OrderedDict() self.quantum_cache = OrderedDict() if enable_quantum_cache else None self.hits = 0 self.misses = 0 self.quantum_hits = 0 self.quantum_misses = 0 def get(self, key: str, is_quantum: bool = False) -> Optional[torch.Tensor]: cache_dict = self.quantum_cache if (is_quantum and self.quantum_cache is not None) else self.cache if key in cache_dict: cache_dict.move_to_end(key) if is_quantum: self.quantum_hits += 1 else: self.hits += 1 return cache_dict[key].clone() else: if is_quantum: self.quantum_misses += 1 else: self.misses += 1 return None def put(self, key: str, value: torch.Tensor, is_quantum: bool = False): cache_dict = self.quantum_cache if (is_quantum and self.quantum_cache is not None) else self.cache if key in cache_dict: cache_dict.move_to_end(key) else: if len(cache_dict) >= self.capacity: cache_dict.popitem(last=False) cache_dict[key] = value.clone().detach() def clear(self): self.cache.clear() if self.quantum_cache is not None: self.quantum_cache.clear() self.hits = 0 self.misses = 0 self.quantum_hits = 0 self.quantum_misses = 0 def stats(self): total = self.hits + self.misses hit_rate = self.hits / total if total > 0 else 0.0 quantum_total = self.quantum_hits + self.quantum_misses quantum_hit_rate = self.quantum_hits / quantum_total if quantum_total > 0 else 0.0 return { 'hits': self.hits, 'misses': self.misses, 'hit_rate': hit_rate, 'cache_size': len(self.cache), 'capacity': self.capacity, 'quantum_hits': self.quantum_hits, 'quantum_misses': self.quantum_misses, 'quantum_hit_rate': quantum_hit_rate, 'quantum_cache_size': len(self.quantum_cache) if self.quantum_cache else 0 } class QuantumAmbiguityClassifier(nn.Module): """Quantum-enhanced ambiguity classification (optional)""" def __init__(self, config): super().__init__() self.config = config self.use_quantum = config.quantum_config.quantum_ambiguity_classifier # Classical ambiguity detection self.register_buffer( 'ambiguous_vocab', torch.tensor(config.ambiguous_word_ids if config.ambiguous_word_ids else []) ) # Context-based classifier context_input_dim = config.fragment_encoding_dim + config.context_feature_dim self.context_classifier = nn.Sequential( nn.Linear(context_input_dim, config.hidden_dim), nn.LayerNorm(config.hidden_dim), nn.GELU(), nn.Dropout(0.1), nn.Linear(config.hidden_dim, config.hidden_dim // 2), nn.GELU(), nn.Linear(config.hidden_dim // 2, 1), nn.Sigmoid() ) # Quantum enhancement (if enabled) if self.use_quantum: from .quantum_information_fuser import HybridQuantumClassical self.quantum_classifier = HybridQuantumClassical( input_dim=context_input_dim, output_dim=1, quantum_config=config.quantum_config ) self.frequency_classifier = nn.Sequential( nn.Linear(1, 32), nn.ReLU(), nn.Linear(32, 1), nn.Sigmoid() ) self.combination_weights = nn.Parameter(torch.tensor([0.4, 0.3, 0.3])) self.ambiguity_threshold = config.ambiguity_threshold def forward(self, fragment_ids: torch.Tensor, context_features: torch.Tensor, fragment_frequencies: Optional[torch.Tensor] = None) -> torch.Tensor: batch_size = fragment_ids.size(0) ambiguity_scores = torch.zeros(batch_size, device=fragment_ids.device) # 1. Vocabulary-based if len(self.ambiguous_vocab) > 0: vocab_ambiguous = torch.isin(fragment_ids, self.ambiguous_vocab).float() ambiguity_scores += self.combination_weights[0] * vocab_ambiguous # 2. Context-based (quantum or classical) if context_features.size(1) >= self.config.context_feature_dim: fragment_encoding = torch.zeros(batch_size, self.config.fragment_encoding_dim, device=fragment_ids.device) combined_features = torch.cat([fragment_encoding, context_features[:, :self.config.context_feature_dim]], dim=1) if self.use_quantum: context_scores = self.quantum_classifier(combined_features).squeeze(-1) else: context_scores = self.context_classifier(combined_features).squeeze(-1) ambiguity_scores += self.combination_weights[1] * context_scores # 3. Frequency-based if fragment_frequencies is not None: freq_scores = self.frequency_classifier(fragment_frequencies.unsqueeze(-1)).squeeze(-1) ambiguity_scores += self.combination_weights[2] * freq_scores return ambiguity_scores > self.ambiguity_threshold class ProjectionHead(nn.Module): """Projects fused representation to transformer embedding dimension""" def __init__(self, config): super().__init__() self.projection = nn.Sequential( nn.Linear(config.fused_dim, config.d_model), nn.LayerNorm(config.d_model), nn.Tanh(), nn.Dropout(0.1) ) self.use_residual = config.fused_dim == config.d_model if not self.use_residual and hasattr(config, 'enable_projection_residual'): self.residual_proj = nn.Linear(config.fused_dim, config.d_model) self.use_residual = config.enable_projection_residual def forward(self, fused_representation: torch.Tensor) -> torch.Tensor: output = self.projection(fused_representation) if self.use_residual: if hasattr(self, 'residual_proj'): residual = self.residual_proj(fused_representation) else: residual = fused_representation output = output + residual return output class QuantumEnhancedCSTModule(nn.Module): """ Quantum-Enhanced Contextual Spectrum Tokenization Module Integrates quantum computing for enhanced information fusion """ def __init__(self, config): super().__init__() self.config = config self.quantum_enabled = config.quantum_config.enable_quantum # Core components - fully quantum standalone self.fragment_encoder = QuantumFragmentEncoder(config) # Information Fuser - Quantum only (no classical fallback) logger.info("Initializing Quantum Information Fuser (Standalone)") self.information_fuser = QuantumInformationFuser(config, config.quantum_config) self.projection_head = ProjectionHead(config) self.ambiguity_classifier = QuantumAmbiguityClassifier(config) # Static embeddings fallback self.static_embeddings = nn.Embedding(config.vocab_size, config.d_model) nn.init.normal_(self.static_embeddings.weight, mean=0.0, std=0.02) # Enhanced caching with quantum support self.cache = LRUCache( config.cache_size, enable_quantum_cache=self.quantum_enabled ) # Performance tracking self.enable_profiling = False self.profile_stats = { 'cache_hits': 0, 'cache_misses': 0, 'ambiguous_tokens': 0, 'static_tokens': 0, 'quantum_processed_tokens': 0, 'classical_processed_tokens': 0, 'total_forward_time': 0.0, 'quantum_forward_time': 0.0, 'classical_forward_time': 0.0, 'num_forward_calls': 0 } def _compute_cache_key(self, fragment_data: Dict[str, Any], context_data: Dict[str, Any], use_quantum: bool = False) -> str: """Compute cache key with quantum indicator""" key_components = { 'fragment_id': fragment_data.get('fragment_id', '').item() if torch.is_tensor(fragment_data.get('fragment_id')) else str(fragment_data.get('fragment_id', '')), 'context_hash': self._hash_context(context_data), 'quantum': use_quantum } key_string = json.dumps(key_components, sort_keys=True) return hashlib.md5(key_string.encode()).hexdigest() def _hash_context(self, context_data: Dict[str, Any]) -> str: """Create a hash of context data for caching""" context_summary = {} for key, value in context_data.items(): if isinstance(value, torch.Tensor): context_summary[key] = { 'shape': list(value.shape), 'mean': float(value.mean().item()) if value.numel() > 0 else 0.0, 'std': float(value.std().item()) if value.numel() > 0 else 0.0 } elif isinstance(value, dict): context_summary[key] = self._hash_context(value) else: context_summary[key] = str(value) return hashlib.md5(json.dumps(context_summary, sort_keys=True).encode()).hexdigest()[:16] def _compute_dynamic_embedding(self, fragment_data: Dict[str, Any], context_data: Dict[str, Any]) -> Tuple[torch.Tensor, bool]: """ Compute dynamic embedding using quantum-enhanced or classical pipeline Returns: Tuple of (embedding, used_quantum) """ use_quantum = (self.quantum_enabled and config.quantum_config.quantum_information_fuser and self.training) # Use quantum mainly during training # Time tracking start_time = time.time() if self.enable_profiling else 0 # Extract fragment encoding fragment_encoding = self.fragment_encoder( fragment_data['fragment_chars'], fragment_data['context_chars'], fragment_data.get('fragment_positions') ) # Fuse with contextual information (quantum or classical) fused_representation = self.information_fuser(fragment_encoding, context_data) # Project to output space output_embedding = self.projection_head(fused_representation) # Track timing if self.enable_profiling: elapsed = time.time() - start_time if use_quantum: self.profile_stats['quantum_forward_time'] += elapsed else: self.profile_stats['classical_forward_time'] += elapsed return output_embedding, use_quantum def forward(self, text_fragments: torch.Tensor, context_data: Dict[str, Any], fragment_chars: Optional[torch.Tensor] = None, context_chars: Optional[torch.Tensor] = None, fragment_frequencies: Optional[torch.Tensor] = None) -> Tuple[torch.Tensor, Dict[str, Any]]: """ Quantum-enhanced forward pass Returns: Tuple of (output_vectors, quantum_metrics) """ start_time = time.time() if self.enable_profiling else 0 batch_size, seq_len = text_fragments.shape device = text_fragments.device # Initialize output output_vectors = torch.zeros(batch_size, seq_len, self.config.d_model, device=device) # Track quantum usage quantum_tokens_processed = 0 classical_tokens_processed = 0 for i in range(seq_len): fragment_ids = text_fragments[:, i] # Prepare fragment data fragment_data = { 'fragment_id': fragment_ids, 'fragment_chars': fragment_chars[:, i] if fragment_chars is not None else None, 'context_chars': context_chars[:, i] if context_chars is not None else None, 'fragment_positions': torch.full((batch_size,), i, device=device) } # Prepare context features context_features = torch.zeros(batch_size, self.config.context_feature_dim, device=device) if 'document_embedding' in context_data: doc_emb = context_data['document_embedding'] feature_dim = min(self.config.context_feature_dim, doc_emb.size(-1)) context_features[:, :feature_dim] = doc_emb[:, :feature_dim] # Determine ambiguity freqs = fragment_frequencies[:, i] if fragment_frequencies is not None else None is_ambiguous = self.ambiguity_classifier(fragment_ids, context_features, freqs) # Process each sample for b in range(batch_size): if is_ambiguous[b]: # Try cache first sample_fragment_data = {k: v[b] if v is not None else None for k, v in fragment_data.items()} sample_context_data = {k: v[b] if isinstance(v, torch.Tensor) else v for k, v in context_data.items()} use_quantum = (self.quantum_enabled and self.config.quantum_config.quantum_information_fuser) cache_key = self._compute_cache_key(sample_fragment_data, sample_context_data, use_quantum) cached_vector = self.cache.get(cache_key, is_quantum=use_quantum) if cached_vector is not None: output_vectors[b, i] = cached_vector if self.enable_profiling: self.profile_stats['cache_hits'] += 1 else: # Compute dynamic embedding dynamic_vector, used_quantum = self._compute_dynamic_embedding( sample_fragment_data, sample_context_data ) output_vectors[b, i] = dynamic_vector.squeeze(0) if dynamic_vector.dim() > 1 else dynamic_vector # Cache the result self.cache.put(cache_key, output_vectors[b, i], is_quantum=used_quantum) if self.enable_profiling: self.profile_stats['cache_misses'] += 1 self.profile_stats['ambiguous_tokens'] += 1 if used_quantum: quantum_tokens_processed += 1 else: classical_tokens_processed += 1 else: # Use static embedding output_vectors[b, i] = self.static_embeddings(fragment_ids[b]) if self.enable_profiling: self.profile_stats['static_tokens'] += 1 # Update statistics if self.enable_profiling: self.profile_stats['total_forward_time'] += time.time() - start_time self.profile_stats['num_forward_calls'] += 1 self.profile_stats['quantum_processed_tokens'] += quantum_tokens_processed self.profile_stats['classical_processed_tokens'] += classical_tokens_processed # Quantum metrics quantum_metrics = { 'quantum_tokens_in_batch': quantum_tokens_processed, 'classical_tokens_in_batch': classical_tokens_processed, 'quantum_ratio': quantum_tokens_processed / (quantum_tokens_processed + classical_tokens_processed + 1e-10) } if hasattr(self.information_fuser, 'get_quantum_circuit_info'): quantum_metrics.update(self.information_fuser.get_quantum_circuit_info()) return output_vectors, quantum_metrics def enable_profiling_mode(self, enable: bool = True): """Enable or disable performance profiling""" self.enable_profiling = enable if enable: self.profile_stats = {k: 0 if isinstance(v, (int, float)) else v for k, v in self.profile_stats.items()} def get_performance_stats(self) -> Dict[str, Any]: """Get comprehensive performance statistics""" stats = self.profile_stats.copy() cache_stats = self.cache.stats() stats.update(cache_stats) # Add derived metrics if stats['num_forward_calls'] > 0: stats['avg_forward_time'] = stats['total_forward_time'] / stats['num_forward_calls'] stats['avg_quantum_time'] = stats['quantum_forward_time'] / stats['num_forward_calls'] stats['avg_classical_time'] = stats['classical_forward_time'] / stats['num_forward_calls'] total_tokens = (stats['ambiguous_tokens'] + stats['static_tokens']) if total_tokens > 0: stats['ambiguous_ratio'] = stats['ambiguous_tokens'] / total_tokens stats['static_ratio'] = stats['static_tokens'] / total_tokens if stats['ambiguous_tokens'] > 0: quantum_processed = stats['quantum_processed_tokens'] classical_processed = stats['classical_processed_tokens'] total_processed = quantum_processed + classical_processed if total_processed > 0: stats['quantum_usage_ratio'] = quantum_processed / total_processed return stats def get_quantum_info(self) -> Dict[str, Any]: """Get quantum-specific information""" if hasattr(self.information_fuser, 'get_quantum_circuit_info'): return self.information_fuser.get_quantum_circuit_info() return {'quantum_enabled': False} def clear_cache(self): """Clear all caches""" self.cache.clear() # Alias for backward compatibility CSTModule = QuantumEnhancedCSTModule def test_quantum_cst_module(): """Test the quantum-enhanced CST module""" from .quantum_cst_config import CSTConfig, QuantumConfig config = CSTConfig() config.quantum_config = QuantumConfig() config.quantum_config.enable_quantum = True config.quantum_config.n_qubits = 6 config.quantum_config.n_layers = 2 config.ambiguous_word_ids = [1, 5, 10, 15, 20] cst = QuantumEnhancedCSTModule(config) cst.enable_profiling_mode(True) batch_size = 2 seq_len = 8 # Sample input text_fragments = torch.randint(0, config.vocab_size, (batch_size, seq_len)) fragment_chars = torch.randint(0, config.char_vocab_size, (batch_size, seq_len, 32)) context_chars = torch.randint(0, config.char_vocab_size, (batch_size, seq_len, 64)) context_data = { 'document_embedding': torch.randn(batch_size, config.raw_doc_dim), 'metadata': { 'author': torch.randint(0, config.num_authors, (batch_size,)), 'domain': torch.randint(0, config.num_domains, (batch_size,)), } } # Forward pass output, quantum_metrics = cst(text_fragments, context_data, fragment_chars, context_chars) print(f"Input shape: {text_fragments.shape}") print(f"Output shape: {output.shape}") print(f"\nQuantum Metrics:") for key, value in quantum_metrics.items(): print(f" {key}: {value}") # Performance stats stats = cst.get_performance_stats() print("\nPerformance Statistics:") for key, value in list(stats.items())[:10]: print(f" {key}: {value}") # Quantum info quantum_info = cst.get_quantum_info() print("\nQuantum Circuit Info:") for key, value in quantum_info.items(): if key != 'fragment_circuit': print(f" {key}: {value}") print("\n✅ Quantum-Enhanced CST Module test passed!") if __name__ == "__main__": test_quantum_cst_module()