zen-E commited on
Commit
91d7770
·
verified ·
1 Parent(s): 5d22a99

Update modeling_llama_nsa.py

Browse files
Files changed (1) hide show
  1. modeling_llama_nsa.py +35 -25
modeling_llama_nsa.py CHANGED
@@ -279,8 +279,9 @@ class LlamaNSAAttention(nn.Module):
279
  attention_interface: Callable = eager_attention_forward
280
  if self.config._attn_implementation != "eager":
281
  attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
282
-
283
- attn_output_mha, attn_weights = attention_interface(
 
284
  self,
285
  query_states.transpose(1,2),
286
  key_states.transpose(1,2),
@@ -289,26 +290,29 @@ class LlamaNSAAttention(nn.Module):
289
  dropout=0.0 if not self.training else self.attention_dropout,
290
  scaling=self.scaling,
291
  **kwargs,
292
- )
293
- attn_output_mha = attn_output_mha * g_slc.unsqueeze(-1) # also gated
294
-
295
- # new for NSA
296
- attn_output, _ = parallel_nsa(
297
- q=query_states,
298
- k=key_states,
299
- v=value_states,
300
- g_cmp=0,
301
- g_slc=g_slc,
302
- g_swa=0,
303
- block_size=self.config.block_size,
304
- block_counts=self.config.block_counts,
305
- window_size=self.config.window_size,
306
- head_first=False,
307
- )
 
308
  attn_weights = None
309
 
310
- sa_loss = torch.nn.SmoothL1Loss()(attn_output_mha, attn_output.detach()) + torch.nn.SmoothL1Loss()(attn_output_mha.detach(), attn_output)
311
-
 
 
312
  if do_mha:
313
  attn_output_mha = attn_output_mha.reshape(*input_shape, -1).contiguous()
314
  attn_output_mha = self.o_proj(attn_output_mha)
@@ -434,8 +438,13 @@ class LlamaNSAModel(LlamaNSAPreTrainedModel):
434
  use_cache: Optional[bool] = None,
435
  **kwargs: Unpack[TransformersKwargs],
436
  ) -> BaseModelOutputWithPast:
437
- do_mha = random.random() > 0.5
438
-
 
 
 
 
 
439
  if (input_ids is None) ^ (inputs_embeds is not None):
440
  raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
441
 
@@ -560,12 +569,13 @@ class LlamaNSAForCausalLM(LlamaNSAPreTrainedModel, GenerationMixin):
560
  #loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs)
561
  loss = ForCausalLMLoss(hidden_states=hidden_states[:, slice_indices, :], labels=labels, lm_head_weights=self.lm_head.weight, hidden_size=self.config.hidden_size, vocab_size=self.config.vocab_size, **kwargs)
562
 
563
- outputs.sa_loss = outputs.sa_loss*10
564
  if self.training:
 
 
565
  print(f"main={loss.item():.4f}, sa={outputs.sa_loss.item():.4f}")
566
 
567
  return CausalLMOutputWithPast(
568
- loss=loss + outputs.sa_loss,
569
  logits=logits,
570
  past_key_values=outputs.past_key_values,
571
  hidden_states=outputs.hidden_states,
@@ -590,4 +600,4 @@ __all__ = [
590
  "LlamaNSAForSequenceClassification",
591
  "LlamaNSAForQuestionAnswering",
592
  "LlamaNSAForTokenClassification",
593
- ]
 
279
  attention_interface: Callable = eager_attention_forward
280
  if self.config._attn_implementation != "eager":
281
  attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
282
+
283
+ if self.training or do_mha:
284
+ attn_output_mha, attn_weights = attention_interface(
285
  self,
286
  query_states.transpose(1,2),
287
  key_states.transpose(1,2),
 
290
  dropout=0.0 if not self.training else self.attention_dropout,
291
  scaling=self.scaling,
292
  **kwargs,
293
+ )
294
+ attn_output_mha = attn_output_mha * g_slc.unsqueeze(-1) # also gated
295
+
296
+ if self.training or not do_mha:
297
+ # new for Sparse Attention
298
+ attn_output, _ = parallel_nsa(
299
+ q=query_states,
300
+ k=key_states,
301
+ v=value_states,
302
+ g_cmp=0,
303
+ g_slc=g_slc,
304
+ g_swa=0,
305
+ block_size=self.config.block_size,
306
+ block_counts=self.config.block_counts,
307
+ window_size=self.config.window_size,
308
+ head_first=False,
309
+ )
310
  attn_weights = None
311
 
312
+ if self.training:
313
+ sa_loss = torch.nn.SmoothL1Loss()(attn_output_mha, attn_output.detach()) + torch.nn.SmoothL1Loss()(attn_output_mha.detach(), attn_output)
314
+ else:
315
+ sa_loss = 0
316
  if do_mha:
317
  attn_output_mha = attn_output_mha.reshape(*input_shape, -1).contiguous()
318
  attn_output_mha = self.o_proj(attn_output_mha)
 
438
  use_cache: Optional[bool] = None,
439
  **kwargs: Unpack[TransformersKwargs],
440
  ) -> BaseModelOutputWithPast:
441
+ if self.training:
442
+ do_mha = random.random() > 0.5
443
+ else:
444
+ if self.config.inference_mode not in ["sparse", "full"]:
445
+ raise ValueError
446
+ do_mha = False if self.config.inference_mode == "sparse" else True
447
+
448
  if (input_ids is None) ^ (inputs_embeds is not None):
449
  raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
450
 
 
569
  #loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs)
570
  loss = ForCausalLMLoss(hidden_states=hidden_states[:, slice_indices, :], labels=labels, lm_head_weights=self.lm_head.weight, hidden_size=self.config.hidden_size, vocab_size=self.config.vocab_size, **kwargs)
571
 
 
572
  if self.training:
573
+ outputs.sa_loss = outputs.sa_loss*10
574
+ loss = loss + outputs.sa_loss
575
  print(f"main={loss.item():.4f}, sa={outputs.sa_loss.item():.4f}")
576
 
577
  return CausalLMOutputWithPast(
578
+ loss=loss,
579
  logits=logits,
580
  past_key_values=outputs.past_key_values,
581
  hidden_states=outputs.hidden_states,
 
600
  "LlamaNSAForSequenceClassification",
601
  "LlamaNSAForQuestionAnswering",
602
  "LlamaNSAForTokenClassification",
603
+ ]