Spaces:
Running
Running
Upload Quasar_axrvi_ranker.py
Browse files- Quasar_axrvi_ranker.py +40 -13
Quasar_axrvi_ranker.py
CHANGED
|
@@ -1,3 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
# -*- coding: utf-8 -*-
|
| 2 |
"""
|
| 3 |
Created on Fri Apr 3 13:46:18 2026
|
|
@@ -2411,10 +2418,26 @@ def crps_loss(quantiles: torch.Tensor, quantile_levels: torch.Tensor,
|
|
| 2411 |
y_true: torch.Tensor) -> torch.Tensor:
|
| 2412 |
"""
|
| 2413 |
Continuous Ranked Probability Score (proper scoring rule for distributions).
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2414 |
"""
|
| 2415 |
-
|
| 2416 |
-
|
| 2417 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2418 |
pinball = torch.where(errors >= 0, tau * errors, (tau - 1) * errors)
|
| 2419 |
return pinball.mean()
|
| 2420 |
|
|
@@ -3474,8 +3497,12 @@ class HybridTrainer:
|
|
| 3474 |
l_gate = out.get("gate_entropy_loss", torch.tensor(0.0, device=seq_t.device))
|
| 3475 |
|
| 3476 |
# L_crps β distributional calibration (CRPS proper scoring rule) [v8]
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3477 |
ql_levels = self.model.distributional.quantile_levels
|
| 3478 |
-
l_crps = crps_loss(
|
| 3479 |
|
| 3480 |
# L_rent β regime-router entropy regularisation [v8]
|
| 3481 |
r_probs = out["regime_probs"] + 1e-8 # (B, N, n_regimes)
|
|
@@ -4550,7 +4577,7 @@ class QuasarAXRVIBridge:
|
|
| 4550 |
reward_strategy: str = "simple",
|
| 4551 |
hub_ws_url: str = os.environ.get("QUASAR_HUB_URL", "ws://localhost:7860/ws/subscribe"),
|
| 4552 |
enable_logging: bool = True,
|
| 4553 |
-
checkpoint_dir: str = "./
|
| 4554 |
resume: bool = False, # start afresh by default; set True to resume
|
| 4555 |
):
|
| 4556 |
self.config = config or AssetRankerConfig()
|
|
@@ -6204,7 +6231,7 @@ class RankerCheckpointManager:
|
|
| 6204 |
On load, missing local files are pulled from the repo first.
|
| 6205 |
Requires HF_TOKEN env-var with write permission.
|
| 6206 |
|
| 6207 |
-
Fallback : Local disk ./
|
| 6208 |
Used when HF_TOKEN is absent or upload/download fails.
|
| 6209 |
All saves still succeed locally even without network access.
|
| 6210 |
|
|
@@ -6258,7 +6285,7 @@ class RankerCheckpointManager:
|
|
| 6258 |
|
| 6259 |
Usage
|
| 6260 |
βββββ
|
| 6261 |
-
mgr = RankerCheckpointManager(checkpoint_dir="./
|
| 6262 |
save_interval_seconds=1800)
|
| 6263 |
mgr.load(bridge) # once after initialize()
|
| 6264 |
mgr.maybe_save(bridge) # call frequently; respects save_interval_seconds
|
|
@@ -6274,7 +6301,7 @@ class RankerCheckpointManager:
|
|
| 6274 |
|
| 6275 |
def __init__(
|
| 6276 |
self,
|
| 6277 |
-
checkpoint_dir: str = "./
|
| 6278 |
save_interval_seconds: float = 1800.0, # 30 minutes
|
| 6279 |
):
|
| 6280 |
import pathlib
|
|
@@ -6328,7 +6355,7 @@ class RankerCheckpointManager:
|
|
| 6328 |
)
|
| 6329 |
logger.info(
|
| 6330 |
f"[RankerCheckpoint] βοΈ Uploaded {local_path.name} β "
|
| 6331 |
-
f"hf://{self.HF_REPO_ID}/
|
| 6332 |
)
|
| 6333 |
return True
|
| 6334 |
except Exception as exc:
|
|
@@ -6359,7 +6386,7 @@ class RankerCheckpointManager:
|
|
| 6359 |
)
|
| 6360 |
logger.info(
|
| 6361 |
f"[RankerCheckpoint] β¬οΈ Downloaded {filename} from "
|
| 6362 |
-
f"hf://{self.HF_REPO_ID}/
|
| 6363 |
)
|
| 6364 |
return True
|
| 6365 |
except Exception as exc:
|
|
@@ -6900,7 +6927,7 @@ async def run_live_trading_system(
|
|
| 6900 |
hub_ws_url: str = "ws://localhost:7860/ws/subscribe",
|
| 6901 |
enable_logging: bool = True,
|
| 6902 |
shreve_config: Optional[ShreveConfig] = None,
|
| 6903 |
-
checkpoint_dir: str = "./
|
| 6904 |
resume: bool = False, # start fresh by default
|
| 6905 |
) -> None:
|
| 6906 |
config = AssetRankerConfig(
|
|
@@ -7212,8 +7239,8 @@ def _parse_args():
|
|
| 7212 |
help="[S6/S8] Trade horizon Ο in seconds (default 60)")
|
| 7213 |
parser.add_argument("--martingale-epsilon", type=float, default=0.05,
|
| 7214 |
help="[S7] Gate E martingale deviation threshold (default 0.05)")
|
| 7215 |
-
parser.add_argument("--checkpoint-dir", default="./
|
| 7216 |
-
help="Directory for full-state checkpoints (default ./
|
| 7217 |
parser.add_argument("--resume", action="store_true",
|
| 7218 |
help="Resume training from the latest saved checkpoint (default: start fresh)")
|
| 7219 |
return parser.parse_args(filtered)
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
+
"""
|
| 3 |
+
Created on Fri Apr 3 21:26:48 2026
|
| 4 |
+
|
| 5 |
+
@author: taten
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
# -*- coding: utf-8 -*-
|
| 9 |
"""
|
| 10 |
Created on Fri Apr 3 13:46:18 2026
|
|
|
|
| 2418 |
y_true: torch.Tensor) -> torch.Tensor:
|
| 2419 |
"""
|
| 2420 |
Continuous Ranked Probability Score (proper scoring rule for distributions).
|
| 2421 |
+
|
| 2422 |
+
Supports two calling conventions:
|
| 2423 |
+
β’ 2-D quantiles (B, Q) β e.g. selected-asset slice, y_true shape (B,)
|
| 2424 |
+
β’ 3-D quantiles (B, N, Q) β e.g. full asset roster, y_true shape (B,)
|
| 2425 |
+
|
| 2426 |
+
The previous implementation used `y_true.unsqueeze(-1)` β (B, 1) for both
|
| 2427 |
+
cases. When quantiles is 3-D, PyTorch left-pads (B, 1) to (1, B, 1) before
|
| 2428 |
+
broadcasting against (B, N, Q), producing a dim-1 mismatch of B vs N
|
| 2429 |
+
(e.g. "size 4 must match 9 at non-singleton dimension 1"). Fixed below by
|
| 2430 |
+
branching on ndim.
|
| 2431 |
"""
|
| 2432 |
+
if quantiles.dim() == 2:
|
| 2433 |
+
# (B, Q) β selected-asset slice or standalone call
|
| 2434 |
+
y_exp = y_true.unsqueeze(-1) # (B, 1)
|
| 2435 |
+
tau = quantile_levels.view(1, -1) # (1, Q)
|
| 2436 |
+
else:
|
| 2437 |
+
# (B, N, Q) β full asset roster
|
| 2438 |
+
y_exp = y_true.view(-1, 1, 1) # (B, 1, 1)
|
| 2439 |
+
tau = quantile_levels.view(1, 1, -1) # (1, 1, Q)
|
| 2440 |
+
errors = y_exp - quantiles
|
| 2441 |
pinball = torch.where(errors >= 0, tau * errors, (tau - 1) * errors)
|
| 2442 |
return pinball.mean()
|
| 2443 |
|
|
|
|
| 3497 |
l_gate = out.get("gate_entropy_loss", torch.tensor(0.0, device=seq_t.device))
|
| 3498 |
|
| 3499 |
# L_crps β distributional calibration (CRPS proper scoring rule) [v8]
|
| 3500 |
+
# Use sel_q (B, n_quantiles) β the selected asset's quantile predictions β
|
| 3501 |
+
# not the full (B, N, n_quantiles) tensor. Calibrating every asset's
|
| 3502 |
+
# distribution against a single scalar reward is semantically incorrect;
|
| 3503 |
+
# only the asset that was actually traded has a matching observed return.
|
| 3504 |
ql_levels = self.model.distributional.quantile_levels
|
| 3505 |
+
l_crps = crps_loss(sel_q, ql_levels, rewards)
|
| 3506 |
|
| 3507 |
# L_rent β regime-router entropy regularisation [v8]
|
| 3508 |
r_probs = out["regime_probs"] + 1e-8 # (B, N, n_regimes)
|
|
|
|
| 4577 |
reward_strategy: str = "simple",
|
| 4578 |
hub_ws_url: str = os.environ.get("QUASAR_HUB_URL", "ws://localhost:7860/ws/subscribe"),
|
| 4579 |
enable_logging: bool = True,
|
| 4580 |
+
checkpoint_dir: str = "./Ranker5", # folder for full-state checkpoints
|
| 4581 |
resume: bool = False, # start afresh by default; set True to resume
|
| 4582 |
):
|
| 4583 |
self.config = config or AssetRankerConfig()
|
|
|
|
| 6231 |
On load, missing local files are pulled from the repo first.
|
| 6232 |
Requires HF_TOKEN env-var with write permission.
|
| 6233 |
|
| 6234 |
+
Fallback : Local disk ./Ranker5/
|
| 6235 |
Used when HF_TOKEN is absent or upload/download fails.
|
| 6236 |
All saves still succeed locally even without network access.
|
| 6237 |
|
|
|
|
| 6285 |
|
| 6286 |
Usage
|
| 6287 |
βββββ
|
| 6288 |
+
mgr = RankerCheckpointManager(checkpoint_dir="./Ranker5",
|
| 6289 |
save_interval_seconds=1800)
|
| 6290 |
mgr.load(bridge) # once after initialize()
|
| 6291 |
mgr.maybe_save(bridge) # call frequently; respects save_interval_seconds
|
|
|
|
| 6301 |
|
| 6302 |
def __init__(
|
| 6303 |
self,
|
| 6304 |
+
checkpoint_dir: str = "./Ranker5",
|
| 6305 |
save_interval_seconds: float = 1800.0, # 30 minutes
|
| 6306 |
):
|
| 6307 |
import pathlib
|
|
|
|
| 6355 |
)
|
| 6356 |
logger.info(
|
| 6357 |
f"[RankerCheckpoint] βοΈ Uploaded {local_path.name} β "
|
| 6358 |
+
f"hf://{self.HF_REPO_ID}/Ranker5/{local_path.name}"
|
| 6359 |
)
|
| 6360 |
return True
|
| 6361 |
except Exception as exc:
|
|
|
|
| 6386 |
)
|
| 6387 |
logger.info(
|
| 6388 |
f"[RankerCheckpoint] β¬οΈ Downloaded {filename} from "
|
| 6389 |
+
f"hf://{self.HF_REPO_ID}/Ranker5/"
|
| 6390 |
)
|
| 6391 |
return True
|
| 6392 |
except Exception as exc:
|
|
|
|
| 6927 |
hub_ws_url: str = "ws://localhost:7860/ws/subscribe",
|
| 6928 |
enable_logging: bool = True,
|
| 6929 |
shreve_config: Optional[ShreveConfig] = None,
|
| 6930 |
+
checkpoint_dir: str = "./Ranker5",
|
| 6931 |
resume: bool = False, # start fresh by default
|
| 6932 |
) -> None:
|
| 6933 |
config = AssetRankerConfig(
|
|
|
|
| 7239 |
help="[S6/S8] Trade horizon Ο in seconds (default 60)")
|
| 7240 |
parser.add_argument("--martingale-epsilon", type=float, default=0.05,
|
| 7241 |
help="[S7] Gate E martingale deviation threshold (default 0.05)")
|
| 7242 |
+
parser.add_argument("--checkpoint-dir", default="./Ranker5",
|
| 7243 |
+
help="Directory for full-state checkpoints (default ./Ranker5)")
|
| 7244 |
parser.add_argument("--resume", action="store_true",
|
| 7245 |
help="Resume training from the latest saved checkpoint (default: start fresh)")
|
| 7246 |
return parser.parse_args(filtered)
|