| from pathlib import Path |
|
|
| import streamlit as st |
|
|
| from langchain import SQLDatabase |
| from langchain.agents import AgentType |
| from langchain.agents import initialize_agent, Tool |
| from langchain.callbacks import StreamlitCallbackHandler |
| from langchain.chains import LLMMathChain |
| from langchain.llms import OpenAI |
| from langchain.utilities import DuckDuckGoSearchAPIWrapper |
| from langchain.llms import OpenAI |
| from langchain.sql_database import SQLDatabase |
| from langchain_experimental.sql import SQLDatabaseChain |
| from langchain.chat_models import ChatOpenAI |
|
|
|
|
| from streamlit_agent.callbacks.capturing_callback_handler import playback_callbacks |
| from streamlit_agent.clear_results import with_clear_container |
|
|
| from chat2plot import chat2plot |
| from chat2plot.chat2plot import Plot |
|
|
| import pandas as pd |
| import sqlite3 |
| import os |
|
|
| user_openai_api_key = os.environ.get('OPENAI_API_KEY') |
|
|
| |
| DB_PATH = (Path(__file__).parent / "sitios2.sqlite").absolute() |
|
|
| SAVED_SESSIONS = { |
| "how many points are in field_id = 29?": "alanis.pickle", |
| "what is the proportion of points in point_type?": "alanis.pickle" |
| } |
|
|
| st.set_page_config( |
| page_title="MRKL", page_icon="🦜", layout="wide", initial_sidebar_state="collapsed" |
| ) |
|
|
| "# Points and Samples" |
|
|
| |
| |
| |
| |
|
|
| if user_openai_api_key: |
| openai_api_key = user_openai_api_key |
| enable_custom = True |
| else: |
| openai_api_key = "not_supplied" |
| enable_custom = False |
|
|
|
|
| conn = sqlite3.connect(DB_PATH) |
| df = pd.read_sql_query("SELECT ogc_fid, field_id, point_id, sample_id, label, sample_state, plot_type, depth_range_shallow_m, depth_range_deep_m, sample_timestamp FROM points_and_samples", conn) |
| conn.close() |
|
|
|
|
| |
| llm = OpenAI(temperature=0, openai_api_key=openai_api_key, streaming=True) |
| search = DuckDuckGoSearchAPIWrapper() |
| llm_math_chain = LLMMathChain.from_llm(llm) |
| db = SQLDatabase.from_uri(f"sqlite:///{DB_PATH}") |
| db_chain = SQLDatabaseChain.from_llm(llm, db) |
| c2p = chat2plot(df,chat=ChatOpenAI(temperature=0, model_name="gpt-3.5-turbo-0125")) |
| tools = [ |
| Tool( |
| name="Search", |
| func=search.run, |
| description="useful for when you need to answer questions about current events. You should ask targeted questions", |
| ), |
| Tool( |
| name="Calculator", |
| func=llm_math_chain.run, |
| description="useful for when you need to answer questions about math", |
| ), |
| Tool( |
| name="sitios piloto DB", |
| func=db_chain.run, |
| description="useful for when you need to answer questions about sitios piloto. Input should be in the form of a question containing full context", |
| ), |
| Tool( |
| name="chat2plot", |
| func=c2p, |
| description="useful for when you need to create a plot from a table", |
| return_direct=True, |
| ), |
| ] |
|
|
| |
| mrkl = initialize_agent( |
| tools, llm, agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION, verbose=True |
| ) |
|
|
| |
|
|
|
|
| st.write(df) |
|
|
|
|
| with st.form(key="form"): |
| if not enable_custom: |
| "Ask one of the sample questions, or enter your API Key in the sidebar to ask your own custom questions." |
| prefilled = st.selectbox("Sample questions", sorted(SAVED_SESSIONS.keys())) or "" |
| mrkl_input = "" |
|
|
| if enable_custom: |
| |
| question = list(SAVED_SESSIONS.keys())[0] |
| user_input = st.text_input("Or, ask your own question") |
| if not user_input: |
| user_input = prefilled |
| submit_clicked = st.form_submit_button("Submit Question") |
|
|
| output_container = st.empty() |
| if with_clear_container(submit_clicked): |
| with output_container.container(): |
| output_container.text("user") |
| st.write(user_input) |
| |
| |
| with output_container.container(): |
| output_container.text("assistant") |
| |
| answer_container = output_container.text("assistant") |
| st_callback = StreamlitCallbackHandler(answer_container) |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| answer = mrkl.run(user_input, callbacks=[st_callback]) |
| print(type(answer)) |
| if isinstance(answer, Plot): |
| result = answer |
| st.plotly_chart(result.figure) |
| else: |
| answer_container.write(answer) |
|
|