Update modeling_llama_nsa.py
Browse files- modeling_llama_nsa.py +1 -1
modeling_llama_nsa.py
CHANGED
|
@@ -307,7 +307,7 @@ class LlamaNSAAttention(nn.Module):
|
|
| 307 |
window_size=self.config.window_size,
|
| 308 |
head_first=False,
|
| 309 |
)
|
| 310 |
-
|
| 311 |
|
| 312 |
sa_loss = 0 #torch.nn.SmoothL1Loss()(attn_output_mha, attn_output.detach()) + torch.nn.SmoothL1Loss()(attn_output_mha.detach(), attn_output)
|
| 313 |
|
|
|
|
| 307 |
window_size=self.config.window_size,
|
| 308 |
head_first=False,
|
| 309 |
)
|
| 310 |
+
attn_weights = None
|
| 311 |
|
| 312 |
sa_loss = 0 #torch.nn.SmoothL1Loss()(attn_output_mha, attn_output.detach()) + torch.nn.SmoothL1Loss()(attn_output_mha.detach(), attn_output)
|
| 313 |
|