Spaces:
Running
Running
Upload Quasar_axrvi_ranker.py
Browse files- Quasar_axrvi_ranker.py +257 -104
Quasar_axrvi_ranker.py
CHANGED
|
@@ -392,8 +392,8 @@ GAMMA = 0.99
|
|
| 392 |
LAMBDA_RANK = 0.4
|
| 393 |
LAMBDA_RISK = 0.3
|
| 394 |
REPLAY_CAPACITY = 10_000
|
| 395 |
-
TRAIN_BATCH =
|
| 396 |
-
TRAIN_EVERY_N =
|
| 397 |
|
| 398 |
# Connection
|
| 399 |
WS_RECONNECT_DELAY = 5
|
|
@@ -1751,9 +1751,14 @@ class STDPAdjacencyLayer(nn.Module):
|
|
| 1751 |
regime_weights: (B, N) optional multiplier (e.g., crash probability)
|
| 1752 |
"""
|
| 1753 |
B, N, _ = asset_embeddings.shape
|
|
|
|
|
|
|
| 1754 |
|
|
|
|
|
|
|
|
|
|
| 1755 |
# Activity proxy: L2 norm normalized across batch
|
| 1756 |
-
activity =
|
| 1757 |
mu = activity.mean(dim=0, keepdim=True)
|
| 1758 |
std = activity.std(dim=0, keepdim=True) + 1e-6
|
| 1759 |
z_activity = (activity - mu) / std
|
|
@@ -1787,8 +1792,16 @@ class STDPAdjacencyLayer(nn.Module):
|
|
| 1787 |
self.step_count += 1
|
| 1788 |
|
| 1789 |
def get_adapted_bias(self, base_adj_bias: torch.Tensor, N: int) -> torch.Tensor:
|
| 1790 |
-
"""Return base bias + accumulated STDP delta."""
|
| 1791 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1792 |
|
| 1793 |
def reset_plasticity(self) -> None:
|
| 1794 |
"""Reset STDP state (call on regime shift detection)."""
|
|
@@ -2015,7 +2028,7 @@ class HyperbolicCrossAssetLayer(nn.Module):
|
|
| 2015 |
nn.Linear(d_model, d_model)
|
| 2016 |
)
|
| 2017 |
|
| 2018 |
-
# Learnable adjacency bias
|
| 2019 |
self.adj_bias = nn.Parameter(torch.zeros(num_heads, self.MAX_ASSETS, self.MAX_ASSETS))
|
| 2020 |
|
| 2021 |
# STDP plasticity
|
|
@@ -2057,6 +2070,12 @@ class HyperbolicCrossAssetLayer(nn.Module):
|
|
| 2057 |
|
| 2058 |
adapted_bias = self.stdp.get_adapted_bias(self.adj_bias, N)
|
| 2059 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2060 |
attn = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k)
|
| 2061 |
attn = attn + adapted_bias.unsqueeze(0)
|
| 2062 |
|
|
@@ -2458,8 +2477,25 @@ class AXRVINet(nn.Module):
|
|
| 2458 |
[5] KANScoringHead β significance scores
|
| 2459 |
"""
|
| 2460 |
|
| 2461 |
-
def __init__(self, num_assets: int = 5, config: AXRVIConfig = DEFAULT_CONFIG
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2462 |
super().__init__()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2463 |
self.num_assets = num_assets
|
| 2464 |
self.config = config
|
| 2465 |
self.d_model = config.d_model
|
|
@@ -2809,8 +2845,12 @@ def create_axrvi_v8(num_assets: int = 5,
|
|
| 2809 |
def _axrvi_config_from_ranker_config(rc: "AssetRankerConfig") -> AXRVIConfig:
|
| 2810 |
"""
|
| 2811 |
Derive an ``AXRVIConfig`` from an ``AssetRankerConfig`` so all dimension
|
| 2812 |
-
constants live in one place.
|
| 2813 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2814 |
|
| 2815 |
The mapping is:
|
| 2816 |
rc.feature_dim β AXRVIConfig.feature_dim (26)
|
|
@@ -3325,6 +3365,16 @@ class HybridTrainer:
|
|
| 3325 |
self.train_step = 0
|
| 3326 |
self.loss_history = deque(maxlen=200)
|
| 3327 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3328 |
def train_on_batch(self, episodes: List[dict]) -> dict:
|
| 3329 |
"""
|
| 3330 |
v7 training β 5-loss objective:
|
|
@@ -3369,88 +3419,84 @@ class HybridTrainer:
|
|
| 3369 |
imp_w = torch.tensor([ep.get("importance_weight", 1.0) for ep in valid],
|
| 3370 |
dtype=torch.float32, device=device)
|
| 3371 |
|
| 3372 |
-
|
| 3373 |
-
|
| 3374 |
-
|
| 3375 |
-
|
| 3376 |
-
|
| 3377 |
-
|
| 3378 |
-
|
| 3379 |
-
|
| 3380 |
-
|
| 3381 |
-
|
| 3382 |
-
|
| 3383 |
-
|
| 3384 |
-
|
| 3385 |
-
|
| 3386 |
-
|
| 3387 |
-
|
| 3388 |
-
|
| 3389 |
-
|
| 3390 |
-
|
| 3391 |
-
|
| 3392 |
-
|
| 3393 |
-
|
| 3394 |
-
|
| 3395 |
-
|
| 3396 |
-
|
| 3397 |
-
|
| 3398 |
-
|
| 3399 |
-
|
| 3400 |
-
|
| 3401 |
-
|
| 3402 |
-
|
| 3403 |
-
|
| 3404 |
-
|
| 3405 |
-
|
| 3406 |
-
|
| 3407 |
-
|
| 3408 |
-
|
| 3409 |
-
|
| 3410 |
-
|
| 3411 |
-
|
| 3412 |
-
|
| 3413 |
-
|
| 3414 |
-
|
| 3415 |
-
|
| 3416 |
-
|
| 3417 |
-
|
| 3418 |
-
|
| 3419 |
-
|
| 3420 |
-
|
| 3421 |
-
|
| 3422 |
-
|
| 3423 |
-
|
| 3424 |
-
|
| 3425 |
-
|
| 3426 |
-
|
| 3427 |
-
|
| 3428 |
-
|
| 3429 |
-
|
| 3430 |
-
|
| 3431 |
-
|
| 3432 |
-
|
| 3433 |
-
|
| 3434 |
-
|
| 3435 |
-
|
| 3436 |
-
|
| 3437 |
-
|
| 3438 |
-
|
| 3439 |
-
|
| 3440 |
-
|
| 3441 |
-
|
| 3442 |
-
|
| 3443 |
-
+ self.lambda_risk * l_risk
|
| 3444 |
-
+ self.lambda_ql * l_ql
|
| 3445 |
-
+ self.lambda_moe * l_moe
|
| 3446 |
-
+ self.lambda_gate * l_gate
|
| 3447 |
-
+ self.lambda_crps * l_crps
|
| 3448 |
-
+ self.lambda_rent * l_rent)
|
| 3449 |
-
|
| 3450 |
self.optimizer.zero_grad()
|
| 3451 |
-
loss.backward()
|
|
|
|
|
|
|
| 3452 |
torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)
|
| 3453 |
-
self.
|
|
|
|
| 3454 |
# scheduler.step() is called externally (once per training epoch / rank
|
| 3455 |
# cycle) via step_scheduler(), NOT here per batch. Calling it here
|
| 3456 |
# compressed the entire T_max=1000 cosine schedule into ~14 hours of
|
|
@@ -4500,7 +4546,7 @@ class QuasarAXRVIBridge:
|
|
| 4500 |
reward_strategy: str = "simple",
|
| 4501 |
hub_ws_url: str = os.environ.get("QUASAR_HUB_URL", "ws://localhost:7860/ws/subscribe"),
|
| 4502 |
enable_logging: bool = True,
|
| 4503 |
-
checkpoint_dir: str = "./
|
| 4504 |
resume: bool = False, # start afresh by default; set True to resume
|
| 4505 |
):
|
| 4506 |
self.config = config or AssetRankerConfig()
|
|
@@ -4812,6 +4858,15 @@ class QuasarAXRVIBridge:
|
|
| 4812 |
with self.position_mgr._lock:
|
| 4813 |
rejected = self.position_mgr._open_trades.pop(trade_id, None)
|
| 4814 |
if rejected:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 4815 |
logger.error(
|
| 4816 |
f"β [{rejected.asset}] BROKER REJECTED buy | "
|
| 4817 |
f"trade_id={trade_id} | code={code} | {message}"
|
|
@@ -4900,8 +4955,16 @@ class QuasarAXRVIBridge:
|
|
| 4900 |
else:
|
| 4901 |
logger.warning(
|
| 4902 |
f"[Deriv] Buy confirmation β no trade_id for req_id={req_id} | "
|
| 4903 |
-
f"contract_id={contract_id} (late or orphaned confirmation)"
|
|
|
|
| 4904 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 4905 |
|
| 4906 |
def _on_poc_update(self, poc: dict, raw_msg: dict) -> None:
|
| 4907 |
"""
|
|
@@ -5257,6 +5320,12 @@ class QuasarAXRVIBridge:
|
|
| 5257 |
for asset in departed:
|
| 5258 |
trade = self.position_mgr.get_open_trade_by_asset(asset)
|
| 5259 |
if trade:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 5260 |
streamer = self.price_streamers.get(asset)
|
| 5261 |
price = streamer.latest_mid if streamer else trade.entry_price
|
| 5262 |
logger.info(
|
|
@@ -5270,11 +5339,18 @@ class QuasarAXRVIBridge:
|
|
| 5270 |
# ββ Position monitoring ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 5271 |
|
| 5272 |
async def _close_position(self, trade_id: str, exit_price: float) -> None:
|
| 5273 |
-
# ββ
|
| 5274 |
-
# Retrieve contract_id BEFORE close_trade() removes the trade from
|
| 5275 |
-
# _open_trades. price=0 means "sell at best available market price".
|
| 5276 |
with self.position_mgr._lock:
|
| 5277 |
open_trade = self.position_mgr._open_trades.get(trade_id)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 5278 |
cid = open_trade.contract_id if open_trade else None
|
| 5279 |
|
| 5280 |
if cid and self.ws_client and self.ws_client.connected:
|
|
@@ -5356,13 +5432,75 @@ class QuasarAXRVIBridge:
|
|
| 5356 |
After any position is closed, if open_trade_count drops below 2, immediately
|
| 5357 |
call rank_and_gate() to refill. This ensures the 2-trade minimum is maintained
|
| 5358 |
continuously without waiting for the next scheduled _rank_loop cycle.
|
|
|
|
|
|
|
|
|
|
|
|
|
| 5359 |
"""
|
| 5360 |
sc = self.config.shreve_config
|
|
|
|
|
|
|
| 5361 |
while self.running:
|
| 5362 |
try:
|
| 5363 |
closed_any = False # track whether we closed a trade this tick
|
| 5364 |
|
| 5365 |
for trade in self.position_mgr.get_open_trades():
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 5366 |
streamer = self.price_streamers.get(trade.asset)
|
| 5367 |
if not streamer:
|
| 5368 |
continue
|
|
@@ -5384,8 +5522,6 @@ class QuasarAXRVIBridge:
|
|
| 5384 |
raw_log_ret = math.log(price / trade.entry_price)
|
| 5385 |
sign = 1.0 if trade.direction == TradeDirection.LONG else -1.0
|
| 5386 |
# Fees for G_t must be on stake, NOT on spot price.
|
| 5387 |
-
# price * commission_rate would be e.g. 316738 * 0.001 = $316 (wrong)
|
| 5388 |
-
# stake * commission_rate is e.g. 1.0 * 0.001 = $0.001 (correct)
|
| 5389 |
fees = self.trade_config.amount * self.trade_config.commission_rate
|
| 5390 |
slippage = self.trade_config.slippage_bps / 10_000.0
|
| 5391 |
g_t = sign * raw_log_ret - fees / price - slippage
|
|
@@ -5932,6 +6068,21 @@ class QuasarAXRVIBridge:
|
|
| 5932 |
next_seq_t, _, _ = self._build_input_tensors()
|
| 5933 |
next_sequences = next_seq_t.squeeze(0).numpy() # (N, T, F)
|
| 5934 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 5935 |
pnl_proxy = np.zeros(n_assets, dtype=np.float32)
|
| 5936 |
if asset_id in self.config.asset_symbols:
|
| 5937 |
pnl_proxy[self.config.asset_symbols.index(asset_id)] = float(reward)
|
|
@@ -5945,6 +6096,8 @@ class QuasarAXRVIBridge:
|
|
| 5945 |
# Girsanov priority metadata [S2/GirsanovReplayBuffer]
|
| 5946 |
"volatility": ep["volatility"],
|
| 5947 |
"td_error": abs(reward),
|
|
|
|
|
|
|
| 5948 |
})
|
| 5949 |
|
| 5950 |
if asset_id in self.config.asset_symbols:
|
|
@@ -6038,7 +6191,7 @@ class RankerCheckpointManager:
|
|
| 6038 |
On load, missing local files are pulled from the repo first.
|
| 6039 |
Requires HF_TOKEN env-var with write permission.
|
| 6040 |
|
| 6041 |
-
Fallback : Local disk ./
|
| 6042 |
Used when HF_TOKEN is absent or upload/download fails.
|
| 6043 |
All saves still succeed locally even without network access.
|
| 6044 |
|
|
@@ -6092,7 +6245,7 @@ class RankerCheckpointManager:
|
|
| 6092 |
|
| 6093 |
Usage
|
| 6094 |
βββββ
|
| 6095 |
-
mgr = RankerCheckpointManager(checkpoint_dir="./
|
| 6096 |
save_interval_seconds=1800)
|
| 6097 |
mgr.load(bridge) # once after initialize()
|
| 6098 |
mgr.maybe_save(bridge) # call frequently; respects save_interval_seconds
|
|
@@ -6108,7 +6261,7 @@ class RankerCheckpointManager:
|
|
| 6108 |
|
| 6109 |
def __init__(
|
| 6110 |
self,
|
| 6111 |
-
checkpoint_dir: str = "./
|
| 6112 |
save_interval_seconds: float = 1800.0, # 30 minutes
|
| 6113 |
):
|
| 6114 |
import pathlib
|
|
@@ -6162,7 +6315,7 @@ class RankerCheckpointManager:
|
|
| 6162 |
)
|
| 6163 |
logger.info(
|
| 6164 |
f"[RankerCheckpoint] βοΈ Uploaded {local_path.name} β "
|
| 6165 |
-
f"hf://{self.HF_REPO_ID}/
|
| 6166 |
)
|
| 6167 |
return True
|
| 6168 |
except Exception as exc:
|
|
@@ -6193,7 +6346,7 @@ class RankerCheckpointManager:
|
|
| 6193 |
)
|
| 6194 |
logger.info(
|
| 6195 |
f"[RankerCheckpoint] β¬οΈ Downloaded {filename} from "
|
| 6196 |
-
f"hf://{self.HF_REPO_ID}/
|
| 6197 |
)
|
| 6198 |
return True
|
| 6199 |
except Exception as exc:
|
|
@@ -6734,7 +6887,7 @@ async def run_live_trading_system(
|
|
| 6734 |
hub_ws_url: str = "ws://localhost:7860/ws/subscribe",
|
| 6735 |
enable_logging: bool = True,
|
| 6736 |
shreve_config: Optional[ShreveConfig] = None,
|
| 6737 |
-
checkpoint_dir: str = "./
|
| 6738 |
resume: bool = False, # start fresh by default
|
| 6739 |
) -> None:
|
| 6740 |
config = AssetRankerConfig(
|
|
@@ -7046,8 +7199,8 @@ def _parse_args():
|
|
| 7046 |
help="[S6/S8] Trade horizon Ο in seconds (default 60)")
|
| 7047 |
parser.add_argument("--martingale-epsilon", type=float, default=0.05,
|
| 7048 |
help="[S7] Gate E martingale deviation threshold (default 0.05)")
|
| 7049 |
-
parser.add_argument("--checkpoint-dir", default="./
|
| 7050 |
-
help="Directory for full-state checkpoints (default ./
|
| 7051 |
parser.add_argument("--resume", action="store_true",
|
| 7052 |
help="Resume training from the latest saved checkpoint (default: start fresh)")
|
| 7053 |
return parser.parse_args(filtered)
|
|
|
|
| 392 |
LAMBDA_RANK = 0.4
|
| 393 |
LAMBDA_RISK = 0.3
|
| 394 |
REPLAY_CAPACITY = 10_000
|
| 395 |
+
TRAIN_BATCH = 8 # Lowered from 16 β fills after ~8 closed trades (~2-4 min)
|
| 396 |
+
TRAIN_EVERY_N = 5 # Lowered from 10 β checks buffer every ~25s
|
| 397 |
|
| 398 |
# Connection
|
| 399 |
WS_RECONNECT_DELAY = 5
|
|
|
|
| 1751 |
regime_weights: (B, N) optional multiplier (e.g., crash probability)
|
| 1752 |
"""
|
| 1753 |
B, N, _ = asset_embeddings.shape
|
| 1754 |
+
# Clamp N to max_assets β extra assets are simply not tracked
|
| 1755 |
+
N = min(N, self.max_assets)
|
| 1756 |
|
| 1757 |
+
# Use only the first N (clamped) assets
|
| 1758 |
+
asset_emb_n = asset_embeddings[:, :N, :]
|
| 1759 |
+
|
| 1760 |
# Activity proxy: L2 norm normalized across batch
|
| 1761 |
+
activity = asset_emb_n.norm(dim=-1)
|
| 1762 |
mu = activity.mean(dim=0, keepdim=True)
|
| 1763 |
std = activity.std(dim=0, keepdim=True) + 1e-6
|
| 1764 |
z_activity = (activity - mu) / std
|
|
|
|
| 1792 |
self.step_count += 1
|
| 1793 |
|
| 1794 |
def get_adapted_bias(self, base_adj_bias: torch.Tensor, N: int) -> torch.Tensor:
|
| 1795 |
+
"""Return base bias + accumulated STDP delta, clamped to registered max_assets."""
|
| 1796 |
+
N_clamped = min(N, self.max_assets)
|
| 1797 |
+
if N > self.max_assets:
|
| 1798 |
+
# Pad with zeros for the extra assets beyond max_assets
|
| 1799 |
+
pad = N - self.max_assets
|
| 1800 |
+
padded_delta = F.pad(self.stdp_delta[:, :N_clamped, :N_clamped],
|
| 1801 |
+
(0, pad, 0, pad), value=0.0)
|
| 1802 |
+
padded_base = base_adj_bias[:, :N, :N]
|
| 1803 |
+
return padded_base + padded_delta
|
| 1804 |
+
return base_adj_bias[:, :N_clamped, :N_clamped] + self.stdp_delta[:, :N_clamped, :N_clamped]
|
| 1805 |
|
| 1806 |
def reset_plasticity(self) -> None:
|
| 1807 |
"""Reset STDP state (call on regime shift detection)."""
|
|
|
|
| 2028 |
nn.Linear(d_model, d_model)
|
| 2029 |
)
|
| 2030 |
|
| 2031 |
+
# Learnable adjacency bias β sized to MAX_ASSETS; dynamically padded in forward()
|
| 2032 |
self.adj_bias = nn.Parameter(torch.zeros(num_heads, self.MAX_ASSETS, self.MAX_ASSETS))
|
| 2033 |
|
| 2034 |
# STDP plasticity
|
|
|
|
| 2070 |
|
| 2071 |
adapted_bias = self.stdp.get_adapted_bias(self.adj_bias, N)
|
| 2072 |
|
| 2073 |
+
# If N > MAX_ASSETS the adapted_bias may be smaller; pad with zeros
|
| 2074 |
+
if adapted_bias.shape[-1] < N or adapted_bias.shape[-2] < N:
|
| 2075 |
+
pad_r = N - adapted_bias.shape[-1]
|
| 2076 |
+
pad_c = N - adapted_bias.shape[-2]
|
| 2077 |
+
adapted_bias = F.pad(adapted_bias, (0, pad_r, 0, pad_c), value=0.0)
|
| 2078 |
+
|
| 2079 |
attn = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k)
|
| 2080 |
attn = attn + adapted_bias.unsqueeze(0)
|
| 2081 |
|
|
|
|
| 2477 |
[5] KANScoringHead β significance scores
|
| 2478 |
"""
|
| 2479 |
|
| 2480 |
+
def __init__(self, num_assets: int = 5, config: AXRVIConfig = DEFAULT_CONFIG,
|
| 2481 |
+
feature_dim: int = None, seq_len: int = None):
|
| 2482 |
+
"""
|
| 2483 |
+
Args:
|
| 2484 |
+
num_assets : number of parallel asset streams (N dimension)
|
| 2485 |
+
config : AXRVIConfig β all hyperparameters
|
| 2486 |
+
feature_dim : override config.feature_dim (backward compat)
|
| 2487 |
+
seq_len : override config.seq_len (backward compat)
|
| 2488 |
+
"""
|
| 2489 |
super().__init__()
|
| 2490 |
+
# Apply any direct overrides so callers that pass feature_dim/seq_len
|
| 2491 |
+
# directly (e.g. test_components) still get the right architecture.
|
| 2492 |
+
if feature_dim is not None or seq_len is not None:
|
| 2493 |
+
import copy
|
| 2494 |
+
config = copy.copy(config)
|
| 2495 |
+
if feature_dim is not None:
|
| 2496 |
+
config.feature_dim = feature_dim
|
| 2497 |
+
if seq_len is not None:
|
| 2498 |
+
config.seq_len = seq_len
|
| 2499 |
self.num_assets = num_assets
|
| 2500 |
self.config = config
|
| 2501 |
self.d_model = config.d_model
|
|
|
|
| 2845 |
def _axrvi_config_from_ranker_config(rc: "AssetRankerConfig") -> AXRVIConfig:
|
| 2846 |
"""
|
| 2847 |
Derive an ``AXRVIConfig`` from an ``AssetRankerConfig`` so all dimension
|
| 2848 |
+
constants live in one place.
|
| 2849 |
+
|
| 2850 |
+
Called by ``QuasarAXRVIBridge.initialize()`` to construct the neural-net
|
| 2851 |
+
config directly from the top-level ranker config, avoiding duplicated
|
| 2852 |
+
constants. Any stand-alone caller that has an ``AssetRankerConfig`` may
|
| 2853 |
+
use this function as well.
|
| 2854 |
|
| 2855 |
The mapping is:
|
| 2856 |
rc.feature_dim β AXRVIConfig.feature_dim (26)
|
|
|
|
| 3365 |
self.train_step = 0
|
| 3366 |
self.loss_history = deque(maxlen=200)
|
| 3367 |
|
| 3368 |
+
# ββ AMP (Automatic Mixed Precision) ββββββββββββββββββββββββββββββββββ
|
| 3369 |
+
# Reads use_amp from the model's AXRVIConfig so the flag is the single
|
| 3370 |
+
# source of truth. Falls back to False on CPU (AMP is CUDA-only).
|
| 3371 |
+
_cfg = getattr(model, "config", None)
|
| 3372 |
+
_use_amp = bool(getattr(_cfg, "use_amp", False))
|
| 3373 |
+
_device = next(model.parameters()).device
|
| 3374 |
+
# AMP is only meaningful on CUDA; silently disable on CPU
|
| 3375 |
+
self.use_amp = _use_amp and _device.type == "cuda"
|
| 3376 |
+
self.scaler = torch.cuda.amp.GradScaler(enabled=self.use_amp)
|
| 3377 |
+
|
| 3378 |
def train_on_batch(self, episodes: List[dict]) -> dict:
|
| 3379 |
"""
|
| 3380 |
v7 training β 5-loss objective:
|
|
|
|
| 3419 |
imp_w = torch.tensor([ep.get("importance_weight", 1.0) for ep in valid],
|
| 3420 |
dtype=torch.float32, device=device)
|
| 3421 |
|
| 3422 |
+
# ββ Forward pass (AMP-aware) ββββββββββββββββββββββββββββββββββββββββββ
|
| 3423 |
+
with torch.cuda.amp.autocast(enabled=self.use_amp):
|
| 3424 |
+
out = self.model(seq_t)
|
| 3425 |
+
scores = out["significance_logits"]
|
| 3426 |
+
value = out["value"].squeeze(-1) # (B, N) β median QΜ_{0.5}
|
| 3427 |
+
log_var = out["log_var"].squeeze(-1) # (B, N)
|
| 3428 |
+
quantiles = out["quantiles"] # (B, N, n_quantiles)
|
| 3429 |
+
|
| 3430 |
+
with torch.no_grad():
|
| 3431 |
+
next_out = self.model(next_seq_t)
|
| 3432 |
+
best_next_v = next_out["value"].squeeze(-1).max(dim=1).values
|
| 3433 |
+
|
| 3434 |
+
# L_rl β TD error
|
| 3435 |
+
selected_v = value.gather(1, selected.unsqueeze(1)).squeeze(1)
|
| 3436 |
+
td_target = rewards + self.gamma * best_next_v
|
| 3437 |
+
l_rl = (imp_w * F.mse_loss(selected_v, td_target.detach(),
|
| 3438 |
+
reduction="none")).mean()
|
| 3439 |
+
|
| 3440 |
+
# L_ce β [S1] Value-consistency: selected_v β E[R_{t+Ο} | F_t]
|
| 3441 |
+
l_ce = (imp_w * F.mse_loss(selected_v, rewards.detach(),
|
| 3442 |
+
reduction="none")).mean()
|
| 3443 |
+
|
| 3444 |
+
# L_rank β Ranking margin
|
| 3445 |
+
best_idx = pnl_arr.argmax(dim=1)
|
| 3446 |
+
worst_idx = pnl_arr.argmin(dim=1)
|
| 3447 |
+
l_rank = F.relu(
|
| 3448 |
+
self.rank_margin
|
| 3449 |
+
- (scores.gather(1, best_idx.unsqueeze(1)).squeeze(1)
|
| 3450 |
+
- scores.gather(1, worst_idx.unsqueeze(1)).squeeze(1))
|
| 3451 |
+
).mean()
|
| 3452 |
+
|
| 3453 |
+
# L_risk β Uncertainty penalty
|
| 3454 |
+
l_risk = torch.exp(log_var.gather(1, selected.unsqueeze(1)).squeeze(1)).mean()
|
| 3455 |
+
|
| 3456 |
+
# L_ql β Quantile / pinball loss [v7]: calibrates full return distribution
|
| 3457 |
+
sel_q = quantiles.gather(
|
| 3458 |
+
1,
|
| 3459 |
+
selected.unsqueeze(1).unsqueeze(2).expand(-1, 1, quantiles.shape[-1])
|
| 3460 |
+
).squeeze(1) # (B, n_quantiles)
|
| 3461 |
+
|
| 3462 |
+
tau = self.model.distributional.quantile_levels # (n_quantiles,)
|
| 3463 |
+
u = rewards.unsqueeze(1) - sel_q # (B, n_quantiles)
|
| 3464 |
+
l_ql = torch.max(tau * u, (tau - 1.0) * u).mean()
|
| 3465 |
+
|
| 3466 |
+
# L_moe β MoE load-balance regularisation [v8]
|
| 3467 |
+
l_moe = out.get("moe_balance_loss", torch.tensor(0.0, device=seq_t.device))
|
| 3468 |
+
|
| 3469 |
+
# L_gate β DendriticFFN gate-entropy regularisation [v8]
|
| 3470 |
+
l_gate = out.get("gate_entropy_loss", torch.tensor(0.0, device=seq_t.device))
|
| 3471 |
+
|
| 3472 |
+
# L_crps β distributional calibration (CRPS proper scoring rule) [v8]
|
| 3473 |
+
ql_levels = self.model.distributional.quantile_levels
|
| 3474 |
+
l_crps = crps_loss(quantiles, ql_levels, rewards)
|
| 3475 |
+
|
| 3476 |
+
# L_rent β regime-router entropy regularisation [v8]
|
| 3477 |
+
r_probs = out["regime_probs"] + 1e-8 # (B, N, n_regimes)
|
| 3478 |
+
regime_ent = -(r_probs * r_probs.log()).sum(-1).mean()
|
| 3479 |
+
l_rent = -regime_ent
|
| 3480 |
+
|
| 3481 |
+
# Total loss β 9-component objective [v8 complete]
|
| 3482 |
+
loss = (l_rl
|
| 3483 |
+
+ self.lambda_ce * l_ce
|
| 3484 |
+
+ self.lambda_rank * l_rank
|
| 3485 |
+
+ self.lambda_risk * l_risk
|
| 3486 |
+
+ self.lambda_ql * l_ql
|
| 3487 |
+
+ self.lambda_moe * l_moe
|
| 3488 |
+
+ self.lambda_gate * l_gate
|
| 3489 |
+
+ self.lambda_crps * l_crps
|
| 3490 |
+
+ self.lambda_rent * l_rent)
|
| 3491 |
+
|
| 3492 |
+
# ββ Backward pass (AMP-aware) βββββββββββββββββββββββββββββββββββββββββ
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3493 |
self.optimizer.zero_grad()
|
| 3494 |
+
self.scaler.scale(loss).backward()
|
| 3495 |
+
# Unscale before clipping so grad norms are in the original fp32 scale
|
| 3496 |
+
self.scaler.unscale_(self.optimizer)
|
| 3497 |
torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)
|
| 3498 |
+
self.scaler.step(self.optimizer)
|
| 3499 |
+
self.scaler.update()
|
| 3500 |
# scheduler.step() is called externally (once per training epoch / rank
|
| 3501 |
# cycle) via step_scheduler(), NOT here per batch. Calling it here
|
| 3502 |
# compressed the entire T_max=1000 cosine schedule into ~14 hours of
|
|
|
|
| 4546 |
reward_strategy: str = "simple",
|
| 4547 |
hub_ws_url: str = os.environ.get("QUASAR_HUB_URL", "ws://localhost:7860/ws/subscribe"),
|
| 4548 |
enable_logging: bool = True,
|
| 4549 |
+
checkpoint_dir: str = "./Ranker3", # folder for full-state checkpoints
|
| 4550 |
resume: bool = False, # start afresh by default; set True to resume
|
| 4551 |
):
|
| 4552 |
self.config = config or AssetRankerConfig()
|
|
|
|
| 4858 |
with self.position_mgr._lock:
|
| 4859 |
rejected = self.position_mgr._open_trades.pop(trade_id, None)
|
| 4860 |
if rejected:
|
| 4861 |
+
# ββ Clean up all trade-level state so nothing leaks ββ
|
| 4862 |
+
# _pending_episodes: episode was opened by _open_pending_episode
|
| 4863 |
+
# but there will never be a close event β discard it.
|
| 4864 |
+
self._pending_episodes.pop(trade_id, None)
|
| 4865 |
+
# _trade_tick_counts: monitor_positions tracks this; clear it.
|
| 4866 |
+
self._trade_tick_counts.pop(trade_id, None)
|
| 4867 |
+
# portfolio_risk_mgr: committed capital was registered in
|
| 4868 |
+
# process_axrvi_signal; release it with 0 PnL.
|
| 4869 |
+
self.portfolio_risk_mgr.register_close(trade_id, 0.0)
|
| 4870 |
logger.error(
|
| 4871 |
f"β [{rejected.asset}] BROKER REJECTED buy | "
|
| 4872 |
f"trade_id={trade_id} | code={code} | {message}"
|
|
|
|
| 4955 |
else:
|
| 4956 |
logger.warning(
|
| 4957 |
f"[Deriv] Buy confirmation β no trade_id for req_id={req_id} | "
|
| 4958 |
+
f"contract_id={contract_id} (late or orphaned confirmation) β "
|
| 4959 |
+
f"sending immediate SELL to avoid dangling open contract on broker side"
|
| 4960 |
)
|
| 4961 |
+
# The contract is live on Deriv but we have no internal tracking for
|
| 4962 |
+
# it. Immediately sell at market so we don't accumulate un-tracked
|
| 4963 |
+
# open contracts that drain the account without appearing in our books.
|
| 4964 |
+
if self.ws_client and self.ws_client.connected:
|
| 4965 |
+
asyncio.get_running_loop().create_task(
|
| 4966 |
+
self.ws_client.send_message({"sell": contract_id, "price": 0})
|
| 4967 |
+
)
|
| 4968 |
|
| 4969 |
def _on_poc_update(self, poc: dict, raw_msg: dict) -> None:
|
| 4970 |
"""
|
|
|
|
| 5320 |
for asset in departed:
|
| 5321 |
trade = self.position_mgr.get_open_trade_by_asset(asset)
|
| 5322 |
if trade:
|
| 5323 |
+
# Skip if already closing β SELL was already sent to broker
|
| 5324 |
+
if trade.state == PositionState.CLOSING:
|
| 5325 |
+
logger.debug(
|
| 5326 |
+
f"[Rotation] β© {asset} already CLOSING β skipping duplicate SELL"
|
| 5327 |
+
)
|
| 5328 |
+
continue
|
| 5329 |
streamer = self.price_streamers.get(asset)
|
| 5330 |
price = streamer.latest_mid if streamer else trade.entry_price
|
| 5331 |
logger.info(
|
|
|
|
| 5339 |
# ββ Position monitoring ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 5340 |
|
| 5341 |
async def _close_position(self, trade_id: str, exit_price: float) -> None:
|
| 5342 |
+
# ββ Prevent re-sending SELL for a contract already in CLOSING state β
|
|
|
|
|
|
|
| 5343 |
with self.position_mgr._lock:
|
| 5344 |
open_trade = self.position_mgr._open_trades.get(trade_id)
|
| 5345 |
+
if open_trade is None:
|
| 5346 |
+
return # already closed
|
| 5347 |
+
if open_trade.state == PositionState.CLOSING:
|
| 5348 |
+
# Sell already sent; don't spam the broker β just wait for POC
|
| 5349 |
+
logger.debug(
|
| 5350 |
+
f"[{trade_id}] β³ Already CLOSING β skipping duplicate SELL | "
|
| 5351 |
+
f"contract_id={open_trade.contract_id}"
|
| 5352 |
+
)
|
| 5353 |
+
return
|
| 5354 |
cid = open_trade.contract_id if open_trade else None
|
| 5355 |
|
| 5356 |
if cid and self.ws_client and self.ws_client.connected:
|
|
|
|
| 5432 |
After any position is closed, if open_trade_count drops below 2, immediately
|
| 5433 |
call rank_and_gate() to refill. This ensures the 2-trade minimum is maintained
|
| 5434 |
continuously without waiting for the next scheduled _rank_loop cycle.
|
| 5435 |
+
|
| 5436 |
+
CLOSING TIMEOUT FIX (v6.1):
|
| 5437 |
+
If a trade remains in CLOSING state for > 10 seconds without a terminal
|
| 5438 |
+
event from the broker, force-close it locally to prevent stuck trades.
|
| 5439 |
"""
|
| 5440 |
sc = self.config.shreve_config
|
| 5441 |
+
CLOSING_TIMEOUT_SECONDS = 10.0 # Maximum time to wait for broker terminal event
|
| 5442 |
+
|
| 5443 |
while self.running:
|
| 5444 |
try:
|
| 5445 |
closed_any = False # track whether we closed a trade this tick
|
| 5446 |
|
| 5447 |
for trade in self.position_mgr.get_open_trades():
|
| 5448 |
+
# ββ CLOSING TIMEOUT HANDLER ββββββββββββββββββββββββββββββββββ
|
| 5449 |
+
# If trade has been CLOSING for > CLOSING_TIMEOUT_SECONDS,
|
| 5450 |
+
# force-close it locally (broker never responded)
|
| 5451 |
+
if trade.state == PositionState.CLOSING:
|
| 5452 |
+
closing_duration = time.time() - trade.exit_time if trade.exit_time else 0
|
| 5453 |
+
if closing_duration > CLOSING_TIMEOUT_SECONDS:
|
| 5454 |
+
logger.warning(
|
| 5455 |
+
f"[{trade.asset}] β οΈ CLOSING TIMEOUT | "
|
| 5456 |
+
f"trade_id={trade.trade_id} | "
|
| 5457 |
+
f"contract_id={trade.contract_id} | "
|
| 5458 |
+
f"stuck in CLOSING for {closing_duration:.1f}s β "
|
| 5459 |
+
f"forcing local close"
|
| 5460 |
+
)
|
| 5461 |
+
# Force close locally (broker never responded)
|
| 5462 |
+
# Use current price as exit price
|
| 5463 |
+
streamer = self.price_streamers.get(trade.asset)
|
| 5464 |
+
price = streamer.latest_mid if streamer else trade.entry_price
|
| 5465 |
+
|
| 5466 |
+
# Estimate profit from broker data if available
|
| 5467 |
+
if trade.profit is not None:
|
| 5468 |
+
profit = trade.profit
|
| 5469 |
+
else:
|
| 5470 |
+
# Fallback: estimate from price movement
|
| 5471 |
+
if price > 0 and trade.entry_price > 0:
|
| 5472 |
+
pct_move = (price - trade.entry_price) / trade.entry_price
|
| 5473 |
+
sign = 1.0 if trade.direction == TradeDirection.LONG else -1.0
|
| 5474 |
+
mult = ASSET_MULTIPLIER.get(trade.asset, 50)
|
| 5475 |
+
stake = trade.buy_price if (trade.buy_price and trade.buy_price > 0) else 1.0
|
| 5476 |
+
profit = sign * pct_move * stake * mult
|
| 5477 |
+
else:
|
| 5478 |
+
profit = 0.0
|
| 5479 |
+
|
| 5480 |
+
closed_trade = self.position_mgr.close_trade_from_broker(
|
| 5481 |
+
trade_id=trade.trade_id,
|
| 5482 |
+
status="timeout",
|
| 5483 |
+
profit=profit,
|
| 5484 |
+
sell_price=price,
|
| 5485 |
+
exit_tick=price,
|
| 5486 |
+
)
|
| 5487 |
+
|
| 5488 |
+
if closed_trade:
|
| 5489 |
+
reward = self._reward_from_broker(closed_trade)
|
| 5490 |
+
self.portfolio_risk_mgr.register_close(trade.trade_id, closed_trade.realized_pnl)
|
| 5491 |
+
self._close_pending_episode(trade.trade_id, reward)
|
| 5492 |
+
self._trade_tick_counts.pop(trade.trade_id, None)
|
| 5493 |
+
|
| 5494 |
+
self.stats["trades_closed"] += 1
|
| 5495 |
+
self.stats["total_pnl"] += closed_trade.realized_pnl
|
| 5496 |
+
closed_any = True
|
| 5497 |
+
|
| 5498 |
+
logger.info(
|
| 5499 |
+
f"π° [{closed_trade.asset}] TRADE FORCE-CLOSED (timeout) | "
|
| 5500 |
+
f"reward={reward:+.6f} | profit={profit:+.4f}"
|
| 5501 |
+
)
|
| 5502 |
+
continue # Skip other checks for CLOSING trades
|
| 5503 |
+
|
| 5504 |
streamer = self.price_streamers.get(trade.asset)
|
| 5505 |
if not streamer:
|
| 5506 |
continue
|
|
|
|
| 5522 |
raw_log_ret = math.log(price / trade.entry_price)
|
| 5523 |
sign = 1.0 if trade.direction == TradeDirection.LONG else -1.0
|
| 5524 |
# Fees for G_t must be on stake, NOT on spot price.
|
|
|
|
|
|
|
| 5525 |
fees = self.trade_config.amount * self.trade_config.commission_rate
|
| 5526 |
slippage = self.trade_config.slippage_bps / 10_000.0
|
| 5527 |
g_t = sign * raw_log_ret - fees / price - slippage
|
|
|
|
| 6068 |
next_seq_t, _, _ = self._build_input_tensors()
|
| 6069 |
next_sequences = next_seq_t.squeeze(0).numpy() # (N, T, F)
|
| 6070 |
|
| 6071 |
+
# ββ Episode validity check: s_t and s_{t+1} must differ ββββββββββββ
|
| 6072 |
+
# If they are identical the TD error collapses to zero and training
|
| 6073 |
+
# produces no gradient. This happens when a trade opens and closes
|
| 6074 |
+
# within the same rank cycle before any price ticks arrive.
|
| 6075 |
+
# In that case we still push the episode but tag it so HybridTrainer
|
| 6076 |
+
# can apply a lower importance weight.
|
| 6077 |
+
state_diff = float(np.mean(np.abs(next_sequences - sequences)))
|
| 6078 |
+
if state_diff < 1e-6:
|
| 6079 |
+
logger.warning(
|
| 6080 |
+
f"[_close_pending_episode] [{asset_id}] s_t β s_{{t+1}} "
|
| 6081 |
+
f"(diff={state_diff:.2e}) β episode may not produce useful gradient. "
|
| 6082 |
+
f"Trade may have closed before a rank cycle completed."
|
| 6083 |
+
)
|
| 6084 |
+
# Still push β the buffer needs data; the trainer will use low importance_weight
|
| 6085 |
+
|
| 6086 |
pnl_proxy = np.zeros(n_assets, dtype=np.float32)
|
| 6087 |
if asset_id in self.config.asset_symbols:
|
| 6088 |
pnl_proxy[self.config.asset_symbols.index(asset_id)] = float(reward)
|
|
|
|
| 6096 |
# Girsanov priority metadata [S2/GirsanovReplayBuffer]
|
| 6097 |
"volatility": ep["volatility"],
|
| 6098 |
"td_error": abs(reward),
|
| 6099 |
+
# State diversity marker β used for importance weighting
|
| 6100 |
+
"state_diff": state_diff,
|
| 6101 |
})
|
| 6102 |
|
| 6103 |
if asset_id in self.config.asset_symbols:
|
|
|
|
| 6191 |
On load, missing local files are pulled from the repo first.
|
| 6192 |
Requires HF_TOKEN env-var with write permission.
|
| 6193 |
|
| 6194 |
+
Fallback : Local disk ./Ranker3/
|
| 6195 |
Used when HF_TOKEN is absent or upload/download fails.
|
| 6196 |
All saves still succeed locally even without network access.
|
| 6197 |
|
|
|
|
| 6245 |
|
| 6246 |
Usage
|
| 6247 |
βββββ
|
| 6248 |
+
mgr = RankerCheckpointManager(checkpoint_dir="./Ranker3",
|
| 6249 |
save_interval_seconds=1800)
|
| 6250 |
mgr.load(bridge) # once after initialize()
|
| 6251 |
mgr.maybe_save(bridge) # call frequently; respects save_interval_seconds
|
|
|
|
| 6261 |
|
| 6262 |
def __init__(
|
| 6263 |
self,
|
| 6264 |
+
checkpoint_dir: str = "./Ranker3",
|
| 6265 |
save_interval_seconds: float = 1800.0, # 30 minutes
|
| 6266 |
):
|
| 6267 |
import pathlib
|
|
|
|
| 6315 |
)
|
| 6316 |
logger.info(
|
| 6317 |
f"[RankerCheckpoint] βοΈ Uploaded {local_path.name} β "
|
| 6318 |
+
f"hf://{self.HF_REPO_ID}/Ranker3/{local_path.name}"
|
| 6319 |
)
|
| 6320 |
return True
|
| 6321 |
except Exception as exc:
|
|
|
|
| 6346 |
)
|
| 6347 |
logger.info(
|
| 6348 |
f"[RankerCheckpoint] β¬οΈ Downloaded {filename} from "
|
| 6349 |
+
f"hf://{self.HF_REPO_ID}/Ranker3/"
|
| 6350 |
)
|
| 6351 |
return True
|
| 6352 |
except Exception as exc:
|
|
|
|
| 6887 |
hub_ws_url: str = "ws://localhost:7860/ws/subscribe",
|
| 6888 |
enable_logging: bool = True,
|
| 6889 |
shreve_config: Optional[ShreveConfig] = None,
|
| 6890 |
+
checkpoint_dir: str = "./Ranker3",
|
| 6891 |
resume: bool = False, # start fresh by default
|
| 6892 |
) -> None:
|
| 6893 |
config = AssetRankerConfig(
|
|
|
|
| 7199 |
help="[S6/S8] Trade horizon Ο in seconds (default 60)")
|
| 7200 |
parser.add_argument("--martingale-epsilon", type=float, default=0.05,
|
| 7201 |
help="[S7] Gate E martingale deviation threshold (default 0.05)")
|
| 7202 |
+
parser.add_argument("--checkpoint-dir", default="./Ranker3",
|
| 7203 |
+
help="Directory for full-state checkpoints (default ./Ranker3)")
|
| 7204 |
parser.add_argument("--resume", action="store_true",
|
| 7205 |
help="Resume training from the latest saved checkpoint (default: start fresh)")
|
| 7206 |
return parser.parse_args(filtered)
|