Update dnaflash.py
Browse files- dnaflash.py +4 -4
dnaflash.py
CHANGED
|
@@ -118,12 +118,12 @@ class T5RelativePositionBias(nn.Module):
|
|
| 118 |
class OffsetScale(nn.Module):
|
| 119 |
def __init__(self, dim, heads = 1):
|
| 120 |
super().__init__()
|
| 121 |
-
self.
|
| 122 |
-
self.
|
| 123 |
-
nn.init.normal_(self.
|
| 124 |
|
| 125 |
def forward(self, x):
|
| 126 |
-
out = einsum('... d, h d -> ... h d', x, self.
|
| 127 |
return out.unbind(dim = -2)
|
| 128 |
|
| 129 |
# activation functions
|
|
|
|
| 118 |
class OffsetScale(nn.Module):
|
| 119 |
def __init__(self, dim, heads = 1):
|
| 120 |
super().__init__()
|
| 121 |
+
self.gamma = nn.Parameter(torch.ones(heads, dim))
|
| 122 |
+
self.beta = nn.Parameter(torch.zeros(heads, dim))
|
| 123 |
+
nn.init.normal_(self.gamma, std = 0.02)
|
| 124 |
|
| 125 |
def forward(self, x):
|
| 126 |
+
out = einsum('... d, h d -> ... h d', x, self.gamma) + self.beta
|
| 127 |
return out.unbind(dim = -2)
|
| 128 |
|
| 129 |
# activation functions
|