Spaces:
Sleeping
Sleeping
File size: 3,973 Bytes
e8051be |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 |
"""
Context Management Module for Advanced RAG
Handles context creation and management for LLM generation.
"""
from typing import List, Dict
from collections import defaultdict
from config.config import MAX_CONTEXT_LENGTH
class ContextManager:
"""Manages context creation for LLM generation."""
def __init__(self):
"""Initialize the context manager."""
print("✅ Context Manager initialized")
def create_enhanced_context(self, question: str, results: List[Dict], max_length: int = MAX_CONTEXT_LENGTH) -> str:
"""Create enhanced context ensuring each query contributes equally."""
# Group results by expanded query index
query_to_chunks = defaultdict(list)
for i, result in enumerate(results):
# Find the most relevant expanded query for this chunk
if 'contributing_queries' in result and result['contributing_queries']:
# Use the highest scoring contributing query
best_contrib = max(result['contributing_queries'], key=lambda cq: cq.get('semantic_score', cq.get('bm25_score', 0)))
query_idx = best_contrib['query_idx']
else:
query_idx = 0 # fallback to first query
query_to_chunks[query_idx].append((i, result))
# Sort chunks within each query by their relevance scores
for q_idx in query_to_chunks:
query_to_chunks[q_idx].sort(key=lambda x: x[1].get('rerank_score', x[1].get('final_score', x[1].get('score', 0))), reverse=True)
# Calculate chunks per query (should be 3 for each query with total budget = 9 and 3 queries)
num_queries = len(query_to_chunks)
if num_queries == 0:
return ""
# Ensure each query contributes equally (round-robin with guaranteed slots)
context_parts = []
current_length = 0
added_chunks = set()
# Calculate how many chunks each query should contribute
chunks_per_query = len(results) // num_queries if num_queries > 0 else len(results)
extra_chunks = len(results) % num_queries
print(f"📊 Context Creation: {num_queries} queries, {chunks_per_query} chunks per query (+{extra_chunks} extra)")
for q_idx in sorted(query_to_chunks.keys()):
# Determine how many chunks this query should contribute
query_chunk_limit = chunks_per_query + (1 if q_idx < extra_chunks else 0)
query_chunks_added = 0
print(f" Query {q_idx+1}: Adding up to {query_chunk_limit} chunks")
for i, result in query_to_chunks[q_idx]:
if i not in added_chunks and query_chunks_added < query_chunk_limit:
text = result['payload'].get('text', '')
relevance_info = ""
if 'rerank_score' in result:
relevance_info = f" [Relevance: {result['rerank_score']:.2f}]"
elif 'final_score' in result:
relevance_info = f" [Score: {result['final_score']:.2f}]"
doc_text = f"[Query {q_idx+1} Doc {len(added_chunks)+1}]{relevance_info}\n{text}\n"
if current_length + len(doc_text) > max_length:
print(f" ⚠️ Context length limit reached at {current_length} chars")
break
context_parts.append(doc_text)
current_length += len(doc_text)
added_chunks.add(i)
query_chunks_added += 1
print(f" Query {q_idx+1}: Added {query_chunks_added} chunks")
print(f"📝 Final context: {len(added_chunks)} chunks, {current_length} chars")
return "\n".join(context_parts)
|