Sush
Initial commit: Banking Intelligence Assistant
c62ce64
import os
from langchain_groq import ChatGroq
from langchain_community.utilities import SQLDatabase
from langchain_core.prompts import PromptTemplate
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import RunnableLambda
from dotenv import load_dotenv
load_dotenv()
def load_sql_agent(db_path: str = "data/database/banking.db"):
"""Custom SQL chain using LCEL — more reliable than agent for Llama models."""
db = SQLDatabase.from_uri(f"sqlite:///{db_path}")
llm = ChatGroq(
api_key=os.getenv("GROQ_API_KEY"),
model_name="llama-3.3-70b-versatile",
temperature=0
)
# Step 1 — Generate SQL from question
sql_generation_prompt = PromptTemplate(
template="""You are a SQL expert. Given the database schema and question, write a SQLite SQL query.
Return ONLY the SQL query, nothing else. No explanation, no markdown, no backticks.
Database Schema:
{schema}
Question: {question}
SQL Query:""",
input_variables=["schema", "question"]
)
# Step 2 — Generate final answer from SQL result
answer_prompt = PromptTemplate(
template="""Given the question, SQL query, and result, write a clear answer.
Question: {question}
SQL Query: {query}
SQL Result: {result}
Answer:""",
input_variables=["question", "query", "result"]
)
def run_sql_chain(question: str) -> str:
try:
# Get schema
schema = db.get_table_info()
# Generate SQL
sql_chain = sql_generation_prompt | llm | StrOutputParser()
sql_query = sql_chain.invoke({
"schema": schema,
"question": question
}).strip()
# Clean up query if needed
sql_query = sql_query.replace("```sql", "").replace("```", "").strip()
# Execute SQL
result = db.run(sql_query)
# Generate answer
answer_chain = answer_prompt | llm | StrOutputParser()
answer = answer_chain.invoke({
"question": question,
"query": sql_query,
"result": result
})
return answer
except Exception as e:
return f"I encountered an error processing your query: {str(e)}"
# Wrap as a dict-compatible interface to match orchestrator
class SQLChainWrapper:
def invoke(self, input_dict):
question = input_dict.get("input", "")
output = run_sql_chain(question)
return {"output": output}
return SQLChainWrapper()