Spaces:
Sleeping
Sleeping
| import streamlit as st | |
| import pandas as pd | |
| import matplotlib.pyplot as plt | |
| import seaborn as sns | |
| from transformers import pipeline | |
| import numpy as np | |
| # Load a lightweight NLP model for query understanding | |
| nlp = pipeline("zero-shot-classification", model="facebook/bart-large-mnli") | |
| summarizer = pipeline("summarization", model="facebook/bart-large-cnn") # Model for summarization | |
| # Function to load the uploaded file (CSV or Excel) | |
| def load_file(uploaded_file): | |
| """Load data from an uploaded file.""" | |
| try: | |
| if uploaded_file.type == "text/csv": | |
| data = pd.read_csv(uploaded_file) | |
| elif uploaded_file.type == "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet": | |
| data = pd.read_excel(uploaded_file) | |
| else: | |
| st.error("Unsupported file type.") | |
| return None | |
| return data | |
| except Exception as e: | |
| st.error(f"Error loading file: {e}") | |
| return None | |
| # Function to classify the user query | |
| def classify_query(query, candidate_labels): | |
| """Classify the user query into graph types or general analysis queries.""" | |
| results = nlp(query, candidate_labels) | |
| if results: | |
| return results['labels'][0] | |
| return None | |
| # Function to summarize document content | |
| def summarize_document(data): | |
| """Generate a summary for the document.""" | |
| if data is None: | |
| return "No data loaded." | |
| # Convert the dataframe to a string representation for summarization | |
| document_text = data.to_string(index=False) | |
| # Summarize the document text using the summarizer | |
| try: | |
| summary = summarizer(document_text, max_length=150, min_length=50, do_sample=False) | |
| return summary[0]['summary_text'] | |
| except Exception as e: | |
| return f"Error summarizing document: {e}" | |
| # Function to generate a graph based on user query | |
| def generate_graph(data, query): | |
| """Generate a graph based on user query.""" | |
| try: | |
| fig, ax = plt.subplots(figsize=(10, 6)) | |
| # Infer column types | |
| numerical_columns = data.select_dtypes(include=['number']).columns.tolist() | |
| categorical_columns = data.select_dtypes(include=['object', 'category']).columns.tolist() | |
| datetime_columns = data.select_dtypes(include=['datetime']).columns.tolist() | |
| # Define possible graph types | |
| candidate_labels = ["bar chart", "line chart", "scatter plot", "histogram", "sales question", "general question"] | |
| query_type = classify_query(query, candidate_labels) | |
| # Provide text-based query response | |
| response = "" | |
| if query_type == "bar chart" and categorical_columns and numerical_columns: | |
| response = f"Generating a bar chart for {query}" | |
| x_col = st.selectbox("Select the categorical column:", categorical_columns) | |
| y_col = st.selectbox("Select the numerical column:", numerical_columns) | |
| aggregated_data = data[[x_col, y_col]].groupby(x_col).sum().reset_index() | |
| sns.barplot(x=x_col, y=y_col, data=aggregated_data, ax=ax, color='skyblue') | |
| ax.set_xticklabels(ax.get_xticklabels(), rotation=45, ha='right') | |
| ax.set_title(f"Bar Chart: {x_col} vs {y_col}") | |
| st.pyplot(fig) | |
| elif query_type == "line chart" and datetime_columns and numerical_columns: | |
| response = f"Generating a line chart for {query}" | |
| x_col = st.selectbox("Select the datetime column:", datetime_columns) | |
| y_col = st.selectbox("Select the numerical column:", numerical_columns) | |
| data[x_col] = pd.to_datetime(data[x_col]) | |
| trend_data = data.groupby(x_col)[y_col].sum().reset_index() | |
| sns.lineplot(x=x_col, y=y_col, data=trend_data, ax=ax) | |
| ax.set_title(f"Line Chart: {y_col} Over Time") | |
| st.pyplot(fig) | |
| elif query_type == "scatter plot" and len(numerical_columns) >= 2: | |
| response = f"Generating a scatter plot for {query}" | |
| x_col = st.selectbox("Select the x-axis numerical column:", numerical_columns) | |
| y_col = st.selectbox("Select the y-axis numerical column:", numerical_columns) | |
| sns.scatterplot(x=x_col, y=y_col, data=data, ax=ax) | |
| ax.set_title(f"Scatter Plot: {x_col} vs {y_col}") | |
| st.pyplot(fig) | |
| elif query_type == "histogram" and numerical_columns: | |
| response = f"Generating a histogram for {query}" | |
| hist_col = st.selectbox("Select the numerical column:", numerical_columns) | |
| sns.histplot(data[hist_col], bins=20, kde=True, ax=ax, color='green') | |
| ax.set_title(f"Histogram of {hist_col}") | |
| st.pyplot(fig) | |
| elif query_type == "sales question": | |
| # General sales-related question (e.g., "Which department has the most sales?") | |
| response = "Analyzing the sales data for your query." | |
| # Assuming the file has columns like "Department" and "Sales" | |
| department_column = infer_column(data, ["department", "dept"]) | |
| sales_column = infer_column(data, ["sales", "revenue"]) | |
| if department_column and sales_column: | |
| # Answer the query: Which department has the most sales? | |
| top_department = data.groupby(department_column)[sales_column].sum().idxmax() | |
| top_sales = data.groupby(department_column)[sales_column].sum().max() | |
| response += f" The department with the most sales is {top_department} with total sales of {top_sales:.2f}." | |
| else: | |
| response += " Could not find relevant 'department' or 'sales' columns in the dataset." | |
| elif query.lower() == "what is this document": | |
| # Provide a document summary | |
| response = "Analyzing the document content." | |
| document_summary = summarize_document(data) | |
| response += f" Here is a summary of the document: {document_summary}" | |
| elif query_type == "general question": | |
| # Handle general questions | |
| response = "Analyzing the data for your general question." | |
| # Apply simple logic to answer the query based on dataset | |
| if "sales" in query.lower(): | |
| response += " Checking for the highest sales..." | |
| sales_column = infer_column(data, ["sales", "revenue"]) | |
| if sales_column: | |
| top_country = data.loc[data[sales_column].idxmax(), 'country'] # Assuming 'country' column exists | |
| response += f" The country with the highest sales is {top_country}." | |
| else: | |
| response += " Could not find a 'sales' column." | |
| else: | |
| response = "Unsupported graph type or insufficient data. Try asking for a bar chart, line chart, scatter plot, histogram, or sales-related question." | |
| # Show text-based response | |
| st.write(response) | |
| except Exception as e: | |
| st.error(f"Error generating graph: {e}") | |
| # Helper function to infer column names based on synonyms | |
| def infer_column(data, synonyms): | |
| """Infer a column name based on synonyms.""" | |
| for column in data.columns: | |
| if column.lower() in synonyms: | |
| return column | |
| return None | |
| # Streamlit App Interface | |
| def main(): | |
| st.set_page_config(page_title="Data Visualization App", page_icon="📊", layout="wide") | |
| # Set background image | |
| st.markdown( | |
| """ | |
| <style> | |
| .stApp { | |
| background-image: url('https://cdn.pixabay.com/photo/2016/06/02/02/33/triangles-1430105_1280.png'); | |
| background-size: cover; | |
| } | |
| </style> | |
| """, unsafe_allow_html=True | |
| ) | |
| st.title("Data Visualization App") | |
| st.markdown("Created by: Shamil Shahbaz", unsafe_allow_html=True) | |
| # File upload section | |
| uploaded_file = st.file_uploader("Upload a CSV or Excel file", type=["csv", "xlsx"]) | |
| if uploaded_file is not None: | |
| # Load and display data | |
| data = load_file(uploaded_file) | |
| if data is not None: | |
| st.write("Dataset preview:", data.head()) | |
| # User input for query | |
| user_query = st.text_input("Enter your query (e.g., 'Generate a bar chart for countries and sales', or 'What is this document?')") | |
| if user_query: | |
| # Generate the graph based on the query or handle general questions | |
| generate_graph(data, user_query) | |
| if __name__ == "__main__": | |
| main() | |