klemenk commited on
Commit
3006964
·
verified ·
1 Parent(s): 0df1def

Update modeling_auristream.py

Browse files
Files changed (1) hide show
  1. modeling_auristream.py +5 -0
modeling_auristream.py CHANGED
@@ -54,6 +54,11 @@ class AuriStream(PreTrainedModel):
54
  if pn.endswith('c_proj.weight'):
55
  torch.nn.init.normal_(p, mean=0.0, std=0.02/math.sqrt(2 * config.n_layer))
56
 
 
 
 
 
 
57
  def get_num_params(self, non_embedding=True):
58
  """
59
  Return the number of parameters in the model.
 
54
  if pn.endswith('c_proj.weight'):
55
  torch.nn.init.normal_(p, mean=0.0, std=0.02/math.sqrt(2 * config.n_layer))
56
 
57
+ self.dwa = None
58
+ if config.skip_connections:
59
+ self.dwa = DWA(config.n_layer + 1)
60
+
61
+
62
  def get_num_params(self, non_embedding=True):
63
  """
64
  Return the number of parameters in the model.