from __future__ import annotations from typing import Dict, List, Sequence MOVEMENT_STRONG_DOWN_THRESHOLD = -0.40 MOVEMENT_DOWN_THRESHOLD = -0.30 MOVEMENT_PUMP_50_THRESHOLD = 0.50 MOVEMENT_PUMP_100_THRESHOLD = 1.00 MOVEMENT_PUMP_300_THRESHOLD = 3.00 MOVEMENT_CLASS_NAMES = [ "strong_down", "down", "flat", "up", "strong_up", "extreme_up", ] MOVEMENT_CLASS_TO_ID = {name: idx for idx, name in enumerate(MOVEMENT_CLASS_NAMES)} MOVEMENT_ID_TO_CLASS = {idx: name for name, idx in MOVEMENT_CLASS_TO_ID.items()} DEFAULT_MOVEMENT_LABEL_CONFIG = { "strong_down_threshold": MOVEMENT_STRONG_DOWN_THRESHOLD, "down_threshold": MOVEMENT_DOWN_THRESHOLD, "pump_50_threshold": MOVEMENT_PUMP_50_THRESHOLD, "pump_100_threshold": MOVEMENT_PUMP_100_THRESHOLD, "pump_300_threshold": MOVEMENT_PUMP_300_THRESHOLD, } def classify_movement_return( return_value: float, movement_label_config: Dict[str, float] | None = None, ) -> int: cfg = dict(DEFAULT_MOVEMENT_LABEL_CONFIG) if movement_label_config: cfg.update({k: float(v) for k, v in movement_label_config.items() if k in cfg}) strong_down_threshold = min(cfg["strong_down_threshold"], cfg["down_threshold"]) down_threshold = cfg["down_threshold"] pump_50_threshold = cfg["pump_50_threshold"] pump_100_threshold = cfg["pump_100_threshold"] pump_300_threshold = cfg["pump_300_threshold"] if return_value <= strong_down_threshold: return MOVEMENT_CLASS_TO_ID["strong_down"] if return_value < down_threshold: return MOVEMENT_CLASS_TO_ID["down"] if return_value < pump_50_threshold: return MOVEMENT_CLASS_TO_ID["flat"] if return_value < pump_100_threshold: return MOVEMENT_CLASS_TO_ID["up"] if return_value < pump_300_threshold: return MOVEMENT_CLASS_TO_ID["strong_up"] return MOVEMENT_CLASS_TO_ID["extreme_up"] def derive_movement_targets( horizon_returns: Sequence[float], horizon_mask: Sequence[float], movement_label_config: Dict[str, float] | None = None, ) -> Dict[str, List[int]]: class_targets: List[int] = [] class_mask: List[int] = [] class_names: List[str] = [] usable = min(len(horizon_returns), len(horizon_mask)) for idx in range(usable): if float(horizon_mask[idx]) <= 0: class_targets.append(MOVEMENT_CLASS_TO_ID["flat"]) class_mask.append(0) class_names.append("masked") continue class_id = classify_movement_return( float(horizon_returns[idx]), movement_label_config=movement_label_config, ) class_targets.append(class_id) class_mask.append(1) class_names.append(MOVEMENT_ID_TO_CLASS[class_id]) return { "movement_class_targets": class_targets, "movement_class_mask": class_mask, "movement_class_names": class_names, } def compute_movement_label_config(valid_returns: Sequence[float]) -> Dict[str, float]: del valid_returns return dict(DEFAULT_MOVEMENT_LABEL_CONFIG)