import matplotlib.pyplot as plt import seaborn as sns import pandas as pd def plot_forecast(result): forecast = result["forecast"] actual = result["actual"] next_pred = result.get("latest_prediction", []) fig, ax = plt.subplots(figsize=(10, 6)) ax.plot(actual, label="Actual", color="blue", linewidth=2) ax.plot(forecast, label="Forecast", color="orange", linestyle="--", linewidth=2) if next_pred: ax.scatter(range(len(actual), len(actual) + len(next_pred)), next_pred, color="red", label="Next Prediction(s)", zorder=5) for i, val in enumerate(next_pred): # Adjusted label positioning with background for clarity ax.text( len(actual) + i, val, f"{val:.2f}", color="red", fontsize=8, ha='center', va='bottom', bbox=dict(facecolor='white', alpha=0.8, edgecolor='red', boxstyle='round,pad=0.3') ) ax.legend() ax.set_title("Actual vs Forecasted Values") ax.set_xlabel("Time Index") ax.set_ylabel("Value") ax.grid(True) plt.tight_layout() return fig def plot_future_forecast(df, result): fig, ax = plt.subplots(figsize=(10, 6)) ax.plot(df['Date'], df['value'], label="Historical Data", color="blue", linewidth=2) if "latest_prediction" in result: last_date = df['Date'].iloc[-1] horizon = len(result["latest_prediction"]) future_dates = pd.date_range(start=last_date + pd.Timedelta(days=1), periods=horizon, freq='B') ax.plot(future_dates, result["latest_prediction"], label="Forecast", color="orange", linestyle="--", linewidth=2) for i, val in enumerate(result["latest_prediction"]): ax.text(future_dates[i], val, f"{val:.2f}", color="orange") ax.legend() ax.set_title("Historical Data and Future Forecast") ax.set_xlabel("Date") ax.set_ylabel("Value") ax.grid(True) plt.tight_layout() return fig def plot_metrics_r2(result): fig, ax = plt.subplots(figsize=(6, 4)) metrics = {k: v for k, v in result['metrics'].items() if k in ['R2', 'MAPE']} sns.barplot(x=list(metrics.keys()), y=list(metrics.values()), ax=ax, palette="Blues_d") ax.set_title("R² and MAPE Metrics") ax.set_ylim(-1, 1) # R² range, MAPE typically small ax.grid(True) plt.tight_layout() return fig def plot_metrics_errors(result): fig, ax = plt.subplots(figsize=(6, 4)) metrics = {k: v for k, v in result['metrics'].items() if k in ['RMSE', 'MAE']} sns.barplot(x=list(metrics.keys()), y=list(metrics.values()), ax=ax, palette="Reds_d") ax.set_title("RMSE and MAE Metrics") ax.grid(True) plt.tight_layout() return fig def plot_loss_curve(result): train_losses = result.get('train_loss', []) val_losses = result.get('val_loss', []) fig, ax = plt.subplots(figsize=(8, 5)) ax.plot(train_losses, label='Train Loss', color='blue', linewidth=2) if val_losses: ax.plot(val_losses, label='Validation Loss', color='orange', linewidth=2) ax.set_title('Training and Validation Loss Curve') ax.set_xlabel('Epoch') ax.set_ylabel('Loss (MSE)') ax.legend() ax.grid(True) plt.tight_layout() return fig