jaichhabria's picture
Rename newapp.py to app.py
e771c26
import streamlit as st
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import os
from langchain.agents import create_pandas_dataframe_agent
from langchain.llms import OpenAI
from dotenv import load_dotenv
load_dotenv()
def query_agent(df, query):
# Use the environment variable for OpenAI API key
openai_api_key = os.getenv("OPENAI_API_KEY")
llm = OpenAI(openai_api_key=openai_api_key)
# Create a Pandas DataFrame Agent.
agent = create_pandas_dataframe_agent(llm, df, verbose=True)
# Python REPL: A python shell used to evaluating and executing Python commands
# It takes python code as an input and outputs the results. The input python code can be generated from another tool in the langchain
return agent.run(query)
def plot_bar_chart(df, x_column, y_column):
fig, ax = plt.subplots()
ax.bar(df[x_column], df[y_column])
ax.set_xlabel(x_column)
ax.set_ylabel(y_column)
ax.set_title(f"Bar Chart: {x_column} vs {y_column}")
st.pyplot(fig)
def plot_line_chart(df, x_column, y_column):
fig, ax = plt.subplots()
ax.plot(df[x_column], df[y_column])
ax.set_xlabel(x_column)
ax.set_ylabel(y_column)
ax.set_title(f"Line Chart: {x_column} vs {y_column}")
st.pyplot(fig)
def plot_scatter_plot(df, x_column, y_column):
fig, ax = plt.subplots()
ax.scatter(df[x_column], df[y_column])
ax.set_xlabel(x_column)
ax.set_ylabel(y_column)
ax.set_title(f"Scatter Plot: {x_column} vs {y_column}")
st.pyplot(fig)
def plot_heatmap(df):
fig, ax = plt.subplots()
sns.heatmap(df.corr(), annot=True, cmap="coolwarm")
ax.set_title("Heatmap")
st.pyplot(fig)
def main():
st.title("Data Analysis with Streamlit")
st.header("Please upload the document here:")
# Initialize df to None
df = None
# Capture the file
data = st.file_uploader("Upload CSV file", type="csv")
query = st.text_area("Enter your query")
button = st.button("Generate Response")
if data is not None:
# Read the uploaded CSV file into a DataFrame
df = pd.read_csv(data)
if button:
if df is not None:
# Get Response using query_agent function
answer = query_agent(df, query)
st.write("Response: ", answer)
# Data Visualization: Chart Options
st.subheader("Data Visualization: Chart Options")
chart_types = ["Bar Chart", "Line Chart", "Scatter Plot", "Heatmap"]
selected_chart = st.selectbox("Select Chart Type", chart_types)
if df is not None:
x_column = st.selectbox("Select X-axis column", df.columns)
y_column = st.selectbox("Select Y-axis column", df.columns)
if x_column != y_column:
if selected_chart == "Bar Chart":
plot_bar_chart(df, x_column, y_column)
elif selected_chart == "Line Chart":
plot_line_chart(df, x_column, y_column)
elif selected_chart == "Scatter Plot":
plot_scatter_plot(df, x_column, y_column)
elif selected_chart == "Heatmap":
plot_heatmap(df)
else:
st.write("Please select different columns for X-axis and Y-axis.")
else:
st.warning("Please upload a CSV file.")
if __name__ == "__main__":
main()