saherPervaiz commited on
Commit
50d5332
·
verified ·
1 Parent(s): 4ccd84b

Update utils/visualizations.py

Browse files
Files changed (1) hide show
  1. utils/visualizations.py +21 -91
utils/visualizations.py CHANGED
@@ -1,127 +1,57 @@
1
  import seaborn as sns
2
  import matplotlib.pyplot as plt
3
- import pandas as pd
4
- import streamlit as st
5
 
6
- # Correlation Heatmap
7
  def plot_correlation_heatmap(df):
8
  """
9
- Plot a correlation heatmap for numeric columns in the dataframe.
10
  """
11
- # Select only numeric columns
12
- numeric_df = df.select_dtypes(include=['float64', 'int64'])
13
-
14
- # Compute the correlation matrix
15
- corr = numeric_df.corr()
16
-
17
- # Plot the heatmap
18
  plt.figure(figsize=(10, 8))
19
- sns.heatmap(corr, annot=True, cmap='coolwarm', fmt='.2f', linewidths=0.5)
20
  plt.title("Correlation Heatmap")
21
-
22
- return plt
23
-
24
- # Save plot as PNG image
25
- def save_plot_as_png(plot, filename="plot.png"):
26
- """
27
- Save a given plot as a PNG file.
28
- """
29
- plot.savefig(filename, format='png')
30
- return filename
31
 
32
- # Distribution Plot (Histogram)
33
  def plot_histogram(df, column):
34
  """
35
- Plot a histogram for a specified column in the dataframe.
36
  """
37
  plt.figure(figsize=(8, 6))
38
- sns.histplot(df[column], kde=True, color='blue', bins=20)
39
- plt.title(f"Distribution of {column}")
40
  plt.xlabel(column)
41
  plt.ylabel("Frequency")
42
- return plt
43
 
44
- # Box Plot (For Outliers)
45
  def plot_box_plot(df, column):
46
  """
47
- Plot a box plot for a specified column to visualize outliers.
48
  """
49
  plt.figure(figsize=(8, 6))
50
- sns.boxplot(x=df[column], color='orange')
51
  plt.title(f"Box Plot of {column}")
52
- plt.xlabel(column)
53
- return plt
54
 
55
- # Pair Plot (For Visualizing Relationships Between Features)
56
  def plot_pair_plot(df):
57
  """
58
- Plot a pair plot to visualize relationships between numeric columns in the dataframe.
59
  """
60
- numeric_df = df.select_dtypes(include=['float64', 'int64'])
61
- pair_plot = sns.pairplot(numeric_df)
62
- pair_plot.fig.set_size_inches(10, 8)
63
- return pair_plot
64
 
65
- # Scatter Plot (For Visualizing Relationship Between Two Features)
66
- def plot_scatter_plot(df, x_column, y_column):
67
  """
68
- Plot a scatter plot to visualize the relationship between two features.
69
  """
70
  plt.figure(figsize=(8, 6))
71
- sns.scatterplot(x=df[x_column], y=df[y_column], color='green')
72
- plt.title(f"Scatter Plot between {x_column} and {y_column}")
73
- plt.xlabel(x_column)
74
- plt.ylabel(y_column)
75
- return plt
76
 
77
- # Bar Plot (For Comparing Categorical Data)
78
  def plot_bar_plot(df, column):
79
  """
80
  Plot a bar plot for a categorical column.
81
  """
82
  plt.figure(figsize=(8, 6))
83
- sns.countplot(x=df[column], palette='viridis')
84
  plt.title(f"Bar Plot of {column}")
85
- plt.xlabel(column)
86
- plt.ylabel("Count")
87
- return plt
88
-
89
- st.subheader("Correlation Heatmap")
90
- if st.button("Generate Correlation Heatmap"):
91
- heatmap_plot = plot_correlation_heatmap(df)
92
- st.pyplot(heatmap_plot)
93
-
94
- st.subheader("Histogram")
95
- selected_column_hist = st.selectbox("Select Column for Histogram", df.select_dtypes(include=['float64', 'int64']).columns)
96
- if st.button("Generate Histogram"):
97
- hist_plot = plot_histogram(df, selected_column_hist)
98
- st.pyplot(hist_plot)
99
-
100
- st.subheader("Box Plot")
101
- selected_column_box = st.selectbox("Select Column for Box Plot", df.select_dtypes(include=['float64', 'int64']).columns, key="box")
102
- if st.button("Generate Box Plot"):
103
- box_plot = plot_box_plot(df, selected_column_box)
104
- st.pyplot(box_plot)
105
-
106
- st.subheader("Pair Plot")
107
- if st.button("Generate Pair Plot"):
108
- pair_plot = plot_pair_plot(df)
109
- st.pyplot(pair_plot)
110
-
111
- st.subheader("Scatter Plot")
112
- numeric_columns = df.select_dtypes(include=['float64', 'int64']).columns
113
- x_column = st.selectbox("Select X-Axis Column", numeric_columns, key="scatter_x")
114
- y_column = st.selectbox("Select Y-Axis Column", numeric_columns, key="scatter_y")
115
- if st.button("Generate Scatter Plot"):
116
- scatter_plot = plot_scatter_plot(df, x_column, y_column)
117
- st.pyplot(scatter_plot)
118
-
119
- st.subheader("Bar Plot")
120
- categorical_columns = df.select_dtypes(include=['object']).columns
121
- if not categorical_columns.empty:
122
- selected_column_bar = st.selectbox("Select Column for Bar Plot", categorical_columns)
123
- if st.button("Generate Bar Plot"):
124
- bar_plot = plot_bar_plot(df, selected_column_bar)
125
- st.pyplot(bar_plot)
126
- else:
127
- st.info("No categorical columns available for bar plot.")
 
1
  import seaborn as sns
2
  import matplotlib.pyplot as plt
 
 
3
 
 
4
  def plot_correlation_heatmap(df):
5
  """
6
+ Plot a correlation heatmap for the numeric columns in the dataframe.
7
  """
8
+ corr = df.corr()
 
 
 
 
 
 
9
  plt.figure(figsize=(10, 8))
10
+ heatmap = sns.heatmap(corr, annot=True, cmap="coolwarm", fmt=".2f", linewidths=0.5)
11
  plt.title("Correlation Heatmap")
12
+ return heatmap
 
 
 
 
 
 
 
 
 
13
 
 
14
  def plot_histogram(df, column):
15
  """
16
+ Plot a histogram for a specific column in the dataframe.
17
  """
18
  plt.figure(figsize=(8, 6))
19
+ sns.histplot(df[column], kde=True, bins=30, color="skyblue")
20
+ plt.title(f"Histogram of {column}")
21
  plt.xlabel(column)
22
  plt.ylabel("Frequency")
23
+ return plt.gcf()
24
 
 
25
  def plot_box_plot(df, column):
26
  """
27
+ Plot a box plot for a specific column in the dataframe.
28
  """
29
  plt.figure(figsize=(8, 6))
30
+ sns.boxplot(x=df[column])
31
  plt.title(f"Box Plot of {column}")
32
+ return plt.gcf()
 
33
 
 
34
  def plot_pair_plot(df):
35
  """
36
+ Plot a pair plot for numeric columns in the dataframe.
37
  """
38
+ numeric_columns = df.select_dtypes(include=['float64', 'int64']).columns
39
+ return sns.pairplot(df[numeric_columns])
 
 
40
 
41
+ def plot_scatter_plot(df, x_col, y_col):
 
42
  """
43
+ Plot a scatter plot between two numeric columns.
44
  """
45
  plt.figure(figsize=(8, 6))
46
+ sns.scatterplot(x=df[x_col], y=df[y_col], color="green")
47
+ plt.title(f"Scatter Plot between {x_col} and {y_col}")
48
+ return plt.gcf()
 
 
49
 
 
50
  def plot_bar_plot(df, column):
51
  """
52
  Plot a bar plot for a categorical column.
53
  """
54
  plt.figure(figsize=(8, 6))
55
+ sns.countplot(x=df[column])
56
  plt.title(f"Bar Plot of {column}")
57
+ return plt.gcf()