Spaces:
Sleeping
Sleeping
| import streamlit as st | |
| import pandas as pd | |
| import matplotlib.pyplot as plt | |
| from io import BytesIO | |
| import numpy as np | |
| # Set the style for all plots - using a built-in style | |
| plt.style.use('fivethirtyeight') | |
| def configure_plot_style(fig, ax): | |
| """Configure common plot styling elements""" | |
| ax.spines['top'].set_visible(False) | |
| ax.spines['right'].set_visible(False) | |
| ax.grid(True, linestyle='--', alpha=0.7) | |
| fig.patch.set_facecolor('white') | |
| ax.set_facecolor('white') | |
| st.title("Interactive Dataset Plotting Tool") | |
| # Upload Dataset | |
| uploaded_file = st.file_uploader("Upload your CSV dataset", type=["csv"]) | |
| if uploaded_file: | |
| try: | |
| # Load dataset | |
| df = pd.read_csv(uploaded_file) | |
| st.write("Dataset Preview:") | |
| st.dataframe(df) | |
| # Plot type selection | |
| plot_types = ["Line Plot", "Bar Plot", "Scatter Plot", "Histogram", "Box Plot", "Correlation Matrix"] | |
| plot_type = st.selectbox("Select Plot Type:", plot_types) | |
| # Color scheme selection | |
| color_schemes = ['viridis', 'magma', 'plasma', 'inferno', 'cividis'] | |
| color_scheme = st.selectbox("Select Color Scheme:", color_schemes) | |
| # Common figure creation | |
| fig, ax = plt.subplots(figsize=(10, 6)) | |
| configure_plot_style(fig, ax) | |
| if plot_type in ["Line Plot", "Bar Plot"]: | |
| x_column = st.selectbox("Select X-axis column:", df.columns) | |
| y_column = st.selectbox("Select Y-axis column:", df.columns) | |
| if not pd.api.types.is_numeric_dtype(df[y_column]): | |
| st.warning("Y-axis column must be numeric for this plot type.") | |
| else: | |
| if plot_type == "Line Plot": | |
| ax.plot(df[x_column], df[y_column], marker='o', linewidth=2, | |
| color=plt.cm.get_cmap(color_scheme)(0.6)) | |
| else: | |
| ax.bar(df[x_column], df[y_column], color=plt.cm.get_cmap(color_scheme)(0.6)) | |
| ax.set_title(f"{plot_type} of {y_column} vs {x_column}", pad=20, fontsize=14) | |
| ax.set_xlabel(x_column, fontsize=12) | |
| ax.set_ylabel(y_column, fontsize=12) | |
| plt.xticks(rotation=45 if len(df[x_column].unique()) > 10 else 0) | |
| elif plot_type == "Scatter Plot": | |
| x_column = st.selectbox("Select X-axis column:", df.columns) | |
| y_column = st.selectbox("Select Y-axis column:", df.columns) | |
| if not pd.api.types.is_numeric_dtype(df[x_column]) or not pd.api.types.is_numeric_dtype(df[y_column]): | |
| st.warning("Both X and Y columns must be numeric for scatter plot.") | |
| else: | |
| scatter = ax.scatter(df[x_column], df[y_column], | |
| c=np.arange(len(df)), cmap=color_scheme, | |
| alpha=0.6, s=100) | |
| plt.colorbar(scatter, ax=ax, label='Index') | |
| ax.set_title(f"Scatter Plot of {y_column} vs {x_column}", pad=20, fontsize=14) | |
| ax.set_xlabel(x_column, fontsize=12) | |
| ax.set_ylabel(y_column, fontsize=12) | |
| elif plot_type == "Histogram": | |
| column = st.selectbox("Select column:", df.columns) | |
| bins = st.slider("Number of bins:", min_value=5, max_value=50, value=20) | |
| if not pd.api.types.is_numeric_dtype(df[column]): | |
| st.warning("Column must be numeric for histogram.") | |
| else: | |
| n, bins, patches = ax.hist(df[column], bins=bins, edgecolor='white', linewidth=1) | |
| for i, patch in enumerate(patches): | |
| patch.set_facecolor(plt.cm.get_cmap(color_scheme)(i/len(patches))) | |
| ax.set_title(f"Histogram of {column}", pad=20, fontsize=14) | |
| ax.set_xlabel(column, fontsize=12) | |
| ax.set_ylabel("Frequency", fontsize=12) | |
| elif plot_type == "Box Plot": | |
| x_column = st.selectbox("Select grouping column:", df.columns) | |
| y_column = st.selectbox("Select value column:", df.columns) | |
| if not pd.api.types.is_numeric_dtype(df[y_column]): | |
| st.warning("Value column must be numeric for box plot.") | |
| else: | |
| box_plot = ax.boxplot([group[1][y_column].values for group in df.groupby(x_column)], | |
| labels=df[x_column].unique(), | |
| patch_artist=True) | |
| # Color the boxes | |
| colors = [plt.cm.get_cmap(color_scheme)(i/len(box_plot['boxes'])) | |
| for i in range(len(box_plot['boxes']))] | |
| for patch, color in zip(box_plot['boxes'], colors): | |
| patch.set_facecolor(color) | |
| patch.set_alpha(0.7) | |
| ax.set_title(f"Box Plot of {y_column} grouped by {x_column}", pad=20, fontsize=14) | |
| ax.set_xlabel(x_column, fontsize=12) | |
| ax.set_ylabel(y_column, fontsize=12) | |
| plt.xticks(rotation=45 if len(df[x_column].unique()) > 10 else 0) | |
| elif plot_type == "Correlation Matrix": | |
| numeric_columns = df.select_dtypes(include=['float64', 'int64']).columns | |
| numeric_df = df[numeric_columns] | |
| if len(numeric_columns) == 0: | |
| st.warning("No numeric columns found in the dataset for correlation matrix.") | |
| else: | |
| corr = numeric_df.corr() | |
| im = ax.imshow(corr, cmap=color_scheme) | |
| plt.colorbar(im, ax=ax) | |
| # Add correlation values | |
| for i in range(len(corr)): | |
| for j in range(len(corr)): | |
| text = ax.text(j, i, f'{corr.iloc[i, j]:.2f}', | |
| ha='center', va='center', | |
| color='white' if abs(corr.iloc[i, j]) > 0.5 else 'black') | |
| ax.set_xticks(range(len(corr.columns))) | |
| ax.set_yticks(range(len(corr.columns))) | |
| ax.set_xticklabels(corr.columns, rotation=45, ha='right') | |
| ax.set_yticklabels(corr.columns) | |
| ax.set_title("Correlation Matrix", pad=20, fontsize=14) | |
| # Adjust layout and display plot | |
| plt.tight_layout() | |
| st.pyplot(fig) | |
| # Download button | |
| buffer = BytesIO() | |
| plt.savefig(buffer, format="png", dpi=300, bbox_inches='tight') | |
| buffer.seek(0) | |
| st.download_button( | |
| label="Download Plot as PNG", | |
| data=buffer, | |
| file_name="plot.png", | |
| mime="image/png" | |
| ) | |
| except Exception as e: | |
| st.error(f"An error occurred: {str(e)}") | |
| st.info("Please make sure your dataset is properly formatted and contains appropriate data types for the selected plot type.") |