Data-Excel / app.py
SHAMIL SHAHBAZ AWAN
Update app.py
c443e24 verified
import streamlit as st
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from transformers import pipeline
import numpy as np
# Load a lightweight NLP model for query understanding
nlp = pipeline("zero-shot-classification", model="facebook/bart-large-mnli")
summarizer = pipeline("summarization", model="facebook/bart-large-cnn") # Model for summarization
# Function to load the uploaded file (CSV or Excel)
def load_file(uploaded_file):
"""Load data from an uploaded file."""
try:
if uploaded_file.type == "text/csv":
data = pd.read_csv(uploaded_file)
elif uploaded_file.type == "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet":
data = pd.read_excel(uploaded_file)
else:
st.error("Unsupported file type.")
return None
return data
except Exception as e:
st.error(f"Error loading file: {e}")
return None
# Function to classify the user query
def classify_query(query, candidate_labels):
"""Classify the user query into graph types or general analysis queries."""
results = nlp(query, candidate_labels)
if results:
return results['labels'][0]
return None
# Function to summarize document content
def summarize_document(data):
"""Generate a summary for the document."""
if data is None:
return "No data loaded."
# Convert the dataframe to a string representation for summarization
document_text = data.to_string(index=False)
# Summarize the document text using the summarizer
try:
summary = summarizer(document_text, max_length=150, min_length=50, do_sample=False)
return summary[0]['summary_text']
except Exception as e:
return f"Error summarizing document: {e}"
# Function to generate a graph based on user query
def generate_graph(data, query):
"""Generate a graph based on user query."""
try:
fig, ax = plt.subplots(figsize=(10, 6))
# Infer column types
numerical_columns = data.select_dtypes(include=['number']).columns.tolist()
categorical_columns = data.select_dtypes(include=['object', 'category']).columns.tolist()
datetime_columns = data.select_dtypes(include=['datetime']).columns.tolist()
# Define possible graph types
candidate_labels = ["bar chart", "line chart", "scatter plot", "histogram", "sales question", "general question"]
query_type = classify_query(query, candidate_labels)
# Provide text-based query response
response = ""
if query_type == "bar chart" and categorical_columns and numerical_columns:
response = f"Generating a bar chart for {query}"
x_col = st.selectbox("Select the categorical column:", categorical_columns)
y_col = st.selectbox("Select the numerical column:", numerical_columns)
aggregated_data = data[[x_col, y_col]].groupby(x_col).sum().reset_index()
sns.barplot(x=x_col, y=y_col, data=aggregated_data, ax=ax, color='skyblue')
ax.set_xticklabels(ax.get_xticklabels(), rotation=45, ha='right')
ax.set_title(f"Bar Chart: {x_col} vs {y_col}")
st.pyplot(fig)
elif query_type == "line chart" and datetime_columns and numerical_columns:
response = f"Generating a line chart for {query}"
x_col = st.selectbox("Select the datetime column:", datetime_columns)
y_col = st.selectbox("Select the numerical column:", numerical_columns)
data[x_col] = pd.to_datetime(data[x_col])
trend_data = data.groupby(x_col)[y_col].sum().reset_index()
sns.lineplot(x=x_col, y=y_col, data=trend_data, ax=ax)
ax.set_title(f"Line Chart: {y_col} Over Time")
st.pyplot(fig)
elif query_type == "scatter plot" and len(numerical_columns) >= 2:
response = f"Generating a scatter plot for {query}"
x_col = st.selectbox("Select the x-axis numerical column:", numerical_columns)
y_col = st.selectbox("Select the y-axis numerical column:", numerical_columns)
sns.scatterplot(x=x_col, y=y_col, data=data, ax=ax)
ax.set_title(f"Scatter Plot: {x_col} vs {y_col}")
st.pyplot(fig)
elif query_type == "histogram" and numerical_columns:
response = f"Generating a histogram for {query}"
hist_col = st.selectbox("Select the numerical column:", numerical_columns)
sns.histplot(data[hist_col], bins=20, kde=True, ax=ax, color='green')
ax.set_title(f"Histogram of {hist_col}")
st.pyplot(fig)
elif query_type == "sales question":
# General sales-related question (e.g., "Which department has the most sales?")
response = "Analyzing the sales data for your query."
# Assuming the file has columns like "Department" and "Sales"
department_column = infer_column(data, ["department", "dept"])
sales_column = infer_column(data, ["sales", "revenue"])
if department_column and sales_column:
# Answer the query: Which department has the most sales?
top_department = data.groupby(department_column)[sales_column].sum().idxmax()
top_sales = data.groupby(department_column)[sales_column].sum().max()
response += f" The department with the most sales is {top_department} with total sales of {top_sales:.2f}."
else:
response += " Could not find relevant 'department' or 'sales' columns in the dataset."
elif query.lower() == "what is this document":
# Provide a document summary
response = "Analyzing the document content."
document_summary = summarize_document(data)
response += f" Here is a summary of the document: {document_summary}"
elif query_type == "general question":
# Handle general questions
response = "Analyzing the data for your general question."
# Apply simple logic to answer the query based on dataset
if "sales" in query.lower():
response += " Checking for the highest sales..."
sales_column = infer_column(data, ["sales", "revenue"])
if sales_column:
top_country = data.loc[data[sales_column].idxmax(), 'country'] # Assuming 'country' column exists
response += f" The country with the highest sales is {top_country}."
else:
response += " Could not find a 'sales' column."
else:
response = "Unsupported graph type or insufficient data. Try asking for a bar chart, line chart, scatter plot, histogram, or sales-related question."
# Show text-based response
st.write(response)
except Exception as e:
st.error(f"Error generating graph: {e}")
# Helper function to infer column names based on synonyms
def infer_column(data, synonyms):
"""Infer a column name based on synonyms."""
for column in data.columns:
if column.lower() in synonyms:
return column
return None
# Streamlit App Interface
def main():
st.set_page_config(page_title="Data Visualization App", page_icon="📊", layout="wide")
# Set background image
st.markdown(
"""
<style>
.stApp {
background-image: url('https://cdn.pixabay.com/photo/2016/06/02/02/33/triangles-1430105_1280.png');
background-size: cover;
}
</style>
""", unsafe_allow_html=True
)
st.title("Data Visualization App")
st.markdown("Created by: Shamil Shahbaz", unsafe_allow_html=True)
# File upload section
uploaded_file = st.file_uploader("Upload a CSV or Excel file", type=["csv", "xlsx"])
if uploaded_file is not None:
# Load and display data
data = load_file(uploaded_file)
if data is not None:
st.write("Dataset preview:", data.head())
# User input for query
user_query = st.text_input("Enter your query (e.g., 'Generate a bar chart for countries and sales', or 'What is this document?')")
if user_query:
# Generate the graph based on the query or handle general questions
generate_graph(data, user_query)
if __name__ == "__main__":
main()