WaveLSFromer / results_analysis.py
ducheng678
Initial WaveLSFromer project
093b0a5
Raw
History Blame Contribute Delete
8.23 kB
import os
from pprint import pprint
from typing import Any
import pandas as pd
import numpy as np
from utils.stock_metrics import (
LogPctProfitDirection,
LogPctProfitTanhV1,
LogPctProfitTanhV2,
apply_threshold_metric,
pct_direction,
)
from utils.tools import dotdict
def open_results(
log_dir: str, args: dotdict, df: pd.DataFrame
) -> dict[str, dict[str, Any]]:
"""Function to open an experiment and return its tpd_dict"""
tpd_dict_tuple: dict[str, tuple[Any, Any, Any]] = {}
for data_group in ["train", "val", "test"]:
device = 0
while True: # Device Loop
preds_path = os.path.join(
log_dir, f"results/pred_{data_group}_{device}.npy"
)
trues_path = os.path.join(
log_dir, f"results/true_{data_group}_{device}.npy"
)
dates_path = os.path.join(
log_dir, f"results/date_{data_group}_{device}.npy"
)
if (
os.path.exists(preds_path)
and os.path.exists(trues_path)
and os.path.exists(dates_path)
):
dp = [
np.load(trues_path)[:, 0, 0],
np.load(preds_path)[:, 0, 0],
np.load(dates_path)[:, 0],
]
tpd_dict_tuple[data_group] = (
dp
if data_group not in tpd_dict_tuple
else [
np.append(tpdfi, dpi, axis=0)
for tpdfi, dpi in zip(tpd_dict_tuple[data_group], dp)
]
)
s = np.argsort(tpd_dict_tuple[data_group][2], axis=None)
tpd_dict_tuple[data_group] = list(
map(lambda x: x[s], tpd_dict_tuple[data_group])
)
tpd_dict_tuple[data_group][2] = pd.DatetimeIndex(
tpd_dict_tuple[data_group][2], tz="UTC"
)
# Override trues with df target data to get original numerical precision
if (
not ("mse" in args.loss and not args.inverse_output)
and df is not None
):
print("OVERRIDING trues with df target")
df_data_group = df.loc[tpd_dict_tuple[data_group][2]]
t = args.target.split("_")
df_target = df_data_group[t[0]][t[1]].to_numpy()
tpd_dict_tuple[data_group][0] = df_target
else:
# Done searching for devices
break
device += 1
tpd_dict: dict[str, dict[str, Any]] = {}
for data_group in tpd_dict_tuple:
tpd_dict[data_group] = {
"trues": tpd_dict_tuple[data_group][0],
"preds": tpd_dict_tuple[data_group][1],
"dates": tpd_dict_tuple[data_group][2],
}
return tpd_dict
def get_metrics(
args: dotdict | None, pred: np.ndarray, true: np.ndarray, thresh: float = 0.0
) -> dict:
"""Function to return metrics based on outputs and labels"""
# Filter by a threshold
pred_f, true_f = apply_threshold_metric(pred, true, thresh)
# df_f = df.loc[date[np.abs(pred) >= thresh]]
# Percent direction correct, ie up or down
pct_dir_correct = pct_direction(pred_f, true_f)
# Percent profit all in
pct_profit_dir = LogPctProfitDirection.metric(pred_f, true_f, short_filter=None)
pct_profit_dir_nshort = LogPctProfitDirection.metric(
pred_f, true_f, short_filter="ns"
)
pct_profit_dir_oshort = LogPctProfitDirection.metric(
pred_f, true_f, short_filter="os"
)
# Percent profit with tanh partial purchase
pct_profit_tanhv1 = LogPctProfitTanhV1.metric(pred_f, true_f, short_filter=None)[0].cpu().numpy().tolist()
pct_profit_tanhv1_nshort = LogPctProfitTanhV1.metric(
pred_f, true_f, short_filter="ns"
)[0].cpu().numpy().tolist()
pct_profit_tanhv1_oshort = LogPctProfitTanhV1.metric(
pred_f, true_f, short_filter="os"
)[0].cpu().numpy().tolist()
pct_profit_tanhv2 = LogPctProfitTanhV2.metric(pred_f, true_f, short_filter=None)
pct_profit_tanhv2_nshort = LogPctProfitTanhV2.metric(
pred_f, true_f, short_filter="ns"
)
pct_profit_tanhv2_oshort = LogPctProfitTanhV2.metric(
pred_f, true_f, short_filter="os"
)
# Optimal percent profit
pct_profit_dir_opt = LogPctProfitDirection.metric(true_f, true_f)
# Always 1 direction
avg_pct_profit_always_short = np.power(
LogPctProfitDirection.metric(-np.ones(pred_f.shape), true_f, short_filter=None),
(1 / len(pred_f)),
)
avg_pct_profit_always_buy = np.power(
LogPctProfitDirection.metric(np.ones(pred_f.shape), true_f, short_filter=None),
(1 / len(pred_f)),
)
# Return metrics
metrics = {
"avg_pct_profit_tanhv1": np.power(pct_profit_tanhv1, (1 / len(pred_f))),
"pct_profit_dir": pct_profit_dir,
"pct_profit_dir_nshort": pct_profit_dir_nshort,
"pct_profit_dir_oshort": pct_profit_dir_oshort,
"pct_profit_tanhv1": pct_profit_tanhv1,
"pct_profit_tanhv1_nshort": pct_profit_tanhv1_nshort,
"pct_profit_tanhv1_oshort": pct_profit_tanhv1_oshort,
"pct_profit_tanhv2": pct_profit_tanhv2,
"pct_profit_tanhv2_nshort": pct_profit_tanhv2_nshort,
"pct_profit_tanhv2_oshort": pct_profit_tanhv2_oshort,
"pct_excluded": (len(pred) - len(pred_f)) / len(pred),
"pct_excluded_nshort": (len(pred) - len(pred_f[pred_f > 0])) / len(pred),
"pct_excluded_oshort": (len(pred) - len(pred_f[pred_f < 0])) / len(pred),
"pct_dir_correct": pct_dir_correct,
"pct_profit_dir_opt": pct_profit_dir_opt,
"avg_pct_profit_always_short": avg_pct_profit_always_short,
"avg_pct_profit_always_buy": avg_pct_profit_always_buy,
}
return metrics
def get_tuned_metrics(args: dotdict, tpd_dict: dict):
# df = read_data(os.path.join(args.root_path, args.data_path))
# Get the percentile to check thresh until
what_percentile = 50
percentile_value = 0.0
for data_group in ["train"]: # tpd_dict:
preds = tpd_dict[data_group]["preds"]
percentile_value += np.percentile(np.abs(preds), what_percentile)
# percentile_value /= len(tpd_dict)
print(f"{what_percentile}'th percentile: {percentile_value}")
# Safety check
ticker, field = args.target.split("_")
assert field == "pctchange" or field == "logpctchange"
# Tune threshhold based off of train's metric we care about
if args.loss == "stock_tanhv1":
tune_metric = "pct_profit_tanhv1"
elif args.loss == "stock_tanhv2":
tune_metric = "pct_profit_tanhv2"
elif args.loss == "stock_tanhv4":
tune_metric = "pct_profit_tanhv1" # this is on purpose
elif args.loss == "stock_tanh":
tune_metric = "pct_profit_tanh"
else:
tune_metric = "pct_profit_dir"
max_tracker = (0, 0)
# Tracks results
tracker = {}
# Try all of the thresholds
for thresh in np.linspace(0, percentile_value, 501):
tracker[thresh] = {}
for data_group in tpd_dict:
true = tpd_dict[data_group]["trues"]
pred = tpd_dict[data_group]["preds"]
date = tpd_dict[data_group]["dates"]
tracker[thresh][data_group] = get_metrics(args, pred, true, thresh)
# Tune threshhold based off of train's metric we care about
tune_value = tracker[thresh][data_group][tune_metric]
if tune_value > max_tracker[0] and data_group == "train":
max_tracker = (tune_value, thresh)
best_thresh = max_tracker[1]
print("zeros thresh:")
for data_group in tracker[best_thresh]:
print(data_group, end="\t")
pprint(tracker[0.0][data_group], indent=3)
print("\n\n")
print("best thresh:", best_thresh)
for data_group in tracker[best_thresh]:
print(data_group, end="\t")
pprint(tracker[best_thresh][data_group], indent=3)
return best_thresh, tracker[best_thresh], tracker[0.0]