manbeast3b commited on
Commit
f8f95c3
·
verified ·
1 Parent(s): 692691e

Update src/pipeline.py

Browse files
Files changed (1) hide show
  1. src/pipeline.py +2 -1
src/pipeline.py CHANGED
@@ -923,7 +923,7 @@ def xattn1(query, key, value, attn_mask=None, dropout_p=0.0,
923
  def xattn1(query, key, value, attn_mask=None, dropout_p=0.0,
924
  is_causal=False, scale=None, enable_gqa=False) -> torch.Tensor:
925
  device = query.device
926
- query, key, value, attn_mask = query.cpu(), key.cpu(), value.cpu(), None if attn_mask is None else attn_mask
927
  with torch.autocast(device_type="cpu", dtype=torch.bfloat16):
928
  # Dynamically handle dimensions
929
  if query.ndim == 2:
@@ -959,6 +959,7 @@ def xattn1(query, key, value, attn_mask=None, dropout_p=0.0,
959
  )
960
 
961
  if attn_mask is not None:
 
962
  if attn_mask.ndim == 2:
963
  attn_mask = attn_mask.unsqueeze(0).unsqueeze(0)
964
  elif attn_mask.ndim == 3:
 
923
  def xattn1(query, key, value, attn_mask=None, dropout_p=0.0,
924
  is_causal=False, scale=None, enable_gqa=False) -> torch.Tensor:
925
  device = query.device
926
+ query, key, value = query.cpu(), key.cpu(), value.cpu()
927
  with torch.autocast(device_type="cpu", dtype=torch.bfloat16):
928
  # Dynamically handle dimensions
929
  if query.ndim == 2:
 
959
  )
960
 
961
  if attn_mask is not None:
962
+ attn_mask = attn_mask.cpu()
963
  if attn_mask.ndim == 2:
964
  attn_mask = attn_mask.unsqueeze(0).unsqueeze(0)
965
  elif attn_mask.ndim == 3: