File size: 2,995 Bytes
6bff5d9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
"""QueryPlannerService — single LLM call: question + catalog → JSON IR.

Prompt: src/config/prompts/query_planner.md (system) + the human content
built by `prompt.build_planner_prompt(...)`.

Output: a QueryIR ready for the IRValidator. Validation + retry are the
caller's concern (`QueryService` orchestrates that loop).
"""

from __future__ import annotations

from pathlib import Path

from langchain_core.messages import SystemMessage
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.runnables import Runnable
from langchain_openai import AzureChatOpenAI

from src.middlewares.logging import get_logger

from ...catalog.models import Catalog
from ..ir.models import QueryIR
from .prompt import build_planner_prompt

logger = get_logger("query_planner")

_PROMPT_PATH = (
    Path(__file__).resolve().parent.parent.parent
    / "config"
    / "prompts"
    / "query_planner.md"
)


def _load_prompt_text() -> str:
    return _PROMPT_PATH.read_text(encoding="utf-8")


def _build_default_chain() -> Runnable:
    from src.config.settings import settings

    llm = AzureChatOpenAI(
        azure_deployment=settings.azureai_deployment_name_4o,
        openai_api_version=settings.azureai_api_version_4o,
        azure_endpoint=settings.azureai_endpoint_url_4o,
        api_key=settings.azureai_api_key_4o,
        temperature=0,
    )
    prompt = ChatPromptTemplate.from_messages(
        [
            SystemMessage(content=_load_prompt_text()),
            ("human", "{human_content}"),
        ]
    )
    return prompt | llm.with_structured_output(QueryIR)


_default_chain: Runnable | None = None


def _get_default_chain() -> Runnable:
    global _default_chain
    if _default_chain is None:
        _default_chain = _build_default_chain()
    return _default_chain


class QueryPlannerService:
    """Wraps the LLM call with structured-output parsing into QueryIR.

    Inject `structured_chain` for tests. The planner prompt is composed
    by `build_planner_prompt(question, catalog, previous_error)` so retry
    callers can append the prior error context to nudge the LLM.
    """

    def __init__(self, structured_chain: Runnable | None = None) -> None:
        self._chain = structured_chain

    def _ensure_chain(self) -> Runnable:
        if self._chain is None:
            self._chain = _get_default_chain()
        return self._chain

    async def plan(
        self,
        question: str,
        catalog: Catalog,
        previous_error: str | None = None,
    ) -> QueryIR:
        human_content = build_planner_prompt(question, catalog, previous_error)
        chain = self._ensure_chain()
        ir: QueryIR = await chain.ainvoke({"human_content": human_content})
        logger.info(
            "query planned",
            source_id=ir.source_id,
            table_id=ir.table_id,
            select_n=len(ir.select),
            filters_n=len(ir.filters),
            retry=previous_error is not None,
        )
        return ir