klemenk commited on
Commit
d533175
·
verified ·
1 Parent(s): 2141c91

Update modeling_auristream.py

Browse files
Files changed (1) hide show
  1. modeling_auristream.py +22 -0
modeling_auristream.py CHANGED
@@ -577,3 +577,25 @@ class RMSNorm(nn.Module):
577
  if self.weight is not None:
578
  return output * self.weight
579
  return output
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
577
  if self.weight is not None:
578
  return output * self.weight
579
  return output
580
+
581
+ class DWA(nn.Module):
582
+ """ Depth Weighted Average layer that averages representations across the layers of a transformer """
583
+ """ From: https://arxiv.org/pdf/2402.02622"""
584
+
585
+ def __init__(self, n_layers: int):
586
+ super().__init__()
587
+ self.alphas = nn.Parameter(torch.zeros(n_layers, n_layers))
588
+ self.alphas.data = torch.eye(n_layers)
589
+ self.accumulators = []
590
+
591
+ def init_accumulators(self, x):
592
+ self.accumulators = [x]
593
+ return x * self.alphas[0, 0]
594
+
595
+ def forward(self, x):
596
+ self.accumulators.append(x)
597
+ x = 0.0
598
+ for i in range(len(self.accumulators)):
599
+ x = x + self.alphas[i, len(self.accumulators)-1] * self.accumulators[i]
600
+ return x
601
+