File size: 2,444 Bytes
bcceb77
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
"""Multi-task coordinator for aligning direction, magnitude, and volatility outputs."""

import logging

import numpy as np

from src.models.base import PredictionResult

logger = logging.getLogger(__name__)


class MultiTaskCoordinator:
    """Coordinates direction + magnitude + volatility outputs from the ensemble.

    Ensures consistency: if direction is 'down', magnitude should be negative;
    if direction is 'up', magnitude should be positive. Volatility is always
    non-negative.
    """

    def coordinate(
        self,
        meta_prediction: PredictionResult,
        conformal_intervals: dict | None = None,
    ) -> dict:
        """Return final coordinated prediction with all outputs aligned.

        Args:
            meta_prediction: PredictionResult from the meta-learner
            conformal_intervals: dict from ConformalPredictor.predict() or None

        Returns:
            dict with keys: direction, direction_proba, magnitude, volatility,
            confidence, and optionally interval keys from conformal.
        """
        direction = np.asarray(meta_prediction.direction, dtype=int)
        magnitude = np.asarray(meta_prediction.magnitude, dtype=float).copy()
        volatility = np.asarray(meta_prediction.volatility, dtype=float).copy()
        confidence = np.asarray(meta_prediction.confidence, dtype=float)
        direction_proba = np.asarray(meta_prediction.direction_proba, dtype=float)

        # Enforce sign consistency between direction and magnitude
        for i in range(len(direction)):
            if direction[i] == 1 and magnitude[i] < 0:
                magnitude[i] = abs(magnitude[i])
            elif direction[i] == -1 and magnitude[i] > 0:
                magnitude[i] = -abs(magnitude[i])
            # direction == 0 (flat): leave magnitude as-is (should be near zero)

        # Volatility must be non-negative
        volatility = np.abs(volatility)

        result = {
            "direction": direction,
            "direction_proba": direction_proba,
            "magnitude": magnitude,
            "volatility": volatility,
            "confidence": confidence,
        }

        # Attach conformal intervals if provided
        if conformal_intervals is not None:
            for key, value in conformal_intervals.items():
                if key != "point_prediction":  # avoid duplicating magnitude
                    result[key] = value

        return result