Data-Excel / app.py
SHAMIL SHAHBAZ AWAN
Update app.py
508be90 verified
raw
history blame
5.52 kB
import streamlit as st
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import os
from io import StringIO
from transformers import pipeline
# Load a lightweight NLP model for query understanding
nlp = pipeline("zero-shot-classification", model="facebook/bart-large-mnli")
# Function to load the uploaded file (CSV or Excel)
def load_file(uploaded_file):
"""Load data from an uploaded file."""
try:
if uploaded_file.type == "text/csv":
data = pd.read_csv(uploaded_file)
elif uploaded_file.type == "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet":
data = pd.read_excel(uploaded_file)
else:
st.error("Unsupported file type.")
return None
return data
except Exception as e:
st.error(f"Error loading file: {e}")
return None
# Function to infer column names based on synonyms
def infer_column(data, synonyms):
"""Infer a column name based on synonyms."""
for column in data.columns:
if column.lower() in synonyms:
return column
return None
# Function to classify the user query
def classify_query(query, candidate_labels):
"""Classify the user query into graph types."""
results = nlp(query, candidate_labels)
if results:
return results['labels'][0]
return None
# Function to generate graph based on user query
def generate_graph(data, query):
"""Generate a graph based on user query."""
try:
fig, ax = plt.subplots(figsize=(10, 6))
# Infer column names dynamically
numerical_columns = data.select_dtypes(include=['number']).columns.tolist()
categorical_columns = data.select_dtypes(include=['object', 'category']).columns.tolist()
datetime_columns = data.select_dtypes(include=['datetime']).columns.tolist()
# Define possible graph types
candidate_labels = ["bar chart", "line chart", "scatter plot", "histogram"]
query_type = classify_query(query, candidate_labels)
if query_type == "bar chart" and categorical_columns and numerical_columns:
# Bar chart for categorical vs numerical
x_col = st.selectbox("Select the categorical column:", categorical_columns)
y_col = st.selectbox("Select the numerical column:", numerical_columns)
aggregated_data = data[[x_col, y_col]].groupby(x_col).sum().reset_index()
sns.barplot(x=x_col, y=y_col, data=aggregated_data, ax=ax, color='skyblue')
ax.set_xticklabels(ax.get_xticklabels(), rotation=45, ha='right')
ax.set_title(f"Bar Chart: {x_col} vs {y_col}")
st.pyplot(fig)
elif query_type == "line chart" and datetime_columns and numerical_columns:
# Line chart for numerical trend over time
x_col = st.selectbox("Select the datetime column:", datetime_columns)
y_col = st.selectbox("Select the numerical column:", numerical_columns)
data[x_col] = pd.to_datetime(data[x_col])
trend_data = data.groupby(x_col)[y_col].sum().reset_index()
sns.lineplot(x=x_col, y=y_col, data=trend_data, ax=ax)
ax.set_title(f"Line Chart: {y_col} Over Time")
st.pyplot(fig)
elif query_type == "scatter plot" and len(numerical_columns) >= 2:
# Scatter plot for numerical relationships
x_col = st.selectbox("Select the x-axis numerical column:", numerical_columns)
y_col = st.selectbox("Select the y-axis numerical column:", numerical_columns)
sns.scatterplot(x=x_col, y=y_col, data=data, ax=ax)
ax.set_title(f"Scatter Plot: {x_col} vs {y_col}")
st.pyplot(fig)
elif query_type == "histogram" and numerical_columns:
# Histogram for a numerical column
hist_col = st.selectbox("Select the numerical column:", numerical_columns)
sns.histplot(data[hist_col], bins=20, kde=True, ax=ax, color='green')
ax.set_title(f"Histogram of {hist_col}")
st.pyplot(fig)
else:
st.error("Unsupported graph type or insufficient data. Try asking for a bar chart, line chart, scatter plot, or histogram.")
except Exception as e:
st.error(f"Error generating graph: {e}")
# Streamlit App Interface
def main():
st.set_page_config(page_title="Data Visualization App", page_icon="📊", layout="wide")
# Set background image
st.markdown(
"""
<style>
.stApp {
background-image: url('https://cdn.pixabay.com/photo/2016/06/02/02/33/triangles-1430105_1280.png');
background-size: cover;
}
</style>
""", unsafe_allow_html=True
)
st.title("Data Visualization App")
st.markdown("Created by: Shamil Shahbaz", unsafe_allow_html=True)
# File upload section
uploaded_file = st.file_uploader("Upload a CSV or Excel file", type=["csv", "xlsx"])
if uploaded_file is not None:
# Load and display data
data = load_file(uploaded_file)
if data is not None:
st.write("Dataset preview:", data.head())
# User input for graph generation
query = st.text_input("Enter your query (e.g., 'Generate a bar chart for countries and gross sales')")
if query:
# Generate the graph based on the query
generate_graph(data, query)
if __name__ == "__main__":
main()