Spaces:
Running
Running
File size: 4,329 Bytes
8d1bbd4 990ad35 8d1bbd4 964ddf9 8d1bbd4 990ad35 964ddf9 e411cee 8d1bbd4 964ddf9 8d1bbd4 c271c72 8d1bbd4 964ddf9 8d1bbd4 c271c72 e411cee d317049 e411cee 4c79e2a e411cee 4c79e2a d317049 4c79e2a d317049 e411cee c271c72 d317049 4c79e2a d317049 c271c72 e411cee d317049 e411cee 4c79e2a e411cee 4c79e2a d317049 4c79e2a d317049 c271c72 e411cee 8d1bbd4 847c80d 8d1bbd4 990ad35 8d1bbd4 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 | """
CI quality gate for TFT-ASRO training.
Reads tft_metadata.json written by trainer.py and exits non-zero when
metrics fall below deployment thresholds.
Thin wrapper that delegates threshold logic to `app.quality_gate` so that
GitHub Actions CI and the FastAPI runtime always agree on the rules.
"""
from __future__ import annotations
import json
import os
import pathlib
import sys
BACKEND_ROOT = pathlib.Path(__file__).resolve().parents[1]
if str(BACKEND_ROOT) not in sys.path:
sys.path.insert(0, str(BACKEND_ROOT))
from app.quality_gate import evaluate_quality_gate, evaluate_quality_gate_warnings
META_PATH = pathlib.Path(os.environ.get("TFT_METADATA_PATH", "/tmp/models/tft/tft_metadata.json"))
def main() -> int:
if not META_PATH.exists():
print("No metadata file found - quality gate cannot evaluate training output")
return 1
data = json.loads(META_PATH.read_text(encoding="utf-8-sig"))
metrics = data.get("test_metrics", {})
da = metrics.get("directional_accuracy", 0.5)
sharpe = metrics.get("sharpe_ratio", 0.0)
vr = metrics.get("variance_ratio", 1.0)
tail_capture = metrics.get("tail_capture_rate")
quantile_crossing = metrics.get("quantile_crossing_rate")
median_gap_max = metrics.get("median_sort_gap_max")
pi80_width = metrics.get("pi80_width")
pi96_width = metrics.get("pi96_width")
mae_vs_naive_zero = metrics.get("mae_vs_naive_zero")
weekly_da = metrics.get("weekly_directional_accuracy")
weekly_mr = metrics.get("weekly_magnitude_ratio")
weekly_tail = metrics.get("weekly_tail_capture_rate")
weekly_pi80 = metrics.get("weekly_pi80_coverage")
weekly_pi80_width = metrics.get("weekly_pi80_width")
weekly_pi80_width_ratio = metrics.get("weekly_pi80_width_ratio")
weekly_pi96 = metrics.get("weekly_pi96_coverage")
weekly_pi96_width = metrics.get("weekly_pi96_width")
weekly_pi96_width_ratio = metrics.get("weekly_pi96_width_ratio")
weekly_qcross = metrics.get("weekly_quantile_crossing_rate")
weekly_sorted_qcross = metrics.get("weekly_sorted_quantile_crossing_rate")
weekly_gap = metrics.get("weekly_median_sort_gap_max")
weekly_samples = metrics.get("weekly_sample_count")
weekly_mae_vs_naive_zero = metrics.get("weekly_mae_vs_naive_zero")
print(
"Quality gate metrics: "
f"DA={da:.4f} Sharpe={sharpe:.4f} VR={vr:.4f} "
f"Tail={tail_capture if tail_capture is not None else 'n/a'} "
f"QCross={quantile_crossing if quantile_crossing is not None else 'n/a'}"
)
print(
"Weekly gate metrics: "
f"WeeklyDA={weekly_da} WeeklyMR={weekly_mr} "
f"WeeklyTail={weekly_tail} WeeklyPI80={weekly_pi80} "
f"WeeklyPI96WidthRatio={weekly_pi96_width_ratio} "
f"WeeklyQCross={weekly_qcross} WeeklySortedQCross={weekly_sorted_qcross} "
f"WeeklyN={weekly_samples}"
)
passed, reasons = evaluate_quality_gate(
da,
sharpe,
vr,
tail_capture=tail_capture,
quantile_crossing_rate=quantile_crossing,
median_sort_gap_max=median_gap_max,
pi80_width=pi80_width,
pi96_width=pi96_width,
weekly_directional_accuracy=weekly_da,
weekly_magnitude_ratio=weekly_mr,
weekly_tail_capture_rate=weekly_tail,
weekly_pi80_coverage=weekly_pi80,
weekly_pi80_width=weekly_pi80_width,
weekly_pi80_width_ratio=weekly_pi80_width_ratio,
weekly_pi96_coverage=weekly_pi96,
weekly_pi96_width=weekly_pi96_width,
weekly_pi96_width_ratio=weekly_pi96_width_ratio,
weekly_quantile_crossing_rate=weekly_qcross,
weekly_sorted_quantile_crossing_rate=weekly_sorted_qcross,
weekly_median_sort_gap_max=weekly_gap,
weekly_sample_count=weekly_samples,
)
warnings = evaluate_quality_gate_warnings(
vr=vr,
mae_vs_naive_zero=mae_vs_naive_zero,
weekly_mae_vs_naive_zero=weekly_mae_vs_naive_zero,
)
for warning in warnings:
print(f"QUALITY GATE WARNING: {warning}")
if passed:
print("QUALITY GATE: PASSED")
return 0
print(f"QUALITY GATE: FAILED — {reasons}")
print("Model checkpoint will NOT be promoted. Previous checkpoint retained.")
return 1
if __name__ == "__main__":
sys.exit(main())
|