File size: 8,497 Bytes
417d8a6
 
 
 
559d037
a36f392
559d037
 
508be90
c443e24
417d8a6
e89e6f7
417d8a6
e89e6f7
 
 
 
 
a36f392
e89e6f7
 
 
 
 
 
417d8a6
 
559d037
508be90
0779acb
508be90
559d037
508be90
559d037
 
c443e24
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a36f392
417d8a6
 
 
 
05d8272
0779acb
508be90
 
 
9dd7961
508be90
a36f392
508be90
559d037
0779acb
 
508be90
0779acb
508be90
 
 
 
9dd7961
508be90
9dd7961
508be90
 
0779acb
508be90
 
 
 
 
 
 
 
 
0779acb
508be90
 
 
 
9dd7961
e89e6f7
508be90
0779acb
508be90
 
 
 
e89e6f7
0779acb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c443e24
 
 
 
 
 
a36f392
 
 
 
 
 
 
 
 
 
 
 
 
b5e2763
0779acb
 
 
 
 
417d8a6
b5e2763
417d8a6
0779acb
 
 
 
 
 
 
 
e89e6f7
 
b5e2763
 
e89e6f7
 
 
 
 
 
 
 
 
 
 
b5e2763
 
e89e6f7
b5e2763
e89e6f7
 
b5e2763
e89e6f7
 
 
b5e2763
e89e6f7
 
 
a36f392
c443e24
417d8a6
a36f392
0779acb
a36f392
417d8a6
e89e6f7
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
import streamlit as st
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from transformers import pipeline
import numpy as np

# Load a lightweight NLP model for query understanding
nlp = pipeline("zero-shot-classification", model="facebook/bart-large-mnli")
summarizer = pipeline("summarization", model="facebook/bart-large-cnn")  # Model for summarization

# Function to load the uploaded file (CSV or Excel)
def load_file(uploaded_file):
    """Load data from an uploaded file."""
    try:
        if uploaded_file.type == "text/csv":
            data = pd.read_csv(uploaded_file)
        elif uploaded_file.type == "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet":
            data = pd.read_excel(uploaded_file)
        else:
            st.error("Unsupported file type.")
            return None
        return data
    except Exception as e:
        st.error(f"Error loading file: {e}")
        return None

# Function to classify the user query
def classify_query(query, candidate_labels):
    """Classify the user query into graph types or general analysis queries."""
    results = nlp(query, candidate_labels)
    if results:
        return results['labels'][0]
    return None

# Function to summarize document content
def summarize_document(data):
    """Generate a summary for the document."""
    if data is None:
        return "No data loaded."
    
    # Convert the dataframe to a string representation for summarization
    document_text = data.to_string(index=False)
    
    # Summarize the document text using the summarizer
    try:
        summary = summarizer(document_text, max_length=150, min_length=50, do_sample=False)
        return summary[0]['summary_text']
    except Exception as e:
        return f"Error summarizing document: {e}"

# Function to generate a graph based on user query
def generate_graph(data, query):
    """Generate a graph based on user query."""
    try:
        fig, ax = plt.subplots(figsize=(10, 6))

        # Infer column types
        numerical_columns = data.select_dtypes(include=['number']).columns.tolist()
        categorical_columns = data.select_dtypes(include=['object', 'category']).columns.tolist()
        datetime_columns = data.select_dtypes(include=['datetime']).columns.tolist()

        # Define possible graph types
        candidate_labels = ["bar chart", "line chart", "scatter plot", "histogram", "sales question", "general question"]
        query_type = classify_query(query, candidate_labels)

        # Provide text-based query response
        response = ""
        if query_type == "bar chart" and categorical_columns and numerical_columns:
            response = f"Generating a bar chart for {query}"
            x_col = st.selectbox("Select the categorical column:", categorical_columns)
            y_col = st.selectbox("Select the numerical column:", numerical_columns)
            aggregated_data = data[[x_col, y_col]].groupby(x_col).sum().reset_index()
            sns.barplot(x=x_col, y=y_col, data=aggregated_data, ax=ax, color='skyblue')
            ax.set_xticklabels(ax.get_xticklabels(), rotation=45, ha='right')
            ax.set_title(f"Bar Chart: {x_col} vs {y_col}")
            st.pyplot(fig)

        elif query_type == "line chart" and datetime_columns and numerical_columns:
            response = f"Generating a line chart for {query}"
            x_col = st.selectbox("Select the datetime column:", datetime_columns)
            y_col = st.selectbox("Select the numerical column:", numerical_columns)
            data[x_col] = pd.to_datetime(data[x_col])
            trend_data = data.groupby(x_col)[y_col].sum().reset_index()
            sns.lineplot(x=x_col, y=y_col, data=trend_data, ax=ax)
            ax.set_title(f"Line Chart: {y_col} Over Time")
            st.pyplot(fig)

        elif query_type == "scatter plot" and len(numerical_columns) >= 2:
            response = f"Generating a scatter plot for {query}"
            x_col = st.selectbox("Select the x-axis numerical column:", numerical_columns)
            y_col = st.selectbox("Select the y-axis numerical column:", numerical_columns)
            sns.scatterplot(x=x_col, y=y_col, data=data, ax=ax)
            ax.set_title(f"Scatter Plot: {x_col} vs {y_col}")
            st.pyplot(fig)

        elif query_type == "histogram" and numerical_columns:
            response = f"Generating a histogram for {query}"
            hist_col = st.selectbox("Select the numerical column:", numerical_columns)
            sns.histplot(data[hist_col], bins=20, kde=True, ax=ax, color='green')
            ax.set_title(f"Histogram of {hist_col}")
            st.pyplot(fig)

        elif query_type == "sales question":
            # General sales-related question (e.g., "Which department has the most sales?")
            response = "Analyzing the sales data for your query."
            # Assuming the file has columns like "Department" and "Sales"
            department_column = infer_column(data, ["department", "dept"])
            sales_column = infer_column(data, ["sales", "revenue"])

            if department_column and sales_column:
                # Answer the query: Which department has the most sales?
                top_department = data.groupby(department_column)[sales_column].sum().idxmax()
                top_sales = data.groupby(department_column)[sales_column].sum().max()
                response += f" The department with the most sales is {top_department} with total sales of {top_sales:.2f}."
            else:
                response += " Could not find relevant 'department' or 'sales' columns in the dataset."

        elif query.lower() == "what is this document":
            # Provide a document summary
            response = "Analyzing the document content."
            document_summary = summarize_document(data)
            response += f" Here is a summary of the document: {document_summary}"

        elif query_type == "general question":
            # Handle general questions
            response = "Analyzing the data for your general question."
            # Apply simple logic to answer the query based on dataset
            if "sales" in query.lower():
                response += " Checking for the highest sales..."
                sales_column = infer_column(data, ["sales", "revenue"])
                if sales_column:
                    top_country = data.loc[data[sales_column].idxmax(), 'country']  # Assuming 'country' column exists
                    response += f" The country with the highest sales is {top_country}."
                else:
                    response += " Could not find a 'sales' column."

        else:
            response = "Unsupported graph type or insufficient data. Try asking for a bar chart, line chart, scatter plot, histogram, or sales-related question."
        
        # Show text-based response
        st.write(response)

    except Exception as e:
        st.error(f"Error generating graph: {e}")

# Helper function to infer column names based on synonyms
def infer_column(data, synonyms):
    """Infer a column name based on synonyms."""
    for column in data.columns:
        if column.lower() in synonyms:
            return column
    return None

# Streamlit App Interface
def main():
    st.set_page_config(page_title="Data Visualization App", page_icon="📊", layout="wide")

    # Set background image
    st.markdown(
        """
        <style>
        .stApp {
            background-image: url('https://cdn.pixabay.com/photo/2016/06/02/02/33/triangles-1430105_1280.png');
            background-size: cover;
        }
        </style>
        """, unsafe_allow_html=True
    )

    st.title("Data Visualization App")
    st.markdown("Created by: Shamil Shahbaz", unsafe_allow_html=True)

    # File upload section
    uploaded_file = st.file_uploader("Upload a CSV or Excel file", type=["csv", "xlsx"])

    if uploaded_file is not None:
        # Load and display data
        data = load_file(uploaded_file)

        if data is not None:
            st.write("Dataset preview:", data.head())

            # User input for query
            user_query = st.text_input("Enter your query (e.g., 'Generate a bar chart for countries and sales', or 'What is this document?')")

            if user_query:
                # Generate the graph based on the query or handle general questions
                generate_graph(data, user_query)

if __name__ == "__main__":
    main()