Spaces:
Running
Running
Sync from GitHub (tests passed)
Browse files- app/quality_gate.py +49 -13
- deep_learning/config.py +3 -8
- deep_learning/models/hub.py +112 -0
- deep_learning/models/monotonic_quantiles.py +63 -0
- deep_learning/models/tft_copper.py +173 -118
- deep_learning/training/callbacks.py +41 -0
- deep_learning/training/hyperopt.py +70 -36
- deep_learning/training/metrics.py +55 -14
- deep_learning/training/trainer.py +66 -52
- scripts/tft_quality_gate.py +18 -1
app/quality_gate.py
CHANGED
|
@@ -21,12 +21,16 @@ def evaluate_quality_gate(
|
|
| 21 |
tail_capture: Optional[float] = None,
|
| 22 |
quantile_crossing_rate: Optional[float] = None,
|
| 23 |
median_sort_gap_max: Optional[float] = None,
|
|
|
|
|
|
|
| 24 |
weekly_directional_accuracy: Optional[float] = None,
|
| 25 |
weekly_magnitude_ratio: Optional[float] = None,
|
| 26 |
weekly_tail_capture_rate: Optional[float] = None,
|
| 27 |
weekly_pi80_coverage: Optional[float] = None,
|
|
|
|
| 28 |
weekly_pi80_width_ratio: Optional[float] = None,
|
| 29 |
weekly_pi96_coverage: Optional[float] = None,
|
|
|
|
| 30 |
weekly_pi96_width_ratio: Optional[float] = None,
|
| 31 |
weekly_quantile_crossing_rate: Optional[float] = None,
|
| 32 |
weekly_sorted_quantile_crossing_rate: Optional[float] = None,
|
|
@@ -74,6 +78,8 @@ def evaluate_quality_gate(
|
|
| 74 |
reasons.append(
|
| 75 |
f"WeeklyPI80Overwide={weekly_pi80_width_ratio:.4f} with coverage={weekly_pi80_coverage:.4f}"
|
| 76 |
)
|
|
|
|
|
|
|
| 77 |
|
| 78 |
if weekly_pi96_coverage is None:
|
| 79 |
reasons.append("Missing weekly_pi96_coverage")
|
|
@@ -82,33 +88,63 @@ def evaluate_quality_gate(
|
|
| 82 |
reasons.append("Missing weekly_pi96_width_ratio")
|
| 83 |
elif weekly_pi96_width_ratio > 3.0:
|
| 84 |
reasons.append(f"WeeklyPI96WidthRatio={weekly_pi96_width_ratio:.4f} > 3.0")
|
|
|
|
|
|
|
| 85 |
|
| 86 |
if weekly_quantile_crossing_rate is None:
|
| 87 |
reasons.append("Missing weekly_quantile_crossing_rate")
|
| 88 |
-
elif weekly_quantile_crossing_rate > 0.
|
| 89 |
-
|
|
|
|
|
|
|
| 90 |
|
| 91 |
if weekly_sorted_quantile_crossing_rate is None:
|
| 92 |
reasons.append("Missing weekly_sorted_quantile_crossing_rate")
|
| 93 |
-
elif weekly_sorted_quantile_crossing_rate > 0.
|
| 94 |
-
|
| 95 |
-
f"
|
| 96 |
)
|
| 97 |
|
| 98 |
-
if weekly_median_sort_gap_max is not None and weekly_median_sort_gap_max > 0.
|
| 99 |
-
|
|
|
|
|
|
|
| 100 |
|
| 101 |
if sharpe < -0.30:
|
| 102 |
reasons.append(f"Sharpe={sharpe:.4f} < -0.30")
|
| 103 |
-
if vr < 0.2 or vr > 3.0:
|
| 104 |
-
reasons.append(f"VR={vr:.4f} outside [0.2, 3.0]")
|
| 105 |
if tail_capture is not None and tail_capture < 0.35:
|
| 106 |
reasons.append(f"TailCapture={tail_capture:.4f} < 0.35")
|
| 107 |
if quantile_crossing_rate is None:
|
| 108 |
reasons.append("Missing quantile_crossing_rate")
|
| 109 |
-
elif quantile_crossing_rate > 0.
|
| 110 |
-
|
| 111 |
-
if median_sort_gap_max is not None and median_sort_gap_max > 0.
|
| 112 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 113 |
|
| 114 |
return len(reasons) == 0, reasons
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 21 |
tail_capture: Optional[float] = None,
|
| 22 |
quantile_crossing_rate: Optional[float] = None,
|
| 23 |
median_sort_gap_max: Optional[float] = None,
|
| 24 |
+
pi80_width: Optional[float] = None,
|
| 25 |
+
pi96_width: Optional[float] = None,
|
| 26 |
weekly_directional_accuracy: Optional[float] = None,
|
| 27 |
weekly_magnitude_ratio: Optional[float] = None,
|
| 28 |
weekly_tail_capture_rate: Optional[float] = None,
|
| 29 |
weekly_pi80_coverage: Optional[float] = None,
|
| 30 |
+
weekly_pi80_width: Optional[float] = None,
|
| 31 |
weekly_pi80_width_ratio: Optional[float] = None,
|
| 32 |
weekly_pi96_coverage: Optional[float] = None,
|
| 33 |
+
weekly_pi96_width: Optional[float] = None,
|
| 34 |
weekly_pi96_width_ratio: Optional[float] = None,
|
| 35 |
weekly_quantile_crossing_rate: Optional[float] = None,
|
| 36 |
weekly_sorted_quantile_crossing_rate: Optional[float] = None,
|
|
|
|
| 78 |
reasons.append(
|
| 79 |
f"WeeklyPI80Overwide={weekly_pi80_width_ratio:.4f} with coverage={weekly_pi80_coverage:.4f}"
|
| 80 |
)
|
| 81 |
+
if weekly_pi80_width is not None and weekly_pi80_width < 0.0:
|
| 82 |
+
reasons.append(f"WeeklyPI80Width={weekly_pi80_width:.4f} < 0.0")
|
| 83 |
|
| 84 |
if weekly_pi96_coverage is None:
|
| 85 |
reasons.append("Missing weekly_pi96_coverage")
|
|
|
|
| 88 |
reasons.append("Missing weekly_pi96_width_ratio")
|
| 89 |
elif weekly_pi96_width_ratio > 3.0:
|
| 90 |
reasons.append(f"WeeklyPI96WidthRatio={weekly_pi96_width_ratio:.4f} > 3.0")
|
| 91 |
+
if weekly_pi96_width is not None and weekly_pi96_width < 0.0:
|
| 92 |
+
reasons.append(f"WeeklyPI96Width={weekly_pi96_width:.4f} < 0.0")
|
| 93 |
|
| 94 |
if weekly_quantile_crossing_rate is None:
|
| 95 |
reasons.append("Missing weekly_quantile_crossing_rate")
|
| 96 |
+
elif weekly_quantile_crossing_rate > 0.001:
|
| 97 |
+
raise AssertionError(
|
| 98 |
+
f"WeeklyPublicQuantileCrossing={weekly_quantile_crossing_rate:.4f} > 0.001"
|
| 99 |
+
)
|
| 100 |
|
| 101 |
if weekly_sorted_quantile_crossing_rate is None:
|
| 102 |
reasons.append("Missing weekly_sorted_quantile_crossing_rate")
|
| 103 |
+
elif weekly_sorted_quantile_crossing_rate > 0.001:
|
| 104 |
+
raise AssertionError(
|
| 105 |
+
f"WeeklyOrderedQuantileCrossing={weekly_sorted_quantile_crossing_rate:.4f} > 0.001"
|
| 106 |
)
|
| 107 |
|
| 108 |
+
if weekly_median_sort_gap_max is not None and weekly_median_sort_gap_max > 0.001:
|
| 109 |
+
raise AssertionError(
|
| 110 |
+
f"WeeklyOrderedMedianSortGapMax={weekly_median_sort_gap_max:.4f} > 0.001"
|
| 111 |
+
)
|
| 112 |
|
| 113 |
if sharpe < -0.30:
|
| 114 |
reasons.append(f"Sharpe={sharpe:.4f} < -0.30")
|
|
|
|
|
|
|
| 115 |
if tail_capture is not None and tail_capture < 0.35:
|
| 116 |
reasons.append(f"TailCapture={tail_capture:.4f} < 0.35")
|
| 117 |
if quantile_crossing_rate is None:
|
| 118 |
reasons.append("Missing quantile_crossing_rate")
|
| 119 |
+
elif quantile_crossing_rate > 0.001:
|
| 120 |
+
raise AssertionError(f"PublicQuantileCrossing={quantile_crossing_rate:.4f} > 0.001")
|
| 121 |
+
if median_sort_gap_max is not None and median_sort_gap_max > 0.001:
|
| 122 |
+
raise AssertionError(f"OrderedMedianSortGapMax={median_sort_gap_max:.4f} > 0.001")
|
| 123 |
+
if pi80_width is not None and pi80_width < 0.0:
|
| 124 |
+
reasons.append(f"PI80Width={pi80_width:.4f} < 0.0")
|
| 125 |
+
if pi96_width is not None and pi96_width < 0.0:
|
| 126 |
+
reasons.append(f"PI96Width={pi96_width:.4f} < 0.0")
|
| 127 |
|
| 128 |
return len(reasons) == 0, reasons
|
| 129 |
+
|
| 130 |
+
|
| 131 |
+
def evaluate_quality_gate_warnings(
|
| 132 |
+
vr: float,
|
| 133 |
+
mae_vs_naive_zero: Optional[float] = None,
|
| 134 |
+
weekly_mae_vs_naive_zero: Optional[float] = None,
|
| 135 |
+
) -> List[str]:
|
| 136 |
+
"""Return stabilization warnings that do not fail promotion yet."""
|
| 137 |
+
warnings: list[str] = []
|
| 138 |
+
if vr > 2.5:
|
| 139 |
+
warnings.append(f"VR={vr:.4f} > 2.5 - model overdispersed")
|
| 140 |
+
if vr < 0.4:
|
| 141 |
+
warnings.append(f"VR={vr:.4f} < 0.4 - model underdispersed")
|
| 142 |
+
if mae_vs_naive_zero is not None and mae_vs_naive_zero > 1.25:
|
| 143 |
+
warnings.append(
|
| 144 |
+
f"MAEvsNaiveZero={mae_vs_naive_zero:.4f} > 1.25 - worse than warning baseline"
|
| 145 |
+
)
|
| 146 |
+
if weekly_mae_vs_naive_zero is not None and weekly_mae_vs_naive_zero > 1.25:
|
| 147 |
+
warnings.append(
|
| 148 |
+
f"WeeklyMAEvsNaiveZero={weekly_mae_vs_naive_zero:.4f} > 1.25 - worse than warning baseline"
|
| 149 |
+
)
|
| 150 |
+
return warnings
|
deep_learning/config.py
CHANGED
|
@@ -136,15 +136,10 @@ class ASROConfig:
|
|
| 136 |
|
| 137 |
@dataclass(frozen=True)
|
| 138 |
class WeeklyLossConfig:
|
| 139 |
-
lambda_weekly_quantile: float = 0.
|
| 140 |
-
lambda_t1_quantile: float = 0.
|
|
|
|
| 141 |
lambda_directional: float = 0.10
|
| 142 |
-
lambda_magnitude: float = 0.55
|
| 143 |
-
lambda_vol: float = 0.35
|
| 144 |
-
lambda_crossing: float = 7.0
|
| 145 |
-
lambda_sanity: float = 0.20
|
| 146 |
-
lambda_width: float = 0.50
|
| 147 |
-
lambda_tail_width: float = 0.30
|
| 148 |
|
| 149 |
|
| 150 |
@dataclass(frozen=True)
|
|
|
|
| 136 |
|
| 137 |
@dataclass(frozen=True)
|
| 138 |
class WeeklyLossConfig:
|
| 139 |
+
lambda_weekly_quantile: float = 0.55
|
| 140 |
+
lambda_t1_quantile: float = 0.15
|
| 141 |
+
lambda_dispersion: float = 0.20
|
| 142 |
lambda_directional: float = 0.10
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 143 |
|
| 144 |
|
| 145 |
@dataclass(frozen=True)
|
deep_learning/models/hub.py
CHANGED
|
@@ -48,6 +48,108 @@ def _sha256_file(path: Path) -> str:
|
|
| 48 |
return digest.hexdigest()
|
| 49 |
|
| 50 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 51 |
def build_artifact_manifest(local_dir: str | Path) -> dict:
|
| 52 |
"""Build a SHA256 manifest for every present TFT artifact except itself."""
|
| 53 |
local_dir = Path(local_dir)
|
|
@@ -68,6 +170,7 @@ def build_artifact_manifest(local_dir: str | Path) -> dict:
|
|
| 68 |
"manifest_version": 1,
|
| 69 |
"generated_at": datetime.now(timezone.utc).isoformat(),
|
| 70 |
"artifacts": artifacts,
|
|
|
|
| 71 |
}
|
| 72 |
|
| 73 |
|
|
@@ -180,6 +283,12 @@ def validate_tft_artifact_set(local_dir: str | Path) -> bool:
|
|
| 180 |
return True
|
| 181 |
|
| 182 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 183 |
def upload_tft_artifacts(
|
| 184 |
local_dir: str | Path,
|
| 185 |
repo_id: str,
|
|
@@ -208,6 +317,9 @@ def upload_tft_artifacts(
|
|
| 208 |
if not validate_tft_artifact_set(local_dir):
|
| 209 |
logger.warning("TFT artifact manifest validation failed before upload")
|
| 210 |
return False
|
|
|
|
|
|
|
|
|
|
| 211 |
|
| 212 |
files_to_upload = [
|
| 213 |
local_dir / name for name in _ARTIFACTS if (local_dir / name).exists()
|
|
|
|
| 48 |
return digest.hexdigest()
|
| 49 |
|
| 50 |
|
| 51 |
+
def _load_json(path: Path) -> dict:
|
| 52 |
+
if not path.exists():
|
| 53 |
+
return {}
|
| 54 |
+
try:
|
| 55 |
+
return json.loads(path.read_text(encoding="utf-8"))
|
| 56 |
+
except Exception as exc:
|
| 57 |
+
logger.warning("Could not read JSON artifact %s: %s", path, exc)
|
| 58 |
+
return {}
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
def build_artifact_health(local_dir: str | Path) -> dict:
|
| 62 |
+
"""Build promotion/inference health metadata for the TFT artifact set."""
|
| 63 |
+
local_dir = Path(local_dir)
|
| 64 |
+
metadata_path = local_dir / "tft_metadata.json"
|
| 65 |
+
checkpoint_present = (local_dir / "best_tft_asro.ckpt").exists()
|
| 66 |
+
metadata_present = metadata_path.exists()
|
| 67 |
+
conformal_present = (local_dir / "conformal_calibration.json").exists()
|
| 68 |
+
|
| 69 |
+
metadata = _load_json(metadata_path)
|
| 70 |
+
config = metadata.get("config") or {}
|
| 71 |
+
metrics = metadata.get("test_metrics") or {}
|
| 72 |
+
optuna = _load_json(local_dir / "optuna_results.json")
|
| 73 |
+
structural_report = optuna.get("structural_invalidity_report") or {}
|
| 74 |
+
best_preflight = optuna.get("best_trial_preflight") or {}
|
| 75 |
+
|
| 76 |
+
quality_gate_passed = False
|
| 77 |
+
gate_error = None
|
| 78 |
+
if metrics:
|
| 79 |
+
try:
|
| 80 |
+
from app.quality_gate import evaluate_quality_gate
|
| 81 |
+
|
| 82 |
+
quality_gate_passed, reasons = evaluate_quality_gate(
|
| 83 |
+
da=float(metrics.get("directional_accuracy", 0.5)),
|
| 84 |
+
sharpe=float(metrics.get("sharpe_ratio", 0.0)),
|
| 85 |
+
vr=float(metrics.get("variance_ratio", 1.0)),
|
| 86 |
+
tail_capture=metrics.get("tail_capture_rate"),
|
| 87 |
+
quantile_crossing_rate=metrics.get("quantile_crossing_rate"),
|
| 88 |
+
median_sort_gap_max=metrics.get("median_sort_gap_max"),
|
| 89 |
+
pi80_width=metrics.get("pi80_width"),
|
| 90 |
+
pi96_width=metrics.get("pi96_width"),
|
| 91 |
+
weekly_directional_accuracy=metrics.get("weekly_directional_accuracy"),
|
| 92 |
+
weekly_magnitude_ratio=metrics.get("weekly_magnitude_ratio"),
|
| 93 |
+
weekly_tail_capture_rate=metrics.get("weekly_tail_capture_rate"),
|
| 94 |
+
weekly_pi80_coverage=metrics.get("weekly_pi80_coverage"),
|
| 95 |
+
weekly_pi80_width=metrics.get("weekly_pi80_width"),
|
| 96 |
+
weekly_pi80_width_ratio=metrics.get("weekly_pi80_width_ratio"),
|
| 97 |
+
weekly_pi96_coverage=metrics.get("weekly_pi96_coverage"),
|
| 98 |
+
weekly_pi96_width=metrics.get("weekly_pi96_width"),
|
| 99 |
+
weekly_pi96_width_ratio=metrics.get("weekly_pi96_width_ratio"),
|
| 100 |
+
weekly_quantile_crossing_rate=metrics.get("weekly_quantile_crossing_rate"),
|
| 101 |
+
weekly_sorted_quantile_crossing_rate=metrics.get(
|
| 102 |
+
"weekly_sorted_quantile_crossing_rate"
|
| 103 |
+
),
|
| 104 |
+
weekly_median_sort_gap_max=metrics.get("weekly_median_sort_gap_max"),
|
| 105 |
+
weekly_sample_count=metrics.get("weekly_sample_count"),
|
| 106 |
+
)
|
| 107 |
+
if not quality_gate_passed:
|
| 108 |
+
gate_error = "; ".join(reasons)
|
| 109 |
+
except Exception as exc:
|
| 110 |
+
gate_error = str(exc)
|
| 111 |
+
quality_gate_passed = False
|
| 112 |
+
else:
|
| 113 |
+
gate_error = "missing test_metrics"
|
| 114 |
+
|
| 115 |
+
safe = bool(quality_gate_passed and checkpoint_present and metadata_present)
|
| 116 |
+
next_required_action = "No action required; artifact is promotable."
|
| 117 |
+
if not safe:
|
| 118 |
+
next_required_action = (
|
| 119 |
+
gate_error
|
| 120 |
+
or "Run deterministic validation and pass the weekly quality gate before upload."
|
| 121 |
+
)
|
| 122 |
+
|
| 123 |
+
return {
|
| 124 |
+
"forecast_contract_version": (
|
| 125 |
+
metadata.get("forecast_contract_version")
|
| 126 |
+
or config.get("forecast_contract_version")
|
| 127 |
+
),
|
| 128 |
+
"monotonic_quantile_transform": bool(
|
| 129 |
+
config.get("monotonic_quantile_transform")
|
| 130 |
+
or metadata.get("monotonic_quantile_transform")
|
| 131 |
+
),
|
| 132 |
+
"checkpoint_present": checkpoint_present,
|
| 133 |
+
"metadata_present": metadata_present,
|
| 134 |
+
"conformal_present": conformal_present,
|
| 135 |
+
"quality_gate_passed": quality_gate_passed,
|
| 136 |
+
"best_trial_preflight_passed": bool(best_preflight.get("preflight_passed", False)),
|
| 137 |
+
"structural_invalidity_verdict": structural_report.get("verdict", "UNKNOWN"),
|
| 138 |
+
"safe_to_upload_to_hub": safe,
|
| 139 |
+
"safe_for_inference": safe,
|
| 140 |
+
"raw_quantile_crossing_rate": metrics.get("raw_quantile_crossing_rate"),
|
| 141 |
+
"ordered_quantile_crossing_rate": metrics.get("ordered_quantile_crossing_rate"),
|
| 142 |
+
"public_quantile_crossing_rate": metrics.get(
|
| 143 |
+
"public_quantile_crossing_rate",
|
| 144 |
+
metrics.get("quantile_crossing_rate"),
|
| 145 |
+
),
|
| 146 |
+
"variance_ratio": metrics.get("variance_ratio"),
|
| 147 |
+
"mae_vs_naive_zero": metrics.get("mae_vs_naive_zero"),
|
| 148 |
+
"weekly_mae_vs_naive_zero": metrics.get("weekly_mae_vs_naive_zero"),
|
| 149 |
+
"next_required_action": next_required_action,
|
| 150 |
+
}
|
| 151 |
+
|
| 152 |
+
|
| 153 |
def build_artifact_manifest(local_dir: str | Path) -> dict:
|
| 154 |
"""Build a SHA256 manifest for every present TFT artifact except itself."""
|
| 155 |
local_dir = Path(local_dir)
|
|
|
|
| 170 |
"manifest_version": 1,
|
| 171 |
"generated_at": datetime.now(timezone.utc).isoformat(),
|
| 172 |
"artifacts": artifacts,
|
| 173 |
+
"artifact_health": build_artifact_health(local_dir),
|
| 174 |
}
|
| 175 |
|
| 176 |
|
|
|
|
| 283 |
return True
|
| 284 |
|
| 285 |
|
| 286 |
+
def _manifest_safe_to_upload(local_dir: str | Path) -> bool:
|
| 287 |
+
manifest = _load_json(Path(local_dir) / "artifact_manifest.json")
|
| 288 |
+
health = manifest.get("artifact_health") or {}
|
| 289 |
+
return bool(health.get("safe_to_upload_to_hub"))
|
| 290 |
+
|
| 291 |
+
|
| 292 |
def upload_tft_artifacts(
|
| 293 |
local_dir: str | Path,
|
| 294 |
repo_id: str,
|
|
|
|
| 317 |
if not validate_tft_artifact_set(local_dir):
|
| 318 |
logger.warning("TFT artifact manifest validation failed before upload")
|
| 319 |
return False
|
| 320 |
+
if not _manifest_safe_to_upload(local_dir):
|
| 321 |
+
logger.warning("TFT artifact health is not safe for Hub upload; upload skipped")
|
| 322 |
+
return False
|
| 323 |
|
| 324 |
files_to_upload = [
|
| 325 |
local_dir / name for name in _ARTIFACTS if (local_dir / name).exists()
|
deep_learning/models/monotonic_quantiles.py
ADDED
|
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn.functional as F
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
def enforce_monotonic_quantiles(
|
| 8 |
+
y_pred: torch.Tensor,
|
| 9 |
+
median_idx: int = 3,
|
| 10 |
+
min_gap: float = 1e-5,
|
| 11 |
+
gap_scale: float = 0.01,
|
| 12 |
+
init_bias: float = -3.0,
|
| 13 |
+
) -> torch.Tensor:
|
| 14 |
+
"""
|
| 15 |
+
Transform unconstrained quantile outputs into structurally monotonic
|
| 16 |
+
quantile outputs.
|
| 17 |
+
|
| 18 |
+
The median dimension is preserved exactly. Lower/upper quantile distances
|
| 19 |
+
are positive by construction and scaled for log-return targets.
|
| 20 |
+
"""
|
| 21 |
+
base = y_pred[..., median_idx]
|
| 22 |
+
|
| 23 |
+
lower_raw = y_pred[..., :median_idx]
|
| 24 |
+
upper_raw = y_pred[..., median_idx + 1 :]
|
| 25 |
+
|
| 26 |
+
lower_steps = min_gap + gap_scale * F.softplus(
|
| 27 |
+
torch.flip(lower_raw, dims=[-1]) + init_bias
|
| 28 |
+
)
|
| 29 |
+
upper_steps = min_gap + gap_scale * F.softplus(upper_raw + init_bias)
|
| 30 |
+
|
| 31 |
+
lower_from_median = torch.cumsum(lower_steps, dim=-1)
|
| 32 |
+
upper_from_median = torch.cumsum(upper_steps, dim=-1)
|
| 33 |
+
|
| 34 |
+
lower = base.unsqueeze(-1) - lower_from_median
|
| 35 |
+
lower = torch.flip(lower, dims=[-1])
|
| 36 |
+
upper = base.unsqueeze(-1) + upper_from_median
|
| 37 |
+
|
| 38 |
+
ordered = torch.cat([lower, base.unsqueeze(-1), upper], dim=-1)
|
| 39 |
+
|
| 40 |
+
assert ordered.shape == y_pred.shape, (
|
| 41 |
+
f"Monotonic transform output shape {ordered.shape} "
|
| 42 |
+
f"does not match input shape {y_pred.shape}"
|
| 43 |
+
)
|
| 44 |
+
return ordered
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
def validate_monotonicity(
|
| 48 |
+
y_pred: torch.Tensor,
|
| 49 |
+
tolerance: float = 1e-6,
|
| 50 |
+
) -> dict:
|
| 51 |
+
"""Return crossing diagnostics for an ordered quantile tensor."""
|
| 52 |
+
diffs = y_pred[..., 1:] - y_pred[..., :-1]
|
| 53 |
+
violations = diffs < -tolerance
|
| 54 |
+
crossing_rate = violations.float().mean().item()
|
| 55 |
+
max_violation = (
|
| 56 |
+
(-diffs[violations]).max().item() if violations.any().item() else 0.0
|
| 57 |
+
)
|
| 58 |
+
|
| 59 |
+
return {
|
| 60 |
+
"crossing_rate": crossing_rate,
|
| 61 |
+
"max_violation": max_violation,
|
| 62 |
+
"is_valid": crossing_rate == 0.0,
|
| 63 |
+
}
|
deep_learning/models/tft_copper.py
CHANGED
|
@@ -15,10 +15,15 @@ from pathlib import Path
|
|
| 15 |
from typing import Any, Dict, Optional, Sequence
|
| 16 |
|
| 17 |
import torch
|
|
|
|
| 18 |
import numpy as np
|
| 19 |
|
| 20 |
from deep_learning.contract import RETURN_SPACE, log_to_simple_return
|
| 21 |
from deep_learning.config import TFTASROConfig, get_tft_config
|
|
|
|
|
|
|
|
|
|
|
|
|
| 22 |
from deep_learning.models.losses import (
|
| 23 |
AdaptiveSharpeRatioLoss,
|
| 24 |
CombinedQuantileLoss,
|
|
@@ -131,38 +136,76 @@ try:
|
|
| 131 |
def __init__(
|
| 132 |
self,
|
| 133 |
quantiles: list,
|
| 134 |
-
lambda_weekly_quantile: float = 0.
|
| 135 |
-
lambda_t1_quantile: float = 0.
|
|
|
|
| 136 |
lambda_directional: float = 0.10,
|
| 137 |
-
|
| 138 |
-
|
| 139 |
-
lambda_crossing: float = 7.0,
|
| 140 |
-
lambda_sanity: float = 0.20,
|
| 141 |
-
lambda_width: float = 0.50,
|
| 142 |
-
lambda_tail_width: float = 0.30,
|
| 143 |
-
sharpe_eps: float = 1e-6,
|
| 144 |
-
daily_log_return_bound: float = 0.08,
|
| 145 |
-
weekly_log_return_bound: float = 0.20,
|
| 146 |
):
|
| 147 |
super().__init__(quantiles=quantiles)
|
| 148 |
self.lambda_weekly_quantile = lambda_weekly_quantile
|
| 149 |
self.lambda_t1_quantile = lambda_t1_quantile
|
|
|
|
| 150 |
self.lambda_directional = lambda_directional
|
| 151 |
-
self.lambda_magnitude = lambda_magnitude
|
| 152 |
-
self.lambda_vol = lambda_vol
|
| 153 |
-
self.lambda_crossing = lambda_crossing
|
| 154 |
-
self.lambda_sanity = lambda_sanity
|
| 155 |
-
self.lambda_width = lambda_width
|
| 156 |
-
self.lambda_tail_width = lambda_tail_width
|
| 157 |
self.sharpe_eps = sharpe_eps
|
| 158 |
-
self.
|
| 159 |
-
self.weekly_log_return_bound = weekly_log_return_bound
|
| 160 |
self.median_idx = len(quantiles) // 2
|
| 161 |
-
|
| 162 |
-
|
| 163 |
-
|
| 164 |
-
self.
|
| 165 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 166 |
|
| 167 |
def _pinball(self, pred: torch.Tensor, actual: torch.Tensor) -> torch.Tensor:
|
| 168 |
q = torch.tensor(self.quantiles, device=pred.device, dtype=pred.dtype).view(1, -1)
|
|
@@ -178,70 +221,48 @@ try:
|
|
| 178 |
y_actual = y_actual.float()
|
| 179 |
y_pred = y_pred.float()
|
| 180 |
|
| 181 |
-
|
| 182 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 183 |
actual_weekly = y_actual.sum(dim=1)
|
| 184 |
|
| 185 |
weekly_q_loss = self._pinball(pred_weekly_quantiles, actual_weekly)
|
| 186 |
-
t1_q_loss = super().loss(
|
| 187 |
|
| 188 |
pred_weekly_median = median_path.sum(dim=1)
|
| 189 |
-
|
| 190 |
-
|
| 191 |
-
|
| 192 |
-
)
|
| 193 |
|
| 194 |
-
|
| 195 |
-
|
| 196 |
-
|
| 197 |
-
|
| 198 |
-
(pred_weekly_median.abs() + self.sharpe_eps)
|
| 199 |
-
/ (actual_weekly.abs() + self.sharpe_eps)
|
| 200 |
-
)
|
| 201 |
-
).mean()
|
| 202 |
-
if material_mask.any():
|
| 203 |
-
pred_abs = pred_weekly_median[material_mask].abs()
|
| 204 |
-
true_abs = actual_weekly[material_mask].abs()
|
| 205 |
-
material_magnitude_loss = torch.abs(
|
| 206 |
-
torch.log((pred_abs + self.sharpe_eps) / (true_abs + self.sharpe_eps))
|
| 207 |
-
).mean()
|
| 208 |
-
else:
|
| 209 |
-
material_magnitude_loss = y_pred.new_tensor(0.0)
|
| 210 |
-
magnitude_loss = 0.5 * global_magnitude_loss + 0.5 * material_magnitude_loss
|
| 211 |
|
| 212 |
-
|
| 213 |
-
|
| 214 |
-
|
| 215 |
-
)
|
| 216 |
-
actual_weekly_std = actual_weekly.std() + self.sharpe_eps
|
| 217 |
-
target_spread = 2.56 * actual_weekly_std
|
| 218 |
-
mean_weekly_spread = weekly_spread.mean()
|
| 219 |
-
vol_loss = torch.abs(mean_weekly_spread - target_spread)
|
| 220 |
-
width_ratio = mean_weekly_spread / (target_spread + self.sharpe_eps)
|
| 221 |
-
safe_width_ratio = torch.clamp(width_ratio + self.sharpe_eps, min=1e-6)
|
| 222 |
-
width_loss = torch.abs(torch.log(safe_width_ratio))
|
| 223 |
-
width_loss = width_loss + torch.relu(width_ratio - 2.0).pow(2)
|
| 224 |
-
|
| 225 |
-
weekly_tail_spread = (
|
| 226 |
-
pred_weekly_quantiles[:, self._q98_idx]
|
| 227 |
-
- pred_weekly_quantiles[:, self._q02_idx]
|
| 228 |
-
)
|
| 229 |
-
target_tail_spread = 4.10 * actual_weekly_std
|
| 230 |
-
tail_width_ratio = weekly_tail_spread.mean() / (target_tail_spread + self.sharpe_eps)
|
| 231 |
-
safe_tail_width_ratio = torch.clamp(tail_width_ratio + self.sharpe_eps, min=1e-6)
|
| 232 |
-
tail_width_loss = torch.abs(torch.log(safe_tail_width_ratio))
|
| 233 |
-
tail_width_loss = tail_width_loss + torch.relu(tail_width_ratio - 3.0).pow(2)
|
| 234 |
-
daily_crossing_loss = quantile_crossing_penalty(y_pred)
|
| 235 |
-
weekly_crossing_loss = quantile_crossing_penalty(pred_weekly_quantiles.unsqueeze(1))
|
| 236 |
-
crossing_loss = daily_crossing_loss + weekly_crossing_loss
|
| 237 |
-
|
| 238 |
-
daily_bound_loss = torch.relu(
|
| 239 |
-
median_path.abs() - self.daily_log_return_bound
|
| 240 |
-
).pow(2).mean()
|
| 241 |
-
weekly_bound_loss = torch.relu(
|
| 242 |
-
pred_weekly_median.abs() - self.weekly_log_return_bound
|
| 243 |
-
).pow(2).mean()
|
| 244 |
-
sanity_loss = daily_bound_loss + weekly_bound_loss
|
| 245 |
|
| 246 |
def _to_scalar(x: torch.Tensor) -> torch.Tensor:
|
| 247 |
# pytorch_forecasting metrics can return per-sample tensors;
|
|
@@ -249,17 +270,24 @@ try:
|
|
| 249 |
# boolean comparisons in tests and stable optimizer behaviour.
|
| 250 |
return x.mean() if x.ndim > 0 else x
|
| 251 |
|
| 252 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 253 |
self.lambda_weekly_quantile * _to_scalar(weekly_q_loss)
|
| 254 |
+ self.lambda_t1_quantile * _to_scalar(t1_q_loss)
|
| 255 |
-
+ self.
|
| 256 |
-
+ self.
|
| 257 |
-
+ self.lambda_vol * _to_scalar(vol_loss)
|
| 258 |
-
+ self.lambda_width * _to_scalar(width_loss)
|
| 259 |
-
+ self.lambda_tail_width * _to_scalar(tail_width_loss)
|
| 260 |
-
+ self.lambda_crossing * _to_scalar(crossing_loss)
|
| 261 |
-
+ self.lambda_sanity * _to_scalar(sanity_loss)
|
| 262 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 263 |
|
| 264 |
except ImportError:
|
| 265 |
ASROPFLoss = None # type: ignore[assignment,misc]
|
|
@@ -295,25 +323,15 @@ def create_tft_model(
|
|
| 295 |
quantiles=quantiles,
|
| 296 |
lambda_weekly_quantile=cfg.weekly_loss.lambda_weekly_quantile,
|
| 297 |
lambda_t1_quantile=cfg.weekly_loss.lambda_t1_quantile,
|
|
|
|
| 298 |
lambda_directional=cfg.weekly_loss.lambda_directional,
|
| 299 |
-
lambda_magnitude=cfg.weekly_loss.lambda_magnitude,
|
| 300 |
-
lambda_vol=cfg.weekly_loss.lambda_vol,
|
| 301 |
-
lambda_crossing=cfg.weekly_loss.lambda_crossing,
|
| 302 |
-
lambda_sanity=cfg.weekly_loss.lambda_sanity,
|
| 303 |
-
lambda_width=cfg.weekly_loss.lambda_width,
|
| 304 |
-
lambda_tail_width=cfg.weekly_loss.lambda_tail_width,
|
| 305 |
)
|
| 306 |
logger.info(
|
| 307 |
-
"Using weekly ASRO loss | weekly_q=%.2f t1_q=%.2f
|
| 308 |
cfg.weekly_loss.lambda_weekly_quantile,
|
| 309 |
cfg.weekly_loss.lambda_t1_quantile,
|
|
|
|
| 310 |
cfg.weekly_loss.lambda_directional,
|
| 311 |
-
cfg.weekly_loss.lambda_magnitude,
|
| 312 |
-
cfg.weekly_loss.lambda_vol,
|
| 313 |
-
cfg.weekly_loss.lambda_width,
|
| 314 |
-
cfg.weekly_loss.lambda_tail_width,
|
| 315 |
-
cfg.weekly_loss.lambda_crossing,
|
| 316 |
-
cfg.weekly_loss.lambda_sanity,
|
| 317 |
)
|
| 318 |
elif use_asro and ASROPFLoss is not None:
|
| 319 |
loss = ASROPFLoss(
|
|
@@ -490,20 +508,28 @@ def _format_prediction_legacy_simple_return(
|
|
| 490 |
quantile_diffs = np.diff(raw_pred, axis=-1) if raw_pred.shape[-1] > 1 else np.array([])
|
| 491 |
crossing_mask = quantile_diffs < -1e-12 if quantile_diffs.size else np.array([], dtype=bool)
|
| 492 |
quantile_crossing_detected = bool(crossing_mask.any())
|
| 493 |
-
|
| 494 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 495 |
median_sort_gap = float(
|
| 496 |
np.max(np.abs(raw_pred[..., median_idx] - sorted_pred[..., median_idx]))
|
| 497 |
)
|
| 498 |
if quantile_crossing_detected:
|
| 499 |
logger.error(
|
| 500 |
"format_prediction: non-monotonic quantiles detected "
|
| 501 |
-
"(
|
| 502 |
-
"
|
| 503 |
-
|
| 504 |
median_sort_gap,
|
| 505 |
)
|
| 506 |
-
pred = sorted_pred
|
| 507 |
|
| 508 |
if _math.isnan(baseline_price) or _math.isinf(baseline_price) or baseline_price <= 0:
|
| 509 |
logger.warning(
|
|
@@ -643,18 +669,42 @@ def format_prediction(
|
|
| 643 |
quantile_diffs = np.diff(raw_pred, axis=-1) if raw_pred.shape[-1] > 1 else np.array([])
|
| 644 |
crossing_mask = quantile_diffs < -1e-12 if quantile_diffs.size else np.array([], dtype=bool)
|
| 645 |
quantile_crossing_detected = bool(crossing_mask.any())
|
| 646 |
-
|
| 647 |
-
|
| 648 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 649 |
if quantile_crossing_detected:
|
| 650 |
logger.error(
|
| 651 |
"format_prediction: non-monotonic quantiles detected "
|
| 652 |
-
"(
|
| 653 |
-
"
|
| 654 |
-
|
| 655 |
-
|
| 656 |
)
|
| 657 |
-
pred = sorted_pred
|
| 658 |
|
| 659 |
if _math.isnan(baseline_price) or _math.isinf(baseline_price) or baseline_price <= 0:
|
| 660 |
logger.warning(
|
|
@@ -750,8 +800,13 @@ def format_prediction(
|
|
| 750 |
"quantiles_log": {f"q{q:.2f}": float(bounded_pred[0, i]) for i, q in enumerate(q_list)},
|
| 751 |
"raw_quantiles": {f"q{q:.2f}": float(raw_pred[0, i]) for i, q in enumerate(q_list)},
|
| 752 |
"quantile_crossing_detected": quantile_crossing_detected,
|
| 753 |
-
"quantile_crossing_rate":
|
| 754 |
-
"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 755 |
"weekly_return": log_to_simple_return(weekly_log_return),
|
| 756 |
"weekly_log_return": weekly_log_return,
|
| 757 |
"weekly_price": _price(weekly_log_return),
|
|
|
|
| 15 |
from typing import Any, Dict, Optional, Sequence
|
| 16 |
|
| 17 |
import torch
|
| 18 |
+
import torch.nn.functional as F
|
| 19 |
import numpy as np
|
| 20 |
|
| 21 |
from deep_learning.contract import RETURN_SPACE, log_to_simple_return
|
| 22 |
from deep_learning.config import TFTASROConfig, get_tft_config
|
| 23 |
+
from deep_learning.models.monotonic_quantiles import (
|
| 24 |
+
enforce_monotonic_quantiles,
|
| 25 |
+
validate_monotonicity,
|
| 26 |
+
)
|
| 27 |
from deep_learning.models.losses import (
|
| 28 |
AdaptiveSharpeRatioLoss,
|
| 29 |
CombinedQuantileLoss,
|
|
|
|
| 136 |
def __init__(
|
| 137 |
self,
|
| 138 |
quantiles: list,
|
| 139 |
+
lambda_weekly_quantile: float = 0.55,
|
| 140 |
+
lambda_t1_quantile: float = 0.15,
|
| 141 |
+
lambda_dispersion: float = 0.20,
|
| 142 |
lambda_directional: float = 0.10,
|
| 143 |
+
sharpe_eps: float = 1e-8,
|
| 144 |
+
debug_mode: bool = False,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 145 |
):
|
| 146 |
super().__init__(quantiles=quantiles)
|
| 147 |
self.lambda_weekly_quantile = lambda_weekly_quantile
|
| 148 |
self.lambda_t1_quantile = lambda_t1_quantile
|
| 149 |
+
self.lambda_dispersion = lambda_dispersion
|
| 150 |
self.lambda_directional = lambda_directional
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 151 |
self.sharpe_eps = sharpe_eps
|
| 152 |
+
self.debug_mode = debug_mode
|
|
|
|
| 153 |
self.median_idx = len(quantiles) // 2
|
| 154 |
+
self.reset_component_accumulators()
|
| 155 |
+
|
| 156 |
+
def reset_component_accumulators(self) -> None:
|
| 157 |
+
self._component_sums = {
|
| 158 |
+
"weekly_q": 0.0,
|
| 159 |
+
"t1_q": 0.0,
|
| 160 |
+
"dispersion": 0.0,
|
| 161 |
+
"directional": 0.0,
|
| 162 |
+
"total": 0.0,
|
| 163 |
+
}
|
| 164 |
+
self._component_batches = 0
|
| 165 |
+
|
| 166 |
+
def _record_components(
|
| 167 |
+
self,
|
| 168 |
+
weekly_q_loss: torch.Tensor,
|
| 169 |
+
t1_q_loss: torch.Tensor,
|
| 170 |
+
dispersion_loss: torch.Tensor,
|
| 171 |
+
directional_loss: torch.Tensor,
|
| 172 |
+
total_loss: torch.Tensor,
|
| 173 |
+
) -> None:
|
| 174 |
+
self._component_sums["weekly_q"] += float(weekly_q_loss.detach().mean().cpu())
|
| 175 |
+
self._component_sums["t1_q"] += float(t1_q_loss.detach().mean().cpu())
|
| 176 |
+
self._component_sums["dispersion"] += float(dispersion_loss.detach().mean().cpu())
|
| 177 |
+
self._component_sums["directional"] += float(directional_loss.detach().mean().cpu())
|
| 178 |
+
self._component_sums["total"] += float(total_loss.detach().mean().cpu())
|
| 179 |
+
self._component_batches += 1
|
| 180 |
+
|
| 181 |
+
def component_means(self) -> dict:
|
| 182 |
+
n_batches = self._component_batches
|
| 183 |
+
if n_batches <= 0:
|
| 184 |
+
return {
|
| 185 |
+
"n_batches": 0,
|
| 186 |
+
"weekly_q_loss_mean": 0.0,
|
| 187 |
+
"t1_q_loss_mean": 0.0,
|
| 188 |
+
"dispersion_loss_mean": 0.0,
|
| 189 |
+
"directional_loss_mean": 0.0,
|
| 190 |
+
"total_loss_mean": 0.0,
|
| 191 |
+
"dominant_component": None,
|
| 192 |
+
}
|
| 193 |
+
|
| 194 |
+
components = {
|
| 195 |
+
"weekly_q": self._component_sums["weekly_q"],
|
| 196 |
+
"t1_q": self._component_sums["t1_q"],
|
| 197 |
+
"dispersion": self._component_sums["dispersion"],
|
| 198 |
+
"directional": self._component_sums["directional"],
|
| 199 |
+
}
|
| 200 |
+
return {
|
| 201 |
+
"n_batches": n_batches,
|
| 202 |
+
"weekly_q_loss_mean": self._component_sums["weekly_q"] / n_batches,
|
| 203 |
+
"t1_q_loss_mean": self._component_sums["t1_q"] / n_batches,
|
| 204 |
+
"dispersion_loss_mean": self._component_sums["dispersion"] / n_batches,
|
| 205 |
+
"directional_loss_mean": self._component_sums["directional"] / n_batches,
|
| 206 |
+
"total_loss_mean": self._component_sums["total"] / n_batches,
|
| 207 |
+
"dominant_component": max(components, key=components.get),
|
| 208 |
+
}
|
| 209 |
|
| 210 |
def _pinball(self, pred: torch.Tensor, actual: torch.Tensor) -> torch.Tensor:
|
| 211 |
q = torch.tensor(self.quantiles, device=pred.device, dtype=pred.dtype).view(1, -1)
|
|
|
|
| 221 |
y_actual = y_actual.float()
|
| 222 |
y_pred = y_pred.float()
|
| 223 |
|
| 224 |
+
ordered_pred = enforce_monotonic_quantiles(
|
| 225 |
+
y_pred,
|
| 226 |
+
median_idx=self.median_idx,
|
| 227 |
+
min_gap=1e-5,
|
| 228 |
+
gap_scale=0.01,
|
| 229 |
+
init_bias=-3.0,
|
| 230 |
+
)
|
| 231 |
+
if self.debug_mode:
|
| 232 |
+
ordered_diagnostics = validate_monotonicity(ordered_pred)
|
| 233 |
+
assert ordered_diagnostics["is_valid"], (
|
| 234 |
+
f"Monotonic transform produced crossings: "
|
| 235 |
+
f"rate={ordered_diagnostics['crossing_rate']}, "
|
| 236 |
+
f"max_violation={ordered_diagnostics['max_violation']}"
|
| 237 |
+
)
|
| 238 |
+
assert torch.allclose(
|
| 239 |
+
ordered_pred[..., self.median_idx],
|
| 240 |
+
y_pred[..., self.median_idx],
|
| 241 |
+
rtol=1e-6,
|
| 242 |
+
atol=1e-7,
|
| 243 |
+
), "Monotonic transform must preserve the median quantile exactly"
|
| 244 |
+
|
| 245 |
+
median_path = ordered_pred[..., self.median_idx]
|
| 246 |
+
pred_weekly_quantiles = ordered_pred.sum(dim=1)
|
| 247 |
actual_weekly = y_actual.sum(dim=1)
|
| 248 |
|
| 249 |
weekly_q_loss = self._pinball(pred_weekly_quantiles, actual_weekly)
|
| 250 |
+
t1_q_loss = super().loss(ordered_pred[:, 0:1, :], y_actual[:, 0:1])
|
| 251 |
|
| 252 |
pred_weekly_median = median_path.sum(dim=1)
|
| 253 |
+
eps = self.sharpe_eps
|
| 254 |
+
pred_std = pred_weekly_median.std() + eps
|
| 255 |
+
actual_std = actual_weekly.std() + eps
|
| 256 |
+
dispersion_loss = torch.abs(torch.log(pred_std / actual_std))
|
| 257 |
|
| 258 |
+
pred_abs_med = pred_weekly_median.abs().median() + eps
|
| 259 |
+
actual_abs_med = actual_weekly.abs().median() + eps
|
| 260 |
+
magnitude_loss = torch.abs(torch.log(pred_abs_med / actual_abs_med))
|
| 261 |
+
combined_calibration_loss = 0.5 * dispersion_loss + 0.5 * magnitude_loss
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 262 |
|
| 263 |
+
pred_direction = torch.tanh(median_path * 10.0)
|
| 264 |
+
actual_direction = torch.sign(y_actual)
|
| 265 |
+
directional_loss = F.mse_loss(pred_direction, actual_direction.float())
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 266 |
|
| 267 |
def _to_scalar(x: torch.Tensor) -> torch.Tensor:
|
| 268 |
# pytorch_forecasting metrics can return per-sample tensors;
|
|
|
|
| 270 |
# boolean comparisons in tests and stable optimizer behaviour.
|
| 271 |
return x.mean() if x.ndim > 0 else x
|
| 272 |
|
| 273 |
+
weekly_q_loss = _to_scalar(weekly_q_loss)
|
| 274 |
+
t1_q_loss = _to_scalar(t1_q_loss)
|
| 275 |
+
combined_calibration_loss = _to_scalar(combined_calibration_loss)
|
| 276 |
+
directional_loss = _to_scalar(directional_loss)
|
| 277 |
+
total_loss = (
|
| 278 |
self.lambda_weekly_quantile * _to_scalar(weekly_q_loss)
|
| 279 |
+ self.lambda_t1_quantile * _to_scalar(t1_q_loss)
|
| 280 |
+
+ self.lambda_dispersion * _to_scalar(combined_calibration_loss)
|
| 281 |
+
+ self.lambda_directional * _to_scalar(directional_loss)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 282 |
)
|
| 283 |
+
self._record_components(
|
| 284 |
+
weekly_q_loss,
|
| 285 |
+
t1_q_loss,
|
| 286 |
+
combined_calibration_loss,
|
| 287 |
+
directional_loss,
|
| 288 |
+
total_loss,
|
| 289 |
+
)
|
| 290 |
+
return total_loss
|
| 291 |
|
| 292 |
except ImportError:
|
| 293 |
ASROPFLoss = None # type: ignore[assignment,misc]
|
|
|
|
| 323 |
quantiles=quantiles,
|
| 324 |
lambda_weekly_quantile=cfg.weekly_loss.lambda_weekly_quantile,
|
| 325 |
lambda_t1_quantile=cfg.weekly_loss.lambda_t1_quantile,
|
| 326 |
+
lambda_dispersion=cfg.weekly_loss.lambda_dispersion,
|
| 327 |
lambda_directional=cfg.weekly_loss.lambda_directional,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 328 |
)
|
| 329 |
logger.info(
|
| 330 |
+
"Using weekly ASRO loss | weekly_q=%.2f t1_q=%.2f dispersion=%.2f dir=%.2f monotonic_transform=true",
|
| 331 |
cfg.weekly_loss.lambda_weekly_quantile,
|
| 332 |
cfg.weekly_loss.lambda_t1_quantile,
|
| 333 |
+
cfg.weekly_loss.lambda_dispersion,
|
| 334 |
cfg.weekly_loss.lambda_directional,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 335 |
)
|
| 336 |
elif use_asro and ASROPFLoss is not None:
|
| 337 |
loss = ASROPFLoss(
|
|
|
|
| 508 |
quantile_diffs = np.diff(raw_pred, axis=-1) if raw_pred.shape[-1] > 1 else np.array([])
|
| 509 |
crossing_mask = quantile_diffs < -1e-12 if quantile_diffs.size else np.array([], dtype=bool)
|
| 510 |
quantile_crossing_detected = bool(crossing_mask.any())
|
| 511 |
+
raw_quantile_crossing_rate = float(crossing_mask.mean()) if crossing_mask.size else 0.0
|
| 512 |
+
ordered_tensor = enforce_monotonic_quantiles(
|
| 513 |
+
torch.as_tensor(raw_pred, dtype=torch.float64),
|
| 514 |
+
median_idx=median_idx,
|
| 515 |
+
min_gap=1e-5,
|
| 516 |
+
gap_scale=0.01,
|
| 517 |
+
init_bias=-3.0,
|
| 518 |
+
)
|
| 519 |
+
pred = ordered_tensor.detach().cpu().numpy()
|
| 520 |
+
quantile_crossing_rate = 0.0
|
| 521 |
+
sorted_pred = pred
|
| 522 |
median_sort_gap = float(
|
| 523 |
np.max(np.abs(raw_pred[..., median_idx] - sorted_pred[..., median_idx]))
|
| 524 |
)
|
| 525 |
if quantile_crossing_detected:
|
| 526 |
logger.error(
|
| 527 |
"format_prediction: non-monotonic quantiles detected "
|
| 528 |
+
"(raw_crossing_rate=%.3f, max_median_sort_gap=%.4f); public output "
|
| 529 |
+
"uses the structural monotonic transform and exposes raw_quantiles for audit.",
|
| 530 |
+
raw_quantile_crossing_rate,
|
| 531 |
median_sort_gap,
|
| 532 |
)
|
|
|
|
| 533 |
|
| 534 |
if _math.isnan(baseline_price) or _math.isinf(baseline_price) or baseline_price <= 0:
|
| 535 |
logger.warning(
|
|
|
|
| 669 |
quantile_diffs = np.diff(raw_pred, axis=-1) if raw_pred.shape[-1] > 1 else np.array([])
|
| 670 |
crossing_mask = quantile_diffs < -1e-12 if quantile_diffs.size else np.array([], dtype=bool)
|
| 671 |
quantile_crossing_detected = bool(crossing_mask.any())
|
| 672 |
+
raw_quantile_crossing_rate = float(crossing_mask.mean()) if crossing_mask.size else 0.0
|
| 673 |
+
ordered_tensor = enforce_monotonic_quantiles(
|
| 674 |
+
torch.as_tensor(raw_pred, dtype=torch.float64),
|
| 675 |
+
median_idx=median_idx,
|
| 676 |
+
min_gap=1e-5,
|
| 677 |
+
gap_scale=0.01,
|
| 678 |
+
init_bias=-3.0,
|
| 679 |
+
)
|
| 680 |
+
pred = ordered_tensor.detach().cpu().numpy()
|
| 681 |
+
ordered_diffs = np.diff(pred, axis=-1) if pred.shape[-1] > 1 else np.array([])
|
| 682 |
+
ordered_crossing_mask = (
|
| 683 |
+
ordered_diffs < -1e-12 if ordered_diffs.size else np.array([], dtype=bool)
|
| 684 |
+
)
|
| 685 |
+
ordered_quantile_crossing_rate = (
|
| 686 |
+
float(ordered_crossing_mask.mean()) if ordered_crossing_mask.size else 0.0
|
| 687 |
+
)
|
| 688 |
+
if ordered_quantile_crossing_rate > 0.0:
|
| 689 |
+
raise AssertionError(
|
| 690 |
+
"Monotonic quantile transform produced public crossings: "
|
| 691 |
+
f"{ordered_quantile_crossing_rate:.6f}"
|
| 692 |
+
)
|
| 693 |
+
sorted_raw = np.sort(raw_pred, axis=-1)
|
| 694 |
+
raw_median_sort_gap = float(
|
| 695 |
+
np.max(np.abs(raw_pred[..., median_idx] - sorted_raw[..., median_idx]))
|
| 696 |
+
)
|
| 697 |
+
ordered_median_sort_gap = float(
|
| 698 |
+
np.max(np.abs(pred[..., median_idx] - np.sort(pred, axis=-1)[..., median_idx]))
|
| 699 |
+
)
|
| 700 |
if quantile_crossing_detected:
|
| 701 |
logger.error(
|
| 702 |
"format_prediction: non-monotonic quantiles detected "
|
| 703 |
+
"(raw_crossing_rate=%.3f, raw_max_median_sort_gap=%.4f); public output "
|
| 704 |
+
"uses the structural monotonic transform and exposes raw_quantiles for audit.",
|
| 705 |
+
raw_quantile_crossing_rate,
|
| 706 |
+
raw_median_sort_gap,
|
| 707 |
)
|
|
|
|
| 708 |
|
| 709 |
if _math.isnan(baseline_price) or _math.isinf(baseline_price) or baseline_price <= 0:
|
| 710 |
logger.warning(
|
|
|
|
| 800 |
"quantiles_log": {f"q{q:.2f}": float(bounded_pred[0, i]) for i, q in enumerate(q_list)},
|
| 801 |
"raw_quantiles": {f"q{q:.2f}": float(raw_pred[0, i]) for i, q in enumerate(q_list)},
|
| 802 |
"quantile_crossing_detected": quantile_crossing_detected,
|
| 803 |
+
"quantile_crossing_rate": ordered_quantile_crossing_rate,
|
| 804 |
+
"raw_quantile_crossing_rate": raw_quantile_crossing_rate,
|
| 805 |
+
"ordered_quantile_crossing_rate": ordered_quantile_crossing_rate,
|
| 806 |
+
"public_quantile_crossing_rate": ordered_quantile_crossing_rate,
|
| 807 |
+
"median_sort_gap": ordered_median_sort_gap,
|
| 808 |
+
"raw_median_sort_gap": raw_median_sort_gap,
|
| 809 |
+
"ordered_median_sort_gap": ordered_median_sort_gap,
|
| 810 |
"weekly_return": log_to_simple_return(weekly_log_return),
|
| 811 |
"weekly_log_return": weekly_log_return,
|
| 812 |
"weekly_price": _price(weekly_log_return),
|
deep_learning/training/callbacks.py
CHANGED
|
@@ -81,6 +81,47 @@ class CurriculumLossScheduler(pl.Callback):
|
|
| 81 |
)
|
| 82 |
|
| 83 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 84 |
class SWACallback(pl.Callback):
|
| 85 |
"""
|
| 86 |
Stochastic Weight Averaging over the last ``swa_pct`` of training.
|
|
|
|
| 81 |
)
|
| 82 |
|
| 83 |
|
| 84 |
+
class WeeklyLossComponentLogger(pl.Callback):
|
| 85 |
+
"""Log weekly loss component scales at validation epoch boundaries."""
|
| 86 |
+
|
| 87 |
+
def on_validation_epoch_start(self, trainer, pl_module):
|
| 88 |
+
loss = getattr(pl_module, "loss", None)
|
| 89 |
+
if hasattr(loss, "reset_component_accumulators"):
|
| 90 |
+
loss.reset_component_accumulators()
|
| 91 |
+
|
| 92 |
+
def on_validation_epoch_end(self, trainer, pl_module):
|
| 93 |
+
loss = getattr(pl_module, "loss", None)
|
| 94 |
+
if not hasattr(loss, "component_means"):
|
| 95 |
+
return
|
| 96 |
+
|
| 97 |
+
stats = loss.component_means()
|
| 98 |
+
if not stats.get("n_batches"):
|
| 99 |
+
return
|
| 100 |
+
|
| 101 |
+
epoch = getattr(trainer, "current_epoch", 0)
|
| 102 |
+
logger.info(
|
| 103 |
+
"Weekly loss components | epoch=%s weekly_q=%.6f t1_q=%.6f "
|
| 104 |
+
"dispersion=%.6f directional=%.6f total=%.6f dominant=%s",
|
| 105 |
+
epoch,
|
| 106 |
+
stats["weekly_q_loss_mean"],
|
| 107 |
+
stats["t1_q_loss_mean"],
|
| 108 |
+
stats["dispersion_loss_mean"],
|
| 109 |
+
stats["directional_loss_mean"],
|
| 110 |
+
stats["total_loss_mean"],
|
| 111 |
+
stats["dominant_component"],
|
| 112 |
+
)
|
| 113 |
+
if stats["dispersion_loss_mean"] > 3.0 * max(stats["weekly_q_loss_mean"], 1e-12):
|
| 114 |
+
logger.warning(
|
| 115 |
+
"Weekly dispersion loss is dominating weekly quantile loss; "
|
| 116 |
+
"lambda_dispersion may need to be reduced."
|
| 117 |
+
)
|
| 118 |
+
if stats["directional_loss_mean"] < 0.05 * max(stats["total_loss_mean"], 1e-12):
|
| 119 |
+
logger.warning(
|
| 120 |
+
"Weekly directional loss is below 5%% of total loss; "
|
| 121 |
+
"lambda_directional may need to increase."
|
| 122 |
+
)
|
| 123 |
+
|
| 124 |
+
|
| 125 |
class SWACallback(pl.Callback):
|
| 126 |
"""
|
| 127 |
Stochastic Weight Averaging over the last ``swa_pct`` of training.
|
deep_learning/training/hyperopt.py
CHANGED
|
@@ -13,6 +13,7 @@ from __future__ import annotations
|
|
| 13 |
import argparse
|
| 14 |
import json
|
| 15 |
import logging
|
|
|
|
| 16 |
import warnings
|
| 17 |
from dataclasses import replace
|
| 18 |
from pathlib import Path
|
|
@@ -37,6 +38,16 @@ from deep_learning.config import (
|
|
| 37 |
)
|
| 38 |
from deep_learning.logging_utils import configure_cli_logging, suppress_lightning_noise
|
| 39 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 40 |
logger = logging.getLogger(__name__)
|
| 41 |
|
| 42 |
MIN_COMPLETED_TRIALS = 10
|
|
@@ -55,15 +66,10 @@ KNOWN_GOOD_TRIAL_PARAMS = {
|
|
| 55 |
"lambda_vol": 0.30,
|
| 56 |
"lambda_quantile": 0.25,
|
| 57 |
"lambda_madl": 0.40,
|
| 58 |
-
"lambda_weekly_quantile": 0.
|
| 59 |
-
"lambda_t1_quantile": 0.
|
|
|
|
| 60 |
"lambda_directional": 0.10,
|
| 61 |
-
"lambda_magnitude": 0.55,
|
| 62 |
-
"weekly_lambda_vol": 0.35,
|
| 63 |
-
"lambda_width": 0.50,
|
| 64 |
-
"lambda_tail_width": 0.30,
|
| 65 |
-
"lambda_sanity": 0.20,
|
| 66 |
-
"lambda_crossing": 7.0,
|
| 67 |
"batch_size": 32,
|
| 68 |
}
|
| 69 |
|
|
@@ -144,7 +150,9 @@ def _build_prune_diagnostics(study) -> tuple[dict[str, int], list[dict]]:
|
|
| 144 |
"avg_variance_ratio",
|
| 145 |
"avg_directional_accuracy",
|
| 146 |
"avg_val_sharpe",
|
|
|
|
| 147 |
"avg_quantile_crossing_rate",
|
|
|
|
| 148 |
"avg_median_sort_gap",
|
| 149 |
"avg_weekly_magnitude_ratio",
|
| 150 |
"avg_weekly_pi80_coverage",
|
|
@@ -180,6 +188,8 @@ def _build_result_payload(study) -> dict:
|
|
| 180 |
trial_state_counts = _trial_state_counts(study)
|
| 181 |
best = _best_finite_completed_trial(study)
|
| 182 |
prune_reasons, fold_diagnostics = _build_prune_diagnostics(study)
|
|
|
|
|
|
|
| 183 |
|
| 184 |
if best is None:
|
| 185 |
return {
|
|
@@ -191,6 +201,9 @@ def _build_result_payload(study) -> dict:
|
|
| 191 |
"trial_state_counts": trial_state_counts,
|
| 192 |
"prune_reasons": prune_reasons,
|
| 193 |
"fold_diagnostics": fold_diagnostics,
|
|
|
|
|
|
|
|
|
|
| 194 |
"message": (
|
| 195 |
"No Optuna trials completed with a finite objective value; "
|
| 196 |
"final training will use the known-good fallback config "
|
|
@@ -198,6 +211,11 @@ def _build_result_payload(study) -> dict:
|
|
| 198 |
),
|
| 199 |
}
|
| 200 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 201 |
return {
|
| 202 |
"status": "completed",
|
| 203 |
"best_trial": best.number,
|
|
@@ -207,6 +225,9 @@ def _build_result_payload(study) -> dict:
|
|
| 207 |
"trial_state_counts": trial_state_counts,
|
| 208 |
"prune_reasons": prune_reasons,
|
| 209 |
"fold_diagnostics": fold_diagnostics,
|
|
|
|
|
|
|
|
|
|
| 210 |
}
|
| 211 |
|
| 212 |
|
|
@@ -266,15 +287,10 @@ def create_trial_config(trial, base_cfg: TFTASROConfig) -> TFTASROConfig:
|
|
| 266 |
)
|
| 267 |
|
| 268 |
weekly_loss_cfg = WeeklyLossConfig(
|
| 269 |
-
lambda_weekly_quantile=trial.suggest_float("lambda_weekly_quantile", 0.
|
| 270 |
-
lambda_t1_quantile=trial.suggest_float("lambda_t1_quantile", 0.05, 0.
|
|
|
|
| 271 |
lambda_directional=trial.suggest_float("lambda_directional", 0.05, 0.12, step=0.01),
|
| 272 |
-
lambda_magnitude=trial.suggest_float("lambda_magnitude", 0.50, 0.80, step=0.05),
|
| 273 |
-
lambda_vol=trial.suggest_float("weekly_lambda_vol", 0.25, 0.45, step=0.05),
|
| 274 |
-
lambda_crossing=trial.suggest_float("lambda_crossing", 5.0, 10.0, step=1.0),
|
| 275 |
-
lambda_sanity=trial.suggest_float("lambda_sanity", 0.10, 0.30, step=0.05),
|
| 276 |
-
lambda_width=trial.suggest_float("lambda_width", 0.40, 0.90, step=0.05),
|
| 277 |
-
lambda_tail_width=trial.suggest_float("lambda_tail_width", 0.25, 0.75, step=0.05),
|
| 278 |
)
|
| 279 |
|
| 280 |
training_cfg = TrainingConfig(
|
|
@@ -343,6 +359,7 @@ def _objective(trial, base_cfg: TFTASROConfig, master_data: tuple) -> float:
|
|
| 343 |
from deep_learning.training.callbacks import CurriculumLossScheduler
|
| 344 |
from deep_learning.training.metrics import (
|
| 345 |
compute_weekly_metrics,
|
|
|
|
| 346 |
quantile_crossing_rate,
|
| 347 |
quantile_median_sort_gap,
|
| 348 |
select_prediction_horizon,
|
|
@@ -367,7 +384,9 @@ def _objective(trial, base_cfg: TFTASROConfig, master_data: tuple) -> float:
|
|
| 367 |
fold_sharpe_list: list[float] = []
|
| 368 |
fold_vr_list: list[float] = []
|
| 369 |
fold_crossing_list: list[float] = []
|
|
|
|
| 370 |
fold_median_gap_list: list[float] = []
|
|
|
|
| 371 |
fold_weekly_objectives: list[float] = []
|
| 372 |
fold_weekly_mr_list: list[float] = []
|
| 373 |
fold_weekly_pi80_coverage_list: list[float] = []
|
|
@@ -469,10 +488,14 @@ def _objective(trial, base_cfg: TFTASROConfig, master_data: tuple) -> float:
|
|
| 469 |
raise ValueError(
|
| 470 |
f"Prediction horizon too short: {pred_np.shape[1]} < {trial_cfg.forecast.primary_horizon_days}"
|
| 471 |
)
|
| 472 |
-
|
|
|
|
|
|
|
| 473 |
y_pred = pred_t1[:, median_idx]
|
| 474 |
fold_crossing_rate = quantile_crossing_rate(pred_t1)
|
|
|
|
| 475 |
_, fold_median_gap = quantile_median_sort_gap(pred_t1, median_idx)
|
|
|
|
| 476 |
|
| 477 |
y_actual_parts = []
|
| 478 |
for batch in fold_val_dl:
|
|
@@ -510,7 +533,7 @@ def _objective(trial, base_cfg: TFTASROConfig, master_data: tuple) -> float:
|
|
| 510 |
)
|
| 511 |
weekly_pinball = _weekly_pinball_loss(
|
| 512 |
y_actual_path[:n_path],
|
| 513 |
-
|
| 514 |
tuple(trial_cfg.model.quantiles),
|
| 515 |
horizon=trial_cfg.forecast.primary_horizon_days,
|
| 516 |
)
|
|
@@ -518,9 +541,9 @@ def _objective(trial, base_cfg: TFTASROConfig, master_data: tuple) -> float:
|
|
| 518 |
fold_weekly_pi80_coverage = float(weekly.get("weekly_pi80_coverage", 0.0))
|
| 519 |
fold_weekly_pi80_width_ratio = float(weekly.get("weekly_pi80_width_ratio", 1.0))
|
| 520 |
fold_weekly_pi96_width_ratio = float(weekly.get("weekly_pi96_width_ratio", 1.0))
|
| 521 |
-
fold_weekly_raw_crossing = float(weekly.get("
|
| 522 |
fold_weekly_sorted_crossing = float(
|
| 523 |
-
weekly.get("
|
| 524 |
)
|
| 525 |
fold_weekly_interval_score_80 = float(weekly.get("weekly_interval_score_80", 0.0))
|
| 526 |
fold_weekly_interval_score_96 = float(weekly.get("weekly_interval_score_96", 0.0))
|
|
@@ -530,7 +553,6 @@ def _objective(trial, base_cfg: TFTASROConfig, master_data: tuple) -> float:
|
|
| 530 |
coverage_penalty = abs(fold_weekly_pi80_coverage - 0.80)
|
| 531 |
width_penalty = max(0.0, fold_weekly_pi80_width_ratio - 1.5)
|
| 532 |
tail_width_penalty = max(0.0, fold_weekly_pi96_width_ratio - 3.0)
|
| 533 |
-
raw_crossing_penalty = max(0.0, fold_weekly_raw_crossing - 0.05)
|
| 534 |
fold_weekly_objective = (
|
| 535 |
0.35 * weekly_pinball
|
| 536 |
+ 0.15 * (1.0 - float(weekly.get("weekly_directional_accuracy", 0.5)))
|
|
@@ -540,7 +562,6 @@ def _objective(trial, base_cfg: TFTASROConfig, master_data: tuple) -> float:
|
|
| 540 |
+ 0.35 * tail_width_penalty
|
| 541 |
+ 0.10 * interval_score_penalty
|
| 542 |
+ 0.05 * interval_score_96_penalty
|
| 543 |
-
+ 0.50 * raw_crossing_penalty
|
| 544 |
+ 0.25 * fold_weekly_sorted_crossing
|
| 545 |
)
|
| 546 |
except Exception as exc:
|
|
@@ -553,7 +574,9 @@ def _objective(trial, base_cfg: TFTASROConfig, master_data: tuple) -> float:
|
|
| 553 |
fold_da_list.append(fold_da)
|
| 554 |
fold_sharpe_list.append(fold_sharpe)
|
| 555 |
fold_crossing_list.append(fold_crossing_rate)
|
|
|
|
| 556 |
fold_median_gap_list.append(fold_median_gap)
|
|
|
|
| 557 |
fold_weekly_objectives.append(fold_weekly_objective)
|
| 558 |
fold_weekly_mr_list.append(fold_weekly_mr)
|
| 559 |
fold_weekly_pi80_coverage_list.append(fold_weekly_pi80_coverage)
|
|
@@ -671,7 +694,13 @@ def _objective(trial, base_cfg: TFTASROConfig, master_data: tuple) -> float:
|
|
| 671 |
avg_sharpe = float(np.mean(fold_sharpe_list)) if fold_sharpe_list else 0.0
|
| 672 |
avg_vr = float(np.mean(fold_vr_list)) if fold_vr_list else 0.0
|
| 673 |
avg_crossing = float(np.mean(fold_crossing_list)) if fold_crossing_list else 0.0
|
|
|
|
|
|
|
|
|
|
| 674 |
avg_median_gap = float(np.mean(fold_median_gap_list)) if fold_median_gap_list else 0.0
|
|
|
|
|
|
|
|
|
|
| 675 |
avg_weekly_mr = float(np.mean(fold_weekly_mr_list)) if fold_weekly_mr_list else 1.0
|
| 676 |
avg_weekly_pi80_coverage = (
|
| 677 |
float(np.mean(fold_weekly_pi80_coverage_list))
|
|
@@ -717,7 +746,9 @@ def _objective(trial, base_cfg: TFTASROConfig, master_data: tuple) -> float:
|
|
| 717 |
trial.set_user_attr("avg_variance_ratio", round(avg_vr, 4))
|
| 718 |
trial.set_user_attr("avg_directional_accuracy", round(avg_da, 4))
|
| 719 |
trial.set_user_attr("avg_val_sharpe", round(avg_sharpe, 4))
|
|
|
|
| 720 |
trial.set_user_attr("avg_quantile_crossing_rate", round(avg_crossing, 4))
|
|
|
|
| 721 |
trial.set_user_attr("avg_median_sort_gap", round(avg_median_gap, 4))
|
| 722 |
trial.set_user_attr("avg_weekly_magnitude_ratio", round(avg_weekly_mr, 4))
|
| 723 |
trial.set_user_attr("avg_weekly_pi80_coverage", round(avg_weekly_pi80_coverage, 4))
|
|
@@ -741,21 +772,11 @@ def _objective(trial, base_cfg: TFTASROConfig, master_data: tuple) -> float:
|
|
| 741 |
trial.set_user_attr("prune_reason", "sharpe_prune")
|
| 742 |
raise optuna.exceptions.TrialPruned()
|
| 743 |
|
| 744 |
-
if
|
| 745 |
-
|
| 746 |
-
"
|
| 747 |
-
|
| 748 |
-
)
|
| 749 |
-
trial.set_user_attr("prune_reason", "crossing_prune")
|
| 750 |
-
raise optuna.exceptions.TrialPruned()
|
| 751 |
-
|
| 752 |
-
if (avg_weekly_raw_crossing > 0.05 or avg_weekly_sorted_crossing > 0.0) and not protect_trial:
|
| 753 |
-
logger.warning(
|
| 754 |
-
"Trial %d PRUNED: weekly quantile incoherence raw=%.3f sorted=%.3f",
|
| 755 |
-
trial.number, avg_weekly_raw_crossing, avg_weekly_sorted_crossing,
|
| 756 |
)
|
| 757 |
-
trial.set_user_attr("prune_reason", "weekly_raw_crossing_prune")
|
| 758 |
-
raise optuna.exceptions.TrialPruned()
|
| 759 |
|
| 760 |
# Soft penalty: avg DA below coin-flip
|
| 761 |
da_penalty = 2.0 * max(0.0, 0.50 - avg_da) if avg_da < 0.50 else 0.0
|
|
@@ -825,6 +846,15 @@ def run_hyperopt(
|
|
| 825 |
results_path.parent.mkdir(parents=True, exist_ok=True)
|
| 826 |
result = _build_result_payload(study)
|
| 827 |
results_path.write_text(json.dumps(result, indent=2, allow_nan=False))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 828 |
|
| 829 |
if result["best_trial"] is None:
|
| 830 |
logger.warning(
|
|
@@ -841,6 +871,10 @@ def run_hyperopt(
|
|
| 841 |
)
|
| 842 |
logger.info("Best params: %s", result["best_params"])
|
| 843 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 844 |
return result
|
| 845 |
|
| 846 |
|
|
|
|
| 13 |
import argparse
|
| 14 |
import json
|
| 15 |
import logging
|
| 16 |
+
import sys
|
| 17 |
import warnings
|
| 18 |
from dataclasses import replace
|
| 19 |
from pathlib import Path
|
|
|
|
| 38 |
)
|
| 39 |
from deep_learning.logging_utils import configure_cli_logging, suppress_lightning_noise
|
| 40 |
|
| 41 |
+
PROJECT_ROOT = Path(__file__).resolve().parents[3]
|
| 42 |
+
if str(PROJECT_ROOT) not in sys.path:
|
| 43 |
+
sys.path.insert(0, str(PROJECT_ROOT))
|
| 44 |
+
|
| 45 |
+
from scripts.hyperopt_diagnostics import (
|
| 46 |
+
best_trial_preflight_check,
|
| 47 |
+
compute_structural_invalidity_report,
|
| 48 |
+
compute_trial_distribution_summary,
|
| 49 |
+
)
|
| 50 |
+
|
| 51 |
logger = logging.getLogger(__name__)
|
| 52 |
|
| 53 |
MIN_COMPLETED_TRIALS = 10
|
|
|
|
| 66 |
"lambda_vol": 0.30,
|
| 67 |
"lambda_quantile": 0.25,
|
| 68 |
"lambda_madl": 0.40,
|
| 69 |
+
"lambda_weekly_quantile": 0.55,
|
| 70 |
+
"lambda_t1_quantile": 0.15,
|
| 71 |
+
"lambda_dispersion": 0.20,
|
| 72 |
"lambda_directional": 0.10,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 73 |
"batch_size": 32,
|
| 74 |
}
|
| 75 |
|
|
|
|
| 150 |
"avg_variance_ratio",
|
| 151 |
"avg_directional_accuracy",
|
| 152 |
"avg_val_sharpe",
|
| 153 |
+
"avg_raw_quantile_crossing_rate",
|
| 154 |
"avg_quantile_crossing_rate",
|
| 155 |
+
"avg_raw_median_sort_gap",
|
| 156 |
"avg_median_sort_gap",
|
| 157 |
"avg_weekly_magnitude_ratio",
|
| 158 |
"avg_weekly_pi80_coverage",
|
|
|
|
| 188 |
trial_state_counts = _trial_state_counts(study)
|
| 189 |
best = _best_finite_completed_trial(study)
|
| 190 |
prune_reasons, fold_diagnostics = _build_prune_diagnostics(study)
|
| 191 |
+
structural_report = compute_structural_invalidity_report(fold_diagnostics)
|
| 192 |
+
distribution_summary = compute_trial_distribution_summary(fold_diagnostics)
|
| 193 |
|
| 194 |
if best is None:
|
| 195 |
return {
|
|
|
|
| 201 |
"trial_state_counts": trial_state_counts,
|
| 202 |
"prune_reasons": prune_reasons,
|
| 203 |
"fold_diagnostics": fold_diagnostics,
|
| 204 |
+
"structural_invalidity_report": structural_report,
|
| 205 |
+
"trial_distribution_summary": distribution_summary,
|
| 206 |
+
"best_trial_preflight": None,
|
| 207 |
"message": (
|
| 208 |
"No Optuna trials completed with a finite objective value; "
|
| 209 |
"final training will use the known-good fallback config "
|
|
|
|
| 211 |
),
|
| 212 |
}
|
| 213 |
|
| 214 |
+
best_diagnostics = next(
|
| 215 |
+
(d for d in fold_diagnostics if d.get("trial") == best.number),
|
| 216 |
+
{},
|
| 217 |
+
)
|
| 218 |
+
preflight = best_trial_preflight_check(best_diagnostics)
|
| 219 |
return {
|
| 220 |
"status": "completed",
|
| 221 |
"best_trial": best.number,
|
|
|
|
| 225 |
"trial_state_counts": trial_state_counts,
|
| 226 |
"prune_reasons": prune_reasons,
|
| 227 |
"fold_diagnostics": fold_diagnostics,
|
| 228 |
+
"structural_invalidity_report": structural_report,
|
| 229 |
+
"trial_distribution_summary": distribution_summary,
|
| 230 |
+
"best_trial_preflight": preflight,
|
| 231 |
}
|
| 232 |
|
| 233 |
|
|
|
|
| 287 |
)
|
| 288 |
|
| 289 |
weekly_loss_cfg = WeeklyLossConfig(
|
| 290 |
+
lambda_weekly_quantile=trial.suggest_float("lambda_weekly_quantile", 0.45, 0.65, step=0.05),
|
| 291 |
+
lambda_t1_quantile=trial.suggest_float("lambda_t1_quantile", 0.05, 0.20, step=0.05),
|
| 292 |
+
lambda_dispersion=trial.suggest_float("lambda_dispersion", 0.15, 0.35, step=0.05),
|
| 293 |
lambda_directional=trial.suggest_float("lambda_directional", 0.05, 0.12, step=0.01),
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 294 |
)
|
| 295 |
|
| 296 |
training_cfg = TrainingConfig(
|
|
|
|
| 359 |
from deep_learning.training.callbacks import CurriculumLossScheduler
|
| 360 |
from deep_learning.training.metrics import (
|
| 361 |
compute_weekly_metrics,
|
| 362 |
+
monotonic_quantiles_np,
|
| 363 |
quantile_crossing_rate,
|
| 364 |
quantile_median_sort_gap,
|
| 365 |
select_prediction_horizon,
|
|
|
|
| 384 |
fold_sharpe_list: list[float] = []
|
| 385 |
fold_vr_list: list[float] = []
|
| 386 |
fold_crossing_list: list[float] = []
|
| 387 |
+
fold_raw_crossing_list: list[float] = []
|
| 388 |
fold_median_gap_list: list[float] = []
|
| 389 |
+
fold_raw_median_gap_list: list[float] = []
|
| 390 |
fold_weekly_objectives: list[float] = []
|
| 391 |
fold_weekly_mr_list: list[float] = []
|
| 392 |
fold_weekly_pi80_coverage_list: list[float] = []
|
|
|
|
| 488 |
raise ValueError(
|
| 489 |
f"Prediction horizon too short: {pred_np.shape[1]} < {trial_cfg.forecast.primary_horizon_days}"
|
| 490 |
)
|
| 491 |
+
ordered_pred_np = monotonic_quantiles_np(pred_np, median_idx=median_idx)
|
| 492 |
+
raw_pred_t1 = pred_np[:, 0, :]
|
| 493 |
+
pred_t1 = ordered_pred_np[:, 0, :]
|
| 494 |
y_pred = pred_t1[:, median_idx]
|
| 495 |
fold_crossing_rate = quantile_crossing_rate(pred_t1)
|
| 496 |
+
fold_raw_crossing_rate = quantile_crossing_rate(raw_pred_t1)
|
| 497 |
_, fold_median_gap = quantile_median_sort_gap(pred_t1, median_idx)
|
| 498 |
+
_, fold_raw_median_gap = quantile_median_sort_gap(raw_pred_t1, median_idx)
|
| 499 |
|
| 500 |
y_actual_parts = []
|
| 501 |
for batch in fold_val_dl:
|
|
|
|
| 533 |
)
|
| 534 |
weekly_pinball = _weekly_pinball_loss(
|
| 535 |
y_actual_path[:n_path],
|
| 536 |
+
ordered_pred_np[:n_path],
|
| 537 |
tuple(trial_cfg.model.quantiles),
|
| 538 |
horizon=trial_cfg.forecast.primary_horizon_days,
|
| 539 |
)
|
|
|
|
| 541 |
fold_weekly_pi80_coverage = float(weekly.get("weekly_pi80_coverage", 0.0))
|
| 542 |
fold_weekly_pi80_width_ratio = float(weekly.get("weekly_pi80_width_ratio", 1.0))
|
| 543 |
fold_weekly_pi96_width_ratio = float(weekly.get("weekly_pi96_width_ratio", 1.0))
|
| 544 |
+
fold_weekly_raw_crossing = float(weekly.get("weekly_raw_quantile_crossing_rate", 0.0))
|
| 545 |
fold_weekly_sorted_crossing = float(
|
| 546 |
+
weekly.get("weekly_ordered_quantile_crossing_rate", 0.0)
|
| 547 |
)
|
| 548 |
fold_weekly_interval_score_80 = float(weekly.get("weekly_interval_score_80", 0.0))
|
| 549 |
fold_weekly_interval_score_96 = float(weekly.get("weekly_interval_score_96", 0.0))
|
|
|
|
| 553 |
coverage_penalty = abs(fold_weekly_pi80_coverage - 0.80)
|
| 554 |
width_penalty = max(0.0, fold_weekly_pi80_width_ratio - 1.5)
|
| 555 |
tail_width_penalty = max(0.0, fold_weekly_pi96_width_ratio - 3.0)
|
|
|
|
| 556 |
fold_weekly_objective = (
|
| 557 |
0.35 * weekly_pinball
|
| 558 |
+ 0.15 * (1.0 - float(weekly.get("weekly_directional_accuracy", 0.5)))
|
|
|
|
| 562 |
+ 0.35 * tail_width_penalty
|
| 563 |
+ 0.10 * interval_score_penalty
|
| 564 |
+ 0.05 * interval_score_96_penalty
|
|
|
|
| 565 |
+ 0.25 * fold_weekly_sorted_crossing
|
| 566 |
)
|
| 567 |
except Exception as exc:
|
|
|
|
| 574 |
fold_da_list.append(fold_da)
|
| 575 |
fold_sharpe_list.append(fold_sharpe)
|
| 576 |
fold_crossing_list.append(fold_crossing_rate)
|
| 577 |
+
fold_raw_crossing_list.append(fold_raw_crossing_rate)
|
| 578 |
fold_median_gap_list.append(fold_median_gap)
|
| 579 |
+
fold_raw_median_gap_list.append(fold_raw_median_gap)
|
| 580 |
fold_weekly_objectives.append(fold_weekly_objective)
|
| 581 |
fold_weekly_mr_list.append(fold_weekly_mr)
|
| 582 |
fold_weekly_pi80_coverage_list.append(fold_weekly_pi80_coverage)
|
|
|
|
| 694 |
avg_sharpe = float(np.mean(fold_sharpe_list)) if fold_sharpe_list else 0.0
|
| 695 |
avg_vr = float(np.mean(fold_vr_list)) if fold_vr_list else 0.0
|
| 696 |
avg_crossing = float(np.mean(fold_crossing_list)) if fold_crossing_list else 0.0
|
| 697 |
+
avg_raw_crossing = (
|
| 698 |
+
float(np.mean(fold_raw_crossing_list)) if fold_raw_crossing_list else 0.0
|
| 699 |
+
)
|
| 700 |
avg_median_gap = float(np.mean(fold_median_gap_list)) if fold_median_gap_list else 0.0
|
| 701 |
+
avg_raw_median_gap = (
|
| 702 |
+
float(np.mean(fold_raw_median_gap_list)) if fold_raw_median_gap_list else 0.0
|
| 703 |
+
)
|
| 704 |
avg_weekly_mr = float(np.mean(fold_weekly_mr_list)) if fold_weekly_mr_list else 1.0
|
| 705 |
avg_weekly_pi80_coverage = (
|
| 706 |
float(np.mean(fold_weekly_pi80_coverage_list))
|
|
|
|
| 746 |
trial.set_user_attr("avg_variance_ratio", round(avg_vr, 4))
|
| 747 |
trial.set_user_attr("avg_directional_accuracy", round(avg_da, 4))
|
| 748 |
trial.set_user_attr("avg_val_sharpe", round(avg_sharpe, 4))
|
| 749 |
+
trial.set_user_attr("avg_raw_quantile_crossing_rate", round(avg_raw_crossing, 4))
|
| 750 |
trial.set_user_attr("avg_quantile_crossing_rate", round(avg_crossing, 4))
|
| 751 |
+
trial.set_user_attr("avg_raw_median_sort_gap", round(avg_raw_median_gap, 4))
|
| 752 |
trial.set_user_attr("avg_median_sort_gap", round(avg_median_gap, 4))
|
| 753 |
trial.set_user_attr("avg_weekly_magnitude_ratio", round(avg_weekly_mr, 4))
|
| 754 |
trial.set_user_attr("avg_weekly_pi80_coverage", round(avg_weekly_pi80_coverage, 4))
|
|
|
|
| 772 |
trial.set_user_attr("prune_reason", "sharpe_prune")
|
| 773 |
raise optuna.exceptions.TrialPruned()
|
| 774 |
|
| 775 |
+
if avg_crossing > 0.001 or avg_weekly_sorted_crossing > 0.001:
|
| 776 |
+
raise RuntimeError(
|
| 777 |
+
"Monotonic quantile transform produced public crossings: "
|
| 778 |
+
f"daily={avg_crossing:.6f}, weekly={avg_weekly_sorted_crossing:.6f}"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 779 |
)
|
|
|
|
|
|
|
| 780 |
|
| 781 |
# Soft penalty: avg DA below coin-flip
|
| 782 |
da_penalty = 2.0 * max(0.0, 0.50 - avg_da) if avg_da < 0.50 else 0.0
|
|
|
|
| 846 |
results_path.parent.mkdir(parents=True, exist_ok=True)
|
| 847 |
result = _build_result_payload(study)
|
| 848 |
results_path.write_text(json.dumps(result, indent=2, allow_nan=False))
|
| 849 |
+
logger.info(
|
| 850 |
+
"Optuna structural invalidity report: %s",
|
| 851 |
+
result.get("structural_invalidity_report"),
|
| 852 |
+
)
|
| 853 |
+
logger.info(
|
| 854 |
+
"Optuna trial distribution summary: %s",
|
| 855 |
+
result.get("trial_distribution_summary"),
|
| 856 |
+
)
|
| 857 |
+
logger.info("Optuna best trial preflight: %s", result.get("best_trial_preflight"))
|
| 858 |
|
| 859 |
if result["best_trial"] is None:
|
| 860 |
logger.warning(
|
|
|
|
| 871 |
)
|
| 872 |
logger.info("Best params: %s", result["best_params"])
|
| 873 |
|
| 874 |
+
structural_report = result.get("structural_invalidity_report") or {}
|
| 875 |
+
if structural_report.get("verdict") == "STRUCTURAL_FAILURE":
|
| 876 |
+
raise RuntimeError(structural_report.get("next_action", "Structural failure in hyperopt."))
|
| 877 |
+
|
| 878 |
return result
|
| 879 |
|
| 880 |
|
deep_learning/training/metrics.py
CHANGED
|
@@ -13,6 +13,9 @@ from __future__ import annotations
|
|
| 13 |
|
| 14 |
import numpy as np
|
| 15 |
import pandas as pd
|
|
|
|
|
|
|
|
|
|
| 16 |
|
| 17 |
|
| 18 |
def select_prediction_horizon(values: np.ndarray, horizon_idx: int = 0) -> np.ndarray:
|
|
@@ -53,6 +56,27 @@ def cumulative_quantiles(pred: np.ndarray, horizon: int = 5) -> np.ndarray:
|
|
| 53 |
return arr[:, :horizon, :].sum(axis=1)
|
| 54 |
|
| 55 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 56 |
def magnitude_ratio(y_actual: np.ndarray, y_pred: np.ndarray) -> float:
|
| 57 |
"""Median predicted absolute move divided by median actual absolute move."""
|
| 58 |
denom = np.median(np.abs(np.asarray(y_actual, dtype=np.float64)))
|
|
@@ -288,18 +312,24 @@ def compute_all_metrics(
|
|
| 288 |
|
| 289 |
if y_pred_quantiles is not None:
|
| 290 |
q_arr = np.asarray(y_pred_quantiles, dtype=np.float64)
|
| 291 |
-
|
| 292 |
raw_crossing = quantile_crossing_rate(q_arr)
|
| 293 |
-
|
| 294 |
-
metrics["quantile_crossing_rate"] =
|
| 295 |
metrics["raw_quantile_crossing_rate"] = raw_crossing
|
| 296 |
-
metrics["
|
|
|
|
|
|
|
| 297 |
gap_mean, gap_max = quantile_median_sort_gap(q_arr)
|
| 298 |
-
metrics["
|
| 299 |
-
metrics["
|
| 300 |
-
|
| 301 |
-
metrics["
|
| 302 |
-
metrics["
|
|
|
|
|
|
|
|
|
|
|
|
|
| 303 |
|
| 304 |
return metrics
|
| 305 |
|
|
@@ -317,8 +347,10 @@ def compute_weekly_metrics(
|
|
| 317 |
to simple returns happens only during inference formatting.
|
| 318 |
"""
|
| 319 |
weekly_actual = cumulative_horizon(y_actual_path, horizon=horizon)
|
| 320 |
-
|
| 321 |
-
|
|
|
|
|
|
|
| 322 |
|
| 323 |
median_idx = len(quantiles) // 2
|
| 324 |
q10_idx = quantiles.index(0.10)
|
|
@@ -340,18 +372,27 @@ def compute_weekly_metrics(
|
|
| 340 |
y_pred_q90=weekly_quantiles[:, q90_idx],
|
| 341 |
y_pred_q02=weekly_quantiles[:, q02_idx],
|
| 342 |
y_pred_q98=weekly_quantiles[:, q98_idx],
|
| 343 |
-
y_pred_quantiles=
|
| 344 |
tail_threshold=tail_threshold,
|
| 345 |
)
|
| 346 |
|
| 347 |
weekly_metrics = {f"weekly_{k}": v for k, v in metrics.items()}
|
| 348 |
weekly_metrics["weekly_interval_quantile_source"] = 1.0
|
| 349 |
weekly_metrics["weekly_approx_quantile_crossing_rate"] = quantile_crossing_rate(
|
| 350 |
-
|
| 351 |
)
|
| 352 |
-
approx_gap_mean, approx_gap_max = quantile_median_sort_gap(
|
| 353 |
weekly_metrics["weekly_approx_median_sort_gap_mean"] = approx_gap_mean
|
| 354 |
weekly_metrics["weekly_approx_median_sort_gap_max"] = approx_gap_max
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 355 |
weekly_metrics["weekly_magnitude_ratio"] = magnitude_ratio(weekly_actual, weekly_pred)
|
| 356 |
weekly_metrics["weekly_mean_actual_abs"] = float(np.mean(np.abs(weekly_actual)))
|
| 357 |
weekly_metrics["weekly_mean_pred_abs"] = float(np.mean(np.abs(weekly_pred)))
|
|
|
|
| 13 |
|
| 14 |
import numpy as np
|
| 15 |
import pandas as pd
|
| 16 |
+
import torch
|
| 17 |
+
|
| 18 |
+
from deep_learning.models.monotonic_quantiles import enforce_monotonic_quantiles
|
| 19 |
|
| 20 |
|
| 21 |
def select_prediction_horizon(values: np.ndarray, horizon_idx: int = 0) -> np.ndarray:
|
|
|
|
| 56 |
return arr[:, :horizon, :].sum(axis=1)
|
| 57 |
|
| 58 |
|
| 59 |
+
def monotonic_quantiles_np(
|
| 60 |
+
pred: np.ndarray,
|
| 61 |
+
median_idx: int | None = None,
|
| 62 |
+
) -> np.ndarray:
|
| 63 |
+
"""Apply the production monotonic quantile transform to a numpy tensor."""
|
| 64 |
+
arr = np.asarray(pred, dtype=np.float64)
|
| 65 |
+
if arr.shape[-1] == 0:
|
| 66 |
+
return arr.copy()
|
| 67 |
+
if median_idx is None:
|
| 68 |
+
median_idx = arr.shape[-1] // 2
|
| 69 |
+
tensor = torch.as_tensor(arr, dtype=torch.float64)
|
| 70 |
+
ordered = enforce_monotonic_quantiles(
|
| 71 |
+
tensor,
|
| 72 |
+
median_idx=median_idx,
|
| 73 |
+
min_gap=1e-5,
|
| 74 |
+
gap_scale=0.01,
|
| 75 |
+
init_bias=-3.0,
|
| 76 |
+
)
|
| 77 |
+
return ordered.detach().cpu().numpy()
|
| 78 |
+
|
| 79 |
+
|
| 80 |
def magnitude_ratio(y_actual: np.ndarray, y_pred: np.ndarray) -> float:
|
| 81 |
"""Median predicted absolute move divided by median actual absolute move."""
|
| 82 |
denom = np.median(np.abs(np.asarray(y_actual, dtype=np.float64)))
|
|
|
|
| 312 |
|
| 313 |
if y_pred_quantiles is not None:
|
| 314 |
q_arr = np.asarray(y_pred_quantiles, dtype=np.float64)
|
| 315 |
+
ordered_q = monotonic_quantiles_np(q_arr)
|
| 316 |
raw_crossing = quantile_crossing_rate(q_arr)
|
| 317 |
+
ordered_crossing = quantile_crossing_rate(ordered_q)
|
| 318 |
+
metrics["quantile_crossing_rate"] = ordered_crossing
|
| 319 |
metrics["raw_quantile_crossing_rate"] = raw_crossing
|
| 320 |
+
metrics["ordered_quantile_crossing_rate"] = ordered_crossing
|
| 321 |
+
metrics["public_quantile_crossing_rate"] = ordered_crossing
|
| 322 |
+
metrics["sorted_quantile_crossing_rate"] = ordered_crossing
|
| 323 |
gap_mean, gap_max = quantile_median_sort_gap(q_arr)
|
| 324 |
+
metrics["raw_median_sort_gap_mean"] = gap_mean
|
| 325 |
+
metrics["raw_median_sort_gap_max"] = gap_max
|
| 326 |
+
ordered_gap_mean, ordered_gap_max = quantile_median_sort_gap(ordered_q)
|
| 327 |
+
metrics["median_sort_gap_mean"] = ordered_gap_mean
|
| 328 |
+
metrics["median_sort_gap_max"] = ordered_gap_max
|
| 329 |
+
metrics["ordered_median_sort_gap_mean"] = ordered_gap_mean
|
| 330 |
+
metrics["ordered_median_sort_gap_max"] = ordered_gap_max
|
| 331 |
+
metrics["sorted_median_sort_gap_mean"] = ordered_gap_mean
|
| 332 |
+
metrics["sorted_median_sort_gap_max"] = ordered_gap_max
|
| 333 |
|
| 334 |
return metrics
|
| 335 |
|
|
|
|
| 347 |
to simple returns happens only during inference formatting.
|
| 348 |
"""
|
| 349 |
weekly_actual = cumulative_horizon(y_actual_path, horizon=horizon)
|
| 350 |
+
raw_path = np.asarray(y_pred_quantiles_path, dtype=np.float64)
|
| 351 |
+
ordered_path = monotonic_quantiles_np(raw_path, median_idx=len(quantiles) // 2)
|
| 352 |
+
raw_weekly_quantiles = cumulative_quantiles(raw_path, horizon=horizon)
|
| 353 |
+
weekly_quantiles = cumulative_quantiles(ordered_path, horizon=horizon)
|
| 354 |
|
| 355 |
median_idx = len(quantiles) // 2
|
| 356 |
q10_idx = quantiles.index(0.10)
|
|
|
|
| 372 |
y_pred_q90=weekly_quantiles[:, q90_idx],
|
| 373 |
y_pred_q02=weekly_quantiles[:, q02_idx],
|
| 374 |
y_pred_q98=weekly_quantiles[:, q98_idx],
|
| 375 |
+
y_pred_quantiles=weekly_quantiles,
|
| 376 |
tail_threshold=tail_threshold,
|
| 377 |
)
|
| 378 |
|
| 379 |
weekly_metrics = {f"weekly_{k}": v for k, v in metrics.items()}
|
| 380 |
weekly_metrics["weekly_interval_quantile_source"] = 1.0
|
| 381 |
weekly_metrics["weekly_approx_quantile_crossing_rate"] = quantile_crossing_rate(
|
| 382 |
+
raw_weekly_quantiles
|
| 383 |
)
|
| 384 |
+
approx_gap_mean, approx_gap_max = quantile_median_sort_gap(raw_weekly_quantiles)
|
| 385 |
weekly_metrics["weekly_approx_median_sort_gap_mean"] = approx_gap_mean
|
| 386 |
weekly_metrics["weekly_approx_median_sort_gap_max"] = approx_gap_max
|
| 387 |
+
weekly_metrics["weekly_raw_quantile_crossing_rate"] = quantile_crossing_rate(
|
| 388 |
+
raw_weekly_quantiles
|
| 389 |
+
)
|
| 390 |
+
weekly_metrics["weekly_ordered_quantile_crossing_rate"] = quantile_crossing_rate(
|
| 391 |
+
weekly_quantiles
|
| 392 |
+
)
|
| 393 |
+
weekly_metrics["weekly_public_quantile_crossing_rate"] = weekly_metrics[
|
| 394 |
+
"weekly_ordered_quantile_crossing_rate"
|
| 395 |
+
]
|
| 396 |
weekly_metrics["weekly_magnitude_ratio"] = magnitude_ratio(weekly_actual, weekly_pred)
|
| 397 |
weekly_metrics["weekly_mean_actual_abs"] = float(np.mean(np.abs(weekly_actual)))
|
| 398 |
weekly_metrics["weekly_mean_pred_abs"] = float(np.mean(np.abs(weekly_pred)))
|
deep_learning/training/trainer.py
CHANGED
|
@@ -47,7 +47,7 @@ warnings.filterwarnings(
|
|
| 47 |
logger = logging.getLogger(__name__)
|
| 48 |
|
| 49 |
KNOWN_GOOD_CONFIG = {
|
| 50 |
-
"max_encoder_length":
|
| 51 |
"hidden_size": 48,
|
| 52 |
"attention_head_size": 2,
|
| 53 |
"dropout": 0.30,
|
|
@@ -57,18 +57,15 @@ KNOWN_GOOD_CONFIG = {
|
|
| 57 |
"lambda_vol": 0.30,
|
| 58 |
"lambda_quantile": 0.25,
|
| 59 |
"lambda_madl": 0.40,
|
| 60 |
-
"lambda_weekly_quantile": 0.
|
| 61 |
-
"lambda_t1_quantile": 0.
|
|
|
|
| 62 |
"lambda_directional": 0.10,
|
| 63 |
-
"lambda_magnitude": 0.55,
|
| 64 |
-
"weekly_lambda_vol": 0.35,
|
| 65 |
-
"lambda_width": 0.50,
|
| 66 |
-
"lambda_tail_width": 0.30,
|
| 67 |
-
"lambda_sanity": 0.20,
|
| 68 |
-
"lambda_crossing": 7.0,
|
| 69 |
"batch_size": 32,
|
| 70 |
}
|
| 71 |
|
|
|
|
|
|
|
| 72 |
REQUIRED_PROMOTABLE_METRICS = (
|
| 73 |
"weekly_directional_accuracy",
|
| 74 |
"weekly_magnitude_ratio",
|
|
@@ -126,13 +123,22 @@ def _compute_test_metrics_from_quantiles(
|
|
| 126 |
pred_np: np.ndarray,
|
| 127 |
cfg: TFTASROConfig,
|
| 128 |
) -> dict[str, float]:
|
| 129 |
-
from deep_learning.training.metrics import
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 130 |
|
| 131 |
pred_np = np.asarray(pred_np)
|
| 132 |
_validate_quantile_prediction_shape(pred_np, cfg)
|
| 133 |
|
| 134 |
median_idx = len(cfg.model.quantiles) // 2
|
| 135 |
-
|
|
|
|
|
|
|
| 136 |
y_pred_median = pred_t1[:, median_idx]
|
| 137 |
y_pred_q10 = pred_t1[:, 1]
|
| 138 |
y_pred_q90 = pred_t1[:, -2]
|
|
@@ -150,6 +156,10 @@ def _compute_test_metrics_from_quantiles(
|
|
| 150 |
y_pred_q98=y_pred_q98[:n],
|
| 151 |
y_pred_quantiles=pred_t1[:n],
|
| 152 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 153 |
|
| 154 |
n_path = min(len(y_actual_path), len(pred_np))
|
| 155 |
weekly_metrics = compute_weekly_metrics(
|
|
@@ -167,6 +177,7 @@ def train_tft_model(
|
|
| 167 |
cfg: Optional[TFTASROConfig] = None,
|
| 168 |
use_asro: bool = True,
|
| 169 |
upload_to_hub: bool = False,
|
|
|
|
| 170 |
) -> dict:
|
| 171 |
"""
|
| 172 |
End-to-end TFT-ASRO training.
|
|
@@ -189,16 +200,23 @@ def train_tft_model(
|
|
| 189 |
from deep_learning.data.feature_store import build_tft_dataframe
|
| 190 |
from deep_learning.data.dataset import build_datasets, create_dataloaders
|
| 191 |
from deep_learning.models.tft_copper import create_tft_model, get_variable_importance, format_prediction
|
| 192 |
-
from deep_learning.training.callbacks import
|
|
|
|
|
|
|
|
|
|
|
|
|
| 193 |
|
| 194 |
if cfg is None:
|
| 195 |
cfg = get_tft_config()
|
| 196 |
|
| 197 |
-
# ---- 0a. Load
|
| 198 |
-
#
|
| 199 |
-
#
|
| 200 |
-
|
| 201 |
-
|
|
|
|
|
|
|
|
|
|
| 202 |
|
| 203 |
# ---- 0b. ASRO loss sanity check (runs before any training) ----
|
| 204 |
try:
|
|
@@ -260,19 +278,11 @@ def train_tft_model(
|
|
| 260 |
cfg.training.early_stopping_patience,
|
| 261 |
)
|
| 262 |
logger.info(
|
| 263 |
-
"Weekly loss | weekly_q=%.2f t1_q=%.2f
|
| 264 |
cfg.weekly_loss.lambda_weekly_quantile,
|
| 265 |
cfg.weekly_loss.lambda_t1_quantile,
|
|
|
|
| 266 |
cfg.weekly_loss.lambda_directional,
|
| 267 |
-
cfg.weekly_loss.lambda_magnitude,
|
| 268 |
-
cfg.weekly_loss.lambda_vol,
|
| 269 |
-
)
|
| 270 |
-
logger.info(
|
| 271 |
-
"Weekly guards | width=%.2f tail_width=%.2f crossing=%.2f sanity=%.2f",
|
| 272 |
-
cfg.weekly_loss.lambda_width,
|
| 273 |
-
cfg.weekly_loss.lambda_tail_width,
|
| 274 |
-
cfg.weekly_loss.lambda_crossing,
|
| 275 |
-
cfg.weekly_loss.lambda_sanity,
|
| 276 |
)
|
| 277 |
else:
|
| 278 |
logger.info(
|
|
@@ -308,6 +318,7 @@ def train_tft_model(
|
|
| 308 |
save_top_k=3,
|
| 309 |
save_last=True,
|
| 310 |
),
|
|
|
|
| 311 |
]
|
| 312 |
|
| 313 |
if use_asro and cfg.forecast.primary_horizon_days != 5:
|
|
@@ -430,12 +441,8 @@ def train_tft_model(
|
|
| 430 |
"lambda_weekly_quantile": cfg.weekly_loss.lambda_weekly_quantile,
|
| 431 |
"lambda_t1_quantile": cfg.weekly_loss.lambda_t1_quantile,
|
| 432 |
"lambda_directional": cfg.weekly_loss.lambda_directional,
|
| 433 |
-
"
|
| 434 |
-
"
|
| 435 |
-
"weekly_lambda_crossing": cfg.weekly_loss.lambda_crossing,
|
| 436 |
-
"lambda_sanity": cfg.weekly_loss.lambda_sanity,
|
| 437 |
-
"lambda_width": cfg.weekly_loss.lambda_width,
|
| 438 |
-
"lambda_tail_width": cfg.weekly_loss.lambda_tail_width,
|
| 439 |
"max_encoder_length": cfg.model.max_encoder_length,
|
| 440 |
"max_prediction_length": cfg.model.max_prediction_length,
|
| 441 |
"forecast_contract_version": FORECAST_CONTRACT_VERSION,
|
|
@@ -518,7 +525,11 @@ def _write_conformal_calibration_artifact(
|
|
| 518 |
import torch
|
| 519 |
|
| 520 |
from deep_learning.calibration.conformal import rolling_conformal_adjustment
|
| 521 |
-
from deep_learning.training.metrics import
|
|
|
|
|
|
|
|
|
|
|
|
|
| 522 |
|
| 523 |
y_parts = []
|
| 524 |
for batch in val_dl:
|
|
@@ -534,9 +545,13 @@ def _write_conformal_calibration_artifact(
|
|
| 534 |
return None
|
| 535 |
|
| 536 |
weekly_actual = cumulative_horizon(y_actual_path[:n], horizon=cfg.forecast.primary_horizon_days)
|
| 537 |
-
|
| 538 |
-
|
| 539 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 540 |
)
|
| 541 |
q = tuple(cfg.model.quantiles)
|
| 542 |
q10_idx = q.index(0.10)
|
|
@@ -638,18 +653,10 @@ def _apply_optuna_results(cfg: TFTASROConfig) -> TFTASROConfig:
|
|
| 638 |
params["learning_rate"] = min(float(params["learning_rate"]), 6e-4)
|
| 639 |
if "weight_decay" in params:
|
| 640 |
params["weight_decay"] = min(float(params["weight_decay"]), 5e-4)
|
| 641 |
-
if "lambda_magnitude" in params:
|
| 642 |
-
params["lambda_magnitude"] = max(float(params["lambda_magnitude"]), 0.50)
|
| 643 |
if "lambda_directional" in params:
|
| 644 |
params["lambda_directional"] = min(float(params["lambda_directional"]), 0.12)
|
| 645 |
-
if "
|
| 646 |
-
params["
|
| 647 |
-
if "lambda_tail_width" in params:
|
| 648 |
-
params["lambda_tail_width"] = max(float(params["lambda_tail_width"]), 0.25)
|
| 649 |
-
if "lambda_sanity" in params:
|
| 650 |
-
params["lambda_sanity"] = max(float(params["lambda_sanity"]), 0.10)
|
| 651 |
-
if "lambda_crossing" in params:
|
| 652 |
-
params["lambda_crossing"] = max(float(params["lambda_crossing"]), 5.0)
|
| 653 |
|
| 654 |
logger.info(
|
| 655 |
"Loaded Optuna best params (trial #%d, weekly_objective=%.4f): %s",
|
|
@@ -685,12 +692,9 @@ def _overlay_training_config(cfg: TFTASROConfig, params: dict) -> TFTASROConfig:
|
|
| 685 |
weekly_loss_overrides = {
|
| 686 |
k: params[k] for k in (
|
| 687 |
"lambda_weekly_quantile", "lambda_t1_quantile", "lambda_directional",
|
| 688 |
-
"
|
| 689 |
-
"lambda_width", "lambda_tail_width",
|
| 690 |
) if k in params
|
| 691 |
}
|
| 692 |
-
if "weekly_lambda_vol" in params:
|
| 693 |
-
weekly_loss_overrides["lambda_vol"] = params["weekly_lambda_vol"]
|
| 694 |
|
| 695 |
new_model = replace(cfg.model, **model_overrides) if model_overrides else cfg.model
|
| 696 |
new_asro = replace(cfg.asro, **asro_overrides) if asro_overrides else cfg.asro
|
|
@@ -744,10 +748,20 @@ if __name__ == "__main__":
|
|
| 744 |
parser.add_argument("--symbol", default="HG=F")
|
| 745 |
parser.add_argument("--no-asro", action="store_true", help="Use standard QuantileLoss instead of ASRO")
|
| 746 |
parser.add_argument("--upload-hub", action="store_true", help="Upload artifacts to HF Hub after training")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 747 |
args = parser.parse_args()
|
| 748 |
|
| 749 |
cfg = get_tft_config()
|
| 750 |
-
result = train_tft_model(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 751 |
|
| 752 |
print("\n" + "=" * 60)
|
| 753 |
print("TFT-ASRO TRAINING COMPLETE")
|
|
|
|
| 47 |
logger = logging.getLogger(__name__)
|
| 48 |
|
| 49 |
KNOWN_GOOD_CONFIG = {
|
| 50 |
+
"max_encoder_length": 50,
|
| 51 |
"hidden_size": 48,
|
| 52 |
"attention_head_size": 2,
|
| 53 |
"dropout": 0.30,
|
|
|
|
| 57 |
"lambda_vol": 0.30,
|
| 58 |
"lambda_quantile": 0.25,
|
| 59 |
"lambda_madl": 0.40,
|
| 60 |
+
"lambda_weekly_quantile": 0.55,
|
| 61 |
+
"lambda_t1_quantile": 0.15,
|
| 62 |
+
"lambda_dispersion": 0.20,
|
| 63 |
"lambda_directional": 0.10,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 64 |
"batch_size": 32,
|
| 65 |
}
|
| 66 |
|
| 67 |
+
DETERMINISTIC_WEEKLY_CONFIG = dict(KNOWN_GOOD_CONFIG)
|
| 68 |
+
|
| 69 |
REQUIRED_PROMOTABLE_METRICS = (
|
| 70 |
"weekly_directional_accuracy",
|
| 71 |
"weekly_magnitude_ratio",
|
|
|
|
| 123 |
pred_np: np.ndarray,
|
| 124 |
cfg: TFTASROConfig,
|
| 125 |
) -> dict[str, float]:
|
| 126 |
+
from deep_learning.training.metrics import (
|
| 127 |
+
compute_all_metrics,
|
| 128 |
+
compute_weekly_metrics,
|
| 129 |
+
monotonic_quantiles_np,
|
| 130 |
+
quantile_crossing_rate,
|
| 131 |
+
quantile_median_sort_gap,
|
| 132 |
+
select_prediction_horizon,
|
| 133 |
+
)
|
| 134 |
|
| 135 |
pred_np = np.asarray(pred_np)
|
| 136 |
_validate_quantile_prediction_shape(pred_np, cfg)
|
| 137 |
|
| 138 |
median_idx = len(cfg.model.quantiles) // 2
|
| 139 |
+
ordered_pred_np = monotonic_quantiles_np(pred_np, median_idx=median_idx)
|
| 140 |
+
raw_pred_t1 = pred_np[:, 0, :]
|
| 141 |
+
pred_t1 = ordered_pred_np[:, 0, :]
|
| 142 |
y_pred_median = pred_t1[:, median_idx]
|
| 143 |
y_pred_q10 = pred_t1[:, 1]
|
| 144 |
y_pred_q90 = pred_t1[:, -2]
|
|
|
|
| 156 |
y_pred_q98=y_pred_q98[:n],
|
| 157 |
y_pred_quantiles=pred_t1[:n],
|
| 158 |
)
|
| 159 |
+
raw_gap_mean, raw_gap_max = quantile_median_sort_gap(raw_pred_t1[:n], median_idx)
|
| 160 |
+
test_metrics["raw_quantile_crossing_rate"] = quantile_crossing_rate(raw_pred_t1[:n])
|
| 161 |
+
test_metrics["raw_median_sort_gap_mean"] = raw_gap_mean
|
| 162 |
+
test_metrics["raw_median_sort_gap_max"] = raw_gap_max
|
| 163 |
|
| 164 |
n_path = min(len(y_actual_path), len(pred_np))
|
| 165 |
weekly_metrics = compute_weekly_metrics(
|
|
|
|
| 177 |
cfg: Optional[TFTASROConfig] = None,
|
| 178 |
use_asro: bool = True,
|
| 179 |
upload_to_hub: bool = False,
|
| 180 |
+
deterministic_weekly_validation: bool = False,
|
| 181 |
) -> dict:
|
| 182 |
"""
|
| 183 |
End-to-end TFT-ASRO training.
|
|
|
|
| 200 |
from deep_learning.data.feature_store import build_tft_dataframe
|
| 201 |
from deep_learning.data.dataset import build_datasets, create_dataloaders
|
| 202 |
from deep_learning.models.tft_copper import create_tft_model, get_variable_importance, format_prediction
|
| 203 |
+
from deep_learning.training.callbacks import (
|
| 204 |
+
CurriculumLossScheduler,
|
| 205 |
+
SWACallback,
|
| 206 |
+
WeeklyLossComponentLogger,
|
| 207 |
+
)
|
| 208 |
|
| 209 |
if cfg is None:
|
| 210 |
cfg = get_tft_config()
|
| 211 |
|
| 212 |
+
# ---- 0a. Load training params ----
|
| 213 |
+
# Deterministic validation bypasses Optuna so structural changes can be
|
| 214 |
+
# measured before investing in search.
|
| 215 |
+
if deterministic_weekly_validation:
|
| 216 |
+
cfg = _overlay_training_config(cfg, DETERMINISTIC_WEEKLY_CONFIG)
|
| 217 |
+
logger.info("Using deterministic weekly validation config: %s", DETERMINISTIC_WEEKLY_CONFIG)
|
| 218 |
+
else:
|
| 219 |
+
cfg = _apply_optuna_results(cfg)
|
| 220 |
|
| 221 |
# ---- 0b. ASRO loss sanity check (runs before any training) ----
|
| 222 |
try:
|
|
|
|
| 278 |
cfg.training.early_stopping_patience,
|
| 279 |
)
|
| 280 |
logger.info(
|
| 281 |
+
"Weekly loss | weekly_q=%.2f t1_q=%.2f dispersion=%.2f directional=%.2f monotonic_transform=true",
|
| 282 |
cfg.weekly_loss.lambda_weekly_quantile,
|
| 283 |
cfg.weekly_loss.lambda_t1_quantile,
|
| 284 |
+
cfg.weekly_loss.lambda_dispersion,
|
| 285 |
cfg.weekly_loss.lambda_directional,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 286 |
)
|
| 287 |
else:
|
| 288 |
logger.info(
|
|
|
|
| 318 |
save_top_k=3,
|
| 319 |
save_last=True,
|
| 320 |
),
|
| 321 |
+
WeeklyLossComponentLogger(),
|
| 322 |
]
|
| 323 |
|
| 324 |
if use_asro and cfg.forecast.primary_horizon_days != 5:
|
|
|
|
| 441 |
"lambda_weekly_quantile": cfg.weekly_loss.lambda_weekly_quantile,
|
| 442 |
"lambda_t1_quantile": cfg.weekly_loss.lambda_t1_quantile,
|
| 443 |
"lambda_directional": cfg.weekly_loss.lambda_directional,
|
| 444 |
+
"lambda_dispersion": cfg.weekly_loss.lambda_dispersion,
|
| 445 |
+
"monotonic_quantile_transform": True,
|
|
|
|
|
|
|
|
|
|
|
|
|
| 446 |
"max_encoder_length": cfg.model.max_encoder_length,
|
| 447 |
"max_prediction_length": cfg.model.max_prediction_length,
|
| 448 |
"forecast_contract_version": FORECAST_CONTRACT_VERSION,
|
|
|
|
| 525 |
import torch
|
| 526 |
|
| 527 |
from deep_learning.calibration.conformal import rolling_conformal_adjustment
|
| 528 |
+
from deep_learning.training.metrics import (
|
| 529 |
+
cumulative_horizon,
|
| 530 |
+
cumulative_quantiles,
|
| 531 |
+
monotonic_quantiles_np,
|
| 532 |
+
)
|
| 533 |
|
| 534 |
y_parts = []
|
| 535 |
for batch in val_dl:
|
|
|
|
| 545 |
return None
|
| 546 |
|
| 547 |
weekly_actual = cumulative_horizon(y_actual_path[:n], horizon=cfg.forecast.primary_horizon_days)
|
| 548 |
+
ordered_pred_np = monotonic_quantiles_np(
|
| 549 |
+
pred_np[:n],
|
| 550 |
+
median_idx=len(cfg.model.quantiles) // 2,
|
| 551 |
+
)
|
| 552 |
+
weekly_quantiles = cumulative_quantiles(
|
| 553 |
+
ordered_pred_np,
|
| 554 |
+
horizon=cfg.forecast.primary_horizon_days,
|
| 555 |
)
|
| 556 |
q = tuple(cfg.model.quantiles)
|
| 557 |
q10_idx = q.index(0.10)
|
|
|
|
| 653 |
params["learning_rate"] = min(float(params["learning_rate"]), 6e-4)
|
| 654 |
if "weight_decay" in params:
|
| 655 |
params["weight_decay"] = min(float(params["weight_decay"]), 5e-4)
|
|
|
|
|
|
|
| 656 |
if "lambda_directional" in params:
|
| 657 |
params["lambda_directional"] = min(float(params["lambda_directional"]), 0.12)
|
| 658 |
+
if "lambda_dispersion" in params:
|
| 659 |
+
params["lambda_dispersion"] = max(float(params["lambda_dispersion"]), 0.20)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 660 |
|
| 661 |
logger.info(
|
| 662 |
"Loaded Optuna best params (trial #%d, weekly_objective=%.4f): %s",
|
|
|
|
| 692 |
weekly_loss_overrides = {
|
| 693 |
k: params[k] for k in (
|
| 694 |
"lambda_weekly_quantile", "lambda_t1_quantile", "lambda_directional",
|
| 695 |
+
"lambda_dispersion",
|
|
|
|
| 696 |
) if k in params
|
| 697 |
}
|
|
|
|
|
|
|
| 698 |
|
| 699 |
new_model = replace(cfg.model, **model_overrides) if model_overrides else cfg.model
|
| 700 |
new_asro = replace(cfg.asro, **asro_overrides) if asro_overrides else cfg.asro
|
|
|
|
| 748 |
parser.add_argument("--symbol", default="HG=F")
|
| 749 |
parser.add_argument("--no-asro", action="store_true", help="Use standard QuantileLoss instead of ASRO")
|
| 750 |
parser.add_argument("--upload-hub", action="store_true", help="Upload artifacts to HF Hub after training")
|
| 751 |
+
parser.add_argument(
|
| 752 |
+
"--deterministic-weekly-validation",
|
| 753 |
+
action="store_true",
|
| 754 |
+
help="Bypass Optuna overlays and run the fixed monotonic weekly validation config",
|
| 755 |
+
)
|
| 756 |
args = parser.parse_args()
|
| 757 |
|
| 758 |
cfg = get_tft_config()
|
| 759 |
+
result = train_tft_model(
|
| 760 |
+
cfg,
|
| 761 |
+
use_asro=not args.no_asro,
|
| 762 |
+
upload_to_hub=args.upload_hub,
|
| 763 |
+
deterministic_weekly_validation=args.deterministic_weekly_validation,
|
| 764 |
+
)
|
| 765 |
|
| 766 |
print("\n" + "=" * 60)
|
| 767 |
print("TFT-ASRO TRAINING COMPLETE")
|
scripts/tft_quality_gate.py
CHANGED
|
@@ -19,7 +19,7 @@ BACKEND_ROOT = pathlib.Path(__file__).resolve().parents[1]
|
|
| 19 |
if str(BACKEND_ROOT) not in sys.path:
|
| 20 |
sys.path.insert(0, str(BACKEND_ROOT))
|
| 21 |
|
| 22 |
-
from app.quality_gate import evaluate_quality_gate
|
| 23 |
|
| 24 |
META_PATH = pathlib.Path(os.environ.get("TFT_METADATA_PATH", "/tmp/models/tft/tft_metadata.json"))
|
| 25 |
|
|
@@ -37,17 +37,23 @@ def main() -> int:
|
|
| 37 |
tail_capture = metrics.get("tail_capture_rate")
|
| 38 |
quantile_crossing = metrics.get("quantile_crossing_rate")
|
| 39 |
median_gap_max = metrics.get("median_sort_gap_max")
|
|
|
|
|
|
|
|
|
|
| 40 |
weekly_da = metrics.get("weekly_directional_accuracy")
|
| 41 |
weekly_mr = metrics.get("weekly_magnitude_ratio")
|
| 42 |
weekly_tail = metrics.get("weekly_tail_capture_rate")
|
| 43 |
weekly_pi80 = metrics.get("weekly_pi80_coverage")
|
|
|
|
| 44 |
weekly_pi80_width_ratio = metrics.get("weekly_pi80_width_ratio")
|
| 45 |
weekly_pi96 = metrics.get("weekly_pi96_coverage")
|
|
|
|
| 46 |
weekly_pi96_width_ratio = metrics.get("weekly_pi96_width_ratio")
|
| 47 |
weekly_qcross = metrics.get("weekly_quantile_crossing_rate")
|
| 48 |
weekly_sorted_qcross = metrics.get("weekly_sorted_quantile_crossing_rate")
|
| 49 |
weekly_gap = metrics.get("weekly_median_sort_gap_max")
|
| 50 |
weekly_samples = metrics.get("weekly_sample_count")
|
|
|
|
| 51 |
|
| 52 |
print(
|
| 53 |
"Quality gate metrics: "
|
|
@@ -71,18 +77,29 @@ def main() -> int:
|
|
| 71 |
tail_capture=tail_capture,
|
| 72 |
quantile_crossing_rate=quantile_crossing,
|
| 73 |
median_sort_gap_max=median_gap_max,
|
|
|
|
|
|
|
| 74 |
weekly_directional_accuracy=weekly_da,
|
| 75 |
weekly_magnitude_ratio=weekly_mr,
|
| 76 |
weekly_tail_capture_rate=weekly_tail,
|
| 77 |
weekly_pi80_coverage=weekly_pi80,
|
|
|
|
| 78 |
weekly_pi80_width_ratio=weekly_pi80_width_ratio,
|
| 79 |
weekly_pi96_coverage=weekly_pi96,
|
|
|
|
| 80 |
weekly_pi96_width_ratio=weekly_pi96_width_ratio,
|
| 81 |
weekly_quantile_crossing_rate=weekly_qcross,
|
| 82 |
weekly_sorted_quantile_crossing_rate=weekly_sorted_qcross,
|
| 83 |
weekly_median_sort_gap_max=weekly_gap,
|
| 84 |
weekly_sample_count=weekly_samples,
|
| 85 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 86 |
|
| 87 |
if passed:
|
| 88 |
print("QUALITY GATE: PASSED")
|
|
|
|
| 19 |
if str(BACKEND_ROOT) not in sys.path:
|
| 20 |
sys.path.insert(0, str(BACKEND_ROOT))
|
| 21 |
|
| 22 |
+
from app.quality_gate import evaluate_quality_gate, evaluate_quality_gate_warnings
|
| 23 |
|
| 24 |
META_PATH = pathlib.Path(os.environ.get("TFT_METADATA_PATH", "/tmp/models/tft/tft_metadata.json"))
|
| 25 |
|
|
|
|
| 37 |
tail_capture = metrics.get("tail_capture_rate")
|
| 38 |
quantile_crossing = metrics.get("quantile_crossing_rate")
|
| 39 |
median_gap_max = metrics.get("median_sort_gap_max")
|
| 40 |
+
pi80_width = metrics.get("pi80_width")
|
| 41 |
+
pi96_width = metrics.get("pi96_width")
|
| 42 |
+
mae_vs_naive_zero = metrics.get("mae_vs_naive_zero")
|
| 43 |
weekly_da = metrics.get("weekly_directional_accuracy")
|
| 44 |
weekly_mr = metrics.get("weekly_magnitude_ratio")
|
| 45 |
weekly_tail = metrics.get("weekly_tail_capture_rate")
|
| 46 |
weekly_pi80 = metrics.get("weekly_pi80_coverage")
|
| 47 |
+
weekly_pi80_width = metrics.get("weekly_pi80_width")
|
| 48 |
weekly_pi80_width_ratio = metrics.get("weekly_pi80_width_ratio")
|
| 49 |
weekly_pi96 = metrics.get("weekly_pi96_coverage")
|
| 50 |
+
weekly_pi96_width = metrics.get("weekly_pi96_width")
|
| 51 |
weekly_pi96_width_ratio = metrics.get("weekly_pi96_width_ratio")
|
| 52 |
weekly_qcross = metrics.get("weekly_quantile_crossing_rate")
|
| 53 |
weekly_sorted_qcross = metrics.get("weekly_sorted_quantile_crossing_rate")
|
| 54 |
weekly_gap = metrics.get("weekly_median_sort_gap_max")
|
| 55 |
weekly_samples = metrics.get("weekly_sample_count")
|
| 56 |
+
weekly_mae_vs_naive_zero = metrics.get("weekly_mae_vs_naive_zero")
|
| 57 |
|
| 58 |
print(
|
| 59 |
"Quality gate metrics: "
|
|
|
|
| 77 |
tail_capture=tail_capture,
|
| 78 |
quantile_crossing_rate=quantile_crossing,
|
| 79 |
median_sort_gap_max=median_gap_max,
|
| 80 |
+
pi80_width=pi80_width,
|
| 81 |
+
pi96_width=pi96_width,
|
| 82 |
weekly_directional_accuracy=weekly_da,
|
| 83 |
weekly_magnitude_ratio=weekly_mr,
|
| 84 |
weekly_tail_capture_rate=weekly_tail,
|
| 85 |
weekly_pi80_coverage=weekly_pi80,
|
| 86 |
+
weekly_pi80_width=weekly_pi80_width,
|
| 87 |
weekly_pi80_width_ratio=weekly_pi80_width_ratio,
|
| 88 |
weekly_pi96_coverage=weekly_pi96,
|
| 89 |
+
weekly_pi96_width=weekly_pi96_width,
|
| 90 |
weekly_pi96_width_ratio=weekly_pi96_width_ratio,
|
| 91 |
weekly_quantile_crossing_rate=weekly_qcross,
|
| 92 |
weekly_sorted_quantile_crossing_rate=weekly_sorted_qcross,
|
| 93 |
weekly_median_sort_gap_max=weekly_gap,
|
| 94 |
weekly_sample_count=weekly_samples,
|
| 95 |
)
|
| 96 |
+
warnings = evaluate_quality_gate_warnings(
|
| 97 |
+
vr=vr,
|
| 98 |
+
mae_vs_naive_zero=mae_vs_naive_zero,
|
| 99 |
+
weekly_mae_vs_naive_zero=weekly_mae_vs_naive_zero,
|
| 100 |
+
)
|
| 101 |
+
for warning in warnings:
|
| 102 |
+
print(f"QUALITY GATE WARNING: {warning}")
|
| 103 |
|
| 104 |
if passed:
|
| 105 |
print("QUALITY GATE: PASSED")
|