File size: 5,057 Bytes
0df7f5d | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 | import matplotlib.pyplot as plt
import seaborn as sns
from matplotlib.figure import Figure
from pandas import DataFrame, Series
from src.theme import custom_palette
def plot_target_distribution(df: DataFrame) -> tuple[DataFrame, Figure]:
"""
Plot the distribution of the 'TARGET' column in a DataFrame.
Args:
df (DataFrame): The input DataFrame containing the 'TARGET' column.
Returns:
DataFrame: A DataFrame containing the count and percentage of each class.
Figure: The matplotlib Figure object containing the plot.
"""
target_counts = df["TARGET"].value_counts()
target_percent = (target_counts / target_counts.sum() * 100).round(2)
# Combine into a DataFrame for clarity
target_df = target_counts.to_frame(name="Count")
target_df["Percentage"] = target_percent
fig, ax = plt.subplots(figsize=(8, 5))
sns.barplot(
data=target_df,
x="TARGET",
y="Count",
hue="TARGET",
palette=custom_palette[:2],
)
# Titles and formatting
ax.set_xlabel("Payment Difficulties (1 = Yes, 0 = No)", fontsize=12)
ax.set_ylabel("Count", fontsize=12)
ax.grid(axis="y", linestyle="--", alpha=0.4)
fig.tight_layout()
return target_df, fig
def plot_credit_amounts(df: DataFrame) -> Figure:
"""
Plot a histogram of credit amounts.
Args:
df (DataFrame): The DataFrame containing the credit amount data.
Returns:
Figure: The matplotlib figure object containing the plot.
"""
fig, ax = plt.subplots(figsize=(10, 6))
sns.histplot(data=df, x="AMT_CREDIT", bins=100, kde=True, color=custom_palette[0])
ax.grid(axis="y", linestyle="--", alpha=0.5)
fig.tight_layout()
return fig
def plot_education_levels(df: DataFrame) -> tuple[DataFrame, Figure]:
"""
Plot a bar chart of education levels.
Args:
df (DataFrame): The DataFrame containing the education level data.
Returns:
DataFrame: The DataFrame containing the education level counts and percentages.
Figure: The matplotlib figure object containing the plot.
"""
education_count = (
df["NAME_EDUCATION_TYPE"].value_counts().sort_values(ascending=False)
)
education_percentage = (education_count / df.shape[0] * 100).round(2)
education_df = education_count.to_frame(name="Count")
education_df["Percentage"] = education_percentage
fig, ax = plt.subplots(figsize=(10, 6))
sns.countplot(
data=df,
y="NAME_EDUCATION_TYPE",
hue="NAME_EDUCATION_TYPE",
palette=custom_palette[:5],
)
ax.set_xlabel("Count")
ax.set_ylabel("Education Level")
ax.grid(axis="x", linestyle="--", alpha=0.5)
fig.tight_layout()
return education_df, fig
def plot_occupation(df: DataFrame) -> tuple[Series, Figure]:
"""
Plot the distribution of occupations in the dataset.
Args:
df (DataFrame): The DataFrame containing the data.
Returns:
Series: A Series containing the count of each occupation.
Figure: A Matplotlib Figure object containing the plot.
"""
occupation_df = df["OCCUPATION_TYPE"].value_counts(dropna=False, ascending=False)
fig, ax = plt.subplots(figsize=(10, 6))
sns.barplot(
x=occupation_df.values,
y=occupation_df.index,
hue=occupation_df.index,
legend=False,
)
ax.set_xlabel("Number of Applicants")
ax.set_ylabel("Occupation")
ax.grid(axis="x", linestyle="--", alpha=0.5)
fig.tight_layout()
return occupation_df, fig
def plot_family_status(df: DataFrame) -> tuple[Series, Figure]:
"""
Plot the distribution of family statuses in the dataset.
Args:
df (DataFrame): The DataFrame containing the data.
Returns:
Series: A Series containing the count of each family status.
Figure: A Matplotlib Figure object containing the plot.
"""
family_status_df = df["NAME_FAMILY_STATUS"].value_counts(
dropna=False, ascending=False
)
fig, ax = plt.subplots(figsize=(10, 6))
sns.barplot(
x=family_status_df.values,
y=family_status_df.index,
hue=family_status_df.index,
palette=custom_palette[:6],
legend=False,
)
ax.set_xlabel("Number of Applicants")
ax.set_ylabel("Family Status")
ax.grid(axis="x", linestyle="--", alpha=0.5)
fig.tight_layout()
return family_status_df, fig
def plot_income_type(df: DataFrame) -> Figure:
"""
Plot the count of income types for each target group.
Args:
df (DataFrame): The DataFrame containing the data.
Returns:
Figure: A Matplotlib Figure object containing the plot.
"""
fig, ax1 = plt.subplots(figsize=(10, 6))
sns.countplot(
data=df, y="NAME_INCOME_TYPE", hue="TARGET", palette=custom_palette[:2]
)
ax1.set_xlabel("Number of Applicants")
ax1.set_ylabel("Income Type")
ax1.grid(axis="x", linestyle="--", alpha=0.5)
fig.tight_layout()
return fig
|