Update modeling_auristream.py
Browse files- modeling_auristream.py +5 -0
modeling_auristream.py
CHANGED
|
@@ -54,6 +54,11 @@ class AuriStream(PreTrainedModel):
|
|
| 54 |
if pn.endswith('c_proj.weight'):
|
| 55 |
torch.nn.init.normal_(p, mean=0.0, std=0.02/math.sqrt(2 * config.n_layer))
|
| 56 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 57 |
def get_num_params(self, non_embedding=True):
|
| 58 |
"""
|
| 59 |
Return the number of parameters in the model.
|
|
|
|
| 54 |
if pn.endswith('c_proj.weight'):
|
| 55 |
torch.nn.init.normal_(p, mean=0.0, std=0.02/math.sqrt(2 * config.n_layer))
|
| 56 |
|
| 57 |
+
self.dwa = None
|
| 58 |
+
if config.skip_connections:
|
| 59 |
+
self.dwa = DWA(config.n_layer + 1)
|
| 60 |
+
|
| 61 |
+
|
| 62 |
def get_num_params(self, non_embedding=True):
|
| 63 |
"""
|
| 64 |
Return the number of parameters in the model.
|