Maxtimer97 commited on
Commit
bb26ab9
·
1 Parent(s): c8b397a

Removed wrong edge case assertion

Browse files
Files changed (1) hide show
  1. compressed_attention.py +1 -1
compressed_attention.py CHANGED
@@ -541,7 +541,7 @@ def _compressed_attention_fwd(
541
  k_len, num_k_heads, head_dim = k.shape
542
  v_len, num_v_heads, head_dim = v.shape
543
  batch_size = cu_seqlens_q.shape[0] - 1
544
- assert k_len == v_len and q_len > k_len
545
  # gqa
546
  assert num_k_heads == num_v_heads
547
  assert num_q_heads % num_k_heads == 0
 
541
  k_len, num_k_heads, head_dim = k.shape
542
  v_len, num_v_heads, head_dim = v.shape
543
  batch_size = cu_seqlens_q.shape[0] - 1
544
+ assert k_len == v_len and q_len >= k_len
545
  # gqa
546
  assert num_k_heads == num_v_heads
547
  assert num_q_heads % num_k_heads == 0