Spaces:
Running
Running
Upload Quasar_axrvi_ranker.py
Browse files- Quasar_axrvi_ranker.py +34 -0
Quasar_axrvi_ranker.py
CHANGED
|
@@ -3870,6 +3870,7 @@ class QCSAMCrossAssetLayer(nn.Module):
|
|
| 3870 |
# Diagnostics
|
| 3871 |
self.last_align_loss: float = 0.0
|
| 3872 |
self.last_qmha_diagnostics: dict = {}
|
|
|
|
| 3873 |
|
| 3874 |
def forward(
|
| 3875 |
self,
|
|
@@ -3946,8 +3947,27 @@ class QCSAMCrossAssetLayer(nn.Module):
|
|
| 3946 |
h_out = h + gate * delta # (B, N, d_model) float32
|
| 3947 |
|
| 3948 |
# ββ Mark first forward complete βββββββββββββββββββββββββββββββββββ
|
|
|
|
| 3949 |
QCSAMCrossAssetLayer._first_forward_complete = True
|
| 3950 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3951 |
return h_out, align_loss
|
| 3952 |
|
| 3953 |
|
|
@@ -4244,6 +4264,20 @@ class AXRVINet(nn.Module):
|
|
| 4244 |
# Average align loss over MC samples
|
| 4245 |
mean_align = sum(mc_align) / len(mc_align) if mc_align else torch.tensor(0.0)
|
| 4246 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 4247 |
return {
|
| 4248 |
"significance_weight": mean_sig,
|
| 4249 |
"epistemic_variance": epistemic_var,
|
|
|
|
| 3870 |
# Diagnostics
|
| 3871 |
self.last_align_loss: float = 0.0
|
| 3872 |
self.last_qmha_diagnostics: dict = {}
|
| 3873 |
+
self._fwd_count: int = 0 # counts forward calls for periodic logging
|
| 3874 |
|
| 3875 |
def forward(
|
| 3876 |
self,
|
|
|
|
| 3947 |
h_out = h + gate * delta # (B, N, d_model) float32
|
| 3948 |
|
| 3949 |
# ββ Mark first forward complete βββββββββββββββββββββββββββββββββββ
|
| 3950 |
+
self._fwd_count += 1
|
| 3951 |
QCSAMCrossAssetLayer._first_forward_complete = True
|
| 3952 |
|
| 3953 |
+
# ββ QCSAM diagnostic logging ββββββββββββββββββββββββββββββββββββββ
|
| 3954 |
+
# Log on the very first forward pass, then every 100 calls.
|
| 3955 |
+
# (10 MC-samples Γ ~5s rank cycle β every ~50s in steady state.)
|
| 3956 |
+
if self._fwd_count == 1 or self._fwd_count % 100 == 0:
|
| 3957 |
+
gate_val = torch.sigmoid(self.residual_gate).item()
|
| 3958 |
+
heads_info = " | ".join(
|
| 3959 |
+
f"h{i}: attn_norm={hd.get('attn_matrix_norm', 0.0):.4f} "
|
| 3960 |
+
f"align_h={hd.get('align_loss_h', 0.0):.6f}"
|
| 3961 |
+
for i, hd in enumerate(self.last_qmha_diagnostics.get("heads", []))
|
| 3962 |
+
)
|
| 3963 |
+
logger.info(
|
| 3964 |
+
f"π¬ [QCSAM] fwd#{self._fwd_count} | "
|
| 3965 |
+
f"B={B} N={N} hilbert_dim={self.hilbert_dim} | "
|
| 3966 |
+
f"gate={gate_val:.4f} | "
|
| 3967 |
+
f"align_loss={self.last_align_loss:.6f} | "
|
| 3968 |
+
f"{heads_info}"
|
| 3969 |
+
)
|
| 3970 |
+
|
| 3971 |
return h_out, align_loss
|
| 3972 |
|
| 3973 |
|
|
|
|
| 4264 |
# Average align loss over MC samples
|
| 4265 |
mean_align = sum(mc_align) / len(mc_align) if mc_align else torch.tensor(0.0)
|
| 4266 |
|
| 4267 |
+
# ββ QCSAM inference summary (every 20 mc_forward calls) βββββββββββ
|
| 4268 |
+
if not hasattr(self, "_mc_fwd_count"):
|
| 4269 |
+
self._mc_fwd_count = 0
|
| 4270 |
+
self._mc_fwd_count += 1
|
| 4271 |
+
if self._mc_fwd_count % 20 == 1:
|
| 4272 |
+
align_val = mean_align.item() if hasattr(mean_align, "item") else float(mean_align)
|
| 4273 |
+
gate_val = torch.sigmoid(self.qcsam_layer.residual_gate).item()
|
| 4274 |
+
logger.info(
|
| 4275 |
+
f"π¬ [QCSAM/Inference] mc_fwd#{self._mc_fwd_count} | "
|
| 4276 |
+
f"samples={mc_samples} | "
|
| 4277 |
+
f"mean_align_loss={align_val:.6f} | "
|
| 4278 |
+
f"residual_gate={gate_val:.4f}"
|
| 4279 |
+
)
|
| 4280 |
+
|
| 4281 |
return {
|
| 4282 |
"significance_weight": mean_sig,
|
| 4283 |
"epistemic_variance": epistemic_var,
|