updated title of plots
Browse files- make_plot.py +51 -42
make_plot.py
CHANGED
|
@@ -21,37 +21,41 @@ def plot_train_test(df1: pd.DataFrame, df2: pd.DataFrame) -> go.Figure:
|
|
| 21 |
fig = go.Figure()
|
| 22 |
|
| 23 |
# Add the first scatter plot with steelblue color
|
| 24 |
-
fig.add_trace(
|
|
|
|
| 25 |
x=df1.index,
|
| 26 |
y=df1.iloc[:, 0],
|
| 27 |
-
mode=
|
| 28 |
-
name=
|
| 29 |
-
line=dict(color=
|
| 30 |
-
marker=dict(color=
|
| 31 |
-
|
|
|
|
| 32 |
|
| 33 |
# Add the second scatter plot with yellow color
|
| 34 |
-
fig.add_trace(
|
|
|
|
| 35 |
x=df2.index,
|
| 36 |
y=df2.iloc[:, 0],
|
| 37 |
-
mode=
|
| 38 |
-
name=
|
| 39 |
-
line=dict(color=
|
| 40 |
-
marker=dict(color=
|
| 41 |
-
|
|
|
|
| 42 |
|
| 43 |
# Customize the layout
|
| 44 |
fig.update_layout(
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
|
| 51 |
return fig
|
| 52 |
|
| 53 |
|
| 54 |
-
def plot_forecast(df: pd.DataFrame, forecasts: List[pd.DataFrame])
|
| 55 |
"""
|
| 56 |
Plot the true values and forecasts using Plotly.
|
| 57 |
|
|
@@ -67,48 +71,53 @@ def plot_forecast(df: pd.DataFrame, forecasts: List[pd.DataFrame]) -> go.Figure:
|
|
| 67 |
fig = go.Figure()
|
| 68 |
|
| 69 |
# Add the true values trace
|
| 70 |
-
fig.add_trace(
|
|
|
|
| 71 |
x=pd.to_datetime(df.index),
|
| 72 |
y=df.iloc[:, 0],
|
| 73 |
-
mode=
|
| 74 |
-
name=
|
| 75 |
-
line=dict(color=
|
| 76 |
-
|
|
|
|
| 77 |
|
| 78 |
# Add the forecast traces
|
| 79 |
colors = ["green", "blue", "purple"]
|
| 80 |
for i, forecast in enumerate(forecasts):
|
| 81 |
color = colors[i]
|
| 82 |
for sample in forecast.samples:
|
| 83 |
-
fig.add_trace(
|
|
|
|
| 84 |
x=forecast.index.to_timestamp(),
|
| 85 |
y=sample,
|
| 86 |
-
mode=
|
| 87 |
opacity=0.15, # Adjust opacity to control visibility of individual samples
|
| 88 |
-
name=f
|
| 89 |
showlegend=False, # Hide the individual forecast series from the legend
|
| 90 |
-
hoverinfo=
|
| 91 |
-
line=dict(color=color)
|
| 92 |
-
|
|
|
|
| 93 |
# Add the average
|
| 94 |
mean_forecast = np.mean(forecast.samples, axis=0)
|
| 95 |
-
fig.add_trace(
|
|
|
|
| 96 |
x=forecast.index.to_timestamp(),
|
| 97 |
y=mean_forecast,
|
| 98 |
-
mode=
|
| 99 |
-
name=
|
| 100 |
-
line=dict(color=
|
| 101 |
-
|
|
|
|
| 102 |
|
| 103 |
# Customize the layout
|
| 104 |
fig.update_layout(
|
| 105 |
-
|
| 106 |
-
|
| 107 |
-
|
| 108 |
-
|
| 109 |
-
|
| 110 |
-
|
| 111 |
-
)
|
| 112 |
|
| 113 |
# Return the figure
|
| 114 |
return fig
|
|
|
|
| 21 |
fig = go.Figure()
|
| 22 |
|
| 23 |
# Add the first scatter plot with steelblue color
|
| 24 |
+
fig.add_trace(
|
| 25 |
+
go.Scatter(
|
| 26 |
x=df1.index,
|
| 27 |
y=df1.iloc[:, 0],
|
| 28 |
+
mode="lines",
|
| 29 |
+
name="Training Data",
|
| 30 |
+
line=dict(color="steelblue"),
|
| 31 |
+
marker=dict(color="steelblue"),
|
| 32 |
+
)
|
| 33 |
+
)
|
| 34 |
|
| 35 |
# Add the second scatter plot with yellow color
|
| 36 |
+
fig.add_trace(
|
| 37 |
+
go.Scatter(
|
| 38 |
x=df2.index,
|
| 39 |
y=df2.iloc[:, 0],
|
| 40 |
+
mode="lines",
|
| 41 |
+
name="Test Data",
|
| 42 |
+
line=dict(color="gold"),
|
| 43 |
+
marker=dict(color="gold"),
|
| 44 |
+
)
|
| 45 |
+
)
|
| 46 |
|
| 47 |
# Customize the layout
|
| 48 |
fig.update_layout(
|
| 49 |
+
title="Univariate Time Series",
|
| 50 |
+
xaxis=dict(title="Date"),
|
| 51 |
+
yaxis=dict(title="Value"),
|
| 52 |
+
showlegend=True,
|
| 53 |
+
template="plotly_white",
|
| 54 |
+
)
|
| 55 |
return fig
|
| 56 |
|
| 57 |
|
| 58 |
+
def plot_forecast(df: pd.DataFrame, forecasts: List[pd.DataFrame]):
|
| 59 |
"""
|
| 60 |
Plot the true values and forecasts using Plotly.
|
| 61 |
|
|
|
|
| 71 |
fig = go.Figure()
|
| 72 |
|
| 73 |
# Add the true values trace
|
| 74 |
+
fig.add_trace(
|
| 75 |
+
go.Scatter(
|
| 76 |
x=pd.to_datetime(df.index),
|
| 77 |
y=df.iloc[:, 0],
|
| 78 |
+
mode="lines",
|
| 79 |
+
name="True values",
|
| 80 |
+
line=dict(color="black"),
|
| 81 |
+
)
|
| 82 |
+
)
|
| 83 |
|
| 84 |
# Add the forecast traces
|
| 85 |
colors = ["green", "blue", "purple"]
|
| 86 |
for i, forecast in enumerate(forecasts):
|
| 87 |
color = colors[i]
|
| 88 |
for sample in forecast.samples:
|
| 89 |
+
fig.add_trace(
|
| 90 |
+
go.Scatter(
|
| 91 |
x=forecast.index.to_timestamp(),
|
| 92 |
y=sample,
|
| 93 |
+
mode="lines",
|
| 94 |
opacity=0.15, # Adjust opacity to control visibility of individual samples
|
| 95 |
+
name=f"Forecast {i + 1}",
|
| 96 |
showlegend=False, # Hide the individual forecast series from the legend
|
| 97 |
+
hoverinfo="none", # Disable hover information for the forecast series
|
| 98 |
+
line=dict(color=color),
|
| 99 |
+
)
|
| 100 |
+
)
|
| 101 |
# Add the average
|
| 102 |
mean_forecast = np.mean(forecast.samples, axis=0)
|
| 103 |
+
fig.add_trace(
|
| 104 |
+
go.Scatter(
|
| 105 |
x=forecast.index.to_timestamp(),
|
| 106 |
y=mean_forecast,
|
| 107 |
+
mode="lines",
|
| 108 |
+
name="Mean Forecast",
|
| 109 |
+
line=dict(color="red", dash="dash"),
|
| 110 |
+
)
|
| 111 |
+
)
|
| 112 |
|
| 113 |
# Customize the layout
|
| 114 |
fig.update_layout(
|
| 115 |
+
title=f"{df.columns[0]} Forecast",
|
| 116 |
+
yaxis=dict(title=df.columns[0]),
|
| 117 |
+
showlegend=True,
|
| 118 |
+
legend=dict(x=0, y=1, font=dict(size=16)),
|
| 119 |
+
hovermode="x", # Enable x-axis hover for better interactivity
|
| 120 |
+
)
|
|
|
|
| 121 |
|
| 122 |
# Return the figure
|
| 123 |
return fig
|