File size: 4,329 Bytes
8d1bbd4
 
 
 
 
 
990ad35
 
8d1bbd4
 
 
 
 
964ddf9
8d1bbd4
 
990ad35
964ddf9
 
 
 
e411cee
8d1bbd4
964ddf9
8d1bbd4
 
 
 
c271c72
 
8d1bbd4
964ddf9
8d1bbd4
 
 
 
c271c72
 
 
e411cee
 
 
d317049
 
 
 
e411cee
4c79e2a
 
e411cee
4c79e2a
d317049
4c79e2a
d317049
 
e411cee
c271c72
 
 
 
 
 
 
d317049
 
 
 
4c79e2a
 
 
d317049
c271c72
 
 
 
 
 
 
 
e411cee
 
d317049
 
 
 
e411cee
4c79e2a
 
e411cee
4c79e2a
d317049
4c79e2a
d317049
 
c271c72
e411cee
 
 
 
 
 
 
8d1bbd4
847c80d
8d1bbd4
 
 
 
 
 
 
990ad35
8d1bbd4
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
"""
CI quality gate for TFT-ASRO training.

Reads tft_metadata.json written by trainer.py and exits non-zero when
metrics fall below deployment thresholds.

Thin wrapper that delegates threshold logic to `app.quality_gate` so that
GitHub Actions CI and the FastAPI runtime always agree on the rules.
"""

from __future__ import annotations

import json
import os
import pathlib
import sys

BACKEND_ROOT = pathlib.Path(__file__).resolve().parents[1]
if str(BACKEND_ROOT) not in sys.path:
    sys.path.insert(0, str(BACKEND_ROOT))

from app.quality_gate import evaluate_quality_gate, evaluate_quality_gate_warnings

META_PATH = pathlib.Path(os.environ.get("TFT_METADATA_PATH", "/tmp/models/tft/tft_metadata.json"))


def main() -> int:
    if not META_PATH.exists():
        print("No metadata file found - quality gate cannot evaluate training output")
        return 1

    data = json.loads(META_PATH.read_text(encoding="utf-8-sig"))
    metrics = data.get("test_metrics", {})
    da = metrics.get("directional_accuracy", 0.5)
    sharpe = metrics.get("sharpe_ratio", 0.0)
    vr = metrics.get("variance_ratio", 1.0)
    tail_capture = metrics.get("tail_capture_rate")
    quantile_crossing = metrics.get("quantile_crossing_rate")
    median_gap_max = metrics.get("median_sort_gap_max")
    pi80_width = metrics.get("pi80_width")
    pi96_width = metrics.get("pi96_width")
    mae_vs_naive_zero = metrics.get("mae_vs_naive_zero")
    weekly_da = metrics.get("weekly_directional_accuracy")
    weekly_mr = metrics.get("weekly_magnitude_ratio")
    weekly_tail = metrics.get("weekly_tail_capture_rate")
    weekly_pi80 = metrics.get("weekly_pi80_coverage")
    weekly_pi80_width = metrics.get("weekly_pi80_width")
    weekly_pi80_width_ratio = metrics.get("weekly_pi80_width_ratio")
    weekly_pi96 = metrics.get("weekly_pi96_coverage")
    weekly_pi96_width = metrics.get("weekly_pi96_width")
    weekly_pi96_width_ratio = metrics.get("weekly_pi96_width_ratio")
    weekly_qcross = metrics.get("weekly_quantile_crossing_rate")
    weekly_sorted_qcross = metrics.get("weekly_sorted_quantile_crossing_rate")
    weekly_gap = metrics.get("weekly_median_sort_gap_max")
    weekly_samples = metrics.get("weekly_sample_count")
    weekly_mae_vs_naive_zero = metrics.get("weekly_mae_vs_naive_zero")

    print(
        "Quality gate metrics: "
        f"DA={da:.4f} Sharpe={sharpe:.4f} VR={vr:.4f} "
        f"Tail={tail_capture if tail_capture is not None else 'n/a'} "
        f"QCross={quantile_crossing if quantile_crossing is not None else 'n/a'}"
    )
    print(
        "Weekly gate metrics: "
        f"WeeklyDA={weekly_da} WeeklyMR={weekly_mr} "
        f"WeeklyTail={weekly_tail} WeeklyPI80={weekly_pi80} "
        f"WeeklyPI96WidthRatio={weekly_pi96_width_ratio} "
        f"WeeklyQCross={weekly_qcross} WeeklySortedQCross={weekly_sorted_qcross} "
        f"WeeklyN={weekly_samples}"
    )

    passed, reasons = evaluate_quality_gate(
        da,
        sharpe,
        vr,
        tail_capture=tail_capture,
        quantile_crossing_rate=quantile_crossing,
        median_sort_gap_max=median_gap_max,
        pi80_width=pi80_width,
        pi96_width=pi96_width,
        weekly_directional_accuracy=weekly_da,
        weekly_magnitude_ratio=weekly_mr,
        weekly_tail_capture_rate=weekly_tail,
        weekly_pi80_coverage=weekly_pi80,
        weekly_pi80_width=weekly_pi80_width,
        weekly_pi80_width_ratio=weekly_pi80_width_ratio,
        weekly_pi96_coverage=weekly_pi96,
        weekly_pi96_width=weekly_pi96_width,
        weekly_pi96_width_ratio=weekly_pi96_width_ratio,
        weekly_quantile_crossing_rate=weekly_qcross,
        weekly_sorted_quantile_crossing_rate=weekly_sorted_qcross,
        weekly_median_sort_gap_max=weekly_gap,
        weekly_sample_count=weekly_samples,
    )
    warnings = evaluate_quality_gate_warnings(
        vr=vr,
        mae_vs_naive_zero=mae_vs_naive_zero,
        weekly_mae_vs_naive_zero=weekly_mae_vs_naive_zero,
    )
    for warning in warnings:
        print(f"QUALITY GATE WARNING: {warning}")

    if passed:
        print("QUALITY GATE: PASSED")
        return 0

    print(f"QUALITY GATE: FAILED — {reasons}")
    print("Model checkpoint will NOT be promoted. Previous checkpoint retained.")
    return 1


if __name__ == "__main__":
    sys.exit(main())