Spaces:
Running
Running
| # DEPENDENCIES | |
| from typing import List | |
| from typing import Optional | |
| from collections import defaultdict | |
| from config.settings import get_settings | |
| from config.models import ChunkWithScore | |
| from config.models import DocumentChunk | |
| from config.logging_config import get_logger | |
| from utils.error_handler import handle_errors | |
| from chunking.token_counter import get_token_counter | |
| from utils.error_handler import ContextAssemblyError | |
| from retrieval.citation_tracker import CitationTracker | |
| # Setup Settings and Logging | |
| settings = get_settings() | |
| logger = get_logger(__name__) | |
| class ContextAssembler: | |
| """ | |
| Context assembly and optimization: Assembles retrieved chunks into optimal context | |
| for LLM processing with token limits and quality optimization | |
| """ | |
| def __init__(self, max_context_tokens: int = None, strategy: str = "score_based"): | |
| """ | |
| Initialize context assembler | |
| Arguments: | |
| ---------- | |
| max_context_tokens { int } : Maximum tokens for assembled context (default from settings) | |
| strategy { str } : Assembly strategy ('score_based', 'diversity', 'sequential') | |
| """ | |
| self.logger = logger | |
| self.settings = get_settings() | |
| self.max_context_tokens = max_context_tokens or (settings.CONTEXT_WINDOW - 1000) | |
| self.strategy = strategy | |
| self.citation_tracker = CitationTracker() | |
| # Strategy configurations | |
| self.strategy_configs = {"score_based" : {"diversity_penalty": 0.1, "min_chunk_score": 0.1}, | |
| "diversity" : {"diversity_penalty": 0.3, "min_chunk_score": 0.05}, | |
| "sequential" : {"diversity_penalty": 0.0, "min_chunk_score": 0.0}, | |
| } | |
| self.logger.info(f"ContextAssembler initialized: max_tokens={self.max_context_tokens}, strategy={strategy}") | |
| def assemble_context(self, chunks: List[ChunkWithScore], query: str = "", include_citations: bool = True, format_for_llm: bool = True) -> str: | |
| """ | |
| Assemble context from retrieved chunks | |
| Arguments: | |
| ---------- | |
| chunks { list } : List of retrieved chunks | |
| query { str } : Original query (for relevance optimization) | |
| include_citations { bool } : Include citation markers in context | |
| format_for_llm { bool } : Format context for LLM consumption | |
| Returns: | |
| -------- | |
| { str } : Assembled context string | |
| """ | |
| if not chunks: | |
| self.logger.warning("No chunks provided for context assembly") | |
| return "" | |
| try: | |
| self.logger.info(f"Starting context assembly with {len(chunks)} chunks") | |
| # Filter and sort chunks based on strategy | |
| filtered_chunks = self._filter_chunks(chunks = chunks) | |
| self.logger.info(f"After filtering: {len(filtered_chunks)} chunks") | |
| # Ensure we have chunks after filtering | |
| if not filtered_chunks: | |
| self.logger.error("All chunks filtered out - using top chunk from original list") | |
| filtered_chunks = [chunks[0]] | |
| sorted_chunks = self._sort_chunks(chunks = filtered_chunks, | |
| query = query, | |
| ) | |
| # Select chunks within token limit | |
| selected_chunks = self._select_chunks_by_tokens(chunks = sorted_chunks) | |
| self.logger.info(f"After token selection: {len(selected_chunks)} chunks selected") | |
| # Validate selection results with progressive fallback | |
| if not selected_chunks: | |
| self.logger.warning("Token selection returned 0 chunks - using progressive fallback") | |
| # Fallback 1 - Try with higher token budget (allow 10% overflow) | |
| if sorted_chunks: | |
| overflow_budget = int(self.max_context_tokens * 1.10) | |
| self.logger.info(f"Fallback 1: Allowing {overflow_budget} tokens (10% overflow)") | |
| old_budget = self.max_context_tokens | |
| self.max_context_tokens = overflow_budget | |
| selected_chunks = self._select_chunks_by_tokens(chunks = sorted_chunks) | |
| self.max_context_tokens = old_budget | |
| # Fallback 2 - Force include top chunks regardless of token count | |
| if not selected_chunks and sorted_chunks: | |
| self.logger.warning("Fallback 2: Force including top 3 chunks") | |
| selected_chunks = sorted_chunks[:min(3, len(sorted_chunks))] | |
| # Fallback 3 - Use first available chunk | |
| if not selected_chunks and chunks: | |
| self.logger.error("Fallback 3: Emergency - using only first chunk") | |
| selected_chunks = [chunks[0]] | |
| # Complete failure | |
| if not selected_chunks: | |
| self.logger.error("All fallbacks exhausted - no chunks available") | |
| raise ContextAssemblyError("No valid chunks available after all selection attempts") | |
| # Assemble context | |
| if format_for_llm: | |
| context = self._format_for_llm(chunks = selected_chunks, | |
| include_citations = include_citations, | |
| ) | |
| else: | |
| context = self._format_simple(chunks = selected_chunks, | |
| include_citations = include_citations, | |
| ) | |
| # Validate assembled context | |
| if not context or not context.strip(): | |
| self.logger.error("Assembled context is empty") | |
| raise ContextAssemblyError("Context assembly produced empty result") | |
| context_tokens = self._count_tokens(text = context) | |
| self.logger.info(f"Assembled context: {len(selected_chunks)} chunks, {context_tokens} tokens, {len(context)} chars") | |
| # Warn if context exceeds limit | |
| if (context_tokens > self.max_context_tokens): | |
| overflow_pct = ((context_tokens - self.max_context_tokens) / self.max_context_tokens) * 100 | |
| self.logger.warning(f"Context exceeds limit by {overflow_pct:.1f}% ({context_tokens} > {self.max_context_tokens})") | |
| return context | |
| except ContextAssemblyError: | |
| # Re-raise context assembly errors | |
| raise | |
| except Exception as e: | |
| self.logger.error(f"Context assembly failed with unexpected error: {repr(e)}", exc_info = True) | |
| # Emergency fallback: return first chunk text only | |
| if (chunks and len(chunks) > 0): | |
| self.logger.warning("Emergency fallback: returning first chunk text only") | |
| return chunks[0].chunk.text | |
| # No chunks available at all | |
| raise ContextAssemblyError(f"Context assembly failed with no fallback available: {repr(e)}") | |
| def _filter_chunks(self, chunks: List[ChunkWithScore]) -> List[ChunkWithScore]: | |
| """ | |
| Filter chunks based on quality and strategy | |
| """ | |
| if not chunks: | |
| return [] | |
| strategy_config = self.strategy_configs.get(self.strategy, self.strategy_configs["score_based"]) | |
| min_score = strategy_config["min_chunk_score"] | |
| # Don't filter if all scores are below threshold | |
| all_below_threshold = all(chunk.score < min_score for chunk in chunks) | |
| if all_below_threshold: | |
| self.logger.warning(f"All chunks below min_score {min_score}, keeping top chunks anyway") | |
| # Keep at least top 5 chunks regardless of score | |
| filtered = sorted(chunks, key = lambda x: x.score, reverse = True)[:5] | |
| else: | |
| filtered = [chunk for chunk in chunks if chunk.score >= min_score] | |
| self.logger.info(f"Filtered {len(chunks)} -> {len(filtered)} chunks (min_score={min_score})") | |
| # Remove very similar chunks if diversity is important | |
| if (strategy_config["diversity_penalty"] > 0): | |
| filtered = self._apply_diversity_filter(filtered, strategy_config["diversity_penalty"]) | |
| return filtered | |
| def _apply_diversity_filter(self, chunks: List[ChunkWithScore], diversity_penalty: float) -> List[ChunkWithScore]: | |
| """ | |
| Apply diversity filtering to reduce redundancy - FIXED | |
| Arguments: | |
| ---------- | |
| chunks { list } : Chunks to filter | |
| diversity_penalty { float } : Penalty factor for same-document chunks | |
| Returns: | |
| -------- | |
| { list } : Diversified chunks | |
| """ | |
| if (len(chunks) <= 1): | |
| return chunks | |
| # Simple diversity: penalize chunks from same document | |
| document_scores = dict() | |
| for chunk in chunks: | |
| doc_id = chunk.chunk.document_id | |
| if doc_id not in document_scores: | |
| document_scores[doc_id] = [] | |
| document_scores[doc_id].append(chunk.score) | |
| # Adjust scores based on document diversity | |
| diversified = list() | |
| for chunk in chunks: | |
| doc_id = chunk.chunk.document_id | |
| doc_chunk_count = len(document_scores[doc_id]) | |
| # Penalize if multiple chunks from same document | |
| penalty = diversity_penalty * (doc_chunk_count - 1) | |
| adjusted_score = max(0.0, chunk.score - penalty) | |
| if (adjusted_score > 0): | |
| diversified_chunk = ChunkWithScore(chunk = chunk.chunk, | |
| score = adjusted_score, | |
| rank = chunk.rank, | |
| retrieval_method = chunk.retrieval_method + "_diversified" | |
| ) | |
| diversified.append(diversified_chunk) | |
| # Re-sort by adjusted scores | |
| diversified.sort(key = lambda x: x.score, reverse = True) | |
| # Re-normalize scores to maintain 0-1 range after penalty | |
| if diversified: | |
| max_score = max(chunk.score for chunk in diversified) | |
| if (max_score > 0): | |
| for chunk in diversified: | |
| chunk.score = chunk.score / max_score | |
| return diversified | |
| def _sort_chunks(self, chunks: List[ChunkWithScore], query: str) -> List[ChunkWithScore]: | |
| """ | |
| Sort chunks based on strategy | |
| """ | |
| if (self.strategy == "sequential"): | |
| # Sort by document order and position | |
| chunks.sort(key = lambda x: (x.chunk.document_id, | |
| x.chunk.page_number or 0, | |
| x.chunk.chunk_index | |
| ) | |
| ) | |
| else: | |
| # Default: sort by score (already sorted by retrieval) | |
| chunks.sort(key = lambda x: x.score, reverse = True) | |
| return chunks | |
| def _select_chunks_by_tokens(self, chunks: List[ChunkWithScore]) -> List[ChunkWithScore]: | |
| """ | |
| Select chunks that fit within token limit - IMPROVED VERSION | |
| Arguments: | |
| ---------- | |
| chunks { list } : Chunks to select from | |
| Returns: | |
| -------- | |
| { list } : Selected chunks within token limit | |
| """ | |
| if not chunks: | |
| self.logger.error("No chunks provided to _select_chunks_by_tokens") | |
| return [] | |
| selected = list() | |
| total_tokens = 0 | |
| self.logger.info(f"Token selection: {len(chunks)} chunks, max={self.max_context_tokens} tokens") | |
| # Reserve tokens for system prompt and formatting overhead | |
| formatting_overhead = 200 | |
| available_tokens = self.max_context_tokens - formatting_overhead | |
| self.logger.debug(f"Available tokens after overhead: {available_tokens}") | |
| for i, chunk in enumerate(chunks): | |
| # Get or calculate chunk tokens | |
| chunk_tokens = chunk.chunk.token_count | |
| if (chunk_tokens is None) or (chunk_tokens <= 0): | |
| chunk_text = chunk.chunk.text if chunk.chunk.text else "" | |
| if not chunk_text: | |
| self.logger.warning(f"Chunk {i} has no text content, skipping") | |
| continue | |
| chunk_tokens = self._count_tokens(text = chunk_text) | |
| self.logger.debug(f"Chunk {i} calculated: {chunk_tokens} tokens from {len(chunk_text)} chars") | |
| # Reserve tokens for chunk separators and citations | |
| chunk_formatting = 25 | |
| total_needed = total_tokens + chunk_tokens + chunk_formatting | |
| if (total_needed <= available_tokens): | |
| selected.append(chunk) | |
| total_tokens += chunk_tokens + chunk_formatting | |
| self.logger.debug(f"Chunk {i}: score={chunk.score:.3f}, tokens={chunk_tokens}, total={total_tokens}/{available_tokens}") | |
| else: | |
| # Calculate remaining space | |
| remaining = available_tokens - total_tokens - chunk_formatting | |
| self.logger.debug(f"Chunk {i} exceeds limit: needs {chunk_tokens}, have {remaining} remaining") | |
| # Try partial chunk if we have reasonable space (at least 200 tokens) | |
| if (remaining >= 200): | |
| self.logger.info(f"Attempting partial chunk {i} with {remaining} available tokens") | |
| partial_chunk = self._create_partial_chunk(chunk = chunk, | |
| available_tokens = remaining, | |
| ) | |
| if partial_chunk: | |
| selected.append(partial_chunk) | |
| partial_tokens = self._count_tokens(text = partial_chunk.chunk.text) | |
| total_tokens += partial_tokens + chunk_formatting | |
| self.logger.info(f"Added partial chunk {i}: {partial_tokens} tokens") | |
| # Stop adding chunks - no more space | |
| self.logger.info(f"Stopping chunk selection at index {i}") | |
| break | |
| # Log selection summary | |
| utilization = (total_tokens / self.max_context_tokens * 100) if (self.max_context_tokens > 0) else 0 | |
| self.logger.info(f"Token selection complete:") | |
| self.logger.info(f"- Selected: {len(selected)}/{len(chunks)} chunks") | |
| self.logger.info(f"- Tokens: {total_tokens}/{self.max_context_tokens} ({utilization:.1f}% utilization)") | |
| # Warning if selection is poor | |
| if ((len(selected) == 0) and (len(chunks) > 0)): | |
| self.logger.error(f"- ZERO chunks selected from {len(chunks)} available!") | |
| self.logger.error(f"- Max tokens: {self.max_context_tokens}") | |
| self.logger.error(f"- Available after overhead: {available_tokens}") | |
| self.logger.error(f"- First chunk tokens: {chunks[0].chunk.token_count or 'unknown'}") | |
| # Diagnostic: check if first chunk is too large | |
| if chunks[0].chunk.text: | |
| first_chunk_tokens = self._count_tokens(chunks[0].chunk.text) | |
| self.logger.error(f"- First chunk actual tokens: {first_chunk_tokens}") | |
| if first_chunk_tokens > available_tokens: | |
| self.logger.error(f"- First chunk ({first_chunk_tokens} tokens) exceeds available space ({available_tokens} tokens)") | |
| return selected | |
| def _create_partial_chunk(self, chunk: ChunkWithScore, available_tokens: int) -> Optional[ChunkWithScore]: | |
| """ | |
| Create a partial chunk that fits within available tokens | |
| """ | |
| full_text = chunk.chunk.text | |
| # Try to truncate at sentence boundary | |
| sentences = full_text.split('. ') | |
| partial_text = "" | |
| for sentence in sentences: | |
| test_text = partial_text + sentence + ". " | |
| test_tokens = self._count_tokens(text = test_text) | |
| if (test_tokens <= available_tokens): | |
| partial_text = test_text | |
| else: | |
| break | |
| if partial_text: | |
| partial_chunk_obj = DocumentChunk(chunk_id = chunk.chunk.chunk_id + "_partial", | |
| document_id = chunk.chunk.document_id, | |
| text = partial_text.strip(), | |
| embedding = chunk.chunk.embedding, | |
| chunk_index = chunk.chunk.chunk_index, | |
| start_char = chunk.chunk.start_char, | |
| end_char = chunk.chunk.start_char + len(partial_text.strip()), | |
| page_number = chunk.chunk.page_number, | |
| section_title = chunk.chunk.section_title, | |
| token_count = self._count_tokens(text = partial_text.strip()), | |
| metadata = chunk.chunk.metadata, | |
| ) | |
| # Create partial chunk | |
| partial_chunk = ChunkWithScore(chunk = partial_chunk_obj, | |
| score = chunk.score * 0.8, | |
| rank = chunk.rank, | |
| retrieval_method = chunk.retrieval_method + "_partial", | |
| ) | |
| return partial_chunk | |
| return None | |
| def _format_for_llm(self, chunks: List[ChunkWithScore], include_citations: bool) -> str: | |
| """ | |
| Format context for LLM consumption with citations | |
| """ | |
| context_parts = list() | |
| for i, chunk_with_score in enumerate(chunks, 1): | |
| chunk = chunk_with_score.chunk | |
| # Build citation marker | |
| citation_marker = f"[{i}]" if include_citations else "" | |
| # Build source info | |
| source_info = list() | |
| if chunk.page_number: | |
| source_info.append(f"Page {chunk.page_number}") | |
| if chunk.section_title: | |
| source_info.append(f"Section: {chunk.section_title}") | |
| source_str = f"({', '.join(source_info)})" if source_info else "" | |
| # Format chunk | |
| if include_citations and source_info: | |
| chunk_text = f"{citation_marker} {source_str}\n{chunk.text}" | |
| elif include_citations: | |
| chunk_text = f"{citation_marker}\n{chunk.text}" | |
| else: | |
| chunk_text = chunk.text | |
| context_parts.append(chunk_text) | |
| return "\n\n".join(context_parts) | |
| def _format_simple(self, chunks: List[ChunkWithScore], include_citations: bool) -> str: | |
| """ | |
| Simple formatting without extensive metadata | |
| """ | |
| context_parts = list() | |
| for i, chunk_with_score in enumerate(chunks, 1): | |
| chunk = chunk_with_score.chunk | |
| if include_citations: | |
| context_parts.append(f"[{i}] {chunk.text}") | |
| else: | |
| context_parts.append(chunk.text) | |
| return "\n\n".join(context_parts) | |
| def _count_tokens(self, text: str) -> int: | |
| """ | |
| Count tokens in text with conservative fallback | |
| Arguments: | |
| ---------- | |
| text { str } : Text to count tokens for | |
| Returns: | |
| -------- | |
| { int } : Token count | |
| """ | |
| if not text: | |
| return 0 | |
| try: | |
| token_counter = get_token_counter() | |
| return token_counter.count_tokens(text) | |
| except Exception as e: | |
| # Conservative fallback calculation for technical text | |
| self.logger.debug(f"Token counter error, using conservative approximation: {repr(e)}") | |
| # More accurate approximation for technical/scientific text: Count words (split by whitespace) | |
| words = text.split() | |
| # Technical text has more subword tokenization: Average: 1 word ≈ 1.8 tokens for technical English (conservative) | |
| estimated_tokens = int(len(words) * 1.8) | |
| # Add overhead for punctuation, numbers, special chars (15%) | |
| estimated_tokens = int(estimated_tokens * 1.15) | |
| # Add safety margin (10%) | |
| estimated_tokens = int(estimated_tokens * 1.10) | |
| # Ensure minimum reasonable value | |
| estimated_tokens = max(10, estimated_tokens) | |
| self.logger.debug(f"Conservative estimate: {len(words)} words → {estimated_tokens} tokens") | |
| return estimated_tokens | |
| def optimize_context_quality(self, context: str, chunks: List[ChunkWithScore]) -> str: | |
| """ | |
| Optimize context quality by removing redundancies and improving flow | |
| """ | |
| # Remove duplicate sentences | |
| sentences = context.split('. ') | |
| unique_sentences = list() | |
| for sentence in sentences: | |
| sentence_clean = sentence.strip() | |
| if (sentence_clean and (sentence_clean not in unique_sentences)): | |
| unique_sentences.append(sentence_clean) | |
| optimized = '. '.join(unique_sentences) | |
| # Ensure proper citation consistency | |
| if '[' in optimized: | |
| optimized = self.citation_tracker.ensure_citation_consistency(optimized, chunks) | |
| return optimized | |
| def get_context_statistics(self, context: str, chunks: List[ChunkWithScore]) -> dict: | |
| """ | |
| Get statistics about assembled context | |
| """ | |
| token_count = self._count_tokens(text = context) | |
| char_count = len(context) | |
| # Citation statistics | |
| citation_stats = self.citation_tracker.get_citation_statistics(context, chunks) | |
| # Source diversity | |
| source_docs = set(chunk.chunk.document_id for chunk in chunks) | |
| source_pages = set(chunk.chunk.page_number for chunk in chunks if chunk.chunk.page_number) | |
| return {"total_tokens" : token_count, | |
| "total_chars" : char_count, | |
| "chunk_count" : len(chunks), | |
| "source_documents" : len(source_docs), | |
| "source_pages" : len(source_pages), | |
| "token_utilization" : (token_count / self.max_context_tokens * 100) if self.max_context_tokens > 0 else 0, | |
| "citation_stats" : citation_stats, | |
| "strategy" : self.strategy, | |
| } | |
| # Global context assembler instance | |
| _context_assembler = None | |
| def get_context_assembler() -> ContextAssembler: | |
| """ | |
| Get global context assembler instance | |
| """ | |
| global _context_assembler | |
| if _context_assembler is None: | |
| _context_assembler = ContextAssembler() | |
| return _context_assembler | |
| def assemble_context_for_llm(chunks: List[ChunkWithScore], query: str = "", **kwargs) -> str: | |
| """ | |
| Convenience function for context assembly | |
| """ | |
| assembler = get_context_assembler() | |
| return assembler.assemble_context(chunks, query, **kwargs) |