Spaces:
Runtime error
Runtime error
| import os | |
| import sqlite3 | |
| from langchain.agents import create_sql_agent, AgentType, initialize_agent | |
| from langchain.tools import Tool | |
| from langchain.llms import Groq | |
| from langchain.sql_database import SQLDatabase | |
| # -------------------------- | |
| # Load DB | |
| # -------------------------- | |
| DB_PATH = "customer_orders.db" | |
| conn = sqlite3.connect(DB_PATH) | |
| db = SQLDatabase.from_uri(f"sqlite:///{DB_PATH}") | |
| # -------------------------- | |
| # Initialize LLM using HF secret | |
| # -------------------------- | |
| llm = Groq( | |
| model="llama-3.3-70b-versatile", | |
| groq_api_key=os.getenv("GROQ_API_KEY") | |
| ) | |
| # -------------------------- | |
| # Create SQL agent | |
| # -------------------------- | |
| sql_agent = create_sql_agent( | |
| llm=llm, | |
| db=db, | |
| agent_type=AgentType.ZERO_SHOT_REACT_DESCRIPTION, | |
| verbose=False | |
| ) | |
| # -------------------------- | |
| # Safety wrapper | |
| # -------------------------- | |
| def safe_generate(prompt): | |
| guardrail = f""" | |
| You are a FoodHub support chatbot. | |
| Rules: | |
| - No hallucinating order details. | |
| - No SQL logs or internal reasoning. | |
| - Answer politely and clearly. | |
| USER: | |
| {prompt} | |
| """ | |
| try: | |
| return llm.invoke(guardrail).content | |
| except: | |
| return "Sorry, something went wrong." | |
| # -------------------------- | |
| # Intent Classifier | |
| # -------------------------- | |
| def intent_classifier(query): | |
| prompt = f""" | |
| Classify the user's intent. Respond with EXACTLY one label: | |
| needs_order_id | |
| sql_query | |
| general_query | |
| User query: {query} | |
| """ | |
| return llm.invoke(prompt).content.strip().lower() | |
| # -------------------------- | |
| # Query Normalizer | |
| # -------------------------- | |
| def normalize_query(user_query): | |
| prompt = f""" | |
| Rewrite the user's message into a short instruction for retrieving order info. | |
| Do NOT write SQL. Only describe what to retrieve. | |
| User: | |
| {user_query} | |
| """ | |
| try: | |
| return llm.invoke(prompt).content | |
| except: | |
| return user_query | |
| # -------------------------- | |
| # Order Query Tool | |
| # -------------------------- | |
| def order_query_tool_function(query): | |
| try: | |
| cleaned = normalize_query(query) | |
| result = sql_agent.invoke({"input": cleaned}) | |
| raw = result.get("output", "").strip() | |
| if not raw or "i don't know" in raw.lower() or "no results" in raw.lower(): | |
| return "The order ID you provided does not exist." | |
| return raw | |
| except: | |
| return "Error retrieving data. Please check the order ID." | |
| order_query_tool = Tool( | |
| name="OrderQueryTool", | |
| func=order_query_tool_function, | |
| description="Fetch raw order details." | |
| ) | |
| # -------------------------- | |
| # Answer Refiner | |
| # -------------------------- | |
| def answer_refiner_tool_function(raw): | |
| if "does not exist" in raw.lower(): | |
| return raw | |
| prompt = f""" | |
| Rewrite the raw order data in a simple, polite, customer-friendly way. | |
| RAW: | |
| {raw} | |
| """ | |
| return safe_generate(prompt) | |
| answer_refiner_tool = Tool( | |
| name="AnswerRefinerTool", | |
| func=answer_refiner_tool_function | |
| ) | |
| # -------------------------- | |
| # Chat Agent (tools + LLM) | |
| # -------------------------- | |
| chat_agent = initialize_agent( | |
| tools=[order_query_tool, answer_refiner_tool], | |
| llm=llm, | |
| agent_type=AgentType.ZERO_SHOT_REACT_DESCRIPTION, | |
| verbose=False | |
| ) | |
| # -------------------------- | |
| # Smart Router | |
| # -------------------------- | |
| def smart_chat_router(user_input): | |
| intent = intent_classifier(user_input) | |
| if intent == "needs_order_id": | |
| return "Could you please provide the order ID?" | |
| elif intent == "general_query": | |
| return safe_generate(user_input) | |
| else: | |
| result = chat_agent.invoke({"input": user_input}) | |
| output = result.get("output", "") | |
| if not output or "i don't know" in output.lower(): | |
| return "Sorry, I could not find that order." | |
| return output | |
| def process_message(msg): | |
| return smart_chat_router(msg) | |