""" Context Scaling System Handles length scaling (millions of tokens) and multi-modal/structural scaling Implements advanced attention methods and memory techniques from the article """ import logging from typing import Dict, List, Any, Optional, Tuple from dataclasses import dataclass import numpy as np from datetime import datetime import heapq logger = logging.getLogger(__name__) @dataclass class ScaledContext: """Context that can scale to millions of tokens""" segments: List[str] # Segmented content attention_map: np.ndarray # Attention weights for segments token_count: int compression_level: int # 0=none, 1=light, 2=medium, 3=heavy modalities: Dict[str, Any] # Different context modalities class AttentionOptimizer: """ Advanced attention methods for handling extremely long contexts Implements sliding window, sparse attention, and hierarchical attention """ def __init__(self, window_size: int = 512, stride: int = 256): self.window_size = window_size self.stride = stride def sliding_window_attention( self, context: str, query: str, max_windows: int = 10 ) -> List[Tuple[str, float]]: """ Process context using sliding window attention Returns relevant windows with attention scores """ tokens = context.split() windows = [] # Create sliding windows for i in range(0, len(tokens) - self.window_size + 1, self.stride): window = ' '.join(tokens[i:i + self.window_size]) score = self._calculate_attention_score(window, query) windows.append((window, score)) # Return top windows windows.sort(key=lambda x: x[1], reverse=True) return windows[:max_windows] def hierarchical_attention( self, context: str, query: str, levels: int = 3 ) -> Dict[int, List[str]]: """ Multi-level hierarchical attention Higher levels = more compressed/abstract """ hierarchy = {} current_text = context for level in range(levels): if level == 0: # Finest level - full detail hierarchy[level] = self._segment_text(current_text, 500) elif level == 1: # Middle level - paragraphs/sections hierarchy[level] = self._extract_key_sentences(current_text) else: # Highest level - summary hierarchy[level] = [self._generate_summary(current_text)] # Compress for next level current_text = ' '.join(hierarchy[level]) return hierarchy def sparse_attention( self, context: str, query: str, sparsity: float = 0.1 ) -> List[str]: """ Sparse attention - only attend to most relevant tokens Reduces computation from O(n²) to O(n*k) """ tokens = context.split() query_tokens = set(query.lower().split()) # Calculate relevance for each token token_scores = [] for i, token in enumerate(tokens): score = 1.0 if token.lower() in query_tokens else np.random.random() * 0.5 token_scores.append((i, token, score)) # Keep only top k% tokens k = int(len(tokens) * sparsity) top_tokens = heapq.nlargest(k, token_scores, key=lambda x: x[2]) # Sort by original position to maintain order top_tokens.sort(key=lambda x: x[0]) # Reconstruct sparse context sparse_context = [] last_idx = -1 for idx, token, score in top_tokens: if idx > last_idx + 1: sparse_context.append("...") sparse_context.append(token) last_idx = idx return sparse_context def _calculate_attention_score(self, window: str, query: str) -> float: """Calculate attention score between window and query""" window_words = set(window.lower().split()) query_words = set(query.lower().split()) if not query_words: return 0.0 overlap = len(window_words & query_words) return overlap / len(query_words) def _segment_text(self, text: str, segment_size: int) -> List[str]: """Segment text into chunks""" words = text.split() segments = [] for i in range(0, len(words), segment_size): segments.append(' '.join(words[i:i + segment_size])) return segments def _extract_key_sentences(self, text: str) -> List[str]: """Extract key sentences (simplified)""" sentences = text.split('.') # Keep sentences with more than 10 words (likely more informative) key_sentences = [s.strip() + '.' for s in sentences if len(s.split()) > 10] return key_sentences[:10] # Top 10 sentences def _generate_summary(self, text: str) -> str: """Generate summary (simplified - would use LLM in production)""" sentences = text.split('.')[:3] # First 3 sentences as summary return '. '.join(sentences) + '.' class LengthScaler: """ Handle context scaling from thousands to millions of tokens Maintains coherence across long documents """ def __init__(self, max_tokens: int = 1000000): self.max_tokens = max_tokens self.attention_optimizer = AttentionOptimizer() def scale_context( self, context: str, query: str, target_tokens: int = 2000 ) -> ScaledContext: """Scale context to target token count while maintaining relevance""" tokens = context.split() current_tokens = len(tokens) # Determine compression level needed compression_ratio = current_tokens / target_tokens if compression_ratio <= 1: # No compression needed return ScaledContext( segments=[context], attention_map=np.array([1.0]), token_count=current_tokens, compression_level=0, modalities={} ) # Apply appropriate scaling strategy if compression_ratio < 5: # Light compression - sliding window segments = self._light_compression(context, query, target_tokens) compression_level = 1 elif compression_ratio < 20: # Medium compression - hierarchical segments = self._medium_compression(context, query, target_tokens) compression_level = 2 else: # Heavy compression - sparse attention segments = self._heavy_compression(context, query, target_tokens) compression_level = 3 # Calculate attention map attention_map = self._calculate_attention_map(segments, query) return ScaledContext( segments=segments, attention_map=attention_map, token_count=sum(len(s.split()) for s in segments), compression_level=compression_level, modalities={} ) def _light_compression( self, context: str, query: str, target_tokens: int ) -> List[str]: """Light compression using sliding windows""" windows = self.attention_optimizer.sliding_window_attention( context, query, max_windows=target_tokens // 100 ) return [w for w, _ in windows] def _medium_compression( self, context: str, query: str, target_tokens: int ) -> List[str]: """Medium compression using hierarchical attention""" hierarchy = self.attention_optimizer.hierarchical_attention(context, query) segments = [] remaining_tokens = target_tokens # Add from each level based on available tokens for level in sorted(hierarchy.keys()): level_segments = hierarchy[level] for segment in level_segments: segment_tokens = len(segment.split()) if segment_tokens <= remaining_tokens: segments.append(segment) remaining_tokens -= segment_tokens if remaining_tokens <= 0: break return segments def _heavy_compression( self, context: str, query: str, target_tokens: int ) -> List[str]: """Heavy compression using sparse attention""" sparsity = target_tokens / len(context.split()) sparse_tokens = self.attention_optimizer.sparse_attention( context, query, sparsity=min(sparsity, 0.3) ) # Group sparse tokens into segments segments = [] current_segment = [] for token in sparse_tokens: if token == "...": if current_segment: segments.append(' '.join(current_segment)) current_segment = [] segments.append("...") else: current_segment.append(token) if current_segment: segments.append(' '.join(current_segment)) return segments def _calculate_attention_map( self, segments: List[str], query: str ) -> np.ndarray: """Calculate attention weights for each segment""" query_words = set(query.lower().split()) attention_scores = [] for segment in segments: if segment == "...": attention_scores.append(0.0) else: segment_words = set(segment.lower().split()) overlap = len(query_words & segment_words) score = overlap / max(len(query_words), 1) attention_scores.append(score) # Normalize scores = np.array(attention_scores) if scores.sum() > 0: scores = scores / scores.sum() return scores class MultiModalScaler: """ Handle multi-modal and structural context scaling Temporal, spatial, participant states, intentional, cultural """ def __init__(self): self.modality_handlers = { 'temporal': self._scale_temporal, 'spatial': self._scale_spatial, 'participant': self._scale_participant, 'intentional': self._scale_intentional, 'cultural': self._scale_cultural } def scale_multimodal( self, modalities: Dict[str, Any], importance_weights: Optional[Dict[str, float]] = None ) -> Dict[str, Any]: """Scale multiple modalities based on importance""" if importance_weights is None: importance_weights = { 'temporal': 0.3, 'spatial': 0.1, 'participant': 0.3, 'intentional': 0.2, 'cultural': 0.1 } scaled = {} for modality, data in modalities.items(): if modality in self.modality_handlers: weight = importance_weights.get(modality, 0.1) scaled[modality] = self.modality_handlers[modality](data, weight) return scaled def _scale_temporal(self, data: List[Dict], weight: float) -> List[Dict]: """Scale temporal context - keep most recent and important events""" # Sort by timestamp sorted_data = sorted(data, key=lambda x: x.get('timestamp', datetime.min), reverse=True) # Keep based on weight (more weight = more events kept) keep_count = max(1, int(len(sorted_data) * weight)) return sorted_data[:keep_count] def _scale_spatial(self, data: Dict, weight: float) -> Dict: """Scale spatial context - simplify based on importance""" if weight < 0.3: # Low importance - just keep basic location return {'location': data.get('primary_location', 'unknown')} else: # Higher importance - keep more detail return data def _scale_participant(self, data: Dict, weight: float) -> Dict: """Scale participant states - keep most active participants""" if not data: return {} # Sort by activity level (approximated by state changes) participants = [] for pid, pdata in data.items(): activity = len(pdata.get('history', [])) participants.append((pid, pdata, activity)) participants.sort(key=lambda x: x[2], reverse=True) # Keep based on weight keep_count = max(1, int(len(participants) * weight)) return {pid: pdata for pid, pdata, _ in participants[:keep_count]} def _scale_intentional(self, data: Dict, weight: float) -> Dict: """Scale intentional context - keep high priority goals""" if not data: return {} # Sort by priority goals = [(k, v) for k, v in data.items()] goals.sort(key=lambda x: x[1].get('priority', 0), reverse=True) # Keep based on weight keep_count = max(1, int(len(goals) * weight)) return {k: v for k, v in goals[:keep_count]} def _scale_cultural(self, data: Dict, weight: float) -> Dict: """Scale cultural context - keep if important""" if weight < 0.2: return {} # Skip if low importance return data class ContextScalingOrchestrator: """ Main orchestrator for context scaling Combines length and multi-modal scaling """ def __init__(self, max_context_tokens: int = 100000): self.length_scaler = LengthScaler(max_context_tokens) self.multimodal_scaler = MultiModalScaler() def scale_complete_context( self, text_context: str, multimodal_context: Dict[str, Any], query: str, target_tokens: int = 2000, modality_weights: Optional[Dict[str, float]] = None ) -> Dict[str, Any]: """ Scale both text and multi-modal context Returns optimally scaled context """ # Scale text context scaled_text = self.length_scaler.scale_context( text_context, query, target_tokens ) # Scale multi-modal context scaled_multimodal = self.multimodal_scaler.scale_multimodal( multimodal_context, modality_weights ) # Combine result = { 'text': { 'segments': scaled_text.segments, 'attention_map': scaled_text.attention_map.tolist(), 'token_count': scaled_text.token_count, 'compression_level': scaled_text.compression_level }, 'multimodal': scaled_multimodal, 'metadata': { 'original_tokens': len(text_context.split()), 'scaled_tokens': scaled_text.token_count, 'compression_ratio': len(text_context.split()) / max(scaled_text.token_count, 1), 'modalities_preserved': list(scaled_multimodal.keys()) } } return result # Demo usage def demo_context_scaling(): """Demonstrate context scaling capabilities""" # Create a very long context long_context = " ".join([ f"Sentence {i} about various topics including AI, engineering, and software development." for i in range(10000) ]) # ~100k tokens # Multi-modal context multimodal = { 'temporal': [ {'event': f'Event {i}', 'timestamp': datetime.now()} for i in range(50) ], 'participant': { f'person_{i}': {'state': 'active', 'history': []} for i in range(20) }, 'intentional': { f'goal_{i}': {'priority': np.random.random()} for i in range(10) } } # Scale the context orchestrator = ContextScalingOrchestrator() scaled = orchestrator.scale_complete_context( text_context=long_context, multimodal_context=multimodal, query="AI engineering position requirements", target_tokens=2000 ) print(f"Scaling Results:") print(f"Original tokens: {scaled['metadata']['original_tokens']}") print(f"Scaled tokens: {scaled['metadata']['scaled_tokens']}") print(f"Compression ratio: {scaled['metadata']['compression_ratio']:.2f}x") print(f"Compression level: {scaled['text']['compression_level']}") print(f"Modalities preserved: {scaled['metadata']['modalities_preserved']}") print(f"Text segments: {len(scaled['text']['segments'])}") print(f"Temporal events kept: {len(scaled['multimodal'].get('temporal', []))}") if __name__ == "__main__": demo_context_scaling()