StockPredict / core /plot.py
aromidvar1355's picture
Update core/plot.py
00dca38 verified
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
def plot_forecast(result):
forecast = result["forecast"]
actual = result["actual"]
next_pred = result.get("latest_prediction", [])
fig, ax = plt.subplots(figsize=(10, 6))
ax.plot(actual, label="Actual", color="blue", linewidth=2)
ax.plot(forecast, label="Forecast", color="orange", linestyle="--", linewidth=2)
if next_pred:
ax.scatter(range(len(actual), len(actual) + len(next_pred)), next_pred, color="red", label="Next Prediction(s)", zorder=5)
for i, val in enumerate(next_pred):
# Adjusted label positioning with background for clarity
ax.text(
len(actual) + i, val, f"{val:.2f}",
color="red", fontsize=8, ha='center', va='bottom',
bbox=dict(facecolor='white', alpha=0.8, edgecolor='red', boxstyle='round,pad=0.3')
)
ax.legend()
ax.set_title("Actual vs Forecasted Values")
ax.set_xlabel("Time Index")
ax.set_ylabel("Value")
ax.grid(True)
plt.tight_layout()
return fig
def plot_future_forecast(df, result):
fig, ax = plt.subplots(figsize=(10, 6))
ax.plot(df['Date'], df['value'], label="Historical Data", color="blue", linewidth=2)
if "latest_prediction" in result:
last_date = df['Date'].iloc[-1]
horizon = len(result["latest_prediction"])
future_dates = pd.date_range(start=last_date + pd.Timedelta(days=1), periods=horizon, freq='B')
ax.plot(future_dates, result["latest_prediction"], label="Forecast", color="orange", linestyle="--", linewidth=2)
for i, val in enumerate(result["latest_prediction"]):
ax.text(future_dates[i], val, f"{val:.2f}", color="orange")
ax.legend()
ax.set_title("Historical Data and Future Forecast")
ax.set_xlabel("Date")
ax.set_ylabel("Value")
ax.grid(True)
plt.tight_layout()
return fig
def plot_metrics_r2(result):
fig, ax = plt.subplots(figsize=(6, 4))
metrics = {k: v for k, v in result['metrics'].items() if k in ['R2', 'MAPE']}
sns.barplot(x=list(metrics.keys()), y=list(metrics.values()), ax=ax, palette="Blues_d")
ax.set_title("R² and MAPE Metrics")
ax.set_ylim(-1, 1) # R² range, MAPE typically small
ax.grid(True)
plt.tight_layout()
return fig
def plot_metrics_errors(result):
fig, ax = plt.subplots(figsize=(6, 4))
metrics = {k: v for k, v in result['metrics'].items() if k in ['RMSE', 'MAE']}
sns.barplot(x=list(metrics.keys()), y=list(metrics.values()), ax=ax, palette="Reds_d")
ax.set_title("RMSE and MAE Metrics")
ax.grid(True)
plt.tight_layout()
return fig
def plot_loss_curve(result):
train_losses = result.get('train_loss', [])
val_losses = result.get('val_loss', [])
fig, ax = plt.subplots(figsize=(8, 5))
ax.plot(train_losses, label='Train Loss', color='blue', linewidth=2)
if val_losses:
ax.plot(val_losses, label='Validation Loss', color='orange', linewidth=2)
ax.set_title('Training and Validation Loss Curve')
ax.set_xlabel('Epoch')
ax.set_ylabel('Loss (MSE)')
ax.legend()
ax.grid(True)
plt.tight_layout()
return fig