Update src/pipeline.py
Browse files- src/pipeline.py +2 -1
src/pipeline.py
CHANGED
|
@@ -923,7 +923,7 @@ def xattn1(query, key, value, attn_mask=None, dropout_p=0.0,
|
|
| 923 |
def xattn1(query, key, value, attn_mask=None, dropout_p=0.0,
|
| 924 |
is_causal=False, scale=None, enable_gqa=False) -> torch.Tensor:
|
| 925 |
device = query.device
|
| 926 |
-
query, key, value
|
| 927 |
with torch.autocast(device_type="cpu", dtype=torch.bfloat16):
|
| 928 |
# Dynamically handle dimensions
|
| 929 |
if query.ndim == 2:
|
|
@@ -959,6 +959,7 @@ def xattn1(query, key, value, attn_mask=None, dropout_p=0.0,
|
|
| 959 |
)
|
| 960 |
|
| 961 |
if attn_mask is not None:
|
|
|
|
| 962 |
if attn_mask.ndim == 2:
|
| 963 |
attn_mask = attn_mask.unsqueeze(0).unsqueeze(0)
|
| 964 |
elif attn_mask.ndim == 3:
|
|
|
|
| 923 |
def xattn1(query, key, value, attn_mask=None, dropout_p=0.0,
|
| 924 |
is_causal=False, scale=None, enable_gqa=False) -> torch.Tensor:
|
| 925 |
device = query.device
|
| 926 |
+
query, key, value = query.cpu(), key.cpu(), value.cpu()
|
| 927 |
with torch.autocast(device_type="cpu", dtype=torch.bfloat16):
|
| 928 |
# Dynamically handle dimensions
|
| 929 |
if query.ndim == 2:
|
|
|
|
| 959 |
)
|
| 960 |
|
| 961 |
if attn_mask is not None:
|
| 962 |
+
attn_mask = attn_mask.cpu()
|
| 963 |
if attn_mask.ndim == 2:
|
| 964 |
attn_mask = attn_mask.unsqueeze(0).unsqueeze(0)
|
| 965 |
elif attn_mask.ndim == 3:
|