Spaces:
Runtime error
Runtime error
| """ | |
| Context Engineering System | |
| Implements the complete context engineering framework for optimal LLM performance | |
| Based on the three-step evolution: Retrieval/Generation β Processing β Management | |
| """ | |
| import json | |
| import logging | |
| from typing import Dict, List, Any, Optional, Tuple | |
| from datetime import datetime, timedelta | |
| from dataclasses import dataclass, field | |
| import hashlib | |
| from collections import deque | |
| import numpy as np | |
| from pathlib import Path | |
| logger = logging.getLogger(__name__) | |
| class ContextChunk: | |
| """A unit of context with metadata""" | |
| content: str | |
| source: str | |
| timestamp: datetime | |
| relevance_score: float = 0.0 | |
| token_count: int = 0 | |
| embedding: Optional[np.ndarray] = None | |
| metadata: Dict = field(default_factory=dict) | |
| compression_ratio: float = 1.0 | |
| access_count: int = 0 | |
| last_accessed: Optional[datetime] = None | |
| def update_access(self): | |
| """Update access statistics""" | |
| self.access_count += 1 | |
| self.last_accessed = datetime.now() | |
| class DataFlywheel: | |
| """ | |
| NVIDIA's concept: Continuous improvement through input/output pairing | |
| Learns from successful context usage to optimize future retrievals | |
| """ | |
| def __init__(self, storage_path: str = "flywheel_data.json"): | |
| self.storage_path = Path(storage_path) | |
| self.successful_contexts: List[Dict] = [] | |
| self.feedback_pairs: List[Tuple[str, str, float]] = [] # (input, output, score) | |
| self.pattern_cache: Dict[str, List[str]] = {} | |
| self.load() | |
| def record_success( | |
| self, | |
| input_context: str, | |
| output: str, | |
| success_score: float, | |
| context_chunks: List[ContextChunk] | |
| ): | |
| """Record successful context usage for learning""" | |
| self.successful_contexts.append({ | |
| 'timestamp': datetime.now().isoformat(), | |
| 'input': input_context[:500], # Truncate for storage | |
| 'output': output[:500], | |
| 'score': success_score, | |
| 'chunks_used': [c.source for c in context_chunks], | |
| 'avg_relevance': np.mean([c.relevance_score for c in context_chunks]) | |
| }) | |
| # Update pattern cache | |
| key = self._generate_pattern_key(input_context) | |
| if key not in self.pattern_cache: | |
| self.pattern_cache[key] = [] | |
| self.pattern_cache[key].extend([c.source for c in context_chunks]) | |
| self.save() | |
| def get_recommended_sources(self, query: str) -> List[str]: | |
| """Get recommended context sources based on past successes""" | |
| key = self._generate_pattern_key(query) | |
| if key in self.pattern_cache: | |
| # Return most frequently used sources for similar queries | |
| sources = self.pattern_cache[key] | |
| from collections import Counter | |
| return [s for s, _ in Counter(sources).most_common(5)] | |
| return [] | |
| def _generate_pattern_key(self, text: str) -> str: | |
| """Generate pattern key for caching""" | |
| # Simple keyword extraction for pattern matching | |
| keywords = sorted(set(text.lower().split()[:10])) | |
| return hashlib.md5('_'.join(keywords).encode()).hexdigest()[:8] | |
| def save(self): | |
| """Persist flywheel data""" | |
| data = { | |
| 'successful_contexts': self.successful_contexts[-100:], # Keep last 100 | |
| 'pattern_cache': {k: v[-20:] for k, v in self.pattern_cache.items()} # Keep last 20 per pattern | |
| } | |
| with open(self.storage_path, 'w') as f: | |
| json.dump(data, f, indent=2) | |
| def load(self): | |
| """Load flywheel data""" | |
| if self.storage_path.exists(): | |
| try: | |
| with open(self.storage_path, 'r') as f: | |
| data = json.load(f) | |
| self.successful_contexts = data.get('successful_contexts', []) | |
| self.pattern_cache = data.get('pattern_cache', {}) | |
| except Exception as e: | |
| logger.error(f"Error loading flywheel data: {e}") | |
| class ContextProcessor: | |
| """ | |
| Step 2: Process and refine raw context | |
| Handles chunking, embedding, relevance scoring, and compression | |
| """ | |
| def __init__(self, max_chunk_size: int = 500, overlap: int = 50): | |
| self.max_chunk_size = max_chunk_size | |
| self.overlap = overlap | |
| def process_context( | |
| self, | |
| raw_context: str, | |
| query: str, | |
| source: str = "unknown" | |
| ) -> List[ContextChunk]: | |
| """Process raw context into optimized chunks""" | |
| # 1. Chunk the context | |
| chunks = self._chunk_text(raw_context) | |
| # 2. Create ContextChunk objects | |
| context_chunks = [] | |
| for chunk_text in chunks: | |
| chunk = ContextChunk( | |
| content=chunk_text, | |
| source=source, | |
| timestamp=datetime.now(), | |
| token_count=len(chunk_text.split()), | |
| relevance_score=self._calculate_relevance(chunk_text, query) | |
| ) | |
| # 3. Apply compression if needed | |
| if chunk.token_count > 100: | |
| chunk.content, chunk.compression_ratio = self._compress_text(chunk_text) | |
| context_chunks.append(chunk) | |
| # 4. Sort by relevance | |
| context_chunks.sort(key=lambda c: c.relevance_score, reverse=True) | |
| return context_chunks | |
| def _chunk_text(self, text: str) -> List[str]: | |
| """Smart chunking with overlap""" | |
| words = text.split() | |
| chunks = [] | |
| for i in range(0, len(words), self.max_chunk_size - self.overlap): | |
| chunk = ' '.join(words[i:i + self.max_chunk_size]) | |
| chunks.append(chunk) | |
| return chunks | |
| def _calculate_relevance(self, chunk: str, query: str) -> float: | |
| """Calculate relevance score between chunk and query""" | |
| # Simple keyword overlap scoring (would use embeddings in production) | |
| query_words = set(query.lower().split()) | |
| chunk_words = set(chunk.lower().split()) | |
| if not query_words: | |
| return 0.0 | |
| overlap = len(query_words & chunk_words) | |
| return overlap / len(query_words) | |
| def _compress_text(self, text: str) -> Tuple[str, float]: | |
| """Compress text by removing redundancy""" | |
| # Simple compression: remove duplicate sentences | |
| sentences = text.split('.') | |
| unique_sentences = [] | |
| seen = set() | |
| for sent in sentences: | |
| sent_clean = sent.strip().lower() | |
| if sent_clean and sent_clean not in seen: | |
| unique_sentences.append(sent.strip()) | |
| seen.add(sent_clean) | |
| compressed = '. '.join(unique_sentences) | |
| if unique_sentences and not compressed.endswith('.'): | |
| compressed += '.' | |
| compression_ratio = len(compressed) / len(text) if text else 1.0 | |
| return compressed, compression_ratio | |
| class MemoryHierarchy: | |
| """ | |
| Hierarchical memory system with different levels | |
| L1: Hot cache (immediate access) | |
| L2: Working memory (recent contexts) | |
| L3: Long-term storage (compressed historical) | |
| """ | |
| def __init__( | |
| self, | |
| l1_size: int = 10, | |
| l2_size: int = 100, | |
| l3_path: str = "long_term_memory.json" | |
| ): | |
| self.l1_cache: deque = deque(maxlen=l1_size) # Most recent/relevant | |
| self.l2_memory: deque = deque(maxlen=l2_size) # Working memory | |
| self.l3_storage_path = Path(l3_path) | |
| self.l3_index: Dict[str, Dict] = {} # Index for long-term storage | |
| self.load_l3() | |
| def add_context(self, chunk: ContextChunk): | |
| """Add context to appropriate memory level""" | |
| # High relevance goes to L1 | |
| if chunk.relevance_score > 0.8: | |
| self.l1_cache.append(chunk) | |
| # Medium relevance to L2 | |
| elif chunk.relevance_score > 0.5: | |
| self.l2_memory.append(chunk) | |
| # Everything gets indexed in L3 | |
| self._add_to_l3(chunk) | |
| def retrieve( | |
| self, | |
| query: str, | |
| max_chunks: int = 10, | |
| recency_weight: float = 0.3 | |
| ) -> List[ContextChunk]: | |
| """Retrieve relevant context from all memory levels""" | |
| all_chunks = [] | |
| # Get from all levels | |
| all_chunks.extend(list(self.l1_cache)) | |
| all_chunks.extend(list(self.l2_memory)) | |
| # Score chunks considering both relevance and recency | |
| now = datetime.now() | |
| for chunk in all_chunks: | |
| # Calculate recency score (0-1, where 1 is most recent) | |
| age_hours = (now - chunk.timestamp).total_seconds() / 3600 | |
| recency_score = max(0, 1 - (age_hours / 168)) # Decay over a week | |
| # Combine relevance and recency | |
| chunk.metadata['combined_score'] = ( | |
| chunk.relevance_score * (1 - recency_weight) + | |
| recency_score * recency_weight | |
| ) | |
| # Sort by combined score | |
| all_chunks.sort( | |
| key=lambda c: c.metadata.get('combined_score', 0), | |
| reverse=True | |
| ) | |
| # Update access statistics | |
| for chunk in all_chunks[:max_chunks]: | |
| chunk.update_access() | |
| return all_chunks[:max_chunks] | |
| def _add_to_l3(self, chunk: ContextChunk): | |
| """Add to long-term storage index""" | |
| key = hashlib.md5(chunk.content.encode()).hexdigest()[:16] | |
| self.l3_index[key] = { | |
| 'source': chunk.source, | |
| 'timestamp': chunk.timestamp.isoformat(), | |
| 'relevance': chunk.relevance_score, | |
| 'summary': chunk.content[:100], # Store summary only | |
| 'access_count': chunk.access_count | |
| } | |
| # Periodically save | |
| if len(self.l3_index) % 10 == 0: | |
| self.save_l3() | |
| def save_l3(self): | |
| """Save long-term memory to disk""" | |
| with open(self.l3_storage_path, 'w') as f: | |
| json.dump(self.l3_index, f, indent=2) | |
| def load_l3(self): | |
| """Load long-term memory from disk""" | |
| if self.l3_storage_path.exists(): | |
| try: | |
| with open(self.l3_storage_path, 'r') as f: | |
| self.l3_index = json.load(f) | |
| except Exception as e: | |
| logger.error(f"Error loading L3 memory: {e}") | |
| class MultiModalContext: | |
| """ | |
| Handle different types of context beyond text | |
| Temporal, spatial, participant states, intentional, cultural | |
| """ | |
| def __init__(self): | |
| self.temporal_context: List[Dict] = [] # Time-based relationships | |
| self.spatial_context: Dict = {} # Location/geometry | |
| self.participant_states: Dict[str, Dict] = {} # Entity tracking | |
| self.intentional_context: Dict = {} # Goals and motivations | |
| self.cultural_context: Dict = {} # Social/cultural nuances | |
| def add_temporal_context( | |
| self, | |
| event: str, | |
| timestamp: datetime, | |
| duration: Optional[timedelta] = None, | |
| related_events: List[str] = None | |
| ): | |
| """Add time-based context""" | |
| self.temporal_context.append({ | |
| 'event': event, | |
| 'timestamp': timestamp, | |
| 'duration': duration, | |
| 'related': related_events or [] | |
| }) | |
| # Sort by timestamp | |
| self.temporal_context.sort(key=lambda x: x['timestamp']) | |
| def add_participant_state( | |
| self, | |
| participant_id: str, | |
| state: Dict, | |
| timestamp: Optional[datetime] = None | |
| ): | |
| """Track participant/entity states over time""" | |
| if participant_id not in self.participant_states: | |
| self.participant_states[participant_id] = { | |
| 'current': state, | |
| 'history': [] | |
| } | |
| else: | |
| # Archive current state | |
| self.participant_states[participant_id]['history'].append({ | |
| 'state': self.participant_states[participant_id]['current'], | |
| 'timestamp': timestamp or datetime.now() | |
| }) | |
| self.participant_states[participant_id]['current'] = state | |
| def add_intentional_context( | |
| self, | |
| goal: str, | |
| motivation: str, | |
| constraints: List[str] = None, | |
| priority: float = 0.5 | |
| ): | |
| """Add goals and motivations""" | |
| self.intentional_context[goal] = { | |
| 'motivation': motivation, | |
| 'constraints': constraints or [], | |
| 'priority': priority, | |
| 'added': datetime.now() | |
| } | |
| def get_multimodal_summary(self) -> Dict: | |
| """Get summary of all context types""" | |
| return { | |
| 'temporal_events': len(self.temporal_context), | |
| 'tracked_participants': len(self.participant_states), | |
| 'active_goals': len(self.intentional_context), | |
| 'has_spatial': bool(self.spatial_context), | |
| 'has_cultural': bool(self.cultural_context) | |
| } | |
| class ContextEngineer: | |
| """ | |
| Main context engineering orchestrator | |
| Implements the complete 3-step framework | |
| """ | |
| def __init__(self): | |
| self.flywheel = DataFlywheel() | |
| self.processor = ContextProcessor() | |
| self.memory = MemoryHierarchy() | |
| self.multimodal = MultiModalContext() | |
| def engineer_context( | |
| self, | |
| query: str, | |
| raw_sources: List[Tuple[str, str]], # (source_name, content) | |
| multimodal_data: Optional[Dict] = None | |
| ) -> Dict[str, Any]: | |
| """ | |
| Complete context engineering pipeline | |
| Step 1: Retrieval & Generation | |
| Step 2: Processing | |
| Step 3: Management | |
| """ | |
| # Step 1: Retrieval & Generation | |
| # Get recommended sources from flywheel | |
| recommended = self.flywheel.get_recommended_sources(query) | |
| # Prioritize recommended sources | |
| prioritized_sources = [] | |
| for source_name, content in raw_sources: | |
| priority = 2.0 if source_name in recommended else 1.0 | |
| prioritized_sources.append((source_name, content, priority)) | |
| # Step 2: Processing | |
| all_chunks = [] | |
| for source_name, content, priority in prioritized_sources: | |
| chunks = self.processor.process_context(content, query, source_name) | |
| # Apply priority boost | |
| for chunk in chunks: | |
| chunk.relevance_score *= priority | |
| all_chunks.extend(chunks) | |
| # Add to memory hierarchy | |
| for chunk in all_chunks: | |
| self.memory.add_context(chunk) | |
| # Step 3: Management | |
| # Retrieve optimized context | |
| final_chunks = self.memory.retrieve(query, max_chunks=10) | |
| # Add multimodal context if provided | |
| if multimodal_data: | |
| for key, value in multimodal_data.items(): | |
| if key == 'temporal': | |
| for event in value: | |
| self.multimodal.add_temporal_context(**event) | |
| elif key == 'participants': | |
| for pid, state in value.items(): | |
| self.multimodal.add_participant_state(pid, state) | |
| elif key == 'goals': | |
| for goal, details in value.items(): | |
| self.multimodal.add_intentional_context(goal, **details) | |
| # Build final context | |
| context = { | |
| 'primary_context': '\n\n'.join([c.content for c in final_chunks[:5]]), | |
| 'supporting_context': '\n'.join([c.content for c in final_chunks[5:10]]), | |
| 'metadata': { | |
| 'total_chunks': len(all_chunks), | |
| 'selected_chunks': len(final_chunks), | |
| 'avg_relevance': np.mean([c.relevance_score for c in final_chunks]) if final_chunks else 0, | |
| 'compression_ratio': np.mean([c.compression_ratio for c in final_chunks]) if final_chunks else 1, | |
| 'sources_used': list(set(c.source for c in final_chunks)), | |
| 'multimodal': self.multimodal.get_multimodal_summary() | |
| }, | |
| 'chunks': final_chunks # For feedback loop | |
| } | |
| return context | |
| def record_feedback( | |
| self, | |
| context: Dict, | |
| output: str, | |
| success_score: float | |
| ): | |
| """Record feedback for continuous improvement""" | |
| self.flywheel.record_success( | |
| context['primary_context'], | |
| output, | |
| success_score, | |
| context['chunks'] | |
| ) | |
| def optimize_memory(self): | |
| """Optimize memory by removing low-value chunks""" | |
| # This would implement memory pruning based on: | |
| # - Access frequency | |
| # - Age | |
| # - Relevance scores | |
| # - Compression potential | |
| pass | |
| # Demo usage | |
| def demo_context_engineering(): | |
| """Demonstrate context engineering""" | |
| engineer = ContextEngineer() | |
| # Sample sources | |
| sources = [ | |
| ("resume", "10 years experience in Python, AI, Machine Learning..."), | |
| ("job_description", "Looking for senior AI engineer with Python skills..."), | |
| ("company_research", "TechCorp is a leading AI company focused on NLP...") | |
| ] | |
| # Multimodal context | |
| multimodal = { | |
| 'temporal': [ | |
| { | |
| 'event': 'Application deadline', | |
| 'timestamp': datetime.now() + timedelta(days=7) | |
| } | |
| ], | |
| 'participants': { | |
| 'applicant': {'status': 'preparing', 'confidence': 0.8} | |
| }, | |
| 'goals': { | |
| 'get_interview': { | |
| 'motivation': 'Career advancement', | |
| 'constraints': ['Remote only'], | |
| 'priority': 0.9 | |
| } | |
| } | |
| } | |
| # Engineer context | |
| context = engineer.engineer_context( | |
| query="Write a cover letter for AI engineer position", | |
| raw_sources=sources, | |
| multimodal_data=multimodal | |
| ) | |
| print("Engineered Context:") | |
| print(f"Primary: {context['primary_context'][:200]}...") | |
| print(f"Metadata: {context['metadata']}") | |
| # Simulate success and record feedback | |
| engineer.record_feedback(context, "Generated cover letter...", 0.9) | |
| print("\nFlywheel learned patterns for future use!") | |
| if __name__ == "__main__": | |
| demo_context_engineering() |