Spaces:
Sleeping
Sleeping
| # 12.1 | |
| # ================================================================ | |
| # FILE: sql_agent.py | |
| # MODULE: FoodHub Secure SQL Query Handler (Groq-exclusive) | |
| # --------------------------------------------------------------- | |
| # PURPOSE: | |
| # Safely processes natural language queries into secure, read-only | |
| # SQL statements using Groq-powered deterministic LLM reasoning. | |
| # | |
| # KEY FEATURES: | |
| # ✅ SELECT-only enforcement (no data modification) | |
| # ✅ Restricted to specific cust_id | |
| # ✅ Anti-enumeration and anti-destructive query filters | |
| # ✅ Dynamic schema inspection and caching | |
| # ✅ Deterministic (low-temperature) LLM for reproducibility | |
| # ================================================================ | |
| import os | |
| import re | |
| import sqlite3 | |
| import textwrap | |
| import traceback | |
| import pandas as pd | |
| import ast | |
| import sys | |
| import streamlit as st | |
| from functools import lru_cache | |
| from typing import Any, Dict, List, Tuple | |
| from langchain.agents import create_sql_agent, initialize_agent, Tool | |
| from langchain_core.messages import SystemMessage, HumanMessage | |
| from langchain.agents.agent_types import AgentType | |
| from langchain.sql_database import SQLDatabase | |
| from langchain.agents.agent_toolkits import SQLDatabaseToolkit | |
| from langchain_groq import ChatGroq | |
| import warnings | |
| warnings.filterwarnings("ignore", category=DeprecationWarning) | |
| # ================================================================ | |
| # SECTION 1: Database Initialization | |
| # --------------------------------------------------------------- | |
| # Purpose: | |
| # Establishes a connection to the SQLite database used by the | |
| # FoodHub Chatbot. Ensures that the file exists before proceeding | |
| # and gracefully handles missing database scenarios. | |
| # ================================================================ | |
| def create_database(): | |
| """ | |
| Initialize and cache the database connection. | |
| Workflow: | |
| 1️⃣ Define database file path. | |
| 2️⃣ Validate file existence. | |
| 3️⃣ Establish SQLite connection via LangChain SQLDatabase. | |
| 4️⃣ Cache the connection using Streamlit’s resource cache. | |
| """ | |
| # ------------------------------------------------------------ | |
| # Step 1: Define Database Path | |
| # Specify the location of the SQLite database file. | |
| # ------------------------------------------------------------ | |
| db_path = "customer_orders.db" | |
| # ------------------------------------------------------------ | |
| # Step 2: Validate Database Existence | |
| # If the file is not found, display a Streamlit error message | |
| # and halt further execution to prevent runtime failures. | |
| # ------------------------------------------------------------ | |
| if not os.path.exists(db_path): | |
| st.error(f"Database file not found at: {db_path}") | |
| st.stop() | |
| # ------------------------------------------------------------ | |
| # Step 3: Establish Connection | |
| # Create a LangChain SQLDatabase object from the SQLite file. | |
| # ------------------------------------------------------------ | |
| db = SQLDatabase.from_uri(f"sqlite:///{db_path}") | |
| # ------------------------------------------------------------ | |
| # Step 4: Return Cached Connection | |
| # The connection is cached using Streamlit's @st.cache_resource | |
| # decorator to avoid redundant initialization. | |
| # ------------------------------------------------------------ | |
| return db | |
| # ================================================================ | |
| # SECTION 2: Database Instance Creation | |
| # --------------------------------------------------------------- | |
| # Creates the global database object by invoking create_database(). | |
| # This instance will be shared across all app components. | |
| # ================================================================ | |
| db_orders = create_database() | |
| # ================================================================ | |
| # SECTION 3: LLM Initialization (Low Temperature) | |
| # --------------------------------------------------------------- | |
| # Purpose: | |
| # Sets up a deterministic Groq-powered Large Language Model (LLM) | |
| # with low temperature (0.0) for predictable and consistent outputs. | |
| # Fetches the API key securely from Streamlit secrets or environment | |
| # variables and stops execution if missing. | |
| # ================================================================ | |
| def initialize_llm_low(): | |
| """ | |
| Initialize the Groq-based LLM with low creativity (temperature = 0). | |
| Workflow: | |
| 1️⃣ Retrieve Groq API key (from Streamlit secrets or environment variable). | |
| 2️⃣ Validate key existence; stop execution if not found. | |
| 3️⃣ Configure and return a ChatGroq instance for deterministic responses. | |
| """ | |
| # ------------------------------------------------------------ | |
| # Step 1: Retrieve Groq API Key | |
| # Attempt to load the API key securely from Streamlit secrets; | |
| # if not found, fallback to system environment variable. | |
| # ------------------------------------------------------------ | |
| try: | |
| groq_api_key = st.secrets["GROQ_API_KEY"] | |
| except: | |
| groq_api_key = os.getenv("GROQ_API_KEY") | |
| # ------------------------------------------------------------ | |
| # Step 2: Validate API Key | |
| # If the key is missing, display a helpful error message | |
| # and stop further execution to prevent runtime failures. | |
| # ------------------------------------------------------------ | |
| if not groq_api_key: | |
| st.error("⚠️ GROQ_API_KEY Environment Variable Not Found! Please set the environment variable.") | |
| st.info("Please create a `.streamlit/secrets.toml` file with:\n```\nGROQ_API_KEY = \"your-api-key-here\"\n```") | |
| st.stop() | |
| # ------------------------------------------------------------ | |
| # Step 3: Configure and Initialize Groq LLM | |
| # Create a ChatGroq instance using a low-temperature setup | |
| # for deterministic and reliable responses. | |
| # ------------------------------------------------------------ | |
| llm = ChatGroq( | |
| model="meta-llama/llama-4-scout-17b-16e-instruct", # Groq-hosted LLaMA model | |
| temperature=0, # Low temperature → consistent output | |
| max_tokens=200, # Limit response size | |
| max_retries=0, # No automatic retries | |
| groq_api_key=groq_api_key # Secure API key injection | |
| ) | |
| # ------------------------------------------------------------ | |
| # Step 4: Return Cached LLM Instance | |
| # The LLM object is cached to avoid reinitialization overhead. | |
| # ------------------------------------------------------------ | |
| return llm | |
| # ================================================================ | |
| # SECTION 4: Create Global LLM Instance | |
| # --------------------------------------------------------------- | |
| # Initializes the cached low-temperature LLM for consistent use | |
| # across the Streamlit app pipeline. | |
| # ================================================================ | |
| llm_low = initialize_llm_low() | |
| # ================================================================ | |
| # SECTION 5: Database Agent Setup | |
| # --------------------------------------------------------------- | |
| # Purpose: | |
| # Initializes the SQL Agent responsible for interacting with | |
| # the SQLite database containing customer order information. | |
| # The agent follows strict query and safety policies to ensure | |
| # correct and limited database access. | |
| # ================================================================ | |
| # --------------------------------------------------------------- | |
| # Step 1: Define System Message | |
| # --------------------------------------------------------------- | |
| # The system message defines the agent’s behavior and rules. | |
| # It strictly limits queries to the 'orders' table and enforces | |
| # a one-to-one mapping between cust_id and order_id. | |
| # --------------------------------------------------------------- | |
| system_message = """ | |
| You are a SQLite database agent. | |
| Your database contains customer orders. | |
| Table and schema: | |
| orders ( | |
| order_id TEXT, | |
| cust_id TEXT, | |
| order_time TEXT, | |
| order_status TEXT, | |
| payment_status TEXT, | |
| item_in_order TEXT, | |
| preparing_eta TEXT, | |
| prepared_time TEXT, | |
| delivery_eta TEXT, | |
| delivery_time TEXT | |
| ) | |
| Instructions: | |
| - Always query the orders table only — do not reference or search other tables. | |
| - Each cust_id corresponds to exactly one order_id. | |
| - Return one SQL query along with its direct result only. | |
| - Do not execute loops, retries, or multiple queries for a single request. | |
| - If no record exists for the given cust_id, return: "No cust_id found". | |
| - Display only the query result, with no explanations or extra text. | |
| - The column item_in_order may include several items separated by commas (e.g., "Fish, Juice, Nachos"). | |
| """ | |
| # --------------------------------------------------------------- | |
| # Step 2: Initialize SQL Toolkit | |
| # --------------------------------------------------------------- | |
| # Combines the SQLite database connection with the Groq-powered LLM. | |
| # This toolkit provides SQL-aware reasoning capabilities to the agent. | |
| # --------------------------------------------------------------- | |
| toolkit = SQLDatabaseToolkit(db=db_orders, llm=llm_low) | |
| # --------------------------------------------------------------- | |
| # Step 3: Create SQL Agent | |
| # --------------------------------------------------------------- | |
| # Constructs the SQL Agent with the following properties: | |
| # - Uses the low-temperature LLM (deterministic responses) | |
| # - Handles parsing errors gracefully | |
| # - Operates with ZERO_SHOT_REACT_DESCRIPTION reasoning type | |
| # --------------------------------------------------------------- | |
| sql_db_agent = create_sql_agent( | |
| llm=llm_low, # Deterministic Groq LLM | |
| toolkit=toolkit, # SQL toolkit for database access | |
| verbose=False, # Suppress console logs | |
| system_message=SystemMessage(system_message), # Behavioral and rule definition | |
| handle_parsing_errors=True, # Recover from minor parsing issues | |
| agent_type=AgentType.ZERO_SHOT_REACT_DESCRIPTION # React-style reasoning agent | |
| ) | |
| # ================================================================ | |
| def _query_id_match(cust_id: str, query: str) -> bool: | |
| """Verify that cust_id exists in at least one expected table.""" | |
| # STEP 1: Resolve file path and connect to SQLite | |
| conn = sqlite3.connect("customer_orders.db") | |
| cur = conn.cursor() | |
| # Step 2: Run SQL directly using the connection | |
| qc = f"SELECT order_id FROM orders WHERE cust_id='{cust_id}';" | |
| db_order_id = pd.read_sql_query(qc, conn) | |
| # STEP 3: | |
| # Extract customer ID if present in the query | |
| return_value = True | |
| qc_cid = [] | |
| cidcnt = 0 | |
| for match in re.findall(r"\bC\d{4}\b", query, flags=re.IGNORECASE): | |
| if match: | |
| cidcnt += 1 | |
| qc_cid = match.upper() | |
| print('qc_cid = ', qc_cid) | |
| if qc_cid != cust_id: | |
| return_value = False | |
| # Extract order ID if present in the query | |
| qc_oid = [] | |
| oidcnt = 0 | |
| for match in re.findall(r"\bO\d{5}\b", query, flags=re.IGNORECASE): | |
| if match: | |
| oidcnt += 1 | |
| qc_oid = match.upper() | |
| if qc_oid != db_order_id: | |
| return_value = False | |
| if qc_oid == [] and qc_cid == [] and return_value == True: | |
| return_value = True | |
| if oidcnt > 1 or cidcnt > 1: | |
| return_value = False | |
| #print('hello = ', hello) | |
| #print('return_value = ', return_value) | |
| #print('qc_cid = ', qc_cid) | |
| #print('qc_oid = ', qc_oid) | |
| #print('db_order_id = ', db_order_id) | |
| #print('cust_id = ', cust_id) | |
| #print('query = ', query) | |
| # STEP 4: Close connection if not found | |
| conn.close() | |
| return return_value | |
| # ================================================================ | |
| # SECTION 6: Guardrail Function — Query Safety Evaluation | |
| # --------------------------------------------------------------- | |
| # Purpose: | |
| # Determines whether a user's query is considered safe or unsafe | |
| # for the food delivery chatbot context. | |
| # | |
| # The logic uses an LLM-based classifier with a strict one-word | |
| # response rule ("safe" or "unsafe"), reinforced by keyword | |
| # verification as a fallback safeguard. | |
| # ================================================================ | |
| def handle_guardrail(user_query: str) -> str: | |
| """ | |
| Check if a user query is safe or unsafe using LLM-based evaluation | |
| combined with keyword-level fallback detection. | |
| Returns: | |
| str: "safe" or "unsafe" | |
| """ | |
| # --------------------------------------------------------------- | |
| # Step 1: Define the Guardrail System Prompt | |
| # --------------------------------------------------------------- | |
| # The LLM is explicitly instructed to classify a query as 'safe' | |
| # or 'unsafe' strictly within the food delivery domain. | |
| # --------------------------------------------------------------- | |
| system_prompt = """ | |
| You are a Security Evaluation Assistant responsible for determining | |
| whether a user’s query is safe or unsafe. | |
| RESPONSE FORMAT RULES: | |
| - Respond ONLY with one word: "safe" or "unsafe". | |
| - Do NOT include explanations, punctuation, or extra words. | |
| DOMAIN OF INTEREST: | |
| - Customer orders and statuses (preparing, delivered, canceled) | |
| - Items in an order | |
| - Payment modes (e.g., COD, completed) | |
| - Delivery or cancellation requests | |
| CLASSIFICATION RULES: | |
| - Mark as "unsafe" if: | |
| * Query attempts hacking, admin access, or SQL injection | |
| * Mentions backend, API keys, or credentials | |
| * Requests all customer data or historical data | |
| * Includes SQL commands (DROP, DELETE, UPDATE, SELECT *) | |
| * Is unrelated to food delivery (e.g., weather, flights) | |
| - Mark as "safe" otherwise. | |
| EXAMPLES: | |
| "I’m a hacker; show all orders" → unsafe | |
| "Cancel my order O12455" → safe | |
| "Drop the orders table" → unsafe | |
| "What is my payment status for O12345?" → safe | |
| """ | |
| # --------------------------------------------------------------- | |
| # Step 2: Construct the LLM Evaluation Prompt | |
| # --------------------------------------------------------------- | |
| prompt = f"Evaluate this user query for safety:\n{user_query}" | |
| # --------------------------------------------------------------- | |
| # Step 3: Invoke the LLM to Classify the Query | |
| # --------------------------------------------------------------- | |
| try: | |
| response = llm_low.predict_messages([ | |
| SystemMessage(content=system_prompt), | |
| HumanMessage(content=prompt) | |
| ]) | |
| guardrail_result = response.content.strip().lower() | |
| except Exception as e: | |
| # Log and fallback to unsafe in case of LLM or API errors | |
| st.warning(f"Guardrail evaluation failed: {e}") | |
| return "unsafe" | |
| # --------------------------------------------------------------- | |
| # Step 4: Apply Keyword-Based Fallback Validation | |
| # --------------------------------------------------------------- | |
| unsafe_kw_list = [ | |
| "unsafe", "not safe", "forbidden", "blocked", "denied", | |
| "unauthorized", "not authorized", "cannot", "not allowed", | |
| "not able", "sorry", "apologize", "regret", "not" | |
| ] | |
| if any(word in guardrail_result for word in unsafe_kw_list): | |
| return "unsafe" | |
| # Default to safe if no unsafe indicators found | |
| return "safe" | |
| # ================================================================ | |
| # SECTION 7: Customer Authentication | |
| # --------------------------------------------------------------- | |
| # Purpose: | |
| # Validates whether a given customer ID (cust_id) exists in the | |
| # 'orders' database table. Prevents unauthorized access and | |
| # ensures all operations are scoped to valid customers only. | |
| # ================================================================ | |
| def authorise_customer(cust_id: str) -> bool: | |
| """ | |
| Authenticate a customer by verifying if the provided cust_id | |
| exists in the 'orders' table. | |
| Workflow: | |
| 1️⃣ Build a SQL SELECT query to check customer presence. | |
| 2️⃣ Execute query through db_agent interface. | |
| 3️⃣ Validate and parse returned results. | |
| 4️⃣ Return True if match found, else False. | |
| """ | |
| try: | |
| # ------------------------------------------------------------ | |
| # Step 1: Prepare Authentication Query | |
| # Create a SQL statement to check if cust_id exists in orders. | |
| # ------------------------------------------------------------ | |
| query = f"SELECT * FROM orders WHERE cust_id = '{cust_id}';" | |
| # ------------------------------------------------------------ | |
| # Step 2: Execute Query via db_agent | |
| # The db_agent handles safe database interaction and returns | |
| # the output in a structured dictionary format. | |
| # ------------------------------------------------------------ | |
| result = sql_db_agent.invoke({"input": query}) | |
| # Validate response type and check for expected field | |
| if not isinstance(result, dict) or "output" not in result: | |
| return False | |
| # Extract query output | |
| output = result["output"] | |
| # ------------------------------------------------------------ | |
| # Step 3: Check if cust_id appears in query result | |
| # Supports both string and structured (list/dict) response types. | |
| # ------------------------------------------------------------ | |
| if isinstance(output, str) and cust_id in output: | |
| return True | |
| if isinstance(output, (list, dict)) and cust_id in str(output): | |
| return True | |
| # ------------------------------------------------------------ | |
| # Step 4: No match found | |
| # Return False if cust_id not detected in the output. | |
| # ------------------------------------------------------------ | |
| return False | |
| except Exception: | |
| # ------------------------------------------------------------ | |
| # Step 5: Exception Handling | |
| # Return False in case of query or connection failure. | |
| # ------------------------------------------------------------ | |
| return False | |
| # ================================================================ | |
| # SECTION 8: Order Query Tool | |
| # --------------------------------------------------------------- | |
| # Purpose: | |
| # Extracts customer-specific order details securely from the | |
| # database. Enforces safety filters, authentication, and | |
| # deterministic logic before returning structured results. | |
| # ================================================================ | |
| def order_query_tool_func(orderagent_input: str) -> str: | |
| """ | |
| Accepts a stringified dict input like: | |
| "{'cust_id': 'C1018', 'user_query': 'What is the status of my order?'}" | |
| Workflow: | |
| 1️⃣ Parse input string safely into a Python dictionary. | |
| 2️⃣ Validate and extract 'cust_id' and 'user_query'. | |
| 3️⃣ Apply guardrail and authorization checks. | |
| 4️⃣ If safe and valid → query the database for matching order(s). | |
| 5️⃣ Return a structured stringified dictionary for downstream tools. | |
| """ | |
| try: | |
| # ------------------------------------------------------------ | |
| # Step 1: Parse Input | |
| # Safely convert the input string into a Python dictionary. | |
| # Rejects malicious or malformed strings. | |
| # ------------------------------------------------------------ | |
| data = ast.literal_eval(orderagent_input) | |
| # Extract essential fields from parsed input | |
| cust_id = data.get("cust_id") | |
| user_query = data.get("user_query") | |
| except Exception: | |
| # ------------------------------------------------------------ | |
| # Step 2: Handle Invalid Input | |
| # Return an error response if parsing fails. | |
| # Ensures structured output even on failure. | |
| # ------------------------------------------------------------ | |
| return str({ | |
| "cust_id": None, | |
| "orig_query": None, | |
| "db_response": "⚠️ Invalid input format for OrderQueryTool." | |
| }) | |
| #print('order_query_tool_func : LEVEL-1 Done',flush=True) | |
| #sys.stdout.flush() | |
| # ------------------------------------------------------------ | |
| # Step 3: Guardrail Evaluation | |
| # Uses handle_guardrail() to detect unsafe or irrelevant queries. | |
| # ------------------------------------------------------------ | |
| #guardrail_response = handle_guardrail(user_query) | |
| #if any(keyword in guardrail_response.lower() for keyword in ["unsafe", "unable", "unauthorized"]): | |
| # ------------------------------------------------------------ | |
| # Step 4: Unsafe Query Handling | |
| # If guardrail detects unsafe intent, stop execution immediately. | |
| # Prevents SQL injection, data leaks, and unauthorized access. | |
| # ------------------------------------------------------------ | |
| #return str({ | |
| # "cust_id": cust_id, | |
| # "orig_query": user_query, | |
| # "db_response": "🚫 Unauthorized or Inappropriate query. Please ask something related to your own order." | |
| #}) | |
| #print('order_query_tool_func : LEVEL-2 Done',flush=True) | |
| #sys.stdout.flush() | |
| # ------------------------------------------------------------ | |
| # Step 5: Customer Authorization | |
| # Verify whether the provided cust_id is valid and known. | |
| # ------------------------------------------------------------ | |
| #if not authorise_customer(cust_id): | |
| #return str({ | |
| # "cust_id": cust_id, | |
| # "orig_query": user_query, | |
| # "db_response": "🚫 Invalid customer ID. Please provide a valid customer ID." | |
| #}) | |
| #print('order_query_tool_func : LEVEL-3 Done',flush=True) | |
| #sys.stdout.flush() | |
| # ------------------------------------------------------------ | |
| # Step 6: Database Query | |
| # Retrieve customer’s order details from the 'orders' table. | |
| # ------------------------------------------------------------ | |
| try: | |
| # Execute the SQL query safely through sql_db_agent | |
| order_result = sql_db_agent.invoke(f"SELECT * FROM orders WHERE cust_id = '{cust_id}';") | |
| # Extract the 'output' field from query response (if available) | |
| db_response = order_result.get("output") if order_result else None | |
| except Exception: | |
| # ------------------------------------------------------------ | |
| # Step 7: Handle Database Errors | |
| # In case of query or connection issues, return user-friendly message. | |
| # ------------------------------------------------------------ | |
| return str({ | |
| "cust_id": cust_id, | |
| "orig_query": user_query, | |
| "db_response": "🚫 Sorry, we cannot fetch your order details right now. Please try again later." | |
| }) | |
| #print('order_query_tool_func : LEVEL-4 Done',flush=True) | |
| #print('cust_id = ',cust_id, flush=True) | |
| #print('orig_query = ',user_query, flush=True) | |
| #print('db_response = ',db_response, flush=True) | |
| #sys.stdout.flush() | |
| # ------------------------------------------------------------ | |
| # Step 8: Final Structured Output | |
| # Return consistent output for downstream tools (AnswerTool). | |
| # ------------------------------------------------------------ | |
| return str({ | |
| "cust_id": cust_id, | |
| "orig_query": user_query, | |
| "db_response": db_response | |
| }) | |
| # ================================================================ | |
| # SECTION 9: LangChain Tool Wrapper | |
| # --------------------------------------------------------------- | |
| # Wraps the SQL query executor as a callable Tool. | |
| # Enables integration with agent workflows that need database access. | |
| # ================================================================ | |
| #from langchain.tools import Tool | |
| #OrderQueryTool = Tool( | |
| # name="order_query_tool", | |
| # func=order_query_tool_func, | |
| # description="Use this tool to fetch order-related (read-only) info for a customer's order. Requires customer id from session. Blocks confidential fields. Returns structured output as a stringified dictionary" | |
| #) | |