import pandas as pd
import plotly.graph_objects as go
import plotly.express as px
import streamlit as st # Cần thiết để báo lỗi nếu cột không tồn tại
# --- HÀM 1: Biểu đồ suy giảm hiệu suất (Theo Checklist mục 5) ---
def plot_performance_degradation(df, metric_column, metric_name, color='blue'):
"""
Tạo biểu đồ đường (line plot) cho thấy một chỉ số (metric) thay đổi
như thế nào qua 5 ngày dự báo.
Args:
df (pd.DataFrame): DataFrame được tải từ 'final_5_day_results_df.csv'.
metric_column (str): Tên cột chính xác trong CSV (ví dụ: 'RMSE (Absolute Error)').
metric_name (str): Tên hiển thị đẹp cho trục Y (ví dụ: 'RMSE (Temperature °C)').
color (str): Tên màu cho đường line.
Returns:
plotly.graph_objects.Figure: Một đối tượng biểu đồ Plotly.
"""
# --- TÙY CHỈNH QUAN TRỌNG (ĐÃ SỬA) ---
# Cột chứa "Day 1", "Day 2",... là 'Horizon'
DAY_AHEAD_COLUMN = 'Horizon'
# ---------------------------
if DAY_AHEAD_COLUMN not in df.columns:
st.error(f"Lỗi plot: Không tìm thấy cột '{DAY_AHEAD_COLUMN}' trong dữ liệu. "
f"Vui lòng kiểm tra file `src/diagnostic_plots.py`.")
return go.Figure()
if metric_column not in df.columns:
st.error(f"Lỗi plot: Không tìm thấy cột '{metric_column}' trong dữ liệu. "
f"Vui lòng kiểm tra file `src/diagnostic_plots.py`.")
return go.Figure()
# --- SỬA LỖI LOGIC: Chuyển "Day 1" thành số 1 ---
# Tạo một bản copy để tránh cảnh báo
plot_df = df.copy()
# Trích xuất số từ cột 'Horizon' (ví dụ: 'Day 1' -> 1)
# và tạo cột mới 'day_num'
plot_df['day_num'] = plot_df[DAY_AHEAD_COLUMN].str.extract(r'(\d+)').astype(int)
plot_df = plot_df.sort_values(by='day_num')
# ---------------------------------------------
fig = go.Figure()
fig.add_trace(go.Scatter(
x=plot_df['day_num'], # Dùng cột số 'day_num' mới cho trục X
y=plot_df[metric_column],
mode='lines+markers',
name=metric_name,
line=dict(color=color, width=3),
marker=dict(size=8)
))
fig.update_layout(
title=f"{metric_name} vs. Forecast Horizon",
xaxis_title="Day Ahead (Horizon)",
yaxis_title=metric_name,
title_x=0.5, # Căn giữa tiêu đề
template="plotly_white",
xaxis = dict(tickmode = 'linear', tick0 = 1, dtick = 1) # Đảm bảo trục X là 1, 2, 3, 4, 5
)
# Nếu là R2, đặt giới hạn trục y từ 0 đến 1 cho dễ nhìn
if "R2" in metric_name or "R-squared" in metric_name:
fig.update_layout(yaxis_range=[0, 1])
return fig
# --- HÀM 2: Biểu đồ Dự báo vs. Thực tế (Theo Checklist mục 5) ---
def plot_forecast_vs_actual(y_true, y_pred, day_ahead_title):
"""
Tạo biểu đồ phân tán (scatter plot) so sánh giá trị dự báo và giá trị thực tế.
Args:
y_true (array-like): Mảng chứa các giá trị thực tế.
y_pred (array-like): Mảng chứa các giá trị dự báo.
day_ahead_title (str): Tiêu đề phụ (ví dụ: "Day 1" hoặc "Day 5").
Returns:
plotly.graph_objects.Figure: Một đối tượng biểu đồ Plotly.
"""
# Tạo DataFrame tạm thời để vẽ
plot_df = pd.DataFrame({
'Actual': y_true,
'Predicted': y_pred
})
fig = px.scatter(
plot_df,
x='Actual',
y='Predicted',
title=f"Forecast vs. Actual - {day_ahead_title}",
opacity=0.7,
hover_data={'Actual': ':.2f', 'Predicted': ':.2f'}
)
# Thêm đường chéo (y=x) thể hiện dự báo hoàn hảo
min_val = min(plot_df['Actual'].min(), plot_df['Predicted'].min())
max_val = max(plot_df['Actual'].max(), plot_df['Predicted'].max())
fig.add_trace(go.Scatter(
x=[min_val, max_val],
y=[min_val, max_val],
mode='lines',
name='Perfect Prediction',
line=dict(color='red', dash='dash', width=2)
))
fig.update_layout(
title_x=0.5,
xaxis_title="Actual Temperature (°C)",
yaxis_title="Predicted Temperature (°C)",
template="plotly_white"
)
return fig
# --- CÁC HÀM 3 & 4: Biểu đồ "Deep Dive" (Theo Checklist mục 5 - Tùy chọn) ---
def plot_residuals_vs_time(y_true, y_pred, dates, day_ahead_title):
"""
Tạo biểu đồ phân tán của phần dư (residuals) theo thời gian.
Args:
y_true (array-like): Mảng giá trị thực tế.
y_pred (array-like): Mảng giá trị dự báo.
dates (array-like): Mảng chứa ngày tháng tương ứng.
day_ahead_title (str): Tiêu đề phụ (ví dụ: "Day 1").
Returns:
plotly.graph_objects.Figure: Một đối tượng biểu đồ Plotly.
"""
residuals = y_true - y_pred
plot_df = pd.DataFrame({
'Date': dates,
'Residual': residuals
})
fig = px.scatter(
plot_df,
x='Date',
y='Residual',
title=f"Residuals vs. Time - {day_ahead_title}",
opacity=0.7
)
# Thêm đường y=0 (lỗi bằng 0)
fig.add_hline(y=0, line=dict(color='red', dash='dash', width=2))
fig.update_layout(
title_x=0.5,
yaxis_title="Residual (Actual - Predicted)",
template="plotly_white"
)
return fig
def plot_residuals_distribution(y_true, y_pred, day_ahead_title):
"""
Tạo biểu đồ histogram phân phối của phần dư (residuals).
Args:
y_true (array-like): Mảng giá trị thực tế.
y_pred (array-like): Mảng giá trị dự báo.
day_ahead_title (str): Tiêu đề phụ (ví dụ: "Day 1").
Returns:
plotly.graph_objects.Figure: Một đối tượng biểu đồ Plotly.
"""
residuals = y_true - y_pred
fig = px.histogram(
residuals,
nbins=50,
title=f"Residuals Distribution - {day_ahead_title}"
)
fig.update_layout(
title_x=0.5,
xaxis_title="Residual (Error)",
yaxis_title="Count",
template="plotly_white"
)
return fig