""" Context shaping module for optimizing retrieved chunks. Performs: - Deduplication: Remove semantically similar chunks - Token budgeting: Allocate tokens based on relevance - Pruning: Remove irrelevant sentences within chunks - Compression: Summarize if over budget """ from dataclasses import dataclass from typing import List, Dict, Any, Optional, Tuple import re # Lazy imports _sentence_model = None @dataclass class ContextShapeResult: """Result of context shaping.""" chunks: List[Dict[str, Any]] original_tokens: int final_tokens: int chunks_removed: int compression_applied: bool def _estimate_tokens(text: str) -> int: """Estimate token count (rough: 1 token ≈ 4 chars).""" return len(text) // 4 def _get_sentence_model(): """Lazy load sentence transformer for similarity.""" global _sentence_model if _sentence_model is None: try: from sentence_transformers import SentenceTransformer _sentence_model = SentenceTransformer("all-MiniLM-L6-v2") except ImportError: return None return _sentence_model def _compute_similarity(text1: str, text2: str) -> float: """Compute cosine similarity between two texts.""" model = _get_sentence_model() if model is None: return 0.0 try: import numpy as np embeddings = model.encode([text1, text2]) cos_sim = np.dot(embeddings[0], embeddings[1]) / ( np.linalg.norm(embeddings[0]) * np.linalg.norm(embeddings[1]) ) return float(cos_sim) except Exception: return 0.0 def _split_sentences(text: str) -> List[str]: """Split text into sentences.""" # Simple sentence splitter sentences = re.split(r'(?<=[.!?])\s+', text) return [s.strip() for s in sentences if s.strip()] def deduplicate_chunks( chunks: List[Dict[str, Any]], threshold: float = 0.85 ) -> Tuple[List[Dict[str, Any]], int]: """ Remove chunks with high semantic similarity. Args: chunks: List of chunks threshold: Similarity threshold for deduplication Returns: Tuple of (deduplicated chunks, count removed) """ if len(chunks) <= 1: return chunks, 0 # Keep track of which chunks to keep keep_indices = [] removed = 0 for i, chunk in enumerate(chunks): text_i = chunk.get("text", "") is_duplicate = False # Compare with already kept chunks for j in keep_indices: text_j = chunks[j].get("text", "") similarity = _compute_similarity(text_i, text_j) if similarity >= threshold: is_duplicate = True removed += 1 break if not is_duplicate: keep_indices.append(i) return [chunks[i] for i in keep_indices], removed def budget_chunks( chunks: List[Dict[str, Any]], token_budget: int, min_tokens_per_chunk: int = 50 ) -> List[Dict[str, Any]]: """ Allocate token budget across chunks based on relevance scores. Args: chunks: List of chunks with scores token_budget: Total token budget min_tokens_per_chunk: Minimum tokens to keep per chunk Returns: Chunks with text trimmed to fit budget """ if not chunks: return [] # Calculate total relevance for weighting total_score = sum(c.get("score", 0.5) for c in chunks) if total_score == 0: total_score = len(chunks) # Equal weight budgeted = [] remaining_budget = token_budget for chunk in chunks: text = chunk.get("text", "") score = chunk.get("score", 0.5) # Allocate budget proportionally to score chunk_budget = int((score / total_score) * token_budget) chunk_budget = max(chunk_budget, min_tokens_per_chunk) chunk_budget = min(chunk_budget, remaining_budget) if chunk_budget <= 0: continue # Trim text if needed current_tokens = _estimate_tokens(text) if current_tokens > chunk_budget: # Truncate to fit budget (keep first N chars) char_limit = chunk_budget * 4 text = text[:char_limit].rsplit(" ", 1)[0] + "..." new_chunk = chunk.copy() new_chunk["text"] = text new_chunk["budget_allocated"] = chunk_budget budgeted.append(new_chunk) remaining_budget -= _estimate_tokens(text) if remaining_budget <= 0: break return budgeted def prune_irrelevant_sentences( chunk: Dict[str, Any], query: str, relevance_threshold: float = 0.3 ) -> Dict[str, Any]: """ Remove sentences within a chunk that are not relevant to the query. Args: chunk: Chunk to prune query: Query for relevance comparison relevance_threshold: Minimum similarity to keep sentence Returns: Chunk with irrelevant sentences removed """ text = chunk.get("text", "") if not text: return chunk sentences = _split_sentences(text) if len(sentences) <= 1: return chunk # Score each sentence relevant_sentences = [] for sentence in sentences: if len(sentence) < 10: # Keep short fragments relevant_sentences.append(sentence) continue similarity = _compute_similarity(query, sentence) if similarity >= relevance_threshold: relevant_sentences.append(sentence) if not relevant_sentences: # Keep at least the first sentence relevant_sentences = sentences[:1] new_chunk = chunk.copy() new_chunk["text"] = " ".join(relevant_sentences) new_chunk["sentences_pruned"] = len(sentences) - len(relevant_sentences) return new_chunk def compress_with_llm( chunks: List[Dict[str, Any]], query: str, target_tokens: int ) -> List[Dict[str, Any]]: """ Compress chunks using LLM summarization. Args: chunks: Chunks to compress query: Query for context-aware compression target_tokens: Target token count Returns: Compressed chunks """ try: from src.llm_providers import call_llm except ImportError: return chunks # Combine all chunk texts combined = "\n\n".join(c.get("text", "") for c in chunks) current_tokens = _estimate_tokens(combined) if current_tokens <= target_tokens: return chunks prompt = f"""Summarize the following context to approximately {target_tokens * 4} characters. Preserve all key facts relevant to this query: {query} Keep specific names, numbers, and dates. Context: {combined} Summary:""" try: response = call_llm(prompt=prompt, temperature=0.0, max_tokens=target_tokens) summary = response.get("text", "").strip() # Return as single compressed chunk return [{ "id": "compressed_context", "text": summary, "score": max(c.get("score", 0) for c in chunks), "compressed_from": len(chunks) }] except Exception: return chunks def shape_context( chunks: List[Dict[str, Any]], query: str, token_budget: int = 3000, dedup_threshold: float = 0.85, enable_pruning: bool = True, enable_compression: bool = True, relevance_threshold: float = 0.3 ) -> ContextShapeResult: """ Shape context by deduplicating, pruning, and compressing chunks. Args: chunks: Retrieved chunks query: User query for relevance token_budget: Maximum tokens for final context dedup_threshold: Similarity threshold for deduplication enable_pruning: Whether to prune irrelevant sentences enable_compression: Whether to compress if over budget relevance_threshold: Minimum relevance for sentence pruning Returns: ContextShapeResult with shaped chunks and metadata """ if not chunks: return ContextShapeResult( chunks=[], original_tokens=0, final_tokens=0, chunks_removed=0, compression_applied=False ) # Calculate original token count original_tokens = sum(_estimate_tokens(c.get("text", "")) for c in chunks) # Step 1: Deduplicate deduped, removed = deduplicate_chunks(chunks, threshold=dedup_threshold) # Step 2: Prune irrelevant sentences (optional) if enable_pruning: deduped = [ prune_irrelevant_sentences(c, query, relevance_threshold) for c in deduped ] # Step 3: Budget allocation budgeted = budget_chunks(deduped, token_budget) # Step 4: Check if compression needed current_tokens = sum(_estimate_tokens(c.get("text", "")) for c in budgeted) compression_applied = False if enable_compression and current_tokens > token_budget * 1.2: budgeted = compress_with_llm(budgeted, query, token_budget) compression_applied = True final_tokens = sum(_estimate_tokens(c.get("text", "")) for c in budgeted) return ContextShapeResult( chunks=budgeted, original_tokens=original_tokens, final_tokens=final_tokens, chunks_removed=removed, compression_applied=compression_applied )