Rifqi Hafizuddin commited on
Commit
220f59e
·
1 Parent(s): 15cd3a7

[KM-512] add Pydantic model the LLM fills via function calling in sql_query, and add same signature for db and tabular

Browse files
src/models/sql_query.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ """Structured output model for LLM-generated SQL queries."""
2
+
3
+ from pydantic import BaseModel, Field
4
+
5
+
6
+ class SQLQuery(BaseModel):
7
+ sql: str = Field(description="A single SQL SELECT statement. No markdown, no explanation inline.")
8
+ reasoning: str = Field(description="One sentence: what this query answers.")
src/query/base.py CHANGED
@@ -3,6 +3,8 @@
3
  from abc import ABC, abstractmethod
4
  from dataclasses import dataclass, field
5
 
 
 
6
  from src.rag.base import RetrievalResult
7
 
8
 
@@ -15,6 +17,7 @@ class QueryResult:
15
  rows: list[dict]
16
  row_count: int
17
  metadata: dict = field(default_factory=dict)
 
18
 
19
 
20
  class BaseExecutor(ABC):
@@ -23,5 +26,6 @@ class BaseExecutor(ABC):
23
  self,
24
  results: list[RetrievalResult],
25
  user_id: str,
 
26
  limit: int = 100,
27
  ) -> list[QueryResult]: ...
 
3
  from abc import ABC, abstractmethod
4
  from dataclasses import dataclass, field
5
 
6
+ from sqlalchemy.ext.asyncio import AsyncSession
7
+
8
  from src.rag.base import RetrievalResult
9
 
10
 
 
17
  rows: list[dict]
18
  row_count: int
19
  metadata: dict = field(default_factory=dict)
20
+ # metadata should include "column_types": {"col_name": "dtype"} when available
21
 
22
 
23
  class BaseExecutor(ABC):
 
26
  self,
27
  results: list[RetrievalResult],
28
  user_id: str,
29
+ db: AsyncSession,
30
  limit: int = 100,
31
  ) -> list[QueryResult]: ...
src/query/executor.py CHANGED
@@ -2,6 +2,8 @@
2
 
3
  import asyncio
4
 
 
 
5
  from src.middlewares.logging import get_logger
6
  from src.query.base import QueryResult
7
  from src.query.executors.db import db_executor
@@ -16,6 +18,7 @@ class QueryExecutor:
16
  self,
17
  results: list[RetrievalResult],
18
  user_id: str,
 
19
  limit: int = 100,
20
  ) -> list[QueryResult]:
21
  db_results = [r for r in results if r.source_type == "database"]
@@ -29,8 +32,8 @@ class QueryExecutor:
29
  return []
30
 
31
  batches = await asyncio.gather(
32
- db_executor.execute(db_results, user_id, limit) if db_results else _empty(),
33
- tabular_executor.execute(tabular_results, user_id, limit) if tabular_results else _empty(),
34
  return_exceptions=True,
35
  )
36
 
 
2
 
3
  import asyncio
4
 
5
+ from sqlalchemy.ext.asyncio import AsyncSession
6
+
7
  from src.middlewares.logging import get_logger
8
  from src.query.base import QueryResult
9
  from src.query.executors.db import db_executor
 
18
  self,
19
  results: list[RetrievalResult],
20
  user_id: str,
21
+ db: AsyncSession,
22
  limit: int = 100,
23
  ) -> list[QueryResult]:
24
  db_results = [r for r in results if r.source_type == "database"]
 
32
  return []
33
 
34
  batches = await asyncio.gather(
35
+ db_executor.execute(db_results, user_id, db, limit) if db_results else _empty(),
36
+ tabular_executor.execute(tabular_results, user_id, db, limit) if tabular_results else _empty(),
37
  return_exceptions=True,
38
  )
39
 
src/query/executors/tabular.py CHANGED
@@ -7,6 +7,8 @@ Flow:
7
  4. Return QueryResult per document.
8
  """
9
 
 
 
10
  from src.middlewares.logging import get_logger
11
  from src.query.base import BaseExecutor, QueryResult
12
  from src.rag.base import RetrievalResult
@@ -21,6 +23,7 @@ class TabularExecutor(BaseExecutor):
21
  self,
22
  results: list[RetrievalResult],
23
  user_id: str,
 
24
  limit: int = 100,
25
  ) -> list[QueryResult]:
26
  # TODO: implement
 
7
  4. Return QueryResult per document.
8
  """
9
 
10
+ from sqlalchemy.ext.asyncio import AsyncSession
11
+
12
  from src.middlewares.logging import get_logger
13
  from src.query.base import BaseExecutor, QueryResult
14
  from src.rag.base import RetrievalResult
 
23
  self,
24
  results: list[RetrievalResult],
25
  user_id: str,
26
+ db: AsyncSession,
27
  limit: int = 100,
28
  ) -> list[QueryResult]:
29
  # TODO: implement