Spaces:
Running
Running
| # app/graph/nodes/router.py | |
| from app.core.rag_service import get_rag_context | |
| from app.core.llm_engine import llm # β Use main llm, not eval_llm | |
| from langchain_core.output_parsers import StrOutputParser | |
| from langchain_core.prompts import PromptTemplate | |
| import ast | |
| # β IMPROVED: Multi-strategy expansion | |
| expansion_prompt = PromptTemplate( | |
| input_variables=["query"], | |
| template=( | |
| "Generate 4 diverse search queries for: '{query}'\n\n" | |
| "1. Synonym variation (different words, same meaning)\n" | |
| "2. Acronym/abbreviation expansion (if applicable)\n" | |
| "3. Broader concept query\n" | |
| "4. Technical detail query\n\n" | |
| "Return ONLY a Python list of 4 strings.\n" | |
| "Example: ['machine learning algorithms', 'ML techniques', 'artificial intelligence methods', 'supervised learning models']\n\n" | |
| "List:" | |
| ) | |
| ) | |
| expansion_chain = expansion_prompt | llm | StrOutputParser() | |
| def expand_query(query: str) -> list: | |
| """Generate diverse query variations for better retrieval.""" | |
| try: | |
| raw = expansion_chain.invoke({"query": query}).strip() | |
| # Handle both list format and line-by-line format | |
| if raw.startswith('['): | |
| expansions = ast.literal_eval(raw) | |
| else: | |
| expansions = [line.strip().strip('"-') for line in raw.split('\n') if line.strip()] | |
| if isinstance(expansions, list) and expansions: | |
| # Return original + up to 4 unique expansions | |
| unique_expansions = [query] | |
| for exp in expansions: | |
| if exp and exp != query and exp not in unique_expansions: | |
| unique_expansions.append(exp) | |
| if len(unique_expansions) >= 5: # Original + 4 expansions | |
| break | |
| return unique_expansions | |
| except Exception as e: | |
| print(f"β οΈ QUERY EXPANSION FAILED β {e}") | |
| return [query] | |
| def router_node(state): | |
| """ | |
| Pure score-based routing with smart threshold zones. | |
| NO LLM calls during routing β saves quota and latency. | |
| """ | |
| query = state.get("query") | |
| doc_id = state.get("doc_id") | |
| # β Initial retrieval with relaxed threshold | |
| original_context, original_sources, original_scores = get_rag_context( | |
| query, doc_id, top_k=5 # β Increased from 3 | |
| ) | |
| original_max_score = max(original_scores) if original_scores else 0.0 | |
| print(f"π ORIGINAL QUERY SCORE β {original_max_score:.3f}") | |
| # β IMPROVED THRESHOLDS | |
| HIGH_THRESHOLD = 0.50 # Lowered from 0.55 β strong match β RAG | |
| HYBRID_THRESHOLD = 0.28 # Lowered from 0.30 β weak match β hybrid | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # ZONE 1: Pure General (< 0.28) | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββ | |
| if original_max_score < HYBRID_THRESHOLD: | |
| print(f"π ROUTER β general | score: {original_max_score:.3f}") | |
| return { | |
| **state, | |
| "route": "general", | |
| "context": "", | |
| "sources": [], | |
| "score": original_max_score | |
| } | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # ZONE 2: Hybrid (0.28 - 0.50) | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββ | |
| if HYBRID_THRESHOLD <= original_max_score < HIGH_THRESHOLD: | |
| print(f"π ROUTER β hybrid | score: {original_max_score:.3f}") | |
| # Pass the best chunks we have, synthesizer will supplement with general knowledge | |
| return { | |
| **state, | |
| "route": "hybrid", | |
| "context": original_context, | |
| "sources": original_sources, | |
| "score": original_max_score | |
| } | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # ZONE 3: Full RAG (β₯ 0.50) β Query Expansion + Fusion | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββ | |
| print(f"π HIGH SCORE β Expanding query for better coverage...") | |
| expanded_queries = expand_query(query) | |
| print(f"π EXPANDED QUERIES β {expanded_queries}") | |
| # Collect chunks from all query variations | |
| all_contexts, all_scores = _collect_chunks_with_rrf( | |
| expanded_queries, doc_id | |
| ) | |
| if not all_contexts: | |
| # Fallback to original context if expansion failed | |
| all_contexts = original_sources | |
| merged = "\n\n---\n\n".join(all_contexts[:10]) # β Cap at 10 chunks | |
| print(f"π― ROUTER β rag | score: {original_max_score:.3f} | chunks: {len(all_contexts)}") | |
| return { | |
| **state, | |
| "route": "rag", | |
| "context": merged, | |
| "sources": all_contexts[:10], | |
| "score": original_max_score | |
| } | |
| def _collect_chunks_with_rrf(queries: list, doc_id: str, k: int = 60): | |
| """ | |
| Reciprocal Rank Fusion across multiple queries. | |
| RRF formula: score = Ξ£(1 / (k + rank)) for each query | |
| """ | |
| from collections import defaultdict | |
| chunk_scores = defaultdict(lambda: {'text': '', 'rrf_score': 0.0, 'max_sim': 0.0}) | |
| for query in queries: | |
| context, sources, scores = get_rag_context(query, doc_id, top_k=8) | |
| if not context: | |
| continue | |
| chunks = context.split("\n\n---\n\n") | |
| for rank, (chunk, score) in enumerate(zip(chunks, scores), start=1): | |
| chunk = chunk.strip() | |
| if not chunk: | |
| continue | |
| # Use chunk text as key for deduplication | |
| chunk_key = chunk[:100] # First 100 chars as unique identifier | |
| # RRF score accumulation | |
| chunk_scores[chunk_key]['text'] = chunk | |
| chunk_scores[chunk_key]['rrf_score'] += 1.0 / (k + rank) | |
| chunk_scores[chunk_key]['max_sim'] = max( | |
| chunk_scores[chunk_key]['max_sim'], | |
| score | |
| ) | |
| # Sort by RRF score (primary) and max similarity (tiebreaker) | |
| sorted_chunks = sorted( | |
| chunk_scores.values(), | |
| key=lambda x: (x['rrf_score'], x['max_sim']), | |
| reverse=True | |
| ) | |
| texts = [item['text'] for item in sorted_chunks] | |
| scores = [item['max_sim'] for item in sorted_chunks] | |
| print(f"β RRF FUSION β {len(texts)} unique chunks from {len(queries)} queries") | |
| return texts, scores | |