Update sam2/modeling/sam/transformer.py
Browse files
sam2/modeling/sam/transformer.py
CHANGED
|
@@ -322,8 +322,7 @@ class RoPEAttention(Attention):
|
|
| 322 |
v_bshd = rearrange(v_hsd, "b h s d -> b s h d")
|
| 323 |
|
| 324 |
out = flash_attn_interface.flash_attn_func(
|
| 325 |
-
q_bshd, k_bshd, v_bshd
|
| 326 |
-
dropout_p=self.dropout_p if self.training else 0.0
|
| 327 |
) # (B, S, H, D)
|
| 328 |
|
| 329 |
out = rearrange(out, "b s h d -> b s (h d)")
|
|
|
|
| 322 |
v_bshd = rearrange(v_hsd, "b h s d -> b s h d")
|
| 323 |
|
| 324 |
out = flash_attn_interface.flash_attn_func(
|
| 325 |
+
q_bshd, k_bshd, v_bshd
|
|
|
|
| 326 |
) # (B, S, H, D)
|
| 327 |
|
| 328 |
out = rearrange(out, "b s h d -> b s (h d)")
|