Spaces:
Sleeping
Sleeping
| """Multi-task coordinator for aligning direction, magnitude, and volatility outputs.""" | |
| import logging | |
| import numpy as np | |
| from src.models.base import PredictionResult | |
| logger = logging.getLogger(__name__) | |
| class MultiTaskCoordinator: | |
| """Coordinates direction + magnitude + volatility outputs from the ensemble. | |
| Ensures consistency: if direction is 'down', magnitude should be negative; | |
| if direction is 'up', magnitude should be positive. Volatility is always | |
| non-negative. | |
| """ | |
| def coordinate( | |
| self, | |
| meta_prediction: PredictionResult, | |
| conformal_intervals: dict | None = None, | |
| ) -> dict: | |
| """Return final coordinated prediction with all outputs aligned. | |
| Args: | |
| meta_prediction: PredictionResult from the meta-learner | |
| conformal_intervals: dict from ConformalPredictor.predict() or None | |
| Returns: | |
| dict with keys: direction, direction_proba, magnitude, volatility, | |
| confidence, and optionally interval keys from conformal. | |
| """ | |
| direction = np.asarray(meta_prediction.direction, dtype=int) | |
| magnitude = np.asarray(meta_prediction.magnitude, dtype=float).copy() | |
| volatility = np.asarray(meta_prediction.volatility, dtype=float).copy() | |
| confidence = np.asarray(meta_prediction.confidence, dtype=float) | |
| direction_proba = np.asarray(meta_prediction.direction_proba, dtype=float) | |
| # Enforce sign consistency between direction and magnitude | |
| for i in range(len(direction)): | |
| if direction[i] == 1 and magnitude[i] < 0: | |
| magnitude[i] = abs(magnitude[i]) | |
| elif direction[i] == -1 and magnitude[i] > 0: | |
| magnitude[i] = -abs(magnitude[i]) | |
| # direction == 0 (flat): leave magnitude as-is (should be near zero) | |
| # Volatility must be non-negative | |
| volatility = np.abs(volatility) | |
| result = { | |
| "direction": direction, | |
| "direction_proba": direction_proba, | |
| "magnitude": magnitude, | |
| "volatility": volatility, | |
| "confidence": confidence, | |
| } | |
| # Attach conformal intervals if provided | |
| if conformal_intervals is not None: | |
| for key, value in conformal_intervals.items(): | |
| if key != "point_prediction": # avoid duplicating magnitude | |
| result[key] = value | |
| return result | |