File size: 6,868 Bytes
082f3f8
edde763
 
082f3f8
edde763
082f3f8
 
edde763
082f3f8
edde763
082f3f8
 
 
edde763
 
 
 
 
 
 
082f3f8
 
 
 
 
 
 
edde763
082f3f8
 
edde763
 
 
 
 
082f3f8
edde763
 
 
 
 
 
 
 
 
082f3f8
edde763
 
082f3f8
 
 
 
edde763
 
 
 
082f3f8
 
 
edde763
082f3f8
edde763
082f3f8
 
edde763
082f3f8
edde763
 
 
082f3f8
edde763
 
 
082f3f8
edde763
082f3f8
edde763
 
 
 
082f3f8
 
 
edde763
 
 
 
 
 
082f3f8
 
edde763
 
 
 
082f3f8
 
edde763
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
082f3f8
edde763
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173


# app/graph/nodes/router.py
from app.core.rag_service import get_rag_context
from app.core.llm_engine import llm  # βœ… Use main llm, not eval_llm
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import PromptTemplate
import ast

# βœ… IMPROVED: Multi-strategy expansion
expansion_prompt = PromptTemplate(
    input_variables=["query"],
    template=(
        "Generate 4 diverse search queries for: '{query}'\n\n"
        "1. Synonym variation (different words, same meaning)\n"
        "2. Acronym/abbreviation expansion (if applicable)\n"
        "3. Broader concept query\n"
        "4. Technical detail query\n\n"
        "Return ONLY a Python list of 4 strings.\n"
        "Example: ['machine learning algorithms', 'ML techniques', 'artificial intelligence methods', 'supervised learning models']\n\n"
        "List:"
    )
)

expansion_chain = expansion_prompt | llm | StrOutputParser()

def expand_query(query: str) -> list:
    """Generate diverse query variations for better retrieval."""
    try:
        raw = expansion_chain.invoke({"query": query}).strip()
        # Handle both list format and line-by-line format
        if raw.startswith('['):
            expansions = ast.literal_eval(raw)
        else:
            expansions = [line.strip().strip('"-') for line in raw.split('\n') if line.strip()]
        
        if isinstance(expansions, list) and expansions:
            # Return original + up to 4 unique expansions
            unique_expansions = [query]
            for exp in expansions:
                if exp and exp != query and exp not in unique_expansions:
                    unique_expansions.append(exp)
                if len(unique_expansions) >= 5:  # Original + 4 expansions
                    break
            return unique_expansions
    except Exception as e:
        print(f"⚠️ QUERY EXPANSION FAILED β†’ {e}")
    
    return [query]


def router_node(state):
    """
    Pure score-based routing with smart threshold zones.
    NO LLM calls during routing β€” saves quota and latency.
    """
    query = state.get("query")
    doc_id = state.get("doc_id")

    # βœ… Initial retrieval with relaxed threshold
    original_context, original_sources, original_scores = get_rag_context(
        query, doc_id, top_k=5  # βœ… Increased from 3
    )
    original_max_score = max(original_scores) if original_scores else 0.0
    print(f"πŸ“Š ORIGINAL QUERY SCORE β†’ {original_max_score:.3f}")

    # βœ… IMPROVED THRESHOLDS
    HIGH_THRESHOLD   = 0.50   # Lowered from 0.55 β€” strong match β†’ RAG
    HYBRID_THRESHOLD = 0.28   # Lowered from 0.30 β€” weak match β†’ hybrid
    
    # ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
    # ZONE 1: Pure General (< 0.28)
    # ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
    if original_max_score < HYBRID_THRESHOLD:
        print(f"πŸ”€ ROUTER β†’ general | score: {original_max_score:.3f}")
        return {
            **state, 
            "route": "general", 
            "context": "", 
            "sources": [], 
            "score": original_max_score
        }

    # ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
    # ZONE 2: Hybrid (0.28 - 0.50)
    # ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
    if HYBRID_THRESHOLD <= original_max_score < HIGH_THRESHOLD:
        print(f"πŸ”€ ROUTER β†’ hybrid | score: {original_max_score:.3f}")
        # Pass the best chunks we have, synthesizer will supplement with general knowledge
        return {
            **state,
            "route": "hybrid",
            "context": original_context,
            "sources": original_sources,
            "score": original_max_score
        }

    # ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
    # ZONE 3: Full RAG (β‰₯ 0.50) β€” Query Expansion + Fusion
    # ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
    print(f"πŸš€ HIGH SCORE β†’ Expanding query for better coverage...")
    expanded_queries = expand_query(query)
    print(f"πŸ“ EXPANDED QUERIES β†’ {expanded_queries}")
    
    # Collect chunks from all query variations
    all_contexts, all_scores = _collect_chunks_with_rrf(
        expanded_queries, doc_id
    )
    
    if not all_contexts:
        # Fallback to original context if expansion failed
        all_contexts = original_sources
    
    merged = "\n\n---\n\n".join(all_contexts[:10])  # βœ… Cap at 10 chunks
    
    print(f"🎯 ROUTER β†’ rag | score: {original_max_score:.3f} | chunks: {len(all_contexts)}")
    return {
        **state, 
        "route": "rag", 
        "context": merged, 
        "sources": all_contexts[:10], 
        "score": original_max_score
    }


def _collect_chunks_with_rrf(queries: list, doc_id: str, k: int = 60):
    """
    Reciprocal Rank Fusion across multiple queries.
    RRF formula: score = Ξ£(1 / (k + rank)) for each query
    """
    from collections import defaultdict
    
    chunk_scores = defaultdict(lambda: {'text': '', 'rrf_score': 0.0, 'max_sim': 0.0})
    
    for query in queries:
        context, sources, scores = get_rag_context(query, doc_id, top_k=8)
        
        if not context:
            continue
        
        chunks = context.split("\n\n---\n\n")
        
        for rank, (chunk, score) in enumerate(zip(chunks, scores), start=1):
            chunk = chunk.strip()
            if not chunk:
                continue
            
            # Use chunk text as key for deduplication
            chunk_key = chunk[:100]  # First 100 chars as unique identifier
            
            # RRF score accumulation
            chunk_scores[chunk_key]['text'] = chunk
            chunk_scores[chunk_key]['rrf_score'] += 1.0 / (k + rank)
            chunk_scores[chunk_key]['max_sim'] = max(
                chunk_scores[chunk_key]['max_sim'], 
                score
            )
    
    # Sort by RRF score (primary) and max similarity (tiebreaker)
    sorted_chunks = sorted(
        chunk_scores.values(),
        key=lambda x: (x['rrf_score'], x['max_sim']),
        reverse=True
    )
    
    texts = [item['text'] for item in sorted_chunks]
    scores = [item['max_sim'] for item in sorted_chunks]
    
    print(f"βœ… RRF FUSION β†’ {len(texts)} unique chunks from {len(queries)} queries")
    
    return texts, scores