Spaces:
Running
Running
File size: 6,868 Bytes
082f3f8 edde763 082f3f8 edde763 082f3f8 edde763 082f3f8 edde763 082f3f8 edde763 082f3f8 edde763 082f3f8 edde763 082f3f8 edde763 082f3f8 edde763 082f3f8 edde763 082f3f8 edde763 082f3f8 edde763 082f3f8 edde763 082f3f8 edde763 082f3f8 edde763 082f3f8 edde763 082f3f8 edde763 082f3f8 edde763 082f3f8 edde763 082f3f8 edde763 082f3f8 edde763 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 |
# app/graph/nodes/router.py
from app.core.rag_service import get_rag_context
from app.core.llm_engine import llm # β
Use main llm, not eval_llm
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import PromptTemplate
import ast
# β
IMPROVED: Multi-strategy expansion
expansion_prompt = PromptTemplate(
input_variables=["query"],
template=(
"Generate 4 diverse search queries for: '{query}'\n\n"
"1. Synonym variation (different words, same meaning)\n"
"2. Acronym/abbreviation expansion (if applicable)\n"
"3. Broader concept query\n"
"4. Technical detail query\n\n"
"Return ONLY a Python list of 4 strings.\n"
"Example: ['machine learning algorithms', 'ML techniques', 'artificial intelligence methods', 'supervised learning models']\n\n"
"List:"
)
)
expansion_chain = expansion_prompt | llm | StrOutputParser()
def expand_query(query: str) -> list:
"""Generate diverse query variations for better retrieval."""
try:
raw = expansion_chain.invoke({"query": query}).strip()
# Handle both list format and line-by-line format
if raw.startswith('['):
expansions = ast.literal_eval(raw)
else:
expansions = [line.strip().strip('"-') for line in raw.split('\n') if line.strip()]
if isinstance(expansions, list) and expansions:
# Return original + up to 4 unique expansions
unique_expansions = [query]
for exp in expansions:
if exp and exp != query and exp not in unique_expansions:
unique_expansions.append(exp)
if len(unique_expansions) >= 5: # Original + 4 expansions
break
return unique_expansions
except Exception as e:
print(f"β οΈ QUERY EXPANSION FAILED β {e}")
return [query]
def router_node(state):
"""
Pure score-based routing with smart threshold zones.
NO LLM calls during routing β saves quota and latency.
"""
query = state.get("query")
doc_id = state.get("doc_id")
# β
Initial retrieval with relaxed threshold
original_context, original_sources, original_scores = get_rag_context(
query, doc_id, top_k=5 # β
Increased from 3
)
original_max_score = max(original_scores) if original_scores else 0.0
print(f"π ORIGINAL QUERY SCORE β {original_max_score:.3f}")
# β
IMPROVED THRESHOLDS
HIGH_THRESHOLD = 0.50 # Lowered from 0.55 β strong match β RAG
HYBRID_THRESHOLD = 0.28 # Lowered from 0.30 β weak match β hybrid
# ββββββββββββββββββββββββββββββββββββββββββββββββ
# ZONE 1: Pure General (< 0.28)
# ββββββββββββββββββββββββββββββββββββββββββββββββ
if original_max_score < HYBRID_THRESHOLD:
print(f"π ROUTER β general | score: {original_max_score:.3f}")
return {
**state,
"route": "general",
"context": "",
"sources": [],
"score": original_max_score
}
# ββββββββββββββββββββββββββββββββββββββββββββββββ
# ZONE 2: Hybrid (0.28 - 0.50)
# ββββββββββββββββββββββββββββββββββββββββββββββββ
if HYBRID_THRESHOLD <= original_max_score < HIGH_THRESHOLD:
print(f"π ROUTER β hybrid | score: {original_max_score:.3f}")
# Pass the best chunks we have, synthesizer will supplement with general knowledge
return {
**state,
"route": "hybrid",
"context": original_context,
"sources": original_sources,
"score": original_max_score
}
# ββββββββββββββββββββββββββββββββββββββββββββββββ
# ZONE 3: Full RAG (β₯ 0.50) β Query Expansion + Fusion
# ββββββββββββββββββββββββββββββββββββββββββββββββ
print(f"π HIGH SCORE β Expanding query for better coverage...")
expanded_queries = expand_query(query)
print(f"π EXPANDED QUERIES β {expanded_queries}")
# Collect chunks from all query variations
all_contexts, all_scores = _collect_chunks_with_rrf(
expanded_queries, doc_id
)
if not all_contexts:
# Fallback to original context if expansion failed
all_contexts = original_sources
merged = "\n\n---\n\n".join(all_contexts[:10]) # β
Cap at 10 chunks
print(f"π― ROUTER β rag | score: {original_max_score:.3f} | chunks: {len(all_contexts)}")
return {
**state,
"route": "rag",
"context": merged,
"sources": all_contexts[:10],
"score": original_max_score
}
def _collect_chunks_with_rrf(queries: list, doc_id: str, k: int = 60):
"""
Reciprocal Rank Fusion across multiple queries.
RRF formula: score = Ξ£(1 / (k + rank)) for each query
"""
from collections import defaultdict
chunk_scores = defaultdict(lambda: {'text': '', 'rrf_score': 0.0, 'max_sim': 0.0})
for query in queries:
context, sources, scores = get_rag_context(query, doc_id, top_k=8)
if not context:
continue
chunks = context.split("\n\n---\n\n")
for rank, (chunk, score) in enumerate(zip(chunks, scores), start=1):
chunk = chunk.strip()
if not chunk:
continue
# Use chunk text as key for deduplication
chunk_key = chunk[:100] # First 100 chars as unique identifier
# RRF score accumulation
chunk_scores[chunk_key]['text'] = chunk
chunk_scores[chunk_key]['rrf_score'] += 1.0 / (k + rank)
chunk_scores[chunk_key]['max_sim'] = max(
chunk_scores[chunk_key]['max_sim'],
score
)
# Sort by RRF score (primary) and max similarity (tiebreaker)
sorted_chunks = sorted(
chunk_scores.values(),
key=lambda x: (x['rrf_score'], x['max_sim']),
reverse=True
)
texts = [item['text'] for item in sorted_chunks]
scores = [item['max_sim'] for item in sorted_chunks]
print(f"β
RRF FUSION β {len(texts)} unique chunks from {len(queries)} queries")
return texts, scores
|