specsGuy commited on
Commit
4199e49
·
verified ·
1 Parent(s): c9d64d3

Update modeling_deepseekv2.py

Browse files
Files changed (1) hide show
  1. modeling_deepseekv2.py +11 -4
modeling_deepseekv2.py CHANGED
@@ -34,10 +34,17 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
34
  from transformers.activations import ACT2FN
35
  from transformers.cache_utils import Cache, DynamicCache
36
  from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask
37
- from transformers.models.llama.modeling_llama import (
38
- LlamaAttention,
39
- LlamaFlashAttention2
40
- )
 
 
 
 
 
 
 
41
  from transformers.modeling_outputs import (
42
  BaseModelOutputWithPast,
43
  CausalLMOutputWithPast,
 
34
  from transformers.activations import ACT2FN
35
  from transformers.cache_utils import Cache, DynamicCache
36
  from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask
37
+ try:
38
+ from transformers.models.llama.modeling_llama import (
39
+ LlamaFlashAttention2,
40
+ LlamaSdpaAttention,
41
+ LlamaAttention,
42
+ )
43
+ except ImportError:
44
+ # Fallback for CPU or environments without flash-attn
45
+ from transformers.models.llama.modeling_llama import LlamaAttention
46
+ LlamaFlashAttention2 = None
47
+ LlamaSdpaAttention = LlamaAttention
48
  from transformers.modeling_outputs import (
49
  BaseModelOutputWithPast,
50
  CausalLMOutputWithPast,