klemenk commited on
Commit
b56f0a4
·
verified ·
1 Parent(s): 0afa744

Update modeling_auristream.py

Browse files
Files changed (1) hide show
  1. modeling_auristream.py +21 -1
modeling_auristream.py CHANGED
@@ -587,4 +587,24 @@ class RMSNorm(nn.Module):
587
  output = self._norm(x.float()).type_as(x)
588
  if self.weight is not None:
589
  return output * self.weight
590
- return output
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
587
  output = self._norm(x.float()).type_as(x)
588
  if self.weight is not None:
589
  return output * self.weight
590
+ return output
591
+
592
+ class DWA(nn.Module):
593
+ """ Depth Weighted Average layer that averages representations across the layers of a transformer """
594
+ """ From: https://arxiv.org/pdf/2402.02622"""
595
+
596
+ def __init__(self, n_layers: int):
597
+ super().__init__()
598
+ self.alphas = nn.Parameter(torch.zeros(n_layers, n_layers))
599
+ self.alphas.data = torch.eye(n_layers)
600
+ self.accumulators = []
601
+
602
+ def init_accumulators(self, x):
603
+ self.accumulators = [x]
604
+ return x * self.alphas[0, 0]
605
+
606
+ def forward(self, x):
607
+ self.accumulators.append(x)
608
+ for i in range(len(self.accumulators)):
609
+ x = x + self.alphas[i, len(self.accumulators)-1] * self.accumulators[i]
610
+ return x