File size: 18,543 Bytes
94c2e42
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
# CST / QCST Dual License
# Non-commercial research use only.
# Commercial use requires explicit permission.
# Copyright (c) 2025 Mohamed Mohamed Elhelbawi
# All rights reserved.
# See LICENSE file in the project root for full license information.

"""
Core CST Module Implementation
Main module that orchestrates fragment encoding, information fusion, and caching
"""

import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Dict, List, Optional, Any, Tuple
import hashlib
import json
from collections import OrderedDict
import time

from fragment_encoder import FragmentEncoder
from information_fuser import InformationFuser


class LRUCache:
    """Simple LRU cache implementation for embedding caching"""
    
    def __init__(self, capacity: int):
        self.capacity = capacity
        self.cache = OrderedDict()
        self.hits = 0
        self.misses = 0
        
    def get(self, key: str) -> Optional[torch.Tensor]:
        if key in self.cache:
            # Move to end (most recently used)
            self.cache.move_to_end(key)
            self.hits += 1
            return self.cache[key].clone()  # Clone to avoid in-place modifications
        else:
            self.misses += 1
            return None
    
    def put(self, key: str, value: torch.Tensor):
        if key in self.cache:
            self.cache.move_to_end(key)
        else:
            if len(self.cache) >= self.capacity:
                # Remove least recently used item
                self.cache.popitem(last=False)
            
        self.cache[key] = value.clone().detach()
    
    def clear(self):
        self.cache.clear()
        self.hits = 0
        self.misses = 0
    
    def stats(self):
        total = self.hits + self.misses
        hit_rate = self.hits / total if total > 0 else 0.0
        return {
            'hits': self.hits,
            'misses': self.misses,
            'hit_rate': hit_rate,
            'cache_size': len(self.cache),
            'capacity': self.capacity
        }


class AmbiguityClassifier(nn.Module):
    """Determines whether dynamic processing is needed for each fragment"""
    
    def __init__(self, config):
        super().__init__()
        self.config = config
        
        # Pre-computed ambiguous word vocabulary (loaded during training)
        self.register_buffer(
            'ambiguous_vocab', 
            torch.tensor(config.ambiguous_word_ids if config.ambiguous_word_ids else [])
        )
        
        # Context-based ambiguity classifier
        context_input_dim = config.fragment_encoding_dim + config.context_feature_dim
        self.context_classifier = nn.Sequential(
            nn.Linear(context_input_dim, config.hidden_dim),
            nn.LayerNorm(config.hidden_dim),
            nn.GELU(),
            nn.Dropout(0.1),
            nn.Linear(config.hidden_dim, config.hidden_dim // 2),
            nn.GELU(),
            nn.Linear(config.hidden_dim // 2, 1),
            nn.Sigmoid()
        )
        
        # Frequency-based classifier (learns from data)
        self.frequency_classifier = nn.Sequential(
            nn.Linear(1, 32),  # Input: log frequency
            nn.ReLU(),
            nn.Linear(32, 1),
            nn.Sigmoid()
        )
        
        # Combination weights
        self.combination_weights = nn.Parameter(torch.tensor([0.4, 0.3, 0.3]))  # vocab, context, frequency
        self.ambiguity_threshold = config.ambiguity_threshold
        
    def forward(self, 
                fragment_ids: torch.Tensor, 
                context_features: torch.Tensor,
                fragment_frequencies: Optional[torch.Tensor] = None) -> torch.Tensor:
        """
        Determine ambiguity for each fragment
        
        Args:
            fragment_ids: [batch_size] - Fragment token IDs
            context_features: [batch_size, context_feature_dim] - Context features
            fragment_frequencies: [batch_size] - Log frequencies of fragments
        """
        batch_size = fragment_ids.size(0)
        ambiguity_scores = torch.zeros(batch_size, device=fragment_ids.device)
        
        # 1. Vocabulary-based ambiguity
        if len(self.ambiguous_vocab) > 0:
            vocab_ambiguous = torch.isin(fragment_ids, self.ambiguous_vocab).float()
            ambiguity_scores += self.combination_weights[0] * vocab_ambiguous
        
        # 2. Context-based ambiguity  
        if context_features.size(1) >= self.config.context_feature_dim:
            # Pad fragment encoding to match expected dimension
            fragment_encoding = torch.zeros(batch_size, self.config.fragment_encoding_dim, 
                                          device=fragment_ids.device)
            combined_features = torch.cat([fragment_encoding, context_features[:, :self.config.context_feature_dim]], dim=1)
            context_scores = self.context_classifier(combined_features).squeeze(-1)
            ambiguity_scores += self.combination_weights[1] * context_scores
        
        # 3. Frequency-based ambiguity (high frequency words are more likely ambiguous)
        if fragment_frequencies is not None:
            freq_scores = self.frequency_classifier(fragment_frequencies.unsqueeze(-1)).squeeze(-1)
            ambiguity_scores += self.combination_weights[2] * freq_scores
        
        # Return binary decisions
        return ambiguity_scores > self.ambiguity_threshold
    
    def update_ambiguous_vocab(self, new_ambiguous_ids: List[int]):
        """Update the ambiguous vocabulary during training"""
        self.ambiguous_vocab = torch.tensor(new_ambiguous_ids, device=self.ambiguous_vocab.device)


class ProjectionHead(nn.Module):
    """Projects fused representation to transformer embedding dimension"""
    
    def __init__(self, config):
        super().__init__()
        
        self.projection = nn.Sequential(
            nn.Linear(config.fused_dim, config.d_model),
            nn.LayerNorm(config.d_model),
            nn.Tanh(),  # Bounded output for stability
            nn.Dropout(0.1)
        )
        
        # Residual connection option
        self.use_residual = config.fused_dim == config.d_model
        if not self.use_residual and hasattr(config, 'enable_projection_residual'):
            self.residual_proj = nn.Linear(config.fused_dim, config.d_model)
            self.use_residual = config.enable_projection_residual
        
    def forward(self, fused_representation: torch.Tensor) -> torch.Tensor:
        output = self.projection(fused_representation)
        
        if self.use_residual:
            if hasattr(self, 'residual_proj'):
                residual = self.residual_proj(fused_representation)
            else:
                residual = fused_representation
            output = output + residual
            
        return output


class CSTModule(nn.Module):
    """
    Main Contextual Spectrum Tokenization Module
    
    Integrates fragment encoding, information fusion, ambiguity detection, and caching
    """
    
    def __init__(self, config):
        super().__init__()
        self.config = config
        
        # Core components
        self.fragment_encoder = FragmentEncoder(config)
        self.information_fuser = InformationFuser(config)
        self.projection_head = ProjectionHead(config)
        self.ambiguity_classifier = AmbiguityClassifier(config)
        
        # Static embeddings fallback
        self.static_embeddings = nn.Embedding(config.vocab_size, config.d_model)
        
        # Initialize static embeddings with reasonable values
        nn.init.normal_(self.static_embeddings.weight, mean=0.0, std=0.02)
        
        # Caching system
        self.cache = LRUCache(config.cache_size)
        
        # Performance tracking
        self.enable_profiling = False
        self.profile_stats = {
            'cache_hits': 0,
            'cache_misses': 0,
            'ambiguous_tokens': 0,
            'static_tokens': 0,
            'total_forward_time': 0.0,
            'num_forward_calls': 0
        }
        
    def _compute_cache_key(self, fragment_data: Dict[str, Any], context_data: Dict[str, Any]) -> str:
        """Compute a hash key for caching"""
        # Create a simplified representation for hashing
        key_components = {
            'fragment_id': fragment_data.get('fragment_id', '').item() if torch.is_tensor(fragment_data.get('fragment_id')) else str(fragment_data.get('fragment_id', '')),
            'context_hash': self._hash_context(context_data)
        }
        
        key_string = json.dumps(key_components, sort_keys=True)
        return hashlib.md5(key_string.encode()).hexdigest()
    
    def _hash_context(self, context_data: Dict[str, Any]) -> str:
        """Create a hash of context data for caching"""
        context_summary = {}
        
        for key, value in context_data.items():
            if isinstance(value, torch.Tensor):
                # Use tensor statistics for hashing
                context_summary[key] = {
                    'shape': list(value.shape),
                    'mean': float(value.mean().item()) if value.numel() > 0 else 0.0,
                    'std': float(value.std().item()) if value.numel() > 0 else 0.0
                }
            elif isinstance(value, dict):
                context_summary[key] = self._hash_context(value)
            else:
                context_summary[key] = str(value)
        
        return hashlib.md5(json.dumps(context_summary, sort_keys=True).encode()).hexdigest()[:16]
    
    def _compute_dynamic_embedding(self, fragment_data: Dict[str, Any], context_data: Dict[str, Any]) -> torch.Tensor:
        """Compute dynamic embedding using the full CST pipeline"""
        
        # Extract fragment encoding
        fragment_encoding = self.fragment_encoder(
            fragment_data['fragment_chars'],
            fragment_data['context_chars'], 
            fragment_data.get('fragment_positions')
        )
        
        # Fuse with contextual information
        fused_representation = self.information_fuser(fragment_encoding, context_data)
        
        # Project to output space
        output_embedding = self.projection_head(fused_representation)
        
        return output_embedding
    
    def forward(self, 
                text_fragments: torch.Tensor, 
                context_data: Dict[str, Any],
                fragment_chars: Optional[torch.Tensor] = None,
                context_chars: Optional[torch.Tensor] = None,
                fragment_frequencies: Optional[torch.Tensor] = None) -> torch.Tensor:
        """
        Main forward pass of CST module
        
        Args:
            text_fragments: [batch_size, seq_len] - Token IDs
            context_data: Dictionary of contextual information
            fragment_chars: [batch_size, seq_len, char_len] - Character-level data
            context_chars: [batch_size, seq_len, context_char_len] - Context characters
            fragment_frequencies: [batch_size, seq_len] - Fragment frequencies
        """
        start_time = time.time() if self.enable_profiling else 0
        
        batch_size, seq_len = text_fragments.shape
        device = text_fragments.device
        
        # Initialize output
        output_vectors = torch.zeros(batch_size, seq_len, self.config.d_model, device=device)
        
        for i in range(seq_len):
            fragment_ids = text_fragments[:, i]
            
            # Prepare fragment data
            fragment_data = {
                'fragment_id': fragment_ids,
                'fragment_chars': fragment_chars[:, i] if fragment_chars is not None else None,
                'context_chars': context_chars[:, i] if context_chars is not None else None,
                'fragment_positions': torch.full((batch_size,), i, device=device)
            }
            
            # Prepare context features for ambiguity classification
            context_features = torch.zeros(batch_size, self.config.context_feature_dim, device=device)
            if 'document_embedding' in context_data:
                doc_emb = context_data['document_embedding']
                feature_dim = min(self.config.context_feature_dim, doc_emb.size(-1))
                context_features[:, :feature_dim] = doc_emb[:, :feature_dim]
            
            # Determine if dynamic processing is needed
            freqs = fragment_frequencies[:, i] if fragment_frequencies is not None else None
            is_ambiguous = self.ambiguity_classifier(fragment_ids, context_features, freqs)
            
            # Process each sample in the batch
            for b in range(batch_size):
                if is_ambiguous[b]:
                    # Try cache first
                    sample_fragment_data = {k: v[b] if v is not None else None for k, v in fragment_data.items()}
                    sample_context_data = {k: v[b] if isinstance(v, torch.Tensor) else v for k, v in context_data.items()}
                    
                    cache_key = self._compute_cache_key(sample_fragment_data, sample_context_data)
                    cached_vector = self.cache.get(cache_key)
                    
                    if cached_vector is not None:
                        output_vectors[b, i] = cached_vector
                        if self.enable_profiling:
                            self.profile_stats['cache_hits'] += 1
                    else:
                        # Compute dynamic embedding
                        dynamic_vector = self._compute_dynamic_embedding(sample_fragment_data, sample_context_data)
                        output_vectors[b, i] = dynamic_vector.squeeze(0) if dynamic_vector.dim() > 1 else dynamic_vector
                        
                        # Cache the result
                        self.cache.put(cache_key, output_vectors[b, i])
                        
                        if self.enable_profiling:
                            self.profile_stats['cache_misses'] += 1
                            self.profile_stats['ambiguous_tokens'] += 1
                else:
                    # Use static embedding
                    output_vectors[b, i] = self.static_embeddings(fragment_ids[b])
                    if self.enable_profiling:
                        self.profile_stats['static_tokens'] += 1
        
        if self.enable_profiling:
            self.profile_stats['total_forward_time'] += time.time() - start_time
            self.profile_stats['num_forward_calls'] += 1
        
        return output_vectors
    
    def encode_single_fragment(self, fragment_text: str, context_data: Dict[str, Any]) -> torch.Tensor:
        """Encode a single text fragment (useful for inference)"""
        # This would need proper text preprocessing - simplified for now
        fragment_id = hash(fragment_text) % self.config.vocab_size  # Simplified tokenization
        fragment_tensor = torch.tensor([[fragment_id]], dtype=torch.long)
        
        return self.forward(fragment_tensor, context_data).squeeze()
    
    def enable_profiling_mode(self, enable: bool = True):
        """Enable or disable performance profiling"""
        self.enable_profiling = enable
        if enable:
            # Reset stats
            self.profile_stats = {k: 0 if isinstance(v, (int, float)) else v for k, v in self.profile_stats.items()}
    
    def get_performance_stats(self) -> Dict[str, Any]:
        """Get performance statistics"""
        stats = self.profile_stats.copy()
        cache_stats = self.cache.stats()
        stats.update(cache_stats)
        
        # Add derived metrics
        if stats['num_forward_calls'] > 0:
            stats['avg_forward_time'] = stats['total_forward_time'] / stats['num_forward_calls']
        
        total_tokens = stats['ambiguous_tokens'] + stats['static_tokens']
        if total_tokens > 0:
            stats['ambiguous_ratio'] = stats['ambiguous_tokens'] / total_tokens
            stats['static_ratio'] = stats['static_tokens'] / total_tokens
        
        return stats
    
    def clear_cache(self):
        """Clear the embedding cache"""
        self.cache.clear()
    
    def save_ambiguous_vocab(self, filepath: str):
        """Save the current ambiguous vocabulary"""
        vocab_list = self.ambiguous_vocab.cpu().numpy().tolist()
        with open(filepath, 'w') as f:
            json.dump(vocab_list, f)
    
    def load_ambiguous_vocab(self, filepath: str):
        """Load ambiguous vocabulary from file"""
        with open(filepath, 'r') as f:
            vocab_list = json.load(f)
        self.ambiguity_classifier.update_ambiguous_vocab(vocab_list)


def test_cst_module():
    """Test the complete CST module"""
    from config import CSTConfig
    
    config = CSTConfig()
    config.ambiguous_word_ids = [1, 5, 10, 15, 20]  # Sample ambiguous words
    
    cst = CSTModule(config)
    cst.enable_profiling_mode(True)
    
    batch_size = 2
    seq_len = 8
    
    # Sample input
    text_fragments = torch.randint(0, config.vocab_size, (batch_size, seq_len))
    fragment_chars = torch.randint(0, config.char_vocab_size, (batch_size, seq_len, 32))
    context_chars = torch.randint(0, config.char_vocab_size, (batch_size, seq_len, 64))
    
    context_data = {
        'document_embedding': torch.randn(batch_size, config.raw_doc_dim),
        'metadata': {
            'author': torch.randint(0, config.num_authors, (batch_size,)),
            'domain': torch.randint(0, config.num_domains, (batch_size,)),
        }
    }
    
    # Forward pass
    output = cst(text_fragments, context_data, fragment_chars, context_chars)
    
    print(f"Input shape: {text_fragments.shape}")
    print(f"Output shape: {output.shape}")
    print(f"Expected output shape: {(batch_size, seq_len, config.d_model)}")
    
    # Print performance stats
    stats = cst.get_performance_stats()
    print("\nPerformance Statistics:")
    for key, value in stats.items():
        print(f"  {key}: {value}")
    
    # Test caching
    print("\nTesting caching...")
    output2 = cst(text_fragments, context_data, fragment_chars, context_chars)
    
    cache_stats = cst.get_performance_stats()
    print(f"Cache hit rate after second pass: {cache_stats['hit_rate']:.2%}")
    
    assert output.shape == (batch_size, seq_len, config.d_model), \
        f"Expected {(batch_size, seq_len, config.d_model)}, got {output.shape}"
    
    print("CST Module test passed!")


if __name__ == "__main__":
    test_cst_module()