Spaces:
Sleeping
Sleeping
| import streamlit as st | |
| import pandas as pd | |
| import matplotlib.pyplot as plt | |
| import seaborn as sns | |
| import os | |
| from io import StringIO | |
| from transformers import pipeline | |
| # Load a lightweight NLP model for query understanding | |
| nlp = pipeline("text-classification", model="distilbert-base-uncased", tokenizer="distilbert-base-uncased") | |
| # Function to load the uploaded file (CSV or Excel) | |
| def load_file(uploaded_file): | |
| """Load data from an uploaded file.""" | |
| try: | |
| if uploaded_file.type == "text/csv": | |
| data = pd.read_csv(uploaded_file) | |
| elif uploaded_file.type == "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet": | |
| data = pd.read_excel(uploaded_file) | |
| else: | |
| st.error("Unsupported file type.") | |
| return None | |
| return data | |
| except Exception as e: | |
| st.error(f"Error loading file: {e}") | |
| return None | |
| # Function to infer column names based on synonyms | |
| def infer_column(data, synonyms): | |
| """Infer a column name based on synonyms.""" | |
| for column in data.columns: | |
| if column.lower() in synonyms: | |
| return column | |
| return None | |
| # Function to classify the user query | |
| def classify_query(query): | |
| """Classify the user query into graph types.""" | |
| results = nlp(query) | |
| if results: | |
| return results[0]['label'] | |
| return None | |
| # Function to generate graph based on user query | |
| def generate_graph(data, query): | |
| """Generate a graph based on user query.""" | |
| try: | |
| fig, ax = plt.subplots(figsize=(10, 6)) | |
| # Infer column names | |
| country_col = infer_column(data, {"country", "countries"}) | |
| sales_col = infer_column(data, {"gross_sales", "sales", "revenue"}) | |
| date_col = infer_column(data, {"date", "time"}) | |
| query_type = classify_query(query) | |
| if "bar" in query.lower() and country_col and sales_col: | |
| # Bar chart for countries and gross sales | |
| country_data = data[[country_col, sales_col]].groupby(country_col).sum().reset_index() | |
| sns.barplot(x=country_col, y=sales_col, data=country_data, ax=ax, color='skyblue') | |
| ax.set_xticklabels(ax.get_xticklabels(), rotation=45, ha='right') | |
| ax.set_title(f"Bar Chart: {country_col} vs {sales_col}") | |
| st.pyplot(fig) | |
| elif "line" in query.lower() and date_col and sales_col: | |
| # Line chart for sales trend over time | |
| data[date_col] = pd.to_datetime(data[date_col]) | |
| sales_trend = data.groupby(date_col)[sales_col].sum().reset_index() | |
| sns.lineplot(x=date_col, y=sales_col, data=sales_trend, ax=ax) | |
| ax.set_title(f"Line Chart: {sales_col} Over Time") | |
| st.pyplot(fig) | |
| elif "scatter" in query.lower(): | |
| # Scatter plot for relationships | |
| if "between" in query.lower(): | |
| columns = query.lower().split("between")[-1].strip().split("and") | |
| if len(columns) == 2: | |
| x_col = infer_column(data, {columns[0].strip()}) | |
| y_col = infer_column(data, {columns[1].strip()}) | |
| if x_col and y_col: | |
| sns.scatterplot(x=x_col, y=y_col, data=data, ax=ax) | |
| ax.set_title(f"Scatter Plot: {x_col} vs {y_col}") | |
| st.pyplot(fig) | |
| return | |
| st.error("Please specify valid columns for the scatter plot.") | |
| elif "histogram" in query.lower(): | |
| # Histogram for a specified column | |
| if "for" in query.lower(): | |
| column = query.lower().split("for")[-1].strip() | |
| hist_col = infer_column(data, {column}) | |
| if hist_col: | |
| sns.histplot(data[hist_col], bins=20, kde=True, ax=ax, color='green') | |
| ax.set_title(f"Histogram of {hist_col}") | |
| st.pyplot(fig) | |
| return | |
| st.error("Please specify a valid column for the histogram.") | |
| else: | |
| st.error("Unsupported graph type. Try asking for a bar chart, line chart, scatter plot, or histogram.") | |
| except Exception as e: | |
| st.error(f"Error generating graph: {e}") | |
| # Streamlit App Interface | |
| def main(): | |
| st.set_page_config(page_title="Data Visualization App", page_icon="📊", layout="wide") | |
| # Set background image | |
| st.markdown( | |
| """ | |
| <style> | |
| .stApp { | |
| background-image: url('https://cdn.pixabay.com/photo/2016/06/02/02/33/triangles-1430105_1280.png'); | |
| background-size: cover; | |
| } | |
| </style> | |
| """, unsafe_allow_html=True | |
| ) | |
| st.title("Data Visualization App") | |
| st.markdown("Created by: Shamil Shahbaz", unsafe_allow_html=True) | |
| # File upload section | |
| uploaded_file = st.file_uploader("Upload a CSV or Excel file", type=["csv", "xlsx"]) | |
| if uploaded_file is not None: | |
| # Load and display data | |
| data = load_file(uploaded_file) | |
| if data is not None: | |
| st.write("Dataset preview:", data.head()) | |
| # User input for graph generation | |
| query = st.text_input("Enter your query (e.g., 'Generate a bar chart for countries and gross sales')") | |
| if query: | |
| # Generate the graph based on the query | |
| generate_graph(data, query) | |
| if __name__ == "__main__": | |
| main() | |