Spaces:
Sleeping
Sleeping
Update core/plot.py
Browse files- core/plot.py +60 -49
core/plot.py
CHANGED
|
@@ -6,7 +6,7 @@ import logging
|
|
| 6 |
|
| 7 |
logging.basicConfig(level=logging.INFO)
|
| 8 |
|
| 9 |
-
def plot_indicators(df, ticker
|
| 10 |
try:
|
| 11 |
fig = make_subplots(
|
| 12 |
rows=5, cols=1, shared_xaxes=True, vertical_spacing=0.02,
|
|
@@ -16,9 +16,9 @@ def plot_indicators(df, ticker, account_size=10000, risk_percent=1):
|
|
| 16 |
fig.add_trace(go.Candlestick(x=df['Date'], open=df['Open'], high=df['High'], low=df['Low'], close=df['value'],
|
| 17 |
name='Price', increasing_line_color='#00CC96', decreasing_line_color='#EF553B'), row=1, col=1)
|
| 18 |
for ma in ['sma_10', 'sma_20', 'sma_50', 'ema_12', 'ema_26', 'ema_50']:
|
| 19 |
-
if ma in df
|
| 20 |
fig.add_trace(go.Scatter(x=df['Date'], y=df[ma], name=ma.upper(), line=dict(width=1)), row=1, col=1)
|
| 21 |
-
if 'bbu_20_2.0' in df
|
| 22 |
fig.add_trace(go.Scatter(x=df['Date'], y=df['bbu_20_2.0'], name='BB Upper', line=dict(color='gray', dash='dot')), row=1, col=1)
|
| 23 |
fig.add_trace(go.Scatter(x=df['Date'], y=df['bbm_20_2.0'], name='BB Middle', line=dict(color='gray')), row=1, col=1)
|
| 24 |
fig.add_trace(go.Scatter(x=df['Date'], y=df['bbl_20_2.0'], name='BB Lower', line=dict(color='gray', dash='dot')), row=1, col=1)
|
|
@@ -26,43 +26,66 @@ def plot_indicators(df, ticker, account_size=10000, risk_percent=1):
|
|
| 26 |
sell_signals = df[df['Signal'] == 'Sell']
|
| 27 |
fig.add_trace(go.Scatter(x=buy_signals['Date'], y=buy_signals['value'], mode='markers', name='Buy', marker=dict(symbol='triangle-up', size=12, color='green')), row=1, col=1)
|
| 28 |
fig.add_trace(go.Scatter(x=sell_signals['Date'], y=sell_signals['value'], mode='markers', name='Sell', marker=dict(symbol='triangle-down', size=12, color='red')), row=1, col=1)
|
| 29 |
-
if 'atr_14' in df
|
| 30 |
atr = df['atr_14'].iloc[-1]
|
| 31 |
stop_distance = atr * 2
|
| 32 |
-
position_size = (
|
| 33 |
-
fig.add_annotation(text=f"Suggested Position Size: {position_size:.0f} shares (Risk
|
| 34 |
xref="paper", yref="paper", x=0.05, y=0.95, showarrow=False, font=dict(color="black", size=12))
|
| 35 |
fig.add_trace(go.Bar(x=df['Date'], y=df['Volume'], name='Volume', marker_color='blue'), row=2, col=1)
|
| 36 |
-
if 'macd_12_26_9' in df
|
| 37 |
fig.add_trace(go.Scatter(x=df['Date'], y=df['macd_12_26_9'], name='MACD', line=dict(color='blue')), row=3, col=1)
|
| 38 |
fig.add_trace(go.Scatter(x=df['Date'], y=df['macds_12_26_9'], name='MACD Signal', line=dict(color='orange')), row=3, col=1)
|
| 39 |
fig.add_trace(go.Bar(x=df['Date'], y=df['macdh_12_26_9'], name='MACD Hist', marker_color='gray'), row=3, col=1)
|
| 40 |
-
if 'rsi_14' in df
|
| 41 |
fig.add_trace(go.Scatter(x=df['Date'], y=df['rsi_14'], name='RSI 14', line=dict(color='purple')), row=3, col=1)
|
| 42 |
-
fig.add_hline(y=70, line_dash="dash", line_color="red", row=3, col=1
|
| 43 |
-
fig.add_hline(y=30, line_dash="dash", line_color="green", row=3, col=1
|
| 44 |
-
if 'stochk_14_3_3' in df
|
| 45 |
fig.add_trace(go.Scatter(x=df['Date'], y=df['stochk_14_3_3'], name='Stoch %K', line=dict(color='blue')), row=4, col=1)
|
| 46 |
fig.add_trace(go.Scatter(x=df['Date'], y=df['stochd_14_3_3'], name='Stoch %D', line=dict(color='orange')), row=4, col=1)
|
| 47 |
-
if 'willr_14' in df
|
| 48 |
fig.add_trace(go.Scatter(x=df['Date'], y=df['willr_14'], name='Williams %R', line=dict(color='green')), row=4, col=1)
|
| 49 |
fig.add_hline(y=-20, line_dash="dash", line_color="red", row=4, col=1)
|
| 50 |
fig.add_hline(y=-80, line_dash="dash", line_color="green", row=4, col=1)
|
| 51 |
-
if 'adx_14' in df
|
| 52 |
fig.add_trace(go.Scatter(x=df['Date'], y=df['adx_14'], name='ADX', line=dict(color='blue')), row=5, col=1)
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
fig.add_trace(go.Scatter(x=df['Date'], y=df['mdi_14'], name='-DI', line=dict(color='red')), row=5, col=1)
|
| 57 |
-
if 'atr_14' in df.columns:
|
| 58 |
fig.add_trace(go.Scatter(x=df['Date'], y=df['atr_14'], name='ATR', line=dict(color='orange')), row=5, col=1)
|
| 59 |
-
if 'cci_20' in df
|
| 60 |
fig.add_trace(go.Scatter(x=df['Date'], y=df['cci_20'], name='CCI', line=dict(color='purple')), row=5, col=1)
|
| 61 |
fig.add_hline(y=100, line_dash="dash", line_color="red", row=5, col=1)
|
| 62 |
fig.add_hline(y=-100, line_dash="dash", line_color="green", row=5, col=1)
|
| 63 |
-
fig.update_layout(title=f"{ticker}
|
| 64 |
-
|
| 65 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 66 |
return fig
|
| 67 |
except Exception as e:
|
| 68 |
logging.error(f"Plot indicators error: {e}")
|
|
@@ -75,27 +98,24 @@ def plot_forecast(result, df):
|
|
| 75 |
if len(actual) == 0 or len(forecast) == 0:
|
| 76 |
logging.warning("No actual or forecast data provided for plotting.")
|
| 77 |
return None
|
| 78 |
-
|
|
|
|
|
|
|
|
|
|
| 79 |
logging.info(f"Plotting forecast with date range: {dates.iloc[0]} to {dates.iloc[-1]}")
|
| 80 |
if len(dates) != len(actual):
|
| 81 |
logging.warning(f"Date length ({len(dates)}) does not match actual data length ({len(actual)}).")
|
| 82 |
return None
|
| 83 |
fig = go.Figure()
|
| 84 |
fig.add_trace(go.Scatter(x=dates, y=actual, mode='lines', name='Actual', line=dict(color='#00CC96')))
|
| 85 |
-
|
| 86 |
-
for i in range(1, len(forecast)):
|
| 87 |
-
color = '#00FF00' if forecast[i] > prev else '#FF0000'
|
| 88 |
-
fig.add_trace(go.Scatter(x=dates[i-1:i+1], y=[forecast[i-1], forecast[i]], mode='lines', showlegend=False, line=dict(color=color, width=2)))
|
| 89 |
-
prev = forecast[i]
|
| 90 |
-
trend = 'Rising' if forecast[-1] > forecast[0] else 'Falling'
|
| 91 |
-
fig.add_annotation(text=f"Trend: {trend}", xref="paper", yref="paper", x=0.05, y=0.95, showarrow=False, font=dict(size=14, color="blue"))
|
| 92 |
fig.update_layout(
|
| 93 |
title="Actual vs Forecast",
|
| 94 |
template="plotly_white",
|
| 95 |
height=600,
|
| 96 |
xaxis_title="Date",
|
| 97 |
yaxis_title="Price",
|
| 98 |
-
xaxis=dict(tickformat="%Y-%m-%d
|
| 99 |
yaxis=dict(gridcolor="lightgrey", minor=dict(ticks="inside", showgrid=True), minor_gridcolor="whitesmoke"),
|
| 100 |
plot_bgcolor="white",
|
| 101 |
paper_bgcolor="white"
|
|
@@ -114,27 +134,15 @@ def plot_future_forecast(df, result, timeframe):
|
|
| 114 |
freq_map = {'1m': 'T', '2m': '2T', '5m': '5T', '15m': '15T', '30m': '30T', '60m': 'H', '90m': '90T', '1h': 'H', '1d': 'D', '5d': '5D', '1wk': 'W', '1mo': 'M', '3mo': '3M'}
|
| 115 |
future_dates = pd.date_range(start=last_date + pd.Timedelta(days=1), periods=len(latest_pred), freq=freq_map.get(timeframe, 'D'))
|
| 116 |
fig = go.Figure()
|
| 117 |
-
|
| 118 |
-
|
| 119 |
-
fig.add_trace(go.Scatter(x=hist_x, y=hist_y, mode='lines', name='Historical', line=dict(color='#00CC96')))
|
| 120 |
-
prev = latest_pred[0]
|
| 121 |
-
for i in range(1, len(latest_pred)):
|
| 122 |
-
color = '#00FF00' if latest_pred[i] > prev else '#FF0000'
|
| 123 |
-
fig.add_trace(go.Scatter(x=[future_dates[i-1], future_dates[i]], y=[latest_pred[i-1], latest_pred[i]], mode='lines', showlegend=False, line=dict(color=color, width=2)))
|
| 124 |
-
prev = latest_pred[i]
|
| 125 |
-
ci_lower, ci_upper = result["metrics"].get("CI_Lower", latest_pred[0] - 1), result["metrics"].get("CI_Upper", latest_pred[0] + 1)
|
| 126 |
-
ci_width = (ci_upper - ci_lower) / 2
|
| 127 |
-
fig.add_trace(go.Scatter(x=future_dates, y=[p + ci_width for p in latest_pred], mode='lines', line=dict(color='transparent'), showlegend=False, name='CI Upper'))
|
| 128 |
-
fig.add_trace(go.Scatter(x=future_dates, y=[p - ci_width for p in latest_pred], fill='tonexty', mode='lines', line=dict(color='transparent'), fillcolor='rgba(0,100,80,0.2)', name='Confidence Interval'))
|
| 129 |
-
trend = 'Rising' if latest_pred[-1] > latest_pred[0] else 'Falling'
|
| 130 |
-
fig.add_annotation(text=f"Trend: {trend}", xref="paper", yref="paper", x=0.05, y=0.95, showarrow=False, font=dict(size=14, color="blue"))
|
| 131 |
fig.update_layout(
|
| 132 |
title="Future Forecast",
|
| 133 |
template="plotly_white",
|
| 134 |
height=600,
|
| 135 |
xaxis_title="Date",
|
| 136 |
yaxis_title="Price",
|
| 137 |
-
xaxis=dict(tickformat="%Y-%m-%d
|
| 138 |
yaxis=dict(gridcolor="lightgrey", minor=dict(ticks="inside", showgrid=True), minor_gridcolor="whitesmoke"),
|
| 139 |
plot_bgcolor="white",
|
| 140 |
paper_bgcolor="white"
|
|
@@ -174,9 +182,12 @@ def plot_metrics_errors(result):
|
|
| 174 |
return None
|
| 175 |
fig = go.Figure()
|
| 176 |
fig.add_trace(go.Bar(
|
| 177 |
-
x=['RMSE', 'MAE'],
|
| 178 |
-
y=[
|
| 179 |
-
|
|
|
|
|
|
|
|
|
|
| 180 |
))
|
| 181 |
fig.update_layout(
|
| 182 |
title="Error Metrics",
|
|
|
|
| 6 |
|
| 7 |
logging.basicConfig(level=logging.INFO)
|
| 8 |
|
| 9 |
+
def plot_indicators(df, ticker):
|
| 10 |
try:
|
| 11 |
fig = make_subplots(
|
| 12 |
rows=5, cols=1, shared_xaxes=True, vertical_spacing=0.02,
|
|
|
|
| 16 |
fig.add_trace(go.Candlestick(x=df['Date'], open=df['Open'], high=df['High'], low=df['Low'], close=df['value'],
|
| 17 |
name='Price', increasing_line_color='#00CC96', decreasing_line_color='#EF553B'), row=1, col=1)
|
| 18 |
for ma in ['sma_10', 'sma_20', 'sma_50', 'ema_12', 'ema_26', 'ema_50']:
|
| 19 |
+
if ma in df:
|
| 20 |
fig.add_trace(go.Scatter(x=df['Date'], y=df[ma], name=ma.upper(), line=dict(width=1)), row=1, col=1)
|
| 21 |
+
if 'bbu_20_2.0' in df:
|
| 22 |
fig.add_trace(go.Scatter(x=df['Date'], y=df['bbu_20_2.0'], name='BB Upper', line=dict(color='gray', dash='dot')), row=1, col=1)
|
| 23 |
fig.add_trace(go.Scatter(x=df['Date'], y=df['bbm_20_2.0'], name='BB Middle', line=dict(color='gray')), row=1, col=1)
|
| 24 |
fig.add_trace(go.Scatter(x=df['Date'], y=df['bbl_20_2.0'], name='BB Lower', line=dict(color='gray', dash='dot')), row=1, col=1)
|
|
|
|
| 26 |
sell_signals = df[df['Signal'] == 'Sell']
|
| 27 |
fig.add_trace(go.Scatter(x=buy_signals['Date'], y=buy_signals['value'], mode='markers', name='Buy', marker=dict(symbol='triangle-up', size=12, color='green')), row=1, col=1)
|
| 28 |
fig.add_trace(go.Scatter(x=sell_signals['Date'], y=sell_signals['value'], mode='markers', name='Sell', marker=dict(symbol='triangle-down', size=12, color='red')), row=1, col=1)
|
| 29 |
+
if 'atr_14' in df:
|
| 30 |
atr = df['atr_14'].iloc[-1]
|
| 31 |
stop_distance = atr * 2
|
| 32 |
+
position_size = (10000 * (1 / 100)) / stop_distance
|
| 33 |
+
fig.add_annotation(text=f"Suggested Position Size: {position_size:.0f} shares (Risk 1%, ATR {atr:.2f})",
|
| 34 |
xref="paper", yref="paper", x=0.05, y=0.95, showarrow=False, font=dict(color="black", size=12))
|
| 35 |
fig.add_trace(go.Bar(x=df['Date'], y=df['Volume'], name='Volume', marker_color='blue'), row=2, col=1)
|
| 36 |
+
if 'macd_12_26_9' in df:
|
| 37 |
fig.add_trace(go.Scatter(x=df['Date'], y=df['macd_12_26_9'], name='MACD', line=dict(color='blue')), row=3, col=1)
|
| 38 |
fig.add_trace(go.Scatter(x=df['Date'], y=df['macds_12_26_9'], name='MACD Signal', line=dict(color='orange')), row=3, col=1)
|
| 39 |
fig.add_trace(go.Bar(x=df['Date'], y=df['macdh_12_26_9'], name='MACD Hist', marker_color='gray'), row=3, col=1)
|
| 40 |
+
if 'rsi_14' in df:
|
| 41 |
fig.add_trace(go.Scatter(x=df['Date'], y=df['rsi_14'], name='RSI 14', line=dict(color='purple')), row=3, col=1)
|
| 42 |
+
fig.add_hline(y=70, line_dash="dash", line_color="red", row=3, col=1)
|
| 43 |
+
fig.add_hline(y=30, line_dash="dash", line_color="green", row=3, col=1)
|
| 44 |
+
if 'stochk_14_3_3' in df:
|
| 45 |
fig.add_trace(go.Scatter(x=df['Date'], y=df['stochk_14_3_3'], name='Stoch %K', line=dict(color='blue')), row=4, col=1)
|
| 46 |
fig.add_trace(go.Scatter(x=df['Date'], y=df['stochd_14_3_3'], name='Stoch %D', line=dict(color='orange')), row=4, col=1)
|
| 47 |
+
if 'willr_14' in df:
|
| 48 |
fig.add_trace(go.Scatter(x=df['Date'], y=df['willr_14'], name='Williams %R', line=dict(color='green')), row=4, col=1)
|
| 49 |
fig.add_hline(y=-20, line_dash="dash", line_color="red", row=4, col=1)
|
| 50 |
fig.add_hline(y=-80, line_dash="dash", line_color="green", row=4, col=1)
|
| 51 |
+
if 'adx_14' in df:
|
| 52 |
fig.add_trace(go.Scatter(x=df['Date'], y=df['adx_14'], name='ADX', line=dict(color='blue')), row=5, col=1)
|
| 53 |
+
fig.add_trace(go.Scatter(x=df['Date'], y=df.get('pdi_14'), name='+DI', line=dict(color='green')), row=5, col=1)
|
| 54 |
+
fig.add_trace(go.Scatter(x=df['Date'], y=df.get('mdi_14'), name='-DI', line=dict(color='red')), row=5, col=1)
|
| 55 |
+
if 'atr_14' in df:
|
|
|
|
|
|
|
| 56 |
fig.add_trace(go.Scatter(x=df['Date'], y=df['atr_14'], name='ATR', line=dict(color='orange')), row=5, col=1)
|
| 57 |
+
if 'cci_20' in df:
|
| 58 |
fig.add_trace(go.Scatter(x=df['Date'], y=df['cci_20'], name='CCI', line=dict(color='purple')), row=5, col=1)
|
| 59 |
fig.add_hline(y=100, line_dash="dash", line_color="red", row=5, col=1)
|
| 60 |
fig.add_hline(y=-100, line_dash="dash", line_color="green", row=5, col=1)
|
| 61 |
+
fig.update_layout(title=f"{ticker} Price and Technical Indicators",
|
| 62 |
+
template="plotly_white",
|
| 63 |
+
height=2000,
|
| 64 |
+
width=1400,
|
| 65 |
+
showlegend=True,
|
| 66 |
+
xaxis_rangeslider_visible=False,
|
| 67 |
+
margin=dict(l=40, r=40, t=80, b=40),
|
| 68 |
+
xaxis=dict(
|
| 69 |
+
tickformat="%Y-%m-%d",
|
| 70 |
+
minor=dict(ticks="inside", showgrid=True),
|
| 71 |
+
gridcolor="lightgrey",
|
| 72 |
+
minor_gridcolor="whitesmoke",
|
| 73 |
+
showticklabels=True
|
| 74 |
+
),
|
| 75 |
+
yaxis=dict(gridcolor="lightgrey", minor=dict(ticks="inside", showgrid=True), minor_gridcolor="whitesmoke"),
|
| 76 |
+
title_font=dict(size=20, color="black"),
|
| 77 |
+
plot_bgcolor="white",
|
| 78 |
+
paper_bgcolor="white"
|
| 79 |
+
)
|
| 80 |
+
for i in range(2, 6):
|
| 81 |
+
fig.update_xaxes(
|
| 82 |
+
tickformat="%Y-%m-%d",
|
| 83 |
+
minor=dict(ticks="inside", showgrid=True),
|
| 84 |
+
gridcolor="lightgrey",
|
| 85 |
+
minor_gridcolor="whitesmoke",
|
| 86 |
+
showticklabels=True if i == 5 else False,
|
| 87 |
+
row=i, col=1
|
| 88 |
+
)
|
| 89 |
return fig
|
| 90 |
except Exception as e:
|
| 91 |
logging.error(f"Plot indicators error: {e}")
|
|
|
|
| 98 |
if len(actual) == 0 or len(forecast) == 0:
|
| 99 |
logging.warning("No actual or forecast data provided for plotting.")
|
| 100 |
return None
|
| 101 |
+
if 'Date' not in df.columns:
|
| 102 |
+
logging.error("DataFrame missing 'Date' column.")
|
| 103 |
+
return None
|
| 104 |
+
dates = df['Date'].iloc[-len(actual):] # Align with actual data length
|
| 105 |
logging.info(f"Plotting forecast with date range: {dates.iloc[0]} to {dates.iloc[-1]}")
|
| 106 |
if len(dates) != len(actual):
|
| 107 |
logging.warning(f"Date length ({len(dates)}) does not match actual data length ({len(actual)}).")
|
| 108 |
return None
|
| 109 |
fig = go.Figure()
|
| 110 |
fig.add_trace(go.Scatter(x=dates, y=actual, mode='lines', name='Actual', line=dict(color='#00CC96')))
|
| 111 |
+
fig.add_trace(go.Scatter(x=dates, y=forecast, mode='lines', name='Forecast', line=dict(color='#EF553B')))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 112 |
fig.update_layout(
|
| 113 |
title="Actual vs Forecast",
|
| 114 |
template="plotly_white",
|
| 115 |
height=600,
|
| 116 |
xaxis_title="Date",
|
| 117 |
yaxis_title="Price",
|
| 118 |
+
xaxis=dict(tickformat="%Y-%m-%d", minor=dict(ticks="inside", showgrid=True), gridcolor="lightgrey", minor_gridcolor="whitesmoke"),
|
| 119 |
yaxis=dict(gridcolor="lightgrey", minor=dict(ticks="inside", showgrid=True), minor_gridcolor="whitesmoke"),
|
| 120 |
plot_bgcolor="white",
|
| 121 |
paper_bgcolor="white"
|
|
|
|
| 134 |
freq_map = {'1m': 'T', '2m': '2T', '5m': '5T', '15m': '15T', '30m': '30T', '60m': 'H', '90m': '90T', '1h': 'H', '1d': 'D', '5d': '5D', '1wk': 'W', '1mo': 'M', '3mo': '3M'}
|
| 135 |
future_dates = pd.date_range(start=last_date + pd.Timedelta(days=1), periods=len(latest_pred), freq=freq_map.get(timeframe, 'D'))
|
| 136 |
fig = go.Figure()
|
| 137 |
+
fig.add_trace(go.Scatter(x=df['Date'].tail(30), y=df['value'].tail(30), mode='lines', name='Historical', line=dict(color='#00CC96')))
|
| 138 |
+
fig.add_trace(go.Scatter(x=future_dates, y=latest_pred, mode='lines', name='Forecast', line=dict(color='#EF553B')))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 139 |
fig.update_layout(
|
| 140 |
title="Future Forecast",
|
| 141 |
template="plotly_white",
|
| 142 |
height=600,
|
| 143 |
xaxis_title="Date",
|
| 144 |
yaxis_title="Price",
|
| 145 |
+
xaxis=dict(tickformat="%Y-%m-%d", minor=dict(ticks="inside", showgrid=True), gridcolor="lightgrey", minor_gridcolor="whitesmoke"),
|
| 146 |
yaxis=dict(gridcolor="lightgrey", minor=dict(ticks="inside", showgrid=True), minor_gridcolor="whitesmoke"),
|
| 147 |
plot_bgcolor="white",
|
| 148 |
paper_bgcolor="white"
|
|
|
|
| 182 |
return None
|
| 183 |
fig = go.Figure()
|
| 184 |
fig.add_trace(go.Bar(
|
| 185 |
+
x=['RMSE', 'MAE', 'CI_Lower', 'CI_Upper'],
|
| 186 |
+
y=[
|
| 187 |
+
metrics.get('RMSE', 0), metrics.get('MAE', 0),
|
| 188 |
+
metrics.get('CI_Lower', 0), metrics.get('CI_Upper', 0)
|
| 189 |
+
],
|
| 190 |
+
marker_color=['#1E90FF', '#FF6347', '#00CED1', '#FF4500']
|
| 191 |
))
|
| 192 |
fig.update_layout(
|
| 193 |
title="Error Metrics",
|