ifieryarrows commited on
Commit
8d1bbd4
·
verified ·
1 Parent(s): 18d4089

Sync from GitHub (tests passed)

Browse files
Files changed (1) hide show
  1. scripts/tft_quality_gate.py +51 -0
scripts/tft_quality_gate.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ CI quality gate for TFT-ASRO training.
3
+
4
+ Reads tft_metadata.json written by trainer.py and exits non-zero when
5
+ metrics fall below deployment thresholds.
6
+
7
+ Used by .github/workflows/tft-training.yml (YAML cannot embed indented
8
+ Python multiline strings without breaking the workflow parser).
9
+ """
10
+
11
+ from __future__ import annotations
12
+
13
+ import json
14
+ import pathlib
15
+ import sys
16
+
17
+ META_PATH = pathlib.Path("/tmp/models/tft/tft_metadata.json")
18
+
19
+
20
+ def main() -> int:
21
+ if not META_PATH.exists():
22
+ print("No metadata file found — skipping quality gate")
23
+ return 0
24
+
25
+ data = json.loads(META_PATH.read_text(encoding="utf-8"))
26
+ metrics = data.get("test_metrics", {})
27
+ da = metrics.get("directional_accuracy", 0.5)
28
+ sharpe = metrics.get("sharpe_ratio", 0.0)
29
+ vr = metrics.get("variance_ratio", 1.0)
30
+
31
+ print(f"Quality gate metrics: DA={da:.4f} Sharpe={sharpe:.4f} VR={vr:.4f}")
32
+
33
+ reasons: list[str] = []
34
+ if da < 0.49:
35
+ reasons.append(f"DA={da:.4f} < 0.49")
36
+ if sharpe < -0.30:
37
+ reasons.append(f"Sharpe={sharpe:.4f} < -0.30")
38
+ if vr < 0.2 or vr > 2.5:
39
+ reasons.append(f"VR={vr:.4f} outside [0.2, 2.5]")
40
+
41
+ if not reasons:
42
+ print("QUALITY GATE: PASSED")
43
+ return 0
44
+
45
+ print(f"QUALITY GATE: FAILED — {reasons}")
46
+ print("Model checkpoint will NOT be promoted. Previous checkpoint retained.")
47
+ return 1
48
+
49
+
50
+ if __name__ == "__main__":
51
+ sys.exit(main())