Update modeling_llama_nsa.py
Browse files- modeling_llama_nsa.py +35 -25
modeling_llama_nsa.py
CHANGED
|
@@ -279,8 +279,9 @@ class LlamaNSAAttention(nn.Module):
|
|
| 279 |
attention_interface: Callable = eager_attention_forward
|
| 280 |
if self.config._attn_implementation != "eager":
|
| 281 |
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
|
| 282 |
-
|
| 283 |
-
|
|
|
|
| 284 |
self,
|
| 285 |
query_states.transpose(1,2),
|
| 286 |
key_states.transpose(1,2),
|
|
@@ -289,26 +290,29 @@ class LlamaNSAAttention(nn.Module):
|
|
| 289 |
dropout=0.0 if not self.training else self.attention_dropout,
|
| 290 |
scaling=self.scaling,
|
| 291 |
**kwargs,
|
| 292 |
-
|
| 293 |
-
|
| 294 |
-
|
| 295 |
-
|
| 296 |
-
|
| 297 |
-
|
| 298 |
-
|
| 299 |
-
|
| 300 |
-
|
| 301 |
-
|
| 302 |
-
|
| 303 |
-
|
| 304 |
-
|
| 305 |
-
|
| 306 |
-
|
| 307 |
-
|
|
|
|
| 308 |
attn_weights = None
|
| 309 |
|
| 310 |
-
|
| 311 |
-
|
|
|
|
|
|
|
| 312 |
if do_mha:
|
| 313 |
attn_output_mha = attn_output_mha.reshape(*input_shape, -1).contiguous()
|
| 314 |
attn_output_mha = self.o_proj(attn_output_mha)
|
|
@@ -434,8 +438,13 @@ class LlamaNSAModel(LlamaNSAPreTrainedModel):
|
|
| 434 |
use_cache: Optional[bool] = None,
|
| 435 |
**kwargs: Unpack[TransformersKwargs],
|
| 436 |
) -> BaseModelOutputWithPast:
|
| 437 |
-
|
| 438 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 439 |
if (input_ids is None) ^ (inputs_embeds is not None):
|
| 440 |
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
|
| 441 |
|
|
@@ -560,12 +569,13 @@ class LlamaNSAForCausalLM(LlamaNSAPreTrainedModel, GenerationMixin):
|
|
| 560 |
#loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs)
|
| 561 |
loss = ForCausalLMLoss(hidden_states=hidden_states[:, slice_indices, :], labels=labels, lm_head_weights=self.lm_head.weight, hidden_size=self.config.hidden_size, vocab_size=self.config.vocab_size, **kwargs)
|
| 562 |
|
| 563 |
-
outputs.sa_loss = outputs.sa_loss*10
|
| 564 |
if self.training:
|
|
|
|
|
|
|
| 565 |
print(f"main={loss.item():.4f}, sa={outputs.sa_loss.item():.4f}")
|
| 566 |
|
| 567 |
return CausalLMOutputWithPast(
|
| 568 |
-
loss=loss
|
| 569 |
logits=logits,
|
| 570 |
past_key_values=outputs.past_key_values,
|
| 571 |
hidden_states=outputs.hidden_states,
|
|
@@ -590,4 +600,4 @@ __all__ = [
|
|
| 590 |
"LlamaNSAForSequenceClassification",
|
| 591 |
"LlamaNSAForQuestionAnswering",
|
| 592 |
"LlamaNSAForTokenClassification",
|
| 593 |
-
]
|
|
|
|
| 279 |
attention_interface: Callable = eager_attention_forward
|
| 280 |
if self.config._attn_implementation != "eager":
|
| 281 |
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
|
| 282 |
+
|
| 283 |
+
if self.training or do_mha:
|
| 284 |
+
attn_output_mha, attn_weights = attention_interface(
|
| 285 |
self,
|
| 286 |
query_states.transpose(1,2),
|
| 287 |
key_states.transpose(1,2),
|
|
|
|
| 290 |
dropout=0.0 if not self.training else self.attention_dropout,
|
| 291 |
scaling=self.scaling,
|
| 292 |
**kwargs,
|
| 293 |
+
)
|
| 294 |
+
attn_output_mha = attn_output_mha * g_slc.unsqueeze(-1) # also gated
|
| 295 |
+
|
| 296 |
+
if self.training or not do_mha:
|
| 297 |
+
# new for Sparse Attention
|
| 298 |
+
attn_output, _ = parallel_nsa(
|
| 299 |
+
q=query_states,
|
| 300 |
+
k=key_states,
|
| 301 |
+
v=value_states,
|
| 302 |
+
g_cmp=0,
|
| 303 |
+
g_slc=g_slc,
|
| 304 |
+
g_swa=0,
|
| 305 |
+
block_size=self.config.block_size,
|
| 306 |
+
block_counts=self.config.block_counts,
|
| 307 |
+
window_size=self.config.window_size,
|
| 308 |
+
head_first=False,
|
| 309 |
+
)
|
| 310 |
attn_weights = None
|
| 311 |
|
| 312 |
+
if self.training:
|
| 313 |
+
sa_loss = torch.nn.SmoothL1Loss()(attn_output_mha, attn_output.detach()) + torch.nn.SmoothL1Loss()(attn_output_mha.detach(), attn_output)
|
| 314 |
+
else:
|
| 315 |
+
sa_loss = 0
|
| 316 |
if do_mha:
|
| 317 |
attn_output_mha = attn_output_mha.reshape(*input_shape, -1).contiguous()
|
| 318 |
attn_output_mha = self.o_proj(attn_output_mha)
|
|
|
|
| 438 |
use_cache: Optional[bool] = None,
|
| 439 |
**kwargs: Unpack[TransformersKwargs],
|
| 440 |
) -> BaseModelOutputWithPast:
|
| 441 |
+
if self.training:
|
| 442 |
+
do_mha = random.random() > 0.5
|
| 443 |
+
else:
|
| 444 |
+
if self.config.inference_mode not in ["sparse", "full"]:
|
| 445 |
+
raise ValueError
|
| 446 |
+
do_mha = False if self.config.inference_mode == "sparse" else True
|
| 447 |
+
|
| 448 |
if (input_ids is None) ^ (inputs_embeds is not None):
|
| 449 |
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
|
| 450 |
|
|
|
|
| 569 |
#loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs)
|
| 570 |
loss = ForCausalLMLoss(hidden_states=hidden_states[:, slice_indices, :], labels=labels, lm_head_weights=self.lm_head.weight, hidden_size=self.config.hidden_size, vocab_size=self.config.vocab_size, **kwargs)
|
| 571 |
|
|
|
|
| 572 |
if self.training:
|
| 573 |
+
outputs.sa_loss = outputs.sa_loss*10
|
| 574 |
+
loss = loss + outputs.sa_loss
|
| 575 |
print(f"main={loss.item():.4f}, sa={outputs.sa_loss.item():.4f}")
|
| 576 |
|
| 577 |
return CausalLMOutputWithPast(
|
| 578 |
+
loss=loss,
|
| 579 |
logits=logits,
|
| 580 |
past_key_values=outputs.past_key_values,
|
| 581 |
hidden_states=outputs.hidden_states,
|
|
|
|
| 600 |
"LlamaNSAForSequenceClassification",
|
| 601 |
"LlamaNSAForQuestionAnswering",
|
| 602 |
"LlamaNSAForTokenClassification",
|
| 603 |
+
]
|