Spaces:
Sleeping
Sleeping
| """Handles end-to-end query: preprocess β embed β retrieve β rerank β build context β LLM β format response.""" | |
| from __future__ import annotations | |
| from rag_engine.prompts.response_format import ResponseFormatter | |
| from rag_engine.utils.logger import get_logger | |
| logger = get_logger(__name__) | |
| class QueryService: | |
| """End-to-end query pipeline: retrieve β rerank β context β LLM β format.""" | |
| def __init__(self) -> None: | |
| from rag_engine.embeddings.embedding_factory import get_embedder | |
| from rag_engine.llm.llm_factory import get_llm | |
| from rag_engine.retrieval.context_builder import ContextBuilder | |
| from rag_engine.retrieval.reranker import CrossEncoderReranker | |
| from rag_engine.retrieval.retriever import PolicyRetriever | |
| from rag_engine.vector_store.store_factory import get_vector_store | |
| self._embedder = get_embedder() | |
| self._store = get_vector_store() | |
| self._retriever = PolicyRetriever(self._store, self._embedder) | |
| self._reranker = CrossEncoderReranker() | |
| self._context_builder = ContextBuilder() | |
| self._llm = get_llm() | |
| self._formatter = ResponseFormatter() | |
| logger.info("QueryService initialized β all components ready") | |
| # ------------------------------------------------------------------ # | |
| # single-policy query | |
| # ------------------------------------------------------------------ # | |
| def query( | |
| self, | |
| question: str, | |
| policy_id: str, | |
| k: int = 5, # reduced from 8 β fetch 5 candidates | |
| top_k_rerank: int = 3, # reduced from 5 β keep top 3 after rerank | |
| ) -> dict: | |
| """Answer *question* using chunks from *policy_id*.""" | |
| logger.info( | |
| "QueryService.query START | policy_id=%s | q='%s'", | |
| policy_id, | |
| question[:60], | |
| ) | |
| # Step 1 β retrieve | |
| raw_results = self._retriever.retrieve(question, policy_id, k=k) | |
| # Step 2 β rerank | |
| reranked = self._reranker.rerank(question, raw_results, top_k=top_k_rerank) | |
| # Step 3 β build context | |
| context = self._context_builder.build(reranked) | |
| # Step 4 β build prompt | |
| from rag_engine.prompts.context_template import build_query_prompt | |
| from rag_engine.prompts.system_prompt import SYSTEM_PROMPT | |
| prompt = build_query_prompt(question, context, policy_id) | |
| # Step 5 β LLM | |
| raw_answer = self._llm.complete(prompt, system=SYSTEM_PROMPT) | |
| # Step 6 β format | |
| result = self._formatter.format_answer(raw_answer, question, reranked) | |
| logger.info( | |
| "QueryService.query COMPLETE | sources=%d", | |
| result["source_count"], | |
| ) | |
| return result | |
| def stream_query( | |
| self, | |
| question: str, | |
| policy_id: str, | |
| k: int = 5, # reduced from 8 | |
| top_k_rerank: int = 3, # reduced from 5 | |
| ): | |
| """Answer *question* via stream using chunks from *policy_id*.""" | |
| logger.info( | |
| "QueryService.stream_query START | policy_id=%s | q='%s'", | |
| policy_id, | |
| question[:60], | |
| ) | |
| # Step 1 β retrieve | |
| raw_results = self._retriever.retrieve(question, policy_id, k=k) | |
| # Step 2 β rerank | |
| reranked = self._reranker.rerank(question, raw_results, top_k=top_k_rerank) | |
| # Step 3 β build context | |
| context = self._context_builder.build(reranked) | |
| # Step 4 β build prompt | |
| from rag_engine.prompts.context_template import build_query_prompt | |
| from rag_engine.prompts.system_prompt import SYSTEM_PROMPT | |
| prompt = build_query_prompt(question, context, policy_id) | |
| # Step 5 β LLM Stream | |
| for token in self._llm.stream(prompt, system=SYSTEM_PROMPT): | |
| yield token | |
| # ------------------------------------------------------------------ # | |
| # multi-policy query | |
| # ------------------------------------------------------------------ # | |
| def query_multi_policy( | |
| self, question: str, policy_ids: list[str] | |
| ) -> dict: | |
| """Answer *question* across multiple policies.""" | |
| raw_results = self._retriever.retrieve_multi_policy( | |
| question, policy_ids, k_per_policy=3 # reduced from 4 | |
| ) | |
| reranked = self._reranker.rerank(question, raw_results, top_k=3) # reduced from 5 | |
| context = self._context_builder.build(reranked) | |
| from rag_engine.prompts.context_template import build_query_prompt | |
| from rag_engine.prompts.system_prompt import SYSTEM_PROMPT | |
| prompt = build_query_prompt( | |
| question, context, f"Multi-policy: {', '.join(policy_ids)}" | |
| ) | |
| raw_answer = self._llm.complete(prompt, system=SYSTEM_PROMPT) | |
| return self._formatter.format_answer(raw_answer, question, reranked) | |