smithblack-0 commited on
Commit
0bc8a52
·
verified ·
1 Parent(s): d827698

Update architecture and tokenizer

Browse files
Files changed (1) hide show
  1. huggingface.py +102 -16
huggingface.py CHANGED
@@ -44,6 +44,7 @@ from torch import nn
44
  from torch.nn.attention.flex_attention import create_block_mask
45
  from torch.nn.attention.flex_attention import flex_attention
46
  import torch.nn.functional as F
 
47
 
48
 
49
 
@@ -547,8 +548,7 @@ class MoSRAHCache(CacheLayerMixin):
547
  # boolean-mask transfer is correct without any explicit count verification.
548
  self.keys[dest_mask] = key_states[active_mask]
549
  self.values[dest_mask] = value_states[active_mask]
550
-
551
- self._counts = post_counts
552
 
553
  return self.keys, self.values, self._make_active_mask()
554
 
@@ -1405,15 +1405,20 @@ Returns a plain dict with keys:
1405
  """Decoder layer — a single transformer block.
1406
 
1407
  Each block applies pre-norm hybrid attention followed by pre-norm MLP, with
1408
- residual connections around both sublayers:
1409
 
1410
  normed_attn = RMSNorm(x)
1411
  attn_out, load_balance_loss, max_vio = SHRAMHybridLayer(normed_attn, ...)
1412
- h = x + attn_out
1413
 
1414
  normed_mlp = RMSNorm(h)
1415
  mlp_out = SwiGLUMLP(normed_mlp)
1416
- out = h + mlp_out
 
 
 
 
 
1417
 
1418
  Pre-norm keeps the residual stream unnormalised. Gradients flow more cleanly
1419
  through unnormalised residuals at depth, and each sublayer receives a stable,
@@ -2344,7 +2349,7 @@ def setup_packing(
2344
  batch_size,
2345
  sequence_length * num_selected_heads,
2346
  )
2347
-
2348
  permutation = torch.argsort(flattened_selected_heads, dim=-1, stable=True)
2349
  inverse_permutation = torch.argsort(permutation, dim=-1)
2350
 
@@ -2352,6 +2357,7 @@ def setup_packing(
2352
  "flattened_selected_heads": flattened_selected_heads,
2353
  "permutation": permutation,
2354
  "inverse_permutation": inverse_permutation,
 
2355
  }
2356
 
2357
 
@@ -2493,6 +2499,7 @@ def pack_experts(
2493
  (batch_size, num_experts, packed_length, *extra_shape),
2494
  fill_value=padding_value,
2495
  )
 
2496
  packed_tensor[unpacking_mask] = sorted_tensor.reshape(-1, *extra_shape)
2497
  packed_entries[key] = packed_tensor
2498
 
@@ -2537,7 +2544,17 @@ def unpack_experts(
2537
  batch_size, sequence_length, num_selected_heads = selected_heads.shape
2538
  hidden_dim = expert_outputs.shape[-1]
2539
 
2540
- active_outputs = expert_outputs[unpacking_mask]
 
 
 
 
 
 
 
 
 
 
2541
  sorted_token_choice_outputs = active_outputs.reshape(
2542
  batch_size,
2543
  sequence_length * num_selected_heads,
@@ -2547,7 +2564,6 @@ def unpack_experts(
2547
  dim=1,
2548
  index=inverse_permutation.unsqueeze(-1).expand(-1, -1, hidden_dim),
2549
  )
2550
-
2551
  return restored_outputs.reshape(
2552
  batch_size,
2553
  sequence_length,
@@ -2753,6 +2769,7 @@ class LoadBalanceLoss(torch.autograd.Function):
2753
 
2754
 
2755
 
 
2756
  class MoSRAHRouter(nn.Module):
2757
  """Token-choice router for MoSRAH sparse attention.
2758
 
@@ -2775,6 +2792,10 @@ class MoSRAHRouter(nn.Module):
2775
  self.num_mosrah_heads = config.num_mosrah_heads
2776
  self.num_selected_heads = config.num_selected_heads
2777
  self.load_balance_p = config.load_balance_p
 
 
 
 
2778
 
2779
  # W_r: routing projection, no bias (paper specifies xW_r, no additional term).
2780
  self.routing_projection = nn.Linear(
@@ -2786,8 +2807,69 @@ class MoSRAHRouter(nn.Module):
2786
  # via the LoadBalanceLoss custom backward.
2787
  self.expert_bias = nn.Parameter(torch.zeros(config.num_mosrah_heads))
2788
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2789
  def forward(
2790
- self, x: torch.Tensor, active_mask: torch.Tensor
 
 
 
2791
  ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
2792
  """Route input tokens to K expert heads each and compute routing probabilities.
2793
 
@@ -2796,7 +2878,7 @@ class MoSRAHRouter(nn.Module):
2796
  active_mask: Current-chunk active mask of shape (batch, seq_len), where
2797
  True means the token is semantically live. Dead tokens do not
2798
  contribute to routing frequencies, load_balance_loss, or max_vio.
2799
-
2800
  Returns:
2801
  selected_heads: Head indices I of shape (batch, seq_len, num_selected_heads).
2802
  Each token's K selected head indices, determined by TopK on biased scores.
@@ -2821,18 +2903,21 @@ class MoSRAHRouter(nn.Module):
2821
  # Biased routing scores R̂ = Softmax(xW_r + b). Used only for TopK head
2822
  # selection. expert_bias is added to logits before softmax so that the bias
2823
  # shifts selection probability without rescaling the unbiased distribution.
 
 
2824
  biased_routing_scores = F.softmax( # R̂, (B, N, L)
2825
- logits + self.expert_bias, dim=-1
2826
  )
2827
 
2828
  # selected_heads I = TopK(R̂): K head indices per token, shape (B, N, K).
 
2829
  selected_heads = biased_routing_scores.topk(K, dim=-1).indices
 
2830
 
2831
  # Routing probabilities P: gathered from unbiased R at selected_heads indices,
2832
  # then renormalized so they sum to 1 per token. Gathering from routing_scores
2833
  # (not biased_routing_scores) is the invariant that keeps the gradient path from
2834
  # the output back to the router weights free of expert_bias influence.
2835
- gathered = routing_scores.gather(dim=-1, index=selected_heads) # V, (B, N, K)
2836
  routing_probs = gathered / gathered.sum(dim=-1, keepdim=True) # P, (B, N, K)
2837
 
2838
  # Per-item routing frequencies f_{b,l}: for each batch item b and head l, what
@@ -3062,8 +3147,9 @@ class MoSRAHLayer(nn.Module):
3062
  # B*N*K True entries) and the packed active mask (live slots only);
3063
  # active_mask is rebound to the packed form after this point.
3064
  # -------------------------------------------------------------------
 
3065
  selected_heads, routing_probs, load_balance_loss, max_vio = self.router(
3066
- hidden_states, active_mask
3067
  )
3068
 
3069
  setup = setup_packing(selected_heads)
@@ -3282,7 +3368,7 @@ class DecoderLayer(nn.Module):
3282
  self.mlp_norm = nn.RMSNorm(config.embedding_width, eps=config.rms_norm_eps)
3283
  self.attention = SHRAMHybridLayer(config)
3284
  self.mlp = SwiGLUMLP(config)
3285
-
3286
  def num_mosrah_parameters(self) -> int:
3287
  """Return the total number of trainable MoSRAH parameters in this decoder layer."""
3288
  return self.attention.num_mosrah_parameters()
@@ -3318,8 +3404,8 @@ class DecoderLayer(nn.Module):
3318
  active_mask=active_mask,
3319
  cache=cache,
3320
  )
3321
- hidden_states = x + attn_out
3322
- output = hidden_states + self.mlp(self.mlp_norm(hidden_states))
3323
  return output, load_balance_loss, max_vio
3324
 
3325
 
 
44
  from torch.nn.attention.flex_attention import create_block_mask
45
  from torch.nn.attention.flex_attention import flex_attention
46
  import torch.nn.functional as F
47
+ from typing import Optional
48
 
49
 
50
 
 
548
  # boolean-mask transfer is correct without any explicit count verification.
549
  self.keys[dest_mask] = key_states[active_mask]
550
  self.values[dest_mask] = value_states[active_mask]
551
+ self._counts[:] = post_counts[:]
 
552
 
553
  return self.keys, self.values, self._make_active_mask()
554
 
 
1405
  """Decoder layer — a single transformer block.
1406
 
1407
  Each block applies pre-norm hybrid attention followed by pre-norm MLP, with
1408
+ gated residual connections around both sublayers:
1409
 
1410
  normed_attn = RMSNorm(x)
1411
  attn_out, load_balance_loss, max_vio = SHRAMHybridLayer(normed_attn, ...)
1412
+ h = x + residual_gate * attn_out
1413
 
1414
  normed_mlp = RMSNorm(h)
1415
  mlp_out = SwiGLUMLP(normed_mlp)
1416
+ out = h + residual_gate * mlp_out
1417
+
1418
+ A single shared residual_gate vector (shape: embedding_width, init: zeros) gates
1419
+ both sublayer contributions. At initialisation the layer is a pure identity, which
1420
+ prevents variance explosion through depth regardless of how HuggingFace initialises
1421
+ the projection weights. The gate is a trainable parameter and opens during training.
1422
 
1423
  Pre-norm keeps the residual stream unnormalised. Gradients flow more cleanly
1424
  through unnormalised residuals at depth, and each sublayer receives a stable,
 
2349
  batch_size,
2350
  sequence_length * num_selected_heads,
2351
  )
2352
+ num_elements = batch_size*sequence_length*num_selected_heads
2353
  permutation = torch.argsort(flattened_selected_heads, dim=-1, stable=True)
2354
  inverse_permutation = torch.argsort(permutation, dim=-1)
2355
 
 
2357
  "flattened_selected_heads": flattened_selected_heads,
2358
  "permutation": permutation,
2359
  "inverse_permutation": inverse_permutation,
2360
+ "num_elements" : num_elements,
2361
  }
2362
 
2363
 
 
2499
  (batch_size, num_experts, packed_length, *extra_shape),
2500
  fill_value=padding_value,
2501
  )
2502
+
2503
  packed_tensor[unpacking_mask] = sorted_tensor.reshape(-1, *extra_shape)
2504
  packed_entries[key] = packed_tensor
2505
 
 
2544
  batch_size, sequence_length, num_selected_heads = selected_heads.shape
2545
  hidden_dim = expert_outputs.shape[-1]
2546
 
2547
+ coords = torch.nonzero_static(
2548
+ unpacking_mask,
2549
+ size=setup["num_elements"],
2550
+ ) # shape: (B*N*K, 3)
2551
+
2552
+ active_outputs = expert_outputs[
2553
+ coords[:, 0],
2554
+ coords[:, 1],
2555
+ coords[:, 2],
2556
+ ] # shape: (B*N*K, d)
2557
+
2558
  sorted_token_choice_outputs = active_outputs.reshape(
2559
  batch_size,
2560
  sequence_length * num_selected_heads,
 
2564
  dim=1,
2565
  index=inverse_permutation.unsqueeze(-1).expand(-1, -1, hidden_dim),
2566
  )
 
2567
  return restored_outputs.reshape(
2568
  batch_size,
2569
  sequence_length,
 
2769
 
2770
 
2771
 
2772
+
2773
  class MoSRAHRouter(nn.Module):
2774
  """Token-choice router for MoSRAH sparse attention.
2775
 
 
2792
  self.num_mosrah_heads = config.num_mosrah_heads
2793
  self.num_selected_heads = config.num_selected_heads
2794
  self.load_balance_p = config.load_balance_p
2795
+ if config.use_cache:
2796
+ self.capacity = config.mosrah_cache_length
2797
+ else:
2798
+ self.capacity = config.mosrah_packed_length
2799
 
2800
  # W_r: routing projection, no bias (paper specifies xW_r, no additional term).
2801
  self.routing_projection = nn.Linear(
 
2807
  # via the LoadBalanceLoss custom backward.
2808
  self.expert_bias = nn.Parameter(torch.zeros(config.num_mosrah_heads))
2809
 
2810
+ @staticmethod
2811
+ def balance_capacity(logits: torch.Tensor,
2812
+ used_capacity: torch.Tensor | None,
2813
+ capacity: int,
2814
+ )->torch.Tensor:
2815
+ """
2816
+ Balances capacity limits so that if choosing an
2817
+ expert would go over capacity, the expert is simply
2818
+ not chosen instead
2819
+ :param logits: The logits to balance. (B, N, L)
2820
+ :param used_capacity: The used capacity, if it exists. (B, L)
2821
+ :param capacity: The maximum available capacity. Int.
2822
+ :return: Modified logits.
2823
+ """
2824
+
2825
+ if used_capacity is None:
2826
+ # Presume we are in training mode.
2827
+
2828
+ # Looking up capacity limits only
2829
+ # matters if it is, in fact, possible
2830
+ # to exceed capacity limits.
2831
+ if logits.shape[-2] < capacity:
2832
+ return logits
2833
+
2834
+ # Look up the kthvalue and use that as
2835
+ # the threshold to mask when below.
2836
+ # Note we negate then negate again to sort
2837
+ # in ascending order.
2838
+ response = torch.kthvalue(-logits, capacity, dim=-2)
2839
+ threshold = -response.values
2840
+ threshold = threshold.unsqueeze(-2) #(B, 1, L)
2841
+ else:
2842
+ # We are operating in inference mode.
2843
+ # We have to use padding to accomodate the
2844
+ # response physically not being long enough
2845
+ # to reach capacity
2846
+
2847
+ # Note that padding at zero and shifting
2848
+ # the indexes prevents dereferencing a symint,
2849
+ # as a version that just patted at 0, 1 and set to
2850
+ # length + 1 would do. This prevents a graph break.
2851
+ remaining_capacity = capacity - used_capacity # 0 means all used, can be at most capacity
2852
+ response_length = logits.shape[-2]
2853
+ index = torch.clamp(remaining_capacity, 0, response_length+1)
2854
+
2855
+ # Sort, and add padding. Anything asking for a sequence position
2856
+ # outside the current sequence will get a threshold of -1e8; always include
2857
+ # If we are asking for a value at zero, get 1e8, or full and we include
2858
+ # nothing.
2859
+ ordered_logits = torch.sort(logits, dim=-2, descending=True).values
2860
+ ordered_logits = F.pad(ordered_logits, (0,0, 1, 0), value=1e8)
2861
+ ordered_logits = F.pad(ordered_logits, (0, 0, 0, 1), value=-1e8)
2862
+
2863
+ threshold = ordered_logits.gather(-2, index.unsqueeze(-2)) #(B, 1, L)
2864
+
2865
+ mask = threshold > logits
2866
+ logits = logits.masked_fill(mask, -1e8)
2867
+ return logits
2868
  def forward(
2869
+ self,
2870
+ x: torch.Tensor,
2871
+ active_mask: torch.Tensor,
2872
+ used_capacity: torch.Tensor | None
2873
  ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
2874
  """Route input tokens to K expert heads each and compute routing probabilities.
2875
 
 
2878
  active_mask: Current-chunk active mask of shape (batch, seq_len), where
2879
  True means the token is semantically live. Dead tokens do not
2880
  contribute to routing frequencies, load_balance_loss, or max_vio.
2881
+ used_capacity: Used for capacity management during inference, missing during training.
2882
  Returns:
2883
  selected_heads: Head indices I of shape (batch, seq_len, num_selected_heads).
2884
  Each token's K selected head indices, determined by TopK on biased scores.
 
2903
  # Biased routing scores R̂ = Softmax(xW_r + b). Used only for TopK head
2904
  # selection. expert_bias is added to logits before softmax so that the bias
2905
  # shifts selection probability without rescaling the unbiased distribution.
2906
+ biased_logits = logits + self.expert_bias
2907
+ biased_logits = self.balance_capacity(biased_logits, used_capacity, self.capacity)
2908
  biased_routing_scores = F.softmax( # R̂, (B, N, L)
2909
+ biased_logits, dim=-1
2910
  )
2911
 
2912
  # selected_heads I = TopK(R̂): K head indices per token, shape (B, N, K).
2913
+ # and routing logits directly
2914
  selected_heads = biased_routing_scores.topk(K, dim=-1).indices
2915
+ gathered = routing_scores.gather(dim=-1, index=selected_heads) # V, (B, N, K)
2916
 
2917
  # Routing probabilities P: gathered from unbiased R at selected_heads indices,
2918
  # then renormalized so they sum to 1 per token. Gathering from routing_scores
2919
  # (not biased_routing_scores) is the invariant that keeps the gradient path from
2920
  # the output back to the router weights free of expert_bias influence.
 
2921
  routing_probs = gathered / gathered.sum(dim=-1, keepdim=True) # P, (B, N, K)
2922
 
2923
  # Per-item routing frequencies f_{b,l}: for each batch item b and head l, what
 
3147
  # B*N*K True entries) and the packed active mask (live slots only);
3148
  # active_mask is rebound to the packed form after this point.
3149
  # -------------------------------------------------------------------
3150
+ used_capacity = cache.get_heads_lengths() if cache is not None else None
3151
  selected_heads, routing_probs, load_balance_loss, max_vio = self.router(
3152
+ hidden_states, active_mask, used_capacity
3153
  )
3154
 
3155
  setup = setup_packing(selected_heads)
 
3368
  self.mlp_norm = nn.RMSNorm(config.embedding_width, eps=config.rms_norm_eps)
3369
  self.attention = SHRAMHybridLayer(config)
3370
  self.mlp = SwiGLUMLP(config)
3371
+ self.residual_gate = nn.Parameter(torch.zeros([config.embedding_width]))
3372
  def num_mosrah_parameters(self) -> int:
3373
  """Return the total number of trainable MoSRAH parameters in this decoder layer."""
3374
  return self.attention.num_mosrah_parameters()
 
3404
  active_mask=active_mask,
3405
  cache=cache,
3406
  )
3407
+ hidden_states = x + self.residual_gate*attn_out
3408
+ output = hidden_states + self.residual_gate*self.mlp(self.mlp_norm(hidden_states))
3409
  return output, load_balance_loss, max_vio
3410
 
3411