KarlQuant commited on
Commit
2ad6d19
Β·
verified Β·
1 Parent(s): 8033109

Upload Quasar_axrvi_ranker.py

Browse files
Files changed (1) hide show
  1. Quasar_axrvi_ranker.py +34 -0
Quasar_axrvi_ranker.py CHANGED
@@ -3870,6 +3870,7 @@ class QCSAMCrossAssetLayer(nn.Module):
3870
  # Diagnostics
3871
  self.last_align_loss: float = 0.0
3872
  self.last_qmha_diagnostics: dict = {}
 
3873
 
3874
  def forward(
3875
  self,
@@ -3946,8 +3947,27 @@ class QCSAMCrossAssetLayer(nn.Module):
3946
  h_out = h + gate * delta # (B, N, d_model) float32
3947
 
3948
  # ── Mark first forward complete ───────────────────────────────────
 
3949
  QCSAMCrossAssetLayer._first_forward_complete = True
3950
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3951
  return h_out, align_loss
3952
 
3953
 
@@ -4244,6 +4264,20 @@ class AXRVINet(nn.Module):
4244
  # Average align loss over MC samples
4245
  mean_align = sum(mc_align) / len(mc_align) if mc_align else torch.tensor(0.0)
4246
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4247
  return {
4248
  "significance_weight": mean_sig,
4249
  "epistemic_variance": epistemic_var,
 
3870
  # Diagnostics
3871
  self.last_align_loss: float = 0.0
3872
  self.last_qmha_diagnostics: dict = {}
3873
+ self._fwd_count: int = 0 # counts forward calls for periodic logging
3874
 
3875
  def forward(
3876
  self,
 
3947
  h_out = h + gate * delta # (B, N, d_model) float32
3948
 
3949
  # ── Mark first forward complete ───────────────────────────────────
3950
+ self._fwd_count += 1
3951
  QCSAMCrossAssetLayer._first_forward_complete = True
3952
 
3953
+ # ── QCSAM diagnostic logging ──────────────────────────────────────
3954
+ # Log on the very first forward pass, then every 100 calls.
3955
+ # (10 MC-samples Γ— ~5s rank cycle β†’ every ~50s in steady state.)
3956
+ if self._fwd_count == 1 or self._fwd_count % 100 == 0:
3957
+ gate_val = torch.sigmoid(self.residual_gate).item()
3958
+ heads_info = " | ".join(
3959
+ f"h{i}: attn_norm={hd.get('attn_matrix_norm', 0.0):.4f} "
3960
+ f"align_h={hd.get('align_loss_h', 0.0):.6f}"
3961
+ for i, hd in enumerate(self.last_qmha_diagnostics.get("heads", []))
3962
+ )
3963
+ logger.info(
3964
+ f"πŸ”¬ [QCSAM] fwd#{self._fwd_count} | "
3965
+ f"B={B} N={N} hilbert_dim={self.hilbert_dim} | "
3966
+ f"gate={gate_val:.4f} | "
3967
+ f"align_loss={self.last_align_loss:.6f} | "
3968
+ f"{heads_info}"
3969
+ )
3970
+
3971
  return h_out, align_loss
3972
 
3973
 
 
4264
  # Average align loss over MC samples
4265
  mean_align = sum(mc_align) / len(mc_align) if mc_align else torch.tensor(0.0)
4266
 
4267
+ # ── QCSAM inference summary (every 20 mc_forward calls) ───────────
4268
+ if not hasattr(self, "_mc_fwd_count"):
4269
+ self._mc_fwd_count = 0
4270
+ self._mc_fwd_count += 1
4271
+ if self._mc_fwd_count % 20 == 1:
4272
+ align_val = mean_align.item() if hasattr(mean_align, "item") else float(mean_align)
4273
+ gate_val = torch.sigmoid(self.qcsam_layer.residual_gate).item()
4274
+ logger.info(
4275
+ f"πŸ”¬ [QCSAM/Inference] mc_fwd#{self._mc_fwd_count} | "
4276
+ f"samples={mc_samples} | "
4277
+ f"mean_align_loss={align_val:.6f} | "
4278
+ f"residual_gate={gate_val:.4f}"
4279
+ )
4280
+
4281
  return {
4282
  "significance_weight": mean_sig,
4283
  "epistemic_variance": epistemic_var,