KarlQuant commited on
Commit
4965acd
·
verified ·
1 Parent(s): fd2fb28

Upload Quasar_axrvi_ranker.py

Browse files
Files changed (1) hide show
  1. Quasar_axrvi_ranker.py +6 -4
Quasar_axrvi_ranker.py CHANGED
@@ -2254,7 +2254,8 @@ class MoETemporalEncoder(nn.Module):
2254
  self.router = nn.Linear(input_dim, n_experts)
2255
 
2256
  def forward(self, x: torch.Tensor) -> torch.Tensor:
2257
- BN, T, F = x.shape
 
2258
 
2259
  # Routing from first bar
2260
  router_logits = self.router(x[:, 0, :])
@@ -2528,10 +2529,11 @@ class AXRVINet(nn.Module):
2528
  Returns:
2529
  Dictionary with all outputs
2530
  """
2531
- B, N, T, F = sequences.shape
2532
-
 
2533
  # Temporal encoding
2534
- h = self.temporal(sequences.view(B * N, T, F)).view(B, N, self.d_model)
2535
 
2536
  # Cross-asset processing
2537
  total_gate_entropy = torch.tensor(0.0, device=h.device)
 
2254
  self.router = nn.Linear(input_dim, n_experts)
2255
 
2256
  def forward(self, x: torch.Tensor) -> torch.Tensor:
2257
+
2258
+ BN, T, feat_dim = x.shape
2259
 
2260
  # Routing from first bar
2261
  router_logits = self.router(x[:, 0, :])
 
2529
  Returns:
2530
  Dictionary with all outputs
2531
  """
2532
+
2533
+ B, N, T, feat_dim = sequences.shape
2534
+
2535
  # Temporal encoding
2536
+ h = self.temporal(sequences.view(B * N, T, feat_dim)).view(B, N, self.d_model)
2537
 
2538
  # Cross-asset processing
2539
  total_gate_entropy = torch.tensor(0.0, device=h.device)