Update modeling_auristream.py
Browse files- modeling_auristream.py +21 -1
modeling_auristream.py
CHANGED
|
@@ -587,4 +587,24 @@ class RMSNorm(nn.Module):
|
|
| 587 |
output = self._norm(x.float()).type_as(x)
|
| 588 |
if self.weight is not None:
|
| 589 |
return output * self.weight
|
| 590 |
-
return output
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 587 |
output = self._norm(x.float()).type_as(x)
|
| 588 |
if self.weight is not None:
|
| 589 |
return output * self.weight
|
| 590 |
+
return output
|
| 591 |
+
|
| 592 |
+
class DWA(nn.Module):
|
| 593 |
+
""" Depth Weighted Average layer that averages representations across the layers of a transformer """
|
| 594 |
+
""" From: https://arxiv.org/pdf/2402.02622"""
|
| 595 |
+
|
| 596 |
+
def __init__(self, n_layers: int):
|
| 597 |
+
super().__init__()
|
| 598 |
+
self.alphas = nn.Parameter(torch.zeros(n_layers, n_layers))
|
| 599 |
+
self.alphas.data = torch.eye(n_layers)
|
| 600 |
+
self.accumulators = []
|
| 601 |
+
|
| 602 |
+
def init_accumulators(self, x):
|
| 603 |
+
self.accumulators = [x]
|
| 604 |
+
return x * self.alphas[0, 0]
|
| 605 |
+
|
| 606 |
+
def forward(self, x):
|
| 607 |
+
self.accumulators.append(x)
|
| 608 |
+
for i in range(len(self.accumulators)):
|
| 609 |
+
x = x + self.alphas[i, len(self.accumulators)-1] * self.accumulators[i]
|
| 610 |
+
return x
|