smithblack-0 commited on
Commit
46c9e9f
·
verified ·
1 Parent(s): e15c0d5

Update architecture and tokenizer

Browse files
Files changed (2) hide show
  1. config.json +1 -1
  2. huggingface.py +108 -32
config.json CHANGED
@@ -24,7 +24,7 @@
24
  "rope_mode": "main_sequence",
25
  "tie_word_embeddings": false,
26
  "training_sequence_length": 1024,
27
- "transformers_version": "5.10.1",
28
  "use_cache": true,
29
  "vocab_size": 50277,
30
  "window_size": 128
 
24
  "rope_mode": "main_sequence",
25
  "tie_word_embeddings": false,
26
  "training_sequence_length": 1024,
27
+ "transformers_version": "5.10.2",
28
  "use_cache": true,
29
  "vocab_size": 50277,
30
  "window_size": 128
huggingface.py CHANGED
@@ -1458,6 +1458,10 @@ Returns a plain dict with keys:
1458
  - "hidden_states": tuple of per-layer activations if output_hidden_states=True, else None
1459
  - "load_balance_loss": scalar sum of per-layer SHRAM load-balance losses
1460
  - "max_vio": detached scalar maximum routing-imbalance across all decoder layers
 
 
 
 
1461
  """
1462
 
1463
 
@@ -1474,7 +1478,7 @@ Each block applies pre-norm hybrid attention followed by pre-norm MLP, with
1474
  gated residual connections around both sublayers:
1475
 
1476
  normed_attn = RMSNorm(x)
1477
- attn_out, load_balance_loss, max_vio = SHRAMHybridLayer(normed_attn, ...)
1478
  h = x + residual_gate * attn_out
1479
 
1480
  normed_mlp = RMSNorm(h)
@@ -3094,7 +3098,7 @@ class MoSRAHRouter(nn.Module):
3094
  x: torch.Tensor,
3095
  active_mask: torch.Tensor,
3096
  used_capacity: torch.Tensor | None
3097
- ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
3098
  """Route input tokens to K expert heads each and compute routing probabilities.
3099
 
3100
  Args:
@@ -3103,17 +3107,23 @@ class MoSRAHRouter(nn.Module):
3103
  True means the token is semantically live. Dead tokens do not
3104
  contribute to routing frequencies, load_balance_loss, or max_vio.
3105
  used_capacity: Used for capacity management during inference, missing during training.
 
3106
  Returns:
3107
  selected_heads: Head indices I of shape (batch, seq_len, num_selected_heads).
3108
  Each token's K selected head indices, determined by TopK on biased scores.
3109
  routing_probs: Routing probabilities P of shape (batch, seq_len,
3110
  num_selected_heads). Gathered from unbiased scores at selected_heads
3111
  indices and renormalized to sum to 1 per token.
3112
- load_balance_loss: Scalar load balance imbalance loss for this forward pass.
3113
- Training loop scales this by a weight and adds it to the main loss.
3114
- max_vio: Detached scalar routing-imbalance summary for this forward pass.
3115
- Equal to L · max_l(f_l − 1/L). Zero means perfect balance. Not a loss;
3116
- never contributes gradients.
 
 
 
 
 
3117
  """
3118
  B, N, _ = x.shape
3119
  L = self.num_mosrah_heads
@@ -3122,6 +3132,17 @@ class MoSRAHRouter(nn.Module):
3122
  # Unbiased routing scores R = Softmax(xW_r). These are the scores used to
3123
  # compute routing_probs — expert_bias must not influence them.
3124
  logits = self.routing_projection(x) # (B, N, L)
 
 
 
 
 
 
 
 
 
 
 
3125
  routing_scores = F.softmax(logits, dim=-1) # R, (B, N, L)
3126
 
3127
  # Biased routing scores RÌ‚ = Softmax(xW_r + b). Used only for TopK head
@@ -3177,7 +3198,15 @@ class MoSRAHRouter(nn.Module):
3177
  # L · max_l(f_l − 1/L) applied to routing_freqs. Must not contribute gradients.
3178
  max_vio = self._compute_max_vio(routing_freqs, L)
3179
 
3180
- return selected_heads, routing_probs, load_balance_loss, max_vio
 
 
 
 
 
 
 
 
3181
 
3182
  @staticmethod
3183
  def _compute_max_vio(routing_freqs: torch.Tensor, num_heads: int) -> torch.Tensor:
@@ -3322,8 +3351,9 @@ class MoSRAHLayer(nn.Module):
3322
 
3323
  The MoSRAH path consumes model-space hidden states together with
3324
  authoritative per-token positions and returns the model-space sparse-path
3325
- contribution, the router's load-balance loss, and the router's MaxVio
3326
- routing-imbalance scalar.
 
3327
  """
3328
 
3329
  def __init__(self, config: ShramConfig) -> None:
@@ -3348,7 +3378,7 @@ class MoSRAHLayer(nn.Module):
3348
  position_ids: torch.Tensor,
3349
  active_mask: torch.Tensor,
3350
  cache: MoSRAHCache | None,
3351
- ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
3352
  """Run the full MoSRAH sparse path.
3353
 
3354
  Args:
@@ -3364,9 +3394,10 @@ class MoSRAHLayer(nn.Module):
3364
 
3365
  Returns:
3366
  sparse_output: Model-space sparse-path output of shape (B, N, d).
3367
- load_balance_loss: Scalar router load-balance loss.
3368
- max_vio: Detached scalar routing-imbalance summary. Passed through
3369
- unchanged from the router; see MoSRAHRouter for semantics.
 
3370
  """
3371
 
3372
  # -------------------------------------------------------------------
@@ -3381,7 +3412,7 @@ class MoSRAHLayer(nn.Module):
3381
  # active_mask is rebound to the packed form after this point.
3382
  # -------------------------------------------------------------------
3383
  used_capacity = cache.get_heads_lengths() if cache is not None else None
3384
- selected_heads, routing_probs, load_balance_loss, max_vio = self.router(
3385
  hidden_states, active_mask, used_capacity
3386
  )
3387
 
@@ -3434,7 +3465,7 @@ class MoSRAHLayer(nn.Module):
3434
  token_choice_outputs * routing_probs.unsqueeze(-1)
3435
  ).sum(dim=2)
3436
 
3437
- return final_output, load_balance_loss, max_vio
3438
 
3439
 
3440
 
@@ -3463,7 +3494,7 @@ class SHRAMHybridLayer(nn.Module):
3463
  position_ids: torch.Tensor,
3464
  active_mask: torch.Tensor,
3465
  cache: ShramLayerCache | None,
3466
- ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
3467
  """Apply the SHRAM hybrid attention layer.
3468
 
3469
  Args:
@@ -3478,8 +3509,7 @@ class SHRAMHybridLayer(nn.Module):
3478
 
3479
  Returns:
3480
  hybrid_output: Model-space hybrid attention output of shape (B, N, d).
3481
- load_balance_loss: Scalar sparse-path load-balance loss.
3482
- max_vio: Detached scalar routing-imbalance summary. Passed through
3483
  unchanged from MoSRAHLayer; see MoSRAHRouter for semantics.
3484
  """
3485
  # -------------------------------------------------------------------
@@ -3507,7 +3537,7 @@ class SHRAMHybridLayer(nn.Module):
3507
  active_mask=active_mask,
3508
  cache=sliding_window_cache,
3509
  )
3510
- sparse_output, load_balance_loss, max_vio = self.sparse_attention(
3511
  hidden_states=hidden_states,
3512
  position_ids=position_ids,
3513
  active_mask=active_mask,
@@ -3522,7 +3552,7 @@ class SHRAMHybridLayer(nn.Module):
3522
  # -------------------------------------------------------------------
3523
  hybrid_output = local_output + sparse_output
3524
 
3525
- return hybrid_output, load_balance_loss, max_vio
3526
 
3527
 
3528
  # -----------
@@ -3612,7 +3642,7 @@ class DecoderLayer(nn.Module):
3612
  position_ids: torch.Tensor,
3613
  active_mask: torch.Tensor,
3614
  cache: ShramLayerCache | None = None,
3615
- ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
3616
  """Apply one decoder block to the input.
3617
 
3618
  Args:
@@ -3626,12 +3656,10 @@ class DecoderLayer(nn.Module):
3626
 
3627
  Returns:
3628
  output: Tensor of shape (batch, seq_len, hidden_size).
3629
- load_balance_loss: Scalar sparse-path load-balance loss propagated
3630
- from SHRAMHybridLayer.
3631
- max_vio: Detached scalar routing-imbalance summary. Passed through
3632
  unchanged from SHRAMHybridLayer; see MoSRAHRouter for semantics.
3633
  """
3634
- attn_out, load_balance_loss, max_vio = self.attention(
3635
  hidden_states=self.attn_norm(x),
3636
  position_ids=position_ids,
3637
  active_mask=active_mask,
@@ -3639,7 +3667,7 @@ class DecoderLayer(nn.Module):
3639
  )
3640
  hidden_states = x + self.residual_gate*attn_out
3641
  output = hidden_states + self.residual_gate*self.mlp(self.mlp_norm(hidden_states))
3642
- return output, load_balance_loss, max_vio
3643
 
3644
 
3645
  class ShramModel(nn.Module):
@@ -3708,27 +3736,51 @@ class ShramModel(nn.Module):
3708
  - ``"max_vio"``: detached scalar maximum routing-imbalance across
3709
  all decoder layers. Zero means perfectly balanced routing across
3710
  every layer; higher values identify the worst-case head imbalance.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3711
  """
3712
  hidden_states = inputs_embeds
3713
  all_hidden_states = (hidden_states,) if output_hidden_states else None
3714
  total_load_balance_loss = inputs_embeds.new_zeros(())
3715
  max_vio = inputs_embeds.new_zeros(())
 
 
 
 
3716
 
3717
  for layer_idx, layer in enumerate(self.layers):
3718
  layer_cache = None if cache is None else cache.layers[layer_idx]
3719
- hidden_states, layer_load_balance_loss, layer_max_vio = layer(
3720
  hidden_states,
3721
  position_ids,
3722
  active_mask,
3723
  cache=layer_cache,
3724
  )
3725
- total_load_balance_loss = total_load_balance_loss + layer_load_balance_loss
3726
- max_vio = torch.maximum(max_vio, layer_max_vio)
 
 
 
 
3727
 
3728
  if output_hidden_states:
3729
  all_hidden_states = all_hidden_states + (hidden_states,)
3730
 
3731
  hidden_states = self.norm(hidden_states)
 
3732
 
3733
  return {
3734
  "last_hidden_state": hidden_states,
@@ -3736,6 +3788,10 @@ class ShramModel(nn.Module):
3736
  "hidden_states": all_hidden_states,
3737
  "load_balance_loss": total_load_balance_loss,
3738
  "max_vio": max_vio,
 
 
 
 
3739
  }
3740
 
3741
 
@@ -3749,10 +3805,20 @@ class ShramCausalLMOutput(CausalLMOutputWithPast):
3749
  only the SHRAM-specific wrapper outputs.
3750
  """
3751
 
 
 
 
 
 
 
 
3752
  ce_loss: torch.FloatTensor | None = None
3753
  load_balance_loss: torch.FloatTensor | None = None
3754
  max_vio: torch.FloatTensor | None = None
3755
-
 
 
 
3756
 
3757
  class ShramForCausalLM(PreTrainedModel, GenerationMixin):
3758
  """HuggingFace-facing causal language model wrapper for SHRAM.
@@ -4181,6 +4247,9 @@ class ShramForCausalLM(PreTrainedModel, GenerationMixin):
4181
  output_hidden_states: Whether to return backbone hidden states.
4182
  Defaults to ``config.output_hidden_states``.
4183
  labels: Optional target token IDs of shape ``(batch, seq_len)``.
 
 
 
4184
  return_dict: Must be ``True`` or ``None``.
4185
  ce_weight: Weight applied to the cross-entropy loss when combining with
4186
  the load-balance loss. Default 1.0.
@@ -4197,7 +4266,10 @@ class ShramForCausalLM(PreTrainedModel, GenerationMixin):
4197
  - ``past_key_values`` as the active ``ShramCache`` or ``None``,
4198
  - ``hidden_states`` when requested,
4199
  - ``load_balance_loss`` — raw unweighted load-balance loss from the backbone,
4200
- - detached ``max_vio`` from the backbone.
 
 
 
4201
  """
4202
  use_cache = use_cache if use_cache is not None else self.config.use_cache
4203
  output_hidden_states = (
@@ -4304,4 +4376,8 @@ class ShramForCausalLM(PreTrainedModel, GenerationMixin):
4304
  hidden_states=backbone_outputs["hidden_states"],
4305
  load_balance_loss=backbone_outputs["load_balance_loss"],
4306
  max_vio=backbone_outputs["max_vio"],
 
 
 
 
4307
  )
 
1458
  - "hidden_states": tuple of per-layer activations if output_hidden_states=True, else None
1459
  - "load_balance_loss": scalar sum of per-layer SHRAM load-balance losses
1460
  - "max_vio": detached scalar maximum routing-imbalance across all decoder layers
1461
+ - "bias_std": detached scalar mean per-layer std of the expert bias vector
1462
+ - "raw_logit_std": detached scalar mean per-layer per-token routing logit spread
1463
+ - "logit_std": detached scalar mean per-layer per-token combined (logit + bias) spread
1464
+ - "bias_alignment": detached scalar mean per-layer cosine similarity of bias vs logits
1465
  """
1466
 
1467
 
 
1478
  gated residual connections around both sublayers:
1479
 
1480
  normed_attn = RMSNorm(x)
1481
+ attn_out, router_diagnostics = SHRAMHybridLayer(normed_attn, ...)
1482
  h = x + residual_gate * attn_out
1483
 
1484
  normed_mlp = RMSNorm(h)
 
3098
  x: torch.Tensor,
3099
  active_mask: torch.Tensor,
3100
  used_capacity: torch.Tensor | None
3101
+ ) -> tuple[torch.Tensor, torch.Tensor, dict[str, torch.Tensor]]:
3102
  """Route input tokens to K expert heads each and compute routing probabilities.
3103
 
3104
  Args:
 
3107
  True means the token is semantically live. Dead tokens do not
3108
  contribute to routing frequencies, load_balance_loss, or max_vio.
3109
  used_capacity: Used for capacity management during inference, missing during training.
3110
+
3111
  Returns:
3112
  selected_heads: Head indices I of shape (batch, seq_len, num_selected_heads).
3113
  Each token's K selected head indices, determined by TopK on biased scores.
3114
  routing_probs: Routing probabilities P of shape (batch, seq_len,
3115
  num_selected_heads). Gathered from unbiased scores at selected_heads
3116
  indices and renormalized to sum to 1 per token.
3117
+ router_diagnostics: Dict of routing feedback scalars. Keys:
3118
+ - ``load_balance_loss``: scalar load-balance loss with gradient.
3119
+ - ``max_vio``: detached scalar routing-imbalance summary.
3120
+ - ``bias_std``: std of expert_bias; near-zero means corrections have not built up.
3121
+ - ``raw_logit_std``: mean per-token std of unbiased logits; the natural routing scale.
3122
+ - ``logit_std``: mean per-token std of (logits + expert_bias); lower than
3123
+ raw_logit_std means bias is flattening preferences (healthy correction).
3124
+ - ``bias_alignment``: mean cosine similarity of expert_bias against per-token
3125
+ logits. Negative means bias opposes routing direction (healthy correction);
3126
+ positive means runaway reinforcement.
3127
  """
3128
  B, N, _ = x.shape
3129
  L = self.num_mosrah_heads
 
3132
  # Unbiased routing scores R = Softmax(xW_r). These are the scores used to
3133
  # compute routing_probs — expert_bias must not influence them.
3134
  logits = self.routing_projection(x) # (B, N, L)
3135
+
3136
+ # Diagnostic scalars characterising the load-balance mechanism. Must be
3137
+ # computed here — before balance_capacity injects -1e8 sentinels that
3138
+ # would corrupt std and cosine similarity.
3139
+ bias_std = self.expert_bias.std().detach()
3140
+ raw_logit_std = logits.std(dim=-1).mean().detach()
3141
+ logit_std = (logits + self.expert_bias).std(dim=-1).mean().detach()
3142
+ bias_alignment = F.cosine_similarity(
3143
+ logits, self.expert_bias.expand_as(logits), dim=-1
3144
+ ).mean().detach()
3145
+
3146
  routing_scores = F.softmax(logits, dim=-1) # R, (B, N, L)
3147
 
3148
  # Biased routing scores RÌ‚ = Softmax(xW_r + b). Used only for TopK head
 
3198
  # L · max_l(f_l − 1/L) applied to routing_freqs. Must not contribute gradients.
3199
  max_vio = self._compute_max_vio(routing_freqs, L)
3200
 
3201
+ router_diagnostics = {
3202
+ "load_balance_loss": load_balance_loss,
3203
+ "max_vio": max_vio,
3204
+ "bias_std": bias_std,
3205
+ "raw_logit_std": raw_logit_std,
3206
+ "logit_std": logit_std,
3207
+ "bias_alignment": bias_alignment,
3208
+ }
3209
+ return selected_heads, routing_probs, router_diagnostics
3210
 
3211
  @staticmethod
3212
  def _compute_max_vio(routing_freqs: torch.Tensor, num_heads: int) -> torch.Tensor:
 
3351
 
3352
  The MoSRAH path consumes model-space hidden states together with
3353
  authoritative per-token positions and returns the model-space sparse-path
3354
+ contribution and a diagnostics dict from the router containing
3355
+ load-balance loss, routing-imbalance scalar, and load-balance health
3356
+ scalars.
3357
  """
3358
 
3359
  def __init__(self, config: ShramConfig) -> None:
 
3378
  position_ids: torch.Tensor,
3379
  active_mask: torch.Tensor,
3380
  cache: MoSRAHCache | None,
3381
+ ) -> tuple[torch.Tensor, dict[str, torch.Tensor]]:
3382
  """Run the full MoSRAH sparse path.
3383
 
3384
  Args:
 
3394
 
3395
  Returns:
3396
  sparse_output: Model-space sparse-path output of shape (B, N, d).
3397
+ router_diagnostics: Dict of router feedback scalars. Keys:
3398
+ ``load_balance_loss`` (has grad), ``max_vio``, ``bias_std``,
3399
+ ``raw_logit_std``, ``logit_std``, ``bias_alignment`` (all
3400
+ detached). See MoSRAHRouter for semantics.
3401
  """
3402
 
3403
  # -------------------------------------------------------------------
 
3412
  # active_mask is rebound to the packed form after this point.
3413
  # -------------------------------------------------------------------
3414
  used_capacity = cache.get_heads_lengths() if cache is not None else None
3415
+ selected_heads, routing_probs, router_diagnostics = self.router(
3416
  hidden_states, active_mask, used_capacity
3417
  )
3418
 
 
3465
  token_choice_outputs * routing_probs.unsqueeze(-1)
3466
  ).sum(dim=2)
3467
 
3468
+ return final_output, router_diagnostics
3469
 
3470
 
3471
 
 
3494
  position_ids: torch.Tensor,
3495
  active_mask: torch.Tensor,
3496
  cache: ShramLayerCache | None,
3497
+ ) -> tuple[torch.Tensor, dict[str, torch.Tensor]]:
3498
  """Apply the SHRAM hybrid attention layer.
3499
 
3500
  Args:
 
3509
 
3510
  Returns:
3511
  hybrid_output: Model-space hybrid attention output of shape (B, N, d).
3512
+ router_diagnostics: Dict of router feedback scalars passed through
 
3513
  unchanged from MoSRAHLayer; see MoSRAHRouter for semantics.
3514
  """
3515
  # -------------------------------------------------------------------
 
3537
  active_mask=active_mask,
3538
  cache=sliding_window_cache,
3539
  )
3540
+ sparse_output, router_diagnostics = self.sparse_attention(
3541
  hidden_states=hidden_states,
3542
  position_ids=position_ids,
3543
  active_mask=active_mask,
 
3552
  # -------------------------------------------------------------------
3553
  hybrid_output = local_output + sparse_output
3554
 
3555
+ return hybrid_output, router_diagnostics
3556
 
3557
 
3558
  # -----------
 
3642
  position_ids: torch.Tensor,
3643
  active_mask: torch.Tensor,
3644
  cache: ShramLayerCache | None = None,
3645
+ ) -> tuple[torch.Tensor, dict[str, torch.Tensor]]:
3646
  """Apply one decoder block to the input.
3647
 
3648
  Args:
 
3656
 
3657
  Returns:
3658
  output: Tensor of shape (batch, seq_len, hidden_size).
3659
+ router_diagnostics: Dict of router feedback scalars passed through
 
 
3660
  unchanged from SHRAMHybridLayer; see MoSRAHRouter for semantics.
3661
  """
3662
+ attn_out, router_diagnostics = self.attention(
3663
  hidden_states=self.attn_norm(x),
3664
  position_ids=position_ids,
3665
  active_mask=active_mask,
 
3667
  )
3668
  hidden_states = x + self.residual_gate*attn_out
3669
  output = hidden_states + self.residual_gate*self.mlp(self.mlp_norm(hidden_states))
3670
+ return output, router_diagnostics
3671
 
3672
 
3673
  class ShramModel(nn.Module):
 
3736
  - ``"max_vio"``: detached scalar maximum routing-imbalance across
3737
  all decoder layers. Zero means perfectly balanced routing across
3738
  every layer; higher values identify the worst-case head imbalance.
3739
+ - ``"bias_std"``: detached scalar — mean across layers of the std
3740
+ of each layer's expert bias vector. Near-zero means corrections
3741
+ have not built up; large relative to ``raw_logit_std`` means the
3742
+ bias dominates routing.
3743
+ - ``"raw_logit_std"``: detached scalar — mean across layers of the
3744
+ per-token routing logit spread before bias addition. Baseline
3745
+ natural routing preference scale.
3746
+ - ``"logit_std"``: detached scalar — mean across layers of the
3747
+ per-token combined (logit + bias) spread. Lower than
3748
+ ``raw_logit_std`` indicates healthy flattening; higher indicates
3749
+ amplification.
3750
+ - ``"bias_alignment"``: detached scalar — mean across layers of the
3751
+ per-token cosine similarity between the expert bias vector and the
3752
+ routing logits. Negative is healthy correction; positive is
3753
+ runaway feedback.
3754
  """
3755
  hidden_states = inputs_embeds
3756
  all_hidden_states = (hidden_states,) if output_hidden_states else None
3757
  total_load_balance_loss = inputs_embeds.new_zeros(())
3758
  max_vio = inputs_embeds.new_zeros(())
3759
+ total_bias_std = inputs_embeds.new_zeros(())
3760
+ total_raw_logit_std = inputs_embeds.new_zeros(())
3761
+ total_logit_std = inputs_embeds.new_zeros(())
3762
+ total_bias_alignment = inputs_embeds.new_zeros(())
3763
 
3764
  for layer_idx, layer in enumerate(self.layers):
3765
  layer_cache = None if cache is None else cache.layers[layer_idx]
3766
+ hidden_states, layer_diagnostics = layer(
3767
  hidden_states,
3768
  position_ids,
3769
  active_mask,
3770
  cache=layer_cache,
3771
  )
3772
+ total_load_balance_loss = total_load_balance_loss + layer_diagnostics["load_balance_loss"]
3773
+ max_vio = torch.maximum(max_vio, layer_diagnostics["max_vio"])
3774
+ total_bias_std = total_bias_std + layer_diagnostics["bias_std"]
3775
+ total_raw_logit_std = total_raw_logit_std + layer_diagnostics["raw_logit_std"]
3776
+ total_logit_std = total_logit_std + layer_diagnostics["logit_std"]
3777
+ total_bias_alignment = total_bias_alignment + layer_diagnostics["bias_alignment"]
3778
 
3779
  if output_hidden_states:
3780
  all_hidden_states = all_hidden_states + (hidden_states,)
3781
 
3782
  hidden_states = self.norm(hidden_states)
3783
+ num_layers = len(self.layers)
3784
 
3785
  return {
3786
  "last_hidden_state": hidden_states,
 
3788
  "hidden_states": all_hidden_states,
3789
  "load_balance_loss": total_load_balance_loss,
3790
  "max_vio": max_vio,
3791
+ "bias_std": total_bias_std / num_layers,
3792
+ "raw_logit_std": total_raw_logit_std / num_layers,
3793
+ "logit_std": total_logit_std / num_layers,
3794
+ "bias_alignment": total_bias_alignment / num_layers,
3795
  }
3796
 
3797
 
 
3805
  only the SHRAM-specific wrapper outputs.
3806
  """
3807
 
3808
+ ## Python dataclass inheritance violation: CausalLMOutputWithPast defaults all
3809
+ ## fields to None, which forces every subclass field to also carry a default.
3810
+ ## The = None below is a language constraint, not a semantic statement. In
3811
+ ## practice, load_balance_loss, max_vio, bias_std, raw_logit_std, logit_std,
3812
+ ## and bias_alignment are always populated by ShramForCausalLM.forward().
3813
+ ## ce_loss is genuinely optional — present only when labels are supplied.
3814
+
3815
  ce_loss: torch.FloatTensor | None = None
3816
  load_balance_loss: torch.FloatTensor | None = None
3817
  max_vio: torch.FloatTensor | None = None
3818
+ bias_std: torch.Tensor | None = None
3819
+ raw_logit_std: torch.Tensor | None = None
3820
+ logit_std: torch.Tensor | None = None
3821
+ bias_alignment: torch.Tensor | None = None
3822
 
3823
  class ShramForCausalLM(PreTrainedModel, GenerationMixin):
3824
  """HuggingFace-facing causal language model wrapper for SHRAM.
 
4247
  output_hidden_states: Whether to return backbone hidden states.
4248
  Defaults to ``config.output_hidden_states``.
4249
  labels: Optional target token IDs of shape ``(batch, seq_len)``.
4250
+ Pass unshifted labels (same alignment as ``input_ids``). This
4251
+ wrapper shifts internally: ``logits[:, :-1]`` is compared
4252
+ against ``labels[:, 1:]``. Do not pre-shift the caller side.
4253
  return_dict: Must be ``True`` or ``None``.
4254
  ce_weight: Weight applied to the cross-entropy loss when combining with
4255
  the load-balance loss. Default 1.0.
 
4266
  - ``past_key_values`` as the active ``ShramCache`` or ``None``,
4267
  - ``hidden_states`` when requested,
4268
  - ``load_balance_loss`` — raw unweighted load-balance loss from the backbone,
4269
+ - ``max_vio`` — detached worst-case routing imbalance across layers,
4270
+ - ``bias_std``, ``raw_logit_std``, ``logit_std``, ``bias_alignment`` —
4271
+ detached load-balance health scalars averaged across decoder layers;
4272
+ see ``ShramModel`` for interpretation.
4273
  """
4274
  use_cache = use_cache if use_cache is not None else self.config.use_cache
4275
  output_hidden_states = (
 
4376
  hidden_states=backbone_outputs["hidden_states"],
4377
  load_balance_loss=backbone_outputs["load_balance_loss"],
4378
  max_vio=backbone_outputs["max_vio"],
4379
+ bias_std=backbone_outputs["bias_std"],
4380
+ raw_logit_std=backbone_outputs["raw_logit_std"],
4381
+ logit_std=backbone_outputs["logit_std"],
4382
+ bias_alignment=backbone_outputs["bias_alignment"],
4383
  )