Safetensors
English
llama-nsa
custom_code
zen-E commited on
Commit
0b72ebe
·
verified ·
1 Parent(s): 4fb685a

Update modeling_llama_nsa.py

Browse files
Files changed (1) hide show
  1. modeling_llama_nsa.py +44 -28
modeling_llama_nsa.py CHANGED
@@ -279,33 +279,35 @@ class LlamaNSAAttention(nn.Module):
279
  attention_interface: Callable = eager_attention_forward
280
  if self.config._attn_implementation != "eager":
281
  attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
282
-
283
- #attn_output_mha, attn_weights = attention_interface(
284
- # self,
285
- # query_states.transpose(1,2),
286
- # key_states.transpose(1,2),
287
- # value_states.transpose(1,2),
288
- # attention_mask,
289
- # dropout=0.0 if not self.training else self.attention_dropout,
290
- # scaling=self.scaling,
291
- # **kwargs,
292
- #)
293
- #attn_output_mha = attn_output_mha * g_slc.unsqueeze(-1) # also gated
294
 
295
- # new for NSA
296
- attn_output, _ = parallel_nsa(
297
- q=query_states,
298
- k=key_states,
299
- v=value_states,
300
- g_cmp=0,
301
- g_slc=g_slc,
302
- g_swa=0,
303
- block_size=self.config.block_size,
304
- block_counts=self.config.block_counts,
305
- window_size=self.config.window_size,
306
- head_first=False,
307
- )
308
- attn_weights = None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
309
 
310
  sa_loss = 0#torch.nn.SmoothL1Loss()(attn_output_mha, attn_output.detach()) + torch.nn.SmoothL1Loss()(attn_output_mha.detach(), attn_output)
311
 
@@ -313,10 +315,19 @@ class LlamaNSAAttention(nn.Module):
313
  # attn_output_mha = attn_output_mha.reshape(*input_shape, -1).contiguous()
314
  # attn_output_mha = self.o_proj(attn_output_mha)
315
  # return attn_output_mha, attn_weights, sa_loss
316
- if True:
317
  attn_output = attn_output.reshape(*input_shape, -1).contiguous()
318
  attn_output = self.o_proj(attn_output)
319
  return attn_output, attn_weights, sa_loss
 
 
 
 
 
 
 
 
 
320
 
321
 
322
  class LlamaNSADecoderLayer(GradientCheckpointingLayer):
@@ -434,7 +445,12 @@ class LlamaNSAModel(LlamaNSAPreTrainedModel):
434
  use_cache: Optional[bool] = None,
435
  **kwargs: Unpack[TransformersKwargs],
436
  ) -> BaseModelOutputWithPast:
437
- do_mha = False #random.random() > 0.5
 
 
 
 
 
438
 
439
  if (input_ids is None) ^ (inputs_embeds is not None):
440
  raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
 
279
  attention_interface: Callable = eager_attention_forward
280
  if self.config._attn_implementation != "eager":
281
  attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
 
 
 
 
 
 
 
 
 
 
 
 
282
 
283
+ if not self.training and do_mha:
284
+ attn_output_mha, attn_weights = attention_interface(
285
+ self,
286
+ query_states.transpose(1,2),
287
+ key_states.transpose(1,2),
288
+ value_states.transpose(1,2),
289
+ attention_mask,
290
+ dropout=0.0 if not self.training else self.attention_dropout,
291
+ scaling=self.scaling,
292
+ **kwargs,
293
+ )
294
+ attn_output_mha = attn_output_mha * g_slc.unsqueeze(-1) # also gated
295
+
296
+ if self.training or not do_mha:
297
+ # new for NSA
298
+ attn_output, _ = parallel_nsa(
299
+ q=query_states,
300
+ k=key_states,
301
+ v=value_states,
302
+ g_cmp=0,
303
+ g_slc=g_slc,
304
+ g_swa=0,
305
+ block_size=self.config.block_size,
306
+ block_counts=self.config.block_counts,
307
+ window_size=self.config.window_size,
308
+ head_first=False,
309
+ )
310
+ attn_weights = None
311
 
312
  sa_loss = 0#torch.nn.SmoothL1Loss()(attn_output_mha, attn_output.detach()) + torch.nn.SmoothL1Loss()(attn_output_mha.detach(), attn_output)
313
 
 
315
  # attn_output_mha = attn_output_mha.reshape(*input_shape, -1).contiguous()
316
  # attn_output_mha = self.o_proj(attn_output_mha)
317
  # return attn_output_mha, attn_weights, sa_loss
318
+ if self.training:
319
  attn_output = attn_output.reshape(*input_shape, -1).contiguous()
320
  attn_output = self.o_proj(attn_output)
321
  return attn_output, attn_weights, sa_loss
322
+ else:
323
+ if not do_mha:
324
+ attn_output = attn_output.reshape(*input_shape, -1).contiguous()
325
+ attn_output = self.o_proj(attn_output)
326
+ return attn_output, attn_weights, sa_loss
327
+ else:
328
+ attn_output_mha = attn_output_mha.reshape(*input_shape, -1).contiguous()
329
+ attn_output_mha = self.o_proj(attn_output_mha)
330
+ return attn_output_mha, attn_weights, sa_loss
331
 
332
 
333
  class LlamaNSADecoderLayer(GradientCheckpointingLayer):
 
445
  use_cache: Optional[bool] = None,
446
  **kwargs: Unpack[TransformersKwargs],
447
  ) -> BaseModelOutputWithPast:
448
+ if self.training:
449
+ do_mha = False
450
+ else:
451
+ if self.config.inference_mode not in ["sparse", "full"]:
452
+ raise ValueError
453
+ do_mha = False if self.config.inference_mode == "sparse" else True
454
 
455
  if (input_ids is None) ^ (inputs_embeds is not None):
456
  raise ValueError("You must specify exactly one of input_ids or inputs_embeds")