Spaces:
Sleeping
Sleeping
File size: 8,497 Bytes
417d8a6 559d037 a36f392 559d037 508be90 c443e24 417d8a6 e89e6f7 417d8a6 e89e6f7 a36f392 e89e6f7 417d8a6 559d037 508be90 0779acb 508be90 559d037 508be90 559d037 c443e24 a36f392 417d8a6 05d8272 0779acb 508be90 9dd7961 508be90 a36f392 508be90 559d037 0779acb 508be90 0779acb 508be90 9dd7961 508be90 9dd7961 508be90 0779acb 508be90 0779acb 508be90 9dd7961 e89e6f7 508be90 0779acb 508be90 e89e6f7 0779acb c443e24 a36f392 b5e2763 0779acb 417d8a6 b5e2763 417d8a6 0779acb e89e6f7 b5e2763 e89e6f7 b5e2763 e89e6f7 b5e2763 e89e6f7 b5e2763 e89e6f7 b5e2763 e89e6f7 a36f392 c443e24 417d8a6 a36f392 0779acb a36f392 417d8a6 e89e6f7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 |
import streamlit as st
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from transformers import pipeline
import numpy as np
# Load a lightweight NLP model for query understanding
nlp = pipeline("zero-shot-classification", model="facebook/bart-large-mnli")
summarizer = pipeline("summarization", model="facebook/bart-large-cnn") # Model for summarization
# Function to load the uploaded file (CSV or Excel)
def load_file(uploaded_file):
"""Load data from an uploaded file."""
try:
if uploaded_file.type == "text/csv":
data = pd.read_csv(uploaded_file)
elif uploaded_file.type == "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet":
data = pd.read_excel(uploaded_file)
else:
st.error("Unsupported file type.")
return None
return data
except Exception as e:
st.error(f"Error loading file: {e}")
return None
# Function to classify the user query
def classify_query(query, candidate_labels):
"""Classify the user query into graph types or general analysis queries."""
results = nlp(query, candidate_labels)
if results:
return results['labels'][0]
return None
# Function to summarize document content
def summarize_document(data):
"""Generate a summary for the document."""
if data is None:
return "No data loaded."
# Convert the dataframe to a string representation for summarization
document_text = data.to_string(index=False)
# Summarize the document text using the summarizer
try:
summary = summarizer(document_text, max_length=150, min_length=50, do_sample=False)
return summary[0]['summary_text']
except Exception as e:
return f"Error summarizing document: {e}"
# Function to generate a graph based on user query
def generate_graph(data, query):
"""Generate a graph based on user query."""
try:
fig, ax = plt.subplots(figsize=(10, 6))
# Infer column types
numerical_columns = data.select_dtypes(include=['number']).columns.tolist()
categorical_columns = data.select_dtypes(include=['object', 'category']).columns.tolist()
datetime_columns = data.select_dtypes(include=['datetime']).columns.tolist()
# Define possible graph types
candidate_labels = ["bar chart", "line chart", "scatter plot", "histogram", "sales question", "general question"]
query_type = classify_query(query, candidate_labels)
# Provide text-based query response
response = ""
if query_type == "bar chart" and categorical_columns and numerical_columns:
response = f"Generating a bar chart for {query}"
x_col = st.selectbox("Select the categorical column:", categorical_columns)
y_col = st.selectbox("Select the numerical column:", numerical_columns)
aggregated_data = data[[x_col, y_col]].groupby(x_col).sum().reset_index()
sns.barplot(x=x_col, y=y_col, data=aggregated_data, ax=ax, color='skyblue')
ax.set_xticklabels(ax.get_xticklabels(), rotation=45, ha='right')
ax.set_title(f"Bar Chart: {x_col} vs {y_col}")
st.pyplot(fig)
elif query_type == "line chart" and datetime_columns and numerical_columns:
response = f"Generating a line chart for {query}"
x_col = st.selectbox("Select the datetime column:", datetime_columns)
y_col = st.selectbox("Select the numerical column:", numerical_columns)
data[x_col] = pd.to_datetime(data[x_col])
trend_data = data.groupby(x_col)[y_col].sum().reset_index()
sns.lineplot(x=x_col, y=y_col, data=trend_data, ax=ax)
ax.set_title(f"Line Chart: {y_col} Over Time")
st.pyplot(fig)
elif query_type == "scatter plot" and len(numerical_columns) >= 2:
response = f"Generating a scatter plot for {query}"
x_col = st.selectbox("Select the x-axis numerical column:", numerical_columns)
y_col = st.selectbox("Select the y-axis numerical column:", numerical_columns)
sns.scatterplot(x=x_col, y=y_col, data=data, ax=ax)
ax.set_title(f"Scatter Plot: {x_col} vs {y_col}")
st.pyplot(fig)
elif query_type == "histogram" and numerical_columns:
response = f"Generating a histogram for {query}"
hist_col = st.selectbox("Select the numerical column:", numerical_columns)
sns.histplot(data[hist_col], bins=20, kde=True, ax=ax, color='green')
ax.set_title(f"Histogram of {hist_col}")
st.pyplot(fig)
elif query_type == "sales question":
# General sales-related question (e.g., "Which department has the most sales?")
response = "Analyzing the sales data for your query."
# Assuming the file has columns like "Department" and "Sales"
department_column = infer_column(data, ["department", "dept"])
sales_column = infer_column(data, ["sales", "revenue"])
if department_column and sales_column:
# Answer the query: Which department has the most sales?
top_department = data.groupby(department_column)[sales_column].sum().idxmax()
top_sales = data.groupby(department_column)[sales_column].sum().max()
response += f" The department with the most sales is {top_department} with total sales of {top_sales:.2f}."
else:
response += " Could not find relevant 'department' or 'sales' columns in the dataset."
elif query.lower() == "what is this document":
# Provide a document summary
response = "Analyzing the document content."
document_summary = summarize_document(data)
response += f" Here is a summary of the document: {document_summary}"
elif query_type == "general question":
# Handle general questions
response = "Analyzing the data for your general question."
# Apply simple logic to answer the query based on dataset
if "sales" in query.lower():
response += " Checking for the highest sales..."
sales_column = infer_column(data, ["sales", "revenue"])
if sales_column:
top_country = data.loc[data[sales_column].idxmax(), 'country'] # Assuming 'country' column exists
response += f" The country with the highest sales is {top_country}."
else:
response += " Could not find a 'sales' column."
else:
response = "Unsupported graph type or insufficient data. Try asking for a bar chart, line chart, scatter plot, histogram, or sales-related question."
# Show text-based response
st.write(response)
except Exception as e:
st.error(f"Error generating graph: {e}")
# Helper function to infer column names based on synonyms
def infer_column(data, synonyms):
"""Infer a column name based on synonyms."""
for column in data.columns:
if column.lower() in synonyms:
return column
return None
# Streamlit App Interface
def main():
st.set_page_config(page_title="Data Visualization App", page_icon="📊", layout="wide")
# Set background image
st.markdown(
"""
<style>
.stApp {
background-image: url('https://cdn.pixabay.com/photo/2016/06/02/02/33/triangles-1430105_1280.png');
background-size: cover;
}
</style>
""", unsafe_allow_html=True
)
st.title("Data Visualization App")
st.markdown("Created by: Shamil Shahbaz", unsafe_allow_html=True)
# File upload section
uploaded_file = st.file_uploader("Upload a CSV or Excel file", type=["csv", "xlsx"])
if uploaded_file is not None:
# Load and display data
data = load_file(uploaded_file)
if data is not None:
st.write("Dataset preview:", data.head())
# User input for query
user_query = st.text_input("Enter your query (e.g., 'Generate a bar chart for countries and sales', or 'What is this document?')")
if user_query:
# Generate the graph based on the query or handle general questions
generate_graph(data, user_query)
if __name__ == "__main__":
main()
|