Spaces:
Sleeping
Sleeping
File size: 3,228 Bytes
4034469 acf73d3 658a2a8 00dca38 4034469 6b05aaf 5a26413 6b05aaf 658a2a8 6b05aaf 5a26413 00dca38 6b05aaf 658a2a8 6b05aaf 4034469 658a2a8 00dca38 4034469 00dca38 658a2a8 00dca38 7ab37b4 658a2a8 7ab37b4 5a26413 7ab37b4 658a2a8 7ab37b4 658a2a8 5a26413 7ab37b4 5a26413 7ab37b4 5a26413 658a2a8 4034469 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 | import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
def plot_forecast(result):
forecast = result["forecast"]
actual = result["actual"]
next_pred = result.get("latest_prediction", [])
fig, ax = plt.subplots(figsize=(10, 6))
ax.plot(actual, label="Actual", color="blue", linewidth=2)
ax.plot(forecast, label="Forecast", color="orange", linestyle="--", linewidth=2)
if next_pred:
ax.scatter(range(len(actual), len(actual) + len(next_pred)), next_pred, color="red", label="Next Prediction(s)", zorder=5)
for i, val in enumerate(next_pred):
# Adjusted label positioning with background for clarity
ax.text(
len(actual) + i, val, f"{val:.2f}",
color="red", fontsize=8, ha='center', va='bottom',
bbox=dict(facecolor='white', alpha=0.8, edgecolor='red', boxstyle='round,pad=0.3')
)
ax.legend()
ax.set_title("Actual vs Forecasted Values")
ax.set_xlabel("Time Index")
ax.set_ylabel("Value")
ax.grid(True)
plt.tight_layout()
return fig
def plot_future_forecast(df, result):
fig, ax = plt.subplots(figsize=(10, 6))
ax.plot(df['Date'], df['value'], label="Historical Data", color="blue", linewidth=2)
if "latest_prediction" in result:
last_date = df['Date'].iloc[-1]
horizon = len(result["latest_prediction"])
future_dates = pd.date_range(start=last_date + pd.Timedelta(days=1), periods=horizon, freq='B')
ax.plot(future_dates, result["latest_prediction"], label="Forecast", color="orange", linestyle="--", linewidth=2)
for i, val in enumerate(result["latest_prediction"]):
ax.text(future_dates[i], val, f"{val:.2f}", color="orange")
ax.legend()
ax.set_title("Historical Data and Future Forecast")
ax.set_xlabel("Date")
ax.set_ylabel("Value")
ax.grid(True)
plt.tight_layout()
return fig
def plot_metrics_r2(result):
fig, ax = plt.subplots(figsize=(6, 4))
metrics = {k: v for k, v in result['metrics'].items() if k in ['R2', 'MAPE']}
sns.barplot(x=list(metrics.keys()), y=list(metrics.values()), ax=ax, palette="Blues_d")
ax.set_title("R² and MAPE Metrics")
ax.set_ylim(-1, 1) # R² range, MAPE typically small
ax.grid(True)
plt.tight_layout()
return fig
def plot_metrics_errors(result):
fig, ax = plt.subplots(figsize=(6, 4))
metrics = {k: v for k, v in result['metrics'].items() if k in ['RMSE', 'MAE']}
sns.barplot(x=list(metrics.keys()), y=list(metrics.values()), ax=ax, palette="Reds_d")
ax.set_title("RMSE and MAE Metrics")
ax.grid(True)
plt.tight_layout()
return fig
def plot_loss_curve(result):
train_losses = result.get('train_loss', [])
val_losses = result.get('val_loss', [])
fig, ax = plt.subplots(figsize=(8, 5))
ax.plot(train_losses, label='Train Loss', color='blue', linewidth=2)
if val_losses:
ax.plot(val_losses, label='Validation Loss', color='orange', linewidth=2)
ax.set_title('Training and Validation Loss Curve')
ax.set_xlabel('Epoch')
ax.set_ylabel('Loss (MSE)')
ax.legend()
ax.grid(True)
plt.tight_layout()
return fig |