stock_v2 / utils /process.py
jonathanjordan21's picture
init
139173f
Raw
History Blame Contribute Delete
2.45 kB
import pandas as pd
import numpy as np
from sklearn.metrics import accuracy_score
def find_best_model_by_metric(models_dict, metric_name, is_smaller=True):
if is_smaller:
best_value = float(
"inf"
) # Initialize with a very high value for loss (lower is better)
else:
best_value = float(
"-inf"
) # Initialize with a very high value for loss (higher is better)
best_seq_len = None
print(f"Searching for best model based on validation '{metric_name}' loss:\n")
for k, v in models_dict.items():
# current_value = v['best_score']['validation'][metric_name]
current_value = v[metric_name]
print(f"Model (SEQ_LEN={k}): Validation {metric_name} = {current_value}")
if is_smaller:
if current_value < best_value:
best_value = current_value
best_seq_len = k
else:
if current_value > best_value:
best_value = current_value
best_seq_len = k
return best_seq_len, best_value, models_dict[best_seq_len]
# def create_majority_pred(train, transition):
# majority_pred = transition.idxmax(axis=1)
# global_majority = train["target"].mode()[0]
# preds = []
# for s in train["state"]:
# if s in majority_pred.index:
# preds.append(int(majority_pred.loc[s]))
# else:
# preds.append(int(global_majority))
# acc = accuracy_score(
# train["target"],
# preds
# )
# # return {"markov_acc": acc, "random_guess_acc":1/q}
# return acc, np.array(preds)
# def evaluate_markov_chain(train, test, transition, q):
# train_res, _ = create_majority_pred(train, transition)
# test_res, _ = create_majority_pred(test, transition)
# return {
# "train_acc": train_res,
# "test_acc": test_res,
# "random_guess_acc": 1/q,
# }
# def get_accuracy_detail(data, preds):
# label_acc = []
# # preds = np.array(preds)
# for label in sorted(np.unique(data["ret_bin"])):
# mask = (data["target"] == label)
# acc = (preds[mask] == data["target"].loc[mask]).mean()
# label_acc.append({
# "label": label,
# "count": mask.sum(),
# "accuracy": acc
# })
# label_acc = pd.DataFrame(label_acc)
# return label_acc
# # return label_acc["accuracy"].sum()