visualquery / langchain_helper.py
binaychandra's picture
modified langchain_helper.py
6661c2e
import os
from langchain_openai import AzureOpenAI, ChatOpenAI
from langchain_core.prompts import ChatPromptTemplate
from langchain.agents.agent_types import AgentType
from langchain_experimental.agents import create_pandas_dataframe_agent
from langchain_community.utilities import SQLDatabase
from langchain_experimental.sql import SQLDatabaseChain
from langchain.prompts import SemanticSimilarityExampleSelector
from langchain_openai import AzureOpenAIEmbeddings, OpenAIEmbeddings
from langchain_community.vectorstores import Chroma
from langchain.prompts import FewShotPromptTemplate
from langchain.prompts.prompt import PromptTemplate
from langchain.chains.sql_database.prompt import PROMPT_SUFFIX, _mysql_prompt
from sqlalchemy import create_engine
from project_prompts import sqlite_prompt
from few_shots import few_shots
import pandas as pd
import chromadb
import plotly
import plotly.express as px
from plotly.express import bar, line, scatter, area, pie
# from dotenv import load_dotenv
# load_dotenv()
from dotenv import load_dotenv, find_dotenv
_ = load_dotenv(find_dotenv()) # read local .env file
current_model_id = "gpt-35-turbo-instruct"
def get_few_shot_db_chain(user_message):
chromadb.api.client.SharedSystemClient.clear_system_cache()
llm = AzureOpenAI(deployment_name=current_model_id, temperature=0.2)
#llm = ChatOpenAI(model = current_model_id)
print(llm)
engine = create_engine("sqlite:///ecomm.db")
db = SQLDatabase(engine=engine, sample_rows_in_table_info=3)
embeddings = AzureOpenAIEmbeddings(model="text-embedding-3-small")
print(embeddings)
to_vectorize = [" ".join(example.values()) for example in few_shots]
vectorstore = Chroma.from_texts(to_vectorize, embeddings, metadatas=few_shots)
print(vectorstore)
example_selector = SemanticSimilarityExampleSelector(vectorstore=vectorstore, k=2)
example_prompt = PromptTemplate(
input_variables=["Question", "SQLQuery", "SQLResult","Answer",],
template="\nQuestion: {Question}\nSQLQuery: {SQLQuery}\nSQLResult: {SQLResult}\nAnswer: {Answer}"
)
few_shot_prompt = FewShotPromptTemplate(
example_selector=example_selector,
example_prompt=example_prompt,
prefix=sqlite_prompt,
suffix=PROMPT_SUFFIX,
input_variables=["input", "table_info", "top_k"]
)
chain = SQLDatabaseChain.from_llm(llm, db, verbose=True, prompt=few_shot_prompt, return_intermediate_steps = True)
print(chain)
response_llm = chain.invoke(user_message)
print("========PRINTING response LLM==========")
print(response_llm)
print(f"sql query : {response_llm['intermediate_steps'][1]}")
if 'sql_cmd' in response_llm['intermediate_steps'][2].keys():
intermediate_sql_query = response_llm['intermediate_steps'][2]['sql_cmd']
print(f"This is the intermediate query : {intermediate_sql_query}")
if intermediate_sql_query.startswith('SQLQuery: '):
intermediate_sql_query = intermediate_sql_query.replace("SQLQuery: ", "")
result_df = pd.read_sql_query(intermediate_sql_query, engine)
print("Printing results")
print(result_df)
output_dict = {
"result_df" : result_df,
"sql_command" : intermediate_sql_query,
"response" : response_llm['result'],
"input" : response_llm['query'],
"graph_data" : None if ((result_df.shape[0] < 2) | (result_df.shape[1] < 2)) else get_graph_details(user_message, result_df)
}
return output_dict
def get_graph_details(usermessage:str, df=None):
llm = AzureOpenAI(deployment_name=current_model_id, temperature=0.15)
#llm = ChatOpenAI(deployment_name=current_model_id, temperature=0.15)
template = ChatPromptTemplate.from_messages(
[("system", "You are a visualisation expert and plotly developer, your task is to come up with best suitable \
chart representing user ask for the given data. please use plotly express library in python for \
charting purposes.. and provide code for generating the figure.. there should not be any displaying \
instructions..like fig.show() etc.."),
("human", "For the given dataframe below \
---------------------------------\
Dataframe = {dataframe} \
---------------------------------\
and user question \
---------------------------------\
user_ask = {question} \
----------------------------------\
Please provide the plotly chart which \
would be best suitable to represent the user ask graphically \
Please double check the code is not having any fig.show() or display commands"
)]
)
customer_messages = template.format_messages(dataframe = df, question=usermessage)
agent = create_pandas_dataframe_agent(
llm,
df,
agent_type=AgentType.ZERO_SHOT_REACT_DESCRIPTION,
verbose=True,
return_intermediate_steps=True
)
agent_response = agent.invoke(customer_messages)
out_agent_response = agent_response['intermediate_steps']
for _, agent_code_reponse in out_agent_response:
if isinstance(agent_code_reponse, plotly.graph_objects.Figure):
fig = agent_code_reponse
return fig
else:
template = ChatPromptTemplate.from_messages([
("system", "You are a visualisation expert and plotly developer, your task is to come up with best suitable \
chart representing user ask for the given data. please use plotly express library in python for \
charting purposes.. and provide code for generating the figure.. there should not be any displaying \
instructions..like fig.show() etc.."),
("human", "For the given dataframe below \
---------------------------------\
df = State Total_GDP\
0 Florida 7743.0\
1 Texas 9934.0\
2 New_York 6634.5\
3 Denver 4456.0\
4 Atlanta 993.5 \
---------------------------------\
and user question \
---------------------------------\
user_ask = What is the distribution of Total_GDP for each state? \
----------------------------------\
Please provide the code using plotly express in less than 30 words which should clearly satisfy user ask\
in terms of best representation of data. please use dataframe variable as 'df' and \
strictly output only one line of python code start your code with initializing a figure object \n\
like `fig = px.`"),
("ai", "bar(df, x='State', y='Total_GDP', title='Distribution of Total_GDP per State')"),
("human", "This is incorrect.. the required response should be \
`fig = plt.bar(df, x='Plant_Name', y='Total_Available_Days', title='Distribution of Available Days for Each Plant Name')`\
as it starts with `fig = plt.` as user specified"),
("ai", "Sounds good, now I will remember to start with `fig = plt.`"),
("human", "For the given dataframe below \
---------------------------------\
df = {dataframe} \
---------------------------------\
and user question \
---------------------------------\
user_ask = {question} \
----------------------------------\
Please provide the code using plotly express in less than 40 words which should clearly satisfy user ask\
in terms of best representation of data. please use dataframe variable as 'df' and \
strictly output only one line of python code start your code with initializing a figure object \n\
like `fig = px.`"),
])
customer_messages = template.format_messages(dataframe = df, question=usermessage)
print(f"This is the customer message : {customer_messages}")
code_response_llm = llm.invoke(customer_messages)
print(f"This is the code returned by LLM : {code_response_llm}")
try:
print("## Executing the code line generated by llm ##")
if "fig = " in code_response_llm:
code_response_llm = code_response_llm.replace("AI: ", "")
namespace = {'df': df}
exec(code_response_llm, globals(), namespace)
if 'fig' in namespace.keys():
print("fig is there returning fig>>>>>")
return namespace['fig']
else:
return None
except Exception as e:
print(f"Some exception occurred : {str(e)}")
return None
return None