File size: 2,604 Bytes
c62ce64 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 | import os
from langchain_groq import ChatGroq
from langchain_community.utilities import SQLDatabase
from langchain_core.prompts import PromptTemplate
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import RunnableLambda
from dotenv import load_dotenv
load_dotenv()
def load_sql_agent(db_path: str = "data/database/banking.db"):
"""Custom SQL chain using LCEL — more reliable than agent for Llama models."""
db = SQLDatabase.from_uri(f"sqlite:///{db_path}")
llm = ChatGroq(
api_key=os.getenv("GROQ_API_KEY"),
model_name="llama-3.3-70b-versatile",
temperature=0
)
# Step 1 — Generate SQL from question
sql_generation_prompt = PromptTemplate(
template="""You are a SQL expert. Given the database schema and question, write a SQLite SQL query.
Return ONLY the SQL query, nothing else. No explanation, no markdown, no backticks.
Database Schema:
{schema}
Question: {question}
SQL Query:""",
input_variables=["schema", "question"]
)
# Step 2 — Generate final answer from SQL result
answer_prompt = PromptTemplate(
template="""Given the question, SQL query, and result, write a clear answer.
Question: {question}
SQL Query: {query}
SQL Result: {result}
Answer:""",
input_variables=["question", "query", "result"]
)
def run_sql_chain(question: str) -> str:
try:
# Get schema
schema = db.get_table_info()
# Generate SQL
sql_chain = sql_generation_prompt | llm | StrOutputParser()
sql_query = sql_chain.invoke({
"schema": schema,
"question": question
}).strip()
# Clean up query if needed
sql_query = sql_query.replace("```sql", "").replace("```", "").strip()
# Execute SQL
result = db.run(sql_query)
# Generate answer
answer_chain = answer_prompt | llm | StrOutputParser()
answer = answer_chain.invoke({
"question": question,
"query": sql_query,
"result": result
})
return answer
except Exception as e:
return f"I encountered an error processing your query: {str(e)}"
# Wrap as a dict-compatible interface to match orchestrator
class SQLChainWrapper:
def invoke(self, input_dict):
question = input_dict.get("input", "")
output = run_sql_chain(question)
return {"output": output}
return SQLChainWrapper() |