FoodHub / agent_backend.py
wankhedes27's picture
Upload agent_backend.py with huggingface_hub
550c125 verified
import os
import sqlite3
from langchain.agents import create_sql_agent, AgentType, initialize_agent
from langchain.tools import Tool
from langchain.llms import Groq
from langchain.sql_database import SQLDatabase
# --------------------------
# Load DB
# --------------------------
DB_PATH = "customer_orders.db"
conn = sqlite3.connect(DB_PATH)
db = SQLDatabase.from_uri(f"sqlite:///{DB_PATH}")
# --------------------------
# Initialize LLM using HF secret
# --------------------------
llm = Groq(
model="llama-3.3-70b-versatile",
groq_api_key=os.getenv("GROQ_API_KEY")
)
# --------------------------
# Create SQL agent
# --------------------------
sql_agent = create_sql_agent(
llm=llm,
db=db,
agent_type=AgentType.ZERO_SHOT_REACT_DESCRIPTION,
verbose=False
)
# --------------------------
# Safety wrapper
# --------------------------
def safe_generate(prompt):
guardrail = f"""
You are a FoodHub support chatbot.
Rules:
- No hallucinating order details.
- No SQL logs or internal reasoning.
- Answer politely and clearly.
USER:
{prompt}
"""
try:
return llm.invoke(guardrail).content
except:
return "Sorry, something went wrong."
# --------------------------
# Intent Classifier
# --------------------------
def intent_classifier(query):
prompt = f"""
Classify the user's intent. Respond with EXACTLY one label:
needs_order_id
sql_query
general_query
User query: {query}
"""
return llm.invoke(prompt).content.strip().lower()
# --------------------------
# Query Normalizer
# --------------------------
def normalize_query(user_query):
prompt = f"""
Rewrite the user's message into a short instruction for retrieving order info.
Do NOT write SQL. Only describe what to retrieve.
User:
{user_query}
"""
try:
return llm.invoke(prompt).content
except:
return user_query
# --------------------------
# Order Query Tool
# --------------------------
def order_query_tool_function(query):
try:
cleaned = normalize_query(query)
result = sql_agent.invoke({"input": cleaned})
raw = result.get("output", "").strip()
if not raw or "i don't know" in raw.lower() or "no results" in raw.lower():
return "The order ID you provided does not exist."
return raw
except:
return "Error retrieving data. Please check the order ID."
order_query_tool = Tool(
name="OrderQueryTool",
func=order_query_tool_function,
description="Fetch raw order details."
)
# --------------------------
# Answer Refiner
# --------------------------
def answer_refiner_tool_function(raw):
if "does not exist" in raw.lower():
return raw
prompt = f"""
Rewrite the raw order data in a simple, polite, customer-friendly way.
RAW:
{raw}
"""
return safe_generate(prompt)
answer_refiner_tool = Tool(
name="AnswerRefinerTool",
func=answer_refiner_tool_function
)
# --------------------------
# Chat Agent (tools + LLM)
# --------------------------
chat_agent = initialize_agent(
tools=[order_query_tool, answer_refiner_tool],
llm=llm,
agent_type=AgentType.ZERO_SHOT_REACT_DESCRIPTION,
verbose=False
)
# --------------------------
# Smart Router
# --------------------------
def smart_chat_router(user_input):
intent = intent_classifier(user_input)
if intent == "needs_order_id":
return "Could you please provide the order ID?"
elif intent == "general_query":
return safe_generate(user_input)
else:
result = chat_agent.invoke({"input": user_input})
output = result.get("output", "")
if not output or "i don't know" in output.lower():
return "Sorry, I could not find that order."
return output
def process_message(msg):
return smart_chat_router(msg)