oracle / data /context_targets.py
zirobtc's picture
Upload folder using huggingface_hub
7064310 verified
from __future__ import annotations
from typing import Dict, List, Sequence
MOVEMENT_STRONG_DOWN_THRESHOLD = -0.40
MOVEMENT_DOWN_THRESHOLD = -0.30
MOVEMENT_PUMP_50_THRESHOLD = 0.50
MOVEMENT_PUMP_100_THRESHOLD = 1.00
MOVEMENT_PUMP_300_THRESHOLD = 3.00
MOVEMENT_CLASS_NAMES = [
"strong_down",
"down",
"flat",
"up",
"strong_up",
"extreme_up",
]
MOVEMENT_CLASS_TO_ID = {name: idx for idx, name in enumerate(MOVEMENT_CLASS_NAMES)}
MOVEMENT_ID_TO_CLASS = {idx: name for name, idx in MOVEMENT_CLASS_TO_ID.items()}
DEFAULT_MOVEMENT_LABEL_CONFIG = {
"strong_down_threshold": MOVEMENT_STRONG_DOWN_THRESHOLD,
"down_threshold": MOVEMENT_DOWN_THRESHOLD,
"pump_50_threshold": MOVEMENT_PUMP_50_THRESHOLD,
"pump_100_threshold": MOVEMENT_PUMP_100_THRESHOLD,
"pump_300_threshold": MOVEMENT_PUMP_300_THRESHOLD,
}
def classify_movement_return(
return_value: float,
movement_label_config: Dict[str, float] | None = None,
) -> int:
cfg = dict(DEFAULT_MOVEMENT_LABEL_CONFIG)
if movement_label_config:
cfg.update({k: float(v) for k, v in movement_label_config.items() if k in cfg})
strong_down_threshold = min(cfg["strong_down_threshold"], cfg["down_threshold"])
down_threshold = cfg["down_threshold"]
pump_50_threshold = cfg["pump_50_threshold"]
pump_100_threshold = cfg["pump_100_threshold"]
pump_300_threshold = cfg["pump_300_threshold"]
if return_value <= strong_down_threshold:
return MOVEMENT_CLASS_TO_ID["strong_down"]
if return_value < down_threshold:
return MOVEMENT_CLASS_TO_ID["down"]
if return_value < pump_50_threshold:
return MOVEMENT_CLASS_TO_ID["flat"]
if return_value < pump_100_threshold:
return MOVEMENT_CLASS_TO_ID["up"]
if return_value < pump_300_threshold:
return MOVEMENT_CLASS_TO_ID["strong_up"]
return MOVEMENT_CLASS_TO_ID["extreme_up"]
def derive_movement_targets(
horizon_returns: Sequence[float],
horizon_mask: Sequence[float],
movement_label_config: Dict[str, float] | None = None,
) -> Dict[str, List[int]]:
class_targets: List[int] = []
class_mask: List[int] = []
class_names: List[str] = []
usable = min(len(horizon_returns), len(horizon_mask))
for idx in range(usable):
if float(horizon_mask[idx]) <= 0:
class_targets.append(MOVEMENT_CLASS_TO_ID["flat"])
class_mask.append(0)
class_names.append("masked")
continue
class_id = classify_movement_return(
float(horizon_returns[idx]),
movement_label_config=movement_label_config,
)
class_targets.append(class_id)
class_mask.append(1)
class_names.append(MOVEMENT_ID_TO_CLASS[class_id])
return {
"movement_class_targets": class_targets,
"movement_class_mask": class_mask,
"movement_class_names": class_names,
}
def compute_movement_label_config(valid_returns: Sequence[float]) -> Dict[str, float]:
del valid_returns
return dict(DEFAULT_MOVEMENT_LABEL_CONFIG)