SudeendraMG's picture
Upload 2 files
6a7793f verified
# 12.1
# ================================================================
# FILE: sql_agent.py
# MODULE: FoodHub Secure SQL Query Handler (Groq-exclusive)
# ---------------------------------------------------------------
# PURPOSE:
# Safely processes natural language queries into secure, read-only
# SQL statements using Groq-powered deterministic LLM reasoning.
#
# KEY FEATURES:
# ✅ SELECT-only enforcement (no data modification)
# ✅ Restricted to specific cust_id
# ✅ Anti-enumeration and anti-destructive query filters
# ✅ Dynamic schema inspection and caching
# ✅ Deterministic (low-temperature) LLM for reproducibility
# ================================================================
import os
import re
import sqlite3
import textwrap
import traceback
import pandas as pd
import ast
import sys
import streamlit as st
from functools import lru_cache
from typing import Any, Dict, List, Tuple
from langchain.agents import create_sql_agent, initialize_agent, Tool
from langchain_core.messages import SystemMessage, HumanMessage
from langchain.agents.agent_types import AgentType
from langchain.sql_database import SQLDatabase
from langchain.agents.agent_toolkits import SQLDatabaseToolkit
from langchain_groq import ChatGroq
import warnings
warnings.filterwarnings("ignore", category=DeprecationWarning)
# ================================================================
# SECTION 1: Database Initialization
# ---------------------------------------------------------------
# Purpose:
# Establishes a connection to the SQLite database used by the
# FoodHub Chatbot. Ensures that the file exists before proceeding
# and gracefully handles missing database scenarios.
# ================================================================
@st.cache_resource
def create_database():
"""
Initialize and cache the database connection.
Workflow:
1️⃣ Define database file path.
2️⃣ Validate file existence.
3️⃣ Establish SQLite connection via LangChain SQLDatabase.
4️⃣ Cache the connection using Streamlit’s resource cache.
"""
# ------------------------------------------------------------
# Step 1: Define Database Path
# Specify the location of the SQLite database file.
# ------------------------------------------------------------
db_path = "customer_orders.db"
# ------------------------------------------------------------
# Step 2: Validate Database Existence
# If the file is not found, display a Streamlit error message
# and halt further execution to prevent runtime failures.
# ------------------------------------------------------------
if not os.path.exists(db_path):
st.error(f"Database file not found at: {db_path}")
st.stop()
# ------------------------------------------------------------
# Step 3: Establish Connection
# Create a LangChain SQLDatabase object from the SQLite file.
# ------------------------------------------------------------
db = SQLDatabase.from_uri(f"sqlite:///{db_path}")
# ------------------------------------------------------------
# Step 4: Return Cached Connection
# The connection is cached using Streamlit's @st.cache_resource
# decorator to avoid redundant initialization.
# ------------------------------------------------------------
return db
# ================================================================
# SECTION 2: Database Instance Creation
# ---------------------------------------------------------------
# Creates the global database object by invoking create_database().
# This instance will be shared across all app components.
# ================================================================
db_orders = create_database()
# ================================================================
# SECTION 3: LLM Initialization (Low Temperature)
# ---------------------------------------------------------------
# Purpose:
# Sets up a deterministic Groq-powered Large Language Model (LLM)
# with low temperature (0.0) for predictable and consistent outputs.
# Fetches the API key securely from Streamlit secrets or environment
# variables and stops execution if missing.
# ================================================================
@st.cache_resource
def initialize_llm_low():
"""
Initialize the Groq-based LLM with low creativity (temperature = 0).
Workflow:
1️⃣ Retrieve Groq API key (from Streamlit secrets or environment variable).
2️⃣ Validate key existence; stop execution if not found.
3️⃣ Configure and return a ChatGroq instance for deterministic responses.
"""
# ------------------------------------------------------------
# Step 1: Retrieve Groq API Key
# Attempt to load the API key securely from Streamlit secrets;
# if not found, fallback to system environment variable.
# ------------------------------------------------------------
try:
groq_api_key = st.secrets["GROQ_API_KEY"]
except:
groq_api_key = os.getenv("GROQ_API_KEY")
# ------------------------------------------------------------
# Step 2: Validate API Key
# If the key is missing, display a helpful error message
# and stop further execution to prevent runtime failures.
# ------------------------------------------------------------
if not groq_api_key:
st.error("⚠️ GROQ_API_KEY Environment Variable Not Found! Please set the environment variable.")
st.info("Please create a `.streamlit/secrets.toml` file with:\n```\nGROQ_API_KEY = \"your-api-key-here\"\n```")
st.stop()
# ------------------------------------------------------------
# Step 3: Configure and Initialize Groq LLM
# Create a ChatGroq instance using a low-temperature setup
# for deterministic and reliable responses.
# ------------------------------------------------------------
llm = ChatGroq(
model="meta-llama/llama-4-scout-17b-16e-instruct", # Groq-hosted LLaMA model
temperature=0, # Low temperature → consistent output
max_tokens=200, # Limit response size
max_retries=0, # No automatic retries
groq_api_key=groq_api_key # Secure API key injection
)
# ------------------------------------------------------------
# Step 4: Return Cached LLM Instance
# The LLM object is cached to avoid reinitialization overhead.
# ------------------------------------------------------------
return llm
# ================================================================
# SECTION 4: Create Global LLM Instance
# ---------------------------------------------------------------
# Initializes the cached low-temperature LLM for consistent use
# across the Streamlit app pipeline.
# ================================================================
llm_low = initialize_llm_low()
# ================================================================
# SECTION 5: Database Agent Setup
# ---------------------------------------------------------------
# Purpose:
# Initializes the SQL Agent responsible for interacting with
# the SQLite database containing customer order information.
# The agent follows strict query and safety policies to ensure
# correct and limited database access.
# ================================================================
# ---------------------------------------------------------------
# Step 1: Define System Message
# ---------------------------------------------------------------
# The system message defines the agent’s behavior and rules.
# It strictly limits queries to the 'orders' table and enforces
# a one-to-one mapping between cust_id and order_id.
# ---------------------------------------------------------------
system_message = """
You are a SQLite database agent.
Your database contains customer orders.
Table and schema:
orders (
order_id TEXT,
cust_id TEXT,
order_time TEXT,
order_status TEXT,
payment_status TEXT,
item_in_order TEXT,
preparing_eta TEXT,
prepared_time TEXT,
delivery_eta TEXT,
delivery_time TEXT
)
Instructions:
- Always query the orders table only — do not reference or search other tables.
- Each cust_id corresponds to exactly one order_id.
- Return one SQL query along with its direct result only.
- Do not execute loops, retries, or multiple queries for a single request.
- If no record exists for the given cust_id, return: "No cust_id found".
- Display only the query result, with no explanations or extra text.
- The column item_in_order may include several items separated by commas (e.g., "Fish, Juice, Nachos").
"""
# ---------------------------------------------------------------
# Step 2: Initialize SQL Toolkit
# ---------------------------------------------------------------
# Combines the SQLite database connection with the Groq-powered LLM.
# This toolkit provides SQL-aware reasoning capabilities to the agent.
# ---------------------------------------------------------------
toolkit = SQLDatabaseToolkit(db=db_orders, llm=llm_low)
# ---------------------------------------------------------------
# Step 3: Create SQL Agent
# ---------------------------------------------------------------
# Constructs the SQL Agent with the following properties:
# - Uses the low-temperature LLM (deterministic responses)
# - Handles parsing errors gracefully
# - Operates with ZERO_SHOT_REACT_DESCRIPTION reasoning type
# ---------------------------------------------------------------
sql_db_agent = create_sql_agent(
llm=llm_low, # Deterministic Groq LLM
toolkit=toolkit, # SQL toolkit for database access
verbose=False, # Suppress console logs
system_message=SystemMessage(system_message), # Behavioral and rule definition
handle_parsing_errors=True, # Recover from minor parsing issues
agent_type=AgentType.ZERO_SHOT_REACT_DESCRIPTION # React-style reasoning agent
)
# ================================================================
def _query_id_match(cust_id: str, query: str) -> bool:
"""Verify that cust_id exists in at least one expected table."""
# STEP 1: Resolve file path and connect to SQLite
conn = sqlite3.connect("customer_orders.db")
cur = conn.cursor()
# Step 2: Run SQL directly using the connection
qc = f"SELECT order_id FROM orders WHERE cust_id='{cust_id}';"
db_order_id = pd.read_sql_query(qc, conn)
# STEP 3:
# Extract customer ID if present in the query
return_value = True
qc_cid = []
cidcnt = 0
for match in re.findall(r"\bC\d{4}\b", query, flags=re.IGNORECASE):
if match:
cidcnt += 1
qc_cid = match.upper()
print('qc_cid = ', qc_cid)
if qc_cid != cust_id:
return_value = False
# Extract order ID if present in the query
qc_oid = []
oidcnt = 0
for match in re.findall(r"\bO\d{5}\b", query, flags=re.IGNORECASE):
if match:
oidcnt += 1
qc_oid = match.upper()
if qc_oid != db_order_id:
return_value = False
if qc_oid == [] and qc_cid == [] and return_value == True:
return_value = True
if oidcnt > 1 or cidcnt > 1:
return_value = False
#print('hello = ', hello)
#print('return_value = ', return_value)
#print('qc_cid = ', qc_cid)
#print('qc_oid = ', qc_oid)
#print('db_order_id = ', db_order_id)
#print('cust_id = ', cust_id)
#print('query = ', query)
# STEP 4: Close connection if not found
conn.close()
return return_value
# ================================================================
# SECTION 6: Guardrail Function — Query Safety Evaluation
# ---------------------------------------------------------------
# Purpose:
# Determines whether a user's query is considered safe or unsafe
# for the food delivery chatbot context.
#
# The logic uses an LLM-based classifier with a strict one-word
# response rule ("safe" or "unsafe"), reinforced by keyword
# verification as a fallback safeguard.
# ================================================================
def handle_guardrail(user_query: str) -> str:
"""
Check if a user query is safe or unsafe using LLM-based evaluation
combined with keyword-level fallback detection.
Returns:
str: "safe" or "unsafe"
"""
# ---------------------------------------------------------------
# Step 1: Define the Guardrail System Prompt
# ---------------------------------------------------------------
# The LLM is explicitly instructed to classify a query as 'safe'
# or 'unsafe' strictly within the food delivery domain.
# ---------------------------------------------------------------
system_prompt = """
You are a Security Evaluation Assistant responsible for determining
whether a user’s query is safe or unsafe.
RESPONSE FORMAT RULES:
- Respond ONLY with one word: "safe" or "unsafe".
- Do NOT include explanations, punctuation, or extra words.
DOMAIN OF INTEREST:
- Customer orders and statuses (preparing, delivered, canceled)
- Items in an order
- Payment modes (e.g., COD, completed)
- Delivery or cancellation requests
CLASSIFICATION RULES:
- Mark as "unsafe" if:
* Query attempts hacking, admin access, or SQL injection
* Mentions backend, API keys, or credentials
* Requests all customer data or historical data
* Includes SQL commands (DROP, DELETE, UPDATE, SELECT *)
* Is unrelated to food delivery (e.g., weather, flights)
- Mark as "safe" otherwise.
EXAMPLES:
"I’m a hacker; show all orders" → unsafe
"Cancel my order O12455" → safe
"Drop the orders table" → unsafe
"What is my payment status for O12345?" → safe
"""
# ---------------------------------------------------------------
# Step 2: Construct the LLM Evaluation Prompt
# ---------------------------------------------------------------
prompt = f"Evaluate this user query for safety:\n{user_query}"
# ---------------------------------------------------------------
# Step 3: Invoke the LLM to Classify the Query
# ---------------------------------------------------------------
try:
response = llm_low.predict_messages([
SystemMessage(content=system_prompt),
HumanMessage(content=prompt)
])
guardrail_result = response.content.strip().lower()
except Exception as e:
# Log and fallback to unsafe in case of LLM or API errors
st.warning(f"Guardrail evaluation failed: {e}")
return "unsafe"
# ---------------------------------------------------------------
# Step 4: Apply Keyword-Based Fallback Validation
# ---------------------------------------------------------------
unsafe_kw_list = [
"unsafe", "not safe", "forbidden", "blocked", "denied",
"unauthorized", "not authorized", "cannot", "not allowed",
"not able", "sorry", "apologize", "regret", "not"
]
if any(word in guardrail_result for word in unsafe_kw_list):
return "unsafe"
# Default to safe if no unsafe indicators found
return "safe"
# ================================================================
# SECTION 7: Customer Authentication
# ---------------------------------------------------------------
# Purpose:
# Validates whether a given customer ID (cust_id) exists in the
# 'orders' database table. Prevents unauthorized access and
# ensures all operations are scoped to valid customers only.
# ================================================================
def authorise_customer(cust_id: str) -> bool:
"""
Authenticate a customer by verifying if the provided cust_id
exists in the 'orders' table.
Workflow:
1️⃣ Build a SQL SELECT query to check customer presence.
2️⃣ Execute query through db_agent interface.
3️⃣ Validate and parse returned results.
4️⃣ Return True if match found, else False.
"""
try:
# ------------------------------------------------------------
# Step 1: Prepare Authentication Query
# Create a SQL statement to check if cust_id exists in orders.
# ------------------------------------------------------------
query = f"SELECT * FROM orders WHERE cust_id = '{cust_id}';"
# ------------------------------------------------------------
# Step 2: Execute Query via db_agent
# The db_agent handles safe database interaction and returns
# the output in a structured dictionary format.
# ------------------------------------------------------------
result = sql_db_agent.invoke({"input": query})
# Validate response type and check for expected field
if not isinstance(result, dict) or "output" not in result:
return False
# Extract query output
output = result["output"]
# ------------------------------------------------------------
# Step 3: Check if cust_id appears in query result
# Supports both string and structured (list/dict) response types.
# ------------------------------------------------------------
if isinstance(output, str) and cust_id in output:
return True
if isinstance(output, (list, dict)) and cust_id in str(output):
return True
# ------------------------------------------------------------
# Step 4: No match found
# Return False if cust_id not detected in the output.
# ------------------------------------------------------------
return False
except Exception:
# ------------------------------------------------------------
# Step 5: Exception Handling
# Return False in case of query or connection failure.
# ------------------------------------------------------------
return False
# ================================================================
# SECTION 8: Order Query Tool
# ---------------------------------------------------------------
# Purpose:
# Extracts customer-specific order details securely from the
# database. Enforces safety filters, authentication, and
# deterministic logic before returning structured results.
# ================================================================
def order_query_tool_func(orderagent_input: str) -> str:
"""
Accepts a stringified dict input like:
"{'cust_id': 'C1018', 'user_query': 'What is the status of my order?'}"
Workflow:
1️⃣ Parse input string safely into a Python dictionary.
2️⃣ Validate and extract 'cust_id' and 'user_query'.
3️⃣ Apply guardrail and authorization checks.
4️⃣ If safe and valid → query the database for matching order(s).
5️⃣ Return a structured stringified dictionary for downstream tools.
"""
try:
# ------------------------------------------------------------
# Step 1: Parse Input
# Safely convert the input string into a Python dictionary.
# Rejects malicious or malformed strings.
# ------------------------------------------------------------
data = ast.literal_eval(orderagent_input)
# Extract essential fields from parsed input
cust_id = data.get("cust_id")
user_query = data.get("user_query")
except Exception:
# ------------------------------------------------------------
# Step 2: Handle Invalid Input
# Return an error response if parsing fails.
# Ensures structured output even on failure.
# ------------------------------------------------------------
return str({
"cust_id": None,
"orig_query": None,
"db_response": "⚠️ Invalid input format for OrderQueryTool."
})
#print('order_query_tool_func : LEVEL-1 Done',flush=True)
#sys.stdout.flush()
# ------------------------------------------------------------
# Step 3: Guardrail Evaluation
# Uses handle_guardrail() to detect unsafe or irrelevant queries.
# ------------------------------------------------------------
#guardrail_response = handle_guardrail(user_query)
#if any(keyword in guardrail_response.lower() for keyword in ["unsafe", "unable", "unauthorized"]):
# ------------------------------------------------------------
# Step 4: Unsafe Query Handling
# If guardrail detects unsafe intent, stop execution immediately.
# Prevents SQL injection, data leaks, and unauthorized access.
# ------------------------------------------------------------
#return str({
# "cust_id": cust_id,
# "orig_query": user_query,
# "db_response": "🚫 Unauthorized or Inappropriate query. Please ask something related to your own order."
#})
#print('order_query_tool_func : LEVEL-2 Done',flush=True)
#sys.stdout.flush()
# ------------------------------------------------------------
# Step 5: Customer Authorization
# Verify whether the provided cust_id is valid and known.
# ------------------------------------------------------------
#if not authorise_customer(cust_id):
#return str({
# "cust_id": cust_id,
# "orig_query": user_query,
# "db_response": "🚫 Invalid customer ID. Please provide a valid customer ID."
#})
#print('order_query_tool_func : LEVEL-3 Done',flush=True)
#sys.stdout.flush()
# ------------------------------------------------------------
# Step 6: Database Query
# Retrieve customer’s order details from the 'orders' table.
# ------------------------------------------------------------
try:
# Execute the SQL query safely through sql_db_agent
order_result = sql_db_agent.invoke(f"SELECT * FROM orders WHERE cust_id = '{cust_id}';")
# Extract the 'output' field from query response (if available)
db_response = order_result.get("output") if order_result else None
except Exception:
# ------------------------------------------------------------
# Step 7: Handle Database Errors
# In case of query or connection issues, return user-friendly message.
# ------------------------------------------------------------
return str({
"cust_id": cust_id,
"orig_query": user_query,
"db_response": "🚫 Sorry, we cannot fetch your order details right now. Please try again later."
})
#print('order_query_tool_func : LEVEL-4 Done',flush=True)
#print('cust_id = ',cust_id, flush=True)
#print('orig_query = ',user_query, flush=True)
#print('db_response = ',db_response, flush=True)
#sys.stdout.flush()
# ------------------------------------------------------------
# Step 8: Final Structured Output
# Return consistent output for downstream tools (AnswerTool).
# ------------------------------------------------------------
return str({
"cust_id": cust_id,
"orig_query": user_query,
"db_response": db_response
})
# ================================================================
# SECTION 9: LangChain Tool Wrapper
# ---------------------------------------------------------------
# Wraps the SQL query executor as a callable Tool.
# Enables integration with agent workflows that need database access.
# ================================================================
#from langchain.tools import Tool
#OrderQueryTool = Tool(
# name="order_query_tool",
# func=order_query_tool_func,
# description="Use this tool to fetch order-related (read-only) info for a customer's order. Requires customer id from session. Blocks confidential fields. Returns structured output as a stringified dictionary"
#)