File size: 3,973 Bytes
e8051be
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
"""

Context Management Module for Advanced RAG

Handles context creation and management for LLM generation.

"""

from typing import List, Dict
from collections import defaultdict
from config.config import MAX_CONTEXT_LENGTH


class ContextManager:
    """Manages context creation for LLM generation."""
    
    def __init__(self):
        """Initialize the context manager."""
        print("✅ Context Manager initialized")
    
    def create_enhanced_context(self, question: str, results: List[Dict], max_length: int = MAX_CONTEXT_LENGTH) -> str:
        """Create enhanced context ensuring each query contributes equally."""
        # Group results by expanded query index
        query_to_chunks = defaultdict(list)
        for i, result in enumerate(results):
            # Find the most relevant expanded query for this chunk
            if 'contributing_queries' in result and result['contributing_queries']:
                # Use the highest scoring contributing query
                best_contrib = max(result['contributing_queries'], key=lambda cq: cq.get('semantic_score', cq.get('bm25_score', 0)))
                query_idx = best_contrib['query_idx']
            else:
                query_idx = 0  # fallback to first query
            query_to_chunks[query_idx].append((i, result))

        # Sort chunks within each query by their relevance scores
        for q_idx in query_to_chunks:
            query_to_chunks[q_idx].sort(key=lambda x: x[1].get('rerank_score', x[1].get('final_score', x[1].get('score', 0))), reverse=True)

        # Calculate chunks per query (should be 3 for each query with total budget = 9 and 3 queries)
        num_queries = len(query_to_chunks)
        if num_queries == 0:
            return ""
        
        # Ensure each query contributes equally (round-robin with guaranteed slots)
        context_parts = []
        current_length = 0
        added_chunks = set()
        
        # Calculate how many chunks each query should contribute
        chunks_per_query = len(results) // num_queries if num_queries > 0 else len(results)
        extra_chunks = len(results) % num_queries
        
        print(f"📊 Context Creation: {num_queries} queries, {chunks_per_query} chunks per query (+{extra_chunks} extra)")
        
        for q_idx in sorted(query_to_chunks.keys()):
            # Determine how many chunks this query should contribute
            query_chunk_limit = chunks_per_query + (1 if q_idx < extra_chunks else 0)
            query_chunks_added = 0
            
            print(f"   Query {q_idx+1}: Adding up to {query_chunk_limit} chunks")
            
            for i, result in query_to_chunks[q_idx]:
                if i not in added_chunks and query_chunks_added < query_chunk_limit:
                    text = result['payload'].get('text', '')
                    relevance_info = ""
                    if 'rerank_score' in result:
                        relevance_info = f" [Relevance: {result['rerank_score']:.2f}]"
                    elif 'final_score' in result:
                        relevance_info = f" [Score: {result['final_score']:.2f}]"
                    doc_text = f"[Query {q_idx+1} Doc {len(added_chunks)+1}]{relevance_info}\n{text}\n"
                    
                    if current_length + len(doc_text) > max_length:
                        print(f"   ⚠️ Context length limit reached at {current_length} chars")
                        break
                    
                    context_parts.append(doc_text)
                    current_length += len(doc_text)
                    added_chunks.add(i)
                    query_chunks_added += 1
            
            print(f"   Query {q_idx+1}: Added {query_chunks_added} chunks")
        
        print(f"📝 Final context: {len(added_chunks)} chunks, {current_length} chars")
        return "\n".join(context_parts)