Spaces:
Sleeping
Sleeping
| import streamlit as st | |
| import pandas as pd | |
| import seaborn as sns | |
| import matplotlib.pyplot as plt | |
| import io | |
| # Import custom functions from your utils | |
| from utils.data_cleaning import preprocess_data, remove_outliers_iqr, cap_extreme_values, convert_string_to_numeric | |
| from utils.model_training import train_all_models | |
| # New Function: Combined Histogram and Bar Plot Comparison | |
| def combined_histogram_barplot(df): | |
| """ | |
| Creates a combined histogram (numeric) and bar plot (categorical) for all attributes in the dataset. | |
| """ | |
| numeric_columns = df.select_dtypes(include=['float64', 'int64']).columns | |
| categorical_columns = df.select_dtypes(include=['object']).columns | |
| # Create a figure for combined plots | |
| fig, axes = plt.subplots(len(numeric_columns) + len(categorical_columns), 1, figsize=(10, 5 * (len(numeric_columns) + len(categorical_columns)))) | |
| # Histogram for numeric columns | |
| for i, col in enumerate(numeric_columns): | |
| axes[i].hist(df[col], bins=20, color='blue', alpha=0.7, edgecolor='black') | |
| axes[i].set_title(f"Histogram of {col}") | |
| axes[i].set_xlabel(col) | |
| axes[i].set_ylabel("Frequency") | |
| # Bar plots for categorical columns | |
| for j, col in enumerate(categorical_columns, start=len(numeric_columns)): | |
| df[col].value_counts().plot(kind='bar', ax=axes[j], color='orange', alpha=0.7, edgecolor='black') | |
| axes[j].set_title(f"Bar Plot of {col}") | |
| axes[j].set_xlabel(col) | |
| axes[j].set_ylabel("Count") | |
| plt.tight_layout() | |
| return fig | |
| # Plotting Functions | |
| def plot_correlation_heatmap(df): | |
| """ | |
| Plot a correlation heatmap for the numeric columns in the dataframe. | |
| """ | |
| corr = df.corr() | |
| fig = plt.figure(figsize=(10, 8)) # Create a new figure object | |
| heatmap = sns.heatmap(corr, annot=True, cmap="coolwarm", fmt=".2f", linewidths=0.5) | |
| plt.title("Correlation Heatmap") | |
| return fig # Return the figure object | |
| def save_figure_as_png(fig): | |
| """ | |
| Save the given figure as a PNG file to a BytesIO buffer. | |
| """ | |
| buffer = io.BytesIO() | |
| fig.savefig(buffer, format="png") # Save the figure to the buffer | |
| buffer.seek(0) # Reset the buffer's position to the beginning | |
| return buffer | |
| def plot_histogram(df, column): | |
| """ | |
| Plot a histogram for a specific column in the dataframe. | |
| """ | |
| plt.figure(figsize=(8, 6)) | |
| sns.histplot(df[column], kde=True, bins=30, color="skyblue") | |
| plt.title(f"Histogram of {column}") | |
| plt.xlabel(column) | |
| plt.ylabel("Frequency") | |
| return plt.gcf() | |
| def plot_box_plot(df, column): | |
| """ | |
| Plot a box plot for a specific column in the dataframe. | |
| """ | |
| plt.figure(figsize=(8, 6)) | |
| sns.boxplot(x=df[column]) | |
| plt.title(f"Box Plot of {column}") | |
| return plt.gcf() | |
| def plot_pair_plot(df): | |
| """ | |
| Plot a pair plot for numeric columns in the dataframe. | |
| """ | |
| numeric_columns = df.select_dtypes(include=['float64', 'int64']).columns | |
| return sns.pairplot(df[numeric_columns]) | |
| def plot_scatter_plot(df, x_col, y_col): | |
| """ | |
| Plot a scatter plot between two numeric columns. | |
| """ | |
| plt.figure(figsize=(8, 6)) | |
| sns.scatterplot(x=df[x_col], y=df[y_col], color="green") | |
| plt.title(f"Scatter Plot between {x_col} and {y_col}") | |
| return plt.gcf() | |
| def plot_bar_plot(df, column): | |
| """ | |
| Plot a bar plot for a categorical column. | |
| """ | |
| plt.figure(figsize=(8, 6)) | |
| sns.countplot(x=df[column]) | |
| plt.title(f"Bar Plot of {column}") | |
| return plt.gcf() | |
| # Streamlit App Title | |
| st.title("Data Analysis, Model Training, and Visualization") | |
| # File Uploader | |
| uploaded_file = st.file_uploader("Upload a CSV file for data analysis", type=["csv"]) | |
| if uploaded_file is not None: | |
| # Load dataset | |
| df = pd.read_csv(uploaded_file) | |
| st.write("### Dataset Preview") | |
| st.dataframe(df) | |
| try: | |
| # Data Cleaning | |
| st.subheader("Data Cleaning") | |
| st.write("Handling missing values, removing outliers, and capping extreme values...") | |
| df_cleaned = preprocess_data(df) | |
| df_cleaned = remove_outliers_iqr(df_cleaned) | |
| df_cleaned = cap_extreme_values(df_cleaned) | |
| # Convert string columns to numeric (if any) | |
| st.subheader("String to Numeric Conversion") | |
| st.write("Converting string categorical values to numeric using Label Encoding...") | |
| df_cleaned = convert_string_to_numeric(df_cleaned) | |
| st.write("### Cleaned Dataset") | |
| st.dataframe(df_cleaned) | |
| # Download option for cleaned dataset | |
| st.download_button( | |
| label="Download Cleaned Dataset (CSV)", | |
| data=df_cleaned.to_csv(index=False), | |
| file_name="cleaned_dataset.csv", | |
| mime="text/csv" | |
| ) | |
| # Correlation Heatmap | |
| st.subheader("Correlation Heatmap") | |
| st.write("Visualizing correlations between numeric features...") | |
| heatmap_fig = plot_correlation_heatmap(df_cleaned) | |
| st.pyplot(heatmap_fig) # Display the heatmap using Streamlit | |
| # Save and download heatmap as PNG | |
| heatmap_buffer = save_figure_as_png(heatmap_fig) # Save the figure to buffer | |
| st.download_button( | |
| label="Download Correlation Heatmap (PNG)", | |
| data=heatmap_buffer, | |
| file_name="correlation_heatmap.png", | |
| mime="image/png" | |
| ) | |
| # Additional Visualizations | |
| st.subheader("Additional Visualizations") | |
| numeric_columns = df_cleaned.select_dtypes(include=['float64', 'int64']).columns.tolist() | |
| categorical_columns = df_cleaned.select_dtypes(include=['object']).columns.tolist() | |
| # Combined Histogram and Bar Plot | |
| st.subheader("Combined Histogram and Bar Plot") | |
| combined_plot = combined_histogram_barplot(df_cleaned) | |
| st.pyplot(combined_plot) | |
| # Distribution Plot | |
| if numeric_columns: | |
| st.write("### Distribution Plots (Histograms)") | |
| for col in numeric_columns: | |
| st.write(f"#### {col}") | |
| hist_plot = plot_histogram(df_cleaned, col) | |
| st.pyplot(hist_plot) | |
| # Box Plot | |
| if numeric_columns: | |
| st.write("### Box Plots (Outlier Detection)") | |
| for col in numeric_columns: | |
| st.write(f"#### {col}") | |
| box_plot = plot_box_plot(df_cleaned, col) | |
| st.pyplot(box_plot) | |
| # Pair Plot | |
| if len(numeric_columns) > 1: | |
| st.write("### Pair Plot") | |
| pair_plot = plot_pair_plot(df_cleaned) | |
| st.pyplot(pair_plot) | |
| # Scatter Plot | |
| if len(numeric_columns) > 1: | |
| st.write("### Scatter Plot") | |
| x_col = st.selectbox("Select X-axis for Scatter Plot", numeric_columns) | |
| y_col = st.selectbox("Select Y-axis for Scatter Plot", numeric_columns) | |
| if x_col and y_col: | |
| scatter_plot = plot_scatter_plot(df_cleaned, x_col, y_col) | |
| st.pyplot(scatter_plot) | |
| # Bar Plot | |
| if categorical_columns: | |
| st.write("### Bar Plots (For Categorical Data)") | |
| for col in categorical_columns: | |
| st.write(f"#### {col}") | |
| bar_plot = plot_bar_plot(df_cleaned, col) | |
| st.pyplot(bar_plot) | |
| # Select Target and Features | |
| st.subheader("Feature and Target Selection") | |
| target = st.selectbox("Select Target Variable", df_cleaned.columns) | |
| features = [col for col in df_cleaned.columns if col != target] | |
| if not features: | |
| st.warning("No features available after removing the target variable.") | |
| else: | |
| X = df_cleaned[features] | |
| y = df_cleaned[target] | |
| # Train and Evaluate Models | |
| st.subheader("Model Training and Evaluation") | |
| st.write("Training models and calculating metrics...") | |
| model_results = train_all_models(X, y) | |
| st.write("### Model Training Results") | |
| st.dataframe(model_results) | |
| # Download option for model results | |
| st.download_button( | |
| label="Download Model Results (CSV)", | |
| data=model_results.to_csv(index=False), | |
| file_name="model_results.csv", | |
| mime="text/csv" | |
| ) | |
| except Exception as e: | |
| st.error(f"An error occurred: {e}") | |
| else: | |
| st.info("Please upload a CSV file to proceed.") | |