Upload modeling_cloverlm.py with huggingface_hub
Browse files- modeling_cloverlm.py +15 -3
modeling_cloverlm.py
CHANGED
|
@@ -111,15 +111,27 @@ class MHSA(nn.Module):
|
|
| 111 |
|
| 112 |
dtype = Q.dtype if Q.dtype in (torch.bfloat16, torch.float16) else torch.bfloat16
|
| 113 |
if attn_backend == "flash2":
|
| 114 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 115 |
Y = flash_attn.flash_attn_func(Q.to(dtype), K.to(dtype), V.to(dtype), causal=True, softmax_scale=1.0)
|
| 116 |
elif attn_backend == "flash3":
|
| 117 |
import importlib
|
| 118 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 119 |
Y = _fa3.flash_attn_func(Q.to(dtype), K.to(dtype), V.to(dtype), causal=True, softmax_scale=1.0)
|
| 120 |
elif attn_backend == "flash4":
|
| 121 |
import importlib
|
| 122 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 123 |
Y = _fa4.flash_attn_func(Q.to(dtype), K.to(dtype), V.to(dtype), causal=True, softmax_scale=1.0)[0]
|
| 124 |
Y = Y.to(Q.dtype).flatten(-2, -1)
|
| 125 |
|
|
|
|
| 111 |
|
| 112 |
dtype = Q.dtype if Q.dtype in (torch.bfloat16, torch.float16) else torch.bfloat16
|
| 113 |
if attn_backend == "flash2":
|
| 114 |
+
try:
|
| 115 |
+
import flash_attn
|
| 116 |
+
except ImportError as e:
|
| 117 |
+
e.add_note(f"Can't run `attn_backend=flash2` because can't import flash_attn")
|
| 118 |
+
raise e
|
| 119 |
Y = flash_attn.flash_attn_func(Q.to(dtype), K.to(dtype), V.to(dtype), causal=True, softmax_scale=1.0)
|
| 120 |
elif attn_backend == "flash3":
|
| 121 |
import importlib
|
| 122 |
+
try:
|
| 123 |
+
_fa3 = importlib.import_module("flash_attn_interface")
|
| 124 |
+
except ImportError as e:
|
| 125 |
+
e.add_note(f"Can't run `attn_backend=flash3` because can't import flash_attn_interface")
|
| 126 |
+
raise e
|
| 127 |
Y = _fa3.flash_attn_func(Q.to(dtype), K.to(dtype), V.to(dtype), causal=True, softmax_scale=1.0)
|
| 128 |
elif attn_backend == "flash4":
|
| 129 |
import importlib
|
| 130 |
+
try:
|
| 131 |
+
_fa4 = importlib.import_module("flash_attn.cute")
|
| 132 |
+
except ImportError as e:
|
| 133 |
+
e.add_note(f"Can't run `attn_backend=flash4` because can't import flash_attn.cute")
|
| 134 |
+
raise e
|
| 135 |
Y = _fa4.flash_attn_func(Q.to(dtype), K.to(dtype), V.to(dtype), causal=True, softmax_scale=1.0)[0]
|
| 136 |
Y = Y.to(Q.dtype).flatten(-2, -1)
|
| 137 |
|