Spaces:
Sleeping
Sleeping
File size: 4,635 Bytes
3b362b1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 |
# agent/tools.py
import pandas as pd
import re
from backend.db import get_connection, run_sql
from backend import queries as q
from agent.responses import format_response
# Hardcoded roles
VALID_ADMIN = "John"
# --------------------------
# Predefined helper functions
# --------------------------
def get_total_orders():
con = get_connection()
return con.execute(q.TOTAL_ORDERS).fetchone()[0]
def get_cancelled_orders():
con = get_connection()
return con.execute(q.CANCELLED_ORDERS).fetchone()[0]
def get_order_status(order_id, customer_id):
con = get_connection()
return con.execute(q.ORDER_STATUS, [order_id, customer_id]).fetchone()
def update_order_status(order_id, customer_id, new_status):
con = get_connection()
con.execute(q.UPDATE_STATUS, [new_status, order_id, customer_id])
return f"✅ Order {order_id} for customer {customer_id} updated to {new_status}."
# --------------------------
# Identity + validation
# --------------------------
def get_valid_customers():
"""Fetch the list of valid customers from DB dynamically."""
df = run_sql("""
SELECT customer_id, customer_name, customer_email
FROM my_db.main.orders
GROUP BY customer_id, customer_name, customer_email;
""")
return df.to_dict(orient="records")
def validate_identity(role: str, name: str, customer_id: str = None, email: str = None) -> bool:
"""Validate identity of customer or admin against DB or hardcoded admin."""
if role == "admin":
return name.strip().lower() == VALID_ADMIN.lower()
elif role == "customer":
customers = get_valid_customers()
for c in customers:
if (
c["customer_name"].lower() == name.lower()
and str(c["customer_id"]) == str(customer_id)
and c["customer_email"].lower() == email.lower()
):
return True
return False
return False
# --------------------------
# SQL Safety Guardrails
# --------------------------
def is_safe_sql(sql: str, role: str = "customer") -> bool:
"""
Validate SQL to prevent destructive queries.
- Customers: only SELECT.
- Admin: SELECT or UPDATE (status only) with tight WHERE clause.
"""
sql_clean = " ".join(sql.strip().upper().split()) # normalize spaces/case
# Debug print
print("\n--- SAFETY CHECK DEBUG ---")
print(f"SQL being checked: {sql_clean}")
print(f"Role: {role}")
print("---------------------------\n")
# Globally forbidden statements
forbidden = [" DROP ", " TRUNCATE ", " ALTER ", " DELETE ", " INSERT "]
if any(k in f" {sql_clean} " for k in forbidden):
return False
if role == "customer":
# Customers can only run SELECT queries
return sql_clean.startswith("SELECT")
if role == "admin":
if sql_clean.startswith("SELECT"):
# ✅ allow all SELECT queries for admin
return True
if sql_clean.startswith("UPDATE"):
# Guardrails: only update status on specific order with tight WHERE
has_set_status = " SET STATUS " in sql_clean
has_where = " WHERE " in sql_clean
has_order_id = " ORDER_ID " in sql_clean
has_customer_filter = (" CUSTOMER_ID " in sql_clean) or (" CUSTOMER_EMAIL " in sql_clean)
return has_set_status and has_where and has_order_id and has_customer_filter
return False
return False
# --------------------------
# SQL Execution
# --------------------------
def execute_sql(sql: str, role: str = "customer") -> str:
"""
Execute SQL query safely and return conversational response.
- role: "customer" or "admin"
- always returns a human-friendly string
"""
# Normalize query: strip markdown fences + "sql " prefix if LLM added them
sql = sql.strip()
sql = re.sub(r"^```sql", "", sql, flags=re.IGNORECASE).strip()
sql = re.sub(r"```$", "", sql).strip()
if sql.lower().startswith("sql "):
sql = sql[4:].strip()
# 🔍 Debugging logs
print("\n--- SQL DEBUG LOG ---")
print(f"Cleaned SQL from LLM: {sql}")
print(f"Role: {role}")
print("---------------------\n")
# ✅ Pass cleaned SQL to safety check
if not is_safe_sql(sql, role):
return f"⚠️ Unauthorized or unsafe query attempted: {sql}"
try:
df = run_sql(sql)
print(f"✅ Executed successfully. Rows returned: {len(df)}")
return format_response(df, role)
except Exception as e:
print(f"⚠️ Execution error: {e}")
return f"⚠️ Query error: {e}"
|