Aryan Jain commited on
Commit
53d3e55
·
1 Parent(s): 83e6c59

feat: enhance SQL query generation and validation; add detailed descriptions to Pydantic models and enforce SELECT-only rule in query verification

Browse files
src/prompts/_pydantic_agent.py CHANGED
@@ -1 +1,115 @@
1
- SQL_QUERY_EXTRACTOR_PROMPT = """"""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ SQL_QUERY_EXTRACTOR_PROMPT = """
2
+ # ROLE
3
+ You are an expert SQL query planner and generator. Your task is to analyze a natural language user query along with provided table schemas and produce an accurate, executable SQL SELECT query.
4
+
5
+ ---
6
+
7
+ # INPUT
8
+
9
+ You will receive:
10
+
11
+ 1. **User Query** — Natural language request describing the data to retrieve.
12
+ 2. **Table Schema** — A subset of tables with their available columns. This is NOT the full database schema.
13
+
14
+ > ⚠️ You MUST only use the tables and columns explicitly provided in the schema. Do not infer, hallucinate, or reference any table or column not listed.
15
+
16
+ ---
17
+
18
+ # OBJECTIVE
19
+
20
+ Analyze the user query and produce:
21
+
22
+ 1. Required tables and their necessary columns
23
+ 2. Join conditions (if multiple tables are involved)
24
+ 3. Filter conditions (if the query implies filtering)
25
+ 4. Order conditions (if the query implies sorting)
26
+ 5. A final, valid SQL SELECT query
27
+
28
+ ---
29
+
30
+ # RULES
31
+
32
+ ## Table & Column Selection
33
+ - Use ONLY tables and columns present in the provided schema.
34
+ - Select ONLY columns needed to answer the query — never use `SELECT *`.
35
+ - If a required column does not exist in the schema, do not fabricate it.
36
+
37
+ ---
38
+
39
+ ## Join Detection
40
+
41
+ If more than one table is required:
42
+ - Set `is_join_required = true`
43
+ - Populate `joins_required` with join conditions
44
+
45
+ Supported join types: `INNER`, `LEFT`, `RIGHT`, `FULL`
46
+
47
+ If only one table is needed:
48
+ - Set `is_join_required = false`
49
+ - Set `joins_required = []`
50
+
51
+ ---
52
+
53
+ ## Filter Detection
54
+
55
+ If the user query contains filtering intent such as:
56
+ - comparisons (`greater than`, `less than`, `equal to`)
57
+ - date ranges (`before`, `after`, `between`)
58
+ - pattern matching (`like`, `contains`, `starts with`)
59
+ - existence checks (`where`, `only`, `exclude`)
60
+
61
+ Then:
62
+ - Set `is_filter_required = true`
63
+ - Populate `filters_required` with each condition
64
+
65
+ Supported operators: `=`, `!=`, `>`, `<`, `>=`, `<=`, `LIKE`, `BETWEEN`, `IN`, `IS NULL`, `IS NOT NULL`
66
+
67
+ If no filtering is needed:
68
+ - Set `is_filter_required = false`
69
+ - Set `filters_required = []`
70
+
71
+ ---
72
+
73
+ ## Order Detection
74
+
75
+ If the query implies sorting such as:
76
+ - `latest`, `oldest`, `most recent`
77
+ - `highest`, `lowest`, `top N`
78
+ - `alphabetical`, `sorted by`
79
+
80
+ Then:
81
+ - Set `is_order_required = true`
82
+ - Populate `orders_required`
83
+ - Use `DESC` for latest/highest, `ASC` for oldest/lowest/alphabetical
84
+
85
+ If no ordering is needed:
86
+ - Set `is_order_required = false`
87
+ - Set `orders_required = []`
88
+
89
+ ---
90
+
91
+ ## SQL Generation Rules
92
+
93
+ The generated SQL query MUST:
94
+ - Use only tables and columns from the provided schema
95
+ - Use correct JOIN syntax with explicit ON conditions
96
+ - Include WHERE clause only if filters exist
97
+ - Include ORDER BY clause only if ordering is required
98
+ - Be valid, executable ANSI SQL
99
+ - Never use `SELECT *`
100
+
101
+ ---
102
+
103
+ ## Safety Rules — STRICT
104
+
105
+ Only generate **SELECT** queries. Never generate:
106
+ - `DELETE`, `DROP`, `UPDATE`, `TRUNCATE`, `ALTER`, `INSERT`
107
+
108
+ If the user query implies a destructive operation, return an error message in `sql_query` explaining that only SELECT queries are supported.
109
+
110
+ ---
111
+
112
+ # TOOL VERIFICATION
113
+
114
+ After generating the SQL query, you must simulate verification using the provided tools.
115
+ """
src/schemas/_pydantic_agent.py CHANGED
@@ -1,71 +1,215 @@
1
  from typing import Optional
2
  from pydantic import BaseModel, Field, model_validator
3
 
 
4
  class Message(BaseModel):
5
- role: str = Field(..., description="Role of the message")
6
- content: str = Field(..., description="Content of the message")
 
 
 
 
7
 
8
  class TableDetails(BaseModel):
9
- table_name: str = Field(..., description="Name of the table")
 
 
 
 
 
 
10
  column_names: list[str] = Field(
11
- ..., description="List of column required for the query"
 
 
 
 
 
 
12
  )
13
 
14
 
15
  class JoinConditions(BaseModel):
16
- join_type: str = Field(..., description="Type of join")
17
- left_table: str = Field(..., description="Name of the left table")
18
- left_table_column: str = Field(..., description="Name of the left table column")
19
- right_table: str = Field(..., description="Name of the right table")
20
- right_table_column: str = Field(..., description="Name of the right table column")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
 
22
 
23
  class FilterConditions(BaseModel):
24
- table_name: str = Field(..., description="Name of the table")
25
- column_name: str = Field(..., description="Name of the column")
26
- operator: str = Field(..., description="Operator for the filter")
27
- value: str = Field(..., description="Value for the filter")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
 
29
 
30
  class OrderConditions(BaseModel):
31
- table_name: str = Field(..., description="Name of the table")
32
- column_name: str = Field(..., description="Name of the column")
33
- order: str = Field(..., description="Order for the filter")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34
 
35
 
36
  class SQLQueryExtractor(BaseModel):
37
  tables_required: list[TableDetails] = Field(
38
- ..., description="List of tables required for the query"
 
 
 
 
 
39
  )
40
 
41
  is_join_required: bool = Field(
42
- ..., description="Indicates if join is required for the query"
 
 
 
 
 
43
  )
44
  joins_required: Optional[list[JoinConditions]] = Field(
45
- [], description="List of joins required for the query"
 
 
 
 
 
46
  )
47
 
48
  is_filter_required: bool = Field(
49
- ..., description="Indicates if filter is required for the query"
 
 
 
 
 
 
50
  )
51
  filters_required: Optional[list[FilterConditions]] = Field(
52
- [], description="List of filters required for the query"
 
 
 
 
 
53
  )
54
 
55
  is_order_required: bool = Field(
56
- ..., description="Indicates if order is required for the query"
 
 
 
 
 
57
  )
58
  orders_required: Optional[list[OrderConditions]] = Field(
59
- [], description="List of orders required for the query"
 
 
 
 
 
60
  )
61
 
62
- sql_query: str = Field(..., description="SQL query")
 
 
 
 
 
 
 
 
 
 
 
 
 
63
 
64
  is_sql_query_verified_using_provided_tool: bool = Field(
65
- ..., description="Indicates if SQL query is verified using provided tools"
 
 
 
 
 
66
  )
67
  is_sql_query_mark_as_safe_using_provided_tool: bool = Field(
68
- ..., description="Indicates if SQL query is marked as safe using provided tools"
 
 
 
 
69
  )
70
 
71
  @model_validator(mode="after")
 
1
  from typing import Optional
2
  from pydantic import BaseModel, Field, model_validator
3
 
4
+
5
  class Message(BaseModel):
6
+ role: str = Field(
7
+ ...,
8
+ description="Role of the message sender. Must be one of: 'user', 'assistant', or 'system'.",
9
+ )
10
+ content: str = Field(..., description="The text content of the message.")
11
+
12
 
13
  class TableDetails(BaseModel):
14
+ table_name: str = Field(
15
+ ...,
16
+ description=(
17
+ "Exact name of the table as provided in the input schema. "
18
+ "Do NOT use tables that are not explicitly listed in the provided schema."
19
+ ),
20
+ )
21
  column_names: list[str] = Field(
22
+ ...,
23
+ description=(
24
+ "List of column names from this table that are required to answer the user query. "
25
+ "Only include columns that exist in the provided schema for this table. "
26
+ "Do NOT use SELECT * — always list specific columns. "
27
+ "Do NOT fabricate or infer column names not present in the schema."
28
+ ),
29
  )
30
 
31
 
32
  class JoinConditions(BaseModel):
33
+ join_type: str = Field(
34
+ ...,
35
+ description=(
36
+ "Type of SQL JOIN to apply. Must be one of: 'INNER', 'LEFT', 'RIGHT', 'FULL'. "
37
+ "Choose based on the relationship implied by the user query."
38
+ ),
39
+ )
40
+ left_table: str = Field(
41
+ ...,
42
+ description="Name of the left (primary/driving) table in the JOIN. Must exist in the provided schema.",
43
+ )
44
+ left_table_column: str = Field(
45
+ ...,
46
+ description=(
47
+ "Column from the left table used in the JOIN ON condition. "
48
+ "Typically a primary key or foreign key. Must exist in the provided schema."
49
+ ),
50
+ )
51
+ right_table: str = Field(
52
+ ...,
53
+ description="Name of the right (joined) table in the JOIN. Must exist in the provided schema.",
54
+ )
55
+ right_table_column: str = Field(
56
+ ...,
57
+ description=(
58
+ "Column from the right table used in the JOIN ON condition. "
59
+ "Must match the left_table_column semantically (e.g., foreign key relationship). "
60
+ "Must exist in the provided schema."
61
+ ),
62
+ )
63
 
64
 
65
  class FilterConditions(BaseModel):
66
+ table_name: str = Field(
67
+ ...,
68
+ description="Name of the table that contains the column to filter on. Must exist in the provided schema.",
69
+ )
70
+ column_name: str = Field(
71
+ ...,
72
+ description=(
73
+ "Name of the column to apply the filter on. "
74
+ "Must exist in the provided schema for the specified table."
75
+ ),
76
+ )
77
+ operator: str = Field(
78
+ ...,
79
+ description=(
80
+ "SQL comparison operator to use in the WHERE clause. "
81
+ "Allowed values: '=', '!=', '>', '<', '>=', '<=', 'LIKE', 'IN', 'BETWEEN', 'IS NULL', 'IS NOT NULL'. "
82
+ "Choose based on the condition described in the user query."
83
+ ),
84
+ )
85
+ value: str = Field(
86
+ ...,
87
+ description=(
88
+ "The filter value to compare against. "
89
+ "For string values, include surrounding quotes (e.g., \"'active'\"). "
90
+ "For dates, use ISO format (e.g., \"'2024-01-01'\"). "
91
+ "For IN operator, use comma-separated values in parentheses (e.g., \"('a', 'b')\"). "
92
+ "For BETWEEN, use format \"'value1' AND 'value2'\". "
93
+ "For IS NULL / IS NOT NULL, set value to empty string ''."
94
+ ),
95
+ )
96
 
97
 
98
  class OrderConditions(BaseModel):
99
+ table_name: str = Field(
100
+ ...,
101
+ description="Name of the table that contains the column to sort by. Must exist in the provided schema.",
102
+ )
103
+ column_name: str = Field(
104
+ ...,
105
+ description=(
106
+ "Name of the column to sort by. "
107
+ "Must exist in the provided schema for the specified table. "
108
+ "Choose the column that best represents the sorting intent "
109
+ "(e.g., 'created_at' for latest, 'price' for highest/lowest)."
110
+ ),
111
+ )
112
+ order: str = Field(
113
+ ...,
114
+ description=(
115
+ "Sort direction. Must be either 'ASC' or 'DESC'. "
116
+ "Use 'DESC' for: latest, newest, highest, most recent, top. "
117
+ "Use 'ASC' for: oldest, lowest, alphabetical, earliest."
118
+ ),
119
+ )
120
 
121
 
122
  class SQLQueryExtractor(BaseModel):
123
  tables_required: list[TableDetails] = Field(
124
+ ...,
125
+ description=(
126
+ "List of tables needed to answer the user query. "
127
+ "Only include tables from the provided schema that are directly necessary. "
128
+ "Do NOT include unrelated tables. Each entry must specify only the required columns."
129
+ ),
130
  )
131
 
132
  is_join_required: bool = Field(
133
+ ...,
134
+ description=(
135
+ "Set to true if the query requires data from more than one table. "
136
+ "Set to false if a single table is sufficient. "
137
+ "If true, joins_required must be populated."
138
+ ),
139
  )
140
  joins_required: Optional[list[JoinConditions]] = Field(
141
+ [],
142
+ description=(
143
+ "List of JOIN conditions required to combine tables. "
144
+ "Must be populated when is_join_required is true. "
145
+ "Must be empty ([]) when is_join_required is false."
146
+ ),
147
  )
148
 
149
  is_filter_required: bool = Field(
150
+ ...,
151
+ description=(
152
+ "Set to true if the user query contains any filtering intent such as: "
153
+ "comparisons (greater than, less than, equal to), date ranges (before, after, between), "
154
+ "pattern matching (like, contains, starts with), or conditional constraints (only, exclude, where). "
155
+ "Set to false if no filtering is needed. If true, filters_required must be populated."
156
+ ),
157
  )
158
  filters_required: Optional[list[FilterConditions]] = Field(
159
+ [],
160
+ description=(
161
+ "List of WHERE clause conditions extracted from the user query. "
162
+ "Must be populated when is_filter_required is true. "
163
+ "Must be empty ([]) when is_filter_required is false."
164
+ ),
165
  )
166
 
167
  is_order_required: bool = Field(
168
+ ...,
169
+ description=(
170
+ "Set to true if the user query implies sorting, such as: "
171
+ "latest, newest, oldest, highest, lowest, top N, alphabetical, sorted by. "
172
+ "Set to false if no ordering is implied. If true, orders_required must be populated."
173
+ ),
174
  )
175
  orders_required: Optional[list[OrderConditions]] = Field(
176
+ [],
177
+ description=(
178
+ "List of ORDER BY conditions derived from the user query. "
179
+ "Must be populated when is_order_required is true. "
180
+ "Must be empty ([]) when is_order_required is false."
181
+ ),
182
  )
183
 
184
+ sql_query: str = Field(
185
+ ...,
186
+ description=(
187
+ "The final, executable SQL SELECT query that answers the user query. "
188
+ "Rules: "
189
+ "(1) Use only tables and columns from the provided schema. "
190
+ "(2) Never use SELECT * — always list specific columns. "
191
+ "(3) Use table-qualified column names (e.g., table.column) to avoid ambiguity. "
192
+ "(4) Include JOIN clauses if is_join_required is true. "
193
+ "(5) Include WHERE clause if is_filter_required is true. "
194
+ "(6) Include ORDER BY clause if is_order_required is true. "
195
+ "(7) Only generate SELECT queries — never DELETE, DROP, UPDATE, TRUNCATE, ALTER, or INSERT."
196
+ ),
197
+ )
198
 
199
  is_sql_query_verified_using_provided_tool: bool = Field(
200
+ ...,
201
+ description=(
202
+ "Indicates whether the generated sql_query has been verified using the provided verification tool. "
203
+ "This MUST be set true if you call the provided tool to verify the query. "
204
+ "This must be set to false if you do not call the provided tool to verify the query."
205
+ ),
206
  )
207
  is_sql_query_mark_as_safe_using_provided_tool: bool = Field(
208
+ ...,
209
+ description=(
210
+ "Indicates whether the generated sql_query has been marked as safe by the provided safety-check tool. "
211
+ "This must be set True if provided tool return true else set False. "
212
+ ),
213
  )
214
 
215
  @model_validator(mode="after")
src/utils/_pydantic_agent.py CHANGED
@@ -38,6 +38,22 @@ class PydanticAgent:
38
 
39
  async def _verify_sql_query(self, sql_query):
40
  try:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41
  async with DatabaseConfig.async_session() as session:
42
  await session.execute(sql_query)
43
  except Exception as e:
 
38
 
39
  async def _verify_sql_query(self, sql_query):
40
  try:
41
+ words_shoould_not_present_in_sql_query = [
42
+ "DELETE",
43
+ "DROP",
44
+ "UPDATE",
45
+ "TRUNCATE",
46
+ "ALTER",
47
+ "INSERT"
48
+ ]
49
+ sql_query = sql_query.lower().strip()
50
+ if any(
51
+ word.lower() in sql_query
52
+ for word in words_shoould_not_present_in_sql_query
53
+ ):
54
+ raise Exception(
55
+ f"SQL query contains a destructive operation: {sql_query}. Only SELECT queries are allowed."
56
+ )
57
  async with DatabaseConfig.async_session() as session:
58
  await session.execute(sql_query)
59
  except Exception as e: