| import os |
| from langchain_groq import ChatGroq |
| from langchain_community.utilities import SQLDatabase |
| from langchain_core.prompts import PromptTemplate |
| from langchain_core.output_parsers import StrOutputParser |
| from langchain_core.runnables import RunnableLambda |
| from dotenv import load_dotenv |
|
|
| load_dotenv() |
|
|
| def load_sql_agent(db_path: str = "data/database/banking.db"): |
| """Custom SQL chain using LCEL — more reliable than agent for Llama models.""" |
|
|
| db = SQLDatabase.from_uri(f"sqlite:///{db_path}") |
|
|
| llm = ChatGroq( |
| api_key=os.getenv("GROQ_API_KEY"), |
| model_name="llama-3.3-70b-versatile", |
| temperature=0 |
| ) |
|
|
| |
| sql_generation_prompt = PromptTemplate( |
| template="""You are a SQL expert. Given the database schema and question, write a SQLite SQL query. |
| Return ONLY the SQL query, nothing else. No explanation, no markdown, no backticks. |
| |
| Database Schema: |
| {schema} |
| |
| Question: {question} |
| |
| SQL Query:""", |
| input_variables=["schema", "question"] |
| ) |
|
|
| |
| answer_prompt = PromptTemplate( |
| template="""Given the question, SQL query, and result, write a clear answer. |
| |
| Question: {question} |
| SQL Query: {query} |
| SQL Result: {result} |
| |
| Answer:""", |
| input_variables=["question", "query", "result"] |
| ) |
|
|
| def run_sql_chain(question: str) -> str: |
| try: |
| |
| schema = db.get_table_info() |
|
|
| |
| sql_chain = sql_generation_prompt | llm | StrOutputParser() |
| sql_query = sql_chain.invoke({ |
| "schema": schema, |
| "question": question |
| }).strip() |
|
|
| |
| sql_query = sql_query.replace("```sql", "").replace("```", "").strip() |
|
|
| |
| result = db.run(sql_query) |
|
|
| |
| answer_chain = answer_prompt | llm | StrOutputParser() |
| answer = answer_chain.invoke({ |
| "question": question, |
| "query": sql_query, |
| "result": result |
| }) |
|
|
| return answer |
|
|
| except Exception as e: |
| return f"I encountered an error processing your query: {str(e)}" |
|
|
| |
| class SQLChainWrapper: |
| def invoke(self, input_dict): |
| question = input_dict.get("input", "") |
| output = run_sql_chain(question) |
| return {"output": output} |
|
|
| return SQLChainWrapper() |