Spaces:
Sleeping
Sleeping
| import streamlit as st | |
| import pandas as pd | |
| import matplotlib.pyplot as plt | |
| import seaborn as sns | |
| import os | |
| from io import StringIO | |
| from transformers import pipeline | |
| # Load a lightweight NLP model for query understanding | |
| nlp = pipeline("zero-shot-classification", model="facebook/bart-large-mnli") | |
| # Function to load the uploaded file (CSV or Excel) | |
| def load_file(uploaded_file): | |
| """Load data from an uploaded file.""" | |
| try: | |
| if uploaded_file.type == "text/csv": | |
| data = pd.read_csv(uploaded_file) | |
| elif uploaded_file.type == "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet": | |
| data = pd.read_excel(uploaded_file) | |
| else: | |
| st.error("Unsupported file type.") | |
| return None | |
| return data | |
| except Exception as e: | |
| st.error(f"Error loading file: {e}") | |
| return None | |
| # Function to infer column names based on synonyms | |
| def infer_column(data, synonyms): | |
| """Infer a column name based on synonyms.""" | |
| for column in data.columns: | |
| if column.lower() in synonyms: | |
| return column | |
| return None | |
| # Function to classify the user query | |
| def classify_query(query, candidate_labels): | |
| """Classify the user query into graph types.""" | |
| results = nlp(query, candidate_labels) | |
| if results: | |
| return results['labels'][0] | |
| return None | |
| # Function to generate graph based on user query | |
| def generate_graph(data, query): | |
| """Generate a graph based on user query.""" | |
| try: | |
| fig, ax = plt.subplots(figsize=(10, 6)) | |
| # Infer column names dynamically | |
| numerical_columns = data.select_dtypes(include=['number']).columns.tolist() | |
| categorical_columns = data.select_dtypes(include=['object', 'category']).columns.tolist() | |
| datetime_columns = data.select_dtypes(include=['datetime']).columns.tolist() | |
| # Define possible graph types | |
| candidate_labels = ["bar chart", "line chart", "scatter plot", "histogram"] | |
| query_type = classify_query(query, candidate_labels) | |
| if query_type == "bar chart" and categorical_columns and numerical_columns: | |
| # Bar chart for categorical vs numerical | |
| x_col = st.selectbox("Select the categorical column:", categorical_columns) | |
| y_col = st.selectbox("Select the numerical column:", numerical_columns) | |
| aggregated_data = data[[x_col, y_col]].groupby(x_col).sum().reset_index() | |
| sns.barplot(x=x_col, y=y_col, data=aggregated_data, ax=ax, color='skyblue') | |
| ax.set_xticklabels(ax.get_xticklabels(), rotation=45, ha='right') | |
| ax.set_title(f"Bar Chart: {x_col} vs {y_col}") | |
| st.pyplot(fig) | |
| elif query_type == "line chart" and datetime_columns and numerical_columns: | |
| # Line chart for numerical trend over time | |
| x_col = st.selectbox("Select the datetime column:", datetime_columns) | |
| y_col = st.selectbox("Select the numerical column:", numerical_columns) | |
| data[x_col] = pd.to_datetime(data[x_col]) | |
| trend_data = data.groupby(x_col)[y_col].sum().reset_index() | |
| sns.lineplot(x=x_col, y=y_col, data=trend_data, ax=ax) | |
| ax.set_title(f"Line Chart: {y_col} Over Time") | |
| st.pyplot(fig) | |
| elif query_type == "scatter plot" and len(numerical_columns) >= 2: | |
| # Scatter plot for numerical relationships | |
| x_col = st.selectbox("Select the x-axis numerical column:", numerical_columns) | |
| y_col = st.selectbox("Select the y-axis numerical column:", numerical_columns) | |
| sns.scatterplot(x=x_col, y=y_col, data=data, ax=ax) | |
| ax.set_title(f"Scatter Plot: {x_col} vs {y_col}") | |
| st.pyplot(fig) | |
| elif query_type == "histogram" and numerical_columns: | |
| # Histogram for a numerical column | |
| hist_col = st.selectbox("Select the numerical column:", numerical_columns) | |
| sns.histplot(data[hist_col], bins=20, kde=True, ax=ax, color='green') | |
| ax.set_title(f"Histogram of {hist_col}") | |
| st.pyplot(fig) | |
| else: | |
| st.error("Unsupported graph type or insufficient data. Try asking for a bar chart, line chart, scatter plot, or histogram.") | |
| except Exception as e: | |
| st.error(f"Error generating graph: {e}") | |
| # Streamlit App Interface | |
| def main(): | |
| st.set_page_config(page_title="Data Visualization App", page_icon="📊", layout="wide") | |
| # Set background image | |
| st.markdown( | |
| """ | |
| <style> | |
| .stApp { | |
| background-image: url('https://cdn.pixabay.com/photo/2016/06/02/02/33/triangles-1430105_1280.png'); | |
| background-size: cover; | |
| } | |
| </style> | |
| """, unsafe_allow_html=True | |
| ) | |
| st.title("Data Visualization App") | |
| st.markdown("Created by: Shamil Shahbaz", unsafe_allow_html=True) | |
| # File upload section | |
| uploaded_file = st.file_uploader("Upload a CSV or Excel file", type=["csv", "xlsx"]) | |
| if uploaded_file is not None: | |
| # Load and display data | |
| data = load_file(uploaded_file) | |
| if data is not None: | |
| st.write("Dataset preview:", data.head()) | |
| # User input for graph generation | |
| query = st.text_input("Enter your query (e.g., 'Generate a bar chart for countries and gross sales')") | |
| if query: | |
| # Generate the graph based on the query | |
| generate_graph(data, query) | |
| if __name__ == "__main__": | |
| main() | |