stock_v2 / utils /predict.py
jonathanjordan21's picture
init
139173f
Raw
History Blame Contribute Delete
8.43 kB
import uuid
import numpy as np
import pandas as pd
from pandas.tseries.offsets import CustomBusinessDay
from pandas.tseries.holiday import USFederalHolidayCalendar
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
# from prepare import make_sequences
def sharpe(r):
if np.std(r) == 0:
return 0
return np.sqrt(252) * np.mean(r) / np.std(r)
def cagr(equity):
years = len(equity) / 252
return equity[-1] ** (1 / years) - 1
def maxdd(equity):
peak = np.maximum.accumulate(equity)
dd = equity / peak - 1
return dd.min()
def var95(r):
return np.percentile(r, 5)
def kelly(r):
m = np.mean(r)
v = np.var(r)
if v == 0:
return 0
return m / v
# def one_day_future(df, conditions, transition):
# # First, generate predictions for all historical data using the best model
# X_full, y_full = make_sequences(
# df,
# feature_cols,
# "target",
# best_seq_len_multiclass
# )
# X_full_flat = X_full.reshape(X_full.shape[0], -1)
# pred_full = selected_model['model'].predict(X_full_flat)
# pred_full = pred_full.astype(int).ravel()
# # Create tmp_current dataframe using predicted bins and actual historical returns
# tmp_current = pd.DataFrame({
# "pred": pred_full,
# "ret": df_processed["target_ret"].iloc[best_seq_len_multiclass:].values
# })
# # Calculate rolling metrics on tmp_current
# for val in range(N_QUANTILES):
# conditional_ret_current = tmp_current['ret'].where(tmp_current['pred'] == val)
# tmp_current[f"rolling_ret_{TREND_WINDOW}_mean_pred_{val}"] = conditional_ret_current.rolling(window=TREND_WINDOW, min_periods=1).mean()
# # Calculate the trend signal for each day in tmp_current
# trend_signal_current = []
# for idx, row in tmp_current.fillna(0).iterrows():
# res = 0
# for n in range(N_QUANTILES):
# res += row[f"rolling_ret_{TREND_WINDOW}_mean_pred_{n}"]
# trend_signal_current.append(np.sign(res))
# # The trend signal for tomorrow is the last calculated signal
# trend_signal_for_tomorrow = trend_signal_current[-1]
# signal_forecast = np.where(
# conds_forecast, 1, 0 # 1 for bullish (trade), 0 for neutral/bearish (no trade)
# )
# # US market holidays (NYSE close)
# us_cal = CustomBusinessDay(calendar=USFederalHolidayCalendar())
# # print("Signal:", signal_forecast)
# # print("Current Date", df.index[-1])
# # print("Next Date", df.index[-1] + us_cal)#df.index[-1] + pd.Timedelta(days=1))
# # # print("Last Return Bin", last_ret_bin)
# # # print("Last Volatility Bin", last_vol_bin)
# # print(f"Current State (based on last available data): {current_state}")
# # print(f"Predicted Return Bin for the Next Day: {next_day_prediction}")
# ret, vol = current_state.split("_")
# return {
# "current_date": df.index[-1],
# "forecast_date": df.index[-1] + us_cal,
# "current_ret_state": ret,
# "current_vol_state": vol,
# "forecast_state": next_day_prediction,
# "signal": signal_forecast,
# }
def get_trend_signal(pred, ret_test, q):
tmp = pd.DataFrame({"pred": pred, "ret": ret_test})
# Calculate a 30-day rolling mean and standard deviation of 'ret' for EACH prediction category
TREND_WINDOW = 15
for val in range(q):
# Create a temporary series where 'ret' is only present if 'pred' matches 'val', otherwise NaN
conditional_ret = tmp["ret"].where(tmp["pred"] == val)
tmp[f"rolling_ret_30_mean_pred_{val}"] = conditional_ret.rolling(
window=TREND_WINDOW, min_periods=1
).mean()
tmp[f"rolling_ret_30_std_pred_{val}"] = conditional_ret.rolling(
window=TREND_WINDOW, min_periods=1
).std()
# Calculate rolling value counts for each prediction bin (0, 1, 2)
# N_QUANTILES is 3, so possible pred values are 0, 1, 2.
for val in range(q):
tmp[f"rolling_pred_count_{val}"] = (
tmp["pred"]
.rolling(window=TREND_WINDOW, min_periods=1)
.apply(lambda x: (x == val).sum(), raw=True)
)
trend_signal = []
for idx, row in tmp.fillna(0).iterrows():
res = 0
for n in range(q):
res += row[f"rolling_ret_30_mean_pred_{n}"]
trend_signal.append(np.sign(res))
# Set the first 30 days of trend_signal to 0
trend_signal[:TREND_WINDOW] = [0] * TREND_WINDOW
trend_signal = np.array(trend_signal)
return trend_signal
def forecast(
test_df, q, selected_seq, selected_model, fee=0.003, leverage=2, include_short=False
):
ret_test = test_df["target_ret"].iloc[selected_seq:].values
ret_idx = test_df["target_ret"].iloc[selected_seq:].index
TOP_BINS = 2
pred = selected_model["pred"]
trend_signal = get_trend_signal(pred, ret_test, q)
# signal = (
# pred >= (q - TOP_BINS)
# ).astype(int)
signal = np.where(
(pred > q - TOP_BINS) & (trend_signal >= 0),
# pred
# score >= 0,
1,
0
# np.where(
# pred < BOTTOM_BINS,
# -1,
# 0
# )
)
strategy_ret = signal * ret_test * leverage
# strategy_ret_with_fee = np.diff(signal) * ret_test * (1 - fee)
signal_series = pd.Series(signal) # if signal is a 1D numpy array
position_change = signal_series.diff().fillna(0)
strategy_ret_with_fee = (
strategy_ret - fee * abs(position_change) * leverage
).to_numpy()
borrowing_rate = 0.000208
# Leveraged position
levered_signal = leverage * signal # signal is a numpy array
# Gross return
levered_ret = levered_signal * ret_test
# Transaction costs (manual lag)
prev_pos = np.concatenate(([0], levered_signal[:-1]))
trade_cost = fee * np.abs(levered_signal - prev_pos)
# Borrowing costs
borrowed_notional = np.maximum(0, np.abs(levered_signal) - 1)
interest_cost = borrowing_rate * borrowed_notional
# Net return
strategy_ret_with_fee = levered_ret - trade_cost - interest_cost
equity = np.exp(np.cumsum(strategy_ret))
equity_with_fee = np.exp(np.cumsum(strategy_ret_with_fee))
buy_hold = np.exp(np.cumsum(ret_test))
plt.figure(figsize=(12, 6))
# # 2. Find indices where column 2 is the maximum per row
# max_idx = np.argmax(pred, axis=1) # index of max class for each row
# cond = (max_idx == 2) # True where column 2 is biggest
# 4. Add vertical lines at positions where condition holds
# x_positions = np.where(cond)[0] # indices where true
# for x in x_positions:
# plt.axvline(x=x, color='red', linestyle='--', alpha=0.7, linewidth=1)
plt.plot(ret_idx, equity, label="Strategy")
plt.plot(ret_idx, equity_with_fee, label="Strategy with fee")
plt.plot(ret_idx, buy_hold, label="Buy & Hold")
plt.legend()
plt.grid(True)
# 3. Save the figure
chart_path = f"/tmp/{uuid.uuid4()}_chart.png"
plt.savefig(chart_path, format="png", dpi=100, bbox_inches="tight")
return {
"STRATEGY": {
"Final Equity": equity[-1],
"Sharpe": sharpe(strategy_ret),
"CAGR": cagr(equity),
"MaxDD": maxdd(equity),
"VaR 95%": var95(strategy_ret),
"Kelly": kelly(strategy_ret),
},
"STRATEGY WITH FEE": {
"Final Equity": equity_with_fee[-1],
"Sharpe": sharpe(strategy_ret_with_fee),
"CAGR": cagr(equity_with_fee),
"MaxDD": maxdd(equity_with_fee),
"VaR 95%": var95(strategy_ret_with_fee),
"Kelly": kelly(strategy_ret_with_fee),
},
"BUY HOLD": {
"Final Equity": buy_hold[-1],
"Sharpe": sharpe(ret_test),
"CAGR": cagr(buy_hold),
"MaxDD": maxdd(buy_hold),
"VaR 95%": var95(ret_test),
"Kelly": kelly(ret_test),
},
"leverage": leverage,
"fee": fee,
"chart_path": chart_path,
}
# print("Bullish Trade Signal\t", (signal > 0).sum(), "\nNeutral/Exit Signal\t", (signal < 1).sum())
# print("Total Days", len(signal))
# print("Num Buy/Sell", (np.diff(signal) != 0).sum())
# print(np.unique(preds, return_counts=True))