"""Multi-task coordinator for aligning direction, magnitude, and volatility outputs.""" import logging import numpy as np from src.models.base import PredictionResult logger = logging.getLogger(__name__) class MultiTaskCoordinator: """Coordinates direction + magnitude + volatility outputs from the ensemble. Ensures consistency: if direction is 'down', magnitude should be negative; if direction is 'up', magnitude should be positive. Volatility is always non-negative. """ def coordinate( self, meta_prediction: PredictionResult, conformal_intervals: dict | None = None, ) -> dict: """Return final coordinated prediction with all outputs aligned. Args: meta_prediction: PredictionResult from the meta-learner conformal_intervals: dict from ConformalPredictor.predict() or None Returns: dict with keys: direction, direction_proba, magnitude, volatility, confidence, and optionally interval keys from conformal. """ direction = np.asarray(meta_prediction.direction, dtype=int) magnitude = np.asarray(meta_prediction.magnitude, dtype=float).copy() volatility = np.asarray(meta_prediction.volatility, dtype=float).copy() confidence = np.asarray(meta_prediction.confidence, dtype=float) direction_proba = np.asarray(meta_prediction.direction_proba, dtype=float) # Enforce sign consistency between direction and magnitude for i in range(len(direction)): if direction[i] == 1 and magnitude[i] < 0: magnitude[i] = abs(magnitude[i]) elif direction[i] == -1 and magnitude[i] > 0: magnitude[i] = -abs(magnitude[i]) # direction == 0 (flat): leave magnitude as-is (should be near zero) # Volatility must be non-negative volatility = np.abs(volatility) result = { "direction": direction, "direction_proba": direction_proba, "magnitude": magnitude, "volatility": volatility, "confidence": confidence, } # Attach conformal intervals if provided if conformal_intervals is not None: for key, value in conformal_intervals.items(): if key != "point_prediction": # avoid duplicating magnitude result[key] = value return result