wangleiofficial commited on
Commit
9b747e5
·
verified ·
1 Parent(s): bac07df

Update dnaflash.py

Browse files
Files changed (1) hide show
  1. dnaflash.py +4 -4
dnaflash.py CHANGED
@@ -118,12 +118,12 @@ class T5RelativePositionBias(nn.Module):
118
  class OffsetScale(nn.Module):
119
  def __init__(self, dim, heads = 1):
120
  super().__init__()
121
- self.weight = nn.Parameter(torch.ones(heads, dim))
122
- self.bias = nn.Parameter(torch.zeros(heads, dim))
123
- nn.init.normal_(self.weight, std = 0.02)
124
 
125
  def forward(self, x):
126
- out = einsum('... d, h d -> ... h d', x, self.weight) + self.bias
127
  return out.unbind(dim = -2)
128
 
129
  # activation functions
 
118
  class OffsetScale(nn.Module):
119
  def __init__(self, dim, heads = 1):
120
  super().__init__()
121
+ self.gamma = nn.Parameter(torch.ones(heads, dim))
122
+ self.beta = nn.Parameter(torch.zeros(heads, dim))
123
+ nn.init.normal_(self.gamma, std = 0.02)
124
 
125
  def forward(self, x):
126
+ out = einsum('... d, h d -> ... h d', x, self.gamma) + self.beta
127
  return out.unbind(dim = -2)
128
 
129
  # activation functions