File size: 3,070 Bytes
7064310
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
from __future__ import annotations

from typing import Dict, List, Sequence


MOVEMENT_STRONG_DOWN_THRESHOLD = -0.40
MOVEMENT_DOWN_THRESHOLD = -0.30
MOVEMENT_PUMP_50_THRESHOLD = 0.50
MOVEMENT_PUMP_100_THRESHOLD = 1.00
MOVEMENT_PUMP_300_THRESHOLD = 3.00

MOVEMENT_CLASS_NAMES = [
    "strong_down",
    "down",
    "flat",
    "up",
    "strong_up",
    "extreme_up",
]
MOVEMENT_CLASS_TO_ID = {name: idx for idx, name in enumerate(MOVEMENT_CLASS_NAMES)}
MOVEMENT_ID_TO_CLASS = {idx: name for name, idx in MOVEMENT_CLASS_TO_ID.items()}

DEFAULT_MOVEMENT_LABEL_CONFIG = {
    "strong_down_threshold": MOVEMENT_STRONG_DOWN_THRESHOLD,
    "down_threshold": MOVEMENT_DOWN_THRESHOLD,
    "pump_50_threshold": MOVEMENT_PUMP_50_THRESHOLD,
    "pump_100_threshold": MOVEMENT_PUMP_100_THRESHOLD,
    "pump_300_threshold": MOVEMENT_PUMP_300_THRESHOLD,
}


def classify_movement_return(
    return_value: float,
    movement_label_config: Dict[str, float] | None = None,
) -> int:
    cfg = dict(DEFAULT_MOVEMENT_LABEL_CONFIG)
    if movement_label_config:
        cfg.update({k: float(v) for k, v in movement_label_config.items() if k in cfg})

    strong_down_threshold = min(cfg["strong_down_threshold"], cfg["down_threshold"])
    down_threshold = cfg["down_threshold"]
    pump_50_threshold = cfg["pump_50_threshold"]
    pump_100_threshold = cfg["pump_100_threshold"]
    pump_300_threshold = cfg["pump_300_threshold"]

    if return_value <= strong_down_threshold:
        return MOVEMENT_CLASS_TO_ID["strong_down"]
    if return_value < down_threshold:
        return MOVEMENT_CLASS_TO_ID["down"]
    if return_value < pump_50_threshold:
        return MOVEMENT_CLASS_TO_ID["flat"]
    if return_value < pump_100_threshold:
        return MOVEMENT_CLASS_TO_ID["up"]
    if return_value < pump_300_threshold:
        return MOVEMENT_CLASS_TO_ID["strong_up"]
    return MOVEMENT_CLASS_TO_ID["extreme_up"]


def derive_movement_targets(
    horizon_returns: Sequence[float],
    horizon_mask: Sequence[float],
    movement_label_config: Dict[str, float] | None = None,
) -> Dict[str, List[int]]:
    class_targets: List[int] = []
    class_mask: List[int] = []
    class_names: List[str] = []

    usable = min(len(horizon_returns), len(horizon_mask))
    for idx in range(usable):
        if float(horizon_mask[idx]) <= 0:
            class_targets.append(MOVEMENT_CLASS_TO_ID["flat"])
            class_mask.append(0)
            class_names.append("masked")
            continue

        class_id = classify_movement_return(
            float(horizon_returns[idx]),
            movement_label_config=movement_label_config,
        )
        class_targets.append(class_id)
        class_mask.append(1)
        class_names.append(MOVEMENT_ID_TO_CLASS[class_id])

    return {
        "movement_class_targets": class_targets,
        "movement_class_mask": class_mask,
        "movement_class_names": class_names,
    }


def compute_movement_label_config(valid_returns: Sequence[float]) -> Dict[str, float]:
    del valid_returns
    return dict(DEFAULT_MOVEMENT_LABEL_CONFIG)