Maxtimer97 commited on
Commit
22ba83b
·
1 Parent(s): bb26ab9

Removed assertion

Browse files
Files changed (1) hide show
  1. compressed_attention.py +1 -1
compressed_attention.py CHANGED
@@ -954,7 +954,7 @@ def _get_attention_score(
954
  q_len, num_q_heads, head_dim = q.shape
955
  k_len, num_k_heads, head_dim = k.shape
956
  batch_size = cu_seqlens_q.shape[0] - 1
957
- assert q_len > k_len
958
  if sm_scale is None:
959
  sm_scale = 1 / math.sqrt(head_dim)
960
  # gqa
 
954
  q_len, num_q_heads, head_dim = q.shape
955
  k_len, num_k_heads, head_dim = k.shape
956
  batch_size = cu_seqlens_q.shape[0] - 1
957
+ assert q_len >= k_len
958
  if sm_scale is None:
959
  sm_scale = 1 / math.sqrt(head_dim)
960
  # gqa