Spaces:
Sleeping
Sleeping
| import matplotlib.dates as mdates | |
| import matplotlib.pyplot as plt | |
| import plotly.express as px | |
| import plotly.graph_objects as go | |
| import seaborn as sns | |
| from matplotlib import rc_file_defaults | |
| from matplotlib.figure import Figure | |
| from pandas import DataFrame, to_datetime | |
| from src.utils.theme import apply_custom_palette, custom_palette | |
| def plot_revenue_by_month_year(df: DataFrame, year: int) -> Figure: | |
| """ | |
| Generate a matplotlib figure showing monthly revenue for a given year, | |
| using consistent color styling. | |
| """ | |
| # Set the theme | |
| apply_custom_palette() | |
| # Clear any previous settings and set seaborn style | |
| sns.set_style("whitegrid") | |
| fig, ax1 = plt.subplots(figsize=(12, 4)) | |
| # Line plot for revenue trend | |
| sns.lineplot( | |
| data=df[f"Year{year}"], | |
| marker="o", | |
| sort=False, | |
| linewidth=2, | |
| ax=ax1, | |
| label=f"Line: Revenue {year}", | |
| ) | |
| # Bar plot with light transparency | |
| ax2 = ax1.twinx() | |
| sns.barplot( | |
| data=df, | |
| x="month", | |
| y=f"Year{year}", | |
| alpha=0.4, | |
| ax=ax2, | |
| label=f"Bar: Revenue {year}", | |
| ) | |
| # Beautify axes | |
| ax1.set_ylabel("Revenue") | |
| ax1.set_xlabel("Month") | |
| ax1.grid(True, linestyle="--", alpha=0.5) | |
| # Optional: display value annotations on bars | |
| for i, value in enumerate(df[f"Year{year}"]): | |
| ax2.text( | |
| i, | |
| value + value * 0.02, # small offset above bar | |
| f"{int(value):,}", | |
| ha="center", | |
| va="bottom", | |
| fontsize=8, | |
| color="black", | |
| ) | |
| # Remove default plot title (you handle titles in Marimo) | |
| ax1.set_title("") | |
| fig.tight_layout() | |
| return fig | |
| def plot_real_vs_predicted_delivered_time(df: DataFrame, year: int) -> Figure: | |
| """ | |
| Create a line plot comparing real vs. estimated delivery time | |
| by month for a given year. | |
| """ | |
| rc_file_defaults() | |
| sns.set_style("whitegrid") # Use light grid for clarity | |
| fig, ax = plt.subplots(figsize=(12, 4)) | |
| # Plot each line with explicit color and label | |
| sns.lineplot( | |
| x=df["month"], | |
| y=df[f"Year{year}_real_time"], | |
| marker="o", | |
| label="Real Time", | |
| color=custom_palette[0], | |
| ax=ax, | |
| ) | |
| sns.lineplot( | |
| x=df["month"], | |
| y=df[f"Year{year}_estimated_time"], | |
| marker="s", | |
| label="Estimated Time", | |
| color=custom_palette[1], | |
| ax=ax, | |
| ) | |
| # Axis labeling and ticks | |
| ax.set_xlabel("Month") | |
| ax.set_ylabel("Average Days to Deliver") | |
| ax.set_xticks(range(len(df))) | |
| ax.set_xticklabels(df["month"].values, rotation=45) | |
| # Legend configuration | |
| ax.legend(title="", loc="upper right") | |
| # Improve spacing | |
| fig.tight_layout() | |
| return fig | |
| def plot_global_amount_order_status(df: DataFrame) -> Figure: | |
| """ | |
| Create a horizontal bar chart showing the global amount per order status. | |
| Args: | |
| df (DataFrame): DataFrame with: | |
| - 'order_status': Status labels (e.g., 'order delivered') | |
| - 'Amount': Count or value per status | |
| Returns: | |
| Figure: A matplotlib bar chart figure. | |
| """ | |
| rc_file_defaults() | |
| fig, ax = plt.subplots(figsize=(10, 5)) | |
| df = df.copy() | |
| df["short_status"] = df["order_status"].apply(lambda x: x.split()[-1].capitalize()) | |
| sorted_df = df.sort_values("Amount", ascending=True) | |
| colors = custom_palette[: len(sorted_df)] | |
| bars = ax.barh( | |
| sorted_df["short_status"], sorted_df["Amount"], color=colors, edgecolor="black" | |
| ) | |
| # Add value labels | |
| for bar in bars: | |
| width = bar.get_width() | |
| ax.text( | |
| width + 50, | |
| bar.get_y() + bar.get_height() / 2, | |
| f"{int(width):,}", | |
| va="center", | |
| fontsize=9, | |
| color="black", | |
| ) | |
| ax.set_xlabel("Amount") | |
| ax.set_ylabel("Order Status") | |
| ax.grid(axis="x", linestyle="--", alpha=0.4) | |
| fig.tight_layout() | |
| return fig | |
| def plot_revenue_per_state(df: DataFrame) -> go.Figure: | |
| """ | |
| Create a Plotly treemap to visualize revenue per customer state, | |
| using a consistent custom color palette. | |
| """ | |
| fig = px.treemap( | |
| df, | |
| path=["customer_state"], | |
| values="Revenue", | |
| color="customer_state", # Important to trigger color mapping | |
| color_discrete_sequence=custom_palette, | |
| width=800, | |
| height=300, | |
| ) | |
| # Add label customization | |
| fig.update_traces( | |
| textinfo="label+percent entry+value", # show label, percentage, and raw value | |
| textfont_size=14, | |
| marker=dict( | |
| line=dict(color="#FFFFFF", width=1) | |
| ), # white borders between blocks | |
| ) | |
| fig.update_layout( | |
| margin=dict(t=20, l=20, r=20, b=20), | |
| uniformtext=dict(minsize=12, mode="hide"), | |
| ) | |
| return fig | |
| def plot_top_10_least_revenue_categories(df: DataFrame) -> Figure: | |
| """ | |
| Create a horizontal bar chart showing the top 10 least revenue categories. | |
| Args: | |
| df (DataFrame): DataFrame with columns: | |
| - 'Category': Category name | |
| - 'Revenue': Corresponding revenue values | |
| Returns: | |
| Figure: A matplotlib figure with a horizontal bar chart. | |
| """ | |
| rc_file_defaults() | |
| fig, ax = plt.subplots(figsize=(10, 6)) | |
| # Sort and plot | |
| sorted_df = df.sort_values("Revenue", ascending=True) | |
| colors = custom_palette[: len(sorted_df)] | |
| bars = ax.barh( | |
| sorted_df["Category"], sorted_df["Revenue"], color=colors, edgecolor="black" | |
| ) | |
| # Add value labels | |
| for bar in bars: | |
| width = bar.get_width() | |
| ax.text( | |
| width + 100, # shift label to the right of the bar | |
| bar.get_y() + bar.get_height() / 2, | |
| f"${int(width):,}", | |
| va="center", | |
| fontsize=9, | |
| color="black", | |
| ) | |
| ax.set_xlabel("Revenue") | |
| ax.set_ylabel("Category") | |
| ax.grid(axis="x", linestyle="--", alpha=0.4) | |
| fig.tight_layout() | |
| return fig | |
| def plot_top_10_revenue_categories_amount(df: DataFrame) -> Figure: | |
| """ | |
| Create a horizontal bar chart showing the revenue of the top 10 categories. | |
| Args: | |
| df (DataFrame): DataFrame with columns: | |
| - 'Category': Category name | |
| - 'Revenue': Revenue amount | |
| Returns: | |
| Figure: A matplotlib figure object. | |
| """ | |
| rc_file_defaults() | |
| fig, ax = plt.subplots(figsize=(10, 6)) | |
| sorted_df = df.sort_values("Revenue", ascending=True) | |
| colors = custom_palette[: len(sorted_df)] | |
| bars = ax.barh( | |
| sorted_df["Category"], sorted_df["Revenue"], color=colors, edgecolor="black" | |
| ) | |
| # Add value labels on the right | |
| for bar in bars: | |
| width = bar.get_width() | |
| ax.text( | |
| width + 100, | |
| bar.get_y() + bar.get_height() / 2, | |
| f"${int(width):,}", | |
| va="center", | |
| fontsize=9, | |
| color="black", | |
| ) | |
| ax.set_xlabel("Revenue") | |
| ax.set_ylabel("Category") | |
| ax.grid(axis="x", linestyle="--", alpha=0.4) | |
| fig.tight_layout() | |
| return fig | |
| def plot_top_10_revenue_categories(df: DataFrame) -> go.Figure: | |
| """ | |
| Create a Plotly treemap showing the number of orders for the top 10 revenue categories. | |
| Args: | |
| df (DataFrame): DataFrame with columns: | |
| - 'Category': Category name | |
| - 'Num_order': Number of orders per category | |
| Returns: | |
| go.Figure: A Plotly treemap figure object. | |
| """ | |
| fig = px.treemap( | |
| df, | |
| path=["Category"], | |
| values="Num_order", | |
| color="Num_order", | |
| color_continuous_scale=custom_palette, # Optional for consistency | |
| hover_data={"Num_order": ":,"}, # Adds commas to values | |
| width=800, | |
| height=400, | |
| ) | |
| fig.update_layout( | |
| margin=dict(t=40, l=30, r=30, b=30), | |
| coloraxis_showscale=False, # Optional: hides legend bar | |
| ) | |
| return fig | |
| def plot_freight_value_weight_relationship(df: DataFrame) -> Figure: | |
| """ | |
| Plot the relationship between product weight and freight value using a scatter plot. | |
| Args: | |
| df (DataFrame): DataFrame with columns: | |
| - 'product_weight_g': Weight of the product in grams | |
| - 'freight_value': Freight value in dollars | |
| Returns: | |
| Figure: A matplotlib figure object. | |
| """ | |
| rc_file_defaults() | |
| fig, ax = plt.subplots(figsize=(10, 5)) | |
| sns.scatterplot( | |
| data=df, | |
| x="product_weight_g", | |
| y="freight_value", | |
| color=custom_palette[2], | |
| edgecolor="white", | |
| alpha=0.7, | |
| s=50, | |
| ax=ax, | |
| ) | |
| ax.set_xlabel("Product Weight (grams)") | |
| ax.set_ylabel("Freight Value ($)") | |
| ax.grid(True, linestyle="--", alpha=0.5) | |
| fig.tight_layout() | |
| return fig | |
| def plot_delivery_date_difference(df: DataFrame) -> Figure: | |
| """ | |
| Plot the difference between estimated and actual delivery dates, grouped by state. | |
| Args: | |
| df (DataFrame): DataFrame with columns: | |
| - 'Delivery_Difference': Difference in days | |
| - 'State': Destination state | |
| Returns: | |
| Figure: A matplotlib figure object. | |
| """ | |
| rc_file_defaults() | |
| fig, ax = plt.subplots(figsize=(10, 6)) | |
| sns.barplot( | |
| data=df, x="Delivery_Difference", y="State", color=custom_palette[0], ax=ax | |
| ) | |
| ax.set_title( | |
| "Difference Between Estimated and Actual Delivery Dates by State", | |
| fontsize=12, | |
| weight="bold", | |
| ) | |
| ax.set_xlabel("Delivery Difference (Days)") | |
| ax.set_ylabel("State") | |
| ax.grid(True, linestyle="--", alpha=0.4, axis="x") | |
| fig.tight_layout() | |
| return fig | |
| def plot_order_amount_per_day_with_holidays(df: DataFrame) -> Figure: | |
| """ | |
| Plot the number of orders per day, highlighting holidays with vertical lines. | |
| Args: | |
| df (DataFrame): DataFrame with columns: | |
| - 'date': Timestamp in milliseconds | |
| - 'order_count': Number of orders on that date | |
| - 'holiday': Boolean indicating if the date is a holiday | |
| Returns: | |
| Figure: A matplotlib figure object. | |
| """ | |
| rc_file_defaults() | |
| df = df.copy() | |
| df["date"] = to_datetime(df["date"], unit="ms") | |
| df = df.sort_values("date") | |
| fig, ax = plt.subplots(figsize=(12, 4)) | |
| ax.plot(df["date"], df["order_count"], color=custom_palette[2], label="Order Count") | |
| for holiday_date in df[df["holiday"]]["date"]: | |
| ax.axvline( | |
| holiday_date, | |
| color=custom_palette[3], | |
| linestyle="--", | |
| alpha=0.4, | |
| label="Holiday", | |
| ) | |
| ax.set_xlabel("Date") | |
| ax.set_ylabel("Order Count") | |
| ax.xaxis.set_major_locator(mdates.MonthLocator()) | |
| ax.xaxis.set_major_formatter(mdates.DateFormatter("%b %Y")) | |
| ax.tick_params(axis="x", rotation=45) | |
| ax.grid(True, linestyle="--", alpha=0.5) | |
| handles, labels = ax.get_legend_handles_labels() | |
| by_label = dict(zip(labels, handles)) # avoid duplicate "Holiday" entries | |
| ax.legend(by_label.values(), by_label.keys(), loc="upper left") | |
| fig.tight_layout() | |
| return fig | |