RAG-document-assistant / src /orchestrator.py
vn6295337's picture
Initial commit: RAG Document Assistant with Zero-Storage Privacy
f866820
# src/orchestrator.py
from typing import List, Dict, Any, Optional, Literal
import re
import src.config as cfg
from src.ingestion.embeddings import get_embedding # provider-agnostic embedding fn used for ingestion
from src.retrieval.retriever import query_pinecone as pinecone_search, deterministic_embedding
from src.retrieval.hybrid import hybrid_search, HybridSearchResult
from src.retrieval.reranker import rerank_chunks, RerankResult
from src.prompts.rag_prompt import build_rag_prompt
from src.query.rewriter import rewrite_query, QueryRewriteResult
from src.context.shaper import shape_context, ContextShapeResult
from src.reasoning.analyzer import analyze_query, QueryAnalysis
from src.reasoning.chain import reason_over_evidence, ReasoningResult
from src.evaluation.tracer import PipelineTracer, format_trace_summary
# -------------------------
# Citation snippet enrichment
# -------------------------
import json
from pathlib import Path as _Path
def _enrich_citations_with_snippets(result: dict, chunk_map: dict):
"""
Mutates `result` in-place: for each citation where snippet is empty,
set snippet to chunk_map[citation.id] if available.
"""
if not isinstance(result, dict):
return result
for c in result.get("citations", []) + result.get("sources", []):
if not isinstance(c, dict):
continue
if not c.get("snippet"):
s = chunk_map.get(c.get("id"), "")
if s:
c["snippet"] = s
return result
def _load_chunks_map(path: str = "data/chunks.jsonl") -> dict:
"""
Load chunks map from JSONL file for citation enrichment.
Args:
path: Path to JSONL file containing chunk data
Returns:
Dictionary mapping chunk IDs to text content
"""
m = {}
pth = _Path(path)
if not pth.exists():
return m
try:
with pth.open("r", encoding="utf-8") as fh:
for line_num, line in enumerate(fh, 1):
line = line.strip()
if not line:
continue
try:
obj = json.loads(line)
except json.JSONDecodeError:
# Skip malformed JSON lines
continue
cid = obj.get("id") or obj.get("chunk_id") or None
if not cid and "filename" in obj and "chunk_id" in obj:
cid = f"{obj['filename']}::{obj['chunk_id']}"
text = obj.get("text") or obj.get("chunk") or obj.get("content") or ""
if cid:
m[str(cid)] = text
except Exception as e:
# Don't fail if chunks file can't be loaded
pass
return m
_CHUNKS_MAP = _load_chunks_map()
_CURRENT_CHUNKS_PATH = "data/chunks.jsonl"
def set_chunks_path(path: str) -> int:
"""
Set the chunks path and reload the chunks map.
Args:
path: Path to the chunks JSONL file
Returns:
Number of chunks loaded
"""
global _CHUNKS_MAP, _CURRENT_CHUNKS_PATH
_CURRENT_CHUNKS_PATH = path
_CHUNKS_MAP = _load_chunks_map(path)
return len(_CHUNKS_MAP)
def get_current_chunks_path() -> str:
"""Get the current chunks path."""
return _CURRENT_CHUNKS_PATH
def get_chunks_count() -> int:
"""Get the number of chunks currently loaded."""
return len(_CHUNKS_MAP)
import json
from pathlib import Path as _Path
# Try to import a provider wrapper; if missing, use a local deterministic fallback for offline tests.
try:
from src.llm_providers import call_llm # thin wrapper that chooses Gemini/Groq/OpenRouter per config
except Exception:
def call_llm(prompt: str, temperature: float = 0.0, max_tokens: int = 512, **kwargs):
"""
Simple deterministic offline responder: return prompt summary-like text and meta.
Args:
prompt: Prompt text
temperature: Sampling temperature (ignored in fallback)
max_tokens: Maximum tokens (ignored in fallback)
**kwargs: Additional arguments (ignored in fallback)
Returns:
Dict with 'text' and 'meta' keys
"""
# Simple deterministic offline responder: return prompt summary-like text and meta.
summary = prompt.strip().split('\n')[-1] if prompt.strip() else "No prompt provided"
resp_text = f"[offline-llm] Answer based on provided context: {summary}"
return {"text": resp_text, "meta": {"provider": "local-fallback", "temperature": temperature}}
# Legacy prompt template (replaced by src.prompts.rag_prompt)
# Kept for reference during migration
_LEGACY_PROMPT_TEMPLATE = """
You are given a user query and a set of context chunks. Use the context to answer concisely.
Provide a short answer and list the ids of chunks used as citations.
User query:
{query}
Context chunks (top {k} by score):
{context}
Answer:
"""
def _extract_cited_ids_from_llm(text: str) -> List[str]:
"""
Extract cited chunk IDs from LLM response text.
Supports both formats:
- [ID:chunk_id] (new structured prompt format)
- ID:chunk_id (legacy format)
Args:
text: LLM response text
Returns:
List of cited chunk IDs (deduplicated, preserving order)
"""
if not text or not isinstance(text, str):
return []
# Match both [ID:chunk_id] and plain ID:chunk_id formats
# The bracketed format takes precedence in new prompts
bracketed = re.findall(r"\[ID:([A-Za-z0-9_\-:.]+)\]", text)
plain = re.findall(r"(?<!\[)ID:([A-Za-z0-9_\-:.]+)", text)
# Combine and deduplicate while preserving order
seen = set()
result = []
for cid in bracketed + plain:
if cid not in seen:
seen.add(cid)
result.append(cid)
return result
def _merge_chunks(chunk_lists: List[List[Dict[str, Any]]], top_k: int) -> List[Dict[str, Any]]:
"""
Merge and deduplicate chunks from multiple query results.
Uses reciprocal rank fusion (RRF) to combine rankings from different queries.
Chunks appearing in multiple query results get boosted scores.
Args:
chunk_lists: List of chunk lists from different queries
top_k: Maximum number of chunks to return
Returns:
Merged and deduplicated list of chunks, sorted by combined score
"""
# RRF constant (standard value)
k = 60
# Track scores by chunk ID
chunk_scores: Dict[str, float] = {}
chunk_data: Dict[str, Dict[str, Any]] = {}
for chunks in chunk_lists:
for rank, chunk in enumerate(chunks):
chunk_id = chunk.get("id", "")
if not chunk_id:
continue
# RRF score contribution
rrf_score = 1.0 / (k + rank + 1)
chunk_scores[chunk_id] = chunk_scores.get(chunk_id, 0) + rrf_score
# Keep the chunk data (use first occurrence)
if chunk_id not in chunk_data:
chunk_data[chunk_id] = chunk
# Sort by combined RRF score
sorted_ids = sorted(chunk_scores.keys(), key=lambda x: chunk_scores[x], reverse=True)
# Build result with updated scores
merged = []
for chunk_id in sorted_ids[:top_k]:
chunk = chunk_data[chunk_id].copy()
chunk["rrf_score"] = chunk_scores[chunk_id]
merged.append(chunk)
return merged
def orchestrate_query(
query: str,
top_k: int = 3,
llm_params: Dict[str, Any] = None,
rewrite_strategy: Optional[Literal["expand", "multi", "decompose", "auto", "none"]] = "auto",
use_hybrid: bool = True,
use_reranking: bool = True,
semantic_weight: float = 0.7,
keyword_weight: float = 0.3,
chunks_path: str = None
) -> Dict[str, Any]:
"""
Orchestrate the full RAG query pipeline: query rewriting → hybrid retrieval → reranking → LLM generation.
Args:
query: User query string
top_k: Number of top chunks to retrieve
llm_params: Parameters for LLM call (temperature, max_tokens, etc.)
rewrite_strategy: Query rewriting strategy
- "expand": Rule-based synonym expansion (fast, no LLM)
- "multi": LLM generates multiple query variants
- "decompose": LLM breaks complex query into sub-queries
- "auto": Automatically choose based on query complexity
- "none": Disable query rewriting
use_hybrid: Enable hybrid search (semantic + keyword)
use_reranking: Enable cross-encoder reranking
semantic_weight: Weight for semantic search in hybrid mode (0-1)
keyword_weight: Weight for keyword search in hybrid mode (0-1)
chunks_path: Path to chunks JSONL file (uses default if None)
Returns:
Dict with answer, sources, citations, retrieval info, and metadata
Raises:
Exception: If any step in the pipeline fails
"""
if not query or not isinstance(query, str):
return {"answer": "", "sources": [], "citations": [], "llm_meta": {"error": "invalid_query"}}
if llm_params is None:
llm_params = {"temperature": 0.0, "max_tokens": 512}
# Use default chunks path if not specified
if chunks_path is None:
chunks_path = _CURRENT_CHUNKS_PATH
# Validate top_k
if not isinstance(top_k, int) or top_k <= 0:
top_k = 3
# Track retrieval metadata
retrieval_meta = {
"hybrid_enabled": use_hybrid,
"reranking_enabled": use_reranking
}
# 1) Query rewriting
rewrite_result: Optional[QueryRewriteResult] = None
queries_to_search = [query]
if rewrite_strategy and rewrite_strategy != "none":
try:
rewrite_result = rewrite_query(
query=query,
num_variants=3,
strategy=rewrite_strategy,
use_llm=(rewrite_strategy != "expand")
)
queries_to_search = rewrite_result.rewritten_queries
except Exception:
queries_to_search = [query]
# 2) Retrieval: hybrid or semantic-only
all_chunk_lists = []
for q in queries_to_search:
try:
if use_hybrid:
# Fetch more for reranking
fetch_k = top_k * 3 if use_reranking else top_k * 2
hybrid_result = hybrid_search(
query=q,
top_k=fetch_k,
semantic_weight=semantic_weight,
keyword_weight=keyword_weight,
chunks_path=chunks_path
)
if hybrid_result.chunks:
all_chunk_lists.append(hybrid_result.chunks)
retrieval_meta["hybrid_strategy"] = hybrid_result.strategy
retrieval_meta["semantic_count"] = hybrid_result.semantic_count
retrieval_meta["keyword_count"] = hybrid_result.keyword_count
else:
# Semantic-only search
chunks = pinecone_search(q, top_k=top_k * 2 if use_reranking else top_k)
if chunks:
all_chunk_lists.append(chunks)
except Exception:
continue
if not all_chunk_lists:
return {"answer": "", "sources": [], "citations": [], "llm_meta": {"error": "retrieval_failed"}}
# 3) Merge if multiple queries were used
if len(all_chunk_lists) == 1:
chunks = all_chunk_lists[0]
else:
chunks = _merge_chunks(all_chunk_lists, top_k=top_k * 2 if use_reranking else top_k)
if not chunks:
return {"answer": "", "sources": [], "citations": [], "llm_meta": {"error": "no_retrieval_results"}}
# 4) Reranking (optional)
if use_reranking and len(chunks) > 1:
try:
rerank_result = rerank_chunks(query=query, chunks=chunks, top_k=top_k)
chunks = rerank_result.chunks
retrieval_meta["rerank_model"] = rerank_result.model_used
retrieval_meta["reranked"] = rerank_result.reranked
except Exception as e:
retrieval_meta["rerank_error"] = str(e)[:100]
chunks = chunks[:top_k]
else:
chunks = chunks[:top_k]
# 5) build prompt using structured prompt template
prompt = build_rag_prompt(query=query, chunks=chunks, k=top_k)
# 6) call LLM via unified provider wrapper
try:
llm_resp = call_llm(prompt=prompt, **llm_params)
except Exception as e:
return {"answer": "", "sources": [], "citations": [], "llm_meta": {"error": f"llm_call_failed: {str(e)}"}}
# 7) build sources (ensure snippet comes from chunk text or fallback to local chunks map)
sources: List[Dict[str, Any]] = []
for c in chunks:
# prefer chunk text from retrieval result; fallback to local chunk map
text_from_chunk = c.get("text") or "" if isinstance(c, dict) else ""
if not text_from_chunk:
text_from_chunk = _CHUNKS_MAP.get(str(c.get("id"))) or _CHUNKS_MAP.get(str(c.get("chunk_id")), "") if isinstance(c, dict) else ""
snippet = (text_from_chunk or "")[:400]
sources.append({
"id": c.get("id") if isinstance(c, dict) else None,
"score": float(c.get("score", 0.0)) if isinstance(c, dict) else 0.0,
"snippet": snippet
})
# 8) Build citations: prefer explicit IDs listed by LLM, else fallback to top sources
cited_ids = _extract_cited_ids_from_llm(llm_resp.get("text", ""))
citations: List[Dict[str, Any]] = []
if cited_ids:
id_map = {s["id"]: s for s in sources if s["id"]}
for cid in cited_ids:
if cid in id_map:
citations.append(id_map[cid])
if not citations:
citations = sources
# 9) assemble result dict, then enrich citation snippets in-place (best-effort)
result = {
"answer": llm_resp.get("text", "").strip() if isinstance(llm_resp, dict) else "",
"sources": sources,
"citations": citations,
"llm_meta": llm_resp.get("meta", {}) if isinstance(llm_resp, dict) else {},
"query_rewrite": {
"original": query,
"rewritten": queries_to_search,
"strategy": rewrite_result.strategy_used if rewrite_result else "none"
} if rewrite_result else None,
"retrieval_meta": retrieval_meta
}
# Best-effort: enrich any empty snippets from the canonical _CHUNKS_MAP
try:
_enrich_citations_with_snippets(result, _CHUNKS_MAP)
except Exception:
# don't fail the whole call if enrichment breaks
pass
return result
def orchestrate_advanced(
query: str,
top_k: int = 5,
llm_params: Dict[str, Any] = None,
token_budget: int = 3000,
enable_reasoning: bool = True,
enable_tracing: bool = True,
chunks_path: str = None
) -> Dict[str, Any]:
"""
Advanced RAG orchestration with full context engineering pipeline.
Pipeline:
1. Query Analysis - Classify and decompose query
2. Query Rewriting - Expand/reformulate for better retrieval
3. Hybrid Retrieval - Semantic + keyword search
4. Reranking - Cross-encoder precision boost
5. Context Shaping - Dedup, prune, budget
6. Reasoning - Chain-of-thought synthesis
7. Generation - Produce final answer
Args:
query: User query
top_k: Number of chunks to retrieve
llm_params: LLM parameters
token_budget: Max tokens for context
enable_reasoning: Use reasoning-aware generation
enable_tracing: Capture pipeline trace
chunks_path: Path to chunks JSONL file (uses default if None)
Returns:
Dict with answer, sources, reasoning, trace, and metadata
"""
if llm_params is None:
llm_params = {"temperature": 0.0, "max_tokens": 800}
# Use default chunks path if not specified
if chunks_path is None:
chunks_path = _CURRENT_CHUNKS_PATH
# Initialize tracer
tracer = PipelineTracer(query) if enable_tracing else None
try:
# 1. Query Analysis
if tracer:
with tracer.trace_stage("query_analysis") as stage:
analysis = analyze_query(query, use_llm=False)
stage.metadata = {
"query_type": analysis.query_type,
"sub_queries": len(analysis.sub_queries),
"reasoning_required": analysis.reasoning_required
}
else:
analysis = analyze_query(query, use_llm=False)
# 2. Query Rewriting
if tracer:
with tracer.trace_stage("query_rewrite") as stage:
rewrite_result = rewrite_query(
query=query,
strategy="auto",
use_llm=True
)
stage.metadata = {
"strategy": rewrite_result.strategy_used,
"variants": len(rewrite_result.rewritten_queries)
}
else:
rewrite_result = rewrite_query(query, strategy="auto", use_llm=True)
queries_to_search = rewrite_result.rewritten_queries
# 3. Hybrid Retrieval
all_chunks = []
if tracer:
with tracer.trace_stage("retrieval") as stage:
for q in queries_to_search:
result = hybrid_search(q, top_k=top_k * 2, chunks_path=chunks_path)
all_chunks.extend(result.chunks)
stage.metadata = {
"queries_searched": len(queries_to_search),
"chunks_found": len(all_chunks)
}
else:
for q in queries_to_search:
result = hybrid_search(q, top_k=top_k * 2, chunks_path=chunks_path)
all_chunks.extend(result.chunks)
if not all_chunks:
return {"answer": "", "error": "No chunks retrieved"}
# Deduplicate across queries
seen_ids = set()
unique_chunks = []
for c in all_chunks:
cid = c.get("id")
if cid and cid not in seen_ids:
seen_ids.add(cid)
unique_chunks.append(c)
# 4. Reranking
if tracer:
with tracer.trace_stage("reranking") as stage:
rerank_result = rerank_chunks(query, unique_chunks, top_k=top_k * 2)
stage.metadata = {
"input_chunks": len(unique_chunks),
"output_chunks": len(rerank_result.chunks),
"model": rerank_result.model_used
}
else:
rerank_result = rerank_chunks(query, unique_chunks, top_k=top_k * 2)
chunks = rerank_result.chunks
# 5. Context Shaping
if tracer:
with tracer.trace_stage("context_shaping") as stage:
shape_result = shape_context(
chunks=chunks,
query=query,
token_budget=token_budget,
enable_pruning=True,
enable_compression=True
)
stage.metadata = {
"original_tokens": shape_result.original_tokens,
"final_tokens": shape_result.final_tokens,
"chunks_removed": shape_result.chunks_removed,
"compression": shape_result.compression_applied
}
else:
shape_result = shape_context(
chunks=chunks,
query=query,
token_budget=token_budget
)
shaped_chunks = shape_result.chunks[:top_k]
# 6. Reasoning or Standard Generation
if enable_reasoning and analysis.reasoning_required:
if tracer:
with tracer.trace_stage("reasoning") as stage:
reasoning_result = reason_over_evidence(
query=query,
chunks=shaped_chunks,
query_type=analysis.query_type
)
stage.metadata = {
"reasoning_type": reasoning_result.reasoning_type,
"evidence_used": len(reasoning_result.evidence_used),
"confidence": reasoning_result.confidence
}
else:
reasoning_result = reason_over_evidence(
query=query,
chunks=shaped_chunks,
query_type=analysis.query_type
)
answer = reasoning_result.answer
reasoning_info = {
"steps": reasoning_result.reasoning_steps,
"evidence_used": reasoning_result.evidence_used,
"confidence": reasoning_result.confidence,
"type": reasoning_result.reasoning_type
}
else:
# Standard generation
if tracer:
with tracer.trace_stage("generation") as stage:
prompt = build_rag_prompt(query, shaped_chunks, k=top_k)
llm_resp = call_llm(prompt=prompt, **llm_params)
answer = llm_resp.get("text", "").strip()
stage.metadata = {
"prompt_length": len(prompt),
"answer_length": len(answer)
}
else:
prompt = build_rag_prompt(query, shaped_chunks, k=top_k)
llm_resp = call_llm(prompt=prompt, **llm_params)
answer = llm_resp.get("text", "").strip()
reasoning_info = None
# Build sources
sources = []
for c in shaped_chunks:
sources.append({
"id": c.get("id"),
"score": c.get("score", 0),
"snippet": c.get("text", "")[:300]
})
# Build result
result = {
"answer": answer,
"sources": sources,
"query_analysis": {
"type": analysis.query_type,
"sub_queries": analysis.sub_queries,
"reasoning_required": analysis.reasoning_required
},
"context_shaping": {
"original_tokens": shape_result.original_tokens,
"final_tokens": shape_result.final_tokens,
"compression_applied": shape_result.compression_applied
}
}
if reasoning_info:
result["reasoning"] = reasoning_info
if tracer:
tracer.set_answer(answer)
result["trace"] = tracer.to_dict()
return result
except Exception as e:
if tracer:
tracer.set_error(str(e))
return {
"answer": "",
"error": str(e),
"trace": tracer.to_dict()
}
raise