alex commited on
Commit
a8172af
·
1 Parent(s): c4d567b

use flash

Browse files
Files changed (1) hide show
  1. sam2/modeling/sam/transformer.py +60 -44
sam2/modeling/sam/transformer.py CHANGED
@@ -23,6 +23,13 @@ from einops import rearrange
23
  warnings.simplefilter(action="ignore", category=FutureWarning)
24
  # OLD_GPU, USE_FLASH_ATTN, MATH_KERNEL_ON = get_sdpa_settings()
25
 
 
 
 
 
 
 
 
26
 
27
  class TwoWayTransformer(nn.Module):
28
  def __init__(
@@ -241,7 +248,22 @@ class Attention(nn.Module):
241
  k = self.k_proj(k)
242
  v = self.v_proj(v)
243
 
244
- if q.device == "cpu":
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
245
  # Separate into heads
246
  q = self._separate_heads(q, self.num_heads)
247
  k = self._separate_heads(k, self.num_heads)
@@ -255,17 +277,6 @@ class Attention(nn.Module):
255
 
256
  out = self._recombine_heads(out)
257
 
258
- else:
259
- q = rearrange(q, "b s (n d) -> b s n d", n=self.num_heads)
260
- k = rearrange(k, "b s (n d) -> b s n d", n=self.num_heads)
261
- v = rearrange(v, "b s (n d) -> b s n d", n=self.num_heads)
262
-
263
- out = flash_attn_interface.flash_attn_func(q, k, v) # -> [b, s_q, n, d]
264
-
265
- out = rearrange(out, "b s n d -> b s (n d)", n=self.num_heads)
266
-
267
- out = self.out_proj(out)
268
-
269
  return out
270
 
271
 
@@ -299,40 +310,10 @@ class RoPEAttention(Attention):
299
  k = self.k_proj(k)
300
  v = self.v_proj(v)
301
 
302
- if q.device == "cpu":
303
- # Separate into heads
304
- q = self._separate_heads(q, self.num_heads)
305
- k = self._separate_heads(k, self.num_heads)
306
- v = self._separate_heads(v, self.num_heads)
307
-
308
- # Apply rotary position encoding
309
- w = h = math.sqrt(q.shape[-2])
310
- self.freqs_cis = self.freqs_cis.to(q.device)
311
- if self.freqs_cis.shape[0] != q.shape[-2]:
312
- self.freqs_cis = self.compute_cis(end_x=w, end_y=h).to(q.device)
313
- if q.shape[-2] != k.shape[-2]:
314
- assert self.rope_k_repeat
315
 
316
- num_k_rope = k.size(-2) - num_k_exclude_rope
317
- q, k[:, :, :num_k_rope] = apply_rotary_enc(
318
- q,
319
- k[:, :, :num_k_rope],
320
- freqs_cis=self.freqs_cis,
321
- repeat_freqs_k=self.rope_k_repeat,
322
- )
323
 
324
- dropout_p = self.dropout_p if self.training else 0.0
325
-
326
- #with torch.nn.attention.sdpa_kernel(get_sdp_backends(dropout_p)):
327
- out = F.scaled_dot_product_attention(q, k, v, dropout_p=dropout_p)
328
-
329
- out = self._recombine_heads(out)
330
- out = self.out_proj(out)
331
-
332
- return out
333
-
334
- else:
335
-
336
  # 1) reshape to (B, H, S, D) so RoPE sees the sequence at dim -2
337
  q_hsd = rearrange(q, "b s (h d) -> b h s d", h=self.num_heads)
338
  k_hsd = rearrange(k, "b s (h d) -> b h s d", h=self.num_heads)
@@ -365,4 +346,39 @@ class RoPEAttention(Attention):
365
  ) # (B, S, H, D)
366
 
367
  out = rearrange(out, "b s h d -> b s (h d)")
 
368
  return self.out_proj(out)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
  warnings.simplefilter(action="ignore", category=FutureWarning)
24
  # OLD_GPU, USE_FLASH_ATTN, MATH_KERNEL_ON = get_sdpa_settings()
25
 
26
+ def _can_use_flash_attn(q: torch.Tensor) -> bool:
27
+ # FlashAttention works on CUDA with fp16/bf16 and (usually) Ampere+ GPUs
28
+ if not q.is_cuda:
29
+ return False
30
+ major, _ = torch.cuda.get_device_capability(q.device)
31
+ return q.dtype in (torch.float16, torch.bfloat16) and major >= 8 # A100/RTX30+ typically
32
+
33
 
34
  class TwoWayTransformer(nn.Module):
35
  def __init__(
 
248
  k = self.k_proj(k)
249
  v = self.v_proj(v)
250
 
251
+ use_flash = _can_use_flash_attn(q)
252
+
253
+ if use_flash:
254
+
255
+ q = rearrange(q, "b s (n d) -> b s n d", n=self.num_heads)
256
+ k = rearrange(k, "b s (n d) -> b s n d", n=self.num_heads)
257
+ v = rearrange(v, "b s (n d) -> b s n d", n=self.num_heads)
258
+
259
+ out = flash_attn_interface.flash_attn_func(q, k, v) # -> [b, s_q, n, d]
260
+
261
+ out = rearrange(out, "b s n d -> b s (n d)", n=self.num_heads)
262
+
263
+ out = self.out_proj(out)
264
+
265
+ else:
266
+
267
  # Separate into heads
268
  q = self._separate_heads(q, self.num_heads)
269
  k = self._separate_heads(k, self.num_heads)
 
277
 
278
  out = self._recombine_heads(out)
279
 
 
 
 
 
 
 
 
 
 
 
 
280
  return out
281
 
282
 
 
310
  k = self.k_proj(k)
311
  v = self.v_proj(v)
312
 
313
+ use_flash = _can_use_flash_attn(q)
 
 
 
 
 
 
 
 
 
 
 
 
314
 
315
+ if use_flash:
 
 
 
 
 
 
316
 
 
 
 
 
 
 
 
 
 
 
 
 
317
  # 1) reshape to (B, H, S, D) so RoPE sees the sequence at dim -2
318
  q_hsd = rearrange(q, "b s (h d) -> b h s d", h=self.num_heads)
319
  k_hsd = rearrange(k, "b s (h d) -> b h s d", h=self.num_heads)
 
346
  ) # (B, S, H, D)
347
 
348
  out = rearrange(out, "b s h d -> b s (h d)")
349
+
350
  return self.out_proj(out)
351
+
352
+ else:
353
+
354
+ # Separate into heads
355
+ q = self._separate_heads(q, self.num_heads)
356
+ k = self._separate_heads(k, self.num_heads)
357
+ v = self._separate_heads(v, self.num_heads)
358
+
359
+ # Apply rotary position encoding
360
+ w = h = math.sqrt(q.shape[-2])
361
+ self.freqs_cis = self.freqs_cis.to(q.device)
362
+ if self.freqs_cis.shape[0] != q.shape[-2]:
363
+ self.freqs_cis = self.compute_cis(end_x=w, end_y=h).to(q.device)
364
+ if q.shape[-2] != k.shape[-2]:
365
+ assert self.rope_k_repeat
366
+
367
+ num_k_rope = k.size(-2) - num_k_exclude_rope
368
+ q, k[:, :, :num_k_rope] = apply_rotary_enc(
369
+ q,
370
+ k[:, :, :num_k_rope],
371
+ freqs_cis=self.freqs_cis,
372
+ repeat_freqs_k=self.rope_k_repeat,
373
+ )
374
+
375
+ dropout_p = self.dropout_p if self.training else 0.0
376
+
377
+ #with torch.nn.attention.sdpa_kernel(get_sdp_backends(dropout_p)):
378
+ out = F.scaled_dot_product_attention(q, k, v, dropout_p=dropout_p)
379
+
380
+ out = self._recombine_heads(out)
381
+ out = self.out_proj(out)
382
+
383
+ return out
384
+