Spaces:
Running
Running
Sync from GitHub (tests passed)
Browse files- Dockerfile +1 -1
- app/main.py +63 -5
- app/quality_gate.py +28 -2
- app/settings.py +17 -0
- deep_learning/config.py +9 -9
- deep_learning/data/feature_store.py +18 -1
- deep_learning/data/sentiment_market_date.py +55 -4
- deep_learning/inference/predictor.py +14 -0
- deep_learning/models/hub.py +92 -1
- deep_learning/models/tft_copper.py +24 -14
- deep_learning/training/hyperopt.py +106 -24
- deep_learning/training/metrics.py +66 -3
- deep_learning/training/trainer.py +47 -5
- pyproject.toml +1 -0
- scripts/tft_quality_gate.py +11 -1
Dockerfile
CHANGED
|
@@ -1,4 +1,4 @@
|
|
| 1 |
-
FROM python:3.11-slim
|
| 2 |
|
| 3 |
WORKDIR /code
|
| 4 |
|
|
|
|
| 1 |
+
FROM python:3.11-slim@sha256:9a7765b36773a37061455b332f18e265e7f58f6fea9c419a550d2a8b0e9db834
|
| 2 |
|
| 3 |
WORKDIR /code
|
| 4 |
|
app/main.py
CHANGED
|
@@ -8,6 +8,7 @@ Endpoints:
|
|
| 8 |
"""
|
| 9 |
|
| 10 |
import logging
|
|
|
|
| 11 |
from dataclasses import dataclass
|
| 12 |
|
| 13 |
# Suppress httpx request logging to prevent API keys in URLs from appearing in logs
|
|
@@ -18,7 +19,7 @@ from datetime import datetime, timedelta, timezone
|
|
| 18 |
from pathlib import Path
|
| 19 |
from typing import Optional
|
| 20 |
|
| 21 |
-
from fastapi import FastAPI, HTTPException, Query, WebSocket, WebSocketDisconnect, Depends, Header, BackgroundTasks
|
| 22 |
from fastapi.middleware.cors import CORSMiddleware
|
| 23 |
from sqlalchemy import func
|
| 24 |
|
|
@@ -100,10 +101,18 @@ app = FastAPI(
|
|
| 100 |
lifespan=lifespan,
|
| 101 |
)
|
| 102 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 103 |
# CORS configuration
|
| 104 |
app.add_middleware(
|
| 105 |
CORSMiddleware,
|
| 106 |
-
allow_origins=
|
| 107 |
allow_credentials=True,
|
| 108 |
allow_methods=["*"],
|
| 109 |
allow_headers=["*"],
|
|
@@ -1150,7 +1159,32 @@ async def api_root():
|
|
| 1150 |
# =============================================================================
|
| 1151 |
|
| 1152 |
|
| 1153 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1154 |
"""
|
| 1155 |
Verify the pipeline trigger secret from Authorization header.
|
| 1156 |
|
|
@@ -1158,9 +1192,12 @@ def verify_pipeline_secret(authorization: Optional[str] = Header(None)) -> None:
|
|
| 1158 |
"""
|
| 1159 |
settings = get_settings()
|
| 1160 |
|
|
|
|
|
|
|
| 1161 |
# If no secret is configured, reject all requests (fail secure)
|
| 1162 |
if not settings.pipeline_trigger_secret:
|
| 1163 |
logger.warning("Pipeline trigger attempted but PIPELINE_TRIGGER_SECRET not configured")
|
|
|
|
| 1164 |
raise HTTPException(
|
| 1165 |
status_code=401,
|
| 1166 |
detail="Pipeline trigger authentication not configured. Set PIPELINE_TRIGGER_SECRET."
|
|
@@ -1168,6 +1205,7 @@ def verify_pipeline_secret(authorization: Optional[str] = Header(None)) -> None:
|
|
| 1168 |
|
| 1169 |
# Check Authorization header
|
| 1170 |
if not authorization:
|
|
|
|
| 1171 |
raise HTTPException(
|
| 1172 |
status_code=401,
|
| 1173 |
detail="Missing Authorization header. Expected: Bearer <token>"
|
|
@@ -1176,6 +1214,7 @@ def verify_pipeline_secret(authorization: Optional[str] = Header(None)) -> None:
|
|
| 1176 |
# Parse Bearer token
|
| 1177 |
parts = authorization.split(" ", 1)
|
| 1178 |
if len(parts) != 2 or parts[0].lower() != "bearer":
|
|
|
|
| 1179 |
raise HTTPException(
|
| 1180 |
status_code=401,
|
| 1181 |
detail="Invalid Authorization format. Expected: Bearer <token>"
|
|
@@ -1187,11 +1226,13 @@ def verify_pipeline_secret(authorization: Optional[str] = Header(None)) -> None:
|
|
| 1187 |
import secrets
|
| 1188 |
if not secrets.compare_digest(token, settings.pipeline_trigger_secret):
|
| 1189 |
logger.warning("Pipeline trigger attempted with invalid token")
|
|
|
|
| 1190 |
raise HTTPException(
|
| 1191 |
status_code=401,
|
| 1192 |
detail="Invalid pipeline trigger token"
|
| 1193 |
)
|
| 1194 |
-
|
|
|
|
| 1195 |
logger.info("Pipeline trigger authorized successfully")
|
| 1196 |
|
| 1197 |
|
|
@@ -1208,7 +1249,12 @@ def verify_pipeline_secret(authorization: Optional[str] = Header(None)) -> None:
|
|
| 1208 |
)
|
| 1209 |
async def trigger_pipeline(
|
| 1210 |
train_model: bool = Query(default=False, description="Train/retrain XGBoost model"),
|
| 1211 |
-
trigger_source: str = Query(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1212 |
_auth: None = Depends(verify_pipeline_secret),
|
| 1213 |
):
|
| 1214 |
"""
|
|
@@ -1386,7 +1432,11 @@ async def get_tft_summary(
|
|
| 1386 |
weekly_mr = metrics.get("weekly_magnitude_ratio")
|
| 1387 |
weekly_tail = metrics.get("weekly_tail_capture_rate")
|
| 1388 |
weekly_pi80 = metrics.get("weekly_pi80_coverage")
|
|
|
|
|
|
|
|
|
|
| 1389 |
weekly_qcross = metrics.get("weekly_quantile_crossing_rate")
|
|
|
|
| 1390 |
weekly_gap = metrics.get("weekly_median_sort_gap_max")
|
| 1391 |
weekly_samples = metrics.get("weekly_sample_count")
|
| 1392 |
|
|
@@ -1401,7 +1451,11 @@ async def get_tft_summary(
|
|
| 1401 |
weekly_magnitude_ratio=weekly_mr,
|
| 1402 |
weekly_tail_capture_rate=weekly_tail,
|
| 1403 |
weekly_pi80_coverage=weekly_pi80,
|
|
|
|
|
|
|
|
|
|
| 1404 |
weekly_quantile_crossing_rate=weekly_qcross,
|
|
|
|
| 1405 |
weekly_median_sort_gap_max=weekly_gap,
|
| 1406 |
weekly_sample_count=weekly_samples,
|
| 1407 |
)
|
|
@@ -1422,7 +1476,11 @@ async def get_tft_summary(
|
|
| 1422 |
"weekly_magnitude_ratio": weekly_mr,
|
| 1423 |
"weekly_tail_capture_rate": weekly_tail,
|
| 1424 |
"weekly_pi80_coverage": weekly_pi80,
|
|
|
|
|
|
|
|
|
|
| 1425 |
"weekly_quantile_crossing_rate": weekly_qcross,
|
|
|
|
| 1426 |
"weekly_median_sort_gap_max": weekly_gap,
|
| 1427 |
"weekly_sample_count": weekly_samples,
|
| 1428 |
}.items():
|
|
|
|
| 8 |
"""
|
| 9 |
|
| 10 |
import logging
|
| 11 |
+
from collections import defaultdict
|
| 12 |
from dataclasses import dataclass
|
| 13 |
|
| 14 |
# Suppress httpx request logging to prevent API keys in URLs from appearing in logs
|
|
|
|
| 19 |
from pathlib import Path
|
| 20 |
from typing import Optional
|
| 21 |
|
| 22 |
+
from fastapi import FastAPI, HTTPException, Query, WebSocket, WebSocketDisconnect, Depends, Header, BackgroundTasks, Request
|
| 23 |
from fastapi.middleware.cors import CORSMiddleware
|
| 24 |
from sqlalchemy import func
|
| 25 |
|
|
|
|
| 101 |
lifespan=lifespan,
|
| 102 |
)
|
| 103 |
|
| 104 |
+
def _resolve_cors_origins() -> list[str]:
|
| 105 |
+
settings = get_settings()
|
| 106 |
+
origins = settings.cors_allowed_origins_list
|
| 107 |
+
if "*" in origins and settings.environment.lower() in {"prod", "production"}:
|
| 108 |
+
raise RuntimeError("CORS wildcard is forbidden in production")
|
| 109 |
+
return origins
|
| 110 |
+
|
| 111 |
+
|
| 112 |
# CORS configuration
|
| 113 |
app.add_middleware(
|
| 114 |
CORSMiddleware,
|
| 115 |
+
allow_origins=_resolve_cors_origins(),
|
| 116 |
allow_credentials=True,
|
| 117 |
allow_methods=["*"],
|
| 118 |
allow_headers=["*"],
|
|
|
|
| 1159 |
# =============================================================================
|
| 1160 |
|
| 1161 |
|
| 1162 |
+
_PIPELINE_AUTH_FAILURES: dict[str, list[datetime]] = defaultdict(list)
|
| 1163 |
+
|
| 1164 |
+
|
| 1165 |
+
def _pipeline_auth_key(request: Request) -> str:
|
| 1166 |
+
if request.client and request.client.host:
|
| 1167 |
+
return request.client.host
|
| 1168 |
+
return "unknown"
|
| 1169 |
+
|
| 1170 |
+
|
| 1171 |
+
def _record_pipeline_auth_failure(key: str) -> None:
|
| 1172 |
+
now = datetime.now(timezone.utc)
|
| 1173 |
+
cutoff = now - timedelta(minutes=10)
|
| 1174 |
+
recent = [ts for ts in _PIPELINE_AUTH_FAILURES[key] if ts >= cutoff]
|
| 1175 |
+
recent.append(now)
|
| 1176 |
+
_PIPELINE_AUTH_FAILURES[key] = recent
|
| 1177 |
+
if len(recent) > 5:
|
| 1178 |
+
raise HTTPException(
|
| 1179 |
+
status_code=429,
|
| 1180 |
+
detail="Too many invalid pipeline trigger attempts",
|
| 1181 |
+
)
|
| 1182 |
+
|
| 1183 |
+
|
| 1184 |
+
def verify_pipeline_secret(
|
| 1185 |
+
request: Request,
|
| 1186 |
+
authorization: Optional[str] = Header(None),
|
| 1187 |
+
) -> None:
|
| 1188 |
"""
|
| 1189 |
Verify the pipeline trigger secret from Authorization header.
|
| 1190 |
|
|
|
|
| 1192 |
"""
|
| 1193 |
settings = get_settings()
|
| 1194 |
|
| 1195 |
+
auth_key = _pipeline_auth_key(request)
|
| 1196 |
+
|
| 1197 |
# If no secret is configured, reject all requests (fail secure)
|
| 1198 |
if not settings.pipeline_trigger_secret:
|
| 1199 |
logger.warning("Pipeline trigger attempted but PIPELINE_TRIGGER_SECRET not configured")
|
| 1200 |
+
_record_pipeline_auth_failure(auth_key)
|
| 1201 |
raise HTTPException(
|
| 1202 |
status_code=401,
|
| 1203 |
detail="Pipeline trigger authentication not configured. Set PIPELINE_TRIGGER_SECRET."
|
|
|
|
| 1205 |
|
| 1206 |
# Check Authorization header
|
| 1207 |
if not authorization:
|
| 1208 |
+
_record_pipeline_auth_failure(auth_key)
|
| 1209 |
raise HTTPException(
|
| 1210 |
status_code=401,
|
| 1211 |
detail="Missing Authorization header. Expected: Bearer <token>"
|
|
|
|
| 1214 |
# Parse Bearer token
|
| 1215 |
parts = authorization.split(" ", 1)
|
| 1216 |
if len(parts) != 2 or parts[0].lower() != "bearer":
|
| 1217 |
+
_record_pipeline_auth_failure(auth_key)
|
| 1218 |
raise HTTPException(
|
| 1219 |
status_code=401,
|
| 1220 |
detail="Invalid Authorization format. Expected: Bearer <token>"
|
|
|
|
| 1226 |
import secrets
|
| 1227 |
if not secrets.compare_digest(token, settings.pipeline_trigger_secret):
|
| 1228 |
logger.warning("Pipeline trigger attempted with invalid token")
|
| 1229 |
+
_record_pipeline_auth_failure(auth_key)
|
| 1230 |
raise HTTPException(
|
| 1231 |
status_code=401,
|
| 1232 |
detail="Invalid pipeline trigger token"
|
| 1233 |
)
|
| 1234 |
+
|
| 1235 |
+
_PIPELINE_AUTH_FAILURES.pop(auth_key, None)
|
| 1236 |
logger.info("Pipeline trigger authorized successfully")
|
| 1237 |
|
| 1238 |
|
|
|
|
| 1249 |
)
|
| 1250 |
async def trigger_pipeline(
|
| 1251 |
train_model: bool = Query(default=False, description="Train/retrain XGBoost model"),
|
| 1252 |
+
trigger_source: str = Query(
|
| 1253 |
+
default="api",
|
| 1254 |
+
max_length=32,
|
| 1255 |
+
pattern="^(api|cron|manual|github-actions)$",
|
| 1256 |
+
description="Source of trigger (api, cron, manual, github-actions)",
|
| 1257 |
+
),
|
| 1258 |
_auth: None = Depends(verify_pipeline_secret),
|
| 1259 |
):
|
| 1260 |
"""
|
|
|
|
| 1432 |
weekly_mr = metrics.get("weekly_magnitude_ratio")
|
| 1433 |
weekly_tail = metrics.get("weekly_tail_capture_rate")
|
| 1434 |
weekly_pi80 = metrics.get("weekly_pi80_coverage")
|
| 1435 |
+
weekly_pi80_width_ratio = metrics.get("weekly_pi80_width_ratio")
|
| 1436 |
+
weekly_pi96 = metrics.get("weekly_pi96_coverage")
|
| 1437 |
+
weekly_pi96_width_ratio = metrics.get("weekly_pi96_width_ratio")
|
| 1438 |
weekly_qcross = metrics.get("weekly_quantile_crossing_rate")
|
| 1439 |
+
weekly_sorted_qcross = metrics.get("weekly_sorted_quantile_crossing_rate")
|
| 1440 |
weekly_gap = metrics.get("weekly_median_sort_gap_max")
|
| 1441 |
weekly_samples = metrics.get("weekly_sample_count")
|
| 1442 |
|
|
|
|
| 1451 |
weekly_magnitude_ratio=weekly_mr,
|
| 1452 |
weekly_tail_capture_rate=weekly_tail,
|
| 1453 |
weekly_pi80_coverage=weekly_pi80,
|
| 1454 |
+
weekly_pi80_width_ratio=weekly_pi80_width_ratio,
|
| 1455 |
+
weekly_pi96_coverage=weekly_pi96,
|
| 1456 |
+
weekly_pi96_width_ratio=weekly_pi96_width_ratio,
|
| 1457 |
weekly_quantile_crossing_rate=weekly_qcross,
|
| 1458 |
+
weekly_sorted_quantile_crossing_rate=weekly_sorted_qcross,
|
| 1459 |
weekly_median_sort_gap_max=weekly_gap,
|
| 1460 |
weekly_sample_count=weekly_samples,
|
| 1461 |
)
|
|
|
|
| 1476 |
"weekly_magnitude_ratio": weekly_mr,
|
| 1477 |
"weekly_tail_capture_rate": weekly_tail,
|
| 1478 |
"weekly_pi80_coverage": weekly_pi80,
|
| 1479 |
+
"weekly_pi80_width_ratio": weekly_pi80_width_ratio,
|
| 1480 |
+
"weekly_pi96_coverage": weekly_pi96,
|
| 1481 |
+
"weekly_pi96_width_ratio": weekly_pi96_width_ratio,
|
| 1482 |
"weekly_quantile_crossing_rate": weekly_qcross,
|
| 1483 |
+
"weekly_sorted_quantile_crossing_rate": weekly_sorted_qcross,
|
| 1484 |
"weekly_median_sort_gap_max": weekly_gap,
|
| 1485 |
"weekly_sample_count": weekly_samples,
|
| 1486 |
}.items():
|
app/quality_gate.py
CHANGED
|
@@ -25,7 +25,11 @@ def evaluate_quality_gate(
|
|
| 25 |
weekly_magnitude_ratio: Optional[float] = None,
|
| 26 |
weekly_tail_capture_rate: Optional[float] = None,
|
| 27 |
weekly_pi80_coverage: Optional[float] = None,
|
|
|
|
|
|
|
|
|
|
| 28 |
weekly_quantile_crossing_rate: Optional[float] = None,
|
|
|
|
| 29 |
weekly_median_sort_gap_max: Optional[float] = None,
|
| 30 |
weekly_sample_count: Optional[int] = None,
|
| 31 |
) -> Tuple[bool, List[str]]:
|
|
@@ -64,10 +68,32 @@ def evaluate_quality_gate(
|
|
| 64 |
elif weekly_pi80_coverage < 0.74 or weekly_pi80_coverage > 0.86:
|
| 65 |
reasons.append(f"WeeklyPI80={weekly_pi80_coverage:.4f} outside [0.74, 0.86]")
|
| 66 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 67 |
if weekly_quantile_crossing_rate is None:
|
| 68 |
reasons.append("Missing weekly_quantile_crossing_rate")
|
| 69 |
-
elif weekly_quantile_crossing_rate > 0.
|
| 70 |
-
reasons.append(f"WeeklyQuantileCrossing={weekly_quantile_crossing_rate:.4f} > 0.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 71 |
|
| 72 |
if weekly_median_sort_gap_max is not None and weekly_median_sort_gap_max > 0.005:
|
| 73 |
reasons.append(f"WeeklyMedianSortGapMax={weekly_median_sort_gap_max:.4f} > 0.005")
|
|
|
|
| 25 |
weekly_magnitude_ratio: Optional[float] = None,
|
| 26 |
weekly_tail_capture_rate: Optional[float] = None,
|
| 27 |
weekly_pi80_coverage: Optional[float] = None,
|
| 28 |
+
weekly_pi80_width_ratio: Optional[float] = None,
|
| 29 |
+
weekly_pi96_coverage: Optional[float] = None,
|
| 30 |
+
weekly_pi96_width_ratio: Optional[float] = None,
|
| 31 |
weekly_quantile_crossing_rate: Optional[float] = None,
|
| 32 |
+
weekly_sorted_quantile_crossing_rate: Optional[float] = None,
|
| 33 |
weekly_median_sort_gap_max: Optional[float] = None,
|
| 34 |
weekly_sample_count: Optional[int] = None,
|
| 35 |
) -> Tuple[bool, List[str]]:
|
|
|
|
| 68 |
elif weekly_pi80_coverage < 0.74 or weekly_pi80_coverage > 0.86:
|
| 69 |
reasons.append(f"WeeklyPI80={weekly_pi80_coverage:.4f} outside [0.74, 0.86]")
|
| 70 |
|
| 71 |
+
if weekly_pi80_width_ratio is None:
|
| 72 |
+
reasons.append("Missing weekly_pi80_width_ratio")
|
| 73 |
+
elif weekly_pi80_width_ratio > 2.0 and weekly_pi80_coverage is not None and weekly_pi80_coverage > 0.86:
|
| 74 |
+
reasons.append(
|
| 75 |
+
f"WeeklyPI80Overwide={weekly_pi80_width_ratio:.4f} with coverage={weekly_pi80_coverage:.4f}"
|
| 76 |
+
)
|
| 77 |
+
|
| 78 |
+
if weekly_pi96_coverage is None:
|
| 79 |
+
reasons.append("Missing weekly_pi96_coverage")
|
| 80 |
+
|
| 81 |
+
if weekly_pi96_width_ratio is None:
|
| 82 |
+
reasons.append("Missing weekly_pi96_width_ratio")
|
| 83 |
+
elif weekly_pi96_width_ratio > 3.0:
|
| 84 |
+
reasons.append(f"WeeklyPI96WidthRatio={weekly_pi96_width_ratio:.4f} > 3.0")
|
| 85 |
+
|
| 86 |
if weekly_quantile_crossing_rate is None:
|
| 87 |
reasons.append("Missing weekly_quantile_crossing_rate")
|
| 88 |
+
elif weekly_quantile_crossing_rate > 0.05:
|
| 89 |
+
reasons.append(f"WeeklyQuantileCrossing={weekly_quantile_crossing_rate:.4f} > 0.05")
|
| 90 |
+
|
| 91 |
+
if weekly_sorted_quantile_crossing_rate is None:
|
| 92 |
+
reasons.append("Missing weekly_sorted_quantile_crossing_rate")
|
| 93 |
+
elif weekly_sorted_quantile_crossing_rate > 0.0:
|
| 94 |
+
reasons.append(
|
| 95 |
+
f"WeeklySortedQuantileCrossing={weekly_sorted_quantile_crossing_rate:.4f} > 0.0"
|
| 96 |
+
)
|
| 97 |
|
| 98 |
if weekly_median_sort_gap_max is not None and weekly_median_sort_gap_max > 0.005:
|
| 99 |
reasons.append(f"WeeklyMedianSortGapMax={weekly_median_sort_gap_max:.4f} > 0.005")
|
app/settings.py
CHANGED
|
@@ -63,6 +63,13 @@ class Settings(BaseSettings):
|
|
| 63 |
# API settings
|
| 64 |
analysis_ttl_minutes: int = 30
|
| 65 |
log_level: str = "INFO"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 66 |
|
| 67 |
# NOTE: `futures_spot_adjustment` was removed 2026-04.
|
| 68 |
# It was an unused 1:1 scaling constant between HG=F and XCU/USD which
|
|
@@ -204,6 +211,16 @@ class Settings(BaseSettings):
|
|
| 204 |
Always uses env variable (14 symbols).
|
| 205 |
"""
|
| 206 |
return [s.strip() for s in self.yfinance_symbols.split(",") if s.strip()]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 207 |
|
| 208 |
@property
|
| 209 |
def target_symbol(self) -> str:
|
|
|
|
| 63 |
# API settings
|
| 64 |
analysis_ttl_minutes: int = 30
|
| 65 |
log_level: str = "INFO"
|
| 66 |
+
environment: str = "development"
|
| 67 |
+
cors_allowed_origins: str = (
|
| 68 |
+
"http://localhost:3000,"
|
| 69 |
+
"http://localhost:5173,"
|
| 70 |
+
"http://127.0.0.1:3000,"
|
| 71 |
+
"http://127.0.0.1:5173"
|
| 72 |
+
)
|
| 73 |
|
| 74 |
# NOTE: `futures_spot_adjustment` was removed 2026-04.
|
| 75 |
# It was an unused 1:1 scaling constant between HG=F and XCU/USD which
|
|
|
|
| 211 |
Always uses env variable (14 symbols).
|
| 212 |
"""
|
| 213 |
return [s.strip() for s in self.yfinance_symbols.split(",") if s.strip()]
|
| 214 |
+
|
| 215 |
+
@property
|
| 216 |
+
def cors_allowed_origins_list(self) -> list[str]:
|
| 217 |
+
"""Parse CORS origins from comma-separated environment configuration."""
|
| 218 |
+
origins = [
|
| 219 |
+
origin.strip()
|
| 220 |
+
for origin in self.cors_allowed_origins.split(",")
|
| 221 |
+
if origin.strip()
|
| 222 |
+
]
|
| 223 |
+
return origins or ["http://localhost:3000", "http://localhost:5173"]
|
| 224 |
|
| 225 |
@property
|
| 226 |
def target_symbol(self) -> str:
|
deep_learning/config.py
CHANGED
|
@@ -136,15 +136,15 @@ class ASROConfig:
|
|
| 136 |
|
| 137 |
@dataclass(frozen=True)
|
| 138 |
class WeeklyLossConfig:
|
| 139 |
-
lambda_weekly_quantile: float = 0.
|
| 140 |
lambda_t1_quantile: float = 0.10
|
| 141 |
-
lambda_directional: float = 0.
|
| 142 |
-
lambda_magnitude: float = 0.
|
| 143 |
-
lambda_vol: float = 0.
|
| 144 |
-
lambda_crossing: float =
|
| 145 |
-
lambda_sanity: float = 0.
|
| 146 |
-
lambda_width: float = 0.
|
| 147 |
-
lambda_tail_width: float = 0.
|
| 148 |
|
| 149 |
|
| 150 |
@dataclass(frozen=True)
|
|
@@ -167,7 +167,7 @@ class TrainingConfig:
|
|
| 167 |
num_workers: int = 0
|
| 168 |
# 25→15: CI budget fix. 15 trials × 3 folds × 25 epochs ≈ 108 min;
|
| 169 |
# final trainer adds ~40-50 min → total ~155 min < 180 min limit.
|
| 170 |
-
optuna_n_trials: int =
|
| 171 |
# Walk-Forward temporal CV folds for hyperopt (REG-2026-001 P2).
|
| 172 |
# Set to 1 to disable CV and fall back to single-split behaviour.
|
| 173 |
cv_n_folds: int = 3
|
|
|
|
| 136 |
|
| 137 |
@dataclass(frozen=True)
|
| 138 |
class WeeklyLossConfig:
|
| 139 |
+
lambda_weekly_quantile: float = 0.60
|
| 140 |
lambda_t1_quantile: float = 0.10
|
| 141 |
+
lambda_directional: float = 0.10
|
| 142 |
+
lambda_magnitude: float = 0.55
|
| 143 |
+
lambda_vol: float = 0.35
|
| 144 |
+
lambda_crossing: float = 7.0
|
| 145 |
+
lambda_sanity: float = 0.20
|
| 146 |
+
lambda_width: float = 0.50
|
| 147 |
+
lambda_tail_width: float = 0.30
|
| 148 |
|
| 149 |
|
| 150 |
@dataclass(frozen=True)
|
|
|
|
| 167 |
num_workers: int = 0
|
| 168 |
# 25→15: CI budget fix. 15 trials × 3 folds × 25 epochs ≈ 108 min;
|
| 169 |
# final trainer adds ~40-50 min → total ~155 min < 180 min limit.
|
| 170 |
+
optuna_n_trials: int = 30
|
| 171 |
# Walk-Forward temporal CV folds for hyperopt (REG-2026-001 P2).
|
| 172 |
# Set to 1 to disable CV and fall back to single-split behaviour.
|
| 173 |
cv_n_folds: int = 3
|
deep_learning/data/feature_store.py
CHANGED
|
@@ -223,14 +223,21 @@ def _build_daily_embedding_features(
|
|
| 223 |
from app.models import NewsEmbedding, NewsProcessed, NewsRaw
|
| 224 |
from deep_learning.data.embeddings import bytes_to_embedding, aggregate_daily_embeddings
|
| 225 |
from pipelines.market_calendar import assign_market_date
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 226 |
|
| 227 |
rows = (
|
| 228 |
session.query(
|
| 229 |
NewsRaw.published_at,
|
|
|
|
| 230 |
NewsEmbedding.embedding_pca,
|
| 231 |
)
|
| 232 |
.join(NewsProcessed, NewsEmbedding.news_processed_id == NewsProcessed.id)
|
| 233 |
.join(NewsRaw, NewsProcessed.raw_id == NewsRaw.id)
|
|
|
|
| 234 |
.order_by(NewsRaw.published_at.asc())
|
| 235 |
.all()
|
| 236 |
)
|
|
@@ -241,7 +248,13 @@ def _build_daily_embedding_features(
|
|
| 241 |
|
| 242 |
date_groups: dict[str, list[np.ndarray]] = {}
|
| 243 |
for r in rows:
|
| 244 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 245 |
vec = bytes_to_embedding(r.embedding_pca, dim=pca_dim)
|
| 246 |
# bytes_to_embedding now always returns dim-length arrays, but
|
| 247 |
# guard against any future shape surprises to keep stack safe.
|
|
@@ -260,6 +273,10 @@ def _build_daily_embedding_features(
|
|
| 260 |
record[f"emb_pca_{i}"] = float(v)
|
| 261 |
records.append(record)
|
| 262 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 263 |
emb_df = pd.DataFrame(records).set_index("date").sort_index()
|
| 264 |
emb_df.index = pd.to_datetime(emb_df.index)
|
| 265 |
|
|
|
|
| 223 |
from app.models import NewsEmbedding, NewsProcessed, NewsRaw
|
| 224 |
from deep_learning.data.embeddings import bytes_to_embedding, aggregate_daily_embeddings
|
| 225 |
from pipelines.market_calendar import assign_market_date
|
| 226 |
+
from sqlalchemy import func
|
| 227 |
+
|
| 228 |
+
start_date = index.min().to_pydatetime()
|
| 229 |
+
end_date = index.max().to_pydatetime()
|
| 230 |
+
available_expr = func.coalesce(NewsRaw.fetched_at, NewsRaw.published_at)
|
| 231 |
|
| 232 |
rows = (
|
| 233 |
session.query(
|
| 234 |
NewsRaw.published_at,
|
| 235 |
+
NewsRaw.fetched_at,
|
| 236 |
NewsEmbedding.embedding_pca,
|
| 237 |
)
|
| 238 |
.join(NewsProcessed, NewsEmbedding.news_processed_id == NewsProcessed.id)
|
| 239 |
.join(NewsRaw, NewsProcessed.raw_id == NewsRaw.id)
|
| 240 |
+
.filter(NewsRaw.published_at <= end_date, available_expr <= end_date)
|
| 241 |
.order_by(NewsRaw.published_at.asc())
|
| 242 |
.all()
|
| 243 |
)
|
|
|
|
| 248 |
|
| 249 |
date_groups: dict[str, list[np.ndarray]] = {}
|
| 250 |
for r in rows:
|
| 251 |
+
published_market_date = assign_market_date(r.published_at)
|
| 252 |
+
available_at = getattr(r, "fetched_at", None) or r.published_at
|
| 253 |
+
available_market_date = assign_market_date(available_at)
|
| 254 |
+
market_date = max(published_market_date, available_market_date)
|
| 255 |
+
if not (pd.Timestamp(start_date).date() <= market_date <= pd.Timestamp(end_date).date()):
|
| 256 |
+
continue
|
| 257 |
+
d = market_date.isoformat()
|
| 258 |
vec = bytes_to_embedding(r.embedding_pca, dim=pca_dim)
|
| 259 |
# bytes_to_embedding now always returns dim-length arrays, but
|
| 260 |
# guard against any future shape surprises to keep stack safe.
|
|
|
|
| 273 |
record[f"emb_pca_{i}"] = float(v)
|
| 274 |
records.append(record)
|
| 275 |
|
| 276 |
+
if not records:
|
| 277 |
+
cols = [f"emb_pca_{i}" for i in range(pca_dim)]
|
| 278 |
+
return pd.DataFrame(0.0, index=index, columns=cols)
|
| 279 |
+
|
| 280 |
emb_df = pd.DataFrame(records).set_index("date").sort_index()
|
| 281 |
emb_df.index = pd.to_datetime(emb_df.index)
|
| 282 |
|
deep_learning/data/sentiment_market_date.py
CHANGED
|
@@ -3,6 +3,7 @@
|
|
| 3 |
from __future__ import annotations
|
| 4 |
|
| 5 |
import pandas as pd
|
|
|
|
| 6 |
|
| 7 |
from app.models import NewsProcessed, NewsRaw, NewsSentimentV2
|
| 8 |
from pipelines.market_calendar import assign_market_date, is_after_close_news
|
|
@@ -12,6 +13,30 @@ MATERIAL_RELEVANCE_MIN = 0.60
|
|
| 12 |
MATERIAL_CONFIDENCE_MIN = 0.55
|
| 13 |
|
| 14 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 15 |
def build_market_date_sentiment_frame(session, start_date, end_date) -> pd.DataFrame:
|
| 16 |
"""Return daily sentiment indexed by market date, not publication date."""
|
| 17 |
columns = [
|
|
@@ -29,6 +54,7 @@ def build_market_date_sentiment_frame(session, start_date, end_date) -> pd.DataF
|
|
| 29 |
session.query(
|
| 30 |
NewsRaw.published_at,
|
| 31 |
NewsRaw.fetched_at,
|
|
|
|
| 32 |
NewsSentimentV2.final_score,
|
| 33 |
NewsSentimentV2.confidence_calibrated,
|
| 34 |
NewsSentimentV2.relevance_score,
|
|
@@ -36,7 +62,15 @@ def build_market_date_sentiment_frame(session, start_date, end_date) -> pd.DataF
|
|
| 36 |
)
|
| 37 |
.join(NewsProcessed, NewsProcessed.raw_id == NewsRaw.id)
|
| 38 |
.join(NewsSentimentV2, NewsSentimentV2.news_processed_id == NewsProcessed.id)
|
| 39 |
-
.filter(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 40 |
.all()
|
| 41 |
)
|
| 42 |
|
|
@@ -45,7 +79,9 @@ def build_market_date_sentiment_frame(session, start_date, end_date) -> pd.DataF
|
|
| 45 |
|
| 46 |
records = []
|
| 47 |
for r in rows:
|
| 48 |
-
market_date =
|
|
|
|
|
|
|
| 49 |
relevance = float(r.relevance_score or 0.0)
|
| 50 |
confidence = float(r.confidence_calibrated or 0.0)
|
| 51 |
material = relevance >= MATERIAL_RELEVANCE_MIN and confidence >= MATERIAL_CONFIDENCE_MIN
|
|
@@ -61,6 +97,8 @@ def build_market_date_sentiment_frame(session, start_date, end_date) -> pd.DataF
|
|
| 61 |
)
|
| 62 |
|
| 63 |
raw = pd.DataFrame(records)
|
|
|
|
|
|
|
| 64 |
|
| 65 |
def _weighted_sentiment(g: pd.DataFrame) -> float:
|
| 66 |
denom = g["weight"].sum()
|
|
@@ -105,11 +143,21 @@ def build_market_date_event_counts_from_db(session, start_date, end_date) -> pd.
|
|
| 105 |
rows = (
|
| 106 |
session.query(
|
| 107 |
NewsRaw.published_at,
|
|
|
|
|
|
|
| 108 |
NewsSentimentV2.event_type,
|
| 109 |
)
|
| 110 |
.join(NewsProcessed, NewsSentimentV2.news_processed_id == NewsProcessed.id)
|
| 111 |
.join(NewsRaw, NewsProcessed.raw_id == NewsRaw.id)
|
| 112 |
-
.filter(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 113 |
.all()
|
| 114 |
)
|
| 115 |
if not rows:
|
|
@@ -117,12 +165,15 @@ def build_market_date_event_counts_from_db(session, start_date, end_date) -> pd.
|
|
| 117 |
|
| 118 |
records = [
|
| 119 |
{
|
| 120 |
-
"market_date":
|
| 121 |
"event_type": r.event_type,
|
| 122 |
"count": 1,
|
| 123 |
}
|
| 124 |
for r in rows
|
|
|
|
| 125 |
]
|
|
|
|
|
|
|
| 126 |
df = pd.DataFrame(records)
|
| 127 |
pivot = df.pivot_table(index="market_date", columns="event_type", values="count", aggfunc="sum", fill_value=0)
|
| 128 |
pivot.index = pd.to_datetime(pivot.index)
|
|
|
|
| 3 |
from __future__ import annotations
|
| 4 |
|
| 5 |
import pandas as pd
|
| 6 |
+
from sqlalchemy import func
|
| 7 |
|
| 8 |
from app.models import NewsProcessed, NewsRaw, NewsSentimentV2
|
| 9 |
from pipelines.market_calendar import assign_market_date, is_after_close_news
|
|
|
|
| 13 |
MATERIAL_CONFIDENCE_MIN = 0.55
|
| 14 |
|
| 15 |
|
| 16 |
+
def _effective_available_at(row) -> object:
|
| 17 |
+
return (
|
| 18 |
+
getattr(row, "available_at", None)
|
| 19 |
+
or getattr(row, "fetched_at", None)
|
| 20 |
+
or getattr(row, "published_at", None)
|
| 21 |
+
)
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def _effective_market_date(row):
|
| 25 |
+
"""Map news to the later of publication market date and availability date."""
|
| 26 |
+
published_market_date = assign_market_date(row.published_at)
|
| 27 |
+
available_at = _effective_available_at(row)
|
| 28 |
+
if available_at is None:
|
| 29 |
+
return published_market_date
|
| 30 |
+
available_market_date = assign_market_date(available_at)
|
| 31 |
+
return max(published_market_date, available_market_date)
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def _within_window(market_date, start_date, end_date) -> bool:
|
| 35 |
+
start = pd.Timestamp(start_date).date()
|
| 36 |
+
end = pd.Timestamp(end_date).date()
|
| 37 |
+
return start <= market_date <= end
|
| 38 |
+
|
| 39 |
+
|
| 40 |
def build_market_date_sentiment_frame(session, start_date, end_date) -> pd.DataFrame:
|
| 41 |
"""Return daily sentiment indexed by market date, not publication date."""
|
| 42 |
columns = [
|
|
|
|
| 54 |
session.query(
|
| 55 |
NewsRaw.published_at,
|
| 56 |
NewsRaw.fetched_at,
|
| 57 |
+
NewsSentimentV2.available_at,
|
| 58 |
NewsSentimentV2.final_score,
|
| 59 |
NewsSentimentV2.confidence_calibrated,
|
| 60 |
NewsSentimentV2.relevance_score,
|
|
|
|
| 62 |
)
|
| 63 |
.join(NewsProcessed, NewsProcessed.raw_id == NewsRaw.id)
|
| 64 |
.join(NewsSentimentV2, NewsSentimentV2.news_processed_id == NewsProcessed.id)
|
| 65 |
+
.filter(
|
| 66 |
+
NewsRaw.published_at <= end_date,
|
| 67 |
+
func.coalesce(
|
| 68 |
+
NewsSentimentV2.available_at,
|
| 69 |
+
NewsRaw.fetched_at,
|
| 70 |
+
NewsRaw.published_at,
|
| 71 |
+
)
|
| 72 |
+
<= end_date,
|
| 73 |
+
)
|
| 74 |
.all()
|
| 75 |
)
|
| 76 |
|
|
|
|
| 79 |
|
| 80 |
records = []
|
| 81 |
for r in rows:
|
| 82 |
+
market_date = _effective_market_date(r)
|
| 83 |
+
if not _within_window(market_date, start_date, end_date):
|
| 84 |
+
continue
|
| 85 |
relevance = float(r.relevance_score or 0.0)
|
| 86 |
confidence = float(r.confidence_calibrated or 0.0)
|
| 87 |
material = relevance >= MATERIAL_RELEVANCE_MIN and confidence >= MATERIAL_CONFIDENCE_MIN
|
|
|
|
| 97 |
)
|
| 98 |
|
| 99 |
raw = pd.DataFrame(records)
|
| 100 |
+
if raw.empty:
|
| 101 |
+
return pd.DataFrame(columns=columns)
|
| 102 |
|
| 103 |
def _weighted_sentiment(g: pd.DataFrame) -> float:
|
| 104 |
denom = g["weight"].sum()
|
|
|
|
| 143 |
rows = (
|
| 144 |
session.query(
|
| 145 |
NewsRaw.published_at,
|
| 146 |
+
NewsRaw.fetched_at,
|
| 147 |
+
NewsSentimentV2.available_at,
|
| 148 |
NewsSentimentV2.event_type,
|
| 149 |
)
|
| 150 |
.join(NewsProcessed, NewsSentimentV2.news_processed_id == NewsProcessed.id)
|
| 151 |
.join(NewsRaw, NewsProcessed.raw_id == NewsRaw.id)
|
| 152 |
+
.filter(
|
| 153 |
+
NewsRaw.published_at <= end_date,
|
| 154 |
+
func.coalesce(
|
| 155 |
+
NewsSentimentV2.available_at,
|
| 156 |
+
NewsRaw.fetched_at,
|
| 157 |
+
NewsRaw.published_at,
|
| 158 |
+
)
|
| 159 |
+
<= end_date,
|
| 160 |
+
)
|
| 161 |
.all()
|
| 162 |
)
|
| 163 |
if not rows:
|
|
|
|
| 165 |
|
| 166 |
records = [
|
| 167 |
{
|
| 168 |
+
"market_date": _effective_market_date(r),
|
| 169 |
"event_type": r.event_type,
|
| 170 |
"count": 1,
|
| 171 |
}
|
| 172 |
for r in rows
|
| 173 |
+
if _within_window(_effective_market_date(r), start_date, end_date)
|
| 174 |
]
|
| 175 |
+
if not records:
|
| 176 |
+
return pd.DataFrame()
|
| 177 |
df = pd.DataFrame(records)
|
| 178 |
pivot = df.pivot_table(index="market_date", columns="event_type", values="count", aggfunc="sum", fill_value=0)
|
| 179 |
pivot.index = pd.to_datetime(pivot.index)
|
deep_learning/inference/predictor.py
CHANGED
|
@@ -134,6 +134,20 @@ class TFTPredictor:
|
|
| 134 |
return
|
| 135 |
|
| 136 |
metadata_path = Path(self._checkpoint_path).parent / "tft_metadata.json"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 137 |
if not metadata_path.exists():
|
| 138 |
raise IncompatibleTFTCheckpointError(
|
| 139 |
"Incompatible TFT checkpoint: missing weekly_log_v1 metadata. Retraining required."
|
|
|
|
| 134 |
return
|
| 135 |
|
| 136 |
metadata_path = Path(self._checkpoint_path).parent / "tft_metadata.json"
|
| 137 |
+
try:
|
| 138 |
+
from deep_learning.models.hub import validate_artifact_manifest
|
| 139 |
+
|
| 140 |
+
if not validate_artifact_manifest(metadata_path.parent):
|
| 141 |
+
raise IncompatibleTFTCheckpointError(
|
| 142 |
+
"Incompatible TFT checkpoint: missing or invalid artifact manifest. Retraining required."
|
| 143 |
+
)
|
| 144 |
+
except IncompatibleTFTCheckpointError:
|
| 145 |
+
raise
|
| 146 |
+
except Exception as exc:
|
| 147 |
+
raise IncompatibleTFTCheckpointError(
|
| 148 |
+
f"Incompatible TFT checkpoint: artifact manifest validation failed ({exc}). Retraining required."
|
| 149 |
+
) from exc
|
| 150 |
+
|
| 151 |
if not metadata_path.exists():
|
| 152 |
raise IncompatibleTFTCheckpointError(
|
| 153 |
"Incompatible TFT checkpoint: missing weekly_log_v1 metadata. Retraining required."
|
deep_learning/models/hub.py
CHANGED
|
@@ -11,6 +11,8 @@ from __future__ import annotations
|
|
| 11 |
import json
|
| 12 |
import logging
|
| 13 |
import os
|
|
|
|
|
|
|
| 14 |
from pathlib import Path
|
| 15 |
from typing import Optional
|
| 16 |
|
|
@@ -24,11 +26,13 @@ _ARTIFACTS = [
|
|
| 24 |
"conformal_calibration.json",
|
| 25 |
"pca_finbert.joblib",
|
| 26 |
"optuna_results.json",
|
|
|
|
| 27 |
]
|
| 28 |
|
| 29 |
_REQUIRED_ARTIFACTS = [
|
| 30 |
"best_tft_asro.ckpt",
|
| 31 |
"tft_metadata.json",
|
|
|
|
| 32 |
]
|
| 33 |
|
| 34 |
|
|
@@ -36,6 +40,82 @@ def _get_token() -> Optional[str]:
|
|
| 36 |
return os.environ.get(_HF_TOKEN_ENV)
|
| 37 |
|
| 38 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 39 |
def _metadata_contract_valid(metadata_path: Path) -> bool:
|
| 40 |
"""Return True when metadata proves the current weekly TFT contract."""
|
| 41 |
if not metadata_path.exists():
|
|
@@ -94,6 +174,9 @@ def validate_tft_artifact_set(local_dir: str | Path) -> bool:
|
|
| 94 |
logger.warning("TFT artifact set has incompatible metadata in %s", local_dir)
|
| 95 |
return False
|
| 96 |
|
|
|
|
|
|
|
|
|
|
| 97 |
return True
|
| 98 |
|
| 99 |
|
|
@@ -114,13 +197,18 @@ def upload_tft_artifacts(
|
|
| 114 |
return False
|
| 115 |
|
| 116 |
local_dir = Path(local_dir)
|
| 117 |
-
if not
|
| 118 |
logger.warning(
|
| 119 |
"TFT artifact set in %s is not contract-complete; upload skipped",
|
| 120 |
local_dir,
|
| 121 |
)
|
| 122 |
return False
|
| 123 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 124 |
files_to_upload = [
|
| 125 |
local_dir / name for name in _ARTIFACTS if (local_dir / name).exists()
|
| 126 |
]
|
|
@@ -175,6 +263,9 @@ def download_tft_artifacts(
|
|
| 175 |
metadata_path = local_dir / "tft_metadata.json"
|
| 176 |
if metadata_path.exists() and not _metadata_contract_valid(metadata_path):
|
| 177 |
force_download.add("tft_metadata.json")
|
|
|
|
|
|
|
|
|
|
| 178 |
|
| 179 |
try:
|
| 180 |
from huggingface_hub import hf_hub_download
|
|
|
|
| 11 |
import json
|
| 12 |
import logging
|
| 13 |
import os
|
| 14 |
+
import hashlib
|
| 15 |
+
from datetime import datetime, timezone
|
| 16 |
from pathlib import Path
|
| 17 |
from typing import Optional
|
| 18 |
|
|
|
|
| 26 |
"conformal_calibration.json",
|
| 27 |
"pca_finbert.joblib",
|
| 28 |
"optuna_results.json",
|
| 29 |
+
"artifact_manifest.json",
|
| 30 |
]
|
| 31 |
|
| 32 |
_REQUIRED_ARTIFACTS = [
|
| 33 |
"best_tft_asro.ckpt",
|
| 34 |
"tft_metadata.json",
|
| 35 |
+
"artifact_manifest.json",
|
| 36 |
]
|
| 37 |
|
| 38 |
|
|
|
|
| 40 |
return os.environ.get(_HF_TOKEN_ENV)
|
| 41 |
|
| 42 |
|
| 43 |
+
def _sha256_file(path: Path) -> str:
|
| 44 |
+
digest = hashlib.sha256()
|
| 45 |
+
with path.open("rb") as fh:
|
| 46 |
+
for chunk in iter(lambda: fh.read(1024 * 1024), b""):
|
| 47 |
+
digest.update(chunk)
|
| 48 |
+
return digest.hexdigest()
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
def build_artifact_manifest(local_dir: str | Path) -> dict:
|
| 52 |
+
"""Build a SHA256 manifest for every present TFT artifact except itself."""
|
| 53 |
+
local_dir = Path(local_dir)
|
| 54 |
+
artifacts = {}
|
| 55 |
+
for name in _ARTIFACTS:
|
| 56 |
+
if name == "artifact_manifest.json":
|
| 57 |
+
continue
|
| 58 |
+
path = local_dir / name
|
| 59 |
+
if not path.exists():
|
| 60 |
+
continue
|
| 61 |
+
artifacts[name] = {
|
| 62 |
+
"sha256": _sha256_file(path),
|
| 63 |
+
"size_bytes": path.stat().st_size,
|
| 64 |
+
"required": name in {"best_tft_asro.ckpt", "tft_metadata.json"},
|
| 65 |
+
}
|
| 66 |
+
|
| 67 |
+
return {
|
| 68 |
+
"manifest_version": 1,
|
| 69 |
+
"generated_at": datetime.now(timezone.utc).isoformat(),
|
| 70 |
+
"artifacts": artifacts,
|
| 71 |
+
}
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
def write_artifact_manifest(local_dir: str | Path) -> Path:
|
| 75 |
+
"""Write the artifact integrity manifest next to TFT artifacts."""
|
| 76 |
+
local_dir = Path(local_dir)
|
| 77 |
+
manifest_path = local_dir / "artifact_manifest.json"
|
| 78 |
+
manifest = build_artifact_manifest(local_dir)
|
| 79 |
+
manifest_path.write_text(json.dumps(manifest, indent=2, sort_keys=True), encoding="utf-8")
|
| 80 |
+
return manifest_path
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
def validate_artifact_manifest(local_dir: str | Path) -> bool:
|
| 84 |
+
"""Verify artifact_manifest.json hashes before loading checkpoint/joblib files."""
|
| 85 |
+
local_dir = Path(local_dir)
|
| 86 |
+
manifest_path = local_dir / "artifact_manifest.json"
|
| 87 |
+
if not manifest_path.exists():
|
| 88 |
+
logger.warning("TFT artifact manifest missing in %s", local_dir)
|
| 89 |
+
return False
|
| 90 |
+
|
| 91 |
+
try:
|
| 92 |
+
data = json.loads(manifest_path.read_text(encoding="utf-8"))
|
| 93 |
+
artifacts = data.get("artifacts") or {}
|
| 94 |
+
for required in ("best_tft_asro.ckpt", "tft_metadata.json"):
|
| 95 |
+
if required not in artifacts:
|
| 96 |
+
logger.warning("TFT artifact manifest missing required entry: %s", required)
|
| 97 |
+
return False
|
| 98 |
+
|
| 99 |
+
for name, meta in artifacts.items():
|
| 100 |
+
path = local_dir / name
|
| 101 |
+
if not path.exists():
|
| 102 |
+
logger.warning("TFT artifact listed in manifest is missing: %s", name)
|
| 103 |
+
return False
|
| 104 |
+
expected = str(meta.get("sha256", "")).lower()
|
| 105 |
+
actual = _sha256_file(path).lower()
|
| 106 |
+
if not expected or actual != expected:
|
| 107 |
+
logger.warning("TFT artifact hash mismatch for %s", name)
|
| 108 |
+
return False
|
| 109 |
+
expected_size = meta.get("size_bytes")
|
| 110 |
+
if expected_size is not None and int(expected_size) != path.stat().st_size:
|
| 111 |
+
logger.warning("TFT artifact size mismatch for %s", name)
|
| 112 |
+
return False
|
| 113 |
+
return True
|
| 114 |
+
except Exception as exc:
|
| 115 |
+
logger.warning("TFT artifact manifest validation failed: %s", exc)
|
| 116 |
+
return False
|
| 117 |
+
|
| 118 |
+
|
| 119 |
def _metadata_contract_valid(metadata_path: Path) -> bool:
|
| 120 |
"""Return True when metadata proves the current weekly TFT contract."""
|
| 121 |
if not metadata_path.exists():
|
|
|
|
| 174 |
logger.warning("TFT artifact set has incompatible metadata in %s", local_dir)
|
| 175 |
return False
|
| 176 |
|
| 177 |
+
if not validate_artifact_manifest(local_dir):
|
| 178 |
+
return False
|
| 179 |
+
|
| 180 |
return True
|
| 181 |
|
| 182 |
|
|
|
|
| 197 |
return False
|
| 198 |
|
| 199 |
local_dir = Path(local_dir)
|
| 200 |
+
if not _metadata_contract_valid(local_dir / "tft_metadata.json"):
|
| 201 |
logger.warning(
|
| 202 |
"TFT artifact set in %s is not contract-complete; upload skipped",
|
| 203 |
local_dir,
|
| 204 |
)
|
| 205 |
return False
|
| 206 |
|
| 207 |
+
write_artifact_manifest(local_dir)
|
| 208 |
+
if not validate_tft_artifact_set(local_dir):
|
| 209 |
+
logger.warning("TFT artifact manifest validation failed before upload")
|
| 210 |
+
return False
|
| 211 |
+
|
| 212 |
files_to_upload = [
|
| 213 |
local_dir / name for name in _ARTIFACTS if (local_dir / name).exists()
|
| 214 |
]
|
|
|
|
| 263 |
metadata_path = local_dir / "tft_metadata.json"
|
| 264 |
if metadata_path.exists() and not _metadata_contract_valid(metadata_path):
|
| 265 |
force_download.add("tft_metadata.json")
|
| 266 |
+
manifest_path = local_dir / "artifact_manifest.json"
|
| 267 |
+
if manifest_path.exists() and not validate_artifact_manifest(local_dir):
|
| 268 |
+
force_download.update(_ARTIFACTS)
|
| 269 |
|
| 270 |
try:
|
| 271 |
from huggingface_hub import hf_hub_download
|
deep_learning/models/tft_copper.py
CHANGED
|
@@ -131,15 +131,15 @@ try:
|
|
| 131 |
def __init__(
|
| 132 |
self,
|
| 133 |
quantiles: list,
|
| 134 |
-
lambda_weekly_quantile: float = 0.
|
| 135 |
lambda_t1_quantile: float = 0.10,
|
| 136 |
-
lambda_directional: float = 0.
|
| 137 |
-
lambda_magnitude: float = 0.
|
| 138 |
-
lambda_vol: float = 0.
|
| 139 |
-
lambda_crossing: float =
|
| 140 |
-
lambda_sanity: float = 0.
|
| 141 |
-
lambda_width: float = 0.
|
| 142 |
-
lambda_tail_width: float = 0.
|
| 143 |
sharpe_eps: float = 1e-6,
|
| 144 |
daily_log_return_bound: float = 0.08,
|
| 145 |
weekly_log_return_bound: float = 0.20,
|
|
@@ -193,14 +193,21 @@ try:
|
|
| 193 |
|
| 194 |
abs_actual = actual_weekly.abs()
|
| 195 |
material_mask = abs_actual > (abs_actual.median() + self.sharpe_eps)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 196 |
if material_mask.any():
|
| 197 |
pred_abs = pred_weekly_median[material_mask].abs()
|
| 198 |
true_abs = actual_weekly[material_mask].abs()
|
| 199 |
-
|
| 200 |
torch.log((pred_abs + self.sharpe_eps) / (true_abs + self.sharpe_eps))
|
| 201 |
).mean()
|
| 202 |
else:
|
| 203 |
-
|
|
|
|
| 204 |
|
| 205 |
weekly_spread = (
|
| 206 |
pred_weekly_quantiles[:, self._q90_idx]
|
|
@@ -211,16 +218,19 @@ try:
|
|
| 211 |
mean_weekly_spread = weekly_spread.mean()
|
| 212 |
vol_loss = torch.abs(mean_weekly_spread - target_spread)
|
| 213 |
width_ratio = mean_weekly_spread / (target_spread + self.sharpe_eps)
|
| 214 |
-
|
|
|
|
|
|
|
| 215 |
|
| 216 |
weekly_tail_spread = (
|
| 217 |
pred_weekly_quantiles[:, self._q98_idx]
|
| 218 |
- pred_weekly_quantiles[:, self._q02_idx]
|
| 219 |
)
|
| 220 |
target_tail_spread = 4.10 * actual_weekly_std
|
| 221 |
-
|
| 222 |
-
|
| 223 |
-
)
|
|
|
|
| 224 |
daily_crossing_loss = quantile_crossing_penalty(y_pred)
|
| 225 |
weekly_crossing_loss = quantile_crossing_penalty(pred_weekly_quantiles.unsqueeze(1))
|
| 226 |
crossing_loss = daily_crossing_loss + weekly_crossing_loss
|
|
|
|
| 131 |
def __init__(
|
| 132 |
self,
|
| 133 |
quantiles: list,
|
| 134 |
+
lambda_weekly_quantile: float = 0.60,
|
| 135 |
lambda_t1_quantile: float = 0.10,
|
| 136 |
+
lambda_directional: float = 0.10,
|
| 137 |
+
lambda_magnitude: float = 0.55,
|
| 138 |
+
lambda_vol: float = 0.35,
|
| 139 |
+
lambda_crossing: float = 7.0,
|
| 140 |
+
lambda_sanity: float = 0.20,
|
| 141 |
+
lambda_width: float = 0.50,
|
| 142 |
+
lambda_tail_width: float = 0.30,
|
| 143 |
sharpe_eps: float = 1e-6,
|
| 144 |
daily_log_return_bound: float = 0.08,
|
| 145 |
weekly_log_return_bound: float = 0.20,
|
|
|
|
| 193 |
|
| 194 |
abs_actual = actual_weekly.abs()
|
| 195 |
material_mask = abs_actual > (abs_actual.median() + self.sharpe_eps)
|
| 196 |
+
global_magnitude_loss = torch.abs(
|
| 197 |
+
torch.log(
|
| 198 |
+
(pred_weekly_median.abs() + self.sharpe_eps)
|
| 199 |
+
/ (actual_weekly.abs() + self.sharpe_eps)
|
| 200 |
+
)
|
| 201 |
+
).mean()
|
| 202 |
if material_mask.any():
|
| 203 |
pred_abs = pred_weekly_median[material_mask].abs()
|
| 204 |
true_abs = actual_weekly[material_mask].abs()
|
| 205 |
+
material_magnitude_loss = torch.abs(
|
| 206 |
torch.log((pred_abs + self.sharpe_eps) / (true_abs + self.sharpe_eps))
|
| 207 |
).mean()
|
| 208 |
else:
|
| 209 |
+
material_magnitude_loss = y_pred.new_tensor(0.0)
|
| 210 |
+
magnitude_loss = 0.5 * global_magnitude_loss + 0.5 * material_magnitude_loss
|
| 211 |
|
| 212 |
weekly_spread = (
|
| 213 |
pred_weekly_quantiles[:, self._q90_idx]
|
|
|
|
| 218 |
mean_weekly_spread = weekly_spread.mean()
|
| 219 |
vol_loss = torch.abs(mean_weekly_spread - target_spread)
|
| 220 |
width_ratio = mean_weekly_spread / (target_spread + self.sharpe_eps)
|
| 221 |
+
safe_width_ratio = torch.clamp(width_ratio + self.sharpe_eps, min=1e-6)
|
| 222 |
+
width_loss = torch.abs(torch.log(safe_width_ratio))
|
| 223 |
+
width_loss = width_loss + torch.relu(width_ratio - 2.0).pow(2)
|
| 224 |
|
| 225 |
weekly_tail_spread = (
|
| 226 |
pred_weekly_quantiles[:, self._q98_idx]
|
| 227 |
- pred_weekly_quantiles[:, self._q02_idx]
|
| 228 |
)
|
| 229 |
target_tail_spread = 4.10 * actual_weekly_std
|
| 230 |
+
tail_width_ratio = weekly_tail_spread.mean() / (target_tail_spread + self.sharpe_eps)
|
| 231 |
+
safe_tail_width_ratio = torch.clamp(tail_width_ratio + self.sharpe_eps, min=1e-6)
|
| 232 |
+
tail_width_loss = torch.abs(torch.log(safe_tail_width_ratio))
|
| 233 |
+
tail_width_loss = tail_width_loss + torch.relu(tail_width_ratio - 3.0).pow(2)
|
| 234 |
daily_crossing_loss = quantile_crossing_penalty(y_pred)
|
| 235 |
weekly_crossing_loss = quantile_crossing_penalty(pred_weekly_quantiles.unsqueeze(1))
|
| 236 |
crossing_loss = daily_crossing_loss + weekly_crossing_loss
|
deep_learning/training/hyperopt.py
CHANGED
|
@@ -38,7 +38,7 @@ from deep_learning.config import (
|
|
| 38 |
|
| 39 |
logger = logging.getLogger(__name__)
|
| 40 |
|
| 41 |
-
MIN_COMPLETED_TRIALS =
|
| 42 |
SHARPE_PRUNE_THRESHOLD = -0.3
|
| 43 |
FOLD_SHARPE_PRUNE_THRESHOLD = -1.0
|
| 44 |
|
|
@@ -55,10 +55,14 @@ KNOWN_GOOD_TRIAL_PARAMS = {
|
|
| 55 |
"lambda_quantile": 0.25,
|
| 56 |
"lambda_madl": 0.40,
|
| 57 |
"lambda_weekly_quantile": 0.60,
|
| 58 |
-
"lambda_t1_quantile": 0.
|
| 59 |
"lambda_directional": 0.10,
|
| 60 |
-
"lambda_magnitude": 0.
|
| 61 |
"weekly_lambda_vol": 0.35,
|
|
|
|
|
|
|
|
|
|
|
|
|
| 62 |
"batch_size": 32,
|
| 63 |
}
|
| 64 |
|
|
@@ -129,6 +133,8 @@ def _build_prune_diagnostics(study) -> tuple[dict[str, int], list[dict]]:
|
|
| 129 |
"weekly_magnitude_collapse": 0,
|
| 130 |
"weekly_magnitude_explosion": 0,
|
| 131 |
"weekly_interval_width_explosion": 0,
|
|
|
|
|
|
|
| 132 |
"weekly_overcoverage_width_explosion": 0,
|
| 133 |
"error": 0,
|
| 134 |
}
|
|
@@ -142,7 +148,11 @@ def _build_prune_diagnostics(study) -> tuple[dict[str, int], list[dict]]:
|
|
| 142 |
"avg_weekly_magnitude_ratio",
|
| 143 |
"avg_weekly_pi80_coverage",
|
| 144 |
"avg_weekly_pi80_width_ratio",
|
|
|
|
|
|
|
|
|
|
| 145 |
"avg_weekly_interval_score_80",
|
|
|
|
| 146 |
"fold_score_std",
|
| 147 |
)
|
| 148 |
|
|
@@ -217,32 +227,32 @@ def _enqueue_known_good_trial(study, base_cfg: TFTASROConfig) -> bool:
|
|
| 217 |
def create_trial_config(trial, base_cfg: TFTASROConfig) -> TFTASROConfig:
|
| 218 |
"""Map an Optuna trial to a TFT-ASRO configuration."""
|
| 219 |
model_cfg = TFTModelConfig(
|
| 220 |
-
max_encoder_length=trial.
|
| 221 |
max_prediction_length=base_cfg.model.max_prediction_length,
|
| 222 |
# Post-MRMR pruning (~60-80 features), smaller models generalise better.
|
| 223 |
# 24 is viable now that feature count dropped from 200+ to ~60-80.
|
| 224 |
-
hidden_size=trial.
|
| 225 |
-
attention_head_size=trial.
|
| 226 |
# Floor at 0.20: 313 samples with dropout<0.20 causes co-adaptation
|
| 227 |
# and memorization (REG-2026-001). Cap at 0.35: dropout>0.35 with
|
| 228 |
# small hidden_size collapses the output range.
|
| 229 |
-
dropout=trial.
|
| 230 |
# Paired reduction: with hidden=24-48 and ~60-80 features,
|
| 231 |
# 8-16 is the sweet spot for continuous variable processing.
|
| 232 |
-
hidden_continuous_size=trial.
|
| 233 |
quantiles=base_cfg.model.quantiles,
|
| 234 |
# Range [1e-4, 1e-3]: LR < 1e-4 produces near-zero pred_std (VR=0.14);
|
| 235 |
# LR > 1e-3 causes 1-epoch divergence. This band is the stable zone.
|
| 236 |
-
learning_rate=trial.suggest_float("learning_rate", 1e-4,
|
| 237 |
reduce_on_plateau_patience=4,
|
| 238 |
-
gradient_clip_val=trial.
|
| 239 |
-
weight_decay=trial.suggest_float("weight_decay", 1e-5,
|
| 240 |
)
|
| 241 |
|
| 242 |
asro_cfg = ASROConfig(
|
| 243 |
# Floor at 0.25: three Optuna runs consistently selected 0.30-0.35.
|
| 244 |
# Lower values let the model collapse to near-zero pred_std.
|
| 245 |
-
lambda_vol=trial.suggest_float("lambda_vol", 0.
|
| 246 |
# lambda_quantile is the explicit w_quantile weight (w_sharpe = 1 - w_q)
|
| 247 |
# Capped at 0.40 to ensure Sharpe (directional) component always has
|
| 248 |
# ≥60% weight. Higher values caused the "perfect calibration, coin-flip
|
|
@@ -250,20 +260,20 @@ def create_trial_config(trial, base_cfg: TFTASROConfig) -> TFTASROConfig:
|
|
| 250 |
# expense of directional signal.
|
| 251 |
lambda_quantile=trial.suggest_float("lambda_quantile", 0.25, 0.4, step=0.05),
|
| 252 |
# MADL weight: how much the directional loss contributes relative to Sharpe.
|
| 253 |
-
lambda_madl=trial.suggest_float("lambda_madl", 0.
|
| 254 |
risk_free_rate=0.0,
|
| 255 |
)
|
| 256 |
|
| 257 |
weekly_loss_cfg = WeeklyLossConfig(
|
| 258 |
-
lambda_weekly_quantile=trial.suggest_float("lambda_weekly_quantile", 0.
|
| 259 |
-
lambda_t1_quantile=trial.suggest_float("lambda_t1_quantile", 0.
|
| 260 |
-
lambda_directional=trial.suggest_float("lambda_directional", 0.05, 0.
|
| 261 |
-
lambda_magnitude=trial.suggest_float("lambda_magnitude", 0.
|
| 262 |
lambda_vol=trial.suggest_float("weekly_lambda_vol", 0.25, 0.45, step=0.05),
|
| 263 |
-
lambda_crossing=
|
| 264 |
-
lambda_sanity=
|
| 265 |
-
lambda_width=
|
| 266 |
-
lambda_tail_width=
|
| 267 |
)
|
| 268 |
|
| 269 |
training_cfg = TrainingConfig(
|
|
@@ -360,7 +370,11 @@ def _objective(trial, base_cfg: TFTASROConfig, master_data: tuple) -> float:
|
|
| 360 |
fold_weekly_mr_list: list[float] = []
|
| 361 |
fold_weekly_pi80_coverage_list: list[float] = []
|
| 362 |
fold_weekly_pi80_width_ratio_list: list[float] = []
|
|
|
|
|
|
|
|
|
|
| 363 |
fold_weekly_interval_score_80_list: list[float] = []
|
|
|
|
| 364 |
|
| 365 |
for fold_idx, (fold_train_ds, fold_val_ds) in enumerate(cv_folds):
|
| 366 |
# ---- setup ----
|
|
@@ -430,7 +444,11 @@ def _objective(trial, base_cfg: TFTASROConfig, master_data: tuple) -> float:
|
|
| 430 |
fold_weekly_mr = 1.0
|
| 431 |
fold_weekly_pi80_coverage = 0.0
|
| 432 |
fold_weekly_pi80_width_ratio = 1.0
|
|
|
|
|
|
|
|
|
|
| 433 |
fold_weekly_interval_score_80 = 0.0
|
|
|
|
| 434 |
|
| 435 |
try:
|
| 436 |
pred_tensor = model.predict(fold_val_dl, mode="quantiles")
|
|
@@ -494,19 +512,31 @@ def _objective(trial, base_cfg: TFTASROConfig, master_data: tuple) -> float:
|
|
| 494 |
fold_weekly_mr = float(weekly.get("weekly_magnitude_ratio", 1.0))
|
| 495 |
fold_weekly_pi80_coverage = float(weekly.get("weekly_pi80_coverage", 0.0))
|
| 496 |
fold_weekly_pi80_width_ratio = float(weekly.get("weekly_pi80_width_ratio", 1.0))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 497 |
fold_weekly_interval_score_80 = float(weekly.get("weekly_interval_score_80", 0.0))
|
|
|
|
| 498 |
weekly_actual_std = float(weekly.get("weekly_actual_std", 0.0))
|
| 499 |
interval_score_penalty = fold_weekly_interval_score_80 / (weekly_actual_std + 1e-8)
|
|
|
|
| 500 |
coverage_penalty = abs(fold_weekly_pi80_coverage - 0.80)
|
| 501 |
-
width_penalty = max(0.0, fold_weekly_pi80_width_ratio -
|
|
|
|
|
|
|
| 502 |
fold_weekly_objective = (
|
| 503 |
0.35 * weekly_pinball
|
| 504 |
+ 0.15 * (1.0 - float(weekly.get("weekly_directional_accuracy", 0.5)))
|
| 505 |
-
+ 0.
|
| 506 |
+ 0.20 * coverage_penalty
|
| 507 |
+ 0.25 * width_penalty
|
|
|
|
| 508 |
+ 0.10 * interval_score_penalty
|
| 509 |
-
+ 0.
|
|
|
|
|
|
|
| 510 |
)
|
| 511 |
except Exception as exc:
|
| 512 |
logger.warning(
|
|
@@ -523,7 +553,11 @@ def _objective(trial, base_cfg: TFTASROConfig, master_data: tuple) -> float:
|
|
| 523 |
fold_weekly_mr_list.append(fold_weekly_mr)
|
| 524 |
fold_weekly_pi80_coverage_list.append(fold_weekly_pi80_coverage)
|
| 525 |
fold_weekly_pi80_width_ratio_list.append(fold_weekly_pi80_width_ratio)
|
|
|
|
|
|
|
|
|
|
| 526 |
fold_weekly_interval_score_80_list.append(fold_weekly_interval_score_80)
|
|
|
|
| 527 |
|
| 528 |
# Incorporate DA directly into fold_score as a reward (not just penalty).
|
| 529 |
# DA > 50% (coin-flip) is rewarded, < 50% penalised.
|
|
@@ -584,6 +618,22 @@ def _objective(trial, base_cfg: TFTASROConfig, master_data: tuple) -> float:
|
|
| 584 |
trial.set_user_attr("prune_reason", "weekly_interval_width_explosion")
|
| 585 |
raise optuna.exceptions.TrialPruned()
|
| 586 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 587 |
if (
|
| 588 |
fold_weekly_pi80_coverage >= 0.98
|
| 589 |
and fold_weekly_pi80_width_ratio > 3.0
|
|
@@ -628,11 +678,31 @@ def _objective(trial, base_cfg: TFTASROConfig, master_data: tuple) -> float:
|
|
| 628 |
if fold_weekly_pi80_width_ratio_list
|
| 629 |
else 1.0
|
| 630 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 631 |
avg_weekly_interval_score_80 = (
|
| 632 |
float(np.mean(fold_weekly_interval_score_80_list))
|
| 633 |
if fold_weekly_interval_score_80_list
|
| 634 |
else 0.0
|
| 635 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 636 |
|
| 637 |
# High fold-score variance = trial is unreliable (works in one regime, fails in another)
|
| 638 |
consistency_penalty = (
|
|
@@ -647,7 +717,11 @@ def _objective(trial, base_cfg: TFTASROConfig, master_data: tuple) -> float:
|
|
| 647 |
trial.set_user_attr("avg_weekly_magnitude_ratio", round(avg_weekly_mr, 4))
|
| 648 |
trial.set_user_attr("avg_weekly_pi80_coverage", round(avg_weekly_pi80_coverage, 4))
|
| 649 |
trial.set_user_attr("avg_weekly_pi80_width_ratio", round(avg_weekly_pi80_width_ratio, 4))
|
|
|
|
|
|
|
|
|
|
| 650 |
trial.set_user_attr("avg_weekly_interval_score_80", round(avg_weekly_interval_score_80, 4))
|
|
|
|
| 651 |
trial.set_user_attr(
|
| 652 |
"fold_score_std",
|
| 653 |
round(float(np.std(fold_scores)) if len(fold_scores) > 1 else 0.0, 4),
|
|
@@ -670,6 +744,14 @@ def _objective(trial, base_cfg: TFTASROConfig, master_data: tuple) -> float:
|
|
| 670 |
trial.set_user_attr("prune_reason", "crossing_prune")
|
| 671 |
raise optuna.exceptions.TrialPruned()
|
| 672 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 673 |
# Soft penalty: avg DA below coin-flip
|
| 674 |
da_penalty = 2.0 * max(0.0, 0.50 - avg_da) if avg_da < 0.50 else 0.0
|
| 675 |
|
|
|
|
| 38 |
|
| 39 |
logger = logging.getLogger(__name__)
|
| 40 |
|
| 41 |
+
MIN_COMPLETED_TRIALS = 10
|
| 42 |
SHARPE_PRUNE_THRESHOLD = -0.3
|
| 43 |
FOLD_SHARPE_PRUNE_THRESHOLD = -1.0
|
| 44 |
|
|
|
|
| 55 |
"lambda_quantile": 0.25,
|
| 56 |
"lambda_madl": 0.40,
|
| 57 |
"lambda_weekly_quantile": 0.60,
|
| 58 |
+
"lambda_t1_quantile": 0.10,
|
| 59 |
"lambda_directional": 0.10,
|
| 60 |
+
"lambda_magnitude": 0.55,
|
| 61 |
"weekly_lambda_vol": 0.35,
|
| 62 |
+
"lambda_width": 0.50,
|
| 63 |
+
"lambda_tail_width": 0.30,
|
| 64 |
+
"lambda_sanity": 0.20,
|
| 65 |
+
"lambda_crossing": 7.0,
|
| 66 |
"batch_size": 32,
|
| 67 |
}
|
| 68 |
|
|
|
|
| 133 |
"weekly_magnitude_collapse": 0,
|
| 134 |
"weekly_magnitude_explosion": 0,
|
| 135 |
"weekly_interval_width_explosion": 0,
|
| 136 |
+
"weekly_tail_width_explosion": 0,
|
| 137 |
+
"weekly_raw_crossing_prune": 0,
|
| 138 |
"weekly_overcoverage_width_explosion": 0,
|
| 139 |
"error": 0,
|
| 140 |
}
|
|
|
|
| 148 |
"avg_weekly_magnitude_ratio",
|
| 149 |
"avg_weekly_pi80_coverage",
|
| 150 |
"avg_weekly_pi80_width_ratio",
|
| 151 |
+
"avg_weekly_pi96_width_ratio",
|
| 152 |
+
"avg_weekly_raw_crossing_rate",
|
| 153 |
+
"avg_weekly_sorted_crossing_rate",
|
| 154 |
"avg_weekly_interval_score_80",
|
| 155 |
+
"avg_weekly_interval_score_96",
|
| 156 |
"fold_score_std",
|
| 157 |
)
|
| 158 |
|
|
|
|
| 227 |
def create_trial_config(trial, base_cfg: TFTASROConfig) -> TFTASROConfig:
|
| 228 |
"""Map an Optuna trial to a TFT-ASRO configuration."""
|
| 229 |
model_cfg = TFTModelConfig(
|
| 230 |
+
max_encoder_length=trial.suggest_categorical("max_encoder_length", [40, 50, 60, 75, 90]),
|
| 231 |
max_prediction_length=base_cfg.model.max_prediction_length,
|
| 232 |
# Post-MRMR pruning (~60-80 features), smaller models generalise better.
|
| 233 |
# 24 is viable now that feature count dropped from 200+ to ~60-80.
|
| 234 |
+
hidden_size=trial.suggest_categorical("hidden_size", [24, 32, 48]),
|
| 235 |
+
attention_head_size=trial.suggest_categorical("attention_head_size", [1, 2]),
|
| 236 |
# Floor at 0.20: 313 samples with dropout<0.20 causes co-adaptation
|
| 237 |
# and memorization (REG-2026-001). Cap at 0.35: dropout>0.35 with
|
| 238 |
# small hidden_size collapses the output range.
|
| 239 |
+
dropout=trial.suggest_categorical("dropout", [0.20, 0.25, 0.30, 0.35]),
|
| 240 |
# Paired reduction: with hidden=24-48 and ~60-80 features,
|
| 241 |
# 8-16 is the sweet spot for continuous variable processing.
|
| 242 |
+
hidden_continuous_size=trial.suggest_categorical("hidden_continuous_size", [8, 16]),
|
| 243 |
quantiles=base_cfg.model.quantiles,
|
| 244 |
# Range [1e-4, 1e-3]: LR < 1e-4 produces near-zero pred_std (VR=0.14);
|
| 245 |
# LR > 1e-3 causes 1-epoch divergence. This band is the stable zone.
|
| 246 |
+
learning_rate=trial.suggest_float("learning_rate", 1e-4, 6e-4, log=True),
|
| 247 |
reduce_on_plateau_patience=4,
|
| 248 |
+
gradient_clip_val=trial.suggest_categorical("gradient_clip_val", [0.5, 1.0, 1.5]),
|
| 249 |
+
weight_decay=trial.suggest_float("weight_decay", 1e-5, 5e-4, log=True),
|
| 250 |
)
|
| 251 |
|
| 252 |
asro_cfg = ASROConfig(
|
| 253 |
# Floor at 0.25: three Optuna runs consistently selected 0.30-0.35.
|
| 254 |
# Lower values let the model collapse to near-zero pred_std.
|
| 255 |
+
lambda_vol=trial.suggest_float("lambda_vol", 0.25, 0.40, step=0.05),
|
| 256 |
# lambda_quantile is the explicit w_quantile weight (w_sharpe = 1 - w_q)
|
| 257 |
# Capped at 0.40 to ensure Sharpe (directional) component always has
|
| 258 |
# ≥60% weight. Higher values caused the "perfect calibration, coin-flip
|
|
|
|
| 260 |
# expense of directional signal.
|
| 261 |
lambda_quantile=trial.suggest_float("lambda_quantile", 0.25, 0.4, step=0.05),
|
| 262 |
# MADL weight: how much the directional loss contributes relative to Sharpe.
|
| 263 |
+
lambda_madl=trial.suggest_float("lambda_madl", 0.35, 0.60, step=0.05),
|
| 264 |
risk_free_rate=0.0,
|
| 265 |
)
|
| 266 |
|
| 267 |
weekly_loss_cfg = WeeklyLossConfig(
|
| 268 |
+
lambda_weekly_quantile=trial.suggest_float("lambda_weekly_quantile", 0.60, 0.75, step=0.05),
|
| 269 |
+
lambda_t1_quantile=trial.suggest_float("lambda_t1_quantile", 0.05, 0.15, step=0.05),
|
| 270 |
+
lambda_directional=trial.suggest_float("lambda_directional", 0.05, 0.12, step=0.01),
|
| 271 |
+
lambda_magnitude=trial.suggest_float("lambda_magnitude", 0.50, 0.80, step=0.05),
|
| 272 |
lambda_vol=trial.suggest_float("weekly_lambda_vol", 0.25, 0.45, step=0.05),
|
| 273 |
+
lambda_crossing=trial.suggest_float("lambda_crossing", 5.0, 10.0, step=1.0),
|
| 274 |
+
lambda_sanity=trial.suggest_float("lambda_sanity", 0.10, 0.30, step=0.05),
|
| 275 |
+
lambda_width=trial.suggest_float("lambda_width", 0.40, 0.90, step=0.05),
|
| 276 |
+
lambda_tail_width=trial.suggest_float("lambda_tail_width", 0.25, 0.75, step=0.05),
|
| 277 |
)
|
| 278 |
|
| 279 |
training_cfg = TrainingConfig(
|
|
|
|
| 370 |
fold_weekly_mr_list: list[float] = []
|
| 371 |
fold_weekly_pi80_coverage_list: list[float] = []
|
| 372 |
fold_weekly_pi80_width_ratio_list: list[float] = []
|
| 373 |
+
fold_weekly_pi96_width_ratio_list: list[float] = []
|
| 374 |
+
fold_weekly_raw_crossing_list: list[float] = []
|
| 375 |
+
fold_weekly_sorted_crossing_list: list[float] = []
|
| 376 |
fold_weekly_interval_score_80_list: list[float] = []
|
| 377 |
+
fold_weekly_interval_score_96_list: list[float] = []
|
| 378 |
|
| 379 |
for fold_idx, (fold_train_ds, fold_val_ds) in enumerate(cv_folds):
|
| 380 |
# ---- setup ----
|
|
|
|
| 444 |
fold_weekly_mr = 1.0
|
| 445 |
fold_weekly_pi80_coverage = 0.0
|
| 446 |
fold_weekly_pi80_width_ratio = 1.0
|
| 447 |
+
fold_weekly_pi96_width_ratio = 1.0
|
| 448 |
+
fold_weekly_raw_crossing = 0.0
|
| 449 |
+
fold_weekly_sorted_crossing = 0.0
|
| 450 |
fold_weekly_interval_score_80 = 0.0
|
| 451 |
+
fold_weekly_interval_score_96 = 0.0
|
| 452 |
|
| 453 |
try:
|
| 454 |
pred_tensor = model.predict(fold_val_dl, mode="quantiles")
|
|
|
|
| 512 |
fold_weekly_mr = float(weekly.get("weekly_magnitude_ratio", 1.0))
|
| 513 |
fold_weekly_pi80_coverage = float(weekly.get("weekly_pi80_coverage", 0.0))
|
| 514 |
fold_weekly_pi80_width_ratio = float(weekly.get("weekly_pi80_width_ratio", 1.0))
|
| 515 |
+
fold_weekly_pi96_width_ratio = float(weekly.get("weekly_pi96_width_ratio", 1.0))
|
| 516 |
+
fold_weekly_raw_crossing = float(weekly.get("weekly_quantile_crossing_rate", 0.0))
|
| 517 |
+
fold_weekly_sorted_crossing = float(
|
| 518 |
+
weekly.get("weekly_sorted_quantile_crossing_rate", 0.0)
|
| 519 |
+
)
|
| 520 |
fold_weekly_interval_score_80 = float(weekly.get("weekly_interval_score_80", 0.0))
|
| 521 |
+
fold_weekly_interval_score_96 = float(weekly.get("weekly_interval_score_96", 0.0))
|
| 522 |
weekly_actual_std = float(weekly.get("weekly_actual_std", 0.0))
|
| 523 |
interval_score_penalty = fold_weekly_interval_score_80 / (weekly_actual_std + 1e-8)
|
| 524 |
+
interval_score_96_penalty = fold_weekly_interval_score_96 / (weekly_actual_std + 1e-8)
|
| 525 |
coverage_penalty = abs(fold_weekly_pi80_coverage - 0.80)
|
| 526 |
+
width_penalty = max(0.0, fold_weekly_pi80_width_ratio - 1.5)
|
| 527 |
+
tail_width_penalty = max(0.0, fold_weekly_pi96_width_ratio - 3.0)
|
| 528 |
+
raw_crossing_penalty = max(0.0, fold_weekly_raw_crossing - 0.05)
|
| 529 |
fold_weekly_objective = (
|
| 530 |
0.35 * weekly_pinball
|
| 531 |
+ 0.15 * (1.0 - float(weekly.get("weekly_directional_accuracy", 0.5)))
|
| 532 |
+
+ 0.50 * abs(np.log(fold_weekly_mr + 1e-8))
|
| 533 |
+ 0.20 * coverage_penalty
|
| 534 |
+ 0.25 * width_penalty
|
| 535 |
+
+ 0.35 * tail_width_penalty
|
| 536 |
+ 0.10 * interval_score_penalty
|
| 537 |
+
+ 0.05 * interval_score_96_penalty
|
| 538 |
+
+ 0.50 * raw_crossing_penalty
|
| 539 |
+
+ 0.25 * fold_weekly_sorted_crossing
|
| 540 |
)
|
| 541 |
except Exception as exc:
|
| 542 |
logger.warning(
|
|
|
|
| 553 |
fold_weekly_mr_list.append(fold_weekly_mr)
|
| 554 |
fold_weekly_pi80_coverage_list.append(fold_weekly_pi80_coverage)
|
| 555 |
fold_weekly_pi80_width_ratio_list.append(fold_weekly_pi80_width_ratio)
|
| 556 |
+
fold_weekly_pi96_width_ratio_list.append(fold_weekly_pi96_width_ratio)
|
| 557 |
+
fold_weekly_raw_crossing_list.append(fold_weekly_raw_crossing)
|
| 558 |
+
fold_weekly_sorted_crossing_list.append(fold_weekly_sorted_crossing)
|
| 559 |
fold_weekly_interval_score_80_list.append(fold_weekly_interval_score_80)
|
| 560 |
+
fold_weekly_interval_score_96_list.append(fold_weekly_interval_score_96)
|
| 561 |
|
| 562 |
# Incorporate DA directly into fold_score as a reward (not just penalty).
|
| 563 |
# DA > 50% (coin-flip) is rewarded, < 50% penalised.
|
|
|
|
| 618 |
trial.set_user_attr("prune_reason", "weekly_interval_width_explosion")
|
| 619 |
raise optuna.exceptions.TrialPruned()
|
| 620 |
|
| 621 |
+
if fold_weekly_pi96_width_ratio > 3.0 and fold_idx >= 1 and not protect_trial:
|
| 622 |
+
logger.warning(
|
| 623 |
+
"Trial %d PRUNED at fold %d: weekly_pi96_width_ratio=%.4f > 3.0",
|
| 624 |
+
trial.number, fold_idx + 1, fold_weekly_pi96_width_ratio,
|
| 625 |
+
)
|
| 626 |
+
trial.set_user_attr("prune_reason", "weekly_tail_width_explosion")
|
| 627 |
+
raise optuna.exceptions.TrialPruned()
|
| 628 |
+
|
| 629 |
+
if fold_weekly_raw_crossing > 0.05 and fold_idx >= 1 and not protect_trial:
|
| 630 |
+
logger.warning(
|
| 631 |
+
"Trial %d PRUNED at fold %d: weekly raw crossing=%.4f > 0.05",
|
| 632 |
+
trial.number, fold_idx + 1, fold_weekly_raw_crossing,
|
| 633 |
+
)
|
| 634 |
+
trial.set_user_attr("prune_reason", "weekly_raw_crossing_prune")
|
| 635 |
+
raise optuna.exceptions.TrialPruned()
|
| 636 |
+
|
| 637 |
if (
|
| 638 |
fold_weekly_pi80_coverage >= 0.98
|
| 639 |
and fold_weekly_pi80_width_ratio > 3.0
|
|
|
|
| 678 |
if fold_weekly_pi80_width_ratio_list
|
| 679 |
else 1.0
|
| 680 |
)
|
| 681 |
+
avg_weekly_pi96_width_ratio = (
|
| 682 |
+
float(np.mean(fold_weekly_pi96_width_ratio_list))
|
| 683 |
+
if fold_weekly_pi96_width_ratio_list
|
| 684 |
+
else 1.0
|
| 685 |
+
)
|
| 686 |
+
avg_weekly_raw_crossing = (
|
| 687 |
+
float(np.mean(fold_weekly_raw_crossing_list))
|
| 688 |
+
if fold_weekly_raw_crossing_list
|
| 689 |
+
else 0.0
|
| 690 |
+
)
|
| 691 |
+
avg_weekly_sorted_crossing = (
|
| 692 |
+
float(np.mean(fold_weekly_sorted_crossing_list))
|
| 693 |
+
if fold_weekly_sorted_crossing_list
|
| 694 |
+
else 0.0
|
| 695 |
+
)
|
| 696 |
avg_weekly_interval_score_80 = (
|
| 697 |
float(np.mean(fold_weekly_interval_score_80_list))
|
| 698 |
if fold_weekly_interval_score_80_list
|
| 699 |
else 0.0
|
| 700 |
)
|
| 701 |
+
avg_weekly_interval_score_96 = (
|
| 702 |
+
float(np.mean(fold_weekly_interval_score_96_list))
|
| 703 |
+
if fold_weekly_interval_score_96_list
|
| 704 |
+
else 0.0
|
| 705 |
+
)
|
| 706 |
|
| 707 |
# High fold-score variance = trial is unreliable (works in one regime, fails in another)
|
| 708 |
consistency_penalty = (
|
|
|
|
| 717 |
trial.set_user_attr("avg_weekly_magnitude_ratio", round(avg_weekly_mr, 4))
|
| 718 |
trial.set_user_attr("avg_weekly_pi80_coverage", round(avg_weekly_pi80_coverage, 4))
|
| 719 |
trial.set_user_attr("avg_weekly_pi80_width_ratio", round(avg_weekly_pi80_width_ratio, 4))
|
| 720 |
+
trial.set_user_attr("avg_weekly_pi96_width_ratio", round(avg_weekly_pi96_width_ratio, 4))
|
| 721 |
+
trial.set_user_attr("avg_weekly_raw_crossing_rate", round(avg_weekly_raw_crossing, 4))
|
| 722 |
+
trial.set_user_attr("avg_weekly_sorted_crossing_rate", round(avg_weekly_sorted_crossing, 4))
|
| 723 |
trial.set_user_attr("avg_weekly_interval_score_80", round(avg_weekly_interval_score_80, 4))
|
| 724 |
+
trial.set_user_attr("avg_weekly_interval_score_96", round(avg_weekly_interval_score_96, 4))
|
| 725 |
trial.set_user_attr(
|
| 726 |
"fold_score_std",
|
| 727 |
round(float(np.std(fold_scores)) if len(fold_scores) > 1 else 0.0, 4),
|
|
|
|
| 744 |
trial.set_user_attr("prune_reason", "crossing_prune")
|
| 745 |
raise optuna.exceptions.TrialPruned()
|
| 746 |
|
| 747 |
+
if (avg_weekly_raw_crossing > 0.05 or avg_weekly_sorted_crossing > 0.0) and not protect_trial:
|
| 748 |
+
logger.warning(
|
| 749 |
+
"Trial %d PRUNED: weekly quantile incoherence raw=%.3f sorted=%.3f",
|
| 750 |
+
trial.number, avg_weekly_raw_crossing, avg_weekly_sorted_crossing,
|
| 751 |
+
)
|
| 752 |
+
trial.set_user_attr("prune_reason", "weekly_raw_crossing_prune")
|
| 753 |
+
raise optuna.exceptions.TrialPruned()
|
| 754 |
+
|
| 755 |
# Soft penalty: avg DA below coin-flip
|
| 756 |
da_penalty = 2.0 * max(0.0, 0.50 - avg_da) if avg_da < 0.50 else 0.0
|
| 757 |
|
deep_learning/training/metrics.py
CHANGED
|
@@ -139,6 +139,32 @@ def directional_accuracy(
|
|
| 139 |
return float(matches.mean())
|
| 140 |
|
| 141 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 142 |
def tail_capture_rate(
|
| 143 |
y_actual: np.ndarray,
|
| 144 |
y_pred: np.ndarray,
|
|
@@ -217,14 +243,27 @@ def compute_all_metrics(
|
|
| 217 |
# This is the correct series to compute Sharpe/Sortino on — not the raw predictions.
|
| 218 |
# Using y_pred_median directly produces an inflated ratio because pred_std << actual_std.
|
| 219 |
strategy_returns = np.sign(y_pred_median) * y_actual
|
|
|
|
|
|
|
|
|
|
|
|
|
| 220 |
|
| 221 |
metrics: dict[str, float] = {
|
| 222 |
"mae": float(np.abs(y_actual - y_pred_median).mean()),
|
| 223 |
"rmse": float(np.sqrt(((y_actual - y_pred_median) ** 2).mean())),
|
| 224 |
"directional_accuracy": directional_accuracy(y_actual, y_pred_median),
|
|
|
|
|
|
|
|
|
|
| 225 |
"tail_capture_rate": tail_capture_rate(y_actual, y_pred_median, tail_threshold),
|
| 226 |
"sharpe_ratio": sharpe_ratio(strategy_returns),
|
| 227 |
"sortino_ratio": sortino_ratio(strategy_returns),
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 228 |
}
|
| 229 |
|
| 230 |
pred_std = float(y_pred_median.std())
|
|
@@ -238,19 +277,29 @@ def compute_all_metrics(
|
|
| 238 |
q90 = np.asarray(y_pred_q90, dtype=np.float64)
|
| 239 |
metrics["pi80_coverage"] = prediction_interval_coverage(y_actual, q10, q90)
|
| 240 |
metrics["pi80_width"] = prediction_interval_width(q10, q90)
|
|
|
|
| 241 |
|
| 242 |
if y_pred_q02 is not None and y_pred_q98 is not None:
|
| 243 |
q02 = np.asarray(y_pred_q02, dtype=np.float64)
|
| 244 |
q98 = np.asarray(y_pred_q98, dtype=np.float64)
|
| 245 |
metrics["pi96_coverage"] = prediction_interval_coverage(y_actual, q02, q98)
|
| 246 |
metrics["pi96_width"] = prediction_interval_width(q02, q98)
|
|
|
|
| 247 |
|
| 248 |
if y_pred_quantiles is not None:
|
| 249 |
q_arr = np.asarray(y_pred_quantiles, dtype=np.float64)
|
| 250 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 251 |
gap_mean, gap_max = quantile_median_sort_gap(q_arr)
|
| 252 |
metrics["median_sort_gap_mean"] = gap_mean
|
| 253 |
metrics["median_sort_gap_max"] = gap_max
|
|
|
|
|
|
|
|
|
|
| 254 |
|
| 255 |
return metrics
|
| 256 |
|
|
@@ -268,7 +317,8 @@ def compute_weekly_metrics(
|
|
| 268 |
to simple returns happens only during inference formatting.
|
| 269 |
"""
|
| 270 |
weekly_actual = cumulative_horizon(y_actual_path, horizon=horizon)
|
| 271 |
-
|
|
|
|
| 272 |
|
| 273 |
median_idx = len(quantiles) // 2
|
| 274 |
q10_idx = quantiles.index(0.10)
|
|
@@ -290,11 +340,18 @@ def compute_weekly_metrics(
|
|
| 290 |
y_pred_q90=weekly_quantiles[:, q90_idx],
|
| 291 |
y_pred_q02=weekly_quantiles[:, q02_idx],
|
| 292 |
y_pred_q98=weekly_quantiles[:, q98_idx],
|
| 293 |
-
y_pred_quantiles=
|
| 294 |
tail_threshold=tail_threshold,
|
| 295 |
)
|
| 296 |
|
| 297 |
weekly_metrics = {f"weekly_{k}": v for k, v in metrics.items()}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 298 |
weekly_metrics["weekly_magnitude_ratio"] = magnitude_ratio(weekly_actual, weekly_pred)
|
| 299 |
weekly_metrics["weekly_mean_actual_abs"] = float(np.mean(np.abs(weekly_actual)))
|
| 300 |
weekly_metrics["weekly_mean_pred_abs"] = float(np.mean(np.abs(weekly_pred)))
|
|
@@ -311,5 +368,11 @@ def compute_weekly_metrics(
|
|
| 311 |
weekly_quantiles[:, q90_idx],
|
| 312 |
alpha=0.20,
|
| 313 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 314 |
weekly_metrics["weekly_sample_count"] = int(len(weekly_actual))
|
| 315 |
return weekly_metrics
|
|
|
|
| 139 |
return float(matches.mean())
|
| 140 |
|
| 141 |
|
| 142 |
+
def directional_accuracy_count(
|
| 143 |
+
y_actual: np.ndarray,
|
| 144 |
+
y_pred: np.ndarray,
|
| 145 |
+
) -> tuple[int, int]:
|
| 146 |
+
"""Return ``(matches, n)`` for directional accuracy confidence intervals."""
|
| 147 |
+
actual_sign = np.sign(y_actual)
|
| 148 |
+
pred_sign = np.sign(y_pred)
|
| 149 |
+
matches = (actual_sign == pred_sign) | ((actual_sign == 0) & (pred_sign == 0))
|
| 150 |
+
return int(matches.sum()), int(matches.size)
|
| 151 |
+
|
| 152 |
+
|
| 153 |
+
def wilson_interval(
|
| 154 |
+
successes: int,
|
| 155 |
+
n: int,
|
| 156 |
+
z: float = 1.959963984540054,
|
| 157 |
+
) -> tuple[float, float]:
|
| 158 |
+
"""Two-sided Wilson confidence interval for a binomial proportion."""
|
| 159 |
+
if n <= 0:
|
| 160 |
+
return 0.0, 0.0
|
| 161 |
+
phat = successes / n
|
| 162 |
+
denom = 1.0 + z * z / n
|
| 163 |
+
centre = phat + z * z / (2.0 * n)
|
| 164 |
+
margin = z * np.sqrt((phat * (1.0 - phat) + z * z / (4.0 * n)) / n)
|
| 165 |
+
return float((centre - margin) / denom), float((centre + margin) / denom)
|
| 166 |
+
|
| 167 |
+
|
| 168 |
def tail_capture_rate(
|
| 169 |
y_actual: np.ndarray,
|
| 170 |
y_pred: np.ndarray,
|
|
|
|
| 243 |
# This is the correct series to compute Sharpe/Sortino on — not the raw predictions.
|
| 244 |
# Using y_pred_median directly produces an inflated ratio because pred_std << actual_std.
|
| 245 |
strategy_returns = np.sign(y_pred_median) * y_actual
|
| 246 |
+
direction_hits, direction_n = directional_accuracy_count(y_actual, y_pred_median)
|
| 247 |
+
da_ci_low, da_ci_high = wilson_interval(direction_hits, direction_n)
|
| 248 |
+
zero_mae = float(np.abs(y_actual).mean())
|
| 249 |
+
zero_rmse = float(np.sqrt((y_actual ** 2).mean()))
|
| 250 |
|
| 251 |
metrics: dict[str, float] = {
|
| 252 |
"mae": float(np.abs(y_actual - y_pred_median).mean()),
|
| 253 |
"rmse": float(np.sqrt(((y_actual - y_pred_median) ** 2).mean())),
|
| 254 |
"directional_accuracy": directional_accuracy(y_actual, y_pred_median),
|
| 255 |
+
"directional_accuracy_ci_low": da_ci_low,
|
| 256 |
+
"directional_accuracy_ci_high": da_ci_high,
|
| 257 |
+
"directional_accuracy_n": float(direction_n),
|
| 258 |
"tail_capture_rate": tail_capture_rate(y_actual, y_pred_median, tail_threshold),
|
| 259 |
"sharpe_ratio": sharpe_ratio(strategy_returns),
|
| 260 |
"sortino_ratio": sortino_ratio(strategy_returns),
|
| 261 |
+
"naive_zero_mae": zero_mae,
|
| 262 |
+
"naive_zero_rmse": zero_rmse,
|
| 263 |
+
"mae_vs_naive_zero": float(np.abs(y_actual - y_pred_median).mean() / (zero_mae + 1e-12)),
|
| 264 |
+
"rmse_vs_naive_zero": float(
|
| 265 |
+
np.sqrt(((y_actual - y_pred_median) ** 2).mean()) / (zero_rmse + 1e-12)
|
| 266 |
+
),
|
| 267 |
}
|
| 268 |
|
| 269 |
pred_std = float(y_pred_median.std())
|
|
|
|
| 277 |
q90 = np.asarray(y_pred_q90, dtype=np.float64)
|
| 278 |
metrics["pi80_coverage"] = prediction_interval_coverage(y_actual, q10, q90)
|
| 279 |
metrics["pi80_width"] = prediction_interval_width(q10, q90)
|
| 280 |
+
metrics["pi80_interval_score"] = interval_score(y_actual, q10, q90, alpha=0.20)
|
| 281 |
|
| 282 |
if y_pred_q02 is not None and y_pred_q98 is not None:
|
| 283 |
q02 = np.asarray(y_pred_q02, dtype=np.float64)
|
| 284 |
q98 = np.asarray(y_pred_q98, dtype=np.float64)
|
| 285 |
metrics["pi96_coverage"] = prediction_interval_coverage(y_actual, q02, q98)
|
| 286 |
metrics["pi96_width"] = prediction_interval_width(q02, q98)
|
| 287 |
+
metrics["pi96_interval_score"] = interval_score(y_actual, q02, q98, alpha=0.04)
|
| 288 |
|
| 289 |
if y_pred_quantiles is not None:
|
| 290 |
q_arr = np.asarray(y_pred_quantiles, dtype=np.float64)
|
| 291 |
+
sorted_q = np.sort(q_arr, axis=-1)
|
| 292 |
+
raw_crossing = quantile_crossing_rate(q_arr)
|
| 293 |
+
sorted_crossing = quantile_crossing_rate(sorted_q)
|
| 294 |
+
metrics["quantile_crossing_rate"] = raw_crossing
|
| 295 |
+
metrics["raw_quantile_crossing_rate"] = raw_crossing
|
| 296 |
+
metrics["sorted_quantile_crossing_rate"] = sorted_crossing
|
| 297 |
gap_mean, gap_max = quantile_median_sort_gap(q_arr)
|
| 298 |
metrics["median_sort_gap_mean"] = gap_mean
|
| 299 |
metrics["median_sort_gap_max"] = gap_max
|
| 300 |
+
sorted_gap_mean, sorted_gap_max = quantile_median_sort_gap(sorted_q)
|
| 301 |
+
metrics["sorted_median_sort_gap_mean"] = sorted_gap_mean
|
| 302 |
+
metrics["sorted_median_sort_gap_max"] = sorted_gap_max
|
| 303 |
|
| 304 |
return metrics
|
| 305 |
|
|
|
|
| 317 |
to simple returns happens only during inference formatting.
|
| 318 |
"""
|
| 319 |
weekly_actual = cumulative_horizon(y_actual_path, horizon=horizon)
|
| 320 |
+
approx_weekly_quantiles = cumulative_quantiles(y_pred_quantiles_path, horizon=horizon)
|
| 321 |
+
weekly_quantiles = np.sort(approx_weekly_quantiles, axis=-1)
|
| 322 |
|
| 323 |
median_idx = len(quantiles) // 2
|
| 324 |
q10_idx = quantiles.index(0.10)
|
|
|
|
| 340 |
y_pred_q90=weekly_quantiles[:, q90_idx],
|
| 341 |
y_pred_q02=weekly_quantiles[:, q02_idx],
|
| 342 |
y_pred_q98=weekly_quantiles[:, q98_idx],
|
| 343 |
+
y_pred_quantiles=approx_weekly_quantiles,
|
| 344 |
tail_threshold=tail_threshold,
|
| 345 |
)
|
| 346 |
|
| 347 |
weekly_metrics = {f"weekly_{k}": v for k, v in metrics.items()}
|
| 348 |
+
weekly_metrics["weekly_interval_quantile_source"] = 1.0
|
| 349 |
+
weekly_metrics["weekly_approx_quantile_crossing_rate"] = quantile_crossing_rate(
|
| 350 |
+
approx_weekly_quantiles
|
| 351 |
+
)
|
| 352 |
+
approx_gap_mean, approx_gap_max = quantile_median_sort_gap(approx_weekly_quantiles)
|
| 353 |
+
weekly_metrics["weekly_approx_median_sort_gap_mean"] = approx_gap_mean
|
| 354 |
+
weekly_metrics["weekly_approx_median_sort_gap_max"] = approx_gap_max
|
| 355 |
weekly_metrics["weekly_magnitude_ratio"] = magnitude_ratio(weekly_actual, weekly_pred)
|
| 356 |
weekly_metrics["weekly_mean_actual_abs"] = float(np.mean(np.abs(weekly_actual)))
|
| 357 |
weekly_metrics["weekly_mean_pred_abs"] = float(np.mean(np.abs(weekly_pred)))
|
|
|
|
| 368 |
weekly_quantiles[:, q90_idx],
|
| 369 |
alpha=0.20,
|
| 370 |
)
|
| 371 |
+
weekly_metrics["weekly_interval_score_96"] = interval_score(
|
| 372 |
+
weekly_actual,
|
| 373 |
+
weekly_quantiles[:, q02_idx],
|
| 374 |
+
weekly_quantiles[:, q98_idx],
|
| 375 |
+
alpha=0.04,
|
| 376 |
+
)
|
| 377 |
weekly_metrics["weekly_sample_count"] = int(len(weekly_actual))
|
| 378 |
return weekly_metrics
|
deep_learning/training/trainer.py
CHANGED
|
@@ -46,6 +46,7 @@ warnings.filterwarnings(
|
|
| 46 |
logger = logging.getLogger(__name__)
|
| 47 |
|
| 48 |
KNOWN_GOOD_CONFIG = {
|
|
|
|
| 49 |
"hidden_size": 48,
|
| 50 |
"attention_head_size": 2,
|
| 51 |
"dropout": 0.30,
|
|
@@ -56,12 +57,14 @@ KNOWN_GOOD_CONFIG = {
|
|
| 56 |
"lambda_quantile": 0.25,
|
| 57 |
"lambda_madl": 0.40,
|
| 58 |
"lambda_weekly_quantile": 0.60,
|
| 59 |
-
"lambda_t1_quantile": 0.
|
| 60 |
"lambda_directional": 0.10,
|
| 61 |
-
"lambda_magnitude": 0.
|
| 62 |
"weekly_lambda_vol": 0.35,
|
| 63 |
-
"lambda_width": 0.
|
| 64 |
-
"lambda_tail_width": 0.
|
|
|
|
|
|
|
| 65 |
"batch_size": 32,
|
| 66 |
}
|
| 67 |
|
|
@@ -70,9 +73,14 @@ REQUIRED_PROMOTABLE_METRICS = (
|
|
| 70 |
"weekly_magnitude_ratio",
|
| 71 |
"weekly_tail_capture_rate",
|
| 72 |
"weekly_pi80_coverage",
|
|
|
|
|
|
|
|
|
|
| 73 |
"weekly_sample_count",
|
| 74 |
"weekly_quantile_crossing_rate",
|
|
|
|
| 75 |
"quantile_crossing_rate",
|
|
|
|
| 76 |
)
|
| 77 |
|
| 78 |
|
|
@@ -451,8 +459,17 @@ def train_tft_model(
|
|
| 451 |
# Write metadata JSON to disk for CI quality gate
|
| 452 |
meta_json_path = Path(cfg.training.best_model_path).parent / "tft_metadata.json"
|
| 453 |
try:
|
|
|
|
| 454 |
meta_json_path.write_text(json.dumps(result, indent=2, default=str))
|
| 455 |
logger.info("Training metadata written to %s", meta_json_path)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 456 |
except Exception as exc:
|
| 457 |
logger.warning("Could not write metadata JSON: %s", exc)
|
| 458 |
|
|
@@ -515,7 +532,10 @@ def _write_conformal_calibration_artifact(
|
|
| 515 |
return None
|
| 516 |
|
| 517 |
weekly_actual = cumulative_horizon(y_actual_path[:n], horizon=cfg.forecast.primary_horizon_days)
|
| 518 |
-
weekly_quantiles =
|
|
|
|
|
|
|
|
|
|
| 519 |
q = tuple(cfg.model.quantiles)
|
| 520 |
q10_idx = q.index(0.10)
|
| 521 |
q90_idx = q.index(0.90)
|
|
@@ -606,6 +626,28 @@ def _apply_optuna_results(cfg: TFTASROConfig) -> TFTASROConfig:
|
|
| 606 |
params["lambda_quantile"] = min(max(float(params["lambda_quantile"]), 0.25), 0.40)
|
| 607 |
if "lambda_madl" in params:
|
| 608 |
params["lambda_madl"] = max(float(params["lambda_madl"]), 0.30)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 609 |
|
| 610 |
logger.info(
|
| 611 |
"Loaded Optuna best params (trial #%d, weekly_objective=%.4f): %s",
|
|
|
|
| 46 |
logger = logging.getLogger(__name__)
|
| 47 |
|
| 48 |
KNOWN_GOOD_CONFIG = {
|
| 49 |
+
"max_encoder_length": 60,
|
| 50 |
"hidden_size": 48,
|
| 51 |
"attention_head_size": 2,
|
| 52 |
"dropout": 0.30,
|
|
|
|
| 57 |
"lambda_quantile": 0.25,
|
| 58 |
"lambda_madl": 0.40,
|
| 59 |
"lambda_weekly_quantile": 0.60,
|
| 60 |
+
"lambda_t1_quantile": 0.10,
|
| 61 |
"lambda_directional": 0.10,
|
| 62 |
+
"lambda_magnitude": 0.55,
|
| 63 |
"weekly_lambda_vol": 0.35,
|
| 64 |
+
"lambda_width": 0.50,
|
| 65 |
+
"lambda_tail_width": 0.30,
|
| 66 |
+
"lambda_sanity": 0.20,
|
| 67 |
+
"lambda_crossing": 7.0,
|
| 68 |
"batch_size": 32,
|
| 69 |
}
|
| 70 |
|
|
|
|
| 73 |
"weekly_magnitude_ratio",
|
| 74 |
"weekly_tail_capture_rate",
|
| 75 |
"weekly_pi80_coverage",
|
| 76 |
+
"weekly_pi80_width_ratio",
|
| 77 |
+
"weekly_pi96_coverage",
|
| 78 |
+
"weekly_pi96_width_ratio",
|
| 79 |
"weekly_sample_count",
|
| 80 |
"weekly_quantile_crossing_rate",
|
| 81 |
+
"weekly_sorted_quantile_crossing_rate",
|
| 82 |
"quantile_crossing_rate",
|
| 83 |
+
"sorted_quantile_crossing_rate",
|
| 84 |
)
|
| 85 |
|
| 86 |
|
|
|
|
| 459 |
# Write metadata JSON to disk for CI quality gate
|
| 460 |
meta_json_path = Path(cfg.training.best_model_path).parent / "tft_metadata.json"
|
| 461 |
try:
|
| 462 |
+
result["artifact_manifest_path"] = str(meta_json_path.parent / "artifact_manifest.json")
|
| 463 |
meta_json_path.write_text(json.dumps(result, indent=2, default=str))
|
| 464 |
logger.info("Training metadata written to %s", meta_json_path)
|
| 465 |
+
try:
|
| 466 |
+
from deep_learning.models.hub import write_artifact_manifest
|
| 467 |
+
|
| 468 |
+
manifest_path = write_artifact_manifest(meta_json_path.parent)
|
| 469 |
+
result["artifact_manifest_path"] = str(manifest_path)
|
| 470 |
+
logger.info("Artifact manifest written to %s", manifest_path)
|
| 471 |
+
except Exception as exc:
|
| 472 |
+
logger.warning("Could not write artifact manifest: %s", exc)
|
| 473 |
except Exception as exc:
|
| 474 |
logger.warning("Could not write metadata JSON: %s", exc)
|
| 475 |
|
|
|
|
| 532 |
return None
|
| 533 |
|
| 534 |
weekly_actual = cumulative_horizon(y_actual_path[:n], horizon=cfg.forecast.primary_horizon_days)
|
| 535 |
+
weekly_quantiles = np.sort(
|
| 536 |
+
cumulative_quantiles(pred_np[:n], horizon=cfg.forecast.primary_horizon_days),
|
| 537 |
+
axis=-1,
|
| 538 |
+
)
|
| 539 |
q = tuple(cfg.model.quantiles)
|
| 540 |
q10_idx = q.index(0.10)
|
| 541 |
q90_idx = q.index(0.90)
|
|
|
|
| 626 |
params["lambda_quantile"] = min(max(float(params["lambda_quantile"]), 0.25), 0.40)
|
| 627 |
if "lambda_madl" in params:
|
| 628 |
params["lambda_madl"] = max(float(params["lambda_madl"]), 0.30)
|
| 629 |
+
if "max_encoder_length" in params and int(params["max_encoder_length"]) < 40:
|
| 630 |
+
logger.warning(
|
| 631 |
+
"Optuna max_encoder_length=%s is below weekly-safe floor; clamping to 40",
|
| 632 |
+
params["max_encoder_length"],
|
| 633 |
+
)
|
| 634 |
+
params["max_encoder_length"] = 40
|
| 635 |
+
if "learning_rate" in params:
|
| 636 |
+
params["learning_rate"] = min(float(params["learning_rate"]), 6e-4)
|
| 637 |
+
if "weight_decay" in params:
|
| 638 |
+
params["weight_decay"] = min(float(params["weight_decay"]), 5e-4)
|
| 639 |
+
if "lambda_magnitude" in params:
|
| 640 |
+
params["lambda_magnitude"] = max(float(params["lambda_magnitude"]), 0.50)
|
| 641 |
+
if "lambda_directional" in params:
|
| 642 |
+
params["lambda_directional"] = min(float(params["lambda_directional"]), 0.12)
|
| 643 |
+
if "lambda_width" in params:
|
| 644 |
+
params["lambda_width"] = max(float(params["lambda_width"]), 0.40)
|
| 645 |
+
if "lambda_tail_width" in params:
|
| 646 |
+
params["lambda_tail_width"] = max(float(params["lambda_tail_width"]), 0.25)
|
| 647 |
+
if "lambda_sanity" in params:
|
| 648 |
+
params["lambda_sanity"] = max(float(params["lambda_sanity"]), 0.10)
|
| 649 |
+
if "lambda_crossing" in params:
|
| 650 |
+
params["lambda_crossing"] = max(float(params["lambda_crossing"]), 5.0)
|
| 651 |
|
| 652 |
logger.info(
|
| 653 |
"Loaded Optuna best params (trial #%d, weekly_objective=%.4f): %s",
|
pyproject.toml
CHANGED
|
@@ -58,6 +58,7 @@ dev = [
|
|
| 58 |
"pytest>=7.4.3",
|
| 59 |
"pytest-asyncio>=0.21.1",
|
| 60 |
"httpx>=0.25.2",
|
|
|
|
| 61 |
]
|
| 62 |
|
| 63 |
[tool.setuptools.packages.find]
|
|
|
|
| 58 |
"pytest>=7.4.3",
|
| 59 |
"pytest-asyncio>=0.21.1",
|
| 60 |
"httpx>=0.25.2",
|
| 61 |
+
"pip-audit>=2.7.0",
|
| 62 |
]
|
| 63 |
|
| 64 |
[tool.setuptools.packages.find]
|
scripts/tft_quality_gate.py
CHANGED
|
@@ -41,7 +41,11 @@ def main() -> int:
|
|
| 41 |
weekly_mr = metrics.get("weekly_magnitude_ratio")
|
| 42 |
weekly_tail = metrics.get("weekly_tail_capture_rate")
|
| 43 |
weekly_pi80 = metrics.get("weekly_pi80_coverage")
|
|
|
|
|
|
|
|
|
|
| 44 |
weekly_qcross = metrics.get("weekly_quantile_crossing_rate")
|
|
|
|
| 45 |
weekly_gap = metrics.get("weekly_median_sort_gap_max")
|
| 46 |
weekly_samples = metrics.get("weekly_sample_count")
|
| 47 |
|
|
@@ -55,7 +59,9 @@ def main() -> int:
|
|
| 55 |
"Weekly gate metrics: "
|
| 56 |
f"WeeklyDA={weekly_da} WeeklyMR={weekly_mr} "
|
| 57 |
f"WeeklyTail={weekly_tail} WeeklyPI80={weekly_pi80} "
|
| 58 |
-
f"
|
|
|
|
|
|
|
| 59 |
)
|
| 60 |
|
| 61 |
passed, reasons = evaluate_quality_gate(
|
|
@@ -69,7 +75,11 @@ def main() -> int:
|
|
| 69 |
weekly_magnitude_ratio=weekly_mr,
|
| 70 |
weekly_tail_capture_rate=weekly_tail,
|
| 71 |
weekly_pi80_coverage=weekly_pi80,
|
|
|
|
|
|
|
|
|
|
| 72 |
weekly_quantile_crossing_rate=weekly_qcross,
|
|
|
|
| 73 |
weekly_median_sort_gap_max=weekly_gap,
|
| 74 |
weekly_sample_count=weekly_samples,
|
| 75 |
)
|
|
|
|
| 41 |
weekly_mr = metrics.get("weekly_magnitude_ratio")
|
| 42 |
weekly_tail = metrics.get("weekly_tail_capture_rate")
|
| 43 |
weekly_pi80 = metrics.get("weekly_pi80_coverage")
|
| 44 |
+
weekly_pi80_width_ratio = metrics.get("weekly_pi80_width_ratio")
|
| 45 |
+
weekly_pi96 = metrics.get("weekly_pi96_coverage")
|
| 46 |
+
weekly_pi96_width_ratio = metrics.get("weekly_pi96_width_ratio")
|
| 47 |
weekly_qcross = metrics.get("weekly_quantile_crossing_rate")
|
| 48 |
+
weekly_sorted_qcross = metrics.get("weekly_sorted_quantile_crossing_rate")
|
| 49 |
weekly_gap = metrics.get("weekly_median_sort_gap_max")
|
| 50 |
weekly_samples = metrics.get("weekly_sample_count")
|
| 51 |
|
|
|
|
| 59 |
"Weekly gate metrics: "
|
| 60 |
f"WeeklyDA={weekly_da} WeeklyMR={weekly_mr} "
|
| 61 |
f"WeeklyTail={weekly_tail} WeeklyPI80={weekly_pi80} "
|
| 62 |
+
f"WeeklyPI96WidthRatio={weekly_pi96_width_ratio} "
|
| 63 |
+
f"WeeklyQCross={weekly_qcross} WeeklySortedQCross={weekly_sorted_qcross} "
|
| 64 |
+
f"WeeklyN={weekly_samples}"
|
| 65 |
)
|
| 66 |
|
| 67 |
passed, reasons = evaluate_quality_gate(
|
|
|
|
| 75 |
weekly_magnitude_ratio=weekly_mr,
|
| 76 |
weekly_tail_capture_rate=weekly_tail,
|
| 77 |
weekly_pi80_coverage=weekly_pi80,
|
| 78 |
+
weekly_pi80_width_ratio=weekly_pi80_width_ratio,
|
| 79 |
+
weekly_pi96_coverage=weekly_pi96,
|
| 80 |
+
weekly_pi96_width_ratio=weekly_pi96_width_ratio,
|
| 81 |
weekly_quantile_crossing_rate=weekly_qcross,
|
| 82 |
+
weekly_sorted_quantile_crossing_rate=weekly_sorted_qcross,
|
| 83 |
weekly_median_sort_gap_max=weekly_gap,
|
| 84 |
weekly_sample_count=weekly_samples,
|
| 85 |
)
|