KarlQuant commited on
Commit
640294f
Β·
verified Β·
1 Parent(s): f1ebef0

Upload Quasar_axrvi_ranker.py

Browse files
Files changed (1) hide show
  1. Quasar_axrvi_ranker.py +257 -104
Quasar_axrvi_ranker.py CHANGED
@@ -392,8 +392,8 @@ GAMMA = 0.99
392
  LAMBDA_RANK = 0.4
393
  LAMBDA_RISK = 0.3
394
  REPLAY_CAPACITY = 10_000
395
- TRAIN_BATCH = 16 # Reduced from 64 β€” fills after ~3–5 min of live trading
396
- TRAIN_EVERY_N = 10 # Reduced from 50 β€” checks buffer every ~2 min
397
 
398
  # Connection
399
  WS_RECONNECT_DELAY = 5
@@ -1751,9 +1751,14 @@ class STDPAdjacencyLayer(nn.Module):
1751
  regime_weights: (B, N) optional multiplier (e.g., crash probability)
1752
  """
1753
  B, N, _ = asset_embeddings.shape
 
 
1754
 
 
 
 
1755
  # Activity proxy: L2 norm normalized across batch
1756
- activity = asset_embeddings.norm(dim=-1)
1757
  mu = activity.mean(dim=0, keepdim=True)
1758
  std = activity.std(dim=0, keepdim=True) + 1e-6
1759
  z_activity = (activity - mu) / std
@@ -1787,8 +1792,16 @@ class STDPAdjacencyLayer(nn.Module):
1787
  self.step_count += 1
1788
 
1789
  def get_adapted_bias(self, base_adj_bias: torch.Tensor, N: int) -> torch.Tensor:
1790
- """Return base bias + accumulated STDP delta."""
1791
- return base_adj_bias[:, :N, :N] + self.stdp_delta[:, :N, :N]
 
 
 
 
 
 
 
 
1792
 
1793
  def reset_plasticity(self) -> None:
1794
  """Reset STDP state (call on regime shift detection)."""
@@ -2015,7 +2028,7 @@ class HyperbolicCrossAssetLayer(nn.Module):
2015
  nn.Linear(d_model, d_model)
2016
  )
2017
 
2018
- # Learnable adjacency bias
2019
  self.adj_bias = nn.Parameter(torch.zeros(num_heads, self.MAX_ASSETS, self.MAX_ASSETS))
2020
 
2021
  # STDP plasticity
@@ -2057,6 +2070,12 @@ class HyperbolicCrossAssetLayer(nn.Module):
2057
 
2058
  adapted_bias = self.stdp.get_adapted_bias(self.adj_bias, N)
2059
 
 
 
 
 
 
 
2060
  attn = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k)
2061
  attn = attn + adapted_bias.unsqueeze(0)
2062
 
@@ -2458,8 +2477,25 @@ class AXRVINet(nn.Module):
2458
  [5] KANScoringHead β†’ significance scores
2459
  """
2460
 
2461
- def __init__(self, num_assets: int = 5, config: AXRVIConfig = DEFAULT_CONFIG):
 
 
 
 
 
 
 
 
2462
  super().__init__()
 
 
 
 
 
 
 
 
 
2463
  self.num_assets = num_assets
2464
  self.config = config
2465
  self.d_model = config.d_model
@@ -2809,8 +2845,12 @@ def create_axrvi_v8(num_assets: int = 5,
2809
  def _axrvi_config_from_ranker_config(rc: "AssetRankerConfig") -> AXRVIConfig:
2810
  """
2811
  Derive an ``AXRVIConfig`` from an ``AssetRankerConfig`` so all dimension
2812
- constants live in one place. Used by QuasarAXRVIBridge.initialize() and
2813
- any stand-alone callers that have an AssetRankerConfig available.
 
 
 
 
2814
 
2815
  The mapping is:
2816
  rc.feature_dim β†’ AXRVIConfig.feature_dim (26)
@@ -3325,6 +3365,16 @@ class HybridTrainer:
3325
  self.train_step = 0
3326
  self.loss_history = deque(maxlen=200)
3327
 
 
 
 
 
 
 
 
 
 
 
3328
  def train_on_batch(self, episodes: List[dict]) -> dict:
3329
  """
3330
  v7 training β€” 5-loss objective:
@@ -3369,88 +3419,84 @@ class HybridTrainer:
3369
  imp_w = torch.tensor([ep.get("importance_weight", 1.0) for ep in valid],
3370
  dtype=torch.float32, device=device)
3371
 
3372
- out = self.model(seq_t)
3373
- scores = out["significance_logits"]
3374
- value = out["value"].squeeze(-1) # (B, N) β€” median QΜ‚_{0.5}
3375
- log_var = out["log_var"].squeeze(-1) # (B, N)
3376
- quantiles = out["quantiles"] # (B, N, n_quantiles)
3377
-
3378
- with torch.no_grad():
3379
- next_out = self.model(next_seq_t)
3380
- best_next_v = next_out["value"].squeeze(-1).max(dim=1).values
3381
-
3382
- # L_rl β€” TD error
3383
- selected_v = value.gather(1, selected.unsqueeze(1)).squeeze(1)
3384
- td_target = rewards + self.gamma * best_next_v
3385
- l_rl = (imp_w * F.mse_loss(selected_v, td_target.detach(),
3386
- reduction="none")).mean()
3387
-
3388
- # L_ce β€” [S1] Value-consistency: selected_v β‰ˆ E[R_{t+Ο„} | F_t]
3389
- l_ce = (imp_w * F.mse_loss(selected_v, rewards.detach(),
3390
- reduction="none")).mean()
3391
-
3392
- # L_rank β€” Ranking margin
3393
- best_idx = pnl_arr.argmax(dim=1)
3394
- worst_idx = pnl_arr.argmin(dim=1)
3395
- l_rank = F.relu(
3396
- self.rank_margin
3397
- - (scores.gather(1, best_idx.unsqueeze(1)).squeeze(1)
3398
- - scores.gather(1, worst_idx.unsqueeze(1)).squeeze(1))
3399
- ).mean()
3400
-
3401
- # L_risk β€” Uncertainty penalty
3402
- l_risk = torch.exp(log_var.gather(1, selected.unsqueeze(1)).squeeze(1)).mean()
3403
-
3404
- # L_ql β€” Quantile / pinball loss [v7]: calibrates full return distribution
3405
- # For the selected asset: gather its quantile predictions (B, n_q)
3406
- sel_q = quantiles.gather(
3407
- 1,
3408
- selected.unsqueeze(1).unsqueeze(2).expand(-1, 1, quantiles.shape[-1])
3409
- ).squeeze(1) # (B, n_quantiles)
3410
-
3411
- tau = self.model.distributional.quantile_levels # (n_quantiles,)
3412
- u = rewards.unsqueeze(1) - sel_q # (B, n_quantiles)
3413
- # ρ_Ο„(u) = uΒ·Ο„ if uβ‰₯0 else uΒ·(Ο„βˆ’1)
3414
- l_ql = torch.max(tau * u, (tau - 1.0) * u).mean()
3415
-
3416
- # L_moe β€” MoE load-balance regularisation [v8]
3417
- # Pulls the scalar already computed in MoETemporalEncoder.forward()
3418
- l_moe = out.get("moe_balance_loss", torch.tensor(0.0, device=seq_t.device))
3419
-
3420
- # L_gate β€” DendriticFFN gate-entropy regularisation [v8]
3421
- # Accumulated from every HyperbolicCrossAssetLayer in AXRVINet.forward()
3422
- l_gate = out.get("gate_entropy_loss", torch.tensor(0.0, device=seq_t.device))
3423
-
3424
- # L_crps β€” distributional calibration (CRPS proper scoring rule) [v8 / Bug 4 fix]
3425
- # Previously computed only in v8_total_loss() which was never called.
3426
- # Now integrated here so the quantile head is directly calibrated during training.
3427
- ql_levels = self.model.distributional.quantile_levels # (n_quantiles,)
3428
- l_crps = crps_loss(quantiles, ql_levels, rewards)
3429
-
3430
- # L_rent β€” regime-router entropy regularisation [v8 / Bug 4 fix]
3431
- # Maximise regime diversity by minimising the negative entropy.
3432
- r_probs = out["regime_probs"] + 1e-8 # (B, N, n_regimes)
3433
- regime_ent = -(r_probs * r_probs.log()).sum(-1).mean() # positive scalar
3434
- l_rent = -regime_ent # negate: minimise loss β†’ maximise entropy
3435
-
3436
- # Total loss β€” 9-component objective [v8 complete]
3437
- # total = L_rl + Ξ»_ceΒ·L_ce + Ξ»_rankΒ·L_rank + Ξ»_riskΒ·L_risk
3438
- # + Ξ»_qlΒ·L_ql + Ξ»_moeΒ·L_moe + Ξ»_gateΒ·L_gate
3439
- # + Ξ»_crpsΒ·L_crps + Ξ»_rentΒ·L_rent
3440
- loss = (l_rl
3441
- + self.lambda_ce * l_ce
3442
- + self.lambda_rank * l_rank
3443
- + self.lambda_risk * l_risk
3444
- + self.lambda_ql * l_ql
3445
- + self.lambda_moe * l_moe
3446
- + self.lambda_gate * l_gate
3447
- + self.lambda_crps * l_crps
3448
- + self.lambda_rent * l_rent)
3449
-
3450
  self.optimizer.zero_grad()
3451
- loss.backward()
 
 
3452
  torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)
3453
- self.optimizer.step()
 
3454
  # scheduler.step() is called externally (once per training epoch / rank
3455
  # cycle) via step_scheduler(), NOT here per batch. Calling it here
3456
  # compressed the entire T_max=1000 cosine schedule into ~14 hours of
@@ -4500,7 +4546,7 @@ class QuasarAXRVIBridge:
4500
  reward_strategy: str = "simple",
4501
  hub_ws_url: str = os.environ.get("QUASAR_HUB_URL", "ws://localhost:7860/ws/subscribe"),
4502
  enable_logging: bool = True,
4503
- checkpoint_dir: str = "./Ranker2", # folder for full-state checkpoints
4504
  resume: bool = False, # start afresh by default; set True to resume
4505
  ):
4506
  self.config = config or AssetRankerConfig()
@@ -4812,6 +4858,15 @@ class QuasarAXRVIBridge:
4812
  with self.position_mgr._lock:
4813
  rejected = self.position_mgr._open_trades.pop(trade_id, None)
4814
  if rejected:
 
 
 
 
 
 
 
 
 
4815
  logger.error(
4816
  f"❌ [{rejected.asset}] BROKER REJECTED buy | "
4817
  f"trade_id={trade_id} | code={code} | {message}"
@@ -4900,8 +4955,16 @@ class QuasarAXRVIBridge:
4900
  else:
4901
  logger.warning(
4902
  f"[Deriv] Buy confirmation β€” no trade_id for req_id={req_id} | "
4903
- f"contract_id={contract_id} (late or orphaned confirmation)"
 
4904
  )
 
 
 
 
 
 
 
4905
 
4906
  def _on_poc_update(self, poc: dict, raw_msg: dict) -> None:
4907
  """
@@ -5257,6 +5320,12 @@ class QuasarAXRVIBridge:
5257
  for asset in departed:
5258
  trade = self.position_mgr.get_open_trade_by_asset(asset)
5259
  if trade:
 
 
 
 
 
 
5260
  streamer = self.price_streamers.get(asset)
5261
  price = streamer.latest_mid if streamer else trade.entry_price
5262
  logger.info(
@@ -5270,11 +5339,18 @@ class QuasarAXRVIBridge:
5270
  # ── Position monitoring ────────────────────────────────────────────────────────────
5271
 
5272
  async def _close_position(self, trade_id: str, exit_price: float) -> None:
5273
- # ── Early-exit sell on Deriv ──────────────────────────────────────────
5274
- # Retrieve contract_id BEFORE close_trade() removes the trade from
5275
- # _open_trades. price=0 means "sell at best available market price".
5276
  with self.position_mgr._lock:
5277
  open_trade = self.position_mgr._open_trades.get(trade_id)
 
 
 
 
 
 
 
 
 
5278
  cid = open_trade.contract_id if open_trade else None
5279
 
5280
  if cid and self.ws_client and self.ws_client.connected:
@@ -5356,13 +5432,75 @@ class QuasarAXRVIBridge:
5356
  After any position is closed, if open_trade_count drops below 2, immediately
5357
  call rank_and_gate() to refill. This ensures the 2-trade minimum is maintained
5358
  continuously without waiting for the next scheduled _rank_loop cycle.
 
 
 
 
5359
  """
5360
  sc = self.config.shreve_config
 
 
5361
  while self.running:
5362
  try:
5363
  closed_any = False # track whether we closed a trade this tick
5364
 
5365
  for trade in self.position_mgr.get_open_trades():
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5366
  streamer = self.price_streamers.get(trade.asset)
5367
  if not streamer:
5368
  continue
@@ -5384,8 +5522,6 @@ class QuasarAXRVIBridge:
5384
  raw_log_ret = math.log(price / trade.entry_price)
5385
  sign = 1.0 if trade.direction == TradeDirection.LONG else -1.0
5386
  # Fees for G_t must be on stake, NOT on spot price.
5387
- # price * commission_rate would be e.g. 316738 * 0.001 = $316 (wrong)
5388
- # stake * commission_rate is e.g. 1.0 * 0.001 = $0.001 (correct)
5389
  fees = self.trade_config.amount * self.trade_config.commission_rate
5390
  slippage = self.trade_config.slippage_bps / 10_000.0
5391
  g_t = sign * raw_log_ret - fees / price - slippage
@@ -5932,6 +6068,21 @@ class QuasarAXRVIBridge:
5932
  next_seq_t, _, _ = self._build_input_tensors()
5933
  next_sequences = next_seq_t.squeeze(0).numpy() # (N, T, F)
5934
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5935
  pnl_proxy = np.zeros(n_assets, dtype=np.float32)
5936
  if asset_id in self.config.asset_symbols:
5937
  pnl_proxy[self.config.asset_symbols.index(asset_id)] = float(reward)
@@ -5945,6 +6096,8 @@ class QuasarAXRVIBridge:
5945
  # Girsanov priority metadata [S2/GirsanovReplayBuffer]
5946
  "volatility": ep["volatility"],
5947
  "td_error": abs(reward),
 
 
5948
  })
5949
 
5950
  if asset_id in self.config.asset_symbols:
@@ -6038,7 +6191,7 @@ class RankerCheckpointManager:
6038
  On load, missing local files are pulled from the repo first.
6039
  Requires HF_TOKEN env-var with write permission.
6040
 
6041
- Fallback : Local disk ./Ranker2/
6042
  Used when HF_TOKEN is absent or upload/download fails.
6043
  All saves still succeed locally even without network access.
6044
 
@@ -6092,7 +6245,7 @@ class RankerCheckpointManager:
6092
 
6093
  Usage
6094
  ─────
6095
- mgr = RankerCheckpointManager(checkpoint_dir="./Ranker2",
6096
  save_interval_seconds=1800)
6097
  mgr.load(bridge) # once after initialize()
6098
  mgr.maybe_save(bridge) # call frequently; respects save_interval_seconds
@@ -6108,7 +6261,7 @@ class RankerCheckpointManager:
6108
 
6109
  def __init__(
6110
  self,
6111
- checkpoint_dir: str = "./Ranker2",
6112
  save_interval_seconds: float = 1800.0, # 30 minutes
6113
  ):
6114
  import pathlib
@@ -6162,7 +6315,7 @@ class RankerCheckpointManager:
6162
  )
6163
  logger.info(
6164
  f"[RankerCheckpoint] ☁️ Uploaded {local_path.name} β†’ "
6165
- f"hf://{self.HF_REPO_ID}/Ranker2/{local_path.name}"
6166
  )
6167
  return True
6168
  except Exception as exc:
@@ -6193,7 +6346,7 @@ class RankerCheckpointManager:
6193
  )
6194
  logger.info(
6195
  f"[RankerCheckpoint] ⬇️ Downloaded {filename} from "
6196
- f"hf://{self.HF_REPO_ID}/Ranker2/"
6197
  )
6198
  return True
6199
  except Exception as exc:
@@ -6734,7 +6887,7 @@ async def run_live_trading_system(
6734
  hub_ws_url: str = "ws://localhost:7860/ws/subscribe",
6735
  enable_logging: bool = True,
6736
  shreve_config: Optional[ShreveConfig] = None,
6737
- checkpoint_dir: str = "./Ranker2",
6738
  resume: bool = False, # start fresh by default
6739
  ) -> None:
6740
  config = AssetRankerConfig(
@@ -7046,8 +7199,8 @@ def _parse_args():
7046
  help="[S6/S8] Trade horizon Ο„ in seconds (default 60)")
7047
  parser.add_argument("--martingale-epsilon", type=float, default=0.05,
7048
  help="[S7] Gate E martingale deviation threshold (default 0.05)")
7049
- parser.add_argument("--checkpoint-dir", default="./Ranker2",
7050
- help="Directory for full-state checkpoints (default ./Ranker2)")
7051
  parser.add_argument("--resume", action="store_true",
7052
  help="Resume training from the latest saved checkpoint (default: start fresh)")
7053
  return parser.parse_args(filtered)
 
392
  LAMBDA_RANK = 0.4
393
  LAMBDA_RISK = 0.3
394
  REPLAY_CAPACITY = 10_000
395
+ TRAIN_BATCH = 8 # Lowered from 16 β€” fills after ~8 closed trades (~2-4 min)
396
+ TRAIN_EVERY_N = 5 # Lowered from 10 β€” checks buffer every ~25s
397
 
398
  # Connection
399
  WS_RECONNECT_DELAY = 5
 
1751
  regime_weights: (B, N) optional multiplier (e.g., crash probability)
1752
  """
1753
  B, N, _ = asset_embeddings.shape
1754
+ # Clamp N to max_assets β€” extra assets are simply not tracked
1755
+ N = min(N, self.max_assets)
1756
 
1757
+ # Use only the first N (clamped) assets
1758
+ asset_emb_n = asset_embeddings[:, :N, :]
1759
+
1760
  # Activity proxy: L2 norm normalized across batch
1761
+ activity = asset_emb_n.norm(dim=-1)
1762
  mu = activity.mean(dim=0, keepdim=True)
1763
  std = activity.std(dim=0, keepdim=True) + 1e-6
1764
  z_activity = (activity - mu) / std
 
1792
  self.step_count += 1
1793
 
1794
  def get_adapted_bias(self, base_adj_bias: torch.Tensor, N: int) -> torch.Tensor:
1795
+ """Return base bias + accumulated STDP delta, clamped to registered max_assets."""
1796
+ N_clamped = min(N, self.max_assets)
1797
+ if N > self.max_assets:
1798
+ # Pad with zeros for the extra assets beyond max_assets
1799
+ pad = N - self.max_assets
1800
+ padded_delta = F.pad(self.stdp_delta[:, :N_clamped, :N_clamped],
1801
+ (0, pad, 0, pad), value=0.0)
1802
+ padded_base = base_adj_bias[:, :N, :N]
1803
+ return padded_base + padded_delta
1804
+ return base_adj_bias[:, :N_clamped, :N_clamped] + self.stdp_delta[:, :N_clamped, :N_clamped]
1805
 
1806
  def reset_plasticity(self) -> None:
1807
  """Reset STDP state (call on regime shift detection)."""
 
2028
  nn.Linear(d_model, d_model)
2029
  )
2030
 
2031
+ # Learnable adjacency bias β€” sized to MAX_ASSETS; dynamically padded in forward()
2032
  self.adj_bias = nn.Parameter(torch.zeros(num_heads, self.MAX_ASSETS, self.MAX_ASSETS))
2033
 
2034
  # STDP plasticity
 
2070
 
2071
  adapted_bias = self.stdp.get_adapted_bias(self.adj_bias, N)
2072
 
2073
+ # If N > MAX_ASSETS the adapted_bias may be smaller; pad with zeros
2074
+ if adapted_bias.shape[-1] < N or adapted_bias.shape[-2] < N:
2075
+ pad_r = N - adapted_bias.shape[-1]
2076
+ pad_c = N - adapted_bias.shape[-2]
2077
+ adapted_bias = F.pad(adapted_bias, (0, pad_r, 0, pad_c), value=0.0)
2078
+
2079
  attn = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k)
2080
  attn = attn + adapted_bias.unsqueeze(0)
2081
 
 
2477
  [5] KANScoringHead β†’ significance scores
2478
  """
2479
 
2480
+ def __init__(self, num_assets: int = 5, config: AXRVIConfig = DEFAULT_CONFIG,
2481
+ feature_dim: int = None, seq_len: int = None):
2482
+ """
2483
+ Args:
2484
+ num_assets : number of parallel asset streams (N dimension)
2485
+ config : AXRVIConfig β€” all hyperparameters
2486
+ feature_dim : override config.feature_dim (backward compat)
2487
+ seq_len : override config.seq_len (backward compat)
2488
+ """
2489
  super().__init__()
2490
+ # Apply any direct overrides so callers that pass feature_dim/seq_len
2491
+ # directly (e.g. test_components) still get the right architecture.
2492
+ if feature_dim is not None or seq_len is not None:
2493
+ import copy
2494
+ config = copy.copy(config)
2495
+ if feature_dim is not None:
2496
+ config.feature_dim = feature_dim
2497
+ if seq_len is not None:
2498
+ config.seq_len = seq_len
2499
  self.num_assets = num_assets
2500
  self.config = config
2501
  self.d_model = config.d_model
 
2845
  def _axrvi_config_from_ranker_config(rc: "AssetRankerConfig") -> AXRVIConfig:
2846
  """
2847
  Derive an ``AXRVIConfig`` from an ``AssetRankerConfig`` so all dimension
2848
+ constants live in one place.
2849
+
2850
+ Called by ``QuasarAXRVIBridge.initialize()`` to construct the neural-net
2851
+ config directly from the top-level ranker config, avoiding duplicated
2852
+ constants. Any stand-alone caller that has an ``AssetRankerConfig`` may
2853
+ use this function as well.
2854
 
2855
  The mapping is:
2856
  rc.feature_dim β†’ AXRVIConfig.feature_dim (26)
 
3365
  self.train_step = 0
3366
  self.loss_history = deque(maxlen=200)
3367
 
3368
+ # ── AMP (Automatic Mixed Precision) ──────────────────────────────────
3369
+ # Reads use_amp from the model's AXRVIConfig so the flag is the single
3370
+ # source of truth. Falls back to False on CPU (AMP is CUDA-only).
3371
+ _cfg = getattr(model, "config", None)
3372
+ _use_amp = bool(getattr(_cfg, "use_amp", False))
3373
+ _device = next(model.parameters()).device
3374
+ # AMP is only meaningful on CUDA; silently disable on CPU
3375
+ self.use_amp = _use_amp and _device.type == "cuda"
3376
+ self.scaler = torch.cuda.amp.GradScaler(enabled=self.use_amp)
3377
+
3378
  def train_on_batch(self, episodes: List[dict]) -> dict:
3379
  """
3380
  v7 training β€” 5-loss objective:
 
3419
  imp_w = torch.tensor([ep.get("importance_weight", 1.0) for ep in valid],
3420
  dtype=torch.float32, device=device)
3421
 
3422
+ # ── Forward pass (AMP-aware) ──────────────────────────────────────────
3423
+ with torch.cuda.amp.autocast(enabled=self.use_amp):
3424
+ out = self.model(seq_t)
3425
+ scores = out["significance_logits"]
3426
+ value = out["value"].squeeze(-1) # (B, N) β€” median QΜ‚_{0.5}
3427
+ log_var = out["log_var"].squeeze(-1) # (B, N)
3428
+ quantiles = out["quantiles"] # (B, N, n_quantiles)
3429
+
3430
+ with torch.no_grad():
3431
+ next_out = self.model(next_seq_t)
3432
+ best_next_v = next_out["value"].squeeze(-1).max(dim=1).values
3433
+
3434
+ # L_rl β€” TD error
3435
+ selected_v = value.gather(1, selected.unsqueeze(1)).squeeze(1)
3436
+ td_target = rewards + self.gamma * best_next_v
3437
+ l_rl = (imp_w * F.mse_loss(selected_v, td_target.detach(),
3438
+ reduction="none")).mean()
3439
+
3440
+ # L_ce β€” [S1] Value-consistency: selected_v β‰ˆ E[R_{t+Ο„} | F_t]
3441
+ l_ce = (imp_w * F.mse_loss(selected_v, rewards.detach(),
3442
+ reduction="none")).mean()
3443
+
3444
+ # L_rank β€” Ranking margin
3445
+ best_idx = pnl_arr.argmax(dim=1)
3446
+ worst_idx = pnl_arr.argmin(dim=1)
3447
+ l_rank = F.relu(
3448
+ self.rank_margin
3449
+ - (scores.gather(1, best_idx.unsqueeze(1)).squeeze(1)
3450
+ - scores.gather(1, worst_idx.unsqueeze(1)).squeeze(1))
3451
+ ).mean()
3452
+
3453
+ # L_risk β€” Uncertainty penalty
3454
+ l_risk = torch.exp(log_var.gather(1, selected.unsqueeze(1)).squeeze(1)).mean()
3455
+
3456
+ # L_ql β€” Quantile / pinball loss [v7]: calibrates full return distribution
3457
+ sel_q = quantiles.gather(
3458
+ 1,
3459
+ selected.unsqueeze(1).unsqueeze(2).expand(-1, 1, quantiles.shape[-1])
3460
+ ).squeeze(1) # (B, n_quantiles)
3461
+
3462
+ tau = self.model.distributional.quantile_levels # (n_quantiles,)
3463
+ u = rewards.unsqueeze(1) - sel_q # (B, n_quantiles)
3464
+ l_ql = torch.max(tau * u, (tau - 1.0) * u).mean()
3465
+
3466
+ # L_moe β€” MoE load-balance regularisation [v8]
3467
+ l_moe = out.get("moe_balance_loss", torch.tensor(0.0, device=seq_t.device))
3468
+
3469
+ # L_gate β€” DendriticFFN gate-entropy regularisation [v8]
3470
+ l_gate = out.get("gate_entropy_loss", torch.tensor(0.0, device=seq_t.device))
3471
+
3472
+ # L_crps β€” distributional calibration (CRPS proper scoring rule) [v8]
3473
+ ql_levels = self.model.distributional.quantile_levels
3474
+ l_crps = crps_loss(quantiles, ql_levels, rewards)
3475
+
3476
+ # L_rent β€” regime-router entropy regularisation [v8]
3477
+ r_probs = out["regime_probs"] + 1e-8 # (B, N, n_regimes)
3478
+ regime_ent = -(r_probs * r_probs.log()).sum(-1).mean()
3479
+ l_rent = -regime_ent
3480
+
3481
+ # Total loss β€” 9-component objective [v8 complete]
3482
+ loss = (l_rl
3483
+ + self.lambda_ce * l_ce
3484
+ + self.lambda_rank * l_rank
3485
+ + self.lambda_risk * l_risk
3486
+ + self.lambda_ql * l_ql
3487
+ + self.lambda_moe * l_moe
3488
+ + self.lambda_gate * l_gate
3489
+ + self.lambda_crps * l_crps
3490
+ + self.lambda_rent * l_rent)
3491
+
3492
+ # ── Backward pass (AMP-aware) ─────────────────────────────────────────
 
 
 
 
 
 
 
3493
  self.optimizer.zero_grad()
3494
+ self.scaler.scale(loss).backward()
3495
+ # Unscale before clipping so grad norms are in the original fp32 scale
3496
+ self.scaler.unscale_(self.optimizer)
3497
  torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)
3498
+ self.scaler.step(self.optimizer)
3499
+ self.scaler.update()
3500
  # scheduler.step() is called externally (once per training epoch / rank
3501
  # cycle) via step_scheduler(), NOT here per batch. Calling it here
3502
  # compressed the entire T_max=1000 cosine schedule into ~14 hours of
 
4546
  reward_strategy: str = "simple",
4547
  hub_ws_url: str = os.environ.get("QUASAR_HUB_URL", "ws://localhost:7860/ws/subscribe"),
4548
  enable_logging: bool = True,
4549
+ checkpoint_dir: str = "./Ranker3", # folder for full-state checkpoints
4550
  resume: bool = False, # start afresh by default; set True to resume
4551
  ):
4552
  self.config = config or AssetRankerConfig()
 
4858
  with self.position_mgr._lock:
4859
  rejected = self.position_mgr._open_trades.pop(trade_id, None)
4860
  if rejected:
4861
+ # ── Clean up all trade-level state so nothing leaks ──
4862
+ # _pending_episodes: episode was opened by _open_pending_episode
4863
+ # but there will never be a close event β€” discard it.
4864
+ self._pending_episodes.pop(trade_id, None)
4865
+ # _trade_tick_counts: monitor_positions tracks this; clear it.
4866
+ self._trade_tick_counts.pop(trade_id, None)
4867
+ # portfolio_risk_mgr: committed capital was registered in
4868
+ # process_axrvi_signal; release it with 0 PnL.
4869
+ self.portfolio_risk_mgr.register_close(trade_id, 0.0)
4870
  logger.error(
4871
  f"❌ [{rejected.asset}] BROKER REJECTED buy | "
4872
  f"trade_id={trade_id} | code={code} | {message}"
 
4955
  else:
4956
  logger.warning(
4957
  f"[Deriv] Buy confirmation β€” no trade_id for req_id={req_id} | "
4958
+ f"contract_id={contract_id} (late or orphaned confirmation) β€” "
4959
+ f"sending immediate SELL to avoid dangling open contract on broker side"
4960
  )
4961
+ # The contract is live on Deriv but we have no internal tracking for
4962
+ # it. Immediately sell at market so we don't accumulate un-tracked
4963
+ # open contracts that drain the account without appearing in our books.
4964
+ if self.ws_client and self.ws_client.connected:
4965
+ asyncio.get_running_loop().create_task(
4966
+ self.ws_client.send_message({"sell": contract_id, "price": 0})
4967
+ )
4968
 
4969
  def _on_poc_update(self, poc: dict, raw_msg: dict) -> None:
4970
  """
 
5320
  for asset in departed:
5321
  trade = self.position_mgr.get_open_trade_by_asset(asset)
5322
  if trade:
5323
+ # Skip if already closing β€” SELL was already sent to broker
5324
+ if trade.state == PositionState.CLOSING:
5325
+ logger.debug(
5326
+ f"[Rotation] ⏩ {asset} already CLOSING β€” skipping duplicate SELL"
5327
+ )
5328
+ continue
5329
  streamer = self.price_streamers.get(asset)
5330
  price = streamer.latest_mid if streamer else trade.entry_price
5331
  logger.info(
 
5339
  # ── Position monitoring ────────────────────────────────────────────────────────────
5340
 
5341
  async def _close_position(self, trade_id: str, exit_price: float) -> None:
5342
+ # ── Prevent re-sending SELL for a contract already in CLOSING state ─
 
 
5343
  with self.position_mgr._lock:
5344
  open_trade = self.position_mgr._open_trades.get(trade_id)
5345
+ if open_trade is None:
5346
+ return # already closed
5347
+ if open_trade.state == PositionState.CLOSING:
5348
+ # Sell already sent; don't spam the broker β€” just wait for POC
5349
+ logger.debug(
5350
+ f"[{trade_id}] ⏳ Already CLOSING β€” skipping duplicate SELL | "
5351
+ f"contract_id={open_trade.contract_id}"
5352
+ )
5353
+ return
5354
  cid = open_trade.contract_id if open_trade else None
5355
 
5356
  if cid and self.ws_client and self.ws_client.connected:
 
5432
  After any position is closed, if open_trade_count drops below 2, immediately
5433
  call rank_and_gate() to refill. This ensures the 2-trade minimum is maintained
5434
  continuously without waiting for the next scheduled _rank_loop cycle.
5435
+
5436
+ CLOSING TIMEOUT FIX (v6.1):
5437
+ If a trade remains in CLOSING state for > 10 seconds without a terminal
5438
+ event from the broker, force-close it locally to prevent stuck trades.
5439
  """
5440
  sc = self.config.shreve_config
5441
+ CLOSING_TIMEOUT_SECONDS = 10.0 # Maximum time to wait for broker terminal event
5442
+
5443
  while self.running:
5444
  try:
5445
  closed_any = False # track whether we closed a trade this tick
5446
 
5447
  for trade in self.position_mgr.get_open_trades():
5448
+ # ── CLOSING TIMEOUT HANDLER ──────────────────────────────────
5449
+ # If trade has been CLOSING for > CLOSING_TIMEOUT_SECONDS,
5450
+ # force-close it locally (broker never responded)
5451
+ if trade.state == PositionState.CLOSING:
5452
+ closing_duration = time.time() - trade.exit_time if trade.exit_time else 0
5453
+ if closing_duration > CLOSING_TIMEOUT_SECONDS:
5454
+ logger.warning(
5455
+ f"[{trade.asset}] ⚠️ CLOSING TIMEOUT | "
5456
+ f"trade_id={trade.trade_id} | "
5457
+ f"contract_id={trade.contract_id} | "
5458
+ f"stuck in CLOSING for {closing_duration:.1f}s β€” "
5459
+ f"forcing local close"
5460
+ )
5461
+ # Force close locally (broker never responded)
5462
+ # Use current price as exit price
5463
+ streamer = self.price_streamers.get(trade.asset)
5464
+ price = streamer.latest_mid if streamer else trade.entry_price
5465
+
5466
+ # Estimate profit from broker data if available
5467
+ if trade.profit is not None:
5468
+ profit = trade.profit
5469
+ else:
5470
+ # Fallback: estimate from price movement
5471
+ if price > 0 and trade.entry_price > 0:
5472
+ pct_move = (price - trade.entry_price) / trade.entry_price
5473
+ sign = 1.0 if trade.direction == TradeDirection.LONG else -1.0
5474
+ mult = ASSET_MULTIPLIER.get(trade.asset, 50)
5475
+ stake = trade.buy_price if (trade.buy_price and trade.buy_price > 0) else 1.0
5476
+ profit = sign * pct_move * stake * mult
5477
+ else:
5478
+ profit = 0.0
5479
+
5480
+ closed_trade = self.position_mgr.close_trade_from_broker(
5481
+ trade_id=trade.trade_id,
5482
+ status="timeout",
5483
+ profit=profit,
5484
+ sell_price=price,
5485
+ exit_tick=price,
5486
+ )
5487
+
5488
+ if closed_trade:
5489
+ reward = self._reward_from_broker(closed_trade)
5490
+ self.portfolio_risk_mgr.register_close(trade.trade_id, closed_trade.realized_pnl)
5491
+ self._close_pending_episode(trade.trade_id, reward)
5492
+ self._trade_tick_counts.pop(trade.trade_id, None)
5493
+
5494
+ self.stats["trades_closed"] += 1
5495
+ self.stats["total_pnl"] += closed_trade.realized_pnl
5496
+ closed_any = True
5497
+
5498
+ logger.info(
5499
+ f"πŸ’° [{closed_trade.asset}] TRADE FORCE-CLOSED (timeout) | "
5500
+ f"reward={reward:+.6f} | profit={profit:+.4f}"
5501
+ )
5502
+ continue # Skip other checks for CLOSING trades
5503
+
5504
  streamer = self.price_streamers.get(trade.asset)
5505
  if not streamer:
5506
  continue
 
5522
  raw_log_ret = math.log(price / trade.entry_price)
5523
  sign = 1.0 if trade.direction == TradeDirection.LONG else -1.0
5524
  # Fees for G_t must be on stake, NOT on spot price.
 
 
5525
  fees = self.trade_config.amount * self.trade_config.commission_rate
5526
  slippage = self.trade_config.slippage_bps / 10_000.0
5527
  g_t = sign * raw_log_ret - fees / price - slippage
 
6068
  next_seq_t, _, _ = self._build_input_tensors()
6069
  next_sequences = next_seq_t.squeeze(0).numpy() # (N, T, F)
6070
 
6071
+ # ── Episode validity check: s_t and s_{t+1} must differ ────────────
6072
+ # If they are identical the TD error collapses to zero and training
6073
+ # produces no gradient. This happens when a trade opens and closes
6074
+ # within the same rank cycle before any price ticks arrive.
6075
+ # In that case we still push the episode but tag it so HybridTrainer
6076
+ # can apply a lower importance weight.
6077
+ state_diff = float(np.mean(np.abs(next_sequences - sequences)))
6078
+ if state_diff < 1e-6:
6079
+ logger.warning(
6080
+ f"[_close_pending_episode] [{asset_id}] s_t β‰ˆ s_{{t+1}} "
6081
+ f"(diff={state_diff:.2e}) β€” episode may not produce useful gradient. "
6082
+ f"Trade may have closed before a rank cycle completed."
6083
+ )
6084
+ # Still push β€” the buffer needs data; the trainer will use low importance_weight
6085
+
6086
  pnl_proxy = np.zeros(n_assets, dtype=np.float32)
6087
  if asset_id in self.config.asset_symbols:
6088
  pnl_proxy[self.config.asset_symbols.index(asset_id)] = float(reward)
 
6096
  # Girsanov priority metadata [S2/GirsanovReplayBuffer]
6097
  "volatility": ep["volatility"],
6098
  "td_error": abs(reward),
6099
+ # State diversity marker β€” used for importance weighting
6100
+ "state_diff": state_diff,
6101
  })
6102
 
6103
  if asset_id in self.config.asset_symbols:
 
6191
  On load, missing local files are pulled from the repo first.
6192
  Requires HF_TOKEN env-var with write permission.
6193
 
6194
+ Fallback : Local disk ./Ranker3/
6195
  Used when HF_TOKEN is absent or upload/download fails.
6196
  All saves still succeed locally even without network access.
6197
 
 
6245
 
6246
  Usage
6247
  ─────
6248
+ mgr = RankerCheckpointManager(checkpoint_dir="./Ranker3",
6249
  save_interval_seconds=1800)
6250
  mgr.load(bridge) # once after initialize()
6251
  mgr.maybe_save(bridge) # call frequently; respects save_interval_seconds
 
6261
 
6262
  def __init__(
6263
  self,
6264
+ checkpoint_dir: str = "./Ranker3",
6265
  save_interval_seconds: float = 1800.0, # 30 minutes
6266
  ):
6267
  import pathlib
 
6315
  )
6316
  logger.info(
6317
  f"[RankerCheckpoint] ☁️ Uploaded {local_path.name} β†’ "
6318
+ f"hf://{self.HF_REPO_ID}/Ranker3/{local_path.name}"
6319
  )
6320
  return True
6321
  except Exception as exc:
 
6346
  )
6347
  logger.info(
6348
  f"[RankerCheckpoint] ⬇️ Downloaded {filename} from "
6349
+ f"hf://{self.HF_REPO_ID}/Ranker3/"
6350
  )
6351
  return True
6352
  except Exception as exc:
 
6887
  hub_ws_url: str = "ws://localhost:7860/ws/subscribe",
6888
  enable_logging: bool = True,
6889
  shreve_config: Optional[ShreveConfig] = None,
6890
+ checkpoint_dir: str = "./Ranker3",
6891
  resume: bool = False, # start fresh by default
6892
  ) -> None:
6893
  config = AssetRankerConfig(
 
7199
  help="[S6/S8] Trade horizon Ο„ in seconds (default 60)")
7200
  parser.add_argument("--martingale-epsilon", type=float, default=0.05,
7201
  help="[S7] Gate E martingale deviation threshold (default 0.05)")
7202
+ parser.add_argument("--checkpoint-dir", default="./Ranker3",
7203
+ help="Directory for full-state checkpoints (default ./Ranker3)")
7204
  parser.add_argument("--resume", action="store_true",
7205
  help="Resume training from the latest saved checkpoint (default: start fresh)")
7206
  return parser.parse_args(filtered)