aromidvar commited on
Commit
0b9fb08
·
verified ·
1 Parent(s): 858646a

Update core/plot.py

Browse files
Files changed (1) hide show
  1. core/plot.py +60 -49
core/plot.py CHANGED
@@ -6,7 +6,7 @@ import logging
6
 
7
  logging.basicConfig(level=logging.INFO)
8
 
9
- def plot_indicators(df, ticker, account_size=10000, risk_percent=1):
10
  try:
11
  fig = make_subplots(
12
  rows=5, cols=1, shared_xaxes=True, vertical_spacing=0.02,
@@ -16,9 +16,9 @@ def plot_indicators(df, ticker, account_size=10000, risk_percent=1):
16
  fig.add_trace(go.Candlestick(x=df['Date'], open=df['Open'], high=df['High'], low=df['Low'], close=df['value'],
17
  name='Price', increasing_line_color='#00CC96', decreasing_line_color='#EF553B'), row=1, col=1)
18
  for ma in ['sma_10', 'sma_20', 'sma_50', 'ema_12', 'ema_26', 'ema_50']:
19
- if ma in df.columns:
20
  fig.add_trace(go.Scatter(x=df['Date'], y=df[ma], name=ma.upper(), line=dict(width=1)), row=1, col=1)
21
- if 'bbu_20_2.0' in df.columns:
22
  fig.add_trace(go.Scatter(x=df['Date'], y=df['bbu_20_2.0'], name='BB Upper', line=dict(color='gray', dash='dot')), row=1, col=1)
23
  fig.add_trace(go.Scatter(x=df['Date'], y=df['bbm_20_2.0'], name='BB Middle', line=dict(color='gray')), row=1, col=1)
24
  fig.add_trace(go.Scatter(x=df['Date'], y=df['bbl_20_2.0'], name='BB Lower', line=dict(color='gray', dash='dot')), row=1, col=1)
@@ -26,43 +26,66 @@ def plot_indicators(df, ticker, account_size=10000, risk_percent=1):
26
  sell_signals = df[df['Signal'] == 'Sell']
27
  fig.add_trace(go.Scatter(x=buy_signals['Date'], y=buy_signals['value'], mode='markers', name='Buy', marker=dict(symbol='triangle-up', size=12, color='green')), row=1, col=1)
28
  fig.add_trace(go.Scatter(x=sell_signals['Date'], y=sell_signals['value'], mode='markers', name='Sell', marker=dict(symbol='triangle-down', size=12, color='red')), row=1, col=1)
29
- if 'atr_14' in df.columns:
30
  atr = df['atr_14'].iloc[-1]
31
  stop_distance = atr * 2
32
- position_size = (account_size * (risk_percent / 100)) / stop_distance
33
- fig.add_annotation(text=f"Suggested Position Size: {position_size:.0f} shares (Risk {risk_percent}%, ATR {atr:.2f})",
34
  xref="paper", yref="paper", x=0.05, y=0.95, showarrow=False, font=dict(color="black", size=12))
35
  fig.add_trace(go.Bar(x=df['Date'], y=df['Volume'], name='Volume', marker_color='blue'), row=2, col=1)
36
- if 'macd_12_26_9' in df.columns:
37
  fig.add_trace(go.Scatter(x=df['Date'], y=df['macd_12_26_9'], name='MACD', line=dict(color='blue')), row=3, col=1)
38
  fig.add_trace(go.Scatter(x=df['Date'], y=df['macds_12_26_9'], name='MACD Signal', line=dict(color='orange')), row=3, col=1)
39
  fig.add_trace(go.Bar(x=df['Date'], y=df['macdh_12_26_9'], name='MACD Hist', marker_color='gray'), row=3, col=1)
40
- if 'rsi_14' in df.columns:
41
  fig.add_trace(go.Scatter(x=df['Date'], y=df['rsi_14'], name='RSI 14', line=dict(color='purple')), row=3, col=1)
42
- fig.add_hline(y=70, line_dash="dash", line_color="red", row=3, col=1, annotation_text="Overbought")
43
- fig.add_hline(y=30, line_dash="dash", line_color="green", row=3, col=1, annotation_text="Oversold")
44
- if 'stochk_14_3_3' in df.columns:
45
  fig.add_trace(go.Scatter(x=df['Date'], y=df['stochk_14_3_3'], name='Stoch %K', line=dict(color='blue')), row=4, col=1)
46
  fig.add_trace(go.Scatter(x=df['Date'], y=df['stochd_14_3_3'], name='Stoch %D', line=dict(color='orange')), row=4, col=1)
47
- if 'willr_14' in df.columns:
48
  fig.add_trace(go.Scatter(x=df['Date'], y=df['willr_14'], name='Williams %R', line=dict(color='green')), row=4, col=1)
49
  fig.add_hline(y=-20, line_dash="dash", line_color="red", row=4, col=1)
50
  fig.add_hline(y=-80, line_dash="dash", line_color="green", row=4, col=1)
51
- if 'adx_14' in df.columns:
52
  fig.add_trace(go.Scatter(x=df['Date'], y=df['adx_14'], name='ADX', line=dict(color='blue')), row=5, col=1)
53
- if 'pdi_14' in df.columns:
54
- fig.add_trace(go.Scatter(x=df['Date'], y=df['pdi_14'], name='+DI', line=dict(color='green')), row=5, col=1)
55
- if 'mdi_14' in df.columns:
56
- fig.add_trace(go.Scatter(x=df['Date'], y=df['mdi_14'], name='-DI', line=dict(color='red')), row=5, col=1)
57
- if 'atr_14' in df.columns:
58
  fig.add_trace(go.Scatter(x=df['Date'], y=df['atr_14'], name='ATR', line=dict(color='orange')), row=5, col=1)
59
- if 'cci_20' in df.columns:
60
  fig.add_trace(go.Scatter(x=df['Date'], y=df['cci_20'], name='CCI', line=dict(color='purple')), row=5, col=1)
61
  fig.add_hline(y=100, line_dash="dash", line_color="red", row=5, col=1)
62
  fig.add_hline(y=-100, line_dash="dash", line_color="green", row=5, col=1)
63
- fig.update_layout(title=f"{ticker} Advanced Trading Chart", height=1200, width=1400, template="plotly_white", showlegend=True, xaxis_rangeslider_visible=False)
64
- fig.update_xaxes(tickformat="%Y-%m-%d %H:%M", rangeslider_visible=False, tickangle=45, automargin=True)
65
- fig.update_yaxes(automargin=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
66
  return fig
67
  except Exception as e:
68
  logging.error(f"Plot indicators error: {e}")
@@ -75,27 +98,24 @@ def plot_forecast(result, df):
75
  if len(actual) == 0 or len(forecast) == 0:
76
  logging.warning("No actual or forecast data provided for plotting.")
77
  return None
78
- dates = df['Date'].iloc[-len(actual):]
 
 
 
79
  logging.info(f"Plotting forecast with date range: {dates.iloc[0]} to {dates.iloc[-1]}")
80
  if len(dates) != len(actual):
81
  logging.warning(f"Date length ({len(dates)}) does not match actual data length ({len(actual)}).")
82
  return None
83
  fig = go.Figure()
84
  fig.add_trace(go.Scatter(x=dates, y=actual, mode='lines', name='Actual', line=dict(color='#00CC96')))
85
- prev = forecast[0]
86
- for i in range(1, len(forecast)):
87
- color = '#00FF00' if forecast[i] > prev else '#FF0000'
88
- fig.add_trace(go.Scatter(x=dates[i-1:i+1], y=[forecast[i-1], forecast[i]], mode='lines', showlegend=False, line=dict(color=color, width=2)))
89
- prev = forecast[i]
90
- trend = 'Rising' if forecast[-1] > forecast[0] else 'Falling'
91
- fig.add_annotation(text=f"Trend: {trend}", xref="paper", yref="paper", x=0.05, y=0.95, showarrow=False, font=dict(size=14, color="blue"))
92
  fig.update_layout(
93
  title="Actual vs Forecast",
94
  template="plotly_white",
95
  height=600,
96
  xaxis_title="Date",
97
  yaxis_title="Price",
98
- xaxis=dict(tickformat="%Y-%m-%d %H:%M", tickangle=45, automargin=True, minor=dict(ticks="inside", showgrid=True), gridcolor="lightgrey", minor_gridcolor="whitesmoke"),
99
  yaxis=dict(gridcolor="lightgrey", minor=dict(ticks="inside", showgrid=True), minor_gridcolor="whitesmoke"),
100
  plot_bgcolor="white",
101
  paper_bgcolor="white"
@@ -114,27 +134,15 @@ def plot_future_forecast(df, result, timeframe):
114
  freq_map = {'1m': 'T', '2m': '2T', '5m': '5T', '15m': '15T', '30m': '30T', '60m': 'H', '90m': '90T', '1h': 'H', '1d': 'D', '5d': '5D', '1wk': 'W', '1mo': 'M', '3mo': '3M'}
115
  future_dates = pd.date_range(start=last_date + pd.Timedelta(days=1), periods=len(latest_pred), freq=freq_map.get(timeframe, 'D'))
116
  fig = go.Figure()
117
- hist_y = df['value'].tail(30).values
118
- hist_x = df['Date'].tail(30)
119
- fig.add_trace(go.Scatter(x=hist_x, y=hist_y, mode='lines', name='Historical', line=dict(color='#00CC96')))
120
- prev = latest_pred[0]
121
- for i in range(1, len(latest_pred)):
122
- color = '#00FF00' if latest_pred[i] > prev else '#FF0000'
123
- fig.add_trace(go.Scatter(x=[future_dates[i-1], future_dates[i]], y=[latest_pred[i-1], latest_pred[i]], mode='lines', showlegend=False, line=dict(color=color, width=2)))
124
- prev = latest_pred[i]
125
- ci_lower, ci_upper = result["metrics"].get("CI_Lower", latest_pred[0] - 1), result["metrics"].get("CI_Upper", latest_pred[0] + 1)
126
- ci_width = (ci_upper - ci_lower) / 2
127
- fig.add_trace(go.Scatter(x=future_dates, y=[p + ci_width for p in latest_pred], mode='lines', line=dict(color='transparent'), showlegend=False, name='CI Upper'))
128
- fig.add_trace(go.Scatter(x=future_dates, y=[p - ci_width for p in latest_pred], fill='tonexty', mode='lines', line=dict(color='transparent'), fillcolor='rgba(0,100,80,0.2)', name='Confidence Interval'))
129
- trend = 'Rising' if latest_pred[-1] > latest_pred[0] else 'Falling'
130
- fig.add_annotation(text=f"Trend: {trend}", xref="paper", yref="paper", x=0.05, y=0.95, showarrow=False, font=dict(size=14, color="blue"))
131
  fig.update_layout(
132
  title="Future Forecast",
133
  template="plotly_white",
134
  height=600,
135
  xaxis_title="Date",
136
  yaxis_title="Price",
137
- xaxis=dict(tickformat="%Y-%m-%d %H:%M", tickangle=45, automargin=True, minor=dict(ticks="inside", showgrid=True), gridcolor="lightgrey", minor_gridcolor="whitesmoke"),
138
  yaxis=dict(gridcolor="lightgrey", minor=dict(ticks="inside", showgrid=True), minor_gridcolor="whitesmoke"),
139
  plot_bgcolor="white",
140
  paper_bgcolor="white"
@@ -174,9 +182,12 @@ def plot_metrics_errors(result):
174
  return None
175
  fig = go.Figure()
176
  fig.add_trace(go.Bar(
177
- x=['RMSE', 'MAE'],
178
- y=[metrics.get('RMSE', 0), metrics.get('MAE', 0)],
179
- marker_color=['#1E90FF', '#FF6347']
 
 
 
180
  ))
181
  fig.update_layout(
182
  title="Error Metrics",
 
6
 
7
  logging.basicConfig(level=logging.INFO)
8
 
9
+ def plot_indicators(df, ticker):
10
  try:
11
  fig = make_subplots(
12
  rows=5, cols=1, shared_xaxes=True, vertical_spacing=0.02,
 
16
  fig.add_trace(go.Candlestick(x=df['Date'], open=df['Open'], high=df['High'], low=df['Low'], close=df['value'],
17
  name='Price', increasing_line_color='#00CC96', decreasing_line_color='#EF553B'), row=1, col=1)
18
  for ma in ['sma_10', 'sma_20', 'sma_50', 'ema_12', 'ema_26', 'ema_50']:
19
+ if ma in df:
20
  fig.add_trace(go.Scatter(x=df['Date'], y=df[ma], name=ma.upper(), line=dict(width=1)), row=1, col=1)
21
+ if 'bbu_20_2.0' in df:
22
  fig.add_trace(go.Scatter(x=df['Date'], y=df['bbu_20_2.0'], name='BB Upper', line=dict(color='gray', dash='dot')), row=1, col=1)
23
  fig.add_trace(go.Scatter(x=df['Date'], y=df['bbm_20_2.0'], name='BB Middle', line=dict(color='gray')), row=1, col=1)
24
  fig.add_trace(go.Scatter(x=df['Date'], y=df['bbl_20_2.0'], name='BB Lower', line=dict(color='gray', dash='dot')), row=1, col=1)
 
26
  sell_signals = df[df['Signal'] == 'Sell']
27
  fig.add_trace(go.Scatter(x=buy_signals['Date'], y=buy_signals['value'], mode='markers', name='Buy', marker=dict(symbol='triangle-up', size=12, color='green')), row=1, col=1)
28
  fig.add_trace(go.Scatter(x=sell_signals['Date'], y=sell_signals['value'], mode='markers', name='Sell', marker=dict(symbol='triangle-down', size=12, color='red')), row=1, col=1)
29
+ if 'atr_14' in df:
30
  atr = df['atr_14'].iloc[-1]
31
  stop_distance = atr * 2
32
+ position_size = (10000 * (1 / 100)) / stop_distance
33
+ fig.add_annotation(text=f"Suggested Position Size: {position_size:.0f} shares (Risk 1%, ATR {atr:.2f})",
34
  xref="paper", yref="paper", x=0.05, y=0.95, showarrow=False, font=dict(color="black", size=12))
35
  fig.add_trace(go.Bar(x=df['Date'], y=df['Volume'], name='Volume', marker_color='blue'), row=2, col=1)
36
+ if 'macd_12_26_9' in df:
37
  fig.add_trace(go.Scatter(x=df['Date'], y=df['macd_12_26_9'], name='MACD', line=dict(color='blue')), row=3, col=1)
38
  fig.add_trace(go.Scatter(x=df['Date'], y=df['macds_12_26_9'], name='MACD Signal', line=dict(color='orange')), row=3, col=1)
39
  fig.add_trace(go.Bar(x=df['Date'], y=df['macdh_12_26_9'], name='MACD Hist', marker_color='gray'), row=3, col=1)
40
+ if 'rsi_14' in df:
41
  fig.add_trace(go.Scatter(x=df['Date'], y=df['rsi_14'], name='RSI 14', line=dict(color='purple')), row=3, col=1)
42
+ fig.add_hline(y=70, line_dash="dash", line_color="red", row=3, col=1)
43
+ fig.add_hline(y=30, line_dash="dash", line_color="green", row=3, col=1)
44
+ if 'stochk_14_3_3' in df:
45
  fig.add_trace(go.Scatter(x=df['Date'], y=df['stochk_14_3_3'], name='Stoch %K', line=dict(color='blue')), row=4, col=1)
46
  fig.add_trace(go.Scatter(x=df['Date'], y=df['stochd_14_3_3'], name='Stoch %D', line=dict(color='orange')), row=4, col=1)
47
+ if 'willr_14' in df:
48
  fig.add_trace(go.Scatter(x=df['Date'], y=df['willr_14'], name='Williams %R', line=dict(color='green')), row=4, col=1)
49
  fig.add_hline(y=-20, line_dash="dash", line_color="red", row=4, col=1)
50
  fig.add_hline(y=-80, line_dash="dash", line_color="green", row=4, col=1)
51
+ if 'adx_14' in df:
52
  fig.add_trace(go.Scatter(x=df['Date'], y=df['adx_14'], name='ADX', line=dict(color='blue')), row=5, col=1)
53
+ fig.add_trace(go.Scatter(x=df['Date'], y=df.get('pdi_14'), name='+DI', line=dict(color='green')), row=5, col=1)
54
+ fig.add_trace(go.Scatter(x=df['Date'], y=df.get('mdi_14'), name='-DI', line=dict(color='red')), row=5, col=1)
55
+ if 'atr_14' in df:
 
 
56
  fig.add_trace(go.Scatter(x=df['Date'], y=df['atr_14'], name='ATR', line=dict(color='orange')), row=5, col=1)
57
+ if 'cci_20' in df:
58
  fig.add_trace(go.Scatter(x=df['Date'], y=df['cci_20'], name='CCI', line=dict(color='purple')), row=5, col=1)
59
  fig.add_hline(y=100, line_dash="dash", line_color="red", row=5, col=1)
60
  fig.add_hline(y=-100, line_dash="dash", line_color="green", row=5, col=1)
61
+ fig.update_layout(title=f"{ticker} Price and Technical Indicators",
62
+ template="plotly_white",
63
+ height=2000,
64
+ width=1400,
65
+ showlegend=True,
66
+ xaxis_rangeslider_visible=False,
67
+ margin=dict(l=40, r=40, t=80, b=40),
68
+ xaxis=dict(
69
+ tickformat="%Y-%m-%d",
70
+ minor=dict(ticks="inside", showgrid=True),
71
+ gridcolor="lightgrey",
72
+ minor_gridcolor="whitesmoke",
73
+ showticklabels=True
74
+ ),
75
+ yaxis=dict(gridcolor="lightgrey", minor=dict(ticks="inside", showgrid=True), minor_gridcolor="whitesmoke"),
76
+ title_font=dict(size=20, color="black"),
77
+ plot_bgcolor="white",
78
+ paper_bgcolor="white"
79
+ )
80
+ for i in range(2, 6):
81
+ fig.update_xaxes(
82
+ tickformat="%Y-%m-%d",
83
+ minor=dict(ticks="inside", showgrid=True),
84
+ gridcolor="lightgrey",
85
+ minor_gridcolor="whitesmoke",
86
+ showticklabels=True if i == 5 else False,
87
+ row=i, col=1
88
+ )
89
  return fig
90
  except Exception as e:
91
  logging.error(f"Plot indicators error: {e}")
 
98
  if len(actual) == 0 or len(forecast) == 0:
99
  logging.warning("No actual or forecast data provided for plotting.")
100
  return None
101
+ if 'Date' not in df.columns:
102
+ logging.error("DataFrame missing 'Date' column.")
103
+ return None
104
+ dates = df['Date'].iloc[-len(actual):] # Align with actual data length
105
  logging.info(f"Plotting forecast with date range: {dates.iloc[0]} to {dates.iloc[-1]}")
106
  if len(dates) != len(actual):
107
  logging.warning(f"Date length ({len(dates)}) does not match actual data length ({len(actual)}).")
108
  return None
109
  fig = go.Figure()
110
  fig.add_trace(go.Scatter(x=dates, y=actual, mode='lines', name='Actual', line=dict(color='#00CC96')))
111
+ fig.add_trace(go.Scatter(x=dates, y=forecast, mode='lines', name='Forecast', line=dict(color='#EF553B')))
 
 
 
 
 
 
112
  fig.update_layout(
113
  title="Actual vs Forecast",
114
  template="plotly_white",
115
  height=600,
116
  xaxis_title="Date",
117
  yaxis_title="Price",
118
+ xaxis=dict(tickformat="%Y-%m-%d", minor=dict(ticks="inside", showgrid=True), gridcolor="lightgrey", minor_gridcolor="whitesmoke"),
119
  yaxis=dict(gridcolor="lightgrey", minor=dict(ticks="inside", showgrid=True), minor_gridcolor="whitesmoke"),
120
  plot_bgcolor="white",
121
  paper_bgcolor="white"
 
134
  freq_map = {'1m': 'T', '2m': '2T', '5m': '5T', '15m': '15T', '30m': '30T', '60m': 'H', '90m': '90T', '1h': 'H', '1d': 'D', '5d': '5D', '1wk': 'W', '1mo': 'M', '3mo': '3M'}
135
  future_dates = pd.date_range(start=last_date + pd.Timedelta(days=1), periods=len(latest_pred), freq=freq_map.get(timeframe, 'D'))
136
  fig = go.Figure()
137
+ fig.add_trace(go.Scatter(x=df['Date'].tail(30), y=df['value'].tail(30), mode='lines', name='Historical', line=dict(color='#00CC96')))
138
+ fig.add_trace(go.Scatter(x=future_dates, y=latest_pred, mode='lines', name='Forecast', line=dict(color='#EF553B')))
 
 
 
 
 
 
 
 
 
 
 
 
139
  fig.update_layout(
140
  title="Future Forecast",
141
  template="plotly_white",
142
  height=600,
143
  xaxis_title="Date",
144
  yaxis_title="Price",
145
+ xaxis=dict(tickformat="%Y-%m-%d", minor=dict(ticks="inside", showgrid=True), gridcolor="lightgrey", minor_gridcolor="whitesmoke"),
146
  yaxis=dict(gridcolor="lightgrey", minor=dict(ticks="inside", showgrid=True), minor_gridcolor="whitesmoke"),
147
  plot_bgcolor="white",
148
  paper_bgcolor="white"
 
182
  return None
183
  fig = go.Figure()
184
  fig.add_trace(go.Bar(
185
+ x=['RMSE', 'MAE', 'CI_Lower', 'CI_Upper'],
186
+ y=[
187
+ metrics.get('RMSE', 0), metrics.get('MAE', 0),
188
+ metrics.get('CI_Lower', 0), metrics.get('CI_Upper', 0)
189
+ ],
190
+ marker_color=['#1E90FF', '#FF6347', '#00CED1', '#FF4500']
191
  ))
192
  fig.update_layout(
193
  title="Error Metrics",