|
|
import streamlit as st |
|
|
import pandas as pd |
|
|
import numpy as np |
|
|
import plotly.express as px |
|
|
import plotly.graph_objects as go |
|
|
import plotly.io as pio |
|
|
from plotly.subplots import make_subplots |
|
|
import io |
|
|
|
|
|
|
|
|
AUTHOR = "Eduardo Nacimiento García" |
|
|
EMAIL = "enacimie@ull.edu.es" |
|
|
LICENSE = "Apache 2.0" |
|
|
|
|
|
|
|
|
st.set_page_config( |
|
|
page_title="SimpleViz", |
|
|
page_icon="🎨", |
|
|
layout="wide", |
|
|
initial_sidebar_state="expanded", |
|
|
) |
|
|
|
|
|
|
|
|
st.title("🎨 SimpleViz") |
|
|
st.markdown(f"**Author:** {AUTHOR} | **Email:** {EMAIL} | **License:** {LICENSE}") |
|
|
st.write(""" |
|
|
Upload a CSV or use the demo dataset to create beautiful, interactive visualizations in seconds. |
|
|
""") |
|
|
|
|
|
|
|
|
@st.cache_data |
|
|
def create_demo_data(): |
|
|
np.random.seed(42) |
|
|
n = 500 |
|
|
data = { |
|
|
"Age": np.random.normal(35, 12, n).astype(int), |
|
|
"Income": np.random.normal(45000, 15000, n), |
|
|
"Satisfaction": np.random.randint(1, 11, n), |
|
|
"City": np.random.choice(["Madrid", "Barcelona", "Valencia", "Seville"], n), |
|
|
"Gender": np.random.choice(["M", "F"], n, p=[0.6, 0.4]), |
|
|
"Purchase": np.random.choice([0, 1], n, p=[0.7, 0.3]), |
|
|
"Date": pd.date_range(start="2023-01-01", periods=n, freq="D") |
|
|
} |
|
|
df = pd.DataFrame(data) |
|
|
|
|
|
df.loc[np.random.choice(df.index, 15), "Income"] = np.nan |
|
|
return df |
|
|
|
|
|
|
|
|
if st.button("🧪 Load Demo Dataset"): |
|
|
st.session_state['df'] = create_demo_data() |
|
|
st.success("✅ Demo dataset loaded!") |
|
|
|
|
|
uploaded_file = st.file_uploader("📂 Upload your CSV file", type=["csv"]) |
|
|
|
|
|
if uploaded_file: |
|
|
df = pd.read_csv(uploaded_file) |
|
|
st.session_state['df'] = df |
|
|
st.success("✅ File uploaded successfully.") |
|
|
|
|
|
if 'df' not in st.session_state: |
|
|
st.info("👆 Upload a CSV or click 'Load Demo Dataset' to begin.") |
|
|
st.stop() |
|
|
|
|
|
df = st.session_state['df'] |
|
|
|
|
|
|
|
|
with st.expander("🔍 Data Preview (first 10 rows)"): |
|
|
st.dataframe(df.head(10)) |
|
|
|
|
|
|
|
|
st.subheader("📌 Dataset Info") |
|
|
col1, col2, col3 = st.columns(3) |
|
|
col1.metric("Rows", df.shape[0]) |
|
|
col2.metric("Columns", df.shape[1]) |
|
|
col3.metric("Missing Values", df.isnull().sum().sum()) |
|
|
|
|
|
|
|
|
st.header("✨ Smart Visualization Suggestions") |
|
|
|
|
|
numeric_cols = df.select_dtypes(include=[np.number]).columns.tolist() |
|
|
categorical_cols = df.select_dtypes(include=['object', 'category']).columns.tolist() |
|
|
datetime_cols = df.select_dtypes(include=['datetime', 'datetime64']).columns.tolist() |
|
|
|
|
|
if datetime_cols: |
|
|
date_col = datetime_cols[0] |
|
|
else: |
|
|
date_col = None |
|
|
|
|
|
|
|
|
suggestions = [] |
|
|
|
|
|
if len(numeric_cols) >= 2: |
|
|
suggestions.append({ |
|
|
"name": "Scatter Plot", |
|
|
"description": "Visualize relationship between two numeric variables", |
|
|
"plot_type": "scatter", |
|
|
"x": numeric_cols[0], |
|
|
"y": numeric_cols[1] if len(numeric_cols) > 1 else numeric_cols[0] |
|
|
}) |
|
|
|
|
|
if len(numeric_cols) >= 1: |
|
|
suggestions.append({ |
|
|
"name": "Histogram", |
|
|
"description": "Distribution of a numeric variable", |
|
|
"plot_type": "histogram", |
|
|
"x": numeric_cols[0] |
|
|
}) |
|
|
|
|
|
if len(categorical_cols) >= 1 and len(numeric_cols) >= 1: |
|
|
suggestions.append({ |
|
|
"name": "Bar Plot (Mean)", |
|
|
"description": "Compare mean of numeric variable across categories", |
|
|
"plot_type": "bar", |
|
|
"x": categorical_cols[0], |
|
|
"y": numeric_cols[0] |
|
|
}) |
|
|
|
|
|
if len(categorical_cols) >= 2: |
|
|
suggestions.append({ |
|
|
"name": "Stacked Bar Plot", |
|
|
"description": "Relationship between two categorical variables", |
|
|
"plot_type": "stacked_bar", |
|
|
"x": categorical_cols[0], |
|
|
"color": categorical_cols[1] if len(categorical_cols) > 1 else categorical_cols[0] |
|
|
}) |
|
|
|
|
|
if date_col and len(numeric_cols) >= 1: |
|
|
suggestions.append({ |
|
|
"name": "Time Series Line Plot", |
|
|
"description": "Trend of numeric variable over time", |
|
|
"plot_type": "line", |
|
|
"x": date_col, |
|
|
"y": numeric_cols[0] |
|
|
}) |
|
|
|
|
|
if len(numeric_cols) >= 3: |
|
|
suggestions.append({ |
|
|
"name": "Scatter Plot with Color", |
|
|
"description": "Scatter plot with third variable as color", |
|
|
"plot_type": "scatter_color", |
|
|
"x": numeric_cols[0], |
|
|
"y": numeric_cols[1], |
|
|
"color": numeric_cols[2] |
|
|
}) |
|
|
|
|
|
if len(numeric_cols) >= 2: |
|
|
suggestions.append({ |
|
|
"name": "Box Plot", |
|
|
"description": "Distribution and outliers of numeric variable by category", |
|
|
"plot_type": "box", |
|
|
"x": categorical_cols[0] if categorical_cols else None, |
|
|
"y": numeric_cols[0] |
|
|
}) |
|
|
|
|
|
if len(numeric_cols) >= 2: |
|
|
suggestions.append({ |
|
|
"name": "Correlation Heatmap", |
|
|
"description": "Correlation matrix of numeric variables", |
|
|
"plot_type": "heatmap", |
|
|
"cols": numeric_cols[:5] |
|
|
}) |
|
|
|
|
|
|
|
|
for i, suggestion in enumerate(suggestions): |
|
|
with st.expander(f"🎨 Suggestion {i+1}: {suggestion['name']}"): |
|
|
st.write(suggestion["description"]) |
|
|
if st.button(f"Create {suggestion['name']}", key=f"sug_{i}"): |
|
|
st.session_state['selected_suggestion'] = suggestion |
|
|
|
|
|
|
|
|
st.header("🛠️ Custom Visualization Builder") |
|
|
|
|
|
plot_types = [ |
|
|
"Scatter Plot", |
|
|
"Line Plot", |
|
|
"Bar Plot", |
|
|
"Histogram", |
|
|
"Box Plot", |
|
|
"Violin Plot", |
|
|
"Pie Chart", |
|
|
"Heatmap (Correlation)" |
|
|
] |
|
|
|
|
|
selected_plot = st.selectbox("Choose plot type:", plot_types) |
|
|
|
|
|
fig = None |
|
|
|
|
|
if selected_plot == "Scatter Plot": |
|
|
col1, col2 = st.columns(2) |
|
|
with col1: |
|
|
x_col = st.selectbox("X-axis:", numeric_cols) |
|
|
with col2: |
|
|
y_col = st.selectbox("Y-axis:", [col for col in numeric_cols if col != x_col] if len(numeric_cols) > 1 else numeric_cols) |
|
|
|
|
|
color_col = st.selectbox("Color by (optional):", [None] + categorical_cols + numeric_cols, key="scatter_color") |
|
|
size_col = st.selectbox("Size by (optional):", [None] + numeric_cols, key="scatter_size") |
|
|
|
|
|
title = st.text_input("Plot title:", f"{y_col} vs {x_col}") |
|
|
|
|
|
if st.button("Generate Scatter Plot"): |
|
|
fig = px.scatter(df, x=x_col, y=y_col, color=color_col, size=size_col, title=title) |
|
|
|
|
|
elif selected_plot == "Line Plot": |
|
|
if not datetime_cols and not categorical_cols: |
|
|
st.warning("No suitable columns for line plot. Need datetime or categorical x-axis.") |
|
|
else: |
|
|
available_x = datetime_cols + categorical_cols if datetime_cols else categorical_cols |
|
|
col1, col2 = st.columns(2) |
|
|
with col1: |
|
|
x_col = st.selectbox("X-axis:", available_x) |
|
|
with col2: |
|
|
y_col = st.selectbox("Y-axis:", numeric_cols) |
|
|
|
|
|
color_col = st.selectbox("Color by (optional):", [None] + categorical_cols, key="line_color") |
|
|
title = st.text_input("Plot title:", f"{y_col} over {x_col}") |
|
|
|
|
|
if st.button("Generate Line Plot"): |
|
|
fig = px.line(df, x=x_col, y=y_col, color=color_col, title=title, markers=True) |
|
|
|
|
|
elif selected_plot == "Bar Plot": |
|
|
if not categorical_cols: |
|
|
st.warning("No categorical columns available for bar plot.") |
|
|
else: |
|
|
col1, col2 = st.columns(2) |
|
|
with col1: |
|
|
x_col = st.selectbox("Category column:", categorical_cols) |
|
|
with col2: |
|
|
y_col = st.selectbox("Value column:", numeric_cols) |
|
|
|
|
|
agg_func = st.selectbox("Aggregation:", ["Mean", "Sum", "Count", "Median"]) |
|
|
color_col = st.selectbox("Color by (optional):", [None] + categorical_cols, key="bar_color") |
|
|
title = st.text_input("Plot title:", f"{agg_func} of {y_col} by {x_col}") |
|
|
|
|
|
if st.button("Generate Bar Plot"): |
|
|
if agg_func == "Mean": |
|
|
fig = px.bar(df, x=x_col, y=y_col, color=color_col, title=title) |
|
|
elif agg_func == "Sum": |
|
|
fig_data = df.groupby(x_col)[y_col].sum().reset_index() |
|
|
fig = px.bar(fig_data, x=x_col, y=y_col, color=color_col, title=title) |
|
|
elif agg_func == "Count": |
|
|
fig = px.histogram(df, x=x_col, color=color_col, title=title) |
|
|
else: |
|
|
fig_data = df.groupby(x_col)[y_col].median().reset_index() |
|
|
fig = px.bar(fig_data, x=x_col, y=y_col, color=color_col, title=title) |
|
|
|
|
|
elif selected_plot == "Histogram": |
|
|
if not numeric_cols: |
|
|
st.warning("No numeric columns available for histogram.") |
|
|
else: |
|
|
col1, col2 = st.columns(2) |
|
|
with col1: |
|
|
x_col = st.selectbox("Variable:", numeric_cols) |
|
|
with col2: |
|
|
nbins = st.slider("Number of bins:", min_value=5, max_value=100, value=30) |
|
|
|
|
|
color_col = st.selectbox("Color by (optional):", [None] + categorical_cols, key="hist_color") |
|
|
title = st.text_input("Plot title:", f"Distribution of {x_col}") |
|
|
|
|
|
if st.button("Generate Histogram"): |
|
|
fig = px.histogram(df, x=x_col, nbins=nbins, color=color_col, title=title, marginal="box") |
|
|
|
|
|
elif selected_plot == "Box Plot": |
|
|
if not numeric_cols: |
|
|
st.warning("No numeric columns available for box plot.") |
|
|
else: |
|
|
col1, col2 = st.columns(2) |
|
|
with col1: |
|
|
y_col = st.selectbox("Numeric variable:", numeric_cols) |
|
|
with col2: |
|
|
x_col = st.selectbox("Group by (optional):", [None] + categorical_cols) |
|
|
|
|
|
title = st.text_input("Plot title:", f"Box plot of {y_col}" + (f" by {x_col}" if x_col else "")) |
|
|
|
|
|
if st.button("Generate Box Plot"): |
|
|
fig = px.box(df, x=x_col, y=y_col, title=title) |
|
|
|
|
|
elif selected_plot == "Violin Plot": |
|
|
if not numeric_cols: |
|
|
st.warning("No numeric columns available for violin plot.") |
|
|
else: |
|
|
col1, col2 = st.columns(2) |
|
|
with col1: |
|
|
y_col = st.selectbox("Numeric variable:", numeric_cols) |
|
|
with col2: |
|
|
x_col = st.selectbox("Group by (optional):", [None] + categorical_cols) |
|
|
|
|
|
title = st.text_input("Plot title:", f"Violin plot of {y_col}" + (f" by {x_col}" if x_col else "")) |
|
|
|
|
|
if st.button("Generate Violin Plot"): |
|
|
fig = px.violin(df, x=x_col, y=y_col, box=True, points="outliers", title=title) |
|
|
|
|
|
elif selected_plot == "Pie Chart": |
|
|
if not categorical_cols: |
|
|
st.warning("No categorical columns available for pie chart.") |
|
|
else: |
|
|
col_to_plot = st.selectbox("Category column:", categorical_cols) |
|
|
title = st.text_input("Plot title:", f"Distribution of {col_to_plot}") |
|
|
|
|
|
if st.button("Generate Pie Chart"): |
|
|
fig = px.pie(df, names=col_to_plot, title=title) |
|
|
|
|
|
elif selected_plot == "Heatmap (Correlation)": |
|
|
if len(numeric_cols) < 2: |
|
|
st.warning("Need at least 2 numeric columns for correlation heatmap.") |
|
|
else: |
|
|
selected_cols = st.multiselect("Select columns for correlation:", numeric_cols, default=numeric_cols[:5] if len(numeric_cols) >= 5 else numeric_cols) |
|
|
|
|
|
if len(selected_cols) < 2: |
|
|
st.warning("Please select at least 2 columns.") |
|
|
else: |
|
|
title = st.text_input("Plot title:", "Correlation Heatmap") |
|
|
|
|
|
if st.button("Generate Heatmap"): |
|
|
corr_matrix = df[selected_cols].corr() |
|
|
fig = px.imshow(corr_matrix, |
|
|
text_auto=".2f", |
|
|
aspect="auto", |
|
|
title=title, |
|
|
color_continuous_scale='RdBu_r', |
|
|
labels=dict(color="Correlation")) |
|
|
|
|
|
|
|
|
if fig: |
|
|
st.plotly_chart(fig, use_container_width=True) |
|
|
|
|
|
|
|
|
st.subheader("💾 Download Plot") |
|
|
col1, col2 = st.columns(2) |
|
|
|
|
|
with col1: |
|
|
png_data = fig.to_image(format="png", width=1200, height=800, scale=2) |
|
|
st.download_button( |
|
|
label="Download as PNG", |
|
|
data=png_data, |
|
|
file_name="plot.png", |
|
|
mime="image/png" |
|
|
) |
|
|
|
|
|
with col2: |
|
|
html_data = fig.to_html(include_plotlyjs="cdn") |
|
|
st.download_button( |
|
|
label="Download as HTML", |
|
|
data=html_data, |
|
|
file_name="plot.html", |
|
|
mime="text/html" |
|
|
) |
|
|
|
|
|
|
|
|
st.header("⚖️ Compare Multiple Plots") |
|
|
|
|
|
num_plots = st.slider("Number of plots to compare:", min_value=1, max_value=4, value=2) |
|
|
|
|
|
if num_plots > 1: |
|
|
fig_compare = make_subplots( |
|
|
rows=1, cols=num_plots, |
|
|
subplot_titles=[f"Plot {i+1}" for i in range(num_plots)], |
|
|
shared_yaxes=False |
|
|
) |
|
|
|
|
|
plot_success = True |
|
|
|
|
|
for i in range(num_plots): |
|
|
st.markdown(f"### Plot {i+1}") |
|
|
plot_type = st.selectbox(f"Plot type:", plot_types, key=f"compare_type_{i}") |
|
|
|
|
|
try: |
|
|
if plot_type == "Scatter Plot" and len(numeric_cols) >= 2: |
|
|
x_col = st.selectbox(f"X-axis:", numeric_cols, key=f"compare_x_{i}") |
|
|
y_col = st.selectbox(f"Y-axis:", [col for col in numeric_cols if col != x_col], key=f"compare_y_{i}") |
|
|
trace = go.Scatter(x=df[x_col], y=df[y_col], mode='markers', name=f"{y_col} vs {x_col}") |
|
|
fig_compare.add_trace(trace, row=1, col=i+1) |
|
|
|
|
|
elif plot_type == "Histogram" and len(numeric_cols) >= 1: |
|
|
x_col = st.selectbox(f"Variable:", numeric_cols, key=f"compare_hist_{i}") |
|
|
trace = go.Histogram(x=df[x_col], name=f"Distribution of {x_col}") |
|
|
fig_compare.add_trace(trace, row=1, col=i+1) |
|
|
|
|
|
elif plot_type == "Bar Plot" and len(categorical_cols) >= 1 and len(numeric_cols) >= 1: |
|
|
x_col = st.selectbox(f"Category:", categorical_cols, key=f"compare_bar_x_{i}") |
|
|
y_col = st.selectbox(f"Value:", numeric_cols, key=f"compare_bar_y_{i}") |
|
|
trace = go.Bar(x=df[x_col], y=df[y_col], name=f"{y_col} by {x_col}") |
|
|
fig_compare.add_trace(trace, row=1, col=i+1) |
|
|
|
|
|
else: |
|
|
st.warning(f"Plot {i+1}: Invalid combination for {plot_type}") |
|
|
plot_success = False |
|
|
|
|
|
except Exception as e: |
|
|
st.error(f"Error in Plot {i+1}: {e}") |
|
|
plot_success = False |
|
|
|
|
|
if plot_success and st.button("Generate Comparison Plot"): |
|
|
fig_compare.update_layout(height=600, showlegend=True, title_text="Comparison of Multiple Plots") |
|
|
st.plotly_chart(fig_compare, use_container_width=True) |
|
|
|
|
|
|
|
|
st.markdown("---") |
|
|
st.caption(f"© {AUTHOR} | License {LICENSE} | Contact: {EMAIL}") |