Update modeling_llama_nsa.py
Browse files- modeling_llama_nsa.py +44 -28
modeling_llama_nsa.py
CHANGED
|
@@ -279,33 +279,35 @@ class LlamaNSAAttention(nn.Module):
|
|
| 279 |
attention_interface: Callable = eager_attention_forward
|
| 280 |
if self.config._attn_implementation != "eager":
|
| 281 |
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
|
| 282 |
-
|
| 283 |
-
#attn_output_mha, attn_weights = attention_interface(
|
| 284 |
-
# self,
|
| 285 |
-
# query_states.transpose(1,2),
|
| 286 |
-
# key_states.transpose(1,2),
|
| 287 |
-
# value_states.transpose(1,2),
|
| 288 |
-
# attention_mask,
|
| 289 |
-
# dropout=0.0 if not self.training else self.attention_dropout,
|
| 290 |
-
# scaling=self.scaling,
|
| 291 |
-
# **kwargs,
|
| 292 |
-
#)
|
| 293 |
-
#attn_output_mha = attn_output_mha * g_slc.unsqueeze(-1) # also gated
|
| 294 |
|
| 295 |
-
|
| 296 |
-
|
| 297 |
-
|
| 298 |
-
|
| 299 |
-
|
| 300 |
-
|
| 301 |
-
|
| 302 |
-
|
| 303 |
-
|
| 304 |
-
|
| 305 |
-
|
| 306 |
-
|
| 307 |
-
|
| 308 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 309 |
|
| 310 |
sa_loss = 0#torch.nn.SmoothL1Loss()(attn_output_mha, attn_output.detach()) + torch.nn.SmoothL1Loss()(attn_output_mha.detach(), attn_output)
|
| 311 |
|
|
@@ -313,10 +315,19 @@ class LlamaNSAAttention(nn.Module):
|
|
| 313 |
# attn_output_mha = attn_output_mha.reshape(*input_shape, -1).contiguous()
|
| 314 |
# attn_output_mha = self.o_proj(attn_output_mha)
|
| 315 |
# return attn_output_mha, attn_weights, sa_loss
|
| 316 |
-
if
|
| 317 |
attn_output = attn_output.reshape(*input_shape, -1).contiguous()
|
| 318 |
attn_output = self.o_proj(attn_output)
|
| 319 |
return attn_output, attn_weights, sa_loss
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 320 |
|
| 321 |
|
| 322 |
class LlamaNSADecoderLayer(GradientCheckpointingLayer):
|
|
@@ -434,7 +445,12 @@ class LlamaNSAModel(LlamaNSAPreTrainedModel):
|
|
| 434 |
use_cache: Optional[bool] = None,
|
| 435 |
**kwargs: Unpack[TransformersKwargs],
|
| 436 |
) -> BaseModelOutputWithPast:
|
| 437 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 438 |
|
| 439 |
if (input_ids is None) ^ (inputs_embeds is not None):
|
| 440 |
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
|
|
|
|
| 279 |
attention_interface: Callable = eager_attention_forward
|
| 280 |
if self.config._attn_implementation != "eager":
|
| 281 |
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 282 |
|
| 283 |
+
if not self.training and do_mha:
|
| 284 |
+
attn_output_mha, attn_weights = attention_interface(
|
| 285 |
+
self,
|
| 286 |
+
query_states.transpose(1,2),
|
| 287 |
+
key_states.transpose(1,2),
|
| 288 |
+
value_states.transpose(1,2),
|
| 289 |
+
attention_mask,
|
| 290 |
+
dropout=0.0 if not self.training else self.attention_dropout,
|
| 291 |
+
scaling=self.scaling,
|
| 292 |
+
**kwargs,
|
| 293 |
+
)
|
| 294 |
+
attn_output_mha = attn_output_mha * g_slc.unsqueeze(-1) # also gated
|
| 295 |
+
|
| 296 |
+
if self.training or not do_mha:
|
| 297 |
+
# new for NSA
|
| 298 |
+
attn_output, _ = parallel_nsa(
|
| 299 |
+
q=query_states,
|
| 300 |
+
k=key_states,
|
| 301 |
+
v=value_states,
|
| 302 |
+
g_cmp=0,
|
| 303 |
+
g_slc=g_slc,
|
| 304 |
+
g_swa=0,
|
| 305 |
+
block_size=self.config.block_size,
|
| 306 |
+
block_counts=self.config.block_counts,
|
| 307 |
+
window_size=self.config.window_size,
|
| 308 |
+
head_first=False,
|
| 309 |
+
)
|
| 310 |
+
attn_weights = None
|
| 311 |
|
| 312 |
sa_loss = 0#torch.nn.SmoothL1Loss()(attn_output_mha, attn_output.detach()) + torch.nn.SmoothL1Loss()(attn_output_mha.detach(), attn_output)
|
| 313 |
|
|
|
|
| 315 |
# attn_output_mha = attn_output_mha.reshape(*input_shape, -1).contiguous()
|
| 316 |
# attn_output_mha = self.o_proj(attn_output_mha)
|
| 317 |
# return attn_output_mha, attn_weights, sa_loss
|
| 318 |
+
if self.training:
|
| 319 |
attn_output = attn_output.reshape(*input_shape, -1).contiguous()
|
| 320 |
attn_output = self.o_proj(attn_output)
|
| 321 |
return attn_output, attn_weights, sa_loss
|
| 322 |
+
else:
|
| 323 |
+
if not do_mha:
|
| 324 |
+
attn_output = attn_output.reshape(*input_shape, -1).contiguous()
|
| 325 |
+
attn_output = self.o_proj(attn_output)
|
| 326 |
+
return attn_output, attn_weights, sa_loss
|
| 327 |
+
else:
|
| 328 |
+
attn_output_mha = attn_output_mha.reshape(*input_shape, -1).contiguous()
|
| 329 |
+
attn_output_mha = self.o_proj(attn_output_mha)
|
| 330 |
+
return attn_output_mha, attn_weights, sa_loss
|
| 331 |
|
| 332 |
|
| 333 |
class LlamaNSADecoderLayer(GradientCheckpointingLayer):
|
|
|
|
| 445 |
use_cache: Optional[bool] = None,
|
| 446 |
**kwargs: Unpack[TransformersKwargs],
|
| 447 |
) -> BaseModelOutputWithPast:
|
| 448 |
+
if self.training:
|
| 449 |
+
do_mha = False
|
| 450 |
+
else:
|
| 451 |
+
if self.config.inference_mode not in ["sparse", "full"]:
|
| 452 |
+
raise ValueError
|
| 453 |
+
do_mha = False if self.config.inference_mode == "sparse" else True
|
| 454 |
|
| 455 |
if (input_ids is None) ^ (inputs_embeds is not None):
|
| 456 |
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
|