Spaces:
Sleeping
Sleeping
| import streamlit as st | |
| import numpy as np | |
| import pandas as pd | |
| import altair as alt | |
| import chat as idf_chat | |
| from langchain.sql_database import SQLDatabase | |
| from langchain.agents.agent_toolkits import SQLDatabaseToolkit | |
| from langchain.agents import create_sql_agent | |
| from langchain import OpenAI | |
| from langchain import PromptTemplate, OpenAI, LLMChain | |
| from langchain.chains import SimpleSequentialChain | |
| from langchain_experimental.sql import SQLDatabaseChain | |
| from langchain_experimental.agents.agent_toolkits import create_python_agent | |
| from langchain_experimental.tools import PythonREPLTool | |
| from langchain_experimental.utilities import PythonREPL | |
| JSON_DATA_LABEL = 'json_data' | |
| llm=OpenAI(temperature=0) | |
| db = SQLDatabase.from_uri("sqlite:///FXTrades.db") | |
| if JSON_DATA_LABEL not in st.session_state: | |
| st.session_state[JSON_DATA_LABEL] = {} | |
| def get_db_chain(): | |
| template = """Your name is IDF. If you recieve a "Hey IDF" salute reply by saying, "Hey BBHer!". Given an input question, first create a syntactically correct {dialect} query to run, then look at the results of the query and return the answer. | |
| Use the following format: | |
| Question: "Question here" | |
| SQLQuery: "SQL Query to run" | |
| SQLResult: "Result of the SQLQuery" | |
| Answer: "Final answer here" | |
| Only use the following tables: | |
| {table_info} | |
| If someone asks any question involving client name, you need to join with Client table | |
| volume: you need to count records | |
| Trades: you need to get volume of trades | |
| Currency Bought: you need to use ccyBought | |
| Currency Sold: you need to use ccySold | |
| Question: {input}""" | |
| PROMPT = PromptTemplate( | |
| input_variables=["input", "table_info", "dialect"], | |
| template=template | |
| ) | |
| return SQLDatabaseChain.from_llm(llm, db, prompt=PROMPT, verbose=True) | |
| def get_json_chain(): | |
| prompt_template = "Reformat this {result} in JSON format" | |
| return LLMChain( | |
| llm=llm, | |
| prompt=PromptTemplate.from_template(prompt_template) | |
| ) | |
| def plot_chart(): | |
| json = st.session_state[JSON_DATA_LABEL] | |
| # print(json) | |
| # if not json: | |
| # return "no data to plot" | |
| agent_executor = create_python_agent( | |
| llm=llm, | |
| tool=PythonREPLTool(), | |
| verbose=True | |
| ) | |
| question = "Plot these results: " + json | |
| # Add Open AI call to format outcome in a table | |
| return agent_executor.run(question) | |
| db_chain = get_db_chain() | |
| json_chain = get_json_chain() | |
| def from_gpt(query: str, plot: bool): | |
| try: | |
| chains = [db_chain, json_chain] if plot else [db_chain] | |
| main_chain = SimpleSequentialChain(chains=chains, verbose=True) | |
| ans = main_chain.run(query) | |
| # Execute Python if plot | |
| st.session_state[JSON_DATA_LABEL] = ans if plot else {} | |
| return ans | |
| except Exception as e: | |
| return "Please ask a proper business question related to selected datasets" | |
| def circle_chart(): | |
| df = pd.DataFrame(np.random.randn(200, 3), columns=['a', 'b', 'c']) | |
| return alt.Chart(df).mark_circle().encode( | |
| x='a', y='b', size='c', color='c', tooltip=['a', 'b', 'c'] | |
| ) | |
| # Parse the prompt to pick an example and a render function | |
| def get_response(prompt: str, *kargs): | |
| on_render = st.write | |
| response = f"Here's what you asked: '{prompt}'" | |
| prompt_lower = prompt.lower() | |
| if prompt_lower == 'line chart': | |
| on_render = st.line_chart | |
| response = np.random.randn(30, 30) | |
| elif prompt_lower == 'circle chart': | |
| on_render = st.write | |
| response = circle_chart() | |
| elif prompt_lower.startswith('json'): | |
| p = prompt_lower.split('json ')[1] | |
| on_render = st.write | |
| response = from_gpt(p, plot=True) | |
| elif prompt_lower == 'plot': | |
| on_render = st.write | |
| response = plot_chart() | |
| else: | |
| on_render = st.write | |
| response = from_gpt(prompt, plot=False) | |
| return (response, on_render) | |
| chat = idf_chat.Chat() | |
| sidebar_text = """ | |
| # Ask IDF | |
| ## Example prompts | |
| Replace `{client}` with a client name. | |
| Hint: you can ask IDF to tell you what clients are available | |
| ``` | |
| What is the total USD Amount for client {client} for Jan 2022 | |
| ``` | |
| ``` | |
| What is the average USD Amount per month for client {client} for 2022 | |
| ``` | |
| ``` | |
| Trades volume for {client} in May 2022 | |
| ``` | |
| ``` | |
| Get the total USD amount where Currency Bought is {Currency (e.g. GBP)} for {client} in Jun 2022 | |
| ``` | |
| ``` | |
| Get the Currency Bought/Sold with least total USD amount for {client} in Mar 2022 | |
| ``` | |
| """ | |
| with st.sidebar: | |
| st.markdown(sidebar_text) | |
| # prompt = chat.get_promt("Ask IDF Anything") | |
| chat.process(get_response, llm) | |