aromidvar commited on
Commit
0fd9ef8
·
verified ·
1 Parent(s): b1ccdc7

Update core/plot.py

Browse files
Files changed (1) hide show
  1. core/plot.py +102 -64
core/plot.py CHANGED
@@ -1,3 +1,4 @@
 
1
  import plotly.graph_objects as go
2
  import plotly.express as px
3
  import pandas as pd
@@ -39,20 +40,20 @@ def plot_indicators(df, ticker):
39
  logging.warning("Missing columns for candlestick: 'Date', 'Open', 'High', 'Low', 'value'")
40
 
41
  for ma in ['sma_10', 'sma_20', 'sma_50', 'ema_12', 'ema_26', 'ema_50']:
42
- if ma in df:
43
  logging.debug(f"Adding {ma} trace")
44
  fig.add_trace(
45
  go.Scatter(x=df['Date'], y=df[ma], name=ma.upper(), line=dict(width=1.5)),
46
  row=1, col=1
47
  )
48
  else:
49
- logging.warning(f"{ma} not found in DataFrame")
50
 
51
  # Bollinger Bands
52
  bb_u = 'bbu_20_2.0' if 'bbu_20_2.0' in df else ('bbu_20_2' if 'bbu_20_2' in df else None)
53
  bb_m = 'bbm_20_2.0' if 'bbm_20_2.0' in df else ('bbm_20_2' if 'bbm_20_2' in df else None)
54
  bb_l = 'bbl_20_2.0' if 'bbl_20_2.0' in df else ('bbl_20_2' if 'bbl_20_2' in df else None)
55
- if bb_u and bb_m and bb_l:
56
  logging.debug("Adding Bollinger Bands traces")
57
  fig.add_trace(
58
  go.Scatter(x=df['Date'], y=df[bb_u], name='BB Upper', line=dict(color='gray', dash='dot')),
@@ -67,40 +68,52 @@ def plot_indicators(df, ticker):
67
  row=1, col=1
68
  )
69
  else:
70
- logging.warning(f"Bollinger Bands columns missing: {bb_u}, {bb_m}, {bb_l}")
71
 
72
  # Signals
73
- if 'Signal' in df:
74
  logging.debug("Adding signal traces")
75
  buy_signals = df[df['Signal'] == 'Buy']
76
  sell_signals = df[df['Signal'] == 'Sell']
77
  hold_signals = df[df['Signal'] == 'Hold']
78
- fig.add_trace(
79
- go.Scatter(
80
- x=buy_signals['Date'], y=buy_signals['value'], mode='markers+text',
81
- name='Buy', marker=dict(symbol='triangle-up', size=12, color='green'),
82
- text=['Buy'] * len(buy_signals), textposition='top center'
83
- ), row=1, col=1
84
- )
85
- fig.add_trace(
86
- go.Scatter(
87
- x=sell_signals['Date'], y=sell_signals['value'], mode='markers+text',
88
- name='Sell', marker=dict(symbol='triangle-down', size=12, color='red'),
89
- text=['Sell'] * len(sell_signals), textposition='bottom center'
90
- ), row=1, col=1
91
- )
92
- fig.add_trace(
93
- go.Scatter(
94
- x=hold_signals['Date'], y=hold_signals['value'], mode='markers',
95
- name='Hold', marker=dict(symbol='circle', size=8, color='gray'),
96
- opacity=0.5
97
- ), row=1, col=1
98
- )
 
 
 
 
 
 
 
 
 
 
 
 
99
  else:
100
- logging.warning("Signal column not found in DataFrame")
101
 
102
  # Position Size and Risk Annotation
103
- if 'atr_14' in df:
104
  atr = df['atr_14'].iloc[-1]
105
  stop_distance = atr * 2
106
  position_size = (10000 * 0.01) / stop_distance if stop_distance != 0 else 0
@@ -111,20 +124,20 @@ def plot_indicators(df, ticker):
111
  font=dict(color="black", size=12)
112
  )
113
  else:
114
- logging.warning("atr_14 not found for position size annotation")
115
 
116
  # Volume
117
- if 'Volume' in df:
118
  logging.debug("Adding volume trace")
119
  fig.add_trace(
120
  go.Bar(x=df['Date'], y=df['Volume'], name='Volume', marker_color='blue', opacity=0.5),
121
  row=2, col=1
122
  )
123
  else:
124
- logging.warning("Volume column not found in DataFrame")
125
 
126
  # MACD & RSI
127
- if 'macd_12_26_9' in df and 'macds_12_26_9' in df and 'macdh_12_26_9' in df:
128
  logging.debug("Adding MACD traces")
129
  fig.add_trace(
130
  go.Scatter(x=df['Date'], y=df['macd_12_26_9'], name='MACD', line=dict(color='blue')),
@@ -139,9 +152,9 @@ def plot_indicators(df, ticker):
139
  row=3, col=1
140
  )
141
  else:
142
- logging.warning("MACD columns (macd_12_26_9, macds_12_26_9, macdh_12_26_9) not found")
143
 
144
- if 'rsi_14' in df:
145
  logging.debug("Adding RSI 14 trace")
146
  fig.add_trace(
147
  go.Scatter(x=df['Date'], y=df['rsi_14'], name='RSI 14', line=dict(color='purple')),
@@ -152,20 +165,20 @@ def plot_indicators(df, ticker):
152
  fig.add_trace(go.Scatter(x=[df['Date'].min(), df['Date'].max()], y=[30, 30],
153
  showlegend=False, line=dict(color='green', dash='dash', width=1)), row=3, col=1)
154
  else:
155
- logging.warning("rsi_14 not found in DataFrame")
156
 
157
  for rsi in ['rsi_21', 'rsi_50']:
158
- if rsi in df:
159
  logging.debug(f"Adding {rsi} trace")
160
  fig.add_trace(
161
  go.Scatter(x=df['Date'], y=df[rsi], name=rsi.upper(), line=dict(color='magenta' if rsi == 'rsi_21' else 'cyan', dash='dash' if rsi == 'rsi_21' else 'dot')),
162
  row=3, col=1
163
  )
164
  else:
165
- logging.warning(f"{rsi} not found in DataFrame")
166
 
167
  # Stochastic & Williams %R
168
- if 'stochk_14_3_3' in df and 'stochd_14_3_3' in df:
169
  logging.debug("Adding Stochastic traces")
170
  fig.add_trace(
171
  go.Scatter(x=df['Date'], y=df['stochk_14_3_3'], name='Stoch %K', line=dict(color='blue')),
@@ -180,9 +193,9 @@ def plot_indicators(df, ticker):
180
  fig.add_trace(go.Scatter(x=[df['Date'].min(), df['Date'].max()], y=[20, 20],
181
  showlegend=False, line=dict(color='green', dash='dash', width=1)), row=4, col=1)
182
  else:
183
- logging.warning("Stochastic columns (stochk_14_3_3, stochd_14_3_3) not found")
184
 
185
- if 'willr_14' in df:
186
  logging.debug("Adding Williams %R trace")
187
  fig.add_trace(
188
  go.Scatter(x=df['Date'], y=df['willr_14'], name='Williams %R', line=dict(color='green')),
@@ -193,10 +206,10 @@ def plot_indicators(df, ticker):
193
  fig.add_trace(go.Scatter(x=[df['Date'].min(), df['Date'].max()], y=[-80, -80],
194
  showlegend=False, line=dict(color='green', dash='dash', width=1)), row=4, col=1)
195
  else:
196
- logging.warning("willr_14 not found in DataFrame")
197
 
198
  # ADX & DI
199
- if all(col in df for col in ['adx_14', 'pdi_14', 'mdi_14']):
200
  logging.debug("Adding ADX and DI traces")
201
  fig.add_trace(
202
  go.Scatter(x=df['Date'], y=df['adx_14'], name='ADX', line=dict(color='blue')),
@@ -213,19 +226,19 @@ def plot_indicators(df, ticker):
213
  fig.add_trace(go.Scatter(x=[df['Date'].min(), df['Date'].max()], y=[25, 25],
214
  showlegend=False, line=dict(color='black', dash='dash', width=1)), row=5, col=1)
215
  else:
216
- logging.warning("ADX/DI columns (adx_14, pdi_14, mdi_14) not found")
217
 
218
  # ATR & CCI
219
- if 'atr_14' in df:
220
  logging.debug("Adding ATR trace")
221
  fig.add_trace(
222
  go.Scatter(x=df['Date'], y=df['atr_14'], name='ATR', line=dict(color='orange')),
223
  row=6, col=1
224
  )
225
  else:
226
- logging.warning("atr_14 not found in DataFrame")
227
 
228
- if 'cci_20' in df:
229
  logging.debug("Adding CCI trace")
230
  fig.add_trace(
231
  go.Scatter(x=df['Date'], y=df['cci_20'], name='CCI', line=dict(color='purple')),
@@ -236,10 +249,10 @@ def plot_indicators(df, ticker):
236
  fig.add_trace(go.Scatter(x=[df['Date'].min(), df['Date'].max()], y=[-100, -100],
237
  showlegend=False, line=dict(color='green', dash='dash', width=1)), row=6, col=1)
238
  else:
239
- logging.warning("cci_20 not found in DataFrame")
240
 
241
  # Signal Strength
242
- if all(col in df for col in ['RSI_Signal', 'MACD_Signal', 'ADX_Signal', 'Sentiment_Signal', 'Model_Signal']):
243
  logging.debug("Adding signal strength trace")
244
  signal_strength = (
245
  df['RSI_Signal'].abs() +
@@ -257,7 +270,7 @@ def plot_indicators(df, ticker):
257
  fig.add_trace(go.Scatter(x=[df['Date'].min(), df['Date'].max()], y=[3, 3],
258
  showlegend=False, line=dict(color='orange', dash='dash', width=1)), row=7, col=1)
259
  else:
260
- logging.warning("Signal strength columns (RSI_Signal, MACD_Signal, ADX_Signal, Sentiment_Signal, Model_Signal) not found")
261
 
262
  fig.update_layout(
263
  title=f"{ticker} Price and Technical Indicators",
@@ -282,11 +295,11 @@ def plot_future_forecast(df, result, timeframe):
282
  try:
283
  logging.debug(f"Starting plot_future_forecast for timeframe: {timeframe}")
284
  fig = go.Figure()
285
- if 'Date' in df and 'value' in df:
286
  logging.debug("Adding historical close trace")
287
  fig.add_trace(go.Scatter(x=df['Date'], y=df['value'], name='Historical Close', line=dict(color='blue')))
288
  else:
289
- logging.warning("Missing 'Date' or 'value' columns for historical close")
290
 
291
  if "latest_prediction" in result:
292
  last_date = df['Date'].iloc[-1]
@@ -507,12 +520,12 @@ def plot_model_architecture(result):
507
  dummy_input = torch.randn(1, result['arch']['window'], result['arch']['input_dim'])
508
  graph = make_dot(model(dummy_input), params=dict(model.named_parameters()))
509
  graph.format = 'png'
510
- graph.render("model_arch", cleanup=True)
511
  logging.debug("Model architecture graph rendered")
512
  fig = go.Figure()
513
  fig.add_layout_image(
514
  dict(
515
- source="data:image/png;base64," + base64.b64encode(open("model_arch.png", "rb").read()).decode(),
516
  xref="paper", yref="paper",
517
  x=0, y=1,
518
  sizex=1, sizey=1,
@@ -522,7 +535,9 @@ def plot_model_architecture(result):
522
  fig.update_layout(
523
  title="Model Architecture Graph",
524
  template="plotly_dark",
525
- showlegend=False
 
 
526
  )
527
  logging.info("Model architecture plot generated")
528
  return fig
@@ -536,11 +551,11 @@ def plot_signals(signals_df, ticker):
536
  logging.debug(f"Signals DataFrame columns: {signals_df.columns.tolist()}")
537
  fig = go.Figure()
538
  x_col = 'Date' if 'Date' in signals_df.columns else signals_df.index
539
- if 'Price' in signals_df:
540
  logging.debug("Adding price trace")
541
  fig.add_trace(go.Scatter(x=signals_df[x_col], y=signals_df['Price'], mode='lines', name='Price', line=dict(color='blue')))
542
  else:
543
- logging.warning("Price column not found in signals_df")
544
  buy_signals = signals_df[signals_df['Signal'] == 'Buy']
545
  sell_signals = signals_df[signals_df['Signal'] == 'Sell']
546
  if not buy_signals.empty:
@@ -575,20 +590,43 @@ def plot_backtest(result, df, ticker):
575
  logging.warning("Actual or forecast data missing")
576
  return None
577
  logging.debug(f"Actual length: {len(actual)}, Forecast length: {len(forecast)}")
578
- last_historical_date = df['Date'].iloc[-len(actual) - 1]
579
- historical_dates = df['Date'].iloc[-len(actual) - 1: -len(actual) + len(actual)]
580
- forecast_dates = pd.date_range(start=last_historical_date + timedelta(days=1), periods=len(forecast))
581
- historical_values = df['value'].iloc[-len(actual) - 1: -len(actual) + len(actual)]
582
- logging.debug("Adding historical and forecast traces")
 
 
 
 
 
 
583
  fig = go.Figure()
584
- fig.add_trace(go.Scatter(x=historical_dates, y=historical_values, mode='lines', name='Historical', line=dict(color='blue')))
585
- fig.add_trace(go.Scatter(x=forecast_dates, y=forecast, mode='lines', name='Forecast', line=dict(color='orange', dash='dash')))
 
 
 
 
 
 
 
 
 
 
 
 
586
  fig.update_layout(
587
  title=f"{ticker} Backtest: Historical and Prediction",
588
  xaxis_title="Date",
589
  yaxis_title="Price",
590
  template="plotly_dark",
591
- showlegend=True
 
 
 
 
 
592
  )
593
  logging.info(f"Backtest plot generated for {ticker}")
594
  return fig
 
1
+ # core/plot.py
2
  import plotly.graph_objects as go
3
  import plotly.express as px
4
  import pandas as pd
 
40
  logging.warning("Missing columns for candlestick: 'Date', 'Open', 'High', 'Low', 'value'")
41
 
42
  for ma in ['sma_10', 'sma_20', 'sma_50', 'ema_12', 'ema_26', 'ema_50']:
43
+ if ma in df and not df[ma].isna().all():
44
  logging.debug(f"Adding {ma} trace")
45
  fig.add_trace(
46
  go.Scatter(x=df['Date'], y=df[ma], name=ma.upper(), line=dict(width=1.5)),
47
  row=1, col=1
48
  )
49
  else:
50
+ logging.warning(f"{ma} not found or all NaN in DataFrame")
51
 
52
  # Bollinger Bands
53
  bb_u = 'bbu_20_2.0' if 'bbu_20_2.0' in df else ('bbu_20_2' if 'bbu_20_2' in df else None)
54
  bb_m = 'bbm_20_2.0' if 'bbm_20_2.0' in df else ('bbm_20_2' if 'bbm_20_2' in df else None)
55
  bb_l = 'bbl_20_2.0' if 'bbl_20_2.0' in df else ('bbl_20_2' if 'bbl_20_2' in df else None)
56
+ if bb_u and bb_m and bb_l and not df[bb_u].isna().all() and not df[bb_m].isna().all() and not df[bb_l].isna().all():
57
  logging.debug("Adding Bollinger Bands traces")
58
  fig.add_trace(
59
  go.Scatter(x=df['Date'], y=df[bb_u], name='BB Upper', line=dict(color='gray', dash='dot')),
 
68
  row=1, col=1
69
  )
70
  else:
71
+ logging.warning(f"Bollinger Bands columns missing or all NaN: {bb_u}, {bb_m}, {bb_l}")
72
 
73
  # Signals
74
+ if 'Signal' in df and not df['Signal'].isna().all():
75
  logging.debug("Adding signal traces")
76
  buy_signals = df[df['Signal'] == 'Buy']
77
  sell_signals = df[df['Signal'] == 'Sell']
78
  hold_signals = df[df['Signal'] == 'Hold']
79
+ if not buy_signals.empty:
80
+ logging.debug(f"Adding {len(buy_signals)} buy signal markers")
81
+ fig.add_trace(
82
+ go.Scatter(
83
+ x=buy_signals['Date'], y=buy_signals['value'], mode='markers+text',
84
+ name='Buy', marker=dict(symbol='triangle-up', size=12, color='green'),
85
+ text=['Buy'] * len(buy_signals), textposition='top center'
86
+ ), row=1, col=1
87
+ )
88
+ else:
89
+ logging.warning("No buy signals found")
90
+ if not sell_signals.empty:
91
+ logging.debug(f"Adding {len(sell_signals)} sell signal markers")
92
+ fig.add_trace(
93
+ go.Scatter(
94
+ x=sell_signals['Date'], y=sell_signals['value'], mode='markers+text',
95
+ name='Sell', marker=dict(symbol='triangle-down', size=12, color='red'),
96
+ text=['Sell'] * len(sell_signals), textposition='bottom center'
97
+ ), row=1, col=1
98
+ )
99
+ else:
100
+ logging.warning("No sell signals found")
101
+ if not hold_signals.empty:
102
+ logging.debug(f"Adding {len(hold_signals)} hold signal markers")
103
+ fig.add_trace(
104
+ go.Scatter(
105
+ x=hold_signals['Date'], y=hold_signals['value'], mode='markers',
106
+ name='Hold', marker=dict(symbol='circle', size=8, color='gray'),
107
+ opacity=0.5
108
+ ), row=1, col=1
109
+ )
110
+ else:
111
+ logging.warning("No hold signals found")
112
  else:
113
+ logging.warning("Signal column not found or all NaN in DataFrame")
114
 
115
  # Position Size and Risk Annotation
116
+ if 'atr_14' in df and not df['atr_14'].isna().all():
117
  atr = df['atr_14'].iloc[-1]
118
  stop_distance = atr * 2
119
  position_size = (10000 * 0.01) / stop_distance if stop_distance != 0 else 0
 
124
  font=dict(color="black", size=12)
125
  )
126
  else:
127
+ logging.warning("atr_14 not found or all NaN for position size annotation")
128
 
129
  # Volume
130
+ if 'Volume' in df and not df['Volume'].isna().all():
131
  logging.debug("Adding volume trace")
132
  fig.add_trace(
133
  go.Bar(x=df['Date'], y=df['Volume'], name='Volume', marker_color='blue', opacity=0.5),
134
  row=2, col=1
135
  )
136
  else:
137
+ logging.warning("Volume column not found or all NaN in DataFrame")
138
 
139
  # MACD & RSI
140
+ if all(col in df for col in ['macd_12_26_9', 'macds_12_26_9', 'macdh_12_26_9']) and not df['macd_12_26_9'].isna().all():
141
  logging.debug("Adding MACD traces")
142
  fig.add_trace(
143
  go.Scatter(x=df['Date'], y=df['macd_12_26_9'], name='MACD', line=dict(color='blue')),
 
152
  row=3, col=1
153
  )
154
  else:
155
+ logging.warning("MACD columns (macd_12_26_9, macds_12_26_9, macdh_12_26_9) not found or all NaN")
156
 
157
+ if 'rsi_14' in df and not df['rsi_14'].isna().all():
158
  logging.debug("Adding RSI 14 trace")
159
  fig.add_trace(
160
  go.Scatter(x=df['Date'], y=df['rsi_14'], name='RSI 14', line=dict(color='purple')),
 
165
  fig.add_trace(go.Scatter(x=[df['Date'].min(), df['Date'].max()], y=[30, 30],
166
  showlegend=False, line=dict(color='green', dash='dash', width=1)), row=3, col=1)
167
  else:
168
+ logging.warning("rsi_14 not found or all NaN in DataFrame")
169
 
170
  for rsi in ['rsi_21', 'rsi_50']:
171
+ if rsi in df and not df[rsi].isna().all():
172
  logging.debug(f"Adding {rsi} trace")
173
  fig.add_trace(
174
  go.Scatter(x=df['Date'], y=df[rsi], name=rsi.upper(), line=dict(color='magenta' if rsi == 'rsi_21' else 'cyan', dash='dash' if rsi == 'rsi_21' else 'dot')),
175
  row=3, col=1
176
  )
177
  else:
178
+ logging.warning(f"{rsi} not found or all NaN in DataFrame")
179
 
180
  # Stochastic & Williams %R
181
+ if all(col in df for col in ['stochk_14_3_3', 'stochd_14_3_3']) and not df['stochk_14_3_3'].isna().all():
182
  logging.debug("Adding Stochastic traces")
183
  fig.add_trace(
184
  go.Scatter(x=df['Date'], y=df['stochk_14_3_3'], name='Stoch %K', line=dict(color='blue')),
 
193
  fig.add_trace(go.Scatter(x=[df['Date'].min(), df['Date'].max()], y=[20, 20],
194
  showlegend=False, line=dict(color='green', dash='dash', width=1)), row=4, col=1)
195
  else:
196
+ logging.warning("Stochastic columns (stochk_14_3_3, stochd_14_3_3) not found or all NaN")
197
 
198
+ if 'willr_14' in df and not df['willr_14'].isna().all():
199
  logging.debug("Adding Williams %R trace")
200
  fig.add_trace(
201
  go.Scatter(x=df['Date'], y=df['willr_14'], name='Williams %R', line=dict(color='green')),
 
206
  fig.add_trace(go.Scatter(x=[df['Date'].min(), df['Date'].max()], y=[-80, -80],
207
  showlegend=False, line=dict(color='green', dash='dash', width=1)), row=4, col=1)
208
  else:
209
+ logging.warning("willr_14 not found or all NaN in DataFrame")
210
 
211
  # ADX & DI
212
+ if all(col in df for col in ['adx_14', 'pdi_14', 'mdi_14']) and not df['adx_14'].isna().all():
213
  logging.debug("Adding ADX and DI traces")
214
  fig.add_trace(
215
  go.Scatter(x=df['Date'], y=df['adx_14'], name='ADX', line=dict(color='blue')),
 
226
  fig.add_trace(go.Scatter(x=[df['Date'].min(), df['Date'].max()], y=[25, 25],
227
  showlegend=False, line=dict(color='black', dash='dash', width=1)), row=5, col=1)
228
  else:
229
+ logging.warning("ADX/DI columns (adx_14, pdi_14, mdi_14) not found or all NaN")
230
 
231
  # ATR & CCI
232
+ if 'atr_14' in df and not df['atr_14'].isna().all():
233
  logging.debug("Adding ATR trace")
234
  fig.add_trace(
235
  go.Scatter(x=df['Date'], y=df['atr_14'], name='ATR', line=dict(color='orange')),
236
  row=6, col=1
237
  )
238
  else:
239
+ logging.warning("atr_14 not found or all NaN in DataFrame")
240
 
241
+ if 'cci_20' in df and not df['cci_20'].isna().all():
242
  logging.debug("Adding CCI trace")
243
  fig.add_trace(
244
  go.Scatter(x=df['Date'], y=df['cci_20'], name='CCI', line=dict(color='purple')),
 
249
  fig.add_trace(go.Scatter(x=[df['Date'].min(), df['Date'].max()], y=[-100, -100],
250
  showlegend=False, line=dict(color='green', dash='dash', width=1)), row=6, col=1)
251
  else:
252
+ logging.warning("cci_20 not found or all NaN in DataFrame")
253
 
254
  # Signal Strength
255
+ if all(col in df for col in ['RSI_Signal', 'MACD_Signal', 'ADX_Signal', 'Sentiment_Signal', 'Model_Signal']) and not df['RSI_Signal'].isna().all():
256
  logging.debug("Adding signal strength trace")
257
  signal_strength = (
258
  df['RSI_Signal'].abs() +
 
270
  fig.add_trace(go.Scatter(x=[df['Date'].min(), df['Date'].max()], y=[3, 3],
271
  showlegend=False, line=dict(color='orange', dash='dash', width=1)), row=7, col=1)
272
  else:
273
+ logging.warning("Signal strength columns (RSI_Signal, MACD_Signal, ADX_Signal, Sentiment_Signal, Model_Signal) not found or all NaN")
274
 
275
  fig.update_layout(
276
  title=f"{ticker} Price and Technical Indicators",
 
295
  try:
296
  logging.debug(f"Starting plot_future_forecast for timeframe: {timeframe}")
297
  fig = go.Figure()
298
+ if 'Date' in df and 'value' in df and not df['value'].isna().all():
299
  logging.debug("Adding historical close trace")
300
  fig.add_trace(go.Scatter(x=df['Date'], y=df['value'], name='Historical Close', line=dict(color='blue')))
301
  else:
302
+ logging.warning("Missing 'Date' or 'value' columns or all NaN for historical close")
303
 
304
  if "latest_prediction" in result:
305
  last_date = df['Date'].iloc[-1]
 
520
  dummy_input = torch.randn(1, result['arch']['window'], result['arch']['input_dim'])
521
  graph = make_dot(model(dummy_input), params=dict(model.named_parameters()))
522
  graph.format = 'png'
523
+ graph.render("/tmp/model_arch", cleanup=True)
524
  logging.debug("Model architecture graph rendered")
525
  fig = go.Figure()
526
  fig.add_layout_image(
527
  dict(
528
+ source="data:image/png;base64," + base64.b64encode(open("/tmp/model_arch.png", "rb").read()).decode(),
529
  xref="paper", yref="paper",
530
  x=0, y=1,
531
  sizex=1, sizey=1,
 
535
  fig.update_layout(
536
  title="Model Architecture Graph",
537
  template="plotly_dark",
538
+ showlegend=False,
539
+ height=800,
540
+ width=1200
541
  )
542
  logging.info("Model architecture plot generated")
543
  return fig
 
551
  logging.debug(f"Signals DataFrame columns: {signals_df.columns.tolist()}")
552
  fig = go.Figure()
553
  x_col = 'Date' if 'Date' in signals_df.columns else signals_df.index
554
+ if 'Price' in signals_df and not signals_df['Price'].isna().all():
555
  logging.debug("Adding price trace")
556
  fig.add_trace(go.Scatter(x=signals_df[x_col], y=signals_df['Price'], mode='lines', name='Price', line=dict(color='blue')))
557
  else:
558
+ logging.warning("Price column not found or all NaN in signals_df")
559
  buy_signals = signals_df[signals_df['Signal'] == 'Buy']
560
  sell_signals = signals_df[signals_df['Signal'] == 'Sell']
561
  if not buy_signals.empty:
 
590
  logging.warning("Actual or forecast data missing")
591
  return None
592
  logging.debug(f"Actual length: {len(actual)}, Forecast length: {len(forecast)}")
593
+ # Use the last N actual points plus forecast
594
+ n_historical = len(actual)
595
+ historical_dates = df['Date'].iloc[-n_historical:]
596
+ historical_values = df['value'].iloc[-n_historical:]
597
+ forecast_dates = pd.date_range(start=historical_dates.iloc[-1] + timedelta(days=1), periods=len(forecast), freq='D')
598
+ # Combine historical and forecast for continuous plot
599
+ all_dates = pd.concat([pd.Series(historical_dates), pd.Series(forecast_dates)]).reset_index(drop=True)
600
+ all_values = np.concatenate([historical_values, forecast])
601
+ all_types = ['Historical'] * len(historical_values) + ['Forecast'] * len(forecast)
602
+ plot_df = pd.DataFrame({'Date': all_dates, 'Price': all_values, 'Type': all_types})
603
+ logging.debug(f"Combined plot data: {plot_df.head().to_dict()}")
604
  fig = go.Figure()
605
+ fig.add_trace(go.Scatter(
606
+ x=plot_df[plot_df['Type'] == 'Historical']['Date'],
607
+ y=plot_df[plot_df['Type'] == 'Historical']['Price'],
608
+ mode='lines',
609
+ name='Historical',
610
+ line=dict(color='blue')
611
+ ))
612
+ fig.add_trace(go.Scatter(
613
+ x=plot_df[plot_df['Type'] == 'Forecast']['Date'],
614
+ y=plot_df[plot_df['Type'] == 'Forecast']['Price'],
615
+ mode='lines',
616
+ name='Forecast',
617
+ line=dict(color='orange', dash='dash')
618
+ ))
619
  fig.update_layout(
620
  title=f"{ticker} Backtest: Historical and Prediction",
621
  xaxis_title="Date",
622
  yaxis_title="Price",
623
  template="plotly_dark",
624
+ showlegend=True,
625
+ height=600,
626
+ xaxis=dict(tickformat="%Y-%m-%d", minor=dict(ticks="inside", showgrid=True), gridcolor="lightgrey"),
627
+ yaxis=dict(gridcolor="lightgrey"),
628
+ plot_bgcolor="white",
629
+ paper_bgcolor="white"
630
  )
631
  logging.info(f"Backtest plot generated for {ticker}")
632
  return fig