Depression / utils /visualizations.py
saherPervaiz's picture
Update utils/visualizations.py
50d5332 verified
import seaborn as sns
import matplotlib.pyplot as plt
def plot_correlation_heatmap(df):
"""
Plot a correlation heatmap for the numeric columns in the dataframe.
"""
corr = df.corr()
plt.figure(figsize=(10, 8))
heatmap = sns.heatmap(corr, annot=True, cmap="coolwarm", fmt=".2f", linewidths=0.5)
plt.title("Correlation Heatmap")
return heatmap
def plot_histogram(df, column):
"""
Plot a histogram for a specific column in the dataframe.
"""
plt.figure(figsize=(8, 6))
sns.histplot(df[column], kde=True, bins=30, color="skyblue")
plt.title(f"Histogram of {column}")
plt.xlabel(column)
plt.ylabel("Frequency")
return plt.gcf()
def plot_box_plot(df, column):
"""
Plot a box plot for a specific column in the dataframe.
"""
plt.figure(figsize=(8, 6))
sns.boxplot(x=df[column])
plt.title(f"Box Plot of {column}")
return plt.gcf()
def plot_pair_plot(df):
"""
Plot a pair plot for numeric columns in the dataframe.
"""
numeric_columns = df.select_dtypes(include=['float64', 'int64']).columns
return sns.pairplot(df[numeric_columns])
def plot_scatter_plot(df, x_col, y_col):
"""
Plot a scatter plot between two numeric columns.
"""
plt.figure(figsize=(8, 6))
sns.scatterplot(x=df[x_col], y=df[y_col], color="green")
plt.title(f"Scatter Plot between {x_col} and {y_col}")
return plt.gcf()
def plot_bar_plot(df, column):
"""
Plot a bar plot for a categorical column.
"""
plt.figure(figsize=(8, 6))
sns.countplot(x=df[column])
plt.title(f"Bar Plot of {column}")
return plt.gcf()