ishaq101's picture
feat/Catalog Retrieval System (#1)
6bff5d9
"""QueryPlannerService — single LLM call: question + catalog → JSON IR.
Prompt: src/config/prompts/query_planner.md (system) + the human content
built by `prompt.build_planner_prompt(...)`.
Output: a QueryIR ready for the IRValidator. Validation + retry are the
caller's concern (`QueryService` orchestrates that loop).
"""
from __future__ import annotations
from pathlib import Path
from langchain_core.messages import SystemMessage
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.runnables import Runnable
from langchain_openai import AzureChatOpenAI
from src.middlewares.logging import get_logger
from ...catalog.models import Catalog
from ..ir.models import QueryIR
from .prompt import build_planner_prompt
logger = get_logger("query_planner")
_PROMPT_PATH = (
Path(__file__).resolve().parent.parent.parent
/ "config"
/ "prompts"
/ "query_planner.md"
)
def _load_prompt_text() -> str:
return _PROMPT_PATH.read_text(encoding="utf-8")
def _build_default_chain() -> Runnable:
from src.config.settings import settings
llm = AzureChatOpenAI(
azure_deployment=settings.azureai_deployment_name_4o,
openai_api_version=settings.azureai_api_version_4o,
azure_endpoint=settings.azureai_endpoint_url_4o,
api_key=settings.azureai_api_key_4o,
temperature=0,
)
prompt = ChatPromptTemplate.from_messages(
[
SystemMessage(content=_load_prompt_text()),
("human", "{human_content}"),
]
)
return prompt | llm.with_structured_output(QueryIR)
_default_chain: Runnable | None = None
def _get_default_chain() -> Runnable:
global _default_chain
if _default_chain is None:
_default_chain = _build_default_chain()
return _default_chain
class QueryPlannerService:
"""Wraps the LLM call with structured-output parsing into QueryIR.
Inject `structured_chain` for tests. The planner prompt is composed
by `build_planner_prompt(question, catalog, previous_error)` so retry
callers can append the prior error context to nudge the LLM.
"""
def __init__(self, structured_chain: Runnable | None = None) -> None:
self._chain = structured_chain
def _ensure_chain(self) -> Runnable:
if self._chain is None:
self._chain = _get_default_chain()
return self._chain
async def plan(
self,
question: str,
catalog: Catalog,
previous_error: str | None = None,
) -> QueryIR:
human_content = build_planner_prompt(question, catalog, previous_error)
chain = self._ensure_chain()
ir: QueryIR = await chain.ainvoke({"human_content": human_content})
logger.info(
"query planned",
source_id=ir.source_id,
table_id=ir.table_id,
select_n=len(ir.select),
filters_n=len(ir.filters),
retry=previous_error is not None,
)
return ir