tskwvr / tests /unit_tests /data /plugins /sql_pull_data.py
TRaw's picture
Upload 297 files
3d3d712
from operator import itemgetter
import pandas as pd
from langchain.chat_models import AzureChatOpenAI, ChatOpenAI
from langchain.prompts import ChatPromptTemplate
from langchain.schema.output_parser import StrOutputParser
from langchain.schema.runnable import RunnableLambda, RunnableMap
from langchain.utilities import SQLDatabase
from taskweaver.plugin import Plugin, register_plugin
@register_plugin
class SqlPullData(Plugin):
def __call__(self, query: str):
api_type = self.config.get("api_type", "azure")
if api_type == "azure":
model = AzureChatOpenAI(
azure_endpoint=self.config.get("api_base"),
openai_api_key=self.config.get("api_key"),
openai_api_version=self.config.get("api_version"),
azure_deployment=self.config.get("deployment_name"),
temperature=0,
verbose=True,
)
elif api_type == "openai":
model = ChatOpenAI(
openai_api_key=self.config.get("api_key"),
model_name=self.config.get("deployment_name"),
temperature=0,
verbose=True,
)
else:
raise ValueError("Invalid API type. Please check your config file.")
template = """Based on the table schema below, write a SQL query that would answer the user's question:
{schema}
Question: {question}
SQL Query:"""
prompt = ChatPromptTemplate.from_template(template)
db = SQLDatabase.from_uri(self.config.get("sqlite_db_path"))
def get_schema(_):
return db.get_table_info()
inputs = {
"schema": RunnableLambda(get_schema),
"question": itemgetter("question"),
}
sql_response = RunnableMap(inputs) | prompt | model.bind(stop=["\nSQLResult:"]) | StrOutputParser()
sql = sql_response.invoke({"question": query})
result = db._execute(sql, fetch="all")
df = pd.DataFrame(result)
if len(df) == 0:
return df, (
f"I have generated a SQL query based on `{query}`.\nThe SQL query is {sql}.\n" f"The result is empty."
)
else:
return df, (
f"I have generated a SQL query based on `{query}`.\nThe SQL query is {sql}.\n"
f"There are {len(df)} rows in the result.\n"
f"The first {min(5, len(df))} rows are:\n{df.head(min(5, len(df))).to_markdown()}"
)