|
|
from operator import itemgetter |
|
|
|
|
|
import pandas as pd |
|
|
from langchain.chat_models import AzureChatOpenAI, ChatOpenAI |
|
|
from langchain.prompts import ChatPromptTemplate |
|
|
from langchain.schema.output_parser import StrOutputParser |
|
|
from langchain.schema.runnable import RunnableLambda, RunnableMap |
|
|
from langchain.utilities import SQLDatabase |
|
|
|
|
|
from taskweaver.plugin import Plugin, register_plugin |
|
|
|
|
|
|
|
|
@register_plugin |
|
|
class SqlPullData(Plugin): |
|
|
def __call__(self, query: str): |
|
|
api_type = self.config.get("api_type", "azure") |
|
|
if api_type == "azure": |
|
|
model = AzureChatOpenAI( |
|
|
azure_endpoint=self.config.get("api_base"), |
|
|
openai_api_key=self.config.get("api_key"), |
|
|
openai_api_version=self.config.get("api_version"), |
|
|
azure_deployment=self.config.get("deployment_name"), |
|
|
temperature=0, |
|
|
verbose=True, |
|
|
) |
|
|
elif api_type == "openai": |
|
|
model = ChatOpenAI( |
|
|
openai_api_key=self.config.get("api_key"), |
|
|
model_name=self.config.get("deployment_name"), |
|
|
temperature=0, |
|
|
verbose=True, |
|
|
) |
|
|
else: |
|
|
raise ValueError("Invalid API type. Please check your config file.") |
|
|
|
|
|
template = """Based on the table schema below, write a SQL query that would answer the user's question: |
|
|
{schema} |
|
|
|
|
|
Question: {question} |
|
|
SQL Query:""" |
|
|
prompt = ChatPromptTemplate.from_template(template) |
|
|
|
|
|
db = SQLDatabase.from_uri(self.config.get("sqlite_db_path")) |
|
|
|
|
|
def get_schema(_): |
|
|
return db.get_table_info() |
|
|
|
|
|
inputs = { |
|
|
"schema": RunnableLambda(get_schema), |
|
|
"question": itemgetter("question"), |
|
|
} |
|
|
sql_response = RunnableMap(inputs) | prompt | model.bind(stop=["\nSQLResult:"]) | StrOutputParser() |
|
|
|
|
|
sql = sql_response.invoke({"question": query}) |
|
|
|
|
|
result = db._execute(sql, fetch="all") |
|
|
|
|
|
df = pd.DataFrame(result) |
|
|
|
|
|
if len(df) == 0: |
|
|
return df, ( |
|
|
f"I have generated a SQL query based on `{query}`.\nThe SQL query is {sql}.\n" f"The result is empty." |
|
|
) |
|
|
else: |
|
|
return df, ( |
|
|
f"I have generated a SQL query based on `{query}`.\nThe SQL query is {sql}.\n" |
|
|
f"There are {len(df)} rows in the result.\n" |
|
|
f"The first {min(5, len(df))} rows are:\n{df.head(min(5, len(df))).to_markdown()}" |
|
|
) |
|
|
|