Gumball2k5 commited on
Commit
0c89273
·
verified ·
1 Parent(s): c934d79

Create diagnostic_plots.py

Browse files
Files changed (1) hide show
  1. src/diagnostic_plots.py +179 -0
src/diagnostic_plots.py ADDED
@@ -0,0 +1,179 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pandas as pd
2
+ import plotly.graph_objects as go
3
+ import plotly.express as px
4
+ import streamlit as st # Cần thiết để báo lỗi nếu cột không tồn tại
5
+
6
+ # --- HÀM 1: Biểu đồ suy giảm hiệu suất (Theo Checklist mục 5) ---
7
+
8
+ def plot_performance_degradation(df, metric_column, metric_name, color='blue'):
9
+ """
10
+ Tạo biểu đồ đường (line plot) cho thấy một chỉ số (metric) thay đổi
11
+ như thế nào qua 5 ngày dự báo.
12
+
13
+ Args:
14
+ df (pd.DataFrame): DataFrame được tải từ 'final_5_day_results_df.csv'.
15
+ metric_column (str): Tên cột chính xác trong CSV (ví dụ: 'RMSE', 'R2').
16
+ metric_name (str): Tên hiển thị đẹp cho trục Y (ví dụ: 'RMSE (Temperature °C)').
17
+ color (str): Tên màu cho đường line.
18
+
19
+ Returns:
20
+ plotly.graph_objects.Figure: Một đối tượng biểu đồ Plotly.
21
+ """
22
+
23
+ # --- TÙY CHỈNH QUAN TRỌNG ---
24
+ # Đảm bảo tên cột 'day_ahead' (chứa 1, 2, 3, 4, 5) là chính xác
25
+ DAY_AHEAD_COLUMN = 'day_ahead'
26
+ # ---------------------------
27
+
28
+ if DAY_AHEAD_COLUMN not in df.columns:
29
+ st.error(f"Lỗi plot: Không tìm thấy cột '{DAY_AHEAD_COLUMN}' trong dữ liệu. "
30
+ f"Vui lòng kiểm tra file `src/diagnostic_plots.py`.")
31
+ return go.Figure()
32
+ if metric_column not in df.columns:
33
+ st.error(f"Lỗi plot: Không tìm thấy cột '{metric_column}' trong dữ liệu. "
34
+ f"Vui lòng kiểm tra file `src/diagnostic_plots.py`.")
35
+ return go.Figure()
36
+
37
+ fig = go.Figure()
38
+
39
+ fig.add_trace(go.Scatter(
40
+ x=df[DAY_AHEAD_COLUMN],
41
+ y=df[metric_column],
42
+ mode='lines+markers',
43
+ name=metric_name,
44
+ line=dict(color=color, width=3),
45
+ marker=dict(size=8)
46
+ ))
47
+
48
+ fig.update_layout(
49
+ title=f"<b>{metric_name} vs. Forecast Horizon</b>",
50
+ xaxis_title="Day Ahead (Horizon)",
51
+ yaxis_title=metric_name,
52
+ title_x=0.5, # Căn giữa tiêu đề
53
+ template="plotly_white"
54
+ )
55
+
56
+ # Nếu là R2, đặt giới hạn trục y từ 0 đến 1 cho dễ nhìn
57
+ if "R2" in metric_name or "R-squared" in metric_name:
58
+ fig.update_layout(yaxis_range=[0, 1])
59
+
60
+ return fig
61
+
62
+ # --- HÀM 2: Biểu đồ Dự báo vs. Thực tế (Theo Checklist mục 5) ---
63
+
64
+ def plot_forecast_vs_actual(y_true, y_pred, day_ahead_title):
65
+ """
66
+ Tạo biểu đồ phân tán (scatter plot) so sánh giá trị dự báo và giá trị thực tế.
67
+
68
+ Args:
69
+ y_true (array-like): Mảng chứa các giá trị thực tế.
70
+ y_pred (array-like): Mảng chứa các giá trị dự báo.
71
+ day_ahead_title (str): Tiêu đề phụ (ví dụ: "Day 1" hoặc "Day 5").
72
+
73
+ Returns:
74
+ plotly.graph_objects.Figure: Một đối tượng biểu đồ Plotly.
75
+ """
76
+
77
+ # Tạo DataFrame tạm thời để vẽ
78
+ plot_df = pd.DataFrame({
79
+ 'Actual': y_true,
80
+ 'Predicted': y_pred
81
+ })
82
+
83
+ fig = px.scatter(
84
+ plot_df,
85
+ x='Actual',
86
+ y='Predicted',
87
+ title=f"<b>Forecast vs. Actual - {day_ahead_title}</b>",
88
+ opacity=0.7,
89
+ hover_data={'Actual': ':.2f', 'Predicted': ':.2f'}
90
+ )
91
+
92
+ # Thêm đường chéo (y=x) thể hiện dự báo hoàn hảo
93
+ min_val = min(plot_df['Actual'].min(), plot_df['Predicted'].min())
94
+ max_val = max(plot_df['Actual'].max(), plot_df['Predicted'].max())
95
+
96
+ fig.add_trace(go.Scatter(
97
+ x=[min_val, max_val],
98
+ y=[min_val, max_val],
99
+ mode='lines',
100
+ name='Perfect Prediction',
101
+ line=dict(color='red', dash='dash', width=2)
102
+ ))
103
+
104
+ fig.update_layout(
105
+ title_x=0.5,
106
+ xaxis_title="Actual Temperature (°C)",
107
+ yaxis_title="Predicted Temperature (°C)",
108
+ template="plotly_white"
109
+ )
110
+ return fig
111
+
112
+ # --- CÁC HÀM 3 & 4: Biểu đồ "Deep Dive" (Theo Checklist mục 5 - Tùy chọn) ---
113
+
114
+ def plot_residuals_vs_time(y_true, y_pred, dates, day_ahead_title):
115
+ """
116
+ Tạo biểu đồ phân tán của phần dư (residuals) theo thời gian.
117
+
118
+ Args:
119
+ y_true (array-like): Mảng giá trị thực tế.
120
+ y_pred (array-like): Mảng giá trị dự báo.
121
+ dates (array-like): Mảng chứa ngày tháng tương ứng.
122
+ day_ahead_title (str): Tiêu đề phụ (ví dụ: "Day 1").
123
+
124
+ Returns:
125
+ plotly.graph_objects.Figure: Một đối tượng biểu đồ Plotly.
126
+ """
127
+ residuals = y_true - y_pred
128
+
129
+ plot_df = pd.DataFrame({
130
+ 'Date': dates,
131
+ 'Residual': residuals
132
+ })
133
+
134
+ fig = px.scatter(
135
+ plot_df,
136
+ x='Date',
137
+ y='Residual',
138
+ title=f"<b>Residuals vs. Time - {day_ahead_title}</b>",
139
+ opacity=0.7
140
+ )
141
+
142
+ # Thêm đường y=0 (lỗi bằng 0)
143
+ fig.add_hline(y=0, line=dict(color='red', dash='dash', width=2))
144
+
145
+ fig.update_layout(
146
+ title_x=0.5,
147
+ yaxis_title="Residual (Actual - Predicted)",
148
+ template="plotly_white"
149
+ )
150
+ return fig
151
+
152
+
153
+ def plot_residuals_distribution(y_true, y_pred, day_ahead_title):
154
+ """
155
+ Tạo biểu đồ histogram phân phối của phần dư (residuals).
156
+
157
+ Args:
158
+ y_true (array-like): Mảng giá trị thực tế.
159
+ y_pred (array-like): Mảng giá trị dự báo.
160
+ day_ahead_title (str): Tiêu đề phụ (ví dụ: "Day 1").
161
+
162
+ Returns:
163
+ plotly.graph_objects.Figure: Một đối tượng biểu đồ Plotly.
164
+ """
165
+ residuals = y_true - y_pred
166
+
167
+ fig = px.histogram(
168
+ residuals,
169
+ nbins=50,
170
+ title=f"<b>Residuals Distribution - {day_ahead_title}</b>"
171
+ )
172
+
173
+ fig.update_layout(
174
+ title_x=0.5,
175
+ xaxis_title="Residual (Error)",
176
+ yaxis_title="Count",
177
+ template="plotly_white"
178
+ )
179
+ return fig