klemenk commited on
Commit
0df1def
·
verified ·
1 Parent(s): 75983a3

Update modeling_auristream.py

Browse files
Files changed (1) hide show
  1. modeling_auristream.py +40 -11
modeling_auristream.py CHANGED
@@ -88,6 +88,9 @@ class AuriStream(PreTrainedModel):
88
  x = self.transformer.drop(tok_emb + pos_emb)
89
  else:
90
  x = self.transformer.drop(tok_emb)
 
 
 
91
 
92
  all_hidden_states = []
93
  for block_idx, block in enumerate(self.transformer.h):
@@ -97,6 +100,9 @@ class AuriStream(PreTrainedModel):
97
  break
98
  x = block(x)
99
 
 
 
 
100
  # append the last hidden state if we did not exit early
101
  if up_until_layer is None or block_idx == len(self.transformer.h) - 1:
102
  all_hidden_states.append(x)
@@ -530,18 +536,18 @@ class MLP(nn.Module):
530
 
531
 
532
  class Rotary(torch.nn.Module):
533
- def __init__(self, dim, base=500000, learned=True):
534
  super().__init__()
535
  # Compute the base inverse frequencies as before.
536
- inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
537
- # If learned is True, register as a parameter; otherwise, as a buffer.
538
- if learned:
539
- # Initialize randomly and register as a parameter.
540
- self.inv_freq = torch.nn.Parameter(inv_freq)
541
- nn.init.normal_(self.inv_freq, mean=0.0, std=0.02)
542
- else:
543
- self.register_buffer("inv_freq", inv_freq)
544
- self.learned = learned # (optional) Save the flag if needed later
545
 
546
  def forward(self, x):
547
  seq_len = x.shape[1]
@@ -553,6 +559,7 @@ class Rotary(torch.nn.Module):
553
  sin_cached = freqs.sin()
554
  return cos_cached[None, :, None, :], sin_cached[None, :, None, :]
555
 
 
556
  def apply_rotary_emb(x, cos, sin):
557
  assert x.ndim == 4 # multihead attention expected
558
  d = x.shape[3] // 2
@@ -577,4 +584,26 @@ class RMSNorm(nn.Module):
577
  output = self._norm(x.float()).type_as(x)
578
  if self.weight is not None:
579
  return output * self.weight
580
- return output
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
88
  x = self.transformer.drop(tok_emb + pos_emb)
89
  else:
90
  x = self.transformer.drop(tok_emb)
91
+
92
+ if self.dwa is not None:
93
+ x = self.dwa.init_accumulators(x)
94
 
95
  all_hidden_states = []
96
  for block_idx, block in enumerate(self.transformer.h):
 
100
  break
101
  x = block(x)
102
 
103
+ if self.dwa is not None:
104
+ x = self.dwa(x)
105
+
106
  # append the last hidden state if we did not exit early
107
  if up_until_layer is None or block_idx == len(self.transformer.h) - 1:
108
  all_hidden_states.append(x)
 
536
 
537
 
538
  class Rotary(torch.nn.Module):
539
+ def __init__(self, dim, base=10000, learned=False):
540
  super().__init__()
541
  # Compute the base inverse frequencies as before.
542
+ self.inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
543
+ # # If learned is True, register as a parameter; otherwise, as a buffer.
544
+ # if learned:
545
+ # # Initialize randomly and register as a parameter.
546
+ # self.inv_freq = torch.nn.Parameter(inv_freq)
547
+ # nn.init.normal_(self.inv_freq, mean=0.0, std=0.02)
548
+ # else:
549
+ # self.register_buffer("inv_freq", inv_freq)
550
+ # self.learned = learned # (optional) Save the flag if needed later
551
 
552
  def forward(self, x):
553
  seq_len = x.shape[1]
 
559
  sin_cached = freqs.sin()
560
  return cos_cached[None, :, None, :], sin_cached[None, :, None, :]
561
 
562
+
563
  def apply_rotary_emb(x, cos, sin):
564
  assert x.ndim == 4 # multihead attention expected
565
  d = x.shape[3] // 2
 
584
  output = self._norm(x.float()).type_as(x)
585
  if self.weight is not None:
586
  return output * self.weight
587
+ return output
588
+
589
+
590
+ class DWA(nn.Module):
591
+ """ Depth Weighted Average layer that averages representations across the layers of a transformer """
592
+ """ From: https://arxiv.org/pdf/2402.02622"""
593
+
594
+ def __init__(self, n_layers: int):
595
+ super().__init__()
596
+ self.alphas = nn.Parameter(torch.zeros(n_layers, n_layers))
597
+ self.alphas.data = torch.eye(n_layers)
598
+ self.accumulators = []
599
+
600
+ def init_accumulators(self, x):
601
+ self.accumulators = [x]
602
+ return x * self.alphas[0, 0]
603
+
604
+ def forward(self, x):
605
+ self.accumulators.append(x)
606
+ x = 0.0
607
+ for i in range(len(self.accumulators)):
608
+ x = x + self.alphas[i, len(self.accumulators)-1] * self.accumulators[i]
609
+ return x