BlackSamorez commited on
Commit
63c64c0
·
verified ·
1 Parent(s): 370157d

Upload modeling_cloverlm.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. modeling_cloverlm.py +15 -3
modeling_cloverlm.py CHANGED
@@ -111,15 +111,27 @@ class MHSA(nn.Module):
111
 
112
  dtype = Q.dtype if Q.dtype in (torch.bfloat16, torch.float16) else torch.bfloat16
113
  if attn_backend == "flash2":
114
- import flash_attn
 
 
 
 
115
  Y = flash_attn.flash_attn_func(Q.to(dtype), K.to(dtype), V.to(dtype), causal=True, softmax_scale=1.0)
116
  elif attn_backend == "flash3":
117
  import importlib
118
- _fa3 = importlib.import_module("flash_attn_interface")
 
 
 
 
119
  Y = _fa3.flash_attn_func(Q.to(dtype), K.to(dtype), V.to(dtype), causal=True, softmax_scale=1.0)
120
  elif attn_backend == "flash4":
121
  import importlib
122
- _fa4 = importlib.import_module("flash_attn.cute")
 
 
 
 
123
  Y = _fa4.flash_attn_func(Q.to(dtype), K.to(dtype), V.to(dtype), causal=True, softmax_scale=1.0)[0]
124
  Y = Y.to(Q.dtype).flatten(-2, -1)
125
 
 
111
 
112
  dtype = Q.dtype if Q.dtype in (torch.bfloat16, torch.float16) else torch.bfloat16
113
  if attn_backend == "flash2":
114
+ try:
115
+ import flash_attn
116
+ except ImportError as e:
117
+ e.add_note(f"Can't run `attn_backend=flash2` because can't import flash_attn")
118
+ raise e
119
  Y = flash_attn.flash_attn_func(Q.to(dtype), K.to(dtype), V.to(dtype), causal=True, softmax_scale=1.0)
120
  elif attn_backend == "flash3":
121
  import importlib
122
+ try:
123
+ _fa3 = importlib.import_module("flash_attn_interface")
124
+ except ImportError as e:
125
+ e.add_note(f"Can't run `attn_backend=flash3` because can't import flash_attn_interface")
126
+ raise e
127
  Y = _fa3.flash_attn_func(Q.to(dtype), K.to(dtype), V.to(dtype), causal=True, softmax_scale=1.0)
128
  elif attn_backend == "flash4":
129
  import importlib
130
+ try:
131
+ _fa4 = importlib.import_module("flash_attn.cute")
132
+ except ImportError as e:
133
+ e.add_note(f"Can't run `attn_backend=flash4` because can't import flash_attn.cute")
134
+ raise e
135
  Y = _fa4.flash_attn_func(Q.to(dtype), K.to(dtype), V.to(dtype), causal=True, softmax_scale=1.0)[0]
136
  Y = Y.to(Q.dtype).flatten(-2, -1)
137