Hilbertmeng commited on
Commit
f6626e7
·
1 Parent(s): b946238

remove stack_hidden

Browse files
config.json CHANGED
@@ -25,7 +25,6 @@
25
  "rope_base": 10000,
26
  "round64": true,
27
  "sepln": true,
28
- "stack_hidden": false,
29
  "tie_word_embeddings": false,
30
  "torch_dtype": "bfloat16",
31
  "transformers_version": "4.35.0",
 
25
  "rope_base": 10000,
26
  "round64": true,
27
  "sepln": true,
 
28
  "tie_word_embeddings": false,
29
  "torch_dtype": "bfloat16",
30
  "transformers_version": "4.35.0",
configuration_muddformer.py CHANGED
@@ -33,7 +33,6 @@ class MUDDFormerConfig(PretrainedConfig):
33
  eos_token_id: int =2,
34
  tie_word_embeddings: bool =False,
35
  use_layer_cache: bool = True,
36
- stack_hidden: bool = False,
37
  dense: bool = True,
38
  dynamic_dense: bool = True,
39
  sepln: bool = True,
@@ -57,7 +56,6 @@ class MUDDFormerConfig(PretrainedConfig):
57
  self.use_qk_norm=use_qk_norm
58
 
59
  self.use_layer_cache= use_layer_cache
60
- self.stack_hidden= stack_hidden
61
  self.dense= dense
62
  self.dynamic_dense= dynamic_dense
63
  self.sepln= sepln
 
33
  eos_token_id: int =2,
34
  tie_word_embeddings: bool =False,
35
  use_layer_cache: bool = True,
 
36
  dense: bool = True,
37
  dynamic_dense: bool = True,
38
  sepln: bool = True,
 
56
  self.use_qk_norm=use_qk_norm
57
 
58
  self.use_layer_cache= use_layer_cache
 
59
  self.dense= dense
60
  self.dynamic_dense= dynamic_dense
61
  self.sepln= sepln
modeling_muddformer.py CHANGED
@@ -96,7 +96,6 @@ class MUDDFormer(PreTrainedModel):
96
 
97
  self.layer_cache = None
98
  self.use_layer_cache = False if self.is_training else self.config.use_layer_cache
99
- self.stack_hidden = self.config.stack_hidden
100
 
101
  self.dynamic = self.config.dynamic_dense
102
  self.dense = self.config.dense
@@ -178,11 +177,11 @@ class MUDDFormer(PreTrainedModel):
178
  _hidden = self.layer_cache.update(x, i+1) # LBTD
179
  else:
180
  hiddens.append(x)
181
- _hidden = hiddens if not self.stack_hidden else hiddens
182
  if self.dynamic and self.dense:
183
  dw = self.dynamic_dense[i](x) # BTD -> CBTL
184
  dw = dw + self.dense_bs[i][:,None,None,:] # CBTL
185
- if self.stack_hidden:
186
  x = torch.einsum('LBTD, CBTL -> CBTD', _hidden, dw)
187
  else:
188
  x = self.dynamic_dense[i].layer_mix(_hidden, dw)
@@ -216,7 +215,7 @@ class TransformerBlock(nn.Module):
216
  normed_x = self.attention_norm(x)
217
  elif self.config.dense_type == 'qkvr':
218
  res = x[-1] # for mlp
219
- if self.config.stack_hidden or not self.config.sepln:
220
  normed_x = self.attention_norm(x[:3])
221
  else:
222
  normed_x = tuple([norm_fn(_x) for norm_fn, _x in zip(self.attention_norms, x[:3])])
@@ -266,10 +265,7 @@ class Attention(nn.Module):
266
  if self.lidx == 0 or self.config.dense_type == 'l' or not self.config.dense:
267
  bsz, seqlen, _ = x.shape
268
  else:
269
- if self.config.stack_hidden:
270
- C, bsz, seqlen, _ = x.shape
271
- else:
272
- C, (bsz, seqlen, _) = len(x), x[0].shape
273
  kv_size = self.n_local_heads * self.head_dim
274
 
275
  if self.config.dense_type == 'l' or not self.config.dense:
 
96
 
97
  self.layer_cache = None
98
  self.use_layer_cache = False if self.is_training else self.config.use_layer_cache
 
99
 
100
  self.dynamic = self.config.dynamic_dense
101
  self.dense = self.config.dense
 
177
  _hidden = self.layer_cache.update(x, i+1) # LBTD
178
  else:
179
  hiddens.append(x)
180
+ _hidden = torch.stack(hiddens)
181
  if self.dynamic and self.dense:
182
  dw = self.dynamic_dense[i](x) # BTD -> CBTL
183
  dw = dw + self.dense_bs[i][:,None,None,:] # CBTL
184
+ if seqlen > 1:
185
  x = torch.einsum('LBTD, CBTL -> CBTD', _hidden, dw)
186
  else:
187
  x = self.dynamic_dense[i].layer_mix(_hidden, dw)
 
215
  normed_x = self.attention_norm(x)
216
  elif self.config.dense_type == 'qkvr':
217
  res = x[-1] # for mlp
218
+ if not self.config.sepln:
219
  normed_x = self.attention_norm(x[:3])
220
  else:
221
  normed_x = tuple([norm_fn(_x) for norm_fn, _x in zip(self.attention_norms, x[:3])])
 
265
  if self.lidx == 0 or self.config.dense_type == 'l' or not self.config.dense:
266
  bsz, seqlen, _ = x.shape
267
  else:
268
+ C, (bsz, seqlen, _) = len(x), x[0].shape
 
 
 
269
  kv_size = self.n_local_heads * self.head_dim
270
 
271
  if self.config.dense_type == 'l' or not self.config.dense: