Spaces:
Sleeping
Sleeping
| import sys | |
| import os | |
| # Thêm thư mục gốc (chứa thư mục "function") vào sys.path | |
| sys.path.append(os.path.abspath(os.path.join(os.getcwd(), ".."))) | |
| import os | |
| from dotenv import load_dotenv | |
| # Load biến môi trường từ file .env | |
| load_dotenv() | |
| DB_HOST = os.getenv("DB_HOST") | |
| DB_USER = os.getenv("DB_USER") | |
| DB_PASSWORD = os.getenv("DB_PASSWORD") | |
| DB_NAME = os.getenv("DB_NAME") | |
| DB_PORT = os.getenv("DB_PORT") | |
| import os | |
| from urllib.parse import quote | |
| password = os.getenv("DB_PASSWORD") # VD: 'Yahana0509@' | |
| DB_PASSWORD = quote(password) | |
| # Tạo connection string | |
| connection_uri = ( | |
| f"mysql+pymysql://{DB_USER}:{DB_PASSWORD}@{DB_HOST}:{DB_PORT}/{DB_NAME}" | |
| "?ssl_verify_cert=false&ssl_verify_identity=false" | |
| ) | |
| from langchain_community.utilities.sql_database import SQLDatabase | |
| from langchain_experimental.sql import SQLDatabaseChain | |
| import sys | |
| import os | |
| import pymysql | |
| from fastapi import HTTPException | |
| import sys | |
| import os | |
| from fastapi.encoders import jsonable_encoder | |
| import re | |
| from function.prompt import prompt_main as prompt | |
| import function.prompt.prompt_custom as prompt_cus | |
| os.environ["GOOGLE_API_KEY"] = "AIzaSyCO-RlqYewC4e9BEPoC8m-AxHUY7J3_o2E" | |
| from bson import ObjectId | |
| db = SQLDatabase.from_uri(connection_uri) | |
| from dotenv import load_dotenv | |
| import filter.filter_role as filter_role_1 | |
| import filter.filter_sql_injection as filter_sql_injection_1 | |
| import filter.result as query_result_1 | |
| import response.ResponseChat as res_chat | |
| from datetime import datetime | |
| import pytz | |
| from mongoengine import connect | |
| import os | |
| import nltk | |
| import function.agent.pipeline_agent as pipeline_agent | |
| nltk.download('punkt') | |
| from models.Database_Entity import User, ChatHistory, DetailChat | |
| from dotenv import load_dotenv | |
| load_dotenv() | |
| MONGO_URI = os.getenv("MONGO_URI", "") | |
| connect("chatbot_hmdrinks", host=MONGO_URI) | |
| import re | |
| def contains_delete(sql: str) -> bool: | |
| return bool(re.search(r'\bdelete\b', sql, re.IGNORECASE)) | |
| load_dotenv() | |
| #setup model | |
| from bson import ObjectId | |
| import random | |
| from langchain_google_genai import GoogleGenerativeAIEmbeddings, ChatGoogleGenerativeAI | |
| llm1 = ChatGoogleGenerativeAI(model='gemini-2.5-flash-preview-05-20',temperature=0.5) | |
| #setup kết nối database | |
| db = SQLDatabase.from_uri(connection_uri) | |
| db_chain = SQLDatabaseChain.from_llm(llm=llm1,db=db,prompt= prompt.PROMPT) | |
| from prompt.prompt_syntax_insert import is_insert_related_to_product_category_variant, filter_syntax_sql | |
| import sqlparse | |
| import sys | |
| import os | |
| sys.path.append(os.path.abspath(os.path.join(os.getcwd(), '..'))) | |
| from prompt import prompt_detail_table | |
| schema_mapping = { | |
| "payments": prompt_detail_table.prompt_payments, | |
| "user": prompt_detail_table.prompt_users, | |
| "user_voucher": prompt_detail_table.prompt_user_voucher, | |
| "category": prompt_detail_table.prompt_categort, | |
| "category_translation": prompt_detail_table.prompt_category_translation, | |
| "cart": prompt_detail_table.prompt_cart, | |
| "cart_item": prompt_detail_table.prompt_cart_item, | |
| "orders":prompt_detail_table.prompt_orders, | |
| "order_item": prompt_detail_table.prompt_order_item, | |
| "payment": prompt_detail_table.prompt_payments, | |
| "favourite": prompt_detail_table.prompt_favourite, | |
| "favourite_item": prompt_detail_table.prompt_fav_item, | |
| "post": prompt_detail_table.prompt_post, | |
| "post_translation": prompt_detail_table.prompt_post_translation, | |
| "product": prompt_detail_table.prompt_product, | |
| "product_translation": prompt_detail_table.prompt_product_translation, | |
| "shipment": prompt_detail_table.prompt_shipment, | |
| "product_variants": prompt_detail_table.prompt_product_variants, | |
| "review": prompt_detail_table.prompt_review, | |
| "user_coin": prompt_detail_table.prompt_user_coin, | |
| "absence": prompt_detail_table.prompt_absence, | |
| "cart_group": prompt_detail_table.prompt_cart_group, | |
| "cart_item_group": prompt_detail_table.prompt_cartitem_group, | |
| "group_orders": prompt_detail_table.prompt_group_orders, | |
| "payments_group": prompt_detail_table.prompt_payments_group, | |
| "group_order_members":prompt_detail_table.prompt_group_orders_member, | |
| "shipment_group":prompt_detail_table.prompt_shipment_group, | |
| "shipper_attendance":prompt_detail_table.prompt_shipper_attendance, | |
| "shipper_commission_detail":prompt_detail_table.prompt_shipper_commission_detail, | |
| "shipper_salary_summary":prompt_detail_table.prompt_shipper_salary_summary, | |
| "voucher": prompt_detail_table.prompt_voucher | |
| } | |
| def get_schemas_from_sql(sql_query: str, schema_mapping: dict): | |
| import sqlglot | |
| print("SQL query:", sql_query) | |
| parsed_query = sqlglot.parse_one(sql_query, read="mysql") | |
| # Lấy danh sách bảng duy nhất trong query | |
| table_names = list({t.name for t in parsed_query.find_all(sqlglot.exp.Table)}) | |
| schemas_used = {} | |
| for table in table_names: | |
| if table in schema_mapping: | |
| schemas_used[table] = schema_mapping[table] | |
| else: | |
| print(f"⚠️ Warning: Table '{table}' not found in schema_mapping") | |
| # Gom toàn bộ schema thành 1 chuỗi duy nhất | |
| all_schemas = "\n\n".join( | |
| [f"Schema for table '{table}':\n{schemas_used[table]}" for table in schemas_used] | |
| ) | |
| return all_schemas | |
| def build_sql_fix_prompt(schemas_result: dict, sql: str,user_id: int) -> str: | |
| prompt = f""" | |
| Bạn là một chuyên gia cơ sở dữ liệu. | |
| Dưới đây là mô tả schema chi tiết của các bảng có trong hệ thống: | |
| {schemas_result} | |
| --- | |
| Dưới đây là một câu SQL đang bị lỗi do không đúng tên bảng hoặc tên cột: | |
| ```sql | |
| {sql.strip()} | |
| Yêu cầu của bạn là: | |
| User Id: {user_id} | |
| Dựa trên các schema ở trên, hãy kiểm tra và chỉnh sửa câu SQL sao cho: | |
| Tên bảng, tên cột phải chính xác theo schema. | |
| Logic và mục đích của truy vấn được giữ nguyên. | |
| Chỉ trả lại phần SQL đã được chỉnh sửa (không giải thích, không chú thích, không thêm nhận xét). | |
| Loại bỏ các câu comment chỉ trả lại câu SQL chính xác sau khi bạn đã chỉnh sửa. | |
| Vui lòng chỉnh sửa một cách chính xác nhất. | |
| Trả lời dưới dạng một truy vấn SQL duy nhất. | |
| ** Tham khảo thêm các lỗi sau để tránh: | |
| - "- Tránh các lỗi như :\n" | |
| " (1054, \"Unknown column 'oi.pro_id' in 'field list'\")\n . Luôn đảm bảo bạn không bao giờ bị lỗi này" | |
| " (1054, \"Unknown column 'oi.note' in 'field list'\") . Luôn đảm bảo bạn không bao giờ bị lỗi này\n" | |
| " (1054, \"Unknown column 'oi.size' in 'field list'\") . Luôn đảm bảo bạn không bao giờ bị lỗi này \n" | |
| " (1054, \"Unknown column 'c.is_deleted' in 'on clause'\"). Luôn đảm bảo bạn không bao giờ bị lỗi này\n" | |
| """ | |
| return prompt | |
| async def execute_query_user(user_input: str, user_id: int, languages: str, role: str): | |
| PROMPT_CUSTOM = await prompt_cus.get_prompt_custom(user_input) | |
| check_insert = is_insert_related_to_product_category_variant(user_input) | |
| db_config = { | |
| "host": os.getenv("DB_HOST"), | |
| "user": os.getenv("DB_USER"), | |
| "database": os.getenv("DB_NAME"), | |
| "password": os.getenv("DB_PASSWORD"), | |
| "port": int(os.getenv("DB_PORT", 3306)), | |
| "charset": "utf8mb4", | |
| "cursorclass": pymysql.cursors.DictCursor, | |
| } | |
| def regenerate_sql_until_safe(): | |
| max_retry = 5 | |
| retry_count = 0 | |
| while retry_count < max_retry: | |
| try: | |
| regenerated_data = db_chain.run(f""" | |
| Role: {text_role} | |
| Language: {languages} | |
| Question: {user_input}. | |
| """) | |
| regenerated_sql = extract_sql_from_response(regenerated_data) | |
| if regenerated_sql: | |
| regenerated_sql = clean_sql(regenerated_sql) | |
| if not re.search(r"%{1,2}s", regenerated_sql): # đã sạch | |
| return regenerated_sql | |
| retry_count += 1 | |
| except Exception as e: | |
| return f"❌ Lỗi khi tạo lại truy vấn lần {retry_count + 1}: {str(e)}" | |
| return "❌ Lỗi: Không thể tạo được truy vấn an toàn sau nhiều lần thử." | |
| def execute_query_with_pymysql(query, multi=False): | |
| connection = pymysql.connect(**db_config) | |
| try: | |
| with connection.cursor() as cursor: | |
| results = [] | |
| if multi: | |
| # Dùng sqlparse để chia câu truy vấn cho an toàn | |
| statements = sqlparse.split(query) | |
| for stmt in statements: | |
| stmt = stmt.strip() | |
| if stmt: | |
| cursor.execute(stmt) | |
| results.append(cursor.fetchall()) | |
| else: | |
| cursor.execute(query) | |
| results = cursor.fetchall() | |
| connection.commit() | |
| return results | |
| except pymysql.MySQLError as e: | |
| return str(e) | |
| finally: | |
| connection.close() | |
| def clean_sql(sql) -> str: | |
| print("Clean sql1:", sql) | |
| if isinstance(sql, dict) and sql: | |
| first_value = next(iter(sql.values())) | |
| sql = first_value | |
| sql = re.sub(r"```sql", "", sql, flags=re.IGNORECASE) | |
| sql = sql.replace("```sql", "") | |
| sql = re.sub(r"```", "", sql) | |
| return sql.strip() | |
| # sql = re.sub(r"```sql", "", sql, flags=re.IGNORECASE) | |
| # sql = sql.replace("```sql", "") | |
| # # sql = re.sub(r'%%s', r'%s', sql) | |
| # sql = re.sub(r"```", "", sql) | |
| # return sql.strip() | |
| def extract_sql_from_response(data): | |
| match = re.search(r"\[SQL:\s*```sql\s*(.*?)\s*```]", data, re.DOTALL) | |
| if match: | |
| return clean_sql(match.group(1)) | |
| match = re.search(r"SQLQuery:\s*(.*)", data, re.DOTALL) | |
| if match: | |
| return clean_sql(match.group(1)) | |
| return None | |
| def extract_sql_from_error(error_msg): | |
| # Case 1: [SQL: ```sql ... ```] | |
| match = re.search(r"\[SQL:\s*```sql\s*(.*?)\s*```]", error_msg, re.DOTALL) | |
| if match: | |
| return clean_sql(match.group(1)) | |
| match = re.search(r"```(?:sql)?\s*\r?\n(.*?)```", error_msg, re.DOTALL) | |
| if match: | |
| return clean_sql(match.group(1)) | |
| return None | |
| def process_and_execute_sql(sql): | |
| print("NHận sql: ",sql ) | |
| data = sql | |
| if isinstance(data, dict) and data: | |
| print("Là dict") | |
| first_key, first_value = next(iter(data.items())) | |
| sql = first_value | |
| elif isinstance(data, list) and data: | |
| sql = "\n\n".join(data) | |
| sql_clean = clean_sql(sql) | |
| print("SQL Clean: ", sql_clean) | |
| if re.search(r"%{1,2}s", sql_clean): | |
| regenerated_sql = regenerate_sql_until_safe() | |
| sql_clean = clean_sql(regenerated_sql) | |
| result = get_schemas_from_sql(sql_clean, schema_mapping) | |
| prompt = build_sql_fix_prompt(schemas_result=result,sql = str(sql_clean),user_id =user_id) | |
| from advance_shopping.call_gemini import tool_call | |
| data = tool_call.generate(prompt = prompt) | |
| sql_clean = clean_sql(data) | |
| print("SQL step2: ", sql_clean) | |
| if contains_delete(sql_clean): | |
| return "Lỗi: Bạn không dược phép thực hiện truy vấn DELETE trong hệ thống này." | |
| if re.search(r"\bIF\b.*\bTHEN\b", sql_clean, re.IGNORECASE): | |
| return "❌ Lỗi: Không được dùng IF...THEN trong SQL. Vui lòng chia nhỏ truy vấn." | |
| if check_insert: | |
| check_syntax = filter_syntax_sql(sql_clean, PROMPT_CUSTOM, user_input) | |
| if check_syntax is True: | |
| # Tách từng câu và thực thi tuần tự | |
| try: | |
| connection = pymysql.connect(**db_config) | |
| with connection.cursor() as cursor: | |
| statements = sqlparse.split(sql_clean) | |
| results = [] | |
| for stmt in statements: | |
| stmt = stmt.strip() | |
| if stmt: | |
| cursor.execute(stmt) | |
| try: | |
| results.append(cursor.fetchall()) | |
| except: | |
| results.append("✅ OK") # Có thể là câu SET hoặc INSERT | |
| connection.commit() | |
| return results | |
| except Exception as e: | |
| return f"❌ Lỗi khi thực thi từng truy vấn: {str(e)}" | |
| finally: | |
| connection.close() | |
| else: | |
| try: | |
| regenerated_data = db_chain.run(f""" | |
| Role: {text_role} | |
| Language: {languages} | |
| Question: {user_input}. | |
| """) | |
| regenerated_sql = extract_sql_from_response(regenerated_data) | |
| if regenerated_sql: | |
| return process_and_execute_sql(regenerated_sql) | |
| else: | |
| return "❌ Lỗi: Không thể tạo lại truy vấn hợp lệ." | |
| except Exception as regen_error: | |
| return f"❌ Lỗi khi tạo lại truy vấn: {str(regen_error)}" | |
| else: | |
| return execute_query_with_pymysql(sql_clean, multi=True) | |
| PROMPT_CUSTOM.template += ( | |
| "\n\n⚠️ Lưu ý: Do NOT use IF...THEN...ELSE...END in SQL. " | |
| "Only use plain SELECT, INSERT, UPDATE, JOIN, SET statements. " | |
| "Tuyệt đối cấm dùng %%s hoặc %s để lấy giá trị từ biến. " | |
| "SQL phải chứa giá trị cố định, không dùng placeholder hay biến bind." | |
| ) | |
| db_chain = SQLDatabaseChain.from_llm(llm=llm1, db=db, prompt=PROMPT_CUSTOM) | |
| text_role = f"{role} (userId = {user_id})" if role == "ADMIN" else f"{role} (userId = {user_id}), not role ADMIN" | |
| try: | |
| data = db_chain.run(f""" | |
| Role: {text_role} | |
| Language: {languages} | |
| Question: {user_input}. | |
| """) | |
| extracted_sql = extract_sql_from_response(data) | |
| if extracted_sql: | |
| return process_and_execute_sql(extracted_sql) | |
| else: | |
| return data | |
| except Exception as e: | |
| error_message = str(e) | |
| extracted_sql = extract_sql_from_error(error_message) | |
| if extracted_sql: | |
| fix_sql = re.sub(r"```sql", "", extracted_sql) | |
| fix_sql = extracted_sql.replace("```", "") | |
| fix_sql = extracted_sql.replace("```sql", "") | |
| # fix_sql = re.sub(r'%%s', r'%s', fix_sql) | |
| if contains_delete(fix_sql): | |
| return "Lỗi: Bạn không dược phép thực hiện truy vấn DELETE trong hệ thống này." | |
| return process_and_execute_sql(fix_sql) | |
| else: | |
| # Không trích xuất được SQL -> Thử tạo lại truy vấn từ đầu | |
| try: | |
| regenerated_data = db_chain.run(f""" | |
| Role: {text_role} | |
| Language: {languages} | |
| Question: {user_input}. | |
| """) | |
| regenerated_sql = extract_sql_from_response(regenerated_data) | |
| if regenerated_sql: | |
| return process_and_execute_sql(regenerated_sql) | |
| else: | |
| return "❌ Lỗi: Không thể tạo lại truy vấn hợp lệ sau lỗi." | |
| except Exception as regen_error: | |
| return f"❌ Lỗi khi tạo lại truy vấn từ lỗi: {str(regen_error)}" | |