File size: 9,543 Bytes
adf2969
e70ee47
adf2969
 
 
 
 
 
47d3d15
adf2969
 
 
 
 
 
 
 
e6f9ae7
adf2969
 
 
 
47d3d15
 
 
49c30d8
e6f9ae7
47d3d15
 
 
6661c2e
adf2969
 
49c30d8
dc3cb1d
 
ed60c89
adf2969
 
 
dc3cb1d
ed60c89
adf2969
 
 
ed60c89
adf2969
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ed60c89
adf2969
7f85b45
 
adf2969
 
 
7f85b45
501f9c8
adf2969
10d2525
 
501f9c8
adf2969
ed60c89
 
adf2969
 
 
 
 
 
 
 
 
 
 
 
dc3cb1d
 
adf2969
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
import os
from langchain_openai import AzureOpenAI, ChatOpenAI
from langchain_core.prompts import ChatPromptTemplate
from langchain.agents.agent_types import AgentType
from langchain_experimental.agents import create_pandas_dataframe_agent
from langchain_community.utilities import SQLDatabase
from langchain_experimental.sql import SQLDatabaseChain
from langchain.prompts import SemanticSimilarityExampleSelector
from langchain_openai import AzureOpenAIEmbeddings, OpenAIEmbeddings
from langchain_community.vectorstores import Chroma
from langchain.prompts import FewShotPromptTemplate
from langchain.prompts.prompt import PromptTemplate
from langchain.chains.sql_database.prompt import PROMPT_SUFFIX, _mysql_prompt
from sqlalchemy import create_engine
from project_prompts import sqlite_prompt
from few_shots import few_shots
import pandas as pd
import chromadb
import plotly
import plotly.express as px
from plotly.express import bar, line, scatter, area, pie

# from dotenv import load_dotenv
# load_dotenv()



from dotenv import load_dotenv, find_dotenv
_ = load_dotenv(find_dotenv()) # read local .env file

current_model_id = "gpt-35-turbo-instruct"

def get_few_shot_db_chain(user_message):
    chromadb.api.client.SharedSystemClient.clear_system_cache()
    llm = AzureOpenAI(deployment_name=current_model_id, temperature=0.2)
    #llm = ChatOpenAI(model = current_model_id)
    print(llm)
    engine = create_engine("sqlite:///ecomm.db")
    db = SQLDatabase(engine=engine, sample_rows_in_table_info=3)
    
    embeddings = AzureOpenAIEmbeddings(model="text-embedding-3-small")
    print(embeddings)
    to_vectorize = [" ".join(example.values()) for example in few_shots]
    
    vectorstore = Chroma.from_texts(to_vectorize, embeddings, metadatas=few_shots)
    print(vectorstore)
    example_selector = SemanticSimilarityExampleSelector(vectorstore=vectorstore, k=2)
    
    example_prompt = PromptTemplate(
        input_variables=["Question", "SQLQuery", "SQLResult","Answer",],
        template="\nQuestion: {Question}\nSQLQuery: {SQLQuery}\nSQLResult: {SQLResult}\nAnswer: {Answer}"
    )
    
    few_shot_prompt = FewShotPromptTemplate(
                                example_selector=example_selector,
                                example_prompt=example_prompt,
                                prefix=sqlite_prompt,
                                suffix=PROMPT_SUFFIX,
                                input_variables=["input", "table_info", "top_k"]
                            )
    
    chain = SQLDatabaseChain.from_llm(llm, db, verbose=True, prompt=few_shot_prompt, return_intermediate_steps = True)
    print(chain)
    response_llm = chain.invoke(user_message)
    print("========PRINTING response LLM==========")
    print(response_llm)
    print(f"sql query : {response_llm['intermediate_steps'][1]}")
    if 'sql_cmd' in response_llm['intermediate_steps'][2].keys():
        intermediate_sql_query = response_llm['intermediate_steps'][2]['sql_cmd']

    print(f"This is the intermediate query : {intermediate_sql_query}")
    
    if intermediate_sql_query.startswith('SQLQuery: '):
        intermediate_sql_query = intermediate_sql_query.replace("SQLQuery: ", "")

    result_df = pd.read_sql_query(intermediate_sql_query, engine)
    print("Printing results")
    print(result_df)
    output_dict = {
        "result_df" : result_df,
        "sql_command" : intermediate_sql_query,
        "response" : response_llm['result'],
        "input" : response_llm['query'], 
        "graph_data" : None if ((result_df.shape[0] < 2) | (result_df.shape[1] < 2)) else get_graph_details(user_message, result_df)
    }

    return output_dict
    

def get_graph_details(usermessage:str, df=None):
    llm = AzureOpenAI(deployment_name=current_model_id, temperature=0.15)
    #llm = ChatOpenAI(deployment_name=current_model_id, temperature=0.15)
    template = ChatPromptTemplate.from_messages(
        [("system", "You are a visualisation expert and plotly developer, your task is to come up with best suitable \
                     chart representing user ask for the given data. please use plotly express library in python for \
                     charting purposes.. and provide code for generating the figure.. there should not be any displaying \
                     instructions..like fig.show() etc.."), 
         ("human", "For the given dataframe below \
                    ---------------------------------\
                    Dataframe = {dataframe} \
                    ---------------------------------\
                    and user question \
                    ---------------------------------\
                    user_ask =  {question} \
                    ----------------------------------\
                    Please provide the plotly chart which \
                    would be best suitable to represent  the user ask graphically \
                    Please double check the code is not having any fig.show() or display commands"
                    )]
    )

    customer_messages = template.format_messages(dataframe = df, question=usermessage)

    agent = create_pandas_dataframe_agent(
                llm,
                df, 
                agent_type=AgentType.ZERO_SHOT_REACT_DESCRIPTION,
                verbose=True,
                return_intermediate_steps=True
            )
    
    agent_response = agent.invoke(customer_messages)
    out_agent_response = agent_response['intermediate_steps']

    for _, agent_code_reponse in out_agent_response:
        if isinstance(agent_code_reponse, plotly.graph_objects.Figure):
            fig = agent_code_reponse
            return fig
    
    else:
        template = ChatPromptTemplate.from_messages([
            ("system", "You are a visualisation expert and plotly developer, your task is to come up with best suitable \
                     chart representing user ask for the given data. please use plotly express library in python for \
                     charting purposes.. and provide code for generating the figure.. there should not be any displaying \
                     instructions..like fig.show() etc.."), 
            ("human", "For the given dataframe below \
                    ---------------------------------\
                    df =   State  Total_GDP\
                        0  Florida                7743.0\
                        1  Texas                9934.0\
                        2  New_York                6634.5\
                        3  Denver                4456.0\
                        4  Atlanta                 993.5 \
                    ---------------------------------\
                    and user question \
                    ---------------------------------\
                    user_ask = What is the distribution of Total_GDP for each state? \
                    ----------------------------------\
                    Please provide the code using plotly express in less than 30 words which should clearly satisfy user ask\
                    in terms of best representation of data. please use dataframe variable as 'df' and \
                    strictly output only one line of python code start your code with initializing a figure object \n\
                    like `fig = px.`"),
            ("ai", "bar(df, x='State', y='Total_GDP', title='Distribution of Total_GDP per State')"),
            ("human", "This is incorrect.. the required response should be \
                    `fig = plt.bar(df, x='Plant_Name', y='Total_Available_Days', title='Distribution of Available Days for Each Plant Name')`\
                    as it starts with `fig = plt.` as user specified"),
            ("ai", "Sounds good, now I will remember to start with `fig = plt.`"),
            ("human", "For the given dataframe below \
                    ---------------------------------\
                    df = {dataframe} \
                    ---------------------------------\
                    and user question \
                    ---------------------------------\
                    user_ask =  {question} \
                    ----------------------------------\
                    Please provide the code using plotly express in less than 40 words which should clearly satisfy user ask\
                    in terms of best representation of data. please use dataframe variable as 'df' and \
                    strictly output only one line of python code start your code with initializing a figure object \n\
                    like `fig = px.`"),              
        ])
        customer_messages = template.format_messages(dataframe = df, question=usermessage)
        print(f"This is the customer message : {customer_messages}")
        code_response_llm = llm.invoke(customer_messages)
        print(f"This is the code returned by LLM : {code_response_llm}")
        try:
            print("## Executing the code line generated by llm ##")
            
            if "fig = " in code_response_llm:
                code_response_llm = code_response_llm.replace("AI: ", "")
                namespace = {'df': df}
                exec(code_response_llm, globals(), namespace)
                if 'fig' in namespace.keys():
                    print("fig is there returning fig>>>>>")
                    return namespace['fig']
            else:
                return None
        except Exception as e:
            print(f"Some exception occurred : {str(e)}")
            return None
    
    return None