|
|
import streamlit as st |
|
|
import pandas as pd |
|
|
import matplotlib.pyplot as plt |
|
|
import seaborn as sns |
|
|
import os |
|
|
from langchain.agents import create_pandas_dataframe_agent |
|
|
from langchain.llms import OpenAI |
|
|
from dotenv import load_dotenv |
|
|
load_dotenv() |
|
|
|
|
|
|
|
|
def query_agent(df, query): |
|
|
|
|
|
openai_api_key = os.getenv("OPENAI_API_KEY") |
|
|
|
|
|
llm = OpenAI(openai_api_key=openai_api_key) |
|
|
|
|
|
|
|
|
agent = create_pandas_dataframe_agent(llm, df, verbose=True) |
|
|
|
|
|
|
|
|
|
|
|
return agent.run(query) |
|
|
|
|
|
def plot_bar_chart(df, x_column, y_column): |
|
|
fig, ax = plt.subplots() |
|
|
ax.bar(df[x_column], df[y_column]) |
|
|
ax.set_xlabel(x_column) |
|
|
ax.set_ylabel(y_column) |
|
|
ax.set_title(f"Bar Chart: {x_column} vs {y_column}") |
|
|
st.pyplot(fig) |
|
|
|
|
|
def plot_line_chart(df, x_column, y_column): |
|
|
fig, ax = plt.subplots() |
|
|
ax.plot(df[x_column], df[y_column]) |
|
|
ax.set_xlabel(x_column) |
|
|
ax.set_ylabel(y_column) |
|
|
ax.set_title(f"Line Chart: {x_column} vs {y_column}") |
|
|
st.pyplot(fig) |
|
|
|
|
|
def plot_scatter_plot(df, x_column, y_column): |
|
|
fig, ax = plt.subplots() |
|
|
ax.scatter(df[x_column], df[y_column]) |
|
|
ax.set_xlabel(x_column) |
|
|
ax.set_ylabel(y_column) |
|
|
ax.set_title(f"Scatter Plot: {x_column} vs {y_column}") |
|
|
st.pyplot(fig) |
|
|
|
|
|
def plot_heatmap(df): |
|
|
fig, ax = plt.subplots() |
|
|
sns.heatmap(df.corr(), annot=True, cmap="coolwarm") |
|
|
ax.set_title("Heatmap") |
|
|
st.pyplot(fig) |
|
|
|
|
|
def main(): |
|
|
st.title("Data Analysis with Streamlit") |
|
|
st.header("Please upload the document here:") |
|
|
|
|
|
|
|
|
df = None |
|
|
|
|
|
|
|
|
data = st.file_uploader("Upload CSV file", type="csv") |
|
|
|
|
|
query = st.text_area("Enter your query") |
|
|
button = st.button("Generate Response") |
|
|
|
|
|
if data is not None: |
|
|
|
|
|
df = pd.read_csv(data) |
|
|
|
|
|
if button: |
|
|
if df is not None: |
|
|
|
|
|
answer = query_agent(df, query) |
|
|
st.write("Response: ", answer) |
|
|
|
|
|
|
|
|
st.subheader("Data Visualization: Chart Options") |
|
|
chart_types = ["Bar Chart", "Line Chart", "Scatter Plot", "Heatmap"] |
|
|
selected_chart = st.selectbox("Select Chart Type", chart_types) |
|
|
|
|
|
if df is not None: |
|
|
x_column = st.selectbox("Select X-axis column", df.columns) |
|
|
y_column = st.selectbox("Select Y-axis column", df.columns) |
|
|
|
|
|
if x_column != y_column: |
|
|
if selected_chart == "Bar Chart": |
|
|
plot_bar_chart(df, x_column, y_column) |
|
|
elif selected_chart == "Line Chart": |
|
|
plot_line_chart(df, x_column, y_column) |
|
|
elif selected_chart == "Scatter Plot": |
|
|
plot_scatter_plot(df, x_column, y_column) |
|
|
elif selected_chart == "Heatmap": |
|
|
plot_heatmap(df) |
|
|
else: |
|
|
st.write("Please select different columns for X-axis and Y-axis.") |
|
|
else: |
|
|
st.warning("Please upload a CSV file.") |
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |