Spaces:
Sleeping
Sleeping
| # chatbot_logic.py | |
| import os | |
| import re | |
| import json | |
| from typing import Optional | |
| import pandas as pd | |
| from sqlalchemy import create_engine | |
| # LangChain imports | |
| from langchain_groq import ChatGroq | |
| from langchain_core.messages import HumanMessage, SystemMessage | |
| from langchain.utilities import SQLDatabase | |
| from langchain.agents import create_sql_agent | |
| from langchain.memory import ConversationBufferMemory | |
| # ------------------------------- | |
| # STEP 1: Initialize LLM | |
| # ------------------------------- | |
| # Load GROQ API key from environment variables | |
| GROQ_API_KEY = os.environ.get("GROQ_API_KEY") | |
| if not GROQ_API_KEY: | |
| raise ValueError("Please add GROQ_API_KEY as an environment variable.") | |
| model_name = "llama-3.3-70b-versatile" | |
| llm = ChatGroq( | |
| model=model_name, | |
| temperature=0.0, | |
| max_tokens=512, | |
| api_key=GROQ_API_KEY | |
| ) | |
| # ------------------------------- | |
| # STEP 2: Load Database & SQL Agent | |
| # ------------------------------- | |
| # Path to your SQLite database in Hugging Face Space | |
| db_path = "customer_orders.db" # place your DB file in the Space repo | |
| # Load database | |
| db = SQLDatabase.from_uri(f"sqlite:///{db_path}") | |
| # Create SQL agent | |
| sql_agent = create_sql_agent( | |
| llm=llm, | |
| db=db, | |
| verbose=False # Set True if you want SQL queries printed | |
| ) | |
| # ------------------------------- | |
| # STEP 3: Memory for conversation | |
| # ------------------------------- | |
| memory = ConversationBufferMemory( | |
| memory_key="chat_history", | |
| return_messages=True | |
| ) | |
| # ------------------------------- | |
| # STEP 4: Helper Functions | |
| # ------------------------------- | |
| def extract_order_id(text: str) -> Optional[str]: | |
| """Extract order id pattern (e.g. O12345) from user text.""" | |
| m = re.search(r"\b([Oo]\d{3,})\b", text) | |
| return m.group(1) if m else None | |
| def safe_llm_generate(messages): | |
| """Call LLM safely, fallback if needed.""" | |
| try: | |
| res = llm.generate([messages]) | |
| text = res.generations[0][0].text | |
| except Exception: | |
| resp = llm.invoke(messages) | |
| try: | |
| text = resp.content | |
| except Exception: | |
| text = str(resp) | |
| return text.strip() | |
| def order_query_tool(order_id: str) -> str: | |
| """Fetch order data using SQL agent (or fallback to raw DB).""" | |
| if not order_id: | |
| return "ERROR: No order_id provided." | |
| try: | |
| query = f"SELECT * FROM orders WHERE order_id = '{order_id}';" | |
| raw_response = sql_agent.run(query) | |
| raw_text = str(raw_response) | |
| except Exception as e: | |
| try: | |
| rows = db.run(f"SELECT * FROM orders WHERE order_id = '{order_id}'") | |
| raw_text = json.dumps(rows, default=str, indent=2) | |
| except Exception as e2: | |
| raw_text = f"ERROR fetching order {order_id}: {e} / fallback: {e2}" | |
| return raw_text | |
| def answer_tool(raw_order_context: str, user_question: str) -> str: | |
| """Convert raw order context to polite customer-facing reply.""" | |
| system_prompt = SystemMessage( | |
| content=( | |
| "You are a polite, formal customer support assistant for a food delivery app. " | |
| "Do NOT reveal sensitive backend details or PII. Use the order context to answer the user's question clearly and briefly." | |
| ) | |
| ) | |
| human_prompt = HumanMessage( | |
| content=( | |
| f"Order context (raw):\n{raw_order_context}\n\n" | |
| f"Customer question:\n{user_question}\n\n" | |
| "Instructions:\n" | |
| "1) Answer politely in 2-4 sentences.\n" | |
| "2) If order found, state status & ETA and next steps.\n" | |
| "3) If not found, ask politely for order id.\n" | |
| "4) Do NOT include raw DB or internal SQL." | |
| ) | |
| ) | |
| return safe_llm_generate([system_prompt, human_prompt]) | |
| # ------------------------------- | |
| # STEP 5: Main Chat Agent Function | |
| # ------------------------------- | |
| def run_chat_agent(user_input: str) -> dict: | |
| """ | |
| Main agent entrypoint called by app.py | |
| Returns a dict with keys: user_input, order_id, raw_context, final_answer | |
| """ | |
| result = { | |
| "user_input": user_input, | |
| "order_id": None, | |
| "raw_context": None, | |
| "final_answer": None, | |
| } | |
| chat_history = memory.load_memory_variables({}).get("chat_history", []) | |
| # Detect order_id | |
| order_id = extract_order_id(user_input) | |
| result["order_id"] = order_id | |
| if order_id: | |
| raw_context = order_query_tool(order_id) | |
| result["raw_context"] = raw_context | |
| final_answer = answer_tool(raw_context, user_input) | |
| result["final_answer"] = final_answer | |
| else: | |
| # Guardrail prompt if no order id | |
| guardrail = SystemMessage( | |
| content=( | |
| "You are a helpful food delivery assistant. Refuse requests for sensitive data. " | |
| "If user asks about an order but provides no order id, ask politely for the order id." | |
| ) | |
| ) | |
| final_answer = safe_llm_generate([guardrail, *chat_history, HumanMessage(content=user_input)]) | |
| result["final_answer"] = final_answer | |
| # Save conversation to memory | |
| memory.save_context({"input": user_input}, {"output": result["final_answer"]}) | |
| return result | |