kltn21110 commited on
Commit
24bc84b
·
verified ·
1 Parent(s): 6d0dfcd

Update function/analyze/execute_query.py

Browse files
Files changed (1) hide show
  1. function/analyze/execute_query.py +397 -394
function/analyze/execute_query.py CHANGED
@@ -1,394 +1,397 @@
1
- import sys
2
- import os
3
-
4
- # Thêm thư mục gốc (chứa thư mục "function") vào sys.path
5
- sys.path.append(os.path.abspath(os.path.join(os.getcwd(), "..")))
6
- import os
7
- from dotenv import load_dotenv
8
-
9
- # Load biến môi trường từ file .env
10
- load_dotenv()
11
-
12
- DB_HOST = os.getenv("DB_HOST")
13
- DB_USER = os.getenv("DB_USER")
14
- DB_PASSWORD = os.getenv("DB_PASSWORD")
15
- DB_NAME = os.getenv("DB_NAME")
16
- DB_PORT = os.getenv("DB_PORT")
17
-
18
- import os
19
- from urllib.parse import quote
20
-
21
- password = os.getenv("DB_PASSWORD") # VD: 'Yahana0509@'
22
- DB_PASSWORD = quote(password)
23
- # Tạo connection string
24
- connection_uri = f"mysql+pymysql://{DB_USER}:{DB_PASSWORD}@{DB_HOST}:{DB_PORT}/{DB_NAME}"
25
- from langchain_community.utilities.sql_database import SQLDatabase
26
- from langchain_experimental.sql import SQLDatabaseChain
27
- import sys
28
- import os
29
- import pymysql
30
- from fastapi import HTTPException
31
- import sys
32
- import os
33
-
34
-
35
- from fastapi.encoders import jsonable_encoder
36
- import re
37
- from function.prompt import prompt_main as prompt
38
- import function.prompt.prompt_custom as prompt_cus
39
- os.environ["GOOGLE_API_KEY"] = "AIzaSyCO-RlqYewC4e9BEPoC8m-AxHUY7J3_o2E"
40
- from bson import ObjectId
41
- db = SQLDatabase.from_uri(connection_uri)
42
- from dotenv import load_dotenv
43
- import filter.filter_role as filter_role_1
44
- import filter.filter_sql_injection as filter_sql_injection_1
45
- import filter.result as query_result_1
46
- import response.ResponseChat as res_chat
47
- from datetime import datetime
48
- import pytz
49
- from mongoengine import connect
50
- import os
51
- import nltk
52
- import function.agent.pipeline_agent as pipeline_agent
53
- nltk.download('punkt')
54
- from models.Database_Entity import User, ChatHistory, DetailChat
55
- from dotenv import load_dotenv
56
- load_dotenv()
57
- MONGO_URI = os.getenv("MONGO_URI", "")
58
- connect("chatbot_hmdrinks", host=MONGO_URI)
59
- import re
60
-
61
- def contains_delete(sql: str) -> bool:
62
- return bool(re.search(r'\bdelete\b', sql, re.IGNORECASE))
63
- load_dotenv()
64
-
65
- #setup model
66
- from bson import ObjectId
67
- import random
68
- from langchain_google_genai import GoogleGenerativeAIEmbeddings, ChatGoogleGenerativeAI
69
- llm1 = ChatGoogleGenerativeAI(model='gemini-2.5-flash-preview-05-20',temperature=0.5)
70
-
71
- #setup kết nối database
72
- db = SQLDatabase.from_uri(connection_uri)
73
- db_chain = SQLDatabaseChain.from_llm(llm=llm1,db=db,prompt= prompt.PROMPT)
74
-
75
- from prompt.prompt_syntax_insert import is_insert_related_to_product_category_variant, filter_syntax_sql
76
- import sqlparse
77
-
78
-
79
- import sys
80
- import os
81
- sys.path.append(os.path.abspath(os.path.join(os.getcwd(), '..')))
82
- from prompt import prompt_detail_table
83
-
84
- schema_mapping = {
85
- "payments": prompt_detail_table.prompt_payments,
86
- "user": prompt_detail_table.prompt_users,
87
- "user_voucher": prompt_detail_table.prompt_user_voucher,
88
- "category": prompt_detail_table.prompt_categort,
89
- "category_translation": prompt_detail_table.prompt_category_translation,
90
- "cart": prompt_detail_table.prompt_cart,
91
- "cart_item": prompt_detail_table.prompt_cart_item,
92
- "orders":prompt_detail_table.prompt_orders,
93
- "order_item": prompt_detail_table.prompt_order_item,
94
- "payment": prompt_detail_table.prompt_payments,
95
- "favourite": prompt_detail_table.prompt_favourite,
96
- "favourite_item": prompt_detail_table.prompt_fav_item,
97
- "post": prompt_detail_table.prompt_post,
98
- "post_translation": prompt_detail_table.prompt_post_translation,
99
- "product": prompt_detail_table.prompt_product,
100
- "product_translation": prompt_detail_table.prompt_product_translation,
101
- "shipment": prompt_detail_table.prompt_shipment,
102
- "product_variants": prompt_detail_table.prompt_product_variants,
103
- "review": prompt_detail_table.prompt_review,
104
- "user_coin": prompt_detail_table.prompt_user_coin,
105
- "absence": prompt_detail_table.prompt_absence,
106
- "cart_group": prompt_detail_table.prompt_cart_group,
107
- "cart_item_group": prompt_detail_table.prompt_cartitem_group,
108
- "group_orders": prompt_detail_table.prompt_group_orders,
109
- "payments_group": prompt_detail_table.prompt_payments_group,
110
- "group_order_members":prompt_detail_table.prompt_group_orders_member,
111
- "shipment_group":prompt_detail_table.prompt_shipment_group,
112
- "shipper_attendance":prompt_detail_table.prompt_shipper_attendance,
113
- "shipper_commission_detail":prompt_detail_table.prompt_shipper_commission_detail,
114
- "shipper_salary_summary":prompt_detail_table.prompt_shipper_salary_summary,
115
- "voucher": prompt_detail_table.prompt_voucher
116
- }
117
-
118
-
119
- def get_schemas_from_sql(sql_query: str, schema_mapping: dict):
120
- import sqlglot
121
- print("SQL query:", sql_query)
122
- parsed_query = sqlglot.parse_one(sql_query, read="mysql")
123
-
124
- # Lấy danh sách bảng duy nhất trong query
125
- table_names = list({t.name for t in parsed_query.find_all(sqlglot.exp.Table)})
126
-
127
- schemas_used = {}
128
- for table in table_names:
129
- if table in schema_mapping:
130
- schemas_used[table] = schema_mapping[table]
131
- else:
132
- print(f"⚠️ Warning: Table '{table}' not found in schema_mapping")
133
-
134
- # Gom toàn bộ schema thành 1 chuỗi duy nhất
135
- all_schemas = "\n\n".join(
136
- [f"Schema for table '{table}':\n{schemas_used[table]}" for table in schemas_used]
137
- )
138
-
139
- return all_schemas
140
-
141
-
142
- def build_sql_fix_prompt(schemas_result: dict, sql: str,user_id: int) -> str:
143
-
144
- prompt = f"""
145
- Bạn một chuyên gia sở dữ liệu.
146
-
147
- Dưới đây mô tả schema chi tiết của các bảng có trong hệ thống:
148
-
149
- {schemas_result}
150
-
151
- ---
152
-
153
- Dưới đây là một câu SQL đang bị lỗi do không đúng tên bảng hoặc tên cột:
154
-
155
- ```sql
156
- {sql.strip()}
157
- Yêu cầu của bạn là:
158
- User Id: {user_id}
159
- Dựa trên các schema ở trên, hãy kiểm tra và chỉnh sửa câu SQL sao cho:
160
- Tên bảng, tên cột phải chính xác theo schema.
161
- Logic mục đích của truy vấn được giữ nguyên.
162
- Chỉ trả lại phần SQL đã được chỉnh sửa (không giải thích, không chú thích, không thêm nhận xét).
163
- Loại bỏ các câu comment chỉ trả lại câu SQL chính xác sau khi bạn đã chỉnh sửa.
164
- Vui lòng chỉnh sửa một cách chính xác nhất.
165
- Trả lời dưới dạng một truy vấn SQL duy nhất.
166
- ** Tham khảo thêm các lỗi sau để tránh:
167
- - "- Tránh các lỗi như :\n"
168
- " (1054, \"Unknown column 'oi.pro_id' in 'field list'\")\n . Luôn đảm bảo bạn không bao giờ bị lỗi này"
169
- " (1054, \"Unknown column 'oi.note' in 'field list'\") . Luôn đảm bảo bạn không bao giờ bị lỗi này\n"
170
- " (1054, \"Unknown column 'oi.size' in 'field list'\") . Luôn đảm bảo bạn không bao giờ bị lỗi này \n"
171
- " (1054, \"Unknown column 'c.is_deleted' in 'on clause'\"). Luôn đảm bảo bạn không bao giờ bị lỗi này\n"
172
- """
173
- return prompt
174
-
175
-
176
-
177
-
178
- async def execute_query_user(user_input: str, user_id: int, languages: str, role: str):
179
- PROMPT_CUSTOM = await prompt_cus.get_prompt_custom(user_input)
180
- check_insert = is_insert_related_to_product_category_variant(user_input)
181
-
182
- db_config = {
183
- "host": os.getenv("DB_HOST"),
184
- "user": os.getenv("DB_USER"),
185
- "database": os.getenv("DB_NAME"),
186
- "password": os.getenv("DB_PASSWORD"),
187
- "port": int(os.getenv("DB_PORT", 3306)),
188
- "charset": "utf8mb4",
189
- "cursorclass": pymysql.cursors.DictCursor,
190
- }
191
-
192
-
193
- def regenerate_sql_until_safe():
194
- max_retry = 5
195
- retry_count = 0
196
-
197
- while retry_count < max_retry:
198
- try:
199
- regenerated_data = db_chain.run(f"""
200
- Role: {text_role}
201
- Language: {languages}
202
- Question: {user_input}.
203
- """)
204
- regenerated_sql = extract_sql_from_response(regenerated_data)
205
- if regenerated_sql:
206
- regenerated_sql = clean_sql(regenerated_sql)
207
- if not re.search(r"%{1,2}s", regenerated_sql): # đã sạch
208
- return regenerated_sql
209
- retry_count += 1
210
- except Exception as e:
211
- return f"❌ Lỗi khi tạo lại truy vấn lần {retry_count + 1}: {str(e)}"
212
-
213
- return "❌ Lỗi: Không thể tạo được truy vấn an toàn sau nhiều lần thử."
214
-
215
-
216
- def execute_query_with_pymysql(query, multi=False):
217
- connection = pymysql.connect(**db_config)
218
- try:
219
- with connection.cursor() as cursor:
220
- results = []
221
- if multi:
222
- # Dùng sqlparse để chia câu truy vấn cho an toàn
223
- statements = sqlparse.split(query)
224
- for stmt in statements:
225
- stmt = stmt.strip()
226
- if stmt:
227
- cursor.execute(stmt)
228
- results.append(cursor.fetchall())
229
- else:
230
- cursor.execute(query)
231
- results = cursor.fetchall()
232
- connection.commit()
233
- return results
234
- except pymysql.MySQLError as e:
235
- return str(e)
236
- finally:
237
- connection.close()
238
-
239
- def clean_sql(sql) -> str:
240
- print("Clean sql1:", sql)
241
- if isinstance(sql, dict) and sql:
242
- first_value = next(iter(sql.values()))
243
- sql = first_value
244
- sql = re.sub(r"```sql", "", sql, flags=re.IGNORECASE)
245
- sql = sql.replace("```sql", "")
246
- sql = re.sub(r"```", "", sql)
247
- return sql.strip()
248
- # sql = re.sub(r"```sql", "", sql, flags=re.IGNORECASE)
249
- # sql = sql.replace("```sql", "")
250
- # # sql = re.sub(r'%%s', r'%s', sql)
251
- # sql = re.sub(r"```", "", sql)
252
- # return sql.strip()
253
-
254
- def extract_sql_from_response(data):
255
- match = re.search(r"\[SQL:\s*```sql\s*(.*?)\s*```]", data, re.DOTALL)
256
- if match:
257
- return clean_sql(match.group(1))
258
- match = re.search(r"SQLQuery:\s*(.*)", data, re.DOTALL)
259
- if match:
260
- return clean_sql(match.group(1))
261
-
262
- return None
263
-
264
- def extract_sql_from_error(error_msg):
265
- # Case 1: [SQL: ```sql ... ```]
266
- match = re.search(r"\[SQL:\s*```sql\s*(.*?)\s*```]", error_msg, re.DOTALL)
267
- if match:
268
- return clean_sql(match.group(1))
269
-
270
-
271
- match = re.search(r"```(?:sql)?\s*\r?\n(.*?)```", error_msg, re.DOTALL)
272
- if match:
273
- return clean_sql(match.group(1))
274
-
275
- return None
276
-
277
- def process_and_execute_sql(sql):
278
- print("NHận sql: ",sql )
279
- data = sql
280
- if isinstance(data, dict) and data:
281
- print(" dict")
282
- first_key, first_value = next(iter(data.items()))
283
- sql = first_value
284
- elif isinstance(data, list) and data:
285
- sql = "\n\n".join(data)
286
-
287
- sql_clean = clean_sql(sql)
288
- print("SQL Clean: ", sql_clean)
289
- if re.search(r"%{1,2}s", sql_clean):
290
- regenerated_sql = regenerate_sql_until_safe()
291
- sql_clean = clean_sql(regenerated_sql)
292
- result = get_schemas_from_sql(sql_clean, schema_mapping)
293
- prompt = build_sql_fix_prompt(schemas_result=result,sql = str(sql_clean),user_id =user_id)
294
- from advance_shopping.call_gemini import tool_call
295
- data = tool_call.generate(prompt = prompt)
296
- sql_clean = clean_sql(data)
297
- print("SQL step2: ", sql_clean)
298
- if contains_delete(sql_clean):
299
- return "Lỗi: Bạn không dược phép thực hiện truy vấn DELETE trong hệ thống này."
300
-
301
- if re.search(r"\bIF\b.*\bTHEN\b", sql_clean, re.IGNORECASE):
302
- return "Lỗi: Không được dùng IF...THEN trong SQL. Vui lòng chia nhỏ truy vấn."
303
-
304
- if check_insert:
305
- check_syntax = filter_syntax_sql(sql_clean, PROMPT_CUSTOM, user_input)
306
- if check_syntax is True:
307
- # Tách từng câu và thực thi tuần tự
308
- try:
309
- connection = pymysql.connect(**db_config)
310
- with connection.cursor() as cursor:
311
- statements = sqlparse.split(sql_clean)
312
- results = []
313
- for stmt in statements:
314
- stmt = stmt.strip()
315
- if stmt:
316
- cursor.execute(stmt)
317
- try:
318
- results.append(cursor.fetchall())
319
- except:
320
- results.append("✅ OK") # Có thể là câu SET hoặc INSERT
321
- connection.commit()
322
- return results
323
- except Exception as e:
324
- return f"❌ Lỗi khi thực thi từng truy vấn: {str(e)}"
325
- finally:
326
- connection.close()
327
- else:
328
- try:
329
- regenerated_data = db_chain.run(f"""
330
- Role: {text_role}
331
- Language: {languages}
332
- Question: {user_input}.
333
- """)
334
- regenerated_sql = extract_sql_from_response(regenerated_data)
335
- if regenerated_sql:
336
- return process_and_execute_sql(regenerated_sql)
337
- else:
338
- return "❌ Lỗi: Không thể tạo lại truy vấn hợp lệ."
339
- except Exception as regen_error:
340
- return f"❌ Lỗi khi tạo lại truy vấn: {str(regen_error)}"
341
- else:
342
- return execute_query_with_pymysql(sql_clean, multi=True)
343
-
344
-
345
-
346
- PROMPT_CUSTOM.template += (
347
- "\n\n⚠️ Lưu ý: Do NOT use IF...THEN...ELSE...END in SQL. "
348
- "Only use plain SELECT, INSERT, UPDATE, JOIN, SET statements. "
349
- "Tuyệt đối cấm dùng %%s hoặc %s để lấy giá trị từ biến. "
350
- "SQL phải chứa giá trị cố định, không dùng placeholder hay biến bind."
351
- )
352
-
353
- db_chain = SQLDatabaseChain.from_llm(llm=llm1, db=db, prompt=PROMPT_CUSTOM)
354
- text_role = f"{role} (userId = {user_id})" if role == "ADMIN" else f"{role} (userId = {user_id}), not role ADMIN"
355
-
356
- try:
357
- data = db_chain.run(f"""
358
- Role: {text_role}
359
- Language: {languages}
360
- Question: {user_input}.
361
- """)
362
- extracted_sql = extract_sql_from_response(data)
363
- if extracted_sql:
364
- return process_and_execute_sql(extracted_sql)
365
- else:
366
- return data
367
-
368
- except Exception as e:
369
- error_message = str(e)
370
- extracted_sql = extract_sql_from_error(error_message)
371
-
372
- if extracted_sql:
373
- fix_sql = re.sub(r"```sql", "", extracted_sql)
374
- fix_sql = extracted_sql.replace("```", "")
375
- fix_sql = extracted_sql.replace("```sql", "")
376
- # fix_sql = re.sub(r'%%s', r'%s', fix_sql)
377
- if contains_delete(fix_sql):
378
- return "Lỗi: Bạn không dược phép thực hiện truy vấn DELETE trong hệ thống này."
379
- return process_and_execute_sql(fix_sql)
380
- else:
381
- # Không trích xuất được SQL -> Thử tạo lại truy vấn từ đầu
382
- try:
383
- regenerated_data = db_chain.run(f"""
384
- Role: {text_role}
385
- Language: {languages}
386
- Question: {user_input}.
387
- """)
388
- regenerated_sql = extract_sql_from_response(regenerated_data)
389
- if regenerated_sql:
390
- return process_and_execute_sql(regenerated_sql)
391
- else:
392
- return "❌ Lỗi: Không thể tạo lại truy vấn hợp lệ sau lỗi."
393
- except Exception as regen_error:
394
- return f"❌ Lỗi khi tạo lại truy vấn từ lỗi: {str(regen_error)}"
 
 
 
 
1
+ import sys
2
+ import os
3
+
4
+ # Thêm thư mục gốc (chứa thư mục "function") vào sys.path
5
+ sys.path.append(os.path.abspath(os.path.join(os.getcwd(), "..")))
6
+ import os
7
+ from dotenv import load_dotenv
8
+
9
+ # Load biến môi trường từ file .env
10
+ load_dotenv()
11
+
12
+ DB_HOST = os.getenv("DB_HOST")
13
+ DB_USER = os.getenv("DB_USER")
14
+ DB_PASSWORD = os.getenv("DB_PASSWORD")
15
+ DB_NAME = os.getenv("DB_NAME")
16
+ DB_PORT = os.getenv("DB_PORT")
17
+
18
+ import os
19
+ from urllib.parse import quote
20
+
21
+ password = os.getenv("DB_PASSWORD") # VD: 'Yahana0509@'
22
+ DB_PASSWORD = quote(password)
23
+ # Tạo connection string
24
+ connection_uri = (
25
+ f"mysql+pymysql://{DB_USER}:{DB_PASSWORD}@{DB_HOST}:{DB_PORT}/{DB_NAME}"
26
+ "?ssl_verify_cert=false&ssl_verify_identity=false"
27
+ )
28
+ from langchain_community.utilities.sql_database import SQLDatabase
29
+ from langchain_experimental.sql import SQLDatabaseChain
30
+ import sys
31
+ import os
32
+ import pymysql
33
+ from fastapi import HTTPException
34
+ import sys
35
+ import os
36
+
37
+
38
+ from fastapi.encoders import jsonable_encoder
39
+ import re
40
+ from function.prompt import prompt_main as prompt
41
+ import function.prompt.prompt_custom as prompt_cus
42
+ os.environ["GOOGLE_API_KEY"] = "AIzaSyCO-RlqYewC4e9BEPoC8m-AxHUY7J3_o2E"
43
+ from bson import ObjectId
44
+ db = SQLDatabase.from_uri(connection_uri)
45
+ from dotenv import load_dotenv
46
+ import filter.filter_role as filter_role_1
47
+ import filter.filter_sql_injection as filter_sql_injection_1
48
+ import filter.result as query_result_1
49
+ import response.ResponseChat as res_chat
50
+ from datetime import datetime
51
+ import pytz
52
+ from mongoengine import connect
53
+ import os
54
+ import nltk
55
+ import function.agent.pipeline_agent as pipeline_agent
56
+ nltk.download('punkt')
57
+ from models.Database_Entity import User, ChatHistory, DetailChat
58
+ from dotenv import load_dotenv
59
+ load_dotenv()
60
+ MONGO_URI = os.getenv("MONGO_URI", "")
61
+ connect("chatbot_hmdrinks", host=MONGO_URI)
62
+ import re
63
+
64
+ def contains_delete(sql: str) -> bool:
65
+ return bool(re.search(r'\bdelete\b', sql, re.IGNORECASE))
66
+ load_dotenv()
67
+
68
+ #setup model
69
+ from bson import ObjectId
70
+ import random
71
+ from langchain_google_genai import GoogleGenerativeAIEmbeddings, ChatGoogleGenerativeAI
72
+ llm1 = ChatGoogleGenerativeAI(model='gemini-2.5-flash-preview-05-20',temperature=0.5)
73
+
74
+ #setup kết nối database
75
+ db = SQLDatabase.from_uri(connection_uri)
76
+ db_chain = SQLDatabaseChain.from_llm(llm=llm1,db=db,prompt= prompt.PROMPT)
77
+
78
+ from prompt.prompt_syntax_insert import is_insert_related_to_product_category_variant, filter_syntax_sql
79
+ import sqlparse
80
+
81
+
82
+ import sys
83
+ import os
84
+ sys.path.append(os.path.abspath(os.path.join(os.getcwd(), '..')))
85
+ from prompt import prompt_detail_table
86
+
87
+ schema_mapping = {
88
+ "payments": prompt_detail_table.prompt_payments,
89
+ "user": prompt_detail_table.prompt_users,
90
+ "user_voucher": prompt_detail_table.prompt_user_voucher,
91
+ "category": prompt_detail_table.prompt_categort,
92
+ "category_translation": prompt_detail_table.prompt_category_translation,
93
+ "cart": prompt_detail_table.prompt_cart,
94
+ "cart_item": prompt_detail_table.prompt_cart_item,
95
+ "orders":prompt_detail_table.prompt_orders,
96
+ "order_item": prompt_detail_table.prompt_order_item,
97
+ "payment": prompt_detail_table.prompt_payments,
98
+ "favourite": prompt_detail_table.prompt_favourite,
99
+ "favourite_item": prompt_detail_table.prompt_fav_item,
100
+ "post": prompt_detail_table.prompt_post,
101
+ "post_translation": prompt_detail_table.prompt_post_translation,
102
+ "product": prompt_detail_table.prompt_product,
103
+ "product_translation": prompt_detail_table.prompt_product_translation,
104
+ "shipment": prompt_detail_table.prompt_shipment,
105
+ "product_variants": prompt_detail_table.prompt_product_variants,
106
+ "review": prompt_detail_table.prompt_review,
107
+ "user_coin": prompt_detail_table.prompt_user_coin,
108
+ "absence": prompt_detail_table.prompt_absence,
109
+ "cart_group": prompt_detail_table.prompt_cart_group,
110
+ "cart_item_group": prompt_detail_table.prompt_cartitem_group,
111
+ "group_orders": prompt_detail_table.prompt_group_orders,
112
+ "payments_group": prompt_detail_table.prompt_payments_group,
113
+ "group_order_members":prompt_detail_table.prompt_group_orders_member,
114
+ "shipment_group":prompt_detail_table.prompt_shipment_group,
115
+ "shipper_attendance":prompt_detail_table.prompt_shipper_attendance,
116
+ "shipper_commission_detail":prompt_detail_table.prompt_shipper_commission_detail,
117
+ "shipper_salary_summary":prompt_detail_table.prompt_shipper_salary_summary,
118
+ "voucher": prompt_detail_table.prompt_voucher
119
+ }
120
+
121
+
122
+ def get_schemas_from_sql(sql_query: str, schema_mapping: dict):
123
+ import sqlglot
124
+ print("SQL query:", sql_query)
125
+ parsed_query = sqlglot.parse_one(sql_query, read="mysql")
126
+
127
+ # Lấy danh sách bảng duy nhất trong query
128
+ table_names = list({t.name for t in parsed_query.find_all(sqlglot.exp.Table)})
129
+
130
+ schemas_used = {}
131
+ for table in table_names:
132
+ if table in schema_mapping:
133
+ schemas_used[table] = schema_mapping[table]
134
+ else:
135
+ print(f"⚠️ Warning: Table '{table}' not found in schema_mapping")
136
+
137
+ # Gom toàn bộ schema thành 1 chuỗi duy nhất
138
+ all_schemas = "\n\n".join(
139
+ [f"Schema for table '{table}':\n{schemas_used[table]}" for table in schemas_used]
140
+ )
141
+
142
+ return all_schemas
143
+
144
+
145
+ def build_sql_fix_prompt(schemas_result: dict, sql: str,user_id: int) -> str:
146
+
147
+ prompt = f"""
148
+ Bạn là một chuyên gia cơ sở dữ liệu.
149
+
150
+ Dưới đây là mô tả schema chi tiết của các bảng có trong hệ thống:
151
+
152
+ {schemas_result}
153
+
154
+ ---
155
+
156
+ Dưới đây là một câu SQL đang bị lỗi do không đúng tên bảng hoặc tên cột:
157
+
158
+ ```sql
159
+ {sql.strip()}
160
+ Yêu cầu của bạn là:
161
+ User Id: {user_id}
162
+ Dựa trên các schema trên, hãy kiểm tra chỉnh sửa câu SQL sao cho:
163
+ Tên bảng, tên cột phải chính xác theo schema.
164
+ Logic mục đích của truy vấn được giữ nguyên.
165
+ Chỉ trả lại phần SQL đã được chỉnh sửa (không giải thích, không chú thích, không thêm nhận xét).
166
+ Loại bỏ các câu comment chỉ trả lại câu SQL chính xác sau khi bạn đã chỉnh sửa.
167
+ Vui lòng chỉnh sửa một cách chính xác nhất.
168
+ Trả lời dưới dạng một truy vấn SQL duy nhất.
169
+ ** Tham khảo thêm các lỗi sau để tránh:
170
+ - "- Tránh các lỗi như :\n"
171
+ " (1054, \"Unknown column 'oi.pro_id' in 'field list'\")\n . Luôn đảm bảo bạn không bao giờ bị lỗi này"
172
+ " (1054, \"Unknown column 'oi.note' in 'field list'\") . Luôn đảm bảo bạn không bao giờ bị lỗi này\n"
173
+ " (1054, \"Unknown column 'oi.size' in 'field list'\") . Luôn đảm bảo bạn không bao giờ bị lỗi này \n"
174
+ " (1054, \"Unknown column 'c.is_deleted' in 'on clause'\"). Luôn đảm bảo bạn không bao giờ bị lỗi này\n"
175
+ """
176
+ return prompt
177
+
178
+
179
+
180
+
181
+ async def execute_query_user(user_input: str, user_id: int, languages: str, role: str):
182
+ PROMPT_CUSTOM = await prompt_cus.get_prompt_custom(user_input)
183
+ check_insert = is_insert_related_to_product_category_variant(user_input)
184
+
185
+ db_config = {
186
+ "host": os.getenv("DB_HOST"),
187
+ "user": os.getenv("DB_USER"),
188
+ "database": os.getenv("DB_NAME"),
189
+ "password": os.getenv("DB_PASSWORD"),
190
+ "port": int(os.getenv("DB_PORT", 3306)),
191
+ "charset": "utf8mb4",
192
+ "cursorclass": pymysql.cursors.DictCursor,
193
+ }
194
+
195
+
196
+ def regenerate_sql_until_safe():
197
+ max_retry = 5
198
+ retry_count = 0
199
+
200
+ while retry_count < max_retry:
201
+ try:
202
+ regenerated_data = db_chain.run(f"""
203
+ Role: {text_role}
204
+ Language: {languages}
205
+ Question: {user_input}.
206
+ """)
207
+ regenerated_sql = extract_sql_from_response(regenerated_data)
208
+ if regenerated_sql:
209
+ regenerated_sql = clean_sql(regenerated_sql)
210
+ if not re.search(r"%{1,2}s", regenerated_sql): # đã sạch
211
+ return regenerated_sql
212
+ retry_count += 1
213
+ except Exception as e:
214
+ return f"❌ Lỗi khi tạo lại truy vấn lần {retry_count + 1}: {str(e)}"
215
+
216
+ return "❌ Lỗi: Không thể tạo được truy vấn an toàn sau nhiều lần thử."
217
+
218
+
219
+ def execute_query_with_pymysql(query, multi=False):
220
+ connection = pymysql.connect(**db_config)
221
+ try:
222
+ with connection.cursor() as cursor:
223
+ results = []
224
+ if multi:
225
+ # Dùng sqlparse để chia câu truy vấn cho an toàn
226
+ statements = sqlparse.split(query)
227
+ for stmt in statements:
228
+ stmt = stmt.strip()
229
+ if stmt:
230
+ cursor.execute(stmt)
231
+ results.append(cursor.fetchall())
232
+ else:
233
+ cursor.execute(query)
234
+ results = cursor.fetchall()
235
+ connection.commit()
236
+ return results
237
+ except pymysql.MySQLError as e:
238
+ return str(e)
239
+ finally:
240
+ connection.close()
241
+
242
+ def clean_sql(sql) -> str:
243
+ print("Clean sql1:", sql)
244
+ if isinstance(sql, dict) and sql:
245
+ first_value = next(iter(sql.values()))
246
+ sql = first_value
247
+ sql = re.sub(r"```sql", "", sql, flags=re.IGNORECASE)
248
+ sql = sql.replace("```sql", "")
249
+ sql = re.sub(r"```", "", sql)
250
+ return sql.strip()
251
+ # sql = re.sub(r"```sql", "", sql, flags=re.IGNORECASE)
252
+ # sql = sql.replace("```sql", "")
253
+ # # sql = re.sub(r'%%s', r'%s', sql)
254
+ # sql = re.sub(r"```", "", sql)
255
+ # return sql.strip()
256
+
257
+ def extract_sql_from_response(data):
258
+ match = re.search(r"\[SQL:\s*```sql\s*(.*?)\s*```]", data, re.DOTALL)
259
+ if match:
260
+ return clean_sql(match.group(1))
261
+ match = re.search(r"SQLQuery:\s*(.*)", data, re.DOTALL)
262
+ if match:
263
+ return clean_sql(match.group(1))
264
+
265
+ return None
266
+
267
+ def extract_sql_from_error(error_msg):
268
+ # Case 1: [SQL: ```sql ... ```]
269
+ match = re.search(r"\[SQL:\s*```sql\s*(.*?)\s*```]", error_msg, re.DOTALL)
270
+ if match:
271
+ return clean_sql(match.group(1))
272
+
273
+
274
+ match = re.search(r"```(?:sql)?\s*\r?\n(.*?)```", error_msg, re.DOTALL)
275
+ if match:
276
+ return clean_sql(match.group(1))
277
+
278
+ return None
279
+
280
+ def process_and_execute_sql(sql):
281
+ print("NHận sql: ",sql )
282
+ data = sql
283
+ if isinstance(data, dict) and data:
284
+ print("Là dict")
285
+ first_key, first_value = next(iter(data.items()))
286
+ sql = first_value
287
+ elif isinstance(data, list) and data:
288
+ sql = "\n\n".join(data)
289
+
290
+ sql_clean = clean_sql(sql)
291
+ print("SQL Clean: ", sql_clean)
292
+ if re.search(r"%{1,2}s", sql_clean):
293
+ regenerated_sql = regenerate_sql_until_safe()
294
+ sql_clean = clean_sql(regenerated_sql)
295
+ result = get_schemas_from_sql(sql_clean, schema_mapping)
296
+ prompt = build_sql_fix_prompt(schemas_result=result,sql = str(sql_clean),user_id =user_id)
297
+ from advance_shopping.call_gemini import tool_call
298
+ data = tool_call.generate(prompt = prompt)
299
+ sql_clean = clean_sql(data)
300
+ print("SQL step2: ", sql_clean)
301
+ if contains_delete(sql_clean):
302
+ return "Lỗi: Bạn không dược phép thực hiện truy vấn DELETE trong hệ thống này."
303
+
304
+ if re.search(r"\bIF\b.*\bTHEN\b", sql_clean, re.IGNORECASE):
305
+ return "❌ Lỗi: Không được dùng IF...THEN trong SQL. Vui lòng chia nhỏ truy vấn."
306
+
307
+ if check_insert:
308
+ check_syntax = filter_syntax_sql(sql_clean, PROMPT_CUSTOM, user_input)
309
+ if check_syntax is True:
310
+ # Tách từng câu và thực thi tuần tự
311
+ try:
312
+ connection = pymysql.connect(**db_config)
313
+ with connection.cursor() as cursor:
314
+ statements = sqlparse.split(sql_clean)
315
+ results = []
316
+ for stmt in statements:
317
+ stmt = stmt.strip()
318
+ if stmt:
319
+ cursor.execute(stmt)
320
+ try:
321
+ results.append(cursor.fetchall())
322
+ except:
323
+ results.append("✅ OK") # thể là câu SET hoặc INSERT
324
+ connection.commit()
325
+ return results
326
+ except Exception as e:
327
+ return f"❌ Lỗi khi thực thi từng truy vấn: {str(e)}"
328
+ finally:
329
+ connection.close()
330
+ else:
331
+ try:
332
+ regenerated_data = db_chain.run(f"""
333
+ Role: {text_role}
334
+ Language: {languages}
335
+ Question: {user_input}.
336
+ """)
337
+ regenerated_sql = extract_sql_from_response(regenerated_data)
338
+ if regenerated_sql:
339
+ return process_and_execute_sql(regenerated_sql)
340
+ else:
341
+ return "❌ Lỗi: Không thể tạo lại truy vấn hợp lệ."
342
+ except Exception as regen_error:
343
+ return f"❌ Lỗi khi tạo lại truy vấn: {str(regen_error)}"
344
+ else:
345
+ return execute_query_with_pymysql(sql_clean, multi=True)
346
+
347
+
348
+
349
+ PROMPT_CUSTOM.template += (
350
+ "\n\n⚠️ Lưu ý: Do NOT use IF...THEN...ELSE...END in SQL. "
351
+ "Only use plain SELECT, INSERT, UPDATE, JOIN, SET statements. "
352
+ "Tuyệt đối cấm dùng %%s hoặc %s để lấy giá trị từ biến. "
353
+ "SQL phải chứa giá trị cố định, không dùng placeholder hay biến bind."
354
+ )
355
+
356
+ db_chain = SQLDatabaseChain.from_llm(llm=llm1, db=db, prompt=PROMPT_CUSTOM)
357
+ text_role = f"{role} (userId = {user_id})" if role == "ADMIN" else f"{role} (userId = {user_id}), not role ADMIN"
358
+
359
+ try:
360
+ data = db_chain.run(f"""
361
+ Role: {text_role}
362
+ Language: {languages}
363
+ Question: {user_input}.
364
+ """)
365
+ extracted_sql = extract_sql_from_response(data)
366
+ if extracted_sql:
367
+ return process_and_execute_sql(extracted_sql)
368
+ else:
369
+ return data
370
+
371
+ except Exception as e:
372
+ error_message = str(e)
373
+ extracted_sql = extract_sql_from_error(error_message)
374
+
375
+ if extracted_sql:
376
+ fix_sql = re.sub(r"```sql", "", extracted_sql)
377
+ fix_sql = extracted_sql.replace("```", "")
378
+ fix_sql = extracted_sql.replace("```sql", "")
379
+ # fix_sql = re.sub(r'%%s', r'%s', fix_sql)
380
+ if contains_delete(fix_sql):
381
+ return "Lỗi: Bạn không dược phép thực hiện truy vấn DELETE trong hệ thống này."
382
+ return process_and_execute_sql(fix_sql)
383
+ else:
384
+ # Không trích xuất được SQL -> Thử tạo lại truy vấn từ đầu
385
+ try:
386
+ regenerated_data = db_chain.run(f"""
387
+ Role: {text_role}
388
+ Language: {languages}
389
+ Question: {user_input}.
390
+ """)
391
+ regenerated_sql = extract_sql_from_response(regenerated_data)
392
+ if regenerated_sql:
393
+ return process_and_execute_sql(regenerated_sql)
394
+ else:
395
+ return "❌ Lỗi: Không thể tạo lại truy vấn hợp lệ sau lỗi."
396
+ except Exception as regen_error:
397
+ return f"❌ Lỗi khi tạo lại truy vấn từ lỗi: {str(regen_error)}"