policylens-rag-api / rag_engine /services /query_service.py
devjhawar's picture
Upload folder using huggingface_hub
6d03b1a verified
"""Handles end-to-end query: preprocess β†’ embed β†’ retrieve β†’ rerank β†’ build context β†’ LLM β†’ format response."""
from __future__ import annotations
from rag_engine.prompts.response_format import ResponseFormatter
from rag_engine.utils.logger import get_logger
logger = get_logger(__name__)
class QueryService:
"""End-to-end query pipeline: retrieve β†’ rerank β†’ context β†’ LLM β†’ format."""
def __init__(self) -> None:
from rag_engine.embeddings.embedding_factory import get_embedder
from rag_engine.llm.llm_factory import get_llm
from rag_engine.retrieval.context_builder import ContextBuilder
from rag_engine.retrieval.reranker import CrossEncoderReranker
from rag_engine.retrieval.retriever import PolicyRetriever
from rag_engine.vector_store.store_factory import get_vector_store
self._embedder = get_embedder()
self._store = get_vector_store()
self._retriever = PolicyRetriever(self._store, self._embedder)
self._reranker = CrossEncoderReranker()
self._context_builder = ContextBuilder()
self._llm = get_llm()
self._formatter = ResponseFormatter()
logger.info("QueryService initialized β€” all components ready")
# ------------------------------------------------------------------ #
# single-policy query
# ------------------------------------------------------------------ #
def query(
self,
question: str,
policy_id: str,
k: int = 5, # reduced from 8 β€” fetch 5 candidates
top_k_rerank: int = 3, # reduced from 5 β€” keep top 3 after rerank
) -> dict:
"""Answer *question* using chunks from *policy_id*."""
logger.info(
"QueryService.query START | policy_id=%s | q='%s'",
policy_id,
question[:60],
)
# Step 1 β€” retrieve
raw_results = self._retriever.retrieve(question, policy_id, k=k)
# Step 2 β€” rerank
reranked = self._reranker.rerank(question, raw_results, top_k=top_k_rerank)
# Step 3 β€” build context
context = self._context_builder.build(reranked)
# Step 4 β€” build prompt
from rag_engine.prompts.context_template import build_query_prompt
from rag_engine.prompts.system_prompt import SYSTEM_PROMPT
prompt = build_query_prompt(question, context, policy_id)
# Step 5 β€” LLM
raw_answer = self._llm.complete(prompt, system=SYSTEM_PROMPT)
# Step 6 β€” format
result = self._formatter.format_answer(raw_answer, question, reranked)
logger.info(
"QueryService.query COMPLETE | sources=%d",
result["source_count"],
)
return result
def stream_query(
self,
question: str,
policy_id: str,
k: int = 5, # reduced from 8
top_k_rerank: int = 3, # reduced from 5
):
"""Answer *question* via stream using chunks from *policy_id*."""
logger.info(
"QueryService.stream_query START | policy_id=%s | q='%s'",
policy_id,
question[:60],
)
# Step 1 β€” retrieve
raw_results = self._retriever.retrieve(question, policy_id, k=k)
# Step 2 β€” rerank
reranked = self._reranker.rerank(question, raw_results, top_k=top_k_rerank)
# Step 3 β€” build context
context = self._context_builder.build(reranked)
# Step 4 β€” build prompt
from rag_engine.prompts.context_template import build_query_prompt
from rag_engine.prompts.system_prompt import SYSTEM_PROMPT
prompt = build_query_prompt(question, context, policy_id)
# Step 5 β€” LLM Stream
for token in self._llm.stream(prompt, system=SYSTEM_PROMPT):
yield token
# ------------------------------------------------------------------ #
# multi-policy query
# ------------------------------------------------------------------ #
def query_multi_policy(
self, question: str, policy_ids: list[str]
) -> dict:
"""Answer *question* across multiple policies."""
raw_results = self._retriever.retrieve_multi_policy(
question, policy_ids, k_per_policy=3 # reduced from 4
)
reranked = self._reranker.rerank(question, raw_results, top_k=3) # reduced from 5
context = self._context_builder.build(reranked)
from rag_engine.prompts.context_template import build_query_prompt
from rag_engine.prompts.system_prompt import SYSTEM_PROMPT
prompt = build_query_prompt(
question, context, f"Multi-policy: {', '.join(policy_ids)}"
)
raw_answer = self._llm.complete(prompt, system=SYSTEM_PROMPT)
return self._formatter.format_answer(raw_answer, question, reranked)