saherPervaiz commited on
Commit
eee2b22
·
verified ·
1 Parent(s): 557bc92

Update utils/visualizations.py

Browse files
Files changed (1) hide show
  1. utils/visualizations.py +60 -85
utils/visualizations.py CHANGED
@@ -1,86 +1,61 @@
1
- import seaborn as sns
2
- import matplotlib.pyplot as plt
3
  import pandas as pd
4
-
5
- # Correlation Heatmap
6
- def plot_correlation_heatmap(df):
7
- """
8
- Plot a correlation heatmap for numeric columns in the dataframe.
9
- """
10
- # Select only numeric columns
11
- numeric_df = df.select_dtypes(include=['float64', 'int64'])
12
-
13
- # Compute the correlation matrix
14
- corr = numeric_df.corr()
15
-
16
- # Plot the heatmap
17
- plt.figure(figsize=(10, 8))
18
- sns.heatmap(corr, annot=True, cmap='coolwarm', fmt='.2f', linewidths=0.5)
19
- plt.title("Correlation Heatmap")
20
-
21
- return plt
22
-
23
- # Save plot as PNG image
24
- def save_plot_as_png(plot, filename="plot.png"):
25
- """
26
- Save a given plot as a PNG file.
27
- """
28
- plot.savefig(filename, format='png')
29
- return filename
30
-
31
- # Distribution Plot (Histogram)
32
- def plot_histogram(df, column):
33
- """
34
- Plot a histogram for a specified column in the dataframe.
35
- """
36
- plt.figure(figsize=(8, 6))
37
- sns.histplot(df[column], kde=True, color='blue', bins=20)
38
- plt.title(f"Distribution of {column}")
39
- plt.xlabel(column)
40
- plt.ylabel("Frequency")
41
- return plt
42
-
43
- # Box Plot (For Outliers)
44
- def plot_box_plot(df, column):
45
- """
46
- Plot a box plot for a specified column to visualize outliers.
47
- """
48
- plt.figure(figsize=(8, 6))
49
- sns.boxplot(x=df[column], color='orange')
50
- plt.title(f"Box Plot of {column}")
51
- plt.xlabel(column)
52
- return plt
53
-
54
- # Pair Plot (For Visualizing Relationships Between Features)
55
- def plot_pair_plot(df):
56
- """
57
- Plot a pair plot to visualize relationships between numeric columns in the dataframe.
58
- """
59
- numeric_df = df.select_dtypes(include=['float64', 'int64'])
60
- pair_plot = sns.pairplot(numeric_df, hue='target', palette='coolwarm') # Assuming 'target' is a column for classification
61
- pair_plot.fig.set_size_inches(10, 8)
62
- return pair_plot
63
-
64
- # Scatter Plot (For Visualizing Relationship Between Two Features)
65
- def plot_scatter_plot(df, x_column, y_column):
66
- """
67
- Plot a scatter plot to visualize the relationship between two features.
68
- """
69
- plt.figure(figsize=(8, 6))
70
- sns.scatterplot(x=df[x_column], y=df[y_column], color='green')
71
- plt.title(f"Scatter Plot between {x_column} and {y_column}")
72
- plt.xlabel(x_column)
73
- plt.ylabel(y_column)
74
- return plt
75
-
76
- # Bar Plot (For Comparing Categorical Data)
77
- def plot_bar_plot(df, column):
78
- """
79
- Plot a bar plot for a categorical column.
80
- """
81
- plt.figure(figsize=(8, 6))
82
- sns.countplot(x=df[column], palette='viridis')
83
- plt.title(f"Bar Plot of {column}")
84
- plt.xlabel(column)
85
- plt.ylabel("Count")
86
- return plt
 
1
+ import streamlit as st
 
2
  import pandas as pd
3
+ from utils.visualizations import plot_correlation_heatmap, save_plot_as_png
4
+
5
+ # File uploader
6
+ st.title("Model Training with Metrics and Correlation Heatmap")
7
+ uploaded_file = st.file_uploader("Choose a CSV file", type=["csv"])
8
+
9
+ if uploaded_file is not None:
10
+ df = pd.read_csv(uploaded_file)
11
+
12
+ # Show the dataset
13
+ st.write("Dataset:")
14
+ st.dataframe(df)
15
+
16
+ # Clean data: Missing values, outliers, and extreme values (You can add the functions like handle_missing_values, etc.)
17
+ # df = handle_missing_values(df) # Un-comment when cleaning functions are added
18
+ # df = remove_outliers_iqr(df) # Un-comment when cleaning functions are added
19
+ # df = cap_extreme_values(df) # Un-comment when cleaning functions are added
20
+
21
+ st.write("Cleaned Dataset (after applying any cleaning steps):")
22
+ st.dataframe(df)
23
+
24
+ # Add clean data download option
25
+ st.subheader("Download Cleaned Dataset")
26
+ st.download_button(
27
+ label="Download Cleaned Dataset (CSV)",
28
+ data=df.to_csv(index=False),
29
+ file_name="cleaned_dataset.csv",
30
+ mime="text/csv"
31
+ )
32
+
33
+ # Correlation Heatmap
34
+ st.subheader("Correlation Heatmap")
35
+ corr_plot = plot_correlation_heatmap(df)
36
+ st.pyplot(corr_plot) # Display the heatmap in Streamlit
37
+
38
+ # Save heatmap as PNG and allow download
39
+ heatmap_buf = save_plot_as_png(corr_plot)
40
+ st.download_button(
41
+ label="Download Correlation Heatmap as PNG",
42
+ data=heatmap_buf,
43
+ file_name="correlation_heatmap.png",
44
+ mime="image/png"
45
+ )
46
+
47
+ # Target and features selection
48
+ target = st.selectbox("Select Target Variable", df.columns)
49
+ features = [col for col in df.columns if col != target]
50
+ X = df[features]
51
+ y = df[target]
52
+
53
+ # Assuming model training and evaluation functions (train_classification_model, etc.) are implemented and imported
54
+ if y.dtype == 'object' or len(y.unique()) <= 10: # Classification
55
+ st.subheader("Classification Model Training")
56
+ # Example: metrics_df = train_classification_model(X, y)
57
+ # st.dataframe(metrics_df)
58
+ else: # Regression
59
+ st.subheader("Regression Model Training")
60
+ # Example: regression_metrics_df = train_regression_model(X, y)
61
+ # st.dataframe(regression_metrics_df)