m
Initial deployment: ensemble stock predictor with trained models
bcceb77
"""Multi-task coordinator for aligning direction, magnitude, and volatility outputs."""
import logging
import numpy as np
from src.models.base import PredictionResult
logger = logging.getLogger(__name__)
class MultiTaskCoordinator:
"""Coordinates direction + magnitude + volatility outputs from the ensemble.
Ensures consistency: if direction is 'down', magnitude should be negative;
if direction is 'up', magnitude should be positive. Volatility is always
non-negative.
"""
def coordinate(
self,
meta_prediction: PredictionResult,
conformal_intervals: dict | None = None,
) -> dict:
"""Return final coordinated prediction with all outputs aligned.
Args:
meta_prediction: PredictionResult from the meta-learner
conformal_intervals: dict from ConformalPredictor.predict() or None
Returns:
dict with keys: direction, direction_proba, magnitude, volatility,
confidence, and optionally interval keys from conformal.
"""
direction = np.asarray(meta_prediction.direction, dtype=int)
magnitude = np.asarray(meta_prediction.magnitude, dtype=float).copy()
volatility = np.asarray(meta_prediction.volatility, dtype=float).copy()
confidence = np.asarray(meta_prediction.confidence, dtype=float)
direction_proba = np.asarray(meta_prediction.direction_proba, dtype=float)
# Enforce sign consistency between direction and magnitude
for i in range(len(direction)):
if direction[i] == 1 and magnitude[i] < 0:
magnitude[i] = abs(magnitude[i])
elif direction[i] == -1 and magnitude[i] > 0:
magnitude[i] = -abs(magnitude[i])
# direction == 0 (flat): leave magnitude as-is (should be near zero)
# Volatility must be non-negative
volatility = np.abs(volatility)
result = {
"direction": direction,
"direction_proba": direction_proba,
"magnitude": magnitude,
"volatility": volatility,
"confidence": confidence,
}
# Attach conformal intervals if provided
if conformal_intervals is not None:
for key, value in conformal_intervals.items():
if key != "point_prediction": # avoid duplicating magnitude
result[key] = value
return result