Spaces:
Running
Running
Upload Quasar_axrvi_ranker.py
Browse files- Quasar_axrvi_ranker.py +6 -4
Quasar_axrvi_ranker.py
CHANGED
|
@@ -2254,7 +2254,8 @@ class MoETemporalEncoder(nn.Module):
|
|
| 2254 |
self.router = nn.Linear(input_dim, n_experts)
|
| 2255 |
|
| 2256 |
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 2257 |
-
|
|
|
|
| 2258 |
|
| 2259 |
# Routing from first bar
|
| 2260 |
router_logits = self.router(x[:, 0, :])
|
|
@@ -2528,10 +2529,11 @@ class AXRVINet(nn.Module):
|
|
| 2528 |
Returns:
|
| 2529 |
Dictionary with all outputs
|
| 2530 |
"""
|
| 2531 |
-
|
| 2532 |
-
|
|
|
|
| 2533 |
# Temporal encoding
|
| 2534 |
-
h = self.temporal(sequences.view(B * N, T,
|
| 2535 |
|
| 2536 |
# Cross-asset processing
|
| 2537 |
total_gate_entropy = torch.tensor(0.0, device=h.device)
|
|
|
|
| 2254 |
self.router = nn.Linear(input_dim, n_experts)
|
| 2255 |
|
| 2256 |
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 2257 |
+
|
| 2258 |
+
BN, T, feat_dim = x.shape
|
| 2259 |
|
| 2260 |
# Routing from first bar
|
| 2261 |
router_logits = self.router(x[:, 0, :])
|
|
|
|
| 2529 |
Returns:
|
| 2530 |
Dictionary with all outputs
|
| 2531 |
"""
|
| 2532 |
+
|
| 2533 |
+
B, N, T, feat_dim = sequences.shape
|
| 2534 |
+
|
| 2535 |
# Temporal encoding
|
| 2536 |
+
h = self.temporal(sequences.view(B * N, T, feat_dim)).view(B, N, self.d_model)
|
| 2537 |
|
| 2538 |
# Cross-asset processing
|
| 2539 |
total_gate_entropy = torch.tensor(0.0, device=h.device)
|