File size: 2,542 Bytes
3d3d712 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 |
from operator import itemgetter
import pandas as pd
from langchain.chat_models import AzureChatOpenAI, ChatOpenAI
from langchain.prompts import ChatPromptTemplate
from langchain.schema.output_parser import StrOutputParser
from langchain.schema.runnable import RunnableLambda, RunnableMap
from langchain.utilities import SQLDatabase
from taskweaver.plugin import Plugin, register_plugin
@register_plugin
class SqlPullData(Plugin):
def __call__(self, query: str):
api_type = self.config.get("api_type", "azure")
if api_type == "azure":
model = AzureChatOpenAI(
azure_endpoint=self.config.get("api_base"),
openai_api_key=self.config.get("api_key"),
openai_api_version=self.config.get("api_version"),
azure_deployment=self.config.get("deployment_name"),
temperature=0,
verbose=True,
)
elif api_type == "openai":
model = ChatOpenAI(
openai_api_key=self.config.get("api_key"),
model_name=self.config.get("deployment_name"),
temperature=0,
verbose=True,
)
else:
raise ValueError("Invalid API type. Please check your config file.")
template = """Based on the table schema below, write a SQL query that would answer the user's question:
{schema}
Question: {question}
SQL Query:"""
prompt = ChatPromptTemplate.from_template(template)
db = SQLDatabase.from_uri(self.config.get("sqlite_db_path"))
def get_schema(_):
return db.get_table_info()
inputs = {
"schema": RunnableLambda(get_schema),
"question": itemgetter("question"),
}
sql_response = RunnableMap(inputs) | prompt | model.bind(stop=["\nSQLResult:"]) | StrOutputParser()
sql = sql_response.invoke({"question": query})
result = db._execute(sql, fetch="all")
df = pd.DataFrame(result)
if len(df) == 0:
return df, (
f"I have generated a SQL query based on `{query}`.\nThe SQL query is {sql}.\n" f"The result is empty."
)
else:
return df, (
f"I have generated a SQL query based on `{query}`.\nThe SQL query is {sql}.\n"
f"There are {len(df)} rows in the result.\n"
f"The first {min(5, len(df))} rows are:\n{df.head(min(5, len(df))).to_markdown()}"
)
|