iBrokeTheCode commited on
Commit
142679c
·
1 Parent(s): 9894e45

refactor: Return plot instead of use plt.plot()

Browse files
Files changed (1) hide show
  1. src/plots.py +149 -98
src/plots.py CHANGED
@@ -1,5 +1,6 @@
1
  import matplotlib.pyplot as plt
2
  import plotly.express as px
 
3
  import seaborn as sns
4
  from matplotlib import rc_file_defaults
5
  from matplotlib.figure import Figure
@@ -35,42 +36,61 @@ def plot_revenue_by_month_year(df: DataFrame, year: int) -> Figure:
35
  return fig
36
 
37
 
38
- def plot_real_vs_predicted_delivered_time(df: DataFrame, year: int) -> None:
39
  """
40
- Plot the real vs predicted delivered time
 
 
 
 
41
 
42
  Args:
43
- df (DataFrame): The dataframe
44
- year (int): The year
 
 
 
 
 
 
45
  """
46
  rc_file_defaults()
47
  sns.set_style(style=None, rc=None)
48
 
49
- _, ax1 = plt.subplots(figsize=(12, 6))
50
 
51
  sns.lineplot(data=df[f"Year{year}_real_time"], marker="o", sort=False, ax=ax1)
52
- ax1.twinx()
53
- g = sns.lineplot(
54
- data=df[f"Year{year}_estimated_time"], marker="o", sort=False, ax=ax1
55
- )
56
- g.set_xticks(range(len(df)))
57
- g.set_xticklabels(df.month.values)
58
- g.set(xlabel="month", ylabel="Average days delivery time", title="some title")
59
- ax1.set_title(f"Average days delivery time by month in {year}")
60
- ax1.legend(["Real time", "Estimated time"])
61
 
62
- plt.show()
 
 
 
 
 
 
 
63
 
64
 
65
- def plot_global_amount_order_status(df: DataFrame) -> None:
 
 
 
 
66
  """
67
- Plot global amount of order status
68
 
69
  Args:
70
- df (DataFrame): The dataframe
 
 
 
 
 
71
  """
72
- _, ax = plt.subplots(figsize=(8, 3), subplot_kw=dict(aspect="equal"))
73
 
 
74
  elements = [x.split()[-1] for x in df["order_status"]]
75
 
76
  wedges, autotexts = ax.pie(df["Amount"], textprops=dict(color="w"))
@@ -84,42 +104,51 @@ def plot_global_amount_order_status(df: DataFrame) -> None:
84
  )
85
 
86
  plt.setp(autotexts, size=8, weight="bold")
87
-
88
  ax.set_title("Order Status Total")
89
 
90
- my_circle = plt.Circle((0, 0), 0.7, color="white")
91
- p = plt.gcf()
92
- p.gca().add_artist(my_circle)
93
 
94
- plt.show()
95
 
96
 
97
- def plot_revenue_per_state(df: DataFrame) -> None:
98
  """
99
- Plot revenue per state
100
 
101
  Args:
102
- df (DataFrame): The dataframe
 
 
 
 
 
103
  """
104
  fig = px.treemap(
105
  df, path=["customer_state"], values="Revenue", width=800, height=300
106
  )
107
  fig.update_layout(margin=dict(t=50, l=25, r=25, b=25))
108
- fig.show()
109
 
110
 
111
- def plot_top_10_least_revenue_categories(df: DataFrame) -> None:
112
  """
113
- Plot top 10 least revenue categories
114
 
115
  Args:
116
- df (DataFrame): The dataframe
 
 
 
 
 
117
  """
118
- _, ax = plt.subplots(figsize=(6, 3), subplot_kw=dict(aspect="equal"))
119
 
120
  elements = [x.split()[-1] for x in df["Category"]]
121
-
122
  revenue = df["Revenue"]
 
123
  wedges, autotexts = ax.pie(revenue, textprops=dict(color="w"))
124
 
125
  ax.legend(
@@ -131,27 +160,31 @@ def plot_top_10_least_revenue_categories(df: DataFrame) -> None:
131
  )
132
 
133
  plt.setp(autotexts, size=8, weight="bold")
134
- my_circle = plt.Circle((0, 0), 0.7, color="white")
135
- p = plt.gcf()
136
- p.gca().add_artist(my_circle)
137
-
138
  ax.set_title("Top 10 Least Revenue Categories Amount")
139
 
140
- plt.show()
 
 
 
141
 
142
 
143
- def plot_top_10_revenue_categories_amount(df: DataFrame) -> None:
144
- """Plot top 10 revenue categories
 
145
 
146
  Args:
147
- df (DataFrame): Dataframe with top 10 revenue categories query result
 
 
 
 
 
148
  """
149
- # Plotting the top 10 revenue categories amount
150
- _, ax = plt.subplots(figsize=(6, 3), subplot_kw=dict(aspect="equal"))
151
 
152
  elements = [x.split()[-1] for x in df["Category"]]
153
-
154
  revenue = df["Revenue"]
 
155
  wedges, autotexts = ax.pie(revenue, textprops=dict(color="w"))
156
 
157
  ax.legend(
@@ -163,89 +196,107 @@ def plot_top_10_revenue_categories_amount(df: DataFrame) -> None:
163
  )
164
 
165
  plt.setp(autotexts, size=8, weight="bold")
166
- my_circle = plt.Circle((0, 0), 0.7, color="white")
167
- p = plt.gcf()
168
- p.gca().add_artist(my_circle)
169
 
170
  ax.set_title("Top 10 Revenue Categories Amount")
171
 
172
- plt.show()
 
 
 
173
 
174
 
175
- def plot_top_10_revenue_categories(df: DataFrame) -> None:
176
- """Plot top 10 revenue categories
 
177
 
178
  Args:
179
- df (DataFrame): Dataframe with top 10 revenue categories query result
 
 
 
 
 
180
  """
181
  fig = px.treemap(df, path=["Category"], values="Num_order", width=800, height=400)
182
  fig.update_layout(margin=dict(t=50, l=25, r=25, b=25))
183
- fig.show()
184
 
185
 
186
- def plot_freight_value_weight_relationship(df: DataFrame) -> None:
187
- """Plot freight value weight relationship
 
188
 
189
  Args:
190
- df (DataFrame): Dataframe with freight value weight relationship query result
 
 
 
 
 
191
  """
192
- # Set the figure size
193
- plt.figure(figsize=(8, 4))
194
 
195
- # Scatter plot: x=product weight, y=freight value
196
  sns.scatterplot(
197
- data=df,
198
- x="product_weight_g",
199
- y="freight_value",
200
- edgecolor="white",
201
  )
202
 
203
- # Customize chart
204
- plt.title("Freight Value vs Product Weight")
205
- plt.xlabel("Product Weight (g)")
206
- plt.ylabel("Freight Value ($)")
207
- plt.tight_layout()
208
- plt.show()
209
 
 
210
 
211
- def plot_delivery_date_difference(df: DataFrame) -> None:
212
- """Plot delivery date difference
 
 
213
 
214
  Args:
215
- df (DataFrame): Dataframe with delivery date difference query result
 
 
 
 
 
216
  """
217
- plt.figure(figsize=(12, 6))
218
- sns.barplot(data=df, x="Delivery_Difference", y="State").set(
219
- title="Difference Between Delivery Estimate Date and Delivery Date"
220
- )
221
- plt.show()
222
 
 
 
 
 
223
 
224
- def plot_order_amount_per_day_with_holidays(df: DataFrame) -> None:
225
- """Plot order amount per day with holidays
226
 
227
- Args:
228
- df (DataFrame): Dataframe with order amount per day with holidays query result
229
  """
 
230
 
231
- # Convert timestamp in milliseconds to datetime
232
- df["date"] = to_datetime(df["date"], unit="ms")
 
 
 
233
 
234
- # Sort by date
 
 
 
 
235
  df = df.sort_values("date")
236
 
237
- # Plot the line chart for order count
238
- plt.figure(figsize=(9, 4))
239
- plt.plot(df["date"], df["order_count"], color="green")
240
-
241
- # Add vertical lines for holidays
242
- holidays = df[df["holiday"] == True]
243
- for holiday_date in holidays["date"]:
244
- plt.axvline(holiday_date, color="blue", linestyle="dotted", alpha=0.6)
245
-
246
- # Customize chart
247
- plt.title("Order Amount per Day with Holidays")
248
- plt.xlabel("Date")
249
- plt.ylabel("Order Count")
250
- plt.tight_layout()
251
- plt.show()
 
1
  import matplotlib.pyplot as plt
2
  import plotly.express as px
3
+ import plotly.graph_objects as go
4
  import seaborn as sns
5
  from matplotlib import rc_file_defaults
6
  from matplotlib.figure import Figure
 
36
  return fig
37
 
38
 
39
+ def plot_real_vs_predicted_delivered_time(df: DataFrame, year: int) -> Figure:
40
  """
41
+ Generate and return a matplotlib figure comparing real vs. estimated delivery time
42
+ by month for a specific year.
43
+
44
+ Intended for interactive environments like Marimo where returning the figure
45
+ automatically renders the plot.
46
 
47
  Args:
48
+ df (DataFrame): DataFrame with columns:
49
+ - 'month': Month names or numbers.
50
+ - f'Year{year}_real_time': Real average delivery time.
51
+ - f'Year{year}_estimated_time': Estimated average delivery time.
52
+ year (int): The year to visualize (e.g., 2018).
53
+
54
+ Returns:
55
+ Figure: A matplotlib figure with two overlaid line plots.
56
  """
57
  rc_file_defaults()
58
  sns.set_style(style=None, rc=None)
59
 
60
+ fig, ax1 = plt.subplots(figsize=(12, 6))
61
 
62
  sns.lineplot(data=df[f"Year{year}_real_time"], marker="o", sort=False, ax=ax1)
63
+ sns.lineplot(data=df[f"Year{year}_estimated_time"], marker="o", sort=False, ax=ax1)
 
 
 
 
 
 
 
 
64
 
65
+ ax1.set_xticks(range(len(df)))
66
+ ax1.set_xticklabels(df["month"].values)
67
+ ax1.set_xlabel("Month")
68
+ ax1.set_ylabel("Average Days to Deliver")
69
+ ax1.set_title(f"Average Delivery Time (Real vs Estimated) in {year}")
70
+ ax1.legend(["Real Time", "Estimated Time"])
71
+
72
+ return fig
73
 
74
 
75
+ from matplotlib.figure import Figure
76
+ from pandas import DataFrame
77
+
78
+
79
+ def plot_global_amount_order_status(df: DataFrame) -> Figure:
80
  """
81
+ Create and return a donut pie chart showing the global amount per order status.
82
 
83
  Args:
84
+ df (DataFrame): DataFrame containing:
85
+ - 'order_status': Status labels (e.g., 'order delivered').
86
+ - 'Amount': Corresponding counts or totals per status.
87
+
88
+ Returns:
89
+ Figure: A matplotlib figure containing a pie (donut) chart with legend.
90
  """
91
+ fig, ax = plt.subplots(figsize=(8, 3), subplot_kw=dict(aspect="equal"))
92
 
93
+ # Extract last word of each status for cleaner labels
94
  elements = [x.split()[-1] for x in df["order_status"]]
95
 
96
  wedges, autotexts = ax.pie(df["Amount"], textprops=dict(color="w"))
 
104
  )
105
 
106
  plt.setp(autotexts, size=8, weight="bold")
 
107
  ax.set_title("Order Status Total")
108
 
109
+ # Add donut center
110
+ center_circle = plt.Circle((0, 0), 0.7, color="white")
111
+ ax.add_artist(center_circle)
112
 
113
+ return fig
114
 
115
 
116
+ def plot_revenue_per_state(df: DataFrame) -> go.Figure:
117
  """
118
+ Create a Plotly treemap to visualize revenue per customer state.
119
 
120
  Args:
121
+ df (DataFrame): DataFrame with columns:
122
+ - 'customer_state': State or region
123
+ - 'Revenue': Revenue value per state
124
+
125
+ Returns:
126
+ go.Figure: A Plotly treemap figure object.
127
  """
128
  fig = px.treemap(
129
  df, path=["customer_state"], values="Revenue", width=800, height=300
130
  )
131
  fig.update_layout(margin=dict(t=50, l=25, r=25, b=25))
132
+ return fig
133
 
134
 
135
+ def plot_top_10_least_revenue_categories(df: DataFrame) -> Figure:
136
  """
137
+ Create a donut pie chart showing the top 10 least revenue categories.
138
 
139
  Args:
140
+ df (DataFrame): DataFrame with columns:
141
+ - 'Category': Category name
142
+ - 'Revenue': Corresponding revenue values
143
+
144
+ Returns:
145
+ Figure: A matplotlib figure with a donut chart and legend.
146
  """
147
+ fig, ax = plt.subplots(figsize=(6, 3), subplot_kw=dict(aspect="equal"))
148
 
149
  elements = [x.split()[-1] for x in df["Category"]]
 
150
  revenue = df["Revenue"]
151
+
152
  wedges, autotexts = ax.pie(revenue, textprops=dict(color="w"))
153
 
154
  ax.legend(
 
160
  )
161
 
162
  plt.setp(autotexts, size=8, weight="bold")
 
 
 
 
163
  ax.set_title("Top 10 Least Revenue Categories Amount")
164
 
165
+ center_circle = plt.Circle((0, 0), 0.7, color="white")
166
+ ax.add_artist(center_circle)
167
+
168
+ return fig
169
 
170
 
171
+ def plot_top_10_revenue_categories_amount(df: DataFrame) -> Figure:
172
+ """
173
+ Create a donut pie chart showing the revenue distribution of the top 10 categories.
174
 
175
  Args:
176
+ df (DataFrame): DataFrame with columns:
177
+ - 'Category': Category name
178
+ - 'Revenue': Revenue amount
179
+
180
+ Returns:
181
+ Figure: A matplotlib figure object.
182
  """
183
+ fig, ax = plt.subplots(figsize=(6, 3), subplot_kw=dict(aspect="equal"))
 
184
 
185
  elements = [x.split()[-1] for x in df["Category"]]
 
186
  revenue = df["Revenue"]
187
+
188
  wedges, autotexts = ax.pie(revenue, textprops=dict(color="w"))
189
 
190
  ax.legend(
 
196
  )
197
 
198
  plt.setp(autotexts, size=8, weight="bold")
 
 
 
199
 
200
  ax.set_title("Top 10 Revenue Categories Amount")
201
 
202
+ center_circle = plt.Circle((0, 0), 0.7, color="white")
203
+ ax.add_artist(center_circle)
204
+
205
+ return fig
206
 
207
 
208
+ def plot_top_10_revenue_categories(df: DataFrame) -> go.Figure:
209
+ """
210
+ Create a Plotly treemap showing the number of orders for the top 10 revenue categories.
211
 
212
  Args:
213
+ df (DataFrame): DataFrame with columns:
214
+ - 'Category': Category name
215
+ - 'Num_order': Number of orders per category
216
+
217
+ Returns:
218
+ go.Figure: A Plotly treemap figure object.
219
  """
220
  fig = px.treemap(df, path=["Category"], values="Num_order", width=800, height=400)
221
  fig.update_layout(margin=dict(t=50, l=25, r=25, b=25))
222
+ return fig
223
 
224
 
225
+ def plot_freight_value_weight_relationship(df: DataFrame) -> Figure:
226
+ """
227
+ Plot the relationship between product weight and freight value using a scatter plot.
228
 
229
  Args:
230
+ df (DataFrame): DataFrame with columns:
231
+ - 'product_weight_g': Weight of the product in grams
232
+ - 'freight_value': Freight value in dollars
233
+
234
+ Returns:
235
+ Figure: A matplotlib figure object.
236
  """
237
+ fig, ax = plt.subplots(figsize=(8, 4))
 
238
 
 
239
  sns.scatterplot(
240
+ data=df, x="product_weight_g", y="freight_value", edgecolor="white", ax=ax
 
 
 
241
  )
242
 
243
+ ax.set_title("Freight Value vs Product Weight")
244
+ ax.set_xlabel("Product Weight (g)")
245
+ ax.set_ylabel("Freight Value ($)")
246
+ fig.tight_layout()
 
 
247
 
248
+ return fig
249
 
250
+
251
+ def plot_delivery_date_difference(df: DataFrame) -> Figure:
252
+ """
253
+ Plot the difference between estimated and actual delivery dates, grouped by state.
254
 
255
  Args:
256
+ df (DataFrame): DataFrame with columns:
257
+ - 'Delivery_Difference': Difference in days
258
+ - 'State': Destination state
259
+
260
+ Returns:
261
+ Figure: A matplotlib figure object.
262
  """
263
+ fig, ax = plt.subplots(figsize=(12, 6))
 
 
 
 
264
 
265
+ sns.barplot(data=df, x="Delivery_Difference", y="State", ax=ax)
266
+ ax.set_title("Difference Between Delivery Estimate Date and Delivery Date")
267
+ ax.set_xlabel("Delivery Difference (days)")
268
+ ax.set_ylabel("State")
269
 
270
+ fig.tight_layout()
271
+ return fig
272
 
273
+
274
+ def plot_order_amount_per_day_with_holidays(df: DataFrame) -> Figure:
275
  """
276
+ Plot the number of orders per day, highlighting holidays with vertical lines.
277
 
278
+ Args:
279
+ df (DataFrame): DataFrame with columns:
280
+ - 'date': Timestamp in milliseconds
281
+ - 'order_count': Number of orders on that date
282
+ - 'holiday': Boolean indicating if the date is a holiday
283
 
284
+ Returns:
285
+ Figure: A matplotlib figure object.
286
+ """
287
+ df = df.copy()
288
+ df["date"] = to_datetime(df["date"], unit="ms")
289
  df = df.sort_values("date")
290
 
291
+ fig, ax = plt.subplots(figsize=(9, 4))
292
+ ax.plot(df["date"], df["order_count"], color="green")
293
+
294
+ for holiday_date in df[df["holiday"]]["date"]:
295
+ ax.axvline(holiday_date, color="blue", linestyle="dotted", alpha=0.6)
296
+
297
+ ax.set_title("Order Amount per Day with Holidays")
298
+ ax.set_xlabel("Date")
299
+ ax.set_ylabel("Order Count")
300
+ fig.tight_layout()
301
+
302
+ return fig