Spaces:
Runtime error
Runtime error
| """ | |
| Context Scaling System | |
| Handles length scaling (millions of tokens) and multi-modal/structural scaling | |
| Implements advanced attention methods and memory techniques from the article | |
| """ | |
| import logging | |
| from typing import Dict, List, Any, Optional, Tuple | |
| from dataclasses import dataclass | |
| import numpy as np | |
| from datetime import datetime | |
| import heapq | |
| logger = logging.getLogger(__name__) | |
| class ScaledContext: | |
| """Context that can scale to millions of tokens""" | |
| segments: List[str] # Segmented content | |
| attention_map: np.ndarray # Attention weights for segments | |
| token_count: int | |
| compression_level: int # 0=none, 1=light, 2=medium, 3=heavy | |
| modalities: Dict[str, Any] # Different context modalities | |
| class AttentionOptimizer: | |
| """ | |
| Advanced attention methods for handling extremely long contexts | |
| Implements sliding window, sparse attention, and hierarchical attention | |
| """ | |
| def __init__(self, window_size: int = 512, stride: int = 256): | |
| self.window_size = window_size | |
| self.stride = stride | |
| def sliding_window_attention( | |
| self, | |
| context: str, | |
| query: str, | |
| max_windows: int = 10 | |
| ) -> List[Tuple[str, float]]: | |
| """ | |
| Process context using sliding window attention | |
| Returns relevant windows with attention scores | |
| """ | |
| tokens = context.split() | |
| windows = [] | |
| # Create sliding windows | |
| for i in range(0, len(tokens) - self.window_size + 1, self.stride): | |
| window = ' '.join(tokens[i:i + self.window_size]) | |
| score = self._calculate_attention_score(window, query) | |
| windows.append((window, score)) | |
| # Return top windows | |
| windows.sort(key=lambda x: x[1], reverse=True) | |
| return windows[:max_windows] | |
| def hierarchical_attention( | |
| self, | |
| context: str, | |
| query: str, | |
| levels: int = 3 | |
| ) -> Dict[int, List[str]]: | |
| """ | |
| Multi-level hierarchical attention | |
| Higher levels = more compressed/abstract | |
| """ | |
| hierarchy = {} | |
| current_text = context | |
| for level in range(levels): | |
| if level == 0: | |
| # Finest level - full detail | |
| hierarchy[level] = self._segment_text(current_text, 500) | |
| elif level == 1: | |
| # Middle level - paragraphs/sections | |
| hierarchy[level] = self._extract_key_sentences(current_text) | |
| else: | |
| # Highest level - summary | |
| hierarchy[level] = [self._generate_summary(current_text)] | |
| # Compress for next level | |
| current_text = ' '.join(hierarchy[level]) | |
| return hierarchy | |
| def sparse_attention( | |
| self, | |
| context: str, | |
| query: str, | |
| sparsity: float = 0.1 | |
| ) -> List[str]: | |
| """ | |
| Sparse attention - only attend to most relevant tokens | |
| Reduces computation from O(nΒ²) to O(n*k) | |
| """ | |
| tokens = context.split() | |
| query_tokens = set(query.lower().split()) | |
| # Calculate relevance for each token | |
| token_scores = [] | |
| for i, token in enumerate(tokens): | |
| score = 1.0 if token.lower() in query_tokens else np.random.random() * 0.5 | |
| token_scores.append((i, token, score)) | |
| # Keep only top k% tokens | |
| k = int(len(tokens) * sparsity) | |
| top_tokens = heapq.nlargest(k, token_scores, key=lambda x: x[2]) | |
| # Sort by original position to maintain order | |
| top_tokens.sort(key=lambda x: x[0]) | |
| # Reconstruct sparse context | |
| sparse_context = [] | |
| last_idx = -1 | |
| for idx, token, score in top_tokens: | |
| if idx > last_idx + 1: | |
| sparse_context.append("...") | |
| sparse_context.append(token) | |
| last_idx = idx | |
| return sparse_context | |
| def _calculate_attention_score(self, window: str, query: str) -> float: | |
| """Calculate attention score between window and query""" | |
| window_words = set(window.lower().split()) | |
| query_words = set(query.lower().split()) | |
| if not query_words: | |
| return 0.0 | |
| overlap = len(window_words & query_words) | |
| return overlap / len(query_words) | |
| def _segment_text(self, text: str, segment_size: int) -> List[str]: | |
| """Segment text into chunks""" | |
| words = text.split() | |
| segments = [] | |
| for i in range(0, len(words), segment_size): | |
| segments.append(' '.join(words[i:i + segment_size])) | |
| return segments | |
| def _extract_key_sentences(self, text: str) -> List[str]: | |
| """Extract key sentences (simplified)""" | |
| sentences = text.split('.') | |
| # Keep sentences with more than 10 words (likely more informative) | |
| key_sentences = [s.strip() + '.' for s in sentences if len(s.split()) > 10] | |
| return key_sentences[:10] # Top 10 sentences | |
| def _generate_summary(self, text: str) -> str: | |
| """Generate summary (simplified - would use LLM in production)""" | |
| sentences = text.split('.')[:3] # First 3 sentences as summary | |
| return '. '.join(sentences) + '.' | |
| class LengthScaler: | |
| """ | |
| Handle context scaling from thousands to millions of tokens | |
| Maintains coherence across long documents | |
| """ | |
| def __init__(self, max_tokens: int = 1000000): | |
| self.max_tokens = max_tokens | |
| self.attention_optimizer = AttentionOptimizer() | |
| def scale_context( | |
| self, | |
| context: str, | |
| query: str, | |
| target_tokens: int = 2000 | |
| ) -> ScaledContext: | |
| """Scale context to target token count while maintaining relevance""" | |
| tokens = context.split() | |
| current_tokens = len(tokens) | |
| # Determine compression level needed | |
| compression_ratio = current_tokens / target_tokens | |
| if compression_ratio <= 1: | |
| # No compression needed | |
| return ScaledContext( | |
| segments=[context], | |
| attention_map=np.array([1.0]), | |
| token_count=current_tokens, | |
| compression_level=0, | |
| modalities={} | |
| ) | |
| # Apply appropriate scaling strategy | |
| if compression_ratio < 5: | |
| # Light compression - sliding window | |
| segments = self._light_compression(context, query, target_tokens) | |
| compression_level = 1 | |
| elif compression_ratio < 20: | |
| # Medium compression - hierarchical | |
| segments = self._medium_compression(context, query, target_tokens) | |
| compression_level = 2 | |
| else: | |
| # Heavy compression - sparse attention | |
| segments = self._heavy_compression(context, query, target_tokens) | |
| compression_level = 3 | |
| # Calculate attention map | |
| attention_map = self._calculate_attention_map(segments, query) | |
| return ScaledContext( | |
| segments=segments, | |
| attention_map=attention_map, | |
| token_count=sum(len(s.split()) for s in segments), | |
| compression_level=compression_level, | |
| modalities={} | |
| ) | |
| def _light_compression( | |
| self, | |
| context: str, | |
| query: str, | |
| target_tokens: int | |
| ) -> List[str]: | |
| """Light compression using sliding windows""" | |
| windows = self.attention_optimizer.sliding_window_attention( | |
| context, query, max_windows=target_tokens // 100 | |
| ) | |
| return [w for w, _ in windows] | |
| def _medium_compression( | |
| self, | |
| context: str, | |
| query: str, | |
| target_tokens: int | |
| ) -> List[str]: | |
| """Medium compression using hierarchical attention""" | |
| hierarchy = self.attention_optimizer.hierarchical_attention(context, query) | |
| segments = [] | |
| remaining_tokens = target_tokens | |
| # Add from each level based on available tokens | |
| for level in sorted(hierarchy.keys()): | |
| level_segments = hierarchy[level] | |
| for segment in level_segments: | |
| segment_tokens = len(segment.split()) | |
| if segment_tokens <= remaining_tokens: | |
| segments.append(segment) | |
| remaining_tokens -= segment_tokens | |
| if remaining_tokens <= 0: | |
| break | |
| return segments | |
| def _heavy_compression( | |
| self, | |
| context: str, | |
| query: str, | |
| target_tokens: int | |
| ) -> List[str]: | |
| """Heavy compression using sparse attention""" | |
| sparsity = target_tokens / len(context.split()) | |
| sparse_tokens = self.attention_optimizer.sparse_attention( | |
| context, query, sparsity=min(sparsity, 0.3) | |
| ) | |
| # Group sparse tokens into segments | |
| segments = [] | |
| current_segment = [] | |
| for token in sparse_tokens: | |
| if token == "...": | |
| if current_segment: | |
| segments.append(' '.join(current_segment)) | |
| current_segment = [] | |
| segments.append("...") | |
| else: | |
| current_segment.append(token) | |
| if current_segment: | |
| segments.append(' '.join(current_segment)) | |
| return segments | |
| def _calculate_attention_map( | |
| self, | |
| segments: List[str], | |
| query: str | |
| ) -> np.ndarray: | |
| """Calculate attention weights for each segment""" | |
| query_words = set(query.lower().split()) | |
| attention_scores = [] | |
| for segment in segments: | |
| if segment == "...": | |
| attention_scores.append(0.0) | |
| else: | |
| segment_words = set(segment.lower().split()) | |
| overlap = len(query_words & segment_words) | |
| score = overlap / max(len(query_words), 1) | |
| attention_scores.append(score) | |
| # Normalize | |
| scores = np.array(attention_scores) | |
| if scores.sum() > 0: | |
| scores = scores / scores.sum() | |
| return scores | |
| class MultiModalScaler: | |
| """ | |
| Handle multi-modal and structural context scaling | |
| Temporal, spatial, participant states, intentional, cultural | |
| """ | |
| def __init__(self): | |
| self.modality_handlers = { | |
| 'temporal': self._scale_temporal, | |
| 'spatial': self._scale_spatial, | |
| 'participant': self._scale_participant, | |
| 'intentional': self._scale_intentional, | |
| 'cultural': self._scale_cultural | |
| } | |
| def scale_multimodal( | |
| self, | |
| modalities: Dict[str, Any], | |
| importance_weights: Optional[Dict[str, float]] = None | |
| ) -> Dict[str, Any]: | |
| """Scale multiple modalities based on importance""" | |
| if importance_weights is None: | |
| importance_weights = { | |
| 'temporal': 0.3, | |
| 'spatial': 0.1, | |
| 'participant': 0.3, | |
| 'intentional': 0.2, | |
| 'cultural': 0.1 | |
| } | |
| scaled = {} | |
| for modality, data in modalities.items(): | |
| if modality in self.modality_handlers: | |
| weight = importance_weights.get(modality, 0.1) | |
| scaled[modality] = self.modality_handlers[modality](data, weight) | |
| return scaled | |
| def _scale_temporal(self, data: List[Dict], weight: float) -> List[Dict]: | |
| """Scale temporal context - keep most recent and important events""" | |
| # Sort by timestamp | |
| sorted_data = sorted(data, key=lambda x: x.get('timestamp', datetime.min), reverse=True) | |
| # Keep based on weight (more weight = more events kept) | |
| keep_count = max(1, int(len(sorted_data) * weight)) | |
| return sorted_data[:keep_count] | |
| def _scale_spatial(self, data: Dict, weight: float) -> Dict: | |
| """Scale spatial context - simplify based on importance""" | |
| if weight < 0.3: | |
| # Low importance - just keep basic location | |
| return {'location': data.get('primary_location', 'unknown')} | |
| else: | |
| # Higher importance - keep more detail | |
| return data | |
| def _scale_participant(self, data: Dict, weight: float) -> Dict: | |
| """Scale participant states - keep most active participants""" | |
| if not data: | |
| return {} | |
| # Sort by activity level (approximated by state changes) | |
| participants = [] | |
| for pid, pdata in data.items(): | |
| activity = len(pdata.get('history', [])) | |
| participants.append((pid, pdata, activity)) | |
| participants.sort(key=lambda x: x[2], reverse=True) | |
| # Keep based on weight | |
| keep_count = max(1, int(len(participants) * weight)) | |
| return {pid: pdata for pid, pdata, _ in participants[:keep_count]} | |
| def _scale_intentional(self, data: Dict, weight: float) -> Dict: | |
| """Scale intentional context - keep high priority goals""" | |
| if not data: | |
| return {} | |
| # Sort by priority | |
| goals = [(k, v) for k, v in data.items()] | |
| goals.sort(key=lambda x: x[1].get('priority', 0), reverse=True) | |
| # Keep based on weight | |
| keep_count = max(1, int(len(goals) * weight)) | |
| return {k: v for k, v in goals[:keep_count]} | |
| def _scale_cultural(self, data: Dict, weight: float) -> Dict: | |
| """Scale cultural context - keep if important""" | |
| if weight < 0.2: | |
| return {} # Skip if low importance | |
| return data | |
| class ContextScalingOrchestrator: | |
| """ | |
| Main orchestrator for context scaling | |
| Combines length and multi-modal scaling | |
| """ | |
| def __init__(self, max_context_tokens: int = 100000): | |
| self.length_scaler = LengthScaler(max_context_tokens) | |
| self.multimodal_scaler = MultiModalScaler() | |
| def scale_complete_context( | |
| self, | |
| text_context: str, | |
| multimodal_context: Dict[str, Any], | |
| query: str, | |
| target_tokens: int = 2000, | |
| modality_weights: Optional[Dict[str, float]] = None | |
| ) -> Dict[str, Any]: | |
| """ | |
| Scale both text and multi-modal context | |
| Returns optimally scaled context | |
| """ | |
| # Scale text context | |
| scaled_text = self.length_scaler.scale_context( | |
| text_context, query, target_tokens | |
| ) | |
| # Scale multi-modal context | |
| scaled_multimodal = self.multimodal_scaler.scale_multimodal( | |
| multimodal_context, modality_weights | |
| ) | |
| # Combine | |
| result = { | |
| 'text': { | |
| 'segments': scaled_text.segments, | |
| 'attention_map': scaled_text.attention_map.tolist(), | |
| 'token_count': scaled_text.token_count, | |
| 'compression_level': scaled_text.compression_level | |
| }, | |
| 'multimodal': scaled_multimodal, | |
| 'metadata': { | |
| 'original_tokens': len(text_context.split()), | |
| 'scaled_tokens': scaled_text.token_count, | |
| 'compression_ratio': len(text_context.split()) / max(scaled_text.token_count, 1), | |
| 'modalities_preserved': list(scaled_multimodal.keys()) | |
| } | |
| } | |
| return result | |
| # Demo usage | |
| def demo_context_scaling(): | |
| """Demonstrate context scaling capabilities""" | |
| # Create a very long context | |
| long_context = " ".join([ | |
| f"Sentence {i} about various topics including AI, engineering, and software development." | |
| for i in range(10000) | |
| ]) # ~100k tokens | |
| # Multi-modal context | |
| multimodal = { | |
| 'temporal': [ | |
| {'event': f'Event {i}', 'timestamp': datetime.now()} | |
| for i in range(50) | |
| ], | |
| 'participant': { | |
| f'person_{i}': {'state': 'active', 'history': []} | |
| for i in range(20) | |
| }, | |
| 'intentional': { | |
| f'goal_{i}': {'priority': np.random.random()} | |
| for i in range(10) | |
| } | |
| } | |
| # Scale the context | |
| orchestrator = ContextScalingOrchestrator() | |
| scaled = orchestrator.scale_complete_context( | |
| text_context=long_context, | |
| multimodal_context=multimodal, | |
| query="AI engineering position requirements", | |
| target_tokens=2000 | |
| ) | |
| print(f"Scaling Results:") | |
| print(f"Original tokens: {scaled['metadata']['original_tokens']}") | |
| print(f"Scaled tokens: {scaled['metadata']['scaled_tokens']}") | |
| print(f"Compression ratio: {scaled['metadata']['compression_ratio']:.2f}x") | |
| print(f"Compression level: {scaled['text']['compression_level']}") | |
| print(f"Modalities preserved: {scaled['metadata']['modalities_preserved']}") | |
| print(f"Text segments: {len(scaled['text']['segments'])}") | |
| print(f"Temporal events kept: {len(scaled['multimodal'].get('temporal', []))}") | |
| if __name__ == "__main__": | |
| demo_context_scaling() |