Rifqi Hafizuddin commited on
Commit ·
220f59e
1
Parent(s): 15cd3a7
[KM-512] add Pydantic model the LLM fills via function calling in sql_query, and add same signature for db and tabular
Browse files- src/models/sql_query.py +8 -0
- src/query/base.py +4 -0
- src/query/executor.py +5 -2
- src/query/executors/tabular.py +3 -0
src/models/sql_query.py
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Structured output model for LLM-generated SQL queries."""
|
| 2 |
+
|
| 3 |
+
from pydantic import BaseModel, Field
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
class SQLQuery(BaseModel):
|
| 7 |
+
sql: str = Field(description="A single SQL SELECT statement. No markdown, no explanation inline.")
|
| 8 |
+
reasoning: str = Field(description="One sentence: what this query answers.")
|
src/query/base.py
CHANGED
|
@@ -3,6 +3,8 @@
|
|
| 3 |
from abc import ABC, abstractmethod
|
| 4 |
from dataclasses import dataclass, field
|
| 5 |
|
|
|
|
|
|
|
| 6 |
from src.rag.base import RetrievalResult
|
| 7 |
|
| 8 |
|
|
@@ -15,6 +17,7 @@ class QueryResult:
|
|
| 15 |
rows: list[dict]
|
| 16 |
row_count: int
|
| 17 |
metadata: dict = field(default_factory=dict)
|
|
|
|
| 18 |
|
| 19 |
|
| 20 |
class BaseExecutor(ABC):
|
|
@@ -23,5 +26,6 @@ class BaseExecutor(ABC):
|
|
| 23 |
self,
|
| 24 |
results: list[RetrievalResult],
|
| 25 |
user_id: str,
|
|
|
|
| 26 |
limit: int = 100,
|
| 27 |
) -> list[QueryResult]: ...
|
|
|
|
| 3 |
from abc import ABC, abstractmethod
|
| 4 |
from dataclasses import dataclass, field
|
| 5 |
|
| 6 |
+
from sqlalchemy.ext.asyncio import AsyncSession
|
| 7 |
+
|
| 8 |
from src.rag.base import RetrievalResult
|
| 9 |
|
| 10 |
|
|
|
|
| 17 |
rows: list[dict]
|
| 18 |
row_count: int
|
| 19 |
metadata: dict = field(default_factory=dict)
|
| 20 |
+
# metadata should include "column_types": {"col_name": "dtype"} when available
|
| 21 |
|
| 22 |
|
| 23 |
class BaseExecutor(ABC):
|
|
|
|
| 26 |
self,
|
| 27 |
results: list[RetrievalResult],
|
| 28 |
user_id: str,
|
| 29 |
+
db: AsyncSession,
|
| 30 |
limit: int = 100,
|
| 31 |
) -> list[QueryResult]: ...
|
src/query/executor.py
CHANGED
|
@@ -2,6 +2,8 @@
|
|
| 2 |
|
| 3 |
import asyncio
|
| 4 |
|
|
|
|
|
|
|
| 5 |
from src.middlewares.logging import get_logger
|
| 6 |
from src.query.base import QueryResult
|
| 7 |
from src.query.executors.db import db_executor
|
|
@@ -16,6 +18,7 @@ class QueryExecutor:
|
|
| 16 |
self,
|
| 17 |
results: list[RetrievalResult],
|
| 18 |
user_id: str,
|
|
|
|
| 19 |
limit: int = 100,
|
| 20 |
) -> list[QueryResult]:
|
| 21 |
db_results = [r for r in results if r.source_type == "database"]
|
|
@@ -29,8 +32,8 @@ class QueryExecutor:
|
|
| 29 |
return []
|
| 30 |
|
| 31 |
batches = await asyncio.gather(
|
| 32 |
-
db_executor.execute(db_results, user_id, limit) if db_results else _empty(),
|
| 33 |
-
tabular_executor.execute(tabular_results, user_id, limit) if tabular_results else _empty(),
|
| 34 |
return_exceptions=True,
|
| 35 |
)
|
| 36 |
|
|
|
|
| 2 |
|
| 3 |
import asyncio
|
| 4 |
|
| 5 |
+
from sqlalchemy.ext.asyncio import AsyncSession
|
| 6 |
+
|
| 7 |
from src.middlewares.logging import get_logger
|
| 8 |
from src.query.base import QueryResult
|
| 9 |
from src.query.executors.db import db_executor
|
|
|
|
| 18 |
self,
|
| 19 |
results: list[RetrievalResult],
|
| 20 |
user_id: str,
|
| 21 |
+
db: AsyncSession,
|
| 22 |
limit: int = 100,
|
| 23 |
) -> list[QueryResult]:
|
| 24 |
db_results = [r for r in results if r.source_type == "database"]
|
|
|
|
| 32 |
return []
|
| 33 |
|
| 34 |
batches = await asyncio.gather(
|
| 35 |
+
db_executor.execute(db_results, user_id, db, limit) if db_results else _empty(),
|
| 36 |
+
tabular_executor.execute(tabular_results, user_id, db, limit) if tabular_results else _empty(),
|
| 37 |
return_exceptions=True,
|
| 38 |
)
|
| 39 |
|
src/query/executors/tabular.py
CHANGED
|
@@ -7,6 +7,8 @@ Flow:
|
|
| 7 |
4. Return QueryResult per document.
|
| 8 |
"""
|
| 9 |
|
|
|
|
|
|
|
| 10 |
from src.middlewares.logging import get_logger
|
| 11 |
from src.query.base import BaseExecutor, QueryResult
|
| 12 |
from src.rag.base import RetrievalResult
|
|
@@ -21,6 +23,7 @@ class TabularExecutor(BaseExecutor):
|
|
| 21 |
self,
|
| 22 |
results: list[RetrievalResult],
|
| 23 |
user_id: str,
|
|
|
|
| 24 |
limit: int = 100,
|
| 25 |
) -> list[QueryResult]:
|
| 26 |
# TODO: implement
|
|
|
|
| 7 |
4. Return QueryResult per document.
|
| 8 |
"""
|
| 9 |
|
| 10 |
+
from sqlalchemy.ext.asyncio import AsyncSession
|
| 11 |
+
|
| 12 |
from src.middlewares.logging import get_logger
|
| 13 |
from src.query.base import BaseExecutor, QueryResult
|
| 14 |
from src.rag.base import RetrievalResult
|
|
|
|
| 23 |
self,
|
| 24 |
results: list[RetrievalResult],
|
| 25 |
user_id: str,
|
| 26 |
+
db: AsyncSession,
|
| 27 |
limit: int = 100,
|
| 28 |
) -> list[QueryResult]:
|
| 29 |
# TODO: implement
|