Data-Excel / app.py
SHAMIL SHAHBAZ AWAN
Update app.py
559d037 verified
raw
history blame
5.41 kB
import streamlit as st
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import os
from io import StringIO
from transformers import pipeline
# Load a lightweight NLP model for query understanding
nlp = pipeline("text-classification", model="distilbert-base-uncased", tokenizer="distilbert-base-uncased")
# Function to load the uploaded file (CSV or Excel)
def load_file(uploaded_file):
"""Load data from an uploaded file."""
try:
if uploaded_file.type == "text/csv":
data = pd.read_csv(uploaded_file)
elif uploaded_file.type == "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet":
data = pd.read_excel(uploaded_file)
else:
st.error("Unsupported file type.")
return None
return data
except Exception as e:
st.error(f"Error loading file: {e}")
return None
# Function to infer column names based on synonyms
def infer_column(data, synonyms):
"""Infer a column name based on synonyms."""
for column in data.columns:
if column.lower() in synonyms:
return column
return None
# Function to classify the user query
def classify_query(query):
"""Classify the user query into graph types."""
results = nlp(query)
if results:
return results[0]['label']
return None
# Function to generate graph based on user query
def generate_graph(data, query):
"""Generate a graph based on user query."""
try:
fig, ax = plt.subplots(figsize=(10, 6))
# Infer column names
country_col = infer_column(data, {"country", "countries"})
sales_col = infer_column(data, {"gross_sales", "sales", "revenue"})
date_col = infer_column(data, {"date", "time"})
query_type = classify_query(query)
if "bar" in query.lower() and country_col and sales_col:
# Bar chart for countries and gross sales
country_data = data[[country_col, sales_col]].groupby(country_col).sum().reset_index()
sns.barplot(x=country_col, y=sales_col, data=country_data, ax=ax, color='skyblue')
ax.set_xticklabels(ax.get_xticklabels(), rotation=45, ha='right')
ax.set_title(f"Bar Chart: {country_col} vs {sales_col}")
st.pyplot(fig)
elif "line" in query.lower() and date_col and sales_col:
# Line chart for sales trend over time
data[date_col] = pd.to_datetime(data[date_col])
sales_trend = data.groupby(date_col)[sales_col].sum().reset_index()
sns.lineplot(x=date_col, y=sales_col, data=sales_trend, ax=ax)
ax.set_title(f"Line Chart: {sales_col} Over Time")
st.pyplot(fig)
elif "scatter" in query.lower():
# Scatter plot for relationships
if "between" in query.lower():
columns = query.lower().split("between")[-1].strip().split("and")
if len(columns) == 2:
x_col = infer_column(data, {columns[0].strip()})
y_col = infer_column(data, {columns[1].strip()})
if x_col and y_col:
sns.scatterplot(x=x_col, y=y_col, data=data, ax=ax)
ax.set_title(f"Scatter Plot: {x_col} vs {y_col}")
st.pyplot(fig)
return
st.error("Please specify valid columns for the scatter plot.")
elif "histogram" in query.lower():
# Histogram for a specified column
if "for" in query.lower():
column = query.lower().split("for")[-1].strip()
hist_col = infer_column(data, {column})
if hist_col:
sns.histplot(data[hist_col], bins=20, kde=True, ax=ax, color='green')
ax.set_title(f"Histogram of {hist_col}")
st.pyplot(fig)
return
st.error("Please specify a valid column for the histogram.")
else:
st.error("Unsupported graph type. Try asking for a bar chart, line chart, scatter plot, or histogram.")
except Exception as e:
st.error(f"Error generating graph: {e}")
# Streamlit App Interface
def main():
st.set_page_config(page_title="Data Visualization App", page_icon="📊", layout="wide")
# Set background image
st.markdown(
"""
<style>
.stApp {
background-image: url('https://cdn.pixabay.com/photo/2016/06/02/02/33/triangles-1430105_1280.png');
background-size: cover;
}
</style>
""", unsafe_allow_html=True
)
st.title("Data Visualization App")
st.markdown("Created by: Shamil Shahbaz", unsafe_allow_html=True)
# File upload section
uploaded_file = st.file_uploader("Upload a CSV or Excel file", type=["csv", "xlsx"])
if uploaded_file is not None:
# Load and display data
data = load_file(uploaded_file)
if data is not None:
st.write("Dataset preview:", data.head())
# User input for graph generation
query = st.text_input("Enter your query (e.g., 'Generate a bar chart for countries and gross sales')")
if query:
# Generate the graph based on the query
generate_graph(data, query)
if __name__ == "__main__":
main()