Update modeling_auristream.py
Browse files- modeling_auristream.py +40 -11
modeling_auristream.py
CHANGED
|
@@ -88,6 +88,9 @@ class AuriStream(PreTrainedModel):
|
|
| 88 |
x = self.transformer.drop(tok_emb + pos_emb)
|
| 89 |
else:
|
| 90 |
x = self.transformer.drop(tok_emb)
|
|
|
|
|
|
|
|
|
|
| 91 |
|
| 92 |
all_hidden_states = []
|
| 93 |
for block_idx, block in enumerate(self.transformer.h):
|
|
@@ -97,6 +100,9 @@ class AuriStream(PreTrainedModel):
|
|
| 97 |
break
|
| 98 |
x = block(x)
|
| 99 |
|
|
|
|
|
|
|
|
|
|
| 100 |
# append the last hidden state if we did not exit early
|
| 101 |
if up_until_layer is None or block_idx == len(self.transformer.h) - 1:
|
| 102 |
all_hidden_states.append(x)
|
|
@@ -530,18 +536,18 @@ class MLP(nn.Module):
|
|
| 530 |
|
| 531 |
|
| 532 |
class Rotary(torch.nn.Module):
|
| 533 |
-
def __init__(self, dim, base=
|
| 534 |
super().__init__()
|
| 535 |
# Compute the base inverse frequencies as before.
|
| 536 |
-
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
|
| 537 |
-
# If learned is True, register as a parameter; otherwise, as a buffer.
|
| 538 |
-
if learned:
|
| 539 |
-
|
| 540 |
-
|
| 541 |
-
|
| 542 |
-
else:
|
| 543 |
-
|
| 544 |
-
self.learned = learned # (optional) Save the flag if needed later
|
| 545 |
|
| 546 |
def forward(self, x):
|
| 547 |
seq_len = x.shape[1]
|
|
@@ -553,6 +559,7 @@ class Rotary(torch.nn.Module):
|
|
| 553 |
sin_cached = freqs.sin()
|
| 554 |
return cos_cached[None, :, None, :], sin_cached[None, :, None, :]
|
| 555 |
|
|
|
|
| 556 |
def apply_rotary_emb(x, cos, sin):
|
| 557 |
assert x.ndim == 4 # multihead attention expected
|
| 558 |
d = x.shape[3] // 2
|
|
@@ -577,4 +584,26 @@ class RMSNorm(nn.Module):
|
|
| 577 |
output = self._norm(x.float()).type_as(x)
|
| 578 |
if self.weight is not None:
|
| 579 |
return output * self.weight
|
| 580 |
-
return output
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 88 |
x = self.transformer.drop(tok_emb + pos_emb)
|
| 89 |
else:
|
| 90 |
x = self.transformer.drop(tok_emb)
|
| 91 |
+
|
| 92 |
+
if self.dwa is not None:
|
| 93 |
+
x = self.dwa.init_accumulators(x)
|
| 94 |
|
| 95 |
all_hidden_states = []
|
| 96 |
for block_idx, block in enumerate(self.transformer.h):
|
|
|
|
| 100 |
break
|
| 101 |
x = block(x)
|
| 102 |
|
| 103 |
+
if self.dwa is not None:
|
| 104 |
+
x = self.dwa(x)
|
| 105 |
+
|
| 106 |
# append the last hidden state if we did not exit early
|
| 107 |
if up_until_layer is None or block_idx == len(self.transformer.h) - 1:
|
| 108 |
all_hidden_states.append(x)
|
|
|
|
| 536 |
|
| 537 |
|
| 538 |
class Rotary(torch.nn.Module):
|
| 539 |
+
def __init__(self, dim, base=10000, learned=False):
|
| 540 |
super().__init__()
|
| 541 |
# Compute the base inverse frequencies as before.
|
| 542 |
+
self.inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
|
| 543 |
+
# # If learned is True, register as a parameter; otherwise, as a buffer.
|
| 544 |
+
# if learned:
|
| 545 |
+
# # Initialize randomly and register as a parameter.
|
| 546 |
+
# self.inv_freq = torch.nn.Parameter(inv_freq)
|
| 547 |
+
# nn.init.normal_(self.inv_freq, mean=0.0, std=0.02)
|
| 548 |
+
# else:
|
| 549 |
+
# self.register_buffer("inv_freq", inv_freq)
|
| 550 |
+
# self.learned = learned # (optional) Save the flag if needed later
|
| 551 |
|
| 552 |
def forward(self, x):
|
| 553 |
seq_len = x.shape[1]
|
|
|
|
| 559 |
sin_cached = freqs.sin()
|
| 560 |
return cos_cached[None, :, None, :], sin_cached[None, :, None, :]
|
| 561 |
|
| 562 |
+
|
| 563 |
def apply_rotary_emb(x, cos, sin):
|
| 564 |
assert x.ndim == 4 # multihead attention expected
|
| 565 |
d = x.shape[3] // 2
|
|
|
|
| 584 |
output = self._norm(x.float()).type_as(x)
|
| 585 |
if self.weight is not None:
|
| 586 |
return output * self.weight
|
| 587 |
+
return output
|
| 588 |
+
|
| 589 |
+
|
| 590 |
+
class DWA(nn.Module):
|
| 591 |
+
""" Depth Weighted Average layer that averages representations across the layers of a transformer """
|
| 592 |
+
""" From: https://arxiv.org/pdf/2402.02622"""
|
| 593 |
+
|
| 594 |
+
def __init__(self, n_layers: int):
|
| 595 |
+
super().__init__()
|
| 596 |
+
self.alphas = nn.Parameter(torch.zeros(n_layers, n_layers))
|
| 597 |
+
self.alphas.data = torch.eye(n_layers)
|
| 598 |
+
self.accumulators = []
|
| 599 |
+
|
| 600 |
+
def init_accumulators(self, x):
|
| 601 |
+
self.accumulators = [x]
|
| 602 |
+
return x * self.alphas[0, 0]
|
| 603 |
+
|
| 604 |
+
def forward(self, x):
|
| 605 |
+
self.accumulators.append(x)
|
| 606 |
+
x = 0.0
|
| 607 |
+
for i in range(len(self.accumulators)):
|
| 608 |
+
x = x + self.alphas[i, len(self.accumulators)-1] * self.accumulators[i]
|
| 609 |
+
return x
|