KarlQuant commited on
Commit
23e3118
Β·
verified Β·
1 Parent(s): 97e19c3

Upload Quasar_axrvi_ranker.py

Browse files
Files changed (1) hide show
  1. Quasar_axrvi_ranker.py +40 -13
Quasar_axrvi_ranker.py CHANGED
@@ -1,3 +1,10 @@
 
 
 
 
 
 
 
1
  # -*- coding: utf-8 -*-
2
  """
3
  Created on Fri Apr 3 13:46:18 2026
@@ -2411,10 +2418,26 @@ def crps_loss(quantiles: torch.Tensor, quantile_levels: torch.Tensor,
2411
  y_true: torch.Tensor) -> torch.Tensor:
2412
  """
2413
  Continuous Ranked Probability Score (proper scoring rule for distributions).
 
 
 
 
 
 
 
 
 
 
2414
  """
2415
- y_exp = y_true.unsqueeze(-1)
2416
- errors = y_exp - quantiles
2417
- tau = quantile_levels.view(1, 1, -1)
 
 
 
 
 
 
2418
  pinball = torch.where(errors >= 0, tau * errors, (tau - 1) * errors)
2419
  return pinball.mean()
2420
 
@@ -3474,8 +3497,12 @@ class HybridTrainer:
3474
  l_gate = out.get("gate_entropy_loss", torch.tensor(0.0, device=seq_t.device))
3475
 
3476
  # L_crps β€” distributional calibration (CRPS proper scoring rule) [v8]
 
 
 
 
3477
  ql_levels = self.model.distributional.quantile_levels
3478
- l_crps = crps_loss(quantiles, ql_levels, rewards)
3479
 
3480
  # L_rent β€” regime-router entropy regularisation [v8]
3481
  r_probs = out["regime_probs"] + 1e-8 # (B, N, n_regimes)
@@ -4550,7 +4577,7 @@ class QuasarAXRVIBridge:
4550
  reward_strategy: str = "simple",
4551
  hub_ws_url: str = os.environ.get("QUASAR_HUB_URL", "ws://localhost:7860/ws/subscribe"),
4552
  enable_logging: bool = True,
4553
- checkpoint_dir: str = "./Ranker4", # folder for full-state checkpoints
4554
  resume: bool = False, # start afresh by default; set True to resume
4555
  ):
4556
  self.config = config or AssetRankerConfig()
@@ -6204,7 +6231,7 @@ class RankerCheckpointManager:
6204
  On load, missing local files are pulled from the repo first.
6205
  Requires HF_TOKEN env-var with write permission.
6206
 
6207
- Fallback : Local disk ./Ranker4/
6208
  Used when HF_TOKEN is absent or upload/download fails.
6209
  All saves still succeed locally even without network access.
6210
 
@@ -6258,7 +6285,7 @@ class RankerCheckpointManager:
6258
 
6259
  Usage
6260
  ─────
6261
- mgr = RankerCheckpointManager(checkpoint_dir="./Ranker4",
6262
  save_interval_seconds=1800)
6263
  mgr.load(bridge) # once after initialize()
6264
  mgr.maybe_save(bridge) # call frequently; respects save_interval_seconds
@@ -6274,7 +6301,7 @@ class RankerCheckpointManager:
6274
 
6275
  def __init__(
6276
  self,
6277
- checkpoint_dir: str = "./Ranker4",
6278
  save_interval_seconds: float = 1800.0, # 30 minutes
6279
  ):
6280
  import pathlib
@@ -6328,7 +6355,7 @@ class RankerCheckpointManager:
6328
  )
6329
  logger.info(
6330
  f"[RankerCheckpoint] ☁️ Uploaded {local_path.name} β†’ "
6331
- f"hf://{self.HF_REPO_ID}/Ranker4/{local_path.name}"
6332
  )
6333
  return True
6334
  except Exception as exc:
@@ -6359,7 +6386,7 @@ class RankerCheckpointManager:
6359
  )
6360
  logger.info(
6361
  f"[RankerCheckpoint] ⬇️ Downloaded {filename} from "
6362
- f"hf://{self.HF_REPO_ID}/Ranker4/"
6363
  )
6364
  return True
6365
  except Exception as exc:
@@ -6900,7 +6927,7 @@ async def run_live_trading_system(
6900
  hub_ws_url: str = "ws://localhost:7860/ws/subscribe",
6901
  enable_logging: bool = True,
6902
  shreve_config: Optional[ShreveConfig] = None,
6903
- checkpoint_dir: str = "./Ranker4",
6904
  resume: bool = False, # start fresh by default
6905
  ) -> None:
6906
  config = AssetRankerConfig(
@@ -7212,8 +7239,8 @@ def _parse_args():
7212
  help="[S6/S8] Trade horizon Ο„ in seconds (default 60)")
7213
  parser.add_argument("--martingale-epsilon", type=float, default=0.05,
7214
  help="[S7] Gate E martingale deviation threshold (default 0.05)")
7215
- parser.add_argument("--checkpoint-dir", default="./Ranker4",
7216
- help="Directory for full-state checkpoints (default ./Ranker4)")
7217
  parser.add_argument("--resume", action="store_true",
7218
  help="Resume training from the latest saved checkpoint (default: start fresh)")
7219
  return parser.parse_args(filtered)
 
1
+ # -*- coding: utf-8 -*-
2
+ """
3
+ Created on Fri Apr 3 21:26:48 2026
4
+
5
+ @author: taten
6
+ """
7
+
8
  # -*- coding: utf-8 -*-
9
  """
10
  Created on Fri Apr 3 13:46:18 2026
 
2418
  y_true: torch.Tensor) -> torch.Tensor:
2419
  """
2420
  Continuous Ranked Probability Score (proper scoring rule for distributions).
2421
+
2422
+ Supports two calling conventions:
2423
+ β€’ 2-D quantiles (B, Q) β€” e.g. selected-asset slice, y_true shape (B,)
2424
+ β€’ 3-D quantiles (B, N, Q) β€” e.g. full asset roster, y_true shape (B,)
2425
+
2426
+ The previous implementation used `y_true.unsqueeze(-1)` β†’ (B, 1) for both
2427
+ cases. When quantiles is 3-D, PyTorch left-pads (B, 1) to (1, B, 1) before
2428
+ broadcasting against (B, N, Q), producing a dim-1 mismatch of B vs N
2429
+ (e.g. "size 4 must match 9 at non-singleton dimension 1"). Fixed below by
2430
+ branching on ndim.
2431
  """
2432
+ if quantiles.dim() == 2:
2433
+ # (B, Q) β€” selected-asset slice or standalone call
2434
+ y_exp = y_true.unsqueeze(-1) # (B, 1)
2435
+ tau = quantile_levels.view(1, -1) # (1, Q)
2436
+ else:
2437
+ # (B, N, Q) β€” full asset roster
2438
+ y_exp = y_true.view(-1, 1, 1) # (B, 1, 1)
2439
+ tau = quantile_levels.view(1, 1, -1) # (1, 1, Q)
2440
+ errors = y_exp - quantiles
2441
  pinball = torch.where(errors >= 0, tau * errors, (tau - 1) * errors)
2442
  return pinball.mean()
2443
 
 
3497
  l_gate = out.get("gate_entropy_loss", torch.tensor(0.0, device=seq_t.device))
3498
 
3499
  # L_crps β€” distributional calibration (CRPS proper scoring rule) [v8]
3500
+ # Use sel_q (B, n_quantiles) β€” the selected asset's quantile predictions β€”
3501
+ # not the full (B, N, n_quantiles) tensor. Calibrating every asset's
3502
+ # distribution against a single scalar reward is semantically incorrect;
3503
+ # only the asset that was actually traded has a matching observed return.
3504
  ql_levels = self.model.distributional.quantile_levels
3505
+ l_crps = crps_loss(sel_q, ql_levels, rewards)
3506
 
3507
  # L_rent β€” regime-router entropy regularisation [v8]
3508
  r_probs = out["regime_probs"] + 1e-8 # (B, N, n_regimes)
 
4577
  reward_strategy: str = "simple",
4578
  hub_ws_url: str = os.environ.get("QUASAR_HUB_URL", "ws://localhost:7860/ws/subscribe"),
4579
  enable_logging: bool = True,
4580
+ checkpoint_dir: str = "./Ranker5", # folder for full-state checkpoints
4581
  resume: bool = False, # start afresh by default; set True to resume
4582
  ):
4583
  self.config = config or AssetRankerConfig()
 
6231
  On load, missing local files are pulled from the repo first.
6232
  Requires HF_TOKEN env-var with write permission.
6233
 
6234
+ Fallback : Local disk ./Ranker5/
6235
  Used when HF_TOKEN is absent or upload/download fails.
6236
  All saves still succeed locally even without network access.
6237
 
 
6285
 
6286
  Usage
6287
  ─────
6288
+ mgr = RankerCheckpointManager(checkpoint_dir="./Ranker5",
6289
  save_interval_seconds=1800)
6290
  mgr.load(bridge) # once after initialize()
6291
  mgr.maybe_save(bridge) # call frequently; respects save_interval_seconds
 
6301
 
6302
  def __init__(
6303
  self,
6304
+ checkpoint_dir: str = "./Ranker5",
6305
  save_interval_seconds: float = 1800.0, # 30 minutes
6306
  ):
6307
  import pathlib
 
6355
  )
6356
  logger.info(
6357
  f"[RankerCheckpoint] ☁️ Uploaded {local_path.name} β†’ "
6358
+ f"hf://{self.HF_REPO_ID}/Ranker5/{local_path.name}"
6359
  )
6360
  return True
6361
  except Exception as exc:
 
6386
  )
6387
  logger.info(
6388
  f"[RankerCheckpoint] ⬇️ Downloaded {filename} from "
6389
+ f"hf://{self.HF_REPO_ID}/Ranker5/"
6390
  )
6391
  return True
6392
  except Exception as exc:
 
6927
  hub_ws_url: str = "ws://localhost:7860/ws/subscribe",
6928
  enable_logging: bool = True,
6929
  shreve_config: Optional[ShreveConfig] = None,
6930
+ checkpoint_dir: str = "./Ranker5",
6931
  resume: bool = False, # start fresh by default
6932
  ) -> None:
6933
  config = AssetRankerConfig(
 
7239
  help="[S6/S8] Trade horizon Ο„ in seconds (default 60)")
7240
  parser.add_argument("--martingale-epsilon", type=float, default=0.05,
7241
  help="[S7] Gate E martingale deviation threshold (default 0.05)")
7242
+ parser.add_argument("--checkpoint-dir", default="./Ranker5",
7243
+ help="Directory for full-state checkpoints (default ./Ranker5)")
7244
  parser.add_argument("--resume", action="store_true",
7245
  help="Resume training from the latest saved checkpoint (default: start fresh)")
7246
  return parser.parse_args(filtered)