import streamlit as st import pandas as pd import matplotlib.pyplot as plt import seaborn as sns from transformers import pipeline import numpy as np # Load a lightweight NLP model for query understanding nlp = pipeline("zero-shot-classification", model="facebook/bart-large-mnli") summarizer = pipeline("summarization", model="facebook/bart-large-cnn") # Model for summarization # Function to load the uploaded file (CSV or Excel) def load_file(uploaded_file): """Load data from an uploaded file.""" try: if uploaded_file.type == "text/csv": data = pd.read_csv(uploaded_file) elif uploaded_file.type == "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet": data = pd.read_excel(uploaded_file) else: st.error("Unsupported file type.") return None return data except Exception as e: st.error(f"Error loading file: {e}") return None # Function to classify the user query def classify_query(query, candidate_labels): """Classify the user query into graph types or general analysis queries.""" results = nlp(query, candidate_labels) if results: return results['labels'][0] return None # Function to summarize document content def summarize_document(data): """Generate a summary for the document.""" if data is None: return "No data loaded." # Convert the dataframe to a string representation for summarization document_text = data.to_string(index=False) # Summarize the document text using the summarizer try: summary = summarizer(document_text, max_length=150, min_length=50, do_sample=False) return summary[0]['summary_text'] except Exception as e: return f"Error summarizing document: {e}" # Function to generate a graph based on user query def generate_graph(data, query): """Generate a graph based on user query.""" try: fig, ax = plt.subplots(figsize=(10, 6)) # Infer column types numerical_columns = data.select_dtypes(include=['number']).columns.tolist() categorical_columns = data.select_dtypes(include=['object', 'category']).columns.tolist() datetime_columns = data.select_dtypes(include=['datetime']).columns.tolist() # Define possible graph types candidate_labels = ["bar chart", "line chart", "scatter plot", "histogram", "sales question", "general question"] query_type = classify_query(query, candidate_labels) # Provide text-based query response response = "" if query_type == "bar chart" and categorical_columns and numerical_columns: response = f"Generating a bar chart for {query}" x_col = st.selectbox("Select the categorical column:", categorical_columns) y_col = st.selectbox("Select the numerical column:", numerical_columns) aggregated_data = data[[x_col, y_col]].groupby(x_col).sum().reset_index() sns.barplot(x=x_col, y=y_col, data=aggregated_data, ax=ax, color='skyblue') ax.set_xticklabels(ax.get_xticklabels(), rotation=45, ha='right') ax.set_title(f"Bar Chart: {x_col} vs {y_col}") st.pyplot(fig) elif query_type == "line chart" and datetime_columns and numerical_columns: response = f"Generating a line chart for {query}" x_col = st.selectbox("Select the datetime column:", datetime_columns) y_col = st.selectbox("Select the numerical column:", numerical_columns) data[x_col] = pd.to_datetime(data[x_col]) trend_data = data.groupby(x_col)[y_col].sum().reset_index() sns.lineplot(x=x_col, y=y_col, data=trend_data, ax=ax) ax.set_title(f"Line Chart: {y_col} Over Time") st.pyplot(fig) elif query_type == "scatter plot" and len(numerical_columns) >= 2: response = f"Generating a scatter plot for {query}" x_col = st.selectbox("Select the x-axis numerical column:", numerical_columns) y_col = st.selectbox("Select the y-axis numerical column:", numerical_columns) sns.scatterplot(x=x_col, y=y_col, data=data, ax=ax) ax.set_title(f"Scatter Plot: {x_col} vs {y_col}") st.pyplot(fig) elif query_type == "histogram" and numerical_columns: response = f"Generating a histogram for {query}" hist_col = st.selectbox("Select the numerical column:", numerical_columns) sns.histplot(data[hist_col], bins=20, kde=True, ax=ax, color='green') ax.set_title(f"Histogram of {hist_col}") st.pyplot(fig) elif query_type == "sales question": # General sales-related question (e.g., "Which department has the most sales?") response = "Analyzing the sales data for your query." # Assuming the file has columns like "Department" and "Sales" department_column = infer_column(data, ["department", "dept"]) sales_column = infer_column(data, ["sales", "revenue"]) if department_column and sales_column: # Answer the query: Which department has the most sales? top_department = data.groupby(department_column)[sales_column].sum().idxmax() top_sales = data.groupby(department_column)[sales_column].sum().max() response += f" The department with the most sales is {top_department} with total sales of {top_sales:.2f}." else: response += " Could not find relevant 'department' or 'sales' columns in the dataset." elif query.lower() == "what is this document": # Provide a document summary response = "Analyzing the document content." document_summary = summarize_document(data) response += f" Here is a summary of the document: {document_summary}" elif query_type == "general question": # Handle general questions response = "Analyzing the data for your general question." # Apply simple logic to answer the query based on dataset if "sales" in query.lower(): response += " Checking for the highest sales..." sales_column = infer_column(data, ["sales", "revenue"]) if sales_column: top_country = data.loc[data[sales_column].idxmax(), 'country'] # Assuming 'country' column exists response += f" The country with the highest sales is {top_country}." else: response += " Could not find a 'sales' column." else: response = "Unsupported graph type or insufficient data. Try asking for a bar chart, line chart, scatter plot, histogram, or sales-related question." # Show text-based response st.write(response) except Exception as e: st.error(f"Error generating graph: {e}") # Helper function to infer column names based on synonyms def infer_column(data, synonyms): """Infer a column name based on synonyms.""" for column in data.columns: if column.lower() in synonyms: return column return None # Streamlit App Interface def main(): st.set_page_config(page_title="Data Visualization App", page_icon="📊", layout="wide") # Set background image st.markdown( """ """, unsafe_allow_html=True ) st.title("Data Visualization App") st.markdown("Created by: Shamil Shahbaz", unsafe_allow_html=True) # File upload section uploaded_file = st.file_uploader("Upload a CSV or Excel file", type=["csv", "xlsx"]) if uploaded_file is not None: # Load and display data data = load_file(uploaded_file) if data is not None: st.write("Dataset preview:", data.head()) # User input for query user_query = st.text_input("Enter your query (e.g., 'Generate a bar chart for countries and sales', or 'What is this document?')") if user_query: # Generate the graph based on the query or handle general questions generate_graph(data, user_query) if __name__ == "__main__": main()