Spaces:
Sleeping
Sleeping
| """ | |
| Context Management Module for Advanced RAG | |
| Handles context creation and management for LLM generation. | |
| """ | |
| from typing import List, Dict | |
| from collections import defaultdict | |
| from config.config import MAX_CONTEXT_LENGTH | |
| class ContextManager: | |
| """Manages context creation for LLM generation.""" | |
| def __init__(self): | |
| """Initialize the context manager.""" | |
| print("β Context Manager initialized") | |
| def create_enhanced_context(self, question: str, results: List[Dict], max_length: int = MAX_CONTEXT_LENGTH) -> str: | |
| """Create enhanced context ensuring each query contributes equally.""" | |
| # Group results by expanded query index | |
| query_to_chunks = defaultdict(list) | |
| for i, result in enumerate(results): | |
| # Find the most relevant expanded query for this chunk | |
| if 'contributing_queries' in result and result['contributing_queries']: | |
| # Use the highest scoring contributing query | |
| best_contrib = max(result['contributing_queries'], key=lambda cq: cq.get('semantic_score', cq.get('bm25_score', 0))) | |
| query_idx = best_contrib['query_idx'] | |
| else: | |
| query_idx = 0 # fallback to first query | |
| query_to_chunks[query_idx].append((i, result)) | |
| # Sort chunks within each query by their relevance scores | |
| for q_idx in query_to_chunks: | |
| query_to_chunks[q_idx].sort(key=lambda x: x[1].get('rerank_score', x[1].get('final_score', x[1].get('score', 0))), reverse=True) | |
| # Calculate chunks per query (should be 3 for each query with total budget = 9 and 3 queries) | |
| num_queries = len(query_to_chunks) | |
| if num_queries == 0: | |
| return "" | |
| # Ensure each query contributes equally (round-robin with guaranteed slots) | |
| context_parts = [] | |
| current_length = 0 | |
| added_chunks = set() | |
| # Calculate how many chunks each query should contribute | |
| chunks_per_query = len(results) // num_queries if num_queries > 0 else len(results) | |
| extra_chunks = len(results) % num_queries | |
| print(f"π Context Creation: {num_queries} queries, {chunks_per_query} chunks per query (+{extra_chunks} extra)") | |
| for q_idx in sorted(query_to_chunks.keys()): | |
| # Determine how many chunks this query should contribute | |
| query_chunk_limit = chunks_per_query + (1 if q_idx < extra_chunks else 0) | |
| query_chunks_added = 0 | |
| print(f" Query {q_idx+1}: Adding up to {query_chunk_limit} chunks") | |
| for i, result in query_to_chunks[q_idx]: | |
| if i not in added_chunks and query_chunks_added < query_chunk_limit: | |
| text = result['payload'].get('text', '') | |
| relevance_info = "" | |
| if 'rerank_score' in result: | |
| relevance_info = f" [Relevance: {result['rerank_score']:.2f}]" | |
| elif 'final_score' in result: | |
| relevance_info = f" [Score: {result['final_score']:.2f}]" | |
| doc_text = f"[Query {q_idx+1} Doc {len(added_chunks)+1}]{relevance_info}\n{text}\n" | |
| if current_length + len(doc_text) > max_length: | |
| print(f" β οΈ Context length limit reached at {current_length} chars") | |
| break | |
| context_parts.append(doc_text) | |
| current_length += len(doc_text) | |
| added_chunks.add(i) | |
| query_chunks_added += 1 | |
| print(f" Query {q_idx+1}: Added {query_chunks_added} chunks") | |
| print(f"π Final context: {len(added_chunks)} chunks, {current_length} chars") | |
| return "\n".join(context_parts) | |