smithblack-0 commited on
Commit
78610c2
·
verified ·
1 Parent(s): 7bf638f

Update architecture and tokenizer

Browse files
__attention__mosrah.py CHANGED
@@ -45,6 +45,10 @@ class MoSRAHLayer(nn.Module):
45
  self.positions = SparseMoSRAHPositions(config)
46
  self.bea = BottleneckedEnsembleAttention(config)
47
 
 
 
 
 
48
  def forward(
49
  self,
50
  hidden_states: torch.Tensor,
 
45
  self.positions = SparseMoSRAHPositions(config)
46
  self.bea = BottleneckedEnsembleAttention(config)
47
 
48
+ def num_mosrah_parameters(self) -> int:
49
+ """Return the total number of trainable parameters in this MoSRAH layer."""
50
+ return sum(p.numel() for p in self.parameters())
51
+
52
  def forward(
53
  self,
54
  hidden_states: torch.Tensor,
__attention__shram.py CHANGED
@@ -35,6 +35,10 @@ class SHRAMHybridLayer(nn.Module):
35
  self.local_attention = SlidingWindowAttention(config)
36
  self.sparse_attention = MoSRAHLayer(config)
37
 
 
 
 
 
38
  def forward(
39
  self,
40
  hidden_states: torch.Tensor,
 
35
  self.local_attention = SlidingWindowAttention(config)
36
  self.sparse_attention = MoSRAHLayer(config)
37
 
38
+ def num_mosrah_parameters(self) -> int:
39
+ """Return the total number of trainable parameters in the MoSRAH sparse path."""
40
+ return self.sparse_attention.num_mosrah_parameters()
41
+
42
  def forward(
43
  self,
44
  hidden_states: torch.Tensor,
decoder_layer.py CHANGED
@@ -51,6 +51,10 @@ class DecoderLayer(nn.Module):
51
  self.attention = SHRAMHybridLayer(config)
52
  self.mlp = SwiGLUMLP(config)
53
 
 
 
 
 
54
  def forward(
55
  self,
56
  x: torch.Tensor,
 
51
  self.attention = SHRAMHybridLayer(config)
52
  self.mlp = SwiGLUMLP(config)
53
 
54
+ def num_mosrah_parameters(self) -> int:
55
+ """Return the total number of trainable MoSRAH parameters in this decoder layer."""
56
+ return self.attention.num_mosrah_parameters()
57
+
58
  def forward(
59
  self,
60
  x: torch.Tensor,
huggingface.py CHANGED
@@ -90,6 +90,18 @@ class ShramForCausalLM(PreTrainedModel, GenerationMixin):
90
  else:
91
  self._tied_weights_keys = {}
92
 
 
 
 
 
 
 
 
 
 
 
 
 
93
  def get_input_embeddings(self) -> nn.Embedding:
94
  """Return the token embedding matrix."""
95
  return self.embed_tokens
 
90
  else:
91
  self._tied_weights_keys = {}
92
 
93
+ def num_mosrah_parameters(self) -> int:
94
+ """Return the total number of trainable parameters belonging to MoSRAH layers.
95
+
96
+ Aggregates across all decoder layers. Excludes sliding-window path parameters,
97
+ FFN parameters, norms, and embeddings. Use this for experimental plotting of
98
+ MoSRAH parameter count versus performance.
99
+
100
+ Returns:
101
+ Total count of trainable MoSRAH parameters.
102
+ """
103
+ return self.model.num_mosrah_parameters()
104
+
105
  def get_input_embeddings(self) -> nn.Embedding:
106
  """Return the token embedding matrix."""
107
  return self.embed_tokens
model.py CHANGED
@@ -62,6 +62,10 @@ class ShramModel(nn.Module):
62
  )
63
  self.norm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
64
 
 
 
 
 
65
  def forward(
66
  self,
67
  inputs_embeds: torch.Tensor,
 
62
  )
63
  self.norm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
64
 
65
+ def num_mosrah_parameters(self) -> int:
66
+ """Return the total number of trainable MoSRAH parameters across all decoder layers."""
67
+ return sum(layer.num_mosrah_parameters() for layer in self.layers)
68
+
69
  def forward(
70
  self,
71
  inputs_embeds: torch.Tensor,