E-Commerce_ELT / src /plots.py
iBrokeTheCode's picture
refactor: Add theme and improve plots
4523429
import matplotlib.dates as mdates
import matplotlib.pyplot as plt
import plotly.express as px
import plotly.graph_objects as go
import seaborn as sns
from matplotlib import rc_file_defaults
from matplotlib.figure import Figure
from pandas import DataFrame, to_datetime
from src.utils.theme import apply_custom_palette, custom_palette
def plot_revenue_by_month_year(df: DataFrame, year: int) -> Figure:
"""
Generate a matplotlib figure showing monthly revenue for a given year,
using consistent color styling.
"""
# Set the theme
apply_custom_palette()
# Clear any previous settings and set seaborn style
sns.set_style("whitegrid")
fig, ax1 = plt.subplots(figsize=(12, 4))
# Line plot for revenue trend
sns.lineplot(
data=df[f"Year{year}"],
marker="o",
sort=False,
linewidth=2,
ax=ax1,
label=f"Line: Revenue {year}",
)
# Bar plot with light transparency
ax2 = ax1.twinx()
sns.barplot(
data=df,
x="month",
y=f"Year{year}",
alpha=0.4,
ax=ax2,
label=f"Bar: Revenue {year}",
)
# Beautify axes
ax1.set_ylabel("Revenue")
ax1.set_xlabel("Month")
ax1.grid(True, linestyle="--", alpha=0.5)
# Optional: display value annotations on bars
for i, value in enumerate(df[f"Year{year}"]):
ax2.text(
i,
value + value * 0.02, # small offset above bar
f"{int(value):,}",
ha="center",
va="bottom",
fontsize=8,
color="black",
)
# Remove default plot title (you handle titles in Marimo)
ax1.set_title("")
fig.tight_layout()
return fig
def plot_real_vs_predicted_delivered_time(df: DataFrame, year: int) -> Figure:
"""
Create a line plot comparing real vs. estimated delivery time
by month for a given year.
"""
rc_file_defaults()
sns.set_style("whitegrid") # Use light grid for clarity
fig, ax = plt.subplots(figsize=(12, 4))
# Plot each line with explicit color and label
sns.lineplot(
x=df["month"],
y=df[f"Year{year}_real_time"],
marker="o",
label="Real Time",
color=custom_palette[0],
ax=ax,
)
sns.lineplot(
x=df["month"],
y=df[f"Year{year}_estimated_time"],
marker="s",
label="Estimated Time",
color=custom_palette[1],
ax=ax,
)
# Axis labeling and ticks
ax.set_xlabel("Month")
ax.set_ylabel("Average Days to Deliver")
ax.set_xticks(range(len(df)))
ax.set_xticklabels(df["month"].values, rotation=45)
# Legend configuration
ax.legend(title="", loc="upper right")
# Improve spacing
fig.tight_layout()
return fig
def plot_global_amount_order_status(df: DataFrame) -> Figure:
"""
Create a horizontal bar chart showing the global amount per order status.
Args:
df (DataFrame): DataFrame with:
- 'order_status': Status labels (e.g., 'order delivered')
- 'Amount': Count or value per status
Returns:
Figure: A matplotlib bar chart figure.
"""
rc_file_defaults()
fig, ax = plt.subplots(figsize=(10, 5))
df = df.copy()
df["short_status"] = df["order_status"].apply(lambda x: x.split()[-1].capitalize())
sorted_df = df.sort_values("Amount", ascending=True)
colors = custom_palette[: len(sorted_df)]
bars = ax.barh(
sorted_df["short_status"], sorted_df["Amount"], color=colors, edgecolor="black"
)
# Add value labels
for bar in bars:
width = bar.get_width()
ax.text(
width + 50,
bar.get_y() + bar.get_height() / 2,
f"{int(width):,}",
va="center",
fontsize=9,
color="black",
)
ax.set_xlabel("Amount")
ax.set_ylabel("Order Status")
ax.grid(axis="x", linestyle="--", alpha=0.4)
fig.tight_layout()
return fig
def plot_revenue_per_state(df: DataFrame) -> go.Figure:
"""
Create a Plotly treemap to visualize revenue per customer state,
using a consistent custom color palette.
"""
fig = px.treemap(
df,
path=["customer_state"],
values="Revenue",
color="customer_state", # Important to trigger color mapping
color_discrete_sequence=custom_palette,
width=800,
height=300,
)
# Add label customization
fig.update_traces(
textinfo="label+percent entry+value", # show label, percentage, and raw value
textfont_size=14,
marker=dict(
line=dict(color="#FFFFFF", width=1)
), # white borders between blocks
)
fig.update_layout(
margin=dict(t=20, l=20, r=20, b=20),
uniformtext=dict(minsize=12, mode="hide"),
)
return fig
def plot_top_10_least_revenue_categories(df: DataFrame) -> Figure:
"""
Create a horizontal bar chart showing the top 10 least revenue categories.
Args:
df (DataFrame): DataFrame with columns:
- 'Category': Category name
- 'Revenue': Corresponding revenue values
Returns:
Figure: A matplotlib figure with a horizontal bar chart.
"""
rc_file_defaults()
fig, ax = plt.subplots(figsize=(10, 6))
# Sort and plot
sorted_df = df.sort_values("Revenue", ascending=True)
colors = custom_palette[: len(sorted_df)]
bars = ax.barh(
sorted_df["Category"], sorted_df["Revenue"], color=colors, edgecolor="black"
)
# Add value labels
for bar in bars:
width = bar.get_width()
ax.text(
width + 100, # shift label to the right of the bar
bar.get_y() + bar.get_height() / 2,
f"${int(width):,}",
va="center",
fontsize=9,
color="black",
)
ax.set_xlabel("Revenue")
ax.set_ylabel("Category")
ax.grid(axis="x", linestyle="--", alpha=0.4)
fig.tight_layout()
return fig
def plot_top_10_revenue_categories_amount(df: DataFrame) -> Figure:
"""
Create a horizontal bar chart showing the revenue of the top 10 categories.
Args:
df (DataFrame): DataFrame with columns:
- 'Category': Category name
- 'Revenue': Revenue amount
Returns:
Figure: A matplotlib figure object.
"""
rc_file_defaults()
fig, ax = plt.subplots(figsize=(10, 6))
sorted_df = df.sort_values("Revenue", ascending=True)
colors = custom_palette[: len(sorted_df)]
bars = ax.barh(
sorted_df["Category"], sorted_df["Revenue"], color=colors, edgecolor="black"
)
# Add value labels on the right
for bar in bars:
width = bar.get_width()
ax.text(
width + 100,
bar.get_y() + bar.get_height() / 2,
f"${int(width):,}",
va="center",
fontsize=9,
color="black",
)
ax.set_xlabel("Revenue")
ax.set_ylabel("Category")
ax.grid(axis="x", linestyle="--", alpha=0.4)
fig.tight_layout()
return fig
def plot_top_10_revenue_categories(df: DataFrame) -> go.Figure:
"""
Create a Plotly treemap showing the number of orders for the top 10 revenue categories.
Args:
df (DataFrame): DataFrame with columns:
- 'Category': Category name
- 'Num_order': Number of orders per category
Returns:
go.Figure: A Plotly treemap figure object.
"""
fig = px.treemap(
df,
path=["Category"],
values="Num_order",
color="Num_order",
color_continuous_scale=custom_palette, # Optional for consistency
hover_data={"Num_order": ":,"}, # Adds commas to values
width=800,
height=400,
)
fig.update_layout(
margin=dict(t=40, l=30, r=30, b=30),
coloraxis_showscale=False, # Optional: hides legend bar
)
return fig
def plot_freight_value_weight_relationship(df: DataFrame) -> Figure:
"""
Plot the relationship between product weight and freight value using a scatter plot.
Args:
df (DataFrame): DataFrame with columns:
- 'product_weight_g': Weight of the product in grams
- 'freight_value': Freight value in dollars
Returns:
Figure: A matplotlib figure object.
"""
rc_file_defaults()
fig, ax = plt.subplots(figsize=(10, 5))
sns.scatterplot(
data=df,
x="product_weight_g",
y="freight_value",
color=custom_palette[2],
edgecolor="white",
alpha=0.7,
s=50,
ax=ax,
)
ax.set_xlabel("Product Weight (grams)")
ax.set_ylabel("Freight Value ($)")
ax.grid(True, linestyle="--", alpha=0.5)
fig.tight_layout()
return fig
def plot_delivery_date_difference(df: DataFrame) -> Figure:
"""
Plot the difference between estimated and actual delivery dates, grouped by state.
Args:
df (DataFrame): DataFrame with columns:
- 'Delivery_Difference': Difference in days
- 'State': Destination state
Returns:
Figure: A matplotlib figure object.
"""
rc_file_defaults()
fig, ax = plt.subplots(figsize=(10, 6))
sns.barplot(
data=df, x="Delivery_Difference", y="State", color=custom_palette[0], ax=ax
)
ax.set_title(
"Difference Between Estimated and Actual Delivery Dates by State",
fontsize=12,
weight="bold",
)
ax.set_xlabel("Delivery Difference (Days)")
ax.set_ylabel("State")
ax.grid(True, linestyle="--", alpha=0.4, axis="x")
fig.tight_layout()
return fig
def plot_order_amount_per_day_with_holidays(df: DataFrame) -> Figure:
"""
Plot the number of orders per day, highlighting holidays with vertical lines.
Args:
df (DataFrame): DataFrame with columns:
- 'date': Timestamp in milliseconds
- 'order_count': Number of orders on that date
- 'holiday': Boolean indicating if the date is a holiday
Returns:
Figure: A matplotlib figure object.
"""
rc_file_defaults()
df = df.copy()
df["date"] = to_datetime(df["date"], unit="ms")
df = df.sort_values("date")
fig, ax = plt.subplots(figsize=(12, 4))
ax.plot(df["date"], df["order_count"], color=custom_palette[2], label="Order Count")
for holiday_date in df[df["holiday"]]["date"]:
ax.axvline(
holiday_date,
color=custom_palette[3],
linestyle="--",
alpha=0.4,
label="Holiday",
)
ax.set_xlabel("Date")
ax.set_ylabel("Order Count")
ax.xaxis.set_major_locator(mdates.MonthLocator())
ax.xaxis.set_major_formatter(mdates.DateFormatter("%b %Y"))
ax.tick_params(axis="x", rotation=45)
ax.grid(True, linestyle="--", alpha=0.5)
handles, labels = ax.get_legend_handles_labels()
by_label = dict(zip(labels, handles)) # avoid duplicate "Holiday" entries
ax.legend(by_label.values(), by_label.keys(), loc="upper left")
fig.tight_layout()
return fig