File size: 3,333 Bytes
f9ed673
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
import streamlit as st
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import os
from langchain.agents import create_pandas_dataframe_agent
from langchain.llms import OpenAI
from dotenv import load_dotenv
load_dotenv()


def query_agent(df, query):
    # Use the environment variable for OpenAI API key
    openai_api_key = os.getenv("OPENAI_API_KEY")

    llm = OpenAI(openai_api_key=openai_api_key)

    # Create a Pandas DataFrame Agent.
    agent = create_pandas_dataframe_agent(llm, df, verbose=True)

    # Python REPL: A python shell used to evaluating and executing Python commands
    # It takes python code as an input and outputs the results. The input python code can be generated from another tool in the langchain
    return agent.run(query)

def plot_bar_chart(df, x_column, y_column):
    fig, ax = plt.subplots()
    ax.bar(df[x_column], df[y_column])
    ax.set_xlabel(x_column)
    ax.set_ylabel(y_column)
    ax.set_title(f"Bar Chart: {x_column} vs {y_column}")
    st.pyplot(fig)

def plot_line_chart(df, x_column, y_column):
    fig, ax = plt.subplots()
    ax.plot(df[x_column], df[y_column])
    ax.set_xlabel(x_column)
    ax.set_ylabel(y_column)
    ax.set_title(f"Line Chart: {x_column} vs {y_column}")
    st.pyplot(fig)

def plot_scatter_plot(df, x_column, y_column):
    fig, ax = plt.subplots()
    ax.scatter(df[x_column], df[y_column])
    ax.set_xlabel(x_column)
    ax.set_ylabel(y_column)
    ax.set_title(f"Scatter Plot: {x_column} vs {y_column}")
    st.pyplot(fig)

def plot_heatmap(df):
    fig, ax = plt.subplots()
    sns.heatmap(df.corr(), annot=True, cmap="coolwarm")
    ax.set_title("Heatmap")
    st.pyplot(fig)

def main():
    st.title("Data Analysis with Streamlit")
    st.header("Please upload the document here:")
    
    # Initialize df to None
    df = None

    # Capture the file
    data = st.file_uploader("Upload CSV file", type="csv")
    
    query = st.text_area("Enter your query")
    button = st.button("Generate Response")
    
    if data is not None:
        # Read the uploaded CSV file into a DataFrame
        df = pd.read_csv(data)

    if button:
        if df is not None:
            # Get Response using query_agent function
            answer = query_agent(df, query)
            st.write("Response: ", answer)
    
    # Data Visualization: Chart Options
    st.subheader("Data Visualization: Chart Options")
    chart_types = ["Bar Chart", "Line Chart", "Scatter Plot", "Heatmap"]
    selected_chart = st.selectbox("Select Chart Type", chart_types)

    if df is not None:
        x_column = st.selectbox("Select X-axis column", df.columns)
        y_column = st.selectbox("Select Y-axis column", df.columns)

        if x_column != y_column:
            if selected_chart == "Bar Chart":
                plot_bar_chart(df, x_column, y_column)
            elif selected_chart == "Line Chart":
                plot_line_chart(df, x_column, y_column)
            elif selected_chart == "Scatter Plot":
                plot_scatter_plot(df, x_column, y_column)
            elif selected_chart == "Heatmap":
                plot_heatmap(df)
        else:
            st.write("Please select different columns for X-axis and Y-axis.")
    else:
        st.warning("Please upload a CSV file.")

if __name__ == "__main__":
    main()