aromidvar commited on
Commit
f4998b5
·
verified ·
1 Parent(s): 0a0f956

Update core/plot.py

Browse files
Files changed (1) hide show
  1. core/plot.py +57 -1
core/plot.py CHANGED
@@ -32,4 +32,60 @@ def plot_future_forecast(df, result, timeframe):
32
  if "latest_prediction" in result:
33
  last_date = df['Date'].iloc[-1]
34
  horizon = len(result["latest_prediction"])
35
- freq_map = {'1m':'T', '2m':'2T', '5m':'5T
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32
  if "latest_prediction" in result:
33
  last_date = df['Date'].iloc[-1]
34
  horizon = len(result["latest_prediction"])
35
+ freq_map = {'1m':'T', '2m':'2T', '5m':'5T', '15m':'15T', '30m':'30T', '60m':'H', '90m':'90T', '1h':'H', '1d':'D', '5d':'5D', '1wk':'W', '1mo':'M', '3mo':'3M'}
36
+ freq = freq_map.get(timeframe, 'D')
37
+ future_dates = pd.date_range(start=last_date + pd.to_timedelta(1, unit=freq.replace('T', 'min').replace('M', 'ME')), periods=horizon, freq=freq)
38
+ fig.add_trace(go.Scatter(x=future_dates, y=result["latest_prediction"], name="Forecast", mode="lines", line=dict(color="orange", dash="dash")))
39
+ fig.update_layout(title="Historical Data and Future Forecast", xaxis_title="Date", yaxis_title="Value")
40
+ return fig
41
+ except Exception as e:
42
+ logging.error(f"Plot future forecast error: {e}")
43
+ return None
44
+
45
+ def plot_indicators(df):
46
+ try:
47
+ fig = go.Figure()
48
+ fig.add_trace(go.Candlestick(x=df['Date'], open=df['Open'], high=df['High'], low=df['Low'], close=df['value'], name='Price'))
49
+ for col in df.columns:
50
+ if col not in ['Date', 'Open', 'High', 'Low', 'value', 'Volume']:
51
+ fig.add_trace(go.Scatter(x=df['Date'], y=df[col], name=col))
52
+ fig.update_layout(title="Stock Price with Indicators", yaxis_title="Value", xaxis_rangeslider_visible=True)
53
+ return fig
54
+ except Exception as e:
55
+ logging.error(f"Plot indicators error: {e}")
56
+ return None
57
+
58
+ def plot_metrics_r2(result):
59
+ if 'metrics' in result:
60
+ metrics = {k: v for k, v in result['metrics'].items() if k in ['R2', 'MAPE']}
61
+ df = pd.DataFrame(list(metrics.items()), columns=['Metric', 'Value'])
62
+ fig = px.bar(df, x='Metric', y='Value', title="R² and MAPE Metrics")
63
+ return fig
64
+ return None
65
+
66
+ def plot_metrics_errors(result):
67
+ if 'metrics' in result:
68
+ metrics = {k: v for k, v in result['metrics'].items() if k in ['RMSE', 'MAE']}
69
+ df = pd.DataFrame(list(metrics.items()), columns=['Metric', 'Value'])
70
+ fig = px.bar(df, x='Metric', y='Value', title="RMSE and MAE Metrics")
71
+ return fig
72
+ return None
73
+
74
+ def plot_loss_curve(result):
75
+ train_losses = result.get('train_loss', [])
76
+ val_losses = result.get('val_loss', [])
77
+ fig = go.Figure()
78
+ fig.add_trace(go.Scatter(y=train_losses, name='Train Loss', mode="lines", line=dict(color="blue")))
79
+ if val_losses:
80
+ fig.add_trace(go.Scatter(y=val_losses, name='Val Loss', mode="lines", line=dict(color="orange")))
81
+ fig.update_layout(title="Training Loss Curve", xaxis_title="Epoch", yaxis_title="Loss")
82
+ return fig
83
+
84
+ def plot_model_architecture(result):
85
+ if "model_summary" in result:
86
+ fig, ax = plt.subplots(figsize=(10, 6))
87
+ ax.text(0, 1, result["model_summary"], fontsize=8, family='monospace', va='top')
88
+ ax.axis('off')
89
+ plt.tight_layout()
90
+ return fig
91
+ return None