Nirupam Biswas commited on
Commit
f583b83
·
1 Parent(s): 0cfe2a2

Handle LlamaFlashAttention2 import for compatibility with newer transformers versions

Browse files
Files changed (1) hide show
  1. modeling_deepseekv2.py +12 -5
modeling_deepseekv2.py CHANGED
@@ -34,10 +34,17 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
34
  from transformers.activations import ACT2FN
35
  from transformers.cache_utils import Cache, DynamicCache
36
  from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask
37
- from transformers.models.llama.modeling_llama import (
38
- LlamaAttention,
39
- LlamaFlashAttention2
40
- )
 
 
 
 
 
 
 
41
  from transformers.modeling_outputs import (
42
  BaseModelOutputWithPast,
43
  CausalLMOutputWithPast,
@@ -1235,7 +1242,7 @@ ATTENTION_CLASSES = {
1235
  "mla_flash_attention_2": DeepseekV2FlashAttention2,
1236
 
1237
  "mha_eager": LlamaAttention,
1238
- "mha_flash_attention_2": LlamaFlashAttention2
1239
  }
1240
 
1241
 
 
34
  from transformers.activations import ACT2FN
35
  from transformers.cache_utils import Cache, DynamicCache
36
  from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask
37
+ # Handle different transformers versions
38
+ try:
39
+ from transformers.models.llama.modeling_llama import (
40
+ LlamaAttention,
41
+ LlamaFlashAttention2
42
+ )
43
+ except ImportError:
44
+ # Newer transformers versions (4.47+) don't have LlamaFlashAttention2
45
+ from transformers.models.llama.modeling_llama import LlamaAttention
46
+ LlamaFlashAttention2 = None # Will use fallback
47
+
48
  from transformers.modeling_outputs import (
49
  BaseModelOutputWithPast,
50
  CausalLMOutputWithPast,
 
1242
  "mla_flash_attention_2": DeepseekV2FlashAttention2,
1243
 
1244
  "mha_eager": LlamaAttention,
1245
+ "mha_flash_attention_2": LlamaFlashAttention2 if LlamaFlashAttention2 is not None else LlamaAttention
1246
  }
1247
 
1248