Spaces:
Sleeping
Sleeping
File size: 2,444 Bytes
bcceb77 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 | """Multi-task coordinator for aligning direction, magnitude, and volatility outputs."""
import logging
import numpy as np
from src.models.base import PredictionResult
logger = logging.getLogger(__name__)
class MultiTaskCoordinator:
"""Coordinates direction + magnitude + volatility outputs from the ensemble.
Ensures consistency: if direction is 'down', magnitude should be negative;
if direction is 'up', magnitude should be positive. Volatility is always
non-negative.
"""
def coordinate(
self,
meta_prediction: PredictionResult,
conformal_intervals: dict | None = None,
) -> dict:
"""Return final coordinated prediction with all outputs aligned.
Args:
meta_prediction: PredictionResult from the meta-learner
conformal_intervals: dict from ConformalPredictor.predict() or None
Returns:
dict with keys: direction, direction_proba, magnitude, volatility,
confidence, and optionally interval keys from conformal.
"""
direction = np.asarray(meta_prediction.direction, dtype=int)
magnitude = np.asarray(meta_prediction.magnitude, dtype=float).copy()
volatility = np.asarray(meta_prediction.volatility, dtype=float).copy()
confidence = np.asarray(meta_prediction.confidence, dtype=float)
direction_proba = np.asarray(meta_prediction.direction_proba, dtype=float)
# Enforce sign consistency between direction and magnitude
for i in range(len(direction)):
if direction[i] == 1 and magnitude[i] < 0:
magnitude[i] = abs(magnitude[i])
elif direction[i] == -1 and magnitude[i] > 0:
magnitude[i] = -abs(magnitude[i])
# direction == 0 (flat): leave magnitude as-is (should be near zero)
# Volatility must be non-negative
volatility = np.abs(volatility)
result = {
"direction": direction,
"direction_proba": direction_proba,
"magnitude": magnitude,
"volatility": volatility,
"confidence": confidence,
}
# Attach conformal intervals if provided
if conformal_intervals is not None:
for key, value in conformal_intervals.items():
if key != "point_prediction": # avoid duplicating magnitude
result[key] = value
return result
|