Depression / app.py
saherPervaiz's picture
Update app.py
5cb3e7f verified
import streamlit as st
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
import io
# Import custom functions from your utils
from utils.data_cleaning import preprocess_data, remove_outliers_iqr, cap_extreme_values, convert_string_to_numeric
from utils.model_training import train_all_models
# New Function: Combined Histogram and Bar Plot Comparison
def combined_histogram_barplot(df):
"""
Creates a combined histogram (numeric) and bar plot (categorical) for all attributes in the dataset.
"""
numeric_columns = df.select_dtypes(include=['float64', 'int64']).columns
categorical_columns = df.select_dtypes(include=['object']).columns
# Create a figure for combined plots
fig, axes = plt.subplots(len(numeric_columns) + len(categorical_columns), 1, figsize=(10, 5 * (len(numeric_columns) + len(categorical_columns))))
# Histogram for numeric columns
for i, col in enumerate(numeric_columns):
axes[i].hist(df[col], bins=20, color='blue', alpha=0.7, edgecolor='black')
axes[i].set_title(f"Histogram of {col}")
axes[i].set_xlabel(col)
axes[i].set_ylabel("Frequency")
# Bar plots for categorical columns
for j, col in enumerate(categorical_columns, start=len(numeric_columns)):
df[col].value_counts().plot(kind='bar', ax=axes[j], color='orange', alpha=0.7, edgecolor='black')
axes[j].set_title(f"Bar Plot of {col}")
axes[j].set_xlabel(col)
axes[j].set_ylabel("Count")
plt.tight_layout()
return fig
# Plotting Functions
def plot_correlation_heatmap(df):
"""
Plot a correlation heatmap for the numeric columns in the dataframe.
"""
corr = df.corr()
fig = plt.figure(figsize=(10, 8)) # Create a new figure object
heatmap = sns.heatmap(corr, annot=True, cmap="coolwarm", fmt=".2f", linewidths=0.5)
plt.title("Correlation Heatmap")
return fig # Return the figure object
def save_figure_as_png(fig):
"""
Save the given figure as a PNG file to a BytesIO buffer.
"""
buffer = io.BytesIO()
fig.savefig(buffer, format="png") # Save the figure to the buffer
buffer.seek(0) # Reset the buffer's position to the beginning
return buffer
def plot_histogram(df, column):
"""
Plot a histogram for a specific column in the dataframe.
"""
plt.figure(figsize=(8, 6))
sns.histplot(df[column], kde=True, bins=30, color="skyblue")
plt.title(f"Histogram of {column}")
plt.xlabel(column)
plt.ylabel("Frequency")
return plt.gcf()
def plot_box_plot(df, column):
"""
Plot a box plot for a specific column in the dataframe.
"""
plt.figure(figsize=(8, 6))
sns.boxplot(x=df[column])
plt.title(f"Box Plot of {column}")
return plt.gcf()
def plot_pair_plot(df):
"""
Plot a pair plot for numeric columns in the dataframe.
"""
numeric_columns = df.select_dtypes(include=['float64', 'int64']).columns
return sns.pairplot(df[numeric_columns])
def plot_scatter_plot(df, x_col, y_col):
"""
Plot a scatter plot between two numeric columns.
"""
plt.figure(figsize=(8, 6))
sns.scatterplot(x=df[x_col], y=df[y_col], color="green")
plt.title(f"Scatter Plot between {x_col} and {y_col}")
return plt.gcf()
def plot_bar_plot(df, column):
"""
Plot a bar plot for a categorical column.
"""
plt.figure(figsize=(8, 6))
sns.countplot(x=df[column])
plt.title(f"Bar Plot of {column}")
return plt.gcf()
# Streamlit App Title
st.title("Data Analysis, Model Training, and Visualization")
# File Uploader
uploaded_file = st.file_uploader("Upload a CSV file for data analysis", type=["csv"])
if uploaded_file is not None:
# Load dataset
df = pd.read_csv(uploaded_file)
st.write("### Dataset Preview")
st.dataframe(df)
try:
# Data Cleaning
st.subheader("Data Cleaning")
st.write("Handling missing values, removing outliers, and capping extreme values...")
df_cleaned = preprocess_data(df)
df_cleaned = remove_outliers_iqr(df_cleaned)
df_cleaned = cap_extreme_values(df_cleaned)
# Convert string columns to numeric (if any)
st.subheader("String to Numeric Conversion")
st.write("Converting string categorical values to numeric using Label Encoding...")
df_cleaned = convert_string_to_numeric(df_cleaned)
st.write("### Cleaned Dataset")
st.dataframe(df_cleaned)
# Download option for cleaned dataset
st.download_button(
label="Download Cleaned Dataset (CSV)",
data=df_cleaned.to_csv(index=False),
file_name="cleaned_dataset.csv",
mime="text/csv"
)
# Correlation Heatmap
st.subheader("Correlation Heatmap")
st.write("Visualizing correlations between numeric features...")
heatmap_fig = plot_correlation_heatmap(df_cleaned)
st.pyplot(heatmap_fig) # Display the heatmap using Streamlit
# Save and download heatmap as PNG
heatmap_buffer = save_figure_as_png(heatmap_fig) # Save the figure to buffer
st.download_button(
label="Download Correlation Heatmap (PNG)",
data=heatmap_buffer,
file_name="correlation_heatmap.png",
mime="image/png"
)
# Additional Visualizations
st.subheader("Additional Visualizations")
numeric_columns = df_cleaned.select_dtypes(include=['float64', 'int64']).columns.tolist()
categorical_columns = df_cleaned.select_dtypes(include=['object']).columns.tolist()
# Combined Histogram and Bar Plot
st.subheader("Combined Histogram and Bar Plot")
combined_plot = combined_histogram_barplot(df_cleaned)
st.pyplot(combined_plot)
# Distribution Plot
if numeric_columns:
st.write("### Distribution Plots (Histograms)")
for col in numeric_columns:
st.write(f"#### {col}")
hist_plot = plot_histogram(df_cleaned, col)
st.pyplot(hist_plot)
# Box Plot
if numeric_columns:
st.write("### Box Plots (Outlier Detection)")
for col in numeric_columns:
st.write(f"#### {col}")
box_plot = plot_box_plot(df_cleaned, col)
st.pyplot(box_plot)
# Pair Plot
if len(numeric_columns) > 1:
st.write("### Pair Plot")
pair_plot = plot_pair_plot(df_cleaned)
st.pyplot(pair_plot)
# Scatter Plot
if len(numeric_columns) > 1:
st.write("### Scatter Plot")
x_col = st.selectbox("Select X-axis for Scatter Plot", numeric_columns)
y_col = st.selectbox("Select Y-axis for Scatter Plot", numeric_columns)
if x_col and y_col:
scatter_plot = plot_scatter_plot(df_cleaned, x_col, y_col)
st.pyplot(scatter_plot)
# Bar Plot
if categorical_columns:
st.write("### Bar Plots (For Categorical Data)")
for col in categorical_columns:
st.write(f"#### {col}")
bar_plot = plot_bar_plot(df_cleaned, col)
st.pyplot(bar_plot)
# Select Target and Features
st.subheader("Feature and Target Selection")
target = st.selectbox("Select Target Variable", df_cleaned.columns)
features = [col for col in df_cleaned.columns if col != target]
if not features:
st.warning("No features available after removing the target variable.")
else:
X = df_cleaned[features]
y = df_cleaned[target]
# Train and Evaluate Models
st.subheader("Model Training and Evaluation")
st.write("Training models and calculating metrics...")
model_results = train_all_models(X, y)
st.write("### Model Training Results")
st.dataframe(model_results)
# Download option for model results
st.download_button(
label="Download Model Results (CSV)",
data=model_results.to_csv(index=False),
file_name="model_results.csv",
mime="text/csv"
)
except Exception as e:
st.error(f"An error occurred: {e}")
else:
st.info("Please upload a CSV file to proceed.")